Revert back.

This commit is contained in:
2018-10-24 07:11:38 +02:00
parent 9cf91fbdf8
commit b39c446b37
5 changed files with 112 additions and 137 deletions

View File

@@ -3,7 +3,7 @@ use std::collections::HashMap;
use gaussian::Gaussian;
use math;
#[derive(Clone, Copy)]
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct VariableId {
index: usize,
}
@@ -23,7 +23,7 @@ impl VariableArena {
let index = self.variables.len();
self.variables.push(Variable {
value: Gaussian::with_pi_tau(0.0, 0.0),
value: Gaussian::from_precision(0.0, 0.0),
factors: HashMap::new(),
});
@@ -46,16 +46,23 @@ pub struct Variable {
impl Variable {
pub fn attach_factor(&mut self, factor: usize) {
self.factors.insert(factor, Gaussian::new());
self.factors.insert(factor, Gaussian::from_precision(0.0, 0.0));
}
pub fn update_value(&mut self, factor: usize, value: Gaussian) {
let old = self.factors[&factor];
let intermediate = value * old;
let value = intermediate / self.value;
let new = intermediate / self.value;
self.value = value;
/*
println!("update_value: old={}, value={:?}, new={:?}, self.value={:?}",
render_g(old),
value,
new,
self.value);
*/
self.value = new;
}
pub fn get_value(&self) -> Gaussian {
@@ -70,6 +77,8 @@ impl Variable {
self.value = value;
println!("update_message: old={:?} msg={:?}, new={:?}", old, message, value);
self.factors.insert(factor, message);
}
@@ -91,9 +100,7 @@ impl PriorFactor {
variable: VariableId,
gaussian: Gaussian,
) -> PriorFactor {
if let Some(variable) = variable_arena.get_mut(variable) {
variable.attach_factor(id);
}
variable_arena.get_mut(variable).unwrap().attach_factor(id);
PriorFactor {
id,
@@ -103,9 +110,9 @@ impl PriorFactor {
}
pub fn start(&self, variable_arena: &mut VariableArena) {
if let Some(variable) = variable_arena.get_mut(self.variable) {
variable.update_value(self.id, self.gaussian);
}
println!("PriorFactor: variable.id={}, msg={:?}", self.variable.index, self.gaussian);
variable_arena.get_mut(self.variable).unwrap().update_value(self.id, self.gaussian);
}
}
@@ -124,13 +131,8 @@ impl LikelihoodFactor {
value: VariableId,
variance: f64,
) -> LikelihoodFactor {
if let Some(variable) = variable_arena.get_mut(mean) {
variable.attach_factor(id);
}
if let Some(variable) = variable_arena.get_mut(value) {
variable.attach_factor(id);
}
variable_arena.get_mut(mean).unwrap().attach_factor(id);
variable_arena.get_mut(value).unwrap().attach_factor(id);
LikelihoodFactor {
id,
@@ -151,13 +153,13 @@ impl LikelihoodFactor {
.map(|variable| variable.get_message(self.id))
.unwrap();
let a = 1.0 / (1.0 + self.variance * (x.pi - fx.pi));
let a = 1.0 / (1.0 + self.variance * (x.precision_mean() - fx.precision_mean()));
let gaussian = Gaussian::with_pi_tau(a * (x.pi - fx.pi), a * (x.tau - fx.tau));
let gaussian = Gaussian::from_precision(a * (x.precision_mean() - fx.precision_mean()), a * (x.precision() - fx.precision()));
if let Some(variable) = variable_arena.get_mut(self.mean) {
variable.update_message(self.id, gaussian);
}
println!("LikelihoodFactor: mean.id={}, msg={:?}", self.mean.index, gaussian);
variable_arena.get_mut(self.mean).unwrap().update_message(self.id, gaussian);
}
pub fn update_value(&self, variable_arena: &mut VariableArena) {
@@ -165,18 +167,20 @@ impl LikelihoodFactor {
.get(self.mean)
.map(|variable| variable.get_value())
.unwrap();
let fy = variable_arena
.get(self.mean)
.map(|variable| variable.get_message(self.id))
.unwrap();
let a = 1.0 / (1.0 + self.variance * (y.pi - fy.pi));
let a = 1.0 / (1.0 + self.variance * (y.precision_mean() - fy.precision_mean()));
let gaussian = Gaussian::with_pi_tau(a * (y.pi - fy.pi), a * (y.tau - fy.tau));
let gaussian = Gaussian::from_precision(a * (y.precision_mean() - fy.precision_mean()), a * (y.precision() - fy.precision()));
if let Some(variable) = variable_arena.get_mut(self.value) {
variable.update_message(self.id, gaussian);
}
println!("LikelihoodFactor: value.id={}, msg={:?}", self.value.index, gaussian);
variable_arena.get_mut(self.value).unwrap().update_message(self.id, gaussian);
}
}
@@ -195,14 +199,10 @@ impl SumFactor {
terms: Vec<VariableId>,
coeffs: Vec<f64>,
) -> SumFactor {
if let Some(variable) = variable_arena.get_mut(sum) {
variable.attach_factor(id);
}
variable_arena.get_mut(sum).unwrap().attach_factor(id);
for term in &terms {
if let Some(variable) = variable_arena.get_mut(*term) {
variable.attach_factor(id);
}
variable_arena.get_mut(*term).unwrap().attach_factor(id);
}
SumFactor {
@@ -231,18 +231,22 @@ impl SumFactor {
let gy = y[i];
let gfy = fy[i];
sum_pi += da.powi(2) / (gy.pi - gfy.pi);
sum_tau += da * (gy.tau - gfy.tau) / (gy.pi - gfy.pi);
sum_pi += da.powi(2) / (gy.pi() - gfy.pi());
sum_tau += da * (gy.tau() - gfy.tau()) / (gy.pi() - gfy.pi());
}
let new_pi = 1.0 / sum_pi;
let new_tau = new_pi * sum_tau;
let gaussian = Gaussian::with_pi_tau(new_pi, new_tau);
let gaussian = Gaussian::from_pi_tau(new_pi, new_tau);
if let Some(variable) = variable_arena.get_mut(variable) {
variable.update_value(self.id, gaussian);
if variable == self.sum {
println!("SumFactor: sum.id={}, msg={:?}", variable.index, gaussian);
} else {
println!("SumFactor: term.id={}, msg={:?}", variable.index, gaussian);
}
variable_arena.get_mut(variable).unwrap().update_value(self.id, gaussian);
}
pub fn update_sum(&self, variable_arena: &mut VariableArena) {
@@ -351,9 +355,7 @@ impl TruncateFactor {
epsilon: f64,
draw: bool,
) -> TruncateFactor {
if let Some(variable) = variable_arena.get_mut(variable) {
variable.attach_factor(id);
}
variable_arena.get_mut(variable).unwrap().attach_factor(id);
TruncateFactor {
id,
@@ -368,13 +370,15 @@ impl TruncateFactor {
.get(self.variable)
.map(|variable| variable.get_value())
.unwrap();
let fx = variable_arena
.get_mut(self.variable)
.map(|variable| variable.get_message(self.id))
.unwrap();
let c = x.pi - fx.pi;
let d = x.tau - fx.tau;
let c = x.pi() - fx.pi();
let d = x.tau() - fx.tau();
let sqrt_c = c.sqrt();
let t = d / sqrt_c;
@@ -388,10 +392,8 @@ impl TruncateFactor {
let m_w = 1.0 - w;
let gaussian = Gaussian::with_pi_tau(c / m_w, (d + sqrt_c * v) / m_w);
let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w);
if let Some(variable) = variable_arena.get_mut(self.variable) {
variable.update_value(self.id, gaussian);
}
variable_arena.get_mut(self.variable).unwrap().update_value(self.id, gaussian);
}
}