feat(factor): add TruncFactor::propagate_with_alpha for EP damping
Inherent method that applies α-damping to the outgoing message via Gaussian::damp_natural. The Factor trait impl delegates with α=1.0, preserving today's behavior bit-equal. Variable write switched from `trunc` to `cavity * damped` — algebraically identical when α=1.0 (cavity * new_msg = trunc by construction); reflects partial-update math when α<1.0.
This commit is contained in:
+64
-11
@@ -33,29 +33,37 @@ impl TruncFactor {
|
||||
}
|
||||
}
|
||||
|
||||
impl Factor for TruncFactor {
|
||||
fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64) {
|
||||
impl TruncFactor {
|
||||
/// Propagate this factor's message, optionally damping the update in
|
||||
/// natural-parameter space. `alpha = 1.0` matches `Factor::propagate`
|
||||
/// exactly; `alpha < 1.0` writes `α·new_msg + (1−α)·old_msg`.
|
||||
pub(crate) fn propagate_with_alpha(&mut self, vars: &mut VarStore, alpha: f64) -> (f64, f64) {
|
||||
let marginal = vars.get(self.diff);
|
||||
// Cavity: marginal divided by our outgoing message.
|
||||
let cavity = marginal / self.msg;
|
||||
|
||||
// First-time-only: cache the evidence contribution from the cavity.
|
||||
if self.evidence_cached.is_none() {
|
||||
self.evidence_cached = Some(cavity_evidence(cavity, self.margin, self.tie));
|
||||
}
|
||||
|
||||
// Apply the truncation approximation to the cavity.
|
||||
let trunc = approx(cavity, self.margin, self.tie);
|
||||
|
||||
// New outgoing message such that cavity * new_msg = trunc.
|
||||
let new_msg = trunc / cavity;
|
||||
|
||||
let damped = self.msg.damp_natural(new_msg, alpha);
|
||||
let old_msg = self.msg;
|
||||
self.msg = new_msg;
|
||||
self.msg = damped;
|
||||
|
||||
// Update the marginal: marginal_new = cavity * new_msg = trunc.
|
||||
vars.set(self.diff, trunc);
|
||||
// marginal_new = cavity * stored_msg. With alpha = 1.0 this equals
|
||||
// `trunc` (since cavity * new_msg = trunc by construction); with
|
||||
// alpha < 1.0 it reflects the partially-applied update.
|
||||
vars.set(self.diff, cavity * damped);
|
||||
|
||||
old_msg.delta(new_msg)
|
||||
old_msg.delta(damped)
|
||||
}
|
||||
}
|
||||
|
||||
impl Factor for TruncFactor {
|
||||
fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64) {
|
||||
self.propagate_with_alpha(vars, 1.0)
|
||||
}
|
||||
|
||||
fn log_evidence(&self, _vars: &VarStore) -> f64 {
|
||||
@@ -127,4 +135,49 @@ mod tests {
|
||||
let ev = f.evidence_cached.unwrap();
|
||||
assert!(ev > 0.35 && ev < 0.42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn propagate_with_alpha_one_matches_undamped_propagate() {
|
||||
let mut vars_a = VarStore::new();
|
||||
let diff_a = vars_a.alloc(Gaussian::from_ms(2.0, 3.0));
|
||||
let mut f_a = TruncFactor::new(diff_a, 0.0, false);
|
||||
let delta_a = f_a.propagate(&mut vars_a);
|
||||
let result_a = vars_a.get(diff_a);
|
||||
|
||||
let mut vars_b = VarStore::new();
|
||||
let diff_b = vars_b.alloc(Gaussian::from_ms(2.0, 3.0));
|
||||
let mut f_b = TruncFactor::new(diff_b, 0.0, false);
|
||||
let delta_b = f_b.propagate_with_alpha(&mut vars_b, 1.0);
|
||||
let result_b = vars_b.get(diff_b);
|
||||
|
||||
assert_eq!(result_a.pi(), result_b.pi());
|
||||
assert_eq!(result_a.tau(), result_b.tau());
|
||||
assert_eq!(delta_a, delta_b);
|
||||
assert_eq!(f_a.msg.pi(), f_b.msg.pi());
|
||||
assert_eq!(f_a.msg.tau(), f_b.msg.tau());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn propagate_with_alpha_half_blends_msg_in_natural_params() {
|
||||
// Run undamped to capture (initial_msg, undamped_new_msg).
|
||||
let mut vars_full = VarStore::new();
|
||||
let diff_full = vars_full.alloc(Gaussian::from_ms(2.0, 3.0));
|
||||
let mut f_full = TruncFactor::new(diff_full, 0.0, false);
|
||||
let initial_msg_pi = f_full.msg.pi();
|
||||
let initial_msg_tau = f_full.msg.tau();
|
||||
f_full.propagate(&mut vars_full);
|
||||
let undamped_msg_pi = f_full.msg.pi();
|
||||
let undamped_msg_tau = f_full.msg.tau();
|
||||
|
||||
// Run damped at α = 0.5 from the same initial state.
|
||||
let mut vars_half = VarStore::new();
|
||||
let diff_half = vars_half.alloc(Gaussian::from_ms(2.0, 3.0));
|
||||
let mut f_half = TruncFactor::new(diff_half, 0.0, false);
|
||||
f_half.propagate_with_alpha(&mut vars_half, 0.5);
|
||||
|
||||
let expected_pi = 0.5 * undamped_msg_pi + 0.5 * initial_msg_pi;
|
||||
let expected_tau = 0.5 * undamped_msg_tau + 0.5 * initial_msg_tau;
|
||||
assert!((f_half.msg.pi() - expected_pi).abs() < 1e-12);
|
||||
assert!((f_half.msg.tau() - expected_tau).abs() < 1e-12);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user