diff --git a/src/kernel.rs b/src/kernel.rs index 3888343..6cf58e1 100644 --- a/src/kernel.rs +++ b/src/kernel.rs @@ -161,12 +161,6 @@ impl Kernel for Vec> { transition } - /* - def transition(self, t1, t2): - mats = [k.transition(t1, t2) for k in self.parts] - return block_diag(*mats) - */ - fn noise_cov(&self, t0: f64, t1: f64) -> Array2 { let data = self .iter() diff --git a/src/model.rs b/src/model.rs index f3b0869..536e417 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,5 +1,7 @@ mod binary; +mod difference; mod ternary; pub use binary::*; +pub use difference::*; pub use ternary::*; diff --git a/src/model/difference.rs b/src/model/difference.rs new file mode 100644 index 0000000..054a1e5 --- /dev/null +++ b/src/model/difference.rs @@ -0,0 +1,100 @@ +use std::f64; + +use crate::fitter::RecursiveFitter; +use crate::item::Item; +use crate::kernel::Kernel; +use crate::observation::*; +use crate::storage::Storage; + +#[derive(Clone, Copy)] +pub enum DifferenceModelFitMethod { + Ep, + Kl, +} + +pub struct DifferenceModel { + storage: Storage, + last_t: f64, + observations: Vec, + last_method: Option, + var: f64, +} + +impl DifferenceModel { + pub fn new(var: f64) -> Self { + DifferenceModel { + storage: Storage::new(), + last_t: f64::NEG_INFINITY, + observations: Vec::new(), + last_method: None, + var, // default = 1.0 + } + } + + pub fn add_item(&mut self, name: &str, kernel: Box) { + if self.storage.contains_key(name) { + panic!("item '{}' already added", name); + } + + self.storage.insert( + name.to_string(), + Item::new(Box::new(RecursiveFitter::new(kernel))), + ); + } + + pub fn contains_item(&self, name: &str) -> bool { + self.storage.contains_key(name) + } + + pub fn item_score(&self, name: &str, t: f64) -> (f64, f64) { + let id = self.storage.get_id(name); + let (ms, vs) = self.storage.item(id).fitter.predict(&[t]); + + (ms[0], vs[0]) + } + + pub fn observe( + &mut self, + winners: &[&str], + losers: &[&str], + diff: f64, + t: f64, + var: Option, + ) { + if t < self.last_t { + panic!("observations must be added in chronological order"); + } + + let var = var.unwrap_or_else(|| self.var); + + let mut elems = self.process_items(winners, 1.0); + elems.extend(self.process_items(losers, -1.0)); + + let obs = GaussianObservation::new(&mut self.storage, &elems, diff, t, var); + + self.observations.push(obs); + + self.last_t = t; + } + + pub fn fit(&mut self) -> bool { + unimplemented!(); + } + + pub fn probabilities( + &mut self, + team_1: &[&str], + team_2: &[&str], + t: f64, + margin: Option, + ) -> (f64, f64, f64) { + unimplemented!(); + } + + fn process_items(&self, items: &[&str], sign: f64) -> Vec<(usize, f64)> { + items + .iter() + .map(|key| (self.storage.get_id(&key), sign)) + .collect() + } +} diff --git a/src/observation.rs b/src/observation.rs index f4bc284..5582fb4 100644 --- a/src/observation.rs +++ b/src/observation.rs @@ -1,7 +1,9 @@ use crate::storage::Storage; +mod gaussian; mod ordinal; +pub use gaussian::*; pub use ordinal::*; pub trait Observation { diff --git a/src/observation/gaussian.rs b/src/observation/gaussian.rs new file mode 100644 index 0000000..38f470b --- /dev/null +++ b/src/observation/gaussian.rs @@ -0,0 +1,36 @@ +use crate::storage::Storage; + +use super::Observation; + +pub struct GaussianObservation { + m: usize, + items: Vec, + coeffs: Vec, + indices: Vec, + ns_cav: Vec, + xs_cav: Vec, + t: f64, + logpart: f64, + exp_ll: usize, + margin: f64, +} + +impl GaussianObservation { + pub fn new(storage: &mut Storage, elems: &[(usize, f64)], diff: f64, t: f64, var: f64) -> Self { + unimplemented!(); + } +} + +impl Observation for GaussianObservation { + fn match_moments(&self, mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) { + unimplemented!(); + } + + fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 { + unimplemented!(); + } + + fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64 { + unimplemented!(); + } +}