Make models cloneable

This commit is contained in:
2022-06-07 10:53:11 +02:00
parent 5a54d460c9
commit 174065bcf1
19 changed files with 87 additions and 14 deletions

View File

@@ -2,7 +2,7 @@ mod recursive;
pub use recursive::Recursive; pub use recursive::Recursive;
pub trait Fitter { pub trait Fitter: FitterClone {
fn add_sample(&mut self, t: f64) -> usize; fn add_sample(&mut self, t: f64) -> usize;
fn allocate(&mut self); fn allocate(&mut self);
@@ -26,3 +26,22 @@ pub trait Fitter {
fn ns(&self, idx: usize) -> f64; fn ns(&self, idx: usize) -> f64;
fn ns_mut(&mut self, idx: usize) -> &mut f64; fn ns_mut(&mut self, idx: usize) -> &mut f64;
} }
pub trait FitterClone {
fn clone_box(&self) -> Box<dyn Fitter>;
}
impl<T> FitterClone for T
where
T: 'static + Fitter + Clone,
{
fn clone_box(&self) -> Box<dyn Fitter> {
Box::new(self.clone())
}
}
impl Clone for Box<dyn Fitter> {
fn clone(&self) -> Box<dyn Fitter> {
self.clone_box()
}
}

View File

@@ -81,7 +81,7 @@ impl<K: Kernel> fmt::Debug for Recursive<K> {
} }
} }
impl<K: Kernel> Fitter for Recursive<K> { impl<K: 'static + Kernel + Clone> Fitter for Recursive<K> {
fn add_sample(&mut self, t: f64) -> usize { fn add_sample(&mut self, t: f64) -> usize {
let idx = self.ts.len() + self.ts_new.len(); let idx = self.ts.len() + self.ts_new.len();

View File

@@ -1,5 +1,6 @@
use crate::fitter::Fitter; use crate::fitter::Fitter;
#[derive(Clone)]
pub struct Item { pub struct Item {
pub fitter: Box<dyn Fitter>, pub fitter: Box<dyn Fitter>,
} }

View File

@@ -28,7 +28,7 @@ pub(crate) fn distance(ts1: &[f64], ts2: &[f64]) -> Array2<f64> {
r r
} }
pub trait Kernel { pub trait Kernel: KernelClone {
fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2<f64>; fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2<f64>;
fn k_diag(&self, ts: &[f64]) -> Array1<f64>; fn k_diag(&self, ts: &[f64]) -> Array1<f64>;
fn order(&self) -> usize; fn order(&self) -> usize;
@@ -41,6 +41,25 @@ pub trait Kernel {
fn noise_cov(&self, t0: f64, t1: f64) -> Array2<f64>; fn noise_cov(&self, t0: f64, t1: f64) -> Array2<f64>;
} }
pub trait KernelClone {
fn clone_box(&self) -> Box<dyn Kernel>;
}
impl<T> KernelClone for T
where
T: 'static + Kernel + Clone,
{
fn clone_box(&self) -> Box<dyn Kernel> {
Box::new(self.clone())
}
}
impl Clone for Box<dyn Kernel> {
fn clone(&self) -> Box<dyn Kernel> {
self.clone_box()
}
}
impl Kernel for Vec<Box<dyn Kernel>> { impl Kernel for Vec<Box<dyn Kernel>> {
fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2<f64> { fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2<f64> {
let n = ts1.len(); let n = ts1.len();
@@ -533,7 +552,7 @@ mod tests {
macro_rules! tuple_impls { macro_rules! tuple_impls {
( $( $name:ident )+ ) => { ( $( $name:ident )+ ) => {
impl<$($name: Kernel),+> Kernel for ($($name,)+) { impl<$($name: 'static + Kernel + Clone),+> Kernel for ($($name,)+) {
fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2<f64> { fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2<f64> {
let n = ts1.len(); let n = ts1.len();
let m = ts2.map_or(n, |ts| ts.len()); let m = ts2.map_or(n, |ts| ts.len());

View File

@@ -2,6 +2,7 @@ use ndarray::prelude::*;
use super::Kernel; use super::Kernel;
#[derive(Clone, Copy)]
pub struct Affine { pub struct Affine {
var_offset: f64, var_offset: f64,
var_slope: f64, var_slope: f64,

View File

@@ -2,6 +2,7 @@ use ndarray::prelude::*;
use super::Kernel; use super::Kernel;
#[derive(Clone, Copy)]
pub struct Constant { pub struct Constant {
var: f64, var: f64,
} }

View File

@@ -2,6 +2,7 @@ use ndarray::prelude::*;
use super::Kernel; use super::Kernel;
#[derive(Clone, Copy)]
pub struct Exponential { pub struct Exponential {
var: f64, var: f64,
l_scale: f64, l_scale: f64,

View File

@@ -2,7 +2,7 @@ use ndarray::prelude::*;
use super::Kernel; use super::Kernel;
#[derive(Clone)] #[derive(Clone, Copy)]
pub struct Matern32 { pub struct Matern32 {
var: f64, var: f64,
l_scale: f64, l_scale: f64,

View File

@@ -2,6 +2,7 @@ use ndarray::prelude::*;
use super::Kernel; use super::Kernel;
#[derive(Clone, Copy)]
pub struct Matern52 { pub struct Matern52 {
var: f64, var: f64,
l_scale: f64, l_scale: f64,

View File

@@ -2,6 +2,7 @@ use ndarray::prelude::*;
use super::Kernel; use super::Kernel;
#[derive(Clone, Copy)]
pub struct Wiener { pub struct Wiener {
var: f64, var: f64,
t0: f64, t0: f64,

View File

@@ -12,6 +12,7 @@ trait BinaryObeservation: Observation {
} }
*/ */
#[derive(Clone, Copy)]
pub enum Observation { pub enum Observation {
Probit, Probit,
Logit, Logit,
@@ -23,6 +24,7 @@ pub enum FitMethod {
Kl, Kl,
} }
#[derive(Clone)]
pub struct Binary { pub struct Binary {
pub storage: Storage, pub storage: Storage,
last_t: f64, last_t: f64,
@@ -50,7 +52,7 @@ impl Binary {
Self::new(Observation::Logit) Self::new(Observation::Logit)
} }
pub fn add_item<K: Kernel + 'static>(&mut self, name: &str, kernel: K) { pub fn add_item<K: Kernel + 'static + Clone>(&mut self, name: &str, kernel: K) {
if self.storage.contains_key(name) { if self.storage.contains_key(name) {
panic!("item '{}' already added", name); panic!("item '{}' already added", name);
} }

View File

@@ -12,6 +12,7 @@ pub enum DifferenceModelFitMethod {
Kl, Kl,
} }
#[derive(Clone)]
pub struct DifferenceModel { pub struct DifferenceModel {
storage: Storage, storage: Storage,
last_t: f64, last_t: f64,
@@ -31,7 +32,7 @@ impl DifferenceModel {
} }
} }
pub fn add_item<K: Kernel + 'static>(&mut self, name: &str, kernel: K) { pub fn add_item<K: Kernel + 'static + Clone>(&mut self, name: &str, kernel: K) {
if self.storage.contains_key(name) { if self.storage.contains_key(name) {
panic!("item '{}' already added", name); panic!("item '{}' already added", name);
} }

View File

@@ -18,6 +18,7 @@ pub enum FitMethod {
Kl, Kl,
} }
#[derive(Clone)]
pub struct TernaryModel { pub struct TernaryModel {
storage: Storage, storage: Storage,
last_t: f64, last_t: f64,
@@ -47,7 +48,7 @@ impl TernaryModel {
Self::new(Observation::Logit, margin) Self::new(Observation::Logit, margin)
} }
pub fn add_item<K: Kernel + 'static>(&mut self, name: &str, kernel: K) { pub fn add_item<K: Kernel + 'static + Clone>(&mut self, name: &str, kernel: K) {
if self.storage.contains_key(name) { if self.storage.contains_key(name) {
panic!("item '{}' already added", name); panic!("item '{}' already added", name);
} }

View File

@@ -6,11 +6,30 @@ mod ordinal;
pub use gaussian::*; pub use gaussian::*;
pub use ordinal::*; pub use ordinal::*;
pub trait Observation { pub trait Observation: ObservationClone {
fn ep_update(&mut self, storage: &mut Storage, lr: f64) -> f64; fn ep_update(&mut self, storage: &mut Storage, lr: f64) -> f64;
fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64; fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64;
} }
pub trait ObservationClone {
fn clone_box(&self) -> Box<dyn Observation>;
}
impl<T> ObservationClone for T
where
T: 'static + Observation + Clone,
{
fn clone_box(&self) -> Box<dyn Observation> {
Box::new(self.clone())
}
}
impl Clone for Box<dyn Observation> {
fn clone(&self) -> Box<dyn Observation> {
self.clone_box()
}
}
pub fn f_params(elems: &[(usize, f64)], t: f64, storage: &Storage) -> (f64, f64) { pub fn f_params(elems: &[(usize, f64)], t: f64, storage: &Storage) -> (f64, f64) {
let mut m = 0.0; let mut m = 0.0;
let mut v = 0.0; let mut v = 0.0;
@@ -25,6 +44,7 @@ pub fn f_params(elems: &[(usize, f64)], t: f64, storage: &Storage) -> (f64, f64)
(m, v) (m, v)
} }
#[derive(Clone)]
pub(crate) struct Core { pub(crate) struct Core {
pub m: usize, pub m: usize,
pub items: Vec<usize>, pub items: Vec<usize>,

View File

@@ -14,6 +14,7 @@ fn mm_gaussian(mean_cav: f64, var_cav: f64, diff: f64, var_obs: f64) -> (f64, f6
(logpart, dlogpart, d2logpart) (logpart, dlogpart, d2logpart)
} }
#[derive(Clone)]
pub struct GaussianObservation { pub struct GaussianObservation {
core: Core, core: Core,
diff: f64, diff: f64,

View File

@@ -87,6 +87,7 @@ where
(exp_ll, alpha, beta) (exp_ll, alpha, beta)
} }
#[derive(Clone)]
pub struct ProbitWinObservation { pub struct ProbitWinObservation {
core: Core, core: Core,
margin: f64, margin: f64,
@@ -126,6 +127,7 @@ impl Observation for ProbitWinObservation {
} }
} }
#[derive(Clone)]
pub struct ProbitTieObservation { pub struct ProbitTieObservation {
core: Core, core: Core,
margin: f64, margin: f64,
@@ -165,6 +167,7 @@ impl Observation for ProbitTieObservation {
} }
} }
#[derive(Clone)]
pub struct LogitWinObservation { pub struct LogitWinObservation {
core: Core, core: Core,
margin: f64, margin: f64,
@@ -197,6 +200,7 @@ impl Observation for LogitWinObservation {
} }
} }
#[derive(Clone)]
pub struct LogitTieObservation { pub struct LogitTieObservation {
core: Core, core: Core,
margin: f64, margin: f64,

View File

@@ -22,7 +22,7 @@ pub trait Backend {
} }
*/ */
#[derive(Default)] #[derive(Clone, Default)]
pub struct Storage { pub struct Storage {
keys: HashMap<String, usize>, keys: HashMap<String, usize>,
items: Vec<Item>, items: Vec<Item>,

View File

@@ -35,9 +35,9 @@ fn kickscore_basic() {
let (p_win, _p_los) = model.probabilities(&["Jerry"], &["Tom"], 2.0); let (p_win, _p_los) = model.probabilities(&["Jerry"], &["Tom"], 2.0);
assert_abs_diff_eq!(p_win, 0.4455095363120037, epsilon = f64::EPSILON); assert_abs_diff_eq!(p_win, 0.4705389648502623, epsilon = f64::EPSILON);
let (p_win, _p_los) = model.probabilities(&["Jerry"], &["Tom"], -1.0); let (p_win, _p_los) = model.probabilities(&["Jerry"], &["Tom"], -1.0);
assert_abs_diff_eq!(p_win, 0.9037560799725326, epsilon = f64::EPSILON); assert_abs_diff_eq!(p_win, 0.7030407811954662, epsilon = f64::EPSILON);
} }

View File

@@ -84,7 +84,7 @@ fn nba_history() {
.unix_timestamp() as f64, .unix_timestamp() as f64,
); );
assert_abs_diff_eq!(p_win, 0.9002599772490479, epsilon = f64::EPSILON); assert_abs_diff_eq!(p_win, 0.8987931627209078, epsilon = f64::EPSILON);
let (p_win, _) = model.probabilities( let (p_win, _) = model.probabilities(
&["CHI"], &["CHI"],
@@ -95,7 +95,7 @@ fn nba_history() {
.unix_timestamp() as f64, .unix_timestamp() as f64,
); );
assert_abs_diff_eq!(p_win, 0.22837870685441986, epsilon = f64::EPSILON); assert_abs_diff_eq!(p_win, 0.22890824995747874, epsilon = f64::EPSILON);
let (p_win, _) = model.probabilities( let (p_win, _) = model.probabilities(
&["CHI"], &["CHI"],