Fix tests
This commit is contained in:
45
tests/kickscore-basics.rs
Normal file
45
tests/kickscore-basics.rs
Normal file
@@ -0,0 +1,45 @@
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use approx::assert_abs_diff_eq;
|
||||
use kickscore as ks;
|
||||
|
||||
#[test]
|
||||
fn kickscore_basic() {
|
||||
let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit);
|
||||
|
||||
let k_spike = ks::kernel::Constant::new(0.5);
|
||||
|
||||
let k_tom = ks::kernel::Exponential::new(1.0, 1.0);
|
||||
|
||||
let k_jerry: Vec<Box<dyn ks::Kernel>> = vec![
|
||||
Box::new(ks::kernel::Constant::new(1.0)),
|
||||
Box::new(ks::kernel::Matern52::new(0.5, 1.0)),
|
||||
];
|
||||
|
||||
model.add_item("Spike", Box::new(k_spike));
|
||||
model.add_item("Tom", Box::new(k_tom));
|
||||
model.add_item("Jerry", Box::new(k_jerry));
|
||||
|
||||
model.observe(&["Jerry"], &["Tom"], 0.0);
|
||||
model.observe(&["Jerry"], &["Tom"], 0.9);
|
||||
|
||||
model.observe(&["Tom"], &["Spike"], 1.7);
|
||||
model.observe(&["Tom"], &["Jerry"], 2.1);
|
||||
|
||||
model.observe(&["Jerry"], &["Tom"], 3.0);
|
||||
model.observe(&["Jerry"], &["Tom", "Spike"], 3.5);
|
||||
|
||||
model.fit();
|
||||
|
||||
let (p_win, _p_los) = model.probabilities(&[&"Jerry"], &[&"Tom"], 4.0);
|
||||
|
||||
assert_abs_diff_eq!(p_win, 0.7299975928462964, epsilon = f64::EPSILON);
|
||||
|
||||
let (p_win, _p_los) = model.probabilities(&[&"Jerry"], &[&"Tom"], 2.0);
|
||||
|
||||
assert_abs_diff_eq!(p_win, 0.4455095363120037, epsilon = f64::EPSILON);
|
||||
|
||||
let (p_win, _p_los) = model.probabilities(&[&"Jerry"], &[&"Tom"], -1.0);
|
||||
|
||||
assert_abs_diff_eq!(p_win, 0.903756079972532, epsilon = f64::EPSILON);
|
||||
}
|
||||
Reference in New Issue
Block a user