Make API a bit prettier

This commit is contained in:
2022-04-27 09:06:29 +02:00
parent aa8580970a
commit a8cef3806a
8 changed files with 42 additions and 32 deletions

View File

@@ -3,7 +3,7 @@ extern crate blas_src;
use kickscore as ks;
fn main() {
let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
let mut model = ks::model::Binary::probit();
for player in &["A", "B", "C", "D", "E", "F"] {
let kernel: [Box<dyn ks::Kernel>; 2] = [

View File

@@ -1,12 +1,9 @@
extern crate blas_src;
use kickscore::{
kernel::{self, Kernel},
model::{binary, Binary},
};
use kickscore::{kernel, model::Binary};
fn main() {
let mut model = Binary::new(binary::Observation::Probit);
let mut model = Binary::probit();
// Spike's skill does not change over time.
let k_spike = kernel::Constant::new(0.5);
@@ -15,10 +12,7 @@ fn main() {
let k_tom = kernel::Exponential::new(1.0, 1.0);
// Jerry's skill has a constant offset and smooth dynamics.
let k_jerry: [Box<dyn Kernel>; 2] = [
Box::new(kernel::Constant::new(1.0)),
Box::new(kernel::Matern52::new(0.5, 1.0)),
];
let k_jerry = (kernel::Constant::new(1.0), kernel::Matern52::new(0.5, 1.0));
// Now we are ready to add the items in the model.
model.add_item("Spike", k_spike);

View File

@@ -54,7 +54,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let seconds_in_year = 365.25 * 24.0 * 60.0 * 60.0;
let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
let mut model = ks::model::Binary::probit();
for team in teams {
let kernel: [Box<dyn ks::Kernel>; 2] = [

View File

@@ -42,6 +42,14 @@ impl Binary {
}
}
pub fn probit() -> Self {
Self::new(Observation::Probit)
}
pub fn logit() -> Self {
Self::new(Observation::Logit)
}
pub fn add_item<K: Kernel + 'static>(&mut self, name: &str, kernel: K) {
if self.storage.contains_key(name) {
panic!("item '{}' already added", name);

View File

@@ -7,13 +7,13 @@ use crate::observation::*;
use crate::storage::Storage;
#[derive(Clone, Copy)]
pub enum TernaryModelObservation {
pub enum Observation {
Probit,
Logit,
}
#[derive(Clone, Copy)]
pub enum TernaryModelFitMethod {
pub enum FitMethod {
Ep,
Kl,
}
@@ -21,14 +21,14 @@ pub enum TernaryModelFitMethod {
pub struct TernaryModel {
storage: Storage,
last_t: f64,
obs: TernaryModelObservation,
observations: Vec<Box<dyn Observation>>,
last_method: Option<TernaryModelFitMethod>,
obs: Observation,
observations: Vec<Box<dyn crate::observation::Observation>>,
last_method: Option<FitMethod>,
margin: f64,
}
impl TernaryModel {
pub fn new(obs: TernaryModelObservation, margin: f64) -> Self {
pub fn new(obs: Observation, margin: f64) -> Self {
TernaryModel {
storage: Storage::default(),
last_t: f64::NEG_INFINITY,
@@ -39,6 +39,14 @@ impl TernaryModel {
}
}
pub fn probit(margin: f64) -> Self {
Self::new(Observation::Probit, margin)
}
pub fn logit(margin: f64) -> Self {
Self::new(Observation::Logit, margin)
}
pub fn add_item<K: Kernel + 'static>(&mut self, name: &str, kernel: K) {
if self.storage.contains_key(name) {
panic!("item '{}' already added", name);
@@ -78,26 +86,26 @@ impl TernaryModel {
let mut elems = self.process_items(winners, 1.0);
elems.extend(self.process_items(losers, -1.0));
let obs: Box<dyn Observation> = match (tie, self.obs) {
(false, TernaryModelObservation::Probit) => Box::new(ProbitWinObservation::new(
let obs: Box<dyn crate::observation::Observation> = match (tie, self.obs) {
(false, Observation::Probit) => Box::new(ProbitWinObservation::new(
&mut self.storage,
&elems,
t,
margin,
)),
(false, TernaryModelObservation::Logit) => Box::new(LogitWinObservation::new(
(false, Observation::Logit) => Box::new(LogitWinObservation::new(
&mut self.storage,
&elems,
t,
margin,
)),
(true, TernaryModelObservation::Probit) => Box::new(ProbitTieObservation::new(
(true, Observation::Probit) => Box::new(ProbitTieObservation::new(
&mut self.storage,
&elems,
t,
margin,
)),
(true, TernaryModelObservation::Logit) => Box::new(LogitTieObservation::new(
(true, Observation::Logit) => Box::new(LogitTieObservation::new(
&mut self.storage,
&elems,
t,
@@ -113,7 +121,7 @@ impl TernaryModel {
pub fn fit(&mut self) -> bool {
// method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False):
let method = TernaryModelFitMethod::Ep;
let method = FitMethod::Ep;
let lr = 1.0;
let tol = 1e-3;
let max_iter = 100;
@@ -130,8 +138,8 @@ impl TernaryModel {
for obs in &mut self.observations {
let diff = match method {
TernaryModelFitMethod::Ep => obs.ep_update(&mut self.storage, lr),
TernaryModelFitMethod::Kl => obs.kl_update(&mut self.storage, lr),
FitMethod::Ep => obs.ep_update(&mut self.storage, lr),
FitMethod::Kl => obs.kl_update(&mut self.storage, lr),
};
if diff > max_diff {
@@ -168,17 +176,17 @@ impl TernaryModel {
elems.extend(self.process_items(team_2, -1.0));
let prob_1 = match self.obs {
TernaryModelObservation::Probit => {
Observation::Probit => {
ProbitWinObservation::probability(&self.storage, &elems, t, margin)
}
TernaryModelObservation::Logit => unimplemented!(),
Observation::Logit => unimplemented!(),
};
let prob_2 = match self.obs {
TernaryModelObservation::Probit => {
Observation::Probit => {
ProbitTieObservation::probability(&self.storage, &elems, t, margin)
}
TernaryModelObservation::Logit => unimplemented!(),
Observation::Logit => unimplemented!(),
};
(prob_1, prob_2, 1.0 - prob_1 - prob_2)

View File

@@ -4,7 +4,7 @@ use kickscore as ks;
#[test]
fn binary_1() {
let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
let mut model = ks::model::Binary::probit();
let k_audrey = ks::kernel::Matern52::new(1.0, 2.0);
let k_benjamin = ks::kernel::Matern52::new(1.0, 2.0);

View File

@@ -5,7 +5,7 @@ use kickscore as ks;
#[test]
fn kickscore_basic() {
let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
let mut model = ks::model::Binary::probit();
let k_spike = ks::kernel::Constant::new(0.5);
let k_tom = ks::kernel::Exponential::new(1.0, 1.0);

View File

@@ -58,7 +58,7 @@ fn nba_history() {
let seconds_in_year = 365.25 * 24.0 * 60.0 * 60.0;
let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
let mut model = ks::model::Binary::probit();
for team in teams {
let kernel = (