This commit is contained in:
2018-10-24 20:16:16 +02:00
parent bcdabf9fbb
commit 4531fb3272

View File

@@ -131,10 +131,7 @@ impl LikelihoodFactor {
pub fn update_mean(&self, variable_arena: &mut VariableArena) { pub fn update_mean(&self, variable_arena: &mut VariableArena) {
let (x, fx) = variable_arena let (x, fx) = variable_arena
.get(self.value) .get(self.value)
.map(|variable| ( .map(|variable| (variable.get_value(), variable.get_message(self.id)))
variable.get_value(),
variable.get_message(self.id)
))
.unwrap(); .unwrap();
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi())); let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
@@ -149,10 +146,7 @@ impl LikelihoodFactor {
pub fn update_value(&self, variable_arena: &mut VariableArena) { pub fn update_value(&self, variable_arena: &mut VariableArena) {
let (y, fy) = variable_arena let (y, fy) = variable_arena
.get(self.mean) .get(self.mean)
.map(|variable| ( .map(|variable| (variable.get_value(), variable.get_message(self.id)))
variable.get_value(),
variable.get_message(self.id)
))
.unwrap(); .unwrap();
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi())); let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
@@ -202,14 +196,13 @@ impl SumFactor {
fy: Vec<Gaussian>, fy: Vec<Gaussian>,
a: &[f64], a: &[f64],
) { ) {
let (sum_pi, sum_tau) = a.iter().zip(y.iter().zip(fy.iter())) let (sum_pi, sum_tau) =
a.iter()
.zip(y.iter().zip(fy.iter()))
.fold((0.0, 0.0), |(pi, tau), (a, (y, fy))| { .fold((0.0, 0.0), |(pi, tau), (a, (y, fy))| {
let x = *y / *fy; let x = *y / *fy;
( (pi + a.powi(2) / x.pi(), tau + a * x.tau() / x.pi())
pi + a.powi(2) / x.pi(),
tau + a * x.tau() / x.pi()
)
}); });
let new_pi = 1.0 / sum_pi; let new_pi = 1.0 / sum_pi;
@@ -224,15 +217,13 @@ impl SumFactor {
} }
pub fn update_sum(&self, variable_arena: &mut VariableArena) { pub fn update_sum(&self, variable_arena: &mut VariableArena) {
let (y, fy) = self.terms let (y, fy) = self
.terms
.iter() .iter()
.map(|term| { .map(|term| {
variable_arena variable_arena
.get(*term) .get(*term)
.map(|variable| ( .map(|variable| (variable.get_value(), variable.get_message(self.id)))
variable.get_value(),
variable.get_message(self.id)
))
.unwrap() .unwrap()
}) })
.unzip(); .unzip();
@@ -244,7 +235,8 @@ impl SumFactor {
let idx_term = self.terms[index]; let idx_term = self.terms[index];
let idx_coeff = self.coeffs[index]; let idx_coeff = self.coeffs[index];
let a = self.coeffs let a = self
.coeffs
.iter() .iter()
.enumerate() .enumerate()
.map(|(i, coeff)| { .map(|(i, coeff)| {
@@ -256,22 +248,16 @@ impl SumFactor {
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let (y, fy) = self.terms let (y, fy) = self
.terms
.iter() .iter()
.enumerate() .enumerate()
.map(|(i, term)| { .map(|(i, term)| {
let variable = if i == index { let variable = if i == index { self.sum } else { *term };
self.sum
} else {
*term
};
variable_arena variable_arena
.get(variable) .get(variable)
.map(|variable| ( .map(|variable| (variable.get_value(), variable.get_message(self.id)))
variable.get_value(),
variable.get_message(self.id)
))
.unwrap() .unwrap()
}) })
.unzip(); .unzip();
@@ -330,10 +316,7 @@ impl TruncateFactor {
pub fn update(&self, variable_arena: &mut VariableArena) { pub fn update(&self, variable_arena: &mut VariableArena) {
let (x, fx) = variable_arena let (x, fx) = variable_arena
.get(self.variable) .get(self.variable)
.map(|variable| ( .map(|variable| (variable.get_value(), variable.get_message(self.id)))
variable.get_value(),
variable.get_message(self.id)
))
.unwrap(); .unwrap();
let c = x.pi() - fx.pi(); let c = x.pi() - fx.pi();