T0 + T1 + T2: engine redesign through new API surface #1
31
src/convergence.rs
Normal file
31
src/convergence.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
//! Convergence configuration and reporting.
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use smallvec::SmallVec;
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct ConvergenceOptions {
|
||||
pub max_iter: usize,
|
||||
pub epsilon: f64,
|
||||
}
|
||||
|
||||
impl Default for ConvergenceOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_iter: crate::ITERATIONS,
|
||||
epsilon: crate::EPSILON,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Post-hoc summary of a `History::converge` call.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ConvergenceReport {
|
||||
pub iterations: usize,
|
||||
pub final_step: (f64, f64),
|
||||
pub log_evidence: f64,
|
||||
pub converged: bool,
|
||||
pub per_iteration_time: SmallVec<[Duration; 32]>,
|
||||
pub slices_skipped: usize,
|
||||
}
|
||||
136
src/history.rs
136
src/history.rs
@@ -3,8 +3,11 @@ use std::collections::HashMap;
|
||||
use crate::{
|
||||
BETA, GAMMA, Index, MU, N_INF, P_DRAW, SIGMA,
|
||||
competitor::{self, Competitor},
|
||||
convergence::{ConvergenceOptions, ConvergenceReport},
|
||||
drift::{ConstantDrift, Drift},
|
||||
error::InferenceError,
|
||||
gaussian::Gaussian,
|
||||
observer::{NullObserver, Observer},
|
||||
rating::Rating,
|
||||
sort_time,
|
||||
storage::CompetitorStore,
|
||||
@@ -14,17 +17,20 @@ use crate::{
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct HistoryBuilder<T: Time = i64, D: Drift<T> = ConstantDrift> {
|
||||
pub struct HistoryBuilder<T: Time = i64, D: Drift<T> = ConstantDrift, O: Observer<T> = NullObserver>
|
||||
{
|
||||
mu: f64,
|
||||
sigma: f64,
|
||||
beta: f64,
|
||||
drift: D,
|
||||
p_draw: f64,
|
||||
online: bool,
|
||||
convergence: ConvergenceOptions,
|
||||
observer: O,
|
||||
_time: std::marker::PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: Time, D: Drift<T>> HistoryBuilder<T, D> {
|
||||
impl<T: Time, D: Drift<T>, O: Observer<T>> HistoryBuilder<T, D, O> {
|
||||
pub fn mu(mut self, mu: f64) -> Self {
|
||||
self.mu = mu;
|
||||
self
|
||||
@@ -40,7 +46,7 @@ impl<T: Time, D: Drift<T>> HistoryBuilder<T, D> {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn drift<D2: Drift<T>>(self, drift: D2) -> HistoryBuilder<T, D2> {
|
||||
pub fn drift<D2: Drift<T>>(self, drift: D2) -> HistoryBuilder<T, D2, O> {
|
||||
HistoryBuilder {
|
||||
drift,
|
||||
mu: self.mu,
|
||||
@@ -48,7 +54,9 @@ impl<T: Time, D: Drift<T>> HistoryBuilder<T, D> {
|
||||
beta: self.beta,
|
||||
p_draw: self.p_draw,
|
||||
online: self.online,
|
||||
_time: std::marker::PhantomData,
|
||||
convergence: self.convergence,
|
||||
observer: self.observer,
|
||||
_time: self._time,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,7 +70,26 @@ impl<T: Time, D: Drift<T>> HistoryBuilder<T, D> {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> History<T, D> {
|
||||
pub fn convergence(mut self, opts: ConvergenceOptions) -> Self {
|
||||
self.convergence = opts;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn observer<O2: Observer<T>>(self, observer: O2) -> HistoryBuilder<T, D, O2> {
|
||||
HistoryBuilder {
|
||||
mu: self.mu,
|
||||
sigma: self.sigma,
|
||||
beta: self.beta,
|
||||
drift: self.drift,
|
||||
p_draw: self.p_draw,
|
||||
online: self.online,
|
||||
convergence: self.convergence,
|
||||
observer,
|
||||
_time: self._time,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build(self) -> History<T, D, O> {
|
||||
History {
|
||||
size: 0,
|
||||
time_slices: Vec::new(),
|
||||
@@ -73,18 +100,20 @@ impl<T: Time, D: Drift<T>> HistoryBuilder<T, D> {
|
||||
drift: self.drift,
|
||||
p_draw: self.p_draw,
|
||||
online: self.online,
|
||||
convergence: self.convergence,
|
||||
observer: self.observer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HistoryBuilder<i64, ConstantDrift> {
|
||||
impl<O: Observer<i64>> HistoryBuilder<i64, ConstantDrift, O> {
|
||||
pub fn gamma(mut self, gamma: f64) -> Self {
|
||||
self.drift = ConstantDrift(gamma);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HistoryBuilder<i64, ConstantDrift> {
|
||||
impl Default for HistoryBuilder<i64, ConstantDrift, NullObserver> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
mu: MU,
|
||||
@@ -93,12 +122,14 @@ impl Default for HistoryBuilder<i64, ConstantDrift> {
|
||||
drift: ConstantDrift(GAMMA),
|
||||
p_draw: P_DRAW,
|
||||
online: false,
|
||||
convergence: ConvergenceOptions::default(),
|
||||
observer: NullObserver,
|
||||
_time: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct History<T: Time = i64, D: Drift<T> = ConstantDrift> {
|
||||
pub struct History<T: Time = i64, D: Drift<T> = ConstantDrift, O: Observer<T> = NullObserver> {
|
||||
size: usize,
|
||||
pub(crate) time_slices: Vec<TimeSlice<T>>,
|
||||
pub(crate) agents: CompetitorStore<T, D>,
|
||||
@@ -108,31 +139,23 @@ pub struct History<T: Time = i64, D: Drift<T> = ConstantDrift> {
|
||||
drift: D,
|
||||
p_draw: f64,
|
||||
online: bool,
|
||||
convergence: ConvergenceOptions,
|
||||
observer: O,
|
||||
}
|
||||
|
||||
impl Default for History<i64, ConstantDrift> {
|
||||
impl Default for History<i64, ConstantDrift, NullObserver> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
size: 0,
|
||||
time_slices: Vec::new(),
|
||||
agents: CompetitorStore::new(),
|
||||
mu: MU,
|
||||
sigma: SIGMA,
|
||||
beta: BETA,
|
||||
drift: ConstantDrift(GAMMA),
|
||||
p_draw: P_DRAW,
|
||||
online: false,
|
||||
}
|
||||
HistoryBuilder::default().build()
|
||||
}
|
||||
}
|
||||
|
||||
impl History<i64, ConstantDrift> {
|
||||
pub fn builder() -> HistoryBuilder<i64, ConstantDrift> {
|
||||
impl History<i64, ConstantDrift, NullObserver> {
|
||||
pub fn builder() -> HistoryBuilder<i64, ConstantDrift, NullObserver> {
|
||||
HistoryBuilder::default()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Time, D: Drift<T>> History<T, D> {
|
||||
impl<T: Time, D: Drift<T>, O: Observer<T>> History<T, D, O> {
|
||||
fn iteration(&mut self) -> (f64, f64) {
|
||||
let mut step = (0.0, 0.0);
|
||||
|
||||
@@ -243,9 +266,39 @@ impl<T: Time, D: Drift<T>> History<T, D> {
|
||||
.map(|ts| ts.log_evidence(self.online, targets, forward, &self.agents))
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Run the full forward+backward convergence loop and return a summary.
|
||||
pub fn converge(&mut self) -> Result<ConvergenceReport, InferenceError> {
|
||||
use std::time::Instant;
|
||||
|
||||
use smallvec::SmallVec;
|
||||
|
||||
let opts = self.convergence;
|
||||
let mut step = (f64::INFINITY, f64::INFINITY);
|
||||
let mut i = 0;
|
||||
let mut per_iter: SmallVec<[std::time::Duration; 32]> = SmallVec::new();
|
||||
while tuple_gt(step, opts.epsilon) && i < opts.max_iter {
|
||||
let t0 = Instant::now();
|
||||
step = self.iteration();
|
||||
per_iter.push(t0.elapsed());
|
||||
i += 1;
|
||||
self.observer.on_iteration_end(i, step);
|
||||
}
|
||||
let converged = !tuple_gt(step, opts.epsilon);
|
||||
let log_evidence = self.log_evidence(false, &[]);
|
||||
self.observer.on_converged(i, step, converged);
|
||||
Ok(ConvergenceReport {
|
||||
iterations: i,
|
||||
final_step: step,
|
||||
log_evidence,
|
||||
converged,
|
||||
per_iteration_time: per_iter,
|
||||
slices_skipped: 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Drift<i64>> History<i64, D> {
|
||||
impl<D: Drift<i64>, O: Observer<i64>> History<i64, D, O> {
|
||||
pub fn add_events(
|
||||
&mut self,
|
||||
composition: Vec<Vec<Vec<Index>>>,
|
||||
@@ -1287,4 +1340,39 @@ mod tests {
|
||||
assert_ulps_eq!(lc[&a][1].1, lc[&a][0].1, epsilon = 1e-6);
|
||||
assert_ulps_eq!(lc[&b][1].1, lc[&a][0].1, epsilon = 1e-6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_converge_returns_report() {
|
||||
use crate::ConvergenceOptions;
|
||||
|
||||
let mut index_map = crate::KeyTable::new();
|
||||
let a = index_map.get_or_create("a");
|
||||
let b = index_map.get_or_create("b");
|
||||
let c = index_map.get_or_create("c");
|
||||
let composition = vec![
|
||||
vec![vec![a], vec![b]],
|
||||
vec![vec![a], vec![c]],
|
||||
vec![vec![b], vec![c]],
|
||||
];
|
||||
let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]];
|
||||
let times: Vec<i64> = vec![1, 2, 3];
|
||||
|
||||
let mut h = History::builder()
|
||||
.mu(0.0)
|
||||
.sigma(2.0)
|
||||
.beta(1.0)
|
||||
.drift(ConstantDrift(0.0))
|
||||
.convergence(ConvergenceOptions {
|
||||
max_iter: 30,
|
||||
epsilon: 1e-6,
|
||||
})
|
||||
.build();
|
||||
h.add_events(composition, results, times, vec![]);
|
||||
|
||||
let report = h.converge().unwrap();
|
||||
assert!(report.converged);
|
||||
assert!(report.iterations > 0);
|
||||
assert!(report.iterations < 30);
|
||||
assert!(report.final_step.0 <= 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ mod time;
|
||||
mod time_slice;
|
||||
pub use time_slice::TimeSlice;
|
||||
mod competitor;
|
||||
mod convergence;
|
||||
pub mod drift;
|
||||
mod error;
|
||||
mod event;
|
||||
@@ -26,6 +27,7 @@ pub(crate) mod schedule;
|
||||
pub mod storage;
|
||||
|
||||
pub use competitor::Competitor;
|
||||
pub use convergence::{ConvergenceOptions, ConvergenceReport};
|
||||
pub use drift::{ConstantDrift, Drift};
|
||||
pub use error::InferenceError;
|
||||
pub use event::{Event, Member, Team};
|
||||
|
||||
Reference in New Issue
Block a user