Added nba history example, and implemented kernel matern32.
This commit is contained in:
@@ -13,6 +13,9 @@ ndarray = "0.13"
|
|||||||
ndarray-linalg = { version = "0.12" }
|
ndarray-linalg = { version = "0.12" }
|
||||||
openblas-src = { version = "0.8", features = ["static"] }
|
openblas-src = { version = "0.8", features = ["static"] }
|
||||||
ordered-float = "1.0"
|
ordered-float = "1.0"
|
||||||
rand = "0.6"
|
rand = "0.7"
|
||||||
rand_xoshiro = "0.1"
|
rand_xoshiro = "0.4"
|
||||||
statrs = "0.12"
|
statrs = "0.12"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
time = "0.2"
|
||||||
|
|||||||
94
examples/nba-history.rs
Normal file
94
examples/nba-history.rs
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
extern crate openblas_src;
|
||||||
|
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use std::fs;
|
||||||
|
use std::io::{self, BufRead};
|
||||||
|
|
||||||
|
use kickscore as ks;
|
||||||
|
use time::Date;
|
||||||
|
|
||||||
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let reader = fs::File::open("examples/nba.csv").map(io::BufReader::new)?;
|
||||||
|
|
||||||
|
let mut teams = HashSet::new();
|
||||||
|
let mut observations = Vec::new();
|
||||||
|
|
||||||
|
let cutoff = time::date!(2019 - 06 - 01);
|
||||||
|
|
||||||
|
for line in reader.lines() {
|
||||||
|
let line = line?;
|
||||||
|
let data = line.split(',').collect::<Vec<_>>();
|
||||||
|
|
||||||
|
assert!(data.len() == 5);
|
||||||
|
|
||||||
|
let t = Date::parse(data[0], "%F")?;
|
||||||
|
|
||||||
|
if t > cutoff {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
teams.insert(data[1].to_string());
|
||||||
|
teams.insert(data[2].to_string());
|
||||||
|
|
||||||
|
if data[3].is_empty() || data[4].is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let t = t.midnight().timestamp() as f64;
|
||||||
|
|
||||||
|
let score_1: u16 = data[3].parse()?;
|
||||||
|
let score_2: u16 = data[4].parse()?;
|
||||||
|
|
||||||
|
if score_1 > score_2 {
|
||||||
|
observations.push((data[1].to_string(), data[2].to_string(), t));
|
||||||
|
} else if score_1 < score_2 {
|
||||||
|
observations.push((data[2].to_string(), data[1].to_string(), t));
|
||||||
|
} else {
|
||||||
|
panic!("there shouldn't be any tie games");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let seconds_in_year = 365.25 * 24.0 * 60.0 * 60.0;
|
||||||
|
|
||||||
|
let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit);
|
||||||
|
|
||||||
|
for team in teams {
|
||||||
|
let kernel: Vec<Box<dyn ks::Kernel>> = vec![
|
||||||
|
Box::new(ks::kernel::Constant::new(0.03)),
|
||||||
|
Box::new(ks::kernel::Matern32::new(0.138, 1.753 * seconds_in_year)),
|
||||||
|
];
|
||||||
|
|
||||||
|
model.add_item(&team, Box::new(kernel));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (winner, loser, t) in observations {
|
||||||
|
model.observe(&[&winner], &[&loser], t);
|
||||||
|
}
|
||||||
|
|
||||||
|
model.fit();
|
||||||
|
|
||||||
|
println!("Probability that CHI beats BOS...");
|
||||||
|
|
||||||
|
let (p_win, _) = model.probabilities(
|
||||||
|
&[&"CHI"],
|
||||||
|
&[&"BOS"],
|
||||||
|
time::date!(1996 - 01 - 01).midnight().timestamp() as f64,
|
||||||
|
);
|
||||||
|
println!(" ... in 1996: {:.2}%", 100.0 * p_win);
|
||||||
|
|
||||||
|
let (p_win, _) = model.probabilities(
|
||||||
|
&[&"CHI"],
|
||||||
|
&[&"BOS"],
|
||||||
|
time::date!(2001 - 01 - 01).midnight().timestamp() as f64,
|
||||||
|
);
|
||||||
|
println!(" ... in 2001: {:.2}%", 100.0 * p_win);
|
||||||
|
|
||||||
|
let (p_win, _) = model.probabilities(
|
||||||
|
&[&"CHI"],
|
||||||
|
&[&"BOS"],
|
||||||
|
time::date!(2020 - 01 - 01).midnight().timestamp() as f64,
|
||||||
|
);
|
||||||
|
println!(" ... in 2020: {:.2}%", 100.0 * p_win);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -2,10 +2,12 @@ use ndarray::prelude::*;
|
|||||||
|
|
||||||
mod constant;
|
mod constant;
|
||||||
mod exponential;
|
mod exponential;
|
||||||
|
mod matern32;
|
||||||
mod matern52;
|
mod matern52;
|
||||||
|
|
||||||
pub use constant::Constant;
|
pub use constant::Constant;
|
||||||
pub use exponential::Exponential;
|
pub use exponential::Exponential;
|
||||||
|
pub use matern32::Matern32;
|
||||||
pub use matern52::Matern52;
|
pub use matern52::Matern52;
|
||||||
|
|
||||||
pub trait Kernel {
|
pub trait Kernel {
|
||||||
|
|||||||
73
src/kernel/matern32.rs
Normal file
73
src/kernel/matern32.rs
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
use ndarray::prelude::*;
|
||||||
|
|
||||||
|
use super::Kernel;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Matern32 {
|
||||||
|
var: f64,
|
||||||
|
l_scale: f64,
|
||||||
|
lambda: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Matern32 {
|
||||||
|
pub fn new(var: f64, l_scale: f64) -> Self {
|
||||||
|
Matern32 {
|
||||||
|
var,
|
||||||
|
l_scale,
|
||||||
|
lambda: 3.0f64.sqrt() / l_scale,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Kernel for Matern32 {
|
||||||
|
fn k_diag(&self, ts: &[f64]) -> Array1<f64> {
|
||||||
|
Array1::ones(ts.len()) * self.var
|
||||||
|
}
|
||||||
|
|
||||||
|
fn order(&self) -> usize {
|
||||||
|
2
|
||||||
|
}
|
||||||
|
|
||||||
|
fn state_mean(&self, t: f64) -> Array1<f64> {
|
||||||
|
Array1::zeros(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn state_cov(&self, t: f64) -> Array2<f64> {
|
||||||
|
let a = self.lambda;
|
||||||
|
|
||||||
|
array![[1.0, 0.0], [0.0, a * a]] * self.var
|
||||||
|
}
|
||||||
|
|
||||||
|
fn measurement_vector(&self) -> Array1<f64> {
|
||||||
|
array![1.0, 0.0]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn feedback(&self) -> Array2<f64> {
|
||||||
|
let a = self.lambda;
|
||||||
|
|
||||||
|
array![[0.0, 1.0], [-a.powi(2), -2.0 * a]]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn transition(&self, t0: f64, t1: f64) -> Array2<f64> {
|
||||||
|
let d = t1 - t0;
|
||||||
|
let a = self.lambda;
|
||||||
|
|
||||||
|
let ba = array![[d * a + 1.0, d], [-d * a * a, 1.0 - d * a]];
|
||||||
|
|
||||||
|
(-d * a).exp() * ba
|
||||||
|
}
|
||||||
|
|
||||||
|
fn noise_cov(&self, t0: f64, t1: f64) -> Array2<f64> {
|
||||||
|
let d = t1 - t0;
|
||||||
|
let a = self.lambda;
|
||||||
|
let da = d * a;
|
||||||
|
|
||||||
|
let c = (-2.0 * da).exp();
|
||||||
|
|
||||||
|
let x11 = 1.0 - c * (2.0 * da * da + 2.0 * da + 1.0);
|
||||||
|
let x12 = c * (2.0 * da * da * a);
|
||||||
|
let x22 = a * a * (1.0 - c * (2.0 * da * da - 2.0 * da + 1.0));
|
||||||
|
|
||||||
|
self.var * array![[x11, x12], [x12, x22]]
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user