diff --git a/xarray/compat/array_api_compat.py b/xarray/compat/array_api_compat.py index e1e5d5c5bdc..575f8cdf07d 100644 --- a/xarray/compat/array_api_compat.py +++ b/xarray/compat/array_api_compat.py @@ -1,3 +1,5 @@ +from types import ModuleType + import numpy as np from xarray.namedarray.pycompat import array_type @@ -46,7 +48,7 @@ def result_type(*arrays_and_dtypes, xp) -> np.dtype: return _future_array_api_result_type(*arrays_and_dtypes, xp=xp) -def get_array_namespace(*values): +def get_array_namespace(*values) -> ModuleType: def _get_single_namespace(x): if hasattr(x, "__array_namespace__"): return x.__array_namespace__() diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index bb12704e55c..c0cb9a5777f 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -687,7 +687,7 @@ def __array__( else: return np.asarray(self.get_duck_array(), dtype=dtype) - def get_duck_array(self): + def get_duck_array(self) -> duckarray: return self.array.get_duck_array() def __getitem__(self, key: Any): diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 100c256fa9d..b9e0ebf5442 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -704,6 +704,11 @@ def dtype(self: Any) -> np.dtype: def shape(self: Any) -> tuple[int, ...]: return self.array.shape + def __array_namespace__(self: Any) -> ModuleType: + from xarray.compat.array_api_compat import get_array_namespace + + return get_array_namespace(self.array) + def __getitem__(self: Any, key): return self.array[key] diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py index eb01a150c18..c03b9a4da13 100644 --- a/xarray/namedarray/daskmanager.py +++ b/xarray/namedarray/daskmanager.py @@ -68,8 +68,9 @@ def from_array( import dask.array as da if isinstance(data, ImplicitToExplicitIndexingAdapter): - # lazily loaded backend array classes should use NumPy array operations. - kwargs["meta"] = np.ndarray + # lazily loaded backend array classes should use NumPy or CuPy array operations. + xp = data.__array_namespace__() + kwargs["meta"] = xp.ndarray return da.from_array( data, diff --git a/xarray/namedarray/pycompat.py b/xarray/namedarray/pycompat.py index 5832f7cc9e7..b41711e07ab 100644 --- a/xarray/namedarray/pycompat.py +++ b/xarray/namedarray/pycompat.py @@ -140,7 +140,7 @@ def to_duck_array(data: Any, **kwargs: dict[str, Any]) -> duckarray[_ShapeType, return loaded_data if isinstance(data, ExplicitlyIndexed | ImplicitToExplicitIndexingAdapter): - return data.get_duck_array() # type: ignore[no-untyped-call, no-any-return] + return data.get_duck_array() elif is_duck_array(data): return data else: