T0 + T1 + T2: engine redesign through new API surface #1

Merged
logaritmisk merged 45 commits from t2-new-api-surface into main 2026-04-24 11:20:04 +00:00
3 changed files with 205 additions and 40 deletions
Showing only changes of commit 244b94a3e5 - Show all commits

View File

@@ -1,3 +1,5 @@
use std::collections::HashMap;
use plotters::prelude::*; use plotters::prelude::*;
use time::{Date, Month}; use time::{Date, Month};
use trueskill_tt::{History, KeyTable}; use trueskill_tt::{History, KeyTable};
@@ -40,7 +42,7 @@ fn main() {
let mut hist = History::builder().sigma(1.6).gamma(0.036).build(); let mut hist = History::builder().sigma(1.6).gamma(0.036).build();
hist.add_events(composition, results, times, vec![]) hist.add_events_with_prior(composition, results, times, vec![], HashMap::new())
.unwrap(); .unwrap();
hist.convergence(10, 0.01, true); hist.convergence(10, 0.01, true);

View File

@@ -332,24 +332,14 @@ impl<T: Time, D: Drift<T>, O: Observer<T>, K: Eq + Hash + Clone> History<T, D, O
} }
} }
impl<D: Drift<i64>, O: Observer<i64>, K: Eq + Hash + Clone> History<i64, D, O, K> { impl<T: Time, D: Drift<T>, O: Observer<T>, K: Eq + Hash + Clone> History<T, D, O, K> {
pub fn add_events(
&mut self,
composition: Vec<Vec<Vec<Index>>>,
results: Vec<Vec<f64>>,
times: Vec<i64>,
weights: Vec<Vec<Vec<f64>>>,
) -> Result<(), InferenceError> {
self.add_events_with_prior(composition, results, times, weights, HashMap::new())
}
pub fn add_events_with_prior( pub fn add_events_with_prior(
&mut self, &mut self,
composition: Vec<Vec<Vec<Index>>>, composition: Vec<Vec<Vec<Index>>>,
results: Vec<Vec<f64>>, results: Vec<Vec<f64>>,
times: Vec<i64>, times: Vec<T>,
weights: Vec<Vec<Vec<f64>>>, weights: Vec<Vec<Vec<f64>>>,
mut priors: HashMap<Index, Rating<i64, D>>, mut priors: HashMap<Index, Rating<T, D>>,
) -> Result<(), InferenceError> { ) -> Result<(), InferenceError> {
if !results.is_empty() && results.len() != composition.len() { if !results.is_empty() && results.len() != composition.len() {
return Err(InferenceError::MismatchedShape { return Err(InferenceError::MismatchedShape {
@@ -513,12 +503,7 @@ impl<D: Drift<i64>, O: Observer<i64>, K: Eq + Hash + Clone> History<i64, D, O, K
Ok(()) Ok(())
} }
pub fn record_winner<Q>( pub fn record_winner<Q>(&mut self, winner: &Q, loser: &Q, time: T) -> Result<(), InferenceError>
&mut self,
winner: &Q,
loser: &Q,
time: i64,
) -> Result<(), InferenceError>
where where
K: Borrow<Q>, K: Borrow<Q>,
Q: Hash + Eq + ToOwned<Owned = K> + ?Sized, Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
@@ -534,7 +519,7 @@ impl<D: Drift<i64>, O: Observer<i64>, K: Eq + Hash + Clone> History<i64, D, O, K
) )
} }
pub fn record_draw<Q>(&mut self, a: &Q, b: &Q, time: i64) -> Result<(), InferenceError> pub fn record_draw<Q>(&mut self, a: &Q, b: &Q, time: T) -> Result<(), InferenceError>
where where
K: Borrow<Q>, K: Borrow<Q>,
Q: Hash + Eq + ToOwned<Owned = K> + ?Sized, Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
@@ -549,6 +534,62 @@ impl<D: Drift<i64>, O: Observer<i64>, K: Eq + Hash + Clone> History<i64, D, O, K
HashMap::new(), HashMap::new(),
) )
} }
/// Bulk-ingest typed events.
pub fn add_events<I>(&mut self, events: I) -> Result<(), InferenceError>
where
I: IntoIterator<Item = crate::event::Event<T, K>>,
{
use crate::event::Event;
let events: Vec<Event<T, K>> = events.into_iter().collect();
if events.is_empty() {
return Ok(());
}
let mut composition: Vec<Vec<Vec<Index>>> = Vec::with_capacity(events.len());
let mut results: Vec<Vec<f64>> = Vec::with_capacity(events.len());
let mut times: Vec<T> = Vec::with_capacity(events.len());
let mut weights: Vec<Vec<Vec<f64>>> = Vec::with_capacity(events.len());
let mut priors: HashMap<Index, Rating<T, D>> = HashMap::new();
for ev in events {
let ranks = ev.outcome.as_ranks();
if ranks.len() != ev.teams.len() {
return Err(InferenceError::MismatchedShape {
kind: "outcome ranks vs teams",
expected: ev.teams.len(),
got: ranks.len(),
});
}
let mut event_comp: Vec<Vec<Index>> = Vec::with_capacity(ev.teams.len());
let mut event_weights: Vec<Vec<f64>> = Vec::with_capacity(ev.teams.len());
for team in ev.teams {
let mut team_indices: Vec<Index> = Vec::with_capacity(team.members.len());
let mut team_weights: Vec<f64> = Vec::with_capacity(team.members.len());
for member in team.members {
let idx = self.keys.get_or_create(&member.key);
team_indices.push(idx);
team_weights.push(member.weight);
if let Some(prior) = member.prior {
priors.insert(idx, Rating::new(prior, self.beta, self.drift));
}
}
event_comp.push(team_indices);
event_weights.push(team_weights);
}
composition.push(event_comp);
weights.push(event_weights);
let max_rank = ranks.iter().copied().max().unwrap_or(0) as f64;
let inverted: Vec<f64> = ranks.iter().map(|&r| max_rank - r as f64).collect();
results.push(inverted);
times.push(ev.time);
}
self.add_events_with_prior(composition, results, times, weights, priors)
}
} }
#[cfg(test)] #[cfg(test)]
@@ -825,7 +866,8 @@ mod tests {
let n = composition.len(); let n = composition.len();
let times: Vec<i64> = (1..=n as i64).collect(); let times: Vec<i64> = (1..=n as i64).collect();
h.add_events(composition, results, times, vec![]).unwrap(); h.add_events_with_prior(composition, results, times, vec![], HashMap::new())
.unwrap();
h.convergence(ITERATIONS, EPSILON, false); h.convergence(ITERATIONS, EPSILON, false);
@@ -876,7 +918,8 @@ mod tests {
let n = composition.len(); let n = composition.len();
let times: Vec<i64> = (1..=n as i64).collect(); let times: Vec<i64> = (1..=n as i64).collect();
h.add_events(composition, results, times, vec![]).unwrap(); h.add_events_with_prior(composition, results, times, vec![], HashMap::new())
.unwrap();
let trueskill_log_evidence = h.log_evidence(false, &[]); let trueskill_log_evidence = h.log_evidence(false, &[]);
let trueskill_log_evidence_online = h.log_evidence(true, &[]); let trueskill_log_evidence_online = h.log_evidence(true, &[]);
@@ -964,8 +1007,14 @@ mod tests {
let n = composition.len(); let n = composition.len();
let times: Vec<i64> = (1..=n as i64).collect(); let times: Vec<i64> = (1..=n as i64).collect();
h.add_events(composition.clone(), results.clone(), times, vec![]) h.add_events_with_prior(
.unwrap(); composition.clone(),
results.clone(),
times,
vec![],
HashMap::new(),
)
.unwrap();
h.convergence(ITERATIONS, EPSILON, false); h.convergence(ITERATIONS, EPSILON, false);
@@ -989,7 +1038,8 @@ mod tests {
); );
let times2: Vec<i64> = (n as i64 + 1..=2 * n as i64).collect(); let times2: Vec<i64> = (n as i64 + 1..=2 * n as i64).collect();
h.add_events(composition, results, times2, vec![]).unwrap(); h.add_events_with_prior(composition, results, times2, vec![], HashMap::new())
.unwrap();
assert_eq!(h.time_slices.len(), 6); assert_eq!(h.time_slices.len(), 6);
@@ -1056,8 +1106,14 @@ mod tests {
let n = composition.len(); let n = composition.len();
let times: Vec<i64> = (1..=n as i64).collect(); let times: Vec<i64> = (1..=n as i64).collect();
h.add_events(composition.clone(), results.clone(), times, vec![]) h.add_events_with_prior(
.unwrap(); composition.clone(),
results.clone(),
times,
vec![],
HashMap::new(),
)
.unwrap();
h.convergence(ITERATIONS, EPSILON, false); h.convergence(ITERATIONS, EPSILON, false);
@@ -1081,7 +1137,8 @@ mod tests {
); );
let times2: Vec<i64> = (n as i64 + 1..=2 * n as i64).collect(); let times2: Vec<i64> = (n as i64 + 1..=2 * n as i64).collect();
h.add_events(composition, results, times2, vec![]).unwrap(); h.add_events_with_prior(composition, results, times2, vec![], HashMap::new())
.unwrap();
assert_eq!(h.time_slices.len(), 6); assert_eq!(h.time_slices.len(), 6);
@@ -1137,8 +1194,14 @@ mod tests {
let n = composition.len(); let n = composition.len();
let times: Vec<i64> = (1..=n as i64).collect(); let times: Vec<i64> = (1..=n as i64).collect();
h.add_events(composition.clone(), vec![], times.clone(), vec![]) h.add_events_with_prior(
.unwrap(); composition.clone(),
vec![],
times.clone(),
vec![],
HashMap::new(),
)
.unwrap();
let p_d_m_2 = h.log_evidence(false, &[]).exp() * 2.0; let p_d_m_2 = h.log_evidence(false, &[]).exp() * 2.0;
@@ -1177,7 +1240,8 @@ mod tests {
let mut h = History::builder().build(); let mut h = History::builder().build();
h.add_events(composition, vec![], times, vec![]).unwrap(); h.add_events_with_prior(composition, vec![], times, vec![], HashMap::new())
.unwrap();
assert_ulps_eq!( assert_ulps_eq!(
((0.5f64 * 0.1765).ln() / 2.0).exp(), ((0.5f64 * 0.1765).ln() / 2.0).exp(),
@@ -1208,18 +1272,25 @@ mod tests {
.gamma(0.0) .gamma(0.0)
.build(); .build();
h.add_events( h.add_events_with_prior(
composition.clone(), composition.clone(),
results.clone(), results.clone(),
vec![0, 10, 20], vec![0, 10, 20],
vec![], vec![],
HashMap::new(),
) )
.unwrap(); .unwrap();
h.convergence(ITERATIONS, EPSILON, false); h.convergence(ITERATIONS, EPSILON, false);
h.add_events(composition, results, vec![15, 10, 0], vec![]) h.add_events_with_prior(
.unwrap(); composition,
results,
vec![15, 10, 0],
vec![],
HashMap::new(),
)
.unwrap();
assert_eq!(h.time_slices.len(), 4); assert_eq!(h.time_slices.len(), 4);
@@ -1303,12 +1374,18 @@ mod tests {
.gamma(0.0) .gamma(0.0)
.build(); .build();
h.add_events(composition.clone(), vec![], vec![0, 10, 20], vec![]) h.add_events_with_prior(
.unwrap(); composition.clone(),
vec![],
vec![0, 10, 20],
vec![],
HashMap::new(),
)
.unwrap();
h.convergence(ITERATIONS, EPSILON, false); h.convergence(ITERATIONS, EPSILON, false);
h.add_events(composition, vec![], vec![15, 10, 0], vec![]) h.add_events_with_prior(composition, vec![], vec![15, 10, 0], vec![], HashMap::new())
.unwrap(); .unwrap();
assert_eq!(h.time_slices.len(), 4); assert_eq!(h.time_slices.len(), 4);
@@ -1398,7 +1475,8 @@ mod tests {
let n = composition.len(); let n = composition.len();
let times: Vec<i64> = (1..=n as i64).collect(); let times: Vec<i64> = (1..=n as i64).collect();
h.add_events(composition, vec![], times, weights).unwrap(); h.add_events_with_prior(composition, vec![], times, weights, HashMap::new())
.unwrap();
let lc = h.learning_curves(); let lc = h.learning_curves();
@@ -1459,7 +1537,8 @@ mod tests {
epsilon: 1e-6, epsilon: 1e-6,
}) })
.build(); .build();
h.add_events(composition, results, times, vec![]).unwrap(); h.add_events_with_prior(composition, results, times, vec![], HashMap::new())
.unwrap();
let report = h.converge().unwrap(); let report = h.converge().unwrap();
assert!(report.converged); assert!(report.converged);

84
tests/api_shape.rs Normal file
View File

@@ -0,0 +1,84 @@
//! Tests for the new T2 public API surface: typed add_events(iter) and the
//! fluent event builder (added in Task 16).
use smallvec::smallvec;
use trueskill_tt::{ConstantDrift, ConvergenceOptions, Event, History, Member, Outcome, Team};
#[test]
fn add_events_bulk_via_iter() {
let mut h = History::builder()
.mu(0.0)
.sigma(2.0)
.beta(1.0)
.p_draw(0.0)
.drift(ConstantDrift(0.0))
.convergence(ConvergenceOptions {
max_iter: 30,
epsilon: 1e-6,
})
.build();
let events: Vec<Event<i64, &'static str>> = vec![
Event {
time: 1,
teams: smallvec![
Team::with_members([Member::new("a")]),
Team::with_members([Member::new("b")]),
],
outcome: Outcome::winner(0, 2),
},
Event {
time: 2,
teams: smallvec![
Team::with_members([Member::new("b")]),
Team::with_members([Member::new("c")]),
],
outcome: Outcome::winner(0, 2),
},
];
h.add_events(events).unwrap();
let report = h.converge().unwrap();
assert!(report.converged);
assert!(h.lookup(&"a").is_some());
assert!(h.lookup(&"b").is_some());
assert!(h.lookup(&"c").is_some());
}
#[test]
fn add_events_draw() {
let mut h = History::builder()
.mu(25.0)
.sigma(25.0 / 3.0)
.beta(25.0 / 6.0)
.p_draw(0.25)
.drift(ConstantDrift(25.0 / 300.0))
.build();
let events: Vec<Event<i64, &'static str>> = vec![Event {
time: 1,
teams: smallvec![
Team::with_members([Member::new("alice")]),
Team::with_members([Member::new("bob")]),
],
outcome: Outcome::draw(2),
}];
h.add_events(events).unwrap();
h.converge().unwrap();
}
#[test]
fn add_events_rejects_mismatched_outcome_ranks() {
use trueskill_tt::InferenceError;
let mut h: History = History::builder().build();
let events: Vec<Event<i64, &'static str>> = vec![Event {
time: 1,
teams: smallvec![
Team::with_members([Member::new("a")]),
Team::with_members([Member::new("b")]),
],
outcome: Outcome::ranking([0, 1, 2]), // 3 ranks but 2 teams
}];
let err = h.add_events(events).unwrap_err();
assert!(matches!(err, InferenceError::MismatchedShape { .. }));
}