feat: added a Drift trait and a "default" ConstantDrift implementation
This commit is contained in:
51
src/batch.rs
51
src/batch.rs
@@ -1,7 +1,8 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
agent::Agent, game::Game, gaussian::Gaussian, player::Player, tuple_gt, tuple_max, Index, N_INF,
|
||||
Index, N_INF, agent::Agent, drift::Drift, game::Game, gaussian::Gaussian, player::Player,
|
||||
tuple_gt, tuple_max,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -38,22 +39,22 @@ struct Item {
|
||||
}
|
||||
|
||||
impl Item {
|
||||
fn within_prior(
|
||||
fn within_prior<D: Drift>(
|
||||
&self,
|
||||
online: bool,
|
||||
forward: bool,
|
||||
skills: &HashMap<Index, Skill>,
|
||||
agents: &HashMap<Index, Agent>,
|
||||
) -> Player {
|
||||
agents: &HashMap<Index, Agent<D>>,
|
||||
) -> Player<D> {
|
||||
let r = &agents[&self.agent].player;
|
||||
let skill = &skills[&self.agent];
|
||||
|
||||
if online {
|
||||
Player::new(skill.online, r.beta, r.gamma)
|
||||
Player::new(skill.online, r.beta, r.drift)
|
||||
} else if forward {
|
||||
Player::new(skill.forward, r.beta, r.gamma)
|
||||
Player::new(skill.forward, r.beta, r.drift)
|
||||
} else {
|
||||
Player::new(skill.posterior() / self.likelihood, r.beta, r.gamma)
|
||||
Player::new(skill.posterior() / self.likelihood, r.beta, r.drift)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -79,13 +80,13 @@ impl Event {
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
pub(crate) fn within_priors(
|
||||
pub(crate) fn within_priors<D: Drift>(
|
||||
&self,
|
||||
online: bool,
|
||||
forward: bool,
|
||||
skills: &HashMap<Index, Skill>,
|
||||
agents: &HashMap<Index, Agent>,
|
||||
) -> Vec<Vec<Player>> {
|
||||
agents: &HashMap<Index, Agent<D>>,
|
||||
) -> Vec<Vec<Player<D>>> {
|
||||
self.teams
|
||||
.iter()
|
||||
.map(|team| {
|
||||
@@ -116,12 +117,12 @@ impl Batch {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_events(
|
||||
pub fn add_events<D: Drift>(
|
||||
&mut self,
|
||||
composition: Vec<Vec<Vec<Index>>>,
|
||||
results: Vec<Vec<f64>>,
|
||||
weights: Vec<Vec<Vec<f64>>>,
|
||||
agents: &HashMap<Index, Agent>,
|
||||
agents: &HashMap<Index, Agent<D>>,
|
||||
) {
|
||||
let mut unique = Vec::with_capacity(10);
|
||||
|
||||
@@ -207,7 +208,7 @@ impl Batch {
|
||||
.collect::<HashMap<_, _>>()
|
||||
}
|
||||
|
||||
pub fn iteration(&mut self, from: usize, agents: &HashMap<Index, Agent>) {
|
||||
pub fn iteration<D: Drift>(&mut self, from: usize, agents: &HashMap<Index, Agent<D>>) {
|
||||
for event in self.events.iter_mut().skip(from) {
|
||||
let teams = event.within_priors(false, false, &self.skills, agents);
|
||||
let result = event.outputs();
|
||||
@@ -229,7 +230,7 @@ impl Batch {
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn convergence(&mut self, agents: &HashMap<Index, Agent>) -> usize {
|
||||
pub(crate) fn convergence<D: Drift>(&mut self, agents: &HashMap<Index, Agent<D>>) -> usize {
|
||||
let epsilon = 1e-6;
|
||||
let iterations = 20;
|
||||
|
||||
@@ -259,18 +260,18 @@ impl Batch {
|
||||
skill.forward * skill.likelihood
|
||||
}
|
||||
|
||||
pub(crate) fn backward_prior_out(
|
||||
pub(crate) fn backward_prior_out<D: Drift>(
|
||||
&self,
|
||||
agent: &Index,
|
||||
agents: &HashMap<Index, Agent>,
|
||||
agents: &HashMap<Index, Agent<D>>,
|
||||
) -> Gaussian {
|
||||
let skill = &self.skills[agent];
|
||||
let n = skill.likelihood * skill.backward;
|
||||
|
||||
n.forget(agents[agent].player.gamma, skill.elapsed)
|
||||
n.forget(agents[agent].player.drift.variance_delta(skill.elapsed))
|
||||
}
|
||||
|
||||
pub(crate) fn new_backward_info(&mut self, agents: &HashMap<Index, Agent>) {
|
||||
pub(crate) fn new_backward_info<D: Drift>(&mut self, agents: &HashMap<Index, Agent<D>>) {
|
||||
for (agent, skill) in self.skills.iter_mut() {
|
||||
skill.backward = agents[agent].message;
|
||||
}
|
||||
@@ -278,7 +279,7 @@ impl Batch {
|
||||
self.iteration(0, agents);
|
||||
}
|
||||
|
||||
pub(crate) fn new_forward_info(&mut self, agents: &HashMap<Index, Agent>) {
|
||||
pub(crate) fn new_forward_info<D: Drift>(&mut self, agents: &HashMap<Index, Agent<D>>) {
|
||||
for (agent, skill) in self.skills.iter_mut() {
|
||||
skill.forward = agents[agent].receive(skill.elapsed);
|
||||
}
|
||||
@@ -286,12 +287,12 @@ impl Batch {
|
||||
self.iteration(0, agents);
|
||||
}
|
||||
|
||||
pub(crate) fn log_evidence(
|
||||
pub(crate) fn log_evidence<D: Drift>(
|
||||
&self,
|
||||
online: bool,
|
||||
targets: &[Index],
|
||||
forward: bool,
|
||||
agents: &HashMap<Index, Agent>,
|
||||
agents: &HashMap<Index, Agent<D>>,
|
||||
) -> f64 {
|
||||
if targets.is_empty() {
|
||||
if online || forward {
|
||||
@@ -390,7 +391,7 @@ pub(crate) fn compute_elapsed(last_time: i64, actual_time: i64) -> i64 {
|
||||
mod tests {
|
||||
use approx::assert_ulps_eq;
|
||||
|
||||
use crate::{agent::Agent, player::Player, IndexMap};
|
||||
use crate::{IndexMap, agent::Agent, drift::ConstantDrift, player::Player};
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -414,7 +415,7 @@ mod tests {
|
||||
player: Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
25.0 / 300.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
),
|
||||
..Default::default()
|
||||
},
|
||||
@@ -490,7 +491,7 @@ mod tests {
|
||||
player: Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
25.0 / 300.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
),
|
||||
..Default::default()
|
||||
},
|
||||
@@ -569,7 +570,7 @@ mod tests {
|
||||
player: Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
25.0 / 300.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
),
|
||||
..Default::default()
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user