Jelajahi Sumber

:sparkles: initial version of read-only matrix view. not that useful yet

Felix Bytow 2 tahun lalu
induk
melakukan
fd7b545d66
7 mengubah file dengan 296 tambahan dan 22 penghapusan
  1. 1 1
      src/lib.rs
  2. 131 0
      src/traits/numeric.rs
  3. 12 0
      src/types/complex.rs
  4. 20 20
      src/types/matrix/generic.rs
  5. 20 0
      src/types/matrix/mod.rs
  6. 111 0
      src/types/matrix/view.rs
  7. 1 1
      src/types/mod.rs

+ 1 - 1
src/lib.rs

@@ -13,7 +13,7 @@
 //! `Float` in our context.
 pub use self::traits::{Float, Numeric, Primitive};
 #[cfg(feature = "matrix")]
-pub use self::types::{ColumnVector, GenericMatrix, Matrix, RowVector, SquareMatrix};
+pub use self::types::{ColumnVector, GenericMatrix, Matrix, RowVector, SquareMatrix, View};
 #[cfg(feature = "angular")]
 pub use self::types::{Angular, Degree, Radiant};
 #[cfg(feature = "complex")]

+ 131 - 0
src/traits/numeric.rs

@@ -10,8 +10,15 @@ PartialEq {
     /// Function returning the whole number in the requested type.
     fn whole(value: u8) -> Self;
 
+    /// Default tolerance for comparisons.
+    fn epsilon() -> Self;
+
     /// Calculate the absolute value.
     fn abs(self) -> Self;
+
+    /// Check for equality.
+    /// For integrals, the tolerance is ignored.
+    fn is_equal_to(self, other: &Self, tolerance: Self) -> bool;
 }
 
 impl Numeric for f32 {
@@ -19,9 +26,17 @@ impl Numeric for f32 {
         value as Self
     }
 
+    fn epsilon() -> Self {
+        1e-9
+    }
+
     fn abs(self) -> Self {
         self.abs()
     }
+
+    fn is_equal_to(self, other: &Self, tolerance: Self) -> bool {
+        (self - *other).abs() <= tolerance
+    }
 }
 
 impl Numeric for f64 {
@@ -29,9 +44,17 @@ impl Numeric for f64 {
         value as Self
     }
 
+    fn epsilon() -> Self {
+        1e-9
+    }
+
     fn abs(self) -> Self {
         self.abs()
     }
+
+    fn is_equal_to(self, other: &Self, tolerance: Self) -> bool {
+        (self - *other).abs() <= tolerance
+    }
 }
 
 impl Numeric for i8 {
@@ -39,9 +62,18 @@ impl Numeric for i8 {
         value as Self
     }
 
+    fn epsilon() -> Self {
+        0
+    }
+
     fn abs(self) -> Self {
         self.abs()
     }
+
+    fn is_equal_to(self, other: &Self, tolerance: Self) -> bool {
+        debug_assert_eq!(tolerance, 0);
+        self == *other
+    }
 }
 
 impl Numeric for i16 {
@@ -49,9 +81,18 @@ impl Numeric for i16 {
         value as Self
     }
 
+    fn epsilon() -> Self {
+        0
+    }
+
     fn abs(self) -> Self {
         self.abs()
     }
+
+    fn is_equal_to(self, other: &Self, tolerance: Self) -> bool {
+        debug_assert_eq!(tolerance, 0);
+        self == *other
+    }
 }
 
 impl Numeric for i32 {
@@ -59,9 +100,18 @@ impl Numeric for i32 {
         value as Self
     }
 
+    fn epsilon() -> Self {
+        0
+    }
+
     fn abs(self) -> Self {
         self.abs()
     }
+
+    fn is_equal_to(self, other: &Self, tolerance: Self) -> bool {
+        debug_assert_eq!(tolerance, 0);
+        self == *other
+    }
 }
 
 impl Numeric for i64 {
@@ -69,9 +119,18 @@ impl Numeric for i64 {
         value as Self
     }
 
+    fn epsilon() -> Self {
+        0
+    }
+
     fn abs(self) -> Self {
         self.abs()
     }
+
+    fn is_equal_to(self, other: &Self, tolerance: Self) -> bool {
+        debug_assert_eq!(tolerance, 0);
+        self == *other
+    }
 }
 
 impl Numeric for i128 {
@@ -79,9 +138,18 @@ impl Numeric for i128 {
         value as Self
     }
 
+    fn epsilon() -> Self {
+        0
+    }
+
     fn abs(self) -> Self {
         self.abs()
     }
+
+    fn is_equal_to(self, other: &Self, tolerance: Self) -> bool {
+        debug_assert_eq!(tolerance, 0);
+        self == *other
+    }
 }
 
 impl Numeric for isize {
@@ -89,9 +157,18 @@ impl Numeric for isize {
         value as Self
     }
 
+    fn epsilon() -> Self {
+        0
+    }
+
     fn abs(self) -> Self {
         self.abs()
     }
+
+    fn is_equal_to(self, other: &Self, tolerance: Self) -> bool {
+        debug_assert_eq!(tolerance, 0);
+        self == *other
+    }
 }
 
 impl Numeric for u8 {
@@ -99,9 +176,18 @@ impl Numeric for u8 {
         value as Self
     }
 
+    fn epsilon() -> Self {
+        0
+    }
+
     fn abs(self) -> Self {
         self
     }
+
+    fn is_equal_to(self, other: &Self, tolerance: Self) -> bool {
+        debug_assert_eq!(tolerance, 0);
+        self == *other
+    }
 }
 
 impl Numeric for u16 {
@@ -109,9 +195,18 @@ impl Numeric for u16 {
         value as Self
     }
 
+    fn epsilon() -> Self {
+        0
+    }
+
     fn abs(self) -> Self {
         self
     }
+
+    fn is_equal_to(self, other: &Self, tolerance: Self) -> bool {
+        debug_assert_eq!(tolerance, 0);
+        self == *other
+    }
 }
 
 impl Numeric for u32 {
@@ -119,9 +214,18 @@ impl Numeric for u32 {
         value as Self
     }
 
+    fn epsilon() -> Self {
+        0
+    }
+
     fn abs(self) -> Self {
         self
     }
+
+    fn is_equal_to(self, other: &Self, tolerance: Self) -> bool {
+        debug_assert_eq!(tolerance, 0);
+        self == *other
+    }
 }
 
 impl Numeric for u64 {
@@ -129,9 +233,18 @@ impl Numeric for u64 {
         value as Self
     }
 
+    fn epsilon() -> Self {
+        0
+    }
+
     fn abs(self) -> Self {
         self
     }
+
+    fn is_equal_to(self, other: &Self, tolerance: Self) -> bool {
+        debug_assert_eq!(tolerance, 0);
+        self == *other
+    }
 }
 
 impl Numeric for u128 {
@@ -139,9 +252,18 @@ impl Numeric for u128 {
         value as Self
     }
 
+    fn epsilon() -> Self {
+        0
+    }
+
     fn abs(self) -> Self {
         self
     }
+
+    fn is_equal_to(self, other: &Self, tolerance: Self) -> bool {
+        debug_assert_eq!(tolerance, 0);
+        self == *other
+    }
 }
 
 impl Numeric for usize {
@@ -149,9 +271,18 @@ impl Numeric for usize {
         value as Self
     }
 
+    fn epsilon() -> Self {
+        0
+    }
+
     fn abs(self) -> Self {
         self
     }
+
+    fn is_equal_to(self, other: &Self, tolerance: Self) -> bool {
+        debug_assert_eq!(tolerance, 0);
+        self == *other
+    }
 }
 
 #[cfg(test)]

