Rename recursive fitter and binary model

This commit is contained in:
2021-10-27 11:05:43 +02:00
parent b2bf871500
commit 37746f6c02
15 changed files with 41 additions and 91 deletions

View File

@@ -1,6 +1,6 @@
use std::f64;
use crate::fitter::RecursiveFitter;
use crate::fitter::Recursive;
use crate::item::Item;
use crate::kernel::Kernel;
use crate::observation::*;
@@ -12,28 +12,28 @@ trait BinaryObeservation: Observation {
}
*/
pub enum BinaryModelObservation {
pub enum Observation {
Probit,
Logit,
}
#[derive(Clone, Copy)]
pub enum BinaryModelFitMethod {
pub enum FitMethod {
Ep,
Kl,
}
pub struct BinaryModel {
pub struct Binary {
pub storage: Storage,
last_t: f64,
win_obs: BinaryModelObservation,
observations: Vec<Box<dyn Observation>>,
last_method: Option<BinaryModelFitMethod>,
win_obs: Observation,
observations: Vec<Box<dyn crate::observation::Observation>>,
last_method: Option<FitMethod>,
}
impl BinaryModel {
pub fn new(win_obs: BinaryModelObservation) -> Self {
BinaryModel {
impl Binary {
pub fn new(win_obs: Observation) -> Self {
Binary {
storage: Storage::default(),
last_t: f64::NEG_INFINITY,
win_obs,
@@ -49,7 +49,7 @@ impl BinaryModel {
self.storage.insert(
name.to_string(),
Item::new(Box::new(RecursiveFitter::new(kernel))),
Item::new(Box::new(Recursive::new(kernel))),
);
}
@@ -72,11 +72,11 @@ impl BinaryModel {
let mut elems = self.process_items(winners, 1.0);
elems.extend(self.process_items(losers, -1.0));
let obs: Box<dyn Observation> = match self.win_obs {
BinaryModelObservation::Probit => {
let obs: Box<dyn crate::observation::Observation> = match self.win_obs {
Observation::Probit => {
Box::new(ProbitWinObservation::new(&mut self.storage, &elems, t, 0.0))
}
BinaryModelObservation::Logit => {
Observation::Logit => {
Box::new(LogitWinObservation::new(&mut self.storage, &elems, t, 0.0))
}
};
@@ -95,7 +95,7 @@ impl BinaryModel {
pub fn fit(&mut self) -> bool {
// method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False):
let method = BinaryModelFitMethod::Ep;
let method = FitMethod::Ep;
let lr = 1.0;
let tol = 1e-3;
let max_iter = 100;
@@ -112,8 +112,8 @@ impl BinaryModel {
for obs in &mut self.observations {
let diff = match method {
BinaryModelFitMethod::Ep => obs.ep_update(&mut self.storage, lr),
BinaryModelFitMethod::Kl => obs.kl_update(&mut self.storage, lr),
FitMethod::Ep => obs.ep_update(&mut self.storage, lr),
FitMethod::Kl => obs.kl_update(&mut self.storage, lr),
};
if diff > max_diff {
@@ -142,10 +142,8 @@ impl BinaryModel {
elems.extend(self.process_items(team_2, -1.0));
let prob = match self.win_obs {
BinaryModelObservation::Probit => {
ProbitWinObservation::probability(&self.storage, &elems, t, 0.0)
}
BinaryModelObservation::Logit => todo!(),
Observation::Probit => ProbitWinObservation::probability(&self.storage, &elems, t, 0.0),
Observation::Logit => todo!(),
};
(prob, 1.0 - prob)

View File

@@ -1,6 +1,6 @@
use std::f64;
use crate::fitter::RecursiveFitter;
use crate::fitter::Recursive;
use crate::item::Item;
use crate::kernel::Kernel;
use crate::observation::*;
@@ -38,7 +38,7 @@ impl DifferenceModel {
self.storage.insert(
name.to_string(),
Item::new(Box::new(RecursiveFitter::new(kernel))),
Item::new(Box::new(Recursive::new(kernel))),
);
}

View File

@@ -1,6 +1,6 @@
use std::f64;
use crate::fitter::RecursiveFitter;
use crate::fitter::Recursive;
use crate::item::Item;
use crate::kernel::Kernel;
use crate::observation::*;
@@ -46,7 +46,7 @@ impl TernaryModel {
self.storage.insert(
name.to_string(),
Item::new(Box::new(RecursiveFitter::new(kernel))),
Item::new(Box::new(Recursive::new(kernel))),
);
}