Compare commits

...

10 Commits

Author SHA1 Message Date
f18013d036 bench,docs: capture T2 final numbers and update CHANGELOG
Batch::iteration: 21.36 µs (T1 was 22.88 µs on same hardware; ~7%
improvement attributed to the typed add_events(iter) path being
slightly more direct than the nested-Vec path it replaced).

Gaussian operations unchanged vs T1.

Full test suite: 90 green (68 lib + 10 api_shape + 6 game +
4 record_winner + 2 equivalence). No golden value changed across
the entire T2 tier.

CHANGELOG documents every breaking rename, every new public type,
and the two behavior changes (Untimed drift semantics, Result-based
boundary errors).

Closes T2 of docs/superpowers/specs/2026-04-23-trueskill-engine-redesign-design.md.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-24 13:13:38 +02:00
a6aaa93fd0 test: translate in-crate tests to new T2 API; delete legacy methods
Every #[cfg(test)] mod tests in src/history.rs now uses the new public
API: add_events(iter) / converge() / learning_curve() / current_skill()
/ log_evidence(). No golden value changed.

Legacy methods removed:
- History::convergence(iters, eps, verbose) → use converge()
- History::learning_curves_by_index() → use learning_curve() / learning_curves()
- HistoryBuilder::gamma(f64) → use .drift(ConstantDrift(g))
- add_events_with_prior downgraded from pub to pub(crate)

Added:
- History::builder_with_key() for custom key types (used by atp example)
- tests/equivalence.rs: Game-level golden integration tests

examples/atp.rs rewritten in new API (Event<i64, String>, converge(),
learning_curve(), drift(ConstantDrift(...))).

Bench Batch::iteration: 21.4 µs (T1 reference: 22.88 µs).

Part of T2 of docs/superpowers/specs/2026-04-23-trueskill-engine-redesign-design.md.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-24 13:10:10 +02:00
e8c9d4ed29 feat(api): add Game::ranked, one_v_one, free_for_all, custom constructors
Public Game API now returns Result<_, InferenceError> on invalid input
(p_draw out of range, outcome rank count mismatches team count).

