diff --git a/examples/atp.rs b/examples/atp.rs index 236e8e3..7a96599 100644 --- a/examples/atp.rs +++ b/examples/atp.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use plotters::prelude::*; use time::{Date, Month}; use trueskill_tt::{History, KeyTable}; @@ -40,7 +42,7 @@ fn main() { let mut hist = History::builder().sigma(1.6).gamma(0.036).build(); - hist.add_events(composition, results, times, vec![]) + hist.add_events_with_prior(composition, results, times, vec![], HashMap::new()) .unwrap(); hist.convergence(10, 0.01, true); diff --git a/src/history.rs b/src/history.rs index 7fa9229..add2ba0 100644 --- a/src/history.rs +++ b/src/history.rs @@ -332,24 +332,14 @@ impl, O: Observer, K: Eq + Hash + Clone> History, O: Observer, K: Eq + Hash + Clone> History { - pub fn add_events( - &mut self, - composition: Vec>>, - results: Vec>, - times: Vec, - weights: Vec>>, - ) -> Result<(), InferenceError> { - self.add_events_with_prior(composition, results, times, weights, HashMap::new()) - } - +impl, O: Observer, K: Eq + Hash + Clone> History { pub fn add_events_with_prior( &mut self, composition: Vec>>, results: Vec>, - times: Vec, + times: Vec, weights: Vec>>, - mut priors: HashMap>, + mut priors: HashMap>, ) -> Result<(), InferenceError> { if !results.is_empty() && results.len() != composition.len() { return Err(InferenceError::MismatchedShape { @@ -513,12 +503,7 @@ impl, O: Observer, K: Eq + Hash + Clone> History( - &mut self, - winner: &Q, - loser: &Q, - time: i64, - ) -> Result<(), InferenceError> + pub fn record_winner(&mut self, winner: &Q, loser: &Q, time: T) -> Result<(), InferenceError> where K: Borrow, Q: Hash + Eq + ToOwned + ?Sized, @@ -534,7 +519,7 @@ impl, O: Observer, K: Eq + Hash + Clone> History(&mut self, a: &Q, b: &Q, time: i64) -> Result<(), InferenceError> + pub fn record_draw(&mut self, a: &Q, b: &Q, time: T) -> Result<(), InferenceError> where K: Borrow, Q: Hash + Eq + ToOwned + ?Sized, @@ -549,6 +534,62 @@ impl, O: Observer, K: Eq + Hash + Clone> History(&mut self, events: I) -> Result<(), InferenceError> + where + I: IntoIterator>, + { + use crate::event::Event; + let events: Vec> = events.into_iter().collect(); + if events.is_empty() { + return Ok(()); + } + + let mut composition: Vec>> = Vec::with_capacity(events.len()); + let mut results: Vec> = Vec::with_capacity(events.len()); + let mut times: Vec = Vec::with_capacity(events.len()); + let mut weights: Vec>> = Vec::with_capacity(events.len()); + let mut priors: HashMap> = HashMap::new(); + + for ev in events { + let ranks = ev.outcome.as_ranks(); + if ranks.len() != ev.teams.len() { + return Err(InferenceError::MismatchedShape { + kind: "outcome ranks vs teams", + expected: ev.teams.len(), + got: ranks.len(), + }); + } + + let mut event_comp: Vec> = Vec::with_capacity(ev.teams.len()); + let mut event_weights: Vec> = Vec::with_capacity(ev.teams.len()); + + for team in ev.teams { + let mut team_indices: Vec = Vec::with_capacity(team.members.len()); + let mut team_weights: Vec = Vec::with_capacity(team.members.len()); + for member in team.members { + let idx = self.keys.get_or_create(&member.key); + team_indices.push(idx); + team_weights.push(member.weight); + if let Some(prior) = member.prior { + priors.insert(idx, Rating::new(prior, self.beta, self.drift)); + } + } + event_comp.push(team_indices); + event_weights.push(team_weights); + } + composition.push(event_comp); + weights.push(event_weights); + + let max_rank = ranks.iter().copied().max().unwrap_or(0) as f64; + let inverted: Vec = ranks.iter().map(|&r| max_rank - r as f64).collect(); + results.push(inverted); + times.push(ev.time); + } + + self.add_events_with_prior(composition, results, times, weights, priors) + } } #[cfg(test)] @@ -825,7 +866,8 @@ mod tests { let n = composition.len(); let times: Vec = (1..=n as i64).collect(); - h.add_events(composition, results, times, vec![]).unwrap(); + h.add_events_with_prior(composition, results, times, vec![], HashMap::new()) + .unwrap(); h.convergence(ITERATIONS, EPSILON, false); @@ -876,7 +918,8 @@ mod tests { let n = composition.len(); let times: Vec = (1..=n as i64).collect(); - h.add_events(composition, results, times, vec![]).unwrap(); + h.add_events_with_prior(composition, results, times, vec![], HashMap::new()) + .unwrap(); let trueskill_log_evidence = h.log_evidence(false, &[]); let trueskill_log_evidence_online = h.log_evidence(true, &[]); @@ -964,8 +1007,14 @@ mod tests { let n = composition.len(); let times: Vec = (1..=n as i64).collect(); - h.add_events(composition.clone(), results.clone(), times, vec![]) - .unwrap(); + h.add_events_with_prior( + composition.clone(), + results.clone(), + times, + vec![], + HashMap::new(), + ) + .unwrap(); h.convergence(ITERATIONS, EPSILON, false); @@ -989,7 +1038,8 @@ mod tests { ); let times2: Vec = (n as i64 + 1..=2 * n as i64).collect(); - h.add_events(composition, results, times2, vec![]).unwrap(); + h.add_events_with_prior(composition, results, times2, vec![], HashMap::new()) + .unwrap(); assert_eq!(h.time_slices.len(), 6); @@ -1056,8 +1106,14 @@ mod tests { let n = composition.len(); let times: Vec = (1..=n as i64).collect(); - h.add_events(composition.clone(), results.clone(), times, vec![]) - .unwrap(); + h.add_events_with_prior( + composition.clone(), + results.clone(), + times, + vec![], + HashMap::new(), + ) + .unwrap(); h.convergence(ITERATIONS, EPSILON, false); @@ -1081,7 +1137,8 @@ mod tests { ); let times2: Vec = (n as i64 + 1..=2 * n as i64).collect(); - h.add_events(composition, results, times2, vec![]).unwrap(); + h.add_events_with_prior(composition, results, times2, vec![], HashMap::new()) + .unwrap(); assert_eq!(h.time_slices.len(), 6); @@ -1137,8 +1194,14 @@ mod tests { let n = composition.len(); let times: Vec = (1..=n as i64).collect(); - h.add_events(composition.clone(), vec![], times.clone(), vec![]) - .unwrap(); + h.add_events_with_prior( + composition.clone(), + vec![], + times.clone(), + vec![], + HashMap::new(), + ) + .unwrap(); let p_d_m_2 = h.log_evidence(false, &[]).exp() * 2.0; @@ -1177,7 +1240,8 @@ mod tests { let mut h = History::builder().build(); - h.add_events(composition, vec![], times, vec![]).unwrap(); + h.add_events_with_prior(composition, vec![], times, vec![], HashMap::new()) + .unwrap(); assert_ulps_eq!( ((0.5f64 * 0.1765).ln() / 2.0).exp(), @@ -1208,18 +1272,25 @@ mod tests { .gamma(0.0) .build(); - h.add_events( + h.add_events_with_prior( composition.clone(), results.clone(), vec![0, 10, 20], vec![], + HashMap::new(), ) .unwrap(); h.convergence(ITERATIONS, EPSILON, false); - h.add_events(composition, results, vec![15, 10, 0], vec![]) - .unwrap(); + h.add_events_with_prior( + composition, + results, + vec![15, 10, 0], + vec![], + HashMap::new(), + ) + .unwrap(); assert_eq!(h.time_slices.len(), 4); @@ -1303,12 +1374,18 @@ mod tests { .gamma(0.0) .build(); - h.add_events(composition.clone(), vec![], vec![0, 10, 20], vec![]) - .unwrap(); + h.add_events_with_prior( + composition.clone(), + vec![], + vec![0, 10, 20], + vec![], + HashMap::new(), + ) + .unwrap(); h.convergence(ITERATIONS, EPSILON, false); - h.add_events(composition, vec![], vec![15, 10, 0], vec![]) + h.add_events_with_prior(composition, vec![], vec![15, 10, 0], vec![], HashMap::new()) .unwrap(); assert_eq!(h.time_slices.len(), 4); @@ -1398,7 +1475,8 @@ mod tests { let n = composition.len(); let times: Vec = (1..=n as i64).collect(); - h.add_events(composition, vec![], times, weights).unwrap(); + h.add_events_with_prior(composition, vec![], times, weights, HashMap::new()) + .unwrap(); let lc = h.learning_curves(); @@ -1459,7 +1537,8 @@ mod tests { epsilon: 1e-6, }) .build(); - h.add_events(composition, results, times, vec![]).unwrap(); + h.add_events_with_prior(composition, results, times, vec![], HashMap::new()) + .unwrap(); let report = h.converge().unwrap(); assert!(report.converged); diff --git a/tests/api_shape.rs b/tests/api_shape.rs new file mode 100644 index 0000000..3a8dec4 --- /dev/null +++ b/tests/api_shape.rs @@ -0,0 +1,84 @@ +//! Tests for the new T2 public API surface: typed add_events(iter) and the +//! fluent event builder (added in Task 16). + +use smallvec::smallvec; +use trueskill_tt::{ConstantDrift, ConvergenceOptions, Event, History, Member, Outcome, Team}; + +#[test] +fn add_events_bulk_via_iter() { + let mut h = History::builder() + .mu(0.0) + .sigma(2.0) + .beta(1.0) + .p_draw(0.0) + .drift(ConstantDrift(0.0)) + .convergence(ConvergenceOptions { + max_iter: 30, + epsilon: 1e-6, + }) + .build(); + + let events: Vec> = vec![ + Event { + time: 1, + teams: smallvec![ + Team::with_members([Member::new("a")]), + Team::with_members([Member::new("b")]), + ], + outcome: Outcome::winner(0, 2), + }, + Event { + time: 2, + teams: smallvec![ + Team::with_members([Member::new("b")]), + Team::with_members([Member::new("c")]), + ], + outcome: Outcome::winner(0, 2), + }, + ]; + + h.add_events(events).unwrap(); + let report = h.converge().unwrap(); + assert!(report.converged); + assert!(h.lookup(&"a").is_some()); + assert!(h.lookup(&"b").is_some()); + assert!(h.lookup(&"c").is_some()); +} + +#[test] +fn add_events_draw() { + let mut h = History::builder() + .mu(25.0) + .sigma(25.0 / 3.0) + .beta(25.0 / 6.0) + .p_draw(0.25) + .drift(ConstantDrift(25.0 / 300.0)) + .build(); + + let events: Vec> = vec![Event { + time: 1, + teams: smallvec![ + Team::with_members([Member::new("alice")]), + Team::with_members([Member::new("bob")]), + ], + outcome: Outcome::draw(2), + }]; + h.add_events(events).unwrap(); + h.converge().unwrap(); +} + +#[test] +fn add_events_rejects_mismatched_outcome_ranks() { + use trueskill_tt::InferenceError; + let mut h: History = History::builder().build(); + let events: Vec> = vec![Event { + time: 1, + teams: smallvec![ + Team::with_members([Member::new("a")]), + Team::with_members([Member::new("b")]), + ], + outcome: Outcome::ranking([0, 1, 2]), // 3 ranks but 2 teams + }]; + let err = h.add_events(events).unwrap_err(); + assert!(matches!(err, InferenceError::MismatchedShape { .. })); +}