This commit is contained in:
2018-10-25 21:41:23 +02:00
parent 4fc27841e9
commit 2367a8e47f
3 changed files with 37 additions and 33 deletions

View File

@@ -99,7 +99,10 @@ impl PriorFactor {
} }
pub fn start(&self, variable_arena: &mut VariableArena) { pub fn start(&self, variable_arena: &mut VariableArena) {
debug!("Prior::down var={:?}, value={:?}", self.variable.index, self.gaussian); debug!(
"Prior::down var={:?}, value={:?}",
self.variable.index, self.gaussian
);
variable_arena variable_arena
.get_mut(self.variable) .get_mut(self.variable)
@@ -143,7 +146,10 @@ impl LikelihoodFactor {
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi())); let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau())); let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau()));
debug!("Likelihood::up var={:?}, value={:?}", self.mean.index, gaussian); debug!(
"Likelihood::up var={:?}, value={:?}",
self.mean.index, gaussian
);
variable_arena variable_arena
.get_mut(self.mean) .get_mut(self.mean)
@@ -160,7 +166,10 @@ impl LikelihoodFactor {
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi())); let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau())); let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau()));
debug!("Likelihood::down var={:?}, value={:?}", self.value.index, gaussian); debug!(
"Likelihood::down var={:?}, value={:?}",
self.value.index, gaussian
);
variable_arena variable_arena
.get_mut(self.value) .get_mut(self.value)
@@ -359,7 +368,10 @@ impl TruncateFactor {
let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w); let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w);
debug!("Trunc::up var={:?}, value={:?}", self.variable.index, gaussian); debug!(
"Trunc::up var={:?}, value={:?}",
self.variable.index, gaussian
);
variable_arena variable_arena
.get_mut(self.variable) .get_mut(self.variable)

View File

@@ -13,7 +13,6 @@ impl fmt::Debug for Gaussian {
} }
} }
impl Gaussian { impl Gaussian {
pub fn from_pi_tau(pi: f64, tau: f64) -> Gaussian { pub fn from_pi_tau(pi: f64, tau: f64) -> Gaussian {
Gaussian { pi, tau } Gaussian { pi, tau }

View File

@@ -92,16 +92,20 @@ where
let rating_count = ratings.len(); let rating_count = ratings.len();
let team_count = ranks.len(); let team_count = ranks.len();
let rating_vars = (0..rating_count).map(|_| variable_arena.create()).collect::<Vec<_>>(); let rating_vars = (0..rating_count)
.map(|_| variable_arena.create())
.collect::<Vec<_>>();
let perf_vars = ratings let perf_vars = ratings
.iter() .iter()
.map(|(_, team)| { .map(|(_, team)| (variable_arena.create(), *team))
(variable_arena.create(), *team)
})
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let team_perf_vars = (0..team_count).map(|_| variable_arena.create()).collect::<Vec<_>>(); let team_perf_vars = (0..team_count)
let team_diff_vars = (0..team_count - 1).map(|_| variable_arena.create()).collect::<Vec<_>>(); .map(|_| variable_arena.create())
.collect::<Vec<_>>();
let team_diff_vars = (0..team_count - 1)
.map(|_| variable_arena.create())
.collect::<Vec<_>>();
let mut factor_id = 0; let mut factor_id = 0;
@@ -109,16 +113,12 @@ where
.iter() .iter()
.zip(ratings.iter().map(|(rating, _)| rating)) .zip(ratings.iter().map(|(rating, _)| rating))
.map(|(rating_var, rating)| { .map(|(rating_var, rating)| {
let gaussian = Gaussian::from_mu_sigma(rating.mu(), (rating.sigma().powi(2) + tau_sqr).sqrt()); let gaussian =
Gaussian::from_mu_sigma(rating.mu(), (rating.sigma().powi(2) + tau_sqr).sqrt());
factor_id += 1; factor_id += 1;
PriorFactor::new( PriorFactor::new(&mut variable_arena, factor_id, *rating_var, gaussian)
&mut variable_arena,
factor_id,
*rating_var,
gaussian,
)
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@@ -128,13 +128,7 @@ where
.map(|(rating_var, perf)| { .map(|(rating_var, perf)| {
factor_id += 1; factor_id += 1;
LikelihoodFactor::new( LikelihoodFactor::new(&mut variable_arena, factor_id, *rating_var, *perf, beta_sqr)
&mut variable_arena,
factor_id,
*rating_var,
*perf,
beta_sqr,
)
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@@ -175,7 +169,6 @@ where
teams.to_vec(), teams.to_vec(),
vec![1.0, -1.0], vec![1.0, -1.0],
) )
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@@ -197,7 +190,6 @@ where
draw_margin(DRAW_PROBABILITY, BETA, player_count as f64), draw_margin(DRAW_PROBABILITY, BETA, player_count as f64),
ranks[i] == ranks[i + 1], ranks[i] == ranks[i + 1],
) )
}) })
.collect::<Vec<_>>();; .collect::<Vec<_>>();;
@@ -239,11 +231,9 @@ where
rating_vars rating_vars
.iter() .iter()
.map(|variable| variable_arena.get(*variable).unwrap().get_value()) .map(|variable| variable_arena.get(*variable).unwrap().get_value())
.map(|value| { .map(|value| Rating {
Rating {
mu: value.mu(), mu: value.mu(),
sigma: value.sigma(), sigma: value.sigma(),
}
}) })
.collect::<Vec<_>>() .collect::<Vec<_>>()
} }
@@ -372,7 +362,10 @@ mod tests {
Rating::new(16.79337409436942, 6.348053083319977), Rating::new(16.79337409436942, 6.348053083319977),
]; ];
let ratings = rate(&[(alice, 0), (bob, 1), (chris, 2), (darren, 3)], &[0, 1, 2, 3]); let ratings = rate(
&[(alice, 0), (bob, 1), (chris, 2), (darren, 3)],
&[0, 1, 2, 3],
);
for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) { for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) {
assert_relative_eq!(rating, expected, epsilon = EPSILON); assert_relative_eq!(rating, expected, epsilon = EPSILON);