Remove old VariableArena and Variable

This commit is contained in:
2023-01-11 23:06:23 +01:00
parent 8236496c4b
commit 5c202c3f80
2 changed files with 16 additions and 135 deletions

View File

@@ -5,110 +5,9 @@ use std::ops;
use log::*;
use crate::gaussian::Gaussian;
use crate::graph::{MessageArena, MessageId};
use crate::graph::{MessageArena, MessageId, VariableArena, VariableId};
use crate::math;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct VariableId {
index: usize,
}
pub struct VariableArena {
variables: Vec<Variable>,
}
impl VariableArena {
pub fn new() -> VariableArena {
VariableArena {
variables: Vec::new(),
}
}
pub fn create(&mut self) -> VariableId {
let index = self.variables.len();
self.variables.push(Variable {
value: Gaussian::from_pi_tau(0.0, 0.0),
factors: HashMap::new(),
});
VariableId { index }
}
}
impl ops::Index<VariableId> for VariableArena {
type Output = Variable;
fn index(&self, id: VariableId) -> &Self::Output {
&self.variables[id.index]
}
}
impl ops::IndexMut<VariableId> for VariableArena {
fn index_mut(&mut self, id: VariableId) -> &mut Self::Output {
&mut self.variables[id.index]
}
}
pub struct Variable {
value: Gaussian,
factors: HashMap<usize, Gaussian>,
}
impl Variable {
pub fn attach_factor(&mut self, factor: usize) {
self.factors.insert(factor, Gaussian::from_pi_tau(0.0, 0.0));
}
pub fn update_value(&mut self, factor: usize, value: Gaussian) -> f64 {
let old = self.factors[&factor];
self.factors.insert(factor, value * old / self.value);
let pi_delta = (self.value.pi() - value.pi()).abs();
let delta = if !pi_delta.is_finite() {
0.0
} else {
let pi_delta = pi_delta.sqrt();
let tau_delta = (self.value.tau() - value.tau()).abs();
if pi_delta > tau_delta {
pi_delta
} else {
tau_delta
}
};
self.value = value;
debug!("Variable::value old={:?}, new={:?}", old, value);
delta
}
pub fn get_value(&self) -> Gaussian {
self.value
}
pub fn get_value_mut(&mut self) -> &mut Gaussian {
&mut self.value
}
pub fn update_message(&mut self, factor: usize, message: Gaussian) {
let old = self.factors[&factor];
self.factors.insert(factor, message);
self.value = self.value / old * message;
debug!("Variable::message old={:?}, new={:?}", old, message);
}
pub fn get_message(&self, factor: usize) -> Gaussian {
self.factors[&factor]
}
}
pub struct PriorFactor {
id: MessageId,
variable: VariableId,
@@ -132,11 +31,10 @@ impl PriorFactor {
pub fn start(&self, variable_arena: &mut VariableArena, message_arena: &mut MessageArena) {
let old = message_arena[self.id];
let value = variable_arena[self.variable].get_value();
let value = variable_arena[self.variable];
message_arena[self.id] = self.gaussian * old / value;
*variable_arena[self.variable].get_value_mut() = self.gaussian;
variable_arena[self.variable] = self.gaussian;
}
}
@@ -169,7 +67,7 @@ impl LikelihoodFactor {
variable_arena: &mut VariableArena,
message_arena: &mut MessageArena,
) {
let x = variable_arena[self.value].get_value();
let x = variable_arena[self.value];
let fx = message_arena[self.value_msg];
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
@@ -178,8 +76,7 @@ impl LikelihoodFactor {
let old = message_arena[self.mean_msg];
message_arena[self.mean_msg] = gaussian;
*variable_arena[self.mean].get_value_mut() =
variable_arena[self.mean].get_value() / old * gaussian;
variable_arena[self.mean] = variable_arena[self.mean] / old * gaussian;
}
pub fn update_value(
@@ -187,7 +84,7 @@ impl LikelihoodFactor {
variable_arena: &mut VariableArena,
message_arena: &mut MessageArena,
) {
let y = variable_arena[self.mean].get_value();
let y = variable_arena[self.mean];
let fy = message_arena[self.mean_msg];
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
@@ -196,8 +93,7 @@ impl LikelihoodFactor {
let old = message_arena[self.value_msg];
message_arena[self.value_msg] = gaussian;
*variable_arena[self.value].get_value_mut() =
variable_arena[self.value].get_value() / old * gaussian;
variable_arena[self.value] = variable_arena[self.value] / old * gaussian;
}
}
@@ -255,17 +151,10 @@ impl SumFactor {
let gaussian = Gaussian::from_pi_tau(new_pi, new_tau);
if variable == self.sum {
debug!("Sum::down var={:?}, value={:?}", variable.index, gaussian);
} else {
debug!("Sum::up var={:?}, value={:?}", variable.index, gaussian);
}
let old = message_arena[message];
message_arena[message] = gaussian;
*variable_arena[variable].get_value_mut() =
variable_arena[variable].get_value() / old * gaussian;
variable_arena[variable] = variable_arena[variable] / old * gaussian;
}
pub fn update_sum(&self, variable_arena: &mut VariableArena, message_arena: &mut MessageArena) {
@@ -273,11 +162,7 @@ impl SumFactor {
.terms
.iter()
.zip(self.terms_msg.iter())
.map(|(term, msg)| {
let variable = &variable_arena[*term];
(variable.get_value(), message_arena[*msg])
})
.map(|(term, msg)| (variable_arena[*term], message_arena[*msg]))
.unzip();
self.internal_update(
@@ -321,13 +206,9 @@ impl SumFactor {
.enumerate()
.map(|(i, (term, msg))| {
if i == index {
let variable = &variable_arena[self.sum];
(variable.get_value(), message_arena[self.sum_msg])
(variable_arena[self.sum], message_arena[self.sum_msg])
} else {
let variable = &variable_arena[*term];
(variable.get_value(), message_arena[*msg])
(variable_arena[*term], message_arena[*msg])
}
})
.unzip();
@@ -403,7 +284,7 @@ impl TruncateFactor {
variable_arena: &mut VariableArena,
message_arena: &mut MessageArena,
) -> f64 {
let x = variable_arena[self.variable].get_value();
let x = variable_arena[self.variable];
let fx = message_arena[self.variable_msg];
let c = x.pi() - fx.pi();
@@ -425,7 +306,7 @@ impl TruncateFactor {
let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w);
let old = message_arena[self.variable_msg];
let value = variable_arena[self.variable].get_value();
let value = variable_arena[self.variable];
message_arena[self.variable_msg] = gaussian * old / value;
@@ -444,7 +325,7 @@ impl TruncateFactor {
}
};
*variable_arena[self.variable].get_value_mut() = gaussian;
variable_arena[self.variable] = gaussian;
delta
}

View File

@@ -8,7 +8,7 @@ mod matrix;
use crate::factor_graph::*;
use crate::gaussian::Gaussian;
use crate::graph::MessageArena;
use crate::graph::{MessageArena, VariableArena};
use crate::matrix::Matrix;
/// Default initial mean of ratings.
@@ -265,7 +265,7 @@ impl TrueSkill {
rating_vars
.iter()
.map(|variable| variable_arena[*variable].get_value())
.map(|variable| variable_arena[*variable])
.map(|value| Rating {
mu: value.mu(),
sigma: value.sigma(),