More progress.

This commit is contained in:
2020-02-13 10:17:20 +01:00
parent a54bb70138
commit dd5667d82c
8 changed files with 171 additions and 8 deletions

View File

@@ -32,7 +32,7 @@ fn main() {
model.observe(&["Jerry"], &["Tom"], 3.0); model.observe(&["Jerry"], &["Tom"], 3.0);
model.observe(&["Jerry"], &["Tom", "Spike"], 3.5); model.observe(&["Jerry"], &["Tom", "Spike"], 3.5);
model.fit(true); model.fit();
// We can predict a future outcome... // We can predict a future outcome...
let (p_win, _p_los) = model.probabilities(&[&"Jerry"], &[&"Tom"], 4.0); let (p_win, _p_los) = model.probabilities(&[&"Jerry"], &[&"Tom"], 4.0);

View File

@@ -1,7 +1,11 @@
mod batch; mod batch;
mod recursive;
pub use batch::BatchFitter; pub use batch::BatchFitter;
pub use recursive::RecursiveFitter;
pub trait Fitter { pub trait Fitter {
fn add_sample(&mut self, t: f64) -> usize; fn add_sample(&mut self, t: f64) -> usize;
fn allocate(&mut self);
fn fit(&mut self);
} }

View File

@@ -37,4 +37,26 @@ impl Fitter for BatchFitter {
idx idx
} }
fn allocate(&mut self) {
todo!();
let n_new = self.ts_new.len();
// let zeros =
/*
n_new = len(self.ts_new)
zeros = np.zeros(n_new)
self.ts = np.concatenate((self.ts, self.ts_new))
self.ms = np.concatenate((self.ms, zeros))
self.vs = np.concatenate((self.vs, self.kernel.k_diag(self.ts_new)))
self.ns = np.concatenate((self.ns, zeros))
self.xs = np.concatenate((self.xs, zeros))
# Clear the list of pending samples.
self.ts_new = list()
*/
}
fn fit(&mut self) {
todo!();
}
} }

43
src/fitter/recursive.rs Normal file
View File

@@ -0,0 +1,43 @@
use crate::kernel::Kernel;
use super::Fitter;
pub struct RecursiveFitter {
ts_new: Vec<f64>,
kernel: Box<dyn Kernel>,
ts: Vec<usize>,
ms: Vec<usize>,
vs: Vec<usize>,
ns: Vec<usize>,
xs: Vec<usize>,
is_fitted: bool,
}
impl RecursiveFitter {
pub fn new(kernel: Box<dyn Kernel>) -> Self {
RecursiveFitter {
ts_new: Vec::new(),
kernel,
ts: Vec::new(),
ms: Vec::new(),
vs: Vec::new(),
ns: Vec::new(),
xs: Vec::new(),
is_fitted: true,
}
}
}
impl Fitter for RecursiveFitter {
fn add_sample(&mut self, t: f64) -> usize {
todo!();
}
fn allocate(&mut self) {
todo!();
}
fn fit(&mut self) {
todo!();
}
}

View File

