T0 + T1 + T2: engine redesign through new API surface #1
31
src/convergence.rs
Normal file
31
src/convergence.rs
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
//! Convergence configuration and reporting.
|
||||||
|
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use smallvec::SmallVec;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
pub struct ConvergenceOptions {
|
||||||
|
pub max_iter: usize,
|
||||||
|
pub epsilon: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ConvergenceOptions {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
max_iter: crate::ITERATIONS,
|
||||||
|
epsilon: crate::EPSILON,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Post-hoc summary of a `History::converge` call.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct ConvergenceReport {
|
||||||
|
pub iterations: usize,
|
||||||
|
pub final_step: (f64, f64),
|
||||||
|
pub log_evidence: f64,
|
||||||
|
pub converged: bool,
|
||||||
|
pub per_iteration_time: SmallVec<[Duration; 32]>,
|
||||||
|
pub slices_skipped: usize,
|
||||||
|
}
|
||||||
136
src/history.rs
136
src/history.rs
@@ -3,8 +3,11 @@ use std::collections::HashMap;
|
|||||||
use crate::{
|
use crate::{
|
||||||
BETA, GAMMA, Index, MU, N_INF, P_DRAW, SIGMA,
|
BETA, GAMMA, Index, MU, N_INF, P_DRAW, SIGMA,
|
||||||
competitor::{self, Competitor},
|
competitor::{self, Competitor},
|
||||||
|
convergence::{ConvergenceOptions, ConvergenceReport},
|
||||||
drift::{ConstantDrift, Drift},
|
drift::{ConstantDrift, Drift},
|
||||||
|
error::InferenceError,
|
||||||
gaussian::Gaussian,
|
gaussian::Gaussian,
|
||||||
|
observer::{NullObserver, Observer},
|
||||||
rating::Rating,
|
rating::Rating,
|
||||||
sort_time,
|
sort_time,
|
||||||
storage::CompetitorStore,
|
storage::CompetitorStore,
|
||||||
@@ -14,17 +17,20 @@ use crate::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct HistoryBuilder<T: Time = i64, D: Drift<T> = ConstantDrift> {
|
pub struct HistoryBuilder<T: Time = i64, D: Drift<T> = ConstantDrift, O: Observer<T> = NullObserver>
|
||||||
|
{
|
||||||
mu: f64,
|
mu: f64,
|
||||||
sigma: f64,
|
sigma: f64,
|
||||||
beta: f64,
|
beta: f64,
|
||||||
drift: D,
|
drift: D,
|
||||||
p_draw: f64,
|
p_draw: f64,
|
||||||
online: bool,
|
online: bool,
|
||||||
|
convergence: ConvergenceOptions,
|
||||||
|
observer: O,
|
||||||
_time: std::marker::PhantomData<T>,
|
_time: std::marker::PhantomData<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Time, D: Drift<T>> HistoryBuilder<T, D> {
|
impl<T: Time, D: Drift<T>, O: Observer<T>> HistoryBuilder<T, D, O> {
|
||||||
pub fn mu(mut self, mu: f64) -> Self {
|
pub fn mu(mut self, mu: f64) -> Self {
|
||||||
self.mu = mu;
|
self.mu = mu;
|
||||||
self
|
self
|
||||||
@@ -40,7 +46,7 @@ impl<T: Time, D: Drift<T>> HistoryBuilder<T, D> {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn drift<D2: Drift<T>>(self, drift: D2) -> HistoryBuilder<T, D2> {
|
pub fn drift<D2: Drift<T>>(self, drift: D2) -> HistoryBuilder<T, D2, O> {
|
||||||
HistoryBuilder {
|
HistoryBuilder {
|
||||||
drift,
|
drift,
|
||||||
mu: self.mu,
|
mu: self.mu,
|
||||||
@@ -48,7 +54,9 @@ impl<T: Time, D: Drift<T>> HistoryBuilder<T, D> {
|
|||||||
beta: self.beta,
|
beta: self.beta,
|
||||||
p_draw: self.p_draw,
|
p_draw: self.p_draw,
|
||||||
online: self.online,
|
online: self.online,
|
||||||
_time: std::marker::PhantomData,
|
convergence: self.convergence,
|
||||||
|
observer: self.observer,
|
||||||
|
_time: self._time,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,7 +70,26 @@ impl<T: Time, D: Drift<T>> HistoryBuilder<T, D> {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn build(self) -> History<T, D> {
|
pub fn convergence(mut self, opts: ConvergenceOptions) -> Self {
|
||||||
|
self.convergence = opts;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn observer<O2: Observer<T>>(self, observer: O2) -> HistoryBuilder<T, D, O2> {
|
||||||
|
HistoryBuilder {
|
||||||
|
mu: self.mu,
|
||||||
|
sigma: self.sigma,
|
||||||
|
beta: self.beta,
|
||||||
|
drift: self.drift,
|
||||||
|
p_draw: self.p_draw,
|
||||||
|
online: self.online,
|
||||||
|
convergence: self.convergence,
|
||||||
|
observer,
|
||||||
|
_time: self._time,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build(self) -> History<T, D, O> {
|
||||||
History {
|
History {
|
||||||
size: 0,
|
size: 0,
|
||||||
time_slices: Vec::new(),
|
time_slices: Vec::new(),
|
||||||
@@ -73,18 +100,20 @@ impl<T: Time, D: Drift<T>> HistoryBuilder<T, D> {
|
|||||||
drift: self.drift,
|
drift: self.drift,
|
||||||
p_draw: self.p_draw,
|
p_draw: self.p_draw,
|
||||||
online: self.online,
|
online: self.online,
|
||||||
|
convergence: self.convergence,
|
||||||
|
observer: self.observer,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HistoryBuilder<i64, ConstantDrift> {
|
impl<O: Observer<i64>> HistoryBuilder<i64, ConstantDrift, O> {
|
||||||
pub fn gamma(mut self, gamma: f64) -> Self {
|
pub fn gamma(mut self, gamma: f64) -> Self {
|
||||||
self.drift = ConstantDrift(gamma);
|
self.drift = ConstantDrift(gamma);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for HistoryBuilder<i64, ConstantDrift> {
|
impl Default for HistoryBuilder<i64, ConstantDrift, NullObserver> {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
mu: MU,
|
mu: MU,
|
||||||
@@ -93,12 +122,14 @@ impl Default for HistoryBuilder<i64, ConstantDrift> {
|
|||||||
drift: ConstantDrift(GAMMA),
|
drift: ConstantDrift(GAMMA),
|
||||||
p_draw: P_DRAW,
|
p_draw: P_DRAW,
|
||||||
online: false,
|
online: false,
|
||||||
|
convergence: ConvergenceOptions::default(),
|
||||||
|
observer: NullObserver,
|
||||||
_time: std::marker::PhantomData,
|
_time: std::marker::PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct History<T: Time = i64, D: Drift<T> = ConstantDrift> {
|
pub struct History<T: Time = i64, D: Drift<T> = ConstantDrift, O: Observer<T> = NullObserver> {
|
||||||
size: usize,
|
size: usize,
|
||||||
pub(crate) time_slices: Vec<TimeSlice<T>>,
|
pub(crate) time_slices: Vec<TimeSlice<T>>,
|
||||||
pub(crate) agents: CompetitorStore<T, D>,
|
pub(crate) agents: CompetitorStore<T, D>,
|
||||||
@@ -108,31 +139,23 @@ pub struct History<T: Time = i64, D: Drift<T> = ConstantDrift> {
|
|||||||
drift: D,
|
drift: D,
|
||||||
p_draw: f64,
|
p_draw: f64,
|
||||||
online: bool,
|
online: bool,
|
||||||
|
convergence: ConvergenceOptions,
|
||||||
|
observer: O,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for History<i64, ConstantDrift> {
|
impl Default for History<i64, ConstantDrift, NullObserver> {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
HistoryBuilder::default().build()
|
||||||
size: 0,
|
|
||||||
time_slices: Vec::new(),
|
|
||||||
agents: CompetitorStore::new(),
|
|
||||||
mu: MU,
|
|
||||||
sigma: SIGMA,
|
|
||||||
beta: BETA,
|
|
||||||
drift: ConstantDrift(GAMMA),
|
|
||||||
p_draw: P_DRAW,
|
|
||||||
online: false,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl History<i64, ConstantDrift> {
|
impl History<i64, ConstantDrift, NullObserver> {
|
||||||
pub fn builder() -> HistoryBuilder<i64, ConstantDrift> {
|
pub fn builder() -> HistoryBuilder<i64, ConstantDrift, NullObserver> {
|
||||||
HistoryBuilder::default()
|
HistoryBuilder::default()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Time, D: Drift<T>> History<T, D> {
|
impl<T: Time, D: Drift<T>, O: Observer<T>> History<T, D, O> {
|
||||||
fn iteration(&mut self) -> (f64, f64) {
|
fn iteration(&mut self) -> (f64, f64) {
|
||||||
let mut step = (0.0, 0.0);
|
let mut step = (0.0, 0.0);
|
||||||
|
|
||||||
@@ -243,9 +266,39 @@ impl<T: Time, D: Drift<T>> History<T, D> {
|
|||||||
.map(|ts| ts.log_evidence(self.online, targets, forward, &self.agents))
|
.map(|ts| ts.log_evidence(self.online, targets, forward, &self.agents))
|
||||||
.sum()
|
.sum()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Run the full forward+backward convergence loop and return a summary.
|
||||||
|
pub fn converge(&mut self) -> Result<ConvergenceReport, InferenceError> {
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
use smallvec::SmallVec;
|
||||||
|
|
||||||
|
let opts = self.convergence;
|
||||||
|
let mut step = (f64::INFINITY, f64::INFINITY);
|
||||||
|
let mut i = 0;
|
||||||
|
let mut per_iter: SmallVec<[std::time::Duration; 32]> = SmallVec::new();
|
||||||
|
while tuple_gt(step, opts.epsilon) && i < opts.max_iter {
|
||||||
|
let t0 = Instant::now();
|
||||||
|
step = self.iteration();
|
||||||
|
per_iter.push(t0.elapsed());
|
||||||
|
i += 1;
|
||||||
|
self.observer.on_iteration_end(i, step);
|
||||||
|
}
|
||||||
|
let converged = !tuple_gt(step, opts.epsilon);
|
||||||
|
let log_evidence = self.log_evidence(false, &[]);
|
||||||
|
self.observer.on_converged(i, step, converged);
|
||||||
|
Ok(ConvergenceReport {
|
||||||
|
iterations: i,
|
||||||
|
final_step: step,
|
||||||
|
log_evidence,
|
||||||
|
converged,
|
||||||
|
per_iteration_time: per_iter,
|
||||||
|
slices_skipped: 0,
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D: Drift<i64>> History<i64, D> {
|
impl<D: Drift<i64>, O: Observer<i64>> History<i64, D, O> {
|
||||||
pub fn add_events(
|
pub fn add_events(
|
||||||
&mut self,
|
&mut self,
|
||||||
composition: Vec<Vec<Vec<Index>>>,
|
composition: Vec<Vec<Vec<Index>>>,
|
||||||
@@ -1287,4 +1340,39 @@ mod tests {
|
|||||||
assert_ulps_eq!(lc[&a][1].1, lc[&a][0].1, epsilon = 1e-6);
|
assert_ulps_eq!(lc[&a][1].1, lc[&a][0].1, epsilon = 1e-6);
|
||||||
assert_ulps_eq!(lc[&b][1].1, lc[&a][0].1, epsilon = 1e-6);
|
assert_ulps_eq!(lc[&b][1].1, lc[&a][0].1, epsilon = 1e-6);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_converge_returns_report() {
|
||||||
|
use crate::ConvergenceOptions;
|
||||||
|
|
||||||
|
let mut index_map = crate::KeyTable::new();
|
||||||
|
let a = index_map.get_or_create("a");
|
||||||
|
let b = index_map.get_or_create("b");
|
||||||
|
let c = index_map.get_or_create("c");
|
||||||
|
let composition = vec![
|
||||||
|
vec![vec![a], vec![b]],
|
||||||
|
vec![vec![a], vec![c]],
|
||||||
|
vec![vec![b], vec![c]],
|
||||||
|
];
|
||||||
|
let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]];
|
||||||
|
let times: Vec<i64> = vec![1, 2, 3];
|
||||||
|
|
||||||
|
let mut h = History::builder()
|
||||||
|
.mu(0.0)
|
||||||
|
.sigma(2.0)
|
||||||
|
.beta(1.0)
|
||||||
|
.drift(ConstantDrift(0.0))
|
||||||
|
.convergence(ConvergenceOptions {
|
||||||
|
max_iter: 30,
|
||||||
|
epsilon: 1e-6,
|
||||||
|
})
|
||||||
|
.build();
|
||||||
|
h.add_events(composition, results, times, vec![]);
|
||||||
|
|
||||||
|
let report = h.converge().unwrap();
|
||||||
|
assert!(report.converged);
|
||||||
|
assert!(report.iterations > 0);
|
||||||
|
assert!(report.iterations < 30);
|
||||||
|
assert!(report.final_step.0 <= 1e-6);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ mod time;
|
|||||||
mod time_slice;
|
mod time_slice;
|
||||||
pub use time_slice::TimeSlice;
|
pub use time_slice::TimeSlice;
|
||||||
mod competitor;
|
mod competitor;
|
||||||
|
mod convergence;
|
||||||
pub mod drift;
|
pub mod drift;
|
||||||
mod error;
|
mod error;
|
||||||
mod event;
|
mod event;
|
||||||
@@ -26,6 +27,7 @@ pub(crate) mod schedule;
|
|||||||
pub mod storage;
|
pub mod storage;
|
||||||
|
|
||||||
pub use competitor::Competitor;
|
pub use competitor::Competitor;
|
||||||
|
pub use convergence::{ConvergenceOptions, ConvergenceReport};
|
||||||
pub use drift::{ConstantDrift, Drift};
|
pub use drift::{ConstantDrift, Drift};
|
||||||
pub use error::InferenceError;
|
pub use error::InferenceError;
|
||||||
pub use event::{Event, Member, Team};
|
pub use event::{Event, Member, Team};
|
||||||
|
|||||||
Reference in New Issue
Block a user