Files
trueskill-tt/src/gaussian.rs

235 lines
5.0 KiB
Rust

use std::ops;
use crate::{MU, N_INF, SIGMA};
#[derive(Clone, Copy, PartialEq, Debug)]
pub struct Gaussian {
pub mu: f64,
pub sigma: f64,
}
impl Gaussian {
pub fn new(mu: f64, sigma: f64) -> Self {
debug_assert!(sigma >= 0.0, "sigma must be equal or larger than 0.0");
Gaussian { mu, sigma }
}
fn pi(&self) -> f64 {
if self.sigma > 0.0 {
self.sigma.powi(-2)
} else {
f64::INFINITY
}
}
fn tau(&self) -> f64 {
if self.sigma > 0.0 {
self.mu * self.pi()
} else {
f64::INFINITY
}
}
pub(crate) fn delta(&self, m: Gaussian) -> (f64, f64) {
((self.mu - m.mu).abs(), (self.sigma - m.sigma).abs())
}
pub(crate) fn exclude(&self, m: Gaussian) -> Self {
Self {
mu: self.mu - m.mu,
sigma: (self.sigma.powi(2) - m.sigma.powi(2)).sqrt(),
}
}
pub(crate) fn forget(&self, gamma: f64, t: u64) -> Self {
Self {
mu: self.mu,
sigma: (self.sigma.powi(2) + t as f64 * gamma.powi(2)).sqrt(),
}
}
}
impl Default for Gaussian {
fn default() -> Self {
Self {
mu: MU,
sigma: SIGMA,
}
}
}
impl ops::Add<Gaussian> for Gaussian {
type Output = Gaussian;
fn add(self, rhs: Gaussian) -> Self::Output {
Gaussian {
mu: self.mu + rhs.mu,
sigma: (self.sigma.powi(2) + rhs.sigma.powi(2)).sqrt(),
}
}
}
impl ops::Sub<Gaussian> for Gaussian {
type Output = Gaussian;
fn sub(self, rhs: Gaussian) -> Self::Output {
Gaussian {
mu: self.mu - rhs.mu,
sigma: (self.sigma.powi(2) + rhs.sigma.powi(2)).sqrt(),
}
}
}
impl ops::Mul<Gaussian> for Gaussian {
type Output = Gaussian;
fn mul(self, rhs: Gaussian) -> Self::Output {
let (mu, sigma) = if self.sigma == 0.0 || rhs.sigma == 0.0 {
let mu = self.mu / (self.sigma.powi(2) / rhs.sigma.powi(2) + 1.0)
+ rhs.mu / (rhs.sigma.powi(2) / self.sigma.powi(2) + 1.0);
let sigma = (1.0 / ((1.0 / self.sigma.powi(2)) + (1.0 / rhs.sigma.powi(2)))).sqrt();
(mu, sigma)
} else {
mu_sigma(self.tau() + rhs.tau(), self.pi() + rhs.pi())
};
Gaussian { mu, sigma }
}
}
impl ops::Mul<f64> for Gaussian {
type Output = Gaussian;
fn mul(self, rhs: f64) -> Self::Output {
if rhs.is_finite() {
Self {
mu: self.mu * rhs,
sigma: self.sigma * rhs,
}
} else {
N_INF
}
}
}
impl ops::Div<Gaussian> for Gaussian {
type Output = Gaussian;
fn div(self, rhs: Gaussian) -> Self::Output {
let (mu, sigma) = if self.sigma == 0.0 || rhs.sigma == 0.0 {
let mu = self.mu / (1.0 - self.sigma.powi(2) / rhs.sigma.powi(2))
+ rhs.mu / (rhs.sigma.powi(2) / self.sigma.powi(2) - 1.0);
let sigma = (1.0 / ((1.0 / self.sigma.powi(2)) - (1.0 / rhs.sigma.powi(2)))).sqrt();
(mu, sigma)
} else {
mu_sigma(self.tau() - rhs.tau(), self.pi() - rhs.pi())
};
Gaussian { mu, sigma }
}
}
fn mu_sigma(tau: f64, pi: f64) -> (f64, f64) {
if pi > 0.0 {
(tau / pi, (1.0 / pi).sqrt())
} else if (pi + 1e-5) < 0.0 {
panic!("precision should be greater than 0");
} else {
(0.0, f64::INFINITY)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add() {
let n = Gaussian {
mu: 25.0,
sigma: 25.0 / 3.0,
};
let m = Gaussian {
mu: 0.0,
sigma: 1.0,
};
assert_eq!(
n + m,
Gaussian {
mu: 25.0,
sigma: 8.393118874676116
}
);
}
#[test]
fn test_sub() {
let n = Gaussian {
mu: 25.0,
sigma: 25.0 / 3.0,
};
let m = Gaussian {
mu: 1.0,
sigma: 1.0,
};
assert_eq!(
n - m,
Gaussian {
mu: 24.0,
sigma: 8.393118874676116
}
);
}
#[test]
fn test_mul() {
let n = Gaussian {
mu: 25.0,
sigma: 25.0 / 3.0,
};
let m = Gaussian {
mu: 0.0,
sigma: 1.0,
};
assert_eq!(
n * m,
Gaussian {
mu: 0.35488958990536273,
sigma: 0.992876838486922
}
);
}
#[test]
fn test_div() {
let n = Gaussian {
mu: 25.0,
sigma: 25.0 / 3.0,
};
let m = Gaussian {
mu: 0.0,
sigma: 1.0,
};
assert_eq!(
m / n,
Gaussian {
mu: -0.3652597402597402,
sigma: 1.0072787050317253
}
);
}
}