Implement gaussian observation. Added test to fitter, test that are failing! Need to investigate
This commit is contained in:
@@ -321,3 +321,177 @@ impl<K: Kernel> Fitter for Recursive<K> {
|
||||
&mut self.ns[idx]
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
extern crate blas_src;
|
||||
|
||||
use approx::assert_relative_eq;
|
||||
use rand::{distributions::Standard, thread_rng, Rng};
|
||||
|
||||
use crate::kernel::Matern32;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn fitter() -> Recursive<Matern32> {
|
||||
Recursive::new(Matern32::new(2.0, 1.0))
|
||||
}
|
||||
|
||||
fn data_ts_train() -> Vec<f64> {
|
||||
vec![
|
||||
0.11616722, 0.31198904, 0.31203728, 0.74908024, 1.19731697, 1.20223002, 1.41614516,
|
||||
1.46398788, 1.73235229, 1.90142861,
|
||||
]
|
||||
}
|
||||
|
||||
fn data_ys() -> Array1<f64> {
|
||||
array![
|
||||
-1.10494786,
|
||||
-0.07702044,
|
||||
-0.25473925,
|
||||
3.22959111,
|
||||
0.90038114,
|
||||
0.30686385,
|
||||
1.70281621,
|
||||
-1.717506,
|
||||
0.63707278,
|
||||
-1.40986299
|
||||
]
|
||||
}
|
||||
|
||||
fn data_vs() -> Vec<f64> {
|
||||
vec![
|
||||
0.55064619, 0.3540315, 0.34114585, 2.21458142, 7.40431354, 0.35093921, 0.91847147,
|
||||
4.50764809, 0.43440729, 1.3308561,
|
||||
]
|
||||
}
|
||||
|
||||
fn data_mean() -> Array1<f64> {
|
||||
array![
|
||||
-0.52517486,
|
||||
-0.18391072,
|
||||
-0.18381275,
|
||||
0.59905936,
|
||||
0.62923813,
|
||||
0.6280899,
|
||||
0.56576719,
|
||||
0.53663651,
|
||||
0.26874937,
|
||||
0.04892406
|
||||
]
|
||||
}
|
||||
|
||||
fn data_var() -> Array1<f64> {
|
||||
array![
|
||||
0.20318775, 0.12410961, 0.12411533, 0.32855394, 0.19538865, 0.19410925, 0.18676754,
|
||||
0.19074449, 0.22105848, 0.33534931
|
||||
]
|
||||
}
|
||||
|
||||
fn data_loglik() -> f64 {
|
||||
-17.357282245711051
|
||||
}
|
||||
|
||||
fn data_ts_pred() -> &'static [f64] {
|
||||
&[0.0, 1.0, 2.0]
|
||||
}
|
||||
|
||||
fn data_mean_pred() -> Array1<f64> {
|
||||
array![-0.63981819, 0.67552349, -0.04684169]
|
||||
}
|
||||
|
||||
fn data_var_pred() -> Array1<f64> {
|
||||
array![0.33946081, 0.28362645, 0.45585554]
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allocation() {
|
||||
let mut fitter = fitter();
|
||||
|
||||
// No data, hence fitter defined to be allocated
|
||||
assert!(fitter.is_allocated());
|
||||
|
||||
// Add some data
|
||||
for i in 0..8 {
|
||||
fitter.add_sample(i as f64);
|
||||
}
|
||||
|
||||
assert!(!fitter.is_allocated());
|
||||
|
||||
// Allocate the arrays
|
||||
fitter.allocate();
|
||||
|
||||
assert!(fitter.is_allocated());
|
||||
|
||||
// Add some data
|
||||
for i in 0..8 {
|
||||
fitter.add_sample(i as f64);
|
||||
}
|
||||
|
||||
assert!(!fitter.is_allocated());
|
||||
|
||||
// Re-allocate the arrays
|
||||
fitter.allocate();
|
||||
|
||||
assert!(fitter.is_allocated());
|
||||
|
||||
// Check that arrays have the appropriate size
|
||||
assert_eq!(fitter.ts.len(), 16);
|
||||
assert_eq!(fitter.ms.len(), 16);
|
||||
assert_eq!(fitter.vs.len(), 16);
|
||||
assert_eq!(fitter.ns.len(), 16);
|
||||
assert_eq!(fitter.xs.len(), 16);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_against_gpy() {
|
||||
let mut fitter = fitter();
|
||||
|
||||
for t in data_ts_train().into_iter() {
|
||||
fitter.add_sample(t);
|
||||
}
|
||||
|
||||
fitter.allocate();
|
||||
|
||||
fitter.xs = data_vs().iter().map(|v| 1.0 / v).collect();
|
||||
fitter.ns = data_ys()
|
||||
.iter()
|
||||
.zip(data_vs().iter())
|
||||
.map(|(ys, vs)| ys / vs)
|
||||
.collect();
|
||||
|
||||
fitter.fit();
|
||||
|
||||
// Estimation
|
||||
eprintln!("{:#?}", fitter.ms);
|
||||
|
||||
assert_relative_eq!(
|
||||
Array1::from(fitter.ms.clone()),
|
||||
data_mean(),
|
||||
max_relative = 0.0000001
|
||||
);
|
||||
assert_relative_eq!(
|
||||
Array1::from(fitter.vs.clone()),
|
||||
data_var(),
|
||||
max_relative = 0.0000001
|
||||
);
|
||||
|
||||
// Prediction.
|
||||
let (ms, vs) = fitter.predict(data_ts_pred());
|
||||
|
||||
assert_relative_eq!(Array1::from(ms), data_mean_pred(), max_relative = 0.0000001);
|
||||
assert_relative_eq!(Array1::from(vs), data_var_pred(), max_relative = 0.0000001);
|
||||
|
||||
/*
|
||||
# Log-likelihood.
|
||||
ll = fitter.ep_log_likelihood_contrib
|
||||
|
||||
# We need to add the unstable terms that cancel out with the EP
|
||||
# contributions to the log-likelihood. See appendix of the report.
|
||||
ll += sum(-0.5 * log(2 * pi * v) for v in DATA["vs"])
|
||||
ll += sum(-0.5 * y*y / v for y, v in zip(DATA["ys"], DATA["vs"]))
|
||||
|
||||
assert np.allclose(ll, DATA["loglik"])
|
||||
*/
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user