T0 + T1 + T2: engine redesign through new API surface #1

Merged
logaritmisk merged 45 commits from t2-new-api-surface into main 2026-04-24 11:20:04 +00:00
6 changed files with 178 additions and 164 deletions
Showing only changes of commit 5e752f9e98 - Show all commits

View File

@@ -1,7 +1,7 @@
use criterion::{Criterion, criterion_group, criterion_main};
use trueskill_tt::{
BETA, Competitor, GAMMA, KeyTable, MU, P_DRAW, Rating, SIGMA, batch::Batch,
drift::ConstantDrift, gaussian::Gaussian, storage::CompetitorStore,
BETA, Competitor, GAMMA, KeyTable, MU, P_DRAW, Rating, SIGMA, TimeSlice, drift::ConstantDrift,
gaussian::Gaussian, storage::CompetitorStore,
};
fn criterion_benchmark(criterion: &mut Criterion) {
@@ -33,11 +33,11 @@ fn criterion_benchmark(criterion: &mut Criterion) {
weights.push(vec![vec![1.0], vec![1.0]]);
}
let mut batch = Batch::new(1, P_DRAW);
batch.add_events(composition, results, weights, &agents);
let mut time_slice = TimeSlice::new(1, P_DRAW);
time_slice.add_events(composition, results, weights, &agents);
criterion.bench_function("Batch::iteration", |b| {
b.iter(|| batch.iteration(0, &agents))
b.iter(|| time_slice.iteration(0, &agents))
});
}

View File

@@ -2,7 +2,7 @@ use crate::{factor::VarStore, gaussian::Gaussian};
/// Reusable scratch buffers for `Game::likelihoods`.
///
/// A `Batch` owns one arena; all events in the slice share it across
/// A `TimeSlice` owns one arena; all events in the slice share it across
/// the convergence iterations. All Vecs are cleared (not dropped) on
/// `reset()` so their heap capacity is reused across games.
#[derive(Debug, Default)]

View File

