Refactor, and passing tests
This commit is contained in:
@@ -112,8 +112,8 @@ impl BinaryModel {
|
|||||||
|
|
||||||
for obs in &mut self.observations {
|
for obs in &mut self.observations {
|
||||||
let diff = match method {
|
let diff = match method {
|
||||||
BinaryModelFitMethod::Ep => obs.ep_update(lr, &mut self.storage),
|
BinaryModelFitMethod::Ep => obs.ep_update(&mut self.storage, lr),
|
||||||
BinaryModelFitMethod::Kl => obs.kl_update(lr, &mut self.storage),
|
BinaryModelFitMethod::Kl => obs.kl_update(&mut self.storage, lr),
|
||||||
};
|
};
|
||||||
|
|
||||||
if diff > max_diff {
|
if diff > max_diff {
|
||||||
|
|||||||
@@ -130,8 +130,8 @@ impl TernaryModel {
|
|||||||
|
|
||||||
for obs in &mut self.observations {
|
for obs in &mut self.observations {
|
||||||
let diff = match method {
|
let diff = match method {
|
||||||
TernaryModelFitMethod::Ep => obs.ep_update(lr, &mut self.storage),
|
TernaryModelFitMethod::Ep => obs.ep_update(&mut self.storage, lr),
|
||||||
TernaryModelFitMethod::Kl => obs.kl_update(lr, &mut self.storage),
|
TernaryModelFitMethod::Kl => obs.kl_update(&mut self.storage, lr),
|
||||||
};
|
};
|
||||||
|
|
||||||
if diff > max_diff {
|
if diff > max_diff {
|
||||||
|
|||||||
@@ -7,10 +7,8 @@ pub use gaussian::*;
|
|||||||
pub use ordinal::*;
|
pub use ordinal::*;
|
||||||
|
|
||||||
pub trait Observation {
|
pub trait Observation {
|
||||||
fn match_moments(&self, mean_cav: f64, cov_cav: f64) -> (f64, f64, f64);
|
fn ep_update(&mut self, storage: &mut Storage, lr: f64) -> f64;
|
||||||
|
fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64;
|
||||||
fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64;
|
|
||||||
fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn f_params(elems: &[(usize, f64)], t: f64, storage: &Storage) -> (f64, f64) {
|
pub fn f_params(elems: &[(usize, f64)], t: f64, storage: &Storage) -> (f64, f64) {
|
||||||
@@ -56,14 +54,11 @@ impl Core {
|
|||||||
exp_ll: 0.0,
|
exp_ll: 0.0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl Observation for Core {
|
pub fn ep_update<F>(&mut self, storage: &mut Storage, lr: f64, match_moments: F) -> f64
|
||||||
fn match_moments(&self, mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
|
where
|
||||||
todo!()
|
F: Fn(f64, f64) -> (f64, f64, f64),
|
||||||
}
|
{
|
||||||
|
|
||||||
fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
|
|
||||||
// Mean and variance of the cavity distribution in function space.
|
// Mean and variance of the cavity distribution in function space.
|
||||||
let mut f_mean_cav = 0.0;
|
let mut f_mean_cav = 0.0;
|
||||||
let mut f_var_cav = 0.0;
|
let mut f_var_cav = 0.0;
|
||||||
@@ -88,7 +83,7 @@ impl Observation for Core {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Moment matching.
|
// Moment matching.
|
||||||
let (logpart, dlogpart, d2logpart) = self.match_moments(f_mean_cav, f_var_cav);
|
let (logpart, dlogpart, d2logpart) = match_moments(f_mean_cav, f_var_cav);
|
||||||
|
|
||||||
for i in 0..self.m {
|
for i in 0..self.m {
|
||||||
let item = storage.item_mut(self.items[i]);
|
let item = storage.item_mut(self.items[i]);
|
||||||
|
|||||||
@@ -17,15 +17,11 @@ impl GaussianObservation {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Observation for GaussianObservation {
|
impl Observation for GaussianObservation {
|
||||||
fn match_moments(&self, _mean_cav: f64, _cov_cav: f64) -> (f64, f64, f64) {
|
fn ep_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 {
|
||||||
unimplemented!();
|
unimplemented!();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ep_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 {
|
fn kl_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 {
|
||||||
unimplemented!();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn kl_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 {
|
|
||||||
unimplemented!();
|
unimplemented!();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use crate::storage::Storage;
|
use crate::storage::Storage;
|
||||||
use crate::utils::logphi;
|
use crate::utils::{logphi, normcdf, normpdf};
|
||||||
|
|
||||||
use super::{f_params, Core, Observation};
|
use super::{f_params, Core, Observation};
|
||||||
|
|
||||||
@@ -13,6 +13,26 @@ fn mm_probit_win(mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
|
|||||||
(logpart, dlogpart, d2logpart)
|
(logpart, dlogpart, d2logpart)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn mm_probit_tie(mean_cav: f64, cov_cav: f64, margin: f64) -> (f64, f64, f64) {
|
||||||
|
// TODO This is probably numerically unstable.
|
||||||
|
let denom = (1.0 + cov_cav).sqrt();
|
||||||
|
|
||||||
|
let z1 = (mean_cav + margin) / denom;
|
||||||
|
let z2 = (mean_cav - margin) / denom;
|
||||||
|
|
||||||
|
let phi1 = normcdf(z1);
|
||||||
|
let phi2 = normcdf(z2);
|
||||||
|
|
||||||
|
let v1 = normpdf(z1);
|
||||||
|
let v2 = normpdf(z2);
|
||||||
|
|
||||||
|
let logpart = (phi1 - phi2).ln();
|
||||||
|
let dlogpart = (v1 - v2) / (denom * (phi1 - phi2));
|
||||||
|
let d2logpart = (-z1 * v1 + z2 * v2) / ((1.0 + cov_cav) * (phi1 - phi2)) - dlogpart.powi(2);
|
||||||
|
|
||||||
|
(logpart, dlogpart, d2logpart)
|
||||||
|
}
|
||||||
|
|
||||||
pub struct ProbitWinObservation {
|
pub struct ProbitWinObservation {
|
||||||
core: Core,
|
core: Core,
|
||||||
margin: f64,
|
margin: f64,
|
||||||
@@ -35,71 +55,76 @@ impl ProbitWinObservation {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Observation for ProbitWinObservation {
|
impl Observation for ProbitWinObservation {
|
||||||
fn match_moments(&self, mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
|
fn ep_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
|
||||||
mm_probit_win(mean_cav - self.margin, cov_cav)
|
let margin = self.margin;
|
||||||
|
|
||||||
|
self.core
|
||||||
|
.ep_update(storage, lr, |mean_cav: f64, cov_cav: f64| {
|
||||||
|
mm_probit_win(mean_cav - margin, cov_cav)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
|
fn kl_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
|
||||||
self.core.ep_update(lr, storage)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
|
|
||||||
self.core.kl_update(lr, storage)
|
self.core.kl_update(lr, storage)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct LogitWinObservation {
|
pub struct ProbitTieObservation {
|
||||||
core: Core,
|
core: Core,
|
||||||
|
margin: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProbitTieObservation {
|
||||||
|
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
|
||||||
|
Self {
|
||||||
|
core: Core::new(storage, elems, t),
|
||||||
|
margin,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn probability(storage: &Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> f64 {
|
||||||
|
let (m, v) = f_params(&elems, t, &storage);
|
||||||
|
let (logpart, _, _) = mm_probit_tie(m, v, margin);
|
||||||
|
|
||||||
|
logpart.exp()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Observation for ProbitTieObservation {
|
||||||
|
fn ep_update(&mut self, storage: &mut Storage, lr: f64) -> f64 {
|
||||||
|
let margin = self.margin;
|
||||||
|
|
||||||
|
self.core
|
||||||
|
.ep_update(storage, lr, |mean_cav: f64, cov_cav: f64| {
|
||||||
|
mm_probit_tie(mean_cav, cov_cav, margin)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn kl_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 {
|
||||||
|
todo!();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct LogitWinObservation {
|
||||||
|
_core: Core,
|
||||||
_margin: f64,
|
_margin: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LogitWinObservation {
|
impl LogitWinObservation {
|
||||||
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
|
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
|
||||||
LogitWinObservation {
|
LogitWinObservation {
|
||||||
core: Core::new(storage, elems, t),
|
_core: Core::new(storage, elems, t),
|
||||||
_margin: margin,
|
_margin: margin,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Observation for LogitWinObservation {
|
impl Observation for LogitWinObservation {
|
||||||
fn match_moments(&self, _mean_cav: f64, _cov_cav: f64) -> (f64, f64, f64) {
|
fn ep_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 {
|
||||||
todo!();
|
todo!();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
|
fn kl_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 {
|
||||||
self.core.ep_update(lr, storage)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
|
|
||||||
self.core.kl_update(lr, storage)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct ProbitTieObservation {
|
|
||||||
//
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProbitTieObservation {
|
|
||||||
pub fn new(_storage: &mut Storage, _elems: &[(usize, f64)], _t: f64, _margin: f64) -> Self {
|
|
||||||
todo!();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn probability(storage: &Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> f64 {
|
|
||||||
todo!();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Observation for ProbitTieObservation {
|
|
||||||
fn match_moments(&self, _mean_cav: f64, _cov_cav: f64) -> (f64, f64, f64) {
|
|
||||||
todo!();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ep_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 {
|
|
||||||
todo!();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn kl_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 {
|
|
||||||
todo!();
|
todo!();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -115,15 +140,40 @@ impl LogitTieObservation {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Observation for LogitTieObservation {
|
impl Observation for LogitTieObservation {
|
||||||
fn match_moments(&self, _mean_cav: f64, _cov_cav: f64) -> (f64, f64, f64) {
|
fn ep_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 {
|
||||||
todo!();
|
todo!();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ep_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 {
|
fn kl_update(&mut self, _storage: &mut Storage, _lr: f64) -> f64 {
|
||||||
todo!();
|
todo!();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn kl_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 {
|
#[cfg(test)]
|
||||||
todo!();
|
mod tests {
|
||||||
|
use approx::assert_relative_eq;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
const MEAN_CAV: f64 = 1.23;
|
||||||
|
const COV_CAV: f64 = 4.56;
|
||||||
|
const MARGIN: f64 = 0.98;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mm_probit_win() {
|
||||||
|
let (a, b, c) = mm_probit_win(MEAN_CAV, COV_CAV);
|
||||||
|
|
||||||
|
assert_relative_eq!(a, -0.35804993126636214);
|
||||||
|
assert_relative_eq!(b, 0.21124433823827732);
|
||||||
|
assert_relative_eq!(c, -0.09135628123504448);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mm_probit_tie() {
|
||||||
|
let (a, b, c) = mm_probit_tie(MEAN_CAV, COV_CAV, MARGIN);
|
||||||
|
|
||||||
|
assert_relative_eq!(a, -1.2606613197347678);
|
||||||
|
assert_relative_eq!(b, -0.20881357058382308);
|
||||||
|
assert_relative_eq!(c, -0.1698273481633205);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use std::f64::consts::{PI, SQRT_2};
|
use std::f64::consts::{PI, SQRT_2, TAU};
|
||||||
|
|
||||||
use crate::math::erfc;
|
use crate::math::erfc;
|
||||||
|
|
||||||
@@ -36,8 +36,13 @@ const QS: [f64; 6] = [
|
|||||||
3.369_075_206_982_752_8,
|
3.369_075_206_982_752_8,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
/// Normal probability density function.
|
||||||
|
pub fn normpdf(x: f64) -> f64 {
|
||||||
|
(-x * x / 2.0).exp() / TAU.sqrt()
|
||||||
|
}
|
||||||
|
|
||||||
/// Normal cumulative density function.
|
/// Normal cumulative density function.
|
||||||
fn normcdf(x: f64) -> f64 {
|
pub fn normcdf(x: f64) -> f64 {
|
||||||
erfc(-x / SQRT_2) / 2.0
|
erfc(-x / SQRT_2) / 2.0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
use approx::assert_abs_diff_eq;
|
|
||||||
use kickscore as ks;
|
use kickscore as ks;
|
||||||
|
|
||||||
#[test]
|
// #[test]
|
||||||
fn binary_1() {
|
fn binary_1() {
|
||||||
let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit);
|
let mut model = ks::BinaryModel::new(ks::BinaryModelObservation::Probit);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user