Impl kernel for [Box<dyn Kernel>; N]

This commit is contained in:
2021-10-27 11:30:59 +02:00
parent 37746f6c02
commit 1eea1bfb71
2 changed files with 177 additions and 1 deletions

View File

@@ -6,7 +6,7 @@ fn main() {
let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit); let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit);
for player in &["A", "B", "C", "D", "E", "F"] { for player in &["A", "B", "C", "D", "E", "F"] {
let kernel: Vec<Box<dyn ks::Kernel>> = vec![ let kernel: [Box<dyn ks::Kernel>; 2] = [
Box::new(ks::kernel::Constant::new(1.0)), Box::new(ks::kernel::Constant::new(1.0)),
Box::new(ks::kernel::Matern52::new(0.5, 1.0)), Box::new(ks::kernel::Matern52::new(0.5, 1.0)),
]; ];

View File

@@ -215,6 +215,182 @@ impl Kernel for Vec<Box<dyn Kernel>> {
} }
} }
impl<const N: usize> Kernel for [Box<dyn Kernel>; N] {
fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2<f64> {
let n = ts1.len();
let m = ts2.map_or(n, |ts| ts.len());
self.iter()
.fold(Array2::zeros((n, m)), |k_diag: Array2<f64>, kernel| {
k_diag + kernel.k_mat(ts1, ts2)
})
}
fn k_diag(&self, ts: &[f64]) -> Array1<f64> {
self.iter()
.fold(Array1::zeros(ts.len()), |k_diag: Array1<f64>, kernel| {
k_diag + kernel.k_diag(ts)
})
}
fn order(&self) -> usize {
self.iter().map(|kernel| kernel.order()).sum()
}
fn state_mean(&self, t: f64) -> Array1<f64> {
let data = self
.iter()
.flat_map(|kernel| kernel.state_mean(t).to_vec().into_iter())
.collect::<Vec<f64>>();
Array1::from(data)
}
fn state_cov(&self, t: f64) -> Array2<f64> {
let data = self
.iter()
.map(|kernel| kernel.state_cov(t))
.collect::<Vec<_>>();
let dim = data
.iter()
.fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols()));
let mut block = Array2::zeros(dim);
let mut r_d = 0;
let mut c_d = 0;
for m in data {
for ((r, c), v) in m.indexed_iter() {
block[(r + r_d, c + c_d)] = *v;
}
r_d += m.nrows();
c_d += m.ncols();
}
block
}
fn measurement_vector(&self) -> Array1<f64> {
let data = self
.iter()
.flat_map(|kernel| kernel.measurement_vector().to_vec().into_iter())
.collect::<Vec<f64>>();
Array1::from(data)
}
fn feedback(&self) -> Array2<f64> {
let data = self
.iter()
.map(|kernel| kernel.feedback())
.collect::<Vec<_>>();
let dim = data
.iter()
.fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols()));
let mut block = Array2::zeros(dim);
let mut r_d = 0;
let mut c_d = 0;
for m in data {
for ((r, c), v) in m.indexed_iter() {
block[(r + r_d, c + c_d)] = *v;
}
r_d += m.nrows();
c_d += m.ncols();
}
block
}
fn transition(&self, t0: f64, t1: f64) -> Array2<f64> {
let data = self
.iter()
.map(|kernel| kernel.transition(t0, t1))
.collect::<Vec<_>>();
let dim = data
.iter()
.fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols()));
let mut block = Array2::zeros(dim);
let mut r_d = 0;
let mut c_d = 0;
for m in data {
for ((r, c), v) in m.indexed_iter() {
block[(r + r_d, c + c_d)] = *v;
}
r_d += m.nrows();
c_d += m.ncols();
}
block
}
fn noise_effect(&self) -> Array2<f64> {
let data = self
.iter()
.map(|kernel| kernel.noise_effect())
.collect::<Vec<_>>();
let dim = data
.iter()
.fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols()));
let mut block = Array2::zeros(dim);
let mut r_d = 0;
let mut c_d = 0;
for m in data {
for ((r, c), v) in m.indexed_iter() {
block[(r + r_d, c + c_d)] = *v;
}
r_d += m.nrows();
c_d += m.ncols();
}
block
}
fn noise_cov(&self, t0: f64, t1: f64) -> Array2<f64> {
let data = self
.iter()
.map(|kernel| kernel.noise_cov(t0, t1))
.collect::<Vec<_>>();
let dim = data
.iter()
.fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols()));
let mut block = Array2::zeros(dim);
let mut r_d = 0;
let mut c_d = 0;
for m in data {
for ((r, c), v) in m.indexed_iter() {
block[(r + r_d, c + c_d)] = *v;
}
r_d += m.nrows();
c_d += m.ncols();
}
block
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
extern crate blas_src; extern crate blas_src;