Rename recursive fitter and binary model
This commit is contained in:
@@ -3,7 +3,7 @@ extern crate blas_src;
|
||||
use kickscore as ks;
|
||||
|
||||
fn main() {
|
||||
let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit);
|
||||
let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
|
||||
|
||||
for player in &["A", "B", "C", "D", "E", "F"] {
|
||||
let kernel: Vec<Box<dyn ks::Kernel>> = vec![
|
||||
|
||||
@@ -3,7 +3,7 @@ extern crate blas_src;
|
||||
use kickscore as ks;
|
||||
|
||||
fn main() {
|
||||
let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit);
|
||||
let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
|
||||
|
||||
// Spike's skill does not change over time.
|
||||
let k_spike = ks::kernel::Constant::new(0.5);
|
||||
|
||||
@@ -54,7 +54,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
|
||||
let seconds_in_year = 365.25 * 24.0 * 60.0 * 60.0;
|
||||
|
||||
let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit);
|
||||
let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
|
||||
|
||||
for team in teams {
|
||||
let kernel: Vec<Box<dyn ks::Kernel>> = vec![
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
mod recursive;
|
||||
|
||||
pub use recursive::RecursiveFitter;
|
||||
pub use recursive::Recursive;
|
||||
|
||||
pub trait Fitter {
|
||||
fn add_sample(&mut self, t: f64) -> usize;
|
||||
|
||||
@@ -7,7 +7,7 @@ use crate::kernel::Kernel;
|
||||
|
||||
use super::Fitter;
|
||||
|
||||
pub struct RecursiveFitter {
|
||||
pub struct Recursive {
|
||||
ts_new: Vec<f64>,
|
||||
kernel: Box<dyn Kernel>,
|
||||
ts: Vec<f64>,
|
||||
@@ -28,7 +28,7 @@ pub struct RecursiveFitter {
|
||||
p_s: Vec<Array2<f64>>,
|
||||
}
|
||||
|
||||
impl RecursiveFitter {
|
||||
impl Recursive {
|
||||
pub fn new(kernel: Box<dyn Kernel>) -> Self {
|
||||
let m = kernel.order();
|
||||
let h = kernel.measurement_vector();
|
||||
@@ -56,7 +56,7 @@ impl RecursiveFitter {
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for RecursiveFitter {
|
||||
impl fmt::Debug for Recursive {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("RecursiveFitter")
|
||||
.field("ts_new", &self.ts_new)
|
||||
@@ -80,7 +80,7 @@ impl fmt::Debug for RecursiveFitter {
|
||||
}
|
||||
}
|
||||
|
||||
impl Fitter for RecursiveFitter {
|
||||
impl Fitter for Recursive {
|
||||
fn add_sample(&mut self, t: f64) -> usize {
|
||||
let idx = self.ts.len() + self.ts_new.len();
|
||||
|
||||
|
||||
@@ -5,10 +5,10 @@ mod item;
|
||||
pub mod kernel;
|
||||
mod linalg;
|
||||
mod math;
|
||||
mod model;
|
||||
pub mod model;
|
||||
pub mod observation;
|
||||
mod storage;
|
||||
mod utils;
|
||||
|
||||
pub use kernel::Kernel;
|
||||
pub use model::*;
|
||||
// pub use model::*;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
mod binary;
|
||||
pub mod binary;
|
||||
mod difference;
|
||||
mod ternary;
|
||||
|
||||
pub use binary::*;
|
||||
pub use binary::Binary;
|
||||
pub use difference::*;
|
||||
pub use ternary::*;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::f64;
|
||||
|
||||
use crate::fitter::RecursiveFitter;
|
||||
use crate::fitter::Recursive;
|
||||
use crate::item::Item;
|
||||
use crate::kernel::Kernel;
|
||||
use crate::observation::*;
|
||||
@@ -12,28 +12,28 @@ trait BinaryObeservation: Observation {
|
||||
}
|
||||
*/
|
||||
|
||||
pub enum BinaryModelObservation {
|
||||
pub enum Observation {
|
||||
Probit,
|
||||
Logit,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum BinaryModelFitMethod {
|
||||
pub enum FitMethod {
|
||||
Ep,
|
||||
Kl,
|
||||
}
|
||||
|
||||
pub struct BinaryModel {
|
||||
pub struct Binary {
|
||||
pub storage: Storage,
|
||||
last_t: f64,
|
||||
win_obs: BinaryModelObservation,
|
||||
observations: Vec<Box<dyn Observation>>,
|
||||
last_method: Option<BinaryModelFitMethod>,
|
||||
win_obs: Observation,
|
||||
observations: Vec<Box<dyn crate::observation::Observation>>,
|
||||
last_method: Option<FitMethod>,
|
||||
}
|
||||
|
||||
impl BinaryModel {
|
||||
pub fn new(win_obs: BinaryModelObservation) -> Self {
|
||||
BinaryModel {
|
||||
impl Binary {
|
||||
pub fn new(win_obs: Observation) -> Self {
|
||||
Binary {
|
||||
storage: Storage::default(),
|
||||
last_t: f64::NEG_INFINITY,
|
||||
win_obs,
|
||||
@@ -49,7 +49,7 @@ impl BinaryModel {
|
||||
|
||||
self.storage.insert(
|
||||
name.to_string(),
|
||||
Item::new(Box::new(RecursiveFitter::new(kernel))),
|
||||
Item::new(Box::new(Recursive::new(kernel))),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -72,11 +72,11 @@ impl BinaryModel {
|
||||
let mut elems = self.process_items(winners, 1.0);
|
||||
elems.extend(self.process_items(losers, -1.0));
|
||||
|
||||
let obs: Box<dyn Observation> = match self.win_obs {
|
||||
BinaryModelObservation::Probit => {
|
||||
let obs: Box<dyn crate::observation::Observation> = match self.win_obs {
|
||||
Observation::Probit => {
|
||||
Box::new(ProbitWinObservation::new(&mut self.storage, &elems, t, 0.0))
|
||||
}
|
||||
BinaryModelObservation::Logit => {
|
||||
Observation::Logit => {
|
||||
Box::new(LogitWinObservation::new(&mut self.storage, &elems, t, 0.0))
|
||||
}
|
||||
};
|
||||
@@ -95,7 +95,7 @@ impl BinaryModel {
|
||||
pub fn fit(&mut self) -> bool {
|
||||
// method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False):
|
||||
|
||||
let method = BinaryModelFitMethod::Ep;
|
||||
let method = FitMethod::Ep;
|
||||
let lr = 1.0;
|
||||
let tol = 1e-3;
|
||||
let max_iter = 100;
|
||||
@@ -112,8 +112,8 @@ impl BinaryModel {
|
||||
|
||||
for obs in &mut self.observations {
|
||||
let diff = match method {
|
||||
BinaryModelFitMethod::Ep => obs.ep_update(&mut self.storage, lr),
|
||||
BinaryModelFitMethod::Kl => obs.kl_update(&mut self.storage, lr),
|
||||
FitMethod::Ep => obs.ep_update(&mut self.storage, lr),
|
||||
FitMethod::Kl => obs.kl_update(&mut self.storage, lr),
|
||||
};
|
||||
|
||||
if diff > max_diff {
|
||||
@@ -142,10 +142,8 @@ impl BinaryModel {
|
||||
elems.extend(self.process_items(team_2, -1.0));
|
||||
|
||||
let prob = match self.win_obs {
|
||||
BinaryModelObservation::Probit => {
|
||||
ProbitWinObservation::probability(&self.storage, &elems, t, 0.0)
|
||||
}
|
||||
BinaryModelObservation::Logit => todo!(),
|
||||
Observation::Probit => ProbitWinObservation::probability(&self.storage, &elems, t, 0.0),
|
||||
Observation::Logit => todo!(),
|
||||
};
|
||||
|
||||
(prob, 1.0 - prob)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::f64;
|
||||
|
||||
use crate::fitter::RecursiveFitter;
|
||||
use crate::fitter::Recursive;
|
||||
use crate::item::Item;
|
||||
use crate::kernel::Kernel;
|
||||
use crate::observation::*;
|
||||
@@ -38,7 +38,7 @@ impl DifferenceModel {
|
||||
|
||||
self.storage.insert(
|
||||
name.to_string(),
|
||||
Item::new(Box::new(RecursiveFitter::new(kernel))),
|
||||
Item::new(Box::new(Recursive::new(kernel))),
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::f64;
|
||||
|
||||
use crate::fitter::RecursiveFitter;
|
||||
use crate::fitter::Recursive;
|
||||
use crate::item::Item;
|
||||
use crate::kernel::Kernel;
|
||||
use crate::observation::*;
|
||||
@@ -46,7 +46,7 @@ impl TernaryModel {
|
||||
|
||||
self.storage.insert(
|
||||
name.to_string(),
|
||||
Item::new(Box::new(RecursiveFitter::new(kernel))),
|
||||
Item::new(Box::new(Recursive::new(kernel))),
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
use std::f64::consts::{SQRT_2, TAU};
|
||||
use std::f64::consts::TAU;
|
||||
|
||||
use crate::storage::Storage;
|
||||
use crate::utils::{
|
||||
logphi, logsumexp2, normcdf, normpdf, ROOTS_HERMITENORM_WS, ROOTS_HERMITENORM_XS,
|
||||
};
|
||||
use crate::utils::{logphi, normcdf, normpdf, ROOTS_HERMITENORM_WS, ROOTS_HERMITENORM_XS};
|
||||
|
||||
use super::{f_params, Core, Observation};
|
||||
|
||||
@@ -54,52 +52,6 @@ fn ll_probit_tie(x: f64, margin: f64) -> f64 {
|
||||
}
|
||||
}
|
||||
|
||||
fn lambdas() -> [f64; 5] {
|
||||
[
|
||||
0.44 * SQRT_2,
|
||||
0.41 * SQRT_2,
|
||||
0.40 * SQRT_2,
|
||||
0.39 * SQRT_2,
|
||||
0.36 * SQRT_2,
|
||||
]
|
||||
}
|
||||
|
||||
const CS: [f64; 5] = [
|
||||
1.146480988574439e+02,
|
||||
-1.508871030070582e+03,
|
||||
2.676085036831241e+03,
|
||||
-1.356294962039222e+03,
|
||||
7.543285642111850e+01,
|
||||
];
|
||||
|
||||
fn mm_logit_win(mean_cav: f64, cov_cav: f64) {
|
||||
let mut arr1 = [0.0; 5];
|
||||
let mut arr2 = [0.0; 5];
|
||||
let mut arr3 = [0.0; 5];
|
||||
|
||||
for (i, x) in lambdas().iter().enumerate() {
|
||||
let (a, b, c) = mm_probit_win(x * mean_cav, x * x * cov_cav);
|
||||
|
||||
arr1[i] = a;
|
||||
arr2[i] = b;
|
||||
arr3[i] = c;
|
||||
}
|
||||
|
||||
let logpart1 = logsumexp2(arr1, CS);
|
||||
|
||||
/*
|
||||
dlogpart1 = (np.dot(np.exp(arr1) * arr2, CS * LAMBDAS)
|
||||
/ np.dot(np.exp(arr1), CS))
|
||||
d2logpart1 = (np.dot(np.exp(arr1) * (arr2 * arr2 + arr3),
|
||||
CS * LAMBDAS * LAMBDAS)
|
||||
/ np.dot(np.exp(arr1), CS)) - (dlogpart1 * dlogpart1)
|
||||
*/
|
||||
}
|
||||
|
||||
fn mm_logit_tie(x: f64, margin: f64) {
|
||||
//
|
||||
}
|
||||
|
||||
fn ll_logit_win(x: f64, margin: f64) -> f64 {
|
||||
let z = x - margin;
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ pub fn logphi(z: f64) -> (f64, f64) {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn logsumexp2(xs: [f64; 5], bs: [f64; 5]) -> f64 {
|
||||
pub fn logsumexp2(xs: &[f64], bs: &[f64]) -> f64 {
|
||||
let a = xs.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
|
||||
|
||||
a + bs
|
||||
|
||||
@@ -4,7 +4,7 @@ use kickscore as ks;
|
||||
|
||||
#[test]
|
||||
fn binary_1() {
|
||||
let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit);
|
||||
let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
|
||||
|
||||
let k_audrey = ks::kernel::Matern52::new(1.0, 2.0);
|
||||
let k_benjamin = ks::kernel::Matern52::new(1.0, 2.0);
|
||||
|
||||
@@ -5,7 +5,7 @@ use kickscore as ks;
|
||||
|
||||
#[test]
|
||||
fn kickscore_basic() {
|
||||
let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit);
|
||||
let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
|
||||
|
||||
let k_spike = ks::kernel::Constant::new(0.5);
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ fn nba_history() {
|
||||
|
||||
let seconds_in_year = 365.25 * 24.0 * 60.0 * 60.0;
|
||||
|
||||
let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit);
|
||||
let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
|
||||
|
||||
for team in teams {
|
||||
let kernel: Vec<Box<dyn ks::Kernel>> = vec![
|
||||
|
||||
Reference in New Issue
Block a user