Refactor so we can see if there is any way to improve the performance

This commit is contained in:
2022-12-16 15:38:29 +01:00
parent 5eb8e62d6e
commit 6dd84f7fd2
3 changed files with 57 additions and 42 deletions

View File

@@ -13,6 +13,12 @@ pub(crate) struct Skill {
pub(crate) online: Gaussian, pub(crate) online: Gaussian,
} }
impl Skill {
fn posterior(&self) -> Gaussian {
self.likelihood * self.backward * self.forward
}
}
impl Default for Skill { impl Default for Skill {
fn default() -> Self { fn default() -> Self {
Self { Self {
@@ -31,6 +37,29 @@ struct Item {
likelihood: Gaussian, likelihood: Gaussian,
} }
impl Item {
fn within_prior(
&self,
online: bool,
forward: bool,
skills: &HashMap<Index, Skill>,
agents: &HashMap<Index, Agent>,
) -> Player {
let r = &agents[&self.agent].player;
let skill = &skills[&self.agent];
if online {
Player::new(skill.online, r.beta, r.gamma)
} else if forward {
Player::new(skill.forward, r.beta, r.gamma)
} else {
let wp = skill.posterior() / self.likelihood;
Player::new(wp, r.beta, r.gamma)
}
}
}
#[derive(Debug)] #[derive(Debug)]
struct Team { struct Team {
items: Vec<Item>, items: Vec<Item>,
@@ -68,7 +97,7 @@ impl Batch {
weights: Vec<Vec<Vec<f64>>>, weights: Vec<Vec<Vec<f64>>>,
time: i64, time: i64,
p_draw: f64, p_draw: f64,
agents: &mut HashMap<Index, Agent>, agents: &HashMap<Index, Agent>,
) -> Self { ) -> Self {
assert!( assert!(
results.is_empty() || results.len() == composition.len(), results.is_empty() || results.len() == composition.len(),
@@ -235,26 +264,6 @@ impl Batch {
.collect::<HashMap<_, _>>() .collect::<HashMap<_, _>>()
} }
fn within_prior(
&self,
item: &Item,
online: bool,
forward: bool,
agents: &HashMap<Index, Agent>,
) -> Player {
let r = &agents[&item.agent].player;
if online {
Player::new(self.skills[&item.agent].online, r.beta, r.gamma)
} else if forward {
Player::new(self.skills[&item.agent].forward, r.beta, r.gamma)
} else {
let wp = self.posterior(item.agent) / item.likelihood;
Player::new(wp, r.beta, r.gamma)
}
}
pub(crate) fn within_priors( pub(crate) fn within_priors(
&self, &self,
event: usize, event: usize,
@@ -268,13 +277,13 @@ impl Batch {
.map(|team| { .map(|team| {
team.items team.items
.iter() .iter()
.map(|item| self.within_prior(item, online, forward, agents)) .map(|item| item.within_prior(online, forward, &self.skills, agents))
.collect::<Vec<_>>() .collect::<Vec<_>>()
}) })
.collect::<Vec<_>>() .collect::<Vec<_>>()
} }
pub(crate) fn iteration(&mut self, from: usize, agents: &mut HashMap<Index, Agent>) { pub(crate) fn iteration(&mut self, from: usize, agents: &HashMap<Index, Agent>) {
for e in from..self.events.len() { for e in from..self.events.len() {
let teams = self.within_priors(e, false, false, agents); let teams = self.within_priors(e, false, false, agents);
let result = self.events[e].outputs(); let result = self.events[e].outputs();
@@ -295,7 +304,8 @@ impl Batch {
} }
} }
pub(crate) fn convergence(&mut self, agents: &mut HashMap<Index, Agent>) -> usize { #[allow(dead_code)]
pub(crate) fn convergence(&mut self, agents: &HashMap<Index, Agent>) -> usize {
let epsilon = 1e-6; let epsilon = 1e-6;
let iterations = 20; let iterations = 20;
@@ -328,7 +338,7 @@ impl Batch {
pub(crate) fn backward_prior_out( pub(crate) fn backward_prior_out(
&self, &self,
agent: &Index, agent: &Index,
agents: &mut HashMap<Index, Agent>, agents: &HashMap<Index, Agent>,
) -> Gaussian { ) -> Gaussian {
let skill = &self.skills[agent]; let skill = &self.skills[agent];
let n = skill.likelihood * skill.backward; let n = skill.likelihood * skill.backward;
@@ -336,7 +346,7 @@ impl Batch {
n.forget(agents[agent].player.gamma, skill.elapsed) n.forget(agents[agent].player.gamma, skill.elapsed)
} }
pub(crate) fn new_backward_info(&mut self, agents: &mut HashMap<Index, Agent>) { pub(crate) fn new_backward_info(&mut self, agents: &HashMap<Index, Agent>) {
for (agent, skill) in self.skills.iter_mut() { for (agent, skill) in self.skills.iter_mut() {
skill.backward = agents[agent].message; skill.backward = agents[agent].message;
} }
@@ -344,7 +354,7 @@ impl Batch {
self.iteration(0, agents); self.iteration(0, agents);
} }
pub(crate) fn new_forward_info(&mut self, agents: &mut HashMap<Index, Agent>) { pub(crate) fn new_forward_info(&mut self, agents: &HashMap<Index, Agent>) {
for (agent, skill) in self.skills.iter_mut() { for (agent, skill) in self.skills.iter_mut() {
skill.forward = agents[agent].receive(skill.elapsed); skill.forward = agents[agent].receive(skill.elapsed);
} }
@@ -357,7 +367,7 @@ impl Batch {
online: bool, online: bool,
targets: &[Index], targets: &[Index],
forward: bool, forward: bool,
agents: &mut HashMap<Index, Agent>, agents: &HashMap<Index, Agent>,
) -> f64 { ) -> f64 {
if targets.is_empty() { if targets.is_empty() {
if online || forward { if online || forward {
@@ -417,7 +427,7 @@ impl Batch {
} }
} }
pub(crate) fn get_composition(&self) -> Vec<Vec<Vec<Index>>> { pub fn get_composition(&self) -> Vec<Vec<Vec<Index>>> {
self.events self.events
.iter() .iter()
.map(|event| { .map(|event| {
@@ -430,7 +440,7 @@ impl Batch {
.collect::<Vec<_>>() .collect::<Vec<_>>()
} }
pub(crate) fn get_results(&self) -> Vec<Vec<f64>> { pub fn get_results(&self) -> Vec<Vec<f64>> {
self.events self.events
.iter() .iter()
.map(|event| { .map(|event| {
@@ -495,7 +505,7 @@ mod tests {
vec![], vec![],
0, 0,
0.0, 0.0,
&mut agents, &agents,
); );
let post = batch.posteriors(); let post = batch.posteriors();
@@ -507,7 +517,7 @@ mod tests {
assert_ulps_eq!(post[&e], Gaussian::new(29.205220, 7.194481), epsilon = 1e-6); assert_ulps_eq!(post[&e], Gaussian::new(29.205220, 7.194481), epsilon = 1e-6);
assert_ulps_eq!(post[&f], Gaussian::new(20.794779, 7.194481), epsilon = 1e-6); assert_ulps_eq!(post[&f], Gaussian::new(20.794779, 7.194481), epsilon = 1e-6);
assert_eq!(batch.convergence(&mut agents), 1); assert_eq!(batch.convergence(&agents), 1);
} }
#[test] #[test]
@@ -543,7 +553,7 @@ mod tests {
vec![], vec![],
0, 0,
0.0, 0.0,
&mut agents, &agents,
); );
let post = batch.posteriors(); let post = batch.posteriors();
@@ -552,7 +562,7 @@ mod tests {
assert_ulps_eq!(post[&b], Gaussian::new(27.095590, 6.010330), epsilon = 1e-6); assert_ulps_eq!(post[&b], Gaussian::new(27.095590, 6.010330), epsilon = 1e-6);
assert_ulps_eq!(post[&c], Gaussian::new(24.889681, 5.866311), epsilon = 1e-6); assert_ulps_eq!(post[&c], Gaussian::new(24.889681, 5.866311), epsilon = 1e-6);
assert!(batch.convergence(&mut agents) > 1); assert!(batch.convergence(&agents) > 1);
let post = batch.posteriors(); let post = batch.posteriors();
@@ -594,10 +604,10 @@ mod tests {
vec![], vec![],
0, 0,
0.0, 0.0,
&mut agents, &agents,
); );
batch.convergence(&mut agents); batch.convergence(&agents);
let post = batch.posteriors(); let post = batch.posteriors();
@@ -618,7 +628,7 @@ mod tests {
assert_eq!(batch.events.len(), 6); assert_eq!(batch.events.len(), 6);
batch.convergence(&mut agents); batch.convergence(&agents);
let post = batch.posteriors(); let post = batch.posteriors();

View File

@@ -8,9 +8,11 @@ pub(crate) struct TeamMessage {
} }
impl TeamMessage { impl TeamMessage {
/*
pub(crate) fn p(&self) -> Gaussian { pub(crate) fn p(&self) -> Gaussian {
self.prior * self.likelihood_lose * self.likelihood_win * self.likelihood_draw self.prior * self.likelihood_lose * self.likelihood_win * self.likelihood_draw
} }
*/
pub(crate) fn posterior_win(&self) -> Gaussian { pub(crate) fn posterior_win(&self) -> Gaussian {
self.prior * self.likelihood_lose * self.likelihood_draw self.prior * self.likelihood_lose * self.likelihood_draw
@@ -25,6 +27,7 @@ impl TeamMessage {
} }
} }
/*
pub(crate) struct DrawMessage { pub(crate) struct DrawMessage {
pub(crate) prior: Gaussian, pub(crate) prior: Gaussian,
pub(crate) prior_team: Gaussian, pub(crate) prior_team: Gaussian,
@@ -49,14 +52,16 @@ impl DrawMessage {
self.likelihood_win * self.likelihood_lose self.likelihood_win * self.likelihood_lose
} }
} }
*/
pub(crate) struct DiffMessage { pub(crate) struct DiffMessage {
pub(crate) prior: Gaussian, pub(crate) prior: Gaussian,
pub(crate) likelihood: Gaussian, pub(crate) likelihood: Gaussian,
} }
impl DiffMessage { impl DiffMessage {
/*
pub(crate) fn p(&self) -> Gaussian { pub(crate) fn p(&self) -> Gaussian {
self.prior * self.likelihood self.prior * self.likelihood
} }
*/
} }

View File

@@ -1,11 +1,11 @@
use crate::{gaussian::Gaussian, BETA, GAMMA, N_INF}; use crate::{gaussian::Gaussian, BETA, GAMMA};
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug)]
pub struct Player { pub struct Player {
pub(crate) prior: Gaussian, pub(crate) prior: Gaussian,
pub(crate) beta: f64, pub(crate) beta: f64,
pub(crate) gamma: f64, pub(crate) gamma: f64,
pub(crate) draw: Gaussian, // pub(crate) draw: Gaussian,
} }
impl Player { impl Player {
@@ -14,7 +14,7 @@ impl Player {
prior, prior,
beta, beta,
gamma, gamma,
draw: N_INF, // draw: N_INF,
} }
} }
@@ -29,7 +29,7 @@ impl Default for Player {
prior: Gaussian::default(), prior: Gaussian::default(),
beta: BETA, beta: BETA,
gamma: GAMMA, gamma: GAMMA,
draw: N_INF, // draw: N_INF,
} }
} }
} }