Added two test cases from sublee/trueskill.

This commit is contained in:
2018-10-26 13:14:58 +02:00
parent c471ef3399
commit 3424b5f45f
3 changed files with 181 additions and 62 deletions

View File

@@ -1,4 +1,5 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::f64;
use gaussian::Gaussian; use gaussian::Gaussian;
use math; use math;
@@ -71,6 +72,8 @@ impl Variable {
self.value = value; self.value = value;
debug!("Variable::value old={:?}, new={:?}", old, value);
delta delta
} }
@@ -83,6 +86,8 @@ impl Variable {
self.factors.insert(factor, message); self.factors.insert(factor, message);
self.value = self.value / old * message; self.value = self.value / old * message;
debug!("Variable::message old={:?}, new={:?}", old, message);
} }
pub fn get_message(&self, factor: usize) -> Gaussian { pub fn get_message(&self, factor: usize) -> Gaussian {
@@ -229,13 +234,15 @@ impl SumFactor {
fy: Vec<Gaussian>, fy: Vec<Gaussian>,
a: &[f64], a: &[f64],
) { ) {
let (sum_pi, sum_tau) = let (sum_pi, sum_tau) = a.iter()
a.iter()
.zip(y.iter().zip(fy.iter())) .zip(y.iter().zip(fy.iter()))
.fold((0.0, 0.0), |(pi, tau), (a, (y, fy))| { .fold((0.0, 0.0f64), |(pi, tau), (a, (y, fy))| {
let x = *y / *fy; let x = *y / *fy;
(pi + a.powi(2) / x.pi(), tau + a * x.tau() / x.pi()) let new_pi = a.powi(2) / x.pi();
let new_tau = a * x.mu();
(pi + new_pi, tau + new_tau)
}); });
let new_pi = 1.0 / sum_pi; let new_pi = 1.0 / sum_pi;
@@ -278,12 +285,10 @@ impl SumFactor {
.coeffs .coeffs
.iter() .iter()
.enumerate() .enumerate()
.map(|(i, coeff)| { .map(|(i, coeff)| if i == index {
if i == index {
1.0 / idx_coeff 1.0 / idx_coeff
} else { } else {
-coeff / idx_coeff -coeff / idx_coeff
}
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();

View File

@@ -9,7 +9,7 @@ pub struct Gaussian {
impl fmt::Debug for Gaussian { impl fmt::Debug for Gaussian {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "N(mu={:.03}, sigma={:.03})", self.mu(), self.sigma()) write!(f, "N(mu={:.03}, sigma={:.03}, pi={:.03}, tau={:.03})", self.mu(), self.sigma(), self.pi(), self.tau())
} }
} }

View File

@@ -346,6 +346,18 @@ mod tests {
} }
} }
fn generate_free_for_all(players: u16) -> (Vec<(Rating, u16)>, Vec<u16>) {
(0..players)
.map(|i| ((Rating::new(MU, SIGMA), i), i))
.unzip()
}
fn generate_free_for_all_draw(players: u16) -> (Vec<(Rating, u16)>, Vec<u16>) {
(0..players)
.map(|i| ((Rating::new(MU, SIGMA), i), 0))
.unzip()
}
#[test] #[test]
fn test_quality_1vs1() { fn test_quality_1vs1() {
let alice = Rating::new(MU, SIGMA); let alice = Rating::new(MU, SIGMA);
@@ -367,9 +379,9 @@ mod tests {
let expected_ratings = vec![ let expected_ratings = vec![
Rating::new(33.20668089876779, 6.34810941351329), Rating::new(33.20668089876779, 6.34810941351329),
Rating::new(27.401455165087352, 5.7871628131345645), Rating::new(27.40145516508735, 5.78716281313456),
Rating::new(22.598544839299667, 5.787162810091708), Rating::new(22.59854483929966, 5.78716281009170),
Rating::new(16.793319100187123, 6.348109386031168), Rating::new(16.79331910018712, 6.34810938603116),
]; ];
let ratings = rate( let ratings = rate(
@@ -385,18 +397,17 @@ mod tests {
#[test] #[test]
fn test_rate_8_free_for_all_draw() { fn test_rate_8_free_for_all_draw() {
let (ratings, ranks): (Vec<_>, Vec<_>) = let (ratings, ranks) = generate_free_for_all_draw(8);
(0..8u16).map(|i| ((Rating::new(MU, SIGMA), i), 0)).unzip();
let expected_ratings = vec![ let expected_ratings = vec![
Rating::new(25.000000000000014, 4.592173723582464), Rating::new(25.00000000000001, 4.59217372358246),
Rating::new(25.000000000000004, 4.582694291508923), Rating::new(25.00000000000000, 4.58269429150892),
Rating::new(25.000000000000004, 4.576403088132029), Rating::new(25.00000000000000, 4.57640308813202),
Rating::new(25.000000000000004, 4.5732660302525785), Rating::new(25.00000000000000, 4.57326603025257),
Rating::new(25.000000000000004, 4.573266030252499), Rating::new(25.00000000000000, 4.57326603025249),
Rating::new(25.000000000000004, 4.5764030881313005), Rating::new(25.00000000000000, 4.57640308813130),
Rating::new(25.000000000000004, 4.582694291504568), Rating::new(25.00000000000000, 4.58269429150456),
Rating::new(25.000000000000004, 4.5921737235617845), Rating::new(25.00000000000000, 4.59217372356178),
]; ];
let ratings = rate(ratings.as_ref(), ranks.as_ref(), DELTA); let ratings = rate(ratings.as_ref(), ranks.as_ref(), DELTA);
@@ -408,26 +419,25 @@ mod tests {
#[test] #[test]
fn test_rate_16_free_for_all() { fn test_rate_16_free_for_all() {
let (ratings, ranks): (Vec<_>, Vec<_>) = let (ratings, ranks) = generate_free_for_all(16);
(0..16u16).map(|i| ((Rating::new(MU, SIGMA), i), i)).unzip();
let expected_ratings = vec![ let expected_ratings = vec![
Rating::new(38.331788391720295, 5.541377093821206), Rating::new(38.33178839172029, 5.54137709382120),
Rating::new(34.746418422647835, 4.891059608895904), Rating::new(34.74641842264783, 4.89105960889590),
Rating::new(32.61843846390381, 4.65295506775174), Rating::new(32.61843846390381, 4.65295506775174),
Rating::new(31.046552382343123, 4.522499560708352), Rating::new(31.04655238234312, 4.52249956070835),
Rating::new(29.776879742967576, 4.4373517127954045), Rating::new(29.77687974296757, 4.43735171279540),
Rating::new(28.700098150398343, 4.376108908633618), Rating::new(28.70009815039834, 4.37610890863361),
Rating::new(27.27926253322585, 4.345628543772217), Rating::new(27.27926253322585, 4.34562854377221),
Rating::new(25.770921292726943, 4.331516572196449), Rating::new(25.77092129272694, 4.33151657219644),
Rating::new(24.24208410982518, 4.328837113282726), Rating::new(24.24208410982518, 4.32883711328272),
Rating::new(22.69645835682008, 4.3358108926893175), Rating::new(22.69645835682008, 4.33581089268931),
Rating::new(21.111803677204012, 4.353110849422467), Rating::new(21.11180367720401, 4.35311084942246),
Rating::new(19.453144013532874, 4.383947550430257), Rating::new(19.45314401353287, 4.38394755043025),
Rating::new(17.666795353614532, 4.43588072704469), Rating::new(17.66679535361453, 4.43588072704469),
Rating::new(15.65513147838557, 4.527230274864565), Rating::new(15.65513147838557, 4.52723027486456),
Rating::new(13.192401824247439, 4.713609609918767), Rating::new(13.19240182424743, 4.71360960991876),
Rating::new(9.461957485668828, 5.277768698162301), Rating::new(9.461957485668828, 5.27776869816230),
]; ];
let ratings = rate(ratings.as_ref(), ranks.as_ref(), DELTA); let ratings = rate(ratings.as_ref(), ranks.as_ref(), DELTA);
@@ -443,8 +453,8 @@ mod tests {
let bob = Rating::new(MU, SIGMA); let bob = Rating::new(MU, SIGMA);
let expected_ratings = vec![ let expected_ratings = vec![
Rating::new(25.0, 6.457515683245051), Rating::new(25.00000000000000, 6.45751568324505),
Rating::new(25.0, 6.457515683245051), Rating::new(25.00000000000000, 6.45751568324505),
]; ];
let ratings = rate(&[(alice, 0), (bob, 1)], &[0, 0], DELTA); let ratings = rate(&[(alice, 0), (bob, 1)], &[0, 0], DELTA);
@@ -456,18 +466,16 @@ mod tests {
#[test] #[test]
fn test_rate_2vs2() { fn test_rate_2vs2() {
let _ = env_logger::try_init();
let alice = Rating::new(MU, SIGMA); let alice = Rating::new(MU, SIGMA);
let bob = Rating::new(MU, SIGMA); let bob = Rating::new(MU, SIGMA);
let chris = Rating::new(MU, SIGMA); let chris = Rating::new(MU, SIGMA);
let darren = Rating::new(MU, SIGMA); let darren = Rating::new(MU, SIGMA);
let expected_ratings = vec![ let expected_ratings = vec![
Rating::new(28.108322399069035, 7.77436345109384), Rating::new(28.10832239906903, 7.77436345109384),
Rating::new(28.108322399069035, 7.77436345109384), Rating::new(28.10832239906903, 7.77436345109384),
Rating::new(21.891677600930958, 7.77436345109384), Rating::new(21.89167760093095, 7.77436345109384),
Rating::new(21.891677600930958, 7.77436345109384), Rating::new(21.89167760093095, 7.77436345109384),
]; ];
let ratings = rate( let ratings = rate(
@@ -483,8 +491,6 @@ mod tests {
#[test] #[test]
fn test_rate_4vs4() { fn test_rate_4vs4() {
let _ = env_logger::try_init();
let alice = Rating::new(MU, SIGMA); let alice = Rating::new(MU, SIGMA);
let bob = Rating::new(MU, SIGMA); let bob = Rating::new(MU, SIGMA);
let chris = Rating::new(MU, SIGMA); let chris = Rating::new(MU, SIGMA);
@@ -495,14 +501,14 @@ mod tests {
let laura = Rating::new(MU, SIGMA); let laura = Rating::new(MU, SIGMA);
let expected_ratings = vec![ let expected_ratings = vec![
Rating::new(27.19791584649575, 8.058911711843994), Rating::new(27.19791584649575, 8.05891171184399),
Rating::new(27.19791584649575, 8.058911711843994), Rating::new(27.19791584649575, 8.05891171184399),
Rating::new(27.19791584649575, 8.058911711843994), Rating::new(27.19791584649575, 8.05891171184399),
Rating::new(27.19791584649575, 8.058911711843994), Rating::new(27.19791584649575, 8.05891171184399),
Rating::new(22.802084153504236, 8.058911711843994), Rating::new(22.80208415350423, 8.05891171184399),
Rating::new(22.802084153504236, 8.058911711843994), Rating::new(22.80208415350423, 8.05891171184399),
Rating::new(22.802084153504236, 8.058911711843994), Rating::new(22.80208415350423, 8.05891171184399),
Rating::new(22.802084153504236, 8.058911711843994), Rating::new(22.80208415350423, 8.05891171184399),
]; ];
let ratings = rate( let ratings = rate(
@@ -524,4 +530,112 @@ mod tests {
assert_relative_eq!(rating, expected, epsilon = EPSILON); assert_relative_eq!(rating, expected, epsilon = EPSILON);
} }
} }
#[test]
fn test_rate_sublee_trueskill_issue_3_case_1() {
let ratings = vec![
(Rating::new(42.234, 3.728), 0),
(Rating::new(43.290, 3.842), 0),
(Rating::new(16.667, 0.500), 1),
(Rating::new(16.667, 0.500), 1),
(Rating::new(16.667, 0.500), 1),
(Rating::new(16.667, 0.500), 1),
(Rating::new(16.667, 0.500), 1),
(Rating::new(16.667, 0.500), 1),
(Rating::new(16.667, 0.500), 1),
(Rating::new(16.667, 0.500), 1),
(Rating::new(16.667, 0.500), 1),
(Rating::new(16.667, 0.500), 1),
(Rating::new(16.667, 0.500), 1),
(Rating::new(16.667, 0.500), 1),
(Rating::new(16.667, 0.500), 1),
(Rating::new(16.667, 0.500), 1),
];
let expected_ratings = vec![
Rating::new(49.04740037944730, 3.64544755829110),
Rating::new(50.52625967591528, 3.75146367648882),
Rating::new(16.54109749124066, 0.50668947816812),
Rating::new(16.54109749124066, 0.50668947816812),
Rating::new(16.54109749124066, 0.50668947816812),
Rating::new(16.54109749124066, 0.50668947816812),
Rating::new(16.54109749124066, 0.50668947816812),
Rating::new(16.54109749124066, 0.50668947816812),
Rating::new(16.54109749124066, 0.50668947816812),
Rating::new(16.54109749124066, 0.50668947816812),
Rating::new(16.54109749124066, 0.50668947816812),
Rating::new(16.54109749124066, 0.50668947816812),
Rating::new(16.54109749124066, 0.50668947816812),
Rating::new(16.54109749124066, 0.50668947816812),
Rating::new(16.54109749124066, 0.50668947816812),
Rating::new(16.54109749124066, 0.50668947816812),
];
let ratings = rate(
ratings.as_ref(),
&[0, 1],
DELTA,
);
for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) {
assert_relative_eq!(rating, expected, epsilon = EPSILON);
}
}
#[test]
fn test_rate_sublee_trueskill_issue_3_case_2() {
let _ = env_logger::try_init();
let ratings = vec![
(Rating::new(25.000, 0.500), 0),
(Rating::new(25.000, 0.500), 0),
(Rating::new(25.000, 0.500), 0),
(Rating::new(25.000, 0.500), 0),
(Rating::new(25.000, 0.500), 0),
(Rating::new(25.000, 0.500), 0),
(Rating::new(25.000, 0.500), 0),
(Rating::new(25.000, 0.500), 0),
(Rating::new(25.000, 0.500), 0),
(Rating::new(25.000, 0.500), 0),
(Rating::new(25.000, 0.500), 0),
(Rating::new(25.000, 0.500), 0),
(Rating::new(25.000, 0.500), 0),
(Rating::new(25.000, 0.500), 0),
(Rating::new(42.234, 3.728), 1),
(Rating::new(43.290, 3.842), 1),
];
let expected_ratings = vec![
Rating::new(25.00000000000000, 0.50689687752485),
Rating::new(25.00000000000000, 0.50689687752485),
Rating::new(25.00000000000000, 0.50689687752485),
Rating::new(25.00000000000000, 0.50689687752485),
Rating::new(25.00000000000000, 0.50689687752485),
Rating::new(25.00000000000000, 0.50689687752485),
Rating::new(25.00000000000000, 0.50689687752485),
Rating::new(25.00000000000000, 0.50689687752485),
Rating::new(25.00000000000000, 0.50689687752485),
Rating::new(25.00000000000000, 0.50689687752485),
Rating::new(25.00000000000000, 0.50689687752485),
Rating::new(25.00000000000000, 0.50689687752485),
Rating::new(25.00000000000000, 0.50689687752485),
Rating::new(25.00000000000000, 0.50689687752485),
Rating::new(42.23400000000000, 3.72893127376255),
Rating::new(43.29000000000000, 3.84290364756188),
];
let ratings = rate(
ratings.as_ref(),
&[0, 1],
DELTA,
);
for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) {
assert_relative_eq!(rating, expected, epsilon = EPSILON);
}
}
} }