@@ -1,6 +1,6 @@
use std::f64; use std::f64;
use crate::fitter::BatchFitter; use crate::fitter::RecursiveFitter;
use crate::item::Item; use crate::item::Item;
use crate::kernel::Kernel; use crate::kernel::Kernel;
use crate::observation::*; use crate::observation::*;
@@ -11,11 +11,18 @@ pub enum BinaryModelObservation {
Logit, Logit,
} }
#[derive(Clone, Copy)]
pub enum BinaryModelFitMethod {
Ep,
Kl,
}
pub struct BinaryModel { pub struct BinaryModel {
storage: Storage, storage: Storage,
last_t: f64, last_t: f64,
win_obs: BinaryModelObservation, win_obs: BinaryModelObservation,
observations: Vec<Box<dyn Observation>>, observations: Vec<Box<dyn Observation>>,
last_method: Option<BinaryModelFitMethod>,
} }
impl BinaryModel { impl BinaryModel {
@@ -25,6 +32,7 @@ impl BinaryModel {
last_t: f64::NEG_INFINITY, last_t: f64::NEG_INFINITY,
win_obs, win_obs,
observations: Vec::new(), observations: Vec::new(),
last_method: None,
} }
} }
@@ -35,7 +43,7 @@ impl BinaryModel {
self.storage.insert( self.storage.insert(
name.to_string(), name.to_string(),
Item::new(Box::new(BatchFitter::new(kernel))), Item::new(Box::new(RecursiveFitter::new(kernel))),
); );
} }
@@ -65,7 +73,76 @@ impl BinaryModel {
self.last_t = t; self.last_t = t;
} }
pub fn fit(&mut self, verbose: bool) { pub fn fit(&mut self) -> bool {
// method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False):
let method = BinaryModelFitMethod::Ep;
let lr = 1.0;
let tol = 1e-3;
let max_iter = 100;
let verbose = true;
self.last_method = Some(method);
for item in self.storage.items_mut() {
item.fitter.allocate();
}
for i in 0..max_iter {
let mut max_diff = 0.0;
for obs in &mut self.observations {
let diff = match method {
BinaryModelFitMethod::Ep => obs.ep_update(lr),
BinaryModelFitMethod::Kl => obs.kl_update(lr),
};
if diff > max_diff {
max_diff = diff;
}
}
for item in self.storage.items_mut() {
item.fitter.fit();
}
if verbose {
println!("iteration {}, max diff: {:.5}", i + 1, max_diff);
}
if max_diff < tol {
return true;
}
}
false
/*
if method == "ep":
update = lambda obs: obs.ep_update(lr=lr)
elif method == "kl":
update = lambda obs: obs.kl_update(lr=lr)
else:
raise ValueError("'method' should be one of: 'ep', 'kl'")
self._last_method = method
for item in self._item.values():
item.fitter.allocate()
for i in range(max_iter):
max_diff = 0.0
# Recompute the Gaussian pseudo-observations.
for obs in self.observations:
diff = update(obs)
max_diff = max(max_diff, diff)
# Recompute the posterior of the score processes.
for item in self.item.values():
item.fitter.fit()
if verbose:
print("iteration {}, max diff: {:.5f}".format(
i+1, max_diff), flush=True)
if max_diff < tol:
return True
return False # Did not converge after `max_iter`.
*/
// //
} }

View File

@@ -3,5 +3,6 @@ mod ordinal;
pub use ordinal::*; pub use ordinal::*;
pub trait Observation { pub trait Observation {
// fn ep_update(&mut self, lr: f64) -> f64;
fn kl_update(&mut self, lr: f64) -> f64;
} }

View File

@@ -40,7 +40,13 @@ impl ProbitWinObservation {
} }
impl Observation for ProbitWinObservation { impl Observation for ProbitWinObservation {
// fn ep_update(&mut self, lr: f64) -> f64 {
todo!();
}
fn kl_update(&mut self, lr: f64) -> f64 {
todo!();
}
} }
pub struct LogitWinObservation { pub struct LogitWinObservation {
@@ -49,10 +55,16 @@ pub struct LogitWinObservation {
impl LogitWinObservation { impl LogitWinObservation {
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self { pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
LogitWinObservation {} todo!();
} }
} }
impl Observation for LogitWinObservation { impl Observation for LogitWinObservation {
// fn ep_update(&mut self, lr: f64) -> f64 {
todo!();
}
fn kl_update(&mut self, lr: f64) -> f64 {
todo!();
}
} }

View File

@@ -33,4 +33,8 @@ impl Storage {
pub fn get_item(&mut self, id: usize) -> &mut Item { pub fn get_item(&mut self, id: usize) -> &mut Item {
&mut self.items[id] &mut self.items[id]
} }
pub fn items_mut(&mut self) -> impl Iterator<Item = &mut Item> {
self.items.iter_mut()
}
} }