Quellcode durchsuchen

:sparkles: matrix supports inversion now

Felix Bytow vor 2 Jahren
Ursprung
Commit
14628a21b7
2 geänderte Dateien mit 160 neuen und 12 gelöschten Zeilen
  1. 1 4
      src/types/complex.rs
  2. 159 8
      src/types/matrix/generic.rs

+ 1 - 4
src/types/complex.rs

@@ -98,10 +98,7 @@ impl<T: Float + Numeric + Primitive> Complex<T> {
 
     /// The imaginary unit.
     pub fn i() -> Self {
-        Self {
-            r: unsafe { T::whole(0) },
-            i: unsafe { T::whole(1) },
-        }
+        Self::imag(unsafe { T::whole(1) })
     }
 }
 

+ 159 - 8
src/types/matrix/generic.rs

@@ -1,6 +1,7 @@
 use std::ops::{Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Neg, Sub, SubAssign};
+use std::ptr::swap;
 
-use crate::Numeric;
+use crate::{Float, Numeric};
 
 /// Struct representing a dense matrix.
 #[repr(transparent)]
@@ -75,6 +76,11 @@ impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Default for GenericMat
 }
 
 impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> GenericMatrix<T, ROWS, COLUMNS> {
+    /// Create a matrix with all cells being zero.
+    pub fn new() -> Self {
+        Self::default()
+    }
+
     /// Create a transpose of the input matrix.
     ///
     /// # Example
@@ -146,6 +152,105 @@ impl<T: Numeric, const DIMENSION: usize> SquareMatrix<T, DIMENSION> {
     }
 }
 
