Remove need to box kernel on model

This commit is contained in:
2022-04-26 22:47:28 +02:00
parent 6f91c0a765
commit 6665362417
10 changed files with 20 additions and 20 deletions

View File

@@ -11,7 +11,7 @@ fn main() {
Box::new(ks::kernel::Matern52::new(0.5, 1.0)), Box::new(ks::kernel::Matern52::new(0.5, 1.0)),
]; ];
model.add_item(player, Box::new(kernel)); model.add_item(player, kernel);
} }
model.observe(&["A"], &["B"], 0.0); model.observe(&["A"], &["B"], 0.0);

View File

@@ -21,9 +21,9 @@ fn main() {
]; ];
// Now we are ready to add the items in the model. // Now we are ready to add the items in the model.
model.add_item("Spike", Box::new(k_spike)); model.add_item("Spike", k_spike);
model.add_item("Tom", Box::new(k_tom)); model.add_item("Tom", k_tom);
model.add_item("Jerry", Box::new(k_jerry)); model.add_item("Jerry", k_jerry);
// At first, Jerry beats Tom a couple of times. // At first, Jerry beats Tom a couple of times.
model.observe(&["Jerry"], &["Tom"], 0.0); model.observe(&["Jerry"], &["Tom"], 0.0);

View File

@@ -62,7 +62,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
Box::new(ks::kernel::Matern32::new(0.138, 1.753 * seconds_in_year)), Box::new(ks::kernel::Matern32::new(0.138, 1.753 * seconds_in_year)),
]; ];
model.add_item(&team, Box::new(kernel)); model.add_item(&team, kernel);
} }
for (winner, loser, t) in observations { for (winner, loser, t) in observations {

View File

@@ -7,9 +7,9 @@ use crate::kernel::Kernel;
use super::Fitter; use super::Fitter;
pub struct Recursive { pub struct Recursive<K> {
ts_new: Vec<f64>, ts_new: Vec<f64>,
kernel: Box<dyn Kernel>, kernel: K,
ts: Vec<f64>, ts: Vec<f64>,
ms: Vec<f64>, ms: Vec<f64>,
vs: Vec<f64>, vs: Vec<f64>,
@@ -28,8 +28,8 @@ pub struct Recursive {
p_s: Vec<Array2<f64>>, p_s: Vec<Array2<f64>>,
} }
impl Recursive { impl<K: Kernel> Recursive<K> {
pub fn new(kernel: Box<dyn Kernel>) -> Self { pub fn new(kernel: K) -> Self {
let m = kernel.order(); let m = kernel.order();
let h = kernel.measurement_vector(); let h = kernel.measurement_vector();
@@ -56,7 +56,7 @@ impl Recursive {
} }
} }
impl fmt::Debug for Recursive { impl<K: Kernel> fmt::Debug for Recursive<K> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RecursiveFitter") f.debug_struct("RecursiveFitter")
.field("ts_new", &self.ts_new) .field("ts_new", &self.ts_new)
@@ -80,7 +80,7 @@ impl fmt::Debug for Recursive {
} }
} }
impl Fitter for Recursive { impl<K: Kernel> Fitter for Recursive<K> {
fn add_sample(&mut self, t: f64) -> usize { fn add_sample(&mut self, t: f64) -> usize {
let idx = self.ts.len() + self.ts_new.len(); let idx = self.ts.len() + self.ts_new.len();

View File

@@ -42,7 +42,7 @@ impl Binary {
} }
} }
pub fn add_item(&mut self, name: &str, kernel: Box<dyn Kernel>) { pub fn add_item<K: Kernel + 'static>(&mut self, name: &str, kernel: K) {
if self.storage.contains_key(name) { if self.storage.contains_key(name) {
panic!("item '{}' already added", name); panic!("item '{}' already added", name);
} }

View File

@@ -31,7 +31,7 @@ impl DifferenceModel {
} }
} }
pub fn add_item(&mut self, name: &str, kernel: Box<dyn Kernel>) { pub fn add_item<K: Kernel + 'static>(&mut self, name: &str, kernel: K) {
if self.storage.contains_key(name) { if self.storage.contains_key(name) {
panic!("item '{}' already added", name); panic!("item '{}' already added", name);
} }

View File

@@ -39,7 +39,7 @@ impl TernaryModel {
} }
} }
pub fn add_item(&mut self, name: &str, kernel: Box<dyn Kernel>) { pub fn add_item<K: Kernel + 'static>(&mut self, name: &str, kernel: K) {
if self.storage.contains_key(name) { if self.storage.contains_key(name) {
panic!("item '{}' already added", name); panic!("item '{}' already added", name);
} }

View File

@@ -9,8 +9,8 @@ fn binary_1() {
let k_audrey = ks::kernel::Matern52::new(1.0, 2.0); let k_audrey = ks::kernel::Matern52::new(1.0, 2.0);
let k_benjamin = ks::kernel::Matern52::new(1.0, 2.0); let k_benjamin = ks::kernel::Matern52::new(1.0, 2.0);
model.add_item("audrey", Box::new(k_audrey)); model.add_item("audrey", k_audrey);
model.add_item("benjamin", Box::new(k_benjamin)); model.add_item("benjamin", k_benjamin);
model.observe(&["audrey"], &["benjamin"], 0.0); model.observe(&["audrey"], &["benjamin"], 0.0);
model.observe(&["audrey"], &["benjamin"], 1.0); model.observe(&["audrey"], &["benjamin"], 1.0);

View File

@@ -16,9 +16,9 @@ fn kickscore_basic() {
Box::new(ks::kernel::Matern52::new(0.5, 1.0)), Box::new(ks::kernel::Matern52::new(0.5, 1.0)),
]; ];
model.add_item("Spike", Box::new(k_spike)); model.add_item("Spike", k_spike);
model.add_item("Tom", Box::new(k_tom)); model.add_item("Tom", k_tom);
model.add_item("Jerry", Box::new(k_jerry)); model.add_item("Jerry", k_jerry);
model.observe(&["Jerry"], &["Tom"], 0.0); model.observe(&["Jerry"], &["Tom"], 0.0);
model.observe(&["Jerry"], &["Tom"], 0.9); model.observe(&["Jerry"], &["Tom"], 0.9);

View File

@@ -66,7 +66,7 @@ fn nba_history() {
Box::new(ks::kernel::Matern32::new(0.138, 1.753 * seconds_in_year)), Box::new(ks::kernel::Matern32::new(0.138, 1.753 * seconds_in_year)),
]; ];
model.add_item(&team, Box::new(kernel)); model.add_item(&team, kernel);
} }
for (winner, loser, t) in observations { for (winner, loser, t) in observations {