diff --git a/src/batch.rs b/src/batch.rs new file mode 100644 index 0000000..136a0b4 --- /dev/null +++ b/src/batch.rs @@ -0,0 +1,284 @@ +use std::collections::{HashMap, HashSet}; + +use crate::{Game, Gaussian, Player, N_INF}; + +pub struct Skill { + pub forward: Gaussian, + pub backward: Gaussian, + pub likelihood: Gaussian, + pub elapsed: f64, +} + +impl Default for Skill { + fn default() -> Self { + Self { + forward: N_INF, + backward: N_INF, + likelihood: N_INF, + elapsed: 0.0, + } + } +} + +pub struct Agent { + pub player: Player, + pub message: Gaussian, + pub last_time: f64, +} + +impl Agent { + pub fn new(player: Player, message: Gaussian, last_time: f64) -> Self { + Self { + player, + message, + last_time, + } + } + + pub fn receive(&self, elapsed: f64) -> Gaussian { + if self.message != N_INF { + self.message.forget(self.player.gamma, elapsed) + } else { + self.player.prior + } + } +} + +pub struct Item { + name: String, + likelihood: Gaussian, +} + +pub struct Team { + items: Vec, + output: u16, +} + +pub struct Event { + teams: Vec, + evidence: f64, +} + +impl Event { + pub fn names(&self) -> Vec<&str> { + self.teams + .iter() + .flat_map(|team| team.items.iter()) + .map(|item| item.name.as_str()) + .collect::>() + } + + pub fn result(&self) -> Vec { + self.teams + .iter() + .map(|team| team.output) + .collect::>() + } +} + +fn compute_elapsed(last_time: f64, actual_time: f64) -> f64 { + if last_time == f64::NEG_INFINITY { + 0.0 + } else if last_time == f64::INFINITY { + 1.0 + } else { + actual_time - last_time + } +} + +pub struct Batch { + skills: HashMap, + events: Vec, + time: f64, + agents: HashMap, + p_draw: f64, +} + +impl Batch { + pub fn new( + composition: Vec>>, + results: Vec>, + time: f64, + agents: HashMap, + p_draw: f64, + ) -> Self { + let mut this = Self { + skills: HashMap::new(), + events: Vec::new(), + time, + agents, + p_draw, + }; + + this.add_events(composition, results); + + this + } + + pub fn add_events(&mut self, composition: Vec>>, results: Vec>) { + let this_agent = composition + .iter() + .flat_map(|teams| teams.iter()) + .flat_map(|team| team.iter()) + .cloned() + .collect::>(); + + for a in this_agent { + let elapsed = compute_elapsed(self.agents[a].last_time, self.time); + + if let Some(skill) = self.skills.get_mut(a) { + skill.elapsed = elapsed; + skill.forward = self.agents[a].receive(elapsed); + } else { + self.skills.insert( + a.to_string(), + Skill { + forward: self.agents[a].receive(elapsed), + ..Default::default() + }, + ); + } + } + + let from = self.events.len() + 1; + + for e in 0..composition.len() { + let teams = (0..composition[e].len()) + .map(|t| { + let items = (0..composition[e][t].len()) + .map(|a| Item { + name: composition[e][t][a].to_string(), + likelihood: N_INF, + }) + .collect::>(); + + Team { + items, + output: results[e][t], + } + }) + .collect::>(); + + let event = Event { + teams, + evidence: 0.0, + }; + + self.events.push(event); + } + + self.iteration(from); + } + + fn posterior(&self, agent: &str) -> Gaussian { + let skill = &self.skills[agent]; + + skill.likelihood * skill.backward * skill.forward + } + + fn posteriors(&self) -> HashMap<&str, Gaussian> { + self.skills + .keys() + .map(|a| (a.as_str(), self.posterior(a))) + .collect::>() + } + + fn within_prior(&self, item: &Item) -> Player { + let r = &self.agents[&item.name].player; + let g = self.posterior(&item.name) / item.likelihood; + + Player::new(g, r.beta, r.gamma, N_INF) + } + + fn within_priors(&self, event: usize) -> Vec> { + self.events[event] + .teams + .iter() + .map(|team| { + team.items + .iter() + .map(|item| self.within_prior(item)) + .collect::>() + }) + .collect::>() + } + + fn iteration(&mut self, from: usize) { + for e in from..self.events.len() { + let teams = self.within_priors(e); + let result = self.events[e].result(); + + let g = Game::new(teams, result, self.p_draw); + + for (t, team) in self.events[e].teams.iter_mut().enumerate() { + for (i, item) in team.items.iter_mut().enumerate() { + self.skills.get_mut(&item.name).unwrap().likelihood = + (self.skills[&item.name].likelihood / item.likelihood) + * g.likelihoods[t][i]; + + item.likelihood = g.likelihoods[t][i]; + } + } + + self.events[e].evidence = g.evidence; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_one_event_each() { + let mut agents = HashMap::new(); + + for k in ["a", "b", "c", "d", "e", "f"] { + let agent = Agent::new( + Player::new( + Gaussian::new(25.0, 25.0 / 3.0), + 25.0 / 6.0, + 25.0 / 300.0, + N_INF, + ), + N_INF, + f64::NEG_INFINITY, + ); + + agents.insert(k.to_string(), agent); + } + + let b = Batch::new( + vec![ + vec![vec!["a"], vec!["b"]], + vec![vec!["c"], vec!["d"]], + vec![vec!["e"], vec!["f"]], + ], + vec![vec![1, 0], vec![0, 1], vec![1, 0]], + 0.0, + agents, + 0.0, + ); + + let post = b.posteriors(); + + assert_eq!(post["a"].mu(), 29.205); + assert_eq!(post["a"].sigma(), 7.194) + + /* + agents = dict() + for k in ["a", "b", "c", "d", "e", "f"]: + agents[k] = ttt.Agent(ttt.Player(ttt.Gaussian(25., 25.0/3), 25.0/6, 25.0/300 ) , ttt.Ninf, -ttt.inf) + b = ttt.Batch(composition=[ [["a"],["b"]], [["c"],["d"]] , [["e"],["f"]] ], results= [[1,0],[0,1],[1,0]], time = 0, agents=agents) + post = b.posteriors() + self.assertAlmostEqual(post["a"].mu,29.205,3) + self.assertAlmostEqual(post["a"].sigma,7.194,3) + + self.assertAlmostEqual(post["b"].mu,20.795,3) + self.assertAlmostEqual(post["b"].sigma,7.194,3) + self.assertAlmostEqual(post["c"].mu,20.795,3) + self.assertAlmostEqual(post["c"].sigma,7.194,3) + self.assertEqual(b.convergence(),1) + */ + } +} diff --git a/src/gaussian.rs b/src/gaussian.rs index 6554031..bf7cc38 100644 --- a/src/gaussian.rs +++ b/src/gaussian.rs @@ -40,11 +40,8 @@ impl Gaussian { self.sigma.powi(-2) } - pub fn forget(&self, gamma: f64, t: u32) -> Self { - Self::new( - self.mu, - (self.sigma().powi(2) + t as f64 * gamma.powi(2)).sqrt(), - ) + pub fn forget(&self, gamma: f64, t: f64) -> Self { + Self::new(self.mu, (self.sigma().powi(2) + t * gamma.powi(2)).sqrt()) } pub fn delta(&self, m: Gaussian) -> (f64, f64) { diff --git a/src/lib.rs b/src/lib.rs index c1d1f2c..7e29920 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +mod batch; mod game; mod gaussian; mod history; @@ -6,6 +7,7 @@ mod player; mod utils; mod variable; +pub use batch::*; pub use game::*; pub use gaussian::*; pub use history::*; diff --git a/src/player.rs b/src/player.rs index bb06d89..43ab8a3 100644 --- a/src/player.rs +++ b/src/player.rs @@ -6,7 +6,7 @@ use crate::{Gaussian, BETA, GAMMA, N_INF}; pub struct Player { pub prior: Gaussian, pub beta: f64, - gamma: f64, + pub gamma: f64, prior_draw: Gaussian, }