1use crate::core_types::RasterType;
15use crate::types::{GeoTransform, ReadWindow};
16use anyhow::{Context, Result};
17use gdal::{
18 config, Dataset, DatasetOptions, DriverManager, GdalOpenFlags,
19 raster::{Buffer, RasterCreationOptions},
20 spatial_ref::SpatialRef,
21};
22use std::path::{Path, PathBuf};
23use std::sync::Mutex;
24
25pub struct ParallelGeoTiffWriter {
33 pub output_path: PathBuf,
35 pub geo_transform: GeoTransform,
37 pub epsg_code: u32,
39 pub total_cols: usize,
41 pub total_rows: usize,
42 pub n_bands: usize,
44 pub dataset: Mutex<Option<Dataset>>,
46}
47
48impl ParallelGeoTiffWriter {
49 pub fn new(
53 output_path: PathBuf,
54 geo_transform: GeoTransform,
55 epsg_code: u32,
56 total_cols: usize,
57 total_rows: usize,
58 n_bands: usize,
59 ) -> Self {
60 Self {
61 output_path,
62 geo_transform,
63 epsg_code,
64 total_cols,
65 total_rows,
66 n_bands,
67 dataset: Mutex::new(None),
68 }
69 }
70
71 pub fn build_overviews(&self, resampling: &str, levels: &[i32]) -> Result<()> {
81 let mut guard = self.dataset.lock().expect("Mutex poisoned");
82
83 if guard.is_none() {
85 let opts = DatasetOptions {
86 open_flags: GdalOpenFlags::GDAL_OF_UPDATE,
87 ..DatasetOptions::default()
88 };
89 let dataset = Dataset::open_ex(&self.output_path, opts)
90 .with_context(|| format!("Failed to open {:?} for overview building", self.output_path))?;
91 *guard = Some(dataset);
92 }
93
94 config::set_config_option("GDAL_NUM_THREADS", "ALL_CPUS")
96 .context("Failed to set GDAL_NUM_THREADS")?;
97
98 let dataset = guard.as_mut().unwrap();
99 dataset
100 .build_overviews(resampling, levels, &[])
101 .with_context(|| format!("Failed to build overviews on {:?}", self.output_path))?;
102
103 Ok(())
104 }
105}
106
107pub fn create_output_geotiff<T: RasterType>(
112 path: &Path,
113 geo_transform: &GeoTransform,
114 epsg_code: u32,
115 total_cols: usize,
116 total_rows: usize,
117 n_bands: usize,
118 na_value: T,
119) -> Result<()> {
120 if let Some(parent) = path.parent() {
121 std::fs::create_dir_all(parent)
122 .with_context(|| format!("Failed to create parent directory for {:?}", path))?;
123 }
124
125 let driver = DriverManager::get_driver_by_name("GTIFF")
126 .context("GTiff driver not available")?;
127
128 let options = RasterCreationOptions::from_iter([
129 "COMPRESS=LZW",
130 "TILED=YES",
131 "BLOCKXSIZE=512",
132 "BLOCKYSIZE=512",
133 "BIGTIFF=YES",
134 ]);
135
136 let mut dataset = driver
137 .create_with_band_type_with_options::<T, _>(
138 path,
139 total_cols,
140 total_rows,
141 n_bands,
142 &options,
143 )
144 .with_context(|| format!("Failed to create GeoTIFF at {:?}", path))?;
145
146 dataset
147 .set_geo_transform(&geo_transform.to_array())
148 .context("Failed to set geo transform")?;
149
150 let srs = SpatialRef::from_epsg(epsg_code)
151 .context(format!("Invalid EPSG code: {}", epsg_code))?;
152 dataset
153 .set_spatial_ref(&srs)
154 .context("Failed to set spatial reference")?;
155
156 for band_idx in 1..=n_bands {
157 let mut band = dataset
158 .rasterband(band_idx)
159 .context(format!("Failed to access band {}", band_idx))?;
160 if let Some(na_f64) = na_value.to_f64() {
161 band.set_no_data_value(Some(na_f64))
162 .context("Failed to set no-data value")?;
163 }
164 }
165
166 Ok(())
167}
168
169pub fn write_block<T: RasterType>(
179 writer: &ParallelGeoTiffWriter,
180 data: ndarray::ArrayView3<T>,
181 window: ReadWindow,
182) -> Result<()> {
183 let mut guard = writer.dataset.lock().expect("Mutex poisoned");
184
185 if guard.is_none() {
187 let opts = DatasetOptions {
188 open_flags: GdalOpenFlags::GDAL_OF_UPDATE,
189 ..DatasetOptions::default()
190 };
191 let dataset = Dataset::open_ex(&writer.output_path, opts)
192 .with_context(|| format!("Failed to open {:?} for update", writer.output_path))?;
193 *guard = Some(dataset);
194 }
195
196 let dataset = guard.as_mut().unwrap();
197 let block_rows = data.shape()[1];
198 let block_cols = data.shape()[2];
199
200 for band_idx in 0..data.shape()[0] {
201 let mut band = dataset
202 .rasterband(band_idx + 1)
203 .with_context(|| format!("Failed to access band {}", band_idx + 1))?;
204
205 let band_data = data.index_axis(ndarray::Axis(0), band_idx);
206 let data_vec: Vec<T> = band_data.into_iter().copied().collect();
207 let mut buffer = Buffer::new((block_cols, block_rows), data_vec);
208
209 band.write(
210 (window.offset.cols, window.offset.rows),
211 (block_cols, block_rows),
212 &mut buffer,
213 )
214 .with_context(|| {
215 format!(
216 "Failed to write band {} to window {:?}",
217 band_idx + 1,
218 window
219 )
220 })?;
221 }
222
223 Ok(())
224}