generic.rs 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720
  1. use std::ops::{Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Neg, Sub, SubAssign};
  2. use std::ptr::swap;
  3. use crate::{Float, Matrix, Numeric, Primitive, View, ViewMut};
  4. #[cfg(feature = "angular")]
  5. use crate::Angular;
  6. use crate::types::matrix::{compare, ComparisonResult, MatrixMut, multiply};
  7. /// Struct representing a dense matrix.
  8. #[repr(transparent)]
  9. #[derive(Copy, Clone, Debug)]
  10. pub struct GenericMatrix<T: Numeric, const ROWS: usize, const COLUMNS: usize> {
  11. pub data: [[T; COLUMNS]; ROWS],
  12. }
  13. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Matrix<T, ROWS, COLUMNS> for GenericMatrix<T, ROWS, COLUMNS> {
  14. fn get(&self, row: usize, col: usize) -> T {
  15. self.data[row][col]
  16. }
  17. }
  18. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> MatrixMut<T, ROWS, COLUMNS> for GenericMatrix<T, ROWS, COLUMNS> {
  19. fn set(&mut self, row: usize, col: usize, value: T) {
  20. self.data[row][col] = value;
  21. }
  22. }
  23. /// Special kind of `GenericMatrix` where both dimensions are equal.
  24. pub type SquareMatrix<T, const DIMENSION: usize> = GenericMatrix<T, DIMENSION, DIMENSION>;
  25. /// Special kind of `GenericMatrix` with just one column.
  26. pub type ColumnVector<T, const ROWS: usize> = GenericMatrix<T, ROWS, 1>;
  27. /// Special kind of `GenericMatrix` with just one row.
  28. pub type RowVector<T, const COLUMNS: usize> = GenericMatrix<T, 1, COLUMNS>;
  29. /// A macro for easier creation of row vectors.
  30. ///
  31. /// # Example
  32. ///
  33. /// ```rust
  34. /// # use lineal::{RowVector, rvec};
  35. /// let a = rvec![1.0, 2.0, 3.0];
  36. /// // is the same as
  37. /// let b = RowVector {
  38. /// data: [[1.0, 2.0, 3.0]],
  39. /// };
  40. /// ```
  41. #[macro_export]
  42. macro_rules! rvec {
  43. ($($value:expr),* $(,)?) => {
  44. {
  45. RowVector {
  46. data: [[$( $value, )*]],
  47. }
  48. }
  49. }
  50. }
  51. /// A macro for easier creation of column vectors.
  52. ///
  53. /// # Example
  54. ///
  55. /// ```rust
  56. /// # use lineal::{ColumnVector, cvec};
  57. /// let a = cvec![1.0, 2.0, 3.0];
  58. /// // is the same as
  59. /// let b = ColumnVector {
  60. /// data: [[1.0], [2.0], [3.0]],
  61. /// };
  62. /// ```
  63. #[macro_export]
  64. macro_rules! cvec {
  65. ($($value:expr),* $(,)?) => {
  66. {
  67. ColumnVector {
  68. data: [$( [$value], )*],
  69. }
  70. }
  71. }
  72. }
  73. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Default for GenericMatrix<T, ROWS, COLUMNS> {
  74. /// Create a matrix with all cells being zero.
  75. fn default() -> Self {
  76. let zero = T::whole(0);
  77. Self {
  78. data: [[zero; COLUMNS]; ROWS],
  79. }
  80. }
  81. }
  82. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> GenericMatrix<T, ROWS, COLUMNS> {
  83. /// Create a matrix with all cells being zero.
  84. pub fn new() -> Self {
  85. Self::default()
  86. }
  87. /// Create a view for a part of the matrix.
  88. ///
  89. /// # Example
  90. ///
  91. /// ```rust
  92. /// # use lineal::GenericMatrix;
  93. /// let mat = GenericMatrix::from([
  94. /// [1, 2, 3, 4],
  95. /// [5, 6, 7, 8],
  96. /// [9, 10, 11, 12],
  97. /// ]);
  98. /// // 3 and 2 are the size of the view, 0 and 1 are the starting indices of the view.
  99. /// let view = mat.view::<3, 2>(0, 1);
  100. /// assert_eq!(
  101. /// view,
  102. /// GenericMatrix::from([
  103. /// [2, 3],
  104. /// [6, 7],
  105. /// [10, 11],
  106. /// ])
  107. /// );
  108. /// ```
  109. pub fn view<const VIEW_ROWS: usize, const VIEW_COLUMNS: usize>(&self, base_row: usize, base_col: usize) -> View<T, VIEW_ROWS, VIEW_COLUMNS, ROWS, COLUMNS, Self> {
  110. View::new(self, (base_row, base_col))
  111. }
  112. /// Create a mutable view for a part of the matrix.
  113. ///
  114. /// # Example
  115. ///
  116. /// ```rust
  117. /// # use lineal::{GenericMatrix, MatrixMut};
  118. /// let mut mat = GenericMatrix::from([
  119. /// [1, 2, 3, 4],
  120. /// [5, 6, 7, 8],
  121. /// [9, 10, 11, 12],
  122. /// ]);
  123. /// {
  124. /// // 3 and 2 are the size of the view, 0 and 1 are the starting indices of the view.
  125. /// let mut view = mat.view_mut::<3, 2>(0, 1);
  126. /// view.set(1, 0, 23);
  127. /// view.set(1, 1, 42);
  128. /// assert_eq!(
  129. /// view,
  130. /// GenericMatrix::from([
  131. /// [2, 3],
  132. /// [23, 42],
  133. /// [10, 11],
  134. /// ])
  135. /// );
  136. /// }
  137. /// assert_eq!(
  138. /// mat,
  139. /// GenericMatrix::from([
  140. /// [1, 2, 3, 4],
  141. /// [5, 23, 42, 8],
  142. /// [9, 10, 11, 12],
  143. /// ])
  144. /// );
  145. /// ```
  146. pub fn view_mut<const VIEW_ROWS: usize, const VIEW_COLUMNS: usize>(&mut self, base_row: usize, base_col: usize) -> ViewMut<T, VIEW_ROWS, VIEW_COLUMNS, ROWS, COLUMNS, Self> {
  147. ViewMut::new(self, (base_row, base_col))
  148. }
  149. /// Create a transpose of the input matrix.
  150. ///
  151. /// # Example
  152. ///
  153. /// ```rust
  154. /// # use lineal::{ColumnVector, RowVector, cvec, rvec};
  155. /// let a = cvec![1, 2, 3];
  156. /// let b = rvec![1, 2, 3];
  157. /// assert_eq!(a.transposed(), b);
  158. /// ```
  159. pub fn transposed(&self) -> GenericMatrix<T, COLUMNS, ROWS> {
  160. let mut result = GenericMatrix::default();
  161. for r in 0..ROWS {
  162. for c in 0..COLUMNS {
  163. result.data[c][r] = self.data[r][c];
  164. }
  165. }
  166. result
  167. }
  168. }
  169. impl<T: Numeric, const DIMENSION: usize> SquareMatrix<T, DIMENSION> {
  170. /// Create a square identity matrix.
  171. ///
  172. /// In an identity matrix the main diagonal is filled with ones,
  173. /// while the rest if the cells are zero.
  174. ///
  175. /// # Example
  176. ///
  177. /// ```rust
  178. /// # use lineal::{SquareMatrix};
  179. /// let a: SquareMatrix<f32, 3> = SquareMatrix::identity();
  180. /// // is the same as
  181. /// let b: SquareMatrix<f32, 3> = SquareMatrix {
  182. /// data: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
  183. /// };
  184. /// ```
  185. pub fn identity() -> Self {
  186. let mut mat = Self::default();
  187. let one = T::whole(1);
  188. for i in 0..DIMENSION {
  189. mat.data[i][i] = one;
  190. }
  191. mat
  192. }
  193. /// Create a square matrix with its main diagonal filled with the given values.
  194. ///
  195. /// # Example
  196. ///
  197. /// ```rust
  198. /// # use lineal::{SquareMatrix};
  199. /// let a = SquareMatrix::diagonal(&[1, 2, 3]);
  200. /// // is the same as
  201. /// let b = SquareMatrix {
  202. /// data: [
  203. /// [1, 0, 0],
  204. /// [0, 2, 0],
  205. /// [0, 0, 3],
  206. /// ],
  207. /// };
  208. /// ```
  209. pub fn diagonal(values: &[T; DIMENSION]) -> Self {
  210. let mut mat = Self::default();
  211. for i in 0..DIMENSION {
  212. mat.data[i][i] = values[i];
  213. }
  214. mat
  215. }
  216. }
  217. impl<T: Numeric + Float + PartialOrd, const DIMENSION: usize> SquareMatrix<T, DIMENSION> {
  218. /// Apply an LUP decomposition to the matrix.
  219. ///
  220. /// # Safety
  221. ///
  222. /// This function is unsafe, because even when the result is false,
  223. /// the matrix may already have been modified.
  224. pub unsafe fn lup_decompose(&mut self, pivot: &mut Vec<usize>, tolerance: T) -> bool {
  225. pivot.clear();
  226. pivot.reserve(DIMENSION + 1);
  227. for i in 0..(DIMENSION + 1) {
  228. pivot.push(i);
  229. }
  230. let zero = T::whole(0);
  231. for i in 0..DIMENSION {
  232. let mut max_a = zero;
  233. let mut max_i = i;
  234. for k in i..DIMENSION {
  235. let abs_a = self[(k, i)].abs();
  236. if abs_a > max_a {
  237. max_a = abs_a;
  238. max_i = k;
  239. }
  240. }
  241. if max_a <= tolerance {
  242. return false;
  243. }
  244. if max_i != i {
  245. pivot.swap(i, max_i);
  246. let p_a: *mut [T; DIMENSION] = (&mut self.data[i]) as *mut [T; DIMENSION];
  247. let p_a_max: *mut [T; DIMENSION] = (&mut self.data[max_i]) as *mut [T; DIMENSION];
  248. swap(p_a, p_a_max);
  249. pivot[DIMENSION] += 1;
  250. }
  251. for j in (i + 1)..DIMENSION {
  252. let divisor = self[(i, i)];
  253. self[(j, i)] /= divisor;
  254. for k in (i + 1)..DIMENSION {
  255. let a = self[(j, i)];
  256. let b = self[(i, k)];
  257. self[(j, k)] -= a * b;
  258. }
  259. }
  260. }
  261. true
  262. }
  263. /// Calculate the determinant of the input matrix.
  264. pub fn determinant(mut self, tolerance: T) -> Option<T> {
  265. let mut pivot = Vec::new();
  266. if !unsafe { self.lup_decompose(&mut pivot, tolerance) } {
  267. return None;
  268. }
  269. debug_assert_eq!(pivot.len(), DIMENSION + 1);
  270. let mut result = self[(0, 0)];
  271. for i in 1..DIMENSION {
  272. result *= self[(i, i)];
  273. }
  274. if (pivot[DIMENSION] - DIMENSION) % 2 == 0 {
  275. Some(result)
  276. } else {
  277. Some(-result)
  278. }
  279. }
  280. /// Solve the equation `self` * `x` = `b` for `x`.
  281. pub fn solve(mut self, b: &ColumnVector<T, DIMENSION>, tolerance: T) -> Option<RowVector<T, DIMENSION>> {
  282. let mut pivot = Vec::new();
  283. if !unsafe { self.lup_decompose(&mut pivot, tolerance) } {
  284. return None;
  285. }
  286. debug_assert_eq!(pivot.len(), DIMENSION + 1);
  287. let mut result = RowVector::new();
  288. for i in 0..DIMENSION {
  289. result[(0, i)] = b[(pivot[i], 0)];
  290. for k in 0..i {
  291. let factor = result[(0, k)];
  292. result[(0, i)] -= self[(i, k)] * factor;
  293. }
  294. }
  295. for i in (0..DIMENSION).rev() {
  296. for k in (i + 1)..DIMENSION {
  297. let factor = result[(0, k)];
  298. result[(0, i)] -= self[(i, k)] * factor;
  299. }
  300. result[(0, i)] /= self[(i, i)];
  301. }
  302. Some(result)
  303. }
  304. /// Calculate the inverse of the input matrix.
  305. pub fn inverted(mut self, tolerance: T) -> Option<Self> {
  306. let mut pivot = Vec::new();
  307. if !unsafe { self.lup_decompose(&mut pivot, tolerance) } {
  308. return None;
  309. }
  310. debug_assert_eq!(pivot.len(), DIMENSION + 1);
  311. let mut result = Self::new();
  312. let zero = T::whole(0);
  313. let one = T::whole(1);
  314. for j in 0..DIMENSION {
  315. for i in 0..DIMENSION {
  316. result[(i, j)] = if pivot[i] == j { one } else { zero };
  317. for k in 0..i {
  318. let factor = result[(k, j)];
  319. result[(i, j)] -= self[(i, k)] * factor;
  320. }
  321. }
  322. for i in (0..DIMENSION).rev() {
  323. for k in (i + 1)..DIMENSION {
  324. let factor = result[(k, j)];
  325. result[(i, j)] -= self[(i, k)] * factor;
  326. }
  327. result[(i, j)] /= self[(i, i)];
  328. }
  329. }
  330. Some(result)
  331. }
  332. }
  333. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> From<[[T; COLUMNS]; ROWS]> for GenericMatrix<T, ROWS, COLUMNS> {
  334. fn from(data: [[T; COLUMNS]; ROWS]) -> Self {
  335. Self {
  336. data,
  337. }
  338. }
  339. }
  340. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Index<(usize, usize)> for GenericMatrix<T, ROWS, COLUMNS> {
  341. type Output = T;
  342. fn index(&self, index: (usize, usize)) -> &Self::Output {
  343. let (row, column) = index;
  344. &self.data[row][column]
  345. }
  346. }
  347. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> IndexMut<(usize, usize)> for GenericMatrix<T, ROWS, COLUMNS> {
  348. fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
  349. let (row, column) = index;
  350. &mut self.data[row][column]
  351. }
  352. }
  353. impl<T: Numeric + Neg<Output=T>, const ROWS: usize, const COLUMNS: usize> Neg for GenericMatrix<T, ROWS, COLUMNS> {
  354. type Output = Self;
  355. fn neg(self) -> Self::Output {
  356. let mut result = Self::default();
  357. for r in 0..ROWS {
  358. for c in 0..COLUMNS {
  359. result.data[r][c] = -self.data[r][c];
  360. }
  361. }
  362. result
  363. }
  364. }
  365. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize, RHS: Matrix<T, ROWS, COLUMNS>> Add<RHS> for GenericMatrix<T, ROWS, COLUMNS> {
  366. type Output = Self;
  367. fn add(mut self, rhs: RHS) -> Self::Output {
  368. self += rhs;
  369. self
  370. }
  371. }
  372. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize, RHS: Matrix<T, ROWS, COLUMNS>> AddAssign<RHS> for GenericMatrix<T, ROWS, COLUMNS> {
  373. fn add_assign(&mut self, rhs: RHS) {
  374. for r in 0..ROWS {
  375. for c in 0..COLUMNS {
  376. self.data[r][c] += rhs.get(r, c);
  377. }
  378. }
  379. }
  380. }
  381. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize, RHS: Matrix<T, ROWS, COLUMNS>> Sub<RHS> for GenericMatrix<T, ROWS, COLUMNS> {
  382. type Output = Self;
  383. fn sub(mut self, rhs: RHS) -> Self::Output {
  384. self -= rhs;
  385. self
  386. }
  387. }
  388. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize, RHS: Matrix<T, ROWS, COLUMNS>> SubAssign<RHS> for GenericMatrix<T, ROWS, COLUMNS> {
  389. fn sub_assign(&mut self, rhs: RHS) {
  390. for r in 0..ROWS {
  391. for c in 0..COLUMNS {
  392. self.data[r][c] -= rhs.get(r, c);
  393. }
  394. }
  395. }
  396. }
  397. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Mul<T> for GenericMatrix<T, ROWS, COLUMNS> {
  398. type Output = Self;
  399. fn mul(mut self, rhs: T) -> Self::Output {
  400. self *= rhs;
  401. self
  402. }
  403. }
  404. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> MulAssign<T> for GenericMatrix<T, ROWS, COLUMNS> {
  405. fn mul_assign(&mut self, rhs: T) {
  406. for r in 0..ROWS {
  407. for c in 0..COLUMNS {
  408. self.data[r][c] *= rhs;
  409. }
  410. }
  411. }
  412. }
  413. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Div<T> for GenericMatrix<T, ROWS, COLUMNS> {
  414. type Output = Self;
  415. fn div(mut self, rhs: T) -> Self::Output {
  416. self /= rhs;
  417. self
  418. }
  419. }
  420. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> DivAssign<T> for GenericMatrix<T, ROWS, COLUMNS> {
  421. fn div_assign(&mut self, rhs: T) {
  422. for r in 0..ROWS {
  423. for c in 0..COLUMNS {
  424. self.data[r][c] /= rhs;
  425. }
  426. }
  427. }
  428. }
  429. impl<T: Numeric + Float + PartialOrd, const DIMENSION: usize> Div<SquareMatrix<T, DIMENSION>> for ColumnVector<T, DIMENSION> {
  430. type Output = RowVector<T, DIMENSION>;
  431. fn div(self, rhs: SquareMatrix<T, DIMENSION>) -> Self::Output {
  432. rhs.solve(&self, T::value(1e-9)).unwrap()
  433. }
  434. }
  435. impl<T: Numeric, const ROWS: usize, const COMMON: usize, const COLUMNS: usize> Mul<GenericMatrix<T, COMMON, COLUMNS>> for GenericMatrix<T, ROWS, COMMON> {
  436. type Output = GenericMatrix<T, ROWS, COLUMNS>;
  437. fn mul(self, rhs: GenericMatrix<T, COMMON, COLUMNS>) -> Self::Output {
  438. let mut result = Self::Output::default();
  439. multiply(&mut result, &self, &rhs);
  440. result
  441. }
  442. }
  443. /// Apply a transformation matrix to a column vector.
  444. impl<T: Numeric + Float + Primitive> Mul<ColumnVector<T, 3>> for SquareMatrix<T, 4> {
  445. type Output = RowVector<T, 3>;
  446. fn mul(self, rhs: ColumnVector<T, 3>) -> Self::Output {
  447. let x_old = rhs[(0, 0)];
  448. let y_old = rhs[(1, 0)];
  449. let z_old = rhs[(2, 0)];
  450. rvec![
  451. self[(0,0)] * x_old + self[(0, 1)] * y_old + self[(0, 2)] * z_old + self[(0, 3)],
  452. self[(1,0)] * x_old + self[(1, 1)] * y_old + self[(1, 2)] * z_old + self[(1, 3)],
  453. self[(2,0)] * x_old + self[(2, 1)] * y_old + self[(2, 2)] * z_old + self[(2, 3)],
  454. ]
  455. }
  456. }
  457. impl<T: Numeric + Float + Primitive> SquareMatrix<T, 4> {
  458. /// Create a 3D translation matrix.
  459. pub fn translation(x: T, y: T, z: T) -> Self {
  460. let mut result = SquareMatrix::identity();
  461. result[(0, 3)] = x;
  462. result[(1, 3)] = y;
  463. result[(2, 3)] = z;
  464. result
  465. }
  466. /// Create a 3D scale matrix.
  467. pub fn scale(x: T, y: T, z: T) -> Self {
  468. let one = T::whole(1);
  469. SquareMatrix::diagonal(&[x, y, z, one])
  470. }
  471. fn rotation_impl(angle: T, x: T, y: T, z: T) -> Self {
  472. let zero = T::whole(0);
  473. let one = T::whole(1);
  474. let c = angle.cos();
  475. let s = angle.sin();
  476. let omc = one - c;
  477. Self::from([
  478. [x * x * omc + c, x * y * omc - z * s, x * z * omc + y * s, zero],
  479. [y * x * omc + z * s, y * y * omc + c, y * z * omc - x * s, zero],
  480. [x * z * omc - y * s, y * z * omc + x * s, z * z * omc + c, zero],
  481. [zero, zero, zero, one],
  482. ])
  483. }
  484. /// Create a 3D rotation matrix.
  485. ///
  486. /// `angle` is expected to be in radians.
  487. #[cfg(not(feature = "angular"))]
  488. pub fn rotation(angle: T, x: T, y: T, z: T) -> Self {
  489. Self::rotation_impl(angle, x, y, z)
  490. }
  491. /// Create a 3D rotation matrix.
  492. #[cfg(feature = "angular")]
  493. pub fn rotation<A: Angular<T>>(angle: A, x: T, y: T, z: T) -> Self {
  494. Self::rotation_impl(angle.radiant(), x, y, z)
  495. }
  496. }
  497. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize, RHS: Matrix<T, ROWS, COLUMNS>> PartialEq<RHS> for GenericMatrix<T, ROWS, COLUMNS> {
  498. fn eq(&self, other: &RHS) -> bool {
  499. compare(self, other, T::epsilon()) == ComparisonResult::Equal
  500. }
  501. fn ne(&self, other: &RHS) -> bool {
  502. compare(self, other, T::epsilon()) == ComparisonResult::NotEqual
  503. }
  504. }
  505. #[cfg(test)]
  506. mod tests {
  507. use crate::{ColumnVector, Complex, cplx, GenericMatrix, Matrix, MatrixMut, RowVector, SquareMatrix};
  508. #[test]
  509. fn identity_matrix() {
  510. let mat: SquareMatrix<f64, 3> = SquareMatrix::identity();
  511. let expected = SquareMatrix::from([
  512. [1.0, 0.0, 0.0],
  513. [0.0, 1.0, 0.0],
  514. [0.0, 0.0, 1.0],
  515. ]);
  516. assert_eq!(mat, expected);
  517. }
  518. #[test]
  519. fn diagonal_matrix() {
  520. let mat: SquareMatrix<f32, 3> = SquareMatrix::diagonal(&[1.0, 2.0, 3.0]);
  521. let expected = SquareMatrix {
  522. data: [[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]],
  523. };
  524. assert_eq!(mat, expected);
  525. }
  526. #[test]
  527. #[should_panic = "index out of bounds"]
  528. fn out_of_bounds_access() {
  529. let mat: SquareMatrix<f32, 1> = SquareMatrix::identity();
  530. mat[(1, 0)];
  531. }
  532. #[test]
  533. fn transposing_matrix() {
  534. let mat = GenericMatrix {
  535. data: [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
  536. }.transposed();
  537. let expected = GenericMatrix {
  538. data: [[1.0, 3.0, 5.0], [2.0, 4.0, 6.0]],
  539. };
  540. assert_eq!(mat, expected);
  541. }
  542. #[test]
  543. fn negating_matrix() {
  544. let mat = -GenericMatrix {
  545. data: [[1, -1], [2, -2]],
  546. };
  547. let expected = GenericMatrix {
  548. data: [[-1, 1], [-2, 2]],
  549. };
  550. assert_eq!(mat, expected);
  551. }
  552. #[test]
  553. fn dot_product_of_vectors() {
  554. let a = rvec![1, 3, -5];
  555. let b = cvec![4, -2, -1];
  556. let product = a * b;
  557. assert_eq!(product[(0, 0)], 3);
  558. }
  559. #[test]
  560. fn matrix_product_of_vectors() {
  561. let a = cvec![1, 3, -5];
  562. let b = rvec![4, -2, -1];
  563. let product = a * b;
  564. let expected = SquareMatrix {
  565. data: [
  566. [4, -2, -1],
  567. [12, -6, -3],
  568. [-20, 10, 5],
  569. ],
  570. };
  571. assert_eq!(product, expected);
  572. }
  573. #[test]
  574. fn complex_matrix_multiplication() {
  575. let a = SquareMatrix {
  576. data: [
  577. [cplx!(2.0, 1.0), cplx!(i = 5.0,)],
  578. [cplx!(3.0), cplx!(3.0, -4.0)],
  579. ],
  580. };
  581. let b = SquareMatrix {
  582. data: [
  583. [cplx!(1.0, -1.0), cplx!(4.0, 2.0)],
  584. [cplx!(1.0, -6.0), cplx!(3.0)],
  585. ],
  586. };
  587. let mat = a * b;
  588. let expected = SquareMatrix::from([
  589. [cplx!(33.0, 4.0), cplx!(6.0, 23.0)],
  590. [cplx!(-18.0, -25.0), cplx!(21.0, -6.0)],
  591. ]);
  592. assert_eq!(mat, expected);
  593. }
  594. #[test]
  595. fn invert_matrix() {
  596. let tolerance = 1e-9;
  597. let mat = SquareMatrix::from([
  598. [2.0, 1.0],
  599. [6.0, 4.0],
  600. ]);
  601. let inverted = mat.inverted(tolerance).unwrap();
  602. let expected = SquareMatrix::from([
  603. [2.0, -0.5],
  604. [-3.0, 1.0],
  605. ]);
  606. assert_eq!(inverted, expected);
  607. assert_eq!(inverted * mat, SquareMatrix::identity());
  608. }
  609. #[test]
  610. #[should_panic = "index out of bounds"]
  611. fn out_of_range_get() {
  612. let mat = SquareMatrix::from([[1]]);
  613. mat.get(1, 1);
  614. }
  615. #[test]
  616. #[should_panic = "index out of bounds"]
  617. fn out_of_range_set() {
  618. let mut mat = SquareMatrix::from([[1]]);
  619. mat.set(1, 1, 2);
  620. }
  621. }