refactor(batch): replace HashMap<Index, Skill> with dense SkillStore
SkillStore is a Vec<Skill>-backed dense store with a parallel present mask, indexed directly by Index.0. Eliminates per-iteration hashing in the within-slice convergence loop; O(1) array lookup replaces O(1) amortised hash lookup with better cache behaviour. Iteration order is now ascending-by-Index (was arbitrary for HashMap); EP fixed point is order-independent so posteriors are unchanged. Part of T0 engine redesign.
This commit is contained in:
41
src/batch.rs
41
src/batch.rs
@@ -2,7 +2,7 @@ use std::collections::HashMap;
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
Index, N_INF, agent::Agent, drift::Drift, game::Game, gaussian::Gaussian, player::Player,
|
Index, N_INF, agent::Agent, drift::Drift, game::Game, gaussian::Gaussian, player::Player,
|
||||||
tuple_gt, tuple_max,
|
storage::SkillStore, tuple_gt, tuple_max,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -43,11 +43,11 @@ impl Item {
|
|||||||
&self,
|
&self,
|
||||||
online: bool,
|
online: bool,
|
||||||
forward: bool,
|
forward: bool,
|
||||||
skills: &HashMap<Index, Skill>,
|
skills: &SkillStore,
|
||||||
agents: &HashMap<Index, Agent<D>>,
|
agents: &HashMap<Index, Agent<D>>,
|
||||||
) -> Player<D> {
|
) -> Player<D> {
|
||||||
let r = &agents[&self.agent].player;
|
let r = &agents[&self.agent].player;
|
||||||
let skill = &skills[&self.agent];
|
let skill = skills.get(self.agent).unwrap();
|
||||||
|
|
||||||
if online {
|
if online {
|
||||||
Player::new(skill.online, r.beta, r.drift)
|
Player::new(skill.online, r.beta, r.drift)
|
||||||
@@ -84,7 +84,7 @@ impl Event {
|
|||||||
&self,
|
&self,
|
||||||
online: bool,
|
online: bool,
|
||||||
forward: bool,
|
forward: bool,
|
||||||
skills: &HashMap<Index, Skill>,
|
skills: &SkillStore,
|
||||||
agents: &HashMap<Index, Agent<D>>,
|
agents: &HashMap<Index, Agent<D>>,
|
||||||
) -> Vec<Vec<Player<D>>> {
|
) -> Vec<Vec<Player<D>>> {
|
||||||
self.teams
|
self.teams
|
||||||
@@ -102,7 +102,7 @@ impl Event {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Batch {
|
pub struct Batch {
|
||||||
pub(crate) events: Vec<Event>,
|
pub(crate) events: Vec<Event>,
|
||||||
pub(crate) skills: HashMap<Index, Skill>,
|
pub(crate) skills: SkillStore,
|
||||||
pub(crate) time: i64,
|
pub(crate) time: i64,
|
||||||
p_draw: f64,
|
p_draw: f64,
|
||||||
}
|
}
|
||||||
@@ -111,7 +111,7 @@ impl Batch {
|
|||||||
pub fn new(time: i64, p_draw: f64) -> Self {
|
pub fn new(time: i64, p_draw: f64) -> Self {
|
||||||
Self {
|
Self {
|
||||||
events: Vec::new(),
|
events: Vec::new(),
|
||||||
skills: HashMap::new(),
|
skills: SkillStore::new(),
|
||||||
time,
|
time,
|
||||||
p_draw,
|
p_draw,
|
||||||
}
|
}
|
||||||
@@ -137,16 +137,16 @@ impl Batch {
|
|||||||
});
|
});
|
||||||
|
|
||||||
for idx in this_agent {
|
for idx in this_agent {
|
||||||
let elapsed = compute_elapsed(agents[&idx].last_time, self.time);
|
let elapsed = compute_elapsed(agents[idx].last_time, self.time);
|
||||||
|
|
||||||
if let Some(skill) = self.skills.get_mut(idx) {
|
if let Some(skill) = self.skills.get_mut(*idx) {
|
||||||
skill.elapsed = elapsed;
|
skill.elapsed = elapsed;
|
||||||
skill.forward = agents[&idx].receive(elapsed);
|
skill.forward = agents[idx].receive(elapsed);
|
||||||
} else {
|
} else {
|
||||||
self.skills.insert(
|
self.skills.insert(
|
||||||
*idx,
|
*idx,
|
||||||
Skill {
|
Skill {
|
||||||
forward: agents[&idx].receive(elapsed),
|
forward: agents[idx].receive(elapsed),
|
||||||
elapsed,
|
elapsed,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
},
|
},
|
||||||
@@ -204,7 +204,7 @@ impl Batch {
|
|||||||
pub(crate) fn posteriors(&self) -> HashMap<Index, Gaussian> {
|
pub(crate) fn posteriors(&self) -> HashMap<Index, Gaussian> {
|
||||||
self.skills
|
self.skills
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(&idx, skill)| (idx, skill.posterior()))
|
.map(|(idx, skill)| (idx, skill.posterior()))
|
||||||
.collect::<HashMap<_, _>>()
|
.collect::<HashMap<_, _>>()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -217,10 +217,9 @@ impl Batch {
|
|||||||
|
|
||||||
for (t, team) in event.teams.iter_mut().enumerate() {
|
for (t, team) in event.teams.iter_mut().enumerate() {
|
||||||
for (i, item) in team.items.iter_mut().enumerate() {
|
for (i, item) in team.items.iter_mut().enumerate() {
|
||||||
self.skills.get_mut(&item.agent).unwrap().likelihood =
|
let old_likelihood = self.skills.get(item.agent).unwrap().likelihood;
|
||||||
(self.skills[&item.agent].likelihood / item.likelihood)
|
let new_likelihood = (old_likelihood / item.likelihood) * g.likelihoods[t][i];
|
||||||
* g.likelihoods[t][i];
|
self.skills.get_mut(item.agent).unwrap().likelihood = new_likelihood;
|
||||||
|
|
||||||
item.likelihood = g.likelihoods[t][i];
|
item.likelihood = g.likelihoods[t][i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -255,8 +254,7 @@ impl Batch {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn forward_prior_out(&self, agent: &Index) -> Gaussian {
|
pub(crate) fn forward_prior_out(&self, agent: &Index) -> Gaussian {
|
||||||
let skill = &self.skills[agent];
|
let skill = self.skills.get(*agent).unwrap();
|
||||||
|
|
||||||
skill.forward * skill.likelihood
|
skill.forward * skill.likelihood
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -265,25 +263,22 @@ impl Batch {
|
|||||||
agent: &Index,
|
agent: &Index,
|
||||||
agents: &HashMap<Index, Agent<D>>,
|
agents: &HashMap<Index, Agent<D>>,
|
||||||
) -> Gaussian {
|
) -> Gaussian {
|
||||||
let skill = &self.skills[agent];
|
let skill = self.skills.get(*agent).unwrap();
|
||||||
let n = skill.likelihood * skill.backward;
|
let n = skill.likelihood * skill.backward;
|
||||||
|
|
||||||
n.forget(agents[agent].player.drift.variance_delta(skill.elapsed))
|
n.forget(agents[agent].player.drift.variance_delta(skill.elapsed))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn new_backward_info<D: Drift>(&mut self, agents: &HashMap<Index, Agent<D>>) {
|
pub(crate) fn new_backward_info<D: Drift>(&mut self, agents: &HashMap<Index, Agent<D>>) {
|
||||||
for (agent, skill) in self.skills.iter_mut() {
|
for (agent, skill) in self.skills.iter_mut() {
|
||||||
skill.backward = agents[agent].message;
|
skill.backward = agents[&agent].message;
|
||||||
}
|
}
|
||||||
|
|
||||||
self.iteration(0, agents);
|
self.iteration(0, agents);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn new_forward_info<D: Drift>(&mut self, agents: &HashMap<Index, Agent<D>>) {
|
pub(crate) fn new_forward_info<D: Drift>(&mut self, agents: &HashMap<Index, Agent<D>>) {
|
||||||
for (agent, skill) in self.skills.iter_mut() {
|
for (agent, skill) in self.skills.iter_mut() {
|
||||||
skill.forward = agents[agent].receive(skill.elapsed);
|
skill.forward = agents[&agent].receive(skill.elapsed);
|
||||||
}
|
}
|
||||||
|
|
||||||
self.iteration(0, agents);
|
self.iteration(0, agents);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
170
src/history.rs
170
src/history.rs
@@ -145,8 +145,8 @@ impl<D: Drift> History<D> {
|
|||||||
|
|
||||||
for j in (0..self.batches.len() - 1).rev() {
|
for j in (0..self.batches.len() - 1).rev() {
|
||||||
for agent in self.batches[j + 1].skills.keys() {
|
for agent in self.batches[j + 1].skills.keys() {
|
||||||
self.agents.get_mut(agent).unwrap().message =
|
self.agents.get_mut(&agent).unwrap().message =
|
||||||
self.batches[j + 1].backward_prior_out(agent, &self.agents);
|
self.batches[j + 1].backward_prior_out(&agent, &self.agents);
|
||||||
}
|
}
|
||||||
|
|
||||||
let old = self.batches[j].posteriors();
|
let old = self.batches[j].posteriors();
|
||||||
@@ -164,8 +164,8 @@ impl<D: Drift> History<D> {
|
|||||||
|
|
||||||
for j in 1..self.batches.len() {
|
for j in 1..self.batches.len() {
|
||||||
for agent in self.batches[j - 1].skills.keys() {
|
for agent in self.batches[j - 1].skills.keys() {
|
||||||
self.agents.get_mut(agent).unwrap().message =
|
self.agents.get_mut(&agent).unwrap().message =
|
||||||
self.batches[j - 1].forward_prior_out(agent);
|
self.batches[j - 1].forward_prior_out(&agent);
|
||||||
}
|
}
|
||||||
|
|
||||||
let old = self.batches[j].posteriors();
|
let old = self.batches[j].posteriors();
|
||||||
@@ -231,10 +231,10 @@ impl<D: Drift> History<D> {
|
|||||||
for (agent, skill) in b.skills.iter() {
|
for (agent, skill) in b.skills.iter() {
|
||||||
let point = (b.time, skill.posterior());
|
let point = (b.time, skill.posterior());
|
||||||
|
|
||||||
if let Some(entry) = data.get_mut(agent) {
|
if let Some(entry) = data.get_mut(&agent) {
|
||||||
entry.push(point);
|
entry.push(point);
|
||||||
} else {
|
} else {
|
||||||
data.insert(*agent, vec![point]);
|
data.insert(agent, vec![point]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -343,7 +343,7 @@ impl<D: Drift> History<D> {
|
|||||||
|
|
||||||
// TODO: Is it faster to iterate over agents in batch instead?
|
// TODO: Is it faster to iterate over agents in batch instead?
|
||||||
for agent_idx in &this_agent {
|
for agent_idx in &this_agent {
|
||||||
if let Some(skill) = batch.skills.get_mut(agent_idx) {
|
if let Some(skill) = batch.skills.get_mut(*agent_idx) {
|
||||||
skill.elapsed =
|
skill.elapsed =
|
||||||
batch::compute_elapsed(self.agents[agent_idx].last_time, batch.time);
|
batch::compute_elapsed(self.agents[agent_idx].last_time, batch.time);
|
||||||
|
|
||||||
@@ -378,10 +378,10 @@ impl<D: Drift> History<D> {
|
|||||||
batch.add_events(composition, results, weights, &self.agents);
|
batch.add_events(composition, results, weights, &self.agents);
|
||||||
|
|
||||||
for agent_idx in batch.skills.keys() {
|
for agent_idx in batch.skills.keys() {
|
||||||
let agent = self.agents.get_mut(agent_idx).unwrap();
|
let agent = self.agents.get_mut(&agent_idx).unwrap();
|
||||||
|
|
||||||
agent.last_time = if self.time { t } else { i64::MAX };
|
agent.last_time = if self.time { t } else { i64::MAX };
|
||||||
agent.message = batch.forward_prior_out(agent_idx);
|
agent.message = batch.forward_prior_out(&agent_idx);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let mut batch: Batch = Batch::new(t, self.p_draw);
|
let mut batch: Batch = Batch::new(t, self.p_draw);
|
||||||
@@ -392,10 +392,10 @@ impl<D: Drift> History<D> {
|
|||||||
let batch = &self.batches[k];
|
let batch = &self.batches[k];
|
||||||
|
|
||||||
for agent_idx in batch.skills.keys() {
|
for agent_idx in batch.skills.keys() {
|
||||||
let agent = self.agents.get_mut(agent_idx).unwrap();
|
let agent = self.agents.get_mut(&agent_idx).unwrap();
|
||||||
|
|
||||||
agent.last_time = if self.time { t } else { i64::MAX };
|
agent.last_time = if self.time { t } else { i64::MAX };
|
||||||
agent.message = batch.forward_prior_out(agent_idx);
|
agent.message = batch.forward_prior_out(&agent_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
k += 1;
|
k += 1;
|
||||||
@@ -411,7 +411,7 @@ impl<D: Drift> History<D> {
|
|||||||
|
|
||||||
// TODO: Is it faster to iterate over agents in batch instead?
|
// TODO: Is it faster to iterate over agents in batch instead?
|
||||||
for agent_idx in &this_agent {
|
for agent_idx in &this_agent {
|
||||||
if let Some(skill) = batch.skills.get_mut(agent_idx) {
|
if let Some(skill) = batch.skills.get_mut(*agent_idx) {
|
||||||
skill.elapsed =
|
skill.elapsed =
|
||||||
batch::compute_elapsed(self.agents[agent_idx].last_time, batch.time);
|
batch::compute_elapsed(self.agents[agent_idx].last_time, batch.time);
|
||||||
|
|
||||||
@@ -476,13 +476,21 @@ mod tests {
|
|||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
|
|
||||||
let observed = h.batches[1].skills[&a].forward.sigma();
|
let observed = h.batches[1].skills.get(a).unwrap().forward.sigma();
|
||||||
let gamma: f64 = 0.15 * 25.0 / 3.0;
|
let gamma: f64 = 0.15 * 25.0 / 3.0;
|
||||||
let expected = (gamma.powi(2) + h.batches[0].skills[&a].posterior().sigma().powi(2)).sqrt();
|
let expected = (gamma.powi(2)
|
||||||
|
+ h.batches[0]
|
||||||
|
.skills
|
||||||
|
.get(a)
|
||||||
|
.unwrap()
|
||||||
|
.posterior()
|
||||||
|
.sigma()
|
||||||
|
.powi(2))
|
||||||
|
.sqrt();
|
||||||
|
|
||||||
assert_ulps_eq!(observed, expected, epsilon = 0.000001);
|
assert_ulps_eq!(observed, expected, epsilon = 0.000001);
|
||||||
|
|
||||||
let observed = h.batches[1].skills[&a].posterior();
|
let observed = h.batches[1].skills.get(a).unwrap().posterior();
|
||||||
|
|
||||||
let w = [vec![1.0], vec![1.0]];
|
let w = [vec![1.0], vec![1.0]];
|
||||||
let p = Game::new(
|
let p = Game::new(
|
||||||
@@ -531,12 +539,12 @@ mod tests {
|
|||||||
h1.add_events_with_prior(composition, results, times, vec![], priors);
|
h1.add_events_with_prior(composition, results, times, vec![], priors);
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h1.batches[0].skills[&a].posterior(),
|
h1.batches[0].skills.get(a).unwrap().posterior(),
|
||||||
Gaussian::from_ms(22.904409, 6.010330),
|
Gaussian::from_ms(22.904409, 6.010330),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h1.batches[0].skills[&c].posterior(),
|
h1.batches[0].skills.get(c).unwrap().posterior(),
|
||||||
Gaussian::from_ms(25.110318, 5.866311),
|
Gaussian::from_ms(25.110318, 5.866311),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
@@ -544,12 +552,12 @@ mod tests {
|
|||||||
h1.convergence(ITERATIONS, EPSILON, false);
|
h1.convergence(ITERATIONS, EPSILON, false);
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h1.batches[0].skills[&a].posterior(),
|
h1.batches[0].skills.get(a).unwrap().posterior(),
|
||||||
Gaussian::from_ms(25.000000, 5.419212),
|
Gaussian::from_ms(25.000000, 5.419212),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h1.batches[0].skills[&c].posterior(),
|
h1.batches[0].skills.get(c).unwrap().posterior(),
|
||||||
Gaussian::from_ms(25.000000, 5.419212),
|
Gaussian::from_ms(25.000000, 5.419212),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
@@ -580,12 +588,12 @@ mod tests {
|
|||||||
h2.add_events_with_prior(composition, results, times, vec![], priors);
|
h2.add_events_with_prior(composition, results, times, vec![], priors);
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h2.batches[2].skills[&a].posterior(),
|
h2.batches[2].skills.get(a).unwrap().posterior(),
|
||||||
Gaussian::from_ms(22.903522, 6.011017),
|
Gaussian::from_ms(22.903522, 6.011017),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h2.batches[2].skills[&c].posterior(),
|
h2.batches[2].skills.get(c).unwrap().posterior(),
|
||||||
Gaussian::from_ms(25.110702, 5.866811),
|
Gaussian::from_ms(25.110702, 5.866811),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
@@ -593,12 +601,12 @@ mod tests {
|
|||||||
h2.convergence(ITERATIONS, EPSILON, false);
|
h2.convergence(ITERATIONS, EPSILON, false);
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h2.batches[2].skills[&a].posterior(),
|
h2.batches[2].skills.get(a).unwrap().posterior(),
|
||||||
Gaussian::from_ms(24.998668, 5.420053),
|
Gaussian::from_ms(24.998668, 5.420053),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h2.batches[2].skills[&c].posterior(),
|
h2.batches[2].skills.get(c).unwrap().posterior(),
|
||||||
Gaussian::from_ms(25.000532, 5.419827),
|
Gaussian::from_ms(25.000532, 5.419827),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
@@ -685,21 +693,21 @@ mod tests {
|
|||||||
|
|
||||||
h.convergence(ITERATIONS, EPSILON, false);
|
h.convergence(ITERATIONS, EPSILON, false);
|
||||||
|
|
||||||
assert_eq!(h.batches[2].skills[&b].elapsed, 1);
|
assert_eq!(h.batches[2].skills.get(b).unwrap().elapsed, 1);
|
||||||
assert_eq!(h.batches[2].skills[&c].elapsed, 1);
|
assert_eq!(h.batches[2].skills.get(c).unwrap().elapsed, 1);
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[0].skills[&a].posterior(),
|
h.batches[0].skills.get(a).unwrap().posterior(),
|
||||||
Gaussian::from_ms(25.000267, 5.419381),
|
Gaussian::from_ms(25.000267, 5.419381),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[0].skills[&b].posterior(),
|
h.batches[0].skills.get(b).unwrap().posterior(),
|
||||||
Gaussian::from_ms(24.999465, 5.419425),
|
Gaussian::from_ms(24.999465, 5.419425),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[2].skills[&b].posterior(),
|
h.batches[2].skills.get(b).unwrap().posterior(),
|
||||||
Gaussian::from_ms(25.000532, 5.419696),
|
Gaussian::from_ms(25.000532, 5.419696),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
@@ -743,8 +751,8 @@ mod tests {
|
|||||||
);
|
);
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[0].skills[&b].posterior().mu(),
|
h.batches[0].skills.get(b).unwrap().posterior().mu(),
|
||||||
-1.0 * h.batches[0].skills[&c].posterior().mu(),
|
-1.0 * h.batches[0].skills.get(c).unwrap().posterior().mu(),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -763,33 +771,33 @@ mod tests {
|
|||||||
assert_ulps_eq!(p_d_m_hat, 0.172432, epsilon = 1e-6);
|
assert_ulps_eq!(p_d_m_hat, 0.172432, epsilon = 1e-6);
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[0].skills[&a].posterior(),
|
h.batches[0].skills.get(a).unwrap().posterior(),
|
||||||
h.batches[0].skills[&b].posterior(),
|
h.batches[0].skills.get(b).unwrap().posterior(),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[0].skills[&c].posterior(),
|
h.batches[0].skills.get(c).unwrap().posterior(),
|
||||||
h.batches[0].skills[&d].posterior(),
|
h.batches[0].skills.get(d).unwrap().posterior(),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[1].skills[&e].posterior(),
|
h.batches[1].skills.get(e).unwrap().posterior(),
|
||||||
h.batches[1].skills[&f].posterior(),
|
h.batches[1].skills.get(f).unwrap().posterior(),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[0].skills[&a].posterior(),
|
h.batches[0].skills.get(a).unwrap().posterior(),
|
||||||
Gaussian::from_ms(4.084902, 5.106919),
|
Gaussian::from_ms(4.084902, 5.106919),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[0].skills[&c].posterior(),
|
h.batches[0].skills.get(c).unwrap().posterior(),
|
||||||
Gaussian::from_ms(-0.533029, 5.106919),
|
Gaussian::from_ms(-0.533029, 5.106919),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[2].skills[&e].posterior(),
|
h.batches[2].skills.get(e).unwrap().posterior(),
|
||||||
Gaussian::from_ms(-3.551872, 5.154569),
|
Gaussian::from_ms(-3.551872, 5.154569),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
@@ -822,21 +830,21 @@ mod tests {
|
|||||||
|
|
||||||
h.convergence(ITERATIONS, EPSILON, false);
|
h.convergence(ITERATIONS, EPSILON, false);
|
||||||
|
|
||||||
assert_eq!(h.batches[2].skills[&b].elapsed, 1);
|
assert_eq!(h.batches[2].skills.get(b).unwrap().elapsed, 1);
|
||||||
assert_eq!(h.batches[2].skills[&c].elapsed, 1);
|
assert_eq!(h.batches[2].skills.get(c).unwrap().elapsed, 1);
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[0].skills[&a].posterior(),
|
h.batches[0].skills.get(a).unwrap().posterior(),
|
||||||
Gaussian::from_ms(0.000000, 1.300610),
|
Gaussian::from_ms(0.000000, 1.300610),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[0].skills[&b].posterior(),
|
h.batches[0].skills.get(b).unwrap().posterior(),
|
||||||
Gaussian::from_ms(0.000000, 1.300610),
|
Gaussian::from_ms(0.000000, 1.300610),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[2].skills[&b].posterior(),
|
h.batches[2].skills.get(b).unwrap().posterior(),
|
||||||
Gaussian::from_ms(0.000000, 1.300610),
|
Gaussian::from_ms(0.000000, 1.300610),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
@@ -863,22 +871,22 @@ mod tests {
|
|||||||
h.convergence(ITERATIONS, EPSILON, false);
|
h.convergence(ITERATIONS, EPSILON, false);
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[0].skills[&a].posterior(),
|
h.batches[0].skills.get(a).unwrap().posterior(),
|
||||||
Gaussian::from_ms(0.000000, 0.931236),
|
Gaussian::from_ms(0.000000, 0.931236),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[3].skills[&a].posterior(),
|
h.batches[3].skills.get(a).unwrap().posterior(),
|
||||||
Gaussian::from_ms(0.000000, 0.931236),
|
Gaussian::from_ms(0.000000, 0.931236),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[3].skills[&b].posterior(),
|
h.batches[3].skills.get(b).unwrap().posterior(),
|
||||||
Gaussian::from_ms(0.000000, 0.931236),
|
Gaussian::from_ms(0.000000, 0.931236),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[5].skills[&b].posterior(),
|
h.batches[5].skills.get(b).unwrap().posterior(),
|
||||||
Gaussian::from_ms(0.000000, 0.931236),
|
Gaussian::from_ms(0.000000, 0.931236),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
@@ -911,21 +919,21 @@ mod tests {
|
|||||||
|
|
||||||
h.convergence(ITERATIONS, EPSILON, false);
|
h.convergence(ITERATIONS, EPSILON, false);
|
||||||
|
|
||||||
assert_eq!(h.batches[2].skills[&b].elapsed, 1);
|
assert_eq!(h.batches[2].skills.get(b).unwrap().elapsed, 1);
|
||||||
assert_eq!(h.batches[2].skills[&c].elapsed, 1);
|
assert_eq!(h.batches[2].skills.get(c).unwrap().elapsed, 1);
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[0].skills[&a].posterior(),
|
h.batches[0].skills.get(a).unwrap().posterior(),
|
||||||
Gaussian::from_ms(0.000000, 1.300610),
|
Gaussian::from_ms(0.000000, 1.300610),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[0].skills[&b].posterior(),
|
h.batches[0].skills.get(b).unwrap().posterior(),
|
||||||
Gaussian::from_ms(0.000000, 1.300610),
|
Gaussian::from_ms(0.000000, 1.300610),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[2].skills[&b].posterior(),
|
h.batches[2].skills.get(b).unwrap().posterior(),
|
||||||
Gaussian::from_ms(0.000000, 1.300610),
|
Gaussian::from_ms(0.000000, 1.300610),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
@@ -952,22 +960,22 @@ mod tests {
|
|||||||
h.convergence(ITERATIONS, EPSILON, false);
|
h.convergence(ITERATIONS, EPSILON, false);
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[0].skills[&a].posterior(),
|
h.batches[0].skills.get(a).unwrap().posterior(),
|
||||||
Gaussian::from_ms(0.000000, 0.931236),
|
Gaussian::from_ms(0.000000, 0.931236),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[3].skills[&a].posterior(),
|
h.batches[3].skills.get(a).unwrap().posterior(),
|
||||||
Gaussian::from_ms(0.000000, 0.931236),
|
Gaussian::from_ms(0.000000, 0.931236),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[3].skills[&b].posterior(),
|
h.batches[3].skills.get(b).unwrap().posterior(),
|
||||||
Gaussian::from_ms(0.000000, 0.931236),
|
Gaussian::from_ms(0.000000, 0.931236),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[5].skills[&b].posterior(),
|
h.batches[5].skills.get(b).unwrap().posterior(),
|
||||||
Gaussian::from_ms(0.000000, 0.931236),
|
Gaussian::from_ms(0.000000, 0.931236),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
@@ -1103,32 +1111,32 @@ mod tests {
|
|||||||
|
|
||||||
let end = h.batches.len() - 1;
|
let end = h.batches.len() - 1;
|
||||||
|
|
||||||
assert_eq!(h.batches[0].skills[&c].elapsed, 0);
|
assert_eq!(h.batches[0].skills.get(c).unwrap().elapsed, 0);
|
||||||
assert_eq!(h.batches[end].skills[&c].elapsed, 10);
|
assert_eq!(h.batches[end].skills.get(c).unwrap().elapsed, 10);
|
||||||
|
|
||||||
assert_eq!(h.batches[0].skills[&a].elapsed, 0);
|
assert_eq!(h.batches[0].skills.get(a).unwrap().elapsed, 0);
|
||||||
assert_eq!(h.batches[2].skills[&a].elapsed, 5);
|
assert_eq!(h.batches[2].skills.get(a).unwrap().elapsed, 5);
|
||||||
|
|
||||||
assert_eq!(h.batches[0].skills[&b].elapsed, 0);
|
assert_eq!(h.batches[0].skills.get(b).unwrap().elapsed, 0);
|
||||||
assert_eq!(h.batches[end].skills[&b].elapsed, 5);
|
assert_eq!(h.batches[end].skills.get(b).unwrap().elapsed, 5);
|
||||||
|
|
||||||
h.convergence(ITERATIONS, EPSILON, false);
|
h.convergence(ITERATIONS, EPSILON, false);
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[0].skills[&b].posterior(),
|
h.batches[0].skills.get(b).unwrap().posterior(),
|
||||||
h.batches[end].skills[&b].posterior(),
|
h.batches[end].skills.get(b).unwrap().posterior(),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[0].skills[&c].posterior(),
|
h.batches[0].skills.get(c).unwrap().posterior(),
|
||||||
h.batches[end].skills[&c].posterior(),
|
h.batches[end].skills.get(c).unwrap().posterior(),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[0].skills[&c].posterior(),
|
h.batches[0].skills.get(c).unwrap().posterior(),
|
||||||
h.batches[0].skills[&b].posterior(),
|
h.batches[0].skills.get(b).unwrap().posterior(),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -1191,32 +1199,32 @@ mod tests {
|
|||||||
|
|
||||||
let end = h.batches.len() - 1;
|
let end = h.batches.len() - 1;
|
||||||
|
|
||||||
assert_eq!(h.batches[0].skills[&c].elapsed, 0);
|
assert_eq!(h.batches[0].skills.get(c).unwrap().elapsed, 0);
|
||||||
assert_eq!(h.batches[end].skills[&c].elapsed, 10);
|
assert_eq!(h.batches[end].skills.get(c).unwrap().elapsed, 10);
|
||||||
|
|
||||||
assert_eq!(h.batches[0].skills[&a].elapsed, 0);
|
assert_eq!(h.batches[0].skills.get(a).unwrap().elapsed, 0);
|
||||||
assert_eq!(h.batches[2].skills[&a].elapsed, 5);
|
assert_eq!(h.batches[2].skills.get(a).unwrap().elapsed, 5);
|
||||||
|
|
||||||
assert_eq!(h.batches[0].skills[&b].elapsed, 0);
|
assert_eq!(h.batches[0].skills.get(b).unwrap().elapsed, 0);
|
||||||
assert_eq!(h.batches[end].skills[&b].elapsed, 5);
|
assert_eq!(h.batches[end].skills.get(b).unwrap().elapsed, 5);
|
||||||
|
|
||||||
h.convergence(ITERATIONS, EPSILON, false);
|
h.convergence(ITERATIONS, EPSILON, false);
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[0].skills[&b].posterior(),
|
h.batches[0].skills.get(b).unwrap().posterior(),
|
||||||
h.batches[end].skills[&b].posterior(),
|
h.batches[end].skills.get(b).unwrap().posterior(),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[0].skills[&c].posterior(),
|
h.batches[0].skills.get(c).unwrap().posterior(),
|
||||||
h.batches[end].skills[&c].posterior(),
|
h.batches[end].skills.get(c).unwrap().posterior(),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h.batches[0].skills[&c].posterior(),
|
h.batches[0].skills.get(c).unwrap().posterior(),
|
||||||
h.batches[0].skills[&b].posterior(),
|
h.batches[0].skills.get(b).unwrap().posterior(),
|
||||||
epsilon = 1e-6
|
epsilon = 1e-6
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ mod history;
|
|||||||
mod matrix;
|
mod matrix;
|
||||||
mod message;
|
mod message;
|
||||||
pub mod player;
|
pub mod player;
|
||||||
|
pub(crate) mod storage;
|
||||||
|
|
||||||
pub use drift::{ConstantDrift, Drift};
|
pub use drift::{ConstantDrift, Drift};
|
||||||
pub use error::InferenceError;
|
pub use error::InferenceError;
|
||||||
|
|||||||
3
src/storage/mod.rs
Normal file
3
src/storage/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
mod skill_store;
|
||||||
|
|
||||||
|
pub(crate) use skill_store::SkillStore;
|
||||||
128
src/storage/skill_store.rs
Normal file
128
src/storage/skill_store.rs
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
use crate::Index;
|
||||||
|
use crate::batch::Skill;
|
||||||
|
|
||||||
|
/// Dense Vec-backed store for per-agent skill state within a TimeSlice.
|
||||||
|
///
|
||||||
|
/// Indexed directly by Index.0, eliminating HashMap hashing in the inner
|
||||||
|
/// convergence loop. Uses a parallel `present` mask so iteration skips
|
||||||
|
/// absent slots without incurring per-slot Option overhead in the hot path.
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
pub struct SkillStore {
|
||||||
|
skills: Vec<Skill>,
|
||||||
|
present: Vec<bool>,
|
||||||
|
n_present: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SkillStore {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ensure_capacity(&mut self, idx: usize) {
|
||||||
|
if idx >= self.skills.len() {
|
||||||
|
self.skills.resize_with(idx + 1, Skill::default);
|
||||||
|
self.present.resize(idx + 1, false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn insert(&mut self, idx: Index, skill: Skill) {
|
||||||
|
self.ensure_capacity(idx.0);
|
||||||
|
if !self.present[idx.0] {
|
||||||
|
self.n_present += 1;
|
||||||
|
}
|
||||||
|
self.skills[idx.0] = skill;
|
||||||
|
self.present[idx.0] = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, idx: Index) -> Option<&Skill> {
|
||||||
|
if idx.0 < self.present.len() && self.present[idx.0] {
|
||||||
|
Some(&self.skills[idx.0])
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_mut(&mut self, idx: Index) -> Option<&mut Skill> {
|
||||||
|
if idx.0 < self.present.len() && self.present[idx.0] {
|
||||||
|
Some(&mut self.skills[idx.0])
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn contains(&self, idx: Index) -> bool {
|
||||||
|
idx.0 < self.present.len() && self.present[idx.0]
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.n_present
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.n_present == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn iter(&self) -> impl Iterator<Item = (Index, &Skill)> {
|
||||||
|
self.present.iter().enumerate().filter_map(|(i, &p)| {
|
||||||
|
if p {
|
||||||
|
Some((Index(i), &self.skills[i]))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn iter_mut(&mut self) -> impl Iterator<Item = (Index, &mut Skill)> {
|
||||||
|
self.skills
|
||||||
|
.iter_mut()
|
||||||
|
.zip(self.present.iter())
|
||||||
|
.enumerate()
|
||||||
|
.filter_map(|(i, (s, &p))| if p { Some((Index(i), s)) } else { None })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn keys(&self) -> impl Iterator<Item = Index> + '_ {
|
||||||
|
self.present
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.filter_map(|(i, &p)| if p { Some(Index(i)) } else { None })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn insert_then_get() {
|
||||||
|
let mut store = SkillStore::new();
|
||||||
|
let idx = Index(3);
|
||||||
|
store.insert(idx, Skill::default());
|
||||||
|
assert!(store.contains(idx));
|
||||||
|
assert_eq!(store.len(), 1);
|
||||||
|
assert!(store.get(idx).is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn missing_returns_none() {
|
||||||
|
let store = SkillStore::new();
|
||||||
|
assert!(store.get(Index(0)).is_none());
|
||||||
|
assert!(!store.contains(Index(42)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn iter_skips_absent_slots() {
|
||||||
|
let mut store = SkillStore::new();
|
||||||
|
store.insert(Index(0), Skill::default());
|
||||||
|
store.insert(Index(5), Skill::default());
|
||||||
|
let keys: Vec<Index> = store.keys().collect();
|
||||||
|
assert_eq!(keys, vec![Index(0), Index(5)]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn double_insert_does_not_double_count() {
|
||||||
|
let mut store = SkillStore::new();
|
||||||
|
store.insert(Index(2), Skill::default());
|
||||||
|
store.insert(Index(2), Skill::default());
|
||||||
|
assert_eq!(store.len(), 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user