A lot of progress.

This commit is contained in:
2020-02-20 09:19:11 +01:00
parent fd249da405
commit 7528b3b67b
18 changed files with 2191 additions and 110 deletions

View File

@@ -14,7 +14,38 @@ pub trait Kernel {
fn state_mean(&self, t: f64) -> Array1<f64>;
fn state_cov(&self, t: f64) -> Array2<f64>;
fn measurement_vector(&self) -> Array1<f64>;
fn transition(&self, t0: f64, t1: f64) -> Array2<f64>;
fn feedback(&self) -> Array2<f64>;
fn transition(&self, t0: f64, t1: f64) -> Array2<f64> {
let f = self.feedback();
let a = f * (t1 - t0);
let mut b = Array2::<f64>::zeros(a.dim());
crate::expm::expm(&a, &mut b);
b
}
fn noise_cov(&self, t0: f64, t1: f64) -> Array2<f64> {
/*
mat = self.noise_effect.dot(self.noise_density).dot(self.noise_effect.T)
#print(g)
print(mat)
Phi = np.vstack((
np.hstack((self.feedback, mat)),
np.hstack((np.zeros_like(mat), -self.feedback.T))))
print(Phi)
m = self.order
AB = np.dot(sp.linalg.expm(Phi * (t2 - t1)), np.eye(2*m, m, k=-m))
print(AB)
return sp.linalg.solve(AB[m:,:].T, AB[:m,:].T)
*/
// let mat = self.noise_effect()
todo!();
}
}
impl Kernel for Vec<Box<dyn Kernel>> {
@@ -74,7 +105,57 @@ impl Kernel for Vec<Box<dyn Kernel>> {
Array1::from(data)
}
fn transition(&self, t0: f64, t1: f64) -> Array2<f64> {
todo!();
fn feedback(&self) -> Array2<f64> {
let data = self
.iter()
.map(|kernel| kernel.feedback())
.collect::<Vec<_>>();
let dim = data
.iter()
.fold((0, 0), |(w, h), m| (w + m.ncols(), h + m.nrows()));
let mut feedback = Array2::zeros(dim);
let mut r_d = 0;
let mut c_d = 0;
for m in data {
for ((r, c), v) in m.indexed_iter() {
feedback[(r + r_d, c + c_d)] = *v;
}
r_d += m.nrows();
c_d += m.ncols();
}
feedback
}
fn noise_cov(&self, t0: f64, t1: f64) -> Array2<f64> {
let data = self
.iter()
.map(|kernel| kernel.noise_cov(t0, t1))
.collect::<Vec<_>>();
let dim = data
.iter()
.fold((0, 0), |(w, h), m| (w + m.ncols(), h + m.nrows()));
let mut cov = Array2::zeros(dim);
let mut r_d = 0;
let mut c_d = 0;
for m in data {
for ((r, c), v) in m.indexed_iter() {
cov[(r + r_d, c + c_d)] = *v;
}
r_d += m.nrows();
c_d += m.ncols();
}
cov
}
}