Skip to content

Support CuPy-backed arrays in DaskManager#11383

Draft
weiji14 wants to merge 4 commits into
pydata:mainfrom
weiji14:dask_with_cupy
Draft

Support CuPy-backed arrays in DaskManager#11383
weiji14 wants to merge 4 commits into
pydata:mainfrom
weiji14:dask_with_cupy

Conversation

@weiji14

@weiji14 weiji14 commented Jun 13, 2026

Copy link
Copy Markdown
Contributor

Description

The default Dask ChunkManagerEntrypoint appears to be hardcoded to return NumPy arrays by default, even if the underlying arrays are CuPy arrays

TODO:

Probably needs #11381 to be merged first. Part of resolving xarray-contrib/cupy-xarray#81 (comment).

Checklist

  • Closes #xxxx
  • Tests added
  • User visible changes (including notable bug fixes) are documented in whats-new.rst
  • New functions/methods are listed in api.rst

AI Disclosure

The "meta" argument passed to dask.array.from_array should not be hardcoded to just `numpy.ndarray`, but allow for `cupy.ndarray` too.
@github-actions github-actions Bot added the topic-NamedArray Lightweight version of Variable label Jun 13, 2026
Not sure how to type-hint np | cp | ??, so just use Any for output of get_array_namespace.
Comment thread xarray/namedarray/daskmanager.py Outdated
# lazily loaded backend array classes should use NumPy array operations.
kwargs["meta"] = np.ndarray
# lazily loaded backend array classes should use NumPy or CuPy array operations.
xp = get_array_namespace(data.get_duck_array())

@keewis keewis Jun 13, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably need to add some API allow getting the underlying array type / library without actually fetching data. Something like a data.get_array_namespace() or data.get_meta()? Not sure how easy it would be to implement that, though.

(I think this is what causes the tests to fail)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we could just call xp = data.__array_namespace__() following the Array API spec - https://data-apis.org/array-api/2025.12/API_specification/generated/array_api.array.__array_namespace__.html, and it would propagate through all the subclassed layers to get the underlying array namespace (numpy or cupy). I thought this would work by putting it into the NDArrayMixin (b77cc57), but that breaks a lot of the lazy repls...

Might need to think this through a bit more. Wondering if there needs to be a cached .__cached_array_namespace__ attribute of some sort to work with the lazy objects...

Centralize retrieving of the __array_namespace__ through several subclassed layers, to avoid having to go through `.get_duck_array()`. Need to put `from xarray.compat.array_api_compat import get_array_namespace` import within the method to avoid circular import.

Also type-hinted output of `get_array_namespace` as ModuleType following numpy/numpy#20719.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

topic-indexing topic-NamedArray Lightweight version of Variable

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants