complex.rs 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
  2. use crate::{Float, Numeric, Primitive};
  3. /// Struct representing a complex number.
  4. ///
  5. /// # Example
  6. ///
  7. /// ```rust
  8. /// # use lineal::{Complex, complex, Float, Numeric};
  9. /// let c = Complex { real: 9.0, imag: 0.0 };
  10. /// assert_eq!(format!("{:?}", c), "Complex { real: 9.0, imag: 0.0 }");
  11. ///
  12. /// let three = c.sqrt().real;
  13. /// assert_eq!(three, 3.0);
  14. ///
  15. /// let i: Complex<f32> = complex!(-1.0).sqrt();
  16. /// assert_eq!(i, Complex::i());
  17. ///
  18. /// let a = Complex { real: 2.0, imag: 0.5 };
  19. /// let a_squared = a * a;
  20. /// assert_eq!(a_squared, Complex { real: 3.75, imag: 2.0 });
  21. /// assert_eq!(a_squared.sqrt(), a);
  22. ///
  23. /// let x = (16.0).sqrt();
  24. /// let y = complex!(16.0).sqrt().real;
  25. /// assert_eq!(x, y);
  26. /// ```
  27. #[derive(Debug, Copy, Clone)]
  28. pub struct Complex<T: Float + Numeric + Primitive> {
  29. pub real: T,
  30. pub imag: T,
  31. }
  32. #[macro_export]
  33. macro_rules! complex {
  34. ($real:expr, $imag:expr) => {
  35. Complex { real: $real, imag: $imag }
  36. };
  37. ($real:expr) => {
  38. Complex { real: $real, imag: 0.0 }
  39. };
  40. ($($field:ident = $value:expr),* $(,)?) => {
  41. Complex {
  42. $( $field: $value, )*
  43. ..Complex::default()
  44. }
  45. }
  46. }
  47. impl<T: Float + Numeric + Primitive> Default for Complex<T> {
  48. fn default() -> Self {
  49. unsafe { Complex::whole(0) }
  50. }
  51. }
  52. impl<T: Float + Numeric + Primitive> Complex<T> {
  53. pub fn real(value: T) -> Self {
  54. Self {
  55. real: value,
  56. imag: unsafe { T::whole(0) },
  57. }
  58. }
  59. pub fn imag(value: T) -> Self {
  60. Self {
  61. real: unsafe { T::whole(0) },
  62. imag: value,
  63. }
  64. }
  65. /// The imaginary unit.
  66. pub fn i() -> Self {
  67. Self {
  68. real: unsafe { T::whole(0) },
  69. imag: unsafe { T::whole(1) },
  70. }
  71. }
  72. }
  73. impl<T: Float + Numeric + Primitive> Add for Complex<T> {
  74. type Output = Self;
  75. fn add(self, rhs: Self) -> Self::Output {
  76. Self {
  77. real: self.real + rhs.real,
  78. imag: self.imag + rhs.imag,
  79. }
  80. }
  81. }
  82. impl<T: Float + Numeric + Primitive> AddAssign for Complex<T> {
  83. fn add_assign(&mut self, rhs: Self) {
  84. self.real += rhs.real;
  85. self.imag += rhs.imag;
  86. }
  87. }
  88. impl<T: Float + Numeric + Primitive> Sub for Complex<T> {
  89. type Output = Self;
  90. fn sub(self, rhs: Self) -> Self::Output {
  91. Self {
  92. real: self.real - rhs.real,
  93. imag: self.imag - rhs.imag,
  94. }
  95. }
  96. }
  97. impl<T: Float + Numeric + Primitive> SubAssign for Complex<T> {
  98. fn sub_assign(&mut self, rhs: Self) {
  99. self.real -= rhs.real;
  100. self.imag -= rhs.imag;
  101. }
  102. }
  103. impl<T: Float + Numeric + Primitive> Mul for Complex<T> {
  104. type Output = Self;
  105. fn mul(self, rhs: Self) -> Self::Output {
  106. Self {
  107. real: self.real * rhs.real - self.imag * rhs.imag,
  108. imag: self.real * rhs.imag + self.imag * rhs.real,
  109. }
  110. }
  111. }
  112. impl<T: Float + Numeric + Primitive> MulAssign for Complex<T> {
  113. fn mul_assign(&mut self, rhs: Self) {
  114. *self = *self * rhs;
  115. }
  116. }
  117. impl<T: Float + Numeric + Primitive> Div for Complex<T> {
  118. type Output = Self;
  119. fn div(self, rhs: Self) -> Self::Output {
  120. let divisor = rhs.real * rhs.real + rhs.imag * rhs.imag;
  121. let mut result = self * Self { real: rhs.real, imag: -rhs.imag };
  122. result.real /= divisor;
  123. result.imag /= divisor;
  124. return result;
  125. }
  126. }
  127. impl<T: Float + Numeric + Primitive> DivAssign for Complex<T> {
  128. fn div_assign(&mut self, rhs: Self) {
  129. *self = *self / rhs;
  130. }
  131. }
  132. impl<T: Float + Numeric + Primitive> Neg for Complex<T> {
  133. type Output = Self;
  134. fn neg(self) -> Self::Output {
  135. Self {
  136. real: -self.real,
  137. imag: -self.imag,
  138. }
  139. }
  140. }
  141. impl<T: Float + Numeric + Primitive> PartialEq for Complex<T> {
  142. fn eq(&self, other: &Self) -> bool {
  143. self.real == other.real && self.imag == other.imag
  144. }
  145. fn ne(&self, other: &Self) -> bool {
  146. self.real != other.real || self.imag != other.imag
  147. }
  148. }
  149. impl<T: Float + Primitive + Numeric> Float for Complex<T> {
  150. fn abs(self) -> Self {
  151. Self {
  152. real: (self.real * self.real + self.imag * self.imag).sqrt(),
  153. imag: unsafe { T::whole(0) },
  154. }
  155. }
  156. fn sqrt(self) -> Self {
  157. let zero = unsafe { T::whole(0) };
  158. let two = unsafe { T::whole(2) };
  159. let abs_z = self.abs().real;
  160. let base_re = if self.imag == zero { self.real } else { (abs_z + self.real) / two };
  161. let re = if base_re < zero { zero } else { base_re.sqrt() };
  162. let im = if self.imag == zero { zero } else {
  163. (self.imag / self.imag.abs()) * ((abs_z - self.real) / two).sqrt()
  164. } + if base_re < zero { (-base_re).sqrt() } else { zero };
  165. Self {
  166. real: re,
  167. imag: im,
  168. }
  169. }
  170. fn sin(self) -> Self {
  171. Self {
  172. real: self.real.sin() * self.imag.cosh(),
  173. imag: self.real.cos() * self.imag.sinh(),
  174. }
  175. }
  176. fn cos(self) -> Self {
  177. Self {
  178. real: self.real.cos() * self.imag.cosh(),
  179. imag: self.real.sin() * self.imag.sinh(),
  180. }
  181. }
  182. fn sinh(self) -> Self {
  183. Self {
  184. real: self.real.sinh() * self.imag.cos(),
  185. imag: self.real.cosh() * self.imag.sin(),
  186. }
  187. }
  188. fn cosh(self) -> Self {
  189. Self {
  190. real: self.real.cosh() * self.imag.cos(),
  191. imag: self.real.sinh() * self.imag.sin(),
  192. }
  193. }
  194. }
  195. impl<T: Float + Numeric + Primitive> Numeric for Complex<T> {
  196. unsafe fn whole(value: u32) -> Self {
  197. Self { real: T::whole(value), imag: T::whole(0) }
  198. }
  199. }
  200. #[cfg(test)]
  201. mod tests {
  202. use crate::{Complex, Float};
  203. #[test]
  204. fn macro_creation() {
  205. let a = complex!(-1.0);
  206. let b = complex!(real = -1.0,);
  207. assert_eq!(a, b);
  208. let c = complex!(imag = 1.0,);
  209. assert_eq!(a.sqrt(), c);
  210. let d = complex!(real = -1.0, imag = 0.0,);
  211. assert_eq!(c * c, d);
  212. let e = complex!(-1.0, 0.0);
  213. assert_eq!(a, e);
  214. }
  215. #[test]
  216. fn division() {
  217. let mut a: Complex<f32> = Complex { real: 3.0, imag: 2.0 };
  218. let b: Complex<f32> = Complex { real: 4.0, imag: -5.0 };
  219. let expected: Complex<f32> = Complex { real: 2.0 / 41.0, imag: 23.0 / 41.0 };
  220. a /= b;
  221. assert_eq!(format!("{:?}", a), format!("{:?}", expected));
  222. }
  223. #[test]
  224. fn square_root_with_negatives() {
  225. let a = Complex { real: -1.0, imag: -1.0 };
  226. let root = a.sqrt();
  227. let expected = Complex { real: 0.45508986, imag: -1.09868411 };
  228. assert!((root.real - expected.real).abs() < 0.00001);
  229. assert!((root.imag - expected.imag).abs() < 0.00001);
  230. }
  231. }