Added two test cases from sublee/trueskill.

This commit is contained in:
2018-10-26 13:14:58 +02:00
parent c471ef3399
commit 3424b5f45f
3 changed files with 181 additions and 62 deletions

View File

@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::f64;
use gaussian::Gaussian;
use math;
@@ -71,6 +72,8 @@ impl Variable {
self.value = value;
debug!("Variable::value old={:?}, new={:?}", old, value);
delta
}
@@ -83,6 +86,8 @@ impl Variable {
self.factors.insert(factor, message);
self.value = self.value / old * message;
debug!("Variable::message old={:?}, new={:?}", old, message);
}
pub fn get_message(&self, factor: usize) -> Gaussian {
@@ -229,14 +234,16 @@ impl SumFactor {
fy: Vec<Gaussian>,
a: &[f64],
) {
let (sum_pi, sum_tau) =
a.iter()
.zip(y.iter().zip(fy.iter()))
.fold((0.0, 0.0), |(pi, tau), (a, (y, fy))| {
let x = *y / *fy;
let (sum_pi, sum_tau) = a.iter()
.zip(y.iter().zip(fy.iter()))
.fold((0.0, 0.0f64), |(pi, tau), (a, (y, fy))| {
let x = *y / *fy;
(pi + a.powi(2) / x.pi(), tau + a * x.tau() / x.pi())
});
let new_pi = a.powi(2) / x.pi();
let new_tau = a * x.mu();
(pi + new_pi, tau + new_tau)
});
let new_pi = 1.0 / sum_pi;
let new_tau = new_pi * sum_tau;
@@ -278,12 +285,10 @@ impl SumFactor {
.coeffs
.iter()
.enumerate()
.map(|(i, coeff)| {
if i == index {
1.0 / idx_coeff
} else {
-coeff / idx_coeff
}
.map(|(i, coeff)| if i == index {
1.0 / idx_coeff
} else {
-coeff / idx_coeff
})
.collect::<Vec<_>>();