Adds soft Gaussian-observation evidence on the per-pair diff variable,
enabling continuous score margins as a richer alternative to ranks.
Public API:
- `Outcome::Scored([scores])` (non-breaking enum extension under
`#[non_exhaustive]`).
- `Game::scored(teams, outcome, options)` constructor parallel to
`Game::ranked`.
- `EventBuilder::scores([...])` fluent helper.
- `HistoryBuilder::score_sigma(σ)` knob (default 1.0, validated > 0).
- `GameOptions::score_sigma`.
- `EventKind` re-exported from `lib.rs` (annotated `#[non_exhaustive]`).
- New `InferenceError::InvalidParameter { name, value }` variant.
Internals:
- `MarginFactor` (`factor/margin.rs`): Gaussian observation factor that
closes in one EP step; cavity-cached log-evidence mirrors `TruncFactor`.
- `BuiltinFactor::Margin` dispatch arm.
- `DiffFactor` enum in `game.rs` lets `Game::likelihoods` and the new
`likelihoods_scored` share the per-pair link abstraction.
- Per-event `EventKind { Ranked, Scored { score_sigma } }` routed through
`TimeSlice::add_events`, `iteration_direct`, and `log_evidence`.
Tests: 88 lib + 27 integration (4 new in `tests/scored.rs`); existing
goldens byte-identical. Bench: `benches/scored.rs` baseline ~960µs for
60 events × 20-player pool with default convergence.
Plan: docs/superpowers/plans/2026-04-27-t4-margin-factor.md
Spec item marked Done.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
169 lines
4.6 KiB
Rust
169 lines
4.6 KiB
Rust
//! Factor graph machinery for within-game inference.
|
|
|
|
use crate::gaussian::Gaussian;
|
|
|
|
/// Identifier for a variable in a `VarStore`.
|
|
///
|
|
/// Variables hold the current Gaussian marginal and are owned by exactly one
|
|
/// `VarStore`. `VarId` is meaningful only within its owning store.
|
|
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
|
|
pub struct VarId(pub u32);
|
|
|
|
/// Flat storage of variable marginals.
|
|
///
|
|
/// Variables are allocated by `alloc()` and accessed by `VarId`. The store is
|
|
/// reused across `Game::ranked_with_arena` calls (it lives in the `ScratchArena`); call
|
|
/// `clear()` before reuse.
|
|
#[derive(Debug, Default)]
|
|
pub struct VarStore {
|
|
pub(crate) marginals: Vec<Gaussian>,
|
|
}
|
|
|
|
impl VarStore {
|
|
pub fn new() -> Self {
|
|
Self::default()
|
|
}
|
|
|
|
pub fn clear(&mut self) {
|
|
self.marginals.clear();
|
|
}
|
|
|
|
pub fn len(&self) -> usize {
|
|
self.marginals.len()
|
|
}
|
|
|
|
pub fn is_empty(&self) -> bool {
|
|
self.marginals.is_empty()
|
|
}
|
|
|
|
pub fn alloc(&mut self, init: Gaussian) -> VarId {
|
|
let id = VarId(self.marginals.len() as u32);
|
|
self.marginals.push(init);
|
|
id
|
|
}
|
|
|
|
pub fn get(&self, id: VarId) -> Gaussian {
|
|
self.marginals[id.0 as usize]
|
|
}
|
|
|
|
pub fn set(&mut self, id: VarId, g: Gaussian) {
|
|
self.marginals[id.0 as usize] = g;
|
|
}
|
|
}
|
|
|
|
/// A factor in the EP graph.
|
|
///
|
|
/// Factors hold their own outgoing messages and propagate them by reading
|
|
/// connected variable marginals from a `VarStore` and writing back updated
|
|
/// marginals.
|
|
pub trait Factor: Send + Sync {
|
|
/// Update outgoing messages and write back to the var store.
|
|
///
|
|
/// Returns the max delta `(|Δmu|, |Δsigma|)` across writes this
|
|
/// propagation. Used by the `Schedule` to detect convergence.
|
|
fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64);
|
|
|
|
/// Optional log-evidence contribution. Default 0.0 (no contribution).
|
|
fn log_evidence(&self, _vars: &VarStore) -> f64 {
|
|
0.0
|
|
}
|
|
}
|
|
|
|
/// Enum dispatcher for the built-in factor types.
|
|
///
|
|
/// Using an enum instead of `Box<dyn Factor>` keeps factor data inline and
|
|
/// avoids virtual-call overhead in the hot inference loop.
|
|
#[derive(Debug)]
|
|
pub enum BuiltinFactor {
|
|
TeamSum(team_sum::TeamSumFactor),
|
|
RankDiff(rank_diff::RankDiffFactor),
|
|
Trunc(trunc::TruncFactor),
|
|
Margin(margin::MarginFactor),
|
|
}
|
|
|
|
impl Factor for BuiltinFactor {
|
|
fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64) {
|
|
match self {
|
|
Self::TeamSum(f) => f.propagate(vars),
|
|
Self::RankDiff(f) => f.propagate(vars),
|
|
Self::Trunc(f) => f.propagate(vars),
|
|
Self::Margin(f) => f.propagate(vars),
|
|
}
|
|
}
|
|
|
|
fn log_evidence(&self, vars: &VarStore) -> f64 {
|
|
match self {
|
|
Self::Trunc(f) => f.log_evidence(vars),
|
|
Self::Margin(f) => f.log_evidence(vars),
|
|
_ => 0.0,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub mod margin;
|
|
pub mod rank_diff;
|
|
pub mod team_sum;
|
|
pub mod trunc;
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::N_INF;
|
|
|
|
#[test]
|
|
fn alloc_assigns_sequential_ids() {
|
|
let mut store = VarStore::new();
|
|
let a = store.alloc(N_INF);
|
|
let b = store.alloc(N_INF);
|
|
let c = store.alloc(N_INF);
|
|
assert_eq!(a, VarId(0));
|
|
assert_eq!(b, VarId(1));
|
|
assert_eq!(c, VarId(2));
|
|
assert_eq!(store.len(), 3);
|
|
}
|
|
|
|
#[test]
|
|
fn get_returns_initial_value() {
|
|
let mut store = VarStore::new();
|
|
let g = Gaussian::from_ms(2.5, 1.0);
|
|
let id = store.alloc(g);
|
|
assert_eq!(store.get(id), g);
|
|
}
|
|
|
|
#[test]
|
|
fn set_updates_value() {
|
|
let mut store = VarStore::new();
|
|
let id = store.alloc(N_INF);
|
|
let new = Gaussian::from_ms(3.0, 0.5);
|
|
store.set(id, new);
|
|
assert_eq!(store.get(id), new);
|
|
}
|
|
|
|
#[test]
|
|
fn clear_resets_length_keeping_capacity() {
|
|
let mut store = VarStore::new();
|
|
store.alloc(N_INF);
|
|
store.alloc(N_INF);
|
|
let cap = store.marginals.capacity();
|
|
store.clear();
|
|
assert_eq!(store.len(), 0);
|
|
assert_eq!(store.marginals.capacity(), cap);
|
|
}
|
|
|
|
#[test]
|
|
fn builtin_factor_dispatches_to_margin() {
|
|
use super::margin::MarginFactor;
|
|
let mut vars = VarStore::new();
|
|
let diff = vars.alloc(Gaussian::from_ms(0.0, 6.0));
|
|
let mut f = BuiltinFactor::Margin(MarginFactor::new(diff, 5.0, 1.0));
|
|
|
|
f.propagate(&mut vars);
|
|
|
|
let result = vars.get(diff);
|
|
assert!((result.mu() - 4.864864864864865).abs() < 1e-12);
|
|
|
|
let logz = f.log_evidence(&vars);
|
|
assert!((logz - (-3.062235327364623)).abs() < 1e-10);
|
|
}
|
|
}
|