From 22c61d47b1feb2317f0ae1a76367ef916540ab55 Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Tue, 28 Jun 2022 23:18:55 +0200 Subject: [PATCH] Change time to use i64 instead of u64 --- .gitignore | 2 ++ Cargo.toml | 1 + examples/atp.rs | 94 ++++++++++++++++++++++++++++++++++++++++++++++++- src/agent.rs | 8 ++--- src/batch.rs | 12 +++---- src/gaussian.rs | 2 +- src/history.rs | 18 +++++----- src/lib.rs | 2 +- 8 files changed, 117 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index 4fffb2f..573ede2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ /target /Cargo.lock + +*.svg diff --git a/Cargo.toml b/Cargo.toml index 3234fe4..4f9acd0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" approx = { version = "0.5.1", optional = true } [dev-dependencies] +plotters = { version = "0.3.1", default-features = false, features = ["svg_backend", "all_elements", "all_series"] } time = { version = "0.3.9", features = ["parsing"] } trueskill-tt = { path = ".", features = ["approx"] } diff --git a/examples/atp.rs b/examples/atp.rs index bfc77d8..da2785b 100644 --- a/examples/atp.rs +++ b/examples/atp.rs @@ -1,3 +1,4 @@ +use plotters::prelude::*; use time::Date; use trueskill_tt::{History, IndexMap}; @@ -37,13 +38,104 @@ fn main() { .assume_utc() .unix_timestamp(); - times.push(time as u64 / (60 * 60 * 24)); + times.push(time / (60 * 60 * 24)); } let mut hist = History::builder().sigma(1.6).gamma(0.036).build(); hist.add_events(composition, results, times, vec![]); hist.convergence(10, 0.01, true); + + let players = [ + ("djokovic", "d643"), + ("federer", "f324"), + ("sampras", "s402"), + ("lendl", "l018"), + ("connors", "c044"), + ("nadal", "n409"), + ("john_mcenroe", "m047"), + ("bjorn_borg", "b058"), + ("aggasi", "a092"), + ("hewitt", "h432"), + ("edberg", "e004"), + ("vilas", "v028"), + ("nastase", "n008"), + ("courier", "c243"), + ("kuerten", "k293"), + ("murray", "mc10"), + ("wilander", "w023"), + ("roddick", "r485"), + ]; + + let curves = hist.learning_curves(); + + let mut x_spec = (f64::MAX, f64::MIN); + let mut y_spec = (f64::MAX, f64::MIN); + + for id in players.iter().map(|&(_, id)| index_map.get_or_create(id)) { + for (ts, gs) in &curves[&id] { + let ts = *ts as f64; + + if ts < x_spec.0 { + x_spec.0 = ts; + } + + if ts > x_spec.1 { + x_spec.1 = ts; + } + + let mu = gs.mu as f64; + + if mu < y_spec.0 { + y_spec.0 = mu; + } + + if mu > y_spec.1 { + y_spec.1 = mu; + } + } + } + + let root = SVGBackend::new("plot.svg", (1024, 1024)).into_drawing_area(); + + root.fill(&WHITE).unwrap(); + + let mut chart = ChartBuilder::on(&root) + .caption("Hello world", ("sans-serif", 50).into_font()) + .margin(5) + .x_label_area_size(30) + .y_label_area_size(30) + .build_cartesian_2d(x_spec.0..x_spec.1, y_spec.0..y_spec.1) + .unwrap(); + + chart.configure_mesh().draw().unwrap(); + + for (idx, (player, id)) in players + .iter() + .map(|&(player, id)| (player, index_map.get_or_create(id))) + .enumerate() + { + let mut data = Vec::new(); + + for (ts, gs) in &curves[&id] { + data.push((*ts as f64, gs.mu)); + } + + let color = Palette99::pick(idx); + + chart + .draw_series(LineSeries::new(data, &color)) + .unwrap() + .label(format!("{}", player)) + .legend(move |(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], &color)); + } + + chart + .configure_series_labels() + .background_style(&WHITE.mix(0.8)) + .border_style(&BLACK) + .draw() + .unwrap(); } mod csv { diff --git a/src/agent.rs b/src/agent.rs index 43bdcff..671be40 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -4,11 +4,11 @@ use crate::{gaussian::Gaussian, player::Player, N_INF}; pub(crate) struct Agent { pub(crate) player: Player, pub(crate) message: Gaussian, - pub(crate) last_time: u64, + pub(crate) last_time: i64, } impl Agent { - pub(crate) fn receive(&self, elapsed: u64) -> Gaussian { + pub(crate) fn receive(&self, elapsed: i64) -> Gaussian { if self.message != N_INF { self.message.forget(self.player.gamma, elapsed) } else { @@ -22,7 +22,7 @@ impl Default for Agent { Self { player: Player::default(), message: N_INF, - last_time: u64::MIN, + last_time: i64::MIN, } } } @@ -32,7 +32,7 @@ pub(crate) fn clean<'a, A: Iterator>(agents: A, last_time: a.message = N_INF; if last_time { - a.last_time = 0; + a.last_time = i64::MIN; } } } diff --git a/src/batch.rs b/src/batch.rs index 4ecc864..a6ab818 100644 --- a/src/batch.rs +++ b/src/batch.rs @@ -9,7 +9,7 @@ pub(crate) struct Skill { pub(crate) forward: Gaussian, backward: Gaussian, likelihood: Gaussian, - pub(crate) elapsed: u64, + pub(crate) elapsed: i64, pub(crate) online: Gaussian, } @@ -57,7 +57,7 @@ impl Event { pub struct Batch { pub(crate) events: Vec, pub(crate) skills: HashMap, - pub(crate) time: u64, + pub(crate) time: i64, p_draw: f64, } @@ -66,7 +66,7 @@ impl Batch { composition: Vec>>, results: Vec>, weights: Vec>>, - time: u64, + time: i64, p_draw: f64, agents: &mut HashMap, ) -> Self { @@ -444,10 +444,10 @@ impl Batch { } } -pub(crate) fn compute_elapsed(last_time: u64, actual_time: u64) -> u64 { - if last_time == u64::MIN { +pub(crate) fn compute_elapsed(last_time: i64, actual_time: i64) -> i64 { + if last_time == i64::MIN { 0 - } else if last_time == u64::MAX { + } else if last_time == i64::MAX { 1 } else { actual_time - last_time diff --git a/src/gaussian.rs b/src/gaussian.rs index b069699..18b3780 100644 --- a/src/gaussian.rs +++ b/src/gaussian.rs @@ -42,7 +42,7 @@ impl Gaussian { } } - pub(crate) fn forget(&self, gamma: f64, t: u64) -> Self { + pub(crate) fn forget(&self, gamma: f64, t: i64) -> Self { Self { mu: self.mu, sigma: (self.sigma.powi(2) + t as f64 * gamma.powi(2)).sqrt(), diff --git a/src/history.rs b/src/history.rs index 4514d0a..8964551 100644 --- a/src/history.rs +++ b/src/history.rs @@ -206,8 +206,8 @@ impl History { (step, i) } - pub fn learning_curves(&self) -> HashMap> { - let mut data: HashMap> = HashMap::new(); + pub fn learning_curves(&self) -> HashMap> { + let mut data: HashMap> = HashMap::new(); for b in &self.batches { for agent in b.skills.keys() { @@ -235,7 +235,7 @@ impl History { &mut self, composition: Vec>>, results: Vec>, - times: Vec, + times: Vec, weights: Vec>>, ) { self.add_events_with_prior(composition, results, times, weights, HashMap::new()) @@ -245,7 +245,7 @@ impl History { &mut self, composition: Vec>>, results: Vec>, - times: Vec, + times: Vec, weights: Vec>>, priors: HashMap, ) { @@ -302,7 +302,7 @@ impl History { while i < n { let mut j = i + 1; - let t = if self.time { times[o[i]] } else { i as u64 + 1 }; + let t = if self.time { times[o[i]] } else { i as i64 + 1 }; while self.time && j < n && times[o[j]] == t { j += 1; @@ -329,7 +329,7 @@ impl History { let a = self.agents.get_mut(agent).unwrap(); - a.last_time = if self.time { b.time } else { u64::MAX }; + a.last_time = if self.time { b.time } else { i64::MAX }; a.message = b.forward_prior_out(agent); } @@ -359,7 +359,7 @@ impl History { for a in b.skills.keys() { let agent = self.agents.get_mut(a).unwrap(); - agent.last_time = if self.time { t } else { u64::MAX }; + agent.last_time = if self.time { t } else { i64::MAX }; agent.message = b.forward_prior_out(a); } } else { @@ -379,7 +379,7 @@ impl History { for a in b.skills.keys() { let agent = self.agents.get_mut(a).unwrap(); - agent.last_time = if self.time { t } else { u64::MAX }; + agent.last_time = if self.time { t } else { i64::MAX }; agent.message = b.forward_prior_out(a); } @@ -406,7 +406,7 @@ impl History { let a = self.agents.get_mut(agent).unwrap(); - a.last_time = if self.time { b.time } else { u64::MAX }; + a.last_time = if self.time { b.time } else { i64::MAX }; a.message = b.forward_prior_out(agent); } diff --git a/src/lib.rs b/src/lib.rs index 66a0762..1fc6b7b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -223,7 +223,7 @@ pub(crate) fn sort_perm(x: &[f64], reverse: bool) -> Vec { v.into_iter().map(|(i, _)| i).collect() } -pub(crate) fn sort_time(xs: &[u64], reverse: bool) -> Vec { +pub(crate) fn sort_time(xs: &[i64], reverse: bool) -> Vec { let mut x = xs.iter().enumerate().collect::>(); if reverse {