Added index support for VariableArena.
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::f64;
|
use std::f64;
|
||||||
|
use std::ops;
|
||||||
|
|
||||||
use gaussian::Gaussian;
|
use gaussian::Gaussian;
|
||||||
use math;
|
use math;
|
||||||
@@ -30,13 +31,19 @@ impl VariableArena {
|
|||||||
|
|
||||||
VariableId { index }
|
VariableId { index }
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn get(&mut self, id: VariableId) -> Option<&Variable> {
|
impl ops::Index<VariableId> for VariableArena {
|
||||||
self.variables.get(id.index)
|
type Output = Variable;
|
||||||
|
|
||||||
|
fn index(&self, id: VariableId) -> &Self::Output {
|
||||||
|
&self.variables[id.index]
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn get_mut(&mut self, id: VariableId) -> Option<&mut Variable> {
|
impl ops::IndexMut<VariableId> for VariableArena {
|
||||||
self.variables.get_mut(id.index)
|
fn index_mut(&mut self, id: VariableId) -> &mut Self::Output {
|
||||||
|
&mut self.variables[id.index]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,7 +115,7 @@ impl PriorFactor {
|
|||||||
variable: VariableId,
|
variable: VariableId,
|
||||||
gaussian: Gaussian,
|
gaussian: Gaussian,
|
||||||
) -> PriorFactor {
|
) -> PriorFactor {
|
||||||
variable_arena.get_mut(variable).unwrap().attach_factor(id);
|
variable_arena[variable].attach_factor(id);
|
||||||
|
|
||||||
PriorFactor {
|
PriorFactor {
|
||||||
id,
|
id,
|
||||||
@@ -123,10 +130,7 @@ impl PriorFactor {
|
|||||||
self.variable.index, self.gaussian
|
self.variable.index, self.gaussian
|
||||||
);
|
);
|
||||||
|
|
||||||
variable_arena
|
variable_arena[self.variable].update_value(self.id, self.gaussian);
|
||||||
.get_mut(self.variable)
|
|
||||||
.unwrap()
|
|
||||||
.update_value(self.id, self.gaussian);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -145,8 +149,8 @@ impl LikelihoodFactor {
|
|||||||
value: VariableId,
|
value: VariableId,
|
||||||
variance: f64,
|
variance: f64,
|
||||||
) -> LikelihoodFactor {
|
) -> LikelihoodFactor {
|
||||||
variable_arena.get_mut(mean).unwrap().attach_factor(id);
|
variable_arena[mean].attach_factor(id);
|
||||||
variable_arena.get_mut(value).unwrap().attach_factor(id);
|
variable_arena[value].attach_factor(id);
|
||||||
|
|
||||||
LikelihoodFactor {
|
LikelihoodFactor {
|
||||||
id,
|
id,
|
||||||
@@ -157,10 +161,11 @@ impl LikelihoodFactor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn update_mean(&self, variable_arena: &mut VariableArena) {
|
pub fn update_mean(&self, variable_arena: &mut VariableArena) {
|
||||||
let (x, fx) = variable_arena
|
let (x, fx) = {
|
||||||
.get(self.value)
|
let variable = &variable_arena[self.value];
|
||||||
.map(|variable| (variable.get_value(), variable.get_message(self.id)))
|
|
||||||
.unwrap();
|
(variable.get_value(), variable.get_message(self.id))
|
||||||
|
};
|
||||||
|
|
||||||
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
|
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
|
||||||
let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau()));
|
let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau()));
|
||||||
@@ -170,17 +175,15 @@ impl LikelihoodFactor {
|
|||||||
self.mean.index, gaussian
|
self.mean.index, gaussian
|
||||||
);
|
);
|
||||||
|
|
||||||
variable_arena
|
variable_arena[self.mean].update_message(self.id, gaussian);
|
||||||
.get_mut(self.mean)
|
|
||||||
.unwrap()
|
|
||||||
.update_message(self.id, gaussian);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn update_value(&self, variable_arena: &mut VariableArena) {
|
pub fn update_value(&self, variable_arena: &mut VariableArena) {
|
||||||
let (y, fy) = variable_arena
|
let (y, fy) = {
|
||||||
.get(self.mean)
|
let variable = &variable_arena[self.mean];
|
||||||
.map(|variable| (variable.get_value(), variable.get_message(self.id)))
|
|
||||||
.unwrap();
|
(variable.get_value(), variable.get_message(self.id))
|
||||||
|
};
|
||||||
|
|
||||||
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
|
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
|
||||||
let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau()));
|
let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau()));
|
||||||
@@ -190,10 +193,7 @@ impl LikelihoodFactor {
|
|||||||
self.value.index, gaussian
|
self.value.index, gaussian
|
||||||
);
|
);
|
||||||
|
|
||||||
variable_arena
|
variable_arena[self.value].update_message(self.id, gaussian);
|
||||||
.get_mut(self.value)
|
|
||||||
.unwrap()
|
|
||||||
.update_message(self.id, gaussian);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -212,10 +212,10 @@ impl SumFactor {
|
|||||||
terms: Vec<VariableId>,
|
terms: Vec<VariableId>,
|
||||||
coeffs: Vec<f64>,
|
coeffs: Vec<f64>,
|
||||||
) -> SumFactor {
|
) -> SumFactor {
|
||||||
variable_arena.get_mut(sum).unwrap().attach_factor(id);
|
variable_arena[sum].attach_factor(id);
|
||||||
|
|
||||||
for term in &terms {
|
for term in &terms {
|
||||||
variable_arena.get_mut(*term).unwrap().attach_factor(id);
|
variable_arena[*term].attach_factor(id);
|
||||||
}
|
}
|
||||||
|
|
||||||
SumFactor {
|
SumFactor {
|
||||||
@@ -257,10 +257,7 @@ impl SumFactor {
|
|||||||
debug!("Sum::up var={:?}, value={:?}", variable.index, gaussian);
|
debug!("Sum::up var={:?}, value={:?}", variable.index, gaussian);
|
||||||
}
|
}
|
||||||
|
|
||||||
variable_arena
|
variable_arena[variable].update_message(self.id, gaussian);
|
||||||
.get_mut(variable)
|
|
||||||
.unwrap()
|
|
||||||
.update_message(self.id, gaussian);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn update_sum(&self, variable_arena: &mut VariableArena) {
|
pub fn update_sum(&self, variable_arena: &mut VariableArena) {
|
||||||
@@ -268,10 +265,9 @@ impl SumFactor {
|
|||||||
.terms
|
.terms
|
||||||
.iter()
|
.iter()
|
||||||
.map(|term| {
|
.map(|term| {
|
||||||
variable_arena
|
let variable = &variable_arena[*term];
|
||||||
.get(*term)
|
|
||||||
.map(|variable| (variable.get_value(), variable.get_message(self.id)))
|
(variable.get_value(), variable.get_message(self.id))
|
||||||
.unwrap()
|
|
||||||
})
|
})
|
||||||
.unzip();
|
.unzip();
|
||||||
|
|
||||||
@@ -300,12 +296,10 @@ impl SumFactor {
|
|||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(i, term)| {
|
.map(|(i, term)| {
|
||||||
let variable = if i == index { self.sum } else { *term };
|
let variable_id = if i == index { self.sum } else { *term };
|
||||||
|
let variable = &variable_arena[variable_id];
|
||||||
|
|
||||||
variable_arena
|
(variable.get_value(), variable.get_message(self.id))
|
||||||
.get(variable)
|
|
||||||
.map(|variable| (variable.get_value(), variable.get_message(self.id)))
|
|
||||||
.unwrap()
|
|
||||||
})
|
})
|
||||||
.unzip();
|
.unzip();
|
||||||
|
|
||||||
@@ -356,7 +350,7 @@ impl TruncateFactor {
|
|||||||
epsilon: f64,
|
epsilon: f64,
|
||||||
draw: bool,
|
draw: bool,
|
||||||
) -> TruncateFactor {
|
) -> TruncateFactor {
|
||||||
variable_arena.get_mut(variable).unwrap().attach_factor(id);
|
variable_arena[variable].attach_factor(id);
|
||||||
|
|
||||||
TruncateFactor {
|
TruncateFactor {
|
||||||
id,
|
id,
|
||||||
@@ -367,10 +361,11 @@ impl TruncateFactor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn update(&self, variable_arena: &mut VariableArena) -> f64 {
|
pub fn update(&self, variable_arena: &mut VariableArena) -> f64 {
|
||||||
let (x, fx) = variable_arena
|
let (x, fx) = {
|
||||||
.get(self.variable)
|
let variable = &variable_arena[self.variable];
|
||||||
.map(|variable| (variable.get_value(), variable.get_message(self.id)))
|
|
||||||
.unwrap();
|
(variable.get_value(), variable.get_message(self.id))
|
||||||
|
};
|
||||||
|
|
||||||
let c = x.pi() - fx.pi();
|
let c = x.pi() - fx.pi();
|
||||||
let d = x.tau() - fx.tau();
|
let d = x.tau() - fx.tau();
|
||||||
@@ -395,9 +390,6 @@ impl TruncateFactor {
|
|||||||
self.variable.index, gaussian
|
self.variable.index, gaussian
|
||||||
);
|
);
|
||||||
|
|
||||||
variable_arena
|
variable_arena[self.variable].update_value(self.id, gaussian)
|
||||||
.get_mut(self.variable)
|
|
||||||
.unwrap()
|
|
||||||
.update_value(self.id, gaussian)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -261,7 +261,7 @@ impl TrueSkill {
|
|||||||
|
|
||||||
rating_vars
|
rating_vars
|
||||||
.iter()
|
.iter()
|
||||||
.map(|variable| variable_arena.get(*variable).unwrap().get_value())
|
.map(|variable| variable_arena[*variable].get_value())
|
||||||
.map(|value| Rating {
|
.map(|value| Rating {
|
||||||
mu: value.mu(),
|
mu: value.mu(),
|
||||||
sigma: value.sigma(),
|
sigma: value.sigma(),
|
||||||
|
|||||||
Reference in New Issue
Block a user