diff --git a/hyperactor/src/channel.rs b/hyperactor/src/channel.rs index 76d2095e1..0925ceb2b 100644 --- a/hyperactor/src/channel.rs +++ b/hyperactor/src/channel.rs @@ -445,6 +445,65 @@ impl AttrValue for ChannelTransport { } } +/// Specifies how to bind a channel server. +#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Named)] +pub enum BindSpec { + /// Bind to any available address for the given transport. + Any(ChannelTransport), + + /// Bind to a specific channel address. + Addr(ChannelAddr), +} + +impl BindSpec { + /// Return an "any" address for this bind spec. + pub fn any(&self) -> ChannelAddr { + match self { + BindSpec::Any(transport) => ChannelAddr::any(transport.clone()), + BindSpec::Addr(addr) => addr.clone(), + } + } +} + +impl From for BindSpec { + fn from(transport: ChannelTransport) -> Self { + BindSpec::Any(transport) + } +} + +impl From for BindSpec { + fn from(addr: ChannelAddr) -> Self { + BindSpec::Addr(addr) + } +} + +impl fmt::Display for BindSpec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Any(transport) => write!(f, "any({})", transport), + Self::Addr(addr) => write!(f, "addr({})", addr), + } + } +} + +impl AttrValue for BindSpec { + fn display(&self) -> String { + self.to_string() + } + + fn parse(s: &str) -> Result { + if let Some(inner) = s.strip_prefix("addr(").and_then(|s| s.strip_suffix(")")) { + let addr = ChannelAddr::from_str(inner)?; + Ok(BindSpec::Addr(addr)) + } else if let Some(inner) = s.strip_prefix("any(").and_then(|s| s.strip_suffix(")")) { + let transport = ChannelTransport::from_str(inner)?; + Ok(BindSpec::Any(transport)) + } else { + Err(anyhow::anyhow!("invalid bind spec: {}", s)) + } + } +} + /// The type of (TCP) hostnames. pub type Hostname = String; @@ -505,6 +564,27 @@ pub enum ChannelAddr { /// A unix domain socket address. Supports both absolute path names as /// well as "abstract" names per https://manpages.debian.org/unstable/manpages/unix.7.en.html#Abstract_sockets Unix(net::unix::SocketAddr), + + /// A pair of addresses, one for the client and one for the server: + /// - The client should dial to the `dial_to` address. + /// - The server should bind to the `bind_to` address. + /// + /// The user is responsible for ensuring the traffic to the `dial_to` address + /// is routed to the `bind_to` address. + /// + /// This is useful for scenarios where the network is configured in a way, + /// that the bound address is not directly accessible from the client. + /// + /// For example, in AWS, the client could be provided with the public IP + /// address, yet the server is bound to a private IP address or simply + /// INADDR_ANY. Traffic to the public IP address is mapped to the private + /// IP address through network address translation (NAT). + Alias { + /// The address to which the client should dial to. + dial_to: Box, + /// The address to which the server should bind to. + bind_to: Box, + }, } impl From for ChannelAddr { @@ -602,6 +682,9 @@ impl ChannelAddr { Self::Local(_) => ChannelTransport::Local, Self::Sim(addr) => ChannelTransport::Sim(Box::new(addr.transport())), Self::Unix(_) => ChannelTransport::Unix, + // bind_to's transport is what is actually used in communication. + // Therefore we use its transport to represent the Alias. + Self::Alias { bind_to, .. } => bind_to.transport(), } } } @@ -614,6 +697,9 @@ impl fmt::Display for ChannelAddr { Self::Local(index) => write!(f, "local:{}", index), Self::Sim(sim_addr) => write!(f, "sim:{}", sim_addr), Self::Unix(addr) => write!(f, "unix:{}", addr), + Self::Alias { dial_to, bind_to } => { + write!(f, "alias:dial_to={};bind_to={}", dial_to, bind_to) + } } } } @@ -634,6 +720,11 @@ impl FromStr for ChannelAddr { Some(("metatls", rest)) => net::meta::parse(rest).map_err(|e| e.into()), Some(("sim", rest)) => sim::parse(rest).map_err(|e| e.into()), Some(("unix", rest)) => Ok(Self::Unix(net::unix::SocketAddr::from_str(rest)?)), + Some(("alias", _)) => Err(anyhow::anyhow!( + "detect possible alias address, but we currently do not support \ + parsing alias' string representation since we only want to \ + support parsing its zmq url format." + )), Some((r#type, _)) => Err(anyhow::anyhow!("no such channel type: {type}")), None => Err(anyhow::anyhow!("no channel type specified")), } @@ -647,7 +738,38 @@ impl ChannelAddr { /// - inproc://endpoint-name (equivalent to local) /// - ipc://path (equivalent to unix) /// - metatls://hostname:port or metatls://*:port + /// - Alias format: dial_to_url@bind_to_url (e.g., tcp://host:port@tcp://host:port) + /// Note: Alias format is currently only supported for TCP addresses pub fn from_zmq_url(address: &str) -> Result { + // Check for Alias format: dial_to_url@bind_to_url + // The @ character separates two valid ZMQ URLs + if let Some(at_pos) = address.find('@') { + let dial_to_str = &address[..at_pos]; + let bind_to_str = &address[at_pos + 1..]; + + // Validate that both addresses use TCP scheme + if !dial_to_str.starts_with("tcp://") { + return Err(anyhow::anyhow!( + "alias format is only supported for TCP addresses, got dial_to: {}", + dial_to_str + )); + } + if !bind_to_str.starts_with("tcp://") { + return Err(anyhow::anyhow!( + "alias format is only supported for TCP addresses, got bind_to: {}", + bind_to_str + )); + } + + let dial_to = Self::from_zmq_url(dial_to_str)?; + let bind_to = Self::from_zmq_url(bind_to_str)?; + + return Ok(Self::Alias { + dial_to: Box::new(dial_to), + bind_to: Box::new(bind_to), + }); + } + // Try ZMQ-style URL format first (scheme://...) let (scheme, address) = address.split_once("://").ok_or_else(|| { anyhow::anyhow!("address must be in url form scheme://endppoint {}", address) @@ -850,6 +972,7 @@ pub fn dial(addr: ChannelAddr) -> Result, Channel ChannelAddr::MetaTls(meta_addr) => ChannelTxKind::MetaTls(net::meta::dial(meta_addr)?), ChannelAddr::Sim(sim_addr) => ChannelTxKind::Sim(sim::dial::(sim_addr)?), ChannelAddr::Unix(path) => ChannelTxKind::Unix(net::unix::dial(path)), + ChannelAddr::Alias { dial_to, .. } => dial(*dial_to)?.inner, }; Ok(ChannelTx { inner }) } @@ -862,6 +985,19 @@ pub fn serve( addr: ChannelAddr, ) -> Result<(ChannelAddr, ChannelRx), ChannelError> { let caller = Location::caller(); + serve_inner(addr).map(|(addr, inner)| { + tracing::debug!( + name = "serve", + %addr, + %caller, + ); + (addr, ChannelRx { inner }) + }) +} + +fn serve_inner( + addr: ChannelAddr, +) -> Result<(ChannelAddr, ChannelRxKind), ChannelError> { match addr { ChannelAddr::Tcp(addr) => { let (addr, rx) = net::tcp::serve::(addr)?; @@ -887,15 +1023,15 @@ pub fn serve( "invalid local addr: {}", a ))), + ChannelAddr::Alias { dial_to, bind_to } => { + let (bound_addr, rx) = serve_inner::(*bind_to)?; + let alias_addr = ChannelAddr::Alias { + dial_to, + bind_to: Box::new(bound_addr), + }; + Ok((alias_addr, rx)) + } } - .map(|(addr, inner)| { - tracing::debug!( - name = "serve", - %addr, - %caller, - ); - (addr, ChannelRx { inner }) - }) } /// Serve on the local address. The server is turned down @@ -1066,6 +1202,48 @@ mod tests { assert!(ChannelAddr::from_zmq_url("inproc://not-a-number").is_err()); } + #[test] + fn test_zmq_style_alias_channel_addr() { + // Test Alias format: dial_to_url@bind_to_url + // The format is: dial_to_url@bind_to_url where both are valid ZMQ URLs + // Note: Alias format is only supported for TCP addresses + + // Test Alias with tcp on both sides + let alias_addr = ChannelAddr::from_zmq_url("tcp://127.0.0.1:9000@tcp://[::]:8800").unwrap(); + match alias_addr { + ChannelAddr::Alias { dial_to, bind_to } => { + assert_eq!( + *dial_to, + ChannelAddr::Tcp("127.0.0.1:9000".parse().unwrap()) + ); + assert_eq!(*bind_to, ChannelAddr::Tcp("[::]:8800".parse().unwrap())); + } + _ => panic!("Expected Alias"), + } + + // Test error: alias with non-tcp dial_to (not supported) + assert!( + ChannelAddr::from_zmq_url("metatls://example.com:443@tcp://127.0.0.1:8080").is_err() + ); + + // Test error: alias with non-tcp bind_to (not supported) + assert!( + ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@metatls://example.com:443").is_err() + ); + + // Test error: invalid dial_to URL in Alias + assert!(ChannelAddr::from_zmq_url("invalid://scheme@tcp://127.0.0.1:8080").is_err()); + + // Test error: invalid bind_to URL in Alias + assert!(ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@invalid://scheme").is_err()); + + // Test error: missing port in dial_to + assert!(ChannelAddr::from_zmq_url("tcp://host@tcp://127.0.0.1:8080").is_err()); + + // Test error: missing port in bind_to + assert!(ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@tcp://example.com").is_err()); + } + #[tokio::test] async fn test_multiple_connections() { for addr in ChannelTransport::all().map(ChannelAddr::any) { diff --git a/hyperactor_mesh/proptest-regressions/actor_mesh.txt b/hyperactor_mesh/proptest-regressions/actor_mesh.txt new file mode 100644 index 000000000..ed2fe0912 --- /dev/null +++ b/hyperactor_mesh/proptest-regressions/actor_mesh.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 7da0353e3138258986cf2598af0836656a1f7c3399a9ffa18ca93cf983b3e64c # shrinks to extent = Extent { inner: ExtentData { labels: ["d/0"], sizes: [1] } } diff --git a/hyperactor_mesh/src/bootstrap.rs b/hyperactor_mesh/src/bootstrap.rs index 43a5aad6a..accb9d7d0 100644 --- a/hyperactor_mesh/src/bootstrap.rs +++ b/hyperactor_mesh/src/bootstrap.rs @@ -266,7 +266,10 @@ async fn halt() -> R { unreachable!() } -/// Bootstrap a host in this process, returning a handle to the mesh agent: +/// Bootstrap a host in this process, returning a handle to the mesh agent. +/// +/// To obtain the local proc, use `GetLocalProc` on the returned host mesh agent, +/// then use `GetProc` on the returned proc mesh agent. /// /// - `addr`: the listening address of the host; this is used to bind the frontend address; /// - `command`: optional bootstrap command to spawn procs, otherwise [`BootstrapProcManager::current`]; diff --git a/hyperactor_mesh/src/proc_mesh.rs b/hyperactor_mesh/src/proc_mesh.rs index 0135ac429..54801be5d 100644 --- a/hyperactor_mesh/src/proc_mesh.rs +++ b/hyperactor_mesh/src/proc_mesh.rs @@ -30,6 +30,7 @@ use hyperactor::WorldId; use hyperactor::actor::ActorStatus; use hyperactor::actor::remote::Remote; use hyperactor::channel; +use hyperactor::channel::BindSpec; use hyperactor::channel::ChannelAddr; use hyperactor::channel::ChannelTransport; use hyperactor::context; @@ -96,11 +97,24 @@ declare_attrs! { env_name: Some("HYPERACTOR_MESH_DEFAULT_TRANSPORT".to_string()), py_name: Some("default_transport".to_string()), }) - pub attr DEFAULT_TRANSPORT: ChannelTransport = ChannelTransport::Unix; + pub attr DEFAULT_TRANSPORT: BindSpec = BindSpec::Any(ChannelTransport::Unix); } -/// Get the default transport type to use across the application. +/// Temporary: used to support the legacy allocator-based V1 bootstrap. Should +/// be removed once we fully migrate to simple bootstrap. +/// +/// Get the default transport to use across the application. Panic if BindSpec::Addr +/// is set as default transport. Since we expect BindSpec::Addr to be used only +/// with simple bootstrap, we should not see this panic in production. pub fn default_transport() -> ChannelTransport { + match default_bind_spec() { + BindSpec::Any(transport) => transport, + BindSpec::Addr(addr) => panic!("default_bind_spec() returned BindSpec::Addr({addr})"), + } +} + +/// Get the default bind spec to use across the application. +pub fn default_bind_spec() -> BindSpec { global::get_cloned(DEFAULT_TRANSPORT) } @@ -187,7 +201,7 @@ pub fn global_root_client() -> &'static Instance { )> = OnceLock::new(); &GLOBAL_INSTANCE.get_or_init(|| { let client_proc = Proc::direct_with_default( - ChannelAddr::any(default_transport()), + default_bind_spec().any(), "mesh_root_client_proc".into(), router::global().clone().boxed(), ) diff --git a/hyperactor_mesh/src/proc_mesh/mesh_agent.rs b/hyperactor_mesh/src/proc_mesh/mesh_agent.rs index 5a9d49a73..01f200769 100644 --- a/hyperactor_mesh/src/proc_mesh/mesh_agent.rs +++ b/hyperactor_mesh/src/proc_mesh/mesh_agent.rs @@ -795,6 +795,26 @@ impl Handler for ProcMeshAgent { } } +/// A handler to get a clone of the proc managed by this agent. +/// This is used to obtain the local proc from a host mesh. +#[derive(Debug, hyperactor::Handler, hyperactor::HandleClient)] +pub struct GetProc { + #[reply] + pub proc: PortHandle, +} + +#[async_trait] +impl Handler for ProcMeshAgent { + async fn handle( + &mut self, + _cx: &Context, + GetProc { proc }: GetProc, + ) -> anyhow::Result<()> { + proc.send(self.proc.clone())?; + Ok(()) + } +} + /// A mailbox sender that initially queues messages, and then relays them to /// an underlying sender once configured. #[derive(Clone)] diff --git a/hyperactor_mesh/src/v1.rs b/hyperactor_mesh/src/v1.rs index cf7b95821..f0475453c 100644 --- a/hyperactor_mesh/src/v1.rs +++ b/hyperactor_mesh/src/v1.rs @@ -28,6 +28,7 @@ pub use host_mesh::HostMeshRef; use hyperactor::ActorId; use hyperactor::ActorRef; use hyperactor::Named; +use hyperactor::ProcId; use hyperactor::host::HostError; use hyperactor::mailbox::MailboxSenderError; use hyperactor::reference; @@ -149,6 +150,9 @@ pub enum Error { #[error("error spawning controller actor for mesh {0}: {1}")] ControllerActorSpawnError(Name, anyhow::Error), + #[error("proc {0} must be direct-addressable")] + RankedProc(ProcId), + #[error("error: {0} does not exist")] NotExist(Name), @@ -253,7 +257,7 @@ impl Name { } /// Create a Reserved `Name` with no uuid. Only for use by system actors. - pub(crate) fn new_reserved(name: impl Into) -> Result { + pub fn new_reserved(name: impl Into) -> Result { Ok(Self::new_with_uuid(name, None)?) } diff --git a/hyperactor_mesh/src/v1/host_mesh.rs b/hyperactor_mesh/src/v1/host_mesh.rs index 9e7bbb785..32f615f36 100644 --- a/hyperactor_mesh/src/v1/host_mesh.rs +++ b/hyperactor_mesh/src/v1/host_mesh.rs @@ -147,6 +147,18 @@ impl HostRef { } } +impl TryFrom> for HostRef { + type Error = v1::Error; + + fn try_from(value: ActorRef) -> Result { + let proc_id = value.actor_id().proc_id(); + match proc_id.as_direct() { + Some((addr, _)) => Ok(HostRef(addr.clone())), + None => Err(v1::Error::RankedProc(proc_id.clone())), + } + } +} + impl std::fmt::Display for HostRef { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) @@ -734,6 +746,20 @@ impl HostMeshRef { } } + /// Create a new HostMeshRef from an arbitrary set of host mesh agents. + pub fn from_host_agents(name: Name, agents: Vec>) -> v1::Result { + Ok(Self { + name, + region: extent!(hosts = agents.len()).into(), + ranks: Arc::new( + agents + .into_iter() + .map(HostRef::try_from) + .collect::>()?, + ), + }) + } + /// Spawn a ProcMesh onto this host mesh. The per_host extent specifies the shape /// of the procs to spawn on each host. /// diff --git a/hyperactor_mesh/src/v1/proc_mesh.rs b/hyperactor_mesh/src/v1/proc_mesh.rs index a9c5b3619..c51a186bb 100644 --- a/hyperactor_mesh/src/v1/proc_mesh.rs +++ b/hyperactor_mesh/src/v1/proc_mesh.rs @@ -106,7 +106,8 @@ pub struct ProcRef { } impl ProcRef { - pub(crate) fn new(proc_id: ProcId, create_rank: usize, agent: ActorRef) -> Self { + /// Create a new proc ref from the provided id, create rank and agent. + pub fn new(proc_id: ProcId, create_rank: usize, agent: ActorRef) -> Self { Self { proc_id, create_rank, @@ -713,6 +714,20 @@ impl ProcMeshRef { }) } + /// Create a singleton ProcMeshRef, given the provided ProcRef and name. + /// This is used to support creating local singleton proc meshes to support `this_proc()` + /// in python client actors. + pub fn new_singleton(name: Name, proc_ref: ProcRef) -> Self { + Self { + name, + region: Extent::unity().into(), + ranks: Arc::new(vec![proc_ref]), + host_mesh: None, + root_region: None, + root_comm_actor: None, + } + } + pub(crate) fn root_comm_actor(&self) -> Option<&ActorRef> { self.root_comm_actor.as_ref() } diff --git a/monarch_hyperactor/src/actor.rs b/monarch_hyperactor/src/actor.rs index 62f920fff..b5d50aedd 100644 --- a/monarch_hyperactor/src/actor.rs +++ b/monarch_hyperactor/src/actor.rs @@ -27,7 +27,6 @@ use hyperactor::RemoteSpawn; use hyperactor::actor::ActorError; use hyperactor::actor::ActorErrorKind; use hyperactor::actor::ActorStatus; -use hyperactor::channel::ChannelAddr; use hyperactor::mailbox::BoxableMailboxSender; use hyperactor::mailbox::MessageEnvelope; use hyperactor::mailbox::Undeliverable; @@ -40,7 +39,7 @@ use hyperactor_config::Attrs; use hyperactor_mesh::actor_mesh::CAST_ACTOR_MESH_ID; use hyperactor_mesh::comm::multicast::CAST_ORIGINATING_SENDER; use hyperactor_mesh::comm::multicast::CastInfo; -use hyperactor_mesh::proc_mesh::default_transport; +use hyperactor_mesh::proc_mesh::default_bind_spec; use hyperactor_mesh::reference::ActorMeshId; use hyperactor_mesh::router; use hyperactor_mesh::supervision::SupervisionFailureMessage; @@ -521,18 +520,33 @@ impl PythonActor { }) } + /// Bootstrap the root client actor, creating a new proc for it. + /// This is the legacy entry point that creates its own proc. pub(crate) fn bootstrap_client(py: Python<'_>) -> (&'static Instance, ActorHandle) { static ROOT_CLIENT_INSTANCE: OnceLock> = OnceLock::new(); let client_proc = Proc::direct_with_default( - ChannelAddr::any(default_transport()), + default_bind_spec().any(), "mesh_root_client_proc".into(), router::global().clone().boxed(), ) .unwrap(); + Self::bootstrap_client_inner(py, client_proc, &ROOT_CLIENT_INSTANCE) + } + + /// Bootstrap the client proc, storing the root client instance in given static. + /// This is passed in because we require storage, as the instance is shared. + /// This can be simplified when we remove v0. + pub(crate) fn bootstrap_client_inner( + py: Python<'_>, + client_proc: Proc, + root_client_instance: &'static OnceLock>, + ) -> (&'static Instance, ActorHandle) { // Make this proc reachable through the global router, so that we can use the // same client in both direct-addressed and ranked-addressed modes. + // + // DEPRECATE after v0 removal router::global().bind(client_proc.proc_id().clone().into(), client_proc.clone()); let actor_mesh_mod = py @@ -558,7 +572,7 @@ impl PythonActor { ) .expect("root instance create"); - ROOT_CLIENT_INSTANCE + root_client_instance .set(client) .map_err(|_| "already initialized root client instance") .unwrap(); @@ -578,7 +592,7 @@ impl PythonActor { ) .expect("initialize root client"); - let instance = ROOT_CLIENT_INSTANCE.get().unwrap(); + let instance = root_client_instance.get().unwrap(); get_tokio_runtime().spawn(async move { let mut signal_rx = signal_rx; @@ -627,7 +641,7 @@ impl PythonActor { instance.proc().handle_supervision_event(event); }); - (ROOT_CLIENT_INSTANCE.get().unwrap(), handle) + (root_client_instance.get().unwrap(), handle) } } diff --git a/monarch_hyperactor/src/channel.rs b/monarch_hyperactor/src/channel.rs index e3bdebdfe..4eace31b2 100644 --- a/monarch_hyperactor/src/channel.rs +++ b/monarch_hyperactor/src/channel.rs @@ -8,16 +8,21 @@ use std::str::FromStr; +use hyperactor::channel::BindSpec; use hyperactor::channel::ChannelAddr; use hyperactor::channel::ChannelTransport; use hyperactor::channel::MetaTlsAddr; use hyperactor::channel::TcpMode; use hyperactor::channel::TlsMode; use pyo3::exceptions::PyRuntimeError; +use pyo3::exceptions::PyTypeError; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; /// Python binding for [`hyperactor::channel::ChannelTransport`] +/// +/// This enum represents the basic transport types that can be represented +/// as simple enum variants. For explicit addresses, use `PyBindSpec`. #[pyclass( name = "ChannelTransport", module = "monarch._rust_bindings.monarch_hyperactor.channel", @@ -62,6 +67,83 @@ impl TryFrom for PyChannelTransport { } } +/// Python binding for [`hyperactor::channel::BindSpec`] +#[pyclass( + name = "BindSpec", + module = "monarch._rust_bindings.monarch_hyperactor.channel" +)] +#[derive(Clone, Debug, PartialEq)] +pub struct PyBindSpec { + inner: BindSpec, +} + +#[pymethods] +impl PyBindSpec { + /// Create a new PyBindSpec from a ChannelTransport enum, a string representation, + /// or another PyBindSpec object. + /// + /// Examples: + /// PyBindSpec(ChannelTransport.Unix) + /// PyBindSpec("tcp://127.0.0.1:8080") + /// PyBindSpec(PyBindSpec(ChannelTransport.Unix)) + #[new] + pub fn new(spec: &Bound<'_, PyAny>) -> PyResult { + // First try to extract as PyBindSpec (for when passing an existing spec) + if let Ok(bind_spec) = spec.extract::() { + return Ok(bind_spec); + } + + // Then try to extract as PyChannelTransport enum + if let Ok(py_transport) = spec.extract::() { + let transport: ChannelTransport = py_transport.into(); + return Ok(PyBindSpec { + inner: BindSpec::Any(transport), + }); + } + + // Then try to extract as a string and parse it as a ZMQ URL + if let Ok(spec_str) = spec.extract::() { + let addr = ChannelAddr::from_zmq_url(&spec_str).map_err(|e| { + PyValueError::new_err(format!( + "invalid ZMQ URL for address binding '{}': {}", + spec_str, e + )) + })?; + return Ok(PyBindSpec { + inner: BindSpec::Addr(addr), + }); + } + + Err(PyTypeError::new_err( + "expected ChannelTransport enum, BindSpec, or str", + )) + } + + fn __str__(&self) -> String { + self.inner.to_string() + } + + fn __repr__(&self) -> String { + format!("PyBindSpec({:?})", self.inner) + } + + fn __eq__(&self, other: &Self) -> bool { + self.inner == other.inner + } +} + +impl From for BindSpec { + fn from(spec: PyBindSpec) -> Self { + spec.inner + } +} + +impl From for PyBindSpec { + fn from(spec: BindSpec) -> Self { + PyBindSpec { inner: spec } + } +} + #[pyclass( name = "ChannelAddr", module = "monarch._rust_bindings.monarch_hyperactor.channel" @@ -149,6 +231,7 @@ impl From for ChannelTransport { #[pymodule] pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> { hyperactor_mod.add_class::()?; + hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; Ok(()) } diff --git a/monarch_hyperactor/src/config.rs b/monarch_hyperactor/src/config.rs index 4136e80ee..84daf8282 100644 --- a/monarch_hyperactor/src/config.rs +++ b/monarch_hyperactor/src/config.rs @@ -21,7 +21,7 @@ use std::fmt::Debug; use std::time::Duration; use hyperactor::Named; -use hyperactor::channel::ChannelTransport; +use hyperactor::channel::BindSpec; use hyperactor_config::AttrValue; use hyperactor_config::CONFIG; use hyperactor_config::ConfigAttr; @@ -30,12 +30,13 @@ use hyperactor_config::attrs::Attrs; use hyperactor_config::attrs::ErasedKey; use hyperactor_config::attrs::declare_attrs; use hyperactor_config::global::Source; +use pyo3::conversion::IntoPyObject; use pyo3::conversion::IntoPyObjectExt; use pyo3::exceptions::PyTypeError; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; -use crate::channel::PyChannelTransport; +use crate::channel::PyBindSpec; /// Python wrapper for Duration, using humantime format strings. /// @@ -289,9 +290,9 @@ inventory::collect!(PythonConfigTypeInfo); /// like `String` that are convertible directly to/from PyObjects, /// you can just use `declare_py_config_type!(String)`. For types /// that must first be converted to/from a rust python wrapper -/// (e.g., keys with type `ChannelTransport` must use `PyChannelTransport` +/// (e.g., keys with type `BindSpec` must use `PyBindSpec` /// as an intermediate step), the usage is -/// `declare_py_config_type!(PyChannelTransport as ChannelTransport)`. +/// `declare_py_config_type!(PyBindSpec as BindSpec)`. macro_rules! declare_py_config_type { ($($ty:ty),+ $(,)?) => { hyperactor::paste! { @@ -341,7 +342,7 @@ macro_rules! declare_py_config_type { }; } -declare_py_config_type!(PyChannelTransport as ChannelTransport); +declare_py_config_type!(PyBindSpec as BindSpec); declare_py_config_type!(PyDuration as Duration); declare_py_config_type!( i8, i16, i32, i64, u8, u16, u32, u64, usize, f32, f64, bool, String @@ -367,9 +368,19 @@ declare_py_config_type!( fn configure(py: Python<'_>, kwargs: Option>) -> PyResult<()> { kwargs .map(|kwargs| { - kwargs - .into_iter() - .try_for_each(|(key, val)| configure_kwarg(py, &key, val)) + kwargs.into_iter().try_for_each(|(key, val)| { + // Special handling for default_transport: convert ChannelTransport + // enum or string to PyBindSpec before processing + let val = if key == "default_transport" { + PyBindSpec::new(val.bind(py))? + .into_pyobject(py)? + .into_any() + .unbind() + } else { + val + }; + configure_kwarg(py, &key, val) + }) }) .transpose()?; Ok(()) diff --git a/monarch_hyperactor/src/context.rs b/monarch_hyperactor/src/context.rs index 38e2067f7..162935c78 100644 --- a/monarch_hyperactor/src/context.rs +++ b/monarch_hyperactor/src/context.rs @@ -126,6 +126,17 @@ impl PyContext { rank: Extent::unity().point_of_rank(0).unwrap(), }) } + + /// Create a context from an existing instance. + /// This is used when the root client was bootstrapped via bootstrap_host() + /// instead of the default bootstrap_client(). + #[staticmethod] + fn _from_instance(py: Python<'_>, instance: PyInstance) -> PyResult { + Ok(PyContext { + instance: instance.into_pyobject(py)?.into(), + rank: Extent::unity().point_of_rank(0).unwrap(), + }) + } } impl PyContext { diff --git a/monarch_hyperactor/src/v1/host_mesh.rs b/monarch_hyperactor/src/v1/host_mesh.rs index 58e1666ba..09cd56be0 100644 --- a/monarch_hyperactor/src/v1/host_mesh.rs +++ b/monarch_hyperactor/src/v1/host_mesh.rs @@ -9,11 +9,20 @@ use std::collections::HashMap; use std::ops::Deref; use std::path::PathBuf; +use std::sync::OnceLock; +use hyperactor::Instance; +use hyperactor::Proc; use hyperactor_mesh::bootstrap::BootstrapCommand; +use hyperactor_mesh::bootstrap::host; +use hyperactor_mesh::proc_mesh::default_bind_spec; +use hyperactor_mesh::proc_mesh::mesh_agent::GetProcClient; use hyperactor_mesh::shared_cell::SharedCell; +use hyperactor_mesh::v1::ProcMeshRef; use hyperactor_mesh::v1::host_mesh::HostMesh; use hyperactor_mesh::v1::host_mesh::HostMeshRef; +use hyperactor_mesh::v1::host_mesh::mesh_agent::GetLocalProcClient; +use hyperactor_mesh::v1::proc_mesh::ProcRef; use ndslice::View; use ndslice::view::RankedSliceable; use pyo3::IntoPyObjectExt; @@ -24,6 +33,7 @@ use pyo3::prelude::*; use pyo3::types::PyBytes; use pyo3::types::PyType; +use crate::actor::PythonActor; use crate::actor::to_py_error; use crate::alloc::PyAlloc; use crate::context::PyInstance; @@ -240,6 +250,76 @@ impl PyHostMeshRefImpl { } } +/// Static storage for the root client instance when using host-based bootstrap. +static ROOT_CLIENT_INSTANCE_FOR_HOST: OnceLock> = OnceLock::new(); + +/// Bootstrap the client host and root client actor. +/// +/// This creates a proper Host with BootstrapProcManager, spawns the root client +/// actor on the Host's local_proc. +/// +/// Returns a tuple of (HostMesh, ProcMesh, PyInstance) where: +/// - PyHostMesh: the bootstrapped (local) host mesh; and +/// - PyProcMesh: the local ProcMesh on this HostMesh; and +/// - PyInstance: the root client actor instance, on the ProcMesh. +/// +/// The HostMesh is served on the default transport. +/// +/// This should be called only once, at process initialization +#[pyfunction] +fn bootstrap_host(bootstrap_cmd: Option) -> PyResult { + let bootstrap_cmd = match bootstrap_cmd { + Some(cmd) => cmd.to_rust(), + None => BootstrapCommand::current().map_err(|e| PyException::new_err(e.to_string()))?, + }; + + PyPythonTask::new(async move { + let host_mesh_agent = host(default_bind_spec().any(), Some(bootstrap_cmd), None) + .await + .map_err(|e| PyException::new_err(e.to_string()))?; + + let host_mesh_name = hyperactor_mesh::v1::Name::new_reserved("local").unwrap(); + let host_mesh = HostMeshRef::from_host_agents(host_mesh_name, vec![host_mesh_agent.bind()]) + .map_err(|e| PyException::new_err(e.to_string()))?; + + // We require a temporary instance to make a call to the host/proc agent. + let temp_proc = Proc::local(); + let (temp_instance, _) = temp_proc + .instance("temp") + .map_err(|e| PyException::new_err(e.to_string()))?; + + let local_proc_agent = host_mesh_agent + .get_local_proc(&temp_instance) + .await + .map_err(|e| PyException::new_err(e.to_string()))?; + + let proc_mesh_name = hyperactor_mesh::v1::Name::new_reserved("local").unwrap(); + let proc_mesh = ProcMeshRef::new_singleton( + proc_mesh_name, + ProcRef::new( + local_proc_agent.actor_id().proc_id().clone(), + 0, + local_proc_agent.bind(), + ), + ); + + let local_proc = local_proc_agent + .get_proc(&temp_instance) + .await + .map_err(|e| PyException::new_err(e.to_string()))?; + + let (instance, _handle) = Python::with_gil(|py| { + PythonActor::bootstrap_client_inner(py, local_proc, &ROOT_CLIENT_INSTANCE_FOR_HOST) + }); + + Ok(( + PyHostMesh::new_ref(host_mesh), + PyProcMesh::new_ref(proc_mesh), + PyInstance::from(instance), + )) + }) +} + #[pyfunction] fn py_host_mesh_from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult { let r: PyResult = bincode::deserialize(bytes.as_bytes()) @@ -254,6 +334,14 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul "monarch._rust_bindings.monarch_hyperactor.v1.host_mesh", )?; hyperactor_mod.add_function(f)?; + + let f2 = wrap_pyfunction!(bootstrap_host, hyperactor_mod)?; + f2.setattr( + "__module__", + "monarch._rust_bindings.monarch_hyperactor.v1.host_mesh", + )?; + hyperactor_mod.add_function(f2)?; + hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; Ok(()) diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/channel.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/channel.pyi index fea7a1f5f..7b15fa40b 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/channel.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/channel.pyi @@ -9,6 +9,11 @@ from enum import Enum class ChannelTransport(Enum): + """ + Enum representing basic transport types for channels. + For explicit address binding, use BindSpec instead. + """ + TcpWithLocalhost = "tcp(localhost)" TcpWithHostname = "tcp(hostname)" MetaTlsWithHostname = "metatls(hostname)" @@ -17,6 +22,27 @@ class ChannelTransport(Enum): Unix = "unix" # Sim # TODO add support +class BindSpec: + """ + Specify how to bind a channel server. + + Can be created from either a ChannelTransport enum for "any" binding, + or a ZMQ-style URL string for explicit address. + + Note: This class is for internal use only. Users should pass ChannelTransport + enum directly or use the ZMQ-style URL string for explicit address. + """ + + def __init__(self, spec: ChannelTransport | str) -> None: ... + """ + Basic transport types supported by ChannelTransport should be used directly as enum values. + For explicit address binding, use a ZMQ-style URL string. e.g.: + - "tcp://127.0.0.1:8080" + """ + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __eq__(self, other: object) -> bool: ... + class ChannelAddr: @staticmethod def any(transport: ChannelTransport) -> str: diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi index 7ae4446a7..b384b5cc8 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi @@ -32,7 +32,7 @@ def reset_config_to_defaults() -> None: ... def configure( - default_transport: ChannelTransport = ..., + default_transport: ChannelTransport | str = ..., enable_log_forwarding: bool = ..., enable_file_capture: bool = ..., tail_log_lines: int = ..., @@ -50,7 +50,9 @@ def configure( plus any additional CONFIG-marked keys passed via **kwargs. Args: - default_transport: Default channel transport for communication + default_transport: Default channel transport for communication. Can be: + - A ChannelTransport enum value (e.g., ChannelTransport.Unix) + - A explicit address string in the ZMQ-style URL format (e.g., "tcp://127.0.0.1:8080") enable_log_forwarding: Whether to forward logs from actors enable_file_capture: Whether to capture file output tail_log_lines: Number of log lines to tail diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/v1/host_mesh.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/v1/host_mesh.pyi index adf78bef9..4b52bc866 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/v1/host_mesh.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/v1/host_mesh.pyi @@ -101,3 +101,15 @@ class BootstrapCommand: ... def __repr__(self) -> str: ... + +def bootstrap_host( + bootstrap_cmd: BootstrapCommand | None, +) -> PythonTask[(HostMesh, ProcMesh, Instance)]: + """ + Bootstrap a host mesh in this process, returning the host mesh, + proc mesh, and client instance. + + Arguments: + - `bootstrap_cmd`: The bootstrap command to use to bootstrap the host. + """ + ... diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index 6eeddff27..4e594b7f6 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -53,7 +53,7 @@ ) from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh from monarch._rust_bindings.monarch_hyperactor.buffers import FrozenBuffer -from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport +from monarch._rust_bindings.monarch_hyperactor.channel import BindSpec, ChannelTransport from monarch._rust_bindings.monarch_hyperactor.config import configure from monarch._rust_bindings.monarch_hyperactor.context import Instance as HyInstance from monarch._rust_bindings.monarch_hyperactor.mailbox import ( @@ -264,6 +264,9 @@ def message_rank(self) -> Point: @staticmethod def _root_client_context() -> "Context": ... + @staticmethod + def _from_instance(instance: Instance) -> "Context": ... + _context: contextvars.ContextVar[Context] = contextvars.ContextVar( "monarch.actor_mesh._context" @@ -307,9 +310,12 @@ def _patched_addHandler(self: logging.Logger, hdlr: logging.Handler) -> None: logging.Logger.addHandler = _patched_addHandler -def _set_context(c: Context) -> None: +def _set_context(c: Context) -> contextvars.Token[Context]: _init_context_log_handler() - _context.set(c) + return _context.set(c) + +def _reset_context(c: contextvars.Token[Context]): + _context.reset(c) T = TypeVar("T") @@ -331,15 +337,34 @@ def try_get(self) -> Optional[T]: return self._val -def _init_this_host_for_fake_in_process_host() -> "HostMesh": - from monarch._src.actor.host_mesh import create_local_host_mesh +def _init_client_context() -> Context: + """ + Create a client context that bootstraps an actor instance running on a real + local proc mesh on a real local host mesh. + """ + from monarch._rust_bindings.monarch_hyperactor.v1.host_mesh import bootstrap_host + from monarch._src.actor.host_mesh import HostMesh + from monarch._src.actor.proc_mesh import ProcMesh + from monarch._src.actor.v1.host_mesh import _bootstrap_cmd - return create_local_host_mesh() + rust_host_mesh, rust_proc_mesh, py_instance = bootstrap_host( + _bootstrap_cmd() + ).block_on() + ctx = Context._from_instance(py_instance) + # Set the context here to avoid recursive context creation: + token = _set_context(ctx) + try: + py_host_mesh = HostMesh._from_rust(rust_host_mesh) + py_proc_mesh = ProcMesh._from_rust(rust_proc_mesh, py_host_mesh) + finally: + _reset_context(token) -_this_host_for_fake_in_process_host: _Lazy["HostMesh"] = _Lazy( - _init_this_host_for_fake_in_process_host -) + ctx.actor_instance.proc_mesh = py_proc_mesh + return ctx + + +_client_context: _Lazy[Context] = _Lazy(_init_client_context) def shutdown_context() -> "Future[None]": @@ -355,9 +380,10 @@ def shutdown_context() -> "Future[None]": """ from monarch._src.actor.future import Future - local_host = _this_host_for_fake_in_process_host.try_get() - if local_host is not None: - return local_host.shutdown() + client_host_ctx = _client_context.try_get() + if client_host_ctx is not None: + host_mesh = client_host_ctx.actor_instance.proc_mesh.host_mesh + return host_mesh.shutdown() # Nothing to shutdown - return a completed future async def noop() -> None: @@ -366,51 +392,30 @@ async def noop() -> None: return Future(coro=noop()) -def _init_root_proc_mesh() -> "ProcMesh": - from monarch._src.actor.host_mesh import fake_in_process_host - - return fake_in_process_host()._spawn_nonblocking( - name="root_client_proc_mesh", - per_host=Extent([], []), - setup=None, - _attach_controller_controller=False, # can't attach the controller controller because it doesn't exist yet - ) - - -_root_proc_mesh: _Lazy["ProcMesh"] = _Lazy(_init_root_proc_mesh) - - def context() -> Context: c = _context.get(None) if c is None: - c = Context._root_client_context() - _set_context(c) - - from monarch._src.actor.host_mesh import create_local_host_mesh from monarch._src.actor.proc_mesh import _get_controller_controller from monarch._src.actor.v1 import enabled as v1_enabled if not v1_enabled: + from monarch._src.actor.host_mesh import create_local_host_mesh + + c = Context._root_client_context() + _set_context(c) c.actor_instance.proc_mesh, c.actor_instance._controller_controller = ( _get_controller_controller() ) c.actor_instance.proc_mesh._host_mesh = create_local_host_mesh() # type: ignore else: - c.actor_instance.proc_mesh = _root_proc_mesh.get() - - # This needs to be initialized when the root client context is initialized. - # Otherwise, it will be initialized inside an actor endpoint running inside - # a fake in-process host. That will fail with an "unroutable mesh" error, - # because the hyperactor Proc being used to spawn the local host mesh - # won't have the correct type of forwarder. - _this_host_for_fake_in_process_host.get() - - c.actor_instance._controller_controller = _get_controller_controller()[1] + c = _client_context.get() + _set_context(c) + _, c.actor_instance._controller_controller = _get_controller_controller() return c -_transport: Optional[ChannelTransport] = None +_transport: Optional[BindSpec] = None _transport_lock = threading.Lock() @@ -424,17 +429,33 @@ def enable_transport(transport: "ChannelTransport | str") -> None: Currently only one transport type may be enabled at one time. In the future we may allow multiple to be enabled. + Supported transport values: + - ChannelTransport enum: ChannelTransport.Unix, ChannelTransport.TcpWithHostname, etc. + - string short cuts for the ChannelTransport enum: + - "tcp": ChannelTransport.TcpWithHostname + - "ipc": ChannelTransport.Unix + - "metatls": ChannelTransport.MetaTlsWithIpV6 + - "metatls-hostname": ChannelTransport.MetaTlsWithHostname + - ZMQ-style URL format string for explicit address, e.g.: + - "tcp://127.0.0.1:8080" + For Meta usage, use metatls-hostname """ if isinstance(transport, str): - transport = { + # Handle string shortcuts for the ChannelTransport enum, + resolved = { "tcp": ChannelTransport.TcpWithHostname, "ipc": ChannelTransport.Unix, "metatls": ChannelTransport.MetaTlsWithIpV6, "metatls-hostname": ChannelTransport.MetaTlsWithHostname, }.get(transport) - if transport is None: - raise ValueError(f"unknown transport: {transport}") + if resolved is not None: + transport_config = BindSpec(resolved) + else: + transport_config = BindSpec(transport) + else: + # ChannelTransport enum + transport_config = BindSpec(transport) if _context.get(None) is not None: raise RuntimeError( @@ -445,14 +466,16 @@ def enable_transport(transport: "ChannelTransport | str") -> None: global _transport with _transport_lock: - if _transport is not None and _transport != transport: + if _transport is not None and _transport != transport_config: raise RuntimeError( f"Only one transport type may be enabled at one time. " f"Currently enabled transport type is `{_transport}`. " - f"Attempted to enable transport type `{transport}`." + f"Attempted to enable transport type `{transport_config}`." ) - _transport = transport - configure(default_transport=transport) + _transport = transport_config + # pyre-ignore[6]: BindSpec is accepted by configure. We just do not expose + # it in the method's signature since BindSpec is not a public type. + configure(default_transport=transport_config) @dataclass diff --git a/python/monarch/_src/actor/host_mesh.py b/python/monarch/_src/actor/host_mesh.py index b8918bfa5..76a55ea01 100644 --- a/python/monarch/_src/actor/host_mesh.py +++ b/python/monarch/_src/actor/host_mesh.py @@ -26,7 +26,6 @@ from monarch._src.actor.v1.host_mesh import ( _bootstrap_cmd, # noqa: F401 create_local_host_mesh as create_local_host_mesh_v1, - fake_in_process_host as fake_in_process_host_v1, host_mesh_from_alloc as host_mesh_from_alloc_v1, HostMesh as HostMeshV1, hosts_from_config as hosts_from_config_v1, @@ -215,7 +214,7 @@ def host_mesh_from_alloc_v0( this_host = this_host_v1 this_proc = this_proc_v1 create_local_host_mesh = create_local_host_mesh_v1 - fake_in_process_host = fake_in_process_host_v1 + fake_in_process_host = this_host_v1 HostMesh = HostMeshV1 hosts_from_config = hosts_from_config_v1 host_mesh_from_alloc = host_mesh_from_alloc_v1 diff --git a/python/monarch/_src/actor/v1/host_mesh.py b/python/monarch/_src/actor/v1/host_mesh.py index dec78736f..cdeb6a00c 100644 --- a/python/monarch/_src/actor/v1/host_mesh.py +++ b/python/monarch/_src/actor/v1/host_mesh.py @@ -266,6 +266,36 @@ async def task() -> HyHostMesh: None, ) + @classmethod + def _from_rust(cls, hy_host_mesh: HyHostMesh) -> "HostMesh": + """ + Create a HostMesh from a Rust HyHostMesh. + + This is used when the host was bootstrapped via bootstrap_host() + instead of being allocated through an allocator. + """ + return cls._from_initialized_hy_host_mesh( + hy_host_mesh, + hy_host_mesh.region, + stream_logs=False, + is_fake_in_process=False, + ) + + def _local_proc_mesh(self) -> "ProcMesh": + """ + Returns the local singleton proc mesh for this host. + + This is the proc mesh that contains the root client actor. + It's a singleton proc mesh (no dimensions) spawned on the host's local_proc. + """ + # Create a singleton proc mesh on this host + return self._spawn_nonblocking( + name="local_proc", + per_host=Extent([], []), + setup=None, + _attach_controller_controller=False, + ) + def __reduce_ex__(self, protocol: ...) -> Tuple[Any, Tuple[Any, ...]]: return HostMesh._from_initialized_hy_host_mesh, ( self._initialized_mesh(), @@ -349,21 +379,6 @@ async def task() -> Literal[True]: return Future(coro=task()) -def fake_in_process_host() -> "HostMesh": - """ - Create a host mesh for testing and development using a local allocator. - - Returns: - HostMesh: A host mesh configured with local allocation for in-process use. - """ - return HostMesh.allocate_nonblocking( - "fake_host", - Extent([], []), - LocalAllocator(), - bootstrap_cmd=_bootstrap_cmd(), - ) - - def hosts_from_config(name: str) -> HostMesh: """ Get the host mesh 'name' from the monarch configuration for the project. diff --git a/python/monarch/_src/actor/v1/proc_mesh.py b/python/monarch/_src/actor/v1/proc_mesh.py index 6d0e572a9..eca189cbc 100644 --- a/python/monarch/_src/actor/v1/proc_mesh.py +++ b/python/monarch/_src/actor/v1/proc_mesh.py @@ -41,14 +41,7 @@ from monarch._rust_bindings.monarch_hyperactor.v1.proc_mesh import ( ProcMesh as HyProcMesh, ) -from monarch._src.actor.actor_mesh import ( - _Actor, - _Lazy, - _this_host_for_fake_in_process_host, - Actor, - ActorMesh, - context, -) +from monarch._src.actor.actor_mesh import _Actor, _Lazy, Actor, ActorMesh, context from monarch._src.actor.allocator import AllocHandle, SimAllocator from monarch._src.actor.code_sync import ( CodeSyncMeshClient, @@ -154,15 +147,7 @@ def host_mesh(self) -> "HostMesh": raise NotImplementedError( "`ProcMesh.host_mesh` is not yet supported for non-singleton proc meshes." ) - elif self._host_mesh.is_fake_in_process: - host_mesh = _this_host_for_fake_in_process_host.try_get() - assert host_mesh is not None, ( - "Attempted to get `_this_host_for_fake_in_process_host` before the root client context " - "initialized it. This should not be possible." - ) - return host_mesh - else: - return self._host(0) + return self._host(0) @property def _ndslice(self) -> Slice: @@ -435,6 +420,18 @@ async def __aexit__( if not self._stopped: await self.stop() + @classmethod + def _from_rust(cls, hy_proc_mesh: HyProcMesh, host_mesh: "HostMesh") -> "ProcMesh": + """ + Create a HostMesh from a Rust HyProcMesh and its parent HostMesh. + """ + return cls._from_initialized_hy_proc_mesh( + hy_proc_mesh, + host_mesh, + hy_proc_mesh.region, + hy_proc_mesh.region, + ) + @classmethod def _from_initialized_hy_proc_mesh( cls, diff --git a/python/tests/test_config.py b/python/tests/test_config.py index 45df1f89c..188b82374 100644 --- a/python/tests/test_config.py +++ b/python/tests/test_config.py @@ -10,7 +10,8 @@ import monarch import pytest -from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport + +from monarch._rust_bindings.monarch_hyperactor.channel import BindSpec, ChannelTransport from monarch._rust_bindings.monarch_hyperactor.supervision import SupervisionError from monarch.actor import Actor, endpoint, this_proc from monarch.config import configured, get_global_config @@ -34,12 +35,12 @@ def test_get_set_transport() -> None: ChannelTransport.MetaTlsWithHostname, ): with configured(default_transport=transport) as config: - assert config["default_transport"] == transport + assert config["default_transport"] == BindSpec(transport) # Succeed even if we don't specify the transport, but does not change the # previous value. with configured() as config: - assert config["default_transport"] == ChannelTransport.Unix - with pytest.raises(TypeError): + assert config["default_transport"] == BindSpec(ChannelTransport.Unix) + with pytest.raises(ValueError): with configured(default_transport="unix"): # type: ignore pass with pytest.raises(TypeError): @@ -50,6 +51,22 @@ def test_get_set_transport() -> None: pass +def test_get_set_explicit_transport() -> None: + # Test explicit transport with a TCP address + with configured(default_transport="tcp://127.0.0.1:8080") as config: + assert config["default_transport"] == BindSpec("tcp://127.0.0.1:8080") + + # Test that invalid explicit transport strings raise an error + with pytest.raises(ValueError): + with configured(default_transport="invalid://scheme"): + pass + + # Test that random strings (not ZMQ URL format) raise an error + with pytest.raises(ValueError): + with configured(default_transport="random_string"): + pass + + def test_nonexistent_config_key() -> None: with pytest.raises(ValueError): with configured(does_not_exist=42): # type: ignore @@ -64,13 +81,15 @@ def test_get_set_multiple() -> None: assert config["enable_log_forwarding"] assert config["enable_file_capture"] assert config["tail_log_lines"] == 100 - assert config["default_transport"] == ChannelTransport.TcpWithLocalhost + assert config["default_transport"] == BindSpec( + ChannelTransport.TcpWithLocalhost + ) # Make sure the previous values are restored. config = get_global_config() assert not config["enable_log_forwarding"] assert not config["enable_file_capture"] assert config["tail_log_lines"] == 0 - assert config["default_transport"] == ChannelTransport.Unix + assert config["default_transport"] == BindSpec(ChannelTransport.Unix) # This test tries to allocate too much memory for the GitHub actions @@ -219,7 +238,9 @@ def test_duration_config_multiple() -> None: enable_log_forwarding=True, tail_log_lines=100, ) as config: - assert config["default_transport"] == ChannelTransport.TcpWithLocalhost + assert config["default_transport"] == BindSpec( + ChannelTransport.TcpWithLocalhost + ) assert config["host_spawn_ready_timeout"] == "10m" assert config["message_delivery_timeout"] == "5m" assert config["mesh_proc_spawn_max_idle"] == "2m" @@ -228,7 +249,7 @@ def test_duration_config_multiple() -> None: # Verify all values are restored config = get_global_config() - assert config["default_transport"] == ChannelTransport.Unix + assert config["default_transport"] == BindSpec(ChannelTransport.Unix) assert config["host_spawn_ready_timeout"] == "30s" assert config["message_delivery_timeout"] == "30s" assert config["mesh_proc_spawn_max_idle"] == "30s" diff --git a/python/tests/test_host_mesh.py b/python/tests/test_host_mesh.py index 5bb32201b..b91ab72fc 100644 --- a/python/tests/test_host_mesh.py +++ b/python/tests/test_host_mesh.py @@ -15,11 +15,7 @@ import monarch._src.actor.host_mesh import pytest from monarch._rust_bindings.monarch_hyperactor.shape import Extent, Shape, Slice -from monarch._src.actor.actor_mesh import ( - _this_host_for_fake_in_process_host, - Actor, - context, -) +from monarch._src.actor.actor_mesh import _client_context, Actor, context from monarch._src.actor.endpoint import endpoint from monarch._src.actor.host_mesh import ( create_local_host_mesh, @@ -209,13 +205,11 @@ def test_this_host_on_controllers_can_spawn_actual_os_processes() -> None: @pytest.mark.timeout(60) def test_root_client_does_not_leak_host_meshes() -> None: - orig_get_in_process_host = _this_host_for_fake_in_process_host.get - with patch.object( - _this_host_for_fake_in_process_host, "get" - ) as mock_get_in_process_host, patch.object( + orig_get_client_context = _client_context.get + with patch.object(_client_context, "get") as mock_get_client_context, patch.object( monarch._src.actor.host_mesh, "create_local_host_mesh" ) as mock_create_local: - mock_get_in_process_host.side_effect = orig_get_in_process_host + mock_get_client_context.side_effect = orig_get_client_context def sync_sleep_then_context(): time.sleep(0.1) @@ -230,7 +224,7 @@ def sync_sleep_then_context(): for t in threads: t.join() - assert mock_get_in_process_host.call_count == 100 + assert mock_get_client_context.call_count == 100 # If this test is run in isolation, the local host mesh will # be created once. But if it runs with other tests, the host mesh # will have already been initialized and the function never gets diff --git a/python/tests/test_proc_mesh.py b/python/tests/test_proc_mesh.py index 03a0aaf2f..d5d95cb6e 100644 --- a/python/tests/test_proc_mesh.py +++ b/python/tests/test_proc_mesh.py @@ -20,7 +20,7 @@ from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask from monarch._rust_bindings.monarch_hyperactor.shape import Extent, Shape, Slice from monarch._src.actor.actor_mesh import ( - _root_proc_mesh, + _client_context, Actor, ActorMesh, context, @@ -297,11 +297,11 @@ def test_context_proc_mesh_in_controller_spawns_actor_in_client_os_process() -> @pytest.mark.timeout(60) def test_root_client_does_not_leak_proc_meshes() -> None: - orig_get_root_proc_mesh = _root_proc_mesh.get - with patch.object(_root_proc_mesh, "get") as mock_get_root_proc_mesh, patch.object( + orig_get_client_context = _client_context.get + with patch.object(_client_context, "get") as mock_get_client_context, patch.object( monarch._src.actor.host_mesh, "fake_in_process_host" ) as mock_fake_in_process_host: - mock_get_root_proc_mesh.side_effect = orig_get_root_proc_mesh + mock_get_client_context.side_effect = orig_get_client_context def sync_sleep_then_context(): time.sleep(0.1) @@ -316,7 +316,7 @@ def sync_sleep_then_context(): for t in threads: t.join() - assert mock_get_root_proc_mesh.call_count == 100 + assert mock_get_client_context.call_count == 100 # If this test is run in isolation, the local host mesh will # be created once. But if it runs with other tests, the host mesh # will have already been initialized and the function never gets