T0 + T1 + T2: engine redesign through new API surface #1

Merged
logaritmisk merged 45 commits from t2-new-api-surface into main 2026-04-24 11:20:04 +00:00
3 changed files with 145 additions and 24 deletions
Showing only changes of commit a6e008f8ff - Show all commits

31
src/convergence.rs Normal file
View File

@@ -0,0 +1,31 @@
//! Convergence configuration and reporting.
use std::time::Duration;
use smallvec::SmallVec;
#[derive(Clone, Copy, Debug)]
pub struct ConvergenceOptions {
pub max_iter: usize,
pub epsilon: f64,
}
impl Default for ConvergenceOptions {
fn default() -> Self {
Self {
max_iter: crate::ITERATIONS,
epsilon: crate::EPSILON,
}
}
}
/// Post-hoc summary of a `History::converge` call.
#[derive(Clone, Debug)]
pub struct ConvergenceReport {
pub iterations: usize,
pub final_step: (f64, f64),
pub log_evidence: f64,
pub converged: bool,
pub per_iteration_time: SmallVec<[Duration; 32]>,
pub slices_skipped: usize,
}

View File

@@ -3,8 +3,11 @@ use std::collections::HashMap;
use crate::{ use crate::{
BETA, GAMMA, Index, MU, N_INF, P_DRAW, SIGMA, BETA, GAMMA, Index, MU, N_INF, P_DRAW, SIGMA,
competitor::{self, Competitor}, competitor::{self, Competitor},
convergence::{ConvergenceOptions, ConvergenceReport},
drift::{ConstantDrift, Drift}, drift::{ConstantDrift, Drift},
error::InferenceError,
gaussian::Gaussian, gaussian::Gaussian,
observer::{NullObserver, Observer},
rating::Rating, rating::Rating,
sort_time, sort_time,
storage::CompetitorStore, storage::CompetitorStore,
@@ -14,17 +17,20 @@ use crate::{
}; };
#[derive(Clone)] #[derive(Clone)]
pub struct HistoryBuilder<T: Time = i64, D: Drift<T> = ConstantDrift> { pub struct HistoryBuilder<T: Time = i64, D: Drift<T> = ConstantDrift, O: Observer<T> = NullObserver>
{
mu: f64, mu: f64,
sigma: f64, sigma: f64,
beta: f64, beta: f64,
drift: D, drift: D,
p_draw: f64, p_draw: f64,
online: bool, online: bool,
convergence: ConvergenceOptions,
observer: O,
_time: std::marker::PhantomData<T>, _time: std::marker::PhantomData<T>,
} }
impl<T: Time, D: Drift<T>> HistoryBuilder<T, D> { impl<T: Time, D: Drift<T>, O: Observer<T>> HistoryBuilder<T, D, O> {
pub fn mu(mut self, mu: f64) -> Self { pub fn mu(mut self, mu: f64) -> Self {
self.mu = mu; self.mu = mu;
self self
@@ -40,7 +46,7 @@ impl<T: Time, D: Drift<T>> HistoryBuilder<T, D> {
self self
} }
pub fn drift<D2: Drift<T>>(self, drift: D2) -> HistoryBuilder<T, D2> { pub fn drift<D2: Drift<T>>(self, drift: D2) -> HistoryBuilder<T, D2, O> {
HistoryBuilder { HistoryBuilder {
drift, drift,
mu: self.mu, mu: self.mu,
@@ -48,7 +54,9 @@ impl<T: Time, D: Drift<T>> HistoryBuilder<T, D> {
beta: self.beta, beta: self.beta,
p_draw: self.p_draw, p_draw: self.p_draw,
online: self.online, online: self.online,
_time: std::marker::PhantomData, convergence: self.convergence,
observer: self.observer,
_time: self._time,
} }
} }
@@ -62,7 +70,26 @@ impl<T: Time, D: Drift<T>> HistoryBuilder<T, D> {
self self
} }
pub fn build(self) -> History<T, D> { pub fn convergence(mut self, opts: ConvergenceOptions) -> Self {
self.convergence = opts;
self
}
pub fn observer<O2: Observer<T>>(self, observer: O2) -> HistoryBuilder<T, D, O2> {
HistoryBuilder {
mu: self.mu,
sigma: self.sigma,
beta: self.beta,
drift: self.drift,
p_draw: self.p_draw,
online: self.online,
convergence: self.convergence,
observer,
_time: self._time,
}
}
pub fn build(self) -> History<T, D, O> {
History { History {
size: 0, size: 0,
time_slices: Vec::new(), time_slices: Vec::new(),
@@ -73,18 +100,20 @@ impl<T: Time, D: Drift<T>> HistoryBuilder<T, D> {
drift: self.drift, drift: self.drift,
p_draw: self.p_draw, p_draw: self.p_draw,
online: self.online, online: self.online,
convergence: self.convergence,
observer: self.observer,
} }
} }
} }
impl HistoryBuilder<i64, ConstantDrift> { impl<O: Observer<i64>> HistoryBuilder<i64, ConstantDrift, O> {
pub fn gamma(mut self, gamma: f64) -> Self { pub fn gamma(mut self, gamma: f64) -> Self {
self.drift = ConstantDrift(gamma); self.drift = ConstantDrift(gamma);
self self
} }
} }
impl Default for HistoryBuilder<i64, ConstantDrift> { impl Default for HistoryBuilder<i64, ConstantDrift, NullObserver> {
fn default() -> Self { fn default() -> Self {
Self { Self {
mu: MU, mu: MU,
@@ -93,12 +122,14 @@ impl Default for HistoryBuilder<i64, ConstantDrift> {
drift: ConstantDrift(GAMMA), drift: ConstantDrift(GAMMA),
p_draw: P_DRAW, p_draw: P_DRAW,
online: false, online: false,
convergence: ConvergenceOptions::default(),
observer: NullObserver,
_time: std::marker::PhantomData, _time: std::marker::PhantomData,
} }
} }
} }
pub struct History<T: Time = i64, D: Drift<T> = ConstantDrift> { pub struct History<T: Time = i64, D: Drift<T> = ConstantDrift, O: Observer<T> = NullObserver> {
size: usize, size: usize,
pub(crate) time_slices: Vec<TimeSlice<T>>, pub(crate) time_slices: Vec<TimeSlice<T>>,
pub(crate) agents: CompetitorStore<T, D>, pub(crate) agents: CompetitorStore<T, D>,
@@ -108,31 +139,23 @@ pub struct History<T: Time = i64, D: Drift<T> = ConstantDrift> {
drift: D, drift: D,
p_draw: f64, p_draw: f64,
online: bool, online: bool,
convergence: ConvergenceOptions,
observer: O,
} }
impl Default for History<i64, ConstantDrift> { impl Default for History<i64, ConstantDrift, NullObserver> {
fn default() -> Self { fn default() -> Self {
Self { HistoryBuilder::default().build()
size: 0,
time_slices: Vec::new(),
agents: CompetitorStore::new(),
mu: MU,
sigma: SIGMA,
beta: BETA,
drift: ConstantDrift(GAMMA),
p_draw: P_DRAW,
online: false,
}
} }
} }
impl History<i64, ConstantDrift> { impl History<i64, ConstantDrift, NullObserver> {
pub fn builder() -> HistoryBuilder<i64, ConstantDrift> { pub fn builder() -> HistoryBuilder<i64, ConstantDrift, NullObserver> {
HistoryBuilder::default() HistoryBuilder::default()
} }
} }
impl<T: Time, D: Drift<T>> History<T, D> { impl<T: Time, D: Drift<T>, O: Observer<T>> History<T, D, O> {
fn iteration(&mut self) -> (f64, f64) { fn iteration(&mut self) -> (f64, f64) {
let mut step = (0.0, 0.0); let mut step = (0.0, 0.0);
@@ -243,9 +266,39 @@ impl<T: Time, D: Drift<T>> History<T, D> {
.map(|ts| ts.log_evidence(self.online, targets, forward, &self.agents)) .map(|ts| ts.log_evidence(self.online, targets, forward, &self.agents))
.sum() .sum()
} }
/// Run the full forward+backward convergence loop and return a summary.
pub fn converge(&mut self) -> Result<ConvergenceReport, InferenceError> {
use std::time::Instant;
use smallvec::SmallVec;
let opts = self.convergence;
let mut step = (f64::INFINITY, f64::INFINITY);
let mut i = 0;
let mut per_iter: SmallVec<[std::time::Duration; 32]> = SmallVec::new();
while tuple_gt(step, opts.epsilon) && i < opts.max_iter {
let t0 = Instant::now();
step = self.iteration();
per_iter.push(t0.elapsed());
i += 1;
self.observer.on_iteration_end(i, step);
}
let converged = !tuple_gt(step, opts.epsilon);
let log_evidence = self.log_evidence(false, &[]);
self.observer.on_converged(i, step, converged);
Ok(ConvergenceReport {
iterations: i,
final_step: step,
log_evidence,
converged,
per_iteration_time: per_iter,
slices_skipped: 0,
})
}
} }
impl<D: Drift<i64>> History<i64, D> { impl<D: Drift<i64>, O: Observer<i64>> History<i64, D, O> {
pub fn add_events( pub fn add_events(
&mut self, &mut self,
composition: Vec<Vec<Vec<Index>>>, composition: Vec<Vec<Vec<Index>>>,
@@ -1287,4 +1340,39 @@ mod tests {
assert_ulps_eq!(lc[&a][1].1, lc[&a][0].1, epsilon = 1e-6); assert_ulps_eq!(lc[&a][1].1, lc[&a][0].1, epsilon = 1e-6);
assert_ulps_eq!(lc[&b][1].1, lc[&a][0].1, epsilon = 1e-6); assert_ulps_eq!(lc[&b][1].1, lc[&a][0].1, epsilon = 1e-6);
} }
#[test]
fn test_converge_returns_report() {
use crate::ConvergenceOptions;
let mut index_map = crate::KeyTable::new();
let a = index_map.get_or_create("a");
let b = index_map.get_or_create("b");
let c = index_map.get_or_create("c");
let composition = vec![
vec![vec![a], vec![b]],
vec![vec![a], vec![c]],
vec![vec![b], vec![c]],
];
let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]];
let times: Vec<i64> = vec![1, 2, 3];
let mut h = History::builder()
.mu(0.0)
.sigma(2.0)
.beta(1.0)
.drift(ConstantDrift(0.0))
.convergence(ConvergenceOptions {
max_iter: 30,
epsilon: 1e-6,
})
.build();
h.add_events(composition, results, times, vec![]);
let report = h.converge().unwrap();
assert!(report.converged);
assert!(report.iterations > 0);
assert!(report.iterations < 30);
assert!(report.final_step.0 <= 1e-6);
}
} }

View File

@@ -10,6 +10,7 @@ mod time;
mod time_slice; mod time_slice;
pub use time_slice::TimeSlice; pub use time_slice::TimeSlice;
mod competitor; mod competitor;
mod convergence;
pub mod drift; pub mod drift;
mod error; mod error;
mod event; mod event;
@@ -26,6 +27,7 @@ pub(crate) mod schedule;
pub mod storage; pub mod storage;
pub use competitor::Competitor; pub use competitor::Competitor;
pub use convergence::{ConvergenceOptions, ConvergenceReport};
pub use drift::{ConstantDrift, Drift}; pub use drift::{ConstantDrift, Drift};
pub use error::InferenceError; pub use error::InferenceError;
pub use event::{Event, Member, Team}; pub use event::{Event, Member, Team};