Fixed a bug, and added a custom solve function (using lapacke).
This commit is contained in:
43
src/linalg.rs
Normal file
43
src/linalg.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
use ndarray::{prelude::*, DataMut};
|
||||
|
||||
fn as_slice_with_layout_mut<S, T, D>(a: &mut ArrayBase<S, D>) -> Option<(&mut [T], lapacke::Layout)>
|
||||
where
|
||||
S: DataMut<Elem = T>,
|
||||
D: Dimension,
|
||||
{
|
||||
if a.as_slice_mut().is_some() {
|
||||
Some((a.as_slice_mut().unwrap(), lapacke::Layout::RowMajor))
|
||||
} else if a.as_slice_memory_order_mut().is_some() {
|
||||
Some((
|
||||
a.as_slice_memory_order_mut().unwrap(),
|
||||
lapacke::Layout::ColumnMajor,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn solve(mut a: Array2<f64>, mut b: Array2<f64>) -> Array2<f64> {
|
||||
assert!(a.is_square());
|
||||
|
||||
let n = a.ncols() as i32;
|
||||
let nrhs = b.ncols() as i32;
|
||||
let lda = n;
|
||||
let ldb = nrhs;
|
||||
|
||||
let info;
|
||||
let mut ipiv = vec![0; n as usize];
|
||||
|
||||
let (a_slice, layout) = as_slice_with_layout_mut(&mut a).expect("Matrix `a` not contiguous.");
|
||||
let (b_slice, _) = as_slice_with_layout_mut(&mut b).expect("Matrix `a` not contiguous.");
|
||||
|
||||
unsafe {
|
||||
info = lapacke::dgesv(layout, n, nrhs, a_slice, lda, &mut ipiv, b_slice, ldb);
|
||||
}
|
||||
|
||||
if info != 0 {
|
||||
panic!("info={}", info);
|
||||
}
|
||||
|
||||
b
|
||||
}
|
||||
Reference in New Issue
Block a user