From e87937798018ac457c44ed275bff65de74fe871f Mon Sep 17 00:00:00 2001 From: Anders Olsson Date: Tue, 7 Jun 2022 16:04:04 +0200 Subject: [PATCH] Make verbose an argument for the fit function --- examples/abcdef.rs | 10 +++++----- examples/kickscore-basics.rs | 2 +- examples/nba-history.rs | 2 +- src/model/binary.rs | 3 +-- src/model/ternary.rs | 3 +-- tests/binary-1.rs | 2 +- tests/kickscore-basics.rs | 2 +- tests/nba-history.rs | 2 +- 8 files changed, 12 insertions(+), 14 deletions(-) diff --git a/examples/abcdef.rs b/examples/abcdef.rs index 98acdcd..5df70be 100644 --- a/examples/abcdef.rs +++ b/examples/abcdef.rs @@ -16,7 +16,7 @@ fn main() { model.observe(&["A"], &["B"], 0.0); - model.fit(); + model.fit(true); for player in ["A", "B", "C", "D", "E", "F"] { let (mu, sigma) = model.item_score(player, 1.25); @@ -26,7 +26,7 @@ fn main() { model.observe(&["C"], &["D"], 0.25); - model.fit(); + model.fit(true); for player in ["A", "B", "C", "D", "E", "F"] { let (mu, sigma) = model.item_score(player, 1.25); @@ -36,7 +36,7 @@ fn main() { model.observe(&["E"], &["F"], 0.50); - model.fit(); + model.fit(true); for player in ["A", "B", "C", "D", "E", "F"] { let (mu, sigma) = model.item_score(player, 1.25); @@ -46,7 +46,7 @@ fn main() { model.observe(&["B"], &["C"], 0.75); - model.fit(); + model.fit(true); for player in ["A", "B", "C", "D", "E", "F"] { let (mu, sigma) = model.item_score(player, 1.25); @@ -56,7 +56,7 @@ fn main() { model.observe(&["D"], &["E"], 1.00); - model.fit(); + model.fit(true); for player in ["A", "B", "C", "D", "E", "F"] { let (mu, sigma) = model.item_score(player, 1.25); diff --git a/examples/kickscore-basics.rs b/examples/kickscore-basics.rs index 1b6bb6b..1f26cd5 100644 --- a/examples/kickscore-basics.rs +++ b/examples/kickscore-basics.rs @@ -31,7 +31,7 @@ fn main() { model.observe(&["Jerry"], &["Tom"], 3.0); model.observe(&["Jerry"], &["Tom", "Spike"], 3.5); - model.fit(); + model.fit(true); // We can predict a future outcome... let (p_win, _p_los) = model.probabilities(&["Jerry"], &["Tom"], 4.0); diff --git a/examples/nba-history.rs b/examples/nba-history.rs index 25f9efd..ab3ec68 100644 --- a/examples/nba-history.rs +++ b/examples/nba-history.rs @@ -69,7 +69,7 @@ fn main() -> Result<(), Box> { model.observe(&[&winner], &[&loser], t); } - model.fit(); + model.fit(true); println!("Probability that CHI beats BOS..."); diff --git a/src/model/binary.rs b/src/model/binary.rs index f69209d..0d1ae89 100644 --- a/src/model/binary.rs +++ b/src/model/binary.rs @@ -102,14 +102,13 @@ impl Binary { self.last_t = t; } - pub fn fit(&mut self) -> bool { + pub fn fit(&mut self, verbose: bool) -> bool { // method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False): let method = FitMethod::Ep; let lr = 1.0; let tol = 1e-3; let max_iter = 100; - let verbose = true; self.last_method = Some(method); diff --git a/src/model/ternary.rs b/src/model/ternary.rs index 52a348f..7037010 100644 --- a/src/model/ternary.rs +++ b/src/model/ternary.rs @@ -119,14 +119,13 @@ impl TernaryModel { self.last_t = t; } - pub fn fit(&mut self) -> bool { + pub fn fit(&mut self, verbose: bool) -> bool { // method="ep", lr=1.0, tol=1e-3, max_iter=100, verbose=False): let method = FitMethod::Ep; let lr = 1.0; let tol = 1e-3; let max_iter = 100; - let verbose = true; self.last_method = Some(method); diff --git a/tests/binary-1.rs b/tests/binary-1.rs index 1645208..1cec029 100644 --- a/tests/binary-1.rs +++ b/tests/binary-1.rs @@ -23,5 +23,5 @@ fn binary_1() { model.observe(&["benjamin"], &["audrey"], 6.0); model.observe(&["benjamin"], &["audrey"], 7.0); - model.fit(); + model.fit(true); } diff --git a/tests/kickscore-basics.rs b/tests/kickscore-basics.rs index 9e9d1d6..7521a5e 100644 --- a/tests/kickscore-basics.rs +++ b/tests/kickscore-basics.rs @@ -27,7 +27,7 @@ fn kickscore_basic() { model.observe(&["Jerry"], &["Tom"], 3.0); model.observe(&["Jerry"], &["Tom", "Spike"], 3.5); - model.fit(); + model.fit(true); let (p_win, _p_los) = model.probabilities(&["Jerry"], &["Tom"], 4.0); diff --git a/tests/nba-history.rs b/tests/nba-history.rs index 011153d..b044164 100644 --- a/tests/nba-history.rs +++ b/tests/nba-history.rs @@ -73,7 +73,7 @@ fn nba_history() { model.observe(&[&winner], &[&loser], t); } - model.fit(); + model.fit(true); let (p_win, _) = model.probabilities( &["CHI"],