From 3424b5f45fa6018dcb4348e5b1386bddb5efdad6 Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Fri, 26 Oct 2018 13:14:58 +0200 Subject: [PATCH] Added two test cases from sublee/trueskill. --- src/factor_graph.rs | 31 ++++--- src/gaussian.rs | 2 +- src/lib.rs | 210 ++++++++++++++++++++++++++++++++++---------- 3 files changed, 181 insertions(+), 62 deletions(-) diff --git a/src/factor_graph.rs b/src/factor_graph.rs index 72fea73..1232b9a 100644 --- a/src/factor_graph.rs +++ b/src/factor_graph.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::f64; use gaussian::Gaussian; use math; @@ -71,6 +72,8 @@ impl Variable { self.value = value; + debug!("Variable::value old={:?}, new={:?}", old, value); + delta } @@ -83,6 +86,8 @@ impl Variable { self.factors.insert(factor, message); self.value = self.value / old * message; + + debug!("Variable::message old={:?}, new={:?}", old, message); } pub fn get_message(&self, factor: usize) -> Gaussian { @@ -229,14 +234,16 @@ impl SumFactor { fy: Vec, a: &[f64], ) { - let (sum_pi, sum_tau) = - a.iter() - .zip(y.iter().zip(fy.iter())) - .fold((0.0, 0.0), |(pi, tau), (a, (y, fy))| { - let x = *y / *fy; + let (sum_pi, sum_tau) = a.iter() + .zip(y.iter().zip(fy.iter())) + .fold((0.0, 0.0f64), |(pi, tau), (a, (y, fy))| { + let x = *y / *fy; - (pi + a.powi(2) / x.pi(), tau + a * x.tau() / x.pi()) - }); + let new_pi = a.powi(2) / x.pi(); + let new_tau = a * x.mu(); + + (pi + new_pi, tau + new_tau) + }); let new_pi = 1.0 / sum_pi; let new_tau = new_pi * sum_tau; @@ -278,12 +285,10 @@ impl SumFactor { .coeffs .iter() .enumerate() - .map(|(i, coeff)| { - if i == index { - 1.0 / idx_coeff - } else { - -coeff / idx_coeff - } + .map(|(i, coeff)| if i == index { + 1.0 / idx_coeff + } else { + -coeff / idx_coeff }) .collect::>(); diff --git a/src/gaussian.rs b/src/gaussian.rs index e63c975..6677d88 100644 --- a/src/gaussian.rs +++ b/src/gaussian.rs @@ -9,7 +9,7 @@ pub struct Gaussian { impl fmt::Debug for Gaussian { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "N(mu={:.03}, sigma={:.03})", self.mu(), self.sigma()) + write!(f, "N(mu={:.03}, sigma={:.03}, pi={:.03}, tau={:.03})", self.mu(), self.sigma(), self.pi(), self.tau()) } } diff --git a/src/lib.rs b/src/lib.rs index 50bc991..dcac01d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -346,6 +346,18 @@ mod tests { } } + fn generate_free_for_all(players: u16) -> (Vec<(Rating, u16)>, Vec) { + (0..players) + .map(|i| ((Rating::new(MU, SIGMA), i), i)) + .unzip() + } + + fn generate_free_for_all_draw(players: u16) -> (Vec<(Rating, u16)>, Vec) { + (0..players) + .map(|i| ((Rating::new(MU, SIGMA), i), 0)) + .unzip() + } + #[test] fn test_quality_1vs1() { let alice = Rating::new(MU, SIGMA); @@ -367,9 +379,9 @@ mod tests { let expected_ratings = vec![ Rating::new(33.20668089876779, 6.34810941351329), - Rating::new(27.401455165087352, 5.7871628131345645), - Rating::new(22.598544839299667, 5.787162810091708), - Rating::new(16.793319100187123, 6.348109386031168), + Rating::new(27.40145516508735, 5.78716281313456), + Rating::new(22.59854483929966, 5.78716281009170), + Rating::new(16.79331910018712, 6.34810938603116), ]; let ratings = rate( @@ -385,18 +397,17 @@ mod tests { #[test] fn test_rate_8_free_for_all_draw() { - let (ratings, ranks): (Vec<_>, Vec<_>) = - (0..8u16).map(|i| ((Rating::new(MU, SIGMA), i), 0)).unzip(); + let (ratings, ranks) = generate_free_for_all_draw(8); let expected_ratings = vec![ - Rating::new(25.000000000000014, 4.592173723582464), - Rating::new(25.000000000000004, 4.582694291508923), - Rating::new(25.000000000000004, 4.576403088132029), - Rating::new(25.000000000000004, 4.5732660302525785), - Rating::new(25.000000000000004, 4.573266030252499), - Rating::new(25.000000000000004, 4.5764030881313005), - Rating::new(25.000000000000004, 4.582694291504568), - Rating::new(25.000000000000004, 4.5921737235617845), + Rating::new(25.00000000000001, 4.59217372358246), + Rating::new(25.00000000000000, 4.58269429150892), + Rating::new(25.00000000000000, 4.57640308813202), + Rating::new(25.00000000000000, 4.57326603025257), + Rating::new(25.00000000000000, 4.57326603025249), + Rating::new(25.00000000000000, 4.57640308813130), + Rating::new(25.00000000000000, 4.58269429150456), + Rating::new(25.00000000000000, 4.59217372356178), ]; let ratings = rate(ratings.as_ref(), ranks.as_ref(), DELTA); @@ -408,26 +419,25 @@ mod tests { #[test] fn test_rate_16_free_for_all() { - let (ratings, ranks): (Vec<_>, Vec<_>) = - (0..16u16).map(|i| ((Rating::new(MU, SIGMA), i), i)).unzip(); + let (ratings, ranks) = generate_free_for_all(16); let expected_ratings = vec![ - Rating::new(38.331788391720295, 5.541377093821206), - Rating::new(34.746418422647835, 4.891059608895904), + Rating::new(38.33178839172029, 5.54137709382120), + Rating::new(34.74641842264783, 4.89105960889590), Rating::new(32.61843846390381, 4.65295506775174), - Rating::new(31.046552382343123, 4.522499560708352), - Rating::new(29.776879742967576, 4.4373517127954045), - Rating::new(28.700098150398343, 4.376108908633618), - Rating::new(27.27926253322585, 4.345628543772217), - Rating::new(25.770921292726943, 4.331516572196449), - Rating::new(24.24208410982518, 4.328837113282726), - Rating::new(22.69645835682008, 4.3358108926893175), - Rating::new(21.111803677204012, 4.353110849422467), - Rating::new(19.453144013532874, 4.383947550430257), - Rating::new(17.666795353614532, 4.43588072704469), - Rating::new(15.65513147838557, 4.527230274864565), - Rating::new(13.192401824247439, 4.713609609918767), - Rating::new(9.461957485668828, 5.277768698162301), + Rating::new(31.04655238234312, 4.52249956070835), + Rating::new(29.77687974296757, 4.43735171279540), + Rating::new(28.70009815039834, 4.37610890863361), + Rating::new(27.27926253322585, 4.34562854377221), + Rating::new(25.77092129272694, 4.33151657219644), + Rating::new(24.24208410982518, 4.32883711328272), + Rating::new(22.69645835682008, 4.33581089268931), + Rating::new(21.11180367720401, 4.35311084942246), + Rating::new(19.45314401353287, 4.38394755043025), + Rating::new(17.66679535361453, 4.43588072704469), + Rating::new(15.65513147838557, 4.52723027486456), + Rating::new(13.19240182424743, 4.71360960991876), + Rating::new(9.461957485668828, 5.27776869816230), ]; let ratings = rate(ratings.as_ref(), ranks.as_ref(), DELTA); @@ -443,8 +453,8 @@ mod tests { let bob = Rating::new(MU, SIGMA); let expected_ratings = vec![ - Rating::new(25.0, 6.457515683245051), - Rating::new(25.0, 6.457515683245051), + Rating::new(25.00000000000000, 6.45751568324505), + Rating::new(25.00000000000000, 6.45751568324505), ]; let ratings = rate(&[(alice, 0), (bob, 1)], &[0, 0], DELTA); @@ -456,18 +466,16 @@ mod tests { #[test] fn test_rate_2vs2() { - let _ = env_logger::try_init(); - let alice = Rating::new(MU, SIGMA); let bob = Rating::new(MU, SIGMA); let chris = Rating::new(MU, SIGMA); let darren = Rating::new(MU, SIGMA); let expected_ratings = vec![ - Rating::new(28.108322399069035, 7.77436345109384), - Rating::new(28.108322399069035, 7.77436345109384), - Rating::new(21.891677600930958, 7.77436345109384), - Rating::new(21.891677600930958, 7.77436345109384), + Rating::new(28.10832239906903, 7.77436345109384), + Rating::new(28.10832239906903, 7.77436345109384), + Rating::new(21.89167760093095, 7.77436345109384), + Rating::new(21.89167760093095, 7.77436345109384), ]; let ratings = rate( @@ -483,8 +491,6 @@ mod tests { #[test] fn test_rate_4vs4() { - let _ = env_logger::try_init(); - let alice = Rating::new(MU, SIGMA); let bob = Rating::new(MU, SIGMA); let chris = Rating::new(MU, SIGMA); @@ -495,14 +501,14 @@ mod tests { let laura = Rating::new(MU, SIGMA); let expected_ratings = vec![ - Rating::new(27.19791584649575, 8.058911711843994), - Rating::new(27.19791584649575, 8.058911711843994), - Rating::new(27.19791584649575, 8.058911711843994), - Rating::new(27.19791584649575, 8.058911711843994), - Rating::new(22.802084153504236, 8.058911711843994), - Rating::new(22.802084153504236, 8.058911711843994), - Rating::new(22.802084153504236, 8.058911711843994), - Rating::new(22.802084153504236, 8.058911711843994), + Rating::new(27.19791584649575, 8.05891171184399), + Rating::new(27.19791584649575, 8.05891171184399), + Rating::new(27.19791584649575, 8.05891171184399), + Rating::new(27.19791584649575, 8.05891171184399), + Rating::new(22.80208415350423, 8.05891171184399), + Rating::new(22.80208415350423, 8.05891171184399), + Rating::new(22.80208415350423, 8.05891171184399), + Rating::new(22.80208415350423, 8.05891171184399), ]; let ratings = rate( @@ -524,4 +530,112 @@ mod tests { assert_relative_eq!(rating, expected, epsilon = EPSILON); } } + + #[test] + fn test_rate_sublee_trueskill_issue_3_case_1() { + let ratings = vec![ + (Rating::new(42.234, 3.728), 0), + (Rating::new(43.290, 3.842), 0), + + (Rating::new(16.667, 0.500), 1), + (Rating::new(16.667, 0.500), 1), + (Rating::new(16.667, 0.500), 1), + (Rating::new(16.667, 0.500), 1), + (Rating::new(16.667, 0.500), 1), + (Rating::new(16.667, 0.500), 1), + (Rating::new(16.667, 0.500), 1), + (Rating::new(16.667, 0.500), 1), + (Rating::new(16.667, 0.500), 1), + (Rating::new(16.667, 0.500), 1), + (Rating::new(16.667, 0.500), 1), + (Rating::new(16.667, 0.500), 1), + (Rating::new(16.667, 0.500), 1), + (Rating::new(16.667, 0.500), 1), + ]; + + let expected_ratings = vec![ + Rating::new(49.04740037944730, 3.64544755829110), + Rating::new(50.52625967591528, 3.75146367648882), + + Rating::new(16.54109749124066, 0.50668947816812), + Rating::new(16.54109749124066, 0.50668947816812), + Rating::new(16.54109749124066, 0.50668947816812), + Rating::new(16.54109749124066, 0.50668947816812), + Rating::new(16.54109749124066, 0.50668947816812), + Rating::new(16.54109749124066, 0.50668947816812), + Rating::new(16.54109749124066, 0.50668947816812), + Rating::new(16.54109749124066, 0.50668947816812), + Rating::new(16.54109749124066, 0.50668947816812), + Rating::new(16.54109749124066, 0.50668947816812), + Rating::new(16.54109749124066, 0.50668947816812), + Rating::new(16.54109749124066, 0.50668947816812), + Rating::new(16.54109749124066, 0.50668947816812), + Rating::new(16.54109749124066, 0.50668947816812), + ]; + + let ratings = rate( + ratings.as_ref(), + &[0, 1], + DELTA, + ); + + for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) { + assert_relative_eq!(rating, expected, epsilon = EPSILON); + } + } + + #[test] + fn test_rate_sublee_trueskill_issue_3_case_2() { + let _ = env_logger::try_init(); + + let ratings = vec![ + (Rating::new(25.000, 0.500), 0), + (Rating::new(25.000, 0.500), 0), + (Rating::new(25.000, 0.500), 0), + (Rating::new(25.000, 0.500), 0), + (Rating::new(25.000, 0.500), 0), + (Rating::new(25.000, 0.500), 0), + (Rating::new(25.000, 0.500), 0), + (Rating::new(25.000, 0.500), 0), + (Rating::new(25.000, 0.500), 0), + (Rating::new(25.000, 0.500), 0), + (Rating::new(25.000, 0.500), 0), + (Rating::new(25.000, 0.500), 0), + (Rating::new(25.000, 0.500), 0), + (Rating::new(25.000, 0.500), 0), + + (Rating::new(42.234, 3.728), 1), + (Rating::new(43.290, 3.842), 1), + ]; + + let expected_ratings = vec![ + Rating::new(25.00000000000000, 0.50689687752485), + Rating::new(25.00000000000000, 0.50689687752485), + Rating::new(25.00000000000000, 0.50689687752485), + Rating::new(25.00000000000000, 0.50689687752485), + Rating::new(25.00000000000000, 0.50689687752485), + Rating::new(25.00000000000000, 0.50689687752485), + Rating::new(25.00000000000000, 0.50689687752485), + Rating::new(25.00000000000000, 0.50689687752485), + Rating::new(25.00000000000000, 0.50689687752485), + Rating::new(25.00000000000000, 0.50689687752485), + Rating::new(25.00000000000000, 0.50689687752485), + Rating::new(25.00000000000000, 0.50689687752485), + Rating::new(25.00000000000000, 0.50689687752485), + Rating::new(25.00000000000000, 0.50689687752485), + + Rating::new(42.23400000000000, 3.72893127376255), + Rating::new(43.29000000000000, 3.84290364756188), + ]; + + let ratings = rate( + ratings.as_ref(), + &[0, 1], + DELTA, + ); + + for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) { + assert_relative_eq!(rating, expected, epsilon = EPSILON); + } + } }