Skip to content

PR: Implement support for *Python Array API Standard*.#1406

Open
KelSolaar wants to merge 1 commit into
developfrom
feature/array-api-support
Open

PR: Implement support for *Python Array API Standard*.#1406
KelSolaar wants to merge 1 commit into
developfrom
feature/array-api-support

Conversation

@KelSolaar

@KelSolaar KelSolaar commented Jun 2, 2026

Copy link
Copy Markdown
Member

Summary

This PR implements support for the Python Array API Standard, enabling computations to dispatch onto alternative array backends:

  • NumPy (default)
  • JAX
  • PyTorch (including Apple MPS)

Dispatch is currently opt-in and NumPy-only behaviour is unchanged by default. Once enabled, the backend is selected from the type of the input array. It can be enabled three ways:

1. Environment variable, set before importing Colour:

import os

os.environ["COLOUR_SCIENCE__ARRAY_API"] = "1"

import colour
import jax.numpy as jnp

colour.XYZ_to_sRGB(jnp.array([0.20654008, 0.12197225, 0.05136952]))
# Array([0.7057394 , 0.19248262, 0.2235417 ], dtype=float32)

2. Programmatically, toggle the global state at runtime:

import colour
import torch
from colour.utilities import set_array_api_enabled

set_array_api_enabled(True)

colour.XYZ_to_sRGB(torch.tensor([0.20654008, 0.12197225, 0.05136952]))
# tensor([0.7057, 0.1925, 0.2235], dtype=torch.float64)

3. Scoped context manager (also usable as a decorator), enable for a block only:

import colour
import torch
from colour.utilities import array_api_enable

with array_api_enable(True):
    colour.XYZ_to_sRGB(
        torch.tensor([0.20654008, 0.12197225, 0.05136952], device="mps")
    )
# tensor([0.7057, 0.1925, 0.2235], device='mps:0')

What's added

  • Namespace-aware boundary helpers + a full xp_* operation surface in colour.utilities:
    • Namespace resolution: array_namespace, is_numpy_namespace, is_non_ndarray, trace_array_namespace
    • Boundary conversion: as_ndarray, cast_non_ndarray, xp_as_array / xp_as_float_array / xp_as_int_array, xp_astype, xp_ascontiguousarray
    • Shape & manipulation: xp_reshape, xp_squeeze, xp_atleast_1d / xp_atleast_2d, xp_broadcast_to, xp_matrix_transpose, xp_resize, xp_pad, xp_insert
    • Reductions & statistics: xp_average, xp_median, xp_nanmean, xp_trapezoid, xp_gradient
    • Element-wise math: xp_degrees / xp_radians, xp_sinc, xp_round, xp_nan_to_num
    • Linear algebra: xp_lstsq, xp_eig / xp_eigh, xp_create_diagonal
    • Sampling, interpolation & set operations: xp_linspace, xp_interp, xp_select, xp_isin, xp_setxor1d, xp_unique
    • Comparison & testing: xp_isclose, xp_assert_close, xp_assert_equal
  • contextvars-backed global state (Array API enablement, domain-range scale, ndarray copy, caching) for thread/async safety.
  • SciPy-free, dispatchable kernels replacing solver/interpolator hotspots: correlated colour temperature Gauss-Newton (colour.temperature.common), Jakob and Hanika (2019) trilinear interpolation, etc.
  • Default complex precision: COLOUR_SCIENCE__DEFAULT_COMPLEX_DTYPE / set_default_complex_dtype.
  • New public API: CIE_illuminant_D_series, msds_CIE_illuminant_D_series, msds_blackbody, msds_rayleigh_jeans.
  • Cross-backend testing: an xp pytest fixture parametrising numpy/jax/torch/torch-mps, with mps_tolerance_absolute and mps_xfail markers for float32 precision, plus a cross-backend benchmark suite (utilities/benchmark.py).
  • Documentation: a dedicated Array API Support section in advanced.rst.

Performance

Per-suite speed-up vs NumPy (best-of-3, HD inputs): speed-up = NumPy ÷ backend over cases succeeding on both, so higher = faster (e.g. 3.0× = 3× faster than NumPy; < 1.0× = slower). numpy (ms) is the summed best-of-3 over the suite's cases.

Suite cases numpy (ms) jax torch-cpu torch-mps
conversion_graph 207 31455.1 4.4× 3.5× 11×
conversion_graph_iterative 4 36874.8 3.3× 1.8× 12×
difference 17 2139.4 3.1× 3.0× 18×
integration_array 2 201.5 644× 1.00× 15×
integration_object 6 3.2 1.3× 0.77× 0.76×
transfer_function 114 3721.2 15× 2.3× 10×
adaptation 7 1066.4 4.4× 4.8× 6.0×
characterisation 3 249.9 6.4× 4.0× 9.5×
recovery_array 4 1878.3 2.7× 3.0× 7.2×
recovery_object 3 452.6 0.65× 0.66× 0.48×
quality_array 4 488.9 2.4× 2.8× 1.5×
quality_object 5 5.1 0.10× 0.39× 0.05×
volume 2 39.5 1.1× 0.98× 1.2×
volume_iterative 2 1856.4 1.00× 0.97× 16×
phenomena 5 172.8 3.0× 4.5× 27×
temperature_array 4 294.7 4.7× 1.6× 18×
temperature_iterative 4 1015.5 0.75× 1.3× 0.53×
blindness 3 163.4 271× 15× 8.1×
contrast 1 78.6 5.1× 3.5× 26×
generators_array 3 51.8 3.3× 1.7× 1.9×
generators_object 6 12.7 1.3× 1.3× 0.84×
photometry 3 0.1 0.08× 0.18× 0.02×
overall 409 82221.8 3.3× 2.2× 8.5×

NumPy is the baseline (1.00×). Measured on an Apple M1 Max (10-core, 32 GB), macOS 15.7, Python 3.13, NumPy 2.3, PyTorch 2.9, JAX 0.8; 409 cases across 22 suites.

Preflight

Code Style and Quality

  • Unit tests have been implemented and passed.
  • Pyright static checking has been run and passed.
  • Pre-commit hooks have been run and passed.
  • [N/A] New transformations have been added to the Automatic Colour Conversion Graph.
  • New transformations have been exported to the relevant namespaces, e.g. colour, colour.models.

Documentation

  • New features are documented along with examples if relevant.
  • The documentation is Sphinx and numpydoc compliant.

@KelSolaar KelSolaar force-pushed the feature/array-api-support branch 3 times, most recently from 68f59e9 to bc223a6 Compare June 3, 2026 08:28
@KelSolaar KelSolaar force-pushed the feature/array-api-support branch from bc223a6 to 3163508 Compare June 4, 2026 11:24
@KelSolaar KelSolaar changed the title Implement support for *Python Array API Standard*. PR: Implement support for *Python Array API Standard*. Jun 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant