From 8f60258dba670fc3b521ededde02a5938daed588 Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Fri, 24 Apr 2026 07:08:20 +0200 Subject: [PATCH] refactor(batch): replace HashMap with dense SkillStore SkillStore is a Vec-backed dense store with a parallel present mask, indexed directly by Index.0. Eliminates per-iteration hashing in the within-slice convergence loop; O(1) array lookup replaces O(1) amortised hash lookup with better cache behaviour. Iteration order is now ascending-by-Index (was arbitrary for HashMap); EP fixed point is order-independent so posteriors are unchanged. Part of T0 engine redesign. --- src/batch.rs | 41 ++++----- src/history.rs | 170 +++++++++++++++++++------------------ src/lib.rs | 1 + src/storage/mod.rs | 3 + src/storage/skill_store.rs | 128 ++++++++++++++++++++++++++++ 5 files changed, 239 insertions(+), 104 deletions(-) create mode 100644 src/storage/mod.rs create mode 100644 src/storage/skill_store.rs diff --git a/src/batch.rs b/src/batch.rs index 4e2ebf4..24d8c6b 100644 --- a/src/batch.rs +++ b/src/batch.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use crate::{ Index, N_INF, agent::Agent, drift::Drift, game::Game, gaussian::Gaussian, player::Player, - tuple_gt, tuple_max, + storage::SkillStore, tuple_gt, tuple_max, }; #[derive(Debug)] @@ -43,11 +43,11 @@ impl Item { &self, online: bool, forward: bool, - skills: &HashMap, + skills: &SkillStore, agents: &HashMap>, ) -> Player { let r = &agents[&self.agent].player; - let skill = &skills[&self.agent]; + let skill = skills.get(self.agent).unwrap(); if online { Player::new(skill.online, r.beta, r.drift) @@ -84,7 +84,7 @@ impl Event { &self, online: bool, forward: bool, - skills: &HashMap, + skills: &SkillStore, agents: &HashMap>, ) -> Vec>> { self.teams @@ -102,7 +102,7 @@ impl Event { #[derive(Debug)] pub struct Batch { pub(crate) events: Vec, - pub(crate) skills: HashMap, + pub(crate) skills: SkillStore, pub(crate) time: i64, p_draw: f64, } @@ -111,7 +111,7 @@ impl Batch { pub fn new(time: i64, p_draw: f64) -> Self { Self { events: Vec::new(), - skills: HashMap::new(), + skills: SkillStore::new(), time, p_draw, } @@ -137,16 +137,16 @@ impl Batch { }); for idx in this_agent { - let elapsed = compute_elapsed(agents[&idx].last_time, self.time); + let elapsed = compute_elapsed(agents[idx].last_time, self.time); - if let Some(skill) = self.skills.get_mut(idx) { + if let Some(skill) = self.skills.get_mut(*idx) { skill.elapsed = elapsed; - skill.forward = agents[&idx].receive(elapsed); + skill.forward = agents[idx].receive(elapsed); } else { self.skills.insert( *idx, Skill { - forward: agents[&idx].receive(elapsed), + forward: agents[idx].receive(elapsed), elapsed, ..Default::default() }, @@ -204,7 +204,7 @@ impl Batch { pub(crate) fn posteriors(&self) -> HashMap { self.skills .iter() - .map(|(&idx, skill)| (idx, skill.posterior())) + .map(|(idx, skill)| (idx, skill.posterior())) .collect::>() } @@ -217,10 +217,9 @@ impl Batch { for (t, team) in event.teams.iter_mut().enumerate() { for (i, item) in team.items.iter_mut().enumerate() { - self.skills.get_mut(&item.agent).unwrap().likelihood = - (self.skills[&item.agent].likelihood / item.likelihood) - * g.likelihoods[t][i]; - + let old_likelihood = self.skills.get(item.agent).unwrap().likelihood; + let new_likelihood = (old_likelihood / item.likelihood) * g.likelihoods[t][i]; + self.skills.get_mut(item.agent).unwrap().likelihood = new_likelihood; item.likelihood = g.likelihoods[t][i]; } } @@ -255,8 +254,7 @@ impl Batch { } pub(crate) fn forward_prior_out(&self, agent: &Index) -> Gaussian { - let skill = &self.skills[agent]; - + let skill = self.skills.get(*agent).unwrap(); skill.forward * skill.likelihood } @@ -265,25 +263,22 @@ impl Batch { agent: &Index, agents: &HashMap>, ) -> Gaussian { - let skill = &self.skills[agent]; + let skill = self.skills.get(*agent).unwrap(); let n = skill.likelihood * skill.backward; - n.forget(agents[agent].player.drift.variance_delta(skill.elapsed)) } pub(crate) fn new_backward_info(&mut self, agents: &HashMap>) { for (agent, skill) in self.skills.iter_mut() { - skill.backward = agents[agent].message; + skill.backward = agents[&agent].message; } - self.iteration(0, agents); } pub(crate) fn new_forward_info(&mut self, agents: &HashMap>) { for (agent, skill) in self.skills.iter_mut() { - skill.forward = agents[agent].receive(skill.elapsed); + skill.forward = agents[&agent].receive(skill.elapsed); } - self.iteration(0, agents); } diff --git a/src/history.rs b/src/history.rs index 583da74..b4be9ee 100644 --- a/src/history.rs +++ b/src/history.rs @@ -145,8 +145,8 @@ impl History { for j in (0..self.batches.len() - 1).rev() { for agent in self.batches[j + 1].skills.keys() { - self.agents.get_mut(agent).unwrap().message = - self.batches[j + 1].backward_prior_out(agent, &self.agents); + self.agents.get_mut(&agent).unwrap().message = + self.batches[j + 1].backward_prior_out(&agent, &self.agents); } let old = self.batches[j].posteriors(); @@ -164,8 +164,8 @@ impl History { for j in 1..self.batches.len() { for agent in self.batches[j - 1].skills.keys() { - self.agents.get_mut(agent).unwrap().message = - self.batches[j - 1].forward_prior_out(agent); + self.agents.get_mut(&agent).unwrap().message = + self.batches[j - 1].forward_prior_out(&agent); } let old = self.batches[j].posteriors(); @@ -231,10 +231,10 @@ impl History { for (agent, skill) in b.skills.iter() { let point = (b.time, skill.posterior()); - if let Some(entry) = data.get_mut(agent) { + if let Some(entry) = data.get_mut(&agent) { entry.push(point); } else { - data.insert(*agent, vec![point]); + data.insert(agent, vec![point]); } } } @@ -343,7 +343,7 @@ impl History { // TODO: Is it faster to iterate over agents in batch instead? for agent_idx in &this_agent { - if let Some(skill) = batch.skills.get_mut(agent_idx) { + if let Some(skill) = batch.skills.get_mut(*agent_idx) { skill.elapsed = batch::compute_elapsed(self.agents[agent_idx].last_time, batch.time); @@ -378,10 +378,10 @@ impl History { batch.add_events(composition, results, weights, &self.agents); for agent_idx in batch.skills.keys() { - let agent = self.agents.get_mut(agent_idx).unwrap(); + let agent = self.agents.get_mut(&agent_idx).unwrap(); agent.last_time = if self.time { t } else { i64::MAX }; - agent.message = batch.forward_prior_out(agent_idx); + agent.message = batch.forward_prior_out(&agent_idx); } } else { let mut batch: Batch = Batch::new(t, self.p_draw); @@ -392,10 +392,10 @@ impl History { let batch = &self.batches[k]; for agent_idx in batch.skills.keys() { - let agent = self.agents.get_mut(agent_idx).unwrap(); + let agent = self.agents.get_mut(&agent_idx).unwrap(); agent.last_time = if self.time { t } else { i64::MAX }; - agent.message = batch.forward_prior_out(agent_idx); + agent.message = batch.forward_prior_out(&agent_idx); } k += 1; @@ -411,7 +411,7 @@ impl History { // TODO: Is it faster to iterate over agents in batch instead? for agent_idx in &this_agent { - if let Some(skill) = batch.skills.get_mut(agent_idx) { + if let Some(skill) = batch.skills.get_mut(*agent_idx) { skill.elapsed = batch::compute_elapsed(self.agents[agent_idx].last_time, batch.time); @@ -476,13 +476,21 @@ mod tests { epsilon = 1e-6 ); - let observed = h.batches[1].skills[&a].forward.sigma(); + let observed = h.batches[1].skills.get(a).unwrap().forward.sigma(); let gamma: f64 = 0.15 * 25.0 / 3.0; - let expected = (gamma.powi(2) + h.batches[0].skills[&a].posterior().sigma().powi(2)).sqrt(); + let expected = (gamma.powi(2) + + h.batches[0] + .skills + .get(a) + .unwrap() + .posterior() + .sigma() + .powi(2)) + .sqrt(); assert_ulps_eq!(observed, expected, epsilon = 0.000001); - let observed = h.batches[1].skills[&a].posterior(); + let observed = h.batches[1].skills.get(a).unwrap().posterior(); let w = [vec![1.0], vec![1.0]]; let p = Game::new( @@ -531,12 +539,12 @@ mod tests { h1.add_events_with_prior(composition, results, times, vec![], priors); assert_ulps_eq!( - h1.batches[0].skills[&a].posterior(), + h1.batches[0].skills.get(a).unwrap().posterior(), Gaussian::from_ms(22.904409, 6.010330), epsilon = 1e-6 ); assert_ulps_eq!( - h1.batches[0].skills[&c].posterior(), + h1.batches[0].skills.get(c).unwrap().posterior(), Gaussian::from_ms(25.110318, 5.866311), epsilon = 1e-6 ); @@ -544,12 +552,12 @@ mod tests { h1.convergence(ITERATIONS, EPSILON, false); assert_ulps_eq!( - h1.batches[0].skills[&a].posterior(), + h1.batches[0].skills.get(a).unwrap().posterior(), Gaussian::from_ms(25.000000, 5.419212), epsilon = 1e-6 ); assert_ulps_eq!( - h1.batches[0].skills[&c].posterior(), + h1.batches[0].skills.get(c).unwrap().posterior(), Gaussian::from_ms(25.000000, 5.419212), epsilon = 1e-6 ); @@ -580,12 +588,12 @@ mod tests { h2.add_events_with_prior(composition, results, times, vec![], priors); assert_ulps_eq!( - h2.batches[2].skills[&a].posterior(), + h2.batches[2].skills.get(a).unwrap().posterior(), Gaussian::from_ms(22.903522, 6.011017), epsilon = 1e-6 ); assert_ulps_eq!( - h2.batches[2].skills[&c].posterior(), + h2.batches[2].skills.get(c).unwrap().posterior(), Gaussian::from_ms(25.110702, 5.866811), epsilon = 1e-6 ); @@ -593,12 +601,12 @@ mod tests { h2.convergence(ITERATIONS, EPSILON, false); assert_ulps_eq!( - h2.batches[2].skills[&a].posterior(), + h2.batches[2].skills.get(a).unwrap().posterior(), Gaussian::from_ms(24.998668, 5.420053), epsilon = 1e-6 ); assert_ulps_eq!( - h2.batches[2].skills[&c].posterior(), + h2.batches[2].skills.get(c).unwrap().posterior(), Gaussian::from_ms(25.000532, 5.419827), epsilon = 1e-6 ); @@ -685,21 +693,21 @@ mod tests { h.convergence(ITERATIONS, EPSILON, false); - assert_eq!(h.batches[2].skills[&b].elapsed, 1); - assert_eq!(h.batches[2].skills[&c].elapsed, 1); + assert_eq!(h.batches[2].skills.get(b).unwrap().elapsed, 1); + assert_eq!(h.batches[2].skills.get(c).unwrap().elapsed, 1); assert_ulps_eq!( - h.batches[0].skills[&a].posterior(), + h.batches[0].skills.get(a).unwrap().posterior(), Gaussian::from_ms(25.000267, 5.419381), epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[0].skills[&b].posterior(), + h.batches[0].skills.get(b).unwrap().posterior(), Gaussian::from_ms(24.999465, 5.419425), epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[2].skills[&b].posterior(), + h.batches[2].skills.get(b).unwrap().posterior(), Gaussian::from_ms(25.000532, 5.419696), epsilon = 1e-6 ); @@ -743,8 +751,8 @@ mod tests { ); assert_ulps_eq!( - h.batches[0].skills[&b].posterior().mu(), - -1.0 * h.batches[0].skills[&c].posterior().mu(), + h.batches[0].skills.get(b).unwrap().posterior().mu(), + -1.0 * h.batches[0].skills.get(c).unwrap().posterior().mu(), epsilon = 1e-6 ); @@ -763,33 +771,33 @@ mod tests { assert_ulps_eq!(p_d_m_hat, 0.172432, epsilon = 1e-6); assert_ulps_eq!( - h.batches[0].skills[&a].posterior(), - h.batches[0].skills[&b].posterior(), + h.batches[0].skills.get(a).unwrap().posterior(), + h.batches[0].skills.get(b).unwrap().posterior(), epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[0].skills[&c].posterior(), - h.batches[0].skills[&d].posterior(), + h.batches[0].skills.get(c).unwrap().posterior(), + h.batches[0].skills.get(d).unwrap().posterior(), epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[1].skills[&e].posterior(), - h.batches[1].skills[&f].posterior(), + h.batches[1].skills.get(e).unwrap().posterior(), + h.batches[1].skills.get(f).unwrap().posterior(), epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[0].skills[&a].posterior(), + h.batches[0].skills.get(a).unwrap().posterior(), Gaussian::from_ms(4.084902, 5.106919), epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[0].skills[&c].posterior(), + h.batches[0].skills.get(c).unwrap().posterior(), Gaussian::from_ms(-0.533029, 5.106919), epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[2].skills[&e].posterior(), + h.batches[2].skills.get(e).unwrap().posterior(), Gaussian::from_ms(-3.551872, 5.154569), epsilon = 1e-6 ); @@ -822,21 +830,21 @@ mod tests { h.convergence(ITERATIONS, EPSILON, false); - assert_eq!(h.batches[2].skills[&b].elapsed, 1); - assert_eq!(h.batches[2].skills[&c].elapsed, 1); + assert_eq!(h.batches[2].skills.get(b).unwrap().elapsed, 1); + assert_eq!(h.batches[2].skills.get(c).unwrap().elapsed, 1); assert_ulps_eq!( - h.batches[0].skills[&a].posterior(), + h.batches[0].skills.get(a).unwrap().posterior(), Gaussian::from_ms(0.000000, 1.300610), epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[0].skills[&b].posterior(), + h.batches[0].skills.get(b).unwrap().posterior(), Gaussian::from_ms(0.000000, 1.300610), epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[2].skills[&b].posterior(), + h.batches[2].skills.get(b).unwrap().posterior(), Gaussian::from_ms(0.000000, 1.300610), epsilon = 1e-6 ); @@ -863,22 +871,22 @@ mod tests { h.convergence(ITERATIONS, EPSILON, false); assert_ulps_eq!( - h.batches[0].skills[&a].posterior(), + h.batches[0].skills.get(a).unwrap().posterior(), Gaussian::from_ms(0.000000, 0.931236), epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[3].skills[&a].posterior(), + h.batches[3].skills.get(a).unwrap().posterior(), Gaussian::from_ms(0.000000, 0.931236), epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[3].skills[&b].posterior(), + h.batches[3].skills.get(b).unwrap().posterior(), Gaussian::from_ms(0.000000, 0.931236), epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[5].skills[&b].posterior(), + h.batches[5].skills.get(b).unwrap().posterior(), Gaussian::from_ms(0.000000, 0.931236), epsilon = 1e-6 ); @@ -911,21 +919,21 @@ mod tests { h.convergence(ITERATIONS, EPSILON, false); - assert_eq!(h.batches[2].skills[&b].elapsed, 1); - assert_eq!(h.batches[2].skills[&c].elapsed, 1); + assert_eq!(h.batches[2].skills.get(b).unwrap().elapsed, 1); + assert_eq!(h.batches[2].skills.get(c).unwrap().elapsed, 1); assert_ulps_eq!( - h.batches[0].skills[&a].posterior(), + h.batches[0].skills.get(a).unwrap().posterior(), Gaussian::from_ms(0.000000, 1.300610), epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[0].skills[&b].posterior(), + h.batches[0].skills.get(b).unwrap().posterior(), Gaussian::from_ms(0.000000, 1.300610), epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[2].skills[&b].posterior(), + h.batches[2].skills.get(b).unwrap().posterior(), Gaussian::from_ms(0.000000, 1.300610), epsilon = 1e-6 ); @@ -952,22 +960,22 @@ mod tests { h.convergence(ITERATIONS, EPSILON, false); assert_ulps_eq!( - h.batches[0].skills[&a].posterior(), + h.batches[0].skills.get(a).unwrap().posterior(), Gaussian::from_ms(0.000000, 0.931236), epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[3].skills[&a].posterior(), + h.batches[3].skills.get(a).unwrap().posterior(), Gaussian::from_ms(0.000000, 0.931236), epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[3].skills[&b].posterior(), + h.batches[3].skills.get(b).unwrap().posterior(), Gaussian::from_ms(0.000000, 0.931236), epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[5].skills[&b].posterior(), + h.batches[5].skills.get(b).unwrap().posterior(), Gaussian::from_ms(0.000000, 0.931236), epsilon = 1e-6 ); @@ -1103,32 +1111,32 @@ mod tests { let end = h.batches.len() - 1; - assert_eq!(h.batches[0].skills[&c].elapsed, 0); - assert_eq!(h.batches[end].skills[&c].elapsed, 10); + assert_eq!(h.batches[0].skills.get(c).unwrap().elapsed, 0); + assert_eq!(h.batches[end].skills.get(c).unwrap().elapsed, 10); - assert_eq!(h.batches[0].skills[&a].elapsed, 0); - assert_eq!(h.batches[2].skills[&a].elapsed, 5); + assert_eq!(h.batches[0].skills.get(a).unwrap().elapsed, 0); + assert_eq!(h.batches[2].skills.get(a).unwrap().elapsed, 5); - assert_eq!(h.batches[0].skills[&b].elapsed, 0); - assert_eq!(h.batches[end].skills[&b].elapsed, 5); + assert_eq!(h.batches[0].skills.get(b).unwrap().elapsed, 0); + assert_eq!(h.batches[end].skills.get(b).unwrap().elapsed, 5); h.convergence(ITERATIONS, EPSILON, false); assert_ulps_eq!( - h.batches[0].skills[&b].posterior(), - h.batches[end].skills[&b].posterior(), + h.batches[0].skills.get(b).unwrap().posterior(), + h.batches[end].skills.get(b).unwrap().posterior(), epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[0].skills[&c].posterior(), - h.batches[end].skills[&c].posterior(), + h.batches[0].skills.get(c).unwrap().posterior(), + h.batches[end].skills.get(c).unwrap().posterior(), epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[0].skills[&c].posterior(), - h.batches[0].skills[&b].posterior(), + h.batches[0].skills.get(c).unwrap().posterior(), + h.batches[0].skills.get(b).unwrap().posterior(), epsilon = 1e-6 ); @@ -1191,32 +1199,32 @@ mod tests { let end = h.batches.len() - 1; - assert_eq!(h.batches[0].skills[&c].elapsed, 0); - assert_eq!(h.batches[end].skills[&c].elapsed, 10); + assert_eq!(h.batches[0].skills.get(c).unwrap().elapsed, 0); + assert_eq!(h.batches[end].skills.get(c).unwrap().elapsed, 10); - assert_eq!(h.batches[0].skills[&a].elapsed, 0); - assert_eq!(h.batches[2].skills[&a].elapsed, 5); + assert_eq!(h.batches[0].skills.get(a).unwrap().elapsed, 0); + assert_eq!(h.batches[2].skills.get(a).unwrap().elapsed, 5); - assert_eq!(h.batches[0].skills[&b].elapsed, 0); - assert_eq!(h.batches[end].skills[&b].elapsed, 5); + assert_eq!(h.batches[0].skills.get(b).unwrap().elapsed, 0); + assert_eq!(h.batches[end].skills.get(b).unwrap().elapsed, 5); h.convergence(ITERATIONS, EPSILON, false); assert_ulps_eq!( - h.batches[0].skills[&b].posterior(), - h.batches[end].skills[&b].posterior(), + h.batches[0].skills.get(b).unwrap().posterior(), + h.batches[end].skills.get(b).unwrap().posterior(), epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[0].skills[&c].posterior(), - h.batches[end].skills[&c].posterior(), + h.batches[0].skills.get(c).unwrap().posterior(), + h.batches[end].skills.get(c).unwrap().posterior(), epsilon = 1e-6 ); assert_ulps_eq!( - h.batches[0].skills[&c].posterior(), - h.batches[0].skills[&b].posterior(), + h.batches[0].skills.get(c).unwrap().posterior(), + h.batches[0].skills.get(b).unwrap().posterior(), epsilon = 1e-6 ); } diff --git a/src/lib.rs b/src/lib.rs index 9579032..b3f904a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,6 +18,7 @@ mod history; mod matrix; mod message; pub mod player; +pub(crate) mod storage; pub use drift::{ConstantDrift, Drift}; pub use error::InferenceError; diff --git a/src/storage/mod.rs b/src/storage/mod.rs new file mode 100644 index 0000000..ac9b62c --- /dev/null +++ b/src/storage/mod.rs @@ -0,0 +1,3 @@ +mod skill_store; + +pub(crate) use skill_store::SkillStore; diff --git a/src/storage/skill_store.rs b/src/storage/skill_store.rs new file mode 100644 index 0000000..14d9147 --- /dev/null +++ b/src/storage/skill_store.rs @@ -0,0 +1,128 @@ +use crate::Index; +use crate::batch::Skill; + +/// Dense Vec-backed store for per-agent skill state within a TimeSlice. +/// +/// Indexed directly by Index.0, eliminating HashMap hashing in the inner +/// convergence loop. Uses a parallel `present` mask so iteration skips +/// absent slots without incurring per-slot Option overhead in the hot path. +#[derive(Debug, Default)] +pub struct SkillStore { + skills: Vec, + present: Vec, + n_present: usize, +} + +impl SkillStore { + pub fn new() -> Self { + Self::default() + } + + fn ensure_capacity(&mut self, idx: usize) { + if idx >= self.skills.len() { + self.skills.resize_with(idx + 1, Skill::default); + self.present.resize(idx + 1, false); + } + } + + pub fn insert(&mut self, idx: Index, skill: Skill) { + self.ensure_capacity(idx.0); + if !self.present[idx.0] { + self.n_present += 1; + } + self.skills[idx.0] = skill; + self.present[idx.0] = true; + } + + pub fn get(&self, idx: Index) -> Option<&Skill> { + if idx.0 < self.present.len() && self.present[idx.0] { + Some(&self.skills[idx.0]) + } else { + None + } + } + + pub fn get_mut(&mut self, idx: Index) -> Option<&mut Skill> { + if idx.0 < self.present.len() && self.present[idx.0] { + Some(&mut self.skills[idx.0]) + } else { + None + } + } + + pub fn contains(&self, idx: Index) -> bool { + idx.0 < self.present.len() && self.present[idx.0] + } + + pub fn len(&self) -> usize { + self.n_present + } + + pub fn is_empty(&self) -> bool { + self.n_present == 0 + } + + pub fn iter(&self) -> impl Iterator { + self.present.iter().enumerate().filter_map(|(i, &p)| { + if p { + Some((Index(i), &self.skills[i])) + } else { + None + } + }) + } + + pub fn iter_mut(&mut self) -> impl Iterator { + self.skills + .iter_mut() + .zip(self.present.iter()) + .enumerate() + .filter_map(|(i, (s, &p))| if p { Some((Index(i), s)) } else { None }) + } + + pub fn keys(&self) -> impl Iterator + '_ { + self.present + .iter() + .enumerate() + .filter_map(|(i, &p)| if p { Some(Index(i)) } else { None }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn insert_then_get() { + let mut store = SkillStore::new(); + let idx = Index(3); + store.insert(idx, Skill::default()); + assert!(store.contains(idx)); + assert_eq!(store.len(), 1); + assert!(store.get(idx).is_some()); + } + + #[test] + fn missing_returns_none() { + let store = SkillStore::new(); + assert!(store.get(Index(0)).is_none()); + assert!(!store.contains(Index(42))); + } + + #[test] + fn iter_skips_absent_slots() { + let mut store = SkillStore::new(); + store.insert(Index(0), Skill::default()); + store.insert(Index(5), Skill::default()); + let keys: Vec = store.keys().collect(); + assert_eq!(keys, vec![Index(0), Index(5)]); + } + + #[test] + fn double_insert_does_not_double_count() { + let mut store = SkillStore::new(); + store.insert(Index(2), Skill::default()); + store.insert(Index(2), Skill::default()); + assert_eq!(store.len(), 1); + } +}