From a6e008f8ff4ec7c2a8fa78015f4421ee4644e166 Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Fri, 24 Apr 2026 12:20:24 +0200 Subject: [PATCH] feat(api): add ConvergenceOptions, ConvergenceReport, History::converge MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New public types: - ConvergenceOptions { max_iter, epsilon } — config for the loop - ConvergenceReport { iterations, final_step, log_evidence, converged, per_iteration_time, slices_skipped } — post-hoc summary History and HistoryBuilder gain a third generic parameter O: Observer = NullObserver. Builder methods: - .convergence(opts) sets the ConvergenceOptions - .observer(o) plugs in an Observer (reshapes the builder's O param) History::converge() runs the existing iteration loop driven by the stored opts, emits observer callbacks on each iteration end and on completion, and returns Result. The old convergence(iters, eps, verbose) stays — gets removed in Task 20 after tests are translated. Part of T2 of docs/superpowers/specs/2026-04-23-trueskill-engine-redesign-design.md. --- src/convergence.rs | 31 +++++++++++ src/history.rs | 136 +++++++++++++++++++++++++++++++++++++-------- src/lib.rs | 2 + 3 files changed, 145 insertions(+), 24 deletions(-) create mode 100644 src/convergence.rs diff --git a/src/convergence.rs b/src/convergence.rs new file mode 100644 index 0000000..d03359c --- /dev/null +++ b/src/convergence.rs @@ -0,0 +1,31 @@ +//! Convergence configuration and reporting. + +use std::time::Duration; + +use smallvec::SmallVec; + +#[derive(Clone, Copy, Debug)] +pub struct ConvergenceOptions { + pub max_iter: usize, + pub epsilon: f64, +} + +impl Default for ConvergenceOptions { + fn default() -> Self { + Self { + max_iter: crate::ITERATIONS, + epsilon: crate::EPSILON, + } + } +} + +/// Post-hoc summary of a `History::converge` call. +#[derive(Clone, Debug)] +pub struct ConvergenceReport { + pub iterations: usize, + pub final_step: (f64, f64), + pub log_evidence: f64, + pub converged: bool, + pub per_iteration_time: SmallVec<[Duration; 32]>, + pub slices_skipped: usize, +} diff --git a/src/history.rs b/src/history.rs index f162b3c..8545d8d 100644 --- a/src/history.rs +++ b/src/history.rs @@ -3,8 +3,11 @@ use std::collections::HashMap; use crate::{ BETA, GAMMA, Index, MU, N_INF, P_DRAW, SIGMA, competitor::{self, Competitor}, + convergence::{ConvergenceOptions, ConvergenceReport}, drift::{ConstantDrift, Drift}, + error::InferenceError, gaussian::Gaussian, + observer::{NullObserver, Observer}, rating::Rating, sort_time, storage::CompetitorStore, @@ -14,17 +17,20 @@ use crate::{ }; #[derive(Clone)] -pub struct HistoryBuilder = ConstantDrift> { +pub struct HistoryBuilder = ConstantDrift, O: Observer = NullObserver> +{ mu: f64, sigma: f64, beta: f64, drift: D, p_draw: f64, online: bool, + convergence: ConvergenceOptions, + observer: O, _time: std::marker::PhantomData, } -impl> HistoryBuilder { +impl, O: Observer> HistoryBuilder { pub fn mu(mut self, mu: f64) -> Self { self.mu = mu; self @@ -40,7 +46,7 @@ impl> HistoryBuilder { self } - pub fn drift>(self, drift: D2) -> HistoryBuilder { + pub fn drift>(self, drift: D2) -> HistoryBuilder { HistoryBuilder { drift, mu: self.mu, @@ -48,7 +54,9 @@ impl> HistoryBuilder { beta: self.beta, p_draw: self.p_draw, online: self.online, - _time: std::marker::PhantomData, + convergence: self.convergence, + observer: self.observer, + _time: self._time, } } @@ -62,7 +70,26 @@ impl> HistoryBuilder { self } - pub fn build(self) -> History { + pub fn convergence(mut self, opts: ConvergenceOptions) -> Self { + self.convergence = opts; + self + } + + pub fn observer>(self, observer: O2) -> HistoryBuilder { + HistoryBuilder { + mu: self.mu, + sigma: self.sigma, + beta: self.beta, + drift: self.drift, + p_draw: self.p_draw, + online: self.online, + convergence: self.convergence, + observer, + _time: self._time, + } + } + + pub fn build(self) -> History { History { size: 0, time_slices: Vec::new(), @@ -73,18 +100,20 @@ impl> HistoryBuilder { drift: self.drift, p_draw: self.p_draw, online: self.online, + convergence: self.convergence, + observer: self.observer, } } } -impl HistoryBuilder { +impl> HistoryBuilder { pub fn gamma(mut self, gamma: f64) -> Self { self.drift = ConstantDrift(gamma); self } } -impl Default for HistoryBuilder { +impl Default for HistoryBuilder { fn default() -> Self { Self { mu: MU, @@ -93,12 +122,14 @@ impl Default for HistoryBuilder { drift: ConstantDrift(GAMMA), p_draw: P_DRAW, online: false, + convergence: ConvergenceOptions::default(), + observer: NullObserver, _time: std::marker::PhantomData, } } } -pub struct History = ConstantDrift> { +pub struct History = ConstantDrift, O: Observer = NullObserver> { size: usize, pub(crate) time_slices: Vec>, pub(crate) agents: CompetitorStore, @@ -108,31 +139,23 @@ pub struct History = ConstantDrift> { drift: D, p_draw: f64, online: bool, + convergence: ConvergenceOptions, + observer: O, } -impl Default for History { +impl Default for History { fn default() -> Self { - Self { - size: 0, - time_slices: Vec::new(), - agents: CompetitorStore::new(), - mu: MU, - sigma: SIGMA, - beta: BETA, - drift: ConstantDrift(GAMMA), - p_draw: P_DRAW, - online: false, - } + HistoryBuilder::default().build() } } -impl History { - pub fn builder() -> HistoryBuilder { +impl History { + pub fn builder() -> HistoryBuilder { HistoryBuilder::default() } } -impl> History { +impl, O: Observer> History { fn iteration(&mut self) -> (f64, f64) { let mut step = (0.0, 0.0); @@ -243,9 +266,39 @@ impl> History { .map(|ts| ts.log_evidence(self.online, targets, forward, &self.agents)) .sum() } + + /// Run the full forward+backward convergence loop and return a summary. + pub fn converge(&mut self) -> Result { + use std::time::Instant; + + use smallvec::SmallVec; + + let opts = self.convergence; + let mut step = (f64::INFINITY, f64::INFINITY); + let mut i = 0; + let mut per_iter: SmallVec<[std::time::Duration; 32]> = SmallVec::new(); + while tuple_gt(step, opts.epsilon) && i < opts.max_iter { + let t0 = Instant::now(); + step = self.iteration(); + per_iter.push(t0.elapsed()); + i += 1; + self.observer.on_iteration_end(i, step); + } + let converged = !tuple_gt(step, opts.epsilon); + let log_evidence = self.log_evidence(false, &[]); + self.observer.on_converged(i, step, converged); + Ok(ConvergenceReport { + iterations: i, + final_step: step, + log_evidence, + converged, + per_iteration_time: per_iter, + slices_skipped: 0, + }) + } } -impl> History { +impl, O: Observer> History { pub fn add_events( &mut self, composition: Vec>>, @@ -1287,4 +1340,39 @@ mod tests { assert_ulps_eq!(lc[&a][1].1, lc[&a][0].1, epsilon = 1e-6); assert_ulps_eq!(lc[&b][1].1, lc[&a][0].1, epsilon = 1e-6); } + + #[test] + fn test_converge_returns_report() { + use crate::ConvergenceOptions; + + let mut index_map = crate::KeyTable::new(); + let a = index_map.get_or_create("a"); + let b = index_map.get_or_create("b"); + let c = index_map.get_or_create("c"); + let composition = vec![ + vec![vec![a], vec![b]], + vec![vec![a], vec![c]], + vec![vec![b], vec![c]], + ]; + let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]]; + let times: Vec = vec![1, 2, 3]; + + let mut h = History::builder() + .mu(0.0) + .sigma(2.0) + .beta(1.0) + .drift(ConstantDrift(0.0)) + .convergence(ConvergenceOptions { + max_iter: 30, + epsilon: 1e-6, + }) + .build(); + h.add_events(composition, results, times, vec![]); + + let report = h.converge().unwrap(); + assert!(report.converged); + assert!(report.iterations > 0); + assert!(report.iterations < 30); + assert!(report.final_step.0 <= 1e-6); + } } diff --git a/src/lib.rs b/src/lib.rs index d4efa2d..5a397fc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,7 @@ mod time; mod time_slice; pub use time_slice::TimeSlice; mod competitor; +mod convergence; pub mod drift; mod error; mod event; @@ -26,6 +27,7 @@ pub(crate) mod schedule; pub mod storage; pub use competitor::Competitor; +pub use convergence::{ConvergenceOptions, ConvergenceReport}; pub use drift::{ConstantDrift, Drift}; pub use error::InferenceError; pub use event::{Event, Member, Team};