From fe790cc085b45f19c52801b5a182dfdd0142e024 Mon Sep 17 00:00:00 2001 From: Peng Zhang Date: Mon, 15 Dec 2025 09:15:30 -0800 Subject: [PATCH 1/3] Add Channel::Alias Differential Revision: D89190085 --- hyperactor/src/channel.rs | 135 +++++++++++++++++++++++++++++++++++--- 1 file changed, 127 insertions(+), 8 deletions(-) diff --git a/hyperactor/src/channel.rs b/hyperactor/src/channel.rs index 94451d209..ea5604307 100644 --- a/hyperactor/src/channel.rs +++ b/hyperactor/src/channel.rs @@ -507,6 +507,27 @@ pub enum ChannelAddr { /// A unix domain socket address. Supports both absolute path names as /// well as "abstract" names per https://manpages.debian.org/unstable/manpages/unix.7.en.html#Abstract_sockets Unix(net::unix::SocketAddr), + + /// A pair of addresses, one for the client and one for the server: + /// - The client should dial to the `dial_to` address. + /// - The server should bind to the `bind_to` address. + /// + /// The user is responsible for ensuring the traffic to the `dial_to` address + /// is routed to the `bind_to` address. + /// + /// This is useful for scenarios where the network is confgiured in a way, + /// that the bound address is not directly accessible from the client. + /// + /// For example, in AWS, the client could be provided with the public IP + /// address, yet the server is bound to a private IP address or simply + /// INADDR_ANY. Traffic to the public IP address is mapped to the private + /// IP address through network address translation (NAT). + Alias { + /// The address to which the client should dial to. + dial_to: Box, + /// The address to which the server should bind to. + bind_to: Box, + }, } impl From for ChannelAddr { @@ -604,6 +625,9 @@ impl ChannelAddr { Self::Local(_) => ChannelTransport::Local, Self::Sim(addr) => ChannelTransport::Sim(Box::new(addr.transport())), Self::Unix(_) => ChannelTransport::Unix, + // bind_to's transport is what is actually used in communcation. + // Therefore we use its transport to represent the Alias. + Self::Alias { bind_to, .. } => bind_to.transport(), } } } @@ -616,6 +640,9 @@ impl fmt::Display for ChannelAddr { Self::Local(index) => write!(f, "local:{}", index), Self::Sim(sim_addr) => write!(f, "sim:{}", sim_addr), Self::Unix(addr) => write!(f, "unix:{}", addr), + Self::Alias { dial_to, bind_to } => { + write!(f, "alias:dial_to={};bind_to={}", dial_to, bind_to) + } } } } @@ -636,6 +663,11 @@ impl FromStr for ChannelAddr { Some(("metatls", rest)) => net::meta::parse(rest).map_err(|e| e.into()), Some(("sim", rest)) => sim::parse(rest).map_err(|e| e.into()), Some(("unix", rest)) => Ok(Self::Unix(net::unix::SocketAddr::from_str(rest)?)), + Some(("alias", _)) => Err(anyhow::anyhow!( + "detect possible alias address, but we currently do not support \ + parsing alias' string representation since we only want to \ + support parsing its zmq url format." + )), Some((r#type, _)) => Err(anyhow::anyhow!("no such channel type: {type}")), None => Err(anyhow::anyhow!("no channel type specified")), } @@ -649,7 +681,38 @@ impl ChannelAddr { /// - inproc://endpoint-name (equivalent to local) /// - ipc://path (equivalent to unix) /// - metatls://hostname:port or metatls://*:port + /// - Alias format: dial_to_url@bind_to_url (e.g., tcp://host:port@tcp://host:port) + /// Note: Alias format is currently only supported for TCP addresses pub fn from_zmq_url(address: &str) -> Result { + // Check for Alias format: dial_to_url@bind_to_url + // The @ character separates two valid ZMQ URLs + if let Some(at_pos) = address.find('@') { + let dial_to_str = &address[..at_pos]; + let bind_to_str = &address[at_pos + 1..]; + + // Validate that both addresses use TCP scheme + if !dial_to_str.starts_with("tcp://") { + return Err(anyhow::anyhow!( + "alias format is only supported for TCP addresses, got dial_to: {}", + dial_to_str + )); + } + if !bind_to_str.starts_with("tcp://") { + return Err(anyhow::anyhow!( + "alias format is only supported for TCP addresses, got bind_to: {}", + bind_to_str + )); + } + + let dial_to = Self::from_zmq_url(dial_to_str)?; + let bind_to = Self::from_zmq_url(bind_to_str)?; + + return Ok(Self::Alias { + dial_to: Box::new(dial_to), + bind_to: Box::new(bind_to), + }); + } + // Try ZMQ-style URL format first (scheme://...) let (scheme, address) = address.split_once("://").ok_or_else(|| { anyhow::anyhow!("address must be in url form scheme://endppoint {}", address) @@ -840,6 +903,7 @@ pub fn dial(addr: ChannelAddr) -> Result, Channel ChannelAddr::MetaTls(meta_addr) => ChannelTxKind::MetaTls(net::meta::dial(meta_addr)?), ChannelAddr::Sim(sim_addr) => ChannelTxKind::Sim(sim::dial::(sim_addr)?), ChannelAddr::Unix(path) => ChannelTxKind::Unix(net::unix::dial(path)), + ChannelAddr::Alias { dial_to, .. } => dial(*dial_to)?.inner, }; Ok(ChannelTx { inner }) } @@ -852,6 +916,19 @@ pub fn serve( addr: ChannelAddr, ) -> Result<(ChannelAddr, ChannelRx), ChannelError> { let caller = Location::caller(); + serve_inner(addr).map(|(addr, inner)| { + tracing::debug!( + name = "serve", + %addr, + %caller, + ); + (addr, ChannelRx { inner }) + }) +} + +fn serve_inner( + addr: ChannelAddr, +) -> Result<(ChannelAddr, ChannelRxKind), ChannelError> { match addr { ChannelAddr::Tcp(addr) => { let (addr, rx) = net::tcp::serve::(addr)?; @@ -877,15 +954,15 @@ pub fn serve( "invalid local addr: {}", a ))), + ChannelAddr::Alias { dial_to, bind_to } => { + let (bound_addr, rx) = serve_inner::(*bind_to)?; + let alias_addr = ChannelAddr::Alias { + dial_to, + bind_to: Box::new(bound_addr), + }; + Ok((alias_addr, rx)) + } } - .map(|(addr, inner)| { - tracing::debug!( - name = "serve", - %addr, - %caller, - ); - (addr, ChannelRx { inner }) - }) } /// Serve on the local address. The server is turned down @@ -1056,6 +1133,48 @@ mod tests { assert!(ChannelAddr::from_zmq_url("inproc://not-a-number").is_err()); } + #[test] + fn test_zmq_style_alias_channel_addr() { + // Test Alias format: dial_to_url@bind_to_url + // The format is: dial_to_url@bind_to_url where both are valid ZMQ URLs + // Note: Alias format is only supported for TCP addresses + + // Test Alias with tcp on both sides + let alias_addr = ChannelAddr::from_zmq_url("tcp://127.0.0.1:9000@tcp://[::]:8800").unwrap(); + match alias_addr { + ChannelAddr::Alias { dial_to, bind_to } => { + assert_eq!( + *dial_to, + ChannelAddr::Tcp("127.0.0.1:9000".parse().unwrap()) + ); + assert_eq!(*bind_to, ChannelAddr::Tcp("[::]:8800".parse().unwrap())); + } + _ => panic!("Expected Alias"), + } + + // Test error: alias with non-tcp dial_to (not supported) + assert!( + ChannelAddr::from_zmq_url("metatls://example.com:443@tcp://127.0.0.1:8080").is_err() + ); + + // Test error: alias with non-tcp bind_to (not supported) + assert!( + ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@metatls://example.com:443").is_err() + ); + + // Test error: invalid dial_to URL in Alias + assert!(ChannelAddr::from_zmq_url("invalid://scheme@tcp://127.0.0.1:8080").is_err()); + + // Test error: invalid bind_to URL in Alias + assert!(ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@invalid://scheme").is_err()); + + // Test error: missing port in dial_to + assert!(ChannelAddr::from_zmq_url("tcp://host@tcp://127.0.0.1:8080").is_err()); + + // Test error: missing port in bind_to + assert!(ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@tcp://example.com").is_err()); + } + #[tokio::test] async fn test_multiple_connections() { for addr in ChannelTransport::all().map(ChannelAddr::any) { From 5e4299d0e828dd6a693e0439b53c43fd5297ac4d Mon Sep 17 00:00:00 2001 From: Peng Zhang Date: Mon, 15 Dec 2025 09:15:30 -0800 Subject: [PATCH 2/3] Add ChannelTransport::Explicit Differential Revision: D89190087 --- hyperactor/src/channel.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/hyperactor/src/channel.rs b/hyperactor/src/channel.rs index ea5604307..3584eb314 100644 --- a/hyperactor/src/channel.rs +++ b/hyperactor/src/channel.rs @@ -356,6 +356,9 @@ pub enum ChannelTransport { /// Transport over unix domain socket. Unix, + + /// Use a specific channel address for transport. + Explicit(ChannelAddr), } impl fmt::Display for ChannelTransport { @@ -366,6 +369,7 @@ impl fmt::Display for ChannelTransport { Self::Local => write!(f, "local"), Self::Sim(transport) => write!(f, "sim({})", transport), Self::Unix => write!(f, "unix"), + Self::Explicit(addr) => write!(f, "explicit({})", addr), } } } @@ -400,6 +404,11 @@ impl FromStr for ChannelTransport { let mode = inner.parse()?; Ok(ChannelTransport::MetaTls(mode)) } + s if s.starts_with("explicit(") && s.ends_with(")") => Err(anyhow::anyhow!( + "detect possible explicit transport, but we currently do not \ + support parsing explicit's string representation since we + only want to support the zmq_url format." + )), unknown => Err(anyhow::anyhow!("unknown channel transport: {}", unknown)), } } @@ -433,6 +442,7 @@ impl ChannelTransport { ChannelTransport::Local => false, ChannelTransport::Sim(_) => false, ChannelTransport::Unix => false, + ChannelTransport::Explicit(addr) => addr.transport().is_remote(), } } } @@ -598,6 +608,7 @@ impl ChannelAddr { ChannelTransport::Sim(transport) => sim::any(*transport), // This works because the file will be deleted but we know we have a unique file by this point. ChannelTransport::Unix => Self::Unix(net::unix::SocketAddr::from_str("").unwrap()), + ChannelTransport::Explicit(addr) => addr, } } From 5f6cceccb729378bf53f44b87fd7a6067597dab5 Mon Sep 17 00:00:00 2001 From: Peng Zhang Date: Mon, 15 Dec 2025 09:15:30 -0800 Subject: [PATCH 3/3] Add ChannelTransportConfig wrapper Differential Revision: D89185959 --- monarch_hyperactor/src/channel.rs | 94 +++++++++++++++++++ monarch_hyperactor/src/config.rs | 26 +++-- .../monarch_hyperactor/channel.pyi | 24 +++++ .../monarch_hyperactor/config.pyi | 12 ++- python/monarch/_src/actor/actor_mesh.py | 37 ++++++-- python/tests/test_config.py | 45 +++++++-- 6 files changed, 211 insertions(+), 27 deletions(-) diff --git a/monarch_hyperactor/src/channel.rs b/monarch_hyperactor/src/channel.rs index e3bdebdfe..1fd6dd245 100644 --- a/monarch_hyperactor/src/channel.rs +++ b/monarch_hyperactor/src/channel.rs @@ -14,10 +14,15 @@ use hyperactor::channel::MetaTlsAddr; use hyperactor::channel::TcpMode; use hyperactor::channel::TlsMode; use pyo3::exceptions::PyRuntimeError; +use pyo3::exceptions::PyTypeError; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; /// Python binding for [`hyperactor::channel::ChannelTransport`] +/// +/// This enum represents the basic transport types that can be represented +/// as simple enum variants. For more complex transports like `Explicit`, +/// use string-based configuration via `PyChannelTransportConfig`. #[pyclass( name = "ChannelTransport", module = "monarch._rust_bindings.monarch_hyperactor.channel", @@ -62,6 +67,94 @@ impl TryFrom for PyChannelTransport { } } +/// A wrapper for ChannelTransport that can be created from either a +/// PyChannelTransport enum or a string (for explicit transport). +/// +/// We need this wrapper because Python's enum type does not support attaching +/// data to enum variants like Rust does. +#[pyclass( + name = "ChannelTransportConfig", + module = "monarch._rust_bindings.monarch_hyperactor.channel" +)] +#[derive(Clone, Debug, PartialEq)] +pub struct PyChannelTransportConfig { + inner: ChannelTransport, +} + +#[pymethods] +impl PyChannelTransportConfig { + /// Create a new PyChannelTransportConfig from either a ChannelTransport enum + /// or a string representation. + /// + /// Examples: + /// PyChannelTransportConfig(ChannelTransport.Unix) + /// PyChannelTransportConfig("explicit:tcp://127.0.0.1:8080") + #[new] + pub fn new(transport: &Bound<'_, PyAny>) -> PyResult { + // First try to extract as PyChannelTransportConfig (for when passing an existing config) + if let Ok(config) = transport.extract::() { + return Ok(config); + } + + // Then try to extract as PyChannelTransport enum + if let Ok(py_transport) = transport.extract::() { + return Ok(PyChannelTransportConfig { + inner: py_transport.into(), + }); + } + + // Then try to extract as a string and parse it + if let Ok(transport_str) = transport.extract::() { + if !transport_str.starts_with("explicit:") { + return Err(PyValueError::new_err(format!( + "string argument only supports explicit transport with \ + address in the zmq url format (e.g., 'explicit:tcp://127.0.0.1:8080'); \ + but got: {}", + transport_str, + ))); + } + let addr_str = transport_str.strip_prefix("explicit:").unwrap(); + let addr = ChannelAddr::from_zmq_url(addr_str).map_err(|e| { + PyValueError::new_err(format!( + "invalid address string used for explicit transport '{}': {}", + addr_str, e + )) + })?; + return Ok(PyChannelTransportConfig { + inner: ChannelTransport::Explicit(addr), + }); + } + + Err(PyTypeError::new_err( + "expected ChannelTransport enum, ChannelTransportConfig, or str", + )) + } + + fn __str__(&self) -> String { + self.inner.to_string() + } + + fn __repr__(&self) -> String { + format!("PyChannelTransportConfig({:?})", self.inner) + } + + fn __eq__(&self, other: &Self) -> bool { + self.inner == other.inner + } +} + +impl From for ChannelTransport { + fn from(config: PyChannelTransportConfig) -> Self { + config.inner + } +} + +impl From for PyChannelTransportConfig { + fn from(transport: ChannelTransport) -> Self { + PyChannelTransportConfig { inner: transport } + } +} + #[pyclass( name = "ChannelAddr", module = "monarch._rust_bindings.monarch_hyperactor.channel" @@ -149,6 +242,7 @@ impl From for ChannelTransport { #[pymodule] pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> { hyperactor_mod.add_class::()?; + hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; Ok(()) } diff --git a/monarch_hyperactor/src/config.rs b/monarch_hyperactor/src/config.rs index 4136e80ee..32606429c 100644 --- a/monarch_hyperactor/src/config.rs +++ b/monarch_hyperactor/src/config.rs @@ -30,12 +30,13 @@ use hyperactor_config::attrs::Attrs; use hyperactor_config::attrs::ErasedKey; use hyperactor_config::attrs::declare_attrs; use hyperactor_config::global::Source; +use pyo3::conversion::IntoPyObject; use pyo3::conversion::IntoPyObjectExt; use pyo3::exceptions::PyTypeError; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; -use crate::channel::PyChannelTransport; +use crate::channel::PyChannelTransportConfig; /// Python wrapper for Duration, using humantime format strings. /// @@ -289,9 +290,9 @@ inventory::collect!(PythonConfigTypeInfo); /// like `String` that are convertible directly to/from PyObjects, /// you can just use `declare_py_config_type!(String)`. For types /// that must first be converted to/from a rust python wrapper -/// (e.g., keys with type `ChannelTransport` must use `PyChannelTransport` +/// (e.g., keys with type `ChannelTransport` must use `PyChannelTransportConfig` /// as an intermediate step), the usage is -/// `declare_py_config_type!(PyChannelTransport as ChannelTransport)`. +/// `declare_py_config_type!(PyChannelTransportConfig as ChannelTransport)`. macro_rules! declare_py_config_type { ($($ty:ty),+ $(,)?) => { hyperactor::paste! { @@ -341,7 +342,7 @@ macro_rules! declare_py_config_type { }; } -declare_py_config_type!(PyChannelTransport as ChannelTransport); +declare_py_config_type!(PyChannelTransportConfig as ChannelTransport); declare_py_config_type!(PyDuration as Duration); declare_py_config_type!( i8, i16, i32, i64, u8, u16, u32, u64, usize, f32, f64, bool, String @@ -367,9 +368,20 @@ declare_py_config_type!( fn configure(py: Python<'_>, kwargs: Option>) -> PyResult<()> { kwargs .map(|kwargs| { - kwargs - .into_iter() - .try_for_each(|(key, val)| configure_kwarg(py, &key, val)) + kwargs.into_iter().try_for_each(|(key, val)| { + // Special handling for default_transport: convert ChannelTransport + // enum or explicit transport string to PyChannelTransportConfig + // before processing + let val = if key == "default_transport" { + PyChannelTransportConfig::new(val.bind(py))? + .into_pyobject(py)? + .into_any() + .unbind() + } else { + val + }; + configure_kwarg(py, &key, val) + }) }) .transpose()?; Ok(()) diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/channel.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/channel.pyi index fea7a1f5f..167bd380b 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/channel.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/channel.pyi @@ -9,6 +9,11 @@ from enum import Enum class ChannelTransport(Enum): + """ + Enum representing basic transport types for channels. + For the explicit transport type, use ChannelTransportConfig instead. + """ + TcpWithLocalhost = "tcp(localhost)" TcpWithHostname = "tcp(hostname)" MetaTlsWithHostname = "metatls(hostname)" @@ -17,6 +22,25 @@ class ChannelTransport(Enum): Unix = "unix" # Sim # TODO add support +class ChannelTransportConfig: + """ + Internal wrapper for ChannelTransport that accepts either a ChannelTransport enum + or a string for complex transports. + + Note: This class is for internal use only. Users should pass ChannelTransport + enum values or strings directly to enable_transport(). + """ + + def __init__(self, transport: ChannelTransport | str) -> None: ... + """ + Basic transport types supported by ChannelTransport should be used directly as enum values. + For the explicit transport type, use ChannelTransportConfig instead. + - "explicit:": Use a specific channel address with zmq url format, e.g. "explicit:tcp://127.0.0.1:8080" + """ + def __str__(self) -> str: ... + def __repr__(self) -> str: ... + def __eq__(self, other: object) -> bool: ... + class ChannelAddr: @staticmethod def any(transport: ChannelTransport) -> str: diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi index 7ae4446a7..23c05ccc8 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/config.pyi @@ -12,7 +12,10 @@ Type hints for the monarch_hyperactor.config Rust bindings. from typing import Any, Dict -from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport +from monarch._rust_bindings.monarch_hyperactor.channel import ( + ChannelTransport, + ChannelTransportConfig, +) def reload_config_from_env() -> None: """ @@ -32,7 +35,7 @@ def reset_config_to_defaults() -> None: ... def configure( - default_transport: ChannelTransport = ..., + default_transport: ChannelTransportConfig | ChannelTransport | str = ..., enable_log_forwarding: bool = ..., enable_file_capture: bool = ..., tail_log_lines: int = ..., @@ -50,7 +53,10 @@ def configure( plus any additional CONFIG-marked keys passed via **kwargs. Args: - default_transport: Default channel transport for communication + default_transport: Default channel transport for communication. Can be: + - A ChannelTransportConfig object + - A ChannelTransport enum value (e.g., ChannelTransport.Unix) + - A string for explicit transport (e.g., "explicit::tcp://127.0.0.1:8080") enable_log_forwarding: Whether to forward logs from actors enable_file_capture: Whether to capture file output tail_log_lines: Number of log lines to tail diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index 6eeddff27..282b89b2f 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -53,7 +53,10 @@ ) from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh from monarch._rust_bindings.monarch_hyperactor.buffers import FrozenBuffer -from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport +from monarch._rust_bindings.monarch_hyperactor.channel import ( + ChannelTransport, + ChannelTransportConfig, +) from monarch._rust_bindings.monarch_hyperactor.config import configure from monarch._rust_bindings.monarch_hyperactor.context import Instance as HyInstance from monarch._rust_bindings.monarch_hyperactor.mailbox import ( @@ -410,7 +413,7 @@ def context() -> Context: return c -_transport: Optional[ChannelTransport] = None +_transport: Optional[ChannelTransportConfig] = None _transport_lock = threading.Lock() @@ -424,17 +427,33 @@ def enable_transport(transport: "ChannelTransport | str") -> None: Currently only one transport type may be enabled at one time. In the future we may allow multiple to be enabled. + Supported transport values: + - ChannelTransport enum: ChannelTransport.Unix, ChannelTransport.TcpWithHostname, etc. + - string short cuts for the ChannelTransport enum: + - "tcp": ChannelTransport.TcpWithHostname + - "ipc": ChannelTransport.Unix + - "metatls": ChannelTransport.MetaTlsWithIpV6 + - "metatls-hostname": ChannelTransport.MetaTlsWithHostname + - string for explicit transport. e.g.: + - "explicit:tcp://127.0.0.1:8080" + For Meta usage, use metatls-hostname """ if isinstance(transport, str): - transport = { + # Handle string shortcuts for the ChannelTransport enum, + resolved = { "tcp": ChannelTransport.TcpWithHostname, "ipc": ChannelTransport.Unix, "metatls": ChannelTransport.MetaTlsWithIpV6, "metatls-hostname": ChannelTransport.MetaTlsWithHostname, }.get(transport) - if transport is None: - raise ValueError(f"unknown transport: {transport}") + if resolved is not None: + transport_config = ChannelTransportConfig(resolved) + else: + transport_config = ChannelTransportConfig(transport) + else: + # ChannelTransport enum + transport_config = ChannelTransportConfig(transport) if _context.get(None) is not None: raise RuntimeError( @@ -445,14 +464,14 @@ def enable_transport(transport: "ChannelTransport | str") -> None: global _transport with _transport_lock: - if _transport is not None and _transport != transport: + if _transport is not None and _transport != transport_config: raise RuntimeError( f"Only one transport type may be enabled at one time. " f"Currently enabled transport type is `{_transport}`. " - f"Attempted to enable transport type `{transport}`." + f"Attempted to enable transport type `{transport_config}`." ) - _transport = transport - configure(default_transport=transport) + _transport = transport_config + configure(default_transport=transport_config) @dataclass diff --git a/python/tests/test_config.py b/python/tests/test_config.py index 45df1f89c..b158a6065 100644 --- a/python/tests/test_config.py +++ b/python/tests/test_config.py @@ -10,7 +10,11 @@ import monarch import pytest -from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport + +from monarch._rust_bindings.monarch_hyperactor.channel import ( + ChannelTransport, + ChannelTransportConfig, +) from monarch._rust_bindings.monarch_hyperactor.supervision import SupervisionError from monarch.actor import Actor, endpoint, this_proc from monarch.config import configured, get_global_config @@ -34,12 +38,14 @@ def test_get_set_transport() -> None: ChannelTransport.MetaTlsWithHostname, ): with configured(default_transport=transport) as config: - assert config["default_transport"] == transport + assert config["default_transport"] == ChannelTransportConfig(transport) # Succeed even if we don't specify the transport, but does not change the # previous value. with configured() as config: - assert config["default_transport"] == ChannelTransport.Unix - with pytest.raises(TypeError): + assert config["default_transport"] == ChannelTransportConfig( + ChannelTransport.Unix + ) + with pytest.raises(ValueError): with configured(default_transport="unix"): # type: ignore pass with pytest.raises(TypeError): @@ -50,6 +56,25 @@ def test_get_set_transport() -> None: pass +def test_get_set_explicit_transport() -> None: + # Test explicit transport with a TCP address + with configured(default_transport="explicit:tcp://127.0.0.1:8080") as config: + # The config value is a ChannelTransportConfig, so compare string representation + assert config["default_transport"] == ChannelTransportConfig( + "explicit:tcp://127.0.0.1:8080" + ) + + # Test that invalid explicit transport strings raise an error + with pytest.raises(ValueError): + with configured(default_transport="explicit:invalid"): + pass + + # Test that random strings (not explicit format) raise an error + with pytest.raises(ValueError): + with configured(default_transport="random_string"): + pass + + def test_nonexistent_config_key() -> None: with pytest.raises(ValueError): with configured(does_not_exist=42): # type: ignore @@ -64,13 +89,15 @@ def test_get_set_multiple() -> None: assert config["enable_log_forwarding"] assert config["enable_file_capture"] assert config["tail_log_lines"] == 100 - assert config["default_transport"] == ChannelTransport.TcpWithLocalhost + assert config["default_transport"] == ChannelTransportConfig( + ChannelTransport.TcpWithLocalhost + ) # Make sure the previous values are restored. config = get_global_config() assert not config["enable_log_forwarding"] assert not config["enable_file_capture"] assert config["tail_log_lines"] == 0 - assert config["default_transport"] == ChannelTransport.Unix + assert config["default_transport"] == ChannelTransportConfig(ChannelTransport.Unix) # This test tries to allocate too much memory for the GitHub actions @@ -219,7 +246,9 @@ def test_duration_config_multiple() -> None: enable_log_forwarding=True, tail_log_lines=100, ) as config: - assert config["default_transport"] == ChannelTransport.TcpWithLocalhost + assert config["default_transport"] == ChannelTransportConfig( + ChannelTransport.TcpWithLocalhost + ) assert config["host_spawn_ready_timeout"] == "10m" assert config["message_delivery_timeout"] == "5m" assert config["mesh_proc_spawn_max_idle"] == "2m" @@ -228,7 +257,7 @@ def test_duration_config_multiple() -> None: # Verify all values are restored config = get_global_config() - assert config["default_transport"] == ChannelTransport.Unix + assert config["default_transport"] == ChannelTransportConfig(ChannelTransport.Unix) assert config["host_spawn_ready_timeout"] == "30s" assert config["message_delivery_timeout"] == "30s" assert config["mesh_proc_spawn_max_idle"] == "30s"