From 4fc27841e934d4a8d174adad5fb779dbaa978c44 Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Thu, 25 Oct 2018 21:40:38 +0200 Subject: [PATCH] It works! --- Cargo.toml | 2 ++ src/factor_graph.rs | 24 ++++++++++++++++++++++++ src/gaussian.rs | 10 +++++++++- src/lib.rs | 12 +++++++++--- 4 files changed, 44 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index aa1ad51..b7bdec8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,9 @@ version = "0.1.0" authors = ["Anders Olsson "] [dependencies] +log = "0.4" statrs = "0.10" [dev-dependencies] approx = "0.3" +env_logger = "0.5" diff --git a/src/factor_graph.rs b/src/factor_graph.rs index 2b2bf75..607411b 100644 --- a/src/factor_graph.rs +++ b/src/factor_graph.rs @@ -54,6 +54,8 @@ impl Variable { self.factors.insert(factor, value * old / self.value); self.value = value; + + // debug!("Variable::value value={:?}", self.value); } pub fn get_value(&self) -> Gaussian { @@ -65,6 +67,8 @@ impl Variable { self.factors.insert(factor, message); self.value = self.value / old * message; + + // debug!("Variable::message value={:?}", self.value); } pub fn get_message(&self, factor: usize) -> Gaussian { @@ -95,6 +99,8 @@ impl PriorFactor { } pub fn start(&self, variable_arena: &mut VariableArena) { + debug!("Prior::down var={:?}, value={:?}", self.variable.index, self.gaussian); + variable_arena .get_mut(self.variable) .unwrap() @@ -137,6 +143,8 @@ impl LikelihoodFactor { let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi())); let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau())); + debug!("Likelihood::up var={:?}, value={:?}", self.mean.index, gaussian); + variable_arena .get_mut(self.mean) .unwrap() @@ -152,6 +160,8 @@ impl LikelihoodFactor { let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi())); let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau())); + debug!("Likelihood::down var={:?}, value={:?}", self.value.index, gaussian); + variable_arena .get_mut(self.value) .unwrap() @@ -210,6 +220,12 @@ impl SumFactor { let gaussian = Gaussian::from_pi_tau(new_pi, new_tau); + if variable == self.sum { + debug!("Sum::down var={:?}, value={:?}", variable.index, gaussian); + } else { + debug!("Sum::up var={:?}, value={:?}", variable.index, gaussian); + } + variable_arena .get_mut(variable) .unwrap() @@ -264,6 +280,12 @@ impl SumFactor { self.internal_update(variable_arena, idx_term, y, fy, &a); } + + pub fn update_all_terms(&self, variable_arena: &mut VariableArena) { + for i in 0..self.terms.len() { + self.update_term(variable_arena, i); + } + } } fn v_win(t: f64, e: f64) -> f64 { @@ -337,6 +359,8 @@ impl TruncateFactor { let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w); + debug!("Trunc::up var={:?}, value={:?}", self.variable.index, gaussian); + variable_arena .get_mut(self.variable) .unwrap() diff --git a/src/gaussian.rs b/src/gaussian.rs index 33166ac..ad26d15 100644 --- a/src/gaussian.rs +++ b/src/gaussian.rs @@ -1,11 +1,19 @@ +use std::fmt; use std::ops; -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy)] pub struct Gaussian { pi: f64, tau: f64, } +impl fmt::Debug for Gaussian { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "N(mu={:.03}, sigma={:.03})", self.mu(), self.sigma()) + } +} + + impl Gaussian { pub fn from_pi_tau(pi: f64, tau: f64) -> Gaussian { Gaussian { pi, tau } diff --git a/src/lib.rs b/src/lib.rs index ea2cd12..75a6396 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,14 @@ +#[macro_use] +extern crate log; extern crate statrs; #[cfg(test)] #[macro_use] extern crate approx; +#[cfg(test)] +extern crate env_logger; + mod factor_graph; mod gaussian; mod math; @@ -224,7 +229,7 @@ where } for factor in &team_perf_layer { - factor.update_term(&mut variable_arena, 0); + factor.update_all_terms(&mut variable_arena); } for factor in &perf_layer { @@ -307,6 +312,7 @@ where #[cfg(test)] mod tests { use approx::{AbsDiffEq, RelativeEq}; + use env_logger; use super::*; @@ -392,6 +398,8 @@ mod tests { #[test] fn test_rate_2vs2() { + let _ = env_logger::try_init(); + let alice = Rating::new(MU, SIGMA); let bob = Rating::new(MU, SIGMA); let chris = Rating::new(MU, SIGMA); @@ -410,6 +418,4 @@ mod tests { assert_relative_eq!(rating, expected, epsilon = EPSILON); } } - - }