Clean up and dryify
This commit is contained in:
@@ -26,3 +26,96 @@ pub fn f_params(elems: &[(usize, f64)], t: f64, storage: &Storage) -> (f64, f64)
|
|||||||
|
|
||||||
(m, v)
|
(m, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) struct Core {
|
||||||
|
pub m: usize,
|
||||||
|
pub items: Vec<usize>,
|
||||||
|
pub coeffs: Vec<f64>,
|
||||||
|
pub indices: Vec<usize>,
|
||||||
|
pub ns_cav: Vec<f64>,
|
||||||
|
pub xs_cav: Vec<f64>,
|
||||||
|
pub t: f64,
|
||||||
|
pub logpart: f64,
|
||||||
|
pub exp_ll: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Core {
|
||||||
|
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64) -> Self {
|
||||||
|
Core {
|
||||||
|
m: elems.len(),
|
||||||
|
items: elems.iter().map(|(id, _)| id).cloned().collect(),
|
||||||
|
coeffs: elems.iter().map(|(_, sign)| sign).cloned().collect(),
|
||||||
|
indices: elems
|
||||||
|
.iter()
|
||||||
|
.map(|(id, _)| storage.get_item(*id).fitter.add_sample(t))
|
||||||
|
.collect(),
|
||||||
|
ns_cav: (0..elems.len()).map(|_| 0.0).collect(),
|
||||||
|
xs_cav: (0..elems.len()).map(|_| 0.0).collect(),
|
||||||
|
t,
|
||||||
|
logpart: 0.0,
|
||||||
|
exp_ll: 0.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Observation for Core {
|
||||||
|
fn match_moments(&self, mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
|
||||||
|
// Mean and variance of the cavity distribution in function space.
|
||||||
|
let mut f_mean_cav = 0.0;
|
||||||
|
let mut f_var_cav = 0.0;
|
||||||
|
|
||||||
|
for i in 0..self.m {
|
||||||
|
let item = storage.item(self.items[i]);
|
||||||
|
let idx = self.indices[i];
|
||||||
|
let coeff = self.coeffs[i];
|
||||||
|
|
||||||
|
// Compute the natural parameters of the cavity distribution.
|
||||||
|
let x_tot = 1.0 / item.fitter.vs(idx);
|
||||||
|
let n_tot = x_tot * item.fitter.ms(idx);
|
||||||
|
let x_cav = x_tot - item.fitter.xs(idx);
|
||||||
|
let n_cav = n_tot - item.fitter.ns(idx);
|
||||||
|
|
||||||
|
self.xs_cav[i] = x_cav;
|
||||||
|
self.ns_cav[i] = n_cav;
|
||||||
|
|
||||||
|
// Adjust the function-space cavity mean & variance.
|
||||||
|
f_mean_cav += coeff * n_cav / x_cav;
|
||||||
|
f_var_cav += coeff * coeff / x_cav;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Moment matching.
|
||||||
|
let (logpart, dlogpart, d2logpart) = self.match_moments(f_mean_cav, f_var_cav);
|
||||||
|
|
||||||
|
for i in 0..self.m {
|
||||||
|
let item = storage.item_mut(self.items[i]);
|
||||||
|
let idx = self.indices[i];
|
||||||
|
let coeff = self.coeffs[i];
|
||||||
|
|
||||||
|
let x_cav = self.xs_cav[i];
|
||||||
|
let n_cav = self.ns_cav[i];
|
||||||
|
|
||||||
|
// Update the elements' parameters.
|
||||||
|
let denom = 1.0 + coeff * coeff * d2logpart / x_cav;
|
||||||
|
let x = -coeff * coeff * d2logpart / denom;
|
||||||
|
let n = coeff * (dlogpart - coeff * (n_cav / x_cav) * d2logpart) / denom;
|
||||||
|
|
||||||
|
*item.fitter.xs_mut(idx) = (1.0 - lr) * item.fitter.xs(idx) + lr * x;
|
||||||
|
*item.fitter.ns_mut(idx) = (1.0 - lr) * item.fitter.ns(idx) + lr * n;
|
||||||
|
}
|
||||||
|
|
||||||
|
let diff = (self.logpart - logpart).abs();
|
||||||
|
|
||||||
|
// Save log partition function value for the log-likelihood.
|
||||||
|
self.logpart = logpart;
|
||||||
|
|
||||||
|
diff
|
||||||
|
}
|
||||||
|
|
||||||
|
fn kl_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 {
|
||||||
|
todo!();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use crate::storage::Storage;
|
use crate::storage::Storage;
|
||||||
use crate::utils::logphi;
|
use crate::utils::logphi;
|
||||||
|
|
||||||
use super::{f_params, Observation};
|
use super::{f_params, Core, Observation};
|
||||||
|
|
||||||
pub fn probit_win_observation(
|
pub fn probit_win_observation(
|
||||||
elems: &[(usize, f64)],
|
elems: &[(usize, f64)],
|
||||||
@@ -26,37 +26,14 @@ fn mm_probit_win(mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub struct ProbitWinObservation {
|
pub struct ProbitWinObservation {
|
||||||
m: usize,
|
core: Core,
|
||||||
items: Vec<usize>,
|
|
||||||
coeffs: Vec<f64>,
|
|
||||||
indices: Vec<usize>,
|
|
||||||
ns_cav: Vec<f64>,
|
|
||||||
xs_cav: Vec<f64>,
|
|
||||||
_t: f64,
|
|
||||||
logpart: f64,
|
|
||||||
_exp_ll: usize,
|
|
||||||
margin: f64,
|
margin: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ProbitWinObservation {
|
impl ProbitWinObservation {
|
||||||
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
|
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
|
||||||
/*
|
|
||||||
assert len(elems) > 0, "need at least one item per observation"
|
|
||||||
*/
|
|
||||||
|
|
||||||
ProbitWinObservation {
|
ProbitWinObservation {
|
||||||
m: elems.len(),
|
core: Core::new(storage, elems, t),
|
||||||
items: elems.iter().map(|(id, _)| id).cloned().collect(),
|
|
||||||
coeffs: elems.iter().map(|(_, sign)| sign).cloned().collect(),
|
|
||||||
indices: elems
|
|
||||||
.iter()
|
|
||||||
.map(|(id, _)| storage.get_item(*id).fitter.add_sample(t))
|
|
||||||
.collect(),
|
|
||||||
ns_cav: (0..elems.len()).map(|_| 0.0).collect(),
|
|
||||||
xs_cav: (0..elems.len()).map(|_| 0.0).collect(),
|
|
||||||
_t: t,
|
|
||||||
logpart: 0.0,
|
|
||||||
_exp_ll: 0,
|
|
||||||
margin,
|
margin,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -68,69 +45,25 @@ impl Observation for ProbitWinObservation {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
|
fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
|
||||||
// Mean and variance of the cavity distribution in function space.
|
self.core.ep_update(lr, storage)
|
||||||
let mut f_mean_cav = 0.0;
|
|
||||||
let mut f_var_cav = 0.0;
|
|
||||||
|
|
||||||
for i in 0..self.m {
|
|
||||||
let item = storage.item(self.items[i]);
|
|
||||||
let idx = self.indices[i];
|
|
||||||
let coeff = self.coeffs[i];
|
|
||||||
|
|
||||||
// Compute the natural parameters of the cavity distribution.
|
|
||||||
let x_tot = 1.0 / item.fitter.vs(idx);
|
|
||||||
let n_tot = x_tot * item.fitter.ms(idx);
|
|
||||||
let x_cav = x_tot - item.fitter.xs(idx);
|
|
||||||
let n_cav = n_tot - item.fitter.ns(idx);
|
|
||||||
|
|
||||||
self.xs_cav[i] = x_cav;
|
|
||||||
self.ns_cav[i] = n_cav;
|
|
||||||
|
|
||||||
// Adjust the function-space cavity mean & variance.
|
|
||||||
f_mean_cav += coeff * n_cav / x_cav;
|
|
||||||
f_var_cav += coeff * coeff / x_cav;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Moment matching.
|
|
||||||
let (logpart, dlogpart, d2logpart) = self.match_moments(f_mean_cav, f_var_cav);
|
|
||||||
|
|
||||||
for i in 0..self.m {
|
|
||||||
let item = storage.item_mut(self.items[i]);
|
|
||||||
let idx = self.indices[i];
|
|
||||||
let coeff = self.coeffs[i];
|
|
||||||
|
|
||||||
let x_cav = self.xs_cav[i];
|
|
||||||
let n_cav = self.ns_cav[i];
|
|
||||||
|
|
||||||
// Update the elements' parameters.
|
|
||||||
let denom = 1.0 + coeff * coeff * d2logpart / x_cav;
|
|
||||||
let x = -coeff * coeff * d2logpart / denom;
|
|
||||||
let n = coeff * (dlogpart - coeff * (n_cav / x_cav) * d2logpart) / denom;
|
|
||||||
|
|
||||||
*item.fitter.xs_mut(idx) = (1.0 - lr) * item.fitter.xs(idx) + lr * x;
|
|
||||||
*item.fitter.ns_mut(idx) = (1.0 - lr) * item.fitter.ns(idx) + lr * n;
|
|
||||||
}
|
|
||||||
|
|
||||||
let diff = (self.logpart - logpart).abs();
|
|
||||||
|
|
||||||
// Save log partition function value for the log-likelihood.
|
|
||||||
self.logpart = logpart;
|
|
||||||
|
|
||||||
diff
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn kl_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 {
|
fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
|
||||||
todo!();
|
self.core.kl_update(lr, storage)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct LogitWinObservation {
|
pub struct LogitWinObservation {
|
||||||
//
|
core: Core,
|
||||||
|
_margin: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LogitWinObservation {
|
impl LogitWinObservation {
|
||||||
pub fn new(_storage: &mut Storage, _elems: &[(usize, f64)], _t: f64, _margin: f64) -> Self {
|
pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self {
|
||||||
todo!();
|
LogitWinObservation {
|
||||||
|
core: Core::new(storage, elems, t),
|
||||||
|
_margin: margin,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -139,12 +72,12 @@ impl Observation for LogitWinObservation {
|
|||||||
todo!();
|
todo!();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ep_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 {
|
fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
|
||||||
todo!();
|
self.core.ep_update(lr, storage)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn kl_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 {
|
fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64 {
|
||||||
todo!();
|
self.core.kl_update(lr, storage)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user