torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +74 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +198 -0
  10. monarch/actor_mesh.py +692 -0
  11. monarch/allocator.py +62 -0
  12. monarch/bootstrap_main.py +75 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +69 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/common/_C.pyi +11 -0
  18. monarch/common/_C.so +0 -0
  19. monarch/common/__init__.py +0 -0
  20. monarch/common/_coalescing.py +308 -0
  21. monarch/common/_device_utils.py +18 -0
  22. monarch/common/_tensor_to_table.py +172 -0
  23. monarch/common/base_tensor.py +28 -0
  24. monarch/common/borrows.py +143 -0
  25. monarch/common/client.py +646 -0
  26. monarch/common/constants.py +10 -0
  27. monarch/common/context_manager.py +40 -0
  28. monarch/common/controller_api.py +104 -0
  29. monarch/common/device_mesh.py +443 -0
  30. monarch/common/fake.py +55 -0
  31. monarch/common/function.py +160 -0
  32. monarch/common/function_caching.py +164 -0
  33. monarch/common/future.py +168 -0
  34. monarch/common/invocation.py +125 -0
  35. monarch/common/mast.py +221 -0
  36. monarch/common/messages.py +572 -0
  37. monarch/common/mock_cuda.py +41 -0
  38. monarch/common/opaque_ref.py +98 -0
  39. monarch/common/pickle_flatten.py +48 -0
  40. monarch/common/pipe.py +152 -0
  41. monarch/common/process_group.py +55 -0
  42. monarch/common/recording.py +127 -0
  43. monarch/common/reference.py +33 -0
  44. monarch/common/remote.py +304 -0
  45. monarch/common/selection.py +9 -0
  46. monarch/common/shape.py +204 -0
  47. monarch/common/stream.py +111 -0
  48. monarch/common/tensor.py +793 -0
  49. monarch/common/tensor_factory.py +31 -0
  50. monarch/common/tree.py +73 -0
  51. monarch/controller/__init__.py +7 -0
  52. monarch/controller/backend.py +223 -0
  53. monarch/controller/controller.py +223 -0
  54. monarch/controller/debugger.py +47 -0
  55. monarch/controller/history.py +90 -0
  56. monarch/controller/rust_backend/__init__.py +7 -0
  57. monarch/controller/rust_backend/controller.py +245 -0
  58. monarch/fetch.py +55 -0
  59. monarch/future.py +25 -0
  60. monarch/gradient/__init__.py +11 -0
  61. monarch/gradient/_gradient_generator.pyi +22 -0
  62. monarch/gradient/_gradient_generator.so +0 -0
  63. monarch/gradient_generator.py +185 -0
  64. monarch/memory.py +43 -0
  65. monarch/monarch_controller +0 -0
  66. monarch/notebook.py +761 -0
  67. monarch/opaque_module.py +235 -0
  68. monarch/opaque_object.py +88 -0
  69. monarch/parallel/__init__.py +9 -0
  70. monarch/parallel/pipelining/__init__.py +7 -0
  71. monarch/parallel/pipelining/runtime.py +847 -0
  72. monarch/parallel/pipelining/schedule_ir.py +692 -0
  73. monarch/parallel/pipelining/scheduler.py +249 -0
  74. monarch/proc_mesh.py +188 -0
  75. monarch/profiler.py +160 -0
  76. monarch/python_local_mesh.py +107 -0
  77. monarch/random.py +61 -0
  78. monarch/rdma.py +190 -0
  79. monarch/remote_class.py +114 -0
  80. monarch/rust_backend_mesh.py +280 -0
  81. monarch/rust_local_mesh.py +1402 -0
  82. monarch/sim_mesh.py +357 -0
  83. monarch/simulator/__init__.py +7 -0
  84. monarch/simulator/command_history.py +424 -0
  85. monarch/simulator/config.py +21 -0
  86. monarch/simulator/interface.py +59 -0
  87. monarch/simulator/ir.py +770 -0
  88. monarch/simulator/mock_controller.py +214 -0
  89. monarch/simulator/profiling.py +424 -0
  90. monarch/simulator/simulator.py +1052 -0
  91. monarch/simulator/task.py +255 -0
  92. monarch/simulator/tensor.py +373 -0
  93. monarch/simulator/trace.py +395 -0
  94. monarch/simulator/utils.py +41 -0
  95. monarch/simulator/worker.py +389 -0
  96. monarch/tensor_worker_main.py +260 -0
  97. monarch/tensorboard.py +84 -0
  98. monarch/timer/__init__.py +21 -0
  99. monarch/timer/example_monarch.py +78 -0
  100. monarch/timer/example_spmd.py +55 -0
  101. monarch/timer/execution_timer.py +199 -0
  102. monarch/timer/execution_timer_test.py +131 -0
  103. monarch/tools/__init__.py +7 -0
  104. monarch/tools/cli.py +167 -0
  105. monarch/tools/commands.py +189 -0
  106. monarch/tools/components/__init__.py +7 -0
  107. monarch/tools/components/hyperactor.py +57 -0
  108. monarch/tools/config/__init__.py +20 -0
  109. monarch/tools/config/defaults.py +54 -0
  110. monarch/tools/mesh_spec.py +121 -0
  111. monarch/worker/__init__.py +7 -0
  112. monarch/worker/_testing_function.py +481 -0
  113. monarch/worker/compiled_block.py +270 -0
  114. monarch/worker/debugger.py +125 -0
  115. monarch/worker/lines.py +47 -0
  116. monarch/worker/monitor.py +53 -0
  117. monarch/worker/worker.py +1191 -0
  118. monarch/world_mesh.py +34 -0
  119. monarch_supervisor/__init__.py +1044 -0
  120. monarch_supervisor/_testing.py +44 -0
  121. monarch_supervisor/function_call.py +30 -0
  122. monarch_supervisor/host.py +386 -0
  123. monarch_supervisor/launchers.py +145 -0
  124. monarch_supervisor/log_pstree.py +48 -0
  125. monarch_supervisor/logging.py +103 -0
  126. monarch_supervisor/python_executable.py +42 -0
  127. tests/__init__.py +0 -0
  128. tests/dispatch_bench.py +124 -0
  129. tests/dispatch_bench_helper.py +25 -0
  130. tests/error_test_binary.py +139 -0
  131. tests/simulator/__init__.py +0 -0
  132. tests/simulator/test_profiling.py +136 -0
  133. tests/simulator/test_simulator.py +411 -0
  134. tests/simulator/test_task.py +64 -0
  135. tests/simulator/test_worker.py +102 -0
  136. tests/sleep_binary.py +35 -0
  137. tests/test_actor_error.py +112 -0
  138. tests/test_alloc.py +25 -0
  139. tests/test_coalescing.py +492 -0
  140. tests/test_controller.py +835 -0
  141. tests/test_device_mesh.py +132 -0
  142. tests/test_fault_tolerance.py +398 -0
  143. tests/test_future.py +94 -0
  144. tests/test_grad_generator.py +121 -0
  145. tests/test_mock_cuda.py +74 -0
  146. tests/test_pdb_actor.py +110 -0
  147. tests/test_python_actors.py +372 -0
  148. tests/test_remote_functions.py +1271 -0
  149. tests/test_rust_backend.py +182 -0
  150. tests/test_signal_safe_block_on.py +103 -0
  151. tests/test_sim_backend.py +54 -0
  152. torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
  153. torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
  154. torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
  155. torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
  156. torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
  157. torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
