Added TernaryModel.
This commit is contained in:
@@ -13,10 +13,10 @@ pub struct RecursiveFitter {
|
||||
#[derivative(Debug = "ignore")]
|
||||
kernel: Box<dyn Kernel>,
|
||||
ts: Vec<f64>,
|
||||
ms: ArrayD<f64>,
|
||||
vs: Array1<f64>,
|
||||
ns: ArrayD<f64>,
|
||||
xs: ArrayD<f64>,
|
||||
ms: ArrayD<f64>, // TODO Replace with a vec
|
||||
vs: Array1<f64>, // TODO Replace with a vec
|
||||
ns: ArrayD<f64>, // TODO Replace with a vec
|
||||
xs: ArrayD<f64>, // TODO Replace with a vec
|
||||
is_fitted: bool,
|
||||
h: Array1<f64>,
|
||||
i: Array2<f64>,
|
||||
@@ -181,16 +181,10 @@ impl Fitter for RecursiveFitter {
|
||||
} else {
|
||||
let a = self.p_p[i + 1].clone();
|
||||
let b = self.a[i].dot(&self.p_f[i]);
|
||||
// println!("a={:#?}", a);
|
||||
|
||||
let g = crate::linalg::solve(a, b);
|
||||
let g = g.t();
|
||||
|
||||
/*
|
||||
let g = self.a[i]
|
||||
.dot(&self.p_f[i])
|
||||
.dot(&self.p_p[i + 1].inv().expect("failed to inverse matrix"));
|
||||
*/
|
||||
|
||||
self.m_s[i] = &self.m_f[i] + &g.dot(&(&self.m_s[i + 1] - &self.m_p[i + 1]));
|
||||
self.p_s[i] =
|
||||
&self.p_f[i] + &g.dot(&(&self.p_s[i + 1] - &self.p_p[i + 1])).dot(&g.t());
|
||||
@@ -257,6 +251,7 @@ impl Fitter for RecursiveFitter {
|
||||
let a = self.kernel.transition(ts[i], self.ts[(j + 1) as usize]);
|
||||
let a = a.dot(&p);
|
||||
let b = self.p_p[(j + 1) as usize].clone();
|
||||
|
||||
let g = crate::linalg::solve(a, b);
|
||||
let g = g.t();
|
||||
|
||||
|
||||
162
src/model.rs
162
src/model.rs
@@ -1,159 +1,5 @@
|
||||
use std::f64;
|
||||
mod binary;
|
||||
mod ternary;
|
||||
|
||||
use crate::fitter::RecursiveFitter;
|
||||
use crate::item::Item;
|
||||
use crate::kernel::Kernel;
|
||||
use crate::observation::*;
|
||||
use crate::storage::Storage;
|
||||
|
||||
pub enum BinaryModelObservation {
|
||||
Probit,
|
||||
Logit,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum BinaryModelFitMethod {
|
||||
Ep,
|
||||
Kl,
|
||||
}
|
||||
|
||||
pub struct BinaryModel {
|
||||
storage: Storage,
|
||||
last_t: f64,
|
||||
win_obs: BinaryModelObservation,
|
||||
observations: Vec<Box<dyn Observation>>,
|
||||
last_method: Option<BinaryModelFitMethod>,
|
||||
}
|
||||
|
||||
impl BinaryModel {
|
||||
pub fn new(win_obs: BinaryModelObservation) -> Self {
|
||||
BinaryModel {
|
||||
storage: Storage::new(),
|
||||
last_t: f64::NEG_INFINITY,
|
||||
win_obs,
|
||||
observations: Vec::new(),
|
||||
last_method: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_item(&mut self, name: &str, kernel: Box<dyn Kernel>) {
|
||||
if self.storage.contains_key(name) {
|
||||
panic!("item '{}' already added", name);
|
||||
}
|
||||
|
||||
self.storage.insert(
|
||||
name.to_string(),
|
||||
Item::new(Box::new(RecursiveFitter::new(kernel))),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn contains_item(&self, name: &str) -> bool {
|
||||
self.storage.contains_key(name)
|
||||
}
|
||||
|
||||
pub fn item_score(&self, name: &str, t: f64) -> (f64, f64) {
|
||||
let id = self.storage.get_id(name);
|
||||
let (ms, vs) = self.storage.item(id).fitter.predict(&[t]);
|
||||
|
||||
(ms[0], vs[0])
|
||||
}
|
||||
|
||||
pub fn observe(&mut self, winners: &[&str], losers: &[&str], t: f64) {
|
||||
if t < self.last_t {
|
||||
panic!("observations must be added in chronological order");
|
||||
}
|
||||
|
||||
let mut elems = self.process_items(winners, 1.0);
|
||||
elems.extend(self.process_items(losers, -1.0));
|
||||
|
||||
let obs: Box<dyn Observation> = match self.win_obs {
|
||||
BinaryModelObservation::Probit => {
|
||||
Box::new(ProbitWinObservation::new(&mut self.storage, &elems, t, 0.0))
|
||||
}
|
||||
BinaryModelObservation::Logit => {
|
||||
Box::new(LogitWinObservation::new(&mut self.storage, &elems, t, 0.0))
|
||||
}
|
||||
};
|
||||
|
||||
self.observations.push(obs);
|
||||
|
||||
/*
|
||||
for (item, _) in elems {
|
||||
item.link_observation(obs)
|
||||
}
|
||||
*/
|
||||
|
||||
self.last_t = t;
|
||||
}
|
||||
|
||||
pub fn fit(&mut self) -> bool {
|
||||
// method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False):
|
||||
|
||||
let method = BinaryModelFitMethod::Ep;
|
||||
let lr = 1.0;
|
||||
let tol = 1e-3;
|
||||
let max_iter = 100;
|
||||
let verbose = true;
|
||||
|
||||
self.last_method = Some(method);
|
||||
|
||||
for item in self.storage.items_mut() {
|
||||
item.fitter.allocate();
|
||||
}
|
||||
|
||||
for i in 0..max_iter {
|
||||
let mut max_diff = 0.0;
|
||||
|
||||
for obs in &mut self.observations {
|
||||
let diff = match method {
|
||||
BinaryModelFitMethod::Ep => obs.ep_update(lr, &mut self.storage),
|
||||
BinaryModelFitMethod::Kl => obs.kl_update(lr, &mut self.storage),
|
||||
};
|
||||
|
||||
if diff > max_diff {
|
||||
max_diff = diff;
|
||||
}
|
||||
}
|
||||
|
||||
for item in self.storage.items_mut() {
|
||||
item.fitter.fit();
|
||||
}
|
||||
|
||||
if verbose {
|
||||
println!("iteration {}, max diff: {:.5}", i + 1, max_diff);
|
||||
}
|
||||
|
||||
if max_diff < tol {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
pub fn probabilities(&mut self, team_1: &[&str], team_2: &[&str], t: f64) -> (f64, f64) {
|
||||
let mut elems = self.process_items(team_1, 1.0);
|
||||
elems.extend(self.process_items(team_2, -1.0));
|
||||
|
||||
let prob = match self.win_obs {
|
||||
BinaryModelObservation::Probit => {
|
||||
let margin = 0.0;
|
||||
|
||||
let (m, v) = f_params(&elems, t, &self.storage);
|
||||
let (logpart, _, _) = mm_probit_win(m - margin, v);
|
||||
|
||||
logpart.exp()
|
||||
}
|
||||
BinaryModelObservation::Logit => todo!(),
|
||||
};
|
||||
|
||||
(prob, 1.0 - prob)
|
||||
}
|
||||
|
||||
fn process_items(&self, items: &[&str], sign: f64) -> Vec<(usize, f64)> {
|
||||
items
|
||||
.iter()
|
||||
.map(|key| (self.storage.get_id(&key), sign))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
pub use binary::*;
|
||||
pub use ternary::*;
|
||||
|
||||
152
src/model/binary.rs
Normal file
152
src/model/binary.rs
Normal file
@@ -0,0 +1,152 @@
|
||||
use std::f64;
|
||||
|
||||
use crate::fitter::RecursiveFitter;
|
||||
use crate::item::Item;
|
||||
use crate::kernel::Kernel;
|
||||
use crate::observation::*;
|
||||
use crate::storage::Storage;
|
||||
|
||||
pub enum BinaryModelObservation {
|
||||
Probit,
|
||||
Logit,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum BinaryModelFitMethod {
|
||||
Ep,
|
||||
Kl,
|
||||
}
|
||||
|
||||
pub struct BinaryModel {
|
||||
storage: Storage,
|
||||
last_t: f64,
|
||||
win_obs: BinaryModelObservation,
|
||||
observations: Vec<Box<dyn Observation>>,
|
||||
last_method: Option<BinaryModelFitMethod>,
|
||||
}
|
||||
|
||||
impl BinaryModel {
|
||||
pub fn new(win_obs: BinaryModelObservation) -> Self {
|
||||
BinaryModel {
|
||||
storage: Storage::new(),
|
||||
last_t: f64::NEG_INFINITY,
|
||||
win_obs,
|
||||
observations: Vec::new(),
|
||||
last_method: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_item(&mut self, name: &str, kernel: Box<dyn Kernel>) {
|
||||
if self.storage.contains_key(name) {
|
||||
panic!("item '{}' already added", name);
|
||||
}
|
||||
|
||||
self.storage.insert(
|
||||
name.to_string(),
|
||||
Item::new(Box::new(RecursiveFitter::new(kernel))),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn contains_item(&self, name: &str) -> bool {
|
||||
self.storage.contains_key(name)
|
||||
}
|
||||
|
||||
pub fn item_score(&self, name: &str, t: f64) -> (f64, f64) {
|
||||
let id = self.storage.get_id(name);
|
||||
let (ms, vs) = self.storage.item(id).fitter.predict(&[t]);
|
||||
|
||||
(ms[0], vs[0])
|
||||
}
|
||||
|
||||
pub fn observe(&mut self, winners: &[&str], losers: &[&str], t: f64) {
|
||||
if t < self.last_t {
|
||||
panic!("observations must be added in chronological order");
|
||||
}
|
||||
|
||||
let mut elems = self.process_items(winners, 1.0);
|
||||
elems.extend(self.process_items(losers, -1.0));
|
||||
|
||||
let obs: Box<dyn Observation> = match self.win_obs {
|
||||
BinaryModelObservation::Probit => {
|
||||
Box::new(ProbitWinObservation::new(&mut self.storage, &elems, t, 0.0))
|
||||
}
|
||||
BinaryModelObservation::Logit => {
|
||||
Box::new(LogitWinObservation::new(&mut self.storage, &elems, t, 0.0))
|
||||
}
|
||||
};
|
||||
|
||||
self.observations.push(obs);
|
||||
|
||||
/*
|
||||
for (item, _) in elems {
|
||||
item.link_observation(obs)
|
||||
}
|
||||
*/
|
||||
|
||||
self.last_t = t;
|
||||
}
|
||||
|
||||
pub fn fit(&mut self) -> bool {
|
||||
// method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False):
|
||||
|
||||
let method = BinaryModelFitMethod::Ep;
|
||||
let lr = 1.0;
|
||||
let tol = 1e-3;
|
||||
let max_iter = 100;
|
||||
let verbose = true;
|
||||
|
||||
self.last_method = Some(method);
|
||||
|
||||
for item in self.storage.items_mut() {
|
||||
item.fitter.allocate();
|
||||
}
|
||||
|
||||
for i in 0..max_iter {
|
||||
let mut max_diff = 0.0;
|
||||
|
||||
for obs in &mut self.observations {
|
||||
let diff = match method {
|
||||
BinaryModelFitMethod::Ep => obs.ep_update(lr, &mut self.storage),
|
||||
BinaryModelFitMethod::Kl => obs.kl_update(lr, &mut self.storage),
|
||||
};
|
||||
|
||||
if diff > max_diff {
|
||||
max_diff = diff;
|
||||
}
|
||||
}
|
||||
|
||||
for item in self.storage.items_mut() {
|
||||
item.fitter.fit();
|
||||
}
|
||||
|
||||
if verbose {
|
||||
println!("iteration {}, max diff: {:.5}", i + 1, max_diff);
|
||||
}
|
||||
|
||||
if max_diff < tol {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
pub fn probabilities(&mut self, team_1: &[&str], team_2: &[&str], t: f64) -> (f64, f64) {
|
||||
let mut elems = self.process_items(team_1, 1.0);
|
||||
elems.extend(self.process_items(team_2, -1.0));
|
||||
|
||||
let prob = match self.win_obs {
|
||||
BinaryModelObservation::Probit => probit_win_observation(&elems, t, 0.0, &self.storage),
|
||||
BinaryModelObservation::Logit => todo!(),
|
||||
};
|
||||
|
||||
(prob, 1.0 - prob)
|
||||
}
|
||||
|
||||
fn process_items(&self, items: &[&str], sign: f64) -> Vec<(usize, f64)> {
|
||||
items
|
||||
.iter()
|
||||
.map(|key| (self.storage.get_id(&key), sign))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
193
src/model/ternary.rs
Normal file
193
src/model/ternary.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
use std::f64;
|
||||
|
||||
use crate::fitter::RecursiveFitter;
|
||||
use crate::item::Item;
|
||||
use crate::kernel::Kernel;
|
||||
use crate::observation::*;
|
||||
use crate::storage::Storage;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum TernaryModelObservation {
|
||||
Probit,
|
||||
Logit,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum TernaryModelFitMethod {
|
||||
Ep,
|
||||
Kl,
|
||||
}
|
||||
|
||||
pub struct TernaryModel {
|
||||
storage: Storage,
|
||||
last_t: f64,
|
||||
obs: TernaryModelObservation,
|
||||
observations: Vec<Box<dyn Observation>>,
|
||||
last_method: Option<TernaryModelFitMethod>,
|
||||
margin: f64,
|
||||
}
|
||||
|
||||
impl TernaryModel {
|
||||
pub fn new(obs: TernaryModelObservation, margin: f64) -> Self {
|
||||
TernaryModel {
|
||||
storage: Storage::new(),
|
||||
last_t: f64::NEG_INFINITY,
|
||||
obs, // default = probit
|
||||
observations: Vec::new(),
|
||||
last_method: None,
|
||||
margin, // default = 0.1
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_item(&mut self, name: &str, kernel: Box<dyn Kernel>) {
|
||||
if self.storage.contains_key(name) {
|
||||
panic!("item '{}' already added", name);
|
||||
}
|
||||
|
||||
self.storage.insert(
|
||||
name.to_string(),
|
||||
Item::new(Box::new(RecursiveFitter::new(kernel))),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn contains_item(&self, name: &str) -> bool {
|
||||
self.storage.contains_key(name)
|
||||
}
|
||||
|
||||
pub fn item_score(&self, name: &str, t: f64) -> (f64, f64) {
|
||||
let id = self.storage.get_id(name);
|
||||
let (ms, vs) = self.storage.item(id).fitter.predict(&[t]);
|
||||
|
||||
(ms[0], vs[0])
|
||||
}
|
||||
|
||||
pub fn observe(
|
||||
&mut self,
|
||||
winners: &[&str],
|
||||
losers: &[&str],
|
||||
t: f64,
|
||||
tie: bool,
|
||||
margin: Option<f64>,
|
||||
) {
|
||||
if t < self.last_t {
|
||||
panic!("observations must be added in chronological order");
|
||||
}
|
||||
|
||||
let margin = margin.unwrap_or_else(|| self.margin);
|
||||
|
||||
let mut elems = self.process_items(winners, 1.0);
|
||||
elems.extend(self.process_items(losers, -1.0));
|
||||
|
||||
let obs: Box<dyn Observation> = match (tie, self.obs) {
|
||||
(false, TernaryModelObservation::Probit) => Box::new(ProbitWinObservation::new(
|
||||
&mut self.storage,
|
||||
&elems,
|
||||
t,
|
||||
margin,
|
||||
)),
|
||||
(false, TernaryModelObservation::Logit) => Box::new(LogitWinObservation::new(
|
||||
&mut self.storage,
|
||||
&elems,
|
||||
t,
|
||||
margin,
|
||||
)),
|
||||
(true, TernaryModelObservation::Probit) => Box::new(ProbitTieObservation::new(
|
||||
&mut self.storage,
|
||||
&elems,
|
||||
t,
|
||||
margin,
|
||||
)),
|
||||
(true, TernaryModelObservation::Logit) => Box::new(LogitTieObservation::new(
|
||||
&mut self.storage,
|
||||
&elems,
|
||||
t,
|
||||
margin,
|
||||
)),
|
||||
};
|
||||
|
||||
self.observations.push(obs);
|
||||
|
||||
self.last_t = t;
|
||||
}
|
||||
|
||||
pub fn fit(&mut self) -> bool {
|
||||
// method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False):
|
||||
|
||||
let method = TernaryModelFitMethod::Ep;
|
||||
let lr = 1.0;
|
||||
let tol = 1e-3;
|
||||
let max_iter = 100;
|
||||
let verbose = true;
|
||||
|
||||
self.last_method = Some(method);
|
||||
|
||||
for item in self.storage.items_mut() {
|
||||
item.fitter.allocate();
|
||||
}
|
||||
|
||||
for i in 0..max_iter {
|
||||
let mut max_diff = 0.0;
|
||||
|
||||
for obs in &mut self.observations {
|
||||
let diff = match method {
|
||||
TernaryModelFitMethod::Ep => obs.ep_update(lr, &mut self.storage),
|
||||
TernaryModelFitMethod::Kl => obs.kl_update(lr, &mut self.storage),
|
||||
};
|
||||
|
||||
if diff > max_diff {
|
||||
max_diff = diff;
|
||||
}
|
||||
}
|
||||
|
||||
for item in self.storage.items_mut() {
|
||||
item.fitter.fit();
|
||||
}
|
||||
|
||||
if verbose {
|
||||
println!("iteration {}, max diff: {:.5}", i + 1, max_diff);
|
||||
}
|
||||
|
||||
if max_diff < tol {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
pub fn probabilities(
|
||||
&mut self,
|
||||
team_1: &[&str],
|
||||
team_2: &[&str],
|
||||
t: f64,
|
||||
margin: Option<f64>,
|
||||
) -> (f64, f64, f64) {
|
||||
let margin = margin.unwrap_or_else(|| self.margin);
|
||||
|
||||
let mut elems = self.process_items(team_1, 1.0);
|
||||
elems.extend(self.process_items(team_2, -1.0));
|
||||
|
||||
let prob_1 = match self.obs {
|
||||
TernaryModelObservation::Probit => {
|
||||
probit_win_observation(&elems, t, margin, &self.storage)
|
||||
}
|
||||
TernaryModelObservation::Logit => unimplemented!(),
|
||||
};
|
||||
|
||||
let prob_2 = match self.obs {
|
||||
TernaryModelObservation::Probit => {
|
||||
probit_tie_observation(&elems, t, margin, &self.storage)
|
||||
}
|
||||
TernaryModelObservation::Logit => unimplemented!(),
|
||||
};
|
||||
|
||||
(prob_1, prob_2, 1.0 - prob_1 - prob_2)
|
||||
}
|
||||
|
||||
fn process_items(&self, items: &[&str], sign: f64) -> Vec<(usize, f64)> {
|
||||
items
|
||||
.iter()
|
||||
.map(|key| (self.storage.get_id(&key), sign))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,21 @@
|
||||
use crate::storage::Storage;
|
||||
use crate::utils::logphi;
|
||||
|
||||
use super::Observation;
|
||||
use super::{f_params, Observation};
|
||||
|
||||
pub fn mm_probit_win(mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
|
||||
pub fn probit_win_observation(
|
||||
elems: &[(usize, f64)],
|
||||
t: f64,
|
||||
margin: f64,
|
||||
storage: &Storage,
|
||||
) -> f64 {
|
||||
let (m, v) = f_params(&elems, t, &storage);
|
||||
let (logpart, _, _) = mm_probit_win(m - margin, v);
|
||||
|
||||
logpart.exp()
|
||||
}
|
||||
|
||||
fn mm_probit_win(mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
|
||||
// Adapted from the GPML function `likErf.m`.
|
||||
let z = mean_cav / (1.0 + cov_cav).sqrt();
|
||||
let (logpart, val) = logphi(z);
|
||||
@@ -135,3 +147,60 @@ impl Observation for LogitWinObservation {
|
||||
todo!();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn probit_tie_observation(
|
||||
elems: &[(usize, f64)],
|
||||
t: f64,
|
||||
margin: f64,
|
||||
storage: &Storage,
|
||||
) -> f64 {
|
||||
unimplemented!();
|
||||
}
|
||||
|
||||
pub struct ProbitTieObservation {
|
||||
//
|
||||
}
|
||||
|
||||
impl ProbitTieObservation {
|
||||
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
|
||||
todo!();
|
||||
}
|
||||
}
|
||||
|
||||
impl Observation for ProbitTieObservation {
|
||||
fn match_moments(&self, mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
|
||||
todo!();
|
||||
}
|
||||
|
||||
fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
|
||||
todo!();
|
||||
}
|
||||
|
||||
fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
|
||||
todo!();
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LogitTieObservation {
|
||||
//
|
||||
}
|
||||
|
||||
impl LogitTieObservation {
|
||||
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
|
||||
todo!();
|
||||
}
|
||||
}
|
||||
|
||||
impl Observation for LogitTieObservation {
|
||||
fn match_moments(&self, mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
|
||||
todo!();
|
||||
}
|
||||
|
||||
fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
|
||||
todo!();
|
||||
}
|
||||
|
||||
fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
|
||||
todo!();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user