From 7a3ff508cd945267247b43d08a53bac5805be0a0 Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Wed, 27 Apr 2022 11:58:49 +0200 Subject: [PATCH] Implement gaussian observation. Added test to fitter, test that are failing! Need to investigate --- src/fitter/recursive.rs | 174 ++++++++++++++++++++++++++++++++++++ src/observation/gaussian.rs | 43 ++++++--- 2 files changed, 205 insertions(+), 12 deletions(-) diff --git a/src/fitter/recursive.rs b/src/fitter/recursive.rs index 2978916..ee65beb 100644 --- a/src/fitter/recursive.rs +++ b/src/fitter/recursive.rs @@ -321,3 +321,177 @@ impl Fitter for Recursive { &mut self.ns[idx] } } + +#[cfg(test)] +mod tests { + extern crate blas_src; + + use approx::assert_relative_eq; + use rand::{distributions::Standard, thread_rng, Rng}; + + use crate::kernel::Matern32; + + use super::*; + + fn fitter() -> Recursive { + Recursive::new(Matern32::new(2.0, 1.0)) + } + + fn data_ts_train() -> Vec { + vec![ + 0.11616722, 0.31198904, 0.31203728, 0.74908024, 1.19731697, 1.20223002, 1.41614516, + 1.46398788, 1.73235229, 1.90142861, + ] + } + + fn data_ys() -> Array1 { + array![ + -1.10494786, + -0.07702044, + -0.25473925, + 3.22959111, + 0.90038114, + 0.30686385, + 1.70281621, + -1.717506, + 0.63707278, + -1.40986299 + ] + } + + fn data_vs() -> Vec { + vec![ + 0.55064619, 0.3540315, 0.34114585, 2.21458142, 7.40431354, 0.35093921, 0.91847147, + 4.50764809, 0.43440729, 1.3308561, + ] + } + + fn data_mean() -> Array1 { + array![ + -0.52517486, + -0.18391072, + -0.18381275, + 0.59905936, + 0.62923813, + 0.6280899, + 0.56576719, + 0.53663651, + 0.26874937, + 0.04892406 + ] + } + + fn data_var() -> Array1 { + array![ + 0.20318775, 0.12410961, 0.12411533, 0.32855394, 0.19538865, 0.19410925, 0.18676754, + 0.19074449, 0.22105848, 0.33534931 + ] + } + + fn data_loglik() -> f64 { + -17.357282245711051 + } + + fn data_ts_pred() -> &'static [f64] { + &[0.0, 1.0, 2.0] + } + + fn data_mean_pred() -> Array1 { + array![-0.63981819, 0.67552349, -0.04684169] + } + + fn data_var_pred() -> Array1 { + array![0.33946081, 0.28362645, 0.45585554] + } + + #[test] + fn test_allocation() { + let mut fitter = fitter(); + + // No data, hence fitter defined to be allocated + assert!(fitter.is_allocated()); + + // Add some data + for i in 0..8 { + fitter.add_sample(i as f64); + } + + assert!(!fitter.is_allocated()); + + // Allocate the arrays + fitter.allocate(); + + assert!(fitter.is_allocated()); + + // Add some data + for i in 0..8 { + fitter.add_sample(i as f64); + } + + assert!(!fitter.is_allocated()); + + // Re-allocate the arrays + fitter.allocate(); + + assert!(fitter.is_allocated()); + + // Check that arrays have the appropriate size + assert_eq!(fitter.ts.len(), 16); + assert_eq!(fitter.ms.len(), 16); + assert_eq!(fitter.vs.len(), 16); + assert_eq!(fitter.ns.len(), 16); + assert_eq!(fitter.xs.len(), 16); + } + + #[test] + fn test_against_gpy() { + let mut fitter = fitter(); + + for t in data_ts_train().into_iter() { + fitter.add_sample(t); + } + + fitter.allocate(); + + fitter.xs = data_vs().iter().map(|v| 1.0 / v).collect(); + fitter.ns = data_ys() + .iter() + .zip(data_vs().iter()) + .map(|(ys, vs)| ys / vs) + .collect(); + + fitter.fit(); + + // Estimation + eprintln!("{:#?}", fitter.ms); + + assert_relative_eq!( + Array1::from(fitter.ms.clone()), + data_mean(), + max_relative = 0.0000001 + ); + assert_relative_eq!( + Array1::from(fitter.vs.clone()), + data_var(), + max_relative = 0.0000001 + ); + + // Prediction. + let (ms, vs) = fitter.predict(data_ts_pred()); + + assert_relative_eq!(Array1::from(ms), data_mean_pred(), max_relative = 0.0000001); + assert_relative_eq!(Array1::from(vs), data_var_pred(), max_relative = 0.0000001); + + /* + # Log-likelihood. + ll = fitter.ep_log_likelihood_contrib + + # We need to add the unstable terms that cancel out with the EP + # contributions to the log-likelihood. See appendix of the report. + ll += sum(-0.5 * log(2 * pi * v) for v in DATA["vs"]) + ll += sum(-0.5 * y*y / v for y, v in zip(DATA["ys"], DATA["vs"])) + + assert np.allclose(ll, DATA["loglik"]) + */ + } +} diff --git a/src/observation/gaussian.rs b/src/observation/gaussian.rs index b33b6ad..0185448 100644 --- a/src/observation/gaussian.rs +++ b/src/observation/gaussian.rs @@ -1,24 +1,43 @@ +use std::f64::consts::TAU; + use crate::storage::Storage; -use super::Observation; +use super::{Core, Observation}; -pub struct GaussianObservation; +fn mm_gaussian(mean_cav: f64, var_cav: f64, diff: f64, var_obs: f64) -> (f64, f64, f64) { + let logpart = + -0.5 * ((TAU * (var_obs + var_cav)).ln() + (diff - mean_cav).powi(2) / (var_obs + var_cav)); + + let dlogpart = (diff - mean_cav) / (var_obs + var_cav); + let d2logpart = -1.0 / (var_obs + var_cav); + + (logpart, dlogpart, d2logpart) +} + +pub struct GaussianObservation { + core: Core, + diff: f64, + var: f64, +} impl GaussianObservation { - pub fn new( - _storage: &mut Storage, - _elems: &[(usize, f64)], - _diff: f64, - _t: f64, - _var: f64, - ) -> Self { - unimplemented!(); + pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, diff: f64, var: f64) -> Self { + Self { + core: Core::new(storage, elems, t), + diff, + var, + } } } impl Observation for GaussianObservation { - fn ep_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 { - unimplemented!(); + fn ep_update(&mut self, storage: &mut Storage, lr: f64) -> f64 { + let diff = self.diff; + let var = self.var; + + self.core.ep_update(storage, lr, |mean_cav, cov_cav| { + mm_gaussian(mean_cav, cov_cav, diff, var) + }) } fn kl_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 {