Handle case where there is no time

This commit is contained in:
2022-06-13 21:55:43 +02:00
parent 82e7b22443
commit 1b6e07225b
7 changed files with 447347 additions and 59 deletions

View File

@@ -46,19 +46,19 @@ impl Agent {
}
}
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct Item {
name: String,
likelihood: Gaussian,
}
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct Team {
items: Vec<Item>,
output: u16,
}
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct Event {
teams: Vec<Team>,
pub evidence: f64,
@@ -91,7 +91,7 @@ fn compute_elapsed(last_time: f64, actual_time: f64) -> f64 {
}
}
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct Batch {
pub skills: HashMap<String, Skill>,
pub events: Vec<Event>,
@@ -100,8 +100,8 @@ pub struct Batch {
}
impl Batch {
pub fn new(
composition: Vec<Vec<Vec<&str>>>,
pub fn new<C: AsRef<str> + Clone>(
composition: Vec<Vec<Vec<C>>>,
results: Vec<Vec<u16>>,
time: f64,
agents: &mut HashMap<String, Agent>,
@@ -119,9 +119,9 @@ impl Batch {
this
}
pub fn add_events(
pub fn add_events<C: AsRef<str> + Clone>(
&mut self,
composition: Vec<Vec<Vec<&str>>>,
composition: Vec<Vec<Vec<C>>>,
results: Vec<Vec<u16>>,
agents: &mut HashMap<String, Agent>,
) {
@@ -129,20 +129,21 @@ impl Batch {
.iter()
.flatten()
.flatten()
.cloned()
.map(AsRef::as_ref)
.collect::<HashSet<_>>();
for a in this_agent {
let elapsed = compute_elapsed(agents[a].last_time, self.time);
if let Some(skill) = self.skills.get_mut(a) {
skill.elapsed = elapsed;
skill.forward = agents[a].receive(elapsed);
skill.elapsed = elapsed;
} else {
self.skills.insert(
a.to_string(),
Skill {
forward: agents[a].receive(elapsed),
elapsed,
..Default::default()
},
);
@@ -156,7 +157,7 @@ impl Batch {
.map(|t| {
let items = (0..composition[e][t].len())
.map(|a| Item {
name: composition[e][t][a].to_string(),
name: composition[e][t][a].as_ref().to_string(),
likelihood: N_INF,
})
.collect::<Vec<_>>();

View File

@@ -4,8 +4,8 @@ use crate::{utils, Agent, Batch, Gaussian, Player, N_INF};
pub struct History {
size: usize,
batches: Vec<Batch>,
agents: HashMap<String, Agent>,
pub batches: Vec<Batch>,
pub agents: HashMap<String, Agent>,
mu: f64,
sigma: f64,
gamma: f64,
@@ -14,14 +14,14 @@ pub struct History {
}
impl History {
pub fn new(
composition: Vec<Vec<Vec<&str>>>,
pub fn new<C: AsRef<str> + Clone>(
composition: Vec<Vec<Vec<C>>>,
results: Vec<Vec<u16>>,
times: Vec<u64>,
priors: HashMap<String, Player>,
mu: f64,
beta: f64,
sigma: f64,
beta: f64,
gamma: f64,
p_draw: f64,
) -> Self {
@@ -29,7 +29,7 @@ impl History {
.iter()
.flatten()
.flatten()
.cloned()
.map(AsRef::as_ref)
.collect::<HashSet<_>>();
let agents = this_agent
@@ -67,20 +67,25 @@ impl History {
this
}
fn trueskill(
fn trueskill<C: AsRef<str> + Clone>(
&mut self,
composition: Vec<Vec<Vec<&str>>>,
composition: Vec<Vec<Vec<C>>>,
results: Vec<Vec<u16>>,
times: Vec<u64>,
) {
let o = utils::sortperm(&times, false);
let o = if self.time {
utils::sortperm(&times, false)
} else {
(0..composition.len()).collect::<Vec<_>>()
};
let mut i = 0;
while i < self.size {
let mut j = i + 1;
let t = times[o[i]];
let t = if self.time { times[o[i]] } else { i as u64 + 1 };
while j < self.size && times[o[j]] == t {
while self.time && j < self.size && times[o[j]] == t {
j += 1;
}
@@ -103,7 +108,7 @@ impl History {
for a in b.skills.keys() {
let agent = self.agents.get_mut(a).unwrap();
agent.last_time = t as f64;
agent.last_time = if self.time { t as f64 } else { f64::INFINITY };
agent.message = b.forward_prior_out(a);
}
@@ -443,22 +448,22 @@ mod tests {
assert_ulps_eq!(
h2.batches[2].posterior("aj").mu(),
24.99999999,
24.99866831022851,
epsilon = 0.000001
);
assert_ulps_eq!(
h2.batches[2].posterior("aj").sigma(),
5.419212002,
5.420053708148435,
epsilon = 0.000001
);
assert_ulps_eq!(
h2.batches[2].posterior("cj").mu(),
24.99999999,
25.000532179593538,
epsilon = 0.000001
);
assert_ulps_eq!(
h2.batches[2].posterior("cj").sigma(),
5.419212002,
5.419827012784138,
epsilon = 0.000001
);
}
@@ -510,22 +515,80 @@ mod tests {
assert_ulps_eq!(
lc["aj"][aj_e - 1].1.mu(),
24.99999999569006,
24.99866831022851,
epsilon = 0.000001
);
assert_ulps_eq!(
lc["aj"][aj_e - 1].1.sigma(),
5.419212002171145,
5.420053708148435,
epsilon = 0.000001
);
assert_ulps_eq!(
lc["cj"][cj_e - 1].1.mu(),
24.999999998686533,
25.000532179593538,
epsilon = 0.000001
);
assert_ulps_eq!(
lc["cj"][cj_e - 1].1.sigma(),
5.419212002245715,
5.419827012784138,
epsilon = 0.000001
);
}
#[test]
fn test_env_ttt() {
let composition = vec![
vec![vec!["a"], vec!["b"]],
vec![vec!["a"], vec!["c"]],
vec![vec!["b"], vec!["c"]],
];
let results = vec![vec![1, 0], vec![0, 1], vec![1, 0]];
let mut h = History::new(
composition,
results,
Vec::new(),
HashMap::new(),
25.0,
25.0 / 3.0,
25.0 / 6.0,
25.0 / 300.0,
0.0,
);
let (_step, _i) = h.convergence();
assert_eq!(h.batches[2].skills["b"].elapsed, 1.0);
assert_eq!(h.batches[2].skills["c"].elapsed, 1.0);
assert_ulps_eq!(
h.batches[0].posterior("a").mu(),
25.0002673,
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[0].posterior("a").sigma(),
5.41938162,
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[0].posterior("b").mu(),
24.999465,
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[0].posterior("b").sigma(),
5.419425831,
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[2].posterior("b").mu(),
25.00053219,
epsilon = 0.000001
);
assert_ulps_eq!(
h.batches[2].posterior("b").sigma(),
5.419696790,
epsilon = 0.000001
);
}

View File

@@ -4,37 +4,25 @@ use trueskill_tt::*;
fn main() {
let composition = vec![
vec![vec!["aj"], vec!["bj"]],
vec![vec!["bj"], vec!["cj"]],
vec![vec!["cj"], vec!["aj"]],
vec![vec!["a"], vec!["b"]],
vec![vec!["a"], vec!["c"]],
vec![vec!["b"], vec!["c"]],
];
let results = vec![vec![1, 0], vec![1, 0], vec![1, 0]];
let times = vec![1, 2, 3];
let results = vec![vec![1, 0], vec![0, 1], vec![1, 0]];
let mut priors = HashMap::new();
for k in ["aj", "bj", "cj"] {
let player = Player::new(
Gaussian::new(25.0, 25.0 / 3.0),
25.0 / 6.0,
25.0 / 300.0,
N_INF,
);
priors.insert(k.to_string(), player);
}
let mut h2 = History::new(
let mut h = History::new(
composition,
results,
times,
priors,
MU,
BETA,
SIGMA,
GAMMA,
P_DRAW,
vec![],
HashMap::new(),
25.0,
25.0 / 3.0,
25.0 / 6.0,
25.0 / 300.0,
0.0,
);
let (_step, _i) = h2.convergence();
let (_step, _i) = h.convergence();
println!("{:#?}", h.batches);
}

View File

@@ -69,7 +69,7 @@ pub(crate) fn mu_sigma(tau: f64, pi: f64) -> (f64, f64) {
}
if pi + 1e-5 < 0.0 {
panic!("sigma should be greater than 0");
panic!("pi should be greater than 0, got: {}", pi + 1e-5);
}
(0.0, f64::INFINITY)