T0 + T1 + T2: engine redesign through new API surface #1
@@ -1,9 +1,7 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use criterion::{Criterion, criterion_group, criterion_main};
|
||||
use trueskill_tt::{
|
||||
BETA, GAMMA, IndexMap, MU, P_DRAW, SIGMA, agent::Agent, batch::Batch, drift::ConstantDrift,
|
||||
gaussian::Gaussian, player::Player,
|
||||
gaussian::Gaussian, player::Player, storage::AgentStore,
|
||||
};
|
||||
|
||||
fn criterion_benchmark(criterion: &mut Criterion) {
|
||||
@@ -13,33 +11,17 @@ fn criterion_benchmark(criterion: &mut Criterion) {
|
||||
let b = index.get_or_create("b");
|
||||
let c = index.get_or_create("c");
|
||||
|
||||
let agents = {
|
||||
let mut map = HashMap::new();
|
||||
let mut agents: AgentStore<ConstantDrift> = AgentStore::new();
|
||||
|
||||
map.insert(
|
||||
a,
|
||||
for agent in [a, b, c] {
|
||||
agents.insert(
|
||||
agent,
|
||||
Agent {
|
||||
player: Player::new(Gaussian::from_ms(MU, SIGMA), BETA, ConstantDrift(GAMMA)),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
map.insert(
|
||||
b,
|
||||
Agent {
|
||||
player: Player::new(Gaussian::from_ms(MU, SIGMA), BETA, ConstantDrift(GAMMA)),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
map.insert(
|
||||
c,
|
||||
Agent {
|
||||
player: Player::new(Gaussian::from_ms(MU, SIGMA), BETA, ConstantDrift(GAMMA)),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
|
||||
map
|
||||
};
|
||||
}
|
||||
|
||||
let mut composition = Vec::new();
|
||||
let mut results = Vec::new();
|
||||
|
||||
52
src/batch.rs
52
src/batch.rs
@@ -1,8 +1,14 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
Index, N_INF, agent::Agent, drift::Drift, game::Game, gaussian::Gaussian, player::Player,
|
||||
storage::SkillStore, tuple_gt, tuple_max,
|
||||
Index, N_INF,
|
||||
agent::Agent,
|
||||
drift::Drift,
|
||||
game::Game,
|
||||
gaussian::Gaussian,
|
||||
player::Player,
|
||||
storage::{AgentStore, SkillStore},
|
||||
tuple_gt, tuple_max,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -44,9 +50,9 @@ impl Item {
|
||||
online: bool,
|
||||
forward: bool,
|
||||
skills: &SkillStore,
|
||||
agents: &HashMap<Index, Agent<D>>,
|
||||
agents: &AgentStore<D>,
|
||||
) -> Player<D> {
|
||||
let r = &agents[&self.agent].player;
|
||||
let r = &agents[self.agent].player;
|
||||
let skill = skills.get(self.agent).unwrap();
|
||||
|
||||
if online {
|
||||
@@ -85,7 +91,7 @@ impl Event {
|
||||
online: bool,
|
||||
forward: bool,
|
||||
skills: &SkillStore,
|
||||
agents: &HashMap<Index, Agent<D>>,
|
||||
agents: &AgentStore<D>,
|
||||
) -> Vec<Vec<Player<D>>> {
|
||||
self.teams
|
||||
.iter()
|
||||
@@ -122,7 +128,7 @@ impl Batch {
|
||||
composition: Vec<Vec<Vec<Index>>>,
|
||||
results: Vec<Vec<f64>>,
|
||||
weights: Vec<Vec<Vec<f64>>>,
|
||||
agents: &HashMap<Index, Agent<D>>,
|
||||
agents: &AgentStore<D>,
|
||||
) {
|
||||
let mut unique = Vec::with_capacity(10);
|
||||
|
||||
@@ -137,16 +143,16 @@ impl Batch {
|
||||
});
|
||||
|
||||
for idx in this_agent {
|
||||
let elapsed = compute_elapsed(agents[idx].last_time, self.time);
|
||||
let elapsed = compute_elapsed(agents[*idx].last_time, self.time);
|
||||
|
||||
if let Some(skill) = self.skills.get_mut(*idx) {
|
||||
skill.elapsed = elapsed;
|
||||
skill.forward = agents[idx].receive(elapsed);
|
||||
skill.forward = agents[*idx].receive(elapsed);
|
||||
} else {
|
||||
self.skills.insert(
|
||||
*idx,
|
||||
Skill {
|
||||
forward: agents[idx].receive(elapsed),
|
||||
forward: agents[*idx].receive(elapsed),
|
||||
elapsed,
|
||||
..Default::default()
|
||||
},
|
||||
@@ -208,7 +214,7 @@ impl Batch {
|
||||
.collect::<HashMap<_, _>>()
|
||||
}
|
||||
|
||||
pub fn iteration<D: Drift>(&mut self, from: usize, agents: &HashMap<Index, Agent<D>>) {
|
||||
pub fn iteration<D: Drift>(&mut self, from: usize, agents: &AgentStore<D>) {
|
||||
for event in self.events.iter_mut().skip(from) {
|
||||
let teams = event.within_priors(false, false, &self.skills, agents);
|
||||
let result = event.outputs();
|
||||
@@ -229,7 +235,7 @@ impl Batch {
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn convergence<D: Drift>(&mut self, agents: &HashMap<Index, Agent<D>>) -> usize {
|
||||
pub(crate) fn convergence<D: Drift>(&mut self, agents: &AgentStore<D>) -> usize {
|
||||
let epsilon = 1e-6;
|
||||
let iterations = 20;
|
||||
|
||||
@@ -261,23 +267,23 @@ impl Batch {
|
||||
pub(crate) fn backward_prior_out<D: Drift>(
|
||||
&self,
|
||||
agent: &Index,
|
||||
agents: &HashMap<Index, Agent<D>>,
|
||||
agents: &AgentStore<D>,
|
||||
) -> Gaussian {
|
||||
let skill = self.skills.get(*agent).unwrap();
|
||||
let n = skill.likelihood * skill.backward;
|
||||
n.forget(agents[agent].player.drift.variance_delta(skill.elapsed))
|
||||
n.forget(agents[*agent].player.drift.variance_delta(skill.elapsed))
|
||||
}
|
||||
|
||||
pub(crate) fn new_backward_info<D: Drift>(&mut self, agents: &HashMap<Index, Agent<D>>) {
|
||||
pub(crate) fn new_backward_info<D: Drift>(&mut self, agents: &AgentStore<D>) {
|
||||
for (agent, skill) in self.skills.iter_mut() {
|
||||
skill.backward = agents[&agent].message;
|
||||
skill.backward = agents[agent].message;
|
||||
}
|
||||
self.iteration(0, agents);
|
||||
}
|
||||
|
||||
pub(crate) fn new_forward_info<D: Drift>(&mut self, agents: &HashMap<Index, Agent<D>>) {
|
||||
pub(crate) fn new_forward_info<D: Drift>(&mut self, agents: &AgentStore<D>) {
|
||||
for (agent, skill) in self.skills.iter_mut() {
|
||||
skill.forward = agents[&agent].receive(skill.elapsed);
|
||||
skill.forward = agents[agent].receive(skill.elapsed);
|
||||
}
|
||||
self.iteration(0, agents);
|
||||
}
|
||||
@@ -287,7 +293,7 @@ impl Batch {
|
||||
online: bool,
|
||||
targets: &[Index],
|
||||
forward: bool,
|
||||
agents: &HashMap<Index, Agent<D>>,
|
||||
agents: &AgentStore<D>,
|
||||
) -> f64 {
|
||||
if targets.is_empty() {
|
||||
if online || forward {
|
||||
@@ -387,7 +393,9 @@ mod tests {
|
||||
use approx::assert_ulps_eq;
|
||||
|
||||
use super::*;
|
||||
use crate::{IndexMap, agent::Agent, drift::ConstantDrift, player::Player};
|
||||
use crate::{
|
||||
IndexMap, agent::Agent, drift::ConstantDrift, player::Player, storage::AgentStore,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_one_event_each() {
|
||||
@@ -400,7 +408,7 @@ mod tests {
|
||||
let e = index_map.get_or_create("e");
|
||||
let f = index_map.get_or_create("f");
|
||||
|
||||
let mut agents = HashMap::new();
|
||||
let mut agents: AgentStore<ConstantDrift> = AgentStore::new();
|
||||
|
||||
for agent in [a, b, c, d, e, f] {
|
||||
agents.insert(
|
||||
@@ -476,7 +484,7 @@ mod tests {
|
||||
let e = index_map.get_or_create("e");
|
||||
let f = index_map.get_or_create("f");
|
||||
|
||||
let mut agents = HashMap::new();
|
||||
let mut agents: AgentStore<ConstantDrift> = AgentStore::new();
|
||||
|
||||
for agent in [a, b, c, d, e, f] {
|
||||
agents.insert(
|
||||
@@ -555,7 +563,7 @@ mod tests {
|
||||
let e = index_map.get_or_create("e");
|
||||
let f = index_map.get_or_create("f");
|
||||
|
||||
let mut agents = HashMap::new();
|
||||
let mut agents: AgentStore<ConstantDrift> = AgentStore::new();
|
||||
|
||||
for agent in [a, b, c, d, e, f] {
|
||||
agents.insert(
|
||||
|
||||
@@ -7,7 +7,9 @@ use crate::{
|
||||
drift::{ConstantDrift, Drift},
|
||||
gaussian::Gaussian,
|
||||
player::Player,
|
||||
sort_time, tuple_gt, tuple_max,
|
||||
sort_time,
|
||||
storage::AgentStore,
|
||||
tuple_gt, tuple_max,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -68,7 +70,7 @@ impl<D: Drift> HistoryBuilder<D> {
|
||||
History {
|
||||
size: 0,
|
||||
batches: Vec::new(),
|
||||
agents: HashMap::new(),
|
||||
agents: AgentStore::new(),
|
||||
time: self.time,
|
||||
mu: self.mu,
|
||||
sigma: self.sigma,
|
||||
@@ -104,7 +106,7 @@ impl Default for HistoryBuilder<ConstantDrift> {
|
||||
pub struct History<D: Drift = ConstantDrift> {
|
||||
size: usize,
|
||||
pub(crate) batches: Vec<Batch>,
|
||||
agents: HashMap<Index, Agent<D>>,
|
||||
agents: AgentStore<D>,
|
||||
time: bool,
|
||||
mu: f64,
|
||||
sigma: f64,
|
||||
@@ -119,7 +121,7 @@ impl Default for History<ConstantDrift> {
|
||||
Self {
|
||||
size: 0,
|
||||
batches: Vec::new(),
|
||||
agents: HashMap::new(),
|
||||
agents: AgentStore::new(),
|
||||
time: true,
|
||||
mu: MU,
|
||||
sigma: SIGMA,
|
||||
@@ -145,7 +147,7 @@ impl<D: Drift> History<D> {
|
||||
|
||||
for j in (0..self.batches.len() - 1).rev() {
|
||||
for agent in self.batches[j + 1].skills.keys() {
|
||||
self.agents.get_mut(&agent).unwrap().message =
|
||||
self.agents.get_mut(agent).unwrap().message =
|
||||
self.batches[j + 1].backward_prior_out(&agent, &self.agents);
|
||||
}
|
||||
|
||||
@@ -164,7 +166,7 @@ impl<D: Drift> History<D> {
|
||||
|
||||
for j in 1..self.batches.len() {
|
||||
for agent in self.batches[j - 1].skills.keys() {
|
||||
self.agents.get_mut(&agent).unwrap().message =
|
||||
self.agents.get_mut(agent).unwrap().message =
|
||||
self.batches[j - 1].forward_prior_out(&agent);
|
||||
}
|
||||
|
||||
@@ -296,7 +298,7 @@ impl<D: Drift> History<D> {
|
||||
|
||||
this_agent.push(*agent);
|
||||
|
||||
if !self.agents.contains_key(agent) {
|
||||
if !self.agents.contains(*agent) {
|
||||
self.agents.insert(
|
||||
*agent,
|
||||
Agent {
|
||||
@@ -345,9 +347,9 @@ impl<D: Drift> History<D> {
|
||||
for agent_idx in &this_agent {
|
||||
if let Some(skill) = batch.skills.get_mut(*agent_idx) {
|
||||
skill.elapsed =
|
||||
batch::compute_elapsed(self.agents[agent_idx].last_time, batch.time);
|
||||
batch::compute_elapsed(self.agents[*agent_idx].last_time, batch.time);
|
||||
|
||||
let agent = self.agents.get_mut(agent_idx).unwrap();
|
||||
let agent = self.agents.get_mut(*agent_idx).unwrap();
|
||||
|
||||
agent.last_time = if self.time { batch.time } else { i64::MAX };
|
||||
agent.message = batch.forward_prior_out(agent_idx);
|
||||
@@ -378,7 +380,7 @@ impl<D: Drift> History<D> {
|
||||
batch.add_events(composition, results, weights, &self.agents);
|
||||
|
||||
for agent_idx in batch.skills.keys() {
|
||||
let agent = self.agents.get_mut(&agent_idx).unwrap();
|
||||
let agent = self.agents.get_mut(agent_idx).unwrap();
|
||||
|
||||
agent.last_time = if self.time { t } else { i64::MAX };
|
||||
agent.message = batch.forward_prior_out(&agent_idx);
|
||||
@@ -392,7 +394,7 @@ impl<D: Drift> History<D> {
|
||||
let batch = &self.batches[k];
|
||||
|
||||
for agent_idx in batch.skills.keys() {
|
||||
let agent = self.agents.get_mut(&agent_idx).unwrap();
|
||||
let agent = self.agents.get_mut(agent_idx).unwrap();
|
||||
|
||||
agent.last_time = if self.time { t } else { i64::MAX };
|
||||
agent.message = batch.forward_prior_out(&agent_idx);
|
||||
@@ -413,9 +415,9 @@ impl<D: Drift> History<D> {
|
||||
for agent_idx in &this_agent {
|
||||
if let Some(skill) = batch.skills.get_mut(*agent_idx) {
|
||||
skill.elapsed =
|
||||
batch::compute_elapsed(self.agents[agent_idx].last_time, batch.time);
|
||||
batch::compute_elapsed(self.agents[*agent_idx].last_time, batch.time);
|
||||
|
||||
let agent = self.agents.get_mut(agent_idx).unwrap();
|
||||
let agent = self.agents.get_mut(*agent_idx).unwrap();
|
||||
|
||||
agent.last_time = if self.time { batch.time } else { i64::MAX };
|
||||
agent.message = batch.forward_prior_out(agent_idx);
|
||||
|
||||
@@ -18,7 +18,7 @@ mod history;
|
||||
mod matrix;
|
||||
mod message;
|
||||
pub mod player;
|
||||
pub(crate) mod storage;
|
||||
pub mod storage;
|
||||
|
||||
pub use drift::{ConstantDrift, Drift};
|
||||
pub use error::InferenceError;
|
||||
|
||||
125
src/storage/agent_store.rs
Normal file
125
src/storage/agent_store.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
use crate::{Index, agent::Agent, drift::Drift};
|
||||
|
||||
/// Dense Vec-backed store for agent state in History.
|
||||
///
|
||||
/// Indexed directly by Index.0, eliminating HashMap hashing in the
|
||||
/// forward/backward sweep. Uses `Vec<Option<Agent<D>>>` so slots can be
|
||||
/// absent without an explicit present mask.
|
||||
#[derive(Debug)]
|
||||
pub struct AgentStore<D: Drift> {
|
||||
agents: Vec<Option<Agent<D>>>,
|
||||
n_present: usize,
|
||||
}
|
||||
|
||||
impl<D: Drift> Default for AgentStore<D> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
agents: Vec::new(),
|
||||
n_present: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Drift> AgentStore<D> {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
fn ensure_capacity(&mut self, idx: usize) {
|
||||
if idx >= self.agents.len() {
|
||||
self.agents.resize_with(idx + 1, || None);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, idx: Index, agent: Agent<D>) {
|
||||
self.ensure_capacity(idx.0);
|
||||
if self.agents[idx.0].is_none() {
|
||||
self.n_present += 1;
|
||||
}
|
||||
self.agents[idx.0] = Some(agent);
|
||||
}
|
||||
|
||||
pub fn get(&self, idx: Index) -> Option<&Agent<D>> {
|
||||
self.agents.get(idx.0).and_then(|slot| slot.as_ref())
|
||||
}
|
||||
|
||||
pub fn get_mut(&mut self, idx: Index) -> Option<&mut Agent<D>> {
|
||||
self.agents.get_mut(idx.0).and_then(|slot| slot.as_mut())
|
||||
}
|
||||
|
||||
pub fn contains(&self, idx: Index) -> bool {
|
||||
self.get(idx).is_some()
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.n_present
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.n_present == 0
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = (Index, &Agent<D>)> {
|
||||
self.agents
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, slot)| slot.as_ref().map(|a| (Index(i), a)))
|
||||
}
|
||||
|
||||
pub fn iter_mut(&mut self) -> impl Iterator<Item = (Index, &mut Agent<D>)> {
|
||||
self.agents
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.filter_map(|(i, slot)| slot.as_mut().map(|a| (Index(i), a)))
|
||||
}
|
||||
|
||||
pub fn values_mut(&mut self) -> impl Iterator<Item = &mut Agent<D>> {
|
||||
self.agents.iter_mut().filter_map(|s| s.as_mut())
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Drift> std::ops::Index<Index> for AgentStore<D> {
|
||||
type Output = Agent<D>;
|
||||
fn index(&self, idx: Index) -> &Agent<D> {
|
||||
self.get(idx).expect("agent not found at index")
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Drift> std::ops::IndexMut<Index> for AgentStore<D> {
|
||||
fn index_mut(&mut self, idx: Index) -> &mut Agent<D> {
|
||||
self.get_mut(idx).expect("agent not found at index")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{agent::Agent, drift::ConstantDrift, player::Player};
|
||||
|
||||
#[test]
|
||||
fn insert_then_get() {
|
||||
let mut store: AgentStore<ConstantDrift> = AgentStore::new();
|
||||
let idx = Index(7);
|
||||
store.insert(idx, Agent::default());
|
||||
assert!(store.contains(idx));
|
||||
assert_eq!(store.len(), 1);
|
||||
assert!(store.get(idx).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn iter_in_index_order() {
|
||||
let mut store: AgentStore<ConstantDrift> = AgentStore::new();
|
||||
store.insert(Index(2), Agent::default());
|
||||
store.insert(Index(0), Agent::default());
|
||||
store.insert(Index(5), Agent::default());
|
||||
let keys: Vec<Index> = store.iter().map(|(i, _)| i).collect();
|
||||
assert_eq!(keys, vec![Index(0), Index(2), Index(5)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_operator_works() {
|
||||
let mut store: AgentStore<ConstantDrift> = AgentStore::new();
|
||||
store.insert(Index(3), Agent::default());
|
||||
let _ = &store[Index(3)];
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
mod agent_store;
|
||||
mod skill_store;
|
||||
|
||||
pub use agent_store::AgentStore;
|
||||
pub(crate) use skill_store::SkillStore;
|
||||
|
||||
Reference in New Issue
Block a user