+ 12 - 0
src/types/complex.rs

@@ -262,12 +262,24 @@ impl<T: Float + Numeric + Primitive> Numeric for Complex<T> {
         Self { r: T::whole(value), i: T::whole(0) }
     }
 
+    fn epsilon() -> Self {
+        Self {
+            r: T::epsilon(),
+            i: T::whole(0),
+        }
+    }
+
     fn abs(self) -> Self {
         Self {
             r: (self.r * self.r + self.i * self.i).sqrt(),
             i: T::whole(0),
         }
     }
+
+    fn is_equal_to(self, other: &Self, tolerance: Self) -> bool {
+        debug_assert_eq!(tolerance.i, T::value(0.0));
+        (self - *other).abs().r <= tolerance.r
+    }
 }
 
 #[cfg(test)]

+ 20 - 20
src/types/matrix/generic.rs

@@ -1,13 +1,14 @@
 use std::ops::{Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Neg, Sub, SubAssign};
 use std::ptr::swap;
 
-use crate::{Float, Matrix, Numeric, Primitive};
+use crate::{Float, Matrix, Numeric, Primitive, View};
 #[cfg(feature = "angular")]
 use crate::Angular;
+use crate::types::matrix::matrices_are_equal;
 
 /// Struct representing a dense matrix.
 #[repr(transparent)]
