Files
trueskill-tt/examples/atp.rs
T
logaritmisk 0705986929 feat(game): plumb ConvergenceOptions through to run_chain
Game and OwnedGame gain a convergence: ConvergenceOptions field set at
construction. Game::{ranked,scored} forward options.convergence into
OwnedGame::{new,new_scored} (previously dropped on the floor).
{ranked,scored}_with_arena take it as a parameter. run_chain reads
self.convergence.{epsilon, max_iter, alpha} instead of hardcoded
1e-6 / 10 / undamped. DiffFactor::propagate gains an alpha parameter
and dispatches into Trunc/MarginFactor::propagate_with_alpha.

In-tree callsites in src/time_slice.rs and src/history.rs pass
ConvergenceOptions::default(). Pre-existing T2 fallout in tests,
benches, and the atp example (struct literals missing the new alpha
field) is fixed by adding alpha: 1.0 so the workspace builds clean.
Default alpha is 1.0, so all 96 lib + 27 integration test goldens
remain bit-equal.
2026-05-08 15:10:35 +02:00

234 lines
6.3 KiB
Rust

use plotters::prelude::*;
use smallvec::smallvec;
use time::{Date, Month};
use trueskill_tt::{Event, History, Member, Outcome, Team, drift::ConstantDrift};
fn main() {
let mut csv = csv::Reader::open("examples/atp.csv").unwrap();
let from = Date::from_calendar_date(1900, Month::January, 1).unwrap();
let time_format = time::format_description::parse("[year]-[month]-[day]").unwrap();
let mut events: Vec<Event<i64, String>> = Vec::new();
for row in csv.records() {
let date = Date::parse(&row["time_start"], &time_format).unwrap();
let time = (date - from).whole_days();
if &row["double"] == "t" {
events.push(Event {
time,
teams: smallvec![
Team::with_members([
Member::new(row["w1_id"].to_owned()),
Member::new(row["w2_id"].to_owned()),
]),
Team::with_members([
Member::new(row["l1_id"].to_owned()),
Member::new(row["l2_id"].to_owned()),
]),
],
outcome: Outcome::winner(0, 2),
});
} else {
events.push(Event {
time,
teams: smallvec![
Team::with_members([Member::new(row["w1_id"].to_owned())]),
Team::with_members([Member::new(row["l1_id"].to_owned())]),
],
outcome: Outcome::winner(0, 2),
});
}
}
let mut hist: History<i64, _, _, String> = History::builder_with_key()
.sigma(1.6)
.drift(ConstantDrift(0.036))
.convergence(trueskill_tt::ConvergenceOptions {
max_iter: 10,
epsilon: 0.01,
alpha: 1.0,
})
.build();
hist.add_events(events).unwrap();
hist.converge().unwrap();
let players = [
("aggasi", "a092", 38800i64),
("borg", "b058", 30300),
("connors", "c044", 31250),
("courier", "c243", 35750),
("djokovic", "d643", i64::MAX),
("edberg", "e004", 34750),
("federer", "f324", i64::MAX),
("hewitt", "h432", 40750),
("mcenroe", "m047", 33000),
("lendl", "l018", 33750),
("murray", "mc10", 60750),
("nadal", "n409", i64::MAX),
("nastase", "n008", 28750),
("sampras", "s402", i64::MAX),
("wilander", "w023", 32600),
];
let mut x_spec = (f64::MAX, f64::MIN);
let mut y_spec = (f64::MAX, f64::MIN);
for &(_, id, cutoff) in &players {
for (ts, gs) in hist.learning_curve(id) {
if ts >= cutoff {
continue;
}
let ts = ts as f64;
if ts < x_spec.0 {
x_spec.0 = ts;
}
if ts > x_spec.1 {
x_spec.1 = ts;
}
let upper = gs.mu() + gs.sigma();
let lower = gs.mu() - gs.sigma();
if lower < y_spec.0 {
y_spec.0 = lower;
}
if upper > y_spec.1 {
y_spec.1 = upper;
}
}
}
let root = SVGBackend::new("plot.svg", (1280, 640)).into_drawing_area();
root.fill(&WHITE).unwrap();
let mut chart = ChartBuilder::on(&root)
.margin(5)
.x_label_area_size(30)
.y_label_area_size(30)
.build_cartesian_2d(x_spec.0..x_spec.1, y_spec.0..y_spec.1)
.unwrap();
chart.configure_mesh().draw().unwrap();
for (idx, &(player, id, cutoff)) in players.iter().enumerate() {
let mut data = Vec::new();
let mut upper = Vec::new();
let mut lower = Vec::new();
for (ts, gs) in hist.learning_curve(id) {
if ts >= cutoff {
continue;
}
data.push((ts as f64, gs.mu()));
upper.push((ts as f64, gs.mu() + gs.sigma()));
lower.push((ts as f64, gs.mu() - gs.sigma()));
}
let color = Palette99::pick(idx);
let band = upper
.into_iter()
.chain(lower.into_iter().rev())
.collect::<Vec<_>>();
chart
.plotting_area()
.draw(&Polygon::new(band, color.mix(0.15)))
.unwrap();
chart
.draw_series(LineSeries::new(data, &color))
.unwrap()
.label(player)
.legend(move |(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], &color));
}
chart
.configure_series_labels()
.background_style(WHITE.mix(0.8))
.border_style(BLACK)
.draw()
.unwrap();
}
mod csv {
use std::{
fs::File,
io::{self, BufRead, BufReader, Lines},
ops,
path::Path,
};
pub struct Reader {
header_map: Vec<String>,
lines: Lines<BufReader<File>>,
}
impl Reader {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
let mut lines = File::open(path).map(BufReader::new)?.lines();
let header_map = if let Some(header) = lines.next() {
let header = header?;
header.split(',').map(Into::into).collect::<Vec<_>>()
} else {
Vec::new()
};
Ok(Self { header_map, lines })
}
pub fn records(&mut self) -> Records<'_> {
Records {
header_map: &self.header_map,
lines: &mut self.lines,
}
}
}
pub struct Records<'a> {
header_map: &'a Vec<String>,
lines: &'a mut Lines<BufReader<File>>,
}
impl<'a> Iterator for Records<'a> {
type Item = Record<'a>;
fn next(&mut self) -> Option<Self::Item> {
let line = self.lines.next()?;
Some(Record {
header_map: self.header_map,
columns: line.unwrap().split(',').map(Into::into).collect::<Vec<_>>(),
})
}
}
pub struct Record<'a> {
header_map: &'a Vec<String>,
columns: Vec<String>,
}
impl<'a> ops::Index<&str> for Record<'a> {
type Output = str;
fn index(&self, index: &str) -> &Self::Output {
&self.columns[self
.header_map
.iter()
.position(|header| header == index)
.unwrap()]
}
}
}