123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268 |
- use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
- use crate::{Float, Numeric, Primitive};
- /// Struct representing a complex number.
- ///
- /// # Example
- ///
- /// ```rust
- /// # use lineal::{Complex, complex, Float, Numeric};
- /// let c = Complex { real: 9.0, imag: 0.0 };
- /// assert_eq!(format!("{:?}", c), "Complex { real: 9.0, imag: 0.0 }");
- ///
- /// let three = c.sqrt().real;
- /// assert_eq!(three, 3.0);
- ///
- /// let i: Complex<f32> = complex!(-1.0).sqrt();
- /// assert_eq!(i, Complex::i());
- ///
- /// let a = Complex { real: 2.0, imag: 0.5 };
- /// let a_squared = a * a;
- /// assert_eq!(a_squared, Complex { real: 3.75, imag: 2.0 });
- /// assert_eq!(a_squared.sqrt(), a);
- ///
- /// let x = (16.0).sqrt();
- /// let y = complex!(16.0).sqrt().real;
- /// assert_eq!(x, y);
- /// ```
- #[derive(Debug, Copy, Clone)]
- pub struct Complex<T: Float + Numeric + Primitive> {
- pub real: T,
- pub imag: T,
- }
- #[macro_export]
- macro_rules! complex {
- ($real:expr, $imag:expr) => {
- Complex { real: $real, imag: $imag }
- };
- ($real:expr) => {
- Complex { real: $real, imag: 0.0 }
- };
- ($($field:ident = $value:expr),* $(,)?) => {
- Complex {
- $( $field: $value, )*
- ..Complex::default()
- }
- }
- }
- impl<T: Float + Numeric + Primitive> Default for Complex<T> {
- fn default() -> Self {
- unsafe { Complex::whole(0) }
- }
- }
- impl<T: Float + Numeric + Primitive> Complex<T> {
- pub fn real(value: T) -> Self {
- Self {
- real: value,
- imag: unsafe { T::whole(0) },
- }
- }
- pub fn imag(value: T) -> Self {
- Self {
- real: unsafe { T::whole(0) },
- imag: value,
- }
- }
- /// The imaginary unit.
- pub fn i() -> Self {
- Self {
- real: unsafe { T::whole(0) },
- imag: unsafe { T::whole(1) },
- }
- }
- }
- impl<T: Float + Numeric + Primitive> Add for Complex<T> {
- type Output = Self;
- fn add(self, rhs: Self) -> Self::Output {
- Self {
- real: self.real + rhs.real,
- imag: self.imag + rhs.imag,
- }
- }
- }
- impl<T: Float + Numeric + Primitive> AddAssign for Complex<T> {
- fn add_assign(&mut self, rhs: Self) {
- self.real += rhs.real;
- self.imag += rhs.imag;
- }
- }
- impl<T: Float + Numeric + Primitive> Sub for Complex<T> {
- type Output = Self;
- fn sub(self, rhs: Self) -> Self::Output {
- Self {
- real: self.real - rhs.real,
- imag: self.imag - rhs.imag,
- }
- }
- }
- impl<T: Float + Numeric + Primitive> SubAssign for Complex<T> {
- fn sub_assign(&mut self, rhs: Self) {
- self.real -= rhs.real;
- self.imag -= rhs.imag;
- }
- }
- impl<T: Float + Numeric + Primitive> Mul for Complex<T> {
- type Output = Self;
- fn mul(self, rhs: Self) -> Self::Output {
- Self {
- real: self.real * rhs.real - self.imag * rhs.imag,
- imag: self.real * rhs.imag + self.imag * rhs.real,
- }
- }
- }
- impl<T: Float + Numeric + Primitive> MulAssign for Complex<T> {
- fn mul_assign(&mut self, rhs: Self) {
- *self = *self * rhs;
- }
- }
- impl<T: Float + Numeric + Primitive> Div for Complex<T> {
- type Output = Self;
- fn div(self, rhs: Self) -> Self::Output {
- let divisor = rhs.real * rhs.real + rhs.imag * rhs.imag;
- let mut result = self * Self { real: rhs.real, imag: -rhs.imag };
- result.real /= divisor;
- result.imag /= divisor;
- return result;
- }
- }
- impl<T: Float + Numeric + Primitive> DivAssign for Complex<T> {
- fn div_assign(&mut self, rhs: Self) {
- *self = *self / rhs;
- }
- }
- impl<T: Float + Numeric + Primitive> Neg for Complex<T> {
- type Output = Self;
- fn neg(self) -> Self::Output {
- Self {
- real: -self.real,
- imag: -self.imag,
- }
- }
- }
- impl<T: Float + Numeric + Primitive> PartialEq for Complex<T> {
- fn eq(&self, other: &Self) -> bool {
- self.real == other.real && self.imag == other.imag
- }
- fn ne(&self, other: &Self) -> bool {
- self.real != other.real || self.imag != other.imag
- }
- }
- impl<T: Float + Primitive + Numeric> Float for Complex<T> {
- fn abs(self) -> Self {
- Self {
- real: (self.real * self.real + self.imag * self.imag).sqrt(),
- imag: unsafe { T::whole(0) },
- }
- }
- fn sqrt(self) -> Self {
- let zero = unsafe { T::whole(0) };
- let two = unsafe { T::whole(2) };
- let abs_z = self.abs().real;
- let base_re = if self.imag == zero { self.real } else { (abs_z + self.real) / two };
- let re = if base_re < zero { zero } else { base_re.sqrt() };
- let im = if self.imag == zero { zero } else {
- (self.imag / self.imag.abs()) * ((abs_z - self.real) / two).sqrt()
- } + if base_re < zero { (-base_re).sqrt() } else { zero };
- Self {
- real: re,
- imag: im,
- }
- }
- fn sin(self) -> Self {
- Self {
- real: self.real.sin() * self.imag.cosh(),
- imag: self.real.cos() * self.imag.sinh(),
- }
- }
- fn cos(self) -> Self {
- Self {
- real: self.real.cos() * self.imag.cosh(),
- imag: self.real.sin() * self.imag.sinh(),
- }
- }
- fn sinh(self) -> Self {
- Self {
- real: self.real.sinh() * self.imag.cos(),
- imag: self.real.cosh() * self.imag.sin(),
- }
- }
- fn cosh(self) -> Self {
- Self {
- real: self.real.cosh() * self.imag.cos(),
- imag: self.real.sinh() * self.imag.sin(),
- }
- }
- }
- impl<T: Float + Numeric + Primitive> Numeric for Complex<T> {
- unsafe fn whole(value: u32) -> Self {
- Self { real: T::whole(value), imag: T::whole(0) }
- }
- }
- #[cfg(test)]
- mod tests {
- use crate::{Complex, Float};
- #[test]
- fn macro_creation() {
- let a = complex!(-1.0);
- let b = complex!(real = -1.0,);
- assert_eq!(a, b);
- let c = complex!(imag = 1.0,);
- assert_eq!(a.sqrt(), c);
- let d = complex!(real = -1.0, imag = 0.0,);
- assert_eq!(c * c, d);
- let e = complex!(-1.0, 0.0);
- assert_eq!(a, e);
- }
- #[test]
- fn division() {
- let mut a: Complex<f32> = Complex { real: 3.0, imag: 2.0 };
- let b: Complex<f32> = Complex { real: 4.0, imag: -5.0 };
- let expected: Complex<f32> = Complex { real: 2.0 / 41.0, imag: 23.0 / 41.0 };
- a /= b;
- assert_eq!(format!("{:?}", a), format!("{:?}", expected));
- }
- #[test]
- fn square_root_with_negatives() {
- let a = Complex { real: -1.0, imag: -1.0 };
- let root = a.sqrt();
- let expected = Complex { real: 0.45508986, imag: -1.09868411 };
- assert!((root.real - expected.real).abs() < 0.00001);
- assert!((root.imag - expected.imag).abs() < 0.00001);
- }
- }
|