This commit is contained in:
2018-10-26 13:15:35 +02:00
parent 3424b5f45f
commit e3e6ced26f
3 changed files with 25 additions and 27 deletions

View File

@@ -234,16 +234,17 @@ impl SumFactor {
fy: Vec<Gaussian>, fy: Vec<Gaussian>,
a: &[f64], a: &[f64],
) { ) {
let (sum_pi, sum_tau) = a.iter() let (sum_pi, sum_tau) =
.zip(y.iter().zip(fy.iter())) a.iter()
.fold((0.0, 0.0f64), |(pi, tau), (a, (y, fy))| { .zip(y.iter().zip(fy.iter()))
let x = *y / *fy; .fold((0.0, 0.0f64), |(pi, tau), (a, (y, fy))| {
let x = *y / *fy;
let new_pi = a.powi(2) / x.pi(); let new_pi = a.powi(2) / x.pi();
let new_tau = a * x.mu(); let new_tau = a * x.mu();
(pi + new_pi, tau + new_tau) (pi + new_pi, tau + new_tau)
}); });
let new_pi = 1.0 / sum_pi; let new_pi = 1.0 / sum_pi;
let new_tau = new_pi * sum_tau; let new_tau = new_pi * sum_tau;
@@ -285,10 +286,12 @@ impl SumFactor {
.coeffs .coeffs
.iter() .iter()
.enumerate() .enumerate()
.map(|(i, coeff)| if i == index { .map(|(i, coeff)| {
1.0 / idx_coeff if i == index {
} else { 1.0 / idx_coeff
-coeff / idx_coeff } else {
-coeff / idx_coeff
}
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();

View File

@@ -9,7 +9,14 @@ pub struct Gaussian {
impl fmt::Debug for Gaussian { impl fmt::Debug for Gaussian {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "N(mu={:.03}, sigma={:.03}, pi={:.03}, tau={:.03})", self.mu(), self.sigma(), self.pi(), self.tau()) write!(
f,
"N(mu={:.03}, sigma={:.03}, pi={:.03}, tau={:.03})",
self.mu(),
self.sigma(),
self.pi(),
self.tau()
)
} }
} }

View File

@@ -536,7 +536,6 @@ mod tests {
let ratings = vec![ let ratings = vec![
(Rating::new(42.234, 3.728), 0), (Rating::new(42.234, 3.728), 0),
(Rating::new(43.290, 3.842), 0), (Rating::new(43.290, 3.842), 0),
(Rating::new(16.667, 0.500), 1), (Rating::new(16.667, 0.500), 1),
(Rating::new(16.667, 0.500), 1), (Rating::new(16.667, 0.500), 1),
(Rating::new(16.667, 0.500), 1), (Rating::new(16.667, 0.500), 1),
@@ -556,7 +555,6 @@ mod tests {
let expected_ratings = vec![ let expected_ratings = vec![
Rating::new(49.04740037944730, 3.64544755829110), Rating::new(49.04740037944730, 3.64544755829110),
Rating::new(50.52625967591528, 3.75146367648882), Rating::new(50.52625967591528, 3.75146367648882),
Rating::new(16.54109749124066, 0.50668947816812), Rating::new(16.54109749124066, 0.50668947816812),
Rating::new(16.54109749124066, 0.50668947816812), Rating::new(16.54109749124066, 0.50668947816812),
Rating::new(16.54109749124066, 0.50668947816812), Rating::new(16.54109749124066, 0.50668947816812),
@@ -573,11 +571,7 @@ mod tests {
Rating::new(16.54109749124066, 0.50668947816812), Rating::new(16.54109749124066, 0.50668947816812),
]; ];
let ratings = rate( let ratings = rate(ratings.as_ref(), &[0, 1], DELTA);
ratings.as_ref(),
&[0, 1],
DELTA,
);
for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) { for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) {
assert_relative_eq!(rating, expected, epsilon = EPSILON); assert_relative_eq!(rating, expected, epsilon = EPSILON);
@@ -603,7 +597,6 @@ mod tests {
(Rating::new(25.000, 0.500), 0), (Rating::new(25.000, 0.500), 0),
(Rating::new(25.000, 0.500), 0), (Rating::new(25.000, 0.500), 0),
(Rating::new(25.000, 0.500), 0), (Rating::new(25.000, 0.500), 0),
(Rating::new(42.234, 3.728), 1), (Rating::new(42.234, 3.728), 1),
(Rating::new(43.290, 3.842), 1), (Rating::new(43.290, 3.842), 1),
]; ];
@@ -623,16 +616,11 @@ mod tests {
Rating::new(25.00000000000000, 0.50689687752485), Rating::new(25.00000000000000, 0.50689687752485),
Rating::new(25.00000000000000, 0.50689687752485), Rating::new(25.00000000000000, 0.50689687752485),
Rating::new(25.00000000000000, 0.50689687752485), Rating::new(25.00000000000000, 0.50689687752485),
Rating::new(42.23400000000000, 3.72893127376255), Rating::new(42.23400000000000, 3.72893127376255),
Rating::new(43.29000000000000, 3.84290364756188), Rating::new(43.29000000000000, 3.84290364756188),
]; ];
let ratings = rate( let ratings = rate(ratings.as_ref(), &[0, 1], DELTA);
ratings.as_ref(),
&[0, 1],
DELTA,
);
for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) { for (rating, expected) in ratings.iter().zip(expected_ratings.iter()) {
assert_relative_eq!(rating, expected, epsilon = EPSILON); assert_relative_eq!(rating, expected, epsilon = EPSILON);