Port from julia version instead
This commit is contained in:
117
src/gaussian.rs
117
src/gaussian.rs
@@ -1,62 +1,58 @@
|
||||
use std::ops;
|
||||
|
||||
use crate::{utils, MU, SIGMA};
|
||||
use crate::{MU, N_INF, SIGMA};
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Debug)]
|
||||
pub struct Gaussian {
|
||||
mu: f64,
|
||||
sigma: f64,
|
||||
pub(crate) mu: f64,
|
||||
pub(crate) sigma: f64,
|
||||
}
|
||||
|
||||
impl Gaussian {
|
||||
#[inline]
|
||||
pub const fn new(mu: f64, sigma: f64) -> Self {
|
||||
pub fn new(mu: f64, sigma: f64) -> Self {
|
||||
debug_assert!(sigma >= 0.0, "sigma must be equal or larger than 0.0");
|
||||
|
||||
Gaussian { mu, sigma }
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn mu(&self) -> f64 {
|
||||
self.mu
|
||||
fn pi(&self) -> f64 {
|
||||
if self.sigma > 0.0 {
|
||||
self.sigma.powi(-2)
|
||||
} else {
|
||||
f64::INFINITY
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn sigma(&self) -> f64 {
|
||||
self.sigma
|
||||
fn tau(&self) -> f64 {
|
||||
if self.sigma > 0.0 {
|
||||
self.mu * self.pi()
|
||||
} else {
|
||||
f64::INFINITY
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn tau(&self) -> f64 {
|
||||
self.mu * self.pi()
|
||||
pub(crate) fn delta(&self, m: Gaussian) -> (f64, f64) {
|
||||
((self.mu - m.mu).abs(), (self.sigma - m.sigma).abs())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn pi(&self) -> f64 {
|
||||
self.sigma.powi(-2)
|
||||
pub(crate) fn exclude(&self, m: Gaussian) -> Self {
|
||||
Self {
|
||||
mu: self.mu - m.mu,
|
||||
sigma: (self.sigma.powi(2) - m.sigma.powi(2)).sqrt(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn forget(&self, gamma: f64, t: f64) -> Self {
|
||||
Self::new(self.mu, (self.sigma().powi(2) + t * gamma.powi(2)).sqrt())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn delta(&self, m: Gaussian) -> (f64, f64) {
|
||||
((self.mu() - m.mu()).abs(), (self.sigma() - m.sigma()).abs())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn exclude(&self, m: Gaussian) -> Self {
|
||||
Self::new(
|
||||
self.mu() - m.mu(),
|
||||
(self.sigma().powi(2) - m.sigma().powi(2)).sqrt(),
|
||||
)
|
||||
pub(crate) fn forget(&self, gamma: f64, t: u64) -> Self {
|
||||
Self {
|
||||
mu: self.mu,
|
||||
sigma: (self.sigma.powi(2) + t as f64 * gamma.powi(2)).sqrt(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Gaussian {
|
||||
#[inline]
|
||||
fn default() -> Self {
|
||||
Gaussian {
|
||||
Self {
|
||||
mu: MU,
|
||||
sigma: SIGMA,
|
||||
}
|
||||
@@ -66,7 +62,6 @@ impl Default for Gaussian {
|
||||
impl ops::Add<Gaussian> for Gaussian {
|
||||
type Output = Gaussian;
|
||||
|
||||
#[inline]
|
||||
fn add(self, rhs: Gaussian) -> Self::Output {
|
||||
Gaussian {
|
||||
mu: self.mu + rhs.mu,
|
||||
@@ -78,7 +73,6 @@ impl ops::Add<Gaussian> for Gaussian {
|
||||
impl ops::Sub<Gaussian> for Gaussian {
|
||||
type Output = Gaussian;
|
||||
|
||||
#[inline]
|
||||
fn sub(self, rhs: Gaussian) -> Self::Output {
|
||||
Gaussian {
|
||||
mu: self.mu - rhs.mu,
|
||||
@@ -90,25 +84,66 @@ impl ops::Sub<Gaussian> for Gaussian {
|
||||
impl ops::Mul<Gaussian> for Gaussian {
|
||||
type Output = Gaussian;
|
||||
|
||||
#[inline]
|
||||
fn mul(self, rhs: Gaussian) -> Self::Output {
|
||||
let (mu, sigma) = utils::mu_sigma(self.tau() + rhs.tau(), self.pi() + rhs.pi());
|
||||
let (mu, sigma) = if self.sigma == 0.0 || rhs.sigma == 0.0 {
|
||||
let mu = self.mu / (self.sigma.powi(2) / rhs.sigma.powi(2) + 1.0)
|
||||
+ rhs.mu / (rhs.sigma.powi(2) / self.sigma.powi(2) + 1.0);
|
||||
|
||||
let sigma = (1.0 / ((1.0 / self.sigma.powi(2)) + (1.0 / rhs.sigma.powi(2)))).sqrt();
|
||||
|
||||
(mu, sigma)
|
||||
} else {
|
||||
mu_sigma(self.tau() + rhs.tau(), self.pi() + rhs.pi())
|
||||
};
|
||||
|
||||
Gaussian { mu, sigma }
|
||||
}
|
||||
}
|
||||
|
||||
impl ops::Mul<f64> for Gaussian {
|
||||
type Output = Gaussian;
|
||||
|
||||
fn mul(self, rhs: f64) -> Self::Output {
|
||||
if rhs.is_finite() {
|
||||
Self {
|
||||
mu: self.mu * rhs,
|
||||
sigma: self.sigma * rhs,
|
||||
}
|
||||
} else {
|
||||
N_INF
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ops::Div<Gaussian> for Gaussian {
|
||||
type Output = Gaussian;
|
||||
|
||||
#[inline]
|
||||
fn div(self, rhs: Gaussian) -> Self::Output {
|
||||
let (mu, sigma) = utils::mu_sigma(self.tau() - rhs.tau(), self.pi() - rhs.pi());
|
||||
let (mu, sigma) = if self.sigma == 0.0 || rhs.sigma == 0.0 {
|
||||
let mu = self.mu / (1.0 - self.sigma.powi(2) / rhs.sigma.powi(2))
|
||||
+ rhs.mu / (rhs.sigma.powi(2) / self.sigma.powi(2) - 1.0);
|
||||
|
||||
let sigma = (1.0 / ((1.0 / self.sigma.powi(2)) - (1.0 / rhs.sigma.powi(2)))).sqrt();
|
||||
|
||||
(mu, sigma)
|
||||
} else {
|
||||
mu_sigma(self.tau() - rhs.tau(), self.pi() - rhs.pi())
|
||||
};
|
||||
|
||||
Gaussian { mu, sigma }
|
||||
}
|
||||
}
|
||||
|
||||
fn mu_sigma(tau: f64, pi: f64) -> (f64, f64) {
|
||||
if pi > 0.0 {
|
||||
(tau / pi, (1.0 / pi).sqrt())
|
||||
} else if (pi + 1e-5) < 0.0 {
|
||||
panic!("precision should be greater than 0");
|
||||
} else {
|
||||
(0.0, f64::INFINITY)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
Reference in New Issue
Block a user