diff --git a/src/fitter.rs b/src/fitter.rs index 9bcfc66..c336be1 100644 --- a/src/fitter.rs +++ b/src/fitter.rs @@ -10,6 +10,8 @@ pub trait Fitter { fn fit(&mut self); + fn predict(&self, ts: &[f64]) -> (Vec, Vec); + fn vs(&self, idx: usize) -> f64; fn vs_mut(&mut self, idx: usize) -> &mut f64; diff --git a/src/fitter/recursive.rs b/src/fitter/recursive.rs index 8c5016c..ee99a9c 100644 --- a/src/fitter/recursive.rs +++ b/src/fitter/recursive.rs @@ -196,6 +196,82 @@ impl Fitter for RecursiveFitter { self.is_fitted = true; } + fn predict(&self, ts: &[f64]) -> (Vec, Vec) { + if !self.is_fitted { + panic!("new data since last call to `fit()`"); + } + + if self.ts.is_empty() { + return (vec![0.0], self.kernel.k_diag(ts).to_vec()); + } + + let mut ms = Vec::new(); + let mut vs = Vec::new(); + + let locations = ts + .iter() + .map(|t| { + self.ts + .iter() + .position(|tc| t <= tc) + .unwrap_or_else(|| self.ts.len()) + }) + .collect::>(); + + for (i, nxt) in locations.into_iter().enumerate() { + if nxt == self.ts.len() { + // new point is *after* last observation + let a = self.kernel.transition(self.ts[self.ts.len() - 1], ts[i]); + let q = self.kernel.noise_cov(self.ts[self.ts.len() - 1], ts[i]); + + ms.push(self.h.dot(&a.dot(&self.m_s[self.m_s.len() - 1]))); + vs.push( + self.h + .dot(&(a.dot(&self.p_s[self.p_s.len() - 1]).dot(&a.t()) + &q)) + .dot(&self.h), + ); + } else { + let j = nxt as i32 - 1; + + let (m, p) = if j < 0 { + (self.kernel.state_mean(ts[i]), self.kernel.state_cov(ts[i])) + } else { + // Predictive mean and cov for new point based on left neighbor. + let a = self.kernel.transition(self.ts[j as usize], ts[i]); + let q = self.kernel.noise_cov(self.ts[j as usize], ts[i]); + + ( + a.dot(&self.m_f[j as usize]), + a.dot(&self.p_f[j as usize]).dot(&a.t()) + &q, + ) + }; + + // RTS update using the right neighbor. + let a = self.kernel.transition(ts[i], self.ts[(j + 1) as usize]); + let g = a.dot(&p).dot( + &self.p_p[(j + 1) as usize] + .inv() + .expect("failed to inverse matrix"), + ); + + ms.push(self.h.dot( + &(&m + &g.dot(&(&self.m_s[(j + 1) as usize] - &self.m_p[(j + 1) as usize]))), + )); + vs.push( + self.h + .dot( + &(&p + &g + .dot(&(&self.p_s[(j + 1) as usize] - &self.p_p[(j + 1) as usize]))) + .dot(&g.t()), + ) + .dot(&self.h), + ); + } + } + + (ms, vs) + } + fn vs(&self, idx: usize) -> f64 { *&self.vs[idx] } diff --git a/src/model.rs b/src/model.rs index 61c19f1..7a43935 100644 --- a/src/model.rs +++ b/src/model.rs @@ -116,50 +116,28 @@ impl BinaryModel { } false - - /* - if method == "ep": - update = lambda obs: obs.ep_update(lr=lr) - elif method == "kl": - update = lambda obs: obs.kl_update(lr=lr) - else: - raise ValueError("'method' should be one of: 'ep', 'kl'") - self._last_method = method - for item in self._item.values(): - item.fitter.allocate() - for i in range(max_iter): - max_diff = 0.0 - # Recompute the Gaussian pseudo-observations. - for obs in self.observations: - diff = update(obs) - max_diff = max(max_diff, diff) - # Recompute the posterior of the score processes. - for item in self.item.values(): - item.fitter.fit() - if verbose: - print("iteration {}, max diff: {:.5f}".format( - i+1, max_diff), flush=True) - if max_diff < tol: - return True - return False # Did not converge after `max_iter`. - */ - // } pub fn probabilities(&mut self, team_1: &[&str], team_2: &[&str], t: f64) -> (f64, f64) { - (0.0, 0.0) + let mut elems = self.process_items(team_1, 1.0); + elems.extend(self.process_items(team_2, -1.0)); + + let prob = match self.win_obs { + BinaryModelObservation::Probit => { + let margin = 0.0; + + let (m, v) = f_params(&elems, t, &self.storage); + let (logpart, _, _) = mm_probit_win(m - margin, v); + + logpart.exp() + } + BinaryModelObservation::Logit => todo!(), + }; + + (prob, 1.0 - prob) } fn process_items(&self, items: &[&str], sign: f64) -> Vec<(usize, f64)> { - /* - if isinstance(items, dict): - return [(self.item[k], sign * float(v)) for k, v in items.items()] - if isinstance(items, list) or isinstance(items, tuple): - return [(self.item[k], sign) for k in items] - else: - raise ValueError("items should be a list, a tuple or a dict") - */ - items .iter() .map(|key| (self.storage.get_id(&key), sign)) diff --git a/src/observation.rs b/src/observation.rs index d1a4e1a..f4bc284 100644 --- a/src/observation.rs +++ b/src/observation.rs @@ -1,7 +1,7 @@ -mod ordinal; - use crate::storage::Storage; +mod ordinal; + pub use ordinal::*; pub trait Observation { @@ -10,3 +10,17 @@ pub trait Observation { fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64; fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64; } + +pub fn f_params(elems: &[(usize, f64)], t: f64, storage: &Storage) -> (f64, f64) { + let mut m = 0.0; + let mut v = 0.0; + + for (item, coeff) in elems.iter().map(|(id, coeff)| (storage.item(*id), coeff)) { + let (ms, vs) = item.fitter.predict(&[t]); + + m += coeff * ms[0]; + v += coeff * coeff * vs[0]; + } + + (m, v) +} diff --git a/src/observation/ordinal.rs b/src/observation/ordinal.rs index 2af9d6a..dddc2bb 100644 --- a/src/observation/ordinal.rs +++ b/src/observation/ordinal.rs @@ -3,7 +3,7 @@ use crate::utils::logphi; use super::Observation; -fn mm_probit_win(mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) { +pub fn mm_probit_win(mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) { // Adapted from the GPML function `likErf.m`. let z = mean_cav / (1.0 + cov_cav).sqrt(); let (logpart, val) = logphi(z);