@@ -2,13 +2,13 @@ use std::collections::HashMap;
use crate::{
BETA, GAMMA, Index, MU, N_INF, P_DRAW, SIGMA,
batch::{self, Batch},
competitor::{self, Competitor},
drift::{ConstantDrift, Drift},
gaussian::Gaussian,
rating::Rating,
sort_time,
storage::CompetitorStore,
time_slice::{self, TimeSlice},
tuple_gt, tuple_max,
};
@@ -69,7 +69,7 @@ impl<D: Drift> HistoryBuilder<D> {
pub fn build(self) -> History<D> {
History {
size: 0,
batches: Vec::new(),
time_slices: Vec::new(),
agents: CompetitorStore::new(),
time: self.time,
mu: self.mu,
@@ -105,7 +105,7 @@ impl Default for HistoryBuilder<ConstantDrift> {
pub struct History<D: Drift = ConstantDrift> {
size: usize,
pub(crate) batches: Vec<Batch>,
pub(crate) time_slices: Vec<TimeSlice>,
agents: CompetitorStore<D>,
time: bool,
mu: f64,
@@ -120,7 +120,7 @@ impl Default for History<ConstantDrift> {
fn default() -> Self {
Self {
size: 0,
batches: Vec::new(),
time_slices: Vec::new(),
agents: CompetitorStore::new(),
time: true,
mu: MU,
@@ -145,17 +145,17 @@ impl<D: Drift> History<D> {
competitor::clean(self.agents.values_mut(), false);
for j in (0..self.batches.len() - 1).rev() {
for agent in self.batches[j + 1].skills.keys() {
for j in (0..self.time_slices.len() - 1).rev() {
for agent in self.time_slices[j + 1].skills.keys() {
self.agents.get_mut(agent).unwrap().message =
self.batches[j + 1].backward_prior_out(&agent, &self.agents);
self.time_slices[j + 1].backward_prior_out(&agent, &self.agents);
}
let old = self.batches[j].posteriors();
let old = self.time_slices[j].posteriors();
self.batches[j].new_backward_info(&self.agents);
self.time_slices[j].new_backward_info(&self.agents);
let new = self.batches[j].posteriors();
let new = self.time_slices[j].posteriors();
step = old
.iter()
@@ -164,29 +164,29 @@ impl<D: Drift> History<D> {
competitor::clean(self.agents.values_mut(), false);
for j in 1..self.batches.len() {
for agent in self.batches[j - 1].skills.keys() {
for j in 1..self.time_slices.len() {
for agent in self.time_slices[j - 1].skills.keys() {
self.agents.get_mut(agent).unwrap().message =
self.batches[j - 1].forward_prior_out(&agent);
self.time_slices[j - 1].forward_prior_out(&agent);
}
let old = self.batches[j].posteriors();
let old = self.time_slices[j].posteriors();
self.batches[j].new_forward_info(&self.agents);
self.time_slices[j].new_forward_info(&self.agents);
let new = self.batches[j].posteriors();
let new = self.time_slices[j].posteriors();
step = old
.iter()
.fold(step, |step, (a, old)| tuple_max(step, old.delta(new[a])));
}
if self.batches.len() == 1 {
let old = self.batches[0].posteriors();
if self.time_slices.len() == 1 {
let old = self.time_slices[0].posteriors();
self.batches[0].iteration(0, &self.agents);
self.time_slices[0].iteration(0, &self.agents);
let new = self.batches[0].posteriors();
let new = self.time_slices[0].posteriors();
step = old
.iter()
@@ -229,7 +229,7 @@ impl<D: Drift> History<D> {
pub fn learning_curves(&self) -> HashMap<Index, Vec<(i64, Gaussian)>> {
let mut data: HashMap<Index, Vec<(i64, Gaussian)>> = HashMap::new();
for b in &self.batches {
for b in &self.time_slices {
for (agent, skill) in b.skills.iter() {
let point = (b.time, skill.posterior());
@@ -245,9 +245,9 @@ impl<D: Drift> History<D> {
}
pub fn log_evidence(&mut self, forward: bool, targets: &[Index]) -> f64 {
self.batches
self.time_slices
.iter()
.map(|batch| batch.log_evidence(self.online, targets, forward, &self.agents))
.map(|ts| ts.log_evidence(self.online, targets, forward, &self.agents))
.sum()
}
@@ -335,24 +335,26 @@ impl<D: Drift> History<D> {
}
while (!self.time && (self.size > k))
|| (self.time && self.batches.len() > k && self.batches[k].time < t)
|| (self.time && self.time_slices.len() > k && self.time_slices[k].time < t)
{
let batch = &mut self.batches[k];
let time_slice = &mut self.time_slices[k];
if k > 0 {
batch.new_forward_info(&self.agents);
time_slice.new_forward_info(&self.agents);
}
// TODO: Is it faster to iterate over agents in batch instead?
for agent_idx in &this_agent {
if let Some(skill) = batch.skills.get_mut(*agent_idx) {
skill.elapsed =
batch::compute_elapsed(self.agents[*agent_idx].last_time, batch.time);
if let Some(skill) = time_slice.skills.get_mut(*agent_idx) {
skill.elapsed = time_slice::compute_elapsed(
self.agents[*agent_idx].last_time,
time_slice.time,
);
let agent = self.agents.get_mut(*agent_idx).unwrap();
agent.last_time = if self.time { batch.time } else { i64::MAX };
agent.message = batch.forward_prior_out(agent_idx);
agent.last_time = if self.time { time_slice.time } else { i64::MAX };
agent.message = time_slice.forward_prior_out(agent_idx);
}
}
@@ -375,29 +377,29 @@ impl<D: Drift> History<D> {
(i..j).map(|e| weights[o[e]].clone()).collect::<Vec<_>>()
};
if self.time && self.batches.len() > k && self.batches[k].time == t {
let batch = &mut self.batches[k];
batch.add_events(composition, results, weights, &self.agents);
if self.time && self.time_slices.len() > k && self.time_slices[k].time == t {
let time_slice = &mut self.time_slices[k];
time_slice.add_events(composition, results, weights, &self.agents);
for agent_idx in batch.skills.keys() {
for agent_idx in time_slice.skills.keys() {
let agent = self.agents.get_mut(agent_idx).unwrap();
agent.last_time = if self.time { t } else { i64::MAX };
agent.message = batch.forward_prior_out(&agent_idx);
agent.message = time_slice.forward_prior_out(&agent_idx);
}
} else {
let mut batch: Batch = Batch::new(t, self.p_draw);
batch.add_events(composition, results, weights, &self.agents);
let mut time_slice: TimeSlice = TimeSlice::new(t, self.p_draw);
time_slice.add_events(composition, results, weights, &self.agents);
self.batches.insert(k, batch);
self.time_slices.insert(k, time_slice);
let batch = &self.batches[k];
let time_slice = &self.time_slices[k];
for agent_idx in batch.skills.keys() {
for agent_idx in time_slice.skills.keys() {
let agent = self.agents.get_mut(agent_idx).unwrap();
agent.last_time = if self.time { t } else { i64::MAX };
agent.message = batch.forward_prior_out(&agent_idx);
agent.message = time_slice.forward_prior_out(&agent_idx);
}
k += 1;
@@ -406,21 +408,23 @@ impl<D: Drift> History<D> {
i = j;
}
while self.time && self.batches.len() > k {
let batch = &mut self.batches[k];
while self.time && self.time_slices.len() > k {
let time_slice = &mut self.time_slices[k];
batch.new_forward_info(&self.agents);
time_slice.new_forward_info(&self.agents);
// TODO: Is it faster to iterate over agents in batch instead?
for agent_idx in &this_agent {
if let Some(skill) = batch.skills.get_mut(*agent_idx) {
skill.elapsed =
batch::compute_elapsed(self.agents[*agent_idx].last_time, batch.time);
if let Some(skill) = time_slice.skills.get_mut(*agent_idx) {
skill.elapsed = time_slice::compute_elapsed(
self.agents[*agent_idx].last_time,
time_slice.time,
);
let agent = self.agents.get_mut(*agent_idx).unwrap();
agent.last_time = if self.time { batch.time } else { i64::MAX };
agent.message = batch.forward_prior_out(agent_idx);
agent.last_time = if self.time { time_slice.time } else { i64::MAX };
agent.message = time_slice.forward_prior_out(agent_idx);
}
}
@@ -473,7 +477,7 @@ mod tests {
h.add_events_with_prior(composition, results, vec![1, 2, 3], vec![], priors);
let p0 = h.batches[0].posteriors();
let p0 = h.time_slices[0].posteriors();
assert_ulps_eq!(
p0[&a],
@@ -481,10 +485,10 @@ mod tests {
epsilon = 1e-6
);
let observed = h.batches[1].skills.get(a).unwrap().forward.sigma();
let observed = h.time_slices[1].skills.get(a).unwrap().forward.sigma();
let gamma: f64 = 0.15 * 25.0 / 3.0;
let expected = (gamma.powi(2)
+ h.batches[0]
+ h.time_slices[0]
.skills
.get(a)
.unwrap()
@@ -495,11 +499,16 @@ mod tests {
assert_ulps_eq!(observed, expected, epsilon = 0.000001);
let observed = h.batches[1].skills.get(a).unwrap().posterior();
let observed = h.time_slices[1].skills.get(a).unwrap().posterior();
let w = [vec![1.0], vec![1.0]];
let p = Game::new(
h.batches[1].events[0].within_priors(false, false, &h.batches[1].skills, &h.agents),
h.time_slices[1].events[0].within_priors(
false,
false,
&h.time_slices[1].skills,
&h.agents,
),
&[0.0, 1.0],
&w,
P_DRAW,
@@ -545,12 +554,12 @@ mod tests {
h1.add_events_with_prior(composition, results, times, vec![], priors);
assert_ulps_eq!(
h1.batches[0].skills.get(a).unwrap().posterior(),
h1.time_slices[0].skills.get(a).unwrap().posterior(),
Gaussian::from_ms(22.904409, 6.010330),
epsilon = 1e-6
);
assert_ulps_eq!(
h1.batches[0].skills.get(c).unwrap().posterior(),
h1.time_slices[0].skills.get(c).unwrap().posterior(),
Gaussian::from_ms(25.110318, 5.866311),
epsilon = 1e-6
);
@@ -558,12 +567,12 @@ mod tests {
h1.convergence(ITERATIONS, EPSILON, false);
assert_ulps_eq!(
h1.batches[0].skills.get(a).unwrap().posterior(),
h1.time_slices[0].skills.get(a).unwrap().posterior(),
Gaussian::from_ms(25.000000, 5.419212),
epsilon = 1e-6
);
assert_ulps_eq!(
h1.batches[0].skills.get(c).unwrap().posterior(),
h1.time_slices[0].skills.get(c).unwrap().posterior(),
Gaussian::from_ms(25.000000, 5.419212),
epsilon = 1e-6
);
@@ -594,12 +603,12 @@ mod tests {
h2.add_events_with_prior(composition, results, times, vec![], priors);
assert_ulps_eq!(
h2.batches[2].skills.get(a).unwrap().posterior(),
h2.time_slices[2].skills.get(a).unwrap().posterior(),
Gaussian::from_ms(22.903522, 6.011017),
epsilon = 1e-6
);
assert_ulps_eq!(
h2.batches[2].skills.get(c).unwrap().posterior(),
h2.time_slices[2].skills.get(c).unwrap().posterior(),
Gaussian::from_ms(25.110702, 5.866811),
epsilon = 1e-6
);
@@ -607,12 +616,12 @@ mod tests {
h2.convergence(ITERATIONS, EPSILON, false);
assert_ulps_eq!(
h2.batches[2].skills.get(a).unwrap().posterior(),
h2.time_slices[2].skills.get(a).unwrap().posterior(),
Gaussian::from_ms(24.998668, 5.420053),
epsilon = 1e-6
);
assert_ulps_eq!(
h2.batches[2].skills.get(c).unwrap().posterior(),
h2.time_slices[2].skills.get(c).unwrap().posterior(),
Gaussian::from_ms(25.000532, 5.419827),
epsilon = 1e-6
);
@@ -699,21 +708,21 @@ mod tests {
h.convergence(ITERATIONS, EPSILON, false);
assert_eq!(h.batches[2].skills.get(b).unwrap().elapsed, 1);
assert_eq!(h.batches[2].skills.get(c).unwrap().elapsed, 1);
assert_eq!(h.time_slices[2].skills.get(b).unwrap().elapsed, 1);
assert_eq!(h.time_slices[2].skills.get(c).unwrap().elapsed, 1);
assert_ulps_eq!(
h.batches[0].skills.get(a).unwrap().posterior(),
h.time_slices[0].skills.get(a).unwrap().posterior(),
Gaussian::from_ms(25.000267, 5.419381),
epsilon = 1e-6
);
assert_ulps_eq!(
h.batches[0].skills.get(b).unwrap().posterior(),
h.time_slices[0].skills.get(b).unwrap().posterior(),
Gaussian::from_ms(24.999465, 5.419425),
epsilon = 1e-6
);
assert_ulps_eq!(
h.batches[2].skills.get(b).unwrap().posterior(),
h.time_slices[2].skills.get(b).unwrap().posterior(),
Gaussian::from_ms(25.000532, 5.419696),
epsilon = 1e-6
);
@@ -757,8 +766,8 @@ mod tests {
);
assert_ulps_eq!(
h.batches[0].skills.get(b).unwrap().posterior().mu(),
-h.batches[0].skills.get(c).unwrap().posterior().mu(),
h.time_slices[0].skills.get(b).unwrap().posterior().mu(),
-h.time_slices[0].skills.get(c).unwrap().posterior().mu(),
epsilon = 1e-6
);
@@ -777,33 +786,33 @@ mod tests {
assert_ulps_eq!(p_d_m_hat, 0.172432, epsilon = 1e-6);
assert_ulps_eq!(
h.batches[0].skills.get(a).unwrap().posterior(),
h.batches[0].skills.get(b).unwrap().posterior(),
h.time_slices[0].skills.get(a).unwrap().posterior(),
h.time_slices[0].skills.get(b).unwrap().posterior(),
epsilon = 1e-6
);
assert_ulps_eq!(
h.batches[0].skills.get(c).unwrap().posterior(),
h.batches[0].skills.get(d).unwrap().posterior(),
h.time_slices[0].skills.get(c).unwrap().posterior(),
h.time_slices[0].skills.get(d).unwrap().posterior(),
epsilon = 1e-6
);
assert_ulps_eq!(
h.batches[1].skills.get(e).unwrap().posterior(),
h.batches[1].skills.get(f).unwrap().posterior(),
h.time_slices[1].skills.get(e).unwrap().posterior(),
h.time_slices[1].skills.get(f).unwrap().posterior(),
epsilon = 1e-6
);
assert_ulps_eq!(
h.batches[0].skills.get(a).unwrap().posterior(),
h.time_slices[0].skills.get(a).unwrap().posterior(),
Gaussian::from_ms(4.084902, 5.106919),
epsilon = 1e-6
);
assert_ulps_eq!(
h.batches[0].skills.get(c).unwrap().posterior(),
h.time_slices[0].skills.get(c).unwrap().posterior(),
Gaussian::from_ms(-0.533029, 5.106919),
epsilon = 1e-6
);
assert_ulps_eq!(
h.batches[2].skills.get(e).unwrap().posterior(),
h.time_slices[2].skills.get(e).unwrap().posterior(),
Gaussian::from_ms(-3.551872, 5.154569),
epsilon = 1e-6
);
@@ -836,31 +845,31 @@ mod tests {
h.convergence(ITERATIONS, EPSILON, false);
assert_eq!(h.batches[2].skills.get(b).unwrap().elapsed, 1);
assert_eq!(h.batches[2].skills.get(c).unwrap().elapsed, 1);
assert_eq!(h.time_slices[2].skills.get(b).unwrap().elapsed, 1);
assert_eq!(h.time_slices[2].skills.get(c).unwrap().elapsed, 1);
assert_ulps_eq!(
h.batches[0].skills.get(a).unwrap().posterior(),
h.time_slices[0].skills.get(a).unwrap().posterior(),
Gaussian::from_ms(0.000000, 1.300610),
epsilon = 1e-6
);
assert_ulps_eq!(
h.batches[0].skills.get(b).unwrap().posterior(),
h.time_slices[0].skills.get(b).unwrap().posterior(),
Gaussian::from_ms(0.000000, 1.300610),
epsilon = 1e-6
);
assert_ulps_eq!(
h.batches[2].skills.get(b).unwrap().posterior(),
h.time_slices[2].skills.get(b).unwrap().posterior(),
Gaussian::from_ms(0.000000, 1.300610),
epsilon = 1e-6
);
h.add_events(composition, results, vec![], vec![]);
assert_eq!(h.batches.len(), 6);
assert_eq!(h.time_slices.len(), 6);
assert_eq!(
h.batches
h.time_slices
.iter()
.map(|b| b.get_composition())
.collect::<Vec<_>>(),
@@ -877,22 +886,22 @@ mod tests {
h.convergence(ITERATIONS, EPSILON, false);
assert_ulps_eq!(
h.batches[0].skills.get(a).unwrap().posterior(),
h.time_slices[0].skills.get(a).unwrap().posterior(),
Gaussian::from_ms(0.000000, 0.931236),
epsilon = 1e-6
);
assert_ulps_eq!(
h.batches[3].skills.get(a).unwrap().posterior(),
h.time_slices[3].skills.get(a).unwrap().posterior(),
Gaussian::from_ms(0.000000, 0.931236),
epsilon = 1e-6
);
assert_ulps_eq!(
h.batches[3].skills.get(b).unwrap().posterior(),
h.time_slices[3].skills.get(b).unwrap().posterior(),
Gaussian::from_ms(0.000000, 0.931236),
epsilon = 1e-6
);
assert_ulps_eq!(
h.batches[5].skills.get(b).unwrap().posterior(),
h.time_slices[5].skills.get(b).unwrap().posterior(),
Gaussian::from_ms(0.000000, 0.931236),
epsilon = 1e-6
);
@@ -925,31 +934,31 @@ mod tests {
h.convergence(ITERATIONS, EPSILON, false);
assert_eq!(h.batches[2].skills.get(b).unwrap().elapsed, 1);
assert_eq!(h.batches[2].skills.get(c).unwrap().elapsed, 1);
assert_eq!(h.time_slices[2].skills.get(b).unwrap().elapsed, 1);
assert_eq!(h.time_slices[2].skills.get(c).unwrap().elapsed, 1);
assert_ulps_eq!(
h.batches[0].skills.get(a).unwrap().posterior(),
h.time_slices[0].skills.get(a).unwrap().posterior(),
Gaussian::from_ms(0.000000, 1.300610),
epsilon = 1e-6
);
assert_ulps_eq!(
h.batches[0].skills.get(b).unwrap().posterior(),
h.time_slices[0].skills.get(b).unwrap().posterior(),
Gaussian::from_ms(0.000000, 1.300610),
epsilon = 1e-6
);
assert_ulps_eq!(
h.batches[2].skills.get(b).unwrap().posterior(),
h.time_slices[2].skills.get(b).unwrap().posterior(),
Gaussian::from_ms(0.000000, 1.300610),
epsilon = 1e-6
);
h.add_events(composition, results, vec![], vec![]);
assert_eq!(h.batches.len(), 6);
assert_eq!(h.time_slices.len(), 6);
assert_eq!(
h.batches
h.time_slices
.iter()
.map(|b| b.get_composition())
.collect::<Vec<_>>(),
@@ -966,22 +975,22 @@ mod tests {
h.convergence(ITERATIONS, EPSILON, false);
assert_ulps_eq!(
h.batches[0].skills.get(a).unwrap().posterior(),
h.time_slices[0].skills.get(a).unwrap().posterior(),
Gaussian::from_ms(0.000000, 0.931236),
epsilon = 1e-6
);
assert_ulps_eq!(
h.batches[3].skills.get(a).unwrap().posterior(),
h.time_slices[3].skills.get(a).unwrap().posterior(),
Gaussian::from_ms(0.000000, 0.931236),
epsilon = 1e-6
);
assert_ulps_eq!(
h.batches[3].skills.get(b).unwrap().posterior(),
h.time_slices[3].skills.get(b).unwrap().posterior(),
Gaussian::from_ms(0.000000, 0.931236),
epsilon = 1e-6
);
assert_ulps_eq!(
h.batches[5].skills.get(b).unwrap().posterior(),
h.time_slices[5].skills.get(b).unwrap().posterior(),
Gaussian::from_ms(0.000000, 0.931236),
epsilon = 1e-6
);
@@ -1079,18 +1088,18 @@ mod tests {
h.add_events(composition, results, vec![15, 10, 0], vec![]);
assert_eq!(h.batches.len(), 4);
assert_eq!(h.time_slices.len(), 4);
assert_eq!(
h.batches
h.time_slices
.iter()
.map(|batch| batch.events.len())
.map(|ts| ts.events.len())
.collect::<Vec<_>>(),
vec![2, 2, 1, 1]
);
assert_eq!(
h.batches
h.time_slices
.iter()
.map(|b| b.get_composition())
.collect::<Vec<_>>(),
@@ -1103,7 +1112,7 @@ mod tests {
);
assert_eq!(
h.batches
h.time_slices
.iter()
.map(|b| b.get_results())
.collect::<Vec<_>>(),
@@ -1115,34 +1124,34 @@ mod tests {
]
);
let end = h.batches.len() - 1;
let end = h.time_slices.len() - 1;
assert_eq!(h.batches[0].skills.get(c).unwrap().elapsed, 0);
assert_eq!(h.batches[end].skills.get(c).unwrap().elapsed, 10);
assert_eq!(h.time_slices[0].skills.get(c).unwrap().elapsed, 0);
assert_eq!(h.time_slices[end].skills.get(c).unwrap().elapsed, 10);
assert_eq!(h.batches[0].skills.get(a).unwrap().elapsed, 0);
assert_eq!(h.batches[2].skills.get(a).unwrap().elapsed, 5);
assert_eq!(h.time_slices[0].skills.get(a).unwrap().elapsed, 0);
assert_eq!(h.time_slices[2].skills.get(a).unwrap().elapsed, 5);
assert_eq!(h.batches[0].skills.get(b).unwrap().elapsed, 0);
assert_eq!(h.batches[end].skills.get(b).unwrap().elapsed, 5);
assert_eq!(h.time_slices[0].skills.get(b).unwrap().elapsed, 0);
assert_eq!(h.time_slices[end].skills.get(b).unwrap().elapsed, 5);
h.convergence(ITERATIONS, EPSILON, false);
assert_ulps_eq!(
h.batches[0].skills.get(b).unwrap().posterior(),
h.batches[end].skills.get(b).unwrap().posterior(),
h.time_slices[0].skills.get(b).unwrap().posterior(),
h.time_slices[end].skills.get(b).unwrap().posterior(),
epsilon = 1e-6
);
assert_ulps_eq!(
h.batches[0].skills.get(c).unwrap().posterior(),
h.batches[end].skills.get(c).unwrap().posterior(),
h.time_slices[0].skills.get(c).unwrap().posterior(),
h.time_slices[end].skills.get(c).unwrap().posterior(),
epsilon = 1e-6
);
assert_ulps_eq!(
h.batches[0].skills.get(c).unwrap().posterior(),
h.batches[0].skills.get(b).unwrap().posterior(),
h.time_slices[0].skills.get(c).unwrap().posterior(),
h.time_slices[0].skills.get(b).unwrap().posterior(),
epsilon = 1e-6
);
@@ -1167,18 +1176,18 @@ mod tests {
h.add_events(composition, vec![], vec![15, 10, 0], vec![]);
assert_eq!(h.batches.len(), 4);
assert_eq!(h.time_slices.len(), 4);
assert_eq!(
h.batches
h.time_slices
.iter()
.map(|batch| batch.events.len())
.map(|ts| ts.events.len())
.collect::<Vec<_>>(),
vec![2, 2, 1, 1]
);
assert_eq!(
h.batches
h.time_slices
.iter()
.map(|b| b.get_composition())
.collect::<Vec<_>>(),
@@ -1191,7 +1200,7 @@ mod tests {
);
assert_eq!(
h.batches
h.time_slices
.iter()
.map(|b| b.get_results())
.collect::<Vec<_>>(),
@@ -1203,34 +1212,34 @@ mod tests {
]
);
let end = h.batches.len() - 1;
let end = h.time_slices.len() - 1;
assert_eq!(h.batches[0].skills.get(c).unwrap().elapsed, 0);
assert_eq!(h.batches[end].skills.get(c).unwrap().elapsed, 10);
assert_eq!(h.time_slices[0].skills.get(c).unwrap().elapsed, 0);
assert_eq!(h.time_slices[end].skills.get(c).unwrap().elapsed, 10);
assert_eq!(h.batches[0].skills.get(a).unwrap().elapsed, 0);
assert_eq!(h.batches[2].skills.get(a).unwrap().elapsed, 5);
assert_eq!(h.time_slices[0].skills.get(a).unwrap().elapsed, 0);
assert_eq!(h.time_slices[2].skills.get(a).unwrap().elapsed, 5);
assert_eq!(h.batches[0].skills.get(b).unwrap().elapsed, 0);
assert_eq!(h.batches[end].skills.get(b).unwrap().elapsed, 5);
assert_eq!(h.time_slices[0].skills.get(b).unwrap().elapsed, 0);
assert_eq!(h.time_slices[end].skills.get(b).unwrap().elapsed, 5);
h.convergence(ITERATIONS, EPSILON, false);
assert_ulps_eq!(
h.batches[0].skills.get(b).unwrap().posterior(),
h.batches[end].skills.get(b).unwrap().posterior(),
h.time_slices[0].skills.get(b).unwrap().posterior(),
h.time_slices[end].skills.get(b).unwrap().posterior(),
epsilon = 1e-6
);
assert_ulps_eq!(
h.batches[0].skills.get(c).unwrap().posterior(),
h.batches[end].skills.get(c).unwrap().posterior(),
h.time_slices[0].skills.get(c).unwrap().posterior(),
h.time_slices[end].skills.get(c).unwrap().posterior(),
epsilon = 1e-6
);
assert_ulps_eq!(
h.batches[0].skills.get(c).unwrap().posterior(),
h.batches[0].skills.get(b).unwrap().posterior(),
h.time_slices[0].skills.get(c).unwrap().posterior(),
h.time_slices[0].skills.get(b).unwrap().posterior(),
epsilon = 1e-6
);
}

View File

@@ -6,7 +6,8 @@ use std::{
#[cfg(feature = "approx")]
mod approx;
pub(crate) mod arena;
pub mod batch;
mod time_slice;
pub use time_slice::TimeSlice;
mod competitor;
pub mod drift;
mod error;

View File

@@ -1,4 +1,4 @@
use crate::{Index, batch::Skill};
use crate::{Index, time_slice::Skill};
/// Dense Vec-backed store for per-agent skill state within a TimeSlice.
///

View File

@@ -1,3 +1,7 @@
//! A single time step's worth of events.
//!
//! Renamed from `Batch` in T2.
use std::collections::HashMap;
use crate::{
@@ -106,7 +110,7 @@ impl Event {
}
#[derive(Debug)]
pub struct Batch {
pub struct TimeSlice {
pub(crate) events: Vec<Event>,
pub(crate) skills: SkillStore,
pub(crate) time: i64,
@@ -114,7 +118,7 @@ pub struct Batch {
arena: ScratchArena,
}
impl Batch {
impl TimeSlice {
pub fn new(time: i64, p_draw: f64) -> Self {
Self {
events: Vec::new(),
@@ -431,9 +435,9 @@ mod tests {
);
}
let mut batch = Batch::new(0, 0.0);
let mut time_slice = TimeSlice::new(0, 0.0);
batch.add_events(
time_slice.add_events(
vec![
vec![vec![a], vec![b]],
vec![vec![c], vec![d]],
@@ -444,7 +448,7 @@ mod tests {
&agents,
);
let post = batch.posteriors();
let post = time_slice.posteriors();
assert_ulps_eq!(
post[&a],
@@ -477,7 +481,7 @@ mod tests {
epsilon = 1e-6
);
assert_eq!(batch.convergence(&agents), 1);
assert_eq!(time_slice.convergence(&agents), 1);
}
#[test]
@@ -507,9 +511,9 @@ mod tests {
);
}
let mut batch = Batch::new(0, 0.0);
let mut time_slice = TimeSlice::new(0, 0.0);
batch.add_events(
time_slice.add_events(
vec![
vec![vec![a], vec![b]],
vec![vec![a], vec![c]],
@@ -520,7 +524,7 @@ mod tests {
&agents,
);
let post = batch.posteriors();
let post = time_slice.posteriors();
assert_ulps_eq!(
post[&a],
@@ -538,9 +542,9 @@ mod tests {
epsilon = 1e-6
);
assert!(batch.convergence(&agents) > 1);
assert!(time_slice.convergence(&agents) > 1);
let post = batch.posteriors();
let post = time_slice.posteriors();
assert_ulps_eq!(
post[&a],
@@ -586,9 +590,9 @@ mod tests {
);
}
let mut batch = Batch::new(0, 0.0);
let mut time_slice = TimeSlice::new(0, 0.0);
batch.add_events(
time_slice.add_events(
vec![
vec![vec![a], vec![b]],
vec![vec![a], vec![c]],
@@ -599,9 +603,9 @@ mod tests {
&agents,
);
batch.convergence(&agents);
time_slice.convergence(&agents);
let post = batch.posteriors();
let post = time_slice.posteriors();
assert_ulps_eq!(
post[&a],
@@ -619,7 +623,7 @@ mod tests {
epsilon = 1e-6
);
batch.add_events(
time_slice.add_events(
vec![
vec![vec![a], vec![b]],
vec![vec![a], vec![c]],
@@ -630,11 +634,11 @@ mod tests {
&agents,
);
assert_eq!(batch.events.len(), 6);
assert_eq!(time_slice.events.len(), 6);
batch.convergence(&agents);
time_slice.convergence(&agents);
let post = batch.posteriors();
let post = time_slice.posteriors();
assert_ulps_eq!(
post[&a],