diff --git a/examples/atp.rs b/examples/atp.rs index 88170e5..a88bf8c 100644 --- a/examples/atp.rs +++ b/examples/atp.rs @@ -34,7 +34,7 @@ fn main() { .assume_utc() .unix_timestamp(); - times.push(time as u64); + times.push(time as f64 / (60 * 60 * 24) as f64); } let mut h = History::new( @@ -48,6 +48,10 @@ fn main() { 0.036, P_DRAW, ); + + h.epsilon = 0.01; + h.iterations = 10; + h.convergence(); /* diff --git a/src/history.rs b/src/history.rs index 21cb442..7bedfb1 100644 --- a/src/history.rs +++ b/src/history.rs @@ -11,13 +11,16 @@ pub struct History { gamma: f64, p_draw: f64, time: bool, + pub epsilon: f64, + pub iterations: usize, + pub verbose: bool, } impl History { pub fn new + Clone>( composition: Vec>>, results: Vec>, - times: Vec, + times: Vec, priors: HashMap, mu: f64, sigma: f64, @@ -60,6 +63,9 @@ impl History { gamma, p_draw, time: !times.is_empty(), + epsilon: 1e-6, + iterations: 30, + verbose: true, }; this.trueskill(composition, results, times); @@ -71,10 +77,10 @@ impl History { &mut self, composition: Vec>>, results: Vec>, - times: Vec, + times: Vec, ) { let o = if self.time { - utils::sortperm(×, false) + utils::sort_time(×, false) } else { (0..composition.len()).collect::>() }; @@ -83,7 +89,11 @@ impl History { while i < self.size { let mut j = i + 1; - let t = if self.time { times[o[i]] } else { i as u64 + 1 }; + let t = if self.time { + times[o[i]] + } else { + i as f64 + 1.0 + }; while self.time && j < self.size && times[o[j]] == t { j += 1; @@ -173,15 +183,11 @@ impl History { } pub fn convergence(&mut self) -> ((f64, f64), usize) { - let epsilon = 1e-6; - let iterations = 30; - let verbose = true; - let mut step = (f64::INFINITY, f64::INFINITY); let mut i = 0; - while (step.0 > epsilon || step.1 > epsilon) && i < iterations { - if verbose { + while (step.0 > self.epsilon || step.1 > self.epsilon) && i < self.iterations { + if self.verbose { print!("Iteration = {}", i); } @@ -189,12 +195,12 @@ impl History { i += 1; - if verbose { + if self.verbose { println!(", step = {:?}", step); } } - if verbose { + if self.verbose { println!("End"); } @@ -278,7 +284,7 @@ mod tests { let mut h = History::new( composition, results, - vec![1, 2, 3], + vec![1.0, 2.0, 3.0], priors, MU, BETA, @@ -319,7 +325,7 @@ mod tests { vec![vec!["cj"], vec!["aj"]], ]; let results = vec![vec![1, 0], vec![1, 0], vec![1, 0]]; - let times = vec![1, 1, 1]; + let times = vec![1.0, 1.0, 1.0]; let mut priors = HashMap::new(); @@ -396,7 +402,7 @@ mod tests { vec![vec!["cj"], vec!["aj"]], ]; let results = vec![vec![1, 0], vec![1, 0], vec![1, 0]]; - let times = vec![1, 2, 3]; + let times = vec![1.0, 2.0, 3.0]; let mut priors = HashMap::new(); @@ -476,7 +482,7 @@ mod tests { vec![vec!["cj"], vec!["aj"]], ]; let results = vec![vec![1, 0], vec![1, 0], vec![1, 0]]; - let times = vec![5, 6, 7]; + let times = vec![5.0, 6.0, 7.0]; let mut priors = HashMap::new(); diff --git a/src/utils.rs b/src/utils.rs index aa54779..ad5b3cd 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,4 +1,4 @@ -use std::cmp::Reverse; +use std::cmp::{Ordering, Reverse}; use std::f64::consts::{FRAC_1_SQRT_2, FRAC_2_SQRT_PI, SQRT_2}; use crate::Gaussian; @@ -146,6 +146,18 @@ pub(crate) fn sortperm(xs: &[T], reverse: bool) -> Vec { x.into_iter().map(|(i, _)| i).collect() } +pub(crate) fn sort_time(xs: &[f64], reverse: bool) -> Vec { + let mut x = xs.iter().enumerate().collect::>(); + + if reverse { + x.sort_unstable_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap()); + } else { + x.sort_unstable_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()); + } + + x.into_iter().map(|(i, _)| i).collect() +} + #[cfg(test)] mod tests { use crate::{Gaussian, N01}; @@ -225,4 +237,10 @@ mod tests { assert_eq!(sortperm(&[0, 1, 2, 0], true), vec![2, 1, 0, 3]); assert_eq!(sortperm(&[1, 1, 1], false), vec![0, 1, 2]); } + + #[test] + fn test_sort_time() { + assert_eq!(sort_time(&[0.0, 1.0, 2.0, 0.0], true), vec![2, 1, 0, 3]); + assert_eq!(sort_time(&[1.0, 1.0, 1.0], false), vec![0, 1, 2]); + } }