feat(api): typed add_events(iter); generify internal path over T
Public API gains:
History::add_events<I: IntoIterator<Item = Event<T, K>>>(events)
-> Result<(), InferenceError>
which accepts the typed Event<T, K> shape added in Task 10. Ranks
from Outcome::Ranked are mapped to the legacy "higher f64 = better"
results internally.
add_events_with_prior now takes Vec<T> for times (was Vec<i64>),
generifying the whole internal path over T in a single fully-generic
impl<T: Time, D: Drift<T>, O: Observer<T>, K> block. The i64-specific
block is gone; record_winner/record_draw are now generic over T.
add_events_with_prior stays pub (not pub(crate)) because the ATP
example calls it directly with pre-built Index-based composition;
the new typed add_events is the primary public API going forward.
In-crate tests updated to call add_events_with_prior with an empty
HashMap. tests/api_shape.rs added with 3 integration tests covering
bulk ingest, draw, and mismatched-outcome error.
Part of T2 of docs/superpowers/specs/2026-04-23-trueskill-engine-redesign-design.md.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use plotters::prelude::*;
|
||||
use time::{Date, Month};
|
||||
use trueskill_tt::{History, KeyTable};
|
||||
@@ -40,7 +42,7 @@ fn main() {
|
||||
|
||||
let mut hist = History::builder().sigma(1.6).gamma(0.036).build();
|
||||
|
||||
hist.add_events(composition, results, times, vec![])
|
||||
hist.add_events_with_prior(composition, results, times, vec![], HashMap::new())
|
||||
.unwrap();
|
||||
hist.convergence(10, 0.01, true);
|
||||
|
||||
|
||||
157
src/history.rs
157
src/history.rs
@@ -332,24 +332,14 @@ impl<T: Time, D: Drift<T>, O: Observer<T>, K: Eq + Hash + Clone> History<T, D, O
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Drift<i64>, O: Observer<i64>, K: Eq + Hash + Clone> History<i64, D, O, K> {
|
||||
pub fn add_events(
|
||||
&mut self,
|
||||
composition: Vec<Vec<Vec<Index>>>,
|
||||
results: Vec<Vec<f64>>,
|
||||
times: Vec<i64>,
|
||||
weights: Vec<Vec<Vec<f64>>>,
|
||||
) -> Result<(), InferenceError> {
|
||||
self.add_events_with_prior(composition, results, times, weights, HashMap::new())
|
||||
}
|
||||
|
||||
impl<T: Time, D: Drift<T>, O: Observer<T>, K: Eq + Hash + Clone> History<T, D, O, K> {
|
||||
pub fn add_events_with_prior(
|
||||
&mut self,
|
||||
composition: Vec<Vec<Vec<Index>>>,
|
||||
results: Vec<Vec<f64>>,
|
||||
times: Vec<i64>,
|
||||
times: Vec<T>,
|
||||
weights: Vec<Vec<Vec<f64>>>,
|
||||
mut priors: HashMap<Index, Rating<i64, D>>,
|
||||
mut priors: HashMap<Index, Rating<T, D>>,
|
||||
) -> Result<(), InferenceError> {
|
||||
if !results.is_empty() && results.len() != composition.len() {
|
||||
return Err(InferenceError::MismatchedShape {
|
||||
@@ -513,12 +503,7 @@ impl<D: Drift<i64>, O: Observer<i64>, K: Eq + Hash + Clone> History<i64, D, O, K
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn record_winner<Q>(
|
||||
&mut self,
|
||||
winner: &Q,
|
||||
loser: &Q,
|
||||
time: i64,
|
||||
) -> Result<(), InferenceError>
|
||||
pub fn record_winner<Q>(&mut self, winner: &Q, loser: &Q, time: T) -> Result<(), InferenceError>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
|
||||
@@ -534,7 +519,7 @@ impl<D: Drift<i64>, O: Observer<i64>, K: Eq + Hash + Clone> History<i64, D, O, K
|
||||
)
|
||||
}
|
||||
|
||||
pub fn record_draw<Q>(&mut self, a: &Q, b: &Q, time: i64) -> Result<(), InferenceError>
|
||||
pub fn record_draw<Q>(&mut self, a: &Q, b: &Q, time: T) -> Result<(), InferenceError>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
|
||||
@@ -549,6 +534,62 @@ impl<D: Drift<i64>, O: Observer<i64>, K: Eq + Hash + Clone> History<i64, D, O, K
|
||||
HashMap::new(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Bulk-ingest typed events.
|
||||
pub fn add_events<I>(&mut self, events: I) -> Result<(), InferenceError>
|
||||
where
|
||||
I: IntoIterator<Item = crate::event::Event<T, K>>,
|
||||
{
|
||||
use crate::event::Event;
|
||||
let events: Vec<Event<T, K>> = events.into_iter().collect();
|
||||
if events.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut composition: Vec<Vec<Vec<Index>>> = Vec::with_capacity(events.len());
|
||||
let mut results: Vec<Vec<f64>> = Vec::with_capacity(events.len());
|
||||
let mut times: Vec<T> = Vec::with_capacity(events.len());
|
||||
let mut weights: Vec<Vec<Vec<f64>>> = Vec::with_capacity(events.len());
|
||||
let mut priors: HashMap<Index, Rating<T, D>> = HashMap::new();
|
||||
|
||||
for ev in events {
|
||||
let ranks = ev.outcome.as_ranks();
|
||||
if ranks.len() != ev.teams.len() {
|
||||
return Err(InferenceError::MismatchedShape {
|
||||
kind: "outcome ranks vs teams",
|
||||
expected: ev.teams.len(),
|
||||
got: ranks.len(),
|
||||
});
|
||||
}
|
||||
|
||||
let mut event_comp: Vec<Vec<Index>> = Vec::with_capacity(ev.teams.len());
|
||||
let mut event_weights: Vec<Vec<f64>> = Vec::with_capacity(ev.teams.len());
|
||||
|
||||
for team in ev.teams {
|
||||
let mut team_indices: Vec<Index> = Vec::with_capacity(team.members.len());
|
||||
let mut team_weights: Vec<f64> = Vec::with_capacity(team.members.len());
|
||||
for member in team.members {
|
||||
let idx = self.keys.get_or_create(&member.key);
|
||||
team_indices.push(idx);
|
||||
team_weights.push(member.weight);
|
||||
if let Some(prior) = member.prior {
|
||||
priors.insert(idx, Rating::new(prior, self.beta, self.drift));
|
||||
}
|
||||
}
|
||||
event_comp.push(team_indices);
|
||||
event_weights.push(team_weights);
|
||||
}
|
||||
composition.push(event_comp);
|
||||
weights.push(event_weights);
|
||||
|
||||
let max_rank = ranks.iter().copied().max().unwrap_or(0) as f64;
|
||||
let inverted: Vec<f64> = ranks.iter().map(|&r| max_rank - r as f64).collect();
|
||||
results.push(inverted);
|
||||
times.push(ev.time);
|
||||
}
|
||||
|
||||
self.add_events_with_prior(composition, results, times, weights, priors)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -825,7 +866,8 @@ mod tests {
|
||||
|
||||
let n = composition.len();
|
||||
let times: Vec<i64> = (1..=n as i64).collect();
|
||||
h.add_events(composition, results, times, vec![]).unwrap();
|
||||
h.add_events_with_prior(composition, results, times, vec![], HashMap::new())
|
||||
.unwrap();
|
||||
|
||||
h.convergence(ITERATIONS, EPSILON, false);
|
||||
|
||||
@@ -876,7 +918,8 @@ mod tests {
|
||||
|
||||
let n = composition.len();
|
||||
let times: Vec<i64> = (1..=n as i64).collect();
|
||||
h.add_events(composition, results, times, vec![]).unwrap();
|
||||
h.add_events_with_prior(composition, results, times, vec![], HashMap::new())
|
||||
.unwrap();
|
||||
|
||||
let trueskill_log_evidence = h.log_evidence(false, &[]);
|
||||
let trueskill_log_evidence_online = h.log_evidence(true, &[]);
|
||||
@@ -964,8 +1007,14 @@ mod tests {
|
||||
|
||||
let n = composition.len();
|
||||
let times: Vec<i64> = (1..=n as i64).collect();
|
||||
h.add_events(composition.clone(), results.clone(), times, vec![])
|
||||
.unwrap();
|
||||
h.add_events_with_prior(
|
||||
composition.clone(),
|
||||
results.clone(),
|
||||
times,
|
||||
vec![],
|
||||
HashMap::new(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
h.convergence(ITERATIONS, EPSILON, false);
|
||||
|
||||
@@ -989,7 +1038,8 @@ mod tests {
|
||||
);
|
||||
|
||||
let times2: Vec<i64> = (n as i64 + 1..=2 * n as i64).collect();
|
||||
h.add_events(composition, results, times2, vec![]).unwrap();
|
||||
h.add_events_with_prior(composition, results, times2, vec![], HashMap::new())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(h.time_slices.len(), 6);
|
||||
|
||||
@@ -1056,8 +1106,14 @@ mod tests {
|
||||
|
||||
let n = composition.len();
|
||||
let times: Vec<i64> = (1..=n as i64).collect();
|
||||
h.add_events(composition.clone(), results.clone(), times, vec![])
|
||||
.unwrap();
|
||||
h.add_events_with_prior(
|
||||
composition.clone(),
|
||||
results.clone(),
|
||||
times,
|
||||
vec![],
|
||||
HashMap::new(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
h.convergence(ITERATIONS, EPSILON, false);
|
||||
|
||||
@@ -1081,7 +1137,8 @@ mod tests {
|
||||
);
|
||||
|
||||
let times2: Vec<i64> = (n as i64 + 1..=2 * n as i64).collect();
|
||||
h.add_events(composition, results, times2, vec![]).unwrap();
|
||||
h.add_events_with_prior(composition, results, times2, vec![], HashMap::new())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(h.time_slices.len(), 6);
|
||||
|
||||
@@ -1137,8 +1194,14 @@ mod tests {
|
||||
|
||||
let n = composition.len();
|
||||
let times: Vec<i64> = (1..=n as i64).collect();
|
||||
h.add_events(composition.clone(), vec![], times.clone(), vec![])
|
||||
.unwrap();
|
||||
h.add_events_with_prior(
|
||||
composition.clone(),
|
||||
vec![],
|
||||
times.clone(),
|
||||
vec![],
|
||||
HashMap::new(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let p_d_m_2 = h.log_evidence(false, &[]).exp() * 2.0;
|
||||
|
||||
@@ -1177,7 +1240,8 @@ mod tests {
|
||||
|
||||
let mut h = History::builder().build();
|
||||
|
||||
h.add_events(composition, vec![], times, vec![]).unwrap();
|
||||
h.add_events_with_prior(composition, vec![], times, vec![], HashMap::new())
|
||||
.unwrap();
|
||||
|
||||
assert_ulps_eq!(
|
||||
((0.5f64 * 0.1765).ln() / 2.0).exp(),
|
||||
@@ -1208,18 +1272,25 @@ mod tests {
|
||||
.gamma(0.0)
|
||||
.build();
|
||||
|
||||
h.add_events(
|
||||
h.add_events_with_prior(
|
||||
composition.clone(),
|
||||
results.clone(),
|
||||
vec![0, 10, 20],
|
||||
vec![],
|
||||
HashMap::new(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
h.convergence(ITERATIONS, EPSILON, false);
|
||||
|
||||
h.add_events(composition, results, vec![15, 10, 0], vec![])
|
||||
.unwrap();
|
||||
h.add_events_with_prior(
|
||||
composition,
|
||||
results,
|
||||
vec![15, 10, 0],
|
||||
vec![],
|
||||
HashMap::new(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(h.time_slices.len(), 4);
|
||||
|
||||
@@ -1303,12 +1374,18 @@ mod tests {
|
||||
.gamma(0.0)
|
||||
.build();
|
||||
|
||||
h.add_events(composition.clone(), vec![], vec![0, 10, 20], vec![])
|
||||
.unwrap();
|
||||
h.add_events_with_prior(
|
||||
composition.clone(),
|
||||
vec![],
|
||||
vec![0, 10, 20],
|
||||
vec![],
|
||||
HashMap::new(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
h.convergence(ITERATIONS, EPSILON, false);
|
||||
|
||||
h.add_events(composition, vec![], vec![15, 10, 0], vec![])
|
||||
h.add_events_with_prior(composition, vec![], vec![15, 10, 0], vec![], HashMap::new())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(h.time_slices.len(), 4);
|
||||
@@ -1398,7 +1475,8 @@ mod tests {
|
||||
|
||||
let n = composition.len();
|
||||
let times: Vec<i64> = (1..=n as i64).collect();
|
||||
h.add_events(composition, vec![], times, weights).unwrap();
|
||||
h.add_events_with_prior(composition, vec![], times, weights, HashMap::new())
|
||||
.unwrap();
|
||||
|
||||
let lc = h.learning_curves();
|
||||
|
||||
@@ -1459,7 +1537,8 @@ mod tests {
|
||||
epsilon: 1e-6,
|
||||
})
|
||||
.build();
|
||||
h.add_events(composition, results, times, vec![]).unwrap();
|
||||
h.add_events_with_prior(composition, results, times, vec![], HashMap::new())
|
||||
.unwrap();
|
||||
|
||||
let report = h.converge().unwrap();
|
||||
assert!(report.converged);
|
||||
|
||||
84
tests/api_shape.rs
Normal file
84
tests/api_shape.rs
Normal file
@@ -0,0 +1,84 @@
|
||||
//! Tests for the new T2 public API surface: typed add_events(iter) and the
|
||||
//! fluent event builder (added in Task 16).
|
||||
|
||||
use smallvec::smallvec;
|
||||
use trueskill_tt::{ConstantDrift, ConvergenceOptions, Event, History, Member, Outcome, Team};
|
||||
|
||||
#[test]
|
||||
fn add_events_bulk_via_iter() {
|
||||
let mut h = History::builder()
|
||||
.mu(0.0)
|
||||
.sigma(2.0)
|
||||
.beta(1.0)
|
||||
.p_draw(0.0)
|
||||
.drift(ConstantDrift(0.0))
|
||||
.convergence(ConvergenceOptions {
|
||||
max_iter: 30,
|
||||
epsilon: 1e-6,
|
||||
})
|
||||
.build();
|
||||
|
||||
let events: Vec<Event<i64, &'static str>> = vec![
|
||||
Event {
|
||||
time: 1,
|
||||
teams: smallvec![
|
||||
Team::with_members([Member::new("a")]),
|
||||
Team::with_members([Member::new("b")]),
|
||||
],
|
||||
outcome: Outcome::winner(0, 2),
|
||||
},
|
||||
Event {
|
||||
time: 2,
|
||||
teams: smallvec![
|
||||
Team::with_members([Member::new("b")]),
|
||||
Team::with_members([Member::new("c")]),
|
||||
],
|
||||
outcome: Outcome::winner(0, 2),
|
||||
},
|
||||
];
|
||||
|
||||
h.add_events(events).unwrap();
|
||||
let report = h.converge().unwrap();
|
||||
assert!(report.converged);
|
||||
assert!(h.lookup(&"a").is_some());
|
||||
assert!(h.lookup(&"b").is_some());
|
||||
assert!(h.lookup(&"c").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_events_draw() {
|
||||
let mut h = History::builder()
|
||||
.mu(25.0)
|
||||
.sigma(25.0 / 3.0)
|
||||
.beta(25.0 / 6.0)
|
||||
.p_draw(0.25)
|
||||
.drift(ConstantDrift(25.0 / 300.0))
|
||||
.build();
|
||||
|
||||
let events: Vec<Event<i64, &'static str>> = vec![Event {
|
||||
time: 1,
|
||||
teams: smallvec![
|
||||
Team::with_members([Member::new("alice")]),
|
||||
Team::with_members([Member::new("bob")]),
|
||||
],
|
||||
outcome: Outcome::draw(2),
|
||||
}];
|
||||
h.add_events(events).unwrap();
|
||||
h.converge().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_events_rejects_mismatched_outcome_ranks() {
|
||||
use trueskill_tt::InferenceError;
|
||||
let mut h: History = History::builder().build();
|
||||
let events: Vec<Event<i64, &'static str>> = vec![Event {
|
||||
time: 1,
|
||||
teams: smallvec![
|
||||
Team::with_members([Member::new("a")]),
|
||||
Team::with_members([Member::new("b")]),
|
||||
],
|
||||
outcome: Outcome::ranking([0, 1, 2]), // 3 ranks but 2 teams
|
||||
}];
|
||||
let err = h.add_events(events).unwrap_err();
|
||||
assert!(matches!(err, InferenceError::MismatchedShape { .. }));
|
||||
}
|
||||
Reference in New Issue
Block a user