This commit is contained in:
2023-01-13 10:56:27 +01:00
parent ababb58055
commit e7ee1e8b39

View File

@@ -10,7 +10,7 @@ pub use message::*;
pub use variable::*; pub use variable::*;
pub struct PriorFactor { pub struct PriorFactor {
variable_msg: MessageId, message: MessageId,
variable: VariableId, variable: VariableId,
gaussian: Gaussian, gaussian: Gaussian,
} }
@@ -18,17 +18,17 @@ pub struct PriorFactor {
impl PriorFactor { impl PriorFactor {
pub fn new(message_arena: &mut MessageArena, variable: VariableId, gaussian: Gaussian) -> Self { pub fn new(message_arena: &mut MessageArena, variable: VariableId, gaussian: Gaussian) -> Self {
Self { Self {
variable_msg: message_arena.create(), message: message_arena.create(),
variable, variable,
gaussian, gaussian,
} }
} }
pub fn start(&self, variable_arena: &mut VariableArena, message_arena: &mut MessageArena) { pub fn start(&self, variable_arena: &mut VariableArena, message_arena: &mut MessageArena) {
let old = message_arena[self.variable_msg]; let old = message_arena[self.message];
let value = variable_arena[self.variable]; let value = variable_arena[self.variable];
message_arena[self.variable_msg] = self.gaussian * old / value; message_arena[self.message] = self.gaussian * old / value;
variable_arena[self.variable] = self.gaussian; variable_arena[self.variable] = self.gaussian;
} }
} }
@@ -36,8 +36,8 @@ impl PriorFactor {
pub struct LikelihoodFactor { pub struct LikelihoodFactor {
mean_msg: MessageId, mean_msg: MessageId,
value_msg: MessageId, value_msg: MessageId,
mean: VariableId, mean_var: VariableId,
value: VariableId, value_var: VariableId,
variance: f64, variance: f64,
} }
@@ -51,8 +51,8 @@ impl LikelihoodFactor {
Self { Self {
mean_msg: message_arena.create(), mean_msg: message_arena.create(),
value_msg: message_arena.create(), value_msg: message_arena.create(),
mean, mean_var: mean,
value, value_var: value,
variance, variance,
} }
} }
@@ -62,7 +62,7 @@ impl LikelihoodFactor {
variable_arena: &mut VariableArena, variable_arena: &mut VariableArena,
message_arena: &mut MessageArena, message_arena: &mut MessageArena,
) { ) {
let x = variable_arena[self.value]; let x = variable_arena[self.value_var];
let fx = message_arena[self.value_msg]; let fx = message_arena[self.value_msg];
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi())); let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
@@ -71,7 +71,7 @@ impl LikelihoodFactor {
let old = message_arena[self.mean_msg]; let old = message_arena[self.mean_msg];
message_arena[self.mean_msg] = gaussian; message_arena[self.mean_msg] = gaussian;
variable_arena[self.mean] = variable_arena[self.mean] / old * gaussian; variable_arena[self.mean_var] = variable_arena[self.mean_var] / old * gaussian;
} }
pub fn update_value( pub fn update_value(
@@ -79,7 +79,7 @@ impl LikelihoodFactor {
variable_arena: &mut VariableArena, variable_arena: &mut VariableArena,
message_arena: &mut MessageArena, message_arena: &mut MessageArena,
) { ) {
let y = variable_arena[self.mean]; let y = variable_arena[self.mean_var];
let fy = message_arena[self.mean_msg]; let fy = message_arena[self.mean_msg];
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi())); let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
@@ -88,7 +88,7 @@ impl LikelihoodFactor {
let old = message_arena[self.value_msg]; let old = message_arena[self.value_msg];
message_arena[self.value_msg] = gaussian; message_arena[self.value_msg] = gaussian;
variable_arena[self.value] = variable_arena[self.value] / old * gaussian; variable_arena[self.value_var] = variable_arena[self.value_var] / old * gaussian;
} }
} }
@@ -225,7 +225,7 @@ impl SumFactor {
} }
pub struct TruncateFactor { pub struct TruncateFactor {
variable_msg: MessageId, message: MessageId,
variable: VariableId, variable: VariableId,
epsilon: f64, epsilon: f64,
draw: bool, draw: bool,
@@ -239,7 +239,7 @@ impl TruncateFactor {
draw: bool, draw: bool,
) -> Self { ) -> Self {
Self { Self {
variable_msg: message_arena.create(), message: message_arena.create(),
variable, variable,
epsilon, epsilon,
draw, draw,
@@ -252,7 +252,7 @@ impl TruncateFactor {
message_arena: &mut MessageArena, message_arena: &mut MessageArena,
) -> f64 { ) -> f64 {
let x = variable_arena[self.variable]; let x = variable_arena[self.variable];
let fx = message_arena[self.variable_msg]; let fx = message_arena[self.message];
let c = x.pi() - fx.pi(); let c = x.pi() - fx.pi();
let d = x.tau() - fx.tau(); let d = x.tau() - fx.tau();
@@ -272,10 +272,10 @@ impl TruncateFactor {
let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w); let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w);
let old = message_arena[self.variable_msg]; let old = message_arena[self.message];
let value = variable_arena[self.variable]; let value = variable_arena[self.variable];
message_arena[self.variable_msg] = gaussian * old / value; message_arena[self.message] = gaussian * old / value;
let pi_delta = (value.pi() - gaussian.pi()).abs(); let pi_delta = (value.pi() - gaussian.pi()).abs();