Files
trueskill-tt/src/factor/trunc.rs
T
logaritmisk fcfe0ffe37 feat(factor): add TruncFactor::propagate_with_alpha for EP damping
Inherent method that applies α-damping to the outgoing message via
Gaussian::damp_natural. The Factor trait impl delegates with α=1.0,
preserving today's behavior bit-equal. Variable write switched from
`trunc` to `cavity * damped` — algebraically identical when α=1.0
(cavity * new_msg = trunc by construction); reflects partial-update
math when α<1.0.
2026-05-08 15:02:09 +02:00

184 lines
6.3 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use crate::{
N_INF, approx, cdf,
factor::{Factor, VarId, VarStore},
gaussian::Gaussian,
};
/// EP truncation factor on a diff variable.
///
/// Implements the rectified-Gaussian approximation that turns a diff
/// distribution into a "this team rank-beats that team" or "tied" likelihood.
/// Stores its outgoing message to the diff variable so the cavity computation
/// produces the correct EP message on each propagation.
#[derive(Debug)]
pub struct TruncFactor {
pub diff: VarId,
pub margin: f64,
pub tie: bool,
/// Outgoing message to the diff variable (initial: N_INF, the EP identity).
pub(crate) msg: Gaussian,
/// Cached evidence (linear, not log) computed from the cavity on first propagation.
pub(crate) evidence_cached: Option<f64>,
}
impl TruncFactor {
pub fn new(diff: VarId, margin: f64, tie: bool) -> Self {
Self {
diff,
margin,
tie,
msg: N_INF,
evidence_cached: None,
}
}
}
impl TruncFactor {
/// Propagate this factor's message, optionally damping the update in
/// natural-parameter space. `alpha = 1.0` matches `Factor::propagate`
/// exactly; `alpha < 1.0` writes `α·new_msg + (1−α)·old_msg`.
pub(crate) fn propagate_with_alpha(&mut self, vars: &mut VarStore, alpha: f64) -> (f64, f64) {
let marginal = vars.get(self.diff);
let cavity = marginal / self.msg;
if self.evidence_cached.is_none() {
self.evidence_cached = Some(cavity_evidence(cavity, self.margin, self.tie));
}
let trunc = approx(cavity, self.margin, self.tie);
let new_msg = trunc / cavity;
let damped = self.msg.damp_natural(new_msg, alpha);
let old_msg = self.msg;
self.msg = damped;
// marginal_new = cavity * stored_msg. With alpha = 1.0 this equals
// `trunc` (since cavity * new_msg = trunc by construction); with
// alpha < 1.0 it reflects the partially-applied update.
vars.set(self.diff, cavity * damped);
old_msg.delta(damped)
}
}
impl Factor for TruncFactor {
fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64) {
self.propagate_with_alpha(vars, 1.0)
}
fn log_evidence(&self, _vars: &VarStore) -> f64 {
self.evidence_cached.unwrap_or(1.0).ln()
}
}
/// P(diff > margin) for non-tie, P(|diff| < margin) for tie.
fn cavity_evidence(diff: Gaussian, margin: f64, tie: bool) -> f64 {
if tie {
cdf(margin, diff.mu(), diff.sigma()) - cdf(-margin, diff.mu(), diff.sigma())
} else {
1.0 - cdf(margin, diff.mu(), diff.sigma())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::factor::VarStore;
#[test]
fn idempotent_after_convergence() {
// After enough iterations, propagate should return ~0 delta.
let mut vars = VarStore::new();
let diff = vars.alloc(Gaussian::from_ms(2.0, 3.0));
let mut f = TruncFactor::new(diff, 0.0, false);
// Propagate many times; delta should drop toward 0.
let mut last = (f64::INFINITY, f64::INFINITY);
for _ in 0..20 {
last = f.propagate(&mut vars);
}
assert!(last.0 < 1e-10, "expected converged delta, got {}", last.0);
assert!(last.1 < 1e-10);
}
#[test]
fn evidence_cached_on_first_propagate() {
let mut vars = VarStore::new();
let diff = vars.alloc(Gaussian::from_ms(2.0, 3.0));
let mut f = TruncFactor::new(diff, 0.0, false);
assert!(f.evidence_cached.is_none());
f.propagate(&mut vars);
assert!(f.evidence_cached.is_some());
let first = f.evidence_cached.unwrap();
// Evidence should be P(diff > 0) for diff ~ N(2, 9) ≈ 0.748
assert!(first > 0.7);
assert!(first < 0.8);
// Subsequent propagations don't change it.
f.propagate(&mut vars);
assert_eq!(f.evidence_cached.unwrap(), first);
}
#[test]
fn tie_evidence_uses_two_sided() {
let mut vars = VarStore::new();
let diff = vars.alloc(Gaussian::from_ms(0.0, 2.0));
let mut f = TruncFactor::new(diff, 1.0, true);
f.propagate(&mut vars);
// For diff ~ N(0, 4), tie=true with margin=1: P(-1 < diff < 1) ≈ 0.383
let ev = f.evidence_cached.unwrap();
assert!(ev > 0.35 && ev < 0.42);
}
#[test]
fn propagate_with_alpha_one_matches_undamped_propagate() {
let mut vars_a = VarStore::new();
let diff_a = vars_a.alloc(Gaussian::from_ms(2.0, 3.0));
let mut f_a = TruncFactor::new(diff_a, 0.0, false);
let delta_a = f_a.propagate(&mut vars_a);
let result_a = vars_a.get(diff_a);
let mut vars_b = VarStore::new();
let diff_b = vars_b.alloc(Gaussian::from_ms(2.0, 3.0));
let mut f_b = TruncFactor::new(diff_b, 0.0, false);
let delta_b = f_b.propagate_with_alpha(&mut vars_b, 1.0);
let result_b = vars_b.get(diff_b);
assert_eq!(result_a.pi(), result_b.pi());
assert_eq!(result_a.tau(), result_b.tau());
assert_eq!(delta_a, delta_b);
assert_eq!(f_a.msg.pi(), f_b.msg.pi());
assert_eq!(f_a.msg.tau(), f_b.msg.tau());
}
#[test]
fn propagate_with_alpha_half_blends_msg_in_natural_params() {
// Run undamped to capture (initial_msg, undamped_new_msg).
let mut vars_full = VarStore::new();
let diff_full = vars_full.alloc(Gaussian::from_ms(2.0, 3.0));
let mut f_full = TruncFactor::new(diff_full, 0.0, false);
let initial_msg_pi = f_full.msg.pi();
let initial_msg_tau = f_full.msg.tau();
f_full.propagate(&mut vars_full);
let undamped_msg_pi = f_full.msg.pi();
let undamped_msg_tau = f_full.msg.tau();
// Run damped at α = 0.5 from the same initial state.
let mut vars_half = VarStore::new();
let diff_half = vars_half.alloc(Gaussian::from_ms(2.0, 3.0));
let mut f_half = TruncFactor::new(diff_half, 0.0, false);
f_half.propagate_with_alpha(&mut vars_half, 0.5);
let expected_pi = 0.5 * undamped_msg_pi + 0.5 * initial_msg_pi;
let expected_tau = 0.5 * undamped_msg_tau + 0.5 * initial_msg_tau;
assert!((f_half.msg.pi() - expected_pi).abs() < 1e-12);
assert!((f_half.msg.tau() - expected_tau).abs() < 1e-12);
}
}