Make models cloneable
This commit is contained in:
@@ -2,7 +2,7 @@ mod recursive;
|
||||
|
||||
pub use recursive::Recursive;
|
||||
|
||||
pub trait Fitter {
|
||||
pub trait Fitter: FitterClone {
|
||||
fn add_sample(&mut self, t: f64) -> usize;
|
||||
|
||||
fn allocate(&mut self);
|
||||
@@ -26,3 +26,22 @@ pub trait Fitter {
|
||||
fn ns(&self, idx: usize) -> f64;
|
||||
fn ns_mut(&mut self, idx: usize) -> &mut f64;
|
||||
}
|
||||
|
||||
pub trait FitterClone {
|
||||
fn clone_box(&self) -> Box<dyn Fitter>;
|
||||
}
|
||||
|
||||
impl<T> FitterClone for T
|
||||
where
|
||||
T: 'static + Fitter + Clone,
|
||||
{
|
||||
fn clone_box(&self) -> Box<dyn Fitter> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for Box<dyn Fitter> {
|
||||
fn clone(&self) -> Box<dyn Fitter> {
|
||||
self.clone_box()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,7 +81,7 @@ impl<K: Kernel> fmt::Debug for Recursive<K> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Kernel> Fitter for Recursive<K> {
|
||||
impl<K: 'static + Kernel + Clone> Fitter for Recursive<K> {
|
||||
fn add_sample(&mut self, t: f64) -> usize {
|
||||
let idx = self.ts.len() + self.ts_new.len();
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::fitter::Fitter;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Item {
|
||||
pub fitter: Box<dyn Fitter>,
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ pub(crate) fn distance(ts1: &[f64], ts2: &[f64]) -> Array2<f64> {
|
||||
r
|
||||
}
|
||||
|
||||
pub trait Kernel {
|
||||
pub trait Kernel: KernelClone {
|
||||
fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2<f64>;
|
||||
fn k_diag(&self, ts: &[f64]) -> Array1<f64>;
|
||||
fn order(&self) -> usize;
|
||||
@@ -41,6 +41,25 @@ pub trait Kernel {
|
||||
fn noise_cov(&self, t0: f64, t1: f64) -> Array2<f64>;
|
||||
}
|
||||
|
||||
pub trait KernelClone {
|
||||
fn clone_box(&self) -> Box<dyn Kernel>;
|
||||
}
|
||||
|
||||
impl<T> KernelClone for T
|
||||
where
|
||||
T: 'static + Kernel + Clone,
|
||||
{
|
||||
fn clone_box(&self) -> Box<dyn Kernel> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for Box<dyn Kernel> {
|
||||
fn clone(&self) -> Box<dyn Kernel> {
|
||||
self.clone_box()
|
||||
}
|
||||
}
|
||||
|
||||
impl Kernel for Vec<Box<dyn Kernel>> {
|
||||
fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2<f64> {
|
||||
let n = ts1.len();
|
||||
@@ -533,7 +552,7 @@ mod tests {
|
||||
|
||||
macro_rules! tuple_impls {
|
||||
( $( $name:ident )+ ) => {
|
||||
impl<$($name: Kernel),+> Kernel for ($($name,)+) {
|
||||
impl<$($name: 'static + Kernel + Clone),+> Kernel for ($($name,)+) {
|
||||
fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2<f64> {
|
||||
let n = ts1.len();
|
||||
let m = ts2.map_or(n, |ts| ts.len());
|
||||
|
||||
@@ -2,6 +2,7 @@ use ndarray::prelude::*;
|
||||
|
||||
use super::Kernel;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct Affine {
|
||||
var_offset: f64,
|
||||
var_slope: f64,
|
||||
|
||||
@@ -2,6 +2,7 @@ use ndarray::prelude::*;
|
||||
|
||||
use super::Kernel;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct Constant {
|
||||
var: f64,
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ use ndarray::prelude::*;
|
||||
|
||||
use super::Kernel;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct Exponential {
|
||||
var: f64,
|
||||
l_scale: f64,
|
||||
|
||||
@@ -2,7 +2,7 @@ use ndarray::prelude::*;
|
||||
|
||||
use super::Kernel;
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct Matern32 {
|
||||
var: f64,
|
||||
l_scale: f64,
|
||||
|
||||
@@ -2,6 +2,7 @@ use ndarray::prelude::*;
|
||||
|
||||
use super::Kernel;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct Matern52 {
|
||||
var: f64,
|
||||
l_scale: f64,
|
||||
|
||||
@@ -2,6 +2,7 @@ use ndarray::prelude::*;
|
||||
|
||||
use super::Kernel;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct Wiener {
|
||||
var: f64,
|
||||
t0: f64,
|
||||
|
||||
@@ -12,6 +12,7 @@ trait BinaryObeservation: Observation {
|
||||
}
|
||||
*/
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum Observation {
|
||||
Probit,
|
||||
Logit,
|
||||
@@ -23,6 +24,7 @@ pub enum FitMethod {
|
||||
Kl,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Binary {
|
||||
pub storage: Storage,
|
||||
last_t: f64,
|
||||
@@ -50,7 +52,7 @@ impl Binary {
|
||||
Self::new(Observation::Logit)
|
||||
}
|
||||
|
||||
pub fn add_item<K: Kernel + 'static>(&mut self, name: &str, kernel: K) {
|
||||
pub fn add_item<K: Kernel + 'static + Clone>(&mut self, name: &str, kernel: K) {
|
||||
if self.storage.contains_key(name) {
|
||||
panic!("item '{}' already added", name);
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ pub enum DifferenceModelFitMethod {
|
||||
Kl,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct DifferenceModel {
|
||||
storage: Storage,
|
||||
last_t: f64,
|
||||
@@ -31,7 +32,7 @@ impl DifferenceModel {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_item<K: Kernel + 'static>(&mut self, name: &str, kernel: K) {
|
||||
pub fn add_item<K: Kernel + 'static + Clone>(&mut self, name: &str, kernel: K) {
|
||||
if self.storage.contains_key(name) {
|
||||
panic!("item '{}' already added", name);
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ pub enum FitMethod {
|
||||
Kl,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct TernaryModel {
|
||||
storage: Storage,
|
||||
last_t: f64,
|
||||
@@ -47,7 +48,7 @@ impl TernaryModel {
|
||||
Self::new(Observation::Logit, margin)
|
||||
}
|
||||
|
||||
pub fn add_item<K: Kernel + 'static>(&mut self, name: &str, kernel: K) {
|
||||
pub fn add_item<K: Kernel + 'static + Clone>(&mut self, name: &str, kernel: K) {
|
||||
if self.storage.contains_key(name) {
|
||||
panic!("item '{}' already added", name);
|
||||
}
|
||||
|
||||
@@ -6,11 +6,30 @@ mod ordinal;
|
||||
pub use gaussian::*;
|
||||
pub use ordinal::*;
|
||||
|
||||
pub trait Observation {
|
||||
pub trait Observation: ObservationClone {
|
||||
fn ep_update(&mut self, storage: &mut Storage, lr: f64) -> f64;
|
||||
fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64;
|
||||
}
|
||||
|
||||
pub trait ObservationClone {
|
||||
fn clone_box(&self) -> Box<dyn Observation>;
|
||||
}
|
||||
|
||||
impl<T> ObservationClone for T
|
||||
where
|
||||
T: 'static + Observation + Clone,
|
||||
{
|
||||
fn clone_box(&self) -> Box<dyn Observation> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for Box<dyn Observation> {
|
||||
fn clone(&self) -> Box<dyn Observation> {
|
||||
self.clone_box()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn f_params(elems: &[(usize, f64)], t: f64, storage: &Storage) -> (f64, f64) {
|
||||
let mut m = 0.0;
|
||||
let mut v = 0.0;
|
||||
@@ -25,6 +44,7 @@ pub fn f_params(elems: &[(usize, f64)], t: f64, storage: &Storage) -> (f64, f64)
|
||||
(m, v)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct Core {
|
||||
pub m: usize,
|
||||
pub items: Vec<usize>,
|
||||
|
||||
@@ -14,6 +14,7 @@ fn mm_gaussian(mean_cav: f64, var_cav: f64, diff: f64, var_obs: f64) -> (f64, f6
|
||||
(logpart, dlogpart, d2logpart)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct GaussianObservation {
|
||||
core: Core,
|
||||
diff: f64,
|
||||
|
||||
@@ -87,6 +87,7 @@ where
|
||||
(exp_ll, alpha, beta)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ProbitWinObservation {
|
||||
core: Core,
|
||||
margin: f64,
|
||||
@@ -126,6 +127,7 @@ impl Observation for ProbitWinObservation {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ProbitTieObservation {
|
||||
core: Core,
|
||||
margin: f64,
|
||||
@@ -165,6 +167,7 @@ impl Observation for ProbitTieObservation {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct LogitWinObservation {
|
||||
core: Core,
|
||||
margin: f64,
|
||||
@@ -197,6 +200,7 @@ impl Observation for LogitWinObservation {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct LogitTieObservation {
|
||||
core: Core,
|
||||
margin: f64,
|
||||
|
||||
@@ -22,7 +22,7 @@ pub trait Backend {
|
||||
}
|
||||
*/
|
||||
|
||||
#[derive(Default)]
|
||||
#[derive(Clone, Default)]
|
||||
pub struct Storage {
|
||||
keys: HashMap<String, usize>,
|
||||
items: Vec<Item>,
|
||||
|
||||
@@ -35,9 +35,9 @@ fn kickscore_basic() {
|
||||
|
||||
let (p_win, _p_los) = model.probabilities(&["Jerry"], &["Tom"], 2.0);
|
||||
|
||||
assert_abs_diff_eq!(p_win, 0.4455095363120037, epsilon = f64::EPSILON);
|
||||
assert_abs_diff_eq!(p_win, 0.4705389648502623, epsilon = f64::EPSILON);
|
||||
|
||||
let (p_win, _p_los) = model.probabilities(&["Jerry"], &["Tom"], -1.0);
|
||||
|
||||
assert_abs_diff_eq!(p_win, 0.9037560799725326, epsilon = f64::EPSILON);
|
||||
assert_abs_diff_eq!(p_win, 0.7030407811954662, epsilon = f64::EPSILON);
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ fn nba_history() {
|
||||
.unix_timestamp() as f64,
|
||||
);
|
||||
|
||||
assert_abs_diff_eq!(p_win, 0.9002599772490479, epsilon = f64::EPSILON);
|
||||
assert_abs_diff_eq!(p_win, 0.8987931627209078, epsilon = f64::EPSILON);
|
||||
|
||||
let (p_win, _) = model.probabilities(
|
||||
&["CHI"],
|
||||
@@ -95,7 +95,7 @@ fn nba_history() {
|
||||
.unix_timestamp() as f64,
|
||||
);
|
||||
|
||||
assert_abs_diff_eq!(p_win, 0.22837870685441986, epsilon = f64::EPSILON);
|
||||
assert_abs_diff_eq!(p_win, 0.22890824995747874, epsilon = f64::EPSILON);
|
||||
|
||||
let (p_win, _) = model.probabilities(
|
||||
&["CHI"],
|
||||
|
||||
Reference in New Issue
Block a user