diff --git a/Cargo.toml b/Cargo.toml index 20851de..e291172 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,5 @@ ndarray = { version = "0.14", features = ["approx"] } [dev-dependencies] approx = "0.4" intel-mkl-src = "0.5" -rand = "0.7" -rand_xoshiro = "0.4" +rand = "0.8" time = "0.2" diff --git a/src/kernel/matern52.rs b/src/kernel/matern52.rs index b4663d4..280def1 100644 --- a/src/kernel/matern52.rs +++ b/src/kernel/matern52.rs @@ -4,7 +4,7 @@ use super::Kernel; pub struct Matern52 { var: f64, - _l_scale: f64, + l_scale: f64, lambda: f64, } @@ -12,15 +12,23 @@ impl Matern52 { pub fn new(var: f64, l_scale: f64) -> Self { Matern52 { var, - _l_scale: l_scale, + l_scale, lambda: 5.0f64.sqrt() / l_scale, } } } impl Kernel for Matern52 { - fn k_mat(&self, _ts1: &[f64], _ts2: Option<&[f64]>) -> Array2 { - unimplemented!(); + fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2 { + let ts2 = ts2.unwrap_or(ts1); + + let sqrt5 = 5.0f64.sqrt(); + + let r = super::distance(ts1, ts2) / self.l_scale; + let r2 = r.mapv(|v| (-sqrt5 * v).exp()); + let r3 = r.mapv(|v| v.powi(2)); + + r2 * (1.0 + sqrt5 * r + (5.0 / 3.0) * r3) * self.var } fn k_diag(&self, ts: &[f64]) -> Array1 { @@ -59,6 +67,10 @@ impl Kernel for Matern52 { ] } + fn noise_effect(&self) -> Array2 { + array![[0.0], [0.0], [1.0]] + } + fn transition(&self, t0: f64, t1: f64) -> Array2 { let d = t1 - t0; let a = self.lambda; @@ -102,3 +114,77 @@ impl Kernel for Matern52 { self.var * array![[x11, x12, x13], [x12, x22, x23], [x13, x23, x33]] } } + +#[cfg(test)] +mod tests { + extern crate intel_mkl_src; + + use approx::assert_abs_diff_eq; + use rand::{distributions::Standard, thread_rng, Rng}; + + use super::*; + + #[test] + fn test_kernel_matrix() { + let kernel = Matern52::new(0.2, 5.0); + + let ts = [1.26, 1.46, 2.67]; + + assert_abs_diff_eq!( + kernel.k_mat(&ts, None), + array![ + [0.2, 0.19973384192198906, 0.18769647349199411], + [0.19973384192198906, 0.2, 0.19077859940208852], + [0.18769647349199411, 0.19077859940208852, 0.2] + ] + ); + } + + #[test] + fn test_kernel_diag() { + let kernel = Matern52::new(0.2, 5.0); + + let ts: Vec<_> = thread_rng() + .sample_iter::(Standard) + .take(10) + .map(|x| x * 10.0) + .collect(); + + assert_eq!(kernel.k_mat(&ts, None).diag(), kernel.k_diag(&ts)); + } + + #[test] + fn test_kernel_order() { + let kernel = Matern52::new(0.2, 5.0); + + let m = kernel.order(); + + assert_eq!(kernel.state_mean(0.0).shape(), &[m]); + assert_eq!(kernel.state_cov(0.0).shape(), &[m, m]); + assert_eq!(kernel.measurement_vector().shape(), &[m]); + assert_eq!(kernel.feedback().shape(), &[m, m]); + assert_eq!(kernel.noise_effect().shape()[0], m); + assert_eq!(kernel.transition(0.0, 1.0).shape(), &[m, m]); + assert_eq!(kernel.noise_cov(0.0, 1.0).shape(), &[m, m]); + } + + #[test] + fn test_ssm_variance() { + let kernel = Matern52::new(0.2, 5.0); + + let ts: Vec<_> = thread_rng() + .sample_iter::(Standard) + .take(10) + .map(|x| x * 10.0) + .collect(); + + let h = kernel.measurement_vector(); + + let vars = ts + .iter() + .map(|t| h.dot(&kernel.state_cov(*t)).dot(&h)) + .collect::>(); + + assert_abs_diff_eq!(Array::from(vars), kernel.k_diag(&ts)); + } +}