Added test for linalg::solve. Fix bug in fitter (wrong argument order)
This commit is contained in:
@@ -264,9 +264,10 @@ impl<K: Kernel> Fitter for Recursive<K> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// RTS update using the right neighbor.
|
// RTS update using the right neighbor.
|
||||||
let a = self.kernel.transition(ts[i], self.ts[(j + 1) as usize]);
|
let a = self.p_p[(j + 1) as usize].clone();
|
||||||
let a = a.dot(&p);
|
|
||||||
let b = self.p_p[(j + 1) as usize].clone();
|
let b = self.kernel.transition(ts[i], self.ts[(j + 1) as usize]);
|
||||||
|
let b = b.dot(&p);
|
||||||
|
|
||||||
let g = crate::linalg::solve(a, b);
|
let g = crate::linalg::solve(a, b);
|
||||||
let g = g.t();
|
let g = g.t();
|
||||||
@@ -327,7 +328,6 @@ mod tests {
|
|||||||
extern crate blas_src;
|
extern crate blas_src;
|
||||||
|
|
||||||
use approx::assert_relative_eq;
|
use approx::assert_relative_eq;
|
||||||
use rand::{distributions::Standard, thread_rng, Rng};
|
|
||||||
|
|
||||||
use crate::kernel::Matern32;
|
use crate::kernel::Matern32;
|
||||||
|
|
||||||
@@ -463,7 +463,7 @@ mod tests {
|
|||||||
fitter.fit();
|
fitter.fit();
|
||||||
|
|
||||||
// Estimation
|
// Estimation
|
||||||
eprintln!("{:#?}", fitter.ms);
|
// eprintln!("{:#?}", fitter.ms);
|
||||||
|
|
||||||
assert_relative_eq!(
|
assert_relative_eq!(
|
||||||
Array1::from(fitter.ms.clone()),
|
Array1::from(fitter.ms.clone()),
|
||||||
@@ -479,8 +479,8 @@ mod tests {
|
|||||||
// Prediction.
|
// Prediction.
|
||||||
let (ms, vs) = fitter.predict(data_ts_pred());
|
let (ms, vs) = fitter.predict(data_ts_pred());
|
||||||
|
|
||||||
assert_relative_eq!(Array1::from(ms), data_mean_pred(), max_relative = 0.0000001);
|
assert_relative_eq!(Array1::from(ms), data_mean_pred(), max_relative = 0.000001);
|
||||||
assert_relative_eq!(Array1::from(vs), data_var_pred(), max_relative = 0.0000001);
|
assert_relative_eq!(Array1::from(vs), data_var_pred(), max_relative = 0.000001);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
# Log-likelihood.
|
# Log-likelihood.
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ pub fn solve(mut a: Array2<f64>, mut b: Array2<f64>) -> Array2<f64> {
|
|||||||
let mut ipiv = vec![0; n as usize];
|
let mut ipiv = vec![0; n as usize];
|
||||||
|
|
||||||
let (a_slice, layout) = as_slice_with_layout_mut(&mut a).expect("Matrix `a` not contiguous.");
|
let (a_slice, layout) = as_slice_with_layout_mut(&mut a).expect("Matrix `a` not contiguous.");
|
||||||
let (b_slice, _) = as_slice_with_layout_mut(&mut b).expect("Matrix `a` not contiguous.");
|
let (b_slice, _) = as_slice_with_layout_mut(&mut b).expect("Matrix `b` not contiguous.");
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
info = lapacke::dgesv(layout, n, nrhs, a_slice, lda, &mut ipiv, b_slice, ldb);
|
info = lapacke::dgesv(layout, n, nrhs, a_slice, lda, &mut ipiv, b_slice, ldb);
|
||||||
@@ -39,3 +39,27 @@ pub fn solve(mut a: Array2<f64>, mut b: Array2<f64>) -> Array2<f64> {
|
|||||||
|
|
||||||
b
|
b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use approx::assert_relative_eq;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_solve() {
|
||||||
|
let a = array![[2.0, 0.0], [0.0, 6.0]];
|
||||||
|
let b = array![[1.9645574, 0.56996938], [-0.56996938, 3.91924033]];
|
||||||
|
|
||||||
|
let e = array![[0.9822787, 0.28498469], [-0.0949949, 0.65320672]];
|
||||||
|
|
||||||
|
assert_relative_eq!(solve(a, b), e, max_relative = 0.0000001);
|
||||||
|
|
||||||
|
let a = array![[1.22650294, 0.67685001], [0.67685001, 5.34978493]];
|
||||||
|
let b = array![[1.00213885, 1.47494613], [-0.03356356, 2.14214437]];
|
||||||
|
|
||||||
|
let e = array![[0.88212208, 1.0552697], [-0.11787911, 0.26690514]];
|
||||||
|
|
||||||
|
assert_relative_eq!(solve(a, b), e, max_relative = 0.0000001);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user