feat: added a Drift trait and a "default" ConstantDrift implementation

This commit is contained in:
2026-03-16 12:06:04 +01:00
parent 853f177fa8
commit a1f282a1c8
14 changed files with 423 additions and 127 deletions

View File

@@ -1,7 +1,8 @@
use std::collections::HashMap;
use crate::{
agent::Agent, game::Game, gaussian::Gaussian, player::Player, tuple_gt, tuple_max, Index, N_INF,
Index, N_INF, agent::Agent, drift::Drift, game::Game, gaussian::Gaussian, player::Player,
tuple_gt, tuple_max,
};
#[derive(Debug)]
@@ -38,22 +39,22 @@ struct Item {
}
impl Item {
fn within_prior(
fn within_prior<D: Drift>(
&self,
online: bool,
forward: bool,
skills: &HashMap<Index, Skill>,
agents: &HashMap<Index, Agent>,
) -> Player {
agents: &HashMap<Index, Agent<D>>,
) -> Player<D> {
let r = &agents[&self.agent].player;
let skill = &skills[&self.agent];
if online {
Player::new(skill.online, r.beta, r.gamma)
Player::new(skill.online, r.beta, r.drift)
} else if forward {
Player::new(skill.forward, r.beta, r.gamma)
Player::new(skill.forward, r.beta, r.drift)
} else {
Player::new(skill.posterior() / self.likelihood, r.beta, r.gamma)
Player::new(skill.posterior() / self.likelihood, r.beta, r.drift)
}
}
}
@@ -79,13 +80,13 @@ impl Event {
.collect::<Vec<_>>()
}
pub(crate) fn within_priors(
pub(crate) fn within_priors<D: Drift>(
&self,
online: bool,
forward: bool,
skills: &HashMap<Index, Skill>,
agents: &HashMap<Index, Agent>,
) -> Vec<Vec<Player>> {
agents: &HashMap<Index, Agent<D>>,
) -> Vec<Vec<Player<D>>> {
self.teams
.iter()
.map(|team| {
@@ -116,12 +117,12 @@ impl Batch {
}
}
pub fn add_events(
pub fn add_events<D: Drift>(
&mut self,
composition: Vec<Vec<Vec<Index>>>,
results: Vec<Vec<f64>>,
weights: Vec<Vec<Vec<f64>>>,
agents: &HashMap<Index, Agent>,
agents: &HashMap<Index, Agent<D>>,
) {
let mut unique = Vec::with_capacity(10);
@@ -207,7 +208,7 @@ impl Batch {
.collect::<HashMap<_, _>>()
}
pub fn iteration(&mut self, from: usize, agents: &HashMap<Index, Agent>) {
pub fn iteration<D: Drift>(&mut self, from: usize, agents: &HashMap<Index, Agent<D>>) {
for event in self.events.iter_mut().skip(from) {
let teams = event.within_priors(false, false, &self.skills, agents);
let result = event.outputs();
@@ -229,7 +230,7 @@ impl Batch {
}
#[allow(dead_code)]
pub(crate) fn convergence(&mut self, agents: &HashMap<Index, Agent>) -> usize {
pub(crate) fn convergence<D: Drift>(&mut self, agents: &HashMap<Index, Agent<D>>) -> usize {
let epsilon = 1e-6;
let iterations = 20;
@@ -259,18 +260,18 @@ impl Batch {
skill.forward * skill.likelihood
}
pub(crate) fn backward_prior_out(
pub(crate) fn backward_prior_out<D: Drift>(
&self,
agent: &Index,
agents: &HashMap<Index, Agent>,
agents: &HashMap<Index, Agent<D>>,
) -> Gaussian {
let skill = &self.skills[agent];
let n = skill.likelihood * skill.backward;
n.forget(agents[agent].player.gamma, skill.elapsed)
n.forget(agents[agent].player.drift.variance_delta(skill.elapsed))
}
pub(crate) fn new_backward_info(&mut self, agents: &HashMap<Index, Agent>) {
pub(crate) fn new_backward_info<D: Drift>(&mut self, agents: &HashMap<Index, Agent<D>>) {
for (agent, skill) in self.skills.iter_mut() {
skill.backward = agents[agent].message;
}
@@ -278,7 +279,7 @@ impl Batch {
self.iteration(0, agents);
}
pub(crate) fn new_forward_info(&mut self, agents: &HashMap<Index, Agent>) {
pub(crate) fn new_forward_info<D: Drift>(&mut self, agents: &HashMap<Index, Agent<D>>) {
for (agent, skill) in self.skills.iter_mut() {
skill.forward = agents[agent].receive(skill.elapsed);
}
@@ -286,12 +287,12 @@ impl Batch {
self.iteration(0, agents);
}
pub(crate) fn log_evidence(
pub(crate) fn log_evidence<D: Drift>(
&self,
online: bool,
targets: &[Index],
forward: bool,
agents: &HashMap<Index, Agent>,
agents: &HashMap<Index, Agent<D>>,
) -> f64 {
if targets.is_empty() {
if online || forward {
@@ -390,7 +391,7 @@ pub(crate) fn compute_elapsed(last_time: i64, actual_time: i64) -> i64 {
mod tests {
use approx::assert_ulps_eq;
use crate::{agent::Agent, player::Player, IndexMap};
use crate::{IndexMap, agent::Agent, drift::ConstantDrift, player::Player};
use super::*;
@@ -414,7 +415,7 @@ mod tests {
player: Player::new(
Gaussian::from_ms(25.0, 25.0 / 3.0),
25.0 / 6.0,
25.0 / 300.0,
ConstantDrift(25.0 / 300.0),
),
..Default::default()
},
@@ -490,7 +491,7 @@ mod tests {
player: Player::new(
Gaussian::from_ms(25.0, 25.0 / 3.0),
25.0 / 6.0,
25.0 / 300.0,
ConstantDrift(25.0 / 300.0),
),
..Default::default()
},
@@ -569,7 +570,7 @@ mod tests {
player: Player::new(
Gaussian::from_ms(25.0, 25.0 / 3.0),
25.0 / 6.0,
25.0 / 300.0,
ConstantDrift(25.0 / 300.0),
),
..Default::default()
},