T0 + T1 + T2: engine redesign through new API surface #1

Merged
logaritmisk merged 45 commits from t2-new-api-surface into main 2026-04-24 11:20:04 +00:00
3 changed files with 88 additions and 33 deletions
Showing only changes of commit a83c9acacb - Show all commits

View File

@@ -40,7 +40,8 @@ fn main() {
let mut hist = History::builder().sigma(1.6).gamma(0.036).build();
hist.add_events(composition, results, times, vec![]);
hist.add_events(composition, results, times, vec![])
.unwrap();
hist.convergence(10, 0.01, true);
let players = [

View File

@@ -2,12 +2,45 @@ use std::fmt;
#[derive(Debug, Clone, PartialEq)]
pub enum InferenceError {
/// Expected and actual lengths of some array-shaped input differ.
MismatchedShape {
kind: &'static str,
expected: usize,
got: usize,
},
/// A probability value is outside `[0, 1]`.
InvalidProbability { value: f64 },
/// Convergence exceeded `max_iter` without falling below `epsilon`.
ConvergenceFailed {
last_step: (f64, f64),
iterations: usize,
},
/// Negative precision: a Gaussian with `pi < 0` slipped into an API call.
NegativePrecision { pi: f64 },
}
impl fmt::Display for InferenceError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::MismatchedShape {
kind,
expected,
got,
} => {
write!(f, "{kind}: expected length {expected}, got {got}")
}
Self::InvalidProbability { value } => {
write!(f, "probability must be in [0, 1]; got {value}")
}
Self::ConvergenceFailed {
last_step,
iterations,
} => {
write!(
f,
"convergence failed after {iterations} iterations; last step = {last_step:?}"
)
}
Self::NegativePrecision { pi } => {
write!(f, "precision must be non-negative; got {pi}")
}

View File

@@ -305,7 +305,7 @@ impl<D: Drift<i64>, O: Observer<i64>> History<i64, D, O> {
results: Vec<Vec<f64>>,
times: Vec<i64>,
weights: Vec<Vec<Vec<f64>>>,
) {
) -> Result<(), InferenceError> {
self.add_events_with_prior(composition, results, times, weights, HashMap::new())
}
@@ -316,19 +316,28 @@ impl<D: Drift<i64>, O: Observer<i64>> History<i64, D, O> {
times: Vec<i64>,
weights: Vec<Vec<Vec<f64>>>,
mut priors: HashMap<Index, Rating<i64, D>>,
) {
assert!(
results.is_empty() || results.len() == composition.len(),
"(length(results) > 0) & (length(composition) != length(results))"
);
assert!(
times.len() == composition.len(),
"length(times) must equal length(composition)"
);
assert!(
weights.is_empty() || weights.len() == composition.len(),
"(length(weights) > 0) & (length(composition) != length(weights))"
);
) -> Result<(), InferenceError> {
if !results.is_empty() && results.len() != composition.len() {
return Err(InferenceError::MismatchedShape {
kind: "results",
expected: composition.len(),
got: results.len(),
});
}
if times.len() != composition.len() {
return Err(InferenceError::MismatchedShape {
kind: "times",
expected: composition.len(),
got: times.len(),
});
}
if !weights.is_empty() && weights.len() != composition.len() {
return Err(InferenceError::MismatchedShape {
kind: "weights",
expected: composition.len(),
got: weights.len(),
});
}
competitor::clean(self.agents.values_mut(), true);
@@ -467,6 +476,7 @@ impl<D: Drift<i64>, O: Observer<i64>> History<i64, D, O> {
}
self.size += n;
Ok(())
}
}
@@ -510,7 +520,8 @@ mod tests {
let mut h = History::default();
h.add_events_with_prior(composition, results, vec![1, 2, 3], vec![], priors);
h.add_events_with_prior(composition, results, vec![1, 2, 3], vec![], priors)
.unwrap();
let p0 = h.time_slices[0].posteriors();
@@ -586,7 +597,8 @@ mod tests {
let mut h1 = History::default();
h1.add_events_with_prior(composition, results, times, vec![], priors);
h1.add_events_with_prior(composition, results, times, vec![], priors)
.unwrap();
assert_ulps_eq!(
h1.time_slices[0].skills.get(a).unwrap().posterior(),
@@ -635,7 +647,8 @@ mod tests {
let mut h2 = History::default();
h2.add_events_with_prior(composition, results, times, vec![], priors);
h2.add_events_with_prior(composition, results, times, vec![], priors)
.unwrap();
assert_ulps_eq!(
h2.time_slices[2].skills.get(a).unwrap().posterior(),
@@ -693,7 +706,8 @@ mod tests {
let mut h = History::default();
h.add_events_with_prior(composition, results, times, vec![], priors);
h.add_events_with_prior(composition, results, times, vec![], priors)
.unwrap();
h.convergence(ITERATIONS, EPSILON, false);
let lc = h.learning_curves();
@@ -740,7 +754,7 @@ mod tests {
let n = composition.len();
let times: Vec<i64> = (1..=n as i64).collect();
h.add_events(composition, results, times, vec![]);
h.add_events(composition, results, times, vec![]).unwrap();
h.convergence(ITERATIONS, EPSILON, false);
@@ -791,7 +805,7 @@ mod tests {
let n = composition.len();
let times: Vec<i64> = (1..=n as i64).collect();
h.add_events(composition, results, times, vec![]);
h.add_events(composition, results, times, vec![]).unwrap();
let trueskill_log_evidence = h.log_evidence(false, &[]);
let trueskill_log_evidence_online = h.log_evidence(true, &[]);
@@ -879,7 +893,8 @@ mod tests {
let n = composition.len();
let times: Vec<i64> = (1..=n as i64).collect();
h.add_events(composition.clone(), results.clone(), times, vec![]);
h.add_events(composition.clone(), results.clone(), times, vec![])
.unwrap();
h.convergence(ITERATIONS, EPSILON, false);
@@ -903,7 +918,7 @@ mod tests {
);
let times2: Vec<i64> = (n as i64 + 1..=2 * n as i64).collect();
h.add_events(composition, results, times2, vec![]);
h.add_events(composition, results, times2, vec![]).unwrap();
assert_eq!(h.time_slices.len(), 6);
@@ -970,7 +985,8 @@ mod tests {
let n = composition.len();
let times: Vec<i64> = (1..=n as i64).collect();
h.add_events(composition.clone(), results.clone(), times, vec![]);
h.add_events(composition.clone(), results.clone(), times, vec![])
.unwrap();
h.convergence(ITERATIONS, EPSILON, false);
@@ -994,7 +1010,7 @@ mod tests {
);
let times2: Vec<i64> = (n as i64 + 1..=2 * n as i64).collect();
h.add_events(composition, results, times2, vec![]);
h.add_events(composition, results, times2, vec![]).unwrap();
assert_eq!(h.time_slices.len(), 6);
@@ -1050,7 +1066,8 @@ mod tests {
let n = composition.len();
let times: Vec<i64> = (1..=n as i64).collect();
h.add_events(composition.clone(), vec![], times.clone(), vec![]);
h.add_events(composition.clone(), vec![], times.clone(), vec![])
.unwrap();
let p_d_m_2 = h.log_evidence(false, &[]).exp() * 2.0;
@@ -1089,7 +1106,7 @@ mod tests {
let mut h = History::builder().build();
h.add_events(composition, vec![], times, vec![]);
h.add_events(composition, vec![], times, vec![]).unwrap();
assert_ulps_eq!(
((0.5f64 * 0.1765).ln() / 2.0).exp(),
@@ -1125,11 +1142,13 @@ mod tests {
results.clone(),
vec![0, 10, 20],
vec![],
);
)
.unwrap();
h.convergence(ITERATIONS, EPSILON, false);
h.add_events(composition, results, vec![15, 10, 0], vec![]);
h.add_events(composition, results, vec![15, 10, 0], vec![])
.unwrap();
assert_eq!(h.time_slices.len(), 4);
@@ -1213,11 +1232,13 @@ mod tests {
.gamma(0.0)
.build();
h.add_events(composition.clone(), vec![], vec![0, 10, 20], vec![]);
h.add_events(composition.clone(), vec![], vec![0, 10, 20], vec![])
.unwrap();
h.convergence(ITERATIONS, EPSILON, false);
h.add_events(composition, vec![], vec![15, 10, 0], vec![]);
h.add_events(composition, vec![], vec![15, 10, 0], vec![])
.unwrap();
assert_eq!(h.time_slices.len(), 4);
@@ -1306,7 +1327,7 @@ mod tests {
let n = composition.len();
let times: Vec<i64> = (1..=n as i64).collect();
h.add_events(composition, vec![], times, weights);
h.add_events(composition, vec![], times, weights).unwrap();
let lc = h.learning_curves();
@@ -1367,7 +1388,7 @@ mod tests {
epsilon: 1e-6,
})
.build();
h.add_events(composition, results, times, vec![]);
h.add_events(composition, results, times, vec![]).unwrap();
let report = h.converge().unwrap();
assert!(report.converged);