refactor(history): replace HashMap<Index, Agent<D>> with dense AgentStore<D>
AgentStore<D> is a Vec<Option<Agent<D>>>-backed store indexed directly by Index.0, eliminating per-iteration hashing in the cross-history forward/backward sweep. Implements Index<Index>/IndexMut<Index> for ergonomic agent access. AgentStore is public (so benches/batch.rs can use it). SkillStore remains pub(crate) since Skill is pub(crate) in batch.rs. HashMap<Index, _> is now only used for the posteriors() return value (temporary; will be replaced in T2 with a proper typed return) and for the add_events_with_prior(priors: HashMap<Index, Player<D>>) API (also T2 target). Part of T0 engine redesign.
This commit is contained in:
@@ -7,7 +7,9 @@ use crate::{
|
||||
drift::{ConstantDrift, Drift},
|
||||
gaussian::Gaussian,
|
||||
player::Player,
|
||||
sort_time, tuple_gt, tuple_max,
|
||||
sort_time,
|
||||
storage::AgentStore,
|
||||
tuple_gt, tuple_max,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -68,7 +70,7 @@ impl<D: Drift> HistoryBuilder<D> {
|
||||
History {
|
||||
size: 0,
|
||||
batches: Vec::new(),
|
||||
agents: HashMap::new(),
|
||||
agents: AgentStore::new(),
|
||||
time: self.time,
|
||||
mu: self.mu,
|
||||
sigma: self.sigma,
|
||||
@@ -104,7 +106,7 @@ impl Default for HistoryBuilder<ConstantDrift> {
|
||||
pub struct History<D: Drift = ConstantDrift> {
|
||||
size: usize,
|
||||
pub(crate) batches: Vec<Batch>,
|
||||
agents: HashMap<Index, Agent<D>>,
|
||||
agents: AgentStore<D>,
|
||||
time: bool,
|
||||
mu: f64,
|
||||
sigma: f64,
|
||||
@@ -119,7 +121,7 @@ impl Default for History<ConstantDrift> {
|
||||
Self {
|
||||
size: 0,
|
||||
batches: Vec::new(),
|
||||
agents: HashMap::new(),
|
||||
agents: AgentStore::new(),
|
||||
time: true,
|
||||
mu: MU,
|
||||
sigma: SIGMA,
|
||||
@@ -145,7 +147,7 @@ impl<D: Drift> History<D> {
|
||||
|
||||
for j in (0..self.batches.len() - 1).rev() {
|
||||
for agent in self.batches[j + 1].skills.keys() {
|
||||
self.agents.get_mut(&agent).unwrap().message =
|
||||
self.agents.get_mut(agent).unwrap().message =
|
||||
self.batches[j + 1].backward_prior_out(&agent, &self.agents);
|
||||
}
|
||||
|
||||
@@ -164,7 +166,7 @@ impl<D: Drift> History<D> {
|
||||
|
||||
for j in 1..self.batches.len() {
|
||||
for agent in self.batches[j - 1].skills.keys() {
|
||||
self.agents.get_mut(&agent).unwrap().message =
|
||||
self.agents.get_mut(agent).unwrap().message =
|
||||
self.batches[j - 1].forward_prior_out(&agent);
|
||||
}
|
||||
|
||||
@@ -296,7 +298,7 @@ impl<D: Drift> History<D> {
|
||||
|
||||
this_agent.push(*agent);
|
||||
|
||||
if !self.agents.contains_key(agent) {
|
||||
if !self.agents.contains(*agent) {
|
||||
self.agents.insert(
|
||||
*agent,
|
||||
Agent {
|
||||
@@ -345,9 +347,9 @@ impl<D: Drift> History<D> {
|
||||
for agent_idx in &this_agent {
|
||||
if let Some(skill) = batch.skills.get_mut(*agent_idx) {
|
||||
skill.elapsed =
|
||||
batch::compute_elapsed(self.agents[agent_idx].last_time, batch.time);
|
||||
batch::compute_elapsed(self.agents[*agent_idx].last_time, batch.time);
|
||||
|
||||
let agent = self.agents.get_mut(agent_idx).unwrap();
|
||||
let agent = self.agents.get_mut(*agent_idx).unwrap();
|
||||
|
||||
agent.last_time = if self.time { batch.time } else { i64::MAX };
|
||||
agent.message = batch.forward_prior_out(agent_idx);
|
||||
@@ -378,7 +380,7 @@ impl<D: Drift> History<D> {
|
||||
batch.add_events(composition, results, weights, &self.agents);
|
||||
|
||||
for agent_idx in batch.skills.keys() {
|
||||
let agent = self.agents.get_mut(&agent_idx).unwrap();
|
||||
let agent = self.agents.get_mut(agent_idx).unwrap();
|
||||
|
||||
agent.last_time = if self.time { t } else { i64::MAX };
|
||||
agent.message = batch.forward_prior_out(&agent_idx);
|
||||
@@ -392,7 +394,7 @@ impl<D: Drift> History<D> {
|
||||
let batch = &self.batches[k];
|
||||
|
||||
for agent_idx in batch.skills.keys() {
|
||||
let agent = self.agents.get_mut(&agent_idx).unwrap();
|
||||
let agent = self.agents.get_mut(agent_idx).unwrap();
|
||||
|
||||
agent.last_time = if self.time { t } else { i64::MAX };
|
||||
agent.message = batch.forward_prior_out(&agent_idx);
|
||||
@@ -413,9 +415,9 @@ impl<D: Drift> History<D> {
|
||||
for agent_idx in &this_agent {
|
||||
if let Some(skill) = batch.skills.get_mut(*agent_idx) {
|
||||
skill.elapsed =
|
||||
batch::compute_elapsed(self.agents[agent_idx].last_time, batch.time);
|
||||
batch::compute_elapsed(self.agents[*agent_idx].last_time, batch.time);
|
||||
|
||||
let agent = self.agents.get_mut(agent_idx).unwrap();
|
||||
let agent = self.agents.get_mut(*agent_idx).unwrap();
|
||||
|
||||
agent.last_time = if self.time { batch.time } else { i64::MAX };
|
||||
agent.message = batch.forward_prior_out(agent_idx);
|
||||
|
||||
Reference in New Issue
Block a user