diff --git a/examples/atp.rs b/examples/atp.rs index d54ddd3..730e45f 100644 --- a/examples/atp.rs +++ b/examples/atp.rs @@ -107,9 +107,9 @@ fn main() { */ let mut h = History::new( - composition, - results, - times, + &composition, + &results, + ×, HashMap::new(), MU, 1.6, diff --git a/src/batch.rs b/src/batch.rs index 3405114..6372b93 100644 --- a/src/batch.rs +++ b/src/batch.rs @@ -37,6 +37,7 @@ impl Agent { } } + #[inline] pub fn receive(&self, elapsed: f64) -> Gaussian { if self.message != N_INF { self.message.forget(self.player.gamma, elapsed) @@ -191,12 +192,14 @@ impl Batch { self.iteration(from, agents); } + #[inline] pub fn posterior(&self, agent: &PlayerIndex) -> Gaussian { let skill = &self.skills[agent]; skill.likelihood * skill.backward * skill.forward } + #[inline] pub fn posteriors(&self) -> HashMap { self.skills .keys() @@ -204,6 +207,7 @@ impl Batch { .collect::>() } + #[inline] fn within_prior(&self, item: &Item, agents: &mut HashMap) -> Player { let r = &agents[&item.index].player; let g = self.posterior(&item.index) / item.likelihood; @@ -278,12 +282,14 @@ impl Batch { i } + #[inline] pub fn forward_prior_out(&self, agent: &PlayerIndex) -> Gaussian { let skill = &self.skills[agent]; skill.forward * skill.likelihood } + #[inline] pub fn backward_prior_out( &self, agent: &PlayerIndex, @@ -295,6 +301,7 @@ impl Batch { n.forget(agents[agent].player.gamma, skill.elapsed) } + #[inline] pub fn new_backward_info(&mut self, agents: &mut HashMap) { for (agent, skill) in self.skills.iter_mut() { skill.backward = agents[agent].message; @@ -303,6 +310,7 @@ impl Batch { self.iteration(0, agents); } + #[inline] pub fn new_forward_info(&mut self, agents: &mut HashMap) { for (agent, skill) in self.skills.iter_mut() { skill.forward = agents[agent].receive(skill.elapsed); diff --git a/src/game.rs b/src/game.rs index 083d4de..40845c4 100644 --- a/src/game.rs +++ b/src/game.rs @@ -72,7 +72,7 @@ impl Game { } let r = &self.result; - let o = utils::sortperm(r, true); + let o = utils::sort_perm(r); let t = (0..self.teams.len()) .map(|e| TeamVariable { diff --git a/src/history.rs b/src/history.rs index a51a3cb..632b420 100644 --- a/src/history.rs +++ b/src/history.rs @@ -6,9 +6,10 @@ pub struct History { size: usize, batches: Vec, agents: HashMap, - mu: f64, - sigma: f64, - gamma: f64, + // mu: f64, + // sigma: f64, + // beta: f64, + // gamma: f64, p_draw: f64, time: bool, pub epsilon: f64, @@ -18,9 +19,9 @@ pub struct History { impl History { pub fn new( - composition: Vec>>, - results: Vec>, - times: Vec, + composition: &[Vec>], + results: &[Vec], + times: &[f64], priors: HashMap, mu: f64, sigma: f64, @@ -57,9 +58,6 @@ impl History { size: composition.len(), batches: Vec::new(), agents, - mu, - sigma, - gamma, p_draw, time: !times.is_empty(), epsilon: 1e-6, @@ -74,12 +72,12 @@ impl History { fn trueskill( &mut self, - composition: Vec>>, - results: Vec>, - times: Vec, + composition: &[Vec>], + results: &[Vec], + times: &[f64], ) { let o = if self.time { - utils::sort_time(×, false) + utils::sort_time(times) } else { (0..composition.len()).collect::>() }; @@ -88,13 +86,13 @@ impl History { while i < self.size { let mut j = i + 1; - let t = if self.time { + let time = if self.time { times[o[i]] } else { i as f64 + 1.0 }; - while self.time && j < self.size && times[o[j]] == t { + while self.time && j < self.size && times[o[j]] == time { j += 1; } @@ -104,20 +102,17 @@ impl History { let results = (i..j).map(|e| results[o[e]].clone()).collect::>(); - let b = Batch::new( - composition, - results, - t as f64, - &mut self.agents, - self.p_draw, - ); + let b = Batch::new(composition, results, time, &mut self.agents, self.p_draw); - self.batches.push(b.clone()); + self.batches.push(b); + + let idx = self.batches.len() - 1; + let b = &mut self.batches[idx]; for a in b.skills.keys() { let agent = self.agents.get_mut(a).unwrap(); - agent.last_time = if self.time { t as f64 } else { f64::INFINITY }; + agent.last_time = if self.time { time } else { f64::INFINITY }; agent.message = b.forward_prior_out(a); } @@ -285,9 +280,9 @@ mod tests { } let mut h = History::new( - composition, - results, - vec![1.0, 2.0, 3.0], + &composition, + &results, + &[1.0, 2.0, 3.0], priors, MU, BETA, @@ -348,9 +343,9 @@ mod tests { } let mut h1 = History::new( - composition, - results, - times, + &composition, + &results, + ×, priors, MU, BETA, @@ -425,9 +420,9 @@ mod tests { } let mut h2 = History::new( - composition, - results, - times, + &composition, + &results, + ×, priors, MU, BETA, @@ -509,9 +504,9 @@ mod tests { } let mut h = History::new( - composition, - results, - times, + &composition, + &results, + ×, priors, MU, BETA, @@ -566,9 +561,9 @@ mod tests { let results = vec![vec![1, 0], vec![0, 1], vec![1, 0]]; let mut h = History::new( - composition, - results, - Vec::new(), + &composition, + &results, + &[], HashMap::new(), 25.0, 25.0 / 3.0, diff --git a/src/utils.rs b/src/utils.rs index 3dc92c7..e980d46 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -144,26 +144,20 @@ pub(crate) fn compute_margin(p_draw: f64, sd: f64) -> f64 { ppf(0.5 - p_draw / 2.0, 0.0, sd).abs() } -pub(crate) fn sortperm(xs: &[T], reverse: bool) -> Vec { +#[inline] +pub(crate) fn sort_perm(xs: &[u16]) -> Vec { let mut x = xs.iter().enumerate().collect::>(); - if reverse { - x.sort_unstable_by_key(|(_, x)| Reverse(*x)); - } else { - x.sort_unstable_by_key(|(_, x)| *x); - } + x.sort_unstable_by_key(|(_, x)| Reverse(*x)); x.into_iter().map(|(i, _)| i).collect() } -pub(crate) fn sort_time(xs: &[f64], reverse: bool) -> Vec { +#[inline] +pub(crate) fn sort_time(xs: &[f64]) -> Vec { let mut x = xs.iter().enumerate().collect::>(); - if reverse { - x.sort_unstable_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap()); - } else { - x.sort_unstable_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()); - } + x.sort_unstable_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()); x.into_iter().map(|(i, _)| i).collect() } @@ -244,13 +238,11 @@ mod tests { #[test] fn test_sortperm() { - assert_eq!(sortperm(&[0, 1, 2, 0], true), vec![2, 1, 0, 3]); - assert_eq!(sortperm(&[1, 1, 1], false), vec![0, 1, 2]); + assert_eq!(sort_perm(&[0, 1, 2, 0]), vec![2, 1, 0, 3]); } #[test] fn test_sort_time() { - assert_eq!(sort_time(&[0.0, 1.0, 2.0, 0.0], true), vec![2, 1, 0, 3]); - assert_eq!(sort_time(&[1.0, 1.0, 1.0], false), vec![0, 1, 2]); + assert_eq!(sort_time(&[1.0, 1.0, 1.0]), vec![0, 1, 2]); } }