Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion xarray/compat/array_api_compat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from types import ModuleType

import numpy as np

from xarray.namedarray.pycompat import array_type
Expand Down Expand Up @@ -46,7 +48,7 @@ def result_type(*arrays_and_dtypes, xp) -> np.dtype:
return _future_array_api_result_type(*arrays_and_dtypes, xp=xp)


def get_array_namespace(*values):
def get_array_namespace(*values) -> ModuleType:
def _get_single_namespace(x):
if hasattr(x, "__array_namespace__"):
return x.__array_namespace__()
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def __array__(
else:
return np.asarray(self.get_duck_array(), dtype=dtype)

def get_duck_array(self):
def get_duck_array(self) -> duckarray:
return self.array.get_duck_array()

def __getitem__(self, key: Any):
Expand Down
5 changes: 5 additions & 0 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,11 @@ def dtype(self: Any) -> np.dtype:
def shape(self: Any) -> tuple[int, ...]:
return self.array.shape

def __array_namespace__(self: Any) -> ModuleType:
from xarray.compat.array_api_compat import get_array_namespace

return get_array_namespace(self.array)

def __getitem__(self: Any, key):
return self.array[key]

Expand Down
5 changes: 3 additions & 2 deletions xarray/namedarray/daskmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def from_array(
import dask.array as da

if isinstance(data, ImplicitToExplicitIndexingAdapter):
# lazily loaded backend array classes should use NumPy array operations.
kwargs["meta"] = np.ndarray
# lazily loaded backend array classes should use NumPy or CuPy array operations.
xp = data.__array_namespace__()
kwargs["meta"] = xp.ndarray

return da.from_array(
data,
Expand Down
2 changes: 1 addition & 1 deletion xarray/namedarray/pycompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def to_duck_array(data: Any, **kwargs: dict[str, Any]) -> duckarray[_ShapeType,
return loaded_data

if isinstance(data, ExplicitlyIndexed | ImplicitToExplicitIndexingAdapter):
return data.get_duck_array() # type: ignore[no-untyped-call, no-any-return]
return data.get_duck_array()
elif is_duck_array(data):
return data
else:
Expand Down
Loading