Added TernaryModel.
This commit is contained in:
152
src/model/binary.rs
Normal file
152
src/model/binary.rs
Normal file
@@ -0,0 +1,152 @@
|
||||
use std::f64;
|
||||
|
||||
use crate::fitter::RecursiveFitter;
|
||||
use crate::item::Item;
|
||||
use crate::kernel::Kernel;
|
||||
use crate::observation::*;
|
||||
use crate::storage::Storage;
|
||||
|
||||
pub enum BinaryModelObservation {
|
||||
Probit,
|
||||
Logit,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum BinaryModelFitMethod {
|
||||
Ep,
|
||||
Kl,
|
||||
}
|
||||
|
||||
pub struct BinaryModel {
|
||||
storage: Storage,
|
||||
last_t: f64,
|
||||
win_obs: BinaryModelObservation,
|
||||
observations: Vec<Box<dyn Observation>>,
|
||||
last_method: Option<BinaryModelFitMethod>,
|
||||
}
|
||||
|
||||
impl BinaryModel {
|
||||
pub fn new(win_obs: BinaryModelObservation) -> Self {
|
||||
BinaryModel {
|
||||
storage: Storage::new(),
|
||||
last_t: f64::NEG_INFINITY,
|
||||
win_obs,
|
||||
observations: Vec::new(),
|
||||
last_method: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_item(&mut self, name: &str, kernel: Box<dyn Kernel>) {
|
||||
if self.storage.contains_key(name) {
|
||||
panic!("item '{}' already added", name);
|
||||
}
|
||||
|
||||
self.storage.insert(
|
||||
name.to_string(),
|
||||
Item::new(Box::new(RecursiveFitter::new(kernel))),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn contains_item(&self, name: &str) -> bool {
|
||||
self.storage.contains_key(name)
|
||||
}
|
||||
|
||||
pub fn item_score(&self, name: &str, t: f64) -> (f64, f64) {
|
||||
let id = self.storage.get_id(name);
|
||||
let (ms, vs) = self.storage.item(id).fitter.predict(&[t]);
|
||||
|
||||
(ms[0], vs[0])
|
||||
}
|
||||
|
||||
pub fn observe(&mut self, winners: &[&str], losers: &[&str], t: f64) {
|
||||
if t < self.last_t {
|
||||
panic!("observations must be added in chronological order");
|
||||
}
|
||||
|
||||
let mut elems = self.process_items(winners, 1.0);
|
||||
elems.extend(self.process_items(losers, -1.0));
|
||||
|
||||
let obs: Box<dyn Observation> = match self.win_obs {
|
||||
BinaryModelObservation::Probit => {
|
||||
Box::new(ProbitWinObservation::new(&mut self.storage, &elems, t, 0.0))
|
||||
}
|
||||
BinaryModelObservation::Logit => {
|
||||
Box::new(LogitWinObservation::new(&mut self.storage, &elems, t, 0.0))
|
||||
}
|
||||
};
|
||||
|
||||
self.observations.push(obs);
|
||||
|
||||
/*
|
||||
for (item, _) in elems {
|
||||
item.link_observation(obs)
|
||||
}
|
||||
*/
|
||||
|
||||
self.last_t = t;
|
||||
}
|
||||
|
||||
pub fn fit(&mut self) -> bool {
|
||||
// method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False):
|
||||
|
||||
let method = BinaryModelFitMethod::Ep;
|
||||
let lr = 1.0;
|
||||
let tol = 1e-3;
|
||||
let max_iter = 100;
|
||||
let verbose = true;
|
||||
|
||||
self.last_method = Some(method);
|
||||
|
||||
for item in self.storage.items_mut() {
|
||||
item.fitter.allocate();
|
||||
}
|
||||
|
||||
for i in 0..max_iter {
|
||||
let mut max_diff = 0.0;
|
||||
|
||||
for obs in &mut self.observations {
|
||||
let diff = match method {
|
||||
BinaryModelFitMethod::Ep => obs.ep_update(lr, &mut self.storage),
|
||||
BinaryModelFitMethod::Kl => obs.kl_update(lr, &mut self.storage),
|
||||
};
|
||||
|
||||
if diff > max_diff {
|
||||
max_diff = diff;
|
||||
}
|
||||
}
|
||||
|
||||
for item in self.storage.items_mut() {
|
||||
item.fitter.fit();
|
||||
}
|
||||
|
||||
if verbose {
|
||||
println!("iteration {}, max diff: {:.5}", i + 1, max_diff);
|
||||
}
|
||||
|
||||
if max_diff < tol {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
pub fn probabilities(&mut self, team_1: &[&str], team_2: &[&str], t: f64) -> (f64, f64) {
|
||||
let mut elems = self.process_items(team_1, 1.0);
|
||||
elems.extend(self.process_items(team_2, -1.0));
|
||||
|
||||
let prob = match self.win_obs {
|
||||
BinaryModelObservation::Probit => probit_win_observation(&elems, t, 0.0, &self.storage),
|
||||
BinaryModelObservation::Logit => todo!(),
|
||||
};
|
||||
|
||||
(prob, 1.0 - prob)
|
||||
}
|
||||
|
||||
fn process_items(&self, items: &[&str], sign: f64) -> Vec<(usize, f64)> {
|
||||
items
|
||||
.iter()
|
||||
.map(|key| (self.storage.get_id(&key), sign))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user