Skip to content

Commit acc7606

Browse files
colin2328facebook-github-bot
authored andcommitted
add monarch serve torchx command to launch the (MAST) job and cache the command inside the jobs .pkl
Summary: Intorduce monarch serve torchx ... as proposed here https://docs.google.com/document/d/1F3m3oDBX3sipHCxsp_2ghSvgOBKgmFRWql3A5StMh1I/edit?tab=t.0#heading=h.zb5e0il0bn6a here, we create the job and create the .pkl file of the command we then add run_spmd so we can run python -c "from monarch.job import job_load; job = job_load(); job.run_spmd()" Differential Revision: D88515552
1 parent c493edd commit acc7606

File tree

4 files changed

+293
-0
lines changed

4 files changed

+293
-0
lines changed

python/monarch/_src/job/spmd.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, Dict, List, Optional
8+
9+
from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport
10+
from monarch._rust_bindings.monarch_hyperactor.config import configure
11+
from monarch._src.job.job import JobState, JobTrait
12+
from monarch._src.spmd import SPMDActor
13+
14+
15+
def create_job_for_scheduler(
16+
scheduler: str,
17+
scheduler_cfg: Dict[str, Any],
18+
num_hosts: int,
19+
host_type: str,
20+
workspace: Optional[str] = None,
21+
) -> JobTrait:
22+
"""
23+
Create appropriate job based on scheduler type.
24+
25+
Args:
26+
scheduler: Scheduler name (e.g., "mast", "mast_conda", "slurm")
27+
scheduler_cfg: Scheduler configuration dict with keys like hpcIdentity, etc.
28+
num_hosts: Number of hosts to allocate
29+
host_type: Host type (e.g., "gtt_any")
30+
workspace: Optional local workspace directory to pack
31+
32+
Returns:
33+
JobTrait instance configured for the scheduler
34+
35+
Raises:
36+
NotImplementedError: If scheduler is not yet supported
37+
ValueError: If scheduler is unsupported
38+
"""
39+
match scheduler:
40+
case "mast_conda":
41+
from monarch._src.job.meta import MASTJob
42+
43+
job = MASTJob(
44+
hpcIdentity=scheduler_cfg["hpcIdentity"],
45+
hpcJobOncall=scheduler_cfg["hpcJobOncall"],
46+
rmAttribution=scheduler_cfg["rmAttribution"],
47+
hpcClusterUuid=scheduler_cfg.get("hpcClusterUuid", "MastProdCluster"),
48+
)
49+
job.add_mesh("workers", num_hosts, host_type)
50+
51+
# Add workspace if provided (pack to root of WORKSPACE_DIR)
52+
if workspace:
53+
job.add_directory(workspace, "")
54+
55+
return job
56+
57+
case "slurm":
58+
raise NotImplementedError(f"Scheduler {scheduler} not yet supported")
59+
60+
case _:
61+
raise ValueError(f"Unsupported scheduler: {scheduler}")
62+
63+
64+
class SPMDJob(JobTrait):
65+
"""
66+
SPMD (Single Program Multiple Data) job that wraps any JobTrait.
67+
68+
This job type is created via `monarch serve torchx ...` CLI and stores
69+
both the underlying job (e.g., MASTJob) and the original torchx command metadata.
70+
"""
71+
72+
def __init__(
73+
self,
74+
job: JobTrait,
75+
scheduler: str,
76+
nnodes: int,
77+
nproc_per_node: int,
78+
component: str,
79+
component_args: List[str],
80+
script_args: List[str],
81+
workspace: Optional[str] = None,
82+
scheduler_args: Optional[Dict[str, Any]] = None,
83+
):
84+
super().__init__()
85+
self._job = job
86+
self._scheduler = scheduler
87+
self._nnodes = nnodes
88+
self._nproc_per_node = nproc_per_node
89+
self._component = component
90+
self._component_args = component_args
91+
self._script_args = script_args
92+
self._workspace = workspace
93+
self._scheduler_args = scheduler_args or {}
94+
95+
def _create(self, client_script: Optional[str] = None):
96+
self._job._create(client_script)
97+
98+
def can_run(self, spec: "JobTrait") -> bool:
99+
if not isinstance(spec, SPMDJob):
100+
return False
101+
return (
102+
self._scheduler == spec._scheduler
103+
and self._nnodes == spec._nnodes
104+
and self._nproc_per_node == spec._nproc_per_node
105+
and self._component == spec._component
106+
and self._component_args == spec._component_args
107+
and self._script_args == spec._script_args
108+
and self._workspace == spec._workspace
109+
and self._scheduler_args == spec._scheduler_args
110+
and self._job.can_run(spec._job)
111+
)
112+
113+
def _state(self) -> JobState:
114+
return self._job._state()
115+
116+
def _kill(self):
117+
self._job._kill()
118+
119+
def run_spmd(self):
120+
configure(default_transport=ChannelTransport.MetaTlsWithHostname)
121+
job_state = self._state()
122+
workers = job_state.workers
123+
pm = workers.spawn_procs(per_host={"gpus": self._nproc_per_node})
124+
am = pm.spawn("_SPMDActor", SPMDActor)
125+
126+
first_values = dict.fromkeys(pm._labels, 0)
127+
master_addr, master_port = (
128+
am.slice(**first_values).get_host_port.call_one(None).get()
129+
)
130+
131+
print("Calling SPMDActor.main with:")
132+
print(f" master_addr: {master_addr}")
133+
print(f" master_port: {master_port}")
134+
print(f" script_args: {self._script_args}")
135+
136+
am.main.call(master_addr, master_port, self._script_args).get()

