Rustify some code.

This commit is contained in:
2018-10-24 20:15:29 +02:00
parent 77523e9db4
commit bcdabf9fbb
3 changed files with 85 additions and 118 deletions

View File

@@ -52,9 +52,7 @@ impl Variable {
pub fn update_value(&mut self, factor: usize, value: Gaussian) { pub fn update_value(&mut self, factor: usize, value: Gaussian) {
let old = self.factors[&factor]; let old = self.factors[&factor];
let intermediate = value * old; self.factors.insert(factor, value * old / self.value);
self.factors.insert(factor, intermediate / self.value);
self.value = value; self.value = value;
} }
@@ -65,12 +63,8 @@ impl Variable {
pub fn update_message(&mut self, factor: usize, message: Gaussian) { pub fn update_message(&mut self, factor: usize, message: Gaussian) {
let old = self.factors[&factor]; let old = self.factors[&factor];
let intermediate = self.value / old;
let value = intermediate * message;
self.value = value;
self.factors.insert(factor, message); self.factors.insert(factor, message);
self.value = self.value / old * message;
} }
pub fn get_message(&self, factor: usize) -> Gaussian { pub fn get_message(&self, factor: usize) -> Gaussian {
@@ -135,18 +129,15 @@ impl LikelihoodFactor {
} }
pub fn update_mean(&self, variable_arena: &mut VariableArena) { pub fn update_mean(&self, variable_arena: &mut VariableArena) {
let x = variable_arena let (x, fx) = variable_arena
.get(self.value) .get(self.value)
.map(|variable| variable.get_value()) .map(|variable| (
.unwrap(); variable.get_value(),
variable.get_message(self.id)
let fx = variable_arena ))
.get_mut(self.value)
.map(|variable| variable.get_message(self.id))
.unwrap(); .unwrap();
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi())); let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau())); let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau()));
variable_arena variable_arena
@@ -156,18 +147,15 @@ impl LikelihoodFactor {
} }
pub fn update_value(&self, variable_arena: &mut VariableArena) { pub fn update_value(&self, variable_arena: &mut VariableArena) {
let y = variable_arena let (y, fy) = variable_arena
.get(self.mean) .get(self.mean)
.map(|variable| variable.get_value()) .map(|variable| (
.unwrap(); variable.get_value(),
variable.get_message(self.id)
let fy = variable_arena ))
.get(self.mean)
.map(|variable| variable.get_message(self.id))
.unwrap(); .unwrap();
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi())); let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau())); let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau()));
variable_arena variable_arena
@@ -212,21 +200,17 @@ impl SumFactor {
variable: VariableId, variable: VariableId,
y: Vec<Gaussian>, y: Vec<Gaussian>,
fy: Vec<Gaussian>, fy: Vec<Gaussian>,
a: &Vec<f64>, a: &[f64],
) { ) {
let size = a.len(); let (sum_pi, sum_tau) = a.iter().zip(y.iter().zip(fy.iter()))
.fold((0.0, 0.0), |(pi, tau), (a, (y, fy))| {
let x = *y / *fy;
let mut sum_pi = 0.0; (
let mut sum_tau = 0.0; pi + a.powi(2) / x.pi(),
tau + a * x.tau() / x.pi()
for i in 0..size { )
let da = a[i]; });
let gy = y[i];
let gfy = fy[i];
sum_pi += da.powi(2) / (gy.pi() - gfy.pi());
sum_tau += da * (gy.tau() - gfy.tau()) / (gy.pi() - gfy.pi());
}
let new_pi = 1.0 / sum_pi; let new_pi = 1.0 / sum_pi;
let new_tau = new_pi * sum_tau; let new_tau = new_pi * sum_tau;
@@ -240,69 +224,57 @@ impl SumFactor {
} }
pub fn update_sum(&self, variable_arena: &mut VariableArena) { pub fn update_sum(&self, variable_arena: &mut VariableArena) {
let mut y = Vec::new(); let (y, fy) = self.terms
.iter()
for term in &self.terms { .map(|term| {
let value = variable_arena variable_arena
.get(*term) .get(*term)
.map(|variable| variable.get_value()) .map(|variable| (
.unwrap(); variable.get_value(),
variable.get_message(self.id)
y.push(value); ))
} .unwrap()
})
let mut fy = Vec::new(); .unzip();
for term in &self.terms {
let value = variable_arena
.get(*term)
.map(|variable| variable.get_message(self.id))
.unwrap();
fy.push(value);
}
self.internal_update(variable_arena, self.sum, y, fy, &self.coeffs); self.internal_update(variable_arena, self.sum, y, fy, &self.coeffs);
} }
pub fn update_term(&self, variable_arena: &mut VariableArena, index: usize) { pub fn update_term(&self, variable_arena: &mut VariableArena, index: usize) {
let size = self.coeffs.len(); let idx_term = self.terms[index];
let idx_coeff = self.coeffs[index]; let idx_coeff = self.coeffs[index];
let mut a = vec![0.0; size]; let a = self.coeffs
.iter()
.enumerate()
.map(|(i, coeff)| {
if i == index {
1.0 / idx_coeff
} else {
-coeff / idx_coeff
}
})
.collect::<Vec<_>>();
for i in 0..size { let (y, fy) = self.terms
if i != index { .iter()
a[i] = -self.coeffs[i] / idx_coeff; .enumerate()
} .map(|(i, term)| {
} let variable = if i == index {
self.sum
} else {
*term
};
a[index] = 1.0 / idx_coeff; variable_arena
.get(variable)
let idx_term = self.terms[index]; .map(|variable| (
variable.get_value(),
let mut y = Vec::new(); variable.get_message(self.id)
let mut fy = Vec::new(); ))
.unwrap()
let mut v = self.terms.clone(); })
.unzip();
v[index] = self.sum;
for term in &v {
let value = variable_arena
.get(*term)
.map(|variable| variable.get_value())
.unwrap();
y.push(value);
let value = variable_arena
.get(*term)
.map(|variable| variable.get_message(self.id))
.unwrap();
fy.push(value);
}
self.internal_update(variable_arena, idx_term, y, fy, &a); self.internal_update(variable_arena, idx_term, y, fy, &a);
} }
@@ -356,14 +328,12 @@ impl TruncateFactor {
} }
pub fn update(&self, variable_arena: &mut VariableArena) { pub fn update(&self, variable_arena: &mut VariableArena) {
let x = variable_arena let (x, fx) = variable_arena
.get(self.variable) .get(self.variable)
.map(|variable| variable.get_value()) .map(|variable| (
.unwrap(); variable.get_value(),
variable.get_message(self.id)
let fx = variable_arena ))
.get_mut(self.variable)
.map(|variable| variable.get_message(self.id))
.unwrap(); .unwrap();
let c = x.pi() - fx.pi(); let c = x.pi() - fx.pi();

View File

@@ -295,8 +295,6 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::f64;
use approx::{AbsDiffEq, RelativeEq}; use approx::{AbsDiffEq, RelativeEq};
use super::*; use super::*;
@@ -306,12 +304,10 @@ mod tests {
impl AbsDiffEq for Rating { impl AbsDiffEq for Rating {
type Epsilon = f64; type Epsilon = f64;
#[inline]
fn default_epsilon() -> Self::Epsilon { fn default_epsilon() -> Self::Epsilon {
f64::default_epsilon() f64::default_epsilon()
} }
#[inline]
fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool { fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
self.mu.abs_diff_eq(&other.mu, epsilon) && self.sigma.abs_diff_eq(&other.sigma, epsilon) self.mu.abs_diff_eq(&other.mu, epsilon) && self.sigma.abs_diff_eq(&other.sigma, epsilon)
} }
@@ -338,7 +334,11 @@ mod tests {
let alice = Rating::new(MU, SIGMA); let alice = Rating::new(MU, SIGMA);
let bob = Rating::new(MU, SIGMA); let bob = Rating::new(MU, SIGMA);
assert_relative_eq!(quality(&[&[alice], &[bob]]), 0.4472135954999579, epsilon = EPSILON); assert_relative_eq!(
quality(&[&[alice], &[bob]]),
0.4472135954999579,
epsilon = EPSILON
);
} }
#[test] #[test]

View File

@@ -88,41 +88,38 @@ fn p1evl(x: f64, coef: &[f64], n: usize) -> f64 {
} }
fn ndtri(y0: f64) -> f64 { fn ndtri(y0: f64) -> f64 {
let mut code = 1;
let mut y = y0; let mut y = y0;
if y > (1.0 - 0.13533528323661269189) { let code = if y > (1.0 - 0.13533528323661269189) {
y = 1.0 - y; y = 1.0 - y;
code = 0;
} false
} else {
true
};
if y > 0.13533528323661269189 { if y > 0.13533528323661269189 {
y = y - 0.5; y = y - 0.5;
let y2 = y * y; let y2 = y * y;
let x = y + y * (y2 * polevl(y2, &P0, 4) / p1evl(y2, &Q0, 8));
let x = x * S2PI;
return x; return (y + y * (y2 * polevl(y2, &P0, 4) / p1evl(y2, &Q0, 8))) * S2PI;
} }
let x = (-2.0 * y.ln()).sqrt(); let x = (-2.0 * y.ln()).sqrt();
let x0 = x - x.ln() / x;
let z = 1.0 / x; let z = 1.0 / x;
let x0 = x - x.ln() / x;
let x1 = if x < 8.0 { let x1 = if x < 8.0 {
z * polevl(z, &P1, 8) / p1evl(z, &Q1, 8) z * polevl(z, &P1, 8) / p1evl(z, &Q1, 8)
} else { } else {
z * polevl(z, &P2, 8) / p1evl(z, &Q2, 8) z * polevl(z, &P2, 8) / p1evl(z, &Q2, 8)
}; };
let mut x = x0 - x1; if code {
x1 - x0
if code != 0 { } else {
x = -x; x0 - x1
} }
x
} }
pub fn cdf(x: f64) -> f64 { pub fn cdf(x: f64) -> f64 {