diff --git a/src/convergence.rs b/src/convergence.rs new file mode 100644 index 0000000..d03359c --- /dev/null +++ b/src/convergence.rs @@ -0,0 +1,31 @@ +//! Convergence configuration and reporting. + +use std::time::Duration; + +use smallvec::SmallVec; + +#[derive(Clone, Copy, Debug)] +pub struct ConvergenceOptions { + pub max_iter: usize, + pub epsilon: f64, +} + +impl Default for ConvergenceOptions { + fn default() -> Self { + Self { + max_iter: crate::ITERATIONS, + epsilon: crate::EPSILON, + } + } +} + +/// Post-hoc summary of a `History::converge` call. +#[derive(Clone, Debug)] +pub struct ConvergenceReport { + pub iterations: usize, + pub final_step: (f64, f64), + pub log_evidence: f64, + pub converged: bool, + pub per_iteration_time: SmallVec<[Duration; 32]>, + pub slices_skipped: usize, +} diff --git a/src/history.rs b/src/history.rs index f162b3c..8545d8d 100644 --- a/src/history.rs +++ b/src/history.rs @@ -3,8 +3,11 @@ use std::collections::HashMap; use crate::{ BETA, GAMMA, Index, MU, N_INF, P_DRAW, SIGMA, competitor::{self, Competitor}, + convergence::{ConvergenceOptions, ConvergenceReport}, drift::{ConstantDrift, Drift}, + error::InferenceError, gaussian::Gaussian, + observer::{NullObserver, Observer}, rating::Rating, sort_time, storage::CompetitorStore, @@ -14,17 +17,20 @@ use crate::{ }; #[derive(Clone)] -pub struct HistoryBuilder = ConstantDrift> { +pub struct HistoryBuilder = ConstantDrift, O: Observer = NullObserver> +{ mu: f64, sigma: f64, beta: f64, drift: D, p_draw: f64, online: bool, + convergence: ConvergenceOptions, + observer: O, _time: std::marker::PhantomData, } -impl> HistoryBuilder { +impl, O: Observer> HistoryBuilder { pub fn mu(mut self, mu: f64) -> Self { self.mu = mu; self @@ -40,7 +46,7 @@ impl> HistoryBuilder { self } - pub fn drift>(self, drift: D2) -> HistoryBuilder { + pub fn drift>(self, drift: D2) -> HistoryBuilder { HistoryBuilder { drift, mu: self.mu, @@ -48,7 +54,9 @@ impl> HistoryBuilder { beta: self.beta, p_draw: self.p_draw, online: self.online, - _time: std::marker::PhantomData, + convergence: self.convergence, + observer: self.observer, + _time: self._time, } } @@ -62,7 +70,26 @@ impl> HistoryBuilder { self } - pub fn build(self) -> History { + pub fn convergence(mut self, opts: ConvergenceOptions) -> Self { + self.convergence = opts; + self + } + + pub fn observer>(self, observer: O2) -> HistoryBuilder { + HistoryBuilder { + mu: self.mu, + sigma: self.sigma, + beta: self.beta, + drift: self.drift, + p_draw: self.p_draw, + online: self.online, + convergence: self.convergence, + observer, + _time: self._time, + } + } + + pub fn build(self) -> History { History { size: 0, time_slices: Vec::new(), @@ -73,18 +100,20 @@ impl> HistoryBuilder { drift: self.drift, p_draw: self.p_draw, online: self.online, + convergence: self.convergence, + observer: self.observer, } } } -impl HistoryBuilder { +impl> HistoryBuilder { pub fn gamma(mut self, gamma: f64) -> Self { self.drift = ConstantDrift(gamma); self } } -impl Default for HistoryBuilder { +impl Default for HistoryBuilder { fn default() -> Self { Self { mu: MU, @@ -93,12 +122,14 @@ impl Default for HistoryBuilder { drift: ConstantDrift(GAMMA), p_draw: P_DRAW, online: false, + convergence: ConvergenceOptions::default(), + observer: NullObserver, _time: std::marker::PhantomData, } } } -pub struct History = ConstantDrift> { +pub struct History = ConstantDrift, O: Observer = NullObserver> { size: usize, pub(crate) time_slices: Vec>, pub(crate) agents: CompetitorStore, @@ -108,31 +139,23 @@ pub struct History = ConstantDrift> { drift: D, p_draw: f64, online: bool, + convergence: ConvergenceOptions, + observer: O, } -impl Default for History { +impl Default for History { fn default() -> Self { - Self { - size: 0, - time_slices: Vec::new(), - agents: CompetitorStore::new(), - mu: MU, - sigma: SIGMA, - beta: BETA, - drift: ConstantDrift(GAMMA), - p_draw: P_DRAW, - online: false, - } + HistoryBuilder::default().build() } } -impl History { - pub fn builder() -> HistoryBuilder { +impl History { + pub fn builder() -> HistoryBuilder { HistoryBuilder::default() } } -impl> History { +impl, O: Observer> History { fn iteration(&mut self) -> (f64, f64) { let mut step = (0.0, 0.0); @@ -243,9 +266,39 @@ impl> History { .map(|ts| ts.log_evidence(self.online, targets, forward, &self.agents)) .sum() } + + /// Run the full forward+backward convergence loop and return a summary. + pub fn converge(&mut self) -> Result { + use std::time::Instant; + + use smallvec::SmallVec; + + let opts = self.convergence; + let mut step = (f64::INFINITY, f64::INFINITY); + let mut i = 0; + let mut per_iter: SmallVec<[std::time::Duration; 32]> = SmallVec::new(); + while tuple_gt(step, opts.epsilon) && i < opts.max_iter { + let t0 = Instant::now(); + step = self.iteration(); + per_iter.push(t0.elapsed()); + i += 1; + self.observer.on_iteration_end(i, step); + } + let converged = !tuple_gt(step, opts.epsilon); + let log_evidence = self.log_evidence(false, &[]); + self.observer.on_converged(i, step, converged); + Ok(ConvergenceReport { + iterations: i, + final_step: step, + log_evidence, + converged, + per_iteration_time: per_iter, + slices_skipped: 0, + }) + } } -impl> History { +impl, O: Observer> History { pub fn add_events( &mut self, composition: Vec>>, @@ -1287,4 +1340,39 @@ mod tests { assert_ulps_eq!(lc[&a][1].1, lc[&a][0].1, epsilon = 1e-6); assert_ulps_eq!(lc[&b][1].1, lc[&a][0].1, epsilon = 1e-6); } + + #[test] + fn test_converge_returns_report() { + use crate::ConvergenceOptions; + + let mut index_map = crate::KeyTable::new(); + let a = index_map.get_or_create("a"); + let b = index_map.get_or_create("b"); + let c = index_map.get_or_create("c"); + let composition = vec![ + vec![vec![a], vec![b]], + vec![vec![a], vec![c]], + vec![vec![b], vec![c]], + ]; + let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]]; + let times: Vec = vec![1, 2, 3]; + + let mut h = History::builder() + .mu(0.0) + .sigma(2.0) + .beta(1.0) + .drift(ConstantDrift(0.0)) + .convergence(ConvergenceOptions { + max_iter: 30, + epsilon: 1e-6, + }) + .build(); + h.add_events(composition, results, times, vec![]); + + let report = h.converge().unwrap(); + assert!(report.converged); + assert!(report.iterations > 0); + assert!(report.iterations < 30); + assert!(report.final_step.0 <= 1e-6); + } } diff --git a/src/lib.rs b/src/lib.rs index d4efa2d..5a397fc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,7 @@ mod time; mod time_slice; pub use time_slice::TimeSlice; mod competitor; +mod convergence; pub mod drift; mod error; mod event; @@ -26,6 +27,7 @@ pub(crate) mod schedule; pub mod storage; pub use competitor::Competitor; +pub use convergence::{ConvergenceOptions, ConvergenceReport}; pub use drift::{ConstantDrift, Drift}; pub use error::InferenceError; pub use event::{Event, Member, Team};