From cd1079a811a48a036a11e2a8d203c3e199aa032a Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Mon, 27 Jun 2022 10:16:12 +0200 Subject: [PATCH] Use and Index struct instead of str and String for player id --- README.md | 2 +- src/batch.rs | 237 ++++++++--------- src/history.rs | 685 ++++++++++++++++++++++++++----------------------- src/lib.rs | 45 +++- 4 files changed, 514 insertions(+), 455 deletions(-) diff --git a/README.md b/README.md index 6223a5f..7672b0f 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ Rust port of [TrueSkillThroughTime.py](https://github.com/glandfried/TrueSkillTh ## Todo - [x] Implement approx for Gaussian -- [ ] Add more tests from `TrueSkillThroughTime.jl` +- [x] Add more tests from `TrueSkillThroughTime.jl` - [ ] Time needs to be an enum so we can have multiple states (see `batch::compute_elapsed()`) - [ ] Add examples (use same TrueSkillThroughTime.(py|jl)) - [ ] Add Observer (see [argmin](https://docs.rs/argmin/latest/argmin/core/trait.Observe.html) for inspiration) diff --git a/src/batch.rs b/src/batch.rs index 1a87e8e..4ecc864 100644 --- a/src/batch.rs +++ b/src/batch.rs @@ -1,7 +1,7 @@ use std::collections::{HashMap, HashSet}; use crate::{ - agent::Agent, game::Game, gaussian::Gaussian, player::Player, tuple_gt, tuple_max, N_INF, + agent::Agent, game::Game, gaussian::Gaussian, player::Player, tuple_gt, tuple_max, Index, N_INF, }; #[derive(Debug)] @@ -27,7 +27,7 @@ impl Default for Skill { #[derive(Debug)] struct Item { - agent: String, + agent: Index, likelihood: Gaussian, } @@ -56,19 +56,19 @@ impl Event { #[derive(Debug)] pub struct Batch { pub(crate) events: Vec, - pub(crate) skills: HashMap, + pub(crate) skills: HashMap, pub(crate) time: u64, p_draw: f64, } impl Batch { pub(crate) fn new( - composition: Vec>>, + composition: Vec>>, results: Vec>, weights: Vec>>, time: u64, p_draw: f64, - agents: &mut HashMap, + agents: &mut HashMap, ) -> Self { assert!( results.is_empty() || results.len() == composition.len(), @@ -88,17 +88,17 @@ impl Batch { let elapsed = this_agent .iter() - .map(|&a| (a, compute_elapsed(agents[a].last_time, time))) + .map(|&idx| (idx, compute_elapsed(agents[&idx].last_time, time))) .collect::>(); let skills = this_agent .iter() - .map(|&a| { + .map(|&idx| { ( - a.to_string(), + idx, Skill { - forward: agents[a].receive(elapsed[a]), - elapsed: elapsed[a], + forward: agents[&idx].receive(elapsed[&idx]), + elapsed: elapsed[&idx], ..Default::default() }, ) @@ -111,7 +111,7 @@ impl Batch { .map(|t| { let items = (0..composition[e][t].len()) .map(|a| Item { - agent: composition[e][t][a].to_string(), + agent: composition[e][t][a], likelihood: N_INF, }) .collect::>(); @@ -153,10 +153,10 @@ impl Batch { pub(crate) fn add_events( &mut self, - composition: Vec>>, + composition: Vec>>, results: Vec>, weights: Vec>>, - agents: &mut HashMap, + agents: &mut HashMap, ) { let this_agent = composition .iter() @@ -165,17 +165,17 @@ impl Batch { .cloned() .collect::>(); - for a in this_agent { - let elapsed = compute_elapsed(agents[a].last_time, self.time); + for idx in this_agent { + let elapsed = compute_elapsed(agents[&idx].last_time, self.time); - if let Some(skill) = self.skills.get_mut(a) { + if let Some(skill) = self.skills.get_mut(&idx) { skill.elapsed = elapsed; - skill.forward = agents[a].receive(elapsed); + skill.forward = agents[&idx].receive(elapsed); } else { self.skills.insert( - a.to_string(), + idx, Skill { - forward: agents[a].receive(elapsed), + forward: agents[&idx].receive(elapsed), elapsed, ..Default::default() }, @@ -190,7 +190,7 @@ impl Batch { .map(|t| { let items = (0..composition[e][t].len()) .map(|a| Item { - agent: composition[e][t][a].to_string(), + agent: composition[e][t][a], likelihood: N_INF, }) .collect::>(); @@ -222,16 +222,16 @@ impl Batch { self.iteration(from, agents); } - pub fn posterior(&self, agent: &str) -> Gaussian { - let skill = &self.skills[agent]; + pub fn posterior(&self, agent: Index) -> Gaussian { + let skill = &self.skills[&agent]; skill.likelihood * skill.backward * skill.forward } - pub(crate) fn posteriors(&self) -> HashMap { + pub(crate) fn posteriors(&self) -> HashMap { self.skills .keys() - .map(|a| (a.to_string(), self.posterior(a))) + .map(|&idx| (idx, self.posterior(idx))) .collect::>() } @@ -240,7 +240,7 @@ impl Batch { item: &Item, online: bool, forward: bool, - agents: &mut HashMap, + agents: &mut HashMap, ) -> Player { let r = &agents[&item.agent].player; @@ -249,7 +249,7 @@ impl Batch { } else if forward { Player::new(self.skills[&item.agent].forward, r.beta, r.gamma) } else { - let wp = self.posterior(&item.agent) / item.likelihood; + let wp = self.posterior(item.agent) / item.likelihood; Player::new(wp, r.beta, r.gamma) } @@ -260,7 +260,7 @@ impl Batch { event: usize, online: bool, forward: bool, - agents: &mut HashMap, + agents: &mut HashMap, ) -> Vec> { self.events[event] .teams @@ -274,7 +274,7 @@ impl Batch { .collect::>() } - pub(crate) fn iteration(&mut self, from: usize, agents: &mut HashMap) { + pub(crate) fn iteration(&mut self, from: usize, agents: &mut HashMap) { for e in from..self.events.len() { let teams = self.within_priors(e, false, false, agents); let result = self.events[e].outputs(); @@ -295,7 +295,7 @@ impl Batch { } } - pub(crate) fn convergence(&mut self, agents: &mut HashMap) -> usize { + pub(crate) fn convergence(&mut self, agents: &mut HashMap) -> usize { let epsilon = 1e-6; let iterations = 20; @@ -319,7 +319,7 @@ impl Batch { i } - pub(crate) fn forward_prior_out(&self, agent: &str) -> Gaussian { + pub(crate) fn forward_prior_out(&self, agent: &Index) -> Gaussian { let skill = &self.skills[agent]; skill.forward * skill.likelihood @@ -327,8 +327,8 @@ impl Batch { pub(crate) fn backward_prior_out( &self, - agent: &str, - agents: &mut HashMap, + agent: &Index, + agents: &mut HashMap, ) -> Gaussian { let skill = &self.skills[agent]; let n = skill.likelihood * skill.backward; @@ -336,7 +336,7 @@ impl Batch { n.forget(agents[agent].player.gamma, skill.elapsed) } - pub(crate) fn new_backward_info(&mut self, agents: &mut HashMap) { + pub(crate) fn new_backward_info(&mut self, agents: &mut HashMap) { for (agent, skill) in self.skills.iter_mut() { skill.backward = agents[agent].message; } @@ -344,7 +344,7 @@ impl Batch { self.iteration(0, agents); } - pub(crate) fn new_forward_info(&mut self, agents: &mut HashMap) { + pub(crate) fn new_forward_info(&mut self, agents: &mut HashMap) { for (agent, skill) in self.skills.iter_mut() { skill.forward = agents[agent].receive(skill.elapsed); } @@ -352,14 +352,14 @@ impl Batch { self.iteration(0, agents); } - pub(crate) fn log_evidence2( + pub(crate) fn log_evidence( &self, online: bool, - agents2: &Vec<&str>, + targets: &[Index], forward: bool, - agents: &mut HashMap, + agents: &mut HashMap, ) -> f64 { - if agents2.is_empty() { + if targets.is_empty() { if online || forward { self.events .iter() @@ -388,7 +388,7 @@ impl Batch { .teams .iter() .flat_map(|team| &team.items) - .any(|item| agents2.contains(&item.agent.as_str())) + .any(|item| targets.contains(&item.agent)) }) .map(|(e, event)| { Game::new( @@ -409,7 +409,7 @@ impl Batch { .teams .iter() .flat_map(|team| &team.items) - .any(|item| agents2.contains(&item.agent.as_str())) + .any(|item| targets.contains(&item.agent)) }) .map(|event| event.evidence.ln()) .sum() @@ -417,19 +417,14 @@ impl Batch { } } - pub(crate) fn get_composition(&self) -> Vec>> { + pub(crate) fn get_composition(&self) -> Vec>> { self.events .iter() .map(|event| { event .teams .iter() - .map(|team| { - team.items - .iter() - .map(|item| item.agent.as_str()) - .collect::>() - }) + .map(|team| team.items.iter().map(|item| item.agent).collect::>()) .collect::>() }) .collect::>() @@ -463,17 +458,26 @@ pub(crate) fn compute_elapsed(last_time: u64, actual_time: u64) -> u64 { mod tests { use approx::assert_ulps_eq; - use crate::{agent::Agent, player::Player}; + use crate::{agent::Agent, player::Player, IndexMap}; use super::*; #[test] fn test_one_event_each() { + let mut index_map = IndexMap::new(); + + let a = index_map.get_or_create("a"); + let b = index_map.get_or_create("b"); + let c = index_map.get_or_create("c"); + let d = index_map.get_or_create("d"); + let e = index_map.get_or_create("e"); + let f = index_map.get_or_create("f"); + let mut agents = HashMap::new(); - for agent in ["a", "b", "c", "d", "e", "f"] { + for agent in [a, b, c, d, e, f] { agents.insert( - agent.to_string(), + agent, Agent { player: Player::new(Gaussian::new(25.0, 25.0 / 3.0), 25.0 / 6.0, 25.0 / 300.0), ..Default::default() @@ -483,9 +487,9 @@ mod tests { let mut batch = Batch::new( vec![ - vec![vec!["a"], vec!["b"]], - vec![vec!["c"], vec!["d"]], - vec![vec!["e"], vec!["f"]], + vec![vec![a], vec![b]], + vec![vec![c], vec![d]], + vec![vec![e], vec![f]], ], vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]], vec![], @@ -496,34 +500,32 @@ mod tests { let post = batch.posteriors(); - assert_ulps_eq!(post["a"].mu, 29.205220743876975, epsilon = 0.000001); - assert_ulps_eq!(post["a"].sigma, 7.194481422570443, epsilon = 0.000001); - - assert_ulps_eq!(post["b"].mu, 20.79477925612302, epsilon = 0.000001); - assert_ulps_eq!(post["b"].sigma, 7.194481422570443, epsilon = 0.000001); - - assert_ulps_eq!(post["c"].mu, 20.79477925612302, epsilon = 0.000001); - assert_ulps_eq!(post["c"].sigma, 7.194481422570443, epsilon = 0.000001); - - assert_ulps_eq!(post["d"].mu, 29.205220743876975, epsilon = 0.000001); - assert_ulps_eq!(post["d"].sigma, 7.194481422570443, epsilon = 0.000001); - - assert_ulps_eq!(post["e"].mu, 29.205220743876975, epsilon = 0.000001); - assert_ulps_eq!(post["e"].sigma, 7.194481422570443, epsilon = 0.000001); - - assert_ulps_eq!(post["f"].mu, 20.79477925612302, epsilon = 0.000001); - assert_ulps_eq!(post["f"].sigma, 7.194481422570443, epsilon = 0.000001); + assert_ulps_eq!(post[&a], Gaussian::new(29.205220, 7.194481), epsilon = 1e-6); + assert_ulps_eq!(post[&b], Gaussian::new(20.794779, 7.194481), epsilon = 1e-6); + assert_ulps_eq!(post[&c], Gaussian::new(20.794779, 7.194481), epsilon = 1e-6); + assert_ulps_eq!(post[&d], Gaussian::new(29.205220, 7.194481), epsilon = 1e-6); + assert_ulps_eq!(post[&e], Gaussian::new(29.205220, 7.194481), epsilon = 1e-6); + assert_ulps_eq!(post[&f], Gaussian::new(20.794779, 7.194481), epsilon = 1e-6); assert_eq!(batch.convergence(&mut agents), 1); } #[test] fn test_same_strength() { + let mut index_map = IndexMap::new(); + + let a = index_map.get_or_create("a"); + let b = index_map.get_or_create("b"); + let c = index_map.get_or_create("c"); + let d = index_map.get_or_create("d"); + let e = index_map.get_or_create("e"); + let f = index_map.get_or_create("f"); + let mut agents = HashMap::new(); - for agent in ["a", "b", "c", "d", "e", "f"] { + for agent in [a, b, c, d, e, f] { agents.insert( - agent.to_string(), + agent, Agent { player: Player::new(Gaussian::new(25.0, 25.0 / 3.0), 25.0 / 6.0, 25.0 / 300.0), ..Default::default() @@ -533,9 +535,9 @@ mod tests { let mut batch = Batch::new( vec![ - vec![vec!["a"], vec!["b"]], - vec![vec!["a"], vec!["c"]], - vec![vec!["b"], vec!["c"]], + vec![vec![a], vec![b]], + vec![vec![a], vec![c]], + vec![vec![b], vec![c]], ], vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]], vec![], @@ -546,36 +548,35 @@ mod tests { let post = batch.posteriors(); - assert_ulps_eq!(post["a"].mu, 24.96097857478182, epsilon = 0.000001); - assert_ulps_eq!(post["a"].sigma, 6.298544763358269, epsilon = 0.000001); - - assert_ulps_eq!(post["b"].mu, 27.095590669107086, epsilon = 0.000001); - assert_ulps_eq!(post["b"].sigma, 6.010330439043099, epsilon = 0.000001); - - assert_ulps_eq!(post["c"].mu, 24.88968178743119, epsilon = 0.000001); - assert_ulps_eq!(post["c"].sigma, 5.866311348102562, epsilon = 0.000001); + assert_ulps_eq!(post[&a], Gaussian::new(24.960978, 6.298544), epsilon = 1e-6); + assert_ulps_eq!(post[&b], Gaussian::new(27.095590, 6.010330), epsilon = 1e-6); + assert_ulps_eq!(post[&c], Gaussian::new(24.889681, 5.866311), epsilon = 1e-6); assert!(batch.convergence(&mut agents) > 1); let post = batch.posteriors(); - assert_ulps_eq!(post["a"].mu, 25.000000, epsilon = 0.000001); - assert_ulps_eq!(post["a"].sigma, 5.4192120, epsilon = 0.000001); - - assert_ulps_eq!(post["b"].mu, 25.000000, epsilon = 0.000001); - assert_ulps_eq!(post["b"].sigma, 5.4192120, epsilon = 0.000001); - - assert_ulps_eq!(post["c"].mu, 25.000000, epsilon = 0.000001); - assert_ulps_eq!(post["c"].sigma, 5.4192120, epsilon = 0.000001); + assert_ulps_eq!(post[&a], Gaussian::new(25.000000, 5.419212), epsilon = 1e-6); + assert_ulps_eq!(post[&b], Gaussian::new(25.000000, 5.419212), epsilon = 1e-6); + assert_ulps_eq!(post[&c], Gaussian::new(25.000000, 5.419212), epsilon = 1e-6); } #[test] fn test_add_events() { + let mut index_map = IndexMap::new(); + + let a = index_map.get_or_create("a"); + let b = index_map.get_or_create("b"); + let c = index_map.get_or_create("c"); + let d = index_map.get_or_create("d"); + let e = index_map.get_or_create("e"); + let f = index_map.get_or_create("f"); + let mut agents = HashMap::new(); - for agent in ["a", "b", "c", "d", "e", "f"] { + for agent in [a, b, c, d, e, f] { agents.insert( - agent.to_string(), + agent, Agent { player: Player::new(Gaussian::new(25.0, 25.0 / 3.0), 25.0 / 6.0, 25.0 / 300.0), ..Default::default() @@ -585,9 +586,9 @@ mod tests { let mut batch = Batch::new( vec![ - vec![vec!["a"], vec!["b"]], - vec![vec!["a"], vec!["c"]], - vec![vec!["b"], vec!["c"]], + vec![vec![a], vec![b]], + vec![vec![a], vec![c]], + vec![vec![b], vec![c]], ], vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]], vec![], @@ -600,29 +601,15 @@ mod tests { let post = batch.posteriors(); - assert_ulps_eq!( - post["a"], - Gaussian::new(25.000000, 5.4192120), - epsilon = 0.000001 - ); - - assert_ulps_eq!( - post["b"], - Gaussian::new(25.000000, 5.4192120), - epsilon = 0.000001 - ); - - assert_ulps_eq!( - post["c"], - Gaussian::new(25.000000, 5.4192120), - epsilon = 0.000001 - ); + assert_ulps_eq!(post[&a], Gaussian::new(25.000000, 5.419212), epsilon = 1e-6); + assert_ulps_eq!(post[&b], Gaussian::new(25.000000, 5.419212), epsilon = 1e-6); + assert_ulps_eq!(post[&c], Gaussian::new(25.000000, 5.419212), epsilon = 1e-6); batch.add_events( vec![ - vec![vec!["a"], vec!["b"]], - vec![vec!["a"], vec!["c"]], - vec![vec!["b"], vec!["c"]], + vec![vec![a], vec![b]], + vec![vec![a], vec![c]], + vec![vec![b], vec![c]], ], vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]], vec![], @@ -635,20 +622,8 @@ mod tests { let post = batch.posteriors(); - assert_ulps_eq!( - post["a"], - Gaussian::new(25.00000315330858, 3.880150268080797), - epsilon = 0.000001 - ); - assert_ulps_eq!( - post["b"], - Gaussian::new(25.00000315330858, 3.880150268080797), - epsilon = 0.000001 - ); - assert_ulps_eq!( - post["c"], - Gaussian::new(25.00000315330858, 3.880150268080797), - epsilon = 0.000001 - ); + assert_ulps_eq!(post[&a], Gaussian::new(25.000003, 3.880150), epsilon = 1e-6); + assert_ulps_eq!(post[&b], Gaussian::new(25.000003, 3.880150), epsilon = 1e-6); + assert_ulps_eq!(post[&c], Gaussian::new(25.000003, 3.880150), epsilon = 1e-6); } } diff --git a/src/history.rs b/src/history.rs index badb85b..4514d0a 100644 --- a/src/history.rs +++ b/src/history.rs @@ -5,7 +5,7 @@ use crate::{ batch::{self, Batch}, gaussian::Gaussian, player::Player, - sort_time, tuple_gt, tuple_max, BETA, GAMMA, MU, P_DRAW, SIGMA, + sort_time, tuple_gt, tuple_max, Index, BETA, GAMMA, MU, P_DRAW, SIGMA, }; #[derive(Clone)] @@ -88,7 +88,7 @@ impl Default for HistoryBuilder { pub struct History { size: usize, pub batches: Vec, - agents: HashMap, + agents: HashMap, time: bool, mu: f64, sigma: f64, @@ -206,17 +206,17 @@ impl History { (step, i) } - pub fn learning_curves(&self) -> HashMap> { - let mut data: HashMap> = HashMap::new(); + pub fn learning_curves(&self) -> HashMap> { + let mut data: HashMap> = HashMap::new(); for b in &self.batches { for agent in b.skills.keys() { - let point = (b.time, b.posterior(agent)); + let point = (b.time, b.posterior(*agent)); if let Some(entry) = data.get_mut(agent) { entry.push(point); } else { - data.insert(agent.to_string(), vec![point]); + data.insert(*agent, vec![point]); } } } @@ -224,20 +224,30 @@ impl History { data } - pub fn log_evidence(&mut self, forward: bool, agents: &Vec<&str>) -> f64 { + pub fn log_evidence(&mut self, forward: bool, targets: &[Index]) -> f64 { self.batches .iter() - .map(|batch| batch.log_evidence2(self.online, agents, forward, &mut self.agents)) + .map(|batch| batch.log_evidence(self.online, targets, forward, &mut self.agents)) .sum() } pub fn add_events( &mut self, - composition: Vec>>, + composition: Vec>>, results: Vec>, times: Vec, weights: Vec>>, - priors: HashMap, + ) { + self.add_events_with_prior(composition, results, times, weights, HashMap::new()) + } + + pub fn add_events_with_prior( + &mut self, + composition: Vec>>, + results: Vec>, + times: Vec, + weights: Vec>>, + priors: HashMap, ) { assert!(times.is_empty() || self.time, "length(times)>0 but !h.time"); assert!( @@ -265,11 +275,11 @@ impl History { .collect::>(); for agent in &this_agent { - if !self.agents.contains_key(*agent) { + if !self.agents.contains_key(agent) { self.agents.insert( - agent.to_string(), + *agent, Agent { - player: priors.get(*agent).cloned().unwrap_or_else(|| { + player: priors.get(agent).cloned().unwrap_or_else(|| { Player::new(Gaussian::new(self.mu, self.sigma), self.beta, self.gamma) }), ..Default::default() @@ -309,11 +319,11 @@ impl History { let intersect = this_agent .iter() - .filter(|&&agent| b.skills.contains_key(agent)) + .filter(|&agent| b.skills.contains_key(agent)) .cloned() .collect::>(); - for agent in intersect { + for agent in &intersect { b.skills.get_mut(agent).unwrap().elapsed = batch::compute_elapsed(self.agents[agent].last_time, b.time); @@ -386,11 +396,11 @@ impl History { let intersect = this_agent .iter() - .filter(|&&agent| b.skills.contains_key(agent)) + .filter(|&agent| b.skills.contains_key(agent)) .cloned() .collect::>(); - for agent in intersect { + for agent in &intersect { b.skills.get_mut(agent).unwrap().elapsed = batch::compute_elapsed(self.agents[agent].last_time, b.time); @@ -411,24 +421,30 @@ impl History { mod tests { use approx::assert_ulps_eq; - use crate::{Game, Gaussian, Player, EPSILON, ITERATIONS, P_DRAW}; + use crate::{Game, Gaussian, IndexMap, Player, EPSILON, ITERATIONS, P_DRAW}; use super::*; #[test] fn test_init() { + let mut index_map = IndexMap::new(); + + let a = index_map.get_or_create("a"); + let b = index_map.get_or_create("b"); + let c = index_map.get_or_create("c"); + let composition = vec![ - vec![vec!["a"], vec!["b"]], - vec![vec!["a"], vec!["c"]], - vec![vec!["b"], vec!["c"]], + vec![vec![a], vec![b]], + vec![vec![a], vec![c]], + vec![vec![b], vec![c]], ]; let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]]; let mut priors = HashMap::new(); - for agent in ["a", "b", "c"] { + for agent in [a, b, c] { priors.insert( - agent.to_string(), + agent, Player::new( Gaussian::new(25.0, 25.0 / 3.0), 25.0 / 6.0, @@ -439,23 +455,19 @@ mod tests { let mut h = History::default(); - h.add_events(composition, results, vec![1, 2, 3], vec![], priors); + h.add_events_with_prior(composition, results, vec![1, 2, 3], vec![], priors); let p0 = h.batches[0].posteriors(); - assert_ulps_eq!( - p0["a"], - Gaussian::new(29.205220743876975, 7.194481422570443), - epsilon = 0.000001 - ); + assert_ulps_eq!(p0[&a], Gaussian::new(29.205220, 7.194481), epsilon = 1e-6); - let observed = h.batches[1].skills["a"].forward.sigma; + let observed = h.batches[1].skills[&a].forward.sigma; let gamma: f64 = 0.15 * 25.0 / 3.0; - let expected = (gamma.powi(2) + h.batches[0].posterior("a").sigma.powi(2)).sqrt(); + let expected = (gamma.powi(2) + h.batches[0].posterior(a).sigma.powi(2)).sqrt(); assert_ulps_eq!(observed, expected, epsilon = 0.000001); - let observed = h.batches[1].posterior("a"); + let observed = h.batches[1].posterior(a); let p = Game::new( h.batches[1].within_priors(0, false, false, &mut h.agents), @@ -466,153 +478,174 @@ mod tests { .posteriors(); let expected = p[0][0]; - assert_ulps_eq!(observed, expected, epsilon = 0.000001); + assert_ulps_eq!(observed, expected, epsilon = 1e-6); } #[test] fn test_one_batch() { + let mut index_map = IndexMap::new(); + + let a = index_map.get_or_create("a"); + let b = index_map.get_or_create("b"); + let c = index_map.get_or_create("c"); + let composition = vec![ - vec![vec!["a"], vec!["b"]], - vec![vec!["b"], vec!["c"]], - vec![vec!["c"], vec!["a"]], + vec![vec![a], vec![b]], + vec![vec![b], vec![c]], + vec![vec![c], vec![a]], ]; let results = vec![vec![1.0, 0.0], vec![1.0, 0.0], vec![1.0, 0.0]]; let times = vec![1, 1, 1]; let mut priors = HashMap::new(); - for k in ["a", "b", "c"] { - let player = Player::new( - Gaussian::new(25.0, 25.0 / 3.0), - 25.0 / 6.0, - 0.15 * 25.0 / 3.0, + for agent in [a, b, c] { + priors.insert( + agent, + Player::new( + Gaussian::new(25.0, 25.0 / 3.0), + 25.0 / 6.0, + 0.15 * 25.0 / 3.0, + ), ); - - priors.insert(k.to_string(), player); } let mut h1 = History::default(); - h1.add_events(composition, results, times, vec![], priors); + h1.add_events_with_prior(composition, results, times, vec![], priors); assert_ulps_eq!( - h1.batches[0].posterior("a"), - Gaussian::new(22.904409330892914, 6.0103304390431), - epsilon = 0.000001 + h1.batches[0].posterior(a), + Gaussian::new(22.904409, 6.010330), + epsilon = 1e-6 ); assert_ulps_eq!( - h1.batches[0].posterior("c"), - Gaussian::new(25.110318212568806, 5.866311348102563), - epsilon = 0.000001 + h1.batches[0].posterior(c), + Gaussian::new(25.110318, 5.866311), + epsilon = 1e-6 ); h1.convergence(ITERATIONS, EPSILON, false); assert_ulps_eq!( - h1.batches[0].posterior("a"), - Gaussian::new(25.00000000, 5.41921200), - epsilon = 0.000001 + h1.batches[0].posterior(a), + Gaussian::new(25.000000, 5.419212), + epsilon = 1e-6 ); assert_ulps_eq!( - h1.batches[0].posterior("c"), - Gaussian::new(25.00000000, 5.41921200), - epsilon = 0.000001 + h1.batches[0].posterior(c), + Gaussian::new(25.000000, 5.419212), + epsilon = 1e-6 ); let composition = vec![ - vec![vec!["a"], vec!["b"]], - vec![vec!["b"], vec!["c"]], - vec![vec!["c"], vec!["a"]], + vec![vec![a], vec![b]], + vec![vec![b], vec![c]], + vec![vec![c], vec![a]], ]; let results = vec![vec![1.0, 0.0], vec![1.0, 0.0], vec![1.0, 0.0]]; let times = vec![1, 2, 3]; let mut priors = HashMap::new(); - for k in ["a", "b", "c"] { - let player = Player::new(Gaussian::new(25.0, 25.0 / 3.0), 25.0 / 6.0, 25.0 / 300.0); - - priors.insert(k.to_string(), player); + for agent in [a, b, c] { + priors.insert( + agent, + Player::new(Gaussian::new(25.0, 25.0 / 3.0), 25.0 / 6.0, 25.0 / 300.0), + ); } - let mut h2 = History::builder().build(); + let mut h2 = History::default(); - h2.add_events(composition, results, times, vec![], priors); + h2.add_events_with_prior(composition, results, times, vec![], priors); assert_ulps_eq!( - h2.batches[2].posterior("a"), - Gaussian::new(22.90352227792141, 6.011017301320632), - epsilon = 0.000001 + h2.batches[2].posterior(a), + Gaussian::new(22.903522, 6.011017), + epsilon = 1e-6 ); assert_ulps_eq!( - h2.batches[2].posterior("c"), - Gaussian::new(25.110702468366718, 5.866811597660157), - epsilon = 0.000001 + h2.batches[2].posterior(c), + Gaussian::new(25.110702, 5.866811), + epsilon = 1e-6 ); h2.convergence(ITERATIONS, EPSILON, false); assert_ulps_eq!( - h2.batches[2].posterior("a"), - Gaussian::new(24.99866831022851, 5.420053708148435), - epsilon = 0.000001 + h2.batches[2].posterior(a), + Gaussian::new(24.998668, 5.420053), + epsilon = 1e-6 ); assert_ulps_eq!( - h2.batches[2].posterior("c"), - Gaussian::new(25.000532179593538, 5.419827012784138), - epsilon = 0.000001 + h2.batches[2].posterior(c), + Gaussian::new(25.000532, 5.419827), + epsilon = 1e-6 ); } #[test] fn test_learning_curves() { + let mut index_map = IndexMap::new(); + + let a = index_map.get_or_create("a"); + let b = index_map.get_or_create("b"); + let c = index_map.get_or_create("c"); + let composition = vec![ - vec![vec!["a"], vec!["b"]], - vec![vec!["b"], vec!["c"]], - vec![vec!["c"], vec!["a"]], + vec![vec![a], vec![b]], + vec![vec![b], vec![c]], + vec![vec![c], vec![a]], ]; let results = vec![vec![1.0, 0.0], vec![1.0, 0.0], vec![1.0, 0.0]]; let times = vec![5, 6, 7]; let mut priors = HashMap::new(); - for k in ["a", "b", "c"] { - let player = Player::new(Gaussian::new(25.0, 25.0 / 3.0), 25.0 / 6.0, 25.0 / 300.0); - - priors.insert(k.to_string(), player); + for agent in [a, b, c] { + priors.insert( + agent, + Player::new(Gaussian::new(25.0, 25.0 / 3.0), 25.0 / 6.0, 25.0 / 300.0), + ); } let mut h = History::default(); - h.add_events(composition, results, times, vec![], priors); + h.add_events_with_prior(composition, results, times, vec![], priors); h.convergence(ITERATIONS, EPSILON, false); let lc = h.learning_curves(); - let aj_e = lc["a"].len(); - let cj_e = lc["c"].len(); + let aj_e = lc[&a].len(); + let cj_e = lc[&c].len(); - assert_eq!(lc["a"][0].0, 5); - assert_eq!(lc["a"][aj_e - 1].0, 7); + assert_eq!(lc[&a][0].0, 5); + assert_eq!(lc[&a][aj_e - 1].0, 7); assert_ulps_eq!( - lc["a"][aj_e - 1].1, - Gaussian::new(24.99866831022851, 5.420053708148435), - epsilon = 0.000001 + lc[&a][aj_e - 1].1, + Gaussian::new(24.998668, 5.420053), + epsilon = 1e-6 ); assert_ulps_eq!( - lc["c"][cj_e - 1].1, - Gaussian::new(25.000532179593538, 5.419827012784138), - epsilon = 0.000001 + lc[&c][cj_e - 1].1, + Gaussian::new(25.000532, 5.419827), + epsilon = 1e-6 ); } #[test] fn test_env_ttt() { + let mut index_map = IndexMap::new(); + + let a = index_map.get_or_create("a"); + let b = index_map.get_or_create("b"); + let c = index_map.get_or_create("c"); + let composition = vec![ - vec![vec!["a"], vec!["b"]], - vec![vec!["a"], vec!["c"]], - vec![vec!["b"], vec!["c"]], + vec![vec![a], vec![b]], + vec![vec![a], vec![c]], + vec![vec![b], vec![c]], ]; let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]]; @@ -624,36 +657,45 @@ mod tests { .time(false) .build(); - h.add_events(composition, results, vec![], vec![], HashMap::new()); + h.add_events(composition, results, vec![], vec![]); h.convergence(ITERATIONS, EPSILON, false); - assert_eq!(h.batches[2].skills["b"].elapsed, 1); - assert_eq!(h.batches[2].skills["c"].elapsed, 1); + assert_eq!(h.batches[2].skills[&b].elapsed, 1); + assert_eq!(h.batches[2].skills[&c].elapsed, 1); assert_ulps_eq!( - h.batches[0].posterior("a"), - Gaussian::new(25.0002673, 5.41938162), - epsilon = 0.000001 + h.batches[0].posterior(a), + Gaussian::new(25.000267, 5.419381), + epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[0].posterior("b"), - Gaussian::new(24.999465, 5.419425831), - epsilon = 0.000001 + h.batches[0].posterior(b), + Gaussian::new(24.999465, 5.419425), + epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[2].posterior("b"), - Gaussian::new(25.00053219, 5.419696790), - epsilon = 0.000001 + h.batches[2].posterior(b), + Gaussian::new(25.000532, 5.419696), + epsilon = 1e-6 ); } #[test] fn test_teams() { + let mut index_map = IndexMap::new(); + + let a = index_map.get_or_create("a"); + let b = index_map.get_or_create("b"); + let c = index_map.get_or_create("c"); + let d = index_map.get_or_create("d"); + let e = index_map.get_or_create("e"); + let f = index_map.get_or_create("f"); + let composition = vec![ - vec![vec!["a", "b"], vec!["c", "d"]], - vec![vec!["e", "f"], vec!["b", "c"]], - vec![vec!["a", "d"], vec!["e", "f"]], + vec![vec![a, b], vec![c, d]], + vec![vec![e, f], vec![b, c]], + vec![vec![a, d], vec![e, f]], ]; let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]]; @@ -665,76 +707,82 @@ mod tests { .time(false) .build(); - h.add_events(composition, results, vec![], vec![], HashMap::new()); + h.add_events(composition, results, vec![], vec![]); - let trueskill_log_evidence = h.log_evidence(false, &vec![]); - let trueskill_log_evidence_online = h.log_evidence(true, &vec![]); + let trueskill_log_evidence = h.log_evidence(false, &[]); + let trueskill_log_evidence_online = h.log_evidence(true, &[]); assert_ulps_eq!( trueskill_log_evidence, trueskill_log_evidence_online, - epsilon = 0.000001 + epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[0].posterior("b").mu, - -1.0 * h.batches[0].posterior("c").mu, - epsilon = 0.000001 + h.batches[0].posterior(b).mu, + -1.0 * h.batches[0].posterior(c).mu, + epsilon = 1e-6 ); - let evidence_second_event = h.log_evidence(false, &vec!["b"]).exp() * 2.0; - assert_ulps_eq!(0.5, evidence_second_event, epsilon = 0.000001); + let evidence_second_event = h.log_evidence(false, &[b]).exp() * 2.0; + assert_ulps_eq!(0.5, evidence_second_event, epsilon = 1e-6); - let evidence_third_event = h.log_evidence(false, &vec!["a"]).exp() * 2.0; - assert_ulps_eq!(0.669885, evidence_third_event, epsilon = 0.000001); + let evidence_third_event = h.log_evidence(false, &[a]).exp() * 2.0; + assert_ulps_eq!(0.669885, evidence_third_event, epsilon = 1e-6); h.convergence(ITERATIONS, EPSILON, false); - let loocv_hat = h.log_evidence(false, &vec![]).exp(); - let p_d_m_hat = h.log_evidence(true, &vec![]).exp(); + let loocv_hat = h.log_evidence(false, &[]).exp(); + let p_d_m_hat = h.log_evidence(true, &[]).exp(); - assert_ulps_eq!(loocv_hat, 0.2410274245857821, epsilon = 0.000001); - assert_ulps_eq!(p_d_m_hat, 0.17243238958411006, epsilon = 0.000001); + assert_ulps_eq!(loocv_hat, 0.241027, epsilon = 1e-6); + assert_ulps_eq!(p_d_m_hat, 0.172432, epsilon = 1e-6); assert_ulps_eq!( - h.batches[0].posterior("a"), - h.batches[0].posterior("b"), - epsilon = 0.000001 + h.batches[0].posterior(a), + h.batches[0].posterior(b), + epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[0].posterior("c"), - h.batches[0].posterior("d"), - epsilon = 0.000001 + h.batches[0].posterior(c), + h.batches[0].posterior(d), + epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[1].posterior("e"), - h.batches[1].posterior("f"), - epsilon = 0.000001 + h.batches[1].posterior(e), + h.batches[1].posterior(f), + epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[0].posterior("a"), - Gaussian::new(4.084902364982456, 5.10691909049607), - epsilon = 0.000001 + h.batches[0].posterior(a), + Gaussian::new(4.084902, 5.106919), + epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[0].posterior("c"), - Gaussian::new(-0.5330294544847751, 5.10691909049607), - epsilon = 0.000001 + h.batches[0].posterior(c), + Gaussian::new(-0.533029, 5.106919), + epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[2].posterior("e"), - Gaussian::new(-3.551872900373382, 5.154569731627773), - epsilon = 0.000001 + h.batches[2].posterior(e), + Gaussian::new(-3.551872, 5.154569), + epsilon = 1e-6 ); } #[test] fn test_add_events() { + let mut index_map = IndexMap::new(); + + let a = index_map.get_or_create("a"); + let b = index_map.get_or_create("b"); + let c = index_map.get_or_create("c"); + let composition = vec![ - vec![vec!["a"], vec!["b"]], - vec![vec!["a"], vec!["c"]], - vec![vec!["b"], vec!["c"]], + vec![vec![a], vec![b]], + vec![vec![a], vec![c]], + vec![vec![b], vec![c]], ]; let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]]; @@ -746,36 +794,30 @@ mod tests { .time(false) .build(); - h.add_events( - composition.clone(), - results.clone(), - vec![], - vec![], - HashMap::new(), - ); + h.add_events(composition.clone(), results.clone(), vec![], vec![]); h.convergence(ITERATIONS, EPSILON, false); - assert_eq!(h.batches[2].skills["b"].elapsed, 1); - assert_eq!(h.batches[2].skills["c"].elapsed, 1); + assert_eq!(h.batches[2].skills[&b].elapsed, 1); + assert_eq!(h.batches[2].skills[&c].elapsed, 1); assert_ulps_eq!( - h.batches[0].posterior("a"), - Gaussian::new(0.0, 1.30061), - epsilon = 0.000001 + h.batches[0].posterior(a), + Gaussian::new(0.000000, 1.300610), + epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[0].posterior("b"), - Gaussian::new(0.0, 1.30061), - epsilon = 0.000001 + h.batches[0].posterior(b), + Gaussian::new(0.000000, 1.300610), + epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[2].posterior("b"), - Gaussian::new(0.0, 1.30061), - epsilon = 0.000001 + h.batches[2].posterior(b), + Gaussian::new(0.000000, 1.300610), + epsilon = 1e-6 ); - h.add_events(composition, results, vec![], vec![], HashMap::new()); + h.add_events(composition, results, vec![], vec![]); assert_eq!(h.batches.len(), 6); @@ -785,45 +827,51 @@ mod tests { .map(|b| b.get_composition()) .collect::>(), vec![ - vec![vec![vec!["a"], vec!["b"]]], - vec![vec![vec!["a"], vec!["c"]]], - vec![vec![vec!["b"], vec!["c"]]], - vec![vec![vec!["a"], vec!["b"]]], - vec![vec![vec!["a"], vec!["c"]]], - vec![vec![vec!["b"], vec!["c"]]] + vec![vec![vec![a], vec![b]]], + vec![vec![vec![a], vec![c]]], + vec![vec![vec![b], vec![c]]], + vec![vec![vec![a], vec![b]]], + vec![vec![vec![a], vec![c]]], + vec![vec![vec![b], vec![c]]] ] ); h.convergence(ITERATIONS, EPSILON, false); assert_ulps_eq!( - h.batches[0].posterior("a"), - Gaussian::new(0.0, 0.9312360609998878), - epsilon = 0.000001 + h.batches[0].posterior(a), + Gaussian::new(0.000000, 0.931236), + epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[3].posterior("a"), - Gaussian::new(0.0, 0.9312360609998878), - epsilon = 0.000001 + h.batches[3].posterior(a), + Gaussian::new(0.000000, 0.931236), + epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[3].posterior("b"), - Gaussian::new(0.0, 0.9312360609998878), - epsilon = 0.000001 + h.batches[3].posterior(b), + Gaussian::new(0.000000, 0.931236), + epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[5].posterior("b"), - Gaussian::new(0.0, 0.9312360609998878), - epsilon = 0.000001 + h.batches[5].posterior(b), + Gaussian::new(0.000000, 0.931236), + epsilon = 1e-6 ); } #[test] fn test_only_add_events() { + let mut index_map = IndexMap::new(); + + let a = index_map.get_or_create("a"); + let b = index_map.get_or_create("b"); + let c = index_map.get_or_create("c"); + let composition = vec![ - vec![vec!["a"], vec!["b"]], - vec![vec!["a"], vec!["c"]], - vec![vec!["b"], vec!["c"]], + vec![vec![a], vec![b]], + vec![vec![a], vec![c]], + vec![vec![b], vec![c]], ]; let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]]; @@ -835,36 +883,30 @@ mod tests { .time(false) .build(); - h.add_events( - composition.clone(), - results.clone(), - vec![], - vec![], - HashMap::new(), - ); + h.add_events(composition.clone(), results.clone(), vec![], vec![]); h.convergence(ITERATIONS, EPSILON, false); - assert_eq!(h.batches[2].skills["b"].elapsed, 1); - assert_eq!(h.batches[2].skills["c"].elapsed, 1); + assert_eq!(h.batches[2].skills[&b].elapsed, 1); + assert_eq!(h.batches[2].skills[&c].elapsed, 1); assert_ulps_eq!( - h.batches[0].posterior("a"), - Gaussian::new(0.0, 1.30061), - epsilon = 0.000001 + h.batches[0].posterior(a), + Gaussian::new(0.000000, 1.300610), + epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[0].posterior("b"), - Gaussian::new(0.0, 1.30061), - epsilon = 0.000001 + h.batches[0].posterior(b), + Gaussian::new(0.000000, 1.300610), + epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[2].posterior("b"), - Gaussian::new(0.0, 1.30061), - epsilon = 0.000001 + h.batches[2].posterior(b), + Gaussian::new(0.000000, 1.300610), + epsilon = 1e-6 ); - h.add_events(composition, results, vec![], vec![], HashMap::new()); + h.add_events(composition, results, vec![], vec![]); assert_eq!(h.batches.len(), 6); @@ -874,99 +916,110 @@ mod tests { .map(|b| b.get_composition()) .collect::>(), vec![ - vec![vec![vec!["a"], vec!["b"]]], - vec![vec![vec!["a"], vec!["c"]]], - vec![vec![vec!["b"], vec!["c"]]], - vec![vec![vec!["a"], vec!["b"]]], - vec![vec![vec!["a"], vec!["c"]]], - vec![vec![vec!["b"], vec!["c"]]] + vec![vec![vec![a], vec![b]]], + vec![vec![vec![a], vec![c]]], + vec![vec![vec![b], vec![c]]], + vec![vec![vec![a], vec![b]]], + vec![vec![vec![a], vec![c]]], + vec![vec![vec![b], vec![c]]] ] ); h.convergence(ITERATIONS, EPSILON, false); assert_ulps_eq!( - h.batches[0].posterior("a"), - Gaussian::new(0.0, 0.9312360609998878), - epsilon = 0.000001 + h.batches[0].posterior(a), + Gaussian::new(0.000000, 0.931236), + epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[3].posterior("a"), - Gaussian::new(0.0, 0.9312360609998878), - epsilon = 0.000001 + h.batches[3].posterior(a), + Gaussian::new(0.000000, 0.931236), + epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[3].posterior("b"), - Gaussian::new(0.0, 0.9312360609998878), - epsilon = 0.000001 + h.batches[3].posterior(b), + Gaussian::new(0.000000, 0.931236), + epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[5].posterior("b"), - Gaussian::new(0.0, 0.9312360609998878), - epsilon = 0.000001 + h.batches[5].posterior(b), + Gaussian::new(0.000000, 0.931236), + epsilon = 1e-6 ); } #[test] fn test_log_evidence() { - let composition = vec![vec![vec!["a"], vec!["b"]], vec![vec!["b"], vec!["a"]]]; + let mut index_map = IndexMap::new(); + + let a = index_map.get_or_create("a"); + let b = index_map.get_or_create("b"); + + let composition = vec![vec![vec![a], vec![b]], vec![vec![b], vec![a]]]; let mut h = History::builder().time(false).build(); - h.add_events(composition.clone(), vec![], vec![], vec![], HashMap::new()); + h.add_events(composition.clone(), vec![], vec![], vec![]); - let p_d_m_2 = h.log_evidence(false, &vec![]).exp() * 2.0; + let p_d_m_2 = h.log_evidence(false, &[]).exp() * 2.0; - assert_ulps_eq!(p_d_m_2, 0.17650911, epsilon = 0.000001); + assert_ulps_eq!(p_d_m_2, 0.17650911, epsilon = 1e-6); assert_ulps_eq!( p_d_m_2, - h.log_evidence(true, &vec![]).exp() * 2.0, - epsilon = 0.000001 + h.log_evidence(true, &[]).exp() * 2.0, + epsilon = 1e-6 ); assert_ulps_eq!( p_d_m_2, - h.log_evidence(true, &vec!["a"]).exp() * 2.0, - epsilon = 0.000001 + h.log_evidence(true, &[a]).exp() * 2.0, + epsilon = 1e-6 ); assert_ulps_eq!( p_d_m_2, - h.log_evidence(false, &vec!["a"]).exp() * 2.0, - epsilon = 0.000001 + h.log_evidence(false, &[a]).exp() * 2.0, + epsilon = 1e-6 ); h.convergence(11, EPSILON, false); - let loocv_approx_2 = h.log_evidence(false, &vec![]).exp().sqrt(); + let loocv_approx_2 = h.log_evidence(false, &[]).exp().sqrt(); assert_ulps_eq!(loocv_approx_2, 0.001976774, epsilon = 0.000001); - let p_d_m_approx_2 = h.log_evidence(true, &vec![]).exp() * 2.0; + let p_d_m_approx_2 = h.log_evidence(true, &[]).exp() * 2.0; assert!(loocv_approx_2 - p_d_m_approx_2 < 1e-4); assert_ulps_eq!( loocv_approx_2, - h.log_evidence(true, &vec!["b"]).exp() * 2.0, - epsilon = 0.00001 + h.log_evidence(true, &[b]).exp() * 2.0, + epsilon = 1e-4 ); let mut h = History::builder().time(false).build(); - h.add_events(composition, vec![], vec![], vec![], HashMap::new()); + h.add_events(composition, vec![], vec![], vec![]); assert_ulps_eq!( ((0.5f64 * 0.1765).ln() / 2.0).exp(), - (h.log_evidence(false, &vec![]) / 2.0).exp(), + (h.log_evidence(false, &[]) / 2.0).exp(), epsilon = 1e-4 ); } #[test] fn test_add_events_with_time() { + let mut index_map = IndexMap::new(); + + let a = index_map.get_or_create("a"); + let b = index_map.get_or_create("b"); + let c = index_map.get_or_create("c"); + let composition = vec![ - vec![vec!["a"], vec!["b"]], - vec![vec!["a"], vec!["c"]], - vec![vec!["b"], vec!["c"]], + vec![vec![a], vec![b]], + vec![vec![a], vec![c]], + vec![vec![b], vec![c]], ]; let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]]; @@ -982,18 +1035,11 @@ mod tests { results.clone(), vec![0, 10, 20], vec![], - HashMap::new(), ); h.convergence(ITERATIONS, EPSILON, false); - h.add_events( - composition, - results, - vec![15, 10, 0], - vec![], - HashMap::new(), - ); + h.add_events(composition, results, vec![15, 10, 0], vec![]); assert_eq!(h.batches.len(), 4); @@ -1011,10 +1057,10 @@ mod tests { .map(|b| b.get_composition()) .collect::>(), vec![ - vec![vec![vec!["a"], vec!["b"]], vec![vec!["b"], vec!["c"]]], - vec![vec![vec!["a"], vec!["c"]], vec![vec!["a"], vec!["c"]]], - vec![vec![vec!["a"], vec!["b"]]], - vec![vec![vec!["b"], vec!["c"]]] + vec![vec![vec![a], vec![b]], vec![vec![b], vec![c]]], + vec![vec![vec![a], vec![c]], vec![vec![a], vec![c]]], + vec![vec![vec![a], vec![b]]], + vec![vec![vec![b], vec![c]]] ] ); @@ -1033,41 +1079,41 @@ mod tests { let end = h.batches.len() - 1; - assert_eq!(h.batches[0].skills["c"].elapsed, 0); - assert_eq!(h.batches[end].skills["c"].elapsed, 10); + assert_eq!(h.batches[0].skills[&c].elapsed, 0); + assert_eq!(h.batches[end].skills[&c].elapsed, 10); - assert_eq!(h.batches[0].skills["a"].elapsed, 0); - assert_eq!(h.batches[2].skills["a"].elapsed, 5); + assert_eq!(h.batches[0].skills[&a].elapsed, 0); + assert_eq!(h.batches[2].skills[&a].elapsed, 5); - assert_eq!(h.batches[0].skills["b"].elapsed, 0); - assert_eq!(h.batches[end].skills["b"].elapsed, 5); + assert_eq!(h.batches[0].skills[&b].elapsed, 0); + assert_eq!(h.batches[end].skills[&b].elapsed, 5); h.convergence(ITERATIONS, EPSILON, false); assert_ulps_eq!( - h.batches[0].posterior("b"), - h.batches[end].posterior("b"), - epsilon = 0.000001 + h.batches[0].posterior(b), + h.batches[end].posterior(b), + epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[0].posterior("c"), - h.batches[end].posterior("c"), - epsilon = 0.000001 + h.batches[0].posterior(c), + h.batches[end].posterior(c), + epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[0].posterior("c"), - h.batches[0].posterior("b"), - epsilon = 0.000001 + h.batches[0].posterior(c), + h.batches[0].posterior(b), + epsilon = 1e-6 ); // --------------------------------------- let composition = vec![ - vec![vec!["a"], vec!["b"]], - vec![vec!["c"], vec!["a"]], - vec![vec!["b"], vec!["c"]], + vec![vec![a], vec![b]], + vec![vec![c], vec![a]], + vec![vec![b], vec![c]], ]; let mut h = History::builder() @@ -1077,17 +1123,11 @@ mod tests { .gamma(0.0) .build(); - h.add_events( - composition.clone(), - vec![], - vec![0, 10, 20], - vec![], - HashMap::new(), - ); + h.add_events(composition.clone(), vec![], vec![0, 10, 20], vec![]); h.convergence(ITERATIONS, EPSILON, false); - h.add_events(composition, vec![], vec![15, 10, 0], vec![], HashMap::new()); + h.add_events(composition, vec![], vec![15, 10, 0], vec![]); assert_eq!(h.batches.len(), 4); @@ -1105,10 +1145,10 @@ mod tests { .map(|b| b.get_composition()) .collect::>(), vec![ - vec![vec![vec!["a"], vec!["b"]], vec![vec!["b"], vec!["c"]]], - vec![vec![vec!["c"], vec!["a"]], vec![vec!["c"], vec!["a"]]], - vec![vec![vec!["a"], vec!["b"]]], - vec![vec![vec!["b"], vec!["c"]]] + vec![vec![vec![a], vec![b]], vec![vec![b], vec![c]]], + vec![vec![vec![c], vec![a]], vec![vec![c], vec![a]]], + vec![vec![vec![a], vec![b]]], + vec![vec![vec![b], vec![c]]] ] ); @@ -1127,39 +1167,44 @@ mod tests { let end = h.batches.len() - 1; - assert_eq!(h.batches[0].skills["c"].elapsed, 0); - assert_eq!(h.batches[end].skills["c"].elapsed, 10); + assert_eq!(h.batches[0].skills[&c].elapsed, 0); + assert_eq!(h.batches[end].skills[&c].elapsed, 10); - assert_eq!(h.batches[0].skills["a"].elapsed, 0); - assert_eq!(h.batches[2].skills["a"].elapsed, 5); + assert_eq!(h.batches[0].skills[&a].elapsed, 0); + assert_eq!(h.batches[2].skills[&a].elapsed, 5); - assert_eq!(h.batches[0].skills["b"].elapsed, 0); - assert_eq!(h.batches[end].skills["b"].elapsed, 5); + assert_eq!(h.batches[0].skills[&b].elapsed, 0); + assert_eq!(h.batches[end].skills[&b].elapsed, 5); h.convergence(ITERATIONS, EPSILON, false); assert_ulps_eq!( - h.batches[0].posterior("b"), - h.batches[end].posterior("b"), - epsilon = 0.000001 + h.batches[0].posterior(b), + h.batches[end].posterior(b), + epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[0].posterior("c"), - h.batches[end].posterior("c"), - epsilon = 0.000001 + h.batches[0].posterior(c), + h.batches[end].posterior(c), + epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[0].posterior("c"), - h.batches[0].posterior("b"), - epsilon = 0.000001 + h.batches[0].posterior(c), + h.batches[0].posterior(b), + epsilon = 1e-6 ); } #[test] fn test_1vs1_weighted() { - let composition = vec![vec![vec!["a"], vec!["b"]], vec![vec!["b"], vec!["a"]]]; + let mut index_map = IndexMap::new(); + + let a = index_map.get_or_create("a"); + let b = index_map.get_or_create("b"); + + let composition = vec![vec![vec![a], vec![b]], vec![vec![b], vec![a]]]; let weights = vec![vec![vec![5.0], vec![4.0]], vec![vec![5.0], vec![4.0]]]; let mut h = History::builder() @@ -1170,38 +1215,38 @@ mod tests { .time(false) .build(); - h.add_events(composition.clone(), vec![], vec![], weights, HashMap::new()); + h.add_events(composition, vec![], vec![], weights); let lc = h.learning_curves(); assert_ulps_eq!( - lc["a"][0].1, - Gaussian::new(5.53765944, 4.758722), - epsilon = 0.000001 + lc[&a][0].1, + Gaussian::new(5.537659, 4.758722), + epsilon = 1e-6 ); assert_ulps_eq!( - lc["b"][0].1, - Gaussian::new(-0.83012755, 5.2395689), - epsilon = 0.000001 + lc[&b][0].1, + Gaussian::new(-0.830127, 5.239568), + epsilon = 1e-6 ); assert_ulps_eq!( - lc["a"][1].1, - Gaussian::new(1.7922776, 4.099566689), - epsilon = 0.000001 + lc[&a][1].1, + Gaussian::new(1.792277, 4.099566), + epsilon = 1e-6 ); assert_ulps_eq!( - lc["b"][1].1, - Gaussian::new(4.8455331752, 3.7476161), - epsilon = 0.000001 + lc[&b][1].1, + Gaussian::new(4.845533, 3.747616), + epsilon = 1e-6 ); h.convergence(ITERATIONS, EPSILON, false); let lc = h.learning_curves(); - assert_ulps_eq!(lc["a"][0].1, lc["a"][0].1, epsilon = 0.000001); - assert_ulps_eq!(lc["b"][0].1, lc["a"][0].1, epsilon = 0.000001); - assert_ulps_eq!(lc["a"][1].1, lc["a"][0].1, epsilon = 0.000001); - assert_ulps_eq!(lc["b"][1].1, lc["a"][0].1, epsilon = 0.000001); + assert_ulps_eq!(lc[&a][0].1, lc[&a][0].1, epsilon = 1e-6); + assert_ulps_eq!(lc[&b][0].1, lc[&a][0].1, epsilon = 1e-6); + assert_ulps_eq!(lc[&a][1].1, lc[&a][0].1, epsilon = 1e-6); + assert_ulps_eq!(lc[&b][1].1, lc[&a][0].1, epsilon = 1e-6); } } diff --git a/src/lib.rs b/src/lib.rs index 70ffa3c..1cab459 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,8 @@ +use std::borrow::{Borrow, ToOwned}; use std::cmp::Reverse; +use std::collections::HashMap; use std::f64::consts::{FRAC_1_SQRT_2, FRAC_2_SQRT_PI, SQRT_2}; +use std::hash::Hash; mod agent; #[cfg(feature = "approx")] @@ -27,19 +30,55 @@ pub const ITERATIONS: usize = 30; const SQRT_TAU: f64 = 2.5066282746310002; -const N01: Gaussian = Gaussian { +pub const N01: Gaussian = Gaussian { mu: 0.0, sigma: 1.0, }; -pub(crate) const N00: Gaussian = Gaussian { +pub const N00: Gaussian = Gaussian { mu: 0.0, sigma: 0.0, }; -pub(crate) const N_INF: Gaussian = Gaussian { +pub const N_INF: Gaussian = Gaussian { mu: 0.0, sigma: f64::INFINITY, }; +#[derive(Copy, Clone, Default, PartialEq, PartialOrd, Eq, Ord, Hash, Debug)] +pub struct Index(usize); + +impl From for Index { + fn from(ix: usize) -> Self { + Self(ix) + } +} + +pub struct IndexMap(HashMap); + +impl IndexMap +where + K: Eq + Hash, +{ + pub fn new() -> Self { + Self(HashMap::new()) + } + + pub fn get_or_create(&mut self, k: &Q) -> Index + where + K: Borrow, + Q: Hash + Eq + ToOwned, + { + if let Some(idx) = self.0.get(k) { + *idx + } else { + let idx = Index::from(self.0.len()); + + self.0.insert(k.to_owned(), idx); + + idx + } + } +} + fn erfc(x: f64) -> f64 { let z = x.abs(); let t = 1.0 / (1.0 + z / 2.0);