Added delta and more tests.

This commit is contained in:
2018-10-26 08:57:58 +02:00
parent 7a378c26b0
commit c471ef3399
2 changed files with 99 additions and 15 deletions

View File

@@ -49,13 +49,29 @@ impl Variable {
self.factors.insert(factor, Gaussian::from_pi_tau(0.0, 0.0));
}
pub fn update_value(&mut self, factor: usize, value: Gaussian) {
pub fn update_value(&mut self, factor: usize, value: Gaussian) -> f64 {
let old = self.factors[&factor];
self.factors.insert(factor, value * old / self.value);
let pi_delta = (self.value.pi() - value.pi()).abs();
let delta = if !pi_delta.is_finite() {
0.0
} else {
let pi_delta = pi_delta.sqrt();
let tau_delta = (self.value.tau() - value.tau()).abs();
if pi_delta > tau_delta {
pi_delta
} else {
tau_delta
}
};
self.value = value;
// debug!("Variable::value value={:?}", self.value);
delta
}
pub fn get_value(&self) -> Gaussian {
@@ -67,8 +83,6 @@ impl Variable {
self.factors.insert(factor, message);
self.value = self.value / old * message;
// debug!("Variable::message value={:?}", self.value);
}
pub fn get_message(&self, factor: usize) -> Gaussian {
@@ -344,7 +358,7 @@ impl TruncateFactor {
}
}
pub fn update(&self, variable_arena: &mut VariableArena) {
pub fn update(&self, variable_arena: &mut VariableArena) -> f64 {
let (x, fx) = variable_arena
.get(self.variable)
.map(|variable| (variable.get_value(), variable.get_message(self.id)))
@@ -376,6 +390,6 @@ impl TruncateFactor {
variable_arena
.get_mut(self.variable)
.unwrap()
.update_value(self.id, gaussian);
.update_value(self.id, gaussian)
}
}