diff --git a/src/model/binary.rs b/src/model/binary.rs index 97f0041..b090af8 100644 --- a/src/model/binary.rs +++ b/src/model/binary.rs @@ -142,7 +142,9 @@ impl BinaryModel { elems.extend(self.process_items(team_2, -1.0)); let prob = match self.win_obs { - BinaryModelObservation::Probit => probit_win_observation(&elems, t, 0.0, &self.storage), + BinaryModelObservation::Probit => { + ProbitWinObservation::probability(&self.storage, &elems, t, 0.0) + } BinaryModelObservation::Logit => todo!(), }; diff --git a/src/model/ternary.rs b/src/model/ternary.rs index 729f616..ba577e7 100644 --- a/src/model/ternary.rs +++ b/src/model/ternary.rs @@ -169,14 +169,14 @@ impl TernaryModel { let prob_1 = match self.obs { TernaryModelObservation::Probit => { - probit_win_observation(&elems, t, margin, &self.storage) + ProbitWinObservation::probability(&self.storage, &elems, t, margin) } TernaryModelObservation::Logit => unimplemented!(), }; let prob_2 = match self.obs { TernaryModelObservation::Probit => { - probit_tie_observation(&elems, t, margin, &self.storage) + ProbitTieObservation::probability(&self.storage, &elems, t, margin) } TernaryModelObservation::Logit => unimplemented!(), }; diff --git a/src/observation/ordinal.rs b/src/observation/ordinal.rs index 370829d..473ea22 100644 --- a/src/observation/ordinal.rs +++ b/src/observation/ordinal.rs @@ -3,18 +3,6 @@ use crate::utils::logphi; use super::{f_params, Core, Observation}; -pub fn probit_win_observation( - elems: &[(usize, f64)], - t: f64, - margin: f64, - storage: &Storage, -) -> f64 { - let (m, v) = f_params(&elems, t, &storage); - let (logpart, _, _) = mm_probit_win(m - margin, v); - - logpart.exp() -} - fn mm_probit_win(mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) { // Adapted from the GPML function `likErf.m`. let z = mean_cav / (1.0 + cov_cav).sqrt(); @@ -37,6 +25,13 @@ impl ProbitWinObservation { margin, } } + + pub fn probability(storage: &Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> f64 { + let (m, v) = f_params(&elems, t, &storage); + let (logpart, _, _) = mm_probit_win(m - margin, v); + + logpart.exp() + } } impl Observation for ProbitWinObservation { @@ -81,15 +76,6 @@ impl Observation for LogitWinObservation { } } -pub fn probit_tie_observation( - _elems: &[(usize, f64)], - _t: f64, - _margin: f64, - _storage: &Storage, -) -> f64 { - unimplemented!(); -} - pub struct ProbitTieObservation { // } @@ -98,6 +84,10 @@ impl ProbitTieObservation { pub fn new(_storage: &mut Storage, _elems: &[(usize, f64)], _t: f64, _margin: f64) -> Self { todo!(); } + + pub fn probability(storage: &Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> f64 { + todo!(); + } } impl Observation for ProbitTieObservation {