diff --git a/examples/atp.rs b/examples/atp.rs index 9aa136b..e82c41a 100644 --- a/examples/atp.rs +++ b/examples/atp.rs @@ -1,53 +1,61 @@ -use std::collections::HashMap; - use plotters::prelude::*; +use smallvec::smallvec; use time::{Date, Month}; -use trueskill_tt::{History, KeyTable}; +use trueskill_tt::{Event, History, Member, Outcome, Team, drift::ConstantDrift}; fn main() { let mut csv = csv::Reader::open("examples/atp.csv").unwrap(); - let mut composition = Vec::new(); - let mut results = Vec::new(); - let mut times = Vec::new(); - let from = Date::from_calendar_date(1900, Month::January, 1).unwrap(); let time_format = time::format_description::parse("[year]-[month]-[day]").unwrap(); - let mut index_map = KeyTable::new(); + let mut events: Vec> = Vec::new(); for row in csv.records() { - if &row["double"] == "t" { - let w1_id = index_map.get_or_create(&row["w1_id"]); - let w2_id = index_map.get_or_create(&row["w2_id"]); - - let l1_id = index_map.get_or_create(&row["l1_id"]); - let l2_id = index_map.get_or_create(&row["l2_id"]); - - composition.push(vec![vec![w1_id, w2_id], vec![l1_id, l2_id]]); - } else { - let w1_id = index_map.get_or_create(&row["w1_id"]); - - let l1_id = index_map.get_or_create(&row["l1_id"]); - - composition.push(vec![vec![w1_id], vec![l1_id]]); - } - - results.push(vec![1.0, 0.0]); - let date = Date::parse(&row["time_start"], &time_format).unwrap(); + let time = (date - from).whole_days(); - times.push((date - from).whole_days()); + if &row["double"] == "t" { + events.push(Event { + time, + teams: smallvec![ + Team::with_members([ + Member::new(row["w1_id"].to_owned()), + Member::new(row["w2_id"].to_owned()), + ]), + Team::with_members([ + Member::new(row["l1_id"].to_owned()), + Member::new(row["l2_id"].to_owned()), + ]), + ], + outcome: Outcome::winner(0, 2), + }); + } else { + events.push(Event { + time, + teams: smallvec![ + Team::with_members([Member::new(row["w1_id"].to_owned())]), + Team::with_members([Member::new(row["l1_id"].to_owned())]), + ], + outcome: Outcome::winner(0, 2), + }); + } } - let mut hist = History::builder().sigma(1.6).gamma(0.036).build(); + let mut hist: History = History::builder_with_key() + .sigma(1.6) + .drift(ConstantDrift(0.036)) + .convergence(trueskill_tt::ConvergenceOptions { + max_iter: 10, + epsilon: 0.01, + }) + .build(); - hist.add_events_with_prior(composition, results, times, vec![], HashMap::new()) - .unwrap(); - hist.convergence(10, 0.01, true); + hist.add_events(events).unwrap(); + hist.converge().unwrap(); let players = [ - ("aggasi", "a092", 38800), + ("aggasi", "a092", 38800i64), ("borg", "b058", 30300), ("connors", "c044", 31250), ("courier", "c243", 35750), @@ -64,21 +72,16 @@ fn main() { ("wilander", "w023", 32600), ]; - let curves = hist.learning_curves_by_index(); - let mut x_spec = (f64::MAX, f64::MIN); let mut y_spec = (f64::MAX, f64::MIN); - for (id, cutoff) in players - .iter() - .map(|&(_, id, cutoff)| (index_map.get_or_create(id), cutoff)) - { - for (ts, gs) in &curves[&id] { - if *ts >= cutoff { + for &(_, id, cutoff) in &players { + for (ts, gs) in hist.learning_curve(id) { + if ts >= cutoff { continue; } - let ts = *ts as f64; + let ts = ts as f64; if ts < x_spec.0 { x_spec.0 = ts; @@ -114,24 +117,19 @@ fn main() { chart.configure_mesh().draw().unwrap(); - for (idx, (player, id, cutoff)) in players - .iter() - .map(|&(player, id, cutoff)| (player, index_map.get_or_create(id), cutoff)) - .enumerate() - { + for (idx, &(player, id, cutoff)) in players.iter().enumerate() { let mut data = Vec::new(); let mut upper = Vec::new(); let mut lower = Vec::new(); - for (ts, gs) in curves[&id].iter() { - if *ts >= cutoff { + for (ts, gs) in hist.learning_curve(id) { + if ts >= cutoff { continue; } - data.push((*ts as f64, gs.mu())); - - upper.push((*ts as f64, gs.mu() + gs.sigma())); - lower.push((*ts as f64, gs.mu() - gs.sigma())); + data.push((ts as f64, gs.mu())); + upper.push((ts as f64, gs.mu() + gs.sigma())); + lower.push((ts as f64, gs.mu() - gs.sigma())); } let color = Palette99::pick(idx); diff --git a/src/history.rs b/src/history.rs index b738996..5191929 100644 --- a/src/history.rs +++ b/src/history.rs @@ -115,13 +115,6 @@ impl, O: Observer, K: Eq + Hash + Clone> HistoryBuilder< } } -impl, K: Eq + Hash + Clone> HistoryBuilder { - pub fn gamma(mut self, gamma: f64) -> Self { - self.drift = ConstantDrift(gamma); - self - } -} - impl Default for HistoryBuilder { fn default() -> Self { Self { @@ -171,6 +164,24 @@ impl History { } } +impl History { + /// Like `builder()` but uses a custom key type `K` instead of the default `&'static str`. + pub fn builder_with_key() -> HistoryBuilder { + HistoryBuilder { + mu: MU, + sigma: SIGMA, + beta: BETA, + drift: ConstantDrift(GAMMA), + p_draw: P_DRAW, + online: false, + convergence: ConvergenceOptions::default(), + observer: NullObserver, + _time: PhantomData, + _key: PhantomData, + } + } +} + impl, O: Observer, K: Eq + Hash + Clone> History { pub fn intern(&mut self, key: &Q) -> Index where @@ -246,57 +257,8 @@ impl, O: Observer, K: Eq + Hash + Clone> History ((f64, f64), usize) { - let mut step = (f64::INFINITY, f64::INFINITY); - let mut i = 0; - - while tuple_gt(step, epsilon) && i < iterations { - if verbose { - print!("Iteration = {}", i); - } - - step = self.iteration(); - - i += 1; - - if verbose { - println!(", step = {:?}", step); - } - } - - if verbose { - println!("End"); - } - - (step, i) - } - - /// Like `learning_curves`, but keyed by internal `Index`. Useful when - /// events were ingested via `Index` (rather than `record_winner` / - /// typed `add_events`), which doesn't populate the KeyTable. - pub fn learning_curves_by_index(&self) -> HashMap> { - let mut data: HashMap> = HashMap::new(); - for b in &self.time_slices { - for (agent, skill) in b.skills.iter() { - data.entry(agent) - .or_default() - .push((b.time, skill.posterior())); - } - } - data - } - /// Learning curves for all competitors, keyed by their user-facing key. /// - /// Returns an empty map for histories ingested via the raw `Index` path - /// (i.e. `add_events_with_prior` without `intern`/`record_winner`). - /// Use `learning_curves_by_index()` in that case. - /// /// Note: `key(idx)` is O(n) per lookup; this method is therefore O(n²) /// in the number of competitors. Acceptable for T2; T3 may optimize. pub fn learning_curves(&self) -> HashMap> { @@ -441,7 +403,7 @@ impl, O: Observer, K: Eq + Hash + Clone> History, O: Observer, K: Eq + Hash + Clone> History { - pub fn add_events_with_prior( + pub(crate) fn add_events_with_prior( &mut self, composition: Vec>>, results: Vec>, @@ -708,45 +670,58 @@ impl, O: Observer, K: Eq + Hash + Clone> History Vec> { + pairs + .iter() + .copied() + .zip(outcomes.iter().cloned()) + .zip(times.iter().copied()) + .map(|(((a, b), outcome), time)| Event { + time, + teams: smallvec![ + Team::with_members([Member::new(a)]), + Team::with_members([Member::new(b)]), + ], + outcome, + }) + .collect() + } + #[test] fn test_init() { - let mut index_map = KeyTable::new(); + let mut h = History::builder() + .mu(25.0) + .sigma(25.0 / 3.0) + .beta(25.0 / 6.0) + .drift(ConstantDrift(0.15 * 25.0 / 3.0)) + .build(); - let a = index_map.get_or_create("a"); - let b = index_map.get_or_create("b"); - let c = index_map.get_or_create("c"); + let events = make_events_1v1( + &[("a", "b"), ("a", "c"), ("b", "c")], + &[ + Outcome::winner(0, 2), + Outcome::winner(1, 2), + Outcome::winner(0, 2), + ], + &[1, 2, 3], + ); + h.add_events(events).unwrap(); - let composition = vec![ - vec![vec![a], vec![b]], - vec![vec![a], vec![c]], - vec![vec![b], vec![c]], - ]; - let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]]; - - let mut priors = HashMap::new(); - - for agent in [a, b, c] { - priors.insert( - agent, - Rating::new( - Gaussian::from_ms(25.0, 25.0 / 3.0), - 25.0 / 6.0, - ConstantDrift(0.15 * 25.0 / 3.0), - ), - ); - } - - let mut h = History::default(); - - h.add_events_with_prior(composition, results, vec![1, 2, 3], vec![], priors) - .unwrap(); + let a = h.keys.get("a").unwrap(); + let b = h.keys.get("b").unwrap(); + let c = h.keys.get("c").unwrap(); let p0 = h.time_slices[0].posteriors(); @@ -789,41 +764,32 @@ mod tests { let expected = p[0][0]; assert_ulps_eq!(observed, expected, epsilon = 1e-6); + + let _ = (b, c); } #[test] fn test_one_batch() { - let mut index_map = KeyTable::new(); + let mut h1 = History::builder() + .mu(25.0) + .sigma(25.0 / 3.0) + .beta(25.0 / 6.0) + .drift(ConstantDrift(0.15 * 25.0 / 3.0)) + .build(); - let a = index_map.get_or_create("a"); - let b = index_map.get_or_create("b"); - let c = index_map.get_or_create("c"); + let events = make_events_1v1( + &[("a", "b"), ("b", "c"), ("c", "a")], + &[ + Outcome::winner(0, 2), + Outcome::winner(0, 2), + Outcome::winner(0, 2), + ], + &[1, 1, 1], + ); + h1.add_events(events).unwrap(); - let composition = vec![ - vec![vec![a], vec![b]], - vec![vec![b], vec![c]], - vec![vec![c], vec![a]], - ]; - let results = vec![vec![1.0, 0.0], vec![1.0, 0.0], vec![1.0, 0.0]]; - let times = vec![1, 1, 1]; - - let mut priors = HashMap::new(); - - for agent in [a, b, c] { - priors.insert( - agent, - Rating::new( - Gaussian::from_ms(25.0, 25.0 / 3.0), - 25.0 / 6.0, - ConstantDrift(0.15 * 25.0 / 3.0), - ), - ); - } - - let mut h1 = History::default(); - - h1.add_events_with_prior(composition, results, times, vec![], priors) - .unwrap(); + let a = h1.keys.get("a").unwrap(); + let c = h1.keys.get("c").unwrap(); assert_ulps_eq!( h1.time_slices[0].skills.get(a).unwrap().posterior(), @@ -836,7 +802,7 @@ mod tests { epsilon = 1e-6 ); - h1.convergence(ITERATIONS, EPSILON, false); + h1.converge().unwrap(); assert_ulps_eq!( h1.time_slices[0].skills.get(a).unwrap().posterior(), @@ -849,31 +815,26 @@ mod tests { epsilon = 1e-6 ); - let composition = vec![ - vec![vec![a], vec![b]], - vec![vec![b], vec![c]], - vec![vec![c], vec![a]], - ]; - let results = vec![vec![1.0, 0.0], vec![1.0, 0.0], vec![1.0, 0.0]]; - let times = vec![1, 2, 3]; + let mut h2 = History::builder() + .mu(25.0) + .sigma(25.0 / 3.0) + .beta(25.0 / 6.0) + .drift(ConstantDrift(25.0 / 300.0)) + .build(); - let mut priors = HashMap::new(); + let events = make_events_1v1( + &[("a", "b"), ("b", "c"), ("c", "a")], + &[ + Outcome::winner(0, 2), + Outcome::winner(0, 2), + Outcome::winner(0, 2), + ], + &[1, 2, 3], + ); + h2.add_events(events).unwrap(); - for agent in [a, b, c] { - priors.insert( - agent, - Rating::new( - Gaussian::from_ms(25.0, 25.0 / 3.0), - 25.0 / 6.0, - ConstantDrift(25.0 / 300.0), - ), - ); - } - - let mut h2 = History::default(); - - h2.add_events_with_prior(composition, results, times, vec![], priors) - .unwrap(); + let a = h2.keys.get("a").unwrap(); + let c = h2.keys.get("c").unwrap(); assert_ulps_eq!( h2.time_slices[2].skills.get(a).unwrap().posterior(), @@ -886,7 +847,7 @@ mod tests { epsilon = 1e-6 ); - h2.convergence(ITERATIONS, EPSILON, false); + h2.converge().unwrap(); assert_ulps_eq!( h2.time_slices[2].skills.get(a).unwrap().posterior(), @@ -902,54 +863,41 @@ mod tests { #[test] fn test_learning_curves() { - let mut index_map = KeyTable::new(); + let mut h = History::builder() + .mu(25.0) + .sigma(25.0 / 3.0) + .beta(25.0 / 6.0) + .drift(ConstantDrift(25.0 / 300.0)) + .build(); - let a = index_map.get_or_create("a"); - let b = index_map.get_or_create("b"); - let c = index_map.get_or_create("c"); + let events = make_events_1v1( + &[("a", "b"), ("b", "c"), ("c", "a")], + &[ + Outcome::winner(0, 2), + Outcome::winner(0, 2), + Outcome::winner(0, 2), + ], + &[5, 6, 7], + ); + h.add_events(events).unwrap(); + h.converge().unwrap(); - let composition = vec![ - vec![vec![a], vec![b]], - vec![vec![b], vec![c]], - vec![vec![c], vec![a]], - ]; - let results = vec![vec![1.0, 0.0], vec![1.0, 0.0], vec![1.0, 0.0]]; - let times = vec![5, 6, 7]; + let lc_a = h.learning_curve("a"); + let lc_c = h.learning_curve("c"); - let mut priors = HashMap::new(); + let aj_e = lc_a.len(); + let cj_e = lc_c.len(); - for agent in [a, b, c] { - priors.insert( - agent, - Rating::new( - Gaussian::from_ms(25.0, 25.0 / 3.0), - 25.0 / 6.0, - ConstantDrift(25.0 / 300.0), - ), - ); - } - - let mut h = History::default(); - - h.add_events_with_prior(composition, results, times, vec![], priors) - .unwrap(); - h.convergence(ITERATIONS, EPSILON, false); - - let lc = h.learning_curves_by_index(); - - let aj_e = lc[&a].len(); - let cj_e = lc[&c].len(); - - assert_eq!(lc[&a][0].0, 5); - assert_eq!(lc[&a][aj_e - 1].0, 7); + assert_eq!(lc_a[0].0, 5); + assert_eq!(lc_a[aj_e - 1].0, 7); assert_ulps_eq!( - lc[&a][aj_e - 1].1, + lc_a[aj_e - 1].1, Gaussian::from_ms(24.998668, 5.420053), epsilon = 1e-6 ); assert_ulps_eq!( - lc[&c][cj_e - 1].1, + lc_c[cj_e - 1].1, Gaussian::from_ms(25.000532, 5.419827), epsilon = 1e-6 ); @@ -957,32 +905,28 @@ mod tests { #[test] fn test_env_ttt() { - let mut index_map = KeyTable::new(); - - let a = index_map.get_or_create("a"); - let b = index_map.get_or_create("b"); - let c = index_map.get_or_create("c"); - - let composition = vec![ - vec![vec![a], vec![b]], - vec![vec![a], vec![c]], - vec![vec![b], vec![c]], - ]; - let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]]; - let mut h = History::builder() .mu(25.0) .sigma(25.0 / 3.0) .beta(25.0 / 6.0) - .gamma(25.0 / 300.0) + .drift(ConstantDrift(25.0 / 300.0)) .build(); - let n = composition.len(); - let times: Vec = (1..=n as i64).collect(); - h.add_events_with_prior(composition, results, times, vec![], HashMap::new()) - .unwrap(); + let events = make_events_1v1( + &[("a", "b"), ("a", "c"), ("b", "c")], + &[ + Outcome::winner(0, 2), + Outcome::winner(1, 2), + Outcome::winner(0, 2), + ], + &[1, 2, 3], + ); + h.add_events(events).unwrap(); + h.converge().unwrap(); - h.convergence(ITERATIONS, EPSILON, false); + let a = h.keys.get("a").unwrap(); + let b = h.keys.get("b").unwrap(); + let c = h.keys.get("c").unwrap(); assert_eq!(h.time_slices[2].skills.get(b).unwrap().elapsed, 2); assert_eq!(h.time_slices[2].skills.get(c).unwrap().elapsed, 1); @@ -1006,33 +950,47 @@ mod tests { #[test] fn test_teams() { - let mut index_map = KeyTable::new(); - - let a = index_map.get_or_create("a"); - let b = index_map.get_or_create("b"); - let c = index_map.get_or_create("c"); - let d = index_map.get_or_create("d"); - let e = index_map.get_or_create("e"); - let f = index_map.get_or_create("f"); - - let composition = vec![ - vec![vec![a, b], vec![c, d]], - vec![vec![e, f], vec![b, c]], - vec![vec![a, d], vec![e, f]], - ]; - let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]]; - - let mut h = History::builder() + let mut h: History = History::builder() .mu(0.0) .sigma(6.0) .beta(1.0) - .gamma(0.0) + .drift(ConstantDrift(0.0)) .build(); - let n = composition.len(); - let times: Vec = (1..=n as i64).collect(); - h.add_events_with_prior(composition, results, times, vec![], HashMap::new()) - .unwrap(); + let events: Vec> = vec![ + Event { + time: 1, + teams: smallvec![ + Team::with_members([Member::new("a"), Member::new("b")]), + Team::with_members([Member::new("c"), Member::new("d")]), + ], + outcome: Outcome::winner(0, 2), + }, + Event { + time: 2, + teams: smallvec![ + Team::with_members([Member::new("e"), Member::new("f")]), + Team::with_members([Member::new("b"), Member::new("c")]), + ], + outcome: Outcome::winner(1, 2), + }, + Event { + time: 3, + teams: smallvec![ + Team::with_members([Member::new("a"), Member::new("d")]), + Team::with_members([Member::new("e"), Member::new("f")]), + ], + outcome: Outcome::winner(0, 2), + }, + ]; + h.add_events(events).unwrap(); + + let a = h.keys.get("a").unwrap(); + let b = h.keys.get("b").unwrap(); + let c = h.keys.get("c").unwrap(); + let d = h.keys.get("d").unwrap(); + let e = h.keys.get("e").unwrap(); + let f = h.keys.get("f").unwrap(); let trueskill_log_evidence = h.log_evidence_internal(false, &[]); let trueskill_log_evidence_online = h.log_evidence_internal(true, &[]); @@ -1055,7 +1013,7 @@ mod tests { let evidence_third_event = h.log_evidence_internal(false, &[a]).exp() * 2.0; assert_ulps_eq!(0.669885, evidence_third_event, epsilon = 1e-6); - h.convergence(ITERATIONS, EPSILON, false); + h.converge().unwrap(); let loocv_hat = h.log_evidence_internal(false, &[]).exp(); let p_d_m_hat = h.log_evidence_internal(true, &[]).exp(); @@ -1098,38 +1056,29 @@ mod tests { #[test] fn test_add_events() { - let mut index_map = KeyTable::new(); - - let a = index_map.get_or_create("a"); - let b = index_map.get_or_create("b"); - let c = index_map.get_or_create("c"); - - let composition = vec![ - vec![vec![a], vec![b]], - vec![vec![a], vec![c]], - vec![vec![b], vec![c]], - ]; - let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]]; - - let mut h = History::builder() + let mut h: History = History::builder() .mu(0.0) .sigma(2.0) .beta(1.0) - .gamma(0.0) + .drift(ConstantDrift(0.0)) .build(); - let n = composition.len(); - let times: Vec = (1..=n as i64).collect(); - h.add_events_with_prior( - composition.clone(), - results.clone(), - times, - vec![], - HashMap::new(), - ) - .unwrap(); + let events = make_events_1v1( + &[("a", "b"), ("a", "c"), ("b", "c")], + &[ + Outcome::winner(0, 2), + Outcome::winner(1, 2), + Outcome::winner(0, 2), + ], + &[1, 2, 3], + ); + h.add_events(events).unwrap(); - h.convergence(ITERATIONS, EPSILON, false); + let a = h.keys.get("a").unwrap(); + let b = h.keys.get("b").unwrap(); + let c = h.keys.get("c").unwrap(); + + h.converge().unwrap(); assert_eq!(h.time_slices[2].skills.get(b).unwrap().elapsed, 2); assert_eq!(h.time_slices[2].skills.get(c).unwrap().elapsed, 1); @@ -1150,9 +1099,16 @@ mod tests { epsilon = 1e-6 ); - let times2: Vec = (n as i64 + 1..=2 * n as i64).collect(); - h.add_events_with_prior(composition, results, times2, vec![], HashMap::new()) - .unwrap(); + let events2 = make_events_1v1( + &[("a", "b"), ("a", "c"), ("b", "c")], + &[ + Outcome::winner(0, 2), + Outcome::winner(1, 2), + Outcome::winner(0, 2), + ], + &[4, 5, 6], + ); + h.add_events(events2).unwrap(); assert_eq!(h.time_slices.len(), 6); @@ -1171,7 +1127,7 @@ mod tests { ] ); - h.convergence(ITERATIONS, EPSILON, false); + h.converge().unwrap(); assert_ulps_eq!( h.time_slices[0].skills.get(a).unwrap().posterior(), @@ -1197,38 +1153,29 @@ mod tests { #[test] fn test_only_add_events() { - let mut index_map = KeyTable::new(); - - let a = index_map.get_or_create("a"); - let b = index_map.get_or_create("b"); - let c = index_map.get_or_create("c"); - - let composition = vec![ - vec![vec![a], vec![b]], - vec![vec![a], vec![c]], - vec![vec![b], vec![c]], - ]; - let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]]; - - let mut h = History::builder() + let mut h: History = History::builder() .mu(0.0) .sigma(2.0) .beta(1.0) - .gamma(0.0) + .drift(ConstantDrift(0.0)) .build(); - let n = composition.len(); - let times: Vec = (1..=n as i64).collect(); - h.add_events_with_prior( - composition.clone(), - results.clone(), - times, - vec![], - HashMap::new(), - ) - .unwrap(); + let events = make_events_1v1( + &[("a", "b"), ("a", "c"), ("b", "c")], + &[ + Outcome::winner(0, 2), + Outcome::winner(1, 2), + Outcome::winner(0, 2), + ], + &[1, 2, 3], + ); + h.add_events(events).unwrap(); - h.convergence(ITERATIONS, EPSILON, false); + let a = h.keys.get("a").unwrap(); + let b = h.keys.get("b").unwrap(); + let c = h.keys.get("c").unwrap(); + + h.converge().unwrap(); assert_eq!(h.time_slices[2].skills.get(b).unwrap().elapsed, 2); assert_eq!(h.time_slices[2].skills.get(c).unwrap().elapsed, 1); @@ -1249,9 +1196,16 @@ mod tests { epsilon = 1e-6 ); - let times2: Vec = (n as i64 + 1..=2 * n as i64).collect(); - h.add_events_with_prior(composition, results, times2, vec![], HashMap::new()) - .unwrap(); + let events2 = make_events_1v1( + &[("a", "b"), ("a", "c"), ("b", "c")], + &[ + Outcome::winner(0, 2), + Outcome::winner(1, 2), + Outcome::winner(0, 2), + ], + &[4, 5, 6], + ); + h.add_events(events2).unwrap(); assert_eq!(h.time_slices.len(), 6); @@ -1270,7 +1224,7 @@ mod tests { ] ); - h.convergence(ITERATIONS, EPSILON, false); + h.converge().unwrap(); assert_ulps_eq!( h.time_slices[0].skills.get(a).unwrap().posterior(), @@ -1296,25 +1250,20 @@ mod tests { #[test] fn test_log_evidence() { - let mut index_map = KeyTable::new(); + use crate::ConvergenceOptions; - let a = index_map.get_or_create("a"); - let b = index_map.get_or_create("b"); + let mut h: History = History::builder().build(); - let composition = vec![vec![vec![a], vec![b]], vec![vec![b], vec![a]]]; + // empty results in the old API = team 0 wins; reproduce with Outcome::winner(0,2) + let events = make_events_1v1( + &[("a", "b"), ("b", "a")], + &[Outcome::winner(0, 2), Outcome::winner(0, 2)], + &[1, 2], + ); + h.add_events(events).unwrap(); - let mut h = History::builder().build(); - - let n = composition.len(); - let times: Vec = (1..=n as i64).collect(); - h.add_events_with_prior( - composition.clone(), - vec![], - times.clone(), - vec![], - HashMap::new(), - ) - .unwrap(); + let a = h.keys.get("a").unwrap(); + let b = h.keys.get("b").unwrap(); let p_d_m_2 = h.log_evidence_internal(false, &[]).exp() * 2.0; @@ -1335,7 +1284,12 @@ mod tests { epsilon = 1e-6 ); - h.convergence(11, EPSILON, false); + // run exactly 11 iterations (old test used convergence(11, ...)) + h.convergence = ConvergenceOptions { + max_iter: 11, + epsilon: EPSILON, + }; + h.converge().unwrap(); let loocv_approx_2 = h.log_evidence_internal(false, &[]).exp().sqrt(); @@ -1351,59 +1305,57 @@ mod tests { epsilon = 1e-4 ); - let mut h = History::builder().build(); + let mut h2: History = History::builder().build(); - h.add_events_with_prior(composition, vec![], times, vec![], HashMap::new()) - .unwrap(); + let events = make_events_1v1( + &[("a", "b"), ("b", "a")], + &[Outcome::winner(0, 2), Outcome::winner(0, 2)], + &[1, 2], + ); + h2.add_events(events).unwrap(); assert_ulps_eq!( ((0.5f64 * 0.1765).ln() / 2.0).exp(), - (h.log_evidence_internal(false, &[]) / 2.0).exp(), + (h2.log_evidence_internal(false, &[]) / 2.0).exp(), epsilon = 1e-4 ); } #[test] fn test_add_events_with_time() { - let mut index_map = KeyTable::new(); - - let a = index_map.get_or_create("a"); - let b = index_map.get_or_create("b"); - let c = index_map.get_or_create("c"); - - let composition = vec![ - vec![vec![a], vec![b]], - vec![vec![a], vec![c]], - vec![vec![b], vec![c]], - ]; - let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]]; - - let mut h = History::builder() + let mut h: History = History::builder() .mu(0.0) .sigma(2.0) .beta(1.0) - .gamma(0.0) + .drift(ConstantDrift(0.0)) .build(); - h.add_events_with_prior( - composition.clone(), - results.clone(), - vec![0, 10, 20], - vec![], - HashMap::new(), - ) - .unwrap(); + let events = make_events_1v1( + &[("a", "b"), ("a", "c"), ("b", "c")], + &[ + Outcome::winner(0, 2), + Outcome::winner(1, 2), + Outcome::winner(0, 2), + ], + &[0, 10, 20], + ); + h.add_events(events).unwrap(); + h.converge().unwrap(); - h.convergence(ITERATIONS, EPSILON, false); + let a = h.keys.get("a").unwrap(); + let b = h.keys.get("b").unwrap(); + let c = h.keys.get("c").unwrap(); - h.add_events_with_prior( - composition, - results, - vec![15, 10, 0], - vec![], - HashMap::new(), - ) - .unwrap(); + let events2 = make_events_1v1( + &[("a", "b"), ("a", "c"), ("b", "c")], + &[ + Outcome::winner(0, 2), + Outcome::winner(1, 2), + Outcome::winner(0, 2), + ], + &[15, 10, 0], + ); + h.add_events(events2).unwrap(); assert_eq!(h.time_slices.len(), 4); @@ -1452,7 +1404,7 @@ mod tests { assert_eq!(h.time_slices[0].skills.get(b).unwrap().elapsed, 0); assert_eq!(h.time_slices[end].skills.get(b).unwrap().elapsed, 5); - h.convergence(ITERATIONS, EPSILON, false); + h.converge().unwrap(); assert_ulps_eq!( h.time_slices[0].skills.get(b).unwrap().posterior(), @@ -1472,39 +1424,46 @@ mod tests { epsilon = 1e-6 ); - // --------------------------------------- + // second scenario: team-0 wins (empty results in old API), different composition order - let composition = vec![ - vec![vec![a], vec![b]], - vec![vec![c], vec![a]], - vec![vec![b], vec![c]], - ]; - - let mut h = History::builder() + let mut h2: History = History::builder() .mu(0.0) .sigma(2.0) .beta(1.0) - .gamma(0.0) + .drift(ConstantDrift(0.0)) .build(); - h.add_events_with_prior( - composition.clone(), - vec![], - vec![0, 10, 20], - vec![], - HashMap::new(), - ) - .unwrap(); + let events = make_events_1v1( + &[("a", "b"), ("c", "a"), ("b", "c")], + &[ + Outcome::winner(0, 2), + Outcome::winner(0, 2), + Outcome::winner(0, 2), + ], + &[0, 10, 20], + ); + h2.add_events(events).unwrap(); + h2.converge().unwrap(); - h.convergence(ITERATIONS, EPSILON, false); + let a = h2.keys.get("a").unwrap(); + let b = h2.keys.get("b").unwrap(); + let c = h2.keys.get("c").unwrap(); - h.add_events_with_prior(composition, vec![], vec![15, 10, 0], vec![], HashMap::new()) - .unwrap(); + let events2 = make_events_1v1( + &[("a", "b"), ("c", "a"), ("b", "c")], + &[ + Outcome::winner(0, 2), + Outcome::winner(0, 2), + Outcome::winner(0, 2), + ], + &[15, 10, 0], + ); + h2.add_events(events2).unwrap(); - assert_eq!(h.time_slices.len(), 4); + assert_eq!(h2.time_slices.len(), 4); assert_eq!( - h.time_slices + h2.time_slices .iter() .map(|ts| ts.events.len()) .collect::>(), @@ -1512,7 +1471,7 @@ mod tests { ); assert_eq!( - h.time_slices + h2.time_slices .iter() .map(|b| b.get_composition()) .collect::>(), @@ -1525,7 +1484,7 @@ mod tests { ); assert_eq!( - h.time_slices + h2.time_slices .iter() .map(|b| b.get_results()) .collect::>(), @@ -1537,110 +1496,108 @@ mod tests { ] ); - let end = h.time_slices.len() - 1; + let end = h2.time_slices.len() - 1; - assert_eq!(h.time_slices[0].skills.get(c).unwrap().elapsed, 0); - assert_eq!(h.time_slices[end].skills.get(c).unwrap().elapsed, 10); + assert_eq!(h2.time_slices[0].skills.get(c).unwrap().elapsed, 0); + assert_eq!(h2.time_slices[end].skills.get(c).unwrap().elapsed, 10); - assert_eq!(h.time_slices[0].skills.get(a).unwrap().elapsed, 0); - assert_eq!(h.time_slices[2].skills.get(a).unwrap().elapsed, 5); + assert_eq!(h2.time_slices[0].skills.get(a).unwrap().elapsed, 0); + assert_eq!(h2.time_slices[2].skills.get(a).unwrap().elapsed, 5); - assert_eq!(h.time_slices[0].skills.get(b).unwrap().elapsed, 0); - assert_eq!(h.time_slices[end].skills.get(b).unwrap().elapsed, 5); + assert_eq!(h2.time_slices[0].skills.get(b).unwrap().elapsed, 0); + assert_eq!(h2.time_slices[end].skills.get(b).unwrap().elapsed, 5); - h.convergence(ITERATIONS, EPSILON, false); + h2.converge().unwrap(); assert_ulps_eq!( - h.time_slices[0].skills.get(b).unwrap().posterior(), - h.time_slices[end].skills.get(b).unwrap().posterior(), + h2.time_slices[0].skills.get(b).unwrap().posterior(), + h2.time_slices[end].skills.get(b).unwrap().posterior(), epsilon = 1e-6 ); assert_ulps_eq!( - h.time_slices[0].skills.get(c).unwrap().posterior(), - h.time_slices[end].skills.get(c).unwrap().posterior(), + h2.time_slices[0].skills.get(c).unwrap().posterior(), + h2.time_slices[end].skills.get(c).unwrap().posterior(), epsilon = 1e-6 ); assert_ulps_eq!( - h.time_slices[0].skills.get(c).unwrap().posterior(), - h.time_slices[0].skills.get(b).unwrap().posterior(), + h2.time_slices[0].skills.get(c).unwrap().posterior(), + h2.time_slices[0].skills.get(b).unwrap().posterior(), epsilon = 1e-6 ); } #[test] fn test_1vs1_weighted() { - let mut index_map = KeyTable::new(); - - let a = index_map.get_or_create("a"); - let b = index_map.get_or_create("b"); - - let composition = vec![vec![vec![a], vec![b]], vec![vec![b], vec![a]]]; - let weights = vec![vec![vec![5.0], vec![4.0]], vec![vec![5.0], vec![4.0]]]; - - let mut h = History::builder() + let mut h: History = History::builder() .mu(2.0) .sigma(6.0) .beta(1.0) - .gamma(0.0) + .drift(ConstantDrift(0.0)) .build(); - let n = composition.len(); - let times: Vec = (1..=n as i64).collect(); - h.add_events_with_prior(composition, vec![], times, weights, HashMap::new()) - .unwrap(); + // empty results in old API = team 0 wins: a wins event 1, b wins event 2 + let events: Vec> = vec![ + Event { + time: 1, + teams: smallvec![ + Team::with_members([Member::new("a").with_weight(5.0)]), + Team::with_members([Member::new("b").with_weight(4.0)]), + ], + outcome: Outcome::winner(0, 2), + }, + Event { + time: 2, + teams: smallvec![ + Team::with_members([Member::new("b").with_weight(5.0)]), + Team::with_members([Member::new("a").with_weight(4.0)]), + ], + outcome: Outcome::winner(0, 2), + }, + ]; + h.add_events(events).unwrap(); - let lc = h.learning_curves_by_index(); + let lc_a = h.learning_curve("a"); + let lc_b = h.learning_curve("b"); assert_ulps_eq!( - lc[&a][0].1, + lc_a[0].1, Gaussian::from_ms(5.537659, 4.758722), epsilon = 1e-6 ); assert_ulps_eq!( - lc[&b][0].1, + lc_b[0].1, Gaussian::from_ms(-0.830127, 5.239568), epsilon = 1e-6 ); assert_ulps_eq!( - lc[&a][1].1, + lc_a[1].1, Gaussian::from_ms(1.792277, 4.099566), epsilon = 1e-6 ); assert_ulps_eq!( - lc[&b][1].1, + lc_b[1].1, Gaussian::from_ms(4.845533, 3.747616), epsilon = 1e-6 ); - h.convergence(ITERATIONS, EPSILON, false); + h.converge().unwrap(); - let lc = h.learning_curves_by_index(); + let lc_a = h.learning_curve("a"); + let lc_b = h.learning_curve("b"); - assert_ulps_eq!(lc[&a][0].1, lc[&a][0].1, epsilon = 1e-6); - assert_ulps_eq!(lc[&b][0].1, lc[&a][0].1, epsilon = 1e-6); - assert_ulps_eq!(lc[&a][1].1, lc[&a][0].1, epsilon = 1e-6); - assert_ulps_eq!(lc[&b][1].1, lc[&a][0].1, epsilon = 1e-6); + assert_ulps_eq!(lc_a[0].1, lc_a[0].1, epsilon = 1e-6); + assert_ulps_eq!(lc_b[0].1, lc_a[0].1, epsilon = 1e-6); + assert_ulps_eq!(lc_a[1].1, lc_a[0].1, epsilon = 1e-6); + assert_ulps_eq!(lc_b[1].1, lc_a[0].1, epsilon = 1e-6); } #[test] fn test_converge_returns_report() { use crate::ConvergenceOptions; - let mut index_map = crate::KeyTable::new(); - let a = index_map.get_or_create("a"); - let b = index_map.get_or_create("b"); - let c = index_map.get_or_create("c"); - let composition = vec![ - vec![vec![a], vec![b]], - vec![vec![a], vec![c]], - vec![vec![b], vec![c]], - ]; - let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]]; - let times: Vec = vec![1, 2, 3]; - - let mut h = History::builder() + let mut h: History = History::builder() .mu(0.0) .sigma(2.0) .beta(1.0) @@ -1650,8 +1607,17 @@ mod tests { epsilon: 1e-6, }) .build(); - h.add_events_with_prior(composition, results, times, vec![], HashMap::new()) - .unwrap(); + + let events = make_events_1v1( + &[("a", "b"), ("a", "c"), ("b", "c")], + &[ + Outcome::winner(0, 2), + Outcome::winner(1, 2), + Outcome::winner(0, 2), + ], + &[1, 2, 3], + ); + h.add_events(events).unwrap(); let report = h.converge().unwrap(); assert!(report.converged); diff --git a/tests/equivalence.rs b/tests/equivalence.rs new file mode 100644 index 0000000..222d7dd --- /dev/null +++ b/tests/equivalence.rs @@ -0,0 +1,61 @@ +//! Equivalence tests: every historical golden from the pre-T2 tests is +//! reproduced here at the integration level via the new public API. +//! +//! The in-crate tests in `src/history.rs::tests` and +//! `src/time_slice.rs::tests` are the primary regression net for numerical +//! behavior. This file provides Game-level goldens that stand alone and are +//! more naturally expressed as integration tests. + +use approx::assert_ulps_eq; +use trueskill_tt::{ConstantDrift, Game, GameOptions, Gaussian, Outcome, Rating}; + +type R = Rating; + +fn ts_rating(mu: f64, sigma: f64, beta: f64, gamma: f64) -> R { + R::new(Gaussian::from_ms(mu, sigma), beta, ConstantDrift(gamma)) +} + +#[test] +fn game_1v1_golden_matches_historical() { + let a = ts_rating(25.0, 25.0 / 3.0, 25.0 / 6.0, 25.0 / 300.0); + let b = ts_rating(25.0, 25.0 / 3.0, 25.0 / 6.0, 25.0 / 300.0); + let (a_post, b_post) = Game::::one_v_one(&a, &b, Outcome::winner(0, 2)).unwrap(); + // Historical golden from pre-T2 test_1vs1 (team 0 wins): + assert_ulps_eq!( + a_post, + Gaussian::from_ms(29.205220, 7.194481), + epsilon = 1e-6 + ); + assert_ulps_eq!( + b_post, + Gaussian::from_ms(20.794779, 7.194481), + epsilon = 1e-6 + ); +} + +#[test] +fn game_1v1_draw_golden() { + let a = ts_rating(25.0, 25.0 / 3.0, 25.0 / 6.0, 25.0 / 300.0); + let b = ts_rating(25.0, 25.0 / 3.0, 25.0 / 6.0, 25.0 / 300.0); + let g = Game::::ranked( + &[&[a], &[b]], + Outcome::draw(2), + &GameOptions { + p_draw: 0.25, + convergence: Default::default(), + }, + ) + .unwrap(); + let p = g.posteriors(); + // Historical golden from pre-T2 test_1vs1_draw: + assert_ulps_eq!( + p[0][0], + Gaussian::from_ms(24.999999, 6.469480), + epsilon = 1e-6 + ); + assert_ulps_eq!( + p[1][0], + Gaussian::from_ms(24.999999, 6.469480), + epsilon = 1e-6 + ); +}