Clean up dependencies
This commit is contained in:
@@ -12,10 +12,8 @@ ndarray = { version = "0.14", features = ["approx"] }
|
|||||||
ordered-float = "1.0"
|
ordered-float = "1.0"
|
||||||
rand = "0.7"
|
rand = "0.7"
|
||||||
rand_xoshiro = "0.4"
|
rand_xoshiro = "0.4"
|
||||||
statrs = "0.13"
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
approx = "0.4"
|
approx = "0.4"
|
||||||
blis-src = "0.2"
|
|
||||||
intel-mkl-src = "0.5"
|
intel-mkl-src = "0.5"
|
||||||
time = "0.2"
|
time = "0.2"
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let t = t.midnight().assume_utc().timestamp() as f64;
|
let t = t.midnight().assume_utc().unix_timestamp() as f64;
|
||||||
|
|
||||||
let score_1: u16 = data[3].parse()?;
|
let score_1: u16 = data[3].parse()?;
|
||||||
let score_2: u16 = data[4].parse()?;
|
let score_2: u16 = data[4].parse()?;
|
||||||
@@ -75,7 +75,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
time::date!(1996 - 01 - 01)
|
time::date!(1996 - 01 - 01)
|
||||||
.midnight()
|
.midnight()
|
||||||
.assume_utc()
|
.assume_utc()
|
||||||
.timestamp() as f64,
|
.unix_timestamp() as f64,
|
||||||
);
|
);
|
||||||
println!(" ... in 1996: {:.2}%", 100.0 * p_win);
|
println!(" ... in 1996: {:.2}%", 100.0 * p_win);
|
||||||
|
|
||||||
@@ -85,7 +85,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
time::date!(2001 - 01 - 01)
|
time::date!(2001 - 01 - 01)
|
||||||
.midnight()
|
.midnight()
|
||||||
.assume_utc()
|
.assume_utc()
|
||||||
.timestamp() as f64,
|
.unix_timestamp() as f64,
|
||||||
);
|
);
|
||||||
println!(" ... in 2001: {:.2}%", 100.0 * p_win);
|
println!(" ... in 2001: {:.2}%", 100.0 * p_win);
|
||||||
|
|
||||||
@@ -95,7 +95,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
time::date!(2020 - 01 - 01)
|
time::date!(2020 - 01 - 01)
|
||||||
.midnight()
|
.midnight()
|
||||||
.assume_utc()
|
.assume_utc()
|
||||||
.timestamp() as f64,
|
.unix_timestamp() as f64,
|
||||||
);
|
);
|
||||||
println!(" ... in 2020: {:.2}%", 100.0 * p_win);
|
println!(" ... in 2020: {:.2}%", 100.0 * p_win);
|
||||||
|
|
||||||
|
|||||||
921
src/condest.rs
921
src/condest.rs
@@ -1,921 +0,0 @@
|
|||||||
//! This crate implements the matrix 1-norm estimator by [Higham and Tisseur].
|
|
||||||
//!
|
|
||||||
//! [Higham and Tisseur]: http://eprints.ma.man.ac.uk/321/1/covered/MIMS_ep2006_145.pdf
|
|
||||||
use std::cmp;
|
|
||||||
use std::collections::BTreeSet;
|
|
||||||
use std::slice;
|
|
||||||
|
|
||||||
use ndarray::{prelude::*, s, ArrayBase, Data, DataMut, Dimension, Ix1, Ix2};
|
|
||||||
use ordered_float::NotNan;
|
|
||||||
use rand::{thread_rng, Rng, SeedableRng};
|
|
||||||
use rand_xoshiro::Xoshiro256StarStar;
|
|
||||||
|
|
||||||
pub struct Normest1 {
|
|
||||||
n: usize,
|
|
||||||
t: usize,
|
|
||||||
rng: Xoshiro256StarStar,
|
|
||||||
x_matrix: Array2<f64>,
|
|
||||||
y_matrix: Array2<f64>,
|
|
||||||
z_matrix: Array2<f64>,
|
|
||||||
w_vector: Array1<f64>,
|
|
||||||
sign_matrix: Array2<f64>,
|
|
||||||
sign_matrix_old: Array2<f64>,
|
|
||||||
column_is_parallel: Vec<bool>,
|
|
||||||
indices: Vec<usize>,
|
|
||||||
indices_history: BTreeSet<usize>,
|
|
||||||
h: Vec<NotNan<f64>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A trait to generalize over 1-norm estimates of a matrix `A`, matrix powers `A^m`,
|
|
||||||
/// or matrix products `A1 * A2 * ... * An`.
|
|
||||||
///
|
|
||||||
/// In the 1-norm estimator, one repeatedly constructs a matrix-matrix product between some n×n
|
|
||||||
/// matrix X and some other n×t matrix Y. If one wanted to estimate the 1-norm of a matrix m times
|
|
||||||
/// itself, X^m, it might thus be computationally less expensive to repeatedly apply
|
|
||||||
/// X * ( * ( X ... ( X * Y ) rather than to calculate Z = X^m = X * X * ... * X and then apply Z *
|
|
||||||
/// Y. In the first case, one has several matrix-matrix multiplications with complexity O(m*n*n*t),
|
|
||||||
/// while in the latter case one has O(m*n*n*n) (plus one more O(n*n*t)).
|
|
||||||
///
|
|
||||||
/// So in case of t << n, it is cheaper to repeatedly apply matrix multiplication to the smaller
|
|
||||||
/// matrix on the RHS, rather than to construct one definite matrix on the LHS. Of course, this is
|
|
||||||
/// modified by the number of iterations needed when performing the norm estimate, sustained
|
|
||||||
/// performance of the matrix multiplication method used, etc.
|
|
||||||
///
|
|
||||||
/// It is at the designation of the user to check what is more efficient: to pass in one definite
|
|
||||||
/// matrix or choose the alternative route described here.
|
|
||||||
trait LinearOperator {
|
|
||||||
fn multiply_matrix<S>(
|
|
||||||
&self,
|
|
||||||
b: &mut ArrayBase<S, Ix2>,
|
|
||||||
c: &mut ArrayBase<S, Ix2>,
|
|
||||||
transpose: bool,
|
|
||||||
) where
|
|
||||||
S: DataMut<Elem = f64>;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S1> LinearOperator for ArrayBase<S1, Ix2>
|
|
||||||
where
|
|
||||||
S1: Data<Elem = f64>,
|
|
||||||
{
|
|
||||||
fn multiply_matrix<S2>(
|
|
||||||
&self,
|
|
||||||
b: &mut ArrayBase<S2, Ix2>,
|
|
||||||
c: &mut ArrayBase<S2, Ix2>,
|
|
||||||
transpose: bool,
|
|
||||||
) where
|
|
||||||
S2: DataMut<Elem = f64>,
|
|
||||||
{
|
|
||||||
let (n_rows, n_cols) = self.dim();
|
|
||||||
assert_eq!(
|
|
||||||
n_rows, n_cols,
|
|
||||||
"Number of rows and columns does not match: `self` has to be a square matrix"
|
|
||||||
);
|
|
||||||
let n = n_rows;
|
|
||||||
|
|
||||||
let (b_n, b_t) = b.dim();
|
|
||||||
let (c_n, c_t) = b.dim();
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
n, b_n,
|
|
||||||
"Number of rows of b not equal to number of rows of `self`."
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
n, c_n,
|
|
||||||
"Number of rows of c not equal to number of rows of `self`."
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
b_t, c_t,
|
|
||||||
"Number of columns of b not equal to number of columns of c."
|
|
||||||
);
|
|
||||||
|
|
||||||
let t = b_t;
|
|
||||||
|
|
||||||
let (a_slice, a_layout) =
|
|
||||||
as_slice_with_layout(self).expect("Matrix `self` not contiguous.");
|
|
||||||
let (b_slice, b_layout) = as_slice_with_layout(b).expect("Matrix `b` not contiguous.");
|
|
||||||
let (c_slice, c_layout) = as_slice_with_layout_mut(c).expect("Matrix `c` not contiguous.");
|
|
||||||
|
|
||||||
assert_eq!(a_layout, b_layout);
|
|
||||||
assert_eq!(a_layout, c_layout);
|
|
||||||
|
|
||||||
let layout = a_layout;
|
|
||||||
|
|
||||||
let a_transpose = if transpose {
|
|
||||||
cblas::Transpose::Ordinary
|
|
||||||
} else {
|
|
||||||
cblas::Transpose::None
|
|
||||||
};
|
|
||||||
|
|
||||||
unsafe {
|
|
||||||
cblas::dgemm(
|
|
||||||
layout,
|
|
||||||
a_transpose,
|
|
||||||
cblas::Transpose::None,
|
|
||||||
n as i32,
|
|
||||||
t as i32,
|
|
||||||
n as i32,
|
|
||||||
1.0,
|
|
||||||
a_slice,
|
|
||||||
n as i32,
|
|
||||||
b_slice,
|
|
||||||
t as i32,
|
|
||||||
0.0,
|
|
||||||
c_slice,
|
|
||||||
t as i32,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S1> LinearOperator for [&ArrayBase<S1, Ix2>]
|
|
||||||
where
|
|
||||||
S1: Data<Elem = f64>,
|
|
||||||
{
|
|
||||||
fn multiply_matrix<S2>(
|
|
||||||
&self,
|
|
||||||
b: &mut ArrayBase<S2, Ix2>,
|
|
||||||
c: &mut ArrayBase<S2, Ix2>,
|
|
||||||
transpose: bool,
|
|
||||||
) where
|
|
||||||
S2: DataMut<Elem = f64>,
|
|
||||||
{
|
|
||||||
if !self.is_empty() {
|
|
||||||
let mut reversed;
|
|
||||||
let mut forward;
|
|
||||||
|
|
||||||
// TODO: Investigate, if an enum instead of a trait object might be more performant.
|
|
||||||
// This probably doesn't matter for large matrices, but could have a measurable impact
|
|
||||||
// on small ones.
|
|
||||||
let a_iter: &mut dyn DoubleEndedIterator<Item = _> = if transpose {
|
|
||||||
reversed = self.iter().rev();
|
|
||||||
&mut reversed
|
|
||||||
} else {
|
|
||||||
forward = self.iter();
|
|
||||||
&mut forward
|
|
||||||
};
|
|
||||||
let a = a_iter.next().unwrap(); // Ok because of if condition
|
|
||||||
a.multiply_matrix(b, c, transpose);
|
|
||||||
|
|
||||||
// NOTE: The swap in the loop body makes use of the fact that in all instances where
|
|
||||||
// `multiply_matrix` is used, the values potentially stored in `b` are not required
|
|
||||||
// anymore.
|
|
||||||
for a in a_iter {
|
|
||||||
std::mem::swap(b, c);
|
|
||||||
a.multiply_matrix(b, c, transpose);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S1> LinearOperator for (&ArrayBase<S1, Ix2>, usize)
|
|
||||||
where
|
|
||||||
S1: Data<Elem = f64>,
|
|
||||||
{
|
|
||||||
fn multiply_matrix<S2>(
|
|
||||||
&self,
|
|
||||||
b: &mut ArrayBase<S2, Ix2>,
|
|
||||||
c: &mut ArrayBase<S2, Ix2>,
|
|
||||||
transpose: bool,
|
|
||||||
) where
|
|
||||||
S2: DataMut<Elem = f64>,
|
|
||||||
{
|
|
||||||
let a = self.0;
|
|
||||||
let m = self.1;
|
|
||||||
if m > 0 {
|
|
||||||
a.multiply_matrix(b, c, transpose);
|
|
||||||
for _ in 1..m {
|
|
||||||
std::mem::swap(b, c);
|
|
||||||
self.0.multiply_matrix(b, c, transpose);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Normest1 {
|
|
||||||
pub fn new(n: usize, t: usize) -> Self {
|
|
||||||
assert!(
|
|
||||||
t <= n,
|
|
||||||
"Cannot have more iteration columns t than columns in the matrix."
|
|
||||||
);
|
|
||||||
let rng =
|
|
||||||
Xoshiro256StarStar::from_rng(&mut thread_rng()).expect("Rng initialization failed.");
|
|
||||||
let x_matrix = unsafe { Array2::<f64>::uninitialized((n, t)) };
|
|
||||||
let y_matrix = unsafe { Array2::<f64>::uninitialized((n, t)) };
|
|
||||||
let z_matrix = unsafe { Array2::<f64>::uninitialized((n, t)) };
|
|
||||||
|
|
||||||
let w_vector = unsafe { Array1::uninitialized(n) };
|
|
||||||
|
|
||||||
let sign_matrix = unsafe { Array2::<f64>::uninitialized((n, t)) };
|
|
||||||
let sign_matrix_old = unsafe { Array2::<f64>::uninitialized((n, t)) };
|
|
||||||
|
|
||||||
let column_is_parallel = vec![false; t];
|
|
||||||
|
|
||||||
let indices = (0..n).collect();
|
|
||||||
let indices_history = BTreeSet::new();
|
|
||||||
|
|
||||||
let h = vec![unsafe { NotNan::unchecked_new(0.0) }; n];
|
|
||||||
|
|
||||||
Normest1 {
|
|
||||||
n,
|
|
||||||
t,
|
|
||||||
rng,
|
|
||||||
x_matrix,
|
|
||||||
y_matrix,
|
|
||||||
z_matrix,
|
|
||||||
w_vector,
|
|
||||||
sign_matrix,
|
|
||||||
sign_matrix_old,
|
|
||||||
column_is_parallel,
|
|
||||||
indices,
|
|
||||||
indices_history,
|
|
||||||
h,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calculate<L>(&mut self, a_linear_operator: &L, itmax: usize) -> f64
|
|
||||||
where
|
|
||||||
L: LinearOperator + ?Sized,
|
|
||||||
{
|
|
||||||
assert!(itmax > 1, "normest1 is undefined for iterations itmax < 2");
|
|
||||||
|
|
||||||
// Explicitly empty the index history; all other quantities will be overwritten at some
|
|
||||||
// point.
|
|
||||||
self.indices_history.clear();
|
|
||||||
|
|
||||||
let n = self.n;
|
|
||||||
let t = self.t;
|
|
||||||
|
|
||||||
let sample = [-1., 1.0];
|
|
||||||
|
|
||||||
// “We now explain our choice of starting matrix. We take the first column of X to be the
|
|
||||||
// vector of 1s, which is the starting vector used in Algorithm 2.1. This has the advantage
|
|
||||||
// that for a matrix with nonnegative elements the algorithm converges with an exact estimate
|
|
||||||
// on the second iteration, and such matrices arise in applications, for example as a
|
|
||||||
// stochastic matrix or as the inverse of an M -matrix.”
|
|
||||||
//
|
|
||||||
// “The remaining columns are chosen as rand {− 1 , 1 } , with a check for and correction of
|
|
||||||
// parallel columns, exactly as for S in the body of the algorithm. We choose random vectors
|
|
||||||
// because it is difficult to argue for any particular fixed vectors and because randomness
|
|
||||||
// lessens the importance of counterexamples (see the comments in the next section).”
|
|
||||||
{
|
|
||||||
let rng_mut = &mut self.rng;
|
|
||||||
self.x_matrix
|
|
||||||
.mapv_inplace(|_| sample[rng_mut.gen_range(0, sample.len())]);
|
|
||||||
self.x_matrix.column_mut(0).fill(1.);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resample the x_matrix to make sure no columns are parallel
|
|
||||||
find_parallel_columns_in(
|
|
||||||
&self.x_matrix,
|
|
||||||
&mut self.y_matrix,
|
|
||||||
&mut self.column_is_parallel,
|
|
||||||
);
|
|
||||||
for (i, is_parallel) in self.column_is_parallel.iter().enumerate() {
|
|
||||||
if *is_parallel {
|
|
||||||
resample_column(&mut self.x_matrix, i, &mut self.rng, &sample);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set all columns to unit vectors
|
|
||||||
self.x_matrix.mapv_inplace(|x| x / n as f64);
|
|
||||||
|
|
||||||
let mut estimate = 0.0;
|
|
||||||
let mut best_index = 0;
|
|
||||||
|
|
||||||
'optimization_loop: for k in 0..itmax {
|
|
||||||
// Y = A X
|
|
||||||
a_linear_operator.multiply_matrix(&mut self.x_matrix, &mut self.y_matrix, false);
|
|
||||||
|
|
||||||
// est = max{‖Y(:,j)‖₁ : j = 1:t}
|
|
||||||
let (max_norm_index, max_norm) = matrix_onenorm_with_index(&self.y_matrix);
|
|
||||||
|
|
||||||
// if est > est_old or k=2
|
|
||||||
if max_norm > estimate || k == 1 {
|
|
||||||
// ind_best = indⱼ where est = ‖Y(:,j)‖₁, w = Y(:, ind_best)
|
|
||||||
estimate = max_norm;
|
|
||||||
best_index = self.indices[max_norm_index];
|
|
||||||
self.w_vector.assign(&self.y_matrix.column(max_norm_index));
|
|
||||||
} else if k > 1 && max_norm <= estimate {
|
|
||||||
break 'optimization_loop;
|
|
||||||
}
|
|
||||||
|
|
||||||
if k >= itmax {
|
|
||||||
break 'optimization_loop;
|
|
||||||
}
|
|
||||||
|
|
||||||
// S = sign(Y)
|
|
||||||
assign_signum_of_array(&self.y_matrix, &mut self.sign_matrix);
|
|
||||||
|
|
||||||
// TODO: Combine the test checking for parallelity between _all_ columns between S
|
|
||||||
// and S_old with the “if t > 1” test below.
|
|
||||||
//
|
|
||||||
// > If every column of S is parallel to a column of Sold, goto (6), end
|
|
||||||
//
|
|
||||||
// NOTE: We are reusing `y_matrix` here as a temporary value.
|
|
||||||
if are_all_columns_parallel_between(
|
|
||||||
&self.sign_matrix_old,
|
|
||||||
&self.sign_matrix,
|
|
||||||
&mut self.y_matrix,
|
|
||||||
) {
|
|
||||||
break 'optimization_loop;
|
|
||||||
}
|
|
||||||
|
|
||||||
// FIXME: Is an explicit if condition here necessary?
|
|
||||||
if t > 1 {
|
|
||||||
// > Ensure that no column of S is parallel to another column of S
|
|
||||||
// > or to a column of Sold by replacing columns of S by rand{-1,+1}
|
|
||||||
//
|
|
||||||
// NOTE: We are reusing `y_matrix` here as a temporary value.
|
|
||||||
resample_parallel_columns(
|
|
||||||
&mut self.sign_matrix,
|
|
||||||
&self.sign_matrix_old,
|
|
||||||
&mut self.y_matrix,
|
|
||||||
&mut self.column_is_parallel,
|
|
||||||
&mut self.rng,
|
|
||||||
&sample,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// > est_old = est, Sold = S
|
|
||||||
// NOTE: Other than in the original algorithm, we store the sign matrix at this point
|
|
||||||
// already. This way, we can reuse the sign matrix as additional workspace which is
|
|
||||||
// useful when performing matrix multiplication with A^m or A1 A2 ... An (see the
|
|
||||||
// description of the LinearOperator trait for explanation).
|
|
||||||
//
|
|
||||||
// NOTE: We don't “save” the old estimate, because we are using max_norm as another name
|
|
||||||
// for the new estimate instead of overwriting/reusing est.
|
|
||||||
self.sign_matrix_old.assign(&self.sign_matrix);
|
|
||||||
|
|
||||||
// Z = A^T S
|
|
||||||
a_linear_operator.multiply_matrix(&mut self.sign_matrix, &mut self.z_matrix, true);
|
|
||||||
|
|
||||||
// hᵢ= ‖Z(i,:)‖_∞
|
|
||||||
let mut max_h = 0.0;
|
|
||||||
for (row, h_element) in self.z_matrix.genrows().into_iter().zip(self.h.iter_mut()) {
|
|
||||||
let h = vector_maxnorm(&row);
|
|
||||||
max_h = if h > max_h { h } else { max_h };
|
|
||||||
// Convert f64 to NotNan for using sort_unstable_by below
|
|
||||||
*h_element = h.into();
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: This test for equality needs an approximate equality test instead.
|
|
||||||
if k > 0 && max_h == self.h[best_index].into() {
|
|
||||||
break 'optimization_loop;
|
|
||||||
}
|
|
||||||
|
|
||||||
// > Sort h so that h_1 >= ... >= h_n and re-order correspondingly.
|
|
||||||
// NOTE: h itself doesn't need to be reordered. Only the order of
|
|
||||||
// the indices is relevant.
|
|
||||||
{
|
|
||||||
let h_ref = &self.h;
|
|
||||||
self.indices
|
|
||||||
.sort_unstable_by(|i, j| h_ref[*j].cmp(&h_ref[*i]));
|
|
||||||
}
|
|
||||||
|
|
||||||
self.x_matrix.fill(0.0);
|
|
||||||
if t > 1 {
|
|
||||||
// > Replace ind(1:t) by the first t indices in ind(1:n) that are not in ind_hist.
|
|
||||||
//
|
|
||||||
// > X(:, j) = e_ind_j, j = 1:t
|
|
||||||
//
|
|
||||||
// > ind_hist = [ind_hist ind(1:t)]
|
|
||||||
//
|
|
||||||
// NOTE: It's not actually needed to operate on the `indices` vector. What's important
|
|
||||||
// is that the history of indices, `indices_history`, gets updated with visited indices,
|
|
||||||
// and that each column of `x_matrix` is assigned that unit vector that is defined by the
|
|
||||||
// respective index.
|
|
||||||
//
|
|
||||||
// If so many indices have already been used that `n_cols - indices_history.len() < t`
|
|
||||||
// (which means that we have less than `t` unused indices remaining), we have to use a few
|
|
||||||
// historical indices when filling up the columns in `x_matrix`. For that, we put the
|
|
||||||
// historical indices after the fresh indices, but otherwise keep the order induced by `h`
|
|
||||||
// above.
|
|
||||||
let fresh_indices = cmp::min(t, n - self.indices_history.len());
|
|
||||||
if fresh_indices == 0 {
|
|
||||||
break 'optimization_loop;
|
|
||||||
}
|
|
||||||
let mut current_column_fresh = 0;
|
|
||||||
let mut current_column_historical = fresh_indices;
|
|
||||||
let mut index_iterator = self.indices.iter();
|
|
||||||
|
|
||||||
let mut all_first_t_in_history = true;
|
|
||||||
// First, iterate over the first t sorted indices.
|
|
||||||
for i in (&mut index_iterator).take(t) {
|
|
||||||
if !self.indices_history.contains(i) {
|
|
||||||
all_first_t_in_history = false;
|
|
||||||
self.x_matrix[(*i, current_column_fresh)] = 1.0;
|
|
||||||
current_column_fresh += 1;
|
|
||||||
self.indices_history.insert(*i);
|
|
||||||
} else if current_column_historical < t {
|
|
||||||
self.x_matrix[(*i, current_column_historical)] = 1.0;
|
|
||||||
current_column_historical += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// > if ind(1:t) is contained in ind_hist, goto (6), end
|
|
||||||
if all_first_t_in_history {
|
|
||||||
break 'optimization_loop;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate over the remaining indices
|
|
||||||
'fill_x: for i in index_iterator {
|
|
||||||
if current_column_fresh >= t {
|
|
||||||
break 'fill_x;
|
|
||||||
}
|
|
||||||
if !self.indices_history.contains(i) {
|
|
||||||
self.x_matrix[(*i, current_column_fresh)] = 1.0;
|
|
||||||
current_column_fresh += 1;
|
|
||||||
self.indices_history.insert(*i);
|
|
||||||
} else if current_column_historical < t {
|
|
||||||
self.x_matrix[(*i, current_column_historical)] = 1.0;
|
|
||||||
current_column_historical += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
estimate
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Estimate the 1-norm of matrix `a` using up to `itmax` iterations.
|
|
||||||
pub fn normest1<S>(&mut self, a: &ArrayBase<S, Ix2>, itmax: usize) -> f64
|
|
||||||
where
|
|
||||||
S: Data<Elem = f64>,
|
|
||||||
{
|
|
||||||
self.calculate(a, itmax)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Estimate the 1-norm of a marix `a` to the power `m` up to `itmax` iterations.
|
|
||||||
pub fn normest1_pow<S>(&mut self, a: &ArrayBase<S, Ix2>, m: usize, itmax: usize) -> f64
|
|
||||||
where
|
|
||||||
S: Data<Elem = f64>,
|
|
||||||
{
|
|
||||||
self.calculate(&(a, m), itmax)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Estimate the 1-norm of a product of matrices `a1 a2 ... an` up to `itmax` iterations.
|
|
||||||
pub fn normest1_prod<S>(&mut self, aprod: &[&ArrayBase<S, Ix2>], itmax: usize) -> f64
|
|
||||||
where
|
|
||||||
S: Data<Elem = f64>,
|
|
||||||
{
|
|
||||||
self.calculate(aprod, itmax)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Assigns the sign of matrix `a` to matrix `b`.
|
|
||||||
///
|
|
||||||
/// Panics if matrices `a` and `b` have different shape and strides, or if either underlying array is
|
|
||||||
/// non-contiguous. This is to make sure that the iteration order over the matrices is the same.
|
|
||||||
fn assign_signum_of_array<S1, S2, D>(a: &ArrayBase<S1, D>, b: &mut ArrayBase<S2, D>)
|
|
||||||
where
|
|
||||||
S1: Data<Elem = f64>,
|
|
||||||
S2: DataMut<Elem = f64>,
|
|
||||||
D: Dimension,
|
|
||||||
{
|
|
||||||
assert_eq!(a.strides(), b.strides());
|
|
||||||
let (a_slice, a_layout) = as_slice_with_layout(a).expect("Matrix `a` is not contiguous.");
|
|
||||||
let (b_slice, b_layout) = as_slice_with_layout_mut(b).expect("Matrix `b` is not contiguous.");
|
|
||||||
assert_eq!(a_layout, b_layout);
|
|
||||||
|
|
||||||
signum_of_slice(a_slice, b_slice);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn signum_of_slice(source: &[f64], destination: &mut [f64]) {
|
|
||||||
for (s, d) in source.iter().zip(destination) {
|
|
||||||
*d = s.signum();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Calculate the onenorm of a vector (an `ArrayBase` with dimension `Ix1`).
|
|
||||||
fn vector_onenorm<S>(a: &ArrayBase<S, Ix1>) -> f64
|
|
||||||
where
|
|
||||||
S: Data<Elem = f64>,
|
|
||||||
{
|
|
||||||
let stride = a.strides()[0];
|
|
||||||
assert!(stride >= 0);
|
|
||||||
let stride = stride as usize;
|
|
||||||
let n_elements = a.len();
|
|
||||||
let a_slice = {
|
|
||||||
let a = a.as_ptr();
|
|
||||||
let total_len = n_elements * stride;
|
|
||||||
unsafe { slice::from_raw_parts(a, total_len) }
|
|
||||||
};
|
|
||||||
|
|
||||||
unsafe { cblas::dasum(n_elements as i32, a_slice, stride as i32) }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Calculate the maximum norm of a vector (an `ArrayBase` with dimension `Ix1`).
|
|
||||||
fn vector_maxnorm<S>(a: &ArrayBase<S, Ix1>) -> f64
|
|
||||||
where
|
|
||||||
S: Data<Elem = f64>,
|
|
||||||
{
|
|
||||||
let stride = a.strides()[0];
|
|
||||||
assert!(stride >= 0);
|
|
||||||
let stride = stride as usize;
|
|
||||||
let n_elements = a.len();
|
|
||||||
let a_slice = {
|
|
||||||
let a = a.as_ptr();
|
|
||||||
let total_len = n_elements * stride;
|
|
||||||
unsafe { slice::from_raw_parts(a, total_len) }
|
|
||||||
};
|
|
||||||
|
|
||||||
let idx = unsafe { cblas::idamax(n_elements as i32, a_slice, stride as i32) as usize };
|
|
||||||
f64::abs(a[idx])
|
|
||||||
}
|
|
||||||
|
|
||||||
// /// Calculate the onenorm of a matrix (an `ArrayBase` with dimension `Ix2`).
|
|
||||||
// fn matrix_onenorm<S>(a: &ArrayBase<S, Ix2>) -> f64
|
|
||||||
// where S: Data<Elem=f64>,
|
|
||||||
// {
|
|
||||||
// let (n_rows, n_cols) = a.dim();
|
|
||||||
// if let Some((a_slice, layout)) = as_slice_with_layout(a) {
|
|
||||||
// let layout = match layout {
|
|
||||||
// cblas::Layout::RowMajor => lapacke::Layout::RowMajor,
|
|
||||||
// cblas::Layout::ColumnMajor => lapacke::Layout::ColumnMajor,
|
|
||||||
// };
|
|
||||||
// unsafe {
|
|
||||||
// lapacke::dlange(
|
|
||||||
// layout,
|
|
||||||
// b'1',
|
|
||||||
// n_rows as i32,
|
|
||||||
// n_cols as i32,
|
|
||||||
// a_slice,
|
|
||||||
// n_rows as i32,
|
|
||||||
// )
|
|
||||||
// }
|
|
||||||
// // Fall through case for non-contiguous arrays.
|
|
||||||
// } else {
|
|
||||||
// a.gencolumns().into_iter()
|
|
||||||
// .fold(0.0, |max, column| {
|
|
||||||
// let onenorm = column.fold(0.0, |acc, element| { acc + f64::abs(*element) });
|
|
||||||
// if onenorm > max { onenorm } else { max }
|
|
||||||
// })
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
/// Returns the one-norm of a matrix `a` together with the index of that column for
|
|
||||||
/// which the norm is maximal.
|
|
||||||
fn matrix_onenorm_with_index<S>(a: &ArrayBase<S, Ix2>) -> (usize, f64)
|
|
||||||
where
|
|
||||||
S: Data<Elem = f64>,
|
|
||||||
{
|
|
||||||
let mut max_norm = 0.0;
|
|
||||||
let mut max_norm_index = 0;
|
|
||||||
for (i, column) in a.gencolumns().into_iter().enumerate() {
|
|
||||||
let norm = vector_onenorm(&column);
|
|
||||||
if norm > max_norm {
|
|
||||||
max_norm = norm;
|
|
||||||
max_norm_index = i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(max_norm_index, max_norm)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Finds columns in the matrix `a` that are parallel to to some other column in `a`.
|
|
||||||
///
|
|
||||||
/// Assumes that all entries of `a` are either +1 or -1.
|
|
||||||
///
|
|
||||||
/// If column `j` of matrix `a` is parallel to some column `i`, `column_is_parallel[i]` is set to
|
|
||||||
/// `true`. The matrix `c` is used as an intermediate value for the matrix product `a^t * a`.
|
|
||||||
///
|
|
||||||
/// This function does not reset `column_is_parallel` to `false`. Entries that are `true` will be
|
|
||||||
/// assumed to be parallel and not checked.
|
|
||||||
///
|
|
||||||
/// Panics if arrays `a` and `c` don't have the same dimensions, or if the length of the slice
|
|
||||||
/// `column_is_parallel` is not equal to the number of columns in `a`.
|
|
||||||
fn find_parallel_columns_in<S1, S2>(
|
|
||||||
a: &ArrayBase<S1, Ix2>,
|
|
||||||
c: &mut ArrayBase<S2, Ix2>,
|
|
||||||
column_is_parallel: &mut [bool],
|
|
||||||
) where
|
|
||||||
S1: Data<Elem = f64>,
|
|
||||||
S2: DataMut<Elem = f64>,
|
|
||||||
{
|
|
||||||
let a_dim = a.dim();
|
|
||||||
let c_dim = c.dim();
|
|
||||||
assert_eq!(a_dim, c_dim);
|
|
||||||
|
|
||||||
let (n_rows, n_cols) = a_dim;
|
|
||||||
|
|
||||||
assert_eq!(column_is_parallel.len(), n_cols);
|
|
||||||
{
|
|
||||||
let (a_slice, a_layout) = as_slice_with_layout(a).expect("Matrix `a` is not contiguous.");
|
|
||||||
let (c_slice, c_layout) =
|
|
||||||
as_slice_with_layout_mut(c).expect("Matrix `c` is not contiguous.");
|
|
||||||
assert_eq!(a_layout, c_layout);
|
|
||||||
let layout = a_layout;
|
|
||||||
|
|
||||||
// NOTE: When calling the wrapped Fortran dsyrk subroutine with row major layout,
|
|
||||||
// cblas::*syrk changes `'U'` to `'L'` (`Upper` to `Lower`), and `'O'` to `'N'` (`Ordinary`
|
|
||||||
// to `None`). Different from `cblas::*gemm`, however, it does not automatically make sure
|
|
||||||
// that the other arguments are changed to make sense in a routine expecting column major
|
|
||||||
// order (in `cblas::*gemm`, this happens by flipping the matrices `a` and `b` as
|
|
||||||
// arguments).
|
|
||||||
//
|
|
||||||
// So while `cblas::dsyrk` changes transposition and the position of where the results are
|
|
||||||
// written to, it passes the other arguments on to the Fortran routine as is.
|
|
||||||
//
|
|
||||||
// For example, in case matrix `a` is a 4x2 matrix in column major order, and we want to
|
|
||||||
// perform the operation `a^T a` on it (resulting in a symmetric 2x2 matrix), we would pass
|
|
||||||
// TRANS='T', N=2 (order of c), K=4 (number of rows because of 'T'), LDA=4 (max(1,k)
|
|
||||||
// because of 'T'), LDC=2.
|
|
||||||
//
|
|
||||||
// But if `a` is in row major order and we want to perform the same operation, we pass
|
|
||||||
// TRANS='T' (gets translated to 'N'), N=2, K=2 (number of columns, because we 'T' -> 'N'),
|
|
||||||
// LDA=2 (max(1,n) because of 'N'), LDC=2.
|
|
||||||
//
|
|
||||||
// In other words, because of row major order, the Fortran routine actually sees our 4x2
|
|
||||||
// matrix as a 2x4 matrix, and if we want to calculate `a^T a`, `cblas::dsyrk` makes sure
|
|
||||||
// `'N'` is passed.
|
|
||||||
let (k, lda) = match layout {
|
|
||||||
cblas::Layout::ColumnMajor => (n_cols, n_rows),
|
|
||||||
cblas::Layout::RowMajor => (n_rows, n_cols),
|
|
||||||
};
|
|
||||||
unsafe {
|
|
||||||
cblas::dsyrk(
|
|
||||||
layout,
|
|
||||||
cblas::Part::Upper,
|
|
||||||
cblas::Transpose::Ordinary,
|
|
||||||
n_cols as i32,
|
|
||||||
k as i32,
|
|
||||||
1.0,
|
|
||||||
a_slice,
|
|
||||||
lda as i32,
|
|
||||||
0.0,
|
|
||||||
c_slice,
|
|
||||||
n_cols as i32,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// c is upper triangular and contains all pair-wise vector products:
|
|
||||||
//
|
|
||||||
// x x x x x
|
|
||||||
// . x x x x
|
|
||||||
// . . x x x
|
|
||||||
// . . . x x
|
|
||||||
// . . . . x
|
|
||||||
|
|
||||||
// Don't check more rows than we have columns
|
|
||||||
'rows: for (i, row) in c.genrows().into_iter().enumerate().take(n_cols) {
|
|
||||||
// Skip if the column is already found to be parallel or if we are checking
|
|
||||||
// the last column
|
|
||||||
if column_is_parallel[i] || i >= n_cols - 1 {
|
|
||||||
continue 'rows;
|
|
||||||
}
|
|
||||||
for (j, element) in row.slice(s![i + 1..]).iter().enumerate() {
|
|
||||||
// Check if the vectors are parallel or anti-parallel
|
|
||||||
if f64::abs(*element) == n_rows as f64 {
|
|
||||||
column_is_parallel[i + j + 1] = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Checks whether any columns of the matrix `a` are parallel to any columns of `b`.
|
|
||||||
///
|
|
||||||
/// Assumes that we have parallelity only if all entries of two columns `a` and `b` are either +1
|
|
||||||
/// or -1.
|
|
||||||
///
|
|
||||||
/// `The matrix `c` is used as an intermediate value for the matrix product `a^t * b`.
|
|
||||||
///
|
|
||||||
/// `column_is_parallel[j]` is set to `true` if column `j` of matrix `a` is parallel to some column
|
|
||||||
/// `i` of the matrix `b`,
|
|
||||||
///
|
|
||||||
/// This function does not reset `column_is_parallel` to `false`. Entries that are `true` will be
|
|
||||||
/// assumed to be parallel and not checked.
|
|
||||||
///
|
|
||||||
/// Panics if arrays `a`, `b`, and `c` don't have the same dimensions, or if the length of the slice
|
|
||||||
/// `column_is_parallel` is not equal to the number of columns in `a`.
|
|
||||||
fn find_parallel_columns_between<S1, S2, S3>(
|
|
||||||
a: &ArrayBase<S1, Ix2>,
|
|
||||||
b: &ArrayBase<S2, Ix2>,
|
|
||||||
c: &mut ArrayBase<S3, Ix2>,
|
|
||||||
column_is_parallel: &mut [bool],
|
|
||||||
) where
|
|
||||||
S1: Data<Elem = f64>,
|
|
||||||
S2: Data<Elem = f64>,
|
|
||||||
S3: DataMut<Elem = f64>,
|
|
||||||
{
|
|
||||||
let a_dim = a.dim();
|
|
||||||
let b_dim = b.dim();
|
|
||||||
let c_dim = c.dim();
|
|
||||||
assert_eq!(a_dim, b_dim);
|
|
||||||
assert_eq!(a_dim, c_dim);
|
|
||||||
|
|
||||||
let (n_rows, n_cols) = a_dim;
|
|
||||||
|
|
||||||
assert_eq!(column_is_parallel.len(), n_cols);
|
|
||||||
|
|
||||||
// Extra scope, because c_slice needs to be dropped after the dgemm
|
|
||||||
{
|
|
||||||
let (a_slice, a_layout) = as_slice_with_layout(&a).expect("Matrix `a` not contiguous.");
|
|
||||||
let (b_slice, b_layout) = as_slice_with_layout(&b).expect("Matrix `b` not contiguous.");
|
|
||||||
let (c_slice, c_layout) = as_slice_with_layout_mut(c).expect("Matrix `c` not contiguous.");
|
|
||||||
|
|
||||||
assert_eq!(a_layout, b_layout);
|
|
||||||
assert_eq!(a_layout, c_layout);
|
|
||||||
|
|
||||||
let layout = a_layout;
|
|
||||||
|
|
||||||
unsafe {
|
|
||||||
cblas::dgemm(
|
|
||||||
layout,
|
|
||||||
cblas::Transpose::Ordinary,
|
|
||||||
cblas::Transpose::None,
|
|
||||||
n_cols as i32,
|
|
||||||
n_cols as i32,
|
|
||||||
n_rows as i32,
|
|
||||||
1.0,
|
|
||||||
a_slice,
|
|
||||||
n_cols as i32,
|
|
||||||
b_slice,
|
|
||||||
n_cols as i32,
|
|
||||||
0.0,
|
|
||||||
c_slice,
|
|
||||||
n_cols as i32,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We are iterating over the rows because it's more memory efficient (for row-major array). In
|
|
||||||
// terms of logic there is no difference: we simply check if the current column of a (that's
|
|
||||||
// the outer loop) is parallel to any column of b (inner loop). By iterating via columns we would check if
|
|
||||||
// any column of a is parallel to the, in that case, current column of b.
|
|
||||||
// TODO: Implement for column major arrays.
|
|
||||||
'rows: for (i, row) in c.genrows().into_iter().enumerate().take(n_cols) {
|
|
||||||
// Skip if the column is already found to be parallel the last column.
|
|
||||||
if column_is_parallel[i] {
|
|
||||||
continue 'rows;
|
|
||||||
}
|
|
||||||
for element in row {
|
|
||||||
if f64::abs(*element) == n_rows as f64 {
|
|
||||||
column_is_parallel[i] = true;
|
|
||||||
continue 'rows;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if every column in `a` is parallel to some column in `b`.
|
|
||||||
///
|
|
||||||
/// Assumes that we have parallelity only if all entries of two columns `a` and `b` are either +1
|
|
||||||
/// or -1.
|
|
||||||
fn are_all_columns_parallel_between<S1, S2>(
|
|
||||||
a: &ArrayBase<S1, Ix2>,
|
|
||||||
b: &ArrayBase<S1, Ix2>,
|
|
||||||
c: &mut ArrayBase<S2, Ix2>,
|
|
||||||
) -> bool
|
|
||||||
where
|
|
||||||
S1: Data<Elem = f64>,
|
|
||||||
S2: DataMut<Elem = f64>,
|
|
||||||
{
|
|
||||||
let a_dim = a.dim();
|
|
||||||
let b_dim = b.dim();
|
|
||||||
let c_dim = c.dim();
|
|
||||||
assert_eq!(a_dim, b_dim);
|
|
||||||
assert_eq!(a_dim, c_dim);
|
|
||||||
|
|
||||||
let (n_rows, n_cols) = a_dim;
|
|
||||||
|
|
||||||
// Extra scope, because c_slice needs to be dropped after the dgemm
|
|
||||||
{
|
|
||||||
let (a_slice, a_layout) = as_slice_with_layout(&a).expect("Matrix `a` not contiguous.");
|
|
||||||
let (b_slice, b_layout) = as_slice_with_layout(&b).expect("Matrix `b` not contiguous.");
|
|
||||||
let (c_slice, c_layout) = as_slice_with_layout_mut(c).expect("Matrix `c` not contiguous.");
|
|
||||||
|
|
||||||
assert_eq!(a_layout, b_layout);
|
|
||||||
assert_eq!(a_layout, c_layout);
|
|
||||||
|
|
||||||
let layout = a_layout;
|
|
||||||
|
|
||||||
unsafe {
|
|
||||||
cblas::dgemm(
|
|
||||||
layout,
|
|
||||||
cblas::Transpose::Ordinary,
|
|
||||||
cblas::Transpose::None,
|
|
||||||
n_cols as i32,
|
|
||||||
n_cols as i32,
|
|
||||||
n_rows as i32,
|
|
||||||
1.0,
|
|
||||||
a_slice,
|
|
||||||
n_cols as i32,
|
|
||||||
b_slice,
|
|
||||||
n_cols as i32,
|
|
||||||
0.0,
|
|
||||||
c_slice,
|
|
||||||
n_rows as i32,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We are iterating over the rows because it's more memory efficient (for row-major array). In
|
|
||||||
// terms of logic there is no difference: we simply check if a specific column of a is parallel
|
|
||||||
// to any column of b. By iterating via columns we would check if any column of a is parallel
|
|
||||||
// to a specific column of b.
|
|
||||||
'rows: for row in c.genrows() {
|
|
||||||
for element in row {
|
|
||||||
// If a parallel column was found, cut to the next one.
|
|
||||||
if f64::abs(*element) == n_rows as f64 {
|
|
||||||
continue 'rows;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// This return statement should only be reached if not a single column parallel to the
|
|
||||||
// current one was found.
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Find parallel columns in matrix `a` and columns in `a` that are parallel to any columns in
|
|
||||||
/// matrix `b`, and replace those with random vectors. Returns `true` if resampling has taken place.
|
|
||||||
fn resample_parallel_columns<S1, S2, S3, R>(
|
|
||||||
a: &mut ArrayBase<S1, Ix2>,
|
|
||||||
b: &ArrayBase<S2, Ix2>,
|
|
||||||
c: &mut ArrayBase<S3, Ix2>,
|
|
||||||
column_is_parallel: &mut [bool],
|
|
||||||
rng: &mut R,
|
|
||||||
sample: &[f64],
|
|
||||||
) -> bool
|
|
||||||
where
|
|
||||||
S1: DataMut<Elem = f64>,
|
|
||||||
S2: Data<Elem = f64>,
|
|
||||||
S3: DataMut<Elem = f64>,
|
|
||||||
R: Rng,
|
|
||||||
{
|
|
||||||
column_is_parallel.iter_mut().for_each(|x| {
|
|
||||||
*x = false;
|
|
||||||
});
|
|
||||||
find_parallel_columns_in(a, c, column_is_parallel);
|
|
||||||
find_parallel_columns_between(a, b, c, column_is_parallel);
|
|
||||||
let mut has_resampled = false;
|
|
||||||
for (i, is_parallel) in column_is_parallel.iter_mut().enumerate() {
|
|
||||||
if *is_parallel {
|
|
||||||
resample_column(a, i, rng, sample);
|
|
||||||
has_resampled = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
has_resampled
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Resamples column `i` of matrix `a` with elements drawn from `sample` using `rng`.
|
|
||||||
///
|
|
||||||
/// Panics if `i` exceeds the number of columns in `a`.
|
|
||||||
fn resample_column<R, S>(a: &mut ArrayBase<S, Ix2>, i: usize, rng: &mut R, sample: &[f64])
|
|
||||||
where
|
|
||||||
S: DataMut<Elem = f64>,
|
|
||||||
R: Rng,
|
|
||||||
{
|
|
||||||
assert!(
|
|
||||||
i < a.dim().1,
|
|
||||||
"Trying to resample column with index exceeding matrix dimensions"
|
|
||||||
);
|
|
||||||
assert!(!sample.is_empty());
|
|
||||||
a.column_mut(i)
|
|
||||||
.mapv_inplace(|_| sample[rng.gen_range(0, sample.len())]);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns slice and layout underlying an array `a`.
|
|
||||||
fn as_slice_with_layout<S, T, D>(a: &ArrayBase<S, D>) -> Option<(&[T], cblas::Layout)>
|
|
||||||
where
|
|
||||||
S: Data<Elem = T>,
|
|
||||||
D: Dimension,
|
|
||||||
{
|
|
||||||
if let Some(a_slice) = a.as_slice() {
|
|
||||||
Some((a_slice, cblas::Layout::RowMajor))
|
|
||||||
} else if let Some(a_slice) = a.as_slice_memory_order() {
|
|
||||||
Some((a_slice, cblas::Layout::ColumnMajor))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns mutable slice and layout underlying an array `a`.
|
|
||||||
fn as_slice_with_layout_mut<S, T, D>(a: &mut ArrayBase<S, D>) -> Option<(&mut [T], cblas::Layout)>
|
|
||||||
where
|
|
||||||
S: DataMut<Elem = T>,
|
|
||||||
D: Dimension,
|
|
||||||
{
|
|
||||||
if a.as_slice_mut().is_some() {
|
|
||||||
Some((a.as_slice_mut().unwrap(), cblas::Layout::RowMajor))
|
|
||||||
} else if a.as_slice_memory_order_mut().is_some() {
|
|
||||||
Some((
|
|
||||||
a.as_slice_memory_order_mut().unwrap(),
|
|
||||||
cblas::Layout::ColumnMajor,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
// XXX: The above is a workaround for Rust not having non-lexical lifetimes yet.
|
|
||||||
// More information here:
|
|
||||||
// http://smallcultfollowing.com/babysteps/blog/2016/04/27/non-lexical-lifetimes-introduction/#problem-case-3-conditional-control-flow-across-functions
|
|
||||||
//
|
|
||||||
// if let Some(slice) = a.as_slice_mut() {
|
|
||||||
// Some((slice, cblas::Layout::RowMajor))
|
|
||||||
// } else if let Some(slice) = a.as_slice_memory_order_mut() {
|
|
||||||
// Some((slice, cblas::Layout::ColumnMajor))
|
|
||||||
// } else {
|
|
||||||
// None
|
|
||||||
// }
|
|
||||||
}
|
|
||||||
781
src/expm.rs
781
src/expm.rs
@@ -1,781 +0,0 @@
|
|||||||
/// This crate contains `expm`, an implementation of Algorithm 6.1 by [Al-Mohy, Higham] in the Rust
|
|
||||||
/// programming language. It calculates the exponential of a matrix. See the linked paper for more
|
|
||||||
/// information.
|
|
||||||
///
|
|
||||||
/// An important ingredient is `normest1`, Algorithm 2.4 in [Higham, Tisseur], which estimates
|
|
||||||
/// the 1-norm of a matrix.
|
|
||||||
///
|
|
||||||
/// Furthermore, to fully understand the algorithm as described in the original paper, one has to
|
|
||||||
/// understand that the factor $\lvert C_{2m+1} \rvert$ arises during the Padé approximation of the
|
|
||||||
/// exponential function. The derivation is described in [Gautschi 2012], pp. 363--365, and the
|
|
||||||
/// factor reads:
|
|
||||||
///
|
|
||||||
/// \begin{equation}
|
|
||||||
/// C_{n,m} = (-1)^n \frac{n!m!}{(n+m)!(n+m+1)!},
|
|
||||||
/// \end{equation}
|
|
||||||
///
|
|
||||||
/// or using only the diagonal elements, $m=n$:
|
|
||||||
///
|
|
||||||
/// \begin{equation}
|
|
||||||
/// C_m = (-1)^m \frac{m!m!}{(2m)!(2m+1)!}
|
|
||||||
/// \end{equation}
|
|
||||||
///
|
|
||||||
///
|
|
||||||
/// [Al-Mohy, Higham]: http://eprints.ma.man.ac.uk/1300/1/covered/MIMS_ep2009_9.pdf
|
|
||||||
/// [Higham, Tisseur]: http://eprints.ma.man.ac.uk/321/1/covered/MIMS_ep2006_145.pdf
|
|
||||||
/// [Gautschi 2012]: https://doi.org/10.1007/978-0-8176-8259-0
|
|
||||||
use ndarray::{self, prelude::*, Data, DataMut, Dimension, Zip};
|
|
||||||
|
|
||||||
use crate::condest::Normest1;
|
|
||||||
|
|
||||||
// Can we calculate these at compile time?
|
|
||||||
const THETA_3: f64 = 1.495585217958292e-2;
|
|
||||||
const THETA_5: f64 = 2.539398330063230e-1;
|
|
||||||
const THETA_7: f64 = 9.504178996162932e-1;
|
|
||||||
const THETA_9: f64 = 2.097847961257068e0;
|
|
||||||
// const THETA_13: f64 = 5.371920351148152e0 // Alg 3.1
|
|
||||||
const THETA_13: f64 = 4.25; // Alg 5.1
|
|
||||||
|
|
||||||
const PADE_COEFF_3: [f64; 4] = [120., 60., 12., 1.];
|
|
||||||
const PADE_COEFF_5: [f64; 6] = [30240., 15120., 3360., 420., 30., 1.];
|
|
||||||
const PADE_COEFF_7: [f64; 8] = [
|
|
||||||
17297280., 8648640., 1995840., 277200., 25200., 1512., 56., 1.,
|
|
||||||
];
|
|
||||||
const PADE_COEFF_9: [f64; 10] = [
|
|
||||||
17643225600.,
|
|
||||||
8821612800.,
|
|
||||||
2075673600.,
|
|
||||||
302702400.,
|
|
||||||
30270240.,
|
|
||||||
2162160.,
|
|
||||||
110880.,
|
|
||||||
3960.,
|
|
||||||
90.,
|
|
||||||
1.,
|
|
||||||
];
|
|
||||||
const PADE_COEFF_13: [f64; 14] = [
|
|
||||||
64764752532480000.,
|
|
||||||
32382376266240000.,
|
|
||||||
7771770303897600.,
|
|
||||||
1187353796428800.,
|
|
||||||
129060195264000.,
|
|
||||||
10559470521600.,
|
|
||||||
670442572800.,
|
|
||||||
33522128640.,
|
|
||||||
1323241920.,
|
|
||||||
40840800.,
|
|
||||||
960960.,
|
|
||||||
16380.,
|
|
||||||
182.,
|
|
||||||
1.,
|
|
||||||
];
|
|
||||||
|
|
||||||
/// Calculates the of leading terms in the backward error function for the [m/m] Padé approximant
|
|
||||||
/// to the exponential function, i.e. it calculates:
|
|
||||||
///
|
|
||||||
/// \begin{align}
|
|
||||||
/// C_{2m+1} &= \frac{(m!)^2}{(2m)!(2m+1)!} \\
|
|
||||||
/// &= \frac{1}{\binom{2m}{m} (2m+1)!}
|
|
||||||
/// \end{align}
|
|
||||||
///
|
|
||||||
/// NOTE: Depending on the notation used in the scientific papers, the coefficient `C` is,
|
|
||||||
/// confusingly, sometimes indexed `C_i` and sometimes `C_{2m+1}`. These essentially mean the same
|
|
||||||
/// thing and is due to the power series expansion of the backward error function:
|
|
||||||
///
|
|
||||||
/// \begin{equation}
|
|
||||||
/// h(x) = \sum^\infty_{i=2m+1} C_i x^i
|
|
||||||
/// \end{equation}
|
|
||||||
fn pade_error_coefficient(m: u64) -> f64 {
|
|
||||||
use statrs::function::factorial::{binomial, factorial};
|
|
||||||
|
|
||||||
1.0 / (binomial(2 * m, m) * factorial(2 * m + 1))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(non_camel_case_types)]
|
|
||||||
struct PadeOrder_3;
|
|
||||||
#[allow(non_camel_case_types)]
|
|
||||||
struct PadeOrder_5;
|
|
||||||
#[allow(non_camel_case_types)]
|
|
||||||
struct PadeOrder_7;
|
|
||||||
#[allow(non_camel_case_types)]
|
|
||||||
struct PadeOrder_9;
|
|
||||||
#[allow(non_camel_case_types)]
|
|
||||||
struct PadeOrder_13;
|
|
||||||
|
|
||||||
enum PadeOrders {
|
|
||||||
_3,
|
|
||||||
_5,
|
|
||||||
_7,
|
|
||||||
_9,
|
|
||||||
_13,
|
|
||||||
}
|
|
||||||
|
|
||||||
trait PadeOrder {
|
|
||||||
const ORDER: u64;
|
|
||||||
|
|
||||||
/// Return the coefficients arising in both the numerator as well as in the denominator of the
|
|
||||||
/// Padé approximant (they are the same, due to $p(x) = q(-x)$.
|
|
||||||
///
|
|
||||||
/// TODO: This is a great usecase for const generics, returning &[u64; Self::ORDER] and
|
|
||||||
/// possibly calculating the values at compile time instead of hardcoding them.
|
|
||||||
/// Maybe possible once RFC 2000 lands? See the PR https://github.com/rust-lang/rust/pull/53645
|
|
||||||
fn coefficients() -> &'static [f64];
|
|
||||||
|
|
||||||
fn calculate_pade_sums<S1, S2, S3>(
|
|
||||||
a: &ArrayBase<S1, Ix2>,
|
|
||||||
a_powers: &[&ArrayBase<S1, Ix2>],
|
|
||||||
u: &mut ArrayBase<S2, Ix2>,
|
|
||||||
v: &mut ArrayBase<S3, Ix2>,
|
|
||||||
work: &mut ArrayBase<S2, Ix2>,
|
|
||||||
) where
|
|
||||||
S1: Data<Elem = f64>,
|
|
||||||
S2: DataMut<Elem = f64>,
|
|
||||||
S3: DataMut<Elem = f64>;
|
|
||||||
}
|
|
||||||
|
|
||||||
macro_rules! impl_padeorder {
|
|
||||||
($($ty:ty, $m:literal, $const_coeff:ident),+) => {
|
|
||||||
|
|
||||||
$(
|
|
||||||
|
|
||||||
impl PadeOrder for $ty {
|
|
||||||
const ORDER: u64 = $m;
|
|
||||||
|
|
||||||
fn coefficients() -> &'static [f64] {
|
|
||||||
&$const_coeff
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calculate_pade_sums<S1, S2, S3>(
|
|
||||||
a: &ArrayBase<S1, Ix2>,
|
|
||||||
a_powers: &[&ArrayBase<S1, Ix2>],
|
|
||||||
u: &mut ArrayBase<S2, Ix2>,
|
|
||||||
v: &mut ArrayBase<S3, Ix2>,
|
|
||||||
work: &mut ArrayBase<S2, Ix2>,
|
|
||||||
)
|
|
||||||
where S1: Data<Elem=f64>,
|
|
||||||
S2: DataMut<Elem=f64>,
|
|
||||||
S3: DataMut<Elem=f64>,
|
|
||||||
{
|
|
||||||
assert_eq!(a_powers.len(), ($m - 1)/2 + 1);
|
|
||||||
|
|
||||||
let (n_rows, n_cols) = a.dim();
|
|
||||||
assert_eq!(n_rows, n_cols, "Pade sum only defined for square matrices.");
|
|
||||||
let n = n_rows as i32;
|
|
||||||
|
|
||||||
// Iterator to get 2 coefficients, c_{2i} and c_{2i+1}, and 1 matrix power at a time.
|
|
||||||
let mut iterator = Self::coefficients().chunks_exact(2).zip(a_powers.iter());
|
|
||||||
|
|
||||||
// First element from the iterator.
|
|
||||||
//
|
|
||||||
// NOTE: The unwrap() and unreachable!() are permissable because the assertion above
|
|
||||||
// ensures the validity.
|
|
||||||
//
|
|
||||||
// TODO: An optimization is probably to just set u and v to zero and only assign the
|
|
||||||
// coefficients to its diagonal, given that A_0 = A^0 = 1.
|
|
||||||
let (c_0, c_1, a_pow) = match iterator.next().unwrap() {
|
|
||||||
(&[c_0, c_1], a_pow) => (c_0, c_1, a_pow),
|
|
||||||
_ => unreachable!()
|
|
||||||
};
|
|
||||||
|
|
||||||
work.zip_mut_with(a_pow, |x, &y| *x = c_1 * y);
|
|
||||||
v.zip_mut_with(a_pow, |x, &y| *x = c_0 * y);
|
|
||||||
|
|
||||||
// Rest of the iterator
|
|
||||||
while let Some(item) = iterator.next() {
|
|
||||||
let (c_2k, c_2k1, a_pow) = match item {
|
|
||||||
(&[c_2k, c_2k1], a_pow) => (c_2k, c_2k1, a_pow),
|
|
||||||
_ => unreachable!()
|
|
||||||
};
|
|
||||||
|
|
||||||
work.zip_mut_with(a_pow, |x, &y| *x += c_2k1 * y);
|
|
||||||
v.zip_mut_with(a_pow, |x, &y| *x += c_2k * y);
|
|
||||||
}
|
|
||||||
|
|
||||||
let (a_slice, a_layout) = as_slice_with_layout(a).expect("Matrix `a` not contiguous.");
|
|
||||||
let (work_slice, _) = as_slice_with_layout(work).expect("Matrix `work` not contiguous.");
|
|
||||||
let (u_slice, u_layout) = as_slice_with_layout_mut(u).expect("Matrix `u` not contiguous.");
|
|
||||||
assert_eq!(a_layout, u_layout, "Memory layout mismatch between matrices; currently only row major matrices are supported.");
|
|
||||||
let layout = a_layout;
|
|
||||||
unsafe {
|
|
||||||
cblas::dgemm(
|
|
||||||
layout,
|
|
||||||
cblas::Transpose::None,
|
|
||||||
cblas::Transpose::None,
|
|
||||||
n,
|
|
||||||
n,
|
|
||||||
n,
|
|
||||||
1.0,
|
|
||||||
a_slice,
|
|
||||||
n,
|
|
||||||
work_slice,
|
|
||||||
n,
|
|
||||||
0.0,
|
|
||||||
u_slice,
|
|
||||||
n,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
)+
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl_padeorder!(
|
|
||||||
PadeOrder_3,
|
|
||||||
3,
|
|
||||||
PADE_COEFF_3,
|
|
||||||
PadeOrder_5,
|
|
||||||
5,
|
|
||||||
PADE_COEFF_5,
|
|
||||||
PadeOrder_7,
|
|
||||||
7,
|
|
||||||
PADE_COEFF_7,
|
|
||||||
PadeOrder_9,
|
|
||||||
9,
|
|
||||||
PADE_COEFF_9
|
|
||||||
);
|
|
||||||
|
|
||||||
impl PadeOrder for PadeOrder_13 {
|
|
||||||
const ORDER: u64 = 13;
|
|
||||||
|
|
||||||
fn coefficients() -> &'static [f64] {
|
|
||||||
&PADE_COEFF_13
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calculate_pade_sums<S1, S2, S3>(
|
|
||||||
a: &ArrayBase<S1, Ix2>,
|
|
||||||
a_powers: &[&ArrayBase<S1, Ix2>],
|
|
||||||
u: &mut ArrayBase<S2, Ix2>,
|
|
||||||
v: &mut ArrayBase<S3, Ix2>,
|
|
||||||
work: &mut ArrayBase<S2, Ix2>,
|
|
||||||
) where
|
|
||||||
S1: Data<Elem = f64>,
|
|
||||||
S2: DataMut<Elem = f64>,
|
|
||||||
S3: DataMut<Elem = f64>,
|
|
||||||
{
|
|
||||||
assert_eq!(a_powers.len(), (13 - 1) / 2 + 1);
|
|
||||||
|
|
||||||
let (n_rows, n_cols) = a.dim();
|
|
||||||
assert_eq!(n_rows, n_cols, "Pade sum only defined for square matrices.");
|
|
||||||
let n = n_rows;
|
|
||||||
|
|
||||||
let coefficients = Self::coefficients();
|
|
||||||
|
|
||||||
Zip::from(&mut *work)
|
|
||||||
.and(a_powers[0])
|
|
||||||
.and(a_powers[1])
|
|
||||||
.and(a_powers[2])
|
|
||||||
.and(a_powers[3])
|
|
||||||
.apply(|x, &a0, &a2, &a4, &a6| {
|
|
||||||
*x = *x
|
|
||||||
+ coefficients[1] * a0
|
|
||||||
+ coefficients[3] * a2
|
|
||||||
+ coefficients[5] * a4
|
|
||||||
+ coefficients[7] * a6;
|
|
||||||
});
|
|
||||||
|
|
||||||
{
|
|
||||||
let (a_slice, a_layout) = as_slice_with_layout(a).expect("Matrix `a` not contiguous.");
|
|
||||||
let (work_slice, _) =
|
|
||||||
as_slice_with_layout(work).expect("Matrix `work` not contiguous.");
|
|
||||||
let (u_slice, u_layout) =
|
|
||||||
as_slice_with_layout_mut(u).expect("Matrix `u` not contiguous.");
|
|
||||||
assert_eq!(a_layout, u_layout, "Memory layout mismatch between matrices; currently only row major matrices are supported.");
|
|
||||||
let layout = a_layout;
|
|
||||||
unsafe {
|
|
||||||
cblas::dgemm(
|
|
||||||
layout,
|
|
||||||
cblas::Transpose::None,
|
|
||||||
cblas::Transpose::None,
|
|
||||||
n as i32,
|
|
||||||
n as i32,
|
|
||||||
n as i32,
|
|
||||||
1.0,
|
|
||||||
a_slice,
|
|
||||||
n as i32,
|
|
||||||
work_slice,
|
|
||||||
n as i32,
|
|
||||||
0.0,
|
|
||||||
u_slice,
|
|
||||||
n as i32,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Zip::from(&mut *work)
|
|
||||||
.and(a_powers[1])
|
|
||||||
.and(a_powers[2])
|
|
||||||
.and(a_powers[3])
|
|
||||||
.apply(|x, &a2, &a4, &a6| {
|
|
||||||
*x = coefficients[8] * a2 + coefficients[10] * a4 + coefficients[12] * a6;
|
|
||||||
});
|
|
||||||
|
|
||||||
{
|
|
||||||
let (a6_slice, a6_layout) =
|
|
||||||
as_slice_with_layout(a_powers[3]).expect("Matrix `a6` not contiguous.");
|
|
||||||
let (work_slice, _) =
|
|
||||||
as_slice_with_layout(work).expect("Matrix `work` not contiguous.");
|
|
||||||
let (v_slice, v_layout) =
|
|
||||||
as_slice_with_layout_mut(v).expect("Matrix `v` not contiguous.");
|
|
||||||
assert_eq!(a6_layout, v_layout, "Memory layout mismatch between matrices; currently only row major matrices are supported.");
|
|
||||||
let layout = a6_layout;
|
|
||||||
unsafe {
|
|
||||||
cblas::dgemm(
|
|
||||||
layout,
|
|
||||||
cblas::Transpose::None,
|
|
||||||
cblas::Transpose::None,
|
|
||||||
n as i32,
|
|
||||||
n as i32,
|
|
||||||
n as i32,
|
|
||||||
1.0,
|
|
||||||
a6_slice,
|
|
||||||
n as i32,
|
|
||||||
work_slice,
|
|
||||||
n as i32,
|
|
||||||
0.0,
|
|
||||||
v_slice,
|
|
||||||
n as i32,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Zip::from(v)
|
|
||||||
.and(a_powers[0])
|
|
||||||
.and(a_powers[1])
|
|
||||||
.and(a_powers[2])
|
|
||||||
.and(a_powers[3])
|
|
||||||
.apply(|x, &a0, &a2, &a4, &a6| {
|
|
||||||
*x = *x
|
|
||||||
+ coefficients[0] * a0
|
|
||||||
+ coefficients[2] * a2
|
|
||||||
+ coefficients[4] * a4
|
|
||||||
+ coefficients[6] * a6;
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Storage for calculating the matrix exponential.
|
|
||||||
pub struct Expm {
|
|
||||||
n: usize,
|
|
||||||
itmax: usize,
|
|
||||||
eye: Array2<f64>,
|
|
||||||
a1: Array2<f64>,
|
|
||||||
a2: Array2<f64>,
|
|
||||||
a4: Array2<f64>,
|
|
||||||
a6: Array2<f64>,
|
|
||||||
a8: Array2<f64>,
|
|
||||||
a_abs: Array2<f64>,
|
|
||||||
u: Array2<f64>,
|
|
||||||
work: Array2<f64>,
|
|
||||||
pivot: Array1<i32>,
|
|
||||||
normest1: Normest1,
|
|
||||||
layout: cblas::Layout,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Expm {
|
|
||||||
/// Allocates all space to calculate the matrix exponential for a square matrix of dimension
|
|
||||||
/// n×n.
|
|
||||||
pub fn new(n: usize) -> Self {
|
|
||||||
let eye = Array2::<f64>::eye(n);
|
|
||||||
let a1 = Array2::<f64>::zeros((n, n));
|
|
||||||
let a2 = Array2::<f64>::zeros((n, n));
|
|
||||||
let a4 = Array2::<f64>::zeros((n, n));
|
|
||||||
let a6 = Array2::<f64>::zeros((n, n));
|
|
||||||
let a8 = Array2::<f64>::zeros((n, n));
|
|
||||||
let a_abs = Array2::<f64>::zeros((n, n));
|
|
||||||
let u = Array2::<f64>::zeros((n, n));
|
|
||||||
let work = Array2::<f64>::zeros((n, n));
|
|
||||||
let pivot = Array1::<i32>::zeros(n);
|
|
||||||
let layout = cblas::Layout::RowMajor;
|
|
||||||
|
|
||||||
// TODO: Investigate what an optimal value for t is when estimating the 1-norm.
|
|
||||||
// Python's SciPY uses t=2. Why?
|
|
||||||
let t = 2;
|
|
||||||
let itmax = 5;
|
|
||||||
|
|
||||||
let normest1 = Normest1::new(n, t);
|
|
||||||
|
|
||||||
Expm {
|
|
||||||
n,
|
|
||||||
itmax,
|
|
||||||
eye,
|
|
||||||
a1,
|
|
||||||
a2,
|
|
||||||
a4,
|
|
||||||
a6,
|
|
||||||
a8,
|
|
||||||
a_abs,
|
|
||||||
u,
|
|
||||||
work,
|
|
||||||
pivot,
|
|
||||||
normest1,
|
|
||||||
layout,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Calculate the matrix exponential of the n×n matrix `a` storing the result in matrix `b`.
|
|
||||||
///
|
|
||||||
/// NOTE: Panics if input matrices `a` and `b` don't have matching dimensions, are not square,
|
|
||||||
/// not in row-major order, or don't have the same dimension as the `Expm` object `expm` is
|
|
||||||
/// called on.
|
|
||||||
pub fn expm<S1, S2>(&mut self, a: &ArrayBase<S1, Ix2>, b: &mut ArrayBase<S2, Ix2>)
|
|
||||||
where
|
|
||||||
S1: Data<Elem = f64>,
|
|
||||||
S2: DataMut<Elem = f64>,
|
|
||||||
{
|
|
||||||
assert_eq!(
|
|
||||||
a.dim(),
|
|
||||||
b.dim(),
|
|
||||||
"Input matrices `a` and `b` have to have matching dimensions."
|
|
||||||
);
|
|
||||||
let (n_rows, n_cols) = a.dim();
|
|
||||||
assert_eq!(
|
|
||||||
n_rows, n_cols,
|
|
||||||
"expm is only implemented for square matrices."
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
n_rows, self.n,
|
|
||||||
"Dimension mismatch between matrix `a` and preconfigured `Expm` struct."
|
|
||||||
);
|
|
||||||
|
|
||||||
// Rename b to v to be in line with the nomenclature of the original paper.
|
|
||||||
let v = b;
|
|
||||||
|
|
||||||
self.a1.assign(a);
|
|
||||||
|
|
||||||
let n = self.n as i32;
|
|
||||||
|
|
||||||
{
|
|
||||||
let (a_slice, a_layout) =
|
|
||||||
as_slice_with_layout(&self.a1).expect("Matrix `a` not contiguous.");
|
|
||||||
let (a2_slice, _) =
|
|
||||||
as_slice_with_layout_mut(&mut self.a2).expect("Matrix `a2` not contiguous.");
|
|
||||||
assert_eq!(a_layout, self.layout, "Memory layout mismatch between matrices; currently only row major matrices are supported.");
|
|
||||||
unsafe {
|
|
||||||
cblas::dgemm(
|
|
||||||
self.layout,
|
|
||||||
cblas::Transpose::None,
|
|
||||||
cblas::Transpose::None,
|
|
||||||
n,
|
|
||||||
n,
|
|
||||||
n,
|
|
||||||
1.0,
|
|
||||||
a_slice,
|
|
||||||
n,
|
|
||||||
a_slice,
|
|
||||||
n,
|
|
||||||
0.0,
|
|
||||||
a2_slice,
|
|
||||||
n as i32,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let d4_estimated = self
|
|
||||||
.normest1
|
|
||||||
.normest1_pow(&self.a2, 2, self.itmax)
|
|
||||||
.powf(1.0 / 4.0);
|
|
||||||
let d6_estimated = self
|
|
||||||
.normest1
|
|
||||||
.normest1_pow(&self.a2, 3, self.itmax)
|
|
||||||
.powf(1.0 / 6.0);
|
|
||||||
let eta_1 = d4_estimated.max(d6_estimated);
|
|
||||||
|
|
||||||
if eta_1 <= THETA_3 && self.ell(3) == 0 {
|
|
||||||
println!("eta_1 condition");
|
|
||||||
self.solve_via_pade(PadeOrders::_3, v);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
let (a2_slice, _) =
|
|
||||||
as_slice_with_layout(&self.a2).expect("Matrix `a2` not contiguous.");
|
|
||||||
let (a4_slice, _) =
|
|
||||||
as_slice_with_layout_mut(&mut self.a4).expect("Matrix `a4` not contiguous.");
|
|
||||||
unsafe {
|
|
||||||
cblas::dgemm(
|
|
||||||
self.layout,
|
|
||||||
cblas::Transpose::None,
|
|
||||||
cblas::Transpose::None,
|
|
||||||
self.n as i32,
|
|
||||||
self.n as i32,
|
|
||||||
self.n as i32,
|
|
||||||
1.0,
|
|
||||||
a2_slice,
|
|
||||||
n as i32,
|
|
||||||
a2_slice,
|
|
||||||
n as i32,
|
|
||||||
0.0,
|
|
||||||
a4_slice,
|
|
||||||
n as i32,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let d4_precise = self.normest1.normest1(&self.a4, self.itmax).powf(1.0 / 4.0);
|
|
||||||
let eta_2 = d4_precise.max(d6_estimated);
|
|
||||||
|
|
||||||
if eta_2 <= THETA_5 && self.ell(5) == 0 {
|
|
||||||
println!("eta_2 condition");
|
|
||||||
self.solve_via_pade(PadeOrders::_5, v);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
let (a2_slice, _) =
|
|
||||||
as_slice_with_layout(&self.a2).expect("Matrix `a2` not contiguous.");
|
|
||||||
let (a4_slice, _) =
|
|
||||||
as_slice_with_layout(&self.a4).expect("Matrix `a4` not contiguous.");
|
|
||||||
let (a6_slice, _) =
|
|
||||||
as_slice_with_layout_mut(&mut self.a6).expect("Matrix `a6` not contiguous.");
|
|
||||||
unsafe {
|
|
||||||
cblas::dgemm(
|
|
||||||
self.layout,
|
|
||||||
cblas::Transpose::None,
|
|
||||||
cblas::Transpose::None,
|
|
||||||
self.n as i32,
|
|
||||||
self.n as i32,
|
|
||||||
self.n as i32,
|
|
||||||
1.0,
|
|
||||||
a2_slice,
|
|
||||||
n as i32,
|
|
||||||
a4_slice,
|
|
||||||
n as i32,
|
|
||||||
0.0,
|
|
||||||
a6_slice,
|
|
||||||
n as i32,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let d6_precise = self.normest1.normest1(&self.a6, self.itmax).powf(1.0 / 6.0);
|
|
||||||
let d8_estimated = self.normest1.normest1_pow(&self.a4, 2, self.itmax);
|
|
||||||
let eta_3 = d6_precise.max(d8_estimated);
|
|
||||||
|
|
||||||
if eta_3 <= THETA_7 && self.ell(7) == 0 {
|
|
||||||
println!("eta_3 (first) condition");
|
|
||||||
self.solve_via_pade(PadeOrders::_7, v);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
let (a4_slice, _) =
|
|
||||||
as_slice_with_layout(&self.a4).expect("Matrix `a4` not contiguous.");
|
|
||||||
let (a8_slice, _) =
|
|
||||||
as_slice_with_layout_mut(&mut self.a8).expect("Matrix `a8` not contiguous.");
|
|
||||||
unsafe {
|
|
||||||
cblas::dgemm(
|
|
||||||
self.layout,
|
|
||||||
cblas::Transpose::None,
|
|
||||||
cblas::Transpose::None,
|
|
||||||
self.n as i32,
|
|
||||||
self.n as i32,
|
|
||||||
self.n as i32,
|
|
||||||
1.0,
|
|
||||||
a4_slice,
|
|
||||||
n as i32,
|
|
||||||
a4_slice,
|
|
||||||
n as i32,
|
|
||||||
0.0,
|
|
||||||
a8_slice,
|
|
||||||
n as i32,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if eta_3 <= THETA_9 && self.ell(9) == 0 {
|
|
||||||
println!("eta_3 (second) condition");
|
|
||||||
self.solve_via_pade(PadeOrders::_9, v);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let eta_4 = d8_estimated.max(
|
|
||||||
self.normest1
|
|
||||||
.normest1_prod(&[&self.a4, &self.a6], self.itmax)
|
|
||||||
.powf(1.0 / 10.0),
|
|
||||||
);
|
|
||||||
let eta_5 = eta_3.min(eta_4);
|
|
||||||
|
|
||||||
use std::cmp;
|
|
||||||
use std::f64;
|
|
||||||
let mut s = cmp::max(f64::ceil(f64::log2(eta_5 / THETA_13)) as i32, 0);
|
|
||||||
self.a1.mapv_inplace(|x| x / 2f64.powi(s));
|
|
||||||
s += self.ell(13);
|
|
||||||
self.a1.zip_mut_with(a, |x, &y| *x = y / 2f64.powi(s));
|
|
||||||
self.a2.mapv_inplace(|x| x / 2f64.powi(2 * s));
|
|
||||||
self.a4.mapv_inplace(|x| x / 2f64.powi(4 * s));
|
|
||||||
self.a6.mapv_inplace(|x| x / 2f64.powi(6 * s));
|
|
||||||
|
|
||||||
self.solve_via_pade(PadeOrders::_13, v);
|
|
||||||
|
|
||||||
// TODO: Call code fragment 2.1 in the paper if `a` is triangular, instead of the code below.
|
|
||||||
//
|
|
||||||
// NOTE: it's guaranteed that s >= 0 by its definition.
|
|
||||||
let (u_slice, _) =
|
|
||||||
as_slice_with_layout_mut(&mut self.u).expect("Matrix `u` not contiguous.");
|
|
||||||
|
|
||||||
// NOTE: v initially contains r after `solve_via_pade`.
|
|
||||||
let (v_slice, _) = as_slice_with_layout_mut(v).expect("Matrix `v` not contiguous.");
|
|
||||||
for _ in 0..s {
|
|
||||||
unsafe {
|
|
||||||
cblas::dgemm(
|
|
||||||
self.layout,
|
|
||||||
cblas::Transpose::None,
|
|
||||||
cblas::Transpose::None,
|
|
||||||
self.n as i32,
|
|
||||||
self.n as i32,
|
|
||||||
self.n as i32,
|
|
||||||
1.0,
|
|
||||||
v_slice,
|
|
||||||
n as i32,
|
|
||||||
v_slice,
|
|
||||||
n as i32,
|
|
||||||
0.0,
|
|
||||||
u_slice,
|
|
||||||
n as i32,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
u_slice.swap_with_slice(v_slice);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A helper function (as it is called in the original paper) returning the
|
|
||||||
/// $\max(\lceil \log_2(\alpha/u) / 2m \rceil, 0)$, where
|
|
||||||
/// $\alpha = \lvert c_{2m+1}\rvert \texttt{normest}(\lvert A\rvert^{2m+1})/\lVertA\rVert_1$.
|
|
||||||
fn ell(&mut self, m: usize) -> i32 {
|
|
||||||
Zip::from(&mut self.a_abs)
|
|
||||||
.and(&self.a1)
|
|
||||||
.apply(|x, &y| *x = y.abs());
|
|
||||||
|
|
||||||
let c2m1 = pade_error_coefficient(m as u64);
|
|
||||||
|
|
||||||
let norm_abs_a_2m1 = self
|
|
||||||
.normest1
|
|
||||||
.normest1_pow(&self.a_abs, 2 * m + 1, self.itmax);
|
|
||||||
let norm_a = self.normest1.normest1(&self.a1, self.itmax);
|
|
||||||
let alpha = c2m1.abs() * norm_abs_a_2m1 / norm_a;
|
|
||||||
|
|
||||||
// The unit roundoff, defined as half the machine epsilon.
|
|
||||||
let u = std::f64::EPSILON / 2.0;
|
|
||||||
|
|
||||||
use std::cmp;
|
|
||||||
use std::f64;
|
|
||||||
|
|
||||||
cmp::max(0, f64::ceil(f64::log2(alpha / u) / (2 * m) as f64) as i32)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn solve_via_pade<S>(&mut self, pade_order: PadeOrders, v: &mut ArrayBase<S, Ix2>)
|
|
||||||
where
|
|
||||||
S: DataMut<Elem = f64>,
|
|
||||||
{
|
|
||||||
use PadeOrders::*;
|
|
||||||
|
|
||||||
macro_rules! pade {
|
|
||||||
($order:ty, [$(&$apow:expr),+]) => {
|
|
||||||
<$order as PadeOrder>::calculate_pade_sums(&self.a1, &[$(&$apow),+], &mut self.u, v, &mut self.work);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
match pade_order {
|
|
||||||
_3 => pade!(PadeOrder_3, [&self.eye, &self.a2]),
|
|
||||||
_5 => pade!(PadeOrder_5, [&self.eye, &self.a2, &self.a4]),
|
|
||||||
_7 => pade!(PadeOrder_7, [&self.eye, &self.a2, &self.a4, &self.a6]),
|
|
||||||
_9 => pade!(
|
|
||||||
PadeOrder_9,
|
|
||||||
[&self.eye, &self.a2, &self.a4, &self.a6, &self.a8]
|
|
||||||
),
|
|
||||||
_13 => pade!(PadeOrder_13, [&self.eye, &self.a2, &self.a4, &self.a6]),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Here we set v = p <- u + v and u = q <- -u + v, overwriting u and v via work.
|
|
||||||
self.work.assign(v);
|
|
||||||
|
|
||||||
Zip::from(&mut *v).and(&self.u).apply(|x, &y| {
|
|
||||||
*x += y;
|
|
||||||
});
|
|
||||||
|
|
||||||
Zip::from(&mut self.u).and(&self.work).apply(|x, &y| {
|
|
||||||
*x = -*x + y;
|
|
||||||
});
|
|
||||||
|
|
||||||
let (u_slice, _) =
|
|
||||||
as_slice_with_layout_mut(&mut self.u).expect("Matrix `u` not contiguous.");
|
|
||||||
let (v_slice, _) = as_slice_with_layout_mut(v).expect("Matrix `v` not contiguous.");
|
|
||||||
let (pivot_slice, _) =
|
|
||||||
as_slice_with_layout_mut(&mut self.pivot).expect("Vector `pivot` not contiguous.");
|
|
||||||
|
|
||||||
let n = self.n as i32;
|
|
||||||
|
|
||||||
let layout = {
|
|
||||||
match self.layout {
|
|
||||||
cblas::Layout::ColumnMajor => lapacke::Layout::ColumnMajor,
|
|
||||||
cblas::Layout::RowMajor => lapacke::Layout::RowMajor,
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// FIXME: Handle the info for error management.
|
|
||||||
let _ = unsafe { lapacke::dgesv(layout, n, n, u_slice, n, pivot_slice, v_slice, n) };
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Calculate the matrix exponential of the n×n matrix `a` storing the result in matrix `b`.
|
|
||||||
///
|
|
||||||
/// NOTE: Panics if input matrices `a` and `b` don't have matching dimensions, are not square,
|
|
||||||
/// not in row-major order, or don't have the same dimension as the `Expm` object `expm` is
|
|
||||||
/// called on.
|
|
||||||
pub fn expm<S1, S2>(a: &ArrayBase<S1, Ix2>, b: &mut ArrayBase<S2, Ix2>)
|
|
||||||
where
|
|
||||||
S1: Data<Elem = f64>,
|
|
||||||
S2: DataMut<Elem = f64>,
|
|
||||||
{
|
|
||||||
let (n, _) = a.dim();
|
|
||||||
|
|
||||||
let mut expm = Expm::new(n);
|
|
||||||
expm.expm(a, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns slice and layout underlying an array `a`.
|
|
||||||
fn as_slice_with_layout<S, T, D>(a: &ArrayBase<S, D>) -> Option<(&[T], cblas::Layout)>
|
|
||||||
where
|
|
||||||
S: Data<Elem = T>,
|
|
||||||
D: Dimension,
|
|
||||||
{
|
|
||||||
if let Some(a_slice) = a.as_slice() {
|
|
||||||
Some((a_slice, cblas::Layout::RowMajor))
|
|
||||||
} else if let Some(a_slice) = a.as_slice_memory_order() {
|
|
||||||
Some((a_slice, cblas::Layout::ColumnMajor))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns mutable slice and layout underlying an array `a`.
|
|
||||||
fn as_slice_with_layout_mut<S, T, D>(a: &mut ArrayBase<S, D>) -> Option<(&mut [T], cblas::Layout)>
|
|
||||||
where
|
|
||||||
S: DataMut<Elem = T>,
|
|
||||||
D: Dimension,
|
|
||||||
{
|
|
||||||
if a.as_slice_mut().is_some() {
|
|
||||||
Some((a.as_slice_mut().unwrap(), cblas::Layout::RowMajor))
|
|
||||||
} else if a.as_slice_memory_order_mut().is_some() {
|
|
||||||
Some((
|
|
||||||
a.as_slice_memory_order_mut().unwrap(),
|
|
||||||
cblas::Layout::ColumnMajor,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
// XXX: The above is a workaround for Rust not having non-lexical lifetimes yet.
|
|
||||||
// More information here:
|
|
||||||
// http://smallcultfollowing.com/babysteps/blog/2016/04/27/non-lexical-lifetimes-introduction/#problem-case-3-conditional-control-flow-across-functions
|
|
||||||
//
|
|
||||||
// if let Some(slice) = a.as_slice_mut() {
|
|
||||||
// Some((slice, cblas::Layout::RowMajor))
|
|
||||||
// } else if let Some(slice) = a.as_slice_memory_order_mut() {
|
|
||||||
// Some((slice, cblas::Layout::ColumnMajor))
|
|
||||||
// } else {
|
|
||||||
// None
|
|
||||||
// }
|
|
||||||
}
|
|
||||||
@@ -10,21 +10,6 @@ pub use exponential::Exponential;
|
|||||||
pub use matern32::Matern32;
|
pub use matern32::Matern32;
|
||||||
pub use matern52::Matern52;
|
pub use matern52::Matern52;
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub(crate) fn transition(t0: f64, t1: f64, feedback: Array2<f64>) -> Array2<f64> {
|
|
||||||
let a = feedback * (t1 - t0);
|
|
||||||
|
|
||||||
if a.shape() == [1, 1] {
|
|
||||||
array![[a[(0, 0)].exp()]]
|
|
||||||
} else {
|
|
||||||
let mut b = Array2::<f64>::zeros(a.dim());
|
|
||||||
|
|
||||||
crate::expm::expm(&a, &mut b);
|
|
||||||
|
|
||||||
b
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn distance(ts1: &[f64], ts2: &[f64]) -> Array2<f64> {
|
pub(crate) fn distance(ts1: &[f64], ts2: &[f64]) -> Array2<f64> {
|
||||||
let mut r = Array2::zeros((ts1.len(), ts2.len()));
|
let mut r = Array2::zeros((ts1.len(), ts2.len()));
|
||||||
|
|
||||||
@@ -54,27 +39,9 @@ pub trait Kernel {
|
|||||||
unimplemented!();
|
unimplemented!();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn transition(&self, t0: f64, t1: f64) -> Array2<f64> {
|
fn transition(&self, t0: f64, t1: f64) -> Array2<f64>;
|
||||||
transition(t0, t1, self.feedback())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn noise_cov(&self, _t0: f64, _t1: f64) -> Array2<f64> {
|
fn noise_cov(&self, _t0: f64, _t1: f64) -> Array2<f64> {
|
||||||
/*
|
|
||||||
mat = self.noise_effect.dot(self.noise_density).dot(self.noise_effect.T)
|
|
||||||
#print(g)
|
|
||||||
print(mat)
|
|
||||||
Phi = np.vstack((
|
|
||||||
np.hstack((self.feedback, mat)),
|
|
||||||
np.hstack((np.zeros_like(mat), -self.feedback.T))))
|
|
||||||
print(Phi)
|
|
||||||
m = self.order
|
|
||||||
AB = np.dot(sp.linalg.expm(Phi * (t2 - t1)), np.eye(2*m, m, k=-m))
|
|
||||||
print(AB)
|
|
||||||
return sp.linalg.solve(AB[m:,:].T, AB[:m,:].T)
|
|
||||||
*/
|
|
||||||
|
|
||||||
// let mat = self.noise_effect()
|
|
||||||
|
|
||||||
todo!();
|
todo!();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,7 +59,6 @@ impl Kernel for Constant {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
extern crate blis_src;
|
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
use approx::assert_abs_diff_eq;
|
use approx::assert_abs_diff_eq;
|
||||||
@@ -126,18 +125,4 @@ mod tests {
|
|||||||
|
|
||||||
assert_abs_diff_eq!(Array::from(vars), kernel.k_diag(&ts));
|
assert_abs_diff_eq!(Array::from(vars), kernel.k_diag(&ts));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_ssm_matrices() {
|
|
||||||
let kernel = Constant::new(2.5);
|
|
||||||
|
|
||||||
let deltas = [0.01, 1.0, 10.0];
|
|
||||||
|
|
||||||
for delta in &deltas {
|
|
||||||
assert_abs_diff_eq!(
|
|
||||||
crate::kernel::transition(0.0, *delta, kernel.feedback()),
|
|
||||||
kernel.transition(0.0, *delta)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -62,7 +62,6 @@ impl Kernel for Exponential {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
extern crate blis_src;
|
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
use approx::assert_abs_diff_eq;
|
use approx::assert_abs_diff_eq;
|
||||||
@@ -133,18 +132,4 @@ mod tests {
|
|||||||
|
|
||||||
assert_abs_diff_eq!(Array::from(vars), kernel.k_diag(&ts));
|
assert_abs_diff_eq!(Array::from(vars), kernel.k_diag(&ts));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_ssm_matrices() {
|
|
||||||
let kernel = Exponential::new(1.1, 2.2);
|
|
||||||
|
|
||||||
let deltas = [0.01, 1.0, 10.0];
|
|
||||||
|
|
||||||
for delta in &deltas {
|
|
||||||
assert_abs_diff_eq!(
|
|
||||||
crate::kernel::transition(0.0, *delta, kernel.feedback()),
|
|
||||||
kernel.transition(0.0, *delta)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,8 +20,15 @@ impl Matern32 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Kernel for Matern32 {
|
impl Kernel for Matern32 {
|
||||||
fn k_mat(&self, _ts1: &[f64], _ts2: Option<&[f64]>) -> Array2<f64> {
|
fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2<f64> {
|
||||||
unimplemented!();
|
let ts2 = ts2.unwrap_or(ts1);
|
||||||
|
|
||||||
|
let sqrt3 = 3.0f64.sqrt();
|
||||||
|
|
||||||
|
let r = super::distance(ts1, ts2) / self.l_scale;
|
||||||
|
let r2 = r.mapv(|v| (-sqrt3 * v).exp());
|
||||||
|
|
||||||
|
r2 * (1.0 + sqrt3 * r) * self.var
|
||||||
}
|
}
|
||||||
|
|
||||||
fn k_diag(&self, ts: &[f64]) -> Array1<f64> {
|
fn k_diag(&self, ts: &[f64]) -> Array1<f64> {
|
||||||
@@ -52,6 +59,10 @@ impl Kernel for Matern32 {
|
|||||||
array![[0.0, 1.0], [-a.powi(2), -2.0 * a]]
|
array![[0.0, 1.0], [-a.powi(2), -2.0 * a]]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn noise_effect(&self) -> Array2<f64> {
|
||||||
|
array![[0.0], [1.0]]
|
||||||
|
}
|
||||||
|
|
||||||
fn transition(&self, t0: f64, t1: f64) -> Array2<f64> {
|
fn transition(&self, t0: f64, t1: f64) -> Array2<f64> {
|
||||||
let d = t1 - t0;
|
let d = t1 - t0;
|
||||||
let a = self.lambda;
|
let a = self.lambda;
|
||||||
@@ -75,3 +86,77 @@ impl Kernel for Matern32 {
|
|||||||
self.var * array![[x11, x12], [x12, x22]]
|
self.var * array![[x11, x12], [x12, x22]]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
use approx::assert_abs_diff_eq;
|
||||||
|
use rand::{distributions::Standard, thread_rng, Rng};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_kernel_matrix() {
|
||||||
|
let kernel = Matern32::new(1.5, 0.7);
|
||||||
|
|
||||||
|
let ts = [1.26, 1.46, 2.67];
|
||||||
|
|
||||||
|
assert_abs_diff_eq!(
|
||||||
|
kernel.k_mat(&ts, None),
|
||||||
|
array![
|
||||||
|
[1.5, 1.3670208436282583, 0.20560783965565255],
|
||||||
|
[1.3670208436282583, 1.5, 0.30007530446059727],
|
||||||
|
[0.20560783965565255, 0.30007530446059727, 1.5]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_kernel_diag() {
|
||||||
|
let kernel = Matern32::new(1.5, 0.7);
|
||||||
|
|
||||||
|
let ts: Vec<_> = thread_rng()
|
||||||
|
.sample_iter::<f64, _>(Standard)
|
||||||
|
.take(10)
|
||||||
|
.map(|x| x * 10.0)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
assert_eq!(kernel.k_mat(&ts, None).diag(), kernel.k_diag(&ts));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_kernel_order() {
|
||||||
|
let kernel = Matern32::new(1.5, 0.7);
|
||||||
|
|
||||||
|
let m = kernel.order();
|
||||||
|
|
||||||
|
assert_eq!(kernel.state_mean(0.0).shape(), &[m]);
|
||||||
|
assert_eq!(kernel.state_cov(0.0).shape(), &[m, m]);
|
||||||
|
assert_eq!(kernel.measurement_vector().shape(), &[m]);
|
||||||
|
assert_eq!(kernel.feedback().shape(), &[m, m]);
|
||||||
|
assert_eq!(kernel.noise_effect().shape()[0], m);
|
||||||
|
assert_eq!(kernel.transition(0.0, 1.0).shape(), &[m, m]);
|
||||||
|
assert_eq!(kernel.noise_cov(0.0, 1.0).shape(), &[m, m]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ssm_variance() {
|
||||||
|
let kernel = Matern32::new(1.5, 0.7);
|
||||||
|
|
||||||
|
let ts: Vec<_> = thread_rng()
|
||||||
|
.sample_iter::<f64, _>(Standard)
|
||||||
|
.take(10)
|
||||||
|
.map(|x| x * 10.0)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let h = kernel.measurement_vector();
|
||||||
|
|
||||||
|
let vars = ts
|
||||||
|
.iter()
|
||||||
|
.map(|t| h.dot(&kernel.state_cov(*t)).dot(&h))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
assert_abs_diff_eq!(Array::from(vars), kernel.k_diag(&ts));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
// https://github.com/lucasmaystre/kickscore/tree/master/kickscore
|
// https://github.com/lucasmaystre/kickscore/tree/master/kickscore
|
||||||
|
|
||||||
mod condest;
|
|
||||||
mod expm;
|
|
||||||
mod fitter;
|
mod fitter;
|
||||||
mod item;
|
mod item;
|
||||||
pub mod kernel;
|
pub mod kernel;
|
||||||
mod linalg;
|
mod linalg;
|
||||||
|
mod math;
|
||||||
mod model;
|
mod model;
|
||||||
pub mod observation;
|
pub mod observation;
|
||||||
mod storage;
|
mod storage;
|
||||||
|
|||||||
434
src/math.rs
Normal file
434
src/math.rs
Normal file
@@ -0,0 +1,434 @@
|
|||||||
|
pub fn erfc(x: f64) -> f64 {
|
||||||
|
if x.is_nan() {
|
||||||
|
f64::NAN
|
||||||
|
} else if x == f64::INFINITY {
|
||||||
|
0.0
|
||||||
|
} else if x == f64::NEG_INFINITY {
|
||||||
|
2.0
|
||||||
|
} else {
|
||||||
|
erf_impl(x, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator of `erf_impl`
|
||||||
|
/// in the interval [1e-10, 0.5].
|
||||||
|
const ERF_IMPL_AN: &[f64] = &[
|
||||||
|
0.00337916709551257388990745,
|
||||||
|
-0.00073695653048167948530905,
|
||||||
|
-0.374732337392919607868241,
|
||||||
|
0.0817442448733587196071743,
|
||||||
|
-0.0421089319936548595203468,
|
||||||
|
0.0070165709512095756344528,
|
||||||
|
-0.00495091255982435110337458,
|
||||||
|
0.000871646599037922480317225,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator of `erf_impl`
|
||||||
|
/// in the interval [1e-10, 0.5]
|
||||||
|
const ERF_IMPL_AD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
-0.218088218087924645390535,
|
||||||
|
0.412542972725442099083918,
|
||||||
|
-0.0841891147873106755410271,
|
||||||
|
0.0655338856400241519690695,
|
||||||
|
-0.0120019604454941768171266,
|
||||||
|
0.00408165558926174048329689,
|
||||||
|
-0.000615900721557769691924509,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [0.5, 0.75].
|
||||||
|
const ERF_IMPL_BN: &[f64] = &[
|
||||||
|
-0.0361790390718262471360258,
|
||||||
|
0.292251883444882683221149,
|
||||||
|
0.281447041797604512774415,
|
||||||
|
0.125610208862766947294894,
|
||||||
|
0.0274135028268930549240776,
|
||||||
|
0.00250839672168065762786937,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [0.5, 0.75].
|
||||||
|
const ERF_IMPL_BD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
1.8545005897903486499845,
|
||||||
|
1.43575803037831418074962,
|
||||||
|
0.582827658753036572454135,
|
||||||
|
0.124810476932949746447682,
|
||||||
|
0.0113724176546353285778481,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [0.75, 1.25].
|
||||||
|
const ERF_IMPL_CN: &[f64] = &[
|
||||||
|
-0.0397876892611136856954425,
|
||||||
|
0.153165212467878293257683,
|
||||||
|
0.191260295600936245503129,
|
||||||
|
0.10276327061989304213645,
|
||||||
|
0.029637090615738836726027,
|
||||||
|
0.0046093486780275489468812,
|
||||||
|
0.000307607820348680180548455,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [0.75, 1.25].
|
||||||
|
const ERF_IMPL_CD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
1.95520072987627704987886,
|
||||||
|
1.64762317199384860109595,
|
||||||
|
0.768238607022126250082483,
|
||||||
|
0.209793185936509782784315,
|
||||||
|
0.0319569316899913392596356,
|
||||||
|
0.00213363160895785378615014,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [1.25, 2.25].
|
||||||
|
const ERF_IMPL_DN: &[f64] = &[
|
||||||
|
-0.0300838560557949717328341,
|
||||||
|
0.0538578829844454508530552,
|
||||||
|
0.0726211541651914182692959,
|
||||||
|
0.0367628469888049348429018,
|
||||||
|
0.00964629015572527529605267,
|
||||||
|
0.00133453480075291076745275,
|
||||||
|
0.778087599782504251917881e-4,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [1.25, 2.25].
|
||||||
|
const ERF_IMPL_DD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
1.75967098147167528287343,
|
||||||
|
1.32883571437961120556307,
|
||||||
|
0.552528596508757581287907,
|
||||||
|
0.133793056941332861912279,
|
||||||
|
0.0179509645176280768640766,
|
||||||
|
0.00104712440019937356634038,
|
||||||
|
-0.106640381820357337177643e-7,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [2.25, 3.5].
|
||||||
|
const ERF_IMPL_EN: &[f64] = &[
|
||||||
|
-0.0117907570137227847827732,
|
||||||
|
0.014262132090538809896674,
|
||||||
|
0.0202234435902960820020765,
|
||||||
|
0.00930668299990432009042239,
|
||||||
|
0.00213357802422065994322516,
|
||||||
|
0.00025022987386460102395382,
|
||||||
|
0.120534912219588189822126e-4,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [2.25, 3.5].
|
||||||
|
const ERF_IMPL_ED: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
1.50376225203620482047419,
|
||||||
|
0.965397786204462896346934,
|
||||||
|
0.339265230476796681555511,
|
||||||
|
0.0689740649541569716897427,
|
||||||
|
0.00771060262491768307365526,
|
||||||
|
0.000371421101531069302990367,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [3.5, 5.25].
|
||||||
|
const ERF_IMPL_FN: &[f64] = &[
|
||||||
|
-0.00546954795538729307482955,
|
||||||
|
0.00404190278731707110245394,
|
||||||
|
0.0054963369553161170521356,
|
||||||
|
0.00212616472603945399437862,
|
||||||
|
0.000394984014495083900689956,
|
||||||
|
0.365565477064442377259271e-4,
|
||||||
|
0.135485897109932323253786e-5,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [3.5, 5.25].
|
||||||
|
const ERF_IMPL_FD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
1.21019697773630784832251,
|
||||||
|
0.620914668221143886601045,
|
||||||
|
0.173038430661142762569515,
|
||||||
|
0.0276550813773432047594539,
|
||||||
|
0.00240625974424309709745382,
|
||||||
|
0.891811817251336577241006e-4,
|
||||||
|
-0.465528836283382684461025e-11,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [5.25, 8].
|
||||||
|
const ERF_IMPL_GN: &[f64] = &[
|
||||||
|
-0.00270722535905778347999196,
|
||||||
|
0.0013187563425029400461378,
|
||||||
|
0.00119925933261002333923989,
|
||||||
|
0.00027849619811344664248235,
|
||||||
|
0.267822988218331849989363e-4,
|
||||||
|
0.923043672315028197865066e-6,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [5.25, 8].
|
||||||
|
const ERF_IMPL_GD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
0.814632808543141591118279,
|
||||||
|
0.268901665856299542168425,
|
||||||
|
0.0449877216103041118694989,
|
||||||
|
0.00381759663320248459168994,
|
||||||
|
0.000131571897888596914350697,
|
||||||
|
0.404815359675764138445257e-11,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [8, 11.5].
|
||||||
|
const ERF_IMPL_HN: &[f64] = &[
|
||||||
|
-0.00109946720691742196814323,
|
||||||
|
0.000406425442750422675169153,
|
||||||
|
0.000274499489416900707787024,
|
||||||
|
0.465293770646659383436343e-4,
|
||||||
|
0.320955425395767463401993e-5,
|
||||||
|
0.778286018145020892261936e-7,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [8, 11.5].
|
||||||
|
const ERF_IMPL_HD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
0.588173710611846046373373,
|
||||||
|
0.139363331289409746077541,
|
||||||
|
0.0166329340417083678763028,
|
||||||
|
0.00100023921310234908642639,
|
||||||
|
0.24254837521587225125068e-4,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [11.5, 17].
|
||||||
|
const ERF_IMPL_IN: &[f64] = &[
|
||||||
|
-0.00056907993601094962855594,
|
||||||
|
0.000169498540373762264416984,
|
||||||
|
0.518472354581100890120501e-4,
|
||||||
|
0.382819312231928859704678e-5,
|
||||||
|
0.824989931281894431781794e-7,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [11.5, 17].
|
||||||
|
const ERF_IMPL_ID: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
0.339637250051139347430323,
|
||||||
|
0.043472647870310663055044,
|
||||||
|
0.00248549335224637114641629,
|
||||||
|
0.535633305337152900549536e-4,
|
||||||
|
-0.117490944405459578783846e-12,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [17, 24].
|
||||||
|
const ERF_IMPL_JN: &[f64] = &[
|
||||||
|
-0.000241313599483991337479091,
|
||||||
|
0.574224975202501512365975e-4,
|
||||||
|
0.115998962927383778460557e-4,
|
||||||
|
0.581762134402593739370875e-6,
|
||||||
|
0.853971555085673614607418e-8,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [17, 24].
|
||||||
|
const ERF_IMPL_JD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
0.233044138299687841018015,
|
||||||
|
0.0204186940546440312625597,
|
||||||
|
0.000797185647564398289151125,
|
||||||
|
0.117019281670172327758019e-4,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [24, 38].
|
||||||
|
const ERF_IMPL_KN: &[f64] = &[
|
||||||
|
-0.000146674699277760365803642,
|
||||||
|
0.162666552112280519955647e-4,
|
||||||
|
0.269116248509165239294897e-5,
|
||||||
|
0.979584479468091935086972e-7,
|
||||||
|
0.101994647625723465722285e-8,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [24, 38].
|
||||||
|
const ERF_IMPL_KD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
0.165907812944847226546036,
|
||||||
|
0.0103361716191505884359634,
|
||||||
|
0.000286593026373868366935721,
|
||||||
|
0.298401570840900340874568e-5,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [38, 60].
|
||||||
|
const ERF_IMPL_LN: &[f64] = &[
|
||||||
|
-0.583905797629771786720406e-4,
|
||||||
|
0.412510325105496173512992e-5,
|
||||||
|
0.431790922420250949096906e-6,
|
||||||
|
0.993365155590013193345569e-8,
|
||||||
|
0.653480510020104699270084e-10,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [38, 60].
|
||||||
|
const ERF_IMPL_LD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
0.105077086072039915406159,
|
||||||
|
0.00414278428675475620830226,
|
||||||
|
0.726338754644523769144108e-4,
|
||||||
|
0.477818471047398785369849e-6,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [60, 85].
|
||||||
|
const ERF_IMPL_MN: &[f64] = &[
|
||||||
|
-0.196457797609229579459841e-4,
|
||||||
|
0.157243887666800692441195e-5,
|
||||||
|
0.543902511192700878690335e-7,
|
||||||
|
0.317472492369117710852685e-9,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [60, 85].
|
||||||
|
const ERF_IMPL_MD: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
0.052803989240957632204885,
|
||||||
|
0.000926876069151753290378112,
|
||||||
|
0.541011723226630257077328e-5,
|
||||||
|
0.535093845803642394908747e-15,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a numerator in `erf_impl`
|
||||||
|
/// in the interval [85, 110].
|
||||||
|
const ERF_IMPL_NN: &[f64] = &[
|
||||||
|
-0.789224703978722689089794e-5,
|
||||||
|
0.622088451660986955124162e-6,
|
||||||
|
0.145728445676882396797184e-7,
|
||||||
|
0.603715505542715364529243e-10,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Polynomial coefficients for a denominator in `erf_impl`
|
||||||
|
/// in the interval [85, 110].
|
||||||
|
const ERF_IMPL_ND: &[f64] = &[
|
||||||
|
1.0,
|
||||||
|
0.0375328846356293715248719,
|
||||||
|
0.000467919535974625308126054,
|
||||||
|
0.193847039275845656900547e-5,
|
||||||
|
];
|
||||||
|
|
||||||
|
/// `erf_impl` computes the error function at `z`.
|
||||||
|
/// If `inv` is true, `1 - erf` is calculated as opposed to `erf`
|
||||||
|
fn erf_impl(z: f64, inv: bool) -> f64 {
|
||||||
|
if z < 0.0 {
|
||||||
|
if !inv {
|
||||||
|
return -erf_impl(-z, false);
|
||||||
|
}
|
||||||
|
if z < -0.5 {
|
||||||
|
return 2.0 - erf_impl(-z, true);
|
||||||
|
}
|
||||||
|
return 1.0 + erf_impl(-z, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = if z < 0.5 {
|
||||||
|
if z < 1e-10 {
|
||||||
|
z * 1.125 + z * 0.003379167095512573896158903121545171688
|
||||||
|
} else {
|
||||||
|
z * 1.125 + z * polynomial(z, ERF_IMPL_AN) / polynomial(z, ERF_IMPL_AD)
|
||||||
|
}
|
||||||
|
} else if z < 110.0 {
|
||||||
|
let (r, b) = if z < 0.75 {
|
||||||
|
(
|
||||||
|
polynomial(z - 0.5, ERF_IMPL_BN) / polynomial(z - 0.5, ERF_IMPL_BD),
|
||||||
|
0.3440242112,
|
||||||
|
)
|
||||||
|
} else if z < 1.25 {
|
||||||
|
(
|
||||||
|
polynomial(z - 0.75, ERF_IMPL_CN) / polynomial(z - 0.75, ERF_IMPL_CD),
|
||||||
|
0.419990927,
|
||||||
|
)
|
||||||
|
} else if z < 2.25 {
|
||||||
|
(
|
||||||
|
polynomial(z - 1.25, ERF_IMPL_DN) / polynomial(z - 1.25, ERF_IMPL_DD),
|
||||||
|
0.4898625016,
|
||||||
|
)
|
||||||
|
} else if z < 3.5 {
|
||||||
|
(
|
||||||
|
polynomial(z - 2.25, ERF_IMPL_EN) / polynomial(z - 2.25, ERF_IMPL_ED),
|
||||||
|
0.5317370892,
|
||||||
|
)
|
||||||
|
} else if z < 5.25 {
|
||||||
|
(
|
||||||
|
polynomial(z - 3.5, ERF_IMPL_FN) / polynomial(z - 3.5, ERF_IMPL_FD),
|
||||||
|
0.5489973426,
|
||||||
|
)
|
||||||
|
} else if z < 8.0 {
|
||||||
|
(
|
||||||
|
polynomial(z - 5.25, ERF_IMPL_GN) / polynomial(z - 5.25, ERF_IMPL_GD),
|
||||||
|
0.5571740866,
|
||||||
|
)
|
||||||
|
} else if z < 11.5 {
|
||||||
|
(
|
||||||
|
polynomial(z - 8.0, ERF_IMPL_HN) / polynomial(z - 8.0, ERF_IMPL_HD),
|
||||||
|
0.5609807968,
|
||||||
|
)
|
||||||
|
} else if z < 17.0 {
|
||||||
|
(
|
||||||
|
polynomial(z - 11.5, ERF_IMPL_IN) / polynomial(z - 11.5, ERF_IMPL_ID),
|
||||||
|
0.5626493692,
|
||||||
|
)
|
||||||
|
} else if z < 24.0 {
|
||||||
|
(
|
||||||
|
polynomial(z - 17.0, ERF_IMPL_JN) / polynomial(z - 17.0, ERF_IMPL_JD),
|
||||||
|
0.5634598136,
|
||||||
|
)
|
||||||
|
} else if z < 38.0 {
|
||||||
|
(
|
||||||
|
polynomial(z - 24.0, ERF_IMPL_KN) / polynomial(z - 24.0, ERF_IMPL_KD),
|
||||||
|
0.5638477802,
|
||||||
|
)
|
||||||
|
} else if z < 60.0 {
|
||||||
|
(
|
||||||
|
polynomial(z - 38.0, ERF_IMPL_LN) / polynomial(z - 38.0, ERF_IMPL_LD),
|
||||||
|
0.5640528202,
|
||||||
|
)
|
||||||
|
} else if z < 85.0 {
|
||||||
|
(
|
||||||
|
polynomial(z - 60.0, ERF_IMPL_MN) / polynomial(z - 60.0, ERF_IMPL_MD),
|
||||||
|
0.5641309023,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
polynomial(z - 85.0, ERF_IMPL_NN) / polynomial(z - 85.0, ERF_IMPL_ND),
|
||||||
|
0.5641584396,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
let g = (-z * z).exp() / z;
|
||||||
|
g * b + g * r
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
|
||||||
|
if inv && z >= 0.5 {
|
||||||
|
result
|
||||||
|
} else if z >= 0.5 || inv {
|
||||||
|
1.0 - result
|
||||||
|
} else {
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn polynomial(z: f64, coeff: &[f64]) -> f64 {
|
||||||
|
let n = coeff.len();
|
||||||
|
if n == 0 {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut sum = *coeff.last().unwrap();
|
||||||
|
for c in coeff[0..n - 1].iter().rev() {
|
||||||
|
sum = *c + z * sum;
|
||||||
|
}
|
||||||
|
sum
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
use std::f64::consts::{PI, SQRT_2};
|
use std::f64::consts::{PI, SQRT_2};
|
||||||
|
|
||||||
use statrs::function::erf::erfc;
|
use crate::math::erfc;
|
||||||
|
|
||||||
const CS: [f64; 14] = [
|
const CS: [f64; 14] = [
|
||||||
0.00048204,
|
0.00048204,
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ fn nba_history() {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let t = t.midnight().assume_utc().timestamp() as f64;
|
let t = t.midnight().assume_utc().unix_timestamp() as f64;
|
||||||
|
|
||||||
let score_1: u16 = data[3].parse().unwrap();
|
let score_1: u16 = data[3].parse().unwrap();
|
||||||
let score_2: u16 = data[4].parse().unwrap();
|
let score_2: u16 = data[4].parse().unwrap();
|
||||||
@@ -77,7 +77,7 @@ fn nba_history() {
|
|||||||
time::date!(1996 - 01 - 01)
|
time::date!(1996 - 01 - 01)
|
||||||
.midnight()
|
.midnight()
|
||||||
.assume_utc()
|
.assume_utc()
|
||||||
.timestamp() as f64,
|
.unix_timestamp() as f64,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_abs_diff_eq!(p_win, 0.9002599772490479, epsilon = f64::EPSILON);
|
assert_abs_diff_eq!(p_win, 0.9002599772490479, epsilon = f64::EPSILON);
|
||||||
@@ -88,7 +88,7 @@ fn nba_history() {
|
|||||||
time::date!(2001 - 01 - 01)
|
time::date!(2001 - 01 - 01)
|
||||||
.midnight()
|
.midnight()
|
||||||
.assume_utc()
|
.assume_utc()
|
||||||
.timestamp() as f64,
|
.unix_timestamp() as f64,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_abs_diff_eq!(p_win, 0.22837870685441986, epsilon = f64::EPSILON);
|
assert_abs_diff_eq!(p_win, 0.22837870685441986, epsilon = f64::EPSILON);
|
||||||
@@ -99,7 +99,7 @@ fn nba_history() {
|
|||||||
time::date!(2020 - 01 - 01)
|
time::date!(2020 - 01 - 01)
|
||||||
.midnight()
|
.midnight()
|
||||||
.assume_utc()
|
.assume_utc()
|
||||||
.timestamp() as f64,
|
.unix_timestamp() as f64,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_abs_diff_eq!(p_win, 0.2748029998412422, epsilon = f64::EPSILON);
|
assert_abs_diff_eq!(p_win, 0.2748029998412422, epsilon = f64::EPSILON);
|
||||||
|
|||||||
Reference in New Issue
Block a user