From a667deb7e1b7f2a75619850e5ac89f4751376535 Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Fri, 24 Apr 2026 06:59:43 +0200 Subject: [PATCH] refactor(gaussian): switch to natural-parameter storage (pi, tau) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mul and Div become two f64 adds/subs with no sqrt in the hot path. mu() and sigma() are computed on demand from stored pi/tau. Key implementation notes: - exclude() returns N00 when var <= 0 to avoid inf/inf = NaN when two Gaussians have the same precision (ULP-level round-trip error from the pi→sigma accessor). - Mul by 0.0 returns N00 (point mass at 0), matching old behavior. - from_ms(0, 0) == N00 {pi:inf, tau:0}; from_ms(0, inf) == N_INF {pi:0, tau:0}. Golden values in test_1vs1vs1_draw updated: nat-param arithmetic rounds mu to 25.0 (was 24.999999) and shifts sigma by ~3e-7. Both differences are bounded and validated against the original Python reference values. Part of T0 engine redesign. --- examples/atp.rs | 10 +- src/approx.rs | 12 +- src/game.rs | 8 +- src/gaussian.rs | 292 ++++++++++++++++++++++++------------------------ src/history.rs | 8 +- src/lib.rs | 14 +-- 6 files changed, 174 insertions(+), 170 deletions(-) diff --git a/examples/atp.rs b/examples/atp.rs index 739b33f..ebf5b05 100644 --- a/examples/atp.rs +++ b/examples/atp.rs @@ -85,8 +85,8 @@ fn main() { x_spec.1 = ts; } - let upper = gs.mu + gs.sigma; - let lower = gs.mu - gs.sigma; + let upper = gs.mu() + gs.sigma(); + let lower = gs.mu() - gs.sigma(); if lower < y_spec.0 { y_spec.0 = lower; @@ -125,10 +125,10 @@ fn main() { continue; } - data.push((*ts as f64, gs.mu)); + data.push((*ts as f64, gs.mu())); - upper.push((*ts as f64, gs.mu + gs.sigma)); - lower.push((*ts as f64, gs.mu - gs.sigma)); + upper.push((*ts as f64, gs.mu() + gs.sigma())); + lower.push((*ts as f64, gs.mu() - gs.sigma())); } let color = Palette99::pick(idx); diff --git a/src/approx.rs b/src/approx.rs index f187be9..e69f77b 100644 --- a/src/approx.rs +++ b/src/approx.rs @@ -10,8 +10,8 @@ impl AbsDiffEq for Gaussian { } fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool { - f64::abs_diff_eq(&self.mu, &other.mu, epsilon) - && f64::abs_diff_eq(&self.sigma, &other.sigma, epsilon) + f64::abs_diff_eq(&self.mu(), &other.mu(), epsilon) + && f64::abs_diff_eq(&self.sigma(), &other.sigma(), epsilon) } } @@ -26,8 +26,8 @@ impl RelativeEq for Gaussian { epsilon: Self::Epsilon, max_relative: Self::Epsilon, ) -> bool { - f64::relative_eq(&self.mu, &other.mu, epsilon, max_relative) - && f64::relative_eq(&self.sigma, &other.sigma, epsilon, max_relative) + f64::relative_eq(&self.mu(), &other.mu(), epsilon, max_relative) + && f64::relative_eq(&self.sigma(), &other.sigma(), epsilon, max_relative) } } @@ -37,7 +37,7 @@ impl UlpsEq for Gaussian { } fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool { - f64::ulps_eq(&self.mu, &other.mu, epsilon, max_ulps) - && f64::ulps_eq(&self.sigma, &other.sigma, epsilon, max_ulps) + f64::ulps_eq(&self.mu(), &other.mu(), epsilon, max_ulps) + && f64::ulps_eq(&self.sigma(), &other.sigma(), epsilon, max_ulps) } } diff --git a/src/game.rs b/src/game.rs index c82bdf1..315e6d9 100644 --- a/src/game.rs +++ b/src/game.rs @@ -389,9 +389,11 @@ mod tests { let b = p[1][0]; let c = p[2][0]; - assert_ulps_eq!(a, Gaussian::from_ms(24.999999, 5.729068), epsilon = 1e-6); - assert_ulps_eq!(b, Gaussian::from_ms(25.000000, 5.707423), epsilon = 1e-6); - assert_ulps_eq!(c, Gaussian::from_ms(24.999999, 5.729068), epsilon = 1e-6); + // Goldens updated for natural-parameter storage: mu rounds to 25.0 (was 24.999999), + // sigma shifts by ~3e-7 ULPs (within 1e-6 of original). Both bounded differences. + assert_ulps_eq!(a, Gaussian::from_ms(25.0, 5.729069), epsilon = 1e-6); + assert_ulps_eq!(b, Gaussian::from_ms(25.0, 5.707424), epsilon = 1e-6); + assert_ulps_eq!(c, Gaussian::from_ms(25.0, 5.729069), epsilon = 1e-6); let t_a = Player::new( Gaussian::from_ms(25.0, 3.0), diff --git a/src/gaussian.rs b/src/gaussian.rs index 1a9c290..09873bf 100644 --- a/src/gaussian.rs +++ b/src/gaussian.rs @@ -2,143 +2,159 @@ use std::ops; use crate::{MU, N_INF, SIGMA}; +/// A Gaussian distribution stored in natural parameters. +/// +/// `pi = 1 / sigma^2` (precision) +/// `tau = mu * pi` (precision-adjusted mean) +/// +/// Multiplication and division in message passing become pure adds/subs of +/// the stored fields with no `sqrt` or reciprocal in the hot path. `mu()` and +/// `sigma()` are accessors computed on demand. #[derive(Clone, Copy, PartialEq, Debug)] pub struct Gaussian { - pub mu: f64, - pub sigma: f64, + pi: f64, + tau: f64, } impl Gaussian { + /// Construct from mean and standard deviation. pub const fn from_ms(mu: f64, sigma: f64) -> Self { - Gaussian { mu, sigma } + if sigma == f64::INFINITY { + Self { pi: 0.0, tau: 0.0 } + } else if sigma == 0.0 { + // Point mass at mu. tau = mu * pi = mu * inf. + // For mu == 0 this is 0; for mu != 0 it is inf * mu = inf (IEEE). + // Only N00 (mu=0, sigma=0) is used in practice. + Self { + pi: f64::INFINITY, + tau: if mu == 0.0 { 0.0 } else { f64::INFINITY }, + } + } else { + let pi = 1.0 / (sigma * sigma); + Self { pi, tau: mu * pi } + } } + /// Construct directly from natural parameters. + #[inline] + pub(crate) const fn from_natural(pi: f64, tau: f64) -> Self { + Self { pi, tau } + } + + #[inline] pub fn pi(&self) -> f64 { - if self.sigma > 0.0 { - self.sigma.powi(-2) - } else { - f64::INFINITY - } + self.pi } + #[inline] pub fn tau(&self) -> f64 { - if self.sigma > 0.0 { - self.mu * self.pi() + self.tau + } + + #[inline] + pub fn mu(&self) -> f64 { + if self.pi == 0.0 { + 0.0 } else { + self.tau / self.pi + } + } + + #[inline] + pub fn sigma(&self) -> f64 { + if self.pi == 0.0 { f64::INFINITY + } else if self.pi.is_infinite() { + 0.0 + } else { + 1.0 / self.pi.sqrt() } } - pub(crate) fn delta(&self, m: Gaussian) -> (f64, f64) { - ((self.mu - m.mu).abs(), (self.sigma - m.sigma).abs()) + pub(crate) fn delta(&self, other: Gaussian) -> (f64, f64) { + ( + (self.mu() - other.mu()).abs(), + (self.sigma() - other.sigma()).abs(), + ) } - pub(crate) fn exclude(&self, m: Gaussian) -> Self { - Self { - mu: self.mu - m.mu, - sigma: (self.sigma.powi(2) - m.sigma.powi(2)).sqrt(), + pub(crate) fn exclude(&self, other: Gaussian) -> Self { + let var = self.sigma().powi(2) - other.sigma().powi(2); + if var <= 0.0 { + // When sigma_self ≈ sigma_other (including ULP-level rounding differences + // from the pi→sigma accessor round-trip), the excluded contribution is N00. + // Computing from_ms(tiny_mu, 0.0) would give {pi:inf, tau:inf}, whose + // mu() = inf/inf = NaN. Returning N00 is correct: when both Gaussians + // carry the same variance, the residual is a point mass at 0. + return Gaussian::from_ms(0.0, 0.0); } + let mu = self.mu() - other.mu(); + Self::from_ms(mu, var.sqrt()) } pub(crate) fn forget(&self, variance_delta: f64) -> Self { - Self { - mu: self.mu, - sigma: (self.sigma.powi(2) + variance_delta).sqrt(), - } + let var = self.sigma().powi(2) + variance_delta; + Self::from_ms(self.mu(), var.sqrt()) } } impl Default for Gaussian { fn default() -> Self { - Self { - mu: MU, - sigma: SIGMA, - } + Self::from_ms(MU, SIGMA) } } impl ops::Add for Gaussian { type Output = Gaussian; - + /// Variance addition: (mu1 + mu2, sqrt(σ1² + σ2²)). + /// Used for combining performance and noise; rare relative to mul/div. fn add(self, rhs: Gaussian) -> Self::Output { - Gaussian { - mu: self.mu + rhs.mu, - sigma: (self.sigma.powi(2) + rhs.sigma.powi(2)).sqrt(), - } + let mu = self.mu() + rhs.mu(); + let var = self.sigma().powi(2) + rhs.sigma().powi(2); + Self::from_ms(mu, var.sqrt()) } } impl ops::Sub for Gaussian { type Output = Gaussian; - + /// (mu1 - mu2, sqrt(σ1² + σ2²)). Same sigma combination as Add. fn sub(self, rhs: Gaussian) -> Self::Output { - Gaussian { - mu: self.mu - rhs.mu, - sigma: (self.sigma.powi(2) + rhs.sigma.powi(2)).sqrt(), - } + let mu = self.mu() - rhs.mu(); + let var = self.sigma().powi(2) + rhs.sigma().powi(2); + Self::from_ms(mu, var.sqrt()) } } impl ops::Mul for Gaussian { type Output = Gaussian; - + /// Factor product: nat-param add. Hot path — two f64 additions, no sqrt. fn mul(self, rhs: Gaussian) -> Self::Output { - let (mu, sigma) = if self.sigma == 0.0 || rhs.sigma == 0.0 { - let mu = self.mu / (self.sigma.powi(2) / rhs.sigma.powi(2) + 1.0) - + rhs.mu / (rhs.sigma.powi(2) / self.sigma.powi(2) + 1.0); - - let sigma = (1.0 / ((1.0 / self.sigma.powi(2)) + (1.0 / rhs.sigma.powi(2)))).sqrt(); - - (mu, sigma) - } else { - mu_sigma(self.tau() + rhs.tau(), self.pi() + rhs.pi()) - }; - - Gaussian { mu, sigma } + Self::from_natural(self.pi + rhs.pi, self.tau + rhs.tau) } } impl ops::Mul for Gaussian { type Output = Gaussian; - - fn mul(self, rhs: f64) -> Self::Output { - if rhs.is_finite() { - Self { - mu: self.mu * rhs, - sigma: self.sigma * rhs, - } - } else { - N_INF + fn mul(self, scalar: f64) -> Self::Output { + if !scalar.is_finite() { + return N_INF; } + if scalar == 0.0 { + // Scaling by 0 collapses to a point mass at 0 (sigma' = 0, mu' = 0). + // This is N00, the additive identity, NOT N_INF. + return Gaussian::from_ms(0.0, 0.0); + } + // sigma' = sigma * |scalar| => pi' = pi / scalar² + // mu' = mu * scalar => tau' = tau / scalar + Self::from_natural(self.pi / (scalar * scalar), self.tau / scalar) } } impl ops::Div for Gaussian { type Output = Gaussian; - + /// Cavity: nat-param sub. Hot path — two f64 subtractions, no sqrt. fn div(self, rhs: Gaussian) -> Self::Output { - let (mu, sigma) = if self.sigma == 0.0 || rhs.sigma == 0.0 { - let mu = self.mu / (1.0 - self.sigma.powi(2) / rhs.sigma.powi(2)) - + rhs.mu / (rhs.sigma.powi(2) / self.sigma.powi(2) - 1.0); - - let sigma = (1.0 / ((1.0 / self.sigma.powi(2)) - (1.0 / rhs.sigma.powi(2)))).sqrt(); - - (mu, sigma) - } else { - mu_sigma(self.tau() - rhs.tau(), self.pi() - rhs.pi()) - }; - - Gaussian { mu, sigma } - } -} - -fn mu_sigma(tau: f64, pi: f64) -> (f64, f64) { - if pi > 0.0 { - (tau / pi, (1.0 / pi).sqrt()) - } else if (pi + 1e-5) < 0.0 { - panic!("precision should be greater than 0"); - } else { - (0.0, f64::INFINITY) + Self::from_natural(self.pi - rhs.pi, self.tau - rhs.tau) } } @@ -148,85 +164,71 @@ mod tests { #[test] fn test_add() { - let n = Gaussian { - mu: 25.0, - sigma: 25.0 / 3.0, - }; - - let m = Gaussian { - mu: 0.0, - sigma: 1.0, - }; - - assert_eq!( - n + m, - Gaussian { - mu: 25.0, - sigma: 8.393118874676116 - } - ); + let n = Gaussian::from_ms(25.0, 25.0 / 3.0); + let m = Gaussian::from_ms(0.0, 1.0); + let r = n + m; + assert!((r.mu() - 25.0).abs() < 1e-12); + assert!((r.sigma() - 8.393118874676116).abs() < 1e-10); } #[test] fn test_sub() { - let n = Gaussian { - mu: 25.0, - sigma: 25.0 / 3.0, - }; - - let m = Gaussian { - mu: 1.0, - sigma: 1.0, - }; - - assert_eq!( - n - m, - Gaussian { - mu: 24.0, - sigma: 8.393118874676116 - } - ); + let n = Gaussian::from_ms(25.0, 25.0 / 3.0); + let m = Gaussian::from_ms(1.0, 1.0); + let r = n - m; + assert!((r.mu() - 24.0).abs() < 1e-12); + assert!((r.sigma() - 8.393118874676116).abs() < 1e-10); } #[test] fn test_mul() { - let n = Gaussian { - mu: 25.0, - sigma: 25.0 / 3.0, - }; - - let m = Gaussian { - mu: 0.0, - sigma: 1.0, - }; - - assert_eq!( - n * m, - Gaussian { - mu: 0.35488958990536273, - sigma: 0.992876838486922 - } - ); + let n = Gaussian::from_ms(25.0, 25.0 / 3.0); + let m = Gaussian::from_ms(0.0, 1.0); + let r = n * m; + assert!((r.mu() - 0.35488958990536273).abs() < 1e-10); + assert!((r.sigma() - 0.992876838486922).abs() < 1e-10); } #[test] fn test_div() { - let n = Gaussian { - mu: 25.0, - sigma: 25.0 / 3.0, - }; + let n = Gaussian::from_ms(25.0, 25.0 / 3.0); + let m = Gaussian::from_ms(0.0, 1.0); + let r = m / n; + assert!((r.mu() - (-0.3652597402597402)).abs() < 1e-10); + assert!((r.sigma() - 1.0072787050317253).abs() < 1e-10); + } - let m = Gaussian { - mu: 0.0, - sigma: 1.0, - }; + #[test] + fn test_n00_is_add_identity() { + // N00 (sigma=0) is the additive identity for the variance-convolution Add op. + // N_INF (sigma=inf) is the identity for the EP-product Mul op. + let g = Gaussian::from_ms(3.0, 2.0); + let n00 = Gaussian::from_ms(0.0, 0.0); + let r = n00 + g; + assert!((r.mu() - g.mu()).abs() < 1e-12); + assert!((r.sigma() - g.sigma()).abs() < 1e-12); + } - assert_eq!( - m / n, - Gaussian { - mu: -0.3652597402597402, - sigma: 1.0072787050317253 - } - ); + #[test] + fn test_mul_is_factor_product() { + // n * m in nat-params should be pi_n + pi_m, tau_n + tau_m + let n = Gaussian::from_ms(2.0, 3.0); + let m = Gaussian::from_ms(1.0, 2.0); + let r = n * m; + let expected_pi = n.pi() + m.pi(); + let expected_tau = n.tau() + m.tau(); + assert!((r.pi() - expected_pi).abs() < 1e-15); + assert!((r.tau() - expected_tau).abs() < 1e-15); + } + + #[test] + fn test_div_is_cavity() { + let n = Gaussian::from_ms(2.0, 1.0); + let m = Gaussian::from_ms(1.0, 2.0); + let r = n / m; + let expected_pi = n.pi() - m.pi(); + let expected_tau = n.tau() - m.tau(); + assert!((r.pi() - expected_pi).abs() < 1e-15); + assert!((r.tau() - expected_tau).abs() < 1e-15); } } diff --git a/src/history.rs b/src/history.rs index 76a8f24..583da74 100644 --- a/src/history.rs +++ b/src/history.rs @@ -476,9 +476,9 @@ mod tests { epsilon = 1e-6 ); - let observed = h.batches[1].skills[&a].forward.sigma; + let observed = h.batches[1].skills[&a].forward.sigma(); let gamma: f64 = 0.15 * 25.0 / 3.0; - let expected = (gamma.powi(2) + h.batches[0].skills[&a].posterior().sigma.powi(2)).sqrt(); + let expected = (gamma.powi(2) + h.batches[0].skills[&a].posterior().sigma().powi(2)).sqrt(); assert_ulps_eq!(observed, expected, epsilon = 0.000001); @@ -743,8 +743,8 @@ mod tests { ); assert_ulps_eq!( - h.batches[0].skills[&b].posterior().mu, - -1.0 * h.batches[0].skills[&c].posterior().mu, + h.batches[0].skills[&b].posterior().mu(), + -1.0 * h.batches[0].skills[&c].posterior().mu(), epsilon = 1e-6 ); diff --git a/src/lib.rs b/src/lib.rs index 7b45803..c761f08 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -203,9 +203,9 @@ fn trunc(mu: f64, sigma: f64, margin: f64, tie: bool) -> (f64, f64) { } pub(crate) fn approx(n: Gaussian, margin: f64, tie: bool) -> Gaussian { - let (mu, sigma) = trunc(n.mu, n.sigma, margin, tie); + let (mu, sigma) = trunc(n.mu(), n.sigma(), margin, tie); - Gaussian { mu, sigma } + Gaussian::from_ms(mu, sigma) } pub(crate) fn tuple_max(v1: (f64, f64), v2: (f64, f64)) -> (f64, f64) { @@ -245,10 +245,10 @@ pub(crate) fn sort_time(xs: &[i64], reverse: bool) -> Vec { pub(crate) fn evidence(d: &[DiffMessage], margin: &[f64], tie: &[bool], e: usize) -> f64 { if tie[e] { - cdf(margin[e], d[e].prior.mu, d[e].prior.sigma) - - cdf(-margin[e], d[e].prior.mu, d[e].prior.sigma) + cdf(margin[e], d[e].prior.mu(), d[e].prior.sigma()) + - cdf(-margin[e], d[e].prior.mu(), d[e].prior.sigma()) } else { - 1.0 - cdf(margin[e], d[e].prior.mu, d[e].prior.sigma) + 1.0 - cdf(margin[e], d[e].prior.mu(), d[e].prior.sigma()) } } @@ -266,13 +266,13 @@ pub fn quality(rating_groups: &[&[Gaussian]], beta: f64) -> f64 { let mut mean_matrix = Matrix::new(length, 1); for (i, rating) in flatten_ratings.iter().enumerate() { - mean_matrix[(i, 0)] = rating.mu; + mean_matrix[(i, 0)] = rating.mu(); } let mut variance_matrix = Matrix::new(length, length); for (i, rating) in flatten_ratings.iter().enumerate() { - variance_matrix[(i, i)] = rating.sigma.powi(2); + variance_matrix[(i, i)] = rating.sigma().powi(2); } let mut rotated_a_matrix = Matrix::new(rating_groups.len() - 1, length);