+impl<T: Numeric + Float + PartialOrd, const DIMENSION: usize> SquareMatrix<T, DIMENSION> {
+    /// Apply an LUP decomposition to the matrix.
+    ///
+    /// # Safety
+    ///
+    /// This function is unsafe, because even when the result is false,
+    /// the matrix may already have been modified.
+    pub unsafe fn lup_decompose(&mut self, pivot: &mut RowVector<usize, DIMENSION>, tolerance: T) -> bool {
+        for i in 0..DIMENSION {
+            pivot[(0, i)] = i;
+        }
+
+        let zero = T::whole(0);
+
+        for i in 0..DIMENSION {
+            let mut max_a = zero;
+            let mut max_i = i;
+
+            for k in i..DIMENSION {
+                let abs_a = self[(k, i)].abs();
+                if abs_a > max_a {
+                    max_a = abs_a;
+                    max_i = k;
+                }
+            }
+
+            if max_a <= tolerance {
+                return false;
+            }
+
+            if max_i != i {
+                let p_i = (&mut pivot.data[0][i]) as *mut usize;
+                let p_i_max = (&mut pivot.data[0][max_i]) as *mut usize;
+                swap(p_i, p_i_max);
+
+                let p_a: *mut [T; DIMENSION] = (&mut self.data[i]) as *mut [T; DIMENSION];
+                let p_a_max: *mut [T; DIMENSION] = (&mut self.data[max_i]) as *mut [T; DIMENSION];
+                swap(p_a, p_a_max);
+            }
+
+            for j in (i + 1)..DIMENSION {
+                let divisor = self[(i, i)];
+                self[(j, i)] /= divisor;
+
+                for k in (i + 1)..DIMENSION {
+                    let a = self[(j, i)];
+                    let b = self[(i, k)];
+                    self[(j, k)] -= a * b;
+                }
+            }
+        }
+
+        true
+    }
+
+    /// Calculate the inverse of the input matrix.
+    pub fn inverted(mut self, tolerance: T) -> Option<Self> {
+        let mut pivot: RowVector<usize, DIMENSION> = RowVector::new();
+        if !unsafe { self.lup_decompose(&mut pivot, tolerance) } {
+            return None;
+        }
+
+        let mut result = Self::new();
+
+        let zero = unsafe { T::whole(0) };
+        let one = unsafe { T::whole(1) };
+
+        for j in 0..DIMENSION {
+            for i in 0..DIMENSION {
+                result[(i, j)] = if pivot[(0, i)] == j { one } else { zero };
+
+                for k in 0..i {
+                    let factor = result[(k, j)];
+                    result[(i, j)] -= self[(i, k)] * factor;
+                }
+            }
+
+            for i in (0..DIMENSION).rev() {
+                for k in (i + 1)..DIMENSION {
+                    let factor = result[(k, j)];
+                    result[(i, j)] -= self[(i, k)] * factor;
+                }
+
+                result[(i, j)] /= self[(i, i)];
+            }
+        }
+
+        Some(result)
+    }
+}
+
+impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> From<[[T; COLUMNS]; ROWS]> for GenericMatrix<T, ROWS, COLUMNS> {
+    fn from(data: [[T; COLUMNS]; ROWS]) -> Self {
+        Self {
+            data,
+        }
+    }
+}
+
 impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Index<(usize, usize)> for GenericMatrix<T, ROWS, COLUMNS> {
     type Output = T;
 
@@ -270,7 +375,22 @@ impl<T: Numeric, const ROWS: usize, const COMMON: usize, const COLUMNS: usize> M
 
 #[cfg(test)]
 mod tests {
-    use crate::{ColumnVector, Complex, cplx, GenericMatrix, RowVector, SquareMatrix};
+    use crate::{ColumnVector, Complex, cplx, Float, GenericMatrix, Numeric, RowVector, SquareMatrix};
+
+    fn matrices_are_equal<T: Numeric + Float + PartialOrd, const ROWS: usize, const COLUMNS: usize>(
+        a: &GenericMatrix<T, ROWS, COLUMNS>,
+        b: &GenericMatrix<T, ROWS, COLUMNS>,
+        tolerance: T,
+    ) -> bool {
+        for r in 0..ROWS {
+            for c in 0..COLUMNS {
+                if (a[(r, c)] - b[(r, c)]).abs() > tolerance {
+                    return false;
+                }
+            }
+        }
+        true
+    }
 
     #[test]
     fn identity_matrix() {
@@ -332,6 +452,22 @@ mod tests {
         assert_eq!(product[(0, 0)], 3);
     }
 
+    #[test]
+    fn matrix_product_of_vectors() {
+        let a = cvec![1, 3, -5];
+        let b = rvec![4, -2, -1];
+        let product = a * b;
+        let expected = SquareMatrix {
+            data: [
+                [4, -2, -1],
+                [12, -6, -3],
+                [-20, 10, 5],
+            ],
+        };
+
+        assert_eq!(product, expected);
+    }
+
     #[test]
     fn complex_matrix_multiplication() {
         let a = SquareMatrix {
@@ -347,13 +483,28 @@ mod tests {
             ],
         };
         let mat = a * b;
-        let expected = SquareMatrix {
-            data: [
-                [cplx!(33.0, 4.0), cplx!(6.0, 23.0)],
-                [cplx!(-18.0, -25.0), cplx!(21.0, -6.0)],
-            ],
-        };
+        let expected = SquareMatrix::from([
+            [cplx!(33.0, 4.0), cplx!(6.0, 23.0)],
+            [cplx!(-18.0, -25.0), cplx!(21.0, -6.0)],
+        ]);
 
         assert_eq!(mat, expected);
     }
+
+    #[test]
+    fn invert_matrix() {
+        let tolerance = 1e-9;
+        let mat = SquareMatrix::from([
+            [2.0, 1.0],
+            [6.0, 4.0],
+        ]);
+        let inverted = mat.inverted(tolerance).unwrap();
+        let expected = SquareMatrix::from([
+            [2.0, -0.5],
+            [-3.0, 1.0],
+        ]);
+
+        assert!(matrices_are_equal(&inverted, &expected, tolerance));
+        assert!(matrices_are_equal(&(inverted * mat), &SquareMatrix::identity(), tolerance));
+    }
 }