diff --git a/src/batch.rs b/src/batch.rs index 46475d9..56a76bb 100644 --- a/src/batch.rs +++ b/src/batch.rs @@ -14,7 +14,7 @@ pub(crate) struct Skill { } impl Skill { - fn posterior(&self) -> Gaussian { + pub(crate) fn posterior(&self) -> Gaussian { self.likelihood * self.backward * self.forward } } @@ -251,6 +251,7 @@ impl Batch { self.iteration(from, agents); } + // TODO(anders): Use Item::posterior() instead. pub fn posterior(&self, agent: Index) -> Gaussian { let skill = &self.skills[&agent]; @@ -259,8 +260,8 @@ impl Batch { pub(crate) fn posteriors(&self) -> HashMap { self.skills - .keys() - .map(|&idx| (idx, self.posterior(idx))) + .iter() + .map(|(&idx, skill)| (idx, skill.posterior())) .collect::>() } @@ -388,42 +389,40 @@ impl Batch { } else { self.events.iter().map(|event| event.evidence.ln()).sum() } + } else if online || forward { + self.events + .iter() + .enumerate() + .filter(|(_, event)| { + event + .teams + .iter() + .flat_map(|team| &team.items) + .any(|item| targets.contains(&item.agent)) + }) + .map(|(e, event)| { + Game::new( + self.within_priors(e, online, forward, agents), + event.outputs(), + event.weights.clone(), + self.p_draw, + ) + .evidence + .ln() + }) + .sum() } else { - if online || forward { - self.events - .iter() - .enumerate() - .filter(|(_, event)| { - event - .teams - .iter() - .flat_map(|team| &team.items) - .any(|item| targets.contains(&item.agent)) - }) - .map(|(e, event)| { - Game::new( - self.within_priors(e, online, forward, agents), - event.outputs(), - event.weights.clone(), - self.p_draw, - ) - .evidence - .ln() - }) - .sum() - } else { - self.events - .iter() - .filter(|event| { - event - .teams - .iter() - .flat_map(|team| &team.items) - .any(|item| targets.contains(&item.agent)) - }) - .map(|event| event.evidence.ln()) - .sum() - } + self.events + .iter() + .filter(|event| { + event + .teams + .iter() + .flat_map(|team| &team.items) + .any(|item| targets.contains(&item.agent)) + }) + .map(|event| event.evidence.ln()) + .sum() } } diff --git a/src/history.rs b/src/history.rs index 8964551..1c03dd6 100644 --- a/src/history.rs +++ b/src/history.rs @@ -210,8 +210,8 @@ impl History { let mut data: HashMap> = HashMap::new(); for b in &self.batches { - for agent in b.skills.keys() { - let point = (b.time, b.posterior(*agent)); + for (agent, skill) in b.skills.iter() { + let point = (b.time, skill.posterior()); if let Some(entry) = data.get_mut(agent) { entry.push(point);