More refactoring
This commit is contained in:
39
src/batch.rs
39
src/batch.rs
@@ -80,6 +80,24 @@ impl Event {
|
|||||||
.map(|team| team.output)
|
.map(|team| team.output)
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn within_priors(
|
||||||
|
&self,
|
||||||
|
online: bool,
|
||||||
|
forward: bool,
|
||||||
|
skills: &HashMap<Index, Skill>,
|
||||||
|
agents: &HashMap<Index, Agent>,
|
||||||
|
) -> Vec<Vec<Player>> {
|
||||||
|
self.teams
|
||||||
|
.iter()
|
||||||
|
.map(|team| {
|
||||||
|
team.items
|
||||||
|
.iter()
|
||||||
|
.map(|item| item.within_prior(online, forward, skills, agents))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -265,6 +283,7 @@ impl Batch {
|
|||||||
.collect::<HashMap<_, _>>()
|
.collect::<HashMap<_, _>>()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(anders): Remove this function.
|
||||||
pub(crate) fn within_priors(
|
pub(crate) fn within_priors(
|
||||||
&self,
|
&self,
|
||||||
event: usize,
|
event: usize,
|
||||||
@@ -285,13 +304,13 @@ impl Batch {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn iteration(&mut self, from: usize, agents: &HashMap<Index, Agent>) {
|
pub(crate) fn iteration(&mut self, from: usize, agents: &HashMap<Index, Agent>) {
|
||||||
for e in from..self.events.len() {
|
for event in self.events.iter_mut().skip(from) {
|
||||||
let teams = self.within_priors(e, false, false, agents);
|
let teams = event.within_priors(false, false, &self.skills, agents);
|
||||||
let result = self.events[e].outputs();
|
let result = event.outputs();
|
||||||
|
|
||||||
let g = Game::new(teams, result, self.events[e].weights.clone(), self.p_draw);
|
let g = Game::new(teams, result, event.weights.clone(), self.p_draw);
|
||||||
|
|
||||||
for (t, team) in self.events[e].teams.iter_mut().enumerate() {
|
for (t, team) in event.teams.iter_mut().enumerate() {
|
||||||
for (i, item) in team.items.iter_mut().enumerate() {
|
for (i, item) in team.items.iter_mut().enumerate() {
|
||||||
self.skills.get_mut(&item.agent).unwrap().likelihood =
|
self.skills.get_mut(&item.agent).unwrap().likelihood =
|
||||||
(self.skills[&item.agent].likelihood / item.likelihood)
|
(self.skills[&item.agent].likelihood / item.likelihood)
|
||||||
@@ -301,7 +320,7 @@ impl Batch {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
self.events[e].evidence = g.evidence;
|
event.evidence = g.evidence;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -375,9 +394,9 @@ impl Batch {
|
|||||||
self.events
|
self.events
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(e, event)| {
|
.map(|(_, event)| {
|
||||||
Game::new(
|
Game::new(
|
||||||
self.within_priors(e, online, forward, agents),
|
event.within_priors(online, forward, &self.skills, agents),
|
||||||
event.outputs(),
|
event.outputs(),
|
||||||
event.weights.clone(),
|
event.weights.clone(),
|
||||||
self.p_draw,
|
self.p_draw,
|
||||||
@@ -400,9 +419,9 @@ impl Batch {
|
|||||||
.flat_map(|team| &team.items)
|
.flat_map(|team| &team.items)
|
||||||
.any(|item| targets.contains(&item.agent))
|
.any(|item| targets.contains(&item.agent))
|
||||||
})
|
})
|
||||||
.map(|(e, event)| {
|
.map(|(_, event)| {
|
||||||
Game::new(
|
Game::new(
|
||||||
self.within_priors(e, online, forward, agents),
|
event.within_priors(online, forward, &self.skills, agents),
|
||||||
event.outputs(),
|
event.outputs(),
|
||||||
event.weights.clone(),
|
event.weights.clone(),
|
||||||
self.p_draw,
|
self.p_draw,
|
||||||
|
|||||||
@@ -227,7 +227,7 @@ impl History {
|
|||||||
pub fn log_evidence(&mut self, forward: bool, targets: &[Index]) -> f64 {
|
pub fn log_evidence(&mut self, forward: bool, targets: &[Index]) -> f64 {
|
||||||
self.batches
|
self.batches
|
||||||
.iter()
|
.iter()
|
||||||
.map(|batch| batch.log_evidence(self.online, targets, forward, &mut self.agents))
|
.map(|batch| batch.log_evidence(self.online, targets, forward, &self.agents))
|
||||||
.sum()
|
.sum()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user