Files
trueskill-tt/src/lib.rs
2024-04-03 10:25:10 +02:00

347 lines
8.1 KiB
Rust

use std::borrow::{Borrow, ToOwned};
use std::cmp::Reverse;
use std::collections::HashMap;
use std::f64::consts::{FRAC_1_SQRT_2, FRAC_2_SQRT_PI, SQRT_2};
use std::hash::Hash;
pub mod agent;
#[cfg(feature = "approx")]
mod approx;
pub mod batch;
mod game;
pub mod gaussian;
mod history;
mod matrix;
mod message;
pub mod player;
pub use game::Game;
pub use gaussian::Gaussian;
pub use history::History;
use matrix::Matrix;
use message::DiffMessage;
pub use player::Player;
pub const BETA: f64 = 1.0;
pub const MU: f64 = 0.0;
pub const SIGMA: f64 = BETA * 6.0;
pub const GAMMA: f64 = BETA * 0.03;
pub const P_DRAW: f64 = 0.0;
pub const EPSILON: f64 = 1e-6;
pub const ITERATIONS: usize = 30;
const SQRT_TAU: f64 = 2.5066282746310002;
pub const N01: Gaussian = Gaussian::from_ms(0.0, 1.0);
pub const N00: Gaussian = Gaussian::from_ms(0.0, 0.0);
pub const N_INF: Gaussian = Gaussian::from_ms(0.0, f64::INFINITY);
#[derive(Copy, Clone, Default, PartialEq, PartialOrd, Eq, Ord, Hash, Debug)]
pub struct Index(usize);
impl From<usize> for Index {
fn from(ix: usize) -> Self {
Self(ix)
}
}
pub struct IndexMap<K>(HashMap<K, Index>);
impl<K> IndexMap<K>
where
K: Eq + Hash,
{
pub fn new() -> Self {
Self(HashMap::new())
}
pub fn get<Q: ?Sized>(&self, k: &Q) -> Option<Index>
where
K: Borrow<Q>,
Q: Hash + Eq + ToOwned<Owned = K>,
{
self.0.get(k).cloned()
}
pub fn get_or_create<Q: ?Sized>(&mut self, k: &Q) -> Index
where
K: Borrow<Q>,
Q: Hash + Eq + ToOwned<Owned = K>,
{
if let Some(idx) = self.0.get(k) {
*idx
} else {
let idx = Index::from(self.0.len());
self.0.insert(k.to_owned(), idx);
idx
}
}
pub fn key(&self, idx: Index) -> Option<&K> {
self.0
.iter()
.find(|(_, &value)| value == idx)
.map(|(key, _)| key)
}
pub fn keys(&self) -> impl Iterator<Item = &K> {
self.0.keys()
}
}
impl<K> Default for IndexMap<K>
where
K: Eq + Hash,
{
fn default() -> Self {
IndexMap::new()
}
}
fn erfc(x: f64) -> f64 {
let z = x.abs();
let t = 1.0 / (1.0 + z / 2.0);
let a = -0.82215223 + t * 0.17087277;
let b = 1.48851587 + t * a;
let c = -1.13520398 + t * b;
let d = 0.27886807 + t * c;
let e = -0.18628806 + t * d;
let f = 0.09678418 + t * e;
let g = 0.37409196 + t * f;
let h = 1.00002368 + t * g;
let r = t * (-z * z - 1.26551223 + t * h).exp();
if x >= 0.0 {
r
} else {
2.0 - r
}
}
fn erfc_inv(mut y: f64) -> f64 {
if y >= 2.0 {
return f64::NEG_INFINITY;
}
debug_assert!(y >= 0.0, "y must be nonnegative");
if y == 0.0 {
return f64::INFINITY;
}
if y >= 1.0 {
y = 2.0 - y;
}
let t = (-2.0 * (y / 2.0).ln()).sqrt();
let mut x = FRAC_1_SQRT_2 * ((2.30753 + t * 0.27061) / (1.0 + t * (0.99229 + t * 0.04481)) - t);
for _ in 0..3 {
let err = erfc(x) - y;
x += err / (FRAC_2_SQRT_PI * (-(x.powi(2))).exp() - x * err)
}
if y < 1.0 {
x
} else {
-x
}
}
fn ppf(p: f64, mu: f64, sigma: f64) -> f64 {
mu - sigma * SQRT_2 * erfc_inv(2.0 * p)
}
fn compute_margin(p_draw: f64, sd: f64) -> f64 {
ppf(0.5 - p_draw / 2.0, 0.0, sd).abs()
}
fn cdf(x: f64, mu: f64, sigma: f64) -> f64 {
let z = -(x - mu) / (sigma * SQRT_2);
0.5 * erfc(z)
}
fn pdf(x: f64, mu: f64, sigma: f64) -> f64 {
let normalizer = (SQRT_TAU * sigma).powi(-1);
let functional = (-((x - mu).powi(2)) / (2.0 * sigma.powi(2))).exp();
normalizer * functional
}
fn v_w(mu: f64, sigma: f64, margin: f64, tie: bool) -> (f64, f64) {
if !tie {
let alpha = (margin - mu) / sigma;
let v = pdf(-alpha, 0.0, 1.0) / cdf(-alpha, 0.0, 1.0);
let w = v * (v + (-alpha));
(v, w)
} else {
let alpha = (-margin - mu) / sigma;
let beta = (margin - mu) / sigma;
let v = (pdf(alpha, 0.0, 1.0) - pdf(beta, 0.0, 1.0))
/ (cdf(beta, 0.0, 1.0) - cdf(alpha, 0.0, 1.0));
let u = (alpha * pdf(alpha, 0.0, 1.0) - beta * pdf(beta, 0.0, 1.0))
/ (cdf(beta, 0.0, 1.0) - cdf(alpha, 0.0, 1.0));
let w = -(u - v.powi(2));
(v, w)
}
}
fn trunc(mu: f64, sigma: f64, margin: f64, tie: bool) -> (f64, f64) {
let (v, w) = v_w(mu, sigma, margin, tie);
let mu_trunc = mu + sigma * v;
let sigma_trunc = sigma * (1.0 - w).sqrt();
(mu_trunc, sigma_trunc)
}
pub(crate) fn approx(n: Gaussian, margin: f64, tie: bool) -> Gaussian {
let (mu, sigma) = trunc(n.mu, n.sigma, margin, tie);
Gaussian { mu, sigma }
}
pub(crate) fn tuple_max(v1: (f64, f64), v2: (f64, f64)) -> (f64, f64) {
(
if v1.0 > v2.0 { v1.0 } else { v2.0 },
if v1.1 > v2.1 { v1.1 } else { v2.1 },
)
}
pub(crate) fn tuple_gt(t: (f64, f64), e: f64) -> bool {
t.0 > e || t.1 > e
}
pub(crate) fn sort_perm(x: &[f64], reverse: bool) -> Vec<usize> {
let mut v = x.iter().enumerate().collect::<Vec<_>>();
if reverse {
v.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
} else {
v.sort_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap());
}
v.into_iter().map(|(i, _)| i).collect()
}
pub(crate) fn sort_time(xs: &[i64], reverse: bool) -> Vec<usize> {
let mut x = xs.iter().enumerate().collect::<Vec<_>>();
if reverse {
x.sort_by_key(|(_, &x)| Reverse(x));
} else {
x.sort_by_key(|(_, &x)| x);
}
x.into_iter().map(|(i, _)| i).collect()
}
pub(crate) fn evidence(d: &[DiffMessage], margin: &[f64], tie: &[bool], e: usize) -> f64 {
if tie[e] {
cdf(margin[e], d[e].prior.mu, d[e].prior.sigma)
- cdf(-margin[e], d[e].prior.mu, d[e].prior.sigma)
} else {
1.0 - cdf(margin[e], d[e].prior.mu, d[e].prior.sigma)
}
}
/// Calculates the match quality of the given rating groups. A result is the draw probability in the association
pub fn quality(rating_groups: &[&[Gaussian]], beta: f64) -> f64 {
let flatten_ratings = rating_groups
.iter()
.flat_map(|group| group.iter())
.collect::<Vec<_>>();
let flatten_weights = vec![1.0; flatten_ratings.len()].into_boxed_slice();
let length = flatten_ratings.len();
let mut mean_matrix = Matrix::new(length, 1);
for (i, rating) in flatten_ratings.iter().enumerate() {
mean_matrix[(i, 0)] = rating.mu;
}
let mut variance_matrix = Matrix::new(length, length);
for (i, rating) in flatten_ratings.iter().enumerate() {
variance_matrix[(i, i)] = rating.sigma.powi(2);
}
let mut rotated_a_matrix = Matrix::new(rating_groups.len() - 1, length);
let mut t = 0;
let mut x = 0;
for (row, group) in rating_groups.windows(2).enumerate() {
let current = group[0];
let next = group[1];
for n in t..t + current.len() {
rotated_a_matrix[(row, n)] = flatten_weights[n];
x += 1;
}
t += current.len();
for n in x..x + next.len() {
rotated_a_matrix[(row, n)] = -flatten_weights[n];
}
x += next.len();
}
let a_matrix = rotated_a_matrix.transpose();
let ata = beta.powi(2) * &rotated_a_matrix * &a_matrix;
let atsa = &rotated_a_matrix * &variance_matrix * &a_matrix;
let start = mean_matrix.transpose() * &a_matrix;
let middle = &ata + &atsa;
let end = &rotated_a_matrix * &mean_matrix;
let e_arg = (-0.5 * &start * &middle.inverse() * &end).determinant();
let s_arg = ata.determinant() / middle.determinant();
e_arg.exp() * s_arg.sqrt()
}
#[cfg(test)]
mod tests {
use ::approx::assert_ulps_eq;
use super::*;
#[test]
fn test_sort_perm() {
assert_eq!(sort_perm(&[0.0, 1.0, 2.0, 0.0], true), vec![2, 1, 0, 3]);
}
#[test]
fn test_sort_time() {
assert_eq!(sort_time(&[0, 1, 2, 0], true), vec![2, 1, 0, 3]);
}
#[test]
fn test_quality() {
let a = Gaussian::from_ms(25.0, 3.0);
let b = Gaussian::from_ms(25.0, 3.0);
let q = quality(&[&[a], &[b]], 25.0 / 3.0 / 2.0);
assert_ulps_eq!(q, 0.8115343414514944, epsilon = 1e-6)
}
}