diff --git a/Cargo.toml b/Cargo.toml index ae16f6d..4be6f77 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,5 @@ name = "trueskill-tt" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] +[dev-dependencies] +approx = "0.5.1" diff --git a/src/batch.rs b/src/batch.rs index 136a0b4..80eb979 100644 --- a/src/batch.rs +++ b/src/batch.rs @@ -140,7 +140,7 @@ impl Batch { } } - let from = self.events.len() + 1; + let from = self.events.len(); for e in 0..composition.len() { let teams = (0..composition[e].len()) @@ -176,10 +176,10 @@ impl Batch { skill.likelihood * skill.backward * skill.forward } - fn posteriors(&self) -> HashMap<&str, Gaussian> { + pub fn posteriors(&self) -> HashMap { self.skills .keys() - .map(|a| (a.as_str(), self.posterior(a))) + .map(|a| (a.clone(), self.posterior(a))) .collect::>() } @@ -223,10 +223,65 @@ impl Batch { self.events[e].evidence = g.evidence; } } + + pub fn convergence(&mut self) -> usize { + let epsilon = 1e-6; + let iterations = 20; + + let mut step = (f64::INFINITY, f64::INFINITY); + let mut i = 0; + + while (step.0 > epsilon || step.1 > epsilon) && i < iterations { + let old = self.posteriors(); + + self.iteration(0); + + let new = self.posteriors(); + + step = old.iter().fold((0.0, 0.0), |(o_l, o_r), (a, old)| { + let (n_l, n_r) = old.delta(new[a]); + + ( + if n_l > o_l { n_l } else { o_l }, + if n_r > o_r { n_r } else { o_r }, + ) + }); + + i += 1; + } + + i + } + + /* + def convergence(self, epsilon=1e-6, iterations = 20): + step, i = (inf, inf), 0 + while gr_tuple(step, epsilon) and (i < iterations): + old = self.posteriors().copy() + self.iteration() + step = dict_diff(old, self.posteriors()) + i += 1 + return i + def forward_prior_out(self, agent): + return self.skills[agent].forward * self.skills[agent].likelihood + def backward_prior_out(self, agent): + N = self.skills[agent].likelihood*self.skills[agent].backward + return N.forget(self.agents[agent].player.gamma, self.skills[agent].elapsed) + def new_backward_info(self): + for a in self.skills: + self.skills[a].backward = self.agents[a].message + return self.iteration() + def new_forward_info(self): + for a in self.skills: + self.skills[a].forward = self.agents[a].receive(self.skills[a].elapsed) + return self.iteration() + */ } #[cfg(test)] mod tests { + use approx::assert_ulps_eq; + use super::*; #[test] @@ -248,7 +303,7 @@ mod tests { agents.insert(k.to_string(), agent); } - let b = Batch::new( + let mut b = Batch::new( vec![ vec![vec!["a"], vec!["b"]], vec![vec!["c"], vec!["d"]], @@ -262,23 +317,71 @@ mod tests { let post = b.posteriors(); - assert_eq!(post["a"].mu(), 29.205); - assert_eq!(post["a"].sigma(), 7.194) + assert_eq!(post["a"].mu(), 29.205220743876975); + assert_eq!(post["a"].sigma(), 7.194481422570443); - /* - agents = dict() - for k in ["a", "b", "c", "d", "e", "f"]: - agents[k] = ttt.Agent(ttt.Player(ttt.Gaussian(25., 25.0/3), 25.0/6, 25.0/300 ) , ttt.Ninf, -ttt.inf) - b = ttt.Batch(composition=[ [["a"],["b"]], [["c"],["d"]] , [["e"],["f"]] ], results= [[1,0],[0,1],[1,0]], time = 0, agents=agents) - post = b.posteriors() - self.assertAlmostEqual(post["a"].mu,29.205,3) - self.assertAlmostEqual(post["a"].sigma,7.194,3) + assert_eq!(post["b"].mu(), 20.79477925612302); + assert_eq!(post["b"].sigma(), 7.194481422570443); - self.assertAlmostEqual(post["b"].mu,20.795,3) - self.assertAlmostEqual(post["b"].sigma,7.194,3) - self.assertAlmostEqual(post["c"].mu,20.795,3) - self.assertAlmostEqual(post["c"].sigma,7.194,3) - self.assertEqual(b.convergence(),1) - */ + assert_eq!(post["c"].mu(), 20.79477925612302); + assert_eq!(post["c"].sigma(), 7.194481422570443); + + assert_eq!(b.convergence(), 1); + } + + #[test] + fn test_same_strength() { + let mut agents = HashMap::new(); + + for k in ["a", "b", "c", "d", "e", "f"] { + let agent = Agent::new( + Player::new( + Gaussian::new(25.0, 25.0 / 3.0), + 25.0 / 6.0, + 25.0 / 300.0, + N_INF, + ), + N_INF, + f64::NEG_INFINITY, + ); + + agents.insert(k.to_string(), agent); + } + + let mut b = Batch::new( + vec![ + vec![vec!["a"], vec!["b"]], + vec![vec!["a"], vec!["c"]], + vec![vec!["b"], vec!["c"]], + ], + vec![vec![1, 0], vec![0, 1], vec![1, 0]], + 2.0, + agents, + 0.0, + ); + + let post = b.posteriors(); + + assert_eq!(post["a"].mu(), 24.96097857478182); + assert_eq!(post["a"].sigma(), 6.298544763358269); + + assert_eq!(post["b"].mu(), 27.095590669107086); + assert_eq!(post["b"].sigma(), 6.010330439043099); + + assert_eq!(post["c"].mu(), 24.88968178743119); + assert_eq!(post["c"].sigma(), 5.866311348102562); + + assert!(b.convergence() > 1); + + let post = b.posteriors(); + + assert_ulps_eq!(post["a"].mu(), 25.000000, epsilon = 0.000001); + assert_ulps_eq!(post["a"].sigma(), 5.4192120, epsilon = 0.000001); + + assert_ulps_eq!(post["b"].mu(), 25.000000, epsilon = 0.000001); + assert_ulps_eq!(post["b"].sigma(), 5.4192120, epsilon = 0.000001); + + assert_ulps_eq!(post["c"].mu(), 25.000000, epsilon = 0.000001); + assert_ulps_eq!(post["c"].sigma(), 5.4192120, epsilon = 0.000001); } } diff --git a/src/main.rs b/src/main.rs index 9c2d3b1..f019792 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,39 @@ +use std::collections::HashMap; + use trueskill_tt::*; fn main() { - let t_a = Player::new(Gaussian::new(29.0, 1.0), 25.0 / 6.0, GAMMA, N_INF); - let t_b = Player::new(Gaussian::new(25.0, 25.0 / 3.0), 25.0 / 6.0, GAMMA, N_INF); + let mut agents = HashMap::new(); - let g = Game::new(vec![vec![t_a], vec![t_b]], vec![0, 1], 0.0); - let p = g.posteriors(); + for k in ["a", "b", "c", "d", "e", "f"] { + let agent = Agent::new( + Player::new( + Gaussian::new(25.0, 25.0 / 3.0), + 25.0 / 6.0, + 25.0 / 300.0, + N_INF, + ), + N_INF, + f64::NEG_INFINITY, + ); + + agents.insert(k.to_string(), agent); + } + + let b = Batch::new( + vec![ + vec![vec!["a"], vec!["b"]], + vec![vec!["c"], vec!["d"]], + vec![vec!["e"], vec!["f"]], + ], + vec![vec![1, 0], vec![0, 1], vec![1, 0]], + 0.0, + agents, + 0.0, + ); + + let post = b.posteriors(); + + println!("{} {}", post["a"].mu(), 29.205); + println!("{} {}", post["a"].sigma(), 7.194) }