diff --git a/src/event_builder.rs b/src/event_builder.rs new file mode 100644 index 0000000..d415e16 --- /dev/null +++ b/src/event_builder.rs @@ -0,0 +1,94 @@ +use smallvec::SmallVec; + +use crate::{ + InferenceError, Outcome, + drift::Drift, + event::{Event, Member, Team}, + history::History, + observer::Observer, + time::Time, +}; + +pub struct EventBuilder<'h, T, D, O, K> +where + T: Time, + D: Drift, + O: Observer, + K: Eq + std::hash::Hash + Clone, +{ + history: &'h mut History, + event: Event, + current_team_idx: Option, +} + +impl<'h, T, D, O, K> EventBuilder<'h, T, D, O, K> +where + T: Time, + D: Drift, + O: Observer, + K: Eq + std::hash::Hash + Clone, +{ + pub(crate) fn new(history: &'h mut History, time: T) -> Self { + Self { + history, + event: Event { + time, + teams: SmallVec::new(), + outcome: Outcome::Ranked(SmallVec::new()), + }, + current_team_idx: None, + } + } + + /// Add a team by its member keys (weight 1.0 each, no prior overrides). + pub fn team>(mut self, keys: I) -> Self { + let members: SmallVec<[Member; 4]> = keys.into_iter().map(Member::new).collect(); + self.event.teams.push(Team { members }); + self.current_team_idx = Some(self.event.teams.len() - 1); + self + } + + /// Set per-member weights for the most recently added team. + /// + /// Panics in debug builds if called before `.team(...)` or if the length + /// doesn't match the team's member count. + pub fn weights>(mut self, weights: I) -> Self { + let idx = self + .current_team_idx + .expect(".weights(...) called before any .team(...)"); + let ws: Vec = weights.into_iter().collect(); + let team = &mut self.event.teams[idx]; + debug_assert_eq!( + ws.len(), + team.members.len(), + "weights length must match team size" + ); + for (m, w) in team.members.iter_mut().zip(ws) { + m.weight = w; + } + self + } + + /// Set explicit ranks per team (length must equal number of teams). + pub fn ranking>(mut self, ranks: I) -> Self { + self.event.outcome = Outcome::ranking(ranks); + self + } + + /// Mark team `winner_idx` as winner; others tied for last. + pub fn winner(mut self, winner_idx: u32) -> Self { + self.event.outcome = Outcome::winner(winner_idx, self.event.teams.len() as u32); + self + } + + /// All teams tied. + pub fn draw(mut self) -> Self { + self.event.outcome = Outcome::draw(self.event.teams.len() as u32); + self + } + + /// Commit the event to the history. + pub fn commit(self) -> Result<(), InferenceError> { + self.history.add_events(std::iter::once(self.event)) + } +} diff --git a/src/history.rs b/src/history.rs index add2ba0..7ba717a 100644 --- a/src/history.rs +++ b/src/history.rs @@ -535,6 +535,11 @@ impl, O: Observer, K: Eq + Hash + Clone> History crate::event_builder::EventBuilder<'_, T, D, O, K> { + crate::event_builder::EventBuilder::new(self, time) + } + /// Bulk-ingest typed events. pub fn add_events(&mut self, events: I) -> Result<(), InferenceError> where diff --git a/src/lib.rs b/src/lib.rs index 5a397fc..63ab8a0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,7 @@ mod convergence; pub mod drift; mod error; mod event; +mod event_builder; pub(crate) mod factor; mod game; pub mod gaussian; @@ -31,6 +32,7 @@ pub use convergence::{ConvergenceOptions, ConvergenceReport}; pub use drift::{ConstantDrift, Drift}; pub use error::InferenceError; pub use event::{Event, Member, Team}; +pub use event_builder::EventBuilder; pub use game::Game; pub use gaussian::Gaussian; pub use history::History; diff --git a/tests/api_shape.rs b/tests/api_shape.rs index 3a8dec4..886be48 100644 --- a/tests/api_shape.rs +++ b/tests/api_shape.rs @@ -82,3 +82,63 @@ fn add_events_rejects_mismatched_outcome_ranks() { let err = h.add_events(events).unwrap_err(); assert!(matches!(err, InferenceError::MismatchedShape { .. })); } + +#[test] +fn fluent_event_builder_basic() { + let mut h = History::builder() + .mu(25.0) + .sigma(25.0 / 3.0) + .beta(25.0 / 6.0) + .p_draw(0.0) + .build(); + + h.event(1) + .team(["alice", "bob"]) + .weights([1.0, 0.7]) + .team(["carol"]) + .ranking([1, 0]) + .commit() + .unwrap(); + + let report = h.converge().unwrap(); + assert!(report.converged); + assert!(h.lookup(&"alice").is_some()); + assert!(h.lookup(&"bob").is_some()); + assert!(h.lookup(&"carol").is_some()); +} + +#[test] +fn fluent_event_builder_winner_convenience() { + let mut h = History::builder() + .mu(25.0) + .sigma(25.0 / 3.0) + .beta(25.0 / 6.0) + .p_draw(0.0) + .build(); + + h.event(1) + .team(["alice"]) + .team(["bob"]) + .winner(0) + .commit() + .unwrap(); + h.converge().unwrap(); +} + +#[test] +fn fluent_event_builder_draw() { + let mut h = History::builder() + .mu(25.0) + .sigma(25.0 / 3.0) + .beta(25.0 / 6.0) + .p_draw(0.25) + .build(); + + h.event(1) + .team(["alice"]) + .team(["bob"]) + .draw() + .commit() + .unwrap(); + h.converge().unwrap(); +}