Passing tests for Batch

This commit is contained in:
2022-06-12 21:43:35 +02:00
parent ae1c765dbb
commit 5a7053fb5d
3 changed files with 159 additions and 27 deletions

View File

@@ -3,6 +3,5 @@ name = "trueskill-tt"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dev-dependencies]
approx = "0.5.1"
[dependencies]

View File

@@ -140,7 +140,7 @@ impl Batch {
} }
} }
let from = self.events.len() + 1; let from = self.events.len();
for e in 0..composition.len() { for e in 0..composition.len() {
let teams = (0..composition[e].len()) let teams = (0..composition[e].len())
@@ -176,10 +176,10 @@ impl Batch {
skill.likelihood * skill.backward * skill.forward skill.likelihood * skill.backward * skill.forward
} }
fn posteriors(&self) -> HashMap<&str, Gaussian> { pub fn posteriors(&self) -> HashMap<String, Gaussian> {
self.skills self.skills
.keys() .keys()
.map(|a| (a.as_str(), self.posterior(a))) .map(|a| (a.clone(), self.posterior(a)))
.collect::<HashMap<_, _>>() .collect::<HashMap<_, _>>()
} }
@@ -223,10 +223,65 @@ impl Batch {
self.events[e].evidence = g.evidence; self.events[e].evidence = g.evidence;
} }
} }
pub fn convergence(&mut self) -> usize {
let epsilon = 1e-6;
let iterations = 20;
let mut step = (f64::INFINITY, f64::INFINITY);
let mut i = 0;
while (step.0 > epsilon || step.1 > epsilon) && i < iterations {
let old = self.posteriors();
self.iteration(0);
let new = self.posteriors();
step = old.iter().fold((0.0, 0.0), |(o_l, o_r), (a, old)| {
let (n_l, n_r) = old.delta(new[a]);
(
if n_l > o_l { n_l } else { o_l },
if n_r > o_r { n_r } else { o_r },
)
});
i += 1;
}
i
}
/*
def convergence(self, epsilon=1e-6, iterations = 20):
step, i = (inf, inf), 0
while gr_tuple(step, epsilon) and (i < iterations):
old = self.posteriors().copy()
self.iteration()
step = dict_diff(old, self.posteriors())
i += 1
return i
def forward_prior_out(self, agent):
return self.skills[agent].forward * self.skills[agent].likelihood
def backward_prior_out(self, agent):
N = self.skills[agent].likelihood*self.skills[agent].backward
return N.forget(self.agents[agent].player.gamma, self.skills[agent].elapsed)
def new_backward_info(self):
for a in self.skills:
self.skills[a].backward = self.agents[a].message
return self.iteration()
def new_forward_info(self):
for a in self.skills:
self.skills[a].forward = self.agents[a].receive(self.skills[a].elapsed)
return self.iteration()
*/
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use approx::assert_ulps_eq;
use super::*; use super::*;
#[test] #[test]
@@ -248,7 +303,7 @@ mod tests {
agents.insert(k.to_string(), agent); agents.insert(k.to_string(), agent);
} }
let b = Batch::new( let mut b = Batch::new(
vec![ vec![
vec![vec!["a"], vec!["b"]], vec![vec!["a"], vec!["b"]],
vec![vec!["c"], vec!["d"]], vec![vec!["c"], vec!["d"]],
@@ -262,23 +317,71 @@ mod tests {
let post = b.posteriors(); let post = b.posteriors();
assert_eq!(post["a"].mu(), 29.205); assert_eq!(post["a"].mu(), 29.205220743876975);
assert_eq!(post["a"].sigma(), 7.194) assert_eq!(post["a"].sigma(), 7.194481422570443);
/* assert_eq!(post["b"].mu(), 20.79477925612302);
agents = dict() assert_eq!(post["b"].sigma(), 7.194481422570443);
for k in ["a", "b", "c", "d", "e", "f"]:
agents[k] = ttt.Agent(ttt.Player(ttt.Gaussian(25., 25.0/3), 25.0/6, 25.0/300 ) , ttt.Ninf, -ttt.inf)
b = ttt.Batch(composition=[ [["a"],["b"]], [["c"],["d"]] , [["e"],["f"]] ], results= [[1,0],[0,1],[1,0]], time = 0, agents=agents)
post = b.posteriors()
self.assertAlmostEqual(post["a"].mu,29.205,3)
self.assertAlmostEqual(post["a"].sigma,7.194,3)
self.assertAlmostEqual(post["b"].mu,20.795,3) assert_eq!(post["c"].mu(), 20.79477925612302);
self.assertAlmostEqual(post["b"].sigma,7.194,3) assert_eq!(post["c"].sigma(), 7.194481422570443);
self.assertAlmostEqual(post["c"].mu,20.795,3)
self.assertAlmostEqual(post["c"].sigma,7.194,3) assert_eq!(b.convergence(), 1);
self.assertEqual(b.convergence(),1) }
*/
#[test]
fn test_same_strength() {
let mut agents = HashMap::new();
for k in ["a", "b", "c", "d", "e", "f"] {
let agent = Agent::new(
Player::new(
Gaussian::new(25.0, 25.0 / 3.0),
25.0 / 6.0,
25.0 / 300.0,
N_INF,
),
N_INF,
f64::NEG_INFINITY,
);
agents.insert(k.to_string(), agent);
}
let mut b = Batch::new(
vec![
vec![vec!["a"], vec!["b"]],
vec![vec!["a"], vec!["c"]],
vec![vec!["b"], vec!["c"]],
],
vec![vec![1, 0], vec![0, 1], vec![1, 0]],
2.0,
agents,
0.0,
);
let post = b.posteriors();
assert_eq!(post["a"].mu(), 24.96097857478182);
assert_eq!(post["a"].sigma(), 6.298544763358269);
assert_eq!(post["b"].mu(), 27.095590669107086);
assert_eq!(post["b"].sigma(), 6.010330439043099);
assert_eq!(post["c"].mu(), 24.88968178743119);
assert_eq!(post["c"].sigma(), 5.866311348102562);
assert!(b.convergence() > 1);
let post = b.posteriors();
assert_ulps_eq!(post["a"].mu(), 25.000000, epsilon = 0.000001);
assert_ulps_eq!(post["a"].sigma(), 5.4192120, epsilon = 0.000001);
assert_ulps_eq!(post["b"].mu(), 25.000000, epsilon = 0.000001);
assert_ulps_eq!(post["b"].sigma(), 5.4192120, epsilon = 0.000001);
assert_ulps_eq!(post["c"].mu(), 25.000000, epsilon = 0.000001);
assert_ulps_eq!(post["c"].sigma(), 5.4192120, epsilon = 0.000001);
} }
} }

View File

@@ -1,9 +1,39 @@
use std::collections::HashMap;
use trueskill_tt::*; use trueskill_tt::*;
fn main() { fn main() {
let t_a = Player::new(Gaussian::new(29.0, 1.0), 25.0 / 6.0, GAMMA, N_INF); let mut agents = HashMap::new();
let t_b = Player::new(Gaussian::new(25.0, 25.0 / 3.0), 25.0 / 6.0, GAMMA, N_INF);
let g = Game::new(vec![vec![t_a], vec![t_b]], vec![0, 1], 0.0); for k in ["a", "b", "c", "d", "e", "f"] {
let p = g.posteriors(); let agent = Agent::new(
Player::new(
Gaussian::new(25.0, 25.0 / 3.0),
25.0 / 6.0,
25.0 / 300.0,
N_INF,
),
N_INF,
f64::NEG_INFINITY,
);
agents.insert(k.to_string(), agent);
}
let b = Batch::new(
vec![
vec![vec!["a"], vec!["b"]],
vec![vec!["c"], vec!["d"]],
vec![vec!["e"], vec!["f"]],
],
vec![vec![1, 0], vec![0, 1], vec![1, 0]],
0.0,
agents,
0.0,
);
let post = b.posteriors();
println!("{} {}", post["a"].mu(), 29.205);
println!("{} {}", post["a"].sigma(), 7.194)
} }