瀏覽代碼

:white_check_mark: added matrix multiplication test

Felix Bytow 2 年之前
父節點
當前提交
04b2e43c5b
共有 3 個文件被更改,包括 46 次插入4 次删除
  1. 20 0
      src/types/matrix/generic.rs
  2. 0 3
      src/types/matrix/mod.rs
  3. 26 1
      src/types/matrix/view.rs

+ 20 - 0
src/types/matrix/generic.rs

@@ -384,6 +384,26 @@ impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> From<[[T; COLUMNS]; RO
     }
 }
 
+impl<
+    'a,
+    T: Numeric,
+    const ROWS: usize,
+    const COLUMNS: usize,
+    const BASE_ROWS: usize,
+    const BASE_COLS: usize,
+    M: Matrix<T, BASE_ROWS, BASE_COLS>
+> From<View<'a, T, ROWS, COLUMNS, BASE_ROWS, BASE_COLS, M>> for GenericMatrix<T, ROWS, COLUMNS> {
+    fn from(v: View<'a, T, ROWS, COLUMNS, BASE_ROWS, BASE_COLS, M>) -> Self {
+        let mut result = Self::new();
+        for r in 0..ROWS {
+            for c in 0..COLUMNS {
+                result.data[r][c] = v.get(r, c);
+            }
+        }
+        result
+    }
+}
+
 impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Index<(usize, usize)> for GenericMatrix<T, ROWS, COLUMNS> {
     type Output = T;
 

+ 0 - 3
src/types/matrix/mod.rs

@@ -95,6 +95,3 @@ pub fn multiply<
         }
     }
 }
-
-#[cfg(test)]
-mod tests {}

+ 26 - 1
src/types/matrix/view.rs

@@ -200,7 +200,8 @@ impl<
 
 #[cfg(test)]
 mod tests {
-    use crate::{SquareMatrix};
+    use crate::{ColumnVector, RowVector, SquareMatrix};
+    use crate::types::matrix::multiply;
 
     #[test]
     fn view_of_generic_matrix() {
@@ -208,4 +209,28 @@ mod tests {
         let view = mat.view::<2, 2>(1, 1);
         assert_eq!(view, SquareMatrix::diagonal(&[2, 3]));
     }
+
+    #[test]
+    fn multiply_views_inside_generic_matrix() {
+        let mut mat = SquareMatrix::from([
+            [0, 4, -2, -1],
+            [1, 0, 0, 0],
+            [3, 0, 0, 0],
+            [-5, 0, 0, 0],
+        ]);
+
+        let a = ColumnVector::from(mat.view::<3, 1>(1, 0));
+        let b = RowVector::from(mat.view::<1, 3>(0, 1));
+        multiply(&mut mat.view_mut::<3, 3>(1, 1), &a, &b);
+
+        assert_eq!(
+            mat,
+            SquareMatrix::from([
+                [0, 4, -2, -1],
+                [1, 4, -2, -1],
+                [3, 12, -6, -3],
+                [-5, -20, 10, 5],
+            ])
+        );
+    }
 }