Parcourir la source

:sparkles: there is now Matrix and MatrixMut and View and ViewMut

Felix Bytow il y a 2 ans
Parent
commit
951117bf6c
5 fichiers modifiés avec 249 ajouts et 37 suppressions
  1. 1 1
      src/lib.rs
  2. 80 24
      src/types/matrix/generic.rs
  3. 61 5
      src/types/matrix/mod.rs
  4. 106 6
      src/types/matrix/view.rs
  5. 1 1
      src/types/mod.rs

+ 1 - 1
src/lib.rs

@@ -13,7 +13,7 @@
 //! `Float` in our context.
 pub use self::traits::{Float, Numeric, Primitive};
 #[cfg(feature = "matrix")]
-pub use self::types::{ColumnVector, GenericMatrix, Matrix, RowVector, SquareMatrix, View};
+pub use self::types::{ColumnVector, GenericMatrix, Matrix, MatrixMut, RowVector, SquareMatrix, View, ViewMut};
 #[cfg(feature = "angular")]
 pub use self::types::{Angular, Degree, Radiant};
 #[cfg(feature = "complex")]

+ 80 - 24
src/types/matrix/generic.rs

@@ -1,10 +1,10 @@
 use std::ops::{Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Neg, Sub, SubAssign};
 use std::ptr::swap;
 
-use crate::{Float, Matrix, Numeric, Primitive, View};
+use crate::{Float, Matrix, Numeric, Primitive, View, ViewMut};
 #[cfg(feature = "angular")]
 use crate::Angular;
-use crate::types::matrix::matrices_are_equal;
+use crate::types::matrix::{compare, ComparisonResult, MatrixMut, multiply};
 
 /// Struct representing a dense matrix.
 #[repr(transparent)]
@@ -17,7 +17,9 @@ impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Matrix<T, ROWS, COLUMN
     fn get(&self, row: usize, col: usize) -> T {
         self.data[row][col]
     }
+}
 
+impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> MatrixMut<T, ROWS, COLUMNS> for GenericMatrix<T, ROWS, COLUMNS> {
     fn set(&mut self, row: usize, col: usize, value: T) {
         self.data[row][col] = value;
     }
@@ -94,8 +96,68 @@ impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> GenericMatrix<T, ROWS,
         Self::default()
     }
 
-    pub fn view<const VIEW_ROWS: usize, const VIEW_COLUMNS: usize>(&self, origin: (usize, usize)) -> View<T, VIEW_ROWS, VIEW_COLUMNS, ROWS, COLUMNS, Self> {
-        View::new(self, origin)
+    /// Create a view for a part of the matrix.
+    ///
+    /// # Example
+    ///
+    /// ```rust
+    /// # use lineal::GenericMatrix;
+    /// let mat = GenericMatrix::from([
+    ///     [1, 2, 3, 4],
+    ///     [5, 6, 7, 8],
+    ///     [9, 10, 11, 12],
+    /// ]);
+    /// // 3 and 2 are the size of the view, 0 and 1 are the starting indices of the view.
+    /// let view = mat.view::<3, 2>(0, 1);
+    /// assert_eq!(
+    ///     view,
+    ///     GenericMatrix::from([
+    ///         [2, 3],
+    ///         [6, 7],
+    ///         [10, 11],
+    ///     ])
+    /// );
+    /// ```
+    pub fn view<const VIEW_ROWS: usize, const VIEW_COLUMNS: usize>(&self, base_row: usize, base_col: usize) -> View<T, VIEW_ROWS, VIEW_COLUMNS, ROWS, COLUMNS, Self> {
+        View::new(self, (base_row, base_col))
+    }
+
+    /// Create a mutable view for a part of the matrix.
+    ///
+    /// # Example
+    ///
+    /// ```rust
+    /// # use lineal::{GenericMatrix, MatrixMut};
+    /// let mut mat = GenericMatrix::from([
+    ///     [1, 2, 3, 4],
+    ///     [5, 6, 7, 8],
+    ///     [9, 10, 11, 12],
+    /// ]);
+    /// {
+    ///     // 3 and 2 are the size of the view, 0 and 1 are the starting indices of the view.
+    ///     let mut view = mat.view_mut::<3, 2>(0, 1);
+    ///     view.set(1, 0, 23);
+    ///     view.set(1, 1, 42);
+    ///     assert_eq!(
+    ///         view,
+    ///         GenericMatrix::from([
+    ///             [2, 3],
+    ///             [23, 42],
+    ///             [10, 11],
+    ///         ])
+    ///     );
+    /// }
+    /// assert_eq!(
+    ///     mat,
+    ///     GenericMatrix::from([
+    ///         [1, 2, 3, 4],
+    ///         [5, 23, 42, 8],
+    ///         [9, 10, 11, 12],
+    ///     ])
+    /// );
+    /// ```
+    pub fn view_mut<const VIEW_ROWS: usize, const VIEW_COLUMNS: usize>(&mut self, base_row: usize, base_col: usize) -> ViewMut<T, VIEW_ROWS, VIEW_COLUMNS, ROWS, COLUMNS, Self> {
+        ViewMut::new(self, (base_row, base_col))
     }
 
     /// Create a transpose of the input matrix.
@@ -352,39 +414,39 @@ impl<T: Numeric + Neg<Output=T>, const ROWS: usize, const COLUMNS: usize> Neg fo
     }
 }
 
-impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Add for GenericMatrix<T, ROWS, COLUMNS> {
+impl<T: Numeric, const ROWS: usize, const COLUMNS: usize, RHS: Matrix<T, ROWS, COLUMNS>> Add<RHS> for GenericMatrix<T, ROWS, COLUMNS> {
     type Output = Self;
 
-    fn add(mut self, rhs: Self) -> Self::Output {
+    fn add(mut self, rhs: RHS) -> Self::Output {
         self += rhs;
         self
     }
 }
 
-impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> AddAssign for GenericMatrix<T, ROWS, COLUMNS> {
-    fn add_assign(&mut self, rhs: Self) {
+impl<T: Numeric, const ROWS: usize, const COLUMNS: usize, RHS: Matrix<T, ROWS, COLUMNS>> AddAssign<RHS> for GenericMatrix<T, ROWS, COLUMNS> {
+    fn add_assign(&mut self, rhs: RHS) {
         for r in 0..ROWS {
             for c in 0..COLUMNS {
-                self.data[r][c] += rhs.data[r][c];
+                self.data[r][c] += rhs.get(r, c);
             }
         }
     }
 }
 
-impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Sub for GenericMatrix<T, ROWS, COLUMNS> {
+impl<T: Numeric, const ROWS: usize, const COLUMNS: usize, RHS: Matrix<T, ROWS, COLUMNS>> Sub<RHS> for GenericMatrix<T, ROWS, COLUMNS> {
     type Output = Self;
 
-    fn sub(mut self, rhs: Self) -> Self::Output {
+    fn sub(mut self, rhs: RHS) -> Self::Output {
         self -= rhs;
         self
     }
 }
 
-impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> SubAssign for GenericMatrix<T, ROWS, COLUMNS> {
-    fn sub_assign(&mut self, rhs: Self) {
+impl<T: Numeric, const ROWS: usize, const COLUMNS: usize, RHS: Matrix<T, ROWS, COLUMNS>> SubAssign<RHS> for GenericMatrix<T, ROWS, COLUMNS> {
+    fn sub_assign(&mut self, rhs: RHS) {
         for r in 0..ROWS {
             for c in 0..COLUMNS {
-                self.data[r][c] -= rhs.data[r][c];
+                self.data[r][c] -= rhs.get(r, c);
             }
         }
     }
@@ -441,13 +503,7 @@ impl<T: Numeric, const ROWS: usize, const COMMON: usize, const COLUMNS: usize> M
 
     fn mul(self, rhs: GenericMatrix<T, COMMON, COLUMNS>) -> Self::Output {
         let mut result = Self::Output::default();
-        for i in 0..ROWS {
-            for j in 0..COLUMNS {
-                for k in 0..COMMON {
-                    result.data[i][j] += self.data[i][k] * rhs.data[k][j];
-                }
-            }
-        }
+        multiply(&mut result, &self, &rhs);
         result
     }
 }
@@ -518,17 +574,17 @@ impl<T: Numeric + Float + Primitive> SquareMatrix<T, 4> {
 
 impl<T: Numeric, const ROWS: usize, const COLUMNS: usize, RHS: Matrix<T, ROWS, COLUMNS>> PartialEq<RHS> for GenericMatrix<T, ROWS, COLUMNS> {
     fn eq(&self, other: &RHS) -> bool {
-        matrices_are_equal(self, other, T::epsilon())
+        compare(self, other, T::epsilon()) == ComparisonResult::Equal
     }
 
     fn ne(&self, other: &RHS) -> bool {
-        !matrices_are_equal(self, other, T::epsilon())
+        compare(self, other, T::epsilon()) == ComparisonResult::NotEqual
     }
 }
 
 #[cfg(test)]
 mod tests {
-    use crate::{ColumnVector, Complex, cplx, GenericMatrix, Matrix, RowVector, SquareMatrix};
+    use crate::{ColumnVector, Complex, cplx, GenericMatrix, Matrix, MatrixMut, RowVector, SquareMatrix};
 
     #[test]
     fn identity_matrix() {

+ 61 - 5
src/types/matrix/mod.rs

@@ -1,7 +1,7 @@
 use crate::Numeric;
 
 pub use self::generic::{ColumnVector, GenericMatrix, RowVector, SquareMatrix};
-pub use self::view::{View};
+pub use self::view::{View, ViewMut};
 
 mod generic;
 mod view;
@@ -14,7 +14,9 @@ pub trait Matrix<T: Numeric, const ROWS: usize, const COLUMNS: usize> {
     ///
     /// This function panics, if the row or column are outside of the valid range of the matrix.
     fn get(&self, row: usize, col: usize) -> T;
+}
 
+pub trait MatrixMut<T: Numeric, const ROWS: usize, const COLUMNS: usize>: Matrix<T, ROWS, COLUMNS> {
     /// Set the value of the given cell.
     ///
     /// # Panics
@@ -26,19 +28,73 @@ pub trait Matrix<T: Numeric, const ROWS: usize, const COLUMNS: usize> {
     fn set(&mut self, row: usize, col: usize, value: T);
 }
 
+/// Helper class representing the result of a matrix comparison.
+#[derive(Eq, PartialEq, Debug, Copy, Clone)]
+pub enum ComparisonResult {
+    /// The two matrices are considered unequal.
+    NotEqual,
+
+    /// The two matrices are considered equal.
+    Equal,
+}
+
 /// Compare two matrices for equality with a given tolerance.
 /// This is used as the implementation by all the `PartialEq` implementations of the different matrix types.
-fn matrices_are_equal<T: Numeric, const ROWS: usize, const COLUMNS: usize, RHS: Matrix<T, ROWS, COLUMNS>, LHS: Matrix<T, ROWS, COLUMNS>>(
+pub fn compare<T: Numeric, const ROWS: usize, const COLUMNS: usize, RHS: Matrix<T, ROWS, COLUMNS>, LHS: Matrix<T, ROWS, COLUMNS>>(
     rhs: &RHS,
     lhs: &LHS,
     tolerance: T,
-) -> bool {
+) -> ComparisonResult {
     for r in 0..ROWS {
         for c in 0..COLUMNS {
             if !rhs.get(r, c).is_equal_to(&lhs.get(r, c), tolerance) {
-                return false;
+                return ComparisonResult::NotEqual;
+            }
+        }
+    }
+    ComparisonResult::Equal
+}
+
+pub fn add<
+    T: Numeric, const ROWS: usize, const COLUMNS: usize,
+    RESULT: MatrixMut<T, ROWS, COLUMNS>,
+    LHS: Matrix<T, ROWS, COLUMNS>,
+    RHS: Matrix<T, ROWS, COLUMNS>
+>(result: &mut RESULT, lhs: &LHS, rhs: &RHS) {
+    for r in 0..ROWS {
+        for c in 0..COLUMNS {
+            result.set(r, c, lhs.get(r, c) + rhs.get(r, c));
+        }
+    }
+}
+
+pub fn subtract<
+    T: Numeric, const ROWS: usize, const COLUMNS: usize,
+    RESULT: MatrixMut<T, ROWS, COLUMNS>,
+    LHS: Matrix<T, ROWS, COLUMNS>,
+    RHS: Matrix<T, ROWS, COLUMNS>
+>(result: &mut RESULT, lhs: &LHS, rhs: &RHS) {
+    for r in 0..ROWS {
+        for c in 0..COLUMNS {
+            result.set(r, c, lhs.get(r, c) - rhs.get(r, c));
+        }
+    }
+}
+
+pub fn multiply<
+    T: Numeric, const ROWS: usize, const COMMON: usize, const COLUMNS: usize,
+    RESULT: MatrixMut<T, ROWS, COLUMNS>,
+    LHS: Matrix<T, ROWS, COMMON>,
+    RHS: Matrix<T, COMMON, COLUMNS>
+>(result: &mut RESULT, lhs: &LHS, rhs: &RHS) {
+    for i in 0..ROWS {
+        for j in 0..COLUMNS {
+            for k in 0..COMMON {
+                result.set(i, j, result.get(i, j) + lhs.get(i, k) * rhs.get(k, j));
             }
         }
     }
-    true
 }
+
+#[cfg(test)]
+mod tests {}

+ 106 - 6
src/types/matrix/view.rs

@@ -1,7 +1,7 @@
 use std::fmt::{Debug, Formatter};
 use std::marker::PhantomData;
-use crate::{GenericMatrix, Matrix, Numeric};
-use crate::types::matrix::matrices_are_equal;
+use crate::{GenericMatrix, Matrix, MatrixMut, Numeric};
+use crate::types::matrix::{compare, ComparisonResult};
 
 pub struct View<
     'a,
@@ -17,6 +17,20 @@ pub struct View<
     phantom: PhantomData<*const T>,
 }
 
+pub struct ViewMut<
+    'a,
+    T: Numeric,
+    const ROWS: usize,
+    const COLUMNS: usize,
+    const BASE_ROWS: usize,
+    const BASE_COLUMNS: usize,
+    M: MatrixMut<T, BASE_ROWS, BASE_COLUMNS>
+> {
+    base: &'a mut M,
+    origin: (usize, usize),
+    phantom: PhantomData<*const T>,
+}
+
 impl<
     'a,
     T: Numeric,
@@ -41,6 +55,30 @@ impl<
     }
 }
 
+impl<
+    'a,
+    T: Numeric,
+    const ROWS: usize,
+    const COLUMNS: usize,
+    const BASE_ROWS: usize,
+    const BASE_COLUMNS: usize,
+    M: MatrixMut<T, BASE_ROWS, BASE_COLUMNS>
+> Debug for ViewMut<'a, T, ROWS, COLUMNS, BASE_ROWS, BASE_COLUMNS, M> {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        let mut mat = GenericMatrix::<T, ROWS, COLUMNS>::new();
+        for r in 0..ROWS {
+            for c in 0..COLUMNS {
+                mat.data[r][c] = self.get(r, c);
+            }
+        }
+
+        f
+            .debug_struct("ViewMut")
+            .field("data", &mat.data)
+            .finish()
+    }
+}
+
 impl<
     'a,
     T: Numeric,
@@ -59,6 +97,24 @@ impl<
     }
 }
 
+impl<
+    'a,
+    T: Numeric,
+    const ROWS: usize,
+    const COLUMNS: usize,
+    const BASE_ROWS: usize,
+    const BASE_COLUMNS: usize,
+    M: MatrixMut<T, BASE_ROWS, BASE_COLUMNS>
+> ViewMut<'a, T, ROWS, COLUMNS, BASE_ROWS, BASE_COLUMNS, M> {
+    pub fn new(base: &'a mut M, origin: (usize, usize)) -> Self {
+        Self {
+            base,
+            origin,
+            phantom: PhantomData,
+        }
+    }
+}
+
 impl<
     'a,
     T: Numeric,
@@ -72,10 +128,35 @@ impl<
         let (base_row, base_col) = self.origin;
         self.base.get(base_row + row, base_col + col)
     }
+}
 
+impl<
+    'a,
+    T: Numeric,
+    const ROWS: usize,
+    const COLUMNS: usize,
+    const BASE_ROWS: usize,
+    const BASE_COLUMNS: usize,
+    M: MatrixMut<T, BASE_ROWS, BASE_COLUMNS>
+> Matrix<T, ROWS, COLUMNS> for ViewMut<'a, T, ROWS, COLUMNS, BASE_ROWS, BASE_COLUMNS, M> {
+    fn get(&self, row: usize, col: usize) -> T {
+        let (base_row, base_col) = self.origin;
+        self.base.get(base_row + row, base_col + col)
+    }
+}
+
+impl<
+    'a,
+    T: Numeric,
+    const ROWS: usize,
+    const COLUMNS: usize,
+    const BASE_ROWS: usize,
+    const BASE_COLUMNS: usize,
+    M: MatrixMut<T, BASE_ROWS, BASE_COLUMNS>
+> MatrixMut<T, ROWS, COLUMNS> for ViewMut<'a, T, ROWS, COLUMNS, BASE_ROWS, BASE_COLUMNS, M> {
     fn set(&mut self, row: usize, col: usize, value: T) {
         let (base_row, base_col) = self.origin;
-        assert_eq!(self.base.get(base_row + row, base_col + col), value);
+        self.base.set(base_row + row, base_col + col, value);
     }
 }
 
@@ -90,11 +171,30 @@ impl<
     RHS: Matrix<T, ROWS, COLUMNS>
 > PartialEq<RHS> for View<'a, T, ROWS, COLUMNS, BASE_ROWS, BASE_COLUMNS, M> {
     fn eq(&self, other: &RHS) -> bool {
-        matrices_are_equal(self, other, T::epsilon())
+        compare(self, other, T::epsilon()) == ComparisonResult::Equal
+    }
+
+    fn ne(&self, other: &RHS) -> bool {
+        compare(self, other, T::epsilon()) == ComparisonResult::NotEqual
+    }
+}
+
+impl<
+    'a,
+    T: Numeric,
+    const ROWS: usize,
+    const COLUMNS: usize,
+    const BASE_ROWS: usize,
+    const BASE_COLUMNS: usize,
+    M: MatrixMut<T, BASE_ROWS, BASE_COLUMNS>,
+    RHS: Matrix<T, ROWS, COLUMNS>
+> PartialEq<RHS> for ViewMut<'a, T, ROWS, COLUMNS, BASE_ROWS, BASE_COLUMNS, M> {
+    fn eq(&self, other: &RHS) -> bool {
+        compare(self, other, T::epsilon()) == ComparisonResult::Equal
     }
 
     fn ne(&self, other: &RHS) -> bool {
-        !matrices_are_equal(self, other, T::epsilon())
+        compare(self, other, T::epsilon()) == ComparisonResult::NotEqual
     }
 }
 
@@ -105,7 +205,7 @@ mod tests {
     #[test]
     fn view_of_generic_matrix() {
         let mat = SquareMatrix::diagonal(&[1, 2, 3, 4]);
-        let view = mat.view::<2, 2>((1, 1));
+        let view = mat.view::<2, 2>(1, 1);
         assert_eq!(view, SquareMatrix::diagonal(&[2, 3]));
     }
 }

+ 1 - 1
src/types/mod.rs

@@ -3,7 +3,7 @@ pub use self::angular::{Angular, Degree, Radiant};
 #[cfg(feature = "complex")]
 pub use self::complex::Complex;
 #[cfg(feature = "matrix")]
-pub use self::matrix::{ColumnVector, GenericMatrix, Matrix, RowVector, SquareMatrix, View};
+pub use self::matrix::{ColumnVector, GenericMatrix, Matrix, MatrixMut, RowVector, SquareMatrix, View, ViewMut};
 #[cfg(feature = "quaternion")]
 pub use self::quaternion::Quaternion;