Fix tests
This commit is contained in:
@@ -15,5 +15,6 @@ rand_xoshiro = "0.4"
|
|||||||
statrs = "0.13"
|
statrs = "0.13"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
openblas-src = { version = "0.10", features = ["system"] }
|
approx = "0.4"
|
||||||
|
intel-mkl-src = "0.5"
|
||||||
time = "0.2"
|
time = "0.2"
|
||||||
|
|||||||
68402
data/nba.csv
Normal file
68402
data/nba.csv
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
|||||||
extern crate openblas_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
use kickscore as ks;
|
use kickscore as ks;
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
extern crate openblas_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
use kickscore as ks;
|
use kickscore as ks;
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
extern crate openblas_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
@@ -8,14 +8,14 @@ use kickscore as ks;
|
|||||||
use time::Date;
|
use time::Date;
|
||||||
|
|
||||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let reader = fs::File::open("examples/nba.csv").map(io::BufReader::new)?;
|
let reader = fs::File::open("data/nba.csv").map(io::BufReader::new)?;
|
||||||
|
|
||||||
let mut teams = HashSet::new();
|
let mut teams = HashSet::new();
|
||||||
let mut observations = Vec::new();
|
let mut observations = Vec::new();
|
||||||
|
|
||||||
let cutoff = time::date!(2019 - 06 - 01);
|
let cutoff = time::date!(2019 - 06 - 01);
|
||||||
|
|
||||||
for line in reader.lines() {
|
for line in reader.lines().skip(1) {
|
||||||
let line = line?;
|
let line = line?;
|
||||||
let data = line.split(',').collect::<Vec<_>>();
|
let data = line.split(',').collect::<Vec<_>>();
|
||||||
|
|
||||||
|
|||||||
45
tests/kickscore-basics.rs
Normal file
45
tests/kickscore-basics.rs
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
use approx::assert_abs_diff_eq;
|
||||||
|
use kickscore as ks;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn kickscore_basic() {
|
||||||
|
let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit);
|
||||||
|
|
||||||
|
let k_spike = ks::kernel::Constant::new(0.5);
|
||||||
|
|
||||||
|
let k_tom = ks::kernel::Exponential::new(1.0, 1.0);
|
||||||
|
|
||||||
|
let k_jerry: Vec<Box<dyn ks::Kernel>> = vec![
|
||||||
|
Box::new(ks::kernel::Constant::new(1.0)),
|
||||||
|
Box::new(ks::kernel::Matern52::new(0.5, 1.0)),
|
||||||
|
];
|
||||||
|
|
||||||
|
model.add_item("Spike", Box::new(k_spike));
|
||||||
|
model.add_item("Tom", Box::new(k_tom));
|
||||||
|
model.add_item("Jerry", Box::new(k_jerry));
|
||||||
|
|
||||||
|
model.observe(&["Jerry"], &["Tom"], 0.0);
|
||||||
|
model.observe(&["Jerry"], &["Tom"], 0.9);
|
||||||
|
|
||||||
|
model.observe(&["Tom"], &["Spike"], 1.7);
|
||||||
|
model.observe(&["Tom"], &["Jerry"], 2.1);
|
||||||
|
|
||||||
|
model.observe(&["Jerry"], &["Tom"], 3.0);
|
||||||
|
model.observe(&["Jerry"], &["Tom", "Spike"], 3.5);
|
||||||
|
|
||||||
|
model.fit();
|
||||||
|
|
||||||
|
let (p_win, _p_los) = model.probabilities(&[&"Jerry"], &[&"Tom"], 4.0);
|
||||||
|
|
||||||
|
assert_abs_diff_eq!(p_win, 0.7299975928462964, epsilon = f64::EPSILON);
|
||||||
|
|
||||||
|
let (p_win, _p_los) = model.probabilities(&[&"Jerry"], &[&"Tom"], 2.0);
|
||||||
|
|
||||||
|
assert_abs_diff_eq!(p_win, 0.4455095363120037, epsilon = f64::EPSILON);
|
||||||
|
|
||||||
|
let (p_win, _p_los) = model.probabilities(&[&"Jerry"], &[&"Tom"], -1.0);
|
||||||
|
|
||||||
|
assert_abs_diff_eq!(p_win, 0.903756079972532, epsilon = f64::EPSILON);
|
||||||
|
}
|
||||||
106
tests/nba-history.rs
Normal file
106
tests/nba-history.rs
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use std::fs;
|
||||||
|
use std::io::{self, BufRead};
|
||||||
|
|
||||||
|
use approx::assert_abs_diff_eq;
|
||||||
|
use kickscore as ks;
|
||||||
|
use time::Date;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn nba_history() {
|
||||||
|
let reader = fs::File::open("data/nba.csv")
|
||||||
|
.map(io::BufReader::new)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut teams = HashSet::new();
|
||||||
|
let mut observations = Vec::new();
|
||||||
|
|
||||||
|
let cutoff = time::date!(2019 - 06 - 01);
|
||||||
|
|
||||||
|
for line in reader.lines().skip(1) {
|
||||||
|
let line = line.unwrap();
|
||||||
|
let data = line.split(',').collect::<Vec<_>>();
|
||||||
|
|
||||||
|
assert!(data.len() == 5);
|
||||||
|
|
||||||
|
let t = Date::parse(data[0], "%F").unwrap();
|
||||||
|
|
||||||
|
if t > cutoff {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
teams.insert(data[1].to_string());
|
||||||
|
teams.insert(data[2].to_string());
|
||||||
|
|
||||||
|
if data[3].is_empty() || data[4].is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let t = t.midnight().assume_utc().timestamp() as f64;
|
||||||
|
|
||||||
|
let score_1: u16 = data[3].parse().unwrap();
|
||||||
|
let score_2: u16 = data[4].parse().unwrap();
|
||||||
|
|
||||||
|
if score_1 > score_2 {
|
||||||
|
observations.push((data[1].to_string(), data[2].to_string(), t));
|
||||||
|
} else if score_1 < score_2 {
|
||||||
|
observations.push((data[2].to_string(), data[1].to_string(), t));
|
||||||
|
} else {
|
||||||
|
panic!("there shouldn't be any tie games");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let seconds_in_year = 365.25 * 24.0 * 60.0 * 60.0;
|
||||||
|
|
||||||
|
let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit);
|
||||||
|
|
||||||
|
for team in teams {
|
||||||
|
let kernel: Vec<Box<dyn ks::Kernel>> = vec![
|
||||||
|
Box::new(ks::kernel::Constant::new(0.03)),
|
||||||
|
Box::new(ks::kernel::Matern32::new(0.138, 1.753 * seconds_in_year)),
|
||||||
|
];
|
||||||
|
|
||||||
|
model.add_item(&team, Box::new(kernel));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (winner, loser, t) in observations {
|
||||||
|
model.observe(&[&winner], &[&loser], t);
|
||||||
|
}
|
||||||
|
|
||||||
|
model.fit();
|
||||||
|
|
||||||
|
let (p_win, _) = model.probabilities(
|
||||||
|
&[&"CHI"],
|
||||||
|
&[&"BOS"],
|
||||||
|
time::date!(1996 - 01 - 01)
|
||||||
|
.midnight()
|
||||||
|
.assume_utc()
|
||||||
|
.timestamp() as f64,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_abs_diff_eq!(p_win, 0.9002599772490479, epsilon = f64::EPSILON);
|
||||||
|
|
||||||
|
let (p_win, _) = model.probabilities(
|
||||||
|
&[&"CHI"],
|
||||||
|
&[&"BOS"],
|
||||||
|
time::date!(2001 - 01 - 01)
|
||||||
|
.midnight()
|
||||||
|
.assume_utc()
|
||||||
|
.timestamp() as f64,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_abs_diff_eq!(p_win, 0.22837870685441986, epsilon = f64::EPSILON);
|
||||||
|
|
||||||
|
let (p_win, _) = model.probabilities(
|
||||||
|
&[&"CHI"],
|
||||||
|
&[&"BOS"],
|
||||||
|
time::date!(2020 - 01 - 01)
|
||||||
|
.midnight()
|
||||||
|
.assume_utc()
|
||||||
|
.timestamp() as f64,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_abs_diff_eq!(p_win, 0.2748029998412422, epsilon = f64::EPSILON);
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user