Added structure for diff model.

This commit is contained in:
2020-03-06 16:29:57 +01:00
parent 67d1412af8
commit 9490bffd1e
5 changed files with 140 additions and 6 deletions

View File

@@ -161,12 +161,6 @@ impl Kernel for Vec<Box<dyn Kernel>> {
transition
}
/*
def transition(self, t1, t2):
mats = [k.transition(t1, t2) for k in self.parts]
return block_diag(*mats)
*/
fn noise_cov(&self, t0: f64, t1: f64) -> Array2<f64> {
let data = self
.iter()

View File

@@ -1,5 +1,7 @@
mod binary;
mod difference;
mod ternary;
pub use binary::*;
pub use difference::*;
pub use ternary::*;

100
src/model/difference.rs Normal file
View File

@@ -0,0 +1,100 @@
use std::f64;
use crate::fitter::RecursiveFitter;
use crate::item::Item;
use crate::kernel::Kernel;
use crate::observation::*;
use crate::storage::Storage;
#[derive(Clone, Copy)]
pub enum DifferenceModelFitMethod {
Ep,
Kl,
}
pub struct DifferenceModel {
storage: Storage,
last_t: f64,
observations: Vec<GaussianObservation>,
last_method: Option<DifferenceModelFitMethod>,
var: f64,
}
impl DifferenceModel {
pub fn new(var: f64) -> Self {
DifferenceModel {
storage: Storage::new(),
last_t: f64::NEG_INFINITY,
observations: Vec::new(),
last_method: None,
var, // default = 1.0
}
}
pub fn add_item(&mut self, name: &str, kernel: Box<dyn Kernel>) {
if self.storage.contains_key(name) {
panic!("item '{}' already added", name);
}
self.storage.insert(
name.to_string(),
Item::new(Box::new(RecursiveFitter::new(kernel))),
);
}
pub fn contains_item(&self, name: &str) -> bool {
self.storage.contains_key(name)
}
pub fn item_score(&self, name: &str, t: f64) -> (f64, f64) {
let id = self.storage.get_id(name);
let (ms, vs) = self.storage.item(id).fitter.predict(&[t]);
(ms[0], vs[0])
}
pub fn observe(
&mut self,
winners: &[&str],
losers: &[&str],
diff: f64,
t: f64,
var: Option<f64>,
) {
if t < self.last_t {
panic!("observations must be added in chronological order");
}
let var = var.unwrap_or_else(|| self.var);
let mut elems = self.process_items(winners, 1.0);
elems.extend(self.process_items(losers, -1.0));
let obs = GaussianObservation::new(&mut self.storage, &elems, diff, t, var);
self.observations.push(obs);
self.last_t = t;
}
pub fn fit(&mut self) -> bool {
unimplemented!();
}
pub fn probabilities(
&mut self,
team_1: &[&str],
team_2: &[&str],
t: f64,
margin: Option<f64>,
) -> (f64, f64, f64) {
unimplemented!();
}
fn process_items(&self, items: &[&str], sign: f64) -> Vec<(usize, f64)> {
items
.iter()
.map(|key| (self.storage.get_id(&key), sign))
.collect()
}
}

View File

@@ -1,7 +1,9 @@
use crate::storage::Storage;
mod gaussian;
mod ordinal;
pub use gaussian::*;
pub use ordinal::*;
pub trait Observation {

View File

@@ -0,0 +1,36 @@
use crate::storage::Storage;
use super::Observation;
pub struct GaussianObservation {
m: usize,
items: Vec<usize>,
coeffs: Vec<f64>,
indices: Vec<usize>,
ns_cav: Vec<f64>,
xs_cav: Vec<f64>,
t: f64,
logpart: f64,
exp_ll: usize,
margin: f64,
}
impl GaussianObservation {
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], diff: f64, t: f64, var: f64) -> Self {
unimplemented!();
}
}
impl Observation for GaussianObservation {
fn match_moments(&self, mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
unimplemented!();
}
fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
unimplemented!();
}
fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
unimplemented!();
}
}