Procházet zdrojové kódy

:sparkles: matrices now support complex cells, added documentation

Felix Bytow před 2 roky
rodič
revize
1eb25c9023
5 změnil soubory, kde provedl 204 přidání a 95 odebrání
  1. 1 1
      src/lib.rs
  2. 65 68
      src/types/complex.rs
  3. 136 24
      src/types/matrix/generic.rs
  4. 1 1
      src/types/matrix/mod.rs
  5. 1 1
      src/types/mod.rs

+ 1 - 1
src/lib.rs

@@ -12,7 +12,7 @@
 //! As a consequence the complex types `Complex<f32>` and `Complex<f64>` are also considered
 //! `Float` in our context.
 pub use self::traits::{Float, Numeric, Primitive};
-pub use self::types::{ColumnVector, Complex, GenericMatrix, Quaternion, RowVector};
+pub use self::types::{ColumnVector, Complex, GenericMatrix, Quaternion, RowVector, SquareMatrix};
 
 mod traits;
 mod types;

+ 65 - 68
src/types/complex.rs

@@ -8,30 +8,30 @@ use crate::{Float, Numeric, Primitive};
 ///
 /// ```rust
 /// # use lineal::{Complex, cplx, Float, Numeric};
-/// let c = Complex { real: 9.0, imag: 0.0 };
-/// assert_eq!(format!("{:?}", c), "Complex { real: 9.0, imag: 0.0 }");
+/// let c = Complex { r: 9.0, i: 0.0 };
+/// assert_eq!(format!("{:?}", c), "Complex { r: 9.0, i: 0.0 }");
 ///
-/// let three = c.sqrt().real;
+/// let three = c.sqrt().r;
 /// assert_eq!(three, 3.0);
 ///
 /// let i: Complex<f32> = cplx!(-1.0).sqrt();
 /// assert_eq!(i, Complex::i());
 ///
-/// let a = Complex { real: 2.0, imag: 0.5 };
+/// let a = Complex { r: 2.0, i: 0.5 };
 /// let a_squared = a * a;
-/// assert_eq!(a_squared, Complex { real: 3.75, imag: 2.0 });
+/// assert_eq!(a_squared, Complex { r: 3.75, i: 2.0 });
 /// assert_eq!(a_squared.sqrt(), a);
 ///
 /// let x = (16.0).sqrt();
-/// let y = cplx!(16.0).sqrt().real;
+/// let y = cplx!(16.0).sqrt().r;
 /// assert_eq!(x, y);
 /// ```
 #[derive(Debug, Copy, Clone)]
 pub struct Complex<T: Float + Numeric + Primitive> {
     /// The real component of the complex number.
-    pub real: T,
+    pub r: T,
     /// The imaginary component of the complex number.
-    pub imag: T,
+    pub i: T,
 }
 
 /// A macro for easier creation of `Complex` numbers.