monarch/sim_mesh.py ADDED
@@ -0,0 +1,357 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ import importlib.resources
10
+ import logging
11
+ import os
12
+ import random
13
+ import string
14
+ import subprocess
15
+ import tempfile
16
+ import time
17
+ from pathlib import Path
18
+ from typing import (
19
+ Callable,
20
+ ContextManager as AbstractContextManager,
21
+ Dict,
22
+ Generic,
23
+ Iterable,
24
+ List,
25
+ Optional,
26
+ Tuple,
27
+ )
28
+
29
+ from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension # @manual=//monarch/monarch_extension:monarch_extension
30
+ ClientActor,
31
+ )
32
+
33
+ from monarch._rust_bindings.monarch_extension.simulator_client import ( # @manual=//monarch/monarch_extension:monarch_extension
34
+ bootstrap_simulator_backend,
35
+ SimulatorClient,
36
+ )
37
+
38
+ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
39
+ ActorId,
40
+ init_proc,
41
+ Proc,
42
+ )
43
+ from monarch.common.client import Client
44
+ from monarch.common.constants import (
45
+ SIM_MESH_CLIENT_SUPERVISION_UPDATE_INTERVAL,
46
+ SIM_MESH_CLIENT_TIMEOUT,
47
+ )
48
+ from monarch.common.device_mesh import DeviceMesh
49
+ from monarch.common.fake import fake_call
50
+ from monarch.common.future import Future, T
51
+ from monarch.common.invocation import DeviceException, RemoteException
52
+ from monarch.common.messages import Dims
53
+ from monarch.common.shape import NDSlice
54
+ from monarch.controller.rust_backend.controller import RustController
55
+ from monarch.rust_backend_mesh import MeshWorld
56
+
57
+
58
+ logger: logging.Logger = logging.getLogger(__name__)
59
+
60
+
61
+ def sim_mesh(
62
+ n_meshes: int, hosts: int, gpus_per_host: int, proxy_addr: Optional[str] = None
63
+ ) -> List[DeviceMesh]:
64
+ """
65
+ Creates a single simulated device mesh with the given number of per host.
66
+
67
+ Args:
68
+ n_meshes : number of device meshes to create.
69
+ hosts : number of hosts, primarily used for simulating multiple machines locally.
70
+ Default: 1
71
+ gpus_per_host : number of gpus per host.
72
+ Default: the number of GPUs this machine has.
73
+ """
74
+ mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]] = {}
75
+ bootstrap: Bootstrap = Bootstrap(
76
+ n_meshes,
77
+ mesh_world_state,
78
+ proxy_addr=proxy_addr,
79
+ world_size=hosts * gpus_per_host,
80
+ )
81
+
82
+ client_proc_id = "client[0]"
83
+ client_proc: Proc = init_proc(
84
+ proc_id=client_proc_id,
85
+ bootstrap_addr=bootstrap.client_bootstrap_addr,
86
+ timeout=SIM_MESH_CLIENT_TIMEOUT, # unused
87
+ supervision_update_interval=SIM_MESH_CLIENT_SUPERVISION_UPDATE_INTERVAL,
88
+ listen_addr=bootstrap.client_listen_addr,
89
+ )
90
+ root_client_actor: ClientActor = ClientActor(
91
+ proc=client_proc, actor_name="root_client"
92
+ )
93
+
94
+ dms = []
95
+ for i in range(n_meshes):
96
+ controller_id = ActorId(
97
+ world_name=f"mesh_{i}_controller", rank=0, actor_name="root"
98
+ )
99
+ # Create a new device mesh
100
+ backend_ctrl = RustController(
101
+ proc=client_proc,
102
+ client_actor=ClientActor.new_with_parent(
103
+ client_proc, root_client_actor.actor_id
104
+ ),
105
+ controller_id=controller_id,
106
+ worker_world_name=f"mesh_{i}_worker",
107
+ )
108
+ client = Client(backend_ctrl, hosts * gpus_per_host, gpus_per_host)
109
+ dm = SimMesh(
110
+ client,
111
+ NDSlice(offset=0, sizes=[hosts, gpus_per_host], strides=[gpus_per_host, 1]),
112
+ ("host", "gpu"),
113
+ bootstrap._simulator_client,
114
+ f"mesh_{i}_worker",
115
+ )
116
+ dms.append(dm)
117
+
118
+ return dms
119
+
120
+
121
+ class OriginalFutureWrapper(Generic[T]):
122
+ result: Callable[
123
+ [
124
+ Future[T],
125
+ float | None,
126
+ ],
127
+ T,
128
+ ] = Future.result
129
+ _set_result: Callable[[Future[T], T], None] = Future._set_result
130
+
131
+
132
+ class SimMesh(DeviceMesh, Generic[T]):
133
+ def __init__(
134
+ self,
135
+ client: "Client",
136
+ processes: "NDSlice",
137
+ names: Dims,
138
+ simulator_client: SimulatorClient,
139
+ mesh_name: str = "default",
140
+ ) -> None:
141
+ super().__init__(client, processes, names, mesh_name)
142
+ self.simulator_client: SimulatorClient = simulator_client
143
+
144
+ # monkey patch Future.result and Future._set_result to hook into set_training_script_state_{running,waiting}
145
+ def activate(self) -> AbstractContextManager[DeviceMesh]:
146
+ def sim_result(fut: Future[T], timeout: float | None = None) -> T:
147
+ self.simulator_client.set_training_script_state_waiting()
148
+ return OriginalFutureWrapper.result(fut, timeout)
149
+
150
+ def sim_set_result(fut: Future[T], result: T) -> None:
151
+ self.simulator_client.set_training_script_state_running()
152
+ return OriginalFutureWrapper._set_result(fut, result)
153
+
154
+ # pyre-ignore
155
+ Future.result = sim_result
156
+ Future._set_result = sim_set_result
157
+
158
+ return super().activate()
159
+
160
+ # restore Future.result and Future._set_result to their previous values
161
+ def exit(
162
+ self,
163
+ error: Optional[RemoteException | DeviceException | Exception] = None,
164
+ ) -> None:
165
+ self.client.shutdown(True, error)
166
+ # pyre-ignore
167
+ Future.result = OriginalFutureWrapper._result
168
+ Future._set_result = OriginalFutureWrapper._set_result
169
+
170
+
171
+ def _random_id(length: int = 14) -> str:
172
+ """
173
+ A simple random id generator.
174
+ """
175
+ return "".join(random.choice(string.ascii_lowercase) for _ in range(length))
176
+
177
+
178
+ class Bootstrap:
179
+ def __init__(
180
+ self,
181
+ num_meshes: int,
182
+ mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]],
183
+ proxy_addr: Optional[str] = None,
184
+ world_size: int = 1,
185
+ ) -> None:
186
+ """
187
+ Bootstraps a SimMesh.
188
+ Args:
189
+ num_meshes: int - number of meshes to create.
190
+ proxy_addr: Option[str] - the proxy address of the simulation process
191
+ mesh_world_state: a state of the meshes. Keys are the MeshWorld and values are boolean indicating if this mesh is active.
192
+ """
193
+ # do a fake call to instantiate ThreadPoolExecutor so we don't block GIL later
194
+ fake_call(lambda: 0)
195
+
196
+ env = os.environ.copy()
197
+ env["HYPERACTOR_MANAGED_SUBPROCESS"] = "1"
198
+ self.env: dict[str, str] = env
199
+
200
+ self._mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]] = mesh_world_state
201
+
202
+ proxy_addr = proxy_addr or f"unix!@{_random_id()}-proxy"
203
+ self.bootstrap_addr: str = f"sim!unix!@system,{proxy_addr}"
204
+ self.client_listen_addr: str = f"sim!unix!@client,{proxy_addr}"
205
+ self.client_bootstrap_addr: str = (
206
+ f"sim!unix!@client,{proxy_addr},unix!@system,{proxy_addr}"
207
+ )
208
+ bootstrap_simulator_backend(self.bootstrap_addr, world_size)
209
+
210
+ self._simulator_client = SimulatorClient(proxy_addr)
211
+ for i in range(num_meshes):
212
+ mesh_name: str = f"mesh_{i}"
213
+ controller_world: str = f"{mesh_name}_controller"
214
+ worker_world: str = f"{mesh_name}_worker"
215
+ controller_id: ActorId = ActorId(
216
+ world_name=controller_world,
217
+ rank=0,
218
+ actor_name="root",
219
+ )
220
+ mesh_world = (worker_world, controller_id)
221
+ self._mesh_world_state[mesh_world] = None
222
+ self.spawn_mesh(mesh_world)
223
+ # sleep for 10 sec for the worker and controller tasks to be spawned and ready.
224
+ time.sleep(10)
225
+
226
+ def get_mesh_worlds(self) -> List[MeshWorld]:
227
+ return []
228
+
229
+ def kill_mesh(self, mesh_world: MeshWorld) -> None:
230
+ pass
231
+
232
+ def spawn_mesh(self, mesh_world: MeshWorld) -> None:
233
+ worker_world, controller_id = mesh_world
234
+ controller_world = controller_id.world_name
235
+ self._simulator_client.spawn_mesh(
236
+ self.bootstrap_addr, f"{controller_world}[0].root", worker_world
237
+ )
238
+
239
+
240
+ def _validate_proccesses_end(
241
+ processes: Iterable[subprocess.Popen[bytes]],
242
+ timeout_in_sec: int = 1,
243
+ raise_on_abnormal_exit: bool = True,
244
+ ) -> list[int]:
245
+ """
246
+ Check if processes have ended properly. Raise errors immediately
247
+ if any process has ended with a non-zero return code.
248
+ Return a list of process indices that have not ended yet.
249
+ """
250
+ running = []
251
+ start_time = time.time()
252
+ for i, process in enumerate(processes):
253
+ try:
254
+ current_time = time.time()
255
+ elapsed_time = current_time - start_time
256
+ # The processes are running in parallel. No need to wait for
257
+ # `timeout_in_sec` for each process. Only count the remaining ones.
258
+ wait_in_sec = max(0, timeout_in_sec - elapsed_time)
259
+ return_code = process.wait(timeout=wait_in_sec)
260
+ if return_code != 0:
261
+ error_message: str = (
262
+ f"Process[{i}] {process.pid} exited with "
263
+ f"return code {return_code}. Command:\n "
264
+ f"{process.args!r}"
265
+ )
266
+ if raise_on_abnormal_exit:
267
+ raise RuntimeError(error_message)
268
+ else:
269
+ logger.error(error_message)
270
+ except subprocess.TimeoutExpired:
271
+ running.append(i)
272
+
273
+ return running
274
+
275
+
276
+ class PoolDeviceMeshProvider:
277
+ def __init__(
278
+ self,
279
+ hosts_per_mesh: int,
280
+ gpus_per_host: int,
281
+ client_proc: Proc,
282
+ mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]],
283
+ simulator_client: SimulatorClient,
284
+ ) -> None:
285
+ self._hosts_per_mesh = hosts_per_mesh
286
+ self._gpus_per_host = gpus_per_host
287
+ self._client_proc = client_proc
288
+ self._root_client_actor: ClientActor = ClientActor(
289
+ proc=client_proc, actor_name="root_client"
290
+ )
291
+ self._mesh_world_state = mesh_world_state
292
+ self._simulator_client = simulator_client
293
+
294
+ def new_mesh(self, timeout_in_sec: Optional[int] = None) -> DeviceMesh:
295
+ mesh_world_to_create = next(
296
+ (
297
+ mesh_world
298
+ for mesh_world, is_created in self._mesh_world_state.items()
299
+ if not is_created
300
+ ),
301
+ None,
302
+ )
303
+ assert mesh_world_to_create is not None, "No mesh world to create"
304
+
305
+ worker_world, controller_id = mesh_world_to_create
306
+ # Create a new device mesh
307
+ backend_ctrl = RustController(
308
+ proc=self._client_proc,
309
+ client_actor=ClientActor.new_with_parent(
310
+ self._client_proc, self._root_client_actor.actor_id
311
+ ),
312
+ controller_id=controller_id,
313
+ worker_world_name=worker_world,
314
+ )
315
+ client = Client(
316
+ backend_ctrl,
317
+ self._hosts_per_mesh * self._gpus_per_host,
318
+ self._gpus_per_host,
319
+ )
320
+ dm = SimMesh(
321
+ client,
322
+ NDSlice(
323
+ offset=0,
324
+ sizes=[self._hosts_per_mesh, self._gpus_per_host],
325
+ strides=[self._gpus_per_host, 1],
326
+ ),
327
+ ("host", "gpu"),
328
+ self._simulator_client,
329
+ worker_world,
330
+ )
331
+ self._mesh_world_state[mesh_world_to_create] = dm
332
+
333
+ return dm
334
+
335
+
336
+ def sim_mesh_provider(
337
+ num_meshes: int, hosts_per_mesh: int, gpus_per_host: int
338
+ ) -> Tuple[PoolDeviceMeshProvider, Bootstrap]:
339
+ mesh_world_state = {}
340
+ bootstrap = Bootstrap(num_meshes, mesh_world_state)
341
+
342
+ client_proc_id = "client[0]"
343
+ client_proc: Proc = init_proc(
344
+ proc_id=client_proc_id,
345
+ bootstrap_addr=bootstrap.client_bootstrap_addr,
346
+ timeout=SIM_MESH_CLIENT_TIMEOUT, # unused
347
+ supervision_update_interval=SIM_MESH_CLIENT_SUPERVISION_UPDATE_INTERVAL,
348
+ listen_addr=bootstrap.client_listen_addr,
349
+ )
350
+ dm_provider = PoolDeviceMeshProvider(
351
+ hosts_per_mesh,
352
+ gpus_per_host,
353
+ client_proc,
354
+ mesh_world_state,
355
+ bootstrap._simulator_client,
356
+ )
357
+ return (dm_provider, bootstrap)
@@ -0,0 +1,7 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict