From 2b950a0caa99f87893fb28644c44e50c639c71f5 Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Thu, 25 Oct 2018 16:16:08 +0200 Subject: [PATCH] Start refactor of rate, so it can handle teams. --- src/lib.rs | 75 +++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 71 insertions(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index f1110e2..ac73872 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -72,13 +72,16 @@ fn draw_margin(p: f64, beta: f64, total_players: f64) -> f64 { math::icdf((p + 1.0) / 2.0) * total_players.sqrt() * beta } -pub fn rate(rating_groups: &[&[R]], ranks: &[u16]) -> Vec +pub fn rate(rating_groups: &[(R, u16)], ranks: &[u16]) -> Vec where R: Rateable, { + // TODO Validate rating_groups is orderded in teams. + // TODO Validate that teams are orderd after rank. + let flatten_ratings = rating_groups .iter() - .flat_map(|group| group.iter()) + .map(|group| &group.0) .collect::>(); let size = flatten_ratings.len(); @@ -88,6 +91,70 @@ where let mut variable_arena = VariableArena::new(); + // --------------------------------------- + + let ratings = rating_groups; + + let rating_count = ratings.len(); + let team_count = ranks.len(); + + let rating_vars = (0..rating_count).map(|_| variable_arena.create()).collect::>(); + let perf_vars = (0..rating_count).map(|_| variable_arena.create()).collect::>(); + let team_perf_vars = (0..team_count).map(|_| variable_arena.create()).collect::>(); + let team_diff_vars = (0..team_count - 1).map(|_| variable_arena.create()).collect::>(); + + let mut factor_id = 0; + + let rating_layer = rating_vars + .iter() + .zip(flatten_ratings.iter()) + .map(|(rating_var, rating)| { + let gaussian = Gaussian::from_mu_sigma(rating.mu(), (rating.sigma().powi(2) + tau_sqr).sqrt()); + + factor_id += 1; + + PriorFactor::new( + &mut variable_arena, + factor_id, + *rating_var, + gaussian, + ) + }) + .collect::>(); + + let perf_layer = rating_vars + .iter() + .zip(perf_vars.iter()) + .map(|(rating_var, perf)| { + factor_id += 1; + + LikelihoodFactor::new( + &mut variable_arena, + factor_id, + *rating_var, + *perf, + beta_sqr, + ) + }) + .collect::>(); + + // let team_perf_layer = team_perf_vars.iter(); + + /* + def build_team_perf_layer(): + for team, team_perf_var in enumerate(team_perf_vars): + if team > 0: + start = team_sizes[team - 1] + else: + start = 0 + end = team_sizes[team] + child_perf_vars = perf_vars[start:end] + coeffs = flatten_weights[start:end] + yield SumFactor(team_perf_var, child_perf_vars, coeffs) + */ + + // --------------------------------------- + let mut ss = Vec::new(); let mut ps = Vec::new(); let mut ts = Vec::new(); @@ -355,7 +422,7 @@ mod tests { Rating::new(16.79337409436942, 6.348053083319977), ]; - let ratings = rate(&[&[alice], &[bob], &[chris], &[darren]], &[0, 1, 2, 3]); + let ratings = rate(&[(alice, 0), (bob, 1), (chris, 2), (darren, 3)], &[0, 1, 2, 3]); for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) { assert_relative_eq!(rating, expected, epsilon = EPSILON); @@ -372,7 +439,7 @@ mod tests { Rating::new(25.0, 6.457515683245051), ]; - let ratings = rate(&[&[alice], &[bob]], &[0, 0]); + let ratings = rate(&[(alice, 0), (bob, 1)], &[0, 0]); for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) { assert_relative_eq!(rating, expected, epsilon = EPSILON);