Added delta and more tests.
This commit is contained in:
@@ -49,13 +49,29 @@ impl Variable {
|
|||||||
self.factors.insert(factor, Gaussian::from_pi_tau(0.0, 0.0));
|
self.factors.insert(factor, Gaussian::from_pi_tau(0.0, 0.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn update_value(&mut self, factor: usize, value: Gaussian) {
|
pub fn update_value(&mut self, factor: usize, value: Gaussian) -> f64 {
|
||||||
let old = self.factors[&factor];
|
let old = self.factors[&factor];
|
||||||
|
|
||||||
self.factors.insert(factor, value * old / self.value);
|
self.factors.insert(factor, value * old / self.value);
|
||||||
|
|
||||||
|
let pi_delta = (self.value.pi() - value.pi()).abs();
|
||||||
|
|
||||||
|
let delta = if !pi_delta.is_finite() {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
let pi_delta = pi_delta.sqrt();
|
||||||
|
let tau_delta = (self.value.tau() - value.tau()).abs();
|
||||||
|
|
||||||
|
if pi_delta > tau_delta {
|
||||||
|
pi_delta
|
||||||
|
} else {
|
||||||
|
tau_delta
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
self.value = value;
|
self.value = value;
|
||||||
|
|
||||||
// debug!("Variable::value value={:?}", self.value);
|
delta
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_value(&self) -> Gaussian {
|
pub fn get_value(&self) -> Gaussian {
|
||||||
@@ -67,8 +83,6 @@ impl Variable {
|
|||||||
|
|
||||||
self.factors.insert(factor, message);
|
self.factors.insert(factor, message);
|
||||||
self.value = self.value / old * message;
|
self.value = self.value / old * message;
|
||||||
|
|
||||||
// debug!("Variable::message value={:?}", self.value);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_message(&self, factor: usize) -> Gaussian {
|
pub fn get_message(&self, factor: usize) -> Gaussian {
|
||||||
@@ -344,7 +358,7 @@ impl TruncateFactor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn update(&self, variable_arena: &mut VariableArena) {
|
pub fn update(&self, variable_arena: &mut VariableArena) -> f64 {
|
||||||
let (x, fx) = variable_arena
|
let (x, fx) = variable_arena
|
||||||
.get(self.variable)
|
.get(self.variable)
|
||||||
.map(|variable| (variable.get_value(), variable.get_message(self.id)))
|
.map(|variable| (variable.get_value(), variable.get_message(self.id)))
|
||||||
@@ -376,6 +390,6 @@ impl TruncateFactor {
|
|||||||
variable_arena
|
variable_arena
|
||||||
.get_mut(self.variable)
|
.get_mut(self.variable)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.update_value(self.id, gaussian);
|
.update_value(self.id, gaussian)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
88
src/lib.rs
88
src/lib.rs
@@ -77,7 +77,7 @@ fn draw_margin(p: f64, beta: f64, total_players: f64) -> f64 {
|
|||||||
math::icdf((p + 1.0) / 2.0) * total_players.sqrt() * beta
|
math::icdf((p + 1.0) / 2.0) * total_players.sqrt() * beta
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rate<R>(ratings: &[(R, u16)], ranks: &[u16]) -> Vec<Rating>
|
pub fn rate<R>(ratings: &[(R, u16)], ranks: &[u16], min_delta: f64) -> Vec<Rating>
|
||||||
where
|
where
|
||||||
R: Rateable,
|
R: Rateable,
|
||||||
{
|
{
|
||||||
@@ -205,19 +205,29 @@ where
|
|||||||
factor.update_sum(&mut variable_arena);
|
factor.update_sum(&mut variable_arena);
|
||||||
}
|
}
|
||||||
|
|
||||||
for _ in 0..5 {
|
for _ in 0..10 {
|
||||||
|
let mut delta = 0.0;
|
||||||
|
|
||||||
for factor in &team_diff_layer {
|
for factor in &team_diff_layer {
|
||||||
factor.update_sum(&mut variable_arena);
|
factor.update_sum(&mut variable_arena);
|
||||||
}
|
}
|
||||||
|
|
||||||
for factor in &trunc_layer {
|
for factor in &trunc_layer {
|
||||||
factor.update(&mut variable_arena);
|
let d = factor.update(&mut variable_arena);
|
||||||
|
|
||||||
|
if d > delta {
|
||||||
|
delta = d;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for factor in &team_diff_layer {
|
for factor in &team_diff_layer {
|
||||||
factor.update_term(&mut variable_arena, 0);
|
factor.update_term(&mut variable_arena, 0);
|
||||||
factor.update_term(&mut variable_arena, 1);
|
factor.update_term(&mut variable_arena, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if delta < min_delta {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for factor in &team_perf_layer {
|
for factor in &team_perf_layer {
|
||||||
@@ -356,15 +366,16 @@ mod tests {
|
|||||||
let darren = Rating::new(MU, SIGMA);
|
let darren = Rating::new(MU, SIGMA);
|
||||||
|
|
||||||
let expected_ratings = vec![
|
let expected_ratings = vec![
|
||||||
Rating::new(33.20778932559525, 6.347937214998893),
|
Rating::new(33.20668089876779, 6.34810941351329),
|
||||||
Rating::new(27.401497882797486, 5.787057812482782),
|
Rating::new(27.401455165087352, 5.7871628131345645),
|
||||||
Rating::new(22.598576351652632, 5.7871159419307645),
|
Rating::new(22.598544839299667, 5.787162810091708),
|
||||||
Rating::new(16.79337409436942, 6.348053083319977),
|
Rating::new(16.793319100187123, 6.348109386031168),
|
||||||
];
|
];
|
||||||
|
|
||||||
let ratings = rate(
|
let ratings = rate(
|
||||||
&[(alice, 0), (bob, 1), (chris, 2), (darren, 3)],
|
&[(alice, 0), (bob, 1), (chris, 2), (darren, 3)],
|
||||||
&[0, 1, 2, 3],
|
&[0, 1, 2, 3],
|
||||||
|
DELTA,
|
||||||
);
|
);
|
||||||
|
|
||||||
for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) {
|
for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) {
|
||||||
@@ -372,6 +383,60 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_rate_8_free_for_all_draw() {
|
||||||
|
let (ratings, ranks): (Vec<_>, Vec<_>) =
|
||||||
|
(0..8u16).map(|i| ((Rating::new(MU, SIGMA), i), 0)).unzip();
|
||||||
|
|
||||||
|
let expected_ratings = vec![
|
||||||
|
Rating::new(25.000000000000014, 4.592173723582464),
|
||||||
|
Rating::new(25.000000000000004, 4.582694291508923),
|
||||||
|
Rating::new(25.000000000000004, 4.576403088132029),
|
||||||
|
Rating::new(25.000000000000004, 4.5732660302525785),
|
||||||
|
Rating::new(25.000000000000004, 4.573266030252499),
|
||||||
|
Rating::new(25.000000000000004, 4.5764030881313005),
|
||||||
|
Rating::new(25.000000000000004, 4.582694291504568),
|
||||||
|
Rating::new(25.000000000000004, 4.5921737235617845),
|
||||||
|
];
|
||||||
|
|
||||||
|
let ratings = rate(ratings.as_ref(), ranks.as_ref(), DELTA);
|
||||||
|
|
||||||
|
for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) {
|
||||||
|
assert_relative_eq!(rating, expected, epsilon = EPSILON);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_rate_16_free_for_all() {
|
||||||
|
let (ratings, ranks): (Vec<_>, Vec<_>) =
|
||||||
|
(0..16u16).map(|i| ((Rating::new(MU, SIGMA), i), i)).unzip();
|
||||||
|
|
||||||
|
let expected_ratings = vec![
|
||||||
|
Rating::new(38.331788391720295, 5.541377093821206),
|
||||||
|
Rating::new(34.746418422647835, 4.891059608895904),
|
||||||
|
Rating::new(32.61843846390381, 4.65295506775174),
|
||||||
|
Rating::new(31.046552382343123, 4.522499560708352),
|
||||||
|
Rating::new(29.776879742967576, 4.4373517127954045),
|
||||||
|
Rating::new(28.700098150398343, 4.376108908633618),
|
||||||
|
Rating::new(27.27926253322585, 4.345628543772217),
|
||||||
|
Rating::new(25.770921292726943, 4.331516572196449),
|
||||||
|
Rating::new(24.24208410982518, 4.328837113282726),
|
||||||
|
Rating::new(22.69645835682008, 4.3358108926893175),
|
||||||
|
Rating::new(21.111803677204012, 4.353110849422467),
|
||||||
|
Rating::new(19.453144013532874, 4.383947550430257),
|
||||||
|
Rating::new(17.666795353614532, 4.43588072704469),
|
||||||
|
Rating::new(15.65513147838557, 4.527230274864565),
|
||||||
|
Rating::new(13.192401824247439, 4.713609609918767),
|
||||||
|
Rating::new(9.461957485668828, 5.277768698162301),
|
||||||
|
];
|
||||||
|
|
||||||
|
let ratings = rate(ratings.as_ref(), ranks.as_ref(), DELTA);
|
||||||
|
|
||||||
|
for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) {
|
||||||
|
assert_relative_eq!(rating, expected, epsilon = EPSILON);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_rate_1vs1_draw() {
|
fn test_rate_1vs1_draw() {
|
||||||
let alice = Rating::new(MU, SIGMA);
|
let alice = Rating::new(MU, SIGMA);
|
||||||
@@ -382,7 +447,7 @@ mod tests {
|
|||||||
Rating::new(25.0, 6.457515683245051),
|
Rating::new(25.0, 6.457515683245051),
|
||||||
];
|
];
|
||||||
|
|
||||||
let ratings = rate(&[(alice, 0), (bob, 1)], &[0, 0]);
|
let ratings = rate(&[(alice, 0), (bob, 1)], &[0, 0], DELTA);
|
||||||
|
|
||||||
for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) {
|
for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) {
|
||||||
assert_relative_eq!(rating, expected, epsilon = EPSILON);
|
assert_relative_eq!(rating, expected, epsilon = EPSILON);
|
||||||
@@ -405,7 +470,11 @@ mod tests {
|
|||||||
Rating::new(21.891677600930958, 7.77436345109384),
|
Rating::new(21.891677600930958, 7.77436345109384),
|
||||||
];
|
];
|
||||||
|
|
||||||
let ratings = rate(&[(alice, 0), (bob, 0), (chris, 1), (darren, 1)], &[0, 1]);
|
let ratings = rate(
|
||||||
|
&[(alice, 0), (bob, 0), (chris, 1), (darren, 1)],
|
||||||
|
&[0, 1],
|
||||||
|
DELTA,
|
||||||
|
);
|
||||||
|
|
||||||
for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) {
|
for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) {
|
||||||
assert_relative_eq!(rating, expected, epsilon = EPSILON);
|
assert_relative_eq!(rating, expected, epsilon = EPSILON);
|
||||||
@@ -448,6 +517,7 @@ mod tests {
|
|||||||
(laura, 1),
|
(laura, 1),
|
||||||
],
|
],
|
||||||
&[0, 1],
|
&[0, 1],
|
||||||
|
DELTA,
|
||||||
);
|
);
|
||||||
|
|
||||||
for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) {
|
for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) {
|
||||||
|
|||||||
Reference in New Issue
Block a user