diff --git a/Cargo.toml b/Cargo.toml index 53f6483..3234fe4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,10 +4,11 @@ version = "0.1.0" edition = "2021" [dependencies] +approx = { version = "0.5.1", optional = true } [dev-dependencies] -approx = "0.5.1" time = { version = "0.3.9", features = ["parsing"] } +trueskill-tt = { path = ".", features = ["approx"] } [profile.release] debug = true diff --git a/README.md b/README.md index 51331c8..6223a5f 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,8 @@ Rust port of [TrueSkillThroughTime.py](https://github.com/glandfried/TrueSkillTh ## Todo -- [ ] Change time from u64 to i64 so we can use i64::MIN in `batch::compute_elapsed()` +- [x] Implement approx for Gaussian +- [ ] Add more tests from `TrueSkillThroughTime.jl` +- [ ] Time needs to be an enum so we can have multiple states (see `batch::compute_elapsed()`) - [ ] Add examples (use same TrueSkillThroughTime.(py|jl)) - [ ] Add Observer (see [argmin](https://docs.rs/argmin/latest/argmin/core/trait.Observe.html) for inspiration) diff --git a/src/approx.rs b/src/approx.rs new file mode 100644 index 0000000..f187be9 --- /dev/null +++ b/src/approx.rs @@ -0,0 +1,43 @@ +use approx::{AbsDiffEq, RelativeEq, UlpsEq}; + +use crate::gaussian::Gaussian; + +impl AbsDiffEq for Gaussian { + type Epsilon = ::Epsilon; + + fn default_epsilon() -> Self::Epsilon { + f64::default_epsilon() + } + + fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool { + f64::abs_diff_eq(&self.mu, &other.mu, epsilon) + && f64::abs_diff_eq(&self.sigma, &other.sigma, epsilon) + } +} + +impl RelativeEq for Gaussian { + fn default_max_relative() -> Self::Epsilon { + f64::default_max_relative() + } + + fn relative_eq( + &self, + other: &Self, + epsilon: Self::Epsilon, + max_relative: Self::Epsilon, + ) -> bool { + f64::relative_eq(&self.mu, &other.mu, epsilon, max_relative) + && f64::relative_eq(&self.sigma, &other.sigma, epsilon, max_relative) + } +} + +impl UlpsEq for Gaussian { + fn default_max_ulps() -> u32 { + f64::default_max_ulps() + } + + fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool { + f64::ulps_eq(&self.mu, &other.mu, epsilon, max_ulps) + && f64::ulps_eq(&self.sigma, &other.sigma, epsilon, max_ulps) + } +} diff --git a/src/batch.rs b/src/batch.rs index 38fb1bb..bbe9731 100644 --- a/src/batch.rs +++ b/src/batch.rs @@ -24,16 +24,19 @@ impl Default for Skill { } } +#[derive(Debug)] struct Item { agent: String, likelihood: Gaussian, } +#[derive(Debug)] struct Team { items: Vec, output: f64, } +#[derive(Debug)] struct Event { teams: Vec, evidence: f64, @@ -346,6 +349,71 @@ impl Batch { self.iteration(0, agents); } + + pub(crate) fn log_evidence2( + &self, + online: bool, + agents2: &Vec<&str>, + forward: bool, + agents: &mut HashMap, + ) -> f64 { + if agents2.is_empty() { + if online || forward { + self.events + .iter() + .enumerate() + .map(|(e, event)| { + Game::new( + self.within_priors(e, online, forward, agents), + event.outputs(), + event.weights.clone(), + self.p_draw, + ) + .evidence + .ln() + }) + .sum() + } else { + self.events.iter().map(|event| event.evidence.ln()).sum() + } + } else { + if online || forward { + self.events + .iter() + .enumerate() + .filter(|(_, event)| { + event + .teams + .iter() + .flat_map(|team| &team.items) + .any(|item| agents2.contains(&item.agent.as_str())) + }) + .map(|(e, event)| { + Game::new( + self.within_priors(e, online, forward, agents), + event.outputs(), + event.weights.clone(), + self.p_draw, + ) + .evidence + .ln() + }) + .sum() + } else { + self.events + .iter() + .filter(|event| { + event + .teams + .iter() + .flat_map(|team| &team.items) + .any(|item| agents2.contains(&item.agent.as_str())) + }) + .map(|event| event.evidence.ln()) + .sum() + } + } + } } fn compute_elapsed(last_time: u64, actual_time: u64) -> u64 { diff --git a/src/game.rs b/src/game.rs index 58db44c..60510cc 100644 --- a/src/game.rs +++ b/src/game.rs @@ -158,6 +158,8 @@ impl Game { .collect::>() }; + self.evidence = 1.0; + let mut step = (f64::INFINITY, f64::INFINITY); let mut iter = 0; @@ -194,7 +196,7 @@ impl Game { } if d.len() == 1 { - self.evidence *= evidence(&d, &margin, &tie, 0); + self.evidence = evidence(&d, &margin, &tie, 0); d[0].prior = t[0].posterior_win() - t[1].posterior_lose(); d[0].likelihood = approx(d[0].prior, margin[0], tie[0]) / d[0].prior; diff --git a/src/history.rs b/src/history.rs index 5d814c9..eb8e929 100644 --- a/src/history.rs +++ b/src/history.rs @@ -286,13 +286,20 @@ impl History { data } + + pub fn log_evidence(&mut self, forward: bool, agents: &Vec<&str>) -> f64 { + self.batches + .iter() + .map(|batch| batch.log_evidence2(self.online, agents, forward, &mut self.agents)) + .sum() + } } #[cfg(test)] mod tests { use approx::assert_ulps_eq; - use crate::{Game, Player, BETA, EPSILON, GAMMA, ITERATIONS, MU, P_DRAW, SIGMA}; + use crate::{Game, Gaussian, Player, BETA, EPSILON, GAMMA, ITERATIONS, MU, P_DRAW, SIGMA}; use super::*; @@ -635,4 +642,89 @@ mod tests { epsilon = 0.000001 ); } + + #[test] + fn test_teams() { + let composition = vec![ + vec![vec!["a", "b"], vec!["c", "d"]], + vec![vec!["e", "f"], vec!["b", "c"]], + vec![vec!["a", "d"], vec!["e", "f"]], + ]; + let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]]; + + let mut h = History::new( + composition, + results, + vec![], + vec![], + HashMap::new(), + 0.0, + 6.0, + 1.0, + 0.0, + 0.0, + false, + ); + + let trueskill_log_evidence = h.log_evidence(false, &vec![]); + let trueskill_log_evidence_online = h.log_evidence(true, &vec![]); + + assert_ulps_eq!( + trueskill_log_evidence, + trueskill_log_evidence_online, + epsilon = 0.000001 + ); + + assert_ulps_eq!( + h.batches[0].posterior("b").mu, + -1.0 * h.batches[0].posterior("c").mu, + epsilon = 0.000001 + ); + + let evidence_second_event = h.log_evidence(false, &vec!["b"]).exp() * 2.0; + assert_ulps_eq!(0.5, evidence_second_event, epsilon = 0.000001); + + let evidence_third_event = h.log_evidence(false, &vec!["a"]).exp() * 2.0; + assert_ulps_eq!(0.669885, evidence_third_event, epsilon = 0.000001); + + h.convergence(ITERATIONS, EPSILON, false); + + let loocv_hat = h.log_evidence(false, &vec![]).exp(); + let p_d_m_hat = h.log_evidence(true, &vec![]).exp(); + + assert_ulps_eq!(loocv_hat, 0.2410274245857821, epsilon = 0.000001); + assert_ulps_eq!(p_d_m_hat, 0.17243238958411006, epsilon = 0.000001); + + assert_ulps_eq!( + h.batches[0].posterior("a"), + h.batches[0].posterior("b"), + epsilon = 0.000001 + ); + assert_ulps_eq!( + h.batches[0].posterior("c"), + h.batches[0].posterior("d"), + epsilon = 0.000001 + ); + assert_ulps_eq!( + h.batches[1].posterior("e"), + h.batches[1].posterior("f"), + epsilon = 0.000001 + ); + + assert_ulps_eq!( + h.batches[0].posterior("a"), + Gaussian::new(4.084902364982456, 5.10691909049607), + epsilon = 0.000001 + ); + assert_ulps_eq!( + h.batches[0].posterior("c"), + Gaussian::new(-0.5330294544847751, 5.10691909049607), + epsilon = 0.000001 + ); + assert_ulps_eq!( + h.batches[2].posterior("e"), + Gaussian::new(-3.551872900373382, 5.154569731627773), + epsilon = 0.000001 + ); + } } diff --git a/src/lib.rs b/src/lib.rs index aa427f0..aceb40c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,8 @@ use std::cmp::Reverse; use std::f64::consts::{FRAC_1_SQRT_2, FRAC_2_SQRT_PI, SQRT_2}; mod agent; +#[cfg(feature = "approx")] +mod approx; mod batch; mod game; mod gaussian;