T0 + T1 + T2: engine redesign through new API surface #1
@@ -47,6 +47,56 @@ impl VarStore {
|
||||
}
|
||||
}
|
||||
|
||||
/// A factor in the EP graph.
|
||||
///
|
||||
/// Factors hold their own outgoing messages and propagate them by reading
|
||||
/// connected variable marginals from a `VarStore` and writing back updated
|
||||
/// marginals.
|
||||
pub(crate) trait Factor {
|
||||
/// Update outgoing messages and write back to the var store.
|
||||
///
|
||||
/// Returns the max delta `(|Δmu|, |Δsigma|)` across writes this
|
||||
/// propagation. Used by the `Schedule` to detect convergence.
|
||||
fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64);
|
||||
|
||||
/// Optional log-evidence contribution. Default 0.0 (no contribution).
|
||||
fn log_evidence(&self, _vars: &VarStore) -> f64 {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Enum dispatcher for the built-in factor types.
|
||||
///
|
||||
/// Using an enum instead of `Box<dyn Factor>` keeps factor data inline and
|
||||
/// avoids virtual-call overhead in the hot inference loop.
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum BuiltinFactor {
|
||||
TeamSum(team_sum::TeamSumFactor),
|
||||
RankDiff(rank_diff::RankDiffFactor),
|
||||
Trunc(trunc::TruncFactor),
|
||||
}
|
||||
|
||||
impl Factor for BuiltinFactor {
|
||||
fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64) {
|
||||
match self {
|
||||
Self::TeamSum(f) => f.propagate(vars),
|
||||
Self::RankDiff(f) => f.propagate(vars),
|
||||
Self::Trunc(f) => f.propagate(vars),
|
||||
}
|
||||
}
|
||||
|
||||
fn log_evidence(&self, vars: &VarStore) -> f64 {
|
||||
match self {
|
||||
Self::Trunc(f) => f.log_evidence(vars),
|
||||
_ => 0.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) mod rank_diff;
|
||||
pub(crate) mod team_sum;
|
||||
pub(crate) mod trunc;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
14
src/factor/rank_diff.rs
Normal file
14
src/factor/rank_diff.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
use crate::factor::{Factor, VarId, VarStore};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct RankDiffFactor {
|
||||
pub(crate) team_a: VarId,
|
||||
pub(crate) team_b: VarId,
|
||||
pub(crate) diff: VarId,
|
||||
}
|
||||
|
||||
impl Factor for RankDiffFactor {
|
||||
fn propagate(&mut self, _vars: &mut VarStore) -> (f64, f64) {
|
||||
unimplemented!("RankDiffFactor stub — implemented in Task 5")
|
||||
}
|
||||
}
|
||||
16
src/factor/team_sum.rs
Normal file
16
src/factor/team_sum.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
use crate::{
|
||||
factor::{Factor, VarId, VarStore},
|
||||
gaussian::Gaussian,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct TeamSumFactor {
|
||||
pub(crate) inputs: Vec<(Gaussian, f64)>,
|
||||
pub(crate) out: VarId,
|
||||
}
|
||||
|
||||
impl Factor for TeamSumFactor {
|
||||
fn propagate(&mut self, _vars: &mut VarStore) -> (f64, f64) {
|
||||
unimplemented!("TeamSumFactor stub — implemented in Task 4")
|
||||
}
|
||||
}
|
||||
32
src/factor/trunc.rs
Normal file
32
src/factor/trunc.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use crate::{
|
||||
N_INF,
|
||||
factor::{Factor, VarId, VarStore},
|
||||
gaussian::Gaussian,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct TruncFactor {
|
||||
pub(crate) diff: VarId,
|
||||
pub(crate) margin: f64,
|
||||
pub(crate) tie: bool,
|
||||
pub(crate) msg: Gaussian,
|
||||
pub(crate) evidence_cached: Option<f64>,
|
||||
}
|
||||
|
||||
impl TruncFactor {
|
||||
pub(crate) fn new(diff: VarId, margin: f64, tie: bool) -> Self {
|
||||
Self {
|
||||
diff,
|
||||
margin,
|
||||
tie,
|
||||
msg: N_INF,
|
||||
evidence_cached: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Factor for TruncFactor {
|
||||
fn propagate(&mut self, _vars: &mut VarStore) -> (f64, f64) {
|
||||
unimplemented!("TruncFactor stub — implemented in Task 6")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user