eorst/rasterdataset/
sampling.rs

1//! Sampling and extraction methods for RasterDataset.
2//!
3//! This module contains methods for extracting values from raster blocks at vector points.
4
5use crate::core_types::RasterType;
6use crate::gdal_utils::create_rayon_pool;
7use crate::rasterdataset::builder::n_block_cols;
8use crate::types::{Coordinates, Index2d, Rectangle, SamplingMethod};
9use crate::rasterdataset::RasterDataset;
10
11use anyhow::Result;
12use gdal::Dataset;
13use gdal::vector::{Geometry, LayerAccess};
14use itertools::Itertools;
15use kdam::par_tqdm;
16use ndarray::{s, Array2};
17use rayon::prelude::*;
18use std::collections::BTreeMap;
19use std::hash::Hash;
20use std::path::Path;
21
22/// Samples a value from a 2D array at the given point using the specified method.
23///
24/// Replaces the duplicated `match method { SamplingMethod::* }` pattern found in
25/// `extract_blockwise()` and `extract()`.
26fn sample_value(
27    band_data: &ndarray::ArrayView2<i16>,
28    rect: Rectangle,
29    point: Index2d,
30    method: SamplingMethod,
31) -> i16 {
32    match method {
33        SamplingMethod::Value => band_data[(point.row, point.col)],
34        SamplingMethod::Avg => {
35            let window_data: Vec<i16> =
36                crate::array_ops::rect_view(band_data, rect, point)
37                    .iter()
38                    .copied()
39                    .collect();
40            let window_size = window_data.len();
41            let avg: f32 = window_data
42                .iter()
43                .map(|v| *v as f32 / window_size as f32)
44                .sum();
45            avg.round() as i16
46        }
47        SamplingMethod::Mode => {
48            let mut window_data: Vec<i16> =
49                crate::array_ops::rect_view(band_data, rect, point)
50                    .iter()
51                    .copied()
52                    .collect();
53            window_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
54            window_data[window_data.len() / 2]
55        }
56        SamplingMethod::Min => {
57            let window_data: Vec<i16> =
58                crate::array_ops::rect_view(band_data, rect, point)
59                    .iter()
60                    .copied()
61                    .collect();
62            *window_data.iter().min().unwrap()
63        }
64        SamplingMethod::StdDev => {
65            let window_data: Vec<i16> =
66                crate::array_ops::rect_view(band_data, rect, point)
67                    .iter()
68                    .copied()
69                    .collect();
70            let sum: i32 = window_data.iter().map(|&x| x as i32).sum();
71            let mean = sum as f64 / window_data.len() as f64;
72            let variance: f64 = window_data
73                .iter()
74                .map(|&x| (x as f64 - mean).powi(2))
75                .sum::<f64>()
76                / window_data.len() as f64;
77            variance.sqrt().round() as i16
78        }
79    }
80}
81
82/// Shared helper: validates buffer size against overlap size.
83fn validate_buffer_size(buffer_size: usize, overlap_size: usize) {
84    assert!(
85        buffer_size <= overlap_size,
86        "Buffer size has to be > overlap size"
87    );
88}
89
90/// Shared helper: creates a Rectangle from buffer size.
91fn make_rectangle(buffer_size: usize) -> Rectangle {
92    Rectangle {
93        left: buffer_size,
94        top: buffer_size,
95        right: buffer_size,
96        bottom: buffer_size,
97    }
98}
99
100/// Shared helper: builds the index-to-blocks pipeline.
101/// Returns (id_indices, blocks_to_process) from geometry data.
102fn build_block_index_pipeline<R: RasterType>(
103    raster: &RasterDataset<R>,
104    geoms: &BTreeMap<i64, Vec<(f64, f64, f64)>>,
105) -> (BTreeMap<i64, Index2d>, Vec<(usize, (i64, Index2d))>, Vec<usize>) {
106    let idx_global = raster.geoms_to_global_indices(geoms.clone());
107
108    let id_indices: Vec<(usize, (i64, Index2d))> = idx_global
109        .par_iter()
110        .map(|(pid, index)| raster.block_id_rowcol(*pid, *index))
111        .collect();
112
113    let block_ids: Vec<_> = idx_global
114        .par_iter()
115        .map(|(_, index)| raster.id_from_indices(*index))
116        .collect();
117
118    let blocks_to_process: Vec<usize> = block_ids.iter().unique().copied().collect();
119
120    (idx_global, id_indices, blocks_to_process)
121}
122
123/// Shared helper: collects pos/idx/pids for a given block ID.
124fn collect_points_for_block(
125    id_indices: &[(usize, (i64, Index2d))],
126    block_id: usize,
127) -> (Vec<Index2d>, Vec<usize>, Vec<usize>) {
128    let mut pos: Vec<Index2d> = Vec::new();
129    let mut idx: Vec<usize> = Vec::new();
130    let mut pids: Vec<usize> = Vec::new();
131    for (pid, p) in id_indices.iter().enumerate() {
132        if p.0 == block_id {
133            pos.push(Index2d {
134                col: p.1 .1.col,
135                row: p.1 .1.row,
136            });
137            pids.push(pid);
138            idx.push(p.1 .0 as usize);
139        }
140    }
141    (pos, idx, pids)
142}
143
144/// Shared helper: assembles parallel block results into a BTreeMap.
145/// Generic over the key type (i16 or i64).
146fn assemble_block_results<K>(
147    collected: &[(Vec<usize>, Vec<usize>, Vec<Vec<i16>>)],
148    key_converter: fn(usize) -> K,
149) -> BTreeMap<K, Vec<i16>>
150where
151    K: Ord + Hash,
152{
153    let pids: Vec<_> = collected.iter().map(|(pid, _, _)| pid).collect();
154    let vals: Vec<_> = collected.iter().map(|(_, _, vals)| vals).collect();
155    let idxs: Vec<_> = collected.iter().map(|(_, idx, _)| idx).collect();
156    let mut results = BTreeMap::new();
157
158    let num_bands = vals[0].len();
159    let num_blocks = pids.len();
160    for block in 0..num_blocks {
161        for i in 0..pids[block].len() {
162            let mut vals_point: Vec<i16> = Vec::new();
163            let id = idxs[block][i];
164            for band in 0..num_bands {
165                vals_point.push(vals[block][band][i]);
166            }
167            let mut res_point = BTreeMap::new();
168            res_point.insert(key_converter(id), vals_point);
169            results.append(&mut res_point);
170        }
171    }
172    results
173}
174
175impl<R> RasterDataset<R>
176where
177    R: RasterType,
178{
179    /// Converts geometries to global array indices.
180    pub fn geoms_to_global_indices(
181        &self,
182        geoms: BTreeMap<i64, Vec<(f64, f64, f64)>>,
183    ) -> BTreeMap<i64, Index2d> {
184        let idx_global: BTreeMap<_, _> = geoms
185            .par_iter()
186            .map(|(pid, p)| {
187                let point: Coordinates = Coordinates {
188                    x: p[0].0,
189                    y: p[0].1,
190                };
191                (*pid, self.geo_to_global_rc(point))
192            })
193            .collect();
194        idx_global
195    }
196
197    fn geo_to_global_rc(&self, point: Coordinates) -> Index2d {
198        let gt = self.metadata.geo_transform.to_array();
199        let row = ((point.y - gt[3]) / gt[5]) as usize;
200        let col = ((point.x - gt[0]) / gt[1]) as usize;
201        Index2d { col, row }
202    }
203
204    /// Gets the block ID and local row/col for a global point ID.
205    pub fn block_id_rowcol(&self, pid: i64, index: Index2d) -> (usize, (i64, Index2d)) {
206        let id = self.id_from_indices(index);
207        let row_col = self.global_rc_to_block_rc(index);
208        (id, (pid, row_col))
209    }
210
211    fn global_rc_to_block_rc(&self, global_index: Index2d) -> Index2d {
212        let mut block_col = global_index.col % self.metadata.block_size.cols;
213        let mut block_row = global_index.row % self.metadata.block_size.rows;
214
215        let block_col_ov = block_col + self.metadata.overlap_size;
216        let block_row_ov = block_row + self.metadata.overlap_size;
217
218        if (global_index.col as i16 - block_col_ov as i16) > 0 {
219            block_col = block_col_ov;
220        };
221
222        if global_index.row as i16 - block_row_ov as i16 > 0 {
223            block_row = block_row_ov;
224        };
225
226        Index2d {
227            col: block_col,
228            row: block_row,
229        }
230    }
231
232    fn id_from_indices(&self, index: Index2d) -> usize {
233        let n_block_cols = self.n_block_cols();
234        (index.col / self.metadata.block_size.cols)
235            + (index.row / self.metadata.block_size.rows) * n_block_cols
236    }
237
238    fn n_block_cols(&self) -> usize {
239        let image_size = crate::types::ImageSize {
240            rows: self.metadata.shape.rows,
241            cols: self.metadata.shape.cols,
242        };
243        n_block_cols(image_size, self.metadata.block_size)
244    }
245
246    /// Extracts values from the raster dataset for vector features, block-wise.
247    pub fn extract_blockwise(
248        &self,
249        vector_path: &std::path::PathBuf,
250        id_col_name: &str,
251        method: SamplingMethod,
252        buffer_size: Option<usize>,
253    ) -> BTreeMap<i16, Vec<i16>> {
254        log::debug!("Starting extract.");
255        let buffer_size = buffer_size.unwrap_or(0);
256        validate_buffer_size(buffer_size, self.metadata.overlap_size);
257
258        let vector_dataset = Dataset::open(Path::new(vector_path)).unwrap();
259        let mut layer = vector_dataset.layer(0).unwrap();
260        let mut geoms = BTreeMap::new();
261
262        for feature in layer.features() {
263            let mut geom = Vec::new();
264            feature
265                .geometry()
266                .expect("Geometries")
267                .get_points(&mut geom);
268            let field_index = feature.field_index(id_col_name).expect("Bad column name.");
269            let pid_filed = feature.field(field_index).unwrap().unwrap();
270            let pid = pid_filed.into_int64().unwrap();
271            geoms.insert(pid, geom);
272        }
273
274        let (_idx_global, id_indices, blocks_to_process) =
275            build_block_index_pipeline(self, &geoms);
276        drop(geoms);
277
278        let pool = create_rayon_pool(1);
279        let handle = pool.install(|| {
280            par_tqdm!(blocks_to_process
281                .into_par_iter())
282                .map(|id| -> (Vec<usize>, Vec<usize>, Vec<Vec<i16>>) {
283                    let (pos, idx, pids) = collect_points_for_block(&id_indices, id);
284
285                    let mut res = Vec::new();
286                    let rect = make_rectangle(buffer_size);
287
288                    let data = self.read_block(id);
289
290                    let bands = data.shape()[1];
291                    log::debug!("Bands {:?}", bands);
292                    for band_n in 0..bands {
293                        let mut res_band = Vec::new();
294                        let band_data = data.slice(s![0_i32, band_n, .., ..]);
295                        for point in pos.iter() {
296                            let val = sample_value(&band_data, rect, *point, method);
297                            res_band.push(val);
298                        }
299                        res.push(res_band);
300                    }
301                    (pids, idx, res)
302                })
303        });
304
305        let collected: Vec<_> = handle.collect();
306        assemble_block_results(&collected, |id| id as i16)
307    }
308
309    /// Extracts values from the raster dataset for point geometries.
310    pub fn extract(
311        &self,
312        geometries: &[Geometry],
313        point_ids: &[i64],
314        method: SamplingMethod,
315        buffer_size: Option<usize>,
316    ) -> Result<(Array2<i16>, Vec<i64>)> {
317        let buffer_size = buffer_size.unwrap_or(0);
318        validate_buffer_size(buffer_size, self.metadata.overlap_size);
319
320        let mut geoms = BTreeMap::new();
321        for (idx, point_id) in point_ids.iter().enumerate() {
322            let geometry = &geometries[idx];
323            let point = geometry.get_point(0);
324            let (x, y, z) = point;
325            geoms.insert(*point_id, vec![(x, y, z)]);
326        }
327
328        let (_idx_global, id_indices, blocks_to_process) =
329            build_block_index_pipeline(self, &geoms);
330        drop(geoms);
331
332        let blocks_to_process: Vec<usize> = blocks_to_process;
333
334        let pool = create_rayon_pool(1);
335        let handle = pool.install(|| {
336            par_tqdm!(blocks_to_process
337                .into_par_iter())
338                .map(|id| -> (Vec<usize>, Vec<usize>, Vec<Vec<i16>>) {
339                    let (pos, idx, pids) = collect_points_for_block(&id_indices, id);
340                    log::debug!("Extracting {} points, from block: {}", pos.len(), id);
341
342                    let mut res = Vec::new();
343                    let rect = make_rectangle(buffer_size);
344
345                    let data = self.read_block(id);
346                    let n_times = data.shape()[0];
347                    let n_layers = data.shape()[1];
348                    for time in 0..n_times {
349                        for layer in 0..n_layers {
350                            let mut res_band = Vec::new();
351                            let band_data = data.slice(s![time, layer, .., ..]);
352                            for point in pos.iter() {
353                                let col = point.col.checked_sub(self.blocks[id].overlap.left);
354                                let row = point.row.checked_sub(self.blocks[id].overlap.top);
355                                let col = col.unwrap_or(point.col);
356                                let row = row.unwrap_or(point.row);
357                                let val = sample_value(&band_data, rect, Index2d { col, row }, method);
358                                res_band.push(val);
359                            }
360                            res.push(res_band);
361                        }
362                    }
363                    (pids, idx, res)
364                })
365        });
366
367        let collected: Vec<_> = handle.collect();
368        let results = assemble_block_results(&collected, |id| id as i64);
369
370        let k = results.keys().next().unwrap();
371        let n_rows = results.len();
372        let n_cols = results[k].len();
373        let mut array: Array2<i16> = ndarray::Array::zeros((n_rows, n_cols));
374        for (row_index, values) in results.values().enumerate() {
375            for (col_index, value) in values.iter().enumerate() {
376                array[[row_index, col_index]] = *value;
377            }
378        }
379        let pids: Vec<i64> = results.into_keys().collect();
380        Ok((array, pids))
381    }
382}