Use const instead. Close #1

This commit is contained in:
2020-02-21 09:23:28 +01:00
parent 8a1e6620ad
commit e7a2679941

View File

@@ -1,4 +1,4 @@
use std::f64::consts::PI; use std::f64::consts::{PI, SQRT_2};
use statrs::function::erf::erfc; use statrs::function::erf::erfc;
@@ -38,20 +38,15 @@ const QS: [f64; 6] = [
/// Normal cumulative density function. /// Normal cumulative density function.
fn normcdf(x: f64) -> f64 { fn normcdf(x: f64) -> f64 {
let SQRT2: f64 = 2.0f64.sqrt(); erfc(-x / SQRT_2) / 2.0
erfc(-x / SQRT2) / 2.0
} }
/// Compute the log of the normal cumulative density function. /// Compute the log of the normal cumulative density function.
pub fn logphi(z: f64) -> (f64, f64) { pub fn logphi(z: f64) -> (f64, f64) {
// Adapted from the GPML function `logphi.m`. // Adapted from the GPML function `logphi.m`.
let SQRT2: f64 = 2.0f64.sqrt();
let SQRT2PI: f64 = (2.0 * PI).sqrt();
if z * z < 0.0492 { if z * z < 0.0492 {
// First case: z close to zero. // First case: z close to zero.
let coef = -z / SQRT2PI; let coef = -z / (2.0 * PI).sqrt();
let mut val = 0.0; let mut val = 0.0;
for c in &CS { for c in &CS {
@@ -59,7 +54,7 @@ pub fn logphi(z: f64) -> (f64, f64) {
} }
let res = -2.0 * val - 2.0f64.ln(); let res = -2.0 * val - 2.0f64.ln();
let dres = (-(z * z) / 2.0 - res).exp() / SQRT2PI; let dres = (-(z * z) / 2.0 - res).exp() / (2.0 * PI).sqrt();
(res, dres) (res, dres)
} else if z < -11.3137 { } else if z < -11.3137 {
@@ -67,13 +62,13 @@ pub fn logphi(z: f64) -> (f64, f64) {
let mut num = 0.5641895835477550741; let mut num = 0.5641895835477550741;
for r in &RS { for r in &RS {
num = -z * num / SQRT2 + r; num = -z * num / SQRT_2 + r;
} }
let mut den = 1.0; let mut den = 1.0;
for q in &QS { for q in &QS {
den = -z * den / SQRT2 + q; den = -z * den / SQRT_2 + q;
} }
let res = (num / (2.0 * den)).ln() - (z * z) / 2.0; let res = (num / (2.0 * den)).ln() - (z * z) / 2.0;
@@ -82,7 +77,7 @@ pub fn logphi(z: f64) -> (f64, f64) {
(res, dres) (res, dres)
} else { } else {
let res = normcdf(z).ln(); let res = normcdf(z).ln();
let dres = (-(z * z) / 2.0 - res).exp() / SQRT2PI; let dres = (-(z * z) / 2.0 - res).exp() / (2.0 * PI).sqrt();
(res, dres) (res, dres)
} }