diff --git a/examples/atp.rs b/examples/atp.rs index 0ebf845..236e8e3 100644 --- a/examples/atp.rs +++ b/examples/atp.rs @@ -40,7 +40,8 @@ fn main() { let mut hist = History::builder().sigma(1.6).gamma(0.036).build(); - hist.add_events(composition, results, times, vec![]); + hist.add_events(composition, results, times, vec![]) + .unwrap(); hist.convergence(10, 0.01, true); let players = [ diff --git a/src/error.rs b/src/error.rs index 3886451..e32a124 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,12 +2,45 @@ use std::fmt; #[derive(Debug, Clone, PartialEq)] pub enum InferenceError { + /// Expected and actual lengths of some array-shaped input differ. + MismatchedShape { + kind: &'static str, + expected: usize, + got: usize, + }, + /// A probability value is outside `[0, 1]`. + InvalidProbability { value: f64 }, + /// Convergence exceeded `max_iter` without falling below `epsilon`. + ConvergenceFailed { + last_step: (f64, f64), + iterations: usize, + }, + /// Negative precision: a Gaussian with `pi < 0` slipped into an API call. NegativePrecision { pi: f64 }, } impl fmt::Display for InferenceError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { + Self::MismatchedShape { + kind, + expected, + got, + } => { + write!(f, "{kind}: expected length {expected}, got {got}") + } + Self::InvalidProbability { value } => { + write!(f, "probability must be in [0, 1]; got {value}") + } + Self::ConvergenceFailed { + last_step, + iterations, + } => { + write!( + f, + "convergence failed after {iterations} iterations; last step = {last_step:?}" + ) + } Self::NegativePrecision { pi } => { write!(f, "precision must be non-negative; got {pi}") } diff --git a/src/history.rs b/src/history.rs index 8545d8d..1a83a31 100644 --- a/src/history.rs +++ b/src/history.rs @@ -305,7 +305,7 @@ impl, O: Observer> History { results: Vec>, times: Vec, weights: Vec>>, - ) { + ) -> Result<(), InferenceError> { self.add_events_with_prior(composition, results, times, weights, HashMap::new()) } @@ -316,19 +316,28 @@ impl, O: Observer> History { times: Vec, weights: Vec>>, mut priors: HashMap>, - ) { - assert!( - results.is_empty() || results.len() == composition.len(), - "(length(results) > 0) & (length(composition) != length(results))" - ); - assert!( - times.len() == composition.len(), - "length(times) must equal length(composition)" - ); - assert!( - weights.is_empty() || weights.len() == composition.len(), - "(length(weights) > 0) & (length(composition) != length(weights))" - ); + ) -> Result<(), InferenceError> { + if !results.is_empty() && results.len() != composition.len() { + return Err(InferenceError::MismatchedShape { + kind: "results", + expected: composition.len(), + got: results.len(), + }); + } + if times.len() != composition.len() { + return Err(InferenceError::MismatchedShape { + kind: "times", + expected: composition.len(), + got: times.len(), + }); + } + if !weights.is_empty() && weights.len() != composition.len() { + return Err(InferenceError::MismatchedShape { + kind: "weights", + expected: composition.len(), + got: weights.len(), + }); + } competitor::clean(self.agents.values_mut(), true); @@ -467,6 +476,7 @@ impl, O: Observer> History { } self.size += n; + Ok(()) } } @@ -510,7 +520,8 @@ mod tests { let mut h = History::default(); - h.add_events_with_prior(composition, results, vec![1, 2, 3], vec![], priors); + h.add_events_with_prior(composition, results, vec![1, 2, 3], vec![], priors) + .unwrap(); let p0 = h.time_slices[0].posteriors(); @@ -586,7 +597,8 @@ mod tests { let mut h1 = History::default(); - h1.add_events_with_prior(composition, results, times, vec![], priors); + h1.add_events_with_prior(composition, results, times, vec![], priors) + .unwrap(); assert_ulps_eq!( h1.time_slices[0].skills.get(a).unwrap().posterior(), @@ -635,7 +647,8 @@ mod tests { let mut h2 = History::default(); - h2.add_events_with_prior(composition, results, times, vec![], priors); + h2.add_events_with_prior(composition, results, times, vec![], priors) + .unwrap(); assert_ulps_eq!( h2.time_slices[2].skills.get(a).unwrap().posterior(), @@ -693,7 +706,8 @@ mod tests { let mut h = History::default(); - h.add_events_with_prior(composition, results, times, vec![], priors); + h.add_events_with_prior(composition, results, times, vec![], priors) + .unwrap(); h.convergence(ITERATIONS, EPSILON, false); let lc = h.learning_curves(); @@ -740,7 +754,7 @@ mod tests { let n = composition.len(); let times: Vec = (1..=n as i64).collect(); - h.add_events(composition, results, times, vec![]); + h.add_events(composition, results, times, vec![]).unwrap(); h.convergence(ITERATIONS, EPSILON, false); @@ -791,7 +805,7 @@ mod tests { let n = composition.len(); let times: Vec = (1..=n as i64).collect(); - h.add_events(composition, results, times, vec![]); + h.add_events(composition, results, times, vec![]).unwrap(); let trueskill_log_evidence = h.log_evidence(false, &[]); let trueskill_log_evidence_online = h.log_evidence(true, &[]); @@ -879,7 +893,8 @@ mod tests { let n = composition.len(); let times: Vec = (1..=n as i64).collect(); - h.add_events(composition.clone(), results.clone(), times, vec![]); + h.add_events(composition.clone(), results.clone(), times, vec![]) + .unwrap(); h.convergence(ITERATIONS, EPSILON, false); @@ -903,7 +918,7 @@ mod tests { ); let times2: Vec = (n as i64 + 1..=2 * n as i64).collect(); - h.add_events(composition, results, times2, vec![]); + h.add_events(composition, results, times2, vec![]).unwrap(); assert_eq!(h.time_slices.len(), 6); @@ -970,7 +985,8 @@ mod tests { let n = composition.len(); let times: Vec = (1..=n as i64).collect(); - h.add_events(composition.clone(), results.clone(), times, vec![]); + h.add_events(composition.clone(), results.clone(), times, vec![]) + .unwrap(); h.convergence(ITERATIONS, EPSILON, false); @@ -994,7 +1010,7 @@ mod tests { ); let times2: Vec = (n as i64 + 1..=2 * n as i64).collect(); - h.add_events(composition, results, times2, vec![]); + h.add_events(composition, results, times2, vec![]).unwrap(); assert_eq!(h.time_slices.len(), 6); @@ -1050,7 +1066,8 @@ mod tests { let n = composition.len(); let times: Vec = (1..=n as i64).collect(); - h.add_events(composition.clone(), vec![], times.clone(), vec![]); + h.add_events(composition.clone(), vec![], times.clone(), vec![]) + .unwrap(); let p_d_m_2 = h.log_evidence(false, &[]).exp() * 2.0; @@ -1089,7 +1106,7 @@ mod tests { let mut h = History::builder().build(); - h.add_events(composition, vec![], times, vec![]); + h.add_events(composition, vec![], times, vec![]).unwrap(); assert_ulps_eq!( ((0.5f64 * 0.1765).ln() / 2.0).exp(), @@ -1125,11 +1142,13 @@ mod tests { results.clone(), vec![0, 10, 20], vec![], - ); + ) + .unwrap(); h.convergence(ITERATIONS, EPSILON, false); - h.add_events(composition, results, vec![15, 10, 0], vec![]); + h.add_events(composition, results, vec![15, 10, 0], vec![]) + .unwrap(); assert_eq!(h.time_slices.len(), 4); @@ -1213,11 +1232,13 @@ mod tests { .gamma(0.0) .build(); - h.add_events(composition.clone(), vec![], vec![0, 10, 20], vec![]); + h.add_events(composition.clone(), vec![], vec![0, 10, 20], vec![]) + .unwrap(); h.convergence(ITERATIONS, EPSILON, false); - h.add_events(composition, vec![], vec![15, 10, 0], vec![]); + h.add_events(composition, vec![], vec![15, 10, 0], vec![]) + .unwrap(); assert_eq!(h.time_slices.len(), 4); @@ -1306,7 +1327,7 @@ mod tests { let n = composition.len(); let times: Vec = (1..=n as i64).collect(); - h.add_events(composition, vec![], times, weights); + h.add_events(composition, vec![], times, weights).unwrap(); let lc = h.learning_curves(); @@ -1367,7 +1388,7 @@ mod tests { epsilon: 1e-6, }) .build(); - h.add_events(composition, results, times, vec![]); + h.add_events(composition, results, times, vec![]).unwrap(); let report = h.converge().unwrap(); assert!(report.converged);