It works!
This commit is contained in:
@@ -4,7 +4,9 @@ version = "0.1.0"
|
|||||||
authors = ["Anders Olsson <anders.e.olsson@gmail.com>"]
|
authors = ["Anders Olsson <anders.e.olsson@gmail.com>"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
log = "0.4"
|
||||||
statrs = "0.10"
|
statrs = "0.10"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
approx = "0.3"
|
approx = "0.3"
|
||||||
|
env_logger = "0.5"
|
||||||
|
|||||||
@@ -54,6 +54,8 @@ impl Variable {
|
|||||||
|
|
||||||
self.factors.insert(factor, value * old / self.value);
|
self.factors.insert(factor, value * old / self.value);
|
||||||
self.value = value;
|
self.value = value;
|
||||||
|
|
||||||
|
// debug!("Variable::value value={:?}", self.value);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_value(&self) -> Gaussian {
|
pub fn get_value(&self) -> Gaussian {
|
||||||
@@ -65,6 +67,8 @@ impl Variable {
|
|||||||
|
|
||||||
self.factors.insert(factor, message);
|
self.factors.insert(factor, message);
|
||||||
self.value = self.value / old * message;
|
self.value = self.value / old * message;
|
||||||
|
|
||||||
|
// debug!("Variable::message value={:?}", self.value);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_message(&self, factor: usize) -> Gaussian {
|
pub fn get_message(&self, factor: usize) -> Gaussian {
|
||||||
@@ -95,6 +99,8 @@ impl PriorFactor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn start(&self, variable_arena: &mut VariableArena) {
|
pub fn start(&self, variable_arena: &mut VariableArena) {
|
||||||
|
debug!("Prior::down var={:?}, value={:?}", self.variable.index, self.gaussian);
|
||||||
|
|
||||||
variable_arena
|
variable_arena
|
||||||
.get_mut(self.variable)
|
.get_mut(self.variable)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
@@ -137,6 +143,8 @@ impl LikelihoodFactor {
|
|||||||
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
|
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
|
||||||
let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau()));
|
let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau()));
|
||||||
|
|
||||||
|
debug!("Likelihood::up var={:?}, value={:?}", self.mean.index, gaussian);
|
||||||
|
|
||||||
variable_arena
|
variable_arena
|
||||||
.get_mut(self.mean)
|
.get_mut(self.mean)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
@@ -152,6 +160,8 @@ impl LikelihoodFactor {
|
|||||||
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
|
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
|
||||||
let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau()));
|
let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau()));
|
||||||
|
|
||||||
|
debug!("Likelihood::down var={:?}, value={:?}", self.value.index, gaussian);
|
||||||
|
|
||||||
variable_arena
|
variable_arena
|
||||||
.get_mut(self.value)
|
.get_mut(self.value)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
@@ -210,6 +220,12 @@ impl SumFactor {
|
|||||||
|
|
||||||
let gaussian = Gaussian::from_pi_tau(new_pi, new_tau);
|
let gaussian = Gaussian::from_pi_tau(new_pi, new_tau);
|
||||||
|
|
||||||
|
if variable == self.sum {
|
||||||
|
debug!("Sum::down var={:?}, value={:?}", variable.index, gaussian);
|
||||||
|
} else {
|
||||||
|
debug!("Sum::up var={:?}, value={:?}", variable.index, gaussian);
|
||||||
|
}
|
||||||
|
|
||||||
variable_arena
|
variable_arena
|
||||||
.get_mut(variable)
|
.get_mut(variable)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
@@ -264,6 +280,12 @@ impl SumFactor {
|
|||||||
|
|
||||||
self.internal_update(variable_arena, idx_term, y, fy, &a);
|
self.internal_update(variable_arena, idx_term, y, fy, &a);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn update_all_terms(&self, variable_arena: &mut VariableArena) {
|
||||||
|
for i in 0..self.terms.len() {
|
||||||
|
self.update_term(variable_arena, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn v_win(t: f64, e: f64) -> f64 {
|
fn v_win(t: f64, e: f64) -> f64 {
|
||||||
@@ -337,6 +359,8 @@ impl TruncateFactor {
|
|||||||
|
|
||||||
let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w);
|
let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w);
|
||||||
|
|
||||||
|
debug!("Trunc::up var={:?}, value={:?}", self.variable.index, gaussian);
|
||||||
|
|
||||||
variable_arena
|
variable_arena
|
||||||
.get_mut(self.variable)
|
.get_mut(self.variable)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
|||||||
@@ -1,11 +1,19 @@
|
|||||||
|
use std::fmt;
|
||||||
use std::ops;
|
use std::ops;
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug)]
|
#[derive(Clone, Copy)]
|
||||||
pub struct Gaussian {
|
pub struct Gaussian {
|
||||||
pi: f64,
|
pi: f64,
|
||||||
tau: f64,
|
tau: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl fmt::Debug for Gaussian {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
write!(f, "N(mu={:.03}, sigma={:.03})", self.mu(), self.sigma())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
impl Gaussian {
|
impl Gaussian {
|
||||||
pub fn from_pi_tau(pi: f64, tau: f64) -> Gaussian {
|
pub fn from_pi_tau(pi: f64, tau: f64) -> Gaussian {
|
||||||
Gaussian { pi, tau }
|
Gaussian { pi, tau }
|
||||||
|
|||||||
12
src/lib.rs
12
src/lib.rs
@@ -1,9 +1,14 @@
|
|||||||
|
#[macro_use]
|
||||||
|
extern crate log;
|
||||||
extern crate statrs;
|
extern crate statrs;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
extern crate approx;
|
extern crate approx;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
extern crate env_logger;
|
||||||
|
|
||||||
mod factor_graph;
|
mod factor_graph;
|
||||||
mod gaussian;
|
mod gaussian;
|
||||||
mod math;
|
mod math;
|
||||||
@@ -224,7 +229,7 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
for factor in &team_perf_layer {
|
for factor in &team_perf_layer {
|
||||||
factor.update_term(&mut variable_arena, 0);
|
factor.update_all_terms(&mut variable_arena);
|
||||||
}
|
}
|
||||||
|
|
||||||
for factor in &perf_layer {
|
for factor in &perf_layer {
|
||||||
@@ -307,6 +312,7 @@ where
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use approx::{AbsDiffEq, RelativeEq};
|
use approx::{AbsDiffEq, RelativeEq};
|
||||||
|
use env_logger;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
@@ -392,6 +398,8 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_rate_2vs2() {
|
fn test_rate_2vs2() {
|
||||||
|
let _ = env_logger::try_init();
|
||||||
|
|
||||||
let alice = Rating::new(MU, SIGMA);
|
let alice = Rating::new(MU, SIGMA);
|
||||||
let bob = Rating::new(MU, SIGMA);
|
let bob = Rating::new(MU, SIGMA);
|
||||||
let chris = Rating::new(MU, SIGMA);
|
let chris = Rating::new(MU, SIGMA);
|
||||||
@@ -410,6 +418,4 @@ mod tests {
|
|||||||
assert_relative_eq!(rating, expected, epsilon = EPSILON);
|
assert_relative_eq!(rating, expected, epsilon = EPSILON);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user