This commit is contained in:
2018-10-24 11:13:06 +02:00
parent be74c1eac7
commit 61f8a9ccce
3 changed files with 25 additions and 11 deletions

View File

@@ -101,7 +101,10 @@ impl PriorFactor {
}
pub fn start(&self, variable_arena: &mut VariableArena) {
variable_arena.get_mut(self.variable).unwrap().update_value(self.id, self.gaussian);
variable_arena
.get_mut(self.variable)
.unwrap()
.update_value(self.id, self.gaussian);
}
}
@@ -146,7 +149,10 @@ impl LikelihoodFactor {
let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau()));
variable_arena.get_mut(self.mean).unwrap().update_message(self.id, gaussian);
variable_arena
.get_mut(self.mean)
.unwrap()
.update_message(self.id, gaussian);
}
pub fn update_value(&self, variable_arena: &mut VariableArena) {
@@ -164,7 +170,10 @@ impl LikelihoodFactor {
let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau()));
variable_arena.get_mut(self.value).unwrap().update_message(self.id, gaussian);
variable_arena
.get_mut(self.value)
.unwrap()
.update_message(self.id, gaussian);
}
}
@@ -224,7 +233,10 @@ impl SumFactor {
let gaussian = Gaussian::from_pi_tau(new_pi, new_tau);
variable_arena.get_mut(variable).unwrap().update_message(self.id, gaussian);
variable_arena
.get_mut(variable)
.unwrap()
.update_message(self.id, gaussian);
}
pub fn update_sum(&self, variable_arena: &mut VariableArena) {
@@ -372,6 +384,9 @@ impl TruncateFactor {
let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w);
variable_arena.get_mut(self.variable).unwrap().update_value(self.id, gaussian);
variable_arena
.get_mut(self.variable)
.unwrap()
.update_value(self.id, gaussian);
}
}

View File

@@ -1,5 +1,5 @@
extern crate statrs;
extern crate noisy_float;
extern crate statrs;
mod factor_graph;
mod gaussian;
@@ -28,7 +28,6 @@ const DRAW_PROBABILITY: f64 = 0.10;
/// A basis to check reliability of the result.
const DELTA: f64 = 0.0001;
fn draw_margin(p: f64, beta: f64, total_players: f64) -> f64 {
math::icdf((p + 1.0) / 2.0) * total_players.sqrt() * beta
}
@@ -78,7 +77,8 @@ pub fn rate(rating_groups: &[&[Gaussian]]) {
for (i, rating) in flatten_ratings.iter().enumerate() {
let variable = ss[i];
let gaussian = Gaussian::from_mu_sigma(rating.mu(), (rating.sigma().powi(2) + tau_sqr).sqrt());
let gaussian =
Gaussian::from_mu_sigma(rating.mu(), (rating.sigma().powi(2) + tau_sqr).sqrt());
skill.push(PriorFactor::new(
&mut variable_arena,

View File

@@ -1,4 +1,4 @@
use statrs::distribution::{Normal, Univariate, Continuous};
use statrs::distribution::{Continuous, Normal, Univariate};
const S2PI: f64 = 2.50662827463100050242E0;
@@ -137,7 +137,6 @@ pub fn icdf(x: f64) -> f64 {
ndtri(x)
}
#[cfg(test)]
mod tests {
use super::*;