torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,847 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
from __future__ import annotations
|
8
|
+
|
9
|
+
import copy
|
10
|
+
import importlib
|
11
|
+
import sys
|
12
|
+
from itertools import chain
|
13
|
+
from logging import getLogger
|
14
|
+
from typing import Callable, Dict, List, Optional
|
15
|
+
|
16
|
+
import torch
|
17
|
+
|
18
|
+
import torch.nn as nn
|
19
|
+
|
20
|
+
import torch.optim as optim
|
21
|
+
from monarch import fetch_shard, no_mesh, OpaqueRef, remote, Stream, Tensor
|
22
|
+
from monarch.common.device_mesh import DeviceMesh
|
23
|
+
from monarch.opaque_module import OpaqueModule
|
24
|
+
|
25
|
+
from .schedule_ir import (
|
26
|
+
_Action,
|
27
|
+
_format_pipeline_order,
|
28
|
+
B,
|
29
|
+
BW,
|
30
|
+
F,
|
31
|
+
RECV_B,
|
32
|
+
RECV_F,
|
33
|
+
SEND_B,
|
34
|
+
SEND_B_RECV_F,
|
35
|
+
SEND_F,
|
36
|
+
SEND_F_RECV_B,
|
37
|
+
W,
|
38
|
+
)
|
39
|
+
from .scheduler import generate_schedule
|
40
|
+
|
41
|
+
logger = getLogger()
|
42
|
+
|
43
|
+
|
44
|
+
run_forward_udf = remote(
|
45
|
+
"monarch.parallel.pipelining.runtime.run_forward_impl",
|
46
|
+
propagate=lambda stage, input_tensor, model_chunk_id, microbatch_id: input_tensor,
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
def run_forward_impl(
|
51
|
+
stage: nn.Module | OpaqueRef,
|
52
|
+
input_tensor: torch.Tensor,
|
53
|
+
model_chunk_id: int,
|
54
|
+
microbatch_id: int,
|
55
|
+
) -> torch.Tensor:
|
56
|
+
"""
|
57
|
+
Run the forward function for one model chunk.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
stage: The current stage of the model.
|
61
|
+
input_tensor: The input tensor for the forward pass.
|
62
|
+
buffers: Buffers used during the forward pass.
|
63
|
+
model_chunk_id: Identifier for the model chunk.
|
64
|
+
microbatch_id: Identifier for the microbatch.
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
The output tensor after the forward pass.
|
68
|
+
"""
|
69
|
+
if isinstance(stage, OpaqueRef):
|
70
|
+
worker_stage = stage.value
|
71
|
+
else:
|
72
|
+
assert isinstance(stage, nn.Module)
|
73
|
+
worker_stage = stage
|
74
|
+
input_tensor.requires_grad_(True)
|
75
|
+
with torch.enable_grad():
|
76
|
+
output = worker_stage(
|
77
|
+
input_tensor=input_tensor,
|
78
|
+
)
|
79
|
+
return output
|
80
|
+
|
81
|
+
|
82
|
+
def _run_backward_udf(
|
83
|
+
input_tensor: torch.Tensor,
|
84
|
+
output_tensor: torch.Tensor,
|
85
|
+
output_tensor_grad: Optional[torch.Tensor],
|
86
|
+
y: torch.Tensor,
|
87
|
+
loss_layer: OpaqueRef,
|
88
|
+
loss_list: OpaqueRef,
|
89
|
+
model_chunk_id: int,
|
90
|
+
microbatch_id: int,
|
91
|
+
num_microbatches: int,
|
92
|
+
is_last_stage: bool,
|
93
|
+
is_last_microbatch: bool,
|
94
|
+
) -> Optional[torch.Tensor]:
|
95
|
+
return input_tensor
|
96
|
+
|
97
|
+
|
98
|
+
run_backward_udf = remote(
|
99
|
+
"monarch.parallel.pipelining.runtime.run_backward_impl", propagate=_run_backward_udf
|
100
|
+
)
|
101
|
+
|
102
|
+
|
103
|
+
def run_backward_impl(
|
104
|
+
input_tensor: torch.Tensor,
|
105
|
+
output_tensor: torch.Tensor,
|
106
|
+
output_tensor_grad: Optional[torch.Tensor],
|
107
|
+
y: torch.Tensor,
|
108
|
+
loss_layer: nn.Module | OpaqueRef,
|
109
|
+
loss_list: List[torch.Tensor] | OpaqueRef,
|
110
|
+
model_chunk_id: int,
|
111
|
+
microbatch_id: int,
|
112
|
+
num_microbatches: int,
|
113
|
+
is_last_stage: bool,
|
114
|
+
is_last_microbatch: bool,
|
115
|
+
) -> Optional[torch.Tensor]:
|
116
|
+
"""
|
117
|
+
Run the backward function for one model chunk.
|
118
|
+
|
119
|
+
Args:
|
120
|
+
input_tensor: The input tensor for the backward pass.
|
121
|
+
output_tensor: The output tensor from the forward pass.
|
122
|
+
output_tensor_grad: The gradient of the output tensor.
|
123
|
+
y: The target tensor.
|
124
|
+
loss_layer: The loss layer used to compute the loss.
|
125
|
+
loss_list: A list to store the computed loss values.
|
126
|
+
model_chunk_id: Identifier for the model chunk.
|
127
|
+
microbatch_id: Identifier for the microbatch.
|
128
|
+
num_microbatches: Total number of microbatches.
|
129
|
+
is_last_stage: Flag indicating if this is the last stage.
|
130
|
+
is_last_microbatch: Flag indicating if this is the last microbatch.
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
The gradient of the input tensor if it requires gradient, otherwise None.
|
134
|
+
"""
|
135
|
+
input_tensor.requires_grad_(True)
|
136
|
+
if is_last_stage:
|
137
|
+
if isinstance(loss_layer, OpaqueRef):
|
138
|
+
worker_loss_layer = loss_layer.value
|
139
|
+
else:
|
140
|
+
worker_loss_layer = loss_layer
|
141
|
+
with torch.enable_grad():
|
142
|
+
loss = worker_loss_layer(output_tensor, y).mean() / num_microbatches
|
143
|
+
worker_loss_list = (
|
144
|
+
loss_list.value if isinstance(loss_list, OpaqueRef) else loss_list
|
145
|
+
)
|
146
|
+
worker_loss_list.append(loss)
|
147
|
+
if is_last_microbatch:
|
148
|
+
all_loss = torch.stack(worker_loss_list).sum()
|
149
|
+
worker_loss_list.clear()
|
150
|
+
worker_loss_list.append(all_loss)
|
151
|
+
|
152
|
+
output = loss
|
153
|
+
|
154
|
+
if input_tensor is not None and input_tensor.requires_grad:
|
155
|
+
input_tensor.retain_grad()
|
156
|
+
if output_tensor_grad is None:
|
157
|
+
assert is_last_stage
|
158
|
+
output.backward(retain_graph=True)
|
159
|
+
else:
|
160
|
+
torch.autograd.backward(
|
161
|
+
output_tensor, grad_tensors=output_tensor_grad, retain_graph=True
|
162
|
+
)
|
163
|
+
if input_tensor is not None and input_tensor.requires_grad:
|
164
|
+
return input_tensor.grad
|
165
|
+
return None
|
166
|
+
|
167
|
+
|
168
|
+
# Get the parameter with the given name from a module.
|
169
|
+
get_parameter_udf = remote(
|
170
|
+
"monarch.parallel.pipelining.runtime.get_parameter",
|
171
|
+
propagate=lambda module, param_name, param_shape: torch.randn(param_shape),
|
172
|
+
)
|
173
|
+
|
174
|
+
|
175
|
+
def get_parameter(
|
176
|
+
module_ref: nn.Module | OpaqueRef,
|
177
|
+
param_name: str,
|
178
|
+
param_shape: tuple,
|
179
|
+
):
|
180
|
+
"""
|
181
|
+
Retrieves a parameter from a PyTorch module.
|
182
|
+
Args:
|
183
|
+
module (nn.Module): The PyTorch module to retrieve the parameter from.
|
184
|
+
param_name (str): The name of the parameter to retrieve.
|
185
|
+
Returns:
|
186
|
+
torch.Tensor: The retrieved parameter as a tensor.
|
187
|
+
Raises:
|
188
|
+
AttributeError: If the parameter does not exist in the module.
|
189
|
+
"""
|
190
|
+
|
191
|
+
if isinstance(module_ref, OpaqueRef):
|
192
|
+
module = module_ref.value
|
193
|
+
else:
|
194
|
+
module = module_ref
|
195
|
+
for name, param in module.named_parameters():
|
196
|
+
if name == param_name:
|
197
|
+
return param
|
198
|
+
raise AttributeError(
|
199
|
+
f"Module '{module.__class__.__name__}' has no attribute '{param_name}'"
|
200
|
+
)
|
201
|
+
return param
|
202
|
+
|
203
|
+
|
204
|
+
# Retrieves the loss for the batch.
|
205
|
+
get_loss_udf = remote(
|
206
|
+
"monarch.parallel.pipelining.runtime.get_loss_impl",
|
207
|
+
propagate=lambda loss_list: torch.tensor(0.0),
|
208
|
+
)
|
209
|
+
|
210
|
+
|
211
|
+
def get_loss_impl(loss_list):
|
212
|
+
"""
|
213
|
+
Get the loss for the batch.
|
214
|
+
|
215
|
+
Args:
|
216
|
+
loss_list: A list containing loss values.
|
217
|
+
|
218
|
+
Returns:
|
219
|
+
The first loss value from the list.
|
220
|
+
"""
|
221
|
+
if isinstance(loss_list, OpaqueRef):
|
222
|
+
worker_loss_list = loss_list.value
|
223
|
+
else:
|
224
|
+
worker_loss_list = loss_list
|
225
|
+
loss = worker_loss_list[0]
|
226
|
+
worker_loss_list.clear()
|
227
|
+
return loss
|
228
|
+
|
229
|
+
|
230
|
+
class PipelineParallelism:
|
231
|
+
"""
|
232
|
+
Utility class for generating schedule actions based on a pipeline parallelism schedule.
|
233
|
+
This class is not a core Monarch primitive, but rather a helper utility to simplify the
|
234
|
+
process of creating and executing pipeline parallelism schedules.
|
235
|
+
It reuses the similar abstraction as PyTorch pipelining API
|
236
|
+
(https://github.com/pytorch/pytorch/blob/3cbc8c54fd37eb590e2a9206aecf3ab568b3e63c/torch/distributed/pipelining/_IR.py#L1200)
|
237
|
+
|
238
|
+
This class handles the following functionality of pipeline parallelism:
|
239
|
+
1. Initialization: Takes a list of modules as model pipeline stages and a
|
240
|
+
list of meshes as devices to execute pipeline stages. Initializes the
|
241
|
+
model stages on the meshes in the initialize() function.
|
242
|
+
2. Pipeline IR Schedule Generation: Generates a pipeline parallelism
|
243
|
+
schedule according to the user-selected algorithm in the
|
244
|
+
generate_pipeline_ir_schedule() function.
|
245
|
+
3. Schedule Dispatch: Dispatches actions from the pipeline parallelism
|
246
|
+
schedule to all stages in the dispatch_pp_schedule() function.
|
247
|
+
4. Action Execution: Executes individual actions on a pipeline stage in the run_action() function.
|
248
|
+
"""
|
249
|
+
|
250
|
+
def __init__(
|
251
|
+
self,
|
252
|
+
meshes: List[DeviceMesh],
|
253
|
+
stages: List[List[nn.Module | OpaqueRef | OpaqueModule]],
|
254
|
+
compute_stream: Stream,
|
255
|
+
p2p_stream: Stream,
|
256
|
+
schedule: str = "dora-dfs",
|
257
|
+
batch_size: int = 4,
|
258
|
+
microbatch_size: int = 1,
|
259
|
+
loss_fn: Optional[nn.Module] = None,
|
260
|
+
loss_list=None,
|
261
|
+
):
|
262
|
+
self.stages = stages
|
263
|
+
self.meshes = meshes
|
264
|
+
self.schedule = schedule
|
265
|
+
self.batch_size = batch_size
|
266
|
+
self.microbatch_size = microbatch_size
|
267
|
+
self.compute_stream = compute_stream
|
268
|
+
self.p2p_stream = p2p_stream or Stream("p2p_stream")
|
269
|
+
|
270
|
+
# TODO(dongli): clean up buffer eagerly to save memory.
|
271
|
+
self.input_tensors = {}
|
272
|
+
self.output_tensors = {}
|
273
|
+
self.output_tensor_grads = {}
|
274
|
+
self.input_tensor_grads = {}
|
275
|
+
self.fwd_send_handles = {}
|
276
|
+
self.bwd_send_handles = {}
|
277
|
+
self.input_tensors_borrowed = {}
|
278
|
+
|
279
|
+
self.num_microbatches = self.batch_size // self.microbatch_size
|
280
|
+
self.num_model_chunks = len(self.stages)
|
281
|
+
self.pipeline_parallel_size = len(self.meshes)
|
282
|
+
self.loss_list = loss_list
|
283
|
+
|
284
|
+
for i in range(self.num_microbatches):
|
285
|
+
self.output_tensor_grads[(self.num_model_chunks - 1, i)] = None
|
286
|
+
|
287
|
+
with self.meshes[-1].activate():
|
288
|
+
self.loss_layer = loss_fn
|
289
|
+
|
290
|
+
self.all_rank_actions, self.stage_to_rank_map = (
|
291
|
+
self.generate_pipeline_ir_schedule(
|
292
|
+
schedule_name=self.schedule,
|
293
|
+
total_num_model_chunks=len(self.stages),
|
294
|
+
pipeline_parallel_size=len(self.meshes),
|
295
|
+
batch_size=self.batch_size,
|
296
|
+
microbatch_size=self.microbatch_size,
|
297
|
+
)
|
298
|
+
)
|
299
|
+
logger.info(
|
300
|
+
f"PipelineParallelism: pp_ir_schedule:\n{_format_pipeline_order(self.all_rank_actions)} \n{self.stage_to_rank_map=}"
|
301
|
+
)
|
302
|
+
|
303
|
+
def stage_to_rank(self, stage_idx):
|
304
|
+
return self.stage_to_rank_map[stage_idx]
|
305
|
+
|
306
|
+
def initialize(
|
307
|
+
self,
|
308
|
+
):
|
309
|
+
pp_stages = self.stages
|
310
|
+
assert len(pp_stages) == len(self.meshes)
|
311
|
+
for stage_idx, stage in enumerate(pp_stages):
|
312
|
+
for module in stage:
|
313
|
+
state_dict = module.state_dict()
|
314
|
+
for k, v in state_dict.items():
|
315
|
+
if isinstance(v, Tensor):
|
316
|
+
state_dict[k] = v.to_mesh(self.meshes[stage_idx])
|
317
|
+
module.load_state_dict(state_dict, assign=True)
|
318
|
+
|
319
|
+
def copy_params_to_new_model(
|
320
|
+
self,
|
321
|
+
ref_model: List[nn.Module],
|
322
|
+
):
|
323
|
+
pp_stages = self.stages
|
324
|
+
assert len(pp_stages) == len(self.meshes)
|
325
|
+
for stage_idx, stage in enumerate(pp_stages):
|
326
|
+
assert len(stage) == 1
|
327
|
+
module = stage[0]
|
328
|
+
ref_module = ref_model[stage_idx]
|
329
|
+
ref_model_state_dict = ref_module.state_dict()
|
330
|
+
|
331
|
+
src_params = {}
|
332
|
+
ref_params_shape = {}
|
333
|
+
with self.meshes[0].activate():
|
334
|
+
for ref_name, ref_param in ref_model_state_dict.items():
|
335
|
+
ref_params_shape[ref_name] = ref_param.shape
|
336
|
+
with self.meshes[stage_idx].activate():
|
337
|
+
for ref_name, _ in ref_model_state_dict.items():
|
338
|
+
ref_param_shape = ref_params_shape[ref_name]
|
339
|
+
if isinstance(module, OpaqueRef):
|
340
|
+
param = get_parameter_udf(module, ref_name, ref_param_shape)
|
341
|
+
elif isinstance(module, OpaqueModule):
|
342
|
+
# TODO: implment named_parameters() for OpaqueModule
|
343
|
+
param = get_parameter_udf(
|
344
|
+
module._object, ref_name, ref_param_shape
|
345
|
+
)
|
346
|
+
elif isinstance(module, nn.Module):
|
347
|
+
param = get_parameter(module, ref_name, ref_param_shape)
|
348
|
+
else:
|
349
|
+
raise ValueError(f"Unknown module type: {module}")
|
350
|
+
|
351
|
+
param_local = fetch_shard(param).result()
|
352
|
+
with no_mesh.activate():
|
353
|
+
src_params[ref_name] = param_local.detach().cpu().numpy()
|
354
|
+
for (
|
355
|
+
name,
|
356
|
+
_,
|
357
|
+
) in ref_model_state_dict.items():
|
358
|
+
param_value = src_params[name]
|
359
|
+
with self.meshes[0].activate():
|
360
|
+
new_param = torch.tensor(param_value)
|
361
|
+
ref_model_state_dict[name] = new_param
|
362
|
+
|
363
|
+
ref_module.load_state_dict(ref_model_state_dict, assign=True)
|
364
|
+
|
365
|
+
def configure_optimizers(self, config, config_fn):
|
366
|
+
optimizers = []
|
367
|
+
|
368
|
+
for stage in self.stages:
|
369
|
+
params = list(chain(*[list(m.parameters()) for m in stage]))
|
370
|
+
optimizers.append(
|
371
|
+
config_fn(
|
372
|
+
config.weight_decay,
|
373
|
+
config.learning_rate,
|
374
|
+
(config.beta1, config.beta2),
|
375
|
+
config.device_type,
|
376
|
+
config.optimizer,
|
377
|
+
params,
|
378
|
+
)
|
379
|
+
)
|
380
|
+
|
381
|
+
return optimizers
|
382
|
+
|
383
|
+
def generate_pipeline_ir_schedule(
|
384
|
+
self,
|
385
|
+
schedule_name,
|
386
|
+
total_num_model_chunks,
|
387
|
+
pipeline_parallel_size,
|
388
|
+
batch_size,
|
389
|
+
microbatch_size,
|
390
|
+
):
|
391
|
+
assert (
|
392
|
+
batch_size % microbatch_size == 0
|
393
|
+
), "Batch size should be divisible by microbatch size."
|
394
|
+
num_microbatches = batch_size // microbatch_size
|
395
|
+
num_round = max(num_microbatches // pipeline_parallel_size, 1)
|
396
|
+
assert (
|
397
|
+
num_microbatches % num_round == 0
|
398
|
+
), "Number of microbatches should be divisible by number of pipeline rounds."
|
399
|
+
num_microbatch_per_round = num_microbatches // num_round
|
400
|
+
|
401
|
+
num_model_chunks = total_num_model_chunks // pipeline_parallel_size
|
402
|
+
total_num_microbatches = num_microbatches * num_model_chunks
|
403
|
+
zero_bubble = True
|
404
|
+
all_rank_actions, stage_to_rank = generate_schedule(
|
405
|
+
schedule_name,
|
406
|
+
num_model_chunks,
|
407
|
+
pipeline_parallel_size,
|
408
|
+
num_round,
|
409
|
+
num_microbatch_per_round,
|
410
|
+
zero_bubble,
|
411
|
+
total_num_microbatches,
|
412
|
+
num_microbatches,
|
413
|
+
)
|
414
|
+
return all_rank_actions, stage_to_rank
|
415
|
+
|
416
|
+
def split_inputs_outputs(self, x, y):
|
417
|
+
microbatch_x = x.split(self.microbatch_size, dim=0)
|
418
|
+
for i, _x in enumerate(microbatch_x):
|
419
|
+
_x = _x.to_mesh(self.meshes[0])
|
420
|
+
self.input_tensors[(0, i)] = _x
|
421
|
+
|
422
|
+
y = y.to_mesh(self.meshes[-1])
|
423
|
+
with self.meshes[-1].activate():
|
424
|
+
microbatch_y = y.split(self.microbatch_size, dim=0)
|
425
|
+
return microbatch_x, microbatch_y
|
426
|
+
|
427
|
+
def run(self, x, y):
|
428
|
+
self.loss = None
|
429
|
+
microbatch_x, microbatch_y = self.split_inputs_outputs(x, y)
|
430
|
+
self.dispatch_pp_schedule(
|
431
|
+
pipeline_order=self.all_rank_actions,
|
432
|
+
stage_to_rank=self.stage_to_rank,
|
433
|
+
num_stages=len(self.stages),
|
434
|
+
microbatch_x=microbatch_x,
|
435
|
+
microbatch_y=microbatch_y,
|
436
|
+
)
|
437
|
+
self.loss = (
|
438
|
+
get_loss_udf(self.loss_list)
|
439
|
+
if isinstance(self.loss_list, OpaqueRef)
|
440
|
+
else get_loss_impl(self.loss_list)
|
441
|
+
)
|
442
|
+
return self.loss
|
443
|
+
return self.loss
|
444
|
+
|
445
|
+
def dispatch_pp_schedule(
|
446
|
+
self,
|
447
|
+
pipeline_order,
|
448
|
+
stage_to_rank: Callable[[int], int],
|
449
|
+
num_stages: int,
|
450
|
+
microbatch_x,
|
451
|
+
microbatch_y,
|
452
|
+
):
|
453
|
+
pipeline_order = {
|
454
|
+
rank: [a for a in pipeline_order[rank] if a is not None]
|
455
|
+
for rank in sorted(pipeline_order)
|
456
|
+
}
|
457
|
+
schedule: Dict[int, List[_Action | None]] = {
|
458
|
+
rank: [] for rank in sorted(pipeline_order)
|
459
|
+
}
|
460
|
+
|
461
|
+
def _prev_ops(stage_idx):
|
462
|
+
rank = stage_to_rank(stage_idx)
|
463
|
+
ops = copy.deepcopy(schedule[rank])
|
464
|
+
return ops
|
465
|
+
|
466
|
+
def _ready_to_schedule(action: Optional[_Action]) -> bool:
|
467
|
+
if action is None:
|
468
|
+
return True
|
469
|
+
stage_idx = action.stage_index
|
470
|
+
if action.computation_type == F:
|
471
|
+
if action.stage_index == 0:
|
472
|
+
return True
|
473
|
+
for p in _prev_ops(stage_idx):
|
474
|
+
if p is None:
|
475
|
+
continue
|
476
|
+
elif (
|
477
|
+
p.computation_type == F
|
478
|
+
and p.stage_index + 1 == action.stage_index
|
479
|
+
and p.microbatch_index == action.microbatch_index
|
480
|
+
):
|
481
|
+
return True
|
482
|
+
elif (
|
483
|
+
p.computation_type == RECV_F
|
484
|
+
and p.stage_index == action.stage_index
|
485
|
+
and p.microbatch_index == action.microbatch_index
|
486
|
+
):
|
487
|
+
return True
|
488
|
+
elif (
|
489
|
+
p.computation_type == SEND_B_RECV_F
|
490
|
+
and p.other_stage_index == action.stage_index
|
491
|
+
and p.other_microbatch_index == action.microbatch_index
|
492
|
+
):
|
493
|
+
return True
|
494
|
+
return False
|
495
|
+
elif action.computation_type in (B, BW):
|
496
|
+
if action.stage_index == num_stages - 1:
|
497
|
+
return True
|
498
|
+
|
499
|
+
for p in _prev_ops(stage_idx):
|
500
|
+
if p is None:
|
501
|
+
continue
|
502
|
+
elif (
|
503
|
+
p.computation_type == RECV_B
|
504
|
+
and p.stage_index == action.stage_index
|
505
|
+
and p.microbatch_index == action.microbatch_index
|
506
|
+
):
|
507
|
+
return True
|
508
|
+
elif (
|
509
|
+
p.computation_type == SEND_F_RECV_B
|
510
|
+
and p.other_stage_index == action.stage_index
|
511
|
+
and p.other_microbatch_index == action.microbatch_index
|
512
|
+
):
|
513
|
+
return True
|
514
|
+
elif (
|
515
|
+
p.computation_type in (B, BW)
|
516
|
+
and p.stage_index - 1 == action.stage_index
|
517
|
+
and p.microbatch_index == action.microbatch_index
|
518
|
+
):
|
519
|
+
return True
|
520
|
+
return False
|
521
|
+
elif action.computation_type == W:
|
522
|
+
return True
|
523
|
+
elif action.computation_type == SEND_F:
|
524
|
+
expected_f = _Action(action.stage_index, F, action.microbatch_index)
|
525
|
+
return expected_f in _prev_ops(stage_idx)
|
526
|
+
elif action.computation_type == RECV_F:
|
527
|
+
peer_stage_idx = stage_idx - 1
|
528
|
+
expected_send = _Action(peer_stage_idx, SEND_F, action.microbatch_index)
|
529
|
+
return expected_send in _prev_ops(peer_stage_idx)
|
530
|
+
elif action.computation_type == SEND_B:
|
531
|
+
expected_b = _Action(action.stage_index, B, action.microbatch_index)
|
532
|
+
expected_bw = _Action(action.stage_index, BW, action.microbatch_index)
|
533
|
+
return expected_b in _prev_ops(stage_idx) or expected_bw in _prev_ops(
|
534
|
+
stage_idx
|
535
|
+
)
|
536
|
+
elif action.computation_type == RECV_B:
|
537
|
+
peer_stage_idx = stage_idx + 1
|
538
|
+
expected_send = _Action(peer_stage_idx, SEND_B, action.microbatch_index)
|
539
|
+
return expected_send in _prev_ops(peer_stage_idx)
|
540
|
+
elif action.computation_type == SEND_F_RECV_B:
|
541
|
+
peer_stage_idx = stage_idx + 1
|
542
|
+
for p in _prev_ops(peer_stage_idx):
|
543
|
+
if p is None:
|
544
|
+
continue
|
545
|
+
elif (
|
546
|
+
p.computation_type == SEND_B_RECV_F
|
547
|
+
and action.other_stage_index is not None
|
548
|
+
and p.stage_index == action.other_stage_index + 1
|
549
|
+
and p.other_stage_index is not None
|
550
|
+
and p.other_stage_index == action.stage_index + 1
|
551
|
+
and p.microbatch_index == action.other_microbatch_index
|
552
|
+
and p.other_microbatch_index == action.microbatch_index
|
553
|
+
):
|
554
|
+
return True
|
555
|
+
return False
|
556
|
+
elif action.computation_type == SEND_B_RECV_F:
|
557
|
+
peer_stage_idx = action.stage_index - 1
|
558
|
+
for p in _prev_ops(peer_stage_idx):
|
559
|
+
if p is None:
|
560
|
+
continue
|
561
|
+
elif (
|
562
|
+
p.computation_type == SEND_F_RECV_B
|
563
|
+
and p.stage_index + 1 == action.other_stage_index
|
564
|
+
and p.other_stage_index + 1 == action.stage_index
|
565
|
+
and p.microbatch_index == action.other_microbatch_index
|
566
|
+
and p.other_microbatch_index == action.microbatch_index
|
567
|
+
):
|
568
|
+
return True
|
569
|
+
return False
|
570
|
+
|
571
|
+
else:
|
572
|
+
raise ValueError(f"Unsupported action type {action}")
|
573
|
+
|
574
|
+
while pipeline_order:
|
575
|
+
is_progressing = False
|
576
|
+
for rank in sorted(pipeline_order):
|
577
|
+
if len(pipeline_order[rank]) == 0:
|
578
|
+
continue
|
579
|
+
|
580
|
+
action = pipeline_order[rank][0]
|
581
|
+
if _ready_to_schedule(action):
|
582
|
+
if action is not None:
|
583
|
+
schedule[rank].append(action)
|
584
|
+
self.run_action(action, microbatch_x, microbatch_y)
|
585
|
+
pipeline_order[rank].pop(0)
|
586
|
+
is_progressing = True
|
587
|
+
else:
|
588
|
+
schedule[rank].append(None)
|
589
|
+
|
590
|
+
for i in sorted(pipeline_order, reverse=True):
|
591
|
+
if len(pipeline_order[i]) == 0:
|
592
|
+
del pipeline_order[i]
|
593
|
+
|
594
|
+
if not is_progressing:
|
595
|
+
logger.error("WIP comms schedule:\n", _format_pipeline_order(schedule))
|
596
|
+
for rank in pipeline_order:
|
597
|
+
print(f"{rank=} next action= {pipeline_order[rank][0]}")
|
598
|
+
raise ValueError("Schedule is not progressing")
|
599
|
+
return schedule
|
600
|
+
|
601
|
+
def run_action(
|
602
|
+
self,
|
603
|
+
action,
|
604
|
+
microbatch_x,
|
605
|
+
microbatch_y,
|
606
|
+
):
|
607
|
+
logger.info(f"running --------> {action=}")
|
608
|
+
comp_type = action.computation_type
|
609
|
+
mb_index: int = (
|
610
|
+
action.microbatch_index if action.microbatch_index is not None else -1
|
611
|
+
)
|
612
|
+
stage_idx = action.stage_index
|
613
|
+
model_chunk_id = stage_idx
|
614
|
+
microbatch_id = mb_index
|
615
|
+
|
616
|
+
pipeline_parallel_rank = self.stage_to_rank(stage_idx)
|
617
|
+
|
618
|
+
num_model_chunks = self.num_model_chunks
|
619
|
+
|
620
|
+
is_last_stage = stage_idx == num_model_chunks - 1
|
621
|
+
is_last_microbatch = microbatch_id == self.num_microbatches - 1
|
622
|
+
try:
|
623
|
+
with torch.profiler.record_function(
|
624
|
+
f"r{pipeline_parallel_rank}/{model_chunk_id}_{comp_type}_{microbatch_id}"
|
625
|
+
):
|
626
|
+
match str(comp_type):
|
627
|
+
case "SEND_F":
|
628
|
+
output_tensor = (
|
629
|
+
self.output_tensors[(model_chunk_id, microbatch_id)]
|
630
|
+
.clone()
|
631
|
+
.detach()
|
632
|
+
)
|
633
|
+
other_rank = self.stage_to_rank(stage_idx + 1)
|
634
|
+
borrow_output_tensor, borrow = self.p2p_stream.borrow(
|
635
|
+
output_tensor
|
636
|
+
)
|
637
|
+
self.fwd_send_handles[(model_chunk_id + 1, microbatch_id)] = (
|
638
|
+
borrow
|
639
|
+
)
|
640
|
+
with self.p2p_stream.activate():
|
641
|
+
self.input_tensors[(model_chunk_id + 1, microbatch_id)] = (
|
642
|
+
borrow_output_tensor.to_mesh(self.meshes[other_rank])
|
643
|
+
)
|
644
|
+
case "SEND_B":
|
645
|
+
other_rank = self.stage_to_rank(stage_idx - 1)
|
646
|
+
input_tensor_grad = self.input_tensor_grads[
|
647
|
+
(model_chunk_id, microbatch_id)
|
648
|
+
]
|
649
|
+
borrow_input_tensor_grad, borrow = self.p2p_stream.borrow(
|
650
|
+
input_tensor_grad
|
651
|
+
)
|
652
|
+
self.bwd_send_handles[(model_chunk_id - 1, microbatch_id)] = (
|
653
|
+
borrow
|
654
|
+
)
|
655
|
+
with self.p2p_stream.activate():
|
656
|
+
self.output_tensor_grads[
|
657
|
+
(model_chunk_id - 1, microbatch_id)
|
658
|
+
] = borrow_input_tensor_grad.to_mesh(
|
659
|
+
self.meshes[other_rank]
|
660
|
+
)
|
661
|
+
if model_chunk_id > 0:
|
662
|
+
(
|
663
|
+
input_tensor,
|
664
|
+
borrow,
|
665
|
+
) = self.input_tensors_borrowed[
|
666
|
+
(model_chunk_id, microbatch_id)
|
667
|
+
]
|
668
|
+
borrow.drop()
|
669
|
+
case "RECV_F":
|
670
|
+
assert (model_chunk_id, microbatch_id) in self.input_tensors
|
671
|
+
borrow = self.fwd_send_handles[(model_chunk_id, microbatch_id)]
|
672
|
+
borrow.drop()
|
673
|
+
case "RECV_B":
|
674
|
+
assert (
|
675
|
+
model_chunk_id,
|
676
|
+
microbatch_id,
|
677
|
+
) in self.output_tensor_grads
|
678
|
+
borrow = self.bwd_send_handles[(model_chunk_id, microbatch_id)]
|
679
|
+
borrow.drop()
|
680
|
+
case "F":
|
681
|
+
with self.meshes[stage_idx].activate():
|
682
|
+
stage = self.stages[model_chunk_id][0]
|
683
|
+
input_tensor = self.input_tensors[
|
684
|
+
(model_chunk_id, microbatch_id)
|
685
|
+
]
|
686
|
+
if model_chunk_id > 0:
|
687
|
+
input_tensor_borrowed, borrow = (
|
688
|
+
self.compute_stream.borrow(input_tensor)
|
689
|
+
)
|
690
|
+
self.input_tensors_borrowed[
|
691
|
+
(model_chunk_id, microbatch_id)
|
692
|
+
] = (
|
693
|
+
input_tensor_borrowed,
|
694
|
+
borrow,
|
695
|
+
)
|
696
|
+
else:
|
697
|
+
input_tensor_borrowed = input_tensor
|
698
|
+
if isinstance(stage, OpaqueRef) or isinstance(
|
699
|
+
stage, OpaqueModule
|
700
|
+
):
|
701
|
+
fwd_func = run_forward_udf
|
702
|
+
else:
|
703
|
+
fwd_func = run_forward_impl
|
704
|
+
|
705
|
+
output_tensor = fwd_func(
|
706
|
+
stage._object
|
707
|
+
if isinstance(stage, OpaqueModule)
|
708
|
+
else stage,
|
709
|
+
input_tensor_borrowed,
|
710
|
+
model_chunk_id=model_chunk_id,
|
711
|
+
microbatch_id=microbatch_id,
|
712
|
+
)
|
713
|
+
self.output_tensors[(model_chunk_id, microbatch_id)] = (
|
714
|
+
output_tensor
|
715
|
+
)
|
716
|
+
case "BW":
|
717
|
+
with self.meshes[stage_idx].activate():
|
718
|
+
stage = self.stages[model_chunk_id][0]
|
719
|
+
if model_chunk_id > 0:
|
720
|
+
(
|
721
|
+
input_tensor,
|
722
|
+
borrow,
|
723
|
+
) = self.input_tensors_borrowed[
|
724
|
+
(model_chunk_id, microbatch_id)
|
725
|
+
]
|
726
|
+
else:
|
727
|
+
input_tensor = self.input_tensors[
|
728
|
+
(model_chunk_id, microbatch_id)
|
729
|
+
]
|
730
|
+
borrow = None
|
731
|
+
|
732
|
+
output_tensor = self.output_tensors[
|
733
|
+
(model_chunk_id, microbatch_id)
|
734
|
+
]
|
735
|
+
output_tensor_grad = self.output_tensor_grads[
|
736
|
+
(model_chunk_id, microbatch_id)
|
737
|
+
]
|
738
|
+
if output_tensor_grad is not None:
|
739
|
+
borrow_output_tensor_grad, output_tensor_grad_borrow = (
|
740
|
+
self.compute_stream.borrow(output_tensor_grad)
|
741
|
+
)
|
742
|
+
else:
|
743
|
+
borrow_output_tensor_grad = None
|
744
|
+
if isinstance(self.loss_list, OpaqueRef):
|
745
|
+
bwd_func = run_backward_udf
|
746
|
+
else:
|
747
|
+
bwd_func = run_backward_impl
|
748
|
+
input_tensor_grad = bwd_func(
|
749
|
+
input_tensor=input_tensor,
|
750
|
+
output_tensor=output_tensor,
|
751
|
+
output_tensor_grad=borrow_output_tensor_grad,
|
752
|
+
y=microbatch_y[microbatch_id]
|
753
|
+
if is_last_stage
|
754
|
+
else None,
|
755
|
+
loss_layer=self.loss_layer if is_last_stage else None,
|
756
|
+
loss_list=self.loss_list if is_last_stage else None,
|
757
|
+
model_chunk_id=model_chunk_id,
|
758
|
+
microbatch_id=microbatch_id,
|
759
|
+
num_microbatches=self.num_microbatches,
|
760
|
+
is_last_stage=is_last_stage,
|
761
|
+
is_last_microbatch=is_last_microbatch,
|
762
|
+
)
|
763
|
+
self.input_tensor_grads[(model_chunk_id, microbatch_id)] = (
|
764
|
+
input_tensor_grad
|
765
|
+
)
|
766
|
+
if output_tensor_grad is not None:
|
767
|
+
output_tensor_grad_borrow.drop()
|
768
|
+
case _:
|
769
|
+
raise ValueError(f"{action=} is unknown or unsupported")
|
770
|
+
|
771
|
+
except Exception as e:
|
772
|
+
logger.exception(
|
773
|
+
"_PipelineScheduleRuntime caught exception at step when running action %s. error %s Full Schedule:",
|
774
|
+
action,
|
775
|
+
e,
|
776
|
+
)
|
777
|
+
|
778
|
+
|
779
|
+
def add_sys_path_impl(new_directory):
|
780
|
+
if new_directory not in sys.path:
|
781
|
+
sys.path.append(new_directory)
|
782
|
+
|
783
|
+
|
784
|
+
def build_module_chunk(module_name_or_path, *args, **kwargs):
|
785
|
+
"""
|
786
|
+
Builds a module chunk for pipeline parallelism.
|
787
|
+
|
788
|
+
Args:
|
789
|
+
input_dim (int): The number of input features.
|
790
|
+
output_dim (int): The number of output features.
|
791
|
+
hidden_dim (int, optional): The number of neurons in the hidden layer. Defaults to 128.
|
792
|
+
|
793
|
+
Returns:
|
794
|
+
torch.nn.Module: The module chunk.
|
795
|
+
"""
|
796
|
+
module_path, class_name = module_name_or_path.rsplit(".", 1)
|
797
|
+
module = importlib.import_module(module_path)
|
798
|
+
module_class = getattr(module, class_name)
|
799
|
+
|
800
|
+
model_chunk = module_class(*args, **kwargs)
|
801
|
+
model_chunk.train()
|
802
|
+
model_chunk.to("cuda")
|
803
|
+
return OpaqueRef(model_chunk)
|
804
|
+
|
805
|
+
|
806
|
+
def build_loss_list():
|
807
|
+
loss_list = []
|
808
|
+
return OpaqueRef(loss_list)
|
809
|
+
|
810
|
+
|
811
|
+
def build_pp_loss_layer():
|
812
|
+
loss = nn.MSELoss()
|
813
|
+
return OpaqueRef(loss)
|
814
|
+
|
815
|
+
|
816
|
+
def build_optimizer_chunk(model_chunk, lr):
|
817
|
+
"""
|
818
|
+
Builds an optimizer chunk for pipeline parallelism.
|
819
|
+
|
820
|
+
Args:
|
821
|
+
model_chunk (torch.nn.Module): The module chunk.
|
822
|
+
|
823
|
+
Returns:
|
824
|
+
torch.optim.Optimizer: The optimizer chunk.
|
825
|
+
"""
|
826
|
+
optimizer_chunk = optim.SGD(model_chunk.value.parameters(), lr=lr)
|
827
|
+
return OpaqueRef(optimizer_chunk)
|
828
|
+
|
829
|
+
|
830
|
+
def optimizer_zero_grad(optimizer_chunk):
|
831
|
+
"""
|
832
|
+
Zeros the gradients of the optimizer chunk.
|
833
|
+
|
834
|
+
Args:
|
835
|
+
optimizer_chunk (torch.optim.Optimizer): The optimizer chunk.
|
836
|
+
"""
|
837
|
+
optimizer_chunk.value.zero_grad()
|
838
|
+
|
839
|
+
|
840
|
+
def optimizer_step(optimizer_chunk):
|
841
|
+
"""
|
842
|
+
Performs a step of the optimizer chunk.
|
843
|
+
|
844
|
+
Args:
|
845
|
+
optimizer_chunk (torch.optim.Optimizer): The optimizer chunk.
|
846
|
+
"""
|
847
|
+
optimizer_chunk.value.step()
|