torchmonarch-nightly 2025.6.27__cp313-cp313-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,395 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+ import pickle
10
+ import subprocess
11
+ import traceback
12
+
13
+ from typing import Any, Dict, List, Literal, Sequence, TypedDict
14
+
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class TraceEvent:
20
+ """
21
+ Represents a trace event in the simulation.
22
+
23
+ Args:
24
+ start (int): The start time, in nanoseconds, of the event.
25
+ runtime (int): The runtime, in nanoseconds, of the event.
26
+ meta (list): A list of metadata associated with the event.
27
+ command_id (int): The associated command id of this task.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ start: int,
33
+ runtime: int,
34
+ meta: List[str],
35
+ command_id: int,
36
+ traceback: Sequence[traceback.FrameSummary],
37
+ ):
38
+ self.start = start
39
+ self.runtime = runtime
40
+ self.end = start + runtime
41
+ self.meta = meta
42
+ self.command_id = command_id
43
+ self.traceback = traceback
44
+
45
+ def __repr__(self):
46
+ return f"E(meta={self.meta}, start={self.start:.2f}, end={self.end:.1f})"
47
+
48
+
49
+ def visualize_events(worker_events):
50
+ import pandas as pd
51
+ import plotly.graph_objs as go
52
+
53
+ # Convert the data to a DataFrame
54
+ records = []
55
+ for key, events in worker_events.items():
56
+ for event in events:
57
+ records.append(
58
+ {
59
+ "Process": key,
60
+ "Event": event.meta,
61
+ "Start": event.start,
62
+ "End": event.end,
63
+ "Duration": event.end - event.start,
64
+ }
65
+ )
66
+
67
+ df = pd.DataFrame(records)
68
+
69
+ # Create Gantt chart using plotly.graph_objs
70
+ fig = go.Figure()
71
+
72
+ fw_list = [
73
+ "#0000FF", # Blue
74
+ "#1E90FF", # Dodger Blue
75
+ "#00BFFF", # Deep Sky Blue
76
+ "#5F9EA0", # Cadet Blue
77
+ "#4682B4", # Steel Blue
78
+ "#87CEFA", # Light Sky Blue
79
+ "#6495ED", # Cornflower Blue
80
+ "#4169E1", # Royal Blue
81
+ ]
82
+ bw_list = [
83
+ "#FF0000", # Red
84
+ "#FF4500", # Orange Red
85
+ "#FF1493", # Deep Pink
86
+ "#FF69B4", # Hot Pink
87
+ "#DB7093", # Pale Violet Red
88
+ "#B22222", # Firebrick
89
+ "#8B0000", # Dark Red
90
+ "#FF6347", # Tomato
91
+ ]
92
+
93
+ # Map each event to a color
94
+
95
+ def get_color(metas):
96
+ if "fw" in metas:
97
+ for meta in metas:
98
+ if meta.isdigit():
99
+ return fw_list[int(meta) % len(fw_list)]
100
+ elif "bw" in metas:
101
+ for meta in metas:
102
+ if meta.isdigit():
103
+ return bw_list[int(meta) % len(fw_list)]
104
+ return "red"
105
+
106
+ for process in df["Process"].unique():
107
+ process_df = df[df["Process"] == process]
108
+ for _, row in process_df.iterrows():
109
+ color = get_color(row["Event"])
110
+ fig.add_trace(
111
+ go.Bar(
112
+ x=[row["Duration"]],
113
+ y=[str(process)],
114
+ base=[row["Start"]],
115
+ orientation="h",
116
+ name=" ".join(row["Event"]),
117
+ hoverinfo="name+x",
118
+ marker={
119
+ "color": color,
120
+ },
121
+ showlegend=False, # Hide default legend
122
+ )
123
+ )
124
+
125
+ # Add custom legend
126
+ # annotations = []
127
+ # legend_x = 0.95
128
+ # legend_y = 1.0
129
+
130
+ fig.update_layout(
131
+ title="Timeline Visualization",
132
+ xaxis_title="Time",
133
+ yaxis_title="Process",
134
+ barmode="stack",
135
+ # annotations=annotations,
136
+ showlegend=False, # Disable the default legend
137
+ yaxis={"autorange": "reversed"}, # Reverse the y-axis
138
+ )
139
+
140
+ # Show the plot
141
+ fig.write_html("sim.html")
142
+ # fig.show()
143
+
144
+
145
+ def dump_process_name(trace: List[Dict[str, Any]], *, pid: int, name: str):
146
+ trace.append(
147
+ {
148
+ "name": "process_name",
149
+ "ph": "M",
150
+ "pid": pid,
151
+ "tid": 0,
152
+ "args": {"name": name},
153
+ }
154
+ )
155
+
156
+
157
+ def _include_file(filename: str):
158
+ if "controller/" in filename:
159
+ return False
160
+ return True
161
+
162
+
163
+ def _filter_traceback(tb: Sequence[traceback.FrameSummary]):
164
+ notebook = [i for i, f in enumerate(tb) if f.name == "run_code"]
165
+ if notebook:
166
+ tb = tb[notebook[-1] + 1 :] # noqa: whitespace before ':'
167
+ filtered = [frame for frame in tb if _include_file(frame.filename)]
168
+ filtered.reverse()
169
+ return filtered
170
+
171
+
172
+ def _format_traceback(tb):
173
+ return "Traceback (most recent call first)\n" + "".join(
174
+ traceback.format_list(_filter_traceback(tb))
175
+ )
176
+
177
+
178
+ def dump_thread_event_trace(
179
+ trace: List[Dict[str, Any]],
180
+ events: List[TraceEvent],
181
+ *,
182
+ pid: int,
183
+ tid: int,
184
+ name: str,
185
+ ) -> float:
186
+ trace.append(
187
+ {
188
+ "name": "thread_name",
189
+ "ph": "M",
190
+ "pid": pid,
191
+ "tid": tid,
192
+ "args": {"name": name},
193
+ }
194
+ )
195
+ max_time = 0.0
196
+ for event in events:
197
+ name = " ".join(event.meta)
198
+ trace.append(
199
+ {
200
+ "name": name,
201
+ "cat": "compute",
202
+ "ph": "X",
203
+ "ts": event.start / 1000,
204
+ "dur": event.runtime / 1000,
205
+ "pid": pid,
206
+ "tid": tid,
207
+ "args": {
208
+ "External id": event.command_id + pid * 10000,
209
+ "correlation": event.command_id + pid * 10000,
210
+ "cbid": event.command_id,
211
+ " traceback": _format_traceback(event.traceback),
212
+ },
213
+ "cname": "rail_animation" if "waiting" in name else None,
214
+ }
215
+ )
216
+ max_time = max(max_time, (event.start + event.runtime) / 1000)
217
+
218
+ return max_time
219
+
220
+
221
+ def dump_memory_trace(
222
+ trace: List[Dict[str, Any]], *, pid: int, memory: int, ts: int, name: str
223
+ ) -> None:
224
+ trace.append(
225
+ {
226
+ "name": name,
227
+ "cat": "memory",
228
+ "ph": "C",
229
+ "ts": ts / 1000,
230
+ "pid": pid,
231
+ "args": {
232
+ "allocated": memory / 10**6,
233
+ },
234
+ }
235
+ )
236
+
237
+
238
+ def upload_trace(file_path) -> None:
239
+ logger.info("Uploading the trace file to Manifold...")
240
+
241
+ command_path = "~/fbsource/arvr/scripts/perfetto/share_trace.py"
242
+ command = [f"{command_path} {file_path}"]
243
+ result = subprocess.run(command, capture_output=True, text=True, shell=True)
244
+
245
+ if result.returncode == 0:
246
+ print(result.stdout)
247
+ else:
248
+ print("Failed to upload the file.")
249
+ print(result.stdout)
250
+ print(result.stderr)
251
+
252
+
253
+ class Frame(TypedDict):
254
+ filename: str
255
+ line: int
256
+ name: str
257
+
258
+
259
+ class Block(TypedDict):
260
+ # A piece of memory returned from the allocator, or
261
+ # current cached but inactive.
262
+ size: int
263
+ requested_size: int # size requested during malloc, may be smaller than
264
+ # size due to rounding
265
+ address: int
266
+ state: Literal[
267
+ "active_allocated", # used by a tensor
268
+ "active_awaiting_free", # waiting for another stream to finish using
269
+ # this, then it will become free
270
+ "inactive",
271
+ ] # free for reuse
272
+ frames: List[Frame] # stack trace from where the allocation occurred
273
+
274
+
275
+ class Segment(TypedDict):
276
+ # Segments are memory returned from a cudaMalloc call.
277
+ # The size of reserved memory is the sum of all Segments.
278
+ # Segments are cached and reused for future allocations.
279
+ # If the reuse is smaller than the segment, the segment
280
+ # is split into more then one Block.
281
+ # empty_cache() frees Segments that are entirely inactive.
282
+ address: int
283
+ total_size: int # cudaMalloc'd size of segment
284
+ stream: int
285
+ segment_type: Literal["small", "large"] # 'large' (>1MB)
286
+ allocated_size: int # size of memory in use
287
+ active_size: int # size of memory in use or in active_awaiting_free state
288
+ device: int
289
+ blocks: List[Block]
290
+
291
+
292
+ class TraceEntry(TypedDict):
293
+ # When `torch.cuda.memory._record_memory_history()` is enabled,
294
+ # the snapshot will contain TraceEntry objects that record each
295
+ # action the allocator took.
296
+ action: Literal[
297
+ "alloc", # memory allocated
298
+ "free_requested", # the allocated received a call to free memory
299
+ "free_completed", # the memory that was requested to be freed is now
300
+ # able to be used in future allocation calls
301
+ "segment_alloc", # the caching allocator ask cudaMalloc for more memory
302
+ # and added it as a segment in its cache
303
+ "segment_free", # the caching allocator called cudaFree to return memory
304
+ # to cuda possibly trying free up memory to
305
+ # allocate more segments or because empty_caches was called
306
+ "oom", # the allocator threw an OOM exception. 'size' is
307
+ # the requested number of bytes that did not succeed
308
+ "snapshot", # the allocator generated a memory snapshot
309
+ # useful to coorelate a previously taken
310
+ # snapshot with this trace
311
+ ]
312
+ addr: int # not present for OOM
313
+ frames: List[Frame]
314
+ size: int
315
+ stream: int
316
+
317
+
318
+ class Snapshot(TypedDict):
319
+ segments: List[Segment]
320
+ device_traces: List[List[TraceEntry]]
321
+
322
+
323
+ class MemoryViewer:
324
+ def __init__(self) -> None:
325
+ self.current_segments = {}
326
+ self.snapshot: Snapshot = {"segments": [], "device_traces": []}
327
+ self.addr_map = {}
328
+
329
+ def next_device(self) -> None:
330
+ self.addr_map.clear()
331
+ self.current_segments.clear()
332
+ self.snapshot["device_traces"].append([])
333
+
334
+ def get_or_add_segment(self, stream: int):
335
+ if stream in self.current_segments:
336
+ return self.current_segments[stream]
337
+ s: Segment = {
338
+ "address": 0,
339
+ "total_size": 0,
340
+ "stream": stream,
341
+ "segment_type": "large",
342
+ "allocated_size": 0,
343
+ "active_size": 0,
344
+ "blocks": [],
345
+ "device": len(self.snapshot["device_traces"]) - 1,
346
+ }
347
+ self.current_segments[stream] = s
348
+ self.snapshot["segments"].append(s)
349
+ return s
350
+
351
+ def add_trace(self, addr: int, delta: int, stream: int, traceback) -> None:
352
+ segment = self.get_or_add_segment(stream)
353
+ if delta > 0:
354
+ maddr = self.addr_map[addr] = segment["allocated_size"]
355
+ segment["allocated_size"] += delta
356
+ action: Literal["alloc", "free_requested"] = "alloc"
357
+ else:
358
+ maddr = self.addr_map[addr]
359
+ action: Literal["alloc", "free_requested"] = "free_requested"
360
+
361
+ frames: List[Frame] = [
362
+ {"filename": frame.filename, "line": frame.lineno, "name": frame.name}
363
+ for frame in _filter_traceback(traceback)
364
+ ]
365
+
366
+ trace: TraceEntry = {
367
+ "addr": maddr,
368
+ "frames": frames,
369
+ "size": abs(delta),
370
+ "stream": stream,
371
+ "action": action,
372
+ }
373
+ self.snapshot["device_traces"][-1].append(trace)
374
+ if delta < 0:
375
+ self.snapshot["device_traces"][-1].append(
376
+ # pyre-ignore
377
+ {**trace, "action": "free_completed"}
378
+ )
379
+
380
+ def dump(self, path: str) -> None:
381
+ for segment in self.snapshot["segments"]:
382
+ sz = segment["total_size"] = segment["allocated_size"]
383
+ segment["blocks"].append(
384
+ {
385
+ "address": 0,
386
+ "size": sz,
387
+ "requested_size": sz,
388
+ "state": "inactive",
389
+ "frames": [],
390
+ }
391
+ )
392
+
393
+ with open(path, "wb") as fp:
394
+ # @lint-ignore PYTHONPICKLEISBAD
395
+ pickle.dump(self.snapshot, fp)
@@ -0,0 +1,41 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import os
9
+
10
+ import numpy as np
11
+
12
+
13
+ def file_path_with_iter(file_path: str, iter_count: int) -> str:
14
+ dir_path = os.path.dirname(file_path)
15
+ file_name, file_postfix = os.path.basename(file_path).split(".")
16
+ file_name = f"{file_name}_{iter_count}.{file_postfix}"
17
+ return os.path.join(dir_path, file_name)
18
+
19
+
20
+ def compress_workers_range(workers) -> str:
21
+ regions = []
22
+ start = workers[0]
23
+ end = workers[0]
24
+ sorted_workers = np.sort(workers)
25
+ for i in range(1, len(sorted_workers)):
26
+ if workers[i] == end + 1:
27
+ end = workers[i]
28
+ else:
29
+ regions.append(f"[{start}-{end}]")
30
+ start = workers[i]
31
+ end = workers[i]
32
+ regions.append(f"[{start}-{end}]")
33
+ return " ".join(regions)
34
+
35
+
36
+ def clean_name(name: str) -> str:
37
+ if name.startswith("torch.ops.aten."):
38
+ name = name[len("torch.ops.") :] # noqa: whitespace before ':'
39
+ if name.endswith(".default"):
40
+ name = name[: -len(".default")]
41
+ return name