IT WORKS! But small difference between this version and the python. 46.2% vs 47.1%, but that's not much.
This commit is contained in:
@@ -10,6 +10,8 @@ pub trait Fitter {
|
|||||||
|
|
||||||
fn fit(&mut self);
|
fn fit(&mut self);
|
||||||
|
|
||||||
|
fn predict(&self, ts: &[f64]) -> (Vec<f64>, Vec<f64>);
|
||||||
|
|
||||||
fn vs(&self, idx: usize) -> f64;
|
fn vs(&self, idx: usize) -> f64;
|
||||||
fn vs_mut(&mut self, idx: usize) -> &mut f64;
|
fn vs_mut(&mut self, idx: usize) -> &mut f64;
|
||||||
|
|
||||||
|
|||||||
@@ -196,6 +196,82 @@ impl Fitter for RecursiveFitter {
|
|||||||
self.is_fitted = true;
|
self.is_fitted = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn predict(&self, ts: &[f64]) -> (Vec<f64>, Vec<f64>) {
|
||||||
|
if !self.is_fitted {
|
||||||
|
panic!("new data since last call to `fit()`");
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.ts.is_empty() {
|
||||||
|
return (vec![0.0], self.kernel.k_diag(ts).to_vec());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut ms = Vec::new();
|
||||||
|
let mut vs = Vec::new();
|
||||||
|
|
||||||
|
let locations = ts
|
||||||
|
.iter()
|
||||||
|
.map(|t| {
|
||||||
|
self.ts
|
||||||
|
.iter()
|
||||||
|
.position(|tc| t <= tc)
|
||||||
|
.unwrap_or_else(|| self.ts.len())
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
for (i, nxt) in locations.into_iter().enumerate() {
|
||||||
|
if nxt == self.ts.len() {
|
||||||
|
// new point is *after* last observation
|
||||||
|
let a = self.kernel.transition(self.ts[self.ts.len() - 1], ts[i]);
|
||||||
|
let q = self.kernel.noise_cov(self.ts[self.ts.len() - 1], ts[i]);
|
||||||
|
|
||||||
|
ms.push(self.h.dot(&a.dot(&self.m_s[self.m_s.len() - 1])));
|
||||||
|
vs.push(
|
||||||
|
self.h
|
||||||
|
.dot(&(a.dot(&self.p_s[self.p_s.len() - 1]).dot(&a.t()) + &q))
|
||||||
|
.dot(&self.h),
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
let j = nxt as i32 - 1;
|
||||||
|
|
||||||
|
let (m, p) = if j < 0 {
|
||||||
|
(self.kernel.state_mean(ts[i]), self.kernel.state_cov(ts[i]))
|
||||||
|
} else {
|
||||||
|
// Predictive mean and cov for new point based on left neighbor.
|
||||||
|
let a = self.kernel.transition(self.ts[j as usize], ts[i]);
|
||||||
|
let q = self.kernel.noise_cov(self.ts[j as usize], ts[i]);
|
||||||
|
|
||||||
|
(
|
||||||
|
a.dot(&self.m_f[j as usize]),
|
||||||
|
a.dot(&self.p_f[j as usize]).dot(&a.t()) + &q,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
// RTS update using the right neighbor.
|
||||||
|
let a = self.kernel.transition(ts[i], self.ts[(j + 1) as usize]);
|
||||||
|
let g = a.dot(&p).dot(
|
||||||
|
&self.p_p[(j + 1) as usize]
|
||||||
|
.inv()
|
||||||
|
.expect("failed to inverse matrix"),
|
||||||
|
);
|
||||||
|
|
||||||
|
ms.push(self.h.dot(
|
||||||
|
&(&m + &g.dot(&(&self.m_s[(j + 1) as usize] - &self.m_p[(j + 1) as usize]))),
|
||||||
|
));
|
||||||
|
vs.push(
|
||||||
|
self.h
|
||||||
|
.dot(
|
||||||
|
&(&p + &g
|
||||||
|
.dot(&(&self.p_s[(j + 1) as usize] - &self.p_p[(j + 1) as usize])))
|
||||||
|
.dot(&g.t()),
|
||||||
|
)
|
||||||
|
.dot(&self.h),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(ms, vs)
|
||||||
|
}
|
||||||
|
|
||||||
fn vs(&self, idx: usize) -> f64 {
|
fn vs(&self, idx: usize) -> f64 {
|
||||||
*&self.vs[idx]
|
*&self.vs[idx]
|
||||||
}
|
}
|
||||||
|
|||||||
54
src/model.rs
54
src/model.rs
@@ -116,50 +116,28 @@ impl BinaryModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
false
|
false
|
||||||
|
|
||||||
/*
|
|
||||||
if method == "ep":
|
|
||||||
update = lambda obs: obs.ep_update(lr=lr)
|
|
||||||
elif method == "kl":
|
|
||||||
update = lambda obs: obs.kl_update(lr=lr)
|
|
||||||
else:
|
|
||||||
raise ValueError("'method' should be one of: 'ep', 'kl'")
|
|
||||||
self._last_method = method
|
|
||||||
for item in self._item.values():
|
|
||||||
item.fitter.allocate()
|
|
||||||
for i in range(max_iter):
|
|
||||||
max_diff = 0.0
|
|
||||||
# Recompute the Gaussian pseudo-observations.
|
|
||||||
for obs in self.observations:
|
|
||||||
diff = update(obs)
|
|
||||||
max_diff = max(max_diff, diff)
|
|
||||||
# Recompute the posterior of the score processes.
|
|
||||||
for item in self.item.values():
|
|
||||||
item.fitter.fit()
|
|
||||||
if verbose:
|
|
||||||
print("iteration {}, max diff: {:.5f}".format(
|
|
||||||
i+1, max_diff), flush=True)
|
|
||||||
if max_diff < tol:
|
|
||||||
return True
|
|
||||||
return False # Did not converge after `max_iter`.
|
|
||||||
*/
|
|
||||||
//
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn probabilities(&mut self, team_1: &[&str], team_2: &[&str], t: f64) -> (f64, f64) {
|
pub fn probabilities(&mut self, team_1: &[&str], team_2: &[&str], t: f64) -> (f64, f64) {
|
||||||
(0.0, 0.0)
|
let mut elems = self.process_items(team_1, 1.0);
|
||||||
|
elems.extend(self.process_items(team_2, -1.0));
|
||||||
|
|
||||||
|
let prob = match self.win_obs {
|
||||||
|
BinaryModelObservation::Probit => {
|
||||||
|
let margin = 0.0;
|
||||||
|
|
||||||
|
let (m, v) = f_params(&elems, t, &self.storage);
|
||||||
|
let (logpart, _, _) = mm_probit_win(m - margin, v);
|
||||||
|
|
||||||
|
logpart.exp()
|
||||||
|
}
|
||||||
|
BinaryModelObservation::Logit => todo!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
(prob, 1.0 - prob)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn process_items(&self, items: &[&str], sign: f64) -> Vec<(usize, f64)> {
|
fn process_items(&self, items: &[&str], sign: f64) -> Vec<(usize, f64)> {
|
||||||
/*
|
|
||||||
if isinstance(items, dict):
|
|
||||||
return [(self.item[k], sign * float(v)) for k, v in items.items()]
|
|
||||||
if isinstance(items, list) or isinstance(items, tuple):
|
|
||||||
return [(self.item[k], sign) for k in items]
|
|
||||||
else:
|
|
||||||
raise ValueError("items should be a list, a tuple or a dict")
|
|
||||||
*/
|
|
||||||
|
|
||||||
items
|
items
|
||||||
.iter()
|
.iter()
|
||||||
.map(|key| (self.storage.get_id(&key), sign))
|
.map(|key| (self.storage.get_id(&key), sign))
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
mod ordinal;
|
|
||||||
|
|
||||||
use crate::storage::Storage;
|
use crate::storage::Storage;
|
||||||
|
|
||||||
|
mod ordinal;
|
||||||
|
|
||||||
pub use ordinal::*;
|
pub use ordinal::*;
|
||||||
|
|
||||||
pub trait Observation {
|
pub trait Observation {
|
||||||
@@ -10,3 +10,17 @@ pub trait Observation {
|
|||||||
fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64;
|
fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64;
|
||||||
fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64;
|
fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn f_params(elems: &[(usize, f64)], t: f64, storage: &Storage) -> (f64, f64) {
|
||||||
|
let mut m = 0.0;
|
||||||
|
let mut v = 0.0;
|
||||||
|
|
||||||
|
for (item, coeff) in elems.iter().map(|(id, coeff)| (storage.item(*id), coeff)) {
|
||||||
|
let (ms, vs) = item.fitter.predict(&[t]);
|
||||||
|
|
||||||
|
m += coeff * ms[0];
|
||||||
|
v += coeff * coeff * vs[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
(m, v)
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ use crate::utils::logphi;
|
|||||||
|
|
||||||
use super::Observation;
|
use super::Observation;
|
||||||
|
|
||||||
fn mm_probit_win(mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
|
pub fn mm_probit_win(mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
|
||||||
// Adapted from the GPML function `likErf.m`.
|
// Adapted from the GPML function `likErf.m`.
|
||||||
let z = mean_cav / (1.0 + cov_cav).sqrt();
|
let z = mean_cav / (1.0 + cov_cav).sqrt();
|
||||||
let (logpart, val) = logphi(z);
|
let (logpart, val) = logphi(z);
|
||||||
|
|||||||
Reference in New Issue
Block a user