diff --git a/src/factor/mod.rs b/src/factor/mod.rs new file mode 100644 index 0000000..3580fa4 --- /dev/null +++ b/src/factor/mod.rs @@ -0,0 +1,94 @@ +//! Factor graph machinery for within-game inference. + +use crate::gaussian::Gaussian; + +/// Identifier for a variable in a `VarStore`. +/// +/// Variables hold the current Gaussian marginal and are owned by exactly one +/// `VarStore`. `VarId` is meaningful only within its owning store. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub(crate) struct VarId(pub(crate) u32); + +/// Flat storage of variable marginals. +/// +/// Variables are allocated by `alloc()` and accessed by `VarId`. The store is +/// reused across `Game::new` calls (it lives in the `ScratchArena`); call +/// `clear()` before reuse. +#[derive(Debug, Default)] +pub(crate) struct VarStore { + pub(crate) marginals: Vec, +} + +impl VarStore { + pub(crate) fn new() -> Self { + Self::default() + } + + pub(crate) fn clear(&mut self) { + self.marginals.clear(); + } + + pub(crate) fn len(&self) -> usize { + self.marginals.len() + } + + pub(crate) fn alloc(&mut self, init: Gaussian) -> VarId { + let id = VarId(self.marginals.len() as u32); + self.marginals.push(init); + id + } + + pub(crate) fn get(&self, id: VarId) -> Gaussian { + self.marginals[id.0 as usize] + } + + pub(crate) fn set(&mut self, id: VarId, g: Gaussian) { + self.marginals[id.0 as usize] = g; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::N_INF; + + #[test] + fn alloc_assigns_sequential_ids() { + let mut store = VarStore::new(); + let a = store.alloc(N_INF); + let b = store.alloc(N_INF); + let c = store.alloc(N_INF); + assert_eq!(a, VarId(0)); + assert_eq!(b, VarId(1)); + assert_eq!(c, VarId(2)); + assert_eq!(store.len(), 3); + } + + #[test] + fn get_returns_initial_value() { + let mut store = VarStore::new(); + let g = Gaussian::from_ms(2.5, 1.0); + let id = store.alloc(g); + assert_eq!(store.get(id), g); + } + + #[test] + fn set_updates_value() { + let mut store = VarStore::new(); + let id = store.alloc(N_INF); + let new = Gaussian::from_ms(3.0, 0.5); + store.set(id, new); + assert_eq!(store.get(id), new); + } + + #[test] + fn clear_resets_length_keeping_capacity() { + let mut store = VarStore::new(); + store.alloc(N_INF); + store.alloc(N_INF); + let cap = store.marginals.capacity(); + store.clear(); + assert_eq!(store.len(), 0); + assert_eq!(store.marginals.capacity(), cap); + } +} diff --git a/src/lib.rs b/src/lib.rs index ca0ea06..fd1f27c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ pub(crate) mod arena; pub mod batch; pub mod drift; mod error; +pub(crate) mod factor; mod game; pub mod gaussian; mod history;