More tests, more code

This commit is contained in:
2022-06-19 23:52:52 +02:00
parent c9d9d59535
commit 0efb0312da
3 changed files with 472 additions and 121 deletions

View File

@@ -37,7 +37,7 @@ struct Team {
}
#[derive(Debug)]
struct Event {
pub(crate) struct Event {
teams: Vec<Team>,
evidence: f64,
weights: Vec<Vec<f64>>,
@@ -53,7 +53,7 @@ impl Event {
}
pub(crate) struct Batch {
events: Vec<Event>,
pub(crate) events: Vec<Event>,
pub(crate) skills: HashMap<String, Skill>,
pub(crate) time: u64,
p_draw: f64,
@@ -414,9 +414,40 @@ impl Batch {
}
}
}
pub(crate) fn get_composition(&self) -> Vec<Vec<Vec<&str>>> {
self.events
.iter()
.map(|event| {
event
.teams
.iter()
.map(|team| {
team.items
.iter()
.map(|item| item.agent.as_str())
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
}
pub(crate) fn get_results(&self) -> Vec<Vec<f64>> {
self.events
.iter()
.map(|event| {
event
.teams
.iter()
.map(|team| team.output)
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
}
}
fn compute_elapsed(last_time: u64, actual_time: u64) -> u64 {
pub(crate) fn compute_elapsed(last_time: u64, actual_time: u64) -> u64 {
if last_time == u64::MIN {
0
} else if last_time == u64::MAX {
@@ -425,6 +456,7 @@ fn compute_elapsed(last_time: u64, actual_time: u64) -> u64 {
actual_time - last_time
}
}
#[cfg(test)]
mod tests {
use approx::assert_ulps_eq;
@@ -566,14 +598,23 @@ mod tests {
let post = batch.posteriors();
assert_ulps_eq!(post["a"].mu, 25.000000, epsilon = 0.000001);
assert_ulps_eq!(post["a"].sigma, 5.4192120, epsilon = 0.000001);
assert_ulps_eq!(
post["a"],
Gaussian::new(25.000000, 5.4192120),
epsilon = 0.000001
);
assert_ulps_eq!(post["b"].mu, 25.000000, epsilon = 0.000001);
assert_ulps_eq!(post["b"].sigma, 5.4192120, epsilon = 0.000001);
assert_ulps_eq!(
post["b"],
Gaussian::new(25.000000, 5.4192120),
epsilon = 0.000001
);
assert_ulps_eq!(post["c"].mu, 25.000000, epsilon = 0.000001);
assert_ulps_eq!(post["c"].sigma, 5.4192120, epsilon = 0.000001);
assert_ulps_eq!(
post["c"],
Gaussian::new(25.000000, 5.4192120),
epsilon = 0.000001
);
batch.add_events(
vec![
@@ -592,13 +633,20 @@ mod tests {
let post = batch.posteriors();
assert_ulps_eq!(post["a"].mu, 25.00000315330858, epsilon = 0.000001);
assert_ulps_eq!(post["a"].sigma, 3.880150268080797, epsilon = 0.000001);
assert_ulps_eq!(post["b"].mu, 25.00000315330858, epsilon = 0.000001);
assert_ulps_eq!(post["b"].sigma, 3.880150268080797, epsilon = 0.000001);
assert_ulps_eq!(post["c"].mu, 25.00000315330858, epsilon = 0.000001);
assert_ulps_eq!(post["c"].sigma, 3.880150268080797, epsilon = 0.000001);
assert_ulps_eq!(
post["a"],
Gaussian::new(25.00000315330858, 3.880150268080797),
epsilon = 0.000001
);
assert_ulps_eq!(
post["b"],
Gaussian::new(25.00000315330858, 3.880150268080797),
epsilon = 0.000001
);
assert_ulps_eq!(
post["c"],
Gaussian::new(25.00000315330858, 3.880150268080797),
epsilon = 0.000001
);
}
}

View File

@@ -24,12 +24,12 @@ impl Game {
) -> Self {
assert!(
(result.is_empty() || result.len() == teams.len()),
"result.must be empty or the same length as teams"
"result must be empty or the same length as teams"
);
assert!(
(weights.is_empty() || weights.len() == teams.len()),
"weights.must be empty or the same length as teams"
"weights must be empty or the same length as teams"
);
assert!(
@@ -38,7 +38,7 @@ impl Game {
.iter()
.zip(teams.iter())
.all(|(w, t)| w.len() == t.len()),
"weights.must be empty or has the same dimensions as teams"
"weights must be empty or has the same dimensions as teams"
);
assert!(
@@ -52,7 +52,7 @@ impl Game {
r.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
r.windows(2).all(|w| w[0] != w[1])
},
"draw.must be > 0.0 if there is teams with draw"
"draw must be > 0.0 if there is teams with draw"
);
if result.is_empty() {

View File

@@ -2,7 +2,7 @@ use std::collections::{HashMap, HashSet};
use crate::{
agent::{self, Agent},
batch::Batch,
batch::{self, Batch},
gaussian::Gaussian,
player::Player,
sort_time, tuple_gt, tuple_max,
@@ -40,15 +40,15 @@ impl History {
) -> Self {
assert!(
results.is_empty() || results.len() == composition.len(),
"TODO: Add a comment here"
"(length(results) > 0) & (length(composition) != length(results))"
);
assert!(
times.is_empty() || times.len() == composition.len(),
"TODO: Add a comment here"
"length(times) > 0) & (length(composition) != length(times))"
);
assert!(
weights.is_empty() || weights.len() == composition.len(),
"TODO: Add a comment here"
"(length(weights) > 0) & (length(composition) != length(weights))"
);
let this_agent = composition
@@ -293,6 +293,183 @@ impl History {
.map(|batch| batch.log_evidence2(self.online, agents, forward, &mut self.agents))
.sum()
}
pub fn add_events(
&mut self,
composition: Vec<Vec<Vec<&str>>>,
results: Vec<Vec<f64>>,
times: Vec<u64>,
weights: Vec<Vec<Vec<f64>>>,
priors: HashMap<String, Player>,
) {
assert!(times.is_empty() || self.time, "length(times)>0 but !h.time");
assert!(
!times.is_empty() || !self.time,
"length(times)==0 but h.time"
);
assert!(
results.is_empty() || results.len() == composition.len(),
"(length(results) > 0) & (length(composition) != length(results))"
);
assert!(
times.is_empty() || times.len() == composition.len(),
"length(times) > 0) & (length(composition) != length(times))"
);
assert!(
weights.is_empty() || weights.len() == composition.len(),
"(length(weights) > 0) & (length(composition) != length(weights))"
);
let this_agent = composition
.iter()
.flatten()
.flatten()
.cloned()
.collect::<HashSet<_>>();
for agent in &this_agent {
if !self.agents.contains_key(*agent) {
self.agents.insert(
agent.to_string(),
Agent {
player: priors.get(*agent).cloned().unwrap_or_else(|| {
Player::new(Gaussian::new(self.mu, self.sigma), self.beta, self.gamma)
}),
..Default::default()
},
);
}
}
agent::clean(self.agents.values_mut(), true);
let n = composition.len();
let o = if self.time {
sort_time(&times, false)
} else {
(0..composition.len()).collect::<Vec<_>>()
};
let mut i = 0;
let mut k = 0;
while i < n {
let mut j = i + 1;
let t = if self.time { times[o[i]] } else { i as u64 + 1 };
while self.time && j < self.size && times[o[j]] == t {
j += 1;
}
while (!self.time && (self.size > k))
|| (self.time && self.batches.len() > k && self.batches[k].time < t)
{
let b = &mut self.batches[k];
if k > 0 {
b.new_forward_info(&mut self.agents);
}
let intersect = this_agent
.iter()
.filter(|&&agent| b.skills.contains_key(agent))
.cloned()
.collect::<Vec<_>>();
for agent in intersect {
b.skills.get_mut(agent).unwrap().elapsed =
batch::compute_elapsed(self.agents[agent].last_time, b.time);
let a = self.agents.get_mut(agent).unwrap();
a.last_time = if self.time { b.time } else { u64::MAX };
a.message = b.forward_prior_out(agent);
}
k += 1;
}
let composition = (i..j)
.map(|e| composition[o[e]].clone())
.collect::<Vec<_>>();
let results = if results.is_empty() {
Vec::new()
} else {
(i..j).map(|e| results[o[e]].clone()).collect::<Vec<_>>()
};
let weights = if weights.is_empty() {
Vec::new()
} else {
(i..j).map(|e| weights[o[e]].clone()).collect::<Vec<_>>()
};
if self.time && self.batches.len() > k && self.batches[k].time == t {
let b = &mut self.batches[k];
b.add_events(composition, results, weights, &mut self.agents);
for a in b.skills.keys() {
let agent = self.agents.get_mut(a).unwrap();
agent.last_time = if self.time { t } else { u64::MAX };
agent.message = b.forward_prior_out(a);
}
} else {
let b = Batch::new(
composition,
results,
weights,
t,
self.p_draw,
&mut self.agents,
);
self.batches.insert(k, b);
let b = &mut self.batches[k];
for a in b.skills.keys() {
let agent = self.agents.get_mut(a).unwrap();
agent.last_time = if self.time { t } else { u64::MAX };
agent.message = b.forward_prior_out(a);
}
k += 1;
}
i = j;
}
while self.time && self.batches.len() > k {
let b = &mut self.batches[k];
b.new_backward_info(&mut self.agents);
let intersect = this_agent
.iter()
.filter(|&&agent| b.skills.contains_key(agent))
.cloned()
.collect::<Vec<_>>();
for agent in intersect {
b.skills.get_mut(agent).unwrap().elapsed =
batch::compute_elapsed(self.agents[agent].last_time, b.time);
let a = self.agents.get_mut(agent).unwrap();
a.last_time = if self.time { b.time } else { u64::MAX };
a.message = b.forward_prior_out(agent);
}
k += 1;
}
self.size += n;
self.iteration();
}
}
#[cfg(test)]
@@ -341,8 +518,11 @@ mod tests {
let p0 = h.batches[0].posteriors();
assert_ulps_eq!(p0["a"].mu, 29.205220743876975, epsilon = 0.000001);
assert_ulps_eq!(p0["a"].sigma, 7.194481422570443, epsilon = 0.000001);
assert_ulps_eq!(
p0["a"],
Gaussian::new(29.205220743876975, 7.194481422570443),
epsilon = 0.000001
);
let observed = h.batches[1].skills["a"].forward.sigma;
let gamma: f64 = 0.15 * 25.0 / 3.0;
@@ -351,6 +531,7 @@ mod tests {
assert_ulps_eq!(observed, expected, epsilon = 0.000001);
let observed = h.batches[1].posterior("a");
let p = Game::new(
h.batches[1].within_priors(0, false, false, &mut h.agents),
vec![0.0, 1.0],
@@ -360,8 +541,7 @@ mod tests {
.posteriors();
let expected = p[0][0];
assert_ulps_eq!(observed.mu, expected.mu, epsilon = 0.000001);
assert_ulps_eq!(observed.sigma, expected.sigma, epsilon = 0.000001);
assert_ulps_eq!(observed, expected, epsilon = 0.000001);
}
#[test]
@@ -401,46 +581,26 @@ mod tests {
);
assert_ulps_eq!(
h1.batches[0].posterior("a").mu,
22.904409330892914,
h1.batches[0].posterior("a"),
Gaussian::new(22.904409330892914, 6.0103304390431),
epsilon = 0.000001
);
assert_ulps_eq!(
h1.batches[0].posterior("a").sigma,
6.0103304390431,
epsilon = 0.000001
);
assert_ulps_eq!(
h1.batches[0].posterior("c").mu,
25.110318212568806,
epsilon = 0.000001
);
assert_ulps_eq!(
h1.batches[0].posterior("c").sigma,
5.866311348102563,
h1.batches[0].posterior("c"),
Gaussian::new(25.110318212568806, 5.866311348102563),
epsilon = 0.000001
);
let (_step, _i) = h1.convergence(ITERATIONS, EPSILON, false);
assert_ulps_eq!(
h1.batches[0].posterior("a").mu,
25.00000000,
h1.batches[0].posterior("a"),
Gaussian::new(25.00000000, 5.41921200),
epsilon = 0.000001
);
assert_ulps_eq!(
h1.batches[0].posterior("a").sigma,
5.41921200,
epsilon = 0.000001
);
assert_ulps_eq!(
h1.batches[0].posterior("c").mu,
25.00000000,
epsilon = 0.000001
);
assert_ulps_eq!(
h1.batches[0].posterior("c").sigma,
5.41921200,
h1.batches[0].posterior("c"),
Gaussian::new(25.00000000, 5.41921200),
epsilon = 0.000001
);
@@ -475,46 +635,26 @@ mod tests {
);
assert_ulps_eq!(
h2.batches[2].posterior("a").mu,
22.90352227792141,
h2.batches[2].posterior("a"),
Gaussian::new(22.90352227792141, 6.011017301320632),
epsilon = 0.000001
);
assert_ulps_eq!(
h2.batches[2].posterior("a").sigma,
6.011017301320632,
epsilon = 0.000001
);
assert_ulps_eq!(
h2.batches[2].posterior("c").mu,
25.110702468366718,
epsilon = 0.000001
);
assert_ulps_eq!(
h2.batches[2].posterior("c").sigma,
5.866811597660157,
h2.batches[2].posterior("c"),
Gaussian::new(25.110702468366718, 5.866811597660157),
epsilon = 0.000001
);
let (_step, _i) = h2.convergence(ITERATIONS, EPSILON, false);
h2.convergence(ITERATIONS, EPSILON, false);
assert_ulps_eq!(
h2.batches[2].posterior("a").mu,
24.99866831022851,
h2.batches[2].posterior("a"),
Gaussian::new(24.99866831022851, 5.420053708148435),
epsilon = 0.000001
);
assert_ulps_eq!(
h2.batches[2].posterior("a").sigma,
5.420053708148435,
epsilon = 0.000001
);
assert_ulps_eq!(
h2.batches[2].posterior("c").mu,
25.000532179593538,
epsilon = 0.000001
);
assert_ulps_eq!(
h2.batches[2].posterior("c").sigma,
5.419827012784138,
h2.batches[2].posterior("c"),
Gaussian::new(25.000532179593538, 5.419827012784138),
epsilon = 0.000001
);
}
@@ -562,23 +702,13 @@ mod tests {
assert_eq!(lc["a"][aj_e - 1].0, 7);
assert_ulps_eq!(
lc["a"][aj_e - 1].1.mu,
24.99866831022851,
lc["a"][aj_e - 1].1,
Gaussian::new(24.99866831022851, 5.420053708148435),
epsilon = 0.000001
);
assert_ulps_eq!(
lc["a"][aj_e - 1].1.sigma,
5.420053708148435,
epsilon = 0.000001
);
assert_ulps_eq!(
lc["c"][cj_e - 1].1.mu,
25.000532179593538,
epsilon = 0.000001
);
assert_ulps_eq!(
lc["c"][cj_e - 1].1.sigma,
5.419827012784138,
lc["c"][cj_e - 1].1,
Gaussian::new(25.000532179593538, 5.419827012784138),
epsilon = 0.000001
);
}
@@ -612,33 +742,18 @@ mod tests {
assert_eq!(h.batches[2].skills["c"].elapsed, 1);
assert_ulps_eq!(
h.batches[0].posterior("a").mu,
25.0002673,
h.batches[0].posterior("a"),
Gaussian::new(25.0002673, 5.41938162),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[0].posterior("a").sigma,
5.41938162,
h.batches[0].posterior("b"),
Gaussian::new(24.999465, 5.419425831),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[0].posterior("b").mu,
24.999465,
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[0].posterior("b").sigma,
5.419425831,
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[2].posterior("b").mu,
25.00053219,
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[2].posterior("b").sigma,
5.419696790,
h.batches[2].posterior("b"),
Gaussian::new(25.00053219, 5.419696790),
epsilon = 0.000001
);
}
@@ -727,4 +842,192 @@ mod tests {
epsilon = 0.000001
);
}
#[test]
fn test_add_events() {
let composition = vec![
vec![vec!["a"], vec!["b"]],
vec![vec!["a"], vec!["c"]],
vec![vec!["b"], vec!["c"]],
];
let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]];
let mut h = History::new(
composition.clone(),
results.clone(),
vec![],
vec![],
HashMap::new(),
0.0,
2.0,
1.0,
0.0,
0.0,
false,
);
let (_step, _i) = h.convergence(ITERATIONS, EPSILON, false);
assert_eq!(h.batches[2].skills["b"].elapsed, 1);
assert_eq!(h.batches[2].skills["c"].elapsed, 1);
assert_ulps_eq!(
h.batches[0].posterior("a"),
Gaussian::new(0.0, 1.30061),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[0].posterior("b"),
Gaussian::new(0.0, 1.30061),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[2].posterior("b"),
Gaussian::new(0.0, 1.30061),
epsilon = 0.000001
);
h.add_events(composition, results, vec![], vec![], HashMap::new());
assert_eq!(h.batches.len(), 6);
assert_eq!(
h.batches
.iter()
.map(|b| b.get_composition())
.collect::<Vec<_>>(),
vec![
vec![vec![vec!["a"], vec!["b"]]],
vec![vec![vec!["a"], vec!["c"]]],
vec![vec![vec!["b"], vec!["c"]]],
vec![vec![vec!["a"], vec!["b"]]],
vec![vec![vec!["a"], vec!["c"]]],
vec![vec![vec!["b"], vec!["c"]]]
]
);
h.convergence(ITERATIONS, EPSILON, false);
assert_ulps_eq!(
h.batches[0].posterior("a"),
Gaussian::new(0.0, 0.9312360609998878),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[3].posterior("a"),
Gaussian::new(0.0, 0.9312360609998878),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[3].posterior("b"),
Gaussian::new(0.0, 0.9312360609998878),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[5].posterior("b"),
Gaussian::new(0.0, 0.9312360609998878),
epsilon = 0.000001
);
}
#[test]
fn test_add_events_with_time() {
let composition = vec![
vec![vec!["a"], vec!["b"]],
vec![vec!["a"], vec!["c"]],
vec![vec!["b"], vec!["c"]],
];
let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]];
let mut h = History::new(
composition.clone(),
results.clone(),
vec![0, 10, 20],
vec![],
HashMap::new(),
0.0,
2.0,
1.0,
0.0,
0.0,
false,
);
h.convergence(ITERATIONS, EPSILON, false);
h.add_events(
composition,
results,
vec![15, 10, 0],
vec![],
HashMap::new(),
);
assert_eq!(h.batches.len(), 4);
assert_eq!(
h.batches
.iter()
.map(|batch| batch.events.len())
.collect::<Vec<_>>(),
vec![2, 2, 1, 1]
);
assert_eq!(
h.batches
.iter()
.map(|b| b.get_composition())
.collect::<Vec<_>>(),
vec![
vec![vec![vec!["a"], vec!["b"]], vec![vec!["b"], vec!["c"]]],
vec![vec![vec!["a"], vec!["c"]], vec![vec!["a"], vec!["c"]]],
vec![vec![vec!["a"], vec!["b"]]],
vec![vec![vec!["b"], vec!["c"]]]
]
);
assert_eq!(
h.batches
.iter()
.map(|b| b.get_results())
.collect::<Vec<_>>(),
vec![
vec![vec![1.0, 0.0], vec![1.0, 0.0]],
vec![vec![0.0, 1.0], vec![0.0, 1.0]],
vec![vec![1.0, 0.0]],
vec![vec![1.0, 0.0]]
]
);
let end = h.batches.len() - 1;
assert_eq!(h.batches[0].skills["c"].elapsed, 0);
assert_eq!(h.batches[end].skills["c"].elapsed, 10);
assert_eq!(h.batches[0].skills["a"].elapsed, 0);
assert_eq!(h.batches[2].skills["a"].elapsed, 5);
assert_eq!(h.batches[0].skills["b"].elapsed, 0);
assert_eq!(h.batches[end].skills["b"].elapsed, 5);
h.convergence(ITERATIONS, EPSILON, false);
assert_ulps_eq!(
h.batches[0].posterior("b"),
h.batches[end].posterior("b"),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[0].posterior("c"),
h.batches[end].posterior("c"),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[0].posterior("c"),
h.batches[0].posterior("b"),
epsilon = 0.000001
);
}
}