339 lines
8.1 KiB
Rust
339 lines
8.1 KiB
Rust
use std::borrow::{Borrow, ToOwned};
|
|
use std::cmp::Reverse;
|
|
use std::collections::HashMap;
|
|
use std::f64::consts::{FRAC_1_SQRT_2, FRAC_2_SQRT_PI, SQRT_2};
|
|
use std::hash::Hash;
|
|
|
|
pub mod agent;
|
|
#[cfg(feature = "approx")]
|
|
mod approx;
|
|
pub mod batch;
|
|
mod game;
|
|
pub mod gaussian;
|
|
mod history;
|
|
mod matrix;
|
|
mod message;
|
|
pub mod player;
|
|
|
|
pub use game::Game;
|
|
pub use gaussian::Gaussian;
|
|
pub use history::History;
|
|
use matrix::Matrix;
|
|
use message::DiffMessage;
|
|
pub use player::Player;
|
|
|
|
pub const BETA: f64 = 1.0;
|
|
pub const MU: f64 = 0.0;
|
|
pub const SIGMA: f64 = BETA * 6.0;
|
|
pub const GAMMA: f64 = BETA * 0.03;
|
|
pub const P_DRAW: f64 = 0.0;
|
|
pub const EPSILON: f64 = 1e-6;
|
|
pub const ITERATIONS: usize = 30;
|
|
|
|
const SQRT_TAU: f64 = 2.5066282746310002;
|
|
|
|
pub const N01: Gaussian = Gaussian::from_ms(0.0, 1.0);
|
|
pub const N00: Gaussian = Gaussian::from_ms(0.0, 0.0);
|
|
pub const N_INF: Gaussian = Gaussian::from_ms(0.0, f64::INFINITY);
|
|
|
|
#[derive(Copy, Clone, Default, PartialEq, PartialOrd, Eq, Ord, Hash, Debug)]
|
|
pub struct Index(usize);
|
|
|
|
impl From<usize> for Index {
|
|
fn from(ix: usize) -> Self {
|
|
Self(ix)
|
|
}
|
|
}
|
|
|
|
pub struct IndexMap<K>(HashMap<K, Index>);
|
|
|
|
impl<K> IndexMap<K>
|
|
where
|
|
K: Eq + Hash,
|
|
{
|
|
pub fn new() -> Self {
|
|
Self(HashMap::new())
|
|
}
|
|
|
|
pub fn get<Q: ?Sized>(&self, k: &Q) -> Option<Index>
|
|
where
|
|
K: Borrow<Q>,
|
|
Q: Hash + Eq + ToOwned<Owned = K>,
|
|
{
|
|
self.0.get(k).cloned()
|
|
}
|
|
|
|
pub fn get_or_create<Q: ?Sized>(&mut self, k: &Q) -> Index
|
|
where
|
|
K: Borrow<Q>,
|
|
Q: Hash + Eq + ToOwned<Owned = K>,
|
|
{
|
|
if let Some(idx) = self.0.get(k) {
|
|
*idx
|
|
} else {
|
|
let idx = Index::from(self.0.len());
|
|
|
|
self.0.insert(k.to_owned(), idx);
|
|
|
|
idx
|
|
}
|
|
}
|
|
|
|
pub fn key(&self, idx: Index) -> Option<&K> {
|
|
self.0
|
|
.iter()
|
|
.find(|&(_, value)| *value == idx)
|
|
.map(|(key, _)| key)
|
|
}
|
|
|
|
pub fn keys(&self) -> impl Iterator<Item = &K> {
|
|
self.0.keys()
|
|
}
|
|
}
|
|
|
|
impl<K> Default for IndexMap<K>
|
|
where
|
|
K: Eq + Hash,
|
|
{
|
|
fn default() -> Self {
|
|
IndexMap::new()
|
|
}
|
|
}
|
|
|
|
fn erfc(x: f64) -> f64 {
|
|
let z = x.abs();
|
|
let t = 1.0 / (1.0 + z / 2.0);
|
|
|
|
let a = -0.82215223 + t * 0.17087277;
|
|
let b = 1.48851587 + t * a;
|
|
let c = -1.13520398 + t * b;
|
|
let d = 0.27886807 + t * c;
|
|
let e = -0.18628806 + t * d;
|
|
let f = 0.09678418 + t * e;
|
|
let g = 0.37409196 + t * f;
|
|
let h = 1.00002368 + t * g;
|
|
|
|
let r = t * (-z * z - 1.26551223 + t * h).exp();
|
|
|
|
if x >= 0.0 { r } else { 2.0 - r }
|
|
}
|
|
|
|
fn erfc_inv(mut y: f64) -> f64 {
|
|
if y >= 2.0 {
|
|
return f64::NEG_INFINITY;
|
|
}
|
|
|
|
debug_assert!(y >= 0.0, "y must be nonnegative");
|
|
|
|
if y == 0.0 {
|
|
return f64::INFINITY;
|
|
}
|
|
|
|
if y >= 1.0 {
|
|
y = 2.0 - y;
|
|
}
|
|
|
|
let t = (-2.0 * (y / 2.0).ln()).sqrt();
|
|
|
|
let mut x = FRAC_1_SQRT_2 * ((2.30753 + t * 0.27061) / (1.0 + t * (0.99229 + t * 0.04481)) - t);
|
|
|
|
for _ in 0..3 {
|
|
let err = erfc(x) - y;
|
|
|
|
x += err / (FRAC_2_SQRT_PI * (-(x.powi(2))).exp() - x * err)
|
|
}
|
|
|
|
if y < 1.0 { x } else { -x }
|
|
}
|
|
|
|
fn ppf(p: f64, mu: f64, sigma: f64) -> f64 {
|
|
mu - sigma * SQRT_2 * erfc_inv(2.0 * p)
|
|
}
|
|
|
|
fn compute_margin(p_draw: f64, sd: f64) -> f64 {
|
|
ppf(0.5 - p_draw / 2.0, 0.0, sd).abs()
|
|
}
|
|
|
|
fn cdf(x: f64, mu: f64, sigma: f64) -> f64 {
|
|
let z = -(x - mu) / (sigma * SQRT_2);
|
|
|
|
0.5 * erfc(z)
|
|
}
|
|
|
|
fn pdf(x: f64, mu: f64, sigma: f64) -> f64 {
|
|
let normalizer = (SQRT_TAU * sigma).powi(-1);
|
|
let functional = (-((x - mu).powi(2)) / (2.0 * sigma.powi(2))).exp();
|
|
|
|
normalizer * functional
|
|
}
|
|
|
|
fn v_w(mu: f64, sigma: f64, margin: f64, tie: bool) -> (f64, f64) {
|
|
if !tie {
|
|
let alpha = (margin - mu) / sigma;
|
|
|
|
let v = pdf(-alpha, 0.0, 1.0) / cdf(-alpha, 0.0, 1.0);
|
|
let w = v * (v + (-alpha));
|
|
|
|
(v, w)
|
|
} else {
|
|
let alpha = (-margin - mu) / sigma;
|
|
let beta = (margin - mu) / sigma;
|
|
|
|
let v = (pdf(alpha, 0.0, 1.0) - pdf(beta, 0.0, 1.0))
|
|
/ (cdf(beta, 0.0, 1.0) - cdf(alpha, 0.0, 1.0));
|
|
let u = (alpha * pdf(alpha, 0.0, 1.0) - beta * pdf(beta, 0.0, 1.0))
|
|
/ (cdf(beta, 0.0, 1.0) - cdf(alpha, 0.0, 1.0));
|
|
let w = -(u - v.powi(2));
|
|
|
|
(v, w)
|
|
}
|
|
}
|
|
|
|
fn trunc(mu: f64, sigma: f64, margin: f64, tie: bool) -> (f64, f64) {
|
|
let (v, w) = v_w(mu, sigma, margin, tie);
|
|
|
|
let mu_trunc = mu + sigma * v;
|
|
let sigma_trunc = sigma * (1.0 - w).sqrt();
|
|
|
|
(mu_trunc, sigma_trunc)
|
|
}
|
|
|
|
pub(crate) fn approx(n: Gaussian, margin: f64, tie: bool) -> Gaussian {
|
|
let (mu, sigma) = trunc(n.mu, n.sigma, margin, tie);
|
|
|
|
Gaussian { mu, sigma }
|
|
}
|
|
|
|
pub(crate) fn tuple_max(v1: (f64, f64), v2: (f64, f64)) -> (f64, f64) {
|
|
(
|
|
if v1.0 > v2.0 { v1.0 } else { v2.0 },
|
|
if v1.1 > v2.1 { v1.1 } else { v2.1 },
|
|
)
|
|
}
|
|
|
|
pub(crate) fn tuple_gt(t: (f64, f64), e: f64) -> bool {
|
|
t.0 > e || t.1 > e
|
|
}
|
|
|
|
pub(crate) fn sort_perm(x: &[f64], reverse: bool) -> Vec<usize> {
|
|
let mut v = x.iter().enumerate().collect::<Vec<_>>();
|
|
|
|
if reverse {
|
|
v.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
|
|
} else {
|
|
v.sort_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap());
|
|
}
|
|
|
|
v.into_iter().map(|(i, _)| i).collect()
|
|
}
|
|
|
|
pub(crate) fn sort_time(xs: &[i64], reverse: bool) -> Vec<usize> {
|
|
let mut x = xs.iter().enumerate().collect::<Vec<_>>();
|
|
|
|
if reverse {
|
|
x.sort_by_key(|&(_, x)| Reverse(x));
|
|
} else {
|
|
x.sort_by_key(|&(_, x)| x);
|
|
}
|
|
|
|
x.into_iter().map(|(i, _)| i).collect()
|
|
}
|
|
|
|
pub(crate) fn evidence(d: &[DiffMessage], margin: &[f64], tie: &[bool], e: usize) -> f64 {
|
|
if tie[e] {
|
|
cdf(margin[e], d[e].prior.mu, d[e].prior.sigma)
|
|
- cdf(-margin[e], d[e].prior.mu, d[e].prior.sigma)
|
|
} else {
|
|
1.0 - cdf(margin[e], d[e].prior.mu, d[e].prior.sigma)
|
|
}
|
|
}
|
|
|
|
/// Calculates the match quality of the given rating groups. A result is the draw probability in the association
|
|
pub fn quality(rating_groups: &[&[Gaussian]], beta: f64) -> f64 {
|
|
let flatten_ratings = rating_groups
|
|
.iter()
|
|
.flat_map(|group| group.iter())
|
|
.collect::<Vec<_>>();
|
|
|
|
let flatten_weights = vec![1.0; flatten_ratings.len()].into_boxed_slice();
|
|
|
|
let length = flatten_ratings.len();
|
|
|
|
let mut mean_matrix = Matrix::new(length, 1);
|
|
|
|
for (i, rating) in flatten_ratings.iter().enumerate() {
|
|
mean_matrix[(i, 0)] = rating.mu;
|
|
}
|
|
|
|
let mut variance_matrix = Matrix::new(length, length);
|
|
|
|
for (i, rating) in flatten_ratings.iter().enumerate() {
|
|
variance_matrix[(i, i)] = rating.sigma.powi(2);
|
|
}
|
|
|
|
let mut rotated_a_matrix = Matrix::new(rating_groups.len() - 1, length);
|
|
|
|
let mut t = 0;
|
|
let mut x = 0;
|
|
|
|
for (row, group) in rating_groups.windows(2).enumerate() {
|
|
let current = group[0];
|
|
let next = group[1];
|
|
|
|
for n in t..t + current.len() {
|
|
rotated_a_matrix[(row, n)] = flatten_weights[n];
|
|
|
|
x += 1;
|
|
}
|
|
|
|
t += current.len();
|
|
|
|
for n in x..x + next.len() {
|
|
rotated_a_matrix[(row, n)] = -flatten_weights[n];
|
|
}
|
|
|
|
x += next.len();
|
|
}
|
|
|
|
let a_matrix = rotated_a_matrix.transpose();
|
|
|
|
let ata = beta.powi(2) * &rotated_a_matrix * &a_matrix;
|
|
let atsa = &rotated_a_matrix * &variance_matrix * &a_matrix;
|
|
|
|
let start = mean_matrix.transpose() * &a_matrix;
|
|
let middle = &ata + &atsa;
|
|
let end = &rotated_a_matrix * &mean_matrix;
|
|
|
|
let e_arg = (-0.5 * &start * &middle.inverse() * &end).determinant();
|
|
let s_arg = ata.determinant() / middle.determinant();
|
|
|
|
e_arg.exp() * s_arg.sqrt()
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use ::approx::assert_ulps_eq;
|
|
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_sort_perm() {
|
|
assert_eq!(sort_perm(&[0.0, 1.0, 2.0, 0.0], true), vec![2, 1, 0, 3]);
|
|
}
|
|
|
|
#[test]
|
|
fn test_sort_time() {
|
|
assert_eq!(sort_time(&[0, 1, 2, 0], true), vec![2, 1, 0, 3]);
|
|
}
|
|
|
|
#[test]
|
|
fn test_quality() {
|
|
let a = Gaussian::from_ms(25.0, 3.0);
|
|
let b = Gaussian::from_ms(25.0, 3.0);
|
|
|
|
let q = quality(&[&[a], &[b]], 25.0 / 3.0 / 2.0);
|
|
|
|
assert_ulps_eq!(q, 0.8115343414514944, epsilon = 1e-6)
|
|
}
|
|
}
|