use std::fmt; use std::iter; use ndarray::prelude::*; use crate::kernel::Kernel; use super::Fitter; #[derive(Clone)] pub struct Recursive { ts_new: Vec, kernel: K, ts: Vec, ms: Vec, vs: Vec, ns: Vec, xs: Vec, is_fitted: bool, h: Array1, i: Array2, a: Vec>, q: Vec>, m_p: Vec>, p_p: Vec>, m_f: Vec>, p_f: Vec>, m_s: Vec>, p_s: Vec>, } impl Recursive { pub fn new(kernel: K) -> Self { let m = kernel.order(); let h = kernel.measurement_vector(); Self { ts_new: Vec::new(), kernel, ts: Vec::new(), ms: Vec::new(), vs: Vec::new(), ns: Vec::new(), xs: Vec::new(), is_fitted: true, h, i: Array::eye(m), a: Vec::new(), q: Vec::new(), m_p: Vec::new(), p_p: Vec::new(), m_f: Vec::new(), p_f: Vec::new(), m_s: Vec::new(), p_s: Vec::new(), } } } impl fmt::Debug for Recursive { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("RecursiveFitter") .field("ts_new", &self.ts_new) .field("ts", &self.ts) .field("ms", &self.ms) .field("vs", &self.vs) .field("ns", &self.ns) .field("xs", &self.xs) .field("is_fitted", &self.is_fitted) .field("h", &self.h) .field("i", &self.i) .field("a", &self.a) .field("q", &self.q) .field("m_p", &self.m_p) .field("p_p", &self.p_p) .field("m_f", &self.m_f) .field("p_f", &self.p_f) .field("m_s", &self.m_s) .field("p_s", &self.p_s) .finish() } } impl Fitter for Recursive { fn add_sample(&mut self, t: f64) -> usize { let idx = self.ts.len() + self.ts_new.len(); self.ts_new.push(t); self.is_fitted = false; idx } fn allocate(&mut self) { let n_new = self.ts_new.len(); if n_new == 0 { return; } // Usual variables. self.ts.extend(self.ts_new.iter()); self.ms.extend(iter::repeat(0.0).take(n_new)); self.vs.extend(self.kernel.k_diag(&self.ts_new).iter()); self.ns.extend(iter::repeat(0.0).take(n_new)); self.xs.extend(iter::repeat(0.0).take(n_new)); // Initialize the predictive, filtering and smoothing distributions. for t in &self.ts_new { let mean = self.kernel.state_mean(*t); self.m_p.push(mean.clone()); self.m_f.push(mean.clone()); self.m_s.push(mean); let cov = self.kernel.state_cov(*t); self.p_p.push(cov.clone()); self.p_f.push(cov.clone()); self.p_s.push(cov); } // Compute the new transition and noise covariance matrices. let m = self.kernel.order(); for _ in 0..n_new { self.a.push(Array2::zeros((m, m))); self.q.push(Array2::zeros((m, m))); } for i in (self.ts.len() - n_new)..self.ts.len() { if i == 0 { continue; } self.a[i - 1] = self.kernel.transition(self.ts[i - 1], self.ts[i]); self.q[i - 1] = self.kernel.noise_cov(self.ts[i - 1], self.ts[i]); } self.ts_new.clear(); } fn is_allocated(&self) -> bool { self.ts_new.is_empty() } fn fit(&mut self) { if !self.is_allocated() { panic!("new data since last call to `allocate()`"); } if self.ts.is_empty() { self.is_fitted = true; return; } // Forward pass (Kalman filter). for i in 0..self.ts.len() { if i > 0 { self.m_p[i] = self.a[i - 1].dot(&self.m_f[i - 1]); self.p_p[i] = self.a[i - 1].dot(&self.p_f[i - 1]).dot(&self.a[i - 1].t()) + &self.q[i - 1]; } // These are slightly modified equations to work with tau and nu. let k = self.p_p[i].dot(&self.h) / (1.0 + self.xs[i] * self.h.dot(&self.p_p[i]).dot(&self.h)); self.m_f[i] = &self.m_p[i] + &(&k * (self.ns[i] - self.xs[i] * self.h.dot(&self.m_p[i]))); // Covariance matrix is computed using the Joseph form. let outer = (self.xs[i] * &k) .iter() .flat_map(|a| self.h.iter().map(move |b| a * b)) .collect::>(); let outer = Array::from_shape_vec((self.h.len(), self.h.len()), outer) .expect("failed to create outer matrix"); let z = &self.i - &outer; let outer = k .iter() .flat_map(|a| k.iter().map(move |b| a * b)) .collect::>(); let outer = Array::from_shape_vec((self.h.len(), self.h.len()), outer) .expect("failed to create outer matrix"); self.p_f[i] = z.dot(&self.p_p[i]).dot(&z.t()) + self.xs[i] * outer; } // Backward pass (RTS smoother). for i in (0..self.ts.len()).rev() { if i == self.ts.len() - 1 { self.m_s[i] = self.m_f[i].clone(); self.p_s[i] = self.p_f[i].clone(); } else { let a = self.p_p[i + 1].clone(); let b = self.a[i].dot(&self.p_f[i]); let g = crate::linalg::solve(a, b); let g = g.t(); self.m_s[i] = &self.m_f[i] + &g.dot(&(&self.m_s[i + 1] - &self.m_p[i + 1])); self.p_s[i] = &self.p_f[i] + &g.dot(&(&self.p_s[i + 1] - &self.p_p[i + 1])).dot(&g.t()); } self.ms[i] = self.h.dot(&self.m_s[i]); self.vs[i] = self.h.dot(&self.p_s[i]).dot(&self.h); } self.is_fitted = true; } #[allow(clippy::many_single_char_names)] fn predict(&self, ts: &[f64]) -> (Vec, Vec) { if !self.is_fitted { panic!("new data since last call to `fit()`"); } if self.ts.is_empty() { return (vec![0.0], self.kernel.k_diag(ts).to_vec()); } let mut ms = Vec::new(); let mut vs = Vec::new(); let locations = ts.iter().map(|t| { self.ts .iter() .position(|tc| t <= tc) .unwrap_or(self.ts.len()) }); for (i, nxt) in locations.into_iter().enumerate() { if nxt == self.ts.len() { // new point is *after* last observation let a = self.kernel.transition(self.ts[self.ts.len() - 1], ts[i]); let q = self.kernel.noise_cov(self.ts[self.ts.len() - 1], ts[i]); ms.push(self.h.dot(&a.dot(&self.m_s[self.m_s.len() - 1]))); vs.push( self.h .dot(&(a.dot(&self.p_s[self.p_s.len() - 1]).dot(&a.t()) + &q)) .dot(&self.h), ); } else { let j = nxt as i32 - 1; let (m, p) = if j < 0 { (self.kernel.state_mean(ts[i]), self.kernel.state_cov(ts[i])) } else { // Predictive mean and cov for new point based on left neighbor. let a = self.kernel.transition(self.ts[j as usize], ts[i]); let q = self.kernel.noise_cov(self.ts[j as usize], ts[i]); ( a.dot(&self.m_f[j as usize]), a.dot(&self.p_f[j as usize]).dot(&a.t()) + &q, ) }; // RTS update using the right neighbor. let a = self.p_p[(j + 1) as usize].clone(); let b = self.kernel.transition(ts[i], self.ts[(j + 1) as usize]); let b = b.dot(&p); let g = crate::linalg::solve(a, b); let g = g.t(); ms.push(self.h.dot( &(&m + &g.dot(&(&self.m_s[(j + 1) as usize] - &self.m_p[(j + 1) as usize]))), )); vs.push( self.h .dot( &(&p + &g .dot(&(&self.p_s[(j + 1) as usize] - &self.p_p[(j + 1) as usize])) .dot(&g.t())), ) .dot(&self.h), ); } } (ms, vs) } /// Contribution to the log-marginal likelihood of the model fn ep_log_likelihood_contrib(&self) -> f64 { // Note: this is *not* equal to the log of the marginal likelihood of the // regression model. See "stable computation of the marginal likelihood" // in the notes. if !self.is_fitted { panic!("new data since last call to `fit()`") } let mut val = 0.0; for i in 0..self.ts.len() { let o = self.h.dot(&self.m_p[i]); let v = self.h.dot(&self.p_p[i]).dot(&self.h); val += -0.5 * ((self.xs[i] * v + 1.0).ln() + (-self.ns[i].powi(2) * v - 2.0 * self.ns[i] * o + self.xs[i] * o.powi(2)) / (self.xs[i] * v + 1.0)); } val } fn vs(&self, idx: usize) -> f64 { self.vs[idx] } fn vs_mut(&mut self, idx: usize) -> &mut f64 { &mut self.vs[idx] } fn ms(&self, idx: usize) -> f64 { self.ms[idx] } fn ms_mut(&mut self, idx: usize) -> &mut f64 { &mut self.ms[idx] } fn xs(&self, idx: usize) -> f64 { self.xs[idx] } fn xs_mut(&mut self, idx: usize) -> &mut f64 { &mut self.xs[idx] } fn ns(&self, idx: usize) -> f64 { self.ns[idx] } fn ns_mut(&mut self, idx: usize) -> &mut f64 { &mut self.ns[idx] } } #[cfg(test)] mod tests { extern crate blas_src; use std::f64::consts::TAU; use approx::assert_relative_eq; use crate::kernel::Matern32; use super::*; fn fitter() -> Recursive { Recursive::new(Matern32::new(2.0, 1.0)) } fn data_ts_train() -> Vec { vec![ 0.11616722, 0.31198904, 0.31203728, 0.74908024, 1.19731697, 1.20223002, 1.41614516, 1.46398788, 1.73235229, 1.90142861, ] } fn data_ys() -> Array1 { array![ -1.10494786, -0.07702044, -0.25473925, 3.22959111, 0.90038114, 0.30686385, 1.70281621, -1.717506, 0.63707278, -1.40986299 ] } fn data_vs() -> Vec { vec![ 0.55064619, 0.3540315, 0.34114585, 2.21458142, 7.40431354, 0.35093921, 0.91847147, 4.50764809, 0.43440729, 1.3308561, ] } fn data_mean() -> Array1 { array![ -0.52517486, -0.18391072, -0.18381275, 0.59905936, 0.62923813, 0.6280899, 0.56576719, 0.53663651, 0.26874937, 0.04892406 ] } fn data_var() -> Array1 { array![ 0.20318775, 0.12410961, 0.12411533, 0.32855394, 0.19538865, 0.19410925, 0.18676754, 0.19074449, 0.22105848, 0.33534931 ] } fn data_loglik() -> f64 { -17.35728224571105 } fn data_ts_pred() -> &'static [f64] { &[0.0, 1.0, 2.0] } fn data_mean_pred() -> Array1 { array![-0.63981819, 0.67552349, -0.04684169] } fn data_var_pred() -> Array1 { array![0.33946081, 0.28362645, 0.45585554] } #[test] fn test_allocation() { let mut fitter = fitter(); // No data, hence fitter defined to be allocated assert!(fitter.is_allocated()); // Add some data for i in 0..8 { fitter.add_sample(i as f64); } assert!(!fitter.is_allocated()); // Allocate the arrays fitter.allocate(); assert!(fitter.is_allocated()); // Add some data for i in 0..8 { fitter.add_sample(i as f64); } assert!(!fitter.is_allocated()); // Re-allocate the arrays fitter.allocate(); assert!(fitter.is_allocated()); // Check that arrays have the appropriate size assert_eq!(fitter.ts.len(), 16); assert_eq!(fitter.ms.len(), 16); assert_eq!(fitter.vs.len(), 16); assert_eq!(fitter.ns.len(), 16); assert_eq!(fitter.xs.len(), 16); } #[test] fn test_against_gpy() { let mut fitter = fitter(); for t in data_ts_train().into_iter() { fitter.add_sample(t); } fitter.allocate(); fitter.xs = data_vs().iter().map(|v| 1.0 / v).collect(); fitter.ns = data_ys() .iter() .zip(data_vs().iter()) .map(|(ys, vs)| ys / vs) .collect(); fitter.fit(); // Estimation // eprintln!("{:#?}", fitter.ms); assert_relative_eq!( Array1::from(fitter.ms.clone()), data_mean(), max_relative = 0.0000001 ); assert_relative_eq!( Array1::from(fitter.vs.clone()), data_var(), max_relative = 0.0000001 ); // Prediction. let (ms, vs) = fitter.predict(data_ts_pred()); assert_relative_eq!(Array1::from(ms), data_mean_pred(), max_relative = 0.000001); assert_relative_eq!(Array1::from(vs), data_var_pred(), max_relative = 0.000001); // Log-likelihood let mut ll = fitter.ep_log_likelihood_contrib(); // We need to add the unstable terms that cancel out with the EP // contributions to the log-likelihood. See appendix of the report ll += data_vs().iter().map(|v| -0.5 * (TAU * v).ln()).sum::(); ll += data_ys() .iter() .zip(data_vs().iter()) .map(|(y, v)| -0.5 * y.powi(2) / v) .sum::(); assert_relative_eq!(ll, data_loglik(), max_relative = 0.0000001); } }