From fcfe0ffe37201f567bc63eed87ffd6d4bfc48453 Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Fri, 8 May 2026 15:02:09 +0200 Subject: [PATCH] feat(factor): add TruncFactor::propagate_with_alpha for EP damping MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Inherent method that applies α-damping to the outgoing message via Gaussian::damp_natural. The Factor trait impl delegates with α=1.0, preserving today's behavior bit-equal. Variable write switched from `trunc` to `cavity * damped` — algebraically identical when α=1.0 (cavity * new_msg = trunc by construction); reflects partial-update math when α<1.0. --- src/factor/trunc.rs | 75 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 64 insertions(+), 11 deletions(-) diff --git a/src/factor/trunc.rs b/src/factor/trunc.rs index 6090a39..4b1aaa2 100644 --- a/src/factor/trunc.rs +++ b/src/factor/trunc.rs @@ -33,29 +33,37 @@ impl TruncFactor { } } -impl Factor for TruncFactor { - fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64) { +impl TruncFactor { + /// Propagate this factor's message, optionally damping the update in + /// natural-parameter space. `alpha = 1.0` matches `Factor::propagate` + /// exactly; `alpha < 1.0` writes `α·new_msg + (1−α)·old_msg`. + pub(crate) fn propagate_with_alpha(&mut self, vars: &mut VarStore, alpha: f64) -> (f64, f64) { let marginal = vars.get(self.diff); - // Cavity: marginal divided by our outgoing message. let cavity = marginal / self.msg; - // First-time-only: cache the evidence contribution from the cavity. if self.evidence_cached.is_none() { self.evidence_cached = Some(cavity_evidence(cavity, self.margin, self.tie)); } - // Apply the truncation approximation to the cavity. let trunc = approx(cavity, self.margin, self.tie); - - // New outgoing message such that cavity * new_msg = trunc. let new_msg = trunc / cavity; + + let damped = self.msg.damp_natural(new_msg, alpha); let old_msg = self.msg; - self.msg = new_msg; + self.msg = damped; - // Update the marginal: marginal_new = cavity * new_msg = trunc. - vars.set(self.diff, trunc); + // marginal_new = cavity * stored_msg. With alpha = 1.0 this equals + // `trunc` (since cavity * new_msg = trunc by construction); with + // alpha < 1.0 it reflects the partially-applied update. + vars.set(self.diff, cavity * damped); - old_msg.delta(new_msg) + old_msg.delta(damped) + } +} + +impl Factor for TruncFactor { + fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64) { + self.propagate_with_alpha(vars, 1.0) } fn log_evidence(&self, _vars: &VarStore) -> f64 { @@ -127,4 +135,49 @@ mod tests { let ev = f.evidence_cached.unwrap(); assert!(ev > 0.35 && ev < 0.42); } + + #[test] + fn propagate_with_alpha_one_matches_undamped_propagate() { + let mut vars_a = VarStore::new(); + let diff_a = vars_a.alloc(Gaussian::from_ms(2.0, 3.0)); + let mut f_a = TruncFactor::new(diff_a, 0.0, false); + let delta_a = f_a.propagate(&mut vars_a); + let result_a = vars_a.get(diff_a); + + let mut vars_b = VarStore::new(); + let diff_b = vars_b.alloc(Gaussian::from_ms(2.0, 3.0)); + let mut f_b = TruncFactor::new(diff_b, 0.0, false); + let delta_b = f_b.propagate_with_alpha(&mut vars_b, 1.0); + let result_b = vars_b.get(diff_b); + + assert_eq!(result_a.pi(), result_b.pi()); + assert_eq!(result_a.tau(), result_b.tau()); + assert_eq!(delta_a, delta_b); + assert_eq!(f_a.msg.pi(), f_b.msg.pi()); + assert_eq!(f_a.msg.tau(), f_b.msg.tau()); + } + + #[test] + fn propagate_with_alpha_half_blends_msg_in_natural_params() { + // Run undamped to capture (initial_msg, undamped_new_msg). + let mut vars_full = VarStore::new(); + let diff_full = vars_full.alloc(Gaussian::from_ms(2.0, 3.0)); + let mut f_full = TruncFactor::new(diff_full, 0.0, false); + let initial_msg_pi = f_full.msg.pi(); + let initial_msg_tau = f_full.msg.tau(); + f_full.propagate(&mut vars_full); + let undamped_msg_pi = f_full.msg.pi(); + let undamped_msg_tau = f_full.msg.tau(); + + // Run damped at α = 0.5 from the same initial state. + let mut vars_half = VarStore::new(); + let diff_half = vars_half.alloc(Gaussian::from_ms(2.0, 3.0)); + let mut f_half = TruncFactor::new(diff_half, 0.0, false); + f_half.propagate_with_alpha(&mut vars_half, 0.5); + + let expected_pi = 0.5 * undamped_msg_pi + 0.5 * initial_msg_pi; + let expected_tau = 0.5 * undamped_msg_tau + 0.5 * initial_msg_tau; + assert!((f_half.msg.pi() - expected_pi).abs() < 1e-12); + assert!((f_half.msg.tau() - expected_tau).abs() < 1e-12); + } }