From 712bcd9f422cf7c7b8b578b70cf40fcbe154730b Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Sun, 28 Oct 2018 19:52:57 +0100 Subject: [PATCH] Added index support for VariableArena. --- src/factor_graph.rs | 94 +++++++++++++++++++++------------------------ src/lib.rs | 2 +- 2 files changed, 44 insertions(+), 52 deletions(-) diff --git a/src/factor_graph.rs b/src/factor_graph.rs index 486dbbb..6727ebf 100644 --- a/src/factor_graph.rs +++ b/src/factor_graph.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::f64; +use std::ops; use gaussian::Gaussian; use math; @@ -30,13 +31,19 @@ impl VariableArena { VariableId { index } } +} - pub fn get(&mut self, id: VariableId) -> Option<&Variable> { - self.variables.get(id.index) +impl ops::Index for VariableArena { + type Output = Variable; + + fn index(&self, id: VariableId) -> &Self::Output { + &self.variables[id.index] } +} - pub fn get_mut(&mut self, id: VariableId) -> Option<&mut Variable> { - self.variables.get_mut(id.index) +impl ops::IndexMut for VariableArena { + fn index_mut(&mut self, id: VariableId) -> &mut Self::Output { + &mut self.variables[id.index] } } @@ -108,7 +115,7 @@ impl PriorFactor { variable: VariableId, gaussian: Gaussian, ) -> PriorFactor { - variable_arena.get_mut(variable).unwrap().attach_factor(id); + variable_arena[variable].attach_factor(id); PriorFactor { id, @@ -123,10 +130,7 @@ impl PriorFactor { self.variable.index, self.gaussian ); - variable_arena - .get_mut(self.variable) - .unwrap() - .update_value(self.id, self.gaussian); + variable_arena[self.variable].update_value(self.id, self.gaussian); } } @@ -145,8 +149,8 @@ impl LikelihoodFactor { value: VariableId, variance: f64, ) -> LikelihoodFactor { - variable_arena.get_mut(mean).unwrap().attach_factor(id); - variable_arena.get_mut(value).unwrap().attach_factor(id); + variable_arena[mean].attach_factor(id); + variable_arena[value].attach_factor(id); LikelihoodFactor { id, @@ -157,10 +161,11 @@ impl LikelihoodFactor { } pub fn update_mean(&self, variable_arena: &mut VariableArena) { - let (x, fx) = variable_arena - .get(self.value) - .map(|variable| (variable.get_value(), variable.get_message(self.id))) - .unwrap(); + let (x, fx) = { + let variable = &variable_arena[self.value]; + + (variable.get_value(), variable.get_message(self.id)) + }; let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi())); let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau())); @@ -170,17 +175,15 @@ impl LikelihoodFactor { self.mean.index, gaussian ); - variable_arena - .get_mut(self.mean) - .unwrap() - .update_message(self.id, gaussian); + variable_arena[self.mean].update_message(self.id, gaussian); } pub fn update_value(&self, variable_arena: &mut VariableArena) { - let (y, fy) = variable_arena - .get(self.mean) - .map(|variable| (variable.get_value(), variable.get_message(self.id))) - .unwrap(); + let (y, fy) = { + let variable = &variable_arena[self.mean]; + + (variable.get_value(), variable.get_message(self.id)) + }; let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi())); let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau())); @@ -190,10 +193,7 @@ impl LikelihoodFactor { self.value.index, gaussian ); - variable_arena - .get_mut(self.value) - .unwrap() - .update_message(self.id, gaussian); + variable_arena[self.value].update_message(self.id, gaussian); } } @@ -212,10 +212,10 @@ impl SumFactor { terms: Vec, coeffs: Vec, ) -> SumFactor { - variable_arena.get_mut(sum).unwrap().attach_factor(id); + variable_arena[sum].attach_factor(id); for term in &terms { - variable_arena.get_mut(*term).unwrap().attach_factor(id); + variable_arena[*term].attach_factor(id); } SumFactor { @@ -257,10 +257,7 @@ impl SumFactor { debug!("Sum::up var={:?}, value={:?}", variable.index, gaussian); } - variable_arena - .get_mut(variable) - .unwrap() - .update_message(self.id, gaussian); + variable_arena[variable].update_message(self.id, gaussian); } pub fn update_sum(&self, variable_arena: &mut VariableArena) { @@ -268,10 +265,9 @@ impl SumFactor { .terms .iter() .map(|term| { - variable_arena - .get(*term) - .map(|variable| (variable.get_value(), variable.get_message(self.id))) - .unwrap() + let variable = &variable_arena[*term]; + + (variable.get_value(), variable.get_message(self.id)) }) .unzip(); @@ -300,12 +296,10 @@ impl SumFactor { .iter() .enumerate() .map(|(i, term)| { - let variable = if i == index { self.sum } else { *term }; + let variable_id = if i == index { self.sum } else { *term }; + let variable = &variable_arena[variable_id]; - variable_arena - .get(variable) - .map(|variable| (variable.get_value(), variable.get_message(self.id))) - .unwrap() + (variable.get_value(), variable.get_message(self.id)) }) .unzip(); @@ -356,7 +350,7 @@ impl TruncateFactor { epsilon: f64, draw: bool, ) -> TruncateFactor { - variable_arena.get_mut(variable).unwrap().attach_factor(id); + variable_arena[variable].attach_factor(id); TruncateFactor { id, @@ -367,10 +361,11 @@ impl TruncateFactor { } pub fn update(&self, variable_arena: &mut VariableArena) -> f64 { - let (x, fx) = variable_arena - .get(self.variable) - .map(|variable| (variable.get_value(), variable.get_message(self.id))) - .unwrap(); + let (x, fx) = { + let variable = &variable_arena[self.variable]; + + (variable.get_value(), variable.get_message(self.id)) + }; let c = x.pi() - fx.pi(); let d = x.tau() - fx.tau(); @@ -395,9 +390,6 @@ impl TruncateFactor { self.variable.index, gaussian ); - variable_arena - .get_mut(self.variable) - .unwrap() - .update_value(self.id, gaussian) + variable_arena[self.variable].update_value(self.id, gaussian) } } diff --git a/src/lib.rs b/src/lib.rs index 1e54fba..da8fda0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -261,7 +261,7 @@ impl TrueSkill { rating_vars .iter() - .map(|variable| variable_arena.get(*variable).unwrap().get_value()) + .map(|variable| variable_arena[*variable].get_value()) .map(|value| Rating { mu: value.mu(), sigma: value.sigma(),