use std::f64; use crate::fitter::RecursiveFitter; use crate::item::Item; use crate::kernel::Kernel; use crate::observation::*; use crate::storage::Storage; pub enum BinaryModelObservation { Probit, Logit, } #[derive(Clone, Copy)] pub enum BinaryModelFitMethod { Ep, Kl, } pub struct BinaryModel { storage: Storage, last_t: f64, win_obs: BinaryModelObservation, observations: Vec>, last_method: Option, } impl BinaryModel { pub fn new(win_obs: BinaryModelObservation) -> Self { BinaryModel { storage: Storage::new(), last_t: f64::NEG_INFINITY, win_obs, observations: Vec::new(), last_method: None, } } pub fn add_item(&mut self, name: &str, kernel: Box) { if self.storage.contains_key(name) { panic!("item '{}' already added", name); } self.storage.insert( name.to_string(), Item::new(Box::new(RecursiveFitter::new(kernel))), ); } pub fn contains_item(&self, name: &str) -> bool { self.storage.contains_key(name) } pub fn item_score(&self, name: &str, t: f64) -> (f64, f64) { let id = self.storage.get_id(name); let (ms, vs) = self.storage.item(id).fitter.predict(&[t]); (ms[0], vs[0]) } pub fn observe(&mut self, winners: &[&str], losers: &[&str], t: f64) { if t < self.last_t { panic!("observations must be added in chronological order"); } let mut elems = self.process_items(winners, 1.0); elems.extend(self.process_items(losers, -1.0)); let obs: Box = match self.win_obs { BinaryModelObservation::Probit => { Box::new(ProbitWinObservation::new(&mut self.storage, &elems, t, 0.0)) } BinaryModelObservation::Logit => { Box::new(LogitWinObservation::new(&mut self.storage, &elems, t, 0.0)) } }; self.observations.push(obs); /* for (item, _) in elems { item.link_observation(obs) } */ self.last_t = t; } pub fn fit(&mut self) -> bool { // method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False): let method = BinaryModelFitMethod::Ep; let lr = 1.0; let tol = 1e-3; let max_iter = 100; let verbose = true; self.last_method = Some(method); for item in self.storage.items_mut() { item.fitter.allocate(); } for i in 0..max_iter { let mut max_diff = 0.0; for obs in &mut self.observations { let diff = match method { BinaryModelFitMethod::Ep => obs.ep_update(lr, &mut self.storage), BinaryModelFitMethod::Kl => obs.kl_update(lr, &mut self.storage), }; if diff > max_diff { max_diff = diff; } } for item in self.storage.items_mut() { item.fitter.fit(); } if verbose { println!("iteration {}, max diff: {:.5}", i + 1, max_diff); } if max_diff < tol { return true; } } false } pub fn probabilities(&mut self, team_1: &[&str], team_2: &[&str], t: f64) -> (f64, f64) { let mut elems = self.process_items(team_1, 1.0); elems.extend(self.process_items(team_2, -1.0)); let prob = match self.win_obs { BinaryModelObservation::Probit => probit_win_observation(&elems, t, 0.0, &self.storage), BinaryModelObservation::Logit => todo!(), }; (prob, 1.0 - prob) } fn process_items(&self, items: &[&str], sign: f64) -> Vec<(usize, f64)> { items .iter() .map(|key| (self.storage.get_id(&key), sign)) .collect() } }