Implement gaussian observation. Added test to fitter, test that are failing! Need to investigate
This commit is contained in:
@@ -321,3 +321,177 @@ impl<K: Kernel> Fitter for Recursive<K> {
|
|||||||
&mut self.ns[idx]
|
&mut self.ns[idx]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
extern crate blas_src;
|
||||||
|
|
||||||
|
use approx::assert_relative_eq;
|
||||||
|
use rand::{distributions::Standard, thread_rng, Rng};
|
||||||
|
|
||||||
|
use crate::kernel::Matern32;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn fitter() -> Recursive<Matern32> {
|
||||||
|
Recursive::new(Matern32::new(2.0, 1.0))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn data_ts_train() -> Vec<f64> {
|
||||||
|
vec![
|
||||||
|
0.11616722, 0.31198904, 0.31203728, 0.74908024, 1.19731697, 1.20223002, 1.41614516,
|
||||||
|
1.46398788, 1.73235229, 1.90142861,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn data_ys() -> Array1<f64> {
|
||||||
|
array![
|
||||||
|
-1.10494786,
|
||||||
|
-0.07702044,
|
||||||
|
-0.25473925,
|
||||||
|
3.22959111,
|
||||||
|
0.90038114,
|
||||||
|
0.30686385,
|
||||||
|
1.70281621,
|
||||||
|
-1.717506,
|
||||||
|
0.63707278,
|
||||||
|
-1.40986299
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn data_vs() -> Vec<f64> {
|
||||||
|
vec![
|
||||||
|
0.55064619, 0.3540315, 0.34114585, 2.21458142, 7.40431354, 0.35093921, 0.91847147,
|
||||||
|
4.50764809, 0.43440729, 1.3308561,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn data_mean() -> Array1<f64> {
|
||||||
|
array![
|
||||||
|
-0.52517486,
|
||||||
|
-0.18391072,
|
||||||
|
-0.18381275,
|
||||||
|
0.59905936,
|
||||||
|
0.62923813,
|
||||||
|
0.6280899,
|
||||||
|
0.56576719,
|
||||||
|
0.53663651,
|
||||||
|
0.26874937,
|
||||||
|
0.04892406
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn data_var() -> Array1<f64> {
|
||||||
|
array![
|
||||||
|
0.20318775, 0.12410961, 0.12411533, 0.32855394, 0.19538865, 0.19410925, 0.18676754,
|
||||||
|
0.19074449, 0.22105848, 0.33534931
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn data_loglik() -> f64 {
|
||||||
|
-17.357282245711051
|
||||||
|
}
|
||||||
|
|
||||||
|
fn data_ts_pred() -> &'static [f64] {
|
||||||
|
&[0.0, 1.0, 2.0]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn data_mean_pred() -> Array1<f64> {
|
||||||
|
array![-0.63981819, 0.67552349, -0.04684169]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn data_var_pred() -> Array1<f64> {
|
||||||
|
array![0.33946081, 0.28362645, 0.45585554]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_allocation() {
|
||||||
|
let mut fitter = fitter();
|
||||||
|
|
||||||
|
// No data, hence fitter defined to be allocated
|
||||||
|
assert!(fitter.is_allocated());
|
||||||
|
|
||||||
|
// Add some data
|
||||||
|
for i in 0..8 {
|
||||||
|
fitter.add_sample(i as f64);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(!fitter.is_allocated());
|
||||||
|
|
||||||
|
// Allocate the arrays
|
||||||
|
fitter.allocate();
|
||||||
|
|
||||||
|
assert!(fitter.is_allocated());
|
||||||
|
|
||||||
|
// Add some data
|
||||||
|
for i in 0..8 {
|
||||||
|
fitter.add_sample(i as f64);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(!fitter.is_allocated());
|
||||||
|
|
||||||
|
// Re-allocate the arrays
|
||||||
|
fitter.allocate();
|
||||||
|
|
||||||
|
assert!(fitter.is_allocated());
|
||||||
|
|
||||||
|
// Check that arrays have the appropriate size
|
||||||
|
assert_eq!(fitter.ts.len(), 16);
|
||||||
|
assert_eq!(fitter.ms.len(), 16);
|
||||||
|
assert_eq!(fitter.vs.len(), 16);
|
||||||
|
assert_eq!(fitter.ns.len(), 16);
|
||||||
|
assert_eq!(fitter.xs.len(), 16);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_against_gpy() {
|
||||||
|
let mut fitter = fitter();
|
||||||
|
|
||||||
|
for t in data_ts_train().into_iter() {
|
||||||
|
fitter.add_sample(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
fitter.allocate();
|
||||||
|
|
||||||
|
fitter.xs = data_vs().iter().map(|v| 1.0 / v).collect();
|
||||||
|
fitter.ns = data_ys()
|
||||||
|
.iter()
|
||||||
|
.zip(data_vs().iter())
|
||||||
|
.map(|(ys, vs)| ys / vs)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
fitter.fit();
|
||||||
|
|
||||||
|
// Estimation
|
||||||
|
eprintln!("{:#?}", fitter.ms);
|
||||||
|
|
||||||
|
assert_relative_eq!(
|
||||||
|
Array1::from(fitter.ms.clone()),
|
||||||
|
data_mean(),
|
||||||
|
max_relative = 0.0000001
|
||||||
|
);
|
||||||
|
assert_relative_eq!(
|
||||||
|
Array1::from(fitter.vs.clone()),
|
||||||
|
data_var(),
|
||||||
|
max_relative = 0.0000001
|
||||||
|
);
|
||||||
|
|
||||||
|
// Prediction.
|
||||||
|
let (ms, vs) = fitter.predict(data_ts_pred());
|
||||||
|
|
||||||
|
assert_relative_eq!(Array1::from(ms), data_mean_pred(), max_relative = 0.0000001);
|
||||||
|
assert_relative_eq!(Array1::from(vs), data_var_pred(), max_relative = 0.0000001);
|
||||||
|
|
||||||
|
/*
|
||||||
|
# Log-likelihood.
|
||||||
|
ll = fitter.ep_log_likelihood_contrib
|
||||||
|
|
||||||
|
# We need to add the unstable terms that cancel out with the EP
|
||||||
|
# contributions to the log-likelihood. See appendix of the report.
|
||||||
|
ll += sum(-0.5 * log(2 * pi * v) for v in DATA["vs"])
|
||||||
|
ll += sum(-0.5 * y*y / v for y, v in zip(DATA["ys"], DATA["vs"]))
|
||||||
|
|
||||||
|
assert np.allclose(ll, DATA["loglik"])
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,24 +1,43 @@
|
|||||||
|
use std::f64::consts::TAU;
|
||||||
|
|
||||||
use crate::storage::Storage;
|
use crate::storage::Storage;
|
||||||
|
|
||||||
use super::Observation;
|
use super::{Core, Observation};
|
||||||
|
|
||||||
pub struct GaussianObservation;
|
fn mm_gaussian(mean_cav: f64, var_cav: f64, diff: f64, var_obs: f64) -> (f64, f64, f64) {
|
||||||
|
let logpart =
|
||||||
|
-0.5 * ((TAU * (var_obs + var_cav)).ln() + (diff - mean_cav).powi(2) / (var_obs + var_cav));
|
||||||
|
|
||||||
|
let dlogpart = (diff - mean_cav) / (var_obs + var_cav);
|
||||||
|
let d2logpart = -1.0 / (var_obs + var_cav);
|
||||||
|
|
||||||
|
(logpart, dlogpart, d2logpart)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct GaussianObservation {
|
||||||
|
core: Core,
|
||||||
|
diff: f64,
|
||||||
|
var: f64,
|
||||||
|
}
|
||||||
|
|
||||||
impl GaussianObservation {
|
impl GaussianObservation {
|
||||||
pub fn new(
|
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, diff: f64, var: f64) -> Self {
|
||||||
_storage: &mut Storage,
|
Self {
|
||||||
_elems: &[(usize, f64)],
|
core: Core::new(storage, elems, t),
|
||||||
_diff: f64,
|
diff,
|
||||||
_t: f64,
|
var,
|
||||||
_var: f64,
|
}
|
||||||
) -> Self {
|
|
||||||
unimplemented!();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Observation for GaussianObservation {
|
impl Observation for GaussianObservation {
|
||||||
fn ep_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 {
|
fn ep_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
|
||||||
unimplemented!();
|
let diff = self.diff;
|
||||||
|
let var = self.var;
|
||||||
|
|
||||||
|
self.core.ep_update(storage, lr, |mean_cav, cov_cav| {
|
||||||
|
mm_gaussian(mean_cav, cov_cav, diff, var)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn kl_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 {
|
fn kl_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 {
|
||||||
|
|||||||
Reference in New Issue
Block a user