More refactoring
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
use std::collections::HashMap;
|
||||
use std::f64;
|
||||
use std::ops;
|
||||
|
||||
use log::*;
|
||||
|
||||
use crate::gaussian::Gaussian;
|
||||
use crate::graph::{MessageArena, MessageId, VariableArena, VariableId};
|
||||
use crate::math;
|
||||
|
||||
mod message;
|
||||
mod variable;
|
||||
|
||||
pub use message::*;
|
||||
pub use variable::*;
|
||||
|
||||
pub struct PriorFactor {
|
||||
id: MessageId,
|
||||
variable: VariableId,
|
||||
@@ -126,10 +127,8 @@ impl SumFactor {
|
||||
|
||||
fn internal_update(
|
||||
&self,
|
||||
variable_arena: &mut VariableArena,
|
||||
message_arena: &mut MessageArena,
|
||||
variable: VariableId,
|
||||
message: MessageId,
|
||||
variable: &mut Gaussian,
|
||||
message: &mut Gaussian,
|
||||
y: &[Gaussian],
|
||||
fy: &[Gaussian],
|
||||
a: &[f64],
|
||||
@@ -151,10 +150,10 @@ impl SumFactor {
|
||||
|
||||
let gaussian = Gaussian::from_pi_tau(new_pi, new_tau);
|
||||
|
||||
let old = message_arena[message];
|
||||
let old = *message;
|
||||
|
||||
message_arena[message] = gaussian;
|
||||
variable_arena[variable] = variable_arena[variable] / old * gaussian;
|
||||
*message = gaussian;
|
||||
*variable = *variable / old * gaussian;
|
||||
}
|
||||
|
||||
pub fn update_sum(&self, variable_arena: &mut VariableArena, message_arena: &mut MessageArena) {
|
||||
@@ -166,10 +165,8 @@ impl SumFactor {
|
||||
.unzip();
|
||||
|
||||
self.internal_update(
|
||||
variable_arena,
|
||||
message_arena,
|
||||
self.sum,
|
||||
self.sum_msg,
|
||||
&mut variable_arena[self.sum],
|
||||
&mut message_arena[self.sum_msg],
|
||||
&y,
|
||||
&fy,
|
||||
&self.coeffs,
|
||||
@@ -214,10 +211,8 @@ impl SumFactor {
|
||||
.unzip();
|
||||
|
||||
self.internal_update(
|
||||
variable_arena,
|
||||
message_arena,
|
||||
idx_term,
|
||||
idx_term_msg,
|
||||
&mut variable_arena[idx_term],
|
||||
&mut message_arena[idx_term_msg],
|
||||
&y,
|
||||
&fy,
|
||||
&a,
|
||||
|
||||
420
src/graph.rs
420
src/graph.rs
@@ -1,420 +0,0 @@
|
||||
mod message;
|
||||
mod variable;
|
||||
|
||||
pub use message::*;
|
||||
pub use variable::*;
|
||||
|
||||
/*
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub struct FactorId(usize);
|
||||
|
||||
pub struct Graph {
|
||||
variables: VariableArena,
|
||||
factors: Vec<Factor>,
|
||||
messages: MessageArena,
|
||||
}
|
||||
|
||||
impl Graph {
|
||||
#[inline(always)]
|
||||
pub fn add_variable(&mut self) -> VariableId {
|
||||
self.variables.create()
|
||||
}
|
||||
|
||||
pub fn add_factor(&mut self, factor: Factor) -> FactorId {
|
||||
let idx = self.factors.len();
|
||||
self.factors.push(factor);
|
||||
|
||||
FactorId(idx)
|
||||
}
|
||||
}
|
||||
|
||||
pub enum Factor {
|
||||
Prior { gaussian: Gaussian },
|
||||
}
|
||||
*/
|
||||
// FACTORS
|
||||
|
||||
/*
|
||||
pub struct PriorFactor {
|
||||
id: usize,
|
||||
variable: VariableId,
|
||||
gaussian: Gaussian,
|
||||
}
|
||||
|
||||
impl PriorFactor {
|
||||
pub fn new(
|
||||
variable_arena: &mut VariableArena,
|
||||
id: usize,
|
||||
variable: VariableId,
|
||||
gaussian: Gaussian,
|
||||
) -> PriorFactor {
|
||||
variable_arena[variable].attach_factor(id);
|
||||
|
||||
PriorFactor {
|
||||
id,
|
||||
variable,
|
||||
gaussian,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start(&self, variable_arena: &mut VariableArena) {
|
||||
debug!(
|
||||
"Prior::down var={:?}, value={:?}",
|
||||
self.variable.index, self.gaussian
|
||||
);
|
||||
|
||||
variable_arena[self.variable].update_value(self.id, self.gaussian);
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
/*
|
||||
pub struct Variable {
|
||||
value: Gaussian,
|
||||
factors: HashMap<usize, Gaussian>,
|
||||
}
|
||||
|
||||
impl Variable {
|
||||
pub fn attach_factor(&mut self, factor: usize) {
|
||||
self.factors.insert(factor, Gaussian::from_pi_tau(0.0, 0.0));
|
||||
}
|
||||
|
||||
pub fn update_value(&mut self, factor: usize, value: Gaussian) -> f64 {
|
||||
let old = self.factors[&factor];
|
||||
|
||||
self.factors.insert(factor, value * old / self.value);
|
||||
|
||||
let pi_delta = (self.value.pi() - value.pi()).abs();
|
||||
|
||||
let delta = if !pi_delta.is_finite() {
|
||||
0.0
|
||||
} else {
|
||||
let pi_delta = pi_delta.sqrt();
|
||||
let tau_delta = (self.value.tau() - value.tau()).abs();
|
||||
|
||||
if pi_delta > tau_delta {
|
||||
pi_delta
|
||||
} else {
|
||||
tau_delta
|
||||
}
|
||||
};
|
||||
|
||||
self.value = value;
|
||||
|
||||
debug!("Variable::value old={:?}, new={:?}", old, value);
|
||||
|
||||
delta
|
||||
}
|
||||
|
||||
pub fn get_value(&self) -> Gaussian {
|
||||
self.value
|
||||
}
|
||||
|
||||
pub fn get_message(&self, factor: usize) -> Gaussian {
|
||||
self.factors[&factor]
|
||||
}
|
||||
|
||||
pub fn update_message(&mut self, factor: usize, message: Gaussian) {
|
||||
let old = self.factors[&factor];
|
||||
|
||||
self.factors.insert(factor, message);
|
||||
self.value = self.value / old * message;
|
||||
|
||||
debug!("Variable::message old={:?}, new={:?}", old, message);
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
/*
|
||||
pub struct PriorFactor {
|
||||
id: usize,
|
||||
variable: VariableId,
|
||||
gaussian: Gaussian,
|
||||
}
|
||||
|
||||
impl PriorFactor {
|
||||
pub fn new(
|
||||
variable_arena: &mut VariableArena,
|
||||
id: usize,
|
||||
variable: VariableId,
|
||||
gaussian: Gaussian,
|
||||
) -> PriorFactor {
|
||||
variable_arena[variable].attach_factor(id);
|
||||
|
||||
PriorFactor {
|
||||
id,
|
||||
variable,
|
||||
gaussian,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start(&self, variable_arena: &mut VariableArena) {
|
||||
debug!(
|
||||
"Prior::down var={:?}, value={:?}",
|
||||
self.variable.index, self.gaussian
|
||||
);
|
||||
|
||||
variable_arena[self.variable].update_value(self.id, self.gaussian);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LikelihoodFactor {
|
||||
id: usize,
|
||||
mean: VariableId,
|
||||
value: VariableId,
|
||||
variance: f64,
|
||||
}
|
||||
|
||||
impl LikelihoodFactor {
|
||||
pub fn new(
|
||||
variable_arena: &mut VariableArena,
|
||||
id: usize,
|
||||
mean: VariableId,
|
||||
value: VariableId,
|
||||
variance: f64,
|
||||
) -> LikelihoodFactor {
|
||||
variable_arena[mean].attach_factor(id);
|
||||
variable_arena[value].attach_factor(id);
|
||||
|
||||
LikelihoodFactor {
|
||||
id,
|
||||
mean,
|
||||
value,
|
||||
variance,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update_mean(&self, variable_arena: &mut VariableArena) {
|
||||
let (x, fx) = {
|
||||
let variable = &variable_arena[self.value];
|
||||
|
||||
(variable.get_value(), variable.get_message(self.id))
|
||||
};
|
||||
|
||||
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
|
||||
let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau()));
|
||||
|
||||
debug!(
|
||||
"Likelihood::up var={:?}, value={:?}",
|
||||
self.mean.index, gaussian
|
||||
);
|
||||
|
||||
variable_arena[self.mean].update_message(self.id, gaussian);
|
||||
}
|
||||
|
||||
pub fn update_value(&self, variable_arena: &mut VariableArena) {
|
||||
let (y, fy) = {
|
||||
let variable = &variable_arena[self.mean];
|
||||
|
||||
(variable.get_value(), variable.get_message(self.id))
|
||||
};
|
||||
|
||||
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
|
||||
let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau()));
|
||||
|
||||
debug!(
|
||||
"Likelihood::down var={:?}, value={:?}",
|
||||
self.value.index, gaussian
|
||||
);
|
||||
|
||||
variable_arena[self.value].update_message(self.id, gaussian);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SumFactor {
|
||||
id: usize,
|
||||
sum: VariableId,
|
||||
terms: Vec<VariableId>,
|
||||
coeffs: Vec<f64>,
|
||||
}
|
||||
|
||||
impl SumFactor {
|
||||
pub fn new(
|
||||
variable_arena: &mut VariableArena,
|
||||
id: usize,
|
||||
sum: VariableId,
|
||||
terms: Vec<VariableId>,
|
||||
coeffs: Vec<f64>,
|
||||
) -> SumFactor {
|
||||
variable_arena[sum].attach_factor(id);
|
||||
|
||||
for term in &terms {
|
||||
variable_arena[*term].attach_factor(id);
|
||||
}
|
||||
|
||||
SumFactor {
|
||||
id,
|
||||
sum,
|
||||
terms,
|
||||
coeffs,
|
||||
}
|
||||
}
|
||||
|
||||
fn internal_update(
|
||||
&self,
|
||||
variable_arena: &mut VariableArena,
|
||||
variable: VariableId,
|
||||
y: &[Gaussian],
|
||||
fy: &[Gaussian],
|
||||
a: &[f64],
|
||||
) {
|
||||
let (sum_pi, sum_tau) =
|
||||
a.iter()
|
||||
.zip(y.iter().zip(fy.iter()))
|
||||
.fold((0.0, 0.0f64), |(pi, tau), (a, (y, fy))| {
|
||||
let x = *y / *fy;
|
||||
|
||||
let new_pi = a.powi(2) / x.pi();
|
||||
let new_tau = a * x.mu();
|
||||
|
||||
(pi + new_pi, tau + new_tau)
|
||||
});
|
||||
|
||||
let new_pi = 1.0 / sum_pi;
|
||||
let new_tau = new_pi * sum_tau;
|
||||
|
||||
let gaussian = Gaussian::from_pi_tau(new_pi, new_tau);
|
||||
|
||||
if variable == self.sum {
|
||||
debug!("Sum::down var={:?}, value={:?}", variable.index, gaussian);
|
||||
} else {
|
||||
debug!("Sum::up var={:?}, value={:?}", variable.index, gaussian);
|
||||
}
|
||||
|
||||
variable_arena[variable].update_message(self.id, gaussian);
|
||||
}
|
||||
|
||||
pub fn update_sum(&self, variable_arena: &mut VariableArena) {
|
||||
let (y, fy): (Vec<_>, Vec<_>) = self
|
||||
.terms
|
||||
.iter()
|
||||
.map(|term| {
|
||||
let variable = &variable_arena[*term];
|
||||
|
||||
(variable.get_value(), variable.get_message(self.id))
|
||||
})
|
||||
.unzip();
|
||||
|
||||
self.internal_update(variable_arena, self.sum, &y, &fy, &self.coeffs);
|
||||
}
|
||||
|
||||
pub fn update_term(&self, variable_arena: &mut VariableArena, index: usize) {
|
||||
let idx_term = self.terms[index];
|
||||
let idx_coeff = self.coeffs[index];
|
||||
|
||||
let a = self
|
||||
.coeffs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, coeff)| {
|
||||
if i == index {
|
||||
1.0 / idx_coeff
|
||||
} else {
|
||||
-coeff / idx_coeff
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let (y, fy): (Vec<_>, Vec<_>) = self
|
||||
.terms
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, term)| {
|
||||
let variable_id = if i == index { self.sum } else { *term };
|
||||
let variable = &variable_arena[variable_id];
|
||||
|
||||
(variable.get_value(), variable.get_message(self.id))
|
||||
})
|
||||
.unzip();
|
||||
|
||||
self.internal_update(variable_arena, idx_term, &y, &fy, &a);
|
||||
}
|
||||
|
||||
pub fn update_all_terms(&self, variable_arena: &mut VariableArena) {
|
||||
for i in 0..self.terms.len() {
|
||||
self.update_term(variable_arena, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn v_win(t: f64, e: f64) -> f64 {
|
||||
math::pdf(t - e) / math::cdf(t - e)
|
||||
}
|
||||
|
||||
fn w_win(t: f64, e: f64) -> f64 {
|
||||
let vwin = v_win(t, e);
|
||||
|
||||
vwin * (vwin + t - e)
|
||||
}
|
||||
|
||||
fn v_draw(t: f64, e: f64) -> f64 {
|
||||
(math::pdf(-e - t) - math::pdf(e - t)) / (math::cdf(e - t) - math::cdf(-e - t))
|
||||
}
|
||||
|
||||
fn w_draw(t: f64, e: f64) -> f64 {
|
||||
let vdraw = v_draw(t, e);
|
||||
let n = (vdraw * vdraw) + ((e - t) * math::pdf(e - t) + (e + t) * math::pdf(e + t));
|
||||
let d = math::cdf(e - t) - math::cdf(-e - t);
|
||||
|
||||
n / d
|
||||
}
|
||||
|
||||
pub struct TruncateFactor {
|
||||
id: usize,
|
||||
variable: VariableId,
|
||||
epsilon: f64,
|
||||
draw: bool,
|
||||
}
|
||||
|
||||
impl TruncateFactor {
|
||||
pub fn new(
|
||||
variable_arena: &mut VariableArena,
|
||||
id: usize,
|
||||
variable: VariableId,
|
||||
epsilon: f64,
|
||||
draw: bool,
|
||||
) -> TruncateFactor {
|
||||
variable_arena[variable].attach_factor(id);
|
||||
|
||||
TruncateFactor {
|
||||
id,
|
||||
variable,
|
||||
epsilon,
|
||||
draw,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(&self, variable_arena: &mut VariableArena) -> f64 {
|
||||
let (x, fx) = {
|
||||
let variable = &variable_arena[self.variable];
|
||||
|
||||
(variable.get_value(), variable.get_message(self.id))
|
||||
};
|
||||
|
||||
let c = x.pi() - fx.pi();
|
||||
let d = x.tau() - fx.tau();
|
||||
|
||||
let sqrt_c = c.sqrt();
|
||||
|
||||
let t = d / sqrt_c;
|
||||
let e = self.epsilon * sqrt_c;
|
||||
|
||||
let (v, w) = if self.draw {
|
||||
(v_draw(t, e), w_draw(t, e))
|
||||
} else {
|
||||
(v_win(t, e), w_win(t, e))
|
||||
};
|
||||
|
||||
let m_w = 1.0 - w;
|
||||
|
||||
let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w);
|
||||
|
||||
debug!(
|
||||
"Trunc::up var={:?}, value={:?}",
|
||||
self.variable.index, gaussian
|
||||
);
|
||||
|
||||
variable_arena[self.variable].update_value(self.id, gaussian)
|
||||
}
|
||||
}
|
||||
*/
|
||||
@@ -1,42 +0,0 @@
|
||||
use std::marker::PhantomData;
|
||||
use std::ops;
|
||||
|
||||
use crate::gaussian::Gaussian;
|
||||
|
||||
pub struct DistributionArena<T> {
|
||||
arena: Vec<Gaussian>,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> DistributionArena<T>
|
||||
where
|
||||
T: From<usize>,
|
||||
{
|
||||
pub fn create(&mut self) -> T {
|
||||
let idx = self.arena.len();
|
||||
|
||||
self.arena.push(Gaussian::from_pi_tau(0.0, 0.0));
|
||||
|
||||
idx.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ops::Index<T> for DistributionArena<T>
|
||||
where
|
||||
T: Into<usize>,
|
||||
{
|
||||
type Output = Gaussian;
|
||||
|
||||
fn index(&self, id: T) -> &Self::Output {
|
||||
&self.arena[id.into()]
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ops::IndexMut<T> for DistributionArena<T>
|
||||
where
|
||||
T: Into<usize>,
|
||||
{
|
||||
fn index_mut(&mut self, id: T) -> &mut Self::Output {
|
||||
&mut self.arena[id.into()]
|
||||
}
|
||||
}
|
||||
@@ -2,13 +2,11 @@
|
||||
|
||||
mod factor_graph;
|
||||
mod gaussian;
|
||||
mod graph;
|
||||
mod math;
|
||||
mod matrix;
|
||||
|
||||
use crate::factor_graph::*;
|
||||
use crate::gaussian::Gaussian;
|
||||
use crate::graph::{MessageArena, VariableArena};
|
||||
use crate::matrix::Matrix;
|
||||
|
||||
/// Default initial mean of ratings.
|
||||
|
||||
Reference in New Issue
Block a user