From bcdabf9fbbe54dc13a9e0312d6fca6e4b5cd22ec Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Wed, 24 Oct 2018 20:15:29 +0200 Subject: [PATCH] Rustify some code. --- src/factor_graph.rs | 166 ++++++++++++++++++-------------------------- src/lib.rs | 10 +-- src/math.rs | 27 ++++--- 3 files changed, 85 insertions(+), 118 deletions(-) diff --git a/src/factor_graph.rs b/src/factor_graph.rs index 4934b1d..c5d66bf 100644 --- a/src/factor_graph.rs +++ b/src/factor_graph.rs @@ -52,9 +52,7 @@ impl Variable { pub fn update_value(&mut self, factor: usize, value: Gaussian) { let old = self.factors[&factor]; - let intermediate = value * old; - self.factors.insert(factor, intermediate / self.value); - + self.factors.insert(factor, value * old / self.value); self.value = value; } @@ -65,12 +63,8 @@ impl Variable { pub fn update_message(&mut self, factor: usize, message: Gaussian) { let old = self.factors[&factor]; - let intermediate = self.value / old; - let value = intermediate * message; - - self.value = value; - self.factors.insert(factor, message); + self.value = self.value / old * message; } pub fn get_message(&self, factor: usize) -> Gaussian { @@ -135,18 +129,15 @@ impl LikelihoodFactor { } pub fn update_mean(&self, variable_arena: &mut VariableArena) { - let x = variable_arena + let (x, fx) = variable_arena .get(self.value) - .map(|variable| variable.get_value()) - .unwrap(); - - let fx = variable_arena - .get_mut(self.value) - .map(|variable| variable.get_message(self.id)) + .map(|variable| ( + variable.get_value(), + variable.get_message(self.id) + )) .unwrap(); let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi())); - let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau())); variable_arena @@ -156,18 +147,15 @@ impl LikelihoodFactor { } pub fn update_value(&self, variable_arena: &mut VariableArena) { - let y = variable_arena + let (y, fy) = variable_arena .get(self.mean) - .map(|variable| variable.get_value()) - .unwrap(); - - let fy = variable_arena - .get(self.mean) - .map(|variable| variable.get_message(self.id)) + .map(|variable| ( + variable.get_value(), + variable.get_message(self.id) + )) .unwrap(); let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi())); - let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau())); variable_arena @@ -212,21 +200,17 @@ impl SumFactor { variable: VariableId, y: Vec, fy: Vec, - a: &Vec, + a: &[f64], ) { - let size = a.len(); + let (sum_pi, sum_tau) = a.iter().zip(y.iter().zip(fy.iter())) + .fold((0.0, 0.0), |(pi, tau), (a, (y, fy))| { + let x = *y / *fy; - let mut sum_pi = 0.0; - let mut sum_tau = 0.0; - - for i in 0..size { - let da = a[i]; - let gy = y[i]; - let gfy = fy[i]; - - sum_pi += da.powi(2) / (gy.pi() - gfy.pi()); - sum_tau += da * (gy.tau() - gfy.tau()) / (gy.pi() - gfy.pi()); - } + ( + pi + a.powi(2) / x.pi(), + tau + a * x.tau() / x.pi() + ) + }); let new_pi = 1.0 / sum_pi; let new_tau = new_pi * sum_tau; @@ -240,69 +224,57 @@ impl SumFactor { } pub fn update_sum(&self, variable_arena: &mut VariableArena) { - let mut y = Vec::new(); - - for term in &self.terms { - let value = variable_arena - .get(*term) - .map(|variable| variable.get_value()) - .unwrap(); - - y.push(value); - } - - let mut fy = Vec::new(); - - for term in &self.terms { - let value = variable_arena - .get(*term) - .map(|variable| variable.get_message(self.id)) - .unwrap(); - - fy.push(value); - } + let (y, fy) = self.terms + .iter() + .map(|term| { + variable_arena + .get(*term) + .map(|variable| ( + variable.get_value(), + variable.get_message(self.id) + )) + .unwrap() + }) + .unzip(); self.internal_update(variable_arena, self.sum, y, fy, &self.coeffs); } pub fn update_term(&self, variable_arena: &mut VariableArena, index: usize) { - let size = self.coeffs.len(); + let idx_term = self.terms[index]; let idx_coeff = self.coeffs[index]; - let mut a = vec![0.0; size]; + let a = self.coeffs + .iter() + .enumerate() + .map(|(i, coeff)| { + if i == index { + 1.0 / idx_coeff + } else { + -coeff / idx_coeff + } + }) + .collect::>(); - for i in 0..size { - if i != index { - a[i] = -self.coeffs[i] / idx_coeff; - } - } + let (y, fy) = self.terms + .iter() + .enumerate() + .map(|(i, term)| { + let variable = if i == index { + self.sum + } else { + *term + }; - a[index] = 1.0 / idx_coeff; - - let idx_term = self.terms[index]; - - let mut y = Vec::new(); - let mut fy = Vec::new(); - - let mut v = self.terms.clone(); - - v[index] = self.sum; - - for term in &v { - let value = variable_arena - .get(*term) - .map(|variable| variable.get_value()) - .unwrap(); - - y.push(value); - - let value = variable_arena - .get(*term) - .map(|variable| variable.get_message(self.id)) - .unwrap(); - - fy.push(value); - } + variable_arena + .get(variable) + .map(|variable| ( + variable.get_value(), + variable.get_message(self.id) + )) + .unwrap() + }) + .unzip(); self.internal_update(variable_arena, idx_term, y, fy, &a); } @@ -356,14 +328,12 @@ impl TruncateFactor { } pub fn update(&self, variable_arena: &mut VariableArena) { - let x = variable_arena + let (x, fx) = variable_arena .get(self.variable) - .map(|variable| variable.get_value()) - .unwrap(); - - let fx = variable_arena - .get_mut(self.variable) - .map(|variable| variable.get_message(self.id)) + .map(|variable| ( + variable.get_value(), + variable.get_message(self.id) + )) .unwrap(); let c = x.pi() - fx.pi(); diff --git a/src/lib.rs b/src/lib.rs index 72d3f91..f1110e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -295,8 +295,6 @@ where #[cfg(test)] mod tests { - use std::f64; - use approx::{AbsDiffEq, RelativeEq}; use super::*; @@ -306,12 +304,10 @@ mod tests { impl AbsDiffEq for Rating { type Epsilon = f64; - #[inline] fn default_epsilon() -> Self::Epsilon { f64::default_epsilon() } - #[inline] fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool { self.mu.abs_diff_eq(&other.mu, epsilon) && self.sigma.abs_diff_eq(&other.sigma, epsilon) } @@ -338,7 +334,11 @@ mod tests { let alice = Rating::new(MU, SIGMA); let bob = Rating::new(MU, SIGMA); - assert_relative_eq!(quality(&[&[alice], &[bob]]), 0.4472135954999579, epsilon = EPSILON); + assert_relative_eq!( + quality(&[&[alice], &[bob]]), + 0.4472135954999579, + epsilon = EPSILON + ); } #[test] diff --git a/src/math.rs b/src/math.rs index 284cd83..6242a2e 100644 --- a/src/math.rs +++ b/src/math.rs @@ -88,41 +88,38 @@ fn p1evl(x: f64, coef: &[f64], n: usize) -> f64 { } fn ndtri(y0: f64) -> f64 { - let mut code = 1; let mut y = y0; - if y > (1.0 - 0.13533528323661269189) { + let code = if y > (1.0 - 0.13533528323661269189) { y = 1.0 - y; - code = 0; - } + + false + } else { + true + }; if y > 0.13533528323661269189 { y = y - 0.5; let y2 = y * y; - let x = y + y * (y2 * polevl(y2, &P0, 4) / p1evl(y2, &Q0, 8)); - let x = x * S2PI; - return x; + return (y + y * (y2 * polevl(y2, &P0, 4) / p1evl(y2, &Q0, 8))) * S2PI; } let x = (-2.0 * y.ln()).sqrt(); - let x0 = x - x.ln() / x; - let z = 1.0 / x; + let x0 = x - x.ln() / x; let x1 = if x < 8.0 { z * polevl(z, &P1, 8) / p1evl(z, &Q1, 8) } else { z * polevl(z, &P2, 8) / p1evl(z, &Q2, 8) }; - let mut x = x0 - x1; - - if code != 0 { - x = -x; + if code { + x1 - x0 + } else { + x0 - x1 } - - x } pub fn cdf(x: f64) -> f64 {