Source code for raster_tools.io

import os
import urllib
import warnings

import dask
import numpy as np
import rasterio as rio
import rioxarray as xrio
import xarray as xr
from affine import Affine
from dask.array.core import normalize_chunks as dask_chunks

from raster_tools.dtypes import F32, F64, I64, U8, is_bool, is_float, is_int
from raster_tools.exceptions import (
    AffineEncodingError,
    DimensionsError,
    RasterDataError,
    RasterIOError,
)
from raster_tools.masking import get_default_null_value
from raster_tools.utils import to_chunk_dict, validate_path


def _get_extension(path):
    return os.path.splitext(path)[-1].lower()


def _get_chunking_info_from_file(src_file):
    with rio.open(src_file) as src:
        tile_shape = (1, *src.block_shapes[0])
        shape = (src.count, *src.shape)
        dtype = np.dtype(src.dtypes[0])
        return tile_shape, shape, dtype


def _get_chunks(data=None, src_file=None):
    chunks = (1, "auto", "auto")
    if data is None:
        if src_file is None:
            return chunks
        tile_shape, shape, dtype = _get_chunking_info_from_file(src_file)
    else:
        shape = data.shape
        dtype = data.dtype
        tile_shape = None
        if dask.is_dask_collection(data):
            tile_shape = data.chunks
        elif src_file is not None:
            _, tile_shape, _ = _get_chunking_info_from_file(src_file)
    return dask_chunks(chunks, shape, dtype=dtype, previous_chunks=tile_shape)


