1use crate::core_types::RasterData;
7use crate::core_types::RasterType;
8use crate::metadata::RasterDataBlock;
9use crate::rasterdataset::RasterDataset;
10use async_tiff::decoder::DecoderRegistry;
11use async_tiff::metadata::cache::ReadaheadMetadataCache;
12use async_tiff::metadata::TiffMetadataReader;
13use async_tiff::reader::ObjectReader;
14use async_tiff::TIFF;
15use ndarray::{s, Array2, Array4};
16use object_store::aws::AmazonS3Builder;
17
18
19const S3_REGION: &str = "ap-southeast-2";
21
22fn build_s3_store(bucket: &str) -> std::sync::Arc<dyn object_store::ObjectStore> {
24 let store = AmazonS3Builder::new()
25 .with_bucket_name(bucket)
26 .with_region(S3_REGION)
27 .with_skip_signature(true)
28 .build()
29 .expect("failed to build S3 store");
30 std::sync::Arc::new(store)
31}
32
33struct CachedTiffReader {
37 tiff: TIFF,
38 reader: ObjectReader,
39 decoder: DecoderRegistry,
40}
41
42impl CachedTiffReader {
43 async fn open(s3_url: &str) -> anyhow::Result<Self> {
45 let url = url::Url::parse(s3_url)?;
46 let bucket = url.host_str().ok_or_else(|| anyhow::anyhow!("No bucket in URL"))?;
47 let path = url.path().trim_start_matches('/');
48
49 let store = build_s3_store(bucket);
50 let reader = ObjectReader::new(store, path.into());
51 let cache = ReadaheadMetadataCache::new(reader.clone());
52
53 let mut meta = TiffMetadataReader::try_open(&cache).await?;
54 let ifds = meta.read_all_ifds(&cache).await?;
55 let tiff = TIFF::new(ifds, meta.endianness());
56
57 Ok(Self {
58 tiff,
59 reader,
60 decoder: DecoderRegistry::default(),
61 })
62 }
63
64 async fn read_window<T: RasterType>(
66 &self,
67 band_index: usize,
68 offset: (isize, isize),
69 window_size: (usize, usize),
70 ) -> anyhow::Result<Array2<T>> {
71 let ifd = &self.tiff.ifds()[band_index - 1];
72 let tile_h = ifd.tile_height().expect("not tiled") as usize;
73 let tile_w = ifd.tile_width().expect("not tiled") as usize;
74
75 let (x_off, y_off) = (offset.0 as usize, offset.1 as usize);
76 let (width, height) = window_size;
77
78 let start_ty = y_off / tile_h;
80 let start_tx = x_off / tile_w;
81 let end_ty = (y_off + height - 1) / tile_h;
82 let end_tx = (x_off + width - 1) / tile_w;
83
84 let tile_coords: Vec<_> = (start_ty..=end_ty)
85 .flat_map(|ty| (start_tx..=end_tx).map(move |tx| (tx as usize, ty as usize)))
86 .collect();
87
88 let tiles = ifd.fetch_tiles(&tile_coords, &self.reader).await?;
90
91 let mut output = Array2::<T>::zeros((height, width));
93
94 for (tile_idx, tile) in tiles.into_iter().enumerate() {
96 let (tx, ty) = tile_coords[tile_idx];
97 let array = tile.decode(&self.decoder)?;
98 let (typed, shape, _dtype) = array.into_inner();
99
100 let t_h = shape[0];
101 let t_w = shape[1];
102
103 let tile_pixel_y = ty * tile_h;
104 let tile_pixel_x = tx * tile_w;
105
106 let tile_row_start = y_off.saturating_sub(tile_pixel_y);
107 let tile_col_start = x_off.saturating_sub(tile_pixel_x);
108
109 let out_row_start = tile_pixel_y.saturating_sub(y_off);
110 let out_col_start = tile_pixel_x.saturating_sub(x_off);
111
112 let copy_rows = (out_row_start + (t_h - tile_row_start)).min(height) - out_row_start;
113 let copy_cols = (out_col_start + (t_w - tile_col_start)).min(width) - out_col_start;
114
115 if let async_tiff::TypedArray::UInt16(data) = typed {
116 let tile_arr = Array2::from_shape_vec((t_h, t_w), data)?;
117 let tile_slice = tile_arr.slice(s![
118 tile_row_start..tile_row_start + copy_rows,
119 tile_col_start..tile_col_start + copy_cols
120 ]);
121 for i in 0..copy_rows {
122 for j in 0..copy_cols {
123 if let Some(val) = num_traits::NumCast::from(tile_slice[[i, j]]) {
124 output[[out_row_start + i, out_col_start + j]] = val;
125 }
126 }
127 }
128 }
129 }
130
131 Ok(output)
132 }
133}
134
135pub async fn read_raster_band_async<T: RasterType>(
146 s3_url: &str,
147 band_index: usize,
148 offset: (isize, isize),
149 window_size: (usize, usize),
150) -> anyhow::Result<Array2<T>> {
151 let reader = CachedTiffReader::open(s3_url).await?;
153 reader.read_window(band_index, offset, window_size).await
154}
155
156impl<R> RasterDataset<R>
161where
162 R: RasterType,
163{
164 pub async fn read_block_async<T: RasterType>(
168 &self,
169 block_id: usize,
170 ) -> RasterData<T> {
171 let s3_urls: Vec<String> = self.metadata.layers
173 .iter()
174 .map(|layer| vsi_to_s3_url(layer.source.to_str().unwrap_or_default()))
175 .collect();
176
177 self.read_block_async_with_urls::<T>(&s3_urls, block_id).await
178 }
179
180 pub async fn read_block_async_with_urls<T: RasterType>(
184 &self,
185 s3_urls: &[String],
186 block_id: usize,
187 ) -> RasterData<T> {
188 let block = &self.blocks[block_id];
189 let read_window = block.read_window;
190
191 let rows = read_window.size.rows as usize;
192 let cols = read_window.size.cols as usize;
193 let data_shape = (
194 self.metadata.shape.times,
195 self.metadata.shape.layers,
196 rows,
197 cols,
198 );
199
200 let mut data: RasterData<T> = RasterData::zeros(data_shape);
201
202 let mut readers: std::collections::HashMap<usize, CachedTiffReader> = std::collections::HashMap::new();
204
205 for (idx, layer) in self.metadata.layers.iter().enumerate() {
206 if !readers.contains_key(&idx) {
208 let reader = CachedTiffReader::open(&s3_urls[idx]).await;
209 if let Ok(r) = reader {
210 readers.insert(idx, r);
211 } else {
212 continue;
213 }
214 }
215
216 let reader = readers.get(&idx).unwrap();
217 let window = (read_window.offset.cols, read_window.offset.rows);
218 let window_size = (cols, rows);
219
220 let layer_data = reader.read_window::<T>(1, window, window_size).await;
221 let layer_data = layer_data.expect("async read failed");
223
224 let slice = s![
225 layer.time_pos,
226 layer.layer_pos,
227 ..,
228 ..
229 ];
230 data.slice_mut(slice).assign(&layer_data);
231 }
232 data
233 }
234
235 pub async fn apply_async<U>(
246 &self,
247 worker: fn(&RasterDataBlock<R>) -> anyhow::Result<Array4<U>>,
248 n_cpus: usize,
249 out_file: &std::path::Path,
250 ) -> anyhow::Result<()>
251 where
252 U: RasterType,
253 {
254 let s3_urls: Vec<String> = self.metadata.layers
256 .iter()
257 .map(|layer| vsi_to_s3_url(layer.source.to_str().unwrap_or_default()))
258 .collect();
259
260 self.apply_async_with_urls(&s3_urls, worker, n_cpus, out_file).await
261 }
262
263 pub async fn apply_async_with_urls<U>(
267 &self,
268 s3_urls: &[String],
269 worker: fn(&RasterDataBlock<R>) -> anyhow::Result<Array4<U>>,
270 n_cpus: usize,
271 out_file: &std::path::Path,
272 ) -> anyhow::Result<()>
273 where
274 U: RasterType,
275 {
276 use crate::gdal_utils::{create_temp_file, file_stem_str, mosaic_translate_cleanup_time_steps};
277 use num_traits::NumCast;
278
279 use std::path::PathBuf;
280
281 let tmp_file = PathBuf::from(create_temp_file("vrt"));
282 let n_times = self.metadata.shape.times;
283 let epsg_code = self.metadata.epsg_code;
284
285 let block_futures: Vec<_> = self
287 .blocks
288 .iter()
289 .map(|raster_block| {
290 let tmp_file_clone = tmp_file.clone();
291 let s3_urls_clone = s3_urls.to_vec();
292 async move {
293 let bid = raster_block.block_index;
294 let file_stem = file_stem_str(&tmp_file_clone);
295
296 let block_data: RasterData<R> = self.read_block_async_with_urls::<R>(&s3_urls_clone, bid).await;
298
299 let raster_data_block = RasterDataBlock {
301 data: block_data,
302 metadata: self.metadata.clone(),
303 no_data: NumCast::from(0).unwrap(),
304 };
305
306 let result = worker(&raster_data_block)?;
308
309 let block_fns = raster_block.write_time_step_blocks(
311 &result,
312 &tmp_file_clone,
313 file_stem,
314 bid,
315 );
316
317 anyhow::Result::Ok(block_fns)
318 }
319 })
320 .collect();
321
322 let collected: Vec<anyhow::Result<Vec<PathBuf>>> =
324 futures::future::join_all(block_futures).await;
325
326 let collected: Vec<Vec<PathBuf>> = collected
327 .into_iter()
328 .collect::<anyhow::Result<_>>()?;
329
330 let pool = crate::gdal_utils::create_rayon_pool(n_cpus);
332 pool.install(|| {
333 mosaic_translate_cleanup_time_steps(&collected, out_file, epsg_code, n_times);
334 });
335
336 Ok(())
337 }
338}
339
340fn vsi_to_s3_url(vsi_path: &str) -> String {
344 if let Some(stripped) = vsi_path.strip_prefix("/vsis3/") {
346 return format!("s3://{}", stripped);
347 }
348 if let Some(stripped) = vsi_path.strip_prefix("/vsicurl/") {
350 if let Ok(url) = url::Url::parse(stripped) {
353 if let Some(host) = url.host_str() {
354 if let Some(bucket) = host.split('.').next() {
356 let path = url.path().trim_start_matches('/');
357 return format!("s3://{}/{}", bucket, path);
358 }
359 }
360 }
361 return stripped.to_string();
362 }
363 if vsi_path.starts_with("s3://") {
365 return vsi_path.to_string();
366 }
367 vsi_path.to_string()
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use tokio;
375
376 #[tokio::test]
377 async fn test_read_raster_band_async() {
378 let s3_url = "s3://dea-public-data/baseline/ga_s2bm_ard_3/56/JNS/2021/01/15/20210116T010541/ga_s2bm_nbart_3-2-1_56JNS_2021-01-15_final_band04.tif";
380 let result = read_raster_band_async::<u16>(s3_url, 1, (0, 0), (512, 512)).await;
381 assert!(result.is_ok(), "Failed to read raster band async: {:?}", result.err());
382 let data = result.unwrap();
383 assert_eq!(data.shape(), &[512, 512]);
384 }
385}