From 7edc0b6b6d9dd0636eb391dcc9f240ffc4f71cc5 Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Mon, 4 Jan 2021 15:00:00 +0100 Subject: [PATCH] Small fixes --- examples/nba-history.rs | 15 +++++++++------ src/observation/gaussian.rs | 14 +------------- tests/binary-1.rs | 28 ++++++++++++++++++++++++++++ tests/nba-history.rs | 15 +++++++++------ 4 files changed, 47 insertions(+), 25 deletions(-) create mode 100644 tests/binary-1.rs diff --git a/examples/nba-history.rs b/examples/nba-history.rs index 91e8da8..ba023ee 100644 --- a/examples/nba-history.rs +++ b/examples/nba-history.rs @@ -1,5 +1,6 @@ extern crate intel_mkl_src; +use std::cmp::Ordering; use std::collections::HashSet; use std::fs; use std::io::{self, BufRead}; @@ -39,12 +40,14 @@ fn main() -> Result<(), Box> { let score_1: u16 = data[3].parse()?; let score_2: u16 = data[4].parse()?; - if score_1 > score_2 { - observations.push((data[1].to_string(), data[2].to_string(), t)); - } else if score_1 < score_2 { - observations.push((data[2].to_string(), data[1].to_string(), t)); - } else { - panic!("there shouldn't be any tie games"); + match score_1.cmp(&score_2) { + Ordering::Greater => { + observations.push((data[1].to_string(), data[2].to_string(), t)); + } + Ordering::Less => { + observations.push((data[2].to_string(), data[1].to_string(), t)); + } + _ => panic!("there shouldn't be any tie games"), } } diff --git a/src/observation/gaussian.rs b/src/observation/gaussian.rs index 4d90892..7bfe7aa 100644 --- a/src/observation/gaussian.rs +++ b/src/observation/gaussian.rs @@ -2,19 +2,7 @@ use crate::storage::Storage; use super::Observation; -pub struct GaussianObservation { - /* -m: usize, -items: Vec, -coeffs: Vec, -indices: Vec, -ns_cav: Vec, -xs_cav: Vec, -t: f64, -logpart: f64, -exp_ll: usize, -margin: f64, -*/} +pub struct GaussianObservation; impl GaussianObservation { pub fn new( diff --git a/tests/binary-1.rs b/tests/binary-1.rs new file mode 100644 index 0000000..406c51f --- /dev/null +++ b/tests/binary-1.rs @@ -0,0 +1,28 @@ +extern crate intel_mkl_src; + +use approx::assert_abs_diff_eq; +use kickscore as ks; + +#[test] +fn binary_1() { + let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit); + + let k_audrey = ks::kernel::Matern52::new(1.0, 2.0); + let k_benjamin = ks::kernel::Matern52::new(1.0, 2.0); + + model.add_item("audrey", Box::new(k_audrey)); + model.add_item("benjamin", Box::new(k_benjamin)); + + model.observe(&["audrey"], &["benjamin"], 0.0); + model.observe(&["audrey"], &["benjamin"], 1.0); + model.observe(&["audrey"], &["benjamin"], 2.0); + + model.observe(&["audrey"], &["benjamin"], 3.0); + + model.observe(&["benjamin"], &["audrey"], 4.0); + model.observe(&["benjamin"], &["audrey"], 5.0); + model.observe(&["benjamin"], &["audrey"], 6.0); + model.observe(&["benjamin"], &["audrey"], 7.0); + + model.fit(); +} diff --git a/tests/nba-history.rs b/tests/nba-history.rs index bcc3ea3..f21c57c 100644 --- a/tests/nba-history.rs +++ b/tests/nba-history.rs @@ -1,5 +1,6 @@ extern crate intel_mkl_src; +use std::cmp::Ordering; use std::collections::HashSet; use std::fs; use std::io::{self, BufRead}; @@ -43,12 +44,14 @@ fn nba_history() { let score_1: u16 = data[3].parse().unwrap(); let score_2: u16 = data[4].parse().unwrap(); - if score_1 > score_2 { - observations.push((data[1].to_string(), data[2].to_string(), t)); - } else if score_1 < score_2 { - observations.push((data[2].to_string(), data[1].to_string(), t)); - } else { - panic!("there shouldn't be any tie games"); + match score_1.cmp(&score_2) { + Ordering::Greater => { + observations.push((data[1].to_string(), data[2].to_string(), t)); + } + Ordering::Less => { + observations.push((data[2].to_string(), data[1].to_string(), t)); + } + _ => panic!("there shouldn't be any tie games"), } }