It works!

This commit is contained in:
2018-10-24 11:12:34 +02:00
parent b39c446b37
commit be74c1eac7
4 changed files with 190 additions and 68 deletions

View File

@@ -23,7 +23,7 @@ impl VariableArena {
let index = self.variables.len();
self.variables.push(Variable {
value: Gaussian::from_precision(0.0, 0.0),
value: Gaussian::from_pi_tau(0.0, 0.0),
factors: HashMap::new(),
});
@@ -46,26 +46,19 @@ pub struct Variable {
impl Variable {
pub fn attach_factor(&mut self, factor: usize) {
self.factors.insert(factor, Gaussian::from_precision(0.0, 0.0));
self.factors.insert(factor, Gaussian::from_pi_tau(0.0, 0.0));
}
pub fn update_value(&mut self, factor: usize, value: Gaussian) {
let old = self.factors[&factor];
let intermediate = value * old;
let new = intermediate / self.value;
self.factors.insert(factor, intermediate / self.value);
/*
println!("update_value: old={}, value={:?}, new={:?}, self.value={:?}",
render_g(old),
value,
new,
self.value);
*/
self.value = new;
self.value = value;
}
pub fn get_value(&self) -> Gaussian {
pub fn get_value(&self) -> Gaussian {
self.value
}
@@ -77,8 +70,6 @@ impl Variable {
self.value = value;
println!("update_message: old={:?} msg={:?}, new={:?}", old, message, value);
self.factors.insert(factor, message);
}
@@ -110,8 +101,6 @@ impl PriorFactor {
}
pub fn start(&self, variable_arena: &mut VariableArena) {
println!("PriorFactor: variable.id={}, msg={:?}", self.variable.index, self.gaussian);
variable_arena.get_mut(self.variable).unwrap().update_value(self.id, self.gaussian);
}
}
@@ -153,11 +142,9 @@ impl LikelihoodFactor {
.map(|variable| variable.get_message(self.id))
.unwrap();
let a = 1.0 / (1.0 + self.variance * (x.precision_mean() - fx.precision_mean()));
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
let gaussian = Gaussian::from_precision(a * (x.precision_mean() - fx.precision_mean()), a * (x.precision() - fx.precision()));
println!("LikelihoodFactor: mean.id={}, msg={:?}", self.mean.index, gaussian);
let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau()));
variable_arena.get_mut(self.mean).unwrap().update_message(self.id, gaussian);
}
@@ -173,12 +160,9 @@ impl LikelihoodFactor {
.map(|variable| variable.get_message(self.id))
.unwrap();
let a = 1.0 / (1.0 + self.variance * (y.precision_mean() - fy.precision_mean()));
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
let gaussian = Gaussian::from_precision(a * (y.precision_mean() - fy.precision_mean()), a * (y.precision() - fy.precision()));
println!("LikelihoodFactor: value.id={}, msg={:?}", self.value.index, gaussian);
let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau()));
variable_arena.get_mut(self.value).unwrap().update_message(self.id, gaussian);
}
@@ -240,13 +224,7 @@ impl SumFactor {
let gaussian = Gaussian::from_pi_tau(new_pi, new_tau);
if variable == self.sum {
println!("SumFactor: sum.id={}, msg={:?}", variable.index, gaussian);
} else {
println!("SumFactor: term.id={}, msg={:?}", variable.index, gaussian);
}
variable_arena.get_mut(variable).unwrap().update_value(self.id, gaussian);
variable_arena.get_mut(variable).unwrap().update_message(self.id, gaussian);
}
pub fn update_sum(&self, variable_arena: &mut VariableArena) {
@@ -298,16 +276,16 @@ impl SumFactor {
v[index] = self.sum;
for term in v {
for term in &v {
let value = variable_arena
.get(term)
.get(*term)
.map(|variable| variable.get_value())
.unwrap();
y.push(value);
let value = variable_arena
.get(term)
.get(*term)
.map(|variable| variable.get_message(self.id))
.unwrap();