From c471ef33994b156970c266820095e8124b5e3b58 Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Fri, 26 Oct 2018 08:57:58 +0200 Subject: [PATCH] Added delta and more tests. --- src/factor_graph.rs | 26 ++++++++++---- src/lib.rs | 88 ++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 99 insertions(+), 15 deletions(-) diff --git a/src/factor_graph.rs b/src/factor_graph.rs index 594d6a3..72fea73 100644 --- a/src/factor_graph.rs +++ b/src/factor_graph.rs @@ -49,13 +49,29 @@ impl Variable { self.factors.insert(factor, Gaussian::from_pi_tau(0.0, 0.0)); } - pub fn update_value(&mut self, factor: usize, value: Gaussian) { + pub fn update_value(&mut self, factor: usize, value: Gaussian) -> f64 { let old = self.factors[&factor]; self.factors.insert(factor, value * old / self.value); + + let pi_delta = (self.value.pi() - value.pi()).abs(); + + let delta = if !pi_delta.is_finite() { + 0.0 + } else { + let pi_delta = pi_delta.sqrt(); + let tau_delta = (self.value.tau() - value.tau()).abs(); + + if pi_delta > tau_delta { + pi_delta + } else { + tau_delta + } + }; + self.value = value; - // debug!("Variable::value value={:?}", self.value); + delta } pub fn get_value(&self) -> Gaussian { @@ -67,8 +83,6 @@ impl Variable { self.factors.insert(factor, message); self.value = self.value / old * message; - - // debug!("Variable::message value={:?}", self.value); } pub fn get_message(&self, factor: usize) -> Gaussian { @@ -344,7 +358,7 @@ impl TruncateFactor { } } - pub fn update(&self, variable_arena: &mut VariableArena) { + pub fn update(&self, variable_arena: &mut VariableArena) -> f64 { let (x, fx) = variable_arena .get(self.variable) .map(|variable| (variable.get_value(), variable.get_message(self.id))) @@ -376,6 +390,6 @@ impl TruncateFactor { variable_arena .get_mut(self.variable) .unwrap() - .update_value(self.id, gaussian); + .update_value(self.id, gaussian) } } diff --git a/src/lib.rs b/src/lib.rs index 9bc80e3..50bc991 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -77,7 +77,7 @@ fn draw_margin(p: f64, beta: f64, total_players: f64) -> f64 { math::icdf((p + 1.0) / 2.0) * total_players.sqrt() * beta } -pub fn rate(ratings: &[(R, u16)], ranks: &[u16]) -> Vec +pub fn rate(ratings: &[(R, u16)], ranks: &[u16], min_delta: f64) -> Vec where R: Rateable, { @@ -205,19 +205,29 @@ where factor.update_sum(&mut variable_arena); } - for _ in 0..5 { + for _ in 0..10 { + let mut delta = 0.0; + for factor in &team_diff_layer { factor.update_sum(&mut variable_arena); } for factor in &trunc_layer { - factor.update(&mut variable_arena); + let d = factor.update(&mut variable_arena); + + if d > delta { + delta = d; + } } for factor in &team_diff_layer { factor.update_term(&mut variable_arena, 0); factor.update_term(&mut variable_arena, 1); } + + if delta < min_delta { + break; + } } for factor in &team_perf_layer { @@ -356,15 +366,16 @@ mod tests { let darren = Rating::new(MU, SIGMA); let expected_ratings = vec![ - Rating::new(33.20778932559525, 6.347937214998893), - Rating::new(27.401497882797486, 5.787057812482782), - Rating::new(22.598576351652632, 5.7871159419307645), - Rating::new(16.79337409436942, 6.348053083319977), + Rating::new(33.20668089876779, 6.34810941351329), + Rating::new(27.401455165087352, 5.7871628131345645), + Rating::new(22.598544839299667, 5.787162810091708), + Rating::new(16.793319100187123, 6.348109386031168), ]; let ratings = rate( &[(alice, 0), (bob, 1), (chris, 2), (darren, 3)], &[0, 1, 2, 3], + DELTA, ); for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) { @@ -372,6 +383,60 @@ mod tests { } } + #[test] + fn test_rate_8_free_for_all_draw() { + let (ratings, ranks): (Vec<_>, Vec<_>) = + (0..8u16).map(|i| ((Rating::new(MU, SIGMA), i), 0)).unzip(); + + let expected_ratings = vec![ + Rating::new(25.000000000000014, 4.592173723582464), + Rating::new(25.000000000000004, 4.582694291508923), + Rating::new(25.000000000000004, 4.576403088132029), + Rating::new(25.000000000000004, 4.5732660302525785), + Rating::new(25.000000000000004, 4.573266030252499), + Rating::new(25.000000000000004, 4.5764030881313005), + Rating::new(25.000000000000004, 4.582694291504568), + Rating::new(25.000000000000004, 4.5921737235617845), + ]; + + let ratings = rate(ratings.as_ref(), ranks.as_ref(), DELTA); + + for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) { + assert_relative_eq!(rating, expected, epsilon = EPSILON); + } + } + + #[test] + fn test_rate_16_free_for_all() { + let (ratings, ranks): (Vec<_>, Vec<_>) = + (0..16u16).map(|i| ((Rating::new(MU, SIGMA), i), i)).unzip(); + + let expected_ratings = vec![ + Rating::new(38.331788391720295, 5.541377093821206), + Rating::new(34.746418422647835, 4.891059608895904), + Rating::new(32.61843846390381, 4.65295506775174), + Rating::new(31.046552382343123, 4.522499560708352), + Rating::new(29.776879742967576, 4.4373517127954045), + Rating::new(28.700098150398343, 4.376108908633618), + Rating::new(27.27926253322585, 4.345628543772217), + Rating::new(25.770921292726943, 4.331516572196449), + Rating::new(24.24208410982518, 4.328837113282726), + Rating::new(22.69645835682008, 4.3358108926893175), + Rating::new(21.111803677204012, 4.353110849422467), + Rating::new(19.453144013532874, 4.383947550430257), + Rating::new(17.666795353614532, 4.43588072704469), + Rating::new(15.65513147838557, 4.527230274864565), + Rating::new(13.192401824247439, 4.713609609918767), + Rating::new(9.461957485668828, 5.277768698162301), + ]; + + let ratings = rate(ratings.as_ref(), ranks.as_ref(), DELTA); + + for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) { + assert_relative_eq!(rating, expected, epsilon = EPSILON); + } + } + #[test] fn test_rate_1vs1_draw() { let alice = Rating::new(MU, SIGMA); @@ -382,7 +447,7 @@ mod tests { Rating::new(25.0, 6.457515683245051), ]; - let ratings = rate(&[(alice, 0), (bob, 1)], &[0, 0]); + let ratings = rate(&[(alice, 0), (bob, 1)], &[0, 0], DELTA); for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) { assert_relative_eq!(rating, expected, epsilon = EPSILON); @@ -405,7 +470,11 @@ mod tests { Rating::new(21.891677600930958, 7.77436345109384), ]; - let ratings = rate(&[(alice, 0), (bob, 0), (chris, 1), (darren, 1)], &[0, 1]); + let ratings = rate( + &[(alice, 0), (bob, 0), (chris, 1), (darren, 1)], + &[0, 1], + DELTA, + ); for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) { assert_relative_eq!(rating, expected, epsilon = EPSILON); @@ -448,6 +517,7 @@ mod tests { (laura, 1), ], &[0, 1], + DELTA, ); for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) {