Make models cloneable
This commit is contained in:
@@ -2,7 +2,7 @@ mod recursive;
|
|||||||
|
|
||||||
pub use recursive::Recursive;
|
pub use recursive::Recursive;
|
||||||
|
|
||||||
pub trait Fitter {
|
pub trait Fitter: FitterClone {
|
||||||
fn add_sample(&mut self, t: f64) -> usize;
|
fn add_sample(&mut self, t: f64) -> usize;
|
||||||
|
|
||||||
fn allocate(&mut self);
|
fn allocate(&mut self);
|
||||||
@@ -26,3 +26,22 @@ pub trait Fitter {
|
|||||||
fn ns(&self, idx: usize) -> f64;
|
fn ns(&self, idx: usize) -> f64;
|
||||||
fn ns_mut(&mut self, idx: usize) -> &mut f64;
|
fn ns_mut(&mut self, idx: usize) -> &mut f64;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait FitterClone {
|
||||||
|
fn clone_box(&self) -> Box<dyn Fitter>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> FitterClone for T
|
||||||
|
where
|
||||||
|
T: 'static + Fitter + Clone,
|
||||||
|
{
|
||||||
|
fn clone_box(&self) -> Box<dyn Fitter> {
|
||||||
|
Box::new(self.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Clone for Box<dyn Fitter> {
|
||||||
|
fn clone(&self) -> Box<dyn Fitter> {
|
||||||
|
self.clone_box()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ impl<K: Kernel> fmt::Debug for Recursive<K> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<K: Kernel> Fitter for Recursive<K> {
|
impl<K: 'static + Kernel + Clone> Fitter for Recursive<K> {
|
||||||
fn add_sample(&mut self, t: f64) -> usize {
|
fn add_sample(&mut self, t: f64) -> usize {
|
||||||
let idx = self.ts.len() + self.ts_new.len();
|
let idx = self.ts.len() + self.ts_new.len();
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
use crate::fitter::Fitter;
|
use crate::fitter::Fitter;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct Item {
|
pub struct Item {
|
||||||
pub fitter: Box<dyn Fitter>,
|
pub fitter: Box<dyn Fitter>,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ pub(crate) fn distance(ts1: &[f64], ts2: &[f64]) -> Array2<f64> {
|
|||||||
r
|
r
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait Kernel {
|
pub trait Kernel: KernelClone {
|
||||||
fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2<f64>;
|
fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2<f64>;
|
||||||
fn k_diag(&self, ts: &[f64]) -> Array1<f64>;
|
fn k_diag(&self, ts: &[f64]) -> Array1<f64>;
|
||||||
fn order(&self) -> usize;
|
fn order(&self) -> usize;
|
||||||
@@ -41,6 +41,25 @@ pub trait Kernel {
|
|||||||
fn noise_cov(&self, t0: f64, t1: f64) -> Array2<f64>;
|
fn noise_cov(&self, t0: f64, t1: f64) -> Array2<f64>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait KernelClone {
|
||||||
|
fn clone_box(&self) -> Box<dyn Kernel>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> KernelClone for T
|
||||||
|
where
|
||||||
|
T: 'static + Kernel + Clone,
|
||||||
|
{
|
||||||
|
fn clone_box(&self) -> Box<dyn Kernel> {
|
||||||
|
Box::new(self.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Clone for Box<dyn Kernel> {
|
||||||
|
fn clone(&self) -> Box<dyn Kernel> {
|
||||||
|
self.clone_box()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Kernel for Vec<Box<dyn Kernel>> {
|
impl Kernel for Vec<Box<dyn Kernel>> {
|
||||||
fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2<f64> {
|
fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2<f64> {
|
||||||
let n = ts1.len();
|
let n = ts1.len();
|
||||||
@@ -533,7 +552,7 @@ mod tests {
|
|||||||
|
|
||||||
macro_rules! tuple_impls {
|
macro_rules! tuple_impls {
|
||||||
( $( $name:ident )+ ) => {
|
( $( $name:ident )+ ) => {
|
||||||
impl<$($name: Kernel),+> Kernel for ($($name,)+) {
|
impl<$($name: 'static + Kernel + Clone),+> Kernel for ($($name,)+) {
|
||||||
fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2<f64> {
|
fn k_mat(&self, ts1: &[f64], ts2: Option<&[f64]>) -> Array2<f64> {
|
||||||
let n = ts1.len();
|
let n = ts1.len();
|
||||||
let m = ts2.map_or(n, |ts| ts.len());
|
let m = ts2.map_or(n, |ts| ts.len());
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ use ndarray::prelude::*;
|
|||||||
|
|
||||||
use super::Kernel;
|
use super::Kernel;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
pub struct Affine {
|
pub struct Affine {
|
||||||
var_offset: f64,
|
var_offset: f64,
|
||||||
var_slope: f64,
|
var_slope: f64,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ use ndarray::prelude::*;
|
|||||||
|
|
||||||
use super::Kernel;
|
use super::Kernel;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
pub struct Constant {
|
pub struct Constant {
|
||||||
var: f64,
|
var: f64,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ use ndarray::prelude::*;
|
|||||||
|
|
||||||
use super::Kernel;
|
use super::Kernel;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
pub struct Exponential {
|
pub struct Exponential {
|
||||||
var: f64,
|
var: f64,
|
||||||
l_scale: f64,
|
l_scale: f64,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ use ndarray::prelude::*;
|
|||||||
|
|
||||||
use super::Kernel;
|
use super::Kernel;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone, Copy)]
|
||||||
pub struct Matern32 {
|
pub struct Matern32 {
|
||||||
var: f64,
|
var: f64,
|
||||||
l_scale: f64,
|
l_scale: f64,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ use ndarray::prelude::*;
|
|||||||
|
|
||||||
use super::Kernel;
|
use super::Kernel;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
pub struct Matern52 {
|
pub struct Matern52 {
|
||||||
var: f64,
|
var: f64,
|
||||||
l_scale: f64,
|
l_scale: f64,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ use ndarray::prelude::*;
|
|||||||
|
|
||||||
use super::Kernel;
|
use super::Kernel;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
pub struct Wiener {
|
pub struct Wiener {
|
||||||
var: f64,
|
var: f64,
|
||||||
t0: f64,
|
t0: f64,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ trait BinaryObeservation: Observation {
|
|||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
pub enum Observation {
|
pub enum Observation {
|
||||||
Probit,
|
Probit,
|
||||||
Logit,
|
Logit,
|
||||||
@@ -23,6 +24,7 @@ pub enum FitMethod {
|
|||||||
Kl,
|
Kl,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct Binary {
|
pub struct Binary {
|
||||||
pub storage: Storage,
|
pub storage: Storage,
|
||||||
last_t: f64,
|
last_t: f64,
|
||||||
@@ -50,7 +52,7 @@ impl Binary {
|
|||||||
Self::new(Observation::Logit)
|
Self::new(Observation::Logit)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_item<K: Kernel + 'static>(&mut self, name: &str, kernel: K) {
|
pub fn add_item<K: Kernel + 'static + Clone>(&mut self, name: &str, kernel: K) {
|
||||||
if self.storage.contains_key(name) {
|
if self.storage.contains_key(name) {
|
||||||
panic!("item '{}' already added", name);
|
panic!("item '{}' already added", name);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ pub enum DifferenceModelFitMethod {
|
|||||||
Kl,
|
Kl,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct DifferenceModel {
|
pub struct DifferenceModel {
|
||||||
storage: Storage,
|
storage: Storage,
|
||||||
last_t: f64,
|
last_t: f64,
|
||||||
@@ -31,7 +32,7 @@ impl DifferenceModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_item<K: Kernel + 'static>(&mut self, name: &str, kernel: K) {
|
pub fn add_item<K: Kernel + 'static + Clone>(&mut self, name: &str, kernel: K) {
|
||||||
if self.storage.contains_key(name) {
|
if self.storage.contains_key(name) {
|
||||||
panic!("item '{}' already added", name);
|
panic!("item '{}' already added", name);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ pub enum FitMethod {
|
|||||||
Kl,
|
Kl,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct TernaryModel {
|
pub struct TernaryModel {
|
||||||
storage: Storage,
|
storage: Storage,
|
||||||
last_t: f64,
|
last_t: f64,
|
||||||
@@ -47,7 +48,7 @@ impl TernaryModel {
|
|||||||
Self::new(Observation::Logit, margin)
|
Self::new(Observation::Logit, margin)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_item<K: Kernel + 'static>(&mut self, name: &str, kernel: K) {
|
pub fn add_item<K: Kernel + 'static + Clone>(&mut self, name: &str, kernel: K) {
|
||||||
if self.storage.contains_key(name) {
|
if self.storage.contains_key(name) {
|
||||||
panic!("item '{}' already added", name);
|
panic!("item '{}' already added", name);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,11 +6,30 @@ mod ordinal;
|
|||||||
pub use gaussian::*;
|
pub use gaussian::*;
|
||||||
pub use ordinal::*;
|
pub use ordinal::*;
|
||||||
|
|
||||||
pub trait Observation {
|
pub trait Observation: ObservationClone {
|
||||||
fn ep_update(&mut self, storage: &mut Storage, lr: f64) -> f64;
|
fn ep_update(&mut self, storage: &mut Storage, lr: f64) -> f64;
|
||||||
fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64;
|
fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait ObservationClone {
|
||||||
|
fn clone_box(&self) -> Box<dyn Observation>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> ObservationClone for T
|
||||||
|
where
|
||||||
|
T: 'static + Observation + Clone,
|
||||||
|
{
|
||||||
|
fn clone_box(&self) -> Box<dyn Observation> {
|
||||||
|
Box::new(self.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Clone for Box<dyn Observation> {
|
||||||
|
fn clone(&self) -> Box<dyn Observation> {
|
||||||
|
self.clone_box()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn f_params(elems: &[(usize, f64)], t: f64, storage: &Storage) -> (f64, f64) {
|
pub fn f_params(elems: &[(usize, f64)], t: f64, storage: &Storage) -> (f64, f64) {
|
||||||
let mut m = 0.0;
|
let mut m = 0.0;
|
||||||
let mut v = 0.0;
|
let mut v = 0.0;
|
||||||
@@ -25,6 +44,7 @@ pub fn f_params(elems: &[(usize, f64)], t: f64, storage: &Storage) -> (f64, f64)
|
|||||||
(m, v)
|
(m, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub(crate) struct Core {
|
pub(crate) struct Core {
|
||||||
pub m: usize,
|
pub m: usize,
|
||||||
pub items: Vec<usize>,
|
pub items: Vec<usize>,
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ fn mm_gaussian(mean_cav: f64, var_cav: f64, diff: f64, var_obs: f64) -> (f64, f6
|
|||||||
(logpart, dlogpart, d2logpart)
|
(logpart, dlogpart, d2logpart)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct GaussianObservation {
|
pub struct GaussianObservation {
|
||||||
core: Core,
|
core: Core,
|
||||||
diff: f64,
|
diff: f64,
|
||||||
|
|||||||
@@ -87,6 +87,7 @@ where
|
|||||||
(exp_ll, alpha, beta)
|
(exp_ll, alpha, beta)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct ProbitWinObservation {
|
pub struct ProbitWinObservation {
|
||||||
core: Core,
|
core: Core,
|
||||||
margin: f64,
|
margin: f64,
|
||||||
@@ -126,6 +127,7 @@ impl Observation for ProbitWinObservation {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct ProbitTieObservation {
|
pub struct ProbitTieObservation {
|
||||||
core: Core,
|
core: Core,
|
||||||
margin: f64,
|
margin: f64,
|
||||||
@@ -165,6 +167,7 @@ impl Observation for ProbitTieObservation {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct LogitWinObservation {
|
pub struct LogitWinObservation {
|
||||||
core: Core,
|
core: Core,
|
||||||
margin: f64,
|
margin: f64,
|
||||||
@@ -197,6 +200,7 @@ impl Observation for LogitWinObservation {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct LogitTieObservation {
|
pub struct LogitTieObservation {
|
||||||
core: Core,
|
core: Core,
|
||||||
margin: f64,
|
margin: f64,
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ pub trait Backend {
|
|||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Clone, Default)]
|
||||||
pub struct Storage {
|
pub struct Storage {
|
||||||
keys: HashMap<String, usize>,
|
keys: HashMap<String, usize>,
|
||||||
items: Vec<Item>,
|
items: Vec<Item>,
|
||||||
|
|||||||
@@ -35,9 +35,9 @@ fn kickscore_basic() {
|
|||||||
|
|
||||||
let (p_win, _p_los) = model.probabilities(&["Jerry"], &["Tom"], 2.0);
|
let (p_win, _p_los) = model.probabilities(&["Jerry"], &["Tom"], 2.0);
|
||||||
|
|
||||||
assert_abs_diff_eq!(p_win, 0.4455095363120037, epsilon = f64::EPSILON);
|
assert_abs_diff_eq!(p_win, 0.4705389648502623, epsilon = f64::EPSILON);
|
||||||
|
|
||||||
let (p_win, _p_los) = model.probabilities(&["Jerry"], &["Tom"], -1.0);
|
let (p_win, _p_los) = model.probabilities(&["Jerry"], &["Tom"], -1.0);
|
||||||
|
|
||||||
assert_abs_diff_eq!(p_win, 0.9037560799725326, epsilon = f64::EPSILON);
|
assert_abs_diff_eq!(p_win, 0.7030407811954662, epsilon = f64::EPSILON);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ fn nba_history() {
|
|||||||
.unix_timestamp() as f64,
|
.unix_timestamp() as f64,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_abs_diff_eq!(p_win, 0.9002599772490479, epsilon = f64::EPSILON);
|
assert_abs_diff_eq!(p_win, 0.8987931627209078, epsilon = f64::EPSILON);
|
||||||
|
|
||||||
let (p_win, _) = model.probabilities(
|
let (p_win, _) = model.probabilities(
|
||||||
&["CHI"],
|
&["CHI"],
|
||||||
@@ -95,7 +95,7 @@ fn nba_history() {
|
|||||||
.unix_timestamp() as f64,
|
.unix_timestamp() as f64,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_abs_diff_eq!(p_win, 0.22837870685441986, epsilon = f64::EPSILON);
|
assert_abs_diff_eq!(p_win, 0.22890824995747874, epsilon = f64::EPSILON);
|
||||||
|
|
||||||
let (p_win, _) = model.probabilities(
|
let (p_win, _) = model.probabilities(
|
||||||
&["CHI"],
|
&["CHI"],
|
||||||
|
|||||||
Reference in New Issue
Block a user