feat(schedule): add Schedule trait and EpsilonOrMax impl
EpsilonOrMax mirrors today's Game::likelihoods loop: sweep forward then backward over iterating factors, capped at 10 iterations or step <= 1e-6. Setup factors (TeamSum) run exactly once before the loop begins. ScheduleReport is the only public surface from this module.
This commit is contained in:
126
src/schedule.rs
Normal file
126
src/schedule.rs
Normal file
@@ -0,0 +1,126 @@
|
||||
//! Schedule trait and built-in implementations.
|
||||
//!
|
||||
//! A schedule drives factor propagation to convergence. The default
|
||||
//! `EpsilonOrMax` performs one TeamSum sweep (setup) then alternating
|
||||
//! forward/backward sweeps over the iterating factors until the max
|
||||
//! delta drops below epsilon or `max` iterations is reached.
|
||||
|
||||
use crate::factor::{BuiltinFactor, Factor, VarStore};
|
||||
|
||||
/// Result returned by a `Schedule::run` call.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ScheduleReport {
|
||||
pub iterations: usize,
|
||||
pub final_step: (f64, f64),
|
||||
pub converged: bool,
|
||||
}
|
||||
|
||||
/// Drives factor propagation to convergence.
|
||||
pub(crate) trait Schedule {
|
||||
fn run(&self, factors: &mut [BuiltinFactor], vars: &mut VarStore) -> ScheduleReport;
|
||||
}
|
||||
|
||||
/// Default schedule: sweep forward then backward until step ≤ eps or iter == max.
|
||||
///
|
||||
/// Matches the existing `Game::likelihoods` loop bit-for-bit when given the
|
||||
/// same factor layout (TeamSums first, then alternating RankDiff/Trunc pairs).
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub(crate) struct EpsilonOrMax {
|
||||
pub eps: f64,
|
||||
pub max: usize,
|
||||
}
|
||||
|
||||
impl Default for EpsilonOrMax {
|
||||
fn default() -> Self {
|
||||
// Matches today's hard-coded tolerance and iteration cap.
|
||||
Self { eps: 1e-6, max: 10 }
|
||||
}
|
||||
}
|
||||
|
||||
impl Schedule for EpsilonOrMax {
|
||||
fn run(&self, factors: &mut [BuiltinFactor], vars: &mut VarStore) -> ScheduleReport {
|
||||
// Partition: leading run of TeamSum factors run exactly once (setup).
|
||||
let n_setup = factors
|
||||
.iter()
|
||||
.position(|f| !matches!(f, BuiltinFactor::TeamSum(_)))
|
||||
.unwrap_or(factors.len());
|
||||
|
||||
for f in factors[..n_setup].iter_mut() {
|
||||
f.propagate(vars);
|
||||
}
|
||||
|
||||
let mut iterations = 0;
|
||||
let mut final_step = (f64::INFINITY, f64::INFINITY);
|
||||
let mut converged = false;
|
||||
|
||||
if n_setup < factors.len() {
|
||||
for _ in 0..self.max {
|
||||
let mut step = (0.0_f64, 0.0_f64);
|
||||
|
||||
// Forward sweep over iterating factors.
|
||||
for f in factors[n_setup..].iter_mut() {
|
||||
let d = f.propagate(vars);
|
||||
step.0 = step.0.max(d.0);
|
||||
step.1 = step.1.max(d.1);
|
||||
}
|
||||
|
||||
// Backward sweep.
|
||||
for f in factors[n_setup..].iter_mut().rev() {
|
||||
let d = f.propagate(vars);
|
||||
step.0 = step.0.max(d.0);
|
||||
step.1 = step.1.max(d.1);
|
||||
}
|
||||
|
||||
iterations += 1;
|
||||
final_step = step;
|
||||
|
||||
if step.0 <= self.eps && step.1 <= self.eps {
|
||||
converged = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ScheduleReport {
|
||||
iterations,
|
||||
final_step,
|
||||
converged,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{N_INF, factor::team_sum::TeamSumFactor, gaussian::Gaussian};
|
||||
|
||||
#[test]
|
||||
fn schedule_runs_setup_factors_once() {
|
||||
// Single TeamSum factor; schedule should propagate it exactly once and report 0 iterations.
|
||||
let mut vars = VarStore::new();
|
||||
let out = vars.alloc(N_INF);
|
||||
let mut factors = vec![BuiltinFactor::TeamSum(TeamSumFactor {
|
||||
inputs: vec![(Gaussian::from_ms(5.0, 1.0), 1.0)],
|
||||
out,
|
||||
})];
|
||||
let schedule = EpsilonOrMax::default();
|
||||
let report = schedule.run(&mut factors, &mut vars);
|
||||
assert_eq!(report.iterations, 0);
|
||||
// The team-perf var should hold the sum.
|
||||
let result = vars.get(out);
|
||||
assert!((result.mu() - 5.0).abs() < 1e-12);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn report_marks_converged_when_no_iterating_factors() {
|
||||
// No iterating factors → 0 iterations, converged stays false (loop never ran).
|
||||
let mut vars = VarStore::new();
|
||||
let out = vars.alloc(N_INF);
|
||||
let mut factors = vec![BuiltinFactor::TeamSum(TeamSumFactor {
|
||||
inputs: vec![(Gaussian::from_ms(0.0, 1.0), 1.0)],
|
||||
out,
|
||||
})];
|
||||
let report = EpsilonOrMax::default().run(&mut factors, &mut vars);
|
||||
assert_eq!(report.iterations, 0);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user