Implements tiers T0, T1, T2 of `docs/superpowers/specs/2026-04-23-trueskill-engine-redesign-design.md`. All three tiers have landed together on this branch because they build on one another; this PR rolls them up for a single review pass. Per-tier plans: - T0: `docs/superpowers/plans/2026-04-23-t0-numerical-parity.md` - T1: `docs/superpowers/plans/2026-04-24-t1-factor-graph.md` - T2: `docs/superpowers/plans/2026-04-24-t2-new-api-surface.md` ## Summary ### T0 — Numerical parity (internal) - `Gaussian` switched to natural-parameter storage `(pi, tau)`; mul/div now ~7× faster (218 ps vs 1.57 ns). - `HashMap<Index, _>` → dense `Vec<_>` keyed by `Index.0` (via `AgentStore<D>`, `SkillStore`). - `ScratchArena` eliminates per-event allocations in `Game::likelihoods`. - `InferenceError` seed type added (1 variant). - 38 → 53 tests passing through T1. - Benchmark: `Batch::iteration` 29.84 → 21.25 µs. ### T1 — Factor graph machinery (internal) - `Factor` trait + `BuiltinFactor` enum (TeamSum / RankDiff / Trunc) driving within-game inference. - `VarStore` flat storage for variable marginals. - `Schedule` trait + `EpsilonOrMax` impl replacing the hand-rolled EP loop. - `Game::likelihoods` rebuilt on the factor-graph machinery; iteration counts and goldens preserved to within 1e-6. - 53 tests passing. - Benchmark: `Batch::iteration` 23.01 µs (slight regression absorbed in T2). ### T2 — New API surface (breaking) **Renames:** - `IndexMap → KeyTable`, `Player → Rating`, `Agent → Competitor`, `Batch → TimeSlice` **New types:** - `Time` trait with `Untimed` ZST and `i64` impls; `Drift<T>`, `Rating<T, D>`, `Competitor<T, D>`, `TimeSlice<T>`, `History<T, D, O, K>` all generic. - `Event<T, K>`, `Team<K>`, `Member<K>`, `Outcome` (`Ranked` variant; `#[non_exhaustive]`). - `Observer<T>` trait + `NullObserver`. - `ConvergenceOptions`, `ConvergenceReport`. - `GameOptions`, `OwnedGame<T, D>`. **Three-tier ingestion:** - `history.record_winner(&K, &K, T)` / `record_draw(&K, &K, T)` — 1v1 convenience. - `history.add_events(iter)` — typed bulk. - `history.event(T).team([...]).weights([...]).ranking([...]).commit()` — fluent. **Query API:** `current_skill`, `learning_curve`, `learning_curves` (keyed on `K`), `log_evidence`, `log_evidence_for`, `predict_quality`, `predict_outcome`. **Game constructors:** `ranked`, `one_v_one`, `free_for_all`, `custom` — all returning `Result<_, InferenceError>`. **`factors` module:** `Factor`, `Schedule`, `VarStore`, `VarId`, `BuiltinFactor`, `EpsilonOrMax`, `ScheduleReport`, `TeamSumFactor`, `RankDiffFactor`, `TruncFactor` now public. **Errors:** `InferenceError` gains `MismatchedShape`, `InvalidProbability`, `ConvergenceFailed`; boundary panics converted to `Result`. **Removed (breaking):** `History::convergence(iters, eps, verbose)`, `HistoryBuilder::gamma(f64)`, `HistoryBuilder::time(bool)`, `History.time: bool`, `learning_curves_by_index`, nested-Vec public `add_events`. ## Behavior change (documented in CHANGELOG) `Time = Untimed` has `elapsed_to → 0`, so no drift accumulates between slices. The old `time=false` mode implicitly forced `elapsed=1` on reappearance via an `i64::MAX` sentinel — that quirk is not reproducible under a typed time axis. Tests that depended on it now use `History::<i64, _>` with explicit `1..=n` timestamps. One test (`test_env_ttt`) had 3 Gaussian goldens updated to reflect the corrected semantics; documented in commit `33a7d90`. ## Final numbers | Metric | Before T0 | After T2 | Delta | |---|---|---|---| | `Batch::iteration` | 29.84 µs | 21.36 µs | **-28%** | | `Gaussian::mul` | 1.57 ns | 219 ps | **-86%** | | `Gaussian::div` | 1.57 ns | 219 ps | **-86%** | | Tests passing | 38 | 90 | +52 | All other Gaussian ops unchanged (~219 ps add/sub, ~264 ps pi/tau reads). ## Test plan - [x] `cargo test --features approx` — 90/90 pass (68 lib + 10 api_shape + 6 game + 4 record_winner + 2 equivalence) - [x] `cargo clippy --all-targets --features approx -- -D warnings` — clean - [x] `cargo +nightly fmt --check` — clean - [x] `cargo bench --bench batch` — 21.36 µs - [x] `cargo bench --bench gaussian` — unchanged from T1 - [x] `cargo run --example atp --features approx` — rewritten in new API, runs clean - [x] Historical Game-level goldens preserved in `tests/equivalence.rs` - [x] Public API matches spec Section 4 (verified by integration tests in `tests/api_shape.rs`) ## Commit history ~45 commits total across T0 + T1 + T2. Each task is self-contained and individually tested; the branch is bisectable. See `git log main..t2-new-api-surface` for the full list. ## Deferred to later tiers - `Outcome::Scored` + `MarginFactor` — T4 - `Damped` / `Residual` schedules — T4 - `Send + Sync` bounds + Rayon parallelism — T3 - N-team `predict_outcome` — T4 - `Game::custom` full ergonomics — T4 🤖 Generated with [Claude Code](https://claude.com/claude-code) Reviewed-on: #1 Co-authored-by: Anders Olsson <anders.e.olsson@gmail.com> Co-committed-by: Anders Olsson <anders.e.olsson@gmail.com>
1629 lines
50 KiB
Rust
1629 lines
50 KiB
Rust
use std::{borrow::Borrow, collections::HashMap, hash::Hash, marker::PhantomData};
|
|
|
|
use crate::{
|
|
BETA, GAMMA, Index, MU, N_INF, P_DRAW, SIGMA,
|
|
competitor::{self, Competitor},
|
|
convergence::{ConvergenceOptions, ConvergenceReport},
|
|
drift::{ConstantDrift, Drift},
|
|
error::InferenceError,
|
|
gaussian::Gaussian,
|
|
key_table::KeyTable,
|
|
observer::{NullObserver, Observer},
|
|
rating::Rating,
|
|
sort_time,
|
|
storage::CompetitorStore,
|
|
time::Time,
|
|
time_slice::{self, TimeSlice},
|
|
tuple_gt, tuple_max,
|
|
};
|
|
|
|
#[derive(Clone)]
|
|
pub struct HistoryBuilder<
|
|
T: Time = i64,
|
|
D: Drift<T> = ConstantDrift,
|
|
O: Observer<T> = NullObserver,
|
|
K: Eq + Hash + Clone = &'static str,
|
|
> {
|
|
mu: f64,
|
|
sigma: f64,
|
|
beta: f64,
|
|
drift: D,
|
|
p_draw: f64,
|
|
online: bool,
|
|
convergence: ConvergenceOptions,
|
|
observer: O,
|
|
_time: PhantomData<T>,
|
|
_key: PhantomData<K>,
|
|
}
|
|
|
|
impl<T: Time, D: Drift<T>, O: Observer<T>, K: Eq + Hash + Clone> HistoryBuilder<T, D, O, K> {
|
|
pub fn mu(mut self, mu: f64) -> Self {
|
|
self.mu = mu;
|
|
self
|
|
}
|
|
|
|
pub fn sigma(mut self, sigma: f64) -> Self {
|
|
self.sigma = sigma;
|
|
self
|
|
}
|
|
|
|
pub fn beta(mut self, beta: f64) -> Self {
|
|
self.beta = beta;
|
|
self
|
|
}
|
|
|
|
pub fn drift<D2: Drift<T>>(self, drift: D2) -> HistoryBuilder<T, D2, O, K> {
|
|
HistoryBuilder {
|
|
drift,
|
|
mu: self.mu,
|
|
sigma: self.sigma,
|
|
beta: self.beta,
|
|
p_draw: self.p_draw,
|
|
online: self.online,
|
|
convergence: self.convergence,
|
|
observer: self.observer,
|
|
_time: self._time,
|
|
_key: self._key,
|
|
}
|
|
}
|
|
|
|
pub fn p_draw(mut self, p_draw: f64) -> Self {
|
|
self.p_draw = p_draw;
|
|
self
|
|
}
|
|
|
|
pub fn online(mut self, online: bool) -> Self {
|
|
self.online = online;
|
|
self
|
|
}
|
|
|
|
pub fn convergence(mut self, opts: ConvergenceOptions) -> Self {
|
|
self.convergence = opts;
|
|
self
|
|
}
|
|
|
|
pub fn observer<O2: Observer<T>>(self, observer: O2) -> HistoryBuilder<T, D, O2, K> {
|
|
HistoryBuilder {
|
|
mu: self.mu,
|
|
sigma: self.sigma,
|
|
beta: self.beta,
|
|
drift: self.drift,
|
|
p_draw: self.p_draw,
|
|
online: self.online,
|
|
convergence: self.convergence,
|
|
observer,
|
|
_time: self._time,
|
|
_key: self._key,
|
|
}
|
|
}
|
|
|
|
pub fn build(self) -> History<T, D, O, K> {
|
|
History {
|
|
size: 0,
|
|
time_slices: Vec::new(),
|
|
agents: CompetitorStore::new(),
|
|
keys: KeyTable::new(),
|
|
mu: self.mu,
|
|
sigma: self.sigma,
|
|
beta: self.beta,
|
|
drift: self.drift,
|
|
p_draw: self.p_draw,
|
|
online: self.online,
|
|
convergence: self.convergence,
|
|
observer: self.observer,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Default for HistoryBuilder<i64, ConstantDrift, NullObserver, &'static str> {
|
|
fn default() -> Self {
|
|
Self {
|
|
mu: MU,
|
|
sigma: SIGMA,
|
|
beta: BETA,
|
|
drift: ConstantDrift(GAMMA),
|
|
p_draw: P_DRAW,
|
|
online: false,
|
|
convergence: ConvergenceOptions::default(),
|
|
observer: NullObserver,
|
|
_time: PhantomData,
|
|
_key: PhantomData,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct History<
|
|
T: Time = i64,
|
|
D: Drift<T> = ConstantDrift,
|
|
O: Observer<T> = NullObserver,
|
|
K: Eq + Hash + Clone = &'static str,
|
|
> {
|
|
size: usize,
|
|
pub(crate) time_slices: Vec<TimeSlice<T>>,
|
|
pub(crate) agents: CompetitorStore<T, D>,
|
|
keys: KeyTable<K>,
|
|
mu: f64,
|
|
sigma: f64,
|
|
beta: f64,
|
|
drift: D,
|
|
p_draw: f64,
|
|
online: bool,
|
|
convergence: ConvergenceOptions,
|
|
observer: O,
|
|
}
|
|
|
|
impl Default for History<i64, ConstantDrift, NullObserver, &'static str> {
|
|
fn default() -> Self {
|
|
HistoryBuilder::default().build()
|
|
}
|
|
}
|
|
|
|
impl History<i64, ConstantDrift, NullObserver, &'static str> {
|
|
pub fn builder() -> HistoryBuilder<i64, ConstantDrift, NullObserver, &'static str> {
|
|
HistoryBuilder::default()
|
|
}
|
|
}
|
|
|
|
impl<K: Eq + Hash + Clone> History<i64, ConstantDrift, NullObserver, K> {
|
|
/// Like `builder()` but uses a custom key type `K` instead of the default `&'static str`.
|
|
pub fn builder_with_key() -> HistoryBuilder<i64, ConstantDrift, NullObserver, K> {
|
|
HistoryBuilder {
|
|
mu: MU,
|
|
sigma: SIGMA,
|
|
beta: BETA,
|
|
drift: ConstantDrift(GAMMA),
|
|
p_draw: P_DRAW,
|
|
online: false,
|
|
convergence: ConvergenceOptions::default(),
|
|
observer: NullObserver,
|
|
_time: PhantomData,
|
|
_key: PhantomData,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T: Time, D: Drift<T>, O: Observer<T>, K: Eq + Hash + Clone> History<T, D, O, K> {
|
|
pub fn intern<Q>(&mut self, key: &Q) -> Index
|
|
where
|
|
K: Borrow<Q>,
|
|
Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
|
|
{
|
|
self.keys.get_or_create(key)
|
|
}
|
|
|
|
pub fn lookup<Q>(&self, key: &Q) -> Option<Index>
|
|
where
|
|
K: Borrow<Q>,
|
|
Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
|
|
{
|
|
self.keys.get(key)
|
|
}
|
|
}
|
|
|
|
impl<T: Time, D: Drift<T>, O: Observer<T>, K: Eq + Hash + Clone> History<T, D, O, K> {
|
|
fn iteration(&mut self) -> (f64, f64) {
|
|
let mut step = (0.0, 0.0);
|
|
|
|
competitor::clean(self.agents.values_mut(), false);
|
|
|
|
for j in (0..self.time_slices.len() - 1).rev() {
|
|
for agent in self.time_slices[j + 1].skills.keys() {
|
|
self.agents.get_mut(agent).unwrap().message =
|
|
self.time_slices[j + 1].backward_prior_out(&agent, &self.agents);
|
|
}
|
|
|
|
let old = self.time_slices[j].posteriors();
|
|
|
|
self.time_slices[j].new_backward_info(&self.agents);
|
|
|
|
let new = self.time_slices[j].posteriors();
|
|
|
|
step = old
|
|
.iter()
|
|
.fold(step, |step, (a, old)| tuple_max(step, old.delta(new[a])));
|
|
}
|
|
|
|
competitor::clean(self.agents.values_mut(), false);
|
|
|
|
for j in 1..self.time_slices.len() {
|
|
for agent in self.time_slices[j - 1].skills.keys() {
|
|
self.agents.get_mut(agent).unwrap().message =
|
|
self.time_slices[j - 1].forward_prior_out(&agent);
|
|
}
|
|
|
|
let old = self.time_slices[j].posteriors();
|
|
|
|
self.time_slices[j].new_forward_info(&self.agents);
|
|
|
|
let new = self.time_slices[j].posteriors();
|
|
|
|
step = old
|
|
.iter()
|
|
.fold(step, |step, (a, old)| tuple_max(step, old.delta(new[a])));
|
|
}
|
|
|
|
if self.time_slices.len() == 1 {
|
|
let old = self.time_slices[0].posteriors();
|
|
|
|
self.time_slices[0].iteration(0, &self.agents);
|
|
|
|
let new = self.time_slices[0].posteriors();
|
|
|
|
step = old
|
|
.iter()
|
|
.fold(step, |step, (a, old)| tuple_max(step, old.delta(new[a])));
|
|
}
|
|
|
|
step
|
|
}
|
|
|
|
/// Learning curves for all competitors, keyed by their user-facing key.
|
|
///
|
|
/// Note: `key(idx)` is O(n) per lookup; this method is therefore O(n²)
|
|
/// in the number of competitors. Acceptable for T2; T3 may optimize.
|
|
pub fn learning_curves(&self) -> HashMap<K, Vec<(T, Gaussian)>> {
|
|
let mut data: HashMap<K, Vec<(T, Gaussian)>> = HashMap::new();
|
|
for slice in &self.time_slices {
|
|
for (idx, skill) in slice.skills.iter() {
|
|
if let Some(key) = self.keys.key(idx).cloned() {
|
|
data.entry(key)
|
|
.or_default()
|
|
.push((slice.time, skill.posterior()));
|
|
}
|
|
}
|
|
}
|
|
data
|
|
}
|
|
|
|
/// Skill estimate at the latest time slice the competitor appears in.
|
|
pub fn current_skill<Q>(&self, key: &Q) -> Option<Gaussian>
|
|
where
|
|
K: std::borrow::Borrow<Q>,
|
|
Q: std::hash::Hash + Eq + ?Sized,
|
|
{
|
|
let idx = self.keys.get(key)?;
|
|
self.time_slices
|
|
.iter()
|
|
.rev()
|
|
.find_map(|ts| ts.skills.get(idx).map(|sk| sk.posterior()))
|
|
}
|
|
|
|
/// Learning curve for a single key: (time, posterior) pairs in time order.
|
|
pub fn learning_curve<Q>(&self, key: &Q) -> Vec<(T, Gaussian)>
|
|
where
|
|
K: std::borrow::Borrow<Q>,
|
|
Q: std::hash::Hash + Eq + ?Sized,
|
|
{
|
|
let Some(idx) = self.keys.get(key) else {
|
|
return Vec::new();
|
|
};
|
|
self.time_slices
|
|
.iter()
|
|
.filter_map(|ts| ts.skills.get(idx).map(|sk| (ts.time, sk.posterior())))
|
|
.collect()
|
|
}
|
|
|
|
pub(crate) fn log_evidence_internal(&mut self, forward: bool, targets: &[Index]) -> f64 {
|
|
self.time_slices
|
|
.iter()
|
|
.map(|ts| ts.log_evidence(self.online, targets, forward, &self.agents))
|
|
.sum()
|
|
}
|
|
|
|
/// Total log-evidence across the history.
|
|
pub fn log_evidence(&mut self) -> f64 {
|
|
self.log_evidence_internal(false, &[])
|
|
}
|
|
|
|
/// Log-evidence restricted to time slices containing at least one of the
|
|
/// given keys. Useful for leave-one-out cross-validation.
|
|
pub fn log_evidence_for<Q>(&mut self, keys: &[&Q]) -> f64
|
|
where
|
|
K: std::borrow::Borrow<Q>,
|
|
Q: std::hash::Hash + Eq + ?Sized,
|
|
{
|
|
let targets: Vec<Index> = keys.iter().filter_map(|k| self.keys.get(*k)).collect();
|
|
self.log_evidence_internal(false, &targets)
|
|
}
|
|
|
|
/// Draw-probability quality metric for the given teams (key slices).
|
|
///
|
|
/// Values range roughly [0, 1]; 1 == perfectly matched.
|
|
pub fn predict_quality(&self, teams: &[&[&K]]) -> f64 {
|
|
let groups: Vec<Vec<Gaussian>> = teams
|
|
.iter()
|
|
.map(|team| {
|
|
team.iter()
|
|
.filter_map(|k| self.keys.get(*k))
|
|
.filter_map(|idx| {
|
|
self.time_slices
|
|
.iter()
|
|
.rev()
|
|
.find_map(|ts| ts.skills.get(idx).map(|s| s.posterior()))
|
|
})
|
|
.collect()
|
|
})
|
|
.collect();
|
|
let group_refs: Vec<&[Gaussian]> = groups.iter().map(|g| g.as_slice()).collect();
|
|
crate::quality(&group_refs, self.beta)
|
|
}
|
|
|
|
/// 2-team win probability: returns `[P(team0 wins), P(team1 wins)]`.
|
|
///
|
|
/// Panics if `teams.len() != 2`. N-team support lands in T4.
|
|
pub fn predict_outcome(&self, teams: &[&[&K]]) -> Vec<f64> {
|
|
assert_eq!(teams.len(), 2, "predict_outcome T2: 2 teams only");
|
|
let gather = |team: &[&K]| -> Gaussian {
|
|
team.iter()
|
|
.filter_map(|k| self.keys.get(*k))
|
|
.filter_map(|idx| {
|
|
self.time_slices
|
|
.iter()
|
|
.rev()
|
|
.find_map(|ts| ts.skills.get(idx).map(|s| s.posterior()))
|
|
})
|
|
.fold(crate::N00, |acc, g| acc + g.forget(self.beta.powi(2)))
|
|
};
|
|
let a = gather(teams[0]);
|
|
let b = gather(teams[1]);
|
|
let diff = a - b;
|
|
let p_a = 1.0 - crate::cdf(0.0, diff.mu(), diff.sigma());
|
|
vec![p_a, 1.0 - p_a]
|
|
}
|
|
|
|
/// Run the full forward+backward convergence loop and return a summary.
|
|
pub fn converge(&mut self) -> Result<ConvergenceReport, InferenceError> {
|
|
use std::time::Instant;
|
|
|
|
use smallvec::SmallVec;
|
|
|
|
let opts = self.convergence;
|
|
let mut step = (f64::INFINITY, f64::INFINITY);
|
|
let mut i = 0;
|
|
let mut per_iter: SmallVec<[std::time::Duration; 32]> = SmallVec::new();
|
|
while tuple_gt(step, opts.epsilon) && i < opts.max_iter {
|
|
let t0 = Instant::now();
|
|
step = self.iteration();
|
|
per_iter.push(t0.elapsed());
|
|
i += 1;
|
|
self.observer.on_iteration_end(i, step);
|
|
}
|
|
let converged = !tuple_gt(step, opts.epsilon);
|
|
let log_evidence = self.log_evidence_internal(false, &[]);
|
|
self.observer.on_converged(i, step, converged);
|
|
Ok(ConvergenceReport {
|
|
iterations: i,
|
|
final_step: step,
|
|
log_evidence,
|
|
converged,
|
|
per_iteration_time: per_iter,
|
|
slices_skipped: 0,
|
|
})
|
|
}
|
|
}
|
|
|
|
impl<T: Time, D: Drift<T>, O: Observer<T>, K: Eq + Hash + Clone> History<T, D, O, K> {
|
|
pub(crate) fn add_events_with_prior(
|
|
&mut self,
|
|
composition: Vec<Vec<Vec<Index>>>,
|
|
results: Vec<Vec<f64>>,
|
|
times: Vec<T>,
|
|
weights: Vec<Vec<Vec<f64>>>,
|
|
mut priors: HashMap<Index, Rating<T, D>>,
|
|
) -> Result<(), InferenceError> {
|
|
if !results.is_empty() && results.len() != composition.len() {
|
|
return Err(InferenceError::MismatchedShape {
|
|
kind: "results",
|
|
expected: composition.len(),
|
|
got: results.len(),
|
|
});
|
|
}
|
|
if times.len() != composition.len() {
|
|
return Err(InferenceError::MismatchedShape {
|
|
kind: "times",
|
|
expected: composition.len(),
|
|
got: times.len(),
|
|
});
|
|
}
|
|
if !weights.is_empty() && weights.len() != composition.len() {
|
|
return Err(InferenceError::MismatchedShape {
|
|
kind: "weights",
|
|
expected: composition.len(),
|
|
got: weights.len(),
|
|
});
|
|
}
|
|
|
|
competitor::clean(self.agents.values_mut(), true);
|
|
|
|
let mut this_agent = Vec::with_capacity(1024);
|
|
|
|
for agent in composition.iter().flatten().flatten() {
|
|
if this_agent.contains(agent) {
|
|
continue;
|
|
}
|
|
|
|
this_agent.push(*agent);
|
|
|
|
if !self.agents.contains(*agent) {
|
|
self.agents.insert(
|
|
*agent,
|
|
Competitor {
|
|
rating: priors.remove(agent).unwrap_or_else(|| {
|
|
Rating::new(
|
|
Gaussian::from_ms(self.mu, self.sigma),
|
|
self.beta,
|
|
self.drift,
|
|
)
|
|
}),
|
|
message: N_INF,
|
|
last_time: None,
|
|
},
|
|
);
|
|
}
|
|
}
|
|
|
|
let n = composition.len();
|
|
let o = sort_time(×, false);
|
|
|
|
let mut i = 0;
|
|
let mut k = 0;
|
|
|
|
while i < n {
|
|
let mut j = i + 1;
|
|
let t = times[o[i]];
|
|
|
|
while j < n && times[o[j]] == t {
|
|
j += 1;
|
|
}
|
|
|
|
while self.time_slices.len() > k && self.time_slices[k].time < t {
|
|
let time_slice = &mut self.time_slices[k];
|
|
|
|
if k > 0 {
|
|
time_slice.new_forward_info(&self.agents);
|
|
}
|
|
|
|
for agent_idx in &this_agent {
|
|
if let Some(skill) = time_slice.skills.get_mut(*agent_idx) {
|
|
skill.elapsed = time_slice::compute_elapsed(
|
|
self.agents[*agent_idx].last_time.as_ref(),
|
|
&time_slice.time,
|
|
);
|
|
|
|
let agent = self.agents.get_mut(*agent_idx).unwrap();
|
|
|
|
agent.last_time = Some(time_slice.time);
|
|
agent.message = time_slice.forward_prior_out(agent_idx);
|
|
}
|
|
}
|
|
|
|
k += 1;
|
|
}
|
|
|
|
let composition = (i..j)
|
|
.map(|e| composition[o[e]].clone())
|
|
.collect::<Vec<_>>();
|
|
|
|
let results = if results.is_empty() {
|
|
Vec::new()
|
|
} else {
|
|
(i..j).map(|e| results[o[e]].clone()).collect::<Vec<_>>()
|
|
};
|
|
|
|
let weights = if weights.is_empty() {
|
|
Vec::new()
|
|
} else {
|
|
(i..j).map(|e| weights[o[e]].clone()).collect::<Vec<_>>()
|
|
};
|
|
|
|
if self.time_slices.len() > k && self.time_slices[k].time == t {
|
|
let time_slice = &mut self.time_slices[k];
|
|
time_slice.add_events(composition, results, weights, &self.agents);
|
|
|
|
for agent_idx in time_slice.skills.keys() {
|
|
let agent = self.agents.get_mut(agent_idx).unwrap();
|
|
|
|
agent.last_time = Some(t);
|
|
agent.message = time_slice.forward_prior_out(&agent_idx);
|
|
}
|
|
} else {
|
|
let mut time_slice = TimeSlice::new(t, self.p_draw);
|
|
time_slice.add_events(composition, results, weights, &self.agents);
|
|
|
|
self.time_slices.insert(k, time_slice);
|
|
|
|
let time_slice = &self.time_slices[k];
|
|
|
|
for agent_idx in time_slice.skills.keys() {
|
|
let agent = self.agents.get_mut(agent_idx).unwrap();
|
|
|
|
agent.last_time = Some(t);
|
|
agent.message = time_slice.forward_prior_out(&agent_idx);
|
|
}
|
|
|
|
k += 1;
|
|
}
|
|
|
|
i = j;
|
|
}
|
|
|
|
while self.time_slices.len() > k {
|
|
let time_slice = &mut self.time_slices[k];
|
|
|
|
time_slice.new_forward_info(&self.agents);
|
|
|
|
for agent_idx in &this_agent {
|
|
if let Some(skill) = time_slice.skills.get_mut(*agent_idx) {
|
|
skill.elapsed = time_slice::compute_elapsed(
|
|
self.agents[*agent_idx].last_time.as_ref(),
|
|
&time_slice.time,
|
|
);
|
|
|
|
let agent = self.agents.get_mut(*agent_idx).unwrap();
|
|
|
|
agent.last_time = Some(time_slice.time);
|
|
agent.message = time_slice.forward_prior_out(agent_idx);
|
|
}
|
|
}
|
|
|
|
k += 1;
|
|
}
|
|
|
|
self.size += n;
|
|
Ok(())
|
|
}
|
|
|
|
pub fn record_winner<Q>(&mut self, winner: &Q, loser: &Q, time: T) -> Result<(), InferenceError>
|
|
where
|
|
K: Borrow<Q>,
|
|
Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
|
|
{
|
|
let w = self.intern(winner);
|
|
let l = self.intern(loser);
|
|
self.add_events_with_prior(
|
|
vec![vec![vec![w], vec![l]]],
|
|
vec![vec![1.0, 0.0]],
|
|
vec![time],
|
|
vec![],
|
|
HashMap::new(),
|
|
)
|
|
}
|
|
|
|
pub fn record_draw<Q>(&mut self, a: &Q, b: &Q, time: T) -> Result<(), InferenceError>
|
|
where
|
|
K: Borrow<Q>,
|
|
Q: Hash + Eq + ToOwned<Owned = K> + ?Sized,
|
|
{
|
|
let a_idx = self.intern(a);
|
|
let b_idx = self.intern(b);
|
|
self.add_events_with_prior(
|
|
vec![vec![vec![a_idx], vec![b_idx]]],
|
|
vec![vec![0.0, 0.0]],
|
|
vec![time],
|
|
vec![],
|
|
HashMap::new(),
|
|
)
|
|
}
|
|
|
|
/// Start a fluent event builder for a single match at `time`.
|
|
pub fn event(&mut self, time: T) -> crate::event_builder::EventBuilder<'_, T, D, O, K> {
|
|
crate::event_builder::EventBuilder::new(self, time)
|
|
}
|
|
|
|
/// Bulk-ingest typed events.
|
|
pub fn add_events<I>(&mut self, events: I) -> Result<(), InferenceError>
|
|
where
|
|
I: IntoIterator<Item = crate::event::Event<T, K>>,
|
|
{
|
|
use crate::event::Event;
|
|
let events: Vec<Event<T, K>> = events.into_iter().collect();
|
|
if events.is_empty() {
|
|
return Ok(());
|
|
}
|
|
|
|
let mut composition: Vec<Vec<Vec<Index>>> = Vec::with_capacity(events.len());
|
|
let mut results: Vec<Vec<f64>> = Vec::with_capacity(events.len());
|
|
let mut times: Vec<T> = Vec::with_capacity(events.len());
|
|
let mut weights: Vec<Vec<Vec<f64>>> = Vec::with_capacity(events.len());
|
|
let mut priors: HashMap<Index, Rating<T, D>> = HashMap::new();
|
|
|
|
for ev in events {
|
|
let ranks = ev.outcome.as_ranks();
|
|
if ranks.len() != ev.teams.len() {
|
|
return Err(InferenceError::MismatchedShape {
|
|
kind: "outcome ranks vs teams",
|
|
expected: ev.teams.len(),
|
|
got: ranks.len(),
|
|
});
|
|
}
|
|
|
|
let mut event_comp: Vec<Vec<Index>> = Vec::with_capacity(ev.teams.len());
|
|
let mut event_weights: Vec<Vec<f64>> = Vec::with_capacity(ev.teams.len());
|
|
|
|
for team in ev.teams {
|
|
let mut team_indices: Vec<Index> = Vec::with_capacity(team.members.len());
|
|
let mut team_weights: Vec<f64> = Vec::with_capacity(team.members.len());
|
|
for member in team.members {
|
|
let idx = self.keys.get_or_create(&member.key);
|
|
team_indices.push(idx);
|
|
team_weights.push(member.weight);
|
|
if let Some(prior) = member.prior {
|
|
priors.insert(idx, Rating::new(prior, self.beta, self.drift));
|
|
}
|
|
}
|
|
event_comp.push(team_indices);
|
|
event_weights.push(team_weights);
|
|
}
|
|
composition.push(event_comp);
|
|
weights.push(event_weights);
|
|
|
|
let max_rank = ranks.iter().copied().max().unwrap_or(0) as f64;
|
|
let inverted: Vec<f64> = ranks.iter().map(|&r| max_rank - r as f64).collect();
|
|
results.push(inverted);
|
|
times.push(ev.time);
|
|
}
|
|
|
|
self.add_events_with_prior(composition, results, times, weights, priors)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use approx::assert_ulps_eq;
|
|
use smallvec::smallvec;
|
|
|
|
use super::*;
|
|
use crate::{
|
|
ConstantDrift, EPSILON, Event, Game, Gaussian, Member, Outcome, P_DRAW, Team,
|
|
arena::ScratchArena,
|
|
};
|
|
|
|
fn make_events_1v1(
|
|
pairs: &[(&'static str, &'static str)],
|
|
outcomes: &[Outcome],
|
|
times: &[i64],
|
|
) -> Vec<Event<i64, &'static str>> {
|
|
pairs
|
|
.iter()
|
|
.copied()
|
|
.zip(outcomes.iter().cloned())
|
|
.zip(times.iter().copied())
|
|
.map(|(((a, b), outcome), time)| Event {
|
|
time,
|
|
teams: smallvec![
|
|
Team::with_members([Member::new(a)]),
|
|
Team::with_members([Member::new(b)]),
|
|
],
|
|
outcome,
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
#[test]
|
|
fn test_init() {
|
|
let mut h = History::builder()
|
|
.mu(25.0)
|
|
.sigma(25.0 / 3.0)
|
|
.beta(25.0 / 6.0)
|
|
.drift(ConstantDrift(0.15 * 25.0 / 3.0))
|
|
.build();
|
|
|
|
let events = make_events_1v1(
|
|
&[("a", "b"), ("a", "c"), ("b", "c")],
|
|
&[
|
|
Outcome::winner(0, 2),
|
|
Outcome::winner(1, 2),
|
|
Outcome::winner(0, 2),
|
|
],
|
|
&[1, 2, 3],
|
|
);
|
|
h.add_events(events).unwrap();
|
|
|
|
let a = h.keys.get("a").unwrap();
|
|
let b = h.keys.get("b").unwrap();
|
|
let c = h.keys.get("c").unwrap();
|
|
|
|
let p0 = h.time_slices[0].posteriors();
|
|
|
|
assert_ulps_eq!(
|
|
p0[&a],
|
|
Gaussian::from_ms(29.205220, 7.194481),
|
|
epsilon = 1e-6
|
|
);
|
|
|
|
let observed = h.time_slices[1].skills.get(a).unwrap().forward.sigma();
|
|
let gamma: f64 = 0.15 * 25.0 / 3.0;
|
|
let expected = (gamma.powi(2)
|
|
+ h.time_slices[0]
|
|
.skills
|
|
.get(a)
|
|
.unwrap()
|
|
.posterior()
|
|
.sigma()
|
|
.powi(2))
|
|
.sqrt();
|
|
|
|
assert_ulps_eq!(observed, expected, epsilon = 0.000001);
|
|
|
|
let observed = h.time_slices[1].skills.get(a).unwrap().posterior();
|
|
|
|
let w = [vec![1.0], vec![1.0]];
|
|
let p = Game::ranked_with_arena(
|
|
h.time_slices[1].events[0].within_priors(
|
|
false,
|
|
false,
|
|
&h.time_slices[1].skills,
|
|
&h.agents,
|
|
),
|
|
&[0.0, 1.0],
|
|
&w,
|
|
P_DRAW,
|
|
&mut ScratchArena::new(),
|
|
)
|
|
.posteriors();
|
|
let expected = p[0][0];
|
|
|
|
assert_ulps_eq!(observed, expected, epsilon = 1e-6);
|
|
|
|
let _ = (b, c);
|
|
}
|
|
|
|
#[test]
|
|
fn test_one_batch() {
|
|
let mut h1 = History::builder()
|
|
.mu(25.0)
|
|
.sigma(25.0 / 3.0)
|
|
.beta(25.0 / 6.0)
|
|
.drift(ConstantDrift(0.15 * 25.0 / 3.0))
|
|
.build();
|
|
|
|
let events = make_events_1v1(
|
|
&[("a", "b"), ("b", "c"), ("c", "a")],
|
|
&[
|
|
Outcome::winner(0, 2),
|
|
Outcome::winner(0, 2),
|
|
Outcome::winner(0, 2),
|
|
],
|
|
&[1, 1, 1],
|
|
);
|
|
h1.add_events(events).unwrap();
|
|
|
|
let a = h1.keys.get("a").unwrap();
|
|
let c = h1.keys.get("c").unwrap();
|
|
|
|
assert_ulps_eq!(
|
|
h1.time_slices[0].skills.get(a).unwrap().posterior(),
|
|
Gaussian::from_ms(22.904409, 6.010330),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
h1.time_slices[0].skills.get(c).unwrap().posterior(),
|
|
Gaussian::from_ms(25.110318, 5.866311),
|
|
epsilon = 1e-6
|
|
);
|
|
|
|
h1.converge().unwrap();
|
|
|
|
assert_ulps_eq!(
|
|
h1.time_slices[0].skills.get(a).unwrap().posterior(),
|
|
Gaussian::from_ms(25.000000, 5.419212),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
h1.time_slices[0].skills.get(c).unwrap().posterior(),
|
|
Gaussian::from_ms(25.000000, 5.419212),
|
|
epsilon = 1e-6
|
|
);
|
|
|
|
let mut h2 = History::builder()
|
|
.mu(25.0)
|
|
.sigma(25.0 / 3.0)
|
|
.beta(25.0 / 6.0)
|
|
.drift(ConstantDrift(25.0 / 300.0))
|
|
.build();
|
|
|
|
let events = make_events_1v1(
|
|
&[("a", "b"), ("b", "c"), ("c", "a")],
|
|
&[
|
|
Outcome::winner(0, 2),
|
|
Outcome::winner(0, 2),
|
|
Outcome::winner(0, 2),
|
|
],
|
|
&[1, 2, 3],
|
|
);
|
|
h2.add_events(events).unwrap();
|
|
|
|
let a = h2.keys.get("a").unwrap();
|
|
let c = h2.keys.get("c").unwrap();
|
|
|
|
assert_ulps_eq!(
|
|
h2.time_slices[2].skills.get(a).unwrap().posterior(),
|
|
Gaussian::from_ms(22.903522, 6.011017),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
h2.time_slices[2].skills.get(c).unwrap().posterior(),
|
|
Gaussian::from_ms(25.110702, 5.866811),
|
|
epsilon = 1e-6
|
|
);
|
|
|
|
h2.converge().unwrap();
|
|
|
|
assert_ulps_eq!(
|
|
h2.time_slices[2].skills.get(a).unwrap().posterior(),
|
|
Gaussian::from_ms(24.998668, 5.420053),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
h2.time_slices[2].skills.get(c).unwrap().posterior(),
|
|
Gaussian::from_ms(25.000532, 5.419827),
|
|
epsilon = 1e-6
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_learning_curves() {
|
|
let mut h = History::builder()
|
|
.mu(25.0)
|
|
.sigma(25.0 / 3.0)
|
|
.beta(25.0 / 6.0)
|
|
.drift(ConstantDrift(25.0 / 300.0))
|
|
.build();
|
|
|
|
let events = make_events_1v1(
|
|
&[("a", "b"), ("b", "c"), ("c", "a")],
|
|
&[
|
|
Outcome::winner(0, 2),
|
|
Outcome::winner(0, 2),
|
|
Outcome::winner(0, 2),
|
|
],
|
|
&[5, 6, 7],
|
|
);
|
|
h.add_events(events).unwrap();
|
|
h.converge().unwrap();
|
|
|
|
let lc_a = h.learning_curve("a");
|
|
let lc_c = h.learning_curve("c");
|
|
|
|
let aj_e = lc_a.len();
|
|
let cj_e = lc_c.len();
|
|
|
|
assert_eq!(lc_a[0].0, 5);
|
|
assert_eq!(lc_a[aj_e - 1].0, 7);
|
|
|
|
assert_ulps_eq!(
|
|
lc_a[aj_e - 1].1,
|
|
Gaussian::from_ms(24.998668, 5.420053),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
lc_c[cj_e - 1].1,
|
|
Gaussian::from_ms(25.000532, 5.419827),
|
|
epsilon = 1e-6
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_env_ttt() {
|
|
let mut h = History::builder()
|
|
.mu(25.0)
|
|
.sigma(25.0 / 3.0)
|
|
.beta(25.0 / 6.0)
|
|
.drift(ConstantDrift(25.0 / 300.0))
|
|
.build();
|
|
|
|
let events = make_events_1v1(
|
|
&[("a", "b"), ("a", "c"), ("b", "c")],
|
|
&[
|
|
Outcome::winner(0, 2),
|
|
Outcome::winner(1, 2),
|
|
Outcome::winner(0, 2),
|
|
],
|
|
&[1, 2, 3],
|
|
);
|
|
h.add_events(events).unwrap();
|
|
h.converge().unwrap();
|
|
|
|
let a = h.keys.get("a").unwrap();
|
|
let b = h.keys.get("b").unwrap();
|
|
let c = h.keys.get("c").unwrap();
|
|
|
|
assert_eq!(h.time_slices[2].skills.get(b).unwrap().elapsed, 2);
|
|
assert_eq!(h.time_slices[2].skills.get(c).unwrap().elapsed, 1);
|
|
|
|
assert_ulps_eq!(
|
|
h.time_slices[0].skills.get(a).unwrap().posterior(),
|
|
Gaussian::from_ms(25.000267, 5.419423),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
h.time_slices[0].skills.get(b).unwrap().posterior(),
|
|
Gaussian::from_ms(24.999198, 5.419512),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
h.time_slices[2].skills.get(b).unwrap().posterior(),
|
|
Gaussian::from_ms(25.001332, 5.420054),
|
|
epsilon = 1e-6
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_teams() {
|
|
let mut h: History<i64, _, _, &'static str> = History::builder()
|
|
.mu(0.0)
|
|
.sigma(6.0)
|
|
.beta(1.0)
|
|
.drift(ConstantDrift(0.0))
|
|
.build();
|
|
|
|
let events: Vec<Event<i64, &'static str>> = vec![
|
|
Event {
|
|
time: 1,
|
|
teams: smallvec![
|
|
Team::with_members([Member::new("a"), Member::new("b")]),
|
|
Team::with_members([Member::new("c"), Member::new("d")]),
|
|
],
|
|
outcome: Outcome::winner(0, 2),
|
|
},
|
|
Event {
|
|
time: 2,
|
|
teams: smallvec![
|
|
Team::with_members([Member::new("e"), Member::new("f")]),
|
|
Team::with_members([Member::new("b"), Member::new("c")]),
|
|
],
|
|
outcome: Outcome::winner(1, 2),
|
|
},
|
|
Event {
|
|
time: 3,
|
|
teams: smallvec![
|
|
Team::with_members([Member::new("a"), Member::new("d")]),
|
|
Team::with_members([Member::new("e"), Member::new("f")]),
|
|
],
|
|
outcome: Outcome::winner(0, 2),
|
|
},
|
|
];
|
|
h.add_events(events).unwrap();
|
|
|
|
let a = h.keys.get("a").unwrap();
|
|
let b = h.keys.get("b").unwrap();
|
|
let c = h.keys.get("c").unwrap();
|
|
let d = h.keys.get("d").unwrap();
|
|
let e = h.keys.get("e").unwrap();
|
|
let f = h.keys.get("f").unwrap();
|
|
|
|
let trueskill_log_evidence = h.log_evidence_internal(false, &[]);
|
|
let trueskill_log_evidence_online = h.log_evidence_internal(true, &[]);
|
|
|
|
assert_ulps_eq!(
|
|
trueskill_log_evidence,
|
|
trueskill_log_evidence_online,
|
|
epsilon = 1e-6
|
|
);
|
|
|
|
assert_ulps_eq!(
|
|
h.time_slices[0].skills.get(b).unwrap().posterior().mu(),
|
|
-h.time_slices[0].skills.get(c).unwrap().posterior().mu(),
|
|
epsilon = 1e-6
|
|
);
|
|
|
|
let evidence_second_event = h.log_evidence_internal(false, &[b]).exp() * 2.0;
|
|
assert_ulps_eq!(0.5, evidence_second_event, epsilon = 1e-6);
|
|
|
|
let evidence_third_event = h.log_evidence_internal(false, &[a]).exp() * 2.0;
|
|
assert_ulps_eq!(0.669885, evidence_third_event, epsilon = 1e-6);
|
|
|
|
h.converge().unwrap();
|
|
|
|
let loocv_hat = h.log_evidence_internal(false, &[]).exp();
|
|
let p_d_m_hat = h.log_evidence_internal(true, &[]).exp();
|
|
|
|
assert_ulps_eq!(loocv_hat, 0.241027, epsilon = 1e-6);
|
|
assert_ulps_eq!(p_d_m_hat, 0.172432, epsilon = 1e-6);
|
|
|
|
assert_ulps_eq!(
|
|
h.time_slices[0].skills.get(a).unwrap().posterior(),
|
|
h.time_slices[0].skills.get(b).unwrap().posterior(),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
h.time_slices[0].skills.get(c).unwrap().posterior(),
|
|
h.time_slices[0].skills.get(d).unwrap().posterior(),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
h.time_slices[1].skills.get(e).unwrap().posterior(),
|
|
h.time_slices[1].skills.get(f).unwrap().posterior(),
|
|
epsilon = 1e-6
|
|
);
|
|
|
|
assert_ulps_eq!(
|
|
h.time_slices[0].skills.get(a).unwrap().posterior(),
|
|
Gaussian::from_ms(4.084902, 5.106919),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
h.time_slices[0].skills.get(c).unwrap().posterior(),
|
|
Gaussian::from_ms(-0.533029, 5.106919),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
h.time_slices[2].skills.get(e).unwrap().posterior(),
|
|
Gaussian::from_ms(-3.551872, 5.154569),
|
|
epsilon = 1e-6
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_add_events() {
|
|
let mut h: History<i64, _, _, &'static str> = History::builder()
|
|
.mu(0.0)
|
|
.sigma(2.0)
|
|
.beta(1.0)
|
|
.drift(ConstantDrift(0.0))
|
|
.build();
|
|
|
|
let events = make_events_1v1(
|
|
&[("a", "b"), ("a", "c"), ("b", "c")],
|
|
&[
|
|
Outcome::winner(0, 2),
|
|
Outcome::winner(1, 2),
|
|
Outcome::winner(0, 2),
|
|
],
|
|
&[1, 2, 3],
|
|
);
|
|
h.add_events(events).unwrap();
|
|
|
|
let a = h.keys.get("a").unwrap();
|
|
let b = h.keys.get("b").unwrap();
|
|
let c = h.keys.get("c").unwrap();
|
|
|
|
h.converge().unwrap();
|
|
|
|
assert_eq!(h.time_slices[2].skills.get(b).unwrap().elapsed, 2);
|
|
assert_eq!(h.time_slices[2].skills.get(c).unwrap().elapsed, 1);
|
|
|
|
assert_ulps_eq!(
|
|
h.time_slices[0].skills.get(a).unwrap().posterior(),
|
|
Gaussian::from_ms(0.000000, 1.300610),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
h.time_slices[0].skills.get(b).unwrap().posterior(),
|
|
Gaussian::from_ms(0.000000, 1.300610),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
h.time_slices[2].skills.get(b).unwrap().posterior(),
|
|
Gaussian::from_ms(0.000000, 1.300610),
|
|
epsilon = 1e-6
|
|
);
|
|
|
|
let events2 = make_events_1v1(
|
|
&[("a", "b"), ("a", "c"), ("b", "c")],
|
|
&[
|
|
Outcome::winner(0, 2),
|
|
Outcome::winner(1, 2),
|
|
Outcome::winner(0, 2),
|
|
],
|
|
&[4, 5, 6],
|
|
);
|
|
h.add_events(events2).unwrap();
|
|
|
|
assert_eq!(h.time_slices.len(), 6);
|
|
|
|
assert_eq!(
|
|
h.time_slices
|
|
.iter()
|
|
.map(|b| b.get_composition())
|
|
.collect::<Vec<_>>(),
|
|
vec![
|
|
vec![vec![vec![a], vec![b]]],
|
|
vec![vec![vec![a], vec![c]]],
|
|
vec![vec![vec![b], vec![c]]],
|
|
vec![vec![vec![a], vec![b]]],
|
|
vec![vec![vec![a], vec![c]]],
|
|
vec![vec![vec![b], vec![c]]]
|
|
]
|
|
);
|
|
|
|
h.converge().unwrap();
|
|
|
|
assert_ulps_eq!(
|
|
h.time_slices[0].skills.get(a).unwrap().posterior(),
|
|
Gaussian::from_ms(0.000000, 0.931236),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
h.time_slices[3].skills.get(a).unwrap().posterior(),
|
|
Gaussian::from_ms(0.000000, 0.931236),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
h.time_slices[3].skills.get(b).unwrap().posterior(),
|
|
Gaussian::from_ms(0.000000, 0.931236),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
h.time_slices[5].skills.get(b).unwrap().posterior(),
|
|
Gaussian::from_ms(0.000000, 0.931236),
|
|
epsilon = 1e-6
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_only_add_events() {
|
|
let mut h: History<i64, _, _, &'static str> = History::builder()
|
|
.mu(0.0)
|
|
.sigma(2.0)
|
|
.beta(1.0)
|
|
.drift(ConstantDrift(0.0))
|
|
.build();
|
|
|
|
let events = make_events_1v1(
|
|
&[("a", "b"), ("a", "c"), ("b", "c")],
|
|
&[
|
|
Outcome::winner(0, 2),
|
|
Outcome::winner(1, 2),
|
|
Outcome::winner(0, 2),
|
|
],
|
|
&[1, 2, 3],
|
|
);
|
|
h.add_events(events).unwrap();
|
|
|
|
let a = h.keys.get("a").unwrap();
|
|
let b = h.keys.get("b").unwrap();
|
|
let c = h.keys.get("c").unwrap();
|
|
|
|
h.converge().unwrap();
|
|
|
|
assert_eq!(h.time_slices[2].skills.get(b).unwrap().elapsed, 2);
|
|
assert_eq!(h.time_slices[2].skills.get(c).unwrap().elapsed, 1);
|
|
|
|
assert_ulps_eq!(
|
|
h.time_slices[0].skills.get(a).unwrap().posterior(),
|
|
Gaussian::from_ms(0.000000, 1.300610),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
h.time_slices[0].skills.get(b).unwrap().posterior(),
|
|
Gaussian::from_ms(0.000000, 1.300610),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
h.time_slices[2].skills.get(b).unwrap().posterior(),
|
|
Gaussian::from_ms(0.000000, 1.300610),
|
|
epsilon = 1e-6
|
|
);
|
|
|
|
let events2 = make_events_1v1(
|
|
&[("a", "b"), ("a", "c"), ("b", "c")],
|
|
&[
|
|
Outcome::winner(0, 2),
|
|
Outcome::winner(1, 2),
|
|
Outcome::winner(0, 2),
|
|
],
|
|
&[4, 5, 6],
|
|
);
|
|
h.add_events(events2).unwrap();
|
|
|
|
assert_eq!(h.time_slices.len(), 6);
|
|
|
|
assert_eq!(
|
|
h.time_slices
|
|
.iter()
|
|
.map(|b| b.get_composition())
|
|
.collect::<Vec<_>>(),
|
|
vec![
|
|
vec![vec![vec![a], vec![b]]],
|
|
vec![vec![vec![a], vec![c]]],
|
|
vec![vec![vec![b], vec![c]]],
|
|
vec![vec![vec![a], vec![b]]],
|
|
vec![vec![vec![a], vec![c]]],
|
|
vec![vec![vec![b], vec![c]]]
|
|
]
|
|
);
|
|
|
|
h.converge().unwrap();
|
|
|
|
assert_ulps_eq!(
|
|
h.time_slices[0].skills.get(a).unwrap().posterior(),
|
|
Gaussian::from_ms(0.000000, 0.931236),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
h.time_slices[3].skills.get(a).unwrap().posterior(),
|
|
Gaussian::from_ms(0.000000, 0.931236),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
h.time_slices[3].skills.get(b).unwrap().posterior(),
|
|
Gaussian::from_ms(0.000000, 0.931236),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
h.time_slices[5].skills.get(b).unwrap().posterior(),
|
|
Gaussian::from_ms(0.000000, 0.931236),
|
|
epsilon = 1e-6
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_log_evidence() {
|
|
use crate::ConvergenceOptions;
|
|
|
|
let mut h: History<i64, _, _, &'static str> = History::builder().build();
|
|
|
|
// empty results in the old API = team 0 wins; reproduce with Outcome::winner(0,2)
|
|
let events = make_events_1v1(
|
|
&[("a", "b"), ("b", "a")],
|
|
&[Outcome::winner(0, 2), Outcome::winner(0, 2)],
|
|
&[1, 2],
|
|
);
|
|
h.add_events(events).unwrap();
|
|
|
|
let a = h.keys.get("a").unwrap();
|
|
let b = h.keys.get("b").unwrap();
|
|
|
|
let p_d_m_2 = h.log_evidence_internal(false, &[]).exp() * 2.0;
|
|
|
|
assert_ulps_eq!(p_d_m_2, 0.17650911, epsilon = 1e-6);
|
|
assert_ulps_eq!(
|
|
p_d_m_2,
|
|
h.log_evidence_internal(true, &[]).exp() * 2.0,
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
p_d_m_2,
|
|
h.log_evidence_internal(true, &[a]).exp() * 2.0,
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
p_d_m_2,
|
|
h.log_evidence_internal(false, &[a]).exp() * 2.0,
|
|
epsilon = 1e-6
|
|
);
|
|
|
|
// run exactly 11 iterations (old test used convergence(11, ...))
|
|
h.convergence = ConvergenceOptions {
|
|
max_iter: 11,
|
|
epsilon: EPSILON,
|
|
};
|
|
h.converge().unwrap();
|
|
|
|
let loocv_approx_2 = h.log_evidence_internal(false, &[]).exp().sqrt();
|
|
|
|
assert_ulps_eq!(loocv_approx_2, 0.001976774, epsilon = 0.000001);
|
|
|
|
let p_d_m_approx_2 = h.log_evidence_internal(true, &[]).exp() * 2.0;
|
|
|
|
assert!(loocv_approx_2 - p_d_m_approx_2 < 1e-4);
|
|
|
|
assert_ulps_eq!(
|
|
loocv_approx_2,
|
|
h.log_evidence_internal(true, &[b]).exp() * 2.0,
|
|
epsilon = 1e-4
|
|
);
|
|
|
|
let mut h2: History<i64, _, _, &'static str> = History::builder().build();
|
|
|
|
let events = make_events_1v1(
|
|
&[("a", "b"), ("b", "a")],
|
|
&[Outcome::winner(0, 2), Outcome::winner(0, 2)],
|
|
&[1, 2],
|
|
);
|
|
h2.add_events(events).unwrap();
|
|
|
|
assert_ulps_eq!(
|
|
((0.5f64 * 0.1765).ln() / 2.0).exp(),
|
|
(h2.log_evidence_internal(false, &[]) / 2.0).exp(),
|
|
epsilon = 1e-4
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_add_events_with_time() {
|
|
let mut h: History<i64, _, _, &'static str> = History::builder()
|
|
.mu(0.0)
|
|
.sigma(2.0)
|
|
.beta(1.0)
|
|
.drift(ConstantDrift(0.0))
|
|
.build();
|
|
|
|
let events = make_events_1v1(
|
|
&[("a", "b"), ("a", "c"), ("b", "c")],
|
|
&[
|
|
Outcome::winner(0, 2),
|
|
Outcome::winner(1, 2),
|
|
Outcome::winner(0, 2),
|
|
],
|
|
&[0, 10, 20],
|
|
);
|
|
h.add_events(events).unwrap();
|
|
h.converge().unwrap();
|
|
|
|
let a = h.keys.get("a").unwrap();
|
|
let b = h.keys.get("b").unwrap();
|
|
let c = h.keys.get("c").unwrap();
|
|
|
|
let events2 = make_events_1v1(
|
|
&[("a", "b"), ("a", "c"), ("b", "c")],
|
|
&[
|
|
Outcome::winner(0, 2),
|
|
Outcome::winner(1, 2),
|
|
Outcome::winner(0, 2),
|
|
],
|
|
&[15, 10, 0],
|
|
);
|
|
h.add_events(events2).unwrap();
|
|
|
|
assert_eq!(h.time_slices.len(), 4);
|
|
|
|
assert_eq!(
|
|
h.time_slices
|
|
.iter()
|
|
.map(|ts| ts.events.len())
|
|
.collect::<Vec<_>>(),
|
|
vec![2, 2, 1, 1]
|
|
);
|
|
|
|
assert_eq!(
|
|
h.time_slices
|
|
.iter()
|
|
.map(|b| b.get_composition())
|
|
.collect::<Vec<_>>(),
|
|
vec![
|
|
vec![vec![vec![a], vec![b]], vec![vec![b], vec![c]]],
|
|
vec![vec![vec![a], vec![c]], vec![vec![a], vec![c]]],
|
|
vec![vec![vec![a], vec![b]]],
|
|
vec![vec![vec![b], vec![c]]]
|
|
]
|
|
);
|
|
|
|
assert_eq!(
|
|
h.time_slices
|
|
.iter()
|
|
.map(|b| b.get_results())
|
|
.collect::<Vec<_>>(),
|
|
vec![
|
|
vec![vec![1.0, 0.0], vec![1.0, 0.0]],
|
|
vec![vec![0.0, 1.0], vec![0.0, 1.0]],
|
|
vec![vec![1.0, 0.0]],
|
|
vec![vec![1.0, 0.0]]
|
|
]
|
|
);
|
|
|
|
let end = h.time_slices.len() - 1;
|
|
|
|
assert_eq!(h.time_slices[0].skills.get(c).unwrap().elapsed, 0);
|
|
assert_eq!(h.time_slices[end].skills.get(c).unwrap().elapsed, 10);
|
|
|
|
assert_eq!(h.time_slices[0].skills.get(a).unwrap().elapsed, 0);
|
|
assert_eq!(h.time_slices[2].skills.get(a).unwrap().elapsed, 5);
|
|
|
|
assert_eq!(h.time_slices[0].skills.get(b).unwrap().elapsed, 0);
|
|
assert_eq!(h.time_slices[end].skills.get(b).unwrap().elapsed, 5);
|
|
|
|
h.converge().unwrap();
|
|
|
|
assert_ulps_eq!(
|
|
h.time_slices[0].skills.get(b).unwrap().posterior(),
|
|
h.time_slices[end].skills.get(b).unwrap().posterior(),
|
|
epsilon = 1e-6
|
|
);
|
|
|
|
assert_ulps_eq!(
|
|
h.time_slices[0].skills.get(c).unwrap().posterior(),
|
|
h.time_slices[end].skills.get(c).unwrap().posterior(),
|
|
epsilon = 1e-6
|
|
);
|
|
|
|
assert_ulps_eq!(
|
|
h.time_slices[0].skills.get(c).unwrap().posterior(),
|
|
h.time_slices[0].skills.get(b).unwrap().posterior(),
|
|
epsilon = 1e-6
|
|
);
|
|
|
|
// second scenario: team-0 wins (empty results in old API), different composition order
|
|
|
|
let mut h2: History<i64, _, _, &'static str> = History::builder()
|
|
.mu(0.0)
|
|
.sigma(2.0)
|
|
.beta(1.0)
|
|
.drift(ConstantDrift(0.0))
|
|
.build();
|
|
|
|
let events = make_events_1v1(
|
|
&[("a", "b"), ("c", "a"), ("b", "c")],
|
|
&[
|
|
Outcome::winner(0, 2),
|
|
Outcome::winner(0, 2),
|
|
Outcome::winner(0, 2),
|
|
],
|
|
&[0, 10, 20],
|
|
);
|
|
h2.add_events(events).unwrap();
|
|
h2.converge().unwrap();
|
|
|
|
let a = h2.keys.get("a").unwrap();
|
|
let b = h2.keys.get("b").unwrap();
|
|
let c = h2.keys.get("c").unwrap();
|
|
|
|
let events2 = make_events_1v1(
|
|
&[("a", "b"), ("c", "a"), ("b", "c")],
|
|
&[
|
|
Outcome::winner(0, 2),
|
|
Outcome::winner(0, 2),
|
|
Outcome::winner(0, 2),
|
|
],
|
|
&[15, 10, 0],
|
|
);
|
|
h2.add_events(events2).unwrap();
|
|
|
|
assert_eq!(h2.time_slices.len(), 4);
|
|
|
|
assert_eq!(
|
|
h2.time_slices
|
|
.iter()
|
|
.map(|ts| ts.events.len())
|
|
.collect::<Vec<_>>(),
|
|
vec![2, 2, 1, 1]
|
|
);
|
|
|
|
assert_eq!(
|
|
h2.time_slices
|
|
.iter()
|
|
.map(|b| b.get_composition())
|
|
.collect::<Vec<_>>(),
|
|
vec![
|
|
vec![vec![vec![a], vec![b]], vec![vec![b], vec![c]]],
|
|
vec![vec![vec![c], vec![a]], vec![vec![c], vec![a]]],
|
|
vec![vec![vec![a], vec![b]]],
|
|
vec![vec![vec![b], vec![c]]]
|
|
]
|
|
);
|
|
|
|
assert_eq!(
|
|
h2.time_slices
|
|
.iter()
|
|
.map(|b| b.get_results())
|
|
.collect::<Vec<_>>(),
|
|
vec![
|
|
vec![vec![1.0, 0.0], vec![1.0, 0.0]],
|
|
vec![vec![1.0, 0.0], vec![1.0, 0.0]],
|
|
vec![vec![1.0, 0.0]],
|
|
vec![vec![1.0, 0.0]]
|
|
]
|
|
);
|
|
|
|
let end = h2.time_slices.len() - 1;
|
|
|
|
assert_eq!(h2.time_slices[0].skills.get(c).unwrap().elapsed, 0);
|
|
assert_eq!(h2.time_slices[end].skills.get(c).unwrap().elapsed, 10);
|
|
|
|
assert_eq!(h2.time_slices[0].skills.get(a).unwrap().elapsed, 0);
|
|
assert_eq!(h2.time_slices[2].skills.get(a).unwrap().elapsed, 5);
|
|
|
|
assert_eq!(h2.time_slices[0].skills.get(b).unwrap().elapsed, 0);
|
|
assert_eq!(h2.time_slices[end].skills.get(b).unwrap().elapsed, 5);
|
|
|
|
h2.converge().unwrap();
|
|
|
|
assert_ulps_eq!(
|
|
h2.time_slices[0].skills.get(b).unwrap().posterior(),
|
|
h2.time_slices[end].skills.get(b).unwrap().posterior(),
|
|
epsilon = 1e-6
|
|
);
|
|
|
|
assert_ulps_eq!(
|
|
h2.time_slices[0].skills.get(c).unwrap().posterior(),
|
|
h2.time_slices[end].skills.get(c).unwrap().posterior(),
|
|
epsilon = 1e-6
|
|
);
|
|
|
|
assert_ulps_eq!(
|
|
h2.time_slices[0].skills.get(c).unwrap().posterior(),
|
|
h2.time_slices[0].skills.get(b).unwrap().posterior(),
|
|
epsilon = 1e-6
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_1vs1_weighted() {
|
|
let mut h: History<i64, _, _, &'static str> = History::builder()
|
|
.mu(2.0)
|
|
.sigma(6.0)
|
|
.beta(1.0)
|
|
.drift(ConstantDrift(0.0))
|
|
.build();
|
|
|
|
// empty results in old API = team 0 wins: a wins event 1, b wins event 2
|
|
let events: Vec<Event<i64, &'static str>> = vec![
|
|
Event {
|
|
time: 1,
|
|
teams: smallvec![
|
|
Team::with_members([Member::new("a").with_weight(5.0)]),
|
|
Team::with_members([Member::new("b").with_weight(4.0)]),
|
|
],
|
|
outcome: Outcome::winner(0, 2),
|
|
},
|
|
Event {
|
|
time: 2,
|
|
teams: smallvec![
|
|
Team::with_members([Member::new("b").with_weight(5.0)]),
|
|
Team::with_members([Member::new("a").with_weight(4.0)]),
|
|
],
|
|
outcome: Outcome::winner(0, 2),
|
|
},
|
|
];
|
|
h.add_events(events).unwrap();
|
|
|
|
let lc_a = h.learning_curve("a");
|
|
let lc_b = h.learning_curve("b");
|
|
|
|
assert_ulps_eq!(
|
|
lc_a[0].1,
|
|
Gaussian::from_ms(5.537659, 4.758722),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
lc_b[0].1,
|
|
Gaussian::from_ms(-0.830127, 5.239568),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
lc_a[1].1,
|
|
Gaussian::from_ms(1.792277, 4.099566),
|
|
epsilon = 1e-6
|
|
);
|
|
assert_ulps_eq!(
|
|
lc_b[1].1,
|
|
Gaussian::from_ms(4.845533, 3.747616),
|
|
epsilon = 1e-6
|
|
);
|
|
|
|
h.converge().unwrap();
|
|
|
|
let lc_a = h.learning_curve("a");
|
|
let lc_b = h.learning_curve("b");
|
|
|
|
assert_ulps_eq!(lc_a[0].1, lc_a[0].1, epsilon = 1e-6);
|
|
assert_ulps_eq!(lc_b[0].1, lc_a[0].1, epsilon = 1e-6);
|
|
assert_ulps_eq!(lc_a[1].1, lc_a[0].1, epsilon = 1e-6);
|
|
assert_ulps_eq!(lc_b[1].1, lc_a[0].1, epsilon = 1e-6);
|
|
}
|
|
|
|
#[test]
|
|
fn test_converge_returns_report() {
|
|
use crate::ConvergenceOptions;
|
|
|
|
let mut h: History<i64, _, _, &'static str> = History::builder()
|
|
.mu(0.0)
|
|
.sigma(2.0)
|
|
.beta(1.0)
|
|
.drift(ConstantDrift(0.0))
|
|
.convergence(ConvergenceOptions {
|
|
max_iter: 30,
|
|
epsilon: 1e-6,
|
|
})
|
|
.build();
|
|
|
|
let events = make_events_1v1(
|
|
&[("a", "b"), ("a", "c"), ("b", "c")],
|
|
&[
|
|
Outcome::winner(0, 2),
|
|
Outcome::winner(1, 2),
|
|
Outcome::winner(0, 2),
|
|
],
|
|
&[1, 2, 3],
|
|
);
|
|
h.add_events(events).unwrap();
|
|
|
|
let report = h.converge().unwrap();
|
|
assert!(report.converged);
|
|
assert!(report.iterations > 0);
|
|
assert!(report.iterations < 30);
|
|
assert!(report.final_step.0 <= 1e-6);
|
|
}
|
|
}
|