diff --git a/src/batch.rs b/src/batch.rs index 434e31a..fcc10d9 100644 --- a/src/batch.rs +++ b/src/batch.rs @@ -2,7 +2,7 @@ use std::collections::{HashMap, HashSet}; use crate::{Game, Gaussian, Player, N_INF}; -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Skill { pub forward: Gaussian, pub backward: Gaussian, @@ -96,7 +96,6 @@ pub struct Batch { pub skills: HashMap, events: Vec, time: f64, - agents: HashMap, p_draw: f64, } @@ -105,23 +104,27 @@ impl Batch { composition: Vec>>, results: Vec>, time: f64, - agents: HashMap, + agents: &mut HashMap, p_draw: f64, ) -> Self { let mut this = Self { skills: HashMap::new(), events: Vec::new(), time, - agents, p_draw, }; - this.add_events(composition, results); + this.add_events(composition, results, agents); this } - pub fn add_events(&mut self, composition: Vec>>, results: Vec>) { + pub fn add_events( + &mut self, + composition: Vec>>, + results: Vec>, + agents: &mut HashMap, + ) { let this_agent = composition .iter() .flat_map(|teams| teams.iter()) @@ -130,16 +133,16 @@ impl Batch { .collect::>(); for a in this_agent { - let elapsed = compute_elapsed(self.agents[a].last_time, self.time); + let elapsed = compute_elapsed(agents[a].last_time, self.time); if let Some(skill) = self.skills.get_mut(a) { skill.elapsed = elapsed; - skill.forward = self.agents[a].receive(elapsed); + skill.forward = agents[a].receive(elapsed); } else { self.skills.insert( a.to_string(), Skill { - forward: self.agents[a].receive(elapsed), + forward: agents[a].receive(elapsed), ..Default::default() }, ); @@ -173,7 +176,7 @@ impl Batch { self.events.push(event); } - self.iteration(from); + self.iteration(from, agents); } pub fn posterior(&self, agent: &str) -> Gaussian { @@ -189,29 +192,33 @@ impl Batch { .collect::>() } - fn within_prior(&self, item: &Item) -> Player { - let r = &self.agents[&item.name].player; + fn within_prior(&self, item: &Item, agents: &mut HashMap) -> Player { + let r = &agents[&item.name].player; let g = self.posterior(&item.name) / item.likelihood; Player::new(g, r.beta, r.gamma, N_INF) } - pub fn within_priors(&self, event: usize) -> Vec> { + pub fn within_priors( + &self, + event: usize, + agents: &mut HashMap, + ) -> Vec> { self.events[event] .teams .iter() .map(|team| { team.items .iter() - .map(|item| self.within_prior(item)) + .map(|item| self.within_prior(item, agents)) .collect::>() }) .collect::>() } - fn iteration(&mut self, from: usize) { + fn iteration(&mut self, from: usize, agents: &mut HashMap) { for e in from..self.events.len() { - let teams = self.within_priors(e); + let teams = self.within_priors(e, agents); let result = self.events[e].result(); let g = Game::new(teams, result, self.p_draw); @@ -230,7 +237,7 @@ impl Batch { } } - pub fn convergence(&mut self) -> usize { + pub fn convergence(&mut self, agents: &mut HashMap) -> usize { let epsilon = 1e-6; let iterations = 20; @@ -240,7 +247,7 @@ impl Batch { while (step.0 > epsilon || step.1 > epsilon) && i < iterations { let old = self.posteriors(); - self.iteration(0); + self.iteration(0, agents); let new = self.posteriors(); @@ -259,36 +266,34 @@ impl Batch { i } - /* - def convergence(self, epsilon=1e-6, iterations = 20): - step, i = (inf, inf), 0 - while gr_tuple(step, epsilon) and (i < iterations): - old = self.posteriors().copy() - self.iteration() - step = dict_diff(old, self.posteriors()) - i += 1 - return i - */ - pub fn forward_prior_out(&self, agent: &str) -> Gaussian { let skill = &self.skills[agent]; skill.forward * skill.likelihood } - /* - def backward_prior_out(self, agent): - N = self.skills[agent].likelihood*self.skills[agent].backward - return N.forget(self.agents[agent].player.gamma, self.skills[agent].elapsed) - def new_backward_info(self): - for a in self.skills: - self.skills[a].backward = self.agents[a].message - return self.iteration() - def new_forward_info(self): - for a in self.skills: - self.skills[a].forward = self.agents[a].receive(self.skills[a].elapsed) - return self.iteration() - */ + pub fn backward_prior_out(&self, agent: &str, agents: &mut HashMap) -> Gaussian { + let skill = &self.skills[agent]; + let n = skill.likelihood * skill.backward; + + n.forget(agents[agent].player.gamma, skill.elapsed) + } + + pub fn new_backward_info(&mut self, agents: &mut HashMap) { + for (agent, skill) in self.skills.iter_mut() { + skill.backward = agents[agent].message; + } + + self.iteration(0, agents); + } + + pub fn new_forward_info(&mut self, agents: &mut HashMap) { + for (agent, skill) in self.skills.iter_mut() { + skill.forward = agents[agent].receive(skill.elapsed); + } + + self.iteration(0, agents); + } } #[cfg(test)] @@ -324,7 +329,7 @@ mod tests { ], vec![vec![1, 0], vec![0, 1], vec![1, 0]], 0.0, - agents, + &mut agents, 0.0, ); @@ -339,7 +344,7 @@ mod tests { assert_eq!(post["c"].mu(), 20.79477925612302); assert_eq!(post["c"].sigma(), 7.194481422570443); - assert_eq!(b.convergence(), 1); + assert_eq!(b.convergence(&mut agents), 1); } #[test] @@ -369,7 +374,7 @@ mod tests { ], vec![vec![1, 0], vec![0, 1], vec![1, 0]], 2.0, - agents, + &mut agents, 0.0, ); @@ -384,7 +389,7 @@ mod tests { assert_eq!(post["c"].mu(), 24.88968178743119); assert_eq!(post["c"].sigma(), 5.866311348102562); - assert!(b.convergence() > 1); + assert!(b.convergence(&mut agents) > 1); let post = b.posteriors(); diff --git a/src/game.rs b/src/game.rs index 53866c7..f766a0c 100644 --- a/src/game.rs +++ b/src/game.rs @@ -1,4 +1,3 @@ -use std::cmp::Reverse; use std::collections::HashSet; use crate::{message::DiffMessages, utils, variable::TeamVariable, Gaussian, Player, N00}; @@ -73,7 +72,7 @@ impl Game { } let r = &self.result; - let o = utils::sortperm(r); + let o = utils::sortperm(r, true); let t = (0..self.teams.len()) .map(|e| TeamVariable { diff --git a/src/history.rs b/src/history.rs index 3a55a2b..6c26791 100644 --- a/src/history.rs +++ b/src/history.rs @@ -51,8 +51,6 @@ impl History { }) .collect::>(); - println!("{:#?}", agents); - let mut this = Self { size: composition.len(), batches: Vec::new(), @@ -75,13 +73,7 @@ impl History { results: Vec>, times: Vec, ) { - let o = { - let mut o = utils::sortperm(×); - o.reverse(); - o - }; - - let o = o; + let o = utils::sortperm(×, false); let mut i = 0; while i < self.size { @@ -102,7 +94,7 @@ impl History { composition, results, t as f64, - self.agents.clone(), + &mut self.agents, self.p_draw, ); @@ -119,16 +111,89 @@ impl History { } } - fn iteration(&self) { - todo!() + fn iteration(&mut self) -> (f64, f64) { + let mut step = (0.0, 0.0); + + clean(self.agents.values_mut(), false); + + for j in (0..self.batches.len() - 1).rev() { + for agent in self.batches[j + 1].skills.keys() { + self.agents.get_mut(agent).unwrap().message = + self.batches[j + 1].backward_prior_out(agent, &mut self.agents); + } + + let old = self.batches[j].posteriors(); + + self.batches[j].new_backward_info(&mut self.agents); + + let new = self.batches[j].posteriors(); + + step = old + .iter() + .fold(step, |step, (a, old)| tuple_max(step, old.delta(new[a]))); + } + + clean(self.agents.values_mut(), false); + + for j in 1..self.batches.len() { + for agent in self.batches[j - 1].skills.keys() { + self.agents.get_mut(agent).unwrap().message = + self.batches[j - 1].forward_prior_out(agent); + } + + let old = self.batches[j].posteriors(); + + self.batches[j].new_forward_info(&mut self.agents); + + let new = self.batches[j].posteriors(); + + step = old + .iter() + .fold(step, |step, (a, old)| tuple_max(step, old.delta(new[a]))); + } + + if self.batches.len() == 1 { + let old = self.batches[0].posteriors(); + + self.batches[0].convergence(&mut self.agents); + + let new = self.batches[0].posteriors(); + + step = old + .iter() + .fold(step, |step, (a, old)| tuple_max(step, old.delta(new[a]))); + } + + step } - fn convergence(&self) { + pub fn convergence(&mut self) -> ((f64, f64), usize) { let epsilon = 1e-6; let iterations = 30; - let verbose = true; + let verbose = false; - todo!() + let mut step = (f64::INFINITY, f64::INFINITY); + let mut i = 0; + + while (step.0 > epsilon || step.1 > epsilon) && i < iterations { + if verbose { + print!("Iteration = {}", i); + } + + step = self.iteration(); + + i += 1; + + if verbose { + println!(", step = {:?}", step); + } + } + + if verbose { + println!("End"); + } + + (step, i) } fn learning_curves(&self) { @@ -140,6 +205,23 @@ impl History { } } +fn clean<'a, A: Iterator>(agents: A, last_time: bool) { + for a in agents { + a.message = N_INF; + + if last_time { + a.last_time = f64::NEG_INFINITY; + } + } +} + +fn tuple_max(a: (f64, f64), b: (f64, f64)) -> (f64, f64) { + ( + if a.0 > b.0 { a.0 } else { b.0 }, + if a.1 > b.1 { a.1 } else { b.1 }, + ) +} + #[cfg(test)] mod tests { use approx::assert_ulps_eq; @@ -170,7 +252,7 @@ mod tests { priors.insert(k.to_string(), player); } - let h = History::new( + let mut h = History::new( composition, results, vec![1, 2, 3], @@ -194,10 +276,172 @@ mod tests { assert_ulps_eq!(observed, expected, epsilon = 0.000001); let observed = h.batches[1].posterior("a"); - let p = Game::new(h.batches[1].within_priors(0), vec![0, 1], P_DRAW).posteriors(); + let p = Game::new( + h.batches[1].within_priors(0, &mut h.agents), + vec![0, 1], + P_DRAW, + ) + .posteriors(); let expected = p[0][0]; assert_ulps_eq!(observed.mu(), expected.mu(), epsilon = 0.000001); assert_ulps_eq!(observed.sigma(), expected.sigma(), epsilon = 0.000001); } + + #[test] + fn test_one_batch() { + let composition = vec![ + vec![vec!["aj"], vec!["bj"]], + vec![vec!["bj"], vec!["cj"]], + vec![vec!["cj"], vec!["aj"]], + ]; + let results = vec![vec![1, 0], vec![1, 0], vec![1, 0]]; + let times = vec![1, 1, 1]; + + let mut priors = HashMap::new(); + + for k in ["aj", "bj", "cj"] { + let player = Player::new( + Gaussian::new(25.0, 25.0 / 3.0), + 25.0 / 6.0, + 0.15 * 25.0 / 3.0, + N_INF, + ); + + priors.insert(k.to_string(), player); + } + + let mut h1 = History::new( + composition, + results, + times, + priors, + MU, + BETA, + SIGMA, + GAMMA, + P_DRAW, + ); + + assert_ulps_eq!( + h1.batches[0].posterior("aj").mu(), + 22.904409330892914, + epsilon = 0.000001 + ); + assert_ulps_eq!( + h1.batches[0].posterior("aj").sigma(), + 6.0103304390431, + epsilon = 0.000001 + ); + assert_ulps_eq!( + h1.batches[0].posterior("cj").mu(), + 25.110318212568806, + epsilon = 0.000001 + ); + assert_ulps_eq!( + h1.batches[0].posterior("cj").sigma(), + 5.866311348102563, + epsilon = 0.000001 + ); + + let (_step, _i) = h1.convergence(); + + assert_ulps_eq!( + h1.batches[0].posterior("aj").mu(), + 25.00000000, + epsilon = 0.000001 + ); + assert_ulps_eq!( + h1.batches[0].posterior("aj").sigma(), + 5.41921200, + epsilon = 0.000001 + ); + assert_ulps_eq!( + h1.batches[0].posterior("cj").mu(), + 25.00000000, + epsilon = 0.000001 + ); + assert_ulps_eq!( + h1.batches[0].posterior("cj").sigma(), + 5.41921200, + epsilon = 0.000001 + ); + + let composition = vec![ + vec![vec!["aj"], vec!["bj"]], + vec![vec!["bj"], vec!["cj"]], + vec![vec!["cj"], vec!["aj"]], + ]; + let results = vec![vec![1, 0], vec![1, 0], vec![1, 0]]; + let times = vec![1, 2, 3]; + + let mut priors = HashMap::new(); + + for k in ["aj", "bj", "cj"] { + let player = Player::new( + Gaussian::new(25.0, 25.0 / 3.0), + 25.0 / 6.0, + 25.0 / 300.0, + N_INF, + ); + + priors.insert(k.to_string(), player); + } + + let mut h2 = History::new( + composition, + results, + times, + priors, + MU, + BETA, + SIGMA, + GAMMA, + P_DRAW, + ); + + assert_ulps_eq!( + h2.batches[2].posterior("aj").mu(), + 22.90352227792141, + epsilon = 0.000001 + ); + assert_ulps_eq!( + h2.batches[2].posterior("aj").sigma(), + 6.011017301320632, + epsilon = 0.000001 + ); + assert_ulps_eq!( + h2.batches[2].posterior("cj").mu(), + 25.110702468366718, + epsilon = 0.000001 + ); + assert_ulps_eq!( + h2.batches[2].posterior("cj").sigma(), + 5.866811597660157, + epsilon = 0.000001 + ); + + let (_step, _i) = h2.convergence(); + + assert_ulps_eq!( + h2.batches[2].posterior("aj").mu(), + 24.99999999, + epsilon = 0.000001 + ); + assert_ulps_eq!( + h2.batches[2].posterior("aj").sigma(), + 5.419212002, + epsilon = 0.000001 + ); + assert_ulps_eq!( + h2.batches[2].posterior("cj").mu(), + 24.99999999, + epsilon = 0.000001 + ); + assert_ulps_eq!( + h2.batches[2].posterior("cj").sigma(), + 5.419212002, + epsilon = 0.000001 + ); + } } diff --git a/src/main.rs b/src/main.rs index f019792..7b36285 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,37 +3,38 @@ use std::collections::HashMap; use trueskill_tt::*; fn main() { - let mut agents = HashMap::new(); + let composition = vec![ + vec![vec!["aj"], vec!["bj"]], + vec![vec!["bj"], vec!["cj"]], + vec![vec!["cj"], vec!["aj"]], + ]; + let results = vec![vec![1, 0], vec![1, 0], vec![1, 0]]; + let times = vec![1, 2, 3]; - for k in ["a", "b", "c", "d", "e", "f"] { - let agent = Agent::new( - Player::new( - Gaussian::new(25.0, 25.0 / 3.0), - 25.0 / 6.0, - 25.0 / 300.0, - N_INF, - ), + let mut priors = HashMap::new(); + + for k in ["aj", "bj", "cj"] { + let player = Player::new( + Gaussian::new(25.0, 25.0 / 3.0), + 25.0 / 6.0, + 25.0 / 300.0, N_INF, - f64::NEG_INFINITY, ); - agents.insert(k.to_string(), agent); + priors.insert(k.to_string(), player); } - let b = Batch::new( - vec![ - vec![vec!["a"], vec!["b"]], - vec![vec!["c"], vec!["d"]], - vec![vec!["e"], vec!["f"]], - ], - vec![vec![1, 0], vec![0, 1], vec![1, 0]], - 0.0, - agents, - 0.0, + let mut h2 = History::new( + composition, + results, + times, + priors, + MU, + BETA, + SIGMA, + GAMMA, + P_DRAW, ); - let post = b.posteriors(); - - println!("{} {}", post["a"].mu(), 29.205); - println!("{} {}", post["a"].sigma(), 7.194) + let (step, i) = h2.convergence(); } diff --git a/src/utils.rs b/src/utils.rs index dda82a1..226c500 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -134,9 +134,15 @@ pub(crate) fn compute_margin(p_draw: f64, sd: f64) -> f64 { ppf(0.5 - p_draw / 2.0, 0.0, sd).abs() } -pub(crate) fn sortperm(xs: &[T]) -> Vec { +pub(crate) fn sortperm(xs: &[T], reverse: bool) -> Vec { let mut x = xs.iter().enumerate().collect::>(); - x.sort_unstable_by_key(|(_, x)| Reverse(*x)); + + if reverse { + x.sort_unstable_by_key(|(_, x)| Reverse(*x)); + } else { + x.sort_unstable_by_key(|(_, x)| *x); + } + x.into_iter().map(|(i, _)| i).collect() } @@ -216,6 +222,7 @@ mod tests { #[test] fn test_sortperm() { - assert_eq!(sortperm(&[0, 1, 2, 0]), vec![2, 1, 0, 3]); + assert_eq!(sortperm(&[0, 1, 2, 0], true), vec![2, 1, 0, 3]); + assert_eq!(sortperm(&[1, 1, 1], false), vec![0, 1, 2]); } }