From 347af1c908b3382b3762cfc442fc84716ada93c1 Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Thu, 28 Apr 2022 10:23:37 +0200 Subject: [PATCH] Added test for linalg::solve. Fix bug in fitter (wrong argument order) --- src/fitter/recursive.rs | 14 +++++++------- src/linalg.rs | 26 +++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/src/fitter/recursive.rs b/src/fitter/recursive.rs index ee65beb..44326d7 100644 --- a/src/fitter/recursive.rs +++ b/src/fitter/recursive.rs @@ -264,9 +264,10 @@ impl Fitter for Recursive { }; // RTS update using the right neighbor. - let a = self.kernel.transition(ts[i], self.ts[(j + 1) as usize]); - let a = a.dot(&p); - let b = self.p_p[(j + 1) as usize].clone(); + let a = self.p_p[(j + 1) as usize].clone(); + + let b = self.kernel.transition(ts[i], self.ts[(j + 1) as usize]); + let b = b.dot(&p); let g = crate::linalg::solve(a, b); let g = g.t(); @@ -327,7 +328,6 @@ mod tests { extern crate blas_src; use approx::assert_relative_eq; - use rand::{distributions::Standard, thread_rng, Rng}; use crate::kernel::Matern32; @@ -463,7 +463,7 @@ mod tests { fitter.fit(); // Estimation - eprintln!("{:#?}", fitter.ms); + // eprintln!("{:#?}", fitter.ms); assert_relative_eq!( Array1::from(fitter.ms.clone()), @@ -479,8 +479,8 @@ mod tests { // Prediction. let (ms, vs) = fitter.predict(data_ts_pred()); - assert_relative_eq!(Array1::from(ms), data_mean_pred(), max_relative = 0.0000001); - assert_relative_eq!(Array1::from(vs), data_var_pred(), max_relative = 0.0000001); + assert_relative_eq!(Array1::from(ms), data_mean_pred(), max_relative = 0.000001); + assert_relative_eq!(Array1::from(vs), data_var_pred(), max_relative = 0.000001); /* # Log-likelihood. diff --git a/src/linalg.rs b/src/linalg.rs index 1d4a5dd..9127cbf 100644 --- a/src/linalg.rs +++ b/src/linalg.rs @@ -29,7 +29,7 @@ pub fn solve(mut a: Array2, mut b: Array2) -> Array2 { let mut ipiv = vec![0; n as usize]; let (a_slice, layout) = as_slice_with_layout_mut(&mut a).expect("Matrix `a` not contiguous."); - let (b_slice, _) = as_slice_with_layout_mut(&mut b).expect("Matrix `a` not contiguous."); + let (b_slice, _) = as_slice_with_layout_mut(&mut b).expect("Matrix `b` not contiguous."); unsafe { info = lapacke::dgesv(layout, n, nrhs, a_slice, lda, &mut ipiv, b_slice, ldb); @@ -39,3 +39,27 @@ pub fn solve(mut a: Array2, mut b: Array2) -> Array2 { b } + +#[cfg(test)] +mod tests { + use approx::assert_relative_eq; + + use super::*; + + #[test] + fn test_solve() { + let a = array![[2.0, 0.0], [0.0, 6.0]]; + let b = array![[1.9645574, 0.56996938], [-0.56996938, 3.91924033]]; + + let e = array![[0.9822787, 0.28498469], [-0.0949949, 0.65320672]]; + + assert_relative_eq!(solve(a, b), e, max_relative = 0.0000001); + + let a = array![[1.22650294, 0.67685001], [0.67685001, 5.34978493]]; + let b = array![[1.00213885, 1.47494613], [-0.03356356, 2.14214437]]; + + let e = array![[0.88212208, 1.0552697], [-0.11787911, 0.26690514]]; + + assert_relative_eq!(solve(a, b), e, max_relative = 0.0000001); + } +}