diff --git a/src/factor_graph.rs b/src/factor_graph.rs index c46ceff..150a995 100644 --- a/src/factor_graph.rs +++ b/src/factor_graph.rs @@ -1,13 +1,14 @@ -use std::collections::HashMap; use std::f64; -use std::ops; - -use log::*; use crate::gaussian::Gaussian; -use crate::graph::{MessageArena, MessageId, VariableArena, VariableId}; use crate::math; +mod message; +mod variable; + +pub use message::*; +pub use variable::*; + pub struct PriorFactor { id: MessageId, variable: VariableId, @@ -126,10 +127,8 @@ impl SumFactor { fn internal_update( &self, - variable_arena: &mut VariableArena, - message_arena: &mut MessageArena, - variable: VariableId, - message: MessageId, + variable: &mut Gaussian, + message: &mut Gaussian, y: &[Gaussian], fy: &[Gaussian], a: &[f64], @@ -151,10 +150,10 @@ impl SumFactor { let gaussian = Gaussian::from_pi_tau(new_pi, new_tau); - let old = message_arena[message]; + let old = *message; - message_arena[message] = gaussian; - variable_arena[variable] = variable_arena[variable] / old * gaussian; + *message = gaussian; + *variable = *variable / old * gaussian; } pub fn update_sum(&self, variable_arena: &mut VariableArena, message_arena: &mut MessageArena) { @@ -166,10 +165,8 @@ impl SumFactor { .unzip(); self.internal_update( - variable_arena, - message_arena, - self.sum, - self.sum_msg, + &mut variable_arena[self.sum], + &mut message_arena[self.sum_msg], &y, &fy, &self.coeffs, @@ -214,10 +211,8 @@ impl SumFactor { .unzip(); self.internal_update( - variable_arena, - message_arena, - idx_term, - idx_term_msg, + &mut variable_arena[idx_term], + &mut message_arena[idx_term_msg], &y, &fy, &a, diff --git a/src/graph/message.rs b/src/factor_graph/message.rs similarity index 100% rename from src/graph/message.rs rename to src/factor_graph/message.rs diff --git a/src/graph/variable.rs b/src/factor_graph/variable.rs similarity index 100% rename from src/graph/variable.rs rename to src/factor_graph/variable.rs diff --git a/src/graph.rs b/src/graph.rs deleted file mode 100644 index 07b1388..0000000 --- a/src/graph.rs +++ /dev/null @@ -1,420 +0,0 @@ -mod message; -mod variable; - -pub use message::*; -pub use variable::*; - -/* -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub struct FactorId(usize); - -pub struct Graph { - variables: VariableArena, - factors: Vec, - messages: MessageArena, -} - -impl Graph { - #[inline(always)] - pub fn add_variable(&mut self) -> VariableId { - self.variables.create() - } - - pub fn add_factor(&mut self, factor: Factor) -> FactorId { - let idx = self.factors.len(); - self.factors.push(factor); - - FactorId(idx) - } -} - -pub enum Factor { - Prior { gaussian: Gaussian }, -} -*/ -// FACTORS - -/* -pub struct PriorFactor { - id: usize, - variable: VariableId, - gaussian: Gaussian, -} - -impl PriorFactor { - pub fn new( - variable_arena: &mut VariableArena, - id: usize, - variable: VariableId, - gaussian: Gaussian, - ) -> PriorFactor { - variable_arena[variable].attach_factor(id); - - PriorFactor { - id, - variable, - gaussian, - } - } - - pub fn start(&self, variable_arena: &mut VariableArena) { - debug!( - "Prior::down var={:?}, value={:?}", - self.variable.index, self.gaussian - ); - - variable_arena[self.variable].update_value(self.id, self.gaussian); - } -} -*/ - -/* -pub struct Variable { - value: Gaussian, - factors: HashMap, -} - -impl Variable { - pub fn attach_factor(&mut self, factor: usize) { - self.factors.insert(factor, Gaussian::from_pi_tau(0.0, 0.0)); - } - - pub fn update_value(&mut self, factor: usize, value: Gaussian) -> f64 { - let old = self.factors[&factor]; - - self.factors.insert(factor, value * old / self.value); - - let pi_delta = (self.value.pi() - value.pi()).abs(); - - let delta = if !pi_delta.is_finite() { - 0.0 - } else { - let pi_delta = pi_delta.sqrt(); - let tau_delta = (self.value.tau() - value.tau()).abs(); - - if pi_delta > tau_delta { - pi_delta - } else { - tau_delta - } - }; - - self.value = value; - - debug!("Variable::value old={:?}, new={:?}", old, value); - - delta - } - - pub fn get_value(&self) -> Gaussian { - self.value - } - - pub fn get_message(&self, factor: usize) -> Gaussian { - self.factors[&factor] - } - - pub fn update_message(&mut self, factor: usize, message: Gaussian) { - let old = self.factors[&factor]; - - self.factors.insert(factor, message); - self.value = self.value / old * message; - - debug!("Variable::message old={:?}, new={:?}", old, message); - } -} -*/ - -/* -pub struct PriorFactor { - id: usize, - variable: VariableId, - gaussian: Gaussian, -} - -impl PriorFactor { - pub fn new( - variable_arena: &mut VariableArena, - id: usize, - variable: VariableId, - gaussian: Gaussian, - ) -> PriorFactor { - variable_arena[variable].attach_factor(id); - - PriorFactor { - id, - variable, - gaussian, - } - } - - pub fn start(&self, variable_arena: &mut VariableArena) { - debug!( - "Prior::down var={:?}, value={:?}", - self.variable.index, self.gaussian - ); - - variable_arena[self.variable].update_value(self.id, self.gaussian); - } -} - -pub struct LikelihoodFactor { - id: usize, - mean: VariableId, - value: VariableId, - variance: f64, -} - -impl LikelihoodFactor { - pub fn new( - variable_arena: &mut VariableArena, - id: usize, - mean: VariableId, - value: VariableId, - variance: f64, - ) -> LikelihoodFactor { - variable_arena[mean].attach_factor(id); - variable_arena[value].attach_factor(id); - - LikelihoodFactor { - id, - mean, - value, - variance, - } - } - - pub fn update_mean(&self, variable_arena: &mut VariableArena) { - let (x, fx) = { - let variable = &variable_arena[self.value]; - - (variable.get_value(), variable.get_message(self.id)) - }; - - let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi())); - let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau())); - - debug!( - "Likelihood::up var={:?}, value={:?}", - self.mean.index, gaussian - ); - - variable_arena[self.mean].update_message(self.id, gaussian); - } - - pub fn update_value(&self, variable_arena: &mut VariableArena) { - let (y, fy) = { - let variable = &variable_arena[self.mean]; - - (variable.get_value(), variable.get_message(self.id)) - }; - - let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi())); - let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau())); - - debug!( - "Likelihood::down var={:?}, value={:?}", - self.value.index, gaussian - ); - - variable_arena[self.value].update_message(self.id, gaussian); - } -} - -pub struct SumFactor { - id: usize, - sum: VariableId, - terms: Vec, - coeffs: Vec, -} - -impl SumFactor { - pub fn new( - variable_arena: &mut VariableArena, - id: usize, - sum: VariableId, - terms: Vec, - coeffs: Vec, - ) -> SumFactor { - variable_arena[sum].attach_factor(id); - - for term in &terms { - variable_arena[*term].attach_factor(id); - } - - SumFactor { - id, - sum, - terms, - coeffs, - } - } - - fn internal_update( - &self, - variable_arena: &mut VariableArena, - variable: VariableId, - y: &[Gaussian], - fy: &[Gaussian], - a: &[f64], - ) { - let (sum_pi, sum_tau) = - a.iter() - .zip(y.iter().zip(fy.iter())) - .fold((0.0, 0.0f64), |(pi, tau), (a, (y, fy))| { - let x = *y / *fy; - - let new_pi = a.powi(2) / x.pi(); - let new_tau = a * x.mu(); - - (pi + new_pi, tau + new_tau) - }); - - let new_pi = 1.0 / sum_pi; - let new_tau = new_pi * sum_tau; - - let gaussian = Gaussian::from_pi_tau(new_pi, new_tau); - - if variable == self.sum { - debug!("Sum::down var={:?}, value={:?}", variable.index, gaussian); - } else { - debug!("Sum::up var={:?}, value={:?}", variable.index, gaussian); - } - - variable_arena[variable].update_message(self.id, gaussian); - } - - pub fn update_sum(&self, variable_arena: &mut VariableArena) { - let (y, fy): (Vec<_>, Vec<_>) = self - .terms - .iter() - .map(|term| { - let variable = &variable_arena[*term]; - - (variable.get_value(), variable.get_message(self.id)) - }) - .unzip(); - - self.internal_update(variable_arena, self.sum, &y, &fy, &self.coeffs); - } - - pub fn update_term(&self, variable_arena: &mut VariableArena, index: usize) { - let idx_term = self.terms[index]; - let idx_coeff = self.coeffs[index]; - - let a = self - .coeffs - .iter() - .enumerate() - .map(|(i, coeff)| { - if i == index { - 1.0 / idx_coeff - } else { - -coeff / idx_coeff - } - }) - .collect::>(); - - let (y, fy): (Vec<_>, Vec<_>) = self - .terms - .iter() - .enumerate() - .map(|(i, term)| { - let variable_id = if i == index { self.sum } else { *term }; - let variable = &variable_arena[variable_id]; - - (variable.get_value(), variable.get_message(self.id)) - }) - .unzip(); - - self.internal_update(variable_arena, idx_term, &y, &fy, &a); - } - - pub fn update_all_terms(&self, variable_arena: &mut VariableArena) { - for i in 0..self.terms.len() { - self.update_term(variable_arena, i); - } - } -} - -fn v_win(t: f64, e: f64) -> f64 { - math::pdf(t - e) / math::cdf(t - e) -} - -fn w_win(t: f64, e: f64) -> f64 { - let vwin = v_win(t, e); - - vwin * (vwin + t - e) -} - -fn v_draw(t: f64, e: f64) -> f64 { - (math::pdf(-e - t) - math::pdf(e - t)) / (math::cdf(e - t) - math::cdf(-e - t)) -} - -fn w_draw(t: f64, e: f64) -> f64 { - let vdraw = v_draw(t, e); - let n = (vdraw * vdraw) + ((e - t) * math::pdf(e - t) + (e + t) * math::pdf(e + t)); - let d = math::cdf(e - t) - math::cdf(-e - t); - - n / d -} - -pub struct TruncateFactor { - id: usize, - variable: VariableId, - epsilon: f64, - draw: bool, -} - -impl TruncateFactor { - pub fn new( - variable_arena: &mut VariableArena, - id: usize, - variable: VariableId, - epsilon: f64, - draw: bool, - ) -> TruncateFactor { - variable_arena[variable].attach_factor(id); - - TruncateFactor { - id, - variable, - epsilon, - draw, - } - } - - pub fn update(&self, variable_arena: &mut VariableArena) -> f64 { - let (x, fx) = { - let variable = &variable_arena[self.variable]; - - (variable.get_value(), variable.get_message(self.id)) - }; - - let c = x.pi() - fx.pi(); - let d = x.tau() - fx.tau(); - - let sqrt_c = c.sqrt(); - - let t = d / sqrt_c; - let e = self.epsilon * sqrt_c; - - let (v, w) = if self.draw { - (v_draw(t, e), w_draw(t, e)) - } else { - (v_win(t, e), w_win(t, e)) - }; - - let m_w = 1.0 - w; - - let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w); - - debug!( - "Trunc::up var={:?}, value={:?}", - self.variable.index, gaussian - ); - - variable_arena[self.variable].update_value(self.id, gaussian) - } -} -*/ diff --git a/src/graph/distribution.rs b/src/graph/distribution.rs deleted file mode 100644 index 9ee54c2..0000000 --- a/src/graph/distribution.rs +++ /dev/null @@ -1,42 +0,0 @@ -use std::marker::PhantomData; -use std::ops; - -use crate::gaussian::Gaussian; - -pub struct DistributionArena { - arena: Vec, - _phantom: PhantomData, -} - -impl DistributionArena -where - T: From, -{ - pub fn create(&mut self) -> T { - let idx = self.arena.len(); - - self.arena.push(Gaussian::from_pi_tau(0.0, 0.0)); - - idx.into() - } -} - -impl ops::Index for DistributionArena -where - T: Into, -{ - type Output = Gaussian; - - fn index(&self, id: T) -> &Self::Output { - &self.arena[id.into()] - } -} - -impl ops::IndexMut for DistributionArena -where - T: Into, -{ - fn index_mut(&mut self, id: T) -> &mut Self::Output { - &mut self.arena[id.into()] - } -} diff --git a/src/lib.rs b/src/lib.rs index 1c0dc73..6678aba 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,13 +2,11 @@ mod factor_graph; mod gaussian; -mod graph; mod math; mod matrix; use crate::factor_graph::*; use crate::gaussian::Gaussian; -use crate::graph::{MessageArena, VariableArena}; use crate::matrix::Matrix; /// Default initial mean of ratings.