T0 + T1 + T2: engine redesign through new API surface #1
@@ -7,7 +7,7 @@ use crate::gaussian::Gaussian;
|
|||||||
/// Variables hold the current Gaussian marginal and are owned by exactly one
|
/// Variables hold the current Gaussian marginal and are owned by exactly one
|
||||||
/// `VarStore`. `VarId` is meaningful only within its owning store.
|
/// `VarStore`. `VarId` is meaningful only within its owning store.
|
||||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
|
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
|
||||||
pub(crate) struct VarId(pub(crate) u32);
|
pub struct VarId(pub u32);
|
||||||
|
|
||||||
/// Flat storage of variable marginals.
|
/// Flat storage of variable marginals.
|
||||||
///
|
///
|
||||||
@@ -15,36 +15,38 @@ pub(crate) struct VarId(pub(crate) u32);
|
|||||||
/// reused across `Game::new` calls (it lives in the `ScratchArena`); call
|
/// reused across `Game::new` calls (it lives in the `ScratchArena`); call
|
||||||
/// `clear()` before reuse.
|
/// `clear()` before reuse.
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
pub(crate) struct VarStore {
|
pub struct VarStore {
|
||||||
pub(crate) marginals: Vec<Gaussian>,
|
pub(crate) marginals: Vec<Gaussian>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VarStore {
|
impl VarStore {
|
||||||
#[allow(dead_code)]
|
pub fn new() -> Self {
|
||||||
pub(crate) fn new() -> Self {
|
|
||||||
Self::default()
|
Self::default()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn clear(&mut self) {
|
pub fn clear(&mut self) {
|
||||||
self.marginals.clear();
|
self.marginals.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
pub fn len(&self) -> usize {
|
||||||
pub(crate) fn len(&self) -> usize {
|
|
||||||
self.marginals.len()
|
self.marginals.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn alloc(&mut self, init: Gaussian) -> VarId {
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.marginals.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn alloc(&mut self, init: Gaussian) -> VarId {
|
||||||
let id = VarId(self.marginals.len() as u32);
|
let id = VarId(self.marginals.len() as u32);
|
||||||
self.marginals.push(init);
|
self.marginals.push(init);
|
||||||
id
|
id
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn get(&self, id: VarId) -> Gaussian {
|
pub fn get(&self, id: VarId) -> Gaussian {
|
||||||
self.marginals[id.0 as usize]
|
self.marginals[id.0 as usize]
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn set(&mut self, id: VarId, g: Gaussian) {
|
pub fn set(&mut self, id: VarId, g: Gaussian) {
|
||||||
self.marginals[id.0 as usize] = g;
|
self.marginals[id.0 as usize] = g;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -54,7 +56,7 @@ impl VarStore {
|
|||||||
/// Factors hold their own outgoing messages and propagate them by reading
|
/// Factors hold their own outgoing messages and propagate them by reading
|
||||||
/// connected variable marginals from a `VarStore` and writing back updated
|
/// connected variable marginals from a `VarStore` and writing back updated
|
||||||
/// marginals.
|
/// marginals.
|
||||||
pub(crate) trait Factor {
|
pub trait Factor {
|
||||||
/// Update outgoing messages and write back to the var store.
|
/// Update outgoing messages and write back to the var store.
|
||||||
///
|
///
|
||||||
/// Returns the max delta `(|Δmu|, |Δsigma|)` across writes this
|
/// Returns the max delta `(|Δmu|, |Δsigma|)` across writes this
|
||||||
@@ -62,7 +64,6 @@ pub(crate) trait Factor {
|
|||||||
fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64);
|
fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64);
|
||||||
|
|
||||||
/// Optional log-evidence contribution. Default 0.0 (no contribution).
|
/// Optional log-evidence contribution. Default 0.0 (no contribution).
|
||||||
#[allow(dead_code)]
|
|
||||||
fn log_evidence(&self, _vars: &VarStore) -> f64 {
|
fn log_evidence(&self, _vars: &VarStore) -> f64 {
|
||||||
0.0
|
0.0
|
||||||
}
|
}
|
||||||
@@ -73,8 +74,7 @@ pub(crate) trait Factor {
|
|||||||
/// Using an enum instead of `Box<dyn Factor>` keeps factor data inline and
|
/// Using an enum instead of `Box<dyn Factor>` keeps factor data inline and
|
||||||
/// avoids virtual-call overhead in the hot inference loop.
|
/// avoids virtual-call overhead in the hot inference loop.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
#[allow(dead_code)]
|
pub enum BuiltinFactor {
|
||||||
pub(crate) enum BuiltinFactor {
|
|
||||||
TeamSum(team_sum::TeamSumFactor),
|
TeamSum(team_sum::TeamSumFactor),
|
||||||
RankDiff(rank_diff::RankDiffFactor),
|
RankDiff(rank_diff::RankDiffFactor),
|
||||||
Trunc(trunc::TruncFactor),
|
Trunc(trunc::TruncFactor),
|
||||||
@@ -97,9 +97,9 @@ impl Factor for BuiltinFactor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) mod rank_diff;
|
pub mod rank_diff;
|
||||||
pub(crate) mod team_sum;
|
pub mod team_sum;
|
||||||
pub(crate) mod trunc;
|
pub mod trunc;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
|||||||
@@ -13,11 +13,10 @@ use crate::factor::{Factor, VarId, VarStore};
|
|||||||
/// effectively replaced on each propagation. The TruncFactor on the same diff
|
/// effectively replaced on each propagation. The TruncFactor on the same diff
|
||||||
/// var holds the EP-divide message that produces the cavity.
|
/// var holds the EP-divide message that produces the cavity.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
#[allow(dead_code)]
|
pub struct RankDiffFactor {
|
||||||
pub(crate) struct RankDiffFactor {
|
pub team_a: VarId,
|
||||||
pub(crate) team_a: VarId,
|
pub team_b: VarId,
|
||||||
pub(crate) team_b: VarId,
|
pub diff: VarId,
|
||||||
pub(crate) diff: VarId,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Factor for RankDiffFactor {
|
impl Factor for RankDiffFactor {
|
||||||
|
|||||||
@@ -10,10 +10,9 @@ use crate::{
|
|||||||
/// already with beta² noise added via `Rating::performance()`). The factor
|
/// already with beta² noise added via `Rating::performance()`). The factor
|
||||||
/// runs once per game and writes the weighted sum to the output var.
|
/// runs once per game and writes the weighted sum to the output var.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
#[allow(dead_code)]
|
pub struct TeamSumFactor {
|
||||||
pub(crate) struct TeamSumFactor {
|
pub inputs: Vec<(Gaussian, f64)>,
|
||||||
pub(crate) inputs: Vec<(Gaussian, f64)>,
|
pub out: VarId,
|
||||||
pub(crate) out: VarId,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Factor for TeamSumFactor {
|
impl Factor for TeamSumFactor {
|
||||||
|
|||||||
@@ -11,10 +11,10 @@ use crate::{
|
|||||||
/// Stores its outgoing message to the diff variable so the cavity computation
|
/// Stores its outgoing message to the diff variable so the cavity computation
|
||||||
/// produces the correct EP message on each propagation.
|
/// produces the correct EP message on each propagation.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct TruncFactor {
|
pub struct TruncFactor {
|
||||||
pub(crate) diff: VarId,
|
pub diff: VarId,
|
||||||
pub(crate) margin: f64,
|
pub margin: f64,
|
||||||
pub(crate) tie: bool,
|
pub tie: bool,
|
||||||
/// Outgoing message to the diff variable (initial: N_INF, the EP identity).
|
/// Outgoing message to the diff variable (initial: N_INF, the EP identity).
|
||||||
pub(crate) msg: Gaussian,
|
pub(crate) msg: Gaussian,
|
||||||
/// Cached evidence (linear, not log) computed from the cavity on first propagation.
|
/// Cached evidence (linear, not log) computed from the cavity on first propagation.
|
||||||
@@ -22,7 +22,7 @@ pub(crate) struct TruncFactor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl TruncFactor {
|
impl TruncFactor {
|
||||||
pub(crate) fn new(diff: VarId, margin: f64, tie: bool) -> Self {
|
pub fn new(diff: VarId, margin: f64, tie: bool) -> Self {
|
||||||
Self {
|
Self {
|
||||||
diff,
|
diff,
|
||||||
margin,
|
margin,
|
||||||
|
|||||||
13
src/factors.rs
Normal file
13
src/factors.rs
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
//! Factor-graph public API.
|
||||||
|
//!
|
||||||
|
//! Power users can construct custom factor graphs via `Game::custom` (T2
|
||||||
|
//! minimal; full ergonomics in T4) and drive them with custom `Schedule`
|
||||||
|
//! implementations.
|
||||||
|
|
||||||
|
pub use crate::{
|
||||||
|
factor::{
|
||||||
|
BuiltinFactor, Factor, VarId, VarStore, rank_diff::RankDiffFactor, team_sum::TeamSumFactor,
|
||||||
|
trunc::TruncFactor,
|
||||||
|
},
|
||||||
|
schedule::{EpsilonOrMax, Schedule, ScheduleReport},
|
||||||
|
};
|
||||||
@@ -16,6 +16,7 @@ mod error;
|
|||||||
mod event;
|
mod event;
|
||||||
mod event_builder;
|
mod event_builder;
|
||||||
pub(crate) mod factor;
|
pub(crate) mod factor;
|
||||||
|
pub mod factors;
|
||||||
mod game;
|
mod game;
|
||||||
pub mod gaussian;
|
pub mod gaussian;
|
||||||
mod history;
|
mod history;
|
||||||
|
|||||||
@@ -16,8 +16,7 @@ pub struct ScheduleReport {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Drives factor propagation to convergence.
|
/// Drives factor propagation to convergence.
|
||||||
#[allow(dead_code)]
|
pub trait Schedule {
|
||||||
pub(crate) trait Schedule {
|
|
||||||
fn run(&self, factors: &mut [BuiltinFactor], vars: &mut VarStore) -> ScheduleReport;
|
fn run(&self, factors: &mut [BuiltinFactor], vars: &mut VarStore) -> ScheduleReport;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -26,8 +25,7 @@ pub(crate) trait Schedule {
|
|||||||
/// Matches the existing `Game::likelihoods` loop bit-for-bit when given the
|
/// Matches the existing `Game::likelihoods` loop bit-for-bit when given the
|
||||||
/// same factor layout (TeamSums first, then alternating RankDiff/Trunc pairs).
|
/// same factor layout (TeamSums first, then alternating RankDiff/Trunc pairs).
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
#[allow(dead_code)]
|
pub struct EpsilonOrMax {
|
||||||
pub(crate) struct EpsilonOrMax {
|
|
||||||
pub eps: f64,
|
pub eps: f64,
|
||||||
pub max: usize,
|
pub max: usize,
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user