Files
kickscore/src/model/binary.rs
2020-03-06 09:49:55 +01:00

153 lines
4.0 KiB
Rust

use std::f64;
use crate::fitter::RecursiveFitter;
use crate::item::Item;
use crate::kernel::Kernel;
use crate::observation::*;
use crate::storage::Storage;
pub enum BinaryModelObservation {
Probit,
Logit,
}
#[derive(Clone, Copy)]
pub enum BinaryModelFitMethod {
Ep,
Kl,
}
pub struct BinaryModel {
storage: Storage,
last_t: f64,
win_obs: BinaryModelObservation,
observations: Vec<Box<dyn Observation>>,
last_method: Option<BinaryModelFitMethod>,
}
impl BinaryModel {
pub fn new(win_obs: BinaryModelObservation) -> Self {
BinaryModel {
storage: Storage::new(),
last_t: f64::NEG_INFINITY,
win_obs,
observations: Vec::new(),
last_method: None,
}
}
pub fn add_item(&mut self, name: &str, kernel: Box<dyn Kernel>) {
if self.storage.contains_key(name) {
panic!("item '{}' already added", name);
}
self.storage.insert(
name.to_string(),
Item::new(Box::new(RecursiveFitter::new(kernel))),
);
}
pub fn contains_item(&self, name: &str) -> bool {
self.storage.contains_key(name)
}
pub fn item_score(&self, name: &str, t: f64) -> (f64, f64) {
let id = self.storage.get_id(name);
let (ms, vs) = self.storage.item(id).fitter.predict(&[t]);
(ms[0], vs[0])
}
pub fn observe(&mut self, winners: &[&str], losers: &[&str], t: f64) {
if t < self.last_t {
panic!("observations must be added in chronological order");
}
let mut elems = self.process_items(winners, 1.0);
elems.extend(self.process_items(losers, -1.0));
let obs: Box<dyn Observation> = match self.win_obs {
BinaryModelObservation::Probit => {
Box::new(ProbitWinObservation::new(&mut self.storage, &elems, t, 0.0))
}
BinaryModelObservation::Logit => {
Box::new(LogitWinObservation::new(&mut self.storage, &elems, t, 0.0))
}
};
self.observations.push(obs);
/*
for (item, _) in elems {
item.link_observation(obs)
}
*/
self.last_t = t;
}
pub fn fit(&mut self) -> bool {
// method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False):
let method = BinaryModelFitMethod::Ep;
let lr = 1.0;
let tol = 1e-3;
let max_iter = 100;
let verbose = true;
self.last_method = Some(method);
for item in self.storage.items_mut() {
item.fitter.allocate();
}
for i in 0..max_iter {
let mut max_diff = 0.0;
for obs in &mut self.observations {
let diff = match method {
BinaryModelFitMethod::Ep => obs.ep_update(lr, &mut self.storage),
BinaryModelFitMethod::Kl => obs.kl_update(lr, &mut self.storage),
};
if diff > max_diff {
max_diff = diff;
}
}
for item in self.storage.items_mut() {
item.fitter.fit();
}
if verbose {
println!("iteration {}, max diff: {:.5}", i + 1, max_diff);
}
if max_diff < tol {
return true;
}
}
false
}
pub fn probabilities(&mut self, team_1: &[&str], team_2: &[&str], t: f64) -> (f64, f64) {
let mut elems = self.process_items(team_1, 1.0);
elems.extend(self.process_items(team_2, -1.0));
let prob = match self.win_obs {
BinaryModelObservation::Probit => probit_win_observation(&elems, t, 0.0, &self.storage),
BinaryModelObservation::Logit => todo!(),
};
(prob, 1.0 - prob)
}
fn process_items(&self, items: &[&str], sign: f64) -> Vec<(usize, f64)> {
items
.iter()
.map(|key| (self.storage.get_id(&key), sign))
.collect()
}
}