From 18d55a8ccf63aad001645d1b2677de273e17b24d Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Mon, 4 Jul 2022 23:13:57 +0200 Subject: [PATCH] Clean up example --- Cargo.toml | 1 + examples/atp.rs | 92 ++++++++++++++++++++++++++++++------------------- 2 files changed, 57 insertions(+), 36 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4f9acd0..91210be 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ approx = { version = "0.5.1", optional = true } [dev-dependencies] plotters = { version = "0.3.1", default-features = false, features = ["svg_backend", "all_elements", "all_series"] } +plotters-backend = "0.3.2" time = { version = "0.3.9", features = ["parsing"] } trueskill-tt = { path = ".", features = ["approx"] } diff --git a/examples/atp.rs b/examples/atp.rs index 363b3c4..02a1f19 100644 --- a/examples/atp.rs +++ b/examples/atp.rs @@ -1,5 +1,5 @@ use plotters::prelude::*; -use time::Date; +use time::{Date, Month}; use trueskill_tt::{History, IndexMap}; fn main() { @@ -9,6 +9,7 @@ fn main() { let mut results = Vec::new(); let mut times = Vec::new(); + let from = Date::from_calendar_date(1900, Month::January, 1).unwrap(); let time_format = time::format_description::parse("[year]-[month]-[day]").unwrap(); let mut index_map = IndexMap::new(); @@ -32,13 +33,9 @@ fn main() { results.push(vec![1.0, 0.0]); - let time = Date::parse(&row["time_start"], &time_format) - .unwrap() - .midnight() - .assume_utc() - .unix_timestamp(); + let date = Date::parse(&row["time_start"], &time_format).unwrap(); - times.push(time / (60 * 60 * 24)); + times.push((date - from).whole_days()); } let mut hist = History::builder().sigma(1.6).gamma(0.036).build(); @@ -47,24 +44,21 @@ fn main() { hist.convergence(10, 0.01, true); let players = [ - ("djokovic", "d643"), - ("federer", "f324"), - ("sampras", "s402"), - ("lendl", "l018"), - ("connors", "c044"), - ("nadal", "n409"), - ("john_mcenroe", "m047"), - ("bjorn_borg", "b058"), - ("aggasi", "a092"), - ("hewitt", "h432"), - ("edberg", "e004"), - ("vilas", "v028"), - ("nastase", "n008"), - ("courier", "c243"), - ("kuerten", "k293"), - ("murray", "mc10"), - ("wilander", "w023"), - ("roddick", "r485"), + ("aggasi", "a092", 38800), + ("borg", "b058", 30300), + ("connors", "c044", 31250), + ("courier", "c243", 35750), + ("djokovic", "d643", i64::MAX), + ("edberg", "e004", 34750), + ("federer", "f324", i64::MAX), + ("hewitt", "h432", 40750), + ("mcenroe", "m047", 33000), + ("lendl", "l018", 33750), + ("murray", "mc10", 60750), + ("nadal", "n409", i64::MAX), + ("nastase", "n008", 28750), + ("sampras", "s402", i64::MAX), + ("wilander", "w023", 32600), ]; let curves = hist.learning_curves(); @@ -72,8 +66,15 @@ fn main() { let mut x_spec = (f64::MAX, f64::MIN); let mut y_spec = (f64::MAX, f64::MIN); - for id in players.iter().map(|&(_, id)| index_map.get_or_create(id)) { + for (id, cutoff) in players + .iter() + .map(|&(_, id, cutoff)| (index_map.get_or_create(id), cutoff)) + { for (ts, gs) in &curves[&id] { + if *ts >= cutoff { + continue; + } + let ts = *ts as f64; if ts < x_spec.0 { @@ -84,24 +85,24 @@ fn main() { x_spec.1 = ts; } - let mu = gs.mu as f64; + let upper = gs.mu + gs.sigma; + let lower = gs.mu - gs.sigma; - if mu < y_spec.0 { - y_spec.0 = mu; + if lower < y_spec.0 { + y_spec.0 = lower; } - if mu > y_spec.1 { - y_spec.1 = mu; + if upper > y_spec.1 { + y_spec.1 = upper; } } } - let root = SVGBackend::new("plot.svg", (1024, 1024)).into_drawing_area(); + let root = SVGBackend::new("plot.svg", (1280, 640)).into_drawing_area(); root.fill(&WHITE).unwrap(); let mut chart = ChartBuilder::on(&root) - .caption("Hello world", ("sans-serif", 50).into_font()) .margin(5) .x_label_area_size(30) .y_label_area_size(30) @@ -110,19 +111,38 @@ fn main() { chart.configure_mesh().draw().unwrap(); - for (idx, (player, id)) in players + for (idx, (player, id, cutoff)) in players .iter() - .map(|&(player, id)| (player, index_map.get_or_create(id))) + .map(|&(player, id, cutoff)| (player, index_map.get_or_create(id), cutoff)) .enumerate() { let mut data = Vec::new(); + let mut upper = Vec::new(); + let mut lower = Vec::new(); + + for (ts, gs) in curves[&id].iter() { + if *ts >= cutoff { + continue; + } - for (ts, gs) in &curves[&id] { data.push((*ts as f64, gs.mu)); + + upper.push((*ts as f64, gs.mu + gs.sigma)); + lower.push((*ts as f64, gs.mu - gs.sigma)); } let color = Palette99::pick(idx); + let band = upper + .into_iter() + .chain(lower.into_iter().rev()) + .collect::>(); + + chart + .plotting_area() + .draw(&Polygon::new(band, &color.mix(0.15))) + .unwrap(); + chart .draw_series(LineSeries::new(data, &color)) .unwrap()