Source code for raster_tools._mosaic

import numbers

import dask.array as da
import geopandas as gpd
import numba as nb
import numpy as np
import rasterio as rio
import shapely
from odc.geo.geobox import GeoBox

import raster_tools as rts

__all__ = ["mosaic"]


# Tolerance for geobox comparisons, as a fraction of pixel size. Some
# published products carry sub-pixel FP noise (observed up to ~1e-4 in CRS
# units) in what are otherwise shared grids; strict odc-geo equality would
# reject those, forcing unnecessary reprojections.
_GRID_PIXEL_TOLERANCE = 1e-3


def _grids_close(a, b, pixel_tolerance=_GRID_PIXEL_TOLERANCE):
    if a.crs != b.crs or a.shape != b.shape:
        return False
    aa, bb = a.affine, b.affine
    atol = pixel_tolerance * max(abs(aa.a), abs(aa.e))
    return all(
        abs(x - y) <= atol
        for x, y in zip(
            (aa.a, aa.b, aa.c, aa.d, aa.e, aa.f),
            (bb.a, bb.b, bb.c, bb.d, bb.e, bb.f),
        )
    )


def _are_all_grids_same(grids):
    if not grids:
        return True

    grids = [getattr(g, "geobox", g) for g in grids]
    gtest = grids[0]
    return all(_grids_close(gtest, g) for g in grids[1:])


def _build_dst_grid_from_inputs(rasters, dst_crs=None):
    if _are_all_grids_same(rasters):
        if dst_crs is None:
            return rasters[0].geobox
        else:
            return rasters[0].reproject(dst_crs).geobox

    if dst_crs is None:
        dst_crs = rasters[0].crs
    resolution = np.abs(rasters[0].resolution[0])
    rasters_dst = [
        r.reproject(dst_crs, resolution=resolution) for r in rasters
    ]
    bboxes_dst = [shapely.box(*r.bounds) for r in rasters_dst]
    total_bounds_dst = gpd.GeoSeries(bboxes_dst, crs=dst_crs).total_bounds
    dst_grid = GeoBox.from_bbox(
        total_bounds_dst, crs=dst_crs, resolution=resolution, tight=True
    )
    return dst_grid


def _build_raster_from_grid(grid, dtype, nodata):
    data = da.full((grid.shape.y, grid.shape.x), nodata, dtype=dtype)
    # coordinates is an ordered (y-axis, x-axis) mapping; key names vary by
    # CRS (e.g. "x"/"y" for projected, "longitude"/"latitude" for 4326).
    y_coord, x_coord = grid.coordinates.values()
    return rts.data_to_raster(
        data, x=x_coord.values, y=y_coord.values, crs=grid.crs, nv=nodata
    )


@nb.jit(nopython=True, nogil=True)
def _nb_push(x, missing_value):
    out = x
    isnan = np.isnan(missing_value)
    for i in range(1, out.shape[0]):
        for j in range(out.shape[1]):
            for k in range(out.shape[2]):
                is_missing = (
                    np.isnan(out[i, j, k])
                    if isnan
                    else out[i, j, k] == missing_value
                )
                if is_missing:
                    out[i, j, k] = out[i - 1, j, k]
    return out


def _push(x, missing_value, axis=None, dtype=None):
    # dask's blelloch forces dtype in the interface so accept and ignore it.
    # Do copy to avoid read-only array issues in _nb_push
    return _nb_push(x.copy(), missing_value)


def _push_binop_last(lhs, rhs, missing_value):
    cond = np.isnan(rhs) if np.isnan(missing_value) else rhs == missing_value
    return np.where(cond, lhs, rhs)


def _push_take_last(
    x, axis=None, dtype=None, keepdims=None, missing_value=None, **kwargs
):
    return _push(x, missing_value)[-1:]


def _dask_push(x, missing_value):
    def push_func(x, axis=None, dtype=None, **kwargs):
        return _push(x, missing_value=missing_value, axis=axis)

    def push_binop(a, b):
        return _push_binop_last(a, b, missing_value=missing_value)

    def push_preop(x, axis=None, dtype=None, keepdims=None, **kwargs):
        return _push_take_last(x, missing_value=missing_value)

    return da.reductions.cumreduction(
        func=push_func,
        binop=push_binop,
        ident=missing_value,
        x=x,
        axis=0,
        dtype=x.dtype,
        method="blelloch",
        preop=push_preop,
    )


def _dask_push_take_last(data, missing_value):
    return _dask_push(data, missing_value)[-1:]


