T0 + T1 + T2: engine redesign through new API surface #1
@@ -1,7 +1,7 @@
|
||||
use criterion::{Criterion, criterion_group, criterion_main};
|
||||
use trueskill_tt::{
|
||||
BETA, Competitor, GAMMA, KeyTable, MU, P_DRAW, Rating, SIGMA, batch::Batch,
|
||||
drift::ConstantDrift, gaussian::Gaussian, storage::CompetitorStore,
|
||||
BETA, Competitor, GAMMA, KeyTable, MU, P_DRAW, Rating, SIGMA, TimeSlice, drift::ConstantDrift,
|
||||
gaussian::Gaussian, storage::CompetitorStore,
|
||||
};
|
||||
|
||||
fn criterion_benchmark(criterion: &mut Criterion) {
|
||||
@@ -33,11 +33,11 @@ fn criterion_benchmark(criterion: &mut Criterion) {
|
||||
weights.push(vec![vec![1.0], vec![1.0]]);
|
||||
}
|
||||
|
||||
let mut batch = Batch::new(1, P_DRAW);
|
||||
batch.add_events(composition, results, weights, &agents);
|
||||
let mut time_slice = TimeSlice::new(1, P_DRAW);
|
||||
time_slice.add_events(composition, results, weights, &agents);
|
||||
|
||||
criterion.bench_function("Batch::iteration", |b| {
|
||||
b.iter(|| batch.iteration(0, &agents))
|
||||
b.iter(|| time_slice.iteration(0, &agents))
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::{factor::VarStore, gaussian::Gaussian};
|
||||
|
||||
/// Reusable scratch buffers for `Game::likelihoods`.
|
||||
///
|
||||
/// A `Batch` owns one arena; all events in the slice share it across
|
||||
/// A `TimeSlice` owns one arena; all events in the slice share it across
|
||||
/// the convergence iterations. All Vecs are cleared (not dropped) on
|
||||
/// `reset()` so their heap capacity is reused across games.
|
||||
#[derive(Debug, Default)]
|
||||
|
||||
283
src/history.rs
283
src/history.rs
@@ -2,13 +2,13 @@ use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
BETA, GAMMA, Index, MU, N_INF, P_DRAW, SIGMA,
|
||||
batch::{self, Batch},
|
||||
competitor::{self, Competitor},
|
||||
drift::{ConstantDrift, Drift},
|
||||
gaussian::Gaussian,
|
||||
rating::Rating,
|
||||
sort_time,
|
||||
storage::CompetitorStore,
|
||||
time_slice::{self, TimeSlice},
|
||||
tuple_gt, tuple_max,
|
||||
};
|
||||
|
||||
@@ -69,7 +69,7 @@ impl<D: Drift> HistoryBuilder<D> {
|
||||
pub fn build(self) -> History<D> {
|
||||
History {
|
||||
size: 0,
|
||||
batches: Vec::new(),
|
||||
time_slices: Vec::new(),
|
||||
agents: CompetitorStore::new(),
|
||||
time: self.time,
|
||||
mu: self.mu,
|
||||
@@ -105,7 +105,7 @@ impl Default for HistoryBuilder<ConstantDrift> {
|
||||
|
||||
pub struct History<D: Drift = ConstantDrift> {
|
||||
size: usize,
|
||||
pub(crate) batches: Vec<Batch>,
|
||||
pub(crate) time_slices: Vec<TimeSlice>,
|
||||
agents: CompetitorStore<D>,
|
||||
time: bool,
|
||||
mu: f64,
|
||||
@@ -120,7 +120,7 @@ impl Default for History<ConstantDrift> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
size: 0,
|
||||
batches: Vec::new(),
|
||||
time_slices: Vec::new(),
|
||||
agents: CompetitorStore::new(),
|
||||
time: true,
|
||||
mu: MU,
|
||||
@@ -145,17 +145,17 @@ impl<D: Drift> History<D> {
|
||||
|
||||
competitor::clean(self.agents.values_mut(), false);
|
||||
|
||||
for j in (0..self.batches.len() - 1).rev() {
|
||||
for agent in self.batches[j + 1].skills.keys() {
|
||||
for j in (0..self.time_slices.len() - 1).rev() {
|
||||
for agent in self.time_slices[j + 1].skills.keys() {
|
||||
self.agents.get_mut(agent).unwrap().message =
|
||||
self.batches[j + 1].backward_prior_out(&agent, &self.agents);
|
||||
self.time_slices[j + 1].backward_prior_out(&agent, &self.agents);
|
||||
}
|
||||
|
||||
let old = self.batches[j].posteriors();
|
||||
let old = self.time_slices[j].posteriors();
|
||||
|
||||
self.batches[j].new_backward_info(&self.agents);
|
||||
self.time_slices[j].new_backward_info(&self.agents);
|
||||
|
||||
let new = self.batches[j].posteriors();
|
||||
let new = self.time_slices[j].posteriors();
|
||||
|
||||
step = old
|
||||
.iter()
|
||||
@@ -164,29 +164,29 @@ impl<D: Drift> History<D> {
|
||||
|
||||
competitor::clean(self.agents.values_mut(), false);
|
||||
|
||||
for j in 1..self.batches.len() {
|
||||
for agent in self.batches[j - 1].skills.keys() {
|
||||
for j in 1..self.time_slices.len() {
|
||||
for agent in self.time_slices[j - 1].skills.keys() {
|
||||
self.agents.get_mut(agent).unwrap().message =
|
||||
self.batches[j - 1].forward_prior_out(&agent);
|
||||
self.time_slices[j - 1].forward_prior_out(&agent);
|
||||
}
|
||||
|
||||
let old = self.batches[j].posteriors();
|
||||
let old = self.time_slices[j].posteriors();
|
||||
|
||||
self.batches[j].new_forward_info(&self.agents);
|
||||
self.time_slices[j].new_forward_info(&self.agents);
|
||||
|
||||
let new = self.batches[j].posteriors();
|
||||
let new = self.time_slices[j].posteriors();
|
||||
|
||||
step = old
|
||||
.iter()
|
||||
.fold(step, |step, (a, old)| tuple_max(step, old.delta(new[a])));
|
||||
}
|
||||
|
||||
if self.batches.len() == 1 {
|
||||
let old = self.batches[0].posteriors();
|
||||
if self.time_slices.len() == 1 {
|
||||
let old = self.time_slices[0].posteriors();
|
||||
|
||||
self.batches[0].iteration(0, &self.agents);
|
||||
self.time_slices[0].iteration(0, &self.agents);
|
||||
|
||||
let new = self.batches[0].posteriors();
|
||||
let new = self.time_slices[0].posteriors();
|
||||
|
||||
step = old
|
||||
.iter()
|
||||
@@ -229,7 +229,7 @@ impl<D: Drift> History<D> {
|
||||
pub fn learning_curves(&self) -> HashMap<Index, Vec<(i64, Gaussian)>> {
|
||||
let mut data: HashMap<Index, Vec<(i64, Gaussian)>> = HashMap::new();
|
||||
|
||||
for b in &self.batches {
|
||||
for b in &self.time_slices {
|
||||
for (agent, skill) in b.skills.iter() {
|
||||
let point = (b.time, skill.posterior());
|
||||
|
||||
@@ -245,9 +245,9 @@ impl<D: Drift> History<D> {
|
||||
}
|
||||
|
||||
pub fn log_evidence(&mut self, forward: bool, targets: &[Index]) -> f64 {
|
||||
self.batches
|
||||
self.time_slices
|
||||
.iter()
|
||||
.map(|batch| batch.log_evidence(self.online, targets, forward, &self.agents))
|
||||
.map(|ts| ts.log_evidence(self.online, targets, forward, &self.agents))
|
||||
.sum()
|
||||
}
|
||||
|
||||
@@ -335,24 +335,26 @@ impl<D: Drift> History<D> {
|
||||
}
|
||||
|
||||
while (!self.time && (self.size > k))
|
||||
|| (self.time && self.batches.len() > k && self.batches[k].time < t)
|
||||
|| (self.time && self.time_slices.len() > k && self.time_slices[k].time < t)
|
||||
{
|
||||
let batch = &mut self.batches[k];
|
||||
let time_slice = &mut self.time_slices[k];
|
||||
|
||||
if k > 0 {
|
||||
batch.new_forward_info(&self.agents);
|
||||
time_slice.new_forward_info(&self.agents);
|
||||
}
|
||||
|
||||
// TODO: Is it faster to iterate over agents in batch instead?
|
||||
for agent_idx in &this_agent {
|
||||
if let Some(skill) = batch.skills.get_mut(*agent_idx) {
|
||||
skill.elapsed =
|
||||
batch::compute_elapsed(self.agents[*agent_idx].last_time, batch.time);
|
||||
if let Some(skill) = time_slice.skills.get_mut(*agent_idx) {
|
||||
skill.elapsed = time_slice::compute_elapsed(
|
||||
self.agents[*agent_idx].last_time,
|
||||
time_slice.time,
|
||||
);
|
||||
|
||||
let agent = self.agents.get_mut(*agent_idx).unwrap();
|
||||
|
||||
agent.last_time = if self.time { batch.time } else { i64::MAX };
|
||||
agent.message = batch.forward_prior_out(agent_idx);
|
||||
agent.last_time = if self.time { time_slice.time } else { i64::MAX };
|
||||
agent.message = time_slice.forward_prior_out(agent_idx);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -375,29 +377,29 @@ impl<D: Drift> History<D> {
|
||||
(i..j).map(|e| weights[o[e]].clone()).collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
if self.time && self.batches.len() > k && self.batches[k].time == t {
|
||||
let batch = &mut self.batches[k];
|
||||
batch.add_events(composition, results, weights, &self.agents);
|
||||
if self.time && self.time_slices.len() > k && self.time_slices[k].time == t {
|
||||
let time_slice = &mut self.time_slices[k];
|
||||
time_slice.add_events(composition, results, weights, &self.agents);
|
||||
|
||||
for agent_idx in batch.skills.keys() {
|
||||
for agent_idx in time_slice.skills.keys() {
|
||||
let agent = self.agents.get_mut(agent_idx).unwrap();
|
||||
|
||||
agent.last_time = if self.time { t } else { i64::MAX };
|
||||
agent.message = batch.forward_prior_out(&agent_idx);
|
||||
agent.message = time_slice.forward_prior_out(&agent_idx);
|
||||
}
|
||||
} else {
|
||||
let mut batch: Batch = Batch::new(t, self.p_draw);
|
||||
batch.add_events(composition, results, weights, &self.agents);
|
||||
let mut time_slice: TimeSlice = TimeSlice::new(t, self.p_draw);
|
||||
time_slice.add_events(composition, results, weights, &self.agents);
|
||||
|
||||
self.batches.insert(k, batch);
|
||||
self.time_slices.insert(k, time_slice);
|
||||
|
||||
let batch = &self.batches[k];
|
||||
let time_slice = &self.time_slices[k];
|
||||
|
||||
for agent_idx in batch.skills.keys() {
|
||||
for agent_idx in time_slice.skills.keys() {
|
||||
let agent = self.agents.get_mut(agent_idx).unwrap();
|
||||
|
||||
agent.last_time = if self.time { t } else { i64::MAX };
|
||||
agent.message = batch.forward_prior_out(&agent_idx);
|
||||
agent.message = time_slice.forward_prior_out(&agent_idx);
|
||||
}
|
||||
|
||||
k += 1;
|
||||
@@ -406,21 +408,23 @@ impl<D: Drift> History<D> {
|
||||
i = j;
|
||||
}
|
||||
|
||||
while self.time && self.batches.len() > k {
|
||||
let batch = &mut self.batches[k];
|
||||
while self.time && self.time_slices.len() > k {
|
||||
let time_slice = &mut self.time_slices[k];
|
||||
|
||||
batch.new_forward_info(&self.agents);
|
||||
time_slice.new_forward_info(&self.agents);
|
||||
|
||||
// TODO: Is it faster to iterate over agents in batch instead?
|
||||
for agent_idx in &this_agent {
|
||||
if let Some(skill) = batch.skills.get_mut(*agent_idx) {
|
||||
skill.elapsed =
|
||||
batch::compute_elapsed(self.agents[*agent_idx].last_time, batch.time);
|
||||
if let Some(skill) = time_slice.skills.get_mut(*agent_idx) {
|
||||
skill.elapsed = time_slice::compute_elapsed(
|
||||
self.agents[*agent_idx].last_time,
|
||||
time_slice.time,
|
||||
);
|
||||
|
||||
let agent = self.agents.get_mut(*agent_idx).unwrap();
|
||||
|
||||
agent.last_time = if self.time { batch.time } else { i64::MAX };
|
||||
agent.message = batch.forward_prior_out(agent_idx);
|
||||
agent.last_time = if self.time { time_slice.time } else { i64::MAX };
|
||||
agent.message = time_slice.forward_prior_out(agent_idx);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -473,7 +477,7 @@ mod tests {
|
||||
|
||||
h.add_events_with_prior(composition, results, vec![1, 2, 3], vec![], priors);
|
||||
|
||||
let p0 = h.batches[0].posteriors();
|
||||
let p0 = h.time_slices[0].posteriors();
|
||||
|
||||
assert_ulps_eq!(
|
||||
p0[&a],
|
||||
@@ -481,10 +485,10 @@ mod tests {
|
||||
epsilon = 1e-6
|
||||
);
|
||||
|
||||
let observed = h.batches[1].skills.get(a).unwrap().forward.sigma();
|
||||
let observed = h.time_slices[1].skills.get(a).unwrap().forward.sigma();
|
||||
let gamma: f64 = 0.15 * 25.0 / 3.0;
|
||||
let expected = (gamma.powi(2)
|
||||
+ h.batches[0]
|
||||
+ h.time_slices[0]
|
||||
.skills
|
||||
.get(a)
|
||||
.unwrap()
|
||||
@@ -495,11 +499,16 @@ mod tests {
|
||||
|
||||
assert_ulps_eq!(observed, expected, epsilon = 0.000001);
|
||||
|
||||
let observed = h.batches[1].skills.get(a).unwrap().posterior();
|
||||
let observed = h.time_slices[1].skills.get(a).unwrap().posterior();
|
||||
|
||||
let w = [vec![1.0], vec![1.0]];
|
||||
let p = Game::new(
|
||||
h.batches[1].events[0].within_priors(false, false, &h.batches[1].skills, &h.agents),
|
||||
h.time_slices[1].events[0].within_priors(
|
||||
false,
|
||||
false,
|
||||
&h.time_slices[1].skills,
|
||||
&h.agents,
|
||||
),
|
||||
&[0.0, 1.0],
|
||||
&w,
|
||||
P_DRAW,
|
||||
@@ -545,12 +554,12 @@ mod tests {
|
||||
h1.add_events_with_prior(composition, results, times, vec![], priors);
|
||||
|
||||
assert_ulps_eq!(
|
||||
h1.batches[0].skills.get(a).unwrap().posterior(),
|
||||
h1.time_slices[0].skills.get(a).unwrap().posterior(),
|
||||
Gaussian::from_ms(22.904409, 6.010330),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
assert_ulps_eq!(
|
||||
h1.batches[0].skills.get(c).unwrap().posterior(),
|
||||
h1.time_slices[0].skills.get(c).unwrap().posterior(),
|
||||
Gaussian::from_ms(25.110318, 5.866311),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
@@ -558,12 +567,12 @@ mod tests {
|
||||
h1.convergence(ITERATIONS, EPSILON, false);
|
||||
|
||||
assert_ulps_eq!(
|
||||
h1.batches[0].skills.get(a).unwrap().posterior(),
|
||||
h1.time_slices[0].skills.get(a).unwrap().posterior(),
|
||||
Gaussian::from_ms(25.000000, 5.419212),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
assert_ulps_eq!(
|
||||
h1.batches[0].skills.get(c).unwrap().posterior(),
|
||||
h1.time_slices[0].skills.get(c).unwrap().posterior(),
|
||||
Gaussian::from_ms(25.000000, 5.419212),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
@@ -594,12 +603,12 @@ mod tests {
|
||||
h2.add_events_with_prior(composition, results, times, vec![], priors);
|
||||
|
||||
assert_ulps_eq!(
|
||||
h2.batches[2].skills.get(a).unwrap().posterior(),
|
||||
h2.time_slices[2].skills.get(a).unwrap().posterior(),
|
||||
Gaussian::from_ms(22.903522, 6.011017),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
assert_ulps_eq!(
|
||||
h2.batches[2].skills.get(c).unwrap().posterior(),
|
||||
h2.time_slices[2].skills.get(c).unwrap().posterior(),
|
||||
Gaussian::from_ms(25.110702, 5.866811),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
@@ -607,12 +616,12 @@ mod tests {
|
||||
h2.convergence(ITERATIONS, EPSILON, false);
|
||||
|
||||
assert_ulps_eq!(
|
||||
h2.batches[2].skills.get(a).unwrap().posterior(),
|
||||
h2.time_slices[2].skills.get(a).unwrap().posterior(),
|
||||
Gaussian::from_ms(24.998668, 5.420053),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
assert_ulps_eq!(
|
||||
h2.batches[2].skills.get(c).unwrap().posterior(),
|
||||
h2.time_slices[2].skills.get(c).unwrap().posterior(),
|
||||
Gaussian::from_ms(25.000532, 5.419827),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
@@ -699,21 +708,21 @@ mod tests {
|
||||
|
||||
h.convergence(ITERATIONS, EPSILON, false);
|
||||
|
||||
assert_eq!(h.batches[2].skills.get(b).unwrap().elapsed, 1);
|
||||
assert_eq!(h.batches[2].skills.get(c).unwrap().elapsed, 1);
|
||||
assert_eq!(h.time_slices[2].skills.get(b).unwrap().elapsed, 1);
|
||||
assert_eq!(h.time_slices[2].skills.get(c).unwrap().elapsed, 1);
|
||||
|
||||
assert_ulps_eq!(
|
||||
h.batches[0].skills.get(a).unwrap().posterior(),
|
||||
h.time_slices[0].skills.get(a).unwrap().posterior(),
|
||||
Gaussian::from_ms(25.000267, 5.419381),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
assert_ulps_eq!(
|
||||
h.batches[0].skills.get(b).unwrap().posterior(),
|
||||
h.time_slices[0].skills.get(b).unwrap().posterior(),
|
||||
Gaussian::from_ms(24.999465, 5.419425),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
assert_ulps_eq!(
|
||||
h.batches[2].skills.get(b).unwrap().posterior(),
|
||||
h.time_slices[2].skills.get(b).unwrap().posterior(),
|
||||
Gaussian::from_ms(25.000532, 5.419696),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
@@ -757,8 +766,8 @@ mod tests {
|
||||
);
|
||||
|
||||
assert_ulps_eq!(
|
||||
h.batches[0].skills.get(b).unwrap().posterior().mu(),
|
||||
-h.batches[0].skills.get(c).unwrap().posterior().mu(),
|
||||
h.time_slices[0].skills.get(b).unwrap().posterior().mu(),
|
||||
-h.time_slices[0].skills.get(c).unwrap().posterior().mu(),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
|
||||
@@ -777,33 +786,33 @@ mod tests {
|
||||
assert_ulps_eq!(p_d_m_hat, 0.172432, epsilon = 1e-6);
|
||||
|
||||
assert_ulps_eq!(
|
||||
h.batches[0].skills.get(a).unwrap().posterior(),
|
||||
h.batches[0].skills.get(b).unwrap().posterior(),
|
||||
h.time_slices[0].skills.get(a).unwrap().posterior(),
|
||||
h.time_slices[0].skills.get(b).unwrap().posterior(),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
assert_ulps_eq!(
|
||||
h.batches[0].skills.get(c).unwrap().posterior(),
|
||||
h.batches[0].skills.get(d).unwrap().posterior(),
|
||||
h.time_slices[0].skills.get(c).unwrap().posterior(),
|
||||
h.time_slices[0].skills.get(d).unwrap().posterior(),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
assert_ulps_eq!(
|
||||
h.batches[1].skills.get(e).unwrap().posterior(),
|
||||
h.batches[1].skills.get(f).unwrap().posterior(),
|
||||
h.time_slices[1].skills.get(e).unwrap().posterior(),
|
||||
h.time_slices[1].skills.get(f).unwrap().posterior(),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
|
||||
assert_ulps_eq!(
|
||||
h.batches[0].skills.get(a).unwrap().posterior(),
|
||||
h.time_slices[0].skills.get(a).unwrap().posterior(),
|
||||
Gaussian::from_ms(4.084902, 5.106919),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
assert_ulps_eq!(
|
||||
h.batches[0].skills.get(c).unwrap().posterior(),
|
||||
h.time_slices[0].skills.get(c).unwrap().posterior(),
|
||||
Gaussian::from_ms(-0.533029, 5.106919),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
assert_ulps_eq!(
|
||||
h.batches[2].skills.get(e).unwrap().posterior(),
|
||||
h.time_slices[2].skills.get(e).unwrap().posterior(),
|
||||
Gaussian::from_ms(-3.551872, 5.154569),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
@@ -836,31 +845,31 @@ mod tests {
|
||||
|
||||
h.convergence(ITERATIONS, EPSILON, false);
|
||||
|
||||
assert_eq!(h.batches[2].skills.get(b).unwrap().elapsed, 1);
|
||||
assert_eq!(h.batches[2].skills.get(c).unwrap().elapsed, 1);
|
||||
assert_eq!(h.time_slices[2].skills.get(b).unwrap().elapsed, 1);
|
||||
assert_eq!(h.time_slices[2].skills.get(c).unwrap().elapsed, 1);
|
||||
|
||||
assert_ulps_eq!(
|
||||
h.batches[0].skills.get(a).unwrap().posterior(),
|
||||
h.time_slices[0].skills.get(a).unwrap().posterior(),
|
||||
Gaussian::from_ms(0.000000, 1.300610),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
assert_ulps_eq!(
|
||||
h.batches[0].skills.get(b).unwrap().posterior(),
|
||||
h.time_slices[0].skills.get(b).unwrap().posterior(),
|
||||
Gaussian::from_ms(0.000000, 1.300610),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
assert_ulps_eq!(
|
||||
h.batches[2].skills.get(b).unwrap().posterior(),
|
||||
h.time_slices[2].skills.get(b).unwrap().posterior(),
|
||||
Gaussian::from_ms(0.000000, 1.300610),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
|
||||
h.add_events(composition, results, vec![], vec![]);
|
||||
|
||||
assert_eq!(h.batches.len(), 6);
|
||||
assert_eq!(h.time_slices.len(), 6);
|
||||
|
||||
assert_eq!(
|
||||
h.batches
|
||||
h.time_slices
|
||||
.iter()
|
||||
.map(|b| b.get_composition())
|
||||
.collect::<Vec<_>>(),
|
||||
@@ -877,22 +886,22 @@ mod tests {
|
||||
h.convergence(ITERATIONS, EPSILON, false);
|
||||
|
||||
assert_ulps_eq!(
|
||||
h.batches[0].skills.get(a).unwrap().posterior(),
|
||||
h.time_slices[0].skills.get(a).unwrap().posterior(),
|
||||
Gaussian::from_ms(0.000000, 0.931236),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
assert_ulps_eq!(
|
||||
h.batches[3].skills.get(a).unwrap().posterior(),
|
||||
h.time_slices[3].skills.get(a).unwrap().posterior(),
|
||||
Gaussian::from_ms(0.000000, 0.931236),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
assert_ulps_eq!(
|
||||
h.batches[3].skills.get(b).unwrap().posterior(),
|
||||
h.time_slices[3].skills.get(b).unwrap().posterior(),
|
||||
Gaussian::from_ms(0.000000, 0.931236),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
assert_ulps_eq!(
|
||||
h.batches[5].skills.get(b).unwrap().posterior(),
|
||||
h.time_slices[5].skills.get(b).unwrap().posterior(),
|
||||
Gaussian::from_ms(0.000000, 0.931236),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
@@ -925,31 +934,31 @@ mod tests {
|
||||
|
||||
h.convergence(ITERATIONS, EPSILON, false);
|
||||
|
||||
assert_eq!(h.batches[2].skills.get(b).unwrap().elapsed, 1);
|
||||
assert_eq!(h.batches[2].skills.get(c).unwrap().elapsed, 1);
|
||||
assert_eq!(h.time_slices[2].skills.get(b).unwrap().elapsed, 1);
|
||||
assert_eq!(h.time_slices[2].skills.get(c).unwrap().elapsed, 1);
|
||||
|
||||
assert_ulps_eq!(
|
||||
h.batches[0].skills.get(a).unwrap().posterior(),
|
||||
h.time_slices[0].skills.get(a).unwrap().posterior(),
|
||||
Gaussian::from_ms(0.000000, 1.300610),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
assert_ulps_eq!(
|
||||
h.batches[0].skills.get(b).unwrap().posterior(),
|
||||
h.time_slices[0].skills.get(b).unwrap().posterior(),
|
||||
Gaussian::from_ms(0.000000, 1.300610),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
assert_ulps_eq!(
|
||||
h.batches[2].skills.get(b).unwrap().posterior(),
|
||||
h.time_slices[2].skills.get(b).unwrap().posterior(),
|
||||
Gaussian::from_ms(0.000000, 1.300610),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
|
||||
h.add_events(composition, results, vec![], vec![]);
|
||||
|
||||
assert_eq!(h.batches.len(), 6);
|
||||
assert_eq!(h.time_slices.len(), 6);
|
||||
|
||||
assert_eq!(
|
||||
h.batches
|
||||
h.time_slices
|
||||
.iter()
|
||||
.map(|b| b.get_composition())
|
||||
.collect::<Vec<_>>(),
|
||||
@@ -966,22 +975,22 @@ mod tests {
|
||||
h.convergence(ITERATIONS, EPSILON, false);
|
||||
|
||||
assert_ulps_eq!(
|
||||
h.batches[0].skills.get(a).unwrap().posterior(),
|
||||
h.time_slices[0].skills.get(a).unwrap().posterior(),
|
||||
Gaussian::from_ms(0.000000, 0.931236),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
assert_ulps_eq!(
|
||||
h.batches[3].skills.get(a).unwrap().posterior(),
|
||||
h.time_slices[3].skills.get(a).unwrap().posterior(),
|
||||
Gaussian::from_ms(0.000000, 0.931236),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
assert_ulps_eq!(
|
||||
h.batches[3].skills.get(b).unwrap().posterior(),
|
||||
h.time_slices[3].skills.get(b).unwrap().posterior(),
|
||||
Gaussian::from_ms(0.000000, 0.931236),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
assert_ulps_eq!(
|
||||
h.batches[5].skills.get(b).unwrap().posterior(),
|
||||
h.time_slices[5].skills.get(b).unwrap().posterior(),
|
||||
Gaussian::from_ms(0.000000, 0.931236),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
@@ -1079,18 +1088,18 @@ mod tests {
|
||||
|
||||
h.add_events(composition, results, vec![15, 10, 0], vec![]);
|
||||
|
||||
assert_eq!(h.batches.len(), 4);
|
||||
assert_eq!(h.time_slices.len(), 4);
|
||||
|
||||
assert_eq!(
|
||||
h.batches
|
||||
h.time_slices
|
||||
.iter()
|
||||
.map(|batch| batch.events.len())
|
||||
.map(|ts| ts.events.len())
|
||||
.collect::<Vec<_>>(),
|
||||
vec![2, 2, 1, 1]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
h.batches
|
||||
h.time_slices
|
||||
.iter()
|
||||
.map(|b| b.get_composition())
|
||||
.collect::<Vec<_>>(),
|
||||
@@ -1103,7 +1112,7 @@ mod tests {
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
h.batches
|
||||
h.time_slices
|
||||
.iter()
|
||||
.map(|b| b.get_results())
|
||||
.collect::<Vec<_>>(),
|
||||
@@ -1115,34 +1124,34 @@ mod tests {
|
||||
]
|
||||
);
|
||||
|
||||
let end = h.batches.len() - 1;
|
||||
let end = h.time_slices.len() - 1;
|
||||
|
||||
assert_eq!(h.batches[0].skills.get(c).unwrap().elapsed, 0);
|
||||
assert_eq!(h.batches[end].skills.get(c).unwrap().elapsed, 10);
|
||||
assert_eq!(h.time_slices[0].skills.get(c).unwrap().elapsed, 0);
|
||||
assert_eq!(h.time_slices[end].skills.get(c).unwrap().elapsed, 10);
|
||||
|
||||
assert_eq!(h.batches[0].skills.get(a).unwrap().elapsed, 0);
|
||||
assert_eq!(h.batches[2].skills.get(a).unwrap().elapsed, 5);
|
||||
assert_eq!(h.time_slices[0].skills.get(a).unwrap().elapsed, 0);
|
||||
assert_eq!(h.time_slices[2].skills.get(a).unwrap().elapsed, 5);
|
||||
|
||||
assert_eq!(h.batches[0].skills.get(b).unwrap().elapsed, 0);
|
||||
assert_eq!(h.batches[end].skills.get(b).unwrap().elapsed, 5);
|
||||
assert_eq!(h.time_slices[0].skills.get(b).unwrap().elapsed, 0);
|
||||
assert_eq!(h.time_slices[end].skills.get(b).unwrap().elapsed, 5);
|
||||
|
||||
h.convergence(ITERATIONS, EPSILON, false);
|
||||
|
||||
assert_ulps_eq!(
|
||||
h.batches[0].skills.get(b).unwrap().posterior(),
|
||||
h.batches[end].skills.get(b).unwrap().posterior(),
|
||||
h.time_slices[0].skills.get(b).unwrap().posterior(),
|
||||
h.time_slices[end].skills.get(b).unwrap().posterior(),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
|
||||
assert_ulps_eq!(
|
||||
h.batches[0].skills.get(c).unwrap().posterior(),
|
||||
h.batches[end].skills.get(c).unwrap().posterior(),
|
||||
h.time_slices[0].skills.get(c).unwrap().posterior(),
|
||||
h.time_slices[end].skills.get(c).unwrap().posterior(),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
|
||||
assert_ulps_eq!(
|
||||
h.batches[0].skills.get(c).unwrap().posterior(),
|
||||
h.batches[0].skills.get(b).unwrap().posterior(),
|
||||
h.time_slices[0].skills.get(c).unwrap().posterior(),
|
||||
h.time_slices[0].skills.get(b).unwrap().posterior(),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
|
||||
@@ -1167,18 +1176,18 @@ mod tests {
|
||||
|
||||
h.add_events(composition, vec![], vec![15, 10, 0], vec![]);
|
||||
|
||||
assert_eq!(h.batches.len(), 4);
|
||||
assert_eq!(h.time_slices.len(), 4);
|
||||
|
||||
assert_eq!(
|
||||
h.batches
|
||||
h.time_slices
|
||||
.iter()
|
||||
.map(|batch| batch.events.len())
|
||||
.map(|ts| ts.events.len())
|
||||
.collect::<Vec<_>>(),
|
||||
vec![2, 2, 1, 1]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
h.batches
|
||||
h.time_slices
|
||||
.iter()
|
||||
.map(|b| b.get_composition())
|
||||
.collect::<Vec<_>>(),
|
||||
@@ -1191,7 +1200,7 @@ mod tests {
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
h.batches
|
||||
h.time_slices
|
||||
.iter()
|
||||
.map(|b| b.get_results())
|
||||
.collect::<Vec<_>>(),
|
||||
@@ -1203,34 +1212,34 @@ mod tests {
|
||||
]
|
||||
);
|
||||
|
||||
let end = h.batches.len() - 1;
|
||||
let end = h.time_slices.len() - 1;
|
||||
|
||||
assert_eq!(h.batches[0].skills.get(c).unwrap().elapsed, 0);
|
||||
assert_eq!(h.batches[end].skills.get(c).unwrap().elapsed, 10);
|
||||
assert_eq!(h.time_slices[0].skills.get(c).unwrap().elapsed, 0);
|
||||
assert_eq!(h.time_slices[end].skills.get(c).unwrap().elapsed, 10);
|
||||
|
||||
assert_eq!(h.batches[0].skills.get(a).unwrap().elapsed, 0);
|
||||
assert_eq!(h.batches[2].skills.get(a).unwrap().elapsed, 5);
|
||||
assert_eq!(h.time_slices[0].skills.get(a).unwrap().elapsed, 0);
|
||||
assert_eq!(h.time_slices[2].skills.get(a).unwrap().elapsed, 5);
|
||||
|
||||
assert_eq!(h.batches[0].skills.get(b).unwrap().elapsed, 0);
|
||||
assert_eq!(h.batches[end].skills.get(b).unwrap().elapsed, 5);
|
||||
assert_eq!(h.time_slices[0].skills.get(b).unwrap().elapsed, 0);
|
||||
assert_eq!(h.time_slices[end].skills.get(b).unwrap().elapsed, 5);
|
||||
|
||||
h.convergence(ITERATIONS, EPSILON, false);
|
||||
|
||||
assert_ulps_eq!(
|
||||
h.batches[0].skills.get(b).unwrap().posterior(),
|
||||
h.batches[end].skills.get(b).unwrap().posterior(),
|
||||
h.time_slices[0].skills.get(b).unwrap().posterior(),
|
||||
h.time_slices[end].skills.get(b).unwrap().posterior(),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
|
||||
assert_ulps_eq!(
|
||||
h.batches[0].skills.get(c).unwrap().posterior(),
|
||||
h.batches[end].skills.get(c).unwrap().posterior(),
|
||||
h.time_slices[0].skills.get(c).unwrap().posterior(),
|
||||
h.time_slices[end].skills.get(c).unwrap().posterior(),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
|
||||
assert_ulps_eq!(
|
||||
h.batches[0].skills.get(c).unwrap().posterior(),
|
||||
h.batches[0].skills.get(b).unwrap().posterior(),
|
||||
h.time_slices[0].skills.get(c).unwrap().posterior(),
|
||||
h.time_slices[0].skills.get(b).unwrap().posterior(),
|
||||
epsilon = 1e-6
|
||||
);
|
||||
}
|
||||
|
||||
@@ -6,7 +6,8 @@ use std::{
|
||||
#[cfg(feature = "approx")]
|
||||
mod approx;
|
||||
pub(crate) mod arena;
|
||||
pub mod batch;
|
||||
mod time_slice;
|
||||
pub use time_slice::TimeSlice;
|
||||
mod competitor;
|
||||
pub mod drift;
|
||||
mod error;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::{Index, batch::Skill};
|
||||
use crate::{Index, time_slice::Skill};
|
||||
|
||||
/// Dense Vec-backed store for per-agent skill state within a TimeSlice.
|
||||
///
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
//! A single time step's worth of events.
|
||||
//!
|
||||
//! Renamed from `Batch` in T2.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
@@ -106,7 +110,7 @@ impl Event {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Batch {
|
||||
pub struct TimeSlice {
|
||||
pub(crate) events: Vec<Event>,
|
||||
pub(crate) skills: SkillStore,
|
||||
pub(crate) time: i64,
|
||||
@@ -114,7 +118,7 @@ pub struct Batch {
|
||||
arena: ScratchArena,
|
||||
}
|
||||
|
||||
impl Batch {
|
||||
impl TimeSlice {
|
||||
pub fn new(time: i64, p_draw: f64) -> Self {
|
||||
Self {
|
||||
events: Vec::new(),
|
||||
@@ -431,9 +435,9 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
let mut batch = Batch::new(0, 0.0);
|
||||
let mut time_slice = TimeSlice::new(0, 0.0);
|
||||
|
||||
batch.add_events(
|
||||
time_slice.add_events(
|
||||
vec![
|
||||
vec![vec![a], vec![b]],
|
||||
vec![vec![c], vec![d]],
|
||||
@@ -444,7 +448,7 @@ mod tests {
|
||||
&agents,
|
||||
);
|
||||
|
||||
let post = batch.posteriors();
|
||||
let post = time_slice.posteriors();
|
||||
|
||||
assert_ulps_eq!(
|
||||
post[&a],
|
||||
@@ -477,7 +481,7 @@ mod tests {
|
||||
epsilon = 1e-6
|
||||
);
|
||||
|
||||
assert_eq!(batch.convergence(&agents), 1);
|
||||
assert_eq!(time_slice.convergence(&agents), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -507,9 +511,9 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
let mut batch = Batch::new(0, 0.0);
|
||||
let mut time_slice = TimeSlice::new(0, 0.0);
|
||||
|
||||
batch.add_events(
|
||||
time_slice.add_events(
|
||||
vec![
|
||||
vec![vec![a], vec![b]],
|
||||
vec![vec![a], vec![c]],
|
||||
@@ -520,7 +524,7 @@ mod tests {
|
||||
&agents,
|
||||
);
|
||||
|
||||
let post = batch.posteriors();
|
||||
let post = time_slice.posteriors();
|
||||
|
||||
assert_ulps_eq!(
|
||||
post[&a],
|
||||
@@ -538,9 +542,9 @@ mod tests {
|
||||
epsilon = 1e-6
|
||||
);
|
||||
|
||||
assert!(batch.convergence(&agents) > 1);
|
||||
assert!(time_slice.convergence(&agents) > 1);
|
||||
|
||||
let post = batch.posteriors();
|
||||
let post = time_slice.posteriors();
|
||||
|
||||
assert_ulps_eq!(
|
||||
post[&a],
|
||||
@@ -586,9 +590,9 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
let mut batch = Batch::new(0, 0.0);
|
||||
let mut time_slice = TimeSlice::new(0, 0.0);
|
||||
|
||||
batch.add_events(
|
||||
time_slice.add_events(
|
||||
vec![
|
||||
vec![vec![a], vec![b]],
|
||||
vec![vec![a], vec![c]],
|
||||
@@ -599,9 +603,9 @@ mod tests {
|
||||
&agents,
|
||||
);
|
||||
|
||||
batch.convergence(&agents);
|
||||
time_slice.convergence(&agents);
|
||||
|
||||
let post = batch.posteriors();
|
||||
let post = time_slice.posteriors();
|
||||
|
||||
assert_ulps_eq!(
|
||||
post[&a],
|
||||
@@ -619,7 +623,7 @@ mod tests {
|
||||
epsilon = 1e-6
|
||||
);
|
||||
|
||||
batch.add_events(
|
||||
time_slice.add_events(
|
||||
vec![
|
||||
vec![vec![a], vec![b]],
|
||||
vec![vec![a], vec![c]],
|
||||
@@ -630,11 +634,11 @@ mod tests {
|
||||
&agents,
|
||||
);
|
||||
|
||||
assert_eq!(batch.events.len(), 6);
|
||||
assert_eq!(time_slice.events.len(), 6);
|
||||
|
||||
batch.convergence(&agents);
|
||||
time_slice.convergence(&agents);
|
||||
|
||||
let post = batch.posteriors();
|
||||
let post = time_slice.posteriors();
|
||||
|
||||
assert_ulps_eq!(
|
||||
post[&a],
|
||||
Reference in New Issue
Block a user