More progress.
This commit is contained in:
@@ -32,7 +32,7 @@ fn main() {
|
|||||||
model.observe(&["Jerry"], &["Tom"], 3.0);
|
model.observe(&["Jerry"], &["Tom"], 3.0);
|
||||||
model.observe(&["Jerry"], &["Tom", "Spike"], 3.5);
|
model.observe(&["Jerry"], &["Tom", "Spike"], 3.5);
|
||||||
|
|
||||||
model.fit(true);
|
model.fit();
|
||||||
|
|
||||||
// We can predict a future outcome...
|
// We can predict a future outcome...
|
||||||
let (p_win, _p_los) = model.probabilities(&[&"Jerry"], &[&"Tom"], 4.0);
|
let (p_win, _p_los) = model.probabilities(&[&"Jerry"], &[&"Tom"], 4.0);
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
mod batch;
|
mod batch;
|
||||||
|
mod recursive;
|
||||||
|
|
||||||
pub use batch::BatchFitter;
|
pub use batch::BatchFitter;
|
||||||
|
pub use recursive::RecursiveFitter;
|
||||||
|
|
||||||
pub trait Fitter {
|
pub trait Fitter {
|
||||||
fn add_sample(&mut self, t: f64) -> usize;
|
fn add_sample(&mut self, t: f64) -> usize;
|
||||||
|
fn allocate(&mut self);
|
||||||
|
fn fit(&mut self);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,4 +37,26 @@ impl Fitter for BatchFitter {
|
|||||||
|
|
||||||
idx
|
idx
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn allocate(&mut self) {
|
||||||
|
todo!();
|
||||||
|
|
||||||
|
let n_new = self.ts_new.len();
|
||||||
|
// let zeros =
|
||||||
|
/*
|
||||||
|
n_new = len(self.ts_new)
|
||||||
|
zeros = np.zeros(n_new)
|
||||||
|
self.ts = np.concatenate((self.ts, self.ts_new))
|
||||||
|
self.ms = np.concatenate((self.ms, zeros))
|
||||||
|
self.vs = np.concatenate((self.vs, self.kernel.k_diag(self.ts_new)))
|
||||||
|
self.ns = np.concatenate((self.ns, zeros))
|
||||||
|
self.xs = np.concatenate((self.xs, zeros))
|
||||||
|
# Clear the list of pending samples.
|
||||||
|
self.ts_new = list()
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fit(&mut self) {
|
||||||
|
todo!();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
43
src/fitter/recursive.rs
Normal file
43
src/fitter/recursive.rs
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
use crate::kernel::Kernel;
|
||||||
|
|
||||||
|
use super::Fitter;
|
||||||
|
|
||||||
|
pub struct RecursiveFitter {
|
||||||
|
ts_new: Vec<f64>,
|
||||||
|
kernel: Box<dyn Kernel>,
|
||||||
|
ts: Vec<usize>,
|
||||||
|
ms: Vec<usize>,
|
||||||
|
vs: Vec<usize>,
|
||||||
|
ns: Vec<usize>,
|
||||||
|
xs: Vec<usize>,
|
||||||
|
is_fitted: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RecursiveFitter {
|
||||||
|
pub fn new(kernel: Box<dyn Kernel>) -> Self {
|
||||||
|
RecursiveFitter {
|
||||||
|
ts_new: Vec::new(),
|
||||||
|
kernel,
|
||||||
|
ts: Vec::new(),
|
||||||
|
ms: Vec::new(),
|
||||||
|
vs: Vec::new(),
|
||||||
|
ns: Vec::new(),
|
||||||
|
xs: Vec::new(),
|
||||||
|
is_fitted: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Fitter for RecursiveFitter {
|
||||||
|
fn add_sample(&mut self, t: f64) -> usize {
|
||||||
|
todo!();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn allocate(&mut self) {
|
||||||
|
todo!();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fit(&mut self) {
|
||||||
|
todo!();
|
||||||
|
}
|
||||||
|
}
|
||||||
83
src/model.rs
83
src/model.rs
@@ -1,6 +1,6 @@
|
|||||||
use std::f64;
|
use std::f64;
|
||||||
|
|
||||||
use crate::fitter::BatchFitter;
|
use crate::fitter::RecursiveFitter;
|
||||||
use crate::item::Item;
|
use crate::item::Item;
|
||||||
use crate::kernel::Kernel;
|
use crate::kernel::Kernel;
|
||||||
use crate::observation::*;
|
use crate::observation::*;
|
||||||
@@ -11,11 +11,18 @@ pub enum BinaryModelObservation {
|
|||||||
Logit,
|
Logit,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
pub enum BinaryModelFitMethod {
|
||||||
|
Ep,
|
||||||
|
Kl,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct BinaryModel {
|
pub struct BinaryModel {
|
||||||
storage: Storage,
|
storage: Storage,
|
||||||
last_t: f64,
|
last_t: f64,
|
||||||
win_obs: BinaryModelObservation,
|
win_obs: BinaryModelObservation,
|
||||||
observations: Vec<Box<dyn Observation>>,
|
observations: Vec<Box<dyn Observation>>,
|
||||||
|
last_method: Option<BinaryModelFitMethod>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BinaryModel {
|
impl BinaryModel {
|
||||||
@@ -25,6 +32,7 @@ impl BinaryModel {
|
|||||||
last_t: f64::NEG_INFINITY,
|
last_t: f64::NEG_INFINITY,
|
||||||
win_obs,
|
win_obs,
|
||||||
observations: Vec::new(),
|
observations: Vec::new(),
|
||||||
|
last_method: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -35,7 +43,7 @@ impl BinaryModel {
|
|||||||
|
|
||||||
self.storage.insert(
|
self.storage.insert(
|
||||||
name.to_string(),
|
name.to_string(),
|
||||||
Item::new(Box::new(BatchFitter::new(kernel))),
|
Item::new(Box::new(RecursiveFitter::new(kernel))),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,7 +73,76 @@ impl BinaryModel {
|
|||||||
self.last_t = t;
|
self.last_t = t;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fit(&mut self, verbose: bool) {
|
pub fn fit(&mut self) -> bool {
|
||||||
|
// method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False):
|
||||||
|
|
||||||
|
let method = BinaryModelFitMethod::Ep;
|
||||||
|
let lr = 1.0;
|
||||||
|
let tol = 1e-3;
|
||||||
|
let max_iter = 100;
|
||||||
|
let verbose = true;
|
||||||
|
|
||||||
|
self.last_method = Some(method);
|
||||||
|
|
||||||
|
for item in self.storage.items_mut() {
|
||||||
|
item.fitter.allocate();
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in 0..max_iter {
|
||||||
|
let mut max_diff = 0.0;
|
||||||
|
|
||||||
|
for obs in &mut self.observations {
|
||||||
|
let diff = match method {
|
||||||
|
BinaryModelFitMethod::Ep => obs.ep_update(lr),
|
||||||
|
BinaryModelFitMethod::Kl => obs.kl_update(lr),
|
||||||
|
};
|
||||||
|
|
||||||
|
if diff > max_diff {
|
||||||
|
max_diff = diff;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for item in self.storage.items_mut() {
|
||||||
|
item.fitter.fit();
|
||||||
|
}
|
||||||
|
|
||||||
|
if verbose {
|
||||||
|
println!("iteration {}, max diff: {:.5}", i + 1, max_diff);
|
||||||
|
}
|
||||||
|
|
||||||
|
if max_diff < tol {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
false
|
||||||
|
|
||||||
|
/*
|
||||||
|
if method == "ep":
|
||||||
|
update = lambda obs: obs.ep_update(lr=lr)
|
||||||
|
elif method == "kl":
|
||||||
|
update = lambda obs: obs.kl_update(lr=lr)
|
||||||
|
else:
|
||||||
|
raise ValueError("'method' should be one of: 'ep', 'kl'")
|
||||||
|
self._last_method = method
|
||||||
|
for item in self._item.values():
|
||||||
|
item.fitter.allocate()
|
||||||
|
for i in range(max_iter):
|
||||||
|
max_diff = 0.0
|
||||||
|
# Recompute the Gaussian pseudo-observations.
|
||||||
|
for obs in self.observations:
|
||||||
|
diff = update(obs)
|
||||||
|
max_diff = max(max_diff, diff)
|
||||||
|
# Recompute the posterior of the score processes.
|
||||||
|
for item in self.item.values():
|
||||||
|
item.fitter.fit()
|
||||||
|
if verbose:
|
||||||
|
print("iteration {}, max diff: {:.5f}".format(
|
||||||
|
i+1, max_diff), flush=True)
|
||||||
|
if max_diff < tol:
|
||||||
|
return True
|
||||||
|
return False # Did not converge after `max_iter`.
|
||||||
|
*/
|
||||||
//
|
//
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,5 +3,6 @@ mod ordinal;
|
|||||||
pub use ordinal::*;
|
pub use ordinal::*;
|
||||||
|
|
||||||
pub trait Observation {
|
pub trait Observation {
|
||||||
//
|
fn ep_update(&mut self, lr: f64) -> f64;
|
||||||
|
fn kl_update(&mut self, lr: f64) -> f64;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,7 +40,13 @@ impl ProbitWinObservation {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Observation for ProbitWinObservation {
|
impl Observation for ProbitWinObservation {
|
||||||
//
|
fn ep_update(&mut self, lr: f64) -> f64 {
|
||||||
|
todo!();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn kl_update(&mut self, lr: f64) -> f64 {
|
||||||
|
todo!();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct LogitWinObservation {
|
pub struct LogitWinObservation {
|
||||||
@@ -49,10 +55,16 @@ pub struct LogitWinObservation {
|
|||||||
|
|
||||||
impl LogitWinObservation {
|
impl LogitWinObservation {
|
||||||
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
|
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
|
||||||
LogitWinObservation {}
|
todo!();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Observation for LogitWinObservation {
|
impl Observation for LogitWinObservation {
|
||||||
//
|
fn ep_update(&mut self, lr: f64) -> f64 {
|
||||||
|
todo!();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn kl_update(&mut self, lr: f64) -> f64 {
|
||||||
|
todo!();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,4 +33,8 @@ impl Storage {
|
|||||||
pub fn get_item(&mut self, id: usize) -> &mut Item {
|
pub fn get_item(&mut self, id: usize) -> &mut Item {
|
||||||
&mut self.items[id]
|
&mut self.items[id]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn items_mut(&mut self) -> impl Iterator<Item = &mut Item> {
|
||||||
|
self.items.iter_mut()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user