It works?!

This commit is contained in:
2018-10-23 14:58:30 +02:00
parent d6ea5e3116
commit 1b840e737d
4 changed files with 607 additions and 62 deletions

View File

@@ -1,54 +1,396 @@
use std::cmp;
use std::collections::HashMap;
use gaussian::Gaussian;
use math;
#[derive(Clone, Copy)]
pub struct VariableId {
index: usize,
}
pub struct VariableArena {
variables: Vec<Variable>,
}
impl VariableArena {
pub fn new() -> VariableArena {
VariableArena {
variables: Vec::new(),
}
}
pub fn create(&mut self) -> VariableId {
let index = self.variables.len();
self.variables.push(Variable {
value: Gaussian::with_pi_tau(0.0, 0.0),
factors: HashMap::new(),
});
VariableId { index }
}
pub fn get(&mut self, id: VariableId) -> Option<&Variable> {
self.variables.get(id.index)
}
pub fn get_mut(&mut self, id: VariableId) -> Option<&mut Variable> {
self.variables.get_mut(id.index)
}
}
pub struct Variable {
gaussian: Gaussian,
value: Gaussian,
factors: HashMap<usize, Gaussian>,
}
impl Variable {
pub fn new() -> Variable {
Variable {
gaussian: Gaussian::new(0.0, 0.0),
}
pub fn attach_factor(&mut self, factor: usize) {
self.factors.insert(factor, Gaussian::new());
}
fn delta(&self, other: &Variable) -> f32 {
let pi_delta = self.gaussian.pi - other.gaussian.pi;
pub fn update_value(&mut self, factor: usize, value: Gaussian) {
let old = self.factors[&factor];
if pi_delta.is_infinite() {
0.0
} else {
let tau_delta = (self.gaussian.tau - other.gaussian.tau).abs();
let intermediate = value * old;
let value = intermediate / self.value;
if pi_delta > tau_delta {
pi_delta
} else {
tau_delta
}
}
self.value = value;
}
}
pub trait Factor {
fn down(&self) -> f32 {
0.0
pub fn get_value(&self) -> Gaussian {
self.value
}
pub fn update_message(&mut self, factor: usize, message: Gaussian) {
let old = self.factors[&factor];
let intermediate = self.value / old;
let value = intermediate * message;
self.value = value;
self.factors.insert(factor, message);
}
pub fn get_message(&self, factor: usize) -> Gaussian {
self.factors[&factor]
}
}
pub struct PriorFactor {
variable: Variable,
dynamic: f32
id: usize,
variable: VariableId,
gaussian: Gaussian,
}
impl PriorFactor {
pub fn new(variable: Variable, dynamic: f32) -> PriorFactor {
PriorFactor { variable, dynamic }
pub fn new(
variable_arena: &mut VariableArena,
id: usize,
variable: VariableId,
gaussian: Gaussian,
) -> PriorFactor {
if let Some(variable) = variable_arena.get_mut(variable) {
variable.attach_factor(id);
}
PriorFactor {
id,
variable,
gaussian,
}
}
pub fn start(&self, variable_arena: &mut VariableArena) {
if let Some(variable) = variable_arena.get_mut(self.variable) {
variable.update_value(self.id, self.gaussian);
}
}
}
impl Factor for PriorFactor {
fn down(&self) -> f32 {
0.0
pub struct LikelihoodFactor {
id: usize,
mean: VariableId,
value: VariableId,
variance: f32,
}
impl LikelihoodFactor {
pub fn new(
variable_arena: &mut VariableArena,
id: usize,
mean: VariableId,
value: VariableId,
variance: f32,
) -> LikelihoodFactor {
if let Some(variable) = variable_arena.get_mut(mean) {
variable.attach_factor(id);
}
if let Some(variable) = variable_arena.get_mut(value) {
variable.attach_factor(id);
}
LikelihoodFactor {
id,
mean,
value,
variance,
}
}
pub fn update_mean(&self, variable_arena: &mut VariableArena) {
let x = variable_arena
.get(self.value)
.map(|variable| variable.get_value())
.unwrap();
let fx = variable_arena
.get_mut(self.value)
.map(|variable| variable.get_message(self.id))
.unwrap();
let a = 1.0 / (1.0 + self.variance * (x.pi - fx.pi));
let gaussian = Gaussian::with_pi_tau(a * (x.pi - fx.pi), a * (x.tau - fx.tau));
if let Some(variable) = variable_arena.get_mut(self.mean) {
variable.update_message(self.id, gaussian);
}
}
pub fn update_value(&self, variable_arena: &mut VariableArena) {
let y = variable_arena
.get(self.mean)
.map(|variable| variable.get_value())
.unwrap();
let fy = variable_arena
.get(self.mean)
.map(|variable| variable.get_message(self.id))
.unwrap();
let a = 1.0 / (1.0 + self.variance * (y.pi - fy.pi));
let gaussian = Gaussian::with_pi_tau(a * (y.pi - fy.pi), a * (y.tau - fy.tau));
if let Some(variable) = variable_arena.get_mut(self.value) {
variable.update_message(self.id, gaussian);
}
}
}
pub struct SumFactor {
id: usize,
sum: VariableId,
terms: Vec<VariableId>,
coeffs: Vec<f32>,
}
impl SumFactor {
pub fn new(
variable_arena: &mut VariableArena,
id: usize,
sum: VariableId,
terms: Vec<VariableId>,
coeffs: Vec<f32>,
) -> SumFactor {
if let Some(variable) = variable_arena.get_mut(sum) {
variable.attach_factor(id);
}
for term in &terms {
if let Some(variable) = variable_arena.get_mut(*term) {
variable.attach_factor(id);
}
}
SumFactor {
id,
sum,
terms,
coeffs,
}
}
fn internal_update(
&self,
variable_arena: &mut VariableArena,
variable: VariableId,
y: Vec<Gaussian>,
fy: Vec<Gaussian>,
a: &Vec<f32>,
) {
let size = a.len();
let mut sum_pi = 0.0;
let mut sum_tau = 0.0;
for i in 0..size {
let da = a[i];
let gy = y[i];
let gfy = fy[i];
sum_pi += da.powi(2) / (gy.pi - gfy.pi);
sum_tau += da * (gy.tau - gfy.tau) / (gy.pi - gfy.pi);
}
let new_pi = 1.0 / sum_pi;
let new_tau = new_pi * sum_tau;
let gaussian = Gaussian::with_pi_tau(new_pi, new_tau);
if let Some(variable) = variable_arena.get_mut(variable) {
variable.update_value(self.id, gaussian);
}
}
pub fn update_sum(&self, variable_arena: &mut VariableArena) {
let mut y = Vec::new();
for term in &self.terms {
let value = variable_arena
.get(*term)
.map(|variable| variable.get_value())
.unwrap();
y.push(value);
}
let mut fy = Vec::new();
for term in &self.terms {
let value = variable_arena
.get(*term)
.map(|variable| variable.get_message(self.id))
.unwrap();
fy.push(value);
}
self.internal_update(variable_arena, self.sum, y, fy, &self.coeffs);
}
pub fn update_term(&self, variable_arena: &mut VariableArena, index: usize) {
let size = self.coeffs.len();
let idx_coeff = self.coeffs[index];
let mut a = vec![0.0; size];
for i in 0..size {
if i != index {
a[i] = -self.coeffs[i] / idx_coeff;
}
}
a[index] = 1.0 / idx_coeff;
let idx_term = self.terms[index];
let mut y = Vec::new();
let mut fy = Vec::new();
let mut v = self.terms.clone();
v[index] = self.sum;
for term in v {
let value = variable_arena
.get(term)
.map(|variable| variable.get_value())
.unwrap();
y.push(value);
let value = variable_arena
.get(term)
.map(|variable| variable.get_message(self.id))
.unwrap();
fy.push(value);
}
self.internal_update(variable_arena, idx_term, y, fy, &a);
}
}
fn v_win(t: f32, e: f32) -> f32 {
math::pdf(t - e) / math::cdf(t - e)
}
fn w_win(t: f32, e: f32) -> f32 {
let vwin = v_win(t, e);
vwin * (vwin + t - e)
}
fn v_draw(t: f32, e: f32) -> f32 {
(math::pdf(-e - t) - math::pdf(e - t)) / (math::cdf(e - t) - math::cdf(-e - t))
}
fn w_draw(t: f32, e: f32) -> f32 {
let vdraw = v_draw(t, e);
let n = (vdraw * vdraw) + ((e - t) * math::pdf(e - t) + (e + t) * math::pdf(e + t));
let d = math::cdf(e - t) - math::cdf(-e - t);
n / d
}
pub struct TruncateFactor {
id: usize,
variable: VariableId,
epsilon: f32,
draw: bool,
}
impl TruncateFactor {
pub fn new(
variable_arena: &mut VariableArena,
id: usize,
variable: VariableId,
epsilon: f32,
draw: bool,
) -> TruncateFactor {
if let Some(variable) = variable_arena.get_mut(variable) {
variable.attach_factor(id);
}
TruncateFactor {
id,
variable,
epsilon,
draw,
}
}
pub fn update(&self, variable_arena: &mut VariableArena) {
let x = variable_arena
.get(self.variable)
.map(|variable| variable.get_value())
.unwrap();
let fx = variable_arena
.get_mut(self.variable)
.map(|variable| variable.get_message(self.id))
.unwrap();
let c = x.pi - fx.pi;
let d = x.tau - fx.tau;
let sqrt_c = c.sqrt();
let t = d / sqrt_c;
let e = self.epsilon * sqrt_c;
let (v, w) = if self.draw {
(v_draw(t, e), w_draw(t, e))
} else {
(v_win(t, e), w_win(t, e))
};
let m_w = 1.0 - w;
let gaussian = Gaussian::with_pi_tau(c / m_w, (d + sqrt_c * v) / m_w);
if let Some(variable) = variable_arena.get_mut(self.variable) {
variable.update_value(self.id, gaussian);
}
}
}