diff --git a/src/fitter.rs b/src/fitter.rs index ea67001..07eab9c 100644 --- a/src/fitter.rs +++ b/src/fitter.rs @@ -2,7 +2,7 @@ mod recursive; pub use recursive::Recursive; -pub trait Fitter { +pub trait Fitter: FitterClone { fn add_sample(&mut self, t: f64) -> usize; fn allocate(&mut self); @@ -26,3 +26,22 @@ pub trait Fitter { fn ns(&self, idx: usize) -> f64; fn ns_mut(&mut self, idx: usize) -> &mut f64; } + +pub trait FitterClone { + fn clone_box(&self) -> Box; +} + +impl FitterClone for T +where + T: 'static + Fitter + Clone, +{ + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + +impl Clone for Box { + fn clone(&self) -> Box { + self.clone_box() + } +} diff --git a/src/fitter/recursive.rs b/src/fitter/recursive.rs index 2993dfb..5afef4d 100644 --- a/src/fitter/recursive.rs +++ b/src/fitter/recursive.rs @@ -81,7 +81,7 @@ impl fmt::Debug for Recursive { } } -impl Fitter for Recursive { +impl Fitter for Recursive { fn add_sample(&mut self, t: f64) -> usize { let idx = self.ts.len() + self.ts_new.len(); diff --git a/src/item.rs b/src/item.rs index c7b892b..53cfdb6 100644 --- a/src/item.rs +++ b/src/item.rs @@ -1,5 +1,6 @@ use crate::fitter::Fitter; +#[derive(Clone)] pub struct Item { pub fitter: Box, } diff --git a/src/kernel.rs b/src/kernel.rs index 94cc591..a455be5 100644 --- a/src/kernel.rs +++ b/src/kernel.rs @@ -28,7 +28,7 @@ pub(crate) fn distance(ts1: &[f64], ts2: &[f64]) -> Array2 { r } -pub trait Kernel { +pub trait Kernel: KernelClone { fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2; fn k_diag(&self, ts: &[f64]) -> Array1; fn order(&self) -> usize; @@ -41,6 +41,25 @@ pub trait Kernel { fn noise_cov(&self, t0: f64, t1: f64) -> Array2; } +pub trait KernelClone { + fn clone_box(&self) -> Box; +} + +impl KernelClone for T +where + T: 'static + Kernel + Clone, +{ + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + +impl Clone for Box { + fn clone(&self) -> Box { + self.clone_box() + } +} + impl Kernel for Vec> { fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2 { let n = ts1.len(); @@ -533,7 +552,7 @@ mod tests { macro_rules! tuple_impls { ( $( $name:ident )+ ) => { - impl<$($name: Kernel),+> Kernel for ($($name,)+) { + impl<$($name: 'static + Kernel + Clone),+> Kernel for ($($name,)+) { fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2 { let n = ts1.len(); let m = ts2.map_or(n, |ts| ts.len()); diff --git a/src/kernel/affine.rs b/src/kernel/affine.rs index 258b093..24aa578 100644 --- a/src/kernel/affine.rs +++ b/src/kernel/affine.rs @@ -2,6 +2,7 @@ use ndarray::prelude::*; use super::Kernel; +#[derive(Clone, Copy)] pub struct Affine { var_offset: f64, var_slope: f64, diff --git a/src/kernel/constant.rs b/src/kernel/constant.rs index 1340e1f..0929c91 100644 --- a/src/kernel/constant.rs +++ b/src/kernel/constant.rs @@ -2,6 +2,7 @@ use ndarray::prelude::*; use super::Kernel; +#[derive(Clone, Copy)] pub struct Constant { var: f64, } diff --git a/src/kernel/exponential.rs b/src/kernel/exponential.rs index cba50b0..4f9f87b 100644 --- a/src/kernel/exponential.rs +++ b/src/kernel/exponential.rs @@ -2,6 +2,7 @@ use ndarray::prelude::*; use super::Kernel; +#[derive(Clone, Copy)] pub struct Exponential { var: f64, l_scale: f64, diff --git a/src/kernel/matern32.rs b/src/kernel/matern32.rs index cf63391..84af987 100644 --- a/src/kernel/matern32.rs +++ b/src/kernel/matern32.rs @@ -2,7 +2,7 @@ use ndarray::prelude::*; use super::Kernel; -#[derive(Clone)] +#[derive(Clone, Copy)] pub struct Matern32 { var: f64, l_scale: f64, diff --git a/src/kernel/matern52.rs b/src/kernel/matern52.rs index d4533e0..63376b9 100644 --- a/src/kernel/matern52.rs +++ b/src/kernel/matern52.rs @@ -2,6 +2,7 @@ use ndarray::prelude::*; use super::Kernel; +#[derive(Clone, Copy)] pub struct Matern52 { var: f64, l_scale: f64, diff --git a/src/kernel/wiener.rs b/src/kernel/wiener.rs index 1a86d0a..025571b 100644 --- a/src/kernel/wiener.rs +++ b/src/kernel/wiener.rs @@ -2,6 +2,7 @@ use ndarray::prelude::*; use super::Kernel; +#[derive(Clone, Copy)] pub struct Wiener { var: f64, t0: f64, diff --git a/src/model/binary.rs b/src/model/binary.rs index a3dadc4..0069be8 100644 --- a/src/model/binary.rs +++ b/src/model/binary.rs @@ -12,6 +12,7 @@ trait BinaryObeservation: Observation { } */ +#[derive(Clone, Copy)] pub enum Observation { Probit, Logit, @@ -23,6 +24,7 @@ pub enum FitMethod { Kl, } +#[derive(Clone)] pub struct Binary { pub storage: Storage, last_t: f64, @@ -50,7 +52,7 @@ impl Binary { Self::new(Observation::Logit) } - pub fn add_item(&mut self, name: &str, kernel: K) { + pub fn add_item(&mut self, name: &str, kernel: K) { if self.storage.contains_key(name) { panic!("item '{}' already added", name); } diff --git a/src/model/difference.rs b/src/model/difference.rs index f7f0c28..1145503 100644 --- a/src/model/difference.rs +++ b/src/model/difference.rs @@ -12,6 +12,7 @@ pub enum DifferenceModelFitMethod { Kl, } +#[derive(Clone)] pub struct DifferenceModel { storage: Storage, last_t: f64, @@ -31,7 +32,7 @@ impl DifferenceModel { } } - pub fn add_item(&mut self, name: &str, kernel: K) { + pub fn add_item(&mut self, name: &str, kernel: K) { if self.storage.contains_key(name) { panic!("item '{}' already added", name); } diff --git a/src/model/ternary.rs b/src/model/ternary.rs index ea7411c..7dae213 100644 --- a/src/model/ternary.rs +++ b/src/model/ternary.rs @@ -18,6 +18,7 @@ pub enum FitMethod { Kl, } +#[derive(Clone)] pub struct TernaryModel { storage: Storage, last_t: f64, @@ -47,7 +48,7 @@ impl TernaryModel { Self::new(Observation::Logit, margin) } - pub fn add_item(&mut self, name: &str, kernel: K) { + pub fn add_item(&mut self, name: &str, kernel: K) { if self.storage.contains_key(name) { panic!("item '{}' already added", name); } diff --git a/src/observation.rs b/src/observation.rs index 11b3bda..2859b3b 100644 --- a/src/observation.rs +++ b/src/observation.rs @@ -6,11 +6,30 @@ mod ordinal; pub use gaussian::*; pub use ordinal::*; -pub trait Observation { +pub trait Observation: ObservationClone { fn ep_update(&mut self, storage: &mut Storage, lr: f64) -> f64; fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64; } +pub trait ObservationClone { + fn clone_box(&self) -> Box; +} + +impl ObservationClone for T +where + T: 'static + Observation + Clone, +{ + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + +impl Clone for Box { + fn clone(&self) -> Box { + self.clone_box() + } +} + pub fn f_params(elems: &[(usize, f64)], t: f64, storage: &Storage) -> (f64, f64) { let mut m = 0.0; let mut v = 0.0; @@ -25,6 +44,7 @@ pub fn f_params(elems: &[(usize, f64)], t: f64, storage: &Storage) -> (f64, f64) (m, v) } +#[derive(Clone)] pub(crate) struct Core { pub m: usize, pub items: Vec, diff --git a/src/observation/gaussian.rs b/src/observation/gaussian.rs index 0185448..ea982e4 100644 --- a/src/observation/gaussian.rs +++ b/src/observation/gaussian.rs @@ -14,6 +14,7 @@ fn mm_gaussian(mean_cav: f64, var_cav: f64, diff: f64, var_obs: f64) -> (f64, f6 (logpart, dlogpart, d2logpart) } +#[derive(Clone)] pub struct GaussianObservation { core: Core, diff: f64, diff --git a/src/observation/ordinal.rs b/src/observation/ordinal.rs index a5885cb..5556f15 100644 --- a/src/observation/ordinal.rs +++ b/src/observation/ordinal.rs @@ -87,6 +87,7 @@ where (exp_ll, alpha, beta) } +#[derive(Clone)] pub struct ProbitWinObservation { core: Core, margin: f64, @@ -126,6 +127,7 @@ impl Observation for ProbitWinObservation { } } +#[derive(Clone)] pub struct ProbitTieObservation { core: Core, margin: f64, @@ -165,6 +167,7 @@ impl Observation for ProbitTieObservation { } } +#[derive(Clone)] pub struct LogitWinObservation { core: Core, margin: f64, @@ -197,6 +200,7 @@ impl Observation for LogitWinObservation { } } +#[derive(Clone)] pub struct LogitTieObservation { core: Core, margin: f64, diff --git a/src/storage.rs b/src/storage.rs index 154c4de..48ce3ac 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -22,7 +22,7 @@ pub trait Backend { } */ -#[derive(Default)] +#[derive(Clone, Default)] pub struct Storage { keys: HashMap, items: Vec, diff --git a/tests/kickscore-basics.rs b/tests/kickscore-basics.rs index b1d9753..9e9d1d6 100644 --- a/tests/kickscore-basics.rs +++ b/tests/kickscore-basics.rs @@ -35,9 +35,9 @@ fn kickscore_basic() { let (p_win, _p_los) = model.probabilities(&["Jerry"], &["Tom"], 2.0); - assert_abs_diff_eq!(p_win, 0.4455095363120037, epsilon = f64::EPSILON); + assert_abs_diff_eq!(p_win, 0.4705389648502623, epsilon = f64::EPSILON); let (p_win, _p_los) = model.probabilities(&["Jerry"], &["Tom"], -1.0); - assert_abs_diff_eq!(p_win, 0.9037560799725326, epsilon = f64::EPSILON); + assert_abs_diff_eq!(p_win, 0.7030407811954662, epsilon = f64::EPSILON); } diff --git a/tests/nba-history.rs b/tests/nba-history.rs index ac1aeb2..011153d 100644 --- a/tests/nba-history.rs +++ b/tests/nba-history.rs @@ -84,7 +84,7 @@ fn nba_history() { .unix_timestamp() as f64, ); - assert_abs_diff_eq!(p_win, 0.9002599772490479, epsilon = f64::EPSILON); + assert_abs_diff_eq!(p_win, 0.8987931627209078, epsilon = f64::EPSILON); let (p_win, _) = model.probabilities( &["CHI"], @@ -95,7 +95,7 @@ fn nba_history() { .unix_timestamp() as f64, ); - assert_abs_diff_eq!(p_win, 0.22837870685441986, epsilon = f64::EPSILON); + assert_abs_diff_eq!(p_win, 0.22890824995747874, epsilon = f64::EPSILON); let (p_win, _) = model.probabilities( &["CHI"],