It works!

This commit is contained in:
2018-10-25 21:40:38 +02:00
parent e8aa60fcdd
commit 4fc27841e9
4 changed files with 44 additions and 4 deletions

View File

@@ -4,7 +4,9 @@ version = "0.1.0"
authors = ["Anders Olsson <anders.e.olsson@gmail.com>"]
[dependencies]
log = "0.4"
statrs = "0.10"
[dev-dependencies]
approx = "0.3"
env_logger = "0.5"

View File

@@ -54,6 +54,8 @@ impl Variable {
self.factors.insert(factor, value * old / self.value);
self.value = value;
// debug!("Variable::value value={:?}", self.value);
}
pub fn get_value(&self) -> Gaussian {
@@ -65,6 +67,8 @@ impl Variable {
self.factors.insert(factor, message);
self.value = self.value / old * message;
// debug!("Variable::message value={:?}", self.value);
}
pub fn get_message(&self, factor: usize) -> Gaussian {
@@ -95,6 +99,8 @@ impl PriorFactor {
}
pub fn start(&self, variable_arena: &mut VariableArena) {
debug!("Prior::down var={:?}, value={:?}", self.variable.index, self.gaussian);
variable_arena
.get_mut(self.variable)
.unwrap()
@@ -137,6 +143,8 @@ impl LikelihoodFactor {
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau()));
debug!("Likelihood::up var={:?}, value={:?}", self.mean.index, gaussian);
variable_arena
.get_mut(self.mean)
.unwrap()
@@ -152,6 +160,8 @@ impl LikelihoodFactor {
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau()));
debug!("Likelihood::down var={:?}, value={:?}", self.value.index, gaussian);
variable_arena
.get_mut(self.value)
.unwrap()
@@ -210,6 +220,12 @@ impl SumFactor {
let gaussian = Gaussian::from_pi_tau(new_pi, new_tau);
if variable == self.sum {
debug!("Sum::down var={:?}, value={:?}", variable.index, gaussian);
} else {
debug!("Sum::up var={:?}, value={:?}", variable.index, gaussian);
}
variable_arena
.get_mut(variable)
.unwrap()
@@ -264,6 +280,12 @@ impl SumFactor {
self.internal_update(variable_arena, idx_term, y, fy, &a);
}
pub fn update_all_terms(&self, variable_arena: &mut VariableArena) {
for i in 0..self.terms.len() {
self.update_term(variable_arena, i);
}
}
}
fn v_win(t: f64, e: f64) -> f64 {
@@ -337,6 +359,8 @@ impl TruncateFactor {
let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w);
debug!("Trunc::up var={:?}, value={:?}", self.variable.index, gaussian);
variable_arena
.get_mut(self.variable)
.unwrap()

View File

@@ -1,11 +1,19 @@
use std::fmt;
use std::ops;
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy)]
pub struct Gaussian {
pi: f64,
tau: f64,
}
impl fmt::Debug for Gaussian {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "N(mu={:.03}, sigma={:.03})", self.mu(), self.sigma())
}
}
impl Gaussian {
pub fn from_pi_tau(pi: f64, tau: f64) -> Gaussian {
Gaussian { pi, tau }

View File

@@ -1,9 +1,14 @@
#[macro_use]
extern crate log;
extern crate statrs;
#[cfg(test)]
#[macro_use]
extern crate approx;
#[cfg(test)]
extern crate env_logger;
mod factor_graph;
mod gaussian;
mod math;
@@ -224,7 +229,7 @@ where
}
for factor in &team_perf_layer {
factor.update_term(&mut variable_arena, 0);
factor.update_all_terms(&mut variable_arena);
}
for factor in &perf_layer {
@@ -307,6 +312,7 @@ where
#[cfg(test)]
mod tests {
use approx::{AbsDiffEq, RelativeEq};
use env_logger;
use super::*;
@@ -392,6 +398,8 @@ mod tests {
#[test]
fn test_rate_2vs2() {
let _ = env_logger::try_init();
let alice = Rating::new(MU, SIGMA);
let bob = Rating::new(MU, SIGMA);
let chris = Rating::new(MU, SIGMA);
@@ -410,6 +418,4 @@ mod tests {
assert_relative_eq!(rating, expected, epsilon = EPSILON);
}
}
}