Implemented more functions
This commit is contained in:
@@ -33,7 +33,7 @@ impl RecursiveFitter {
|
||||
let m = kernel.order();
|
||||
let h = kernel.measurement_vector();
|
||||
|
||||
RecursiveFitter {
|
||||
Self {
|
||||
ts_new: Vec::new(),
|
||||
kernel,
|
||||
ts: Vec::new(),
|
||||
@@ -215,6 +215,7 @@ impl Fitter for RecursiveFitter {
|
||||
self.is_fitted = true;
|
||||
}
|
||||
|
||||
#[allow(clippy::many_single_char_names)]
|
||||
fn predict(&self, ts: &[f64]) -> (Vec<f64>, Vec<f64>) {
|
||||
if !self.is_fitted {
|
||||
panic!("new data since last call to `fit()`");
|
||||
|
||||
@@ -110,7 +110,45 @@ impl Core {
|
||||
diff
|
||||
}
|
||||
|
||||
fn kl_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 {
|
||||
todo!();
|
||||
pub fn kl_update<F>(&mut self, storage: &mut Storage, lr: f64, cvi_expectations: F) -> f64
|
||||
where
|
||||
F: Fn(f64, f64) -> (f64, f64, f64),
|
||||
{
|
||||
// Mean and variance in function space.
|
||||
let mut f_mean: f64 = 0.0;
|
||||
let mut f_var: f64 = 0.0;
|
||||
|
||||
for i in 0..self.m {
|
||||
let item = storage.item_mut(self.items[i]);
|
||||
let idx = self.indices[i];
|
||||
let coeff = self.coeffs[i];
|
||||
|
||||
// Adjust the function-space mean & variance.
|
||||
f_mean += coeff * item.fitter.ms(idx);
|
||||
f_var += coeff * coeff * item.fitter.vs(idx);
|
||||
}
|
||||
|
||||
// Compute the derivatives of the exp. log-lik. w.r.t. mean parameters.
|
||||
let (exp_ll, alpha, beta) = cvi_expectations(f_mean, f_var);
|
||||
|
||||
for i in 0..self.m {
|
||||
let item = storage.item_mut(self.items[i]);
|
||||
let idx = self.indices[i];
|
||||
let coeff = self.coeffs[i];
|
||||
|
||||
// Update the elements' parameters.
|
||||
let x = -2.0 * coeff * coeff * beta;
|
||||
let n = coeff * (alpha - 2.0 * item.fitter.ms(idx) * coeff * beta);
|
||||
|
||||
*item.fitter.xs_mut(idx) = (1.0 - lr) * item.fitter.xs(idx) + lr * x;
|
||||
*item.fitter.ns_mut(idx) = (1.0 - lr) * item.fitter.ns(idx) + lr * n;
|
||||
}
|
||||
|
||||
let diff = (self.exp_ll - exp_ll).abs();
|
||||
|
||||
// Save the expected log-likelihood.
|
||||
self.exp_ll = exp_ll;
|
||||
|
||||
diff
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
use std::f64::consts::TAU;
|
||||
|
||||
use crate::storage::Storage;
|
||||
use crate::utils::{logphi, normcdf, normpdf};
|
||||
use crate::utils::{logphi, normcdf, normpdf, ROOTS_HERMITENORM_WS, ROOTS_HERMITENORM_XS};
|
||||
|
||||
use super::{f_params, Core, Observation};
|
||||
|
||||
fn mm_probit_win(mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
|
||||
// Adapted from the GPML function `likErf.m`.
|
||||
let z = mean_cav / (1.0 + cov_cav).sqrt();
|
||||
|
||||
let (logpart, val) = logphi(z);
|
||||
let dlogpart = val / (1.0 + cov_cav).sqrt(); // 1st derivative w.r.t. mean.
|
||||
let d2logpart = -val * (z + val) / (1.0 + cov_cav);
|
||||
@@ -33,6 +36,45 @@ fn mm_probit_tie(mean_cav: f64, cov_cav: f64, margin: f64) -> (f64, f64, f64) {
|
||||
(logpart, dlogpart, d2logpart)
|
||||
}
|
||||
|
||||
fn ll_probit_win(x: f64, margin: f64) -> f64 {
|
||||
logphi(x - margin).0
|
||||
}
|
||||
|
||||
fn ll_probit_tie(x: f64, margin: f64) -> f64 {
|
||||
let x = -x.abs();
|
||||
let z = logphi(x + margin).0;
|
||||
let a = logphi(x - margin).0 - z;
|
||||
|
||||
if a > -0.693 {
|
||||
z + (-a.exp_m1()).ln()
|
||||
} else {
|
||||
z + (-a.exp()).ln_1p()
|
||||
}
|
||||
}
|
||||
|
||||
fn cvi_expectations<F>(mean: f64, var: f64, ll_fct: F) -> (f64, f64, f64)
|
||||
where
|
||||
F: Fn(f64) -> f64,
|
||||
{
|
||||
const N: usize = 30;
|
||||
|
||||
let std = var.sqrt();
|
||||
let mut exp_ll = 0.0;
|
||||
let mut alpha = 0.0;
|
||||
let mut beta = 0.0;
|
||||
|
||||
for i in 0..N {
|
||||
let val =
|
||||
(ROOTS_HERMITENORM_WS[i] / TAU.sqrt()) * ll_fct(std * ROOTS_HERMITENORM_XS[i] + mean);
|
||||
|
||||
exp_ll += val;
|
||||
alpha += (ROOTS_HERMITENORM_XS[i] / std) * val;
|
||||
beta += ((ROOTS_HERMITENORM_XS[i].powi(2) - 1.0) / (2.0 * var)) * val;
|
||||
}
|
||||
|
||||
(exp_ll, alpha, beta)
|
||||
}
|
||||
|
||||
pub struct ProbitWinObservation {
|
||||
core: Core,
|
||||
margin: f64,
|
||||
@@ -40,7 +82,7 @@ pub struct ProbitWinObservation {
|
||||
|
||||
impl ProbitWinObservation {
|
||||
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
|
||||
ProbitWinObservation {
|
||||
Self {
|
||||
core: Core::new(storage, elems, t),
|
||||
margin,
|
||||
}
|
||||
@@ -58,14 +100,17 @@ impl Observation for ProbitWinObservation {
|
||||
fn ep_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
|
||||
let margin = self.margin;
|
||||
|
||||
self.core
|
||||
.ep_update(storage, lr, |mean_cav: f64, cov_cav: f64| {
|
||||
mm_probit_win(mean_cav - margin, cov_cav)
|
||||
})
|
||||
self.core.ep_update(storage, lr, |mean_cav, cov_cav| {
|
||||
mm_probit_win(mean_cav - margin, cov_cav)
|
||||
})
|
||||
}
|
||||
|
||||
fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
|
||||
self.core.kl_update(lr, storage)
|
||||
let margin = self.margin;
|
||||
|
||||
self.core.kl_update(storage, lr, |mean, var| {
|
||||
cvi_expectations(mean, var, |x| ll_probit_win(x, margin))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,26 +139,29 @@ impl Observation for ProbitTieObservation {
|
||||
fn ep_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
|
||||
let margin = self.margin;
|
||||
|
||||
self.core
|
||||
.ep_update(storage, lr, |mean_cav: f64, cov_cav: f64| {
|
||||
mm_probit_tie(mean_cav, cov_cav, margin)
|
||||
})
|
||||
self.core.ep_update(storage, lr, |mean_cav, cov_cav| {
|
||||
mm_probit_tie(mean_cav, cov_cav, margin)
|
||||
})
|
||||
}
|
||||
|
||||
fn kl_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 {
|
||||
todo!();
|
||||
fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
|
||||
let margin = self.margin;
|
||||
|
||||
self.core.kl_update(storage, lr, |mean, var| {
|
||||
cvi_expectations(mean, var, |x| ll_probit_tie(x, margin))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LogitWinObservation {
|
||||
_core: Core,
|
||||
core: Core,
|
||||
_margin: f64,
|
||||
}
|
||||
|
||||
impl LogitWinObservation {
|
||||
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
|
||||
LogitWinObservation {
|
||||
_core: Core::new(storage, elems, t),
|
||||
Self {
|
||||
core: Core::new(storage, elems, t),
|
||||
_margin: margin,
|
||||
}
|
||||
}
|
||||
@@ -124,18 +172,22 @@ impl Observation for LogitWinObservation {
|
||||
todo!();
|
||||
}
|
||||
|
||||
fn kl_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 {
|
||||
todo!();
|
||||
fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
|
||||
self.core.kl_update(storage, lr, |mean, var| todo!())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LogitTieObservation {
|
||||
//
|
||||
core: Core,
|
||||
_margin: f64,
|
||||
}
|
||||
|
||||
impl LogitTieObservation {
|
||||
pub fn new(_storage: &mut Storage, _elems: &[(usize, f64)], _t: f64, _margin: f64) -> Self {
|
||||
todo!();
|
||||
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
|
||||
Self {
|
||||
core: Core::new(storage, elems, t),
|
||||
_margin: margin,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,8 +196,8 @@ impl Observation for LogitTieObservation {
|
||||
todo!();
|
||||
}
|
||||
|
||||
fn kl_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 {
|
||||
todo!();
|
||||
fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
|
||||
self.core.kl_update(storage, lr, |mean, var| todo!())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -176,4 +228,13 @@ mod tests {
|
||||
assert_relative_eq!(b, -0.20881357058382308);
|
||||
assert_relative_eq!(c, -0.1698273481633205);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cvi_expectations() {
|
||||
let (a, b, c) = cvi_expectations(0.3, 2.7, |x| logphi(x).0);
|
||||
|
||||
assert_relative_eq!(a, -1.198109740470403);
|
||||
assert_relative_eq!(b, 0.8970390100143342);
|
||||
assert_relative_eq!(c, -0.2565392538236815);
|
||||
}
|
||||
}
|
||||
|
||||
66
src/utils.rs
66
src/utils.rs
@@ -87,3 +87,69 @@ pub fn logphi(z: f64) -> (f64, f64) {
|
||||
(res, dres)
|
||||
}
|
||||
}
|
||||
|
||||
pub const ROOTS_HERMITENORM_XS: [f64; 30] = [
|
||||
-9.706236,
|
||||
-8.68083772,
|
||||
-7.82505174,
|
||||
-7.05539687,
|
||||
-6.33999769,
|
||||
-5.66238185,
|
||||
-5.0126006,
|
||||
-4.38402037,
|
||||
-3.77189442,
|
||||
-3.17263464,
|
||||
-2.5834021,
|
||||
-2.00185861,
|
||||
-1.42600566,
|
||||
-0.85407335,
|
||||
-0.28443876,
|
||||
0.28443876,
|
||||
0.85407335,
|
||||
1.42600566,
|
||||
2.00185861,
|
||||
2.5834021,
|
||||
3.17263464,
|
||||
3.77189442,
|
||||
4.38402037,
|
||||
5.0126006,
|
||||
5.66238185,
|
||||
6.33999769,
|
||||
7.05539687,
|
||||
7.82505174,
|
||||
8.68083772,
|
||||
9.706236,
|
||||
];
|
||||
|
||||
pub const ROOTS_HERMITENORM_WS: [f64; 30] = [
|
||||
4.11289324e-21,
|
||||
3.97441190e-17,
|
||||
4.07096517e-14,
|
||||
1.14638786e-11,
|
||||
1.29804729e-09,
|
||||
7.22454173e-08,
|
||||
2.23317741e-06,
|
||||
4.15598507e-05,
|
||||
4.92584902e-04,
|
||||
3.87200709e-03,
|
||||
2.07943554e-02,
|
||||
7.79856428e-02,
|
||||
2.07515826e-01,
|
||||
3.96164962e-01,
|
||||
5.46444893e-01,
|
||||
5.46444893e-01,
|
||||
3.96164962e-01,
|
||||
2.07515826e-01,
|
||||
7.79856428e-02,
|
||||
2.07943554e-02,
|
||||
3.87200709e-03,
|
||||
4.92584902e-04,
|
||||
4.15598507e-05,
|
||||
2.23317741e-06,
|
||||
7.22454173e-08,
|
||||
1.29804729e-09,
|
||||
1.14638786e-11,
|
||||
4.07096517e-14,
|
||||
3.97441190e-17,
|
||||
4.11289324e-21,
|
||||
];
|
||||
|
||||
@@ -2,7 +2,7 @@ extern crate intel_mkl_src;
|
||||
|
||||
use kickscore as ks;
|
||||
|
||||
// #[test]
|
||||
#[test]
|
||||
fn binary_1() {
|
||||
let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user