use std::cmp::Reverse; use std::f64::consts::{FRAC_1_SQRT_2, FRAC_2_SQRT_PI, SQRT_2}; mod agent; #[cfg(feature = "approx")] mod approx; mod batch; mod game; mod gaussian; mod history; mod message; mod player; pub use game::Game; pub use gaussian::Gaussian; pub use history::History; use message::DiffMessage; pub use player::Player; pub const BETA: f64 = 1.0; pub const MU: f64 = 0.0; pub const SIGMA: f64 = BETA * 6.0; pub const GAMMA: f64 = BETA * 0.03; pub const P_DRAW: f64 = 0.0; pub const EPSILON: f64 = 1e-6; pub const ITERATIONS: usize = 30; const SQRT_TAU: f64 = 2.5066282746310002; const N01: Gaussian = Gaussian { mu: 0.0, sigma: 1.0, }; pub(crate) const N00: Gaussian = Gaussian { mu: 0.0, sigma: 0.0, }; pub(crate) const N_INF: Gaussian = Gaussian { mu: 0.0, sigma: f64::INFINITY, }; fn erfc(x: f64) -> f64 { let z = x.abs(); let t = 1.0 / (1.0 + z / 2.0); let a = -0.82215223 + t * 0.17087277; let b = 1.48851587 + t * a; let c = -1.13520398 + t * b; let d = 0.27886807 + t * c; let e = -0.18628806 + t * d; let f = 0.09678418 + t * e; let g = 0.37409196 + t * f; let h = 1.00002368 + t * g; let r = t * (-z * z - 1.26551223 + t * h).exp(); if x >= 0.0 { r } else { 2.0 - r } } fn erfc_inv(mut y: f64) -> f64 { if y >= 2.0 { return f64::NEG_INFINITY; } debug_assert!(y >= 0.0, "y must be nonnegative"); if y == 0.0 { return f64::INFINITY; } if y >= 1.0 { y = 2.0 - y; } let t = (-2.0 * (y / 2.0).ln()).sqrt(); let mut x = FRAC_1_SQRT_2 * ((2.30753 + t * 0.27061) / (1.0 + t * (0.99229 + t * 0.04481)) - t); for _ in 0..3 { let err = erfc(x) - y; x += err / (FRAC_2_SQRT_PI * (-(x.powi(2))).exp() - x * err) } if y < 1.0 { x } else { -x } } fn ppf(p: f64, mu: f64, sigma: f64) -> f64 { mu - sigma * SQRT_2 * erfc_inv(2.0 * p) } fn compute_margin(p_draw: f64, sd: f64) -> f64 { ppf(0.5 - p_draw / 2.0, 0.0, sd).abs() } fn cdf(x: f64, mu: f64, sigma: f64) -> f64 { let z = -(x - mu) / (sigma * SQRT_2); 0.5 * erfc(z) } fn pdf(x: f64, mu: f64, sigma: f64) -> f64 { let normalizer = (SQRT_TAU * sigma).powi(-1); let functional = (-((x - mu).powi(2)) / (2.0 * sigma.powi(2))).exp(); normalizer * functional } fn v_w(mu: f64, sigma: f64, margin: f64, tie: bool) -> (f64, f64) { if !tie { let alpha = (margin - mu) / sigma; let v = pdf(-alpha, 0.0, 1.0) / cdf(-alpha, 0.0, 1.0); let w = v * (v + (-alpha)); (v, w) } else { let alpha = (-margin - mu) / sigma; let beta = (margin - mu) / sigma; let v = (pdf(alpha, 0.0, 1.0) - pdf(beta, 0.0, 1.0)) / (cdf(beta, 0.0, 1.0) - cdf(alpha, 0.0, 1.0)); let u = (alpha * pdf(alpha, 0.0, 1.0) - beta * pdf(beta, 0.0, 1.0)) / (cdf(beta, 0.0, 1.0) - cdf(alpha, 0.0, 1.0)); let w = -(u - v.powi(2)); (v, w) } } fn trunc(mu: f64, sigma: f64, margin: f64, tie: bool) -> (f64, f64) { let (v, w) = v_w(mu, sigma, margin, tie); let mu_trunc = mu + sigma * v; let sigma_trunc = sigma * (1.0 - w).sqrt(); (mu_trunc, sigma_trunc) } pub(crate) fn approx(n: Gaussian, margin: f64, tie: bool) -> Gaussian { let (mu, sigma) = trunc(n.mu, n.sigma, margin, tie); Gaussian { mu, sigma } } pub(crate) fn tuple_max(v1: (f64, f64), v2: (f64, f64)) -> (f64, f64) { ( if v1.0 > v2.0 { v1.0 } else { v2.0 }, if v1.1 > v2.1 { v1.1 } else { v2.1 }, ) } pub(crate) fn tuple_gt(t: (f64, f64), e: f64) -> bool { t.0 > e || t.1 > e } pub(crate) fn sort_perm(x: &[f64], reverse: bool) -> Vec { let mut v = x.iter().enumerate().collect::>(); if reverse { v.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap()); } else { v.sort_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()); } v.into_iter().map(|(i, _)| i).collect() } pub(crate) fn sort_time(xs: &[u64], reverse: bool) -> Vec { let mut x = xs.iter().enumerate().collect::>(); if reverse { x.sort_by_key(|(_, &x)| Reverse(x)); } else { x.sort_by_key(|(_, &x)| x); } x.into_iter().map(|(i, _)| i).collect() } pub(crate) fn evidence(d: &[DiffMessage], margin: &[f64], tie: &[bool], e: usize) -> f64 { if tie[e] { cdf(margin[e], d[e].prior.mu, d[e].prior.sigma) - cdf(-margin[e], d[e].prior.mu, d[e].prior.sigma) } else { 1.0 - cdf(margin[e], d[e].prior.mu, d[e].prior.sigma) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_sort_perm() { assert_eq!(sort_perm(&[0.0, 1.0, 2.0, 0.0], true), vec![2, 1, 0, 3]); } #[test] fn test_sort_time() { assert_eq!(sort_time(&[0, 1, 2, 0], true), vec![2, 1, 0, 3]); } }