Implemented more functions

This commit is contained in:
2021-05-28 14:35:21 +02:00
parent 32dafd4e54
commit 7deeabfd55
5 changed files with 193 additions and 27 deletions

View File

@@ -33,7 +33,7 @@ impl RecursiveFitter {
let m = kernel.order();
let h = kernel.measurement_vector();
RecursiveFitter {
Self {
ts_new: Vec::new(),
kernel,
ts: Vec::new(),
@@ -215,6 +215,7 @@ impl Fitter for RecursiveFitter {
self.is_fitted = true;
}
#[allow(clippy::many_single_char_names)]
fn predict(&self, ts: &[f64]) -> (Vec<f64>, Vec<f64>) {
if !self.is_fitted {
panic!("new data since last call to `fit()`");

View File

@@ -110,7 +110,45 @@ impl Core {
diff
}
fn kl_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 {
todo!();
pub fn kl_update<F>(&mut self, storage: &mut Storage, lr: f64, cvi_expectations: F) -> f64
where
F: Fn(f64, f64) -> (f64, f64, f64),
{
// Mean and variance in function space.
let mut f_mean: f64 = 0.0;
let mut f_var: f64 = 0.0;
for i in 0..self.m {
let item = storage.item_mut(self.items[i]);
let idx = self.indices[i];
let coeff = self.coeffs[i];
// Adjust the function-space mean & variance.
f_mean += coeff * item.fitter.ms(idx);
f_var += coeff * coeff * item.fitter.vs(idx);
}
// Compute the derivatives of the exp. log-lik. w.r.t. mean parameters.
let (exp_ll, alpha, beta) = cvi_expectations(f_mean, f_var);
for i in 0..self.m {
let item = storage.item_mut(self.items[i]);
let idx = self.indices[i];
let coeff = self.coeffs[i];
// Update the elements' parameters.
let x = -2.0 * coeff * coeff * beta;
let n = coeff * (alpha - 2.0 * item.fitter.ms(idx) * coeff * beta);
*item.fitter.xs_mut(idx) = (1.0 - lr) * item.fitter.xs(idx) + lr * x;
*item.fitter.ns_mut(idx) = (1.0 - lr) * item.fitter.ns(idx) + lr * n;
}
let diff = (self.exp_ll - exp_ll).abs();
// Save the expected log-likelihood.
self.exp_ll = exp_ll;
diff
}
}

View File

@@ -1,11 +1,14 @@
use std::f64::consts::TAU;
use crate::storage::Storage;
use crate::utils::{logphi, normcdf, normpdf};
use crate::utils::{logphi, normcdf, normpdf, ROOTS_HERMITENORM_WS, ROOTS_HERMITENORM_XS};
use super::{f_params, Core, Observation};
fn mm_probit_win(mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
// Adapted from the GPML function `likErf.m`.
let z = mean_cav / (1.0 + cov_cav).sqrt();
let (logpart, val) = logphi(z);
let dlogpart = val / (1.0 + cov_cav).sqrt(); // 1st derivative w.r.t. mean.
let d2logpart = -val * (z + val) / (1.0 + cov_cav);
@@ -33,6 +36,45 @@ fn mm_probit_tie(mean_cav: f64, cov_cav: f64, margin: f64) -> (f64, f64, f64) {
(logpart, dlogpart, d2logpart)
}
fn ll_probit_win(x: f64, margin: f64) -> f64 {
logphi(x - margin).0
}
fn ll_probit_tie(x: f64, margin: f64) -> f64 {
let x = -x.abs();
let z = logphi(x + margin).0;
let a = logphi(x - margin).0 - z;
if a > -0.693 {
z + (-a.exp_m1()).ln()
} else {
z + (-a.exp()).ln_1p()
}
}
fn cvi_expectations<F>(mean: f64, var: f64, ll_fct: F) -> (f64, f64, f64)
where
F: Fn(f64) -> f64,
{
const N: usize = 30;
let std = var.sqrt();
let mut exp_ll = 0.0;
let mut alpha = 0.0;
let mut beta = 0.0;
for i in 0..N {
let val =
(ROOTS_HERMITENORM_WS[i] / TAU.sqrt()) * ll_fct(std * ROOTS_HERMITENORM_XS[i] + mean);
exp_ll += val;
alpha += (ROOTS_HERMITENORM_XS[i] / std) * val;
beta += ((ROOTS_HERMITENORM_XS[i].powi(2) - 1.0) / (2.0 * var)) * val;
}
(exp_ll, alpha, beta)
}
pub struct ProbitWinObservation {
core: Core,
margin: f64,
@@ -40,7 +82,7 @@ pub struct ProbitWinObservation {
impl ProbitWinObservation {
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
ProbitWinObservation {
Self {
core: Core::new(storage, elems, t),
margin,
}
@@ -58,14 +100,17 @@ impl Observation for ProbitWinObservation {
fn ep_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
let margin = self.margin;
self.core
.ep_update(storage, lr, |mean_cav: f64, cov_cav: f64| {
mm_probit_win(mean_cav - margin, cov_cav)
})
self.core.ep_update(storage, lr, |mean_cav, cov_cav| {
mm_probit_win(mean_cav - margin, cov_cav)
})
}
fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
self.core.kl_update(lr, storage)
let margin = self.margin;
self.core.kl_update(storage, lr, |mean, var| {
cvi_expectations(mean, var, |x| ll_probit_win(x, margin))
})
}
}
@@ -94,26 +139,29 @@ impl Observation for ProbitTieObservation {
fn ep_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
let margin = self.margin;
self.core
.ep_update(storage, lr, |mean_cav: f64, cov_cav: f64| {
mm_probit_tie(mean_cav, cov_cav, margin)
})
self.core.ep_update(storage, lr, |mean_cav, cov_cav| {
mm_probit_tie(mean_cav, cov_cav, margin)
})
}
fn kl_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 {
todo!();
fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
let margin = self.margin;
self.core.kl_update(storage, lr, |mean, var| {
cvi_expectations(mean, var, |x| ll_probit_tie(x, margin))
})
}
}
pub struct LogitWinObservation {
_core: Core,
core: Core,
_margin: f64,
}
impl LogitWinObservation {
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
LogitWinObservation {
_core: Core::new(storage, elems, t),
Self {
core: Core::new(storage, elems, t),
_margin: margin,
}
}
@@ -124,18 +172,22 @@ impl Observation for LogitWinObservation {
todo!();
}
fn kl_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 {
todo!();
fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
self.core.kl_update(storage, lr, |mean, var| todo!())
}
}
pub struct LogitTieObservation {
//
core: Core,
_margin: f64,
}
impl LogitTieObservation {
pub fn new(_storage: &mut Storage, _elems: &[(usize, f64)], _t: f64, _margin: f64) -> Self {
todo!();
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
Self {
core: Core::new(storage, elems, t),
_margin: margin,
}
}
}
@@ -144,8 +196,8 @@ impl Observation for LogitTieObservation {
todo!();
}
fn kl_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 {
todo!();
fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
self.core.kl_update(storage, lr, |mean, var| todo!())
}
}
@@ -176,4 +228,13 @@ mod tests {
assert_relative_eq!(b, -0.20881357058382308);
assert_relative_eq!(c, -0.1698273481633205);
}
#[test]
fn test_cvi_expectations() {
let (a, b, c) = cvi_expectations(0.3, 2.7, |x| logphi(x).0);
assert_relative_eq!(a, -1.198109740470403);
assert_relative_eq!(b, 0.8970390100143342);
assert_relative_eq!(c, -0.2565392538236815);
}
}

View File

@@ -87,3 +87,69 @@ pub fn logphi(z: f64) -> (f64, f64) {
(res, dres)
}
}
pub const ROOTS_HERMITENORM_XS: [f64; 30] = [
-9.706236,
-8.68083772,
-7.82505174,
-7.05539687,
-6.33999769,
-5.66238185,
-5.0126006,
-4.38402037,
-3.77189442,
-3.17263464,
-2.5834021,
-2.00185861,
-1.42600566,
-0.85407335,
-0.28443876,
0.28443876,
0.85407335,
1.42600566,
2.00185861,
2.5834021,
3.17263464,
3.77189442,
4.38402037,
5.0126006,
5.66238185,
6.33999769,
7.05539687,
7.82505174,
8.68083772,
9.706236,
];
pub const ROOTS_HERMITENORM_WS: [f64; 30] = [
4.11289324e-21,
3.97441190e-17,
4.07096517e-14,
1.14638786e-11,
1.29804729e-09,
7.22454173e-08,
2.23317741e-06,
4.15598507e-05,
4.92584902e-04,
3.87200709e-03,
2.07943554e-02,
7.79856428e-02,
2.07515826e-01,
3.96164962e-01,
5.46444893e-01,
5.46444893e-01,
3.96164962e-01,
2.07515826e-01,
7.79856428e-02,
2.07943554e-02,
3.87200709e-03,
4.92584902e-04,
4.15598507e-05,
2.23317741e-06,
7.22454173e-08,
1.29804729e-09,
1.14638786e-11,
4.07096517e-14,
3.97441190e-17,
4.11289324e-21,
];