T0 + T1 + T2: engine redesign through new API surface #1
94
src/factor/mod.rs
Normal file
94
src/factor/mod.rs
Normal file
@@ -0,0 +1,94 @@
|
||||
//! Factor graph machinery for within-game inference.
|
||||
|
||||
use crate::gaussian::Gaussian;
|
||||
|
||||
/// Identifier for a variable in a `VarStore`.
|
||||
///
|
||||
/// Variables hold the current Gaussian marginal and are owned by exactly one
|
||||
/// `VarStore`. `VarId` is meaningful only within its owning store.
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub(crate) struct VarId(pub(crate) u32);
|
||||
|
||||
/// Flat storage of variable marginals.
|
||||
///
|
||||
/// Variables are allocated by `alloc()` and accessed by `VarId`. The store is
|
||||
/// reused across `Game::new` calls (it lives in the `ScratchArena`); call
|
||||
/// `clear()` before reuse.
|
||||
#[derive(Debug, Default)]
|
||||
pub(crate) struct VarStore {
|
||||
pub(crate) marginals: Vec<Gaussian>,
|
||||
}
|
||||
|
||||
impl VarStore {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub(crate) fn clear(&mut self) {
|
||||
self.marginals.clear();
|
||||
}
|
||||
|
||||
pub(crate) fn len(&self) -> usize {
|
||||
self.marginals.len()
|
||||
}
|
||||
|
||||
pub(crate) fn alloc(&mut self, init: Gaussian) -> VarId {
|
||||
let id = VarId(self.marginals.len() as u32);
|
||||
self.marginals.push(init);
|
||||
id
|
||||
}
|
||||
|
||||
pub(crate) fn get(&self, id: VarId) -> Gaussian {
|
||||
self.marginals[id.0 as usize]
|
||||
}
|
||||
|
||||
pub(crate) fn set(&mut self, id: VarId, g: Gaussian) {
|
||||
self.marginals[id.0 as usize] = g;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::N_INF;
|
||||
|
||||
#[test]
|
||||
fn alloc_assigns_sequential_ids() {
|
||||
let mut store = VarStore::new();
|
||||
let a = store.alloc(N_INF);
|
||||
let b = store.alloc(N_INF);
|
||||
let c = store.alloc(N_INF);
|
||||
assert_eq!(a, VarId(0));
|
||||
assert_eq!(b, VarId(1));
|
||||
assert_eq!(c, VarId(2));
|
||||
assert_eq!(store.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_returns_initial_value() {
|
||||
let mut store = VarStore::new();
|
||||
let g = Gaussian::from_ms(2.5, 1.0);
|
||||
let id = store.alloc(g);
|
||||
assert_eq!(store.get(id), g);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_updates_value() {
|
||||
let mut store = VarStore::new();
|
||||
let id = store.alloc(N_INF);
|
||||
let new = Gaussian::from_ms(3.0, 0.5);
|
||||
store.set(id, new);
|
||||
assert_eq!(store.get(id), new);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clear_resets_length_keeping_capacity() {
|
||||
let mut store = VarStore::new();
|
||||
store.alloc(N_INF);
|
||||
store.alloc(N_INF);
|
||||
let cap = store.marginals.capacity();
|
||||
store.clear();
|
||||
assert_eq!(store.len(), 0);
|
||||
assert_eq!(store.marginals.capacity(), cap);
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,7 @@ pub(crate) mod arena;
|
||||
pub mod batch;
|
||||
pub mod drift;
|
||||
mod error;
|
||||
pub(crate) mod factor;
|
||||
mod game;
|
||||
pub mod gaussian;
|
||||
mod history;
|
||||
|
||||
Reference in New Issue
Block a user