generic.rs 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599
  1. use std::ops::{Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Neg, Sub, SubAssign};
  2. use std::ptr::swap;
  3. use crate::{Float, Numeric, Primitive};
  4. #[cfg(feature = "angular")]
  5. use crate::Angular;
  6. /// Struct representing a dense matrix.
  7. #[repr(transparent)]
  8. #[derive(Copy, Clone, Debug, PartialEq)]
  9. pub struct GenericMatrix<T: Numeric, const ROWS: usize, const COLUMNS: usize> {
  10. pub data: [[T; COLUMNS]; ROWS],
  11. }
  12. /// Special kind of `GenericMatrix` where both dimensions are equal.
  13. pub type SquareMatrix<T, const DIMENSION: usize> = GenericMatrix<T, DIMENSION, DIMENSION>;
  14. /// Special kind of `GenericMatrix` with just one column.
  15. pub type ColumnVector<T, const ROWS: usize> = GenericMatrix<T, ROWS, 1>;
  16. /// Special kind of `GenericMatrix` with just one row.
  17. pub type RowVector<T, const COLUMNS: usize> = GenericMatrix<T, 1, COLUMNS>;
  18. /// A macro for easier creation of row vectors.
  19. ///
  20. /// # Example
  21. ///
  22. /// ```rust
  23. /// # use lineal::{RowVector, rvec};
  24. /// let a = rvec![1.0, 2.0, 3.0];
  25. /// // is the same as
  26. /// let b = RowVector {
  27. /// data: [[1.0, 2.0, 3.0]],
  28. /// };
  29. /// ```
  30. #[macro_export]
  31. macro_rules! rvec {
  32. ($($value:expr),* $(,)?) => {
  33. {
  34. RowVector {
  35. data: [[$( $value, )*]],
  36. }
  37. }
  38. }
  39. }
  40. /// A macro for easier creation of column vectors.
  41. ///
  42. /// # Example
  43. ///
  44. /// ```rust
  45. /// # use lineal::{ColumnVector, cvec};
  46. /// let a = cvec![1.0, 2.0, 3.0];
  47. /// // is the same as
  48. /// let b = ColumnVector {
  49. /// data: [[1.0], [2.0], [3.0]],
  50. /// };
  51. /// ```
  52. #[macro_export]
  53. macro_rules! cvec {
  54. ($($value:expr),* $(,)?) => {
  55. {
  56. ColumnVector {
  57. data: [$( [$value], )*],
  58. }
  59. }
  60. }
  61. }
  62. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Default for GenericMatrix<T, ROWS, COLUMNS> {
  63. /// Create a matrix with all cells being zero.
  64. fn default() -> Self {
  65. let zero = unsafe { T::whole(0) };
  66. Self {
  67. data: [[zero; COLUMNS]; ROWS],
  68. }
  69. }
  70. }
  71. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> GenericMatrix<T, ROWS, COLUMNS> {
  72. /// Create a matrix with all cells being zero.
  73. pub fn new() -> Self {
  74. Self::default()
  75. }
  76. /// Create a transpose of the input matrix.
  77. ///
  78. /// # Example
  79. ///
  80. /// ```rust
  81. /// # use lineal::{ColumnVector, RowVector, cvec, rvec};
  82. /// let a = cvec![1, 2, 3];
  83. /// let b = rvec![1, 2, 3];
  84. /// assert_eq!(a.transposed(), b);
  85. /// ```
  86. pub fn transposed(&self) -> GenericMatrix<T, COLUMNS, ROWS> {
  87. let mut result = GenericMatrix::default();
  88. for r in 0..ROWS {
  89. for c in 0..COLUMNS {
  90. result.data[c][r] = self.data[r][c];
  91. }
  92. }
  93. result
  94. }
  95. }
  96. impl<T: Numeric, const DIMENSION: usize> SquareMatrix<T, DIMENSION> {
  97. /// Create a square identity matrix.
  98. ///
  99. /// In an identity matrix the main diagonal is filled with ones,
  100. /// while the rest if the cells are zero.
  101. ///
  102. /// # Example
  103. ///
  104. /// ```rust
  105. /// # use lineal::{SquareMatrix};
  106. /// let a: SquareMatrix<f32, 3> = SquareMatrix::identity();
  107. /// // is the same as
  108. /// let b: SquareMatrix<f32, 3> = SquareMatrix {
  109. /// data: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
  110. /// };
  111. /// ```
  112. pub fn identity() -> Self {
  113. let mut mat = Self::default();
  114. let one = unsafe { T::whole(1) };
  115. for i in 0..DIMENSION {
  116. mat.data[i][i] = one;
  117. }
  118. mat
  119. }
  120. /// Create a square matrix with its main diagonal filled with the given values.
  121. ///
  122. /// # Example
  123. ///
  124. /// ```rust
  125. /// # use lineal::{SquareMatrix};
  126. /// let a = SquareMatrix::diagonal(&[1, 2, 3]);
  127. /// // is the same as
  128. /// let b = SquareMatrix {
  129. /// data: [
  130. /// [1, 0, 0],
  131. /// [0, 2, 0],
  132. /// [0, 0, 3],
  133. /// ],
  134. /// };
  135. /// ```
  136. pub fn diagonal(values: &[T; DIMENSION]) -> Self {
  137. let mut mat = Self::default();
  138. for i in 0..DIMENSION {
  139. mat.data[i][i] = values[i];
  140. }
  141. mat
  142. }
  143. }
  144. impl<T: Numeric + Float + PartialOrd, const DIMENSION: usize> SquareMatrix<T, DIMENSION> {
  145. /// Apply an LUP decomposition to the matrix.
  146. ///
  147. /// # Safety
  148. ///
  149. /// This function is unsafe, because even when the result is false,
  150. /// the matrix may already have been modified.
  151. pub unsafe fn lup_decompose(&mut self, pivot: &mut Vec<usize>, tolerance: T) -> bool {
  152. pivot.clear();
  153. pivot.reserve(DIMENSION + 1);
  154. for i in 0..(DIMENSION + 1) {
  155. pivot.push(i);
  156. }
  157. let zero = T::whole(0);
  158. for i in 0..DIMENSION {
  159. let mut max_a = zero;
  160. let mut max_i = i;
  161. for k in i..DIMENSION {
  162. let abs_a = self[(k, i)].abs();
  163. if abs_a > max_a {
  164. max_a = abs_a;
  165. max_i = k;
  166. }
  167. }
  168. if max_a <= tolerance {
  169. return false;
  170. }
  171. if max_i != i {
  172. pivot.swap(i, max_i);
  173. let p_a: *mut [T; DIMENSION] = (&mut self.data[i]) as *mut [T; DIMENSION];
  174. let p_a_max: *mut [T; DIMENSION] = (&mut self.data[max_i]) as *mut [T; DIMENSION];
  175. swap(p_a, p_a_max);
  176. pivot[DIMENSION] += 1;
  177. }
  178. for j in (i + 1)..DIMENSION {
  179. let divisor = self[(i, i)];
  180. self[(j, i)] /= divisor;
  181. for k in (i + 1)..DIMENSION {
  182. let a = self[(j, i)];
  183. let b = self[(i, k)];
  184. self[(j, k)] -= a * b;
  185. }
  186. }
  187. }
  188. true
  189. }
  190. /// Calculate the determinant of the input matrix.
  191. pub fn determinant(mut self, tolerance: T) -> Option<T> {
  192. let mut pivot = Vec::new();
  193. if !unsafe { self.lup_decompose(&mut pivot, tolerance) } {
  194. return None;
  195. }
  196. let mut result = self[(0, 0)];
  197. for i in 1..DIMENSION {
  198. result *= self[(i, i)];
  199. }
  200. if (pivot[DIMENSION] - DIMENSION) % 2 == 0 {
  201. Some(result)
  202. } else {
  203. Some(-result)
  204. }
  205. }
  206. /// Calculate the inverse of the input matrix.
  207. pub fn inverted(mut self, tolerance: T) -> Option<Self> {
  208. let mut pivot = Vec::new();
  209. if !unsafe { self.lup_decompose(&mut pivot, tolerance) } {
  210. return None;
  211. }
  212. let mut result = Self::new();
  213. let zero = unsafe { T::whole(0) };
  214. let one = unsafe { T::whole(1) };
  215. for j in 0..DIMENSION {
  216. for i in 0..DIMENSION {
  217. result[(i, j)] = if pivot[i] == j { one } else { zero };
  218. for k in 0..i {
  219. let factor = result[(k, j)];
  220. result[(i, j)] -= self[(i, k)] * factor;
  221. }
  222. }
  223. for i in (0..DIMENSION).rev() {
  224. for k in (i + 1)..DIMENSION {
  225. let factor = result[(k, j)];
  226. result[(i, j)] -= self[(i, k)] * factor;
  227. }
  228. result[(i, j)] /= self[(i, i)];
  229. }
  230. }
  231. Some(result)
  232. }
  233. }
  234. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> From<[[T; COLUMNS]; ROWS]> for GenericMatrix<T, ROWS, COLUMNS> {
  235. fn from(data: [[T; COLUMNS]; ROWS]) -> Self {
  236. Self {
  237. data,
  238. }
  239. }
  240. }
  241. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Index<(usize, usize)> for GenericMatrix<T, ROWS, COLUMNS> {
  242. type Output = T;
  243. fn index(&self, index: (usize, usize)) -> &Self::Output {
  244. let (row, column) = index;
  245. &self.data[row][column]
  246. }
  247. }
  248. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> IndexMut<(usize, usize)> for GenericMatrix<T, ROWS, COLUMNS> {
  249. fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
  250. let (row, column) = index;
  251. &mut self.data[row][column]
  252. }
  253. }
  254. impl<T: Numeric + Neg<Output=T>, const ROWS: usize, const COLUMNS: usize> Neg for GenericMatrix<T, ROWS, COLUMNS> {
  255. type Output = Self;
  256. fn neg(self) -> Self::Output {
  257. let mut result = Self::default();
  258. for r in 0..ROWS {
  259. for c in 0..COLUMNS {
  260. result.data[r][c] = -self.data[r][c];
  261. }
  262. }
  263. result
  264. }
  265. }
  266. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Add for GenericMatrix<T, ROWS, COLUMNS> {
  267. type Output = Self;
  268. fn add(mut self, rhs: Self) -> Self::Output {
  269. self += rhs;
  270. self
  271. }
  272. }
  273. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> AddAssign for GenericMatrix<T, ROWS, COLUMNS> {
  274. fn add_assign(&mut self, rhs: Self) {
  275. for r in 0..ROWS {
  276. for c in 0..COLUMNS {
  277. self.data[r][c] += rhs.data[r][c];
  278. }
  279. }
  280. }
  281. }
  282. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Sub for GenericMatrix<T, ROWS, COLUMNS> {
  283. type Output = Self;
  284. fn sub(mut self, rhs: Self) -> Self::Output {
  285. self -= rhs;
  286. self
  287. }
  288. }
  289. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> SubAssign for GenericMatrix<T, ROWS, COLUMNS> {
  290. fn sub_assign(&mut self, rhs: Self) {
  291. for r in 0..ROWS {
  292. for c in 0..COLUMNS {
  293. self.data[r][c] -= rhs.data[r][c];
  294. }
  295. }
  296. }
  297. }
  298. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Mul<T> for GenericMatrix<T, ROWS, COLUMNS> {
  299. type Output = Self;
  300. fn mul(mut self, rhs: T) -> Self::Output {
  301. self *= rhs;
  302. self
  303. }
  304. }
  305. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> MulAssign<T> for GenericMatrix<T, ROWS, COLUMNS> {
  306. fn mul_assign(&mut self, rhs: T) {
  307. for r in 0..ROWS {
  308. for c in 0..COLUMNS {
  309. self.data[r][c] *= rhs;
  310. }
  311. }
  312. }
  313. }
  314. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> Div<T> for GenericMatrix<T, ROWS, COLUMNS> {
  315. type Output = Self;
  316. fn div(mut self, rhs: T) -> Self::Output {
  317. self /= rhs;
  318. self
  319. }
  320. }
  321. impl<T: Numeric, const ROWS: usize, const COLUMNS: usize> DivAssign<T> for GenericMatrix<T, ROWS, COLUMNS> {
  322. fn div_assign(&mut self, rhs: T) {
  323. for r in 0..ROWS {
  324. for c in 0..COLUMNS {
  325. self.data[r][c] /= rhs;
  326. }
  327. }
  328. }
  329. }
  330. impl<T: Numeric, const ROWS: usize, const COMMON: usize, const COLUMNS: usize> Mul<GenericMatrix<T, COMMON, COLUMNS>> for GenericMatrix<T, ROWS, COMMON> {
  331. type Output = GenericMatrix<T, ROWS, COLUMNS>;
  332. fn mul(self, rhs: GenericMatrix<T, COMMON, COLUMNS>) -> Self::Output {
  333. let mut result = Self::Output::default();
  334. for i in 0..ROWS {
  335. for j in 0..COLUMNS {
  336. for k in 0..COMMON {
  337. result.data[i][j] += self.data[i][k] * rhs.data[k][j];
  338. }
  339. }
  340. }
  341. result
  342. }
  343. }
  344. /// Apply a transformation matrix to a column vector.
  345. impl<T: Numeric + Float + Primitive> Mul<ColumnVector<T, 3>> for SquareMatrix<T, 4> {
  346. type Output = ColumnVector<T, 3>;
  347. fn mul(self, rhs: ColumnVector<T, 3>) -> Self::Output {
  348. let x_old = rhs[(0, 0)];
  349. let y_old = rhs[(1, 0)];
  350. let z_old = rhs[(2, 0)];
  351. cvec![
  352. self[(0,0)] * x_old + self[(0, 1)] * y_old + self[(0, 2)] * z_old + self[(0, 3)],
  353. self[(1,0)] * x_old + self[(1, 1)] * y_old + self[(1, 2)] * z_old + self[(1, 3)],
  354. self[(2,0)] * x_old + self[(2, 1)] * y_old + self[(2, 2)] * z_old + self[(2, 3)],
  355. ]
  356. }
  357. }
  358. impl<T: Numeric + Float + Primitive> SquareMatrix<T, 4> {
  359. /// Create a 3D translation matrix.
  360. pub fn translation(x: T, y: T, z: T) -> Self {
  361. let mut result = SquareMatrix::identity();
  362. result[(0, 3)] = x;
  363. result[(1, 3)] = y;
  364. result[(2, 3)] = z;
  365. result
  366. }
  367. /// Create a 3D scale matrix.
  368. pub fn scale(x: T, y: T, z: T) -> Self {
  369. let one = unsafe { T::whole(1) };
  370. SquareMatrix::diagonal(&[x, y, z, one])
  371. }
  372. fn rotation_impl(angle: T, x: T, y: T, z: T) -> Self {
  373. let zero = unsafe { T::whole(0) };
  374. let one = unsafe { T::whole(1) };
  375. let c = angle.cos();
  376. let s = angle.sin();
  377. let omc = one - c;
  378. Self::from([
  379. [x * x * omc + c, x * y * omc - z * s, x * z * omc + y * s, zero],
  380. [y * x * omc + z * s, y * y * omc + c, y * z * omc - x * s, zero],
  381. [x * z * omc - y * s, y * z * omc + x * s, z * z * omc + c, zero],
  382. [zero, zero, zero, one],
  383. ])
  384. }
  385. /// Create a 3D rotation matrix.
  386. ///
  387. /// `angle` is expected to be in radians.
  388. #[cfg(not(feature = "angular"))]
  389. pub fn rotation(angle: T, x: T, y: T, z: T) -> Self {
  390. Self::rotation_impl(angle, x, y, z)
  391. }
  392. /// Create a 3D rotation matrix.
  393. #[cfg(feature = "angular")]
  394. pub fn rotation<A: Angular<T>>(angle: A, x: T, y: T, z: T) -> Self {
  395. Self::rotation_impl(angle.radiant(), x, y, z)
  396. }
  397. }
  398. #[cfg(test)]
  399. mod tests {
  400. use crate::{ColumnVector, Complex, cplx, Float, GenericMatrix, Numeric, RowVector, SquareMatrix};
  401. fn matrices_are_equal<T: Numeric + Float + PartialOrd, const ROWS: usize, const COLUMNS: usize>(
  402. a: &GenericMatrix<T, ROWS, COLUMNS>,
  403. b: &GenericMatrix<T, ROWS, COLUMNS>,
  404. tolerance: T,
  405. ) -> bool {
  406. for r in 0..ROWS {
  407. for c in 0..COLUMNS {
  408. if (a[(r, c)] - b[(r, c)]).abs() > tolerance {
  409. return false;
  410. }
  411. }
  412. }
  413. true
  414. }
  415. #[test]
  416. fn identity_matrix() {
  417. let mat: SquareMatrix<f64, 3> = SquareMatrix::identity();
  418. let expected = SquareMatrix::from([
  419. [1.0, 0.0, 0.0],
  420. [0.0, 1.0, 0.0],
  421. [0.0, 0.0, 1.0],
  422. ]);
  423. assert_eq!(mat, expected);
  424. }
  425. #[test]
  426. fn diagonal_matrix() {
  427. let mat: SquareMatrix<f32, 3> = SquareMatrix::diagonal(&[1.0, 2.0, 3.0]);
  428. let expected = SquareMatrix {
  429. data: [[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]],
  430. };
  431. assert_eq!(mat, expected);
  432. }
  433. #[test]
  434. #[should_panic = "index out of bounds"]
  435. fn out_of_bounds_access() {
  436. let mat: SquareMatrix<f32, 1> = SquareMatrix::identity();
  437. mat[(1, 0)];
  438. }
  439. #[test]
  440. fn transposing_matrix() {
  441. let mat = GenericMatrix {
  442. data: [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
  443. }.transposed();
  444. let expected = GenericMatrix {
  445. data: [[1.0, 3.0, 5.0], [2.0, 4.0, 6.0]],
  446. };
  447. assert_eq!(mat, expected);
  448. }
  449. #[test]
  450. fn negating_matrix() {
  451. let mat = -GenericMatrix {
  452. data: [[1, -1], [2, -2]],
  453. };
  454. let expected = GenericMatrix {
  455. data: [[-1, 1], [-2, 2]],
  456. };
  457. assert_eq!(mat, expected);
  458. }
  459. #[test]
  460. fn dot_product_of_vectors() {
  461. let a = rvec![1, 3, -5];
  462. let b = cvec![4, -2, -1];
  463. let product = a * b;
  464. assert_eq!(product[(0, 0)], 3);
  465. }
  466. #[test]
  467. fn matrix_product_of_vectors() {
  468. let a = cvec![1, 3, -5];
  469. let b = rvec![4, -2, -1];
  470. let product = a * b;
  471. let expected = SquareMatrix {
  472. data: [
  473. [4, -2, -1],
  474. [12, -6, -3],
  475. [-20, 10, 5],
  476. ],
  477. };
  478. assert_eq!(product, expected);
  479. }
  480. #[test]
  481. fn complex_matrix_multiplication() {
  482. let a = SquareMatrix {
  483. data: [
  484. [cplx!(2.0, 1.0), cplx!(i = 5.0,)],
  485. [cplx!(3.0), cplx!(3.0, -4.0)],
  486. ],
  487. };
  488. let b = SquareMatrix {
  489. data: [
  490. [cplx!(1.0, -1.0), cplx!(4.0, 2.0)],
  491. [cplx!(1.0, -6.0), cplx!(3.0)],
  492. ],
  493. };
  494. let mat = a * b;
  495. let expected = SquareMatrix::from([
  496. [cplx!(33.0, 4.0), cplx!(6.0, 23.0)],
  497. [cplx!(-18.0, -25.0), cplx!(21.0, -6.0)],
  498. ]);
  499. assert_eq!(mat, expected);
  500. }
  501. #[test]
  502. fn invert_matrix() {
  503. let tolerance = 1e-9;
  504. let mat = SquareMatrix::from([
  505. [2.0, 1.0],
  506. [6.0, 4.0],
  507. ]);
  508. let inverted = mat.inverted(tolerance).unwrap();
  509. let expected = SquareMatrix::from([
  510. [2.0, -0.5],
  511. [-3.0, 1.0],
  512. ]);
  513. assert!(matrices_are_equal(&inverted, &expected, tolerance));
  514. assert!(matrices_are_equal(&(inverted * mat), &SquareMatrix::identity(), tolerance));
  515. }
  516. }