-
Notifications
You must be signed in to change notification settings - Fork 3k
Switch Tensor to use ArcArray instead of Array #16256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,11 +10,14 @@ | |
| // copyright notice, and modified files need to carry a notice indicating | ||
| // that they have been altered from the originals. | ||
|
|
||
| use ndarray::{ArrayD, IxDyn, Zip}; | ||
| use ndarray::{ArcArray, ArrayD, IxDyn, Zip}; | ||
| use num_complex::{Complex32, Complex64}; | ||
| use std::fmt; | ||
| use thiserror::Error; | ||
|
|
||
| /// Dynamic-dimensional [`ArcArray`]; the storage type for every [`Tensor`] variant. | ||
| type ArcArrayD<T> = ArcArray<T, IxDyn>; | ||
|
|
||
| /// Errors returned by [`Tensor`] operations. | ||
| #[derive(Debug, Clone, PartialEq, Eq, Error)] | ||
| pub enum TensorError { | ||
|
|
@@ -234,50 +237,68 @@ impl TensorType { | |
| } | ||
|
|
||
| /// A tensor of one of the supported dtypes. | ||
| /// | ||
| /// Each variant wraps a reference-counted dynamic ndarray ([`ArcArray`]). | ||
| /// | ||
| /// This allows [`Tensor::clone`] to cause a refcount bump rather than a copy of | ||
| /// underlying data. Note that mutating the underlying buffer in place (via ndarray | ||
| /// methods that require `DataMut`) clones-on-write when the buffer is shared. | ||
| #[derive(Debug, Clone)] | ||
| pub enum Tensor { | ||
| C64(ArrayD<Complex32>), // complex | ||
| C128(ArrayD<Complex64>), | ||
| F32(ArrayD<f32>), // real | ||
| F64(ArrayD<f64>), | ||
| I8(ArrayD<i8>), // signed integer | ||
| I16(ArrayD<i16>), | ||
| I32(ArrayD<i32>), | ||
| I64(ArrayD<i64>), | ||
| U8(ArrayD<u8>), // unsigned integer | ||
| U16(ArrayD<u16>), | ||
| U32(ArrayD<u32>), | ||
| U64(ArrayD<u64>), | ||
| Bit(ArrayD<u8>), // bool | ||
| C64(ArcArrayD<Complex32>), // complex | ||
| C128(ArcArrayD<Complex64>), | ||
| F32(ArcArrayD<f32>), // real | ||
| F64(ArcArrayD<f64>), | ||
| I8(ArcArrayD<i8>), // signed integer | ||
| I16(ArcArrayD<i16>), | ||
| I32(ArcArrayD<i32>), | ||
| I64(ArcArrayD<i64>), | ||
| U8(ArcArrayD<u8>), // unsigned integer | ||
| U16(ArcArrayD<u16>), | ||
| U32(ArcArrayD<u32>), | ||
| U64(ArcArrayD<u64>), | ||
| Bit(ArcArrayD<u8>), // bool | ||
| } | ||
|
|
||
| /// Cast an `ArrayD` of a real numeric type to any supported dtype. | ||
| /// Cast an array of a real numeric type to any supported dtype. | ||
| macro_rules! cast_real { | ||
| ($arr:expr, $src:ty, $target:expr) => { | ||
| match $target { | ||
| DType::Bit => Tensor::Bit($arr.mapv(|x: $src| x as u8)), | ||
| DType::U8 => Tensor::U8($arr.mapv(|x: $src| x as u8)), | ||
| DType::U16 => Tensor::U16($arr.mapv(|x: $src| x as u16)), | ||
| DType::U32 => Tensor::U32($arr.mapv(|x: $src| x as u32)), | ||
| DType::U64 => Tensor::U64($arr.mapv(|x: $src| x as u64)), | ||
| DType::I8 => Tensor::I8($arr.mapv(|x: $src| x as i8)), | ||
| DType::I16 => Tensor::I16($arr.mapv(|x: $src| x as i16)), | ||
| DType::I32 => Tensor::I32($arr.mapv(|x: $src| x as i32)), | ||
| DType::I64 => Tensor::I64($arr.mapv(|x: $src| x as i64)), | ||
| DType::F32 => Tensor::F32($arr.mapv(|x: $src| x as f32)), | ||
| DType::F64 => Tensor::F64($arr.mapv(|x: $src| x as f64)), | ||
| DType::C64 => Tensor::C64($arr.mapv(|x: $src| Complex32::new(x as f32, 0.0))), | ||
| DType::C128 => Tensor::C128($arr.mapv(|x: $src| Complex64::new(x as f64, 0.0))), | ||
| DType::Bit => Tensor::Bit($arr.mapv(|x: $src| x as u8).into_shared()), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this need to be
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since The docs do recommend using This is all to say: I'm pretty sure all of the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, you're right. Looking deeper at the ndarray code, the so it's doing the same thing as I think it's unlikely we'll deal with a lot of |
||
| DType::U8 => Tensor::U8($arr.mapv(|x: $src| x as u8).into_shared()), | ||
| DType::U16 => Tensor::U16($arr.mapv(|x: $src| x as u16).into_shared()), | ||
| DType::U32 => Tensor::U32($arr.mapv(|x: $src| x as u32).into_shared()), | ||
| DType::U64 => Tensor::U64($arr.mapv(|x: $src| x as u64).into_shared()), | ||
| DType::I8 => Tensor::I8($arr.mapv(|x: $src| x as i8).into_shared()), | ||
| DType::I16 => Tensor::I16($arr.mapv(|x: $src| x as i16).into_shared()), | ||
| DType::I32 => Tensor::I32($arr.mapv(|x: $src| x as i32).into_shared()), | ||
| DType::I64 => Tensor::I64($arr.mapv(|x: $src| x as i64).into_shared()), | ||
| DType::F32 => Tensor::F32($arr.mapv(|x: $src| x as f32).into_shared()), | ||
| DType::F64 => Tensor::F64($arr.mapv(|x: $src| x as f64).into_shared()), | ||
| DType::C64 => Tensor::C64( | ||
| $arr.mapv(|x: $src| Complex32::new(x as f32, 0.0)) | ||
| .into_shared(), | ||
| ), | ||
| DType::C128 => Tensor::C128( | ||
| $arr.mapv(|x: $src| Complex64::new(x as f64, 0.0)) | ||
| .into_shared(), | ||
| ), | ||
| } | ||
| }; | ||
| } | ||
|
|
||
| /// Cast an `ArrayD` of a complex type to a complex dtype (panics for real targets). | ||
| /// Cast an array of a complex type to a complex dtype (panics for real targets). | ||
| macro_rules! cast_complex { | ||
| ($arr:expr, $target:expr) => { | ||
| match $target { | ||
| DType::C64 => Tensor::C64($arr.mapv(|x| Complex32::new(x.re as f32, x.im as f32))), | ||
| DType::C128 => Tensor::C128($arr.mapv(|x| Complex64::new(x.re as f64, x.im as f64))), | ||
| DType::C64 => Tensor::C64( | ||
| $arr.mapv(|x| Complex32::new(x.re as f32, x.im as f32)) | ||
| .into_shared(), | ||
| ), | ||
| DType::C128 => Tensor::C128( | ||
| $arr.mapv(|x| Complex64::new(x.re as f64, x.im as f64)) | ||
| .into_shared(), | ||
| ), | ||
| _ => panic!("cannot cast complex tensor to a real dtype"), | ||
| } | ||
| }; | ||
|
|
@@ -318,10 +339,10 @@ fn broadcast_shape(a: &[usize], b: &[usize]) -> Result<Vec<usize>, TensorError> | |
| /// this helper is needed for operations without a Rust operator (e.g. `pow`). Returns | ||
| /// [`TensorError::ShapeMismatch`] if the operand shapes are not broadcast-compatible. | ||
| fn broadcast_elementwise<T, F>( | ||
| a: &ArrayD<T>, | ||
| b: &ArrayD<T>, | ||
| a: &ArcArrayD<T>, | ||
| b: &ArcArrayD<T>, | ||
| op: F, | ||
| ) -> Result<ArrayD<T>, TensorError> | ||
| ) -> Result<ArcArrayD<T>, TensorError> | ||
| where | ||
| T: Clone, | ||
| F: Fn(&T, &T) -> T, | ||
|
|
@@ -330,7 +351,7 @@ where | |
| let out_ix = IxDyn(&out_shape); | ||
| let a_bc = a.broadcast(out_ix.clone()).expect("broadcast failed"); | ||
| let b_bc = b.broadcast(out_ix).expect("broadcast failed"); | ||
| Ok(Zip::from(a_bc).and(b_bc).map_collect(op)) | ||
| Ok(Zip::from(a_bc).and(b_bc).map_collect(op).into_shared()) | ||
| } | ||
|
|
||
| impl Tensor { | ||
|
|
@@ -455,21 +476,27 @@ impl Tensor { | |
| } | ||
| } | ||
|
|
||
| /// Implement `From<&[T]>`, `From<&[T; N]>`, and `From<ArrayD<T>>` for a given `Tensor` variant. | ||
| /// Implement `From<&[T]>`, `From<&[T; N]>`, `From<ArrayD<T>>`, and | ||
| /// `From<ArcArrayD<T>>` for a given `Tensor` variant. | ||
| macro_rules! impl_tensor_from { | ||
| ($variant:ident, $t:ty) => { | ||
| impl From<&[$t]> for Tensor { | ||
| fn from(data: &[$t]) -> Self { | ||
| Tensor::$variant(ndarray::arr1(data).into_dyn()) | ||
| Tensor::$variant(ndarray::arr1(data).into_dyn().into_shared()) | ||
| } | ||
| } | ||
| impl<const N: usize> From<[$t; N]> for Tensor { | ||
| fn from(data: [$t; N]) -> Self { | ||
| Tensor::$variant(ndarray::arr1(&data).into_dyn()) | ||
| Tensor::$variant(ndarray::arr1(&data).into_dyn().into_shared()) | ||
| } | ||
| } | ||
| impl From<ArrayD<$t>> for Tensor { | ||
| fn from(data: ArrayD<$t>) -> Self { | ||
| Tensor::$variant(data.into_shared()) | ||
| } | ||
| } | ||
| impl From<ArcArrayD<$t>> for Tensor { | ||
| fn from(data: ArcArrayD<$t>) -> Self { | ||
| Tensor::$variant(data) | ||
| } | ||
| } | ||
|
|
@@ -508,18 +535,18 @@ macro_rules! impl_tensor_binop { | |
| pub fn $tensor_method(&self, rhs: &Tensor) -> Result<Tensor, TensorError> { | ||
| broadcast_shape(self.shape(), rhs.shape())?; | ||
| match (self, rhs) { | ||
| (Tensor::C128(a), Tensor::C128(b)) => Ok(Tensor::C128(a $op b)), | ||
| (Tensor::C64(a), Tensor::C64(b)) => Ok(Tensor::C64(a $op b)), | ||
| (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64(a $op b)), | ||
| (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32(a $op b)), | ||
| (Tensor::I64(a), Tensor::I64(b)) => Ok(Tensor::I64(a $op b)), | ||
| (Tensor::I32(a), Tensor::I32(b)) => Ok(Tensor::I32(a $op b)), | ||
| (Tensor::I16(a), Tensor::I16(b)) => Ok(Tensor::I16(a $op b)), | ||
| (Tensor::I8(a), Tensor::I8(b)) => Ok(Tensor::I8(a $op b)), | ||
| (Tensor::U64(a), Tensor::U64(b)) => Ok(Tensor::U64(a $op b)), | ||
| (Tensor::U32(a), Tensor::U32(b)) => Ok(Tensor::U32(a $op b)), | ||
| (Tensor::U16(a), Tensor::U16(b)) => Ok(Tensor::U16(a $op b)), | ||
| (Tensor::U8(a), Tensor::U8(b)) => Ok(Tensor::U8(a $op b)), | ||
| (Tensor::C128(a), Tensor::C128(b)) => Ok(Tensor::C128((a $op b).into_shared())), | ||
| (Tensor::C64(a), Tensor::C64(b)) => Ok(Tensor::C64((a $op b).into_shared())), | ||
| (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64((a $op b).into_shared())), | ||
| (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32((a $op b).into_shared())), | ||
| (Tensor::I64(a), Tensor::I64(b)) => Ok(Tensor::I64((a $op b).into_shared())), | ||
| (Tensor::I32(a), Tensor::I32(b)) => Ok(Tensor::I32((a $op b).into_shared())), | ||
| (Tensor::I16(a), Tensor::I16(b)) => Ok(Tensor::I16((a $op b).into_shared())), | ||
| (Tensor::I8(a), Tensor::I8(b)) => Ok(Tensor::I8((a $op b).into_shared())), | ||
| (Tensor::U64(a), Tensor::U64(b)) => Ok(Tensor::U64((a $op b).into_shared())), | ||
| (Tensor::U32(a), Tensor::U32(b)) => Ok(Tensor::U32((a $op b).into_shared())), | ||
| (Tensor::U16(a), Tensor::U16(b)) => Ok(Tensor::U16((a $op b).into_shared())), | ||
| (Tensor::U8(a), Tensor::U8(b)) => Ok(Tensor::U8((a $op b).into_shared())), | ||
| _ => Err(TensorError::DTypeMismatch { | ||
| op: $op_name, | ||
| lhs: self.dtype(), | ||
|
|
@@ -557,16 +584,16 @@ impl Tensor { | |
| pub fn rem_tensor(&self, rhs: &Tensor) -> Result<Tensor, TensorError> { | ||
| broadcast_shape(self.shape(), rhs.shape())?; | ||
| match (self, rhs) { | ||
| (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64(a % b)), | ||
| (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32(a % b)), | ||
| (Tensor::I64(a), Tensor::I64(b)) => Ok(Tensor::I64(a % b)), | ||
| (Tensor::I32(a), Tensor::I32(b)) => Ok(Tensor::I32(a % b)), | ||
| (Tensor::I16(a), Tensor::I16(b)) => Ok(Tensor::I16(a % b)), | ||
| (Tensor::I8(a), Tensor::I8(b)) => Ok(Tensor::I8(a % b)), | ||
| (Tensor::U64(a), Tensor::U64(b)) => Ok(Tensor::U64(a % b)), | ||
| (Tensor::U32(a), Tensor::U32(b)) => Ok(Tensor::U32(a % b)), | ||
| (Tensor::U16(a), Tensor::U16(b)) => Ok(Tensor::U16(a % b)), | ||
| (Tensor::U8(a), Tensor::U8(b)) => Ok(Tensor::U8(a % b)), | ||
| (Tensor::F64(a), Tensor::F64(b)) => Ok(Tensor::F64((a % b).into_shared())), | ||
| (Tensor::F32(a), Tensor::F32(b)) => Ok(Tensor::F32((a % b).into_shared())), | ||
| (Tensor::I64(a), Tensor::I64(b)) => Ok(Tensor::I64((a % b).into_shared())), | ||
| (Tensor::I32(a), Tensor::I32(b)) => Ok(Tensor::I32((a % b).into_shared())), | ||
| (Tensor::I16(a), Tensor::I16(b)) => Ok(Tensor::I16((a % b).into_shared())), | ||
| (Tensor::I8(a), Tensor::I8(b)) => Ok(Tensor::I8((a % b).into_shared())), | ||
| (Tensor::U64(a), Tensor::U64(b)) => Ok(Tensor::U64((a % b).into_shared())), | ||
| (Tensor::U32(a), Tensor::U32(b)) => Ok(Tensor::U32((a % b).into_shared())), | ||
| (Tensor::U16(a), Tensor::U16(b)) => Ok(Tensor::U16((a % b).into_shared())), | ||
| (Tensor::U8(a), Tensor::U8(b)) => Ok(Tensor::U8((a % b).into_shared())), | ||
| _ => Err(TensorError::DTypeMismatch { | ||
| op: "rem", | ||
| lhs: self.dtype(), | ||
|
|
@@ -770,6 +797,22 @@ mod test { | |
| assert_eq!(t.shape(), &[4]); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_clone_shares_buffer() { | ||
| // ArcArray storage means Tensor::clone() is a refcount bump, not a deep | ||
| // copy. Verify by comparing the underlying buffer pointer between the | ||
| // original and a clone. | ||
| let t = Tensor::from([1.0_f64, 2.0, 3.0]); | ||
| let cloned = t.clone(); | ||
| let Tensor::F64(orig) = &t else { | ||
| panic!("expected F64 tensor") | ||
| }; | ||
| let Tensor::F64(copy) = &cloned else { | ||
| panic!("expected F64 tensor") | ||
| }; | ||
| assert_eq!(orig.as_ptr(), copy.as_ptr()); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_from_arrayd() { | ||
| let arr = ndarray::Array::from_shape_vec(IxDyn(&[2, 3]), vec![1.0f64; 6]).unwrap(); | ||
|
|
@@ -1390,17 +1433,17 @@ mod test { | |
| DType::C128, | ||
| ]; | ||
| let sources = [ | ||
| Tensor::Bit(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8)), | ||
| Tensor::U8(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8)), | ||
| Tensor::U16(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u16)), | ||
| Tensor::U32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u32)), | ||
| Tensor::U64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u64)), | ||
| Tensor::I8(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i8)), | ||
| Tensor::I16(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i16)), | ||
| Tensor::I32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i32)), | ||
| Tensor::I64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i64)), | ||
| Tensor::F32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1.0f32)), | ||
| Tensor::F64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1.0f64)), | ||
| Tensor::Bit(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8).into_shared()), | ||
| Tensor::U8(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8).into_shared()), | ||
| Tensor::U16(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u16).into_shared()), | ||
| Tensor::U32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u32).into_shared()), | ||
| Tensor::U64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u64).into_shared()), | ||
| Tensor::I8(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i8).into_shared()), | ||
| Tensor::I16(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i16).into_shared()), | ||
| Tensor::I32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i32).into_shared()), | ||
| Tensor::I64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1i64).into_shared()), | ||
| Tensor::F32(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1.0f32).into_shared()), | ||
| Tensor::F64(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1.0f64).into_shared()), | ||
| ]; | ||
| for src in sources { | ||
| let src_dtype = src.dtype(); | ||
|
|
@@ -1425,7 +1468,8 @@ mod test { | |
| } | ||
|
|
||
| // Spot-check a numeric value (Bit(1) -> F64 -> 1.0). | ||
| let bit_to_f64 = Tensor::Bit(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8)).cast(DType::F64); | ||
| let bit_to_f64 = Tensor::Bit(ndarray::ArrayD::from_elem(IxDyn(&[2]), 1u8).into_shared()) | ||
| .cast(DType::F64); | ||
| if let Tensor::F64(arr) = bit_to_f64 { | ||
| assert_eq!(arr.as_slice().unwrap(), &[1.0_f64, 1.0]); | ||
| } else { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually why do we need this type? Isn't it in ndarray already? https://docs.rs/ndarray/latest/ndarray/type.ArcArrayD.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, looks like it was introduced in ndarary 0.17, but we're on 0.16. Can I upgrade to that version?
rust-ndarray/ndarray#1561
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're stuck on 0.16 for right now because of rustworkx-core IIRC. We get arrays from rustworkx-core (like adjacency and distance matrices) so the version we use internally for Qiskit has to match the version rustworkx-core uses which is unfortunately pinned to 0.16 right now. It's something we want to fix for the next rustworkx-core release