e4ff46f45c
EP message cancellation can leave a Gaussian's precision (pi) a tiny negative value — round-off of exactly zero. mu()/sigma() only special-cased pi == 0, so sigma() computed 1/sqrt(pi) = NaN for pi < 0. That NaN flowed through the moment-space Sub in the game diff-chain and poisoned every skill in the slice once it grew past ~75 competitors, making converge() return all-NaN on real-scale histories (regression vs 0.1.0, which stored sigma directly). Guard pi <= 0.0 in both accessors (improper Gaussian: mu 0, sigma infinite), matching the existing pi == 0 handling. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
305 lines
10 KiB
Rust
305 lines
10 KiB
Rust
use std::ops;
|
||
|
||
use crate::{MU, N_INF, SIGMA};
|
||
|
||
/// A Gaussian distribution stored in natural parameters.
|
||
///
|
||
/// `pi = 1 / sigma^2` (precision)
|
||
/// `tau = mu * pi` (precision-adjusted mean)
|
||
///
|
||
/// Multiplication and division in message passing become pure adds/subs of
|
||
/// the stored fields with no `sqrt` or reciprocal in the hot path. `mu()` and
|
||
/// `sigma()` are accessors computed on demand.
|
||
#[derive(Clone, Copy, PartialEq, Debug)]
|
||
pub struct Gaussian {
|
||
pi: f64,
|
||
tau: f64,
|
||
}
|
||
|
||
impl Gaussian {
|
||
/// Construct from mean and standard deviation.
|
||
pub const fn from_ms(mu: f64, sigma: f64) -> Self {
|
||
if sigma == f64::INFINITY {
|
||
Self { pi: 0.0, tau: 0.0 }
|
||
} else if sigma == 0.0 {
|
||
// Point mass at mu. tau = mu * pi = mu * inf.
|
||
// For mu == 0 this is 0; for mu != 0 it is inf * mu = inf (IEEE).
|
||
// Only N00 (mu=0, sigma=0) is used in practice.
|
||
Self {
|
||
pi: f64::INFINITY,
|
||
tau: if mu == 0.0 { 0.0 } else { f64::INFINITY },
|
||
}
|
||
} else {
|
||
let pi = 1.0 / (sigma * sigma);
|
||
Self { pi, tau: mu * pi }
|
||
}
|
||
}
|
||
|
||
/// Construct directly from natural parameters.
|
||
#[inline]
|
||
pub(crate) const fn from_natural(pi: f64, tau: f64) -> Self {
|
||
Self { pi, tau }
|
||
}
|
||
|
||
#[inline]
|
||
pub fn pi(&self) -> f64 {
|
||
self.pi
|
||
}
|
||
|
||
#[inline]
|
||
pub fn tau(&self) -> f64 {
|
||
self.tau
|
||
}
|
||
|
||
#[inline]
|
||
pub fn mu(&self) -> f64 {
|
||
// A non-positive precision is an improper (uninformative) Gaussian — its mean is
|
||
// undefined. Treat it like `pi == 0` and return 0. EP message cancellation can land
|
||
// `pi` on a tiny negative value (round-off of exactly zero); without this guard
|
||
// `tau / pi` would yield a spurious finite mean.
|
||
if self.pi <= 0.0 {
|
||
0.0
|
||
} else {
|
||
self.tau / self.pi
|
||
}
|
||
}
|
||
|
||
#[inline]
|
||
pub fn sigma(&self) -> f64 {
|
||
// A non-positive precision is improper → infinite standard deviation. Guarding
|
||
// `pi <= 0.0` (not just `== 0.0`) keeps `1.0 / pi.sqrt()` from returning NaN when EP
|
||
// cancellation produces a tiny negative precision (round-off of exactly zero).
|
||
if self.pi <= 0.0 {
|
||
f64::INFINITY
|
||
} else if self.pi.is_infinite() {
|
||
0.0
|
||
} else {
|
||
1.0 / self.pi.sqrt()
|
||
}
|
||
}
|
||
|
||
pub(crate) fn delta(&self, other: Gaussian) -> (f64, f64) {
|
||
(
|
||
(self.mu() - other.mu()).abs(),
|
||
(self.sigma() - other.sigma()).abs(),
|
||
)
|
||
}
|
||
|
||
pub(crate) fn exclude(&self, other: Gaussian) -> Self {
|
||
let var = self.sigma().powi(2) - other.sigma().powi(2);
|
||
if var <= 0.0 {
|
||
// When sigma_self ≈ sigma_other (including ULP-level rounding differences
|
||
// from the pi→sigma accessor round-trip), the excluded contribution is N00.
|
||
// Computing from_ms(tiny_mu, 0.0) would give {pi:inf, tau:inf}, whose
|
||
// mu() = inf/inf = NaN. Returning N00 is correct: when both Gaussians
|
||
// carry the same variance, the residual is a point mass at 0.
|
||
return Gaussian::from_ms(0.0, 0.0);
|
||
}
|
||
let mu = self.mu() - other.mu();
|
||
Self::from_ms(mu, var.sqrt())
|
||
}
|
||
|
||
pub(crate) fn forget(&self, variance_delta: f64) -> Self {
|
||
let var = self.sigma().powi(2) + variance_delta;
|
||
Self::from_ms(self.mu(), var.sqrt())
|
||
}
|
||
|
||
/// EP damping in natural-parameter space: `α·new + (1−α)·self`.
|
||
///
|
||
/// Used by within-game inference to stabilise oscillating fixed-point
|
||
/// loops on hard graphs. `alpha = 1.0` returns `new` exactly;
|
||
/// `alpha < 1.0` shrinks each per-step update.
|
||
pub fn damp_natural(self, new: Gaussian, alpha: f64) -> Gaussian {
|
||
Gaussian::from_natural(
|
||
alpha * new.pi() + (1.0 - alpha) * self.pi(),
|
||
alpha * new.tau() + (1.0 - alpha) * self.tau(),
|
||
)
|
||
}
|
||
}
|
||
|
||
impl Default for Gaussian {
|
||
fn default() -> Self {
|
||
Self::from_ms(MU, SIGMA)
|
||
}
|
||
}
|
||
|
||
impl ops::Add<Gaussian> for Gaussian {
|
||
type Output = Gaussian;
|
||
/// Variance addition: (mu1 + mu2, sqrt(σ1² + σ2²)).
|
||
/// Used for combining performance and noise; rare relative to mul/div.
|
||
fn add(self, rhs: Gaussian) -> Self::Output {
|
||
let mu = self.mu() + rhs.mu();
|
||
let var = self.sigma().powi(2) + rhs.sigma().powi(2);
|
||
Self::from_ms(mu, var.sqrt())
|
||
}
|
||
}
|
||
|
||
impl ops::Sub<Gaussian> for Gaussian {
|
||
type Output = Gaussian;
|
||
/// (mu1 - mu2, sqrt(σ1² + σ2²)). Same sigma combination as Add.
|
||
fn sub(self, rhs: Gaussian) -> Self::Output {
|
||
let mu = self.mu() - rhs.mu();
|
||
let var = self.sigma().powi(2) + rhs.sigma().powi(2);
|
||
Self::from_ms(mu, var.sqrt())
|
||
}
|
||
}
|
||
|
||
impl ops::Mul<Gaussian> for Gaussian {
|
||
type Output = Gaussian;
|
||
/// Factor product: nat-param add. Hot path — two f64 additions, no sqrt.
|
||
fn mul(self, rhs: Gaussian) -> Self::Output {
|
||
Self::from_natural(self.pi + rhs.pi, self.tau + rhs.tau)
|
||
}
|
||
}
|
||
|
||
impl ops::Mul<f64> for Gaussian {
|
||
type Output = Gaussian;
|
||
fn mul(self, scalar: f64) -> Self::Output {
|
||
if !scalar.is_finite() {
|
||
return N_INF;
|
||
}
|
||
if scalar == 0.0 {
|
||
// Scaling by 0 collapses to a point mass at 0 (sigma' = 0, mu' = 0).
|
||
// This is N00, the additive identity, NOT N_INF.
|
||
return Gaussian::from_ms(0.0, 0.0);
|
||
}
|
||
// sigma' = sigma * |scalar| => pi' = pi / scalar²
|
||
// mu' = mu * scalar => tau' = tau / scalar
|
||
Self::from_natural(self.pi / (scalar * scalar), self.tau / scalar)
|
||
}
|
||
}
|
||
|
||
impl ops::Div<Gaussian> for Gaussian {
|
||
type Output = Gaussian;
|
||
/// Cavity: nat-param sub. Hot path — two f64 subtractions, no sqrt.
|
||
fn div(self, rhs: Gaussian) -> Self::Output {
|
||
Self::from_natural(self.pi - rhs.pi, self.tau - rhs.tau)
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn non_positive_precision_is_improper_not_nan() {
|
||
// EP message cancellation can leave `pi` a tiny negative (round-off of exactly zero).
|
||
// Such a Gaussian is improper/uninformative: mu() must be 0 and sigma() infinite, not
|
||
// NaN. A NaN here propagates through the moment-space `Sub` in the game chain and
|
||
// poisons every skill in the slice.
|
||
let tiny_neg = Gaussian::from_natural(-5.55e-17, -8.88e-16);
|
||
assert_eq!(tiny_neg.mu(), 0.0);
|
||
assert!(tiny_neg.sigma().is_infinite());
|
||
|
||
// A frankly-negative precision is treated the same way.
|
||
let neg = Gaussian::from_natural(-1.0, 2.0);
|
||
assert_eq!(neg.mu(), 0.0);
|
||
assert!(neg.sigma().is_infinite());
|
||
|
||
// Subtracting such a message must not produce NaN (the original failure path).
|
||
let proper = Gaussian::from_ms(9.75, 1.256);
|
||
let diff = proper - tiny_neg;
|
||
assert!(diff.pi().is_finite() && !diff.pi().is_nan());
|
||
assert!(diff.tau().is_finite() && !diff.tau().is_nan());
|
||
}
|
||
|
||
#[test]
|
||
fn test_add() {
|
||
let n = Gaussian::from_ms(25.0, 25.0 / 3.0);
|
||
let m = Gaussian::from_ms(0.0, 1.0);
|
||
let r = n + m;
|
||
assert!((r.mu() - 25.0).abs() < 1e-12);
|
||
assert!((r.sigma() - 8.393118874676116).abs() < 1e-10);
|
||
}
|
||
|
||
#[test]
|
||
fn test_sub() {
|
||
let n = Gaussian::from_ms(25.0, 25.0 / 3.0);
|
||
let m = Gaussian::from_ms(1.0, 1.0);
|
||
let r = n - m;
|
||
assert!((r.mu() - 24.0).abs() < 1e-12);
|
||
assert!((r.sigma() - 8.393118874676116).abs() < 1e-10);
|
||
}
|
||
|
||
#[test]
|
||
fn test_mul() {
|
||
let n = Gaussian::from_ms(25.0, 25.0 / 3.0);
|
||
let m = Gaussian::from_ms(0.0, 1.0);
|
||
let r = n * m;
|
||
assert!((r.mu() - 0.35488958990536273).abs() < 1e-10);
|
||
assert!((r.sigma() - 0.992876838486922).abs() < 1e-10);
|
||
}
|
||
|
||
#[test]
|
||
fn test_div() {
|
||
let n = Gaussian::from_ms(25.0, 25.0 / 3.0);
|
||
let m = Gaussian::from_ms(0.0, 1.0);
|
||
let r = m / n;
|
||
assert!((r.mu() - (-0.3652597402597402)).abs() < 1e-10);
|
||
assert!((r.sigma() - 1.0072787050317253).abs() < 1e-10);
|
||
}
|
||
|
||
#[test]
|
||
fn test_n00_is_add_identity() {
|
||
// N00 (sigma=0) is the additive identity for the variance-convolution Add op.
|
||
// N_INF (sigma=inf) is the identity for the EP-product Mul op.
|
||
let g = Gaussian::from_ms(3.0, 2.0);
|
||
let n00 = Gaussian::from_ms(0.0, 0.0);
|
||
let r = n00 + g;
|
||
assert!((r.mu() - g.mu()).abs() < 1e-12);
|
||
assert!((r.sigma() - g.sigma()).abs() < 1e-12);
|
||
}
|
||
|
||
#[test]
|
||
fn test_mul_is_factor_product() {
|
||
// n * m in nat-params should be pi_n + pi_m, tau_n + tau_m
|
||
let n = Gaussian::from_ms(2.0, 3.0);
|
||
let m = Gaussian::from_ms(1.0, 2.0);
|
||
let r = n * m;
|
||
let expected_pi = n.pi() + m.pi();
|
||
let expected_tau = n.tau() + m.tau();
|
||
assert!((r.pi() - expected_pi).abs() < 1e-15);
|
||
assert!((r.tau() - expected_tau).abs() < 1e-15);
|
||
}
|
||
|
||
#[test]
|
||
fn test_div_is_cavity() {
|
||
let n = Gaussian::from_ms(2.0, 1.0);
|
||
let m = Gaussian::from_ms(1.0, 2.0);
|
||
let r = n / m;
|
||
let expected_pi = n.pi() - m.pi();
|
||
let expected_tau = n.tau() - m.tau();
|
||
assert!((r.pi() - expected_pi).abs() < 1e-15);
|
||
assert!((r.tau() - expected_tau).abs() < 1e-15);
|
||
}
|
||
|
||
#[test]
|
||
fn damp_natural_alpha_one_returns_new() {
|
||
let old = Gaussian::from_ms(1.0, 2.0);
|
||
let new = Gaussian::from_ms(5.0, 0.5);
|
||
let damped = old.damp_natural(new, 1.0);
|
||
assert_eq!(damped.pi(), new.pi());
|
||
assert_eq!(damped.tau(), new.tau());
|
||
}
|
||
|
||
#[test]
|
||
fn damp_natural_alpha_zero_returns_self() {
|
||
let old = Gaussian::from_ms(1.0, 2.0);
|
||
let new = Gaussian::from_ms(5.0, 0.5);
|
||
let damped = old.damp_natural(new, 0.0);
|
||
assert_eq!(damped.pi(), old.pi());
|
||
assert_eq!(damped.tau(), old.tau());
|
||
}
|
||
|
||
#[test]
|
||
fn damp_natural_alpha_half_is_midpoint_in_natural_params() {
|
||
let old = Gaussian::from_ms(1.0, 2.0);
|
||
let new = Gaussian::from_ms(5.0, 0.5);
|
||
let damped = old.damp_natural(new, 0.5);
|
||
let expected_pi = 0.5 * new.pi() + 0.5 * old.pi();
|
||
let expected_tau = 0.5 * new.tau() + 0.5 * old.tau();
|
||
assert!((damped.pi() - expected_pi).abs() < 1e-12);
|
||
assert!((damped.tau() - expected_tau).abs() < 1e-12);
|
||
}
|
||
}
|