diff --git a/src/fitter/recursive.rs b/src/fitter/recursive.rs index 0bfddd7..348d688 100644 --- a/src/fitter/recursive.rs +++ b/src/fitter/recursive.rs @@ -33,7 +33,7 @@ impl RecursiveFitter { let m = kernel.order(); let h = kernel.measurement_vector(); - RecursiveFitter { + Self { ts_new: Vec::new(), kernel, ts: Vec::new(), @@ -215,6 +215,7 @@ impl Fitter for RecursiveFitter { self.is_fitted = true; } + #[allow(clippy::many_single_char_names)] fn predict(&self, ts: &[f64]) -> (Vec, Vec) { if !self.is_fitted { panic!("new data since last call to `fit()`"); diff --git a/src/observation.rs b/src/observation.rs index e5382f6..a0e0fb5 100644 --- a/src/observation.rs +++ b/src/observation.rs @@ -110,7 +110,45 @@ impl Core { diff } - fn kl_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 { - todo!(); + pub fn kl_update(&mut self, storage: &mut Storage, lr: f64, cvi_expectations: F) -> f64 + where + F: Fn(f64, f64) -> (f64, f64, f64), + { + // Mean and variance in function space. + let mut f_mean: f64 = 0.0; + let mut f_var: f64 = 0.0; + + for i in 0..self.m { + let item = storage.item_mut(self.items[i]); + let idx = self.indices[i]; + let coeff = self.coeffs[i]; + + // Adjust the function-space mean & variance. + f_mean += coeff * item.fitter.ms(idx); + f_var += coeff * coeff * item.fitter.vs(idx); + } + + // Compute the derivatives of the exp. log-lik. w.r.t. mean parameters. + let (exp_ll, alpha, beta) = cvi_expectations(f_mean, f_var); + + for i in 0..self.m { + let item = storage.item_mut(self.items[i]); + let idx = self.indices[i]; + let coeff = self.coeffs[i]; + + // Update the elements' parameters. + let x = -2.0 * coeff * coeff * beta; + let n = coeff * (alpha - 2.0 * item.fitter.ms(idx) * coeff * beta); + + *item.fitter.xs_mut(idx) = (1.0 - lr) * item.fitter.xs(idx) + lr * x; + *item.fitter.ns_mut(idx) = (1.0 - lr) * item.fitter.ns(idx) + lr * n; + } + + let diff = (self.exp_ll - exp_ll).abs(); + + // Save the expected log-likelihood. + self.exp_ll = exp_ll; + + diff } } diff --git a/src/observation/ordinal.rs b/src/observation/ordinal.rs index ef0e6de..289f049 100644 --- a/src/observation/ordinal.rs +++ b/src/observation/ordinal.rs @@ -1,11 +1,14 @@ +use std::f64::consts::TAU; + use crate::storage::Storage; -use crate::utils::{logphi, normcdf, normpdf}; +use crate::utils::{logphi, normcdf, normpdf, ROOTS_HERMITENORM_WS, ROOTS_HERMITENORM_XS}; use super::{f_params, Core, Observation}; fn mm_probit_win(mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) { // Adapted from the GPML function `likErf.m`. let z = mean_cav / (1.0 + cov_cav).sqrt(); + let (logpart, val) = logphi(z); let dlogpart = val / (1.0 + cov_cav).sqrt(); // 1st derivative w.r.t. mean. let d2logpart = -val * (z + val) / (1.0 + cov_cav); @@ -33,6 +36,45 @@ fn mm_probit_tie(mean_cav: f64, cov_cav: f64, margin: f64) -> (f64, f64, f64) { (logpart, dlogpart, d2logpart) } +fn ll_probit_win(x: f64, margin: f64) -> f64 { + logphi(x - margin).0 +} + +fn ll_probit_tie(x: f64, margin: f64) -> f64 { + let x = -x.abs(); + let z = logphi(x + margin).0; + let a = logphi(x - margin).0 - z; + + if a > -0.693 { + z + (-a.exp_m1()).ln() + } else { + z + (-a.exp()).ln_1p() + } +} + +fn cvi_expectations(mean: f64, var: f64, ll_fct: F) -> (f64, f64, f64) +where + F: Fn(f64) -> f64, +{ + const N: usize = 30; + + let std = var.sqrt(); + let mut exp_ll = 0.0; + let mut alpha = 0.0; + let mut beta = 0.0; + + for i in 0..N { + let val = + (ROOTS_HERMITENORM_WS[i] / TAU.sqrt()) * ll_fct(std * ROOTS_HERMITENORM_XS[i] + mean); + + exp_ll += val; + alpha += (ROOTS_HERMITENORM_XS[i] / std) * val; + beta += ((ROOTS_HERMITENORM_XS[i].powi(2) - 1.0) / (2.0 * var)) * val; + } + + (exp_ll, alpha, beta) +} + pub struct ProbitWinObservation { core: Core, margin: f64, @@ -40,7 +82,7 @@ pub struct ProbitWinObservation { impl ProbitWinObservation { pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self { - ProbitWinObservation { + Self { core: Core::new(storage, elems, t), margin, } @@ -58,14 +100,17 @@ impl Observation for ProbitWinObservation { fn ep_update(&mut self, storage: &mut Storage, lr: f64) -> f64 { let margin = self.margin; - self.core - .ep_update(storage, lr, |mean_cav: f64, cov_cav: f64| { - mm_probit_win(mean_cav - margin, cov_cav) - }) + self.core.ep_update(storage, lr, |mean_cav, cov_cav| { + mm_probit_win(mean_cav - margin, cov_cav) + }) } fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64 { - self.core.kl_update(lr, storage) + let margin = self.margin; + + self.core.kl_update(storage, lr, |mean, var| { + cvi_expectations(mean, var, |x| ll_probit_win(x, margin)) + }) } } @@ -94,26 +139,29 @@ impl Observation for ProbitTieObservation { fn ep_update(&mut self, storage: &mut Storage, lr: f64) -> f64 { let margin = self.margin; - self.core - .ep_update(storage, lr, |mean_cav: f64, cov_cav: f64| { - mm_probit_tie(mean_cav, cov_cav, margin) - }) + self.core.ep_update(storage, lr, |mean_cav, cov_cav| { + mm_probit_tie(mean_cav, cov_cav, margin) + }) } - fn kl_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 { - todo!(); + fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64 { + let margin = self.margin; + + self.core.kl_update(storage, lr, |mean, var| { + cvi_expectations(mean, var, |x| ll_probit_tie(x, margin)) + }) } } pub struct LogitWinObservation { - _core: Core, + core: Core, _margin: f64, } impl LogitWinObservation { pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self { - LogitWinObservation { - _core: Core::new(storage, elems, t), + Self { + core: Core::new(storage, elems, t), _margin: margin, } } @@ -124,18 +172,22 @@ impl Observation for LogitWinObservation { todo!(); } - fn kl_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 { - todo!(); + fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64 { + self.core.kl_update(storage, lr, |mean, var| todo!()) } } pub struct LogitTieObservation { - // + core: Core, + _margin: f64, } impl LogitTieObservation { - pub fn new(_storage: &mut Storage, _elems: &[(usize, f64)], _t: f64, _margin: f64) -> Self { - todo!(); + pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self { + Self { + core: Core::new(storage, elems, t), + _margin: margin, + } } } @@ -144,8 +196,8 @@ impl Observation for LogitTieObservation { todo!(); } - fn kl_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 { - todo!(); + fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64 { + self.core.kl_update(storage, lr, |mean, var| todo!()) } } @@ -176,4 +228,13 @@ mod tests { assert_relative_eq!(b, -0.20881357058382308); assert_relative_eq!(c, -0.1698273481633205); } + + #[test] + fn test_cvi_expectations() { + let (a, b, c) = cvi_expectations(0.3, 2.7, |x| logphi(x).0); + + assert_relative_eq!(a, -1.198109740470403); + assert_relative_eq!(b, 0.8970390100143342); + assert_relative_eq!(c, -0.2565392538236815); + } } diff --git a/src/utils.rs b/src/utils.rs index b50a153..9b4ee72 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -87,3 +87,69 @@ pub fn logphi(z: f64) -> (f64, f64) { (res, dres) } } + +pub const ROOTS_HERMITENORM_XS: [f64; 30] = [ + -9.706236, + -8.68083772, + -7.82505174, + -7.05539687, + -6.33999769, + -5.66238185, + -5.0126006, + -4.38402037, + -3.77189442, + -3.17263464, + -2.5834021, + -2.00185861, + -1.42600566, + -0.85407335, + -0.28443876, + 0.28443876, + 0.85407335, + 1.42600566, + 2.00185861, + 2.5834021, + 3.17263464, + 3.77189442, + 4.38402037, + 5.0126006, + 5.66238185, + 6.33999769, + 7.05539687, + 7.82505174, + 8.68083772, + 9.706236, +]; + +pub const ROOTS_HERMITENORM_WS: [f64; 30] = [ + 4.11289324e-21, + 3.97441190e-17, + 4.07096517e-14, + 1.14638786e-11, + 1.29804729e-09, + 7.22454173e-08, + 2.23317741e-06, + 4.15598507e-05, + 4.92584902e-04, + 3.87200709e-03, + 2.07943554e-02, + 7.79856428e-02, + 2.07515826e-01, + 3.96164962e-01, + 5.46444893e-01, + 5.46444893e-01, + 3.96164962e-01, + 2.07515826e-01, + 7.79856428e-02, + 2.07943554e-02, + 3.87200709e-03, + 4.92584902e-04, + 4.15598507e-05, + 2.23317741e-06, + 7.22454173e-08, + 1.29804729e-09, + 1.14638786e-11, + 4.07096517e-14, + 3.97441190e-17, + 4.11289324e-21, +]; diff --git a/tests/binary-1.rs b/tests/binary-1.rs index c541984..0e5dd1e 100644 --- a/tests/binary-1.rs +++ b/tests/binary-1.rs @@ -2,7 +2,7 @@ extern crate intel_mkl_src; use kickscore as ks; -// #[test] +#[test] fn binary_1() { let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit);