Use and Index struct instead of str and String for player id
This commit is contained in:
@@ -12,7 +12,7 @@ Rust port of [TrueSkillThroughTime.py](https://github.com/glandfried/TrueSkillTh
|
|||||||
## Todo
|
## Todo
|
||||||
|
|
||||||
- [x] Implement approx for Gaussian
|
- [x] Implement approx for Gaussian
|
||||||
- [ ] Add more tests from `TrueSkillThroughTime.jl`
|
- [x] Add more tests from `TrueSkillThroughTime.jl`
|
||||||
- [ ] Time needs to be an enum so we can have multiple states (see `batch::compute_elapsed()`)
|
- [ ] Time needs to be an enum so we can have multiple states (see `batch::compute_elapsed()`)
|
||||||
- [ ] Add examples (use same TrueSkillThroughTime.(py|jl))
|
- [ ] Add examples (use same TrueSkillThroughTime.(py|jl))
|
||||||
- [ ] Add Observer (see [argmin](https://docs.rs/argmin/latest/argmin/core/trait.Observe.html) for inspiration)
|
- [ ] Add Observer (see [argmin](https://docs.rs/argmin/latest/argmin/core/trait.Observe.html) for inspiration)
|
||||||
|
|||||||
237
src/batch.rs
237
src/batch.rs
@@ -1,7 +1,7 @@
|
|||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
agent::Agent, game::Game, gaussian::Gaussian, player::Player, tuple_gt, tuple_max, N_INF,
|
agent::Agent, game::Game, gaussian::Gaussian, player::Player, tuple_gt, tuple_max, Index, N_INF,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -27,7 +27,7 @@ impl Default for Skill {
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Item {
|
struct Item {
|
||||||
agent: String,
|
agent: Index,
|
||||||
likelihood: Gaussian,
|
likelihood: Gaussian,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -56,19 +56,19 @@ impl Event {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Batch {
|
pub struct Batch {
|
||||||
pub(crate) events: Vec<Event>,
|
pub(crate) events: Vec<Event>,
|
||||||
pub(crate) skills: HashMap<String, Skill>,
|
pub(crate) skills: HashMap<Index, Skill>,
|
||||||
pub(crate) time: u64,
|
pub(crate) time: u64,
|
||||||
p_draw: f64,
|
p_draw: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Batch {
|
impl Batch {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
composition: Vec<Vec<Vec<&str>>>,
|
composition: Vec<Vec<Vec<Index>>>,
|
||||||
results: Vec<Vec<f64>>,
|
results: Vec<Vec<f64>>,
|
||||||
weights: Vec<Vec<Vec<f64>>>,
|
weights: Vec<Vec<Vec<f64>>>,
|
||||||
time: u64,
|
time: u64,
|
||||||
p_draw: f64,
|
p_draw: f64,
|
||||||
agents: &mut HashMap<String, Agent>,
|
agents: &mut HashMap<Index, Agent>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
assert!(
|
assert!(
|
||||||
results.is_empty() || results.len() == composition.len(),
|
results.is_empty() || results.len() == composition.len(),
|
||||||
@@ -88,17 +88,17 @@ impl Batch {
|
|||||||
|
|
||||||
let elapsed = this_agent
|
let elapsed = this_agent
|
||||||
.iter()
|
.iter()
|
||||||
.map(|&a| (a, compute_elapsed(agents[a].last_time, time)))
|
.map(|&idx| (idx, compute_elapsed(agents[&idx].last_time, time)))
|
||||||
.collect::<HashMap<_, _>>();
|
.collect::<HashMap<_, _>>();
|
||||||
|
|
||||||
let skills = this_agent
|
let skills = this_agent
|
||||||
.iter()
|
.iter()
|
||||||
.map(|&a| {
|
.map(|&idx| {
|
||||||
(
|
(
|
||||||
a.to_string(),
|
idx,
|
||||||
Skill {
|
Skill {
|
||||||
forward: agents[a].receive(elapsed[a]),
|
forward: agents[&idx].receive(elapsed[&idx]),
|
||||||
elapsed: elapsed[a],
|
elapsed: elapsed[&idx],
|
||||||
..Default::default()
|
..Default::default()
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -111,7 +111,7 @@ impl Batch {
|
|||||||
.map(|t| {
|
.map(|t| {
|
||||||
let items = (0..composition[e][t].len())
|
let items = (0..composition[e][t].len())
|
||||||
.map(|a| Item {
|
.map(|a| Item {
|
||||||
agent: composition[e][t][a].to_string(),
|
agent: composition[e][t][a],
|
||||||
likelihood: N_INF,
|
likelihood: N_INF,
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
@@ -153,10 +153,10 @@ impl Batch {
|
|||||||
|
|
||||||
pub(crate) fn add_events(
|
pub(crate) fn add_events(
|
||||||
&mut self,
|
&mut self,
|
||||||
composition: Vec<Vec<Vec<&str>>>,
|
composition: Vec<Vec<Vec<Index>>>,
|
||||||
results: Vec<Vec<f64>>,
|
results: Vec<Vec<f64>>,
|
||||||
weights: Vec<Vec<Vec<f64>>>,
|
weights: Vec<Vec<Vec<f64>>>,
|
||||||
agents: &mut HashMap<String, Agent>,
|
agents: &mut HashMap<Index, Agent>,
|
||||||
) {
|
) {
|
||||||
let this_agent = composition
|
let this_agent = composition
|
||||||
.iter()
|
.iter()
|
||||||
@@ -165,17 +165,17 @@ impl Batch {
|
|||||||
.cloned()
|
.cloned()
|
||||||
.collect::<HashSet<_>>();
|
.collect::<HashSet<_>>();
|
||||||
|
|
||||||
for a in this_agent {
|
for idx in this_agent {
|
||||||
let elapsed = compute_elapsed(agents[a].last_time, self.time);
|
let elapsed = compute_elapsed(agents[&idx].last_time, self.time);
|
||||||
|
|
||||||
if let Some(skill) = self.skills.get_mut(a) {
|
if let Some(skill) = self.skills.get_mut(&idx) {
|
||||||
skill.elapsed = elapsed;
|
skill.elapsed = elapsed;
|
||||||
skill.forward = agents[a].receive(elapsed);
|
skill.forward = agents[&idx].receive(elapsed);
|
||||||
} else {
|
} else {
|
||||||
self.skills.insert(
|
self.skills.insert(
|
||||||
a.to_string(),
|
idx,
|
||||||
Skill {
|
Skill {
|
||||||
forward: agents[a].receive(elapsed),
|
forward: agents[&idx].receive(elapsed),
|
||||||
elapsed,
|
elapsed,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
},
|
},
|
||||||
@@ -190,7 +190,7 @@ impl Batch {
|
|||||||
.map(|t| {
|
.map(|t| {
|
||||||
let items = (0..composition[e][t].len())
|
let items = (0..composition[e][t].len())
|
||||||
.map(|a| Item {
|
.map(|a| Item {
|
||||||
agent: composition[e][t][a].to_string(),
|
agent: composition[e][t][a],
|
||||||
likelihood: N_INF,
|
likelihood: N_INF,
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
@@ -222,16 +222,16 @@ impl Batch {
|
|||||||
self.iteration(from, agents);
|
self.iteration(from, agents);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn posterior(&self, agent: &str) -> Gaussian {
|
pub fn posterior(&self, agent: Index) -> Gaussian {
|
||||||
let skill = &self.skills[agent];
|
let skill = &self.skills[&agent];
|
||||||
|
|
||||||
skill.likelihood * skill.backward * skill.forward
|
skill.likelihood * skill.backward * skill.forward
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn posteriors(&self) -> HashMap<String, Gaussian> {
|
pub(crate) fn posteriors(&self) -> HashMap<Index, Gaussian> {
|
||||||
self.skills
|
self.skills
|
||||||
.keys()
|
.keys()
|
||||||
.map(|a| (a.to_string(), self.posterior(a)))
|
.map(|&idx| (idx, self.posterior(idx)))
|
||||||
.collect::<HashMap<_, _>>()
|
.collect::<HashMap<_, _>>()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -240,7 +240,7 @@ impl Batch {
|
|||||||
item: &Item,
|
item: &Item,
|
||||||
online: bool,
|
online: bool,
|
||||||
forward: bool,
|
forward: bool,
|
||||||
agents: &mut HashMap<String, Agent>,
|
agents: &mut HashMap<Index, Agent>,
|
||||||
) -> Player {
|
) -> Player {
|
||||||
let r = &agents[&item.agent].player;
|
let r = &agents[&item.agent].player;
|
||||||
|
|
||||||
@@ -249,7 +249,7 @@ impl Batch {
|
|||||||
} else if forward {
|
} else if forward {
|
||||||
Player::new(self.skills[&item.agent].forward, r.beta, r.gamma)
|
Player::new(self.skills[&item.agent].forward, r.beta, r.gamma)
|
||||||
} else {
|
} else {
|
||||||
let wp = self.posterior(&item.agent) / item.likelihood;
|
let wp = self.posterior(item.agent) / item.likelihood;
|
||||||
|
|
||||||
Player::new(wp, r.beta, r.gamma)
|
Player::new(wp, r.beta, r.gamma)
|
||||||
}
|
}
|
||||||
@@ -260,7 +260,7 @@ impl Batch {
|
|||||||
event: usize,
|
event: usize,
|
||||||
online: bool,
|
online: bool,
|
||||||
forward: bool,
|
forward: bool,
|
||||||
agents: &mut HashMap<String, Agent>,
|
agents: &mut HashMap<Index, Agent>,
|
||||||
) -> Vec<Vec<Player>> {
|
) -> Vec<Vec<Player>> {
|
||||||
self.events[event]
|
self.events[event]
|
||||||
.teams
|
.teams
|
||||||
@@ -274,7 +274,7 @@ impl Batch {
|
|||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn iteration(&mut self, from: usize, agents: &mut HashMap<String, Agent>) {
|
pub(crate) fn iteration(&mut self, from: usize, agents: &mut HashMap<Index, Agent>) {
|
||||||
for e in from..self.events.len() {
|
for e in from..self.events.len() {
|
||||||
let teams = self.within_priors(e, false, false, agents);
|
let teams = self.within_priors(e, false, false, agents);
|
||||||
let result = self.events[e].outputs();
|
let result = self.events[e].outputs();
|
||||||
@@ -295,7 +295,7 @@ impl Batch {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn convergence(&mut self, agents: &mut HashMap<String, Agent>) -> usize {
|
pub(crate) fn convergence(&mut self, agents: &mut HashMap<Index, Agent>) -> usize {
|
||||||
let epsilon = 1e-6;
|
let epsilon = 1e-6;
|
||||||
let iterations = 20;
|
let iterations = 20;
|
||||||
|
|
||||||
@@ -319,7 +319,7 @@ impl Batch {
|
|||||||
i
|
i
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn forward_prior_out(&self, agent: &str) -> Gaussian {
|
pub(crate) fn forward_prior_out(&self, agent: &Index) -> Gaussian {
|
||||||
let skill = &self.skills[agent];
|
let skill = &self.skills[agent];
|
||||||
|
|
||||||
skill.forward * skill.likelihood
|
skill.forward * skill.likelihood
|
||||||
@@ -327,8 +327,8 @@ impl Batch {
|
|||||||
|
|
||||||
pub(crate) fn backward_prior_out(
|
pub(crate) fn backward_prior_out(
|
||||||
&self,
|
&self,
|
||||||
agent: &str,
|
agent: &Index,
|
||||||
agents: &mut HashMap<String, Agent>,
|
agents: &mut HashMap<Index, Agent>,
|
||||||
) -> Gaussian {
|
) -> Gaussian {
|
||||||
let skill = &self.skills[agent];
|
let skill = &self.skills[agent];
|
||||||
let n = skill.likelihood * skill.backward;
|
let n = skill.likelihood * skill.backward;
|
||||||
@@ -336,7 +336,7 @@ impl Batch {
|
|||||||
n.forget(agents[agent].player.gamma, skill.elapsed)
|
n.forget(agents[agent].player.gamma, skill.elapsed)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn new_backward_info(&mut self, agents: &mut HashMap<String, Agent>) {
|
pub(crate) fn new_backward_info(&mut self, agents: &mut HashMap<Index, Agent>) {
|
||||||
for (agent, skill) in self.skills.iter_mut() {
|
for (agent, skill) in self.skills.iter_mut() {
|
||||||
skill.backward = agents[agent].message;
|
skill.backward = agents[agent].message;
|
||||||
}
|
}
|
||||||
@@ -344,7 +344,7 @@ impl Batch {
|
|||||||
self.iteration(0, agents);
|
self.iteration(0, agents);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn new_forward_info(&mut self, agents: &mut HashMap<String, Agent>) {
|
pub(crate) fn new_forward_info(&mut self, agents: &mut HashMap<Index, Agent>) {
|
||||||
for (agent, skill) in self.skills.iter_mut() {
|
for (agent, skill) in self.skills.iter_mut() {
|
||||||
skill.forward = agents[agent].receive(skill.elapsed);
|
skill.forward = agents[agent].receive(skill.elapsed);
|
||||||
}
|
}
|
||||||
@@ -352,14 +352,14 @@ impl Batch {
|
|||||||
self.iteration(0, agents);
|
self.iteration(0, agents);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn log_evidence2(
|
pub(crate) fn log_evidence(
|
||||||
&self,
|
&self,
|
||||||
online: bool,
|
online: bool,
|
||||||
agents2: &Vec<&str>,
|
targets: &[Index],
|
||||||
forward: bool,
|
forward: bool,
|
||||||
agents: &mut HashMap<String, Agent>,
|
agents: &mut HashMap<Index, Agent>,
|
||||||
) -> f64 {
|
) -> f64 {
|
||||||
if agents2.is_empty() {
|
if targets.is_empty() {
|
||||||
if online || forward {
|
if online || forward {
|
||||||
self.events
|
self.events
|
||||||
.iter()
|
.iter()
|
||||||
@@ -388,7 +388,7 @@ impl Batch {
|
|||||||
.teams
|
.teams
|
||||||
.iter()
|
.iter()
|
||||||
.flat_map(|team| &team.items)
|
.flat_map(|team| &team.items)
|
||||||
.any(|item| agents2.contains(&item.agent.as_str()))
|
.any(|item| targets.contains(&item.agent))
|
||||||
})
|
})
|
||||||
.map(|(e, event)| {
|
.map(|(e, event)| {
|
||||||
Game::new(
|
Game::new(
|
||||||
@@ -409,7 +409,7 @@ impl Batch {
|
|||||||
.teams
|
.teams
|
||||||
.iter()
|
.iter()
|
||||||
.flat_map(|team| &team.items)
|
.flat_map(|team| &team.items)
|
||||||
.any(|item| agents2.contains(&item.agent.as_str()))
|
.any(|item| targets.contains(&item.agent))
|
||||||
})
|
})
|
||||||
.map(|event| event.evidence.ln())
|
.map(|event| event.evidence.ln())
|
||||||
.sum()
|
.sum()
|
||||||
@@ -417,19 +417,14 @@ impl Batch {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn get_composition(&self) -> Vec<Vec<Vec<&str>>> {
|
pub(crate) fn get_composition(&self) -> Vec<Vec<Vec<Index>>> {
|
||||||
self.events
|
self.events
|
||||||
.iter()
|
.iter()
|
||||||
.map(|event| {
|
.map(|event| {
|
||||||
event
|
event
|
||||||
.teams
|
.teams
|
||||||
.iter()
|
.iter()
|
||||||
.map(|team| {
|
.map(|team| team.items.iter().map(|item| item.agent).collect::<Vec<_>>())
|
||||||
team.items
|
|
||||||
.iter()
|
|
||||||
.map(|item| item.agent.as_str())
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
@@ -463,17 +458,26 @@ pub(crate) fn compute_elapsed(last_time: u64, actual_time: u64) -> u64 {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use approx::assert_ulps_eq;
|
use approx::assert_ulps_eq;
|
||||||
|
|
||||||
use crate::{agent::Agent, player::Player};
|
use crate::{agent::Agent, player::Player, IndexMap};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_one_event_each() {
|
fn test_one_event_each() {
|
||||||
|
let mut index_map = IndexMap::new();
|
||||||
|
|
||||||
|
let a = index_map.get_or_create("a");
|
||||||
|
let b = index_map.get_or_create("b");
|
||||||
|
let c = index_map.get_or_create("c");
|
||||||
|
let d = index_map.get_or_create("d");
|
||||||
|
let e = index_map.get_or_create("e");
|
||||||
|
let f = index_map.get_or_create("f");
|
||||||
|
|
||||||
let mut agents = HashMap::new();
|
let mut agents = HashMap::new();
|
||||||
|
|
||||||
for agent in ["a", "b", "c", "d", "e", "f"] {
|
for agent in [a, b, c, d, e, f] {
|
||||||
agents.insert(
|
agents.insert(
|
||||||
agent.to_string(),
|
agent,
|
||||||
Agent {
|
Agent {
|
||||||
player: Player::new(Gaussian::new(25.0, 25.0 / 3.0), 25.0 / 6.0, 25.0 / 300.0),
|
player: Player::new(Gaussian::new(25.0, 25.0 / 3.0), 25.0 / 6.0, 25.0 / 300.0),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
@@ -483,9 +487,9 @@ mod tests {
|
|||||||
|
|
||||||
let mut batch = Batch::new(
|
let mut batch = Batch::new(
|
||||||
vec![
|
vec![
|
||||||
vec![vec!["a"], vec!["b"]],
|
vec![vec![a], vec![b]],
|
||||||
vec![vec!["c"], vec!["d"]],
|
vec![vec![c], vec![d]],
|
||||||
vec![vec!["e"], vec!["f"]],
|
vec![vec![e], vec![f]],
|
||||||
],
|
],
|
||||||
vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]],
|
vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]],
|
||||||
vec![],
|
vec![],
|
||||||
@@ -496,34 +500,32 @@ mod tests {
|
|||||||
|
|
||||||
let post = batch.posteriors();
|
let post = batch.posteriors();
|
||||||
|
|
||||||
assert_ulps_eq!(post["a"].mu, 29.205220743876975, epsilon = 0.000001);
|
assert_ulps_eq!(post[&a], Gaussian::new(29.205220, 7.194481), epsilon = 1e-6);
|
||||||
assert_ulps_eq!(post["a"].sigma, 7.194481422570443, epsilon = 0.000001);
|
assert_ulps_eq!(post[&b], Gaussian::new(20.794779, 7.194481), epsilon = 1e-6);
|
||||||
|
assert_ulps_eq!(post[&c], Gaussian::new(20.794779, 7.194481), epsilon = 1e-6);
|
||||||
assert_ulps_eq!(post["b"].mu, 20.79477925612302, epsilon = 0.000001);
|
assert_ulps_eq!(post[&d], Gaussian::new(29.205220, 7.194481), epsilon = 1e-6);
|
||||||
assert_ulps_eq!(post["b"].sigma, 7.194481422570443, epsilon = 0.000001);
|
assert_ulps_eq!(post[&e], Gaussian::new(29.205220, 7.194481), epsilon = 1e-6);
|
||||||
|
assert_ulps_eq!(post[&f], Gaussian::new(20.794779, 7.194481), epsilon = 1e-6);
|
||||||
assert_ulps_eq!(post["c"].mu, 20.79477925612302, epsilon = 0.000001);
|
|
||||||
assert_ulps_eq!(post["c"].sigma, 7.194481422570443, epsilon = 0.000001);
|
|
||||||
|
|
||||||
assert_ulps_eq!(post["d"].mu, 29.205220743876975, epsilon = 0.000001);
|
|
||||||
assert_ulps_eq!(post["d"].sigma, 7.194481422570443, epsilon = 0.000001);
|
|
||||||
|
|
||||||
assert_ulps_eq!(post["e"].mu, 29.205220743876975, epsilon = 0.000001);
|
|
||||||
assert_ulps_eq!(post["e"].sigma, 7.194481422570443, epsilon = 0.000001);
|
|
||||||
|
|
||||||
assert_ulps_eq!(post["f"].mu, 20.79477925612302, epsilon = 0.000001);
|
|
||||||
assert_ulps_eq!(post["f"].sigma, 7.194481422570443, epsilon = 0.000001);
|
|
||||||
|
|
||||||
assert_eq!(batch.convergence(&mut agents), 1);
|
assert_eq!(batch.convergence(&mut agents), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_same_strength() {
|
fn test_same_strength() {
|
||||||
|
let mut index_map = IndexMap::new();
|
||||||
|
|
||||||
|
let a = index_map.get_or_create("a");
|
||||||
|
let b = index_map.get_or_create("b");
|
||||||
|
let c = index_map.get_or_create("c");
|
||||||
|
let d = index_map.get_or_create("d");
|
||||||
|
let e = index_map.get_or_create("e");
|
||||||
|
let f = index_map.get_or_create("f");
|
||||||
|
|
||||||
let mut agents = HashMap::new();
|
let mut agents = HashMap::new();
|
||||||
|
|
||||||
for agent in ["a", "b", "c", "d", "e", "f"] {
|
for agent in [a, b, c, d, e, f] {
|
||||||
agents.insert(
|
agents.insert(
|
||||||
agent.to_string(),
|
agent,
|
||||||
Agent {
|
Agent {
|
||||||
player: Player::new(Gaussian::new(25.0, 25.0 / 3.0), 25.0 / 6.0, 25.0 / 300.0),
|
player: Player::new(Gaussian::new(25.0, 25.0 / 3.0), 25.0 / 6.0, 25.0 / 300.0),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
@@ -533,9 +535,9 @@ mod tests {
|
|||||||
|
|
||||||
let mut batch = Batch::new(
|
let mut batch = Batch::new(
|
||||||
vec![
|
vec![
|
||||||
vec![vec!["a"], vec!["b"]],
|
vec![vec![a], vec![b]],
|
||||||
vec![vec!["a"], vec!["c"]],
|
vec![vec![a], vec![c]],
|
||||||
vec![vec!["b"], vec!["c"]],
|
vec![vec![b], vec![c]],
|
||||||
],
|
],
|
||||||
vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]],
|
vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]],
|
||||||
vec![],
|
vec![],
|
||||||
@@ -546,36 +548,35 @@ mod tests {
|
|||||||
|
|
||||||
let post = batch.posteriors();
|
let post = batch.posteriors();
|
||||||
|
|
||||||
assert_ulps_eq!(post["a"].mu, 24.96097857478182, epsilon = 0.000001);
|
assert_ulps_eq!(post[&a], Gaussian::new(24.960978, 6.298544), epsilon = 1e-6);
|
||||||
assert_ulps_eq!(post["a"].sigma, 6.298544763358269, epsilon = 0.000001);
|
assert_ulps_eq!(post[&b], Gaussian::new(27.095590, 6.010330), epsilon = 1e-6);
|
||||||
|
assert_ulps_eq!(post[&c], Gaussian::new(24.889681, 5.866311), epsilon = 1e-6);
|
||||||
assert_ulps_eq!(post["b"].mu, 27.095590669107086, epsilon = 0.000001);
|
|
||||||
assert_ulps_eq!(post["b"].sigma, 6.010330439043099, epsilon = 0.000001);
|
|
||||||
|
|
||||||
assert_ulps_eq!(post["c"].mu, 24.88968178743119, epsilon = 0.000001);
|
|
||||||
assert_ulps_eq!(post["c"].sigma, 5.866311348102562, epsilon = 0.000001);
|
|
||||||
|
|
||||||
assert!(batch.convergence(&mut agents) > 1);
|
assert!(batch.convergence(&mut agents) > 1);
|
||||||
|
|
||||||
let post = batch.posteriors();
|
let post = batch.posteriors();
|
||||||
|
|
||||||
assert_ulps_eq!(post["a"].mu, 25.000000, epsilon = 0.000001);
|
assert_ulps_eq!(post[&a], Gaussian::new(25.000000, 5.419212), epsilon = 1e-6);
|
||||||
assert_ulps_eq!(post["a"].sigma, 5.4192120, epsilon = 0.000001);
|
assert_ulps_eq!(post[&b], Gaussian::new(25.000000, 5.419212), epsilon = 1e-6);
|
||||||
|
assert_ulps_eq!(post[&c], Gaussian::new(25.000000, 5.419212), epsilon = 1e-6);
|
||||||
assert_ulps_eq!(post["b"].mu, 25.000000, epsilon = 0.000001);
|
|
||||||
assert_ulps_eq!(post["b"].sigma, 5.4192120, epsilon = 0.000001);
|
|
||||||
|
|
||||||
assert_ulps_eq!(post["c"].mu, 25.000000, epsilon = 0.000001);
|
|
||||||
assert_ulps_eq!(post["c"].sigma, 5.4192120, epsilon = 0.000001);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_add_events() {
|
fn test_add_events() {
|
||||||
|
let mut index_map = IndexMap::new();
|
||||||
|
|
||||||
|
let a = index_map.get_or_create("a");
|
||||||
|
let b = index_map.get_or_create("b");
|
||||||
|
let c = index_map.get_or_create("c");
|
||||||
|
let d = index_map.get_or_create("d");
|
||||||
|
let e = index_map.get_or_create("e");
|
||||||
|
let f = index_map.get_or_create("f");
|
||||||
|
|
||||||
let mut agents = HashMap::new();
|
let mut agents = HashMap::new();
|
||||||
|
|
||||||
for agent in ["a", "b", "c", "d", "e", "f"] {
|
for agent in [a, b, c, d, e, f] {
|
||||||
agents.insert(
|
agents.insert(
|
||||||
agent.to_string(),
|
agent,
|
||||||
Agent {
|
Agent {
|
||||||
player: Player::new(Gaussian::new(25.0, 25.0 / 3.0), 25.0 / 6.0, 25.0 / 300.0),
|
player: Player::new(Gaussian::new(25.0, 25.0 / 3.0), 25.0 / 6.0, 25.0 / 300.0),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
@@ -585,9 +586,9 @@ mod tests {
|
|||||||
|
|
||||||
let mut batch = Batch::new(
|
let mut batch = Batch::new(
|
||||||
vec![
|
vec![
|
||||||
vec![vec!["a"], vec!["b"]],
|
vec![vec![a], vec![b]],
|
||||||
vec![vec!["a"], vec!["c"]],
|
vec![vec![a], vec![c]],
|
||||||
vec![vec!["b"], vec!["c"]],
|
vec![vec![b], vec![c]],
|
||||||
],
|
],
|
||||||
vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]],
|
vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]],
|
||||||
vec![],
|
vec![],
|
||||||
@@ -600,29 +601,15 @@ mod tests {
|
|||||||
|
|
||||||
let post = batch.posteriors();
|
let post = batch.posteriors();
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(post[&a], Gaussian::new(25.000000, 5.419212), epsilon = 1e-6);
|
||||||
post["a"],
|
assert_ulps_eq!(post[&b], Gaussian::new(25.000000, 5.419212), epsilon = 1e-6);
|
||||||
Gaussian::new(25.000000, 5.4192120),
|
assert_ulps_eq!(post[&c], Gaussian::new(25.000000, 5.419212), epsilon = 1e-6);
|
||||||
epsilon = 0.000001
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_ulps_eq!(
|
|
||||||
post["b"],
|
|
||||||
Gaussian::new(25.000000, 5.4192120),
|
|
||||||
epsilon = 0.000001
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_ulps_eq!(
|
|
||||||
post["c"],
|
|
||||||
Gaussian::new(25.000000, 5.4192120),
|
|
||||||
epsilon = 0.000001
|
|
||||||
);
|
|
||||||
|
|
||||||
batch.add_events(
|
batch.add_events(
|
||||||
vec![
|
vec![
|
||||||
vec![vec!["a"], vec!["b"]],
|
vec![vec![a], vec![b]],
|
||||||
vec![vec!["a"], vec!["c"]],
|
vec![vec![a], vec![c]],
|
||||||
vec![vec!["b"], vec!["c"]],
|
vec![vec![b], vec![c]],
|
||||||
],
|
],
|
||||||
vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]],
|
vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]],
|
||||||
vec![],
|
vec![],
|
||||||
@@ -635,20 +622,8 @@ mod tests {
|
|||||||
|
|
||||||
let post = batch.posteriors();
|
let post = batch.posteriors();
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(post[&a], Gaussian::new(25.000003, 3.880150), epsilon = 1e-6);
|
||||||
post["a"],
|
assert_ulps_eq!(post[&b], Gaussian::new(25.000003, 3.880150), epsilon = 1e-6);
|
||||||
Gaussian::new(25.00000315330858, 3.880150268080797),
|
assert_ulps_eq!(post[&c], Gaussian::new(25.000003, 3.880150), epsilon = 1e-6);
|
||||||
epsilon = 0.000001
|
|
||||||
);
|
|
||||||
assert_ulps_eq!(
|
|
||||||
post["b"],
|
|
||||||
Gaussian::new(25.00000315330858, 3.880150268080797),
|
|
||||||
epsilon = 0.000001
|
|
||||||
);
|
|
||||||
assert_ulps_eq!(
|
|
||||||
post["c"],
|
|
||||||
Gaussian::new(25.00000315330858, 3.880150268080797),
|
|
||||||
epsilon = 0.000001
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
679
src/history.rs
679
src/history.rs
File diff suppressed because it is too large
Load Diff
45
src/lib.rs
45
src/lib.rs
@@ -1,5 +1,8 @@
|
|||||||
|
use std::borrow::{Borrow, ToOwned};
|
||||||
use std::cmp::Reverse;
|
use std::cmp::Reverse;
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::f64::consts::{FRAC_1_SQRT_2, FRAC_2_SQRT_PI, SQRT_2};
|
use std::f64::consts::{FRAC_1_SQRT_2, FRAC_2_SQRT_PI, SQRT_2};
|
||||||
|
use std::hash::Hash;
|
||||||
|
|
||||||
mod agent;
|
mod agent;
|
||||||
#[cfg(feature = "approx")]
|
#[cfg(feature = "approx")]
|
||||||
@@ -27,19 +30,55 @@ pub const ITERATIONS: usize = 30;
|
|||||||
|
|
||||||
const SQRT_TAU: f64 = 2.5066282746310002;
|
const SQRT_TAU: f64 = 2.5066282746310002;
|
||||||
|
|
||||||
const N01: Gaussian = Gaussian {
|
pub const N01: Gaussian = Gaussian {
|
||||||
mu: 0.0,
|
mu: 0.0,
|
||||||
sigma: 1.0,
|
sigma: 1.0,
|
||||||
};
|
};
|
||||||
pub(crate) const N00: Gaussian = Gaussian {
|
pub const N00: Gaussian = Gaussian {
|
||||||
mu: 0.0,
|
mu: 0.0,
|
||||||
sigma: 0.0,
|
sigma: 0.0,
|
||||||
};
|
};
|
||||||
pub(crate) const N_INF: Gaussian = Gaussian {
|
pub const N_INF: Gaussian = Gaussian {
|
||||||
mu: 0.0,
|
mu: 0.0,
|
||||||
sigma: f64::INFINITY,
|
sigma: f64::INFINITY,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Default, PartialEq, PartialOrd, Eq, Ord, Hash, Debug)]
|
||||||
|
pub struct Index(usize);
|
||||||
|
|
||||||
|
impl From<usize> for Index {
|
||||||
|
fn from(ix: usize) -> Self {
|
||||||
|
Self(ix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct IndexMap<K>(HashMap<K, Index>);
|
||||||
|
|
||||||
|
impl<K> IndexMap<K>
|
||||||
|
where
|
||||||
|
K: Eq + Hash,
|
||||||
|
{
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self(HashMap::new())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_or_create<Q: ?Sized>(&mut self, k: &Q) -> Index
|
||||||
|
where
|
||||||
|
K: Borrow<Q>,
|
||||||
|
Q: Hash + Eq + ToOwned<Owned = K>,
|
||||||
|
{
|
||||||
|
if let Some(idx) = self.0.get(k) {
|
||||||
|
*idx
|
||||||
|
} else {
|
||||||
|
let idx = Index::from(self.0.len());
|
||||||
|
|
||||||
|
self.0.insert(k.to_owned(), idx);
|
||||||
|
|
||||||
|
idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn erfc(x: f64) -> f64 {
|
fn erfc(x: f64) -> f64 {
|
||||||
let z = x.abs();
|
let z = x.abs();
|
||||||
let t = 1.0 / (1.0 + z / 2.0);
|
let t = 1.0 / (1.0 + z / 2.0);
|
||||||
|
|||||||
Reference in New Issue
Block a user