feat: added a Drift trait and a "default" ConstantDrift implementation
This commit is contained in:
@@ -1,25 +1,27 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
BETA, GAMMA, Index, MU, N_INF, P_DRAW, SIGMA,
|
||||
agent::{self, Agent},
|
||||
batch::{self, Batch},
|
||||
drift::{ConstantDrift, Drift},
|
||||
gaussian::Gaussian,
|
||||
player::Player,
|
||||
sort_time, tuple_gt, tuple_max, Index, BETA, GAMMA, MU, P_DRAW, SIGMA,
|
||||
sort_time, tuple_gt, tuple_max,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct HistoryBuilder {
|
||||
pub struct HistoryBuilder<D: Drift = ConstantDrift> {
|
||||
time: bool,
|
||||
mu: f64,
|
||||
sigma: f64,
|
||||
beta: f64,
|
||||
gamma: f64,
|
||||
drift: D,
|
||||
p_draw: f64,
|
||||
online: bool,
|
||||
}
|
||||
|
||||
impl HistoryBuilder {
|
||||
impl<D: Drift> HistoryBuilder<D> {
|
||||
pub fn time(mut self, time: bool) -> Self {
|
||||
self.time = time;
|
||||
self
|
||||
@@ -40,9 +42,16 @@ impl HistoryBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn gamma(mut self, gamma: f64) -> Self {
|
||||
self.gamma = gamma;
|
||||
self
|
||||
pub fn drift<D2: Drift>(self, drift: D2) -> HistoryBuilder<D2> {
|
||||
HistoryBuilder {
|
||||
drift,
|
||||
time: self.time,
|
||||
mu: self.mu,
|
||||
sigma: self.sigma,
|
||||
beta: self.beta,
|
||||
p_draw: self.p_draw,
|
||||
online: self.online,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn p_draw(mut self, p_draw: f64) -> Self {
|
||||
@@ -55,7 +64,7 @@ impl HistoryBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> History {
|
||||
pub fn build(self) -> History<D> {
|
||||
History {
|
||||
size: 0,
|
||||
batches: Vec::new(),
|
||||
@@ -64,41 +73,48 @@ impl HistoryBuilder {
|
||||
mu: self.mu,
|
||||
sigma: self.sigma,
|
||||
beta: self.beta,
|
||||
gamma: self.gamma,
|
||||
drift: self.drift,
|
||||
p_draw: self.p_draw,
|
||||
online: self.online,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HistoryBuilder {
|
||||
impl HistoryBuilder<ConstantDrift> {
|
||||
pub fn gamma(mut self, gamma: f64) -> Self {
|
||||
self.drift = ConstantDrift(gamma);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HistoryBuilder<ConstantDrift> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
time: true,
|
||||
mu: MU,
|
||||
sigma: SIGMA,
|
||||
beta: BETA,
|
||||
gamma: GAMMA,
|
||||
drift: ConstantDrift(GAMMA),
|
||||
p_draw: P_DRAW,
|
||||
online: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct History {
|
||||
pub struct History<D: Drift = ConstantDrift> {
|
||||
size: usize,
|
||||
pub(crate) batches: Vec<Batch>,
|
||||
agents: HashMap<Index, Agent>,
|
||||
agents: HashMap<Index, Agent<D>>,
|
||||
time: bool,
|
||||
mu: f64,
|
||||
sigma: f64,
|
||||
beta: f64,
|
||||
gamma: f64,
|
||||
drift: D,
|
||||
p_draw: f64,
|
||||
online: bool,
|
||||
}
|
||||
|
||||
impl Default for History {
|
||||
impl Default for History<ConstantDrift> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
size: 0,
|
||||
@@ -108,18 +124,20 @@ impl Default for History {
|
||||
mu: MU,
|
||||
sigma: SIGMA,
|
||||
beta: BETA,
|
||||
gamma: GAMMA,
|
||||
drift: ConstantDrift(GAMMA),
|
||||
p_draw: P_DRAW,
|
||||
online: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl History {
|
||||
pub fn builder() -> HistoryBuilder {
|
||||
impl History<ConstantDrift> {
|
||||
pub fn builder() -> HistoryBuilder<ConstantDrift> {
|
||||
HistoryBuilder::default()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Drift> History<D> {
|
||||
fn iteration(&mut self) -> (f64, f64) {
|
||||
let mut step = (0.0, 0.0);
|
||||
|
||||
@@ -247,7 +265,7 @@ impl History {
|
||||
results: Vec<Vec<f64>>,
|
||||
times: Vec<i64>,
|
||||
weights: Vec<Vec<Vec<f64>>>,
|
||||
mut priors: HashMap<Index, Player>,
|
||||
mut priors: HashMap<Index, Player<D>>,
|
||||
) {
|
||||
assert!(times.is_empty() || self.time, "length(times)>0 but !h.time");
|
||||
assert!(
|
||||
@@ -286,10 +304,11 @@ impl History {
|
||||
Player::new(
|
||||
Gaussian::from_ms(self.mu, self.sigma),
|
||||
self.beta,
|
||||
self.gamma,
|
||||
self.drift,
|
||||
)
|
||||
}),
|
||||
..Default::default()
|
||||
message: N_INF,
|
||||
last_time: i64::MIN,
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -414,7 +433,7 @@ impl History {
|
||||
mod tests {
|
||||
use approx::assert_ulps_eq;
|
||||
|
||||
use crate::{Game, Gaussian, IndexMap, Player, EPSILON, ITERATIONS, P_DRAW};
|
||||
use crate::{ConstantDrift, EPSILON, Game, Gaussian, ITERATIONS, IndexMap, P_DRAW, Player};
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -441,7 +460,7 @@ mod tests {
|
||||
Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
0.15 * 25.0 / 3.0,
|
||||
ConstantDrift(0.15 * 25.0 / 3.0),
|
||||
),
|
||||
);
|
||||
}
|
||||
@@ -503,7 +522,7 @@ mod tests {
|
||||
Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
0.15 * 25.0 / 3.0,
|
||||
ConstantDrift(0.15 * 25.0 / 3.0),
|
||||
),
|
||||
);
|
||||
}
|
||||
@@ -552,7 +571,7 @@ mod tests {
|
||||
Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
25.0 / 300.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
),
|
||||
);
|
||||
}
|
||||
@@ -610,7 +629,7 @@ mod tests {
|
||||
Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
25.0 / 300.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user