use std::collections::HashMap; use std::f64; use std::ops; use log::*; use crate::gaussian::Gaussian; use crate::math; #[derive(Clone, Copy, Debug, PartialEq)] pub struct VariableId { index: usize, } pub struct VariableArena { variables: Vec, } impl VariableArena { pub fn new() -> VariableArena { VariableArena { variables: Vec::new(), } } pub fn create(&mut self) -> VariableId { let index = self.variables.len(); self.variables.push(Variable { value: Gaussian::from_pi_tau(0.0, 0.0), factors: HashMap::new(), }); VariableId { index } } } impl ops::Index for VariableArena { type Output = Variable; fn index(&self, id: VariableId) -> &Self::Output { &self.variables[id.index] } } impl ops::IndexMut for VariableArena { fn index_mut(&mut self, id: VariableId) -> &mut Self::Output { &mut self.variables[id.index] } } pub struct Variable { value: Gaussian, factors: HashMap, } impl Variable { pub fn attach_factor(&mut self, factor: usize) { self.factors.insert(factor, Gaussian::from_pi_tau(0.0, 0.0)); } pub fn update_value(&mut self, factor: usize, value: Gaussian) -> f64 { let old = self.factors[&factor]; self.factors.insert(factor, value * old / self.value); let pi_delta = (self.value.pi() - value.pi()).abs(); let delta = if !pi_delta.is_finite() { 0.0 } else { let pi_delta = pi_delta.sqrt(); let tau_delta = (self.value.tau() - value.tau()).abs(); if pi_delta > tau_delta { pi_delta } else { tau_delta } }; self.value = value; debug!("Variable::value old={:?}, new={:?}", old, value); delta } pub fn get_value(&self) -> Gaussian { self.value } pub fn update_message(&mut self, factor: usize, message: Gaussian) { let old = self.factors[&factor]; self.factors.insert(factor, message); self.value = self.value / old * message; debug!("Variable::message old={:?}, new={:?}", old, message); } pub fn get_message(&self, factor: usize) -> Gaussian { self.factors[&factor] } } pub struct PriorFactor { id: usize, variable: VariableId, gaussian: Gaussian, } impl PriorFactor { pub fn new( variable_arena: &mut VariableArena, id: usize, variable: VariableId, gaussian: Gaussian, ) -> PriorFactor { variable_arena[variable].attach_factor(id); PriorFactor { id, variable, gaussian, } } pub fn start(&self, variable_arena: &mut VariableArena) { debug!( "Prior::down var={:?}, value={:?}", self.variable.index, self.gaussian ); variable_arena[self.variable].update_value(self.id, self.gaussian); } } pub struct LikelihoodFactor { id: usize, mean: VariableId, value: VariableId, variance: f64, } impl LikelihoodFactor { pub fn new( variable_arena: &mut VariableArena, id: usize, mean: VariableId, value: VariableId, variance: f64, ) -> LikelihoodFactor { variable_arena[mean].attach_factor(id); variable_arena[value].attach_factor(id); LikelihoodFactor { id, mean, value, variance, } } pub fn update_mean(&self, variable_arena: &mut VariableArena) { let (x, fx) = { let variable = &variable_arena[self.value]; (variable.get_value(), variable.get_message(self.id)) }; let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi())); let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau())); debug!( "Likelihood::up var={:?}, value={:?}", self.mean.index, gaussian ); variable_arena[self.mean].update_message(self.id, gaussian); } pub fn update_value(&self, variable_arena: &mut VariableArena) { let (y, fy) = { let variable = &variable_arena[self.mean]; (variable.get_value(), variable.get_message(self.id)) }; let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi())); let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau())); debug!( "Likelihood::down var={:?}, value={:?}", self.value.index, gaussian ); variable_arena[self.value].update_message(self.id, gaussian); } } pub struct SumFactor { id: usize, sum: VariableId, terms: Vec, coeffs: Vec, } impl SumFactor { pub fn new( variable_arena: &mut VariableArena, id: usize, sum: VariableId, terms: Vec, coeffs: Vec, ) -> SumFactor { variable_arena[sum].attach_factor(id); for term in &terms { variable_arena[*term].attach_factor(id); } SumFactor { id, sum, terms, coeffs, } } fn internal_update( &self, variable_arena: &mut VariableArena, variable: VariableId, y: &[Gaussian], fy: &[Gaussian], a: &[f64], ) { let (sum_pi, sum_tau) = a.iter() .zip(y.iter().zip(fy.iter())) .fold((0.0, 0.0f64), |(pi, tau), (a, (y, fy))| { let x = *y / *fy; let new_pi = a.powi(2) / x.pi(); let new_tau = a * x.mu(); (pi + new_pi, tau + new_tau) }); let new_pi = 1.0 / sum_pi; let new_tau = new_pi * sum_tau; let gaussian = Gaussian::from_pi_tau(new_pi, new_tau); if variable == self.sum { debug!("Sum::down var={:?}, value={:?}", variable.index, gaussian); } else { debug!("Sum::up var={:?}, value={:?}", variable.index, gaussian); } variable_arena[variable].update_message(self.id, gaussian); } pub fn update_sum(&self, variable_arena: &mut VariableArena) { let (y, fy): (Vec<_>, Vec<_>) = self .terms .iter() .map(|term| { let variable = &variable_arena[*term]; (variable.get_value(), variable.get_message(self.id)) }) .unzip(); self.internal_update(variable_arena, self.sum, &y, &fy, &self.coeffs); } pub fn update_term(&self, variable_arena: &mut VariableArena, index: usize) { let idx_term = self.terms[index]; let idx_coeff = self.coeffs[index]; let a = self .coeffs .iter() .enumerate() .map(|(i, coeff)| { if i == index { 1.0 / idx_coeff } else { -coeff / idx_coeff } }) .collect::>(); let (y, fy): (Vec<_>, Vec<_>) = self .terms .iter() .enumerate() .map(|(i, term)| { let variable_id = if i == index { self.sum } else { *term }; let variable = &variable_arena[variable_id]; (variable.get_value(), variable.get_message(self.id)) }) .unzip(); self.internal_update(variable_arena, idx_term, &y, &fy, &a); } pub fn update_all_terms(&self, variable_arena: &mut VariableArena) { for i in 0..self.terms.len() { self.update_term(variable_arena, i); } } } fn v_win(t: f64, e: f64) -> f64 { math::pdf(t - e) / math::cdf(t - e) } fn w_win(t: f64, e: f64) -> f64 { let vwin = v_win(t, e); vwin * (vwin + t - e) } fn v_draw(t: f64, e: f64) -> f64 { (math::pdf(-e - t) - math::pdf(e - t)) / (math::cdf(e - t) - math::cdf(-e - t)) } fn w_draw(t: f64, e: f64) -> f64 { let vdraw = v_draw(t, e); let n = (vdraw * vdraw) + ((e - t) * math::pdf(e - t) + (e + t) * math::pdf(e + t)); let d = math::cdf(e - t) - math::cdf(-e - t); n / d } pub struct TruncateFactor { id: usize, variable: VariableId, epsilon: f64, draw: bool, } impl TruncateFactor { pub fn new( variable_arena: &mut VariableArena, id: usize, variable: VariableId, epsilon: f64, draw: bool, ) -> TruncateFactor { variable_arena[variable].attach_factor(id); TruncateFactor { id, variable, epsilon, draw, } } pub fn update(&self, variable_arena: &mut VariableArena) -> f64 { let (x, fx) = { let variable = &variable_arena[self.variable]; (variable.get_value(), variable.get_message(self.id)) }; let c = x.pi() - fx.pi(); let d = x.tau() - fx.tau(); let sqrt_c = c.sqrt(); let t = d / sqrt_c; let e = self.epsilon * sqrt_c; let (v, w) = if self.draw { (v_draw(t, e), w_draw(t, e)) } else { (v_win(t, e), w_win(t, e)) }; let m_w = 1.0 - w; let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w); debug!( "Trunc::up var={:?}, value={:?}", self.variable.index, gaussian ); variable_arena[self.variable].update_value(self.id, gaussian) } }