torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
monarch/sim_mesh.py ADDED
@@ -0,0 +1,359 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ import importlib.resources
10
+ import logging
11
+ import os
12
+ import random
13
+ import string
14
+ import subprocess
15
+ import tempfile
16
+ import time
17
+ from pathlib import Path
18
+ from typing import (
19
+ Callable,
20
+ ContextManager as AbstractContextManager,
21
+ Dict,
22
+ Generic,
23
+ Iterable,
24
+ List,
25
+ Optional,
26
+ Tuple,
27
+ )
28
+
29
+ from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension # @manual=//monarch/monarch_extension:monarch_extension
30
+ ClientActor,
31
+ )
32
+
33
+ from monarch._rust_bindings.monarch_extension.simulator_client import ( # @manual=//monarch/monarch_extension:monarch_extension
34
+ bootstrap_simulator_backend,
35
+ SimulatorClient,
36
+ )
37
+
38
+ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
39
+ ActorId,
40
+ init_proc,
41
+ Proc,
42
+ )
43
+ from monarch.common.client import Client
44
+ from monarch.common.constants import (
45
+ SIM_MESH_CLIENT_SUPERVISION_UPDATE_INTERVAL,
46
+ SIM_MESH_CLIENT_TIMEOUT,
47
+ )
48
+ from monarch.common.device_mesh import DeviceMesh
49
+ from monarch.common.fake import fake_call
50
+ from monarch.common.future import Future, T
51
+ from monarch.common.invocation import DeviceException, RemoteException
52
+ from monarch.common.messages import Dims
53
+ from monarch.common.shape import NDSlice
54
+ from monarch.controller.rust_backend.controller import RustController
55
+ from monarch.rust_backend_mesh import MeshWorld
56
+
57
+
58
+ logger: logging.Logger = logging.getLogger(__name__)
59
+
60
+
61
+ def sim_mesh(
62
+ n_meshes: int, hosts: int, gpus_per_host: int, proxy_addr: Optional[str] = None
63
+ ) -> List[DeviceMesh]:
64
+ """
65
+ Creates a single simulated device mesh with the given number of per host.
66
+
67
+ Args:
68
+ n_meshes : number of device meshes to create.
69
+ hosts : number of hosts, primarily used for simulating multiple machines locally.
70
+ Default: 1
71
+ gpus_per_host : number of gpus per host.
72
+ Default: the number of GPUs this machine has.
73
+ """
74
+ mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]] = {}
75
+ bootstrap: Bootstrap = Bootstrap(
76
+ n_meshes,
77
+ mesh_world_state,
78
+ proxy_addr=proxy_addr,
79
+ world_size=hosts * gpus_per_host,
80
+ )
81
+
82
+ client_proc_id = "client[0]"
83
+ client_proc: Proc = init_proc(
84
+ proc_id=client_proc_id,
85
+ bootstrap_addr=bootstrap.client_bootstrap_addr,
86
+ timeout=SIM_MESH_CLIENT_TIMEOUT, # unused
87
+ supervision_update_interval=SIM_MESH_CLIENT_SUPERVISION_UPDATE_INTERVAL,
88
+ listen_addr=bootstrap.client_listen_addr,
89
+ )
90
+ root_client_actor: ClientActor = ClientActor(
91
+ proc=client_proc, actor_name="root_client"
92
+ )
93
+
94
+ dms = []
95
+ for i in range(n_meshes):
96
+ controller_id = ActorId(
97
+ world_name=f"mesh_{i}_controller", rank=0, actor_name="root"
98
+ )
99
+ # Create a new device mesh
100
+ backend_ctrl = RustController(
101
+ proc=client_proc,
102
+ client_actor=ClientActor.new_with_parent(
103
+ client_proc, root_client_actor.actor_id
104
+ ),
105
+ controller_id=controller_id,
106
+ worker_world_name=f"mesh_{i}_worker",
107
+ )
108
+ client = Client(backend_ctrl, hosts * gpus_per_host, gpus_per_host)
109
+ dm = SimMesh(
110
+ client,
111
+ NDSlice(offset=0, sizes=[hosts, gpus_per_host], strides=[gpus_per_host, 1]),
112
+ ("host", "gpu"),
113
+ bootstrap._simulator_client,
114
+ f"mesh_{i}_worker",
115
+ )
116
+ dms.append(dm)
117
+
118
+ return dms
119
+
120
+
121
+ class OriginalFutureWrapper(Generic[T]):
122
+ result: Callable[
123
+ [
124
+ Future[T],
125
+ float | None,
126
+ ],
127
+ T,
128
+ ] = Future.result
129
+ _set_result: Callable[[Future[T], T], None] = Future._set_result
130
+
131
+
132
+ class SimMesh(DeviceMesh, Generic[T]):
133
+ def __init__(
134
+ self,
135
+ client: "Client",
136
+ processes: "NDSlice",
137
+ names: Dims,
138
+ simulator_client: SimulatorClient,
139
+ mesh_name: str = "default",
140
+ ) -> None:
141
+ super().__init__(client, processes, names, mesh_name)
142
+ self.simulator_client: SimulatorClient = simulator_client
143
+
144
+ # monkey patch Future.result and Future._set_result to hook into set_training_script_state_{running,waiting}
145
+ def activate(self) -> AbstractContextManager[DeviceMesh]:
146
+ def sim_result(fut: Future[T], timeout: float | None = None) -> T:
147
+ self.simulator_client.set_training_script_state_waiting()
148
+ return OriginalFutureWrapper.result(fut, timeout)
149
+
150
+ def sim_set_result(fut: Future[T], result: T) -> None:
151
+ self.simulator_client.set_training_script_state_running()
152
+ return OriginalFutureWrapper._set_result(fut, result)
153
+
154
+ # pyre-ignore
155
+ Future.result = sim_result
156
+ Future._set_result = sim_set_result
157
+
158
+ return super().activate()
159
+
160
+ # restore Future.result and Future._set_result to their previous values
161
+ def exit(
162
+ self,
163
+ error: Optional[RemoteException | DeviceException | Exception] = None,
164
+ ) -> None:
165
+ self.client.shutdown(True, error)
166
+ # pyre-ignore
167
+ Future.result = OriginalFutureWrapper._result
168
+ Future._set_result = OriginalFutureWrapper._set_result
169
+
170
+
171
+ def _random_id(length: int = 14) -> str:
172
+ """
173
+ A simple random id generator.
174
+ """
175
+ return "".join(random.choice(string.ascii_lowercase) for _ in range(length))
176
+
177
+
178
+ class Bootstrap:
179
+ def __init__(
180
+ self,
181
+ num_meshes: int,
182
+ mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]],
183
+ proxy_addr: Optional[str] = None,
184
+ world_size: int = 1,
185
+ ) -> None:
186
+ """
187
+ Bootstraps a SimMesh.
188
+ Args:
189
+ num_meshes: int - number of meshes to create.
190
+ proxy_addr: Option[str] - the proxy address of the simulation process
191
+ mesh_world_state: a state of the meshes. Keys are the MeshWorld and values are boolean indicating if this mesh is active.
192
+ """
193
+ # do a fake call to instantiate ThreadPoolExecutor so we don't block GIL later
194
+ fake_call(lambda: 0)
195
+
196
+ env = os.environ.copy()
197
+ env["HYPERACTOR_MANAGED_SUBPROCESS"] = "1"
198
+ self.env: dict[str, str] = env
199
+
200
+ self._mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]] = mesh_world_state
201
+
202
+ proxy_addr = proxy_addr or f"unix!@{_random_id()}-proxy"
203
+ self.bootstrap_addr: str = f"sim!unix!@system,{proxy_addr}"
204
+
205
+ client_proxy_addr = f"unix!@{_random_id()}-proxy"
206
+ self.client_listen_addr: str = f"sim!unix!@client,{client_proxy_addr}"
207
+ self.client_bootstrap_addr: str = (
208
+ f"sim!unix!@client,{client_proxy_addr},unix!@system,{proxy_addr}"
209
+ )
210
+ bootstrap_simulator_backend(self.bootstrap_addr, proxy_addr, world_size)
211
+
212
+ self._simulator_client = SimulatorClient(proxy_addr)
213
+ for i in range(num_meshes):
214
+ mesh_name: str = f"mesh_{i}"
215
+ controller_world: str = f"{mesh_name}_controller"
216
+ worker_world: str = f"{mesh_name}_worker"
217
+ controller_id: ActorId = ActorId(
218
+ world_name=controller_world,
219
+ rank=0,
220
+ actor_name="root",
221
+ )
222
+ mesh_world = (worker_world, controller_id)
223
+ self._mesh_world_state[mesh_world] = None
224
+ self.spawn_mesh(mesh_world)
225
+ # sleep for 10 sec for the worker and controller tasks to be spawned and ready.
226
+ time.sleep(10)
227
+
228
+ def get_mesh_worlds(self) -> List[MeshWorld]:
229
+ return []
230
+
231
+ def kill_mesh(self, mesh_world: MeshWorld) -> None:
232
+ pass
233
+
234
+ def spawn_mesh(self, mesh_world: MeshWorld) -> None:
235
+ worker_world, controller_id = mesh_world
236
+ controller_world = controller_id.world_name
237
+ self._simulator_client.spawn_mesh(
238
+ self.bootstrap_addr, f"{controller_world}[0].root", worker_world
239
+ )
240
+
241
+
242
+ def _validate_proccesses_end(
243
+ processes: Iterable[subprocess.Popen[bytes]],
244
+ timeout_in_sec: int = 1,
245
+ raise_on_abnormal_exit: bool = True,
246
+ ) -> list[int]:
247
+ """
248
+ Check if processes have ended properly. Raise errors immediately
249
+ if any process has ended with a non-zero return code.
250
+ Return a list of process indices that have not ended yet.
251
+ """
252
+ running = []
253
+ start_time = time.time()
254
+ for i, process in enumerate(processes):
255
+ try:
256
+ current_time = time.time()
257
+ elapsed_time = current_time - start_time
258
+ # The processes are running in parallel. No need to wait for
259
+ # `timeout_in_sec` for each process. Only count the remaining ones.
260
+ wait_in_sec = max(0, timeout_in_sec - elapsed_time)
261
+ return_code = process.wait(timeout=wait_in_sec)
262
+ if return_code != 0:
263
+ error_message: str = (
264
+ f"Process[{i}] {process.pid} exited with "
265
+ f"return code {return_code}. Command:\n "
266
+ f"{process.args!r}"
267
+ )
268
+ if raise_on_abnormal_exit:
269
+ raise RuntimeError(error_message)
270
+ else:
271
+ logger.error(error_message)
272
+ except subprocess.TimeoutExpired:
273
+ running.append(i)
274
+
275
+ return running
276
+
277
+
278
+ class PoolDeviceMeshProvider:
279
+ def __init__(
280
+ self,
281
+ hosts_per_mesh: int,
282
+ gpus_per_host: int,
283
+ client_proc: Proc,
284
+ mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]],
285
+ simulator_client: SimulatorClient,
286
+ ) -> None:
287
+ self._hosts_per_mesh = hosts_per_mesh
288
+ self._gpus_per_host = gpus_per_host
289
+ self._client_proc = client_proc
290
+ self._root_client_actor: ClientActor = ClientActor(
291
+ proc=client_proc, actor_name="root_client"
292
+ )
293
+ self._mesh_world_state = mesh_world_state
294
+ self._simulator_client = simulator_client
295
+
296
+ def new_mesh(self, timeout_in_sec: Optional[int] = None) -> DeviceMesh:
297
+ mesh_world_to_create = next(
298
+ (
299
+ mesh_world
300
+ for mesh_world, is_created in self._mesh_world_state.items()
301
+ if not is_created
302
+ ),
303
+ None,
304
+ )
305
+ assert mesh_world_to_create is not None, "No mesh world to create"
306
+
307
+ worker_world, controller_id = mesh_world_to_create
308
+ # Create a new device mesh
309
+ backend_ctrl = RustController(
310
+ proc=self._client_proc,
311
+ client_actor=ClientActor.new_with_parent(
312
+ self._client_proc, self._root_client_actor.actor_id
313
+ ),
314
+ controller_id=controller_id,
315
+ worker_world_name=worker_world,
316
+ )
317
+ client = Client(
318
+ backend_ctrl,
319
+ self._hosts_per_mesh * self._gpus_per_host,
320
+ self._gpus_per_host,
321
+ )
322
+ dm = SimMesh(
323
+ client,
324
+ NDSlice(
325
+ offset=0,
326
+ sizes=[self._hosts_per_mesh, self._gpus_per_host],
327
+ strides=[self._gpus_per_host, 1],
328
+ ),
329
+ ("host", "gpu"),
330
+ self._simulator_client,
331
+ worker_world,
332
+ )
333
+ self._mesh_world_state[mesh_world_to_create] = dm
334
+
335
+ return dm
336
+
337
+
338
+ def sim_mesh_provider(
339
+ num_meshes: int, hosts_per_mesh: int, gpus_per_host: int
340
+ ) -> Tuple[PoolDeviceMeshProvider, Bootstrap]:
341
+ mesh_world_state = {}
342
+ bootstrap = Bootstrap(num_meshes, mesh_world_state)
343
+
344
+ client_proc_id = "client[0]"
345
+ client_proc: Proc = init_proc(
346
+ proc_id=client_proc_id,
347
+ bootstrap_addr=bootstrap.client_bootstrap_addr,
348
+ timeout=SIM_MESH_CLIENT_TIMEOUT, # unused
349
+ supervision_update_interval=SIM_MESH_CLIENT_SUPERVISION_UPDATE_INTERVAL,
350
+ listen_addr=bootstrap.client_listen_addr,
351
+ )
352
+ dm_provider = PoolDeviceMeshProvider(
353
+ hosts_per_mesh,
354
+ gpus_per_host,
355
+ client_proc,
356
+ mesh_world_state,
357
+ bootstrap._simulator_client,
358
+ )
359
+ return (dm_provider, bootstrap)
@@ -0,0 +1,7 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict