IT WORKS! But small difference between this version and the python. 46.2% vs 47.1%, but that's not much.

This commit is contained in:
2020-02-20 16:36:02 +01:00
parent 9ac7e44776
commit 8a1e6620ad
5 changed files with 111 additions and 41 deletions

View File

@@ -196,6 +196,82 @@ impl Fitter for RecursiveFitter {
self.is_fitted = true;
}
fn predict(&self, ts: &[f64]) -> (Vec<f64>, Vec<f64>) {
if !self.is_fitted {
panic!("new data since last call to `fit()`");
}
if self.ts.is_empty() {
return (vec![0.0], self.kernel.k_diag(ts).to_vec());
}
let mut ms = Vec::new();
let mut vs = Vec::new();
let locations = ts
.iter()
.map(|t| {
self.ts
.iter()
.position(|tc| t <= tc)
.unwrap_or_else(|| self.ts.len())
})
.collect::<Vec<_>>();
for (i, nxt) in locations.into_iter().enumerate() {
if nxt == self.ts.len() {
// new point is *after* last observation
let a = self.kernel.transition(self.ts[self.ts.len() - 1], ts[i]);
let q = self.kernel.noise_cov(self.ts[self.ts.len() - 1], ts[i]);
ms.push(self.h.dot(&a.dot(&self.m_s[self.m_s.len() - 1])));
vs.push(
self.h
.dot(&(a.dot(&self.p_s[self.p_s.len() - 1]).dot(&a.t()) + &q))
.dot(&self.h),
);
} else {
let j = nxt as i32 - 1;
let (m, p) = if j < 0 {
(self.kernel.state_mean(ts[i]), self.kernel.state_cov(ts[i]))
} else {
// Predictive mean and cov for new point based on left neighbor.
let a = self.kernel.transition(self.ts[j as usize], ts[i]);
let q = self.kernel.noise_cov(self.ts[j as usize], ts[i]);
(
a.dot(&self.m_f[j as usize]),
a.dot(&self.p_f[j as usize]).dot(&a.t()) + &q,
)
};
// RTS update using the right neighbor.
let a = self.kernel.transition(ts[i], self.ts[(j + 1) as usize]);
let g = a.dot(&p).dot(
&self.p_p[(j + 1) as usize]
.inv()
.expect("failed to inverse matrix"),
);
ms.push(self.h.dot(
&(&m + &g.dot(&(&self.m_s[(j + 1) as usize] - &self.m_p[(j + 1) as usize]))),
));
vs.push(
self.h
.dot(
&(&p + &g
.dot(&(&self.p_s[(j + 1) as usize] - &self.p_p[(j + 1) as usize])))
.dot(&g.t()),
)
.dot(&self.h),
);
}
}
(ms, vs)
}
fn vs(&self, idx: usize) -> f64 {
*&self.vs[idx]
}