Make API a bit prettier
This commit is contained in:
@@ -3,7 +3,7 @@ extern crate blas_src;
|
||||
use kickscore as ks;
|
||||
|
||||
fn main() {
|
||||
let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
|
||||
let mut model = ks::model::Binary::probit();
|
||||
|
||||
for player in &["A", "B", "C", "D", "E", "F"] {
|
||||
let kernel: [Box<dyn ks::Kernel>; 2] = [
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
extern crate blas_src;
|
||||
|
||||
use kickscore::{
|
||||
kernel::{self, Kernel},
|
||||
model::{binary, Binary},
|
||||
};
|
||||
use kickscore::{kernel, model::Binary};
|
||||
|
||||
fn main() {
|
||||
let mut model = Binary::new(binary::Observation::Probit);
|
||||
let mut model = Binary::probit();
|
||||
|
||||
// Spike's skill does not change over time.
|
||||
let k_spike = kernel::Constant::new(0.5);
|
||||
@@ -15,10 +12,7 @@ fn main() {
|
||||
let k_tom = kernel::Exponential::new(1.0, 1.0);
|
||||
|
||||
// Jerry's skill has a constant offset and smooth dynamics.
|
||||
let k_jerry: [Box<dyn Kernel>; 2] = [
|
||||
Box::new(kernel::Constant::new(1.0)),
|
||||
Box::new(kernel::Matern52::new(0.5, 1.0)),
|
||||
];
|
||||
let k_jerry = (kernel::Constant::new(1.0), kernel::Matern52::new(0.5, 1.0));
|
||||
|
||||
// Now we are ready to add the items in the model.
|
||||
model.add_item("Spike", k_spike);
|
||||
|
||||
@@ -54,7 +54,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
|
||||
let seconds_in_year = 365.25 * 24.0 * 60.0 * 60.0;
|
||||
|
||||
let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
|
||||
let mut model = ks::model::Binary::probit();
|
||||
|
||||
for team in teams {
|
||||
let kernel: [Box<dyn ks::Kernel>; 2] = [
|
||||
|
||||
@@ -42,6 +42,14 @@ impl Binary {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn probit() -> Self {
|
||||
Self::new(Observation::Probit)
|
||||
}
|
||||
|
||||
pub fn logit() -> Self {
|
||||
Self::new(Observation::Logit)
|
||||
}
|
||||
|
||||
pub fn add_item<K: Kernel + 'static>(&mut self, name: &str, kernel: K) {
|
||||
if self.storage.contains_key(name) {
|
||||
panic!("item '{}' already added", name);
|
||||
|
||||
@@ -7,13 +7,13 @@ use crate::observation::*;
|
||||
use crate::storage::Storage;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum TernaryModelObservation {
|
||||
pub enum Observation {
|
||||
Probit,
|
||||
Logit,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum TernaryModelFitMethod {
|
||||
pub enum FitMethod {
|
||||
Ep,
|
||||
Kl,
|
||||
}
|
||||
@@ -21,14 +21,14 @@ pub enum TernaryModelFitMethod {
|
||||
pub struct TernaryModel {
|
||||
storage: Storage,
|
||||
last_t: f64,
|
||||
obs: TernaryModelObservation,
|
||||
observations: Vec<Box<dyn Observation>>,
|
||||
last_method: Option<TernaryModelFitMethod>,
|
||||
obs: Observation,
|
||||
observations: Vec<Box<dyn crate::observation::Observation>>,
|
||||
last_method: Option<FitMethod>,
|
||||
margin: f64,
|
||||
}
|
||||
|
||||
impl TernaryModel {
|
||||
pub fn new(obs: TernaryModelObservation, margin: f64) -> Self {
|
||||
pub fn new(obs: Observation, margin: f64) -> Self {
|
||||
TernaryModel {
|
||||
storage: Storage::default(),
|
||||
last_t: f64::NEG_INFINITY,
|
||||
@@ -39,6 +39,14 @@ impl TernaryModel {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn probit(margin: f64) -> Self {
|
||||
Self::new(Observation::Probit, margin)
|
||||
}
|
||||
|
||||
pub fn logit(margin: f64) -> Self {
|
||||
Self::new(Observation::Logit, margin)
|
||||
}
|
||||
|
||||
pub fn add_item<K: Kernel + 'static>(&mut self, name: &str, kernel: K) {
|
||||
if self.storage.contains_key(name) {
|
||||
panic!("item '{}' already added", name);
|
||||
@@ -78,26 +86,26 @@ impl TernaryModel {
|
||||
let mut elems = self.process_items(winners, 1.0);
|
||||
elems.extend(self.process_items(losers, -1.0));
|
||||
|
||||
let obs: Box<dyn Observation> = match (tie, self.obs) {
|
||||
(false, TernaryModelObservation::Probit) => Box::new(ProbitWinObservation::new(
|
||||
let obs: Box<dyn crate::observation::Observation> = match (tie, self.obs) {
|
||||
(false, Observation::Probit) => Box::new(ProbitWinObservation::new(
|
||||
&mut self.storage,
|
||||
&elems,
|
||||
t,
|
||||
margin,
|
||||
)),
|
||||
(false, TernaryModelObservation::Logit) => Box::new(LogitWinObservation::new(
|
||||
(false, Observation::Logit) => Box::new(LogitWinObservation::new(
|
||||
&mut self.storage,
|
||||
&elems,
|
||||
t,
|
||||
margin,
|
||||
)),
|
||||
(true, TernaryModelObservation::Probit) => Box::new(ProbitTieObservation::new(
|
||||
(true, Observation::Probit) => Box::new(ProbitTieObservation::new(
|
||||
&mut self.storage,
|
||||
&elems,
|
||||
t,
|
||||
margin,
|
||||
)),
|
||||
(true, TernaryModelObservation::Logit) => Box::new(LogitTieObservation::new(
|
||||
(true, Observation::Logit) => Box::new(LogitTieObservation::new(
|
||||
&mut self.storage,
|
||||
&elems,
|
||||
t,
|
||||
@@ -113,7 +121,7 @@ impl TernaryModel {
|
||||
pub fn fit(&mut self) -> bool {
|
||||
// method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False):
|
||||
|
||||
let method = TernaryModelFitMethod::Ep;
|
||||
let method = FitMethod::Ep;
|
||||
let lr = 1.0;
|
||||
let tol = 1e-3;
|
||||
let max_iter = 100;
|
||||
@@ -130,8 +138,8 @@ impl TernaryModel {
|
||||
|
||||
for obs in &mut self.observations {
|
||||
let diff = match method {
|
||||
TernaryModelFitMethod::Ep => obs.ep_update(&mut self.storage, lr),
|
||||
TernaryModelFitMethod::Kl => obs.kl_update(&mut self.storage, lr),
|
||||
FitMethod::Ep => obs.ep_update(&mut self.storage, lr),
|
||||
FitMethod::Kl => obs.kl_update(&mut self.storage, lr),
|
||||
};
|
||||
|
||||
if diff > max_diff {
|
||||
@@ -168,17 +176,17 @@ impl TernaryModel {
|
||||
elems.extend(self.process_items(team_2, -1.0));
|
||||
|
||||
let prob_1 = match self.obs {
|
||||
TernaryModelObservation::Probit => {
|
||||
Observation::Probit => {
|
||||
ProbitWinObservation::probability(&self.storage, &elems, t, margin)
|
||||
}
|
||||
TernaryModelObservation::Logit => unimplemented!(),
|
||||
Observation::Logit => unimplemented!(),
|
||||
};
|
||||
|
||||
let prob_2 = match self.obs {
|
||||
TernaryModelObservation::Probit => {
|
||||
Observation::Probit => {
|
||||
ProbitTieObservation::probability(&self.storage, &elems, t, margin)
|
||||
}
|
||||
TernaryModelObservation::Logit => unimplemented!(),
|
||||
Observation::Logit => unimplemented!(),
|
||||
};
|
||||
|
||||
(prob_1, prob_2, 1.0 - prob_1 - prob_2)
|
||||
|
||||
@@ -4,7 +4,7 @@ use kickscore as ks;
|
||||
|
||||
#[test]
|
||||
fn binary_1() {
|
||||
let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
|
||||
let mut model = ks::model::Binary::probit();
|
||||
|
||||
let k_audrey = ks::kernel::Matern52::new(1.0, 2.0);
|
||||
let k_benjamin = ks::kernel::Matern52::new(1.0, 2.0);
|
||||
|
||||
@@ -5,7 +5,7 @@ use kickscore as ks;
|
||||
|
||||
#[test]
|
||||
fn kickscore_basic() {
|
||||
let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
|
||||
let mut model = ks::model::Binary::probit();
|
||||
|
||||
let k_spike = ks::kernel::Constant::new(0.5);
|
||||
let k_tom = ks::kernel::Exponential::new(1.0, 1.0);
|
||||
|
||||
@@ -58,7 +58,7 @@ fn nba_history() {
|
||||
|
||||
let seconds_in_year = 365.25 * 24.0 * 60.0 * 60.0;
|
||||
|
||||
let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
|
||||
let mut model = ks::model::Binary::probit();
|
||||
|
||||
for team in teams {
|
||||
let kernel = (
|
||||
|
||||
Reference in New Issue
Block a user