-#[derive(Copy, Clone, Debug, PartialEq)]
+#[derive(Copy, Clone, Debug)]
 pub struct GenericMatrix<T: Numeric, const ROWS: usize, const COLUMNS: usize> {
     pub data: [[T; COLUMNS]; ROWS],
 }
@@ -93,6 +94,10 @@ impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> GenericMatrix<T, ROWS,
         Self::default()
     }
 
+    pub fn view<const VIEW_ROWS: usize, const VIEW_COLUMNS: usize>(&self, origin: (usize, usize)) -> View<T, VIEW_ROWS, VIEW_COLUMNS, ROWS, COLUMNS, Self> {
+        View::new(self, origin)
+    }
+
     /// Create a transpose of the input matrix.
     ///
     /// # Example
@@ -511,24 +516,19 @@ impl<T: Numeric + Float + Primitive> SquareMatrix<T, 4> {
     }
 }
 
-#[cfg(test)]
-mod tests {
-    use crate::{ColumnVector, Complex, cplx, Float, GenericMatrix, Matrix, Numeric, RowVector, SquareMatrix};
+impl<T: Numeric, const ROWS: usize, const COLUMNS: usize, RHS: Matrix<T, ROWS, COLUMNS>> PartialEq<RHS> for GenericMatrix<T, ROWS, COLUMNS> {
+    fn eq(&self, other: &RHS) -> bool {
+        matrices_are_equal(self, other, T::epsilon())
+    }
 
-    fn matrices_are_equal<T: Numeric + Float + PartialOrd, const ROWS: usize, const COLUMNS: usize>(
-        a: &GenericMatrix<T, ROWS, COLUMNS>,
-        b: &GenericMatrix<T, ROWS, COLUMNS>,
-        tolerance: T,
-    ) -> bool {
-        for r in 0..ROWS {
-            for c in 0..COLUMNS {
-                if (a[(r, c)] - b[(r, c)]).abs() > tolerance {
-                    return false;
-                }
-            }
-        }
-        true
+    fn ne(&self, other: &RHS) -> bool {
+        !matrices_are_equal(self, other, T::epsilon())
     }
+}
+
+#[cfg(test)]
+mod tests {
+    use crate::{ColumnVector, Complex, cplx, GenericMatrix, Matrix, RowVector, SquareMatrix};
 
     #[test]
     fn identity_matrix() {
@@ -644,8 +644,8 @@ mod tests {
             [-3.0, 1.0],
         ]);
 
-        assert!(matrices_are_equal(&inverted, &expected, tolerance));
-        assert!(matrices_are_equal(&(inverted * mat), &SquareMatrix::identity(), tolerance));
+        assert_eq!(inverted, expected);
+        assert_eq!(inverted * mat, SquareMatrix::identity());
     }
 
     #[test]

+ 20 - 0
src/types/matrix/mod.rs

@@ -1,7 +1,10 @@
 use crate::Numeric;
+
 pub use self::generic::{ColumnVector, GenericMatrix, RowVector, SquareMatrix};
+pub use self::view::{View};
 
 mod generic;
+mod view;
 
 /// Abstraction for all kinds of fixed size matrices.
 pub trait Matrix<T: Numeric, const ROWS: usize, const COLUMNS: usize> {
@@ -22,3 +25,20 @@ pub trait Matrix<T: Numeric, const ROWS: usize, const COLUMNS: usize> {
     /// E.g. cells outsize of the main diagonal of a `DiagonalMatrix` can only be set to zero.
     fn set(&mut self, row: usize, col: usize, value: T);
 }
+
+/// Compare two matrices for equality with a given tolerance.
+/// This is used as the implementation by all the `PartialEq` implementations of the different matrix types.
+fn matrices_are_equal<T: Numeric, const ROWS: usize, const COLUMNS: usize, RHS: Matrix<T, ROWS, COLUMNS>, LHS: Matrix<T, ROWS, COLUMNS>>(
+    rhs: &RHS,
+    lhs: &LHS,
+    tolerance: T,
+) -> bool {
+    for r in 0..ROWS {
+        for c in 0..COLUMNS {
+            if !rhs.get(r, c).is_equal_to(&lhs.get(r, c), tolerance) {
+                return false;
+            }
+        }
+    }
+    true
+}

+ 111 - 0
src/types/matrix/view.rs

@@ -0,0 +1,111 @@
+use std::fmt::{Debug, Formatter};
+use std::marker::PhantomData;
+use crate::{GenericMatrix, Matrix, Numeric};
+use crate::types::matrix::matrices_are_equal;
+
+pub struct View<
+    'a,
+    T: Numeric,
+    const ROWS: usize,
+    const COLUMNS: usize,
+    const BASE_ROWS: usize,
+    const BASE_COLUMNS: usize,
+    M: Matrix<T, BASE_ROWS, BASE_COLUMNS>
+> {
+    base: &'a M,
+    origin: (usize, usize),
+    phantom: PhantomData<*const T>,
+}
+
+impl<
+    'a,
+    T: Numeric,
+    const ROWS: usize,
+    const COLUMNS: usize,
+    const BASE_ROWS: usize,
+    const BASE_COLUMNS: usize,
+    M: Matrix<T, BASE_ROWS, BASE_COLUMNS>
+> Debug for View<'a, T, ROWS, COLUMNS, BASE_ROWS, BASE_COLUMNS, M> {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        let mut mat = GenericMatrix::<T, ROWS, COLUMNS>::new();
+        for r in 0..ROWS {
+            for c in 0..COLUMNS {
+                mat.data[r][c] = self.get(r, c);
+            }
+        }
+
+        f
+            .debug_struct("View")
+            .field("data", &mat.data)
+            .finish()
+    }
+}
+
+impl<
+    'a,
+    T: Numeric,
+    const ROWS: usize,
+    const COLUMNS: usize,
+    const BASE_ROWS: usize,
+    const BASE_COLUMNS: usize,
+    M: Matrix<T, BASE_ROWS, BASE_COLUMNS>
+> View<'a, T, ROWS, COLUMNS, BASE_ROWS, BASE_COLUMNS, M> {
+    pub fn new(base: &'a M, origin: (usize, usize)) -> Self {
+        Self {
+            base,
+            origin,
+            phantom: PhantomData,
+        }
+    }
+}
+
+impl<
+    'a,
+    T: Numeric,
+    const ROWS: usize,
+    const COLUMNS: usize,
+    const BASE_ROWS: usize,
+    const BASE_COLUMNS: usize,
+    M: Matrix<T, BASE_ROWS, BASE_COLUMNS>
+> Matrix<T, ROWS, COLUMNS> for View<'a, T, ROWS, COLUMNS, BASE_ROWS, BASE_COLUMNS, M> {
+    fn get(&self, row: usize, col: usize) -> T {
+        let (base_row, base_col) = self.origin;
+        self.base.get(base_row + row, base_col + col)
+    }
+
+    fn set(&mut self, row: usize, col: usize, value: T) {
+        let (base_row, base_col) = self.origin;
+        assert_eq!(self.base.get(base_row + row, base_col + col), value);
+    }
+}
+
+impl<
+    'a,
+    T: Numeric,
+    const ROWS: usize,
+    const COLUMNS: usize,
+    const BASE_ROWS: usize,
+    const BASE_COLUMNS: usize,
+    M: Matrix<T, BASE_ROWS, BASE_COLUMNS>,
+    RHS: Matrix<T, ROWS, COLUMNS>
+> PartialEq<RHS> for View<'a, T, ROWS, COLUMNS, BASE_ROWS, BASE_COLUMNS, M> {
+    fn eq(&self, other: &RHS) -> bool {
+        matrices_are_equal(self, other, T::epsilon())
+    }
+
+    fn ne(&self, other: &RHS) -> bool {
+        !matrices_are_equal(self, other, T::epsilon())
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use crate::{SquareMatrix};
+
+    #[test]
+    fn view_of_generic_matrix() {
+        let mat = SquareMatrix::diagonal(&[1, 2, 3, 4]);
+        let view = mat.view::<2, 2>((1, 1));
+        assert_eq!(view, SquareMatrix::diagonal(&[2, 3]));
+    }
+}

+ 1 - 1
src/types/mod.rs

@@ -3,7 +3,7 @@ pub use self::angular::{Angular, Degree, Radiant};
 #[cfg(feature = "complex")]
 pub use self::complex::Complex;
 #[cfg(feature = "matrix")]
-pub use self::matrix::{ColumnVector, GenericMatrix, Matrix, RowVector, SquareMatrix};
+pub use self::matrix::{ColumnVector, GenericMatrix, Matrix, RowVector, SquareMatrix, View};
 #[cfg(feature = "quaternion")]
 pub use self::quaternion::Quaternion;