Added tests for Matern52
This commit is contained in:
@@ -12,6 +12,5 @@ ndarray = { version = "0.14", features = ["approx"] }
|
|||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
approx = "0.4"
|
approx = "0.4"
|
||||||
intel-mkl-src = "0.5"
|
intel-mkl-src = "0.5"
|
||||||
rand = "0.7"
|
rand = "0.8"
|
||||||
rand_xoshiro = "0.4"
|
|
||||||
time = "0.2"
|
time = "0.2"
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use super::Kernel;
|
|||||||
|
|
||||||
pub struct Matern52 {
|
pub struct Matern52 {
|
||||||
var: f64,
|
var: f64,
|
||||||
_l_scale: f64,
|
l_scale: f64,
|
||||||
lambda: f64,
|
lambda: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -12,15 +12,23 @@ impl Matern52 {
|
|||||||
pub fn new(var: f64, l_scale: f64) -> Self {
|
pub fn new(var: f64, l_scale: f64) -> Self {
|
||||||
Matern52 {
|
Matern52 {
|
||||||
var,
|
var,
|
||||||
_l_scale: l_scale,
|
l_scale,
|
||||||
lambda: 5.0f64.sqrt() / l_scale,
|
lambda: 5.0f64.sqrt() / l_scale,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Kernel for Matern52 {
|
impl Kernel for Matern52 {
|
||||||
fn k_mat(&self, _ts1: &[f64], _ts2: Option<&[f64]>) -> Array2<f64> {
|
fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2<f64> {
|
||||||
unimplemented!();
|
let ts2 = ts2.unwrap_or(ts1);
|
||||||
|
|
||||||
|
let sqrt5 = 5.0f64.sqrt();
|
||||||
|
|
||||||
|
let r = super::distance(ts1, ts2) / self.l_scale;
|
||||||
|
let r2 = r.mapv(|v| (-sqrt5 * v).exp());
|
||||||
|
let r3 = r.mapv(|v| v.powi(2));
|
||||||
|
|
||||||
|
r2 * (1.0 + sqrt5 * r + (5.0 / 3.0) * r3) * self.var
|
||||||
}
|
}
|
||||||
|
|
||||||
fn k_diag(&self, ts: &[f64]) -> Array1<f64> {
|
fn k_diag(&self, ts: &[f64]) -> Array1<f64> {
|
||||||
@@ -59,6 +67,10 @@ impl Kernel for Matern52 {
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn noise_effect(&self) -> Array2<f64> {
|
||||||
|
array![[0.0], [0.0], [1.0]]
|
||||||
|
}
|
||||||
|
|
||||||
fn transition(&self, t0: f64, t1: f64) -> Array2<f64> {
|
fn transition(&self, t0: f64, t1: f64) -> Array2<f64> {
|
||||||
let d = t1 - t0;
|
let d = t1 - t0;
|
||||||
let a = self.lambda;
|
let a = self.lambda;
|
||||||
@@ -102,3 +114,77 @@ impl Kernel for Matern52 {
|
|||||||
self.var * array![[x11, x12, x13], [x12, x22, x23], [x13, x23, x33]]
|
self.var * array![[x11, x12, x13], [x12, x22, x23], [x13, x23, x33]]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
use approx::assert_abs_diff_eq;
|
||||||
|
use rand::{distributions::Standard, thread_rng, Rng};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_kernel_matrix() {
|
||||||
|
let kernel = Matern52::new(0.2, 5.0);
|
||||||
|
|
||||||
|
let ts = [1.26, 1.46, 2.67];
|
||||||
|
|
||||||
|
assert_abs_diff_eq!(
|
||||||
|
kernel.k_mat(&ts, None),
|
||||||
|
array![
|
||||||
|
[0.2, 0.19973384192198906, 0.18769647349199411],
|
||||||
|
[0.19973384192198906, 0.2, 0.19077859940208852],
|
||||||
|
[0.18769647349199411, 0.19077859940208852, 0.2]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_kernel_diag() {
|
||||||
|
let kernel = Matern52::new(0.2, 5.0);
|
||||||
|
|
||||||
|
let ts: Vec<_> = thread_rng()
|
||||||
|
.sample_iter::<f64, _>(Standard)
|
||||||
|
.take(10)
|
||||||
|
.map(|x| x * 10.0)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
assert_eq!(kernel.k_mat(&ts, None).diag(), kernel.k_diag(&ts));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_kernel_order() {
|
||||||
|
let kernel = Matern52::new(0.2, 5.0);
|
||||||
|
|
||||||
|
let m = kernel.order();
|
||||||
|
|
||||||
|
assert_eq!(kernel.state_mean(0.0).shape(), &[m]);
|
||||||
|
assert_eq!(kernel.state_cov(0.0).shape(), &[m, m]);
|
||||||
|
assert_eq!(kernel.measurement_vector().shape(), &[m]);
|
||||||
|
assert_eq!(kernel.feedback().shape(), &[m, m]);
|
||||||
|
assert_eq!(kernel.noise_effect().shape()[0], m);
|
||||||
|
assert_eq!(kernel.transition(0.0, 1.0).shape(), &[m, m]);
|
||||||
|
assert_eq!(kernel.noise_cov(0.0, 1.0).shape(), &[m, m]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ssm_variance() {
|
||||||
|
let kernel = Matern52::new(0.2, 5.0);
|
||||||
|
|
||||||
|
let ts: Vec<_> = thread_rng()
|
||||||
|
.sample_iter::<f64, _>(Standard)
|
||||||
|
.take(10)
|
||||||
|
.map(|x| x * 10.0)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let h = kernel.measurement_vector();
|
||||||
|
|
||||||
|
let vars = ts
|
||||||
|
.iter()
|
||||||
|
.map(|t| h.dot(&kernel.state_cov(*t)).dot(&h))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
assert_abs_diff_eq!(Array::from(vars), kernel.k_diag(&ts));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user