More test and fix a bug
This commit is contained in:
155
src/kernel.rs
155
src/kernel.rs
@@ -50,8 +50,14 @@ pub trait Kernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Kernel for Vec<Box<dyn Kernel>> {
|
impl Kernel for Vec<Box<dyn Kernel>> {
|
||||||
fn k_mat(&self, _ts1: &[f64], _ts2: Option<&[f64]>) -> Array2<f64> {
|
fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2<f64> {
|
||||||
unimplemented!();
|
let n = ts1.len();
|
||||||
|
let m = ts2.map_or(n, |ts| ts.len());
|
||||||
|
|
||||||
|
self.iter()
|
||||||
|
.fold(Array2::zeros((n, m)), |k_diag: Array2<f64>, kernel| {
|
||||||
|
k_diag + kernel.k_mat(ts1, ts2)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn k_diag(&self, ts: &[f64]) -> Array1<f64> {
|
fn k_diag(&self, ts: &[f64]) -> Array1<f64> {
|
||||||
@@ -82,23 +88,23 @@ impl Kernel for Vec<Box<dyn Kernel>> {
|
|||||||
|
|
||||||
let dim = data
|
let dim = data
|
||||||
.iter()
|
.iter()
|
||||||
.fold((0, 0), |(w, h), m| (w + m.ncols(), h + m.nrows()));
|
.fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols()));
|
||||||
|
|
||||||
let mut cov = Array2::zeros(dim);
|
let mut block = Array2::zeros(dim);
|
||||||
|
|
||||||
let mut r_d = 0;
|
let mut r_d = 0;
|
||||||
let mut c_d = 0;
|
let mut c_d = 0;
|
||||||
|
|
||||||
for m in data {
|
for m in data {
|
||||||
for ((r, c), v) in m.indexed_iter() {
|
for ((r, c), v) in m.indexed_iter() {
|
||||||
cov[(r + r_d, c + c_d)] = *v;
|
block[(r + r_d, c + c_d)] = *v;
|
||||||
}
|
}
|
||||||
|
|
||||||
r_d += m.nrows();
|
r_d += m.nrows();
|
||||||
c_d += m.ncols();
|
c_d += m.ncols();
|
||||||
}
|
}
|
||||||
|
|
||||||
cov
|
block
|
||||||
}
|
}
|
||||||
|
|
||||||
fn measurement_vector(&self) -> Array1<f64> {
|
fn measurement_vector(&self) -> Array1<f64> {
|
||||||
@@ -118,23 +124,23 @@ impl Kernel for Vec<Box<dyn Kernel>> {
|
|||||||
|
|
||||||
let dim = data
|
let dim = data
|
||||||
.iter()
|
.iter()
|
||||||
.fold((0, 0), |(w, h), m| (w + m.ncols(), h + m.nrows()));
|
.fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols()));
|
||||||
|
|
||||||
let mut feedback = Array2::zeros(dim);
|
let mut block = Array2::zeros(dim);
|
||||||
|
|
||||||
let mut r_d = 0;
|
let mut r_d = 0;
|
||||||
let mut c_d = 0;
|
let mut c_d = 0;
|
||||||
|
|
||||||
for m in data {
|
for m in data {
|
||||||
for ((r, c), v) in m.indexed_iter() {
|
for ((r, c), v) in m.indexed_iter() {
|
||||||
feedback[(r + r_d, c + c_d)] = *v;
|
block[(r + r_d, c + c_d)] = *v;
|
||||||
}
|
}
|
||||||
|
|
||||||
r_d += m.nrows();
|
r_d += m.nrows();
|
||||||
c_d += m.ncols();
|
c_d += m.ncols();
|
||||||
}
|
}
|
||||||
|
|
||||||
feedback
|
block
|
||||||
}
|
}
|
||||||
|
|
||||||
fn transition(&self, t0: f64, t1: f64) -> Array2<f64> {
|
fn transition(&self, t0: f64, t1: f64) -> Array2<f64> {
|
||||||
@@ -145,23 +151,50 @@ impl Kernel for Vec<Box<dyn Kernel>> {
|
|||||||
|
|
||||||
let dim = data
|
let dim = data
|
||||||
.iter()
|
.iter()
|
||||||
.fold((0, 0), |(w, h), m| (w + m.ncols(), h + m.nrows()));
|
.fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols()));
|
||||||
|
|
||||||
let mut transition = Array2::zeros(dim);
|
let mut block = Array2::zeros(dim);
|
||||||
|
|
||||||
let mut r_d = 0;
|
let mut r_d = 0;
|
||||||
let mut c_d = 0;
|
let mut c_d = 0;
|
||||||
|
|
||||||
for m in data {
|
for m in data {
|
||||||
for ((r, c), v) in m.indexed_iter() {
|
for ((r, c), v) in m.indexed_iter() {
|
||||||
transition[(r + r_d, c + c_d)] = *v;
|
block[(r + r_d, c + c_d)] = *v;
|
||||||
}
|
}
|
||||||
|
|
||||||
r_d += m.nrows();
|
r_d += m.nrows();
|
||||||
c_d += m.ncols();
|
c_d += m.ncols();
|
||||||
}
|
}
|
||||||
|
|
||||||
transition
|
block
|
||||||
|
}
|
||||||
|
|
||||||
|
fn noise_effect(&self) -> Array2<f64> {
|
||||||
|
let data = self
|
||||||
|
.iter()
|
||||||
|
.map(|kernel| kernel.noise_effect())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let dim = data
|
||||||
|
.iter()
|
||||||
|
.fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols()));
|
||||||
|
|
||||||
|
let mut block = Array2::zeros(dim);
|
||||||
|
|
||||||
|
let mut r_d = 0;
|
||||||
|
let mut c_d = 0;
|
||||||
|
|
||||||
|
for m in data {
|
||||||
|
for ((r, c), v) in m.indexed_iter() {
|
||||||
|
block[(r + r_d, c + c_d)] = *v;
|
||||||
|
}
|
||||||
|
|
||||||
|
r_d += m.nrows();
|
||||||
|
c_d += m.ncols();
|
||||||
|
}
|
||||||
|
|
||||||
|
block
|
||||||
}
|
}
|
||||||
|
|
||||||
fn noise_cov(&self, t0: f64, t1: f64) -> Array2<f64> {
|
fn noise_cov(&self, t0: f64, t1: f64) -> Array2<f64> {
|
||||||
@@ -172,22 +205,108 @@ impl Kernel for Vec<Box<dyn Kernel>> {
|
|||||||
|
|
||||||
let dim = data
|
let dim = data
|
||||||
.iter()
|
.iter()
|
||||||
.fold((0, 0), |(w, h), m| (w + m.ncols(), h + m.nrows()));
|
.fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols()));
|
||||||
|
|
||||||
let mut cov = Array2::zeros(dim);
|
let mut block = Array2::zeros(dim);
|
||||||
|
|
||||||
let mut r_d = 0;
|
let mut r_d = 0;
|
||||||
let mut c_d = 0;
|
let mut c_d = 0;
|
||||||
|
|
||||||
for m in data {
|
for m in data {
|
||||||
for ((r, c), v) in m.indexed_iter() {
|
for ((r, c), v) in m.indexed_iter() {
|
||||||
cov[(r + r_d, c + c_d)] = *v;
|
block[(r + r_d, c + c_d)] = *v;
|
||||||
}
|
}
|
||||||
|
|
||||||
r_d += m.nrows();
|
r_d += m.nrows();
|
||||||
c_d += m.ncols();
|
c_d += m.ncols();
|
||||||
}
|
}
|
||||||
|
|
||||||
cov
|
block
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
use approx::assert_abs_diff_eq;
|
||||||
|
use rand::{distributions::Standard, thread_rng, Rng};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_kernel_matrix() {
|
||||||
|
let kernel: Vec<Box<dyn Kernel>> = vec![
|
||||||
|
Box::new(Matern32::new(1.5, 0.7)),
|
||||||
|
Box::new(Matern52::new(0.2, 5.0)),
|
||||||
|
];
|
||||||
|
|
||||||
|
let ts = [1.26, 1.46, 2.67];
|
||||||
|
|
||||||
|
assert_abs_diff_eq!(
|
||||||
|
kernel.k_mat(&ts, None),
|
||||||
|
array![
|
||||||
|
[1.7, 1.5667546855502472, 0.3933043131476467],
|
||||||
|
[1.5667546855502472, 1.7, 0.4908539038626858],
|
||||||
|
[0.3933043131476467, 0.4908539038626858, 1.7]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_kernel_diag() {
|
||||||
|
let kernel: Vec<Box<dyn Kernel>> = vec![
|
||||||
|
Box::new(Matern32::new(1.5, 0.7)),
|
||||||
|
Box::new(Matern52::new(0.2, 5.0)),
|
||||||
|
];
|
||||||
|
|
||||||
|
let ts: Vec<_> = thread_rng()
|
||||||
|
.sample_iter::<f64, _>(Standard)
|
||||||
|
.take(10)
|
||||||
|
.map(|x| x * 10.0)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
assert_eq!(kernel.k_mat(&ts, None).diag(), kernel.k_diag(&ts));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_kernel_order() {
|
||||||
|
let kernel: Vec<Box<dyn Kernel>> = vec![
|
||||||
|
Box::new(Matern32::new(1.5, 0.7)),
|
||||||
|
Box::new(Matern52::new(0.2, 5.0)),
|
||||||
|
];
|
||||||
|
|
||||||
|
let m = kernel.order();
|
||||||
|
|
||||||
|
assert_eq!(kernel.state_mean(0.0).shape(), &[m]);
|
||||||
|
assert_eq!(kernel.state_cov(0.0).shape(), &[m, m]);
|
||||||
|
assert_eq!(kernel.measurement_vector().shape(), &[m]);
|
||||||
|
assert_eq!(kernel.feedback().shape(), &[m, m]);
|
||||||
|
assert_eq!(kernel.noise_effect().shape()[0], m);
|
||||||
|
assert_eq!(kernel.transition(0.0, 1.0).shape(), &[m, m]);
|
||||||
|
assert_eq!(kernel.noise_cov(0.0, 1.0).shape(), &[m, m]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ssm_variance() {
|
||||||
|
let kernel: Vec<Box<dyn Kernel>> = vec![
|
||||||
|
Box::new(Matern32::new(1.5, 0.7)),
|
||||||
|
Box::new(Matern52::new(0.2, 5.0)),
|
||||||
|
];
|
||||||
|
|
||||||
|
let ts: Vec<_> = thread_rng()
|
||||||
|
.sample_iter::<f64, _>(Standard)
|
||||||
|
.take(10)
|
||||||
|
.map(|x| x * 10.0)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let h = kernel.measurement_vector();
|
||||||
|
|
||||||
|
let vars = ts
|
||||||
|
.iter()
|
||||||
|
.map(|t| h.dot(&kernel.state_cov(*t)).dot(&h))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
assert_abs_diff_eq!(Array::from(vars), kernel.k_diag(&ts));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user