torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,692 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from __future__ import annotations
8
+
9
+ import copy
10
+ import csv
11
+ import itertools
12
+ import re
13
+ from collections import defaultdict
14
+ from enum import Enum
15
+ from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
16
+
17
+
18
+ # We reuse the IR definition and optimizations from FairInternal/XLFormers' implementation of pipeline parallelism,
19
+ # originally found in core/parallelism/pipeline_parallel/schedule_ir.py.
20
+ # TODO: Investigate how to adapt this code for reuse after further integration
21
+ class _ComputationType(Enum):
22
+ # TODO(whc) rename to _ActType?
23
+ FORWARD = 1
24
+ BACKWARD = 2
25
+ WEIGHT = 3
26
+ UNSHARD = 4
27
+ RESHARD = 5
28
+ SEND_F = 6
29
+ RECV_F = 7
30
+ SEND_B = 8
31
+ RECV_B = 9
32
+ SEND_F_RECV_B = 10
33
+ SEND_B_RECV_F = 11
34
+ # TODO- probably want to reconsider naming backward_input 'B' and having 'FULL_BACKWARD'.
35
+ # instead, B = full backward, Bx, Bw are the partials?
36
+ FULL_BACKWARD = 12
37
+
38
+ def __str__(self):
39
+ str_map = {
40
+ _ComputationType.FORWARD: "F",
41
+ _ComputationType.BACKWARD: "B",
42
+ _ComputationType.WEIGHT: "W",
43
+ _ComputationType.UNSHARD: "UNSHARD",
44
+ _ComputationType.RESHARD: "RESHARD",
45
+ _ComputationType.SEND_F: "SEND_F",
46
+ _ComputationType.RECV_F: "RECV_F",
47
+ _ComputationType.SEND_B: "SEND_B",
48
+ _ComputationType.RECV_B: "RECV_B",
49
+ _ComputationType.SEND_F_RECV_B: "SEND_F_RECV_B",
50
+ _ComputationType.SEND_B_RECV_F: "SEND_B_RECV_F",
51
+ _ComputationType.FULL_BACKWARD: "BW",
52
+ }
53
+ return str_map[self]
54
+
55
+ @staticmethod
56
+ def from_str(action):
57
+ if action == "F":
58
+ return _ComputationType.FORWARD
59
+ elif action == "B":
60
+ return _ComputationType.BACKWARD
61
+ elif action == "W":
62
+ return _ComputationType.WEIGHT
63
+ elif action == "UNSHARD":
64
+ return _ComputationType.UNSHARD
65
+ elif action == "RESHARD":
66
+ return _ComputationType.RESHARD
67
+ elif action == "SEND_F":
68
+ return _ComputationType.SEND_F
69
+ elif action == "RECV_F":
70
+ return _ComputationType.RECV_F
71
+ elif action == "SEND_B":
72
+ return _ComputationType.SEND_B
73
+ elif action == "RECV_B":
74
+ return _ComputationType.RECV_B
75
+ elif action == "SEND_F_RECV_B":
76
+ return _ComputationType.SEND_F_RECV_B
77
+ elif action == "SEND_B_RECV_F":
78
+ return _ComputationType.SEND_B_RECV_F
79
+ elif action == "BW":
80
+ return _ComputationType.FULL_BACKWARD
81
+ else:
82
+ raise RuntimeError(f"Invalid computation type {action}")
83
+
84
+
85
+ FORWARD = _ComputationType.FORWARD
86
+ BACKWARD = _ComputationType.BACKWARD
87
+ WEIGHT = _ComputationType.WEIGHT
88
+ UNSHARD = _ComputationType.UNSHARD
89
+ RESHARD = _ComputationType.RESHARD
90
+ SEND_F = _ComputationType.SEND_F
91
+ RECV_F = _ComputationType.RECV_F
92
+ SEND_B = _ComputationType.SEND_B
93
+ RECV_B = _ComputationType.RECV_B
94
+ SEND_F_RECV_B = _ComputationType.SEND_F_RECV_B
95
+ SEND_B_RECV_F = _ComputationType.SEND_B_RECV_F
96
+ FULL_BACKWARD = _ComputationType.FULL_BACKWARD
97
+
98
+ # Convenience shorthand for compute actions only since they are used in 'simple schedule format'
99
+ F = FORWARD
100
+ B = BACKWARD
101
+ W = WEIGHT
102
+ BW = FULL_BACKWARD
103
+
104
+ # Helper to parse an action string like 1F0 into a tuple of (stage_index, computation_type, microbatch_index)
105
+ _action_regex = re.compile(
106
+ r"(\d+)(F|BW|B|W|UNSHARD|RESHARD|SEND_F|RECV_F|SEND_B|RECV_B){0,1}(\d*)(_(\d*)(RECV_B|RECV_F)(\d)){0,1}"
107
+ )
108
+
109
+
110
+ class _Action(NamedTuple):
111
+ stage_index: int
112
+ computation_type: _ComputationType
113
+ microbatch_index: Optional[int] = None
114
+ # Used only for batched comms, for the second comm
115
+ other_stage_index: Optional[int] = None
116
+ other_microbatch_index: Optional[int] = None
117
+ # Indicates whether to call the post-backward reduce-scatter for W/BW actions.
118
+ require_reduce_scatter: Optional[bool] = False
119
+
120
+ def __repr__(self):
121
+ repr = str(self.stage_index)
122
+ if self.computation_type == SEND_B_RECV_F:
123
+ assert (
124
+ self.microbatch_index is not None
125
+ ), "SEND_B_RECV_F requires microbatch_index"
126
+ assert (
127
+ self.other_stage_index is not None
128
+ ), "SEND_B_RECV_F requires other_stage_index"
129
+ assert (
130
+ self.other_microbatch_index is not None
131
+ ), "SEND_B_RECV_F requires other_microbatch_index"
132
+ repr += str(SEND_B) + str(self.microbatch_index)
133
+ repr += "_" + str(self.other_stage_index)
134
+ repr += str(RECV_F) + str(self.other_microbatch_index)
135
+ elif self.computation_type == SEND_F_RECV_B:
136
+ assert (
137
+ self.microbatch_index is not None
138
+ ), "SEND_F_RECV_B requires microbatch_index"
139
+ assert (
140
+ self.other_stage_index is not None
141
+ ), "SEND_F_RECV_B requires other_stage_index"
142
+ assert (
143
+ self.other_microbatch_index is not None
144
+ ), "SEND_F_RECV_B requires other_microbatch_index"
145
+ repr += str(SEND_F) + str(self.microbatch_index)
146
+ repr += "_" + str(self.other_stage_index)
147
+ repr += str(RECV_B) + str(self.other_microbatch_index)
148
+ else:
149
+ repr += str(self.computation_type)
150
+ if self.microbatch_index is not None:
151
+ repr += str(self.microbatch_index)
152
+ require_reduce_scatter = (
153
+ hasattr(self, "require_reduce_scatter") and self.require_reduce_scatter
154
+ )
155
+ if require_reduce_scatter and self.computation_type in [
156
+ WEIGHT,
157
+ FULL_BACKWARD,
158
+ ]:
159
+ repr += "_rs"
160
+ return repr
161
+
162
+ @staticmethod
163
+ def from_str(str):
164
+ """
165
+ Reverse of __repr__
166
+
167
+ String should be formatted as [stage][action type][(microbatch)]
168
+ e.g. `2F0`, `1UNSHARD`, `3SEND_F1`
169
+ """
170
+ if match := _action_regex.match(str):
171
+ # the _ is for the combined group that captures the whole second action
172
+ (
173
+ stage_index,
174
+ computation_type,
175
+ microbatch_index,
176
+ _,
177
+ other_stage_index,
178
+ other_computation_type,
179
+ other_microbatch_index,
180
+ ) = match.groups()
181
+ if other_computation_type is not None:
182
+ assert (
183
+ other_stage_index is not None and other_microbatch_index is not None
184
+ )
185
+ return _Action(
186
+ int(stage_index),
187
+ _ComputationType.from_str(
188
+ f"{computation_type}_{other_computation_type}"
189
+ ),
190
+ int(microbatch_index) if len(microbatch_index) else None,
191
+ int(other_stage_index),
192
+ int(other_microbatch_index),
193
+ )
194
+ return _Action(
195
+ int(stage_index),
196
+ _ComputationType.from_str(computation_type),
197
+ int(microbatch_index) if len(microbatch_index) else None,
198
+ )
199
+ elif str == "" or str.isspace():
200
+ return None
201
+ raise RuntimeError(
202
+ f"Invalid action string: {str}, should be formatted as [stage][action type][(microbatch)] e.g. 2F0"
203
+ )
204
+
205
+ def get_pair_commu_action(self) -> Optional[_Action]:
206
+ """
207
+ Returns the corresponding communication action another rank.
208
+ """
209
+ if self.computation_type not in [RECV_F, RECV_B, SEND_F, SEND_B]:
210
+ return None
211
+ stage_id = self.stage_index
212
+ op = self.computation_type
213
+ microbatch_id = self.microbatch_index
214
+ if op == RECV_F:
215
+ other_stage = stage_id - 1
216
+ other_op = SEND_F
217
+ elif op == RECV_B:
218
+ other_stage = stage_id + 1
219
+ other_op = SEND_B
220
+ elif op == SEND_F:
221
+ other_stage = stage_id + 1
222
+ other_op = RECV_F
223
+ else:
224
+ assert op == SEND_B
225
+ other_stage = stage_id - 1
226
+ other_op = RECV_B
227
+ return _Action(other_stage, other_op, microbatch_id)
228
+
229
+
230
+ def _format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]) -> str:
231
+ """
232
+ Formats the pipeline order in a timestep (row) x rank (column) grid of actions
233
+ and returns the formatted string
234
+ """
235
+ # Replace None with ""
236
+ for rank in pipeline_order:
237
+ for i in range(len(pipeline_order[rank])):
238
+ if pipeline_order[rank][i] is None:
239
+ # TODO make a real 'None action' that prints as empty string and make mypy happy
240
+ pipeline_order[rank][i] = "" # type: ignore[call-overload]
241
+ # Calculate the maximum number of steps across all ranks
242
+ num_steps = max(len(actions) for actions in pipeline_order.values())
243
+ step_labels = [
244
+ "Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps)
245
+ ]
246
+ # Sorting the dictionary by keys and retrieving values in that order
247
+ rank_actions = [
248
+ pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order)
249
+ ]
250
+ # Transpose the list of lists (rows to columns)
251
+ transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue=""))
252
+ # Generate column labels for ranks
253
+ num_ranks = len(pipeline_order)
254
+ rank_labels = ["Rank " + str(i) for i in range(num_ranks)]
255
+ # Calculate the maximum length of each column, considering labels
256
+ max_lengths = [
257
+ max(len(str(item)) if item is not None else 0 for item in col)
258
+ for col in zip(step_labels, *transposed_actions)
259
+ ]
260
+ # Format the header row with rank labels
261
+ header_row = " " * (len(step_labels[0]) + 2) + " ".join(
262
+ f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels)
263
+ )
264
+ # Format each row with its corresponding label
265
+ formatted_rows = [
266
+ f"{label}: "
267
+ + " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row))
268
+ for label, row in zip(step_labels, transposed_actions)
269
+ ]
270
+ # Join the rows into a single string
271
+ formatted_table = header_row + "\n" + "\n".join(formatted_rows) + "\n"
272
+ return formatted_table
273
+
274
+
275
+ def _add_send_recv(
276
+ compute_actions: Dict[int, List[_Action]],
277
+ stage_to_rank: Callable[[int], int],
278
+ num_stages: int,
279
+ batch_send_recv: bool = False,
280
+ ) -> Dict[int, List[_Action]]:
281
+ comm_actions: Dict[int, List[_Action]] = {rank: [] for rank in compute_actions}
282
+
283
+ def _has_comms(action: _Action) -> bool:
284
+ if action.computation_type == F:
285
+ return action.stage_index != num_stages - 1 and stage_to_rank(
286
+ action.stage_index + 1
287
+ ) != stage_to_rank(action.stage_index)
288
+ elif action.computation_type in (B, BW):
289
+ return action.stage_index != 0 and stage_to_rank(
290
+ action.stage_index - 1
291
+ ) != stage_to_rank(action.stage_index)
292
+ return False
293
+
294
+ def _get_comms(action: _Action) -> Tuple[_Action, _Action]:
295
+ assert _has_comms(action), f"{action} is not a valid comm action"
296
+ stage_idx = action.stage_index
297
+ ctype = action.computation_type
298
+ mb_idx = action.microbatch_index
299
+ send = _Action(stage_idx, SEND_F if ctype == F else SEND_B, mb_idx)
300
+ recv_stage_idx = stage_idx + 1 if ctype == F else stage_idx - 1
301
+ recv = _Action(recv_stage_idx, RECV_F if ctype == F else RECV_B, mb_idx)
302
+ return send, recv
303
+
304
+ def _peer_rank(action: _Action) -> int:
305
+ # TODO asserts for invalid stage ids (RECV_F for stage 0)
306
+ if action.computation_type == SEND_F:
307
+ return stage_to_rank(action.stage_index + 1)
308
+ elif action.computation_type == SEND_B:
309
+ return stage_to_rank(action.stage_index - 1)
310
+ elif action.computation_type == RECV_F:
311
+ return stage_to_rank(action.stage_index - 1)
312
+ elif action.computation_type == RECV_B:
313
+ return stage_to_rank(action.stage_index + 1)
314
+ else:
315
+ raise ValueError("unsupported action for peer rank")
316
+
317
+ def _ready_to_schedule(
318
+ action: Optional[_Action], prev_actions: List[_Action]
319
+ ) -> bool:
320
+ """We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place.
321
+ This helps ensure a sane (non-hanging) ordering of sends and recvs.
322
+ But it also means we might not be able to schedule our next compute action yet.
323
+ """
324
+ if action is None:
325
+ return True
326
+ elif action.computation_type == F and not action.stage_index == 0:
327
+ for p in prev_actions:
328
+ if (
329
+ p.computation_type == RECV_F
330
+ and p.stage_index == action.stage_index
331
+ and p.microbatch_index == action.microbatch_index
332
+ ):
333
+ return True
334
+ elif (
335
+ p.computation_type == SEND_B_RECV_F
336
+ and p.other_stage_index == action.stage_index
337
+ and p.other_microbatch_index == action.microbatch_index
338
+ ):
339
+ return True
340
+ elif (
341
+ p.computation_type == FORWARD
342
+ and p.stage_index == action.stage_index - 1
343
+ and p.microbatch_index == action.microbatch_index
344
+ ):
345
+ return True
346
+ return False
347
+ elif (
348
+ action.computation_type in (B, BW)
349
+ and not action.stage_index == num_stages - 1
350
+ ):
351
+ for p in prev_actions:
352
+ if (
353
+ p.computation_type == RECV_B
354
+ and p.stage_index == action.stage_index
355
+ and p.microbatch_index == action.microbatch_index
356
+ ):
357
+ return True
358
+ elif (
359
+ p.computation_type == SEND_F_RECV_B
360
+ and p.other_stage_index == action.stage_index
361
+ and p.other_microbatch_index == action.microbatch_index
362
+ ):
363
+ return True
364
+ elif (
365
+ p.computation_type in (B, BW)
366
+ and p.stage_index == action.stage_index + 1
367
+ and p.microbatch_index == action.microbatch_index
368
+ ):
369
+ return True
370
+ return False
371
+ else:
372
+ return True
373
+
374
+ while compute_actions:
375
+ progress = False
376
+ # go in order of ranks even if dict keys aren't ordered
377
+ new_comms: Dict[int, defaultdict[int, list]] = {
378
+ rank: defaultdict(list) for rank in sorted(compute_actions)
379
+ }
380
+ for rank in sorted(compute_actions):
381
+ if rank not in compute_actions:
382
+ continue
383
+
384
+ assert len(compute_actions[rank]) > 0
385
+ action = compute_actions[rank][0]
386
+ if not _ready_to_schedule(action, comm_actions[rank]):
387
+ continue
388
+
389
+ if action is not None:
390
+ comm_actions[rank].append(action)
391
+ if _has_comms(action):
392
+ send, recv = _get_comms(action)
393
+ # TODO we can avoid send/recv if the 2 stages are on the same rank.
394
+ # should we avoid that in the runtime or here?
395
+ new_comms[rank][_peer_rank(send)].append(send)
396
+ new_comms[stage_to_rank(recv.stage_index)][rank].append(recv)
397
+
398
+ compute_actions[rank].pop(0)
399
+ if len(compute_actions[rank]) == 0:
400
+ del compute_actions[rank]
401
+ progress = True
402
+
403
+ if not progress:
404
+ print("WIP comms schedule:\n", _format_pipeline_order(comm_actions)) # type: ignore[arg-type]
405
+ print("remaining compute actions:\n", compute_actions)
406
+ assert progress, "Malformed compute schedule, can't schedule sends/recvs"
407
+
408
+ # comm batching needs to be done carefully to avoid reordering comms and causing a hang
409
+ # algorithm:
410
+ # Process sends/recvs in pairs. Processing means consuming from 'new_comms' and adding the final schedule
411
+ # processing batches is done the same way except 4 ops at a time are consumed and 2 are written
412
+ # rules:
413
+ # 1- if we batch ops for one rank, we also batch matching ops for another rank
414
+ # 2- when we create a batch, we append the batches to both ranks' schedules at the same time
415
+ # 3- we remove individual sends/recvs from 'new_comms' when we consume them in a batch
416
+ # 4- append individual (unbatchable) sends/recvs
417
+ for rank in new_comms:
418
+ for peer in new_comms[rank]:
419
+ if rank == peer:
420
+ continue
421
+ # we batch and process all the operations between rank and peer.
422
+ # this should symmetrically consume all actions from new_comms[rank][peer] and new_comms[peer][rank]
423
+ ops = new_comms[rank][peer]
424
+ peer_ops = new_comms[peer][rank]
425
+ if len(ops) == 0:
426
+ assert (
427
+ len(peer_ops) == 0
428
+ ), f"ops was empty but peer_ops was not, {peer_ops}"
429
+
430
+ batched_ops = list(ops)
431
+ batched_peer_ops = list(peer_ops)
432
+ # TODO - refactor so that it is not necessary to consume/clear ops/peer_ops
433
+ ops.clear()
434
+ peer_ops.clear()
435
+ comm_actions[rank].extend(batched_ops)
436
+ comm_actions[peer].extend(batched_peer_ops)
437
+
438
+ # # Run extra optimizations to adjust send/recv scheduling.
439
+ # optimized_comm_actions = _optimize_communication_ops(
440
+ # comm_actions,
441
+ # )
442
+ return comm_actions
443
+
444
+
445
+ def _simulate_comms_compute(
446
+ pipeline_order, stage_to_rank: Callable[[int], int], num_stages: int
447
+ ):
448
+ pipeline_order = {
449
+ rank: [a for a in pipeline_order[rank] if a is not None]
450
+ for rank in sorted(pipeline_order)
451
+ }
452
+ schedule: Dict[int, List[_Action | None]] = {
453
+ rank: [] for rank in sorted(pipeline_order)
454
+ }
455
+
456
+ def _prev_ops(stage_idx):
457
+ rank = stage_to_rank(stage_idx)
458
+ ops = copy.deepcopy(schedule[rank])
459
+ if len(pipeline_order[rank]):
460
+ # batched comm ops may need to be jointly scheduled (e.g. send_f_recv_b depends on and is a dep of send_b_recv_f)
461
+ # assuming we iterate in sorted rank order, peeking at the next unscheduled action for later ranks should unblock us
462
+ ops.append(pipeline_order[rank][0])
463
+
464
+ return ops
465
+
466
+ def _ready_to_schedule(action: Optional[_Action]) -> bool:
467
+ if action is None:
468
+ return True
469
+
470
+ stage_idx = action.stage_index
471
+ if action.computation_type == F:
472
+ if action.stage_index == 0:
473
+ return True
474
+ for p in _prev_ops(stage_idx):
475
+ if p is None:
476
+ continue
477
+ elif (
478
+ p.computation_type == F
479
+ and p.stage_index + 1 == action.stage_index
480
+ and p.microbatch_index == action.microbatch_index
481
+ ):
482
+ return True
483
+ elif (
484
+ p.computation_type == RECV_F
485
+ and p.stage_index == action.stage_index
486
+ and p.microbatch_index == action.microbatch_index
487
+ ):
488
+ return True
489
+ elif (
490
+ p.computation_type == SEND_B_RECV_F
491
+ and p.other_stage_index == action.stage_index
492
+ and p.other_microbatch_index == action.microbatch_index
493
+ ):
494
+ return True
495
+ return False
496
+ elif action.computation_type in (B, BW):
497
+ if action.stage_index == num_stages - 1:
498
+ return True
499
+
500
+ for p in _prev_ops(stage_idx):
501
+ if p is None:
502
+ continue
503
+ elif (
504
+ p.computation_type == RECV_B
505
+ and p.stage_index == action.stage_index
506
+ and p.microbatch_index == action.microbatch_index
507
+ ):
508
+ return True
509
+ elif (
510
+ p.computation_type == SEND_F_RECV_B
511
+ and p.other_stage_index == action.stage_index
512
+ and p.other_microbatch_index == action.microbatch_index
513
+ ):
514
+ return True
515
+ elif (
516
+ p.computation_type in (B, BW)
517
+ and p.stage_index - 1 == action.stage_index
518
+ and p.microbatch_index == action.microbatch_index
519
+ ):
520
+ return True
521
+ return False
522
+ elif action.computation_type == W:
523
+ return True
524
+ elif action.computation_type == SEND_F:
525
+ expected_f = _Action(action.stage_index, F, action.microbatch_index)
526
+ return expected_f in _prev_ops(stage_idx)
527
+ elif action.computation_type == RECV_F:
528
+ peer_stage_idx = stage_idx - 1
529
+ expected_send = _Action(peer_stage_idx, SEND_F, action.microbatch_index)
530
+ return expected_send in _prev_ops(peer_stage_idx)
531
+ elif action.computation_type == SEND_B:
532
+ expected_b = _Action(action.stage_index, B, action.microbatch_index)
533
+ expected_bw = _Action(action.stage_index, BW, action.microbatch_index)
534
+ return expected_b in _prev_ops(stage_idx) or expected_bw in _prev_ops(
535
+ stage_idx
536
+ )
537
+ elif action.computation_type == RECV_B:
538
+ peer_stage_idx = stage_idx + 1
539
+ expected_send = _Action(peer_stage_idx, SEND_B, action.microbatch_index)
540
+ return expected_send in _prev_ops(peer_stage_idx)
541
+ elif action.computation_type == SEND_F_RECV_B:
542
+ # though the stage_index may not be the same between the SEND and the RECV, the rank must be
543
+ peer_stage_idx = stage_idx + 1
544
+ for p in _prev_ops(peer_stage_idx):
545
+ if p is None:
546
+ continue
547
+ elif (
548
+ p.computation_type == SEND_B_RECV_F
549
+ and action.other_stage_index is not None
550
+ and p.stage_index == action.other_stage_index + 1
551
+ and p.other_stage_index is not None
552
+ and p.other_stage_index == action.stage_index + 1
553
+ and p.microbatch_index == action.other_microbatch_index
554
+ and p.other_microbatch_index == action.microbatch_index
555
+ ):
556
+ return True
557
+ return False
558
+ elif action.computation_type == SEND_B_RECV_F:
559
+ # though the stage_index may not be the same between the SEND and the RECV, the rank must be
560
+ peer_stage_idx = action.stage_index - 1
561
+ for p in _prev_ops(peer_stage_idx):
562
+ # if p is not None and str(p) == "0SEND_F14-16RECV_B0":
563
+ # breakpoint()
564
+ if p is None:
565
+ continue
566
+ elif (
567
+ p.computation_type == SEND_F_RECV_B
568
+ and p.stage_index + 1 == action.other_stage_index
569
+ and p.other_stage_index + 1 == action.stage_index
570
+ and p.microbatch_index == action.other_microbatch_index
571
+ and p.other_microbatch_index == action.microbatch_index
572
+ ):
573
+ return True
574
+ return False
575
+
576
+ else:
577
+ raise ValueError(f"Unsupported action type {action}")
578
+
579
+ while pipeline_order:
580
+ progress = False
581
+ for rank in sorted(pipeline_order):
582
+ if len(pipeline_order[rank]) == 0:
583
+ continue
584
+
585
+ action = pipeline_order[rank][0]
586
+ if _ready_to_schedule(action):
587
+ if action is not None:
588
+ schedule[rank].append(action)
589
+ pipeline_order[rank].pop(0)
590
+ progress = True
591
+ else:
592
+ schedule[rank].append(None)
593
+
594
+ for i in sorted(pipeline_order, reverse=True):
595
+ if len(pipeline_order[i]) == 0:
596
+ del pipeline_order[i]
597
+
598
+ # hacky, but do a second pass to replace any 'none' at this timestep with a real action, if it got unblocked
599
+ # by one of the later ranks
600
+ for rank in sorted(pipeline_order):
601
+ if len(pipeline_order[rank]) == 0:
602
+ continue
603
+
604
+ if schedule[rank][-1] is not None:
605
+ continue
606
+
607
+ action = pipeline_order[rank][0]
608
+ if _ready_to_schedule(action):
609
+ if action is not None:
610
+ schedule[rank][-1] = action
611
+ pipeline_order[rank].pop(0)
612
+
613
+ for i in sorted(pipeline_order, reverse=True):
614
+ if len(pipeline_order[i]) == 0:
615
+ del pipeline_order[i]
616
+
617
+ if not progress:
618
+ print("WIP comms schedule:\n", _format_pipeline_order(schedule))
619
+ for rank in pipeline_order:
620
+ print(f"{rank=} next action= {pipeline_order[rank][0]}")
621
+ raise ValueError("Schedule is not progressing")
622
+
623
+ return schedule
624
+
625
+
626
+ def _dump_chrometrace(schedule, filename):
627
+ events = []
628
+ for rank in sorted(schedule):
629
+ for timestep, action in enumerate(schedule[rank]):
630
+ if action is None:
631
+ continue
632
+ events.append(
633
+ {
634
+ "name": str(action),
635
+ "cat": (
636
+ "computation"
637
+ if action.computation_type in (F, B, W)
638
+ else "communication"
639
+ ),
640
+ "ph": "X",
641
+ "pid": rank,
642
+ "tid": rank,
643
+ "ts": timestep,
644
+ "dur": 1,
645
+ }
646
+ )
647
+ import json
648
+
649
+ with open(filename, "w") as f:
650
+ json.dump({"traceEvents": events}, f)
651
+
652
+
653
+ def _dump_csv(pipeline_order_with_comms, filename: str):
654
+ """Dump a CSV representation of the compute + comms schedule into a file with the provided filename."""
655
+ with open(filename, "w", newline="") as csvfile:
656
+ writer = csv.writer(csvfile)
657
+ for rank in pipeline_order_with_comms:
658
+ writer.writerow(pipeline_order_with_comms[rank])
659
+
660
+
661
+ def _merge_bw(
662
+ compute_actions: List[Optional[_Action]],
663
+ ) -> List[_Action]:
664
+ """Given a basic schedule involving only compute actions (F,B,W), merge adjacent B and W ops into BW ops.
665
+
666
+ BW refers to running the whole backward (not separating grad_input and grad_weight), which can be more efficient
667
+ in some cases.
668
+ """
669
+ merged_actions = []
670
+ while compute_actions:
671
+ action = compute_actions.pop(0)
672
+ if action is None:
673
+ continue
674
+
675
+ while len(compute_actions) and (next_action := compute_actions[0]) is None:
676
+ # remove any None actions between 'action' and 'next_action'
677
+ compute_actions.pop(0)
678
+
679
+ if (
680
+ action.computation_type == B
681
+ and next_action is not None
682
+ and next_action.computation_type == W
683
+ and action.stage_index == next_action.stage_index
684
+ and action.microbatch_index == next_action.microbatch_index
685
+ ):
686
+ merged_actions.append(
687
+ _Action(action.stage_index, BW, action.microbatch_index)
688
+ )
689
+ compute_actions.pop(0)
690
+ else:
691
+ merged_actions.append(action)
692
+ return merged_actions