1use crate::core_types::{RasterData, RasterType};
7use crate::data_sources::DateType;
8use crate::metadata::{RasterDataBlock, RasterMetadata};
9use crate::types::{Dimension, RasterDataShape};
10use anyhow::Result;
11use gdal::raster::GdalType;
12use ndarray::{Array2, Array3, Axis};
13use num_traits::{FromPrimitive, ToPrimitive};
14use std::fmt::{self, Debug};
15use std::ops::{Add, Div, Index, IndexMut};
16
17pub trait Stack {
19 fn stack(&mut self, other: RasterDataShape, dim_to_stack: Dimension) -> &mut RasterDataShape;
21 fn extend(&mut self, other: RasterDataShape) -> &mut RasterDataShape;
23}
24
25impl Stack for RasterDataShape {
26 fn extend(&mut self, other: RasterDataShape) -> &mut RasterDataShape {
27 let mut extendable = true;
28
29 for dim_loc in 1..4 {
30 if self[dim_loc] != other[dim_loc] {
31 extendable = false
32 }
33 }
34
35 if extendable {
36 self[0] += other[0];
37 self
38 } else {
39 panic!("Unable to extend layers");
40 }
41 }
42
43 fn stack(&mut self, other: RasterDataShape, dim_to_stack: Dimension) -> &mut RasterDataShape {
44 let dimension_axis = dim_to_stack.get_axis();
45 let mut stackable = true;
46 for dim_loc in 0..4 {
47 if dim_loc != dimension_axis && self[dim_loc] != other[dim_loc] {
48 stackable = false;
49 }
50 }
51
52 if stackable {
53 self[dimension_axis] += other[dimension_axis];
54 self
55 } else {
56 panic!("Unable to stack layers");
57 }
58 }
59}
60
61impl Dimension {
62 pub fn get_axis(&self) -> usize {
64 match self {
65 Dimension::Layer => 1,
66 Dimension::Time => 0,
67 }
68 }
69}
70
71impl Index<usize> for RasterDataShape {
72 type Output = usize;
73
74 fn index(&self, index: usize) -> &usize {
75 match index {
76 0 => &self.times,
77 1 => &self.layers,
78 2 => &self.rows,
79 3 => &self.cols,
80 n => panic!("Invalid index: {}", n),
81 }
82 }
83}
84
85impl IndexMut<usize> for RasterDataShape {
86 fn index_mut(&mut self, index: usize) -> &mut usize {
87 match index {
88 0 => &mut self.times,
89 1 => &mut self.layers,
90 2 => &mut self.rows,
91 3 => &mut self.cols,
92 n => panic!("Invalid index: {}", n),
93 }
94 }
95}
96
97pub trait SumDimension<T>
99where
100 T: GdalType + num_traits::identities::Zero + Copy + FromPrimitive + Add<Output = T> + Div<Output = T>,
101{
102 fn sum_dimension(&self, dimension: Dimension) -> Array3<T>;
104}
105
106impl<T> SumDimension<T> for RasterData<T>
107where
108 T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T>,
109{
110 fn sum_dimension(&self, dimension: Dimension) -> Array3<T> {
111 match dimension {
112 Dimension::Layer => self.sum_axis(Axis(1)),
113 Dimension::Time => self.sum_axis(Axis(0)),
114 }
115 }
116}
117
118#[allow(dead_code)]
119pub(crate) trait MeanDimension<T>
120where
121 T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T>,
122{
123 fn mean_dimension(&self, dimension: Dimension) -> Array3<T>;
124}
125
126impl<T> MeanDimension<T> for RasterData<T>
127where
128 T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T>,
129{
130 fn mean_dimension(&self, dimension: Dimension) -> Array3<T> {
131 let mean = match dimension {
132 Dimension::Layer => self.mean_axis(Axis(1)),
133 Dimension::Time => self.mean_axis(Axis(0)),
134 };
135 mean.unwrap()
136 }
137}
138
139pub trait VarDimension<T>
141where
142 T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T> + num_traits::Float,
143{
144 fn var_dimension(&self, ddof: T, dimension: Dimension) -> Array3<T>;
147}
148
149impl<T> VarDimension<T> for RasterData<T>
150where
151 T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T> + num_traits::Float,
152{
153 fn var_dimension(&self, ddof: T, dimension: Dimension) -> Array3<T> {
154 match dimension {
155 Dimension::Layer => self.var_axis(Axis(1), ddof),
156 Dimension::Time => self.var_axis(Axis(0), ddof),
157 }
158 }
159}
160
161#[allow(dead_code)]
162pub(crate) trait StdDimension<T>
163where
164 T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T> + num_traits::Float,
165{
166 fn std_dimension(&self, ddof: T, dimension: Dimension) -> Array3<T>;
167}
168
169impl<T> StdDimension<T> for RasterData<T>
170where
171 T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T> + num_traits::Float,
172{
173 fn std_dimension(&self, ddof: T, dimension: Dimension) -> Array3<T> {
174 match dimension {
175 Dimension::Layer => self.std_axis(Axis(1), ddof),
176 Dimension::Time => self.std_axis(Axis(0), ddof),
177 }
178 }
179}
180
181#[derive(Debug)]
183pub enum SelectError {
184 LayerNotFound {
186 requested: String,
188 available: Vec<String>,
190 },
191 TimeNotFound {
193 requested: DateType,
195 available: Vec<DateType>,
197 },
198 EmptySelection,
200 ConcatenationError(String),
202}
203
204impl fmt::Display for SelectError {
205 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206 match self {
207 SelectError::LayerNotFound { requested, available } => {
208 write!(f, "Layer '{}' not found. Available: {:?}", requested, available)
209 }
210 SelectError::TimeNotFound { requested, available } => {
211 write!(f, "Time {:?} not found. Available: {:?}", requested, available)
212 }
213 SelectError::EmptySelection => {
214 write!(f, "Empty selection requested")
215 }
216 SelectError::ConcatenationError(msg) => {
217 write!(f, "Array concatenation failed: {}", msg)
218 }
219 }
220 }
221}
222
223impl std::error::Error for SelectError {}
224
225pub trait Select<T>
227where
228 T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T>,
229{
230 fn select_layers(&self, layer_names: &[&str]) -> Result<RasterDataBlock<T>, SelectError>;
232
233 fn select_times(&self, dates: &[DateType]) -> Result<RasterDataBlock<T>, SelectError>;
235
236 fn find_layer_index(&self, name: &str) -> Result<usize, SelectError>;
238
239 fn find_time_index(&self, date: &DateType) -> Result<usize, SelectError>;
241}
242
243impl<T> Select<T> for RasterDataBlock<T>
244where
245 T: RasterType + FromPrimitive + Add<Output = T> + Div<Output = T>,
246{
247 fn find_layer_index(&self, name: &str) -> Result<usize, SelectError> {
248 self.metadata
249 .layer_indices
250 .iter()
251 .position(|s| s.as_str() == name)
252 .ok_or(SelectError::LayerNotFound {
253 requested: name.to_string(),
254 available: self.metadata.layer_indices.clone(),
255 })
256 }
257
258 fn find_time_index(&self, date: &DateType) -> Result<usize, SelectError> {
259 self.metadata
260 .date_indices
261 .iter()
262 .position(|d| d == date)
263 .ok_or(SelectError::TimeNotFound {
264 requested: date.clone(),
265 available: self.metadata.date_indices.clone(),
266 })
267 }
268
269 fn select_layers(&self, layer_names: &[&str]) -> Result<RasterDataBlock<T>, SelectError> {
270 if layer_names.is_empty() {
271 return Err(SelectError::EmptySelection);
272 }
273
274 let indices: Vec<usize> = layer_names
275 .iter()
276 .map(|name| self.find_layer_index(name))
277 .collect::<Result<_, _>>()?;
278
279 let views: Vec<_> = indices
280 .iter()
281 .map(|&idx| self.data.index_axis(Axis(1), idx))
282 .collect();
283
284 let data = ndarray::stack(Axis(1), &views)
285 .map_err(|e| SelectError::ConcatenationError(e.to_string()))?;
286
287 let new_metadata = RasterMetadata {
288 layer_indices: layer_names.iter().map(|s| s.to_string()).collect(),
289 shape: RasterDataShape {
290 layers: layer_names.len(),
291 ..self.metadata.shape
292 },
293 ..self.metadata.clone()
294 };
295
296 Ok(RasterDataBlock {
297 data,
298 metadata: new_metadata,
299 no_data: self.no_data,
300 })
301 }
302
303 fn select_times(&self, dates: &[DateType]) -> Result<RasterDataBlock<T>, SelectError> {
304 if dates.is_empty() {
305 return Err(SelectError::EmptySelection);
306 }
307
308 let indices: Vec<usize> = dates
309 .iter()
310 .map(|d| self.find_time_index(d))
311 .collect::<Result<_, _>>()?;
312
313 let views: Vec<_> = indices
314 .iter()
315 .map(|&idx| self.data.index_axis(Axis(0), idx))
316 .collect();
317
318 let data = ndarray::stack(Axis(0), &views)
319 .map_err(|e| SelectError::ConcatenationError(e.to_string()))?;
320
321 let new_metadata = RasterMetadata {
322 date_indices: dates.to_vec(),
323 shape: RasterDataShape {
324 times: dates.len(),
325 ..self.metadata.shape
326 },
327 ..self.metadata.clone()
328 };
329
330 Ok(RasterDataBlock {
331 data,
332 metadata: new_metadata,
333 no_data: self.no_data,
334 })
335 }
336}
337
338impl<T> RasterDataBlock<T>
340where
341 T: RasterType,
342{
343 pub fn available_layer_names(&self) -> &[String] {
345 &self.metadata.layer_indices
346 }
347
348 pub fn available_time_indices(&self) -> &[DateType] {
350 &self.metadata.date_indices
351 }
352}
353
354pub trait RasterBlockTrait<U>
356where
357 U: RasterType,
358{
359 fn into_frc(&self, data: &Array2<U>) -> Array3<U>;
361
362 fn write_samples_feature<T>(&self, data: &Array2<T>, file_name: &std::path::PathBuf, na: T)
364 where
365 T: RasterType + ToPrimitive;
366
367 fn write3<T>(&self, data: Array3<T>, out_fn: &std::path::PathBuf)
369 where
370 T: RasterType + ToPrimitive;
371}
372
373#[cfg(test)]
374mod tests {
375 use crate::data_sources::DateType;
376 use crate::metadata::{RasterDataBlock, RasterMetadata};
377 use crate::types::RasterDataShape;
378 use ndarray::Array4;
379 use num_traits::NumCast;
380
381 use super::{Select, SelectError};
382
383 fn make_test_block() -> RasterDataBlock<f32> {
384 let data = Array4::<f32>::zeros((3, 4, 2, 2));
386 let metadata = RasterMetadata {
387 layer_indices: vec!["red".into(), "green".into(), "nir".into(), "swir".into()],
388 date_indices: vec![
389 DateType::Index(0),
390 DateType::Index(1),
391 DateType::Index(2),
392 ],
393 shape: RasterDataShape {
394 times: 3,
395 layers: 4,
396 rows: 2,
397 cols: 2,
398 },
399 ..RasterMetadata::new()
400 };
401 RasterDataBlock {
402 data,
403 metadata,
404 no_data: NumCast::from(0.0f32).unwrap(),
405 }
406 }
407
408 #[test]
409 fn test_select_layers_basic() {
410 let block = make_test_block();
411 let result = block.select_layers(&["red", "nir"]).unwrap();
412 assert_eq!(result.metadata.shape.layers, 2);
413 assert_eq!(result.data.shape(), &[3, 2, 2, 2]);
414 assert_eq!(result.metadata.layer_indices, vec!["red", "nir"]);
415 assert_eq!(result.metadata.shape.times, 3);
416 assert_eq!(result.metadata.shape.rows, 2);
417 assert_eq!(result.metadata.shape.cols, 2);
418 }
419
420 #[test]
421 fn test_select_layers_single() {
422 let block = make_test_block();
423 let result = block.select_layers(&["nir"]).unwrap();
424 assert_eq!(result.metadata.shape.layers, 1);
425 assert_eq!(result.data.shape(), &[3, 1, 2, 2]);
426 assert_eq!(result.metadata.layer_indices, vec!["nir"]);
427 }
428
429 #[test]
430 fn test_select_layers_all() {
431 let block = make_test_block();
432 let result = block
433 .select_layers(&["red", "green", "nir", "swir"])
434 .unwrap();
435 assert_eq!(result.metadata.shape.layers, 4);
436 assert_eq!(result.data.shape(), &[3, 4, 2, 2]);
437 }
438
439 #[test]
440 fn test_select_layers_not_found() {
441 let block = make_test_block();
442 let err = block.select_layers(&["red", "blue"]).unwrap_err();
443 assert!(matches!(err, SelectError::LayerNotFound { .. }));
444 if let SelectError::LayerNotFound { requested, available } = err {
445 assert_eq!(requested, "blue");
446 assert_eq!(available.len(), 4);
447 }
448 }
449
450 #[test]
451 fn test_select_layers_empty() {
452 let block = make_test_block();
453 let err = block.select_layers(&[]).unwrap_err();
454 assert!(matches!(err, SelectError::EmptySelection));
455 }
456
457 #[test]
458 fn test_select_times_basic() {
459 let block = make_test_block();
460 let dates = vec![DateType::Index(0), DateType::Index(2)];
461 let result = block.select_times(&dates).unwrap();
462 assert_eq!(result.metadata.shape.times, 2);
463 assert_eq!(result.data.shape(), &[2, 4, 2, 2]);
464 assert_eq!(result.metadata.shape.layers, 4);
465 }
466
467 #[test]
468 fn test_select_times_single() {
469 let block = make_test_block();
470 let result = block.select_times(&[DateType::Index(1)]).unwrap();
471 assert_eq!(result.metadata.shape.times, 1);
472 assert_eq!(result.data.shape(), &[1, 4, 2, 2]);
473 }
474
475 #[test]
476 fn test_select_times_not_found() {
477 let block = make_test_block();
478 let err = block
479 .select_times(&[DateType::Index(99)])
480 .unwrap_err();
481 assert!(matches!(err, SelectError::TimeNotFound { .. }));
482 }
483
484 #[test]
485 fn test_select_times_empty() {
486 let block = make_test_block();
487 let err = block.select_times(&[]).unwrap_err();
488 assert!(matches!(err, SelectError::EmptySelection));
489 }
490
491 #[test]
492 fn test_select_chaining_layers_then_times() {
493 let block = make_test_block();
494 let result = block
495 .select_layers(&["red", "nir"])
496 .unwrap()
497 .select_times(&[DateType::Index(0)])
498 .unwrap();
499 assert_eq!(result.data.shape(), &[1, 2, 2, 2]);
500 assert_eq!(result.metadata.layer_indices, vec!["red", "nir"]);
501 assert_eq!(result.metadata.date_indices.len(), 1);
502 }
503
504 #[test]
505 fn test_select_chaining_times_then_layers() {
506 let block = make_test_block();
507 let result = block
508 .select_times(&[DateType::Index(1), DateType::Index(2)])
509 .unwrap()
510 .select_layers(&["swir"])
511 .unwrap();
512 assert_eq!(result.data.shape(), &[2, 1, 2, 2]);
513 assert_eq!(result.metadata.date_indices.len(), 2);
514 assert_eq!(result.metadata.layer_indices, vec!["swir"]);
515 }
516
517 #[test]
518 fn test_available_layer_names() {
519 let block = make_test_block();
520 assert_eq!(
521 block.available_layer_names(),
522 &["red", "green", "nir", "swir"]
523 );
524 }
525
526 #[test]
527 fn test_available_time_indices() {
528 let block = make_test_block();
529 let times = block.available_time_indices();
530 assert_eq!(times.len(), 3);
531 assert_eq!(times[0], DateType::Index(0));
532 assert_eq!(times[1], DateType::Index(1));
533 assert_eq!(times[2], DateType::Index(2));
534 }
535
536 #[test]
537 fn test_select_preserves_no_data() {
538 let block = make_test_block();
539 let result = block.select_layers(&["red"]).unwrap();
540 assert_eq!(result.no_data, block.no_data);
541 }
542
543 #[test]
544 fn test_select_error_display() {
545 let block = make_test_block();
546 let err = block.select_layers(&["missing"]).unwrap_err();
547 let msg = format!("{}", err);
548 assert!(msg.contains("missing"));
549 assert!(msg.contains("not found"));
550 }
551
552 #[test]
553 fn test_select_layers_order_preserved() {
554 let block = make_test_block();
555 let result = block.select_layers(&["swir", "red"]).unwrap();
556 assert_eq!(result.metadata.layer_indices, vec!["swir", "red"]);
557 }
558}