T0 + T1 + T2: engine redesign through new API surface #1
@@ -40,7 +40,8 @@ fn main() {
|
|||||||
|
|
||||||
let mut hist = History::builder().sigma(1.6).gamma(0.036).build();
|
let mut hist = History::builder().sigma(1.6).gamma(0.036).build();
|
||||||
|
|
||||||
hist.add_events(composition, results, times, vec![]);
|
hist.add_events(composition, results, times, vec![])
|
||||||
|
.unwrap();
|
||||||
hist.convergence(10, 0.01, true);
|
hist.convergence(10, 0.01, true);
|
||||||
|
|
||||||
let players = [
|
let players = [
|
||||||
|
|||||||
33
src/error.rs
33
src/error.rs
@@ -2,12 +2,45 @@ use std::fmt;
|
|||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub enum InferenceError {
|
pub enum InferenceError {
|
||||||
|
/// Expected and actual lengths of some array-shaped input differ.
|
||||||
|
MismatchedShape {
|
||||||
|
kind: &'static str,
|
||||||
|
expected: usize,
|
||||||
|
got: usize,
|
||||||
|
},
|
||||||
|
/// A probability value is outside `[0, 1]`.
|
||||||
|
InvalidProbability { value: f64 },
|
||||||
|
/// Convergence exceeded `max_iter` without falling below `epsilon`.
|
||||||
|
ConvergenceFailed {
|
||||||
|
last_step: (f64, f64),
|
||||||
|
iterations: usize,
|
||||||
|
},
|
||||||
|
/// Negative precision: a Gaussian with `pi < 0` slipped into an API call.
|
||||||
NegativePrecision { pi: f64 },
|
NegativePrecision { pi: f64 },
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for InferenceError {
|
impl fmt::Display for InferenceError {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
|
Self::MismatchedShape {
|
||||||
|
kind,
|
||||||
|
expected,
|
||||||
|
got,
|
||||||
|
} => {
|
||||||
|
write!(f, "{kind}: expected length {expected}, got {got}")
|
||||||
|
}
|
||||||
|
Self::InvalidProbability { value } => {
|
||||||
|
write!(f, "probability must be in [0, 1]; got {value}")
|
||||||
|
}
|
||||||
|
Self::ConvergenceFailed {
|
||||||
|
last_step,
|
||||||
|
iterations,
|
||||||
|
} => {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"convergence failed after {iterations} iterations; last step = {last_step:?}"
|
||||||
|
)
|
||||||
|
}
|
||||||
Self::NegativePrecision { pi } => {
|
Self::NegativePrecision { pi } => {
|
||||||
write!(f, "precision must be non-negative; got {pi}")
|
write!(f, "precision must be non-negative; got {pi}")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -305,7 +305,7 @@ impl<D: Drift<i64>, O: Observer<i64>> History<i64, D, O> {
|
|||||||
results: Vec<Vec<f64>>,
|
results: Vec<Vec<f64>>,
|
||||||
times: Vec<i64>,
|
times: Vec<i64>,
|
||||||
weights: Vec<Vec<Vec<f64>>>,
|
weights: Vec<Vec<Vec<f64>>>,
|
||||||
) {
|
) -> Result<(), InferenceError> {
|
||||||
self.add_events_with_prior(composition, results, times, weights, HashMap::new())
|
self.add_events_with_prior(composition, results, times, weights, HashMap::new())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -316,19 +316,28 @@ impl<D: Drift<i64>, O: Observer<i64>> History<i64, D, O> {
|
|||||||
times: Vec<i64>,
|
times: Vec<i64>,
|
||||||
weights: Vec<Vec<Vec<f64>>>,
|
weights: Vec<Vec<Vec<f64>>>,
|
||||||
mut priors: HashMap<Index, Rating<i64, D>>,
|
mut priors: HashMap<Index, Rating<i64, D>>,
|
||||||
) {
|
) -> Result<(), InferenceError> {
|
||||||
assert!(
|
if !results.is_empty() && results.len() != composition.len() {
|
||||||
results.is_empty() || results.len() == composition.len(),
|
return Err(InferenceError::MismatchedShape {
|
||||||
"(length(results) > 0) & (length(composition) != length(results))"
|
kind: "results",
|
||||||
);
|
expected: composition.len(),
|
||||||
assert!(
|
got: results.len(),
|
||||||
times.len() == composition.len(),
|
});
|
||||||
"length(times) must equal length(composition)"
|
}
|
||||||
);
|
if times.len() != composition.len() {
|
||||||
assert!(
|
return Err(InferenceError::MismatchedShape {
|
||||||
weights.is_empty() || weights.len() == composition.len(),
|
kind: "times",
|
||||||
"(length(weights) > 0) & (length(composition) != length(weights))"
|
expected: composition.len(),
|
||||||
);
|
got: times.len(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if !weights.is_empty() && weights.len() != composition.len() {
|
||||||
|
return Err(InferenceError::MismatchedShape {
|
||||||
|
kind: "weights",
|
||||||
|
expected: composition.len(),
|
||||||
|
got: weights.len(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
competitor::clean(self.agents.values_mut(), true);
|
competitor::clean(self.agents.values_mut(), true);
|
||||||
|
|
||||||
@@ -467,6 +476,7 @@ impl<D: Drift<i64>, O: Observer<i64>> History<i64, D, O> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
self.size += n;
|
self.size += n;
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -510,7 +520,8 @@ mod tests {
|
|||||||
|
|
||||||
let mut h = History::default();
|
let mut h = History::default();
|
||||||
|
|
||||||
h.add_events_with_prior(composition, results, vec![1, 2, 3], vec![], priors);
|
h.add_events_with_prior(composition, results, vec![1, 2, 3], vec![], priors)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let p0 = h.time_slices[0].posteriors();
|
let p0 = h.time_slices[0].posteriors();
|
||||||
|
|
||||||
@@ -586,7 +597,8 @@ mod tests {
|
|||||||
|
|
||||||
let mut h1 = History::default();
|
let mut h1 = History::default();
|
||||||
|
|
||||||
h1.add_events_with_prior(composition, results, times, vec![], priors);
|
h1.add_events_with_prior(composition, results, times, vec![], priors)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h1.time_slices[0].skills.get(a).unwrap().posterior(),
|
h1.time_slices[0].skills.get(a).unwrap().posterior(),
|
||||||
@@ -635,7 +647,8 @@ mod tests {
|
|||||||
|
|
||||||
let mut h2 = History::default();
|
let mut h2 = History::default();
|
||||||
|
|
||||||
h2.add_events_with_prior(composition, results, times, vec![], priors);
|
h2.add_events_with_prior(composition, results, times, vec![], priors)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
h2.time_slices[2].skills.get(a).unwrap().posterior(),
|
h2.time_slices[2].skills.get(a).unwrap().posterior(),
|
||||||
@@ -693,7 +706,8 @@ mod tests {
|
|||||||
|
|
||||||
let mut h = History::default();
|
let mut h = History::default();
|
||||||
|
|
||||||
h.add_events_with_prior(composition, results, times, vec![], priors);
|
h.add_events_with_prior(composition, results, times, vec![], priors)
|
||||||
|
.unwrap();
|
||||||
h.convergence(ITERATIONS, EPSILON, false);
|
h.convergence(ITERATIONS, EPSILON, false);
|
||||||
|
|
||||||
let lc = h.learning_curves();
|
let lc = h.learning_curves();
|
||||||
@@ -740,7 +754,7 @@ mod tests {
|
|||||||
|
|
||||||
let n = composition.len();
|
let n = composition.len();
|
||||||
let times: Vec<i64> = (1..=n as i64).collect();
|
let times: Vec<i64> = (1..=n as i64).collect();
|
||||||
h.add_events(composition, results, times, vec![]);
|
h.add_events(composition, results, times, vec![]).unwrap();
|
||||||
|
|
||||||
h.convergence(ITERATIONS, EPSILON, false);
|
h.convergence(ITERATIONS, EPSILON, false);
|
||||||
|
|
||||||
@@ -791,7 +805,7 @@ mod tests {
|
|||||||
|
|
||||||
let n = composition.len();
|
let n = composition.len();
|
||||||
let times: Vec<i64> = (1..=n as i64).collect();
|
let times: Vec<i64> = (1..=n as i64).collect();
|
||||||
h.add_events(composition, results, times, vec![]);
|
h.add_events(composition, results, times, vec![]).unwrap();
|
||||||
|
|
||||||
let trueskill_log_evidence = h.log_evidence(false, &[]);
|
let trueskill_log_evidence = h.log_evidence(false, &[]);
|
||||||
let trueskill_log_evidence_online = h.log_evidence(true, &[]);
|
let trueskill_log_evidence_online = h.log_evidence(true, &[]);
|
||||||
@@ -879,7 +893,8 @@ mod tests {
|
|||||||
|
|
||||||
let n = composition.len();
|
let n = composition.len();
|
||||||
let times: Vec<i64> = (1..=n as i64).collect();
|
let times: Vec<i64> = (1..=n as i64).collect();
|
||||||
h.add_events(composition.clone(), results.clone(), times, vec![]);
|
h.add_events(composition.clone(), results.clone(), times, vec![])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
h.convergence(ITERATIONS, EPSILON, false);
|
h.convergence(ITERATIONS, EPSILON, false);
|
||||||
|
|
||||||
@@ -903,7 +918,7 @@ mod tests {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let times2: Vec<i64> = (n as i64 + 1..=2 * n as i64).collect();
|
let times2: Vec<i64> = (n as i64 + 1..=2 * n as i64).collect();
|
||||||
h.add_events(composition, results, times2, vec![]);
|
h.add_events(composition, results, times2, vec![]).unwrap();
|
||||||
|
|
||||||
assert_eq!(h.time_slices.len(), 6);
|
assert_eq!(h.time_slices.len(), 6);
|
||||||
|
|
||||||
@@ -970,7 +985,8 @@ mod tests {
|
|||||||
|
|
||||||
let n = composition.len();
|
let n = composition.len();
|
||||||
let times: Vec<i64> = (1..=n as i64).collect();
|
let times: Vec<i64> = (1..=n as i64).collect();
|
||||||
h.add_events(composition.clone(), results.clone(), times, vec![]);
|
h.add_events(composition.clone(), results.clone(), times, vec![])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
h.convergence(ITERATIONS, EPSILON, false);
|
h.convergence(ITERATIONS, EPSILON, false);
|
||||||
|
|
||||||
@@ -994,7 +1010,7 @@ mod tests {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let times2: Vec<i64> = (n as i64 + 1..=2 * n as i64).collect();
|
let times2: Vec<i64> = (n as i64 + 1..=2 * n as i64).collect();
|
||||||
h.add_events(composition, results, times2, vec![]);
|
h.add_events(composition, results, times2, vec![]).unwrap();
|
||||||
|
|
||||||
assert_eq!(h.time_slices.len(), 6);
|
assert_eq!(h.time_slices.len(), 6);
|
||||||
|
|
||||||
@@ -1050,7 +1066,8 @@ mod tests {
|
|||||||
|
|
||||||
let n = composition.len();
|
let n = composition.len();
|
||||||
let times: Vec<i64> = (1..=n as i64).collect();
|
let times: Vec<i64> = (1..=n as i64).collect();
|
||||||
h.add_events(composition.clone(), vec![], times.clone(), vec![]);
|
h.add_events(composition.clone(), vec![], times.clone(), vec![])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let p_d_m_2 = h.log_evidence(false, &[]).exp() * 2.0;
|
let p_d_m_2 = h.log_evidence(false, &[]).exp() * 2.0;
|
||||||
|
|
||||||
@@ -1089,7 +1106,7 @@ mod tests {
|
|||||||
|
|
||||||
let mut h = History::builder().build();
|
let mut h = History::builder().build();
|
||||||
|
|
||||||
h.add_events(composition, vec![], times, vec![]);
|
h.add_events(composition, vec![], times, vec![]).unwrap();
|
||||||
|
|
||||||
assert_ulps_eq!(
|
assert_ulps_eq!(
|
||||||
((0.5f64 * 0.1765).ln() / 2.0).exp(),
|
((0.5f64 * 0.1765).ln() / 2.0).exp(),
|
||||||
@@ -1125,11 +1142,13 @@ mod tests {
|
|||||||
results.clone(),
|
results.clone(),
|
||||||
vec![0, 10, 20],
|
vec![0, 10, 20],
|
||||||
vec![],
|
vec![],
|
||||||
);
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
h.convergence(ITERATIONS, EPSILON, false);
|
h.convergence(ITERATIONS, EPSILON, false);
|
||||||
|
|
||||||
h.add_events(composition, results, vec![15, 10, 0], vec![]);
|
h.add_events(composition, results, vec![15, 10, 0], vec![])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(h.time_slices.len(), 4);
|
assert_eq!(h.time_slices.len(), 4);
|
||||||
|
|
||||||
@@ -1213,11 +1232,13 @@ mod tests {
|
|||||||
.gamma(0.0)
|
.gamma(0.0)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
h.add_events(composition.clone(), vec![], vec![0, 10, 20], vec![]);
|
h.add_events(composition.clone(), vec![], vec![0, 10, 20], vec![])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
h.convergence(ITERATIONS, EPSILON, false);
|
h.convergence(ITERATIONS, EPSILON, false);
|
||||||
|
|
||||||
h.add_events(composition, vec![], vec![15, 10, 0], vec![]);
|
h.add_events(composition, vec![], vec![15, 10, 0], vec![])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(h.time_slices.len(), 4);
|
assert_eq!(h.time_slices.len(), 4);
|
||||||
|
|
||||||
@@ -1306,7 +1327,7 @@ mod tests {
|
|||||||
|
|
||||||
let n = composition.len();
|
let n = composition.len();
|
||||||
let times: Vec<i64> = (1..=n as i64).collect();
|
let times: Vec<i64> = (1..=n as i64).collect();
|
||||||
h.add_events(composition, vec![], times, weights);
|
h.add_events(composition, vec![], times, weights).unwrap();
|
||||||
|
|
||||||
let lc = h.learning_curves();
|
let lc = h.learning_curves();
|
||||||
|
|
||||||
@@ -1367,7 +1388,7 @@ mod tests {
|
|||||||
epsilon: 1e-6,
|
epsilon: 1e-6,
|
||||||
})
|
})
|
||||||
.build();
|
.build();
|
||||||
h.add_events(composition, results, times, vec![]);
|
h.add_events(composition, results, times, vec![]).unwrap();
|
||||||
|
|
||||||
let report = h.converge().unwrap();
|
let report = h.converge().unwrap();
|
||||||
assert!(report.converged);
|
assert!(report.converged);
|
||||||
|
|||||||
Reference in New Issue
Block a user