A lot of progress.
This commit is contained in:
113
src/utils.rs
Normal file
113
src/utils.rs
Normal file
@@ -0,0 +1,113 @@
|
||||
use std::f64::consts::PI;
|
||||
|
||||
const CS: [f64; 14] = [
|
||||
0.00048204,
|
||||
-0.00142906,
|
||||
0.0013200243174,
|
||||
0.0009461589032,
|
||||
-0.0045563339802,
|
||||
0.00556964649138,
|
||||
0.00125993961762116,
|
||||
-0.01621575378835404,
|
||||
0.02629651521057465,
|
||||
-0.001829764677455021,
|
||||
2.0 * (1.0 - PI / 3.0),
|
||||
(4.0 - PI) / 3.0,
|
||||
1.0,
|
||||
1.0,
|
||||
];
|
||||
|
||||
const RS: [f64; 5] = [
|
||||
1.2753666447299659525,
|
||||
5.019049726784267463450,
|
||||
6.1602098531096305441,
|
||||
7.409740605964741794425,
|
||||
2.9788656263939928886,
|
||||
];
|
||||
|
||||
const QS: [f64; 6] = [
|
||||
2.260528520767326969592,
|
||||
9.3960340162350541504,
|
||||
12.048951927855129036034,
|
||||
17.081440747466004316,
|
||||
9.608965327192787870698,
|
||||
3.3690752069827527677,
|
||||
];
|
||||
|
||||
/// Normal cumulative density function.
|
||||
fn normcdf(x: f64) -> f64 {
|
||||
// If X ~ N(0,1), returns P(X < x).
|
||||
// erfc(-x / SQRT2) / 2.0
|
||||
|
||||
todo!();
|
||||
}
|
||||
|
||||
/// Compute the log of the normal cumulative density function.
|
||||
pub fn logphi(z: f64) -> (f64, f64) {
|
||||
// Adapted from the GPML function `logphi.m`.
|
||||
let SQRT2: f64 = 2.0f64.sqrt();
|
||||
let SQRT2PI: f64 = (2.0 * PI).sqrt();
|
||||
|
||||
if z * z < 0.0492 {
|
||||
// First case: z close to zero.
|
||||
let coef = -z / SQRT2PI;
|
||||
let mut val = 0.0;
|
||||
|
||||
for c in &CS {
|
||||
val = coef * (c + val);
|
||||
}
|
||||
|
||||
let res = -2.0 * val - 2.0f64.ln();
|
||||
let dres = (-(z * z) / 2.0 - res).exp() / SQRT2PI;
|
||||
|
||||
(res, dres)
|
||||
} else if z < -11.3137 {
|
||||
// Second case: z very small.
|
||||
let mut num = 0.5641895835477550741;
|
||||
|
||||
for r in &RS {
|
||||
num = -z * num / SQRT2 + r;
|
||||
}
|
||||
|
||||
let mut den = 1.0;
|
||||
|
||||
for q in &QS {
|
||||
den = -z * den / SQRT2 + q;
|
||||
}
|
||||
|
||||
let res = (num / (2.0 * den)).ln() - (z * z) / 2.0;
|
||||
let dres = (den / num).abs() * (2.0 / PI).sqrt();
|
||||
|
||||
(res, dres)
|
||||
} else {
|
||||
let res = normcdf(z).ln();
|
||||
let dres = (-(z * z) / 2.0 - res).exp() / SQRT2PI;
|
||||
|
||||
(res, dres)
|
||||
}
|
||||
|
||||
/*
|
||||
if z * z < 0.0492:
|
||||
# First case: z close to zero.
|
||||
coef = -z / SQRT2PI
|
||||
val = 0
|
||||
for c in CS:
|
||||
val = coef * (c + val)
|
||||
res = -2 * val - log(2)
|
||||
dres = exp(-(z * z) / 2 - res) / SQRT2PI
|
||||
elif z < -11.3137:
|
||||
# Second case: z very small.
|
||||
num = 0.5641895835477550741
|
||||
for r in RS:
|
||||
num = -z * num / SQRT2 + r
|
||||
den = 1.0
|
||||
for q in QS:
|
||||
den = -z * den / SQRT2 + q
|
||||
res = log(num / (2 * den)) - (z * z) / 2
|
||||
dres = abs(den / num) * sqrt(2.0 / pi)
|
||||
else:
|
||||
res = log(normcdf(z))
|
||||
dres = exp(-(z * z) / 2 - res) / SQRT2PI
|
||||
return res, dres
|
||||
*/
|
||||
}
|
||||
Reference in New Issue
Block a user