feat: added a Drift trait and a "default" ConstantDrift implementation

This commit is contained in:
2026-03-16 12:06:04 +01:00
parent 853f177fa8
commit a1f282a1c8
14 changed files with 423 additions and 127 deletions

View File

@@ -1,25 +1,27 @@
use std::collections::HashMap;
use crate::{
BETA, GAMMA, Index, MU, N_INF, P_DRAW, SIGMA,
agent::{self, Agent},
batch::{self, Batch},
drift::{ConstantDrift, Drift},
gaussian::Gaussian,
player::Player,
sort_time, tuple_gt, tuple_max, Index, BETA, GAMMA, MU, P_DRAW, SIGMA,
sort_time, tuple_gt, tuple_max,
};
#[derive(Clone)]
pub struct HistoryBuilder {
pub struct HistoryBuilder<D: Drift = ConstantDrift> {
time: bool,
mu: f64,
sigma: f64,
beta: f64,
gamma: f64,
drift: D,
p_draw: f64,
online: bool,
}
impl HistoryBuilder {
impl<D: Drift> HistoryBuilder<D> {
pub fn time(mut self, time: bool) -> Self {
self.time = time;
self
@@ -40,9 +42,16 @@ impl HistoryBuilder {
self
}
pub fn gamma(mut self, gamma: f64) -> Self {
self.gamma = gamma;
self
pub fn drift<D2: Drift>(self, drift: D2) -> HistoryBuilder<D2> {
HistoryBuilder {
drift,
time: self.time,
mu: self.mu,
sigma: self.sigma,
beta: self.beta,
p_draw: self.p_draw,
online: self.online,
}
}
pub fn p_draw(mut self, p_draw: f64) -> Self {
@@ -55,7 +64,7 @@ impl HistoryBuilder {
self
}
pub fn build(self) -> History {
pub fn build(self) -> History<D> {
History {
size: 0,
batches: Vec::new(),
@@ -64,41 +73,48 @@ impl HistoryBuilder {
mu: self.mu,
sigma: self.sigma,
beta: self.beta,
gamma: self.gamma,
drift: self.drift,
p_draw: self.p_draw,
online: self.online,
}
}
}
impl Default for HistoryBuilder {
impl HistoryBuilder<ConstantDrift> {
pub fn gamma(mut self, gamma: f64) -> Self {
self.drift = ConstantDrift(gamma);
self
}
}
impl Default for HistoryBuilder<ConstantDrift> {
fn default() -> Self {
Self {
time: true,
mu: MU,
sigma: SIGMA,
beta: BETA,
gamma: GAMMA,
drift: ConstantDrift(GAMMA),
p_draw: P_DRAW,
online: false,
}
}
}
pub struct History {
pub struct History<D: Drift = ConstantDrift> {
size: usize,
pub(crate) batches: Vec<Batch>,
agents: HashMap<Index, Agent>,
agents: HashMap<Index, Agent<D>>,
time: bool,
mu: f64,
sigma: f64,
beta: f64,
gamma: f64,
drift: D,
p_draw: f64,
online: bool,
}
impl Default for History {
impl Default for History<ConstantDrift> {
fn default() -> Self {
Self {
size: 0,
@@ -108,18 +124,20 @@ impl Default for History {
mu: MU,
sigma: SIGMA,
beta: BETA,
gamma: GAMMA,
drift: ConstantDrift(GAMMA),
p_draw: P_DRAW,
online: false,
}
}
}
impl History {
pub fn builder() -> HistoryBuilder {
impl History<ConstantDrift> {
pub fn builder() -> HistoryBuilder<ConstantDrift> {
HistoryBuilder::default()
}
}
impl<D: Drift> History<D> {
fn iteration(&mut self) -> (f64, f64) {
let mut step = (0.0, 0.0);
@@ -247,7 +265,7 @@ impl History {
results: Vec<Vec<f64>>,
times: Vec<i64>,
weights: Vec<Vec<Vec<f64>>>,
mut priors: HashMap<Index, Player>,
mut priors: HashMap<Index, Player<D>>,
) {
assert!(times.is_empty() || self.time, "length(times)>0 but !h.time");
assert!(
@@ -286,10 +304,11 @@ impl History {
Player::new(
Gaussian::from_ms(self.mu, self.sigma),
self.beta,
self.gamma,
self.drift,
)
}),
..Default::default()
message: N_INF,
last_time: i64::MIN,
},
);
}
@@ -414,7 +433,7 @@ impl History {
mod tests {
use approx::assert_ulps_eq;
use crate::{Game, Gaussian, IndexMap, Player, EPSILON, ITERATIONS, P_DRAW};
use crate::{ConstantDrift, EPSILON, Game, Gaussian, ITERATIONS, IndexMap, P_DRAW, Player};
use super::*;
@@ -441,7 +460,7 @@ mod tests {
Player::new(
Gaussian::from_ms(25.0, 25.0 / 3.0),
25.0 / 6.0,
0.15 * 25.0 / 3.0,
ConstantDrift(0.15 * 25.0 / 3.0),
),
);
}
@@ -503,7 +522,7 @@ mod tests {
Player::new(
Gaussian::from_ms(25.0, 25.0 / 3.0),
25.0 / 6.0,
0.15 * 25.0 / 3.0,
ConstantDrift(0.15 * 25.0 / 3.0),
),
);
}
@@ -552,7 +571,7 @@ mod tests {
Player::new(
Gaussian::from_ms(25.0, 25.0 / 3.0),
25.0 / 6.0,
25.0 / 300.0,
ConstantDrift(25.0 / 300.0),
),
);
}
@@ -610,7 +629,7 @@ mod tests {
Player::new(
Gaussian::from_ms(25.0, 25.0 / 3.0),
25.0 / 6.0,
25.0 / 300.0,
ConstantDrift(25.0 / 300.0),
),
);
}