Clean up and dryify
This commit is contained in:
@@ -26,3 +26,96 @@ pub fn f_params(elems: &[(usize, f64)], t: f64, storage: &Storage) -> (f64, f64)
|
||||
|
||||
(m, v)
|
||||
}
|
||||
|
||||
pub(crate) struct Core {
|
||||
pub m: usize,
|
||||
pub items: Vec<usize>,
|
||||
pub coeffs: Vec<f64>,
|
||||
pub indices: Vec<usize>,
|
||||
pub ns_cav: Vec<f64>,
|
||||
pub xs_cav: Vec<f64>,
|
||||
pub t: f64,
|
||||
pub logpart: f64,
|
||||
pub exp_ll: f64,
|
||||
}
|
||||
|
||||
impl Core {
|
||||
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64) -> Self {
|
||||
Core {
|
||||
m: elems.len(),
|
||||
items: elems.iter().map(|(id, _)| id).cloned().collect(),
|
||||
coeffs: elems.iter().map(|(_, sign)| sign).cloned().collect(),
|
||||
indices: elems
|
||||
.iter()
|
||||
.map(|(id, _)| storage.get_item(*id).fitter.add_sample(t))
|
||||
.collect(),
|
||||
ns_cav: (0..elems.len()).map(|_| 0.0).collect(),
|
||||
xs_cav: (0..elems.len()).map(|_| 0.0).collect(),
|
||||
t,
|
||||
logpart: 0.0,
|
||||
exp_ll: 0.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Observation for Core {
|
||||
fn match_moments(&self, mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
|
||||
// Mean and variance of the cavity distribution in function space.
|
||||
let mut f_mean_cav = 0.0;
|
||||
let mut f_var_cav = 0.0;
|
||||
|
||||
for i in 0..self.m {
|
||||
let item = storage.item(self.items[i]);
|
||||
let idx = self.indices[i];
|
||||
let coeff = self.coeffs[i];
|
||||
|
||||
// Compute the natural parameters of the cavity distribution.
|
||||
let x_tot = 1.0 / item.fitter.vs(idx);
|
||||
let n_tot = x_tot * item.fitter.ms(idx);
|
||||
let x_cav = x_tot - item.fitter.xs(idx);
|
||||
let n_cav = n_tot - item.fitter.ns(idx);
|
||||
|
||||
self.xs_cav[i] = x_cav;
|
||||
self.ns_cav[i] = n_cav;
|
||||
|
||||
// Adjust the function-space cavity mean & variance.
|
||||
f_mean_cav += coeff * n_cav / x_cav;
|
||||
f_var_cav += coeff * coeff / x_cav;
|
||||
}
|
||||
|
||||
// Moment matching.
|
||||
let (logpart, dlogpart, d2logpart) = self.match_moments(f_mean_cav, f_var_cav);
|
||||
|
||||
for i in 0..self.m {
|
||||
let item = storage.item_mut(self.items[i]);
|
||||
let idx = self.indices[i];
|
||||
let coeff = self.coeffs[i];
|
||||
|
||||
let x_cav = self.xs_cav[i];
|
||||
let n_cav = self.ns_cav[i];
|
||||
|
||||
// Update the elements' parameters.
|
||||
let denom = 1.0 + coeff * coeff * d2logpart / x_cav;
|
||||
let x = -coeff * coeff * d2logpart / denom;
|
||||
let n = coeff * (dlogpart - coeff * (n_cav / x_cav) * d2logpart) / denom;
|
||||
|
||||
*item.fitter.xs_mut(idx) = (1.0 - lr) * item.fitter.xs(idx) + lr * x;
|
||||
*item.fitter.ns_mut(idx) = (1.0 - lr) * item.fitter.ns(idx) + lr * n;
|
||||
}
|
||||
|
||||
let diff = (self.logpart - logpart).abs();
|
||||
|
||||
// Save log partition function value for the log-likelihood.
|
||||
self.logpart = logpart;
|
||||
|
||||
diff
|
||||
}
|
||||
|
||||
fn kl_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 {
|
||||
todo!();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::storage::Storage;
|
||||
use crate::utils::logphi;
|
||||
|
||||
use super::{f_params, Observation};
|
||||
use super::{f_params, Core, Observation};
|
||||
|
||||
pub fn probit_win_observation(
|
||||
elems: &[(usize, f64)],
|
||||
@@ -26,37 +26,14 @@ fn mm_probit_win(mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
|
||||
}
|
||||
|
||||
pub struct ProbitWinObservation {
|
||||
m: usize,
|
||||
items: Vec<usize>,
|
||||
coeffs: Vec<f64>,
|
||||
indices: Vec<usize>,
|
||||
ns_cav: Vec<f64>,
|
||||
xs_cav: Vec<f64>,
|
||||
_t: f64,
|
||||
logpart: f64,
|
||||
_exp_ll: usize,
|
||||
core: Core,
|
||||
margin: f64,
|
||||
}
|
||||
|
||||
impl ProbitWinObservation {
|
||||
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
|
||||
/*
|
||||
assert len(elems) > 0, "need at least one item per observation"
|
||||
*/
|
||||
|
||||
ProbitWinObservation {
|
||||
m: elems.len(),
|
||||
items: elems.iter().map(|(id, _)| id).cloned().collect(),
|
||||
coeffs: elems.iter().map(|(_, sign)| sign).cloned().collect(),
|
||||
indices: elems
|
||||
.iter()
|
||||
.map(|(id, _)| storage.get_item(*id).fitter.add_sample(t))
|
||||
.collect(),
|
||||
ns_cav: (0..elems.len()).map(|_| 0.0).collect(),
|
||||
xs_cav: (0..elems.len()).map(|_| 0.0).collect(),
|
||||
_t: t,
|
||||
logpart: 0.0,
|
||||
_exp_ll: 0,
|
||||
core: Core::new(storage, elems, t),
|
||||
margin,
|
||||
}
|
||||
}
|
||||
@@ -68,69 +45,25 @@ impl Observation for ProbitWinObservation {
|
||||
}
|
||||
|
||||
fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
|
||||
// Mean and variance of the cavity distribution in function space.
|
||||
let mut f_mean_cav = 0.0;
|
||||
let mut f_var_cav = 0.0;
|
||||
|
||||
for i in 0..self.m {
|
||||
let item = storage.item(self.items[i]);
|
||||
let idx = self.indices[i];
|
||||
let coeff = self.coeffs[i];
|
||||
|
||||
// Compute the natural parameters of the cavity distribution.
|
||||
let x_tot = 1.0 / item.fitter.vs(idx);
|
||||
let n_tot = x_tot * item.fitter.ms(idx);
|
||||
let x_cav = x_tot - item.fitter.xs(idx);
|
||||
let n_cav = n_tot - item.fitter.ns(idx);
|
||||
|
||||
self.xs_cav[i] = x_cav;
|
||||
self.ns_cav[i] = n_cav;
|
||||
|
||||
// Adjust the function-space cavity mean & variance.
|
||||
f_mean_cav += coeff * n_cav / x_cav;
|
||||
f_var_cav += coeff * coeff / x_cav;
|
||||
}
|
||||
|
||||
// Moment matching.
|
||||
let (logpart, dlogpart, d2logpart) = self.match_moments(f_mean_cav, f_var_cav);
|
||||
|
||||
for i in 0..self.m {
|
||||
let item = storage.item_mut(self.items[i]);
|
||||
let idx = self.indices[i];
|
||||
let coeff = self.coeffs[i];
|
||||
|
||||
let x_cav = self.xs_cav[i];
|
||||
let n_cav = self.ns_cav[i];
|
||||
|
||||
// Update the elements' parameters.
|
||||
let denom = 1.0 + coeff * coeff * d2logpart / x_cav;
|
||||
let x = -coeff * coeff * d2logpart / denom;
|
||||
let n = coeff * (dlogpart - coeff * (n_cav / x_cav) * d2logpart) / denom;
|
||||
|
||||
*item.fitter.xs_mut(idx) = (1.0 - lr) * item.fitter.xs(idx) + lr * x;
|
||||
*item.fitter.ns_mut(idx) = (1.0 - lr) * item.fitter.ns(idx) + lr * n;
|
||||
}
|
||||
|
||||
let diff = (self.logpart - logpart).abs();
|
||||
|
||||
// Save log partition function value for the log-likelihood.
|
||||
self.logpart = logpart;
|
||||
|
||||
diff
|
||||
self.core.ep_update(lr, storage)
|
||||
}
|
||||
|
||||
fn kl_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 {
|
||||
todo!();
|
||||
fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
|
||||
self.core.kl_update(lr, storage)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LogitWinObservation {
|
||||
//
|
||||
core: Core,
|
||||
_margin: f64,
|
||||
}
|
||||
|
||||
impl LogitWinObservation {
|
||||
pub fn new(_storage: &mut Storage, _elems: &[(usize, f64)], _t: f64, _margin: f64) -> Self {
|
||||
todo!();
|
||||
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
|
||||
LogitWinObservation {
|
||||
core: Core::new(storage, elems, t),
|
||||
_margin: margin,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -139,12 +72,12 @@ impl Observation for LogitWinObservation {
|
||||
todo!();
|
||||
}
|
||||
|
||||
fn ep_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 {
|
||||
todo!();
|
||||
fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
|
||||
self.core.ep_update(lr, storage)
|
||||
}
|
||||
|
||||
fn kl_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 {
|
||||
todo!();
|
||||
fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
|
||||
self.core.kl_update(lr, storage)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user