From ebccc7b454124a3327832340ce64f2315ab6bc89 Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Fri, 24 Apr 2026 08:14:00 +0200 Subject: [PATCH] feat(factor): introduce Factor trait and BuiltinFactor enum Adds the trait that all factors implement and the enum dispatcher used by the schedule to drive heterogeneous factors without dynamic dispatch in the hot loop. The three built-in factors (TeamSum, RankDiff, Trunc) are stubbed out; concrete implementations follow in tasks 4-6. --- src/factor/mod.rs | 50 +++++++++++++++++++++++++++++++++++++++++ src/factor/rank_diff.rs | 14 ++++++++++++ src/factor/team_sum.rs | 16 +++++++++++++ src/factor/trunc.rs | 32 ++++++++++++++++++++++++++ 4 files changed, 112 insertions(+) create mode 100644 src/factor/rank_diff.rs create mode 100644 src/factor/team_sum.rs create mode 100644 src/factor/trunc.rs diff --git a/src/factor/mod.rs b/src/factor/mod.rs index 3580fa4..b0ce1b9 100644 --- a/src/factor/mod.rs +++ b/src/factor/mod.rs @@ -47,6 +47,56 @@ impl VarStore { } } +/// A factor in the EP graph. +/// +/// Factors hold their own outgoing messages and propagate them by reading +/// connected variable marginals from a `VarStore` and writing back updated +/// marginals. +pub(crate) trait Factor { + /// Update outgoing messages and write back to the var store. + /// + /// Returns the max delta `(|Δmu|, |Δsigma|)` across writes this + /// propagation. Used by the `Schedule` to detect convergence. + fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64); + + /// Optional log-evidence contribution. Default 0.0 (no contribution). + fn log_evidence(&self, _vars: &VarStore) -> f64 { + 0.0 + } +} + +/// Enum dispatcher for the built-in factor types. +/// +/// Using an enum instead of `Box` keeps factor data inline and +/// avoids virtual-call overhead in the hot inference loop. +#[derive(Debug)] +pub(crate) enum BuiltinFactor { + TeamSum(team_sum::TeamSumFactor), + RankDiff(rank_diff::RankDiffFactor), + Trunc(trunc::TruncFactor), +} + +impl Factor for BuiltinFactor { + fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64) { + match self { + Self::TeamSum(f) => f.propagate(vars), + Self::RankDiff(f) => f.propagate(vars), + Self::Trunc(f) => f.propagate(vars), + } + } + + fn log_evidence(&self, vars: &VarStore) -> f64 { + match self { + Self::Trunc(f) => f.log_evidence(vars), + _ => 0.0, + } + } +} + +pub(crate) mod rank_diff; +pub(crate) mod team_sum; +pub(crate) mod trunc; + #[cfg(test)] mod tests { use super::*; diff --git a/src/factor/rank_diff.rs b/src/factor/rank_diff.rs new file mode 100644 index 0000000..9ecf995 --- /dev/null +++ b/src/factor/rank_diff.rs @@ -0,0 +1,14 @@ +use crate::factor::{Factor, VarId, VarStore}; + +#[derive(Debug)] +pub(crate) struct RankDiffFactor { + pub(crate) team_a: VarId, + pub(crate) team_b: VarId, + pub(crate) diff: VarId, +} + +impl Factor for RankDiffFactor { + fn propagate(&mut self, _vars: &mut VarStore) -> (f64, f64) { + unimplemented!("RankDiffFactor stub — implemented in Task 5") + } +} diff --git a/src/factor/team_sum.rs b/src/factor/team_sum.rs new file mode 100644 index 0000000..1619ce0 --- /dev/null +++ b/src/factor/team_sum.rs @@ -0,0 +1,16 @@ +use crate::{ + factor::{Factor, VarId, VarStore}, + gaussian::Gaussian, +}; + +#[derive(Debug)] +pub(crate) struct TeamSumFactor { + pub(crate) inputs: Vec<(Gaussian, f64)>, + pub(crate) out: VarId, +} + +impl Factor for TeamSumFactor { + fn propagate(&mut self, _vars: &mut VarStore) -> (f64, f64) { + unimplemented!("TeamSumFactor stub — implemented in Task 4") + } +} diff --git a/src/factor/trunc.rs b/src/factor/trunc.rs new file mode 100644 index 0000000..f5b6dfe --- /dev/null +++ b/src/factor/trunc.rs @@ -0,0 +1,32 @@ +use crate::{ + N_INF, + factor::{Factor, VarId, VarStore}, + gaussian::Gaussian, +}; + +#[derive(Debug)] +pub(crate) struct TruncFactor { + pub(crate) diff: VarId, + pub(crate) margin: f64, + pub(crate) tie: bool, + pub(crate) msg: Gaussian, + pub(crate) evidence_cached: Option, +} + +impl TruncFactor { + pub(crate) fn new(diff: VarId, margin: f64, tie: bool) -> Self { + Self { + diff, + margin, + tie, + msg: N_INF, + evidence_cached: None, + } + } +} + +impl Factor for TruncFactor { + fn propagate(&mut self, _vars: &mut VarStore) -> (f64, f64) { + unimplemented!("TruncFactor stub — implemented in Task 6") + } +}