Added TernaryModel.

This commit is contained in:
2020-03-06 09:49:55 +01:00
parent 8fe57b4649
commit 67d1412af8
5 changed files with 426 additions and 171 deletions

View File

@@ -13,10 +13,10 @@ pub struct RecursiveFitter {
#[derivative(Debug = "ignore")]
kernel: Box<dyn Kernel>,
ts: Vec<f64>,
ms: ArrayD<f64>,
vs: Array1<f64>,
ns: ArrayD<f64>,
xs: ArrayD<f64>,
ms: ArrayD<f64>, // TODO Replace with a vec
vs: Array1<f64>, // TODO Replace with a vec
ns: ArrayD<f64>, // TODO Replace with a vec
xs: ArrayD<f64>, // TODO Replace with a vec
is_fitted: bool,
h: Array1<f64>,
i: Array2<f64>,
@@ -181,16 +181,10 @@ impl Fitter for RecursiveFitter {
} else {
let a = self.p_p[i + 1].clone();
let b = self.a[i].dot(&self.p_f[i]);
// println!("a={:#?}", a);
let g = crate::linalg::solve(a, b);
let g = g.t();
/*
let g = self.a[i]
.dot(&self.p_f[i])
.dot(&self.p_p[i + 1].inv().expect("failed to inverse matrix"));
*/
self.m_s[i] = &self.m_f[i] + &g.dot(&(&self.m_s[i + 1] - &self.m_p[i + 1]));
self.p_s[i] =
&self.p_f[i] + &g.dot(&(&self.p_s[i + 1] - &self.p_p[i + 1])).dot(&g.t());
@@ -257,6 +251,7 @@ impl Fitter for RecursiveFitter {
let a = self.kernel.transition(ts[i], self.ts[(j + 1) as usize]);
let a = a.dot(&p);
let b = self.p_p[(j + 1) as usize].clone();
let g = crate::linalg::solve(a, b);
let g = g.t();

View File

@@ -1,159 +1,5 @@
use std::f64;
mod binary;
mod ternary;
use crate::fitter::RecursiveFitter;
use crate::item::Item;
use crate::kernel::Kernel;
use crate::observation::*;
use crate::storage::Storage;
pub enum BinaryModelObservation {
Probit,
Logit,
}
#[derive(Clone, Copy)]
pub enum BinaryModelFitMethod {
Ep,
Kl,
}
pub struct BinaryModel {
storage: Storage,
last_t: f64,
win_obs: BinaryModelObservation,
observations: Vec<Box<dyn Observation>>,
last_method: Option<BinaryModelFitMethod>,
}
impl BinaryModel {
pub fn new(win_obs: BinaryModelObservation) -> Self {
BinaryModel {
storage: Storage::new(),
last_t: f64::NEG_INFINITY,
win_obs,
observations: Vec::new(),
last_method: None,
}
}
pub fn add_item(&mut self, name: &str, kernel: Box<dyn Kernel>) {
if self.storage.contains_key(name) {
panic!("item '{}' already added", name);
}
self.storage.insert(
name.to_string(),
Item::new(Box::new(RecursiveFitter::new(kernel))),
);
}
pub fn contains_item(&self, name: &str) -> bool {
self.storage.contains_key(name)
}
pub fn item_score(&self, name: &str, t: f64) -> (f64, f64) {
let id = self.storage.get_id(name);
let (ms, vs) = self.storage.item(id).fitter.predict(&[t]);
(ms[0], vs[0])
}
pub fn observe(&mut self, winners: &[&str], losers: &[&str], t: f64) {
if t < self.last_t {
panic!("observations must be added in chronological order");
}
let mut elems = self.process_items(winners, 1.0);
elems.extend(self.process_items(losers, -1.0));
let obs: Box<dyn Observation> = match self.win_obs {
BinaryModelObservation::Probit => {
Box::new(ProbitWinObservation::new(&mut self.storage, &elems, t, 0.0))
}
BinaryModelObservation::Logit => {
Box::new(LogitWinObservation::new(&mut self.storage, &elems, t, 0.0))
}
};
self.observations.push(obs);
/*
for (item, _) in elems {
item.link_observation(obs)
}
*/
self.last_t = t;
}
pub fn fit(&mut self) -> bool {
// method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False):
let method = BinaryModelFitMethod::Ep;
let lr = 1.0;
let tol = 1e-3;
let max_iter = 100;
let verbose = true;
self.last_method = Some(method);
for item in self.storage.items_mut() {
item.fitter.allocate();
}
for i in 0..max_iter {
let mut max_diff = 0.0;
for obs in &mut self.observations {
let diff = match method {
BinaryModelFitMethod::Ep => obs.ep_update(lr, &mut self.storage),
BinaryModelFitMethod::Kl => obs.kl_update(lr, &mut self.storage),
};
if diff > max_diff {
max_diff = diff;
}
}
for item in self.storage.items_mut() {
item.fitter.fit();
}
if verbose {
println!("iteration {}, max diff: {:.5}", i + 1, max_diff);
}
if max_diff < tol {
return true;
}
}
false
}
pub fn probabilities(&mut self, team_1: &[&str], team_2: &[&str], t: f64) -> (f64, f64) {
let mut elems = self.process_items(team_1, 1.0);
elems.extend(self.process_items(team_2, -1.0));
let prob = match self.win_obs {
BinaryModelObservation::Probit => {
let margin = 0.0;
let (m, v) = f_params(&elems, t, &self.storage);
let (logpart, _, _) = mm_probit_win(m - margin, v);
logpart.exp()
}
BinaryModelObservation::Logit => todo!(),
};
(prob, 1.0 - prob)
}
fn process_items(&self, items: &[&str], sign: f64) -> Vec<(usize, f64)> {
items
.iter()
.map(|key| (self.storage.get_id(&key), sign))
.collect()
}
}
pub use binary::*;
pub use ternary::*;

152
src/model/binary.rs Normal file
View File

@@ -0,0 +1,152 @@
use std::f64;
use crate::fitter::RecursiveFitter;
use crate::item::Item;
use crate::kernel::Kernel;
use crate::observation::*;
use crate::storage::Storage;
pub enum BinaryModelObservation {
Probit,
Logit,
}
#[derive(Clone, Copy)]
pub enum BinaryModelFitMethod {
Ep,
Kl,
}
pub struct BinaryModel {
storage: Storage,
last_t: f64,
win_obs: BinaryModelObservation,
observations: Vec<Box<dyn Observation>>,
last_method: Option<BinaryModelFitMethod>,
}
impl BinaryModel {
pub fn new(win_obs: BinaryModelObservation) -> Self {
BinaryModel {
storage: Storage::new(),
last_t: f64::NEG_INFINITY,
win_obs,
observations: Vec::new(),
last_method: None,
}
}
pub fn add_item(&mut self, name: &str, kernel: Box<dyn Kernel>) {
if self.storage.contains_key(name) {
panic!("item '{}' already added", name);
}
self.storage.insert(
name.to_string(),
Item::new(Box::new(RecursiveFitter::new(kernel))),
);
}
pub fn contains_item(&self, name: &str) -> bool {
self.storage.contains_key(name)
}
pub fn item_score(&self, name: &str, t: f64) -> (f64, f64) {
let id = self.storage.get_id(name);
let (ms, vs) = self.storage.item(id).fitter.predict(&[t]);
(ms[0], vs[0])
}
pub fn observe(&mut self, winners: &[&str], losers: &[&str], t: f64) {
if t < self.last_t {
panic!("observations must be added in chronological order");
}
let mut elems = self.process_items(winners, 1.0);
elems.extend(self.process_items(losers, -1.0));
let obs: Box<dyn Observation> = match self.win_obs {
BinaryModelObservation::Probit => {
Box::new(ProbitWinObservation::new(&mut self.storage, &elems, t, 0.0))
}
BinaryModelObservation::Logit => {
Box::new(LogitWinObservation::new(&mut self.storage, &elems, t, 0.0))
}
};
self.observations.push(obs);
/*
for (item, _) in elems {
item.link_observation(obs)
}
*/
self.last_t = t;
}
pub fn fit(&mut self) -> bool {
// method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False):
let method = BinaryModelFitMethod::Ep;
let lr = 1.0;
let tol = 1e-3;
let max_iter = 100;
let verbose = true;
self.last_method = Some(method);
for item in self.storage.items_mut() {
item.fitter.allocate();
}
for i in 0..max_iter {
let mut max_diff = 0.0;
for obs in &mut self.observations {
let diff = match method {
BinaryModelFitMethod::Ep => obs.ep_update(lr, &mut self.storage),
BinaryModelFitMethod::Kl => obs.kl_update(lr, &mut self.storage),
};
if diff > max_diff {
max_diff = diff;
}
}
for item in self.storage.items_mut() {
item.fitter.fit();
}
if verbose {
println!("iteration {}, max diff: {:.5}", i + 1, max_diff);
}
if max_diff < tol {
return true;
}
}
false
}
pub fn probabilities(&mut self, team_1: &[&str], team_2: &[&str], t: f64) -> (f64, f64) {
let mut elems = self.process_items(team_1, 1.0);
elems.extend(self.process_items(team_2, -1.0));
let prob = match self.win_obs {
BinaryModelObservation::Probit => probit_win_observation(&elems, t, 0.0, &self.storage),
BinaryModelObservation::Logit => todo!(),
};
(prob, 1.0 - prob)
}
fn process_items(&self, items: &[&str], sign: f64) -> Vec<(usize, f64)> {
items
.iter()
.map(|key| (self.storage.get_id(&key), sign))
.collect()
}
}

193
src/model/ternary.rs Normal file
View File

@@ -0,0 +1,193 @@
use std::f64;
use crate::fitter::RecursiveFitter;
use crate::item::Item;
use crate::kernel::Kernel;
use crate::observation::*;
use crate::storage::Storage;
#[derive(Clone, Copy)]
pub enum TernaryModelObservation {
Probit,
Logit,
}
#[derive(Clone, Copy)]
pub enum TernaryModelFitMethod {
Ep,
Kl,
}
pub struct TernaryModel {
storage: Storage,
last_t: f64,
obs: TernaryModelObservation,
observations: Vec<Box<dyn Observation>>,
last_method: Option<TernaryModelFitMethod>,
margin: f64,
}
impl TernaryModel {
pub fn new(obs: TernaryModelObservation, margin: f64) -> Self {
TernaryModel {
storage: Storage::new(),
last_t: f64::NEG_INFINITY,
obs, // default = probit
observations: Vec::new(),
last_method: None,
margin, // default = 0.1
}
}
pub fn add_item(&mut self, name: &str, kernel: Box<dyn Kernel>) {
if self.storage.contains_key(name) {
panic!("item '{}' already added", name);
}
self.storage.insert(
name.to_string(),
Item::new(Box::new(RecursiveFitter::new(kernel))),
);
}
pub fn contains_item(&self, name: &str) -> bool {
self.storage.contains_key(name)
}
pub fn item_score(&self, name: &str, t: f64) -> (f64, f64) {
let id = self.storage.get_id(name);
let (ms, vs) = self.storage.item(id).fitter.predict(&[t]);
(ms[0], vs[0])
}
pub fn observe(
&mut self,
winners: &[&str],
losers: &[&str],
t: f64,
tie: bool,
margin: Option<f64>,
) {
if t < self.last_t {
panic!("observations must be added in chronological order");
}
let margin = margin.unwrap_or_else(|| self.margin);
let mut elems = self.process_items(winners, 1.0);
elems.extend(self.process_items(losers, -1.0));
let obs: Box<dyn Observation> = match (tie, self.obs) {
(false, TernaryModelObservation::Probit) => Box::new(ProbitWinObservation::new(
&mut self.storage,
&elems,
t,
margin,
)),
(false, TernaryModelObservation::Logit) => Box::new(LogitWinObservation::new(
&mut self.storage,
&elems,
t,
margin,
)),
(true, TernaryModelObservation::Probit) => Box::new(ProbitTieObservation::new(
&mut self.storage,
&elems,
t,
margin,
)),
(true, TernaryModelObservation::Logit) => Box::new(LogitTieObservation::new(
&mut self.storage,
&elems,
t,
margin,
)),
};
self.observations.push(obs);
self.last_t = t;
}
pub fn fit(&mut self) -> bool {
// method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False):
let method = TernaryModelFitMethod::Ep;
let lr = 1.0;
let tol = 1e-3;
let max_iter = 100;
let verbose = true;
self.last_method = Some(method);
for item in self.storage.items_mut() {
item.fitter.allocate();
}
for i in 0..max_iter {
let mut max_diff = 0.0;
for obs in &mut self.observations {
let diff = match method {
TernaryModelFitMethod::Ep => obs.ep_update(lr, &mut self.storage),
TernaryModelFitMethod::Kl => obs.kl_update(lr, &mut self.storage),
};
if diff > max_diff {
max_diff = diff;
}
}
for item in self.storage.items_mut() {
item.fitter.fit();
}
if verbose {
println!("iteration {}, max diff: {:.5}", i + 1, max_diff);
}
if max_diff < tol {
return true;
}
}
false
}
pub fn probabilities(
&mut self,
team_1: &[&str],
team_2: &[&str],
t: f64,
margin: Option<f64>,
) -> (f64, f64, f64) {
let margin = margin.unwrap_or_else(|| self.margin);
let mut elems = self.process_items(team_1, 1.0);
elems.extend(self.process_items(team_2, -1.0));
let prob_1 = match self.obs {
TernaryModelObservation::Probit => {
probit_win_observation(&elems, t, margin, &self.storage)
}
TernaryModelObservation::Logit => unimplemented!(),
};
let prob_2 = match self.obs {
TernaryModelObservation::Probit => {
probit_tie_observation(&elems, t, margin, &self.storage)
}
TernaryModelObservation::Logit => unimplemented!(),
};
(prob_1, prob_2, 1.0 - prob_1 - prob_2)
}
fn process_items(&self, items: &[&str], sign: f64) -> Vec<(usize, f64)> {
items
.iter()
.map(|key| (self.storage.get_id(&key), sign))
.collect()
}
}

View File

@@ -1,9 +1,21 @@
use crate::storage::Storage;
use crate::utils::logphi;
use super::Observation;
use super::{f_params, Observation};
pub fn mm_probit_win(mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
pub fn probit_win_observation(
elems: &[(usize, f64)],
t: f64,
margin: f64,
storage: &Storage,
) -> f64 {
let (m, v) = f_params(&elems, t, &storage);
let (logpart, _, _) = mm_probit_win(m - margin, v);
logpart.exp()
}
fn mm_probit_win(mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
// Adapted from the GPML function `likErf.m`.
let z = mean_cav / (1.0 + cov_cav).sqrt();
let (logpart, val) = logphi(z);
@@ -135,3 +147,60 @@ impl Observation for LogitWinObservation {
todo!();
}
}
pub fn probit_tie_observation(
elems: &[(usize, f64)],
t: f64,
margin: f64,
storage: &Storage,
) -> f64 {
unimplemented!();
}
pub struct ProbitTieObservation {
//
}
impl ProbitTieObservation {
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
todo!();
}
}
impl Observation for ProbitTieObservation {
fn match_moments(&self, mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
todo!();
}
fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
todo!();
}
fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
todo!();
}
}
pub struct LogitTieObservation {
//
}
impl LogitTieObservation {
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
todo!();
}
}
impl Observation for LogitTieObservation {
fn match_moments(&self, mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
todo!();
}
fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
todo!();
}
fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
todo!();
}
}