393 lines
9.1 KiB
Rust
393 lines
9.1 KiB
Rust
use std::collections::HashMap;
|
|
|
|
use gaussian::Gaussian;
|
|
use math;
|
|
|
|
#[derive(Clone, Copy, Debug, PartialEq)]
|
|
pub struct VariableId {
|
|
index: usize,
|
|
}
|
|
|
|
pub struct VariableArena {
|
|
variables: Vec<Variable>,
|
|
}
|
|
|
|
impl VariableArena {
|
|
pub fn new() -> VariableArena {
|
|
VariableArena {
|
|
variables: Vec::new(),
|
|
}
|
|
}
|
|
|
|
pub fn create(&mut self) -> VariableId {
|
|
let index = self.variables.len();
|
|
|
|
self.variables.push(Variable {
|
|
value: Gaussian::from_pi_tau(0.0, 0.0),
|
|
factors: HashMap::new(),
|
|
});
|
|
|
|
VariableId { index }
|
|
}
|
|
|
|
pub fn get(&mut self, id: VariableId) -> Option<&Variable> {
|
|
self.variables.get(id.index)
|
|
}
|
|
|
|
pub fn get_mut(&mut self, id: VariableId) -> Option<&mut Variable> {
|
|
self.variables.get_mut(id.index)
|
|
}
|
|
}
|
|
|
|
pub struct Variable {
|
|
value: Gaussian,
|
|
factors: HashMap<usize, Gaussian>,
|
|
}
|
|
|
|
impl Variable {
|
|
pub fn attach_factor(&mut self, factor: usize) {
|
|
self.factors.insert(factor, Gaussian::from_pi_tau(0.0, 0.0));
|
|
}
|
|
|
|
pub fn update_value(&mut self, factor: usize, value: Gaussian) {
|
|
let old = self.factors[&factor];
|
|
|
|
let intermediate = value * old;
|
|
self.factors.insert(factor, intermediate / self.value);
|
|
|
|
self.value = value;
|
|
}
|
|
|
|
pub fn get_value(&self) -> Gaussian {
|
|
self.value
|
|
}
|
|
|
|
pub fn update_message(&mut self, factor: usize, message: Gaussian) {
|
|
let old = self.factors[&factor];
|
|
|
|
let intermediate = self.value / old;
|
|
let value = intermediate * message;
|
|
|
|
self.value = value;
|
|
|
|
self.factors.insert(factor, message);
|
|
}
|
|
|
|
pub fn get_message(&self, factor: usize) -> Gaussian {
|
|
self.factors[&factor]
|
|
}
|
|
}
|
|
|
|
pub struct PriorFactor {
|
|
id: usize,
|
|
variable: VariableId,
|
|
gaussian: Gaussian,
|
|
}
|
|
|
|
impl PriorFactor {
|
|
pub fn new(
|
|
variable_arena: &mut VariableArena,
|
|
id: usize,
|
|
variable: VariableId,
|
|
gaussian: Gaussian,
|
|
) -> PriorFactor {
|
|
variable_arena.get_mut(variable).unwrap().attach_factor(id);
|
|
|
|
PriorFactor {
|
|
id,
|
|
variable,
|
|
gaussian,
|
|
}
|
|
}
|
|
|
|
pub fn start(&self, variable_arena: &mut VariableArena) {
|
|
variable_arena
|
|
.get_mut(self.variable)
|
|
.unwrap()
|
|
.update_value(self.id, self.gaussian);
|
|
}
|
|
}
|
|
|
|
pub struct LikelihoodFactor {
|
|
id: usize,
|
|
mean: VariableId,
|
|
value: VariableId,
|
|
variance: f64,
|
|
}
|
|
|
|
impl LikelihoodFactor {
|
|
pub fn new(
|
|
variable_arena: &mut VariableArena,
|
|
id: usize,
|
|
mean: VariableId,
|
|
value: VariableId,
|
|
variance: f64,
|
|
) -> LikelihoodFactor {
|
|
variable_arena.get_mut(mean).unwrap().attach_factor(id);
|
|
variable_arena.get_mut(value).unwrap().attach_factor(id);
|
|
|
|
LikelihoodFactor {
|
|
id,
|
|
mean,
|
|
value,
|
|
variance,
|
|
}
|
|
}
|
|
|
|
pub fn update_mean(&self, variable_arena: &mut VariableArena) {
|
|
let x = variable_arena
|
|
.get(self.value)
|
|
.map(|variable| variable.get_value())
|
|
.unwrap();
|
|
|
|
let fx = variable_arena
|
|
.get_mut(self.value)
|
|
.map(|variable| variable.get_message(self.id))
|
|
.unwrap();
|
|
|
|
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
|
|
|
|
let gaussian = Gaussian::from_pi_tau(a * (x.pi() - fx.pi()), a * (x.tau() - fx.tau()));
|
|
|
|
variable_arena
|
|
.get_mut(self.mean)
|
|
.unwrap()
|
|
.update_message(self.id, gaussian);
|
|
}
|
|
|
|
pub fn update_value(&self, variable_arena: &mut VariableArena) {
|
|
let y = variable_arena
|
|
.get(self.mean)
|
|
.map(|variable| variable.get_value())
|
|
.unwrap();
|
|
|
|
let fy = variable_arena
|
|
.get(self.mean)
|
|
.map(|variable| variable.get_message(self.id))
|
|
.unwrap();
|
|
|
|
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
|
|
|
|
let gaussian = Gaussian::from_pi_tau(a * (y.pi() - fy.pi()), a * (y.tau() - fy.tau()));
|
|
|
|
variable_arena
|
|
.get_mut(self.value)
|
|
.unwrap()
|
|
.update_message(self.id, gaussian);
|
|
}
|
|
}
|
|
|
|
pub struct SumFactor {
|
|
id: usize,
|
|
sum: VariableId,
|
|
terms: Vec<VariableId>,
|
|
coeffs: Vec<f64>,
|
|
}
|
|
|
|
impl SumFactor {
|
|
pub fn new(
|
|
variable_arena: &mut VariableArena,
|
|
id: usize,
|
|
sum: VariableId,
|
|
terms: Vec<VariableId>,
|
|
coeffs: Vec<f64>,
|
|
) -> SumFactor {
|
|
variable_arena.get_mut(sum).unwrap().attach_factor(id);
|
|
|
|
for term in &terms {
|
|
variable_arena.get_mut(*term).unwrap().attach_factor(id);
|
|
}
|
|
|
|
SumFactor {
|
|
id,
|
|
sum,
|
|
terms,
|
|
coeffs,
|
|
}
|
|
}
|
|
|
|
fn internal_update(
|
|
&self,
|
|
variable_arena: &mut VariableArena,
|
|
variable: VariableId,
|
|
y: Vec<Gaussian>,
|
|
fy: Vec<Gaussian>,
|
|
a: &Vec<f64>,
|
|
) {
|
|
let size = a.len();
|
|
|
|
let mut sum_pi = 0.0;
|
|
let mut sum_tau = 0.0;
|
|
|
|
for i in 0..size {
|
|
let da = a[i];
|
|
let gy = y[i];
|
|
let gfy = fy[i];
|
|
|
|
sum_pi += da.powi(2) / (gy.pi() - gfy.pi());
|
|
sum_tau += da * (gy.tau() - gfy.tau()) / (gy.pi() - gfy.pi());
|
|
}
|
|
|
|
let new_pi = 1.0 / sum_pi;
|
|
let new_tau = new_pi * sum_tau;
|
|
|
|
let gaussian = Gaussian::from_pi_tau(new_pi, new_tau);
|
|
|
|
variable_arena
|
|
.get_mut(variable)
|
|
.unwrap()
|
|
.update_message(self.id, gaussian);
|
|
}
|
|
|
|
pub fn update_sum(&self, variable_arena: &mut VariableArena) {
|
|
let mut y = Vec::new();
|
|
|
|
for term in &self.terms {
|
|
let value = variable_arena
|
|
.get(*term)
|
|
.map(|variable| variable.get_value())
|
|
.unwrap();
|
|
|
|
y.push(value);
|
|
}
|
|
|
|
let mut fy = Vec::new();
|
|
|
|
for term in &self.terms {
|
|
let value = variable_arena
|
|
.get(*term)
|
|
.map(|variable| variable.get_message(self.id))
|
|
.unwrap();
|
|
|
|
fy.push(value);
|
|
}
|
|
|
|
self.internal_update(variable_arena, self.sum, y, fy, &self.coeffs);
|
|
}
|
|
|
|
pub fn update_term(&self, variable_arena: &mut VariableArena, index: usize) {
|
|
let size = self.coeffs.len();
|
|
let idx_coeff = self.coeffs[index];
|
|
|
|
let mut a = vec![0.0; size];
|
|
|
|
for i in 0..size {
|
|
if i != index {
|
|
a[i] = -self.coeffs[i] / idx_coeff;
|
|
}
|
|
}
|
|
|
|
a[index] = 1.0 / idx_coeff;
|
|
|
|
let idx_term = self.terms[index];
|
|
|
|
let mut y = Vec::new();
|
|
let mut fy = Vec::new();
|
|
|
|
let mut v = self.terms.clone();
|
|
|
|
v[index] = self.sum;
|
|
|
|
for term in &v {
|
|
let value = variable_arena
|
|
.get(*term)
|
|
.map(|variable| variable.get_value())
|
|
.unwrap();
|
|
|
|
y.push(value);
|
|
|
|
let value = variable_arena
|
|
.get(*term)
|
|
.map(|variable| variable.get_message(self.id))
|
|
.unwrap();
|
|
|
|
fy.push(value);
|
|
}
|
|
|
|
self.internal_update(variable_arena, idx_term, y, fy, &a);
|
|
}
|
|
}
|
|
|
|
fn v_win(t: f64, e: f64) -> f64 {
|
|
math::pdf(t - e) / math::cdf(t - e)
|
|
}
|
|
|
|
fn w_win(t: f64, e: f64) -> f64 {
|
|
let vwin = v_win(t, e);
|
|
|
|
vwin * (vwin + t - e)
|
|
}
|
|
|
|
fn v_draw(t: f64, e: f64) -> f64 {
|
|
(math::pdf(-e - t) - math::pdf(e - t)) / (math::cdf(e - t) - math::cdf(-e - t))
|
|
}
|
|
|
|
fn w_draw(t: f64, e: f64) -> f64 {
|
|
let vdraw = v_draw(t, e);
|
|
let n = (vdraw * vdraw) + ((e - t) * math::pdf(e - t) + (e + t) * math::pdf(e + t));
|
|
let d = math::cdf(e - t) - math::cdf(-e - t);
|
|
|
|
n / d
|
|
}
|
|
|
|
pub struct TruncateFactor {
|
|
id: usize,
|
|
variable: VariableId,
|
|
epsilon: f64,
|
|
draw: bool,
|
|
}
|
|
|
|
impl TruncateFactor {
|
|
pub fn new(
|
|
variable_arena: &mut VariableArena,
|
|
id: usize,
|
|
variable: VariableId,
|
|
epsilon: f64,
|
|
draw: bool,
|
|
) -> TruncateFactor {
|
|
variable_arena.get_mut(variable).unwrap().attach_factor(id);
|
|
|
|
TruncateFactor {
|
|
id,
|
|
variable,
|
|
epsilon,
|
|
draw,
|
|
}
|
|
}
|
|
|
|
pub fn update(&self, variable_arena: &mut VariableArena) {
|
|
let x = variable_arena
|
|
.get(self.variable)
|
|
.map(|variable| variable.get_value())
|
|
.unwrap();
|
|
|
|
let fx = variable_arena
|
|
.get_mut(self.variable)
|
|
.map(|variable| variable.get_message(self.id))
|
|
.unwrap();
|
|
|
|
let c = x.pi() - fx.pi();
|
|
let d = x.tau() - fx.tau();
|
|
|
|
let sqrt_c = c.sqrt();
|
|
|
|
let t = d / sqrt_c;
|
|
let e = self.epsilon * sqrt_c;
|
|
|
|
let (v, w) = if self.draw {
|
|
(v_draw(t, e), w_draw(t, e))
|
|
} else {
|
|
(v_win(t, e), w_win(t, e))
|
|
};
|
|
|
|
let m_w = 1.0 - w;
|
|
|
|
let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w);
|
|
|
|
variable_arena
|
|
.get_mut(self.variable)
|
|
.unwrap()
|
|
.update_value(self.id, gaussian);
|
|
}
|
|
}
|