Rename recursive fitter and binary model

This commit is contained in:
2021-10-27 11:05:43 +02:00
parent b2bf871500
commit 37746f6c02
15 changed files with 41 additions and 91 deletions

View File

@@ -3,7 +3,7 @@ extern crate blas_src;
use kickscore as ks; use kickscore as ks;
fn main() { fn main() {
let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit); let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
for player in &["A", "B", "C", "D", "E", "F"] { for player in &["A", "B", "C", "D", "E", "F"] {
let kernel: Vec<Box<dyn ks::Kernel>> = vec![ let kernel: Vec<Box<dyn ks::Kernel>> = vec![

View File

@@ -3,7 +3,7 @@ extern crate blas_src;
use kickscore as ks; use kickscore as ks;
fn main() { fn main() {
let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit); let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
// Spike's skill does not change over time. // Spike's skill does not change over time.
let k_spike = ks::kernel::Constant::new(0.5); let k_spike = ks::kernel::Constant::new(0.5);

View File

@@ -54,7 +54,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let seconds_in_year = 365.25 * 24.0 * 60.0 * 60.0; let seconds_in_year = 365.25 * 24.0 * 60.0 * 60.0;
let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit); let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
for team in teams { for team in teams {
let kernel: Vec<Box<dyn ks::Kernel>> = vec![ let kernel: Vec<Box<dyn ks::Kernel>> = vec![

View File

@@ -1,6 +1,6 @@
mod recursive; mod recursive;
pub use recursive::RecursiveFitter; pub use recursive::Recursive;
pub trait Fitter { pub trait Fitter {
fn add_sample(&mut self, t: f64) -> usize; fn add_sample(&mut self, t: f64) -> usize;

View File

@@ -7,7 +7,7 @@ use crate::kernel::Kernel;
use super::Fitter; use super::Fitter;
pub struct RecursiveFitter { pub struct Recursive {
ts_new: Vec<f64>, ts_new: Vec<f64>,
kernel: Box<dyn Kernel>, kernel: Box<dyn Kernel>,
ts: Vec<f64>, ts: Vec<f64>,
@@ -28,7 +28,7 @@ pub struct RecursiveFitter {
p_s: Vec<Array2<f64>>, p_s: Vec<Array2<f64>>,
} }
impl RecursiveFitter { impl Recursive {
pub fn new(kernel: Box<dyn Kernel>) -> Self { pub fn new(kernel: Box<dyn Kernel>) -> Self {
let m = kernel.order(); let m = kernel.order();
let h = kernel.measurement_vector(); let h = kernel.measurement_vector();
@@ -56,7 +56,7 @@ impl RecursiveFitter {
} }
} }
impl fmt::Debug for RecursiveFitter { impl fmt::Debug for Recursive {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RecursiveFitter") f.debug_struct("RecursiveFitter")
.field("ts_new", &self.ts_new) .field("ts_new", &self.ts_new)
@@ -80,7 +80,7 @@ impl fmt::Debug for RecursiveFitter {
} }
} }
impl Fitter for RecursiveFitter { impl Fitter for Recursive {
fn add_sample(&mut self, t: f64) -> usize { fn add_sample(&mut self, t: f64) -> usize {
let idx = self.ts.len() + self.ts_new.len(); let idx = self.ts.len() + self.ts_new.len();

View File

@@ -5,10 +5,10 @@ mod item;
pub mod kernel; pub mod kernel;
mod linalg; mod linalg;
mod math; mod math;
mod model; pub mod model;
pub mod observation; pub mod observation;
mod storage; mod storage;
mod utils; mod utils;
pub use kernel::Kernel; pub use kernel::Kernel;
pub use model::*; // pub use model::*;

View File

@@ -1,7 +1,7 @@
mod binary; pub mod binary;
mod difference; mod difference;
mod ternary; mod ternary;
pub use binary::*; pub use binary::Binary;
pub use difference::*; pub use difference::*;
pub use ternary::*; pub use ternary::*;

View File

@@ -1,6 +1,6 @@
use std::f64; use std::f64;
use crate::fitter::RecursiveFitter; use crate::fitter::Recursive;
use crate::item::Item; use crate::item::Item;
use crate::kernel::Kernel; use crate::kernel::Kernel;
use crate::observation::*; use crate::observation::*;
@@ -12,28 +12,28 @@ trait BinaryObeservation: Observation {
} }
*/ */
pub enum BinaryModelObservation { pub enum Observation {
Probit, Probit,
Logit, Logit,
} }
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
pub enum BinaryModelFitMethod { pub enum FitMethod {
Ep, Ep,
Kl, Kl,
} }
pub struct BinaryModel { pub struct Binary {
pub storage: Storage, pub storage: Storage,
last_t: f64, last_t: f64,
win_obs: BinaryModelObservation, win_obs: Observation,
observations: Vec<Box<dyn Observation>>, observations: Vec<Box<dyn crate::observation::Observation>>,
last_method: Option<BinaryModelFitMethod>, last_method: Option<FitMethod>,
} }
impl BinaryModel { impl Binary {
pub fn new(win_obs: BinaryModelObservation) -> Self { pub fn new(win_obs: Observation) -> Self {
BinaryModel { Binary {
storage: Storage::default(), storage: Storage::default(),
last_t: f64::NEG_INFINITY, last_t: f64::NEG_INFINITY,
win_obs, win_obs,
@@ -49,7 +49,7 @@ impl BinaryModel {
self.storage.insert( self.storage.insert(
name.to_string(), name.to_string(),
Item::new(Box::new(RecursiveFitter::new(kernel))), Item::new(Box::new(Recursive::new(kernel))),
); );
} }
@@ -72,11 +72,11 @@ impl BinaryModel {
let mut elems = self.process_items(winners, 1.0); let mut elems = self.process_items(winners, 1.0);
elems.extend(self.process_items(losers, -1.0)); elems.extend(self.process_items(losers, -1.0));
let obs: Box<dyn Observation> = match self.win_obs { let obs: Box<dyn crate::observation::Observation> = match self.win_obs {
BinaryModelObservation::Probit => { Observation::Probit => {
Box::new(ProbitWinObservation::new(&mut self.storage, &elems, t, 0.0)) Box::new(ProbitWinObservation::new(&mut self.storage, &elems, t, 0.0))
} }
BinaryModelObservation::Logit => { Observation::Logit => {
Box::new(LogitWinObservation::new(&mut self.storage, &elems, t, 0.0)) Box::new(LogitWinObservation::new(&mut self.storage, &elems, t, 0.0))
} }
}; };
@@ -95,7 +95,7 @@ impl BinaryModel {
pub fn fit(&mut self) -> bool { pub fn fit(&mut self) -> bool {
// method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False): // method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False):
let method = BinaryModelFitMethod::Ep; let method = FitMethod::Ep;
let lr = 1.0; let lr = 1.0;
let tol = 1e-3; let tol = 1e-3;
let max_iter = 100; let max_iter = 100;
@@ -112,8 +112,8 @@ impl BinaryModel {
for obs in &mut self.observations { for obs in &mut self.observations {
let diff = match method { let diff = match method {
BinaryModelFitMethod::Ep => obs.ep_update(&mut self.storage, lr), FitMethod::Ep => obs.ep_update(&mut self.storage, lr),
BinaryModelFitMethod::Kl => obs.kl_update(&mut self.storage, lr), FitMethod::Kl => obs.kl_update(&mut self.storage, lr),
}; };
if diff > max_diff { if diff > max_diff {
@@ -142,10 +142,8 @@ impl BinaryModel {
elems.extend(self.process_items(team_2, -1.0)); elems.extend(self.process_items(team_2, -1.0));
let prob = match self.win_obs { let prob = match self.win_obs {
BinaryModelObservation::Probit => { Observation::Probit => ProbitWinObservation::probability(&self.storage, &elems, t, 0.0),
ProbitWinObservation::probability(&self.storage, &elems, t, 0.0) Observation::Logit => todo!(),
}
BinaryModelObservation::Logit => todo!(),
}; };
(prob, 1.0 - prob) (prob, 1.0 - prob)

View File

@@ -1,6 +1,6 @@
use std::f64; use std::f64;
use crate::fitter::RecursiveFitter; use crate::fitter::Recursive;
use crate::item::Item; use crate::item::Item;
use crate::kernel::Kernel; use crate::kernel::Kernel;
use crate::observation::*; use crate::observation::*;
@@ -38,7 +38,7 @@ impl DifferenceModel {
self.storage.insert( self.storage.insert(
name.to_string(), name.to_string(),
Item::new(Box::new(RecursiveFitter::new(kernel))), Item::new(Box::new(Recursive::new(kernel))),
); );
} }

View File

@@ -1,6 +1,6 @@
use std::f64; use std::f64;
use crate::fitter::RecursiveFitter; use crate::fitter::Recursive;
use crate::item::Item; use crate::item::Item;
use crate::kernel::Kernel; use crate::kernel::Kernel;
use crate::observation::*; use crate::observation::*;
@@ -46,7 +46,7 @@ impl TernaryModel {
self.storage.insert( self.storage.insert(
name.to_string(), name.to_string(),
Item::new(Box::new(RecursiveFitter::new(kernel))), Item::new(Box::new(Recursive::new(kernel))),
); );
} }

View File

@@ -1,9 +1,7 @@
use std::f64::consts::{SQRT_2, TAU}; use std::f64::consts::TAU;
use crate::storage::Storage; use crate::storage::Storage;
use crate::utils::{ use crate::utils::{logphi, normcdf, normpdf, ROOTS_HERMITENORM_WS, ROOTS_HERMITENORM_XS};
logphi, logsumexp2, normcdf, normpdf, ROOTS_HERMITENORM_WS, ROOTS_HERMITENORM_XS,
};
use super::{f_params, Core, Observation}; use super::{f_params, Core, Observation};
@@ -54,52 +52,6 @@ fn ll_probit_tie(x: f64, margin: f64) -> f64 {
} }
} }
fn lambdas() -> [f64; 5] {
[
0.44 * SQRT_2,
0.41 * SQRT_2,
0.40 * SQRT_2,
0.39 * SQRT_2,
0.36 * SQRT_2,
]
}
const CS: [f64; 5] = [
1.146480988574439e+02,
-1.508871030070582e+03,
2.676085036831241e+03,
-1.356294962039222e+03,
7.543285642111850e+01,
];
fn mm_logit_win(mean_cav: f64, cov_cav: f64) {
let mut arr1 = [0.0; 5];
let mut arr2 = [0.0; 5];
let mut arr3 = [0.0; 5];
for (i, x) in lambdas().iter().enumerate() {
let (a, b, c) = mm_probit_win(x * mean_cav, x * x * cov_cav);
arr1[i] = a;
arr2[i] = b;
arr3[i] = c;
}
let logpart1 = logsumexp2(arr1, CS);
/*
dlogpart1 = (np.dot(np.exp(arr1) * arr2, CS * LAMBDAS)
/ np.dot(np.exp(arr1), CS))
d2logpart1 = (np.dot(np.exp(arr1) * (arr2 * arr2 + arr3),
CS * LAMBDAS * LAMBDAS)
/ np.dot(np.exp(arr1), CS)) - (dlogpart1 * dlogpart1)
*/
}
fn mm_logit_tie(x: f64, margin: f64) {
//
}
fn ll_logit_win(x: f64, margin: f64) -> f64 { fn ll_logit_win(x: f64, margin: f64) -> f64 {
let z = x - margin; let z = x - margin;

View File

@@ -89,7 +89,7 @@ pub fn logphi(z: f64) -> (f64, f64) {
} }
} }
pub fn logsumexp2(xs: [f64; 5], bs: [f64; 5]) -> f64 { pub fn logsumexp2(xs: &[f64], bs: &[f64]) -> f64 {
let a = xs.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap(); let a = xs.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
a + bs a + bs

View File

@@ -4,7 +4,7 @@ use kickscore as ks;
#[test] #[test]
fn binary_1() { fn binary_1() {
let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit); let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
let k_audrey = ks::kernel::Matern52::new(1.0, 2.0); let k_audrey = ks::kernel::Matern52::new(1.0, 2.0);
let k_benjamin = ks::kernel::Matern52::new(1.0, 2.0); let k_benjamin = ks::kernel::Matern52::new(1.0, 2.0);

View File

@@ -5,7 +5,7 @@ use kickscore as ks;
#[test] #[test]
fn kickscore_basic() { fn kickscore_basic() {
let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit); let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
let k_spike = ks::kernel::Constant::new(0.5); let k_spike = ks::kernel::Constant::new(0.5);

View File

@@ -58,7 +58,7 @@ fn nba_history() {
let seconds_in_year = 365.25 * 24.0 * 60.0 * 60.0; let seconds_in_year = 365.25 * 24.0 * 60.0 * 60.0;
let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit); let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
for team in teams { for team in teams {
let kernel: Vec<Box<dyn ks::Kernel>> = vec![ let kernel: Vec<Box<dyn ks::Kernel>> = vec![