IT WORKS! But small difference between this version and the python. 46.2% vs 47.1%, but that's not much.
This commit is contained in:
@@ -196,6 +196,82 @@ impl Fitter for RecursiveFitter {
|
||||
self.is_fitted = true;
|
||||
}
|
||||
|
||||
fn predict(&self, ts: &[f64]) -> (Vec<f64>, Vec<f64>) {
|
||||
if !self.is_fitted {
|
||||
panic!("new data since last call to `fit()`");
|
||||
}
|
||||
|
||||
if self.ts.is_empty() {
|
||||
return (vec![0.0], self.kernel.k_diag(ts).to_vec());
|
||||
}
|
||||
|
||||
let mut ms = Vec::new();
|
||||
let mut vs = Vec::new();
|
||||
|
||||
let locations = ts
|
||||
.iter()
|
||||
.map(|t| {
|
||||
self.ts
|
||||
.iter()
|
||||
.position(|tc| t <= tc)
|
||||
.unwrap_or_else(|| self.ts.len())
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for (i, nxt) in locations.into_iter().enumerate() {
|
||||
if nxt == self.ts.len() {
|
||||
// new point is *after* last observation
|
||||
let a = self.kernel.transition(self.ts[self.ts.len() - 1], ts[i]);
|
||||
let q = self.kernel.noise_cov(self.ts[self.ts.len() - 1], ts[i]);
|
||||
|
||||
ms.push(self.h.dot(&a.dot(&self.m_s[self.m_s.len() - 1])));
|
||||
vs.push(
|
||||
self.h
|
||||
.dot(&(a.dot(&self.p_s[self.p_s.len() - 1]).dot(&a.t()) + &q))
|
||||
.dot(&self.h),
|
||||
);
|
||||
} else {
|
||||
let j = nxt as i32 - 1;
|
||||
|
||||
let (m, p) = if j < 0 {
|
||||
(self.kernel.state_mean(ts[i]), self.kernel.state_cov(ts[i]))
|
||||
} else {
|
||||
// Predictive mean and cov for new point based on left neighbor.
|
||||
let a = self.kernel.transition(self.ts[j as usize], ts[i]);
|
||||
let q = self.kernel.noise_cov(self.ts[j as usize], ts[i]);
|
||||
|
||||
(
|
||||
a.dot(&self.m_f[j as usize]),
|
||||
a.dot(&self.p_f[j as usize]).dot(&a.t()) + &q,
|
||||
)
|
||||
};
|
||||
|
||||
// RTS update using the right neighbor.
|
||||
let a = self.kernel.transition(ts[i], self.ts[(j + 1) as usize]);
|
||||
let g = a.dot(&p).dot(
|
||||
&self.p_p[(j + 1) as usize]
|
||||
.inv()
|
||||
.expect("failed to inverse matrix"),
|
||||
);
|
||||
|
||||
ms.push(self.h.dot(
|
||||
&(&m + &g.dot(&(&self.m_s[(j + 1) as usize] - &self.m_p[(j + 1) as usize]))),
|
||||
));
|
||||
vs.push(
|
||||
self.h
|
||||
.dot(
|
||||
&(&p + &g
|
||||
.dot(&(&self.p_s[(j + 1) as usize] - &self.p_p[(j + 1) as usize])))
|
||||
.dot(&g.t()),
|
||||
)
|
||||
.dot(&self.h),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
(ms, vs)
|
||||
}
|
||||
|
||||
fn vs(&self, idx: usize) -> f64 {
|
||||
*&self.vs[idx]
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user