IT WORKS! But small difference between this version and the python. 46.2% vs 47.1%, but that's not much.

This commit is contained in:
2020-02-20 16:36:02 +01:00
parent 9ac7e44776
commit 8a1e6620ad
5 changed files with 111 additions and 41 deletions

View File

@@ -10,6 +10,8 @@ pub trait Fitter {
fn fit(&mut self); fn fit(&mut self);
fn predict(&self, ts: &[f64]) -> (Vec<f64>, Vec<f64>);
fn vs(&self, idx: usize) -> f64; fn vs(&self, idx: usize) -> f64;
fn vs_mut(&mut self, idx: usize) -> &mut f64; fn vs_mut(&mut self, idx: usize) -> &mut f64;

View File

@@ -196,6 +196,82 @@ impl Fitter for RecursiveFitter {
self.is_fitted = true; self.is_fitted = true;
} }
fn predict(&self, ts: &[f64]) -> (Vec<f64>, Vec<f64>) {
if !self.is_fitted {
panic!("new data since last call to `fit()`");
}
if self.ts.is_empty() {
return (vec![0.0], self.kernel.k_diag(ts).to_vec());
}
let mut ms = Vec::new();
let mut vs = Vec::new();
let locations = ts
.iter()
.map(|t| {
self.ts
.iter()
.position(|tc| t <= tc)
.unwrap_or_else(|| self.ts.len())
})
.collect::<Vec<_>>();
for (i, nxt) in locations.into_iter().enumerate() {
if nxt == self.ts.len() {
// new point is *after* last observation
let a = self.kernel.transition(self.ts[self.ts.len() - 1], ts[i]);
let q = self.kernel.noise_cov(self.ts[self.ts.len() - 1], ts[i]);
ms.push(self.h.dot(&a.dot(&self.m_s[self.m_s.len() - 1])));
vs.push(
self.h
.dot(&(a.dot(&self.p_s[self.p_s.len() - 1]).dot(&a.t()) + &q))
.dot(&self.h),
);
} else {
let j = nxt as i32 - 1;
let (m, p) = if j < 0 {
(self.kernel.state_mean(ts[i]), self.kernel.state_cov(ts[i]))
} else {
// Predictive mean and cov for new point based on left neighbor.
let a = self.kernel.transition(self.ts[j as usize], ts[i]);
let q = self.kernel.noise_cov(self.ts[j as usize], ts[i]);
(
a.dot(&self.m_f[j as usize]),
a.dot(&self.p_f[j as usize]).dot(&a.t()) + &q,
)
};
// RTS update using the right neighbor.
let a = self.kernel.transition(ts[i], self.ts[(j + 1) as usize]);
let g = a.dot(&p).dot(
&self.p_p[(j + 1) as usize]
.inv()
.expect("failed to inverse matrix"),
);
ms.push(self.h.dot(
&(&m + &g.dot(&(&self.m_s[(j + 1) as usize] - &self.m_p[(j + 1) as usize]))),
));
vs.push(
self.h
.dot(
&(&p + &g
.dot(&(&self.p_s[(j + 1) as usize] - &self.p_p[(j + 1) as usize])))
.dot(&g.t()),
)
.dot(&self.h),
);
}
}
(ms, vs)
}
fn vs(&self, idx: usize) -> f64 { fn vs(&self, idx: usize) -> f64 {
*&self.vs[idx] *&self.vs[idx]
} }

View File

@@ -116,50 +116,28 @@ impl BinaryModel {
} }
false false
/*
if method == "ep":
update = lambda obs: obs.ep_update(lr=lr)
elif method == "kl":
update = lambda obs: obs.kl_update(lr=lr)
else:
raise ValueError("'method' should be one of: 'ep', 'kl'")
self._last_method = method
for item in self._item.values():
item.fitter.allocate()
for i in range(max_iter):
max_diff = 0.0
# Recompute the Gaussian pseudo-observations.
for obs in self.observations:
diff = update(obs)
max_diff = max(max_diff, diff)
# Recompute the posterior of the score processes.
for item in self.item.values():
item.fitter.fit()
if verbose:
print("iteration {}, max diff: {:.5f}".format(
i+1, max_diff), flush=True)
if max_diff < tol:
return True
return False # Did not converge after `max_iter`.
*/
//
} }
pub fn probabilities(&mut self, team_1: &[&str], team_2: &[&str], t: f64) -> (f64, f64) { pub fn probabilities(&mut self, team_1: &[&str], team_2: &[&str], t: f64) -> (f64, f64) {
(0.0, 0.0) let mut elems = self.process_items(team_1, 1.0);
elems.extend(self.process_items(team_2, -1.0));
let prob = match self.win_obs {
BinaryModelObservation::Probit => {
let margin = 0.0;
let (m, v) = f_params(&elems, t, &self.storage);
let (logpart, _, _) = mm_probit_win(m - margin, v);
logpart.exp()
}
BinaryModelObservation::Logit => todo!(),
};
(prob, 1.0 - prob)
} }
fn process_items(&self, items: &[&str], sign: f64) -> Vec<(usize, f64)> { fn process_items(&self, items: &[&str], sign: f64) -> Vec<(usize, f64)> {
/*
if isinstance(items, dict):
return [(self.item[k], sign * float(v)) for k, v in items.items()]
if isinstance(items, list) or isinstance(items, tuple):
return [(self.item[k], sign) for k in items]
else:
raise ValueError("items should be a list, a tuple or a dict")
*/
items items
.iter() .iter()
.map(|key| (self.storage.get_id(&key), sign)) .map(|key| (self.storage.get_id(&key), sign))

View File

@@ -1,7 +1,7 @@
mod ordinal;
use crate::storage::Storage; use crate::storage::Storage;
mod ordinal;
pub use ordinal::*; pub use ordinal::*;
pub trait Observation { pub trait Observation {
@@ -10,3 +10,17 @@ pub trait Observation {
fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64; fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64;
fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64; fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64;
} }
pub fn f_params(elems: &[(usize, f64)], t: f64, storage: &Storage) -> (f64, f64) {
let mut m = 0.0;
let mut v = 0.0;
for (item, coeff) in elems.iter().map(|(id, coeff)| (storage.item(*id), coeff)) {
let (ms, vs) = item.fitter.predict(&[t]);
m += coeff * ms[0];
v += coeff * coeff * vs[0];
}
(m, v)
}

View File

@@ -3,7 +3,7 @@ use crate::utils::logphi;
use super::Observation; use super::Observation;
fn mm_probit_win(mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) { pub fn mm_probit_win(mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
// Adapted from the GPML function `likErf.m`. // Adapted from the GPML function `likErf.m`.
let z = mean_cav / (1.0 + cov_cav).sqrt(); let z = mean_cav / (1.0 + cov_cav).sqrt();
let (logpart, val) = logphi(z); let (logpart, val) = logphi(z);