Make API a bit prettier
This commit is contained in:
@@ -3,7 +3,7 @@ extern crate blas_src;
|
|||||||
use kickscore as ks;
|
use kickscore as ks;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
|
let mut model = ks::model::Binary::probit();
|
||||||
|
|
||||||
for player in &["A", "B", "C", "D", "E", "F"] {
|
for player in &["A", "B", "C", "D", "E", "F"] {
|
||||||
let kernel: [Box<dyn ks::Kernel>; 2] = [
|
let kernel: [Box<dyn ks::Kernel>; 2] = [
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
extern crate blas_src;
|
extern crate blas_src;
|
||||||
|
|
||||||
use kickscore::{
|
use kickscore::{kernel, model::Binary};
|
||||||
kernel::{self, Kernel},
|
|
||||||
model::{binary, Binary},
|
|
||||||
};
|
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let mut model = Binary::new(binary::Observation::Probit);
|
let mut model = Binary::probit();
|
||||||
|
|
||||||
// Spike's skill does not change over time.
|
// Spike's skill does not change over time.
|
||||||
let k_spike = kernel::Constant::new(0.5);
|
let k_spike = kernel::Constant::new(0.5);
|
||||||
@@ -15,10 +12,7 @@ fn main() {
|
|||||||
let k_tom = kernel::Exponential::new(1.0, 1.0);
|
let k_tom = kernel::Exponential::new(1.0, 1.0);
|
||||||
|
|
||||||
// Jerry's skill has a constant offset and smooth dynamics.
|
// Jerry's skill has a constant offset and smooth dynamics.
|
||||||
let k_jerry: [Box<dyn Kernel>; 2] = [
|
let k_jerry = (kernel::Constant::new(1.0), kernel::Matern52::new(0.5, 1.0));
|
||||||
Box::new(kernel::Constant::new(1.0)),
|
|
||||||
Box::new(kernel::Matern52::new(0.5, 1.0)),
|
|
||||||
];
|
|
||||||
|
|
||||||
// Now we are ready to add the items in the model.
|
// Now we are ready to add the items in the model.
|
||||||
model.add_item("Spike", k_spike);
|
model.add_item("Spike", k_spike);
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
|
|
||||||
let seconds_in_year = 365.25 * 24.0 * 60.0 * 60.0;
|
let seconds_in_year = 365.25 * 24.0 * 60.0 * 60.0;
|
||||||
|
|
||||||
let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
|
let mut model = ks::model::Binary::probit();
|
||||||
|
|
||||||
for team in teams {
|
for team in teams {
|
||||||
let kernel: [Box<dyn ks::Kernel>; 2] = [
|
let kernel: [Box<dyn ks::Kernel>; 2] = [
|
||||||
|
|||||||
@@ -42,6 +42,14 @@ impl Binary {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn probit() -> Self {
|
||||||
|
Self::new(Observation::Probit)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn logit() -> Self {
|
||||||
|
Self::new(Observation::Logit)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn add_item<K: Kernel + 'static>(&mut self, name: &str, kernel: K) {
|
pub fn add_item<K: Kernel + 'static>(&mut self, name: &str, kernel: K) {
|
||||||
if self.storage.contains_key(name) {
|
if self.storage.contains_key(name) {
|
||||||
panic!("item '{}' already added", name);
|
panic!("item '{}' already added", name);
|
||||||
|
|||||||
@@ -7,13 +7,13 @@ use crate::observation::*;
|
|||||||
use crate::storage::Storage;
|
use crate::storage::Storage;
|
||||||
|
|
||||||
#[derive(Clone, Copy)]
|
#[derive(Clone, Copy)]
|
||||||
pub enum TernaryModelObservation {
|
pub enum Observation {
|
||||||
Probit,
|
Probit,
|
||||||
Logit,
|
Logit,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy)]
|
#[derive(Clone, Copy)]
|
||||||
pub enum TernaryModelFitMethod {
|
pub enum FitMethod {
|
||||||
Ep,
|
Ep,
|
||||||
Kl,
|
Kl,
|
||||||
}
|
}
|
||||||
@@ -21,14 +21,14 @@ pub enum TernaryModelFitMethod {
|
|||||||
pub struct TernaryModel {
|
pub struct TernaryModel {
|
||||||
storage: Storage,
|
storage: Storage,
|
||||||
last_t: f64,
|
last_t: f64,
|
||||||
obs: TernaryModelObservation,
|
obs: Observation,
|
||||||
observations: Vec<Box<dyn Observation>>,
|
observations: Vec<Box<dyn crate::observation::Observation>>,
|
||||||
last_method: Option<TernaryModelFitMethod>,
|
last_method: Option<FitMethod>,
|
||||||
margin: f64,
|
margin: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TernaryModel {
|
impl TernaryModel {
|
||||||
pub fn new(obs: TernaryModelObservation, margin: f64) -> Self {
|
pub fn new(obs: Observation, margin: f64) -> Self {
|
||||||
TernaryModel {
|
TernaryModel {
|
||||||
storage: Storage::default(),
|
storage: Storage::default(),
|
||||||
last_t: f64::NEG_INFINITY,
|
last_t: f64::NEG_INFINITY,
|
||||||
@@ -39,6 +39,14 @@ impl TernaryModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn probit(margin: f64) -> Self {
|
||||||
|
Self::new(Observation::Probit, margin)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn logit(margin: f64) -> Self {
|
||||||
|
Self::new(Observation::Logit, margin)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn add_item<K: Kernel + 'static>(&mut self, name: &str, kernel: K) {
|
pub fn add_item<K: Kernel + 'static>(&mut self, name: &str, kernel: K) {
|
||||||
if self.storage.contains_key(name) {
|
if self.storage.contains_key(name) {
|
||||||
panic!("item '{}' already added", name);
|
panic!("item '{}' already added", name);
|
||||||
@@ -78,26 +86,26 @@ impl TernaryModel {
|
|||||||
let mut elems = self.process_items(winners, 1.0);
|
let mut elems = self.process_items(winners, 1.0);
|
||||||
elems.extend(self.process_items(losers, -1.0));
|
elems.extend(self.process_items(losers, -1.0));
|
||||||
|
|
||||||
let obs: Box<dyn Observation> = match (tie, self.obs) {
|
let obs: Box<dyn crate::observation::Observation> = match (tie, self.obs) {
|
||||||
(false, TernaryModelObservation::Probit) => Box::new(ProbitWinObservation::new(
|
(false, Observation::Probit) => Box::new(ProbitWinObservation::new(
|
||||||
&mut self.storage,
|
&mut self.storage,
|
||||||
&elems,
|
&elems,
|
||||||
t,
|
t,
|
||||||
margin,
|
margin,
|
||||||
)),
|
)),
|
||||||
(false, TernaryModelObservation::Logit) => Box::new(LogitWinObservation::new(
|
(false, Observation::Logit) => Box::new(LogitWinObservation::new(
|
||||||
&mut self.storage,
|
&mut self.storage,
|
||||||
&elems,
|
&elems,
|
||||||
t,
|
t,
|
||||||
margin,
|
margin,
|
||||||
)),
|
)),
|
||||||
(true, TernaryModelObservation::Probit) => Box::new(ProbitTieObservation::new(
|
(true, Observation::Probit) => Box::new(ProbitTieObservation::new(
|
||||||
&mut self.storage,
|
&mut self.storage,
|
||||||
&elems,
|
&elems,
|
||||||
t,
|
t,
|
||||||
margin,
|
margin,
|
||||||
)),
|
)),
|
||||||
(true, TernaryModelObservation::Logit) => Box::new(LogitTieObservation::new(
|
(true, Observation::Logit) => Box::new(LogitTieObservation::new(
|
||||||
&mut self.storage,
|
&mut self.storage,
|
||||||
&elems,
|
&elems,
|
||||||
t,
|
t,
|
||||||
@@ -113,7 +121,7 @@ impl TernaryModel {
|
|||||||
pub fn fit(&mut self) -> bool {
|
pub fn fit(&mut self) -> bool {
|
||||||
// method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False):
|
// method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False):
|
||||||
|
|
||||||
let method = TernaryModelFitMethod::Ep;
|
let method = FitMethod::Ep;
|
||||||
let lr = 1.0;
|
let lr = 1.0;
|
||||||
let tol = 1e-3;
|
let tol = 1e-3;
|
||||||
let max_iter = 100;
|
let max_iter = 100;
|
||||||
@@ -130,8 +138,8 @@ impl TernaryModel {
|
|||||||
|
|
||||||
for obs in &mut self.observations {
|
for obs in &mut self.observations {
|
||||||
let diff = match method {
|
let diff = match method {
|
||||||
TernaryModelFitMethod::Ep => obs.ep_update(&mut self.storage, lr),
|
FitMethod::Ep => obs.ep_update(&mut self.storage, lr),
|
||||||
TernaryModelFitMethod::Kl => obs.kl_update(&mut self.storage, lr),
|
FitMethod::Kl => obs.kl_update(&mut self.storage, lr),
|
||||||
};
|
};
|
||||||
|
|
||||||
if diff > max_diff {
|
if diff > max_diff {
|
||||||
@@ -168,17 +176,17 @@ impl TernaryModel {
|
|||||||
elems.extend(self.process_items(team_2, -1.0));
|
elems.extend(self.process_items(team_2, -1.0));
|
||||||
|
|
||||||
let prob_1 = match self.obs {
|
let prob_1 = match self.obs {
|
||||||
TernaryModelObservation::Probit => {
|
Observation::Probit => {
|
||||||
ProbitWinObservation::probability(&self.storage, &elems, t, margin)
|
ProbitWinObservation::probability(&self.storage, &elems, t, margin)
|
||||||
}
|
}
|
||||||
TernaryModelObservation::Logit => unimplemented!(),
|
Observation::Logit => unimplemented!(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let prob_2 = match self.obs {
|
let prob_2 = match self.obs {
|
||||||
TernaryModelObservation::Probit => {
|
Observation::Probit => {
|
||||||
ProbitTieObservation::probability(&self.storage, &elems, t, margin)
|
ProbitTieObservation::probability(&self.storage, &elems, t, margin)
|
||||||
}
|
}
|
||||||
TernaryModelObservation::Logit => unimplemented!(),
|
Observation::Logit => unimplemented!(),
|
||||||
};
|
};
|
||||||
|
|
||||||
(prob_1, prob_2, 1.0 - prob_1 - prob_2)
|
(prob_1, prob_2, 1.0 - prob_1 - prob_2)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use kickscore as ks;
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn binary_1() {
|
fn binary_1() {
|
||||||
let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
|
let mut model = ks::model::Binary::probit();
|
||||||
|
|
||||||
let k_audrey = ks::kernel::Matern52::new(1.0, 2.0);
|
let k_audrey = ks::kernel::Matern52::new(1.0, 2.0);
|
||||||
let k_benjamin = ks::kernel::Matern52::new(1.0, 2.0);
|
let k_benjamin = ks::kernel::Matern52::new(1.0, 2.0);
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ use kickscore as ks;
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn kickscore_basic() {
|
fn kickscore_basic() {
|
||||||
let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
|
let mut model = ks::model::Binary::probit();
|
||||||
|
|
||||||
let k_spike = ks::kernel::Constant::new(0.5);
|
let k_spike = ks::kernel::Constant::new(0.5);
|
||||||
let k_tom = ks::kernel::Exponential::new(1.0, 1.0);
|
let k_tom = ks::kernel::Exponential::new(1.0, 1.0);
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ fn nba_history() {
|
|||||||
|
|
||||||
let seconds_in_year = 365.25 * 24.0 * 60.0 * 60.0;
|
let seconds_in_year = 365.25 * 24.0 * 60.0 * 60.0;
|
||||||
|
|
||||||
let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
|
let mut model = ks::model::Binary::probit();
|
||||||
|
|
||||||
for team in teams {
|
for team in teams {
|
||||||
let kernel = (
|
let kernel = (
|
||||||
|
|||||||
Reference in New Issue
Block a user