Skip to content

Commit af7e888

Browse files
colin2328facebook-github-bot
authored andcommitted
add monarch serve torchx command to launch the (MAST) job and cache the command inside the jobs .pkl (#2097)
Summary: Intorduce monarch serve torchx ... as proposed here https://docs.google.com/document/d/1F3m3oDBX3sipHCxsp_2ghSvgOBKgmFRWql3A5StMh1I/edit?tab=t.0#heading=h.zb5e0il0bn6a here, we create the job and create the .pkl file of the command we then add run_spmd so we can run python -c "from monarch.job import job_load; job = job_load(); job.run_spmd()" Differential Revision: D88515552
1 parent 5b545db commit af7e888

File tree

4 files changed

+468
-3
lines changed

4 files changed

+468
-3
lines changed

python/monarch/_src/job/spmd.py

Lines changed: 369 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,369 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import re
8+
import sys
9+
from typing import Any, Dict, List, Optional
10+
11+
from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
12+
from monarch._rust_bindings.monarch_hyperactor.config import configure
13+
from monarch._src.job.job import JobState, JobTrait
14+
from monarch._src.spmd import SPMDActor
15+
from torchx.runner import get_runner
16+
from torchx.specs import AppDef
17+
from torchx.specs.finder import get_component
18+
19+
20+
def _extract_nproc_per_node_from_appdef(appdef: AppDef) -> int:
21+
"""
22+
Extract nproc_per_node from the torchrun command in AppDef.
23+
24+
NOTE: This is for Phase B implementation where we parse torchrun
25+
and spawn GPU workers as monarch actors. Currently unused in Phase A.
26+
27+
The AppDef typically has entrypoint="bash" with args=["-c", "torchrun ... --nproc_per_node N ..."].
28+
This function parses the torchrun command to extract N.
29+
"""
30+
if not appdef.roles or len(appdef.roles) == 0:
31+
raise ValueError("AppDef has no roles")
32+
33+
role = appdef.roles[0]
34+
35+
# For bash entrypoint, the command is in args[1] (after "-c")
36+
if role.entrypoint == "bash" and len(role.args) >= 2 and role.args[0] == "-c":
37+
command = role.args[1]
38+
# For python entrypoint, the command is in the args list
39+
elif role.entrypoint == "python":
40+
command = " ".join(role.args)
41+
else:
42+
# Fallback: join all args
43+
command = " ".join(role.args) if role.args else ""
44+
45+
# Parse --nproc_per_node or --nproc-per-node
46+
match = re.search(r"--nproc[_-]per[_-]node[=\s]+(\d+)", command)
47+
if match:
48+
return int(match.group(1))
49+
50+
raise ValueError(f"Could not extract nproc_per_node from AppDef command: {command}")
51+
52+
53+
def _parse_cli_args_to_kwargs(cli_args: List[str]) -> tuple[List[str], List[str]]:
54+
"""
55+
Convert CLI-style arguments to key=value format expected by component_args_from_cli.
56+
57+
Args:
58+
cli_args: Arguments in CLI format, e.g. ['-j', '1x8', '--script', 'train.py', '--', '--lr', '0.001']
59+
60+
Returns:
61+
Tuple of (kwargs_format, script_args) where:
62+
- kwargs_format: ['j=1x8', 'script=train.py']
63+
- script_args: ['--lr', '0.001'] (arguments after -- delimiter)
64+
"""
65+
kwargs_format = []
66+
i = 0
67+
script_args = []
68+
found_delimiter = False
69+
70+
while i < len(cli_args):
71+
arg = cli_args[i]
72+
73+
# Check for -- delimiter
74+
if arg == "--":
75+
found_delimiter = True
76+
# Everything after -- becomes script_args
77+
script_args = cli_args[i + 1 :]
78+
break
79+
80+
# Parse flag arguments
81+
if arg.startswith("--"):
82+
key = arg[2:] # Remove --
83+
if i + 1 < len(cli_args) and not cli_args[i + 1].startswith("-"):
84+
value = cli_args[i + 1]
85+
kwargs_format.append(f"{key}={value}")
86+
i += 2
87+
else:
88+
# Boolean flag
89+
kwargs_format.append(f"{key}=true")
90+
i += 1
91+
elif arg.startswith("-"):
92+
key = arg[1:] # Remove -
93+
if i + 1 < len(cli_args) and not cli_args[i + 1].startswith("-"):
94+
value = cli_args[i + 1]
95+
kwargs_format.append(f"{key}={value}")
96+
i += 2
97+
else:
98+
# Boolean flag
99+
kwargs_format.append(f"{key}=true")
100+
i += 1
101+
else:
102+
i += 1
103+
104+
return kwargs_format, script_args
105+
106+
107+
def create_job_for_scheduler(
108+
scheduler: str,
109+
scheduler_cfg: Dict[str, Any],
110+
num_hosts: int,
111+
host_type: str,
112+
workspace: Optional[str] = None,
113+
) -> JobTrait:
114+
"""
115+
Create appropriate job based on scheduler type.
116+
117+
Args:
118+
scheduler: Scheduler name (e.g., "mast", "mast_conda", "slurm")
119+
scheduler_cfg: Scheduler configuration dict with keys like hpcIdentity, etc.
120+
num_hosts: Number of hosts to allocate
121+
host_type: Host type (e.g., "gtt_any")
122+
workspace: Optional local workspace directory to pack
123+
124+
Returns:
125+
JobTrait instance configured for the scheduler
126+
127+
Raises:
128+
NotImplementedError: If scheduler is not yet supported
129+
ValueError: If scheduler is unsupported
130+
"""
131+
match scheduler:
132+
case "mast_conda":
133+
from monarch._src.job.meta import MASTJob
134+
135+
job = MASTJob(
136+
hpcIdentity=scheduler_cfg["hpcIdentity"],
137+
hpcJobOncall=scheduler_cfg["hpcJobOncall"],
138+
rmAttribution=scheduler_cfg["rmAttribution"],
139+
hpcClusterUuid=scheduler_cfg.get("hpcClusterUuid", "MastProdCluster"),
140+
)
141+
job.add_mesh("workers", num_hosts, host_type)
142+
143+
# Add workspace if provided (pack to root of WORKSPACE_DIR)
144+
if workspace:
145+
job.add_directory(workspace, "")
146+
147+
return job
148+
149+
case "slurm":
150+
raise NotImplementedError(f"Scheduler {scheduler} not yet supported")
151+
152+
case _:
153+
raise ValueError(f"Unsupported scheduler: {scheduler}")
154+
155+
156+
class SPMDJob(JobTrait):
157+
"""
158+
SPMD (Single Program Multiple Data) job that wraps any JobTrait.
159+
160+
This job type is created via `monarch serve torchx ...` CLI and stores
161+
both the underlying job (e.g., MASTJob) and the AppDef from the torchx component.
162+
"""
163+
164+
def __init__(
165+
self,
166+
job: JobTrait,
167+
scheduler: str,
168+
appdef: AppDef,
169+
workspace: Optional[str] = None,
170+
scheduler_args: Optional[Dict[str, Any]] = None,
171+
):
172+
super().__init__()
173+
self._job = job
174+
self._scheduler = scheduler
175+
self._appdef = appdef
176+
self._workspace = workspace
177+
self._scheduler_args = scheduler_args or {}
178+
179+
@classmethod
180+
def serve_from_command(cls, command: List[str]) -> "SPMDJob":
181+
"""
182+
Create an SPMDJob from a torchx command.
183+
184+
Args:
185+
command: List of command arguments starting with 'torchx run'
186+
Example: ['torchx', 'run', '-s', 'mast_conda', '-cfg', 'key=val',
187+
'dist.ddp', '-j', '1x8', '--script', 'train.py', '--', '--lr', '0.001']
188+
189+
Returns:
190+
SPMDJob instance ready to be launched
191+
192+
Raises:
193+
ValueError: If command format is invalid or required args are missing
194+
"""
195+
# Validate input
196+
if len(command) < 2 or command[0] != "torchx" or command[1] != "run":
197+
raise ValueError(
198+
"Command must start with 'torchx run'. "
199+
f"Got: {' '.join(command[:2]) if len(command) >= 2 else command}"
200+
)
201+
202+
# Remove 'torchx run' from beginning
203+
torchx_args = command[2:]
204+
205+
# Manually parse scheduler args and component args
206+
scheduler = None
207+
scheduler_args_str = ""
208+
workspace = None
209+
component_start_idx = None
210+
211+
i = 0
212+
while i < len(torchx_args):
213+
if torchx_args[i] in ["-s", "--scheduler"]:
214+
if i + 1 < len(torchx_args):
215+
scheduler = torchx_args[i + 1]
216+
i += 2
217+
else:
218+
raise ValueError("-s/--scheduler requires a value")
219+
elif torchx_args[i] in ["-cfg", "--scheduler_args"]:
220+
if i + 1 < len(torchx_args):
221+
scheduler_args_str = torchx_args[i + 1]
222+
i += 2
223+
else:
224+
raise ValueError("-cfg/--scheduler_args requires a value")
225+
elif torchx_args[i] == "--workspace":
226+
if i + 1 < len(torchx_args):
227+
workspace = torchx_args[i + 1]
228+
i += 2
229+
else:
230+
raise ValueError("--workspace requires a value")
231+
else:
232+
# This is the start of component name and args
233+
component_start_idx = i
234+
break
235+
236+
if not scheduler:
237+
raise ValueError("-s/--scheduler is required")
238+
239+
if component_start_idx is None or component_start_idx >= len(torchx_args):
240+
raise ValueError("Component name is required")
241+
242+
# Get component name and remaining args
243+
component_name = torchx_args[component_start_idx]
244+
component_args = torchx_args[component_start_idx + 1 :]
245+
246+
# Parse scheduler args
247+
runner = get_runner()
248+
scheduler_opts = runner.scheduler_run_opts(scheduler)
249+
scheduler_cfg = scheduler_opts.cfg_from_str(scheduler_args_str)
250+
251+
# Get component function and call it
252+
component_fn = get_component(component_name).fn
253+
254+
# Convert CLI-style args to key=value format and extract script_args
255+
component_args_kwformat, script_args = _parse_cli_args_to_kwargs(component_args)
256+
print(f"DEBUG: Original component_args: {component_args}")
257+
print(f"DEBUG: Converted to kwformat: {component_args_kwformat}")
258+
print(f"DEBUG: Script args: {script_args}")
259+
260+
# Parse kwargs manually (bypass component_args_from_cli to avoid *args limitation)
261+
component_kwargs = {}
262+
for arg in component_args_kwformat:
263+
if "=" in arg:
264+
key, value = arg.split("=", 1)
265+
component_kwargs[key] = value
266+
267+
print(f"DEBUG: Component kwargs: {component_kwargs}")
268+
269+
# Call component function to get AppDef
270+
try:
271+
# Pass script_args as positional arguments (*script_args)
272+
appdef: AppDef = component_fn(*script_args, **component_kwargs)
273+
except Exception as e:
274+
raise ValueError(
275+
f"Failed to call component function '{component_name}': {e}"
276+
)
277+
278+
# Extract num_hosts from AppDef
279+
if not appdef.roles or len(appdef.roles) == 0:
280+
raise ValueError("AppDef has no roles")
281+
282+
num_hosts = appdef.roles[0].num_replicas
283+
284+
# Extract host_type from component_args (if provided)
285+
host_type = "gtt_any"
286+
for i, arg in enumerate(component_args):
287+
if arg in ["-h", "--host_type"] and i + 1 < len(component_args):
288+
host_type = component_args[i + 1]
289+
break
290+
291+
# Workspace: CLI overrides AppDef
292+
if workspace is None and appdef.roles[0].env:
293+
workspace = appdef.roles[0].env.get("WORKSPACE_DIR")
294+
295+
# Auto-detect workspace from script path if not specified
296+
if workspace is None and "script" in component_kwargs:
297+
script_path = component_kwargs["script"]
298+
# If script is a relative path, use current directory as workspace
299+
if script_path and not script_path.startswith("/"):
300+
import os
301+
302+
workspace = os.getcwd()
303+
print(f"Auto-detected workspace from relative script path: {workspace}")
304+
305+
# Create underlying job based on scheduler type
306+
underlying_job = create_job_for_scheduler(
307+
scheduler=scheduler,
308+
scheduler_cfg=scheduler_cfg,
309+
num_hosts=num_hosts,
310+
host_type=host_type,
311+
workspace=workspace,
312+
)
313+
314+
# Return SPMDJob with AppDef
315+
return cls(
316+
job=underlying_job,
317+
scheduler=scheduler,
318+
appdef=appdef,
319+
workspace=workspace,
320+
scheduler_args=scheduler_cfg,
321+
)
322+
323+
def _create(self, client_script: Optional[str] = None):
324+
self._job._create(client_script)
325+
326+
def can_run(self, spec: "JobTrait") -> bool:
327+
if not isinstance(spec, SPMDJob):
328+
return False
329+
return (
330+
self._scheduler == spec._scheduler
331+
and self._appdef == spec._appdef
332+
and self._workspace == spec._workspace
333+
and self._scheduler_args == spec._scheduler_args
334+
and self._job.can_run(spec._job)
335+
)
336+
337+
def _state(self) -> JobState:
338+
return self._job._state()
339+
340+
def _kill(self):
341+
self._job._kill()
342+
343+
def run_spmd(self):
344+
"""
345+
Phase A: Execute torchrun command once per host.
346+
torchrun will spawn the GPU worker processes as child processes.
347+
348+
Phase B (TODO): Parse torchrun command and spawn GPU workers as monarch actors.
349+
"""
350+
configure(default_transport=ChannelTransport.MetaTlsWithHostname)
351+
job_state = self._state()
352+
workers = job_state.workers
353+
354+
# Phase A: Spawn 1 actor per host (torchrun will handle GPU processes)
355+
pm = workers.spawn_procs()
356+
am = pm.spawn("_SPMDActor", SPMDActor)
357+
358+
# Extract execution components from AppDef
359+
role = self._appdef.roles[0]
360+
entrypoint = role.entrypoint
361+
args = role.args or []
362+
env = role.env or {}
363+
364+
print("Phase A: Running torchrun command on each host")
365+
print(f" entrypoint: {entrypoint}")
366+
print(f" args: {args[:3] if len(args) > 3 else args}...")
367+
368+
# Run command on all hosts - torchrun handles coordination
369+
am.run_command.call(entrypoint, args, env).get()

0 commit comments

Comments
 (0)