More refactoring

This commit is contained in:
2022-12-16 15:57:56 +01:00
parent 51467f7b69
commit 912a282cd8
2 changed files with 30 additions and 11 deletions

View File

@@ -80,6 +80,24 @@ impl Event {
.map(|team| team.output)
.collect::<Vec<_>>()
}
fn within_priors(
&self,
online: bool,
forward: bool,
skills: &HashMap<Index, Skill>,
agents: &HashMap<Index, Agent>,
) -> Vec<Vec<Player>> {
self.teams
.iter()
.map(|team| {
team.items
.iter()
.map(|item| item.within_prior(online, forward, skills, agents))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
}
}
#[derive(Debug)]
@@ -265,6 +283,7 @@ impl Batch {
.collect::<HashMap<_, _>>()
}
// TODO(anders): Remove this function.
pub(crate) fn within_priors(
&self,
event: usize,
@@ -285,13 +304,13 @@ impl Batch {
}
pub(crate) fn iteration(&mut self, from: usize, agents: &HashMap<Index, Agent>) {
for e in from..self.events.len() {
let teams = self.within_priors(e, false, false, agents);
let result = self.events[e].outputs();
for event in self.events.iter_mut().skip(from) {
let teams = event.within_priors(false, false, &self.skills, agents);
let result = event.outputs();
let g = Game::new(teams, result, self.events[e].weights.clone(), self.p_draw);
let g = Game::new(teams, result, event.weights.clone(), self.p_draw);
for (t, team) in self.events[e].teams.iter_mut().enumerate() {
for (t, team) in event.teams.iter_mut().enumerate() {
for (i, item) in team.items.iter_mut().enumerate() {
self.skills.get_mut(&item.agent).unwrap().likelihood =
(self.skills[&item.agent].likelihood / item.likelihood)
@@ -301,7 +320,7 @@ impl Batch {
}
}
self.events[e].evidence = g.evidence;
event.evidence = g.evidence;
}
}
@@ -375,9 +394,9 @@ impl Batch {
self.events
.iter()
.enumerate()
.map(|(e, event)| {
.map(|(_, event)| {
Game::new(
self.within_priors(e, online, forward, agents),
event.within_priors(online, forward, &self.skills, agents),
event.outputs(),
event.weights.clone(),
self.p_draw,
@@ -400,9 +419,9 @@ impl Batch {
.flat_map(|team| &team.items)
.any(|item| targets.contains(&item.agent))
})
.map(|(e, event)| {
.map(|(_, event)| {
Game::new(
self.within_priors(e, online, forward, agents),
event.within_priors(online, forward, &self.skills, agents),
event.outputs(),
event.weights.clone(),
self.p_draw,

View File

@@ -227,7 +227,7 @@ impl History {
pub fn log_evidence(&mut self, forward: bool, targets: &[Index]) -> f64 {
self.batches
.iter()
.map(|batch| batch.log_evidence(self.online, targets, forward, &mut self.agents))
.map(|batch| batch.log_evidence(self.online, targets, forward, &self.agents))
.sum()
}