eorst/
parallel_writer.rs

1//! Parallel GeoTIFF writer inspired by Dask/rioxarray's RasterioWriter.
2//!
3//! Pre-creates an output GeoTIFF and allows multiple threads to write
4//! their blocks directly to windows in the file, serialized by a Mutex.
5//!
6//! Unlike the rioxarray approach which opens/closes per chunk, this writer
7//! keeps the GDAL dataset open across writes (protected by a Mutex),
8//! avoiding repeated header parsing overhead.
9//!
10//! This replaces the mosaic phase (gdalwarp → gdalbuildvrt → gdal_translate
11//! subprocess chain) with direct windowed writes, eliminating subprocess
12//! spawning and intermediate files.
13
14use crate::core_types::RasterType;
15use crate::types::{GeoTransform, ReadWindow};
16use anyhow::{Context, Result};
17use gdal::{
18    config, Dataset, DatasetOptions, DriverManager, GdalOpenFlags,
19    raster::{Buffer, RasterCreationOptions},
20    spatial_ref::SpatialRef,
21};
22use std::path::{Path, PathBuf};
23use std::sync::Mutex;
24
25/// Writer that allows parallel block writes to a pre-created GeoTIFF.
26///
27/// Thread-safe via internal Mutex. The GDAL dataset is opened once and
28/// kept alive for the writer's lifetime. Each `write_block()` call
29/// acquires the Mutex, writes to the appropriate window, and returns.
30///
31/// The dataset is automatically closed when the writer is dropped.
32pub struct ParallelGeoTiffWriter {
33    /// Path to the output GeoTIFF file
34    pub output_path: PathBuf,
35    /// Geographic transformation parameters
36    pub geo_transform: GeoTransform,
37    /// EPSG coordinate reference system code
38    pub epsg_code: u32,
39    /// Total image size in pixels
40    pub total_cols: usize,
41    pub total_rows: usize,
42    /// Number of bands in the output
43    pub n_bands: usize,
44    /// Mutex-guarded cached GDAL dataset (opened on first write)
45    pub dataset: Mutex<Option<Dataset>>,
46}
47
48impl ParallelGeoTiffWriter {
49    /// Creates a new writer for the given output path and parameters.
50    ///
51    /// The dataset is not opened until the first `write_block()` call.
52    pub fn new(
53        output_path: PathBuf,
54        geo_transform: GeoTransform,
55        epsg_code: u32,
56        total_cols: usize,
57        total_rows: usize,
58        n_bands: usize,
59    ) -> Self {
60        Self {
61            output_path,
62            geo_transform,
63            epsg_code,
64            total_cols,
65            total_rows,
66            n_bands,
67            dataset: Mutex::new(None),
68        }
69    }
70
71    /// Builds overviews on the written GeoTIFF.
72    ///
73    /// This must be called after all `write_block()` calls are complete.
74    /// It opens the dataset in update mode and calls GDAL's `build_overviews`
75    /// with `GDAL_NUM_THREADS=ALL_CPUS` for multithreaded processing.
76    ///
77    /// # Arguments
78    /// * `resampling` - Resampling method, e.g. "CUBIC", "NEAREST", "AVERAGE"
79    /// * `levels` - Overview decimation factors, e.g. \[2, 4, 8, 16, 32\]
80    pub fn build_overviews(&self, resampling: &str, levels: &[i32]) -> Result<()> {
81        let mut guard = self.dataset.lock().expect("Mutex poisoned");
82
83        // Open dataset if not already open
84        if guard.is_none() {
85            let opts = DatasetOptions {
86                open_flags: GdalOpenFlags::GDAL_OF_UPDATE,
87                ..DatasetOptions::default()
88            };
89            let dataset = Dataset::open_ex(&self.output_path, opts)
90                .with_context(|| format!("Failed to open {:?} for overview building", self.output_path))?;
91            *guard = Some(dataset);
92        }
93
94        // Enable multithreaded overview building — uses all available CPUs
95        config::set_config_option("GDAL_NUM_THREADS", "ALL_CPUS")
96            .context("Failed to set GDAL_NUM_THREADS")?;
97
98        let dataset = guard.as_mut().unwrap();
99        dataset
100            .build_overviews(resampling, levels, &[])
101            .with_context(|| format!("Failed to build overviews on {:?}", self.output_path))?;
102
103        Ok(())
104    }
105}
106
107/// Pre-creates a GeoTIFF file with the given parameters.
108///
109/// Creates the file with GTiff driver, LZW compression, tiled 512x512 blocks,
110/// correct geotransform, CRS, and no-data values.
111pub fn create_output_geotiff<T: RasterType>(
112    path: &Path,
113    geo_transform: &GeoTransform,
114    epsg_code: u32,
115    total_cols: usize,
116    total_rows: usize,
117    n_bands: usize,
118    na_value: T,
119) -> Result<()> {
120    if let Some(parent) = path.parent() {
121        std::fs::create_dir_all(parent)
122            .with_context(|| format!("Failed to create parent directory for {:?}", path))?;
123    }
124
125    let driver = DriverManager::get_driver_by_name("GTIFF")
126        .context("GTiff driver not available")?;
127
128    let options = RasterCreationOptions::from_iter([
129        "COMPRESS=LZW",
130        "TILED=YES",
131        "BLOCKXSIZE=512",
132        "BLOCKYSIZE=512",
133        "BIGTIFF=YES",
134    ]);
135
136    let mut dataset = driver
137        .create_with_band_type_with_options::<T, _>(
138            path,
139            total_cols,
140            total_rows,
141            n_bands,
142            &options,
143        )
144        .with_context(|| format!("Failed to create GeoTIFF at {:?}", path))?;
145
146    dataset
147        .set_geo_transform(&geo_transform.to_array())
148        .context("Failed to set geo transform")?;
149
150    let srs = SpatialRef::from_epsg(epsg_code)
151        .context(format!("Invalid EPSG code: {}", epsg_code))?;
152    dataset
153        .set_spatial_ref(&srs)
154        .context("Failed to set spatial reference")?;
155
156    for band_idx in 1..=n_bands {
157        let mut band = dataset
158            .rasterband(band_idx)
159            .context(format!("Failed to access band {}", band_idx))?;
160        if let Some(na_f64) = na_value.to_f64() {
161            band.set_no_data_value(Some(na_f64))
162                .context("Failed to set no-data value")?;
163        }
164    }
165
166    Ok(())
167}
168
169/// Writes a single block's data (all bands) to the pre-created GeoTIFF.
170///
171/// This method is thread-safe: it acquires the writer's mutex, opens the
172/// dataset on first call (then reuses it), writes all bands to the
173/// specified window, and returns. The dataset stays open for subsequent writes.
174///
175/// # Arguments
176/// * `data` - 3D array with shape (bands, rows, cols)
177/// * `window` - The read window defining where this block belongs in the output
178pub fn write_block<T: RasterType>(
179    writer: &ParallelGeoTiffWriter,
180    data: ndarray::ArrayView3<T>,
181    window: ReadWindow,
182) -> Result<()> {
183    let mut guard = writer.dataset.lock().expect("Mutex poisoned");
184
185    // Open dataset on first write
186    if guard.is_none() {
187        let opts = DatasetOptions {
188            open_flags: GdalOpenFlags::GDAL_OF_UPDATE,
189            ..DatasetOptions::default()
190        };
191        let dataset = Dataset::open_ex(&writer.output_path, opts)
192            .with_context(|| format!("Failed to open {:?} for update", writer.output_path))?;
193        *guard = Some(dataset);
194    }
195
196    let dataset = guard.as_mut().unwrap();
197    let block_rows = data.shape()[1];
198    let block_cols = data.shape()[2];
199
200    for band_idx in 0..data.shape()[0] {
201        let mut band = dataset
202            .rasterband(band_idx + 1)
203            .with_context(|| format!("Failed to access band {}", band_idx + 1))?;
204
205        let band_data = data.index_axis(ndarray::Axis(0), band_idx);
206        let data_vec: Vec<T> = band_data.into_iter().copied().collect();
207        let mut buffer = Buffer::new((block_cols, block_rows), data_vec);
208
209        band.write(
210            (window.offset.cols, window.offset.rows),
211            (block_cols, block_rows),
212            &mut buffer,
213        )
214        .with_context(|| {
215            format!(
216                "Failed to write band {} to window {:?}",
217                band_idx + 1,
218                window
219            )
220        })?;
221    }
222
223    Ok(())
224}