diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..8a1fa1d --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,45 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Commands + +```bash +cargo build # Build the library +cargo test --lib # Run all library tests +cargo test --lib # Run a single test by name +cargo test --lib -- --nocapture # Run tests with stdout output +cargo clippy # Lint +cargo bench # Run benchmarks (criterion) +``` + +The `approx` feature enables `approx::AbsDiffEq` for `Gaussian`: +```bash +cargo test --features approx +``` + +## Architecture + +This is a Rust port of [TrueSkillThroughTime.py](https://github.com/glandfried/TrueSkillThroughTime.py) — a Bayesian skill rating system that tracks skill evolution over time using Gaussian message passing. + +### Data flow + +``` +History → Batch[] → Game[] → teams/players +``` + +- **`History`** (`history.rs`) — top-level container. Organizes games by time into `Batch`es, runs forward/backward message passing across batches, and exposes `learning_curves()` and `log_evidence()`. +- **`Batch`** (`batch.rs`) — all games at a single time step. Runs `iteration()` to update skill estimates via `Game::posteriors()`, collecting `Skill` distributions per player. +- **`Game`** (`game.rs`) — a single match. Given teams (slices of `Gaussian`), computes posterior skill distributions using Gaussian factor graphs and `message.rs` helpers. +- **`Agent`** (`agent.rs`) — wraps a `Player` with temporal state (`last_time`, `message`). `receive()` applies time-decay (`gamma`) when the player reappears after a gap. +- **`Player`** (`player.rs`) — static configuration: prior `Gaussian`, `beta` (performance noise), `gamma` (skill drift per time unit). +- **`Gaussian`** (`gaussian.rs`) — core probability type. Stored as natural parameters (`pi = 1/sigma²`, `tau = mu/sigma²`). Arithmetic ops implement message multiplication/division in the factor graph. +- **`message.rs`** — `TeamMessage` and `DiffMessage`: intermediate factor graph messages used inside `Game`. +- **`lib.rs`** — exports the public API (`Game`, `Gaussian`, `History`, `Player`) and standalone functions (`quality()`, `pdf()`, `cdf()`, `erfc()`). Also defines global defaults: `MU=0.0`, `SIGMA=6.0`, `BETA=1.0`, `GAMMA=0.03`, `P_DRAW=0.0`, `EPSILON=1e-6`, `ITERATIONS=30`. + +### Key design points + +- `History` uses `IndexMap` (defined in `lib.rs`) to map arbitrary player keys to `Agent` state. +- Convergence is measured by the maximum `delta()` across all skill distributions; iteration stops when below `EPSILON` or after `ITERATIONS` rounds. +- The `approx` feature gates `AbsDiffEq` on `Gaussian` for use in tests — the feature is optional and only needed for approximate equality assertions. +- `time` in `History`/`Batch` is currently an `f64`; the README notes it needs to become an enum to support richer temporal states. diff --git a/NOTEPAD.md b/NOTEPAD.md index f0e91f2..7554c18 100644 --- a/NOTEPAD.md +++ b/NOTEPAD.md @@ -6,3 +6,11 @@ let mut history = History::new(); let agent_a = history.new_agent(); let agent_b = history.new_agent_with_prior(Prior::new(Gaussian::default(), BETA, GAMMA)); ``` + +```rust +trait Team { + fn players(&self) -> impl Iterator; + fn weights(&self) -> impl Iterator; + fn score(&self) -> u16; +} +``` diff --git a/README.md b/README.md index cfc412e..84a190e 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,66 @@ Rust port of [TrueSkillThroughTime.py](https://github.com/glandfried/TrueSkillTh - [TrueSkill Through Time: Revisiting the History of Chess](https://www.microsoft.com/en-us/research/wp-content/uploads/2008/01/NIPS2007_0931.pdf) - [TrueSkill Through Time. The full scientific documentation](https://glandfried.github.io/publication/landfried2021-learning/) +## Drift + +Skill drift models how a player's true skill can change between appearances. Each time a player reappears after a gap, their skill uncertainty is widened by the drift model before the new evidence is incorporated. + +Drift is represented by the `Drift` trait: + +```rust +pub trait Drift: Copy + Debug { + fn variance_delta(&self, elapsed: i64) -> f64; +} +``` + +`variance_delta` returns the amount to add to `σ²` given the elapsed time since the player last played. Internally, `Gaussian::forget` uses this to compute the new sigma: `σ_new = sqrt(σ² + variance_delta)`. + +### ConstantDrift + +The built-in `ConstantDrift` implements a linear random walk — skill uncertainty grows proportionally to time: + +``` +variance_delta = elapsed * γ² +``` + +This is the standard TrueSkill Through Time model. Use it by passing a `ConstantDrift(gamma)` when constructing a `Player`: + +```rust +use trueskill_tt::{Player, Gaussian, drift::ConstantDrift}; + +// gamma = 0.1 means skill can shift ~0.1 per time unit +let player = Player::new(Gaussian::from_ms(0.0, 6.0), 1.0, ConstantDrift(0.1)); +``` + +### Custom drift + +Implement `Drift` to express any other model. For example, a drift that saturates after a long absence (uncertainty grows with the square root of elapsed time instead of linearly): + +```rust +use trueskill_tt::drift::Drift; + +#[derive(Clone, Copy, Debug)] +struct SqrtDrift { + gamma: f64, +} + +impl Drift for SqrtDrift { + fn variance_delta(&self, elapsed: i64) -> f64 { + (elapsed as f64).sqrt() * self.gamma * self.gamma + } +} + +let player = Player::new(Gaussian::from_ms(0.0, 6.0), 1.0, SqrtDrift { gamma: 0.5 }); +``` + +To use a custom drift type with `History`, use the `.drift()` builder method instead of `.gamma()`: + +```rust +let h = History::builder() + .drift(SqrtDrift { gamma: 0.5 }) + .build(); +``` + ## Todo - [x] Implement approx for Gaussian diff --git a/benches/batch.rs b/benches/batch.rs index 4088818..637a2fb 100644 --- a/benches/batch.rs +++ b/benches/batch.rs @@ -1,9 +1,9 @@ use std::collections::HashMap; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use trueskill_tt::{ - agent::Agent, batch::Batch, gaussian::Gaussian, player::Player, IndexMap, BETA, GAMMA, MU, - P_DRAW, SIGMA, + BETA, GAMMA, IndexMap, MU, P_DRAW, SIGMA, agent::Agent, batch::Batch, drift::ConstantDrift, + gaussian::Gaussian, player::Player, }; fn criterion_benchmark(criterion: &mut Criterion) { @@ -19,21 +19,21 @@ fn criterion_benchmark(criterion: &mut Criterion) { map.insert( a, Agent { - player: Player::new(Gaussian::from_ms(MU, SIGMA), BETA, GAMMA), + player: Player::new(Gaussian::from_ms(MU, SIGMA), BETA, ConstantDrift(GAMMA)), ..Default::default() }, ); map.insert( b, Agent { - player: Player::new(Gaussian::from_ms(MU, SIGMA), BETA, GAMMA), + player: Player::new(Gaussian::from_ms(MU, SIGMA), BETA, ConstantDrift(GAMMA)), ..Default::default() }, ); map.insert( c, Agent { - player: Player::new(Gaussian::from_ms(MU, SIGMA), BETA, GAMMA), + player: Player::new(Gaussian::from_ms(MU, SIGMA), BETA, ConstantDrift(GAMMA)), ..Default::default() }, ); diff --git a/graph.d2 b/graph.d2 new file mode 100644 index 0000000..bf8267b --- /dev/null +++ b/graph.d2 @@ -0,0 +1,64 @@ +vars: { + d2-config: { + layout-engine: elk + # Terminal theme code + theme-id: 300 + } +} + +History: { + shape: class + + agents: "HashMap" + batches: "Vec" +} + +Batch: { + shape: class + + skills: "HashMap" + events: "Vec" + time: "i64" + p_draw: "f64" +} + +Event: { + shape: class + + teams: "Vec" + weights: "Vec>" + evidence: "f64" +} + +Team: { + shape: class + + items: "Vec" + output: "f64" +} + +Item: { + shape: class + + agent: "Index" + likelihood: "Gaussian" +} + +Skill: { + shape: class + + forward: "Gaussian" + backward: "Gaussian" + likelihood: "Gaussian" + elapsed: "i64" + online: "Gaussian" +} + +History -> Batch + +Batch -> Skill +Batch -> Event + +Event -> Team + +Team -> Item diff --git a/src/agent.rs b/src/agent.rs index c77cac6..e8073b9 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -1,23 +1,29 @@ -use crate::{gaussian::Gaussian, player::Player, N_INF}; +use crate::{ + N_INF, + drift::{ConstantDrift, Drift}, + gaussian::Gaussian, + player::Player, +}; #[derive(Debug)] -pub struct Agent { - pub player: Player, +pub struct Agent { + pub player: Player, pub message: Gaussian, pub last_time: i64, } -impl Agent { +impl Agent { pub(crate) fn receive(&self, elapsed: i64) -> Gaussian { if self.message != N_INF { - self.message.forget(self.player.gamma, elapsed) + self.message + .forget(self.player.drift.variance_delta(elapsed)) } else { self.player.prior } } } -impl Default for Agent { +impl Default for Agent { fn default() -> Self { Self { player: Player::default(), @@ -27,7 +33,10 @@ impl Default for Agent { } } -pub(crate) fn clean<'a, A: Iterator>(agents: A, last_time: bool) { +pub(crate) fn clean<'a, D: Drift + 'a, A: Iterator>>( + agents: A, + last_time: bool, +) { for a in agents { a.message = N_INF; diff --git a/src/batch.rs b/src/batch.rs index 8a58007..e4ec886 100644 --- a/src/batch.rs +++ b/src/batch.rs @@ -1,7 +1,8 @@ use std::collections::HashMap; use crate::{ - agent::Agent, game::Game, gaussian::Gaussian, player::Player, tuple_gt, tuple_max, Index, N_INF, + Index, N_INF, agent::Agent, drift::Drift, game::Game, gaussian::Gaussian, player::Player, + tuple_gt, tuple_max, }; #[derive(Debug)] @@ -38,22 +39,22 @@ struct Item { } impl Item { - fn within_prior( + fn within_prior( &self, online: bool, forward: bool, skills: &HashMap, - agents: &HashMap, - ) -> Player { + agents: &HashMap>, + ) -> Player { let r = &agents[&self.agent].player; let skill = &skills[&self.agent]; if online { - Player::new(skill.online, r.beta, r.gamma) + Player::new(skill.online, r.beta, r.drift) } else if forward { - Player::new(skill.forward, r.beta, r.gamma) + Player::new(skill.forward, r.beta, r.drift) } else { - Player::new(skill.posterior() / self.likelihood, r.beta, r.gamma) + Player::new(skill.posterior() / self.likelihood, r.beta, r.drift) } } } @@ -79,13 +80,13 @@ impl Event { .collect::>() } - pub(crate) fn within_priors( + pub(crate) fn within_priors( &self, online: bool, forward: bool, skills: &HashMap, - agents: &HashMap, - ) -> Vec> { + agents: &HashMap>, + ) -> Vec>> { self.teams .iter() .map(|team| { @@ -116,12 +117,12 @@ impl Batch { } } - pub fn add_events( + pub fn add_events( &mut self, composition: Vec>>, results: Vec>, weights: Vec>>, - agents: &HashMap, + agents: &HashMap>, ) { let mut unique = Vec::with_capacity(10); @@ -207,7 +208,7 @@ impl Batch { .collect::>() } - pub fn iteration(&mut self, from: usize, agents: &HashMap) { + pub fn iteration(&mut self, from: usize, agents: &HashMap>) { for event in self.events.iter_mut().skip(from) { let teams = event.within_priors(false, false, &self.skills, agents); let result = event.outputs(); @@ -229,7 +230,7 @@ impl Batch { } #[allow(dead_code)] - pub(crate) fn convergence(&mut self, agents: &HashMap) -> usize { + pub(crate) fn convergence(&mut self, agents: &HashMap>) -> usize { let epsilon = 1e-6; let iterations = 20; @@ -259,18 +260,18 @@ impl Batch { skill.forward * skill.likelihood } - pub(crate) fn backward_prior_out( + pub(crate) fn backward_prior_out( &self, agent: &Index, - agents: &HashMap, + agents: &HashMap>, ) -> Gaussian { let skill = &self.skills[agent]; let n = skill.likelihood * skill.backward; - n.forget(agents[agent].player.gamma, skill.elapsed) + n.forget(agents[agent].player.drift.variance_delta(skill.elapsed)) } - pub(crate) fn new_backward_info(&mut self, agents: &HashMap) { + pub(crate) fn new_backward_info(&mut self, agents: &HashMap>) { for (agent, skill) in self.skills.iter_mut() { skill.backward = agents[agent].message; } @@ -278,7 +279,7 @@ impl Batch { self.iteration(0, agents); } - pub(crate) fn new_forward_info(&mut self, agents: &HashMap) { + pub(crate) fn new_forward_info(&mut self, agents: &HashMap>) { for (agent, skill) in self.skills.iter_mut() { skill.forward = agents[agent].receive(skill.elapsed); } @@ -286,12 +287,12 @@ impl Batch { self.iteration(0, agents); } - pub(crate) fn log_evidence( + pub(crate) fn log_evidence( &self, online: bool, targets: &[Index], forward: bool, - agents: &HashMap, + agents: &HashMap>, ) -> f64 { if targets.is_empty() { if online || forward { @@ -390,7 +391,7 @@ pub(crate) fn compute_elapsed(last_time: i64, actual_time: i64) -> i64 { mod tests { use approx::assert_ulps_eq; - use crate::{agent::Agent, player::Player, IndexMap}; + use crate::{IndexMap, agent::Agent, drift::ConstantDrift, player::Player}; use super::*; @@ -414,7 +415,7 @@ mod tests { player: Player::new( Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, - 25.0 / 300.0, + ConstantDrift(25.0 / 300.0), ), ..Default::default() }, @@ -490,7 +491,7 @@ mod tests { player: Player::new( Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, - 25.0 / 300.0, + ConstantDrift(25.0 / 300.0), ), ..Default::default() }, @@ -569,7 +570,7 @@ mod tests { player: Player::new( Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, - 25.0 / 300.0, + ConstantDrift(25.0 / 300.0), ), ..Default::default() }, diff --git a/src/drift.rs b/src/drift.rs new file mode 100644 index 0000000..5c7107e --- /dev/null +++ b/src/drift.rs @@ -0,0 +1,14 @@ +use std::fmt::Debug; + +pub trait Drift: Copy + Debug { + fn variance_delta(&self, elapsed: i64) -> f64; +} + +#[derive(Clone, Copy, Debug)] +pub struct ConstantDrift(pub f64); + +impl Drift for ConstantDrift { + fn variance_delta(&self, elapsed: i64) -> f64 { + elapsed as f64 * self.0 * self.0 + } +} diff --git a/src/game.rs b/src/game.rs index 7fe08f4..a43c8c6 100644 --- a/src/game.rs +++ b/src/game.rs @@ -1,14 +1,16 @@ use crate::{ - approx, compute_margin, evidence, + N_INF, N00, approx, compute_margin, + drift::Drift, + evidence, gaussian::Gaussian, message::{DiffMessage, TeamMessage}, player::Player, - sort_perm, tuple_gt, tuple_max, N00, N_INF, + sort_perm, tuple_gt, tuple_max, }; #[derive(Debug)] -pub struct Game<'a> { - teams: Vec>, +pub struct Game<'a, D: Drift> { + teams: Vec>>, result: &'a [f64], weights: &'a [Vec], p_draw: f64, @@ -16,9 +18,9 @@ pub struct Game<'a> { pub(crate) evidence: f64, } -impl<'a> Game<'a> { +impl<'a, D: Drift> Game<'a, D> { pub fn new( - teams: Vec>, + teams: Vec>>, result: &'a [f64], weights: &'a [Vec], p_draw: f64, @@ -176,7 +178,7 @@ impl<'a> Game<'a> { .zip(w.iter()) .map(|(p, &w)| { ((m - performance.exclude(p.performance() * w)) * (1.0 / w)) - .forget(p.beta, 1) + .forget(p.beta.powi(2)) }) .collect::>() }) @@ -201,7 +203,7 @@ impl<'a> Game<'a> { mod tests { use ::approx::assert_ulps_eq; - use crate::{Gaussian, Player, GAMMA, N_INF}; + use crate::{ConstantDrift, GAMMA, Gaussian, N_INF, Player}; use super::*; @@ -210,12 +212,12 @@ mod tests { let t_a = Player::new( Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, - 25.0 / 300.0, + ConstantDrift(25.0 / 300.0), ); let t_b = Player::new( Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, - 25.0 / 300.0, + ConstantDrift(25.0 / 300.0), ); let w = [vec![1.0], vec![1.0]]; @@ -228,8 +230,16 @@ mod tests { assert_ulps_eq!(a, Gaussian::from_ms(20.794779, 7.194481), epsilon = 1e-6); assert_ulps_eq!(b, Gaussian::from_ms(29.205220, 7.194481), epsilon = 1e-6); - let t_a = Player::new(Gaussian::from_ms(29.0, 1.0), 25.0 / 6.0, GAMMA); - let t_b = Player::new(Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, GAMMA); + let t_a = Player::new( + Gaussian::from_ms(29.0, 1.0), + 25.0 / 6.0, + ConstantDrift(GAMMA), + ); + let t_b = Player::new( + Gaussian::from_ms(25.0, 25.0 / 3.0), + 25.0 / 6.0, + ConstantDrift(GAMMA), + ); let w = [vec![1.0], vec![1.0]]; let g = Game::new(vec![vec![t_a], vec![t_b]], &[0.0, 1.0], &w, 0.0); @@ -241,8 +251,8 @@ mod tests { assert_ulps_eq!(a, Gaussian::from_ms(28.896475, 0.996604), epsilon = 1e-6); assert_ulps_eq!(b, Gaussian::from_ms(32.189211, 6.062063), epsilon = 1e-6); - let t_a = Player::new(Gaussian::from_ms(1.139, 0.531), 1.0, 0.2125); - let t_b = Player::new(Gaussian::from_ms(15.568, 0.51), 1.0, 0.2125); + let t_a = Player::new(Gaussian::from_ms(1.139, 0.531), 1.0, ConstantDrift(0.2125)); + let t_b = Player::new(Gaussian::from_ms(15.568, 0.51), 1.0, ConstantDrift(0.2125)); let w = [vec![1.0], vec![1.0]]; let g = Game::new(vec![vec![t_a], vec![t_b]], &[0.0, 1.0], &w, 0.0); @@ -257,17 +267,17 @@ mod tests { vec![Player::new( Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, - 25.0 / 300.0, + ConstantDrift(25.0 / 300.0), )], vec![Player::new( Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, - 25.0 / 300.0, + ConstantDrift(25.0 / 300.0), )], vec![Player::new( Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, - 25.0 / 300.0, + ConstantDrift(25.0 / 300.0), )], ]; @@ -309,12 +319,12 @@ mod tests { let t_a = Player::new( Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, - 25.0 / 300.0, + ConstantDrift(25.0 / 300.0), ); let t_b = Player::new( Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, - 25.0 / 300.0, + ConstantDrift(25.0 / 300.0), ); let w = [vec![1.0], vec![1.0]]; @@ -327,8 +337,16 @@ mod tests { assert_ulps_eq!(a, Gaussian::from_ms(24.999999, 6.469480), epsilon = 1e-6); assert_ulps_eq!(b, Gaussian::from_ms(24.999999, 6.469480), epsilon = 1e-6); - let t_a = Player::new(Gaussian::from_ms(25.0, 3.0), 25.0 / 6.0, 25.0 / 300.0); - let t_b = Player::new(Gaussian::from_ms(29.0, 2.0), 25.0 / 6.0, 25.0 / 300.0); + let t_a = Player::new( + Gaussian::from_ms(25.0, 3.0), + 25.0 / 6.0, + ConstantDrift(25.0 / 300.0), + ); + let t_b = Player::new( + Gaussian::from_ms(29.0, 2.0), + 25.0 / 6.0, + ConstantDrift(25.0 / 300.0), + ); let w = [vec![1.0], vec![1.0]]; let g = Game::new(vec![vec![t_a], vec![t_b]], &[0.0, 0.0], &w, 0.25); @@ -346,17 +364,17 @@ mod tests { let t_a = Player::new( Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, - 25.0 / 300.0, + ConstantDrift(25.0 / 300.0), ); let t_b = Player::new( Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, - 25.0 / 300.0, + ConstantDrift(25.0 / 300.0), ); let t_c = Player::new( Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, - 25.0 / 300.0, + ConstantDrift(25.0 / 300.0), ); let w = [vec![1.0], vec![1.0], vec![1.0]]; @@ -376,9 +394,21 @@ mod tests { assert_ulps_eq!(b, Gaussian::from_ms(25.000000, 5.707423), epsilon = 1e-6); assert_ulps_eq!(c, Gaussian::from_ms(24.999999, 5.729068), epsilon = 1e-6); - let t_a = Player::new(Gaussian::from_ms(25.0, 3.0), 25.0 / 6.0, 25.0 / 300.0); - let t_b = Player::new(Gaussian::from_ms(25.0, 3.0), 25.0 / 6.0, 25.0 / 300.0); - let t_c = Player::new(Gaussian::from_ms(29.0, 2.0), 25.0 / 6.0, 25.0 / 300.0); + let t_a = Player::new( + Gaussian::from_ms(25.0, 3.0), + 25.0 / 6.0, + ConstantDrift(25.0 / 300.0), + ); + let t_b = Player::new( + Gaussian::from_ms(25.0, 3.0), + 25.0 / 6.0, + ConstantDrift(25.0 / 300.0), + ); + let t_c = Player::new( + Gaussian::from_ms(29.0, 2.0), + 25.0 / 6.0, + ConstantDrift(25.0 / 300.0), + ); let w = [vec![1.0], vec![1.0], vec![1.0]]; let g = Game::new( @@ -401,17 +431,33 @@ mod tests { #[test] fn test_2vs1vs2_mixed() { let t_a = vec![ - Player::new(Gaussian::from_ms(12.0, 3.0), 25.0 / 6.0, 25.0 / 300.0), - Player::new(Gaussian::from_ms(18.0, 3.0), 25.0 / 6.0, 25.0 / 300.0), + Player::new( + Gaussian::from_ms(12.0, 3.0), + 25.0 / 6.0, + ConstantDrift(25.0 / 300.0), + ), + Player::new( + Gaussian::from_ms(18.0, 3.0), + 25.0 / 6.0, + ConstantDrift(25.0 / 300.0), + ), ]; let t_b = vec![Player::new( Gaussian::from_ms(30.0, 3.0), 25.0 / 6.0, - 25.0 / 300.0, + ConstantDrift(25.0 / 300.0), )]; let t_c = vec![ - Player::new(Gaussian::from_ms(14.0, 3.0), 25.0 / 6.0, 25.0 / 300.0), - Player::new(Gaussian::from_ms(16., 3.0), 25.0 / 6.0, 25.0 / 300.0), + Player::new( + Gaussian::from_ms(14.0, 3.0), + 25.0 / 6.0, + ConstantDrift(25.0 / 300.0), + ), + Player::new( + Gaussian::from_ms(16., 3.0), + 25.0 / 6.0, + ConstantDrift(25.0 / 300.0), + ), ]; let w = [vec![1.0, 1.0], vec![1.0], vec![1.0, 1.0]]; @@ -433,12 +479,12 @@ mod tests { let t_a = vec![Player::new( Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, - 0.0, + ConstantDrift(0.0), )]; let t_b = vec![Player::new( Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, - 0.0, + ConstantDrift(0.0), )]; let w = [w_a, w_b]; @@ -495,8 +541,16 @@ mod tests { let w_a = vec![1.0]; let w_b = vec![0.0]; - let t_a = vec![Player::new(Gaussian::from_ms(2.0, 6.0), 1.0, 0.0)]; - let t_b = vec![Player::new(Gaussian::from_ms(2.0, 6.0), 1.0, 0.0)]; + let t_a = vec![Player::new( + Gaussian::from_ms(2.0, 6.0), + 1.0, + ConstantDrift(0.0), + )]; + let t_b = vec![Player::new( + Gaussian::from_ms(2.0, 6.0), + 1.0, + ConstantDrift(0.0), + )]; let w = [w_a, w_b]; let g = Game::new(vec![t_a, t_b], &[1.0, 0.0], &w, 0.0); @@ -516,8 +570,16 @@ mod tests { let w_a = vec![1.0]; let w_b = vec![-1.0]; - let t_a = vec![Player::new(Gaussian::from_ms(2.0, 6.0), 1.0, 0.0)]; - let t_b = vec![Player::new(Gaussian::from_ms(2.0, 6.0), 1.0, 0.0)]; + let t_a = vec![Player::new( + Gaussian::from_ms(2.0, 6.0), + 1.0, + ConstantDrift(0.0), + )]; + let t_b = vec![Player::new( + Gaussian::from_ms(2.0, 6.0), + 1.0, + ConstantDrift(0.0), + )]; let w = [w_a, w_b]; let g = Game::new(vec![t_a, t_b], &[1.0, 0.0], &w, 0.0); @@ -529,14 +591,30 @@ mod tests { #[test] fn test_2vs2_weighted() { let t_a = vec![ - Player::new(Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, 0.0), - Player::new(Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, 0.0), + Player::new( + Gaussian::from_ms(25.0, 25.0 / 3.0), + 25.0 / 6.0, + ConstantDrift(0.0), + ), + Player::new( + Gaussian::from_ms(25.0, 25.0 / 3.0), + 25.0 / 6.0, + ConstantDrift(0.0), + ), ]; let w_a = vec![0.4, 0.8]; let t_b = vec![ - Player::new(Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, 0.0), - Player::new(Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, 0.0), + Player::new( + Gaussian::from_ms(25.0, 25.0 / 3.0), + 25.0 / 6.0, + ConstantDrift(0.0), + ), + Player::new( + Gaussian::from_ms(25.0, 25.0 / 3.0), + 25.0 / 6.0, + ConstantDrift(0.0), + ), ]; let w_b = vec![0.9, 0.6]; @@ -628,7 +706,7 @@ mod tests { vec![Player::new( Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, - 0.0, + ConstantDrift(0.0), )], ], &[1.0, 0.0], diff --git a/src/gaussian.rs b/src/gaussian.rs index b839d88..8e43099 100644 --- a/src/gaussian.rs +++ b/src/gaussian.rs @@ -40,10 +40,10 @@ impl Gaussian { } } - pub(crate) fn forget(&self, gamma: f64, t: i64) -> Self { + pub(crate) fn forget(&self, variance_delta: f64) -> Self { Self { mu: self.mu, - sigma: (self.sigma.powi(2) + t as f64 * gamma.powi(2)).sqrt(), + sigma: (self.sigma.powi(2) + variance_delta).sqrt(), } } } diff --git a/src/history.rs b/src/history.rs index a5c23bf..edd161f 100644 --- a/src/history.rs +++ b/src/history.rs @@ -1,25 +1,27 @@ use std::collections::HashMap; use crate::{ + BETA, GAMMA, Index, MU, N_INF, P_DRAW, SIGMA, agent::{self, Agent}, batch::{self, Batch}, + drift::{ConstantDrift, Drift}, gaussian::Gaussian, player::Player, - sort_time, tuple_gt, tuple_max, Index, BETA, GAMMA, MU, P_DRAW, SIGMA, + sort_time, tuple_gt, tuple_max, }; #[derive(Clone)] -pub struct HistoryBuilder { +pub struct HistoryBuilder { time: bool, mu: f64, sigma: f64, beta: f64, - gamma: f64, + drift: D, p_draw: f64, online: bool, } -impl HistoryBuilder { +impl HistoryBuilder { pub fn time(mut self, time: bool) -> Self { self.time = time; self @@ -40,9 +42,16 @@ impl HistoryBuilder { self } - pub fn gamma(mut self, gamma: f64) -> Self { - self.gamma = gamma; - self + pub fn drift(self, drift: D2) -> HistoryBuilder { + HistoryBuilder { + drift, + time: self.time, + mu: self.mu, + sigma: self.sigma, + beta: self.beta, + p_draw: self.p_draw, + online: self.online, + } } pub fn p_draw(mut self, p_draw: f64) -> Self { @@ -55,7 +64,7 @@ impl HistoryBuilder { self } - pub fn build(self) -> History { + pub fn build(self) -> History { History { size: 0, batches: Vec::new(), @@ -64,41 +73,48 @@ impl HistoryBuilder { mu: self.mu, sigma: self.sigma, beta: self.beta, - gamma: self.gamma, + drift: self.drift, p_draw: self.p_draw, online: self.online, } } } -impl Default for HistoryBuilder { +impl HistoryBuilder { + pub fn gamma(mut self, gamma: f64) -> Self { + self.drift = ConstantDrift(gamma); + self + } +} + +impl Default for HistoryBuilder { fn default() -> Self { Self { time: true, mu: MU, sigma: SIGMA, beta: BETA, - gamma: GAMMA, + drift: ConstantDrift(GAMMA), p_draw: P_DRAW, online: false, } } } -pub struct History { +pub struct History { size: usize, pub(crate) batches: Vec, - agents: HashMap, + agents: HashMap>, time: bool, mu: f64, sigma: f64, beta: f64, - gamma: f64, + drift: D, p_draw: f64, online: bool, } -impl Default for History { +impl Default for History { fn default() -> Self { Self { size: 0, @@ -108,18 +124,20 @@ impl Default for History { mu: MU, sigma: SIGMA, beta: BETA, - gamma: GAMMA, + drift: ConstantDrift(GAMMA), p_draw: P_DRAW, online: false, } } } -impl History { - pub fn builder() -> HistoryBuilder { +impl History { + pub fn builder() -> HistoryBuilder { HistoryBuilder::default() } +} +impl History { fn iteration(&mut self) -> (f64, f64) { let mut step = (0.0, 0.0); @@ -247,7 +265,7 @@ impl History { results: Vec>, times: Vec, weights: Vec>>, - mut priors: HashMap, + mut priors: HashMap>, ) { assert!(times.is_empty() || self.time, "length(times)>0 but !h.time"); assert!( @@ -286,10 +304,11 @@ impl History { Player::new( Gaussian::from_ms(self.mu, self.sigma), self.beta, - self.gamma, + self.drift, ) }), - ..Default::default() + message: N_INF, + last_time: i64::MIN, }, ); } @@ -414,7 +433,7 @@ impl History { mod tests { use approx::assert_ulps_eq; - use crate::{Game, Gaussian, IndexMap, Player, EPSILON, ITERATIONS, P_DRAW}; + use crate::{ConstantDrift, EPSILON, Game, Gaussian, ITERATIONS, IndexMap, P_DRAW, Player}; use super::*; @@ -441,7 +460,7 @@ mod tests { Player::new( Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, - 0.15 * 25.0 / 3.0, + ConstantDrift(0.15 * 25.0 / 3.0), ), ); } @@ -503,7 +522,7 @@ mod tests { Player::new( Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, - 0.15 * 25.0 / 3.0, + ConstantDrift(0.15 * 25.0 / 3.0), ), ); } @@ -552,7 +571,7 @@ mod tests { Player::new( Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, - 25.0 / 300.0, + ConstantDrift(25.0 / 300.0), ), ); } @@ -610,7 +629,7 @@ mod tests { Player::new( Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, - 25.0 / 300.0, + ConstantDrift(25.0 / 300.0), ), ); } diff --git a/src/lib.rs b/src/lib.rs index f1e4bb4..80761cf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,7 @@ pub mod agent; #[cfg(feature = "approx")] mod approx; pub mod batch; +pub mod drift; mod game; pub mod gaussian; mod history; @@ -15,6 +16,7 @@ mod matrix; mod message; pub mod player; +pub use drift::{ConstantDrift, Drift}; pub use game::Game; pub use gaussian::Gaussian; pub use history::History; diff --git a/src/message.rs b/src/message.rs index a73fb96..c6fd9bc 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,5 +1,4 @@ -use crate::gaussian::Gaussian; -use crate::N_INF; +use crate::{N_INF, gaussian::Gaussian}; pub(crate) struct TeamMessage { pub(crate) prior: Gaussian, diff --git a/src/player.rs b/src/player.rs index 3e1ee45..c4ebe33 100644 --- a/src/player.rs +++ b/src/player.rs @@ -1,35 +1,32 @@ -use crate::{gaussian::Gaussian, BETA, GAMMA}; +use crate::{ + BETA, GAMMA, + drift::{ConstantDrift, Drift}, + gaussian::Gaussian, +}; #[derive(Clone, Copy, Debug)] -pub struct Player { +pub struct Player { pub(crate) prior: Gaussian, pub(crate) beta: f64, - pub(crate) gamma: f64, - // pub(crate) draw: Gaussian, + pub(crate) drift: D, } -impl Player { - pub fn new(prior: Gaussian, beta: f64, gamma: f64) -> Self { - Self { - prior, - beta, - gamma, - // draw: N_INF, - } +impl Player { + pub fn new(prior: Gaussian, beta: f64, drift: D) -> Self { + Self { prior, beta, drift } } pub(crate) fn performance(&self) -> Gaussian { - self.prior.forget(self.beta, 1) + self.prior.forget(self.beta.powi(2)) } } -impl Default for Player { +impl Default for Player { fn default() -> Self { Self { prior: Gaussian::default(), beta: BETA, - gamma: GAMMA, - // draw: N_INF, + drift: ConstantDrift(GAMMA), } } }