@@ -45,29 +45,26 @@ pub struct Complex<T: Float + Numeric + Primitive> {
 /// // first real, second imaginary
 /// let b = cplx!(1.0, -1.0);
 /// // with named arguments
-/// let c = cplx!(real = -1.0,);
-/// let d = cplx!(imag = 1.0,);
-/// let e = cplx!(real = 1.1, imag = 0.0,);
+/// let c = cplx!(r = -1.0,);
+/// let d = cplx!(i = 1.0,);
+/// let e = cplx!(r = 1.1, i = 0.0,);
 /// // even reverse order
-/// let f = cplx!(imag = 42.0, real = 23.0,);
+/// let f = cplx!(i = 42.0, r = 23.0,);
 /// ```
 #[macro_export]
 macro_rules! cplx {
     ($real:expr, $imag:expr) => {
         {
-            use crate::Complex;
-            Complex { real: $real, imag: $imag }
+            Complex { r: $real, i: $imag }
         }
     };
     ($real:expr) => {
         {
-            use crate::Complex;
-            Complex { real: $real, imag: 0.0 }
+            Complex { r: $real, i: 0.0 }
         }
     };
     ($($field:ident = $value:expr),* $(,)?) => {
         {
-            use crate::Complex;
             Complex {
                 $( $field: $value, )*
                 ..Complex::default()
@@ -86,24 +83,24 @@ impl<T: Float + Numeric + Primitive> Complex<T> {
     /// Create a complex number with its imaginary component being zero.
     pub fn real(value: T) -> Self {
         Self {
-            real: value,
-            imag: unsafe { T::whole(0) },
+            r: value,
+            i: unsafe { T::whole(0) },
         }
     }
 
     /// Create a complex number with its real component being zero.
     pub fn imag(value: T) -> Self {
         Self {
-            real: unsafe { T::whole(0) },
-            imag: value,
+            r: unsafe { T::whole(0) },
+            i: value,
         }
     }
 
     /// The imaginary unit.
     pub fn i() -> Self {
         Self {
-            real: unsafe { T::whole(0) },
-            imag: unsafe { T::whole(1) },
+            r: unsafe { T::whole(0) },
+            i: unsafe { T::whole(1) },
         }
     }
 }
@@ -113,16 +110,16 @@ impl<T: Float + Numeric + Primitive> Add for Complex<T> {
 
     fn add(self, rhs: Self) -> Self::Output {
         Self {
-            real: self.real + rhs.real,
-            imag: self.imag + rhs.imag,
+            r: self.r + rhs.r,
+            i: self.i + rhs.i,
         }
     }
 }
 
 impl<T: Float + Numeric + Primitive> AddAssign for Complex<T> {
     fn add_assign(&mut self, rhs: Self) {
-        self.real += rhs.real;
-        self.imag += rhs.imag;
+        self.r += rhs.r;
+        self.i += rhs.i;
     }
 }
 
@@ -131,16 +128,16 @@ impl<T: Float + Numeric + Primitive> Sub for Complex<T> {
 
     fn sub(self, rhs: Self) -> Self::Output {
         Self {
-            real: self.real - rhs.real,
-            imag: self.imag - rhs.imag,
+            r: self.r - rhs.r,
+            i: self.i - rhs.i,
         }
     }
 }
 
 impl<T: Float + Numeric + Primitive> SubAssign for Complex<T> {
     fn sub_assign(&mut self, rhs: Self) {
-        self.real -= rhs.real;
-        self.imag -= rhs.imag;
+        self.r -= rhs.r;
+        self.i -= rhs.i;
     }
 }
 
@@ -149,8 +146,8 @@ impl<T: Float + Numeric + Primitive> Mul for Complex<T> {
 
     fn mul(self, rhs: Self) -> Self::Output {
         Self {
-            real: self.real * rhs.real - self.imag * rhs.imag,
-            imag: self.real * rhs.imag + self.imag * rhs.real,
+            r: self.r * rhs.r - self.i * rhs.i,
+            i: self.r * rhs.i + self.i * rhs.r,
         }
     }
 }
@@ -165,10 +162,10 @@ impl<T: Float + Numeric + Primitive> Div for Complex<T> {
     type Output = Self;
 
     fn div(self, rhs: Self) -> Self::Output {
-        let divisor = rhs.real * rhs.real + rhs.imag * rhs.imag;
-        let mut result = self * Self { real: rhs.real, imag: -rhs.imag };
-        result.real /= divisor;
-        result.imag /= divisor;
+        let divisor = rhs.r * rhs.r + rhs.i * rhs.i;
+        let mut result = self * Self { r: rhs.r, i: -rhs.i };
+        result.r /= divisor;
+        result.i /= divisor;
         return result;
     }
 }
@@ -184,80 +181,80 @@ impl<T: Float + Numeric + Primitive> Neg for Complex<T> {
 
     fn neg(self) -> Self::Output {
         Self {
-            real: -self.real,
-            imag: -self.imag,
+            r: -self.r,
+            i: -self.i,
         }
     }
 }
 
 impl<T: Float + Numeric + Primitive> PartialEq for Complex<T> {
     fn eq(&self, other: &Self) -> bool {
-        self.real == other.real && self.imag == other.imag
+        self.r == other.r && self.i == other.i
     }
 
     fn ne(&self, other: &Self) -> bool {
-        self.real != other.real || self.imag != other.imag
+        self.r != other.r || self.i != other.i
     }
 }
 
 impl<T: Float + Primitive + Numeric> Float for Complex<T> {
     fn abs(self) -> Self {
         Self {
-            real: (self.real * self.real + self.imag * self.imag).sqrt(),
-            imag: unsafe { T::whole(0) },
+            r: (self.r * self.r + self.i * self.i).sqrt(),
+            i: unsafe { T::whole(0) },
         }
     }
 
     fn sqrt(self) -> Self {
         let zero = unsafe { T::whole(0) };
         let two = unsafe { T::whole(2) };
-        let abs_z = self.abs().real;
+        let abs_z = self.abs().r;
 
-        let base_re = if self.imag == zero { self.real } else { (abs_z + self.real) / two };
+        let base_re = if self.i == zero { self.r } else { (abs_z + self.r) / two };
         let re = if base_re < zero { zero } else { base_re.sqrt() };
 
-        let im = if self.imag == zero { zero } else {
-            (self.imag / self.imag.abs()) * ((abs_z - self.real) / two).sqrt()
+        let im = if self.i == zero { zero } else {
+            (self.i / self.i.abs()) * ((abs_z - self.r) / two).sqrt()
         } + if base_re < zero { (-base_re).sqrt() } else { zero };
 
         Self {
-            real: re,
-            imag: im,
+            r: re,
+            i: im,
         }
     }
 
     fn sin(self) -> Self {
         Self {
-            real: self.real.sin() * self.imag.cosh(),
-            imag: self.real.cos() * self.imag.sinh(),
+            r: self.r.sin() * self.i.cosh(),
+            i: self.r.cos() * self.i.sinh(),
         }
     }
 
     fn cos(self) -> Self {
         Self {
-            real: self.real.cos() * self.imag.cosh(),
-            imag: self.real.sin() * self.imag.sinh(),
+            r: self.r.cos() * self.i.cosh(),
+            i: self.r.sin() * self.i.sinh(),
         }
     }
 
     fn sinh(self) -> Self {
         Self {
-            real: self.real.sinh() * self.imag.cos(),
-            imag: self.real.cosh() * self.imag.sin(),
+            r: self.r.sinh() * self.i.cos(),
+            i: self.r.cosh() * self.i.sin(),
         }
     }
 
     fn cosh(self) -> Self {
         Self {
-            real: self.real.cosh() * self.imag.cos(),
-            imag: self.real.sinh() * self.imag.sin(),
+            r: self.r.cosh() * self.i.cos(),
+            i: self.r.sinh() * self.i.sin(),
         }
     }
 }
 
 impl<T: Float + Numeric + Primitive> Numeric for Complex<T> {
     unsafe fn whole(value: u32) -> Self {
-        Self { real: T::whole(value), imag: T::whole(0) }
+        Self { r: T::whole(value), i: T::whole(0) }
     }
 }
 
@@ -268,11 +265,11 @@ mod tests {
     #[test]
     fn macro_creation() {
         let a = cplx!(-1.0);
-        let b = cplx!(real = -1.0,);
+        let b = cplx!(r = -1.0,);
         assert_eq!(a, b);
-        let c = cplx!(imag = 1.0,);
+        let c = cplx!(i = 1.0,);
         assert_eq!(a.sqrt(), c);
-        let d = cplx!(real = -1.0, imag = 0.0,);
+        let d = cplx!(r = -1.0, i = 0.0,);
         assert_eq!(c * c, d);
         let e = cplx!(-1.0, 0.0);
         assert_eq!(a, e);
@@ -280,19 +277,19 @@ mod tests {
 
     #[test]
     fn division() {
-        let mut a: Complex<f32> = Complex { real: 3.0, imag: 2.0 };
-        let b: Complex<f32> = Complex { real: 4.0, imag: -5.0 };
-        let expected: Complex<f32> = Complex { real: 2.0 / 41.0, imag: 23.0 / 41.0 };
+        let mut a: Complex<f32> = Complex { r: 3.0, i: 2.0 };
+        let b: Complex<f32> = Complex { r: 4.0, i: -5.0 };
+        let expected: Complex<f32> = Complex { r: 2.0 / 41.0, i: 23.0 / 41.0 };
         a /= b;
         assert_eq!(format!("{:?}", a), format!("{:?}", expected));
     }
 
     #[test]
     fn square_root_with_negatives() {
-        let a = Complex { real: -1.0, imag: -1.0 };
+        let a = Complex { r: -1.0, i: -1.0 };
         let root = a.sqrt();
-        let expected = Complex { real: 0.45508986, imag: -1.09868411 };
-        assert!((root.real - expected.real).abs() < 0.00001);
-        assert!((root.imag - expected.imag).abs() < 0.00001);
+        let expected = Complex { r: 0.45508986, i: -1.09868411 };
+        assert!((root.r - expected.r).abs() < 0.00001);
+        assert!((root.i - expected.i).abs() < 0.00001);
     }
 }

+ 136 - 24
src/types/matrix/generic.rs

@@ -1,22 +1,39 @@
 use std::ops::{Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Neg, Sub, SubAssign};
 
-use crate::{Numeric, Primitive};
+use crate::Numeric;
 
+/// Struct representing a dense matrix.
 #[repr(transparent)]
 #[derive(Copy, Clone, Debug, PartialEq)]
-pub struct GenericMatrix<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> {
+pub struct GenericMatrix<T: Numeric, const ROWS: usize, const COLUMNS: usize> {
     pub data: [[T; COLUMNS]; ROWS],
 }
 
+/// Special kind of `GenericMatrix` where both dimensions are equal.
 pub type SquareMatrix<T, const DIMENSION: usize> = GenericMatrix<T, DIMENSION, DIMENSION>;
+
+/// Special kind of `GenericMatrix` with just one column.
 pub type ColumnVector<T, const ROWS: usize> = GenericMatrix<T, ROWS, 1>;
+
+/// Special kind of `GenericMatrix` with just one row.
 pub type RowVector<T, const COLUMNS: usize> = GenericMatrix<T, 1, COLUMNS>;
 
+/// A macro for easier creation of row vectors.
+///
+/// # Example
+///
+/// ```rust
+/// # use lineal::{RowVector, rvec};
+/// let a = rvec![1.0, 2.0, 3.0];
+/// // is the same as
+/// let b = RowVector {
+///     data: [[1.0, 2.0, 3.0]],
+/// };
+/// ```
 #[macro_export]
 macro_rules! rvec {
     ($($value:expr),* $(,)?) => {
         {
-            use crate::RowVector;
             RowVector {
                 data: [[$( $value, )*]],
             }
@@ -24,11 +41,22 @@ macro_rules! rvec {
     }
 }
 
+/// A macro for easier creation of column vectors.
+///
+/// # Example
+///
+/// ```rust
+/// # use lineal::{ColumnVector, cvec};
+/// let a = cvec![1.0, 2.0, 3.0];
+/// // is the same as
+/// let b = ColumnVector {
+///     data: [[1.0], [2.0], [3.0]],
+/// };
+/// ```
 #[macro_export]
 macro_rules! cvec {
     ($($value:expr),* $(,)?) => {
         {
-            use crate::ColumnVector;
             ColumnVector {
                 data: [$( [$value], )*],
             }
@@ -36,7 +64,8 @@ macro_rules! cvec {
     }
 }
 
-impl<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> Default for GenericMatrix<T, ROWS, COLUMNS> {
+impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Default for GenericMatrix<T, ROWS, COLUMNS> {
+    /// Create a matrix with all cells being zero.
     fn default() -> Self {
         let zero = unsafe { T::whole(0) };
         Self {
@@ -45,8 +74,18 @@ impl<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> Default fo
     }
 }
 
-impl<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> GenericMatrix<T, ROWS, COLUMNS> {
-    fn transpose(&self) -> GenericMatrix<T, COLUMNS, ROWS> {
+impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> GenericMatrix<T, ROWS, COLUMNS> {
+    /// Create a transpose of the input matrix.
+    ///
+    /// # Example
+    ///
+    /// ```rust
+    /// # use lineal::{ColumnVector, RowVector, cvec, rvec};
+    /// let a = cvec![1, 2, 3];
+    /// let b = rvec![1, 2, 3];
+    /// assert_eq!(a.transposed(), b);
+    /// ```
+    pub fn transposed(&self) -> GenericMatrix<T, COLUMNS, ROWS> {
         let mut result = GenericMatrix::default();
         for r in 0..ROWS {
             for c in 0..COLUMNS {
@@ -57,8 +96,23 @@ impl<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> GenericMat
     }
 }
 
-impl<T: Numeric + Primitive, const DIMENSION: usize> SquareMatrix<T, DIMENSION> {
-    fn identity() -> Self {
+impl<T: Numeric, const DIMENSION: usize> SquareMatrix<T, DIMENSION> {
+    /// Create a square identity matrix.
+    ///
+    /// In an identity matrix the main diagonal is filled with ones,
+    /// while the rest if the cells are zero.
+    ///
+    /// # Example
+    ///
+    /// ```rust
+    /// # use lineal::{SquareMatrix};
+    /// let a: SquareMatrix<f32, 3> = SquareMatrix::identity();
+    /// // is the same as
+    /// let b: SquareMatrix<f32, 3> = SquareMatrix {
+    ///     data: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
+    /// };
+    /// ```
+    pub fn identity() -> Self {
         let mut mat = Self::default();
         let one = unsafe { T::whole(1) };
         for i in 0..DIMENSION {
@@ -66,9 +120,33 @@ impl<T: Numeric + Primitive, const DIMENSION: usize> SquareMatrix<T, DIMENSION>
         }
         mat
     }
+
+    /// Create a square matrix with its main diagonal filled with the given values.
+    ///
+    /// # Example
+    ///
+    /// ```rust
+    /// # use lineal::{SquareMatrix};
+    /// let a = SquareMatrix::diagonal(&[1, 2, 3]);
+    /// // is the same as
+    /// let b = SquareMatrix {
+    ///     data: [
+    ///         [1, 0, 0],
+    ///         [0, 2, 0],
+    ///         [0, 0, 3],
+    ///     ],
+    /// };
+    /// ```
+    pub fn diagonal(values: &[T; DIMENSION]) -> Self {
+        let mut mat = Self::default();
+        for i in 0..DIMENSION {
+            mat.data[i][i] = values[i];
+        }
+        mat
+    }
 }
 
-impl<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> Index<(usize, usize)> for GenericMatrix<T, ROWS, COLUMNS> {
+impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Index<(usize, usize)> for GenericMatrix<T, ROWS, COLUMNS> {
     type Output = T;
 
     fn index(&self, index: (usize, usize)) -> &Self::Output {
@@ -77,14 +155,14 @@ impl<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> Index<(usi
     }
 }
 
-impl<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> IndexMut<(usize, usize)> for GenericMatrix<T, ROWS, COLUMNS> {
+impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> IndexMut<(usize, usize)> for GenericMatrix<T, ROWS, COLUMNS> {
     fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
         let (row, column) = index;
         &mut self.data[row][column]
     }
 }
 
-impl<T: Numeric + Primitive + Neg<Output=T>, const ROWS: usize, const COLUMNS: usize> Neg for GenericMatrix<T, ROWS, COLUMNS> {
+impl<T: Numeric + Neg<Output=T>, const ROWS: usize, const COLUMNS: usize> Neg for GenericMatrix<T, ROWS, COLUMNS> {
     type Output = Self;
 
     fn neg(self) -> Self::Output {
@@ -98,7 +176,7 @@ impl<T: Numeric + Primitive + Neg<Output=T>, const ROWS: usize, const COLUMNS: u
     }
 }
 
-impl<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> Add for GenericMatrix<T, ROWS, COLUMNS> {
+impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Add for GenericMatrix<T, ROWS, COLUMNS> {
     type Output = Self;
 
     fn add(mut self, rhs: Self) -> Self::Output {
@@ -107,7 +185,7 @@ impl<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> Add for Ge
     }
 }
 
-impl<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> AddAssign for GenericMatrix<T, ROWS, COLUMNS> {
+impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> AddAssign for GenericMatrix<T, ROWS, COLUMNS> {
     fn add_assign(&mut self, rhs: Self) {
         for r in 0..ROWS {
             for c in 0..COLUMNS {
@@ -117,7 +195,7 @@ impl<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> AddAssign
     }
 }
 
-impl<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> Sub for GenericMatrix<T, ROWS, COLUMNS> {
+impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Sub for GenericMatrix<T, ROWS, COLUMNS> {
     type Output = Self;
 
     fn sub(mut self, rhs: Self) -> Self::Output {
@@ -126,7 +204,7 @@ impl<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> Sub for Ge
     }
 }
 
-impl<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> SubAssign for GenericMatrix<T, ROWS, COLUMNS> {
+impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> SubAssign for GenericMatrix<T, ROWS, COLUMNS> {
     fn sub_assign(&mut self, rhs: Self) {
         for r in 0..ROWS {
             for c in 0..COLUMNS {
@@ -136,7 +214,7 @@ impl<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> SubAssign
     }
 }
 
-impl<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> Mul<T> for GenericMatrix<T, ROWS, COLUMNS> {
+impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Mul<T> for GenericMatrix<T, ROWS, COLUMNS> {
     type Output = Self;
 
     fn mul(mut self, rhs: T) -> Self::Output {
@@ -145,7 +223,7 @@ impl<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> Mul<T> for
     }
 }
 
-impl<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> MulAssign<T> for GenericMatrix<T, ROWS, COLUMNS> {
+impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> MulAssign<T> for GenericMatrix<T, ROWS, COLUMNS> {
     fn mul_assign(&mut self, rhs: T) {
         for r in 0..ROWS {
             for c in 0..COLUMNS {
@@ -155,7 +233,7 @@ impl<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> MulAssign<
     }
 }
 
-impl<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> Div<T> for GenericMatrix<T, ROWS, COLUMNS> {
+impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Div<T> for GenericMatrix<T, ROWS, COLUMNS> {
     type Output = Self;
 
     fn div(mut self, rhs: T) -> Self::Output {
@@ -164,7 +242,7 @@ impl<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> Div<T> for
     }
 }
 
-impl<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> DivAssign<T> for GenericMatrix<T, ROWS, COLUMNS> {
+impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> DivAssign<T> for GenericMatrix<T, ROWS, COLUMNS> {
     fn div_assign(&mut self, rhs: T) {
         for r in 0..ROWS {
             for c in 0..COLUMNS {
@@ -174,7 +252,7 @@ impl<T: Numeric + Primitive, const ROWS: usize, const COLUMNS: usize> DivAssign<
     }
 }
 
-impl<T: Numeric + Primitive, const ROWS: usize, const COMMON: usize, const COLUMNS: usize> Mul<GenericMatrix<T, COMMON, COLUMNS>> for GenericMatrix<T, ROWS, COMMON> {
+impl<T: Numeric, const ROWS: usize, const COMMON: usize, const COLUMNS: usize> Mul<GenericMatrix<T, COMMON, COLUMNS>> for GenericMatrix<T, ROWS, COMMON> {
     type Output = GenericMatrix<T, ROWS, COLUMNS>;
 
     fn mul(self, rhs: GenericMatrix<T, COMMON, COLUMNS>) -> Self::Output {
@@ -192,8 +270,7 @@ impl<T: Numeric + Primitive, const ROWS: usize, const COMMON: usize, const COLUM
 
 #[cfg(test)]
 mod tests {
-    use crate::GenericMatrix;
-    use crate::types::matrix::generic::SquareMatrix;
+    use crate::{ColumnVector, Complex, cplx, GenericMatrix, RowVector, SquareMatrix};
 
     #[test]
     fn identity_matrix() {
@@ -205,6 +282,16 @@ mod tests {
         assert_eq!(mat, expected);
     }
 
+    #[test]
+    fn diagonal_matrix() {
+        let mat: SquareMatrix<f32, 3> = SquareMatrix::diagonal(&[1.0, 2.0, 3.0]);
+        let expected = SquareMatrix {
+            data: [[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]],
+        };
+
+        assert_eq!(mat, expected);
+    }
+
     #[test]
     #[should_panic = "index out of bounds"]
     fn out_of_bounds_access() {
@@ -216,7 +303,7 @@ mod tests {
     fn transposing_matrix() {
         let mat = GenericMatrix {
             data: [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
-        }.transpose();
+        }.transposed();
         let expected = GenericMatrix {
             data: [[1.0, 3.0, 5.0], [2.0, 4.0, 6.0]],
         };
@@ -244,4 +331,29 @@ mod tests {
 
         assert_eq!(product[(0, 0)], 3);
     }
+
+    #[test]
+    fn complex_matrix_multiplication() {
+        let a = SquareMatrix {
+            data: [
+                [cplx!(2.0, 1.0), cplx!(i = 5.0,)],
+                [cplx!(3.0), cplx!(3.0, -4.0)],
+            ],
+        };
+        let b = SquareMatrix {
+            data: [
+                [cplx!(1.0, -1.0), cplx!(4.0, 2.0)],
+                [cplx!(1.0, -6.0), cplx!(3.0)],
+            ],
+        };
+        let mat = a * b;
+        let expected = SquareMatrix {
+            data: [
+                [cplx!(33.0, 4.0), cplx!(6.0, 23.0)],
+                [cplx!(-18.0, -25.0), cplx!(21.0, -6.0)],
+            ],
+        };
+
+        assert_eq!(mat, expected);
+    }
 }

+ 1 - 1
src/types/matrix/mod.rs

@@ -1,3 +1,3 @@
-pub use self::generic::{ColumnVector, GenericMatrix, RowVector};
+pub use self::generic::{ColumnVector, GenericMatrix, RowVector, SquareMatrix};
 
 mod generic;

+ 1 - 1
src/types/mod.rs

@@ -1,5 +1,5 @@
 pub use self::complex::Complex;
-pub use self::matrix::{ColumnVector, GenericMatrix, RowVector};
+pub use self::matrix::{ColumnVector, GenericMatrix, RowVector, SquareMatrix};
 pub use self::quaternion::Quaternion;
 
 mod complex;