diff --git a/benches/history_converge.rs b/benches/history_converge.rs index e5163a8..d13fbdf 100644 --- a/benches/history_converge.rs +++ b/benches/history_converge.rs @@ -51,6 +51,7 @@ fn build_history_1v1( .convergence(ConvergenceOptions { max_iter: 30, epsilon: 1e-6, + alpha: 1.0, }) .build(); diff --git a/examples/atp.rs b/examples/atp.rs index e82c41a..4478924 100644 --- a/examples/atp.rs +++ b/examples/atp.rs @@ -48,6 +48,7 @@ fn main() { .convergence(trueskill_tt::ConvergenceOptions { max_iter: 10, epsilon: 0.01, + alpha: 1.0, }) .build(); diff --git a/src/game.rs b/src/game.rs index cb16408..d943846 100644 --- a/src/game.rs +++ b/src/game.rs @@ -44,11 +44,14 @@ impl DiffFactor { } } - pub(crate) fn propagate(&mut self, vars: &mut crate::factor::VarStore) -> (f64, f64) { - use crate::factor::Factor; + pub(crate) fn propagate( + &mut self, + vars: &mut crate::factor::VarStore, + alpha: f64, + ) -> (f64, f64) { match self { - Self::Trunc(f) => f.propagate(vars), - Self::Margin(f) => f.propagate(vars), + Self::Trunc(f) => f.propagate_with_alpha(vars, alpha), + Self::Margin(f) => f.propagate_with_alpha(vars, alpha), } } } @@ -87,6 +90,7 @@ pub struct OwnedGame> { result: Vec, weights: Vec>, p_draw: f64, + pub(crate) convergence: crate::ConvergenceOptions, pub(crate) likelihoods: Vec>, pub(crate) evidence: f64, } @@ -97,9 +101,17 @@ impl> OwnedGame { result: Vec, weights: Vec>, p_draw: f64, + convergence: crate::ConvergenceOptions, ) -> Self { let mut arena = ScratchArena::new(); - let g = Game::ranked_with_arena(teams.clone(), &result, &weights, p_draw, &mut arena); + let g = Game::ranked_with_arena( + teams.clone(), + &result, + &weights, + p_draw, + convergence, + &mut arena, + ); let likelihoods = g.likelihoods; let evidence = g.evidence; Self { @@ -107,6 +119,7 @@ impl> OwnedGame { result, weights, p_draw, + convergence, likelihoods, evidence, } @@ -117,9 +130,17 @@ impl> OwnedGame { scores: Vec, weights: Vec>, score_sigma: f64, + convergence: crate::ConvergenceOptions, ) -> Self { let mut arena = ScratchArena::new(); - let g = Game::scored_with_arena(teams.clone(), &scores, &weights, score_sigma, &mut arena); + let g = Game::scored_with_arena( + teams.clone(), + &scores, + &weights, + score_sigma, + convergence, + &mut arena, + ); let likelihoods = g.likelihoods; let evidence = g.evidence; Self { @@ -127,6 +148,7 @@ impl> OwnedGame { result: scores, weights, p_draw: 0.0, + convergence, likelihoods, evidence, } @@ -151,6 +173,7 @@ pub struct Game<'a, T: Time = i64, D: Drift = crate::drift::ConstantDrift> { result: &'a [f64], weights: &'a [Vec], p_draw: f64, + pub(crate) convergence: crate::ConvergenceOptions, pub(crate) likelihoods: Vec>, pub(crate) evidence: f64, } @@ -161,6 +184,7 @@ impl<'a, T: Time, D: Drift> Game<'a, T, D> { result: &'a [f64], weights: &'a [Vec], p_draw: f64, + convergence: crate::ConvergenceOptions, arena: &mut ScratchArena, ) -> Self { debug_assert!( @@ -186,12 +210,17 @@ impl<'a, T: Time, D: Drift> Game<'a, T, D> { }, "draw must be > 0.0 if there are teams with draw" ); + debug_assert!( + convergence.alpha > 0.0 && convergence.alpha <= 1.0, + "convergence alpha must be in (0.0, 1.0]" + ); let mut this = Self { teams, result, weights, p_draw, + convergence, likelihoods: Vec::new(), evidence: 0.0, }; @@ -205,6 +234,7 @@ impl<'a, T: Time, D: Drift> Game<'a, T, D> { scores: &'a [f64], weights: &'a [Vec], score_sigma: f64, + convergence: crate::ConvergenceOptions, arena: &mut ScratchArena, ) -> Self { debug_assert!( @@ -219,12 +249,17 @@ impl<'a, T: Time, D: Drift> Game<'a, T, D> { "weights must have the same dimensions as teams" ); debug_assert!(score_sigma > 0.0, "score_sigma must be positive"); + debug_assert!( + convergence.alpha > 0.0 && convergence.alpha <= 1.0, + "convergence alpha must be in (0.0, 1.0]" + ); let mut this = Self { teams, result: scores, weights, p_draw: 0.0, + convergence, likelihoods: Vec::new(), evidence: 0.0, }; @@ -239,6 +274,10 @@ impl<'a, T: Time, D: Drift> Game<'a, T, D> { { arena.reset(); + let alpha = self.convergence.alpha; + let epsilon = self.convergence.epsilon; + let max_iter = self.convergence.max_iter; + let n_teams = self.teams.len(); arena.sort_buf.extend(0..n_teams); @@ -267,7 +306,7 @@ impl<'a, T: Time, D: Drift> Game<'a, T, D> { let mut step = (f64::INFINITY, f64::INFINITY); let mut iter = 0; - while tuple_gt(step, 1e-6) && iter < 10 { + while tuple_gt(step, epsilon) && iter < max_iter { step = (0.0_f64, 0.0_f64); for (e, lf) in links[..n_diffs.saturating_sub(1)].iter_mut().enumerate() { @@ -275,7 +314,7 @@ impl<'a, T: Time, D: Drift> Game<'a, T, D> { let pl = arena.team_prior[e + 1] * arena.lhood_win[e + 1]; let raw = pw - pl; arena.vars.set(lf.diff(), raw * lf.msg()); - let d = lf.propagate(&mut arena.vars); + let d = lf.propagate(&mut arena.vars, alpha); step = tuple_max(step, d); let new_ll = pw - lf.msg(); @@ -289,7 +328,7 @@ impl<'a, T: Time, D: Drift> Game<'a, T, D> { let pl = arena.team_prior[e + 1] * arena.lhood_win[e + 1]; let raw = pw - pl; arena.vars.set(lf.diff(), raw * lf.msg()); - let d = lf.propagate(&mut arena.vars); + let d = lf.propagate(&mut arena.vars, alpha); step = tuple_max(step, d); let new_lw = pl + lf.msg(); @@ -305,7 +344,7 @@ impl<'a, T: Time, D: Drift> Game<'a, T, D> { let raw = (arena.team_prior[0] * arena.lhood_lose[0]) - (arena.team_prior[1] * arena.lhood_win[1]); arena.vars.set(links[0].diff(), raw * links[0].msg()); - links[0].propagate(&mut arena.vars); + links[0].propagate(&mut arena.vars, alpha); } // Boundary updates: close the chain at both ends. @@ -429,7 +468,13 @@ impl> Game<'_, T, D> { let teams_owned: Vec>> = teams.iter().map(|t| t.to_vec()).collect(); let weights: Vec> = teams.iter().map(|t| vec![1.0; t.len()]).collect(); - Ok(OwnedGame::new(teams_owned, result, weights, options.p_draw)) + Ok(OwnedGame::new( + teams_owned, + result, + weights, + options.p_draw, + options.convergence, + )) } pub fn scored( @@ -465,6 +510,7 @@ impl> Game<'_, T, D> { scores, weights, options.score_sigma, + options.convergence, )) } @@ -526,6 +572,7 @@ mod tests { &[0.0, 1.0], &w, 0.0, + crate::ConvergenceOptions::default(), &mut ScratchArena::new(), ); let p = g.posteriors(); @@ -553,6 +600,7 @@ mod tests { &[0.0, 1.0], &w, 0.0, + crate::ConvergenceOptions::default(), &mut ScratchArena::new(), ); let p = g.posteriors(); @@ -572,6 +620,7 @@ mod tests { &[0.0, 1.0], &w, 0.0, + crate::ConvergenceOptions::default(), &mut ScratchArena::new(), ); @@ -605,6 +654,7 @@ mod tests { &[1.0, 2.0, 0.0], &w, 0.0, + crate::ConvergenceOptions::default(), &mut ScratchArena::new(), ); let p = g.posteriors(); @@ -621,6 +671,7 @@ mod tests { &[2.0, 1.0, 0.0], &w, 0.0, + crate::ConvergenceOptions::default(), &mut ScratchArena::new(), ); let p = g.posteriors(); @@ -632,7 +683,14 @@ mod tests { assert_ulps_eq!(b, Gaussian::from_ms(25.000000, 6.238469), epsilon = 1e-6); let w = [vec![1.0], vec![1.0], vec![1.0]]; - let g = Game::ranked_with_arena(teams, &[1.0, 2.0, 0.0], &w, 0.5, &mut ScratchArena::new()); + let g = Game::ranked_with_arena( + teams, + &[1.0, 2.0, 0.0], + &w, + 0.5, + crate::ConvergenceOptions::default(), + &mut ScratchArena::new(), + ); let p = g.posteriors(); let a = p[0][0]; @@ -664,6 +722,7 @@ mod tests { &[0.0, 0.0], &w, 0.25, + crate::ConvergenceOptions::default(), &mut ScratchArena::new(), ); let p = g.posteriors(); @@ -691,6 +750,7 @@ mod tests { &[0.0, 0.0], &w, 0.25, + crate::ConvergenceOptions::default(), &mut ScratchArena::new(), ); let p = g.posteriors(); @@ -726,6 +786,7 @@ mod tests { &[0.0, 0.0, 0.0], &w, 0.25, + crate::ConvergenceOptions::default(), &mut ScratchArena::new(), ); let p = g.posteriors(); @@ -762,6 +823,7 @@ mod tests { &[0.0, 0.0, 0.0], &w, 0.25, + crate::ConvergenceOptions::default(), &mut ScratchArena::new(), ); let p = g.posteriors(); @@ -813,6 +875,7 @@ mod tests { &[1.0, 0.0, 0.0], &w, 0.25, + crate::ConvergenceOptions::default(), &mut ScratchArena::new(), ); let p = g.posteriors(); @@ -846,6 +909,7 @@ mod tests { &[1.0, 0.0], &w, 0.0, + crate::ConvergenceOptions::default(), &mut ScratchArena::new(), ); let p = g.posteriors(); @@ -870,6 +934,7 @@ mod tests { &[1.0, 0.0], &w, 0.0, + crate::ConvergenceOptions::default(), &mut ScratchArena::new(), ); let p = g.posteriors(); @@ -894,6 +959,7 @@ mod tests { &[1.0, 0.0], &w, 0.0, + crate::ConvergenceOptions::default(), &mut ScratchArena::new(), ); let p = g.posteriors(); @@ -921,6 +987,7 @@ mod tests { &[1.0, 0.0], &w, 0.0, + crate::ConvergenceOptions::default(), &mut ScratchArena::new(), ); let p = g.posteriors(); @@ -948,6 +1015,7 @@ mod tests { &[1.0, 0.0], &w, 0.0, + crate::ConvergenceOptions::default(), &mut ScratchArena::new(), ); let p = g.posteriors(); @@ -967,8 +1035,8 @@ mod tests { let mut t = DiffFactor::Trunc(TruncFactor::new(dt, 0.0, false)); let mut m = DiffFactor::Margin(MarginFactor::new(dm, 5.0, 1.0)); - let _ = t.propagate(&mut vars); - let _ = m.propagate(&mut vars); + let _ = t.propagate(&mut vars, 1.0); + let _ = m.propagate(&mut vars, 1.0); // Smoke: both diffs got written; their msgs are non-N_INF. assert!(t.msg().pi() > 0.0); @@ -989,7 +1057,11 @@ mod tests { let weights = [vec![1.0], vec![1.0]]; let mut arena = ScratchArena::new(); let g = Game::scored_with_arena( - teams, &result, &weights, 1.0, // score_sigma + teams, + &result, + &weights, + 1.0, + crate::ConvergenceOptions::default(), &mut arena, ); let p = g.posteriors(); @@ -1008,7 +1080,8 @@ mod tests { vec![vec![prior], vec![prior]], &result, &weights, - 0.1, // tighter score_sigma + 0.1, + crate::ConvergenceOptions::default(), &mut arena2, ); let p_tight = g_tight.posteriors(); @@ -1116,6 +1189,7 @@ mod tests { &[1.0, 0.0], &w, 0.0, + crate::ConvergenceOptions::default(), &mut ScratchArena::new(), ); let p = g.posteriors(); @@ -1150,6 +1224,7 @@ mod tests { &[1.0, 0.0], &w, 0.0, + crate::ConvergenceOptions::default(), &mut ScratchArena::new(), ); let p = g.posteriors(); @@ -1184,6 +1259,7 @@ mod tests { &[1.0, 0.0], &w, 0.0, + crate::ConvergenceOptions::default(), &mut ScratchArena::new(), ); let p = g.posteriors(); @@ -1222,6 +1298,7 @@ mod tests { &[1.0, 0.0], &w, 0.0, + crate::ConvergenceOptions::default(), &mut ScratchArena::new(), ); let post_2vs1 = g.posteriors(); @@ -1235,6 +1312,7 @@ mod tests { &[1.0, 0.0], &w, 0.0, + crate::ConvergenceOptions::default(), &mut ScratchArena::new(), ); let p = g.posteriors(); diff --git a/src/history.rs b/src/history.rs index 9ef7653..4680838 100644 --- a/src/history.rs +++ b/src/history.rs @@ -838,6 +838,7 @@ mod tests { &[0.0, 1.0], &w, P_DRAW, + crate::ConvergenceOptions::default(), &mut ScratchArena::new(), ) .posteriors(); diff --git a/src/time_slice.rs b/src/time_slice.rs index a6f4806..ee415a0 100644 --- a/src/time_slice.rs +++ b/src/time_slice.rs @@ -138,12 +138,22 @@ impl Event { let teams = self.within_priors(false, false, skills, agents); let result = self.outputs(); let g = match self.kind { - EventKind::Ranked => { - Game::ranked_with_arena(teams, &result, &self.weights, p_draw, arena) - } - EventKind::Scored { score_sigma } => { - Game::scored_with_arena(teams, &result, &self.weights, score_sigma, arena) - } + EventKind::Ranked => Game::ranked_with_arena( + teams, + &result, + &self.weights, + p_draw, + crate::ConvergenceOptions::default(), + arena, + ), + EventKind::Scored { score_sigma } => Game::scored_with_arena( + teams, + &result, + &self.weights, + score_sigma, + crate::ConvergenceOptions::default(), + arena, + ), }; for (t, team) in self.teams.iter_mut().enumerate() { @@ -322,6 +332,7 @@ impl TimeSlice { &result, &event.weights, self.p_draw, + crate::ConvergenceOptions::default(), &mut self.arena, ), EventKind::Scored { score_sigma } => Game::scored_with_arena( @@ -329,6 +340,7 @@ impl TimeSlice { &result, &event.weights, score_sigma, + crate::ConvergenceOptions::default(), &mut self.arena, ), }; @@ -504,16 +516,26 @@ impl TimeSlice { let teams = event.within_priors(online, forward, &self.skills, agents); let result = event.outputs(); match event.kind { - EventKind::Ranked => { - Game::ranked_with_arena(teams, &result, &event.weights, self.p_draw, arena) - .evidence - .ln() - } - EventKind::Scored { score_sigma } => { - Game::scored_with_arena(teams, &result, &event.weights, score_sigma, arena) - .evidence - .ln() - } + EventKind::Ranked => Game::ranked_with_arena( + teams, + &result, + &event.weights, + self.p_draw, + crate::ConvergenceOptions::default(), + arena, + ) + .evidence + .ln(), + EventKind::Scored { score_sigma } => Game::scored_with_arena( + teams, + &result, + &event.weights, + score_sigma, + crate::ConvergenceOptions::default(), + arena, + ) + .evidence + .ln(), } }; diff --git a/tests/api_shape.rs b/tests/api_shape.rs index dafea20..9b4f653 100644 --- a/tests/api_shape.rs +++ b/tests/api_shape.rs @@ -15,6 +15,7 @@ fn add_events_bulk_via_iter() { .convergence(ConvergenceOptions { max_iter: 30, epsilon: 1e-6, + alpha: 1.0, }) .build(); diff --git a/tests/determinism.rs b/tests/determinism.rs index 2f336d2..ce0ebae 100644 --- a/tests/determinism.rs +++ b/tests/determinism.rs @@ -16,6 +16,7 @@ fn build_and_converge(seed: u64) -> Vec<(i64, trueskill_tt::Gaussian)> { .convergence(ConvergenceOptions { max_iter: 30, epsilon: 1e-6, + alpha: 1.0, }) .build(); diff --git a/tests/record_winner.rs b/tests/record_winner.rs index ae18058..659cccc 100644 --- a/tests/record_winner.rs +++ b/tests/record_winner.rs @@ -10,6 +10,7 @@ fn record_winner_builds_history() { .convergence(ConvergenceOptions { max_iter: 30, epsilon: 1e-6, + alpha: 1.0, }) .build();