Rustify some code.
This commit is contained in:
@@ -52,9 +52,7 @@ impl Variable {
|
|||||||
pub fn update_value(&mut self, factor: usize, value: Gaussian) {
|
pub fn update_value(&mut self, factor: usize, value: Gaussian) {
|
||||||
let old = self.factors[&factor];
|
let old = self.factors[&factor];
|
||||||
|
|
||||||
let intermediate = value * old;
|
self.factors.insert(factor, value * old / self.value);
|
||||||
self.factors.insert(factor, intermediate / self.value);
|
|
||||||
|
|
||||||
self.value = value;
|
self.value = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,12 +63,8 @@ impl Variable {
|
|||||||
pub fn update_message(&mut self, factor: usize, message: Gaussian) {
|
pub fn update_message(&mut self, factor: usize, message: Gaussian) {
|
||||||
let old = self.factors[&factor];
|
let old = self.factors[&factor];
|
||||||
|
|
||||||
let intermediate = self.value / old;
|
|
||||||
let value = intermediate * message;
|
|
||||||
|
|
||||||
self.value = value;
|
|
||||||
|
|
||||||
self.factors.insert(factor, message);
|
self.factors.insert(factor, message);
|
||||||
|
self.value = self.value / old * message;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_message(&self, factor: usize) -> Gaussian {
|
pub fn get_message(&self, factor: usize) -> Gaussian {
|
||||||
@@ -135,18 +129,15 @@ impl LikelihoodFactor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn update_mean(&self, variable_arena: &mut VariableArena) {
|
pub fn update_mean(&self, variable_arena: &mut VariableArena) {
|
||||||
let x = variable_arena
|
let (x, fx) = variable_arena
|
||||||
.get(self.value)
|
.get(self.value)
|
||||||
.map(|variable| variable.get_value())
|
.map(|variable| (
|
||||||
.unwrap();
|
variable.get_value(),
|
||||||
|
variable.get_message(self.id)
|
||||||
let fx = variable_arena
|
))
|
||||||
.get_mut(self.value)
|
|
||||||
.map(|variable| variable.get_message(self.id))
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
|
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
|
||||||
|
|
||||||
let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau()));
|
let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau()));
|
||||||
|
|
||||||
variable_arena
|
variable_arena
|
||||||
@@ -156,18 +147,15 @@ impl LikelihoodFactor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn update_value(&self, variable_arena: &mut VariableArena) {
|
pub fn update_value(&self, variable_arena: &mut VariableArena) {
|
||||||
let y = variable_arena
|
let (y, fy) = variable_arena
|
||||||
.get(self.mean)
|
.get(self.mean)
|
||||||
.map(|variable| variable.get_value())
|
.map(|variable| (
|
||||||
.unwrap();
|
variable.get_value(),
|
||||||
|
variable.get_message(self.id)
|
||||||
let fy = variable_arena
|
))
|
||||||
.get(self.mean)
|
|
||||||
.map(|variable| variable.get_message(self.id))
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
|
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
|
||||||
|
|
||||||
let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau()));
|
let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau()));
|
||||||
|
|
||||||
variable_arena
|
variable_arena
|
||||||
@@ -212,21 +200,17 @@ impl SumFactor {
|
|||||||
variable: VariableId,
|
variable: VariableId,
|
||||||
y: Vec<Gaussian>,
|
y: Vec<Gaussian>,
|
||||||
fy: Vec<Gaussian>,
|
fy: Vec<Gaussian>,
|
||||||
a: &Vec<f64>,
|
a: &[f64],
|
||||||
) {
|
) {
|
||||||
let size = a.len();
|
let (sum_pi, sum_tau) = a.iter().zip(y.iter().zip(fy.iter()))
|
||||||
|
.fold((0.0, 0.0), |(pi, tau), (a, (y, fy))| {
|
||||||
|
let x = *y / *fy;
|
||||||
|
|
||||||
let mut sum_pi = 0.0;
|
(
|
||||||
let mut sum_tau = 0.0;
|
pi + a.powi(2) / x.pi(),
|
||||||
|
tau + a * x.tau() / x.pi()
|
||||||
for i in 0..size {
|
)
|
||||||
let da = a[i];
|
});
|
||||||
let gy = y[i];
|
|
||||||
let gfy = fy[i];
|
|
||||||
|
|
||||||
sum_pi += da.powi(2) / (gy.pi() - gfy.pi());
|
|
||||||
sum_tau += da * (gy.tau() - gfy.tau()) / (gy.pi() - gfy.pi());
|
|
||||||
}
|
|
||||||
|
|
||||||
let new_pi = 1.0 / sum_pi;
|
let new_pi = 1.0 / sum_pi;
|
||||||
let new_tau = new_pi * sum_tau;
|
let new_tau = new_pi * sum_tau;
|
||||||
@@ -240,69 +224,57 @@ impl SumFactor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn update_sum(&self, variable_arena: &mut VariableArena) {
|
pub fn update_sum(&self, variable_arena: &mut VariableArena) {
|
||||||
let mut y = Vec::new();
|
let (y, fy) = self.terms
|
||||||
|
.iter()
|
||||||
for term in &self.terms {
|
.map(|term| {
|
||||||
let value = variable_arena
|
variable_arena
|
||||||
.get(*term)
|
.get(*term)
|
||||||
.map(|variable| variable.get_value())
|
.map(|variable| (
|
||||||
.unwrap();
|
variable.get_value(),
|
||||||
|
variable.get_message(self.id)
|
||||||
y.push(value);
|
))
|
||||||
}
|
.unwrap()
|
||||||
|
})
|
||||||
let mut fy = Vec::new();
|
.unzip();
|
||||||
|
|
||||||
for term in &self.terms {
|
|
||||||
let value = variable_arena
|
|
||||||
.get(*term)
|
|
||||||
.map(|variable| variable.get_message(self.id))
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
fy.push(value);
|
|
||||||
}
|
|
||||||
|
|
||||||
self.internal_update(variable_arena, self.sum, y, fy, &self.coeffs);
|
self.internal_update(variable_arena, self.sum, y, fy, &self.coeffs);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn update_term(&self, variable_arena: &mut VariableArena, index: usize) {
|
pub fn update_term(&self, variable_arena: &mut VariableArena, index: usize) {
|
||||||
let size = self.coeffs.len();
|
let idx_term = self.terms[index];
|
||||||
let idx_coeff = self.coeffs[index];
|
let idx_coeff = self.coeffs[index];
|
||||||
|
|
||||||
let mut a = vec![0.0; size];
|
let a = self.coeffs
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, coeff)| {
|
||||||
|
if i == index {
|
||||||
|
1.0 / idx_coeff
|
||||||
|
} else {
|
||||||
|
-coeff / idx_coeff
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
for i in 0..size {
|
let (y, fy) = self.terms
|
||||||
if i != index {
|
.iter()
|
||||||
a[i] = -self.coeffs[i] / idx_coeff;
|
.enumerate()
|
||||||
}
|
.map(|(i, term)| {
|
||||||
}
|
let variable = if i == index {
|
||||||
|
self.sum
|
||||||
|
} else {
|
||||||
|
*term
|
||||||
|
};
|
||||||
|
|
||||||
a[index] = 1.0 / idx_coeff;
|
variable_arena
|
||||||
|
.get(variable)
|
||||||
let idx_term = self.terms[index];
|
.map(|variable| (
|
||||||
|
variable.get_value(),
|
||||||
let mut y = Vec::new();
|
variable.get_message(self.id)
|
||||||
let mut fy = Vec::new();
|
))
|
||||||
|
.unwrap()
|
||||||
let mut v = self.terms.clone();
|
})
|
||||||
|
.unzip();
|
||||||
v[index] = self.sum;
|
|
||||||
|
|
||||||
for term in &v {
|
|
||||||
let value = variable_arena
|
|
||||||
.get(*term)
|
|
||||||
.map(|variable| variable.get_value())
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
y.push(value);
|
|
||||||
|
|
||||||
let value = variable_arena
|
|
||||||
.get(*term)
|
|
||||||
.map(|variable| variable.get_message(self.id))
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
fy.push(value);
|
|
||||||
}
|
|
||||||
|
|
||||||
self.internal_update(variable_arena, idx_term, y, fy, &a);
|
self.internal_update(variable_arena, idx_term, y, fy, &a);
|
||||||
}
|
}
|
||||||
@@ -356,14 +328,12 @@ impl TruncateFactor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn update(&self, variable_arena: &mut VariableArena) {
|
pub fn update(&self, variable_arena: &mut VariableArena) {
|
||||||
let x = variable_arena
|
let (x, fx) = variable_arena
|
||||||
.get(self.variable)
|
.get(self.variable)
|
||||||
.map(|variable| variable.get_value())
|
.map(|variable| (
|
||||||
.unwrap();
|
variable.get_value(),
|
||||||
|
variable.get_message(self.id)
|
||||||
let fx = variable_arena
|
))
|
||||||
.get_mut(self.variable)
|
|
||||||
.map(|variable| variable.get_message(self.id))
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let c = x.pi() - fx.pi();
|
let c = x.pi() - fx.pi();
|
||||||
|
|||||||
10
src/lib.rs
10
src/lib.rs
@@ -295,8 +295,6 @@ where
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::f64;
|
|
||||||
|
|
||||||
use approx::{AbsDiffEq, RelativeEq};
|
use approx::{AbsDiffEq, RelativeEq};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -306,12 +304,10 @@ mod tests {
|
|||||||
impl AbsDiffEq for Rating {
|
impl AbsDiffEq for Rating {
|
||||||
type Epsilon = f64;
|
type Epsilon = f64;
|
||||||
|
|
||||||
#[inline]
|
|
||||||
fn default_epsilon() -> Self::Epsilon {
|
fn default_epsilon() -> Self::Epsilon {
|
||||||
f64::default_epsilon()
|
f64::default_epsilon()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
|
||||||
fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
|
fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
|
||||||
self.mu.abs_diff_eq(&other.mu, epsilon) && self.sigma.abs_diff_eq(&other.sigma, epsilon)
|
self.mu.abs_diff_eq(&other.mu, epsilon) && self.sigma.abs_diff_eq(&other.sigma, epsilon)
|
||||||
}
|
}
|
||||||
@@ -338,7 +334,11 @@ mod tests {
|
|||||||
let alice = Rating::new(MU, SIGMA);
|
let alice = Rating::new(MU, SIGMA);
|
||||||
let bob = Rating::new(MU, SIGMA);
|
let bob = Rating::new(MU, SIGMA);
|
||||||
|
|
||||||
assert_relative_eq!(quality(&[&[alice], &[bob]]), 0.4472135954999579, epsilon = EPSILON);
|
assert_relative_eq!(
|
||||||
|
quality(&[&[alice], &[bob]]),
|
||||||
|
0.4472135954999579,
|
||||||
|
epsilon = EPSILON
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
27
src/math.rs
27
src/math.rs
@@ -88,41 +88,38 @@ fn p1evl(x: f64, coef: &[f64], n: usize) -> f64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn ndtri(y0: f64) -> f64 {
|
fn ndtri(y0: f64) -> f64 {
|
||||||
let mut code = 1;
|
|
||||||
let mut y = y0;
|
let mut y = y0;
|
||||||
|
|
||||||
if y > (1.0 - 0.13533528323661269189) {
|
let code = if y > (1.0 - 0.13533528323661269189) {
|
||||||
y = 1.0 - y;
|
y = 1.0 - y;
|
||||||
code = 0;
|
|
||||||
}
|
false
|
||||||
|
} else {
|
||||||
|
true
|
||||||
|
};
|
||||||
|
|
||||||
if y > 0.13533528323661269189 {
|
if y > 0.13533528323661269189 {
|
||||||
y = y - 0.5;
|
y = y - 0.5;
|
||||||
let y2 = y * y;
|
let y2 = y * y;
|
||||||
let x = y + y * (y2 * polevl(y2, &P0, 4) / p1evl(y2, &Q0, 8));
|
|
||||||
let x = x * S2PI;
|
|
||||||
|
|
||||||
return x;
|
return (y + y * (y2 * polevl(y2, &P0, 4) / p1evl(y2, &Q0, 8))) * S2PI;
|
||||||
}
|
}
|
||||||
|
|
||||||
let x = (-2.0 * y.ln()).sqrt();
|
let x = (-2.0 * y.ln()).sqrt();
|
||||||
let x0 = x - x.ln() / x;
|
|
||||||
|
|
||||||
let z = 1.0 / x;
|
let z = 1.0 / x;
|
||||||
|
|
||||||
|
let x0 = x - x.ln() / x;
|
||||||
let x1 = if x < 8.0 {
|
let x1 = if x < 8.0 {
|
||||||
z * polevl(z, &P1, 8) / p1evl(z, &Q1, 8)
|
z * polevl(z, &P1, 8) / p1evl(z, &Q1, 8)
|
||||||
} else {
|
} else {
|
||||||
z * polevl(z, &P2, 8) / p1evl(z, &Q2, 8)
|
z * polevl(z, &P2, 8) / p1evl(z, &Q2, 8)
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut x = x0 - x1;
|
if code {
|
||||||
|
x1 - x0
|
||||||
if code != 0 {
|
} else {
|
||||||
x = -x;
|
x0 - x1
|
||||||
}
|
}
|
||||||
|
|
||||||
x
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cdf(x: f64) -> f64 {
|
pub fn cdf(x: f64) -> f64 {
|
||||||
|
|||||||
Reference in New Issue
Block a user