feat(factor): implement RankDiffFactor
Maintains diff = team_a - team_b across three variables. On each propagation, reads the team-perf marginals (which may have been updated by neighboring factors) and computes the new diff via Gaussian Sub (variance addition).
This commit is contained in:
@@ -1,5 +1,17 @@
|
|||||||
use crate::factor::{Factor, VarId, VarStore};
|
use crate::factor::{Factor, VarId, VarStore};
|
||||||
|
|
||||||
|
/// Maintains the constraint `diff = team_a - team_b` between three vars.
|
||||||
|
///
|
||||||
|
/// On each propagation:
|
||||||
|
/// - Reads marginals at `team_a` and `team_b` (which already incorporate any
|
||||||
|
/// incoming messages from neighboring factors).
|
||||||
|
/// - Computes `new_diff = team_a - team_b` (variance addition; see Gaussian::Sub).
|
||||||
|
/// - Writes the new marginal to `diff`.
|
||||||
|
/// - Returns the delta against the previous diff value.
|
||||||
|
///
|
||||||
|
/// This factor does NOT store an outgoing message; the diff variable is
|
||||||
|
/// effectively replaced on each propagation. The TruncFactor on the same diff
|
||||||
|
/// var holds the EP-divide message that produces the cavity.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct RankDiffFactor {
|
pub(crate) struct RankDiffFactor {
|
||||||
pub(crate) team_a: VarId,
|
pub(crate) team_a: VarId,
|
||||||
@@ -8,7 +20,76 @@ pub(crate) struct RankDiffFactor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Factor for RankDiffFactor {
|
impl Factor for RankDiffFactor {
|
||||||
fn propagate(&mut self, _vars: &mut VarStore) -> (f64, f64) {
|
fn propagate(&mut self, vars: &mut VarStore) -> (f64, f64) {
|
||||||
unimplemented!("RankDiffFactor stub — implemented in Task 5")
|
let a = vars.get(self.team_a);
|
||||||
|
let b = vars.get(self.team_b);
|
||||||
|
let new_diff = a - b;
|
||||||
|
let old = vars.get(self.diff);
|
||||||
|
vars.set(self.diff, new_diff);
|
||||||
|
old.delta(new_diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::{N_INF, gaussian::Gaussian};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn diff_of_two_known_gaussians() {
|
||||||
|
let mut vars = VarStore::new();
|
||||||
|
let team_a = vars.alloc(Gaussian::from_ms(25.0, 3.0));
|
||||||
|
let team_b = vars.alloc(Gaussian::from_ms(20.0, 4.0));
|
||||||
|
let diff = vars.alloc(N_INF);
|
||||||
|
|
||||||
|
let mut f = RankDiffFactor {
|
||||||
|
team_a,
|
||||||
|
team_b,
|
||||||
|
diff,
|
||||||
|
};
|
||||||
|
f.propagate(&mut vars);
|
||||||
|
|
||||||
|
let result = vars.get(diff);
|
||||||
|
// mu = 25 - 20 = 5; var = 9 + 16 = 25; sigma = 5
|
||||||
|
assert!((result.mu() - 5.0).abs() < 1e-12);
|
||||||
|
assert!((result.sigma() - 5.0).abs() < 1e-12);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn delta_zero_on_repeat() {
|
||||||
|
let mut vars = VarStore::new();
|
||||||
|
let team_a = vars.alloc(Gaussian::from_ms(10.0, 2.0));
|
||||||
|
let team_b = vars.alloc(Gaussian::from_ms(8.0, 1.0));
|
||||||
|
let diff = vars.alloc(N_INF);
|
||||||
|
|
||||||
|
let mut f = RankDiffFactor {
|
||||||
|
team_a,
|
||||||
|
team_b,
|
||||||
|
diff,
|
||||||
|
};
|
||||||
|
f.propagate(&mut vars);
|
||||||
|
let (dmu, dsig) = f.propagate(&mut vars);
|
||||||
|
assert!(dmu < 1e-12);
|
||||||
|
assert!(dsig < 1e-12);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn delta_reflects_team_change() {
|
||||||
|
let mut vars = VarStore::new();
|
||||||
|
let team_a = vars.alloc(Gaussian::from_ms(10.0, 1.0));
|
||||||
|
let team_b = vars.alloc(Gaussian::from_ms(0.0, 1.0));
|
||||||
|
let diff = vars.alloc(N_INF);
|
||||||
|
|
||||||
|
let mut f = RankDiffFactor {
|
||||||
|
team_a,
|
||||||
|
team_b,
|
||||||
|
diff,
|
||||||
|
};
|
||||||
|
f.propagate(&mut vars);
|
||||||
|
|
||||||
|
// change team_a, repropagate; delta should be positive
|
||||||
|
vars.set(team_a, Gaussian::from_ms(15.0, 1.0));
|
||||||
|
let (dmu, _dsig) = f.propagate(&mut vars);
|
||||||
|
assert!(dmu > 4.0, "expected ~5 delta, got {}", dmu);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user