Refactor so we can see if there is any way to improve the performance
This commit is contained in:
84
src/batch.rs
84
src/batch.rs
@@ -13,6 +13,12 @@ pub(crate) struct Skill {
|
|||||||
pub(crate) online: Gaussian,
|
pub(crate) online: Gaussian,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Skill {
|
||||||
|
fn posterior(&self) -> Gaussian {
|
||||||
|
self.likelihood * self.backward * self.forward
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Default for Skill {
|
impl Default for Skill {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -31,6 +37,29 @@ struct Item {
|
|||||||
likelihood: Gaussian,
|
likelihood: Gaussian,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Item {
|
||||||
|
fn within_prior(
|
||||||
|
&self,
|
||||||
|
online: bool,
|
||||||
|
forward: bool,
|
||||||
|
skills: &HashMap<Index, Skill>,
|
||||||
|
agents: &HashMap<Index, Agent>,
|
||||||
|
) -> Player {
|
||||||
|
let r = &agents[&self.agent].player;
|
||||||
|
let skill = &skills[&self.agent];
|
||||||
|
|
||||||
|
if online {
|
||||||
|
Player::new(skill.online, r.beta, r.gamma)
|
||||||
|
} else if forward {
|
||||||
|
Player::new(skill.forward, r.beta, r.gamma)
|
||||||
|
} else {
|
||||||
|
let wp = skill.posterior() / self.likelihood;
|
||||||
|
|
||||||
|
Player::new(wp, r.beta, r.gamma)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Team {
|
struct Team {
|
||||||
items: Vec<Item>,
|
items: Vec<Item>,
|
||||||
@@ -68,7 +97,7 @@ impl Batch {
|
|||||||
weights: Vec<Vec<Vec<f64>>>,
|
weights: Vec<Vec<Vec<f64>>>,
|
||||||
time: i64,
|
time: i64,
|
||||||
p_draw: f64,
|
p_draw: f64,
|
||||||
agents: &mut HashMap<Index, Agent>,
|
agents: &HashMap<Index, Agent>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
assert!(
|
assert!(
|
||||||
results.is_empty() || results.len() == composition.len(),
|
results.is_empty() || results.len() == composition.len(),
|
||||||
@@ -235,26 +264,6 @@ impl Batch {
|
|||||||
.collect::<HashMap<_, _>>()
|
.collect::<HashMap<_, _>>()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn within_prior(
|
|
||||||
&self,
|
|
||||||
item: &Item,
|
|
||||||
online: bool,
|
|
||||||
forward: bool,
|
|
||||||
agents: &HashMap<Index, Agent>,
|
|
||||||
) -> Player {
|
|
||||||
let r = &agents[&item.agent].player;
|
|
||||||
|
|
||||||
if online {
|
|
||||||
Player::new(self.skills[&item.agent].online, r.beta, r.gamma)
|
|
||||||
} else if forward {
|
|
||||||
Player::new(self.skills[&item.agent].forward, r.beta, r.gamma)
|
|
||||||
} else {
|
|
||||||
let wp = self.posterior(item.agent) / item.likelihood;
|
|
||||||
|
|
||||||
Player::new(wp, r.beta, r.gamma)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn within_priors(
|
pub(crate) fn within_priors(
|
||||||
&self,
|
&self,
|
||||||
event: usize,
|
event: usize,
|
||||||
@@ -268,13 +277,13 @@ impl Batch {
|
|||||||
.map(|team| {
|
.map(|team| {
|
||||||
team.items
|
team.items
|
||||||
.iter()
|
.iter()
|
||||||
.map(|item| self.within_prior(item, online, forward, agents))
|
.map(|item| item.within_prior(online, forward, &self.skills, agents))
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn iteration(&mut self, from: usize, agents: &mut HashMap<Index, Agent>) {
|
pub(crate) fn iteration(&mut self, from: usize, agents: &HashMap<Index, Agent>) {
|
||||||
for e in from..self.events.len() {
|
for e in from..self.events.len() {
|
||||||
let teams = self.within_priors(e, false, false, agents);
|
let teams = self.within_priors(e, false, false, agents);
|
||||||
let result = self.events[e].outputs();
|
let result = self.events[e].outputs();
|
||||||
@@ -295,7 +304,8 @@ impl Batch {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn convergence(&mut self, agents: &mut HashMap<Index, Agent>) -> usize {
|
#[allow(dead_code)]
|
||||||
|
pub(crate) fn convergence(&mut self, agents: &HashMap<Index, Agent>) -> usize {
|
||||||
let epsilon = 1e-6;
|
let epsilon = 1e-6;
|
||||||
let iterations = 20;
|
let iterations = 20;
|
||||||
|
|
||||||
@@ -328,7 +338,7 @@ impl Batch {
|
|||||||
pub(crate) fn backward_prior_out(
|
pub(crate) fn backward_prior_out(
|
||||||
&self,
|
&self,
|
||||||
agent: &Index,
|
agent: &Index,
|
||||||
agents: &mut HashMap<Index, Agent>,
|
agents: &HashMap<Index, Agent>,
|
||||||
) -> Gaussian {
|
) -> Gaussian {
|
||||||
let skill = &self.skills[agent];
|
let skill = &self.skills[agent];
|
||||||
let n = skill.likelihood * skill.backward;
|
let n = skill.likelihood * skill.backward;
|
||||||
@@ -336,7 +346,7 @@ impl Batch {
|
|||||||
n.forget(agents[agent].player.gamma, skill.elapsed)
|
n.forget(agents[agent].player.gamma, skill.elapsed)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn new_backward_info(&mut self, agents: &mut HashMap<Index, Agent>) {
|
pub(crate) fn new_backward_info(&mut self, agents: &HashMap<Index, Agent>) {
|
||||||
for (agent, skill) in self.skills.iter_mut() {
|
for (agent, skill) in self.skills.iter_mut() {
|
||||||
skill.backward = agents[agent].message;
|
skill.backward = agents[agent].message;
|
||||||
}
|
}
|
||||||
@@ -344,7 +354,7 @@ impl Batch {
|
|||||||
self.iteration(0, agents);
|
self.iteration(0, agents);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn new_forward_info(&mut self, agents: &mut HashMap<Index, Agent>) {
|
pub(crate) fn new_forward_info(&mut self, agents: &HashMap<Index, Agent>) {
|
||||||
for (agent, skill) in self.skills.iter_mut() {
|
for (agent, skill) in self.skills.iter_mut() {
|
||||||
skill.forward = agents[agent].receive(skill.elapsed);
|
skill.forward = agents[agent].receive(skill.elapsed);
|
||||||
}
|
}
|
||||||
@@ -357,7 +367,7 @@ impl Batch {
|
|||||||
online: bool,
|
online: bool,
|
||||||
targets: &[Index],
|
targets: &[Index],
|
||||||
forward: bool,
|
forward: bool,
|
||||||
agents: &mut HashMap<Index, Agent>,
|
agents: &HashMap<Index, Agent>,
|
||||||
) -> f64 {
|
) -> f64 {
|
||||||
if targets.is_empty() {
|
if targets.is_empty() {
|
||||||
if online || forward {
|
if online || forward {
|
||||||
@@ -417,7 +427,7 @@ impl Batch {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn get_composition(&self) -> Vec<Vec<Vec<Index>>> {
|
pub fn get_composition(&self) -> Vec<Vec<Vec<Index>>> {
|
||||||
self.events
|
self.events
|
||||||
.iter()
|
.iter()
|
||||||
.map(|event| {
|
.map(|event| {
|
||||||
@@ -430,7 +440,7 @@ impl Batch {
|
|||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn get_results(&self) -> Vec<Vec<f64>> {
|
pub fn get_results(&self) -> Vec<Vec<f64>> {
|
||||||
self.events
|
self.events
|
||||||
.iter()
|
.iter()
|
||||||
.map(|event| {
|
.map(|event| {
|
||||||
@@ -495,7 +505,7 @@ mod tests {
|
|||||||
vec![],
|
vec![],
|
||||||
0,
|
0,
|
||||||
0.0,
|
0.0,
|
||||||
&mut agents,
|
&agents,
|
||||||
);
|
);
|
||||||
|
|
||||||
let post = batch.posteriors();
|
let post = batch.posteriors();
|
||||||
@@ -507,7 +517,7 @@ mod tests {
|
|||||||
assert_ulps_eq!(post[&e], Gaussian::new(29.205220, 7.194481), epsilon = 1e-6);
|
assert_ulps_eq!(post[&e], Gaussian::new(29.205220, 7.194481), epsilon = 1e-6);
|
||||||
assert_ulps_eq!(post[&f], Gaussian::new(20.794779, 7.194481), epsilon = 1e-6);
|
assert_ulps_eq!(post[&f], Gaussian::new(20.794779, 7.194481), epsilon = 1e-6);
|
||||||
|
|
||||||
assert_eq!(batch.convergence(&mut agents), 1);
|
assert_eq!(batch.convergence(&agents), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -543,7 +553,7 @@ mod tests {
|
|||||||
vec![],
|
vec![],
|
||||||
0,
|
0,
|
||||||
0.0,
|
0.0,
|
||||||
&mut agents,
|
&agents,
|
||||||
);
|
);
|
||||||
|
|
||||||
let post = batch.posteriors();
|
let post = batch.posteriors();
|
||||||
@@ -552,7 +562,7 @@ mod tests {
|
|||||||
assert_ulps_eq!(post[&b], Gaussian::new(27.095590, 6.010330), epsilon = 1e-6);
|
assert_ulps_eq!(post[&b], Gaussian::new(27.095590, 6.010330), epsilon = 1e-6);
|
||||||
assert_ulps_eq!(post[&c], Gaussian::new(24.889681, 5.866311), epsilon = 1e-6);
|
assert_ulps_eq!(post[&c], Gaussian::new(24.889681, 5.866311), epsilon = 1e-6);
|
||||||
|
|
||||||
assert!(batch.convergence(&mut agents) > 1);
|
assert!(batch.convergence(&agents) > 1);
|
||||||
|
|
||||||
let post = batch.posteriors();
|
let post = batch.posteriors();
|
||||||
|
|
||||||
@@ -594,10 +604,10 @@ mod tests {
|
|||||||
vec![],
|
vec![],
|
||||||
0,
|
0,
|
||||||
0.0,
|
0.0,
|
||||||
&mut agents,
|
&agents,
|
||||||
);
|
);
|
||||||
|
|
||||||
batch.convergence(&mut agents);
|
batch.convergence(&agents);
|
||||||
|
|
||||||
let post = batch.posteriors();
|
let post = batch.posteriors();
|
||||||
|
|
||||||
@@ -618,7 +628,7 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(batch.events.len(), 6);
|
assert_eq!(batch.events.len(), 6);
|
||||||
|
|
||||||
batch.convergence(&mut agents);
|
batch.convergence(&agents);
|
||||||
|
|
||||||
let post = batch.posteriors();
|
let post = batch.posteriors();
|
||||||
|
|
||||||
|
|||||||
@@ -8,9 +8,11 @@ pub(crate) struct TeamMessage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl TeamMessage {
|
impl TeamMessage {
|
||||||
|
/*
|
||||||
pub(crate) fn p(&self) -> Gaussian {
|
pub(crate) fn p(&self) -> Gaussian {
|
||||||
self.prior * self.likelihood_lose * self.likelihood_win * self.likelihood_draw
|
self.prior * self.likelihood_lose * self.likelihood_win * self.likelihood_draw
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
pub(crate) fn posterior_win(&self) -> Gaussian {
|
pub(crate) fn posterior_win(&self) -> Gaussian {
|
||||||
self.prior * self.likelihood_lose * self.likelihood_draw
|
self.prior * self.likelihood_lose * self.likelihood_draw
|
||||||
@@ -25,6 +27,7 @@ impl TeamMessage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
pub(crate) struct DrawMessage {
|
pub(crate) struct DrawMessage {
|
||||||
pub(crate) prior: Gaussian,
|
pub(crate) prior: Gaussian,
|
||||||
pub(crate) prior_team: Gaussian,
|
pub(crate) prior_team: Gaussian,
|
||||||
@@ -49,14 +52,16 @@ impl DrawMessage {
|
|||||||
self.likelihood_win * self.likelihood_lose
|
self.likelihood_win * self.likelihood_lose
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
pub(crate) struct DiffMessage {
|
pub(crate) struct DiffMessage {
|
||||||
pub(crate) prior: Gaussian,
|
pub(crate) prior: Gaussian,
|
||||||
pub(crate) likelihood: Gaussian,
|
pub(crate) likelihood: Gaussian,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DiffMessage {
|
impl DiffMessage {
|
||||||
|
/*
|
||||||
pub(crate) fn p(&self) -> Gaussian {
|
pub(crate) fn p(&self) -> Gaussian {
|
||||||
self.prior * self.likelihood
|
self.prior * self.likelihood
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
use crate::{gaussian::Gaussian, BETA, GAMMA, N_INF};
|
use crate::{gaussian::Gaussian, BETA, GAMMA};
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug)]
|
#[derive(Clone, Copy, Debug)]
|
||||||
pub struct Player {
|
pub struct Player {
|
||||||
pub(crate) prior: Gaussian,
|
pub(crate) prior: Gaussian,
|
||||||
pub(crate) beta: f64,
|
pub(crate) beta: f64,
|
||||||
pub(crate) gamma: f64,
|
pub(crate) gamma: f64,
|
||||||
pub(crate) draw: Gaussian,
|
// pub(crate) draw: Gaussian,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Player {
|
impl Player {
|
||||||
@@ -14,7 +14,7 @@ impl Player {
|
|||||||
prior,
|
prior,
|
||||||
beta,
|
beta,
|
||||||
gamma,
|
gamma,
|
||||||
draw: N_INF,
|
// draw: N_INF,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ impl Default for Player {
|
|||||||
prior: Gaussian::default(),
|
prior: Gaussian::default(),
|
||||||
beta: BETA,
|
beta: BETA,
|
||||||
gamma: GAMMA,
|
gamma: GAMMA,
|
||||||
draw: N_INF,
|
// draw: N_INF,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user