T0 + T1 + T2: engine redesign through new API surface #1

Merged
logaritmisk merged 45 commits from t2-new-api-surface into main 2026-04-24 11:20:04 +00:00
2 changed files with 44 additions and 76 deletions
Showing only changes of commit 33a7d90b89 - Show all commits

View File

@@ -15,7 +15,6 @@ use crate::{
#[derive(Clone)] #[derive(Clone)]
pub struct HistoryBuilder<T: Time = i64, D: Drift<T> = ConstantDrift> { pub struct HistoryBuilder<T: Time = i64, D: Drift<T> = ConstantDrift> {
time: bool,
mu: f64, mu: f64,
sigma: f64, sigma: f64,
beta: f64, beta: f64,
@@ -26,11 +25,6 @@ pub struct HistoryBuilder<T: Time = i64, D: Drift<T> = ConstantDrift> {
} }
impl<T: Time, D: Drift<T>> HistoryBuilder<T, D> { impl<T: Time, D: Drift<T>> HistoryBuilder<T, D> {
pub fn time(mut self, time: bool) -> Self {
self.time = time;
self
}
pub fn mu(mut self, mu: f64) -> Self { pub fn mu(mut self, mu: f64) -> Self {
self.mu = mu; self.mu = mu;
self self
@@ -49,7 +43,6 @@ impl<T: Time, D: Drift<T>> HistoryBuilder<T, D> {
pub fn drift<D2: Drift<T>>(self, drift: D2) -> HistoryBuilder<T, D2> { pub fn drift<D2: Drift<T>>(self, drift: D2) -> HistoryBuilder<T, D2> {
HistoryBuilder { HistoryBuilder {
drift, drift,
time: self.time,
mu: self.mu, mu: self.mu,
sigma: self.sigma, sigma: self.sigma,
beta: self.beta, beta: self.beta,
@@ -74,7 +67,6 @@ impl<T: Time, D: Drift<T>> HistoryBuilder<T, D> {
size: 0, size: 0,
time_slices: Vec::new(), time_slices: Vec::new(),
agents: CompetitorStore::new(), agents: CompetitorStore::new(),
time: self.time,
mu: self.mu, mu: self.mu,
sigma: self.sigma, sigma: self.sigma,
beta: self.beta, beta: self.beta,
@@ -95,7 +87,6 @@ impl HistoryBuilder<i64, ConstantDrift> {
impl Default for HistoryBuilder<i64, ConstantDrift> { impl Default for HistoryBuilder<i64, ConstantDrift> {
fn default() -> Self { fn default() -> Self {
Self { Self {
time: true,
mu: MU, mu: MU,
sigma: SIGMA, sigma: SIGMA,
beta: BETA, beta: BETA,
@@ -111,7 +102,6 @@ pub struct History<T: Time = i64, D: Drift<T> = ConstantDrift> {
size: usize, size: usize,
pub(crate) time_slices: Vec<TimeSlice<T>>, pub(crate) time_slices: Vec<TimeSlice<T>>,
pub(crate) agents: CompetitorStore<T, D>, pub(crate) agents: CompetitorStore<T, D>,
time: bool,
mu: f64, mu: f64,
sigma: f64, sigma: f64,
beta: f64, beta: f64,
@@ -126,7 +116,6 @@ impl Default for History<i64, ConstantDrift> {
size: 0, size: 0,
time_slices: Vec::new(), time_slices: Vec::new(),
agents: CompetitorStore::new(), agents: CompetitorStore::new(),
time: true,
mu: MU, mu: MU,
sigma: SIGMA, sigma: SIGMA,
beta: BETA, beta: BETA,
@@ -275,18 +264,13 @@ impl<D: Drift<i64>> History<i64, D> {
weights: Vec<Vec<Vec<f64>>>, weights: Vec<Vec<Vec<f64>>>,
mut priors: HashMap<Index, Rating<i64, D>>, mut priors: HashMap<Index, Rating<i64, D>>,
) { ) {
assert!(times.is_empty() || self.time, "length(times)>0 but !h.time");
assert!(
!times.is_empty() || !self.time,
"length(times)==0 but h.time"
);
assert!( assert!(
results.is_empty() || results.len() == composition.len(), results.is_empty() || results.len() == composition.len(),
"(length(results) > 0) & (length(composition) != length(results))" "(length(results) > 0) & (length(composition) != length(results))"
); );
assert!( assert!(
times.is_empty() || times.len() == composition.len(), times.len() == composition.len(),
"length(times) > 0) & (length(composition) != length(times))" "length(times) must equal length(composition)"
); );
assert!( assert!(
weights.is_empty() || weights.len() == composition.len(), weights.is_empty() || weights.len() == composition.len(),
@@ -323,26 +307,20 @@ impl<D: Drift<i64>> History<i64, D> {
} }
let n = composition.len(); let n = composition.len();
let o = if self.time { let o = sort_time(&times, false);
sort_time(&times, false)
} else {
(0..composition.len()).collect::<Vec<_>>()
};
let mut i = 0; let mut i = 0;
let mut k = 0; let mut k = 0;
while i < n { while i < n {
let mut j = i + 1; let mut j = i + 1;
let t = if self.time { times[o[i]] } else { i as i64 + 1 }; let t = times[o[i]];
while self.time && j < n && times[o[j]] == t { while j < n && times[o[j]] == t {
j += 1; j += 1;
} }
while (!self.time && (self.size > k)) while self.time_slices.len() > k && self.time_slices[k].time < t {
|| (self.time && self.time_slices.len() > k && self.time_slices[k].time < t)
{
let time_slice = &mut self.time_slices[k]; let time_slice = &mut self.time_slices[k];
if k > 0 { if k > 0 {
@@ -363,16 +341,6 @@ impl<D: Drift<i64>> History<i64, D> {
} }
} }
if !self.time {
let slice_time = time_slice.time;
for agent_idx in &this_agent {
let c = self.agents.get_mut(*agent_idx).unwrap();
if c.last_time.is_some() {
c.last_time = Some(slice_time);
}
}
}
k += 1; k += 1;
} }
@@ -392,7 +360,7 @@ impl<D: Drift<i64>> History<i64, D> {
(i..j).map(|e| weights[o[e]].clone()).collect::<Vec<_>>() (i..j).map(|e| weights[o[e]].clone()).collect::<Vec<_>>()
}; };
if self.time && self.time_slices.len() > k && self.time_slices[k].time == t { if self.time_slices.len() > k && self.time_slices[k].time == t {
let time_slice = &mut self.time_slices[k]; let time_slice = &mut self.time_slices[k];
time_slice.add_events(composition, results, weights, &self.agents); time_slice.add_events(composition, results, weights, &self.agents);
@@ -417,22 +385,13 @@ impl<D: Drift<i64>> History<i64, D> {
agent.message = time_slice.forward_prior_out(&agent_idx); agent.message = time_slice.forward_prior_out(&agent_idx);
} }
if !self.time {
for agent_idx in &this_agent {
let c = self.agents.get_mut(*agent_idx).unwrap();
if c.last_time.is_some() {
c.last_time = Some(t);
}
}
}
k += 1; k += 1;
} }
i = j; i = j;
} }
while self.time && self.time_slices.len() > k { while self.time_slices.len() > k {
let time_slice = &mut self.time_slices[k]; let time_slice = &mut self.time_slices[k];
time_slice.new_forward_info(&self.agents); time_slice.new_forward_info(&self.agents);
@@ -724,29 +683,30 @@ mod tests {
.sigma(25.0 / 3.0) .sigma(25.0 / 3.0)
.beta(25.0 / 6.0) .beta(25.0 / 6.0)
.gamma(25.0 / 300.0) .gamma(25.0 / 300.0)
.time(false)
.build(); .build();
h.add_events(composition, results, vec![], vec![]); let n = composition.len();
let times: Vec<i64> = (1..=n as i64).collect();
h.add_events(composition, results, times, vec![]);
h.convergence(ITERATIONS, EPSILON, false); h.convergence(ITERATIONS, EPSILON, false);
assert_eq!(h.time_slices[2].skills.get(b).unwrap().elapsed, 1); assert_eq!(h.time_slices[2].skills.get(b).unwrap().elapsed, 2);
assert_eq!(h.time_slices[2].skills.get(c).unwrap().elapsed, 1); assert_eq!(h.time_slices[2].skills.get(c).unwrap().elapsed, 1);
assert_ulps_eq!( assert_ulps_eq!(
h.time_slices[0].skills.get(a).unwrap().posterior(), h.time_slices[0].skills.get(a).unwrap().posterior(),
Gaussian::from_ms(25.000267, 5.419381), Gaussian::from_ms(25.000267, 5.419423),
epsilon = 1e-6 epsilon = 1e-6
); );
assert_ulps_eq!( assert_ulps_eq!(
h.time_slices[0].skills.get(b).unwrap().posterior(), h.time_slices[0].skills.get(b).unwrap().posterior(),
Gaussian::from_ms(24.999465, 5.419425), Gaussian::from_ms(24.999198, 5.419512),
epsilon = 1e-6 epsilon = 1e-6
); );
assert_ulps_eq!( assert_ulps_eq!(
h.time_slices[2].skills.get(b).unwrap().posterior(), h.time_slices[2].skills.get(b).unwrap().posterior(),
Gaussian::from_ms(25.000532, 5.419696), Gaussian::from_ms(25.001332, 5.420054),
epsilon = 1e-6 epsilon = 1e-6
); );
} }
@@ -774,10 +734,11 @@ mod tests {
.sigma(6.0) .sigma(6.0)
.beta(1.0) .beta(1.0)
.gamma(0.0) .gamma(0.0)
.time(false)
.build(); .build();
h.add_events(composition, results, vec![], vec![]); let n = composition.len();
let times: Vec<i64> = (1..=n as i64).collect();
h.add_events(composition, results, times, vec![]);
let trueskill_log_evidence = h.log_evidence(false, &[]); let trueskill_log_evidence = h.log_evidence(false, &[]);
let trueskill_log_evidence_online = h.log_evidence(true, &[]); let trueskill_log_evidence_online = h.log_evidence(true, &[]);
@@ -861,14 +822,15 @@ mod tests {
.sigma(2.0) .sigma(2.0)
.beta(1.0) .beta(1.0)
.gamma(0.0) .gamma(0.0)
.time(false)
.build(); .build();
h.add_events(composition.clone(), results.clone(), vec![], vec![]); let n = composition.len();
let times: Vec<i64> = (1..=n as i64).collect();
h.add_events(composition.clone(), results.clone(), times, vec![]);
h.convergence(ITERATIONS, EPSILON, false); h.convergence(ITERATIONS, EPSILON, false);
assert_eq!(h.time_slices[2].skills.get(b).unwrap().elapsed, 1); assert_eq!(h.time_slices[2].skills.get(b).unwrap().elapsed, 2);
assert_eq!(h.time_slices[2].skills.get(c).unwrap().elapsed, 1); assert_eq!(h.time_slices[2].skills.get(c).unwrap().elapsed, 1);
assert_ulps_eq!( assert_ulps_eq!(
@@ -887,7 +849,8 @@ mod tests {
epsilon = 1e-6 epsilon = 1e-6
); );
h.add_events(composition, results, vec![], vec![]); let times2: Vec<i64> = (n as i64 + 1..=2 * n as i64).collect();
h.add_events(composition, results, times2, vec![]);
assert_eq!(h.time_slices.len(), 6); assert_eq!(h.time_slices.len(), 6);
@@ -950,14 +913,15 @@ mod tests {
.sigma(2.0) .sigma(2.0)
.beta(1.0) .beta(1.0)
.gamma(0.0) .gamma(0.0)
.time(false)
.build(); .build();
h.add_events(composition.clone(), results.clone(), vec![], vec![]); let n = composition.len();
let times: Vec<i64> = (1..=n as i64).collect();
h.add_events(composition.clone(), results.clone(), times, vec![]);
h.convergence(ITERATIONS, EPSILON, false); h.convergence(ITERATIONS, EPSILON, false);
assert_eq!(h.time_slices[2].skills.get(b).unwrap().elapsed, 1); assert_eq!(h.time_slices[2].skills.get(b).unwrap().elapsed, 2);
assert_eq!(h.time_slices[2].skills.get(c).unwrap().elapsed, 1); assert_eq!(h.time_slices[2].skills.get(c).unwrap().elapsed, 1);
assert_ulps_eq!( assert_ulps_eq!(
@@ -976,7 +940,8 @@ mod tests {
epsilon = 1e-6 epsilon = 1e-6
); );
h.add_events(composition, results, vec![], vec![]); let times2: Vec<i64> = (n as i64 + 1..=2 * n as i64).collect();
h.add_events(composition, results, times2, vec![]);
assert_eq!(h.time_slices.len(), 6); assert_eq!(h.time_slices.len(), 6);
@@ -1028,9 +993,11 @@ mod tests {
let composition = vec![vec![vec![a], vec![b]], vec![vec![b], vec![a]]]; let composition = vec![vec![vec![a], vec![b]], vec![vec![b], vec![a]]];
let mut h = History::builder().time(false).build(); let mut h = History::builder().build();
h.add_events(composition.clone(), vec![], vec![], vec![]); let n = composition.len();
let times: Vec<i64> = (1..=n as i64).collect();
h.add_events(composition.clone(), vec![], times.clone(), vec![]);
let p_d_m_2 = h.log_evidence(false, &[]).exp() * 2.0; let p_d_m_2 = h.log_evidence(false, &[]).exp() * 2.0;
@@ -1067,9 +1034,9 @@ mod tests {
epsilon = 1e-4 epsilon = 1e-4
); );
let mut h = History::builder().time(false).build(); let mut h = History::builder().build();
h.add_events(composition, vec![], vec![], vec![]); h.add_events(composition, vec![], times, vec![]);
assert_ulps_eq!( assert_ulps_eq!(
((0.5f64 * 0.1765).ln() / 2.0).exp(), ((0.5f64 * 0.1765).ln() / 2.0).exp(),
@@ -1282,10 +1249,11 @@ mod tests {
.sigma(6.0) .sigma(6.0)
.beta(1.0) .beta(1.0)
.gamma(0.0) .gamma(0.0)
.time(false)
.build(); .build();
h.add_events(composition, vec![], vec![], weights); let n = composition.len();
let times: Vec<i64> = (1..=n as i64).collect();
h.add_events(composition, vec![], times, weights);
let lc = h.learning_curves(); let lc = h.learning_curves();

View File

@@ -172,13 +172,13 @@ pub(crate) fn tuple_gt(t: (f64, f64), e: f64) -> bool {
t.0 > e || t.1 > e t.0 > e || t.1 > e
} }
pub(crate) fn sort_time(xs: &[i64], reverse: bool) -> Vec<usize> { pub(crate) fn sort_time<T: Copy + Ord>(xs: &[T], reverse: bool) -> Vec<usize> {
let mut x = xs.iter().enumerate().collect::<Vec<_>>(); let mut x: Vec<(usize, T)> = xs.iter().enumerate().map(|(i, &t)| (i, t)).collect();
if reverse { if reverse {
x.sort_by_key(|&(_, x)| Reverse(x)); x.sort_by_key(|&(_, t)| Reverse(t));
} else { } else {
x.sort_by_key(|&(_, x)| x); x.sort_by_key(|&(_, t)| t);
} }
x.into_iter().map(|(i, _)| i).collect() x.into_iter().map(|(i, _)| i).collect()
@@ -254,7 +254,7 @@ mod tests {
#[test] #[test]
fn test_sort_time() { fn test_sort_time() {
assert_eq!(sort_time(&[0, 1, 2, 0], true), vec![2, 1, 0, 3]); assert_eq!(sort_time(&[0i64, 1, 2, 0], true), vec![2, 1, 0, 3]);
} }
#[test] #[test]