66 lines
1.8 KiB
Rust
66 lines
1.8 KiB
Rust
use ndarray::{prelude::*, DataMut};
|
|
|
|
fn as_slice_with_layout_mut<S, T, D>(a: &mut ArrayBase<S, D>) -> Option<(&mut [T], lapacke::Layout)>
|
|
where
|
|
S: DataMut<Elem = T>,
|
|
D: Dimension,
|
|
{
|
|
if a.as_slice_mut().is_some() {
|
|
Some((a.as_slice_mut().unwrap(), lapacke::Layout::RowMajor))
|
|
} else if a.as_slice_memory_order_mut().is_some() {
|
|
Some((
|
|
a.as_slice_memory_order_mut().unwrap(),
|
|
lapacke::Layout::ColumnMajor,
|
|
))
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
pub fn solve(mut a: Array2<f64>, mut b: Array2<f64>) -> Array2<f64> {
|
|
assert!(a.is_square());
|
|
|
|
let n = a.ncols() as i32;
|
|
let nrhs = b.ncols() as i32;
|
|
let lda = n;
|
|
let ldb = nrhs;
|
|
|
|
let info;
|
|
let mut ipiv = vec![0; n as usize];
|
|
|
|
let (a_slice, layout) = as_slice_with_layout_mut(&mut a).expect("Matrix `a` not contiguous.");
|
|
let (b_slice, _) = as_slice_with_layout_mut(&mut b).expect("Matrix `b` not contiguous.");
|
|
|
|
unsafe {
|
|
info = lapacke::dgesv(layout, n, nrhs, a_slice, lda, &mut ipiv, b_slice, ldb);
|
|
}
|
|
|
|
assert!(info == 0, "info={}", info);
|
|
|
|
b
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use approx::assert_relative_eq;
|
|
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_solve() {
|
|
let a = array![[2.0, 0.0], [0.0, 6.0]];
|
|
let b = array![[1.9645574, 0.56996938], [-0.56996938, 3.91924033]];
|
|
|
|
let e = array![[0.9822787, 0.28498469], [-0.0949949, 0.65320672]];
|
|
|
|
assert_relative_eq!(solve(a, b), e, max_relative = 0.0000001);
|
|
|
|
let a = array![[1.22650294, 0.67685001], [0.67685001, 5.34978493]];
|
|
let b = array![[1.00213885, 1.47494613], [-0.03356356, 2.14214437]];
|
|
|
|
let e = array![[0.88212208, 1.0552697], [-0.11787911, 0.26690514]];
|
|
|
|
assert_relative_eq!(solve(a, b), e, max_relative = 0.0000001);
|
|
}
|
|
}
|