Added two test cases from sublee/trueskill.
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::f64;
|
||||||
|
|
||||||
use gaussian::Gaussian;
|
use gaussian::Gaussian;
|
||||||
use math;
|
use math;
|
||||||
@@ -71,6 +72,8 @@ impl Variable {
|
|||||||
|
|
||||||
self.value = value;
|
self.value = value;
|
||||||
|
|
||||||
|
debug!("Variable::value old={:?}, new={:?}", old, value);
|
||||||
|
|
||||||
delta
|
delta
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,6 +86,8 @@ impl Variable {
|
|||||||
|
|
||||||
self.factors.insert(factor, message);
|
self.factors.insert(factor, message);
|
||||||
self.value = self.value / old * message;
|
self.value = self.value / old * message;
|
||||||
|
|
||||||
|
debug!("Variable::message old={:?}, new={:?}", old, message);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_message(&self, factor: usize) -> Gaussian {
|
pub fn get_message(&self, factor: usize) -> Gaussian {
|
||||||
@@ -229,13 +234,15 @@ impl SumFactor {
|
|||||||
fy: Vec<Gaussian>,
|
fy: Vec<Gaussian>,
|
||||||
a: &[f64],
|
a: &[f64],
|
||||||
) {
|
) {
|
||||||
let (sum_pi, sum_tau) =
|
let (sum_pi, sum_tau) = a.iter()
|
||||||
a.iter()
|
|
||||||
.zip(y.iter().zip(fy.iter()))
|
.zip(y.iter().zip(fy.iter()))
|
||||||
.fold((0.0, 0.0), |(pi, tau), (a, (y, fy))| {
|
.fold((0.0, 0.0f64), |(pi, tau), (a, (y, fy))| {
|
||||||
let x = *y / *fy;
|
let x = *y / *fy;
|
||||||
|
|
||||||
(pi + a.powi(2) / x.pi(), tau + a * x.tau() / x.pi())
|
let new_pi = a.powi(2) / x.pi();
|
||||||
|
let new_tau = a * x.mu();
|
||||||
|
|
||||||
|
(pi + new_pi, tau + new_tau)
|
||||||
});
|
});
|
||||||
|
|
||||||
let new_pi = 1.0 / sum_pi;
|
let new_pi = 1.0 / sum_pi;
|
||||||
@@ -278,12 +285,10 @@ impl SumFactor {
|
|||||||
.coeffs
|
.coeffs
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(i, coeff)| {
|
.map(|(i, coeff)| if i == index {
|
||||||
if i == index {
|
|
||||||
1.0 / idx_coeff
|
1.0 / idx_coeff
|
||||||
} else {
|
} else {
|
||||||
-coeff / idx_coeff
|
-coeff / idx_coeff
|
||||||
}
|
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ pub struct Gaussian {
|
|||||||
|
|
||||||
impl fmt::Debug for Gaussian {
|
impl fmt::Debug for Gaussian {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
write!(f, "N(mu={:.03}, sigma={:.03})", self.mu(), self.sigma())
|
write!(f, "N(mu={:.03}, sigma={:.03}, pi={:.03}, tau={:.03})", self.mu(), self.sigma(), self.pi(), self.tau())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
210
src/lib.rs
210
src/lib.rs
@@ -346,6 +346,18 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn generate_free_for_all(players: u16) -> (Vec<(Rating, u16)>, Vec<u16>) {
|
||||||
|
(0..players)
|
||||||
|
.map(|i| ((Rating::new(MU, SIGMA), i), i))
|
||||||
|
.unzip()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn generate_free_for_all_draw(players: u16) -> (Vec<(Rating, u16)>, Vec<u16>) {
|
||||||
|
(0..players)
|
||||||
|
.map(|i| ((Rating::new(MU, SIGMA), i), 0))
|
||||||
|
.unzip()
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_quality_1vs1() {
|
fn test_quality_1vs1() {
|
||||||
let alice = Rating::new(MU, SIGMA);
|
let alice = Rating::new(MU, SIGMA);
|
||||||
@@ -367,9 +379,9 @@ mod tests {
|
|||||||
|
|
||||||
let expected_ratings = vec![
|
let expected_ratings = vec![
|
||||||
Rating::new(33.20668089876779, 6.34810941351329),
|
Rating::new(33.20668089876779, 6.34810941351329),
|
||||||
Rating::new(27.401455165087352, 5.7871628131345645),
|
Rating::new(27.40145516508735, 5.78716281313456),
|
||||||
Rating::new(22.598544839299667, 5.787162810091708),
|
Rating::new(22.59854483929966, 5.78716281009170),
|
||||||
Rating::new(16.793319100187123, 6.348109386031168),
|
Rating::new(16.79331910018712, 6.34810938603116),
|
||||||
];
|
];
|
||||||
|
|
||||||
let ratings = rate(
|
let ratings = rate(
|
||||||
@@ -385,18 +397,17 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_rate_8_free_for_all_draw() {
|
fn test_rate_8_free_for_all_draw() {
|
||||||
let (ratings, ranks): (Vec<_>, Vec<_>) =
|
let (ratings, ranks) = generate_free_for_all_draw(8);
|
||||||
(0..8u16).map(|i| ((Rating::new(MU, SIGMA), i), 0)).unzip();
|
|
||||||
|
|
||||||
let expected_ratings = vec![
|
let expected_ratings = vec![
|
||||||
Rating::new(25.000000000000014, 4.592173723582464),
|
Rating::new(25.00000000000001, 4.59217372358246),
|
||||||
Rating::new(25.000000000000004, 4.582694291508923),
|
Rating::new(25.00000000000000, 4.58269429150892),
|
||||||
Rating::new(25.000000000000004, 4.576403088132029),
|
Rating::new(25.00000000000000, 4.57640308813202),
|
||||||
Rating::new(25.000000000000004, 4.5732660302525785),
|
Rating::new(25.00000000000000, 4.57326603025257),
|
||||||
Rating::new(25.000000000000004, 4.573266030252499),
|
Rating::new(25.00000000000000, 4.57326603025249),
|
||||||
Rating::new(25.000000000000004, 4.5764030881313005),
|
Rating::new(25.00000000000000, 4.57640308813130),
|
||||||
Rating::new(25.000000000000004, 4.582694291504568),
|
Rating::new(25.00000000000000, 4.58269429150456),
|
||||||
Rating::new(25.000000000000004, 4.5921737235617845),
|
Rating::new(25.00000000000000, 4.59217372356178),
|
||||||
];
|
];
|
||||||
|
|
||||||
let ratings = rate(ratings.as_ref(), ranks.as_ref(), DELTA);
|
let ratings = rate(ratings.as_ref(), ranks.as_ref(), DELTA);
|
||||||
@@ -408,26 +419,25 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_rate_16_free_for_all() {
|
fn test_rate_16_free_for_all() {
|
||||||
let (ratings, ranks): (Vec<_>, Vec<_>) =
|
let (ratings, ranks) = generate_free_for_all(16);
|
||||||
(0..16u16).map(|i| ((Rating::new(MU, SIGMA), i), i)).unzip();
|
|
||||||
|
|
||||||
let expected_ratings = vec![
|
let expected_ratings = vec![
|
||||||
Rating::new(38.331788391720295, 5.541377093821206),
|
Rating::new(38.33178839172029, 5.54137709382120),
|
||||||
Rating::new(34.746418422647835, 4.891059608895904),
|
Rating::new(34.74641842264783, 4.89105960889590),
|
||||||
Rating::new(32.61843846390381, 4.65295506775174),
|
Rating::new(32.61843846390381, 4.65295506775174),
|
||||||
Rating::new(31.046552382343123, 4.522499560708352),
|
Rating::new(31.04655238234312, 4.52249956070835),
|
||||||
Rating::new(29.776879742967576, 4.4373517127954045),
|
Rating::new(29.77687974296757, 4.43735171279540),
|
||||||
Rating::new(28.700098150398343, 4.376108908633618),
|
Rating::new(28.70009815039834, 4.37610890863361),
|
||||||
Rating::new(27.27926253322585, 4.345628543772217),
|
Rating::new(27.27926253322585, 4.34562854377221),
|
||||||
Rating::new(25.770921292726943, 4.331516572196449),
|
Rating::new(25.77092129272694, 4.33151657219644),
|
||||||
Rating::new(24.24208410982518, 4.328837113282726),
|
Rating::new(24.24208410982518, 4.32883711328272),
|
||||||
Rating::new(22.69645835682008, 4.3358108926893175),
|
Rating::new(22.69645835682008, 4.33581089268931),
|
||||||
Rating::new(21.111803677204012, 4.353110849422467),
|
Rating::new(21.11180367720401, 4.35311084942246),
|
||||||
Rating::new(19.453144013532874, 4.383947550430257),
|
Rating::new(19.45314401353287, 4.38394755043025),
|
||||||
Rating::new(17.666795353614532, 4.43588072704469),
|
Rating::new(17.66679535361453, 4.43588072704469),
|
||||||
Rating::new(15.65513147838557, 4.527230274864565),
|
Rating::new(15.65513147838557, 4.52723027486456),
|
||||||
Rating::new(13.192401824247439, 4.713609609918767),
|
Rating::new(13.19240182424743, 4.71360960991876),
|
||||||
Rating::new(9.461957485668828, 5.277768698162301),
|
Rating::new(9.461957485668828, 5.27776869816230),
|
||||||
];
|
];
|
||||||
|
|
||||||
let ratings = rate(ratings.as_ref(), ranks.as_ref(), DELTA);
|
let ratings = rate(ratings.as_ref(), ranks.as_ref(), DELTA);
|
||||||
@@ -443,8 +453,8 @@ mod tests {
|
|||||||
let bob = Rating::new(MU, SIGMA);
|
let bob = Rating::new(MU, SIGMA);
|
||||||
|
|
||||||
let expected_ratings = vec![
|
let expected_ratings = vec![
|
||||||
Rating::new(25.0, 6.457515683245051),
|
Rating::new(25.00000000000000, 6.45751568324505),
|
||||||
Rating::new(25.0, 6.457515683245051),
|
Rating::new(25.00000000000000, 6.45751568324505),
|
||||||
];
|
];
|
||||||
|
|
||||||
let ratings = rate(&[(alice, 0), (bob, 1)], &[0, 0], DELTA);
|
let ratings = rate(&[(alice, 0), (bob, 1)], &[0, 0], DELTA);
|
||||||
@@ -456,18 +466,16 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_rate_2vs2() {
|
fn test_rate_2vs2() {
|
||||||
let _ = env_logger::try_init();
|
|
||||||
|
|
||||||
let alice = Rating::new(MU, SIGMA);
|
let alice = Rating::new(MU, SIGMA);
|
||||||
let bob = Rating::new(MU, SIGMA);
|
let bob = Rating::new(MU, SIGMA);
|
||||||
let chris = Rating::new(MU, SIGMA);
|
let chris = Rating::new(MU, SIGMA);
|
||||||
let darren = Rating::new(MU, SIGMA);
|
let darren = Rating::new(MU, SIGMA);
|
||||||
|
|
||||||
let expected_ratings = vec![
|
let expected_ratings = vec![
|
||||||
Rating::new(28.108322399069035, 7.77436345109384),
|
Rating::new(28.10832239906903, 7.77436345109384),
|
||||||
Rating::new(28.108322399069035, 7.77436345109384),
|
Rating::new(28.10832239906903, 7.77436345109384),
|
||||||
Rating::new(21.891677600930958, 7.77436345109384),
|
Rating::new(21.89167760093095, 7.77436345109384),
|
||||||
Rating::new(21.891677600930958, 7.77436345109384),
|
Rating::new(21.89167760093095, 7.77436345109384),
|
||||||
];
|
];
|
||||||
|
|
||||||
let ratings = rate(
|
let ratings = rate(
|
||||||
@@ -483,8 +491,6 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_rate_4vs4() {
|
fn test_rate_4vs4() {
|
||||||
let _ = env_logger::try_init();
|
|
||||||
|
|
||||||
let alice = Rating::new(MU, SIGMA);
|
let alice = Rating::new(MU, SIGMA);
|
||||||
let bob = Rating::new(MU, SIGMA);
|
let bob = Rating::new(MU, SIGMA);
|
||||||
let chris = Rating::new(MU, SIGMA);
|
let chris = Rating::new(MU, SIGMA);
|
||||||
@@ -495,14 +501,14 @@ mod tests {
|
|||||||
let laura = Rating::new(MU, SIGMA);
|
let laura = Rating::new(MU, SIGMA);
|
||||||
|
|
||||||
let expected_ratings = vec![
|
let expected_ratings = vec![
|
||||||
Rating::new(27.19791584649575, 8.058911711843994),
|
Rating::new(27.19791584649575, 8.05891171184399),
|
||||||
Rating::new(27.19791584649575, 8.058911711843994),
|
Rating::new(27.19791584649575, 8.05891171184399),
|
||||||
Rating::new(27.19791584649575, 8.058911711843994),
|
Rating::new(27.19791584649575, 8.05891171184399),
|
||||||
Rating::new(27.19791584649575, 8.058911711843994),
|
Rating::new(27.19791584649575, 8.05891171184399),
|
||||||
Rating::new(22.802084153504236, 8.058911711843994),
|
Rating::new(22.80208415350423, 8.05891171184399),
|
||||||
Rating::new(22.802084153504236, 8.058911711843994),
|
Rating::new(22.80208415350423, 8.05891171184399),
|
||||||
Rating::new(22.802084153504236, 8.058911711843994),
|
Rating::new(22.80208415350423, 8.05891171184399),
|
||||||
Rating::new(22.802084153504236, 8.058911711843994),
|
Rating::new(22.80208415350423, 8.05891171184399),
|
||||||
];
|
];
|
||||||
|
|
||||||
let ratings = rate(
|
let ratings = rate(
|
||||||
@@ -524,4 +530,112 @@ mod tests {
|
|||||||
assert_relative_eq!(rating, expected, epsilon = EPSILON);
|
assert_relative_eq!(rating, expected, epsilon = EPSILON);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_rate_sublee_trueskill_issue_3_case_1() {
|
||||||
|
let ratings = vec![
|
||||||
|
(Rating::new(42.234, 3.728), 0),
|
||||||
|
(Rating::new(43.290, 3.842), 0),
|
||||||
|
|
||||||
|
(Rating::new(16.667, 0.500), 1),
|
||||||
|
(Rating::new(16.667, 0.500), 1),
|
||||||
|
(Rating::new(16.667, 0.500), 1),
|
||||||
|
(Rating::new(16.667, 0.500), 1),
|
||||||
|
(Rating::new(16.667, 0.500), 1),
|
||||||
|
(Rating::new(16.667, 0.500), 1),
|
||||||
|
(Rating::new(16.667, 0.500), 1),
|
||||||
|
(Rating::new(16.667, 0.500), 1),
|
||||||
|
(Rating::new(16.667, 0.500), 1),
|
||||||
|
(Rating::new(16.667, 0.500), 1),
|
||||||
|
(Rating::new(16.667, 0.500), 1),
|
||||||
|
(Rating::new(16.667, 0.500), 1),
|
||||||
|
(Rating::new(16.667, 0.500), 1),
|
||||||
|
(Rating::new(16.667, 0.500), 1),
|
||||||
|
];
|
||||||
|
|
||||||
|
let expected_ratings = vec![
|
||||||
|
Rating::new(49.04740037944730, 3.64544755829110),
|
||||||
|
Rating::new(50.52625967591528, 3.75146367648882),
|
||||||
|
|
||||||
|
Rating::new(16.54109749124066, 0.50668947816812),
|
||||||
|
Rating::new(16.54109749124066, 0.50668947816812),
|
||||||
|
Rating::new(16.54109749124066, 0.50668947816812),
|
||||||
|
Rating::new(16.54109749124066, 0.50668947816812),
|
||||||
|
Rating::new(16.54109749124066, 0.50668947816812),
|
||||||
|
Rating::new(16.54109749124066, 0.50668947816812),
|
||||||
|
Rating::new(16.54109749124066, 0.50668947816812),
|
||||||
|
Rating::new(16.54109749124066, 0.50668947816812),
|
||||||
|
Rating::new(16.54109749124066, 0.50668947816812),
|
||||||
|
Rating::new(16.54109749124066, 0.50668947816812),
|
||||||
|
Rating::new(16.54109749124066, 0.50668947816812),
|
||||||
|
Rating::new(16.54109749124066, 0.50668947816812),
|
||||||
|
Rating::new(16.54109749124066, 0.50668947816812),
|
||||||
|
Rating::new(16.54109749124066, 0.50668947816812),
|
||||||
|
];
|
||||||
|
|
||||||
|
let ratings = rate(
|
||||||
|
ratings.as_ref(),
|
||||||
|
&[0, 1],
|
||||||
|
DELTA,
|
||||||
|
);
|
||||||
|
|
||||||
|
for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) {
|
||||||
|
assert_relative_eq!(rating, expected, epsilon = EPSILON);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_rate_sublee_trueskill_issue_3_case_2() {
|
||||||
|
let _ = env_logger::try_init();
|
||||||
|
|
||||||
|
let ratings = vec![
|
||||||
|
(Rating::new(25.000, 0.500), 0),
|
||||||
|
(Rating::new(25.000, 0.500), 0),
|
||||||
|
(Rating::new(25.000, 0.500), 0),
|
||||||
|
(Rating::new(25.000, 0.500), 0),
|
||||||
|
(Rating::new(25.000, 0.500), 0),
|
||||||
|
(Rating::new(25.000, 0.500), 0),
|
||||||
|
(Rating::new(25.000, 0.500), 0),
|
||||||
|
(Rating::new(25.000, 0.500), 0),
|
||||||
|
(Rating::new(25.000, 0.500), 0),
|
||||||
|
(Rating::new(25.000, 0.500), 0),
|
||||||
|
(Rating::new(25.000, 0.500), 0),
|
||||||
|
(Rating::new(25.000, 0.500), 0),
|
||||||
|
(Rating::new(25.000, 0.500), 0),
|
||||||
|
(Rating::new(25.000, 0.500), 0),
|
||||||
|
|
||||||
|
(Rating::new(42.234, 3.728), 1),
|
||||||
|
(Rating::new(43.290, 3.842), 1),
|
||||||
|
];
|
||||||
|
|
||||||
|
let expected_ratings = vec![
|
||||||
|
Rating::new(25.00000000000000, 0.50689687752485),
|
||||||
|
Rating::new(25.00000000000000, 0.50689687752485),
|
||||||
|
Rating::new(25.00000000000000, 0.50689687752485),
|
||||||
|
Rating::new(25.00000000000000, 0.50689687752485),
|
||||||
|
Rating::new(25.00000000000000, 0.50689687752485),
|
||||||
|
Rating::new(25.00000000000000, 0.50689687752485),
|
||||||
|
Rating::new(25.00000000000000, 0.50689687752485),
|
||||||
|
Rating::new(25.00000000000000, 0.50689687752485),
|
||||||
|
Rating::new(25.00000000000000, 0.50689687752485),
|
||||||
|
Rating::new(25.00000000000000, 0.50689687752485),
|
||||||
|
Rating::new(25.00000000000000, 0.50689687752485),
|
||||||
|
Rating::new(25.00000000000000, 0.50689687752485),
|
||||||
|
Rating::new(25.00000000000000, 0.50689687752485),
|
||||||
|
Rating::new(25.00000000000000, 0.50689687752485),
|
||||||
|
|
||||||
|
Rating::new(42.23400000000000, 3.72893127376255),
|
||||||
|
Rating::new(43.29000000000000, 3.84290364756188),
|
||||||
|
];
|
||||||
|
|
||||||
|
let ratings = rate(
|
||||||
|
ratings.as_ref(),
|
||||||
|
&[0, 1],
|
||||||
|
DELTA,
|
||||||
|
);
|
||||||
|
|
||||||
|
for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) {
|
||||||
|
assert_relative_eq!(rating, expected, epsilon = EPSILON);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user