generic.rs 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640
  1. use std::ops::{Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Neg, Sub, SubAssign};
  2. use std::ptr::swap;
  3. use crate::{Float, Numeric, Primitive};
  4. #[cfg(feature = "angular")]
  5. use crate::Angular;
  6. /// Struct representing a dense matrix.
  7. #[repr(transparent)]
  8. #[derive(Copy, Clone, Debug, PartialEq)]
  9. pub struct GenericMatrix<T: Numeric, const ROWS: usize, const COLUMNS: usize> {
  10. pub data: [[T; COLUMNS]; ROWS],
  11. }
  12. /// Special kind of `GenericMatrix` where both dimensions are equal.
  13. pub type SquareMatrix<T, const DIMENSION: usize> = GenericMatrix<T, DIMENSION, DIMENSION>;
  14. /// Special kind of `GenericMatrix` with just one column.
  15. pub type ColumnVector<T, const ROWS: usize> = GenericMatrix<T, ROWS, 1>;
  16. /// Special kind of `GenericMatrix` with just one row.
  17. pub type RowVector<T, const COLUMNS: usize> = GenericMatrix<T, 1, COLUMNS>;
  18. /// A macro for easier creation of row vectors.
  19. ///
  20. /// # Example
  21. ///
  22. /// ```rust
  23. /// # use lineal::{RowVector, rvec};
  24. /// let a = rvec![1.0, 2.0, 3.0];
  25. /// // is the same as
  26. /// let b = RowVector {
  27. /// data: [[1.0, 2.0, 3.0]],
  28. /// };
  29. /// ```
  30. #[macro_export]
  31. macro_rules! rvec {
  32. ($($value:expr),* $(,)?) => {
  33. {
  34. RowVector {
  35. data: [[$( $value, )*]],
  36. }
  37. }
  38. }
  39. }
  40. /// A macro for easier creation of column vectors.
  41. ///
  42. /// # Example
  43. ///
  44. /// ```rust
  45. /// # use lineal::{ColumnVector, cvec};
  46. /// let a = cvec![1.0, 2.0, 3.0];
  47. /// // is the same as
  48. /// let b = ColumnVector {
  49. /// data: [[1.0], [2.0], [3.0]],
  50. /// };
  51. /// ```
  52. #[macro_export]
  53. macro_rules! cvec {
  54. ($($value:expr),* $(,)?) => {
  55. {
  56. ColumnVector {
  57. data: [$( [$value], )*],
  58. }
  59. }
  60. }
  61. }
  62. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Default for GenericMatrix<T, ROWS, COLUMNS> {
  63. /// Create a matrix with all cells being zero.
  64. fn default() -> Self {
  65. let zero = T::whole(0);
  66. Self {
  67. data: [[zero; COLUMNS]; ROWS],
  68. }
  69. }
  70. }
  71. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> GenericMatrix<T, ROWS, COLUMNS> {
  72. /// Create a matrix with all cells being zero.
  73. pub fn new() -> Self {
  74. Self::default()
  75. }
  76. /// Create a transpose of the input matrix.
  77. ///
  78. /// # Example
  79. ///
  80. /// ```rust
  81. /// # use lineal::{ColumnVector, RowVector, cvec, rvec};
  82. /// let a = cvec![1, 2, 3];
  83. /// let b = rvec![1, 2, 3];
  84. /// assert_eq!(a.transposed(), b);
  85. /// ```
  86. pub fn transposed(&self) -> GenericMatrix<T, COLUMNS, ROWS> {
  87. let mut result = GenericMatrix::default();
  88. for r in 0..ROWS {
  89. for c in 0..COLUMNS {
  90. result.data[c][r] = self.data[r][c];
  91. }
  92. }
  93. result
  94. }
  95. }
  96. impl<T: Numeric, const DIMENSION: usize> SquareMatrix<T, DIMENSION> {
  97. /// Create a square identity matrix.
  98. ///
  99. /// In an identity matrix the main diagonal is filled with ones,
  100. /// while the rest if the cells are zero.
  101. ///
  102. /// # Example
  103. ///
  104. /// ```rust
  105. /// # use lineal::{SquareMatrix};
  106. /// let a: SquareMatrix<f32, 3> = SquareMatrix::identity();
  107. /// // is the same as
  108. /// let b: SquareMatrix<f32, 3> = SquareMatrix {
  109. /// data: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
  110. /// };
  111. /// ```
  112. pub fn identity() -> Self {
  113. let mut mat = Self::default();
  114. let one = T::whole(1);
  115. for i in 0..DIMENSION {
  116. mat.data[i][i] = one;
  117. }
  118. mat
  119. }
  120. /// Create a square matrix with its main diagonal filled with the given values.
  121. ///
  122. /// # Example
  123. ///
  124. /// ```rust
  125. /// # use lineal::{SquareMatrix};
  126. /// let a = SquareMatrix::diagonal(&[1, 2, 3]);
  127. /// // is the same as
  128. /// let b = SquareMatrix {
  129. /// data: [
  130. /// [1, 0, 0],
  131. /// [0, 2, 0],
  132. /// [0, 0, 3],
  133. /// ],
  134. /// };
  135. /// ```
  136. pub fn diagonal(values: &[T; DIMENSION]) -> Self {
  137. let mut mat = Self::default();
  138. for i in 0..DIMENSION {
  139. mat.data[i][i] = values[i];
  140. }
  141. mat
  142. }
  143. }
  144. impl<T: Numeric + Float + PartialOrd, const DIMENSION: usize> SquareMatrix<T, DIMENSION> {
  145. /// Apply an LUP decomposition to the matrix.
  146. ///
  147. /// # Safety
  148. ///
  149. /// This function is unsafe, because even when the result is false,
  150. /// the matrix may already have been modified.
  151. pub unsafe fn lup_decompose(&mut self, pivot: &mut Vec<usize>, tolerance: T) -> bool {
  152. pivot.clear();
  153. pivot.reserve(DIMENSION + 1);
  154. for i in 0..(DIMENSION + 1) {
  155. pivot.push(i);
  156. }
  157. let zero = T::whole(0);
  158. for i in 0..DIMENSION {
  159. let mut max_a = zero;
  160. let mut max_i = i;
  161. for k in i..DIMENSION {
  162. let abs_a = self[(k, i)].abs();
  163. if abs_a > max_a {
  164. max_a = abs_a;
  165. max_i = k;
  166. }
  167. }
  168. if max_a <= tolerance {
  169. return false;
  170. }
  171. if max_i != i {
  172. pivot.swap(i, max_i);
  173. let p_a: *mut [T; DIMENSION] = (&mut self.data[i]) as *mut [T; DIMENSION];
  174. let p_a_max: *mut [T; DIMENSION] = (&mut self.data[max_i]) as *mut [T; DIMENSION];
  175. swap(p_a, p_a_max);
  176. pivot[DIMENSION] += 1;
  177. }
  178. for j in (i + 1)..DIMENSION {
  179. let divisor = self[(i, i)];
  180. self[(j, i)] /= divisor;
  181. for k in (i + 1)..DIMENSION {
  182. let a = self[(j, i)];
  183. let b = self[(i, k)];
  184. self[(j, k)] -= a * b;
  185. }
  186. }
  187. }
  188. true
  189. }
  190. /// Calculate the determinant of the input matrix.
  191. pub fn determinant(mut self, tolerance: T) -> Option<T> {
  192. let mut pivot = Vec::new();
  193. if !unsafe { self.lup_decompose(&mut pivot, tolerance) } {
  194. return None;
  195. }
  196. debug_assert_eq!(pivot.len(), DIMENSION + 1);
  197. let mut result = self[(0, 0)];
  198. for i in 1..DIMENSION {
  199. result *= self[(i, i)];
  200. }
  201. if (pivot[DIMENSION] - DIMENSION) % 2 == 0 {
  202. Some(result)
  203. } else {
  204. Some(-result)
  205. }
  206. }
  207. /// Solve the equation `self` * `x` = `b` for `x`.
  208. pub fn solve(mut self, b: &ColumnVector<T, DIMENSION>, tolerance: T) -> Option<RowVector<T, DIMENSION>> {
  209. let mut pivot = Vec::new();
  210. if !unsafe { self.lup_decompose(&mut pivot, tolerance) } {
  211. return None;
  212. }
  213. debug_assert_eq!(pivot.len(), DIMENSION + 1);
  214. let mut result = RowVector::new();
  215. for i in 0..DIMENSION {
  216. result[(0, i)] = b[(pivot[i], 0)];
  217. for k in 0..i {
  218. let factor = result[(0, k)];
  219. result[(0, i)] -= self[(i, k)] * factor;
  220. }
  221. }
  222. for i in (0..DIMENSION).rev() {
  223. for k in (i + 1)..DIMENSION {
  224. let factor = result[(0, k)];
  225. result[(0, i)] -= self[(i, k)] * factor;
  226. }
  227. result[(0, i)] /= self[(i, i)];
  228. }
  229. Some(result)
  230. }
  231. /// Calculate the inverse of the input matrix.
  232. pub fn inverted(mut self, tolerance: T) -> Option<Self> {
  233. let mut pivot = Vec::new();
  234. if !unsafe { self.lup_decompose(&mut pivot, tolerance) } {
  235. return None;
  236. }
  237. debug_assert_eq!(pivot.len(), DIMENSION + 1);
  238. let mut result = Self::new();
  239. let zero = T::whole(0);
  240. let one = T::whole(1);
  241. for j in 0..DIMENSION {
  242. for i in 0..DIMENSION {
  243. result[(i, j)] = if pivot[i] == j { one } else { zero };
  244. for k in 0..i {
  245. let factor = result[(k, j)];
  246. result[(i, j)] -= self[(i, k)] * factor;
  247. }
  248. }
  249. for i in (0..DIMENSION).rev() {
  250. for k in (i + 1)..DIMENSION {
  251. let factor = result[(k, j)];
  252. result[(i, j)] -= self[(i, k)] * factor;
  253. }
  254. result[(i, j)] /= self[(i, i)];
  255. }
  256. }
  257. Some(result)
  258. }
  259. }
  260. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> From<[[T; COLUMNS]; ROWS]> for GenericMatrix<T, ROWS, COLUMNS> {
  261. fn from(data: [[T; COLUMNS]; ROWS]) -> Self {
  262. Self {
  263. data,
  264. }
  265. }
  266. }
  267. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Index<(usize, usize)> for GenericMatrix<T, ROWS, COLUMNS> {
  268. type Output = T;
  269. fn index(&self, index: (usize, usize)) -> &Self::Output {
  270. let (row, column) = index;
  271. &self.data[row][column]
  272. }
  273. }
  274. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> IndexMut<(usize, usize)> for GenericMatrix<T, ROWS, COLUMNS> {
  275. fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
  276. let (row, column) = index;
  277. &mut self.data[row][column]
  278. }
  279. }
  280. impl<T: Numeric + Neg<Output=T>, const ROWS: usize, const COLUMNS: usize> Neg for GenericMatrix<T, ROWS, COLUMNS> {
  281. type Output = Self;
  282. fn neg(self) -> Self::Output {
  283. let mut result = Self::default();
  284. for r in 0..ROWS {
  285. for c in 0..COLUMNS {
  286. result.data[r][c] = -self.data[r][c];
  287. }
  288. }
  289. result
  290. }
  291. }
  292. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Add for GenericMatrix<T, ROWS, COLUMNS> {
  293. type Output = Self;
  294. fn add(mut self, rhs: Self) -> Self::Output {
  295. self += rhs;
  296. self
  297. }
  298. }
  299. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> AddAssign for GenericMatrix<T, ROWS, COLUMNS> {
  300. fn add_assign(&mut self, rhs: Self) {
  301. for r in 0..ROWS {
  302. for c in 0..COLUMNS {
  303. self.data[r][c] += rhs.data[r][c];
  304. }
  305. }
  306. }
  307. }
  308. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Sub for GenericMatrix<T, ROWS, COLUMNS> {
  309. type Output = Self;
  310. fn sub(mut self, rhs: Self) -> Self::Output {
  311. self -= rhs;
  312. self
  313. }
  314. }
  315. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> SubAssign for GenericMatrix<T, ROWS, COLUMNS> {
  316. fn sub_assign(&mut self, rhs: Self) {
  317. for r in 0..ROWS {
  318. for c in 0..COLUMNS {
  319. self.data[r][c] -= rhs.data[r][c];
  320. }
  321. }
  322. }
  323. }
  324. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Mul<T> for GenericMatrix<T, ROWS, COLUMNS> {
  325. type Output = Self;
  326. fn mul(mut self, rhs: T) -> Self::Output {
  327. self *= rhs;
  328. self
  329. }
  330. }
  331. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> MulAssign<T> for GenericMatrix<T, ROWS, COLUMNS> {
  332. fn mul_assign(&mut self, rhs: T) {
  333. for r in 0..ROWS {
  334. for c in 0..COLUMNS {
  335. self.data[r][c] *= rhs;
  336. }
  337. }
  338. }
  339. }
  340. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Div<T> for GenericMatrix<T, ROWS, COLUMNS> {
  341. type Output = Self;
  342. fn div(mut self, rhs: T) -> Self::Output {
  343. self /= rhs;
  344. self
  345. }
  346. }
  347. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> DivAssign<T> for GenericMatrix<T, ROWS, COLUMNS> {
  348. fn div_assign(&mut self, rhs: T) {
  349. for r in 0..ROWS {
  350. for c in 0..COLUMNS {
  351. self.data[r][c] /= rhs;
  352. }
  353. }
  354. }
  355. }
  356. impl<T: Numeric + Float + PartialOrd, const DIMENSION: usize> Div<SquareMatrix<T, DIMENSION>> for ColumnVector<T, DIMENSION> {
  357. type Output = RowVector<T, DIMENSION>;
  358. fn div(self, rhs: SquareMatrix<T, DIMENSION>) -> Self::Output {
  359. rhs.solve(&self, T::value(1e-9)).unwrap()
  360. }
  361. }
  362. impl<T: Numeric, const ROWS: usize, const COMMON: usize, const COLUMNS: usize> Mul<GenericMatrix<T, COMMON, COLUMNS>> for GenericMatrix<T, ROWS, COMMON> {
  363. type Output = GenericMatrix<T, ROWS, COLUMNS>;
  364. fn mul(self, rhs: GenericMatrix<T, COMMON, COLUMNS>) -> Self::Output {
  365. let mut result = Self::Output::default();
  366. for i in 0..ROWS {
  367. for j in 0..COLUMNS {
  368. for k in 0..COMMON {
  369. result.data[i][j] += self.data[i][k] * rhs.data[k][j];
  370. }
  371. }
  372. }
  373. result
  374. }
  375. }
  376. /// Apply a transformation matrix to a column vector.
  377. impl<T: Numeric + Float + Primitive> Mul<ColumnVector<T, 3>> for SquareMatrix<T, 4> {
  378. type Output = ColumnVector<T, 3>;
  379. fn mul(self, rhs: ColumnVector<T, 3>) -> Self::Output {
  380. let x_old = rhs[(0, 0)];
  381. let y_old = rhs[(1, 0)];
  382. let z_old = rhs[(2, 0)];
  383. cvec![
  384. self[(0,0)] * x_old + self[(0, 1)] * y_old + self[(0, 2)] * z_old + self[(0, 3)],
  385. self[(1,0)] * x_old + self[(1, 1)] * y_old + self[(1, 2)] * z_old + self[(1, 3)],
  386. self[(2,0)] * x_old + self[(2, 1)] * y_old + self[(2, 2)] * z_old + self[(2, 3)],
  387. ]
  388. }
  389. }
  390. impl<T: Numeric + Float + Primitive> SquareMatrix<T, 4> {
  391. /// Create a 3D translation matrix.
  392. pub fn translation(x: T, y: T, z: T) -> Self {
  393. let mut result = SquareMatrix::identity();
  394. result[(0, 3)] = x;
  395. result[(1, 3)] = y;
  396. result[(2, 3)] = z;
  397. result
  398. }
  399. /// Create a 3D scale matrix.
  400. pub fn scale(x: T, y: T, z: T) -> Self {
  401. let one = T::whole(1);
  402. SquareMatrix::diagonal(&[x, y, z, one])
  403. }
  404. fn rotation_impl(angle: T, x: T, y: T, z: T) -> Self {
  405. let zero = T::whole(0);
  406. let one = T::whole(1);
  407. let c = angle.cos();
  408. let s = angle.sin();
  409. let omc = one - c;
  410. Self::from([
  411. [x * x * omc + c, x * y * omc - z * s, x * z * omc + y * s, zero],
  412. [y * x * omc + z * s, y * y * omc + c, y * z * omc - x * s, zero],
  413. [x * z * omc - y * s, y * z * omc + x * s, z * z * omc + c, zero],
  414. [zero, zero, zero, one],
  415. ])
  416. }
  417. /// Create a 3D rotation matrix.
  418. ///
  419. /// `angle` is expected to be in radians.
  420. #[cfg(not(feature = "angular"))]
  421. pub fn rotation(angle: T, x: T, y: T, z: T) -> Self {
  422. Self::rotation_impl(angle, x, y, z)
  423. }
  424. /// Create a 3D rotation matrix.
  425. #[cfg(feature = "angular")]
  426. pub fn rotation<A: Angular<T>>(angle: A, x: T, y: T, z: T) -> Self {
  427. Self::rotation_impl(angle.radiant(), x, y, z)
  428. }
  429. }
  430. #[cfg(test)]
  431. mod tests {
  432. use crate::{ColumnVector, Complex, cplx, Float, GenericMatrix, Numeric, RowVector, SquareMatrix};
  433. fn matrices_are_equal<T: Numeric + Float + PartialOrd, const ROWS: usize, const COLUMNS: usize>(
  434. a: &GenericMatrix<T, ROWS, COLUMNS>,
  435. b: &GenericMatrix<T, ROWS, COLUMNS>,
  436. tolerance: T,
  437. ) -> bool {
  438. for r in 0..ROWS {
  439. for c in 0..COLUMNS {
  440. if (a[(r, c)] - b[(r, c)]).abs() > tolerance {
  441. return false;
  442. }
  443. }
  444. }
  445. true
  446. }
  447. #[test]
  448. fn identity_matrix() {
  449. let mat: SquareMatrix<f64, 3> = SquareMatrix::identity();
  450. let expected = SquareMatrix::from([
  451. [1.0, 0.0, 0.0],
  452. [0.0, 1.0, 0.0],
  453. [0.0, 0.0, 1.0],
  454. ]);
  455. assert_eq!(mat, expected);
  456. }
  457. #[test]
  458. fn diagonal_matrix() {
  459. let mat: SquareMatrix<f32, 3> = SquareMatrix::diagonal(&[1.0, 2.0, 3.0]);
  460. let expected = SquareMatrix {
  461. data: [[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]],
  462. };
  463. assert_eq!(mat, expected);
  464. }
  465. #[test]
  466. #[should_panic = "index out of bounds"]
  467. fn out_of_bounds_access() {
  468. let mat: SquareMatrix<f32, 1> = SquareMatrix::identity();
  469. mat[(1, 0)];
  470. }
  471. #[test]
  472. fn transposing_matrix() {
  473. let mat = GenericMatrix {
  474. data: [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
  475. }.transposed();
  476. let expected = GenericMatrix {
  477. data: [[1.0, 3.0, 5.0], [2.0, 4.0, 6.0]],
  478. };
  479. assert_eq!(mat, expected);
  480. }
  481. #[test]
  482. fn negating_matrix() {
  483. let mat = -GenericMatrix {
  484. data: [[1, -1], [2, -2]],
  485. };
  486. let expected = GenericMatrix {
  487. data: [[-1, 1], [-2, 2]],
  488. };
  489. assert_eq!(mat, expected);
  490. }
  491. #[test]
  492. fn dot_product_of_vectors() {
  493. let a = rvec![1, 3, -5];
  494. let b = cvec![4, -2, -1];
  495. let product = a * b;
  496. assert_eq!(product[(0, 0)], 3);
  497. }
  498. #[test]
  499. fn matrix_product_of_vectors() {
  500. let a = cvec![1, 3, -5];
  501. let b = rvec![4, -2, -1];
  502. let product = a * b;
  503. let expected = SquareMatrix {
  504. data: [
  505. [4, -2, -1],
  506. [12, -6, -3],
  507. [-20, 10, 5],
  508. ],
  509. };
  510. assert_eq!(product, expected);
  511. }
  512. #[test]
  513. fn complex_matrix_multiplication() {
  514. let a = SquareMatrix {
  515. data: [
  516. [cplx!(2.0, 1.0), cplx!(i = 5.0,)],
  517. [cplx!(3.0), cplx!(3.0, -4.0)],
  518. ],
  519. };
  520. let b = SquareMatrix {
  521. data: [
  522. [cplx!(1.0, -1.0), cplx!(4.0, 2.0)],
  523. [cplx!(1.0, -6.0), cplx!(3.0)],
  524. ],
  525. };
  526. let mat = a * b;
  527. let expected = SquareMatrix::from([
  528. [cplx!(33.0, 4.0), cplx!(6.0, 23.0)],
  529. [cplx!(-18.0, -25.0), cplx!(21.0, -6.0)],
  530. ]);
  531. assert_eq!(mat, expected);
  532. }
  533. #[test]
  534. fn invert_matrix() {
  535. let tolerance = 1e-9;
  536. let mat = SquareMatrix::from([
  537. [2.0, 1.0],
  538. [6.0, 4.0],
  539. ]);
  540. let inverted = mat.inverted(tolerance).unwrap();
  541. let expected = SquareMatrix::from([
  542. [2.0, -0.5],
  543. [-3.0, 1.0],
  544. ]);
  545. assert!(matrices_are_equal(&inverted, &expected, tolerance));
  546. assert!(matrices_are_equal(&(inverted * mat), &SquareMatrix::identity(), tolerance));
  547. }
  548. }