Every #[cfg(test)] mod tests in src/history.rs now uses the new public API: add_events(iter) / converge() / learning_curve() / current_skill() / log_evidence(). No golden value changed. Legacy methods removed: - History::convergence(iters, eps, verbose) → use converge() - History::learning_curves_by_index() → use learning_curve() / learning_curves() - HistoryBuilder::gamma(f64) → use .drift(ConstantDrift(g)) - add_events_with_prior downgraded from pub to pub(crate) Added: - History::builder_with_key() for custom key types (used by atp example) - tests/equivalence.rs: Game-level golden integration tests examples/atp.rs rewritten in new API (Event<i64, String>, converge(), learning_curve(), drift(ConstantDrift(...))). Bench Batch::iteration: 21.4 µs (T1 reference: 22.88 µs). Part of T2 of docs/superpowers/specs/2026-04-23-trueskill-engine-redesign-design.md. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
233 lines
6.3 KiB
Rust
233 lines
6.3 KiB
Rust
use plotters::prelude::*;
|
|
use smallvec::smallvec;
|
|
use time::{Date, Month};
|
|
use trueskill_tt::{Event, History, Member, Outcome, Team, drift::ConstantDrift};
|
|
|
|
fn main() {
|
|
let mut csv = csv::Reader::open("examples/atp.csv").unwrap();
|
|
|
|
let from = Date::from_calendar_date(1900, Month::January, 1).unwrap();
|
|
let time_format = time::format_description::parse("[year]-[month]-[day]").unwrap();
|
|
|
|
let mut events: Vec<Event<i64, String>> = Vec::new();
|
|
|
|
for row in csv.records() {
|
|
let date = Date::parse(&row["time_start"], &time_format).unwrap();
|
|
let time = (date - from).whole_days();
|
|
|
|
if &row["double"] == "t" {
|
|
events.push(Event {
|
|
time,
|
|
teams: smallvec![
|
|
Team::with_members([
|
|
Member::new(row["w1_id"].to_owned()),
|
|
Member::new(row["w2_id"].to_owned()),
|
|
]),
|
|
Team::with_members([
|
|
Member::new(row["l1_id"].to_owned()),
|
|
Member::new(row["l2_id"].to_owned()),
|
|
]),
|
|
],
|
|
outcome: Outcome::winner(0, 2),
|
|
});
|
|
} else {
|
|
events.push(Event {
|
|
time,
|
|
teams: smallvec![
|
|
Team::with_members([Member::new(row["w1_id"].to_owned())]),
|
|
Team::with_members([Member::new(row["l1_id"].to_owned())]),
|
|
],
|
|
outcome: Outcome::winner(0, 2),
|
|
});
|
|
}
|
|
}
|
|
|
|
let mut hist: History<i64, _, _, String> = History::builder_with_key()
|
|
.sigma(1.6)
|
|
.drift(ConstantDrift(0.036))
|
|
.convergence(trueskill_tt::ConvergenceOptions {
|
|
max_iter: 10,
|
|
epsilon: 0.01,
|
|
})
|
|
.build();
|
|
|
|
hist.add_events(events).unwrap();
|
|
hist.converge().unwrap();
|
|
|
|
let players = [
|
|
("aggasi", "a092", 38800i64),
|
|
("borg", "b058", 30300),
|
|
("connors", "c044", 31250),
|
|
("courier", "c243", 35750),
|
|
("djokovic", "d643", i64::MAX),
|
|
("edberg", "e004", 34750),
|
|
("federer", "f324", i64::MAX),
|
|
("hewitt", "h432", 40750),
|
|
("mcenroe", "m047", 33000),
|
|
("lendl", "l018", 33750),
|
|
("murray", "mc10", 60750),
|
|
("nadal", "n409", i64::MAX),
|
|
("nastase", "n008", 28750),
|
|
("sampras", "s402", i64::MAX),
|
|
("wilander", "w023", 32600),
|
|
];
|
|
|
|
let mut x_spec = (f64::MAX, f64::MIN);
|
|
let mut y_spec = (f64::MAX, f64::MIN);
|
|
|
|
for &(_, id, cutoff) in &players {
|
|
for (ts, gs) in hist.learning_curve(id) {
|
|
if ts >= cutoff {
|
|
continue;
|
|
}
|
|
|
|
let ts = ts as f64;
|
|
|
|
if ts < x_spec.0 {
|
|
x_spec.0 = ts;
|
|
}
|
|
|
|
if ts > x_spec.1 {
|
|
x_spec.1 = ts;
|
|
}
|
|
|
|
let upper = gs.mu() + gs.sigma();
|
|
let lower = gs.mu() - gs.sigma();
|
|
|
|
if lower < y_spec.0 {
|
|
y_spec.0 = lower;
|
|
}
|
|
|
|
if upper > y_spec.1 {
|
|
y_spec.1 = upper;
|
|
}
|
|
}
|
|
}
|
|
|
|
let root = SVGBackend::new("plot.svg", (1280, 640)).into_drawing_area();
|
|
|
|
root.fill(&WHITE).unwrap();
|
|
|
|
let mut chart = ChartBuilder::on(&root)
|
|
.margin(5)
|
|
.x_label_area_size(30)
|
|
.y_label_area_size(30)
|
|
.build_cartesian_2d(x_spec.0..x_spec.1, y_spec.0..y_spec.1)
|
|
.unwrap();
|
|
|
|
chart.configure_mesh().draw().unwrap();
|
|
|
|
for (idx, &(player, id, cutoff)) in players.iter().enumerate() {
|
|
let mut data = Vec::new();
|
|
let mut upper = Vec::new();
|
|
let mut lower = Vec::new();
|
|
|
|
for (ts, gs) in hist.learning_curve(id) {
|
|
if ts >= cutoff {
|
|
continue;
|
|
}
|
|
|
|
data.push((ts as f64, gs.mu()));
|
|
upper.push((ts as f64, gs.mu() + gs.sigma()));
|
|
lower.push((ts as f64, gs.mu() - gs.sigma()));
|
|
}
|
|
|
|
let color = Palette99::pick(idx);
|
|
|
|
let band = upper
|
|
.into_iter()
|
|
.chain(lower.into_iter().rev())
|
|
.collect::<Vec<_>>();
|
|
|
|
chart
|
|
.plotting_area()
|
|
.draw(&Polygon::new(band, color.mix(0.15)))
|
|
.unwrap();
|
|
|
|
chart
|
|
.draw_series(LineSeries::new(data, &color))
|
|
.unwrap()
|
|
.label(player)
|
|
.legend(move |(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], &color));
|
|
}
|
|
|
|
chart
|
|
.configure_series_labels()
|
|
.background_style(WHITE.mix(0.8))
|
|
.border_style(BLACK)
|
|
.draw()
|
|
.unwrap();
|
|
}
|
|
|
|
mod csv {
|
|
use std::{
|
|
fs::File,
|
|
io::{self, BufRead, BufReader, Lines},
|
|
ops,
|
|
path::Path,
|
|
};
|
|
|
|
pub struct Reader {
|
|
header_map: Vec<String>,
|
|
lines: Lines<BufReader<File>>,
|
|
}
|
|
|
|
impl Reader {
|
|
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
|
|
let mut lines = File::open(path).map(BufReader::new)?.lines();
|
|
|
|
let header_map = if let Some(header) = lines.next() {
|
|
let header = header?;
|
|
|
|
header.split(',').map(Into::into).collect::<Vec<_>>()
|
|
} else {
|
|
Vec::new()
|
|
};
|
|
|
|
Ok(Self { header_map, lines })
|
|
}
|
|
|
|
pub fn records(&mut self) -> Records<'_> {
|
|
Records {
|
|
header_map: &self.header_map,
|
|
lines: &mut self.lines,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct Records<'a> {
|
|
header_map: &'a Vec<String>,
|
|
lines: &'a mut Lines<BufReader<File>>,
|
|
}
|
|
|
|
impl<'a> Iterator for Records<'a> {
|
|
type Item = Record<'a>;
|
|
|
|
fn next(&mut self) -> Option<Self::Item> {
|
|
let line = self.lines.next()?;
|
|
|
|
Some(Record {
|
|
header_map: self.header_map,
|
|
columns: line.unwrap().split(',').map(Into::into).collect::<Vec<_>>(),
|
|
})
|
|
}
|
|
}
|
|
|
|
pub struct Record<'a> {
|
|
header_map: &'a Vec<String>,
|
|
columns: Vec<String>,
|
|
}
|
|
|
|
impl<'a> ops::Index<&str> for Record<'a> {
|
|
type Output = str;
|
|
|
|
fn index(&self, index: &str) -> &Self::Output {
|
|
&self.columns[self
|
|
.header_map
|
|
.iter()
|
|
.position(|header| header == index)
|
|
.unwrap()]
|
|
}
|
|
}
|
|
}
|