diff --git a/Cargo.toml b/Cargo.toml index 4ff5f48..46d60f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ license = "MIT OR Apache-2.0" [dependencies] cblas = "0.2" lapacke = "0.2" -ndarray = "0.14" +ndarray = { version = "0.14", features = ["approx"] } ordered-float = "1.0" rand = "0.7" rand_xoshiro = "0.4" diff --git a/src/expm.rs b/src/expm.rs index 153f70b..ad1d078 100644 --- a/src/expm.rs +++ b/src/expm.rs @@ -88,7 +88,7 @@ const PADE_COEFF_13: [f64; 14] = [ fn pade_error_coefficient(m: u64) -> f64 { use statrs::function::factorial::{binomial, factorial}; - return 1.0 / (binomial(2 * m, m) * factorial(2 * m + 1)); + 1.0 / (binomial(2 * m, m) * factorial(2 * m + 1)) } #[allow(non_camel_case_types)] diff --git a/src/fitter/recursive.rs b/src/fitter/recursive.rs index 199a55c..0bfddd7 100644 --- a/src/fitter/recursive.rs +++ b/src/fitter/recursive.rs @@ -227,15 +227,12 @@ impl Fitter for RecursiveFitter { let mut ms = Vec::new(); let mut vs = Vec::new(); - let locations = ts - .iter() - .map(|t| { - self.ts - .iter() - .position(|tc| t <= tc) - .unwrap_or_else(|| self.ts.len()) - }) - .collect::>(); + let locations = ts.iter().map(|t| { + self.ts + .iter() + .position(|tc| t <= tc) + .unwrap_or_else(|| self.ts.len()) + }); for (i, nxt) in locations.into_iter().enumerate() { if nxt == self.ts.len() { diff --git a/src/kernel.rs b/src/kernel.rs index 2dc5983..854f74a 100644 --- a/src/kernel.rs +++ b/src/kernel.rs @@ -11,6 +11,7 @@ pub use matern32::Matern32; pub use matern52::Matern52; pub trait Kernel { + fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> ArrayD; fn k_diag(&self, ts: &[f64]) -> Array1; fn order(&self) -> usize; fn state_mean(&self, t: f64) -> Array1; @@ -18,6 +19,14 @@ pub trait Kernel { fn measurement_vector(&self) -> Array1; fn feedback(&self) -> Array2; + fn noise_effect(&self) -> Array2 { + unimplemented!(); + } + + fn noise_density(&self) -> Array2 { + unimplemented!(); + } + fn transition(&self, t0: f64, t1: f64) -> Array2 { let f = self.feedback(); @@ -51,6 +60,10 @@ pub trait Kernel { } impl Kernel for Vec> { + fn k_mat(&self, _ts1: &[f64], _ts2: Option<&[f64]>) -> ArrayD { + unimplemented!(); + } + fn k_diag(&self, ts: &[f64]) -> Array1 { self.iter() .fold(Array1::zeros(ts.len()), |k_diag: Array1, kernel| { diff --git a/src/kernel/constant.rs b/src/kernel/constant.rs index d8ff469..09cba3f 100644 --- a/src/kernel/constant.rs +++ b/src/kernel/constant.rs @@ -1,4 +1,5 @@ use ndarray::prelude::*; +use ndarray::IxDyn; use super::Kernel; @@ -13,6 +14,13 @@ impl Constant { } impl Kernel for Constant { + fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> ArrayD { + let n = ts1.len(); + let m = ts2.map_or(n, |ts| ts.len()); + + ArrayD::ones(IxDyn(&[n, m])) * self.var + } + fn k_diag(&self, ts: &[f64]) -> Array1 { Array1::ones(ts.len()) * self.var } @@ -37,6 +45,10 @@ impl Kernel for Constant { array![[0.0]] } + fn noise_effect(&self) -> Array2 { + array![[1.0]] + } + fn transition(&self, _t0: f64, _t1: f64) -> Array2 { array![[1.0]] } @@ -45,3 +57,71 @@ impl Kernel for Constant { array![[0.0]] } } + +#[cfg(test)] +mod tests { + use approx::assert_abs_diff_eq; + use rand::{distributions::Standard, thread_rng, Rng}; + + use super::*; + + #[test] + fn test_kernel_matrix() { + let kernel = Constant::new(2.5); + + let ts = [1.26, 1.46, 2.67]; + + assert_abs_diff_eq!( + kernel.k_mat(&ts, None), + array![[2.5, 2.5, 2.5], [2.5, 2.5, 2.5], [2.5, 2.5, 2.5]].into_dyn() + ); + } + + #[test] + fn test_kernel_diag() { + let kernel = Constant::new(2.5); + + let ts: Vec<_> = thread_rng() + .sample_iter::(Standard) + .take(10) + .map(|x| x * 10.0) + .collect(); + + assert_eq!(kernel.k_mat(&ts, None).diag(), kernel.k_diag(&ts)); + } + + #[test] + fn test_kernel_order() { + let kernel = Constant::new(2.5); + + let m = kernel.order(); + + assert_eq!(kernel.state_mean(0.0).shape(), &[m]); + assert_eq!(kernel.state_cov(0.0).shape(), &[m, m]); + assert_eq!(kernel.measurement_vector().shape(), &[m]); + assert_eq!(kernel.feedback().shape(), &[m, m]); + assert_eq!(kernel.noise_effect().shape()[0], m); + assert_eq!(kernel.transition(0.0, 1.0).shape(), &[m, m]); + assert_eq!(kernel.noise_cov(0.0, 1.0).shape(), &[m, m]); + } + + #[test] + fn test_ssm_variance() { + let kernel = Constant::new(2.5); + + let ts: Vec<_> = thread_rng() + .sample_iter::(Standard) + .take(10) + .map(|x| x * 10.0) + .collect(); + + let h = kernel.measurement_vector(); + + let vars = ts + .iter() + .map(|t| h.dot(&kernel.state_cov(*t)).dot(&h)) + .collect::>(); + + assert_abs_diff_eq!(Array::from(vars), kernel.k_diag(&ts)); + } +} diff --git a/src/kernel/exponential.rs b/src/kernel/exponential.rs index 92ef02b..83ecd9a 100644 --- a/src/kernel/exponential.rs +++ b/src/kernel/exponential.rs @@ -14,6 +14,10 @@ impl Exponential { } impl Kernel for Exponential { + fn k_mat(&self, _ts1: &[f64], _ts2: Option<&[f64]>) -> ArrayD { + unimplemented!(); + } + fn k_diag(&self, ts: &[f64]) -> Array1 { Array1::ones(ts.len()) * self.var } diff --git a/src/kernel/matern32.rs b/src/kernel/matern32.rs index afe7df5..8b8037d 100644 --- a/src/kernel/matern32.rs +++ b/src/kernel/matern32.rs @@ -20,6 +20,10 @@ impl Matern32 { } impl Kernel for Matern32 { + fn k_mat(&self, _ts1: &[f64], _ts2: Option<&[f64]>) -> ArrayD { + unimplemented!(); + } + fn k_diag(&self, ts: &[f64]) -> Array1 { Array1::ones(ts.len()) * self.var } diff --git a/src/kernel/matern52.rs b/src/kernel/matern52.rs index 9d0ff76..cfe3c43 100644 --- a/src/kernel/matern52.rs +++ b/src/kernel/matern52.rs @@ -19,6 +19,10 @@ impl Matern52 { } impl Kernel for Matern52 { + fn k_mat(&self, _ts1: &[f64], _ts2: Option<&[f64]>) -> ArrayD { + unimplemented!(); + } + fn k_diag(&self, ts: &[f64]) -> Array1 { Array1::ones(ts.len()) * self.var }