torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
monarch/notebook.py ADDED
@@ -0,0 +1,761 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import argparse
9
+ import json
10
+ import logging
11
+ import os
12
+ import socket
13
+ import subprocess
14
+ import sys
15
+ import tempfile
16
+ import time
17
+ from getpass import getuser
18
+ from importlib.abc import SourceLoader
19
+ from importlib.machinery import ExtensionFileLoader, SourceFileLoader
20
+
21
+ from pathlib import Path
22
+ from pprint import pprint
23
+ from socket import gethostname
24
+ from subprocess import check_call, check_output
25
+ from tempfile import NamedTemporaryFile
26
+ from threading import Thread
27
+ from typing import Any, List, Optional
28
+
29
+ import zmq
30
+ from monarch.common.device_mesh import DeviceMesh
31
+
32
+ from monarch.common.mast import mast_get_jobs, MastJob
33
+ from monarch.common.remote import remote
34
+ from monarch.world_mesh import world_mesh
35
+ from monarch_supervisor import Context, get_message_queue, HostConnected
36
+ from monarch_supervisor.host import main as host_main
37
+ from monarch_supervisor.logging import initialize_logging
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+ RESERVE_MAST_TASK_GROUP_NAME = "workers"
42
+ TORCHX_MAST_TASK_GROUP_NAME = "script"
43
+
44
+
45
+ class _Importer:
46
+ def __init__(self, ctx: zmq.Context, endpoint):
47
+ self.socket = ctx.socket(zmq.REQ)
48
+ self.socket.setsockopt(zmq.IPV6, True)
49
+ self.socket.connect(endpoint)
50
+
51
+ def find_spec(self, fullname, path, target=None):
52
+ self.socket.send_pyobj((fullname, path, target))
53
+ r = self.socket.recv_pyobj()
54
+ return r
55
+
56
+
57
+ class _SourceLoader(SourceLoader):
58
+ def __init__(self, name, path, data):
59
+ self.name = name
60
+ self.path = path
61
+ self.data = data
62
+
63
+ def get_filename(self, fullname):
64
+ return self.path
65
+
66
+ def get_data(self, path):
67
+ return self.data
68
+
69
+
70
+ class _ExtensionLoader:
71
+ def __init__(self, name, path, data):
72
+ self.name = name
73
+ self.path = path
74
+ self.data = data
75
+
76
+ def create_module(self, spec):
77
+ with NamedTemporaryFile("wb", delete=False) as f:
78
+ f.write(self.data)
79
+ self._loader = ExtensionFileLoader(self.name, f.name)
80
+ return self._loader.create_module(spec)
81
+
82
+ def exec_module(self, module):
83
+ return self._loader.exec_module(module)
84
+
85
+
86
+ class ControllerImporterServer:
87
+ def __init__(self, context: zmq.Context):
88
+ self.socket: zmq.Socket = context.socket(zmq.REP)
89
+ self.socket.setsockopt(zmq.IPV6, True)
90
+ self.hostname = socket.gethostname()
91
+ self.port = self.socket.bind_to_random_port("tcp://*")
92
+ self.endpoint = f"tcp://{self.hostname}:{self.port}"
93
+
94
+ def run(self):
95
+ while True:
96
+ fullname, path, target = self.socket.recv_pyobj()
97
+ s = None
98
+ for m in sys.meta_path:
99
+ s = m.find_spec(fullname, path, target)
100
+ if s is not None:
101
+ # print("SERVER FOUND", s.loader, s.loader.__dict__)
102
+ if isinstance(s.loader, SourceFileLoader):
103
+ s.loader = _SourceLoader(
104
+ s.loader.name,
105
+ s.loader.path,
106
+ s.loader.get_data(s.loader.path),
107
+ )
108
+ elif isinstance(s.loader, ExtensionFileLoader):
109
+ with open(s.loader.path, "rb") as f:
110
+ s.loader = _ExtensionLoader(
111
+ s.loader.name, s.loader.path, f.read()
112
+ )
113
+ else:
114
+ s = None
115
+ break
116
+ self.socket.send_pyobj(s)
117
+
118
+
119
+ def _start_importer_server(context: zmq.Context):
120
+ server = ControllerImporterServer(
121
+ context,
122
+ )
123
+ thread = Thread(target=server.run, daemon=True)
124
+ thread.start()
125
+ return thread, server.endpoint
126
+
127
+
128
+ def _create_fbpkg(name, directory):
129
+ def create_json(data):
130
+ temp_dir = tempfile.mkdtemp()
131
+ d = os.path.join(temp_dir, "materialized_configs")
132
+ os.makedirs(d)
133
+ json_file_path = os.path.join(d, f"{name}.fbpkg.materialized_JSON")
134
+ with open(json_file_path, "w") as json_file:
135
+ json.dump(data, json_file)
136
+ return temp_dir
137
+
138
+ package_json = {
139
+ "paths": os.listdir(directory),
140
+ "build_command": "",
141
+ }
142
+ path = create_json(package_json)
143
+ package_version = (
144
+ check_output(
145
+ [
146
+ "fbpkg",
147
+ "build",
148
+ "--yes",
149
+ "--ephemeral",
150
+ "--configerator-path",
151
+ path,
152
+ name,
153
+ "--expire",
154
+ "4w",
155
+ ],
156
+ cwd=directory,
157
+ )
158
+ .decode("utf-8")
159
+ .split("\n")[-2]
160
+ )
161
+ return package_version.split(":")
162
+
163
+
164
+ def _nfs_start(
165
+ endpoint: str,
166
+ datacenter: str,
167
+ twmount="/packages/nfs.twmount/twmount",
168
+ dest="/mnt/gen_ai_input_data_nfs",
169
+ options="vers=4.2,port=20492,proto=tcp6,nosharecache",
170
+ ):
171
+ addr_map = {"eag": "[2401:db00:3ff:c0c3::2a]:/ai"}
172
+ addr = addr_map[datacenter]
173
+ cmds = [
174
+ f"{twmount} mount -t nfs4 -s {addr} -d {dest} -o {options}",
175
+ "mkdir -p /mnt/aidev/$MAST_JOB_OWNER_UNIXNAME",
176
+ f"mount --bind {dest}/aidev/$MAST_JOB_OWNER_UNIXNAME /mnt/aidev/$MAST_JOB_OWNER_UNIXNAME",
177
+ f"{sys.executable} -m monarch.notebook worker --endpoint {endpoint}",
178
+ ]
179
+ return " && ".join(cmds)
180
+
181
+
182
+ def _ephemeral_package(package, version):
183
+ return {
184
+ "fbpkgIdentifier": f"{package}:{version}",
185
+ "name": package,
186
+ "version": {"ephemeralId": version},
187
+ }
188
+
189
+
190
+ def _package(package: str, version: str):
191
+ return {
192
+ "fbpkgIdentifier": f"{package}:{version}",
193
+ "name": package,
194
+ }
195
+
196
+
197
+ def launch_mast(
198
+ base_name: str,
199
+ packages: List[Any],
200
+ hosts: int,
201
+ command: str,
202
+ run: bool = True,
203
+ base_image: Optional[str] = None,
204
+ datacenter: Optional[str] = None,
205
+ ):
206
+ name = f"{base_name}_{time.time_ns()}"
207
+ jobspec = {
208
+ "applicationMetadata": {},
209
+ "enableGracefulPreemption": False,
210
+ "hpcClusterUuid": "MastGenAICluster",
211
+ "hpcTaskGroups": [
212
+ {
213
+ "hardwareSpecificTaskGroupOverride": {},
214
+ "name": RESERVE_MAST_TASK_GROUP_NAME,
215
+ "spec": {
216
+ "applicationPackages": [
217
+ {
218
+ "fbpkgIdentifier": "ttls_so:1125",
219
+ "name": "ttls_so",
220
+ },
221
+ *packages,
222
+ ],
223
+ "arguments": [],
224
+ "bindMounts": [],
225
+ "command": command,
226
+ "env": {
227
+ "LD_PRELOAD": "/packages/ttls_so/TransparentTls3.so",
228
+ "TTLS_ENABLED": "1",
229
+ },
230
+ "machineConstraints": {"types": {"serverTypes": [100]}},
231
+ "networkAffinity": {"fallbackScope": 1, "preferredScope": 2},
232
+ "oncallShortname": "pytorch_distributed",
233
+ "opecTag": 0,
234
+ "packages": [],
235
+ "resourceLimit": {
236
+ "compute": {"cpu": 15, "gpu": 0},
237
+ "enableSwapAndSenpai": False,
238
+ "limitType": 0,
239
+ "ramMB": 54272,
240
+ "wholeHost": True,
241
+ },
242
+ "restartPolicy": {
243
+ "failoverOnHostFailures": False,
244
+ "maxTotalFailures": 10,
245
+ "scope": 0,
246
+ },
247
+ "runningTimeoutSec": 2592000,
248
+ "unixUser": "root",
249
+ "ttlsConfig": {"enable": True},
250
+ },
251
+ "taskCount": hosts,
252
+ "taskCountPerHost": 1,
253
+ }
254
+ ],
255
+ "identity": {"name": "oncall_pytorch_distributed"},
256
+ "jobType": 0,
257
+ "maxJobFailures": 0,
258
+ "name": name,
259
+ "networkAffinity": {"fallbackScope": 1, "preferredScope": 2},
260
+ "owner": {
261
+ "oncallShortname": "pytorch_distributed",
262
+ "unixname": os.environ["USER"],
263
+ },
264
+ "aiTrainingMetadata": {
265
+ "launcher": 3,
266
+ "trainingFramework": 10,
267
+ "jobType": None,
268
+ "modelTypeName": "gen_ai_conda",
269
+ "jobPurpose": None,
270
+ "entitlement": "pytorch_distributed",
271
+ "tenantPath": None,
272
+ "tenantPriority": None,
273
+ "productGroup": None,
274
+ "rootWorkflowRunID": None,
275
+ "mastJobID": name,
276
+ "mastWorkflowRunID": None,
277
+ "productGroupMetadata": None,
278
+ "modelIDs": None,
279
+ "model_lifecycle_status": {},
280
+ },
281
+ }
282
+
283
+ if base_image is not None:
284
+ # pyre-fixme[16]: Item `bool` of `Union[Dict[typing.Any, typing.Any], Dict[st...
285
+ jobspec["hpcTaskGroups"][0]["baseImage"] = {
286
+ "baseImagePackage": {
287
+ "fbpkgIdentifier": base_image,
288
+ }
289
+ }
290
+ if datacenter is not None:
291
+ jobspec["localityConstraints"] = {"locality": 1, "options": [datacenter]}
292
+
293
+ pprint(jobspec)
294
+ with tempfile.NamedTemporaryFile(delete=False, mode="w") as job_spec_file:
295
+ json.dump(jobspec, job_spec_file)
296
+ print(job_spec_file.name)
297
+ if run:
298
+ check_call(["mast", "schedule", job_spec_file.name])
299
+ return name
300
+
301
+
302
+ def _endpoint():
303
+ hostname = socket.gethostname()
304
+ with socket.socket() as sock:
305
+ sock.bind(("", 0))
306
+ port = sock.getsockname()[1]
307
+ return hostname, port
308
+
309
+
310
+ def reserve(hosts, nfs=False, force_rebuild=False):
311
+ hostname, port = _endpoint()
312
+ endpoint = f"tcp://{hostname}:{port}"
313
+ name = f"notebook_{port}"
314
+ if nfs:
315
+ executable = Path(sys.executable)
316
+ nfs_path = Path(f"/mnt/aidev/{getuser()}")
317
+ if nfs_path not in executable.parents:
318
+ raise RuntimeError(
319
+ f"conda environment ({executable}) must be installed in nfs path ({nfs_path}) to use nfs workers."
320
+ )
321
+ datacenter = socket.gethostname().split(".")[1][:3]
322
+ jobname = launch_mast(
323
+ base_name=name,
324
+ packages=[_package("nfs.twmount", "stable")],
325
+ hosts=hosts,
326
+ command=_nfs_start(endpoint, datacenter),
327
+ base_image="tupperware.image.sendstream.c9.flare",
328
+ datacenter=datacenter,
329
+ )
330
+ else:
331
+ environment = os.environ["CONDA_PREFIX"]
332
+ cache = f'{os.environ["HOME"]}/.controller_notebook_package'
333
+ try:
334
+ with open(cache, "r") as f:
335
+ package, version = f.read().strip().split(":")
336
+ except FileNotFoundError:
337
+ package, version = None, None
338
+
339
+ if force_rebuild or package is None:
340
+ package, version = _create_fbpkg("ptd_supervisor_testbed", environment)
341
+ with open(cache, "w") as f:
342
+ f.write(f"{package}:{version}\n")
343
+ jobname = launch_mast(
344
+ base_name=name,
345
+ packages=[_ephemeral_package(package, version)],
346
+ hosts=hosts,
347
+ command=f"/packages/{package}/bin/python -m monarch.notebook worker --endpoint {endpoint}",
348
+ )
349
+
350
+ return jobname
351
+
352
+
353
+ def _register_importer(endpoint: str):
354
+ sys.meta_path.append(_Importer(get_message_queue()._ctx, endpoint))
355
+
356
+
357
+ register_importer = remote("monarch.notebook._register_importer", propagate="inspect")
358
+
359
+ _chdir = remote("os.chdir", propagate="inspect")
360
+
361
+
362
+ def mast_job_is_valid(job):
363
+ args = job.get_arguments()
364
+ if args[0:3] != ["-mmonarch.notebook", "worker", "--endpoint"]:
365
+ return False
366
+ maybe_host_and_port = args[3].removeprefix("tcp://").split(":")
367
+ if len(maybe_host_and_port) != 2:
368
+ return False
369
+ host, port = maybe_host_and_port
370
+ return host == socket.gethostname() and port.isdigit()
371
+
372
+
373
+ def get_mast_notebook_jobs(task_group):
374
+ jobs = []
375
+ for job in mast_get_jobs(task_group):
376
+ if "monarch" in job.name() and mast_job_is_valid(job):
377
+ jobs.append(job)
378
+ return jobs
379
+
380
+
381
+ def connect(jobname=None):
382
+ job = None
383
+ user_jobs = get_mast_notebook_jobs(RESERVE_MAST_TASK_GROUP_NAME)
384
+ if jobname is None:
385
+ for j in sorted(user_jobs, key=lambda x: x.get_create_time(), reverse=True):
386
+ if j.name().startswith("notebook_"):
387
+ jobname = j.name()
388
+ job = j
389
+ break
390
+ if job is None:
391
+ raise RuntimeError(
392
+ "no valid jobs found, use monarch.notebook.reserve_workers to create one."
393
+ )
394
+ else:
395
+ for j in user_jobs:
396
+ if j.name() == jobname:
397
+ job = j
398
+ break
399
+ if job is None:
400
+ names = "\n".join([j.name() for j in user_jobs])
401
+ raise RuntimeError(
402
+ f"{jobname} is not one of your current running jobs. Choose from:\n{names}"
403
+ )
404
+
405
+ job.wait_for_running(600 * 10)
406
+
407
+ N = job.get_task_count()
408
+ uses_nfs = job.uses_nfs()
409
+ port = int(job.name().split("_")[1])
410
+
411
+ ctx = Context(port=port)
412
+ ctx.request_hosts(N)
413
+ connections = ctx.messagefilter(HostConnected)
414
+ hosts = [connections.recv(timeout=30).sender for _ in range(N)]
415
+ mesh = world_mesh(ctx, hosts, 8)
416
+ if uses_nfs:
417
+ nfs_path = Path(f"/mnt/aidev/{getuser()}")
418
+ cwd = Path(os.getcwd())
419
+ if nfs_path in cwd.parents:
420
+ with mesh.activate():
421
+ _chdir(str(cwd))
422
+ else:
423
+ _, importer_endpoint = _start_importer_server(ctx._context)
424
+ with mesh.activate():
425
+ register_importer(importer_endpoint)
426
+ logger.info("connected to mast workers")
427
+ return mesh
428
+
429
+
430
+ _ctx: Optional[Context] = None
431
+ _active_mesh: Optional[DeviceMesh] = None
432
+ _is_logging_initialized = False
433
+
434
+
435
+ _DEFAULT_TORCHX_WORKSPACE_PATH = (
436
+ f"/data/users/{os.getenv('USER')}/fbsource/fbcode/monarch/examples"
437
+ )
438
+ _DEFAULT_LOCALITY_CONSTRAINTS = "region;pci"
439
+ _DEFAULT_RM_ATTRIBUTION = "gen_ai_rf_nextgen_infra"
440
+ _DEFAULT_RUNNING_TIMEOUT_SEC = 3600
441
+
442
+
443
+ def reserve_torchx(
444
+ hosts: int,
445
+ torchx_workspace_path: str = _DEFAULT_TORCHX_WORKSPACE_PATH,
446
+ nfs_workspace_dir: Optional[str] = None,
447
+ oilfs_workspace_dir: Optional[str] = None,
448
+ workspace_dir: Optional[str] = None,
449
+ conda_dir: Optional[str] = None,
450
+ locality_constraints: str = _DEFAULT_LOCALITY_CONSTRAINTS,
451
+ rm_attribution: str = _DEFAULT_RM_ATTRIBUTION,
452
+ running_timeout_sec: int = _DEFAULT_RUNNING_TIMEOUT_SEC,
453
+ additional_scheduler_args: Optional[str] = None,
454
+ ) -> str:
455
+ global _is_logging_initialized
456
+ # Avoid initializing logging more than once. Otherwise we'll
457
+ # get duplicate logs.
458
+ if not _is_logging_initialized:
459
+ initialize_logging()
460
+ _is_logging_initialized = True
461
+
462
+ hostname, port = _endpoint()
463
+ old_cwd = os.getcwd()
464
+ os.chdir(torchx_workspace_path)
465
+
466
+ scheduler_args = ",".join(
467
+ [
468
+ f"localityConstraints={locality_constraints}",
469
+ f"rmAttribution={rm_attribution}",
470
+ f"runningTimeoutSec={running_timeout_sec}",
471
+ ]
472
+ )
473
+ if additional_scheduler_args:
474
+ scheduler_args += "," + additional_scheduler_args
475
+
476
+ job_base_name = f"monarch_{time.time_ns()}"
477
+
478
+ torchx_cmd = [
479
+ "torchx",
480
+ "run",
481
+ "--scheduler_args",
482
+ scheduler_args,
483
+ "mast.py:train",
484
+ "--name",
485
+ job_base_name,
486
+ "--nodes",
487
+ str(hosts),
488
+ "--enable_ttls",
489
+ "True",
490
+ ]
491
+
492
+ if nfs_workspace_dir is not None:
493
+ torchx_cmd.extend(["--nfs_workspace_dir", nfs_workspace_dir])
494
+
495
+ if oilfs_workspace_dir is not None:
496
+ torchx_cmd.extend(["--oilfs_workspace_dir", oilfs_workspace_dir])
497
+
498
+ if workspace_dir is not None:
499
+ torchx_cmd.extend(["--workspace_dir", workspace_dir])
500
+
501
+ if conda_dir is not None:
502
+ torchx_cmd.extend(["--conda_dir", conda_dir])
503
+
504
+ torchx_cmd.extend(
505
+ [
506
+ "--module",
507
+ "monarch.notebook",
508
+ "--",
509
+ "worker",
510
+ "--endpoint",
511
+ f"tcp://{hostname}:{port}",
512
+ ]
513
+ )
514
+
515
+ try:
516
+ subprocess.run(torchx_cmd)
517
+
518
+ logger.info(
519
+ f"Started mast workers with supervisor_addr: {f'tcp://{hostname}:{port}'}"
520
+ )
521
+
522
+ return [
523
+ job.name()
524
+ for job in get_mast_notebook_jobs(TORCHX_MAST_TASK_GROUP_NAME)
525
+ if job.name().startswith(job_base_name)
526
+ ][0]
527
+ finally:
528
+ # This gets called even if the try block succeeds and returns.
529
+ os.chdir(old_cwd)
530
+
531
+
532
+ log = remote("monarch.worker._testing_function.log", propagate="inspect")
533
+
534
+
535
+ def mast_mesh(
536
+ mast_job_name: str,
537
+ hosts: Optional[int] = None,
538
+ n_gpus_per_host: Optional[int] = None,
539
+ max_retries: Optional[int] = None,
540
+ ):
541
+ global _ctx, _active_mesh, _is_logging_initialized
542
+ # Avoid initializing logging more than once. Otherwise we'll
543
+ # get duplicate logs.
544
+ if not _is_logging_initialized:
545
+ initialize_logging()
546
+ _is_logging_initialized = True
547
+
548
+ mast_job = MastJob(mast_job_name, TORCHX_MAST_TASK_GROUP_NAME)
549
+ try:
550
+ assert mast_job_is_valid(mast_job)
551
+ except subprocess.CalledProcessError as e:
552
+ raise RuntimeError("Failed to get mast job status") from e
553
+ except AssertionError as e:
554
+ raise RuntimeError(
555
+ "Based on job name and args, this does not appear to be a "
556
+ "mast job created by monarch's monarch.notebook module. "
557
+ f"Your valid mast jobs are: {get_mast_notebook_jobs(TORCHX_MAST_TASK_GROUP_NAME)}"
558
+ ) from e
559
+
560
+ while True:
561
+ if mast_job.is_running():
562
+ logger.info(f"Found running mast job {mast_job_name}. Connecting...")
563
+ break
564
+ else:
565
+ logger.info(
566
+ f"Waiting for mast job {mast_job_name} and all its workers "
567
+ "to have status RUNNING. Sleeping for 10 seconds..."
568
+ )
569
+ time.sleep(10)
570
+
571
+ if hosts is None:
572
+ hosts = mast_job.get_num_hosts()
573
+ else:
574
+ assert (
575
+ hosts <= mast_job.get_num_hosts()
576
+ ), f"Requested {hosts} hosts, but job only has {mast_job.get_num_hosts()} hosts."
577
+
578
+ if n_gpus_per_host is None:
579
+ n_gpus_per_host = mast_job.get_gpus_per_host()
580
+ else:
581
+ assert n_gpus_per_host <= mast_job.get_gpus_per_host(), (
582
+ f"Requested {n_gpus_per_host} gpus per host, but job only has "
583
+ f"{mast_job.get_gpus_per_host()} gpus per host."
584
+ )
585
+
586
+ port = mast_job.get_port()
587
+
588
+ retry = 0
589
+ while max_retries is None or retry < max_retries:
590
+ retry += 1
591
+ try:
592
+ _ctx = Context(port=port)
593
+ ctx: Context = _ctx
594
+ ctx.request_hosts(hosts)
595
+ connections = ctx.messagefilter(HostConnected)
596
+ logger.info(f"connections: {connections}")
597
+ ctx_hosts = [connections.recv(timeout=30).sender for _ in range(hosts)]
598
+ logger.info(f"connections: {ctx_hosts}")
599
+ logger.info(
600
+ f"Connected to mast workers ({hosts} hosts, {n_gpus_per_host} gpus per host)"
601
+ )
602
+ _active_mesh = mesh = world_mesh(ctx, ctx_hosts, n_gpus_per_host)
603
+ mesh.exit = cleanup
604
+
605
+ def remote_mount_activate(
606
+ mesh,
607
+ local_mount_home_dir,
608
+ remote_mount_home_dir,
609
+ remote_mount_workspace_dir,
610
+ ):
611
+ """
612
+ This function does two things:
613
+ 1. If the mast workers are running in a remote mounted workspace directory,
614
+ then add the local equivalent to sys.path so that the notebook can import
615
+ modules relative to the remote workspace directory.
616
+ 2. If the (local) current working directory is inside a mounted file system
617
+ (e.g. NFS or OILFS), and the workers are also running inside the same mounted
618
+ file system, then change the workers' current working directory to the remote equivalent
619
+ of the local current working directory. Additionally, add the empty string to the workers'
620
+ sys.path so that they search their current working directory for modules.
621
+ """
622
+ if remote_mount_home_dir is None or remote_mount_workspace_dir is None:
623
+ return
624
+
625
+ local_mount_home_path = Path(local_mount_home_dir)
626
+ remote_mount_home_path = Path(remote_mount_home_dir)
627
+ remote_mount_workspace_path = Path(remote_mount_workspace_dir)
628
+ relative_workspace_path = remote_mount_workspace_path.relative_to(
629
+ remote_mount_home_path
630
+ )
631
+ local_mount_workspace_path = (
632
+ local_mount_home_path / relative_workspace_path
633
+ )
634
+
635
+ if str(local_mount_workspace_path) not in sys.path:
636
+ # So that the notebook can call remote functions defined in remote_mount_workspace_dir
637
+ # via call_remote even if the cwd isn't inside the local equivalent of remote_mount_workspace_dir.
638
+ sys.path.append(str(local_mount_workspace_path))
639
+
640
+ cwd = Path(os.getcwd())
641
+ if local_mount_home_path in cwd.parents or local_mount_home_path == cwd:
642
+ with mesh.activate():
643
+ # Append the empty string to sys.path on each of the workers so that they
644
+ # search their current working directory for modules.
645
+ remote(
646
+ lambda: (
647
+ sys.path.append("") if "" not in sys.path else None
648
+ )
649
+ )()
650
+ relative_cwd = cwd.relative_to(local_mount_home_path)
651
+ remote(lambda cwd: os.chdir(cwd))(
652
+ str(remote_mount_home_path / relative_cwd)
653
+ )
654
+
655
+ if mast_job.get_oilfs_workspace_dir() is not None:
656
+ remote_mount_activate(
657
+ mesh,
658
+ f"/home/{getuser()}/fuse-aidev",
659
+ mast_job.get_oilfs_home_dir(),
660
+ mast_job.get_oilfs_workspace_dir(),
661
+ )
662
+ elif mast_job.get_nfs_workspace_dir() is not None:
663
+ remote_mount_activate(
664
+ mesh,
665
+ f"/mnt/aidev/{getuser()}",
666
+ mast_job.get_nfs_home_dir(),
667
+ mast_job.get_nfs_workspace_dir(),
668
+ )
669
+
670
+ return _active_mesh
671
+ except TimeoutError:
672
+ logger.warning(
673
+ "Timed out waiting to connect to mast workers. "
674
+ f"Tried {retry} out of {max_retries if max_retries is not None else '(inf)'} "
675
+ "times."
676
+ )
677
+ cleanup()
678
+ except Exception as e:
679
+ cleanup()
680
+ raise e
681
+
682
+
683
+ def list_mast_jobs():
684
+ for job in get_mast_notebook_jobs():
685
+ print(job)
686
+
687
+
688
+ def cleanup():
689
+ global _ctx, _active_mesh
690
+ if _active_mesh:
691
+ _active_mesh.client.shutdown()
692
+ _active_mesh = None
693
+ if _ctx:
694
+ _ctx.shutdown()
695
+ _ctx = None
696
+
697
+
698
+ if __name__ == "__main__":
699
+ parser = argparse.ArgumentParser()
700
+ subparsers = parser.add_subparsers(dest="command")
701
+
702
+ reserve_parser = subparsers.add_parser("reserve")
703
+ reserve_parser.add_argument("--hosts", type=int)
704
+ reserve_parser.add_argument(
705
+ "--torchx_workspace_path",
706
+ type=str,
707
+ default=_DEFAULT_TORCHX_WORKSPACE_PATH,
708
+ )
709
+ reserve_parser.add_argument(
710
+ "--locality_constraints", type=str, default=_DEFAULT_LOCALITY_CONSTRAINTS
711
+ )
712
+ reserve_parser.add_argument("--nfs_workspace_dir", type=str, required=False)
713
+ reserve_parser.add_argument("--oilfs_workspace_dir", type=str, required=False)
714
+ reserve_parser.add_argument("--workspace_dir", type=str, required=False)
715
+ reserve_parser.add_argument("--conda_dir", type=str, required=False)
716
+ reserve_parser.add_argument(
717
+ "--rm_attribution", type=str, required=False, default=_DEFAULT_RM_ATTRIBUTION
718
+ )
719
+ reserve_parser.add_argument(
720
+ "--running_timeout_sec",
721
+ type=int,
722
+ required=False,
723
+ default=_DEFAULT_RUNNING_TIMEOUT_SEC,
724
+ )
725
+ reserve_parser.add_argument("--additional_scheduler_args", type=str, required=False)
726
+
727
+ worker_parser = subparsers.add_parser("worker")
728
+ worker_parser.add_argument("--endpoint", type=str)
729
+
730
+ args = parser.parse_args(sys.argv[1:])
731
+ if args.command == "reserve":
732
+ reserve_torchx(
733
+ args.hosts,
734
+ args.torchx_workspace_path,
735
+ args.nfs_workspace_dir,
736
+ args.oilfs_workspace_dir,
737
+ args.workspace_dir,
738
+ args.conda_dir,
739
+ args.locality_constraints,
740
+ args.rm_attribution,
741
+ args.running_timeout_sec,
742
+ args.additional_scheduler_args,
743
+ )
744
+ sys.exit(0)
745
+ else:
746
+ initialize_logging(f"{gethostname()} pid {os.getpid()} host-manager")
747
+ endpoint = args.endpoint
748
+ logger.info(f"Connecting to {endpoint}")
749
+ while True:
750
+ pid = os.fork()
751
+ if pid == 0:
752
+ try:
753
+ host_main(endpoint)
754
+ except ConnectionAbortedError:
755
+ logger.warning("host manager aborted, restarting new host manager")
756
+ sys.exit(0)
757
+ else:
758
+ exitpid, status = os.wait()
759
+ if status != 0:
760
+ logger.warning("Abnormal exit, stopping")
761
+ break