T0 + T1 + T2: engine redesign through new API surface #1

Merged
logaritmisk merged 45 commits from t2-new-api-surface into main 2026-04-24 11:20:04 +00:00
7 changed files with 45 additions and 35 deletions
Showing only changes of commit fe6f028127 - Show all commits

View File

@@ -7,7 +7,7 @@ use crate::gaussian::Gaussian;
/// Variables hold the current Gaussian marginal and are owned by exactly one /// Variables hold the current Gaussian marginal and are owned by exactly one
/// `VarStore`. `VarId` is meaningful only within its owning store. /// `VarStore`. `VarId` is meaningful only within its owning store.
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) struct VarId(pub(crate) u32); pub struct VarId(pub u32);
/// Flat storage of variable marginals. /// Flat storage of variable marginals.
/// ///
@@ -15,36 +15,38 @@ pub(crate) struct VarId(pub(crate) u32);
/// reused across `Game::new` calls (it lives in the `ScratchArena`); call /// reused across `Game::new` calls (it lives in the `ScratchArena`); call
/// `clear()` before reuse. /// `clear()` before reuse.
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub(crate) struct VarStore { pub struct VarStore {
pub(crate) marginals: Vec<Gaussian>, pub(crate) marginals: Vec<Gaussian>,
} }
impl VarStore { impl VarStore {
#[allow(dead_code)] pub fn new() -> Self {
pub(crate) fn new() -> Self {
Self::default() Self::default()
} }
pub(crate) fn clear(&mut self) { pub fn clear(&mut self) {
self.marginals.clear(); self.marginals.clear();
} }
#[allow(dead_code)] pub fn len(&self) -> usize {
pub(crate) fn len(&self) -> usize {
self.marginals.len() self.marginals.len()
} }
pub(crate) fn alloc(&mut self, init: Gaussian) -> VarId { pub fn is_empty(&self) -> bool {
self.marginals.is_empty()
}
pub fn alloc(&mut self, init: Gaussian) -> VarId {
let id = VarId(self.marginals.len() as u32); let id = VarId(self.marginals.len() as u32);
self.marginals.push(init); self.marginals.push(init);
id id
} }
pub(crate) fn get(&self, id: VarId) -> Gaussian { pub fn get(&self, id: VarId) -> Gaussian {
self.marginals[id.0 as usize] self.marginals[id.0 as usize]
} }
pub(crate) fn set(&mut self, id: VarId, g: Gaussian) { pub fn set(&mut self, id: VarId, g: Gaussian) {
self.marginals[id.0 as usize] = g; self.marginals[id.0 as usize] = g;
} }
} }
@@ -54,7 +56,7 @@ impl VarStore {
/// Factors hold their own outgoing messages and propagate them by reading /// Factors hold their own outgoing messages and propagate them by reading
/// connected variable marginals from a `VarStore` and writing back updated /// connected variable marginals from a `VarStore` and writing back updated
/// marginals. /// marginals.
pub(crate) trait Factor { pub trait Factor {
/// Update outgoing messages and write back to the var store. /// Update outgoing messages and write back to the var store.
/// ///
/// Returns the max delta `(|Δmu|, |Δsigma|)` across writes this /// Returns the max delta `(|Δmu|, |Δsigma|)` across writes this
@@ -62,7 +64,6 @@ pub(crate) trait Factor {
fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64); fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64);
/// Optional log-evidence contribution. Default 0.0 (no contribution). /// Optional log-evidence contribution. Default 0.0 (no contribution).
#[allow(dead_code)]
fn log_evidence(&self, _vars: &VarStore) -> f64 { fn log_evidence(&self, _vars: &VarStore) -> f64 {
0.0 0.0
} }
@@ -73,8 +74,7 @@ pub(crate) trait Factor {
/// Using an enum instead of `Box<dyn Factor>` keeps factor data inline and /// Using an enum instead of `Box<dyn Factor>` keeps factor data inline and
/// avoids virtual-call overhead in the hot inference loop. /// avoids virtual-call overhead in the hot inference loop.
#[derive(Debug)] #[derive(Debug)]
#[allow(dead_code)] pub enum BuiltinFactor {
pub(crate) enum BuiltinFactor {
TeamSum(team_sum::TeamSumFactor), TeamSum(team_sum::TeamSumFactor),
RankDiff(rank_diff::RankDiffFactor), RankDiff(rank_diff::RankDiffFactor),
Trunc(trunc::TruncFactor), Trunc(trunc::TruncFactor),
@@ -97,9 +97,9 @@ impl Factor for BuiltinFactor {
} }
} }
pub(crate) mod rank_diff; pub mod rank_diff;
pub(crate) mod team_sum; pub mod team_sum;
pub(crate) mod trunc; pub mod trunc;
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {

View File

@@ -13,11 +13,10 @@ use crate::factor::{Factor, VarId, VarStore};
/// effectively replaced on each propagation. The TruncFactor on the same diff /// effectively replaced on each propagation. The TruncFactor on the same diff
/// var holds the EP-divide message that produces the cavity. /// var holds the EP-divide message that produces the cavity.
#[derive(Debug)] #[derive(Debug)]
#[allow(dead_code)] pub struct RankDiffFactor {
pub(crate) struct RankDiffFactor { pub team_a: VarId,
pub(crate) team_a: VarId, pub team_b: VarId,
pub(crate) team_b: VarId, pub diff: VarId,
pub(crate) diff: VarId,
} }
impl Factor for RankDiffFactor { impl Factor for RankDiffFactor {

View File

@@ -10,10 +10,9 @@ use crate::{
/// already with beta² noise added via `Rating::performance()`). The factor /// already with beta² noise added via `Rating::performance()`). The factor
/// runs once per game and writes the weighted sum to the output var. /// runs once per game and writes the weighted sum to the output var.
#[derive(Debug)] #[derive(Debug)]
#[allow(dead_code)] pub struct TeamSumFactor {
pub(crate) struct TeamSumFactor { pub inputs: Vec<(Gaussian, f64)>,
pub(crate) inputs: Vec<(Gaussian, f64)>, pub out: VarId,
pub(crate) out: VarId,
} }
impl Factor for TeamSumFactor { impl Factor for TeamSumFactor {

View File

@@ -11,10 +11,10 @@ use crate::{
/// Stores its outgoing message to the diff variable so the cavity computation /// Stores its outgoing message to the diff variable so the cavity computation
/// produces the correct EP message on each propagation. /// produces the correct EP message on each propagation.
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct TruncFactor { pub struct TruncFactor {
pub(crate) diff: VarId, pub diff: VarId,
pub(crate) margin: f64, pub margin: f64,
pub(crate) tie: bool, pub tie: bool,
/// Outgoing message to the diff variable (initial: N_INF, the EP identity). /// Outgoing message to the diff variable (initial: N_INF, the EP identity).
pub(crate) msg: Gaussian, pub(crate) msg: Gaussian,
/// Cached evidence (linear, not log) computed from the cavity on first propagation. /// Cached evidence (linear, not log) computed from the cavity on first propagation.
@@ -22,7 +22,7 @@ pub(crate) struct TruncFactor {
} }
impl TruncFactor { impl TruncFactor {
pub(crate) fn new(diff: VarId, margin: f64, tie: bool) -> Self { pub fn new(diff: VarId, margin: f64, tie: bool) -> Self {
Self { Self {
diff, diff,
margin, margin,

13
src/factors.rs Normal file
View File

@@ -0,0 +1,13 @@
//! Factor-graph public API.
//!
//! Power users can construct custom factor graphs via `Game::custom` (T2
//! minimal; full ergonomics in T4) and drive them with custom `Schedule`
//! implementations.
pub use crate::{
factor::{
BuiltinFactor, Factor, VarId, VarStore, rank_diff::RankDiffFactor, team_sum::TeamSumFactor,
trunc::TruncFactor,
},
schedule::{EpsilonOrMax, Schedule, ScheduleReport},
};

View File

@@ -16,6 +16,7 @@ mod error;
mod event; mod event;
mod event_builder; mod event_builder;
pub(crate) mod factor; pub(crate) mod factor;
pub mod factors;
mod game; mod game;
pub mod gaussian; pub mod gaussian;
mod history; mod history;

View File

@@ -16,8 +16,7 @@ pub struct ScheduleReport {
} }
/// Drives factor propagation to convergence. /// Drives factor propagation to convergence.
#[allow(dead_code)] pub trait Schedule {
pub(crate) trait Schedule {
fn run(&self, factors: &mut [BuiltinFactor], vars: &mut VarStore) -> ScheduleReport; fn run(&self, factors: &mut [BuiltinFactor], vars: &mut VarStore) -> ScheduleReport;
} }
@@ -26,8 +25,7 @@ pub(crate) trait Schedule {
/// Matches the existing `Game::likelihoods` loop bit-for-bit when given the /// Matches the existing `Game::likelihoods` loop bit-for-bit when given the
/// same factor layout (TeamSums first, then alternating RankDiff/Trunc pairs). /// same factor layout (TeamSums first, then alternating RankDiff/Trunc pairs).
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
#[allow(dead_code)] pub struct EpsilonOrMax {
pub(crate) struct EpsilonOrMax {
pub eps: f64, pub eps: f64,
pub max: usize, pub max: usize,
} }