T0 + T1 + T2: engine redesign through new API surface #1
@@ -1,15 +1,23 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
N_INF,
|
N_INF, approx, cdf,
|
||||||
factor::{Factor, VarId, VarStore},
|
factor::{Factor, VarId, VarStore},
|
||||||
gaussian::Gaussian,
|
gaussian::Gaussian,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// EP truncation factor on a diff variable.
|
||||||
|
///
|
||||||
|
/// Implements the rectified-Gaussian approximation that turns a diff
|
||||||
|
/// distribution into a "this team rank-beats that team" or "tied" likelihood.
|
||||||
|
/// Stores its outgoing message to the diff variable so the cavity computation
|
||||||
|
/// produces the correct EP message on each propagation.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct TruncFactor {
|
pub(crate) struct TruncFactor {
|
||||||
pub(crate) diff: VarId,
|
pub(crate) diff: VarId,
|
||||||
pub(crate) margin: f64,
|
pub(crate) margin: f64,
|
||||||
pub(crate) tie: bool,
|
pub(crate) tie: bool,
|
||||||
|
/// Outgoing message to the diff variable (initial: N_INF, the EP identity).
|
||||||
pub(crate) msg: Gaussian,
|
pub(crate) msg: Gaussian,
|
||||||
|
/// Cached evidence (linear, not log) computed from the cavity on first propagation.
|
||||||
pub(crate) evidence_cached: Option<f64>,
|
pub(crate) evidence_cached: Option<f64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -26,7 +34,97 @@ impl TruncFactor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Factor for TruncFactor {
|
impl Factor for TruncFactor {
|
||||||
fn propagate(&mut self, _vars: &mut VarStore) -> (f64, f64) {
|
fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64) {
|
||||||
unimplemented!("TruncFactor stub — implemented in Task 6")
|
let marginal = vars.get(self.diff);
|
||||||
|
// Cavity: marginal divided by our outgoing message.
|
||||||
|
let cavity = marginal / self.msg;
|
||||||
|
|
||||||
|
// First-time-only: cache the evidence contribution from the cavity.
|
||||||
|
if self.evidence_cached.is_none() {
|
||||||
|
self.evidence_cached = Some(cavity_evidence(cavity, self.margin, self.tie));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply the truncation approximation to the cavity.
|
||||||
|
let trunc = approx(cavity, self.margin, self.tie);
|
||||||
|
|
||||||
|
// New outgoing message such that cavity * new_msg = trunc.
|
||||||
|
let new_msg = trunc / cavity;
|
||||||
|
let old_msg = self.msg;
|
||||||
|
self.msg = new_msg;
|
||||||
|
|
||||||
|
// Update the marginal: marginal_new = cavity * new_msg = trunc.
|
||||||
|
vars.set(self.diff, trunc);
|
||||||
|
|
||||||
|
old_msg.delta(new_msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn log_evidence(&self, _vars: &VarStore) -> f64 {
|
||||||
|
self.evidence_cached.unwrap_or(1.0).ln()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// P(diff > margin) for non-tie, P(|diff| < margin) for tie.
|
||||||
|
fn cavity_evidence(diff: Gaussian, margin: f64, tie: bool) -> f64 {
|
||||||
|
if tie {
|
||||||
|
cdf(margin, diff.mu(), diff.sigma()) - cdf(-margin, diff.mu(), diff.sigma())
|
||||||
|
} else {
|
||||||
|
1.0 - cdf(margin, diff.mu(), diff.sigma())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::factor::VarStore;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn idempotent_after_convergence() {
|
||||||
|
// After enough iterations, propagate should return ~0 delta.
|
||||||
|
let mut vars = VarStore::new();
|
||||||
|
let diff = vars.alloc(Gaussian::from_ms(2.0, 3.0));
|
||||||
|
|
||||||
|
let mut f = TruncFactor::new(diff, 0.0, false);
|
||||||
|
|
||||||
|
// Propagate many times; delta should drop toward 0.
|
||||||
|
let mut last = (f64::INFINITY, f64::INFINITY);
|
||||||
|
for _ in 0..20 {
|
||||||
|
last = f.propagate(&mut vars);
|
||||||
|
}
|
||||||
|
assert!(last.0 < 1e-10, "expected converged delta, got {}", last.0);
|
||||||
|
assert!(last.1 < 1e-10);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn evidence_cached_on_first_propagate() {
|
||||||
|
let mut vars = VarStore::new();
|
||||||
|
let diff = vars.alloc(Gaussian::from_ms(2.0, 3.0));
|
||||||
|
|
||||||
|
let mut f = TruncFactor::new(diff, 0.0, false);
|
||||||
|
assert!(f.evidence_cached.is_none());
|
||||||
|
|
||||||
|
f.propagate(&mut vars);
|
||||||
|
assert!(f.evidence_cached.is_some());
|
||||||
|
let first = f.evidence_cached.unwrap();
|
||||||
|
|
||||||
|
// Evidence should be P(diff > 0) for diff ~ N(2, 9) ≈ 0.748
|
||||||
|
assert!(first > 0.7);
|
||||||
|
assert!(first < 0.8);
|
||||||
|
|
||||||
|
// Subsequent propagations don't change it.
|
||||||
|
f.propagate(&mut vars);
|
||||||
|
assert_eq!(f.evidence_cached.unwrap(), first);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tie_evidence_uses_two_sided() {
|
||||||
|
let mut vars = VarStore::new();
|
||||||
|
let diff = vars.alloc(Gaussian::from_ms(0.0, 2.0));
|
||||||
|
|
||||||
|
let mut f = TruncFactor::new(diff, 1.0, true);
|
||||||
|
f.propagate(&mut vars);
|
||||||
|
|
||||||
|
// For diff ~ N(0, 4), tie=true with margin=1: P(-1 < diff < 1) ≈ 0.383
|
||||||
|
let ev = f.evidence_cached.unwrap();
|
||||||
|
assert!(ev > 0.35 && ev < 0.42);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -163,7 +163,7 @@ fn compute_margin(p_draw: f64, sd: f64) -> f64 {
|
|||||||
ppf(0.5 - p_draw / 2.0, 0.0, sd).abs()
|
ppf(0.5 - p_draw / 2.0, 0.0, sd).abs()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cdf(x: f64, mu: f64, sigma: f64) -> f64 {
|
pub(crate) fn cdf(x: f64, mu: f64, sigma: f64) -> f64 {
|
||||||
let z = -(x - mu) / (sigma * SQRT_2);
|
let z = -(x - mu) / (sigma * SQRT_2);
|
||||||
|
|
||||||
0.5 * erfc(z)
|
0.5 * erfc(z)
|
||||||
|
|||||||
Reference in New Issue
Block a user