More tests, more code

This commit is contained in:
2022-06-19 23:52:52 +02:00
parent c9d9d59535
commit 0efb0312da
3 changed files with 472 additions and 121 deletions

View File

@@ -37,7 +37,7 @@ struct Team {
} }
#[derive(Debug)] #[derive(Debug)]
struct Event { pub(crate) struct Event {
teams: Vec<Team>, teams: Vec<Team>,
evidence: f64, evidence: f64,
weights: Vec<Vec<f64>>, weights: Vec<Vec<f64>>,
@@ -53,7 +53,7 @@ impl Event {
} }
pub(crate) struct Batch { pub(crate) struct Batch {
events: Vec<Event>, pub(crate) events: Vec<Event>,
pub(crate) skills: HashMap<String, Skill>, pub(crate) skills: HashMap<String, Skill>,
pub(crate) time: u64, pub(crate) time: u64,
p_draw: f64, p_draw: f64,
@@ -414,9 +414,40 @@ impl Batch {
} }
} }
} }
pub(crate) fn get_composition(&self) -> Vec<Vec<Vec<&str>>> {
self.events
.iter()
.map(|event| {
event
.teams
.iter()
.map(|team| {
team.items
.iter()
.map(|item| item.agent.as_str())
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
}
pub(crate) fn get_results(&self) -> Vec<Vec<f64>> {
self.events
.iter()
.map(|event| {
event
.teams
.iter()
.map(|team| team.output)
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
}
} }
fn compute_elapsed(last_time: u64, actual_time: u64) -> u64 { pub(crate) fn compute_elapsed(last_time: u64, actual_time: u64) -> u64 {
if last_time == u64::MIN { if last_time == u64::MIN {
0 0
} else if last_time == u64::MAX { } else if last_time == u64::MAX {
@@ -425,6 +456,7 @@ fn compute_elapsed(last_time: u64, actual_time: u64) -> u64 {
actual_time - last_time actual_time - last_time
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use approx::assert_ulps_eq; use approx::assert_ulps_eq;
@@ -566,14 +598,23 @@ mod tests {
let post = batch.posteriors(); let post = batch.posteriors();
assert_ulps_eq!(post["a"].mu, 25.000000, epsilon = 0.000001); assert_ulps_eq!(
assert_ulps_eq!(post["a"].sigma, 5.4192120, epsilon = 0.000001); post["a"],
Gaussian::new(25.000000, 5.4192120),
epsilon = 0.000001
);
assert_ulps_eq!(post["b"].mu, 25.000000, epsilon = 0.000001); assert_ulps_eq!(
assert_ulps_eq!(post["b"].sigma, 5.4192120, epsilon = 0.000001); post["b"],
Gaussian::new(25.000000, 5.4192120),
epsilon = 0.000001
);
assert_ulps_eq!(post["c"].mu, 25.000000, epsilon = 0.000001); assert_ulps_eq!(
assert_ulps_eq!(post["c"].sigma, 5.4192120, epsilon = 0.000001); post["c"],
Gaussian::new(25.000000, 5.4192120),
epsilon = 0.000001
);
batch.add_events( batch.add_events(
vec![ vec![
@@ -592,13 +633,20 @@ mod tests {
let post = batch.posteriors(); let post = batch.posteriors();
assert_ulps_eq!(post["a"].mu, 25.00000315330858, epsilon = 0.000001); assert_ulps_eq!(
assert_ulps_eq!(post["a"].sigma, 3.880150268080797, epsilon = 0.000001); post["a"],
Gaussian::new(25.00000315330858, 3.880150268080797),
assert_ulps_eq!(post["b"].mu, 25.00000315330858, epsilon = 0.000001); epsilon = 0.000001
assert_ulps_eq!(post["b"].sigma, 3.880150268080797, epsilon = 0.000001); );
assert_ulps_eq!(
assert_ulps_eq!(post["c"].mu, 25.00000315330858, epsilon = 0.000001); post["b"],
assert_ulps_eq!(post["c"].sigma, 3.880150268080797, epsilon = 0.000001); Gaussian::new(25.00000315330858, 3.880150268080797),
epsilon = 0.000001
);
assert_ulps_eq!(
post["c"],
Gaussian::new(25.00000315330858, 3.880150268080797),
epsilon = 0.000001
);
} }
} }

View File

@@ -24,12 +24,12 @@ impl Game {
) -> Self { ) -> Self {
assert!( assert!(
(result.is_empty() || result.len() == teams.len()), (result.is_empty() || result.len() == teams.len()),
"result.must be empty or the same length as teams" "result must be empty or the same length as teams"
); );
assert!( assert!(
(weights.is_empty() || weights.len() == teams.len()), (weights.is_empty() || weights.len() == teams.len()),
"weights.must be empty or the same length as teams" "weights must be empty or the same length as teams"
); );
assert!( assert!(
@@ -38,7 +38,7 @@ impl Game {
.iter() .iter()
.zip(teams.iter()) .zip(teams.iter())
.all(|(w, t)| w.len() == t.len()), .all(|(w, t)| w.len() == t.len()),
"weights.must be empty or has the same dimensions as teams" "weights must be empty or has the same dimensions as teams"
); );
assert!( assert!(
@@ -52,7 +52,7 @@ impl Game {
r.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); r.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap());
r.windows(2).all(|w| w[0] != w[1]) r.windows(2).all(|w| w[0] != w[1])
}, },
"draw.must be > 0.0 if there is teams with draw" "draw must be > 0.0 if there is teams with draw"
); );
if result.is_empty() { if result.is_empty() {

View File

@@ -2,7 +2,7 @@ use std::collections::{HashMap, HashSet};
use crate::{ use crate::{
agent::{self, Agent}, agent::{self, Agent},
batch::Batch, batch::{self, Batch},
gaussian::Gaussian, gaussian::Gaussian,
player::Player, player::Player,
sort_time, tuple_gt, tuple_max, sort_time, tuple_gt, tuple_max,
@@ -40,15 +40,15 @@ impl History {
) -> Self { ) -> Self {
assert!( assert!(
results.is_empty() || results.len() == composition.len(), results.is_empty() || results.len() == composition.len(),
"TODO: Add a comment here" "(length(results) > 0) & (length(composition) != length(results))"
); );
assert!( assert!(
times.is_empty() || times.len() == composition.len(), times.is_empty() || times.len() == composition.len(),
"TODO: Add a comment here" "length(times) > 0) & (length(composition) != length(times))"
); );
assert!( assert!(
weights.is_empty() || weights.len() == composition.len(), weights.is_empty() || weights.len() == composition.len(),
"TODO: Add a comment here" "(length(weights) > 0) & (length(composition) != length(weights))"
); );
let this_agent = composition let this_agent = composition
@@ -293,6 +293,183 @@ impl History {
.map(|batch| batch.log_evidence2(self.online, agents, forward, &mut self.agents)) .map(|batch| batch.log_evidence2(self.online, agents, forward, &mut self.agents))
.sum() .sum()
} }
pub fn add_events(
&mut self,
composition: Vec<Vec<Vec<&str>>>,
results: Vec<Vec<f64>>,
times: Vec<u64>,
weights: Vec<Vec<Vec<f64>>>,
priors: HashMap<String, Player>,
) {
assert!(times.is_empty() || self.time, "length(times)>0 but !h.time");
assert!(
!times.is_empty() || !self.time,
"length(times)==0 but h.time"
);
assert!(
results.is_empty() || results.len() == composition.len(),
"(length(results) > 0) & (length(composition) != length(results))"
);
assert!(
times.is_empty() || times.len() == composition.len(),
"length(times) > 0) & (length(composition) != length(times))"
);
assert!(
weights.is_empty() || weights.len() == composition.len(),
"(length(weights) > 0) & (length(composition) != length(weights))"
);
let this_agent = composition
.iter()
.flatten()
.flatten()
.cloned()
.collect::<HashSet<_>>();
for agent in &this_agent {
if !self.agents.contains_key(*agent) {
self.agents.insert(
agent.to_string(),
Agent {
player: priors.get(*agent).cloned().unwrap_or_else(|| {
Player::new(Gaussian::new(self.mu, self.sigma), self.beta, self.gamma)
}),
..Default::default()
},
);
}
}
agent::clean(self.agents.values_mut(), true);
let n = composition.len();
let o = if self.time {
sort_time(&times, false)
} else {
(0..composition.len()).collect::<Vec<_>>()
};
let mut i = 0;
let mut k = 0;
while i < n {
let mut j = i + 1;
let t = if self.time { times[o[i]] } else { i as u64 + 1 };
while self.time && j < self.size && times[o[j]] == t {
j += 1;
}
while (!self.time && (self.size > k))
|| (self.time && self.batches.len() > k && self.batches[k].time < t)
{
let b = &mut self.batches[k];
if k > 0 {
b.new_forward_info(&mut self.agents);
}
let intersect = this_agent
.iter()
.filter(|&&agent| b.skills.contains_key(agent))
.cloned()
.collect::<Vec<_>>();
for agent in intersect {
b.skills.get_mut(agent).unwrap().elapsed =
batch::compute_elapsed(self.agents[agent].last_time, b.time);
let a = self.agents.get_mut(agent).unwrap();
a.last_time = if self.time { b.time } else { u64::MAX };
a.message = b.forward_prior_out(agent);
}
k += 1;
}
let composition = (i..j)
.map(|e| composition[o[e]].clone())
.collect::<Vec<_>>();
let results = if results.is_empty() {
Vec::new()
} else {
(i..j).map(|e| results[o[e]].clone()).collect::<Vec<_>>()
};
let weights = if weights.is_empty() {
Vec::new()
} else {
(i..j).map(|e| weights[o[e]].clone()).collect::<Vec<_>>()
};
if self.time && self.batches.len() > k && self.batches[k].time == t {
let b = &mut self.batches[k];
b.add_events(composition, results, weights, &mut self.agents);
for a in b.skills.keys() {
let agent = self.agents.get_mut(a).unwrap();
agent.last_time = if self.time { t } else { u64::MAX };
agent.message = b.forward_prior_out(a);
}
} else {
let b = Batch::new(
composition,
results,
weights,
t,
self.p_draw,
&mut self.agents,
);
self.batches.insert(k, b);
let b = &mut self.batches[k];
for a in b.skills.keys() {
let agent = self.agents.get_mut(a).unwrap();
agent.last_time = if self.time { t } else { u64::MAX };
agent.message = b.forward_prior_out(a);
}
k += 1;
}
i = j;
}
while self.time && self.batches.len() > k {
let b = &mut self.batches[k];
b.new_backward_info(&mut self.agents);
let intersect = this_agent
.iter()
.filter(|&&agent| b.skills.contains_key(agent))
.cloned()
.collect::<Vec<_>>();
for agent in intersect {
b.skills.get_mut(agent).unwrap().elapsed =
batch::compute_elapsed(self.agents[agent].last_time, b.time);
let a = self.agents.get_mut(agent).unwrap();
a.last_time = if self.time { b.time } else { u64::MAX };
a.message = b.forward_prior_out(agent);
}
k += 1;
}
self.size += n;
self.iteration();
}
} }
#[cfg(test)] #[cfg(test)]
@@ -341,8 +518,11 @@ mod tests {
let p0 = h.batches[0].posteriors(); let p0 = h.batches[0].posteriors();
assert_ulps_eq!(p0["a"].mu, 29.205220743876975, epsilon = 0.000001); assert_ulps_eq!(
assert_ulps_eq!(p0["a"].sigma, 7.194481422570443, epsilon = 0.000001); p0["a"],
Gaussian::new(29.205220743876975, 7.194481422570443),
epsilon = 0.000001
);
let observed = h.batches[1].skills["a"].forward.sigma; let observed = h.batches[1].skills["a"].forward.sigma;
let gamma: f64 = 0.15 * 25.0 / 3.0; let gamma: f64 = 0.15 * 25.0 / 3.0;
@@ -351,6 +531,7 @@ mod tests {
assert_ulps_eq!(observed, expected, epsilon = 0.000001); assert_ulps_eq!(observed, expected, epsilon = 0.000001);
let observed = h.batches[1].posterior("a"); let observed = h.batches[1].posterior("a");
let p = Game::new( let p = Game::new(
h.batches[1].within_priors(0, false, false, &mut h.agents), h.batches[1].within_priors(0, false, false, &mut h.agents),
vec![0.0, 1.0], vec![0.0, 1.0],
@@ -360,8 +541,7 @@ mod tests {
.posteriors(); .posteriors();
let expected = p[0][0]; let expected = p[0][0];
assert_ulps_eq!(observed.mu, expected.mu, epsilon = 0.000001); assert_ulps_eq!(observed, expected, epsilon = 0.000001);
assert_ulps_eq!(observed.sigma, expected.sigma, epsilon = 0.000001);
} }
#[test] #[test]
@@ -401,46 +581,26 @@ mod tests {
); );
assert_ulps_eq!( assert_ulps_eq!(
h1.batches[0].posterior("a").mu, h1.batches[0].posterior("a"),
22.904409330892914, Gaussian::new(22.904409330892914, 6.0103304390431),
epsilon = 0.000001 epsilon = 0.000001
); );
assert_ulps_eq!( assert_ulps_eq!(
h1.batches[0].posterior("a").sigma, h1.batches[0].posterior("c"),
6.0103304390431, Gaussian::new(25.110318212568806, 5.866311348102563),
epsilon = 0.000001
);
assert_ulps_eq!(
h1.batches[0].posterior("c").mu,
25.110318212568806,
epsilon = 0.000001
);
assert_ulps_eq!(
h1.batches[0].posterior("c").sigma,
5.866311348102563,
epsilon = 0.000001 epsilon = 0.000001
); );
let (_step, _i) = h1.convergence(ITERATIONS, EPSILON, false); let (_step, _i) = h1.convergence(ITERATIONS, EPSILON, false);
assert_ulps_eq!( assert_ulps_eq!(
h1.batches[0].posterior("a").mu, h1.batches[0].posterior("a"),
25.00000000, Gaussian::new(25.00000000, 5.41921200),
epsilon = 0.000001 epsilon = 0.000001
); );
assert_ulps_eq!( assert_ulps_eq!(
h1.batches[0].posterior("a").sigma, h1.batches[0].posterior("c"),
5.41921200, Gaussian::new(25.00000000, 5.41921200),
epsilon = 0.000001
);
assert_ulps_eq!(
h1.batches[0].posterior("c").mu,
25.00000000,
epsilon = 0.000001
);
assert_ulps_eq!(
h1.batches[0].posterior("c").sigma,
5.41921200,
epsilon = 0.000001 epsilon = 0.000001
); );
@@ -475,46 +635,26 @@ mod tests {
); );
assert_ulps_eq!( assert_ulps_eq!(
h2.batches[2].posterior("a").mu, h2.batches[2].posterior("a"),
22.90352227792141, Gaussian::new(22.90352227792141, 6.011017301320632),
epsilon = 0.000001 epsilon = 0.000001
); );
assert_ulps_eq!( assert_ulps_eq!(
h2.batches[2].posterior("a").sigma, h2.batches[2].posterior("c"),
6.011017301320632, Gaussian::new(25.110702468366718, 5.866811597660157),
epsilon = 0.000001
);
assert_ulps_eq!(
h2.batches[2].posterior("c").mu,
25.110702468366718,
epsilon = 0.000001
);
assert_ulps_eq!(
h2.batches[2].posterior("c").sigma,
5.866811597660157,
epsilon = 0.000001 epsilon = 0.000001
); );
let (_step, _i) = h2.convergence(ITERATIONS, EPSILON, false); h2.convergence(ITERATIONS, EPSILON, false);
assert_ulps_eq!( assert_ulps_eq!(
h2.batches[2].posterior("a").mu, h2.batches[2].posterior("a"),
24.99866831022851, Gaussian::new(24.99866831022851, 5.420053708148435),
epsilon = 0.000001 epsilon = 0.000001
); );
assert_ulps_eq!( assert_ulps_eq!(
h2.batches[2].posterior("a").sigma, h2.batches[2].posterior("c"),
5.420053708148435, Gaussian::new(25.000532179593538, 5.419827012784138),
epsilon = 0.000001
);
assert_ulps_eq!(
h2.batches[2].posterior("c").mu,
25.000532179593538,
epsilon = 0.000001
);
assert_ulps_eq!(
h2.batches[2].posterior("c").sigma,
5.419827012784138,
epsilon = 0.000001 epsilon = 0.000001
); );
} }
@@ -562,23 +702,13 @@ mod tests {
assert_eq!(lc["a"][aj_e - 1].0, 7); assert_eq!(lc["a"][aj_e - 1].0, 7);
assert_ulps_eq!( assert_ulps_eq!(
lc["a"][aj_e - 1].1.mu, lc["a"][aj_e - 1].1,
24.99866831022851, Gaussian::new(24.99866831022851, 5.420053708148435),
epsilon = 0.000001 epsilon = 0.000001
); );
assert_ulps_eq!( assert_ulps_eq!(
lc["a"][aj_e - 1].1.sigma, lc["c"][cj_e - 1].1,
5.420053708148435, Gaussian::new(25.000532179593538, 5.419827012784138),
epsilon = 0.000001
);
assert_ulps_eq!(
lc["c"][cj_e - 1].1.mu,
25.000532179593538,
epsilon = 0.000001
);
assert_ulps_eq!(
lc["c"][cj_e - 1].1.sigma,
5.419827012784138,
epsilon = 0.000001 epsilon = 0.000001
); );
} }
@@ -612,33 +742,18 @@ mod tests {
assert_eq!(h.batches[2].skills["c"].elapsed, 1); assert_eq!(h.batches[2].skills["c"].elapsed, 1);
assert_ulps_eq!( assert_ulps_eq!(
h.batches[0].posterior("a").mu, h.batches[0].posterior("a"),
25.0002673, Gaussian::new(25.0002673, 5.41938162),
epsilon = 0.000001 epsilon = 0.000001
); );
assert_ulps_eq!( assert_ulps_eq!(
h.batches[0].posterior("a").sigma, h.batches[0].posterior("b"),
5.41938162, Gaussian::new(24.999465, 5.419425831),
epsilon = 0.000001 epsilon = 0.000001
); );
assert_ulps_eq!( assert_ulps_eq!(
h.batches[0].posterior("b").mu, h.batches[2].posterior("b"),
24.999465, Gaussian::new(25.00053219, 5.419696790),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[0].posterior("b").sigma,
5.419425831,
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[2].posterior("b").mu,
25.00053219,
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[2].posterior("b").sigma,
5.419696790,
epsilon = 0.000001 epsilon = 0.000001
); );
} }
@@ -727,4 +842,192 @@ mod tests {
epsilon = 0.000001 epsilon = 0.000001
); );
} }
#[test]
fn test_add_events() {
let composition = vec![
vec![vec!["a"], vec!["b"]],
vec![vec!["a"], vec!["c"]],
vec![vec!["b"], vec!["c"]],
];
let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]];
let mut h = History::new(
composition.clone(),
results.clone(),
vec![],
vec![],
HashMap::new(),
0.0,
2.0,
1.0,
0.0,
0.0,
false,
);
let (_step, _i) = h.convergence(ITERATIONS, EPSILON, false);
assert_eq!(h.batches[2].skills["b"].elapsed, 1);
assert_eq!(h.batches[2].skills["c"].elapsed, 1);
assert_ulps_eq!(
h.batches[0].posterior("a"),
Gaussian::new(0.0, 1.30061),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[0].posterior("b"),
Gaussian::new(0.0, 1.30061),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[2].posterior("b"),
Gaussian::new(0.0, 1.30061),
epsilon = 0.000001
);
h.add_events(composition, results, vec![], vec![], HashMap::new());
assert_eq!(h.batches.len(), 6);
assert_eq!(
h.batches
.iter()
.map(|b| b.get_composition())
.collect::<Vec<_>>(),
vec![
vec![vec![vec!["a"], vec!["b"]]],
vec![vec![vec!["a"], vec!["c"]]],
vec![vec![vec!["b"], vec!["c"]]],
vec![vec![vec!["a"], vec!["b"]]],
vec![vec![vec!["a"], vec!["c"]]],
vec![vec![vec!["b"], vec!["c"]]]
]
);
h.convergence(ITERATIONS, EPSILON, false);
assert_ulps_eq!(
h.batches[0].posterior("a"),
Gaussian::new(0.0, 0.9312360609998878),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[3].posterior("a"),
Gaussian::new(0.0, 0.9312360609998878),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[3].posterior("b"),
Gaussian::new(0.0, 0.9312360609998878),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[5].posterior("b"),
Gaussian::new(0.0, 0.9312360609998878),
epsilon = 0.000001
);
}
#[test]
fn test_add_events_with_time() {
let composition = vec![
vec![vec!["a"], vec!["b"]],
vec![vec!["a"], vec!["c"]],
vec![vec!["b"], vec!["c"]],
];
let results = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]];
let mut h = History::new(
composition.clone(),
results.clone(),
vec![0, 10, 20],
vec![],
HashMap::new(),
0.0,
2.0,
1.0,
0.0,
0.0,
false,
);
h.convergence(ITERATIONS, EPSILON, false);
h.add_events(
composition,
results,
vec![15, 10, 0],
vec![],
HashMap::new(),
);
assert_eq!(h.batches.len(), 4);
assert_eq!(
h.batches
.iter()
.map(|batch| batch.events.len())
.collect::<Vec<_>>(),
vec![2, 2, 1, 1]
);
assert_eq!(
h.batches
.iter()
.map(|b| b.get_composition())
.collect::<Vec<_>>(),
vec![
vec![vec![vec!["a"], vec!["b"]], vec![vec!["b"], vec!["c"]]],
vec![vec![vec!["a"], vec!["c"]], vec![vec!["a"], vec!["c"]]],
vec![vec![vec!["a"], vec!["b"]]],
vec![vec![vec!["b"], vec!["c"]]]
]
);
assert_eq!(
h.batches
.iter()
.map(|b| b.get_results())
.collect::<Vec<_>>(),
vec![
vec![vec![1.0, 0.0], vec![1.0, 0.0]],
vec![vec![0.0, 1.0], vec![0.0, 1.0]],
vec![vec![1.0, 0.0]],
vec![vec![1.0, 0.0]]
]
);
let end = h.batches.len() - 1;
assert_eq!(h.batches[0].skills["c"].elapsed, 0);
assert_eq!(h.batches[end].skills["c"].elapsed, 10);
assert_eq!(h.batches[0].skills["a"].elapsed, 0);
assert_eq!(h.batches[2].skills["a"].elapsed, 5);
assert_eq!(h.batches[0].skills["b"].elapsed, 0);
assert_eq!(h.batches[end].skills["b"].elapsed, 5);
h.convergence(ITERATIONS, EPSILON, false);
assert_ulps_eq!(
h.batches[0].posterior("b"),
h.batches[end].posterior("b"),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[0].posterior("c"),
h.batches[end].posterior("c"),
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[0].posterior("c"),
h.batches[0].posterior("b"),
epsilon = 0.000001
);
}
} }