use std::collections::HashMap; use gaussian::Gaussian; use math; #[derive(Clone, Copy, Debug, PartialEq)] pub struct VariableId { index: usize, } pub struct VariableArena { variables: Vec, } impl VariableArena { pub fn new() -> VariableArena { VariableArena { variables: Vec::new(), } } pub fn create(&mut self) -> VariableId { let index = self.variables.len(); self.variables.push(Variable { value: Gaussian::from_precision(0.0, 0.0), factors: HashMap::new(), }); VariableId { index } } pub fn get(&mut self, id: VariableId) -> Option<&Variable> { self.variables.get(id.index) } pub fn get_mut(&mut self, id: VariableId) -> Option<&mut Variable> { self.variables.get_mut(id.index) } } pub struct Variable { value: Gaussian, factors: HashMap, } impl Variable { pub fn attach_factor(&mut self, factor: usize) { self.factors.insert(factor, Gaussian::from_precision(0.0, 0.0)); } pub fn update_value(&mut self, factor: usize, value: Gaussian) { let old = self.factors[&factor]; let intermediate = value * old; let new = intermediate / self.value; /* println!("update_value: old={}, value={:?}, new={:?}, self.value={:?}", render_g(old), value, new, self.value); */ self.value = new; } pub fn get_value(&self) -> Gaussian { self.value } pub fn update_message(&mut self, factor: usize, message: Gaussian) { let old = self.factors[&factor]; let intermediate = self.value / old; let value = intermediate * message; self.value = value; println!("update_message: old={:?} msg={:?}, new={:?}", old, message, value); self.factors.insert(factor, message); } pub fn get_message(&self, factor: usize) -> Gaussian { self.factors[&factor] } } pub struct PriorFactor { id: usize, variable: VariableId, gaussian: Gaussian, } impl PriorFactor { pub fn new( variable_arena: &mut VariableArena, id: usize, variable: VariableId, gaussian: Gaussian, ) -> PriorFactor { variable_arena.get_mut(variable).unwrap().attach_factor(id); PriorFactor { id, variable, gaussian, } } pub fn start(&self, variable_arena: &mut VariableArena) { println!("PriorFactor: variable.id={}, msg={:?}", self.variable.index, self.gaussian); variable_arena.get_mut(self.variable).unwrap().update_value(self.id, self.gaussian); } } pub struct LikelihoodFactor { id: usize, mean: VariableId, value: VariableId, variance: f64, } impl LikelihoodFactor { pub fn new( variable_arena: &mut VariableArena, id: usize, mean: VariableId, value: VariableId, variance: f64, ) -> LikelihoodFactor { variable_arena.get_mut(mean).unwrap().attach_factor(id); variable_arena.get_mut(value).unwrap().attach_factor(id); LikelihoodFactor { id, mean, value, variance, } } pub fn update_mean(&self, variable_arena: &mut VariableArena) { let x = variable_arena .get(self.value) .map(|variable| variable.get_value()) .unwrap(); let fx = variable_arena .get_mut(self.value) .map(|variable| variable.get_message(self.id)) .unwrap(); let a = 1.0 / (1.0 + self.variance * (x.precision_mean() - fx.precision_mean())); let gaussian = Gaussian::from_precision(a * (x.precision_mean() - fx.precision_mean()), a * (x.precision() - fx.precision())); println!("LikelihoodFactor: mean.id={}, msg={:?}", self.mean.index, gaussian); variable_arena.get_mut(self.mean).unwrap().update_message(self.id, gaussian); } pub fn update_value(&self, variable_arena: &mut VariableArena) { let y = variable_arena .get(self.mean) .map(|variable| variable.get_value()) .unwrap(); let fy = variable_arena .get(self.mean) .map(|variable| variable.get_message(self.id)) .unwrap(); let a = 1.0 / (1.0 + self.variance * (y.precision_mean() - fy.precision_mean())); let gaussian = Gaussian::from_precision(a * (y.precision_mean() - fy.precision_mean()), a * (y.precision() - fy.precision())); println!("LikelihoodFactor: value.id={}, msg={:?}", self.value.index, gaussian); variable_arena.get_mut(self.value).unwrap().update_message(self.id, gaussian); } } pub struct SumFactor { id: usize, sum: VariableId, terms: Vec, coeffs: Vec, } impl SumFactor { pub fn new( variable_arena: &mut VariableArena, id: usize, sum: VariableId, terms: Vec, coeffs: Vec, ) -> SumFactor { variable_arena.get_mut(sum).unwrap().attach_factor(id); for term in &terms { variable_arena.get_mut(*term).unwrap().attach_factor(id); } SumFactor { id, sum, terms, coeffs, } } fn internal_update( &self, variable_arena: &mut VariableArena, variable: VariableId, y: Vec, fy: Vec, a: &Vec, ) { let size = a.len(); let mut sum_pi = 0.0; let mut sum_tau = 0.0; for i in 0..size { let da = a[i]; let gy = y[i]; let gfy = fy[i]; sum_pi += da.powi(2) / (gy.pi() - gfy.pi()); sum_tau += da * (gy.tau() - gfy.tau()) / (gy.pi() - gfy.pi()); } let new_pi = 1.0 / sum_pi; let new_tau = new_pi * sum_tau; let gaussian = Gaussian::from_pi_tau(new_pi, new_tau); if variable == self.sum { println!("SumFactor: sum.id={}, msg={:?}", variable.index, gaussian); } else { println!("SumFactor: term.id={}, msg={:?}", variable.index, gaussian); } variable_arena.get_mut(variable).unwrap().update_value(self.id, gaussian); } pub fn update_sum(&self, variable_arena: &mut VariableArena) { let mut y = Vec::new(); for term in &self.terms { let value = variable_arena .get(*term) .map(|variable| variable.get_value()) .unwrap(); y.push(value); } let mut fy = Vec::new(); for term in &self.terms { let value = variable_arena .get(*term) .map(|variable| variable.get_message(self.id)) .unwrap(); fy.push(value); } self.internal_update(variable_arena, self.sum, y, fy, &self.coeffs); } pub fn update_term(&self, variable_arena: &mut VariableArena, index: usize) { let size = self.coeffs.len(); let idx_coeff = self.coeffs[index]; let mut a = vec![0.0; size]; for i in 0..size { if i != index { a[i] = -self.coeffs[i] / idx_coeff; } } a[index] = 1.0 / idx_coeff; let idx_term = self.terms[index]; let mut y = Vec::new(); let mut fy = Vec::new(); let mut v = self.terms.clone(); v[index] = self.sum; for term in v { let value = variable_arena .get(term) .map(|variable| variable.get_value()) .unwrap(); y.push(value); let value = variable_arena .get(term) .map(|variable| variable.get_message(self.id)) .unwrap(); fy.push(value); } self.internal_update(variable_arena, idx_term, y, fy, &a); } } fn v_win(t: f64, e: f64) -> f64 { math::pdf(t - e) / math::cdf(t - e) } fn w_win(t: f64, e: f64) -> f64 { let vwin = v_win(t, e); vwin * (vwin + t - e) } fn v_draw(t: f64, e: f64) -> f64 { (math::pdf(-e - t) - math::pdf(e - t)) / (math::cdf(e - t) - math::cdf(-e - t)) } fn w_draw(t: f64, e: f64) -> f64 { let vdraw = v_draw(t, e); let n = (vdraw * vdraw) + ((e - t) * math::pdf(e - t) + (e + t) * math::pdf(e + t)); let d = math::cdf(e - t) - math::cdf(-e - t); n / d } pub struct TruncateFactor { id: usize, variable: VariableId, epsilon: f64, draw: bool, } impl TruncateFactor { pub fn new( variable_arena: &mut VariableArena, id: usize, variable: VariableId, epsilon: f64, draw: bool, ) -> TruncateFactor { variable_arena.get_mut(variable).unwrap().attach_factor(id); TruncateFactor { id, variable, epsilon, draw, } } pub fn update(&self, variable_arena: &mut VariableArena) { let x = variable_arena .get(self.variable) .map(|variable| variable.get_value()) .unwrap(); let fx = variable_arena .get_mut(self.variable) .map(|variable| variable.get_message(self.id)) .unwrap(); let c = x.pi() - fx.pi(); let d = x.tau() - fx.tau(); let sqrt_c = c.sqrt(); let t = d / sqrt_c; let e = self.epsilon * sqrt_c; let (v, w) = if self.draw { (v_draw(t, e), w_draw(t, e)) } else { (v_win(t, e), w_win(t, e)) }; let m_w = 1.0 - w; let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w); variable_arena.get_mut(self.variable).unwrap().update_value(self.id, gaussian); } }