Added index support for VariableArena.

This commit is contained in:
2018-10-28 19:52:57 +01:00
parent 073a250701
commit 712bcd9f42
2 changed files with 44 additions and 52 deletions

View File

@@ -1,5 +1,6 @@
use std::collections::HashMap;
use std::f64;
use std::ops;
use gaussian::Gaussian;
use math;
@@ -30,13 +31,19 @@ impl VariableArena {
VariableId { index }
}
}
pub fn get(&mut self, id: VariableId) -> Option<&Variable> {
self.variables.get(id.index)
impl ops::Index<VariableId> for VariableArena {
type Output = Variable;
fn index(&self, id: VariableId) -> &Self::Output {
&self.variables[id.index]
}
}
pub fn get_mut(&mut self, id: VariableId) -> Option<&mut Variable> {
self.variables.get_mut(id.index)
impl ops::IndexMut<VariableId> for VariableArena {
fn index_mut(&mut self, id: VariableId) -> &mut Self::Output {
&mut self.variables[id.index]
}
}
@@ -108,7 +115,7 @@ impl PriorFactor {
variable: VariableId,
gaussian: Gaussian,
) -> PriorFactor {
variable_arena.get_mut(variable).unwrap().attach_factor(id);
variable_arena[variable].attach_factor(id);
PriorFactor {
id,
@@ -123,10 +130,7 @@ impl PriorFactor {
self.variable.index, self.gaussian
);
variable_arena
.get_mut(self.variable)
.unwrap()
.update_value(self.id, self.gaussian);
variable_arena[self.variable].update_value(self.id, self.gaussian);
}
}
@@ -145,8 +149,8 @@ impl LikelihoodFactor {
value: VariableId,
variance: f64,
) -> LikelihoodFactor {
variable_arena.get_mut(mean).unwrap().attach_factor(id);
variable_arena.get_mut(value).unwrap().attach_factor(id);
variable_arena[mean].attach_factor(id);
variable_arena[value].attach_factor(id);
LikelihoodFactor {
id,
@@ -157,10 +161,11 @@ impl LikelihoodFactor {
}
pub fn update_mean(&self, variable_arena: &mut VariableArena) {
let (x, fx) = variable_arena
.get(self.value)
.map(|variable| (variable.get_value(), variable.get_message(self.id)))
.unwrap();
let (x, fx) = {
let variable = &variable_arena[self.value];
(variable.get_value(), variable.get_message(self.id))
};
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau()));
@@ -170,17 +175,15 @@ impl LikelihoodFactor {
self.mean.index, gaussian
);
variable_arena
.get_mut(self.mean)
.unwrap()
.update_message(self.id, gaussian);
variable_arena[self.mean].update_message(self.id, gaussian);
}
pub fn update_value(&self, variable_arena: &mut VariableArena) {
let (y, fy) = variable_arena
.get(self.mean)
.map(|variable| (variable.get_value(), variable.get_message(self.id)))
.unwrap();
let (y, fy) = {
let variable = &variable_arena[self.mean];
(variable.get_value(), variable.get_message(self.id))
};
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau()));
@@ -190,10 +193,7 @@ impl LikelihoodFactor {
self.value.index, gaussian
);
variable_arena
.get_mut(self.value)
.unwrap()
.update_message(self.id, gaussian);
variable_arena[self.value].update_message(self.id, gaussian);
}
}
@@ -212,10 +212,10 @@ impl SumFactor {
terms: Vec<VariableId>,
coeffs: Vec<f64>,
) -> SumFactor {
variable_arena.get_mut(sum).unwrap().attach_factor(id);
variable_arena[sum].attach_factor(id);
for term in &terms {
variable_arena.get_mut(*term).unwrap().attach_factor(id);
variable_arena[*term].attach_factor(id);
}
SumFactor {
@@ -257,10 +257,7 @@ impl SumFactor {
debug!("Sum::up var={:?}, value={:?}", variable.index, gaussian);
}
variable_arena
.get_mut(variable)
.unwrap()
.update_message(self.id, gaussian);
variable_arena[variable].update_message(self.id, gaussian);
}
pub fn update_sum(&self, variable_arena: &mut VariableArena) {
@@ -268,10 +265,9 @@ impl SumFactor {
.terms
.iter()
.map(|term| {
variable_arena
.get(*term)
.map(|variable| (variable.get_value(), variable.get_message(self.id)))
.unwrap()
let variable = &variable_arena[*term];
(variable.get_value(), variable.get_message(self.id))
})
.unzip();
@@ -300,12 +296,10 @@ impl SumFactor {
.iter()
.enumerate()
.map(|(i, term)| {
let variable = if i == index { self.sum } else { *term };
let variable_id = if i == index { self.sum } else { *term };
let variable = &variable_arena[variable_id];
variable_arena
.get(variable)
.map(|variable| (variable.get_value(), variable.get_message(self.id)))
.unwrap()
(variable.get_value(), variable.get_message(self.id))
})
.unzip();
@@ -356,7 +350,7 @@ impl TruncateFactor {
epsilon: f64,
draw: bool,
) -> TruncateFactor {
variable_arena.get_mut(variable).unwrap().attach_factor(id);
variable_arena[variable].attach_factor(id);
TruncateFactor {
id,
@@ -367,10 +361,11 @@ impl TruncateFactor {
}
pub fn update(&self, variable_arena: &mut VariableArena) -> f64 {
let (x, fx) = variable_arena
.get(self.variable)
.map(|variable| (variable.get_value(), variable.get_message(self.id)))
.unwrap();
let (x, fx) = {
let variable = &variable_arena[self.variable];
(variable.get_value(), variable.get_message(self.id))
};
let c = x.pi() - fx.pi();
let d = x.tau() - fx.tau();
@@ -395,9 +390,6 @@ impl TruncateFactor {
self.variable.index, gaussian
);
variable_arena
.get_mut(self.variable)
.unwrap()
.update_value(self.id, gaussian)
variable_arena[self.variable].update_value(self.id, gaussian)
}
}

View File

@@ -261,7 +261,7 @@ impl TrueSkill {
rating_vars
.iter()
.map(|variable| variable_arena.get(*variable).unwrap().get_value())
.map(|variable| variable_arena[*variable].get_value())
.map(|value| Rating {
mu: value.mu(),
sigma: value.sigma(),