8 Commits

Author SHA1 Message Date
logaritmisk dbce69f350 test(game): integration tests for ConvergenceOptions behavior
Two end-to-end tests on a 4-team ranked game:
- max_iter=1 produces measurably different posteriors than the default,
  proving run_chain reads convergence.max_iter
- alpha=0.5 with extra iterations reaches the same fixed point as
  alpha=1.0, proving damping doesn't break convergence on benign graphs
2026-05-08 15:13:23 +02:00
logaritmisk 0705986929 feat(game): plumb ConvergenceOptions through to run_chain
Game and OwnedGame gain a convergence: ConvergenceOptions field set at
construction. Game::{ranked,scored} forward options.convergence into
OwnedGame::{new,new_scored} (previously dropped on the floor).
{ranked,scored}_with_arena take it as a parameter. run_chain reads
self.convergence.{epsilon, max_iter, alpha} instead of hardcoded
1e-6 / 10 / undamped. DiffFactor::propagate gains an alpha parameter
and dispatches into Trunc/MarginFactor::propagate_with_alpha.

In-tree callsites in src/time_slice.rs and src/history.rs pass
ConvergenceOptions::default(). Pre-existing T2 fallout in tests,
benches, and the atp example (struct literals missing the new alpha
field) is fixed by adding alpha: 1.0 so the workspace builds clean.
Default alpha is 1.0, so all 96 lib + 27 integration test goldens
remain bit-equal.
2026-05-08 15:10:35 +02:00
logaritmisk aacaa60baa feat(factor): add MarginFactor::propagate_with_alpha for EP damping
Mirrors TruncFactor: inherent damped-propagate method, trait impl
delegates with α=1.0. Existing goldens unchanged because cavity*new_msg
equals the previous marginal write when α=1.0.
2026-05-08 15:03:45 +02:00
logaritmisk fcfe0ffe37 feat(factor): add TruncFactor::propagate_with_alpha for EP damping
Inherent method that applies α-damping to the outgoing message via
Gaussian::damp_natural. The Factor trait impl delegates with α=1.0,
preserving today's behavior bit-equal. Variable write switched from
`trunc` to `cavity * damped` — algebraically identical when α=1.0
(cavity * new_msg = trunc by construction); reflects partial-update
math when α<1.0.
2026-05-08 15:02:09 +02:00
logaritmisk 0fa4e7d277 feat(convergence): add ConvergenceOptions::alpha damping field
Adds an EP damping coefficient defaulting to 1.0 (undamped). Will be
read by run_chain in a follow-up commit. By itself this commit changes
no behavior — existing constructors using ..Default::default() pick up
the new field automatically.
2026-05-08 15:00:34 +02:00
logaritmisk 0dd7dab266 feat(gaussian): add damp_natural helper for EP damping
Computes α·new + (1−α)·self in natural-parameter space. Will be used
by TruncFactor and MarginFactor to support opt-in EP damping via
ConvergenceOptions::alpha.
2026-05-08 14:59:18 +02:00
logaritmisk 43cc6d82f9 docs: implementation plan for game-local Damped EP
Six tasks: Gaussian::damp_natural helper, ConvergenceOptions::alpha
field, TruncFactor and MarginFactor propagate_with_alpha pair, DiffFactor
+ Game integration (the big task — must land atomically), and
end-to-end tests for max_iter and alpha behavior.
2026-05-08 14:57:41 +02:00
logaritmisk 48a6049dc6 docs: spec for game-local Damped EP
Smallest-scope realisation of spec §"Built-in schedules" Damped: a
ConvergenceOptions::alpha field plumbed through run_chain to a new
Gaussian::damp_natural helper applied inside TruncFactor and
MarginFactor's propagate. alpha=1.0 default keeps every existing
golden bit-equal; alpha<1.0 stabilises oscillating fixed-point loops
on hard graphs.

