Making progress.

This commit is contained in:
2020-02-12 22:31:06 +01:00
parent dc545e4063
commit a54bb70138
9 changed files with 243 additions and 12 deletions

View File

@@ -1,7 +1,7 @@
use kickscore as ks; use kickscore as ks;
fn main() { fn main() {
let mut model = ks::BinaryModel {}; let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit);
// Spike's skill does not change over time. // Spike's skill does not change over time.
let k_spike = ks::kernel::Constant::new(0.5); let k_spike = ks::kernel::Constant::new(0.5);
@@ -16,9 +16,9 @@ fn main() {
]; ];
// Now we are ready to add the items in the model. // Now we are ready to add the items in the model.
model.add_item("Spike", k_spike); model.add_item("Spike", Box::new(k_spike));
model.add_item("Tom", k_tom); model.add_item("Tom", Box::new(k_tom));
model.add_item("Jerry", k_jerry); model.add_item("Jerry", Box::new(k_jerry));
// At first, Jerry beats Tom a couple of times. // At first, Jerry beats Tom a couple of times.
model.observe(&["Jerry"], &["Tom"], 0.0); model.observe(&["Jerry"], &["Tom"], 0.0);
@@ -35,20 +35,20 @@ fn main() {
model.fit(true); model.fit(true);
// We can predict a future outcome... // We can predict a future outcome...
let (p_win, p_los) = model.probabilities(&[&"Jerry"], &[&"Tom"], 4.0); let (p_win, _p_los) = model.probabilities(&[&"Jerry"], &[&"Tom"], 4.0);
println!( println!(
"Chances that Jerry beats Tom at t = 4.0: {:.1}%", "Chances that Jerry beats Tom at t = 4.0: {:.1}%",
100.0 * p_win 100.0 * p_win
); );
// ... or simulate what could have happened in the past. // ... or simulate what could have happened in the past.
let (p_win, p_los) = model.probabilities(&[&"Jerry"], &[&"Tom"], 2.0); let (p_win, _p_los) = model.probabilities(&[&"Jerry"], &[&"Tom"], 2.0);
println!( println!(
"Chances that Jerry beats Tom at t = 2.0: {:.1}%", "Chances that Jerry beats Tom at t = 2.0: {:.1}%",
100.0 * p_win 100.0 * p_win
); );
let (p_win, p_los) = model.probabilities(&[&"Jerry"], &[&"Tom"], -1.0); let (p_win, _p_los) = model.probabilities(&[&"Jerry"], &[&"Tom"], -1.0);
println!( println!(
"Chances that Jerry beats Tom at t = -1.0: {:.1}%", "Chances that Jerry beats Tom at t = -1.0: {:.1}%",
100.0 * p_win 100.0 * p_win

7
src/fitter.rs Normal file
View File

@@ -0,0 +1,7 @@
mod batch;
pub use batch::BatchFitter;
pub trait Fitter {
fn add_sample(&mut self, t: f64) -> usize;
}

40
src/fitter/batch.rs Normal file
View File

@@ -0,0 +1,40 @@
use crate::kernel::Kernel;
use super::Fitter;
pub struct BatchFitter {
ts_new: Vec<f64>,
kernel: Box<dyn Kernel>,
ts: Vec<usize>,
ms: Vec<usize>,
vs: Vec<usize>,
ns: Vec<usize>,
xs: Vec<usize>,
is_fitted: bool,
}
impl BatchFitter {
pub fn new(kernel: Box<dyn Kernel>) -> Self {
BatchFitter {
ts_new: Vec::new(),
kernel,
ts: Vec::new(),
ms: Vec::new(),
vs: Vec::new(),
ns: Vec::new(),
xs: Vec::new(),
is_fitted: true,
}
}
}
impl Fitter for BatchFitter {
fn add_sample(&mut self, t: f64) -> usize {
let idx = self.ts.len() + self.ts_new.len();
self.ts_new.push(t);
self.is_fitted = false;
idx
}
}

11
src/item.rs Normal file
View File

@@ -0,0 +1,11 @@
use crate::fitter::Fitter;
pub struct Item {
pub fitter: Box<dyn Fitter>,
}
impl Item {
pub fn new(fitter: Box<dyn Fitter>) -> Self {
Item { fitter }
}
}

View File

@@ -1,6 +1,10 @@
// https://github.com/lucasmaystre/kickscore/tree/master/kickscore // https://github.com/lucasmaystre/kickscore/tree/master/kickscore
mod fitter;
mod item;
pub mod kernel; pub mod kernel;
mod model; mod model;
pub mod observation;
mod storage;
pub use kernel::Kernel; pub use kernel::Kernel;
pub use model::*; pub use model::*;

View File

@@ -1,23 +1,91 @@
use std::f64;
use crate::fitter::BatchFitter;
use crate::item::Item;
use crate::kernel::Kernel; use crate::kernel::Kernel;
use crate::observation::*;
use crate::storage::Storage;
pub enum BinaryModelObservation {
Probit,
Logit,
}
pub struct BinaryModel { pub struct BinaryModel {
// storage: Storage,
last_t: f64,
win_obs: BinaryModelObservation,
observations: Vec<Box<dyn Observation>>,
} }
impl BinaryModel { impl BinaryModel {
pub fn add_item(&mut self, name: &str, kernel: impl Kernel) { pub fn new(win_obs: BinaryModelObservation) -> Self {
// BinaryModel {
storage: Storage::new(),
last_t: f64::NEG_INFINITY,
win_obs,
observations: Vec::new(),
}
}
pub fn add_item(&mut self, name: &str, kernel: Box<dyn Kernel>) {
if self.storage.contains_key(name) {
// raise ValueError("item '{}' already added".format(name))
}
self.storage.insert(
name.to_string(),
Item::new(Box::new(BatchFitter::new(kernel))),
);
} }
pub fn observe(&mut self, winners: &[&str], losers: &[&str], t: f64) { pub fn observe(&mut self, winners: &[&str], losers: &[&str], t: f64) {
// if t < self.last_t {
// raise ValueError("observations must be added in chronological order")
}
let mut elems = self.process_items(winners, 1.0);
elems.extend(self.process_items(losers, -1.0));
let obs: Box<dyn Observation> = match self.win_obs {
BinaryModelObservation::Probit => {
Box::new(ProbitWinObservation::new(&mut self.storage, &elems, t, 0.0))
}
BinaryModelObservation::Logit => {
Box::new(LogitWinObservation::new(&mut self.storage, &elems, t, 0.0))
}
};
self.observations.push(obs);
for (item, _) in elems {
// item.link_observation(obs)
}
self.last_t = t;
} }
pub fn fit(&mut self, verbose: bool) { pub fn fit(&mut self, verbose: bool) {
// //
} }
pub fn probabilities(&mut self, team1: &[&str], team2: &[&str], t: f64) -> (f64, f64) { pub fn probabilities(&mut self, team_1: &[&str], team_2: &[&str], t: f64) -> (f64, f64) {
(0.0, 0.0) (0.0, 0.0)
} }
fn process_items(&self, items: &[&str], sign: f64) -> Vec<(usize, f64)> {
/*
if isinstance(items, dict):
return [(self.item[k], sign * float(v)) for k, v in items.items()]
if isinstance(items, list) or isinstance(items, tuple):
return [(self.item[k], sign) for k in items]
else:
raise ValueError("items should be a list, a tuple or a dict")
*/
items
.iter()
.map(|key| (self.storage.get_id(&key), sign))
.collect()
}
} }

7
src/observation.rs Normal file
View File

@@ -0,0 +1,7 @@
mod ordinal;
pub use ordinal::*;
pub trait Observation {
//
}

View File

@@ -0,0 +1,58 @@
use crate::storage::Storage;
use super::Observation;
pub struct ProbitWinObservation {
m: usize,
items: Vec<usize>,
coeffs: Vec<f64>,
indices: Vec<usize>,
ns_cav: Vec<f64>,
xs_cav: Vec<f64>,
t: f64,
logpart: usize,
exp_ll: usize,
margin: f64,
}
impl ProbitWinObservation {
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
/*
assert len(elems) > 0, "need at least one item per observation"
*/
ProbitWinObservation {
m: elems.len(),
items: elems.iter().map(|(id, _)| id).cloned().collect(),
coeffs: elems.iter().map(|(_, sign)| sign).cloned().collect(),
indices: elems
.iter()
.map(|(id, _)| storage.get_item(*id).fitter.add_sample(t))
.collect(),
ns_cav: Vec::new(),
xs_cav: Vec::new(),
t,
logpart: 0,
exp_ll: 0,
margin,
}
}
}
impl Observation for ProbitWinObservation {
//
}
pub struct LogitWinObservation {
//
}
impl LogitWinObservation {
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
LogitWinObservation {}
}
}
impl Observation for LogitWinObservation {
//
}

36
src/storage.rs Normal file
View File

@@ -0,0 +1,36 @@
use std::collections::HashMap;
use crate::item::Item;
pub struct Storage {
keys: HashMap<String, usize>,
items: Vec<Item>,
}
impl Storage {
pub fn new() -> Self {
Storage {
keys: HashMap::new(),
items: Vec::new(),
}
}
pub fn contains_key(&self, key: &str) -> bool {
self.keys.contains_key(key)
}
pub fn insert(&mut self, key: String, item: Item) {
let index = self.items.len();
self.items.push(item);
self.keys.insert(key, index);
}
pub fn get_id(&self, key: &str) -> usize {
self.keys[key]
}
pub fn get_item(&mut self, id: usize) -> &mut Item {
&mut self.items[id]
}
}