1use crate::types::{Index2d, Overlap, Rectangle};
7use csv::WriterBuilder;
8use ndarray::{s, Array, Array2, Array3, Array4, Axis, ArrayView2, ArrayView3, ArrayView4};
9use ndarray_csv::Array2Writer;
10use rand::Rng;
11use rand::SeedableRng;
12use rand_chacha::ChaCha8Rng;
13use std::cmp::max;
14use std::fs::File;
15use std::path::PathBuf;
16
17pub fn rect_view<'a, T>(
19 array: &'a ArrayView2<T>,
20 rect: Rectangle,
21 indices: Index2d,
22) -> ArrayView2<'a, T> {
23 let max_row = indices.row + rect.bottom;
24 let min_row = max(0, indices.row as i32 - rect.top as i32) as usize;
25
26 let max_col = indices.col + rect.right;
27 let min_col = max(0, indices.col as i32 - rect.left as i32) as usize;
28
29 array.slice(s![min_row..max_row - 1, min_col..max_col - 1])
30}
31
32pub fn write_csv_array(data: Vec<Vec<i16>>, csv_filename: PathBuf) {
34 let rows = data.len();
35 let cols = data[0].len();
36 let flat: Vec<i16> = data.into_iter().flatten().collect();
37 let array = Array::from(flat)
38 .into_shape_clone((rows, cols))
39 .expect("Unable to reshape");
40 {
41 let file = File::create(csv_filename).expect("Unable to create file");
42 let mut writer = WriterBuilder::new().has_headers(true).from_writer(file);
43 writer
44 .serialize_array2(&array)
45 .expect("Unable to serialize array to file");
46 }
47}
48
49pub fn trimm_array3<T>(array: &Array3<T>, overlap_size: usize) -> ArrayView3<'_, T> {
51 let min_row = overlap_size;
52 let max_row = array.shape()[1] - overlap_size;
53
54 let min_col = overlap_size;
55 let max_col = array.shape()[2] - overlap_size;
56
57 array.slice(s![.., min_row..max_row, min_col..max_col])
58}
59
60pub fn argmax<T: PartialOrd>(xs: &[T]) -> usize {
62 if xs.len() == 1 {
63 0
64 } else {
65 let mut maxval = &xs[0];
66 let mut max_ixs: Vec<usize> = vec![0];
67 for (i, x) in xs.iter().enumerate().skip(1) {
68 if x > maxval {
69 maxval = x;
70 max_ixs = vec![i];
71 } else if x == maxval {
72 max_ixs.push(i);
73 }
74 }
75 max_ixs[0]
76 }
77}
78
79pub fn trimm_array4<T>(array: &Array4<T>, overlap_size: usize) -> ArrayView4<'_, T> {
82 let min_row = overlap_size;
83 let max_row = array.shape()[2] - overlap_size;
84
85 let min_col = overlap_size;
86 let max_col = array.shape()[3] - overlap_size;
87 let slice = s![.., .., min_row..max_row, min_col..max_col];
88 array.slice(slice)
89}
90
91pub fn trimm_array3_asymmetric<'a, T>(
93 array: &'a Array3<T>,
94 overlap: &Overlap,
95) -> ArrayView3<'a, T> {
96 let min_row = overlap.top;
97 let max_row = array.shape()[1] - overlap.bottom;
98
99 let min_col = overlap.left;
100 let max_col = array.shape()[2] - overlap.right;
101
102 array.slice(s![.., min_row..max_row, min_col..max_col])
103}
104
105pub fn trimm_array4_owned<T: std::clone::Clone + std::fmt::Debug>(
107 array: &Array4<T>,
108 overlap: &Overlap,
109) -> Array4<T> {
110 let min_row = overlap.top;
111 let max_row = array.shape()[2] - overlap.bottom;
112
113 let min_col = overlap.left;
114 let max_col = array.shape()[3] - overlap.right;
115 let slice = s![.., .., min_row..max_row, min_col..max_col];
116 let trimmed = array.slice(slice);
117 let result = trimmed.as_standard_layout().to_owned();
118 result
119}
120
121pub fn create_clustered_array(size: usize, num_values: u32, cluster_size: usize) -> Array2<u32> {
123 let mut array: Array2<u32> = Array::zeros((size, size));
124
125 let loop_size = size * size / cluster_size / cluster_size;
126 for idx in 0..loop_size {
127 let mut rng = ChaCha8Rng::seed_from_u64(idx.try_into().unwrap());
128 let value = rng.gen_range(1..=num_values);
129 let row = rng.gen_range(0..size);
130 let col = rng.gen_range(0..size);
131
132 for i in 0..cluster_size {
133 for j in 0..cluster_size {
134 let x = (row + i) % size;
135 let y = (col + j) % size;
136 array[[x, y]] = value;
137 }
138 }
139 }
140
141 array
142}
143
144#[allow(dead_code)]
145pub(crate) fn array2_to_nested_vec<T: std::clone::Clone>(arr: &Array2<T>) -> Vec<Vec<T>> {
146 let mut res: Vec<Vec<T>> = Vec::new();
147 arr.axis_iter(Axis(0)) .for_each(|row| res.push(row.to_vec()));
149 res
150}
151
152pub fn fill_nodata_simple(array: &mut Array3<f32>, nodata: f32) {
154 let (bands, rows, cols) = array.dim();
155
156 for b in 0..bands {
157 let mut band_view = array.slice_mut(s![b, .., ..]);
158 let mut changes = true;
159
160 while changes {
162 changes = false;
163 let copy = band_view.to_owned();
164 for r in 0..rows {
165 for c in 0..cols {
166 if band_view[[r, c]] == nodata {
167 let mut sum = 0.0;
168 let mut count = 0;
169 for dr in -1..=1 {
170 for dc in -1..=1 {
171 let nr = r as isize + dr;
172 let nc = c as isize + dc;
173 if nr >= 0 && nr < rows as isize && nc >= 0 && nc < cols as isize {
174 let val = copy[[nr as usize, nc as usize]];
175 if val != nodata {
176 sum += val;
177 count += 1;
178 }
179 }
180 }
181 }
182 if count > 0 {
183 band_view[[r, c]] = sum / count as f32;
184 changes = true;
185 }
186 }
187 }
188 }
189 }
190 }
191}