diff --git a/benches/batch.rs b/benches/batch.rs index c1554af..505e2f9 100644 --- a/benches/batch.rs +++ b/benches/batch.rs @@ -1,11 +1,11 @@ use criterion::{Criterion, criterion_group, criterion_main}; use trueskill_tt::{ - BETA, GAMMA, IndexMap, MU, P_DRAW, SIGMA, agent::Agent, batch::Batch, drift::ConstantDrift, + BETA, GAMMA, KeyTable, MU, P_DRAW, SIGMA, agent::Agent, batch::Batch, drift::ConstantDrift, gaussian::Gaussian, player::Player, storage::AgentStore, }; fn criterion_benchmark(criterion: &mut Criterion) { - let mut index = IndexMap::new(); + let mut index = KeyTable::new(); let a = index.get_or_create("a"); let b = index.get_or_create("b"); diff --git a/examples/atp.rs b/examples/atp.rs index ebf5b05..0ebf845 100644 --- a/examples/atp.rs +++ b/examples/atp.rs @@ -1,6 +1,6 @@ use plotters::prelude::*; use time::{Date, Month}; -use trueskill_tt::{History, IndexMap}; +use trueskill_tt::{History, KeyTable}; fn main() { let mut csv = csv::Reader::open("examples/atp.csv").unwrap(); @@ -12,7 +12,7 @@ fn main() { let from = Date::from_calendar_date(1900, Month::January, 1).unwrap(); let time_format = time::format_description::parse("[year]-[month]-[day]").unwrap(); - let mut index_map = IndexMap::new(); + let mut index_map = KeyTable::new(); for row in csv.records() { if &row["double"] == "t" { diff --git a/src/batch.rs b/src/batch.rs index 75d3f47..72a415c 100644 --- a/src/batch.rs +++ b/src/batch.rs @@ -400,12 +400,12 @@ mod tests { use super::*; use crate::{ - IndexMap, agent::Agent, drift::ConstantDrift, player::Player, storage::AgentStore, + KeyTable, agent::Agent, drift::ConstantDrift, player::Player, storage::AgentStore, }; #[test] fn test_one_event_each() { - let mut index_map = IndexMap::new(); + let mut index_map = KeyTable::new(); let a = index_map.get_or_create("a"); let b = index_map.get_or_create("b"); @@ -481,7 +481,7 @@ mod tests { #[test] fn test_same_strength() { - let mut index_map = IndexMap::new(); + let mut index_map = KeyTable::new(); let a = index_map.get_or_create("a"); let b = index_map.get_or_create("b"); @@ -560,7 +560,7 @@ mod tests { #[test] fn test_add_events() { - let mut index_map = IndexMap::new(); + let mut index_map = KeyTable::new(); let a = index_map.get_or_create("a"); let b = index_map.get_or_create("b"); diff --git a/src/history.rs b/src/history.rs index cd28136..a867bc4 100644 --- a/src/history.rs +++ b/src/history.rs @@ -437,13 +437,13 @@ mod tests { use super::*; use crate::{ - ConstantDrift, EPSILON, Game, Gaussian, ITERATIONS, IndexMap, P_DRAW, Player, + ConstantDrift, EPSILON, Game, Gaussian, ITERATIONS, KeyTable, P_DRAW, Player, arena::ScratchArena, }; #[test] fn test_init() { - let mut index_map = IndexMap::new(); + let mut index_map = KeyTable::new(); let a = index_map.get_or_create("a"); let b = index_map.get_or_create("b"); @@ -513,7 +513,7 @@ mod tests { #[test] fn test_one_batch() { - let mut index_map = IndexMap::new(); + let mut index_map = KeyTable::new(); let a = index_map.get_or_create("a"); let b = index_map.get_or_create("b"); @@ -620,7 +620,7 @@ mod tests { #[test] fn test_learning_curves() { - let mut index_map = IndexMap::new(); + let mut index_map = KeyTable::new(); let a = index_map.get_or_create("a"); let b = index_map.get_or_create("b"); @@ -674,7 +674,7 @@ mod tests { #[test] fn test_env_ttt() { - let mut index_map = IndexMap::new(); + let mut index_map = KeyTable::new(); let a = index_map.get_or_create("a"); let b = index_map.get_or_create("b"); @@ -721,7 +721,7 @@ mod tests { #[test] fn test_teams() { - let mut index_map = IndexMap::new(); + let mut index_map = KeyTable::new(); let a = index_map.get_or_create("a"); let b = index_map.get_or_create("b"); @@ -811,7 +811,7 @@ mod tests { #[test] fn test_add_events() { - let mut index_map = IndexMap::new(); + let mut index_map = KeyTable::new(); let a = index_map.get_or_create("a"); let b = index_map.get_or_create("b"); @@ -900,7 +900,7 @@ mod tests { #[test] fn test_only_add_events() { - let mut index_map = IndexMap::new(); + let mut index_map = KeyTable::new(); let a = index_map.get_or_create("a"); let b = index_map.get_or_create("b"); @@ -989,7 +989,7 @@ mod tests { #[test] fn test_log_evidence() { - let mut index_map = IndexMap::new(); + let mut index_map = KeyTable::new(); let a = index_map.get_or_create("a"); let b = index_map.get_or_create("b"); @@ -1048,7 +1048,7 @@ mod tests { #[test] fn test_add_events_with_time() { - let mut index_map = IndexMap::new(); + let mut index_map = KeyTable::new(); let a = index_map.get_or_create("a"); let b = index_map.get_or_create("b"); @@ -1237,7 +1237,7 @@ mod tests { #[test] fn test_1vs1_weighted() { - let mut index_map = IndexMap::new(); + let mut index_map = KeyTable::new(); let a = index_map.get_or_create("a"); let b = index_map.get_or_create("b"); diff --git a/src/key_table.rs b/src/key_table.rs new file mode 100644 index 0000000..8061654 --- /dev/null +++ b/src/key_table.rs @@ -0,0 +1,72 @@ +use std::{ + borrow::{Borrow, ToOwned}, + collections::HashMap, + hash::Hash, +}; + +use crate::Index; + +/// Maps user keys to internal `Index` handles. +/// +/// Renamed from the former `IndexMap` to avoid colliding with the `indexmap` +/// crate. Power users can promote `&K` to `Index` via `get_or_create` and +/// skip the lookup on subsequent hot-path calls. +#[derive(Debug)] +pub struct KeyTable(HashMap); + +impl KeyTable +where + K: Eq + Hash, +{ + pub fn new() -> Self { + Self(HashMap::new()) + } + + pub fn get>(&self, k: &Q) -> Option + where + K: Borrow, + { + self.0.get(k).cloned() + } + + pub fn get_or_create>(&mut self, k: &Q) -> Index + where + K: Borrow, + { + if let Some(idx) = self.0.get(k) { + *idx + } else { + let idx = Index::from(self.0.len()); + self.0.insert(k.to_owned(), idx); + idx + } + } + + pub fn key(&self, idx: Index) -> Option<&K> { + self.0 + .iter() + .find(|&(_, value)| *value == idx) + .map(|(key, _)| key) + } + + pub fn keys(&self) -> impl Iterator { + self.0.keys() + } + + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + +impl Default for KeyTable +where + K: Eq + Hash, +{ + fn default() -> Self { + KeyTable::new() + } +} diff --git a/src/lib.rs b/src/lib.rs index 3ddd8c0..0afdc64 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,6 @@ use std::{ - borrow::{Borrow, ToOwned}, cmp::Reverse, - collections::HashMap, f64::consts::{FRAC_1_SQRT_2, FRAC_2_SQRT_PI, SQRT_2}, - hash::Hash, }; pub mod agent; @@ -17,6 +14,7 @@ pub(crate) mod factor; mod game; pub mod gaussian; mod history; +pub mod key_table; mod matrix; pub mod player; pub(crate) mod schedule; @@ -27,6 +25,7 @@ pub use error::InferenceError; pub use game::Game; pub use gaussian::Gaussian; pub use history::History; +pub use key_table::KeyTable; use matrix::Matrix; pub use player::Player; pub use schedule::ScheduleReport; @@ -54,59 +53,6 @@ impl From for Index { } } -pub struct IndexMap(HashMap); - -impl IndexMap -where - K: Eq + Hash, -{ - pub fn new() -> Self { - Self(HashMap::new()) - } - - pub fn get>(&self, k: &Q) -> Option - where - K: Borrow, - { - self.0.get(k).cloned() - } - - pub fn get_or_create>(&mut self, k: &Q) -> Index - where - K: Borrow, - { - if let Some(idx) = self.0.get(k) { - *idx - } else { - let idx = Index::from(self.0.len()); - - self.0.insert(k.to_owned(), idx); - - idx - } - } - - pub fn key(&self, idx: Index) -> Option<&K> { - self.0 - .iter() - .find(|&(_, value)| *value == idx) - .map(|(key, _)| key) - } - - pub fn keys(&self) -> impl Iterator { - self.0.keys() - } -} - -impl Default for IndexMap -where - K: Eq + Hash, -{ - fn default() -> Self { - IndexMap::new() - } -} - fn erfc(x: f64) -> f64 { let z = x.abs(); let t = 1.0 / (1.0 + z / 2.0);