T0 + T1 + T2: engine redesign through new API surface #1
@@ -47,6 +47,56 @@ impl VarStore {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A factor in the EP graph.
|
||||||
|
///
|
||||||
|
/// Factors hold their own outgoing messages and propagate them by reading
|
||||||
|
/// connected variable marginals from a `VarStore` and writing back updated
|
||||||
|
/// marginals.
|
||||||
|
pub(crate) trait Factor {
|
||||||
|
/// Update outgoing messages and write back to the var store.
|
||||||
|
///
|
||||||
|
/// Returns the max delta `(|Δmu|, |Δsigma|)` across writes this
|
||||||
|
/// propagation. Used by the `Schedule` to detect convergence.
|
||||||
|
fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64);
|
||||||
|
|
||||||
|
/// Optional log-evidence contribution. Default 0.0 (no contribution).
|
||||||
|
fn log_evidence(&self, _vars: &VarStore) -> f64 {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Enum dispatcher for the built-in factor types.
|
||||||
|
///
|
||||||
|
/// Using an enum instead of `Box<dyn Factor>` keeps factor data inline and
|
||||||
|
/// avoids virtual-call overhead in the hot inference loop.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) enum BuiltinFactor {
|
||||||
|
TeamSum(team_sum::TeamSumFactor),
|
||||||
|
RankDiff(rank_diff::RankDiffFactor),
|
||||||
|
Trunc(trunc::TruncFactor),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Factor for BuiltinFactor {
|
||||||
|
fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64) {
|
||||||
|
match self {
|
||||||
|
Self::TeamSum(f) => f.propagate(vars),
|
||||||
|
Self::RankDiff(f) => f.propagate(vars),
|
||||||
|
Self::Trunc(f) => f.propagate(vars),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn log_evidence(&self, vars: &VarStore) -> f64 {
|
||||||
|
match self {
|
||||||
|
Self::Trunc(f) => f.log_evidence(vars),
|
||||||
|
_ => 0.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) mod rank_diff;
|
||||||
|
pub(crate) mod team_sum;
|
||||||
|
pub(crate) mod trunc;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|||||||
14
src/factor/rank_diff.rs
Normal file
14
src/factor/rank_diff.rs
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
use crate::factor::{Factor, VarId, VarStore};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct RankDiffFactor {
|
||||||
|
pub(crate) team_a: VarId,
|
||||||
|
pub(crate) team_b: VarId,
|
||||||
|
pub(crate) diff: VarId,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Factor for RankDiffFactor {
|
||||||
|
fn propagate(&mut self, _vars: &mut VarStore) -> (f64, f64) {
|
||||||
|
unimplemented!("RankDiffFactor stub — implemented in Task 5")
|
||||||
|
}
|
||||||
|
}
|
||||||
16
src/factor/team_sum.rs
Normal file
16
src/factor/team_sum.rs
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
use crate::{
|
||||||
|
factor::{Factor, VarId, VarStore},
|
||||||
|
gaussian::Gaussian,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct TeamSumFactor {
|
||||||
|
pub(crate) inputs: Vec<(Gaussian, f64)>,
|
||||||
|
pub(crate) out: VarId,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Factor for TeamSumFactor {
|
||||||
|
fn propagate(&mut self, _vars: &mut VarStore) -> (f64, f64) {
|
||||||
|
unimplemented!("TeamSumFactor stub — implemented in Task 4")
|
||||||
|
}
|
||||||
|
}
|
||||||
32
src/factor/trunc.rs
Normal file
32
src/factor/trunc.rs
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
use crate::{
|
||||||
|
N_INF,
|
||||||
|
factor::{Factor, VarId, VarStore},
|
||||||
|
gaussian::Gaussian,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct TruncFactor {
|
||||||
|
pub(crate) diff: VarId,
|
||||||
|
pub(crate) margin: f64,
|
||||||
|
pub(crate) tie: bool,
|
||||||
|
pub(crate) msg: Gaussian,
|
||||||
|
pub(crate) evidence_cached: Option<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TruncFactor {
|
||||||
|
pub(crate) fn new(diff: VarId, margin: f64, tie: bool) -> Self {
|
||||||
|
Self {
|
||||||
|
diff,
|
||||||
|
margin,
|
||||||
|
tie,
|
||||||
|
msg: N_INF,
|
||||||
|
evidence_cached: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Factor for TruncFactor {
|
||||||
|
fn propagate(&mut self, _vars: &mut VarStore) -> (f64, f64) {
|
||||||
|
unimplemented!("TruncFactor stub — implemented in Task 6")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user