use std::ops; use crate::{MU, N_INF, SIGMA}; #[derive(Clone, Copy, PartialEq, Debug)] pub struct Gaussian { pub mu: f64, pub sigma: f64, } impl Gaussian { pub fn new(mu: f64, sigma: f64) -> Self { debug_assert!(sigma >= 0.0, "sigma must be equal or larger than 0.0"); Gaussian { mu, sigma } } fn pi(&self) -> f64 { if self.sigma > 0.0 { self.sigma.powi(-2) } else { f64::INFINITY } } fn tau(&self) -> f64 { if self.sigma > 0.0 { self.mu * self.pi() } else { f64::INFINITY } } pub(crate) fn delta(&self, m: Gaussian) -> (f64, f64) { ((self.mu - m.mu).abs(), (self.sigma - m.sigma).abs()) } pub(crate) fn exclude(&self, m: Gaussian) -> Self { Self { mu: self.mu - m.mu, sigma: (self.sigma.powi(2) - m.sigma.powi(2)).sqrt(), } } pub(crate) fn forget(&self, gamma: f64, t: u64) -> Self { Self { mu: self.mu, sigma: (self.sigma.powi(2) + t as f64 * gamma.powi(2)).sqrt(), } } } impl Default for Gaussian { fn default() -> Self { Self { mu: MU, sigma: SIGMA, } } } impl ops::Add for Gaussian { type Output = Gaussian; fn add(self, rhs: Gaussian) -> Self::Output { Gaussian { mu: self.mu + rhs.mu, sigma: (self.sigma.powi(2) + rhs.sigma.powi(2)).sqrt(), } } } impl ops::Sub for Gaussian { type Output = Gaussian; fn sub(self, rhs: Gaussian) -> Self::Output { Gaussian { mu: self.mu - rhs.mu, sigma: (self.sigma.powi(2) + rhs.sigma.powi(2)).sqrt(), } } } impl ops::Mul for Gaussian { type Output = Gaussian; fn mul(self, rhs: Gaussian) -> Self::Output { let (mu, sigma) = if self.sigma == 0.0 || rhs.sigma == 0.0 { let mu = self.mu / (self.sigma.powi(2) / rhs.sigma.powi(2) + 1.0) + rhs.mu / (rhs.sigma.powi(2) / self.sigma.powi(2) + 1.0); let sigma = (1.0 / ((1.0 / self.sigma.powi(2)) + (1.0 / rhs.sigma.powi(2)))).sqrt(); (mu, sigma) } else { mu_sigma(self.tau() + rhs.tau(), self.pi() + rhs.pi()) }; Gaussian { mu, sigma } } } impl ops::Mul for Gaussian { type Output = Gaussian; fn mul(self, rhs: f64) -> Self::Output { if rhs.is_finite() { Self { mu: self.mu * rhs, sigma: self.sigma * rhs, } } else { N_INF } } } impl ops::Div for Gaussian { type Output = Gaussian; fn div(self, rhs: Gaussian) -> Self::Output { let (mu, sigma) = if self.sigma == 0.0 || rhs.sigma == 0.0 { let mu = self.mu / (1.0 - self.sigma.powi(2) / rhs.sigma.powi(2)) + rhs.mu / (rhs.sigma.powi(2) / self.sigma.powi(2) - 1.0); let sigma = (1.0 / ((1.0 / self.sigma.powi(2)) - (1.0 / rhs.sigma.powi(2)))).sqrt(); (mu, sigma) } else { mu_sigma(self.tau() - rhs.tau(), self.pi() - rhs.pi()) }; Gaussian { mu, sigma } } } fn mu_sigma(tau: f64, pi: f64) -> (f64, f64) { if pi > 0.0 { (tau / pi, (1.0 / pi).sqrt()) } else if (pi + 1e-5) < 0.0 { panic!("precision should be greater than 0"); } else { (0.0, f64::INFINITY) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add() { let n = Gaussian { mu: 25.0, sigma: 25.0 / 3.0, }; let m = Gaussian { mu: 0.0, sigma: 1.0, }; assert_eq!( n + m, Gaussian { mu: 25.0, sigma: 8.393118874676116 } ); } #[test] fn test_sub() { let n = Gaussian { mu: 25.0, sigma: 25.0 / 3.0, }; let m = Gaussian { mu: 1.0, sigma: 1.0, }; assert_eq!( n - m, Gaussian { mu: 24.0, sigma: 8.393118874676116 } ); } #[test] fn test_mul() { let n = Gaussian { mu: 25.0, sigma: 25.0 / 3.0, }; let m = Gaussian { mu: 0.0, sigma: 1.0, }; assert_eq!( n * m, Gaussian { mu: 0.35488958990536273, sigma: 0.992876838486922 } ); } #[test] fn test_div() { let n = Gaussian { mu: 25.0, sigma: 25.0 / 3.0, }; let m = Gaussian { mu: 0.0, sigma: 1.0, }; assert_eq!( m / n, Gaussian { mu: -0.3652597402597402, sigma: 1.0072787050317253 } ); } }