Port from julia version instead
This commit is contained in:
214
src/lib.rs
214
src/lib.rs
@@ -1,25 +1,211 @@
|
||||
use std::cmp::Reverse;
|
||||
use std::f64::consts::{FRAC_1_SQRT_2, FRAC_2_SQRT_PI, SQRT_2};
|
||||
|
||||
mod agent;
|
||||
mod batch;
|
||||
mod game;
|
||||
mod gaussian;
|
||||
mod history;
|
||||
mod message;
|
||||
mod player;
|
||||
mod utils;
|
||||
mod variable;
|
||||
|
||||
pub use batch::*;
|
||||
pub use game::*;
|
||||
pub use gaussian::*;
|
||||
pub use history::*;
|
||||
pub use player::*;
|
||||
use gaussian::Gaussian;
|
||||
use message::DiffMessage;
|
||||
|
||||
pub const BETA: f64 = 1.0;
|
||||
pub use game::Game;
|
||||
pub use history::History;
|
||||
pub use player::Player;
|
||||
|
||||
const BETA: f64 = 1.0;
|
||||
pub const MU: f64 = 0.0;
|
||||
pub const SIGMA: f64 = BETA * 6.0;
|
||||
pub const GAMMA: f64 = BETA * 0.03;
|
||||
pub const P_DRAW: f64 = 0.0;
|
||||
const GAMMA: f64 = BETA * 0.03;
|
||||
const P_DRAW: f64 = 0.0;
|
||||
pub const EPSILON: f64 = 1e-6;
|
||||
pub const ITERATIONS: usize = 30;
|
||||
|
||||
pub const N01: Gaussian = Gaussian::new(0.0, 1.0);
|
||||
pub const N00: Gaussian = Gaussian::new(0.0, 0.0);
|
||||
pub const N_INF: Gaussian = Gaussian::new(0.0, f64::INFINITY);
|
||||
pub const N_MS: Gaussian = Gaussian::new(MU, SIGMA);
|
||||
const SQRT_TAU: f64 = 2.5066282746310002;
|
||||
|
||||
const N01: Gaussian = Gaussian {
|
||||
mu: 0.0,
|
||||
sigma: 1.0,
|
||||
};
|
||||
pub(crate) const N00: Gaussian = Gaussian {
|
||||
mu: 0.0,
|
||||
sigma: 0.0,
|
||||
};
|
||||
pub(crate) const N_INF: Gaussian = Gaussian {
|
||||
mu: 0.0,
|
||||
sigma: f64::INFINITY,
|
||||
};
|
||||
|
||||
fn erfc(x: f64) -> f64 {
|
||||
let z = x.abs();
|
||||
let t = 1.0 / (1.0 + z / 2.0);
|
||||
|
||||
let a = -0.82215223 + t * 0.17087277;
|
||||
let b = 1.48851587 + t * a;
|
||||
let c = -1.13520398 + t * b;
|
||||
let d = 0.27886807 + t * c;
|
||||
let e = -0.18628806 + t * d;
|
||||
let f = 0.09678418 + t * e;
|
||||
let g = 0.37409196 + t * f;
|
||||
let h = 1.00002368 + t * g;
|
||||
|
||||
let r = t * (-z * z - 1.26551223 + t * h).exp();
|
||||
|
||||
if x >= 0.0 {
|
||||
r
|
||||
} else {
|
||||
2.0 - r
|
||||
}
|
||||
}
|
||||
|
||||
fn erfc_inv(mut y: f64) -> f64 {
|
||||
if y >= 2.0 {
|
||||
return f64::NEG_INFINITY;
|
||||
}
|
||||
|
||||
debug_assert!(y >= 0.0, "y must be nonnegative");
|
||||
|
||||
if y == 0.0 {
|
||||
return f64::INFINITY;
|
||||
}
|
||||
|
||||
if y >= 1.0 {
|
||||
y = 2.0 - y;
|
||||
}
|
||||
|
||||
let t = (-2.0 * (y / 2.0).ln()).sqrt();
|
||||
|
||||
let mut x = FRAC_1_SQRT_2 * ((2.30753 + t * 0.27061) / (1.0 + t * (0.99229 + t * 0.04481)) - t);
|
||||
|
||||
for _ in 0..3 {
|
||||
let err = erfc(x) - y;
|
||||
|
||||
x += err / (FRAC_2_SQRT_PI * (-(x.powi(2))).exp() - x * err)
|
||||
}
|
||||
|
||||
if y < 1.0 {
|
||||
x
|
||||
} else {
|
||||
-x
|
||||
}
|
||||
}
|
||||
|
||||
fn ppf(p: f64, mu: f64, sigma: f64) -> f64 {
|
||||
mu - sigma * SQRT_2 * erfc_inv(2.0 * p)
|
||||
}
|
||||
|
||||
fn compute_margin(p_draw: f64, sd: f64) -> f64 {
|
||||
ppf(0.5 - p_draw / 2.0, 0.0, sd).abs()
|
||||
}
|
||||
|
||||
fn cdf(x: f64, mu: f64, sigma: f64) -> f64 {
|
||||
let z = -(x - mu) / (sigma * SQRT_2);
|
||||
|
||||
0.5 * erfc(z)
|
||||
}
|
||||
|
||||
fn pdf(x: f64, mu: f64, sigma: f64) -> f64 {
|
||||
let normalizer = (SQRT_TAU * sigma).powi(-1);
|
||||
let functional = (-((x - mu).powi(2)) / (2.0 * sigma.powi(2))).exp();
|
||||
|
||||
normalizer * functional
|
||||
}
|
||||
|
||||
fn v_w(mu: f64, sigma: f64, margin: f64, tie: bool) -> (f64, f64) {
|
||||
if !tie {
|
||||
let alpha = (margin - mu) / sigma;
|
||||
|
||||
let v = pdf(-alpha, 0.0, 1.0) / cdf(-alpha, 0.0, 1.0);
|
||||
let w = v * (v + (-alpha));
|
||||
|
||||
(v, w)
|
||||
} else {
|
||||
let alpha = (-margin - mu) / sigma;
|
||||
let beta = (margin - mu) / sigma;
|
||||
|
||||
let v = (pdf(alpha, 0.0, 1.0) - pdf(beta, 0.0, 1.0))
|
||||
/ (cdf(beta, 0.0, 1.0) - cdf(alpha, 0.0, 1.0));
|
||||
let u = (alpha * pdf(alpha, 0.0, 1.0) - beta * pdf(beta, 0.0, 1.0))
|
||||
/ (cdf(beta, 0.0, 1.0) - cdf(alpha, 0.0, 1.0));
|
||||
let w = -(u - v.powi(2));
|
||||
|
||||
(v, w)
|
||||
}
|
||||
}
|
||||
|
||||
fn trunc(mu: f64, sigma: f64, margin: f64, tie: bool) -> (f64, f64) {
|
||||
let (v, w) = v_w(mu, sigma, margin, tie);
|
||||
|
||||
let mu_trunc = mu + sigma * v;
|
||||
let sigma_trunc = sigma * (1.0 - w).sqrt();
|
||||
|
||||
(mu_trunc, sigma_trunc)
|
||||
}
|
||||
|
||||
pub(crate) fn approx(n: Gaussian, margin: f64, tie: bool) -> Gaussian {
|
||||
let (mu, sigma) = trunc(n.mu, n.sigma, margin, tie);
|
||||
|
||||
Gaussian { mu, sigma }
|
||||
}
|
||||
|
||||
pub(crate) fn tuple_max(v1: (f64, f64), v2: (f64, f64)) -> (f64, f64) {
|
||||
(
|
||||
if v1.0 > v2.0 { v1.0 } else { v2.0 },
|
||||
if v1.1 > v2.1 { v1.1 } else { v2.1 },
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn tuple_gt(t: (f64, f64), e: f64) -> bool {
|
||||
t.0 > e || t.1 > e
|
||||
}
|
||||
|
||||
pub(crate) fn sort_perm(x: &[f64], reverse: bool) -> Vec<usize> {
|
||||
let mut v = x.iter().enumerate().collect::<Vec<_>>();
|
||||
|
||||
if reverse {
|
||||
v.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
|
||||
} else {
|
||||
v.sort_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap());
|
||||
}
|
||||
|
||||
v.into_iter().map(|(i, _)| i).collect()
|
||||
}
|
||||
|
||||
pub(crate) fn sort_time(xs: &[u64], reverse: bool) -> Vec<usize> {
|
||||
let mut x = xs.iter().enumerate().collect::<Vec<_>>();
|
||||
|
||||
if reverse {
|
||||
x.sort_by_key(|(_, &x)| Reverse(x));
|
||||
} else {
|
||||
x.sort_by_key(|(_, &x)| x);
|
||||
}
|
||||
|
||||
x.into_iter().map(|(i, _)| i).collect()
|
||||
}
|
||||
|
||||
pub(crate) fn evidence(d: &[DiffMessage], margin: &[f64], tie: &[bool], e: usize) -> f64 {
|
||||
if tie[e] {
|
||||
cdf(margin[e], d[e].prior.mu, d[e].prior.sigma)
|
||||
- cdf(-margin[e], d[e].prior.mu, d[e].prior.sigma)
|
||||
} else {
|
||||
1.0 - cdf(margin[e], d[e].prior.mu, d[e].prior.sigma)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_sort_perm() {
|
||||
assert_eq!(sort_perm(&[0.0, 1.0, 2.0, 0.0], true), vec![2, 1, 0, 3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sort_time() {
|
||||
assert_eq!(sort_time(&[0, 1, 2, 0], true), vec![2, 1, 0, 3]);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user