Files
kickscore/tests/nba-history.rs
2022-06-07 10:53:11 +02:00

111 lines
2.8 KiB
Rust

extern crate blas_src;
use std::cmp::Ordering;
use std::collections::HashSet;
use std::fs;
use std::io::{self, BufRead};
use approx::assert_abs_diff_eq;
use kickscore as ks;
use time::{macros::date, Date};
#[test]
fn nba_history() {
let reader = fs::File::open("data/nba.csv")
.map(io::BufReader::new)
.unwrap();
let mut teams = HashSet::new();
let mut observations = Vec::new();
let cutoff = date!(2019 - 06 - 01);
let format = time::format_description::parse("[year]-[month]-[day]").unwrap();
for line in reader.lines().skip(1) {
let line = line.unwrap();
let data = line.split(',').collect::<Vec<_>>();
assert!(data.len() == 5);
let t = Date::parse(data[0], &format).unwrap();
if t > cutoff {
break;
}
teams.insert(data[1].to_string());
teams.insert(data[2].to_string());
if data[3].is_empty() || data[4].is_empty() {
continue;
}
let t = t.midnight().assume_utc().unix_timestamp() as f64;
let score_1: u16 = data[3].parse().unwrap();
let score_2: u16 = data[4].parse().unwrap();
match score_1.cmp(&score_2) {
Ordering::Greater => {
observations.push((data[1].to_string(), data[2].to_string(), t));
}
Ordering::Less => {
observations.push((data[2].to_string(), data[1].to_string(), t));
}
_ => panic!("there shouldn't be any tie games"),
}
}
let seconds_in_year = 365.25 * 24.0 * 60.0 * 60.0;
let mut model = ks::model::Binary::probit();
for team in teams {
let kernel = (
ks::kernel::Constant::new(0.03),
ks::kernel::Matern32::new(0.138, 1.753 * seconds_in_year),
);
model.add_item(&team, kernel);
}
for (winner, loser, t) in observations {
model.observe(&[&winner], &[&loser], t);
}
model.fit();
let (p_win, _) = model.probabilities(
&["CHI"],
&["BOS"],
date!(1996 - 01 - 01)
.midnight()
.assume_utc()
.unix_timestamp() as f64,
);
assert_abs_diff_eq!(p_win, 0.8987931627209078, epsilon = f64::EPSILON);
let (p_win, _) = model.probabilities(
&["CHI"],
&["BOS"],
date!(2001 - 01 - 01)
.midnight()
.assume_utc()
.unix_timestamp() as f64,
);
assert_abs_diff_eq!(p_win, 0.22890824995747874, epsilon = f64::EPSILON);
let (p_win, _) = model.probabilities(
&["CHI"],
&["BOS"],
date!(2020 - 01 - 01)
.midnight()
.assume_utc()
.unix_timestamp() as f64,
);
assert_abs_diff_eq!(p_win, 0.2748029998412422, epsilon = f64::EPSILON);
}