python/monarch/_src/spmd/actor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,11 @@ def main(self, master_addr: str, master_port: int, script_args: list[str]) -> bo
112112
"""
113113
self._setup_env(master_addr, master_port)
114114

115+
# Change to workspace directory if available
116+
workspace_dir = os.environ.get("WORKSPACE_DIR")
117+
if workspace_dir and os.path.exists(workspace_dir):
118+
os.chdir(workspace_dir)
119+
115120
if script_args and script_args[0] == "-m":
116121
module_name = script_args[1]
117122
sys.argv = [module_name] + list(script_args[2:])

python/monarch/job/spmd.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from monarch._src.job.spmd import create_job_for_scheduler, SPMDJob
8+
9+
__all__ = ["SPMDJob", "create_job_for_scheduler"]

python/monarch/tools/cli.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
import json
1010
import sys
1111

12+
from monarch.job.spmd import ( # @manual=//monarch/python/monarch/job:job
13+
create_job_for_scheduler,
14+
SPMDJob,
15+
)
1216
from monarch.tools.commands import (
1317
bounce,
1418
component_args_from_cli,
@@ -25,6 +29,9 @@
2529
)
2630

2731
from monarch.tools.debug_env import _get_debug_server_host, _get_debug_server_port
32+
from torchx.cli.cmd_run import _parse_component_name_and_args, CmdRun
33+
from torchx.components.fb import parse_j
34+
from torchx.runner import get_runner
2835
from torchx.specs.finder import get_component
2936

3037

@@ -163,6 +170,141 @@ def run(self, args: argparse.Namespace) -> None:
163170
debug(args.host, args.port)
164171

165172

173+
class ServeCmd:
174+
"""
175+
Parse and cache a torchx command for monarch execution.
176+
177+
Example:
178+
monarch serve torchx run -s conda_mast -j1x8 train.py -- --lr 0.001
179+
"""
180+
181+
def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
182+
subparser.add_argument(
183+
"torchx_args",
184+
nargs=argparse.REMAINDER,
185+
help="torchx command arguments (e.g., 'run -s mast_conda -j1x8 train.py -- --lr 0.001')",
186+
)
187+
188+
def run(self, args: argparse.Namespace) -> None:
189+
# Validate input
190+
if (
191+
not args.torchx_args
192+
or len(args.torchx_args) < 2
193+
or args.torchx_args[0] != "torchx"
194+
or args.torchx_args[1] != "run"
195+
):
196+
print("Error: Expected 'torchx run ...' command", file=sys.stderr)
197+
print(
198+
"Usage: monarch serve torchx run --scheduler SCHEDULER [--scheduler_args ARGS] COMPONENT [COMPONENT_ARGS] [-- SCRIPT_ARGS]",
199+
file=sys.stderr,
200+
)
201+
sys.exit(1)
202+
203+
# Create torchx CmdRun to reuse its parser
204+
cmd_run = CmdRun()
205+
parser = argparse.ArgumentParser()
206+
cmd_run.add_arguments(parser)
207+
208+
# Remove 'torchx run' from beginning
209+
torchx_args = args.torchx_args[2:]
210+
211+
# Parse using torchx's parser
212+
try:
213+
parsed = parser.parse_args(torchx_args)
214+
except SystemExit:
215+
print("Error: Failed to parse torchx arguments", file=sys.stderr)
216+
sys.exit(1)
217+
218+
# Get runner to parse scheduler_args
219+
runner = get_runner()
220+
scheduler_opts = runner.scheduler_run_opts(parsed.scheduler)
221+
scheduler_cfg = scheduler_opts.cfg_from_str(parsed.scheduler_args or "")
222+
223+
# Parse component name and args using torchx helper
224+
component_name, component_args = _parse_component_name_and_args(
225+
parsed.component_name_and_args, parser
226+
)
227+
228+
# Extract script args (everything after -- delimiter)
229+
script_args = []
230+
try:
231+
delimiter_idx = component_args.index("--")
232+
script_args = component_args[delimiter_idx + 1 :]
233+
except ValueError:
234+
pass
235+
236+
# Extract -j and -h from component_args
237+
job_spec = None
238+
host_type = "gtt_any"
239+
i = 0
240+
while i < len(component_args):
241+
if component_args[i] in ["-j", "--job_spec"]:
242+
if i + 1 < len(component_args):
243+
job_spec = component_args[i + 1]
244+
i += 2
245+
else:
246+
print("Error: -j requires a value", file=sys.stderr)
247+
sys.exit(1)
248+
elif component_args[i] in ["-h", "--host_type"]:
249+
if i + 1 < len(component_args):
250+
host_type = component_args[i + 1]
251+
i += 2
252+
else:
253+
print("Error: -h requires a value", file=sys.stderr)
254+
sys.exit(1)
255+
else:
256+
i += 1
257+
258+
if not job_spec:
259+
print(
260+
"Error: -j/--job_spec required in component arguments", file=sys.stderr
261+
)
262+
sys.exit(1)
263+
264+
# Parse job_spec using torchx's parse_j
265+
try:
266+
nnodes, nproc_per_node = parse_j(job_spec)
267+
except Exception as e:
268+
print(f"Error: Failed to parse job spec '{job_spec}': {e}", file=sys.stderr)
269+
sys.exit(1)
270+
271+
print(f"Scheduler: {parsed.scheduler}")
272+
print(f"Component: {component_name}")
273+
print(
274+
f"Job spec: {job_spec} ({nnodes} node(s) x {nproc_per_node} proc(s) per node)"
275+
)
276+
print(f"Host type: {host_type}")
277+
if parsed.workspace:
278+
print(f"Workspace: {parsed.workspace}")
279+
280+
# Create underlying job based on scheduler type
281+
underlying_job = create_job_for_scheduler(
282+
scheduler=parsed.scheduler,
283+
scheduler_cfg=scheduler_cfg,
284+
num_hosts=nnodes,
285+
host_type=host_type,
286+
workspace=parsed.workspace,
287+
)
288+
289+
# Wrap in SPMDJob
290+
spmd_job = SPMDJob(
291+
job=underlying_job,
292+
scheduler=parsed.scheduler,
293+
nnodes=nnodes,
294+
nproc_per_node=nproc_per_node,
295+
component=component_name,
296+
component_args=component_args,
297+
script_args=script_args,
298+
workspace=parsed.workspace,
299+
scheduler_args=scheduler_cfg,
300+
)
301+
302+
# Launch job (calls apply + caches)
303+
print(f"\nLaunching {parsed.scheduler} job...")
304+
spmd_job.state()
305+
print("✓ Job launched successfully and cached to .monarch/job_state.pkl")
306+
307+
166308
def get_parser() -> argparse.ArgumentParser:
167309
parser = argparse.ArgumentParser(description="Monarch CLI")
168310
subparser = parser.add_subparsers(title="COMMANDS")
@@ -172,6 +314,7 @@ def get_parser() -> argparse.ArgumentParser:
172314
"info": InfoCmd(),
173315
"kill": KillCmd(),
174316
"debug": DebugCmd(),
317+
"serve": ServeCmd(),
175318
# --- placeholder subcommands (not yet implemented) ---
176319
"bounce": BounceCmd(),
177320
"stop": StopCmd(),

0 commit comments

Comments
 (0)