fcfe0ffe37
Inherent method that applies α-damping to the outgoing message via Gaussian::damp_natural. The Factor trait impl delegates with α=1.0, preserving today's behavior bit-equal. Variable write switched from `trunc` to `cavity * damped` — algebraically identical when α=1.0 (cavity * new_msg = trunc by construction); reflects partial-update math when α<1.0.
184 lines
6.3 KiB
Rust
184 lines
6.3 KiB
Rust
use crate::{
|
||
N_INF, approx, cdf,
|
||
factor::{Factor, VarId, VarStore},
|
||
gaussian::Gaussian,
|
||
};
|
||
|
||
/// EP truncation factor on a diff variable.
|
||
///
|
||
/// Implements the rectified-Gaussian approximation that turns a diff
|
||
/// distribution into a "this team rank-beats that team" or "tied" likelihood.
|
||
/// Stores its outgoing message to the diff variable so the cavity computation
|
||
/// produces the correct EP message on each propagation.
|
||
#[derive(Debug)]
|
||
pub struct TruncFactor {
|
||
pub diff: VarId,
|
||
pub margin: f64,
|
||
pub tie: bool,
|
||
/// Outgoing message to the diff variable (initial: N_INF, the EP identity).
|
||
pub(crate) msg: Gaussian,
|
||
/// Cached evidence (linear, not log) computed from the cavity on first propagation.
|
||
pub(crate) evidence_cached: Option<f64>,
|
||
}
|
||
|
||
impl TruncFactor {
|
||
pub fn new(diff: VarId, margin: f64, tie: bool) -> Self {
|
||
Self {
|
||
diff,
|
||
margin,
|
||
tie,
|
||
msg: N_INF,
|
||
evidence_cached: None,
|
||
}
|
||
}
|
||
}
|
||
|
||
impl TruncFactor {
|
||
/// Propagate this factor's message, optionally damping the update in
|
||
/// natural-parameter space. `alpha = 1.0` matches `Factor::propagate`
|
||
/// exactly; `alpha < 1.0` writes `α·new_msg + (1−α)·old_msg`.
|
||
pub(crate) fn propagate_with_alpha(&mut self, vars: &mut VarStore, alpha: f64) -> (f64, f64) {
|
||
let marginal = vars.get(self.diff);
|
||
let cavity = marginal / self.msg;
|
||
|
||
if self.evidence_cached.is_none() {
|
||
self.evidence_cached = Some(cavity_evidence(cavity, self.margin, self.tie));
|
||
}
|
||
|
||
let trunc = approx(cavity, self.margin, self.tie);
|
||
let new_msg = trunc / cavity;
|
||
|
||
let damped = self.msg.damp_natural(new_msg, alpha);
|
||
let old_msg = self.msg;
|
||
self.msg = damped;
|
||
|
||
// marginal_new = cavity * stored_msg. With alpha = 1.0 this equals
|
||
// `trunc` (since cavity * new_msg = trunc by construction); with
|
||
// alpha < 1.0 it reflects the partially-applied update.
|
||
vars.set(self.diff, cavity * damped);
|
||
|
||
old_msg.delta(damped)
|
||
}
|
||
}
|
||
|
||
impl Factor for TruncFactor {
|
||
fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64) {
|
||
self.propagate_with_alpha(vars, 1.0)
|
||
}
|
||
|
||
fn log_evidence(&self, _vars: &VarStore) -> f64 {
|
||
self.evidence_cached.unwrap_or(1.0).ln()
|
||
}
|
||
}
|
||
|
||
/// P(diff > margin) for non-tie, P(|diff| < margin) for tie.
|
||
fn cavity_evidence(diff: Gaussian, margin: f64, tie: bool) -> f64 {
|
||
if tie {
|
||
cdf(margin, diff.mu(), diff.sigma()) - cdf(-margin, diff.mu(), diff.sigma())
|
||
} else {
|
||
1.0 - cdf(margin, diff.mu(), diff.sigma())
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use crate::factor::VarStore;
|
||
|
||
#[test]
|
||
fn idempotent_after_convergence() {
|
||
// After enough iterations, propagate should return ~0 delta.
|
||
let mut vars = VarStore::new();
|
||
let diff = vars.alloc(Gaussian::from_ms(2.0, 3.0));
|
||
|
||
let mut f = TruncFactor::new(diff, 0.0, false);
|
||
|
||
// Propagate many times; delta should drop toward 0.
|
||
let mut last = (f64::INFINITY, f64::INFINITY);
|
||
for _ in 0..20 {
|
||
last = f.propagate(&mut vars);
|
||
}
|
||
assert!(last.0 < 1e-10, "expected converged delta, got {}", last.0);
|
||
assert!(last.1 < 1e-10);
|
||
}
|
||
|
||
#[test]
|
||
fn evidence_cached_on_first_propagate() {
|
||
let mut vars = VarStore::new();
|
||
let diff = vars.alloc(Gaussian::from_ms(2.0, 3.0));
|
||
|
||
let mut f = TruncFactor::new(diff, 0.0, false);
|
||
assert!(f.evidence_cached.is_none());
|
||
|
||
f.propagate(&mut vars);
|
||
assert!(f.evidence_cached.is_some());
|
||
let first = f.evidence_cached.unwrap();
|
||
|
||
// Evidence should be P(diff > 0) for diff ~ N(2, 9) ≈ 0.748
|
||
assert!(first > 0.7);
|
||
assert!(first < 0.8);
|
||
|
||
// Subsequent propagations don't change it.
|
||
f.propagate(&mut vars);
|
||
assert_eq!(f.evidence_cached.unwrap(), first);
|
||
}
|
||
|
||
#[test]
|
||
fn tie_evidence_uses_two_sided() {
|
||
let mut vars = VarStore::new();
|
||
let diff = vars.alloc(Gaussian::from_ms(0.0, 2.0));
|
||
|
||
let mut f = TruncFactor::new(diff, 1.0, true);
|
||
f.propagate(&mut vars);
|
||
|
||
// For diff ~ N(0, 4), tie=true with margin=1: P(-1 < diff < 1) ≈ 0.383
|
||
let ev = f.evidence_cached.unwrap();
|
||
assert!(ev > 0.35 && ev < 0.42);
|
||
}
|
||
|
||
#[test]
|
||
fn propagate_with_alpha_one_matches_undamped_propagate() {
|
||
let mut vars_a = VarStore::new();
|
||
let diff_a = vars_a.alloc(Gaussian::from_ms(2.0, 3.0));
|
||
let mut f_a = TruncFactor::new(diff_a, 0.0, false);
|
||
let delta_a = f_a.propagate(&mut vars_a);
|
||
let result_a = vars_a.get(diff_a);
|
||
|
||
let mut vars_b = VarStore::new();
|
||
let diff_b = vars_b.alloc(Gaussian::from_ms(2.0, 3.0));
|
||
let mut f_b = TruncFactor::new(diff_b, 0.0, false);
|
||
let delta_b = f_b.propagate_with_alpha(&mut vars_b, 1.0);
|
||
let result_b = vars_b.get(diff_b);
|
||
|
||
assert_eq!(result_a.pi(), result_b.pi());
|
||
assert_eq!(result_a.tau(), result_b.tau());
|
||
assert_eq!(delta_a, delta_b);
|
||
assert_eq!(f_a.msg.pi(), f_b.msg.pi());
|
||
assert_eq!(f_a.msg.tau(), f_b.msg.tau());
|
||
}
|
||
|
||
#[test]
|
||
fn propagate_with_alpha_half_blends_msg_in_natural_params() {
|
||
// Run undamped to capture (initial_msg, undamped_new_msg).
|
||
let mut vars_full = VarStore::new();
|
||
let diff_full = vars_full.alloc(Gaussian::from_ms(2.0, 3.0));
|
||
let mut f_full = TruncFactor::new(diff_full, 0.0, false);
|
||
let initial_msg_pi = f_full.msg.pi();
|
||
let initial_msg_tau = f_full.msg.tau();
|
||
f_full.propagate(&mut vars_full);
|
||
let undamped_msg_pi = f_full.msg.pi();
|
||
let undamped_msg_tau = f_full.msg.tau();
|
||
|
||
// Run damped at α = 0.5 from the same initial state.
|
||
let mut vars_half = VarStore::new();
|
||
let diff_half = vars_half.alloc(Gaussian::from_ms(2.0, 3.0));
|
||
let mut f_half = TruncFactor::new(diff_half, 0.0, false);
|
||
f_half.propagate_with_alpha(&mut vars_half, 0.5);
|
||
|
||
let expected_pi = 0.5 * undamped_msg_pi + 0.5 * initial_msg_pi;
|
||
let expected_tau = 0.5 * undamped_msg_tau + 0.5 * initial_msg_tau;
|
||
assert!((f_half.msg.pi() - expected_pi).abs() < 1e-12);
|
||
assert!((f_half.msg.tau() - expected_tau).abs() < 1e-12);
|
||
}
|
||
}
|