From 6b77664b276a138440acf5c4f37435a77cb1b1b7 Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Thu, 28 Apr 2022 13:28:40 +0200 Subject: [PATCH] Implement Fitter::ep_log_likelihood_contrib and add test for it --- src/fitter.rs | 2 ++ src/fitter/recursive.rs | 46 +++++++++++++++++++++++++++++++++-------- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/src/fitter.rs b/src/fitter.rs index 5711631..ea67001 100644 --- a/src/fitter.rs +++ b/src/fitter.rs @@ -12,6 +12,8 @@ pub trait Fitter { fn predict(&self, ts: &[f64]) -> (Vec, Vec); + fn ep_log_likelihood_contrib(&self) -> f64; + fn vs(&self, idx: usize) -> f64; fn vs_mut(&mut self, idx: usize) -> &mut f64; diff --git a/src/fitter/recursive.rs b/src/fitter/recursive.rs index 5ce5c54..b9039c7 100644 --- a/src/fitter/recursive.rs +++ b/src/fitter/recursive.rs @@ -290,6 +290,30 @@ impl Fitter for Recursive { (ms, vs) } + /// Contribution to the log-marginal likelihood of the model + fn ep_log_likelihood_contrib(&self) -> f64 { + // Note: this is *not* equal to the log of the marginal likelihood of the + // regression model. See "stable computation of the marginal likelihood" + // in the notes. + if !self.is_fitted { + panic!("new data since last call to `fit()`") + } + + let mut val = 0.0; + + for i in 0..self.ts.len() { + let o = self.h.dot(&self.m_p[i]); + let v = self.h.dot(&self.p_p[i]).dot(&self.h); + + val += -0.5 + * ((self.xs[i] * v + 1.0).ln() + + (-self.ns[i].powi(2) * v - 2.0 * self.ns[i] * o + self.xs[i] * o.powi(2)) + / (self.xs[i] * v + 1.0)); + } + + val + } + fn vs(&self, idx: usize) -> f64 { self.vs[idx] } @@ -327,6 +351,8 @@ impl Fitter for Recursive { mod tests { extern crate blas_src; + use std::f64::consts::TAU; + use approx::assert_relative_eq; use crate::kernel::Matern32; @@ -482,16 +508,18 @@ mod tests { assert_relative_eq!(Array1::from(ms), data_mean_pred(), max_relative = 0.000001); assert_relative_eq!(Array1::from(vs), data_var_pred(), max_relative = 0.000001); - /* - # Log-likelihood. - ll = fitter.ep_log_likelihood_contrib + // Log-likelihood + let mut ll = fitter.ep_log_likelihood_contrib(); - # We need to add the unstable terms that cancel out with the EP - # contributions to the log-likelihood. See appendix of the report. - ll += sum(-0.5 * log(2 * pi * v) for v in DATA["vs"]) - ll += sum(-0.5 * y*y / v for y, v in zip(DATA["ys"], DATA["vs"])) + // We need to add the unstable terms that cancel out with the EP + // contributions to the log-likelihood. See appendix of the report + ll += data_vs().iter().map(|v| -0.5 * (TAU * v).ln()).sum::(); + ll += data_ys() + .iter() + .zip(data_vs().iter()) + .map(|(y, v)| -0.5 * y.powi(2) / v) + .sum::(); - assert np.allclose(ll, DATA["loglik"]) - */ + assert_relative_eq!(ll, data_loglik(), max_relative = 0.0000001); } }