use std::ops; use crate::{MU, N_INF, SIGMA}; /// A Gaussian distribution stored in natural parameters. /// /// `pi = 1 / sigma^2` (precision) /// `tau = mu * pi` (precision-adjusted mean) /// /// Multiplication and division in message passing become pure adds/subs of /// the stored fields with no `sqrt` or reciprocal in the hot path. `mu()` and /// `sigma()` are accessors computed on demand. #[derive(Clone, Copy, PartialEq, Debug)] pub struct Gaussian { pi: f64, tau: f64, } impl Gaussian { /// Construct from mean and standard deviation. pub const fn from_ms(mu: f64, sigma: f64) -> Self { if sigma == f64::INFINITY { Self { pi: 0.0, tau: 0.0 } } else if sigma == 0.0 { // Point mass at mu. tau = mu * pi = mu * inf. // For mu == 0 this is 0; for mu != 0 it is inf * mu = inf (IEEE). // Only N00 (mu=0, sigma=0) is used in practice. Self { pi: f64::INFINITY, tau: if mu == 0.0 { 0.0 } else { f64::INFINITY }, } } else { let pi = 1.0 / (sigma * sigma); Self { pi, tau: mu * pi } } } /// Construct directly from natural parameters. #[inline] pub(crate) const fn from_natural(pi: f64, tau: f64) -> Self { Self { pi, tau } } #[inline] pub fn pi(&self) -> f64 { self.pi } #[inline] pub fn tau(&self) -> f64 { self.tau } #[inline] pub fn mu(&self) -> f64 { // A non-positive precision is an improper (uninformative) Gaussian — its mean is // undefined. Treat it like `pi == 0` and return 0. EP message cancellation can land // `pi` on a tiny negative value (round-off of exactly zero); without this guard // `tau / pi` would yield a spurious finite mean. if self.pi <= 0.0 { 0.0 } else { self.tau / self.pi } } #[inline] pub fn sigma(&self) -> f64 { // A non-positive precision is improper → infinite standard deviation. Guarding // `pi <= 0.0` (not just `== 0.0`) keeps `1.0 / pi.sqrt()` from returning NaN when EP // cancellation produces a tiny negative precision (round-off of exactly zero). if self.pi <= 0.0 { f64::INFINITY } else if self.pi.is_infinite() { 0.0 } else { 1.0 / self.pi.sqrt() } } pub(crate) fn delta(&self, other: Gaussian) -> (f64, f64) { ( (self.mu() - other.mu()).abs(), (self.sigma() - other.sigma()).abs(), ) } pub(crate) fn exclude(&self, other: Gaussian) -> Self { let var = self.sigma().powi(2) - other.sigma().powi(2); if var <= 0.0 { // When sigma_self ≈ sigma_other (including ULP-level rounding differences // from the pi→sigma accessor round-trip), the excluded contribution is N00. // Computing from_ms(tiny_mu, 0.0) would give {pi:inf, tau:inf}, whose // mu() = inf/inf = NaN. Returning N00 is correct: when both Gaussians // carry the same variance, the residual is a point mass at 0. return Gaussian::from_ms(0.0, 0.0); } let mu = self.mu() - other.mu(); Self::from_ms(mu, var.sqrt()) } pub(crate) fn forget(&self, variance_delta: f64) -> Self { let var = self.sigma().powi(2) + variance_delta; Self::from_ms(self.mu(), var.sqrt()) } /// EP damping in natural-parameter space: `α·new + (1−α)·self`. /// /// Used by within-game inference to stabilise oscillating fixed-point /// loops on hard graphs. `alpha = 1.0` returns `new` exactly; /// `alpha < 1.0` shrinks each per-step update. pub fn damp_natural(self, new: Gaussian, alpha: f64) -> Gaussian { Gaussian::from_natural( alpha * new.pi() + (1.0 - alpha) * self.pi(), alpha * new.tau() + (1.0 - alpha) * self.tau(), ) } } impl Default for Gaussian { fn default() -> Self { Self::from_ms(MU, SIGMA) } } impl ops::Add for Gaussian { type Output = Gaussian; /// Variance addition: (mu1 + mu2, sqrt(σ1² + σ2²)). /// Used for combining performance and noise; rare relative to mul/div. fn add(self, rhs: Gaussian) -> Self::Output { let mu = self.mu() + rhs.mu(); let var = self.sigma().powi(2) + rhs.sigma().powi(2); Self::from_ms(mu, var.sqrt()) } } impl ops::Sub for Gaussian { type Output = Gaussian; /// (mu1 - mu2, sqrt(σ1² + σ2²)). Same sigma combination as Add. fn sub(self, rhs: Gaussian) -> Self::Output { let mu = self.mu() - rhs.mu(); let var = self.sigma().powi(2) + rhs.sigma().powi(2); Self::from_ms(mu, var.sqrt()) } } impl ops::Mul for Gaussian { type Output = Gaussian; /// Factor product: nat-param add. Hot path — two f64 additions, no sqrt. fn mul(self, rhs: Gaussian) -> Self::Output { Self::from_natural(self.pi + rhs.pi, self.tau + rhs.tau) } } impl ops::Mul for Gaussian { type Output = Gaussian; fn mul(self, scalar: f64) -> Self::Output { if !scalar.is_finite() { return N_INF; } if scalar == 0.0 { // Scaling by 0 collapses to a point mass at 0 (sigma' = 0, mu' = 0). // This is N00, the additive identity, NOT N_INF. return Gaussian::from_ms(0.0, 0.0); } // sigma' = sigma * |scalar| => pi' = pi / scalar² // mu' = mu * scalar => tau' = tau / scalar Self::from_natural(self.pi / (scalar * scalar), self.tau / scalar) } } impl ops::Div for Gaussian { type Output = Gaussian; /// Cavity: nat-param sub. Hot path — two f64 subtractions, no sqrt. fn div(self, rhs: Gaussian) -> Self::Output { Self::from_natural(self.pi - rhs.pi, self.tau - rhs.tau) } } #[cfg(test)] mod tests { use super::*; #[test] fn non_positive_precision_is_improper_not_nan() { // EP message cancellation can leave `pi` a tiny negative (round-off of exactly zero). // Such a Gaussian is improper/uninformative: mu() must be 0 and sigma() infinite, not // NaN. A NaN here propagates through the moment-space `Sub` in the game chain and // poisons every skill in the slice. let tiny_neg = Gaussian::from_natural(-5.55e-17, -8.88e-16); assert_eq!(tiny_neg.mu(), 0.0); assert!(tiny_neg.sigma().is_infinite()); // A frankly-negative precision is treated the same way. let neg = Gaussian::from_natural(-1.0, 2.0); assert_eq!(neg.mu(), 0.0); assert!(neg.sigma().is_infinite()); // Subtracting such a message must not produce NaN (the original failure path). let proper = Gaussian::from_ms(9.75, 1.256); let diff = proper - tiny_neg; assert!(diff.pi().is_finite() && !diff.pi().is_nan()); assert!(diff.tau().is_finite() && !diff.tau().is_nan()); } #[test] fn test_add() { let n = Gaussian::from_ms(25.0, 25.0 / 3.0); let m = Gaussian::from_ms(0.0, 1.0); let r = n + m; assert!((r.mu() - 25.0).abs() < 1e-12); assert!((r.sigma() - 8.393118874676116).abs() < 1e-10); } #[test] fn test_sub() { let n = Gaussian::from_ms(25.0, 25.0 / 3.0); let m = Gaussian::from_ms(1.0, 1.0); let r = n - m; assert!((r.mu() - 24.0).abs() < 1e-12); assert!((r.sigma() - 8.393118874676116).abs() < 1e-10); } #[test] fn test_mul() { let n = Gaussian::from_ms(25.0, 25.0 / 3.0); let m = Gaussian::from_ms(0.0, 1.0); let r = n * m; assert!((r.mu() - 0.35488958990536273).abs() < 1e-10); assert!((r.sigma() - 0.992876838486922).abs() < 1e-10); } #[test] fn test_div() { let n = Gaussian::from_ms(25.0, 25.0 / 3.0); let m = Gaussian::from_ms(0.0, 1.0); let r = m / n; assert!((r.mu() - (-0.3652597402597402)).abs() < 1e-10); assert!((r.sigma() - 1.0072787050317253).abs() < 1e-10); } #[test] fn test_n00_is_add_identity() { // N00 (sigma=0) is the additive identity for the variance-convolution Add op. // N_INF (sigma=inf) is the identity for the EP-product Mul op. let g = Gaussian::from_ms(3.0, 2.0); let n00 = Gaussian::from_ms(0.0, 0.0); let r = n00 + g; assert!((r.mu() - g.mu()).abs() < 1e-12); assert!((r.sigma() - g.sigma()).abs() < 1e-12); } #[test] fn test_mul_is_factor_product() { // n * m in nat-params should be pi_n + pi_m, tau_n + tau_m let n = Gaussian::from_ms(2.0, 3.0); let m = Gaussian::from_ms(1.0, 2.0); let r = n * m; let expected_pi = n.pi() + m.pi(); let expected_tau = n.tau() + m.tau(); assert!((r.pi() - expected_pi).abs() < 1e-15); assert!((r.tau() - expected_tau).abs() < 1e-15); } #[test] fn test_div_is_cavity() { let n = Gaussian::from_ms(2.0, 1.0); let m = Gaussian::from_ms(1.0, 2.0); let r = n / m; let expected_pi = n.pi() - m.pi(); let expected_tau = n.tau() - m.tau(); assert!((r.pi() - expected_pi).abs() < 1e-15); assert!((r.tau() - expected_tau).abs() < 1e-15); } #[test] fn damp_natural_alpha_one_returns_new() { let old = Gaussian::from_ms(1.0, 2.0); let new = Gaussian::from_ms(5.0, 0.5); let damped = old.damp_natural(new, 1.0); assert_eq!(damped.pi(), new.pi()); assert_eq!(damped.tau(), new.tau()); } #[test] fn damp_natural_alpha_zero_returns_self() { let old = Gaussian::from_ms(1.0, 2.0); let new = Gaussian::from_ms(5.0, 0.5); let damped = old.damp_natural(new, 0.0); assert_eq!(damped.pi(), old.pi()); assert_eq!(damped.tau(), old.tau()); } #[test] fn damp_natural_alpha_half_is_midpoint_in_natural_params() { let old = Gaussian::from_ms(1.0, 2.0); let new = Gaussian::from_ms(5.0, 0.5); let damped = old.damp_natural(new, 0.5); let expected_pi = 0.5 * new.pi() + 0.5 * old.pi(); let expected_tau = 0.5 * new.tau() + 0.5 * old.tau(); assert!((damped.pi() - expected_pi).abs() < 1e-12); assert!((damped.tau() - expected_tau).abs() < 1e-12); } }