Refactor, and passing tests
This commit is contained in:
@@ -112,8 +112,8 @@ impl BinaryModel {
|
||||
|
||||
for obs in &mut self.observations {
|
||||
let diff = match method {
|
||||
BinaryModelFitMethod::Ep => obs.ep_update(lr, &mut self.storage),
|
||||
BinaryModelFitMethod::Kl => obs.kl_update(lr, &mut self.storage),
|
||||
BinaryModelFitMethod::Ep => obs.ep_update(&mut self.storage, lr),
|
||||
BinaryModelFitMethod::Kl => obs.kl_update(&mut self.storage, lr),
|
||||
};
|
||||
|
||||
if diff > max_diff {
|
||||
|
||||
@@ -130,8 +130,8 @@ impl TernaryModel {
|
||||
|
||||
for obs in &mut self.observations {
|
||||
let diff = match method {
|
||||
TernaryModelFitMethod::Ep => obs.ep_update(lr, &mut self.storage),
|
||||
TernaryModelFitMethod::Kl => obs.kl_update(lr, &mut self.storage),
|
||||
TernaryModelFitMethod::Ep => obs.ep_update(&mut self.storage, lr),
|
||||
TernaryModelFitMethod::Kl => obs.kl_update(&mut self.storage, lr),
|
||||
};
|
||||
|
||||
if diff > max_diff {
|
||||
|
||||
Reference in New Issue
Block a user