Clean up example

This commit is contained in:
2022-07-04 23:13:57 +02:00
parent 32df04fb6d
commit 18d55a8ccf
2 changed files with 57 additions and 36 deletions

View File

@@ -8,6 +8,7 @@ approx = { version = "0.5.1", optional = true }
[dev-dependencies] [dev-dependencies]
plotters = { version = "0.3.1", default-features = false, features = ["svg_backend", "all_elements", "all_series"] } plotters = { version = "0.3.1", default-features = false, features = ["svg_backend", "all_elements", "all_series"] }
plotters-backend = "0.3.2"
time = { version = "0.3.9", features = ["parsing"] } time = { version = "0.3.9", features = ["parsing"] }
trueskill-tt = { path = ".", features = ["approx"] } trueskill-tt = { path = ".", features = ["approx"] }

View File

@@ -1,5 +1,5 @@
use plotters::prelude::*; use plotters::prelude::*;
use time::Date; use time::{Date, Month};
use trueskill_tt::{History, IndexMap}; use trueskill_tt::{History, IndexMap};
fn main() { fn main() {
@@ -9,6 +9,7 @@ fn main() {
let mut results = Vec::new(); let mut results = Vec::new();
let mut times = Vec::new(); let mut times = Vec::new();
let from = Date::from_calendar_date(1900, Month::January, 1).unwrap();
let time_format = time::format_description::parse("[year]-[month]-[day]").unwrap(); let time_format = time::format_description::parse("[year]-[month]-[day]").unwrap();
let mut index_map = IndexMap::new(); let mut index_map = IndexMap::new();
@@ -32,13 +33,9 @@ fn main() {
results.push(vec![1.0, 0.0]); results.push(vec![1.0, 0.0]);
let time = Date::parse(&row["time_start"], &time_format) let date = Date::parse(&row["time_start"], &time_format).unwrap();
.unwrap()
.midnight()
.assume_utc()
.unix_timestamp();
times.push(time / (60 * 60 * 24)); times.push((date - from).whole_days());
} }
let mut hist = History::builder().sigma(1.6).gamma(0.036).build(); let mut hist = History::builder().sigma(1.6).gamma(0.036).build();
@@ -47,24 +44,21 @@ fn main() {
hist.convergence(10, 0.01, true); hist.convergence(10, 0.01, true);
let players = [ let players = [
("djokovic", "d643"), ("aggasi", "a092", 38800),
("federer", "f324"), ("borg", "b058", 30300),
("sampras", "s402"), ("connors", "c044", 31250),
("lendl", "l018"), ("courier", "c243", 35750),
("connors", "c044"), ("djokovic", "d643", i64::MAX),
("nadal", "n409"), ("edberg", "e004", 34750),
("john_mcenroe", "m047"), ("federer", "f324", i64::MAX),
("bjorn_borg", "b058"), ("hewitt", "h432", 40750),
("aggasi", "a092"), ("mcenroe", "m047", 33000),
("hewitt", "h432"), ("lendl", "l018", 33750),
("edberg", "e004"), ("murray", "mc10", 60750),
("vilas", "v028"), ("nadal", "n409", i64::MAX),
("nastase", "n008"), ("nastase", "n008", 28750),
("courier", "c243"), ("sampras", "s402", i64::MAX),
("kuerten", "k293"), ("wilander", "w023", 32600),
("murray", "mc10"),
("wilander", "w023"),
("roddick", "r485"),
]; ];
let curves = hist.learning_curves(); let curves = hist.learning_curves();
@@ -72,8 +66,15 @@ fn main() {
let mut x_spec = (f64::MAX, f64::MIN); let mut x_spec = (f64::MAX, f64::MIN);
let mut y_spec = (f64::MAX, f64::MIN); let mut y_spec = (f64::MAX, f64::MIN);
for id in players.iter().map(|&(_, id)| index_map.get_or_create(id)) { for (id, cutoff) in players
.iter()
.map(|&(_, id, cutoff)| (index_map.get_or_create(id), cutoff))
{
for (ts, gs) in &curves[&id] { for (ts, gs) in &curves[&id] {
if *ts >= cutoff {
continue;
}
let ts = *ts as f64; let ts = *ts as f64;
if ts < x_spec.0 { if ts < x_spec.0 {
@@ -84,24 +85,24 @@ fn main() {
x_spec.1 = ts; x_spec.1 = ts;
} }
let mu = gs.mu as f64; let upper = gs.mu + gs.sigma;
let lower = gs.mu - gs.sigma;
if mu < y_spec.0 { if lower < y_spec.0 {
y_spec.0 = mu; y_spec.0 = lower;
} }
if mu > y_spec.1 { if upper > y_spec.1 {
y_spec.1 = mu; y_spec.1 = upper;
} }
} }
} }
let root = SVGBackend::new("plot.svg", (1024, 1024)).into_drawing_area(); let root = SVGBackend::new("plot.svg", (1280, 640)).into_drawing_area();
root.fill(&WHITE).unwrap(); root.fill(&WHITE).unwrap();
let mut chart = ChartBuilder::on(&root) let mut chart = ChartBuilder::on(&root)
.caption("Hello world", ("sans-serif", 50).into_font())
.margin(5) .margin(5)
.x_label_area_size(30) .x_label_area_size(30)
.y_label_area_size(30) .y_label_area_size(30)
@@ -110,19 +111,38 @@ fn main() {
chart.configure_mesh().draw().unwrap(); chart.configure_mesh().draw().unwrap();
for (idx, (player, id)) in players for (idx, (player, id, cutoff)) in players
.iter() .iter()
.map(|&(player, id)| (player, index_map.get_or_create(id))) .map(|&(player, id, cutoff)| (player, index_map.get_or_create(id), cutoff))
.enumerate() .enumerate()
{ {
let mut data = Vec::new(); let mut data = Vec::new();
let mut upper = Vec::new();
let mut lower = Vec::new();
for (ts, gs) in curves[&id].iter() {
if *ts >= cutoff {
continue;
}
for (ts, gs) in &curves[&id] {
data.push((*ts as f64, gs.mu)); data.push((*ts as f64, gs.mu));
upper.push((*ts as f64, gs.mu + gs.sigma));
lower.push((*ts as f64, gs.mu - gs.sigma));
} }
let color = Palette99::pick(idx); let color = Palette99::pick(idx);
let band = upper
.into_iter()
.chain(lower.into_iter().rev())
.collect::<Vec<_>>();
chart
.plotting_area()
.draw(&Polygon::new(band, &color.mix(0.15)))
.unwrap();
chart chart
.draw_series(LineSeries::new(data, &color)) .draw_series(LineSeries::new(data, &color))
.unwrap() .unwrap()