T0 + T1 + T2: engine redesign through new API surface #1
@@ -40,7 +40,8 @@ fn main() {
|
||||
|
||||
let mut hist = History::builder().sigma(1.6).gamma(0.036).build();
|
||||
|
||||
hist.add_events(composition, results, times, vec![]);
|
||||
hist.add_events(composition, results, times, vec![])
|
||||
.unwrap();
|
||||
hist.convergence(10, 0.01, true);
|
||||
|
||||
let players = [
|
||||
|
||||
33
src/error.rs
33
src/error.rs
@@ -2,12 +2,45 @@ use std::fmt;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum InferenceError {
|
||||
/// Expected and actual lengths of some array-shaped input differ.
|
||||
MismatchedShape {
|
||||
kind: &'static str,
|
||||
expected: usize,
|
||||
got: usize,
|
||||
},
|
||||
/// A probability value is outside `[0, 1]`.
|
||||
InvalidProbability { value: f64 },
|
||||
/// Convergence exceeded `max_iter` without falling below `epsilon`.
|
||||
ConvergenceFailed {
|
||||
last_step: (f64, f64),
|
||||
iterations: usize,
|
||||
},
|
||||
/// Negative precision: a Gaussian with `pi < 0` slipped into an API call.
|
||||
NegativePrecision { pi: f64 },
|
||||
}
|
||||
|
||||
impl fmt::Display for InferenceError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::MismatchedShape {
|
||||
kind,
|
||||
expected,
|
||||
got,
|
||||
} => {
|
||||
write!(f, "{kind}: expected length {expected}, got {got}")
|
||||
}
|
||||
Self::InvalidProbability { value } => {
|
||||
write!(f, "probability must be in [0, 1]; got {value}")
|
||||
}
|
||||
Self::ConvergenceFailed {
|
||||
last_step,
|
||||
iterations,
|
||||
} => {
|
||||
write!(
|
||||
f,
|
||||
"convergence failed after {iterations} iterations; last step = {last_step:?}"
|
||||
)
|
||||
}
|
||||
Self::NegativePrecision { pi } => {
|
||||
write!(f, "precision must be non-negative; got {pi}")
|
||||
}
|
||||
|
||||
@@ -305,7 +305,7 @@ impl<D: Drift<i64>, O: Observer<i64>> History<i64, D, O> {
|
||||
results: Vec<Vec<f64>>,
|
||||
times: Vec<i64>,
|
||||
weights: Vec<Vec<Vec<f64>>>,
|
||||
) {
|
||||
) -> Result<(), InferenceError> {
|
||||
self.add_events_with_prior(composition, results, times, weights, HashMap::new())
|
||||
}
|
||||
|
||||
@@ -316,19 +316,28 @@ impl<D: Drift<i64>, O: Observer<i64>> History<i64, D, O> {
|
||||
times: Vec<i64>,
|
||||
weights: Vec<Vec<Vec<f64>>>,
|
||||
mut priors: HashMap<Index, Rating<i64, D>>,
|
||||
) {
|
||||
assert!(
|
||||
results.is_empty() || results.len() == composition.len(),
|
||||
"(length(results) > 0) & (length(composition) != length(results))"
|
||||
);
|
||||
assert!(
|
||||
times.len() == composition.len(),
|
||||
"length(times) must equal length(composition)"
|
||||
);
|
||||
assert!(
|
||||
weights.is_empty() || weights.len() == composition.len(),
|
||||
"(length(weights) > 0) & (length(composition) != length(weights))"
|
||||
);
|
||||
) -> Result<(), InferenceError> {
|
||||
if !results.is_empty() && results.len() != composition.len() {
|
||||
return Err(InferenceError::MismatchedShape {
|
||||
kind: "results",
|
||||
expected: composition.len(),
|
||||
got: results.len(),
|
||||
});
|
||||
}
|
||||
if times.len() != composition.len() {
|
||||
return Err(InferenceError::MismatchedShape {
|
||||
kind: "times",
|
||||
expected: composition.len(),
|
||||
got: times.len(),
|
||||
});
|
||||
}
|
||||
if !weights.is_empty() && weights.len() != composition.len() {
|
||||
return Err(InferenceError::MismatchedShape {
|
||||
kind: "weights",
|
||||
expected: composition.len(),
|
||||
got: weights.len(),
|
||||
});
|
||||
}
|
||||
|
||||
competitor::clean(self.agents.values_mut(), true);
|
||||
|
||||
@@ -467,6 +476,7 @@ impl<D: Drift<i64>, O: Observer<i64>> History<i64, D, O> {
|
||||
}
|
||||
|
||||
self.size += n;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -510,7 +520,8 @@ mod tests {
|
||||
|
||||
let mut h = History::default();
|
||||
|
||||
h.add_events_with_prior(composition, results, vec![1, 2, 3], vec![], priors);
|
||||
h.add_events_with_prior(composition, results, vec![1, 2, 3], vec![], priors)
|
||||
.unwrap();
|
||||
|
||||
let p0 = h.time_slices[0].posteriors();
|
||||
|
||||
@@ -586,7 +597,8 @@ mod tests {
|
||||
|
||||
let mut h1 = History::default();
|
||||
|
||||
h1.add_events_with_prior(composition, results, times, vec![], priors);
|
||||
h1.add_events_with_prior(composition, results, times, vec![], priors)
|
||||
.unwrap();
|
||||
|
||||
assert_ulps_eq!(
|
||||
h1.time_slices[0].skills.get(a).unwrap().posterior(),
|
||||
@@ -635,7 +647,8 @@ mod tests {
|
||||
|
||||
let mut h2 = History::default();
|
||||
|
||||
h2.add_events_with_prior(composition, results, times, vec![], priors);
|
||||
h2.add_events_with_prior(composition, results, times, vec![], priors)
|
||||
.unwrap();
|
||||
|
||||
assert_ulps_eq!(
|
||||
h2.time_slices[2].skills.get(a).unwrap().posterior(),
|
||||
@@ -693,7 +706,8 @@ mod tests {
|
||||
|
||||
let mut h = History::default();
|
||||
|
||||
h.add_events_with_prior(composition, results, times, vec![], priors);
|
||||
h.add_events_with_prior(composition, results, times, vec![], priors)
|
||||
.unwrap();
|
||||
h.convergence(ITERATIONS, EPSILON, false);
|
||||
|
||||
let lc = h.learning_curves();
|
||||
@@ -740,7 +754,7 @@ mod tests {
|
||||
|
||||
let n = composition.len();
|
||||
let times: Vec<i64> = (1..=n as i64).collect();
|
||||
h.add_events(composition, results, times, vec![]);
|
||||
h.add_events(composition, results, times, vec![]).unwrap();
|
||||
|
||||
h.convergence(ITERATIONS, EPSILON, false);
|
||||
|
||||
@@ -791,7 +805,7 @@ mod tests {
|
||||
|
||||
let n = composition.len();
|
||||
let times: Vec<i64> = (1..=n as i64).collect();
|
||||
h.add_events(composition, results, times, vec![]);
|
||||
h.add_events(composition, results, times, vec![]).unwrap();
|
||||
|
||||
let trueskill_log_evidence = h.log_evidence(false, &[]);
|
||||
let trueskill_log_evidence_online = h.log_evidence(true, &[]);
|
||||
@@ -879,7 +893,8 @@ mod tests {
|
||||
|
||||
let n = composition.len();
|
||||
let times: Vec<i64> = (1..=n as i64).collect();
|
||||
h.add_events(composition.clone(), results.clone(), times, vec![]);
|
||||
h.add_events(composition.clone(), results.clone(), times, vec![])
|
||||
.unwrap();
|
||||
|
||||
h.convergence(ITERATIONS, EPSILON, false);
|
||||
|
||||
@@ -903,7 +918,7 @@ mod tests {
|
||||
);
|
||||
|
||||
let times2: Vec<i64> = (n as i64 + 1..=2 * n as i64).collect();
|
||||
h.add_events(composition, results, times2, vec![]);
|
||||
h.add_events(composition, results, times2, vec![]).unwrap();
|
||||
|
||||
assert_eq!(h.time_slices.len(), 6);
|
||||
|
||||
@@ -970,7 +985,8 @@ mod tests {
|
||||
|
||||
let n = composition.len();
|
||||
let times: Vec<i64> = (1..=n as i64).collect();
|
||||
h.add_events(composition.clone(), results.clone(), times, vec![]);
|
||||
h.add_events(composition.clone(), results.clone(), times, vec![])
|
||||
.unwrap();
|
||||
|
||||
h.convergence(ITERATIONS, EPSILON, false);
|
||||
|
||||
@@ -994,7 +1010,7 @@ mod tests {
|
||||
);
|
||||
|
||||
let times2: Vec<i64> = (n as i64 + 1..=2 * n as i64).collect();
|
||||
h.add_events(composition, results, times2, vec![]);
|
||||
h.add_events(composition, results, times2, vec![]).unwrap();
|
||||
|
||||
assert_eq!(h.time_slices.len(), 6);
|
||||
|
||||
@@ -1050,7 +1066,8 @@ mod tests {
|
||||
|
||||
let n = composition.len();
|
||||
let times: Vec<i64> = (1..=n as i64).collect();
|
||||
h.add_events(composition.clone(), vec![], times.clone(), vec![]);
|
||||
h.add_events(composition.clone(), vec![], times.clone(), vec![])
|
||||
.unwrap();
|
||||
|
||||
let p_d_m_2 = h.log_evidence(false, &[]).exp() * 2.0;
|
||||
|
||||
@@ -1089,7 +1106,7 @@ mod tests {
|
||||
|
||||
let mut h = History::builder().build();
|
||||
|
||||
h.add_events(composition, vec![], times, vec![]);
|
||||
h.add_events(composition, vec![], times, vec![]).unwrap();
|
||||
|
||||
assert_ulps_eq!(
|
||||
((0.5f64 * 0.1765).ln() / 2.0).exp(),
|
||||
@@ -1125,11 +1142,13 @@ mod tests {
|
||||
results.clone(),
|
||||
vec![0, 10, 20],
|
||||
vec![],
|
||||
);
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
h.convergence(ITERATIONS, EPSILON, false);
|
||||
|
||||
h.add_events(composition, results, vec![15, 10, 0], vec![]);
|
||||
h.add_events(composition, results, vec![15, 10, 0], vec![])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(h.time_slices.len(), 4);
|
||||
|
||||
@@ -1213,11 +1232,13 @@ mod tests {
|
||||
.gamma(0.0)
|
||||
.build();
|
||||
|
||||
h.add_events(composition.clone(), vec![], vec![0, 10, 20], vec![]);
|
||||
h.add_events(composition.clone(), vec![], vec![0, 10, 20], vec![])
|
||||
.unwrap();
|
||||
|
||||
h.convergence(ITERATIONS, EPSILON, false);
|
||||
|
||||
h.add_events(composition, vec![], vec![15, 10, 0], vec![]);
|
||||
h.add_events(composition, vec![], vec![15, 10, 0], vec![])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(h.time_slices.len(), 4);
|
||||
|
||||
@@ -1306,7 +1327,7 @@ mod tests {
|
||||
|
||||
let n = composition.len();
|
||||
let times: Vec<i64> = (1..=n as i64).collect();
|
||||
h.add_events(composition, vec![], times, weights);
|
||||
h.add_events(composition, vec![], times, weights).unwrap();
|
||||
|
||||
let lc = h.learning_curves();
|
||||
|
||||
@@ -1367,7 +1388,7 @@ mod tests {
|
||||
epsilon: 1e-6,
|
||||
})
|
||||
.build();
|
||||
h.add_events(composition, results, times, vec![]);
|
||||
h.add_events(composition, results, times, vec![]).unwrap();
|
||||
|
||||
let report = h.converge().unwrap();
|
||||
assert!(report.converged);
|
||||
|
||||
Reference in New Issue
Block a user