diff --git a/examples/abcdef.rs b/examples/abcdef.rs index 01726a0..b59a7bb 100644 --- a/examples/abcdef.rs +++ b/examples/abcdef.rs @@ -3,7 +3,7 @@ extern crate blas_src; use kickscore as ks; fn main() { - let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit); + let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit); for player in &["A", "B", "C", "D", "E", "F"] { let kernel: Vec> = vec![ diff --git a/examples/kickscore-basics.rs b/examples/kickscore-basics.rs index e3d39a5..6c7b862 100644 --- a/examples/kickscore-basics.rs +++ b/examples/kickscore-basics.rs @@ -3,7 +3,7 @@ extern crate blas_src; use kickscore as ks; fn main() { - let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit); + let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit); // Spike's skill does not change over time. let k_spike = ks::kernel::Constant::new(0.5); diff --git a/examples/nba-history.rs b/examples/nba-history.rs index 75a6554..f463724 100644 --- a/examples/nba-history.rs +++ b/examples/nba-history.rs @@ -54,7 +54,7 @@ fn main() -> Result<(), Box> { let seconds_in_year = 365.25 * 24.0 * 60.0 * 60.0; - let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit); + let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit); for team in teams { let kernel: Vec> = vec![ diff --git a/src/fitter.rs b/src/fitter.rs index c336be1..5711631 100644 --- a/src/fitter.rs +++ b/src/fitter.rs @@ -1,6 +1,6 @@ mod recursive; -pub use recursive::RecursiveFitter; +pub use recursive::Recursive; pub trait Fitter { fn add_sample(&mut self, t: f64) -> usize; diff --git a/src/fitter/recursive.rs b/src/fitter/recursive.rs index 348d688..40972d5 100644 --- a/src/fitter/recursive.rs +++ b/src/fitter/recursive.rs @@ -7,7 +7,7 @@ use crate::kernel::Kernel; use super::Fitter; -pub struct RecursiveFitter { +pub struct Recursive { ts_new: Vec, kernel: Box, ts: Vec, @@ -28,7 +28,7 @@ pub struct RecursiveFitter { p_s: Vec>, } -impl RecursiveFitter { +impl Recursive { pub fn new(kernel: Box) -> Self { let m = kernel.order(); let h = kernel.measurement_vector(); @@ -56,7 +56,7 @@ impl RecursiveFitter { } } -impl fmt::Debug for RecursiveFitter { +impl fmt::Debug for Recursive { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("RecursiveFitter") .field("ts_new", &self.ts_new) @@ -80,7 +80,7 @@ impl fmt::Debug for RecursiveFitter { } } -impl Fitter for RecursiveFitter { +impl Fitter for Recursive { fn add_sample(&mut self, t: f64) -> usize { let idx = self.ts.len() + self.ts_new.len(); diff --git a/src/lib.rs b/src/lib.rs index a0c9d5c..908f606 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,10 +5,10 @@ mod item; pub mod kernel; mod linalg; mod math; -mod model; +pub mod model; pub mod observation; mod storage; mod utils; pub use kernel::Kernel; -pub use model::*; +// pub use model::*; diff --git a/src/model.rs b/src/model.rs index 536e417..c54365d 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,7 +1,7 @@ -mod binary; +pub mod binary; mod difference; mod ternary; -pub use binary::*; +pub use binary::Binary; pub use difference::*; pub use ternary::*; diff --git a/src/model/binary.rs b/src/model/binary.rs index def22db..8eeb770 100644 --- a/src/model/binary.rs +++ b/src/model/binary.rs @@ -1,6 +1,6 @@ use std::f64; -use crate::fitter::RecursiveFitter; +use crate::fitter::Recursive; use crate::item::Item; use crate::kernel::Kernel; use crate::observation::*; @@ -12,28 +12,28 @@ trait BinaryObeservation: Observation { } */ -pub enum BinaryModelObservation { +pub enum Observation { Probit, Logit, } #[derive(Clone, Copy)] -pub enum BinaryModelFitMethod { +pub enum FitMethod { Ep, Kl, } -pub struct BinaryModel { +pub struct Binary { pub storage: Storage, last_t: f64, - win_obs: BinaryModelObservation, - observations: Vec>, - last_method: Option, + win_obs: Observation, + observations: Vec>, + last_method: Option, } -impl BinaryModel { - pub fn new(win_obs: BinaryModelObservation) -> Self { - BinaryModel { +impl Binary { + pub fn new(win_obs: Observation) -> Self { + Binary { storage: Storage::default(), last_t: f64::NEG_INFINITY, win_obs, @@ -49,7 +49,7 @@ impl BinaryModel { self.storage.insert( name.to_string(), - Item::new(Box::new(RecursiveFitter::new(kernel))), + Item::new(Box::new(Recursive::new(kernel))), ); } @@ -72,11 +72,11 @@ impl BinaryModel { let mut elems = self.process_items(winners, 1.0); elems.extend(self.process_items(losers, -1.0)); - let obs: Box = match self.win_obs { - BinaryModelObservation::Probit => { + let obs: Box = match self.win_obs { + Observation::Probit => { Box::new(ProbitWinObservation::new(&mut self.storage, &elems, t, 0.0)) } - BinaryModelObservation::Logit => { + Observation::Logit => { Box::new(LogitWinObservation::new(&mut self.storage, &elems, t, 0.0)) } }; @@ -95,7 +95,7 @@ impl BinaryModel { pub fn fit(&mut self) -> bool { // method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False): - let method = BinaryModelFitMethod::Ep; + let method = FitMethod::Ep; let lr = 1.0; let tol = 1e-3; let max_iter = 100; @@ -112,8 +112,8 @@ impl BinaryModel { for obs in &mut self.observations { let diff = match method { - BinaryModelFitMethod::Ep => obs.ep_update(&mut self.storage, lr), - BinaryModelFitMethod::Kl => obs.kl_update(&mut self.storage, lr), + FitMethod::Ep => obs.ep_update(&mut self.storage, lr), + FitMethod::Kl => obs.kl_update(&mut self.storage, lr), }; if diff > max_diff { @@ -142,10 +142,8 @@ impl BinaryModel { elems.extend(self.process_items(team_2, -1.0)); let prob = match self.win_obs { - BinaryModelObservation::Probit => { - ProbitWinObservation::probability(&self.storage, &elems, t, 0.0) - } - BinaryModelObservation::Logit => todo!(), + Observation::Probit => ProbitWinObservation::probability(&self.storage, &elems, t, 0.0), + Observation::Logit => todo!(), }; (prob, 1.0 - prob) diff --git a/src/model/difference.rs b/src/model/difference.rs index de72f45..c62a6bd 100644 --- a/src/model/difference.rs +++ b/src/model/difference.rs @@ -1,6 +1,6 @@ use std::f64; -use crate::fitter::RecursiveFitter; +use crate::fitter::Recursive; use crate::item::Item; use crate::kernel::Kernel; use crate::observation::*; @@ -38,7 +38,7 @@ impl DifferenceModel { self.storage.insert( name.to_string(), - Item::new(Box::new(RecursiveFitter::new(kernel))), + Item::new(Box::new(Recursive::new(kernel))), ); } diff --git a/src/model/ternary.rs b/src/model/ternary.rs index 7a2ff9a..8a7d970 100644 --- a/src/model/ternary.rs +++ b/src/model/ternary.rs @@ -1,6 +1,6 @@ use std::f64; -use crate::fitter::RecursiveFitter; +use crate::fitter::Recursive; use crate::item::Item; use crate::kernel::Kernel; use crate::observation::*; @@ -46,7 +46,7 @@ impl TernaryModel { self.storage.insert( name.to_string(), - Item::new(Box::new(RecursiveFitter::new(kernel))), + Item::new(Box::new(Recursive::new(kernel))), ); } diff --git a/src/observation/ordinal.rs b/src/observation/ordinal.rs index 5bb7496..a5885cb 100644 --- a/src/observation/ordinal.rs +++ b/src/observation/ordinal.rs @@ -1,9 +1,7 @@ -use std::f64::consts::{SQRT_2, TAU}; +use std::f64::consts::TAU; use crate::storage::Storage; -use crate::utils::{ - logphi, logsumexp2, normcdf, normpdf, ROOTS_HERMITENORM_WS, ROOTS_HERMITENORM_XS, -}; +use crate::utils::{logphi, normcdf, normpdf, ROOTS_HERMITENORM_WS, ROOTS_HERMITENORM_XS}; use super::{f_params, Core, Observation}; @@ -54,52 +52,6 @@ fn ll_probit_tie(x: f64, margin: f64) -> f64 { } } -fn lambdas() -> [f64; 5] { - [ - 0.44 * SQRT_2, - 0.41 * SQRT_2, - 0.40 * SQRT_2, - 0.39 * SQRT_2, - 0.36 * SQRT_2, - ] -} - -const CS: [f64; 5] = [ - 1.146480988574439e+02, - -1.508871030070582e+03, - 2.676085036831241e+03, - -1.356294962039222e+03, - 7.543285642111850e+01, -]; - -fn mm_logit_win(mean_cav: f64, cov_cav: f64) { - let mut arr1 = [0.0; 5]; - let mut arr2 = [0.0; 5]; - let mut arr3 = [0.0; 5]; - - for (i, x) in lambdas().iter().enumerate() { - let (a, b, c) = mm_probit_win(x * mean_cav, x * x * cov_cav); - - arr1[i] = a; - arr2[i] = b; - arr3[i] = c; - } - - let logpart1 = logsumexp2(arr1, CS); - - /* - dlogpart1 = (np.dot(np.exp(arr1) * arr2, CS * LAMBDAS) - / np.dot(np.exp(arr1), CS)) - d2logpart1 = (np.dot(np.exp(arr1) * (arr2 * arr2 + arr3), - CS * LAMBDAS * LAMBDAS) - / np.dot(np.exp(arr1), CS)) - (dlogpart1 * dlogpart1) - */ -} - -fn mm_logit_tie(x: f64, margin: f64) { - // -} - fn ll_logit_win(x: f64, margin: f64) -> f64 { let z = x - margin; diff --git a/src/utils.rs b/src/utils.rs index 719e4b9..ad8a4f2 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -89,7 +89,7 @@ pub fn logphi(z: f64) -> (f64, f64) { } } -pub fn logsumexp2(xs: [f64; 5], bs: [f64; 5]) -> f64 { +pub fn logsumexp2(xs: &[f64], bs: &[f64]) -> f64 { let a = xs.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap(); a + bs diff --git a/tests/binary-1.rs b/tests/binary-1.rs index ec84cc9..68079e8 100644 --- a/tests/binary-1.rs +++ b/tests/binary-1.rs @@ -4,7 +4,7 @@ use kickscore as ks; #[test] fn binary_1() { - let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit); + let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit); let k_audrey = ks::kernel::Matern52::new(1.0, 2.0); let k_benjamin = ks::kernel::Matern52::new(1.0, 2.0); diff --git a/tests/kickscore-basics.rs b/tests/kickscore-basics.rs index 6ea33fd..dcc7f98 100644 --- a/tests/kickscore-basics.rs +++ b/tests/kickscore-basics.rs @@ -5,7 +5,7 @@ use kickscore as ks; #[test] fn kickscore_basic() { - let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit); + let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit); let k_spike = ks::kernel::Constant::new(0.5); diff --git a/tests/nba-history.rs b/tests/nba-history.rs index c70795b..3c67719 100644 --- a/tests/nba-history.rs +++ b/tests/nba-history.rs @@ -58,7 +58,7 @@ fn nba_history() { let seconds_in_year = 365.25 * 24.0 * 60.0 * 60.0; - let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit); + let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit); for team in teams { let kernel: Vec> = vec![