diff --git a/src/observation.rs b/src/observation.rs index 5582fb4..2dc7e54 100644 --- a/src/observation.rs +++ b/src/observation.rs @@ -26,3 +26,96 @@ pub fn f_params(elems: &[(usize, f64)], t: f64, storage: &Storage) -> (f64, f64) (m, v) } + +pub(crate) struct Core { + pub m: usize, + pub items: Vec, + pub coeffs: Vec, + pub indices: Vec, + pub ns_cav: Vec, + pub xs_cav: Vec, + pub t: f64, + pub logpart: f64, + pub exp_ll: f64, +} + +impl Core { + pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64) -> Self { + Core { + m: elems.len(), + items: elems.iter().map(|(id, _)| id).cloned().collect(), + coeffs: elems.iter().map(|(_, sign)| sign).cloned().collect(), + indices: elems + .iter() + .map(|(id, _)| storage.get_item(*id).fitter.add_sample(t)) + .collect(), + ns_cav: (0..elems.len()).map(|_| 0.0).collect(), + xs_cav: (0..elems.len()).map(|_| 0.0).collect(), + t, + logpart: 0.0, + exp_ll: 0.0, + } + } +} + +impl Observation for Core { + fn match_moments(&self, mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) { + todo!() + } + + fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 { + // Mean and variance of the cavity distribution in function space. + let mut f_mean_cav = 0.0; + let mut f_var_cav = 0.0; + + for i in 0..self.m { + let item = storage.item(self.items[i]); + let idx = self.indices[i]; + let coeff = self.coeffs[i]; + + // Compute the natural parameters of the cavity distribution. + let x_tot = 1.0 / item.fitter.vs(idx); + let n_tot = x_tot * item.fitter.ms(idx); + let x_cav = x_tot - item.fitter.xs(idx); + let n_cav = n_tot - item.fitter.ns(idx); + + self.xs_cav[i] = x_cav; + self.ns_cav[i] = n_cav; + + // Adjust the function-space cavity mean & variance. + f_mean_cav += coeff * n_cav / x_cav; + f_var_cav += coeff * coeff / x_cav; + } + + // Moment matching. + let (logpart, dlogpart, d2logpart) = self.match_moments(f_mean_cav, f_var_cav); + + for i in 0..self.m { + let item = storage.item_mut(self.items[i]); + let idx = self.indices[i]; + let coeff = self.coeffs[i]; + + let x_cav = self.xs_cav[i]; + let n_cav = self.ns_cav[i]; + + // Update the elements' parameters. + let denom = 1.0 + coeff * coeff * d2logpart / x_cav; + let x = -coeff * coeff * d2logpart / denom; + let n = coeff * (dlogpart - coeff * (n_cav / x_cav) * d2logpart) / denom; + + *item.fitter.xs_mut(idx) = (1.0 - lr) * item.fitter.xs(idx) + lr * x; + *item.fitter.ns_mut(idx) = (1.0 - lr) * item.fitter.ns(idx) + lr * n; + } + + let diff = (self.logpart - logpart).abs(); + + // Save log partition function value for the log-likelihood. + self.logpart = logpart; + + diff + } + + fn kl_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 { + todo!(); + } +} diff --git a/src/observation/ordinal.rs b/src/observation/ordinal.rs index c57e5e0..370829d 100644 --- a/src/observation/ordinal.rs +++ b/src/observation/ordinal.rs @@ -1,7 +1,7 @@ use crate::storage::Storage; use crate::utils::logphi; -use super::{f_params, Observation}; +use super::{f_params, Core, Observation}; pub fn probit_win_observation( elems: &[(usize, f64)], @@ -26,37 +26,14 @@ fn mm_probit_win(mean_cav: f64, cov_cav: f64) -> (f64, f64, f64) { } pub struct ProbitWinObservation { - m: usize, - items: Vec, - coeffs: Vec, - indices: Vec, - ns_cav: Vec, - xs_cav: Vec, - _t: f64, - logpart: f64, - _exp_ll: usize, + core: Core, margin: f64, } impl ProbitWinObservation { pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self { - /* - assert len(elems) > 0, "need at least one item per observation" - */ - ProbitWinObservation { - m: elems.len(), - items: elems.iter().map(|(id, _)| id).cloned().collect(), - coeffs: elems.iter().map(|(_, sign)| sign).cloned().collect(), - indices: elems - .iter() - .map(|(id, _)| storage.get_item(*id).fitter.add_sample(t)) - .collect(), - ns_cav: (0..elems.len()).map(|_| 0.0).collect(), - xs_cav: (0..elems.len()).map(|_| 0.0).collect(), - _t: t, - logpart: 0.0, - _exp_ll: 0, + core: Core::new(storage, elems, t), margin, } } @@ -68,69 +45,25 @@ impl Observation for ProbitWinObservation { } fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 { - // Mean and variance of the cavity distribution in function space. - let mut f_mean_cav = 0.0; - let mut f_var_cav = 0.0; - - for i in 0..self.m { - let item = storage.item(self.items[i]); - let idx = self.indices[i]; - let coeff = self.coeffs[i]; - - // Compute the natural parameters of the cavity distribution. - let x_tot = 1.0 / item.fitter.vs(idx); - let n_tot = x_tot * item.fitter.ms(idx); - let x_cav = x_tot - item.fitter.xs(idx); - let n_cav = n_tot - item.fitter.ns(idx); - - self.xs_cav[i] = x_cav; - self.ns_cav[i] = n_cav; - - // Adjust the function-space cavity mean & variance. - f_mean_cav += coeff * n_cav / x_cav; - f_var_cav += coeff * coeff / x_cav; - } - - // Moment matching. - let (logpart, dlogpart, d2logpart) = self.match_moments(f_mean_cav, f_var_cav); - - for i in 0..self.m { - let item = storage.item_mut(self.items[i]); - let idx = self.indices[i]; - let coeff = self.coeffs[i]; - - let x_cav = self.xs_cav[i]; - let n_cav = self.ns_cav[i]; - - // Update the elements' parameters. - let denom = 1.0 + coeff * coeff * d2logpart / x_cav; - let x = -coeff * coeff * d2logpart / denom; - let n = coeff * (dlogpart - coeff * (n_cav / x_cav) * d2logpart) / denom; - - *item.fitter.xs_mut(idx) = (1.0 - lr) * item.fitter.xs(idx) + lr * x; - *item.fitter.ns_mut(idx) = (1.0 - lr) * item.fitter.ns(idx) + lr * n; - } - - let diff = (self.logpart - logpart).abs(); - - // Save log partition function value for the log-likelihood. - self.logpart = logpart; - - diff + self.core.ep_update(lr, storage) } - fn kl_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 { - todo!(); + fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64 { + self.core.kl_update(lr, storage) } } pub struct LogitWinObservation { - // + core: Core, + _margin: f64, } impl LogitWinObservation { - pub fn new(_storage: &mut Storage, _elems: &[(usize, f64)], _t: f64, _margin: f64) -> Self { - todo!(); + pub fn new(storage: &mut Storage, elems: &[(usize, f64)], t: f64, margin: f64) -> Self { + LogitWinObservation { + core: Core::new(storage, elems, t), + _margin: margin, + } } } @@ -139,12 +72,12 @@ impl Observation for LogitWinObservation { todo!(); } - fn ep_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 { - todo!(); + fn ep_update(&mut self, lr: f64, storage: &mut Storage) -> f64 { + self.core.ep_update(lr, storage) } - fn kl_update(&mut self, _lr: f64, _storage: &mut Storage) -> f64 { - todo!(); + fn kl_update(&mut self, lr: f64, storage: &mut Storage) -> f64 { + self.core.kl_update(lr, storage) } }