From 67d1412af894c58b77a031984f83c4cef71ca034 Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Fri, 6 Mar 2020 09:49:55 +0100 Subject: [PATCH] Added TernaryModel. --- src/fitter/recursive.rs | 17 ++-- src/model.rs | 162 +------------------------------ src/model/binary.rs | 152 +++++++++++++++++++++++++++++ src/model/ternary.rs | 193 +++++++++++++++++++++++++++++++++++++ src/observation/ordinal.rs | 73 +++++++++++++- 5 files changed, 426 insertions(+), 171 deletions(-) create mode 100644 src/model/binary.rs create mode 100644 src/model/ternary.rs diff --git a/src/fitter/recursive.rs b/src/fitter/recursive.rs index 2ec9cb5..98f6b11 100644 --- a/src/fitter/recursive.rs +++ b/src/fitter/recursive.rs @@ -13,10 +13,10 @@ pub struct RecursiveFitter { #[derivative(Debug = "ignore")] kernel: Box, ts: Vec, - ms: ArrayD, - vs: Array1, - ns: ArrayD, - xs: ArrayD, + ms: ArrayD, // TODO Replace with a vec + vs: Array1, // TODO Replace with a vec + ns: ArrayD, // TODO Replace with a vec + xs: ArrayD, // TODO Replace with a vec is_fitted: bool, h: Array1, i: Array2, @@ -181,16 +181,10 @@ impl Fitter for RecursiveFitter { } else { let a = self.p_p[i + 1].clone(); let b = self.a[i].dot(&self.p_f[i]); - // println!("a={:#?}", a); + let g = crate::linalg::solve(a, b); let g = g.t(); - /* - let g = self.a[i] - .dot(&self.p_f[i]) - .dot(&self.p_p[i + 1].inv().expect("failed to inverse matrix")); - */ - self.m_s[i] = &self.m_f[i] + &g.dot(&(&self.m_s[i + 1] - &self.m_p[i + 1])); self.p_s[i] = &self.p_f[i] + &g.dot(&(&self.p_s[i + 1] - &self.p_p[i + 1])).dot(&g.t()); @@ -257,6 +251,7 @@ impl Fitter for RecursiveFitter { let a = self.kernel.transition(ts[i], self.ts[(j + 1) as usize]); let a = a.dot(&p); let b = self.p_p[(j + 1) as usize].clone(); + let g = crate::linalg::solve(a, b); let g = g.t(); diff --git a/src/model.rs b/src/model.rs index c4cee97..f3b0869 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,159 +1,5 @@ -use std::f64; +mod binary; +mod ternary; -use crate::fitter::RecursiveFitter; -use crate::item::Item; -use crate::kernel::Kernel; -use crate::observation::*; -use crate::storage::Storage; - -pub enum BinaryModelObservation { - Probit, - Logit, -} - -#[derive(Clone, Copy)] -pub enum BinaryModelFitMethod { - Ep, - Kl, -} - -pub struct BinaryModel { - storage: Storage, - last_t: f64, - win_obs: BinaryModelObservation, - observations: Vec>, - last_method: Option, -} - -impl BinaryModel { - pub fn new(win_obs: BinaryModelObservation) -> Self { - BinaryModel { - storage: Storage::new(), - last_t: f64::NEG_INFINITY, - win_obs, - observations: Vec::new(), - last_method: None, - } - } - - pub fn add_item(&mut self, name: &str, kernel: Box) { - if self.storage.contains_key(name) { - panic!("item '{}' already added", name); - } - - self.storage.insert( - name.to_string(), - Item::new(Box::new(RecursiveFitter::new(kernel))), - ); - } - - pub fn contains_item(&self, name: &str) -> bool { - self.storage.contains_key(name) - } - - pub fn item_score(&self, name: &str, t: f64) -> (f64, f64) { - let id = self.storage.get_id(name); - let (ms, vs) = self.storage.item(id).fitter.predict(&[t]); - - (ms[0], vs[0]) - } - - pub fn observe(&mut self, winners: &[&str], losers: &[&str], t: f64) { - if t < self.last_t { - panic!("observations must be added in chronological order"); - } - - let mut elems = self.process_items(winners, 1.0); - elems.extend(self.process_items(losers, -1.0)); - - let obs: Box = match self.win_obs { - BinaryModelObservation::Probit => { - Box::new(ProbitWinObservation::new(&mut self.storage, &elems, t, 0.0)) - } - BinaryModelObservation::Logit => { - Box::new(LogitWinObservation::new(&mut self.storage, &elems, t, 0.0)) - } - }; - - self.observations.push(obs); - - /* - for (item, _) in elems { - item.link_observation(obs) - } - */ - - self.last_t = t; - } - - pub fn fit(&mut self) -> bool { - // method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False): - - let method = BinaryModelFitMethod::Ep; - let lr = 1.0; - let tol = 1e-3; - let max_iter = 100; - let verbose = true; - - self.last_method = Some(method); - - for item in self.storage.items_mut() { - item.fitter.allocate(); - } - - for i in 0..max_iter { - let mut max_diff = 0.0; - - for obs in &mut self.observations { - let diff = match method { - BinaryModelFitMethod::Ep => obs.ep_update(lr, &mut self.storage), - BinaryModelFitMethod::Kl => obs.kl_update(lr, &mut self.storage), - }; - - if diff > max_diff { - max_diff = diff; - } - } - - for item in self.storage.items_mut() { - item.fitter.fit(); - } - - if verbose { - println!("iteration {}, max diff: {:.5}", i + 1, max_diff); - } - - if max_diff < tol { - return true; - } - } - - false - } - - pub fn probabilities(&mut self, team_1: &[&str], team_2: &[&str], t: f64) -> (f64, f64) { - let mut elems = self.process_items(team_1, 1.0); - elems.extend(self.process_items(team_2, -1.0)); - - let prob = match self.win_obs { - BinaryModelObservation::Probit => { - let margin = 0.0; - - let (m, v) = f_params(&elems, t, &self.storage); - let (logpart, _, _) = mm_probit_win(m - margin, v); - - logpart.exp() - } - BinaryModelObservation::Logit => todo!(), - }; - - (prob, 1.0 - prob) - } - - fn process_items(&self, items: &[&str], sign: f64) -> Vec<(usize, f64)> { - items - .iter() - .map(|key| (self.storage.get_id(&key), sign)) - .collect() - } -} +pub use binary::*; +pub use ternary::*; diff --git a/src/model/binary.rs b/src/model/binary.rs new file mode 100644 index 0000000..c0b9fbd --- /dev/null +++ b/src/model/binary.rs @@ -0,0 +1,152 @@ +use std::f64; + +use crate::fitter::RecursiveFitter; +use crate::item::Item; +use crate::kernel::Kernel; +use crate::observation::*; +use crate::storage::Storage; + +pub enum BinaryModelObservation { + Probit, + Logit, +} + +#[derive(Clone, Copy)] +pub enum BinaryModelFitMethod { + Ep, + Kl, +} + +pub struct BinaryModel { + storage: Storage, + last_t: f64, + win_obs: BinaryModelObservation, + observations: Vec>, + last_method: Option, +} + +impl BinaryModel { + pub fn new(win_obs: BinaryModelObservation) -> Self { + BinaryModel { + storage: Storage::new(), + last_t: f64::NEG_INFINITY, + win_obs, + observations: Vec::new(), + last_method: None, + } + } + + pub fn add_item(&mut self, name: &str, kernel: Box) { + if self.storage.contains_key(name) { + panic!("item '{}' already added", name); + } + + self.storage.insert( + name.to_string(), + Item::new(Box::new(RecursiveFitter::new(kernel))), + ); + } + + pub fn contains_item(&self, name: &str) -> bool { + self.storage.contains_key(name) + } + + pub fn item_score(&self, name: &str, t: f64) -> (f64, f64) { + let id = self.storage.get_id(name); + let (ms, vs) = self.storage.item(id).fitter.predict(&[t]); + + (ms[0], vs[0]) + } + + pub fn observe(&mut self, winners: &[&str], losers: &[&str], t: f64) { + if t < self.last_t { + panic!("observations must be added in chronological order"); + } + + let mut elems = self.process_items(winners, 1.0); + elems.extend(self.process_items(losers, -1.0)); + + let obs: Box = match self.win_obs { + BinaryModelObservation::Probit => { + Box::new(ProbitWinObservation::new(&mut self.storage, &elems, t, 0.0)) + } + BinaryModelObservation::Logit => { + Box::new(LogitWinObservation::new(&mut self.storage, &elems, t, 0.0)) + } + }; + + self.observations.push(obs); + + /* + for (item, _) in elems { + item.link_observation(obs) + } + */ + + self.last_t = t; + } + + pub fn fit(&mut self) -> bool { + // method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False): + + let method = BinaryModelFitMethod::Ep; + let lr = 1.0; + let tol = 1e-3; + let max_iter = 100; + let verbose = true; + + self.last_method = Some(method); + + for item in self.storage.items_mut() { + item.fitter.allocate(); + } + + for i in 0..max_iter { + let mut max_diff = 0.0; + + for obs in &mut self.observations { + let diff = match method { + BinaryModelFitMethod::Ep => obs.ep_update(lr, &mut self.storage), + BinaryModelFitMethod::Kl => obs.kl_update(lr, &mut self.storage), + }; + + if diff > max_diff { + max_diff = diff; + } + } + + for item in self.storage.items_mut() { + item.fitter.fit(); + } + + if verbose { + println!("iteration {}, max diff: {:.5}", i + 1, max_diff); + } + + if max_diff < tol { + return true; + } + } + + false + } + + pub fn probabilities(&mut self, team_1: &[&str], team_2: &[&str], t: f64) -> (f64, f64) { + let mut elems = self.process_items(team_1, 1.0); + elems.extend(self.process_items(team_2, -1.0)); + + let prob = match self.win_obs { + BinaryModelObservation::Probit => probit_win_observation(&elems, t, 0.0, &self.storage), + BinaryModelObservation::Logit => todo!(), + }; + + (prob, 1.0 - prob) + } + + fn process_items(&self, items: &[&str], sign: f64) -> Vec<(usize, f64)> { + items + .iter() + .map(|key| (self.storage.get_id(&key), sign)) + .collect() + } +} diff --git a/src/model/ternary.rs b/src/model/ternary.rs new file mode 100644 index 0000000..c471485 --- /dev/null +++ b/src/model/ternary.rs @@ -0,0 +1,193 @@ +use std::f64; + +use crate::fitter::RecursiveFitter; +use crate::item::Item; +use crate::kernel::Kernel; +use crate::observation::*; +use crate::storage::Storage; + +#[derive(Clone, Copy)] +pub enum TernaryModelObservation { + Probit, + Logit, +} + +#[derive(Clone, Copy)] +pub enum TernaryModelFitMethod { + Ep, + Kl, +} + +pub struct TernaryModel { + storage: Storage, + last_t: f64, + obs: TernaryModelObservation, + observations: Vec>, + last_method: Option, + margin: f64, +} + +impl TernaryModel { + pub fn new(obs: TernaryModelObservation, margin: f64) -> Self { + TernaryModel { + storage: Storage::new(), + last_t: f64::NEG_INFINITY, + obs, // default = probit + observations: Vec::new(), + last_method: None, + margin, // default = 0.1 + } + } + + pub fn add_item(&mut self, name: &str, kernel: Box) { + if self.storage.contains_key(name) { + panic!("item '{}' already added", name); + } + + self.storage.insert( + name.to_string(), + Item::new(Box::new(RecursiveFitter::new(kernel))), + ); + } + + pub fn contains_item(&self, name: &str) -> bool { + self.storage.contains_key(name) + } + + pub fn item_score(&self, name: &str, t: f64) -> (f64, f64) { + let id = self.storage.get_id(name); + let (ms, vs) = self.storage.item(id).fitter.predict(&[t]); + + (ms[0], vs[0]) + } + + pub fn observe( + &mut self, + winners: &[&str], + losers: &[&str], + t: f64, + tie: bool, + margin: Option, + ) { + if t < self.last_t { + panic!("observations must be added in chronological order"); + } + + let margin = margin.unwrap_or_else(|| self.margin); + + let mut elems = self.process_items(winners, 1.0); + elems.extend(self.process_items(losers, -1.0)); + + let obs: Box = match (tie, self.obs) { + (false, TernaryModelObservation::Probit) => Box::new(ProbitWinObservation::new( + &mut self.storage, + &elems, + t, + margin, + )), + (false, TernaryModelObservation::Logit) => Box::new(LogitWinObservation::new( + &mut self.storage, + &elems, + t, + margin, + )), + (true, TernaryModelObservation::Probit) => Box::new(ProbitTieObservation::new( + &mut self.storage, + &elems, + t, + margin, + )), + (true, TernaryModelObservation::Logit) => Box::new(LogitTieObservation::new( + &mut self.storage, + &elems, + t, + margin, + )), + }; + + self.observations.push(obs); + + self.last_t = t; + } + + pub fn fit(&mut self) -> bool { + // method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False): + + let method = TernaryModelFitMethod::Ep; + let lr = 1.0; + let tol = 1e-3; + let max_iter = 100; + let verbose = true; + + self.last_method = Some(method); + + for item in self.storage.items_mut() { + item.fitter.allocate(); + } + + for i in 0..max_iter { + let mut max_diff = 0.0; + + for obs in &mut self.observations { + let diff = match method { + TernaryModelFitMethod::Ep => obs.ep_update(lr, &mut self.storage), + TernaryModelFitMethod::Kl => obs.kl_update(lr, &mut self.storage), + }; + + if diff > max_diff { + max_diff = diff; + } + } + + for item in self.storage.items_mut() { + item.fitter.fit(); + } + + if verbose { + println!("iteration {}, max diff: {:.5}", i + 1, max_diff); + } + + if max_diff < tol { + return true; + } + } + + false + } + + pub fn probabilities( + &mut self, + team_1: &[&str], + team_2: &[&str], + t: f64, + margin: Option, + ) -> (f64, f64, f64) { + let margin = margin.unwrap_or_else(|| self.margin); + + let mut elems = self.process_items(team_1, 1.0); + elems.extend(self.process_items(team_2, -1.0)); + + let prob_1 = match self.obs { + TernaryModelObservation::Probit => { + probit_win_observation(&elems, t, margin, &self.storage) + } + TernaryModelObservation::Logit => unimplemented!(), + }; + + let prob_2 = match self.obs { + TernaryModelObservation::Probit => { + probit_tie_observation(&elems, t, margin, &self.storage) + } + TernaryModelObservation::Logit => unimplemented!(), + }; + + (prob_1, prob_2, 1.0 - prob_1 - prob_2) + } + + fn process_items(&self, items: &[&str], sign: f64) -> Vec<(usize, f64)> { + items + .iter() + .map(|key| (self.storage.get_id(&key), sign)) + .collect() + } +} diff --git a/src/observation/ordinal.rs b/src/observation/ordinal.rs index dddc2bb..d79fc97 100644 --- a/src/observation/ordinal.rs +++ b/src/observation/ordinal.rs @@ -1,9 +1,21 @@ use crate::storage::Storage; use crate::utils::logphi; -use super::Observation; +use super::{f_params, Observation}; -pub fn mm_probit_win(mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) { +pub fn probit_win_observation( + elems: &[(usize, f64)], + t: f64, + margin: f64, + storage: &Storage, +) -> f64 { + let (m, v) = f_params(&elems, t, &storage); + let (logpart, _, _) = mm_probit_win(m - margin, v); + + logpart.exp() +} + +fn mm_probit_win(mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) { // Adapted from the GPML function `likErf.m`. let z = mean_cav / (1.0 + cov_cav).sqrt(); let (logpart, val) = logphi(z); @@ -135,3 +147,60 @@ impl Observation for LogitWinObservation { todo!(); } } + +pub fn probit_tie_observation( + elems: &[(usize, f64)], + t: f64, + margin: f64, + storage: &Storage, +) -> f64 { + unimplemented!(); +} + +pub struct ProbitTieObservation { + // +} + +impl ProbitTieObservation { + pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self { + todo!(); + } +} + +impl Observation for ProbitTieObservation { + fn match_moments(&self, mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) { + todo!(); + } + + fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 { + todo!(); + } + + fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64 { + todo!(); + } +} + +pub struct LogitTieObservation { + // +} + +impl LogitTieObservation { + pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self { + todo!(); + } +} + +impl Observation for LogitTieObservation { + fn match_moments(&self, mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) { + todo!(); + } + + fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 { + todo!(); + } + + fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64 { + todo!(); + } +}