Start working on factor graph.

This commit is contained in:
2018-10-22 07:29:18 +02:00
parent 14bb0e9bb8
commit d6ea5e3116
3 changed files with 99 additions and 0 deletions

54
src/factor_graph.rs Normal file
View File

@@ -0,0 +1,54 @@
use std::cmp;
use gaussian::Gaussian;
pub struct Variable {
gaussian: Gaussian,
}
impl Variable {
pub fn new() -> Variable {
Variable {
gaussian: Gaussian::new(0.0, 0.0),
}
}
fn delta(&self, other: &Variable) -> f32 {
let pi_delta = self.gaussian.pi - other.gaussian.pi;
if pi_delta.is_infinite() {
0.0
} else {
let tau_delta = (self.gaussian.tau - other.gaussian.tau).abs();
if pi_delta > tau_delta {
pi_delta
} else {
tau_delta
}
}
}
}
pub trait Factor {
fn down(&self) -> f32 {
0.0
}
}
pub struct PriorFactor {
variable: Variable,
dynamic: f32
}
impl PriorFactor {
pub fn new(variable: Variable, dynamic: f32) -> PriorFactor {
PriorFactor { variable, dynamic }
}
}
impl Factor for PriorFactor {
fn down(&self) -> f32 {
0.0
}
}

10
src/gaussian.rs Normal file
View File

@@ -0,0 +1,10 @@
pub struct Gaussian {
pub pi: f32,
pub tau: f32,
}
impl Gaussian {
pub fn new(pi: f32, tau: f32) -> Gaussian {
Gaussian { pi, tau }
}
}

View File

@@ -1,6 +1,9 @@
mod matrix; mod matrix;
mod factor_graph;
mod gaussian;
use matrix::Matrix; use matrix::Matrix;
use factor_graph::*;
/// Default initial mean of ratings. /// Default initial mean of ratings.
const MU: f32 = 25.0; const MU: f32 = 25.0;
@@ -35,6 +38,38 @@ impl Default for Rating {
} }
} }
fn _team_sizes(rating_groups: &[&[Rating]]) -> Vec<usize> {
let mut team_sizes = Vec::new();
for group in rating_groups {
let last = team_sizes.last().map(|v| *v).unwrap_or(0);
team_sizes.push(group.len() + last);
}
team_sizes
}
fn factor_graph_builders(rating_groups: &[&[Rating]]) {
let flatten_ratings = rating_groups
.iter()
.flat_map(|group| group.iter())
.collect::<Vec<_>>();
let flatten_weights = vec![1.0; flatten_ratings.len()].into_boxed_slice();
let size = flatten_ratings.len();
let group_size = rating_groups.len();
let rating_vars = (0..size).map(|_| Variable::new()).collect::<Vec<_>>();
let perf_vars = (0..size).map(|_| Variable::new()).collect::<Vec<_>>();
let team_perf_vars = (0..group_size).map(|_| Variable::new()).collect::<Vec<_>>();
let team_diff_vars = (0..group_size - 1).map(|_| Variable::new()).collect::<Vec<_>>();
let team_sizes = _team_sizes(rating_groups);
}
fn rate(rating_groups: &[&[Rating]]) { fn rate(rating_groups: &[&[Rating]]) {
let ranks = (0..rating_groups.len()).collect::<Vec<_>>(); let ranks = (0..rating_groups.len()).collect::<Vec<_>>();
let weights = vec![1.0; rating_groups.len()]; let weights = vec![1.0; rating_groups.len()];