use std::collections::HashMap; use crate::{ Index, N_INF, agent::Agent, drift::Drift, game::Game, gaussian::Gaussian, player::Player, tuple_gt, tuple_max, }; #[derive(Debug)] pub(crate) struct Skill { pub(crate) forward: Gaussian, backward: Gaussian, likelihood: Gaussian, pub(crate) elapsed: i64, pub(crate) online: Gaussian, } impl Skill { pub(crate) fn posterior(&self) -> Gaussian { self.likelihood * self.backward * self.forward } } impl Default for Skill { fn default() -> Self { Self { forward: N_INF, backward: N_INF, likelihood: N_INF, elapsed: 0, online: N_INF, } } } #[derive(Debug)] struct Item { agent: Index, likelihood: Gaussian, } impl Item { fn within_prior( &self, online: bool, forward: bool, skills: &HashMap, agents: &HashMap>, ) -> Player { let r = &agents[&self.agent].player; let skill = &skills[&self.agent]; if online { Player::new(skill.online, r.beta, r.drift) } else if forward { Player::new(skill.forward, r.beta, r.drift) } else { Player::new(skill.posterior() / self.likelihood, r.beta, r.drift) } } } #[derive(Debug)] struct Team { items: Vec, output: f64, } #[derive(Debug)] pub(crate) struct Event { teams: Vec, evidence: f64, weights: Vec>, } impl Event { fn outputs(&self) -> Vec { self.teams .iter() .map(|team| team.output) .collect::>() } pub(crate) fn within_priors( &self, online: bool, forward: bool, skills: &HashMap, agents: &HashMap>, ) -> Vec>> { self.teams .iter() .map(|team| { team.items .iter() .map(|item| item.within_prior(online, forward, skills, agents)) .collect::>() }) .collect::>() } } #[derive(Debug)] pub struct Batch { pub(crate) events: Vec, pub(crate) skills: HashMap, pub(crate) time: i64, p_draw: f64, } impl Batch { pub fn new(time: i64, p_draw: f64) -> Self { Self { events: Vec::new(), skills: HashMap::new(), time, p_draw, } } pub fn add_events( &mut self, composition: Vec>>, results: Vec>, weights: Vec>>, agents: &HashMap>, ) { let mut unique = Vec::with_capacity(10); let this_agent = composition.iter().flatten().flatten().filter(|idx| { if !unique.contains(idx) { unique.push(*idx); return true; } false }); for idx in this_agent { let elapsed = compute_elapsed(agents[&idx].last_time, self.time); if let Some(skill) = self.skills.get_mut(idx) { skill.elapsed = elapsed; skill.forward = agents[&idx].receive(elapsed); } else { self.skills.insert( *idx, Skill { forward: agents[&idx].receive(elapsed), elapsed, ..Default::default() }, ); } } let events = composition.iter().enumerate().map(|(e, event)| { let teams = event .iter() .enumerate() .map(|(t, team)| { let items = team .iter() .map(|&agent| Item { agent, likelihood: N_INF, }) .collect::>(); Team { items, output: if results.is_empty() { (event.len() - (t + 1)) as f64 } else { results[e][t] }, } }) .collect::>(); let weights = if weights.is_empty() { teams .iter() .map(|team| vec![1.0; team.items.len()]) .collect::>() } else { weights[e].clone() }; Event { teams, evidence: 0.0, weights, } }); let from = self.events.len(); self.events.extend(events); self.iteration(from, agents); } pub(crate) fn posteriors(&self) -> HashMap { self.skills .iter() .map(|(&idx, skill)| (idx, skill.posterior())) .collect::>() } pub fn iteration(&mut self, from: usize, agents: &HashMap>) { for event in self.events.iter_mut().skip(from) { let teams = event.within_priors(false, false, &self.skills, agents); let result = event.outputs(); let g = Game::new(teams, &result, &event.weights, self.p_draw); for (t, team) in event.teams.iter_mut().enumerate() { for (i, item) in team.items.iter_mut().enumerate() { self.skills.get_mut(&item.agent).unwrap().likelihood = (self.skills[&item.agent].likelihood / item.likelihood) * g.likelihoods[t][i]; item.likelihood = g.likelihoods[t][i]; } } event.evidence = g.evidence; } } #[allow(dead_code)] pub(crate) fn convergence(&mut self, agents: &HashMap>) -> usize { let epsilon = 1e-6; let iterations = 20; let mut step = (f64::INFINITY, f64::INFINITY); let mut i = 0; while tuple_gt(step, epsilon) && i < iterations { let old = self.posteriors(); self.iteration(0, agents); let new = self.posteriors(); step = old.iter().fold((0.0, 0.0), |step, (a, old)| { tuple_max(step, old.delta(new[a])) }); i += 1; } i } pub(crate) fn forward_prior_out(&self, agent: &Index) -> Gaussian { let skill = &self.skills[agent]; skill.forward * skill.likelihood } pub(crate) fn backward_prior_out( &self, agent: &Index, agents: &HashMap>, ) -> Gaussian { let skill = &self.skills[agent]; let n = skill.likelihood * skill.backward; n.forget(agents[agent].player.drift.variance_delta(skill.elapsed)) } pub(crate) fn new_backward_info(&mut self, agents: &HashMap>) { for (agent, skill) in self.skills.iter_mut() { skill.backward = agents[agent].message; } self.iteration(0, agents); } pub(crate) fn new_forward_info(&mut self, agents: &HashMap>) { for (agent, skill) in self.skills.iter_mut() { skill.forward = agents[agent].receive(skill.elapsed); } self.iteration(0, agents); } pub(crate) fn log_evidence( &self, online: bool, targets: &[Index], forward: bool, agents: &HashMap>, ) -> f64 { if targets.is_empty() { if online || forward { self.events .iter() .enumerate() .map(|(_, event)| { Game::new( event.within_priors(online, forward, &self.skills, agents), &event.outputs(), &event.weights, self.p_draw, ) .evidence .ln() }) .sum() } else { self.events.iter().map(|event| event.evidence.ln()).sum() } } else if online || forward { self.events .iter() .enumerate() .filter(|(_, event)| { event .teams .iter() .flat_map(|team| &team.items) .any(|item| targets.contains(&item.agent)) }) .map(|(_, event)| { Game::new( event.within_priors(online, forward, &self.skills, agents), &event.outputs(), &event.weights, self.p_draw, ) .evidence .ln() }) .sum() } else { self.events .iter() .filter(|event| { event .teams .iter() .flat_map(|team| &team.items) .any(|item| targets.contains(&item.agent)) }) .map(|event| event.evidence.ln()) .sum() } } pub fn get_composition(&self) -> Vec>> { self.events .iter() .map(|event| { event .teams .iter() .map(|team| team.items.iter().map(|item| item.agent).collect::>()) .collect::>() }) .collect::>() } pub fn get_results(&self) -> Vec> { self.events .iter() .map(|event| { event .teams .iter() .map(|team| team.output) .collect::>() }) .collect::>() } } pub(crate) fn compute_elapsed(last_time: i64, actual_time: i64) -> i64 { if last_time == i64::MIN { 0 } else if last_time == i64::MAX { 1 } else { actual_time - last_time } } #[cfg(test)] mod tests { use approx::assert_ulps_eq; use crate::{IndexMap, agent::Agent, drift::ConstantDrift, player::Player}; use super::*; #[test] fn test_one_event_each() { let mut index_map = IndexMap::new(); let a = index_map.get_or_create("a"); let b = index_map.get_or_create("b"); let c = index_map.get_or_create("c"); let d = index_map.get_or_create("d"); let e = index_map.get_or_create("e"); let f = index_map.get_or_create("f"); let mut agents = HashMap::new(); for agent in [a, b, c, d, e, f] { agents.insert( agent, Agent { player: Player::new( Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, ConstantDrift(25.0 / 300.0), ), ..Default::default() }, ); } let mut batch = Batch::new(0, 0.0); batch.add_events( vec![ vec![vec![a], vec![b]], vec![vec![c], vec![d]], vec![vec![e], vec![f]], ], vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]], vec![], &agents, ); let post = batch.posteriors(); assert_ulps_eq!( post[&a], Gaussian::from_ms(29.205220, 7.194481), epsilon = 1e-6 ); assert_ulps_eq!( post[&b], Gaussian::from_ms(20.794779, 7.194481), epsilon = 1e-6 ); assert_ulps_eq!( post[&c], Gaussian::from_ms(20.794779, 7.194481), epsilon = 1e-6 ); assert_ulps_eq!( post[&d], Gaussian::from_ms(29.205220, 7.194481), epsilon = 1e-6 ); assert_ulps_eq!( post[&e], Gaussian::from_ms(29.205220, 7.194481), epsilon = 1e-6 ); assert_ulps_eq!( post[&f], Gaussian::from_ms(20.794779, 7.194481), epsilon = 1e-6 ); assert_eq!(batch.convergence(&agents), 1); } #[test] fn test_same_strength() { let mut index_map = IndexMap::new(); let a = index_map.get_or_create("a"); let b = index_map.get_or_create("b"); let c = index_map.get_or_create("c"); let d = index_map.get_or_create("d"); let e = index_map.get_or_create("e"); let f = index_map.get_or_create("f"); let mut agents = HashMap::new(); for agent in [a, b, c, d, e, f] { agents.insert( agent, Agent { player: Player::new( Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, ConstantDrift(25.0 / 300.0), ), ..Default::default() }, ); } let mut batch = Batch::new(0, 0.0); batch.add_events( vec![ vec![vec![a], vec![b]], vec![vec![a], vec![c]], vec![vec![b], vec![c]], ], vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]], vec![], &agents, ); let post = batch.posteriors(); assert_ulps_eq!( post[&a], Gaussian::from_ms(24.960978, 6.298544), epsilon = 1e-6 ); assert_ulps_eq!( post[&b], Gaussian::from_ms(27.095590, 6.010330), epsilon = 1e-6 ); assert_ulps_eq!( post[&c], Gaussian::from_ms(24.889681, 5.866311), epsilon = 1e-6 ); assert!(batch.convergence(&agents) > 1); let post = batch.posteriors(); assert_ulps_eq!( post[&a], Gaussian::from_ms(25.000000, 5.419212), epsilon = 1e-6 ); assert_ulps_eq!( post[&b], Gaussian::from_ms(25.000000, 5.419212), epsilon = 1e-6 ); assert_ulps_eq!( post[&c], Gaussian::from_ms(25.000000, 5.419212), epsilon = 1e-6 ); } #[test] fn test_add_events() { let mut index_map = IndexMap::new(); let a = index_map.get_or_create("a"); let b = index_map.get_or_create("b"); let c = index_map.get_or_create("c"); let d = index_map.get_or_create("d"); let e = index_map.get_or_create("e"); let f = index_map.get_or_create("f"); let mut agents = HashMap::new(); for agent in [a, b, c, d, e, f] { agents.insert( agent, Agent { player: Player::new( Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, ConstantDrift(25.0 / 300.0), ), ..Default::default() }, ); } let mut batch = Batch::new(0, 0.0); batch.add_events( vec![ vec![vec![a], vec![b]], vec![vec![a], vec![c]], vec![vec![b], vec![c]], ], vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]], vec![], &agents, ); batch.convergence(&agents); let post = batch.posteriors(); assert_ulps_eq!( post[&a], Gaussian::from_ms(25.000000, 5.419212), epsilon = 1e-6 ); assert_ulps_eq!( post[&b], Gaussian::from_ms(25.000000, 5.419212), epsilon = 1e-6 ); assert_ulps_eq!( post[&c], Gaussian::from_ms(25.000000, 5.419212), epsilon = 1e-6 ); batch.add_events( vec![ vec![vec![a], vec![b]], vec![vec![a], vec![c]], vec![vec![b], vec![c]], ], vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]], vec![], &agents, ); assert_eq!(batch.events.len(), 6); batch.convergence(&agents); let post = batch.posteriors(); assert_ulps_eq!( post[&a], Gaussian::from_ms(25.000003, 3.880150), epsilon = 1e-6 ); assert_ulps_eq!( post[&b], Gaussian::from_ms(25.000003, 3.880150), epsilon = 1e-6 ); assert_ulps_eq!( post[&c], Gaussian::from_ms(25.000003, 3.880150), epsilon = 1e-6 ); } }