def _paint(stacked_data, nodata, mosaic_method):
    """Paint the stacked data onto a result array.

    This does not operate in place. A new array is returned.

    Parameters
    ----------
    stacked_data : da.Array
        The stacked data as (N, H, W) dask array.
    nodata : scalar
        The nodata value fill missing cells with.
    mosaic_method : str
        The mosaic_method to use when painting `stacked_data` onto the result.
        Possible values: "first", "last", "min", "max", "sum".

    Returns
    -------
    da.Array
        The painted result as a dask array.

    """
    if mosaic_method == "first":
        painted_result = _dask_push_take_last(stacked_data, nodata)
    elif mosaic_method == "last":
        painted_result = _dask_push_take_last(stacked_data[::-1], nodata)
    elif mosaic_method in ("min", "max"):
        mask = rts.raster.get_mask_from_data(stacked_data, nodata)
        if np.issubdtype(stacked_data.dtype, np.floating):
            filled = da.where(mask, np.nan, stacked_data)
            op = da.nanmin if mosaic_method == "min" else da.nanmax
        else:
            info = np.iinfo(stacked_data.dtype)
            sentinel = info.max if mosaic_method == "min" else info.min
            filled = da.where(mask, sentinel, stacked_data)
            op = da.min if mosaic_method == "min" else da.max
        reduced = op(filled, axis=0, keepdims=True)
        all_missing = mask.all(axis=0, keepdims=True)
        painted_result = da.where(all_missing, nodata, reduced)
    elif mosaic_method == "sum":
        mask = rts.raster.get_mask_from_data(stacked_data, nodata)
        stacked_data = da.where(mask, 0, stacked_data)
        # Make sure to set the dtype. Sum will upcast otherwise
        summed_data = da.sum(
            stacked_data, axis=0, keepdims=True, dtype=stacked_data.dtype
        )
        mask = mask.all(axis=0, keepdims=True)
        painted_result = da.where(mask, nodata, summed_data)
    else:
        raise ValueError("Invalid mosaic method")
    return painted_result


def _paint_recursive(stacked_data, dst_nodata, mosaic_method):
    # Use recursion to greatly reduce dask's memory usage. 8 was found to be a
    # good cutoff.
    n = len(stacked_data)
    if n < 8:
        return _paint(stacked_data, dst_nodata, mosaic_method)

    left = _paint_recursive(stacked_data[: n // 2], dst_nodata, mosaic_method)
    right = _paint_recursive(stacked_data[n // 2 :], dst_nodata, mosaic_method)
    return _paint_recursive(
        da.concatenate([left, right], axis=0), dst_nodata, mosaic_method
    )


def _mosaic_single_band(src_rasters, dst_raster, mosaic_method):
    dst_nodata = dst_raster.null_value
    # reverse so that the first raster's data will be just before dst_data
    stacked_data = [sr.data for sr in src_rasters[::-1]]
    dst_data = da.full_like(dst_raster.data, dst_nodata)
    stacked_data.append(dst_data)
    stacked_data = da.concatenate(stacked_data, axis=0)
    return _paint_recursive(stacked_data, dst_nodata, mosaic_method)


def _mosaic(src_rasters, dst_raster, mosaic_method):
    nbands = max(src.nbands for src in src_rasters)

    # Group the src rasters by bands
    src_rasters_as_grouped_bands = [[] for i in range(nbands)]
    for src in src_rasters:
        for i in range(src.nbands):
            src_rasters_as_grouped_bands[i].append(src.get_bands(i + 1))

    mosaiced_bands_data = [
        _mosaic_single_band(
            src_rasters_as_grouped_bands[i], dst_raster, mosaic_method
        )
        for i in range(nbands)
    ]
    data = da.concatenate(mosaiced_bands_data, axis=0)
    return rts.data_to_raster_like(data, dst_raster, nv=dst_raster.null_value)


_MOSAIC_OPS = {"first", "last", "min", "max", "sum"}
_RESAMPLING_METHODS = {v.name for v in rio.warp.Resampling}


[docs]def mosaic( rasters, mosaic_method="last", dst_crs=None, dst_grid=None, resampling_method="nearest", dtype=None, null_value=None, ): """Mosaic multiple rasters into a new, single raster. The inputs can have multiple and differing numbers of bands. The number of bands in the output will be the same as the input with the largest number of bands. Parameters ---------- rasters : list of raster_tools.Raster A list-like object containing the rasters to be mosaicked. These can have differing grids, resolutions, and projections. mosaic_method : str, optional The method to use when resolving overlap. Valid options are: 'first' The final pixel will take its value from the first raster with a valid pixel at the pixel's location. 'last' The final pixel will take its value from the last raster with a valid pixel at the pixel's location. Default. 'min' The final pixel will take its value from the minimum valid value across all input rasters, at the given location. 'max' The final pixel will take its value from the maximum valid value across all input rasters, at the given location. 'sum' The final pixel will be the sum of all valid pixels at the given location. Note, this can lead to overflow issues for sufficiently large values and small enough dtypes. dst_crs : CRS-like, str, int, optional The destination CRS to use when building the destination grid. This can be anything that can be parsed by :py:meth:`rasterio.CRS.from_user_input`. This is only checked if `dst_grid` is not provided. The default is to take the CRS from the first raster in `rasters`. dst_grid : odc.geo.GeoBox, raster_tools.Raster, str, optional The definition for the destination grid to mosaic the `rasters` onto. This can be a :py:class:`odc.geo.GeoBox`, :py:class:`raster_tools.Raster` object, or a path str. If the input is a raster object or path, this function does NOT write to the given raster, it instead uses the raster as a reference for the grid. The default is to check `dst_crs` and construct a grid that encompasses all inputs, using the resolution from the first raster in `rasters`. resampling_method : str, optional Resampling method to use when reprojecting input rasters to the destination grid. The default is nearest. Valid options are: 'nearest' Nearest neighbor resampling. This is the default. 'bilinear' Bilinear resampling. 'cubic' Cubic resampling. 'cubic_spline' Cubic spline resampling. 'lanczos' Lanczos windowed sinc resampling. 'average' Average resampling, computes the weighted average of all contributing pixels. 'mode' Mode resampling, selects the value which appears most often. 'max' Maximum resampling. (GDAL 2.0+) 'min' Minimum resampling. (GDAL 2.0+) 'med' Median resampling. (GDAL 2.0+) 'q1' Q1, first quartile resampling. (GDAL 2.0+) 'q3' Q3, third quartile resampling. (GDAL 2.0+) 'sum' Sum, compute the weighted sum. (GDAL 3.1+) 'rms' RMS, root mean square/quadratic mean. (GDAL 3.3+) dtype : numpy.dtype, str, optional The dtype for the output raster. The default is to use :py:func:`numpy.result_type` on the dtypes from the input rasters. null_value : scalar, optional The nodata/null value to use for the output raster. The default is to get a default value based on the output raster's dtype. Returns ------- Raster The resulting mosaicked raster. """ if len(rasters) == 0: raise ValueError("No rasters provided") if mosaic_method not in _MOSAIC_OPS: raise ValueError("Invalid mosaic operation") resampling_method = resampling_method or "nearest" if resampling_method not in _RESAMPLING_METHODS: raise ValueError("Invalid resampling method") nodata = null_value src_rasters = [rts.get_raster(r) for r in rasters] dtype = ( np.dtype(dtype) if dtype is not None else np.result_type(*[r.dtype for r in src_rasters]) ) if (nodata is not None) and (not isinstance(nodata, numbers.Number)): raise TypeError("nodata must be a scalar or None") if nodata is None: nodatas = [ src.null_value for src in src_rasters if src.null_value is not None ] if nodatas: nodata = nodatas[0] else: nodata = rts.masking.get_default_null_value(dtype) if dst_grid is None: dst_crs = ( dst_crs if dst_crs is None else rio.CRS.from_user_input(dst_crs) ) dst_grid = _build_dst_grid_from_inputs(src_rasters, dst_crs=dst_crs) elif isinstance(dst_grid, (str, rts.Raster)): dst_grid = rts.get_raster(dst_grid).geobox elif not isinstance(dst_grid, GeoBox): raise TypeError( f"Expected dst_grid to have type GeoBox. Got {type(dst_grid)}" ) # Make sure inputs are on destination grid src_rasters_in_dst = [ ( r if _are_all_grids_same([r, dst_grid]) else r.reproject(dst_grid, resample_method=resampling_method) ).astype(dtype, new_null_value=nodata) for r in src_rasters ] dst_raster = _build_raster_from_grid(dst_grid, dtype, nodata) # Rechunk all of the reprojected inputs so they are chunk aligned. This # greatly boosts the performance of the dask operations down the line, such # as da.concatenate. tmp = [] target_chunks_2d = dst_raster.data.chunks[1:] for sr in src_rasters_in_dst: target_chunks = ((1,) * sr.nbands, *target_chunks_2d) tmp.append(sr.chunk(target_chunks)) src_rasters_in_dst = tmp # TODO: check for all boolean return _mosaic(src_rasters_in_dst, dst_raster, mosaic_method)