Defers Schedule trait integration, nat-param convergence switch,
oscillation auto-detect, Residual/OneShot, and Synergy/ScoreFactor —
each gets its own future plan.
2026-05-08 14:52:36 +02:00
14 changed files with 1825 additions and 49 deletions
+1
View File
@@ -51,6 +51,7 @@ fn build_history_1v1(
.convergence(ConvergenceOptions { .convergence(ConvergenceOptions {
max_iter: 30, max_iter: 30,
epsilon: 1e-6, epsilon: 1e-6,
alpha: 1.0,
}) })
.build(); .build();
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,320 @@
# Damped EP — Game-Local Damping
## Summary
Add an opt-in EP damping knob to within-game inference. Users set
`ConvergenceOptions::alpha < 1.0` to damp message updates and stabilise
oscillating fixed-point loops on hard graphs. `alpha = 1.0` (the default)
is bit-equal to today.
This is the smallest-scope realisation of the spec's `Damped` schedule:
**game-local**, not plumbed through the `Schedule` trait. The `Schedule`
trait is shipped infrastructure that `run_chain` does not currently call;
wiring `Schedule` into game inference is a separate future task. This
design touches only what the user can actually reach via `GameOptions`.
## Scope
### What ships
1. New field `ConvergenceOptions::alpha: f64` (default `1.0`).
2. `run_chain` reads `options.convergence.{epsilon, max_iter, alpha}`
instead of the hardcoded `1e-6` / `10` / undamped — fixes the existing
latent bug where the first two were already on `GameOptions` but never
read by inference.
3. `Gaussian::damp_natural(self, new, alpha) -> Gaussian` — public helper
computing `α·new + (1−α)·self` in natural-parameter space.
4. `TruncFactor` and `MarginFactor` gain inherent
`propagate_with_alpha(&mut self, vars, alpha) -> (f64, f64)`. Their
`Factor::propagate` impls become one-line delegations passing
`alpha = 1.0`.
5. `DiffFactor::propagate` (game-private enum at `src/game.rs:20-54`)
gains an `alpha: f64` parameter and dispatches into the underlying
factor's `propagate_with_alpha`.
### What does not ship
- No `Damped` impl in `src/schedule.rs`. The `Schedule` trait stays as
it is; integration with `run_chain` is a separate task.
- No nat-param convergence switch. `(|Δmu|, |Δsigma|)` stays the
delta basis (matches today). The spec's "stopping in natural-param
space" wants its own design pass and test re-tuning.
- No oscillation auto-detect. `alpha` is user-supplied and constant for
the duration of a `run_chain` call.
- No `Residual`, `OneShot`, or `SynergyFactor` / `ScoreFactor` work —
separate future plans.
## Design
### `ConvergenceOptions::alpha`
```rust
// src/convergence.rs
#[derive(Clone, Copy, Debug)]
pub struct ConvergenceOptions {
pub max_iter: usize,
pub epsilon: f64,
pub alpha: f64,
}
impl Default for ConvergenceOptions {
fn default() -> Self {
Self {
max_iter: crate::ITERATIONS,
epsilon: crate::EPSILON,
alpha: 1.0,
}
}
}
```
`alpha = 1.0` ⇒ undamped (bit-equal to today). Recommended starting
point if a graph oscillates: `0.5``0.7`. Values approaching `0.0` make
each step tinier and slow convergence; `alpha = 0.0` is degenerate
(factor never updates). Validation in `run_chain`:
```rust
debug_assert!(
opts.convergence.alpha > 0.0 && opts.convergence.alpha <= 1.0,
"convergence alpha must be in (0.0, 1.0]"
);
```
### `Gaussian::damp_natural`
```rust
impl Gaussian {
/// EP damping in natural-parameter space: `α·new + (1−α)·self`.
///
/// Used by within-game schedules to stabilise oscillating fixed-point
/// loops on hard graphs. `alpha = 1.0` returns `new` exactly;
/// `alpha < 1.0` shrinks each per-step update.
pub fn damp_natural(self, new: Gaussian, alpha: f64) -> Gaussian {
Gaussian::from_natural(
alpha * new.pi() + (1.0 - alpha) * self.pi(),
alpha * new.tau() + (1.0 - alpha) * self.tau(),
)
}
}
```
Public on `Gaussian`. The name encodes the WHY (EP damping); the doc
comment fixes the math. No new dependency.
The existing `Mul<f64> for Gaussian` is **distribution scaling**
(`sigma → sigma·|scalar|`), not nat-param interpolation, so it can't be
reused here.
### `TruncFactor::propagate_with_alpha`
```rust
impl TruncFactor {
pub(crate) fn propagate_with_alpha(
&mut self,
vars: &mut VarStore,
alpha: f64,
) -> (f64, f64) {
let marginal = vars.get(self.diff);
let cavity = marginal / self.msg;
if self.evidence_cached.is_none() {
self.evidence_cached = Some(cavity_evidence(cavity, self.margin, self.tie));
}
let trunc = approx(cavity, self.margin, self.tie);
let new_msg = trunc / cavity;
let damped = self.msg.damp_natural(new_msg, alpha);
let old_msg = self.msg;
self.msg = damped;
// marginal_new = cavity * stored_msg (NOT cavity * new_msg with damping)
vars.set(self.diff, cavity * damped);
old_msg.delta(damped)
}
}
impl Factor for TruncFactor {
fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64) {
self.propagate_with_alpha(vars, 1.0)
}
}
```
Two important points:
- The variable receives `cavity * damped` (i.e. `cavity * self.msg`),
not `trunc`. With `alpha = 1.0` these are equal (since
`cavity * new_msg = trunc` by construction), so today's behaviour is
preserved bit-equal. With `alpha < 1.0` the marginal reflects the
partially-applied update.
- The reported delta is `old_msg.delta(damped)` — delta of the actually
stored message, not of the raw `new_msg`. This is the textbook EP
damping convention: the convergence loop measures the trajectory it
is actually walking.
`MarginFactor` follows the same shape, with its own
`propagate_with_alpha` body (the existing `propagate` math, with the
`damp_natural` step inserted in the same place and the var write
switched to `cavity * damped`).
### `DiffFactor::propagate` signature
```rust
// src/game.rs
impl DiffFactor {
pub(crate) fn propagate(
&mut self,
vars: &mut VarStore,
alpha: f64,
) -> (f64, f64) {
match self {
Self::Trunc(f) => f.propagate_with_alpha(vars, alpha),
Self::Margin(f) => f.propagate_with_alpha(vars, alpha),
}
}
}
```
`DiffFactor` is `pub(crate)` and only used inside `run_chain`, so the
signature change has no public-API impact.
### `run_chain` changes
Inside `Game::run_chain` (`src/game.rs:236-348`):
1. Capture `let alpha = opts.convergence.alpha;` once at the top
(avoids repeated `opts.convergence.alpha` lookups in the hot loop).
2. Replace the loop guard
`while tuple_gt(step, 1e-6) && iter < 10`
with
`while tuple_gt(step, opts.convergence.epsilon) && iter < opts.convergence.max_iter`.
3. Replace each `lf.propagate(&mut arena.vars)` call site (three of
them: forward sweep, backward sweep, `n_diffs == 1` special case)
with `lf.propagate(&mut arena.vars, alpha)`.
The threading of `opts: &GameOptions` into `run_chain` is the only
new caller obligation. Today `run_chain` doesn't take `opts`; the two
callers (`likelihoods`, `likelihoods_scored`) currently invoke it
without options. Both will need to pass the options through. The
`Game<'a, T, D>` struct does not currently hold `GameOptions`; the
options are constructed and discarded around the call to
`{ranked,scored}_with_arena`. So:
- `Game::ranked_with_arena` and `Game::scored_with_arena` already
receive `p_draw` / `score_sigma` as scalar params; we extend them to
accept `&ConvergenceOptions` (or the full `&GameOptions`) too.
- `likelihoods` / `likelihoods_scored` either store the options on
`Game` or accept them as method parameters and forward to
`run_chain`.
The simplest plumbing: store `convergence: ConvergenceOptions` as a
field on `Game<'a, T, D>` and `OwnedGame<T, D>` populated at
construction time. Then `run_chain` can read it from `&self`.
## Convergence semantics
With `alpha < 1.0` the per-step update shrinks; convergence may take
more iterations to reach the same `epsilon` threshold. Users who damp
should also raise `max_iter` accordingly. Documentation example:
```rust
let mut opts = GameOptions::default();
opts.convergence.alpha = 0.5;
opts.convergence.max_iter = 30;
```
## Testing strategy
### Regression net (no new file)
The existing 88 lib tests and 27 integration tests are the bit-equal
regression net. With `alpha = 1.0` (the default), every assertion must
pass unchanged. If any test fails, the damping path leaked into the
undamped trajectory.
### New tests
1. **`Gaussian::damp_natural` arithmetic**
(`src/gaussian.rs` test mod):
- `α = 1.0` returns `new` exactly (bit-equal `pi` and `tau`).
- `α = 0.0` returns `self` exactly.
- `α = 0.5`: pi and tau are exact midpoints in nat-param space.
- Three asserts, no new file.
2. **`TruncFactor::propagate_with_alpha` shrinks the step**
(`src/factor/trunc.rs` test mod):
- Set up a TruncFactor step. Run `propagate_with_alpha(α=1.0)` once,
record `delta_undamped` and the resulting `self.msg`.
- Reset to a fresh factor at the same starting state. Run
`propagate_with_alpha(α=0.5)` once, record `delta_damped` and
`damped_msg`.
- Assert: `damped_msg.pi()` equals `0.5 * undamped_msg.pi() + 0.5 * initial_msg.pi()` within 1e-12 (and same for `tau`).
- Assert: `delta_damped.0 <= delta_undamped.0` (mu-delta is no larger; the relationship is monotone in `α` but not strictly `0.5×` for the `delta()` function which is `(|Δmu|, |Δsigma|)`).
3. **`MarginFactor::propagate_with_alpha` parity**
(`src/factor/margin.rs` test mod):
- Same shape as #2, on a `MarginFactor` step.
4. **`run_chain` honours `ConvergenceOptions::max_iter`**
(in an existing or new game-level test):
- Construct a 4-team ranked game that normally converges in ~5 iterations.
- Set `opts.convergence.max_iter = 1`. Assert the per-iteration
`step` returned (or observable indirectly via posterior delta vs.
the converged answer) is non-zero — i.e. the loop stopped early.
- Set `opts.convergence.max_iter = 30`. Assert posteriors match the
baseline within `epsilon`.
5. **Damping default is `1.0` and produces bit-equal output**
(smoke test, can be a single assertion in an existing test):
- `assert_eq!(ConvergenceOptions::default().alpha, 1.0);`
- Existing goldens prove the bit-equality.
No oscillation-stabilisation test (would require constructing a
pathological graph specifically to oscillate; out of scope for a
minimal ship).
## Verification gates
Per task:
```bash
cargo +nightly fmt
cargo clippy --all-targets -- -D warnings
cargo test --lib
cargo test
```
All must succeed. Test count grows by exactly the new tests above
(roughly +58 lib tests).
## Risks
- **Marginal-update change is subtle.** Switching the variable write
from `trunc` to `cavity * damped` is intentionally a no-op when
`alpha = 1.0` (since `cavity * new_msg = trunc`), but it changes the
arithmetic path. If `Gaussian` arithmetic has any non-associativity
in floating-point that the old form happened to dodge, goldens could
shift by 1 ULP. Mitigation: TDD — write the regression test (run all
existing tests with `alpha = 1.0`) **first**, before changing the
variable-write line.
- **`run_chain` signature change ripples to two callers.** Trivial
but must be done atomically with the field addition on `Game` /
`OwnedGame`.
- **`alpha` validation only in debug builds.** A release build will
silently accept `alpha = 0.0` or `alpha > 1.0` and produce nonsense.
This matches the existing pattern (`debug_assert!` for input
validation in `Game::ranked_with_arena`); upgrading to `Result` is
out of scope.
## Out-of-scope follow-ups (logged for future plans)
- Wire `Schedule` into `run_chain` (so `Damped` lands as a real
`Schedule` impl alongside `EpsilonOrMax`).
- Switch convergence check to `(|Δpi|, |Δtau|)` per spec
§"Stopping in natural-param space".
- Oscillation auto-detect (engage `alpha < 1.0` only after N
non-monotone steps).
- `Residual` schedule (priority queue).
- `SynergyFactor`, `ScoreFactor` (new EP factor types).
+1
View File
@@ -48,6 +48,7 @@ fn main() {
.convergence(trueskill_tt::ConvergenceOptions { .convergence(trueskill_tt::ConvergenceOptions {
max_iter: 10, max_iter: 10,
epsilon: 0.01, epsilon: 0.01,
alpha: 1.0,
}) })
.build(); .build();
+17
View File
@@ -8,6 +8,11 @@ use smallvec::SmallVec;
pub struct ConvergenceOptions { pub struct ConvergenceOptions {
pub max_iter: usize, pub max_iter: usize,
pub epsilon: f64, pub epsilon: f64,
/// EP damping factor in natural-parameter space: each per-factor
/// update writes `α·new + (1−α)·old`. `1.0` is undamped (default);
/// `< 1.0` stabilises oscillating fixed-point loops at the cost of
/// more iterations. Must be in `(0.0, 1.0]`.
pub alpha: f64,
} }
impl Default for ConvergenceOptions { impl Default for ConvergenceOptions {
@@ -15,6 +20,7 @@ impl Default for ConvergenceOptions {
Self { Self {
max_iter: crate::ITERATIONS, max_iter: crate::ITERATIONS,
epsilon: crate::EPSILON, epsilon: crate::EPSILON,
alpha: 1.0,
} }
} }
} }
@@ -29,3 +35,14 @@ pub struct ConvergenceReport {
pub per_iteration_time: SmallVec<[Duration; 32]>, pub per_iteration_time: SmallVec<[Duration; 32]>,
pub slices_skipped: usize, pub slices_skipped: usize,
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_alpha_is_one_for_undamped_behavior() {
let opts = ConvergenceOptions::default();
assert_eq!(opts.alpha, 1.0);
}
}
+60 -6
View File
@@ -32,8 +32,11 @@ impl MarginFactor {
} }
} }
impl Factor for MarginFactor { impl MarginFactor {
fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64) { /// Propagate this factor's message, optionally damping the update in
/// natural-parameter space. `alpha = 1.0` matches `Factor::propagate`
/// exactly; `alpha < 1.0` writes `α·new_msg + (1−α)·old_msg`.
pub(crate) fn propagate_with_alpha(&mut self, vars: &mut VarStore, alpha: f64) -> (f64, f64) {
let marginal = vars.get(self.diff); let marginal = vars.get(self.diff);
let cavity = marginal / self.msg; let cavity = marginal / self.msg;
@@ -42,12 +45,18 @@ impl Factor for MarginFactor {
} }
let new_msg = Gaussian::from_ms(self.m_obs, self.sigma); let new_msg = Gaussian::from_ms(self.m_obs, self.sigma);
let new_marginal = cavity * new_msg; let damped = self.msg.damp_natural(new_msg, alpha);
let old_msg = self.msg; let old_msg = self.msg;
self.msg = new_msg; self.msg = damped;
vars.set(self.diff, new_marginal); vars.set(self.diff, cavity * damped);
old_msg.delta(new_msg) old_msg.delta(damped)
}
}
impl Factor for MarginFactor {
fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64) {
self.propagate_with_alpha(vars, 1.0)
} }
fn log_evidence(&self, _vars: &VarStore) -> f64 { fn log_evidence(&self, _vars: &VarStore) -> f64 {
@@ -120,4 +129,49 @@ mod tests {
let logz = f.log_evidence(&vars); let logz = f.log_evidence(&vars);
assert!((logz - (-3.062235327364623)).abs() < 1e-10); assert!((logz - (-3.062235327364623)).abs() < 1e-10);
} }
#[test]
fn propagate_with_alpha_one_matches_undamped_propagate() {
let mut vars_a = VarStore::new();
let diff_a = vars_a.alloc(Gaussian::from_ms(0.0, 6.0));
let mut f_a = MarginFactor::new(diff_a, 5.0, 1.0);
let delta_a = f_a.propagate(&mut vars_a);
let result_a = vars_a.get(diff_a);
let mut vars_b = VarStore::new();
let diff_b = vars_b.alloc(Gaussian::from_ms(0.0, 6.0));
let mut f_b = MarginFactor::new(diff_b, 5.0, 1.0);
let delta_b = f_b.propagate_with_alpha(&mut vars_b, 1.0);
let result_b = vars_b.get(diff_b);
assert_eq!(result_a.pi(), result_b.pi());
assert_eq!(result_a.tau(), result_b.tau());
assert_eq!(delta_a, delta_b);
assert_eq!(f_a.msg.pi(), f_b.msg.pi());
assert_eq!(f_a.msg.tau(), f_b.msg.tau());
}
#[test]
fn propagate_with_alpha_half_blends_msg_in_natural_params() {
// Run undamped to capture (initial_msg, undamped_new_msg).
let mut vars_full = VarStore::new();
let diff_full = vars_full.alloc(Gaussian::from_ms(0.0, 6.0));
let mut f_full = MarginFactor::new(diff_full, 5.0, 1.0);
let initial_msg_pi = f_full.msg.pi();
let initial_msg_tau = f_full.msg.tau();
f_full.propagate(&mut vars_full);
let undamped_msg_pi = f_full.msg.pi();
let undamped_msg_tau = f_full.msg.tau();
// Run damped at α = 0.5 from the same initial state.
let mut vars_half = VarStore::new();
let diff_half = vars_half.alloc(Gaussian::from_ms(0.0, 6.0));
let mut f_half = MarginFactor::new(diff_half, 5.0, 1.0);
f_half.propagate_with_alpha(&mut vars_half, 0.5);
let expected_pi = 0.5 * undamped_msg_pi + 0.5 * initial_msg_pi;
let expected_tau = 0.5 * undamped_msg_tau + 0.5 * initial_msg_tau;
assert!((f_half.msg.pi() - expected_pi).abs() < 1e-12);
assert!((f_half.msg.tau() - expected_tau).abs() < 1e-12);
}
} }
+64 -11
View File
@@ -33,29 +33,37 @@ impl TruncFactor {
} }
} }
impl Factor for TruncFactor { impl TruncFactor {
fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64) { /// Propagate this factor's message, optionally damping the update in
/// natural-parameter space. `alpha = 1.0` matches `Factor::propagate`
/// exactly; `alpha < 1.0` writes `α·new_msg + (1−α)·old_msg`.
pub(crate) fn propagate_with_alpha(&mut self, vars: &mut VarStore, alpha: f64) -> (f64, f64) {
let marginal = vars.get(self.diff); let marginal = vars.get(self.diff);
// Cavity: marginal divided by our outgoing message.
let cavity = marginal / self.msg; let cavity = marginal / self.msg;
// First-time-only: cache the evidence contribution from the cavity.
if self.evidence_cached.is_none() { if self.evidence_cached.is_none() {
self.evidence_cached = Some(cavity_evidence(cavity, self.margin, self.tie)); self.evidence_cached = Some(cavity_evidence(cavity, self.margin, self.tie));
} }
// Apply the truncation approximation to the cavity.
let trunc = approx(cavity, self.margin, self.tie); let trunc = approx(cavity, self.margin, self.tie);
// New outgoing message such that cavity * new_msg = trunc.
let new_msg = trunc / cavity; let new_msg = trunc / cavity;
let damped = self.msg.damp_natural(new_msg, alpha);
let old_msg = self.msg; let old_msg = self.msg;
self.msg = new_msg; self.msg = damped;
// Update the marginal: marginal_new = cavity * new_msg = trunc. // marginal_new = cavity * stored_msg. With alpha = 1.0 this equals
vars.set(self.diff, trunc); // `trunc` (since cavity * new_msg = trunc by construction); with
// alpha < 1.0 it reflects the partially-applied update.
vars.set(self.diff, cavity * damped);
old_msg.delta(new_msg) old_msg.delta(damped)
}
}
impl Factor for TruncFactor {
fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64) {
self.propagate_with_alpha(vars, 1.0)
} }
fn log_evidence(&self, _vars: &VarStore) -> f64 { fn log_evidence(&self, _vars: &VarStore) -> f64 {
@@ -127,4 +135,49 @@ mod tests {
let ev = f.evidence_cached.unwrap(); let ev = f.evidence_cached.unwrap();
assert!(ev > 0.35 && ev < 0.42); assert!(ev > 0.35 && ev < 0.42);
} }
#[test]
fn propagate_with_alpha_one_matches_undamped_propagate() {
let mut vars_a = VarStore::new();
let diff_a = vars_a.alloc(Gaussian::from_ms(2.0, 3.0));
let mut f_a = TruncFactor::new(diff_a, 0.0, false);
let delta_a = f_a.propagate(&mut vars_a);
let result_a = vars_a.get(diff_a);
let mut vars_b = VarStore::new();
let diff_b = vars_b.alloc(Gaussian::from_ms(2.0, 3.0));
let mut f_b = TruncFactor::new(diff_b, 0.0, false);
let delta_b = f_b.propagate_with_alpha(&mut vars_b, 1.0);
let result_b = vars_b.get(diff_b);
assert_eq!(result_a.pi(), result_b.pi());
assert_eq!(result_a.tau(), result_b.tau());
assert_eq!(delta_a, delta_b);
assert_eq!(f_a.msg.pi(), f_b.msg.pi());
assert_eq!(f_a.msg.tau(), f_b.msg.tau());
}
#[test]
fn propagate_with_alpha_half_blends_msg_in_natural_params() {
// Run undamped to capture (initial_msg, undamped_new_msg).
let mut vars_full = VarStore::new();
let diff_full = vars_full.alloc(Gaussian::from_ms(2.0, 3.0));
let mut f_full = TruncFactor::new(diff_full, 0.0, false);
let initial_msg_pi = f_full.msg.pi();
let initial_msg_tau = f_full.msg.tau();
f_full.propagate(&mut vars_full);
let undamped_msg_pi = f_full.msg.pi();
let undamped_msg_tau = f_full.msg.tau();
// Run damped at α = 0.5 from the same initial state.
let mut vars_half = VarStore::new();
let diff_half = vars_half.alloc(Gaussian::from_ms(2.0, 3.0));
let mut f_half = TruncFactor::new(diff_half, 0.0, false);
f_half.propagate_with_alpha(&mut vars_half, 0.5);
let expected_pi = 0.5 * undamped_msg_pi + 0.5 * initial_msg_pi;
let expected_tau = 0.5 * undamped_msg_tau + 0.5 * initial_msg_tau;
assert!((f_half.msg.pi() - expected_pi).abs() < 1e-12);
assert!((f_half.msg.tau() - expected_tau).abs() < 1e-12);
}
} }
+189 -16
View File
@@ -44,11 +44,14 @@ impl DiffFactor {
} }
} }
pub(crate) fn propagate(&mut self, vars: &mut crate::factor::VarStore) -> (f64, f64) { pub(crate) fn propagate(
use crate::factor::Factor; &mut self,
vars: &mut crate::factor::VarStore,
alpha: f64,
) -> (f64, f64) {
match self { match self {
Self::Trunc(f) => f.propagate(vars), Self::Trunc(f) => f.propagate_with_alpha(vars, alpha),
Self::Margin(f) => f.propagate(vars), Self::Margin(f) => f.propagate_with_alpha(vars, alpha),
} }
} }
} }
@@ -87,6 +90,7 @@ pub struct OwnedGame<T: Time, D: Drift<T>> {
result: Vec<f64>, result: Vec<f64>,
weights: Vec<Vec<f64>>, weights: Vec<Vec<f64>>,
p_draw: f64, p_draw: f64,
pub(crate) convergence: crate::ConvergenceOptions,
pub(crate) likelihoods: Vec<Vec<Gaussian>>, pub(crate) likelihoods: Vec<Vec<Gaussian>>,
pub(crate) evidence: f64, pub(crate) evidence: f64,
} }
@@ -97,9 +101,17 @@ impl<T: Time, D: Drift<T>> OwnedGame<T, D> {
result: Vec<f64>, result: Vec<f64>,
weights: Vec<Vec<f64>>, weights: Vec<Vec<f64>>,
p_draw: f64, p_draw: f64,
convergence: crate::ConvergenceOptions,
) -> Self { ) -> Self {
let mut arena = ScratchArena::new(); let mut arena = ScratchArena::new();
let g = Game::ranked_with_arena(teams.clone(), &result, &weights, p_draw, &mut arena); let g = Game::ranked_with_arena(
teams.clone(),
&result,
&weights,
p_draw,
convergence,
&mut arena,
);
let likelihoods = g.likelihoods; let likelihoods = g.likelihoods;
let evidence = g.evidence; let evidence = g.evidence;
Self { Self {
@@ -107,6 +119,7 @@ impl<T: Time, D: Drift<T>> OwnedGame<T, D> {
result, result,
weights, weights,
p_draw, p_draw,
convergence,
likelihoods, likelihoods,
evidence, evidence,
} }
@@ -117,9 +130,17 @@ impl<T: Time, D: Drift<T>> OwnedGame<T, D> {
scores: Vec<f64>, scores: Vec<f64>,
weights: Vec<Vec<f64>>, weights: Vec<Vec<f64>>,
score_sigma: f64, score_sigma: f64,
convergence: crate::ConvergenceOptions,
) -> Self { ) -> Self {
let mut arena = ScratchArena::new(); let mut arena = ScratchArena::new();
let g = Game::scored_with_arena(teams.clone(), &scores, &weights, score_sigma, &mut arena); let g = Game::scored_with_arena(
teams.clone(),
&scores,
&weights,
score_sigma,
convergence,
&mut arena,
);
let likelihoods = g.likelihoods; let likelihoods = g.likelihoods;
let evidence = g.evidence; let evidence = g.evidence;
Self { Self {
@@ -127,6 +148,7 @@ impl<T: Time, D: Drift<T>> OwnedGame<T, D> {
result: scores, result: scores,
weights, weights,
p_draw: 0.0, p_draw: 0.0,
convergence,
likelihoods, likelihoods,
evidence, evidence,
} }
@@ -151,6 +173,7 @@ pub struct Game<'a, T: Time = i64, D: Drift<T> = crate::drift::ConstantDrift> {
result: &'a [f64], result: &'a [f64],
weights: &'a [Vec<f64>], weights: &'a [Vec<f64>],
p_draw: f64, p_draw: f64,
pub(crate) convergence: crate::ConvergenceOptions,
pub(crate) likelihoods: Vec<Vec<Gaussian>>, pub(crate) likelihoods: Vec<Vec<Gaussian>>,
pub(crate) evidence: f64, pub(crate) evidence: f64,
} }
@@ -161,6 +184,7 @@ impl<'a, T: Time, D: Drift<T>> Game<'a, T, D> {
result: &'a [f64], result: &'a [f64],
weights: &'a [Vec<f64>], weights: &'a [Vec<f64>],
p_draw: f64, p_draw: f64,
convergence: crate::ConvergenceOptions,
arena: &mut ScratchArena, arena: &mut ScratchArena,
) -> Self { ) -> Self {
debug_assert!( debug_assert!(
@@ -186,12 +210,17 @@ impl<'a, T: Time, D: Drift<T>> Game<'a, T, D> {
}, },
"draw must be > 0.0 if there are teams with draw" "draw must be > 0.0 if there are teams with draw"
); );
debug_assert!(
convergence.alpha > 0.0 && convergence.alpha <= 1.0,
"convergence alpha must be in (0.0, 1.0]"
);
let mut this = Self { let mut this = Self {
teams, teams,
result, result,
weights, weights,
p_draw, p_draw,
convergence,
likelihoods: Vec::new(), likelihoods: Vec::new(),
evidence: 0.0, evidence: 0.0,
}; };
@@ -205,6 +234,7 @@ impl<'a, T: Time, D: Drift<T>> Game<'a, T, D> {
scores: &'a [f64], scores: &'a [f64],
weights: &'a [Vec<f64>], weights: &'a [Vec<f64>],
score_sigma: f64, score_sigma: f64,
convergence: crate::ConvergenceOptions,
arena: &mut ScratchArena, arena: &mut ScratchArena,
) -> Self { ) -> Self {
debug_assert!( debug_assert!(
@@ -219,12 +249,17 @@ impl<'a, T: Time, D: Drift<T>> Game<'a, T, D> {
"weights must have the same dimensions as teams" "weights must have the same dimensions as teams"
); );
debug_assert!(score_sigma > 0.0, "score_sigma must be positive"); debug_assert!(score_sigma > 0.0, "score_sigma must be positive");
debug_assert!(
convergence.alpha > 0.0 && convergence.alpha <= 1.0,
"convergence alpha must be in (0.0, 1.0]"
);
let mut this = Self { let mut this = Self {
teams, teams,
result: scores, result: scores,
weights, weights,
p_draw: 0.0, p_draw: 0.0,
convergence,
likelihoods: Vec::new(), likelihoods: Vec::new(),
evidence: 0.0, evidence: 0.0,
}; };
@@ -239,6 +274,10 @@ impl<'a, T: Time, D: Drift<T>> Game<'a, T, D> {
{ {
arena.reset(); arena.reset();
let alpha = self.convergence.alpha;
let epsilon = self.convergence.epsilon;
let max_iter = self.convergence.max_iter;
let n_teams = self.teams.len(); let n_teams = self.teams.len();
arena.sort_buf.extend(0..n_teams); arena.sort_buf.extend(0..n_teams);
@@ -267,7 +306,7 @@ impl<'a, T: Time, D: Drift<T>> Game<'a, T, D> {
let mut step = (f64::INFINITY, f64::INFINITY); let mut step = (f64::INFINITY, f64::INFINITY);
let mut iter = 0; let mut iter = 0;
while tuple_gt(step, 1e-6) && iter < 10 { while tuple_gt(step, epsilon) && iter < max_iter {
step = (0.0_f64, 0.0_f64); step = (0.0_f64, 0.0_f64);
for (e, lf) in links[..n_diffs.saturating_sub(1)].iter_mut().enumerate() { for (e, lf) in links[..n_diffs.saturating_sub(1)].iter_mut().enumerate() {
@@ -275,7 +314,7 @@ impl<'a, T: Time, D: Drift<T>> Game<'a, T, D> {
let pl = arena.team_prior[e + 1] * arena.lhood_win[e + 1]; let pl = arena.team_prior[e + 1] * arena.lhood_win[e + 1];
let raw = pw - pl; let raw = pw - pl;
arena.vars.set(lf.diff(), raw * lf.msg()); arena.vars.set(lf.diff(), raw * lf.msg());
let d = lf.propagate(&mut arena.vars); let d = lf.propagate(&mut arena.vars, alpha);
step = tuple_max(step, d); step = tuple_max(step, d);
let new_ll = pw - lf.msg(); let new_ll = pw - lf.msg();
@@ -289,7 +328,7 @@ impl<'a, T: Time, D: Drift<T>> Game<'a, T, D> {
let pl = arena.team_prior[e + 1] * arena.lhood_win[e + 1]; let pl = arena.team_prior[e + 1] * arena.lhood_win[e + 1];
let raw = pw - pl; let raw = pw - pl;
arena.vars.set(lf.diff(), raw * lf.msg()); arena.vars.set(lf.diff(), raw * lf.msg());
let d = lf.propagate(&mut arena.vars); let d = lf.propagate(&mut arena.vars, alpha);
step = tuple_max(step, d); step = tuple_max(step, d);
let new_lw = pl + lf.msg(); let new_lw = pl + lf.msg();
@@ -305,7 +344,7 @@ impl<'a, T: Time, D: Drift<T>> Game<'a, T, D> {
let raw = (arena.team_prior[0] * arena.lhood_lose[0]) let raw = (arena.team_prior[0] * arena.lhood_lose[0])
- (arena.team_prior[1] * arena.lhood_win[1]); - (arena.team_prior[1] * arena.lhood_win[1]);
arena.vars.set(links[0].diff(), raw * links[0].msg()); arena.vars.set(links[0].diff(), raw * links[0].msg());
links[0].propagate(&mut arena.vars); links[0].propagate(&mut arena.vars, alpha);
} }
// Boundary updates: close the chain at both ends. // Boundary updates: close the chain at both ends.
@@ -429,7 +468,13 @@ impl<T: Time, D: Drift<T>> Game<'_, T, D> {
let teams_owned: Vec<Vec<Rating<T, D>>> = teams.iter().map(|t| t.to_vec()).collect(); let teams_owned: Vec<Vec<Rating<T, D>>> = teams.iter().map(|t| t.to_vec()).collect();
let weights: Vec<Vec<f64>> = teams.iter().map(|t| vec![1.0; t.len()]).collect(); let weights: Vec<Vec<f64>> = teams.iter().map(|t| vec![1.0; t.len()]).collect();
Ok(OwnedGame::new(teams_owned, result, weights, options.p_draw)) Ok(OwnedGame::new(
teams_owned,
result,
weights,
options.p_draw,
options.convergence,
))
} }
pub fn scored( pub fn scored(
@@ -465,6 +510,7 @@ impl<T: Time, D: Drift<T>> Game<'_, T, D> {
scores, scores,
weights, weights,
options.score_sigma, options.score_sigma,
options.convergence,
)) ))
} }
@@ -526,6 +572,7 @@ mod tests {
&[0.0, 1.0], &[0.0, 1.0],
&w, &w,
0.0, 0.0,
crate::ConvergenceOptions::default(),
&mut ScratchArena::new(), &mut ScratchArena::new(),
); );
let p = g.posteriors(); let p = g.posteriors();
@@ -553,6 +600,7 @@ mod tests {
&[0.0, 1.0], &[0.0, 1.0],
&w, &w,
0.0, 0.0,
crate::ConvergenceOptions::default(),
&mut ScratchArena::new(), &mut ScratchArena::new(),
); );
let p = g.posteriors(); let p = g.posteriors();
@@ -572,6 +620,7 @@ mod tests {
&[0.0, 1.0], &[0.0, 1.0],
&w, &w,
0.0, 0.0,
crate::ConvergenceOptions::default(),
&mut ScratchArena::new(), &mut ScratchArena::new(),
); );
@@ -605,6 +654,7 @@ mod tests {
&[1.0, 2.0, 0.0], &[1.0, 2.0, 0.0],
&w, &w,
0.0, 0.0,
crate::ConvergenceOptions::default(),
&mut ScratchArena::new(), &mut ScratchArena::new(),
); );
let p = g.posteriors(); let p = g.posteriors();
@@ -621,6 +671,7 @@ mod tests {
&[2.0, 1.0, 0.0], &[2.0, 1.0, 0.0],
&w, &w,
0.0, 0.0,
crate::ConvergenceOptions::default(),
&mut ScratchArena::new(), &mut ScratchArena::new(),
); );
let p = g.posteriors(); let p = g.posteriors();
@@ -632,7 +683,14 @@ mod tests {
assert_ulps_eq!(b, Gaussian::from_ms(25.000000, 6.238469), epsilon = 1e-6); assert_ulps_eq!(b, Gaussian::from_ms(25.000000, 6.238469), epsilon = 1e-6);
let w = [vec![1.0], vec![1.0], vec![1.0]]; let w = [vec![1.0], vec![1.0], vec![1.0]];
let g = Game::ranked_with_arena(teams, &[1.0, 2.0, 0.0], &w, 0.5, &mut ScratchArena::new()); let g = Game::ranked_with_arena(
teams,
&[1.0, 2.0, 0.0],
&w,
0.5,
crate::ConvergenceOptions::default(),
&mut ScratchArena::new(),
);
let p = g.posteriors(); let p = g.posteriors();
let a = p[0][0]; let a = p[0][0];
@@ -664,6 +722,7 @@ mod tests {
&[0.0, 0.0], &[0.0, 0.0],
&w, &w,
0.25, 0.25,
crate::ConvergenceOptions::default(),
&mut ScratchArena::new(), &mut ScratchArena::new(),
); );
let p = g.posteriors(); let p = g.posteriors();
@@ -691,6 +750,7 @@ mod tests {
&[0.0, 0.0], &[0.0, 0.0],
&w, &w,
0.25, 0.25,
crate::ConvergenceOptions::default(),
&mut ScratchArena::new(), &mut ScratchArena::new(),
); );
let p = g.posteriors(); let p = g.posteriors();
@@ -726,6 +786,7 @@ mod tests {
&[0.0, 0.0, 0.0], &[0.0, 0.0, 0.0],
&w, &w,
0.25, 0.25,
crate::ConvergenceOptions::default(),
&mut ScratchArena::new(), &mut ScratchArena::new(),
); );
let p = g.posteriors(); let p = g.posteriors();
@@ -762,6 +823,7 @@ mod tests {
&[0.0, 0.0, 0.0], &[0.0, 0.0, 0.0],
&w, &w,
0.25, 0.25,
crate::ConvergenceOptions::default(),
&mut ScratchArena::new(), &mut ScratchArena::new(),
); );
let p = g.posteriors(); let p = g.posteriors();
@@ -813,6 +875,7 @@ mod tests {
&[1.0, 0.0, 0.0], &[1.0, 0.0, 0.0],
&w, &w,
0.25, 0.25,
crate::ConvergenceOptions::default(),
&mut ScratchArena::new(), &mut ScratchArena::new(),
); );
let p = g.posteriors(); let p = g.posteriors();
@@ -846,6 +909,7 @@ mod tests {
&[1.0, 0.0], &[1.0, 0.0],
&w, &w,
0.0, 0.0,
crate::ConvergenceOptions::default(),
&mut ScratchArena::new(), &mut ScratchArena::new(),
); );
let p = g.posteriors(); let p = g.posteriors();
@@ -870,6 +934,7 @@ mod tests {
&[1.0, 0.0], &[1.0, 0.0],
&w, &w,
0.0, 0.0,
crate::ConvergenceOptions::default(),
&mut ScratchArena::new(), &mut ScratchArena::new(),
); );
let p = g.posteriors(); let p = g.posteriors();
@@ -894,6 +959,7 @@ mod tests {
&[1.0, 0.0], &[1.0, 0.0],
&w, &w,
0.0, 0.0,
crate::ConvergenceOptions::default(),
&mut ScratchArena::new(), &mut ScratchArena::new(),
); );
let p = g.posteriors(); let p = g.posteriors();
@@ -921,6 +987,7 @@ mod tests {
&[1.0, 0.0], &[1.0, 0.0],
&w, &w,
0.0, 0.0,
crate::ConvergenceOptions::default(),
&mut ScratchArena::new(), &mut ScratchArena::new(),
); );
let p = g.posteriors(); let p = g.posteriors();
@@ -948,6 +1015,7 @@ mod tests {
&[1.0, 0.0], &[1.0, 0.0],
&w, &w,
0.0, 0.0,
crate::ConvergenceOptions::default(),
&mut ScratchArena::new(), &mut ScratchArena::new(),
); );
let p = g.posteriors(); let p = g.posteriors();
@@ -967,8 +1035,8 @@ mod tests {
let mut t = DiffFactor::Trunc(TruncFactor::new(dt, 0.0, false)); let mut t = DiffFactor::Trunc(TruncFactor::new(dt, 0.0, false));
let mut m = DiffFactor::Margin(MarginFactor::new(dm, 5.0, 1.0)); let mut m = DiffFactor::Margin(MarginFactor::new(dm, 5.0, 1.0));
let _ = t.propagate(&mut vars); let _ = t.propagate(&mut vars, 1.0);
let _ = m.propagate(&mut vars); let _ = m.propagate(&mut vars, 1.0);
// Smoke: both diffs got written; their msgs are non-N_INF. // Smoke: both diffs got written; their msgs are non-N_INF.
assert!(t.msg().pi() > 0.0); assert!(t.msg().pi() > 0.0);
@@ -989,7 +1057,11 @@ mod tests {
let weights = [vec![1.0], vec![1.0]]; let weights = [vec![1.0], vec![1.0]];
let mut arena = ScratchArena::new(); let mut arena = ScratchArena::new();
let g = Game::scored_with_arena( let g = Game::scored_with_arena(
teams, &result, &weights, 1.0, // score_sigma teams,
&result,
&weights,
1.0,
crate::ConvergenceOptions::default(),
&mut arena, &mut arena,
); );
let p = g.posteriors(); let p = g.posteriors();
@@ -1008,7 +1080,8 @@ mod tests {
vec![vec![prior], vec![prior]], vec![vec![prior], vec![prior]],
&result, &result,
&weights, &weights,
0.1, // tighter score_sigma 0.1,
crate::ConvergenceOptions::default(),
&mut arena2, &mut arena2,
); );
let p_tight = g_tight.posteriors(); let p_tight = g_tight.posteriors();
@@ -1116,6 +1189,7 @@ mod tests {
&[1.0, 0.0], &[1.0, 0.0],
&w, &w,
0.0, 0.0,
crate::ConvergenceOptions::default(),
&mut ScratchArena::new(), &mut ScratchArena::new(),
); );
let p = g.posteriors(); let p = g.posteriors();
@@ -1150,6 +1224,7 @@ mod tests {
&[1.0, 0.0], &[1.0, 0.0],
&w, &w,
0.0, 0.0,
crate::ConvergenceOptions::default(),
&mut ScratchArena::new(), &mut ScratchArena::new(),
); );
let p = g.posteriors(); let p = g.posteriors();
@@ -1184,6 +1259,7 @@ mod tests {
&[1.0, 0.0], &[1.0, 0.0],
&w, &w,
0.0, 0.0,
crate::ConvergenceOptions::default(),
&mut ScratchArena::new(), &mut ScratchArena::new(),
); );
let p = g.posteriors(); let p = g.posteriors();
@@ -1222,6 +1298,7 @@ mod tests {
&[1.0, 0.0], &[1.0, 0.0],
&w, &w,
0.0, 0.0,
crate::ConvergenceOptions::default(),
&mut ScratchArena::new(), &mut ScratchArena::new(),
); );
let post_2vs1 = g.posteriors(); let post_2vs1 = g.posteriors();
@@ -1235,6 +1312,7 @@ mod tests {
&[1.0, 0.0], &[1.0, 0.0],
&w, &w,
0.0, 0.0,
crate::ConvergenceOptions::default(),
&mut ScratchArena::new(), &mut ScratchArena::new(),
); );
let p = g.posteriors(); let p = g.posteriors();
@@ -1244,4 +1322,99 @@ mod tests {
assert_ulps_eq!(p[1][0], post_2vs1[1][0], epsilon = 1e-6); assert_ulps_eq!(p[1][0], post_2vs1[1][0], epsilon = 1e-6);
assert_ulps_eq!(p[1][1], t_b[1].prior, epsilon = 1e-6); assert_ulps_eq!(p[1][1], t_b[1].prior, epsilon = 1e-6);
} }
#[test]
fn run_chain_honours_max_iter_in_convergence_options() {
let players: Vec<R> = (0..4).map(|_| R::default()).collect();
let teams: Vec<Vec<_>> = players.iter().map(|p| vec![*p]).collect();
let result = vec![3.0, 2.0, 1.0, 0.0];
let weights = vec![vec![1.0]; 4];
// Capped at 1 iteration: cannot fully propagate down a 4-team chain.
let mut arena = ScratchArena::new();
let g_capped = Game::ranked_with_arena(
teams.clone(),
&result,
&weights,
0.0,
crate::ConvergenceOptions {
max_iter: 1,
..crate::ConvergenceOptions::default()
},
&mut arena,
);
let posteriors_capped = g_capped.posteriors();
// Same inputs, plenty of iterations: fully converged.
let mut arena = ScratchArena::new();
let g_full = Game::ranked_with_arena(
teams,
&result,
&weights,
0.0,
crate::ConvergenceOptions::default(),
&mut arena,
);
let posteriors_full = g_full.posteriors();
// The two posteriors should differ — capped did not converge.
let mut max_diff: f64 = 0.0;
for (team_capped, team_full) in posteriors_capped.iter().zip(posteriors_full.iter()) {
for (g_capped, g_full) in team_capped.iter().zip(team_full.iter()) {
max_diff = max_diff.max((g_capped.mu() - g_full.mu()).abs());
max_diff = max_diff.max((g_capped.sigma() - g_full.sigma()).abs());
}
}
assert!(
max_diff > 1e-6,
"max_iter=1 should differ from full convergence; max_diff={max_diff}"
);
}
#[test]
fn run_chain_with_damping_converges_to_same_posterior() {
let players: Vec<R> = (0..4).map(|_| R::default()).collect();
let teams: Vec<Vec<_>> = players.iter().map(|p| vec![*p]).collect();
let result = vec![3.0, 2.0, 1.0, 0.0];
let weights = vec![vec![1.0]; 4];
let mut arena = ScratchArena::new();
let g_undamped = Game::ranked_with_arena(
teams.clone(),
&result,
&weights,
0.0,
crate::ConvergenceOptions::default(),
&mut arena,
);
let posteriors_undamped = g_undamped.posteriors();
// alpha=0.5 with extra iterations: should reach the same fixed point.
let mut arena = ScratchArena::new();
let g_damped = Game::ranked_with_arena(
teams,
&result,
&weights,
0.0,
crate::ConvergenceOptions {
alpha: 0.5,
max_iter: 100,
..crate::ConvergenceOptions::default()
},
&mut arena,
);
let posteriors_damped = g_damped.posteriors();
let mut max_diff: f64 = 0.0;
for (team_u, team_d) in posteriors_undamped.iter().zip(posteriors_damped.iter()) {
for (g_u, g_d) in team_u.iter().zip(team_d.iter()) {
max_diff = max_diff.max((g_u.mu() - g_d.mu()).abs());
max_diff = max_diff.max((g_u.sigma() - g_d.sigma()).abs());
}
}
assert!(
max_diff < 1e-4,
"α=0.5 should reach the same fixed point as α=1.0; max_diff={max_diff}"
);
}
} }
+41
View File
@@ -96,6 +96,18 @@ impl Gaussian {
let var = self.sigma().powi(2) + variance_delta; let var = self.sigma().powi(2) + variance_delta;
Self::from_ms(self.mu(), var.sqrt()) Self::from_ms(self.mu(), var.sqrt())
} }
/// EP damping in natural-parameter space: `α·new + (1−α)·self`.
///
/// Used by within-game inference to stabilise oscillating fixed-point
/// loops on hard graphs. `alpha = 1.0` returns `new` exactly;
/// `alpha < 1.0` shrinks each per-step update.
pub fn damp_natural(self, new: Gaussian, alpha: f64) -> Gaussian {
Gaussian::from_natural(
alpha * new.pi() + (1.0 - alpha) * self.pi(),
alpha * new.tau() + (1.0 - alpha) * self.tau(),
)
}
} }
impl Default for Gaussian { impl Default for Gaussian {
@@ -231,4 +243,33 @@ mod tests {
assert!((r.pi() - expected_pi).abs() < 1e-15); assert!((r.pi() - expected_pi).abs() < 1e-15);
assert!((r.tau() - expected_tau).abs() < 1e-15); assert!((r.tau() - expected_tau).abs() < 1e-15);
} }
#[test]
fn damp_natural_alpha_one_returns_new() {
let old = Gaussian::from_ms(1.0, 2.0);
let new = Gaussian::from_ms(5.0, 0.5);
let damped = old.damp_natural(new, 1.0);
assert_eq!(damped.pi(), new.pi());
assert_eq!(damped.tau(), new.tau());
}
#[test]
fn damp_natural_alpha_zero_returns_self() {
let old = Gaussian::from_ms(1.0, 2.0);
let new = Gaussian::from_ms(5.0, 0.5);
let damped = old.damp_natural(new, 0.0);
assert_eq!(damped.pi(), old.pi());
assert_eq!(damped.tau(), old.tau());
}
#[test]
fn damp_natural_alpha_half_is_midpoint_in_natural_params() {
let old = Gaussian::from_ms(1.0, 2.0);
let new = Gaussian::from_ms(5.0, 0.5);
let damped = old.damp_natural(new, 0.5);
let expected_pi = 0.5 * new.pi() + 0.5 * old.pi();
let expected_tau = 0.5 * new.tau() + 0.5 * old.tau();
assert!((damped.pi() - expected_pi).abs() < 1e-12);
assert!((damped.tau() - expected_tau).abs() < 1e-12);
}
} }
+3
View File
@@ -838,6 +838,7 @@ mod tests {
&[0.0, 1.0], &[0.0, 1.0],
&w, &w,
P_DRAW, P_DRAW,
crate::ConvergenceOptions::default(),
&mut ScratchArena::new(), &mut ScratchArena::new(),
) )
.posteriors(); .posteriors();
@@ -1368,6 +1369,7 @@ mod tests {
h.convergence = ConvergenceOptions { h.convergence = ConvergenceOptions {
max_iter: 11, max_iter: 11,
epsilon: EPSILON, epsilon: EPSILON,
alpha: 1.0,
}; };
h.converge().unwrap(); h.converge().unwrap();
@@ -1685,6 +1687,7 @@ mod tests {
.convergence(ConvergenceOptions { .convergence(ConvergenceOptions {
max_iter: 30, max_iter: 30,
epsilon: 1e-6, epsilon: 1e-6,
alpha: 1.0,
}) })
.build(); .build();
+36 -14
View File
@@ -138,12 +138,22 @@ impl Event {
let teams = self.within_priors(false, false, skills, agents); let teams = self.within_priors(false, false, skills, agents);
let result = self.outputs(); let result = self.outputs();
let g = match self.kind { let g = match self.kind {
EventKind::Ranked => { EventKind::Ranked => Game::ranked_with_arena(
Game::ranked_with_arena(teams, &result, &self.weights, p_draw, arena) teams,
} &result,
EventKind::Scored { score_sigma } => { &self.weights,
Game::scored_with_arena(teams, &result, &self.weights, score_sigma, arena) p_draw,
} crate::ConvergenceOptions::default(),
arena,
),
EventKind::Scored { score_sigma } => Game::scored_with_arena(
teams,
&result,
&self.weights,
score_sigma,
crate::ConvergenceOptions::default(),
arena,
),
}; };
for (t, team) in self.teams.iter_mut().enumerate() { for (t, team) in self.teams.iter_mut().enumerate() {
@@ -322,6 +332,7 @@ impl<T: Time> TimeSlice<T> {
&result, &result,
&event.weights, &event.weights,
self.p_draw, self.p_draw,
crate::ConvergenceOptions::default(),
&mut self.arena, &mut self.arena,
), ),
EventKind::Scored { score_sigma } => Game::scored_with_arena( EventKind::Scored { score_sigma } => Game::scored_with_arena(
@@ -329,6 +340,7 @@ impl<T: Time> TimeSlice<T> {
&result, &result,
&event.weights, &event.weights,
score_sigma, score_sigma,
crate::ConvergenceOptions::default(),
&mut self.arena, &mut self.arena,
), ),
}; };
@@ -504,16 +516,26 @@ impl<T: Time> TimeSlice<T> {
let teams = event.within_priors(online, forward, &self.skills, agents); let teams = event.within_priors(online, forward, &self.skills, agents);
let result = event.outputs(); let result = event.outputs();
match event.kind { match event.kind {
EventKind::Ranked => { EventKind::Ranked => Game::ranked_with_arena(
Game::ranked_with_arena(teams, &result, &event.weights, self.p_draw, arena) teams,
&result,
&event.weights,
self.p_draw,
crate::ConvergenceOptions::default(),
arena,
)
.evidence .evidence
.ln() .ln(),
} EventKind::Scored { score_sigma } => Game::scored_with_arena(
EventKind::Scored { score_sigma } => { teams,
Game::scored_with_arena(teams, &result, &event.weights, score_sigma, arena) &result,
&event.weights,
score_sigma,
crate::ConvergenceOptions::default(),
arena,
)
.evidence .evidence
.ln() .ln(),
}
} }
}; };
+1
View File
@@ -15,6 +15,7 @@ fn add_events_bulk_via_iter() {
.convergence(ConvergenceOptions { .convergence(ConvergenceOptions {
max_iter: 30, max_iter: 30,
epsilon: 1e-6, epsilon: 1e-6,
alpha: 1.0,
}) })
.build(); .build();
+1
View File
@@ -16,6 +16,7 @@ fn build_and_converge(seed: u64) -> Vec<(i64, trueskill_tt::Gaussian)> {
.convergence(ConvergenceOptions { .convergence(ConvergenceOptions {
max_iter: 30, max_iter: 30,
epsilon: 1e-6, epsilon: 1e-6,
alpha: 1.0,
}) })
.build(); .build();
+1
View File
@@ -10,6 +10,7 @@ fn record_winner_builds_history() {
.convergence(ConvergenceOptions { .convergence(ConvergenceOptions {
max_iter: 30, max_iter: 30,
epsilon: 1e-6, epsilon: 1e-6,
alpha: 1.0,
}) })
.build(); .build();