Fixed a bug, and added a custom solve function (using lapacke).

This commit is contained in:
2020-02-24 11:51:26 +01:00
parent eae980a74c
commit 9b79722984
8 changed files with 144 additions and 18 deletions

View File

@@ -1,7 +1,6 @@
use derivative::Derivative;
use ndarray::prelude::*;
use ndarray::stack;
use ndarray_linalg::Inverse;
use crate::kernel::Kernel;
@@ -101,22 +100,22 @@ impl Fitter for RecursiveFitter {
}
// Compute the new transition and noise covariance matrices.
let m = self.kernel.order();
for _ in 0..n_new {
self.a.push(Array2::zeros((m, m)));
self.q.push(Array2::zeros((m, m)));
}
for i in (self.ts.len() - n_new)..self.ts.len() {
if i == 0 {
continue;
}
self.a
.push(self.kernel.transition(self.ts[i - 1], self.ts[i]));
self.q
.push(self.kernel.noise_cov(self.ts[i - 1], self.ts[i]));
self.a[i - 1] = self.kernel.transition(self.ts[i - 1], self.ts[i]);
self.q[i - 1] = self.kernel.noise_cov(self.ts[i - 1], self.ts[i]);
}
let m = self.kernel.order();
self.a.push(Array2::zeros((m, m)));
self.q.push(Array2::zeros((m, m)));
self.ts_new.clear();
}
@@ -180,9 +179,17 @@ impl Fitter for RecursiveFitter {
self.m_s[i] = self.m_f[i].clone();
self.p_s[i] = self.p_f[i].clone();
} else {
let a = self.p_p[i + 1].clone();
let b = self.a[i].dot(&self.p_f[i]);
// println!("a={:#?}", a);
let g = crate::linalg::solve(a, b);
let g = g.t();
/*
let g = self.a[i]
.dot(&self.p_f[i])
.dot(&self.p_p[i + 1].inv().expect("failed to inverse matrix"));
*/
self.m_s[i] = &self.m_f[i] + &g.dot(&(&self.m_s[i + 1] - &self.m_p[i + 1]));
self.p_s[i] =
@@ -248,11 +255,10 @@ impl Fitter for RecursiveFitter {
// RTS update using the right neighbor.
let a = self.kernel.transition(ts[i], self.ts[(j + 1) as usize]);
let g = a.dot(&p).dot(
&self.p_p[(j + 1) as usize]
.inv()
.expect("failed to inverse matrix"),
);
let a = a.dot(&p);
let b = self.p_p[(j + 1) as usize].clone();
let g = crate::linalg::solve(a, b);
let g = g.t();
ms.push(self.h.dot(
&(&m + &g.dot(&(&self.m_s[(j + 1) as usize] - &self.m_p[(j + 1) as usize]))),