Added index support for VariableArena.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
use std::collections::HashMap;
|
||||
use std::f64;
|
||||
use std::ops;
|
||||
|
||||
use gaussian::Gaussian;
|
||||
use math;
|
||||
@@ -30,13 +31,19 @@ impl VariableArena {
|
||||
|
||||
VariableId { index }
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get(&mut self, id: VariableId) -> Option<&Variable> {
|
||||
self.variables.get(id.index)
|
||||
impl ops::Index<VariableId> for VariableArena {
|
||||
type Output = Variable;
|
||||
|
||||
fn index(&self, id: VariableId) -> &Self::Output {
|
||||
&self.variables[id.index]
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_mut(&mut self, id: VariableId) -> Option<&mut Variable> {
|
||||
self.variables.get_mut(id.index)
|
||||
impl ops::IndexMut<VariableId> for VariableArena {
|
||||
fn index_mut(&mut self, id: VariableId) -> &mut Self::Output {
|
||||
&mut self.variables[id.index]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,7 +115,7 @@ impl PriorFactor {
|
||||
variable: VariableId,
|
||||
gaussian: Gaussian,
|
||||
) -> PriorFactor {
|
||||
variable_arena.get_mut(variable).unwrap().attach_factor(id);
|
||||
variable_arena[variable].attach_factor(id);
|
||||
|
||||
PriorFactor {
|
||||
id,
|
||||
@@ -123,10 +130,7 @@ impl PriorFactor {
|
||||
self.variable.index, self.gaussian
|
||||
);
|
||||
|
||||
variable_arena
|
||||
.get_mut(self.variable)
|
||||
.unwrap()
|
||||
.update_value(self.id, self.gaussian);
|
||||
variable_arena[self.variable].update_value(self.id, self.gaussian);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,8 +149,8 @@ impl LikelihoodFactor {
|
||||
value: VariableId,
|
||||
variance: f64,
|
||||
) -> LikelihoodFactor {
|
||||
variable_arena.get_mut(mean).unwrap().attach_factor(id);
|
||||
variable_arena.get_mut(value).unwrap().attach_factor(id);
|
||||
variable_arena[mean].attach_factor(id);
|
||||
variable_arena[value].attach_factor(id);
|
||||
|
||||
LikelihoodFactor {
|
||||
id,
|
||||
@@ -157,10 +161,11 @@ impl LikelihoodFactor {
|
||||
}
|
||||
|
||||
pub fn update_mean(&self, variable_arena: &mut VariableArena) {
|
||||
let (x, fx) = variable_arena
|
||||
.get(self.value)
|
||||
.map(|variable| (variable.get_value(), variable.get_message(self.id)))
|
||||
.unwrap();
|
||||
let (x, fx) = {
|
||||
let variable = &variable_arena[self.value];
|
||||
|
||||
(variable.get_value(), variable.get_message(self.id))
|
||||
};
|
||||
|
||||
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
|
||||
let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau()));
|
||||
@@ -170,17 +175,15 @@ impl LikelihoodFactor {
|
||||
self.mean.index, gaussian
|
||||
);
|
||||
|
||||
variable_arena
|
||||
.get_mut(self.mean)
|
||||
.unwrap()
|
||||
.update_message(self.id, gaussian);
|
||||
variable_arena[self.mean].update_message(self.id, gaussian);
|
||||
}
|
||||
|
||||
pub fn update_value(&self, variable_arena: &mut VariableArena) {
|
||||
let (y, fy) = variable_arena
|
||||
.get(self.mean)
|
||||
.map(|variable| (variable.get_value(), variable.get_message(self.id)))
|
||||
.unwrap();
|
||||
let (y, fy) = {
|
||||
let variable = &variable_arena[self.mean];
|
||||
|
||||
(variable.get_value(), variable.get_message(self.id))
|
||||
};
|
||||
|
||||
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
|
||||
let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau()));
|
||||
@@ -190,10 +193,7 @@ impl LikelihoodFactor {
|
||||
self.value.index, gaussian
|
||||
);
|
||||
|
||||
variable_arena
|
||||
.get_mut(self.value)
|
||||
.unwrap()
|
||||
.update_message(self.id, gaussian);
|
||||
variable_arena[self.value].update_message(self.id, gaussian);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -212,10 +212,10 @@ impl SumFactor {
|
||||
terms: Vec<VariableId>,
|
||||
coeffs: Vec<f64>,
|
||||
) -> SumFactor {
|
||||
variable_arena.get_mut(sum).unwrap().attach_factor(id);
|
||||
variable_arena[sum].attach_factor(id);
|
||||
|
||||
for term in &terms {
|
||||
variable_arena.get_mut(*term).unwrap().attach_factor(id);
|
||||
variable_arena[*term].attach_factor(id);
|
||||
}
|
||||
|
||||
SumFactor {
|
||||
@@ -257,10 +257,7 @@ impl SumFactor {
|
||||
debug!("Sum::up var={:?}, value={:?}", variable.index, gaussian);
|
||||
}
|
||||
|
||||
variable_arena
|
||||
.get_mut(variable)
|
||||
.unwrap()
|
||||
.update_message(self.id, gaussian);
|
||||
variable_arena[variable].update_message(self.id, gaussian);
|
||||
}
|
||||
|
||||
pub fn update_sum(&self, variable_arena: &mut VariableArena) {
|
||||
@@ -268,10 +265,9 @@ impl SumFactor {
|
||||
.terms
|
||||
.iter()
|
||||
.map(|term| {
|
||||
variable_arena
|
||||
.get(*term)
|
||||
.map(|variable| (variable.get_value(), variable.get_message(self.id)))
|
||||
.unwrap()
|
||||
let variable = &variable_arena[*term];
|
||||
|
||||
(variable.get_value(), variable.get_message(self.id))
|
||||
})
|
||||
.unzip();
|
||||
|
||||
@@ -300,12 +296,10 @@ impl SumFactor {
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, term)| {
|
||||
let variable = if i == index { self.sum } else { *term };
|
||||
let variable_id = if i == index { self.sum } else { *term };
|
||||
let variable = &variable_arena[variable_id];
|
||||
|
||||
variable_arena
|
||||
.get(variable)
|
||||
.map(|variable| (variable.get_value(), variable.get_message(self.id)))
|
||||
.unwrap()
|
||||
(variable.get_value(), variable.get_message(self.id))
|
||||
})
|
||||
.unzip();
|
||||
|
||||
@@ -356,7 +350,7 @@ impl TruncateFactor {
|
||||
epsilon: f64,
|
||||
draw: bool,
|
||||
) -> TruncateFactor {
|
||||
variable_arena.get_mut(variable).unwrap().attach_factor(id);
|
||||
variable_arena[variable].attach_factor(id);
|
||||
|
||||
TruncateFactor {
|
||||
id,
|
||||
@@ -367,10 +361,11 @@ impl TruncateFactor {
|
||||
}
|
||||
|
||||
pub fn update(&self, variable_arena: &mut VariableArena) -> f64 {
|
||||
let (x, fx) = variable_arena
|
||||
.get(self.variable)
|
||||
.map(|variable| (variable.get_value(), variable.get_message(self.id)))
|
||||
.unwrap();
|
||||
let (x, fx) = {
|
||||
let variable = &variable_arena[self.variable];
|
||||
|
||||
(variable.get_value(), variable.get_message(self.id))
|
||||
};
|
||||
|
||||
let c = x.pi() - fx.pi();
|
||||
let d = x.tau() - fx.tau();
|
||||
@@ -395,9 +390,6 @@ impl TruncateFactor {
|
||||
self.variable.index, gaussian
|
||||
);
|
||||
|
||||
variable_arena
|
||||
.get_mut(self.variable)
|
||||
.unwrap()
|
||||
.update_value(self.id, gaussian)
|
||||
variable_arena[self.variable].update_value(self.id, gaussian)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -261,7 +261,7 @@ impl TrueSkill {
|
||||
|
||||
rating_vars
|
||||
.iter()
|
||||
.map(|variable| variable_arena.get(*variable).unwrap().get_value())
|
||||
.map(|variable| variable_arena[*variable].get_value())
|
||||
.map(|value| Rating {
|
||||
mu: value.mu(),
|
||||
sigma: value.sigma(),
|
||||
|
||||
Reference in New Issue
Block a user