Implement Fitter::ep_log_likelihood_contrib and add test for it

This commit is contained in:
2022-04-28 13:28:40 +02:00
parent ba06c40b2e
commit 6b77664b27
2 changed files with 39 additions and 9 deletions

View File

@@ -290,6 +290,30 @@ impl<K: Kernel> Fitter for Recursive<K> {
(ms, vs)
}
/// Contribution to the log-marginal likelihood of the model
fn ep_log_likelihood_contrib(&self) -> f64 {
// Note: this is *not* equal to the log of the marginal likelihood of the
// regression model. See "stable computation of the marginal likelihood"
// in the notes.
if !self.is_fitted {
panic!("new data since last call to `fit()`")
}
let mut val = 0.0;
for i in 0..self.ts.len() {
let o = self.h.dot(&self.m_p[i]);
let v = self.h.dot(&self.p_p[i]).dot(&self.h);
val += -0.5
* ((self.xs[i] * v + 1.0).ln()
+ (-self.ns[i].powi(2) * v - 2.0 * self.ns[i] * o + self.xs[i] * o.powi(2))
/ (self.xs[i] * v + 1.0));
}
val
}
fn vs(&self, idx: usize) -> f64 {
self.vs[idx]
}
@@ -327,6 +351,8 @@ impl<K: Kernel> Fitter for Recursive<K> {
mod tests {
extern crate blas_src;
use std::f64::consts::TAU;
use approx::assert_relative_eq;
use crate::kernel::Matern32;
@@ -482,16 +508,18 @@ mod tests {
assert_relative_eq!(Array1::from(ms), data_mean_pred(), max_relative = 0.000001);
assert_relative_eq!(Array1::from(vs), data_var_pred(), max_relative = 0.000001);
/*
# Log-likelihood.
ll = fitter.ep_log_likelihood_contrib
// Log-likelihood
let mut ll = fitter.ep_log_likelihood_contrib();
# We need to add the unstable terms that cancel out with the EP
# contributions to the log-likelihood. See appendix of the report.
ll += sum(-0.5 * log(2 * pi * v) for v in DATA["vs"])
ll += sum(-0.5 * y*y / v for y, v in zip(DATA["ys"], DATA["vs"]))
// We need to add the unstable terms that cancel out with the EP
// contributions to the log-likelihood. See appendix of the report
ll += data_vs().iter().map(|v| -0.5 * (TAU * v).ln()).sum::<f64>();
ll += data_ys()
.iter()
.zip(data_vs().iter())
.map(|(y, v)| -0.5 * y.powi(2) / v)
.sum::<f64>();
assert np.allclose(ll, DATA["loglik"])
*/
assert_relative_eq!(ll, data_loglik(), max_relative = 0.0000001);
}
}