@@ -10,7 +10,9 @@ use std::collections::HashMap;
1010use std:: ops:: Deref ;
1111use std:: path:: PathBuf ;
1212use std:: sync:: OnceLock ;
13+ use std:: time:: Duration ;
1314
15+ use hyperactor:: ActorHandle ;
1416use hyperactor:: Instance ;
1517use hyperactor:: Proc ;
1618use hyperactor_mesh:: bootstrap:: BootstrapCommand ;
@@ -22,6 +24,8 @@ use hyperactor_mesh::v1::ProcMeshRef;
2224use hyperactor_mesh:: v1:: host_mesh:: HostMesh ;
2325use hyperactor_mesh:: v1:: host_mesh:: HostMeshRef ;
2426use hyperactor_mesh:: v1:: host_mesh:: mesh_agent:: GetLocalProcClient ;
27+ use hyperactor_mesh:: v1:: host_mesh:: mesh_agent:: HostMeshAgent ;
28+ use hyperactor_mesh:: v1:: host_mesh:: mesh_agent:: ShutdownHostClient ;
2529use hyperactor_mesh:: v1:: proc_mesh:: ProcRef ;
2630use ndslice:: View ;
2731use ndslice:: view:: RankedSliceable ;
@@ -253,6 +257,9 @@ impl PyHostMeshRefImpl {
253257/// Static storage for the root client instance when using host-based bootstrap.
254258static ROOT_CLIENT_INSTANCE_FOR_HOST : OnceLock < Instance < PythonActor > > = OnceLock :: new ( ) ;
255259
260+ /// Static storage for the host mesh agent created by bootstrap_host().
261+ static HOST_MESH_AGENT_FOR_HOST : OnceLock < ActorHandle < HostMeshAgent > > = OnceLock :: new ( ) ;
262+
256263/// Bootstrap the client host and root client actor.
257264///
258265/// This creates a proper Host with BootstrapProcManager, spawns the root client
@@ -282,6 +289,9 @@ fn bootstrap_host(bootstrap_cmd: Option<PyBootstrapCommand>) -> PyResult<PyPytho
282289 . await
283290 . map_err ( |e| PyException :: new_err ( e. to_string ( ) ) ) ?;
284291
292+ // Store the agent for later shutdown
293+ HOST_MESH_AGENT_FOR_HOST . set ( host_mesh_agent. clone ( ) ) . ok ( ) ; // Ignore error if already set
294+
285295 let host_mesh_name = hyperactor_mesh:: v1:: Name :: new_reserved ( "local" ) . unwrap ( ) ;
286296 let host_mesh = HostMeshRef :: from_host_agent ( host_mesh_name, host_mesh_agent. bind ( ) )
287297 . map_err ( |e| PyException :: new_err ( e. to_string ( ) ) ) ?;
@@ -333,6 +343,32 @@ fn py_host_mesh_from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult<PyHostMesh> {
333343 r. map ( PyHostMesh :: new_ref)
334344}
335345
346+ #[ pyfunction]
347+ fn shutdown_local_host_mesh ( ) -> PyResult < PyPythonTask > {
348+ let agent = HOST_MESH_AGENT_FOR_HOST
349+ . get ( )
350+ . ok_or_else ( || PyException :: new_err ( "No local host mesh to shutdown" ) ) ?
351+ . clone ( ) ;
352+
353+ PyPythonTask :: new ( async move {
354+ // Create a temporary instance to send the shutdown message
355+ let temp_proc = hyperactor:: Proc :: local ( ) ;
356+ let ( instance, _) = temp_proc
357+ . instance ( "shutdown_requester" )
358+ . map_err ( |e| PyException :: new_err ( e. to_string ( ) ) ) ?;
359+
360+ // Use same defaults as HostMesh::shutdown():
361+ // - MESH_TERMINATE_TIMEOUT = 10 seconds
362+ // - MESH_TERMINATE_CONCURRENCY = 16
363+ agent
364+ . shutdown_host ( & instance, Duration :: from_secs ( 10 ) , 16 )
365+ . await
366+ . map_err ( |e| PyException :: new_err ( e. to_string ( ) ) ) ?;
367+
368+ Ok ( ( ) )
369+ } )
370+ }
371+
336372pub fn register_python_bindings ( hyperactor_mod : & Bound < ' _ , PyModule > ) -> PyResult < ( ) > {
337373 let f = wrap_pyfunction ! ( py_host_mesh_from_bytes, hyperactor_mod) ?;
338374 f. setattr (
@@ -348,6 +384,13 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
348384 ) ?;
349385 hyperactor_mod. add_function ( f2) ?;
350386
387+ let f3 = wrap_pyfunction ! ( shutdown_local_host_mesh, hyperactor_mod) ?;
388+ f3. setattr (
389+ "__module__" ,
390+ "monarch._rust_bindings.monarch_hyperactor.v1.host_mesh" ,
391+ ) ?;
392+ hyperactor_mod. add_function ( f3) ?;
393+
351394 hyperactor_mod. add_class :: < PyHostMesh > ( ) ?;
352395 hyperactor_mod. add_class :: < PyBootstrapCommand > ( ) ?;
353396 Ok ( ( ) )
0 commit comments