diff --git a/hyperactor_mesh/src/v1/host_mesh/mesh_agent.rs b/hyperactor_mesh/src/v1/host_mesh/mesh_agent.rs index 50ccfd144..3545973f4 100644 --- a/hyperactor_mesh/src/v1/host_mesh/mesh_agent.rs +++ b/hyperactor_mesh/src/v1/host_mesh/mesh_agent.rs @@ -20,6 +20,7 @@ use hyperactor::ActorHandle; use hyperactor::ActorId; use hyperactor::ActorRef; use hyperactor::Context; +use hyperactor::HandleClient; use hyperactor::Handler; use hyperactor::Instance; use hyperactor::Named; @@ -328,7 +329,7 @@ impl Handler for HostMeshAgent { } } -#[derive(Serialize, Deserialize, Debug, Named, Handler, RefClient)] +#[derive(Serialize, Deserialize, Debug, Named, Handler, RefClient, HandleClient)] pub struct ShutdownHost { /// Grace window: send SIGTERM and wait this long before /// escalating. diff --git a/monarch_hyperactor/src/v1/host_mesh.rs b/monarch_hyperactor/src/v1/host_mesh.rs index 55c55ab32..f49ba281a 100644 --- a/monarch_hyperactor/src/v1/host_mesh.rs +++ b/monarch_hyperactor/src/v1/host_mesh.rs @@ -10,7 +10,9 @@ use std::collections::HashMap; use std::ops::Deref; use std::path::PathBuf; use std::sync::OnceLock; +use std::time::Duration; +use hyperactor::ActorHandle; use hyperactor::Instance; use hyperactor::Proc; use hyperactor_mesh::bootstrap::BootstrapCommand; @@ -22,6 +24,8 @@ use hyperactor_mesh::v1::ProcMeshRef; use hyperactor_mesh::v1::host_mesh::HostMesh; use hyperactor_mesh::v1::host_mesh::HostMeshRef; use hyperactor_mesh::v1::host_mesh::mesh_agent::GetLocalProcClient; +use hyperactor_mesh::v1::host_mesh::mesh_agent::HostMeshAgent; +use hyperactor_mesh::v1::host_mesh::mesh_agent::ShutdownHostClient; use hyperactor_mesh::v1::proc_mesh::ProcRef; use ndslice::View; use ndslice::view::RankedSliceable; @@ -253,6 +257,9 @@ impl PyHostMeshRefImpl { /// Static storage for the root client instance when using host-based bootstrap. static ROOT_CLIENT_INSTANCE_FOR_HOST: OnceLock> = OnceLock::new(); +/// Static storage for the host mesh agent created by bootstrap_host(). +static HOST_MESH_AGENT_FOR_HOST: OnceLock> = OnceLock::new(); + /// Bootstrap the client host and root client actor. /// /// This creates a proper Host with BootstrapProcManager, spawns the root client @@ -282,6 +289,11 @@ fn bootstrap_host(bootstrap_cmd: Option) -> PyResult) -> PyResult { r.map(PyHostMesh::new_ref) } +#[pyfunction] +fn shutdown_local_host_mesh() -> PyResult { + let agent = HOST_MESH_AGENT_FOR_HOST + .get() + .ok_or_else(|| PyException::new_err("No local host mesh to shutdown"))? + .clone(); + + PyPythonTask::new(async move { + // Create a temporary instance to send the shutdown message + let temp_proc = hyperactor::Proc::local(); + let (instance, _) = temp_proc + .instance("shutdown_requester") + .map_err(|e| PyException::new_err(e.to_string()))?; + + // Use same defaults as HostMesh::shutdown(): + // - MESH_TERMINATE_TIMEOUT = 10 seconds + // - MESH_TERMINATE_CONCURRENCY = 16 + agent + .shutdown_host(&instance, Duration::from_secs(10), 16) + .await + .map_err(|e| PyException::new_err(e.to_string()))?; + + Ok(()) + }) +} + pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> { let f = wrap_pyfunction!(py_host_mesh_from_bytes, hyperactor_mod)?; f.setattr( @@ -348,6 +386,13 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul )?; hyperactor_mod.add_function(f2)?; + let f3 = wrap_pyfunction!(shutdown_local_host_mesh, hyperactor_mod)?; + f3.setattr( + "__module__", + "monarch._rust_bindings.monarch_hyperactor.v1.host_mesh", + )?; + hyperactor_mod.add_function(f3)?; + hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; Ok(()) diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/v1/host_mesh.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/v1/host_mesh.pyi index e0a9227f4..cef221754 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/v1/host_mesh.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/v1/host_mesh.pyi @@ -113,3 +113,16 @@ def bootstrap_host( - `bootstrap_cmd`: The bootstrap command to use to bootstrap the host. """ ... + +def shutdown_local_host_mesh() -> PythonTask[None]: + """ + Shutdown the local host mesh created by bootstrap_host(). + + Sends ShutdownHost message to the local host mesh agent with: + - timeout: 10 seconds grace period before SIGTERM escalation + - max_in_flight: 16 concurrent child terminations + + Raises: + RuntimeError: If no local host mesh exists (bootstrap_host not called) + """ + ... diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index f18de2f0c..e310ece31 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -380,11 +380,17 @@ def shutdown_context() -> "Future[None]": completion. """ from monarch._src.actor.future import Future + from monarch._src.actor.v1 import enabled as v1_enabled - client_host_ctx = _client_context.try_get() - if client_host_ctx is not None: - host_mesh = client_host_ctx.actor_instance.proc_mesh.host_mesh - return host_mesh.shutdown() + if v1_enabled: + try: + from monarch._rust_bindings.monarch_hyperactor.v1.host_mesh import ( + shutdown_local_host_mesh, + ) + return Future(coro=shutdown_local_host_mesh()) + except RuntimeError: + # No local host mesh to shutdown + pass # Nothing to shutdown - return a completed future async def noop() -> None: