diff --git a/src/history.rs b/src/history.rs index 1a83a31..7fa9229 100644 --- a/src/history.rs +++ b/src/history.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::{borrow::Borrow, collections::HashMap, hash::Hash, marker::PhantomData}; use crate::{ BETA, GAMMA, Index, MU, N_INF, P_DRAW, SIGMA, @@ -7,6 +7,7 @@ use crate::{ drift::{ConstantDrift, Drift}, error::InferenceError, gaussian::Gaussian, + key_table::KeyTable, observer::{NullObserver, Observer}, rating::Rating, sort_time, @@ -17,8 +18,12 @@ use crate::{ }; #[derive(Clone)] -pub struct HistoryBuilder = ConstantDrift, O: Observer = NullObserver> -{ +pub struct HistoryBuilder< + T: Time = i64, + D: Drift = ConstantDrift, + O: Observer = NullObserver, + K: Eq + Hash + Clone = &'static str, +> { mu: f64, sigma: f64, beta: f64, @@ -27,10 +32,11 @@ pub struct HistoryBuilder = ConstantDrift, O: Observe online: bool, convergence: ConvergenceOptions, observer: O, - _time: std::marker::PhantomData, + _time: PhantomData, + _key: PhantomData, } -impl, O: Observer> HistoryBuilder { +impl, O: Observer, K: Eq + Hash + Clone> HistoryBuilder { pub fn mu(mut self, mu: f64) -> Self { self.mu = mu; self @@ -46,7 +52,7 @@ impl, O: Observer> HistoryBuilder { self } - pub fn drift>(self, drift: D2) -> HistoryBuilder { + pub fn drift>(self, drift: D2) -> HistoryBuilder { HistoryBuilder { drift, mu: self.mu, @@ -57,6 +63,7 @@ impl, O: Observer> HistoryBuilder { convergence: self.convergence, observer: self.observer, _time: self._time, + _key: self._key, } } @@ -75,7 +82,7 @@ impl, O: Observer> HistoryBuilder { self } - pub fn observer>(self, observer: O2) -> HistoryBuilder { + pub fn observer>(self, observer: O2) -> HistoryBuilder { HistoryBuilder { mu: self.mu, sigma: self.sigma, @@ -86,14 +93,16 @@ impl, O: Observer> HistoryBuilder { convergence: self.convergence, observer, _time: self._time, + _key: self._key, } } - pub fn build(self) -> History { + pub fn build(self) -> History { History { size: 0, time_slices: Vec::new(), agents: CompetitorStore::new(), + keys: KeyTable::new(), mu: self.mu, sigma: self.sigma, beta: self.beta, @@ -106,14 +115,14 @@ impl, O: Observer> HistoryBuilder { } } -impl> HistoryBuilder { +impl, K: Eq + Hash + Clone> HistoryBuilder { pub fn gamma(mut self, gamma: f64) -> Self { self.drift = ConstantDrift(gamma); self } } -impl Default for HistoryBuilder { +impl Default for HistoryBuilder { fn default() -> Self { Self { mu: MU, @@ -124,15 +133,22 @@ impl Default for HistoryBuilder { online: false, convergence: ConvergenceOptions::default(), observer: NullObserver, - _time: std::marker::PhantomData, + _time: PhantomData, + _key: PhantomData, } } } -pub struct History = ConstantDrift, O: Observer = NullObserver> { +pub struct History< + T: Time = i64, + D: Drift = ConstantDrift, + O: Observer = NullObserver, + K: Eq + Hash + Clone = &'static str, +> { size: usize, pub(crate) time_slices: Vec>, pub(crate) agents: CompetitorStore, + keys: KeyTable, mu: f64, sigma: f64, beta: f64, @@ -143,19 +159,37 @@ pub struct History = ConstantDrift, O: Observer = observer: O, } -impl Default for History { +impl Default for History { fn default() -> Self { HistoryBuilder::default().build() } } -impl History { - pub fn builder() -> HistoryBuilder { +impl History { + pub fn builder() -> HistoryBuilder { HistoryBuilder::default() } } -impl, O: Observer> History { +impl, O: Observer, K: Eq + Hash + Clone> History { + pub fn intern(&mut self, key: &Q) -> Index + where + K: Borrow, + Q: Hash + Eq + ToOwned + ?Sized, + { + self.keys.get_or_create(key) + } + + pub fn lookup(&self, key: &Q) -> Option + where + K: Borrow, + Q: Hash + Eq + ToOwned + ?Sized, + { + self.keys.get(key) + } +} + +impl, O: Observer, K: Eq + Hash + Clone> History { fn iteration(&mut self) -> (f64, f64) { let mut step = (0.0, 0.0); @@ -298,7 +332,7 @@ impl, O: Observer> History { } } -impl, O: Observer> History { +impl, O: Observer, K: Eq + Hash + Clone> History { pub fn add_events( &mut self, composition: Vec>>, @@ -478,6 +512,43 @@ impl, O: Observer> History { self.size += n; Ok(()) } + + pub fn record_winner( + &mut self, + winner: &Q, + loser: &Q, + time: i64, + ) -> Result<(), InferenceError> + where + K: Borrow, + Q: Hash + Eq + ToOwned + ?Sized, + { + let w = self.intern(winner); + let l = self.intern(loser); + self.add_events_with_prior( + vec![vec![vec![w], vec![l]]], + vec![vec![1.0, 0.0]], + vec![time], + vec![], + HashMap::new(), + ) + } + + pub fn record_draw(&mut self, a: &Q, b: &Q, time: i64) -> Result<(), InferenceError> + where + K: Borrow, + Q: Hash + Eq + ToOwned + ?Sized, + { + let a_idx = self.intern(a); + let b_idx = self.intern(b); + self.add_events_with_prior( + vec![vec![vec![a_idx], vec![b_idx]]], + vec![vec![0.0, 0.0]], + vec![time], + vec![], + HashMap::new(), + ) + } } #[cfg(test)] diff --git a/tests/record_winner.rs b/tests/record_winner.rs new file mode 100644 index 0000000..ae18058 --- /dev/null +++ b/tests/record_winner.rs @@ -0,0 +1,54 @@ +use trueskill_tt::{ConstantDrift, ConvergenceOptions, History}; + +#[test] +fn record_winner_builds_history() { + let mut h = History::builder() + .mu(25.0) + .sigma(25.0 / 3.0) + .beta(25.0 / 6.0) + .drift(ConstantDrift(25.0 / 300.0)) + .convergence(ConvergenceOptions { + max_iter: 30, + epsilon: 1e-6, + }) + .build(); + + h.record_winner(&"alice", &"bob", 1).unwrap(); + h.converge().unwrap(); + + let a_idx = h.lookup(&"alice").unwrap(); + let b_idx = h.lookup(&"bob").unwrap(); + + assert_ne!(a_idx, b_idx); +} + +#[test] +fn intern_is_idempotent() { + let mut h: History = History::builder().build(); + let a1 = h.intern(&"alice"); + let a2 = h.intern(&"alice"); + assert_eq!(a1, a2); +} + +#[test] +fn lookup_returns_none_for_missing() { + let h: History = History::builder().build(); + assert!(h.lookup(&"nobody").is_none()); +} + +#[test] +fn record_draw_with_p_draw_set() { + let mut h = History::builder() + .mu(25.0) + .sigma(25.0 / 3.0) + .beta(25.0 / 6.0) + .drift(ConstantDrift(25.0 / 300.0)) + .p_draw(0.25) + .build(); + + h.record_draw(&"alice", &"bob", 1).unwrap(); + h.converge().unwrap(); + + assert!(h.lookup(&"alice").is_some()); + assert!(h.lookup(&"bob").is_some()); +}