Files
trueskill-rs/src/factor_graph.rs
2019-01-02 22:24:42 +01:00

398 lines
9.5 KiB
Rust

use std::collections::HashMap;
use std::f64;
use std::ops;
use log::*;
use crate::gaussian::Gaussian;
use crate::math;
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct VariableId {
index: usize,
}
pub struct VariableArena {
variables: Vec<Variable>,
}
impl VariableArena {
pub fn new() -> VariableArena {
VariableArena {
variables: Vec::new(),
}
}
pub fn create(&mut self) -> VariableId {
let index = self.variables.len();
self.variables.push(Variable {
value: Gaussian::from_pi_tau(0.0, 0.0),
factors: HashMap::new(),
});
VariableId { index }
}
}
impl ops::Index<VariableId> for VariableArena {
type Output = Variable;
fn index(&self, id: VariableId) -> &Self::Output {
&self.variables[id.index]
}
}
impl ops::IndexMut<VariableId> for VariableArena {
fn index_mut(&mut self, id: VariableId) -> &mut Self::Output {
&mut self.variables[id.index]
}
}
pub struct Variable {
value: Gaussian,
factors: HashMap<usize, Gaussian>,
}
impl Variable {
pub fn attach_factor(&mut self, factor: usize) {
self.factors.insert(factor, Gaussian::from_pi_tau(0.0, 0.0));
}
pub fn update_value(&mut self, factor: usize, value: Gaussian) -> f64 {
let old = self.factors[&factor];
self.factors.insert(factor, value * old / self.value);
let pi_delta = (self.value.pi() - value.pi()).abs();
let delta = if !pi_delta.is_finite() {
0.0
} else {
let pi_delta = pi_delta.sqrt();
let tau_delta = (self.value.tau() - value.tau()).abs();
if pi_delta > tau_delta {
pi_delta
} else {
tau_delta
}
};
self.value = value;
debug!("Variable::value old={:?}, new={:?}", old, value);
delta
}
pub fn get_value(&self) -> Gaussian {
self.value
}
pub fn update_message(&mut self, factor: usize, message: Gaussian) {
let old = self.factors[&factor];
self.factors.insert(factor, message);
self.value = self.value / old * message;
debug!("Variable::message old={:?}, new={:?}", old, message);
}
pub fn get_message(&self, factor: usize) -> Gaussian {
self.factors[&factor]
}
}
pub struct PriorFactor {
id: usize,
variable: VariableId,
gaussian: Gaussian,
}
impl PriorFactor {
pub fn new(
variable_arena: &mut VariableArena,
id: usize,
variable: VariableId,
gaussian: Gaussian,
) -> PriorFactor {
variable_arena[variable].attach_factor(id);
PriorFactor {
id,
variable,
gaussian,
}
}
pub fn start(&self, variable_arena: &mut VariableArena) {
debug!(
"Prior::down var={:?}, value={:?}",
self.variable.index, self.gaussian
);
variable_arena[self.variable].update_value(self.id, self.gaussian);
}
}
pub struct LikelihoodFactor {
id: usize,
mean: VariableId,
value: VariableId,
variance: f64,
}
impl LikelihoodFactor {
pub fn new(
variable_arena: &mut VariableArena,
id: usize,
mean: VariableId,
value: VariableId,
variance: f64,
) -> LikelihoodFactor {
variable_arena[mean].attach_factor(id);
variable_arena[value].attach_factor(id);
LikelihoodFactor {
id,
mean,
value,
variance,
}
}
pub fn update_mean(&self, variable_arena: &mut VariableArena) {
let (x, fx) = {
let variable = &variable_arena[self.value];
(variable.get_value(), variable.get_message(self.id))
};
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau()));
debug!(
"Likelihood::up var={:?}, value={:?}",
self.mean.index, gaussian
);
variable_arena[self.mean].update_message(self.id, gaussian);
}
pub fn update_value(&self, variable_arena: &mut VariableArena) {
let (y, fy) = {
let variable = &variable_arena[self.mean];
(variable.get_value(), variable.get_message(self.id))
};
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau()));
debug!(
"Likelihood::down var={:?}, value={:?}",
self.value.index, gaussian
);
variable_arena[self.value].update_message(self.id, gaussian);
}
}
pub struct SumFactor {
id: usize,
sum: VariableId,
terms: Vec<VariableId>,
coeffs: Vec<f64>,
}
impl SumFactor {
pub fn new(
variable_arena: &mut VariableArena,
id: usize,
sum: VariableId,
terms: Vec<VariableId>,
coeffs: Vec<f64>,
) -> SumFactor {
variable_arena[sum].attach_factor(id);
for term in &terms {
variable_arena[*term].attach_factor(id);
}
SumFactor {
id,
sum,
terms,
coeffs,
}
}
fn internal_update(
&self,
variable_arena: &mut VariableArena,
variable: VariableId,
y: &[Gaussian],
fy: &[Gaussian],
a: &[f64],
) {
let (sum_pi, sum_tau) =
a.iter()
.zip(y.iter().zip(fy.iter()))
.fold((0.0, 0.0f64), |(pi, tau), (a, (y, fy))| {
let x = *y / *fy;
let new_pi = a.powi(2) / x.pi();
let new_tau = a * x.mu();
(pi + new_pi, tau + new_tau)
});
let new_pi = 1.0 / sum_pi;
let new_tau = new_pi * sum_tau;
let gaussian = Gaussian::from_pi_tau(new_pi, new_tau);
if variable == self.sum {
debug!("Sum::down var={:?}, value={:?}", variable.index, gaussian);
} else {
debug!("Sum::up var={:?}, value={:?}", variable.index, gaussian);
}
variable_arena[variable].update_message(self.id, gaussian);
}
pub fn update_sum(&self, variable_arena: &mut VariableArena) {
let (y, fy): (Vec<_>, Vec<_>) = self
.terms
.iter()
.map(|term| {
let variable = &variable_arena[*term];
(variable.get_value(), variable.get_message(self.id))
})
.unzip();
self.internal_update(variable_arena, self.sum, &y, &fy, &self.coeffs);
}
pub fn update_term(&self, variable_arena: &mut VariableArena, index: usize) {
let idx_term = self.terms[index];
let idx_coeff = self.coeffs[index];
let a = self
.coeffs
.iter()
.enumerate()
.map(|(i, coeff)| {
if i == index {
1.0 / idx_coeff
} else {
-coeff / idx_coeff
}
})
.collect::<Vec<_>>();
let (y, fy): (Vec<_>, Vec<_>) = self
.terms
.iter()
.enumerate()
.map(|(i, term)| {
let variable_id = if i == index { self.sum } else { *term };
let variable = &variable_arena[variable_id];
(variable.get_value(), variable.get_message(self.id))
})
.unzip();
self.internal_update(variable_arena, idx_term, &y, &fy, &a);
}
pub fn update_all_terms(&self, variable_arena: &mut VariableArena) {
for i in 0..self.terms.len() {
self.update_term(variable_arena, i);
}
}
}
fn v_win(t: f64, e: f64) -> f64 {
math::pdf(t - e) / math::cdf(t - e)
}
fn w_win(t: f64, e: f64) -> f64 {
let vwin = v_win(t, e);
vwin * (vwin + t - e)
}
fn v_draw(t: f64, e: f64) -> f64 {
(math::pdf(-e - t) - math::pdf(e - t)) / (math::cdf(e - t) - math::cdf(-e - t))
}
fn w_draw(t: f64, e: f64) -> f64 {
let vdraw = v_draw(t, e);
let n = (vdraw * vdraw) + ((e - t) * math::pdf(e - t) + (e + t) * math::pdf(e + t));
let d = math::cdf(e - t) - math::cdf(-e - t);
n / d
}
pub struct TruncateFactor {
id: usize,
variable: VariableId,
epsilon: f64,
draw: bool,
}
impl TruncateFactor {
pub fn new(
variable_arena: &mut VariableArena,
id: usize,
variable: VariableId,
epsilon: f64,
draw: bool,
) -> TruncateFactor {
variable_arena[variable].attach_factor(id);
TruncateFactor {
id,
variable,
epsilon,
draw,
}
}
pub fn update(&self, variable_arena: &mut VariableArena) -> f64 {
let (x, fx) = {
let variable = &variable_arena[self.variable];
(variable.get_value(), variable.get_message(self.id))
};
let c = x.pi() - fx.pi();
let d = x.tau() - fx.tau();
let sqrt_c = c.sqrt();
let t = d / sqrt_c;
let e = self.epsilon * sqrt_c;
let (v, w) = if self.draw {
(v_draw(t, e), w_draw(t, e))
} else {
(v_win(t, e), w_win(t, e))
};
let m_w = 1.0 - w;
let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w);
debug!(
"Trunc::up var={:?}, value={:?}",
self.variable.index, gaussian
);
variable_arena[self.variable].update_value(self.id, gaussian)
}
}