From 7e2576085f779ffa85d17638dc67a91aac96077b Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Thu, 26 Oct 2023 11:11:54 +0200 Subject: [PATCH] Make quality a free standing function instead --- src/history.rs | 62 -------------------------------------------------- src/lib.rs | 62 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 62 deletions(-) diff --git a/src/history.rs b/src/history.rs index 0d26b8c..4a09dc4 100644 --- a/src/history.rs +++ b/src/history.rs @@ -4,7 +4,6 @@ use crate::{ agent::{self, Agent}, batch::{self, Batch}, gaussian::Gaussian, - matrix::Matrix, player::Player, sort_time, tuple_gt, tuple_max, Index, BETA, GAMMA, MU, P_DRAW, SIGMA, }; @@ -418,67 +417,6 @@ impl History { self.size += n; } - - pub fn quality(&self, rating_groups: &[&[Gaussian]]) -> f64 { - let flatten_ratings = rating_groups - .iter() - .flat_map(|group| group.iter()) - .collect::>(); - - let flatten_weights = vec![1.0; flatten_ratings.len()].into_boxed_slice(); - - let length = flatten_ratings.len(); - - let mut mean_matrix = Matrix::new(length, 1); - - for (i, rating) in flatten_ratings.iter().enumerate() { - mean_matrix[(i, 0)] = rating.mu; - } - - let mut variance_matrix = Matrix::new(length, length); - - for (i, rating) in flatten_ratings.iter().enumerate() { - variance_matrix[(i, i)] = rating.sigma.powi(2); - } - - let mut rotated_a_matrix = Matrix::new(rating_groups.len() - 1, length); - - let mut t = 0; - let mut x = 0; - - for (row, group) in rating_groups.windows(2).enumerate() { - let current = group[0]; - let next = group[1]; - - for n in t..t + current.len() { - rotated_a_matrix[(row, n)] = flatten_weights[n]; - - x += 1; - } - - t += current.len(); - - for n in x..x + next.len() { - rotated_a_matrix[(row, n)] = -flatten_weights[n]; - } - - x += next.len(); - } - - let a_matrix = rotated_a_matrix.transpose(); - - let ata = self.beta.powi(2) * &rotated_a_matrix * &a_matrix; - let atsa = &rotated_a_matrix * &variance_matrix * &a_matrix; - - let start = mean_matrix.transpose() * &a_matrix; - let middle = &ata + &atsa; - let end = &rotated_a_matrix * &mean_matrix; - - let e_arg = (-0.5 * &start * &middle.inverse() * &end).determinant(); - let s_arg = ata.determinant() / middle.determinant(); - - e_arg.exp() * s_arg.sqrt() - } } #[cfg(test)] diff --git a/src/lib.rs b/src/lib.rs index 76fb17b..2a74c02 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,6 +18,7 @@ pub mod player; pub use game::Game; pub use gaussian::Gaussian; pub use history::History; +use matrix::Matrix; use message::DiffMessage; pub use player::Player; @@ -255,6 +256,67 @@ pub(crate) fn evidence(d: &[DiffMessage], margin: &[f64], tie: &[bool], e: usize } } +pub fn quality(rating_groups: &[&[Gaussian]], beta: f64) -> f64 { + let flatten_ratings = rating_groups + .iter() + .flat_map(|group| group.iter()) + .collect::>(); + + let flatten_weights = vec![1.0; flatten_ratings.len()].into_boxed_slice(); + + let length = flatten_ratings.len(); + + let mut mean_matrix = Matrix::new(length, 1); + + for (i, rating) in flatten_ratings.iter().enumerate() { + mean_matrix[(i, 0)] = rating.mu; + } + + let mut variance_matrix = Matrix::new(length, length); + + for (i, rating) in flatten_ratings.iter().enumerate() { + variance_matrix[(i, i)] = rating.sigma.powi(2); + } + + let mut rotated_a_matrix = Matrix::new(rating_groups.len() - 1, length); + + let mut t = 0; + let mut x = 0; + + for (row, group) in rating_groups.windows(2).enumerate() { + let current = group[0]; + let next = group[1]; + + for n in t..t + current.len() { + rotated_a_matrix[(row, n)] = flatten_weights[n]; + + x += 1; + } + + t += current.len(); + + for n in x..x + next.len() { + rotated_a_matrix[(row, n)] = -flatten_weights[n]; + } + + x += next.len(); + } + + let a_matrix = rotated_a_matrix.transpose(); + + let ata = beta.powi(2) * &rotated_a_matrix * &a_matrix; + let atsa = &rotated_a_matrix * &variance_matrix * &a_matrix; + + let start = mean_matrix.transpose() * &a_matrix; + let middle = &ata + &atsa; + let end = &rotated_a_matrix * &mean_matrix; + + let e_arg = (-0.5 * &start * &middle.inverse() * &end).determinant(); + let s_arg = ata.determinant() / middle.determinant(); + + e_arg.exp() * s_arg.sqrt() +} + #[cfg(test)] mod tests { use super::*;