More refactoring

This commit is contained in:
2022-12-16 15:57:56 +01:00
parent 51467f7b69
commit 912a282cd8
2 changed files with 30 additions and 11 deletions

View File

@@ -80,6 +80,24 @@ impl Event {
.map(|team| team.output) .map(|team| team.output)
.collect::<Vec<_>>() .collect::<Vec<_>>()
} }
fn within_priors(
&self,
online: bool,
forward: bool,
skills: &HashMap<Index, Skill>,
agents: &HashMap<Index, Agent>,
) -> Vec<Vec<Player>> {
self.teams
.iter()
.map(|team| {
team.items
.iter()
.map(|item| item.within_prior(online, forward, skills, agents))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
}
} }
#[derive(Debug)] #[derive(Debug)]
@@ -265,6 +283,7 @@ impl Batch {
.collect::<HashMap<_, _>>() .collect::<HashMap<_, _>>()
} }
// TODO(anders): Remove this function.
pub(crate) fn within_priors( pub(crate) fn within_priors(
&self, &self,
event: usize, event: usize,
@@ -285,13 +304,13 @@ impl Batch {
} }
pub(crate) fn iteration(&mut self, from: usize, agents: &HashMap<Index, Agent>) { pub(crate) fn iteration(&mut self, from: usize, agents: &HashMap<Index, Agent>) {
for e in from..self.events.len() { for event in self.events.iter_mut().skip(from) {
let teams = self.within_priors(e, false, false, agents); let teams = event.within_priors(false, false, &self.skills, agents);
let result = self.events[e].outputs(); let result = event.outputs();
let g = Game::new(teams, result, self.events[e].weights.clone(), self.p_draw); let g = Game::new(teams, result, event.weights.clone(), self.p_draw);
for (t, team) in self.events[e].teams.iter_mut().enumerate() { for (t, team) in event.teams.iter_mut().enumerate() {
for (i, item) in team.items.iter_mut().enumerate() { for (i, item) in team.items.iter_mut().enumerate() {
self.skills.get_mut(&item.agent).unwrap().likelihood = self.skills.get_mut(&item.agent).unwrap().likelihood =
(self.skills[&item.agent].likelihood / item.likelihood) (self.skills[&item.agent].likelihood / item.likelihood)
@@ -301,7 +320,7 @@ impl Batch {
} }
} }
self.events[e].evidence = g.evidence; event.evidence = g.evidence;
} }
} }
@@ -375,9 +394,9 @@ impl Batch {
self.events self.events
.iter() .iter()
.enumerate() .enumerate()
.map(|(e, event)| { .map(|(_, event)| {
Game::new( Game::new(
self.within_priors(e, online, forward, agents), event.within_priors(online, forward, &self.skills, agents),
event.outputs(), event.outputs(),
event.weights.clone(), event.weights.clone(),
self.p_draw, self.p_draw,
@@ -400,9 +419,9 @@ impl Batch {
.flat_map(|team| &team.items) .flat_map(|team| &team.items)
.any(|item| targets.contains(&item.agent)) .any(|item| targets.contains(&item.agent))
}) })
.map(|(e, event)| { .map(|(_, event)| {
Game::new( Game::new(
self.within_priors(e, online, forward, agents), event.within_priors(online, forward, &self.skills, agents),
event.outputs(), event.outputs(),
event.weights.clone(), event.weights.clone(),
self.p_draw, self.p_draw,

View File

@@ -227,7 +227,7 @@ impl History {
pub fn log_evidence(&mut self, forward: bool, targets: &[Index]) -> f64 { pub fn log_evidence(&mut self, forward: bool, targets: &[Index]) -> f64 {
self.batches self.batches
.iter() .iter()
.map(|batch| batch.log_evidence(self.online, targets, forward, &mut self.agents)) .map(|batch| batch.log_evidence(self.online, targets, forward, &self.agents))
.sum() .sum()
} }