eorst/
selection.rs

1//! Selection and aggregation traits for raster data.
2//!
3//! This module provides traits for selecting layers and time slices from raster data,
4//! stacking shapes, and computing aggregations along dimensions.
5
6use crate::core_types::{RasterData, RasterType};
7use crate::data_sources::DateType;
8use crate::metadata::{RasterDataBlock, RasterMetadata};
9use crate::types::{Dimension, RasterDataShape};
10use anyhow::Result;
11use gdal::raster::GdalType;
12use ndarray::{Array2, Array3, Axis};
13use num_traits::{FromPrimitive, ToPrimitive};
14use std::fmt::{self, Debug};
15use std::ops::{Add, Div, Index, IndexMut};
16
17/// Trait for stacking RasterDataShapes.
18pub trait Stack {
19    /// Stacks another shape along a dimension.
20    fn stack(&mut self, other: RasterDataShape, dim_to_stack: Dimension) -> &mut RasterDataShape;
21    /// Extends the time dimension.
22    fn extend(&mut self, other: RasterDataShape) -> &mut RasterDataShape;
23}
24
25impl Stack for RasterDataShape {
26    fn extend(&mut self, other: RasterDataShape) -> &mut RasterDataShape {
27        let mut extendable = true;
28
29        for dim_loc in 1..4 {
30            if self[dim_loc] != other[dim_loc] {
31                extendable = false
32            }
33        }
34
35        if extendable {
36            self[0] += other[0];
37            self
38        } else {
39            panic!("Unable to extend layers");
40        }
41    }
42
43    fn stack(&mut self, other: RasterDataShape, dim_to_stack: Dimension) -> &mut RasterDataShape {
44        let dimension_axis = dim_to_stack.get_axis();
45        let mut stackable = true;
46        for dim_loc in 0..4 {
47            if dim_loc != dimension_axis && self[dim_loc] != other[dim_loc] {
48                stackable = false;
49            }
50        }
51
52        if stackable {
53            self[dimension_axis] += other[dimension_axis];
54            self
55        } else {
56            panic!("Unable to stack layers");
57        }
58    }
59}
60
61impl Dimension {
62    /// Returns the axis index for this dimension.
63    pub fn get_axis(&self) -> usize {
64        match self {
65            Dimension::Layer => 1,
66            Dimension::Time => 0,
67        }
68    }
69}
70
71impl Index<usize> for RasterDataShape {
72    type Output = usize;
73
74    fn index(&self, index: usize) -> &usize {
75        match index {
76            0 => &self.times,
77            1 => &self.layers,
78            2 => &self.rows,
79            3 => &self.cols,
80            n => panic!("Invalid index: {}", n),
81        }
82    }
83}
84
85impl IndexMut<usize> for RasterDataShape {
86    fn index_mut(&mut self, index: usize) -> &mut usize {
87        match index {
88            0 => &mut self.times,
89            1 => &mut self.layers,
90            2 => &mut self.rows,
91            3 => &mut self.cols,
92            n => panic!("Invalid index: {}", n),
93        }
94    }
95}
96
97/// Trait for summing values along a specific dimension.
98pub trait SumDimension<T>
99where
100    T: GdalType + num_traits::identities::Zero + Copy + FromPrimitive + Add<Output = T> + Div<Output = T>,
101{
102    /// Sums all values along the given dimension.
103    fn sum_dimension(&self, dimension: Dimension) -> Array3<T>;
104}
105
106impl<T> SumDimension<T> for RasterData<T>
107where
108    T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T>,
109{
110    fn sum_dimension(&self, dimension: Dimension) -> Array3<T> {
111        match dimension {
112            Dimension::Layer => self.sum_axis(Axis(1)),
113            Dimension::Time => self.sum_axis(Axis(0)),
114        }
115    }
116}
117
118#[allow(dead_code)]
119pub(crate) trait MeanDimension<T>
120where
121    T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T>,
122{
123    fn mean_dimension(&self, dimension: Dimension) -> Array3<T>;
124}
125
126impl<T> MeanDimension<T> for RasterData<T>
127where
128    T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T>,
129{
130    fn mean_dimension(&self, dimension: Dimension) -> Array3<T> {
131        let mean = match dimension {
132            Dimension::Layer => self.mean_axis(Axis(1)),
133            Dimension::Time => self.mean_axis(Axis(0)),
134        };
135        mean.unwrap()
136    }
137}
138
139/// Trait for computing variance along a specific dimension.
140pub trait VarDimension<T>
141where
142    T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T> + num_traits::Float,
143{
144    /// Computes variance along the given dimension.
145    /// `ddof` is the delta degrees of freedom (0 for population, 1 for sample).
146    fn var_dimension(&self, ddof: T, dimension: Dimension) -> Array3<T>;
147}
148
149impl<T> VarDimension<T> for RasterData<T>
150where
151    T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T> + num_traits::Float,
152{
153    fn var_dimension(&self, ddof: T, dimension: Dimension) -> Array3<T> {
154        match dimension {
155            Dimension::Layer => self.var_axis(Axis(1), ddof),
156            Dimension::Time => self.var_axis(Axis(0), ddof),
157        }
158    }
159}
160
161#[allow(dead_code)]
162pub(crate) trait StdDimension<T>
163where
164    T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T> + num_traits::Float,
165{
166    fn std_dimension(&self, ddof: T, dimension: Dimension) -> Array3<T>;
167}
168
169impl<T> StdDimension<T> for RasterData<T>
170where
171    T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T> + num_traits::Float,
172{
173    fn std_dimension(&self, ddof: T, dimension: Dimension) -> Array3<T> {
174        match dimension {
175            Dimension::Layer => self.std_axis(Axis(1), ddof),
176            Dimension::Time => self.std_axis(Axis(0), ddof),
177        }
178    }
179}
180
181/// Error type for layer/time selection operations.
182#[derive(Debug)]
183pub enum SelectError {
184    /// Requested layer name was not found.
185    LayerNotFound {
186        /// The layer name that was requested.
187        requested: String,
188        /// Available layer names.
189        available: Vec<String>,
190    },
191    /// Requested time index was not found.
192    TimeNotFound {
193        /// The date type that was requested.
194        requested: DateType,
195        /// Available date indices.
196        available: Vec<DateType>,
197    },
198    /// An empty selection was requested.
199    EmptySelection,
200    /// Array concatenation/stacking failed.
201    ConcatenationError(String),
202}
203
204impl fmt::Display for SelectError {
205    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206        match self {
207            SelectError::LayerNotFound { requested, available } => {
208                write!(f, "Layer '{}' not found. Available: {:?}", requested, available)
209            }
210            SelectError::TimeNotFound { requested, available } => {
211                write!(f, "Time {:?} not found. Available: {:?}", requested, available)
212            }
213            SelectError::EmptySelection => {
214                write!(f, "Empty selection requested")
215            }
216            SelectError::ConcatenationError(msg) => {
217                write!(f, "Array concatenation failed: {}", msg)
218            }
219        }
220    }
221}
222
223impl std::error::Error for SelectError {}
224
225/// Trait for selecting layers and time slices from raster data by name.
226pub trait Select<T>
227where
228    T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T>,
229{
230    /// Select multiple layers by name.
231    fn select_layers(&self, layer_names: &[&str]) -> Result<RasterDataBlock<T>, SelectError>;
232
233    /// Select multiple time slices by date index.
234    fn select_times(&self, dates: &[DateType]) -> Result<RasterDataBlock<T>, SelectError>;
235
236    /// Find the index of a layer by name.
237    fn find_layer_index(&self, name: &str) -> Result<usize, SelectError>;
238
239    /// Find the index of a time slice by date type.
240    fn find_time_index(&self, date: &DateType) -> Result<usize, SelectError>;
241}
242
243impl<T> Select<T> for RasterDataBlock<T>
244where
245    T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T>,
246{
247    fn find_layer_index(&self, name: &str) -> Result<usize, SelectError> {
248        self.metadata
249            .layer_indices
250            .iter()
251            .position(|s| s.as_str() == name)
252            .ok_or(SelectError::LayerNotFound {
253                requested: name.to_string(),
254                available: self.metadata.layer_indices.clone(),
255            })
256    }
257
258    fn find_time_index(&self, date: &DateType) -> Result<usize, SelectError> {
259        self.metadata
260            .date_indices
261            .iter()
262            .position(|d| d == date)
263            .ok_or(SelectError::TimeNotFound {
264                requested: date.clone(),
265                available: self.metadata.date_indices.clone(),
266            })
267    }
268
269    fn select_layers(&self, layer_names: &[&str]) -> Result<RasterDataBlock<T>, SelectError> {
270        if layer_names.is_empty() {
271            return Err(SelectError::EmptySelection);
272        }
273
274        let indices: Vec<usize> = layer_names
275            .iter()
276            .map(|name| self.find_layer_index(name))
277            .collect::<Result<_, _>>()?;
278
279        let views: Vec<_> = indices
280            .iter()
281            .map(|&idx| self.data.index_axis(Axis(1), idx))
282            .collect();
283
284        let data = ndarray::stack(Axis(1), &views)
285            .map_err(|e| SelectError::ConcatenationError(e.to_string()))?;
286
287        let new_metadata = RasterMetadata {
288            layer_indices: layer_names.iter().map(|s| s.to_string()).collect(),
289            shape: RasterDataShape {
290                layers: layer_names.len(),
291                ..self.metadata.shape
292            },
293            ..self.metadata.clone()
294        };
295
296        Ok(RasterDataBlock {
297            data,
298            metadata: new_metadata,
299            no_data: self.no_data,
300        })
301    }
302
303    fn select_times(&self, dates: &[DateType]) -> Result<RasterDataBlock<T>, SelectError> {
304        if dates.is_empty() {
305            return Err(SelectError::EmptySelection);
306        }
307
308        let indices: Vec<usize> = dates
309            .iter()
310            .map(|d| self.find_time_index(d))
311            .collect::<Result<_, _>>()?;
312
313        let views: Vec<_> = indices
314            .iter()
315            .map(|&idx| self.data.index_axis(Axis(0), idx))
316            .collect();
317
318        let data = ndarray::stack(Axis(0), &views)
319            .map_err(|e| SelectError::ConcatenationError(e.to_string()))?;
320
321        let new_metadata = RasterMetadata {
322            date_indices: dates.to_vec(),
323            shape: RasterDataShape {
324                times: dates.len(),
325                ..self.metadata.shape
326            },
327            ..self.metadata.clone()
328        };
329
330        Ok(RasterDataBlock {
331            data,
332            metadata: new_metadata,
333            no_data: self.no_data,
334        })
335    }
336}
337
338/// Convenience methods for `RasterDataBlock`.
339impl<T> RasterDataBlock<T>
340where
341    T: RasterType,
342{
343    /// Returns the available layer names for this block.
344    pub fn available_layer_names(&self) -> &[String] {
345        &self.metadata.layer_indices
346    }
347
348    /// Returns the available time indices for this block.
349    pub fn available_time_indices(&self) -> &[DateType] {
350        &self.metadata.date_indices
351    }
352}
353
354/// RasterBlock trait for block-level operations.
355pub trait RasterBlockTrait<U>
356where
357    U: RasterType,
358{
359    /// Converts feature-request-count data to full resolution count.
360    fn into_frc(&self, data: &Array2<U>) -> Array3<U>;
361
362    /// Writes samples for a feature.
363    fn write_samples_feature<T>(&self, data: &Array2<T>, file_name: &std::path::PathBuf, na: T)
364    where
365        T: RasterType + ToPrimitive;
366
367    /// Writes 3D data.
368    fn write3<T>(&self, data: Array3<T>, out_fn: &std::path::PathBuf)
369    where
370        T: RasterType + ToPrimitive;
371}
372
373#[cfg(test)]
374mod tests {
375    use crate::data_sources::DateType;
376    use crate::metadata::{RasterDataBlock, RasterMetadata};
377    use crate::types::RasterDataShape;
378    use ndarray::Array4;
379    use num_traits::NumCast;
380
381    use super::{Select, SelectError};
382
383    fn make_test_block() -> RasterDataBlock<f32> {
384        // shape: (3 times, 4 layers, 2 rows, 2 cols)
385        let data = Array4::<f32>::zeros((3, 4, 2, 2));
386        let metadata = RasterMetadata {
387            layer_indices: vec!["red".into(), "green".into(), "nir".into(), "swir".into()],
388            date_indices: vec![
389                DateType::Index(0),
390                DateType::Index(1),
391                DateType::Index(2),
392            ],
393            shape: RasterDataShape {
394                times: 3,
395                layers: 4,
396                rows: 2,
397                cols: 2,
398            },
399            ..RasterMetadata::new()
400        };
401        RasterDataBlock {
402            data,
403            metadata,
404            no_data: NumCast::from(0.0f32).unwrap(),
405        }
406    }
407
408    #[test]
409    fn test_select_layers_basic() {
410        let block = make_test_block();
411        let result = block.select_layers(&["red", "nir"]).unwrap();
412        assert_eq!(result.metadata.shape.layers, 2);
413        assert_eq!(result.data.shape(), &[3, 2, 2, 2]);
414        assert_eq!(result.metadata.layer_indices, vec!["red", "nir"]);
415        assert_eq!(result.metadata.shape.times, 3);
416        assert_eq!(result.metadata.shape.rows, 2);
417        assert_eq!(result.metadata.shape.cols, 2);
418    }
419
420    #[test]
421    fn test_select_layers_single() {
422        let block = make_test_block();
423        let result = block.select_layers(&["nir"]).unwrap();
424        assert_eq!(result.metadata.shape.layers, 1);
425        assert_eq!(result.data.shape(), &[3, 1, 2, 2]);
426        assert_eq!(result.metadata.layer_indices, vec!["nir"]);
427    }
428
429    #[test]
430    fn test_select_layers_all() {
431        let block = make_test_block();
432        let result = block
433            .select_layers(&["red", "green", "nir", "swir"])
434            .unwrap();
435        assert_eq!(result.metadata.shape.layers, 4);
436        assert_eq!(result.data.shape(), &[3, 4, 2, 2]);
437    }
438
439    #[test]
440    fn test_select_layers_not_found() {
441        let block = make_test_block();
442        let err = block.select_layers(&["red", "blue"]).unwrap_err();
443        assert!(matches!(err, SelectError::LayerNotFound { .. }));
444        if let SelectError::LayerNotFound { requested, available } = err {
445            assert_eq!(requested, "blue");
446            assert_eq!(available.len(), 4);
447        }
448    }
449
450    #[test]
451    fn test_select_layers_empty() {
452        let block = make_test_block();
453        let err = block.select_layers(&[]).unwrap_err();
454        assert!(matches!(err, SelectError::EmptySelection));
455    }
456
457    #[test]
458    fn test_select_times_basic() {
459        let block = make_test_block();
460        let dates = vec![DateType::Index(0), DateType::Index(2)];
461        let result = block.select_times(&dates).unwrap();
462        assert_eq!(result.metadata.shape.times, 2);
463        assert_eq!(result.data.shape(), &[2, 4, 2, 2]);
464        assert_eq!(result.metadata.shape.layers, 4);
465    }
466
467    #[test]
468    fn test_select_times_single() {
469        let block = make_test_block();
470        let result = block.select_times(&[DateType::Index(1)]).unwrap();
471        assert_eq!(result.metadata.shape.times, 1);
472        assert_eq!(result.data.shape(), &[1, 4, 2, 2]);
473    }
474
475    #[test]
476    fn test_select_times_not_found() {
477        let block = make_test_block();
478        let err = block
479            .select_times(&[DateType::Index(99)])
480            .unwrap_err();
481        assert!(matches!(err, SelectError::TimeNotFound { .. }));
482    }
483
484    #[test]
485    fn test_select_times_empty() {
486        let block = make_test_block();
487        let err = block.select_times(&[]).unwrap_err();
488        assert!(matches!(err, SelectError::EmptySelection));
489    }
490
491    #[test]
492    fn test_select_chaining_layers_then_times() {
493        let block = make_test_block();
494        let result = block
495            .select_layers(&["red", "nir"])
496            .unwrap()
497            .select_times(&[DateType::Index(0)])
498            .unwrap();
499        assert_eq!(result.data.shape(), &[1, 2, 2, 2]);
500        assert_eq!(result.metadata.layer_indices, vec!["red", "nir"]);
501        assert_eq!(result.metadata.date_indices.len(), 1);
502    }
503
504    #[test]
505    fn test_select_chaining_times_then_layers() {
506        let block = make_test_block();
507        let result = block
508            .select_times(&[DateType::Index(1), DateType::Index(2)])
509            .unwrap()
510            .select_layers(&["swir"])
511            .unwrap();
512        assert_eq!(result.data.shape(), &[2, 1, 2, 2]);
513        assert_eq!(result.metadata.date_indices.len(), 2);
514        assert_eq!(result.metadata.layer_indices, vec!["swir"]);
515    }
516
517    #[test]
518    fn test_available_layer_names() {
519        let block = make_test_block();
520        assert_eq!(
521            block.available_layer_names(),
522            &["red", "green", "nir", "swir"]
523        );
524    }
525
526    #[test]
527    fn test_available_time_indices() {
528        let block = make_test_block();
529        let times = block.available_time_indices();
530        assert_eq!(times.len(), 3);
531        assert_eq!(times[0], DateType::Index(0));
532        assert_eq!(times[1], DateType::Index(1));
533        assert_eq!(times[2], DateType::Index(2));
534    }
535
536    #[test]
537    fn test_select_preserves_no_data() {
538        let block = make_test_block();
539        let result = block.select_layers(&["red"]).unwrap();
540        assert_eq!(result.no_data, block.no_data);
541    }
542
543    #[test]
544    fn test_select_error_display() {
545        let block = make_test_block();
546        let err = block.select_layers(&["missing"]).unwrap_err();
547        let msg = format!("{}", err);
548        assert!(msg.contains("missing"));
549        assert!(msg.contains("not found"));
550    }
551
552    #[test]
553    fn test_select_layers_order_preserved() {
554        let block = make_test_block();
555        let result = block.select_layers(&["swir", "red"]).unwrap();
556        assert_eq!(result.metadata.layer_indices, vec!["swir", "red"]);
557    }
558}