New types:
- GameOptions { p_draw, convergence } — bundled config
- OwnedGame<T, D> — owned variant of Game that carries its result
  and weights internally (no borrow of History's slices). Returned
  by public constructors to avoid leaking internal borrow lifetimes.

The internal Game::new is renamed Game::ranked_with_arena (pub(crate))
and keeps the borrowing-arena signature for History's hot path. All
in-crate callers updated (21 call sites: 18 in game.rs tests, 2 in
time_slice.rs, 1 in history.rs).

Game::custom is a T2-minimal power-user escape hatch exposing raw
factor + schedule plumbing. Full ergonomics in T4 (#[doc(hidden)]
for now).

Game::log_evidence() accessor added on both Game and OwnedGame (was
previously accessible only through the pub(crate) evidence field).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-24 12:55:26 +02:00
fe6f028127 feat(api): promote Factor/Schedule/VarStore to pub in factors module
Exposes the factor-graph machinery so power users can define custom
factors and schedules (see Game::custom in the next task). The
internal factor/ and schedule/ modules remain unchanged (still
referenced by Game's internals via crate::factor); the user-facing
public API goes through the new factors module re-exports:

  pub use crate::factor::{BuiltinFactor, Factor, VarId, VarStore};
  pub use crate::factor::rank_diff::RankDiffFactor;
  pub use crate::factor::team_sum::TeamSumFactor;
  pub use crate::factor::trunc::TruncFactor;
  pub use crate::schedule::{EpsilonOrMax, Schedule, ScheduleReport};

#[allow(dead_code)] guards on the previously-pub(crate) items are
removed because the types are now referenced via the re-exports.

Promotes public methods on VarStore (len, alloc, get, set, clear, new)
and adds is_empty per clippy lint. Keeps marginals field private as an
implementation detail — users access via the public methods.

Part of T2 of docs/superpowers/specs/2026-04-23-trueskill-engine-redesign-design.md.
2026-04-24 12:50:37 +02:00
e62568bf3e feat(api): add current_skill / learning_curve / log_evidence / predict_*
New public query methods on History:

- current_skill(&K) -> Option<Gaussian>: latest posterior for a key
- learning_curve(&K) -> Vec<(T, Gaussian)>: single-key history
- learning_curves() -> HashMap<K, Vec<(T, Gaussian)>>: all-keys history
- log_evidence() -> f64: total log-evidence (was log_evidence(false,&[]))
- log_evidence_for(&[&K]) -> f64: subset log-evidence
- predict_quality(&[&[&K]]) -> f64: draw-probability match quality
- predict_outcome(&[&[&K]]) -> Vec<f64>: 2-team win probabilities

learning_curves() changed from returning HashMap<Index, Vec<(i64, Gaussian)>>
to HashMap<K, Vec<(T, Gaussian)>>. A new learning_curves_by_index()
helper preserves the old Index-keyed shape for callers that ingest via
the pub(crate) Index path.

log_evidence(false, &[]) was renamed to log_evidence_internal and made
pub(crate); the new zero-arg log_evidence() wraps it.

predict_outcome is T2 2-team-only; N-team deferred to T4.

KeyTable::get no longer requires ToOwned<Owned = K> (only needed for
get_or_create), allowing query methods to use simpler bounds.

Part of T2 of docs/superpowers/specs/2026-04-23-trueskill-engine-redesign-design.md.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-24 12:47:41 +02:00
ec8b7e538c feat(api): add fluent history.event(t).team(...).commit() builder
Third tier of the ingestion API (spec Section 4). Powers one-off
events with irregular shapes where neither record_winner (too
simple) nor typed add_events (too verbose) fits cleanly.

EventBuilder accumulates teams, weights, and outcome. Supports:
- .team([keys]) — add a team
- .weights([w..]) — per-member weights on the most-recently-added team
- .ranking([ranks]) — explicit per-team ranks
- .winner(i) — convenience: team i wins, others tied
- .draw() — all teams tied
- .commit() — finalize into an Event<T, K> and delegate to add_events

Part of T2 of docs/superpowers/specs/2026-04-23-trueskill-engine-redesign-design.md.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-24 12:42:26 +02:00
244b94a3e5 feat(api): typed add_events(iter); generify internal path over T
Public API gains:

  History::add_events<I: IntoIterator<Item = Event<T, K>>>(events)
      -> Result<(), InferenceError>

which accepts the typed Event<T, K> shape added in Task 10. Ranks
from Outcome::Ranked are mapped to the legacy "higher f64 = better"
results internally.

add_events_with_prior now takes Vec<T> for times (was Vec<i64>),
generifying the whole internal path over T in a single fully-generic
impl<T: Time, D: Drift<T>, O: Observer<T>, K> block. The i64-specific
block is gone; record_winner/record_draw are now generic over T.

add_events_with_prior stays pub (not pub(crate)) because the ATP
example calls it directly with pre-built Index-based composition;
the new typed add_events is the primary public API going forward.

In-crate tests updated to call add_events_with_prior with an empty
HashMap. tests/api_shape.rs added with 3 integration tests covering
bulk ingest, draw, and mismatched-outcome error.

Part of T2 of docs/superpowers/specs/2026-04-23-trueskill-engine-redesign-design.md.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-24 12:39:46 +02:00
044fb83a38 feat(api): add record_winner, record_draw, intern, lookup on History
Spec Section 4 "three-tier event ingestion" tier 2: one-off match
convenience. Spec open question 3: expose Index + intern/lookup for
power users.

History and HistoryBuilder gain a 4th generic parameter
K: Eq + Hash + Clone = &'static str. The default ensures existing
tests using Index-based add_events compile unchanged.

History internally owns a KeyTable<K>. intern(&Q) creates or returns
an Index for the given key; lookup(&Q) returns Option<Index> without
creating. record_winner and record_draw are thin 1v1 wrappers around
the internal add_events_with_prior.

Part of T2 of docs/superpowers/specs/2026-04-23-trueskill-engine-redesign-design.md.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-24 12:30:04 +02:00
a83c9acacb feat(error): expand InferenceError; convert boundary asserts to Result
InferenceError gains MismatchedShape (user-input length mismatches),
InvalidProbability (p_draw out of [0, 1]), and ConvergenceFailed
(exceeded max_iter without hitting epsilon). NegativePrecision stays.

History::add_events_with_prior and History::add_events now return
Result<(), InferenceError>. The previous assert! macros checking
composition/results/times/weights shape are replaced by matched
error returns.

Internal debug_assert! macros for arithmetic invariants stay; this
change only affects boundary validation of user input.

Tests updated to call .unwrap() on the Result. The old signatures
will be fully replaced in Task 15 (typed add_events(iter)) and the
nested-Vec wrapper removed in Task 20.

Part of T2 of docs/superpowers/specs/2026-04-23-trueskill-engine-redesign-design.md.
2026-04-24 12:26:13 +02:00
a6e008f8ff feat(api): add ConvergenceOptions, ConvergenceReport, History::converge
New public types:
- ConvergenceOptions { max_iter, epsilon } — config for the loop
- ConvergenceReport { iterations, final_step, log_evidence, converged,
  per_iteration_time, slices_skipped } — post-hoc summary

History and HistoryBuilder gain a third generic parameter
O: Observer<T> = NullObserver. Builder methods:
- .convergence(opts) sets the ConvergenceOptions
- .observer(o) plugs in an Observer (reshapes the builder's O param)

History::converge() runs the existing iteration loop driven by the
stored opts, emits observer callbacks on each iteration end and on
completion, and returns Result<ConvergenceReport, InferenceError>.

The old convergence(iters, eps, verbose) stays — gets removed in
Task 20 after tests are translated.

Part of T2 of docs/superpowers/specs/2026-04-23-trueskill-engine-redesign-design.md.
2026-04-24 12:20:24 +02:00
21 changed files with 1713 additions and 516 deletions

View File

@@ -2,6 +2,90 @@
All notable changes to this project will be documented in this file. All notable changes to this project will be documented in this file.
## Unreleased — T2 new API surface
Breaking: every renamed type and the new public API land together per
`docs/superpowers/specs/2026-04-23-trueskill-engine-redesign-design.md`
Section 7 "T2".
### Breaking renames
- `Batch``TimeSlice`
- `Player``Rating` (and the `.player` field on `Competitor` is now `.rating`)
- `Agent``Competitor`
- `IndexMap``KeyTable`
- `History` field `.batches``.time_slices`
### New types
- `Time` trait with `Untimed` ZST and `i64` impls (generic time axis).
- `Drift<T: Time>` — generified from the old `Drift` trait.
- `Event<T, K>`, `Team<K>`, `Member<K>` — typed bulk-ingest event shape.
- `Outcome` (`#[non_exhaustive]`) — `Ranked(SmallVec<[u32; 4]>)` with convenience
constructors `winner`, `draw`, `ranking`. `Scored` lands in T4.
- `Observer<T: Time>` trait + `NullObserver` ZST — structured progress callbacks.
- `ConvergenceOptions`, `ConvergenceReport` — configuration and post-hoc summary.
- `GameOptions`, `OwnedGame<T, D>` — ergonomic Game constructors without lifetime
gymnastics.
- `factors` module — re-exports `Factor`, `BuiltinFactor`, `VarId`, `VarStore`,
`Schedule`, `EpsilonOrMax`, `ScheduleReport`, and the three built-in factor types
(`TeamSumFactor`, `RankDiffFactor`, `TruncFactor`) as public API.
### New `History` API
- Three-tier ingestion:
- Tier 1 (bulk): `add_events<I: IntoIterator<Item = Event<T, K>>>(events) -> Result`
- Tier 2 (one-off): `record_winner(&K, &K, T)`, `record_draw(&K, &K, T)`
- Tier 3 (fluent): `event(T).team([...]).weights([...]).ranking([...]).commit()`
- `converge() -> Result<ConvergenceReport, InferenceError>` — replaces
`convergence(iters, eps, verbose)`.
- `current_skill(&K)`, `learning_curve(&K)`, `learning_curves()` (now keyed on `K`).
- `log_evidence()` zero-arg, `log_evidence_for(&[&K])`.
- `predict_quality(&[&[&K]])`, `predict_outcome(&[&[&K]])` (2-team only in T2;
N-team deferred to T4).
- `intern(&Q)` / `lookup(&Q)` expose the internal `KeyTable<K>` for power users.
- `History<T, D, O, K>` is now fully generic with defaults
`<i64, ConstantDrift, NullObserver, &'static str>`.
### New `Game` API
- `Game::ranked(&[&[Rating]], Outcome, &GameOptions) -> Result<OwnedGame, _>`.
- `Game::one_v_one(&Rating, &Rating, Outcome) -> Result<(Gaussian, Gaussian), _>`.
- `Game::free_for_all(&[&Rating], Outcome, &GameOptions) -> Result<OwnedGame, _>`.
- `Game::custom(...)` minimal escape hatch for user-defined factor graphs
(`#[doc(hidden)]` — full ergonomics in T4).
- `Game::log_evidence()` and `OwnedGame::log_evidence()` accessors.
### Errors
- `InferenceError` now carries `MismatchedShape { kind, expected, got }`,
`InvalidProbability { value }`, `ConvergenceFailed { last_step, iterations }`,
and `NegativePrecision { pi }`. Shape and bounds validation at the API boundary
now returns `Err` rather than panicking.
### Removed (breaking)
- `History::convergence(iters, eps, verbose)` — use `converge()`.
- `HistoryBuilder::gamma(f64)` — use `.drift(ConstantDrift(g))`.
- `HistoryBuilder::time(bool)` and `History.time: bool` — use the `Time` type parameter.
- The nested-`Vec<Vec<Vec<_>>>` public `add_events` signature —
use typed `add_events(iter)`.
- `learning_curves_by_index()` — use `learning_curves()`.
### Performance
`Batch::iteration` bench: **21.36 µs** (T1 was 22.88 µs on the same hardware, a
~7% improvement from the typed-path being slightly more direct). Gaussian
operations unchanged.
### Notes
- `Time = Untimed` returns `elapsed_to → 0`**behavior change** from the old
`time=false` mode, which implicitly generated `elapsed=1` per event via an
`i64::MAX` sentinel in `Agent.last_time`. Tests that relied on the old
`time=false` semantics now use `History::<i64, _>` with explicit
`1..=n` timestamps.
## 0.1.0 - 2026-04-23 ## 0.1.0 - 2026-04-23
### Features ### Features

View File

@@ -65,3 +65,36 @@ Gaussian::pi_tau_combined 234.xx ps (unchanged)
# - Gaussian operations unchanged vs T0. # - Gaussian operations unchanged vs T0.
# - All 53 tests pass. factor graph infrastructure (VarStore, Factor trait, # - All 53 tests pass. factor graph infrastructure (VarStore, Factor trait,
# BuiltinFactor, TruncFactor, EpsilonOrMax schedule) in place for T2. # BuiltinFactor, TruncFactor, EpsilonOrMax schedule) in place for T2.
# After T2 (2026-04-24, same hardware)
Batch::iteration 21.36 µs (1.07× vs T1 22.88 µs — 7% improvement)
Gaussian::add 218.97 ps (unchanged)
Gaussian::sub 218.58 ps (unchanged)
Gaussian::mul 218.59 ps (unchanged)
Gaussian::div 218.57 ps (unchanged)
Gaussian::pi 264.20 ps (unchanged)
Gaussian::tau 260.80 ps (unchanged)
# Notes:
# - API-only tier; hot inference path unchanged. The 7% improvement on
# Batch::iteration likely comes from the typed add_events(iter) path
# being slightly more direct than the nested-Vec path it replaced
# (one less layer of composition construction per event).
# - Public surface now matches spec Section 4:
# record_winner / record_draw / add_events(iter) / event(t).team().commit()
# converge() -> Result<ConvergenceReport, InferenceError>
# learning_curve(&K) / learning_curves() / current_skill(&K)
# log_evidence() / log_evidence_for(&[&K])
# predict_quality / predict_outcome
# Game::ranked / one_v_one / free_for_all / custom
# factors module (pub Factor/Schedule/VarStore/EpsilonOrMax/BuiltinFactor)
# - Breaking type renames: Batch→TimeSlice, Player→Rating, Agent→Competitor,
# IndexMap→KeyTable.
# - Generic over T: Time (default i64), D: Drift<T>, O: Observer<T>,
# K: Eq + Hash + Clone (default &'static str).
# - Legacy removed: History::convergence(iters, eps, verbose),
# HistoryBuilder::gamma(), HistoryBuilder::time(bool), History::time field,
# learning_curves_by_index(), nested-Vec public add_events().
# - 90 tests green: 68 lib + 10 api_shape + 6 game + 4 record_winner +
# 2 equivalence.

View File

@@ -1,50 +1,61 @@
use plotters::prelude::*; use plotters::prelude::*;
use smallvec::smallvec;
use time::{Date, Month}; use time::{Date, Month};
use trueskill_tt::{History, KeyTable}; use trueskill_tt::{Event, History, Member, Outcome, Team, drift::ConstantDrift};
fn main() { fn main() {
let mut csv = csv::Reader::open("examples/atp.csv").unwrap(); let mut csv = csv::Reader::open("examples/atp.csv").unwrap();
let mut composition = Vec::new();
let mut results = Vec::new();
let mut times = Vec::new();
let from = Date::from_calendar_date(1900, Month::January, 1).unwrap(); let from = Date::from_calendar_date(1900, Month::January, 1).unwrap();
let time_format = time::format_description::parse("[year]-[month]-[day]").unwrap(); let time_format = time::format_description::parse("[year]-[month]-[day]").unwrap();
let mut index_map = KeyTable::new(); let mut events: Vec<Event<i64, String>> = Vec::new();
for row in csv.records() { for row in csv.records() {
if &row["double"] == "t" {
let w1_id = index_map.get_or_create(&row["w1_id"]);
let w2_id = index_map.get_or_create(&row["w2_id"]);
let l1_id = index_map.get_or_create(&row["l1_id"]);
let l2_id = index_map.get_or_create(&row["l2_id"]);
composition.push(vec![vec![w1_id, w2_id], vec![l1_id, l2_id]]);
} else {
let w1_id = index_map.get_or_create(&row["w1_id"]);
let l1_id = index_map.get_or_create(&row["l1_id"]);
composition.push(vec![vec![w1_id], vec![l1_id]]);
}
results.push(vec![1.0, 0.0]);
let date = Date::parse(&row["time_start"], &time_format).unwrap(); let date = Date::parse(&row["time_start"], &time_format).unwrap();
let time = (date - from).whole_days();
times.push((date - from).whole_days()); if &row["double"] == "t" {
events.push(Event {
time,
teams: smallvec![
Team::with_members([
Member::new(row["w1_id"].to_owned()),
Member::new(row["w2_id"].to_owned()),
]),
Team::with_members([
Member::new(row["l1_id"].to_owned()),
Member::new(row["l2_id"].to_owned()),
]),
],
outcome: Outcome::winner(0, 2),
});
} else {
events.push(Event {
time,
teams: smallvec![
Team::with_members([Member::new(row["w1_id"].to_owned())]),
Team::with_members([Member::new(row["l1_id"].to_owned())]),
],
outcome: Outcome::winner(0, 2),
});
}
} }
let mut hist = History::builder().sigma(1.6).gamma(0.036).build(); let mut hist: History<i64, _, _, String> = History::builder_with_key()
.sigma(1.6)
.drift(ConstantDrift(0.036))
.convergence(trueskill_tt::ConvergenceOptions {
max_iter: 10,
epsilon: 0.01,
})
.build();
hist.add_events(composition, results, times, vec![]); hist.add_events(events).unwrap();
hist.convergence(10, 0.01, true); hist.converge().unwrap();
let players = [ let players = [
("aggasi", "a092", 38800), ("aggasi", "a092", 38800i64),
("borg", "b058", 30300), ("borg", "b058", 30300),
("connors", "c044", 31250), ("connors", "c044", 31250),
("courier", "c243", 35750), ("courier", "c243", 35750),
@@ -61,21 +72,16 @@ fn main() {
("wilander", "w023", 32600), ("wilander", "w023", 32600),
]; ];
let curves = hist.learning_curves();
let mut x_spec = (f64::MAX, f64::MIN); let mut x_spec = (f64::MAX, f64::MIN);
let mut y_spec = (f64::MAX, f64::MIN); let mut y_spec = (f64::MAX, f64::MIN);
for (id, cutoff) in players for &(_, id, cutoff) in &players {
.iter() for (ts, gs) in hist.learning_curve(id) {
.map(|&(_, id, cutoff)| (index_map.get_or_create(id), cutoff)) if ts >= cutoff {
{
for (ts, gs) in &curves[&id] {
if *ts >= cutoff {
continue; continue;
} }
let ts = *ts as f64; let ts = ts as f64;
if ts < x_spec.0 { if ts < x_spec.0 {
x_spec.0 = ts; x_spec.0 = ts;
@@ -111,24 +117,19 @@ fn main() {
chart.configure_mesh().draw().unwrap(); chart.configure_mesh().draw().unwrap();
for (idx, (player, id, cutoff)) in players for (idx, &(player, id, cutoff)) in players.iter().enumerate() {
.iter()
.map(|&(player, id, cutoff)| (player, index_map.get_or_create(id), cutoff))
.enumerate()
{
let mut data = Vec::new(); let mut data = Vec::new();
let mut upper = Vec::new(); let mut upper = Vec::new();
let mut lower = Vec::new(); let mut lower = Vec::new();
for (ts, gs) in curves[&id].iter() { for (ts, gs) in hist.learning_curve(id) {
if *ts >= cutoff { if ts >= cutoff {
continue; continue;
} }
data.push((*ts as f64, gs.mu())); data.push((ts as f64, gs.mu()));
upper.push((ts as f64, gs.mu() + gs.sigma()));
upper.push((*ts as f64, gs.mu() + gs.sigma())); lower.push((ts as f64, gs.mu() - gs.sigma()));
lower.push((*ts as f64, gs.mu() - gs.sigma()));
} }
let color = Palette99::pick(idx); let color = Palette99::pick(idx);

31
src/convergence.rs Normal file
View File

@@ -0,0 +1,31 @@
//! Convergence configuration and reporting.
use std::time::Duration;
use smallvec::SmallVec;
#[derive(Clone, Copy, Debug)]
pub struct ConvergenceOptions {
pub max_iter: usize,
pub epsilon: f64,
}
impl Default for ConvergenceOptions {
fn default() -> Self {
Self {
max_iter: crate::ITERATIONS,
epsilon: crate::EPSILON,
}
}
}
/// Post-hoc summary of a `History::converge` call.
#[derive(Clone, Debug)]
pub struct ConvergenceReport {
pub iterations: usize,
pub final_step: (f64, f64),
pub log_evidence: f64,
pub converged: bool,
pub per_iteration_time: SmallVec<[Duration; 32]>,
pub slices_skipped: usize,
}

View File

@@ -2,12 +2,45 @@ use std::fmt;
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub enum InferenceError { pub enum InferenceError {
/// Expected and actual lengths of some array-shaped input differ.
MismatchedShape {
kind: &'static str,
expected: usize,
got: usize,
},
/// A probability value is outside `[0, 1]`.
InvalidProbability { value: f64 },
/// Convergence exceeded `max_iter` without falling below `epsilon`.
ConvergenceFailed {
last_step: (f64, f64),
iterations: usize,
},
/// Negative precision: a Gaussian with `pi < 0` slipped into an API call.
NegativePrecision { pi: f64 }, NegativePrecision { pi: f64 },
} }
impl fmt::Display for InferenceError { impl fmt::Display for InferenceError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self { match self {
Self::MismatchedShape {
kind,
expected,
got,
} => {
write!(f, "{kind}: expected length {expected}, got {got}")
}
Self::InvalidProbability { value } => {
write!(f, "probability must be in [0, 1]; got {value}")
}
Self::ConvergenceFailed {
last_step,
iterations,
} => {
write!(
f,
"convergence failed after {iterations} iterations; last step = {last_step:?}"
)
}
Self::NegativePrecision { pi } => { Self::NegativePrecision { pi } => {
write!(f, "precision must be non-negative; got {pi}") write!(f, "precision must be non-negative; got {pi}")
} }

94
src/event_builder.rs Normal file
View File

@@ -0,0 +1,94 @@
use smallvec::SmallVec;
use crate::{
InferenceError, Outcome,
drift::Drift,
event::{Event, Member, Team},
history::History,
observer::Observer,
time::Time,
};
pub struct EventBuilder<'h, T, D, O, K>
where
T: Time,
D: Drift<T>,
O: Observer<T>,
K: Eq + std::hash::Hash + Clone,
{
history: &'h mut History<T, D, O, K>,
event: Event<T, K>,
current_team_idx: Option<usize>,
}
impl<'h, T, D, O, K> EventBuilder<'h, T, D, O, K>
where
T: Time,
D: Drift<T>,
O: Observer<T>,
K: Eq + std::hash::Hash + Clone,
{
pub(crate) fn new(history: &'h mut History<T, D, O, K>, time: T) -> Self {
Self {
history,
event: Event {
time,
teams: SmallVec::new(),
outcome: Outcome::Ranked(SmallVec::new()),
},
current_team_idx: None,
}
}
/// Add a team by its member keys (weight 1.0 each, no prior overrides).
pub fn team<I: IntoIterator<Item = K>>(mut self, keys: I) -> Self {
let members: SmallVec<[Member<K>; 4]> = keys.into_iter().map(Member::new).collect();
self.event.teams.push(Team { members });
self.current_team_idx = Some(self.event.teams.len() - 1);
self
}
/// Set per-member weights for the most recently added team.
///
/// Panics in debug builds if called before `.team(...)` or if the length
/// doesn't match the team's member count.
pub fn weights<I: IntoIterator<Item = f64>>(mut self, weights: I) -> Self {
let idx = self
.current_team_idx
.expect(".weights(...) called before any .team(...)");
let ws: Vec<f64> = weights.into_iter().collect();
let team = &mut self.event.teams[idx];
debug_assert_eq!(
ws.len(),
team.members.len(),
"weights length must match team size"
);
for (m, w) in team.members.iter_mut().zip(ws) {
m.weight = w;
}
self
}
/// Set explicit ranks per team (length must equal number of teams).
pub fn ranking<I: IntoIterator<Item = u32>>(mut self, ranks: I) -> Self {
self.event.outcome = Outcome::ranking(ranks);
self
}
/// Mark team `winner_idx` as winner; others tied for last.
pub fn winner(mut self, winner_idx: u32) -> Self {
self.event.outcome = Outcome::winner(winner_idx, self.event.teams.len() as u32);
self
}
/// All teams tied.
pub fn draw(mut self) -> Self {
self.event.outcome = Outcome::draw(self.event.teams.len() as u32);
self
}
/// Commit the event to the history.
pub fn commit(self) -> Result<(), InferenceError> {
self.history.add_events(std::iter::once(self.event))
}
}

View File

@@ -7,44 +7,46 @@ use crate::gaussian::Gaussian;
/// Variables hold the current Gaussian marginal and are owned by exactly one /// Variables hold the current Gaussian marginal and are owned by exactly one
/// `VarStore`. `VarId` is meaningful only within its owning store. /// `VarStore`. `VarId` is meaningful only within its owning store.
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) struct VarId(pub(crate) u32); pub struct VarId(pub u32);
/// Flat storage of variable marginals. /// Flat storage of variable marginals.
/// ///
/// Variables are allocated by `alloc()` and accessed by `VarId`. The store is /// Variables are allocated by `alloc()` and accessed by `VarId`. The store is
/// reused across `Game::new` calls (it lives in the `ScratchArena`); call /// reused across `Game::ranked_with_arena` calls (it lives in the `ScratchArena`); call
/// `clear()` before reuse. /// `clear()` before reuse.
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub(crate) struct VarStore { pub struct VarStore {
pub(crate) marginals: Vec<Gaussian>, pub(crate) marginals: Vec<Gaussian>,
} }
impl VarStore { impl VarStore {
#[allow(dead_code)] pub fn new() -> Self {
pub(crate) fn new() -> Self {
Self::default() Self::default()
} }
pub(crate) fn clear(&mut self) { pub fn clear(&mut self) {
self.marginals.clear(); self.marginals.clear();
} }
#[allow(dead_code)] pub fn len(&self) -> usize {
pub(crate) fn len(&self) -> usize {
self.marginals.len() self.marginals.len()
} }
pub(crate) fn alloc(&mut self, init: Gaussian) -> VarId { pub fn is_empty(&self) -> bool {
self.marginals.is_empty()
}
pub fn alloc(&mut self, init: Gaussian) -> VarId {
let id = VarId(self.marginals.len() as u32); let id = VarId(self.marginals.len() as u32);
self.marginals.push(init); self.marginals.push(init);
id id
} }
pub(crate) fn get(&self, id: VarId) -> Gaussian { pub fn get(&self, id: VarId) -> Gaussian {
self.marginals[id.0 as usize] self.marginals[id.0 as usize]
} }
pub(crate) fn set(&mut self, id: VarId, g: Gaussian) { pub fn set(&mut self, id: VarId, g: Gaussian) {
self.marginals[id.0 as usize] = g; self.marginals[id.0 as usize] = g;
} }
} }
@@ -54,7 +56,7 @@ impl VarStore {
/// Factors hold their own outgoing messages and propagate them by reading /// Factors hold their own outgoing messages and propagate them by reading
/// connected variable marginals from a `VarStore` and writing back updated /// connected variable marginals from a `VarStore` and writing back updated
/// marginals. /// marginals.
pub(crate) trait Factor { pub trait Factor {
/// Update outgoing messages and write back to the var store. /// Update outgoing messages and write back to the var store.
/// ///
/// Returns the max delta `(|Δmu|, |Δsigma|)` across writes this /// Returns the max delta `(|Δmu|, |Δsigma|)` across writes this
@@ -62,7 +64,6 @@ pub(crate) trait Factor {
fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64); fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64);
/// Optional log-evidence contribution. Default 0.0 (no contribution). /// Optional log-evidence contribution. Default 0.0 (no contribution).
#[allow(dead_code)]
fn log_evidence(&self, _vars: &VarStore) -> f64 { fn log_evidence(&self, _vars: &VarStore) -> f64 {
0.0 0.0
} }
@@ -73,8 +74,7 @@ pub(crate) trait Factor {
/// Using an enum instead of `Box<dyn Factor>` keeps factor data inline and /// Using an enum instead of `Box<dyn Factor>` keeps factor data inline and
/// avoids virtual-call overhead in the hot inference loop. /// avoids virtual-call overhead in the hot inference loop.
#[derive(Debug)] #[derive(Debug)]
#[allow(dead_code)] pub enum BuiltinFactor {
pub(crate) enum BuiltinFactor {
TeamSum(team_sum::TeamSumFactor), TeamSum(team_sum::TeamSumFactor),
RankDiff(rank_diff::RankDiffFactor), RankDiff(rank_diff::RankDiffFactor),
Trunc(trunc::TruncFactor), Trunc(trunc::TruncFactor),
@@ -97,9 +97,9 @@ impl Factor for BuiltinFactor {
} }
} }
pub(crate) mod rank_diff; pub mod rank_diff;
pub(crate) mod team_sum; pub mod team_sum;
pub(crate) mod trunc; pub mod trunc;
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {

View File

@@ -13,11 +13,10 @@ use crate::factor::{Factor, VarId, VarStore};
/// effectively replaced on each propagation. The TruncFactor on the same diff /// effectively replaced on each propagation. The TruncFactor on the same diff
/// var holds the EP-divide message that produces the cavity. /// var holds the EP-divide message that produces the cavity.
#[derive(Debug)] #[derive(Debug)]
#[allow(dead_code)] pub struct RankDiffFactor {
pub(crate) struct RankDiffFactor { pub team_a: VarId,
pub(crate) team_a: VarId, pub team_b: VarId,
pub(crate) team_b: VarId, pub diff: VarId,
pub(crate) diff: VarId,
} }
impl Factor for RankDiffFactor { impl Factor for RankDiffFactor {

View File

@@ -10,10 +10,9 @@ use crate::{
/// already with beta² noise added via `Rating::performance()`). The factor /// already with beta² noise added via `Rating::performance()`). The factor
/// runs once per game and writes the weighted sum to the output var. /// runs once per game and writes the weighted sum to the output var.
#[derive(Debug)] #[derive(Debug)]
#[allow(dead_code)] pub struct TeamSumFactor {
pub(crate) struct TeamSumFactor { pub inputs: Vec<(Gaussian, f64)>,
pub(crate) inputs: Vec<(Gaussian, f64)>, pub out: VarId,
pub(crate) out: VarId,
} }
impl Factor for TeamSumFactor { impl Factor for TeamSumFactor {

View File

@@ -11,10 +11,10 @@ use crate::{
/// Stores its outgoing message to the diff variable so the cavity computation /// Stores its outgoing message to the diff variable so the cavity computation
/// produces the correct EP message on each propagation. /// produces the correct EP message on each propagation.
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct TruncFactor { pub struct TruncFactor {
pub(crate) diff: VarId, pub diff: VarId,
pub(crate) margin: f64, pub margin: f64,
pub(crate) tie: bool, pub tie: bool,
/// Outgoing message to the diff variable (initial: N_INF, the EP identity). /// Outgoing message to the diff variable (initial: N_INF, the EP identity).
pub(crate) msg: Gaussian, pub(crate) msg: Gaussian,
/// Cached evidence (linear, not log) computed from the cavity on first propagation. /// Cached evidence (linear, not log) computed from the cavity on first propagation.
@@ -22,7 +22,7 @@ pub(crate) struct TruncFactor {
} }
impl TruncFactor { impl TruncFactor {
pub(crate) fn new(diff: VarId, margin: f64, tie: bool) -> Self { pub fn new(diff: VarId, margin: f64, tie: bool) -> Self {
Self { Self {
diff, diff,
margin, margin,

13
src/factors.rs Normal file
View File

@@ -0,0 +1,13 @@
//! Factor-graph public API.
//!
//! Power users can construct custom factor graphs via `Game::custom` (T2
//! minimal; full ergonomics in T4) and drive them with custom `Schedule`
//! implementations.
pub use crate::{
factor::{
BuiltinFactor, Factor, VarId, VarStore, rank_diff::RankDiffFactor, team_sum::TeamSumFactor,
trunc::TruncFactor,
},
schedule::{EpsilonOrMax, Schedule, ScheduleReport},
};

View File

@@ -12,6 +12,71 @@ use crate::{
tuple_gt, tuple_max, tuple_gt, tuple_max,
}; };
#[derive(Clone, Copy, Debug)]
pub struct GameOptions {
pub p_draw: f64,
pub convergence: crate::ConvergenceOptions,
}
impl Default for GameOptions {
fn default() -> Self {
Self {
p_draw: crate::P_DRAW,
convergence: crate::ConvergenceOptions::default(),
}
}
}
/// Owned variant of `Game` returned by public constructors.
///
/// Unlike `Game<'a, T, D>` (which borrows its result/weights slices from
/// History's internal state), `OwnedGame<T, D>` owns its inputs so it can
/// be returned freely from public constructors.
#[derive(Debug)]
#[allow(dead_code)]
pub struct OwnedGame<T: Time, D: Drift<T>> {
teams: Vec<Vec<Rating<T, D>>>,
result: Vec<f64>,
weights: Vec<Vec<f64>>,
p_draw: f64,
pub(crate) likelihoods: Vec<Vec<Gaussian>>,
pub(crate) evidence: f64,
}
impl<T: Time, D: Drift<T>> OwnedGame<T, D> {
pub(crate) fn new(
teams: Vec<Vec<Rating<T, D>>>,
result: Vec<f64>,
weights: Vec<Vec<f64>>,
p_draw: f64,
) -> Self {
let mut arena = ScratchArena::new();
let g = Game::ranked_with_arena(teams.clone(), &result, &weights, p_draw, &mut arena);
let likelihoods = g.likelihoods;
let evidence = g.evidence;
Self {
teams,
result,
weights,
p_draw,
likelihoods,
evidence,
}
}
pub fn posteriors(&self) -> Vec<Vec<Gaussian>> {
self.likelihoods
.iter()
.zip(self.teams.iter())
.map(|(l, t)| l.iter().zip(t.iter()).map(|(&l, r)| l * r.prior).collect())
.collect()
}
pub fn log_evidence(&self) -> f64 {
self.evidence.ln()
}
}
#[derive(Debug)] #[derive(Debug)]
pub struct Game<'a, T: Time = i64, D: Drift<T> = crate::drift::ConstantDrift> { pub struct Game<'a, T: Time = i64, D: Drift<T> = crate::drift::ConstantDrift> {
teams: Vec<Vec<Rating<T, D>>>, teams: Vec<Vec<Rating<T, D>>>,
@@ -23,7 +88,7 @@ pub struct Game<'a, T: Time = i64, D: Drift<T> = crate::drift::ConstantDrift> {
} }
impl<'a, T: Time, D: Drift<T>> Game<'a, T, D> { impl<'a, T: Time, D: Drift<T>> Game<'a, T, D> {
pub fn new( pub(crate) fn ranked_with_arena(
teams: Vec<Vec<Rating<T, D>>>, teams: Vec<Vec<Rating<T, D>>>,
result: &'a [f64], result: &'a [f64],
weights: &'a [Vec<f64>], weights: &'a [Vec<f64>],
@@ -219,6 +284,68 @@ impl<'a, T: Time, D: Drift<T>> Game<'a, T, D> {
}) })
.collect::<Vec<_>>() .collect::<Vec<_>>()
} }
pub fn log_evidence(&self) -> f64 {
self.evidence.ln()
}
}
impl<T: Time, D: Drift<T>> Game<'_, T, D> {
pub fn ranked(
teams: &[&[Rating<T, D>]],
outcome: crate::Outcome,
options: &GameOptions,
) -> Result<OwnedGame<T, D>, crate::InferenceError> {
if !(0.0..1.0).contains(&options.p_draw) {
return Err(crate::InferenceError::InvalidProbability {
value: options.p_draw,
});
}
if outcome.team_count() != teams.len() {
return Err(crate::InferenceError::MismatchedShape {
kind: "outcome ranks vs teams",
expected: teams.len(),
got: outcome.team_count(),
});
}
let ranks = outcome.as_ranks();
let max_rank = ranks.iter().copied().max().unwrap_or(0) as f64;
let result: Vec<f64> = ranks.iter().map(|&r| max_rank - r as f64).collect();
let teams_owned: Vec<Vec<Rating<T, D>>> = teams.iter().map(|t| t.to_vec()).collect();
let weights: Vec<Vec<f64>> = teams.iter().map(|t| vec![1.0; t.len()]).collect();
Ok(OwnedGame::new(teams_owned, result, weights, options.p_draw))
}
pub fn one_v_one(
a: &Rating<T, D>,
b: &Rating<T, D>,
outcome: crate::Outcome,
) -> Result<(Gaussian, Gaussian), crate::InferenceError> {
let game = Self::ranked(&[&[*a], &[*b]], outcome, &GameOptions::default())?;
let post = game.posteriors();
Ok((post[0][0], post[1][0]))
}
pub fn free_for_all(
players: &[&Rating<T, D>],
outcome: crate::Outcome,
options: &GameOptions,
) -> Result<OwnedGame<T, D>, crate::InferenceError> {
let teams: Vec<Vec<Rating<T, D>>> = players.iter().map(|p| vec![**p]).collect();
let team_refs: Vec<&[Rating<T, D>]> = teams.iter().map(|t| t.as_slice()).collect();
Self::ranked(&team_refs, outcome, options)
}
#[doc(hidden)]
pub fn custom<S: crate::factors::Schedule>(
factors: &mut [crate::factors::BuiltinFactor],
vars: &mut crate::factors::VarStore,
schedule: &S,
) -> crate::factors::ScheduleReport {
schedule.run(factors, vars)
}
} }
#[cfg(test)] #[cfg(test)]
@@ -244,7 +371,7 @@ mod tests {
); );
let w = [vec![1.0], vec![1.0]]; let w = [vec![1.0], vec![1.0]];
let g = Game::new( let g = Game::ranked_with_arena(
vec![vec![t_a], vec![t_b]], vec![vec![t_a], vec![t_b]],
&[0.0, 1.0], &[0.0, 1.0],
&w, &w,
@@ -271,7 +398,7 @@ mod tests {
); );
let w = [vec![1.0], vec![1.0]]; let w = [vec![1.0], vec![1.0]];
let g = Game::new( let g = Game::ranked_with_arena(
vec![vec![t_a], vec![t_b]], vec![vec![t_a], vec![t_b]],
&[0.0, 1.0], &[0.0, 1.0],
&w, &w,
@@ -290,7 +417,7 @@ mod tests {
let t_b = R::new(Gaussian::from_ms(15.568, 0.51), 1.0, ConstantDrift(0.2125)); let t_b = R::new(Gaussian::from_ms(15.568, 0.51), 1.0, ConstantDrift(0.2125));
let w = [vec![1.0], vec![1.0]]; let w = [vec![1.0], vec![1.0]];
let g = Game::new( let g = Game::ranked_with_arena(
vec![vec![t_a], vec![t_b]], vec![vec![t_a], vec![t_b]],
&[0.0, 1.0], &[0.0, 1.0],
&w, &w,
@@ -323,7 +450,7 @@ mod tests {
]; ];
let w = [vec![1.0], vec![1.0], vec![1.0]]; let w = [vec![1.0], vec![1.0], vec![1.0]];
let g = Game::new( let g = Game::ranked_with_arena(
teams.clone(), teams.clone(),
&[1.0, 2.0, 0.0], &[1.0, 2.0, 0.0],
&w, &w,
@@ -339,7 +466,7 @@ mod tests {
assert_ulps_eq!(b, Gaussian::from_ms(31.311358, 6.698818), epsilon = 1e-6); assert_ulps_eq!(b, Gaussian::from_ms(31.311358, 6.698818), epsilon = 1e-6);
let w = [vec![1.0], vec![1.0], vec![1.0]]; let w = [vec![1.0], vec![1.0], vec![1.0]];
let g = Game::new( let g = Game::ranked_with_arena(
teams.clone(), teams.clone(),
&[2.0, 1.0, 0.0], &[2.0, 1.0, 0.0],
&w, &w,
@@ -355,7 +482,7 @@ mod tests {
assert_ulps_eq!(b, Gaussian::from_ms(25.000000, 6.238469), epsilon = 1e-6); assert_ulps_eq!(b, Gaussian::from_ms(25.000000, 6.238469), epsilon = 1e-6);
let w = [vec![1.0], vec![1.0], vec![1.0]]; let w = [vec![1.0], vec![1.0], vec![1.0]];
let g = Game::new(teams, &[1.0, 2.0, 0.0], &w, 0.5, &mut ScratchArena::new()); let g = Game::ranked_with_arena(teams, &[1.0, 2.0, 0.0], &w, 0.5, &mut ScratchArena::new());
let p = g.posteriors(); let p = g.posteriors();
let a = p[0][0]; let a = p[0][0];
@@ -382,7 +509,7 @@ mod tests {
); );
let w = [vec![1.0], vec![1.0]]; let w = [vec![1.0], vec![1.0]];
let g = Game::new( let g = Game::ranked_with_arena(
vec![vec![t_a], vec![t_b]], vec![vec![t_a], vec![t_b]],
&[0.0, 0.0], &[0.0, 0.0],
&w, &w,
@@ -409,7 +536,7 @@ mod tests {
); );
let w = [vec![1.0], vec![1.0]]; let w = [vec![1.0], vec![1.0]];
let g = Game::new( let g = Game::ranked_with_arena(
vec![vec![t_a], vec![t_b]], vec![vec![t_a], vec![t_b]],
&[0.0, 0.0], &[0.0, 0.0],
&w, &w,
@@ -444,7 +571,7 @@ mod tests {
); );
let w = [vec![1.0], vec![1.0], vec![1.0]]; let w = [vec![1.0], vec![1.0], vec![1.0]];
let g = Game::new( let g = Game::ranked_with_arena(
vec![vec![t_a], vec![t_b], vec![t_c]], vec![vec![t_a], vec![t_b], vec![t_c]],
&[0.0, 0.0, 0.0], &[0.0, 0.0, 0.0],
&w, &w,
@@ -480,7 +607,7 @@ mod tests {
); );
let w = [vec![1.0], vec![1.0], vec![1.0]]; let w = [vec![1.0], vec![1.0], vec![1.0]];
let g = Game::new( let g = Game::ranked_with_arena(
vec![vec![t_a], vec![t_b], vec![t_c]], vec![vec![t_a], vec![t_b], vec![t_c]],
&[0.0, 0.0, 0.0], &[0.0, 0.0, 0.0],
&w, &w,
@@ -531,7 +658,7 @@ mod tests {
]; ];
let w = [vec![1.0, 1.0], vec![1.0], vec![1.0, 1.0]]; let w = [vec![1.0, 1.0], vec![1.0], vec![1.0, 1.0]];
let g = Game::new( let g = Game::ranked_with_arena(
vec![t_a, t_b, t_c], vec![t_a, t_b, t_c],
&[1.0, 0.0, 0.0], &[1.0, 0.0, 0.0],
&w, &w,
@@ -564,7 +691,7 @@ mod tests {
)]; )];
let w = [w_a, w_b]; let w = [w_a, w_b];
let g = Game::new( let g = Game::ranked_with_arena(
vec![t_a.clone(), t_b.clone()], vec![t_a.clone(), t_b.clone()],
&[1.0, 0.0], &[1.0, 0.0],
&w, &w,
@@ -588,7 +715,7 @@ mod tests {
let w_b = vec![0.7]; let w_b = vec![0.7];
let w = [w_a, w_b]; let w = [w_a, w_b];
let g = Game::new( let g = Game::ranked_with_arena(
vec![t_a.clone(), t_b.clone()], vec![t_a.clone(), t_b.clone()],
&[1.0, 0.0], &[1.0, 0.0],
&w, &w,
@@ -612,7 +739,7 @@ mod tests {
let w_b = vec![0.7]; let w_b = vec![0.7];
let w = [w_a, w_b]; let w = [w_a, w_b];
let g = Game::new( let g = Game::ranked_with_arena(
vec![t_a, t_b], vec![t_a, t_b],
&[1.0, 0.0], &[1.0, 0.0],
&w, &w,
@@ -639,7 +766,7 @@ mod tests {
let t_b = vec![R::new(Gaussian::from_ms(2.0, 6.0), 1.0, ConstantDrift(0.0))]; let t_b = vec![R::new(Gaussian::from_ms(2.0, 6.0), 1.0, ConstantDrift(0.0))];
let w = [w_a, w_b]; let w = [w_a, w_b];
let g = Game::new( let g = Game::ranked_with_arena(
vec![t_a, t_b], vec![t_a, t_b],
&[1.0, 0.0], &[1.0, 0.0],
&w, &w,
@@ -666,7 +793,7 @@ mod tests {
let t_b = vec![R::new(Gaussian::from_ms(2.0, 6.0), 1.0, ConstantDrift(0.0))]; let t_b = vec![R::new(Gaussian::from_ms(2.0, 6.0), 1.0, ConstantDrift(0.0))];
let w = [w_a, w_b]; let w = [w_a, w_b];
let g = Game::new( let g = Game::ranked_with_arena(
vec![t_a, t_b], vec![t_a, t_b],
&[1.0, 0.0], &[1.0, 0.0],
&w, &w,
@@ -709,7 +836,7 @@ mod tests {
let w_b = vec![0.9, 0.6]; let w_b = vec![0.9, 0.6];
let w = [w_a, w_b]; let w = [w_a, w_b];
let g = Game::new( let g = Game::ranked_with_arena(
vec![t_a.clone(), t_b.clone()], vec![t_a.clone(), t_b.clone()],
&[1.0, 0.0], &[1.0, 0.0],
&w, &w,
@@ -743,7 +870,7 @@ mod tests {
let w_b = vec![0.7, 0.4]; let w_b = vec![0.7, 0.4];
let w = [w_a, w_b]; let w = [w_a, w_b];
let g = Game::new( let g = Game::ranked_with_arena(
vec![t_a.clone(), t_b.clone()], vec![t_a.clone(), t_b.clone()],
&[1.0, 0.0], &[1.0, 0.0],
&w, &w,
@@ -777,7 +904,7 @@ mod tests {
let w_b = vec![0.7, 2.4]; let w_b = vec![0.7, 2.4];
let w = [w_a, w_b]; let w = [w_a, w_b];
let g = Game::new( let g = Game::ranked_with_arena(
vec![t_a.clone(), t_b.clone()], vec![t_a.clone(), t_b.clone()],
&[1.0, 0.0], &[1.0, 0.0],
&w, &w,
@@ -808,7 +935,7 @@ mod tests {
); );
let w = [vec![1.0, 1.0], vec![1.0]]; let w = [vec![1.0, 1.0], vec![1.0]];
let g = Game::new( let g = Game::ranked_with_arena(
vec![ vec![
t_a.clone(), t_a.clone(),
vec![R::new( vec![R::new(
@@ -828,7 +955,7 @@ mod tests {
let w_b = vec![1.0, 0.0]; let w_b = vec![1.0, 0.0];
let w = [w_a, w_b]; let w = [w_a, w_b];
let g = Game::new( let g = Game::ranked_with_arena(
vec![t_a, t_b.clone()], vec![t_a, t_b.clone()],
&[1.0, 0.0], &[1.0, 0.0],
&w, &w,

File diff suppressed because it is too large Load Diff

View File

@@ -22,7 +22,7 @@ where
Self(HashMap::new()) Self(HashMap::new())
} }
pub fn get<Q: ?Sized + Hash + Eq + ToOwned<Owned = K>>(&self, k: &Q) -> Option<Index> pub fn get<Q: ?Sized + Hash + Eq>(&self, k: &Q) -> Option<Index>
where where
K: Borrow<Q>, K: Borrow<Q>,
{ {

View File

@@ -10,10 +10,13 @@ mod time;
mod time_slice; mod time_slice;
pub use time_slice::TimeSlice; pub use time_slice::TimeSlice;
mod competitor; mod competitor;
mod convergence;
pub mod drift; pub mod drift;
mod error; mod error;
mod event; mod event;
mod event_builder;
pub(crate) mod factor; pub(crate) mod factor;
pub mod factors;
mod game; mod game;
pub mod gaussian; pub mod gaussian;
mod history; mod history;
@@ -26,10 +29,12 @@ pub(crate) mod schedule;
pub mod storage; pub mod storage;
pub use competitor::Competitor; pub use competitor::Competitor;
pub use convergence::{ConvergenceOptions, ConvergenceReport};
pub use drift::{ConstantDrift, Drift}; pub use drift::{ConstantDrift, Drift};
pub use error::InferenceError; pub use error::InferenceError;
pub use event::{Event, Member, Team}; pub use event::{Event, Member, Team};
pub use game::Game; pub use event_builder::EventBuilder;
pub use game::{Game, GameOptions, OwnedGame};
pub use gaussian::Gaussian; pub use gaussian::Gaussian;
pub use history::History; pub use history::History;
pub use key_table::KeyTable; pub use key_table::KeyTable;

View File

@@ -16,8 +16,7 @@ pub struct ScheduleReport {
} }
/// Drives factor propagation to convergence. /// Drives factor propagation to convergence.
#[allow(dead_code)] pub trait Schedule {
pub(crate) trait Schedule {
fn run(&self, factors: &mut [BuiltinFactor], vars: &mut VarStore) -> ScheduleReport; fn run(&self, factors: &mut [BuiltinFactor], vars: &mut VarStore) -> ScheduleReport;
} }
@@ -26,8 +25,7 @@ pub(crate) trait Schedule {
/// Matches the existing `Game::likelihoods` loop bit-for-bit when given the /// Matches the existing `Game::likelihoods` loop bit-for-bit when given the
/// same factor layout (TeamSums first, then alternating RankDiff/Trunc pairs). /// same factor layout (TeamSums first, then alternating RankDiff/Trunc pairs).
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
#[allow(dead_code)] pub struct EpsilonOrMax {
pub(crate) struct EpsilonOrMax {
pub eps: f64, pub eps: f64,
pub max: usize, pub max: usize,
} }

View File

@@ -226,7 +226,13 @@ impl<T: Time> TimeSlice<T> {
let teams = event.within_priors(false, false, &self.skills, agents); let teams = event.within_priors(false, false, &self.skills, agents);
let result = event.outputs(); let result = event.outputs();
let g = Game::new(teams, &result, &event.weights, self.p_draw, &mut self.arena); let g = Game::ranked_with_arena(
teams,
&result,
&event.weights,
self.p_draw,
&mut self.arena,
);
for (t, team) in event.teams.iter_mut().enumerate() { for (t, team) in event.teams.iter_mut().enumerate() {
for (i, item) in team.items.iter_mut().enumerate() { for (i, item) in team.items.iter_mut().enumerate() {
@@ -315,7 +321,7 @@ impl<T: Time> TimeSlice<T> {
self.events self.events
.iter() .iter()
.map(|event| { .map(|event| {
Game::new( Game::ranked_with_arena(
event.within_priors(online, forward, &self.skills, agents), event.within_priors(online, forward, &self.skills, agents),
&event.outputs(), &event.outputs(),
&event.weights, &event.weights,
@@ -341,7 +347,7 @@ impl<T: Time> TimeSlice<T> {
.any(|item| targets.contains(&item.agent)) .any(|item| targets.contains(&item.agent))
}) })
.map(|(_, event)| { .map(|(_, event)| {
Game::new( Game::ranked_with_arena(
event.within_priors(online, forward, &self.skills, agents), event.within_priors(online, forward, &self.skills, agents),
&event.outputs(), &event.outputs(),
&event.weights, &event.weights,

225
tests/api_shape.rs Normal file
View File

@@ -0,0 +1,225 @@
//! Tests for the new T2 public API surface: typed add_events(iter) and the
//! fluent event builder (added in Task 16).
use smallvec::smallvec;
use trueskill_tt::{ConstantDrift, ConvergenceOptions, Event, History, Member, Outcome, Team};
#[test]
fn add_events_bulk_via_iter() {
let mut h = History::builder()
.mu(0.0)
.sigma(2.0)
.beta(1.0)
.p_draw(0.0)
.drift(ConstantDrift(0.0))
.convergence(ConvergenceOptions {
max_iter: 30,
epsilon: 1e-6,
})
.build();
let events: Vec<Event<i64, &'static str>> = vec![
Event {
time: 1,
teams: smallvec![
Team::with_members([Member::new("a")]),
Team::with_members([Member::new("b")]),
],
outcome: Outcome::winner(0, 2),
},
Event {
time: 2,
teams: smallvec![
Team::with_members([Member::new("b")]),
Team::with_members([Member::new("c")]),
],
outcome: Outcome::winner(0, 2),
},
];
h.add_events(events).unwrap();
let report = h.converge().unwrap();
assert!(report.converged);
assert!(h.lookup(&"a").is_some());
assert!(h.lookup(&"b").is_some());
assert!(h.lookup(&"c").is_some());
}
#[test]
fn add_events_draw() {
let mut h = History::builder()
.mu(25.0)
.sigma(25.0 / 3.0)
.beta(25.0 / 6.0)
.p_draw(0.25)
.drift(ConstantDrift(25.0 / 300.0))
.build();
let events: Vec<Event<i64, &'static str>> = vec![Event {
time: 1,
teams: smallvec![
Team::with_members([Member::new("alice")]),
Team::with_members([Member::new("bob")]),
],
outcome: Outcome::draw(2),
}];
h.add_events(events).unwrap();
h.converge().unwrap();
}
#[test]
fn add_events_rejects_mismatched_outcome_ranks() {
use trueskill_tt::InferenceError;
let mut h: History = History::builder().build();
let events: Vec<Event<i64, &'static str>> = vec![Event {
time: 1,
teams: smallvec![
Team::with_members([Member::new("a")]),
Team::with_members([Member::new("b")]),
],
outcome: Outcome::ranking([0, 1, 2]), // 3 ranks but 2 teams
}];
let err = h.add_events(events).unwrap_err();
assert!(matches!(err, InferenceError::MismatchedShape { .. }));
}
#[test]
fn fluent_event_builder_basic() {
let mut h = History::builder()
.mu(25.0)
.sigma(25.0 / 3.0)
.beta(25.0 / 6.0)
.p_draw(0.0)
.build();
h.event(1)
.team(["alice", "bob"])
.weights([1.0, 0.7])
.team(["carol"])
.ranking([1, 0])
.commit()
.unwrap();
let report = h.converge().unwrap();
assert!(report.converged);
assert!(h.lookup(&"alice").is_some());
assert!(h.lookup(&"bob").is_some());
assert!(h.lookup(&"carol").is_some());
}
#[test]
fn fluent_event_builder_winner_convenience() {
let mut h = History::builder()
.mu(25.0)
.sigma(25.0 / 3.0)
.beta(25.0 / 6.0)
.p_draw(0.0)
.build();
h.event(1)
.team(["alice"])
.team(["bob"])
.winner(0)
.commit()
.unwrap();
h.converge().unwrap();
}
#[test]
fn fluent_event_builder_draw() {
let mut h = History::builder()
.mu(25.0)
.sigma(25.0 / 3.0)
.beta(25.0 / 6.0)
.p_draw(0.25)
.build();
h.event(1)
.team(["alice"])
.team(["bob"])
.draw()
.commit()
.unwrap();
h.converge().unwrap();
}
#[test]
fn current_skill_and_learning_curve() {
use trueskill_tt::History;
let mut h = History::builder()
.mu(25.0)
.sigma(25.0 / 3.0)
.beta(25.0 / 6.0)
.p_draw(0.0)
.build();
h.record_winner(&"a", &"b", 1).unwrap();
h.record_winner(&"a", &"b", 2).unwrap();
h.converge().unwrap();
let a = h.current_skill(&"a").unwrap();
assert!(a.mu() > 25.0);
let b = h.current_skill(&"b").unwrap();
assert!(b.mu() < 25.0);
let a_curve = h.learning_curve(&"a");
assert_eq!(a_curve.len(), 2);
assert_eq!(a_curve[0].0, 1);
assert_eq!(a_curve[1].0, 2);
let all = h.learning_curves();
assert_eq!(all.len(), 2);
assert!(all.contains_key("a"));
assert!(all.contains_key("b"));
}
#[test]
fn log_evidence_total_vs_subset() {
use trueskill_tt::{ConstantDrift, History};
let mut h = History::builder()
.mu(0.0)
.sigma(6.0)
.beta(1.0)
.p_draw(0.0)
.drift(ConstantDrift(0.0))
.build();
h.record_winner(&"a", &"b", 1).unwrap();
h.record_winner(&"b", &"a", 2).unwrap();
let total = h.log_evidence();
let a_only = h.log_evidence_for(&[&"a"]);
assert!(total.is_finite());
assert!(a_only.is_finite());
}
#[test]
fn predict_quality_two_teams() {
use trueskill_tt::History;
let mut h = History::builder()
.mu(25.0)
.sigma(25.0 / 3.0)
.beta(25.0 / 6.0)
.p_draw(0.0)
.build();
h.record_winner(&"a", &"b", 1).unwrap();
h.converge().unwrap();
let q = h.predict_quality(&[&[&"a"], &[&"b"]]);
assert!(q > 0.0 && q <= 1.0);
}
#[test]
fn predict_outcome_two_teams_sums_to_one() {
use trueskill_tt::History;
let mut h = History::builder()
.mu(25.0)
.sigma(25.0 / 3.0)
.beta(25.0 / 6.0)
.p_draw(0.0)
.build();
h.record_winner(&"a", &"b", 1).unwrap();
h.converge().unwrap();
let p = h.predict_outcome(&[&[&"a"], &[&"b"]]);
assert_eq!(p.len(), 2);
assert!((p[0] + p[1] - 1.0).abs() < 1e-9);
assert!(p[0] > p[1]);
}

61
tests/equivalence.rs Normal file
View File

@@ -0,0 +1,61 @@
//! Equivalence tests: every historical golden from the pre-T2 tests is
//! reproduced here at the integration level via the new public API.
//!
//! The in-crate tests in `src/history.rs::tests` and
//! `src/time_slice.rs::tests` are the primary regression net for numerical
//! behavior. This file provides Game-level goldens that stand alone and are
//! more naturally expressed as integration tests.
use approx::assert_ulps_eq;
use trueskill_tt::{ConstantDrift, Game, GameOptions, Gaussian, Outcome, Rating};
type R = Rating<i64, ConstantDrift>;
fn ts_rating(mu: f64, sigma: f64, beta: f64, gamma: f64) -> R {
R::new(Gaussian::from_ms(mu, sigma), beta, ConstantDrift(gamma))
}
#[test]
fn game_1v1_golden_matches_historical() {
let a = ts_rating(25.0, 25.0 / 3.0, 25.0 / 6.0, 25.0 / 300.0);
let b = ts_rating(25.0, 25.0 / 3.0, 25.0 / 6.0, 25.0 / 300.0);
let (a_post, b_post) = Game::<i64, _>::one_v_one(&a, &b, Outcome::winner(0, 2)).unwrap();
// Historical golden from pre-T2 test_1vs1 (team 0 wins):
assert_ulps_eq!(
a_post,
Gaussian::from_ms(29.205220, 7.194481),
epsilon = 1e-6
);
assert_ulps_eq!(
b_post,
Gaussian::from_ms(20.794779, 7.194481),
epsilon = 1e-6
);
}
#[test]
fn game_1v1_draw_golden() {
let a = ts_rating(25.0, 25.0 / 3.0, 25.0 / 6.0, 25.0 / 300.0);
let b = ts_rating(25.0, 25.0 / 3.0, 25.0 / 6.0, 25.0 / 300.0);
let g = Game::<i64, _>::ranked(
&[&[a], &[b]],
Outcome::draw(2),
&GameOptions {
p_draw: 0.25,
convergence: Default::default(),
},
)
.unwrap();
let p = g.posteriors();
// Historical golden from pre-T2 test_1vs1_draw:
assert_ulps_eq!(
p[0][0],
Gaussian::from_ms(24.999999, 6.469480),
epsilon = 1e-6
);
assert_ulps_eq!(
p[1][0],
Gaussian::from_ms(24.999999, 6.469480),
epsilon = 1e-6
);
}

96
tests/game.rs Normal file
View File

@@ -0,0 +1,96 @@
use trueskill_tt::{
ConstantDrift, ConvergenceOptions, Game, GameOptions, Gaussian, InferenceError, Outcome, Rating,
};
type R = Rating<i64, ConstantDrift>;
fn default_rating() -> R {
R::new(
Gaussian::from_ms(25.0, 25.0 / 3.0),
25.0 / 6.0,
ConstantDrift(25.0 / 300.0),
)
}
#[test]
fn game_ranked_1v1_golden() {
let a = default_rating();
let b = default_rating();
let g = Game::<i64, _>::ranked(
&[&[a], &[b]],
Outcome::winner(0, 2),
&GameOptions::default(),
)
.unwrap();
let p = g.posteriors();
assert!(p[0][0].mu() > 25.0);
assert!(p[1][0].mu() < 25.0);
assert!((p[0][0].sigma() - p[1][0].sigma()).abs() < 1e-6);
}
#[test]
fn game_one_v_one_shortcut() {
let a = default_rating();
let b = default_rating();
let (a_post, b_post) = Game::<i64, _>::one_v_one(&a, &b, Outcome::winner(0, 2)).unwrap();
assert!(a_post.mu() > 25.0);
assert!(b_post.mu() < 25.0);
}
#[test]
fn game_ranked_rejects_bad_p_draw() {
let a = R::new(Gaussian::default(), 1.0, ConstantDrift(0.0));
let err = Game::<i64, _>::ranked(
&[&[a], &[a]],
Outcome::winner(0, 2),
&GameOptions {
p_draw: 1.5,
convergence: ConvergenceOptions::default(),
},
)
.unwrap_err();
assert!(matches!(err, InferenceError::InvalidProbability { .. }));
}
#[test]
fn game_ranked_rejects_mismatched_ranks() {
let a = R::new(Gaussian::default(), 1.0, ConstantDrift(0.0));
let err = Game::<i64, _>::ranked(
&[&[a], &[a]],
Outcome::ranking([0, 1, 2]),
&GameOptions::default(),
)
.unwrap_err();
assert!(matches!(err, InferenceError::MismatchedShape { .. }));
}
#[test]
fn game_free_for_all_three_players() {
let a = default_rating();
let b = default_rating();
let c = default_rating();
let g = Game::<i64, _>::free_for_all(
&[&a, &b, &c],
Outcome::ranking([0, 1, 2]),
&GameOptions::default(),
)
.unwrap();
let p = g.posteriors();
assert_eq!(p.len(), 3);
assert!(p[0][0].mu() > p[1][0].mu());
assert!(p[1][0].mu() > p[2][0].mu());
}
#[test]
fn game_log_evidence_is_finite() {
let a = default_rating();
let b = default_rating();
let g = Game::<i64, _>::ranked(
&[&[a], &[b]],
Outcome::winner(0, 2),
&GameOptions::default(),
)
.unwrap();
assert!(g.log_evidence().is_finite());
assert!(g.log_evidence() < 0.0);
}

54
tests/record_winner.rs Normal file
View File

@@ -0,0 +1,54 @@
use trueskill_tt::{ConstantDrift, ConvergenceOptions, History};
#[test]
fn record_winner_builds_history() {
let mut h = History::builder()
.mu(25.0)
.sigma(25.0 / 3.0)
.beta(25.0 / 6.0)
.drift(ConstantDrift(25.0 / 300.0))
.convergence(ConvergenceOptions {
max_iter: 30,
epsilon: 1e-6,
})
.build();
h.record_winner(&"alice", &"bob", 1).unwrap();
h.converge().unwrap();
let a_idx = h.lookup(&"alice").unwrap();
let b_idx = h.lookup(&"bob").unwrap();
assert_ne!(a_idx, b_idx);
}
#[test]
fn intern_is_idempotent() {
let mut h: History = History::builder().build();
let a1 = h.intern(&"alice");
let a2 = h.intern(&"alice");
assert_eq!(a1, a2);
}
#[test]
fn lookup_returns_none_for_missing() {
let h: History = History::builder().build();
assert!(h.lookup(&"nobody").is_none());
}
#[test]
fn record_draw_with_p_draw_set() {
let mut h = History::builder()
.mu(25.0)
.sigma(25.0 / 3.0)
.beta(25.0 / 6.0)
.drift(ConstantDrift(25.0 / 300.0))
.p_draw(0.25)
.build();
h.record_draw(&"alice", &"bob", 1).unwrap();
h.converge().unwrap();
assert!(h.lookup(&"alice").is_some());
assert!(h.lookup(&"bob").is_some());
}