From 0a8629aa558753571419eba701c21e5ddba0f3ba Mon Sep 17 00:00:00 2001 From: snowman2 Date: Mon, 1 Apr 2024 14:47:18 -0500 Subject: [PATCH] TYPE: Standardize resampling type --- datacube/api/core.py | 9 +++++---- datacube/storage/_load.py | 3 ++- datacube/utils/cog.py | 10 +++++++--- docs/about/whats_new.rst | 3 +++ 4 files changed, 17 insertions(+), 8 deletions(-) diff --git a/datacube/api/core.py b/datacube/api/core.py index 322868d25..eb3201d3c 100644 --- a/datacube/api/core.py +++ b/datacube/api/core.py @@ -18,6 +18,7 @@ from datacube.storage import reproject_and_fuse, BandInfo from datacube.utils import ignore_exceptions_if from odc.geo import CRS, yx_, res_, resyx_, Resolution, XY +from odc.geo.warp import Resampling from odc.geo.xr import xr_coords from datacube.utils.dates import normalise_dt from odc.geo.geom import intersects, box, bbox_union, Geometry @@ -244,7 +245,7 @@ def load(self, measurements: str | list[str] | None = None, output_crs: Any = None, resolution: int | float | tuple[int | float, int | float] | Resolution | None = None, - resampling: str | dict[str, str] | None = None, + resampling: Resampling | dict[str, Resampling] | None = None, align: XY[float] | Iterable[float] | None = None, skip_broken_datasets: bool = False, dask_chunks: dict[str, str | int] | None = None, @@ -878,7 +879,7 @@ def _cbk(*ignored): @staticmethod def load_data(sources: xarray.DataArray, geobox: GeoBox, measurements: Mapping[str, Measurement] | list[Measurement], - resampling: str | dict[str, str] | None = None, + resampling: Resampling | dict[str, Resampling] | None = None, fuse_func: FuserFunction | Mapping[str, FuserFunction | None] | None = None, dask_chunks: dict[str, str | int] | None = None, skip_broken_datasets: bool = False, @@ -969,7 +970,7 @@ def __exit__(self, type_, value, traceback): def per_band_load_data_settings(measurements: list[Measurement] | Mapping[str, Measurement], - resampling: str | Mapping[str, str] | None = None, + resampling: Resampling | Mapping[str, Resampling] | None = None, fuse_func: FuserFunction | Mapping[str, FuserFunction | None] | None = None ) -> list[Measurement]: def with_resampling(m, resampling, default=None): @@ -982,7 +983,7 @@ def with_fuser(m, fuser, default=None): m['fuser'] = fuser.get(m.name, default) return m - if isinstance(resampling, str): + if resampling is not None and not isinstance(resampling, dict): resampling = {'*': resampling} if fuse_func is None or callable(fuse_func): diff --git a/datacube/storage/_load.py b/datacube/storage/_load.py index ed0b8f5f0..6c25666dc 100644 --- a/datacube/storage/_load.py +++ b/datacube/storage/_load.py @@ -23,6 +23,7 @@ from odc.geo.geobox import GeoBox from odc.geo.roi import roi_is_empty from odc.geo.xr import xr_coords +from odc.geo.warp import Resampling from datacube.model import Measurement from datacube.drivers._types import ReaderDriver from ..drivers.datasource import DataSource @@ -47,7 +48,7 @@ def reproject_and_fuse(datasources: List[DataSource], destination: np.ndarray, dst_geobox: GeoBox, dst_nodata: Optional[Union[int, float]], - resampling: str = 'nearest', + resampling: Resampling = 'nearest', fuse_func: Optional[FuserFunction] = None, skip_broken_datasets: bool = False, progress_cbk: Optional[ProgressFunction] = None, diff --git a/datacube/utils/cog.py b/datacube/utils/cog.py index 460ebc40c..9b4654a5f 100644 --- a/datacube/utils/cog.py +++ b/datacube/utils/cog.py @@ -18,6 +18,7 @@ from .io import check_write_path from odc.geo.geobox import GeoBox from odc.geo.math import align_up +from odc.geo.warp import Resampling, resampling_s2rio from deprecat import deprecat @@ -38,7 +39,7 @@ def _write_cog( nodata: Optional[float] = None, overwrite: bool = False, blocksize: Optional[int] = None, - overview_resampling: Optional[str] = None, + overview_resampling: Optional[Resampling] = None, overview_levels: Optional[List[int]] = None, ovr_blocksize: Optional[int] = None, use_windowed_writes: bool = False, @@ -118,7 +119,10 @@ def _write_cog( fname, overwrite ) # aborts if overwrite=False and file exists already - resampling = rasterio.enums.Resampling[overview_resampling] + if isinstance(overview_resampling, str): + resampling = resampling_s2rio(overview_resampling) + else: + resampling = overview_resampling if (blocksize % 16) != 0: warnings.warn("Block size must be a multiple of 16, will be adjusted") @@ -219,7 +223,7 @@ def write_cog( overwrite: bool = False, blocksize: Optional[int] = None, ovr_blocksize: Optional[int] = None, - overview_resampling: Optional[str] = None, + overview_resampling: Optional[Resampling] = None, overview_levels: Optional[List[int]] = None, use_windowed_writes: bool = False, intermediate_compression: Union[bool, str, Dict[str, Any]] = False, diff --git a/docs/about/whats_new.rst b/docs/about/whats_new.rst index 7c06c0ad0..afd67980d 100644 --- a/docs/about/whats_new.rst +++ b/docs/about/whats_new.rst @@ -8,6 +8,9 @@ What's New v1.9.next ========= +- Standardize resampling input supported to `odc.geo.warp.Resampling`. + + v1.9.0-rc3 (27th March 2024) ============================