use crate::{ N_INF, approx, cdf, factor::{Factor, VarId, VarStore}, gaussian::Gaussian, }; /// EP truncation factor on a diff variable. /// /// Implements the rectified-Gaussian approximation that turns a diff /// distribution into a "this team rank-beats that team" or "tied" likelihood. /// Stores its outgoing message to the diff variable so the cavity computation /// produces the correct EP message on each propagation. #[derive(Debug)] pub struct TruncFactor { pub diff: VarId, pub margin: f64, pub tie: bool, /// Outgoing message to the diff variable (initial: N_INF, the EP identity). pub(crate) msg: Gaussian, /// Cached evidence (linear, not log) computed from the cavity on first propagation. pub(crate) evidence_cached: Option, } impl TruncFactor { pub fn new(diff: VarId, margin: f64, tie: bool) -> Self { Self { diff, margin, tie, msg: N_INF, evidence_cached: None, } } } impl TruncFactor { /// Propagate this factor's message, optionally damping the update in /// natural-parameter space. `alpha = 1.0` matches `Factor::propagate` /// exactly; `alpha < 1.0` writes `α·new_msg + (1−α)·old_msg`. pub(crate) fn propagate_with_alpha(&mut self, vars: &mut VarStore, alpha: f64) -> (f64, f64) { let marginal = vars.get(self.diff); let cavity = marginal / self.msg; if self.evidence_cached.is_none() { self.evidence_cached = Some(cavity_evidence(cavity, self.margin, self.tie)); } let trunc = approx(cavity, self.margin, self.tie); let new_msg = trunc / cavity; let damped = self.msg.damp_natural(new_msg, alpha); let old_msg = self.msg; self.msg = damped; // marginal_new = cavity * stored_msg. With alpha = 1.0 this equals // `trunc` (since cavity * new_msg = trunc by construction); with // alpha < 1.0 it reflects the partially-applied update. vars.set(self.diff, cavity * damped); old_msg.delta(damped) } } impl Factor for TruncFactor { fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64) { self.propagate_with_alpha(vars, 1.0) } fn log_evidence(&self, _vars: &VarStore) -> f64 { self.evidence_cached.unwrap_or(1.0).ln() } } /// P(diff > margin) for non-tie, P(|diff| < margin) for tie. fn cavity_evidence(diff: Gaussian, margin: f64, tie: bool) -> f64 { if tie { cdf(margin, diff.mu(), diff.sigma()) - cdf(-margin, diff.mu(), diff.sigma()) } else { 1.0 - cdf(margin, diff.mu(), diff.sigma()) } } #[cfg(test)] mod tests { use super::*; use crate::factor::VarStore; #[test] fn idempotent_after_convergence() { // After enough iterations, propagate should return ~0 delta. let mut vars = VarStore::new(); let diff = vars.alloc(Gaussian::from_ms(2.0, 3.0)); let mut f = TruncFactor::new(diff, 0.0, false); // Propagate many times; delta should drop toward 0. let mut last = (f64::INFINITY, f64::INFINITY); for _ in 0..20 { last = f.propagate(&mut vars); } assert!(last.0 < 1e-10, "expected converged delta, got {}", last.0); assert!(last.1 < 1e-10); } #[test] fn evidence_cached_on_first_propagate() { let mut vars = VarStore::new(); let diff = vars.alloc(Gaussian::from_ms(2.0, 3.0)); let mut f = TruncFactor::new(diff, 0.0, false); assert!(f.evidence_cached.is_none()); f.propagate(&mut vars); assert!(f.evidence_cached.is_some()); let first = f.evidence_cached.unwrap(); // Evidence should be P(diff > 0) for diff ~ N(2, 9) ≈ 0.748 assert!(first > 0.7); assert!(first < 0.8); // Subsequent propagations don't change it. f.propagate(&mut vars); assert_eq!(f.evidence_cached.unwrap(), first); } #[test] fn tie_evidence_uses_two_sided() { let mut vars = VarStore::new(); let diff = vars.alloc(Gaussian::from_ms(0.0, 2.0)); let mut f = TruncFactor::new(diff, 1.0, true); f.propagate(&mut vars); // For diff ~ N(0, 4), tie=true with margin=1: P(-1 < diff < 1) ≈ 0.383 let ev = f.evidence_cached.unwrap(); assert!(ev > 0.35 && ev < 0.42); } #[test] fn propagate_with_alpha_one_matches_undamped_propagate() { let mut vars_a = VarStore::new(); let diff_a = vars_a.alloc(Gaussian::from_ms(2.0, 3.0)); let mut f_a = TruncFactor::new(diff_a, 0.0, false); let delta_a = f_a.propagate(&mut vars_a); let result_a = vars_a.get(diff_a); let mut vars_b = VarStore::new(); let diff_b = vars_b.alloc(Gaussian::from_ms(2.0, 3.0)); let mut f_b = TruncFactor::new(diff_b, 0.0, false); let delta_b = f_b.propagate_with_alpha(&mut vars_b, 1.0); let result_b = vars_b.get(diff_b); assert_eq!(result_a.pi(), result_b.pi()); assert_eq!(result_a.tau(), result_b.tau()); assert_eq!(delta_a, delta_b); assert_eq!(f_a.msg.pi(), f_b.msg.pi()); assert_eq!(f_a.msg.tau(), f_b.msg.tau()); } #[test] fn propagate_with_alpha_half_blends_msg_in_natural_params() { // Run undamped to capture (initial_msg, undamped_new_msg). let mut vars_full = VarStore::new(); let diff_full = vars_full.alloc(Gaussian::from_ms(2.0, 3.0)); let mut f_full = TruncFactor::new(diff_full, 0.0, false); let initial_msg_pi = f_full.msg.pi(); let initial_msg_tau = f_full.msg.tau(); f_full.propagate(&mut vars_full); let undamped_msg_pi = f_full.msg.pi(); let undamped_msg_tau = f_full.msg.tau(); // Run damped at α = 0.5 from the same initial state. let mut vars_half = VarStore::new(); let diff_half = vars_half.alloc(Gaussian::from_ms(2.0, 3.0)); let mut f_half = TruncFactor::new(diff_half, 0.0, false); f_half.propagate_with_alpha(&mut vars_half, 0.5); let expected_pi = 0.5 * undamped_msg_pi + 0.5 * initial_msg_pi; let expected_tau = 0.5 * undamped_msg_tau + 0.5 * initial_msg_tau; assert!((f_half.msg.pi() - expected_pi).abs() < 1e-12); assert!((f_half.msg.tau() - expected_tau).abs() < 1e-12); } }