More things, better things, awesome
This commit is contained in:
@@ -4,10 +4,11 @@ version = "0.1.0"
|
|||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
approx = { version = "0.5.1", optional = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
approx = "0.5.1"
|
|
||||||
time = { version = "0.3.9", features = ["parsing"] }
|
time = { version = "0.3.9", features = ["parsing"] }
|
||||||
|
trueskill-tt = { path = ".", features = ["approx"] }
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
debug = true
|
debug = true
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ Rust port of [TrueSkillThroughTime.py](https://github.com/glandfried/TrueSkillTh
|
|||||||
|
|
||||||
## Todo
|
## Todo
|
||||||
|
|
||||||
- [ ] Change time from u64 to i64 so we can use i64::MIN in `batch::compute_elapsed()`
|
- [x] Implement approx for Gaussian
|
||||||
|
- [ ] Add more tests from `TrueSkillThroughTime.jl`
|
||||||
|
- [ ] Time needs to be an enum so we can have multiple states (see `batch::compute_elapsed()`)
|
||||||
- [ ] Add examples (use same TrueSkillThroughTime.(py|jl))
|
- [ ] Add examples (use same TrueSkillThroughTime.(py|jl))
|
||||||
- [ ] Add Observer (see [argmin](https://docs.rs/argmin/latest/argmin/core/trait.Observe.html) for inspiration)
|
- [ ] Add Observer (see [argmin](https://docs.rs/argmin/latest/argmin/core/trait.Observe.html) for inspiration)
|
||||||
|
|||||||
43
src/approx.rs
Normal file
43
src/approx.rs
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
use approx::{AbsDiffEq, RelativeEq, UlpsEq};
|
||||||
|
|
||||||
|
use crate::gaussian::Gaussian;
|
||||||
|
|
||||||
|
impl AbsDiffEq for Gaussian {
|
||||||
|
type Epsilon = <f64 as AbsDiffEq>::Epsilon;
|
||||||
|
|
||||||
|
fn default_epsilon() -> Self::Epsilon {
|
||||||
|
f64::default_epsilon()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
|
||||||
|
f64::abs_diff_eq(&self.mu, &other.mu, epsilon)
|
||||||
|
&& f64::abs_diff_eq(&self.sigma, &other.sigma, epsilon)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RelativeEq for Gaussian {
|
||||||
|
fn default_max_relative() -> Self::Epsilon {
|
||||||
|
f64::default_max_relative()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn relative_eq(
|
||||||
|
&self,
|
||||||
|
other: &Self,
|
||||||
|
epsilon: Self::Epsilon,
|
||||||
|
max_relative: Self::Epsilon,
|
||||||
|
) -> bool {
|
||||||
|
f64::relative_eq(&self.mu, &other.mu, epsilon, max_relative)
|
||||||
|
&& f64::relative_eq(&self.sigma, &other.sigma, epsilon, max_relative)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UlpsEq for Gaussian {
|
||||||
|
fn default_max_ulps() -> u32 {
|
||||||
|
f64::default_max_ulps()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool {
|
||||||
|
f64::ulps_eq(&self.mu, &other.mu, epsilon, max_ulps)
|
||||||
|
&& f64::ulps_eq(&self.sigma, &other.sigma, epsilon, max_ulps)
|
||||||
|
}
|
||||||
|
}
|
||||||
68
src/batch.rs
68
src/batch.rs
@@ -24,16 +24,19 @@ impl Default for Skill {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
struct Item {
|
struct Item {
|
||||||
agent: String,
|
agent: String,
|
||||||
likelihood: Gaussian,
|
likelihood: Gaussian,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
struct Team {
|
struct Team {
|
||||||
items: Vec<Item>,
|
items: Vec<Item>,
|
||||||
output: f64,
|
output: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
struct Event {
|
struct Event {
|
||||||
teams: Vec<Team>,
|
teams: Vec<Team>,
|
||||||
evidence: f64,
|
evidence: f64,
|
||||||
@@ -346,6 +349,71 @@ impl Batch {
|
|||||||
|
|
||||||
self.iteration(0, agents);
|
self.iteration(0, agents);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn log_evidence2(
|
||||||
|
&self,
|
||||||
|
online: bool,
|
||||||
|
agents2: &Vec<&str>,
|
||||||
|
forward: bool,
|
||||||
|
agents: &mut HashMap<String, Agent>,
|
||||||
|
) -> f64 {
|
||||||
|
if agents2.is_empty() {
|
||||||
|
if online || forward {
|
||||||
|
self.events
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(e, event)| {
|
||||||
|
Game::new(
|
||||||
|
self.within_priors(e, online, forward, agents),
|
||||||
|
event.outputs(),
|
||||||
|
event.weights.clone(),
|
||||||
|
self.p_draw,
|
||||||
|
)
|
||||||
|
.evidence
|
||||||
|
.ln()
|
||||||
|
})
|
||||||
|
.sum()
|
||||||
|
} else {
|
||||||
|
self.events.iter().map(|event| event.evidence.ln()).sum()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if online || forward {
|
||||||
|
self.events
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.filter(|(_, event)| {
|
||||||
|
event
|
||||||
|
.teams
|
||||||
|
.iter()
|
||||||
|
.flat_map(|team| &team.items)
|
||||||
|
.any(|item| agents2.contains(&item.agent.as_str()))
|
||||||
|
})
|
||||||
|
.map(|(e, event)| {
|
||||||
|
Game::new(
|
||||||
|
self.within_priors(e, online, forward, agents),
|
||||||
|
event.outputs(),
|
||||||
|
event.weights.clone(),
|
||||||
|
self.p_draw,
|
||||||
|
)
|
||||||
|
.evidence
|
||||||
|
.ln()
|
||||||
|
})
|
||||||
|
.sum()
|
||||||
|
} else {
|
||||||
|
self.events
|
||||||
|
.iter()
|
||||||
|
.filter(|event| {
|
||||||
|
event
|
||||||
|
.teams
|
||||||
|
.iter()
|
||||||
|
.flat_map(|team| &team.items)
|
||||||
|
.any(|item| agents2.contains(&item.agent.as_str()))
|
||||||
|
})
|
||||||
|
.map(|event| event.evidence.ln())
|
||||||
|
.sum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compute_elapsed(last_time: u64, actual_time: u64) -> u64 {
|
fn compute_elapsed(last_time: u64, actual_time: u64) -> u64 {
|
||||||
|
|||||||
@@ -158,6 +158,8 @@ impl Game {
|
|||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
self.evidence = 1.0;
|
||||||
|
|
||||||
let mut step = (f64::INFINITY, f64::INFINITY);
|
let mut step = (f64::INFINITY, f64::INFINITY);
|
||||||
let mut iter = 0;
|
let mut iter = 0;
|
||||||
|
|
||||||
@@ -194,7 +196,7 @@ impl Game {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if d.len() == 1 {
|
if d.len() == 1 {
|
||||||
self.evidence *= evidence(&d, &margin, &tie, 0);
|
self.evidence = evidence(&d, &margin, &tie, 0);
|
||||||
|
|
||||||
d[0].prior = t[0].posterior_win() - t[1].posterior_lose();
|
d[0].prior = t[0].posterior_win() - t[1].posterior_lose();
|
||||||
d[0].likelihood = approx(d[0].prior, margin[0], tie[0]) / d[0].prior;
|
d[0].likelihood = approx(d[0].prior, margin[0], tie[0]) / d[0].prior;
|
||||||
|
|||||||
@@ -286,13 +286,20 @@ impl History {
|
|||||||
|
|
||||||
data
|
data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn log_evidence(&mut self, forward: bool, agents: &Vec<&str>) -> f64 {
|
||||||
|
self.batches
|
||||||
|
.iter()
|
||||||
|
.map(|batch| batch.log_evidence2(self.online, agents, forward, &mut self.agents))
|
||||||
|
.sum()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use approx::assert_ulps_eq;
|
use approx::assert_ulps_eq;
|
||||||
|
|
||||||
use crate::{Game, Player, BETA, EPSILON, GAMMA, ITERATIONS, MU, P_DRAW, SIGMA};
|
use crate::{Game, Gaussian, Player, BETA, EPSILON, GAMMA, ITERATIONS, MU, P_DRAW, SIGMA};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
@@ -635,4 +642,89 @@ mod tests {
|
|||||||
epsilon = 0.000001
|
epsilon = 0.000001
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_teams() {
|
||||||
|
let composition = vec![
|
||||||
|
vec![vec!["a", "b"], vec!["c", "d"]],
|
||||||
|
vec![vec!["e", "f"], vec!["b", "c"]],
|
||||||
|
vec![vec!["a", "d"], vec!["e", "f"]],
|
||||||
|
];
|
||||||
|
let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]];
|
||||||
|
|
||||||
|
let mut h = History::new(
|
||||||
|
composition,
|
||||||
|
results,
|
||||||
|
vec![],
|
||||||
|
vec![],
|
||||||
|
HashMap::new(),
|
||||||
|
0.0,
|
||||||
|
6.0,
|
||||||
|
1.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
let trueskill_log_evidence = h.log_evidence(false, &vec![]);
|
||||||
|
let trueskill_log_evidence_online = h.log_evidence(true, &vec![]);
|
||||||
|
|
||||||
|
assert_ulps_eq!(
|
||||||
|
trueskill_log_evidence,
|
||||||
|
trueskill_log_evidence_online,
|
||||||
|
epsilon = 0.000001
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_ulps_eq!(
|
||||||
|
h.batches[0].posterior("b").mu,
|
||||||
|
-1.0 * h.batches[0].posterior("c").mu,
|
||||||
|
epsilon = 0.000001
|
||||||
|
);
|
||||||
|
|
||||||
|
let evidence_second_event = h.log_evidence(false, &vec!["b"]).exp() * 2.0;
|
||||||
|
assert_ulps_eq!(0.5, evidence_second_event, epsilon = 0.000001);
|
||||||
|
|
||||||
|
let evidence_third_event = h.log_evidence(false, &vec!["a"]).exp() * 2.0;
|
||||||
|
assert_ulps_eq!(0.669885, evidence_third_event, epsilon = 0.000001);
|
||||||
|
|
||||||
|
h.convergence(ITERATIONS, EPSILON, false);
|
||||||
|
|
||||||
|
let loocv_hat = h.log_evidence(false, &vec![]).exp();
|
||||||
|
let p_d_m_hat = h.log_evidence(true, &vec![]).exp();
|
||||||
|
|
||||||
|
assert_ulps_eq!(loocv_hat, 0.2410274245857821, epsilon = 0.000001);
|
||||||
|
assert_ulps_eq!(p_d_m_hat, 0.17243238958411006, epsilon = 0.000001);
|
||||||
|
|
||||||
|
assert_ulps_eq!(
|
||||||
|
h.batches[0].posterior("a"),
|
||||||
|
h.batches[0].posterior("b"),
|
||||||
|
epsilon = 0.000001
|
||||||
|
);
|
||||||
|
assert_ulps_eq!(
|
||||||
|
h.batches[0].posterior("c"),
|
||||||
|
h.batches[0].posterior("d"),
|
||||||
|
epsilon = 0.000001
|
||||||
|
);
|
||||||
|
assert_ulps_eq!(
|
||||||
|
h.batches[1].posterior("e"),
|
||||||
|
h.batches[1].posterior("f"),
|
||||||
|
epsilon = 0.000001
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_ulps_eq!(
|
||||||
|
h.batches[0].posterior("a"),
|
||||||
|
Gaussian::new(4.084902364982456, 5.10691909049607),
|
||||||
|
epsilon = 0.000001
|
||||||
|
);
|
||||||
|
assert_ulps_eq!(
|
||||||
|
h.batches[0].posterior("c"),
|
||||||
|
Gaussian::new(-0.5330294544847751, 5.10691909049607),
|
||||||
|
epsilon = 0.000001
|
||||||
|
);
|
||||||
|
assert_ulps_eq!(
|
||||||
|
h.batches[2].posterior("e"),
|
||||||
|
Gaussian::new(-3.551872900373382, 5.154569731627773),
|
||||||
|
epsilon = 0.000001
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ use std::cmp::Reverse;
|
|||||||
use std::f64::consts::{FRAC_1_SQRT_2, FRAC_2_SQRT_PI, SQRT_2};
|
use std::f64::consts::{FRAC_1_SQRT_2, FRAC_2_SQRT_PI, SQRT_2};
|
||||||
|
|
||||||
mod agent;
|
mod agent;
|
||||||
|
#[cfg(feature = "approx")]
|
||||||
|
mod approx;
|
||||||
mod batch;
|
mod batch;
|
||||||
mod game;
|
mod game;
|
||||||
mod gaussian;
|
mod gaussian;
|
||||||
|
|||||||
Reference in New Issue
Block a user