diff --git a/examples/basic.rs b/examples/basic.rs index 86df4f5..76e7776 100644 --- a/examples/basic.rs +++ b/examples/basic.rs @@ -1,7 +1,7 @@ use kickscore as ks; fn main() { - let mut model = ks::BinaryModel {}; + let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit); // Spike's skill does not change over time. let k_spike = ks::kernel::Constant::new(0.5); @@ -16,9 +16,9 @@ fn main() { ]; // Now we are ready to add the items in the model. - model.add_item("Spike", k_spike); - model.add_item("Tom", k_tom); - model.add_item("Jerry", k_jerry); + model.add_item("Spike", Box::new(k_spike)); + model.add_item("Tom", Box::new(k_tom)); + model.add_item("Jerry", Box::new(k_jerry)); // At first, Jerry beats Tom a couple of times. model.observe(&["Jerry"], &["Tom"], 0.0); @@ -35,20 +35,20 @@ fn main() { model.fit(true); // We can predict a future outcome... - let (p_win, p_los) = model.probabilities(&[&"Jerry"], &[&"Tom"], 4.0); + let (p_win, _p_los) = model.probabilities(&[&"Jerry"], &[&"Tom"], 4.0); println!( "Chances that Jerry beats Tom at t = 4.0: {:.1}%", 100.0 * p_win ); // ... or simulate what could have happened in the past. - let (p_win, p_los) = model.probabilities(&[&"Jerry"], &[&"Tom"], 2.0); + let (p_win, _p_los) = model.probabilities(&[&"Jerry"], &[&"Tom"], 2.0); println!( "Chances that Jerry beats Tom at t = 2.0: {:.1}%", 100.0 * p_win ); - let (p_win, p_los) = model.probabilities(&[&"Jerry"], &[&"Tom"], -1.0); + let (p_win, _p_los) = model.probabilities(&[&"Jerry"], &[&"Tom"], -1.0); println!( "Chances that Jerry beats Tom at t = -1.0: {:.1}%", 100.0 * p_win diff --git a/src/fitter.rs b/src/fitter.rs new file mode 100644 index 0000000..8d76303 --- /dev/null +++ b/src/fitter.rs @@ -0,0 +1,7 @@ +mod batch; + +pub use batch::BatchFitter; + +pub trait Fitter { + fn add_sample(&mut self, t: f64) -> usize; +} diff --git a/src/fitter/batch.rs b/src/fitter/batch.rs new file mode 100644 index 0000000..3f55f22 --- /dev/null +++ b/src/fitter/batch.rs @@ -0,0 +1,40 @@ +use crate::kernel::Kernel; + +use super::Fitter; + +pub struct BatchFitter { + ts_new: Vec, + kernel: Box, + ts: Vec, + ms: Vec, + vs: Vec, + ns: Vec, + xs: Vec, + is_fitted: bool, +} + +impl BatchFitter { + pub fn new(kernel: Box) -> Self { + BatchFitter { + ts_new: Vec::new(), + kernel, + ts: Vec::new(), + ms: Vec::new(), + vs: Vec::new(), + ns: Vec::new(), + xs: Vec::new(), + is_fitted: true, + } + } +} + +impl Fitter for BatchFitter { + fn add_sample(&mut self, t: f64) -> usize { + let idx = self.ts.len() + self.ts_new.len(); + + self.ts_new.push(t); + self.is_fitted = false; + + idx + } +} diff --git a/src/item.rs b/src/item.rs new file mode 100644 index 0000000..c7b892b --- /dev/null +++ b/src/item.rs @@ -0,0 +1,11 @@ +use crate::fitter::Fitter; + +pub struct Item { + pub fitter: Box, +} + +impl Item { + pub fn new(fitter: Box) -> Self { + Item { fitter } + } +} diff --git a/src/lib.rs b/src/lib.rs index 5ea3605..bcb0555 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,10 @@ // https://github.com/lucasmaystre/kickscore/tree/master/kickscore +mod fitter; +mod item; pub mod kernel; mod model; +pub mod observation; +mod storage; pub use kernel::Kernel; pub use model::*; diff --git a/src/model.rs b/src/model.rs index ec81ec3..5e486f2 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,23 +1,91 @@ +use std::f64; + +use crate::fitter::BatchFitter; +use crate::item::Item; use crate::kernel::Kernel; +use crate::observation::*; +use crate::storage::Storage; + +pub enum BinaryModelObservation { + Probit, + Logit, +} pub struct BinaryModel { - // + storage: Storage, + last_t: f64, + win_obs: BinaryModelObservation, + observations: Vec>, } impl BinaryModel { - pub fn add_item(&mut self, name: &str, kernel: impl Kernel) { - // + pub fn new(win_obs: BinaryModelObservation) -> Self { + BinaryModel { + storage: Storage::new(), + last_t: f64::NEG_INFINITY, + win_obs, + observations: Vec::new(), + } + } + + pub fn add_item(&mut self, name: &str, kernel: Box) { + if self.storage.contains_key(name) { + // raise ValueError("item '{}' already added".format(name)) + } + + self.storage.insert( + name.to_string(), + Item::new(Box::new(BatchFitter::new(kernel))), + ); } pub fn observe(&mut self, winners: &[&str], losers: &[&str], t: f64) { - // + if t < self.last_t { + // raise ValueError("observations must be added in chronological order") + } + + let mut elems = self.process_items(winners, 1.0); + elems.extend(self.process_items(losers, -1.0)); + + let obs: Box = match self.win_obs { + BinaryModelObservation::Probit => { + Box::new(ProbitWinObservation::new(&mut self.storage, &elems, t, 0.0)) + } + BinaryModelObservation::Logit => { + Box::new(LogitWinObservation::new(&mut self.storage, &elems, t, 0.0)) + } + }; + + self.observations.push(obs); + + for (item, _) in elems { + // item.link_observation(obs) + } + + self.last_t = t; } pub fn fit(&mut self, verbose: bool) { // } - pub fn probabilities(&mut self, team1: &[&str], team2: &[&str], t: f64) -> (f64, f64) { + pub fn probabilities(&mut self, team_1: &[&str], team_2: &[&str], t: f64) -> (f64, f64) { (0.0, 0.0) } + + fn process_items(&self, items: &[&str], sign: f64) -> Vec<(usize, f64)> { + /* + if isinstance(items, dict): + return [(self.item[k], sign * float(v)) for k, v in items.items()] + if isinstance(items, list) or isinstance(items, tuple): + return [(self.item[k], sign) for k in items] + else: + raise ValueError("items should be a list, a tuple or a dict") + */ + + items + .iter() + .map(|key| (self.storage.get_id(&key), sign)) + .collect() + } } diff --git a/src/observation.rs b/src/observation.rs new file mode 100644 index 0000000..ffadc72 --- /dev/null +++ b/src/observation.rs @@ -0,0 +1,7 @@ +mod ordinal; + +pub use ordinal::*; + +pub trait Observation { + // +} diff --git a/src/observation/ordinal.rs b/src/observation/ordinal.rs new file mode 100644 index 0000000..58651c8 --- /dev/null +++ b/src/observation/ordinal.rs @@ -0,0 +1,58 @@ +use crate::storage::Storage; + +use super::Observation; + +pub struct ProbitWinObservation { + m: usize, + items: Vec, + coeffs: Vec, + indices: Vec, + ns_cav: Vec, + xs_cav: Vec, + t: f64, + logpart: usize, + exp_ll: usize, + margin: f64, +} + +impl ProbitWinObservation { + pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self { + /* + assert len(elems) > 0, "need at least one item per observation" + */ + + ProbitWinObservation { + m: elems.len(), + items: elems.iter().map(|(id, _)| id).cloned().collect(), + coeffs: elems.iter().map(|(_, sign)| sign).cloned().collect(), + indices: elems + .iter() + .map(|(id, _)| storage.get_item(*id).fitter.add_sample(t)) + .collect(), + ns_cav: Vec::new(), + xs_cav: Vec::new(), + t, + logpart: 0, + exp_ll: 0, + margin, + } + } +} + +impl Observation for ProbitWinObservation { + // +} + +pub struct LogitWinObservation { + // +} + +impl LogitWinObservation { + pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self { + LogitWinObservation {} + } +} + +impl Observation for LogitWinObservation { + // +} diff --git a/src/storage.rs b/src/storage.rs new file mode 100644 index 0000000..5fe7dec --- /dev/null +++ b/src/storage.rs @@ -0,0 +1,36 @@ +use std::collections::HashMap; + +use crate::item::Item; + +pub struct Storage { + keys: HashMap, + items: Vec, +} + +impl Storage { + pub fn new() -> Self { + Storage { + keys: HashMap::new(), + items: Vec::new(), + } + } + + pub fn contains_key(&self, key: &str) -> bool { + self.keys.contains_key(key) + } + + pub fn insert(&mut self, key: String, item: Item) { + let index = self.items.len(); + + self.items.push(item); + self.keys.insert(key, index); + } + + pub fn get_id(&self, key: &str) -> usize { + self.keys[key] + } + + pub fn get_item(&mut self, id: usize) -> &mut Item { + &mut self.items[id] + } +}