More progress.
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
mod batch;
|
||||
mod recursive;
|
||||
|
||||
pub use batch::BatchFitter;
|
||||
pub use recursive::RecursiveFitter;
|
||||
|
||||
pub trait Fitter {
|
||||
fn add_sample(&mut self, t: f64) -> usize;
|
||||
fn allocate(&mut self);
|
||||
fn fit(&mut self);
|
||||
}
|
||||
|
||||
@@ -37,4 +37,26 @@ impl Fitter for BatchFitter {
|
||||
|
||||
idx
|
||||
}
|
||||
|
||||
fn allocate(&mut self) {
|
||||
todo!();
|
||||
|
||||
let n_new = self.ts_new.len();
|
||||
// let zeros =
|
||||
/*
|
||||
n_new = len(self.ts_new)
|
||||
zeros = np.zeros(n_new)
|
||||
self.ts = np.concatenate((self.ts, self.ts_new))
|
||||
self.ms = np.concatenate((self.ms, zeros))
|
||||
self.vs = np.concatenate((self.vs, self.kernel.k_diag(self.ts_new)))
|
||||
self.ns = np.concatenate((self.ns, zeros))
|
||||
self.xs = np.concatenate((self.xs, zeros))
|
||||
# Clear the list of pending samples.
|
||||
self.ts_new = list()
|
||||
*/
|
||||
}
|
||||
|
||||
fn fit(&mut self) {
|
||||
todo!();
|
||||
}
|
||||
}
|
||||
|
||||
43
src/fitter/recursive.rs
Normal file
43
src/fitter/recursive.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
use crate::kernel::Kernel;
|
||||
|
||||
use super::Fitter;
|
||||
|
||||
pub struct RecursiveFitter {
|
||||
ts_new: Vec<f64>,
|
||||
kernel: Box<dyn Kernel>,
|
||||
ts: Vec<usize>,
|
||||
ms: Vec<usize>,
|
||||
vs: Vec<usize>,
|
||||
ns: Vec<usize>,
|
||||
xs: Vec<usize>,
|
||||
is_fitted: bool,
|
||||
}
|
||||
|
||||
impl RecursiveFitter {
|
||||
pub fn new(kernel: Box<dyn Kernel>) -> Self {
|
||||
RecursiveFitter {
|
||||
ts_new: Vec::new(),
|
||||
kernel,
|
||||
ts: Vec::new(),
|
||||
ms: Vec::new(),
|
||||
vs: Vec::new(),
|
||||
ns: Vec::new(),
|
||||
xs: Vec::new(),
|
||||
is_fitted: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Fitter for RecursiveFitter {
|
||||
fn add_sample(&mut self, t: f64) -> usize {
|
||||
todo!();
|
||||
}
|
||||
|
||||
fn allocate(&mut self) {
|
||||
todo!();
|
||||
}
|
||||
|
||||
fn fit(&mut self) {
|
||||
todo!();
|
||||
}
|
||||
}
|
||||
83
src/model.rs
83
src/model.rs
@@ -1,6 +1,6 @@
|
||||
use std::f64;
|
||||
|
||||
use crate::fitter::BatchFitter;
|
||||
use crate::fitter::RecursiveFitter;
|
||||
use crate::item::Item;
|
||||
use crate::kernel::Kernel;
|
||||
use crate::observation::*;
|
||||
@@ -11,11 +11,18 @@ pub enum BinaryModelObservation {
|
||||
Logit,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum BinaryModelFitMethod {
|
||||
Ep,
|
||||
Kl,
|
||||
}
|
||||
|
||||
pub struct BinaryModel {
|
||||
storage: Storage,
|
||||
last_t: f64,
|
||||
win_obs: BinaryModelObservation,
|
||||
observations: Vec<Box<dyn Observation>>,
|
||||
last_method: Option<BinaryModelFitMethod>,
|
||||
}
|
||||
|
||||
impl BinaryModel {
|
||||
@@ -25,6 +32,7 @@ impl BinaryModel {
|
||||
last_t: f64::NEG_INFINITY,
|
||||
win_obs,
|
||||
observations: Vec::new(),
|
||||
last_method: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,7 +43,7 @@ impl BinaryModel {
|
||||
|
||||
self.storage.insert(
|
||||
name.to_string(),
|
||||
Item::new(Box::new(BatchFitter::new(kernel))),
|
||||
Item::new(Box::new(RecursiveFitter::new(kernel))),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -65,7 +73,76 @@ impl BinaryModel {
|
||||
self.last_t = t;
|
||||
}
|
||||
|
||||
pub fn fit(&mut self, verbose: bool) {
|
||||
pub fn fit(&mut self) -> bool {
|
||||
// method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False):
|
||||
|
||||
let method = BinaryModelFitMethod::Ep;
|
||||
let lr = 1.0;
|
||||
let tol = 1e-3;
|
||||
let max_iter = 100;
|
||||
let verbose = true;
|
||||
|
||||
self.last_method = Some(method);
|
||||
|
||||
for item in self.storage.items_mut() {
|
||||
item.fitter.allocate();
|
||||
}
|
||||
|
||||
for i in 0..max_iter {
|
||||
let mut max_diff = 0.0;
|
||||
|
||||
for obs in &mut self.observations {
|
||||
let diff = match method {
|
||||
BinaryModelFitMethod::Ep => obs.ep_update(lr),
|
||||
BinaryModelFitMethod::Kl => obs.kl_update(lr),
|
||||
};
|
||||
|
||||
if diff > max_diff {
|
||||
max_diff = diff;
|
||||
}
|
||||
}
|
||||
|
||||
for item in self.storage.items_mut() {
|
||||
item.fitter.fit();
|
||||
}
|
||||
|
||||
if verbose {
|
||||
println!("iteration {}, max diff: {:.5}", i + 1, max_diff);
|
||||
}
|
||||
|
||||
if max_diff < tol {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
|
||||
/*
|
||||
if method == "ep":
|
||||
update = lambda obs: obs.ep_update(lr=lr)
|
||||
elif method == "kl":
|
||||
update = lambda obs: obs.kl_update(lr=lr)
|
||||
else:
|
||||
raise ValueError("'method' should be one of: 'ep', 'kl'")
|
||||
self._last_method = method
|
||||
for item in self._item.values():
|
||||
item.fitter.allocate()
|
||||
for i in range(max_iter):
|
||||
max_diff = 0.0
|
||||
# Recompute the Gaussian pseudo-observations.
|
||||
for obs in self.observations:
|
||||
diff = update(obs)
|
||||
max_diff = max(max_diff, diff)
|
||||
# Recompute the posterior of the score processes.
|
||||
for item in self.item.values():
|
||||
item.fitter.fit()
|
||||
if verbose:
|
||||
print("iteration {}, max diff: {:.5f}".format(
|
||||
i+1, max_diff), flush=True)
|
||||
if max_diff < tol:
|
||||
return True
|
||||
return False # Did not converge after `max_iter`.
|
||||
*/
|
||||
//
|
||||
}
|
||||
|
||||
|
||||
@@ -3,5 +3,6 @@ mod ordinal;
|
||||
pub use ordinal::*;
|
||||
|
||||
pub trait Observation {
|
||||
//
|
||||
fn ep_update(&mut self, lr: f64) -> f64;
|
||||
fn kl_update(&mut self, lr: f64) -> f64;
|
||||
}
|
||||
|
||||
@@ -40,7 +40,13 @@ impl ProbitWinObservation {
|
||||
}
|
||||
|
||||
impl Observation for ProbitWinObservation {
|
||||
//
|
||||
fn ep_update(&mut self, lr: f64) -> f64 {
|
||||
todo!();
|
||||
}
|
||||
|
||||
fn kl_update(&mut self, lr: f64) -> f64 {
|
||||
todo!();
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LogitWinObservation {
|
||||
@@ -49,10 +55,16 @@ pub struct LogitWinObservation {
|
||||
|
||||
impl LogitWinObservation {
|
||||
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
|
||||
LogitWinObservation {}
|
||||
todo!();
|
||||
}
|
||||
}
|
||||
|
||||
impl Observation for LogitWinObservation {
|
||||
//
|
||||
fn ep_update(&mut self, lr: f64) -> f64 {
|
||||
todo!();
|
||||
}
|
||||
|
||||
fn kl_update(&mut self, lr: f64) -> f64 {
|
||||
todo!();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,4 +33,8 @@ impl Storage {
|
||||
pub fn get_item(&mut self, id: usize) -> &mut Item {
|
||||
&mut self.items[id]
|
||||
}
|
||||
|
||||
pub fn items_mut(&mut self) -> impl Iterator<Item = &mut Item> {
|
||||
self.items.iter_mut()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user