This commit is contained in:
2023-01-13 10:56:27 +01:00
parent ababb58055
commit e7ee1e8b39

View File

@@ -10,7 +10,7 @@ pub use message::*;
pub use variable::*;
pub struct PriorFactor {
variable_msg: MessageId,
message: MessageId,
variable: VariableId,
gaussian: Gaussian,
}
@@ -18,17 +18,17 @@ pub struct PriorFactor {
impl PriorFactor {
pub fn new(message_arena: &mut MessageArena, variable: VariableId, gaussian: Gaussian) -> Self {
Self {
variable_msg: message_arena.create(),
message: message_arena.create(),
variable,
gaussian,
}
}
pub fn start(&self, variable_arena: &mut VariableArena, message_arena: &mut MessageArena) {
let old = message_arena[self.variable_msg];
let old = message_arena[self.message];
let value = variable_arena[self.variable];
message_arena[self.variable_msg] = self.gaussian * old / value;
message_arena[self.message] = self.gaussian * old / value;
variable_arena[self.variable] = self.gaussian;
}
}
@@ -36,8 +36,8 @@ impl PriorFactor {
pub struct LikelihoodFactor {
mean_msg: MessageId,
value_msg: MessageId,
mean: VariableId,
value: VariableId,
mean_var: VariableId,
value_var: VariableId,
variance: f64,
}
@@ -51,8 +51,8 @@ impl LikelihoodFactor {
Self {
mean_msg: message_arena.create(),
value_msg: message_arena.create(),
mean,
value,
mean_var: mean,
value_var: value,
variance,
}
}
@@ -62,7 +62,7 @@ impl LikelihoodFactor {
variable_arena: &mut VariableArena,
message_arena: &mut MessageArena,
) {
let x = variable_arena[self.value];
let x = variable_arena[self.value_var];
let fx = message_arena[self.value_msg];
let a = 1.0 / (1.0 + self.variance * (x.pi() - fx.pi()));
@@ -71,7 +71,7 @@ impl LikelihoodFactor {
let old = message_arena[self.mean_msg];
message_arena[self.mean_msg] = gaussian;
variable_arena[self.mean] = variable_arena[self.mean] / old * gaussian;
variable_arena[self.mean_var] = variable_arena[self.mean_var] / old * gaussian;
}
pub fn update_value(
@@ -79,7 +79,7 @@ impl LikelihoodFactor {
variable_arena: &mut VariableArena,
message_arena: &mut MessageArena,
) {
let y = variable_arena[self.mean];
let y = variable_arena[self.mean_var];
let fy = message_arena[self.mean_msg];
let a = 1.0 / (1.0 + self.variance * (y.pi() - fy.pi()));
@@ -88,7 +88,7 @@ impl LikelihoodFactor {
let old = message_arena[self.value_msg];
message_arena[self.value_msg] = gaussian;
variable_arena[self.value] = variable_arena[self.value] / old * gaussian;
variable_arena[self.value_var] = variable_arena[self.value_var] / old * gaussian;
}
}
@@ -225,7 +225,7 @@ impl SumFactor {
}
pub struct TruncateFactor {
variable_msg: MessageId,
message: MessageId,
variable: VariableId,
epsilon: f64,
draw: bool,
@@ -239,7 +239,7 @@ impl TruncateFactor {
draw: bool,
) -> Self {
Self {
variable_msg: message_arena.create(),
message: message_arena.create(),
variable,
epsilon,
draw,
@@ -252,7 +252,7 @@ impl TruncateFactor {
message_arena: &mut MessageArena,
) -> f64 {
let x = variable_arena[self.variable];
let fx = message_arena[self.variable_msg];
let fx = message_arena[self.message];
let c = x.pi() - fx.pi();
let d = x.tau() - fx.tau();
@@ -272,10 +272,10 @@ impl TruncateFactor {
let gaussian = Gaussian::from_pi_tau(c / m_w, (d + sqrt_c * v) / m_w);
let old = message_arena[self.variable_msg];
let old = message_arena[self.message];
let value = variable_arena[self.variable];
message_arena[self.variable_msg] = gaussian * old / value;
message_arena[self.message] = gaussian * old / value;
let pi_delta = (value.pi() - gaussian.pi()).abs();