Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion hyperactor_mesh/src/v1/host_mesh/mesh_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use hyperactor::ActorHandle;
use hyperactor::ActorId;
use hyperactor::ActorRef;
use hyperactor::Context;
use hyperactor::HandleClient;
use hyperactor::Handler;
use hyperactor::Instance;
use hyperactor::Named;
Expand Down Expand Up @@ -328,7 +329,7 @@ impl Handler<resource::GetRankStatus> for HostMeshAgent {
}
}

#[derive(Serialize, Deserialize, Debug, Named, Handler, RefClient)]
#[derive(Serialize, Deserialize, Debug, Named, Handler, RefClient, HandleClient)]
pub struct ShutdownHost {
/// Grace window: send SIGTERM and wait this long before
/// escalating.
Expand Down
45 changes: 45 additions & 0 deletions monarch_hyperactor/src/v1/host_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ use std::collections::HashMap;
use std::ops::Deref;
use std::path::PathBuf;
use std::sync::OnceLock;
use std::time::Duration;

use hyperactor::ActorHandle;
use hyperactor::Instance;
use hyperactor::Proc;
use hyperactor_mesh::bootstrap::BootstrapCommand;
Expand All @@ -22,6 +24,8 @@ use hyperactor_mesh::v1::ProcMeshRef;
use hyperactor_mesh::v1::host_mesh::HostMesh;
use hyperactor_mesh::v1::host_mesh::HostMeshRef;
use hyperactor_mesh::v1::host_mesh::mesh_agent::GetLocalProcClient;
use hyperactor_mesh::v1::host_mesh::mesh_agent::HostMeshAgent;
use hyperactor_mesh::v1::host_mesh::mesh_agent::ShutdownHostClient;
use hyperactor_mesh::v1::proc_mesh::ProcRef;
use ndslice::View;
use ndslice::view::RankedSliceable;
Expand Down Expand Up @@ -253,6 +257,9 @@ impl PyHostMeshRefImpl {
/// Static storage for the root client instance when using host-based bootstrap.
static ROOT_CLIENT_INSTANCE_FOR_HOST: OnceLock<Instance<PythonActor>> = OnceLock::new();

/// Static storage for the host mesh agent created by bootstrap_host().
static HOST_MESH_AGENT_FOR_HOST: OnceLock<ActorHandle<HostMeshAgent>> = OnceLock::new();

/// Bootstrap the client host and root client actor.
///
/// This creates a proper Host with BootstrapProcManager, spawns the root client
Expand Down Expand Up @@ -282,6 +289,11 @@ fn bootstrap_host(bootstrap_cmd: Option<PyBootstrapCommand>) -> PyResult<PyPytho
.await
.map_err(|e| PyException::new_err(e.to_string()))?;

// Store the agent for later shutdown
HOST_MESH_AGENT_FOR_HOST
.set(host_mesh_agent.clone())
.ok(); // Ignore error if already set

let host_mesh_name = hyperactor_mesh::v1::Name::new_reserved("local").unwrap();
let host_mesh = HostMeshRef::from_host_agent(host_mesh_name, host_mesh_agent.bind())
.map_err(|e| PyException::new_err(e.to_string()))?;
Expand Down Expand Up @@ -333,6 +345,32 @@ fn py_host_mesh_from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<PyHostMesh> {
r.map(PyHostMesh::new_ref)
}

#[pyfunction]
fn shutdown_local_host_mesh() -> PyResult<PyPythonTask> {
let agent = HOST_MESH_AGENT_FOR_HOST
.get()
.ok_or_else(|| PyException::new_err("No local host mesh to shutdown"))?
.clone();

PyPythonTask::new(async move {
// Create a temporary instance to send the shutdown message
let temp_proc = hyperactor::Proc::local();
let (instance, _) = temp_proc
.instance("shutdown_requester")
.map_err(|e| PyException::new_err(e.to_string()))?;

// Use same defaults as HostMesh::shutdown():
// - MESH_TERMINATE_TIMEOUT = 10 seconds
// - MESH_TERMINATE_CONCURRENCY = 16
agent
.shutdown_host(&instance, Duration::from_secs(10), 16)
.await
.map_err(|e| PyException::new_err(e.to_string()))?;

Ok(())
})
}

pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
let f = wrap_pyfunction!(py_host_mesh_from_bytes, hyperactor_mod)?;
f.setattr(
Expand All @@ -348,6 +386,13 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
)?;
hyperactor_mod.add_function(f2)?;

let f3 = wrap_pyfunction!(shutdown_local_host_mesh, hyperactor_mod)?;
f3.setattr(
"__module__",
"monarch._rust_bindings.monarch_hyperactor.v1.host_mesh",
)?;
hyperactor_mod.add_function(f3)?;

hyperactor_mod.add_class::<PyHostMesh>()?;
hyperactor_mod.add_class::<PyBootstrapCommand>()?;
Ok(())
Expand Down
13 changes: 13 additions & 0 deletions python/monarch/_rust_bindings/monarch_hyperactor/v1/host_mesh.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,16 @@ def bootstrap_host(
- `bootstrap_cmd`: The bootstrap command to use to bootstrap the host.
"""
...

def shutdown_local_host_mesh() -> PythonTask[None]:
"""
Shutdown the local host mesh created by bootstrap_host().
Sends ShutdownHost message to the local host mesh agent with:
- timeout: 10 seconds grace period before SIGTERM escalation
- max_in_flight: 16 concurrent child terminations
Raises:
RuntimeError: If no local host mesh exists (bootstrap_host not called)
"""
...
14 changes: 10 additions & 4 deletions python/monarch/_src/actor/actor_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,11 +380,17 @@ def shutdown_context() -> "Future[None]":
completion.
"""
from monarch._src.actor.future import Future
from monarch._src.actor.v1 import enabled as v1_enabled

client_host_ctx = _client_context.try_get()
if client_host_ctx is not None:
host_mesh = client_host_ctx.actor_instance.proc_mesh.host_mesh
return host_mesh.shutdown()
if v1_enabled:
try:
from monarch._rust_bindings.monarch_hyperactor.v1.host_mesh import (
shutdown_local_host_mesh,
)
return Future(coro=shutdown_local_host_mesh())
except RuntimeError:
# No local host mesh to shutdown
pass

# Nothing to shutdown - return a completed future
async def noop() -> None:
Expand Down