Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 138 additions & 8 deletions hyperactor/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,9 @@ pub enum ChannelTransport {

/// Transport over unix domain socket.
Unix,

/// Use a specific channel address for transport.
Explicit(ChannelAddr),
}

impl fmt::Display for ChannelTransport {
Expand All @@ -366,6 +369,7 @@ impl fmt::Display for ChannelTransport {
Self::Local => write!(f, "local"),
Self::Sim(transport) => write!(f, "sim({})", transport),
Self::Unix => write!(f, "unix"),
Self::Explicit(addr) => write!(f, "explicit({})", addr),
}
}
}
Expand Down Expand Up @@ -400,6 +404,11 @@ impl FromStr for ChannelTransport {
let mode = inner.parse()?;
Ok(ChannelTransport::MetaTls(mode))
}
s if s.starts_with("explicit(") && s.ends_with(")") => Err(anyhow::anyhow!(
"detect possible explicit transport, but we currently do not \
support parsing explicit's string representation since we
only want to support the zmq_url format."
)),
unknown => Err(anyhow::anyhow!("unknown channel transport: {}", unknown)),
}
}
Expand Down Expand Up @@ -433,6 +442,7 @@ impl ChannelTransport {
ChannelTransport::Local => false,
ChannelTransport::Sim(_) => false,
ChannelTransport::Unix => false,
ChannelTransport::Explicit(addr) => addr.transport().is_remote(),
}
}
}
Expand Down Expand Up @@ -507,6 +517,27 @@ pub enum ChannelAddr {
/// A unix domain socket address. Supports both absolute path names as
/// well as "abstract" names per https://manpages.debian.org/unstable/manpages/unix.7.en.html#Abstract_sockets
Unix(net::unix::SocketAddr),

/// A pair of addresses, one for the client and one for the server:
/// - The client should dial to the `dial_to` address.
/// - The server should bind to the `bind_to` address.
///
/// The user is responsible for ensuring the traffic to the `dial_to` address
/// is routed to the `bind_to` address.
///
/// This is useful for scenarios where the network is confgiured in a way,
/// that the bound address is not directly accessible from the client.
///
/// For example, in AWS, the client could be provided with the public IP
/// address, yet the server is bound to a private IP address or simply
/// INADDR_ANY. Traffic to the public IP address is mapped to the private
/// IP address through network address translation (NAT).
Alias {
/// The address to which the client should dial to.
dial_to: Box<ChannelAddr>,
/// The address to which the server should bind to.
bind_to: Box<ChannelAddr>,
},
}

impl From<SocketAddr> for ChannelAddr {
Expand Down Expand Up @@ -577,6 +608,7 @@ impl ChannelAddr {
ChannelTransport::Sim(transport) => sim::any(*transport),
// This works because the file will be deleted but we know we have a unique file by this point.
ChannelTransport::Unix => Self::Unix(net::unix::SocketAddr::from_str("").unwrap()),
ChannelTransport::Explicit(addr) => addr,
}
}

