More refactoring
This commit is contained in:
39
src/batch.rs
39
src/batch.rs
@@ -80,6 +80,24 @@ impl Event {
|
||||
.map(|team| team.output)
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn within_priors(
|
||||
&self,
|
||||
online: bool,
|
||||
forward: bool,
|
||||
skills: &HashMap<Index, Skill>,
|
||||
agents: &HashMap<Index, Agent>,
|
||||
) -> Vec<Vec<Player>> {
|
||||
self.teams
|
||||
.iter()
|
||||
.map(|team| {
|
||||
team.items
|
||||
.iter()
|
||||
.map(|item| item.within_prior(online, forward, skills, agents))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -265,6 +283,7 @@ impl Batch {
|
||||
.collect::<HashMap<_, _>>()
|
||||
}
|
||||
|
||||
// TODO(anders): Remove this function.
|
||||
pub(crate) fn within_priors(
|
||||
&self,
|
||||
event: usize,
|
||||
@@ -285,13 +304,13 @@ impl Batch {
|
||||
}
|
||||
|
||||
pub(crate) fn iteration(&mut self, from: usize, agents: &HashMap<Index, Agent>) {
|
||||
for e in from..self.events.len() {
|
||||
let teams = self.within_priors(e, false, false, agents);
|
||||
let result = self.events[e].outputs();
|
||||
for event in self.events.iter_mut().skip(from) {
|
||||
let teams = event.within_priors(false, false, &self.skills, agents);
|
||||
let result = event.outputs();
|
||||
|
||||
let g = Game::new(teams, result, self.events[e].weights.clone(), self.p_draw);
|
||||
let g = Game::new(teams, result, event.weights.clone(), self.p_draw);
|
||||
|
||||
for (t, team) in self.events[e].teams.iter_mut().enumerate() {
|
||||
for (t, team) in event.teams.iter_mut().enumerate() {
|
||||
for (i, item) in team.items.iter_mut().enumerate() {
|
||||
self.skills.get_mut(&item.agent).unwrap().likelihood =
|
||||
(self.skills[&item.agent].likelihood / item.likelihood)
|
||||
@@ -301,7 +320,7 @@ impl Batch {
|
||||
}
|
||||
}
|
||||
|
||||
self.events[e].evidence = g.evidence;
|
||||
event.evidence = g.evidence;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -375,9 +394,9 @@ impl Batch {
|
||||
self.events
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(e, event)| {
|
||||
.map(|(_, event)| {
|
||||
Game::new(
|
||||
self.within_priors(e, online, forward, agents),
|
||||
event.within_priors(online, forward, &self.skills, agents),
|
||||
event.outputs(),
|
||||
event.weights.clone(),
|
||||
self.p_draw,
|
||||
@@ -400,9 +419,9 @@ impl Batch {
|
||||
.flat_map(|team| &team.items)
|
||||
.any(|item| targets.contains(&item.agent))
|
||||
})
|
||||
.map(|(e, event)| {
|
||||
.map(|(_, event)| {
|
||||
Game::new(
|
||||
self.within_priors(e, online, forward, agents),
|
||||
event.within_priors(online, forward, &self.skills, agents),
|
||||
event.outputs(),
|
||||
event.weights.clone(),
|
||||
self.p_draw,
|
||||
|
||||
@@ -227,7 +227,7 @@ impl History {
|
||||
pub fn log_evidence(&mut self, forward: bool, targets: &[Index]) -> f64 {
|
||||
self.batches
|
||||
.iter()
|
||||
.map(|batch| batch.log_evidence(self.online, targets, forward, &mut self.agents))
|
||||
.map(|batch| batch.log_evidence(self.online, targets, forward, &self.agents))
|
||||
.sum()
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user