diff --git a/src/batch.rs b/src/batch.rs index 56a76bb..43872c0 100644 --- a/src/batch.rs +++ b/src/batch.rs @@ -80,6 +80,24 @@ impl Event { .map(|team| team.output) .collect::>() } + + fn within_priors( + &self, + online: bool, + forward: bool, + skills: &HashMap, + agents: &HashMap, + ) -> Vec> { + self.teams + .iter() + .map(|team| { + team.items + .iter() + .map(|item| item.within_prior(online, forward, skills, agents)) + .collect::>() + }) + .collect::>() + } } #[derive(Debug)] @@ -265,6 +283,7 @@ impl Batch { .collect::>() } + // TODO(anders): Remove this function. pub(crate) fn within_priors( &self, event: usize, @@ -285,13 +304,13 @@ impl Batch { } pub(crate) fn iteration(&mut self, from: usize, agents: &HashMap) { - for e in from..self.events.len() { - let teams = self.within_priors(e, false, false, agents); - let result = self.events[e].outputs(); + for event in self.events.iter_mut().skip(from) { + let teams = event.within_priors(false, false, &self.skills, agents); + let result = event.outputs(); - let g = Game::new(teams, result, self.events[e].weights.clone(), self.p_draw); + let g = Game::new(teams, result, event.weights.clone(), self.p_draw); - for (t, team) in self.events[e].teams.iter_mut().enumerate() { + for (t, team) in event.teams.iter_mut().enumerate() { for (i, item) in team.items.iter_mut().enumerate() { self.skills.get_mut(&item.agent).unwrap().likelihood = (self.skills[&item.agent].likelihood / item.likelihood) @@ -301,7 +320,7 @@ impl Batch { } } - self.events[e].evidence = g.evidence; + event.evidence = g.evidence; } } @@ -375,9 +394,9 @@ impl Batch { self.events .iter() .enumerate() - .map(|(e, event)| { + .map(|(_, event)| { Game::new( - self.within_priors(e, online, forward, agents), + event.within_priors(online, forward, &self.skills, agents), event.outputs(), event.weights.clone(), self.p_draw, @@ -400,9 +419,9 @@ impl Batch { .flat_map(|team| &team.items) .any(|item| targets.contains(&item.agent)) }) - .map(|(e, event)| { + .map(|(_, event)| { Game::new( - self.within_priors(e, online, forward, agents), + event.within_priors(online, forward, &self.skills, agents), event.outputs(), event.weights.clone(), self.p_draw, diff --git a/src/history.rs b/src/history.rs index 1c03dd6..b746b4f 100644 --- a/src/history.rs +++ b/src/history.rs @@ -227,7 +227,7 @@ impl History { pub fn log_evidence(&mut self, forward: bool, targets: &[Index]) -> f64 { self.batches .iter() - .map(|batch| batch.log_evidence(self.online, targets, forward, &mut self.agents)) + .map(|batch| batch.log_evidence(self.online, targets, forward, &self.agents)) .sum() }