feat(gaussian): add damp_natural helper for EP damping
Computes α·new + (1−α)·self in natural-parameter space. Will be used by TruncFactor and MarginFactor to support opt-in EP damping via ConvergenceOptions::alpha.
This commit is contained in:
@@ -96,6 +96,18 @@ impl Gaussian {
|
||||
let var = self.sigma().powi(2) + variance_delta;
|
||||
Self::from_ms(self.mu(), var.sqrt())
|
||||
}
|
||||
|
||||
/// EP damping in natural-parameter space: `α·new + (1−α)·self`.
|
||||
///
|
||||
/// Used by within-game inference to stabilise oscillating fixed-point
|
||||
/// loops on hard graphs. `alpha = 1.0` returns `new` exactly;
|
||||
/// `alpha < 1.0` shrinks each per-step update.
|
||||
pub fn damp_natural(self, new: Gaussian, alpha: f64) -> Gaussian {
|
||||
Gaussian::from_natural(
|
||||
alpha * new.pi() + (1.0 - alpha) * self.pi(),
|
||||
alpha * new.tau() + (1.0 - alpha) * self.tau(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Gaussian {
|
||||
@@ -231,4 +243,33 @@ mod tests {
|
||||
assert!((r.pi() - expected_pi).abs() < 1e-15);
|
||||
assert!((r.tau() - expected_tau).abs() < 1e-15);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn damp_natural_alpha_one_returns_new() {
|
||||
let old = Gaussian::from_ms(1.0, 2.0);
|
||||
let new = Gaussian::from_ms(5.0, 0.5);
|
||||
let damped = old.damp_natural(new, 1.0);
|
||||
assert_eq!(damped.pi(), new.pi());
|
||||
assert_eq!(damped.tau(), new.tau());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn damp_natural_alpha_zero_returns_self() {
|
||||
let old = Gaussian::from_ms(1.0, 2.0);
|
||||
let new = Gaussian::from_ms(5.0, 0.5);
|
||||
let damped = old.damp_natural(new, 0.0);
|
||||
assert_eq!(damped.pi(), old.pi());
|
||||
assert_eq!(damped.tau(), old.tau());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn damp_natural_alpha_half_is_midpoint_in_natural_params() {
|
||||
let old = Gaussian::from_ms(1.0, 2.0);
|
||||
let new = Gaussian::from_ms(5.0, 0.5);
|
||||
let damped = old.damp_natural(new, 0.5);
|
||||
let expected_pi = 0.5 * new.pi() + 0.5 * old.pi();
|
||||
let expected_tau = 0.5 * new.tau() + 0.5 * old.tau();
|
||||
assert!((damped.pi() - expected_pi).abs() < 1e-12);
|
||||
assert!((damped.tau() - expected_tau).abs() < 1e-12);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user