T0 + T1 + T2: engine redesign through new API surface #1
@@ -20,6 +20,7 @@ mod history;
|
|||||||
mod matrix;
|
mod matrix;
|
||||||
mod message;
|
mod message;
|
||||||
pub mod player;
|
pub mod player;
|
||||||
|
pub(crate) mod schedule;
|
||||||
pub mod storage;
|
pub mod storage;
|
||||||
|
|
||||||
pub use drift::{ConstantDrift, Drift};
|
pub use drift::{ConstantDrift, Drift};
|
||||||
@@ -30,6 +31,7 @@ pub use history::History;
|
|||||||
use matrix::Matrix;
|
use matrix::Matrix;
|
||||||
use message::DiffMessage;
|
use message::DiffMessage;
|
||||||
pub use player::Player;
|
pub use player::Player;
|
||||||
|
pub use schedule::ScheduleReport;
|
||||||
|
|
||||||
pub const BETA: f64 = 1.0;
|
pub const BETA: f64 = 1.0;
|
||||||
pub const MU: f64 = 0.0;
|
pub const MU: f64 = 0.0;
|
||||||
|
|||||||
126
src/schedule.rs
Normal file
126
src/schedule.rs
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
//! Schedule trait and built-in implementations.
|
||||||
|
//!
|
||||||
|
//! A schedule drives factor propagation to convergence. The default
|
||||||
|
//! `EpsilonOrMax` performs one TeamSum sweep (setup) then alternating
|
||||||
|
//! forward/backward sweeps over the iterating factors until the max
|
||||||
|
//! delta drops below epsilon or `max` iterations is reached.
|
||||||
|
|
||||||
|
use crate::factor::{BuiltinFactor, Factor, VarStore};
|
||||||
|
|
||||||
|
/// Result returned by a `Schedule::run` call.
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct ScheduleReport {
|
||||||
|
pub iterations: usize,
|
||||||
|
pub final_step: (f64, f64),
|
||||||
|
pub converged: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Drives factor propagation to convergence.
|
||||||
|
pub(crate) trait Schedule {
|
||||||
|
fn run(&self, factors: &mut [BuiltinFactor], vars: &mut VarStore) -> ScheduleReport;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Default schedule: sweep forward then backward until step ≤ eps or iter == max.
|
||||||
|
///
|
||||||
|
/// Matches the existing `Game::likelihoods` loop bit-for-bit when given the
|
||||||
|
/// same factor layout (TeamSums first, then alternating RankDiff/Trunc pairs).
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub(crate) struct EpsilonOrMax {
|
||||||
|
pub eps: f64,
|
||||||
|
pub max: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for EpsilonOrMax {
|
||||||
|
fn default() -> Self {
|
||||||
|
// Matches today's hard-coded tolerance and iteration cap.
|
||||||
|
Self { eps: 1e-6, max: 10 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Schedule for EpsilonOrMax {
|
||||||
|
fn run(&self, factors: &mut [BuiltinFactor], vars: &mut VarStore) -> ScheduleReport {
|
||||||
|
// Partition: leading run of TeamSum factors run exactly once (setup).
|
||||||
|
let n_setup = factors
|
||||||
|
.iter()
|
||||||
|
.position(|f| !matches!(f, BuiltinFactor::TeamSum(_)))
|
||||||
|
.unwrap_or(factors.len());
|
||||||
|
|
||||||
|
for f in factors[..n_setup].iter_mut() {
|
||||||
|
f.propagate(vars);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut iterations = 0;
|
||||||
|
let mut final_step = (f64::INFINITY, f64::INFINITY);
|
||||||
|
let mut converged = false;
|
||||||
|
|
||||||
|
if n_setup < factors.len() {
|
||||||
|
for _ in 0..self.max {
|
||||||
|
let mut step = (0.0_f64, 0.0_f64);
|
||||||
|
|
||||||
|
// Forward sweep over iterating factors.
|
||||||
|
for f in factors[n_setup..].iter_mut() {
|
||||||
|
let d = f.propagate(vars);
|
||||||
|
step.0 = step.0.max(d.0);
|
||||||
|
step.1 = step.1.max(d.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backward sweep.
|
||||||
|
for f in factors[n_setup..].iter_mut().rev() {
|
||||||
|
let d = f.propagate(vars);
|
||||||
|
step.0 = step.0.max(d.0);
|
||||||
|
step.1 = step.1.max(d.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
iterations += 1;
|
||||||
|
final_step = step;
|
||||||
|
|
||||||
|
if step.0 <= self.eps && step.1 <= self.eps {
|
||||||
|
converged = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ScheduleReport {
|
||||||
|
iterations,
|
||||||
|
final_step,
|
||||||
|
converged,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::{N_INF, factor::team_sum::TeamSumFactor, gaussian::Gaussian};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn schedule_runs_setup_factors_once() {
|
||||||
|
// Single TeamSum factor; schedule should propagate it exactly once and report 0 iterations.
|
||||||
|
let mut vars = VarStore::new();
|
||||||
|
let out = vars.alloc(N_INF);
|
||||||
|
let mut factors = vec![BuiltinFactor::TeamSum(TeamSumFactor {
|
||||||
|
inputs: vec![(Gaussian::from_ms(5.0, 1.0), 1.0)],
|
||||||
|
out,
|
||||||
|
})];
|
||||||
|
let schedule = EpsilonOrMax::default();
|
||||||
|
let report = schedule.run(&mut factors, &mut vars);
|
||||||
|
assert_eq!(report.iterations, 0);
|
||||||
|
// The team-perf var should hold the sum.
|
||||||
|
let result = vars.get(out);
|
||||||
|
assert!((result.mu() - 5.0).abs() < 1e-12);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn report_marks_converged_when_no_iterating_factors() {
|
||||||
|
// No iterating factors → 0 iterations, converged stays false (loop never ran).
|
||||||
|
let mut vars = VarStore::new();
|
||||||
|
let out = vars.alloc(N_INF);
|
||||||
|
let mut factors = vec![BuiltinFactor::TeamSum(TeamSumFactor {
|
||||||
|
inputs: vec![(Gaussian::from_ms(0.0, 1.0), 1.0)],
|
||||||
|
out,
|
||||||
|
})];
|
||||||
|
let report = EpsilonOrMax::default().run(&mut factors, &mut vars);
|
||||||
|
assert_eq!(report.iterations, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user