use crate::{ backoff::Backoff, child::ChildHandle, logs::LogSink, policy::{RestartDecision, decide}, retry_window::RetryWindow, }; use std::time::{Duration, Instant}; use tokio::sync::{mpsc, oneshot, watch}; use tokio::time::sleep; use tracing::{debug, info, warn}; use xy_protocol::{ServerConfig, ServerState}; pub enum SupervisorCmd { Start { ack: oneshot::Sender, }, Stop { ack: oneshot::Sender, }, Restart { ack: oneshot::Sender<()>, }, Reconfigure { new: ServerConfig, ack: oneshot::Sender<()>, }, Shutdown { ack: oneshot::Sender<()>, }, } #[derive(Debug, PartialEq, Eq)] pub enum StartAck { Started, AlreadyRunning, } #[derive(Debug, PartialEq, Eq)] pub enum StopAck { Stopped, NotRunning, } #[derive(Debug, Clone)] pub struct Status { pub state: ServerState, pub pid: Option, pub port: u16, pub uptime_secs: Option, pub restart_count: u32, pub last_exit: Option, } #[derive(Clone)] pub struct SupervisorHandle { pub name: String, pub tx: mpsc::Sender, pub status: watch::Receiver, pub log_sink: LogSink, } #[async_trait::async_trait] pub trait Spawner: Send + 'static { type Child: ChildHandle; async fn spawn(&self, cfg: &ServerConfig, sink: LogSink) -> std::io::Result; } pub struct SupervisorTask { cfg: ServerConfig, log_sink: LogSink, spawner: S, status_tx: watch::Sender, cmd_rx: mpsc::Receiver, backoff: Backoff, retry_window: RetryWindow, restart_count: u32, last_exit: Option, started_at: Option, current_pid: Option, } impl SupervisorTask { pub fn new( cfg: ServerConfig, log_sink: LogSink, spawner: S, status_tx: watch::Sender, cmd_rx: mpsc::Receiver, ) -> Self { let backoff = Backoff::new(cfg.restart.backoff_initial, cfg.restart.backoff_max); let retry_window = RetryWindow::new(Duration::from_secs(60), cfg.restart.max_retries_per_minute); Self { cfg, log_sink, spawner, status_tx, cmd_rx, backoff, retry_window, restart_count: 0, last_exit: None, started_at: None, current_pid: None, } } fn set_state(&mut self, s: ServerState) { let uptime_secs = self.started_at.map(|t| t.elapsed().as_secs()); let _ = self.status_tx.send(Status { state: s, pid: self.current_pid, port: self.cfg.port, uptime_secs, restart_count: self.restart_count, last_exit: self.last_exit, }); } pub async fn run(mut self) { let mut child: Option = None; loop { tokio::select! { cmd = self.cmd_rx.recv() => { let Some(cmd) = cmd else { break; }; match cmd { SupervisorCmd::Start { ack } => { if child.is_some() { let _ = ack.send(StartAck::AlreadyRunning); } else { match self.do_start().await { Ok(c) => { child = Some(c); let _ = ack.send(StartAck::Started); } Err(err) => { warn!(name = %self.cfg.name, error = %err, "spawn failed"); self.set_state(ServerState::Failed); let _ = ack.send(StartAck::Started); } } } } SupervisorCmd::Stop { ack } => { if let Some(c) = child.take() { self.do_stop(c).await; let _ = ack.send(StopAck::Stopped); } else { let _ = ack.send(StopAck::NotRunning); } } SupervisorCmd::Restart { ack } => { if let Some(c) = child.take() { self.set_state(ServerState::Restarting); self.do_stop(c).await; } match self.do_start().await { Ok(c) => child = Some(c), Err(err) => { warn!(name = %self.cfg.name, error = %err, "restart spawn failed"); self.set_state(ServerState::Failed); } } let _ = ack.send(()); } SupervisorCmd::Reconfigure { new, ack } => { self.cfg = new; self.backoff = Backoff::new(self.cfg.restart.backoff_initial, self.cfg.restart.backoff_max); self.retry_window = RetryWindow::new( Duration::from_secs(60), self.cfg.restart.max_retries_per_minute, ); let _ = ack.send(()); } SupervisorCmd::Shutdown { ack } => { if let Some(c) = child.take() { self.do_stop(c).await; } let _ = ack.send(()); return; } } } code = wait_child(&mut child) => { child = None; self.last_exit = code; self.current_pid = None; let now = Instant::now(); self.retry_window.record(now); let cap = self.retry_window.cap_reached(now); let decision = decide(self.cfg.restart.policy, code, cap); debug!(name = %self.cfg.name, ?code, ?decision, "child exited"); match decision { RestartDecision::StayStopped => { self.started_at = None; self.set_state(ServerState::Stopped); } RestartDecision::MarkFailed => { self.started_at = None; self.set_state(ServerState::Failed); } RestartDecision::Restart => { self.set_state(ServerState::Restarting); let delay = self.backoff.next(); sleep(delay).await; match self.do_start().await { Ok(c) => child = Some(c), Err(err) => { warn!(name = %self.cfg.name, error = %err, "restart spawn failed"); self.set_state(ServerState::Failed); } } } } } } } } async fn do_start(&mut self) -> std::io::Result { self.set_state(ServerState::Starting); let c = self.spawner.spawn(&self.cfg, self.log_sink.clone()).await?; self.restart_count = self.restart_count.saturating_add(1); self.started_at = Some(Instant::now()); self.current_pid = Some(c.pid()); self.backoff.reset(); self.set_state(ServerState::Running); info!(name = %self.cfg.name, pid = c.pid(), "started"); Ok(c) } async fn do_stop(&mut self, mut c: S::Child) { self.set_state(ServerState::Stopping); let _ = c.terminate(); let grace = self.cfg.stop.grace; match tokio::time::timeout(grace, c.wait()).await { Ok(_) => {} Err(_) => { let _ = c.kill(); let _ = c.wait().await; } } self.current_pid = None; self.started_at = None; self.set_state(ServerState::Stopped); } } async fn wait_child(slot: &mut Option) -> Option { match slot.as_mut() { Some(c) => c.wait().await.ok().flatten(), None => std::future::pending().await, } } pub struct RealSpawner; #[async_trait::async_trait] impl Spawner for RealSpawner { type Child = crate::child::RealChild; async fn spawn(&self, cfg: &ServerConfig, sink: LogSink) -> std::io::Result { crate::child::spawn_with_logs(cfg, sink) } } #[cfg(test)] mod tests { use super::*; use crate::child::MockChild; use crate::logs::{LogSink, RotatingLogWriter}; use std::sync::{Arc, Mutex}; use tempfile::tempdir; use xy_protocol::{RestartConfig, RestartPolicy, StopConfig}; struct QueueSpawner { queue: Arc>>, } #[async_trait::async_trait] impl Spawner for QueueSpawner { type Child = MockChild; async fn spawn(&self, _cfg: &ServerConfig, _sink: LogSink) -> std::io::Result { let mut q = self.queue.lock().unwrap(); Ok(q.remove(0)) } } fn cfg(name: &str, policy: RestartPolicy, max_retries: u32) -> ServerConfig { ServerConfig { name: name.to_string(), command: "/bin/true".into(), args: vec![], port: 1, env: Default::default(), working_dir: None, restart: RestartConfig { policy, backoff_initial: Duration::from_millis(1), backoff_max: Duration::from_millis(1), max_retries_per_minute: max_retries, }, stop: StopConfig { grace: Duration::from_millis(50), }, } } fn sink(name: &str) -> LogSink { let dir = tempdir().unwrap(); let writer = RotatingLogWriter::open(&dir.path().join("s.log"), 1024, 3).unwrap(); std::mem::forget(dir); LogSink::new(name.to_string(), writer, 1024) } fn initial_status(cfg: &ServerConfig) -> Status { Status { state: ServerState::Stopped, pid: None, port: cfg.port, uptime_secs: None, restart_count: 0, last_exit: None, } } async fn wait_for(rx: &mut watch::Receiver, want: ServerState) { let deadline = tokio::time::Instant::now() + Duration::from_secs(2); loop { if rx.borrow().state == want { return; } tokio::select! { _ = rx.changed() => {} _ = tokio::time::sleep_until(deadline) => panic!("never reached {want:?}, last={:?}", rx.borrow().state), } } } #[tokio::test] async fn start_runs_to_running_and_stop_to_stopped() { let cfg = cfg("x", RestartPolicy::Never, 5); let (mock, mut ctl) = MockChild::new(1); let queue = Arc::new(Mutex::new(vec![mock])); let spawner = QueueSpawner { queue }; let (status_tx, mut status_rx) = watch::channel(initial_status(&cfg)); let (cmd_tx, cmd_rx) = mpsc::channel(8); let task = SupervisorTask::new(cfg, sink("x"), spawner, status_tx, cmd_rx); let h = tokio::spawn(task.run()); let (ack_tx, ack_rx) = oneshot::channel(); cmd_tx .send(SupervisorCmd::Start { ack: ack_tx }) .await .unwrap(); assert_eq!(ack_rx.await.unwrap(), StartAck::Started); wait_for(&mut status_rx, ServerState::Running).await; ctl.exit_tx.take().unwrap().send(Some(0)).unwrap(); wait_for(&mut status_rx, ServerState::Stopped).await; let (ack_tx, ack_rx) = oneshot::channel(); cmd_tx .send(SupervisorCmd::Shutdown { ack: ack_tx }) .await .unwrap(); ack_rx.await.unwrap(); h.await.unwrap(); } }