torchmonarch-nightly 2025.6.27__cp313-cp313-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,847 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from __future__ import annotations
8
+
9
+ import copy
10
+ import importlib
11
+ import sys
12
+ from itertools import chain
13
+ from logging import getLogger
14
+ from typing import Callable, Dict, List, Optional
15
+
16
+ import torch
17
+
18
+ import torch.nn as nn
19
+
20
+ import torch.optim as optim
21
+ from monarch import fetch_shard, no_mesh, OpaqueRef, remote, Stream, Tensor
22
+ from monarch.common.device_mesh import DeviceMesh
23
+ from monarch.opaque_module import OpaqueModule
24
+
25
+ from .schedule_ir import (
26
+ _Action,
27
+ _format_pipeline_order,
28
+ B,
29
+ BW,
30
+ F,
31
+ RECV_B,
32
+ RECV_F,
33
+ SEND_B,
34
+ SEND_B_RECV_F,
35
+ SEND_F,
36
+ SEND_F_RECV_B,
37
+ W,
38
+ )
39
+ from .scheduler import generate_schedule
40
+
41
+ logger = getLogger()
42
+
43
+
44
+ run_forward_udf = remote(
45
+ "monarch.parallel.pipelining.runtime.run_forward_impl",
46
+ propagate=lambda stage, input_tensor, model_chunk_id, microbatch_id: input_tensor,
47
+ )
48
+
49
+
50
+ def run_forward_impl(
51
+ stage: nn.Module | OpaqueRef,
52
+ input_tensor: torch.Tensor,
53
+ model_chunk_id: int,
54
+ microbatch_id: int,
55
+ ) -> torch.Tensor:
56
+ """
57
+ Run the forward function for one model chunk.
58
+
59
+ Args:
60
+ stage: The current stage of the model.
61
+ input_tensor: The input tensor for the forward pass.
62
+ buffers: Buffers used during the forward pass.
63
+ model_chunk_id: Identifier for the model chunk.
64
+ microbatch_id: Identifier for the microbatch.
65
+
66
+ Returns:
67
+ The output tensor after the forward pass.
68
+ """
69
+ if isinstance(stage, OpaqueRef):
70
+ worker_stage = stage.value
71
+ else:
72
+ assert isinstance(stage, nn.Module)
73
+ worker_stage = stage
74
+ input_tensor.requires_grad_(True)
75
+ with torch.enable_grad():
76
+ output = worker_stage(
77
+ input_tensor=input_tensor,
78
+ )
79
+ return output
80
+
81
+
82
+ def _run_backward_udf(
83
+ input_tensor: torch.Tensor,
84
+ output_tensor: torch.Tensor,
85
+ output_tensor_grad: Optional[torch.Tensor],
86
+ y: torch.Tensor,
87
+ loss_layer: OpaqueRef,
88
+ loss_list: OpaqueRef,
89
+ model_chunk_id: int,
90
+ microbatch_id: int,
91
+ num_microbatches: int,
92
+ is_last_stage: bool,
93
+ is_last_microbatch: bool,
94
+ ) -> Optional[torch.Tensor]:
95
+ return input_tensor
96
+
97
+
98
+ run_backward_udf = remote(
99
+ "monarch.parallel.pipelining.runtime.run_backward_impl", propagate=_run_backward_udf
100
+ )
101
+
102
+
103
+ def run_backward_impl(
104
+ input_tensor: torch.Tensor,
105
+ output_tensor: torch.Tensor,
106
+ output_tensor_grad: Optional[torch.Tensor],
107
+ y: torch.Tensor,
108
+ loss_layer: nn.Module | OpaqueRef,
109
+ loss_list: List[torch.Tensor] | OpaqueRef,
110
+ model_chunk_id: int,
111
+ microbatch_id: int,
112
+ num_microbatches: int,
113
+ is_last_stage: bool,
114
+ is_last_microbatch: bool,
115
+ ) -> Optional[torch.Tensor]:
116
+ """
117
+ Run the backward function for one model chunk.
118
+
119
+ Args:
120
+ input_tensor: The input tensor for the backward pass.
121
+ output_tensor: The output tensor from the forward pass.
122
+ output_tensor_grad: The gradient of the output tensor.
123
+ y: The target tensor.
124
+ loss_layer: The loss layer used to compute the loss.
125
+ loss_list: A list to store the computed loss values.
126
+ model_chunk_id: Identifier for the model chunk.
127
+ microbatch_id: Identifier for the microbatch.
128
+ num_microbatches: Total number of microbatches.
129
+ is_last_stage: Flag indicating if this is the last stage.
130
+ is_last_microbatch: Flag indicating if this is the last microbatch.
131
+
132
+ Returns:
133
+ The gradient of the input tensor if it requires gradient, otherwise None.
134
+ """
135
+ input_tensor.requires_grad_(True)
136
+ if is_last_stage:
137
+ if isinstance(loss_layer, OpaqueRef):
138
+ worker_loss_layer = loss_layer.value
139
+ else:
140
+ worker_loss_layer = loss_layer
141
+ with torch.enable_grad():
142
+ loss = worker_loss_layer(output_tensor, y).mean() / num_microbatches
143
+ worker_loss_list = (
144
+ loss_list.value if isinstance(loss_list, OpaqueRef) else loss_list
145
+ )
146
+ worker_loss_list.append(loss)
147
+ if is_last_microbatch:
148
+ all_loss = torch.stack(worker_loss_list).sum()
149
+ worker_loss_list.clear()
150
+ worker_loss_list.append(all_loss)
151
+
152
+ output = loss
153
+
154
+ if input_tensor is not None and input_tensor.requires_grad:
155
+ input_tensor.retain_grad()
156
+ if output_tensor_grad is None:
157
+ assert is_last_stage
158
+ output.backward(retain_graph=True)
159
+ else:
160
+ torch.autograd.backward(
161
+ output_tensor, grad_tensors=output_tensor_grad, retain_graph=True
162
+ )
163
+ if input_tensor is not None and input_tensor.requires_grad:
164
+ return input_tensor.grad
165
+ return None
166
+
167
+
168
+ # Get the parameter with the given name from a module.
169
+ get_parameter_udf = remote(
170
+ "monarch.parallel.pipelining.runtime.get_parameter",
171
+ propagate=lambda module, param_name, param_shape: torch.randn(param_shape),
172
+ )
173
+
174
+
175
+ def get_parameter(
176
+ module_ref: nn.Module | OpaqueRef,
177
+ param_name: str,
178
+ param_shape: tuple,
179
+ ):
180
+ """
181
+ Retrieves a parameter from a PyTorch module.
182
+ Args:
183
+ module (nn.Module): The PyTorch module to retrieve the parameter from.
184
+ param_name (str): The name of the parameter to retrieve.
185
+ Returns:
186
+ torch.Tensor: The retrieved parameter as a tensor.
187
+ Raises:
188
+ AttributeError: If the parameter does not exist in the module.
189
+ """
190
+
191
+ if isinstance(module_ref, OpaqueRef):
192
+ module = module_ref.value
193
+ else:
194
+ module = module_ref
195
+ for name, param in module.named_parameters():
196
+ if name == param_name:
197
+ return param
198
+ raise AttributeError(
199
+ f"Module '{module.__class__.__name__}' has no attribute '{param_name}'"
200
+ )
201
+ return param
202
+
203
+
204
+ # Retrieves the loss for the batch.
205
+ get_loss_udf = remote(
206
+ "monarch.parallel.pipelining.runtime.get_loss_impl",
207
+ propagate=lambda loss_list: torch.tensor(0.0),
208
+ )
209
+
210
+
211
+ def get_loss_impl(loss_list):
212
+ """
213
+ Get the loss for the batch.
214
+
215
+ Args:
216
+ loss_list: A list containing loss values.
217
+
218
+ Returns:
219
+ The first loss value from the list.
220
+ """
221
+ if isinstance(loss_list, OpaqueRef):
222
+ worker_loss_list = loss_list.value
223
+ else:
224
+ worker_loss_list = loss_list
225
+ loss = worker_loss_list[0]
226
+ worker_loss_list.clear()
227
+ return loss
228
+
229
+
230
+ class PipelineParallelism:
231
+ """
232
+ Utility class for generating schedule actions based on a pipeline parallelism schedule.
233
+ This class is not a core Monarch primitive, but rather a helper utility to simplify the
234
+ process of creating and executing pipeline parallelism schedules.
235
+ It reuses the similar abstraction as PyTorch pipelining API
236
+ (https://github.com/pytorch/pytorch/blob/3cbc8c54fd37eb590e2a9206aecf3ab568b3e63c/torch/distributed/pipelining/_IR.py#L1200)
237
+
238
+ This class handles the following functionality of pipeline parallelism:
239
+ 1. Initialization: Takes a list of modules as model pipeline stages and a
240
+ list of meshes as devices to execute pipeline stages. Initializes the
241
+ model stages on the meshes in the initialize() function.
242
+ 2. Pipeline IR Schedule Generation: Generates a pipeline parallelism
243
+ schedule according to the user-selected algorithm in the
244
+ generate_pipeline_ir_schedule() function.
245
+ 3. Schedule Dispatch: Dispatches actions from the pipeline parallelism
246
+ schedule to all stages in the dispatch_pp_schedule() function.
247
+ 4. Action Execution: Executes individual actions on a pipeline stage in the run_action() function.
248
+ """
249
+
250
+ def __init__(
251
+ self,
252
+ meshes: List[DeviceMesh],
253
+ stages: List[List[nn.Module | OpaqueRef | OpaqueModule]],
254
+ compute_stream: Stream,
255
+ p2p_stream: Stream,
256
+ schedule: str = "dora-dfs",
257
+ batch_size: int = 4,
258
+ microbatch_size: int = 1,
259
+ loss_fn: Optional[nn.Module] = None,
260
+ loss_list=None,
261
+ ):
262
+ self.stages = stages
263
+ self.meshes = meshes
264
+ self.schedule = schedule
265
+ self.batch_size = batch_size
266
+ self.microbatch_size = microbatch_size
267
+ self.compute_stream = compute_stream
268
+ self.p2p_stream = p2p_stream or Stream("p2p_stream")
269
+
270
+ # TODO(dongli): clean up buffer eagerly to save memory.
271
+ self.input_tensors = {}
272
+ self.output_tensors = {}
273
+ self.output_tensor_grads = {}
274
+ self.input_tensor_grads = {}
275
+ self.fwd_send_handles = {}
276
+ self.bwd_send_handles = {}
277
+ self.input_tensors_borrowed = {}
278
+
279
+ self.num_microbatches = self.batch_size // self.microbatch_size
280
+ self.num_model_chunks = len(self.stages)
281
+ self.pipeline_parallel_size = len(self.meshes)
282
+ self.loss_list = loss_list
283
+
284
+ for i in range(self.num_microbatches):
285
+ self.output_tensor_grads[(self.num_model_chunks - 1, i)] = None
286
+
287
+ with self.meshes[-1].activate():
288
+ self.loss_layer = loss_fn
289
+
290
+ self.all_rank_actions, self.stage_to_rank_map = (
291
+ self.generate_pipeline_ir_schedule(
292
+ schedule_name=self.schedule,
293
+ total_num_model_chunks=len(self.stages),
294
+ pipeline_parallel_size=len(self.meshes),
295
+ batch_size=self.batch_size,
296
+ microbatch_size=self.microbatch_size,
297
+ )
298
+ )
299
+ logger.info(
300
+ f"PipelineParallelism: pp_ir_schedule:\n{_format_pipeline_order(self.all_rank_actions)} \n{self.stage_to_rank_map=}"
301
+ )
302
+
303
+ def stage_to_rank(self, stage_idx):
304
+ return self.stage_to_rank_map[stage_idx]
305
+
306
+ def initialize(
307
+ self,
308
+ ):
309
+ pp_stages = self.stages
310
+ assert len(pp_stages) == len(self.meshes)
311
+ for stage_idx, stage in enumerate(pp_stages):
312
+ for module in stage:
313
+ state_dict = module.state_dict()
314
+ for k, v in state_dict.items():
315
+ if isinstance(v, Tensor):
316
+ state_dict[k] = v.to_mesh(self.meshes[stage_idx])
317
+ module.load_state_dict(state_dict, assign=True)
318
+
319
+ def copy_params_to_new_model(
320
+ self,
321
+ ref_model: List[nn.Module],
322
+ ):
323
+ pp_stages = self.stages
324
+ assert len(pp_stages) == len(self.meshes)
325
+ for stage_idx, stage in enumerate(pp_stages):
326
+ assert len(stage) == 1
327
+ module = stage[0]
328
+ ref_module = ref_model[stage_idx]
329
+ ref_model_state_dict = ref_module.state_dict()
330
+
331
+ src_params = {}
332
+ ref_params_shape = {}
333
+ with self.meshes[0].activate():
334
+ for ref_name, ref_param in ref_model_state_dict.items():
335
+ ref_params_shape[ref_name] = ref_param.shape
336
+ with self.meshes[stage_idx].activate():
337
+ for ref_name, _ in ref_model_state_dict.items():
338
+ ref_param_shape = ref_params_shape[ref_name]
339
+ if isinstance(module, OpaqueRef):
340
+ param = get_parameter_udf(module, ref_name, ref_param_shape)
341
+ elif isinstance(module, OpaqueModule):
342
+ # TODO: implment named_parameters() for OpaqueModule
343
+ param = get_parameter_udf(
344
+ module._object, ref_name, ref_param_shape
345
+ )
346
+ elif isinstance(module, nn.Module):
347
+ param = get_parameter(module, ref_name, ref_param_shape)
348
+ else:
349
+ raise ValueError(f"Unknown module type: {module}")
350
+
351
+ param_local = fetch_shard(param).result()
352
+ with no_mesh.activate():
353
+ src_params[ref_name] = param_local.detach().cpu().numpy()
354
+ for (
355
+ name,
356
+ _,
357
+ ) in ref_model_state_dict.items():
358
+ param_value = src_params[name]
359
+ with self.meshes[0].activate():
360
+ new_param = torch.tensor(param_value)
361
+ ref_model_state_dict[name] = new_param
362
+
363
+ ref_module.load_state_dict(ref_model_state_dict, assign=True)
364
+
365
+ def configure_optimizers(self, config, config_fn):
366
+ optimizers = []
367
+
368
+ for stage in self.stages:
369
+ params = list(chain(*[list(m.parameters()) for m in stage]))
370
+ optimizers.append(
371
+ config_fn(
372
+ config.weight_decay,
373
+ config.learning_rate,
374
+ (config.beta1, config.beta2),
375
+ config.device_type,
376
+ config.optimizer,
377
+ params,
378
+ )
379
+ )
380
+
381
+ return optimizers
382
+
383
+ def generate_pipeline_ir_schedule(
384
+ self,
385
+ schedule_name,
386
+ total_num_model_chunks,
387
+ pipeline_parallel_size,
388
+ batch_size,
389
+ microbatch_size,
390
+ ):
391
+ assert (
392
+ batch_size % microbatch_size == 0
393
+ ), "Batch size should be divisible by microbatch size."
394
+ num_microbatches = batch_size // microbatch_size
395
+ num_round = max(num_microbatches // pipeline_parallel_size, 1)
396
+ assert (
397
+ num_microbatches % num_round == 0
398
+ ), "Number of microbatches should be divisible by number of pipeline rounds."
399
+ num_microbatch_per_round = num_microbatches // num_round
400
+
401
+ num_model_chunks = total_num_model_chunks // pipeline_parallel_size
402
+ total_num_microbatches = num_microbatches * num_model_chunks
403
+ zero_bubble = True
404
+ all_rank_actions, stage_to_rank = generate_schedule(
405
+ schedule_name,
406
+ num_model_chunks,
407
+ pipeline_parallel_size,
408
+ num_round,
409
+ num_microbatch_per_round,
410
+ zero_bubble,
411
+ total_num_microbatches,
412
+ num_microbatches,
413
+ )
414
+ return all_rank_actions, stage_to_rank
415
+
416
+ def split_inputs_outputs(self, x, y):
417
+ microbatch_x = x.split(self.microbatch_size, dim=0)
418
+ for i, _x in enumerate(microbatch_x):
419
+ _x = _x.to_mesh(self.meshes[0])
420
+ self.input_tensors[(0, i)] = _x
421
+
422
+ y = y.to_mesh(self.meshes[-1])
423
+ with self.meshes[-1].activate():
424
+ microbatch_y = y.split(self.microbatch_size, dim=0)
425
+ return microbatch_x, microbatch_y
426
+
427
+ def run(self, x, y):
428
+ self.loss = None
429
+ microbatch_x, microbatch_y = self.split_inputs_outputs(x, y)
430
+ self.dispatch_pp_schedule(
431
+ pipeline_order=self.all_rank_actions,
432
+ stage_to_rank=self.stage_to_rank,
433
+ num_stages=len(self.stages),
434
+ microbatch_x=microbatch_x,
435
+ microbatch_y=microbatch_y,
436
+ )
437
+ self.loss = (
438
+ get_loss_udf(self.loss_list)
439
+ if isinstance(self.loss_list, OpaqueRef)
440
+ else get_loss_impl(self.loss_list)
441
+ )
442
+ return self.loss
443
+ return self.loss
444
+
445
+ def dispatch_pp_schedule(
446
+ self,
447
+ pipeline_order,
448
+ stage_to_rank: Callable[[int], int],
449
+ num_stages: int,
450
+ microbatch_x,
451
+ microbatch_y,
452
+ ):
453
+ pipeline_order = {
454
+ rank: [a for a in pipeline_order[rank] if a is not None]
455
+ for rank in sorted(pipeline_order)
456
+ }
457
+ schedule: Dict[int, List[_Action | None]] = {
458
+ rank: [] for rank in sorted(pipeline_order)
459
+ }
460
+
461
+ def _prev_ops(stage_idx):
462
+ rank = stage_to_rank(stage_idx)
463
+ ops = copy.deepcopy(schedule[rank])
464
+ return ops
465
+
466
+ def _ready_to_schedule(action: Optional[_Action]) -> bool:
467
+ if action is None:
468
+ return True
469
+ stage_idx = action.stage_index
470
+ if action.computation_type == F:
471
+ if action.stage_index == 0:
472
+ return True
473
+ for p in _prev_ops(stage_idx):
474
+ if p is None:
475
+ continue
476
+ elif (
477
+ p.computation_type == F
478
+ and p.stage_index + 1 == action.stage_index
479
+ and p.microbatch_index == action.microbatch_index
480
+ ):
481
+ return True
482
+ elif (
483
+ p.computation_type == RECV_F
484
+ and p.stage_index == action.stage_index
485
+ and p.microbatch_index == action.microbatch_index
486
+ ):
487
+ return True
488
+ elif (
489
+ p.computation_type == SEND_B_RECV_F
490
+ and p.other_stage_index == action.stage_index
491
+ and p.other_microbatch_index == action.microbatch_index
492
+ ):
493
+ return True
494
+ return False
495
+ elif action.computation_type in (B, BW):
496
+ if action.stage_index == num_stages - 1:
497
+ return True
498
+
499
+ for p in _prev_ops(stage_idx):
500
+ if p is None:
501
+ continue
502
+ elif (
503
+ p.computation_type == RECV_B
504
+ and p.stage_index == action.stage_index
505
+ and p.microbatch_index == action.microbatch_index
506
+ ):
507
+ return True
508
+ elif (
509
+ p.computation_type == SEND_F_RECV_B
510
+ and p.other_stage_index == action.stage_index
511
+ and p.other_microbatch_index == action.microbatch_index
512
+ ):
513
+ return True
514
+ elif (
515
+ p.computation_type in (B, BW)
516
+ and p.stage_index - 1 == action.stage_index
517
+ and p.microbatch_index == action.microbatch_index
518
+ ):
519
+ return True
520
+ return False
521
+ elif action.computation_type == W:
522
+ return True
523
+ elif action.computation_type == SEND_F:
524
+ expected_f = _Action(action.stage_index, F, action.microbatch_index)
525
+ return expected_f in _prev_ops(stage_idx)
526
+ elif action.computation_type == RECV_F:
527
+ peer_stage_idx = stage_idx - 1
528
+ expected_send = _Action(peer_stage_idx, SEND_F, action.microbatch_index)
529
+ return expected_send in _prev_ops(peer_stage_idx)
530
+ elif action.computation_type == SEND_B:
531
+ expected_b = _Action(action.stage_index, B, action.microbatch_index)
532
+ expected_bw = _Action(action.stage_index, BW, action.microbatch_index)
533
+ return expected_b in _prev_ops(stage_idx) or expected_bw in _prev_ops(
534
+ stage_idx
535
+ )
536
+ elif action.computation_type == RECV_B:
537
+ peer_stage_idx = stage_idx + 1
538
+ expected_send = _Action(peer_stage_idx, SEND_B, action.microbatch_index)
539
+ return expected_send in _prev_ops(peer_stage_idx)
540
+ elif action.computation_type == SEND_F_RECV_B:
541
+ peer_stage_idx = stage_idx + 1
542
+ for p in _prev_ops(peer_stage_idx):
543
+ if p is None:
544
+ continue
545
+ elif (
546
+ p.computation_type == SEND_B_RECV_F
547
+ and action.other_stage_index is not None
548
+ and p.stage_index == action.other_stage_index + 1
549
+ and p.other_stage_index is not None
550
+ and p.other_stage_index == action.stage_index + 1
551
+ and p.microbatch_index == action.other_microbatch_index
552
+ and p.other_microbatch_index == action.microbatch_index
553
+ ):
554
+ return True
555
+ return False
556
+ elif action.computation_type == SEND_B_RECV_F:
557
+ peer_stage_idx = action.stage_index - 1
558
+ for p in _prev_ops(peer_stage_idx):
559
+ if p is None:
560
+ continue
561
+ elif (
562
+ p.computation_type == SEND_F_RECV_B
563
+ and p.stage_index + 1 == action.other_stage_index
564
+ and p.other_stage_index + 1 == action.stage_index
565
+ and p.microbatch_index == action.other_microbatch_index
566
+ and p.other_microbatch_index == action.microbatch_index
567
+ ):
568
+ return True
569
+ return False
570
+
571
+ else:
572
+ raise ValueError(f"Unsupported action type {action}")
573
+
574
+ while pipeline_order:
575
+ is_progressing = False
576
+ for rank in sorted(pipeline_order):
577
+ if len(pipeline_order[rank]) == 0:
578
+ continue
579
+
580
+ action = pipeline_order[rank][0]
581
+ if _ready_to_schedule(action):
582
+ if action is not None:
583
+ schedule[rank].append(action)
584
+ self.run_action(action, microbatch_x, microbatch_y)
585
+ pipeline_order[rank].pop(0)
586
+ is_progressing = True
587
+ else:
588
+ schedule[rank].append(None)
589
+
590
+ for i in sorted(pipeline_order, reverse=True):
591
+ if len(pipeline_order[i]) == 0:
592
+ del pipeline_order[i]
593
+
594
+ if not is_progressing:
595
+ logger.error("WIP comms schedule:\n", _format_pipeline_order(schedule))
596
+ for rank in pipeline_order:
597
+ print(f"{rank=} next action= {pipeline_order[rank][0]}")
598
+ raise ValueError("Schedule is not progressing")
599
+ return schedule
600
+
601
+ def run_action(
602
+ self,
603
+ action,
604
+ microbatch_x,
605
+ microbatch_y,
606
+ ):
607
+ logger.info(f"running --------> {action=}")
608
+ comp_type = action.computation_type
609
+ mb_index: int = (
610
+ action.microbatch_index if action.microbatch_index is not None else -1
611
+ )
612
+ stage_idx = action.stage_index
613
+ model_chunk_id = stage_idx
614
+ microbatch_id = mb_index
615
+
616
+ pipeline_parallel_rank = self.stage_to_rank(stage_idx)
617
+
618
+ num_model_chunks = self.num_model_chunks
619
+
620
+ is_last_stage = stage_idx == num_model_chunks - 1
621
+ is_last_microbatch = microbatch_id == self.num_microbatches - 1
622
+ try:
623
+ with torch.profiler.record_function(
624
+ f"r{pipeline_parallel_rank}/{model_chunk_id}_{comp_type}_{microbatch_id}"
625
+ ):
626
+ match str(comp_type):
627
+ case "SEND_F":
628
+ output_tensor = (
629
+ self.output_tensors[(model_chunk_id, microbatch_id)]
630
+ .clone()
631
+ .detach()
632
+ )
633
+ other_rank = self.stage_to_rank(stage_idx + 1)
634
+ borrow_output_tensor, borrow = self.p2p_stream.borrow(
635
+ output_tensor
636
+ )
637
+ self.fwd_send_handles[(model_chunk_id + 1, microbatch_id)] = (
638
+ borrow
639
+ )
640
+ with self.p2p_stream.activate():
641
+ self.input_tensors[(model_chunk_id + 1, microbatch_id)] = (
642
+ borrow_output_tensor.to_mesh(self.meshes[other_rank])
643
+ )
644
+ case "SEND_B":
645
+ other_rank = self.stage_to_rank(stage_idx - 1)
646
+ input_tensor_grad = self.input_tensor_grads[
647
+ (model_chunk_id, microbatch_id)
648
+ ]
649
+ borrow_input_tensor_grad, borrow = self.p2p_stream.borrow(
650
+ input_tensor_grad
651
+ )
652
+ self.bwd_send_handles[(model_chunk_id - 1, microbatch_id)] = (
653
+ borrow
654
+ )
655
+ with self.p2p_stream.activate():
656
+ self.output_tensor_grads[
657
+ (model_chunk_id - 1, microbatch_id)
658
+ ] = borrow_input_tensor_grad.to_mesh(
659
+ self.meshes[other_rank]
660
+ )
661
+ if model_chunk_id > 0:
662
+ (
663
+ input_tensor,
664
+ borrow,
665
+ ) = self.input_tensors_borrowed[
666
+ (model_chunk_id, microbatch_id)
667
+ ]
668
+ borrow.drop()
669
+ case "RECV_F":
670
+ assert (model_chunk_id, microbatch_id) in self.input_tensors
671
+ borrow = self.fwd_send_handles[(model_chunk_id, microbatch_id)]
672
+ borrow.drop()
673
+ case "RECV_B":
674
+ assert (
675
+ model_chunk_id,
676
+ microbatch_id,
677
+ ) in self.output_tensor_grads
678
+ borrow = self.bwd_send_handles[(model_chunk_id, microbatch_id)]
679
+ borrow.drop()
680
+ case "F":
681
+ with self.meshes[stage_idx].activate():
682
+ stage = self.stages[model_chunk_id][0]
683
+ input_tensor = self.input_tensors[
684
+ (model_chunk_id, microbatch_id)
685
+ ]
686
+ if model_chunk_id > 0:
687
+ input_tensor_borrowed, borrow = (
688
+ self.compute_stream.borrow(input_tensor)
689
+ )
690
+ self.input_tensors_borrowed[
691
+ (model_chunk_id, microbatch_id)
692
+ ] = (
693
+ input_tensor_borrowed,
694
+ borrow,
695
+ )
696
+ else:
697
+ input_tensor_borrowed = input_tensor
698
+ if isinstance(stage, OpaqueRef) or isinstance(
699
+ stage, OpaqueModule
700
+ ):
701
+ fwd_func = run_forward_udf
702
+ else:
703
+ fwd_func = run_forward_impl
704
+
705
+ output_tensor = fwd_func(
706
+ stage._object
707
+ if isinstance(stage, OpaqueModule)
708
+ else stage,
709
+ input_tensor_borrowed,
710
+ model_chunk_id=model_chunk_id,
711
+ microbatch_id=microbatch_id,
712
+ )
713
+ self.output_tensors[(model_chunk_id, microbatch_id)] = (
714
+ output_tensor
715
+ )
716
+ case "BW":
717
+ with self.meshes[stage_idx].activate():
718
+ stage = self.stages[model_chunk_id][0]
719
+ if model_chunk_id > 0:
720
+ (
721
+ input_tensor,
722
+ borrow,
723
+ ) = self.input_tensors_borrowed[
724
+ (model_chunk_id, microbatch_id)
725
+ ]
726
+ else:
727
+ input_tensor = self.input_tensors[
728
+ (model_chunk_id, microbatch_id)
729
+ ]
730
+ borrow = None
731
+
732
+ output_tensor = self.output_tensors[
733
+ (model_chunk_id, microbatch_id)
734
+ ]
735
+ output_tensor_grad = self.output_tensor_grads[
736
+ (model_chunk_id, microbatch_id)
737
+ ]
738
+ if output_tensor_grad is not None:
739
+ borrow_output_tensor_grad, output_tensor_grad_borrow = (
740
+ self.compute_stream.borrow(output_tensor_grad)
741
+ )
742
+ else:
743
+ borrow_output_tensor_grad = None
744
+ if isinstance(self.loss_list, OpaqueRef):
745
+ bwd_func = run_backward_udf
746
+ else:
747
+ bwd_func = run_backward_impl
748
+ input_tensor_grad = bwd_func(
749
+ input_tensor=input_tensor,
750
+ output_tensor=output_tensor,
751
+ output_tensor_grad=borrow_output_tensor_grad,
752
+ y=microbatch_y[microbatch_id]
753
+ if is_last_stage
754
+ else None,
755
+ loss_layer=self.loss_layer if is_last_stage else None,
756
+ loss_list=self.loss_list if is_last_stage else None,
757
+ model_chunk_id=model_chunk_id,
758
+ microbatch_id=microbatch_id,
759
+ num_microbatches=self.num_microbatches,
760
+ is_last_stage=is_last_stage,
761
+ is_last_microbatch=is_last_microbatch,
762
+ )
763
+ self.input_tensor_grads[(model_chunk_id, microbatch_id)] = (
764
+ input_tensor_grad
765
+ )
766
+ if output_tensor_grad is not None:
767
+ output_tensor_grad_borrow.drop()
768
+ case _:
769
+ raise ValueError(f"{action=} is unknown or unsupported")
770
+
771
+ except Exception as e:
772
+ logger.exception(
773
+ "_PipelineScheduleRuntime caught exception at step when running action %s. error %s Full Schedule:",
774
+ action,
775
+ e,
776
+ )
777
+
778
+
779
+ def add_sys_path_impl(new_directory):
780
+ if new_directory not in sys.path:
781
+ sys.path.append(new_directory)
782
+
783
+
784
+ def build_module_chunk(module_name_or_path, *args, **kwargs):
785
+ """
786
+ Builds a module chunk for pipeline parallelism.
787
+
788
+ Args:
789
+ input_dim (int): The number of input features.
790
+ output_dim (int): The number of output features.
791
+ hidden_dim (int, optional): The number of neurons in the hidden layer. Defaults to 128.
792
+
793
+ Returns:
794
+ torch.nn.Module: The module chunk.
795
+ """
796
+ module_path, class_name = module_name_or_path.rsplit(".", 1)
797
+ module = importlib.import_module(module_path)
798
+ module_class = getattr(module, class_name)
799
+
800
+ model_chunk = module_class(*args, **kwargs)
801
+ model_chunk.train()
802
+ model_chunk.to("cuda")
803
+ return OpaqueRef(model_chunk)
804
+
805
+
806
+ def build_loss_list():
807
+ loss_list = []
808
+ return OpaqueRef(loss_list)
809
+
810
+
811
+ def build_pp_loss_layer():
812
+ loss = nn.MSELoss()
813
+ return OpaqueRef(loss)
814
+
815
+
816
+ def build_optimizer_chunk(model_chunk, lr):
817
+ """
818
+ Builds an optimizer chunk for pipeline parallelism.
819
+
820
+ Args:
821
+ model_chunk (torch.nn.Module): The module chunk.
822
+
823
+ Returns:
824
+ torch.optim.Optimizer: The optimizer chunk.
825
+ """
826
+ optimizer_chunk = optim.SGD(model_chunk.value.parameters(), lr=lr)
827
+ return OpaqueRef(optimizer_chunk)
828
+
829
+
830
+ def optimizer_zero_grad(optimizer_chunk):
831
+ """
832
+ Zeros the gradients of the optimizer chunk.
833
+
834
+ Args:
835
+ optimizer_chunk (torch.optim.Optimizer): The optimizer chunk.
836
+ """
837
+ optimizer_chunk.value.zero_grad()
838
+
839
+
840
+ def optimizer_step(optimizer_chunk):
841
+ """
842
+ Performs a step of the optimizer chunk.
843
+
844
+ Args:
845
+ optimizer_chunk (torch.optim.Optimizer): The optimizer chunk.
846
+ """
847
+ optimizer_chunk.value.step()