feat(api): typed add_events(iter); generify internal path over T
Public API gains:
History::add_events<I: IntoIterator<Item = Event<T, K>>>(events)
-> Result<(), InferenceError>
which accepts the typed Event<T, K> shape added in Task 10. Ranks
from Outcome::Ranked are mapped to the legacy "higher f64 = better"
results internally.
add_events_with_prior now takes Vec<T> for times (was Vec<i64>),
generifying the whole internal path over T in a single fully-generic
impl<T: Time, D: Drift<T>, O: Observer<T>, K> block. The i64-specific
block is gone; record_winner/record_draw are now generic over T.
add_events_with_prior stays pub (not pub(crate)) because the ATP
example calls it directly with pre-built Index-based composition;
the new typed add_events is the primary public API going forward.
In-crate tests updated to call add_events_with_prior with an empty
HashMap. tests/api_shape.rs added with 3 integration tests covering
bulk ingest, draw, and mismatched-outcome error.
Part of T2 of docs/superpowers/specs/2026-04-23-trueskill-engine-redesign-design.md.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,3 +1,5 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use plotters::prelude::*;
|
use plotters::prelude::*;
|
||||||
use time::{Date, Month};
|
use time::{Date, Month};
|
||||||
use trueskill_tt::{History, KeyTable};
|
use trueskill_tt::{History, KeyTable};
|
||||||
@@ -40,7 +42,7 @@ fn main() {
|
|||||||
|
|
||||||
let mut hist = History::builder().sigma(1.6).gamma(0.036).build();
|
let mut hist = History::builder().sigma(1.6).gamma(0.036).build();
|
||||||
|
|
||||||
hist.add_events(composition, results, times, vec![])
|
hist.add_events_with_prior(composition, results, times, vec![], HashMap::new())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
hist.convergence(10, 0.01, true);
|
hist.convergence(10, 0.01, true);
|
||||||
|
|
||||||
|
|||||||
147
src/history.rs
147
src/history.rs
@@ -332,24 +332,14 @@ impl<T: Time, D: Drift<T>, O: Observer<T>, K: Eq + Hash + Clone> History<T, D, O
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D: Drift<i64>, O: Observer<i64>, K: Eq + Hash + Clone> History<i64, D, O, K> {
|
impl<T: Time, D: Drift<T>, O: Observer<T>, K: Eq + Hash + Clone> History<T, D, O, K> {
|
||||||
pub fn add_events(
|
|
||||||
&mut self,
|
|
||||||
composition: Vec<Vec<Vec<Index>>>,
|
|
||||||
results: Vec<Vec<f64>>,
|
|
||||||
times: Vec<i64>,
|
|
||||||
weights: Vec<Vec<Vec<f64>>>,
|
|
||||||
) -> Result<(), InferenceError> {
|
|
||||||
self.add_events_with_prior(composition, results, times, weights, HashMap::new())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn add_events_with_prior(
|
pub fn add_events_with_prior(
|
||||||
&mut self,
|
&mut self,
|
||||||
composition: Vec<Vec<Vec<Index>>>,
|
composition: Vec<Vec<Vec<Index>>>,
|
||||||
results: Vec<Vec<f64>>,
|
results: Vec<Vec<f64>>,
|
||||||
times: Vec<i64>,
|
times: Vec<T>,
|
||||||
weights: Vec<Vec<Vec<f64>>>,
|
weights: Vec<Vec<Vec<f64>>>,
|
||||||
mut priors: HashMap<Index, Rating<i64, D>>,
|
mut priors: HashMap<Index, Rating<T, D>>,
|
||||||
) -> Result<(), InferenceError> {
|
) -> Result<(), InferenceError> {
|
||||||
if !results.is_empty() && results.len() != composition.len() {
|
if !results.is_empty() && results.len() != composition.len() {
|
||||||
return Err(InferenceError::MismatchedShape {
|
return Err(InferenceError::MismatchedShape {
|
||||||
@@ -513,12 +503,7 @@ impl<D: Drift<i64>, O: Observer<i64>, K: Eq + Hash + Clone> History<i64, D, O, K
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn record_winner<Q>(
|
pub fn record_winner<Q>(&mut self, winner: &Q, loser: &Q, time: T) -> Result<(), InferenceError>
|
||||||
&mut self,
|
|
||||||
winner: &Q,
|
|
||||||
loser: &Q,
|
|
||||||
time: i64,
|
|
||||||
) -> Result<(), InferenceError>
|
|
||||||
where
|
where
|
||||||
K: Borrow<Q>,
|
K: Borrow<Q>,
|
||||||
Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
|
Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
|
||||||
@@ -534,7 +519,7 @@ impl<D: Drift<i64>, O: Observer<i64>, K: Eq + Hash + Clone> History<i64, D, O, K
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn record_draw<Q>(&mut self, a: &Q, b: &Q, time: i64) -> Result<(), InferenceError>
|
pub fn record_draw<Q>(&mut self, a: &Q, b: &Q, time: T) -> Result<(), InferenceError>
|
||||||
where
|
where
|
||||||
K: Borrow<Q>,
|
K: Borrow<Q>,
|
||||||
Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
|
Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
|
||||||
@@ -549,6 +534,62 @@ impl<D: Drift<i64>, O: Observer<i64>, K: Eq + Hash + Clone> History<i64, D, O, K
|
|||||||
HashMap::new(),
|
HashMap::new(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Bulk-ingest typed events.
|
||||||
|
pub fn add_events<I>(&mut self, events: I) -> Result<(), InferenceError>
|
||||||
|
where
|
||||||
|
I: IntoIterator<Item = crate::event::Event<T, K>>,
|
||||||
|
{
|
||||||
|
use crate::event::Event;
|
||||||
|
let events: Vec<Event<T, K>> = events.into_iter().collect();
|
||||||
|
if events.is_empty() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut composition: Vec<Vec<Vec<Index>>> = Vec::with_capacity(events.len());
|
||||||
|
let mut results: Vec<Vec<f64>> = Vec::with_capacity(events.len());
|
||||||
|
let mut times: Vec<T> = Vec::with_capacity(events.len());
|
||||||
|
let mut weights: Vec<Vec<Vec<f64>>> = Vec::with_capacity(events.len());
|
||||||
|
let mut priors: HashMap<Index, Rating<T, D>> = HashMap::new();
|
||||||
|
|
||||||
|
for ev in events {
|
||||||
|
let ranks = ev.outcome.as_ranks();
|
||||||
|
if ranks.len() != ev.teams.len() {
|
||||||
|
return Err(InferenceError::MismatchedShape {
|
||||||
|
kind: "outcome ranks vs teams",
|
||||||
|
expected: ev.teams.len(),
|
||||||
|
got: ranks.len(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut event_comp: Vec<Vec<Index>> = Vec::with_capacity(ev.teams.len());
|
||||||
|
let mut event_weights: Vec<Vec<f64>> = Vec::with_capacity(ev.teams.len());
|
||||||
|
|
||||||
|
for team in ev.teams {
|
||||||
|
let mut team_indices: Vec<Index> = Vec::with_capacity(team.members.len());
|
||||||
|
let mut team_weights: Vec<f64> = Vec::with_capacity(team.members.len());
|
||||||
|
for member in team.members {
|
||||||
|
let idx = self.keys.get_or_create(&member.key);
|
||||||
|
team_indices.push(idx);
|
||||||
|
team_weights.push(member.weight);
|
||||||
|
if let Some(prior) = member.prior {
|
||||||
|
priors.insert(idx, Rating::new(prior, self.beta, self.drift));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
event_comp.push(team_indices);
|
||||||
|
event_weights.push(team_weights);
|
||||||
|
}
|
||||||
|
composition.push(event_comp);
|
||||||
|
weights.push(event_weights);
|
||||||
|
|
||||||
|
let max_rank = ranks.iter().copied().max().unwrap_or(0) as f64;
|
||||||
|
let inverted: Vec<f64> = ranks.iter().map(|&r| max_rank - r as f64).collect();
|
||||||
|
results.push(inverted);
|
||||||
|
times.push(ev.time);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.add_events_with_prior(composition, results, times, weights, priors)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -825,7 +866,8 @@ mod tests {
|
|||||||
|
|
||||||
let n = composition.len();
|
let n = composition.len();
|
||||||
let times: Vec<i64> = (1..=n as i64).collect();
|
let times: Vec<i64> = (1..=n as i64).collect();
|
||||||
h.add_events(composition, results, times, vec![]).unwrap();
|
h.add_events_with_prior(composition, results, times, vec![], HashMap::new())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
h.convergence(ITERATIONS, EPSILON, false);
|
h.convergence(ITERATIONS, EPSILON, false);
|
||||||
|
|
||||||
@@ -876,7 +918,8 @@ mod tests {
|
|||||||
|
|
||||||
let n = composition.len();
|
let n = composition.len();
|
||||||
let times: Vec<i64> = (1..=n as i64).collect();
|
let times: Vec<i64> = (1..=n as i64).collect();
|
||||||
h.add_events(composition, results, times, vec![]).unwrap();
|
h.add_events_with_prior(composition, results, times, vec![], HashMap::new())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let trueskill_log_evidence = h.log_evidence(false, &[]);
|
let trueskill_log_evidence = h.log_evidence(false, &[]);
|
||||||
let trueskill_log_evidence_online = h.log_evidence(true, &[]);
|
let trueskill_log_evidence_online = h.log_evidence(true, &[]);
|
||||||
@@ -964,7 +1007,13 @@ mod tests {
|
|||||||
|
|
||||||
let n = composition.len();
|
let n = composition.len();
|
||||||
let times: Vec<i64> = (1..=n as i64).collect();
|
let times: Vec<i64> = (1..=n as i64).collect();
|
||||||
h.add_events(composition.clone(), results.clone(), times, vec![])
|
h.add_events_with_prior(
|
||||||
|
composition.clone(),
|
||||||
|
results.clone(),
|
||||||
|
times,
|
||||||
|
vec![],
|
||||||
|
HashMap::new(),
|
||||||
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
h.convergence(ITERATIONS, EPSILON, false);
|
h.convergence(ITERATIONS, EPSILON, false);
|
||||||
@@ -989,7 +1038,8 @@ mod tests {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let times2: Vec<i64> = (n as i64 + 1..=2 * n as i64).collect();
|
let times2: Vec<i64> = (n as i64 + 1..=2 * n as i64).collect();
|
||||||
h.add_events(composition, results, times2, vec![]).unwrap();
|
h.add_events_with_prior(composition, results, times2, vec![], HashMap::new())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(h.time_slices.len(), 6);
|
assert_eq!(h.time_slices.len(), 6);
|
||||||
|
|
||||||
@@ -1056,7 +1106,13 @@ mod tests {
|
|||||||
|
|
||||||
let n = composition.len();
|
let n = composition.len();
|
||||||
let times: Vec<i64> = (1..=n as i64).collect();
|
let times: Vec<i64> = (1..=n as i64).collect();
|
||||||
h.add_events(composition.clone(), results.clone(), times, vec![])
|
h.add_events_with_prior(
|
||||||
|
composition.clone(),
|
||||||
|
results.clone(),
|
||||||
|
times,
|
||||||
|
vec![],
|
||||||
|
HashMap::new(),
|
||||||
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
h.convergence(ITERATIONS, EPSILON, false);
|
h.convergence(ITERATIONS, EPSILON, false);
|
||||||
@@ -1081,7 +1137,8 @@ mod tests {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let times2: Vec<i64> = (n as i64 + 1..=2 * n as i64).collect();
|
let times2: Vec<i64> = (n as i64 + 1..=2 * n as i64).collect();
|
||||||
h.add_events(composition, results, times2, vec![]).unwrap();
|
h.add_events_with_prior(composition, results, times2, vec![], HashMap::new())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(h.time_slices.len(), 6);
|
assert_eq!(h.time_slices.len(), 6);
|
||||||
|
|
||||||
@@ -1137,7 +1194,13 @@ mod tests {
|
|||||||
|
|
||||||
let n = composition.len();
|
let n = composition.len();
|
||||||
let times: Vec<i64> = (1..=n as i64).collect();
|
let times: Vec<i64> = (1..=n as i64).collect();
|
||||||
h.add_events(composition.clone(), vec![], times.clone(), vec![])
|
h.add_events_with_prior(
|
||||||
|
composition.clone(),
|
||||||
|
vec![],
|
||||||
|
times.clone(),
|
||||||
|
vec![],
|
||||||
|
HashMap::new(),
|
||||||
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let p_d_m_2 = h.log_evidence(false, &[]).exp() * 2.0;
|
let p_d_m_2 = h.log_evidence(false, &[]).exp() * 2.0;
|
||||||
@@ -1177,7 +1240,8 @@ mod tests {
|
|||||||
|
|
||||||
let mut h = History::builder().build();
|
let mut h = History::builder().build();
|
||||||
|
|
||||||
h.add_events(composition, vec![], times, vec![]).unwrap();
|
h.add_events_with_prior(composition, vec![], times, vec![], HashMap::new())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
((0.5f64 * 0.1765).ln() / 2.0).exp(),
|
((0.5f64 * 0.1765).ln() / 2.0).exp(),
|
||||||
@@ -1208,17 +1272,24 @@ mod tests {
|
|||||||
.gamma(0.0)
|
.gamma(0.0)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
h.add_events(
|
h.add_events_with_prior(
|
||||||
composition.clone(),
|
composition.clone(),
|
||||||
results.clone(),
|
results.clone(),
|
||||||
vec![0, 10, 20],
|
vec![0, 10, 20],
|
||||||
vec![],
|
vec![],
|
||||||
|
HashMap::new(),
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
h.convergence(ITERATIONS, EPSILON, false);
|
h.convergence(ITERATIONS, EPSILON, false);
|
||||||
|
|
||||||
h.add_events(composition, results, vec![15, 10, 0], vec![])
|
h.add_events_with_prior(
|
||||||
|
composition,
|
||||||
|
results,
|
||||||
|
vec![15, 10, 0],
|
||||||
|
vec![],
|
||||||
|
HashMap::new(),
|
||||||
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(h.time_slices.len(), 4);
|
assert_eq!(h.time_slices.len(), 4);
|
||||||
@@ -1303,12 +1374,18 @@ mod tests {
|
|||||||
.gamma(0.0)
|
.gamma(0.0)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
h.add_events(composition.clone(), vec![], vec![0, 10, 20], vec![])
|
h.add_events_with_prior(
|
||||||
|
composition.clone(),
|
||||||
|
vec![],
|
||||||
|
vec![0, 10, 20],
|
||||||
|
vec![],
|
||||||
|
HashMap::new(),
|
||||||
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
h.convergence(ITERATIONS, EPSILON, false);
|
h.convergence(ITERATIONS, EPSILON, false);
|
||||||
|
|
||||||
h.add_events(composition, vec![], vec![15, 10, 0], vec![])
|
h.add_events_with_prior(composition, vec![], vec![15, 10, 0], vec![], HashMap::new())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(h.time_slices.len(), 4);
|
assert_eq!(h.time_slices.len(), 4);
|
||||||
@@ -1398,7 +1475,8 @@ mod tests {
|
|||||||
|
|
||||||
let n = composition.len();
|
let n = composition.len();
|
||||||
let times: Vec<i64> = (1..=n as i64).collect();
|
let times: Vec<i64> = (1..=n as i64).collect();
|
||||||
h.add_events(composition, vec![], times, weights).unwrap();
|
h.add_events_with_prior(composition, vec![], times, weights, HashMap::new())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let lc = h.learning_curves();
|
let lc = h.learning_curves();
|
||||||
|
|
||||||
@@ -1459,7 +1537,8 @@ mod tests {
|
|||||||
epsilon: 1e-6,
|
epsilon: 1e-6,
|
||||||
})
|
})
|
||||||
.build();
|
.build();
|
||||||
h.add_events(composition, results, times, vec![]).unwrap();
|
h.add_events_with_prior(composition, results, times, vec![], HashMap::new())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let report = h.converge().unwrap();
|
let report = h.converge().unwrap();
|
||||||
assert!(report.converged);
|
assert!(report.converged);
|
||||||
|
|||||||
84
tests/api_shape.rs
Normal file
84
tests/api_shape.rs
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
//! Tests for the new T2 public API surface: typed add_events(iter) and the
|
||||||
|
//! fluent event builder (added in Task 16).
|
||||||
|
|
||||||
|
use smallvec::smallvec;
|
||||||
|
use trueskill_tt::{ConstantDrift, ConvergenceOptions, Event, History, Member, Outcome, Team};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn add_events_bulk_via_iter() {
|
||||||
|
let mut h = History::builder()
|
||||||
|
.mu(0.0)
|
||||||
|
.sigma(2.0)
|
||||||
|
.beta(1.0)
|
||||||
|
.p_draw(0.0)
|
||||||
|
.drift(ConstantDrift(0.0))
|
||||||
|
.convergence(ConvergenceOptions {
|
||||||
|
max_iter: 30,
|
||||||
|
epsilon: 1e-6,
|
||||||
|
})
|
||||||
|
.build();
|
||||||
|
|
||||||
|
let events: Vec<Event<i64, &'static str>> = vec![
|
||||||
|
Event {
|
||||||
|
time: 1,
|
||||||
|
teams: smallvec![
|
||||||
|
Team::with_members([Member::new("a")]),
|
||||||
|
Team::with_members([Member::new("b")]),
|
||||||
|
],
|
||||||
|
outcome: Outcome::winner(0, 2),
|
||||||
|
},
|
||||||
|
Event {
|
||||||
|
time: 2,
|
||||||
|
teams: smallvec![
|
||||||
|
Team::with_members([Member::new("b")]),
|
||||||
|
Team::with_members([Member::new("c")]),
|
||||||
|
],
|
||||||
|
outcome: Outcome::winner(0, 2),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
h.add_events(events).unwrap();
|
||||||
|
let report = h.converge().unwrap();
|
||||||
|
assert!(report.converged);
|
||||||
|
assert!(h.lookup(&"a").is_some());
|
||||||
|
assert!(h.lookup(&"b").is_some());
|
||||||
|
assert!(h.lookup(&"c").is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn add_events_draw() {
|
||||||
|
let mut h = History::builder()
|
||||||
|
.mu(25.0)
|
||||||
|
.sigma(25.0 / 3.0)
|
||||||
|
.beta(25.0 / 6.0)
|
||||||
|
.p_draw(0.25)
|
||||||
|
.drift(ConstantDrift(25.0 / 300.0))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
let events: Vec<Event<i64, &'static str>> = vec![Event {
|
||||||
|
time: 1,
|
||||||
|
teams: smallvec![
|
||||||
|
Team::with_members([Member::new("alice")]),
|
||||||
|
Team::with_members([Member::new("bob")]),
|
||||||
|
],
|
||||||
|
outcome: Outcome::draw(2),
|
||||||
|
}];
|
||||||
|
h.add_events(events).unwrap();
|
||||||
|
h.converge().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn add_events_rejects_mismatched_outcome_ranks() {
|
||||||
|
use trueskill_tt::InferenceError;
|
||||||
|
let mut h: History = History::builder().build();
|
||||||
|
let events: Vec<Event<i64, &'static str>> = vec![Event {
|
||||||
|
time: 1,
|
||||||
|
teams: smallvec![
|
||||||
|
Team::with_members([Member::new("a")]),
|
||||||
|
Team::with_members([Member::new("b")]),
|
||||||
|
],
|
||||||
|
outcome: Outcome::ranking([0, 1, 2]), // 3 ranks but 2 teams
|
||||||
|
}];
|
||||||
|
let err = h.add_events(events).unwrap_err();
|
||||||
|
assert!(matches!(err, InferenceError::MismatchedShape { .. }));
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user