torchmonarch-nightly 2025.6.27__cp313-cp313-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,692 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
from __future__ import annotations
|
8
|
+
|
9
|
+
import copy
|
10
|
+
import csv
|
11
|
+
import itertools
|
12
|
+
import re
|
13
|
+
from collections import defaultdict
|
14
|
+
from enum import Enum
|
15
|
+
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
|
16
|
+
|
17
|
+
|
18
|
+
# We reuse the IR definition and optimizations from FairInternal/XLFormers' implementation of pipeline parallelism,
|
19
|
+
# originally found in core/parallelism/pipeline_parallel/schedule_ir.py.
|
20
|
+
# TODO: Investigate how to adapt this code for reuse after further integration
|
21
|
+
class _ComputationType(Enum):
|
22
|
+
# TODO(whc) rename to _ActType?
|
23
|
+
FORWARD = 1
|
24
|
+
BACKWARD = 2
|
25
|
+
WEIGHT = 3
|
26
|
+
UNSHARD = 4
|
27
|
+
RESHARD = 5
|
28
|
+
SEND_F = 6
|
29
|
+
RECV_F = 7
|
30
|
+
SEND_B = 8
|
31
|
+
RECV_B = 9
|
32
|
+
SEND_F_RECV_B = 10
|
33
|
+
SEND_B_RECV_F = 11
|
34
|
+
# TODO- probably want to reconsider naming backward_input 'B' and having 'FULL_BACKWARD'.
|
35
|
+
# instead, B = full backward, Bx, Bw are the partials?
|
36
|
+
FULL_BACKWARD = 12
|
37
|
+
|
38
|
+
def __str__(self):
|
39
|
+
str_map = {
|
40
|
+
_ComputationType.FORWARD: "F",
|
41
|
+
_ComputationType.BACKWARD: "B",
|
42
|
+
_ComputationType.WEIGHT: "W",
|
43
|
+
_ComputationType.UNSHARD: "UNSHARD",
|
44
|
+
_ComputationType.RESHARD: "RESHARD",
|
45
|
+
_ComputationType.SEND_F: "SEND_F",
|
46
|
+
_ComputationType.RECV_F: "RECV_F",
|
47
|
+
_ComputationType.SEND_B: "SEND_B",
|
48
|
+
_ComputationType.RECV_B: "RECV_B",
|
49
|
+
_ComputationType.SEND_F_RECV_B: "SEND_F_RECV_B",
|
50
|
+
_ComputationType.SEND_B_RECV_F: "SEND_B_RECV_F",
|
51
|
+
_ComputationType.FULL_BACKWARD: "BW",
|
52
|
+
}
|
53
|
+
return str_map[self]
|
54
|
+
|
55
|
+
@staticmethod
|
56
|
+
def from_str(action):
|
57
|
+
if action == "F":
|
58
|
+
return _ComputationType.FORWARD
|
59
|
+
elif action == "B":
|
60
|
+
return _ComputationType.BACKWARD
|
61
|
+
elif action == "W":
|
62
|
+
return _ComputationType.WEIGHT
|
63
|
+
elif action == "UNSHARD":
|
64
|
+
return _ComputationType.UNSHARD
|
65
|
+
elif action == "RESHARD":
|
66
|
+
return _ComputationType.RESHARD
|
67
|
+
elif action == "SEND_F":
|
68
|
+
return _ComputationType.SEND_F
|
69
|
+
elif action == "RECV_F":
|
70
|
+
return _ComputationType.RECV_F
|
71
|
+
elif action == "SEND_B":
|
72
|
+
return _ComputationType.SEND_B
|
73
|
+
elif action == "RECV_B":
|
74
|
+
return _ComputationType.RECV_B
|
75
|
+
elif action == "SEND_F_RECV_B":
|
76
|
+
return _ComputationType.SEND_F_RECV_B
|
77
|
+
elif action == "SEND_B_RECV_F":
|
78
|
+
return _ComputationType.SEND_B_RECV_F
|
79
|
+
elif action == "BW":
|
80
|
+
return _ComputationType.FULL_BACKWARD
|
81
|
+
else:
|
82
|
+
raise RuntimeError(f"Invalid computation type {action}")
|
83
|
+
|
84
|
+
|
85
|
+
FORWARD = _ComputationType.FORWARD
|
86
|
+
BACKWARD = _ComputationType.BACKWARD
|
87
|
+
WEIGHT = _ComputationType.WEIGHT
|
88
|
+
UNSHARD = _ComputationType.UNSHARD
|
89
|
+
RESHARD = _ComputationType.RESHARD
|
90
|
+
SEND_F = _ComputationType.SEND_F
|
91
|
+
RECV_F = _ComputationType.RECV_F
|
92
|
+
SEND_B = _ComputationType.SEND_B
|
93
|
+
RECV_B = _ComputationType.RECV_B
|
94
|
+
SEND_F_RECV_B = _ComputationType.SEND_F_RECV_B
|
95
|
+
SEND_B_RECV_F = _ComputationType.SEND_B_RECV_F
|
96
|
+
FULL_BACKWARD = _ComputationType.FULL_BACKWARD
|
97
|
+
|
98
|
+
# Convenience shorthand for compute actions only since they are used in 'simple schedule format'
|
99
|
+
F = FORWARD
|
100
|
+
B = BACKWARD
|
101
|
+
W = WEIGHT
|
102
|
+
BW = FULL_BACKWARD
|
103
|
+
|
104
|
+
# Helper to parse an action string like 1F0 into a tuple of (stage_index, computation_type, microbatch_index)
|
105
|
+
_action_regex = re.compile(
|
106
|
+
r"(\d+)(F|BW|B|W|UNSHARD|RESHARD|SEND_F|RECV_F|SEND_B|RECV_B){0,1}(\d*)(_(\d*)(RECV_B|RECV_F)(\d)){0,1}"
|
107
|
+
)
|
108
|
+
|
109
|
+
|
110
|
+
class _Action(NamedTuple):
|
111
|
+
stage_index: int
|
112
|
+
computation_type: _ComputationType
|
113
|
+
microbatch_index: Optional[int] = None
|
114
|
+
# Used only for batched comms, for the second comm
|
115
|
+
other_stage_index: Optional[int] = None
|
116
|
+
other_microbatch_index: Optional[int] = None
|
117
|
+
# Indicates whether to call the post-backward reduce-scatter for W/BW actions.
|
118
|
+
require_reduce_scatter: Optional[bool] = False
|
119
|
+
|
120
|
+
def __repr__(self):
|
121
|
+
repr = str(self.stage_index)
|
122
|
+
if self.computation_type == SEND_B_RECV_F:
|
123
|
+
assert (
|
124
|
+
self.microbatch_index is not None
|
125
|
+
), "SEND_B_RECV_F requires microbatch_index"
|
126
|
+
assert (
|
127
|
+
self.other_stage_index is not None
|
128
|
+
), "SEND_B_RECV_F requires other_stage_index"
|
129
|
+
assert (
|
130
|
+
self.other_microbatch_index is not None
|
131
|
+
), "SEND_B_RECV_F requires other_microbatch_index"
|
132
|
+
repr += str(SEND_B) + str(self.microbatch_index)
|
133
|
+
repr += "_" + str(self.other_stage_index)
|
134
|
+
repr += str(RECV_F) + str(self.other_microbatch_index)
|
135
|
+
elif self.computation_type == SEND_F_RECV_B:
|
136
|
+
assert (
|
137
|
+
self.microbatch_index is not None
|
138
|
+
), "SEND_F_RECV_B requires microbatch_index"
|
139
|
+
assert (
|
140
|
+
self.other_stage_index is not None
|
141
|
+
), "SEND_F_RECV_B requires other_stage_index"
|
142
|
+
assert (
|
143
|
+
self.other_microbatch_index is not None
|
144
|
+
), "SEND_F_RECV_B requires other_microbatch_index"
|
145
|
+
repr += str(SEND_F) + str(self.microbatch_index)
|
146
|
+
repr += "_" + str(self.other_stage_index)
|
147
|
+
repr += str(RECV_B) + str(self.other_microbatch_index)
|
148
|
+
else:
|
149
|
+
repr += str(self.computation_type)
|
150
|
+
if self.microbatch_index is not None:
|
151
|
+
repr += str(self.microbatch_index)
|
152
|
+
require_reduce_scatter = (
|
153
|
+
hasattr(self, "require_reduce_scatter") and self.require_reduce_scatter
|
154
|
+
)
|
155
|
+
if require_reduce_scatter and self.computation_type in [
|
156
|
+
WEIGHT,
|
157
|
+
FULL_BACKWARD,
|
158
|
+
]:
|
159
|
+
repr += "_rs"
|
160
|
+
return repr
|
161
|
+
|
162
|
+
@staticmethod
|
163
|
+
def from_str(str):
|
164
|
+
"""
|
165
|
+
Reverse of __repr__
|
166
|
+
|
167
|
+
String should be formatted as [stage][action type][(microbatch)]
|
168
|
+
e.g. `2F0`, `1UNSHARD`, `3SEND_F1`
|
169
|
+
"""
|
170
|
+
if match := _action_regex.match(str):
|
171
|
+
# the _ is for the combined group that captures the whole second action
|
172
|
+
(
|
173
|
+
stage_index,
|
174
|
+
computation_type,
|
175
|
+
microbatch_index,
|
176
|
+
_,
|
177
|
+
other_stage_index,
|
178
|
+
other_computation_type,
|
179
|
+
other_microbatch_index,
|
180
|
+
) = match.groups()
|
181
|
+
if other_computation_type is not None:
|
182
|
+
assert (
|
183
|
+
other_stage_index is not None and other_microbatch_index is not None
|
184
|
+
)
|
185
|
+
return _Action(
|
186
|
+
int(stage_index),
|
187
|
+
_ComputationType.from_str(
|
188
|
+
f"{computation_type}_{other_computation_type}"
|
189
|
+
),
|
190
|
+
int(microbatch_index) if len(microbatch_index) else None,
|
191
|
+
int(other_stage_index),
|
192
|
+
int(other_microbatch_index),
|
193
|
+
)
|
194
|
+
return _Action(
|
195
|
+
int(stage_index),
|
196
|
+
_ComputationType.from_str(computation_type),
|
197
|
+
int(microbatch_index) if len(microbatch_index) else None,
|
198
|
+
)
|
199
|
+
elif str == "" or str.isspace():
|
200
|
+
return None
|
201
|
+
raise RuntimeError(
|
202
|
+
f"Invalid action string: {str}, should be formatted as [stage][action type][(microbatch)] e.g. 2F0"
|
203
|
+
)
|
204
|
+
|
205
|
+
def get_pair_commu_action(self) -> Optional[_Action]:
|
206
|
+
"""
|
207
|
+
Returns the corresponding communication action another rank.
|
208
|
+
"""
|
209
|
+
if self.computation_type not in [RECV_F, RECV_B, SEND_F, SEND_B]:
|
210
|
+
return None
|
211
|
+
stage_id = self.stage_index
|
212
|
+
op = self.computation_type
|
213
|
+
microbatch_id = self.microbatch_index
|
214
|
+
if op == RECV_F:
|
215
|
+
other_stage = stage_id - 1
|
216
|
+
other_op = SEND_F
|
217
|
+
elif op == RECV_B:
|
218
|
+
other_stage = stage_id + 1
|
219
|
+
other_op = SEND_B
|
220
|
+
elif op == SEND_F:
|
221
|
+
other_stage = stage_id + 1
|
222
|
+
other_op = RECV_F
|
223
|
+
else:
|
224
|
+
assert op == SEND_B
|
225
|
+
other_stage = stage_id - 1
|
226
|
+
other_op = RECV_B
|
227
|
+
return _Action(other_stage, other_op, microbatch_id)
|
228
|
+
|
229
|
+
|
230
|
+
def _format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]) -> str:
|
231
|
+
"""
|
232
|
+
Formats the pipeline order in a timestep (row) x rank (column) grid of actions
|
233
|
+
and returns the formatted string
|
234
|
+
"""
|
235
|
+
# Replace None with ""
|
236
|
+
for rank in pipeline_order:
|
237
|
+
for i in range(len(pipeline_order[rank])):
|
238
|
+
if pipeline_order[rank][i] is None:
|
239
|
+
# TODO make a real 'None action' that prints as empty string and make mypy happy
|
240
|
+
pipeline_order[rank][i] = "" # type: ignore[call-overload]
|
241
|
+
# Calculate the maximum number of steps across all ranks
|
242
|
+
num_steps = max(len(actions) for actions in pipeline_order.values())
|
243
|
+
step_labels = [
|
244
|
+
"Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps)
|
245
|
+
]
|
246
|
+
# Sorting the dictionary by keys and retrieving values in that order
|
247
|
+
rank_actions = [
|
248
|
+
pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order)
|
249
|
+
]
|
250
|
+
# Transpose the list of lists (rows to columns)
|
251
|
+
transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue=""))
|
252
|
+
# Generate column labels for ranks
|
253
|
+
num_ranks = len(pipeline_order)
|
254
|
+
rank_labels = ["Rank " + str(i) for i in range(num_ranks)]
|
255
|
+
# Calculate the maximum length of each column, considering labels
|
256
|
+
max_lengths = [
|
257
|
+
max(len(str(item)) if item is not None else 0 for item in col)
|
258
|
+
for col in zip(step_labels, *transposed_actions)
|
259
|
+
]
|
260
|
+
# Format the header row with rank labels
|
261
|
+
header_row = " " * (len(step_labels[0]) + 2) + " ".join(
|
262
|
+
f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels)
|
263
|
+
)
|
264
|
+
# Format each row with its corresponding label
|
265
|
+
formatted_rows = [
|
266
|
+
f"{label}: "
|
267
|
+
+ " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row))
|
268
|
+
for label, row in zip(step_labels, transposed_actions)
|
269
|
+
]
|
270
|
+
# Join the rows into a single string
|
271
|
+
formatted_table = header_row + "\n" + "\n".join(formatted_rows) + "\n"
|
272
|
+
return formatted_table
|
273
|
+
|
274
|
+
|
275
|
+
def _add_send_recv(
|
276
|
+
compute_actions: Dict[int, List[_Action]],
|
277
|
+
stage_to_rank: Callable[[int], int],
|
278
|
+
num_stages: int,
|
279
|
+
batch_send_recv: bool = False,
|
280
|
+
) -> Dict[int, List[_Action]]:
|
281
|
+
comm_actions: Dict[int, List[_Action]] = {rank: [] for rank in compute_actions}
|
282
|
+
|
283
|
+
def _has_comms(action: _Action) -> bool:
|
284
|
+
if action.computation_type == F:
|
285
|
+
return action.stage_index != num_stages - 1 and stage_to_rank(
|
286
|
+
action.stage_index + 1
|
287
|
+
) != stage_to_rank(action.stage_index)
|
288
|
+
elif action.computation_type in (B, BW):
|
289
|
+
return action.stage_index != 0 and stage_to_rank(
|
290
|
+
action.stage_index - 1
|
291
|
+
) != stage_to_rank(action.stage_index)
|
292
|
+
return False
|
293
|
+
|
294
|
+
def _get_comms(action: _Action) -> Tuple[_Action, _Action]:
|
295
|
+
assert _has_comms(action), f"{action} is not a valid comm action"
|
296
|
+
stage_idx = action.stage_index
|
297
|
+
ctype = action.computation_type
|
298
|
+
mb_idx = action.microbatch_index
|
299
|
+
send = _Action(stage_idx, SEND_F if ctype == F else SEND_B, mb_idx)
|
300
|
+
recv_stage_idx = stage_idx + 1 if ctype == F else stage_idx - 1
|
301
|
+
recv = _Action(recv_stage_idx, RECV_F if ctype == F else RECV_B, mb_idx)
|
302
|
+
return send, recv
|
303
|
+
|
304
|
+
def _peer_rank(action: _Action) -> int:
|
305
|
+
# TODO asserts for invalid stage ids (RECV_F for stage 0)
|
306
|
+
if action.computation_type == SEND_F:
|
307
|
+
return stage_to_rank(action.stage_index + 1)
|
308
|
+
elif action.computation_type == SEND_B:
|
309
|
+
return stage_to_rank(action.stage_index - 1)
|
310
|
+
elif action.computation_type == RECV_F:
|
311
|
+
return stage_to_rank(action.stage_index - 1)
|
312
|
+
elif action.computation_type == RECV_B:
|
313
|
+
return stage_to_rank(action.stage_index + 1)
|
314
|
+
else:
|
315
|
+
raise ValueError("unsupported action for peer rank")
|
316
|
+
|
317
|
+
def _ready_to_schedule(
|
318
|
+
action: Optional[_Action], prev_actions: List[_Action]
|
319
|
+
) -> bool:
|
320
|
+
"""We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place.
|
321
|
+
This helps ensure a sane (non-hanging) ordering of sends and recvs.
|
322
|
+
But it also means we might not be able to schedule our next compute action yet.
|
323
|
+
"""
|
324
|
+
if action is None:
|
325
|
+
return True
|
326
|
+
elif action.computation_type == F and not action.stage_index == 0:
|
327
|
+
for p in prev_actions:
|
328
|
+
if (
|
329
|
+
p.computation_type == RECV_F
|
330
|
+
and p.stage_index == action.stage_index
|
331
|
+
and p.microbatch_index == action.microbatch_index
|
332
|
+
):
|
333
|
+
return True
|
334
|
+
elif (
|
335
|
+
p.computation_type == SEND_B_RECV_F
|
336
|
+
and p.other_stage_index == action.stage_index
|
337
|
+
and p.other_microbatch_index == action.microbatch_index
|
338
|
+
):
|
339
|
+
return True
|
340
|
+
elif (
|
341
|
+
p.computation_type == FORWARD
|
342
|
+
and p.stage_index == action.stage_index - 1
|
343
|
+
and p.microbatch_index == action.microbatch_index
|
344
|
+
):
|
345
|
+
return True
|
346
|
+
return False
|
347
|
+
elif (
|
348
|
+
action.computation_type in (B, BW)
|
349
|
+
and not action.stage_index == num_stages - 1
|
350
|
+
):
|
351
|
+
for p in prev_actions:
|
352
|
+
if (
|
353
|
+
p.computation_type == RECV_B
|
354
|
+
and p.stage_index == action.stage_index
|
355
|
+
and p.microbatch_index == action.microbatch_index
|
356
|
+
):
|
357
|
+
return True
|
358
|
+
elif (
|
359
|
+
p.computation_type == SEND_F_RECV_B
|
360
|
+
and p.other_stage_index == action.stage_index
|
361
|
+
and p.other_microbatch_index == action.microbatch_index
|
362
|
+
):
|
363
|
+
return True
|
364
|
+
elif (
|
365
|
+
p.computation_type in (B, BW)
|
366
|
+
and p.stage_index == action.stage_index + 1
|
367
|
+
and p.microbatch_index == action.microbatch_index
|
368
|
+
):
|
369
|
+
return True
|
370
|
+
return False
|
371
|
+
else:
|
372
|
+
return True
|
373
|
+
|
374
|
+
while compute_actions:
|
375
|
+
progress = False
|
376
|
+
# go in order of ranks even if dict keys aren't ordered
|
377
|
+
new_comms: Dict[int, defaultdict[int, list]] = {
|
378
|
+
rank: defaultdict(list) for rank in sorted(compute_actions)
|
379
|
+
}
|
380
|
+
for rank in sorted(compute_actions):
|
381
|
+
if rank not in compute_actions:
|
382
|
+
continue
|
383
|
+
|
384
|
+
assert len(compute_actions[rank]) > 0
|
385
|
+
action = compute_actions[rank][0]
|
386
|
+
if not _ready_to_schedule(action, comm_actions[rank]):
|
387
|
+
continue
|
388
|
+
|
389
|
+
if action is not None:
|
390
|
+
comm_actions[rank].append(action)
|
391
|
+
if _has_comms(action):
|
392
|
+
send, recv = _get_comms(action)
|
393
|
+
# TODO we can avoid send/recv if the 2 stages are on the same rank.
|
394
|
+
# should we avoid that in the runtime or here?
|
395
|
+
new_comms[rank][_peer_rank(send)].append(send)
|
396
|
+
new_comms[stage_to_rank(recv.stage_index)][rank].append(recv)
|
397
|
+
|
398
|
+
compute_actions[rank].pop(0)
|
399
|
+
if len(compute_actions[rank]) == 0:
|
400
|
+
del compute_actions[rank]
|
401
|
+
progress = True
|
402
|
+
|
403
|
+
if not progress:
|
404
|
+
print("WIP comms schedule:\n", _format_pipeline_order(comm_actions)) # type: ignore[arg-type]
|
405
|
+
print("remaining compute actions:\n", compute_actions)
|
406
|
+
assert progress, "Malformed compute schedule, can't schedule sends/recvs"
|
407
|
+
|
408
|
+
# comm batching needs to be done carefully to avoid reordering comms and causing a hang
|
409
|
+
# algorithm:
|
410
|
+
# Process sends/recvs in pairs. Processing means consuming from 'new_comms' and adding the final schedule
|
411
|
+
# processing batches is done the same way except 4 ops at a time are consumed and 2 are written
|
412
|
+
# rules:
|
413
|
+
# 1- if we batch ops for one rank, we also batch matching ops for another rank
|
414
|
+
# 2- when we create a batch, we append the batches to both ranks' schedules at the same time
|
415
|
+
# 3- we remove individual sends/recvs from 'new_comms' when we consume them in a batch
|
416
|
+
# 4- append individual (unbatchable) sends/recvs
|
417
|
+
for rank in new_comms:
|
418
|
+
for peer in new_comms[rank]:
|
419
|
+
if rank == peer:
|
420
|
+
continue
|
421
|
+
# we batch and process all the operations between rank and peer.
|
422
|
+
# this should symmetrically consume all actions from new_comms[rank][peer] and new_comms[peer][rank]
|
423
|
+
ops = new_comms[rank][peer]
|
424
|
+
peer_ops = new_comms[peer][rank]
|
425
|
+
if len(ops) == 0:
|
426
|
+
assert (
|
427
|
+
len(peer_ops) == 0
|
428
|
+
), f"ops was empty but peer_ops was not, {peer_ops}"
|
429
|
+
|
430
|
+
batched_ops = list(ops)
|
431
|
+
batched_peer_ops = list(peer_ops)
|
432
|
+
# TODO - refactor so that it is not necessary to consume/clear ops/peer_ops
|
433
|
+
ops.clear()
|
434
|
+
peer_ops.clear()
|
435
|
+
comm_actions[rank].extend(batched_ops)
|
436
|
+
comm_actions[peer].extend(batched_peer_ops)
|
437
|
+
|
438
|
+
# # Run extra optimizations to adjust send/recv scheduling.
|
439
|
+
# optimized_comm_actions = _optimize_communication_ops(
|
440
|
+
# comm_actions,
|
441
|
+
# )
|
442
|
+
return comm_actions
|
443
|
+
|
444
|
+
|
445
|
+
def _simulate_comms_compute(
|
446
|
+
pipeline_order, stage_to_rank: Callable[[int], int], num_stages: int
|
447
|
+
):
|
448
|
+
pipeline_order = {
|
449
|
+
rank: [a for a in pipeline_order[rank] if a is not None]
|
450
|
+
for rank in sorted(pipeline_order)
|
451
|
+
}
|
452
|
+
schedule: Dict[int, List[_Action | None]] = {
|
453
|
+
rank: [] for rank in sorted(pipeline_order)
|
454
|
+
}
|
455
|
+
|
456
|
+
def _prev_ops(stage_idx):
|
457
|
+
rank = stage_to_rank(stage_idx)
|
458
|
+
ops = copy.deepcopy(schedule[rank])
|
459
|
+
if len(pipeline_order[rank]):
|
460
|
+
# batched comm ops may need to be jointly scheduled (e.g. send_f_recv_b depends on and is a dep of send_b_recv_f)
|
461
|
+
# assuming we iterate in sorted rank order, peeking at the next unscheduled action for later ranks should unblock us
|
462
|
+
ops.append(pipeline_order[rank][0])
|
463
|
+
|
464
|
+
return ops
|
465
|
+
|
466
|
+
def _ready_to_schedule(action: Optional[_Action]) -> bool:
|
467
|
+
if action is None:
|
468
|
+
return True
|
469
|
+
|
470
|
+
stage_idx = action.stage_index
|
471
|
+
if action.computation_type == F:
|
472
|
+
if action.stage_index == 0:
|
473
|
+
return True
|
474
|
+
for p in _prev_ops(stage_idx):
|
475
|
+
if p is None:
|
476
|
+
continue
|
477
|
+
elif (
|
478
|
+
p.computation_type == F
|
479
|
+
and p.stage_index + 1 == action.stage_index
|
480
|
+
and p.microbatch_index == action.microbatch_index
|
481
|
+
):
|
482
|
+
return True
|
483
|
+
elif (
|
484
|
+
p.computation_type == RECV_F
|
485
|
+
and p.stage_index == action.stage_index
|
486
|
+
and p.microbatch_index == action.microbatch_index
|
487
|
+
):
|
488
|
+
return True
|
489
|
+
elif (
|
490
|
+
p.computation_type == SEND_B_RECV_F
|
491
|
+
and p.other_stage_index == action.stage_index
|
492
|
+
and p.other_microbatch_index == action.microbatch_index
|
493
|
+
):
|
494
|
+
return True
|
495
|
+
return False
|
496
|
+
elif action.computation_type in (B, BW):
|
497
|
+
if action.stage_index == num_stages - 1:
|
498
|
+
return True
|
499
|
+
|
500
|
+
for p in _prev_ops(stage_idx):
|
501
|
+
if p is None:
|
502
|
+
continue
|
503
|
+
elif (
|
504
|
+
p.computation_type == RECV_B
|
505
|
+
and p.stage_index == action.stage_index
|
506
|
+
and p.microbatch_index == action.microbatch_index
|
507
|
+
):
|
508
|
+
return True
|
509
|
+
elif (
|
510
|
+
p.computation_type == SEND_F_RECV_B
|
511
|
+
and p.other_stage_index == action.stage_index
|
512
|
+
and p.other_microbatch_index == action.microbatch_index
|
513
|
+
):
|
514
|
+
return True
|
515
|
+
elif (
|
516
|
+
p.computation_type in (B, BW)
|
517
|
+
and p.stage_index - 1 == action.stage_index
|
518
|
+
and p.microbatch_index == action.microbatch_index
|
519
|
+
):
|
520
|
+
return True
|
521
|
+
return False
|
522
|
+
elif action.computation_type == W:
|
523
|
+
return True
|
524
|
+
elif action.computation_type == SEND_F:
|
525
|
+
expected_f = _Action(action.stage_index, F, action.microbatch_index)
|
526
|
+
return expected_f in _prev_ops(stage_idx)
|
527
|
+
elif action.computation_type == RECV_F:
|
528
|
+
peer_stage_idx = stage_idx - 1
|
529
|
+
expected_send = _Action(peer_stage_idx, SEND_F, action.microbatch_index)
|
530
|
+
return expected_send in _prev_ops(peer_stage_idx)
|
531
|
+
elif action.computation_type == SEND_B:
|
532
|
+
expected_b = _Action(action.stage_index, B, action.microbatch_index)
|
533
|
+
expected_bw = _Action(action.stage_index, BW, action.microbatch_index)
|
534
|
+
return expected_b in _prev_ops(stage_idx) or expected_bw in _prev_ops(
|
535
|
+
stage_idx
|
536
|
+
)
|
537
|
+
elif action.computation_type == RECV_B:
|
538
|
+
peer_stage_idx = stage_idx + 1
|
539
|
+
expected_send = _Action(peer_stage_idx, SEND_B, action.microbatch_index)
|
540
|
+
return expected_send in _prev_ops(peer_stage_idx)
|
541
|
+
elif action.computation_type == SEND_F_RECV_B:
|
542
|
+
# though the stage_index may not be the same between the SEND and the RECV, the rank must be
|
543
|
+
peer_stage_idx = stage_idx + 1
|
544
|
+
for p in _prev_ops(peer_stage_idx):
|
545
|
+
if p is None:
|
546
|
+
continue
|
547
|
+
elif (
|
548
|
+
p.computation_type == SEND_B_RECV_F
|
549
|
+
and action.other_stage_index is not None
|
550
|
+
and p.stage_index == action.other_stage_index + 1
|
551
|
+
and p.other_stage_index is not None
|
552
|
+
and p.other_stage_index == action.stage_index + 1
|
553
|
+
and p.microbatch_index == action.other_microbatch_index
|
554
|
+
and p.other_microbatch_index == action.microbatch_index
|
555
|
+
):
|
556
|
+
return True
|
557
|
+
return False
|
558
|
+
elif action.computation_type == SEND_B_RECV_F:
|
559
|
+
# though the stage_index may not be the same between the SEND and the RECV, the rank must be
|
560
|
+
peer_stage_idx = action.stage_index - 1
|
561
|
+
for p in _prev_ops(peer_stage_idx):
|
562
|
+
# if p is not None and str(p) == "0SEND_F14-16RECV_B0":
|
563
|
+
# breakpoint()
|
564
|
+
if p is None:
|
565
|
+
continue
|
566
|
+
elif (
|
567
|
+
p.computation_type == SEND_F_RECV_B
|
568
|
+
and p.stage_index + 1 == action.other_stage_index
|
569
|
+
and p.other_stage_index + 1 == action.stage_index
|
570
|
+
and p.microbatch_index == action.other_microbatch_index
|
571
|
+
and p.other_microbatch_index == action.microbatch_index
|
572
|
+
):
|
573
|
+
return True
|
574
|
+
return False
|
575
|
+
|
576
|
+
else:
|
577
|
+
raise ValueError(f"Unsupported action type {action}")
|
578
|
+
|
579
|
+
while pipeline_order:
|
580
|
+
progress = False
|
581
|
+
for rank in sorted(pipeline_order):
|
582
|
+
if len(pipeline_order[rank]) == 0:
|
583
|
+
continue
|
584
|
+
|
585
|
+
action = pipeline_order[rank][0]
|
586
|
+
if _ready_to_schedule(action):
|
587
|
+
if action is not None:
|
588
|
+
schedule[rank].append(action)
|
589
|
+
pipeline_order[rank].pop(0)
|
590
|
+
progress = True
|
591
|
+
else:
|
592
|
+
schedule[rank].append(None)
|
593
|
+
|
594
|
+
for i in sorted(pipeline_order, reverse=True):
|
595
|
+
if len(pipeline_order[i]) == 0:
|
596
|
+
del pipeline_order[i]
|
597
|
+
|
598
|
+
# hacky, but do a second pass to replace any 'none' at this timestep with a real action, if it got unblocked
|
599
|
+
# by one of the later ranks
|
600
|
+
for rank in sorted(pipeline_order):
|
601
|
+
if len(pipeline_order[rank]) == 0:
|
602
|
+
continue
|
603
|
+
|
604
|
+
if schedule[rank][-1] is not None:
|
605
|
+
continue
|
606
|
+
|
607
|
+
action = pipeline_order[rank][0]
|
608
|
+
if _ready_to_schedule(action):
|
609
|
+
if action is not None:
|
610
|
+
schedule[rank][-1] = action
|
611
|
+
pipeline_order[rank].pop(0)
|
612
|
+
|
613
|
+
for i in sorted(pipeline_order, reverse=True):
|
614
|
+
if len(pipeline_order[i]) == 0:
|
615
|
+
del pipeline_order[i]
|
616
|
+
|
617
|
+
if not progress:
|
618
|
+
print("WIP comms schedule:\n", _format_pipeline_order(schedule))
|
619
|
+
for rank in pipeline_order:
|
620
|
+
print(f"{rank=} next action= {pipeline_order[rank][0]}")
|
621
|
+
raise ValueError("Schedule is not progressing")
|
622
|
+
|
623
|
+
return schedule
|
624
|
+
|
625
|
+
|
626
|
+
def _dump_chrometrace(schedule, filename):
|
627
|
+
events = []
|
628
|
+
for rank in sorted(schedule):
|
629
|
+
for timestep, action in enumerate(schedule[rank]):
|
630
|
+
if action is None:
|
631
|
+
continue
|
632
|
+
events.append(
|
633
|
+
{
|
634
|
+
"name": str(action),
|
635
|
+
"cat": (
|
636
|
+
"computation"
|
637
|
+
if action.computation_type in (F, B, W)
|
638
|
+
else "communication"
|
639
|
+
),
|
640
|
+
"ph": "X",
|
641
|
+
"pid": rank,
|
642
|
+
"tid": rank,
|
643
|
+
"ts": timestep,
|
644
|
+
"dur": 1,
|
645
|
+
}
|
646
|
+
)
|
647
|
+
import json
|
648
|
+
|
649
|
+
with open(filename, "w") as f:
|
650
|
+
json.dump({"traceEvents": events}, f)
|
651
|
+
|
652
|
+
|
653
|
+
def _dump_csv(pipeline_order_with_comms, filename: str):
|
654
|
+
"""Dump a CSV representation of the compute + comms schedule into a file with the provided filename."""
|
655
|
+
with open(filename, "w", newline="") as csvfile:
|
656
|
+
writer = csv.writer(csvfile)
|
657
|
+
for rank in pipeline_order_with_comms:
|
658
|
+
writer.writerow(pipeline_order_with_comms[rank])
|
659
|
+
|
660
|
+
|
661
|
+
def _merge_bw(
|
662
|
+
compute_actions: List[Optional[_Action]],
|
663
|
+
) -> List[_Action]:
|
664
|
+
"""Given a basic schedule involving only compute actions (F,B,W), merge adjacent B and W ops into BW ops.
|
665
|
+
|
666
|
+
BW refers to running the whole backward (not separating grad_input and grad_weight), which can be more efficient
|
667
|
+
in some cases.
|
668
|
+
"""
|
669
|
+
merged_actions = []
|
670
|
+
while compute_actions:
|
671
|
+
action = compute_actions.pop(0)
|
672
|
+
if action is None:
|
673
|
+
continue
|
674
|
+
|
675
|
+
while len(compute_actions) and (next_action := compute_actions[0]) is None:
|
676
|
+
# remove any None actions between 'action' and 'next_action'
|
677
|
+
compute_actions.pop(0)
|
678
|
+
|
679
|
+
if (
|
680
|
+
action.computation_type == B
|
681
|
+
and next_action is not None
|
682
|
+
and next_action.computation_type == W
|
683
|
+
and action.stage_index == next_action.stage_index
|
684
|
+
and action.microbatch_index == next_action.microbatch_index
|
685
|
+
):
|
686
|
+
merged_actions.append(
|
687
|
+
_Action(action.stage_index, BW, action.microbatch_index)
|
688
|
+
)
|
689
|
+
compute_actions.pop(0)
|
690
|
+
else:
|
691
|
+
merged_actions.append(action)
|
692
|
+
return merged_actions
|