eorst/
array_ops.rs

1//! Array manipulation operations for raster processing.
2//!
3//! This module provides functions for slicing, trimming, and transforming
4//! multidimensional arrays used in raster data processing.
5
6use crate::types::{Index2d, Overlap, Rectangle};
7use csv::WriterBuilder;
8use ndarray::{s, Array, Array2, Array3, Array4, Axis, ArrayView2, ArrayView3, ArrayView4};
9use ndarray_csv::Array2Writer;
10use rand::Rng;
11use rand::SeedableRng;
12use rand_chacha::ChaCha8Rng;
13use std::cmp::max;
14use std::fs::File;
15use std::path::PathBuf;
16
17/// Extracts a rectangular view from a 2D array.
18pub fn rect_view<'a, T>(
19    array: &'a ArrayView2<T>,
20    rect: Rectangle,
21    indices: Index2d,
22) -> ArrayView2<'a, T> {
23    let max_row = indices.row + rect.bottom;
24    let min_row = max(0, indices.row as i32 - rect.top as i32) as usize;
25
26    let max_col = indices.col + rect.right;
27    let min_col = max(0, indices.col as i32 - rect.left as i32) as usize;
28
29    array.slice(s![min_row..max_row - 1, min_col..max_col - 1])
30}
31
32/// Writes a 2D vector to a CSV file.
33pub fn write_csv_array(data: Vec<Vec<i16>>, csv_filename: PathBuf) {
34    let rows = data.len();
35    let cols = data[0].len();
36    let flat: Vec<i16> = data.into_iter().flatten().collect();
37    let array = Array::from(flat)
38        .into_shape_clone((rows, cols))
39        .expect("Unable to reshape");
40    {
41        let file = File::create(csv_filename).expect("Unable to create file");
42        let mut writer = WriterBuilder::new().has_headers(true).from_writer(file);
43        writer
44            .serialize_array2(&array)
45            .expect("Unable to serialize array to file");
46    }
47}
48
49/// Trims the overlap border from a 3D array.
50pub fn trimm_array3<T>(array: &Array3<T>, overlap_size: usize) -> ArrayView3<'_, T> {
51    let min_row = overlap_size;
52    let max_row = array.shape()[1] - overlap_size;
53
54    let min_col = overlap_size;
55    let max_col = array.shape()[2] - overlap_size;
56
57    array.slice(s![.., min_row..max_row, min_col..max_col])
58}
59
60/// Returns the index of the maximum value in a slice.
61pub fn argmax<T: PartialOrd>(xs: &[T]) -> usize {
62    if xs.len() == 1 {
63        0
64    } else {
65        let mut maxval = &xs[0];
66        let mut max_ixs: Vec<usize> = vec![0];
67        for (i, x) in xs.iter().enumerate().skip(1) {
68            if x > maxval {
69                maxval = x;
70                max_ixs = vec![i];
71            } else if x == maxval {
72                max_ixs.push(i);
73            }
74        }
75        max_ixs[0]
76    }
77}
78
79/// Trims the overlap border from a 4D array.
80/// The memory layout of the output is not guaranteed.
81pub fn trimm_array4<T>(array: &Array4<T>, overlap_size: usize) -> ArrayView4<'_, T> {
82    let min_row = overlap_size;
83    let max_row = array.shape()[2] - overlap_size;
84
85    let min_col = overlap_size;
86    let max_col = array.shape()[3] - overlap_size;
87    let slice = s![.., .., min_row..max_row, min_col..max_col];
88    array.slice(slice)
89}
90
91/// Trims the overlap border from a 3D array using asymmetric overlap.
92pub fn trimm_array3_asymmetric<'a, T>(
93    array: &'a Array3<T>,
94    overlap: &Overlap,
95) -> ArrayView3<'a, T> {
96    let min_row = overlap.top;
97    let max_row = array.shape()[1] - overlap.bottom;
98
99    let min_col = overlap.left;
100    let max_col = array.shape()[2] - overlap.right;
101
102    array.slice(s![.., min_row..max_row, min_col..max_col])
103}
104
105/// Trims the overlap border from a 4D array, returning an owned array with standard layout.
106pub fn trimm_array4_owned<T: std::clone::Clone + std::fmt::Debug>(
107    array: &Array4<T>,
108    overlap: &Overlap,
109) -> Array4<T> {
110    let min_row = overlap.top;
111    let max_row = array.shape()[2] - overlap.bottom;
112
113    let min_col = overlap.left;
114    let max_col = array.shape()[3] - overlap.right;
115    let slice = s![.., .., min_row..max_row, min_col..max_col];
116    let trimmed = array.slice(slice);
117    let result = trimmed.as_standard_layout().to_owned();
118    result
119}
120
121/// Creates an array with clustered values (for testing).
122pub fn create_clustered_array(size: usize, num_values: u32, cluster_size: usize) -> Array2<u32> {
123    let mut array: Array2<u32> = Array::zeros((size, size));
124
125    let loop_size = size * size / cluster_size / cluster_size;
126    for idx in 0..loop_size {
127        let mut rng = ChaCha8Rng::seed_from_u64(idx.try_into().unwrap());
128        let value = rng.gen_range(1..=num_values);
129        let row = rng.gen_range(0..size);
130        let col = rng.gen_range(0..size);
131
132        for i in 0..cluster_size {
133            for j in 0..cluster_size {
134                let x = (row + i) % size;
135                let y = (col + j) % size;
136                array[[x, y]] = value;
137            }
138        }
139    }
140
141    array
142}
143
144#[allow(dead_code)]
145pub(crate) fn array2_to_nested_vec<T: std::clone::Clone>(arr: &Array2<T>) -> Vec<Vec<T>> {
146    let mut res: Vec<Vec<T>> = Vec::new();
147    arr.axis_iter(Axis(0)) //
148        .for_each(|row| res.push(row.to_vec()));
149    res
150}
151
152/// Fills no-data values in a 3D array using simple neighbor propagation.
153pub fn fill_nodata_simple(array: &mut Array3<f32>, nodata: f32) {
154    let (bands, rows, cols) = array.dim();
155
156    for b in 0..bands {
157        let mut band_view = array.slice_mut(s![b, .., ..]);
158        let mut changes = true;
159
160        // repeat until all NaNs are filled
161        while changes {
162            changes = false;
163            let copy = band_view.to_owned();
164            for r in 0..rows {
165                for c in 0..cols {
166                    if band_view[[r, c]] == nodata {
167                        let mut sum = 0.0;
168                        let mut count = 0;
169                        for dr in -1..=1 {
170                            for dc in -1..=1 {
171                                let nr = r as isize + dr;
172                                let nc = c as isize + dc;
173                                if nr >= 0 && nr < rows as isize && nc >= 0 && nc < cols as isize {
174                                    let val = copy[[nr as usize, nc as usize]];
175                                    if val != nodata {
176                                        sum += val;
177                                        count += 1;
178                                    }
179                                }
180                            }
181                        }
182                        if count > 0 {
183                            band_view[[r, c]] = sum / count as f32;
184                            changes = true;
185                        }
186                    }
187                }
188            }
189        }
190    }
191}