diff --git a/src/utils.rs b/src/utils.rs index 0092076..0019cd2 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,4 +1,4 @@ -use std::f64::consts::PI; +use std::f64::consts::{PI, SQRT_2}; use statrs::function::erf::erfc; @@ -38,20 +38,15 @@ const QS: [f64; 6] = [ /// Normal cumulative density function. fn normcdf(x: f64) -> f64 { - let SQRT2: f64 = 2.0f64.sqrt(); - - erfc(-x / SQRT2) / 2.0 + erfc(-x / SQRT_2) / 2.0 } /// Compute the log of the normal cumulative density function. pub fn logphi(z: f64) -> (f64, f64) { // Adapted from the GPML function `logphi.m`. - let SQRT2: f64 = 2.0f64.sqrt(); - let SQRT2PI: f64 = (2.0 * PI).sqrt(); - if z * z < 0.0492 { // First case: z close to zero. - let coef = -z / SQRT2PI; + let coef = -z / (2.0 * PI).sqrt(); let mut val = 0.0; for c in &CS { @@ -59,7 +54,7 @@ pub fn logphi(z: f64) -> (f64, f64) { } let res = -2.0 * val - 2.0f64.ln(); - let dres = (-(z * z) / 2.0 - res).exp() / SQRT2PI; + let dres = (-(z * z) / 2.0 - res).exp() / (2.0 * PI).sqrt(); (res, dres) } else if z < -11.3137 { @@ -67,13 +62,13 @@ pub fn logphi(z: f64) -> (f64, f64) { let mut num = 0.5641895835477550741; for r in &RS { - num = -z * num / SQRT2 + r; + num = -z * num / SQRT_2 + r; } let mut den = 1.0; for q in &QS { - den = -z * den / SQRT2 + q; + den = -z * den / SQRT_2 + q; } let res = (num / (2.0 * den)).ln() - (z * z) / 2.0; @@ -82,7 +77,7 @@ pub fn logphi(z: f64) -> (f64, f64) { (res, dres) } else { let res = normcdf(z).ln(); - let dres = (-(z * z) / 2.0 - res).exp() / SQRT2PI; + let dres = (-(z * z) / 2.0 - res).exp() / (2.0 * PI).sqrt(); (res, dres) }