From 6f90aa8170d0639f1dca7f534a16fd50ef7b8a8f Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Tue, 26 Apr 2022 22:41:13 +0200 Subject: [PATCH] Added Kernel impl for tuples --- src/kernel.rs | 236 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 236 insertions(+) diff --git a/src/kernel.rs b/src/kernel.rs index 9fc8ca9..3ca1bd3 100644 --- a/src/kernel.rs +++ b/src/kernel.rs @@ -1,3 +1,5 @@ +#![allow(non_snake_case)] + use ndarray::prelude::*; mod affine; @@ -417,6 +419,19 @@ mod tests { [0.3933043131476467, 0.4908539038626858, 1.7] ] ); + + let kernel = (Matern32::new(1.5, 0.7), Matern52::new(0.2, 5.0)); + + let ts = [1.26, 1.46, 2.67]; + + assert_abs_diff_eq!( + kernel.k_mat(&ts, None), + array![ + [1.7, 1.5667546855502472, 0.3933043131476467], + [1.5667546855502472, 1.7, 0.4908539038626858], + [0.3933043131476467, 0.4908539038626858, 1.7] + ] + ); } #[test] @@ -433,6 +448,16 @@ mod tests { .collect(); assert_eq!(kernel.k_mat(&ts, None).diag(), kernel.k_diag(&ts)); + + let kernel = (Matern32::new(1.5, 0.7), Matern52::new(0.2, 5.0)); + + let ts: Vec<_> = thread_rng() + .sample_iter::(Standard) + .take(10) + .map(|x| x * 10.0) + .collect(); + + assert_eq!(kernel.k_mat(&ts, None).diag(), kernel.k_diag(&ts)); } #[test] @@ -451,6 +476,18 @@ mod tests { assert_eq!(kernel.noise_effect().shape()[0], m); assert_eq!(kernel.transition(0.0, 1.0).shape(), &[m, m]); assert_eq!(kernel.noise_cov(0.0, 1.0).shape(), &[m, m]); + + let kernel = (Matern32::new(1.5, 0.7), Matern52::new(0.2, 5.0)); + + let m = kernel.order(); + + assert_eq!(kernel.state_mean(0.0).shape(), &[m]); + assert_eq!(kernel.state_cov(0.0).shape(), &[m, m]); + assert_eq!(kernel.measurement_vector().shape(), &[m]); + assert_eq!(kernel.feedback().shape(), &[m, m]); + assert_eq!(kernel.noise_effect().shape()[0], m); + assert_eq!(kernel.transition(0.0, 1.0).shape(), &[m, m]); + assert_eq!(kernel.noise_cov(0.0, 1.0).shape(), &[m, m]); } #[test] @@ -474,5 +511,204 @@ mod tests { .collect::>(); assert_abs_diff_eq!(Array::from(vars), kernel.k_diag(&ts)); + + let kernel = (Matern32::new(1.5, 0.7), Matern52::new(0.2, 5.0)); + + let ts: Vec<_> = thread_rng() + .sample_iter::(Standard) + .take(10) + .map(|x| x * 10.0) + .collect(); + + let h = kernel.measurement_vector(); + + let vars = ts + .iter() + .map(|t| h.dot(&kernel.state_cov(*t)).dot(&h)) + .collect::>(); + + assert_abs_diff_eq!(Array::from(vars), kernel.k_diag(&ts)); } } + +macro_rules! tuple_impls { + ( $( $name:ident )+ ) => { + impl<$($name: Kernel),+> Kernel for ($($name,)+) { + fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2 { + let n = ts1.len(); + let m = ts2.map_or(n, |ts| ts.len()); + + let ($($name,)+) = &self; + + Array2::zeros((n, m)) $(+$name.k_mat(ts1, ts2))+ + } + + fn k_diag(&self, ts: &[f64]) -> Array1 { + let ($($name,)+) = &self; + + Array1::zeros(ts.len()) $(+$name.k_diag(ts))+ + } + + fn order(&self) -> usize { + let ($($name,)+) = &self; + + 0 $(+$name.order())+ + } + + fn state_mean(&self, t: f64) -> Array1 { + let ($($name,)+) = &self; + + Array1::from_iter([$($name.state_mean(t).into_iter(),)+].into_iter().flatten().cloned()) + } + + fn state_cov(&self, t: f64) -> Array2 { + let ($($name,)+) = &self; + + let data = [$($name.state_cov(t),)+]; + + let dim = data + .iter() + .fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols())); + + let mut block = Array2::zeros(dim); + + let mut r_d = 0; + let mut c_d = 0; + + for m in data { + for ((r, c), v) in m.indexed_iter() { + block[(r + r_d, c + c_d)] = *v; + } + + r_d += m.nrows(); + c_d += m.ncols(); + } + + block + } + + fn measurement_vector(&self) -> Array1 { + let ($($name,)+) = &self; + + Array1::from_iter([$($name.measurement_vector().into_iter(),)+].into_iter().flatten().cloned()) + } + + fn feedback(&self) -> Array2 { + let ($($name,)+) = &self; + + let data = [$($name.feedback(),)+]; + + let dim = data + .iter() + .fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols())); + + let mut block = Array2::zeros(dim); + + let mut r_d = 0; + let mut c_d = 0; + + for m in data { + for ((r, c), v) in m.indexed_iter() { + block[(r + r_d, c + c_d)] = *v; + } + + r_d += m.nrows(); + c_d += m.ncols(); + } + + block + } + + fn transition(&self, t0: f64, t1: f64) -> Array2 { + let ($($name,)+) = &self; + + let data = [$($name.transition(t0, t1),)+]; + + let dim = data + .iter() + .fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols())); + + let mut block = Array2::zeros(dim); + + let mut r_d = 0; + let mut c_d = 0; + + for m in data { + for ((r, c), v) in m.indexed_iter() { + block[(r + r_d, c + c_d)] = *v; + } + + r_d += m.nrows(); + c_d += m.ncols(); + } + + block + } + + fn noise_effect(&self) -> Array2 { + let ($($name,)+) = &self; + + let data = [$($name.noise_effect(),)+]; + + let dim = data + .iter() + .fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols())); + + let mut block = Array2::zeros(dim); + + let mut r_d = 0; + let mut c_d = 0; + + for m in data { + for ((r, c), v) in m.indexed_iter() { + block[(r + r_d, c + c_d)] = *v; + } + + r_d += m.nrows(); + c_d += m.ncols(); + } + + block + } + + fn noise_cov(&self, t0: f64, t1: f64) -> Array2 { + let ($($name,)+) = &self; + + let data = [$($name.noise_cov(t0, t1),)+]; + + let dim = data + .iter() + .fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols())); + + let mut block = Array2::zeros(dim); + + let mut r_d = 0; + let mut c_d = 0; + + for m in data { + for ((r, c), v) in m.indexed_iter() { + block[(r + r_d, c + c_d)] = *v; + } + + r_d += m.nrows(); + c_d += m.ncols(); + } + + block + } + } + }; +} + +tuple_impls! { A } +tuple_impls! { A B } +tuple_impls! { A B C } +tuple_impls! { A B C D } +tuple_impls! { A B C D E } +tuple_impls! { A B C D E F } +tuple_impls! { A B C D E F G } +tuple_impls! { A B C D E F G H } +tuple_impls! { A B C D E F G H I } +tuple_impls! { A B C D E F G H I J } +tuple_impls! { A B C D E F G H I J K } +tuple_impls! { A B C D E F G H I J K L }