Added Kernel impl for tuples
This commit is contained in:
236
src/kernel.rs
236
src/kernel.rs
@@ -1,3 +1,5 @@
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
use ndarray::prelude::*;
|
||||
|
||||
mod affine;
|
||||
@@ -417,6 +419,19 @@ mod tests {
|
||||
[0.3933043131476467, 0.4908539038626858, 1.7]
|
||||
]
|
||||
);
|
||||
|
||||
let kernel = (Matern32::new(1.5, 0.7), Matern52::new(0.2, 5.0));
|
||||
|
||||
let ts = [1.26, 1.46, 2.67];
|
||||
|
||||
assert_abs_diff_eq!(
|
||||
kernel.k_mat(&ts, None),
|
||||
array![
|
||||
[1.7, 1.5667546855502472, 0.3933043131476467],
|
||||
[1.5667546855502472, 1.7, 0.4908539038626858],
|
||||
[0.3933043131476467, 0.4908539038626858, 1.7]
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -433,6 +448,16 @@ mod tests {
|
||||
.collect();
|
||||
|
||||
assert_eq!(kernel.k_mat(&ts, None).diag(), kernel.k_diag(&ts));
|
||||
|
||||
let kernel = (Matern32::new(1.5, 0.7), Matern52::new(0.2, 5.0));
|
||||
|
||||
let ts: Vec<_> = thread_rng()
|
||||
.sample_iter::<f64, _>(Standard)
|
||||
.take(10)
|
||||
.map(|x| x * 10.0)
|
||||
.collect();
|
||||
|
||||
assert_eq!(kernel.k_mat(&ts, None).diag(), kernel.k_diag(&ts));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -451,6 +476,18 @@ mod tests {
|
||||
assert_eq!(kernel.noise_effect().shape()[0], m);
|
||||
assert_eq!(kernel.transition(0.0, 1.0).shape(), &[m, m]);
|
||||
assert_eq!(kernel.noise_cov(0.0, 1.0).shape(), &[m, m]);
|
||||
|
||||
let kernel = (Matern32::new(1.5, 0.7), Matern52::new(0.2, 5.0));
|
||||
|
||||
let m = kernel.order();
|
||||
|
||||
assert_eq!(kernel.state_mean(0.0).shape(), &[m]);
|
||||
assert_eq!(kernel.state_cov(0.0).shape(), &[m, m]);
|
||||
assert_eq!(kernel.measurement_vector().shape(), &[m]);
|
||||
assert_eq!(kernel.feedback().shape(), &[m, m]);
|
||||
assert_eq!(kernel.noise_effect().shape()[0], m);
|
||||
assert_eq!(kernel.transition(0.0, 1.0).shape(), &[m, m]);
|
||||
assert_eq!(kernel.noise_cov(0.0, 1.0).shape(), &[m, m]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -474,5 +511,204 @@ mod tests {
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_abs_diff_eq!(Array::from(vars), kernel.k_diag(&ts));
|
||||
|
||||
let kernel = (Matern32::new(1.5, 0.7), Matern52::new(0.2, 5.0));
|
||||
|
||||
let ts: Vec<_> = thread_rng()
|
||||
.sample_iter::<f64, _>(Standard)
|
||||
.take(10)
|
||||
.map(|x| x * 10.0)
|
||||
.collect();
|
||||
|
||||
let h = kernel.measurement_vector();
|
||||
|
||||
let vars = ts
|
||||
.iter()
|
||||
.map(|t| h.dot(&kernel.state_cov(*t)).dot(&h))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_abs_diff_eq!(Array::from(vars), kernel.k_diag(&ts));
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! tuple_impls {
|
||||
( $( $name:ident )+ ) => {
|
||||
impl<$($name: Kernel),+> Kernel for ($($name,)+) {
|
||||
fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2<f64> {
|
||||
let n = ts1.len();
|
||||
let m = ts2.map_or(n, |ts| ts.len());
|
||||
|
||||
let ($($name,)+) = &self;
|
||||
|
||||
Array2::zeros((n, m)) $(+$name.k_mat(ts1, ts2))+
|
||||
}
|
||||
|
||||
fn k_diag(&self, ts: &[f64]) -> Array1<f64> {
|
||||
let ($($name,)+) = &self;
|
||||
|
||||
Array1::zeros(ts.len()) $(+$name.k_diag(ts))+
|
||||
}
|
||||
|
||||
fn order(&self) -> usize {
|
||||
let ($($name,)+) = &self;
|
||||
|
||||
0 $(+$name.order())+
|
||||
}
|
||||
|
||||
fn state_mean(&self, t: f64) -> Array1<f64> {
|
||||
let ($($name,)+) = &self;
|
||||
|
||||
Array1::from_iter([$($name.state_mean(t).into_iter(),)+].into_iter().flatten().cloned())
|
||||
}
|
||||
|
||||
fn state_cov(&self, t: f64) -> Array2<f64> {
|
||||
let ($($name,)+) = &self;
|
||||
|
||||
let data = [$($name.state_cov(t),)+];
|
||||
|
||||
let dim = data
|
||||
.iter()
|
||||
.fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols()));
|
||||
|
||||
let mut block = Array2::zeros(dim);
|
||||
|
||||
let mut r_d = 0;
|
||||
let mut c_d = 0;
|
||||
|
||||
for m in data {
|
||||
for ((r, c), v) in m.indexed_iter() {
|
||||
block[(r + r_d, c + c_d)] = *v;
|
||||
}
|
||||
|
||||
r_d += m.nrows();
|
||||
c_d += m.ncols();
|
||||
}
|
||||
|
||||
block
|
||||
}
|
||||
|
||||
fn measurement_vector(&self) -> Array1<f64> {
|
||||
let ($($name,)+) = &self;
|
||||
|
||||
Array1::from_iter([$($name.measurement_vector().into_iter(),)+].into_iter().flatten().cloned())
|
||||
}
|
||||
|
||||
fn feedback(&self) -> Array2<f64> {
|
||||
let ($($name,)+) = &self;
|
||||
|
||||
let data = [$($name.feedback(),)+];
|
||||
|
||||
let dim = data
|
||||
.iter()
|
||||
.fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols()));
|
||||
|
||||
let mut block = Array2::zeros(dim);
|
||||
|
||||
let mut r_d = 0;
|
||||
let mut c_d = 0;
|
||||
|
||||
for m in data {
|
||||
for ((r, c), v) in m.indexed_iter() {
|
||||
block[(r + r_d, c + c_d)] = *v;
|
||||
}
|
||||
|
||||
r_d += m.nrows();
|
||||
c_d += m.ncols();
|
||||
}
|
||||
|
||||
block
|
||||
}
|
||||
|
||||
fn transition(&self, t0: f64, t1: f64) -> Array2<f64> {
|
||||
let ($($name,)+) = &self;
|
||||
|
||||
let data = [$($name.transition(t0, t1),)+];
|
||||
|
||||
let dim = data
|
||||
.iter()
|
||||
.fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols()));
|
||||
|
||||
let mut block = Array2::zeros(dim);
|
||||
|
||||
let mut r_d = 0;
|
||||
let mut c_d = 0;
|
||||
|
||||
for m in data {
|
||||
for ((r, c), v) in m.indexed_iter() {
|
||||
block[(r + r_d, c + c_d)] = *v;
|
||||
}
|
||||
|
||||
r_d += m.nrows();
|
||||
c_d += m.ncols();
|
||||
}
|
||||
|
||||
block
|
||||
}
|
||||
|
||||
fn noise_effect(&self) -> Array2<f64> {
|
||||
let ($($name,)+) = &self;
|
||||
|
||||
let data = [$($name.noise_effect(),)+];
|
||||
|
||||
let dim = data
|
||||
.iter()
|
||||
.fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols()));
|
||||
|
||||
let mut block = Array2::zeros(dim);
|
||||
|
||||
let mut r_d = 0;
|
||||
let mut c_d = 0;
|
||||
|
||||
for m in data {
|
||||
for ((r, c), v) in m.indexed_iter() {
|
||||
block[(r + r_d, c + c_d)] = *v;
|
||||
}
|
||||
|
||||
r_d += m.nrows();
|
||||
c_d += m.ncols();
|
||||
}
|
||||
|
||||
block
|
||||
}
|
||||
|
||||
fn noise_cov(&self, t0: f64, t1: f64) -> Array2<f64> {
|
||||
let ($($name,)+) = &self;
|
||||
|
||||
let data = [$($name.noise_cov(t0, t1),)+];
|
||||
|
||||
let dim = data
|
||||
.iter()
|
||||
.fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols()));
|
||||
|
||||
let mut block = Array2::zeros(dim);
|
||||
|
||||
let mut r_d = 0;
|
||||
let mut c_d = 0;
|
||||
|
||||
for m in data {
|
||||
for ((r, c), v) in m.indexed_iter() {
|
||||
block[(r + r_d, c + c_d)] = *v;
|
||||
}
|
||||
|
||||
r_d += m.nrows();
|
||||
c_d += m.ncols();
|
||||
}
|
||||
|
||||
block
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
tuple_impls! { A }
|
||||
tuple_impls! { A B }
|
||||
tuple_impls! { A B C }
|
||||
tuple_impls! { A B C D }
|
||||
tuple_impls! { A B C D E }
|
||||
tuple_impls! { A B C D E F }
|
||||
tuple_impls! { A B C D E F G }
|
||||
tuple_impls! { A B C D E F G H }
|
||||
tuple_impls! { A B C D E F G H I }
|
||||
tuple_impls! { A B C D E F G H I J }
|
||||
tuple_impls! { A B C D E F G H I J K }
|
||||
tuple_impls! { A B C D E F G H I J K L }
|
||||
|
||||
Reference in New Issue
Block a user