Refactor factor graph

This commit is contained in:
2023-01-11 23:01:29 +01:00
parent ab0a7cb447
commit 8236496c4b
7 changed files with 695 additions and 113 deletions

View File

@@ -8,3 +8,4 @@ Implementation of TrueSkill™ in Rust.
* [JesseBuesking/trueskill](https://github.com/JesseBuesking/trueskill) * [JesseBuesking/trueskill](https://github.com/JesseBuesking/trueskill)
* [sublee/trueskill](https://github.com/sublee/trueskill) * [sublee/trueskill](https://github.com/sublee/trueskill)
* [moserware/Skills](https://github.com/moserware/Skills) * [moserware/Skills](https://github.com/moserware/Skills)
* [TrueSkill(TM): A Bayesian Skill Rating System](https://www.microsoft.com/en-us/research/wp-content/uploads/2007/01/NIPS2006_0688.pdf)

View File

@@ -5,6 +5,7 @@ use std::ops;
use log::*; use log::*;
use crate::gaussian::Gaussian; use crate::gaussian::Gaussian;
use crate::graph::{MessageArena, MessageId};
use crate::math; use crate::math;
#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[derive(Clone, Copy, Debug, PartialEq, Eq)]
@@ -90,6 +91,10 @@ impl Variable {
self.value self.value
} }
pub fn get_value_mut(&mut self) -> &mut Gaussian {
&mut self.value
}
pub fn update_message(&mut self, factor: usize, message: Gaussian) { pub fn update_message(&mut self, factor: usize, message: Gaussian) {
let old = self.factors[&factor]; let old = self.factors[&factor];
@@ -105,19 +110,18 @@ impl Variable {
} }
pub struct PriorFactor { pub struct PriorFactor {
id: usize, id: MessageId,
variable: VariableId, variable: VariableId,
gaussian: Gaussian, gaussian: Gaussian,
} }
impl PriorFactor { impl PriorFactor {
pub fn new( pub fn new(
variable_arena: &mut VariableArena, message_arena: &mut MessageArena,
id: usize,
variable: VariableId, variable: VariableId,
gaussian: Gaussian, gaussian: Gaussian,
) -> PriorFactor { ) -> PriorFactor {
variable_arena[variable].attach_factor(id); let id = message_arena.create();
PriorFactor { PriorFactor {
id, id,
@@ -126,18 +130,19 @@ impl PriorFactor {
} }
} }
pub fn start(&self, variable_arena: &mut VariableArena) { pub fn start(&self, variable_arena: &mut VariableArena, message_arena: &mut MessageArena) {
debug!( let old = message_arena[self.id];
"Prior::down var={:?}, value={:?}", let value = variable_arena[self.variable].get_value();
self.variable.index, self.gaussian
);
variable_arena[self.variable].update_value(self.id, self.gaussian); message_arena[self.id] = self.gaussian * old / value;
*variable_arena[self.variable].get_value_mut() = self.gaussian;
} }
} }
pub struct LikelihoodFactor { pub struct LikelihoodFactor {
id: usize, mean_msg: MessageId,
value_msg: MessageId,
mean: VariableId, mean: VariableId,
value: VariableId, value: VariableId,
variance: f64, variance: f64,
@@ -145,62 +150,60 @@ pub struct LikelihoodFactor {
impl LikelihoodFactor { impl LikelihoodFactor {
pub fn new( pub fn new(
variable_arena: &mut VariableArena, message_arena: &mut MessageArena,
id: usize,
mean: VariableId, mean: VariableId,
value: VariableId, value: VariableId,
variance: f64, variance: f64,
) -> LikelihoodFactor { ) -> LikelihoodFactor {
variable_arena[mean].attach_factor(id);
variable_arena[value].attach_factor(id);
LikelihoodFactor { LikelihoodFactor {
id, mean_msg: message_arena.create(),
value_msg: message_arena.create(),
mean, mean,
value, value,
variance, variance,
} }
} }
pub fn update_mean(&self, variable_arena: &mut VariableArena) { pub fn update_mean(
let (x, fx) = { &self,
let variable = &variable_arena[self.value]; variable_arena: &mut VariableArena,
message_arena: &mut MessageArena,
(variable.get_value(), variable.get_message(self.id)) ) {
}; let x = variable_arena[self.value].get_value();
let fx = message_arena[self.value_msg];
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi())); let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau())); let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau()));
debug!( let old = message_arena[self.mean_msg];
"Likelihood::up var={:?}, value={:?}",
self.mean.index, gaussian
);
variable_arena[self.mean].update_message(self.id, gaussian); message_arena[self.mean_msg] = gaussian;
*variable_arena[self.mean].get_value_mut() =
variable_arena[self.mean].get_value() / old * gaussian;
} }
pub fn update_value(&self, variable_arena: &mut VariableArena) { pub fn update_value(
let (y, fy) = { &self,
let variable = &variable_arena[self.mean]; variable_arena: &mut VariableArena,
message_arena: &mut MessageArena,
(variable.get_value(), variable.get_message(self.id)) ) {
}; let y = variable_arena[self.mean].get_value();
let fy = message_arena[self.mean_msg];
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi())); let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau())); let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau()));
debug!( let old = message_arena[self.value_msg];
"Likelihood::down var={:?}, value={:?}",
self.value.index, gaussian
);
variable_arena[self.value].update_message(self.id, gaussian); message_arena[self.value_msg] = gaussian;
*variable_arena[self.value].get_value_mut() =
variable_arena[self.value].get_value() / old * gaussian;
} }
} }
pub struct SumFactor { pub struct SumFactor {
id: usize, sum_msg: MessageId,
terms_msg: Vec<MessageId>,
sum: VariableId, sum: VariableId,
terms: Vec<VariableId>, terms: Vec<VariableId>,
coeffs: Vec<f64>, coeffs: Vec<f64>,
@@ -208,20 +211,17 @@ pub struct SumFactor {
impl SumFactor { impl SumFactor {
pub fn new( pub fn new(
variable_arena: &mut VariableArena, message_arena: &mut MessageArena,
id: usize,
sum: VariableId, sum: VariableId,
terms: Vec<VariableId>, terms: Vec<VariableId>,
coeffs: Vec<f64>, coeffs: Vec<f64>,
) -> SumFactor { ) -> SumFactor {
variable_arena[sum].attach_factor(id);
for term in &terms {
variable_arena[*term].attach_factor(id);
}
SumFactor { SumFactor {
id, sum_msg: message_arena.create(),
terms_msg: terms
.iter()
.map(|_| message_arena.create())
.collect::<Vec<_>>(),
sum, sum,
terms, terms,
coeffs, coeffs,
@@ -231,7 +231,9 @@ impl SumFactor {
fn internal_update( fn internal_update(
&self, &self,
variable_arena: &mut VariableArena, variable_arena: &mut VariableArena,
message_arena: &mut MessageArena,
variable: VariableId, variable: VariableId,
message: MessageId,
y: &[Gaussian], y: &[Gaussian],
fy: &[Gaussian], fy: &[Gaussian],
a: &[f64], a: &[f64],
@@ -259,25 +261,44 @@ impl SumFactor {
debug!("Sum::up var={:?}, value={:?}", variable.index, gaussian); debug!("Sum::up var={:?}, value={:?}", variable.index, gaussian);
} }
variable_arena[variable].update_message(self.id, gaussian); let old = message_arena[message];
message_arena[message] = gaussian;
*variable_arena[variable].get_value_mut() =
variable_arena[variable].get_value() / old * gaussian;
} }
pub fn update_sum(&self, variable_arena: &mut VariableArena) { pub fn update_sum(&self, variable_arena: &mut VariableArena, message_arena: &mut MessageArena) {
let (y, fy): (Vec<_>, Vec<_>) = self let (y, fy): (Vec<_>, Vec<_>) = self
.terms .terms
.iter() .iter()
.map(|term| { .zip(self.terms_msg.iter())
.map(|(term, msg)| {
let variable = &variable_arena[*term]; let variable = &variable_arena[*term];
(variable.get_value(), variable.get_message(self.id)) (variable.get_value(), message_arena[*msg])
}) })
.unzip(); .unzip();
self.internal_update(variable_arena, self.sum, &y, &fy, &self.coeffs); self.internal_update(
variable_arena,
message_arena,
self.sum,
self.sum_msg,
&y,
&fy,
&self.coeffs,
);
} }
pub fn update_term(&self, variable_arena: &mut VariableArena, index: usize) { pub fn update_term(
&self,
variable_arena: &mut VariableArena,
message_arena: &mut MessageArena,
index: usize,
) {
let idx_term = self.terms[index]; let idx_term = self.terms[index];
let idx_term_msg = self.terms_msg[index];
let idx_coeff = self.coeffs[index]; let idx_coeff = self.coeffs[index];
let a = self let a = self
@@ -296,21 +317,39 @@ impl SumFactor {
let (y, fy): (Vec<_>, Vec<_>) = self let (y, fy): (Vec<_>, Vec<_>) = self
.terms .terms
.iter() .iter()
.zip(self.terms_msg.iter())
.enumerate() .enumerate()
.map(|(i, term)| { .map(|(i, (term, msg))| {
let variable_id = if i == index { self.sum } else { *term }; if i == index {
let variable = &variable_arena[variable_id]; let variable = &variable_arena[self.sum];
(variable.get_value(), variable.get_message(self.id)) (variable.get_value(), message_arena[self.sum_msg])
} else {
let variable = &variable_arena[*term];
(variable.get_value(), message_arena[*msg])
}
}) })
.unzip(); .unzip();
self.internal_update(variable_arena, idx_term, &y, &fy, &a); self.internal_update(
variable_arena,
message_arena,
idx_term,
idx_term_msg,
&y,
&fy,
&a,
);
} }
pub fn update_all_terms(&self, variable_arena: &mut VariableArena) { pub fn update_all_terms(
&self,
variable_arena: &mut VariableArena,
message_arena: &mut MessageArena,
) {
for i in 0..self.terms.len() { for i in 0..self.terms.len() {
self.update_term(variable_arena, i); self.update_term(variable_arena, message_arena, i);
} }
} }
} }
@@ -338,7 +377,7 @@ fn w_draw(t: f64, e: f64) -> f64 {
} }
pub struct TruncateFactor { pub struct TruncateFactor {
id: usize, variable_msg: MessageId,
variable: VariableId, variable: VariableId,
epsilon: f64, epsilon: f64,
draw: bool, draw: bool,
@@ -346,28 +385,26 @@ pub struct TruncateFactor {
impl TruncateFactor { impl TruncateFactor {
pub fn new( pub fn new(
variable_arena: &mut VariableArena, message_arena: &mut MessageArena,
id: usize,
variable: VariableId, variable: VariableId,
epsilon: f64, epsilon: f64,
draw: bool, draw: bool,
) -> TruncateFactor { ) -> TruncateFactor {
variable_arena[variable].attach_factor(id);
TruncateFactor { TruncateFactor {
id, variable_msg: message_arena.create(),
variable, variable,
epsilon, epsilon,
draw, draw,
} }
} }
pub fn update(&self, variable_arena: &mut VariableArena) -> f64 { pub fn update(
let (x, fx) = { &self,
let variable = &variable_arena[self.variable]; variable_arena: &mut VariableArena,
message_arena: &mut MessageArena,
(variable.get_value(), variable.get_message(self.id)) ) -> f64 {
}; let x = variable_arena[self.variable].get_value();
let fx = message_arena[self.variable_msg];
let c = x.pi() - fx.pi(); let c = x.pi() - fx.pi();
let d = x.tau() - fx.tau(); let d = x.tau() - fx.tau();
@@ -387,11 +424,28 @@ impl TruncateFactor {
let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w); let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w);
debug!( let old = message_arena[self.variable_msg];
"Trunc::up var={:?}, value={:?}", let value = variable_arena[self.variable].get_value();
self.variable.index, gaussian
);
variable_arena[self.variable].update_value(self.id, gaussian) message_arena[self.variable_msg] = gaussian * old / value;
let pi_delta = (value.pi() - gaussian.pi()).abs();
let delta = if !pi_delta.is_finite() {
0.0
} else {
let pi_delta = pi_delta.sqrt();
let tau_delta = (value.tau() - gaussian.tau()).abs();
if pi_delta > tau_delta {
pi_delta
} else {
tau_delta
}
};
*variable_arena[self.variable].get_value_mut() = gaussian;
delta
} }
} }

420
src/graph.rs Normal file
View File

@@ -0,0 +1,420 @@
mod message;
mod variable;
pub use message::*;
pub use variable::*;
/*
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct FactorId(usize);
pub struct Graph {
variables: VariableArena,
factors: Vec<Factor>,
messages: MessageArena,
}
impl Graph {
#[inline(always)]
pub fn add_variable(&mut self) -> VariableId {
self.variables.create()
}
pub fn add_factor(&mut self, factor: Factor) -> FactorId {
let idx = self.factors.len();
self.factors.push(factor);
FactorId(idx)
}
}
pub enum Factor {
Prior { gaussian: Gaussian },
}
*/
// FACTORS
/*
pub struct PriorFactor {
id: usize,
variable: VariableId,
gaussian: Gaussian,
}
impl PriorFactor {
pub fn new(
variable_arena: &mut VariableArena,
id: usize,
variable: VariableId,
gaussian: Gaussian,
) -> PriorFactor {
variable_arena[variable].attach_factor(id);
PriorFactor {
id,
variable,
gaussian,
}
}
pub fn start(&self, variable_arena: &mut VariableArena) {
debug!(
"Prior::down var={:?}, value={:?}",
self.variable.index, self.gaussian
);
variable_arena[self.variable].update_value(self.id, self.gaussian);
}
}
*/
/*
pub struct Variable {
value: Gaussian,
factors: HashMap<usize, Gaussian>,
}
impl Variable {
pub fn attach_factor(&mut self, factor: usize) {
self.factors.insert(factor, Gaussian::from_pi_tau(0.0, 0.0));
}
pub fn update_value(&mut self, factor: usize, value: Gaussian) -> f64 {
let old = self.factors[&factor];
self.factors.insert(factor, value * old / self.value);
let pi_delta = (self.value.pi() - value.pi()).abs();
let delta = if !pi_delta.is_finite() {
0.0
} else {
let pi_delta = pi_delta.sqrt();
let tau_delta = (self.value.tau() - value.tau()).abs();
if pi_delta > tau_delta {
pi_delta
} else {
tau_delta
}
};
self.value = value;
debug!("Variable::value old={:?}, new={:?}", old, value);
delta
}
pub fn get_value(&self) -> Gaussian {
self.value
}
pub fn get_message(&self, factor: usize) -> Gaussian {
self.factors[&factor]
}
pub fn update_message(&mut self, factor: usize, message: Gaussian) {
let old = self.factors[&factor];
self.factors.insert(factor, message);
self.value = self.value / old * message;
debug!("Variable::message old={:?}, new={:?}", old, message);
}
}
*/
/*
pub struct PriorFactor {
id: usize,
variable: VariableId,
gaussian: Gaussian,
}
impl PriorFactor {
pub fn new(
variable_arena: &mut VariableArena,
id: usize,
variable: VariableId,
gaussian: Gaussian,
) -> PriorFactor {
variable_arena[variable].attach_factor(id);
PriorFactor {
id,
variable,
gaussian,
}
}
pub fn start(&self, variable_arena: &mut VariableArena) {
debug!(
"Prior::down var={:?}, value={:?}",
self.variable.index, self.gaussian
);
variable_arena[self.variable].update_value(self.id, self.gaussian);
}
}
pub struct LikelihoodFactor {
id: usize,
mean: VariableId,
value: VariableId,
variance: f64,
}
impl LikelihoodFactor {
pub fn new(
variable_arena: &mut VariableArena,
id: usize,
mean: VariableId,
value: VariableId,
variance: f64,
) -> LikelihoodFactor {
variable_arena[mean].attach_factor(id);
variable_arena[value].attach_factor(id);
LikelihoodFactor {
id,
mean,
value,
variance,
}
}
pub fn update_mean(&self, variable_arena: &mut VariableArena) {
let (x, fx) = {
let variable = &variable_arena[self.value];
(variable.get_value(), variable.get_message(self.id))
};
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau()));
debug!(
"Likelihood::up var={:?}, value={:?}",
self.mean.index, gaussian
);
variable_arena[self.mean].update_message(self.id, gaussian);
}
pub fn update_value(&self, variable_arena: &mut VariableArena) {
let (y, fy) = {
let variable = &variable_arena[self.mean];
(variable.get_value(), variable.get_message(self.id))
};
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau()));
debug!(
"Likelihood::down var={:?}, value={:?}",
self.value.index, gaussian
);
variable_arena[self.value].update_message(self.id, gaussian);
}
}
pub struct SumFactor {
id: usize,
sum: VariableId,
terms: Vec<VariableId>,
coeffs: Vec<f64>,
}
impl SumFactor {
pub fn new(
variable_arena: &mut VariableArena,
id: usize,
sum: VariableId,
terms: Vec<VariableId>,
coeffs: Vec<f64>,
) -> SumFactor {
variable_arena[sum].attach_factor(id);
for term in &terms {
variable_arena[*term].attach_factor(id);
}
SumFactor {
id,
sum,
terms,
coeffs,
}
}
fn internal_update(
&self,
variable_arena: &mut VariableArena,
variable: VariableId,
y: &[Gaussian],
fy: &[Gaussian],
a: &[f64],
) {
let (sum_pi, sum_tau) =
a.iter()
.zip(y.iter().zip(fy.iter()))
.fold((0.0, 0.0f64), |(pi, tau), (a, (y, fy))| {
let x = *y / *fy;
let new_pi = a.powi(2) / x.pi();
let new_tau = a * x.mu();
(pi + new_pi, tau + new_tau)
});
let new_pi = 1.0 / sum_pi;
let new_tau = new_pi * sum_tau;
let gaussian = Gaussian::from_pi_tau(new_pi, new_tau);
if variable == self.sum {
debug!("Sum::down var={:?}, value={:?}", variable.index, gaussian);
} else {
debug!("Sum::up var={:?}, value={:?}", variable.index, gaussian);
}
variable_arena[variable].update_message(self.id, gaussian);
}
pub fn update_sum(&self, variable_arena: &mut VariableArena) {
let (y, fy): (Vec<_>, Vec<_>) = self
.terms
.iter()
.map(|term| {
let variable = &variable_arena[*term];
(variable.get_value(), variable.get_message(self.id))
})
.unzip();
self.internal_update(variable_arena, self.sum, &y, &fy, &self.coeffs);
}
pub fn update_term(&self, variable_arena: &mut VariableArena, index: usize) {
let idx_term = self.terms[index];
let idx_coeff = self.coeffs[index];
let a = self
.coeffs
.iter()
.enumerate()
.map(|(i, coeff)| {
if i == index {
1.0 / idx_coeff
} else {
-coeff / idx_coeff
}
})
.collect::<Vec<_>>();
let (y, fy): (Vec<_>, Vec<_>) = self
.terms
.iter()
.enumerate()
.map(|(i, term)| {
let variable_id = if i == index { self.sum } else { *term };
let variable = &variable_arena[variable_id];
(variable.get_value(), variable.get_message(self.id))
})
.unzip();
self.internal_update(variable_arena, idx_term, &y, &fy, &a);
}
pub fn update_all_terms(&self, variable_arena: &mut VariableArena) {
for i in 0..self.terms.len() {
self.update_term(variable_arena, i);
}
}
}
fn v_win(t: f64, e: f64) -> f64 {
math::pdf(t - e) / math::cdf(t - e)
}
fn w_win(t: f64, e: f64) -> f64 {
let vwin = v_win(t, e);
vwin * (vwin + t - e)
}
fn v_draw(t: f64, e: f64) -> f64 {
(math::pdf(-e - t) - math::pdf(e - t)) / (math::cdf(e - t) - math::cdf(-e - t))
}
fn w_draw(t: f64, e: f64) -> f64 {
let vdraw = v_draw(t, e);
let n = (vdraw * vdraw) + ((e - t) * math::pdf(e - t) + (e + t) * math::pdf(e + t));
let d = math::cdf(e - t) - math::cdf(-e - t);
n / d
}
pub struct TruncateFactor {
id: usize,
variable: VariableId,
epsilon: f64,
draw: bool,
}
impl TruncateFactor {
pub fn new(
variable_arena: &mut VariableArena,
id: usize,
variable: VariableId,
epsilon: f64,
draw: bool,
) -> TruncateFactor {
variable_arena[variable].attach_factor(id);
TruncateFactor {
id,
variable,
epsilon,
draw,
}
}
pub fn update(&self, variable_arena: &mut VariableArena) -> f64 {
let (x, fx) = {
let variable = &variable_arena[self.variable];
(variable.get_value(), variable.get_message(self.id))
};
let c = x.pi() - fx.pi();
let d = x.tau() - fx.tau();
let sqrt_c = c.sqrt();
let t = d / sqrt_c;
let e = self.epsilon * sqrt_c;
let (v, w) = if self.draw {
(v_draw(t, e), w_draw(t, e))
} else {
(v_win(t, e), w_win(t, e))
};
let m_w = 1.0 - w;
let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w);
debug!(
"Trunc::up var={:?}, value={:?}",
self.variable.index, gaussian
);
variable_arena[self.variable].update_value(self.id, gaussian)
}
}
*/

42
src/graph/distribution.rs Normal file
View File

@@ -0,0 +1,42 @@
use std::marker::PhantomData;
use std::ops;
use crate::gaussian::Gaussian;
pub struct DistributionArena<T> {
arena: Vec<Gaussian>,
_phantom: PhantomData<T>,
}
impl<T> DistributionArena<T>
where
T: From<usize>,
{
pub fn create(&mut self) -> T {
let idx = self.arena.len();
self.arena.push(Gaussian::from_pi_tau(0.0, 0.0));
idx.into()
}
}
impl<T> ops::Index<T> for DistributionArena<T>
where
T: Into<usize>,
{
type Output = Gaussian;
fn index(&self, id: T) -> &Self::Output {
&self.arena[id.into()]
}
}
impl<T> ops::IndexMut<T> for DistributionArena<T>
where
T: Into<usize>,
{
fn index_mut(&mut self, id: T) -> &mut Self::Output {
&mut self.arena[id.into()]
}
}

41
src/graph/message.rs Normal file
View File

@@ -0,0 +1,41 @@
use std::ops;
use crate::gaussian::Gaussian;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct MessageId(usize);
pub struct MessageArena {
arena: Vec<Gaussian>,
}
impl MessageArena {
pub fn new() -> Self {
Self { arena: Vec::new() }
}
#[inline(always)]
pub fn create(&mut self) -> MessageId {
let idx = self.arena.len();
self.arena.push(Gaussian::from_pi_tau(0.0, 0.0));
MessageId(idx)
}
}
impl ops::Index<MessageId> for MessageArena {
type Output = Gaussian;
#[inline(always)]
fn index(&self, id: MessageId) -> &Self::Output {
&self.arena[id.0]
}
}
impl ops::IndexMut<MessageId> for MessageArena {
#[inline(always)]
fn index_mut(&mut self, id: MessageId) -> &mut Self::Output {
&mut self.arena[id.0]
}
}

41
src/graph/variable.rs Normal file
View File

@@ -0,0 +1,41 @@
use std::ops;
use crate::gaussian::Gaussian;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct VariableId(usize);
pub struct VariableArena {
arena: Vec<Gaussian>,
}
impl VariableArena {
pub fn new() -> Self {
Self { arena: Vec::new() }
}
#[inline(always)]
pub fn create(&mut self) -> VariableId {
let idx = self.arena.len();
self.arena.push(Gaussian::from_pi_tau(0.0, 0.0));
VariableId(idx)
}
}
impl ops::Index<VariableId> for VariableArena {
type Output = Gaussian;
#[inline(always)]
fn index(&self, id: VariableId) -> &Self::Output {
&self.arena[id.0]
}
}
impl ops::IndexMut<VariableId> for VariableArena {
#[inline(always)]
fn index_mut(&mut self, id: VariableId) -> &mut Self::Output {
&mut self.arena[id.0]
}
}

View File

@@ -2,11 +2,13 @@
mod factor_graph; mod factor_graph;
mod gaussian; mod gaussian;
mod graph;
mod math; mod math;
mod matrix; mod matrix;
use crate::factor_graph::*; use crate::factor_graph::*;
use crate::gaussian::Gaussian; use crate::gaussian::Gaussian;
use crate::graph::MessageArena;
use crate::matrix::Matrix; use crate::matrix::Matrix;
/// Default initial mean of ratings. /// Default initial mean of ratings.
@@ -128,6 +130,7 @@ impl TrueSkill {
let beta_sqr = self.beta.powi(2); let beta_sqr = self.beta.powi(2);
let mut variable_arena = VariableArena::new(); let mut variable_arena = VariableArena::new();
let mut message_arena = MessageArena::new();
let rating_count = ratings.len(); let rating_count = ratings.len();
let team_count = ranks.len(); let team_count = ranks.len();
@@ -149,8 +152,6 @@ impl TrueSkill {
.map(|_| variable_arena.create()) .map(|_| variable_arena.create())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let mut factor_id = 0;
let rating_layer = rating_vars let rating_layer = rating_vars
.iter() .iter()
.zip(ratings.iter().map(|(rating, _)| rating)) .zip(ratings.iter().map(|(rating, _)| rating))
@@ -158,9 +159,7 @@ impl TrueSkill {
let gaussian = let gaussian =
Gaussian::from_mu_sigma(rating.mu(), (rating.sigma().powi(2) + tau_sqr).sqrt()); Gaussian::from_mu_sigma(rating.mu(), (rating.sigma().powi(2) + tau_sqr).sqrt());
factor_id += 1; PriorFactor::new(&mut message_arena, *rating_var, gaussian)
PriorFactor::new(&mut variable_arena, factor_id, *rating_var, gaussian)
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@@ -168,9 +167,7 @@ impl TrueSkill {
.iter() .iter()
.zip(perf_vars.iter().map(|(variable, _)| variable)) .zip(perf_vars.iter().map(|(variable, _)| variable))
.map(|(rating_var, perf)| { .map(|(rating_var, perf)| {
factor_id += 1; LikelihoodFactor::new(&mut message_arena, *rating_var, *perf, beta_sqr)
LikelihoodFactor::new(&mut variable_arena, factor_id, *rating_var, *perf, beta_sqr)
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@@ -178,8 +175,6 @@ impl TrueSkill {
.iter() .iter()
.enumerate() .enumerate()
.map(|(i, variable)| { .map(|(i, variable)| {
factor_id += 1;
let team = perf_vars let team = perf_vars
.iter() .iter()
.filter(|(_, team)| *team as usize == i) .filter(|(_, team)| *team as usize == i)
@@ -188,13 +183,7 @@ impl TrueSkill {
let team_count = team.len(); let team_count = team.len();
SumFactor::new( SumFactor::new(&mut message_arena, *variable, team, vec![1.0; team_count])
&mut variable_arena,
factor_id,
*variable,
team,
vec![1.0; team_count],
)
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@@ -202,11 +191,8 @@ impl TrueSkill {
.iter() .iter()
.zip(team_perf_vars.windows(2)) .zip(team_perf_vars.windows(2))
.map(|(variable, teams)| { .map(|(variable, teams)| {
factor_id += 1;
SumFactor::new( SumFactor::new(
&mut variable_arena, &mut message_arena,
factor_id,
*variable, *variable,
teams.to_vec(), teams.to_vec(),
vec![1.0, -1.0], vec![1.0, -1.0],
@@ -218,16 +204,13 @@ impl TrueSkill {
.iter() .iter()
.enumerate() .enumerate()
.map(|(i, variable)| { .map(|(i, variable)| {
factor_id += 1;
let player_count = perf_vars let player_count = perf_vars
.iter() .iter()
.filter(|(_, team)| *team as usize == i || *team as usize == i + 1) .filter(|(_, team)| *team as usize == i || *team as usize == i + 1)
.count(); .count();
TruncateFactor::new( TruncateFactor::new(
&mut variable_arena, &mut message_arena,
factor_id,
*variable, *variable,
draw_margin(self.draw_probability, self.beta, player_count as f64), draw_margin(self.draw_probability, self.beta, player_count as f64),
ranks[i] == ranks[i + 1], ranks[i] == ranks[i + 1],
@@ -236,26 +219,26 @@ impl TrueSkill {
.collect::<Vec<_>>(); .collect::<Vec<_>>();
for factor in &rating_layer { for factor in &rating_layer {
factor.start(&mut variable_arena); factor.start(&mut variable_arena, &mut message_arena);
} }
for factor in &perf_layer { for factor in &perf_layer {
factor.update_value(&mut variable_arena); factor.update_value(&mut variable_arena, &mut message_arena);
} }
for factor in &team_perf_layer { for factor in &team_perf_layer {
factor.update_sum(&mut variable_arena); factor.update_sum(&mut variable_arena, &mut message_arena);
} }
for _ in 0..10 { for _ in 0..10 {
let mut delta = 0.0; let mut delta = 0.0;
for factor in &team_diff_layer { for factor in &team_diff_layer {
factor.update_sum(&mut variable_arena); factor.update_sum(&mut variable_arena, &mut message_arena);
} }
for factor in &trunc_layer { for factor in &trunc_layer {
let d = factor.update(&mut variable_arena); let d = factor.update(&mut variable_arena, &mut message_arena);
if d > delta { if d > delta {
delta = d; delta = d;
@@ -263,8 +246,8 @@ impl TrueSkill {
} }
for factor in &team_diff_layer { for factor in &team_diff_layer {
factor.update_term(&mut variable_arena, 0); factor.update_term(&mut variable_arena, &mut message_arena, 0);
factor.update_term(&mut variable_arena, 1); factor.update_term(&mut variable_arena, &mut message_arena, 1);
} }
if delta < min_delta { if delta < min_delta {
@@ -273,11 +256,11 @@ impl TrueSkill {
} }
for factor in &team_perf_layer { for factor in &team_perf_layer {
factor.update_all_terms(&mut variable_arena); factor.update_all_terms(&mut variable_arena, &mut message_arena);
} }
for factor in &perf_layer { for factor in &perf_layer {
factor.update_mean(&mut variable_arena); factor.update_mean(&mut variable_arena, &mut message_arena);
} }
rating_vars rating_vars