From 448cb78fc0fade23280fc7705ada7a1349f352ce Mon Sep 17 00:00:00 2001 From: Peng Zhang Date: Tue, 16 Dec 2025 10:29:44 -0800 Subject: [PATCH 1/4] Add Channel::Alias (#2141) Summary: In the Lightning integration, there is a need that we want to explicitly to config the client and worker hosts to: 1. use public IP address and an explicit port as the `Host`'s frontend address; 2. bind the port to INADDR_ANY, since AWS does not allow the port being bound to the public IP address. In order to enable this configuration, this diff adds a new variant, `Alias` to channel. This variant provides 2 fields: 1. dial_to, which can be used by Lightning to set as the public IP address. 2. bind_to, which is used to specify how to bind the port. It is Lightning's responsibility to ensure the network is configured in a way that the packages sent to `dial_to` would be routed to `bind_to`. Reviewed By: mariusae Differential Revision: D89190085 --- hyperactor/src/channel.rs | 135 +++++++++++++++++++++++++++++++++++--- 1 file changed, 127 insertions(+), 8 deletions(-) diff --git a/hyperactor/src/channel.rs b/hyperactor/src/channel.rs index 76d2095e1..fc8bc17d3 100644 --- a/hyperactor/src/channel.rs +++ b/hyperactor/src/channel.rs @@ -505,6 +505,27 @@ pub enum ChannelAddr { /// A unix domain socket address. Supports both absolute path names as /// well as "abstract" names per https://manpages.debian.org/unstable/manpages/unix.7.en.html#Abstract_sockets Unix(net::unix::SocketAddr), + + /// A pair of addresses, one for the client and one for the server: + /// - The client should dial to the `dial_to` address. + /// - The server should bind to the `bind_to` address. + /// + /// The user is responsible for ensuring the traffic to the `dial_to` address + /// is routed to the `bind_to` address. + /// + /// This is useful for scenarios where the network is configured in a way, + /// that the bound address is not directly accessible from the client. + /// + /// For example, in AWS, the client could be provided with the public IP + /// address, yet the server is bound to a private IP address or simply + /// INADDR_ANY. Traffic to the public IP address is mapped to the private + /// IP address through network address translation (NAT). + Alias { + /// The address to which the client should dial to. + dial_to: Box, + /// The address to which the server should bind to. + bind_to: Box, + }, } impl From for ChannelAddr { @@ -602,6 +623,9 @@ impl ChannelAddr { Self::Local(_) => ChannelTransport::Local, Self::Sim(addr) => ChannelTransport::Sim(Box::new(addr.transport())), Self::Unix(_) => ChannelTransport::Unix, + // bind_to's transport is what is actually used in communication. + // Therefore we use its transport to represent the Alias. + Self::Alias { bind_to, .. } => bind_to.transport(), } } } @@ -614,6 +638,9 @@ impl fmt::Display for ChannelAddr { Self::Local(index) => write!(f, "local:{}", index), Self::Sim(sim_addr) => write!(f, "sim:{}", sim_addr), Self::Unix(addr) => write!(f, "unix:{}", addr), + Self::Alias { dial_to, bind_to } => { + write!(f, "alias:dial_to={};bind_to={}", dial_to, bind_to) + } } } } @@ -634,6 +661,11 @@ impl FromStr for ChannelAddr { Some(("metatls", rest)) => net::meta::parse(rest).map_err(|e| e.into()), Some(("sim", rest)) => sim::parse(rest).map_err(|e| e.into()), Some(("unix", rest)) => Ok(Self::Unix(net::unix::SocketAddr::from_str(rest)?)), + Some(("alias", _)) => Err(anyhow::anyhow!( + "detect possible alias address, but we currently do not support \ + parsing alias' string representation since we only want to \ + support parsing its zmq url format." + )), Some((r#type, _)) => Err(anyhow::anyhow!("no such channel type: {type}")), None => Err(anyhow::anyhow!("no channel type specified")), } @@ -647,7 +679,38 @@ impl ChannelAddr { /// - inproc://endpoint-name (equivalent to local) /// - ipc://path (equivalent to unix) /// - metatls://hostname:port or metatls://*:port + /// - Alias format: dial_to_url@bind_to_url (e.g., tcp://host:port@tcp://host:port) + /// Note: Alias format is currently only supported for TCP addresses pub fn from_zmq_url(address: &str) -> Result { + // Check for Alias format: dial_to_url@bind_to_url + // The @ character separates two valid ZMQ URLs + if let Some(at_pos) = address.find('@') { + let dial_to_str = &address[..at_pos]; + let bind_to_str = &address[at_pos + 1..]; + + // Validate that both addresses use TCP scheme + if !dial_to_str.starts_with("tcp://") { + return Err(anyhow::anyhow!( + "alias format is only supported for TCP addresses, got dial_to: {}", + dial_to_str + )); + } + if !bind_to_str.starts_with("tcp://") { + return Err(anyhow::anyhow!( + "alias format is only supported for TCP addresses, got bind_to: {}", + bind_to_str + )); + } + + let dial_to = Self::from_zmq_url(dial_to_str)?; + let bind_to = Self::from_zmq_url(bind_to_str)?; + + return Ok(Self::Alias { + dial_to: Box::new(dial_to), + bind_to: Box::new(bind_to), + }); + } + // Try ZMQ-style URL format first (scheme://...) let (scheme, address) = address.split_once("://").ok_or_else(|| { anyhow::anyhow!("address must be in url form scheme://endppoint {}", address) @@ -850,6 +913,7 @@ pub fn dial(addr: ChannelAddr) -> Result, Channel ChannelAddr::MetaTls(meta_addr) => ChannelTxKind::MetaTls(net::meta::dial(meta_addr)?), ChannelAddr::Sim(sim_addr) => ChannelTxKind::Sim(sim::dial::(sim_addr)?), ChannelAddr::Unix(path) => ChannelTxKind::Unix(net::unix::dial(path)), + ChannelAddr::Alias { dial_to, .. } => dial(*dial_to)?.inner, }; Ok(ChannelTx { inner }) } @@ -862,6 +926,19 @@ pub fn serve( addr: ChannelAddr, ) -> Result<(ChannelAddr, ChannelRx), ChannelError> { let caller = Location::caller(); + serve_inner(addr).map(|(addr, inner)| { + tracing::debug!( + name = "serve", + %addr, + %caller, + ); + (addr, ChannelRx { inner }) + }) +} + +fn serve_inner( + addr: ChannelAddr, +) -> Result<(ChannelAddr, ChannelRxKind), ChannelError> { match addr { ChannelAddr::Tcp(addr) => { let (addr, rx) = net::tcp::serve::(addr)?; @@ -887,15 +964,15 @@ pub fn serve( "invalid local addr: {}", a ))), + ChannelAddr::Alias { dial_to, bind_to } => { + let (bound_addr, rx) = serve_inner::(*bind_to)?; + let alias_addr = ChannelAddr::Alias { + dial_to, + bind_to: Box::new(bound_addr), + }; + Ok((alias_addr, rx)) + } } - .map(|(addr, inner)| { - tracing::debug!( - name = "serve", - %addr, - %caller, - ); - (addr, ChannelRx { inner }) - }) } /// Serve on the local address. The server is turned down @@ -1066,6 +1143,48 @@ mod tests { assert!(ChannelAddr::from_zmq_url("inproc://not-a-number").is_err()); } + #[test] + fn test_zmq_style_alias_channel_addr() { + // Test Alias format: dial_to_url@bind_to_url + // The format is: dial_to_url@bind_to_url where both are valid ZMQ URLs + // Note: Alias format is only supported for TCP addresses + + // Test Alias with tcp on both sides + let alias_addr = ChannelAddr::from_zmq_url("tcp://127.0.0.1:9000@tcp://[::]:8800").unwrap(); + match alias_addr { + ChannelAddr::Alias { dial_to, bind_to } => { + assert_eq!( + *dial_to, + ChannelAddr::Tcp("127.0.0.1:9000".parse().unwrap()) + ); + assert_eq!(*bind_to, ChannelAddr::Tcp("[::]:8800".parse().unwrap())); + } + _ => panic!("Expected Alias"), + } + + // Test error: alias with non-tcp dial_to (not supported) + assert!( + ChannelAddr::from_zmq_url("metatls://example.com:443@tcp://127.0.0.1:8080").is_err() + ); + + // Test error: alias with non-tcp bind_to (not supported) + assert!( + ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@metatls://example.com:443").is_err() + ); + + // Test error: invalid dial_to URL in Alias + assert!(ChannelAddr::from_zmq_url("invalid://scheme@tcp://127.0.0.1:8080").is_err()); + + // Test error: invalid bind_to URL in Alias + assert!(ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@invalid://scheme").is_err()); + + // Test error: missing port in dial_to + assert!(ChannelAddr::from_zmq_url("tcp://host@tcp://127.0.0.1:8080").is_err()); + + // Test error: missing port in bind_to + assert!(ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@tcp://example.com").is_err()); + } + #[tokio::test] async fn test_multiple_connections() { for addr in ChannelTransport::all().map(ChannelAddr::any) { From e4e8c38cb25aceea4414bc0ea52ddcd71d4e927e Mon Sep 17 00:00:00 2001 From: Peng Zhang Date: Tue, 16 Dec 2025 10:29:44 -0800 Subject: [PATCH 2/4] Add enum BindSpec and use it in enable_transport (#2140) Summary: As explained in D89190085, Lightning wants to explicitly set the frontend address both client and worker hosts. Since the client's frontend address is controlled through `enable_transport`, this diff adds a new enum BindSpec, and plumb it to `enable_transport`, so the user can pass an explicit address string to do that. Differential Revision: D89190087 --- hyperactor/src/channel.rs | 59 +++++++++++++ hyperactor_mesh/src/proc_mesh.rs | 20 ++++- monarch_hyperactor/src/actor.rs | 5 +- monarch_hyperactor/src/channel.rs | 83 +++++++++++++++++++ monarch_hyperactor/src/config.rs | 27 ++++-- .../monarch_hyperactor/channel.pyi | 26 ++++++ .../monarch_hyperactor/config.pyi | 6 +- python/monarch/_src/actor/actor_mesh.py | 36 ++++++-- python/tests/test_config.py | 37 +++++++-- 9 files changed, 266 insertions(+), 33 deletions(-) diff --git a/hyperactor/src/channel.rs b/hyperactor/src/channel.rs index fc8bc17d3..0925ceb2b 100644 --- a/hyperactor/src/channel.rs +++ b/hyperactor/src/channel.rs @@ -445,6 +445,65 @@ impl AttrValue for ChannelTransport { } } +/// Specifies how to bind a channel server. +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Named)] +pub enum BindSpec { + /// Bind to any available address for the given transport. + Any(ChannelTransport), + + /// Bind to a specific channel address. + Addr(ChannelAddr), +} + +impl BindSpec { + /// Return an "any" address for this bind spec. + pub fn any(&self) -> ChannelAddr { + match self { + BindSpec::Any(transport) => ChannelAddr::any(transport.clone()), + BindSpec::Addr(addr) => addr.clone(), + } + } +} + +impl From for BindSpec { + fn from(transport: ChannelTransport) -> Self { + BindSpec::Any(transport) + } +} + +impl From for BindSpec { + fn from(addr: ChannelAddr) -> Self { + BindSpec::Addr(addr) + } +} + +impl fmt::Display for BindSpec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Any(transport) => write!(f, "any({})", transport), + Self::Addr(addr) => write!(f, "addr({})", addr), + } + } +} + +impl AttrValue for BindSpec { + fn display(&self) -> String { + self.to_string() + } + + fn parse(s: &str) -> Result { + if let Some(inner) = s.strip_prefix("addr(").and_then(|s| s.strip_suffix(")")) { + let addr = ChannelAddr::from_str(inner)?; + Ok(BindSpec::Addr(addr)) + } else if let Some(inner) = s.strip_prefix("any(").and_then(|s| s.strip_suffix(")")) { + let transport = ChannelTransport::from_str(inner)?; + Ok(BindSpec::Any(transport)) + } else { + Err(anyhow::anyhow!("invalid bind spec: {}", s)) + } + } +} + /// The type of (TCP) hostnames. pub type Hostname = String; diff --git a/hyperactor_mesh/src/proc_mesh.rs b/hyperactor_mesh/src/proc_mesh.rs index 0135ac429..54801be5d 100644 --- a/hyperactor_mesh/src/proc_mesh.rs +++ b/hyperactor_mesh/src/proc_mesh.rs @@ -30,6 +30,7 @@ use hyperactor::WorldId; use hyperactor::actor::ActorStatus; use hyperactor::actor::remote::Remote; use hyperactor::channel; +use hyperactor::channel::BindSpec; use hyperactor::channel::ChannelAddr; use hyperactor::channel::ChannelTransport; use hyperactor::context; @@ -96,11 +97,24 @@ declare_attrs! { env_name: Some("HYPERACTOR_MESH_DEFAULT_TRANSPORT".to_string()), py_name: Some("default_transport".to_string()), }) - pub attr DEFAULT_TRANSPORT: ChannelTransport = ChannelTransport::Unix; + pub attr DEFAULT_TRANSPORT: BindSpec = BindSpec::Any(ChannelTransport::Unix); } -/// Get the default transport type to use across the application. +/// Temporary: used to support the legacy allocator-based V1 bootstrap. Should +/// be removed once we fully migrate to simple bootstrap. +/// +/// Get the default transport to use across the application. Panic if BindSpec::Addr +/// is set as default transport. Since we expect BindSpec::Addr to be used only +/// with simple bootstrap, we should not see this panic in production. pub fn default_transport() -> ChannelTransport { + match default_bind_spec() { + BindSpec::Any(transport) => transport, + BindSpec::Addr(addr) => panic!("default_bind_spec() returned BindSpec::Addr({addr})"), + } +} + +/// Get the default bind spec to use across the application. +pub fn default_bind_spec() -> BindSpec { global::get_cloned(DEFAULT_TRANSPORT) } @@ -187,7 +201,7 @@ pub fn global_root_client() -> &'static Instance { )> = OnceLock::new(); &GLOBAL_INSTANCE.get_or_init(|| { let client_proc = Proc::direct_with_default( - ChannelAddr::any(default_transport()), + default_bind_spec().any(), "mesh_root_client_proc".into(), router::global().clone().boxed(), ) diff --git a/monarch_hyperactor/src/actor.rs b/monarch_hyperactor/src/actor.rs index 62f920fff..f728d3b25 100644 --- a/monarch_hyperactor/src/actor.rs +++ b/monarch_hyperactor/src/actor.rs @@ -27,7 +27,6 @@ use hyperactor::RemoteSpawn; use hyperactor::actor::ActorError; use hyperactor::actor::ActorErrorKind; use hyperactor::actor::ActorStatus; -use hyperactor::channel::ChannelAddr; use hyperactor::mailbox::BoxableMailboxSender; use hyperactor::mailbox::MessageEnvelope; use hyperactor::mailbox::Undeliverable; @@ -40,7 +39,7 @@ use hyperactor_config::Attrs; use hyperactor_mesh::actor_mesh::CAST_ACTOR_MESH_ID; use hyperactor_mesh::comm::multicast::CAST_ORIGINATING_SENDER; use hyperactor_mesh::comm::multicast::CastInfo; -use hyperactor_mesh::proc_mesh::default_transport; +use hyperactor_mesh::proc_mesh::default_bind_spec; use hyperactor_mesh::reference::ActorMeshId; use hyperactor_mesh::router; use hyperactor_mesh::supervision::SupervisionFailureMessage; @@ -525,7 +524,7 @@ impl PythonActor { static ROOT_CLIENT_INSTANCE: OnceLock> = OnceLock::new(); let client_proc = Proc::direct_with_default( - ChannelAddr::any(default_transport()), + default_bind_spec().any(), "mesh_root_client_proc".into(), router::global().clone().boxed(), ) diff --git a/monarch_hyperactor/src/channel.rs b/monarch_hyperactor/src/channel.rs index e3bdebdfe..4eace31b2 100644 --- a/monarch_hyperactor/src/channel.rs +++ b/monarch_hyperactor/src/channel.rs @@ -8,16 +8,21 @@ use std::str::FromStr; +use hyperactor::channel::BindSpec; use hyperactor::channel::ChannelAddr; use hyperactor::channel::ChannelTransport; use hyperactor::channel::MetaTlsAddr; use hyperactor::channel::TcpMode; use hyperactor::channel::TlsMode; use pyo3::exceptions::PyRuntimeError; +use pyo3::exceptions::PyTypeError; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; /// Python binding for [`hyperactor::channel::ChannelTransport`] +/// +/// This enum represents the basic transport types that can be represented +/// as simple enum variants. For explicit addresses, use `PyBindSpec`. #[pyclass( name = "ChannelTransport", module = "monarch._rust_bindings.monarch_hyperactor.channel", @@ -62,6 +67,83 @@ impl TryFrom for PyChannelTransport { } } +/// Python binding for [`hyperactor::channel::BindSpec`] +#[pyclass( + name = "BindSpec", + module = "monarch._rust_bindings.monarch_hyperactor.channel" +)] +#[derive(Clone, Debug, PartialEq)] +pub struct PyBindSpec { + inner: BindSpec, +} + +#[pymethods] +impl PyBindSpec { + /// Create a new PyBindSpec from a ChannelTransport enum, a string representation, + /// or another PyBindSpec object. + /// + /// Examples: + /// PyBindSpec(ChannelTransport.Unix) + /// PyBindSpec("tcp://127.0.0.1:8080") + /// PyBindSpec(PyBindSpec(ChannelTransport.Unix)) + #[new] + pub fn new(spec: &Bound<'_, PyAny>) -> PyResult { + // First try to extract as PyBindSpec (for when passing an existing spec) + if let Ok(bind_spec) = spec.extract::() { + return Ok(bind_spec); + } + + // Then try to extract as PyChannelTransport enum + if let Ok(py_transport) = spec.extract::() { + let transport: ChannelTransport = py_transport.into(); + return Ok(PyBindSpec { + inner: BindSpec::Any(transport), + }); + } + + // Then try to extract as a string and parse it as a ZMQ URL + if let Ok(spec_str) = spec.extract::() { + let addr = ChannelAddr::from_zmq_url(&spec_str).map_err(|e| { + PyValueError::new_err(format!( + "invalid ZMQ URL for address binding '{}': {}", + spec_str, e + )) + })?; + return Ok(PyBindSpec { + inner: BindSpec::Addr(addr), + }); + } + + Err(PyTypeError::new_err( + "expected ChannelTransport enum, BindSpec, or str", + )) + } + + fn __str__(&self) -> String { + self.inner.to_string() + } + + fn __repr__(&self) -> String { + format!("PyBindSpec({:?})", self.inner) + } + + fn __eq__(&self, other: &Self) -> bool { + self.inner == other.inner + } +} + +impl From for BindSpec { + fn from(spec: PyBindSpec) -> Self { + spec.inner + } +} + +impl From for PyBindSpec { + fn from(spec: BindSpec) -> Self { + PyBindSpec { inner: spec } + } +} + #[pyclass( name = "ChannelAddr", module = "monarch._rust_bindings.monarch_hyperactor.channel" @@ -149,6 +231,7 @@ impl From for ChannelTransport { #[pymodule] pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> { hyperactor_mod.add_class::()?; + hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; Ok(()) } diff --git a/monarch_hyperactor/src/config.rs b/monarch_hyperactor/src/config.rs index 4136e80ee..84daf8282 100644 --- a/monarch_hyperactor/src/config.rs +++ b/monarch_hyperactor/src/config.rs @@ -21,7 +21,7 @@ use std::fmt::Debug; use std::time::Duration; use hyperactor::Named; -use hyperactor::channel::ChannelTransport; +use hyperactor::channel::BindSpec; use hyperactor_config::AttrValue; use hyperactor_config::CONFIG; use hyperactor_config::ConfigAttr; @@ -30,12 +30,13 @@ use hyperactor_config::attrs::Attrs; use hyperactor_config::attrs::ErasedKey; use hyperactor_config::attrs::declare_attrs; use hyperactor_config::global::Source; +use pyo3::conversion::IntoPyObject; use pyo3::conversion::IntoPyObjectExt; use pyo3::exceptions::PyTypeError; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; -use crate::channel::PyChannelTransport; +use crate::channel::PyBindSpec; /// Python wrapper for Duration, using humantime format strings. /// @@ -289,9 +290,9 @@ inventory::collect!(PythonConfigTypeInfo); /// like `String` that are convertible directly to/from PyObjects, /// you can just use `declare_py_config_type!(String)`. For types /// that must first be converted to/from a rust python wrapper -/// (e.g., keys with type `ChannelTransport` must use `PyChannelTransport` +/// (e.g., keys with type `BindSpec` must use `PyBindSpec` /// as an intermediate step), the usage is -/// `declare_py_config_type!(PyChannelTransport as ChannelTransport)`. +/// `declare_py_config_type!(PyBindSpec as BindSpec)`. macro_rules! declare_py_config_type { ($($ty:ty),+ $(,)?) => { hyperactor::paste! { @@ -341,7 +342,7 @@ macro_rules! declare_py_config_type { }; } -declare_py_config_type!(PyChannelTransport as ChannelTransport); +declare_py_config_type!(PyBindSpec as BindSpec); declare_py_config_type!(PyDuration as Duration); declare_py_config_type!( i8, i16, i32, i64, u8, u16, u32, u64, usize, f32, f64, bool, String @@ -367,9 +368,19 @@ declare_py_config_type!( fn configure(py: Python<'_>, kwargs: Option>) -> PyResult<()> { kwargs .map(|kwargs| { - kwargs - .into_iter() - .try_for_each(|(key, val)| configure_kwarg(py, &key, val)) + kwargs.into_iter().try_for_each(|(key, val)| { + // Special handling for default_transport: convert ChannelTransport + // enum or string to PyBindSpec before processing + let val = if key == "default_transport" { + PyBindSpec::new(val.bind(py))? + .into_pyobject(py)? + .into_any() + .unbind() + } else { + val + }; + configure_kwarg(py, &key, val) + }) }) .transpose()?; Ok(()) diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/channel.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/channel.pyi index fea7a1f5f..7b15fa40b 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/channel.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/channel.pyi @@ -9,6 +9,11 @@ from enum import Enum class ChannelTransport(Enum): + """ + Enum representing basic transport types for channels. + For explicit address binding, use BindSpec instead. + """ + TcpWithLocalhost = "tcp(localhost)" TcpWithHostname = "tcp(hostname)" MetaTlsWithHostname = "metatls(hostname)" @@ -17,6 +22,27 @@ class ChannelTransport(Enum): Unix = "unix" # Sim # TODO add support +class BindSpec: + """ + Specify how to bind a channel server. + + Can be created from either a ChannelTransport enum for "any" binding, + or a ZMQ-style URL string for explicit address. + + Note: This class is for internal use only. Users should pass ChannelTransport + enum directly or use the ZMQ-style URL string for explicit address. + """ + + def __init__(self, spec: ChannelTransport | str) -> None: ... + """ + Basic transport types supported by ChannelTransport should be used directly as enum values. + For explicit address binding, use a ZMQ-style URL string. e.g.: + - "tcp://127.0.0.1:8080" + """ + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __eq__(self, other: object) -> bool: ... + class ChannelAddr: @staticmethod def any(transport: ChannelTransport) -> str: diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi index 7ae4446a7..b384b5cc8 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi @@ -32,7 +32,7 @@ def reset_config_to_defaults() -> None: ... def configure( - default_transport: ChannelTransport = ..., + default_transport: ChannelTransport | str = ..., enable_log_forwarding: bool = ..., enable_file_capture: bool = ..., tail_log_lines: int = ..., @@ -50,7 +50,9 @@ def configure( plus any additional CONFIG-marked keys passed via **kwargs. Args: - default_transport: Default channel transport for communication + default_transport: Default channel transport for communication. Can be: + - A ChannelTransport enum value (e.g., ChannelTransport.Unix) + - A explicit address string in the ZMQ-style URL format (e.g., "tcp://127.0.0.1:8080") enable_log_forwarding: Whether to forward logs from actors enable_file_capture: Whether to capture file output tail_log_lines: Number of log lines to tail diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index 6eeddff27..d366a2477 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -53,7 +53,7 @@ ) from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh from monarch._rust_bindings.monarch_hyperactor.buffers import FrozenBuffer -from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport +from monarch._rust_bindings.monarch_hyperactor.channel import BindSpec, ChannelTransport from monarch._rust_bindings.monarch_hyperactor.config import configure from monarch._rust_bindings.monarch_hyperactor.context import Instance as HyInstance from monarch._rust_bindings.monarch_hyperactor.mailbox import ( @@ -410,7 +410,7 @@ def context() -> Context: return c -_transport: Optional[ChannelTransport] = None +_transport: Optional[BindSpec] = None _transport_lock = threading.Lock() @@ -424,17 +424,33 @@ def enable_transport(transport: "ChannelTransport | str") -> None: Currently only one transport type may be enabled at one time. In the future we may allow multiple to be enabled. + Supported transport values: + - ChannelTransport enum: ChannelTransport.Unix, ChannelTransport.TcpWithHostname, etc. + - string short cuts for the ChannelTransport enum: + - "tcp": ChannelTransport.TcpWithHostname + - "ipc": ChannelTransport.Unix + - "metatls": ChannelTransport.MetaTlsWithIpV6 + - "metatls-hostname": ChannelTransport.MetaTlsWithHostname + - ZMQ-style URL format string for explicit address, e.g.: + - "tcp://127.0.0.1:8080" + For Meta usage, use metatls-hostname """ if isinstance(transport, str): - transport = { + # Handle string shortcuts for the ChannelTransport enum, + resolved = { "tcp": ChannelTransport.TcpWithHostname, "ipc": ChannelTransport.Unix, "metatls": ChannelTransport.MetaTlsWithIpV6, "metatls-hostname": ChannelTransport.MetaTlsWithHostname, }.get(transport) - if transport is None: - raise ValueError(f"unknown transport: {transport}") + if resolved is not None: + transport_config = BindSpec(resolved) + else: + transport_config = BindSpec(transport) + else: + # ChannelTransport enum + transport_config = BindSpec(transport) if _context.get(None) is not None: raise RuntimeError( @@ -445,14 +461,16 @@ def enable_transport(transport: "ChannelTransport | str") -> None: global _transport with _transport_lock: - if _transport is not None and _transport != transport: + if _transport is not None and _transport != transport_config: raise RuntimeError( f"Only one transport type may be enabled at one time. " f"Currently enabled transport type is `{_transport}`. " - f"Attempted to enable transport type `{transport}`." + f"Attempted to enable transport type `{transport_config}`." ) - _transport = transport - configure(default_transport=transport) + _transport = transport_config + # pyre-ignore[6]: BindSpec is accepted by configure. We just do not expose + # it in the method's signature since BindSpec is not a public type. + configure(default_transport=transport_config) @dataclass diff --git a/python/tests/test_config.py b/python/tests/test_config.py index 45df1f89c..188b82374 100644 --- a/python/tests/test_config.py +++ b/python/tests/test_config.py @@ -10,7 +10,8 @@ import monarch import pytest -from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport + +from monarch._rust_bindings.monarch_hyperactor.channel import BindSpec, ChannelTransport from monarch._rust_bindings.monarch_hyperactor.supervision import SupervisionError from monarch.actor import Actor, endpoint, this_proc from monarch.config import configured, get_global_config @@ -34,12 +35,12 @@ def test_get_set_transport() -> None: ChannelTransport.MetaTlsWithHostname, ): with configured(default_transport=transport) as config: - assert config["default_transport"] == transport + assert config["default_transport"] == BindSpec(transport) # Succeed even if we don't specify the transport, but does not change the # previous value. with configured() as config: - assert config["default_transport"] == ChannelTransport.Unix - with pytest.raises(TypeError): + assert config["default_transport"] == BindSpec(ChannelTransport.Unix) + with pytest.raises(ValueError): with configured(default_transport="unix"): # type: ignore pass with pytest.raises(TypeError): @@ -50,6 +51,22 @@ def test_get_set_transport() -> None: pass +def test_get_set_explicit_transport() -> None: + # Test explicit transport with a TCP address + with configured(default_transport="tcp://127.0.0.1:8080") as config: + assert config["default_transport"] == BindSpec("tcp://127.0.0.1:8080") + + # Test that invalid explicit transport strings raise an error + with pytest.raises(ValueError): + with configured(default_transport="invalid://scheme"): + pass + + # Test that random strings (not ZMQ URL format) raise an error + with pytest.raises(ValueError): + with configured(default_transport="random_string"): + pass + + def test_nonexistent_config_key() -> None: with pytest.raises(ValueError): with configured(does_not_exist=42): # type: ignore @@ -64,13 +81,15 @@ def test_get_set_multiple() -> None: assert config["enable_log_forwarding"] assert config["enable_file_capture"] assert config["tail_log_lines"] == 100 - assert config["default_transport"] == ChannelTransport.TcpWithLocalhost + assert config["default_transport"] == BindSpec( + ChannelTransport.TcpWithLocalhost + ) # Make sure the previous values are restored. config = get_global_config() assert not config["enable_log_forwarding"] assert not config["enable_file_capture"] assert config["tail_log_lines"] == 0 - assert config["default_transport"] == ChannelTransport.Unix + assert config["default_transport"] == BindSpec(ChannelTransport.Unix) # This test tries to allocate too much memory for the GitHub actions @@ -219,7 +238,9 @@ def test_duration_config_multiple() -> None: enable_log_forwarding=True, tail_log_lines=100, ) as config: - assert config["default_transport"] == ChannelTransport.TcpWithLocalhost + assert config["default_transport"] == BindSpec( + ChannelTransport.TcpWithLocalhost + ) assert config["host_spawn_ready_timeout"] == "10m" assert config["message_delivery_timeout"] == "5m" assert config["mesh_proc_spawn_max_idle"] == "2m" @@ -228,7 +249,7 @@ def test_duration_config_multiple() -> None: # Verify all values are restored config = get_global_config() - assert config["default_transport"] == ChannelTransport.Unix + assert config["default_transport"] == BindSpec(ChannelTransport.Unix) assert config["host_spawn_ready_timeout"] == "30s" assert config["message_delivery_timeout"] == "30s" assert config["mesh_proc_spawn_max_idle"] == "30s" From 0a4a86bc4a056f9e752db3d87ca82ce765b8bf2a Mon Sep 17 00:00:00 2001 From: Peng Zhang Date: Tue, 16 Dec 2025 10:29:44 -0800 Subject: [PATCH 3/4] make client this_proc() real (#2144) Summary: This builds on the previous change to implement a *real* in-process host and local proc mesh. Thus, nothing is fake anymore in the v1 path: this_proc() behaves the same everywhere, and is managed in precisely the same way. This also means that we can do things like hand out references to the local host (or proc), and have remove actors spawn (procs, actors) on it! It will also simplify integration: since the client proc is now a host, it has a single front-end address that multiplexes all its procs (including the local one). We keep around the old codepaths for the benefit of v0; this can be significantly simplified again once we drop v0 support. Differential Revision: D89196041 --- .../proptest-regressions/actor_mesh.txt | 7 ++ hyperactor_mesh/src/bootstrap.rs | 5 +- hyperactor_mesh/src/proc_mesh/mesh_agent.rs | 20 +++++ hyperactor_mesh/src/v1.rs | 6 +- hyperactor_mesh/src/v1/host_mesh.rs | 26 ++++++ hyperactor_mesh/src/v1/proc_mesh.rs | 17 +++- monarch_hyperactor/src/actor.rs | 21 ++++- monarch_hyperactor/src/context.rs | 11 +++ monarch_hyperactor/src/v1/host_mesh.rs | 88 +++++++++++++++++++ .../monarch_hyperactor/v1/host_mesh.pyi | 12 +++ python/monarch/_src/actor/actor_mesh.py | 83 +++++++++-------- python/monarch/_src/actor/host_mesh.py | 3 +- python/monarch/_src/actor/v1/host_mesh.py | 45 ++++++---- python/monarch/_src/actor/v1/proc_mesh.py | 31 +++---- python/tests/test_host_mesh.py | 16 ++-- python/tests/test_proc_mesh.py | 10 +-- 16 files changed, 306 insertions(+), 95 deletions(-) create mode 100644 hyperactor_mesh/proptest-regressions/actor_mesh.txt diff --git a/hyperactor_mesh/proptest-regressions/actor_mesh.txt b/hyperactor_mesh/proptest-regressions/actor_mesh.txt new file mode 100644 index 000000000..ed2fe0912 --- /dev/null +++ b/hyperactor_mesh/proptest-regressions/actor_mesh.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 7da0353e3138258986cf2598af0836656a1f7c3399a9ffa18ca93cf983b3e64c # shrinks to extent = Extent { inner: ExtentData { labels: ["d/0"], sizes: [1] } } diff --git a/hyperactor_mesh/src/bootstrap.rs b/hyperactor_mesh/src/bootstrap.rs index 43a5aad6a..accb9d7d0 100644 --- a/hyperactor_mesh/src/bootstrap.rs +++ b/hyperactor_mesh/src/bootstrap.rs @@ -266,7 +266,10 @@ async fn halt() -> R { unreachable!() } -/// Bootstrap a host in this process, returning a handle to the mesh agent: +/// Bootstrap a host in this process, returning a handle to the mesh agent. +/// +/// To obtain the local proc, use `GetLocalProc` on the returned host mesh agent, +/// then use `GetProc` on the returned proc mesh agent. /// /// - `addr`: the listening address of the host; this is used to bind the frontend address; /// - `command`: optional bootstrap command to spawn procs, otherwise [`BootstrapProcManager::current`]; diff --git a/hyperactor_mesh/src/proc_mesh/mesh_agent.rs b/hyperactor_mesh/src/proc_mesh/mesh_agent.rs index 5a9d49a73..01f200769 100644 --- a/hyperactor_mesh/src/proc_mesh/mesh_agent.rs +++ b/hyperactor_mesh/src/proc_mesh/mesh_agent.rs @@ -795,6 +795,26 @@ impl Handler for ProcMeshAgent { } } +/// A handler to get a clone of the proc managed by this agent. +/// This is used to obtain the local proc from a host mesh. +#[derive(Debug, hyperactor::Handler, hyperactor::HandleClient)] +pub struct GetProc { + #[reply] + pub proc: PortHandle, +} + +#[async_trait] +impl Handler for ProcMeshAgent { + async fn handle( + &mut self, + _cx: &Context, + GetProc { proc }: GetProc, + ) -> anyhow::Result<()> { + proc.send(self.proc.clone())?; + Ok(()) + } +} + /// A mailbox sender that initially queues messages, and then relays them to /// an underlying sender once configured. #[derive(Clone)] diff --git a/hyperactor_mesh/src/v1.rs b/hyperactor_mesh/src/v1.rs index cf7b95821..f0475453c 100644 --- a/hyperactor_mesh/src/v1.rs +++ b/hyperactor_mesh/src/v1.rs @@ -28,6 +28,7 @@ pub use host_mesh::HostMeshRef; use hyperactor::ActorId; use hyperactor::ActorRef; use hyperactor::Named; +use hyperactor::ProcId; use hyperactor::host::HostError; use hyperactor::mailbox::MailboxSenderError; use hyperactor::reference; @@ -149,6 +150,9 @@ pub enum Error { #[error("error spawning controller actor for mesh {0}: {1}")] ControllerActorSpawnError(Name, anyhow::Error), + #[error("proc {0} must be direct-addressable")] + RankedProc(ProcId), + #[error("error: {0} does not exist")] NotExist(Name), @@ -253,7 +257,7 @@ impl Name { } /// Create a Reserved `Name` with no uuid. Only for use by system actors. - pub(crate) fn new_reserved(name: impl Into) -> Result { + pub fn new_reserved(name: impl Into) -> Result { Ok(Self::new_with_uuid(name, None)?) } diff --git a/hyperactor_mesh/src/v1/host_mesh.rs b/hyperactor_mesh/src/v1/host_mesh.rs index 9e7bbb785..32f615f36 100644 --- a/hyperactor_mesh/src/v1/host_mesh.rs +++ b/hyperactor_mesh/src/v1/host_mesh.rs @@ -147,6 +147,18 @@ impl HostRef { } } +impl TryFrom> for HostRef { + type Error = v1::Error; + + fn try_from(value: ActorRef) -> Result { + let proc_id = value.actor_id().proc_id(); + match proc_id.as_direct() { + Some((addr, _)) => Ok(HostRef(addr.clone())), + None => Err(v1::Error::RankedProc(proc_id.clone())), + } + } +} + impl std::fmt::Display for HostRef { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) @@ -734,6 +746,20 @@ impl HostMeshRef { } } + /// Create a new HostMeshRef from an arbitrary set of host mesh agents. + pub fn from_host_agents(name: Name, agents: Vec>) -> v1::Result { + Ok(Self { + name, + region: extent!(hosts = agents.len()).into(), + ranks: Arc::new( + agents + .into_iter() + .map(HostRef::try_from) + .collect::>()?, + ), + }) + } + /// Spawn a ProcMesh onto this host mesh. The per_host extent specifies the shape /// of the procs to spawn on each host. /// diff --git a/hyperactor_mesh/src/v1/proc_mesh.rs b/hyperactor_mesh/src/v1/proc_mesh.rs index a9c5b3619..c51a186bb 100644 --- a/hyperactor_mesh/src/v1/proc_mesh.rs +++ b/hyperactor_mesh/src/v1/proc_mesh.rs @@ -106,7 +106,8 @@ pub struct ProcRef { } impl ProcRef { - pub(crate) fn new(proc_id: ProcId, create_rank: usize, agent: ActorRef) -> Self { + /// Create a new proc ref from the provided id, create rank and agent. + pub fn new(proc_id: ProcId, create_rank: usize, agent: ActorRef) -> Self { Self { proc_id, create_rank, @@ -713,6 +714,20 @@ impl ProcMeshRef { }) } + /// Create a singleton ProcMeshRef, given the provided ProcRef and name. + /// This is used to support creating local singleton proc meshes to support `this_proc()` + /// in python client actors. + pub fn new_singleton(name: Name, proc_ref: ProcRef) -> Self { + Self { + name, + region: Extent::unity().into(), + ranks: Arc::new(vec![proc_ref]), + host_mesh: None, + root_region: None, + root_comm_actor: None, + } + } + pub(crate) fn root_comm_actor(&self) -> Option<&ActorRef> { self.root_comm_actor.as_ref() } diff --git a/monarch_hyperactor/src/actor.rs b/monarch_hyperactor/src/actor.rs index f728d3b25..b5d50aedd 100644 --- a/monarch_hyperactor/src/actor.rs +++ b/monarch_hyperactor/src/actor.rs @@ -520,6 +520,8 @@ impl PythonActor { }) } + /// Bootstrap the root client actor, creating a new proc for it. + /// This is the legacy entry point that creates its own proc. pub(crate) fn bootstrap_client(py: Python<'_>) -> (&'static Instance, ActorHandle) { static ROOT_CLIENT_INSTANCE: OnceLock> = OnceLock::new(); @@ -530,8 +532,21 @@ impl PythonActor { ) .unwrap(); + Self::bootstrap_client_inner(py, client_proc, &ROOT_CLIENT_INSTANCE) + } + + /// Bootstrap the client proc, storing the root client instance in given static. + /// This is passed in because we require storage, as the instance is shared. + /// This can be simplified when we remove v0. + pub(crate) fn bootstrap_client_inner( + py: Python<'_>, + client_proc: Proc, + root_client_instance: &'static OnceLock>, + ) -> (&'static Instance, ActorHandle) { // Make this proc reachable through the global router, so that we can use the // same client in both direct-addressed and ranked-addressed modes. + // + // DEPRECATE after v0 removal router::global().bind(client_proc.proc_id().clone().into(), client_proc.clone()); let actor_mesh_mod = py @@ -557,7 +572,7 @@ impl PythonActor { ) .expect("root instance create"); - ROOT_CLIENT_INSTANCE + root_client_instance .set(client) .map_err(|_| "already initialized root client instance") .unwrap(); @@ -577,7 +592,7 @@ impl PythonActor { ) .expect("initialize root client"); - let instance = ROOT_CLIENT_INSTANCE.get().unwrap(); + let instance = root_client_instance.get().unwrap(); get_tokio_runtime().spawn(async move { let mut signal_rx = signal_rx; @@ -626,7 +641,7 @@ impl PythonActor { instance.proc().handle_supervision_event(event); }); - (ROOT_CLIENT_INSTANCE.get().unwrap(), handle) + (root_client_instance.get().unwrap(), handle) } } diff --git a/monarch_hyperactor/src/context.rs b/monarch_hyperactor/src/context.rs index 38e2067f7..162935c78 100644 --- a/monarch_hyperactor/src/context.rs +++ b/monarch_hyperactor/src/context.rs @@ -126,6 +126,17 @@ impl PyContext { rank: Extent::unity().point_of_rank(0).unwrap(), }) } + + /// Create a context from an existing instance. + /// This is used when the root client was bootstrapped via bootstrap_host() + /// instead of the default bootstrap_client(). + #[staticmethod] + fn _from_instance(py: Python<'_>, instance: PyInstance) -> PyResult { + Ok(PyContext { + instance: instance.into_pyobject(py)?.into(), + rank: Extent::unity().point_of_rank(0).unwrap(), + }) + } } impl PyContext { diff --git a/monarch_hyperactor/src/v1/host_mesh.rs b/monarch_hyperactor/src/v1/host_mesh.rs index 58e1666ba..bf5a4bcad 100644 --- a/monarch_hyperactor/src/v1/host_mesh.rs +++ b/monarch_hyperactor/src/v1/host_mesh.rs @@ -9,11 +9,20 @@ use std::collections::HashMap; use std::ops::Deref; use std::path::PathBuf; +use std::sync::OnceLock; +use hyperactor::Instance; +use hyperactor::Proc; use hyperactor_mesh::bootstrap::BootstrapCommand; +use hyperactor_mesh::bootstrap::host; +use hyperactor_mesh::proc_mesh::default_transport; +use hyperactor_mesh::proc_mesh::mesh_agent::GetProcClient; use hyperactor_mesh::shared_cell::SharedCell; +use hyperactor_mesh::v1::ProcMeshRef; use hyperactor_mesh::v1::host_mesh::HostMesh; use hyperactor_mesh::v1::host_mesh::HostMeshRef; +use hyperactor_mesh::v1::host_mesh::mesh_agent::GetLocalProcClient; +use hyperactor_mesh::v1::proc_mesh::ProcRef; use ndslice::View; use ndslice::view::RankedSliceable; use pyo3::IntoPyObjectExt; @@ -24,6 +33,7 @@ use pyo3::prelude::*; use pyo3::types::PyBytes; use pyo3::types::PyType; +use crate::actor::PythonActor; use crate::actor::to_py_error; use crate::alloc::PyAlloc; use crate::context::PyInstance; @@ -240,6 +250,76 @@ impl PyHostMeshRefImpl { } } +/// Static storage for the root client instance when using host-based bootstrap. +static ROOT_CLIENT_INSTANCE_FOR_HOST: OnceLock> = OnceLock::new(); + +/// Bootstrap the client host and root client actor. +/// +/// This creates a proper Host with BootstrapProcManager, spawns the root client +/// actor on the Host's local_proc. +/// +/// Returns a tuple of (HostMesh, ProcMesh, PyInstance) where: +/// - PyHostMesh: the bootstrapped (local) host mesh; and +/// - PyProcMesh: the local ProcMesh on this HostMesh; and +/// - PyInstance: the root client actor instance, on the ProcMesh. +/// +/// The HostMesh is served on the default transport. +/// +/// This should be called only once, at process initialization +#[pyfunction] +fn bootstrap_host(bootstrap_cmd: Option) -> PyResult { + let bootstrap_cmd = match bootstrap_cmd { + Some(cmd) => cmd.to_rust(), + None => BootstrapCommand::current().map_err(|e| PyException::new_err(e.to_string()))?, + }; + + PyPythonTask::new(async move { + let host_mesh_agent = host(default_transport().any(), Some(bootstrap_cmd), None) + .await + .map_err(|e| PyException::new_err(e.to_string()))?; + + let host_mesh_name = hyperactor_mesh::v1::Name::new_reserved("local").unwrap(); + let host_mesh = HostMeshRef::from_host_agents(host_mesh_name, vec![host_mesh_agent.bind()]) + .map_err(|e| PyException::new_err(e.to_string()))?; + + // We require a temporary instance to make a call to the host/proc agent. + let temp_proc = Proc::local(); + let (temp_instance, _) = temp_proc + .instance("temp") + .map_err(|e| PyException::new_err(e.to_string()))?; + + let local_proc_agent = host_mesh_agent + .get_local_proc(&temp_instance) + .await + .map_err(|e| PyException::new_err(e.to_string()))?; + + let proc_mesh_name = hyperactor_mesh::v1::Name::new_reserved("local").unwrap(); + let proc_mesh = ProcMeshRef::new_singleton( + proc_mesh_name, + ProcRef::new( + local_proc_agent.actor_id().proc_id().clone(), + 0, + local_proc_agent.bind(), + ), + ); + + let local_proc = local_proc_agent + .get_proc(&temp_instance) + .await + .map_err(|e| PyException::new_err(e.to_string()))?; + + let (instance, _handle) = Python::with_gil(|py| { + PythonActor::bootstrap_client_inner(py, local_proc, &ROOT_CLIENT_INSTANCE_FOR_HOST) + }); + + Ok(( + PyHostMesh::new_ref(host_mesh), + PyProcMesh::new_ref(proc_mesh), + PyInstance::from(instance), + )) + }) +} + #[pyfunction] fn py_host_mesh_from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult { let r: PyResult = bincode::deserialize(bytes.as_bytes()) @@ -254,6 +334,14 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul "monarch._rust_bindings.monarch_hyperactor.v1.host_mesh", )?; hyperactor_mod.add_function(f)?; + + let f2 = wrap_pyfunction!(bootstrap_host, hyperactor_mod)?; + f2.setattr( + "__module__", + "monarch._rust_bindings.monarch_hyperactor.v1.host_mesh", + )?; + hyperactor_mod.add_function(f2)?; + hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; Ok(()) diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/v1/host_mesh.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/v1/host_mesh.pyi index adf78bef9..4b52bc866 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/v1/host_mesh.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/v1/host_mesh.pyi @@ -101,3 +101,15 @@ class BootstrapCommand: ... def __repr__(self) -> str: ... + +def bootstrap_host( + bootstrap_cmd: BootstrapCommand | None, +) -> PythonTask[(HostMesh, ProcMesh, Instance)]: + """ + Bootstrap a host mesh in this process, returning the host mesh, + proc mesh, and client instance. + + Arguments: + - `bootstrap_cmd`: The bootstrap command to use to bootstrap the host. + """ + ... diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index d366a2477..4e594b7f6 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -264,6 +264,9 @@ def message_rank(self) -> Point: @staticmethod def _root_client_context() -> "Context": ... + @staticmethod + def _from_instance(instance: Instance) -> "Context": ... + _context: contextvars.ContextVar[Context] = contextvars.ContextVar( "monarch.actor_mesh._context" @@ -307,9 +310,12 @@ def _patched_addHandler(self: logging.Logger, hdlr: logging.Handler) -> None: logging.Logger.addHandler = _patched_addHandler -def _set_context(c: Context) -> None: +def _set_context(c: Context) -> contextvars.Token[Context]: _init_context_log_handler() - _context.set(c) + return _context.set(c) + +def _reset_context(c: contextvars.Token[Context]): + _context.reset(c) T = TypeVar("T") @@ -331,15 +337,34 @@ def try_get(self) -> Optional[T]: return self._val -def _init_this_host_for_fake_in_process_host() -> "HostMesh": - from monarch._src.actor.host_mesh import create_local_host_mesh +def _init_client_context() -> Context: + """ + Create a client context that bootstraps an actor instance running on a real + local proc mesh on a real local host mesh. + """ + from monarch._rust_bindings.monarch_hyperactor.v1.host_mesh import bootstrap_host + from monarch._src.actor.host_mesh import HostMesh + from monarch._src.actor.proc_mesh import ProcMesh + from monarch._src.actor.v1.host_mesh import _bootstrap_cmd + + rust_host_mesh, rust_proc_mesh, py_instance = bootstrap_host( + _bootstrap_cmd() + ).block_on() + + ctx = Context._from_instance(py_instance) + # Set the context here to avoid recursive context creation: + token = _set_context(ctx) + try: + py_host_mesh = HostMesh._from_rust(rust_host_mesh) + py_proc_mesh = ProcMesh._from_rust(rust_proc_mesh, py_host_mesh) + finally: + _reset_context(token) - return create_local_host_mesh() + ctx.actor_instance.proc_mesh = py_proc_mesh + return ctx -_this_host_for_fake_in_process_host: _Lazy["HostMesh"] = _Lazy( - _init_this_host_for_fake_in_process_host -) +_client_context: _Lazy[Context] = _Lazy(_init_client_context) def shutdown_context() -> "Future[None]": @@ -355,9 +380,10 @@ def shutdown_context() -> "Future[None]": """ from monarch._src.actor.future import Future - local_host = _this_host_for_fake_in_process_host.try_get() - if local_host is not None: - return local_host.shutdown() + client_host_ctx = _client_context.try_get() + if client_host_ctx is not None: + host_mesh = client_host_ctx.actor_instance.proc_mesh.host_mesh + return host_mesh.shutdown() # Nothing to shutdown - return a completed future async def noop() -> None: @@ -366,47 +392,26 @@ async def noop() -> None: return Future(coro=noop()) -def _init_root_proc_mesh() -> "ProcMesh": - from monarch._src.actor.host_mesh import fake_in_process_host - - return fake_in_process_host()._spawn_nonblocking( - name="root_client_proc_mesh", - per_host=Extent([], []), - setup=None, - _attach_controller_controller=False, # can't attach the controller controller because it doesn't exist yet - ) - - -_root_proc_mesh: _Lazy["ProcMesh"] = _Lazy(_init_root_proc_mesh) - - def context() -> Context: c = _context.get(None) if c is None: - c = Context._root_client_context() - _set_context(c) - - from monarch._src.actor.host_mesh import create_local_host_mesh from monarch._src.actor.proc_mesh import _get_controller_controller from monarch._src.actor.v1 import enabled as v1_enabled if not v1_enabled: + from monarch._src.actor.host_mesh import create_local_host_mesh + + c = Context._root_client_context() + _set_context(c) c.actor_instance.proc_mesh, c.actor_instance._controller_controller = ( _get_controller_controller() ) c.actor_instance.proc_mesh._host_mesh = create_local_host_mesh() # type: ignore else: - c.actor_instance.proc_mesh = _root_proc_mesh.get() - - # This needs to be initialized when the root client context is initialized. - # Otherwise, it will be initialized inside an actor endpoint running inside - # a fake in-process host. That will fail with an "unroutable mesh" error, - # because the hyperactor Proc being used to spawn the local host mesh - # won't have the correct type of forwarder. - _this_host_for_fake_in_process_host.get() - - c.actor_instance._controller_controller = _get_controller_controller()[1] + c = _client_context.get() + _set_context(c) + _, c.actor_instance._controller_controller = _get_controller_controller() return c diff --git a/python/monarch/_src/actor/host_mesh.py b/python/monarch/_src/actor/host_mesh.py index b8918bfa5..76a55ea01 100644 --- a/python/monarch/_src/actor/host_mesh.py +++ b/python/monarch/_src/actor/host_mesh.py @@ -26,7 +26,6 @@ from monarch._src.actor.v1.host_mesh import ( _bootstrap_cmd, # noqa: F401 create_local_host_mesh as create_local_host_mesh_v1, - fake_in_process_host as fake_in_process_host_v1, host_mesh_from_alloc as host_mesh_from_alloc_v1, HostMesh as HostMeshV1, hosts_from_config as hosts_from_config_v1, @@ -215,7 +214,7 @@ def host_mesh_from_alloc_v0( this_host = this_host_v1 this_proc = this_proc_v1 create_local_host_mesh = create_local_host_mesh_v1 - fake_in_process_host = fake_in_process_host_v1 + fake_in_process_host = this_host_v1 HostMesh = HostMeshV1 hosts_from_config = hosts_from_config_v1 host_mesh_from_alloc = host_mesh_from_alloc_v1 diff --git a/python/monarch/_src/actor/v1/host_mesh.py b/python/monarch/_src/actor/v1/host_mesh.py index dec78736f..cdeb6a00c 100644 --- a/python/monarch/_src/actor/v1/host_mesh.py +++ b/python/monarch/_src/actor/v1/host_mesh.py @@ -266,6 +266,36 @@ async def task() -> HyHostMesh: None, ) + @classmethod + def _from_rust(cls, hy_host_mesh: HyHostMesh) -> "HostMesh": + """ + Create a HostMesh from a Rust HyHostMesh. + + This is used when the host was bootstrapped via bootstrap_host() + instead of being allocated through an allocator. + """ + return cls._from_initialized_hy_host_mesh( + hy_host_mesh, + hy_host_mesh.region, + stream_logs=False, + is_fake_in_process=False, + ) + + def _local_proc_mesh(self) -> "ProcMesh": + """ + Returns the local singleton proc mesh for this host. + + This is the proc mesh that contains the root client actor. + It's a singleton proc mesh (no dimensions) spawned on the host's local_proc. + """ + # Create a singleton proc mesh on this host + return self._spawn_nonblocking( + name="local_proc", + per_host=Extent([], []), + setup=None, + _attach_controller_controller=False, + ) + def __reduce_ex__(self, protocol: ...) -> Tuple[Any, Tuple[Any, ...]]: return HostMesh._from_initialized_hy_host_mesh, ( self._initialized_mesh(), @@ -349,21 +379,6 @@ async def task() -> Literal[True]: return Future(coro=task()) -def fake_in_process_host() -> "HostMesh": - """ - Create a host mesh for testing and development using a local allocator. - - Returns: - HostMesh: A host mesh configured with local allocation for in-process use. - """ - return HostMesh.allocate_nonblocking( - "fake_host", - Extent([], []), - LocalAllocator(), - bootstrap_cmd=_bootstrap_cmd(), - ) - - def hosts_from_config(name: str) -> HostMesh: """ Get the host mesh 'name' from the monarch configuration for the project. diff --git a/python/monarch/_src/actor/v1/proc_mesh.py b/python/monarch/_src/actor/v1/proc_mesh.py index 6d0e572a9..eca189cbc 100644 --- a/python/monarch/_src/actor/v1/proc_mesh.py +++ b/python/monarch/_src/actor/v1/proc_mesh.py @@ -41,14 +41,7 @@ from monarch._rust_bindings.monarch_hyperactor.v1.proc_mesh import ( ProcMesh as HyProcMesh, ) -from monarch._src.actor.actor_mesh import ( - _Actor, - _Lazy, - _this_host_for_fake_in_process_host, - Actor, - ActorMesh, - context, -) +from monarch._src.actor.actor_mesh import _Actor, _Lazy, Actor, ActorMesh, context from monarch._src.actor.allocator import AllocHandle, SimAllocator from monarch._src.actor.code_sync import ( CodeSyncMeshClient, @@ -154,15 +147,7 @@ def host_mesh(self) -> "HostMesh": raise NotImplementedError( "`ProcMesh.host_mesh` is not yet supported for non-singleton proc meshes." ) - elif self._host_mesh.is_fake_in_process: - host_mesh = _this_host_for_fake_in_process_host.try_get() - assert host_mesh is not None, ( - "Attempted to get `_this_host_for_fake_in_process_host` before the root client context " - "initialized it. This should not be possible." - ) - return host_mesh - else: - return self._host(0) + return self._host(0) @property def _ndslice(self) -> Slice: @@ -435,6 +420,18 @@ async def __aexit__( if not self._stopped: await self.stop() + @classmethod + def _from_rust(cls, hy_proc_mesh: HyProcMesh, host_mesh: "HostMesh") -> "ProcMesh": + """ + Create a HostMesh from a Rust HyProcMesh and its parent HostMesh. + """ + return cls._from_initialized_hy_proc_mesh( + hy_proc_mesh, + host_mesh, + hy_proc_mesh.region, + hy_proc_mesh.region, + ) + @classmethod def _from_initialized_hy_proc_mesh( cls, diff --git a/python/tests/test_host_mesh.py b/python/tests/test_host_mesh.py index 5bb32201b..b91ab72fc 100644 --- a/python/tests/test_host_mesh.py +++ b/python/tests/test_host_mesh.py @@ -15,11 +15,7 @@ import monarch._src.actor.host_mesh import pytest from monarch._rust_bindings.monarch_hyperactor.shape import Extent, Shape, Slice -from monarch._src.actor.actor_mesh import ( - _this_host_for_fake_in_process_host, - Actor, - context, -) +from monarch._src.actor.actor_mesh import _client_context, Actor, context from monarch._src.actor.endpoint import endpoint from monarch._src.actor.host_mesh import ( create_local_host_mesh, @@ -209,13 +205,11 @@ def test_this_host_on_controllers_can_spawn_actual_os_processes() -> None: @pytest.mark.timeout(60) def test_root_client_does_not_leak_host_meshes() -> None: - orig_get_in_process_host = _this_host_for_fake_in_process_host.get - with patch.object( - _this_host_for_fake_in_process_host, "get" - ) as mock_get_in_process_host, patch.object( + orig_get_client_context = _client_context.get + with patch.object(_client_context, "get") as mock_get_client_context, patch.object( monarch._src.actor.host_mesh, "create_local_host_mesh" ) as mock_create_local: - mock_get_in_process_host.side_effect = orig_get_in_process_host + mock_get_client_context.side_effect = orig_get_client_context def sync_sleep_then_context(): time.sleep(0.1) @@ -230,7 +224,7 @@ def sync_sleep_then_context(): for t in threads: t.join() - assert mock_get_in_process_host.call_count == 100 + assert mock_get_client_context.call_count == 100 # If this test is run in isolation, the local host mesh will # be created once. But if it runs with other tests, the host mesh # will have already been initialized and the function never gets diff --git a/python/tests/test_proc_mesh.py b/python/tests/test_proc_mesh.py index 03a0aaf2f..d5d95cb6e 100644 --- a/python/tests/test_proc_mesh.py +++ b/python/tests/test_proc_mesh.py @@ -20,7 +20,7 @@ from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask from monarch._rust_bindings.monarch_hyperactor.shape import Extent, Shape, Slice from monarch._src.actor.actor_mesh import ( - _root_proc_mesh, + _client_context, Actor, ActorMesh, context, @@ -297,11 +297,11 @@ def test_context_proc_mesh_in_controller_spawns_actor_in_client_os_process() -> @pytest.mark.timeout(60) def test_root_client_does_not_leak_proc_meshes() -> None: - orig_get_root_proc_mesh = _root_proc_mesh.get - with patch.object(_root_proc_mesh, "get") as mock_get_root_proc_mesh, patch.object( + orig_get_client_context = _client_context.get + with patch.object(_client_context, "get") as mock_get_client_context, patch.object( monarch._src.actor.host_mesh, "fake_in_process_host" ) as mock_fake_in_process_host: - mock_get_root_proc_mesh.side_effect = orig_get_root_proc_mesh + mock_get_client_context.side_effect = orig_get_client_context def sync_sleep_then_context(): time.sleep(0.1) @@ -316,7 +316,7 @@ def sync_sleep_then_context(): for t in threads: t.join() - assert mock_get_root_proc_mesh.call_count == 100 + assert mock_get_client_context.call_count == 100 # If this test is run in isolation, the local host mesh will # be created once. But if it runs with other tests, the host mesh # will have already been initialized and the function never gets From 37be47e196179823b9965636610690aaa3c2a043 Mon Sep 17 00:00:00 2001 From: Peng Zhang Date: Tue, 16 Dec 2025 10:29:44 -0800 Subject: [PATCH 4/4] tmp: add a local job for testing (#2155) Summary: Pull Request resolved: https://github.com/meta-pytorch/monarch/pull/2155 Differential Revision: D89195085 --- monarch_hyperactor/src/v1/host_mesh.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monarch_hyperactor/src/v1/host_mesh.rs b/monarch_hyperactor/src/v1/host_mesh.rs index bf5a4bcad..09cd56be0 100644 --- a/monarch_hyperactor/src/v1/host_mesh.rs +++ b/monarch_hyperactor/src/v1/host_mesh.rs @@ -15,7 +15,7 @@ use hyperactor::Instance; use hyperactor::Proc; use hyperactor_mesh::bootstrap::BootstrapCommand; use hyperactor_mesh::bootstrap::host; -use hyperactor_mesh::proc_mesh::default_transport; +use hyperactor_mesh::proc_mesh::default_bind_spec; use hyperactor_mesh::proc_mesh::mesh_agent::GetProcClient; use hyperactor_mesh::shared_cell::SharedCell; use hyperactor_mesh::v1::ProcMeshRef; @@ -274,7 +274,7 @@ fn bootstrap_host(bootstrap_cmd: Option) -> PyResult