diff --git a/examples/abcdef.rs b/examples/abcdef.rs index b4e8699..8c68fbb 100644 --- a/examples/abcdef.rs +++ b/examples/abcdef.rs @@ -3,7 +3,7 @@ extern crate blas_src; use kickscore as ks; fn main() { - let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit); + let mut model = ks::model::Binary::probit(); for player in &["A", "B", "C", "D", "E", "F"] { let kernel: [Box; 2] = [ diff --git a/examples/kickscore-basics.rs b/examples/kickscore-basics.rs index 24102e2..1b6bb6b 100644 --- a/examples/kickscore-basics.rs +++ b/examples/kickscore-basics.rs @@ -1,12 +1,9 @@ extern crate blas_src; -use kickscore::{ - kernel::{self, Kernel}, - model::{binary, Binary}, -}; +use kickscore::{kernel, model::Binary}; fn main() { - let mut model = Binary::new(binary::Observation::Probit); + let mut model = Binary::probit(); // Spike's skill does not change over time. let k_spike = kernel::Constant::new(0.5); @@ -15,10 +12,7 @@ fn main() { let k_tom = kernel::Exponential::new(1.0, 1.0); // Jerry's skill has a constant offset and smooth dynamics. - let k_jerry: [Box; 2] = [ - Box::new(kernel::Constant::new(1.0)), - Box::new(kernel::Matern52::new(0.5, 1.0)), - ]; + let k_jerry = (kernel::Constant::new(1.0), kernel::Matern52::new(0.5, 1.0)); // Now we are ready to add the items in the model. model.add_item("Spike", k_spike); diff --git a/examples/nba-history.rs b/examples/nba-history.rs index 7e2cc66..25f9efd 100644 --- a/examples/nba-history.rs +++ b/examples/nba-history.rs @@ -54,7 +54,7 @@ fn main() -> Result<(), Box> { let seconds_in_year = 365.25 * 24.0 * 60.0 * 60.0; - let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit); + let mut model = ks::model::Binary::probit(); for team in teams { let kernel: [Box; 2] = [ diff --git a/src/model/binary.rs b/src/model/binary.rs index 443d627..a3dadc4 100644 --- a/src/model/binary.rs +++ b/src/model/binary.rs @@ -42,6 +42,14 @@ impl Binary { } } + pub fn probit() -> Self { + Self::new(Observation::Probit) + } + + pub fn logit() -> Self { + Self::new(Observation::Logit) + } + pub fn add_item(&mut self, name: &str, kernel: K) { if self.storage.contains_key(name) { panic!("item '{}' already added", name); diff --git a/src/model/ternary.rs b/src/model/ternary.rs index a44858a..ea7411c 100644 --- a/src/model/ternary.rs +++ b/src/model/ternary.rs @@ -7,13 +7,13 @@ use crate::observation::*; use crate::storage::Storage; #[derive(Clone, Copy)] -pub enum TernaryModelObservation { +pub enum Observation { Probit, Logit, } #[derive(Clone, Copy)] -pub enum TernaryModelFitMethod { +pub enum FitMethod { Ep, Kl, } @@ -21,14 +21,14 @@ pub enum TernaryModelFitMethod { pub struct TernaryModel { storage: Storage, last_t: f64, - obs: TernaryModelObservation, - observations: Vec>, - last_method: Option, + obs: Observation, + observations: Vec>, + last_method: Option, margin: f64, } impl TernaryModel { - pub fn new(obs: TernaryModelObservation, margin: f64) -> Self { + pub fn new(obs: Observation, margin: f64) -> Self { TernaryModel { storage: Storage::default(), last_t: f64::NEG_INFINITY, @@ -39,6 +39,14 @@ impl TernaryModel { } } + pub fn probit(margin: f64) -> Self { + Self::new(Observation::Probit, margin) + } + + pub fn logit(margin: f64) -> Self { + Self::new(Observation::Logit, margin) + } + pub fn add_item(&mut self, name: &str, kernel: K) { if self.storage.contains_key(name) { panic!("item '{}' already added", name); @@ -78,26 +86,26 @@ impl TernaryModel { let mut elems = self.process_items(winners, 1.0); elems.extend(self.process_items(losers, -1.0)); - let obs: Box = match (tie, self.obs) { - (false, TernaryModelObservation::Probit) => Box::new(ProbitWinObservation::new( + let obs: Box = match (tie, self.obs) { + (false, Observation::Probit) => Box::new(ProbitWinObservation::new( &mut self.storage, &elems, t, margin, )), - (false, TernaryModelObservation::Logit) => Box::new(LogitWinObservation::new( + (false, Observation::Logit) => Box::new(LogitWinObservation::new( &mut self.storage, &elems, t, margin, )), - (true, TernaryModelObservation::Probit) => Box::new(ProbitTieObservation::new( + (true, Observation::Probit) => Box::new(ProbitTieObservation::new( &mut self.storage, &elems, t, margin, )), - (true, TernaryModelObservation::Logit) => Box::new(LogitTieObservation::new( + (true, Observation::Logit) => Box::new(LogitTieObservation::new( &mut self.storage, &elems, t, @@ -113,7 +121,7 @@ impl TernaryModel { pub fn fit(&mut self) -> bool { // method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False): - let method = TernaryModelFitMethod::Ep; + let method = FitMethod::Ep; let lr = 1.0; let tol = 1e-3; let max_iter = 100; @@ -130,8 +138,8 @@ impl TernaryModel { for obs in &mut self.observations { let diff = match method { - TernaryModelFitMethod::Ep => obs.ep_update(&mut self.storage, lr), - TernaryModelFitMethod::Kl => obs.kl_update(&mut self.storage, lr), + FitMethod::Ep => obs.ep_update(&mut self.storage, lr), + FitMethod::Kl => obs.kl_update(&mut self.storage, lr), }; if diff > max_diff { @@ -168,17 +176,17 @@ impl TernaryModel { elems.extend(self.process_items(team_2, -1.0)); let prob_1 = match self.obs { - TernaryModelObservation::Probit => { + Observation::Probit => { ProbitWinObservation::probability(&self.storage, &elems, t, margin) } - TernaryModelObservation::Logit => unimplemented!(), + Observation::Logit => unimplemented!(), }; let prob_2 = match self.obs { - TernaryModelObservation::Probit => { + Observation::Probit => { ProbitTieObservation::probability(&self.storage, &elems, t, margin) } - TernaryModelObservation::Logit => unimplemented!(), + Observation::Logit => unimplemented!(), }; (prob_1, prob_2, 1.0 - prob_1 - prob_2) diff --git a/tests/binary-1.rs b/tests/binary-1.rs index badcb2b..1645208 100644 --- a/tests/binary-1.rs +++ b/tests/binary-1.rs @@ -4,7 +4,7 @@ use kickscore as ks; #[test] fn binary_1() { - let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit); + let mut model = ks::model::Binary::probit(); let k_audrey = ks::kernel::Matern52::new(1.0, 2.0); let k_benjamin = ks::kernel::Matern52::new(1.0, 2.0); diff --git a/tests/kickscore-basics.rs b/tests/kickscore-basics.rs index 8ca6887..b1d9753 100644 --- a/tests/kickscore-basics.rs +++ b/tests/kickscore-basics.rs @@ -5,7 +5,7 @@ use kickscore as ks; #[test] fn kickscore_basic() { - let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit); + let mut model = ks::model::Binary::probit(); let k_spike = ks::kernel::Constant::new(0.5); let k_tom = ks::kernel::Exponential::new(1.0, 1.0); diff --git a/tests/nba-history.rs b/tests/nba-history.rs index de20909..ac1aeb2 100644 --- a/tests/nba-history.rs +++ b/tests/nba-history.rs @@ -58,7 +58,7 @@ fn nba_history() { let seconds_in_year = 365.25 * 24.0 * 60.0 * 60.0; - let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit); + let mut model = ks::model::Binary::probit(); for team in teams { let kernel = (