From 0efb0312dac8504e01e78550e48296c6122cc485 Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Sun, 19 Jun 2022 23:52:52 +0200 Subject: [PATCH] More tests, more code --- src/batch.rs | 82 ++++++-- src/game.rs | 8 +- src/history.rs | 503 +++++++++++++++++++++++++++++++++++++++---------- 3 files changed, 472 insertions(+), 121 deletions(-) diff --git a/src/batch.rs b/src/batch.rs index bbe9731..b89a943 100644 --- a/src/batch.rs +++ b/src/batch.rs @@ -37,7 +37,7 @@ struct Team { } #[derive(Debug)] -struct Event { +pub(crate) struct Event { teams: Vec, evidence: f64, weights: Vec>, @@ -53,7 +53,7 @@ impl Event { } pub(crate) struct Batch { - events: Vec, + pub(crate) events: Vec, pub(crate) skills: HashMap, pub(crate) time: u64, p_draw: f64, @@ -414,9 +414,40 @@ impl Batch { } } } + + pub(crate) fn get_composition(&self) -> Vec>> { + self.events + .iter() + .map(|event| { + event + .teams + .iter() + .map(|team| { + team.items + .iter() + .map(|item| item.agent.as_str()) + .collect::>() + }) + .collect::>() + }) + .collect::>() + } + + pub(crate) fn get_results(&self) -> Vec> { + self.events + .iter() + .map(|event| { + event + .teams + .iter() + .map(|team| team.output) + .collect::>() + }) + .collect::>() + } } -fn compute_elapsed(last_time: u64, actual_time: u64) -> u64 { +pub(crate) fn compute_elapsed(last_time: u64, actual_time: u64) -> u64 { if last_time == u64::MIN { 0 } else if last_time == u64::MAX { @@ -425,6 +456,7 @@ fn compute_elapsed(last_time: u64, actual_time: u64) -> u64 { actual_time - last_time } } + #[cfg(test)] mod tests { use approx::assert_ulps_eq; @@ -566,14 +598,23 @@ mod tests { let post = batch.posteriors(); - assert_ulps_eq!(post["a"].mu, 25.000000, epsilon = 0.000001); - assert_ulps_eq!(post["a"].sigma, 5.4192120, epsilon = 0.000001); + assert_ulps_eq!( + post["a"], + Gaussian::new(25.000000, 5.4192120), + epsilon = 0.000001 + ); - assert_ulps_eq!(post["b"].mu, 25.000000, epsilon = 0.000001); - assert_ulps_eq!(post["b"].sigma, 5.4192120, epsilon = 0.000001); + assert_ulps_eq!( + post["b"], + Gaussian::new(25.000000, 5.4192120), + epsilon = 0.000001 + ); - assert_ulps_eq!(post["c"].mu, 25.000000, epsilon = 0.000001); - assert_ulps_eq!(post["c"].sigma, 5.4192120, epsilon = 0.000001); + assert_ulps_eq!( + post["c"], + Gaussian::new(25.000000, 5.4192120), + epsilon = 0.000001 + ); batch.add_events( vec![ @@ -592,13 +633,20 @@ mod tests { let post = batch.posteriors(); - assert_ulps_eq!(post["a"].mu, 25.00000315330858, epsilon = 0.000001); - assert_ulps_eq!(post["a"].sigma, 3.880150268080797, epsilon = 0.000001); - - assert_ulps_eq!(post["b"].mu, 25.00000315330858, epsilon = 0.000001); - assert_ulps_eq!(post["b"].sigma, 3.880150268080797, epsilon = 0.000001); - - assert_ulps_eq!(post["c"].mu, 25.00000315330858, epsilon = 0.000001); - assert_ulps_eq!(post["c"].sigma, 3.880150268080797, epsilon = 0.000001); + assert_ulps_eq!( + post["a"], + Gaussian::new(25.00000315330858, 3.880150268080797), + epsilon = 0.000001 + ); + assert_ulps_eq!( + post["b"], + Gaussian::new(25.00000315330858, 3.880150268080797), + epsilon = 0.000001 + ); + assert_ulps_eq!( + post["c"], + Gaussian::new(25.00000315330858, 3.880150268080797), + epsilon = 0.000001 + ); } } diff --git a/src/game.rs b/src/game.rs index 60510cc..a5dc8dc 100644 --- a/src/game.rs +++ b/src/game.rs @@ -24,12 +24,12 @@ impl Game { ) -> Self { assert!( (result.is_empty() || result.len() == teams.len()), - "result.must be empty or the same length as teams" + "result must be empty or the same length as teams" ); assert!( (weights.is_empty() || weights.len() == teams.len()), - "weights.must be empty or the same length as teams" + "weights must be empty or the same length as teams" ); assert!( @@ -38,7 +38,7 @@ impl Game { .iter() .zip(teams.iter()) .all(|(w, t)| w.len() == t.len()), - "weights.must be empty or has the same dimensions as teams" + "weights must be empty or has the same dimensions as teams" ); assert!( @@ -52,7 +52,7 @@ impl Game { r.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); r.windows(2).all(|w| w[0] != w[1]) }, - "draw.must be > 0.0 if there is teams with draw" + "draw must be > 0.0 if there is teams with draw" ); if result.is_empty() { diff --git a/src/history.rs b/src/history.rs index eb8e929..485ed70 100644 --- a/src/history.rs +++ b/src/history.rs @@ -2,7 +2,7 @@ use std::collections::{HashMap, HashSet}; use crate::{ agent::{self, Agent}, - batch::Batch, + batch::{self, Batch}, gaussian::Gaussian, player::Player, sort_time, tuple_gt, tuple_max, @@ -40,15 +40,15 @@ impl History { ) -> Self { assert!( results.is_empty() || results.len() == composition.len(), - "TODO: Add a comment here" + "(length(results) > 0) & (length(composition) != length(results))" ); assert!( times.is_empty() || times.len() == composition.len(), - "TODO: Add a comment here" + "length(times) > 0) & (length(composition) != length(times))" ); assert!( weights.is_empty() || weights.len() == composition.len(), - "TODO: Add a comment here" + "(length(weights) > 0) & (length(composition) != length(weights))" ); let this_agent = composition @@ -293,6 +293,183 @@ impl History { .map(|batch| batch.log_evidence2(self.online, agents, forward, &mut self.agents)) .sum() } + + pub fn add_events( + &mut self, + composition: Vec>>, + results: Vec>, + times: Vec, + weights: Vec>>, + priors: HashMap, + ) { + assert!(times.is_empty() || self.time, "length(times)>0 but !h.time"); + assert!( + !times.is_empty() || !self.time, + "length(times)==0 but h.time" + ); + assert!( + results.is_empty() || results.len() == composition.len(), + "(length(results) > 0) & (length(composition) != length(results))" + ); + assert!( + times.is_empty() || times.len() == composition.len(), + "length(times) > 0) & (length(composition) != length(times))" + ); + assert!( + weights.is_empty() || weights.len() == composition.len(), + "(length(weights) > 0) & (length(composition) != length(weights))" + ); + + let this_agent = composition + .iter() + .flatten() + .flatten() + .cloned() + .collect::>(); + + for agent in &this_agent { + if !self.agents.contains_key(*agent) { + self.agents.insert( + agent.to_string(), + Agent { + player: priors.get(*agent).cloned().unwrap_or_else(|| { + Player::new(Gaussian::new(self.mu, self.sigma), self.beta, self.gamma) + }), + ..Default::default() + }, + ); + } + } + + agent::clean(self.agents.values_mut(), true); + + let n = composition.len(); + let o = if self.time { + sort_time(×, false) + } else { + (0..composition.len()).collect::>() + }; + + let mut i = 0; + let mut k = 0; + + while i < n { + let mut j = i + 1; + let t = if self.time { times[o[i]] } else { i as u64 + 1 }; + + while self.time && j < self.size && times[o[j]] == t { + j += 1; + } + + while (!self.time && (self.size > k)) + || (self.time && self.batches.len() > k && self.batches[k].time < t) + { + let b = &mut self.batches[k]; + + if k > 0 { + b.new_forward_info(&mut self.agents); + } + + let intersect = this_agent + .iter() + .filter(|&&agent| b.skills.contains_key(agent)) + .cloned() + .collect::>(); + + for agent in intersect { + b.skills.get_mut(agent).unwrap().elapsed = + batch::compute_elapsed(self.agents[agent].last_time, b.time); + + let a = self.agents.get_mut(agent).unwrap(); + + a.last_time = if self.time { b.time } else { u64::MAX }; + a.message = b.forward_prior_out(agent); + } + + k += 1; + } + + let composition = (i..j) + .map(|e| composition[o[e]].clone()) + .collect::>(); + let results = if results.is_empty() { + Vec::new() + } else { + (i..j).map(|e| results[o[e]].clone()).collect::>() + }; + + let weights = if weights.is_empty() { + Vec::new() + } else { + (i..j).map(|e| weights[o[e]].clone()).collect::>() + }; + + if self.time && self.batches.len() > k && self.batches[k].time == t { + let b = &mut self.batches[k]; + + b.add_events(composition, results, weights, &mut self.agents); + + for a in b.skills.keys() { + let agent = self.agents.get_mut(a).unwrap(); + + agent.last_time = if self.time { t } else { u64::MAX }; + agent.message = b.forward_prior_out(a); + } + } else { + let b = Batch::new( + composition, + results, + weights, + t, + self.p_draw, + &mut self.agents, + ); + + self.batches.insert(k, b); + + let b = &mut self.batches[k]; + + for a in b.skills.keys() { + let agent = self.agents.get_mut(a).unwrap(); + + agent.last_time = if self.time { t } else { u64::MAX }; + agent.message = b.forward_prior_out(a); + } + + k += 1; + } + + i = j; + } + + while self.time && self.batches.len() > k { + let b = &mut self.batches[k]; + + b.new_backward_info(&mut self.agents); + + let intersect = this_agent + .iter() + .filter(|&&agent| b.skills.contains_key(agent)) + .cloned() + .collect::>(); + + for agent in intersect { + b.skills.get_mut(agent).unwrap().elapsed = + batch::compute_elapsed(self.agents[agent].last_time, b.time); + + let a = self.agents.get_mut(agent).unwrap(); + + a.last_time = if self.time { b.time } else { u64::MAX }; + a.message = b.forward_prior_out(agent); + } + + k += 1; + } + + self.size += n; + + self.iteration(); + } } #[cfg(test)] @@ -341,8 +518,11 @@ mod tests { let p0 = h.batches[0].posteriors(); - assert_ulps_eq!(p0["a"].mu, 29.205220743876975, epsilon = 0.000001); - assert_ulps_eq!(p0["a"].sigma, 7.194481422570443, epsilon = 0.000001); + assert_ulps_eq!( + p0["a"], + Gaussian::new(29.205220743876975, 7.194481422570443), + epsilon = 0.000001 + ); let observed = h.batches[1].skills["a"].forward.sigma; let gamma: f64 = 0.15 * 25.0 / 3.0; @@ -351,6 +531,7 @@ mod tests { assert_ulps_eq!(observed, expected, epsilon = 0.000001); let observed = h.batches[1].posterior("a"); + let p = Game::new( h.batches[1].within_priors(0, false, false, &mut h.agents), vec![0.0, 1.0], @@ -360,8 +541,7 @@ mod tests { .posteriors(); let expected = p[0][0]; - assert_ulps_eq!(observed.mu, expected.mu, epsilon = 0.000001); - assert_ulps_eq!(observed.sigma, expected.sigma, epsilon = 0.000001); + assert_ulps_eq!(observed, expected, epsilon = 0.000001); } #[test] @@ -401,46 +581,26 @@ mod tests { ); assert_ulps_eq!( - h1.batches[0].posterior("a").mu, - 22.904409330892914, + h1.batches[0].posterior("a"), + Gaussian::new(22.904409330892914, 6.0103304390431), epsilon = 0.000001 ); assert_ulps_eq!( - h1.batches[0].posterior("a").sigma, - 6.0103304390431, - epsilon = 0.000001 - ); - assert_ulps_eq!( - h1.batches[0].posterior("c").mu, - 25.110318212568806, - epsilon = 0.000001 - ); - assert_ulps_eq!( - h1.batches[0].posterior("c").sigma, - 5.866311348102563, + h1.batches[0].posterior("c"), + Gaussian::new(25.110318212568806, 5.866311348102563), epsilon = 0.000001 ); let (_step, _i) = h1.convergence(ITERATIONS, EPSILON, false); assert_ulps_eq!( - h1.batches[0].posterior("a").mu, - 25.00000000, + h1.batches[0].posterior("a"), + Gaussian::new(25.00000000, 5.41921200), epsilon = 0.000001 ); assert_ulps_eq!( - h1.batches[0].posterior("a").sigma, - 5.41921200, - epsilon = 0.000001 - ); - assert_ulps_eq!( - h1.batches[0].posterior("c").mu, - 25.00000000, - epsilon = 0.000001 - ); - assert_ulps_eq!( - h1.batches[0].posterior("c").sigma, - 5.41921200, + h1.batches[0].posterior("c"), + Gaussian::new(25.00000000, 5.41921200), epsilon = 0.000001 ); @@ -475,46 +635,26 @@ mod tests { ); assert_ulps_eq!( - h2.batches[2].posterior("a").mu, - 22.90352227792141, + h2.batches[2].posterior("a"), + Gaussian::new(22.90352227792141, 6.011017301320632), epsilon = 0.000001 ); assert_ulps_eq!( - h2.batches[2].posterior("a").sigma, - 6.011017301320632, - epsilon = 0.000001 - ); - assert_ulps_eq!( - h2.batches[2].posterior("c").mu, - 25.110702468366718, - epsilon = 0.000001 - ); - assert_ulps_eq!( - h2.batches[2].posterior("c").sigma, - 5.866811597660157, + h2.batches[2].posterior("c"), + Gaussian::new(25.110702468366718, 5.866811597660157), epsilon = 0.000001 ); - let (_step, _i) = h2.convergence(ITERATIONS, EPSILON, false); + h2.convergence(ITERATIONS, EPSILON, false); assert_ulps_eq!( - h2.batches[2].posterior("a").mu, - 24.99866831022851, + h2.batches[2].posterior("a"), + Gaussian::new(24.99866831022851, 5.420053708148435), epsilon = 0.000001 ); assert_ulps_eq!( - h2.batches[2].posterior("a").sigma, - 5.420053708148435, - epsilon = 0.000001 - ); - assert_ulps_eq!( - h2.batches[2].posterior("c").mu, - 25.000532179593538, - epsilon = 0.000001 - ); - assert_ulps_eq!( - h2.batches[2].posterior("c").sigma, - 5.419827012784138, + h2.batches[2].posterior("c"), + Gaussian::new(25.000532179593538, 5.419827012784138), epsilon = 0.000001 ); } @@ -562,23 +702,13 @@ mod tests { assert_eq!(lc["a"][aj_e - 1].0, 7); assert_ulps_eq!( - lc["a"][aj_e - 1].1.mu, - 24.99866831022851, + lc["a"][aj_e - 1].1, + Gaussian::new(24.99866831022851, 5.420053708148435), epsilon = 0.000001 ); assert_ulps_eq!( - lc["a"][aj_e - 1].1.sigma, - 5.420053708148435, - epsilon = 0.000001 - ); - assert_ulps_eq!( - lc["c"][cj_e - 1].1.mu, - 25.000532179593538, - epsilon = 0.000001 - ); - assert_ulps_eq!( - lc["c"][cj_e - 1].1.sigma, - 5.419827012784138, + lc["c"][cj_e - 1].1, + Gaussian::new(25.000532179593538, 5.419827012784138), epsilon = 0.000001 ); } @@ -612,33 +742,18 @@ mod tests { assert_eq!(h.batches[2].skills["c"].elapsed, 1); assert_ulps_eq!( - h.batches[0].posterior("a").mu, - 25.0002673, + h.batches[0].posterior("a"), + Gaussian::new(25.0002673, 5.41938162), epsilon = 0.000001 ); assert_ulps_eq!( - h.batches[0].posterior("a").sigma, - 5.41938162, + h.batches[0].posterior("b"), + Gaussian::new(24.999465, 5.419425831), epsilon = 0.000001 ); assert_ulps_eq!( - h.batches[0].posterior("b").mu, - 24.999465, - epsilon = 0.000001 - ); - assert_ulps_eq!( - h.batches[0].posterior("b").sigma, - 5.419425831, - epsilon = 0.000001 - ); - assert_ulps_eq!( - h.batches[2].posterior("b").mu, - 25.00053219, - epsilon = 0.000001 - ); - assert_ulps_eq!( - h.batches[2].posterior("b").sigma, - 5.419696790, + h.batches[2].posterior("b"), + Gaussian::new(25.00053219, 5.419696790), epsilon = 0.000001 ); } @@ -727,4 +842,192 @@ mod tests { epsilon = 0.000001 ); } + + #[test] + fn test_add_events() { + let composition = vec![ + vec![vec!["a"], vec!["b"]], + vec![vec!["a"], vec!["c"]], + vec![vec!["b"], vec!["c"]], + ]; + let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]]; + + let mut h = History::new( + composition.clone(), + results.clone(), + vec![], + vec![], + HashMap::new(), + 0.0, + 2.0, + 1.0, + 0.0, + 0.0, + false, + ); + + let (_step, _i) = h.convergence(ITERATIONS, EPSILON, false); + + assert_eq!(h.batches[2].skills["b"].elapsed, 1); + assert_eq!(h.batches[2].skills["c"].elapsed, 1); + + assert_ulps_eq!( + h.batches[0].posterior("a"), + Gaussian::new(0.0, 1.30061), + epsilon = 0.000001 + ); + assert_ulps_eq!( + h.batches[0].posterior("b"), + Gaussian::new(0.0, 1.30061), + epsilon = 0.000001 + ); + assert_ulps_eq!( + h.batches[2].posterior("b"), + Gaussian::new(0.0, 1.30061), + epsilon = 0.000001 + ); + + h.add_events(composition, results, vec![], vec![], HashMap::new()); + + assert_eq!(h.batches.len(), 6); + + assert_eq!( + h.batches + .iter() + .map(|b| b.get_composition()) + .collect::>(), + vec![ + vec![vec![vec!["a"], vec!["b"]]], + vec![vec![vec!["a"], vec!["c"]]], + vec![vec![vec!["b"], vec!["c"]]], + vec![vec![vec!["a"], vec!["b"]]], + vec![vec![vec!["a"], vec!["c"]]], + vec![vec![vec!["b"], vec!["c"]]] + ] + ); + + h.convergence(ITERATIONS, EPSILON, false); + + assert_ulps_eq!( + h.batches[0].posterior("a"), + Gaussian::new(0.0, 0.9312360609998878), + epsilon = 0.000001 + ); + assert_ulps_eq!( + h.batches[3].posterior("a"), + Gaussian::new(0.0, 0.9312360609998878), + epsilon = 0.000001 + ); + assert_ulps_eq!( + h.batches[3].posterior("b"), + Gaussian::new(0.0, 0.9312360609998878), + epsilon = 0.000001 + ); + assert_ulps_eq!( + h.batches[5].posterior("b"), + Gaussian::new(0.0, 0.9312360609998878), + epsilon = 0.000001 + ); + } + + #[test] + fn test_add_events_with_time() { + let composition = vec![ + vec![vec!["a"], vec!["b"]], + vec![vec!["a"], vec!["c"]], + vec![vec!["b"], vec!["c"]], + ]; + let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]]; + + let mut h = History::new( + composition.clone(), + results.clone(), + vec![0, 10, 20], + vec![], + HashMap::new(), + 0.0, + 2.0, + 1.0, + 0.0, + 0.0, + false, + ); + + h.convergence(ITERATIONS, EPSILON, false); + + h.add_events( + composition, + results, + vec![15, 10, 0], + vec![], + HashMap::new(), + ); + + assert_eq!(h.batches.len(), 4); + + assert_eq!( + h.batches + .iter() + .map(|batch| batch.events.len()) + .collect::>(), + vec![2, 2, 1, 1] + ); + + assert_eq!( + h.batches + .iter() + .map(|b| b.get_composition()) + .collect::>(), + vec![ + vec![vec![vec!["a"], vec!["b"]], vec![vec!["b"], vec!["c"]]], + vec![vec![vec!["a"], vec!["c"]], vec![vec!["a"], vec!["c"]]], + vec![vec![vec!["a"], vec!["b"]]], + vec![vec![vec!["b"], vec!["c"]]] + ] + ); + + assert_eq!( + h.batches + .iter() + .map(|b| b.get_results()) + .collect::>(), + vec![ + vec![vec![1.0, 0.0], vec![1.0, 0.0]], + vec![vec![0.0, 1.0], vec![0.0, 1.0]], + vec![vec![1.0, 0.0]], + vec![vec![1.0, 0.0]] + ] + ); + + let end = h.batches.len() - 1; + + assert_eq!(h.batches[0].skills["c"].elapsed, 0); + assert_eq!(h.batches[end].skills["c"].elapsed, 10); + + assert_eq!(h.batches[0].skills["a"].elapsed, 0); + assert_eq!(h.batches[2].skills["a"].elapsed, 5); + + assert_eq!(h.batches[0].skills["b"].elapsed, 0); + assert_eq!(h.batches[end].skills["b"].elapsed, 5); + + h.convergence(ITERATIONS, EPSILON, false); + + assert_ulps_eq!( + h.batches[0].posterior("b"), + h.batches[end].posterior("b"), + epsilon = 0.000001 + ); + + assert_ulps_eq!( + h.batches[0].posterior("c"), + h.batches[end].posterior("c"), + epsilon = 0.000001 + ); + + assert_ulps_eq!( + h.batches[0].posterior("c"), + h.batches[0].posterior("b"), + epsilon = 0.000001 + ); + } }