Use openblas instead
This commit is contained in:
@@ -217,7 +217,7 @@ impl Kernel for Vec<Box<dyn Kernel>> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
extern crate intel_mkl_src;
|
||||
extern crate blas_src;
|
||||
|
||||
use approx::assert_abs_diff_eq;
|
||||
use rand::{distributions::Standard, thread_rng, Rng};
|
||||
|
||||
@@ -77,7 +77,7 @@ impl Kernel for Affine {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
extern crate intel_mkl_src;
|
||||
extern crate blas_src;
|
||||
|
||||
use approx::assert_abs_diff_eq;
|
||||
use rand::{distributions::Standard, thread_rng, Rng};
|
||||
|
||||
@@ -59,7 +59,7 @@ impl Kernel for Constant {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
extern crate intel_mkl_src;
|
||||
extern crate blas_src;
|
||||
|
||||
use approx::assert_abs_diff_eq;
|
||||
use rand::{distributions::Standard, thread_rng, Rng};
|
||||
|
||||
@@ -62,7 +62,7 @@ impl Kernel for Exponential {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
extern crate intel_mkl_src;
|
||||
extern crate blas_src;
|
||||
|
||||
use approx::assert_abs_diff_eq;
|
||||
use rand::{distributions::Standard, thread_rng, Rng};
|
||||
|
||||
@@ -89,7 +89,7 @@ impl Kernel for Matern32 {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
extern crate intel_mkl_src;
|
||||
extern crate blas_src;
|
||||
|
||||
use approx::assert_abs_diff_eq;
|
||||
use rand::{distributions::Standard, thread_rng, Rng};
|
||||
|
||||
@@ -117,7 +117,7 @@ impl Kernel for Matern52 {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
extern crate intel_mkl_src;
|
||||
extern crate blas_src;
|
||||
|
||||
use approx::assert_abs_diff_eq;
|
||||
use rand::{distributions::Standard, thread_rng, Rng};
|
||||
|
||||
@@ -68,7 +68,7 @@ impl Kernel for Wiener {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
extern crate intel_mkl_src;
|
||||
extern crate blas_src;
|
||||
|
||||
use approx::assert_abs_diff_eq;
|
||||
use rand::{distributions::Standard, thread_rng, Rng};
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use std::f64::consts::TAU;
|
||||
use std::f64::consts::{SQRT_2, TAU};
|
||||
|
||||
use crate::storage::Storage;
|
||||
use crate::utils::{logphi, normcdf, normpdf, ROOTS_HERMITENORM_WS, ROOTS_HERMITENORM_XS};
|
||||
use crate::utils::{
|
||||
logphi, logsumexp2, normcdf, normpdf, ROOTS_HERMITENORM_WS, ROOTS_HERMITENORM_XS,
|
||||
};
|
||||
|
||||
use super::{f_params, Core, Observation};
|
||||
|
||||
@@ -52,6 +54,66 @@ fn ll_probit_tie(x: f64, margin: f64) -> f64 {
|
||||
}
|
||||
}
|
||||
|
||||
fn lambdas() -> [f64; 5] {
|
||||
[
|
||||
0.44 * SQRT_2,
|
||||
0.41 * SQRT_2,
|
||||
0.40 * SQRT_2,
|
||||
0.39 * SQRT_2,
|
||||
0.36 * SQRT_2,
|
||||
]
|
||||
}
|
||||
|
||||
const CS: [f64; 5] = [
|
||||
1.146480988574439e+02,
|
||||
-1.508871030070582e+03,
|
||||
2.676085036831241e+03,
|
||||
-1.356294962039222e+03,
|
||||
7.543285642111850e+01,
|
||||
];
|
||||
|
||||
fn mm_logit_win(mean_cav: f64, cov_cav: f64) {
|
||||
let mut arr1 = [0.0; 5];
|
||||
let mut arr2 = [0.0; 5];
|
||||
let mut arr3 = [0.0; 5];
|
||||
|
||||
for (i, x) in lambdas().iter().enumerate() {
|
||||
let (a, b, c) = mm_probit_win(x * mean_cav, x * x * cov_cav);
|
||||
|
||||
arr1[i] = a;
|
||||
arr2[i] = b;
|
||||
arr3[i] = c;
|
||||
}
|
||||
|
||||
let logpart1 = logsumexp2(arr1, CS);
|
||||
|
||||
/*
|
||||
dlogpart1 = (np.dot(np.exp(arr1) * arr2, CS * LAMBDAS)
|
||||
/ np.dot(np.exp(arr1), CS))
|
||||
d2logpart1 = (np.dot(np.exp(arr1) * (arr2 * arr2 + arr3),
|
||||
CS * LAMBDAS * LAMBDAS)
|
||||
/ np.dot(np.exp(arr1), CS)) - (dlogpart1 * dlogpart1)
|
||||
*/
|
||||
}
|
||||
|
||||
fn mm_logit_tie(x: f64, margin: f64) {
|
||||
//
|
||||
}
|
||||
|
||||
fn ll_logit_win(x: f64, margin: f64) -> f64 {
|
||||
let z = x - margin;
|
||||
|
||||
if z > 0.0 {
|
||||
-((-z).exp().ln_1p())
|
||||
} else {
|
||||
z - z.exp().ln_1p()
|
||||
}
|
||||
}
|
||||
|
||||
fn ll_logit_tie(x: f64, margin: f64) -> f64 {
|
||||
ll_logit_win(x, margin) + ll_logit_win(-x, margin) + (2.0 * margin).exp_m1().ln()
|
||||
}
|
||||
|
||||
fn cvi_expectations<F>(mean: f64, var: f64, ll_fct: F) -> (f64, f64, f64)
|
||||
where
|
||||
F: Fn(f64) -> f64,
|
||||
@@ -153,16 +215,20 @@ impl Observation for ProbitTieObservation {
|
||||
|
||||
pub struct LogitWinObservation {
|
||||
core: Core,
|
||||
_margin: f64,
|
||||
margin: f64,
|
||||
}
|
||||
|
||||
impl LogitWinObservation {
|
||||
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
|
||||
Self {
|
||||
core: Core::new(storage, elems, t),
|
||||
_margin: margin,
|
||||
margin,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn probability(_storage: &Storage, _elems: &[(usize, f64)], _t: f64, _margin: f64) -> f64 {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl Observation for LogitWinObservation {
|
||||
@@ -171,22 +237,30 @@ impl Observation for LogitWinObservation {
|
||||
}
|
||||
|
||||
fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
|
||||
self.core.kl_update(storage, lr, |mean, var| todo!())
|
||||
let margin = self.margin;
|
||||
|
||||
self.core.kl_update(storage, lr, |mean, var| {
|
||||
cvi_expectations(mean, var, |x| ll_logit_win(x, margin))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LogitTieObservation {
|
||||
core: Core,
|
||||
_margin: f64,
|
||||
margin: f64,
|
||||
}
|
||||
|
||||
impl LogitTieObservation {
|
||||
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
|
||||
Self {
|
||||
core: Core::new(storage, elems, t),
|
||||
_margin: margin,
|
||||
margin,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn probability(_storage: &Storage, _elems: &[(usize, f64)], _t: f64, _margin: f64) -> f64 {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl Observation for LogitTieObservation {
|
||||
@@ -195,7 +269,11 @@ impl Observation for LogitTieObservation {
|
||||
}
|
||||
|
||||
fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
|
||||
self.core.kl_update(storage, lr, |mean, var| todo!())
|
||||
let margin = self.margin;
|
||||
|
||||
self.core.kl_update(storage, lr, |mean, var| {
|
||||
cvi_expectations(mean, var, |x| ll_logit_tie(x, margin))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
12
src/utils.rs
12
src/utils.rs
@@ -1,3 +1,4 @@
|
||||
use std::cmp::Ordering;
|
||||
use std::f64::consts::{PI, SQRT_2, TAU};
|
||||
|
||||
use crate::math::erfc;
|
||||
@@ -88,6 +89,17 @@ pub fn logphi(z: f64) -> (f64, f64) {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn logsumexp2(xs: [f64; 5], bs: [f64; 5]) -> f64 {
|
||||
let a = xs.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
|
||||
|
||||
a + bs
|
||||
.iter()
|
||||
.zip(xs.iter().map(|x| (x - a).exp()))
|
||||
.map(|(b, x)| b * x)
|
||||
.sum::<f64>()
|
||||
.ln()
|
||||
}
|
||||
|
||||
pub const ROOTS_HERMITENORM_XS: [f64; 30] = [
|
||||
-9.706236,
|
||||
-8.68083772,
|
||||
|
||||
Reference in New Issue
Block a user