diff --git a/crates/providers/src/store.rs b/crates/providers/src/store.rs index b115efc12f29..57ea160a4a6a 100644 --- a/crates/providers/src/store.rs +++ b/crates/providers/src/store.rs @@ -89,7 +89,9 @@ mod tests { #[test] fn test_store_output_types_2d() { use ndarray::arr2; - let data = DataTree::new_leaf(Tensor::F64(arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]).into_dyn())); + let data = DataTree::new_leaf(Tensor::F64( + arr2(&[[1.0_f64, 2.0], [3.0, 4.0]]).into_dyn().into_shared(), + )); let store = Store::new(data); let DataTree::Leaf(tt) = store.output_types() else { panic!("expected leaf output type"); diff --git a/crates/providers/src/tensor.rs b/crates/providers/src/tensor.rs index e631bac1607b..5d11571d397c 100644 --- a/crates/providers/src/tensor.rs +++ b/crates/providers/src/tensor.rs @@ -10,11 +10,14 @@ // copyright notice, and modified files need to carry a notice indicating // that they have been altered from the originals. -use ndarray::{ArrayD, IxDyn, Zip}; +use ndarray::{ArcArray, ArrayD, IxDyn, Zip}; use num_complex::{Complex32, Complex64}; use std::fmt; use thiserror::Error; +/// Dynamic-dimensional [`ArcArray`]; the storage type for every [`Tensor`] variant. +type ArcArrayD = ArcArray; + /// Errors returned by [`Tensor`] operations. #[derive(Debug, Clone, PartialEq, Eq, Error)] pub enum TensorError { @@ -234,50 +237,68 @@ impl TensorType { } /// A tensor of one of the supported dtypes. +/// +/// Each variant wraps a reference-counted dynamic ndarray ([`ArcArray`]). +/// +/// This allows [`Tensor::clone`] to cause a refcount bump rather than a copy of +/// underlying data. Note that mutating the underlying buffer in place (via ndarray +/// methods that require `DataMut`) clones-on-write when the buffer is shared. #[derive(Debug, Clone)] pub enum Tensor { - C64(ArrayD), // complex - C128(ArrayD), - F32(ArrayD), // real - F64(ArrayD), - I8(ArrayD), // signed integer - I16(ArrayD), - I32(ArrayD), - I64(ArrayD), - U8(ArrayD), // unsigned integer - U16(ArrayD), - U32(ArrayD), - U64(ArrayD), - Bit(ArrayD), // bool + C64(ArcArrayD), // complex + C128(ArcArrayD), + F32(ArcArrayD), // real + F64(ArcArrayD), + I8(ArcArrayD), // signed integer + I16(ArcArrayD), + I32(ArcArrayD), + I64(ArcArrayD), + U8(ArcArrayD), // unsigned integer + U16(ArcArrayD), + U32(ArcArrayD), + U64(ArcArrayD), + Bit(ArcArrayD), // bool } -/// Cast an `ArrayD` of a real numeric type to any supported dtype. +/// Cast an array of a real numeric type to any supported dtype. macro_rules! cast_real { ($arr:expr, $src:ty, $target:expr) => { match $target { - DType::Bit => Tensor::Bit($arr.mapv(|x: $src| x as u8)), - DType::U8 => Tensor::U8($arr.mapv(|x: $src| x as u8)), - DType::U16 => Tensor::U16($arr.mapv(|x: $src| x as u16)), - DType::U32 => Tensor::U32($arr.mapv(|x: $src| x as u32)), - DType::U64 => Tensor::U64($arr.mapv(|x: $src| x as u64)), - DType::I8 => Tensor::I8($arr.mapv(|x: $src| x as i8)), - DType::I16 => Tensor::I16($arr.mapv(|x: $src| x as i16)), - DType::I32 => Tensor::I32($arr.mapv(|x: $src| x as i32)), - DType::I64 => Tensor::I64($arr.mapv(|x: $src| x as i64)), - DType::F32 => Tensor::F32($arr.mapv(|x: $src| x as f32)), - DType::F64 => Tensor::F64($arr.mapv(|x: $src| x as f64)), - DType::C64 => Tensor::C64($arr.mapv(|x: $src| Complex32::new(x as f32, 0.0))), - DType::C128 => Tensor::C128($arr.mapv(|x: $src| Complex64::new(x as f64, 0.0))), + DType::Bit => Tensor::Bit($arr.mapv(|x: $src| x as u8).into_shared()), + DType::U8 => Tensor::U8($arr.mapv(|x: $src| x as u8).into_shared()), + DType::U16 => Tensor::U16($arr.mapv(|x: $src| x as u16).into_shared()), + DType::U32 => Tensor::U32($arr.mapv(|x: $src| x as u32).into_shared()), + DType::U64 => Tensor::U64($arr.mapv(|x: $src| x as u64).into_shared()), + DType::I8 => Tensor::I8($arr.mapv(|x: $src| x as i8).into_shared()), + DType::I16 => Tensor::I16($arr.mapv(|x: $src| x as i16).into_shared()), + DType::I32 => Tensor::I32($arr.mapv(|x: $src| x as i32).into_shared()), + DType::I64 => Tensor::I64($arr.mapv(|x: $src| x as i64).into_shared()), + DType::F32 => Tensor::F32($arr.mapv(|x: $src| x as f32).into_shared()), + DType::F64 => Tensor::F64($arr.mapv(|x: $src| x as f64).into_shared()), + DType::C64 => Tensor::C64( + $arr.mapv(|x: $src| Complex32::new(x as f32, 0.0)) + .into_shared(), + ), + DType::C128 => Tensor::C128( + $arr.mapv(|x: $src| Complex64::new(x as f64, 0.0)) + .into_shared(), + ), } }; } -/// Cast an `ArrayD` of a complex type to a complex dtype (panics for real targets). +/// Cast an array of a complex type to a complex dtype (panics for real targets). macro_rules! cast_complex { ($arr:expr, $target:expr) => { match $target { - DType::C64 => Tensor::C64($arr.mapv(|x| Complex32::new(x.re as f32, x.im as f32))), - DType::C128 => Tensor::C128($arr.mapv(|x| Complex64::new(x.re as f64, x.im as f64))), + DType::C64 => Tensor::C64( + $arr.mapv(|x| Complex32::new(x.re as f32, x.im as f32)) + .into_shared(), + ), + DType::C128 => Tensor::C128( + $arr.mapv(|x| Complex64::new(x.re as f64, x.im as f64)) + .into_shared(), + ), _ => panic!("cannot cast complex tensor to a real dtype"), } }; @@ -318,10 +339,10 @@ fn broadcast_shape(a: &[usize], b: &[usize]) -> Result, TensorError> /// this helper is needed for operations without a Rust operator (e.g. `pow`). Returns /// [`TensorError::ShapeMismatch`] if the operand shapes are not broadcast-compatible. fn broadcast_elementwise( - a: &ArrayD, - b: &ArrayD, + a: &ArcArrayD, + b: &ArcArrayD, op: F, -) -> Result, TensorError> +) -> Result, TensorError> where T: Clone, F: Fn(&T, &T) -> T, @@ -330,7 +351,7 @@ where let out_ix = IxDyn(&out_shape); let a_bc = a.broadcast(out_ix.clone()).expect("broadcast failed"); let b_bc = b.broadcast(out_ix).expect("broadcast failed"); - Ok(Zip::from(a_bc).and(b_bc).map_collect(op)) + Ok(Zip::from(a_bc).and(b_bc).map_collect(op).into_shared()) } impl Tensor { @@ -455,21 +476,27 @@ impl Tensor { } } -/// Implement `From<&[T]>`, `From<&[T; N]>`, and `From>` for a given `Tensor` variant. +/// Implement `From<&[T]>`, `From<&[T; N]>`, `From>`, and +/// `From>` for a given `Tensor` variant. macro_rules! impl_tensor_from { ($variant:ident, $t:ty) => { impl From<&[$t]> for Tensor { fn from(data: &[$t]) -> Self { - Tensor::$variant(ndarray::arr1(data).into_dyn()) + Tensor::$variant(ndarray::arr1(data).into_dyn().into_shared()) } } impl From<[$t; N]> for Tensor { fn from(data: [$t; N]) -> Self { - Tensor::$variant(ndarray::arr1(&data).into_dyn()) + Tensor::$variant(ndarray::arr1(&data).into_dyn().into_shared()) } } impl From> for Tensor { fn from(data: ArrayD<$t>) -> Self { + Tensor::$variant(data.into_shared()) + } + } + impl From> for Tensor { + fn from(data: ArcArrayD<$t>) -> Self { Tensor::$variant(data) } } @@ -508,18 +535,18 @@ macro_rules! impl_tensor_binop { pub fn $tensor_method(&self, rhs: &Tensor) -> Result { broadcast_shape(self.shape(), rhs.shape())?; match (self, rhs) { - (Tensor::C128(a), Tensor::C128(b)) => Ok(Tensor::C128(a $op b)), - (Tensor::C64(a), Tensor::C64(b)) => Ok(Tensor::C64(a $op b)), - (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64(a $op b)), - (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32(a $op b)), - (Tensor::I64(a), Tensor::I64(b)) => Ok(Tensor::I64(a $op b)), - (Tensor::I32(a), Tensor::I32(b)) => Ok(Tensor::I32(a $op b)), - (Tensor::I16(a), Tensor::I16(b)) => Ok(Tensor::I16(a $op b)), - (Tensor::I8(a), Tensor::I8(b)) => Ok(Tensor::I8(a $op b)), - (Tensor::U64(a), Tensor::U64(b)) => Ok(Tensor::U64(a $op b)), - (Tensor::U32(a), Tensor::U32(b)) => Ok(Tensor::U32(a $op b)), - (Tensor::U16(a), Tensor::U16(b)) => Ok(Tensor::U16(a $op b)), - (Tensor::U8(a), Tensor::U8(b)) => Ok(Tensor::U8(a $op b)), + (Tensor::C128(a), Tensor::C128(b)) => Ok(Tensor::C128((a $op b).into_shared())), + (Tensor::C64(a), Tensor::C64(b)) => Ok(Tensor::C64((a $op b).into_shared())), + (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64((a $op b).into_shared())), + (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32((a $op b).into_shared())), + (Tensor::I64(a), Tensor::I64(b)) => Ok(Tensor::I64((a $op b).into_shared())), + (Tensor::I32(a), Tensor::I32(b)) => Ok(Tensor::I32((a $op b).into_shared())), + (Tensor::I16(a), Tensor::I16(b)) => Ok(Tensor::I16((a $op b).into_shared())), + (Tensor::I8(a), Tensor::I8(b)) => Ok(Tensor::I8((a $op b).into_shared())), + (Tensor::U64(a), Tensor::U64(b)) => Ok(Tensor::U64((a $op b).into_shared())), + (Tensor::U32(a), Tensor::U32(b)) => Ok(Tensor::U32((a $op b).into_shared())), + (Tensor::U16(a), Tensor::U16(b)) => Ok(Tensor::U16((a $op b).into_shared())), + (Tensor::U8(a), Tensor::U8(b)) => Ok(Tensor::U8((a $op b).into_shared())), _ => Err(TensorError::DTypeMismatch { op: $op_name, lhs: self.dtype(), @@ -557,16 +584,16 @@ impl Tensor { pub fn rem_tensor(&self, rhs: &Tensor) -> Result { broadcast_shape(self.shape(), rhs.shape())?; match (self, rhs) { - (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64(a % b)), - (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32(a % b)), - (Tensor::I64(a), Tensor::I64(b)) => Ok(Tensor::I64(a % b)), - (Tensor::I32(a), Tensor::I32(b)) => Ok(Tensor::I32(a % b)), - (Tensor::I16(a), Tensor::I16(b)) => Ok(Tensor::I16(a % b)), - (Tensor::I8(a), Tensor::I8(b)) => Ok(Tensor::I8(a % b)), - (Tensor::U64(a), Tensor::U64(b)) => Ok(Tensor::U64(a % b)), - (Tensor::U32(a), Tensor::U32(b)) => Ok(Tensor::U32(a % b)), - (Tensor::U16(a), Tensor::U16(b)) => Ok(Tensor::U16(a % b)), - (Tensor::U8(a), Tensor::U8(b)) => Ok(Tensor::U8(a % b)), + (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64((a % b).into_shared())), + (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32((a % b).into_shared())), + (Tensor::I64(a), Tensor::I64(b)) => Ok(Tensor::I64((a % b).into_shared())), + (Tensor::I32(a), Tensor::I32(b)) => Ok(Tensor::I32((a % b).into_shared())), + (Tensor::I16(a), Tensor::I16(b)) => Ok(Tensor::I16((a % b).into_shared())), + (Tensor::I8(a), Tensor::I8(b)) => Ok(Tensor::I8((a % b).into_shared())), + (Tensor::U64(a), Tensor::U64(b)) => Ok(Tensor::U64((a % b).into_shared())), + (Tensor::U32(a), Tensor::U32(b)) => Ok(Tensor::U32((a % b).into_shared())), + (Tensor::U16(a), Tensor::U16(b)) => Ok(Tensor::U16((a % b).into_shared())), + (Tensor::U8(a), Tensor::U8(b)) => Ok(Tensor::U8((a % b).into_shared())), _ => Err(TensorError::DTypeMismatch { op: "rem", lhs: self.dtype(), @@ -770,6 +797,22 @@ mod test { assert_eq!(t.shape(), &[4]); } + #[test] + fn test_clone_shares_buffer() { + // ArcArray storage means Tensor::clone() is a refcount bump, not a deep + // copy. Verify by comparing the underlying buffer pointer between the + // original and a clone. + let t = Tensor::from([1.0_f64, 2.0, 3.0]); + let cloned = t.clone(); + let Tensor::F64(orig) = &t else { + panic!("expected F64 tensor") + }; + let Tensor::F64(copy) = &cloned else { + panic!("expected F64 tensor") + }; + assert_eq!(orig.as_ptr(), copy.as_ptr()); + } + #[test] fn test_from_arrayd() { let arr = ndarray::Array::from_shape_vec(IxDyn(&[2, 3]), vec![1.0f64; 6]).unwrap(); @@ -1390,17 +1433,17 @@ mod test { DType::C128, ]; let sources = [ - Tensor::Bit(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8)), - Tensor::U8(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8)), - Tensor::U16(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u16)), - Tensor::U32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u32)), - Tensor::U64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u64)), - Tensor::I8(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i8)), - Tensor::I16(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i16)), - Tensor::I32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i32)), - Tensor::I64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i64)), - Tensor::F32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1.0f32)), - Tensor::F64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1.0f64)), + Tensor::Bit(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8).into_shared()), + Tensor::U8(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8).into_shared()), + Tensor::U16(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u16).into_shared()), + Tensor::U32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u32).into_shared()), + Tensor::U64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u64).into_shared()), + Tensor::I8(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i8).into_shared()), + Tensor::I16(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i16).into_shared()), + Tensor::I32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i32).into_shared()), + Tensor::I64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i64).into_shared()), + Tensor::F32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1.0f32).into_shared()), + Tensor::F64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1.0f64).into_shared()), ]; for src in sources { let src_dtype = src.dtype(); @@ -1425,7 +1468,8 @@ mod test { } // Spot-check a numeric value (Bit(1) -> F64 -> 1.0). - let bit_to_f64 = Tensor::Bit(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8)).cast(DType::F64); + let bit_to_f64 = Tensor::Bit(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8).into_shared()) + .cast(DType::F64); if let Tensor::F64(arr) = bit_to_f64 { assert_eq!(arr.as_slice().unwrap(), &[1.0_f64, 1.0]); } else {