diff --git a/tests/test_axis.py b/tests/test_axis.py index b85fb344..df604f2b 100644 --- a/tests/test_axis.py +++ b/tests/test_axis.py @@ -247,6 +247,32 @@ def test_returns_dataarray_dimension_coordinate_var_using_standard_name_attr(sel assert result.identical(expected) + def test_multidim_finds_coords_only_in_cf_axes(self): + """Test that multidim=True finds coords discoverable via cf axes.""" + # Create a dataset where 'lat' has axis="Y" attribute but no + # standard_name or recognized coordinate name that cf_xarray would + # put in obj.cf.coordinates. + lat = xr.DataArray( + data=np.array([[0, 1], [2, 3]]), + dims=["y", "x"], + attrs={"axis": "Y"}, + ) + lon = xr.DataArray( + data=np.array([[10, 11], [12, 13]]), + dims=["y", "x"], + attrs={"axis": "X"}, + ) + ds = xr.Dataset(coords={"lat": lat, "lon": lon}) + + # Without multidim, this would raise KeyError because lat/lon are + # multidimensional and not in indexes. + with pytest.raises(KeyError): + get_dim_coords(ds, "Y", multidim=False) + + # With multidim=True, it should find lat via cf axes. + result = get_dim_coords(ds, "Y", multidim=True) + xr.testing.assert_identical(result, ds["lat"]) + class TestGetCoordsByName: def test_raises_error_if_coordinate_not_found(self): diff --git a/tests/test_regrid.py b/tests/test_regrid.py index ab9b619e..b0c251fc 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -1617,6 +1617,96 @@ def test_vertical_tool_check(self, _get_input_grid): ): self.ac.vertical("ts", mock_data, tool="dummy", target_data=None) # type: ignore + def test_horizontal_with_multidim_coords_issue_816(self): + """Regression test for #816: dataset with 2D lat/lon and no 1D coord vars. + + Ensures that ds.regridder.horizontal(..., tool="xesmf") works on + datasets shaped like curvilinear/unstructured grids where lat and lon + are multidimensional coordinates (not index dimensions). + """ + ny, nx = 10, 20 + + lat_2d = np.linspace(-90, 90, ny * nx).reshape(ny, nx) + lon_2d = np.linspace(-180, 180, ny * nx).reshape(ny, nx) + + lat = xr.DataArray( + data=lat_2d, + dims=["y", "x"], + attrs={"units": "degrees_north", "axis": "Y"}, + ) + lon = xr.DataArray( + data=lon_2d, + dims=["y", "x"], + attrs={"units": "degrees_east", "axis": "X"}, + ) + + ts = xr.DataArray( + data=np.random.default_rng(42).random((ny, nx)), + dims=["y", "x"], + attrs={"units": "K"}, + ) + + ds = xr.Dataset( + data_vars={"ts": ts}, + coords={"lat": lat, "lon": lon}, + ) + + output_grid = grid.create_uniform_grid(-90, 90, 30.0, -180, 180, 60.0) + + # This should not raise KeyError; the multidim path must discover + # lat/lon via CF axis metadata. + output = ds.regridder.horizontal( + "ts", output_grid, tool="xesmf", method="bilinear" + ) + + assert "ts" in output + assert output.ts.dims == ("lat", "lon") + + def test_horizontal_with_nonstandard_multidim_coord_names_issue_816(self): + """Regression test for #816: multidim coords with non-standard names. + + Ensures that ds.regridder.horizontal(..., tool="xesmf") works when + 2D coordinates use non-standard names (e.g., nav_lat/nav_lon) that + are not in VAR_NAME_MAP but are discoverable via CF axis attributes. + """ + ny, nx = 10, 20 + + lat_2d = np.linspace(-90, 90, ny * nx).reshape(ny, nx) + lon_2d = np.linspace(-180, 180, ny * nx).reshape(ny, nx) + + nav_lat = xr.DataArray( + data=lat_2d, + dims=["y", "x"], + attrs={"units": "degrees_north", "axis": "Y"}, + ) + nav_lon = xr.DataArray( + data=lon_2d, + dims=["y", "x"], + attrs={"units": "degrees_east", "axis": "X"}, + ) + + ts = xr.DataArray( + data=np.random.default_rng(42).random((ny, nx)), + dims=["y", "x"], + attrs={"units": "K"}, + ) + + ds = xr.Dataset( + data_vars={"ts": ts}, + coords={"nav_lat": nav_lat, "nav_lon": nav_lon}, + ) + + output_grid = grid.create_uniform_grid(-90, 90, 30.0, -180, 180, 60.0) + + # This should not raise KeyError; the multidim flag must be plumbed + # through _get_axis_coord_and_bounds to get_dim_coords. + output = ds.regridder.horizontal( + "ts", output_grid, tool="xesmf", method="bilinear" + ) + + assert "ts" in output + assert output.ts.dims == ("lat", "lon") + class TestBase: def test_preserve_bounds(self): @@ -1677,3 +1767,25 @@ def horizontal(self, data_var, ds): ds_out = regridder.horizontal("ts", ds_in) assert ds_in == ds_out + + def test_supports_multidim_defaults_to_false(self): + """Test that BaseRegridder subclasses default supports_multidim to False.""" + + class MinimalRegridder(base.BaseRegridder): + def horizontal(self, data_var, ds): + return ds + + assert MinimalRegridder.supports_multidim is False + assert MinimalRegridder.can_handle_multidim() is False + + def test_supports_multidim_can_be_overridden(self): + """Test that subclasses can override supports_multidim to True.""" + + class MultidimRegridder(base.BaseRegridder): + supports_multidim = True + + def horizontal(self, data_var, ds): + return ds + + assert MultidimRegridder.supports_multidim is True + assert MultidimRegridder.can_handle_multidim() is True diff --git a/xcdat/axis.py b/xcdat/axis.py index 8cab0657..ff88ebf4 100644 --- a/xcdat/axis.py +++ b/xcdat/axis.py @@ -5,6 +5,7 @@ from typing import Literal +import cf_xarray as cfxr # noqa: F401 import numpy as np import xarray as xr @@ -73,7 +74,7 @@ def get_dim_keys(obj: xr.Dataset | xr.DataArray, axis: CFAxisKey) -> str | list[ def get_dim_coords( - obj: xr.Dataset | xr.DataArray, axis: CFAxisKey + obj: xr.Dataset | xr.DataArray, axis: CFAxisKey, multidim: bool = False ) -> xr.Dataset | xr.DataArray: """Gets the dimension coordinates for an axis. @@ -117,10 +118,21 @@ def get_dim_coords( ---------- .. [1] https://cf-xarray.readthedocs.io/en/latest/coord_axes.html#axes-and-coordinates """ - # Get the object's index keys, with each being a dimension. - # NOTE: xarray does not include multidimensional coordinates as index keys. - # Example: ["lat", "lon", "time"] - index_keys = obj.indexes.keys() + if multidim: + # multidimensional coordinates cannot be indexes, use all coords. + # Combine both obj.cf.coordinates and obj.cf.axes to avoid missing + # coordinates that are only discoverable via CF axis metadata. + cf_keys: set[str] = set() + for keys in obj.cf.coordinates.values(): + cf_keys.update(keys) + for keys in obj.cf.axes.values(): + cf_keys.update(keys) + index_keys = list(cf_keys) + else: + # Get the object's index keys, with each being a dimension. + # NOTE: xarray does not include multidimensional coordinates as index keys. + # Example: ["lat", "lon", "time"] + index_keys = list(obj.indexes.keys()) # Attempt to map the axis it all of its coordinate variable(s) using the # axis and coordinate names in the object attributes (if they are set). diff --git a/xcdat/regridder/accessor.py b/xcdat/regridder/accessor.py index 40062aed..132d8d9d 100644 --- a/xcdat/regridder/accessor.py +++ b/xcdat/regridder/accessor.py @@ -5,16 +5,17 @@ from xcdat.axis import CFAxisKey, get_coords_by_name, get_dim_coords from xcdat.bounds import create_bounds from xcdat.regridder import regrid2, xesmf, xgcm +from xcdat.regridder.base import BaseRegridder from xcdat.regridder.grid import _validate_grid_has_single_axis_dim HorizontalRegridTools = Literal["xesmf", "regrid2"] -HORIZONTAL_REGRID_TOOLS = { +HORIZONTAL_REGRID_TOOLS: dict[str, type[BaseRegridder]] = { "regrid2": regrid2.Regrid2Regridder, "xesmf": xesmf.XESMFRegridder, } VerticalRegridTools = Literal["xgcm"] -VERTICAL_REGRID_TOOLS = {"xgcm": xgcm.XGCMRegridder} +VERTICAL_REGRID_TOOLS: dict[str, type[BaseRegridder]] = {"xgcm": xgcm.XGCMRegridder} @xr.register_dataset_accessor(name="regridder") @@ -166,7 +167,9 @@ def horizontal( f"Tool {e!s} does not exist, valid choices {list(HORIZONTAL_REGRID_TOOLS)}" ) from e - input_grid = _get_input_grid(self._ds, data_var, ["X", "Y"]) + input_grid = _get_input_grid( + self._ds, data_var, ["X", "Y"], multidim=regrid_tool.can_handle_multidim() + ) regridder = regrid_tool(input_grid, output_grid, **options) output_ds = regridder.horizontal(data_var, self._ds) @@ -236,20 +239,17 @@ def vertical( f"Tool {e!s} does not exist, valid choices " f"{list(VERTICAL_REGRID_TOOLS)}" ) from e - input_grid = _get_input_grid( - self._ds, - data_var, - [ - "Z", - ], - ) + + input_grid = _get_input_grid(self._ds, data_var, ["Z"]) regridder = regrid_tool(input_grid, output_grid, **options) output_ds = regridder.vertical(data_var, self._ds) return output_ds -def _obj_to_grid_ds(obj: xr.Dataset | xr.DataArray) -> xr.Dataset: +def _obj_to_grid_ds( + obj: xr.Dataset | xr.DataArray, multidim: bool = False +) -> xr.Dataset: """ Convert an xarray object to a new Dataset containing axis coordinates and bounds. @@ -286,7 +286,7 @@ def _obj_to_grid_ds(obj: xr.Dataset | xr.DataArray) -> xr.Dataset: with xr.set_options(keep_attrs=True): for axis in axis_names: - coord, bounds = _get_axis_coord_and_bounds(obj, axis) + coord, bounds = _get_axis_coord_and_bounds(obj, axis, multidim=multidim) if coord is not None: axis_coords[str(coord.name)] = coord @@ -304,12 +304,17 @@ def _obj_to_grid_ds(obj: xr.Dataset | xr.DataArray) -> xr.Dataset: attrs=obj.attrs, ) + # Multidimensional coordinates bounds generation is not supported + if multidim: + return output_ds + # Add bounds only for axes that do not already have them. This # prevents multiple sets of bounds being added for the same axis. # For example, curvilinear grids can have multiple coordinates for the # same axis (e.g., (nlat, lat) for X and (nlon, lon) for Y). We only # need lat_bnds and lon_bnds for the X and Y axes, respectively, and not # nlat_bnds and nlon_bnds. + for axis, has_bounds in axis_has_bounds.items(): if not has_bounds: output_ds = output_ds.bounds.add_bounds(axis=axis) @@ -318,7 +323,7 @@ def _obj_to_grid_ds(obj: xr.Dataset | xr.DataArray) -> xr.Dataset: def _get_axis_coord_and_bounds( - obj: xr.Dataset | xr.DataArray, axis: CFAxisKey + obj: xr.Dataset | xr.DataArray, axis: CFAxisKey, multidim: bool = False ) -> tuple[xr.DataArray | None, xr.DataArray | None]: try: coord_var = get_coords_by_name(obj, axis) @@ -328,7 +333,7 @@ def _get_axis_coord_and_bounds( ) except (ValueError, KeyError): try: - coord_var = get_dim_coords(obj, axis) # type: ignore + coord_var = get_dim_coords(obj, axis, multidim=multidim) # type: ignore _validate_grid_has_single_axis_dim(axis, coord_var) except KeyError: coord_var = None @@ -347,7 +352,12 @@ def _get_axis_coord_and_bounds( return coord_var, bounds_var -def _get_input_grid(ds: xr.Dataset, data_var: str, dup_check_dims: list[CFAxisKey]): +def _get_input_grid( + ds: xr.Dataset, + data_var: str, + dup_check_dims: list[CFAxisKey], + multidim: bool = False, +): """ Extract the grid from ``ds``. @@ -374,10 +384,12 @@ def _get_input_grid(ds: xr.Dataset, data_var: str, dup_check_dims: list[CFAxisKe all_coords = set(ds.coords.keys()) for dimension in dup_check_dims: - coords = get_dim_coords(ds, dimension) + coords = get_dim_coords(ds, dimension, multidim=multidim) if isinstance(coords, xr.Dataset): - coord = set([get_dim_coords(ds[data_var], dimension).name]) + coord = set( + [get_dim_coords(ds[data_var], dimension, multidim=multidim).name] + ) dimension_coords = set(ds.cf[[dimension]].coords.keys()) @@ -387,7 +399,7 @@ def _get_input_grid(ds: xr.Dataset, data_var: str, dup_check_dims: list[CFAxisKe input_grid = ds.drop_dims(to_drop) # drops extra dimensions on input grid - grid = input_grid.regridder.grid + grid = _obj_to_grid_ds(input_grid, multidim=multidim) # preserve mask on grid if "mask" in ds: diff --git a/xcdat/regridder/base.py b/xcdat/regridder/base.py index 0458abad..571e1255 100644 --- a/xcdat/regridder/base.py +++ b/xcdat/regridder/base.py @@ -95,6 +95,12 @@ def _drop_axis(ds: xr.Dataset, axis: list[CFAxisKey]) -> xr.Dataset: class BaseRegridder(abc.ABC): """BaseRegridder.""" + supports_multidim: bool = False + + @classmethod + def can_handle_multidim(cls) -> bool: + return cls.supports_multidim + def __init__(self, input_grid: xr.Dataset, output_grid: xr.Dataset, **options: Any): self._input_grid = input_grid self._output_grid = output_grid diff --git a/xcdat/regridder/regrid2.py b/xcdat/regridder/regrid2.py index bb702c4b..dfa90eab 100644 --- a/xcdat/regridder/regrid2.py +++ b/xcdat/regridder/regrid2.py @@ -11,6 +11,8 @@ class Regrid2Regridder(BaseRegridder): + supports_multidim = False + def __init__( self, input_grid: xr.Dataset, diff --git a/xcdat/regridder/xesmf.py b/xcdat/regridder/xesmf.py index f4fa8961..325da558 100644 --- a/xcdat/regridder/xesmf.py +++ b/xcdat/regridder/xesmf.py @@ -19,6 +19,8 @@ class XESMFRegridder(BaseRegridder): + supports_multidim = True + def __init__( self, input_grid: xr.Dataset, diff --git a/xcdat/regridder/xgcm.py b/xcdat/regridder/xgcm.py index c0619755..337c4788 100644 --- a/xcdat/regridder/xgcm.py +++ b/xcdat/regridder/xgcm.py @@ -13,6 +13,8 @@ class XGCMRegridder(BaseRegridder): + supports_multidim = False + def __init__( self, input_grid: xr.Dataset,