From 49d2b317dad0d20582c37ec62a8c126d9cd4f723 Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Fri, 24 Apr 2026 07:15:21 +0200 Subject: [PATCH] refactor(history): replace HashMap> with dense AgentStore AgentStore is a Vec>>-backed store indexed directly by Index.0, eliminating per-iteration hashing in the cross-history forward/backward sweep. Implements Index/IndexMut for ergonomic agent access. AgentStore is public (so benches/batch.rs can use it). SkillStore remains pub(crate) since Skill is pub(crate) in batch.rs. HashMap is now only used for the posteriors() return value (temporary; will be replaced in T2 with a proper typed return) and for the add_events_with_prior(priors: HashMap>) API (also T2 target). Part of T0 engine redesign. --- benches/batch.rs | 30 ++------- src/batch.rs | 52 ++++++++------- src/history.rs | 28 +++++---- src/lib.rs | 2 +- src/storage/agent_store.rs | 125 +++++++++++++++++++++++++++++++++++++ src/storage/mod.rs | 2 + 6 files changed, 179 insertions(+), 60 deletions(-) create mode 100644 src/storage/agent_store.rs diff --git a/benches/batch.rs b/benches/batch.rs index 637a2fb..c1554af 100644 --- a/benches/batch.rs +++ b/benches/batch.rs @@ -1,9 +1,7 @@ -use std::collections::HashMap; - use criterion::{Criterion, criterion_group, criterion_main}; use trueskill_tt::{ BETA, GAMMA, IndexMap, MU, P_DRAW, SIGMA, agent::Agent, batch::Batch, drift::ConstantDrift, - gaussian::Gaussian, player::Player, + gaussian::Gaussian, player::Player, storage::AgentStore, }; fn criterion_benchmark(criterion: &mut Criterion) { @@ -13,33 +11,17 @@ fn criterion_benchmark(criterion: &mut Criterion) { let b = index.get_or_create("b"); let c = index.get_or_create("c"); - let agents = { - let mut map = HashMap::new(); + let mut agents: AgentStore = AgentStore::new(); - map.insert( - a, + for agent in [a, b, c] { + agents.insert( + agent, Agent { player: Player::new(Gaussian::from_ms(MU, SIGMA), BETA, ConstantDrift(GAMMA)), ..Default::default() }, ); - map.insert( - b, - Agent { - player: Player::new(Gaussian::from_ms(MU, SIGMA), BETA, ConstantDrift(GAMMA)), - ..Default::default() - }, - ); - map.insert( - c, - Agent { - player: Player::new(Gaussian::from_ms(MU, SIGMA), BETA, ConstantDrift(GAMMA)), - ..Default::default() - }, - ); - - map - }; + } let mut composition = Vec::new(); let mut results = Vec::new(); diff --git a/src/batch.rs b/src/batch.rs index 24d8c6b..8637251 100644 --- a/src/batch.rs +++ b/src/batch.rs @@ -1,8 +1,14 @@ use std::collections::HashMap; use crate::{ - Index, N_INF, agent::Agent, drift::Drift, game::Game, gaussian::Gaussian, player::Player, - storage::SkillStore, tuple_gt, tuple_max, + Index, N_INF, + agent::Agent, + drift::Drift, + game::Game, + gaussian::Gaussian, + player::Player, + storage::{AgentStore, SkillStore}, + tuple_gt, tuple_max, }; #[derive(Debug)] @@ -44,9 +50,9 @@ impl Item { online: bool, forward: bool, skills: &SkillStore, - agents: &HashMap>, + agents: &AgentStore, ) -> Player { - let r = &agents[&self.agent].player; + let r = &agents[self.agent].player; let skill = skills.get(self.agent).unwrap(); if online { @@ -85,7 +91,7 @@ impl Event { online: bool, forward: bool, skills: &SkillStore, - agents: &HashMap>, + agents: &AgentStore, ) -> Vec>> { self.teams .iter() @@ -122,7 +128,7 @@ impl Batch { composition: Vec>>, results: Vec>, weights: Vec>>, - agents: &HashMap>, + agents: &AgentStore, ) { let mut unique = Vec::with_capacity(10); @@ -137,16 +143,16 @@ impl Batch { }); for idx in this_agent { - let elapsed = compute_elapsed(agents[idx].last_time, self.time); + let elapsed = compute_elapsed(agents[*idx].last_time, self.time); if let Some(skill) = self.skills.get_mut(*idx) { skill.elapsed = elapsed; - skill.forward = agents[idx].receive(elapsed); + skill.forward = agents[*idx].receive(elapsed); } else { self.skills.insert( *idx, Skill { - forward: agents[idx].receive(elapsed), + forward: agents[*idx].receive(elapsed), elapsed, ..Default::default() }, @@ -208,7 +214,7 @@ impl Batch { .collect::>() } - pub fn iteration(&mut self, from: usize, agents: &HashMap>) { + pub fn iteration(&mut self, from: usize, agents: &AgentStore) { for event in self.events.iter_mut().skip(from) { let teams = event.within_priors(false, false, &self.skills, agents); let result = event.outputs(); @@ -229,7 +235,7 @@ impl Batch { } #[allow(dead_code)] - pub(crate) fn convergence(&mut self, agents: &HashMap>) -> usize { + pub(crate) fn convergence(&mut self, agents: &AgentStore) -> usize { let epsilon = 1e-6; let iterations = 20; @@ -261,23 +267,23 @@ impl Batch { pub(crate) fn backward_prior_out( &self, agent: &Index, - agents: &HashMap>, + agents: &AgentStore, ) -> Gaussian { let skill = self.skills.get(*agent).unwrap(); let n = skill.likelihood * skill.backward; - n.forget(agents[agent].player.drift.variance_delta(skill.elapsed)) + n.forget(agents[*agent].player.drift.variance_delta(skill.elapsed)) } - pub(crate) fn new_backward_info(&mut self, agents: &HashMap>) { + pub(crate) fn new_backward_info(&mut self, agents: &AgentStore) { for (agent, skill) in self.skills.iter_mut() { - skill.backward = agents[&agent].message; + skill.backward = agents[agent].message; } self.iteration(0, agents); } - pub(crate) fn new_forward_info(&mut self, agents: &HashMap>) { + pub(crate) fn new_forward_info(&mut self, agents: &AgentStore) { for (agent, skill) in self.skills.iter_mut() { - skill.forward = agents[&agent].receive(skill.elapsed); + skill.forward = agents[agent].receive(skill.elapsed); } self.iteration(0, agents); } @@ -287,7 +293,7 @@ impl Batch { online: bool, targets: &[Index], forward: bool, - agents: &HashMap>, + agents: &AgentStore, ) -> f64 { if targets.is_empty() { if online || forward { @@ -387,7 +393,9 @@ mod tests { use approx::assert_ulps_eq; use super::*; - use crate::{IndexMap, agent::Agent, drift::ConstantDrift, player::Player}; + use crate::{ + IndexMap, agent::Agent, drift::ConstantDrift, player::Player, storage::AgentStore, + }; #[test] fn test_one_event_each() { @@ -400,7 +408,7 @@ mod tests { let e = index_map.get_or_create("e"); let f = index_map.get_or_create("f"); - let mut agents = HashMap::new(); + let mut agents: AgentStore = AgentStore::new(); for agent in [a, b, c, d, e, f] { agents.insert( @@ -476,7 +484,7 @@ mod tests { let e = index_map.get_or_create("e"); let f = index_map.get_or_create("f"); - let mut agents = HashMap::new(); + let mut agents: AgentStore = AgentStore::new(); for agent in [a, b, c, d, e, f] { agents.insert( @@ -555,7 +563,7 @@ mod tests { let e = index_map.get_or_create("e"); let f = index_map.get_or_create("f"); - let mut agents = HashMap::new(); + let mut agents: AgentStore = AgentStore::new(); for agent in [a, b, c, d, e, f] { agents.insert( diff --git a/src/history.rs b/src/history.rs index b4be9ee..f2283d0 100644 --- a/src/history.rs +++ b/src/history.rs @@ -7,7 +7,9 @@ use crate::{ drift::{ConstantDrift, Drift}, gaussian::Gaussian, player::Player, - sort_time, tuple_gt, tuple_max, + sort_time, + storage::AgentStore, + tuple_gt, tuple_max, }; #[derive(Clone)] @@ -68,7 +70,7 @@ impl HistoryBuilder { History { size: 0, batches: Vec::new(), - agents: HashMap::new(), + agents: AgentStore::new(), time: self.time, mu: self.mu, sigma: self.sigma, @@ -104,7 +106,7 @@ impl Default for HistoryBuilder { pub struct History { size: usize, pub(crate) batches: Vec, - agents: HashMap>, + agents: AgentStore, time: bool, mu: f64, sigma: f64, @@ -119,7 +121,7 @@ impl Default for History { Self { size: 0, batches: Vec::new(), - agents: HashMap::new(), + agents: AgentStore::new(), time: true, mu: MU, sigma: SIGMA, @@ -145,7 +147,7 @@ impl History { for j in (0..self.batches.len() - 1).rev() { for agent in self.batches[j + 1].skills.keys() { - self.agents.get_mut(&agent).unwrap().message = + self.agents.get_mut(agent).unwrap().message = self.batches[j + 1].backward_prior_out(&agent, &self.agents); } @@ -164,7 +166,7 @@ impl History { for j in 1..self.batches.len() { for agent in self.batches[j - 1].skills.keys() { - self.agents.get_mut(&agent).unwrap().message = + self.agents.get_mut(agent).unwrap().message = self.batches[j - 1].forward_prior_out(&agent); } @@ -296,7 +298,7 @@ impl History { this_agent.push(*agent); - if !self.agents.contains_key(agent) { + if !self.agents.contains(*agent) { self.agents.insert( *agent, Agent { @@ -345,9 +347,9 @@ impl History { for agent_idx in &this_agent { if let Some(skill) = batch.skills.get_mut(*agent_idx) { skill.elapsed = - batch::compute_elapsed(self.agents[agent_idx].last_time, batch.time); + batch::compute_elapsed(self.agents[*agent_idx].last_time, batch.time); - let agent = self.agents.get_mut(agent_idx).unwrap(); + let agent = self.agents.get_mut(*agent_idx).unwrap(); agent.last_time = if self.time { batch.time } else { i64::MAX }; agent.message = batch.forward_prior_out(agent_idx); @@ -378,7 +380,7 @@ impl History { batch.add_events(composition, results, weights, &self.agents); for agent_idx in batch.skills.keys() { - let agent = self.agents.get_mut(&agent_idx).unwrap(); + let agent = self.agents.get_mut(agent_idx).unwrap(); agent.last_time = if self.time { t } else { i64::MAX }; agent.message = batch.forward_prior_out(&agent_idx); @@ -392,7 +394,7 @@ impl History { let batch = &self.batches[k]; for agent_idx in batch.skills.keys() { - let agent = self.agents.get_mut(&agent_idx).unwrap(); + let agent = self.agents.get_mut(agent_idx).unwrap(); agent.last_time = if self.time { t } else { i64::MAX }; agent.message = batch.forward_prior_out(&agent_idx); @@ -413,9 +415,9 @@ impl History { for agent_idx in &this_agent { if let Some(skill) = batch.skills.get_mut(*agent_idx) { skill.elapsed = - batch::compute_elapsed(self.agents[agent_idx].last_time, batch.time); + batch::compute_elapsed(self.agents[*agent_idx].last_time, batch.time); - let agent = self.agents.get_mut(agent_idx).unwrap(); + let agent = self.agents.get_mut(*agent_idx).unwrap(); agent.last_time = if self.time { batch.time } else { i64::MAX }; agent.message = batch.forward_prior_out(agent_idx); diff --git a/src/lib.rs b/src/lib.rs index b3f904a..b6c2924 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,7 +18,7 @@ mod history; mod matrix; mod message; pub mod player; -pub(crate) mod storage; +pub mod storage; pub use drift::{ConstantDrift, Drift}; pub use error::InferenceError; diff --git a/src/storage/agent_store.rs b/src/storage/agent_store.rs new file mode 100644 index 0000000..e52d394 --- /dev/null +++ b/src/storage/agent_store.rs @@ -0,0 +1,125 @@ +use crate::{Index, agent::Agent, drift::Drift}; + +/// Dense Vec-backed store for agent state in History. +/// +/// Indexed directly by Index.0, eliminating HashMap hashing in the +/// forward/backward sweep. Uses `Vec>>` so slots can be +/// absent without an explicit present mask. +#[derive(Debug)] +pub struct AgentStore { + agents: Vec>>, + n_present: usize, +} + +impl Default for AgentStore { + fn default() -> Self { + Self { + agents: Vec::new(), + n_present: 0, + } + } +} + +impl AgentStore { + pub fn new() -> Self { + Self::default() + } + + fn ensure_capacity(&mut self, idx: usize) { + if idx >= self.agents.len() { + self.agents.resize_with(idx + 1, || None); + } + } + + pub fn insert(&mut self, idx: Index, agent: Agent) { + self.ensure_capacity(idx.0); + if self.agents[idx.0].is_none() { + self.n_present += 1; + } + self.agents[idx.0] = Some(agent); + } + + pub fn get(&self, idx: Index) -> Option<&Agent> { + self.agents.get(idx.0).and_then(|slot| slot.as_ref()) + } + + pub fn get_mut(&mut self, idx: Index) -> Option<&mut Agent> { + self.agents.get_mut(idx.0).and_then(|slot| slot.as_mut()) + } + + pub fn contains(&self, idx: Index) -> bool { + self.get(idx).is_some() + } + + pub fn len(&self) -> usize { + self.n_present + } + + pub fn is_empty(&self) -> bool { + self.n_present == 0 + } + + pub fn iter(&self) -> impl Iterator)> { + self.agents + .iter() + .enumerate() + .filter_map(|(i, slot)| slot.as_ref().map(|a| (Index(i), a))) + } + + pub fn iter_mut(&mut self) -> impl Iterator)> { + self.agents + .iter_mut() + .enumerate() + .filter_map(|(i, slot)| slot.as_mut().map(|a| (Index(i), a))) + } + + pub fn values_mut(&mut self) -> impl Iterator> { + self.agents.iter_mut().filter_map(|s| s.as_mut()) + } +} + +impl std::ops::Index for AgentStore { + type Output = Agent; + fn index(&self, idx: Index) -> &Agent { + self.get(idx).expect("agent not found at index") + } +} + +impl std::ops::IndexMut for AgentStore { + fn index_mut(&mut self, idx: Index) -> &mut Agent { + self.get_mut(idx).expect("agent not found at index") + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{agent::Agent, drift::ConstantDrift, player::Player}; + + #[test] + fn insert_then_get() { + let mut store: AgentStore = AgentStore::new(); + let idx = Index(7); + store.insert(idx, Agent::default()); + assert!(store.contains(idx)); + assert_eq!(store.len(), 1); + assert!(store.get(idx).is_some()); + } + + #[test] + fn iter_in_index_order() { + let mut store: AgentStore = AgentStore::new(); + store.insert(Index(2), Agent::default()); + store.insert(Index(0), Agent::default()); + store.insert(Index(5), Agent::default()); + let keys: Vec = store.iter().map(|(i, _)| i).collect(); + assert_eq!(keys, vec![Index(0), Index(2), Index(5)]); + } + + #[test] + fn index_operator_works() { + let mut store: AgentStore = AgentStore::new(); + store.insert(Index(3), Agent::default()); + let _ = &store[Index(3)]; + } +} diff --git a/src/storage/mod.rs b/src/storage/mod.rs index ac9b62c..a77963d 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -1,3 +1,5 @@ +mod agent_store; mod skill_store; +pub use agent_store::AgentStore; pub(crate) use skill_store::SkillStore;