Remove old VariableArena and Variable
This commit is contained in:
@@ -5,110 +5,9 @@ use std::ops;
|
||||
use log::*;
|
||||
|
||||
use crate::gaussian::Gaussian;
|
||||
use crate::graph::{MessageArena, MessageId};
|
||||
use crate::graph::{MessageArena, MessageId, VariableArena, VariableId};
|
||||
use crate::math;
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub struct VariableId {
|
||||
index: usize,
|
||||
}
|
||||
|
||||
pub struct VariableArena {
|
||||
variables: Vec<Variable>,
|
||||
}
|
||||
|
||||
impl VariableArena {
|
||||
pub fn new() -> VariableArena {
|
||||
VariableArena {
|
||||
variables: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create(&mut self) -> VariableId {
|
||||
let index = self.variables.len();
|
||||
|
||||
self.variables.push(Variable {
|
||||
value: Gaussian::from_pi_tau(0.0, 0.0),
|
||||
factors: HashMap::new(),
|
||||
});
|
||||
|
||||
VariableId { index }
|
||||
}
|
||||
}
|
||||
|
||||
impl ops::Index<VariableId> for VariableArena {
|
||||
type Output = Variable;
|
||||
|
||||
fn index(&self, id: VariableId) -> &Self::Output {
|
||||
&self.variables[id.index]
|
||||
}
|
||||
}
|
||||
|
||||
impl ops::IndexMut<VariableId> for VariableArena {
|
||||
fn index_mut(&mut self, id: VariableId) -> &mut Self::Output {
|
||||
&mut self.variables[id.index]
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Variable {
|
||||
value: Gaussian,
|
||||
factors: HashMap<usize, Gaussian>,
|
||||
}
|
||||
|
||||
impl Variable {
|
||||
pub fn attach_factor(&mut self, factor: usize) {
|
||||
self.factors.insert(factor, Gaussian::from_pi_tau(0.0, 0.0));
|
||||
}
|
||||
|
||||
pub fn update_value(&mut self, factor: usize, value: Gaussian) -> f64 {
|
||||
let old = self.factors[&factor];
|
||||
|
||||
self.factors.insert(factor, value * old / self.value);
|
||||
|
||||
let pi_delta = (self.value.pi() - value.pi()).abs();
|
||||
|
||||
let delta = if !pi_delta.is_finite() {
|
||||
0.0
|
||||
} else {
|
||||
let pi_delta = pi_delta.sqrt();
|
||||
let tau_delta = (self.value.tau() - value.tau()).abs();
|
||||
|
||||
if pi_delta > tau_delta {
|
||||
pi_delta
|
||||
} else {
|
||||
tau_delta
|
||||
}
|
||||
};
|
||||
|
||||
self.value = value;
|
||||
|
||||
debug!("Variable::value old={:?}, new={:?}", old, value);
|
||||
|
||||
delta
|
||||
}
|
||||
|
||||
pub fn get_value(&self) -> Gaussian {
|
||||
self.value
|
||||
}
|
||||
|
||||
pub fn get_value_mut(&mut self) -> &mut Gaussian {
|
||||
&mut self.value
|
||||
}
|
||||
|
||||
pub fn update_message(&mut self, factor: usize, message: Gaussian) {
|
||||
let old = self.factors[&factor];
|
||||
|
||||
self.factors.insert(factor, message);
|
||||
self.value = self.value / old * message;
|
||||
|
||||
debug!("Variable::message old={:?}, new={:?}", old, message);
|
||||
}
|
||||
|
||||
pub fn get_message(&self, factor: usize) -> Gaussian {
|
||||
self.factors[&factor]
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PriorFactor {
|
||||
id: MessageId,
|
||||
variable: VariableId,
|
||||
@@ -132,11 +31,10 @@ impl PriorFactor {
|
||||
|
||||
pub fn start(&self, variable_arena: &mut VariableArena, message_arena: &mut MessageArena) {
|
||||
let old = message_arena[self.id];
|
||||
let value = variable_arena[self.variable].get_value();
|
||||
let value = variable_arena[self.variable];
|
||||
|
||||
message_arena[self.id] = self.gaussian * old / value;
|
||||
|
||||
*variable_arena[self.variable].get_value_mut() = self.gaussian;
|
||||
variable_arena[self.variable] = self.gaussian;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,7 +67,7 @@ impl LikelihoodFactor {
|
||||
variable_arena: &mut VariableArena,
|
||||
message_arena: &mut MessageArena,
|
||||
) {
|
||||
let x = variable_arena[self.value].get_value();
|
||||
let x = variable_arena[self.value];
|
||||
let fx = message_arena[self.value_msg];
|
||||
|
||||
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
|
||||
@@ -178,8 +76,7 @@ impl LikelihoodFactor {
|
||||
let old = message_arena[self.mean_msg];
|
||||
|
||||
message_arena[self.mean_msg] = gaussian;
|
||||
*variable_arena[self.mean].get_value_mut() =
|
||||
variable_arena[self.mean].get_value() / old * gaussian;
|
||||
variable_arena[self.mean] = variable_arena[self.mean] / old * gaussian;
|
||||
}
|
||||
|
||||
pub fn update_value(
|
||||
@@ -187,7 +84,7 @@ impl LikelihoodFactor {
|
||||
variable_arena: &mut VariableArena,
|
||||
message_arena: &mut MessageArena,
|
||||
) {
|
||||
let y = variable_arena[self.mean].get_value();
|
||||
let y = variable_arena[self.mean];
|
||||
let fy = message_arena[self.mean_msg];
|
||||
|
||||
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
|
||||
@@ -196,8 +93,7 @@ impl LikelihoodFactor {
|
||||
let old = message_arena[self.value_msg];
|
||||
|
||||
message_arena[self.value_msg] = gaussian;
|
||||
*variable_arena[self.value].get_value_mut() =
|
||||
variable_arena[self.value].get_value() / old * gaussian;
|
||||
variable_arena[self.value] = variable_arena[self.value] / old * gaussian;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -255,17 +151,10 @@ impl SumFactor {
|
||||
|
||||
let gaussian = Gaussian::from_pi_tau(new_pi, new_tau);
|
||||
|
||||
if variable == self.sum {
|
||||
debug!("Sum::down var={:?}, value={:?}", variable.index, gaussian);
|
||||
} else {
|
||||
debug!("Sum::up var={:?}, value={:?}", variable.index, gaussian);
|
||||
}
|
||||
|
||||
let old = message_arena[message];
|
||||
|
||||
message_arena[message] = gaussian;
|
||||
*variable_arena[variable].get_value_mut() =
|
||||
variable_arena[variable].get_value() / old * gaussian;
|
||||
variable_arena[variable] = variable_arena[variable] / old * gaussian;
|
||||
}
|
||||
|
||||
pub fn update_sum(&self, variable_arena: &mut VariableArena, message_arena: &mut MessageArena) {
|
||||
@@ -273,11 +162,7 @@ impl SumFactor {
|
||||
.terms
|
||||
.iter()
|
||||
.zip(self.terms_msg.iter())
|
||||
.map(|(term, msg)| {
|
||||
let variable = &variable_arena[*term];
|
||||
|
||||
(variable.get_value(), message_arena[*msg])
|
||||
})
|
||||
.map(|(term, msg)| (variable_arena[*term], message_arena[*msg]))
|
||||
.unzip();
|
||||
|
||||
self.internal_update(
|
||||
@@ -321,13 +206,9 @@ impl SumFactor {
|
||||
.enumerate()
|
||||
.map(|(i, (term, msg))| {
|
||||
if i == index {
|
||||
let variable = &variable_arena[self.sum];
|
||||
|
||||
(variable.get_value(), message_arena[self.sum_msg])
|
||||
(variable_arena[self.sum], message_arena[self.sum_msg])
|
||||
} else {
|
||||
let variable = &variable_arena[*term];
|
||||
|
||||
(variable.get_value(), message_arena[*msg])
|
||||
(variable_arena[*term], message_arena[*msg])
|
||||
}
|
||||
})
|
||||
.unzip();
|
||||
@@ -403,7 +284,7 @@ impl TruncateFactor {
|
||||
variable_arena: &mut VariableArena,
|
||||
message_arena: &mut MessageArena,
|
||||
) -> f64 {
|
||||
let x = variable_arena[self.variable].get_value();
|
||||
let x = variable_arena[self.variable];
|
||||
let fx = message_arena[self.variable_msg];
|
||||
|
||||
let c = x.pi() - fx.pi();
|
||||
@@ -425,7 +306,7 @@ impl TruncateFactor {
|
||||
let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w);
|
||||
|
||||
let old = message_arena[self.variable_msg];
|
||||
let value = variable_arena[self.variable].get_value();
|
||||
let value = variable_arena[self.variable];
|
||||
|
||||
message_arena[self.variable_msg] = gaussian * old / value;
|
||||
|
||||
@@ -444,7 +325,7 @@ impl TruncateFactor {
|
||||
}
|
||||
};
|
||||
|
||||
*variable_arena[self.variable].get_value_mut() = gaussian;
|
||||
variable_arena[self.variable] = gaussian;
|
||||
|
||||
delta
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ mod matrix;
|
||||
|
||||
use crate::factor_graph::*;
|
||||
use crate::gaussian::Gaussian;
|
||||
use crate::graph::MessageArena;
|
||||
use crate::graph::{MessageArena, VariableArena};
|
||||
use crate::matrix::Matrix;
|
||||
|
||||
/// Default initial mean of ratings.
|
||||
@@ -265,7 +265,7 @@ impl TrueSkill {
|
||||
|
||||
rating_vars
|
||||
.iter()
|
||||
.map(|variable| variable_arena[*variable].get_value())
|
||||
.map(|variable| variable_arena[*variable])
|
||||
.map(|value| Rating {
|
||||
mu: value.mu(),
|
||||
sigma: value.sigma(),
|
||||
|
||||
Reference in New Issue
Block a user