diff --git a/examples/basic.rs b/examples/basic.rs index 76e7776..a35e7b4 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -32,7 +32,7 @@ fn main() { model.observe(&["Jerry"], &["Tom"], 3.0); model.observe(&["Jerry"], &["Tom", "Spike"], 3.5); - model.fit(true); + model.fit(); // We can predict a future outcome... let (p_win, _p_los) = model.probabilities(&[&"Jerry"], &[&"Tom"], 4.0); diff --git a/src/fitter.rs b/src/fitter.rs index 8d76303..fc464eb 100644 --- a/src/fitter.rs +++ b/src/fitter.rs @@ -1,7 +1,11 @@ mod batch; +mod recursive; pub use batch::BatchFitter; +pub use recursive::RecursiveFitter; pub trait Fitter { fn add_sample(&mut self, t: f64) -> usize; + fn allocate(&mut self); + fn fit(&mut self); } diff --git a/src/fitter/batch.rs b/src/fitter/batch.rs index 3f55f22..fbed8c5 100644 --- a/src/fitter/batch.rs +++ b/src/fitter/batch.rs @@ -37,4 +37,26 @@ impl Fitter for BatchFitter { idx } + + fn allocate(&mut self) { + todo!(); + + let n_new = self.ts_new.len(); + // let zeros = + /* + n_new = len(self.ts_new) + zeros = np.zeros(n_new) + self.ts = np.concatenate((self.ts, self.ts_new)) + self.ms = np.concatenate((self.ms, zeros)) + self.vs = np.concatenate((self.vs, self.kernel.k_diag(self.ts_new))) + self.ns = np.concatenate((self.ns, zeros)) + self.xs = np.concatenate((self.xs, zeros)) + # Clear the list of pending samples. + self.ts_new = list() + */ + } + + fn fit(&mut self) { + todo!(); + } } diff --git a/src/fitter/recursive.rs b/src/fitter/recursive.rs new file mode 100644 index 0000000..4dada67 --- /dev/null +++ b/src/fitter/recursive.rs @@ -0,0 +1,43 @@ +use crate::kernel::Kernel; + +use super::Fitter; + +pub struct RecursiveFitter { + ts_new: Vec, + kernel: Box, + ts: Vec, + ms: Vec, + vs: Vec, + ns: Vec, + xs: Vec, + is_fitted: bool, +} + +impl RecursiveFitter { + pub fn new(kernel: Box) -> Self { + RecursiveFitter { + ts_new: Vec::new(), + kernel, + ts: Vec::new(), + ms: Vec::new(), + vs: Vec::new(), + ns: Vec::new(), + xs: Vec::new(), + is_fitted: true, + } + } +} + +impl Fitter for RecursiveFitter { + fn add_sample(&mut self, t: f64) -> usize { + todo!(); + } + + fn allocate(&mut self) { + todo!(); + } + + fn fit(&mut self) { + todo!(); + } +} diff --git a/src/model.rs b/src/model.rs index 5e486f2..e3f7dbd 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,6 +1,6 @@ use std::f64; -use crate::fitter::BatchFitter; +use crate::fitter::RecursiveFitter; use crate::item::Item; use crate::kernel::Kernel; use crate::observation::*; @@ -11,11 +11,18 @@ pub enum BinaryModelObservation { Logit, } +#[derive(Clone, Copy)] +pub enum BinaryModelFitMethod { + Ep, + Kl, +} + pub struct BinaryModel { storage: Storage, last_t: f64, win_obs: BinaryModelObservation, observations: Vec>, + last_method: Option, } impl BinaryModel { @@ -25,6 +32,7 @@ impl BinaryModel { last_t: f64::NEG_INFINITY, win_obs, observations: Vec::new(), + last_method: None, } } @@ -35,7 +43,7 @@ impl BinaryModel { self.storage.insert( name.to_string(), - Item::new(Box::new(BatchFitter::new(kernel))), + Item::new(Box::new(RecursiveFitter::new(kernel))), ); } @@ -65,7 +73,76 @@ impl BinaryModel { self.last_t = t; } - pub fn fit(&mut self, verbose: bool) { + pub fn fit(&mut self) -> bool { + // method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False): + + let method = BinaryModelFitMethod::Ep; + let lr = 1.0; + let tol = 1e-3; + let max_iter = 100; + let verbose = true; + + self.last_method = Some(method); + + for item in self.storage.items_mut() { + item.fitter.allocate(); + } + + for i in 0..max_iter { + let mut max_diff = 0.0; + + for obs in &mut self.observations { + let diff = match method { + BinaryModelFitMethod::Ep => obs.ep_update(lr), + BinaryModelFitMethod::Kl => obs.kl_update(lr), + }; + + if diff > max_diff { + max_diff = diff; + } + } + + for item in self.storage.items_mut() { + item.fitter.fit(); + } + + if verbose { + println!("iteration {}, max diff: {:.5}", i + 1, max_diff); + } + + if max_diff < tol { + return true; + } + } + + false + + /* + if method == "ep": + update = lambda obs: obs.ep_update(lr=lr) + elif method == "kl": + update = lambda obs: obs.kl_update(lr=lr) + else: + raise ValueError("'method' should be one of: 'ep', 'kl'") + self._last_method = method + for item in self._item.values(): + item.fitter.allocate() + for i in range(max_iter): + max_diff = 0.0 + # Recompute the Gaussian pseudo-observations. + for obs in self.observations: + diff = update(obs) + max_diff = max(max_diff, diff) + # Recompute the posterior of the score processes. + for item in self.item.values(): + item.fitter.fit() + if verbose: + print("iteration {}, max diff: {:.5f}".format( + i+1, max_diff), flush=True) + if max_diff < tol: + return True + return False # Did not converge after `max_iter`. + */ // } diff --git a/src/observation.rs b/src/observation.rs index ffadc72..87ce49c 100644 --- a/src/observation.rs +++ b/src/observation.rs @@ -3,5 +3,6 @@ mod ordinal; pub use ordinal::*; pub trait Observation { - // + fn ep_update(&mut self, lr: f64) -> f64; + fn kl_update(&mut self, lr: f64) -> f64; } diff --git a/src/observation/ordinal.rs b/src/observation/ordinal.rs index 58651c8..96c7822 100644 --- a/src/observation/ordinal.rs +++ b/src/observation/ordinal.rs @@ -40,7 +40,13 @@ impl ProbitWinObservation { } impl Observation for ProbitWinObservation { - // + fn ep_update(&mut self, lr: f64) -> f64 { + todo!(); + } + + fn kl_update(&mut self, lr: f64) -> f64 { + todo!(); + } } pub struct LogitWinObservation { @@ -49,10 +55,16 @@ pub struct LogitWinObservation { impl LogitWinObservation { pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self { - LogitWinObservation {} + todo!(); } } impl Observation for LogitWinObservation { - // + fn ep_update(&mut self, lr: f64) -> f64 { + todo!(); + } + + fn kl_update(&mut self, lr: f64) -> f64 { + todo!(); + } } diff --git a/src/storage.rs b/src/storage.rs index 5fe7dec..9eb6dac 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -33,4 +33,8 @@ impl Storage { pub fn get_item(&mut self, id: usize) -> &mut Item { &mut self.items[id] } + + pub fn items_mut(&mut self) -> impl Iterator { + self.items.iter_mut() + } }