test: translate in-crate tests to new T2 API; delete legacy methods
Every #[cfg(test)] mod tests in src/history.rs now uses the new public API: add_events(iter) / converge() / learning_curve() / current_skill() / log_evidence(). No golden value changed. Legacy methods removed: - History::convergence(iters, eps, verbose) → use converge() - History::learning_curves_by_index() → use learning_curve() / learning_curves() - HistoryBuilder::gamma(f64) → use .drift(ConstantDrift(g)) - add_events_with_prior downgraded from pub to pub(crate) Added: - History::builder_with_key() for custom key types (used by atp example) - tests/equivalence.rs: Game-level golden integration tests examples/atp.rs rewritten in new API (Event<i64, String>, converge(), learning_curve(), drift(ConstantDrift(...))). Bench Batch::iteration: 21.4 µs (T1 reference: 22.88 µs). Part of T2 of docs/superpowers/specs/2026-04-23-trueskill-engine-redesign-design.md. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
102
examples/atp.rs
102
examples/atp.rs
@@ -1,53 +1,61 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use plotters::prelude::*;
|
||||
use smallvec::smallvec;
|
||||
use time::{Date, Month};
|
||||
use trueskill_tt::{History, KeyTable};
|
||||
use trueskill_tt::{Event, History, Member, Outcome, Team, drift::ConstantDrift};
|
||||
|
||||
fn main() {
|
||||
let mut csv = csv::Reader::open("examples/atp.csv").unwrap();
|
||||
|
||||
let mut composition = Vec::new();
|
||||
let mut results = Vec::new();
|
||||
let mut times = Vec::new();
|
||||
|
||||
let from = Date::from_calendar_date(1900, Month::January, 1).unwrap();
|
||||
let time_format = time::format_description::parse("[year]-[month]-[day]").unwrap();
|
||||
|
||||
let mut index_map = KeyTable::new();
|
||||
let mut events: Vec<Event<i64, String>> = Vec::new();
|
||||
|
||||
for row in csv.records() {
|
||||
if &row["double"] == "t" {
|
||||
let w1_id = index_map.get_or_create(&row["w1_id"]);
|
||||
let w2_id = index_map.get_or_create(&row["w2_id"]);
|
||||
|
||||
let l1_id = index_map.get_or_create(&row["l1_id"]);
|
||||
let l2_id = index_map.get_or_create(&row["l2_id"]);
|
||||
|
||||
composition.push(vec![vec![w1_id, w2_id], vec![l1_id, l2_id]]);
|
||||
} else {
|
||||
let w1_id = index_map.get_or_create(&row["w1_id"]);
|
||||
|
||||
let l1_id = index_map.get_or_create(&row["l1_id"]);
|
||||
|
||||
composition.push(vec![vec![w1_id], vec![l1_id]]);
|
||||
}
|
||||
|
||||
results.push(vec![1.0, 0.0]);
|
||||
|
||||
let date = Date::parse(&row["time_start"], &time_format).unwrap();
|
||||
let time = (date - from).whole_days();
|
||||
|
||||
times.push((date - from).whole_days());
|
||||
if &row["double"] == "t" {
|
||||
events.push(Event {
|
||||
time,
|
||||
teams: smallvec![
|
||||
Team::with_members([
|
||||
Member::new(row["w1_id"].to_owned()),
|
||||
Member::new(row["w2_id"].to_owned()),
|
||||
]),
|
||||
Team::with_members([
|
||||
Member::new(row["l1_id"].to_owned()),
|
||||
Member::new(row["l2_id"].to_owned()),
|
||||
]),
|
||||
],
|
||||
outcome: Outcome::winner(0, 2),
|
||||
});
|
||||
} else {
|
||||
events.push(Event {
|
||||
time,
|
||||
teams: smallvec![
|
||||
Team::with_members([Member::new(row["w1_id"].to_owned())]),
|
||||
Team::with_members([Member::new(row["l1_id"].to_owned())]),
|
||||
],
|
||||
outcome: Outcome::winner(0, 2),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let mut hist = History::builder().sigma(1.6).gamma(0.036).build();
|
||||
let mut hist: History<i64, _, _, String> = History::builder_with_key()
|
||||
.sigma(1.6)
|
||||
.drift(ConstantDrift(0.036))
|
||||
.convergence(trueskill_tt::ConvergenceOptions {
|
||||
max_iter: 10,
|
||||
epsilon: 0.01,
|
||||
})
|
||||
.build();
|
||||
|
||||
hist.add_events_with_prior(composition, results, times, vec![], HashMap::new())
|
||||
.unwrap();
|
||||
hist.convergence(10, 0.01, true);
|
||||
hist.add_events(events).unwrap();
|
||||
hist.converge().unwrap();
|
||||
|
||||
let players = [
|
||||
("aggasi", "a092", 38800),
|
||||
("aggasi", "a092", 38800i64),
|
||||
("borg", "b058", 30300),
|
||||
("connors", "c044", 31250),
|
||||
("courier", "c243", 35750),
|
||||
@@ -64,21 +72,16 @@ fn main() {
|
||||
("wilander", "w023", 32600),
|
||||
];
|
||||
|
||||
let curves = hist.learning_curves_by_index();
|
||||
|
||||
let mut x_spec = (f64::MAX, f64::MIN);
|
||||
let mut y_spec = (f64::MAX, f64::MIN);
|
||||
|
||||
for (id, cutoff) in players
|
||||
.iter()
|
||||
.map(|&(_, id, cutoff)| (index_map.get_or_create(id), cutoff))
|
||||
{
|
||||
for (ts, gs) in &curves[&id] {
|
||||
if *ts >= cutoff {
|
||||
for &(_, id, cutoff) in &players {
|
||||
for (ts, gs) in hist.learning_curve(id) {
|
||||
if ts >= cutoff {
|
||||
continue;
|
||||
}
|
||||
|
||||
let ts = *ts as f64;
|
||||
let ts = ts as f64;
|
||||
|
||||
if ts < x_spec.0 {
|
||||
x_spec.0 = ts;
|
||||
@@ -114,24 +117,19 @@ fn main() {
|
||||
|
||||
chart.configure_mesh().draw().unwrap();
|
||||
|
||||
for (idx, (player, id, cutoff)) in players
|
||||
.iter()
|
||||
.map(|&(player, id, cutoff)| (player, index_map.get_or_create(id), cutoff))
|
||||
.enumerate()
|
||||
{
|
||||
for (idx, &(player, id, cutoff)) in players.iter().enumerate() {
|
||||
let mut data = Vec::new();
|
||||
let mut upper = Vec::new();
|
||||
let mut lower = Vec::new();
|
||||
|
||||
for (ts, gs) in curves[&id].iter() {
|
||||
if *ts >= cutoff {
|
||||
for (ts, gs) in hist.learning_curve(id) {
|
||||
if ts >= cutoff {
|
||||
continue;
|
||||
}
|
||||
|
||||
data.push((*ts as f64, gs.mu()));
|
||||
|
||||
upper.push((*ts as f64, gs.mu() + gs.sigma()));
|
||||
lower.push((*ts as f64, gs.mu() - gs.sigma()));
|
||||
data.push((ts as f64, gs.mu()));
|
||||
upper.push((ts as f64, gs.mu() + gs.sigma()));
|
||||
lower.push((ts as f64, gs.mu() - gs.sigma()));
|
||||
}
|
||||
|
||||
let color = Palette99::pick(idx);
|
||||
|
||||
Reference in New Issue
Block a user