From e62568bf3e91a949ba5eac70a5b2bf2788d3cdc8 Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Fri, 24 Apr 2026 12:47:41 +0200 Subject: [PATCH] feat(api): add current_skill / learning_curve / log_evidence / predict_* New public query methods on History: - current_skill(&K) -> Option: latest posterior for a key - learning_curve(&K) -> Vec<(T, Gaussian)>: single-key history - learning_curves() -> HashMap>: all-keys history - log_evidence() -> f64: total log-evidence (was log_evidence(false,&[])) - log_evidence_for(&[&K]) -> f64: subset log-evidence - predict_quality(&[&[&K]]) -> f64: draw-probability match quality - predict_outcome(&[&[&K]]) -> Vec: 2-team win probabilities learning_curves() changed from returning HashMap> to HashMap>. A new learning_curves_by_index() helper preserves the old Index-keyed shape for callers that ingest via the pub(crate) Index path. log_evidence(false, &[]) was renamed to log_evidence_internal and made pub(crate); the new zero-arg log_evidence() wraps it. predict_outcome is T2 2-team-only; N-team deferred to T4. KeyTable::get no longer requires ToOwned (only needed for get_or_create), allowing query methods to use simpler bounds. Part of T2 of docs/superpowers/specs/2026-04-23-trueskill-engine-redesign-design.md. Co-Authored-By: Claude Sonnet 4.6 --- examples/atp.rs | 2 +- src/history.rs | 166 +++++++++++++++++++++++++++++++++++++-------- src/key_table.rs | 2 +- tests/api_shape.rs | 81 ++++++++++++++++++++++ 4 files changed, 220 insertions(+), 31 deletions(-) diff --git a/examples/atp.rs b/examples/atp.rs index 7a96599..9aa136b 100644 --- a/examples/atp.rs +++ b/examples/atp.rs @@ -64,7 +64,7 @@ fn main() { ("wilander", "w023", 32600), ]; - let curves = hist.learning_curves(); + let curves = hist.learning_curves_by_index(); let mut x_spec = (f64::MAX, f64::MIN); let mut y_spec = (f64::MAX, f64::MIN); diff --git a/src/history.rs b/src/history.rs index 7ba717a..7606432 100644 --- a/src/history.rs +++ b/src/history.rs @@ -276,31 +276,139 @@ impl, O: Observer, K: Eq + Hash + Clone> History HashMap> { + /// Like `learning_curves`, but keyed by internal `Index`. Useful when + /// events were ingested via `Index` (rather than `record_winner` / + /// typed `add_events`), which doesn't populate the KeyTable. + pub fn learning_curves_by_index(&self) -> HashMap> { let mut data: HashMap> = HashMap::new(); - for b in &self.time_slices { for (agent, skill) in b.skills.iter() { - let point = (b.time, skill.posterior()); - - if let Some(entry) = data.get_mut(&agent) { - entry.push(point); - } else { - data.insert(agent, vec![point]); - } + data.entry(agent) + .or_default() + .push((b.time, skill.posterior())); } } - data } - pub fn log_evidence(&mut self, forward: bool, targets: &[Index]) -> f64 { + /// Learning curves for all competitors, keyed by their user-facing key. + /// + /// Returns an empty map for histories ingested via the raw `Index` path + /// (i.e. `add_events_with_prior` without `intern`/`record_winner`). + /// Use `learning_curves_by_index()` in that case. + /// + /// Note: `key(idx)` is O(n) per lookup; this method is therefore O(n²) + /// in the number of competitors. Acceptable for T2; T3 may optimize. + pub fn learning_curves(&self) -> HashMap> { + let mut data: HashMap> = HashMap::new(); + for slice in &self.time_slices { + for (idx, skill) in slice.skills.iter() { + if let Some(key) = self.keys.key(idx).cloned() { + data.entry(key) + .or_default() + .push((slice.time, skill.posterior())); + } + } + } + data + } + + /// Skill estimate at the latest time slice the competitor appears in. + pub fn current_skill(&self, key: &Q) -> Option + where + K: std::borrow::Borrow, + Q: std::hash::Hash + Eq + ?Sized, + { + let idx = self.keys.get(key)?; + self.time_slices + .iter() + .rev() + .find_map(|ts| ts.skills.get(idx).map(|sk| sk.posterior())) + } + + /// Learning curve for a single key: (time, posterior) pairs in time order. + pub fn learning_curve(&self, key: &Q) -> Vec<(T, Gaussian)> + where + K: std::borrow::Borrow, + Q: std::hash::Hash + Eq + ?Sized, + { + let Some(idx) = self.keys.get(key) else { + return Vec::new(); + }; + self.time_slices + .iter() + .filter_map(|ts| ts.skills.get(idx).map(|sk| (ts.time, sk.posterior()))) + .collect() + } + + pub(crate) fn log_evidence_internal(&mut self, forward: bool, targets: &[Index]) -> f64 { self.time_slices .iter() .map(|ts| ts.log_evidence(self.online, targets, forward, &self.agents)) .sum() } + /// Total log-evidence across the history. + pub fn log_evidence(&mut self) -> f64 { + self.log_evidence_internal(false, &[]) + } + + /// Log-evidence restricted to time slices containing at least one of the + /// given keys. Useful for leave-one-out cross-validation. + pub fn log_evidence_for(&mut self, keys: &[&Q]) -> f64 + where + K: std::borrow::Borrow, + Q: std::hash::Hash + Eq + ?Sized, + { + let targets: Vec = keys.iter().filter_map(|k| self.keys.get(*k)).collect(); + self.log_evidence_internal(false, &targets) + } + + /// Draw-probability quality metric for the given teams (key slices). + /// + /// Values range roughly [0, 1]; 1 == perfectly matched. + pub fn predict_quality(&self, teams: &[&[&K]]) -> f64 { + let groups: Vec> = teams + .iter() + .map(|team| { + team.iter() + .filter_map(|k| self.keys.get(*k)) + .filter_map(|idx| { + self.time_slices + .iter() + .rev() + .find_map(|ts| ts.skills.get(idx).map(|s| s.posterior())) + }) + .collect() + }) + .collect(); + let group_refs: Vec<&[Gaussian]> = groups.iter().map(|g| g.as_slice()).collect(); + crate::quality(&group_refs, self.beta) + } + + /// 2-team win probability: returns `[P(team0 wins), P(team1 wins)]`. + /// + /// Panics if `teams.len() != 2`. N-team support lands in T4. + pub fn predict_outcome(&self, teams: &[&[&K]]) -> Vec { + assert_eq!(teams.len(), 2, "predict_outcome T2: 2 teams only"); + let gather = |team: &[&K]| -> Gaussian { + team.iter() + .filter_map(|k| self.keys.get(*k)) + .filter_map(|idx| { + self.time_slices + .iter() + .rev() + .find_map(|ts| ts.skills.get(idx).map(|s| s.posterior())) + }) + .fold(crate::N00, |acc, g| acc + g.forget(self.beta.powi(2))) + }; + let a = gather(teams[0]); + let b = gather(teams[1]); + let diff = a - b; + let p_a = 1.0 - crate::cdf(0.0, diff.mu(), diff.sigma()); + vec![p_a, 1.0 - p_a] + } + /// Run the full forward+backward convergence loop and return a summary. pub fn converge(&mut self) -> Result { use std::time::Instant; @@ -319,7 +427,7 @@ impl, O: Observer, K: Eq + Hash + Clone> History>(&self, k: &Q) -> Option + pub fn get(&self, k: &Q) -> Option where K: Borrow, { diff --git a/tests/api_shape.rs b/tests/api_shape.rs index 886be48..676d568 100644 --- a/tests/api_shape.rs +++ b/tests/api_shape.rs @@ -142,3 +142,84 @@ fn fluent_event_builder_draw() { .unwrap(); h.converge().unwrap(); } + +#[test] +fn current_skill_and_learning_curve() { + use trueskill_tt::History; + let mut h = History::builder() + .mu(25.0) + .sigma(25.0 / 3.0) + .beta(25.0 / 6.0) + .p_draw(0.0) + .build(); + h.record_winner(&"a", &"b", 1).unwrap(); + h.record_winner(&"a", &"b", 2).unwrap(); + h.converge().unwrap(); + + let a = h.current_skill(&"a").unwrap(); + assert!(a.mu() > 25.0); + let b = h.current_skill(&"b").unwrap(); + assert!(b.mu() < 25.0); + + let a_curve = h.learning_curve(&"a"); + assert_eq!(a_curve.len(), 2); + assert_eq!(a_curve[0].0, 1); + assert_eq!(a_curve[1].0, 2); + + let all = h.learning_curves(); + assert_eq!(all.len(), 2); + assert!(all.contains_key("a")); + assert!(all.contains_key("b")); +} + +#[test] +fn log_evidence_total_vs_subset() { + use trueskill_tt::{ConstantDrift, History}; + let mut h = History::builder() + .mu(0.0) + .sigma(6.0) + .beta(1.0) + .p_draw(0.0) + .drift(ConstantDrift(0.0)) + .build(); + h.record_winner(&"a", &"b", 1).unwrap(); + h.record_winner(&"b", &"a", 2).unwrap(); + let total = h.log_evidence(); + let a_only = h.log_evidence_for(&[&"a"]); + assert!(total.is_finite()); + assert!(a_only.is_finite()); +} + +#[test] +fn predict_quality_two_teams() { + use trueskill_tt::History; + let mut h = History::builder() + .mu(25.0) + .sigma(25.0 / 3.0) + .beta(25.0 / 6.0) + .p_draw(0.0) + .build(); + h.record_winner(&"a", &"b", 1).unwrap(); + h.converge().unwrap(); + + let q = h.predict_quality(&[&[&"a"], &[&"b"]]); + assert!(q > 0.0 && q <= 1.0); +} + +#[test] +fn predict_outcome_two_teams_sums_to_one() { + use trueskill_tt::History; + let mut h = History::builder() + .mu(25.0) + .sigma(25.0 / 3.0) + .beta(25.0 / 6.0) + .p_draw(0.0) + .build(); + h.record_winner(&"a", &"b", 1).unwrap(); + h.converge().unwrap(); + + let p = h.predict_outcome(&[&[&"a"], &[&"b"]]); + assert_eq!(p.len(), 2); + assert!((p[0] + p[1] - 1.0).abs() < 1e-9); + assert!(p[0] > p[1]); +}