Skip to content

Commit 39bd469

Browse files
pzhan9facebook-github-bot
authored andcommitted
Add support in enable_transport so the user can set an explicit TCP address (#2142)
Summary: As explained in D89190087, Lightning needs a way to set and explicit TCP address in `enable_transport`. This diff is to implement that. Differential Revision: D89185959
1 parent f1a66d2 commit 39bd469

File tree

6 files changed

+211
-27
lines changed

6 files changed

+211
-27
lines changed

monarch_hyperactor/src/channel.rs

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,15 @@ use hyperactor::channel::MetaTlsAddr;
1414
use hyperactor::channel::TcpMode;
1515
use hyperactor::channel::TlsMode;
1616
use pyo3::exceptions::PyRuntimeError;
17+
use pyo3::exceptions::PyTypeError;
1718
use pyo3::exceptions::PyValueError;
1819
use pyo3::prelude::*;
1920

2021
/// Python binding for [`hyperactor::channel::ChannelTransport`]
22+
///
23+
/// This enum represents the basic transport types that can be represented
24+
/// as simple enum variants. For more complex transports like `Explicit`,
25+
/// use string-based configuration via `PyChannelTransportConfig`.
2126
#[pyclass(
2227
name = "ChannelTransport",
2328
module = "monarch._rust_bindings.monarch_hyperactor.channel",
@@ -62,6 +67,94 @@ impl TryFrom<ChannelTransport> for PyChannelTransport {
6267
}
6368
}
6469

70+
/// A wrapper for ChannelTransport that can be created from either a
71+
/// PyChannelTransport enum or a string (for explicit transport).
72+
///
73+
/// We need this wrapper because Python's enum type does not support attaching
74+
/// data to enum variants like Rust does.
75+
#[pyclass(
76+
name = "ChannelTransportConfig",
77+
module = "monarch._rust_bindings.monarch_hyperactor.channel"
78+
)]
79+
#[derive(Clone, Debug, PartialEq)]
80+
pub struct PyChannelTransportConfig {
81+
inner: ChannelTransport,
82+
}
83+
84+
#[pymethods]
85+
impl PyChannelTransportConfig {
86+
/// Create a new PyChannelTransportConfig from either a ChannelTransport enum
87+
/// or a string representation.
88+
///
89+
/// Examples:
90+
/// PyChannelTransportConfig(ChannelTransport.Unix)
91+
/// PyChannelTransportConfig("explicit:tcp://127.0.0.1:8080")
92+
#[new]
93+
pub fn new(transport: &Bound<'_, PyAny>) -> PyResult<Self> {
94+
// First try to extract as PyChannelTransportConfig (for when passing an existing config)
95+
if let Ok(config) = transport.extract::<PyChannelTransportConfig>() {
96+
return Ok(config);
97+
}
98+
99+
// Then try to extract as PyChannelTransport enum
100+
if let Ok(py_transport) = transport.extract::<PyChannelTransport>() {
101+
return Ok(PyChannelTransportConfig {
102+
inner: py_transport.into(),
103+
});
104+
}
105+
106+
// Then try to extract as a string and parse it
107+
if let Ok(transport_str) = transport.extract::<String>() {
108+
if !transport_str.starts_with("explicit:") {
109+
return Err(PyValueError::new_err(format!(
110+
"string argument only supports explicit transport with \
111+
address in the zmq url format (e.g., 'explicit:tcp://127.0.0.1:8080'); \
112+
but got: {}",
113+
transport_str,
114+
)));
115+
}
116+
let addr_str = transport_str.strip_prefix("explicit:").unwrap();
117+
let addr = ChannelAddr::from_zmq_url(addr_str).map_err(|e| {
118+
PyValueError::new_err(format!(
119+
"invalid address string used for explicit transport '{}': {}",
120+
addr_str, e
121+
))
122+
})?;
123+
return Ok(PyChannelTransportConfig {
124+
inner: ChannelTransport::Explicit(addr),
125+
});
126+
}
127+
128+
Err(PyTypeError::new_err(
129+
"expected ChannelTransport enum, ChannelTransportConfig, or str",
130+
))
131+
}
132+
133+
fn __str__(&self) -> String {
134+
self.inner.to_string()
135+
}
136+
137+
fn __repr__(&self) -> String {
138+
format!("PyChannelTransportConfig({:?})", self.inner)
139+
}
140+
141+
fn __eq__(&self, other: &Self) -> bool {
142+
self.inner == other.inner
143+
}
144+
}
145+
146+
impl From<PyChannelTransportConfig> for ChannelTransport {
147+
fn from(config: PyChannelTransportConfig) -> Self {
148+
config.inner
149+
}
150+
}
151+
152+
impl From<ChannelTransport> for PyChannelTransportConfig {
153+
fn from(transport: ChannelTransport) -> Self {
154+
PyChannelTransportConfig { inner: transport }
155+
}
156+
}
157+
65158
#[pyclass(
66159
name = "ChannelAddr",
67160
module = "monarch._rust_bindings.monarch_hyperactor.channel"
@@ -149,6 +242,7 @@ impl From<PyChannelTransport> for ChannelTransport {
149242
#[pymodule]
150243
pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
151244
hyperactor_mod.add_class::<PyChannelTransport>()?;
245+
hyperactor_mod.add_class::<PyChannelTransportConfig>()?;
152246
hyperactor_mod.add_class::<PyChannelAddr>()?;
153247
Ok(())
154248
}

