diff --git a/examples/abcdef.rs b/examples/abcdef.rs index b59a7bb..1a14c9d 100644 --- a/examples/abcdef.rs +++ b/examples/abcdef.rs @@ -6,7 +6,7 @@ fn main() { let mut model = ks::model::Binary::new(ks::model::binary::Observation::Probit); for player in &["A", "B", "C", "D", "E", "F"] { - let kernel: Vec> = vec![ + let kernel: [Box; 2] = [ Box::new(ks::kernel::Constant::new(1.0)), Box::new(ks::kernel::Matern52::new(0.5, 1.0)), ]; diff --git a/src/kernel.rs b/src/kernel.rs index 544d217..9fc8ca9 100644 --- a/src/kernel.rs +++ b/src/kernel.rs @@ -215,6 +215,182 @@ impl Kernel for Vec> { } } +impl Kernel for [Box; N] { + fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2 { + let n = ts1.len(); + let m = ts2.map_or(n, |ts| ts.len()); + + self.iter() + .fold(Array2::zeros((n, m)), |k_diag: Array2, kernel| { + k_diag + kernel.k_mat(ts1, ts2) + }) + } + + fn k_diag(&self, ts: &[f64]) -> Array1 { + self.iter() + .fold(Array1::zeros(ts.len()), |k_diag: Array1, kernel| { + k_diag + kernel.k_diag(ts) + }) + } + + fn order(&self) -> usize { + self.iter().map(|kernel| kernel.order()).sum() + } + + fn state_mean(&self, t: f64) -> Array1 { + let data = self + .iter() + .flat_map(|kernel| kernel.state_mean(t).to_vec().into_iter()) + .collect::>(); + + Array1::from(data) + } + + fn state_cov(&self, t: f64) -> Array2 { + let data = self + .iter() + .map(|kernel| kernel.state_cov(t)) + .collect::>(); + + let dim = data + .iter() + .fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols())); + + let mut block = Array2::zeros(dim); + + let mut r_d = 0; + let mut c_d = 0; + + for m in data { + for ((r, c), v) in m.indexed_iter() { + block[(r + r_d, c + c_d)] = *v; + } + + r_d += m.nrows(); + c_d += m.ncols(); + } + + block + } + + fn measurement_vector(&self) -> Array1 { + let data = self + .iter() + .flat_map(|kernel| kernel.measurement_vector().to_vec().into_iter()) + .collect::>(); + + Array1::from(data) + } + + fn feedback(&self) -> Array2 { + let data = self + .iter() + .map(|kernel| kernel.feedback()) + .collect::>(); + + let dim = data + .iter() + .fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols())); + + let mut block = Array2::zeros(dim); + + let mut r_d = 0; + let mut c_d = 0; + + for m in data { + for ((r, c), v) in m.indexed_iter() { + block[(r + r_d, c + c_d)] = *v; + } + + r_d += m.nrows(); + c_d += m.ncols(); + } + + block + } + + fn transition(&self, t0: f64, t1: f64) -> Array2 { + let data = self + .iter() + .map(|kernel| kernel.transition(t0, t1)) + .collect::>(); + + let dim = data + .iter() + .fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols())); + + let mut block = Array2::zeros(dim); + + let mut r_d = 0; + let mut c_d = 0; + + for m in data { + for ((r, c), v) in m.indexed_iter() { + block[(r + r_d, c + c_d)] = *v; + } + + r_d += m.nrows(); + c_d += m.ncols(); + } + + block + } + + fn noise_effect(&self) -> Array2 { + let data = self + .iter() + .map(|kernel| kernel.noise_effect()) + .collect::>(); + + let dim = data + .iter() + .fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols())); + + let mut block = Array2::zeros(dim); + + let mut r_d = 0; + let mut c_d = 0; + + for m in data { + for ((r, c), v) in m.indexed_iter() { + block[(r + r_d, c + c_d)] = *v; + } + + r_d += m.nrows(); + c_d += m.ncols(); + } + + block + } + + fn noise_cov(&self, t0: f64, t1: f64) -> Array2 { + let data = self + .iter() + .map(|kernel| kernel.noise_cov(t0, t1)) + .collect::>(); + + let dim = data + .iter() + .fold((0, 0), |(h, w), m| (h + m.nrows(), w + m.ncols())); + + let mut block = Array2::zeros(dim); + + let mut r_d = 0; + let mut c_d = 0; + + for m in data { + for ((r, c), v) in m.indexed_iter() { + block[(r + r_d, c + c_d)] = *v; + } + + r_d += m.nrows(); + c_d += m.ncols(); + } + + block + } +} + #[cfg(test)] mod tests { extern crate blas_src;