Clean up and dryify

This commit is contained in:
2021-05-25 11:16:35 +02:00
parent c0846609e3
commit a2ecd9b268
2 changed files with 110 additions and 84 deletions

View File

@@ -26,3 +26,96 @@ pub fn f_params(elems: &[(usize, f64)], t: f64, storage: &Storage) -> (f64, f64)
(m, v)
}
pub(crate) struct Core {
pub m: usize,
pub items: Vec<usize>,
pub coeffs: Vec<f64>,
pub indices: Vec<usize>,
pub ns_cav: Vec<f64>,
pub xs_cav: Vec<f64>,
pub t: f64,
pub logpart: f64,
pub exp_ll: f64,
}
impl Core {
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64) -> Self {
Core {
m: elems.len(),
items: elems.iter().map(|(id, _)| id).cloned().collect(),
coeffs: elems.iter().map(|(_, sign)| sign).cloned().collect(),
indices: elems
.iter()
.map(|(id, _)| storage.get_item(*id).fitter.add_sample(t))
.collect(),
ns_cav: (0..elems.len()).map(|_| 0.0).collect(),
xs_cav: (0..elems.len()).map(|_| 0.0).collect(),
t,
logpart: 0.0,
exp_ll: 0.0,
}
}
}
impl Observation for Core {
fn match_moments(&self, mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
todo!()
}
fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
// Mean and variance of the cavity distribution in function space.
let mut f_mean_cav = 0.0;
let mut f_var_cav = 0.0;
for i in 0..self.m {
let item = storage.item(self.items[i]);
let idx = self.indices[i];
let coeff = self.coeffs[i];
// Compute the natural parameters of the cavity distribution.
let x_tot = 1.0 / item.fitter.vs(idx);
let n_tot = x_tot * item.fitter.ms(idx);
let x_cav = x_tot - item.fitter.xs(idx);
let n_cav = n_tot - item.fitter.ns(idx);
self.xs_cav[i] = x_cav;
self.ns_cav[i] = n_cav;
// Adjust the function-space cavity mean & variance.
f_mean_cav += coeff * n_cav / x_cav;
f_var_cav += coeff * coeff / x_cav;
}
// Moment matching.
let (logpart, dlogpart, d2logpart) = self.match_moments(f_mean_cav, f_var_cav);
for i in 0..self.m {
let item = storage.item_mut(self.items[i]);
let idx = self.indices[i];
let coeff = self.coeffs[i];
let x_cav = self.xs_cav[i];
let n_cav = self.ns_cav[i];
// Update the elements' parameters.
let denom = 1.0 + coeff * coeff * d2logpart / x_cav;
let x = -coeff * coeff * d2logpart / denom;
let n = coeff * (dlogpart - coeff * (n_cav / x_cav) * d2logpart) / denom;
*item.fitter.xs_mut(idx) = (1.0 - lr) * item.fitter.xs(idx) + lr * x;
*item.fitter.ns_mut(idx) = (1.0 - lr) * item.fitter.ns(idx) + lr * n;
}
let diff = (self.logpart - logpart).abs();
// Save log partition function value for the log-likelihood.
self.logpart = logpart;
diff
}
fn kl_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 {
todo!();
}
}

View File

@@ -1,7 +1,7 @@
use crate::storage::Storage;
use crate::utils::logphi;
use super::{f_params, Observation};
use super::{f_params, Core, Observation};
pub fn probit_win_observation(
elems: &[(usize, f64)],
@@ -26,37 +26,14 @@ fn mm_probit_win(mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
}
pub struct ProbitWinObservation {
m: usize,
items: Vec<usize>,
coeffs: Vec<f64>,
indices: Vec<usize>,
ns_cav: Vec<f64>,
xs_cav: Vec<f64>,
_t: f64,
logpart: f64,
_exp_ll: usize,
core: Core,
margin: f64,
}
impl ProbitWinObservation {
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
/*
assert len(elems) > 0, "need at least one item per observation"
*/
ProbitWinObservation {
m: elems.len(),
items: elems.iter().map(|(id, _)| id).cloned().collect(),
coeffs: elems.iter().map(|(_, sign)| sign).cloned().collect(),
indices: elems
.iter()
.map(|(id, _)| storage.get_item(*id).fitter.add_sample(t))
.collect(),
ns_cav: (0..elems.len()).map(|_| 0.0).collect(),
xs_cav: (0..elems.len()).map(|_| 0.0).collect(),
_t: t,
logpart: 0.0,
_exp_ll: 0,
core: Core::new(storage, elems, t),
margin,
}
}
@@ -68,69 +45,25 @@ impl Observation for ProbitWinObservation {
}
fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
// Mean and variance of the cavity distribution in function space.
let mut f_mean_cav = 0.0;
let mut f_var_cav = 0.0;
for i in 0..self.m {
let item = storage.item(self.items[i]);
let idx = self.indices[i];
let coeff = self.coeffs[i];
// Compute the natural parameters of the cavity distribution.
let x_tot = 1.0 / item.fitter.vs(idx);
let n_tot = x_tot * item.fitter.ms(idx);
let x_cav = x_tot - item.fitter.xs(idx);
let n_cav = n_tot - item.fitter.ns(idx);
self.xs_cav[i] = x_cav;
self.ns_cav[i] = n_cav;
// Adjust the function-space cavity mean & variance.
f_mean_cav += coeff * n_cav / x_cav;
f_var_cav += coeff * coeff / x_cav;
}
// Moment matching.
let (logpart, dlogpart, d2logpart) = self.match_moments(f_mean_cav, f_var_cav);
for i in 0..self.m {
let item = storage.item_mut(self.items[i]);
let idx = self.indices[i];
let coeff = self.coeffs[i];
let x_cav = self.xs_cav[i];
let n_cav = self.ns_cav[i];
// Update the elements' parameters.
let denom = 1.0 + coeff * coeff * d2logpart / x_cav;
let x = -coeff * coeff * d2logpart / denom;
let n = coeff * (dlogpart - coeff * (n_cav / x_cav) * d2logpart) / denom;
*item.fitter.xs_mut(idx) = (1.0 - lr) * item.fitter.xs(idx) + lr * x;
*item.fitter.ns_mut(idx) = (1.0 - lr) * item.fitter.ns(idx) + lr * n;
}
let diff = (self.logpart - logpart).abs();
// Save log partition function value for the log-likelihood.
self.logpart = logpart;
diff
self.core.ep_update(lr, storage)
}
fn kl_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 {
todo!();
fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
self.core.kl_update(lr, storage)
}
}
pub struct LogitWinObservation {
//
core: Core,
_margin: f64,
}
impl LogitWinObservation {
pub fn new(_storage: &mut Storage, _elems: &[(usize, f64)], _t: f64, _margin: f64) -> Self {
todo!();
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
LogitWinObservation {
core: Core::new(storage, elems, t),
_margin: margin,
}
}
}
@@ -139,12 +72,12 @@ impl Observation for LogitWinObservation {
todo!();
}
fn ep_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 {
todo!();
fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
self.core.ep_update(lr, storage)
}
fn kl_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 {
todo!();
fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
self.core.kl_update(lr, storage)
}
}