1use crate::core_types::RasterType;
6use crate::gdal_utils::create_rayon_pool;
7use crate::rasterdataset::builder::n_block_cols;
8use crate::types::{Coordinates, Index2d, Rectangle, SamplingMethod};
9use crate::rasterdataset::RasterDataset;
10
11use anyhow::Result;
12use gdal::Dataset;
13use gdal::vector::{Geometry, LayerAccess};
14use itertools::Itertools;
15use kdam::par_tqdm;
16use ndarray::{s, Array2};
17use rayon::prelude::*;
18use std::collections::BTreeMap;
19use std::hash::Hash;
20use std::path::Path;
21
22fn sample_value(
27 band_data: &ndarray::ArrayView2<i16>,
28 rect: Rectangle,
29 point: Index2d,
30 method: SamplingMethod,
31) -> i16 {
32 match method {
33 SamplingMethod::Value => band_data[(point.row, point.col)],
34 SamplingMethod::Avg => {
35 let window_data: Vec<i16> =
36 crate::array_ops::rect_view(band_data, rect, point)
37 .iter()
38 .copied()
39 .collect();
40 let window_size = window_data.len();
41 let avg: f32 = window_data
42 .iter()
43 .map(|v| *v as f32 / window_size as f32)
44 .sum();
45 avg.round() as i16
46 }
47 SamplingMethod::Mode => {
48 let mut window_data: Vec<i16> =
49 crate::array_ops::rect_view(band_data, rect, point)
50 .iter()
51 .copied()
52 .collect();
53 window_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
54 window_data[window_data.len() / 2]
55 }
56 SamplingMethod::Min => {
57 let window_data: Vec<i16> =
58 crate::array_ops::rect_view(band_data, rect, point)
59 .iter()
60 .copied()
61 .collect();
62 *window_data.iter().min().unwrap()
63 }
64 SamplingMethod::StdDev => {
65 let window_data: Vec<i16> =
66 crate::array_ops::rect_view(band_data, rect, point)
67 .iter()
68 .copied()
69 .collect();
70 let sum: i32 = window_data.iter().map(|&x| x as i32).sum();
71 let mean = sum as f64 / window_data.len() as f64;
72 let variance: f64 = window_data
73 .iter()
74 .map(|&x| (x as f64 - mean).powi(2))
75 .sum::<f64>()
76 / window_data.len() as f64;
77 variance.sqrt().round() as i16
78 }
79 }
80}
81
82fn validate_buffer_size(buffer_size: usize, overlap_size: usize) {
84 assert!(
85 buffer_size <= overlap_size,
86 "Buffer size has to be > overlap size"
87 );
88}
89
90fn make_rectangle(buffer_size: usize) -> Rectangle {
92 Rectangle {
93 left: buffer_size,
94 top: buffer_size,
95 right: buffer_size,
96 bottom: buffer_size,
97 }
98}
99
100fn build_block_index_pipeline<R: RasterType>(
103 raster: &RasterDataset<R>,
104 geoms: &BTreeMap<i64, Vec<(f64, f64, f64)>>,
105) -> (BTreeMap<i64, Index2d>, Vec<(usize, (i64, Index2d))>, Vec<usize>) {
106 let idx_global = raster.geoms_to_global_indices(geoms.clone());
107
108 let id_indices: Vec<(usize, (i64, Index2d))> = idx_global
109 .par_iter()
110 .map(|(pid, index)| raster.block_id_rowcol(*pid, *index))
111 .collect();
112
113 let block_ids: Vec<_> = idx_global
114 .par_iter()
115 .map(|(_, index)| raster.id_from_indices(*index))
116 .collect();
117
118 let blocks_to_process: Vec<usize> = block_ids.iter().unique().copied().collect();
119
120 (idx_global, id_indices, blocks_to_process)
121}
122
123fn collect_points_for_block(
125 id_indices: &[(usize, (i64, Index2d))],
126 block_id: usize,
127) -> (Vec<Index2d>, Vec<usize>, Vec<usize>) {
128 let mut pos: Vec<Index2d> = Vec::new();
129 let mut idx: Vec<usize> = Vec::new();
130 let mut pids: Vec<usize> = Vec::new();
131 for (pid, p) in id_indices.iter().enumerate() {
132 if p.0 == block_id {
133 pos.push(Index2d {
134 col: p.1 .1.col,
135 row: p.1 .1.row,
136 });
137 pids.push(pid);
138 idx.push(p.1 .0 as usize);
139 }
140 }
141 (pos, idx, pids)
142}
143
144fn assemble_block_results<K>(
147 collected: &[(Vec<usize>, Vec<usize>, Vec<Vec<i16>>)],
148 key_converter: fn(usize) -> K,
149) -> BTreeMap<K, Vec<i16>>
150where
151 K: Ord + Hash,
152{
153 let pids: Vec<_> = collected.iter().map(|(pid, _, _)| pid).collect();
154 let vals: Vec<_> = collected.iter().map(|(_, _, vals)| vals).collect();
155 let idxs: Vec<_> = collected.iter().map(|(_, idx, _)| idx).collect();
156 let mut results = BTreeMap::new();
157
158 let num_bands = vals[0].len();
159 let num_blocks = pids.len();
160 for block in 0..num_blocks {
161 for i in 0..pids[block].len() {
162 let mut vals_point: Vec<i16> = Vec::new();
163 let id = idxs[block][i];
164 for band in 0..num_bands {
165 vals_point.push(vals[block][band][i]);
166 }
167 let mut res_point = BTreeMap::new();
168 res_point.insert(key_converter(id), vals_point);
169 results.append(&mut res_point);
170 }
171 }
172 results
173}
174
175impl<R> RasterDataset<R>
176where
177 R: RasterType,
178{
179 pub fn geoms_to_global_indices(
181 &self,
182 geoms: BTreeMap<i64, Vec<(f64, f64, f64)>>,
183 ) -> BTreeMap<i64, Index2d> {
184 let idx_global: BTreeMap<_, _> = geoms
185 .par_iter()
186 .map(|(pid, p)| {
187 let point: Coordinates = Coordinates {
188 x: p[0].0,
189 y: p[0].1,
190 };
191 (*pid, self.geo_to_global_rc(point))
192 })
193 .collect();
194 idx_global
195 }
196
197 fn geo_to_global_rc(&self, point: Coordinates) -> Index2d {
198 let gt = self.metadata.geo_transform.to_array();
199 let row = ((point.y - gt[3]) / gt[5]) as usize;
200 let col = ((point.x - gt[0]) / gt[1]) as usize;
201 Index2d { col, row }
202 }
203
204 pub fn block_id_rowcol(&self, pid: i64, index: Index2d) -> (usize, (i64, Index2d)) {
206 let id = self.id_from_indices(index);
207 let row_col = self.global_rc_to_block_rc(index);
208 (id, (pid, row_col))
209 }
210
211 fn global_rc_to_block_rc(&self, global_index: Index2d) -> Index2d {
212 let mut block_col = global_index.col % self.metadata.block_size.cols;
213 let mut block_row = global_index.row % self.metadata.block_size.rows;
214
215 let block_col_ov = block_col + self.metadata.overlap_size;
216 let block_row_ov = block_row + self.metadata.overlap_size;
217
218 if (global_index.col as i16 - block_col_ov as i16) > 0 {
219 block_col = block_col_ov;
220 };
221
222 if global_index.row as i16 - block_row_ov as i16 > 0 {
223 block_row = block_row_ov;
224 };
225
226 Index2d {
227 col: block_col,
228 row: block_row,
229 }
230 }
231
232 fn id_from_indices(&self, index: Index2d) -> usize {
233 let n_block_cols = self.n_block_cols();
234 (index.col / self.metadata.block_size.cols)
235 + (index.row / self.metadata.block_size.rows) * n_block_cols
236 }
237
238 fn n_block_cols(&self) -> usize {
239 let image_size = crate::types::ImageSize {
240 rows: self.metadata.shape.rows,
241 cols: self.metadata.shape.cols,
242 };
243 n_block_cols(image_size, self.metadata.block_size)
244 }
245
246 pub fn extract_blockwise(
248 &self,
249 vector_path: &std::path::PathBuf,
250 id_col_name: &str,
251 method: SamplingMethod,
252 buffer_size: Option<usize>,
253 ) -> BTreeMap<i16, Vec<i16>> {
254 log::debug!("Starting extract.");
255 let buffer_size = buffer_size.unwrap_or(0);
256 validate_buffer_size(buffer_size, self.metadata.overlap_size);
257
258 let vector_dataset = Dataset::open(Path::new(vector_path)).unwrap();
259 let mut layer = vector_dataset.layer(0).unwrap();
260 let mut geoms = BTreeMap::new();
261
262 for feature in layer.features() {
263 let mut geom = Vec::new();
264 feature
265 .geometry()
266 .expect("Geometries")
267 .get_points(&mut geom);
268 let field_index = feature.field_index(id_col_name).expect("Bad column name.");
269 let pid_filed = feature.field(field_index).unwrap().unwrap();
270 let pid = pid_filed.into_int64().unwrap();
271 geoms.insert(pid, geom);
272 }
273
274 let (_idx_global, id_indices, blocks_to_process) =
275 build_block_index_pipeline(self, &geoms);
276 drop(geoms);
277
278 let pool = create_rayon_pool(1);
279 let handle = pool.install(|| {
280 par_tqdm!(blocks_to_process
281 .into_par_iter())
282 .map(|id| -> (Vec<usize>, Vec<usize>, Vec<Vec<i16>>) {
283 let (pos, idx, pids) = collect_points_for_block(&id_indices, id);
284
285 let mut res = Vec::new();
286 let rect = make_rectangle(buffer_size);
287
288 let data = self.read_block(id);
289
290 let bands = data.shape()[1];
291 log::debug!("Bands {:?}", bands);
292 for band_n in 0..bands {
293 let mut res_band = Vec::new();
294 let band_data = data.slice(s![0_i32, band_n, .., ..]);
295 for point in pos.iter() {
296 let val = sample_value(&band_data, rect, *point, method);
297 res_band.push(val);
298 }
299 res.push(res_band);
300 }
301 (pids, idx, res)
302 })
303 });
304
305 let collected: Vec<_> = handle.collect();
306 assemble_block_results(&collected, |id| id as i16)
307 }
308
309 pub fn extract(
311 &self,
312 geometries: &[Geometry],
313 point_ids: &[i64],
314 method: SamplingMethod,
315 buffer_size: Option<usize>,
316 ) -> Result<(Array2<i16>, Vec<i64>)> {
317 let buffer_size = buffer_size.unwrap_or(0);
318 validate_buffer_size(buffer_size, self.metadata.overlap_size);
319
320 let mut geoms = BTreeMap::new();
321 for (idx, point_id) in point_ids.iter().enumerate() {
322 let geometry = &geometries[idx];
323 let point = geometry.get_point(0);
324 let (x, y, z) = point;
325 geoms.insert(*point_id, vec![(x, y, z)]);
326 }
327
328 let (_idx_global, id_indices, blocks_to_process) =
329 build_block_index_pipeline(self, &geoms);
330 drop(geoms);
331
332 let blocks_to_process: Vec<usize> = blocks_to_process;
333
334 let pool = create_rayon_pool(1);
335 let handle = pool.install(|| {
336 par_tqdm!(blocks_to_process
337 .into_par_iter())
338 .map(|id| -> (Vec<usize>, Vec<usize>, Vec<Vec<i16>>) {
339 let (pos, idx, pids) = collect_points_for_block(&id_indices, id);
340 log::debug!("Extracting {} points, from block: {}", pos.len(), id);
341
342 let mut res = Vec::new();
343 let rect = make_rectangle(buffer_size);
344
345 let data = self.read_block(id);
346 let n_times = data.shape()[0];
347 let n_layers = data.shape()[1];
348 for time in 0..n_times {
349 for layer in 0..n_layers {
350 let mut res_band = Vec::new();
351 let band_data = data.slice(s![time, layer, .., ..]);
352 for point in pos.iter() {
353 let col = point.col.checked_sub(self.blocks[id].overlap.left);
354 let row = point.row.checked_sub(self.blocks[id].overlap.top);
355 let col = col.unwrap_or(point.col);
356 let row = row.unwrap_or(point.row);
357 let val = sample_value(&band_data, rect, Index2d { col, row }, method);
358 res_band.push(val);
359 }
360 res.push(res_band);
361 }
362 }
363 (pids, idx, res)
364 })
365 });
366
367 let collected: Vec<_> = handle.collect();
368 let results = assemble_block_results(&collected, |id| id as i64);
369
370 let k = results.keys().next().unwrap();
371 let n_rows = results.len();
372 let n_cols = results[k].len();
373 let mut array: Array2<i16> = ndarray::Array::zeros((n_rows, n_cols));
374 for (row_index, values) in results.values().enumerate() {
375 for (col_index, value) in values.iter().enumerate() {
376 array[[row_index, col_index]] = *value;
377 }
378 }
379 let pids: Vec<i64> = results.into_keys().collect();
380 Ok((array, pids))
381 }
382}