//! Factor graph machinery for within-game inference. use crate::gaussian::Gaussian; /// Identifier for a variable in a `VarStore`. /// /// Variables hold the current Gaussian marginal and are owned by exactly one /// `VarStore`. `VarId` is meaningful only within its owning store. #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] pub struct VarId(pub u32); /// Flat storage of variable marginals. /// /// Variables are allocated by `alloc()` and accessed by `VarId`. The store is /// reused across `Game::ranked_with_arena` calls (it lives in the `ScratchArena`); call /// `clear()` before reuse. #[derive(Debug, Default)] pub struct VarStore { pub(crate) marginals: Vec, } impl VarStore { pub fn new() -> Self { Self::default() } pub fn clear(&mut self) { self.marginals.clear(); } pub fn len(&self) -> usize { self.marginals.len() } pub fn is_empty(&self) -> bool { self.marginals.is_empty() } pub fn alloc(&mut self, init: Gaussian) -> VarId { let id = VarId(self.marginals.len() as u32); self.marginals.push(init); id } pub fn get(&self, id: VarId) -> Gaussian { self.marginals[id.0 as usize] } pub fn set(&mut self, id: VarId, g: Gaussian) { self.marginals[id.0 as usize] = g; } } /// A factor in the EP graph. /// /// Factors hold their own outgoing messages and propagate them by reading /// connected variable marginals from a `VarStore` and writing back updated /// marginals. pub trait Factor: Send + Sync { /// Update outgoing messages and write back to the var store. /// /// Returns the max delta `(|Δmu|, |Δsigma|)` across writes this /// propagation. Used by the `Schedule` to detect convergence. fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64); /// Optional log-evidence contribution. Default 0.0 (no contribution). fn log_evidence(&self, _vars: &VarStore) -> f64 { 0.0 } } /// Enum dispatcher for the built-in factor types. /// /// Using an enum instead of `Box` keeps factor data inline and /// avoids virtual-call overhead in the hot inference loop. #[derive(Debug)] pub enum BuiltinFactor { TeamSum(team_sum::TeamSumFactor), RankDiff(rank_diff::RankDiffFactor), Trunc(trunc::TruncFactor), Margin(margin::MarginFactor), } impl Factor for BuiltinFactor { fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64) { match self { Self::TeamSum(f) => f.propagate(vars), Self::RankDiff(f) => f.propagate(vars), Self::Trunc(f) => f.propagate(vars), Self::Margin(f) => f.propagate(vars), } } fn log_evidence(&self, vars: &VarStore) -> f64 { match self { Self::Trunc(f) => f.log_evidence(vars), Self::Margin(f) => f.log_evidence(vars), _ => 0.0, } } } pub mod margin; pub mod rank_diff; pub mod team_sum; pub mod trunc; #[cfg(test)] mod tests { use super::*; use crate::N_INF; #[test] fn alloc_assigns_sequential_ids() { let mut store = VarStore::new(); let a = store.alloc(N_INF); let b = store.alloc(N_INF); let c = store.alloc(N_INF); assert_eq!(a, VarId(0)); assert_eq!(b, VarId(1)); assert_eq!(c, VarId(2)); assert_eq!(store.len(), 3); } #[test] fn get_returns_initial_value() { let mut store = VarStore::new(); let g = Gaussian::from_ms(2.5, 1.0); let id = store.alloc(g); assert_eq!(store.get(id), g); } #[test] fn set_updates_value() { let mut store = VarStore::new(); let id = store.alloc(N_INF); let new = Gaussian::from_ms(3.0, 0.5); store.set(id, new); assert_eq!(store.get(id), new); } #[test] fn clear_resets_length_keeping_capacity() { let mut store = VarStore::new(); store.alloc(N_INF); store.alloc(N_INF); let cap = store.marginals.capacity(); store.clear(); assert_eq!(store.len(), 0); assert_eq!(store.marginals.capacity(), cap); } #[test] fn builtin_factor_dispatches_to_margin() { use super::margin::MarginFactor; let mut vars = VarStore::new(); let diff = vars.alloc(Gaussian::from_ms(0.0, 6.0)); let mut f = BuiltinFactor::Margin(MarginFactor::new(diff, 5.0, 1.0)); f.propagate(&mut vars); let result = vars.get(diff); assert!((result.mu() - 4.864864864864865).abs() < 1e-12); let logz = f.log_evidence(&vars); assert!((logz - (-3.062235327364623)).abs() < 1e-10); } }