diff --git a/src/factor/trunc.rs b/src/factor/trunc.rs index f5b6dfe..abe3dbb 100644 --- a/src/factor/trunc.rs +++ b/src/factor/trunc.rs @@ -1,15 +1,23 @@ use crate::{ - N_INF, + N_INF, approx, cdf, factor::{Factor, VarId, VarStore}, gaussian::Gaussian, }; +/// EP truncation factor on a diff variable. +/// +/// Implements the rectified-Gaussian approximation that turns a diff +/// distribution into a "this team rank-beats that team" or "tied" likelihood. +/// Stores its outgoing message to the diff variable so the cavity computation +/// produces the correct EP message on each propagation. #[derive(Debug)] pub(crate) struct TruncFactor { pub(crate) diff: VarId, pub(crate) margin: f64, pub(crate) tie: bool, + /// Outgoing message to the diff variable (initial: N_INF, the EP identity). pub(crate) msg: Gaussian, + /// Cached evidence (linear, not log) computed from the cavity on first propagation. pub(crate) evidence_cached: Option, } @@ -26,7 +34,97 @@ impl TruncFactor { } impl Factor for TruncFactor { - fn propagate(&mut self, _vars: &mut VarStore) -> (f64, f64) { - unimplemented!("TruncFactor stub — implemented in Task 6") + fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64) { + let marginal = vars.get(self.diff); + // Cavity: marginal divided by our outgoing message. + let cavity = marginal / self.msg; + + // First-time-only: cache the evidence contribution from the cavity. + if self.evidence_cached.is_none() { + self.evidence_cached = Some(cavity_evidence(cavity, self.margin, self.tie)); + } + + // Apply the truncation approximation to the cavity. + let trunc = approx(cavity, self.margin, self.tie); + + // New outgoing message such that cavity * new_msg = trunc. + let new_msg = trunc / cavity; + let old_msg = self.msg; + self.msg = new_msg; + + // Update the marginal: marginal_new = cavity * new_msg = trunc. + vars.set(self.diff, trunc); + + old_msg.delta(new_msg) + } + + fn log_evidence(&self, _vars: &VarStore) -> f64 { + self.evidence_cached.unwrap_or(1.0).ln() + } +} + +/// P(diff > margin) for non-tie, P(|diff| < margin) for tie. +fn cavity_evidence(diff: Gaussian, margin: f64, tie: bool) -> f64 { + if tie { + cdf(margin, diff.mu(), diff.sigma()) - cdf(-margin, diff.mu(), diff.sigma()) + } else { + 1.0 - cdf(margin, diff.mu(), diff.sigma()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::factor::VarStore; + + #[test] + fn idempotent_after_convergence() { + // After enough iterations, propagate should return ~0 delta. + let mut vars = VarStore::new(); + let diff = vars.alloc(Gaussian::from_ms(2.0, 3.0)); + + let mut f = TruncFactor::new(diff, 0.0, false); + + // Propagate many times; delta should drop toward 0. + let mut last = (f64::INFINITY, f64::INFINITY); + for _ in 0..20 { + last = f.propagate(&mut vars); + } + assert!(last.0 < 1e-10, "expected converged delta, got {}", last.0); + assert!(last.1 < 1e-10); + } + + #[test] + fn evidence_cached_on_first_propagate() { + let mut vars = VarStore::new(); + let diff = vars.alloc(Gaussian::from_ms(2.0, 3.0)); + + let mut f = TruncFactor::new(diff, 0.0, false); + assert!(f.evidence_cached.is_none()); + + f.propagate(&mut vars); + assert!(f.evidence_cached.is_some()); + let first = f.evidence_cached.unwrap(); + + // Evidence should be P(diff > 0) for diff ~ N(2, 9) ≈ 0.748 + assert!(first > 0.7); + assert!(first < 0.8); + + // Subsequent propagations don't change it. + f.propagate(&mut vars); + assert_eq!(f.evidence_cached.unwrap(), first); + } + + #[test] + fn tie_evidence_uses_two_sided() { + let mut vars = VarStore::new(); + let diff = vars.alloc(Gaussian::from_ms(0.0, 2.0)); + + let mut f = TruncFactor::new(diff, 1.0, true); + f.propagate(&mut vars); + + // For diff ~ N(0, 4), tie=true with margin=1: P(-1 < diff < 1) ≈ 0.383 + let ev = f.evidence_cached.unwrap(); + assert!(ev > 0.35 && ev < 0.42); } } diff --git a/src/lib.rs b/src/lib.rs index fd1f27c..bd496fd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -163,7 +163,7 @@ fn compute_margin(p_draw: f64, sd: f64) -> f64 { ppf(0.5 - p_draw / 2.0, 0.0, sd).abs() } -fn cdf(x: f64, mu: f64, sigma: f64) -> f64 { +pub(crate) fn cdf(x: f64, mu: f64, sigma: f64) -> f64 { let z = -(x - mu) / (sigma * SQRT_2); 0.5 * erfc(z)