Implemented more functions
This commit is contained in:
@@ -33,7 +33,7 @@ impl RecursiveFitter {
|
|||||||
let m = kernel.order();
|
let m = kernel.order();
|
||||||
let h = kernel.measurement_vector();
|
let h = kernel.measurement_vector();
|
||||||
|
|
||||||
RecursiveFitter {
|
Self {
|
||||||
ts_new: Vec::new(),
|
ts_new: Vec::new(),
|
||||||
kernel,
|
kernel,
|
||||||
ts: Vec::new(),
|
ts: Vec::new(),
|
||||||
@@ -215,6 +215,7 @@ impl Fitter for RecursiveFitter {
|
|||||||
self.is_fitted = true;
|
self.is_fitted = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::many_single_char_names)]
|
||||||
fn predict(&self, ts: &[f64]) -> (Vec<f64>, Vec<f64>) {
|
fn predict(&self, ts: &[f64]) -> (Vec<f64>, Vec<f64>) {
|
||||||
if !self.is_fitted {
|
if !self.is_fitted {
|
||||||
panic!("new data since last call to `fit()`");
|
panic!("new data since last call to `fit()`");
|
||||||
|
|||||||
@@ -110,7 +110,45 @@ impl Core {
|
|||||||
diff
|
diff
|
||||||
}
|
}
|
||||||
|
|
||||||
fn kl_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 {
|
pub fn kl_update<F>(&mut self, storage: &mut Storage, lr: f64, cvi_expectations: F) -> f64
|
||||||
todo!();
|
where
|
||||||
|
F: Fn(f64, f64) -> (f64, f64, f64),
|
||||||
|
{
|
||||||
|
// Mean and variance in function space.
|
||||||
|
let mut f_mean: f64 = 0.0;
|
||||||
|
let mut f_var: f64 = 0.0;
|
||||||
|
|
||||||
|
for i in 0..self.m {
|
||||||
|
let item = storage.item_mut(self.items[i]);
|
||||||
|
let idx = self.indices[i];
|
||||||
|
let coeff = self.coeffs[i];
|
||||||
|
|
||||||
|
// Adjust the function-space mean & variance.
|
||||||
|
f_mean += coeff * item.fitter.ms(idx);
|
||||||
|
f_var += coeff * coeff * item.fitter.vs(idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the derivatives of the exp. log-lik. w.r.t. mean parameters.
|
||||||
|
let (exp_ll, alpha, beta) = cvi_expectations(f_mean, f_var);
|
||||||
|
|
||||||
|
for i in 0..self.m {
|
||||||
|
let item = storage.item_mut(self.items[i]);
|
||||||
|
let idx = self.indices[i];
|
||||||
|
let coeff = self.coeffs[i];
|
||||||
|
|
||||||
|
// Update the elements' parameters.
|
||||||
|
let x = -2.0 * coeff * coeff * beta;
|
||||||
|
let n = coeff * (alpha - 2.0 * item.fitter.ms(idx) * coeff * beta);
|
||||||
|
|
||||||
|
*item.fitter.xs_mut(idx) = (1.0 - lr) * item.fitter.xs(idx) + lr * x;
|
||||||
|
*item.fitter.ns_mut(idx) = (1.0 - lr) * item.fitter.ns(idx) + lr * n;
|
||||||
|
}
|
||||||
|
|
||||||
|
let diff = (self.exp_ll - exp_ll).abs();
|
||||||
|
|
||||||
|
// Save the expected log-likelihood.
|
||||||
|
self.exp_ll = exp_ll;
|
||||||
|
|
||||||
|
diff
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
|
use std::f64::consts::TAU;
|
||||||
|
|
||||||
use crate::storage::Storage;
|
use crate::storage::Storage;
|
||||||
use crate::utils::{logphi, normcdf, normpdf};
|
use crate::utils::{logphi, normcdf, normpdf, ROOTS_HERMITENORM_WS, ROOTS_HERMITENORM_XS};
|
||||||
|
|
||||||
use super::{f_params, Core, Observation};
|
use super::{f_params, Core, Observation};
|
||||||
|
|
||||||
fn mm_probit_win(mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
|
fn mm_probit_win(mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
|
||||||
// Adapted from the GPML function `likErf.m`.
|
// Adapted from the GPML function `likErf.m`.
|
||||||
let z = mean_cav / (1.0 + cov_cav).sqrt();
|
let z = mean_cav / (1.0 + cov_cav).sqrt();
|
||||||
|
|
||||||
let (logpart, val) = logphi(z);
|
let (logpart, val) = logphi(z);
|
||||||
let dlogpart = val / (1.0 + cov_cav).sqrt(); // 1st derivative w.r.t. mean.
|
let dlogpart = val / (1.0 + cov_cav).sqrt(); // 1st derivative w.r.t. mean.
|
||||||
let d2logpart = -val * (z + val) / (1.0 + cov_cav);
|
let d2logpart = -val * (z + val) / (1.0 + cov_cav);
|
||||||
@@ -33,6 +36,45 @@ fn mm_probit_tie(mean_cav: f64, cov_cav: f64, margin: f64) -> (f64, f64, f64) {
|
|||||||
(logpart, dlogpart, d2logpart)
|
(logpart, dlogpart, d2logpart)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn ll_probit_win(x: f64, margin: f64) -> f64 {
|
||||||
|
logphi(x - margin).0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ll_probit_tie(x: f64, margin: f64) -> f64 {
|
||||||
|
let x = -x.abs();
|
||||||
|
let z = logphi(x + margin).0;
|
||||||
|
let a = logphi(x - margin).0 - z;
|
||||||
|
|
||||||
|
if a > -0.693 {
|
||||||
|
z + (-a.exp_m1()).ln()
|
||||||
|
} else {
|
||||||
|
z + (-a.exp()).ln_1p()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cvi_expectations<F>(mean: f64, var: f64, ll_fct: F) -> (f64, f64, f64)
|
||||||
|
where
|
||||||
|
F: Fn(f64) -> f64,
|
||||||
|
{
|
||||||
|
const N: usize = 30;
|
||||||
|
|
||||||
|
let std = var.sqrt();
|
||||||
|
let mut exp_ll = 0.0;
|
||||||
|
let mut alpha = 0.0;
|
||||||
|
let mut beta = 0.0;
|
||||||
|
|
||||||
|
for i in 0..N {
|
||||||
|
let val =
|
||||||
|
(ROOTS_HERMITENORM_WS[i] / TAU.sqrt()) * ll_fct(std * ROOTS_HERMITENORM_XS[i] + mean);
|
||||||
|
|
||||||
|
exp_ll += val;
|
||||||
|
alpha += (ROOTS_HERMITENORM_XS[i] / std) * val;
|
||||||
|
beta += ((ROOTS_HERMITENORM_XS[i].powi(2) - 1.0) / (2.0 * var)) * val;
|
||||||
|
}
|
||||||
|
|
||||||
|
(exp_ll, alpha, beta)
|
||||||
|
}
|
||||||
|
|
||||||
pub struct ProbitWinObservation {
|
pub struct ProbitWinObservation {
|
||||||
core: Core,
|
core: Core,
|
||||||
margin: f64,
|
margin: f64,
|
||||||
@@ -40,7 +82,7 @@ pub struct ProbitWinObservation {
|
|||||||
|
|
||||||
impl ProbitWinObservation {
|
impl ProbitWinObservation {
|
||||||
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
|
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
|
||||||
ProbitWinObservation {
|
Self {
|
||||||
core: Core::new(storage, elems, t),
|
core: Core::new(storage, elems, t),
|
||||||
margin,
|
margin,
|
||||||
}
|
}
|
||||||
@@ -58,14 +100,17 @@ impl Observation for ProbitWinObservation {
|
|||||||
fn ep_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
|
fn ep_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
|
||||||
let margin = self.margin;
|
let margin = self.margin;
|
||||||
|
|
||||||
self.core
|
self.core.ep_update(storage, lr, |mean_cav, cov_cav| {
|
||||||
.ep_update(storage, lr, |mean_cav: f64, cov_cav: f64| {
|
mm_probit_win(mean_cav - margin, cov_cav)
|
||||||
mm_probit_win(mean_cav - margin, cov_cav)
|
})
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
|
fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
|
||||||
self.core.kl_update(lr, storage)
|
let margin = self.margin;
|
||||||
|
|
||||||
|
self.core.kl_update(storage, lr, |mean, var| {
|
||||||
|
cvi_expectations(mean, var, |x| ll_probit_win(x, margin))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,26 +139,29 @@ impl Observation for ProbitTieObservation {
|
|||||||
fn ep_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
|
fn ep_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
|
||||||
let margin = self.margin;
|
let margin = self.margin;
|
||||||
|
|
||||||
self.core
|
self.core.ep_update(storage, lr, |mean_cav, cov_cav| {
|
||||||
.ep_update(storage, lr, |mean_cav: f64, cov_cav: f64| {
|
mm_probit_tie(mean_cav, cov_cav, margin)
|
||||||
mm_probit_tie(mean_cav, cov_cav, margin)
|
})
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn kl_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 {
|
fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
|
||||||
todo!();
|
let margin = self.margin;
|
||||||
|
|
||||||
|
self.core.kl_update(storage, lr, |mean, var| {
|
||||||
|
cvi_expectations(mean, var, |x| ll_probit_tie(x, margin))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct LogitWinObservation {
|
pub struct LogitWinObservation {
|
||||||
_core: Core,
|
core: Core,
|
||||||
_margin: f64,
|
_margin: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LogitWinObservation {
|
impl LogitWinObservation {
|
||||||
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
|
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
|
||||||
LogitWinObservation {
|
Self {
|
||||||
_core: Core::new(storage, elems, t),
|
core: Core::new(storage, elems, t),
|
||||||
_margin: margin,
|
_margin: margin,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -124,18 +172,22 @@ impl Observation for LogitWinObservation {
|
|||||||
todo!();
|
todo!();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn kl_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 {
|
fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
|
||||||
todo!();
|
self.core.kl_update(storage, lr, |mean, var| todo!())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct LogitTieObservation {
|
pub struct LogitTieObservation {
|
||||||
//
|
core: Core,
|
||||||
|
_margin: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LogitTieObservation {
|
impl LogitTieObservation {
|
||||||
pub fn new(_storage: &mut Storage, _elems: &[(usize, f64)], _t: f64, _margin: f64) -> Self {
|
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
|
||||||
todo!();
|
Self {
|
||||||
|
core: Core::new(storage, elems, t),
|
||||||
|
_margin: margin,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,8 +196,8 @@ impl Observation for LogitTieObservation {
|
|||||||
todo!();
|
todo!();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn kl_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 {
|
fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
|
||||||
todo!();
|
self.core.kl_update(storage, lr, |mean, var| todo!())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,4 +228,13 @@ mod tests {
|
|||||||
assert_relative_eq!(b, -0.20881357058382308);
|
assert_relative_eq!(b, -0.20881357058382308);
|
||||||
assert_relative_eq!(c, -0.1698273481633205);
|
assert_relative_eq!(c, -0.1698273481633205);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cvi_expectations() {
|
||||||
|
let (a, b, c) = cvi_expectations(0.3, 2.7, |x| logphi(x).0);
|
||||||
|
|
||||||
|
assert_relative_eq!(a, -1.198109740470403);
|
||||||
|
assert_relative_eq!(b, 0.8970390100143342);
|
||||||
|
assert_relative_eq!(c, -0.2565392538236815);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
66
src/utils.rs
66
src/utils.rs
@@ -87,3 +87,69 @@ pub fn logphi(z: f64) -> (f64, f64) {
|
|||||||
(res, dres)
|
(res, dres)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const ROOTS_HERMITENORM_XS: [f64; 30] = [
|
||||||
|
-9.706236,
|
||||||
|
-8.68083772,
|
||||||
|
-7.82505174,
|
||||||
|
-7.05539687,
|
||||||
|
-6.33999769,
|
||||||
|
-5.66238185,
|
||||||
|
-5.0126006,
|
||||||
|
-4.38402037,
|
||||||
|
-3.77189442,
|
||||||
|
-3.17263464,
|
||||||
|
-2.5834021,
|
||||||
|
-2.00185861,
|
||||||
|
-1.42600566,
|
||||||
|
-0.85407335,
|
||||||
|
-0.28443876,
|
||||||
|
0.28443876,
|
||||||
|
0.85407335,
|
||||||
|
1.42600566,
|
||||||
|
2.00185861,
|
||||||
|
2.5834021,
|
||||||
|
3.17263464,
|
||||||
|
3.77189442,
|
||||||
|
4.38402037,
|
||||||
|
5.0126006,
|
||||||
|
5.66238185,
|
||||||
|
6.33999769,
|
||||||
|
7.05539687,
|
||||||
|
7.82505174,
|
||||||
|
8.68083772,
|
||||||
|
9.706236,
|
||||||
|
];
|
||||||
|
|
||||||
|
pub const ROOTS_HERMITENORM_WS: [f64; 30] = [
|
||||||
|
4.11289324e-21,
|
||||||
|
3.97441190e-17,
|
||||||
|
4.07096517e-14,
|
||||||
|
1.14638786e-11,
|
||||||
|
1.29804729e-09,
|
||||||
|
7.22454173e-08,
|
||||||
|
2.23317741e-06,
|
||||||
|
4.15598507e-05,
|
||||||
|
4.92584902e-04,
|
||||||
|
3.87200709e-03,
|
||||||
|
2.07943554e-02,
|
||||||
|
7.79856428e-02,
|
||||||
|
2.07515826e-01,
|
||||||
|
3.96164962e-01,
|
||||||
|
5.46444893e-01,
|
||||||
|
5.46444893e-01,
|
||||||
|
3.96164962e-01,
|
||||||
|
2.07515826e-01,
|
||||||
|
7.79856428e-02,
|
||||||
|
2.07943554e-02,
|
||||||
|
3.87200709e-03,
|
||||||
|
4.92584902e-04,
|
||||||
|
4.15598507e-05,
|
||||||
|
2.23317741e-06,
|
||||||
|
7.22454173e-08,
|
||||||
|
1.29804729e-09,
|
||||||
|
1.14638786e-11,
|
||||||
|
4.07096517e-14,
|
||||||
|
3.97441190e-17,
|
||||||
|
4.11289324e-21,
|
||||||
|
];
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ extern crate intel_mkl_src;
|
|||||||
|
|
||||||
use kickscore as ks;
|
use kickscore as ks;
|
||||||
|
|
||||||
// #[test]
|
#[test]
|
||||||
fn binary_1() {
|
fn binary_1() {
|
||||||
let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit);
|
let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user