From aacaa60baa62700224daf2ead51cf6e5b581c923 Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Fri, 8 May 2026 15:03:45 +0200 Subject: [PATCH] feat(factor): add MarginFactor::propagate_with_alpha for EP damping MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mirrors TruncFactor: inherent damped-propagate method, trait impl delegates with α=1.0. Existing goldens unchanged because cavity*new_msg equals the previous marginal write when α=1.0. --- src/factor/margin.rs | 66 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 6 deletions(-) diff --git a/src/factor/margin.rs b/src/factor/margin.rs index aa27fc1..4b67124 100644 --- a/src/factor/margin.rs +++ b/src/factor/margin.rs @@ -32,8 +32,11 @@ impl MarginFactor { } } -impl Factor for MarginFactor { - fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64) { +impl MarginFactor { + /// Propagate this factor's message, optionally damping the update in + /// natural-parameter space. `alpha = 1.0` matches `Factor::propagate` + /// exactly; `alpha < 1.0` writes `α·new_msg + (1−α)·old_msg`. + pub(crate) fn propagate_with_alpha(&mut self, vars: &mut VarStore, alpha: f64) -> (f64, f64) { let marginal = vars.get(self.diff); let cavity = marginal / self.msg; @@ -42,12 +45,18 @@ impl Factor for MarginFactor { } let new_msg = Gaussian::from_ms(self.m_obs, self.sigma); - let new_marginal = cavity * new_msg; + let damped = self.msg.damp_natural(new_msg, alpha); let old_msg = self.msg; - self.msg = new_msg; - vars.set(self.diff, new_marginal); + self.msg = damped; + vars.set(self.diff, cavity * damped); - old_msg.delta(new_msg) + old_msg.delta(damped) + } +} + +impl Factor for MarginFactor { + fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64) { + self.propagate_with_alpha(vars, 1.0) } fn log_evidence(&self, _vars: &VarStore) -> f64 { @@ -120,4 +129,49 @@ mod tests { let logz = f.log_evidence(&vars); assert!((logz - (-3.062235327364623)).abs() < 1e-10); } + + #[test] + fn propagate_with_alpha_one_matches_undamped_propagate() { + let mut vars_a = VarStore::new(); + let diff_a = vars_a.alloc(Gaussian::from_ms(0.0, 6.0)); + let mut f_a = MarginFactor::new(diff_a, 5.0, 1.0); + let delta_a = f_a.propagate(&mut vars_a); + let result_a = vars_a.get(diff_a); + + let mut vars_b = VarStore::new(); + let diff_b = vars_b.alloc(Gaussian::from_ms(0.0, 6.0)); + let mut f_b = MarginFactor::new(diff_b, 5.0, 1.0); + let delta_b = f_b.propagate_with_alpha(&mut vars_b, 1.0); + let result_b = vars_b.get(diff_b); + + assert_eq!(result_a.pi(), result_b.pi()); + assert_eq!(result_a.tau(), result_b.tau()); + assert_eq!(delta_a, delta_b); + assert_eq!(f_a.msg.pi(), f_b.msg.pi()); + assert_eq!(f_a.msg.tau(), f_b.msg.tau()); + } + + #[test] + fn propagate_with_alpha_half_blends_msg_in_natural_params() { + // Run undamped to capture (initial_msg, undamped_new_msg). + let mut vars_full = VarStore::new(); + let diff_full = vars_full.alloc(Gaussian::from_ms(0.0, 6.0)); + let mut f_full = MarginFactor::new(diff_full, 5.0, 1.0); + let initial_msg_pi = f_full.msg.pi(); + let initial_msg_tau = f_full.msg.tau(); + f_full.propagate(&mut vars_full); + let undamped_msg_pi = f_full.msg.pi(); + let undamped_msg_tau = f_full.msg.tau(); + + // Run damped at α = 0.5 from the same initial state. + let mut vars_half = VarStore::new(); + let diff_half = vars_half.alloc(Gaussian::from_ms(0.0, 6.0)); + let mut f_half = MarginFactor::new(diff_half, 5.0, 1.0); + f_half.propagate_with_alpha(&mut vars_half, 0.5); + + let expected_pi = 0.5 * undamped_msg_pi + 0.5 * initial_msg_pi; + let expected_tau = 0.5 * undamped_msg_tau + 0.5 * initial_msg_tau; + assert!((f_half.msg.pi() - expected_pi).abs() < 1e-12); + assert!((f_half.msg.tau() - expected_tau).abs() < 1e-12); + } }