torchmonarch-nightly 2025.6.27__cp311-cp311-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,280 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+ import logging
9
+ import time
10
+ from logging import Logger
11
+ from typing import Any, Callable, Optional, Protocol
12
+
13
+ from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension
14
+ ClientActor,
15
+ SystemSnapshotFilter,
16
+ )
17
+
18
+ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
19
+ ActorId,
20
+ init_proc,
21
+ Proc,
22
+ )
23
+ from monarch.common.client import Client
24
+ from monarch.common.device_mesh import DeviceMesh, DeviceMeshStatus
25
+ from monarch.common.invocation import DeviceException, RemoteException
26
+ from monarch.common.mast import MastJob
27
+ from monarch.common.shape import NDSlice
28
+ from monarch.controller.rust_backend.controller import RustController
29
+
30
+ TORCHX_MAST_TASK_GROUP_NAME = "script"
31
+
32
+ logger: Logger = logging.getLogger(__name__)
33
+
34
+ # A world tuple contains a worker world name and a controller actor id
35
+ # The pair forms a functional world that can be used to create a device mesh
36
+ MeshWorld = tuple[str, ActorId]
37
+
38
+ # Taken from //monarch/controller/src/bootstrap.rs
39
+ WORLD_WORKER_LABEL = "world.monarch.meta.com/worker"
40
+ WORLD_CONTROLLER_LABEL = "world.monarch.meta.com/controllerActorId"
41
+ WORLD_CONTROLLER_IP = "world.monarch.meta.com/ip_addr"
42
+
43
+
44
+ class IBootstrap(Protocol):
45
+ def get_mesh_worlds(self) -> list[MeshWorld]:
46
+ """Returns the list of mesh worlds."""
47
+ ...
48
+
49
+ def kill_mesh(self, mesh_world: MeshWorld) -> None:
50
+ """Kills a mesh in a bootstrap instance."""
51
+ ...
52
+
53
+ def spawn_mesh(self, mesh_world: MeshWorld) -> None:
54
+ """Spawns a mesh in a bootstrap instance."""
55
+ ...
56
+
57
+
58
+ class IPoolDeviceMeshProvider(Protocol):
59
+ def new_mesh(self, timeout_in_sec: Optional[int] = None) -> DeviceMesh:
60
+ raise NotImplementedError()
61
+
62
+
63
+ class PoolDeviceMeshProvider:
64
+ """
65
+ Given a client actor, the device mesh provider discovers and keeps track of
66
+ the world status and provides a device mesh given a healthy world.
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ hosts: int,
72
+ gpus: int,
73
+ proc: Proc,
74
+ ) -> None:
75
+ self._hosts = hosts
76
+ self._gpus = gpus
77
+ self._mesh_map: dict[MeshWorld, DeviceMesh | None] = {}
78
+ self._proc = proc
79
+ # Root client is not used to create device meshes.
80
+ # It is only used to pull the world status.
81
+ self._root_client: ClientActor = ClientActor(
82
+ proc=self._proc,
83
+ actor_name="root_client", # The client name really doesn't matter
84
+ )
85
+
86
+ def new_mesh(self, timeout_in_sec: Optional[int] = None) -> DeviceMesh:
87
+ """
88
+ Creates a new device mesh based on the current world status.
89
+ If no healthy world is found, the call will block until a healthy world is found
90
+ or timeout_in_sec is reached.xtimeout_in_sec being None indicates no timeout.
91
+ """
92
+
93
+ logger.info("Trying to allocate a new mesh in its desired world...")
94
+
95
+ def _create_exit(
96
+ client: Client,
97
+ ) -> Callable[[Optional[RemoteException | DeviceException | Exception]], None]:
98
+ def _exit(
99
+ error: Optional[RemoteException | DeviceException | Exception] = None,
100
+ ) -> None:
101
+ client.shutdown(True, error)
102
+
103
+ return _exit
104
+
105
+ def _is_world_healthy(world_status: dict[str, str], target_world: str) -> bool:
106
+ return (
107
+ target_world in world_status
108
+ and DeviceMeshStatus(world_status[target_world])
109
+ == DeviceMeshStatus.LIVE
110
+ )
111
+
112
+ now = time.time()
113
+ while timeout_in_sec is None or time.time() - now < timeout_in_sec:
114
+ # Pull the fresh world status
115
+ self._refresh_worlds()
116
+ world_status = self._root_client.world_status()
117
+ self._remove_evicted_worlds(world_status)
118
+
119
+ # Find the next available world
120
+ for mesh_world, mesh in self._mesh_map.items():
121
+ if mesh is not None:
122
+ # Mesh has been allocated to this world, skip
123
+ continue
124
+
125
+ worker_world, controller_id = mesh_world
126
+ controller_world = controller_id.world_name
127
+
128
+ if (not _is_world_healthy(world_status, worker_world)) or (
129
+ not _is_world_healthy(world_status, controller_world)
130
+ ):
131
+ # Either controller world is not ready or worker world is not ready
132
+ continue
133
+
134
+ # Create a new device mesh
135
+ backend_ctrl = RustController(
136
+ proc=self._proc,
137
+ client_actor=ClientActor.new_with_parent(
138
+ self._proc, self._root_client.actor_id
139
+ ),
140
+ controller_id=controller_id,
141
+ worker_world_name=worker_world,
142
+ )
143
+ client = Client(backend_ctrl, self._hosts * self._gpus, self._gpus)
144
+
145
+ # TODO: we need to consider hosts and gpus constraints as well
146
+ dm = DeviceMesh(
147
+ client,
148
+ NDSlice(
149
+ offset=0,
150
+ sizes=[self._hosts, self._gpus],
151
+ strides=[self._gpus, 1],
152
+ ),
153
+ ("host", "gpu"),
154
+ worker_world,
155
+ )
156
+ dm.exit = _create_exit(client)
157
+ self._mesh_map[mesh_world] = dm
158
+
159
+ logger.info("Mesh successfully allocated in world: %s", worker_world)
160
+
161
+ return dm
162
+
163
+ # TODO(T216841374): Change to healthy world push based checks
164
+ sleep_sec = 0.05
165
+ logger.debug(f"No healthy world found, sleeping for {sleep_sec}s...")
166
+ time.sleep(sleep_sec)
167
+
168
+ raise TimeoutError(f"Could not find a healthy world in {timeout_in_sec}s!")
169
+
170
+ def _refresh_worlds(self) -> None:
171
+ system_snapshot = self._root_client.world_state(
172
+ filter=SystemSnapshotFilter(world_labels={WORLD_WORKER_LABEL: "1"})
173
+ )
174
+ for world_id, world_snapshot in system_snapshot.items():
175
+ if WORLD_CONTROLLER_LABEL not in world_snapshot.labels:
176
+ continue
177
+ controller_actor_id = ActorId.from_string(
178
+ world_snapshot.labels[WORLD_CONTROLLER_LABEL]
179
+ )
180
+ world_tuple = (world_id, controller_actor_id)
181
+ if world_tuple not in self._mesh_map:
182
+ logger.debug(f"Discovered new worker world {world_id}")
183
+ self._mesh_map[world_tuple] = None
184
+
185
+ def _remove_evicted_worlds(self, world_status: dict[str, str]) -> None:
186
+ """
187
+ Go through the mesh map and remove the world that has already been evicted by the system.
188
+ """
189
+ mesh_worlds_to_remove = []
190
+ for mesh_world, _ in self._mesh_map.items():
191
+ worker_world, controller_id = mesh_world
192
+ controller_world = controller_id.world_name
193
+
194
+ if (
195
+ world_status.get(worker_world) is None
196
+ or world_status.get(controller_world) is None
197
+ ):
198
+ logger.debug(f"Removing Evicted world {mesh_world}")
199
+ mesh_worlds_to_remove.append(mesh_world)
200
+
201
+ for mesh_world in mesh_worlds_to_remove:
202
+ self._mesh_map.pop(mesh_world)
203
+
204
+
205
+ def rust_mast_mesh(
206
+ job_name: str, system_port: int = 29500, **kwargs: Any
207
+ ) -> DeviceMesh:
208
+ job = MastJob(job_name, TORCHX_MAST_TASK_GROUP_NAME)
209
+ if not job.is_running():
210
+ job.wait_for_running(10 * 60)
211
+ hostnames = job.get_hostnames()
212
+ system_addr = f"metatls!{hostnames[0]}.facebook.com:{system_port}"
213
+ return rust_backend_mesh(
214
+ system_addr,
215
+ **kwargs,
216
+ )
217
+
218
+
219
+ def rust_backend_mesh(
220
+ system_addr: str,
221
+ hosts: int,
222
+ gpus: int,
223
+ ) -> DeviceMesh:
224
+ dms = rust_backend_meshes(
225
+ system_addr,
226
+ hosts,
227
+ gpus,
228
+ requested_meshes=1,
229
+ )
230
+ assert len(dms) == 1
231
+ return dms[0]
232
+
233
+
234
+ def rust_backend_meshes(
235
+ system_addr: str,
236
+ hosts: int,
237
+ gpus: int,
238
+ requested_meshes: int = 1,
239
+ ) -> list[DeviceMesh]:
240
+ """
241
+ Given system system_addr, discover worlds registered and create a device mesh per
242
+ world with hosts and gpus. The call will block until requested_meshes
243
+ are discovered and created, or 1200s timeout is reached.
244
+ Args:
245
+ system_addr: the system address to connect to.
246
+ hosts: number of hosts to create the device mesh with.
247
+ gpus: number of gpus to create the device mesh with.
248
+ requested_meshes: the minimum number of meshes to create.
249
+ """
250
+ mesh_provider = rust_backend_mesh_provider(system_addr, hosts, gpus)
251
+ dms: list[DeviceMesh] = []
252
+
253
+ # Given a client actor and a list of world names, wait for all the worlds to be ready.
254
+ max_timeout_in_sec = 1200
255
+ start_time = time.time()
256
+ while True:
257
+ if time.time() - start_time > max_timeout_in_sec:
258
+ raise TimeoutError(
259
+ f"Timeout ({max_timeout_in_sec} sec) waiting for all worlds to be ready."
260
+ )
261
+ mesh = mesh_provider.new_mesh()
262
+ dms.append(mesh)
263
+ if len(dms) == requested_meshes:
264
+ return dms
265
+
266
+
267
+ def rust_backend_mesh_provider(
268
+ system_addr: str,
269
+ hosts: int,
270
+ gpus: int,
271
+ client_proc_id: str = "client[0]",
272
+ # pyre-fixme[11]: Annotation `DeviceMeshProvider` is not defined as a type.
273
+ ) -> PoolDeviceMeshProvider:
274
+ proc: Proc = init_proc(
275
+ proc_id=client_proc_id,
276
+ bootstrap_addr=system_addr,
277
+ timeout=5,
278
+ supervision_update_interval=5,
279
+ )
280
+ return PoolDeviceMeshProvider(hosts, gpus, proc)