Skip to content

Commit

Permalink
TYPE: Standardize Resampling type
Browse files Browse the repository at this point in the history
  • Loading branch information
Alan Snow committed Apr 1, 2024
1 parent 2a4aefe commit 725c2ae
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
9 changes: 5 additions & 4 deletions datacube/api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from datacube.storage import reproject_and_fuse, BandInfo
from datacube.utils import ignore_exceptions_if
from odc.geo import CRS, yx_, res_, resyx_, Resolution, XY
from odc.geo.warp import Resampling
from odc.geo.xr import xr_coords
from datacube.utils.dates import normalise_dt
from odc.geo.geom import intersects, box, bbox_union, Geometry
Expand Down Expand Up @@ -244,7 +245,7 @@ def load(self,
measurements: str | list[str] | None = None,
output_crs: Any = None,
resolution: int | float | tuple[int | float, int | float] | Resolution | None = None,
resampling: str | dict[str, str] | None = None,
resampling: Resampling | dict[str, Resampling] | None = None,
align: XY[float] | Iterable[float] | None = None,
skip_broken_datasets: bool = False,
dask_chunks: dict[str, str | int] | None = None,
Expand Down Expand Up @@ -878,7 +879,7 @@ def _cbk(*ignored):
@staticmethod
def load_data(sources: xarray.DataArray, geobox: GeoBox,
measurements: Mapping[str, Measurement] | list[Measurement],
resampling: str | dict[str, str] | None = None,
resampling: Resampling | dict[str, Resampling] | None = None,
fuse_func: FuserFunction | Mapping[str, FuserFunction | None] | None = None,
dask_chunks: dict[str, str | int] | None = None,
skip_broken_datasets: bool = False,
Expand Down Expand Up @@ -969,7 +970,7 @@ def __exit__(self, type_, value, traceback):


def per_band_load_data_settings(measurements: list[Measurement] | Mapping[str, Measurement],
resampling: str | Mapping[str, str] | None = None,
resampling: Resampling | Mapping[str, Resampling] | None = None,
fuse_func: FuserFunction | Mapping[str, FuserFunction | None] | None = None
) -> list[Measurement]:
def with_resampling(m, resampling, default=None):
Expand All @@ -982,7 +983,7 @@ def with_fuser(m, fuser, default=None):
m['fuser'] = fuser.get(m.name, default)
return m

if isinstance(resampling, str):
if not isinstance(resampling, dict):
resampling = {'*': resampling}

if fuse_func is None or callable(fuse_func):
Expand Down
3 changes: 2 additions & 1 deletion datacube/storage/_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from odc.geo.geobox import GeoBox
from odc.geo.roi import roi_is_empty
from odc.geo.xr import xr_coords
from odc.geo.warp import Resampling
from datacube.model import Measurement
from datacube.drivers._types import ReaderDriver
from ..drivers.datasource import DataSource
Expand All @@ -47,7 +48,7 @@ def reproject_and_fuse(datasources: List[DataSource],
destination: np.ndarray,
dst_geobox: GeoBox,
dst_nodata: Optional[Union[int, float]],
resampling: str = 'nearest',
resampling: Resampling = 'nearest',
fuse_func: Optional[FuserFunction] = None,
skip_broken_datasets: bool = False,
progress_cbk: Optional[ProgressFunction] = None,
Expand Down
8 changes: 5 additions & 3 deletions datacube/utils/cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import xarray as xr
import dask
from dask.delayed import Delayed
from odc.geo.warp import Resampling, resampling_s2rio
from pathlib import Path
from typing import Union, Optional, List, Any, Dict

Expand Down Expand Up @@ -38,7 +39,7 @@ def _write_cog(
nodata: Optional[float] = None,
overwrite: bool = False,
blocksize: Optional[int] = None,
overview_resampling: Optional[str] = None,
overview_resampling: Optional[Resampling] = None,
overview_levels: Optional[List[int]] = None,
ovr_blocksize: Optional[int] = None,
use_windowed_writes: bool = False,
Expand Down Expand Up @@ -118,7 +119,8 @@ def _write_cog(
fname, overwrite
) # aborts if overwrite=False and file exists already

resampling = rasterio.enums.Resampling[overview_resampling]
if isinstance(overview_resampling, str):
resampling = resampling_s2rio(overview_resampling)

if (blocksize % 16) != 0:
warnings.warn("Block size must be a multiple of 16, will be adjusted")
Expand Down Expand Up @@ -219,7 +221,7 @@ def write_cog(
overwrite: bool = False,
blocksize: Optional[int] = None,
ovr_blocksize: Optional[int] = None,
overview_resampling: Optional[str] = None,
overview_resampling: Optional[Resampling] = None,
overview_levels: Optional[List[int]] = None,
use_windowed_writes: bool = False,
intermediate_compression: Union[bool, str, Dict[str, Any]] = False,
Expand Down

0 comments on commit 725c2ae

Please sign in to comment.