diff --git a/src/factor/mod.rs b/src/factor/mod.rs index 3580fa4..b0ce1b9 100644 --- a/src/factor/mod.rs +++ b/src/factor/mod.rs @@ -47,6 +47,56 @@ impl VarStore { } } +/// A factor in the EP graph. +/// +/// Factors hold their own outgoing messages and propagate them by reading +/// connected variable marginals from a `VarStore` and writing back updated +/// marginals. +pub(crate) trait Factor { + /// Update outgoing messages and write back to the var store. + /// + /// Returns the max delta `(|Δmu|, |Δsigma|)` across writes this + /// propagation. Used by the `Schedule` to detect convergence. + fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64); + + /// Optional log-evidence contribution. Default 0.0 (no contribution). + fn log_evidence(&self, _vars: &VarStore) -> f64 { + 0.0 + } +} + +/// Enum dispatcher for the built-in factor types. +/// +/// Using an enum instead of `Box` keeps factor data inline and +/// avoids virtual-call overhead in the hot inference loop. +#[derive(Debug)] +pub(crate) enum BuiltinFactor { + TeamSum(team_sum::TeamSumFactor), + RankDiff(rank_diff::RankDiffFactor), + Trunc(trunc::TruncFactor), +} + +impl Factor for BuiltinFactor { + fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64) { + match self { + Self::TeamSum(f) => f.propagate(vars), + Self::RankDiff(f) => f.propagate(vars), + Self::Trunc(f) => f.propagate(vars), + } + } + + fn log_evidence(&self, vars: &VarStore) -> f64 { + match self { + Self::Trunc(f) => f.log_evidence(vars), + _ => 0.0, + } + } +} + +pub(crate) mod rank_diff; +pub(crate) mod team_sum; +pub(crate) mod trunc; + #[cfg(test)] mod tests { use super::*; diff --git a/src/factor/rank_diff.rs b/src/factor/rank_diff.rs new file mode 100644 index 0000000..9ecf995 --- /dev/null +++ b/src/factor/rank_diff.rs @@ -0,0 +1,14 @@ +use crate::factor::{Factor, VarId, VarStore}; + +#[derive(Debug)] +pub(crate) struct RankDiffFactor { + pub(crate) team_a: VarId, + pub(crate) team_b: VarId, + pub(crate) diff: VarId, +} + +impl Factor for RankDiffFactor { + fn propagate(&mut self, _vars: &mut VarStore) -> (f64, f64) { + unimplemented!("RankDiffFactor stub — implemented in Task 5") + } +} diff --git a/src/factor/team_sum.rs b/src/factor/team_sum.rs new file mode 100644 index 0000000..1619ce0 --- /dev/null +++ b/src/factor/team_sum.rs @@ -0,0 +1,16 @@ +use crate::{ + factor::{Factor, VarId, VarStore}, + gaussian::Gaussian, +}; + +#[derive(Debug)] +pub(crate) struct TeamSumFactor { + pub(crate) inputs: Vec<(Gaussian, f64)>, + pub(crate) out: VarId, +} + +impl Factor for TeamSumFactor { + fn propagate(&mut self, _vars: &mut VarStore) -> (f64, f64) { + unimplemented!("TeamSumFactor stub — implemented in Task 4") + } +} diff --git a/src/factor/trunc.rs b/src/factor/trunc.rs new file mode 100644 index 0000000..f5b6dfe --- /dev/null +++ b/src/factor/trunc.rs @@ -0,0 +1,32 @@ +use crate::{ + N_INF, + factor::{Factor, VarId, VarStore}, + gaussian::Gaussian, +}; + +#[derive(Debug)] +pub(crate) struct TruncFactor { + pub(crate) diff: VarId, + pub(crate) margin: f64, + pub(crate) tie: bool, + pub(crate) msg: Gaussian, + pub(crate) evidence_cached: Option, +} + +impl TruncFactor { + pub(crate) fn new(diff: VarId, margin: f64, tie: bool) -> Self { + Self { + diff, + margin, + tie, + msg: N_INF, + evidence_cached: None, + } + } +} + +impl Factor for TruncFactor { + fn propagate(&mut self, _vars: &mut VarStore) -> (f64, f64) { + unimplemented!("TruncFactor stub — implemented in Task 6") + } +}