Added structure for diff model.
This commit is contained in:
@@ -161,12 +161,6 @@ impl Kernel for Vec<Box<dyn Kernel>> {
|
|||||||
transition
|
transition
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
def transition(self, t1, t2):
|
|
||||||
mats = [k.transition(t1, t2) for k in self.parts]
|
|
||||||
return block_diag(*mats)
|
|
||||||
*/
|
|
||||||
|
|
||||||
fn noise_cov(&self, t0: f64, t1: f64) -> Array2<f64> {
|
fn noise_cov(&self, t0: f64, t1: f64) -> Array2<f64> {
|
||||||
let data = self
|
let data = self
|
||||||
.iter()
|
.iter()
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
mod binary;
|
mod binary;
|
||||||
|
mod difference;
|
||||||
mod ternary;
|
mod ternary;
|
||||||
|
|
||||||
pub use binary::*;
|
pub use binary::*;
|
||||||
|
pub use difference::*;
|
||||||
pub use ternary::*;
|
pub use ternary::*;
|
||||||
|
|||||||
100
src/model/difference.rs
Normal file
100
src/model/difference.rs
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
use std::f64;
|
||||||
|
|
||||||
|
use crate::fitter::RecursiveFitter;
|
||||||
|
use crate::item::Item;
|
||||||
|
use crate::kernel::Kernel;
|
||||||
|
use crate::observation::*;
|
||||||
|
use crate::storage::Storage;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
pub enum DifferenceModelFitMethod {
|
||||||
|
Ep,
|
||||||
|
Kl,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct DifferenceModel {
|
||||||
|
storage: Storage,
|
||||||
|
last_t: f64,
|
||||||
|
observations: Vec<GaussianObservation>,
|
||||||
|
last_method: Option<DifferenceModelFitMethod>,
|
||||||
|
var: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DifferenceModel {
|
||||||
|
pub fn new(var: f64) -> Self {
|
||||||
|
DifferenceModel {
|
||||||
|
storage: Storage::new(),
|
||||||
|
last_t: f64::NEG_INFINITY,
|
||||||
|
observations: Vec::new(),
|
||||||
|
last_method: None,
|
||||||
|
var, // default = 1.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_item(&mut self, name: &str, kernel: Box<dyn Kernel>) {
|
||||||
|
if self.storage.contains_key(name) {
|
||||||
|
panic!("item '{}' already added", name);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.storage.insert(
|
||||||
|
name.to_string(),
|
||||||
|
Item::new(Box::new(RecursiveFitter::new(kernel))),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn contains_item(&self, name: &str) -> bool {
|
||||||
|
self.storage.contains_key(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn item_score(&self, name: &str, t: f64) -> (f64, f64) {
|
||||||
|
let id = self.storage.get_id(name);
|
||||||
|
let (ms, vs) = self.storage.item(id).fitter.predict(&[t]);
|
||||||
|
|
||||||
|
(ms[0], vs[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn observe(
|
||||||
|
&mut self,
|
||||||
|
winners: &[&str],
|
||||||
|
losers: &[&str],
|
||||||
|
diff: f64,
|
||||||
|
t: f64,
|
||||||
|
var: Option<f64>,
|
||||||
|
) {
|
||||||
|
if t < self.last_t {
|
||||||
|
panic!("observations must be added in chronological order");
|
||||||
|
}
|
||||||
|
|
||||||
|
let var = var.unwrap_or_else(|| self.var);
|
||||||
|
|
||||||
|
let mut elems = self.process_items(winners, 1.0);
|
||||||
|
elems.extend(self.process_items(losers, -1.0));
|
||||||
|
|
||||||
|
let obs = GaussianObservation::new(&mut self.storage, &elems, diff, t, var);
|
||||||
|
|
||||||
|
self.observations.push(obs);
|
||||||
|
|
||||||
|
self.last_t = t;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn fit(&mut self) -> bool {
|
||||||
|
unimplemented!();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn probabilities(
|
||||||
|
&mut self,
|
||||||
|
team_1: &[&str],
|
||||||
|
team_2: &[&str],
|
||||||
|
t: f64,
|
||||||
|
margin: Option<f64>,
|
||||||
|
) -> (f64, f64, f64) {
|
||||||
|
unimplemented!();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn process_items(&self, items: &[&str], sign: f64) -> Vec<(usize, f64)> {
|
||||||
|
items
|
||||||
|
.iter()
|
||||||
|
.map(|key| (self.storage.get_id(&key), sign))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
use crate::storage::Storage;
|
use crate::storage::Storage;
|
||||||
|
|
||||||
|
mod gaussian;
|
||||||
mod ordinal;
|
mod ordinal;
|
||||||
|
|
||||||
|
pub use gaussian::*;
|
||||||
pub use ordinal::*;
|
pub use ordinal::*;
|
||||||
|
|
||||||
pub trait Observation {
|
pub trait Observation {
|
||||||
|
|||||||
36
src/observation/gaussian.rs
Normal file
36
src/observation/gaussian.rs
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
use crate::storage::Storage;
|
||||||
|
|
||||||
|
use super::Observation;
|
||||||
|
|
||||||
|
pub struct GaussianObservation {
|
||||||
|
m: usize,
|
||||||
|
items: Vec<usize>,
|
||||||
|
coeffs: Vec<f64>,
|
||||||
|
indices: Vec<usize>,
|
||||||
|
ns_cav: Vec<f64>,
|
||||||
|
xs_cav: Vec<f64>,
|
||||||
|
t: f64,
|
||||||
|
logpart: f64,
|
||||||
|
exp_ll: usize,
|
||||||
|
margin: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GaussianObservation {
|
||||||
|
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], diff: f64, t: f64, var: f64) -> Self {
|
||||||
|
unimplemented!();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Observation for GaussianObservation {
|
||||||
|
fn match_moments(&self, mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
|
||||||
|
unimplemented!();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
|
||||||
|
unimplemented!();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
|
||||||
|
unimplemented!();
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user