153 lines
4.0 KiB
Rust
153 lines
4.0 KiB
Rust
use std::f64;
|
|
|
|
use crate::fitter::RecursiveFitter;
|
|
use crate::item::Item;
|
|
use crate::kernel::Kernel;
|
|
use crate::observation::*;
|
|
use crate::storage::Storage;
|
|
|
|
pub enum BinaryModelObservation {
|
|
Probit,
|
|
Logit,
|
|
}
|
|
|
|
#[derive(Clone, Copy)]
|
|
pub enum BinaryModelFitMethod {
|
|
Ep,
|
|
Kl,
|
|
}
|
|
|
|
pub struct BinaryModel {
|
|
storage: Storage,
|
|
last_t: f64,
|
|
win_obs: BinaryModelObservation,
|
|
observations: Vec<Box<dyn Observation>>,
|
|
last_method: Option<BinaryModelFitMethod>,
|
|
}
|
|
|
|
impl BinaryModel {
|
|
pub fn new(win_obs: BinaryModelObservation) -> Self {
|
|
BinaryModel {
|
|
storage: Storage::new(),
|
|
last_t: f64::NEG_INFINITY,
|
|
win_obs,
|
|
observations: Vec::new(),
|
|
last_method: None,
|
|
}
|
|
}
|
|
|
|
pub fn add_item(&mut self, name: &str, kernel: Box<dyn Kernel>) {
|
|
if self.storage.contains_key(name) {
|
|
panic!("item '{}' already added", name);
|
|
}
|
|
|
|
self.storage.insert(
|
|
name.to_string(),
|
|
Item::new(Box::new(RecursiveFitter::new(kernel))),
|
|
);
|
|
}
|
|
|
|
pub fn contains_item(&self, name: &str) -> bool {
|
|
self.storage.contains_key(name)
|
|
}
|
|
|
|
pub fn item_score(&self, name: &str, t: f64) -> (f64, f64) {
|
|
let id = self.storage.get_id(name);
|
|
let (ms, vs) = self.storage.item(id).fitter.predict(&[t]);
|
|
|
|
(ms[0], vs[0])
|
|
}
|
|
|
|
pub fn observe(&mut self, winners: &[&str], losers: &[&str], t: f64) {
|
|
if t < self.last_t {
|
|
panic!("observations must be added in chronological order");
|
|
}
|
|
|
|
let mut elems = self.process_items(winners, 1.0);
|
|
elems.extend(self.process_items(losers, -1.0));
|
|
|
|
let obs: Box<dyn Observation> = match self.win_obs {
|
|
BinaryModelObservation::Probit => {
|
|
Box::new(ProbitWinObservation::new(&mut self.storage, &elems, t, 0.0))
|
|
}
|
|
BinaryModelObservation::Logit => {
|
|
Box::new(LogitWinObservation::new(&mut self.storage, &elems, t, 0.0))
|
|
}
|
|
};
|
|
|
|
self.observations.push(obs);
|
|
|
|
/*
|
|
for (item, _) in elems {
|
|
item.link_observation(obs)
|
|
}
|
|
*/
|
|
|
|
self.last_t = t;
|
|
}
|
|
|
|
pub fn fit(&mut self) -> bool {
|
|
// method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False):
|
|
|
|
let method = BinaryModelFitMethod::Ep;
|
|
let lr = 1.0;
|
|
let tol = 1e-3;
|
|
let max_iter = 100;
|
|
let verbose = true;
|
|
|
|
self.last_method = Some(method);
|
|
|
|
for item in self.storage.items_mut() {
|
|
item.fitter.allocate();
|
|
}
|
|
|
|
for i in 0..max_iter {
|
|
let mut max_diff = 0.0;
|
|
|
|
for obs in &mut self.observations {
|
|
let diff = match method {
|
|
BinaryModelFitMethod::Ep => obs.ep_update(lr, &mut self.storage),
|
|
BinaryModelFitMethod::Kl => obs.kl_update(lr, &mut self.storage),
|
|
};
|
|
|
|
if diff > max_diff {
|
|
max_diff = diff;
|
|
}
|
|
}
|
|
|
|
for item in self.storage.items_mut() {
|
|
item.fitter.fit();
|
|
}
|
|
|
|
if verbose {
|
|
println!("iteration {}, max diff: {:.5}", i + 1, max_diff);
|
|
}
|
|
|
|
if max_diff < tol {
|
|
return true;
|
|
}
|
|
}
|
|
|
|
false
|
|
}
|
|
|
|
pub fn probabilities(&mut self, team_1: &[&str], team_2: &[&str], t: f64) -> (f64, f64) {
|
|
let mut elems = self.process_items(team_1, 1.0);
|
|
elems.extend(self.process_items(team_2, -1.0));
|
|
|
|
let prob = match self.win_obs {
|
|
BinaryModelObservation::Probit => probit_win_observation(&elems, t, 0.0, &self.storage),
|
|
BinaryModelObservation::Logit => todo!(),
|
|
};
|
|
|
|
(prob, 1.0 - prob)
|
|
}
|
|
|
|
fn process_items(&self, items: &[&str], sign: f64) -> Vec<(usize, f64)> {
|
|
items
|
|
.iter()
|
|
.map(|key| (self.storage.get_id(&key), sign))
|
|
.collect()
|
|
}
|
|
}
|