It works!
This commit is contained in:
@@ -4,7 +4,9 @@ version = "0.1.0"
|
||||
authors = ["Anders Olsson <anders.e.olsson@gmail.com>"]
|
||||
|
||||
[dependencies]
|
||||
log = "0.4"
|
||||
statrs = "0.10"
|
||||
|
||||
[dev-dependencies]
|
||||
approx = "0.3"
|
||||
env_logger = "0.5"
|
||||
|
||||
@@ -54,6 +54,8 @@ impl Variable {
|
||||
|
||||
self.factors.insert(factor, value * old / self.value);
|
||||
self.value = value;
|
||||
|
||||
// debug!("Variable::value value={:?}", self.value);
|
||||
}
|
||||
|
||||
pub fn get_value(&self) -> Gaussian {
|
||||
@@ -65,6 +67,8 @@ impl Variable {
|
||||
|
||||
self.factors.insert(factor, message);
|
||||
self.value = self.value / old * message;
|
||||
|
||||
// debug!("Variable::message value={:?}", self.value);
|
||||
}
|
||||
|
||||
pub fn get_message(&self, factor: usize) -> Gaussian {
|
||||
@@ -95,6 +99,8 @@ impl PriorFactor {
|
||||
}
|
||||
|
||||
pub fn start(&self, variable_arena: &mut VariableArena) {
|
||||
debug!("Prior::down var={:?}, value={:?}", self.variable.index, self.gaussian);
|
||||
|
||||
variable_arena
|
||||
.get_mut(self.variable)
|
||||
.unwrap()
|
||||
@@ -137,6 +143,8 @@ impl LikelihoodFactor {
|
||||
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
|
||||
let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau()));
|
||||
|
||||
debug!("Likelihood::up var={:?}, value={:?}", self.mean.index, gaussian);
|
||||
|
||||
variable_arena
|
||||
.get_mut(self.mean)
|
||||
.unwrap()
|
||||
@@ -152,6 +160,8 @@ impl LikelihoodFactor {
|
||||
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
|
||||
let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau()));
|
||||
|
||||
debug!("Likelihood::down var={:?}, value={:?}", self.value.index, gaussian);
|
||||
|
||||
variable_arena
|
||||
.get_mut(self.value)
|
||||
.unwrap()
|
||||
@@ -210,6 +220,12 @@ impl SumFactor {
|
||||
|
||||
let gaussian = Gaussian::from_pi_tau(new_pi, new_tau);
|
||||
|
||||
if variable == self.sum {
|
||||
debug!("Sum::down var={:?}, value={:?}", variable.index, gaussian);
|
||||
} else {
|
||||
debug!("Sum::up var={:?}, value={:?}", variable.index, gaussian);
|
||||
}
|
||||
|
||||
variable_arena
|
||||
.get_mut(variable)
|
||||
.unwrap()
|
||||
@@ -264,6 +280,12 @@ impl SumFactor {
|
||||
|
||||
self.internal_update(variable_arena, idx_term, y, fy, &a);
|
||||
}
|
||||
|
||||
pub fn update_all_terms(&self, variable_arena: &mut VariableArena) {
|
||||
for i in 0..self.terms.len() {
|
||||
self.update_term(variable_arena, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn v_win(t: f64, e: f64) -> f64 {
|
||||
@@ -337,6 +359,8 @@ impl TruncateFactor {
|
||||
|
||||
let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w);
|
||||
|
||||
debug!("Trunc::up var={:?}, value={:?}", self.variable.index, gaussian);
|
||||
|
||||
variable_arena
|
||||
.get_mut(self.variable)
|
||||
.unwrap()
|
||||
|
||||
@@ -1,11 +1,19 @@
|
||||
use std::fmt;
|
||||
use std::ops;
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct Gaussian {
|
||||
pi: f64,
|
||||
tau: f64,
|
||||
}
|
||||
|
||||
impl fmt::Debug for Gaussian {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "N(mu={:.03}, sigma={:.03})", self.mu(), self.sigma())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl Gaussian {
|
||||
pub fn from_pi_tau(pi: f64, tau: f64) -> Gaussian {
|
||||
Gaussian { pi, tau }
|
||||
|
||||
12
src/lib.rs
12
src/lib.rs
@@ -1,9 +1,14 @@
|
||||
#[macro_use]
|
||||
extern crate log;
|
||||
extern crate statrs;
|
||||
|
||||
#[cfg(test)]
|
||||
#[macro_use]
|
||||
extern crate approx;
|
||||
|
||||
#[cfg(test)]
|
||||
extern crate env_logger;
|
||||
|
||||
mod factor_graph;
|
||||
mod gaussian;
|
||||
mod math;
|
||||
@@ -224,7 +229,7 @@ where
|
||||
}
|
||||
|
||||
for factor in &team_perf_layer {
|
||||
factor.update_term(&mut variable_arena, 0);
|
||||
factor.update_all_terms(&mut variable_arena);
|
||||
}
|
||||
|
||||
for factor in &perf_layer {
|
||||
@@ -307,6 +312,7 @@ where
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use approx::{AbsDiffEq, RelativeEq};
|
||||
use env_logger;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -392,6 +398,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_rate_2vs2() {
|
||||
let _ = env_logger::try_init();
|
||||
|
||||
let alice = Rating::new(MU, SIGMA);
|
||||
let bob = Rating::new(MU, SIGMA);
|
||||
let chris = Rating::new(MU, SIGMA);
|
||||
@@ -410,6 +418,4 @@ mod tests {
|
||||
assert_relative_eq!(rating, expected, epsilon = EPSILON);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user