From 5c202c3f804abf62405a97c851f0a54dee68d094 Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Wed, 11 Jan 2023 23:06:23 +0100 Subject: [PATCH] Remove old VariableArena and Variable --- src/factor_graph.rs | 147 +++++--------------------------------------- src/lib.rs | 4 +- 2 files changed, 16 insertions(+), 135 deletions(-) diff --git a/src/factor_graph.rs b/src/factor_graph.rs index 466cb92..c46ceff 100644 --- a/src/factor_graph.rs +++ b/src/factor_graph.rs @@ -5,110 +5,9 @@ use std::ops; use log::*; use crate::gaussian::Gaussian; -use crate::graph::{MessageArena, MessageId}; +use crate::graph::{MessageArena, MessageId, VariableArena, VariableId}; use crate::math; -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub struct VariableId { - index: usize, -} - -pub struct VariableArena { - variables: Vec, -} - -impl VariableArena { - pub fn new() -> VariableArena { - VariableArena { - variables: Vec::new(), - } - } - - pub fn create(&mut self) -> VariableId { - let index = self.variables.len(); - - self.variables.push(Variable { - value: Gaussian::from_pi_tau(0.0, 0.0), - factors: HashMap::new(), - }); - - VariableId { index } - } -} - -impl ops::Index for VariableArena { - type Output = Variable; - - fn index(&self, id: VariableId) -> &Self::Output { - &self.variables[id.index] - } -} - -impl ops::IndexMut for VariableArena { - fn index_mut(&mut self, id: VariableId) -> &mut Self::Output { - &mut self.variables[id.index] - } -} - -pub struct Variable { - value: Gaussian, - factors: HashMap, -} - -impl Variable { - pub fn attach_factor(&mut self, factor: usize) { - self.factors.insert(factor, Gaussian::from_pi_tau(0.0, 0.0)); - } - - pub fn update_value(&mut self, factor: usize, value: Gaussian) -> f64 { - let old = self.factors[&factor]; - - self.factors.insert(factor, value * old / self.value); - - let pi_delta = (self.value.pi() - value.pi()).abs(); - - let delta = if !pi_delta.is_finite() { - 0.0 - } else { - let pi_delta = pi_delta.sqrt(); - let tau_delta = (self.value.tau() - value.tau()).abs(); - - if pi_delta > tau_delta { - pi_delta - } else { - tau_delta - } - }; - - self.value = value; - - debug!("Variable::value old={:?}, new={:?}", old, value); - - delta - } - - pub fn get_value(&self) -> Gaussian { - self.value - } - - pub fn get_value_mut(&mut self) -> &mut Gaussian { - &mut self.value - } - - pub fn update_message(&mut self, factor: usize, message: Gaussian) { - let old = self.factors[&factor]; - - self.factors.insert(factor, message); - self.value = self.value / old * message; - - debug!("Variable::message old={:?}, new={:?}", old, message); - } - - pub fn get_message(&self, factor: usize) -> Gaussian { - self.factors[&factor] - } -} - pub struct PriorFactor { id: MessageId, variable: VariableId, @@ -132,11 +31,10 @@ impl PriorFactor { pub fn start(&self, variable_arena: &mut VariableArena, message_arena: &mut MessageArena) { let old = message_arena[self.id]; - let value = variable_arena[self.variable].get_value(); + let value = variable_arena[self.variable]; message_arena[self.id] = self.gaussian * old / value; - - *variable_arena[self.variable].get_value_mut() = self.gaussian; + variable_arena[self.variable] = self.gaussian; } } @@ -169,7 +67,7 @@ impl LikelihoodFactor { variable_arena: &mut VariableArena, message_arena: &mut MessageArena, ) { - let x = variable_arena[self.value].get_value(); + let x = variable_arena[self.value]; let fx = message_arena[self.value_msg]; let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi())); @@ -178,8 +76,7 @@ impl LikelihoodFactor { let old = message_arena[self.mean_msg]; message_arena[self.mean_msg] = gaussian; - *variable_arena[self.mean].get_value_mut() = - variable_arena[self.mean].get_value() / old * gaussian; + variable_arena[self.mean] = variable_arena[self.mean] / old * gaussian; } pub fn update_value( @@ -187,7 +84,7 @@ impl LikelihoodFactor { variable_arena: &mut VariableArena, message_arena: &mut MessageArena, ) { - let y = variable_arena[self.mean].get_value(); + let y = variable_arena[self.mean]; let fy = message_arena[self.mean_msg]; let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi())); @@ -196,8 +93,7 @@ impl LikelihoodFactor { let old = message_arena[self.value_msg]; message_arena[self.value_msg] = gaussian; - *variable_arena[self.value].get_value_mut() = - variable_arena[self.value].get_value() / old * gaussian; + variable_arena[self.value] = variable_arena[self.value] / old * gaussian; } } @@ -255,17 +151,10 @@ impl SumFactor { let gaussian = Gaussian::from_pi_tau(new_pi, new_tau); - if variable == self.sum { - debug!("Sum::down var={:?}, value={:?}", variable.index, gaussian); - } else { - debug!("Sum::up var={:?}, value={:?}", variable.index, gaussian); - } - let old = message_arena[message]; message_arena[message] = gaussian; - *variable_arena[variable].get_value_mut() = - variable_arena[variable].get_value() / old * gaussian; + variable_arena[variable] = variable_arena[variable] / old * gaussian; } pub fn update_sum(&self, variable_arena: &mut VariableArena, message_arena: &mut MessageArena) { @@ -273,11 +162,7 @@ impl SumFactor { .terms .iter() .zip(self.terms_msg.iter()) - .map(|(term, msg)| { - let variable = &variable_arena[*term]; - - (variable.get_value(), message_arena[*msg]) - }) + .map(|(term, msg)| (variable_arena[*term], message_arena[*msg])) .unzip(); self.internal_update( @@ -321,13 +206,9 @@ impl SumFactor { .enumerate() .map(|(i, (term, msg))| { if i == index { - let variable = &variable_arena[self.sum]; - - (variable.get_value(), message_arena[self.sum_msg]) + (variable_arena[self.sum], message_arena[self.sum_msg]) } else { - let variable = &variable_arena[*term]; - - (variable.get_value(), message_arena[*msg]) + (variable_arena[*term], message_arena[*msg]) } }) .unzip(); @@ -403,7 +284,7 @@ impl TruncateFactor { variable_arena: &mut VariableArena, message_arena: &mut MessageArena, ) -> f64 { - let x = variable_arena[self.variable].get_value(); + let x = variable_arena[self.variable]; let fx = message_arena[self.variable_msg]; let c = x.pi() - fx.pi(); @@ -425,7 +306,7 @@ impl TruncateFactor { let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w); let old = message_arena[self.variable_msg]; - let value = variable_arena[self.variable].get_value(); + let value = variable_arena[self.variable]; message_arena[self.variable_msg] = gaussian * old / value; @@ -444,7 +325,7 @@ impl TruncateFactor { } }; - *variable_arena[self.variable].get_value_mut() = gaussian; + variable_arena[self.variable] = gaussian; delta } diff --git a/src/lib.rs b/src/lib.rs index 7b1da48..1c0dc73 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,7 +8,7 @@ mod matrix; use crate::factor_graph::*; use crate::gaussian::Gaussian; -use crate::graph::MessageArena; +use crate::graph::{MessageArena, VariableArena}; use crate::matrix::Matrix; /// Default initial mean of ratings. @@ -265,7 +265,7 @@ impl TrueSkill { rating_vars .iter() - .map(|variable| variable_arena[*variable].get_value()) + .map(|variable| variable_arena[*variable]) .map(|value| Rating { mu: value.mu(), sigma: value.sigma(),