[docs]def chunk(xrs, src_file=None): chunks = to_chunk_dict( _get_chunks( xrs.raster if isinstance(xrs, xr.Dataset) else xrs, src_file ) ) return xrs.chunk(chunks)
TIFF_EXTS = frozenset((".tif", ".tiff")) NC_EXTS = frozenset((".cdf", ".nc", ".nc4")) HDF_EXTS = frozenset((".hdf", ".h4", ".hdf4", ".he2", ".h5", ".hdf5", ".he5")) GRIB_EXTS = frozenset((".grib", ".grib2", ".grb", ".grb2", ".gb", ".gb2")) BATCH_EXTS = frozenset((".bch",)) # File extenstions that can't be read in yet READ_NOT_IMPLEMENTED_EXTS = NC_EXTS | HDF_EXTS | GRIB_EXTS # File extenstions that can't be written out yet WRITE_NOT_IMPLEMENTED_EXTS = NC_EXTS | HDF_EXTS | GRIB_EXTS
[docs]def is_batch_file(path): return _get_extension(path) in BATCH_EXTS
def _require_backend(import_name, package, ext, extra): import importlib.util if importlib.util.find_spec(import_name) is None: raise ImportError( f"Reading {ext} files requires the '{package}' package, which is" " not installed. Install it with 'pip install" f" raster-tools[{extra}]' or 'conda install -c conda-forge" f" {package}'." ) ESRI_DEFAULT_F32_NV = np.finfo(F32).min
[docs]def normalize_null_value(nv, dtype): # Make sure that ESRI's default F32 null value is properly # registered as F32 if dtype == F32 and nv is not None and np.isclose(nv, ESRI_DEFAULT_F32_NV): nv = F32.type(nv) # Some rasters have (u)int dtype and a null value that is a whole number # but it gets read in as a float. This can cause a lot of accidental type # promotions down the pipeline. Check for this case and correct it. if is_int(dtype) and is_float(nv) and float(nv).is_integer(): nv = int(nv) return nv
[docs]def open_raster_from_path_or_url(path): from raster_tools.raster import ( _try_to_get_null_value_xarray, normalize_xarray_data, ) if isinstance(path, os.PathLike): ext = path.suffix elif isinstance(path, str): if urllib.parse.urlparse(path) == "": # Assume file path validate_path(path) ext = _get_extension(path) else: # Could be a URL or path ext = "" else: raise RasterIOError( f"Could not resolve input to a raster path or URL: '{path}'" ) xrs = None # Try to let gdal open anything but NC, HDF, GRIB files if ext in READ_NOT_IMPLEMENTED_EXTS: raise NotImplementedError( "Reading of NetCDF, HDF, and GRIB files is not supported at this" " time. Try 'raster_tools.open_dataset'." ) else: try: xrs = xrio.open_rasterio( path, chunks=to_chunk_dict(_get_chunks()), lock=False ) except rio.errors.RasterioIOError as e: raise RasterIOError( "Could not open given path as a raster." ) from e if isinstance(xrs, xr.Dataset): raise RasterDataError("Too many data variables in input data") assert isinstance( xrs, xr.DataArray ), "Resulting data structure must be a DataArray" if not dask.is_dask_collection(xrs): xrs = chunk(xrs, path) xrs = normalize_xarray_data(xrs) nv = _try_to_get_null_value_xarray(xrs) nv = normalize_null_value(nv, xrs.dtype) xrs = xrs.rio.write_nodata(nv) return xrs
_EXT_TO_DRIVER = {".tif": "GTiff", ".tiff": "GTiff"} # Drivers that build overviews as part of the write itself; the post-write # build_overviews pass is skipped for these. _DRIVERS_WITH_INTERNAL_OVERVIEWS = frozenset({"COG"}) def _resolve_driver(path, driver): if driver is not None: return driver return _EXT_TO_DRIVER.get(_get_extension(path)) def _gtiff_translate(opts): out = {} tiled = opts.get("tiled") if tiled is not None: out["tiled"] = bool(tiled) bs = opts.get("blocksize") if bs is not None: if isinstance(bs, int): h = w = bs else: h, w = bs out["blockxsize"] = int(w) out["blockysize"] = int(h) compress = opts.get("compress") if compress is None: out["compress"] = "none" else: out["compress"] = str(compress).lower() level = opts.get("compress_level") if level is not None: c = out["compress"] if c == "deflate": out["zlevel"] = int(level) elif c == "zstd": out["zstd_level"] = int(level) elif c == "jpeg": out["jpeg_quality"] = int(level) else: warnings.warn( f"compress_level has no effect with compress={compress!r}", stacklevel=4, ) predictor = opts.get("predictor") if predictor is not None: if out["compress"] == "jpeg": warnings.warn( "predictor is not valid with compress='jpeg'; ignoring", stacklevel=4, ) elif isinstance(predictor, int): # Backward compat with rasterio-style integer predictor values. out["predictor"] = predictor else: mapping = {"horizontal": 2, "float": 3} if predictor not in mapping: raise ValueError( "predictor must be 'horizontal' or 'float', got " f"{predictor!r}" ) out["predictor"] = mapping[predictor] bigtiff = opts.get("bigtiff") if isinstance(bigtiff, bool): out["bigtiff"] = "yes" if bigtiff else "no" elif bigtiff is not None: out["bigtiff"] = str(bigtiff).lower() return out def _cog_translate(opts): out = {} # COG is always tiled; the tiled kwarg is intentionally ignored. bs = opts.get("blocksize") if bs is not None: if isinstance(bs, int): size = bs else: h, w = bs if h != w: raise ValueError( f"COG driver requires a square blocksize; got {bs!r}" ) size = h out["blocksize"] = int(size) compress = opts.get("compress") if compress is None: out["compress"] = "none" else: out["compress"] = str(compress).lower() level = opts.get("compress_level") if level is not None: c = out["compress"] if c in ( "deflate", "zstd", "lzw", "lerc", "lerc_deflate", "lerc_zstd", ): out["level"] = int(level) elif c == "jpeg": out["quality"] = int(level) else: warnings.warn( f"compress_level has no effect with compress={compress!r}", stacklevel=4, ) predictor = opts.get("predictor") if predictor is not None: if out["compress"] == "jpeg": warnings.warn( "predictor is not valid with compress='jpeg'; ignoring", stacklevel=4, ) elif isinstance(predictor, int): # Backward compat with rasterio-style integer predictor values. out["predictor"] = predictor else: mapping = {"horizontal": "STANDARD", "float": "FLOATING_POINT"} if predictor not in mapping: raise ValueError( "predictor must be 'horizontal' or 'float', got " f"{predictor!r}" ) out["predictor"] = mapping[predictor] bigtiff = opts.get("bigtiff") if isinstance(bigtiff, bool): out["bigtiff"] = "yes" if bigtiff else "no" elif bigtiff is not None: out["bigtiff"] = str(bigtiff).lower() overviews = opts.get("overviews") if overviews is None or overviews is False: out["overviews"] = "none" elif isinstance(overviews, (list, tuple)): warnings.warn( "COG driver builds overviews with auto-selected factors; " "explicit overview list is ignored.", stacklevel=4, ) out["overviews"] = "auto" else: out["overviews"] = "auto" overview_resampling = opts.get("overview_resampling") if overview_resampling is not None: out["overview_resampling"] = str(overview_resampling).lower() return out _DRIVER_TRANSLATORS = {"GTiff": _gtiff_translate, "COG": _cog_translate} def _auto_overview_factors(height, width, min_size=256): factors = [] f = 2 while min(height, width) / f >= min_size: factors.append(f) f *= 2 return factors
[docs]def write_raster( xrs, path, *, driver=None, tiled=True, blocksize=None, compress=None, compress_level=None, predictor=None, bigtiff="if_safer", overviews=None, overview_resampling="average", overview_num_threads="all_cpus", **gdal_kwargs, ): ext = _get_extension(path) if ext and ext in WRITE_NOT_IMPLEMENTED_EXTS: raise NotImplementedError( f"Writing files with extension {ext!r} is not supported" ) rio_is_bool = False if xrs.dtype == I64: # GDAL doesn't support I64; cast up to F64 so to_raster won't reject. xrs = xrs.astype(F64) elif is_bool(xrs.dtype): # GDAL doesn't support bool; encode as uint8. rio_is_bool = True xrs = xrs.astype(U8) resolved_driver = _resolve_driver(path, driver) translator = _DRIVER_TRANSLATORS.get(resolved_driver) creation_opts = {} if translator is not None: creation_opts = translator( { "tiled": tiled, "blocksize": blocksize, "compress": compress, "compress_level": compress_level, "predictor": predictor, "bigtiff": bigtiff, "overviews": overviews, "overview_resampling": overview_resampling, } ) if rio_is_bool and resolved_driver == "GTiff": creation_opts["nbits"] = 1 # Escape hatch wins on collisions. creation_opts.update(gdal_kwargs) if resolved_driver == "COG": # rioxarray streams dask chunks by reopening the file in "r+", but # COG forbids updates after creation (it would break the layout). # Stage to a temporary GTiff, then translate to COG. import tempfile from rasterio.shutil import copy as rio_copy out_dir = os.path.dirname(os.path.abspath(path)) with tempfile.NamedTemporaryFile( suffix=".tif", dir=out_dir, delete=False ) as tmpf: tmp_path = tmpf.name try: xrs.rio.to_raster(tmp_path, lock=True, compute=True) rio_copy(tmp_path, path, driver="COG", **creation_opts) finally: if os.path.exists(tmp_path): os.unlink(tmp_path) else: to_raster_kwargs = {"lock": True, "compute": True, **creation_opts} if driver is not None: to_raster_kwargs["driver"] = driver xrs.rio.to_raster(path, **to_raster_kwargs) if overviews and resolved_driver not in _DRIVERS_WITH_INTERNAL_OVERVIEWS: factors = ( _auto_overview_factors(*xrs.shape[-2:]) if overviews is True else list(overviews) ) if factors: from rasterio.enums import Resampling resampling = Resampling[overview_resampling] env_kwargs = {} if overview_num_threads is not None: env_kwargs["GDAL_NUM_THREADS"] = str( overview_num_threads ).upper() with rio.Env(**env_kwargs), rio.open(path, "r+") as ds: ds.build_overviews(factors, resampling) ds.update_tags( ns="rio_overview", resampling=overview_resampling )
def _get_valid_variables(meta, ignore_too_many_dims): data_vars = list(meta.data_vars) valid = [] for v in data_vars: n = meta[v].squeeze().ndim if n > 3: if ignore_too_many_dims: continue else: raise DimensionsError( f"Too many dimensions for variable {v!r} with " f"{meta[v].ndim}." ) elif n in (2, 3): valid.append(v) else: raise DimensionsError( f"Too few dimensions for variable {v!r} with {n}." ) if not valid: raise ValueError("No valid raster variables found") return valid def _build_raster(path, variable, affine, crs, xarray_kwargs): from raster_tools.raster import data_to_raster if affine is None: affine = Affine(1, 0, 0, 0, -1, 0, 0) kwargs = xarray_kwargs.copy() kwargs["chunks"] = "auto" var = xr.open_dataset(path, **kwargs)[variable].squeeze() x = var[var.rio.x_dim].to_numpy() y = var[var.rio.y_dim].to_numpy() nv = var._FillValue if "_FillValue" in var.attrs else var.rio.nodata raster = data_to_raster(var.data, x=x, y=y, affine=affine, crs=crs, nv=nv) if nv is None or np.isnan(nv): raster = raster.set_null_value(get_default_null_value(raster.dtype)) return raster def _get_affine(ds): try: affine = ds.rio.transform() except TypeError as err: # Some datasets like gridMET improperly encode the transform. raise AffineEncodingError( "Error reading GeoTransform data:" f"{ds.coords[ds.rio.grid_mapping].attrs['GeoTransform']!r}" ) from err return affine
[docs]def open_dataset( path, crs=None, ignore_extra_dim_errors=False, xarray_kwargs=None, ): """Open a netCDF or GRIB dataset. This function opens a netCDF or GRIB dataset file and returns a dictionary of Raster objectds where each raster corrersponds to the variables in the the file. netCDF/GRIB files can be N-dimensional, while rasters only comprehend 2 to 3 dimensions (band, y, x), so it may not be possible to map all variables in a file to a raster. See the `ignore_extra_dim_errors` option below for more information. Parameters ---------- path : str THe path to the netCDF or GRIB dataset file. crs : str, rasterio.crs.CRS, optional A coordinate reference system definition to attach to the dataset. This can be an EPSG, PROJ, or WKT string. It can also be a `rasterio.crs.CRS` object. netCDF/GRIB files do not always encode a CRS. This option allows a CRS to be supplied, if known ahead of time. It can also be used to override the CRS encoded in the file. ignore_extra_dim_errors : bool, optional If ``True``, ignore dataset variables that cannot be mapped to a raster. An error is raised, otherwise. netCDF/GRIB files allow N-dimensional. Rasters only comprehend 2 or 3 dimensional data so it is not always possible to map a variable to a raster. The default is ``False``. xarray_kwargs : dict, optional Keyword arguments to supply to `xarray.open_dataset` when opening the file. Raises ------ raster_tools.io.AffineEncodingError Raised if the affine matrix is improperly encoded. ra Returns ------- dataset : dict of Raster A ``dict`` of Raster objects. The keys are the variable names in the dataset file and the values are the corresponding variable data as a raster. """ if xarray_kwargs is None: xarray_kwargs = {} xarray_kwargs["decode_coords"] = "all" ext = _get_extension(path) if ext in NC_EXTS: _require_backend("netCDF4", "netcdf4", ext, extra="io") elif ext in GRIB_EXTS: _require_backend("cfgrib", "cfgrib", ext, extra="io") tmp_ds = xr.open_dataset(path, **xarray_kwargs) data_vars = _get_valid_variables(tmp_ds, ignore_extra_dim_errors) crs = crs or tmp_ds.rio.crs affine = _get_affine(tmp_ds) tmp_ds = None ds = {} for v in data_vars: ds[v] = _build_raster(path, v, affine, crs, xarray_kwargs) return ds