Added index support for VariableArena.

This commit is contained in:
2018-10-28 19:52:57 +01:00
parent 073a250701
commit 712bcd9f42
2 changed files with 44 additions and 52 deletions

View File

@@ -1,5 +1,6 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::f64; use std::f64;
use std::ops;
use gaussian::Gaussian; use gaussian::Gaussian;
use math; use math;
@@ -30,13 +31,19 @@ impl VariableArena {
VariableId { index } VariableId { index }
} }
}
pub fn get(&mut self, id: VariableId) -> Option<&Variable> { impl ops::Index<VariableId> for VariableArena {
self.variables.get(id.index) type Output = Variable;
fn index(&self, id: VariableId) -> &Self::Output {
&self.variables[id.index]
} }
}
pub fn get_mut(&mut self, id: VariableId) -> Option<&mut Variable> { impl ops::IndexMut<VariableId> for VariableArena {
self.variables.get_mut(id.index) fn index_mut(&mut self, id: VariableId) -> &mut Self::Output {
&mut self.variables[id.index]
} }
} }
@@ -108,7 +115,7 @@ impl PriorFactor {
variable: VariableId, variable: VariableId,
gaussian: Gaussian, gaussian: Gaussian,
) -> PriorFactor { ) -> PriorFactor {
variable_arena.get_mut(variable).unwrap().attach_factor(id); variable_arena[variable].attach_factor(id);
PriorFactor { PriorFactor {
id, id,
@@ -123,10 +130,7 @@ impl PriorFactor {
self.variable.index, self.gaussian self.variable.index, self.gaussian
); );
variable_arena variable_arena[self.variable].update_value(self.id, self.gaussian);
.get_mut(self.variable)
.unwrap()
.update_value(self.id, self.gaussian);
} }
} }
@@ -145,8 +149,8 @@ impl LikelihoodFactor {
value: VariableId, value: VariableId,
variance: f64, variance: f64,
) -> LikelihoodFactor { ) -> LikelihoodFactor {
variable_arena.get_mut(mean).unwrap().attach_factor(id); variable_arena[mean].attach_factor(id);
variable_arena.get_mut(value).unwrap().attach_factor(id); variable_arena[value].attach_factor(id);
LikelihoodFactor { LikelihoodFactor {
id, id,
@@ -157,10 +161,11 @@ impl LikelihoodFactor {
} }
pub fn update_mean(&self, variable_arena: &mut VariableArena) { pub fn update_mean(&self, variable_arena: &mut VariableArena) {
let (x, fx) = variable_arena let (x, fx) = {
.get(self.value) let variable = &variable_arena[self.value];
.map(|variable| (variable.get_value(), variable.get_message(self.id)))
.unwrap(); (variable.get_value(), variable.get_message(self.id))
};
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi())); let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau())); let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau()));
@@ -170,17 +175,15 @@ impl LikelihoodFactor {
self.mean.index, gaussian self.mean.index, gaussian
); );
variable_arena variable_arena[self.mean].update_message(self.id, gaussian);
.get_mut(self.mean)
.unwrap()
.update_message(self.id, gaussian);
} }
pub fn update_value(&self, variable_arena: &mut VariableArena) { pub fn update_value(&self, variable_arena: &mut VariableArena) {
let (y, fy) = variable_arena let (y, fy) = {
.get(self.mean) let variable = &variable_arena[self.mean];
.map(|variable| (variable.get_value(), variable.get_message(self.id)))
.unwrap(); (variable.get_value(), variable.get_message(self.id))
};
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi())); let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau())); let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau()));
@@ -190,10 +193,7 @@ impl LikelihoodFactor {
self.value.index, gaussian self.value.index, gaussian
); );
variable_arena variable_arena[self.value].update_message(self.id, gaussian);
.get_mut(self.value)
.unwrap()
.update_message(self.id, gaussian);
} }
} }
@@ -212,10 +212,10 @@ impl SumFactor {
terms: Vec<VariableId>, terms: Vec<VariableId>,
coeffs: Vec<f64>, coeffs: Vec<f64>,
) -> SumFactor { ) -> SumFactor {
variable_arena.get_mut(sum).unwrap().attach_factor(id); variable_arena[sum].attach_factor(id);
for term in &terms { for term in &terms {
variable_arena.get_mut(*term).unwrap().attach_factor(id); variable_arena[*term].attach_factor(id);
} }
SumFactor { SumFactor {
@@ -257,10 +257,7 @@ impl SumFactor {
debug!("Sum::up var={:?}, value={:?}", variable.index, gaussian); debug!("Sum::up var={:?}, value={:?}", variable.index, gaussian);
} }
variable_arena variable_arena[variable].update_message(self.id, gaussian);
.get_mut(variable)
.unwrap()
.update_message(self.id, gaussian);
} }
pub fn update_sum(&self, variable_arena: &mut VariableArena) { pub fn update_sum(&self, variable_arena: &mut VariableArena) {
@@ -268,10 +265,9 @@ impl SumFactor {
.terms .terms
.iter() .iter()
.map(|term| { .map(|term| {
variable_arena let variable = &variable_arena[*term];
.get(*term)
.map(|variable| (variable.get_value(), variable.get_message(self.id))) (variable.get_value(), variable.get_message(self.id))
.unwrap()
}) })
.unzip(); .unzip();
@@ -300,12 +296,10 @@ impl SumFactor {
.iter() .iter()
.enumerate() .enumerate()
.map(|(i, term)| { .map(|(i, term)| {
let variable = if i == index { self.sum } else { *term }; let variable_id = if i == index { self.sum } else { *term };
let variable = &variable_arena[variable_id];
variable_arena (variable.get_value(), variable.get_message(self.id))
.get(variable)
.map(|variable| (variable.get_value(), variable.get_message(self.id)))
.unwrap()
}) })
.unzip(); .unzip();
@@ -356,7 +350,7 @@ impl TruncateFactor {
epsilon: f64, epsilon: f64,
draw: bool, draw: bool,
) -> TruncateFactor { ) -> TruncateFactor {
variable_arena.get_mut(variable).unwrap().attach_factor(id); variable_arena[variable].attach_factor(id);
TruncateFactor { TruncateFactor {
id, id,
@@ -367,10 +361,11 @@ impl TruncateFactor {
} }
pub fn update(&self, variable_arena: &mut VariableArena) -> f64 { pub fn update(&self, variable_arena: &mut VariableArena) -> f64 {
let (x, fx) = variable_arena let (x, fx) = {
.get(self.variable) let variable = &variable_arena[self.variable];
.map(|variable| (variable.get_value(), variable.get_message(self.id)))
.unwrap(); (variable.get_value(), variable.get_message(self.id))
};
let c = x.pi() - fx.pi(); let c = x.pi() - fx.pi();
let d = x.tau() - fx.tau(); let d = x.tau() - fx.tau();
@@ -395,9 +390,6 @@ impl TruncateFactor {
self.variable.index, gaussian self.variable.index, gaussian
); );
variable_arena variable_arena[self.variable].update_value(self.id, gaussian)
.get_mut(self.variable)
.unwrap()
.update_value(self.id, gaussian)
} }
} }

View File

@@ -261,7 +261,7 @@ impl TrueSkill {
rating_vars rating_vars
.iter() .iter()
.map(|variable| variable_arena.get(*variable).unwrap().get_value()) .map(|variable| variable_arena[*variable].get_value())
.map(|value| Rating { .map(|value| Rating {
mu: value.mu(), mu: value.mu(),
sigma: value.sigma(), sigma: value.sigma(),