Remove old VariableArena and Variable

This commit is contained in:
2023-01-11 23:06:23 +01:00
parent 8236496c4b
commit 5c202c3f80
2 changed files with 16 additions and 135 deletions

View File

@@ -5,110 +5,9 @@ use std::ops;
use log::*; use log::*;
use crate::gaussian::Gaussian; use crate::gaussian::Gaussian;
use crate::graph::{MessageArena, MessageId}; use crate::graph::{MessageArena, MessageId, VariableArena, VariableId};
use crate::math; use crate::math;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct VariableId {
index: usize,
}
pub struct VariableArena {
variables: Vec<Variable>,
}
impl VariableArena {
pub fn new() -> VariableArena {
VariableArena {
variables: Vec::new(),
}
}
pub fn create(&mut self) -> VariableId {
let index = self.variables.len();
self.variables.push(Variable {
value: Gaussian::from_pi_tau(0.0, 0.0),
factors: HashMap::new(),
});
VariableId { index }
}
}
impl ops::Index<VariableId> for VariableArena {
type Output = Variable;
fn index(&self, id: VariableId) -> &Self::Output {
&self.variables[id.index]
}
}
impl ops::IndexMut<VariableId> for VariableArena {
fn index_mut(&mut self, id: VariableId) -> &mut Self::Output {
&mut self.variables[id.index]
}
}
pub struct Variable {
value: Gaussian,
factors: HashMap<usize, Gaussian>,
}
impl Variable {
pub fn attach_factor(&mut self, factor: usize) {
self.factors.insert(factor, Gaussian::from_pi_tau(0.0, 0.0));
}
pub fn update_value(&mut self, factor: usize, value: Gaussian) -> f64 {
let old = self.factors[&factor];
self.factors.insert(factor, value * old / self.value);
let pi_delta = (self.value.pi() - value.pi()).abs();
let delta = if !pi_delta.is_finite() {
0.0
} else {
let pi_delta = pi_delta.sqrt();
let tau_delta = (self.value.tau() - value.tau()).abs();
if pi_delta > tau_delta {
pi_delta
} else {
tau_delta
}
};
self.value = value;
debug!("Variable::value old={:?}, new={:?}", old, value);
delta
}
pub fn get_value(&self) -> Gaussian {
self.value
}
pub fn get_value_mut(&mut self) -> &mut Gaussian {
&mut self.value
}
pub fn update_message(&mut self, factor: usize, message: Gaussian) {
let old = self.factors[&factor];
self.factors.insert(factor, message);
self.value = self.value / old * message;
debug!("Variable::message old={:?}, new={:?}", old, message);
}
pub fn get_message(&self, factor: usize) -> Gaussian {
self.factors[&factor]
}
}
pub struct PriorFactor { pub struct PriorFactor {
id: MessageId, id: MessageId,
variable: VariableId, variable: VariableId,
@@ -132,11 +31,10 @@ impl PriorFactor {
pub fn start(&self, variable_arena: &mut VariableArena, message_arena: &mut MessageArena) { pub fn start(&self, variable_arena: &mut VariableArena, message_arena: &mut MessageArena) {
let old = message_arena[self.id]; let old = message_arena[self.id];
let value = variable_arena[self.variable].get_value(); let value = variable_arena[self.variable];
message_arena[self.id] = self.gaussian * old / value; message_arena[self.id] = self.gaussian * old / value;
variable_arena[self.variable] = self.gaussian;
*variable_arena[self.variable].get_value_mut() = self.gaussian;
} }
} }
@@ -169,7 +67,7 @@ impl LikelihoodFactor {
variable_arena: &mut VariableArena, variable_arena: &mut VariableArena,
message_arena: &mut MessageArena, message_arena: &mut MessageArena,
) { ) {
let x = variable_arena[self.value].get_value(); let x = variable_arena[self.value];
let fx = message_arena[self.value_msg]; let fx = message_arena[self.value_msg];
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi())); let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
@@ -178,8 +76,7 @@ impl LikelihoodFactor {
let old = message_arena[self.mean_msg]; let old = message_arena[self.mean_msg];
message_arena[self.mean_msg] = gaussian; message_arena[self.mean_msg] = gaussian;
*variable_arena[self.mean].get_value_mut() = variable_arena[self.mean] = variable_arena[self.mean] / old * gaussian;
variable_arena[self.mean].get_value() / old * gaussian;
} }
pub fn update_value( pub fn update_value(
@@ -187,7 +84,7 @@ impl LikelihoodFactor {
variable_arena: &mut VariableArena, variable_arena: &mut VariableArena,
message_arena: &mut MessageArena, message_arena: &mut MessageArena,
) { ) {
let y = variable_arena[self.mean].get_value(); let y = variable_arena[self.mean];
let fy = message_arena[self.mean_msg]; let fy = message_arena[self.mean_msg];
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi())); let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
@@ -196,8 +93,7 @@ impl LikelihoodFactor {
let old = message_arena[self.value_msg]; let old = message_arena[self.value_msg];
message_arena[self.value_msg] = gaussian; message_arena[self.value_msg] = gaussian;
*variable_arena[self.value].get_value_mut() = variable_arena[self.value] = variable_arena[self.value] / old * gaussian;
variable_arena[self.value].get_value() / old * gaussian;
} }
} }
@@ -255,17 +151,10 @@ impl SumFactor {
let gaussian = Gaussian::from_pi_tau(new_pi, new_tau); let gaussian = Gaussian::from_pi_tau(new_pi, new_tau);
if variable == self.sum {
debug!("Sum::down var={:?}, value={:?}", variable.index, gaussian);
} else {
debug!("Sum::up var={:?}, value={:?}", variable.index, gaussian);
}
let old = message_arena[message]; let old = message_arena[message];
message_arena[message] = gaussian; message_arena[message] = gaussian;
*variable_arena[variable].get_value_mut() = variable_arena[variable] = variable_arena[variable] / old * gaussian;
variable_arena[variable].get_value() / old * gaussian;
} }
pub fn update_sum(&self, variable_arena: &mut VariableArena, message_arena: &mut MessageArena) { pub fn update_sum(&self, variable_arena: &mut VariableArena, message_arena: &mut MessageArena) {
@@ -273,11 +162,7 @@ impl SumFactor {
.terms .terms
.iter() .iter()
.zip(self.terms_msg.iter()) .zip(self.terms_msg.iter())
.map(|(term, msg)| { .map(|(term, msg)| (variable_arena[*term], message_arena[*msg]))
let variable = &variable_arena[*term];
(variable.get_value(), message_arena[*msg])
})
.unzip(); .unzip();
self.internal_update( self.internal_update(
@@ -321,13 +206,9 @@ impl SumFactor {
.enumerate() .enumerate()
.map(|(i, (term, msg))| { .map(|(i, (term, msg))| {
if i == index { if i == index {
let variable = &variable_arena[self.sum]; (variable_arena[self.sum], message_arena[self.sum_msg])
(variable.get_value(), message_arena[self.sum_msg])
} else { } else {
let variable = &variable_arena[*term]; (variable_arena[*term], message_arena[*msg])
(variable.get_value(), message_arena[*msg])
} }
}) })
.unzip(); .unzip();
@@ -403,7 +284,7 @@ impl TruncateFactor {
variable_arena: &mut VariableArena, variable_arena: &mut VariableArena,
message_arena: &mut MessageArena, message_arena: &mut MessageArena,
) -> f64 { ) -> f64 {
let x = variable_arena[self.variable].get_value(); let x = variable_arena[self.variable];
let fx = message_arena[self.variable_msg]; let fx = message_arena[self.variable_msg];
let c = x.pi() - fx.pi(); let c = x.pi() - fx.pi();
@@ -425,7 +306,7 @@ impl TruncateFactor {
let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w); let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w);
let old = message_arena[self.variable_msg]; let old = message_arena[self.variable_msg];
let value = variable_arena[self.variable].get_value(); let value = variable_arena[self.variable];
message_arena[self.variable_msg] = gaussian * old / value; message_arena[self.variable_msg] = gaussian * old / value;
@@ -444,7 +325,7 @@ impl TruncateFactor {
} }
}; };
*variable_arena[self.variable].get_value_mut() = gaussian; variable_arena[self.variable] = gaussian;
delta delta
} }

View File

@@ -8,7 +8,7 @@ mod matrix;
use crate::factor_graph::*; use crate::factor_graph::*;
use crate::gaussian::Gaussian; use crate::gaussian::Gaussian;
use crate::graph::MessageArena; use crate::graph::{MessageArena, VariableArena};
use crate::matrix::Matrix; use crate::matrix::Matrix;
/// Default initial mean of ratings. /// Default initial mean of ratings.
@@ -265,7 +265,7 @@ impl TrueSkill {
rating_vars rating_vars
.iter() .iter()
.map(|variable| variable_arena[*variable].get_value()) .map(|variable| variable_arena[*variable])
.map(|value| Rating { .map(|value| Rating {
mu: value.mu(), mu: value.mu(),
sigma: value.sigma(), sigma: value.sigma(),