Added delta and more tests.
This commit is contained in:
@@ -49,13 +49,29 @@ impl Variable {
|
||||
self.factors.insert(factor, Gaussian::from_pi_tau(0.0, 0.0));
|
||||
}
|
||||
|
||||
pub fn update_value(&mut self, factor: usize, value: Gaussian) {
|
||||
pub fn update_value(&mut self, factor: usize, value: Gaussian) -> f64 {
|
||||
let old = self.factors[&factor];
|
||||
|
||||
self.factors.insert(factor, value * old / self.value);
|
||||
|
||||
let pi_delta = (self.value.pi() - value.pi()).abs();
|
||||
|
||||
let delta = if !pi_delta.is_finite() {
|
||||
0.0
|
||||
} else {
|
||||
let pi_delta = pi_delta.sqrt();
|
||||
let tau_delta = (self.value.tau() - value.tau()).abs();
|
||||
|
||||
if pi_delta > tau_delta {
|
||||
pi_delta
|
||||
} else {
|
||||
tau_delta
|
||||
}
|
||||
};
|
||||
|
||||
self.value = value;
|
||||
|
||||
// debug!("Variable::value value={:?}", self.value);
|
||||
delta
|
||||
}
|
||||
|
||||
pub fn get_value(&self) -> Gaussian {
|
||||
@@ -67,8 +83,6 @@ impl Variable {
|
||||
|
||||
self.factors.insert(factor, message);
|
||||
self.value = self.value / old * message;
|
||||
|
||||
// debug!("Variable::message value={:?}", self.value);
|
||||
}
|
||||
|
||||
pub fn get_message(&self, factor: usize) -> Gaussian {
|
||||
@@ -344,7 +358,7 @@ impl TruncateFactor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(&self, variable_arena: &mut VariableArena) {
|
||||
pub fn update(&self, variable_arena: &mut VariableArena) -> f64 {
|
||||
let (x, fx) = variable_arena
|
||||
.get(self.variable)
|
||||
.map(|variable| (variable.get_value(), variable.get_message(self.id)))
|
||||
@@ -376,6 +390,6 @@ impl TruncateFactor {
|
||||
variable_arena
|
||||
.get_mut(self.variable)
|
||||
.unwrap()
|
||||
.update_value(self.id, gaussian);
|
||||
.update_value(self.id, gaussian)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user