Added builder for History, and start migrating test to use builder instead.

This commit is contained in:
2022-06-22 21:13:24 +02:00
parent a573f5cf0a
commit 73b9cabac8
3 changed files with 216 additions and 70 deletions

View File

@@ -4,8 +4,8 @@ use crate::{MU, N_INF, SIGMA};
#[derive(Clone, Copy, PartialEq, Debug)] #[derive(Clone, Copy, PartialEq, Debug)]
pub struct Gaussian { pub struct Gaussian {
pub(crate) mu: f64, pub mu: f64,
pub(crate) sigma: f64, pub sigma: f64,
} }
impl Gaussian { impl Gaussian {

View File

@@ -5,9 +5,102 @@ use crate::{
batch::{self, Batch}, batch::{self, Batch},
gaussian::Gaussian, gaussian::Gaussian,
player::Player, player::Player,
sort_time, tuple_gt, tuple_max, sort_time, tuple_gt, tuple_max, BETA, EPSILON, GAMMA, ITERATIONS, MU, P_DRAW, SIGMA,
}; };
#[derive(Clone)]
pub struct HistoryBuilder {
time: bool,
mu: f64,
sigma: f64,
beta: f64,
gamma: f64,
p_draw: f64,
online: bool,
epsilon: f64,
iterations: usize,
}
impl HistoryBuilder {
pub fn time(mut self, time: bool) -> Self {
self.time = time;
self
}
pub fn mu(mut self, mu: f64) -> Self {
self.mu = mu;
self
}
pub fn sigma(mut self, sigma: f64) -> Self {
self.sigma = sigma;
self
}
pub fn beta(mut self, beta: f64) -> Self {
self.beta = beta;
self
}
pub fn gamma(mut self, gamma: f64) -> Self {
self.gamma = gamma;
self
}
pub fn p_draw(mut self, p_draw: f64) -> Self {
self.p_draw = p_draw;
self
}
pub fn online(mut self, online: bool) -> Self {
self.online = online;
self
}
pub fn epsilon(mut self, epsilon: f64) -> Self {
self.epsilon = epsilon;
self
}
pub fn iterations(mut self, iterations: usize) -> Self {
self.iterations = iterations;
self
}
pub fn build(self) -> History {
History {
size: 0,
batches: Vec::new(),
agents: HashMap::new(),
time: self.time,
mu: self.mu,
sigma: self.sigma,
beta: self.beta,
gamma: self.gamma,
p_draw: self.p_draw,
online: self.online,
epsilon: self.epsilon,
iterations: self.iterations,
}
}
}
impl Default for HistoryBuilder {
fn default() -> Self {
Self {
time: true,
mu: MU,
sigma: SIGMA,
beta: BETA,
gamma: GAMMA,
p_draw: P_DRAW,
online: false,
epsilon: EPSILON,
iterations: ITERATIONS,
}
}
}
pub struct History { pub struct History {
size: usize, size: usize,
pub batches: Vec<Batch>, pub batches: Vec<Batch>,
@@ -19,12 +112,15 @@ pub struct History {
gamma: f64, gamma: f64,
p_draw: f64, p_draw: f64,
online: bool, online: bool,
weights: Vec<Vec<Vec<f64>>>,
epsilon: f64, epsilon: f64,
iterations: usize, iterations: usize,
} }
impl History { impl History {
pub fn builder() -> HistoryBuilder {
HistoryBuilder::default()
}
pub fn new( pub fn new(
composition: Vec<Vec<Vec<&str>>>, composition: Vec<Vec<Vec<&str>>>,
results: Vec<Vec<f64>>, results: Vec<Vec<f64>>,
@@ -87,7 +183,6 @@ impl History {
gamma, gamma,
p_draw, p_draw,
online, online,
weights: weights.clone(),
epsilon: 0.0, epsilon: 0.0,
iterations: 10, iterations: 10,
}; };
@@ -147,9 +242,9 @@ impl History {
&mut self.agents, &mut self.agents,
); );
self.batches.push(b); let idx = self.batches.len();
let idx = self.batches.len() - 1; self.batches.push(b);
if online { if online {
let new = 100.0 * (i as f64 / self.size as f64); let new = 100.0 * (i as f64 / self.size as f64);
@@ -357,7 +452,7 @@ impl History {
let mut j = i + 1; let mut j = i + 1;
let t = if self.time { times[o[i]] } else { i as u64 + 1 }; let t = if self.time { times[o[i]] } else { i as u64 + 1 };
while self.time && j < self.size && times[o[j]] == t { while self.time && j < n && times[o[j]] == t {
j += 1; j += 1;
} }
@@ -427,7 +522,7 @@ impl History {
self.batches.insert(k, b); self.batches.insert(k, b);
let b = &mut self.batches[k]; let b = &self.batches[k];
for a in b.skills.keys() { for a in b.skills.keys() {
let agent = self.agents.get_mut(a).unwrap(); let agent = self.agents.get_mut(a).unwrap();
@@ -467,8 +562,6 @@ impl History {
} }
self.size += n; self.size += n;
self.iteration();
} }
} }
@@ -502,19 +595,9 @@ mod tests {
); );
} }
let mut h = History::new( let mut h = History::builder().build();
composition,
results, h.add_events(composition, results, vec![1, 2, 3], vec![], priors);
vec![1, 2, 3],
vec![],
priors,
MU,
SIGMA,
BETA,
GAMMA,
P_DRAW,
false,
);
let p0 = h.batches[0].posteriors(); let p0 = h.batches[0].posteriors();
@@ -566,19 +649,9 @@ mod tests {
priors.insert(k.to_string(), player); priors.insert(k.to_string(), player);
} }
let mut h1 = History::new( let mut h1 = History::builder().build();
composition,
results, h1.add_events(composition, results, times, vec![], priors);
times,
vec![],
priors,
MU,
SIGMA,
BETA,
GAMMA,
P_DRAW,
false,
);
assert_ulps_eq!( assert_ulps_eq!(
h1.batches[0].posterior("a"), h1.batches[0].posterior("a"),
@@ -591,7 +664,7 @@ mod tests {
epsilon = 0.000001 epsilon = 0.000001
); );
let (_step, _i) = h1.convergence(ITERATIONS, EPSILON, false); h1.convergence(ITERATIONS, EPSILON, false);
assert_ulps_eq!( assert_ulps_eq!(
h1.batches[0].posterior("a"), h1.batches[0].posterior("a"),
@@ -620,19 +693,9 @@ mod tests {
priors.insert(k.to_string(), player); priors.insert(k.to_string(), player);
} }
let mut h2 = History::new( let mut h2 = History::builder().build();
composition,
results, h2.add_events(composition, results, times, vec![], priors);
times,
vec![],
priors,
MU,
SIGMA,
BETA,
GAMMA,
P_DRAW,
false,
);
assert_ulps_eq!( assert_ulps_eq!(
h2.batches[2].posterior("a"), h2.batches[2].posterior("a"),
@@ -677,20 +740,9 @@ mod tests {
priors.insert(k.to_string(), player); priors.insert(k.to_string(), player);
} }
let mut h = History::new( let mut h = History::builder().build();
composition,
results,
times,
vec![],
priors,
MU,
SIGMA,
BETA,
GAMMA,
P_DRAW,
false,
);
h.add_events(composition, results, times, vec![], priors);
h.convergence(ITERATIONS, EPSILON, false); h.convergence(ITERATIONS, EPSILON, false);
let lc = h.learning_curves(); let lc = h.learning_curves();
@@ -930,6 +982,101 @@ mod tests {
); );
} }
#[test]
fn test_only_add_events() {
let mut h = History::new(
vec![],
vec![],
vec![],
vec![],
HashMap::new(),
0.0,
2.0,
1.0,
0.0,
0.0,
false,
);
let composition = vec![
vec![vec!["a"], vec!["b"]],
vec![vec!["a"], vec!["c"]],
vec![vec!["b"], vec!["c"]],
];
let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]];
h.add_events(
composition.clone(),
results.clone(),
vec![],
vec![],
HashMap::new(),
);
let (_step, _i) = h.convergence(ITERATIONS, EPSILON, false);
assert_eq!(h.batches[2].skills["b"].elapsed, 1);
assert_eq!(h.batches[2].skills["c"].elapsed, 1);
assert_ulps_eq!(
h.batches[0].posterior("a"),
Gaussian::new(0.0, 1.30061),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[0].posterior("b"),
Gaussian::new(0.0, 1.30061),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[2].posterior("b"),
Gaussian::new(0.0, 1.30061),
epsilon = 0.000001
);
h.add_events(composition, results, vec![], vec![], HashMap::new());
assert_eq!(h.batches.len(), 6);
assert_eq!(
h.batches
.iter()
.map(|b| b.get_composition())
.collect::<Vec<_>>(),
vec![
vec![vec![vec!["a"], vec!["b"]]],
vec![vec![vec!["a"], vec!["c"]]],
vec![vec![vec!["b"], vec!["c"]]],
vec![vec![vec!["a"], vec!["b"]]],
vec![vec![vec!["a"], vec!["c"]]],
vec![vec![vec!["b"], vec!["c"]]]
]
);
h.convergence(ITERATIONS, EPSILON, false);
assert_ulps_eq!(
h.batches[0].posterior("a"),
Gaussian::new(0.0, 0.9312360609998878),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[3].posterior("a"),
Gaussian::new(0.0, 0.9312360609998878),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[3].posterior("b"),
Gaussian::new(0.0, 0.9312360609998878),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[5].posterior("b"),
Gaussian::new(0.0, 0.9312360609998878),
epsilon = 0.000001
);
}
#[test] #[test]
fn test_log_evidence() { fn test_log_evidence() {
let composition = vec![vec![vec!["a"], vec!["b"]], vec![vec!["b"], vec!["a"]]]; let composition = vec![vec![vec!["a"], vec!["b"]], vec![vec!["b"], vec!["a"]]];

View File

@@ -11,18 +11,17 @@ mod history;
mod message; mod message;
mod player; mod player;
use gaussian::Gaussian;
use message::DiffMessage;
pub use game::Game; pub use game::Game;
pub use gaussian::Gaussian;
pub use history::History; pub use history::History;
use message::DiffMessage;
pub use player::Player; pub use player::Player;
const BETA: f64 = 1.0; pub const BETA: f64 = 1.0;
pub const MU: f64 = 0.0; pub const MU: f64 = 0.0;
pub const SIGMA: f64 = BETA * 6.0; pub const SIGMA: f64 = BETA * 6.0;
const GAMMA: f64 = BETA * 0.03; pub const GAMMA: f64 = BETA * 0.03;
const P_DRAW: f64 = 0.0; pub const P_DRAW: f64 = 0.0;
pub const EPSILON: f64 = 1e-6; pub const EPSILON: f64 = 1e-6;
pub const ITERATIONS: usize = 30; pub const ITERATIONS: usize = 30;