monarch_hyperactor/src/config.rs

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,13 @@ use hyperactor_config::attrs::Attrs;
3030
use hyperactor_config::attrs::ErasedKey;
3131
use hyperactor_config::attrs::declare_attrs;
3232
use hyperactor_config::global::Source;
33+
use pyo3::conversion::IntoPyObject;
3334
use pyo3::conversion::IntoPyObjectExt;
3435
use pyo3::exceptions::PyTypeError;
3536
use pyo3::exceptions::PyValueError;
3637
use pyo3::prelude::*;
3738

38-
use crate::channel::PyChannelTransport;
39+
use crate::channel::PyChannelTransportConfig;
3940

4041
/// Python wrapper for Duration, using humantime format strings.
4142
///
@@ -289,9 +290,9 @@ inventory::collect!(PythonConfigTypeInfo);
289290
/// like `String` that are convertible directly to/from PyObjects,
290291
/// you can just use `declare_py_config_type!(String)`. For types
291292
/// that must first be converted to/from a rust python wrapper
292-
/// (e.g., keys with type `ChannelTransport` must use `PyChannelTransport`
293+
/// (e.g., keys with type `ChannelTransport` must use `PyChannelTransportConfig`
293294
/// as an intermediate step), the usage is
294-
/// `declare_py_config_type!(PyChannelTransport as ChannelTransport)`.
295+
/// `declare_py_config_type!(PyChannelTransportConfig as ChannelTransport)`.
295296
macro_rules! declare_py_config_type {
296297
($($ty:ty),+ $(,)?) => {
297298
hyperactor::paste! {
@@ -341,7 +342,7 @@ macro_rules! declare_py_config_type {
341342
};
342343
}
343344

344-
declare_py_config_type!(PyChannelTransport as ChannelTransport);
345+
declare_py_config_type!(PyChannelTransportConfig as ChannelTransport);
345346
declare_py_config_type!(PyDuration as Duration);
346347
declare_py_config_type!(
347348
i8, i16, i32, i64, u8, u16, u32, u64, usize, f32, f64, bool, String
@@ -367,9 +368,20 @@ declare_py_config_type!(
367368
fn configure(py: Python<'_>, kwargs: Option<HashMap<String, PyObject>>) -> PyResult<()> {
368369
kwargs
369370
.map(|kwargs| {
370-
kwargs
371-
.into_iter()
372-
.try_for_each(|(key, val)| configure_kwarg(py, &key, val))
371+
kwargs.into_iter().try_for_each(|(key, val)| {
372+
// Special handling for default_transport: convert ChannelTransport
373+
// enum or explicit transport string to PyChannelTransportConfig
374+
// before processing
375+
let val = if key == "default_transport" {
376+
PyChannelTransportConfig::new(val.bind(py))?
377+
.into_pyobject(py)?
378+
.into_any()
379+
.unbind()
380+
} else {
381+
val
382+
};
383+
configure_kwarg(py, &key, val)
384+
})
373385
})
374386
.transpose()?;
375387
Ok(())

python/monarch/_rust_bindings/monarch_hyperactor/channel.pyi

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
from enum import Enum
1010

1111
class ChannelTransport(Enum):
12+
"""
13+
Enum representing basic transport types for channels.
14+
For the explicit transport type, use ChannelTransportConfig instead.
15+
"""
16+
1217
TcpWithLocalhost = "tcp(localhost)"
1318
TcpWithHostname = "tcp(hostname)"
1419
MetaTlsWithHostname = "metatls(hostname)"
@@ -17,6 +22,25 @@ class ChannelTransport(Enum):
1722
Unix = "unix"
1823
# Sim # TODO add support
1924

25+
class ChannelTransportConfig:
26+
"""
27+
Internal wrapper for ChannelTransport that accepts either a ChannelTransport enum
28+
or a string for complex transports.
29+
30+
Note: This class is for internal use only. Users should pass ChannelTransport
31+
enum values or strings directly to enable_transport().
32+
"""
33+
34+
def __init__(self, transport: ChannelTransport | str) -> None: ...
35+
"""
36+
Basic transport types supported by ChannelTransport should be used directly as enum values.
37+
For the explicit transport type, use ChannelTransportConfig instead.
38+
- "explicit:<addr>": Use a specific channel address with zmq url format, e.g. "explicit:tcp://127.0.0.1:8080"
39+
"""
40+
def __str__(self) -> str: ...
41+
def __repr__(self) -> str: ...
42+
def __eq__(self, other: object) -> bool: ...
43+
2044
class ChannelAddr:
2145
@staticmethod
2246
def any(transport: ChannelTransport) -> str:

python/monarch/_rust_bindings/monarch_hyperactor/config.pyi

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ Type hints for the monarch_hyperactor.config Rust bindings.
1212

1313
from typing import Any, Dict
1414

15-
from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
15+
from monarch._rust_bindings.monarch_hyperactor.channel import (
16+
ChannelTransport,
17+
ChannelTransportConfig,
18+
)
1619

1720
def reload_config_from_env() -> None:
1821
"""
@@ -32,7 +35,7 @@ def reset_config_to_defaults() -> None:
3235
...
3336

3437
def configure(
35-
default_transport: ChannelTransport = ...,
38+
default_transport: ChannelTransportConfig | ChannelTransport | str = ...,
3639
enable_log_forwarding: bool = ...,
3740
enable_file_capture: bool = ...,
3841
tail_log_lines: int = ...,
@@ -50,7 +53,10 @@ def configure(
5053
plus any additional CONFIG-marked keys passed via **kwargs.
5154
5255
Args:
53-
default_transport: Default channel transport for communication
56+
default_transport: Default channel transport for communication. Can be:
57+
- A ChannelTransportConfig object
58+
- A ChannelTransport enum value (e.g., ChannelTransport.Unix)
59+
- A string for explicit transport (e.g., "explicit::tcp://127.0.0.1:8080")
5460
enable_log_forwarding: Whether to forward logs from actors
5561
enable_file_capture: Whether to capture file output
5662
tail_log_lines: Number of log lines to tail

python/monarch/_src/actor/actor_mesh.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@
5353
)
5454
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh
5555
from monarch._rust_bindings.monarch_hyperactor.buffers import FrozenBuffer
56-
from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
56+
from monarch._rust_bindings.monarch_hyperactor.channel import (
57+
ChannelTransport,
58+
ChannelTransportConfig,
59+
)
5760
from monarch._rust_bindings.monarch_hyperactor.config import configure
5861
from monarch._rust_bindings.monarch_hyperactor.context import Instance as HyInstance
5962
from monarch._rust_bindings.monarch_hyperactor.mailbox import (
@@ -410,7 +413,7 @@ def context() -> Context:
410413
return c
411414

412415

413-
_transport: Optional[ChannelTransport] = None
416+
_transport: Optional[ChannelTransportConfig] = None
414417
_transport_lock = threading.Lock()
415418

416419

@@ -424,17 +427,33 @@ def enable_transport(transport: "ChannelTransport | str") -> None:
424427
Currently only one transport type may be enabled at one time.
425428
In the future we may allow multiple to be enabled.
426429
430+
Supported transport values:
431+
- ChannelTransport enum: ChannelTransport.Unix, ChannelTransport.TcpWithHostname, etc.
432+
- string short cuts for the ChannelTransport enum:
433+
- "tcp": ChannelTransport.TcpWithHostname
434+
- "ipc": ChannelTransport.Unix
435+
- "metatls": ChannelTransport.MetaTlsWithIpV6
436+
- "metatls-hostname": ChannelTransport.MetaTlsWithHostname
437+
- string for explicit transport. e.g.:
438+
- "explicit:tcp://127.0.0.1:8080"
439+
427440
For Meta usage, use metatls-hostname
428441
"""
429442
if isinstance(transport, str):
430-
transport = {
443+
# Handle string shortcuts for the ChannelTransport enum,
444+
resolved = {
431445
"tcp": ChannelTransport.TcpWithHostname,
432446
"ipc": ChannelTransport.Unix,
433447
"metatls": ChannelTransport.MetaTlsWithIpV6,
434448
"metatls-hostname": ChannelTransport.MetaTlsWithHostname,
435449
}.get(transport)
436-
if transport is None:
437-
raise ValueError(f"unknown transport: {transport}")
450+
if resolved is not None:
451+
transport_config = ChannelTransportConfig(resolved)
452+
else:
453+
transport_config = ChannelTransportConfig(transport)
454+
else:
455+
# ChannelTransport enum
456+
transport_config = ChannelTransportConfig(transport)
438457

439458
if _context.get(None) is not None:
440459
raise RuntimeError(
@@ -445,14 +464,14 @@ def enable_transport(transport: "ChannelTransport | str") -> None:
445464

446465
global _transport
447466
with _transport_lock:
448-
if _transport is not None and _transport != transport:
467+
if _transport is not None and _transport != transport_config:
449468
raise RuntimeError(
450469
f"Only one transport type may be enabled at one time. "
451470
f"Currently enabled transport type is `{_transport}`. "
452-
f"Attempted to enable transport type `{transport}`."
471+
f"Attempted to enable transport type `{transport_config}`."
453472
)
454-
_transport = transport
455-
configure(default_transport=transport)
473+
_transport = transport_config
474+
configure(default_transport=transport_config)
456475

457476

458477
@dataclass

0 commit comments

Comments
 (0)