More things, better things, awesome

This commit is contained in:
2022-06-18 23:39:42 +02:00
parent dc10504b80
commit c9d9d59535
7 changed files with 214 additions and 4 deletions

43
src/approx.rs Normal file
View File

@@ -0,0 +1,43 @@
use approx::{AbsDiffEq, RelativeEq, UlpsEq};
use crate::gaussian::Gaussian;
impl AbsDiffEq for Gaussian {
type Epsilon = <f64 as AbsDiffEq>::Epsilon;
fn default_epsilon() -> Self::Epsilon {
f64::default_epsilon()
}
fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
f64::abs_diff_eq(&self.mu, &other.mu, epsilon)
&& f64::abs_diff_eq(&self.sigma, &other.sigma, epsilon)
}
}
impl RelativeEq for Gaussian {
fn default_max_relative() -> Self::Epsilon {
f64::default_max_relative()
}
fn relative_eq(
&self,
other: &Self,
epsilon: Self::Epsilon,
max_relative: Self::Epsilon,
) -> bool {
f64::relative_eq(&self.mu, &other.mu, epsilon, max_relative)
&& f64::relative_eq(&self.sigma, &other.sigma, epsilon, max_relative)
}
}
impl UlpsEq for Gaussian {
fn default_max_ulps() -> u32 {
f64::default_max_ulps()
}
fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool {
f64::ulps_eq(&self.mu, &other.mu, epsilon, max_ulps)
&& f64::ulps_eq(&self.sigma, &other.sigma, epsilon, max_ulps)
}
}

View File

@@ -24,16 +24,19 @@ impl Default for Skill {
}
}
#[derive(Debug)]
struct Item {
agent: String,
likelihood: Gaussian,
}
#[derive(Debug)]
struct Team {
items: Vec<Item>,
output: f64,
}
#[derive(Debug)]
struct Event {
teams: Vec<Team>,
evidence: f64,
@@ -346,6 +349,71 @@ impl Batch {
self.iteration(0, agents);
}
pub(crate) fn log_evidence2(
&self,
online: bool,
agents2: &Vec<&str>,
forward: bool,
agents: &mut HashMap<String, Agent>,
) -> f64 {
if agents2.is_empty() {
if online || forward {
self.events
.iter()
.enumerate()
.map(|(e, event)| {
Game::new(
self.within_priors(e, online, forward, agents),
event.outputs(),
event.weights.clone(),
self.p_draw,
)
.evidence
.ln()
})
.sum()
} else {
self.events.iter().map(|event| event.evidence.ln()).sum()
}
} else {
if online || forward {
self.events
.iter()
.enumerate()
.filter(|(_, event)| {
event
.teams
.iter()
.flat_map(|team| &team.items)
.any(|item| agents2.contains(&item.agent.as_str()))
})
.map(|(e, event)| {
Game::new(
self.within_priors(e, online, forward, agents),
event.outputs(),
event.weights.clone(),
self.p_draw,
)
.evidence
.ln()
})
.sum()
} else {
self.events
.iter()
.filter(|event| {
event
.teams
.iter()
.flat_map(|team| &team.items)
.any(|item| agents2.contains(&item.agent.as_str()))
})
.map(|event| event.evidence.ln())
.sum()
}
}
}
}
fn compute_elapsed(last_time: u64, actual_time: u64) -> u64 {

View File

@@ -158,6 +158,8 @@ impl Game {
.collect::<Vec<_>>()
};
self.evidence = 1.0;
let mut step = (f64::INFINITY, f64::INFINITY);
let mut iter = 0;
@@ -194,7 +196,7 @@ impl Game {
}
if d.len() == 1 {
self.evidence *= evidence(&d, &margin, &tie, 0);
self.evidence = evidence(&d, &margin, &tie, 0);
d[0].prior = t[0].posterior_win() - t[1].posterior_lose();
d[0].likelihood = approx(d[0].prior, margin[0], tie[0]) / d[0].prior;

View File

@@ -286,13 +286,20 @@ impl History {
data
}
pub fn log_evidence(&mut self, forward: bool, agents: &Vec<&str>) -> f64 {
self.batches
.iter()
.map(|batch| batch.log_evidence2(self.online, agents, forward, &mut self.agents))
.sum()
}
}
#[cfg(test)]
mod tests {
use approx::assert_ulps_eq;
use crate::{Game, Player, BETA, EPSILON, GAMMA, ITERATIONS, MU, P_DRAW, SIGMA};
use crate::{Game, Gaussian, Player, BETA, EPSILON, GAMMA, ITERATIONS, MU, P_DRAW, SIGMA};
use super::*;
@@ -635,4 +642,89 @@ mod tests {
epsilon = 0.000001
);
}
#[test]
fn test_teams() {
let composition = vec![
vec![vec!["a", "b"], vec!["c", "d"]],
vec![vec!["e", "f"], vec!["b", "c"]],
vec![vec!["a", "d"], vec!["e", "f"]],
];
let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]];
let mut h = History::new(
composition,
results,
vec![],
vec![],
HashMap::new(),
0.0,
6.0,
1.0,
0.0,
0.0,
false,
);
let trueskill_log_evidence = h.log_evidence(false, &vec![]);
let trueskill_log_evidence_online = h.log_evidence(true, &vec![]);
assert_ulps_eq!(
trueskill_log_evidence,
trueskill_log_evidence_online,
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[0].posterior("b").mu,
-1.0 * h.batches[0].posterior("c").mu,
epsilon = 0.000001
);
let evidence_second_event = h.log_evidence(false, &vec!["b"]).exp() * 2.0;
assert_ulps_eq!(0.5, evidence_second_event, epsilon = 0.000001);
let evidence_third_event = h.log_evidence(false, &vec!["a"]).exp() * 2.0;
assert_ulps_eq!(0.669885, evidence_third_event, epsilon = 0.000001);
h.convergence(ITERATIONS, EPSILON, false);
let loocv_hat = h.log_evidence(false, &vec![]).exp();
let p_d_m_hat = h.log_evidence(true, &vec![]).exp();
assert_ulps_eq!(loocv_hat, 0.2410274245857821, epsilon = 0.000001);
assert_ulps_eq!(p_d_m_hat, 0.17243238958411006, epsilon = 0.000001);
assert_ulps_eq!(
h.batches[0].posterior("a"),
h.batches[0].posterior("b"),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[0].posterior("c"),
h.batches[0].posterior("d"),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[1].posterior("e"),
h.batches[1].posterior("f"),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[0].posterior("a"),
Gaussian::new(4.084902364982456, 5.10691909049607),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[0].posterior("c"),
Gaussian::new(-0.5330294544847751, 5.10691909049607),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[2].posterior("e"),
Gaussian::new(-3.551872900373382, 5.154569731627773),
epsilon = 0.000001
);
}
}

View File

@@ -2,6 +2,8 @@ use std::cmp::Reverse;
use std::f64::consts::{FRAC_1_SQRT_2, FRAC_2_SQRT_PI, SQRT_2};
mod agent;
#[cfg(feature = "approx")]
mod approx;
mod batch;
mod game;
mod gaussian;