T0 + T1 + T2: engine redesign through new API surface #1

Merged
logaritmisk merged 45 commits from t2-new-api-surface into main 2026-04-24 11:20:04 +00:00
4 changed files with 220 additions and 31 deletions
Showing only changes of commit e62568bf3e - Show all commits

View File

@@ -64,7 +64,7 @@ fn main() {
("wilander", "w023", 32600),
];
let curves = hist.learning_curves();
let curves = hist.learning_curves_by_index();
let mut x_spec = (f64::MAX, f64::MIN);
let mut y_spec = (f64::MAX, f64::MIN);

View File

@@ -276,31 +276,139 @@ impl<T: Time, D: Drift<T>, O: Observer<T>, K: Eq + Hash + Clone> History<T, D, O
(step, i)
}
pub fn learning_curves(&self) -> HashMap<Index, Vec<(T, Gaussian)>> {
/// Like `learning_curves`, but keyed by internal `Index`. Useful when
/// events were ingested via `Index` (rather than `record_winner` /
/// typed `add_events`), which doesn't populate the KeyTable.
pub fn learning_curves_by_index(&self) -> HashMap<Index, Vec<(T, Gaussian)>> {
let mut data: HashMap<Index, Vec<(T, Gaussian)>> = HashMap::new();
for b in &self.time_slices {
for (agent, skill) in b.skills.iter() {
let point = (b.time, skill.posterior());
if let Some(entry) = data.get_mut(&agent) {
entry.push(point);
} else {
data.insert(agent, vec![point]);
}
data.entry(agent)
.or_default()
.push((b.time, skill.posterior()));
}
}
data
}
pub fn log_evidence(&mut self, forward: bool, targets: &[Index]) -> f64 {
/// Learning curves for all competitors, keyed by their user-facing key.
///
/// Returns an empty map for histories ingested via the raw `Index` path
/// (i.e. `add_events_with_prior` without `intern`/`record_winner`).
/// Use `learning_curves_by_index()` in that case.
///
/// Note: `key(idx)` is O(n) per lookup; this method is therefore O(n²)
/// in the number of competitors. Acceptable for T2; T3 may optimize.
pub fn learning_curves(&self) -> HashMap<K, Vec<(T, Gaussian)>> {
let mut data: HashMap<K, Vec<(T, Gaussian)>> = HashMap::new();
for slice in &self.time_slices {
for (idx, skill) in slice.skills.iter() {
if let Some(key) = self.keys.key(idx).cloned() {
data.entry(key)
.or_default()
.push((slice.time, skill.posterior()));
}
}
}
data
}
/// Skill estimate at the latest time slice the competitor appears in.
pub fn current_skill<Q>(&self, key: &Q) -> Option<Gaussian>
where
K: std::borrow::Borrow<Q>,
Q: std::hash::Hash + Eq + ?Sized,
{
let idx = self.keys.get(key)?;
self.time_slices
.iter()
.rev()
.find_map(|ts| ts.skills.get(idx).map(|sk| sk.posterior()))
}
/// Learning curve for a single key: (time, posterior) pairs in time order.
pub fn learning_curve<Q>(&self, key: &Q) -> Vec<(T, Gaussian)>
where
K: std::borrow::Borrow<Q>,
Q: std::hash::Hash + Eq + ?Sized,
{
let Some(idx) = self.keys.get(key) else {
return Vec::new();
};
self.time_slices
.iter()
.filter_map(|ts| ts.skills.get(idx).map(|sk| (ts.time, sk.posterior())))
.collect()
}
pub(crate) fn log_evidence_internal(&mut self, forward: bool, targets: &[Index]) -> f64 {
self.time_slices
.iter()
.map(|ts| ts.log_evidence(self.online, targets, forward, &self.agents))
.sum()
}
/// Total log-evidence across the history.
pub fn log_evidence(&mut self) -> f64 {
self.log_evidence_internal(false, &[])
}
/// Log-evidence restricted to time slices containing at least one of the
/// given keys. Useful for leave-one-out cross-validation.
pub fn log_evidence_for<Q>(&mut self, keys: &[&Q]) -> f64
where
K: std::borrow::Borrow<Q>,
Q: std::hash::Hash + Eq + ?Sized,
{
let targets: Vec<Index> = keys.iter().filter_map(|k| self.keys.get(*k)).collect();
self.log_evidence_internal(false, &targets)
}
/// Draw-probability quality metric for the given teams (key slices).
///
/// Values range roughly [0, 1]; 1 == perfectly matched.
pub fn predict_quality(&self, teams: &[&[&K]]) -> f64 {
let groups: Vec<Vec<Gaussian>> = teams
.iter()
.map(|team| {
team.iter()
.filter_map(|k| self.keys.get(*k))
.filter_map(|idx| {
self.time_slices
.iter()
.rev()
.find_map(|ts| ts.skills.get(idx).map(|s| s.posterior()))
})
.collect()
})
.collect();
let group_refs: Vec<&[Gaussian]> = groups.iter().map(|g| g.as_slice()).collect();
crate::quality(&group_refs, self.beta)
}
/// 2-team win probability: returns `[P(team0 wins), P(team1 wins)]`.
///
/// Panics if `teams.len() != 2`. N-team support lands in T4.
pub fn predict_outcome(&self, teams: &[&[&K]]) -> Vec<f64> {
assert_eq!(teams.len(), 2, "predict_outcome T2: 2 teams only");
let gather = |team: &[&K]| -> Gaussian {
team.iter()
.filter_map(|k| self.keys.get(*k))
.filter_map(|idx| {
self.time_slices
.iter()
.rev()
.find_map(|ts| ts.skills.get(idx).map(|s| s.posterior()))
})
.fold(crate::N00, |acc, g| acc + g.forget(self.beta.powi(2)))
};
let a = gather(teams[0]);
let b = gather(teams[1]);
let diff = a - b;
let p_a = 1.0 - crate::cdf(0.0, diff.mu(), diff.sigma());
vec![p_a, 1.0 - p_a]
}
/// Run the full forward+backward convergence loop and return a summary.
pub fn converge(&mut self) -> Result<ConvergenceReport, InferenceError> {
use std::time::Instant;
@@ -319,7 +427,7 @@ impl<T: Time, D: Drift<T>, O: Observer<T>, K: Eq + Hash + Clone> History<T, D, O
self.observer.on_iteration_end(i, step);
}
let converged = !tuple_gt(step, opts.epsilon);
let log_evidence = self.log_evidence(false, &[]);
let log_evidence = self.log_evidence_internal(false, &[]);
self.observer.on_converged(i, step, converged);
Ok(ConvergenceReport {
iterations: i,
@@ -827,7 +935,7 @@ mod tests {
.unwrap();
h.convergence(ITERATIONS, EPSILON, false);
let lc = h.learning_curves();
let lc = h.learning_curves_by_index();
let aj_e = lc[&a].len();
let cj_e = lc[&c].len();
@@ -926,8 +1034,8 @@ mod tests {
h.add_events_with_prior(composition, results, times, vec![], HashMap::new())
.unwrap();
let trueskill_log_evidence = h.log_evidence(false, &[]);
let trueskill_log_evidence_online = h.log_evidence(true, &[]);
let trueskill_log_evidence = h.log_evidence_internal(false, &[]);
let trueskill_log_evidence_online = h.log_evidence_internal(true, &[]);
assert_ulps_eq!(
trueskill_log_evidence,
@@ -941,16 +1049,16 @@ mod tests {
epsilon = 1e-6
);
let evidence_second_event = h.log_evidence(false, &[b]).exp() * 2.0;
let evidence_second_event = h.log_evidence_internal(false, &[b]).exp() * 2.0;
assert_ulps_eq!(0.5, evidence_second_event, epsilon = 1e-6);
let evidence_third_event = h.log_evidence(false, &[a]).exp() * 2.0;
let evidence_third_event = h.log_evidence_internal(false, &[a]).exp() * 2.0;
assert_ulps_eq!(0.669885, evidence_third_event, epsilon = 1e-6);
h.convergence(ITERATIONS, EPSILON, false);
let loocv_hat = h.log_evidence(false, &[]).exp();
let p_d_m_hat = h.log_evidence(true, &[]).exp();
let loocv_hat = h.log_evidence_internal(false, &[]).exp();
let p_d_m_hat = h.log_evidence_internal(true, &[]).exp();
assert_ulps_eq!(loocv_hat, 0.241027, epsilon = 1e-6);
assert_ulps_eq!(p_d_m_hat, 0.172432, epsilon = 1e-6);
@@ -1208,38 +1316,38 @@ mod tests {
)
.unwrap();
let p_d_m_2 = h.log_evidence(false, &[]).exp() * 2.0;
let p_d_m_2 = h.log_evidence_internal(false, &[]).exp() * 2.0;
assert_ulps_eq!(p_d_m_2, 0.17650911, epsilon = 1e-6);
assert_ulps_eq!(
p_d_m_2,
h.log_evidence(true, &[]).exp() * 2.0,
h.log_evidence_internal(true, &[]).exp() * 2.0,
epsilon = 1e-6
);
assert_ulps_eq!(
p_d_m_2,
h.log_evidence(true, &[a]).exp() * 2.0,
h.log_evidence_internal(true, &[a]).exp() * 2.0,
epsilon = 1e-6
);
assert_ulps_eq!(
p_d_m_2,
h.log_evidence(false, &[a]).exp() * 2.0,
h.log_evidence_internal(false, &[a]).exp() * 2.0,
epsilon = 1e-6
);
h.convergence(11, EPSILON, false);
let loocv_approx_2 = h.log_evidence(false, &[]).exp().sqrt();
let loocv_approx_2 = h.log_evidence_internal(false, &[]).exp().sqrt();
assert_ulps_eq!(loocv_approx_2, 0.001976774, epsilon = 0.000001);
let p_d_m_approx_2 = h.log_evidence(true, &[]).exp() * 2.0;
let p_d_m_approx_2 = h.log_evidence_internal(true, &[]).exp() * 2.0;
assert!(loocv_approx_2 - p_d_m_approx_2 < 1e-4);
assert_ulps_eq!(
loocv_approx_2,
h.log_evidence(true, &[b]).exp() * 2.0,
h.log_evidence_internal(true, &[b]).exp() * 2.0,
epsilon = 1e-4
);
@@ -1250,7 +1358,7 @@ mod tests {
assert_ulps_eq!(
((0.5f64 * 0.1765).ln() / 2.0).exp(),
(h.log_evidence(false, &[]) / 2.0).exp(),
(h.log_evidence_internal(false, &[]) / 2.0).exp(),
epsilon = 1e-4
);
}
@@ -1483,7 +1591,7 @@ mod tests {
h.add_events_with_prior(composition, vec![], times, weights, HashMap::new())
.unwrap();
let lc = h.learning_curves();
let lc = h.learning_curves_by_index();
assert_ulps_eq!(
lc[&a][0].1,
@@ -1508,7 +1616,7 @@ mod tests {
h.convergence(ITERATIONS, EPSILON, false);
let lc = h.learning_curves();
let lc = h.learning_curves_by_index();
assert_ulps_eq!(lc[&a][0].1, lc[&a][0].1, epsilon = 1e-6);
assert_ulps_eq!(lc[&b][0].1, lc[&a][0].1, epsilon = 1e-6);

View File

@@ -22,7 +22,7 @@ where
Self(HashMap::new())
}
pub fn get<Q: ?Sized + Hash + Eq + ToOwned<Owned = K>>(&self, k: &Q) -> Option<Index>
pub fn get<Q: ?Sized + Hash + Eq>(&self, k: &Q) -> Option<Index>
where
K: Borrow<Q>,
{

View File

@@ -142,3 +142,84 @@ fn fluent_event_builder_draw() {
.unwrap();
h.converge().unwrap();
}
#[test]
fn current_skill_and_learning_curve() {
use trueskill_tt::History;
let mut h = History::builder()
.mu(25.0)
.sigma(25.0 / 3.0)
.beta(25.0 / 6.0)
.p_draw(0.0)
.build();
h.record_winner(&"a", &"b", 1).unwrap();
h.record_winner(&"a", &"b", 2).unwrap();
h.converge().unwrap();
let a = h.current_skill(&"a").unwrap();
assert!(a.mu() > 25.0);
let b = h.current_skill(&"b").unwrap();
assert!(b.mu() < 25.0);
let a_curve = h.learning_curve(&"a");
assert_eq!(a_curve.len(), 2);
assert_eq!(a_curve[0].0, 1);
assert_eq!(a_curve[1].0, 2);
let all = h.learning_curves();
assert_eq!(all.len(), 2);
assert!(all.contains_key("a"));
assert!(all.contains_key("b"));
}
#[test]
fn log_evidence_total_vs_subset() {
use trueskill_tt::{ConstantDrift, History};
let mut h = History::builder()
.mu(0.0)
.sigma(6.0)
.beta(1.0)
.p_draw(0.0)
.drift(ConstantDrift(0.0))
.build();
h.record_winner(&"a", &"b", 1).unwrap();
h.record_winner(&"b", &"a", 2).unwrap();
let total = h.log_evidence();
let a_only = h.log_evidence_for(&[&"a"]);
assert!(total.is_finite());
assert!(a_only.is_finite());
}
#[test]
fn predict_quality_two_teams() {
use trueskill_tt::History;
let mut h = History::builder()
.mu(25.0)
.sigma(25.0 / 3.0)
.beta(25.0 / 6.0)
.p_draw(0.0)
.build();
h.record_winner(&"a", &"b", 1).unwrap();
h.converge().unwrap();
let q = h.predict_quality(&[&[&"a"], &[&"b"]]);
assert!(q > 0.0 && q <= 1.0);
}
#[test]
fn predict_outcome_two_teams_sums_to_one() {
use trueskill_tt::History;
let mut h = History::builder()
.mu(25.0)
.sigma(25.0 / 3.0)
.beta(25.0 / 6.0)
.p_draw(0.0)
.build();
h.record_winner(&"a", &"b", 1).unwrap();
h.converge().unwrap();
let p = h.predict_outcome(&[&[&"a"], &[&"b"]]);
assert_eq!(p.len(), 2);
assert!((p[0] + p[1] - 1.0).abs() < 1e-9);
assert!(p[0] > p[1]);
}