Compare commits
7 Commits
refactor
...
04d5478ee4
| Author | SHA1 | Date | |
|---|---|---|---|
| 04d5478ee4 | |||
| 480467ac32 | |||
| dc47964310 | |||
| 61a5507f5c | |||
| a1f282a1c8 | |||
| 853f177fa8 | |||
| fc0efcdc52 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -4,3 +4,6 @@
|
||||
/temp
|
||||
.justfile
|
||||
*.svg
|
||||
NOTEPAD.md
|
||||
|
||||
/.claude
|
||||
|
||||
45
CLAUDE.md
Normal file
45
CLAUDE.md
Normal file
@@ -0,0 +1,45 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Commands
|
||||
|
||||
```bash
|
||||
cargo build # Build the library
|
||||
cargo test --lib # Run all library tests
|
||||
cargo test --lib <test_name> # Run a single test by name
|
||||
cargo test --lib -- --nocapture # Run tests with stdout output
|
||||
cargo clippy # Lint
|
||||
cargo bench # Run benchmarks (criterion)
|
||||
```
|
||||
|
||||
The `approx` feature enables `approx::AbsDiffEq` for `Gaussian`:
|
||||
```bash
|
||||
cargo test --features approx
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
This is a Rust port of [TrueSkillThroughTime.py](https://github.com/glandfried/TrueSkillThroughTime.py) — a Bayesian skill rating system that tracks skill evolution over time using Gaussian message passing.
|
||||
|
||||
### Data flow
|
||||
|
||||
```
|
||||
History → Batch[] → Game[] → teams/players
|
||||
```
|
||||
|
||||
- **`History`** (`history.rs`) — top-level container. Organizes games by time into `Batch`es, runs forward/backward message passing across batches, and exposes `learning_curves()` and `log_evidence()`.
|
||||
- **`Batch`** (`batch.rs`) — all games at a single time step. Runs `iteration()` to update skill estimates via `Game::posteriors()`, collecting `Skill` distributions per player.
|
||||
- **`Game`** (`game.rs`) — a single match. Given teams (slices of `Gaussian`), computes posterior skill distributions using Gaussian factor graphs and `message.rs` helpers.
|
||||
- **`Agent`** (`agent.rs`) — wraps a `Player` with temporal state (`last_time`, `message`). `receive()` applies time-decay (`gamma`) when the player reappears after a gap.
|
||||
- **`Player`** (`player.rs`) — static configuration: prior `Gaussian`, `beta` (performance noise), `gamma` (skill drift per time unit).
|
||||
- **`Gaussian`** (`gaussian.rs`) — core probability type. Stored as natural parameters (`pi = 1/sigma²`, `tau = mu/sigma²`). Arithmetic ops implement message multiplication/division in the factor graph.
|
||||
- **`message.rs`** — `TeamMessage` and `DiffMessage`: intermediate factor graph messages used inside `Game`.
|
||||
- **`lib.rs`** — exports the public API (`Game`, `Gaussian`, `History`, `Player`) and standalone functions (`quality()`, `pdf()`, `cdf()`, `erfc()`). Also defines global defaults: `MU=0.0`, `SIGMA=6.0`, `BETA=1.0`, `GAMMA=0.03`, `P_DRAW=0.0`, `EPSILON=1e-6`, `ITERATIONS=30`.
|
||||
|
||||
### Key design points
|
||||
|
||||
- `History` uses `IndexMap<K>` (defined in `lib.rs`) to map arbitrary player keys to `Agent` state.
|
||||
- Convergence is measured by the maximum `delta()` across all skill distributions; iteration stops when below `EPSILON` or after `ITERATIONS` rounds.
|
||||
- The `approx` feature gates `AbsDiffEq` on `Gaussian` for use in tests — the feature is optional and only needed for approximate equality assertions.
|
||||
- `time` in `History`/`Batch` is currently an `f64`; the README notes it needs to become an enum to support richer temporal states.
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "trueskill-tt"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
edition = "2024"
|
||||
|
||||
[lib]
|
||||
bench = false
|
||||
@@ -10,6 +10,10 @@ bench = false
|
||||
name = "batch"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "gaussian"
|
||||
harness = false
|
||||
|
||||
[dependencies]
|
||||
approx = { version = "0.5.1", optional = true }
|
||||
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
# History
|
||||
|
||||
```rust
|
||||
let mut history = History::new();
|
||||
|
||||
let agent_a = history.new_agent();
|
||||
let agent_b = history.new_agent_with_prior(Prior::new(Gaussian::default(), BETA, GAMMA));
|
||||
```
|
||||
60
README.md
60
README.md
@@ -11,6 +11,66 @@ Rust port of [TrueSkillThroughTime.py](https://github.com/glandfried/TrueSkillTh
|
||||
- [TrueSkill Through Time: Revisiting the History of Chess](https://www.microsoft.com/en-us/research/wp-content/uploads/2008/01/NIPS2007_0931.pdf)
|
||||
- [TrueSkill Through Time. The full scientific documentation](https://glandfried.github.io/publication/landfried2021-learning/)
|
||||
|
||||
## Drift
|
||||
|
||||
Skill drift models how a player's true skill can change between appearances. Each time a player reappears after a gap, their skill uncertainty is widened by the drift model before the new evidence is incorporated.
|
||||
|
||||
Drift is represented by the `Drift` trait:
|
||||
|
||||
```rust
|
||||
pub trait Drift: Copy + Debug {
|
||||
fn variance_delta(&self, elapsed: i64) -> f64;
|
||||
}
|
||||
```
|
||||
|
||||
`variance_delta` returns the amount to add to `σ²` given the elapsed time since the player last played. Internally, `Gaussian::forget` uses this to compute the new sigma: `σ_new = sqrt(σ² + variance_delta)`.
|
||||
|
||||
### ConstantDrift
|
||||
|
||||
The built-in `ConstantDrift` implements a linear random walk — skill uncertainty grows proportionally to time:
|
||||
|
||||
```
|
||||
variance_delta = elapsed * γ²
|
||||
```
|
||||
|
||||
This is the standard TrueSkill Through Time model. Use it by passing a `ConstantDrift(gamma)` when constructing a `Player`:
|
||||
|
||||
```rust
|
||||
use trueskill_tt::{Player, Gaussian, drift::ConstantDrift};
|
||||
|
||||
// gamma = 0.1 means skill can shift ~0.1 per time unit
|
||||
let player = Player::new(Gaussian::from_ms(0.0, 6.0), 1.0, ConstantDrift(0.1));
|
||||
```
|
||||
|
||||
### Custom drift
|
||||
|
||||
Implement `Drift` to express any other model. For example, a drift that saturates after a long absence (uncertainty grows with the square root of elapsed time instead of linearly):
|
||||
|
||||
```rust
|
||||
use trueskill_tt::drift::Drift;
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
struct SqrtDrift {
|
||||
gamma: f64,
|
||||
}
|
||||
|
||||
impl Drift for SqrtDrift {
|
||||
fn variance_delta(&self, elapsed: i64) -> f64 {
|
||||
(elapsed as f64).sqrt() * self.gamma * self.gamma
|
||||
}
|
||||
}
|
||||
|
||||
let player = Player::new(Gaussian::from_ms(0.0, 6.0), 1.0, SqrtDrift { gamma: 0.5 });
|
||||
```
|
||||
|
||||
To use a custom drift type with `History`, use the `.drift()` builder method instead of `.gamma()`:
|
||||
|
||||
```rust
|
||||
let h = History::builder()
|
||||
.drift(SqrtDrift { gamma: 0.5 })
|
||||
.build();
|
||||
```
|
||||
|
||||
## Todo
|
||||
|
||||
- [x] Implement approx for Gaussian
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use criterion::{criterion_group, criterion_main, Criterion};
|
||||
use criterion::{Criterion, criterion_group, criterion_main};
|
||||
use trueskill_tt::{
|
||||
agent::Agent, batch::Batch, gaussian::Gaussian, player::Player, IndexMap, BETA, GAMMA, MU,
|
||||
P_DRAW, SIGMA,
|
||||
BETA, GAMMA, IndexMap, MU, P_DRAW, SIGMA, agent::Agent, batch::Batch, drift::ConstantDrift,
|
||||
gaussian::Gaussian, player::Player,
|
||||
};
|
||||
|
||||
fn criterion_benchmark(criterion: &mut Criterion) {
|
||||
@@ -19,21 +19,21 @@ fn criterion_benchmark(criterion: &mut Criterion) {
|
||||
map.insert(
|
||||
a,
|
||||
Agent {
|
||||
player: Player::new(Gaussian::from_ms(MU, SIGMA), BETA, GAMMA),
|
||||
player: Player::new(Gaussian::from_ms(MU, SIGMA), BETA, ConstantDrift(GAMMA)),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
map.insert(
|
||||
b,
|
||||
Agent {
|
||||
player: Player::new(Gaussian::from_ms(MU, SIGMA), BETA, GAMMA),
|
||||
player: Player::new(Gaussian::from_ms(MU, SIGMA), BETA, ConstantDrift(GAMMA)),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
map.insert(
|
||||
c,
|
||||
Agent {
|
||||
player: Player::new(Gaussian::from_ms(MU, SIGMA), BETA, GAMMA),
|
||||
player: Player::new(Gaussian::from_ms(MU, SIGMA), BETA, ConstantDrift(GAMMA)),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
|
||||
50
benches/gaussian.rs
Normal file
50
benches/gaussian.rs
Normal file
@@ -0,0 +1,50 @@
|
||||
use criterion::{Criterion, criterion_group, criterion_main};
|
||||
use trueskill_tt::gaussian::Gaussian;
|
||||
|
||||
fn benchmark_gaussian_arithmetic(criterion: &mut Criterion) {
|
||||
// Define test Gaussians
|
||||
let g1 = Gaussian::from_ms(25.0, 25.0 / 3.0);
|
||||
let g2 = Gaussian::from_ms(0.0, 1.0);
|
||||
let g3 = Gaussian::from_ms(1.0, 1.0);
|
||||
|
||||
// Benchmark addition
|
||||
criterion.bench_function("Gaussian::add", |bencher| {
|
||||
bencher.iter(|| g1 + g2);
|
||||
});
|
||||
|
||||
// Benchmark subtraction
|
||||
criterion.bench_function("Gaussian::sub", |bencher| {
|
||||
bencher.iter(|| g1 - g3);
|
||||
});
|
||||
|
||||
// Benchmark multiplication
|
||||
criterion.bench_function("Gaussian::mul", |bencher| {
|
||||
bencher.iter(|| g1 * g2);
|
||||
});
|
||||
|
||||
// Benchmark division
|
||||
criterion.bench_function("Gaussian::div", |bencher| {
|
||||
bencher.iter(|| g1 / g2);
|
||||
});
|
||||
|
||||
// Benchmark natural parameter conversions
|
||||
criterion.bench_function("Gaussian::pi", |bencher| {
|
||||
bencher.iter(|| g1.pi());
|
||||
});
|
||||
|
||||
criterion.bench_function("Gaussian::tau", |bencher| {
|
||||
bencher.iter(|| g1.tau());
|
||||
});
|
||||
|
||||
// Benchmark combined pi/tau operations (used in mul/div)
|
||||
criterion.bench_function("Gaussian::pi_tau_combined", |bencher| {
|
||||
bencher.iter(|| {
|
||||
let pi = g1.pi();
|
||||
let tau = g1.tau();
|
||||
(pi, tau)
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(benches, benchmark_gaussian_arithmetic);
|
||||
criterion_main!(benches);
|
||||
65
cliff.toml
Normal file
65
cliff.toml
Normal file
@@ -0,0 +1,65 @@
|
||||
# git-cliff ~ configuration file
|
||||
# https://git-cliff.org/docs/configuration
|
||||
|
||||
[changelog]
|
||||
# A Tera template to be rendered as the changelog's header.
|
||||
# See https://keats.github.io/tera/docs/#introduction
|
||||
header = """
|
||||
# Changelog\n
|
||||
All notable changes to this project will be documented in this file.\n
|
||||
"""
|
||||
# A Tera template to be rendered for each release in the changelog.
|
||||
# See https://keats.github.io/tera/docs/#introduction
|
||||
body = """
|
||||
{% if version %}\
|
||||
## {{ version | trim_start_matches(pat="v") }} - {{ timestamp | date(format="%Y-%m-%d") }}
|
||||
{% else %}\
|
||||
## Unreleased
|
||||
{% endif %}\
|
||||
{% for group, commits in commits | group_by(attribute="group") %}
|
||||
### {{ group | upper_first }}
|
||||
{% for commit in commits %}
|
||||
- {{ commit.message | split(pat="\n") | first | trim_end }}\
|
||||
{% endfor %}
|
||||
{% endfor %}\n
|
||||
"""
|
||||
# A Tera template to be rendered as the changelog's footer.
|
||||
# See https://keats.github.io/tera/docs/#introduction
|
||||
footer = """
|
||||
<!-- generated by git-cliff -->
|
||||
"""
|
||||
# Remove leading and trailing whitespaces from the changelog's body.
|
||||
trim = true
|
||||
|
||||
|
||||
[git]
|
||||
# Parse commits according to the conventional commits specification.
|
||||
# See https://www.conventionalcommits.org
|
||||
conventional_commits = false
|
||||
# Exclude commits that do not match the conventional commits specification.
|
||||
filter_unconventional = false
|
||||
# Split commits on newlines, treating each line as an individual commit.
|
||||
split_commits = false
|
||||
# An array of regex based parsers for extracting data from the commit message.
|
||||
# Assigns commits to groups.
|
||||
# Optionally sets the commit's scope and can decide to exclude commits from further processing.
|
||||
commit_parsers = [
|
||||
{ message = "^feat", group = "Features" },
|
||||
{ message = "^fix", group = "Bug Fixes" },
|
||||
{ message = "^doc", group = "Documentation" },
|
||||
{ message = "^perf", group = "Performance" },
|
||||
{ message = "^refactor", group = "Refactor" },
|
||||
{ message = "^style", group = "Styling" },
|
||||
{ message = "^test", group = "Testing" },
|
||||
{ message = "^chore\\(release\\): prepare for", skip = true },
|
||||
{ message = "^chore", group = "Miscellaneous Tasks" },
|
||||
{ body = ".*security", group = "Security" },
|
||||
{ body = ".*", group = "Other (unconventional)" },
|
||||
]
|
||||
# Exclude commits that are not matched by any commit parser.
|
||||
filter_commits = false
|
||||
# Order releases topologically instead of chronologically.
|
||||
topo_order = false
|
||||
# Order of commits in each group/release within the changelog.
|
||||
# Allowed values: newest, oldest
|
||||
sort_commits = "oldest"
|
||||
@@ -159,10 +159,12 @@ fn main() {
|
||||
}
|
||||
|
||||
mod csv {
|
||||
use std::fs::File;
|
||||
use std::io::{self, BufRead, BufReader, Lines};
|
||||
use std::ops;
|
||||
use std::path::Path;
|
||||
use std::{
|
||||
fs::File,
|
||||
io::{self, BufRead, BufReader, Lines},
|
||||
ops,
|
||||
path::Path,
|
||||
};
|
||||
|
||||
pub struct Reader {
|
||||
header_map: Vec<String>,
|
||||
|
||||
64
graph.d2
Normal file
64
graph.d2
Normal file
@@ -0,0 +1,64 @@
|
||||
vars: {
|
||||
d2-config: {
|
||||
layout-engine: elk
|
||||
# Terminal theme code
|
||||
theme-id: 300
|
||||
}
|
||||
}
|
||||
|
||||
History: {
|
||||
shape: class
|
||||
|
||||
agents: "HashMap<Index, Agent>"
|
||||
batches: "Vec<Batch>"
|
||||
}
|
||||
|
||||
Batch: {
|
||||
shape: class
|
||||
|
||||
skills: "HashMap<Index, Skill>"
|
||||
events: "Vec<Event>"
|
||||
time: "i64"
|
||||
p_draw: "f64"
|
||||
}
|
||||
|
||||
Event: {
|
||||
shape: class
|
||||
|
||||
teams: "Vec<Team>"
|
||||
weights: "Vec<Vec<f64>>"
|
||||
evidence: "f64"
|
||||
}
|
||||
|
||||
Team: {
|
||||
shape: class
|
||||
|
||||
items: "Vec<Item>"
|
||||
output: "f64"
|
||||
}
|
||||
|
||||
Item: {
|
||||
shape: class
|
||||
|
||||
agent: "Index"
|
||||
likelihood: "Gaussian"
|
||||
}
|
||||
|
||||
Skill: {
|
||||
shape: class
|
||||
|
||||
forward: "Gaussian"
|
||||
backward: "Gaussian"
|
||||
likelihood: "Gaussian"
|
||||
elapsed: "i64"
|
||||
online: "Gaussian"
|
||||
}
|
||||
|
||||
History -> Batch
|
||||
|
||||
Batch -> Skill
|
||||
Batch -> Event
|
||||
|
||||
Event -> Team
|
||||
|
||||
Team -> Item
|
||||
1
release.toml
Normal file
1
release.toml
Normal file
@@ -0,0 +1 @@
|
||||
pre-release-hook = ["git", "cliff", "-o", "CHANGELOG.md", "--tag", "{{version}}"]
|
||||
2
rustfmt.toml
Normal file
2
rustfmt.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
imports_granularity = "Crate"
|
||||
group_imports = "StdExternalCrate"
|
||||
23
src/agent.rs
23
src/agent.rs
@@ -1,23 +1,29 @@
|
||||
use crate::{gaussian::Gaussian, player::Player, N_INF};
|
||||
use crate::{
|
||||
N_INF,
|
||||
drift::{ConstantDrift, Drift},
|
||||
gaussian::Gaussian,
|
||||
player::Player,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Agent {
|
||||
pub player: Player,
|
||||
pub struct Agent<D: Drift = ConstantDrift> {
|
||||
pub player: Player<D>,
|
||||
pub message: Gaussian,
|
||||
pub last_time: i64,
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
impl<D: Drift> Agent<D> {
|
||||
pub(crate) fn receive(&self, elapsed: i64) -> Gaussian {
|
||||
if self.message != N_INF {
|
||||
self.message.forget(self.player.gamma, elapsed)
|
||||
self.message
|
||||
.forget(self.player.drift.variance_delta(elapsed))
|
||||
} else {
|
||||
self.player.prior
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Agent {
|
||||
impl Default for Agent<ConstantDrift> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
player: Player::default(),
|
||||
@@ -27,7 +33,10 @@ impl Default for Agent {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn clean<'a, A: Iterator<Item = &'a mut Agent>>(agents: A, last_time: bool) {
|
||||
pub(crate) fn clean<'a, D: Drift + 'a, A: Iterator<Item = &'a mut Agent<D>>>(
|
||||
agents: A,
|
||||
last_time: bool,
|
||||
) {
|
||||
for a in agents {
|
||||
a.message = N_INF;
|
||||
|
||||
|
||||
52
src/batch.rs
52
src/batch.rs
@@ -1,7 +1,8 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
agent::Agent, game::Game, gaussian::Gaussian, player::Player, tuple_gt, tuple_max, Index, N_INF,
|
||||
Index, N_INF, agent::Agent, drift::Drift, game::Game, gaussian::Gaussian, player::Player,
|
||||
tuple_gt, tuple_max,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -38,22 +39,22 @@ struct Item {
|
||||
}
|
||||
|
||||
impl Item {
|
||||
fn within_prior(
|
||||
fn within_prior<D: Drift>(
|
||||
&self,
|
||||
online: bool,
|
||||
forward: bool,
|
||||
skills: &HashMap<Index, Skill>,
|
||||
agents: &HashMap<Index, Agent>,
|
||||
) -> Player {
|
||||
agents: &HashMap<Index, Agent<D>>,
|
||||
) -> Player<D> {
|
||||
let r = &agents[&self.agent].player;
|
||||
let skill = &skills[&self.agent];
|
||||
|
||||
if online {
|
||||
Player::new(skill.online, r.beta, r.gamma)
|
||||
Player::new(skill.online, r.beta, r.drift)
|
||||
} else if forward {
|
||||
Player::new(skill.forward, r.beta, r.gamma)
|
||||
Player::new(skill.forward, r.beta, r.drift)
|
||||
} else {
|
||||
Player::new(skill.posterior() / self.likelihood, r.beta, r.gamma)
|
||||
Player::new(skill.posterior() / self.likelihood, r.beta, r.drift)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -79,13 +80,13 @@ impl Event {
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
pub(crate) fn within_priors(
|
||||
pub(crate) fn within_priors<D: Drift>(
|
||||
&self,
|
||||
online: bool,
|
||||
forward: bool,
|
||||
skills: &HashMap<Index, Skill>,
|
||||
agents: &HashMap<Index, Agent>,
|
||||
) -> Vec<Vec<Player>> {
|
||||
agents: &HashMap<Index, Agent<D>>,
|
||||
) -> Vec<Vec<Player<D>>> {
|
||||
self.teams
|
||||
.iter()
|
||||
.map(|team| {
|
||||
@@ -116,12 +117,12 @@ impl Batch {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_events(
|
||||
pub fn add_events<D: Drift>(
|
||||
&mut self,
|
||||
composition: Vec<Vec<Vec<Index>>>,
|
||||
results: Vec<Vec<f64>>,
|
||||
weights: Vec<Vec<Vec<f64>>>,
|
||||
agents: &HashMap<Index, Agent>,
|
||||
agents: &HashMap<Index, Agent<D>>,
|
||||
) {
|
||||
let mut unique = Vec::with_capacity(10);
|
||||
|
||||
@@ -207,7 +208,7 @@ impl Batch {
|
||||
.collect::<HashMap<_, _>>()
|
||||
}
|
||||
|
||||
pub fn iteration(&mut self, from: usize, agents: &HashMap<Index, Agent>) {
|
||||
pub fn iteration<D: Drift>(&mut self, from: usize, agents: &HashMap<Index, Agent<D>>) {
|
||||
for event in self.events.iter_mut().skip(from) {
|
||||
let teams = event.within_priors(false, false, &self.skills, agents);
|
||||
let result = event.outputs();
|
||||
@@ -229,7 +230,7 @@ impl Batch {
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn convergence(&mut self, agents: &HashMap<Index, Agent>) -> usize {
|
||||
pub(crate) fn convergence<D: Drift>(&mut self, agents: &HashMap<Index, Agent<D>>) -> usize {
|
||||
let epsilon = 1e-6;
|
||||
let iterations = 20;
|
||||
|
||||
@@ -259,18 +260,18 @@ impl Batch {
|
||||
skill.forward * skill.likelihood
|
||||
}
|
||||
|
||||
pub(crate) fn backward_prior_out(
|
||||
pub(crate) fn backward_prior_out<D: Drift>(
|
||||
&self,
|
||||
agent: &Index,
|
||||
agents: &HashMap<Index, Agent>,
|
||||
agents: &HashMap<Index, Agent<D>>,
|
||||
) -> Gaussian {
|
||||
let skill = &self.skills[agent];
|
||||
let n = skill.likelihood * skill.backward;
|
||||
|
||||
n.forget(agents[agent].player.gamma, skill.elapsed)
|
||||
n.forget(agents[agent].player.drift.variance_delta(skill.elapsed))
|
||||
}
|
||||
|
||||
pub(crate) fn new_backward_info(&mut self, agents: &HashMap<Index, Agent>) {
|
||||
pub(crate) fn new_backward_info<D: Drift>(&mut self, agents: &HashMap<Index, Agent<D>>) {
|
||||
for (agent, skill) in self.skills.iter_mut() {
|
||||
skill.backward = agents[agent].message;
|
||||
}
|
||||
@@ -278,7 +279,7 @@ impl Batch {
|
||||
self.iteration(0, agents);
|
||||
}
|
||||
|
||||
pub(crate) fn new_forward_info(&mut self, agents: &HashMap<Index, Agent>) {
|
||||
pub(crate) fn new_forward_info<D: Drift>(&mut self, agents: &HashMap<Index, Agent<D>>) {
|
||||
for (agent, skill) in self.skills.iter_mut() {
|
||||
skill.forward = agents[agent].receive(skill.elapsed);
|
||||
}
|
||||
@@ -286,12 +287,12 @@ impl Batch {
|
||||
self.iteration(0, agents);
|
||||
}
|
||||
|
||||
pub(crate) fn log_evidence(
|
||||
pub(crate) fn log_evidence<D: Drift>(
|
||||
&self,
|
||||
online: bool,
|
||||
targets: &[Index],
|
||||
forward: bool,
|
||||
agents: &HashMap<Index, Agent>,
|
||||
agents: &HashMap<Index, Agent<D>>,
|
||||
) -> f64 {
|
||||
if targets.is_empty() {
|
||||
if online || forward {
|
||||
@@ -390,9 +391,8 @@ pub(crate) fn compute_elapsed(last_time: i64, actual_time: i64) -> i64 {
|
||||
mod tests {
|
||||
use approx::assert_ulps_eq;
|
||||
|
||||
use crate::{agent::Agent, player::Player, IndexMap};
|
||||
|
||||
use super::*;
|
||||
use crate::{IndexMap, agent::Agent, drift::ConstantDrift, player::Player};
|
||||
|
||||
#[test]
|
||||
fn test_one_event_each() {
|
||||
@@ -414,7 +414,7 @@ mod tests {
|
||||
player: Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
25.0 / 300.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
),
|
||||
..Default::default()
|
||||
},
|
||||
@@ -490,7 +490,7 @@ mod tests {
|
||||
player: Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
25.0 / 300.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
),
|
||||
..Default::default()
|
||||
},
|
||||
@@ -569,7 +569,7 @@ mod tests {
|
||||
player: Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
25.0 / 300.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
),
|
||||
..Default::default()
|
||||
},
|
||||
|
||||
14
src/drift.rs
Normal file
14
src/drift.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
pub trait Drift: Copy + Debug {
|
||||
fn variance_delta(&self, elapsed: i64) -> f64;
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct ConstantDrift(pub f64);
|
||||
|
||||
impl Drift for ConstantDrift {
|
||||
fn variance_delta(&self, elapsed: i64) -> f64 {
|
||||
elapsed as f64 * self.0 * self.0
|
||||
}
|
||||
}
|
||||
165
src/game.rs
165
src/game.rs
@@ -1,14 +1,16 @@
|
||||
use crate::{
|
||||
approx, compute_margin, evidence,
|
||||
N_INF, N00, approx, compute_margin,
|
||||
drift::Drift,
|
||||
evidence,
|
||||
gaussian::Gaussian,
|
||||
message::{DiffMessage, TeamMessage},
|
||||
player::Player,
|
||||
sort_perm, tuple_gt, tuple_max, N00, N_INF,
|
||||
sort_perm, tuple_gt, tuple_max,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Game<'a> {
|
||||
teams: Vec<Vec<Player>>,
|
||||
pub struct Game<'a, D: Drift> {
|
||||
teams: Vec<Vec<Player<D>>>,
|
||||
result: &'a [f64],
|
||||
weights: &'a [Vec<f64>],
|
||||
p_draw: f64,
|
||||
@@ -16,9 +18,9 @@ pub struct Game<'a> {
|
||||
pub(crate) evidence: f64,
|
||||
}
|
||||
|
||||
impl<'a> Game<'a> {
|
||||
impl<'a, D: Drift> Game<'a, D> {
|
||||
pub fn new(
|
||||
teams: Vec<Vec<Player>>,
|
||||
teams: Vec<Vec<Player<D>>>,
|
||||
result: &'a [f64],
|
||||
weights: &'a [Vec<f64>],
|
||||
p_draw: f64,
|
||||
@@ -176,7 +178,7 @@ impl<'a> Game<'a> {
|
||||
.zip(w.iter())
|
||||
.map(|(p, &w)| {
|
||||
((m - performance.exclude(p.performance() * w)) * (1.0 / w))
|
||||
.forget(p.beta, 1)
|
||||
.forget(p.beta.powi(2))
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
@@ -201,21 +203,20 @@ impl<'a> Game<'a> {
|
||||
mod tests {
|
||||
use ::approx::assert_ulps_eq;
|
||||
|
||||
use crate::{Gaussian, Player, GAMMA, N_INF};
|
||||
|
||||
use super::*;
|
||||
use crate::{ConstantDrift, GAMMA, Gaussian, N_INF, Player};
|
||||
|
||||
#[test]
|
||||
fn test_1vs1() {
|
||||
let t_a = Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
25.0 / 300.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
);
|
||||
let t_b = Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
25.0 / 300.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
);
|
||||
|
||||
let w = [vec![1.0], vec![1.0]];
|
||||
@@ -228,8 +229,16 @@ mod tests {
|
||||
assert_ulps_eq!(a, Gaussian::from_ms(20.794779, 7.194481), epsilon = 1e-6);
|
||||
assert_ulps_eq!(b, Gaussian::from_ms(29.205220, 7.194481), epsilon = 1e-6);
|
||||
|
||||
let t_a = Player::new(Gaussian::from_ms(29.0, 1.0), 25.0 / 6.0, GAMMA);
|
||||
let t_b = Player::new(Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, GAMMA);
|
||||
let t_a = Player::new(
|
||||
Gaussian::from_ms(29.0, 1.0),
|
||||
25.0 / 6.0,
|
||||
ConstantDrift(GAMMA),
|
||||
);
|
||||
let t_b = Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
ConstantDrift(GAMMA),
|
||||
);
|
||||
|
||||
let w = [vec![1.0], vec![1.0]];
|
||||
let g = Game::new(vec![vec![t_a], vec![t_b]], &[0.0, 1.0], &w, 0.0);
|
||||
@@ -241,8 +250,8 @@ mod tests {
|
||||
assert_ulps_eq!(a, Gaussian::from_ms(28.896475, 0.996604), epsilon = 1e-6);
|
||||
assert_ulps_eq!(b, Gaussian::from_ms(32.189211, 6.062063), epsilon = 1e-6);
|
||||
|
||||
let t_a = Player::new(Gaussian::from_ms(1.139, 0.531), 1.0, 0.2125);
|
||||
let t_b = Player::new(Gaussian::from_ms(15.568, 0.51), 1.0, 0.2125);
|
||||
let t_a = Player::new(Gaussian::from_ms(1.139, 0.531), 1.0, ConstantDrift(0.2125));
|
||||
let t_b = Player::new(Gaussian::from_ms(15.568, 0.51), 1.0, ConstantDrift(0.2125));
|
||||
|
||||
let w = [vec![1.0], vec![1.0]];
|
||||
let g = Game::new(vec![vec![t_a], vec![t_b]], &[0.0, 1.0], &w, 0.0);
|
||||
@@ -257,17 +266,17 @@ mod tests {
|
||||
vec![Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
25.0 / 300.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
)],
|
||||
vec![Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
25.0 / 300.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
)],
|
||||
vec![Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
25.0 / 300.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
)],
|
||||
];
|
||||
|
||||
@@ -309,12 +318,12 @@ mod tests {
|
||||
let t_a = Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
25.0 / 300.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
);
|
||||
let t_b = Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
25.0 / 300.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
);
|
||||
|
||||
let w = [vec![1.0], vec![1.0]];
|
||||
@@ -327,8 +336,16 @@ mod tests {
|
||||
assert_ulps_eq!(a, Gaussian::from_ms(24.999999, 6.469480), epsilon = 1e-6);
|
||||
assert_ulps_eq!(b, Gaussian::from_ms(24.999999, 6.469480), epsilon = 1e-6);
|
||||
|
||||
let t_a = Player::new(Gaussian::from_ms(25.0, 3.0), 25.0 / 6.0, 25.0 / 300.0);
|
||||
let t_b = Player::new(Gaussian::from_ms(29.0, 2.0), 25.0 / 6.0, 25.0 / 300.0);
|
||||
let t_a = Player::new(
|
||||
Gaussian::from_ms(25.0, 3.0),
|
||||
25.0 / 6.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
);
|
||||
let t_b = Player::new(
|
||||
Gaussian::from_ms(29.0, 2.0),
|
||||
25.0 / 6.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
);
|
||||
|
||||
let w = [vec![1.0], vec![1.0]];
|
||||
let g = Game::new(vec![vec![t_a], vec![t_b]], &[0.0, 0.0], &w, 0.25);
|
||||
@@ -346,17 +363,17 @@ mod tests {
|
||||
let t_a = Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
25.0 / 300.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
);
|
||||
let t_b = Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
25.0 / 300.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
);
|
||||
let t_c = Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
25.0 / 300.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
);
|
||||
|
||||
let w = [vec![1.0], vec![1.0], vec![1.0]];
|
||||
@@ -376,9 +393,21 @@ mod tests {
|
||||
assert_ulps_eq!(b, Gaussian::from_ms(25.000000, 5.707423), epsilon = 1e-6);
|
||||
assert_ulps_eq!(c, Gaussian::from_ms(24.999999, 5.729068), epsilon = 1e-6);
|
||||
|
||||
let t_a = Player::new(Gaussian::from_ms(25.0, 3.0), 25.0 / 6.0, 25.0 / 300.0);
|
||||
let t_b = Player::new(Gaussian::from_ms(25.0, 3.0), 25.0 / 6.0, 25.0 / 300.0);
|
||||
let t_c = Player::new(Gaussian::from_ms(29.0, 2.0), 25.0 / 6.0, 25.0 / 300.0);
|
||||
let t_a = Player::new(
|
||||
Gaussian::from_ms(25.0, 3.0),
|
||||
25.0 / 6.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
);
|
||||
let t_b = Player::new(
|
||||
Gaussian::from_ms(25.0, 3.0),
|
||||
25.0 / 6.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
);
|
||||
let t_c = Player::new(
|
||||
Gaussian::from_ms(29.0, 2.0),
|
||||
25.0 / 6.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
);
|
||||
|
||||
let w = [vec![1.0], vec![1.0], vec![1.0]];
|
||||
let g = Game::new(
|
||||
@@ -401,17 +430,33 @@ mod tests {
|
||||
#[test]
|
||||
fn test_2vs1vs2_mixed() {
|
||||
let t_a = vec![
|
||||
Player::new(Gaussian::from_ms(12.0, 3.0), 25.0 / 6.0, 25.0 / 300.0),
|
||||
Player::new(Gaussian::from_ms(18.0, 3.0), 25.0 / 6.0, 25.0 / 300.0),
|
||||
Player::new(
|
||||
Gaussian::from_ms(12.0, 3.0),
|
||||
25.0 / 6.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
),
|
||||
Player::new(
|
||||
Gaussian::from_ms(18.0, 3.0),
|
||||
25.0 / 6.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
),
|
||||
];
|
||||
let t_b = vec![Player::new(
|
||||
Gaussian::from_ms(30.0, 3.0),
|
||||
25.0 / 6.0,
|
||||
25.0 / 300.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
)];
|
||||
let t_c = vec![
|
||||
Player::new(Gaussian::from_ms(14.0, 3.0), 25.0 / 6.0, 25.0 / 300.0),
|
||||
Player::new(Gaussian::from_ms(16., 3.0), 25.0 / 6.0, 25.0 / 300.0),
|
||||
Player::new(
|
||||
Gaussian::from_ms(14.0, 3.0),
|
||||
25.0 / 6.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
),
|
||||
Player::new(
|
||||
Gaussian::from_ms(16., 3.0),
|
||||
25.0 / 6.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
),
|
||||
];
|
||||
|
||||
let w = [vec![1.0, 1.0], vec![1.0], vec![1.0, 1.0]];
|
||||
@@ -433,12 +478,12 @@ mod tests {
|
||||
let t_a = vec![Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
0.0,
|
||||
ConstantDrift(0.0),
|
||||
)];
|
||||
let t_b = vec![Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
0.0,
|
||||
ConstantDrift(0.0),
|
||||
)];
|
||||
|
||||
let w = [w_a, w_b];
|
||||
@@ -495,8 +540,16 @@ mod tests {
|
||||
let w_a = vec![1.0];
|
||||
let w_b = vec![0.0];
|
||||
|
||||
let t_a = vec![Player::new(Gaussian::from_ms(2.0, 6.0), 1.0, 0.0)];
|
||||
let t_b = vec![Player::new(Gaussian::from_ms(2.0, 6.0), 1.0, 0.0)];
|
||||
let t_a = vec![Player::new(
|
||||
Gaussian::from_ms(2.0, 6.0),
|
||||
1.0,
|
||||
ConstantDrift(0.0),
|
||||
)];
|
||||
let t_b = vec![Player::new(
|
||||
Gaussian::from_ms(2.0, 6.0),
|
||||
1.0,
|
||||
ConstantDrift(0.0),
|
||||
)];
|
||||
|
||||
let w = [w_a, w_b];
|
||||
let g = Game::new(vec![t_a, t_b], &[1.0, 0.0], &w, 0.0);
|
||||
@@ -516,8 +569,16 @@ mod tests {
|
||||
let w_a = vec![1.0];
|
||||
let w_b = vec![-1.0];
|
||||
|
||||
let t_a = vec![Player::new(Gaussian::from_ms(2.0, 6.0), 1.0, 0.0)];
|
||||
let t_b = vec![Player::new(Gaussian::from_ms(2.0, 6.0), 1.0, 0.0)];
|
||||
let t_a = vec![Player::new(
|
||||
Gaussian::from_ms(2.0, 6.0),
|
||||
1.0,
|
||||
ConstantDrift(0.0),
|
||||
)];
|
||||
let t_b = vec![Player::new(
|
||||
Gaussian::from_ms(2.0, 6.0),
|
||||
1.0,
|
||||
ConstantDrift(0.0),
|
||||
)];
|
||||
|
||||
let w = [w_a, w_b];
|
||||
let g = Game::new(vec![t_a, t_b], &[1.0, 0.0], &w, 0.0);
|
||||
@@ -529,14 +590,30 @@ mod tests {
|
||||
#[test]
|
||||
fn test_2vs2_weighted() {
|
||||
let t_a = vec![
|
||||
Player::new(Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, 0.0),
|
||||
Player::new(Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, 0.0),
|
||||
Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
ConstantDrift(0.0),
|
||||
),
|
||||
Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
ConstantDrift(0.0),
|
||||
),
|
||||
];
|
||||
let w_a = vec![0.4, 0.8];
|
||||
|
||||
let t_b = vec![
|
||||
Player::new(Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, 0.0),
|
||||
Player::new(Gaussian::from_ms(25.0, 25.0 / 3.0), 25.0 / 6.0, 0.0),
|
||||
Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
ConstantDrift(0.0),
|
||||
),
|
||||
Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
ConstantDrift(0.0),
|
||||
),
|
||||
];
|
||||
let w_b = vec![0.9, 0.6];
|
||||
|
||||
@@ -628,7 +705,7 @@ mod tests {
|
||||
vec![Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
0.0,
|
||||
ConstantDrift(0.0),
|
||||
)],
|
||||
],
|
||||
&[1.0, 0.0],
|
||||
|
||||
@@ -40,10 +40,10 @@ impl Gaussian {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn forget(&self, gamma: f64, t: i64) -> Self {
|
||||
pub(crate) fn forget(&self, variance_delta: f64) -> Self {
|
||||
Self {
|
||||
mu: self.mu,
|
||||
sigma: (self.sigma.powi(2) + t as f64 * gamma.powi(2)).sqrt(),
|
||||
sigma: (self.sigma.powi(2) + variance_delta).sqrt(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,25 +1,27 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
BETA, GAMMA, Index, MU, N_INF, P_DRAW, SIGMA,
|
||||
agent::{self, Agent},
|
||||
batch::{self, Batch},
|
||||
drift::{ConstantDrift, Drift},
|
||||
gaussian::Gaussian,
|
||||
player::Player,
|
||||
sort_time, tuple_gt, tuple_max, Index, BETA, GAMMA, MU, P_DRAW, SIGMA,
|
||||
sort_time, tuple_gt, tuple_max,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct HistoryBuilder {
|
||||
pub struct HistoryBuilder<D: Drift = ConstantDrift> {
|
||||
time: bool,
|
||||
mu: f64,
|
||||
sigma: f64,
|
||||
beta: f64,
|
||||
gamma: f64,
|
||||
drift: D,
|
||||
p_draw: f64,
|
||||
online: bool,
|
||||
}
|
||||
|
||||
impl HistoryBuilder {
|
||||
impl<D: Drift> HistoryBuilder<D> {
|
||||
pub fn time(mut self, time: bool) -> Self {
|
||||
self.time = time;
|
||||
self
|
||||
@@ -40,9 +42,16 @@ impl HistoryBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn gamma(mut self, gamma: f64) -> Self {
|
||||
self.gamma = gamma;
|
||||
self
|
||||
pub fn drift<D2: Drift>(self, drift: D2) -> HistoryBuilder<D2> {
|
||||
HistoryBuilder {
|
||||
drift,
|
||||
time: self.time,
|
||||
mu: self.mu,
|
||||
sigma: self.sigma,
|
||||
beta: self.beta,
|
||||
p_draw: self.p_draw,
|
||||
online: self.online,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn p_draw(mut self, p_draw: f64) -> Self {
|
||||
@@ -55,7 +64,7 @@ impl HistoryBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> History {
|
||||
pub fn build(self) -> History<D> {
|
||||
History {
|
||||
size: 0,
|
||||
batches: Vec::new(),
|
||||
@@ -64,41 +73,48 @@ impl HistoryBuilder {
|
||||
mu: self.mu,
|
||||
sigma: self.sigma,
|
||||
beta: self.beta,
|
||||
gamma: self.gamma,
|
||||
drift: self.drift,
|
||||
p_draw: self.p_draw,
|
||||
online: self.online,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HistoryBuilder {
|
||||
impl HistoryBuilder<ConstantDrift> {
|
||||
pub fn gamma(mut self, gamma: f64) -> Self {
|
||||
self.drift = ConstantDrift(gamma);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HistoryBuilder<ConstantDrift> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
time: true,
|
||||
mu: MU,
|
||||
sigma: SIGMA,
|
||||
beta: BETA,
|
||||
gamma: GAMMA,
|
||||
drift: ConstantDrift(GAMMA),
|
||||
p_draw: P_DRAW,
|
||||
online: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct History {
|
||||
pub struct History<D: Drift = ConstantDrift> {
|
||||
size: usize,
|
||||
pub(crate) batches: Vec<Batch>,
|
||||
agents: HashMap<Index, Agent>,
|
||||
agents: HashMap<Index, Agent<D>>,
|
||||
time: bool,
|
||||
mu: f64,
|
||||
sigma: f64,
|
||||
beta: f64,
|
||||
gamma: f64,
|
||||
drift: D,
|
||||
p_draw: f64,
|
||||
online: bool,
|
||||
}
|
||||
|
||||
impl Default for History {
|
||||
impl Default for History<ConstantDrift> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
size: 0,
|
||||
@@ -108,18 +124,20 @@ impl Default for History {
|
||||
mu: MU,
|
||||
sigma: SIGMA,
|
||||
beta: BETA,
|
||||
gamma: GAMMA,
|
||||
drift: ConstantDrift(GAMMA),
|
||||
p_draw: P_DRAW,
|
||||
online: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl History {
|
||||
pub fn builder() -> HistoryBuilder {
|
||||
impl History<ConstantDrift> {
|
||||
pub fn builder() -> HistoryBuilder<ConstantDrift> {
|
||||
HistoryBuilder::default()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Drift> History<D> {
|
||||
fn iteration(&mut self) -> (f64, f64) {
|
||||
let mut step = (0.0, 0.0);
|
||||
|
||||
@@ -247,7 +265,7 @@ impl History {
|
||||
results: Vec<Vec<f64>>,
|
||||
times: Vec<i64>,
|
||||
weights: Vec<Vec<Vec<f64>>>,
|
||||
mut priors: HashMap<Index, Player>,
|
||||
mut priors: HashMap<Index, Player<D>>,
|
||||
) {
|
||||
assert!(times.is_empty() || self.time, "length(times)>0 but !h.time");
|
||||
assert!(
|
||||
@@ -286,10 +304,11 @@ impl History {
|
||||
Player::new(
|
||||
Gaussian::from_ms(self.mu, self.sigma),
|
||||
self.beta,
|
||||
self.gamma,
|
||||
self.drift,
|
||||
)
|
||||
}),
|
||||
..Default::default()
|
||||
message: N_INF,
|
||||
last_time: i64::MIN,
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -414,9 +433,8 @@ impl History {
|
||||
mod tests {
|
||||
use approx::assert_ulps_eq;
|
||||
|
||||
use crate::{Game, Gaussian, IndexMap, Player, EPSILON, ITERATIONS, P_DRAW};
|
||||
|
||||
use super::*;
|
||||
use crate::{ConstantDrift, EPSILON, Game, Gaussian, ITERATIONS, IndexMap, P_DRAW, Player};
|
||||
|
||||
#[test]
|
||||
fn test_init() {
|
||||
@@ -441,7 +459,7 @@ mod tests {
|
||||
Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
0.15 * 25.0 / 3.0,
|
||||
ConstantDrift(0.15 * 25.0 / 3.0),
|
||||
),
|
||||
);
|
||||
}
|
||||
@@ -503,7 +521,7 @@ mod tests {
|
||||
Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
0.15 * 25.0 / 3.0,
|
||||
ConstantDrift(0.15 * 25.0 / 3.0),
|
||||
),
|
||||
);
|
||||
}
|
||||
@@ -552,7 +570,7 @@ mod tests {
|
||||
Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
25.0 / 300.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
),
|
||||
);
|
||||
}
|
||||
@@ -610,7 +628,7 @@ mod tests {
|
||||
Player::new(
|
||||
Gaussian::from_ms(25.0, 25.0 / 3.0),
|
||||
25.0 / 6.0,
|
||||
25.0 / 300.0,
|
||||
ConstantDrift(25.0 / 300.0),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
32
src/lib.rs
32
src/lib.rs
@@ -1,13 +1,16 @@
|
||||
use std::borrow::{Borrow, ToOwned};
|
||||
use std::cmp::Reverse;
|
||||
use std::collections::HashMap;
|
||||
use std::f64::consts::{FRAC_1_SQRT_2, FRAC_2_SQRT_PI, SQRT_2};
|
||||
use std::hash::Hash;
|
||||
use std::{
|
||||
borrow::{Borrow, ToOwned},
|
||||
cmp::Reverse,
|
||||
collections::HashMap,
|
||||
f64::consts::{FRAC_1_SQRT_2, FRAC_2_SQRT_PI, SQRT_2},
|
||||
hash::Hash,
|
||||
};
|
||||
|
||||
pub mod agent;
|
||||
#[cfg(feature = "approx")]
|
||||
mod approx;
|
||||
pub mod batch;
|
||||
pub mod drift;
|
||||
mod game;
|
||||
pub mod gaussian;
|
||||
mod history;
|
||||
@@ -15,6 +18,7 @@ mod matrix;
|
||||
mod message;
|
||||
pub mod player;
|
||||
|
||||
pub use drift::{ConstantDrift, Drift};
|
||||
pub use game::Game;
|
||||
pub use gaussian::Gaussian;
|
||||
pub use history::History;
|
||||
@@ -82,7 +86,7 @@ where
|
||||
pub fn key(&self, idx: Index) -> Option<&K> {
|
||||
self.0
|
||||
.iter()
|
||||
.find(|(_, &value)| value == idx)
|
||||
.find(|&(_, value)| *value == idx)
|
||||
.map(|(key, _)| key)
|
||||
}
|
||||
|
||||
@@ -115,11 +119,7 @@ fn erfc(x: f64) -> f64 {
|
||||
|
||||
let r = t * (-z * z - 1.26551223 + t * h).exp();
|
||||
|
||||
if x >= 0.0 {
|
||||
r
|
||||
} else {
|
||||
2.0 - r
|
||||
}
|
||||
if x >= 0.0 { r } else { 2.0 - r }
|
||||
}
|
||||
|
||||
fn erfc_inv(mut y: f64) -> f64 {
|
||||
@@ -147,11 +147,7 @@ fn erfc_inv(mut y: f64) -> f64 {
|
||||
x += err / (FRAC_2_SQRT_PI * (-(x.powi(2))).exp() - x * err)
|
||||
}
|
||||
|
||||
if y < 1.0 {
|
||||
x
|
||||
} else {
|
||||
-x
|
||||
}
|
||||
if y < 1.0 { x } else { -x }
|
||||
}
|
||||
|
||||
fn ppf(p: f64, mu: f64, sigma: f64) -> f64 {
|
||||
@@ -239,9 +235,9 @@ pub(crate) fn sort_time(xs: &[i64], reverse: bool) -> Vec<usize> {
|
||||
let mut x = xs.iter().enumerate().collect::<Vec<_>>();
|
||||
|
||||
if reverse {
|
||||
x.sort_by_key(|(_, &x)| Reverse(x));
|
||||
x.sort_by_key(|&(_, x)| Reverse(x));
|
||||
} else {
|
||||
x.sort_by_key(|(_, &x)| x);
|
||||
x.sort_by_key(|&(_, x)| x);
|
||||
}
|
||||
|
||||
x.into_iter().map(|(i, _)| i).collect()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use crate::gaussian::Gaussian;
|
||||
use crate::N_INF;
|
||||
use crate::{N_INF, gaussian::Gaussian};
|
||||
|
||||
pub(crate) struct TeamMessage {
|
||||
pub(crate) prior: Gaussian,
|
||||
|
||||
@@ -1,35 +1,32 @@
|
||||
use crate::{gaussian::Gaussian, BETA, GAMMA};
|
||||
use crate::{
|
||||
BETA, GAMMA,
|
||||
drift::{ConstantDrift, Drift},
|
||||
gaussian::Gaussian,
|
||||
};
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct Player {
|
||||
pub struct Player<D: Drift = ConstantDrift> {
|
||||
pub(crate) prior: Gaussian,
|
||||
pub(crate) beta: f64,
|
||||
pub(crate) gamma: f64,
|
||||
// pub(crate) draw: Gaussian,
|
||||
pub(crate) drift: D,
|
||||
}
|
||||
|
||||
impl Player {
|
||||
pub fn new(prior: Gaussian, beta: f64, gamma: f64) -> Self {
|
||||
Self {
|
||||
prior,
|
||||
beta,
|
||||
gamma,
|
||||
// draw: N_INF,
|
||||
}
|
||||
impl<D: Drift> Player<D> {
|
||||
pub fn new(prior: Gaussian, beta: f64, drift: D) -> Self {
|
||||
Self { prior, beta, drift }
|
||||
}
|
||||
|
||||
pub(crate) fn performance(&self) -> Gaussian {
|
||||
self.prior.forget(self.beta, 1)
|
||||
self.prior.forget(self.beta.powi(2))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Player {
|
||||
impl Default for Player<ConstantDrift> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
prior: Gaussian::default(),
|
||||
beta: BETA,
|
||||
gamma: GAMMA,
|
||||
// draw: N_INF,
|
||||
drift: ConstantDrift(GAMMA),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user