Expand Down Expand Up @@ -604,6 +636,9 @@ impl ChannelAddr {
Self::Local(_) => ChannelTransport::Local,
Self::Sim(addr) => ChannelTransport::Sim(Box::new(addr.transport())),
Self::Unix(_) => ChannelTransport::Unix,
// bind_to's transport is what is actually used in communcation.
// Therefore we use its transport to represent the Alias.
Self::Alias { bind_to, .. } => bind_to.transport(),
}
}
}
Expand All @@ -616,6 +651,9 @@ impl fmt::Display for ChannelAddr {
Self::Local(index) => write!(f, "local:{}", index),
Self::Sim(sim_addr) => write!(f, "sim:{}", sim_addr),
Self::Unix(addr) => write!(f, "unix:{}", addr),
Self::Alias { dial_to, bind_to } => {
write!(f, "alias:dial_to={};bind_to={}", dial_to, bind_to)
}
}
}
}
Expand All @@ -636,6 +674,11 @@ impl FromStr for ChannelAddr {
Some(("metatls", rest)) => net::meta::parse(rest).map_err(|e| e.into()),
Some(("sim", rest)) => sim::parse(rest).map_err(|e| e.into()),
Some(("unix", rest)) => Ok(Self::Unix(net::unix::SocketAddr::from_str(rest)?)),
Some(("alias", _)) => Err(anyhow::anyhow!(
"detect possible alias address, but we currently do not support \
parsing alias' string representation since we only want to \
support parsing its zmq url format."
)),
Some((r#type, _)) => Err(anyhow::anyhow!("no such channel type: {type}")),
None => Err(anyhow::anyhow!("no channel type specified")),
}
Expand All @@ -649,7 +692,38 @@ impl ChannelAddr {
/// - inproc://endpoint-name (equivalent to local)
/// - ipc://path (equivalent to unix)
/// - metatls://hostname:port or metatls://*:port
/// - Alias format: dial_to_url@bind_to_url (e.g., tcp://host:port@tcp://host:port)
/// Note: Alias format is currently only supported for TCP addresses
pub fn from_zmq_url(address: &str) -> Result<Self, anyhow::Error> {
// Check for Alias format: dial_to_url@bind_to_url
// The @ character separates two valid ZMQ URLs
if let Some(at_pos) = address.find('@') {
let dial_to_str = &address[..at_pos];
let bind_to_str = &address[at_pos + 1..];

// Validate that both addresses use TCP scheme
if !dial_to_str.starts_with("tcp://") {
return Err(anyhow::anyhow!(
"alias format is only supported for TCP addresses, got dial_to: {}",
dial_to_str
));
}
if !bind_to_str.starts_with("tcp://") {
return Err(anyhow::anyhow!(
"alias format is only supported for TCP addresses, got bind_to: {}",
bind_to_str
));
}

let dial_to = Self::from_zmq_url(dial_to_str)?;
let bind_to = Self::from_zmq_url(bind_to_str)?;

return Ok(Self::Alias {
dial_to: Box::new(dial_to),
bind_to: Box::new(bind_to),
});
}

// Try ZMQ-style URL format first (scheme://...)
let (scheme, address) = address.split_once("://").ok_or_else(|| {
anyhow::anyhow!("address must be in url form scheme://endppoint {}", address)
Expand Down Expand Up @@ -840,6 +914,7 @@ pub fn dial<M: RemoteMessage>(addr: ChannelAddr) -> Result<ChannelTx<M>, Channel
ChannelAddr::MetaTls(meta_addr) => ChannelTxKind::MetaTls(net::meta::dial(meta_addr)?),
ChannelAddr::Sim(sim_addr) => ChannelTxKind::Sim(sim::dial::<M>(sim_addr)?),
ChannelAddr::Unix(path) => ChannelTxKind::Unix(net::unix::dial(path)),
ChannelAddr::Alias { dial_to, .. } => dial(*dial_to)?.inner,
};
Ok(ChannelTx { inner })
}
Expand All @@ -852,6 +927,19 @@ pub fn serve<M: RemoteMessage>(
addr: ChannelAddr,
) -> Result<(ChannelAddr, ChannelRx<M>), ChannelError> {
let caller = Location::caller();
serve_inner(addr).map(|(addr, inner)| {
tracing::debug!(
name = "serve",
%addr,
%caller,
);
(addr, ChannelRx { inner })
})
}

fn serve_inner<M: RemoteMessage>(
addr: ChannelAddr,
) -> Result<(ChannelAddr, ChannelRxKind<M>), ChannelError> {
match addr {
ChannelAddr::Tcp(addr) => {
let (addr, rx) = net::tcp::serve::<M>(addr)?;
Expand All @@ -877,15 +965,15 @@ pub fn serve<M: RemoteMessage>(
"invalid local addr: {}",
a
))),
ChannelAddr::Alias { dial_to, bind_to } => {
let (bound_addr, rx) = serve_inner::<M>(*bind_to)?;
let alias_addr = ChannelAddr::Alias {
dial_to,
bind_to: Box::new(bound_addr),
};
Ok((alias_addr, rx))
}
}
.map(|(addr, inner)| {
tracing::debug!(
name = "serve",
%addr,
%caller,
);
(addr, ChannelRx { inner })
})
}

/// Serve on the local address. The server is turned down
Expand Down Expand Up @@ -1056,6 +1144,48 @@ mod tests {
assert!(ChannelAddr::from_zmq_url("inproc://not-a-number").is_err());
}

#[test]
fn test_zmq_style_alias_channel_addr() {
// Test Alias format: dial_to_url@bind_to_url
// The format is: dial_to_url@bind_to_url where both are valid ZMQ URLs
// Note: Alias format is only supported for TCP addresses

// Test Alias with tcp on both sides
let alias_addr = ChannelAddr::from_zmq_url("tcp://127.0.0.1:9000@tcp://[::]:8800").unwrap();
match alias_addr {
ChannelAddr::Alias { dial_to, bind_to } => {
assert_eq!(
*dial_to,
ChannelAddr::Tcp("127.0.0.1:9000".parse().unwrap())
);
assert_eq!(*bind_to, ChannelAddr::Tcp("[::]:8800".parse().unwrap()));
}
_ => panic!("Expected Alias"),
}

// Test error: alias with non-tcp dial_to (not supported)
assert!(
ChannelAddr::from_zmq_url("metatls://example.com:443@tcp://127.0.0.1:8080").is_err()
);

// Test error: alias with non-tcp bind_to (not supported)
assert!(
ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@metatls://example.com:443").is_err()
);

// Test error: invalid dial_to URL in Alias
assert!(ChannelAddr::from_zmq_url("invalid://scheme@tcp://127.0.0.1:8080").is_err());

// Test error: invalid bind_to URL in Alias
assert!(ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@invalid://scheme").is_err());

// Test error: missing port in dial_to
assert!(ChannelAddr::from_zmq_url("tcp://host@tcp://127.0.0.1:8080").is_err());

// Test error: missing port in bind_to
assert!(ChannelAddr::from_zmq_url("tcp://127.0.0.1:8080@tcp://example.com").is_err());
}

#[tokio::test]
async fn test_multiple_connections() {
for addr in ChannelTransport::all().map(ChannelAddr::any) {
Expand Down
94 changes: 94 additions & 0 deletions monarch_hyperactor/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@ use hyperactor::channel::MetaTlsAddr;
use hyperactor::channel::TcpMode;
use hyperactor::channel::TlsMode;
use pyo3::exceptions::PyRuntimeError;
use pyo3::exceptions::PyTypeError;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;

/// Python binding for [`hyperactor::channel::ChannelTransport`]
///
/// This enum represents the basic transport types that can be represented
/// as simple enum variants. For more complex transports like `Explicit`,
/// use string-based configuration via `PyChannelTransportConfig`.
#[pyclass(
name = "ChannelTransport",
module = "monarch._rust_bindings.monarch_hyperactor.channel",
Expand Down Expand Up @@ -62,6 +67,94 @@ impl TryFrom<ChannelTransport> for PyChannelTransport {
}
}

/// A wrapper for ChannelTransport that can be created from either a
/// PyChannelTransport enum or a string (for explicit transport).
///
/// We need this wrapper because Python's enum type does not support attaching
/// data to enum variants like Rust does.
#[pyclass(
name = "ChannelTransportConfig",
module = "monarch._rust_bindings.monarch_hyperactor.channel"
)]
#[derive(Clone, Debug, PartialEq)]
pub struct PyChannelTransportConfig {
inner: ChannelTransport,
}

#[pymethods]
impl PyChannelTransportConfig {
/// Create a new PyChannelTransportConfig from either a ChannelTransport enum
/// or a string representation.
///
/// Examples:
/// PyChannelTransportConfig(ChannelTransport.Unix)
/// PyChannelTransportConfig("explicit:tcp://127.0.0.1:8080")
#[new]
pub fn new(transport: &Bound<'_, PyAny>) -> PyResult<Self> {
// First try to extract as PyChannelTransportConfig (for when passing an existing config)
if let Ok(config) = transport.extract::<PyChannelTransportConfig>() {
return Ok(config);
}

// Then try to extract as PyChannelTransport enum
if let Ok(py_transport) = transport.extract::<PyChannelTransport>() {
return Ok(PyChannelTransportConfig {
inner: py_transport.into(),
});
}

// Then try to extract as a string and parse it
if let Ok(transport_str) = transport.extract::<String>() {
if !transport_str.starts_with("explicit:") {
return Err(PyValueError::new_err(format!(
"string argument only supports explicit transport with \
address in the zmq url format (e.g., 'explicit:tcp://127.0.0.1:8080'); \
but got: {}",
transport_str,
)));
}
let addr_str = transport_str.strip_prefix("explicit:").unwrap();
let addr = ChannelAddr::from_zmq_url(addr_str).map_err(|e| {
PyValueError::new_err(format!(
"invalid address string used for explicit transport '{}': {}",
addr_str, e
))
})?;
return Ok(PyChannelTransportConfig {
inner: ChannelTransport::Explicit(addr),
});
}

Err(PyTypeError::new_err(
"expected ChannelTransport enum, ChannelTransportConfig, or str",
))
}

fn __str__(&self) -> String {
self.inner.to_string()
}

fn __repr__(&self) -> String {
format!("PyChannelTransportConfig({:?})", self.inner)
}

fn __eq__(&self, other: &Self) -> bool {
self.inner == other.inner
}
}

impl From<PyChannelTransportConfig> for ChannelTransport {
fn from(config: PyChannelTransportConfig) -> Self {
config.inner
}
}

impl From<ChannelTransport> for PyChannelTransportConfig {
fn from(transport: ChannelTransport) -> Self {
PyChannelTransportConfig { inner: transport }
}
}

#[pyclass(
name = "ChannelAddr",
module = "monarch._rust_bindings.monarch_hyperactor.channel"
Expand Down Expand Up @@ -149,6 +242,7 @@ impl From<PyChannelTransport> for ChannelTransport {
#[pymodule]
pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
hyperactor_mod.add_class::<PyChannelTransport>()?;
hyperactor_mod.add_class::<PyChannelTransportConfig>()?;
hyperactor_mod.add_class::<PyChannelAddr>()?;
Ok(())
}
Expand Down
Loading