99import json
1010import sys
1111
12+ from monarch .job .spmd import ( # @manual=//monarch/python/monarch/job:job
13+ create_job_for_scheduler ,
14+ SPMDJob ,
15+ )
1216from monarch .tools .commands import (
1317 bounce ,
1418 component_args_from_cli ,
2529)
2630
2731from monarch .tools .debug_env import _get_debug_server_host , _get_debug_server_port
32+ from torchx .cli .cmd_run import _parse_component_name_and_args , CmdRun
33+ from torchx .components .fb import parse_j
34+ from torchx .runner import get_runner
2835from torchx .specs .finder import get_component
2936
3037
@@ -163,6 +170,141 @@ def run(self, args: argparse.Namespace) -> None:
163170 debug (args .host , args .port )
164171
165172
173+ class ServeCmd :
174+ """
175+ Parse and cache a torchx command for monarch execution.
176+
177+ Example:
178+ monarch serve torchx run -s conda_mast -j1x8 train.py -- --lr 0.001
179+ """
180+
181+ def add_arguments (self , subparser : argparse .ArgumentParser ) -> None :
182+ subparser .add_argument (
183+ "torchx_args" ,
184+ nargs = argparse .REMAINDER ,
185+ help = "torchx command arguments (e.g., 'run -s mast_conda -j1x8 train.py -- --lr 0.001')" ,
186+ )
187+
188+ def run (self , args : argparse .Namespace ) -> None :
189+ # Validate input
190+ if (
191+ not args .torchx_args
192+ or len (args .torchx_args ) < 2
193+ or args .torchx_args [0 ] != "torchx"
194+ or args .torchx_args [1 ] != "run"
195+ ):
196+ print ("Error: Expected 'torchx run ...' command" , file = sys .stderr )
197+ print (
198+ "Usage: monarch serve torchx run --scheduler SCHEDULER [--scheduler_args ARGS] COMPONENT [COMPONENT_ARGS] [-- SCRIPT_ARGS]" ,
199+ file = sys .stderr ,
200+ )
201+ sys .exit (1 )
202+
203+ # Create torchx CmdRun to reuse its parser
204+ cmd_run = CmdRun ()
205+ parser = argparse .ArgumentParser ()
206+ cmd_run .add_arguments (parser )
207+
208+ # Remove 'torchx run' from beginning
209+ torchx_args = args .torchx_args [2 :]
210+
211+ # Parse using torchx's parser
212+ try :
213+ parsed = parser .parse_args (torchx_args )
214+ except SystemExit :
215+ print ("Error: Failed to parse torchx arguments" , file = sys .stderr )
216+ sys .exit (1 )
217+
218+ # Get runner to parse scheduler_args
219+ runner = get_runner ()
220+ scheduler_opts = runner .scheduler_run_opts (parsed .scheduler )
221+ scheduler_cfg = scheduler_opts .cfg_from_str (parsed .scheduler_args or "" )
222+
223+ # Parse component name and args using torchx helper
224+ component_name , component_args = _parse_component_name_and_args (
225+ parsed .component_name_and_args , parser
226+ )
227+
228+ # Extract script args (everything after -- delimiter)
229+ script_args = []
230+ try :
231+ delimiter_idx = component_args .index ("--" )
232+ script_args = component_args [delimiter_idx + 1 :]
233+ except ValueError :
234+ pass
235+
236+ # Extract -j and -h from component_args
237+ job_spec = None
238+ host_type = "gtt_any"
239+ i = 0
240+ while i < len (component_args ):
241+ if component_args [i ] in ["-j" , "--job_spec" ]:
242+ if i + 1 < len (component_args ):
243+ job_spec = component_args [i + 1 ]
244+ i += 2
245+ else :
246+ print ("Error: -j requires a value" , file = sys .stderr )
247+ sys .exit (1 )
248+ elif component_args [i ] in ["-h" , "--host_type" ]:
249+ if i + 1 < len (component_args ):
250+ host_type = component_args [i + 1 ]
251+ i += 2
252+ else :
253+ print ("Error: -h requires a value" , file = sys .stderr )
254+ sys .exit (1 )
255+ else :
256+ i += 1
257+
258+ if not job_spec :
259+ print (
260+ "Error: -j/--job_spec required in component arguments" , file = sys .stderr
261+ )
262+ sys .exit (1 )
263+
264+ # Parse job_spec using torchx's parse_j
265+ try :
266+ nnodes , nproc_per_node = parse_j (job_spec )
267+ except Exception as e :
268+ print (f"Error: Failed to parse job spec '{ job_spec } ': { e } " , file = sys .stderr )
269+ sys .exit (1 )
270+
271+ print (f"Scheduler: { parsed .scheduler } " )
272+ print (f"Component: { component_name } " )
273+ print (
274+ f"Job spec: { job_spec } ({ nnodes } node(s) x { nproc_per_node } proc(s) per node)"
275+ )
276+ print (f"Host type: { host_type } " )
277+ if parsed .workspace :
278+ print (f"Workspace: { parsed .workspace } " )
279+
280+ # Create underlying job based on scheduler type
281+ underlying_job = create_job_for_scheduler (
282+ scheduler = parsed .scheduler ,
283+ scheduler_cfg = scheduler_cfg ,
284+ num_hosts = nnodes ,
285+ host_type = host_type ,
286+ workspace = parsed .workspace ,
287+ )
288+
289+ # Wrap in SPMDJob
290+ spmd_job = SPMDJob (
291+ job = underlying_job ,
292+ scheduler = parsed .scheduler ,
293+ nnodes = nnodes ,
294+ nproc_per_node = nproc_per_node ,
295+ component = component_name ,
296+ component_args = component_args ,
297+ script_args = script_args ,
298+ workspace = parsed .workspace ,
299+ scheduler_args = scheduler_cfg ,
300+ )
301+
302+ # Launch job (calls apply + caches)
303+ print (f"\n Launching { parsed .scheduler } job..." )
304+ spmd_job .state ()
305+ print ("✓ Job launched successfully and cached to .monarch/job_state.pkl" )
306+
307+
166308def get_parser () -> argparse .ArgumentParser :
167309 parser = argparse .ArgumentParser (description = "Monarch CLI" )
168310 subparser = parser .add_subparsers (title = "COMMANDS" )
@@ -172,6 +314,7 @@ def get_parser() -> argparse.ArgumentParser:
172314 "info" : InfoCmd (),
173315 "kill" : KillCmd (),
174316 "debug" : DebugCmd (),
317+ "serve" : ServeCmd (),
175318 # --- placeholder subcommands (not yet implemented) ---
176319 "bounce" : BounceCmd (),
177320 "stop" : StopCmd (),
0 commit comments