d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,227 @@
1
+ from ..component.program import (
2
+ PipelineProgramBuilder,
3
+ ScheduleStyle,
4
+ add_communication_ops,
5
+ build_stage_to_host_rank_topology,
6
+ )
7
+ from ..component.runtime import (
8
+ ActionBase,
9
+ BackwardFullInputComputeAction,
10
+ BackwardWeightComputeAction,
11
+ ForwardComputeAction,
12
+ )
13
+
14
+
15
+ class ZeroBubbleVPipelineProgramBuilder(PipelineProgramBuilder):
16
+ """
17
+ Builder for the Zero Bubble V (ZBV) Pipeline Schedule.
18
+
19
+ This schedule is designed for V-shape topologies (2 stages per rank) and
20
+ utilizes the Zero Bubble optimizations by splitting backward passes.
21
+
22
+ It requires exactly two stages
23
+ per rank organized in a V-shape topology and splits backward passes into
24
+ Input and Weight gradients to optimize pipeline throughput.
25
+
26
+ References:
27
+ https://arxiv.org/pdf/2401.10241, Section 6
28
+ """
29
+
30
+ def __init__(self):
31
+ """Constructs the ZBV builder."""
32
+
33
+ def compose(
34
+ self, num_microbatches: int, pp_size: int
35
+ ) -> dict[int, list[ActionBase]]:
36
+ num_stages = self.num_stages_per_rank * pp_size
37
+
38
+ # 1. Topology
39
+ # V-style: Rank 0 gets Stage 0 & Stage N-1. Rank 1 gets Stage 1 & Stage N-2...
40
+ stage_to_rank = build_stage_to_host_rank_topology(
41
+ pp_size=pp_size, num_stages=num_stages, style=ScheduleStyle.v
42
+ )
43
+
44
+ actions: dict[int, list[ActionBase]] = {}
45
+
46
+ for rank in range(pp_size):
47
+ actions[rank] = self._generate_rank_schedule(
48
+ rank=rank,
49
+ pp_size=pp_size,
50
+ num_stages=num_stages,
51
+ target_microbatches=num_microbatches,
52
+ )
53
+
54
+ # 2. Inject Communications
55
+ return add_communication_ops(
56
+ compute_actions=actions,
57
+ stage_to_rank=stage_to_rank,
58
+ num_stages=num_stages
59
+ )
60
+
61
+ def _generate_rank_schedule( # noqa: C901
62
+ self,
63
+ rank: int,
64
+ pp_size: int,
65
+ num_stages: int,
66
+ target_microbatches: int,
67
+ ) -> list[ActionBase]:
68
+ # ZBV logic assumes the pipeline is fully saturated to define the loop bounds.
69
+ # We simulate enough steps to cover the topology startup, then filter
70
+ # down to the user's requested microbatches at the end.
71
+ simulated_n_micro = max(2 * pp_size - 1, target_microbatches)
72
+
73
+ rank_ops: list[ActionBase] = []
74
+
75
+ # -- Stage Identification (V-Shape) --
76
+ # s0: The "Forward-going" chunk (e.g., Stage 0 for Rank 0)
77
+ # s1: The "Backward-coming" chunk (e.g., Stage N-1 for Rank 0)
78
+ s0 = rank
79
+ s1 = num_stages - 1 - rank
80
+
81
+ # -- Counters --
82
+ # Track next microbatch index for each operation type on each chunk.
83
+ # F: Forward, I: Backward Input, W: Backward Weight
84
+ f0_cnt = 0
85
+ b0_cnt = 0 # Input Grad Counter (Chunk 0)
86
+ w0_cnt = 0 # Weight Grad Counter (Chunk 0)
87
+
88
+ f1_cnt = 0
89
+ b1_cnt = 0 # Input Grad Counter (Chunk 1)
90
+ w1_cnt = 0 # Weight Grad Counter (Chunk 1)
91
+
92
+ # -- Helpers --
93
+
94
+ def emit_f(stage: int, idx: int):
95
+ rank_ops.append(ForwardComputeAction(stage_idx=stage, microbatch_idx=idx))
96
+
97
+ def emit_i_and_w(stage: int, idx: int):
98
+ rank_ops.append(
99
+ BackwardFullInputComputeAction(
100
+ stage_idx=stage, microbatch_idx=idx, full_backward=False
101
+ )
102
+ )
103
+ rank_ops.append(
104
+ BackwardWeightComputeAction(stage_idx=stage, microbatch_idx=idx)
105
+ )
106
+
107
+ def emit_i(stage: int, idx: int):
108
+ rank_ops.append(
109
+ BackwardFullInputComputeAction(
110
+ stage_idx=stage, microbatch_idx=idx, full_backward=False
111
+ )
112
+ )
113
+
114
+ def emit_w(stage: int, idx: int):
115
+ rank_ops.append(
116
+ BackwardWeightComputeAction(stage_idx=stage, microbatch_idx=idx)
117
+ )
118
+
119
+ # -- Phase 1: Warmup 1 (Chunk 0 Forwards) --
120
+ warmup_n1 = 2 * (pp_size - rank) - 1
121
+ for _ in range(warmup_n1):
122
+ emit_f(s0, f0_cnt)
123
+ f0_cnt += 1
124
+
125
+ # -- Phase 2: Warmup 2 (Interleave F1, F0) --
126
+ warmup_n2 = rank
127
+ for _ in range(warmup_n2):
128
+ emit_f(s1, f1_cnt)
129
+ f1_cnt += 1
130
+ emit_f(s0, f0_cnt)
131
+ f0_cnt += 1
132
+
133
+ # -- Phase 3: Warmup 3 (F1, then B1 I+W) --
134
+ warmup_n3 = pp_size - rank
135
+ for _ in range(warmup_n3):
136
+ emit_f(s1, f1_cnt)
137
+ f1_cnt += 1
138
+
139
+ emit_i_and_w(s1, b1_cnt)
140
+ b1_cnt += 1
141
+ w1_cnt += 1
142
+
143
+ # -- Phase 4: Stable State --
144
+ while f1_cnt < f0_cnt or f0_cnt < simulated_n_micro:
145
+ # Emit F0 if within bounds
146
+ if f0_cnt < simulated_n_micro:
147
+ emit_f(s0, f0_cnt)
148
+ f0_cnt += 1
149
+
150
+ # Emit B0 (I+W)
151
+ emit_i_and_w(s0, b0_cnt)
152
+ b0_cnt += 1
153
+ w0_cnt += 1
154
+
155
+ # Emit F1
156
+ emit_f(s1, f1_cnt)
157
+ f1_cnt += 1
158
+
159
+ # Emit B1 (I+W)
160
+ emit_i_and_w(s1, b1_cnt)
161
+ b1_cnt += 1
162
+ w1_cnt += 1
163
+
164
+ # -- Phase 5: Cooldown 1 (Splitting I and W) --
165
+ # In cooldown, the I and W streams diverge to fill bubbles.
166
+ cooldown_n1 = rank
167
+ for _ in range(cooldown_n1):
168
+ emit_i(s0, b0_cnt)
169
+ b0_cnt += 1
170
+
171
+ emit_i(s1, b1_cnt)
172
+ b1_cnt += 1
173
+
174
+ # -- Phase 6: Cooldown 2 (I0, then W0) --
175
+ cooldown_n2 = pp_size - rank
176
+ for _ in range(cooldown_n2):
177
+ # Input Grad Chunk 0
178
+ emit_i(s0, b0_cnt)
179
+ b0_cnt += 1
180
+
181
+ # Weight Grad Chunk 0 (delayed from previous steps)
182
+ emit_w(s0, w0_cnt)
183
+ w0_cnt += 1
184
+
185
+ # -- Phase 7: Flush Remaining Weights --
186
+
187
+ # Flush W1
188
+ while w1_cnt < b1_cnt:
189
+ emit_w(s1, w1_cnt)
190
+ w1_cnt += 1
191
+
192
+ # Flush W0
193
+ while w0_cnt < b0_cnt:
194
+ emit_w(s0, w0_cnt)
195
+ w0_cnt += 1
196
+
197
+ # -- Integrity Check --
198
+ if not (w0_cnt == b0_cnt == f0_cnt):
199
+ raise RuntimeError(
200
+ f"ZBV Schedule Failed (Chunk 0): F={f0_cnt}, I={b0_cnt}, W={w0_cnt}"
201
+ )
202
+ if not (w1_cnt == b1_cnt == f1_cnt):
203
+ raise RuntimeError(
204
+ f"ZBV Schedule Failed (Chunk 1): F={f1_cnt}, I={b1_cnt}, W={w1_cnt}"
205
+ )
206
+
207
+ # -- Post-Process: Filter to Target Microbatches --
208
+ # Remove any actions involving simulated microbatches beyond the user's request.
209
+ final_ops: list[ActionBase] = []
210
+ for action in rank_ops:
211
+ if isinstance(action, (ForwardComputeAction,
212
+ BackwardFullInputComputeAction,
213
+ BackwardWeightComputeAction)):
214
+ if action.microbatch_idx < target_microbatches:
215
+ final_ops.append(action)
216
+ else:
217
+ final_ops.append(action)
218
+
219
+ return final_ops
220
+
221
+ @property
222
+ def num_stages_per_rank(self) -> int:
223
+ return 2
224
+
225
+ @property
226
+ def topology_style(self) -> ScheduleStyle:
227
+ return ScheduleStyle.v
@@ -0,0 +1,5 @@
1
+ from .stage import PipelineStage
2
+
3
+ __all__ = [
4
+ "PipelineStage"
5
+ ]
@@ -0,0 +1,274 @@
1
+ import dataclasses
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+
6
+
7
+ @dataclasses.dataclass(kw_only=True, slots=True)
8
+ class ReceiveStageInput:
9
+ """
10
+ Instruction to receive a specific tensor from a previous stage (or next stage during backward).
11
+
12
+ Attributes:
13
+ name: A unique identifier for the communication operation.
14
+ from_stage: The stage index sending the data.
15
+ buffer: The pre-allocated tensor buffer where data will be received.
16
+ """
17
+
18
+ name: str
19
+ from_stage: int
20
+ buffer: torch.Tensor
21
+
22
+
23
+ @dataclasses.dataclass
24
+ class StartStageInput:
25
+ """
26
+ Instruction indicating that the input for this stage does not come from communication
27
+ (e.g., this is the first stage receiving data loader inputs).
28
+ """
29
+
30
+
31
+ StageInput = ReceiveStageInput | StartStageInput
32
+
33
+
34
+ @dataclasses.dataclass(kw_only=True, slots=True)
35
+ class SendStageOutput:
36
+ """
37
+ Instruction to send a specific tensor to a next stage (or previous if backward).
38
+
39
+ Attributes:
40
+ to_stage: The stage index receiving the data.
41
+ """
42
+
43
+ to_stage: int
44
+
45
+
46
+ @dataclasses.dataclass
47
+ class EndStageOutput:
48
+ """
49
+ Instruction indicating that the output of this stage is not sent anywhere
50
+ (e.g., this is the last stage computing loss).
51
+ """
52
+
53
+
54
+ StageOutput = SendStageOutput | EndStageOutput
55
+
56
+
57
+ class StageCommunicationHandler:
58
+ """
59
+ Manages Point-to-Point (P2P) communication descriptors for a specific data flow direction within a pipeline stage.
60
+
61
+ This class handles the creation of P2P operations (send/recv) across multiple microbatches,
62
+ managing buffers and mapping logical stage indices to physical ranks.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+
68
+ name: str,
69
+ stage_index: int,
70
+ num_microbatches: int,
71
+
72
+ input_stage_index: int | None,
73
+ input_args: dict[str, torch.Tensor],
74
+
75
+ output_stage_index: int | None,
76
+ output_args: dict[str, torch.Tensor],
77
+
78
+ stage_idx_to_host_rank: dict[int, int],
79
+ group: dist.ProcessGroup
80
+ ):
81
+ """
82
+ Constructs a StageCommunicationHandler object.
83
+
84
+ Args:
85
+ name: Name prefix for this handler (e.g., 'fwd', 'bwd').
86
+ stage_index: The logical index of the current stage.
87
+ num_microbatches: Total number of microbatches ("chunks") to schedule.
88
+ input_stage_index: The logical index of the stage providing inputs, or None if inputs are local.
89
+ input_args: Metadata (shapes/dtypes) for input tensors.
90
+ output_stage_index: The logical index of the stage consuming outputs, or None if outputs are terminal.
91
+ output_args: Metadata (shapes/dtypes) for output tensors.
92
+ stage_idx_to_host_rank: Mapping from logical stage indices to physical world ranks.
93
+ group: The process group strictly for pipeline communication.
94
+ """
95
+
96
+ self._input_handlers = self._build_inputs(
97
+ name=name,
98
+ stage_index=stage_index,
99
+ num_microbatches=num_microbatches,
100
+ input_stage_index=input_stage_index,
101
+ input_args=input_args
102
+ )
103
+ self._output_handlers = self._build_outputs(
104
+ output_stage_index=output_stage_index,
105
+ output_args=output_args
106
+ )
107
+
108
+ self._stage_idx_to_host_rank = stage_idx_to_host_rank
109
+ self._group = group
110
+
111
+ @staticmethod
112
+ def _build_inputs(
113
+ name: str,
114
+ stage_index: int,
115
+ num_microbatches: int,
116
+ input_stage_index: int | None,
117
+ input_args: dict[str, torch.Tensor]
118
+ ) -> dict[int, dict[str, StageInput]]:
119
+ handlers: dict[int, dict[str, StageInput]] = {}
120
+
121
+ for chunk_id in range(num_microbatches):
122
+ handlers[chunk_id] = {}
123
+ for input_name, input_tensor_meta in input_args.items():
124
+ if input_stage_index is None:
125
+ handlers[chunk_id][input_name] = StartStageInput()
126
+ else:
127
+ handlers[chunk_id][input_name] = ReceiveStageInput(
128
+ name=f"{name}_recv_from_{input_stage_index}_to_{stage_index}[{chunk_id}][{input_name}]",
129
+ from_stage=input_stage_index,
130
+ buffer=torch.empty(
131
+ input_tensor_meta.size(),
132
+ dtype=input_tensor_meta.dtype,
133
+ layout=input_tensor_meta.layout,
134
+ device="cuda" # force device
135
+ )
136
+ )
137
+ return handlers
138
+
139
+ @staticmethod
140
+ def _build_outputs(
141
+ output_stage_index: int | None,
142
+ output_args: dict[str, torch.Tensor]
143
+ ) -> dict[str, StageOutput]:
144
+ handlers: dict[str, StageOutput] = {}
145
+
146
+ for output_name in output_args:
147
+ if output_stage_index is None:
148
+ handlers[output_name] = EndStageOutput()
149
+ else:
150
+ handlers[output_name] = SendStageOutput(
151
+ to_stage=output_stage_index
152
+ )
153
+ return handlers
154
+
155
+ def set_input_requires_grad_(self, requires_grad: bool):
156
+ """
157
+ Sets the `requires_grad` flag for all internal input buffers.
158
+
159
+ Typically used to enable gradient flow from backward stages to forward stages.
160
+
161
+ Args:
162
+ requires_grad: Whether the buffers should require gradients.
163
+ """
164
+
165
+ for inputs in self._input_handlers.values():
166
+ for info in inputs.values():
167
+ if isinstance(info, ReceiveStageInput):
168
+ info.buffer.requires_grad_(requires_grad)
169
+
170
+ def set_inputs_local(self, inputs: dict[str, torch.Tensor], microbatch_index: int):
171
+ """
172
+ Manually fills the input buffer for a specific microbatch with local data.
173
+
174
+ This is used when the stage is the first in the pipeline or receives data
175
+ from a dataloader rather than via network communication.
176
+
177
+ Args:
178
+ inputs: Dictionary of input tensors.
179
+ microbatch_index: The microbatch identifier.
180
+ """
181
+
182
+ for input_name, input_value in inputs.items():
183
+ handler = self._input_handlers[microbatch_index][input_name]
184
+ if not isinstance(handler, ReceiveStageInput):
185
+ raise RuntimeError("Tried to set a buffer of no-receive stage input")
186
+ prev_requires_grad = handler.buffer.requires_grad
187
+ handler.buffer = input_value.detach().requires_grad_(
188
+ prev_requires_grad)
189
+
190
+ def get_inputs(self, microbatch_index: int) -> dict[str, torch.Tensor]:
191
+ """
192
+ Retrieves the input tensors for a specific microbatch from the internal buffers.
193
+
194
+ Args:
195
+ microbatch_index: The microbatch identifier.
196
+
197
+ Returns:
198
+ Dictionary mapping input names to tensors.
199
+ """
200
+ outputs: dict[str, torch.Tensor] = {}
201
+
202
+ for input_name, input_info in self._input_handlers[microbatch_index].items():
203
+ if not isinstance(input_info, ReceiveStageInput):
204
+ raise RuntimeError("Tried to get a buffer of no receive stage input")
205
+ outputs[input_name] = input_info.buffer
206
+
207
+ return outputs
208
+
209
+ def create_receive_ops(self, microbatch_index: int) -> list[dist.P2POp]:
210
+ """
211
+ Generates the PyTorch P2P receive operations for a specific microbatch.
212
+
213
+ Args:
214
+ microbatch_index: The microbatch identifier.
215
+
216
+ Returns:
217
+ A list of `dist.P2POp` objects configured for `dist.irecv`.
218
+ """
219
+
220
+ ops = []
221
+
222
+ inputs = self._input_handlers[microbatch_index]
223
+ # sort ops by parameter names to ensure receive ops are ordered the same for send and recv
224
+ for _input_name, input_info in sorted(inputs.items(), key=lambda x: x[0]):
225
+ match input_info:
226
+ case StartStageInput():
227
+ pass
228
+ case ReceiveStageInput():
229
+ peer_rank = self._stage_idx_to_host_rank[input_info.from_stage]
230
+ peer_global_rank = dist.get_global_rank(self._group, peer_rank)
231
+ op = dist.P2POp(dist.irecv, input_info.buffer, peer_global_rank, self._group)
232
+ ops.append(op)
233
+ case _:
234
+ raise ValueError()
235
+
236
+ return ops
237
+
238
+ def create_send_ops(self, send_contents: dict[str, torch.Tensor]) -> list[dist.P2POp]:
239
+ """
240
+ Generates the PyTorch P2P send operations for the provided tensors.
241
+
242
+ Args:
243
+ send_contents: Dictionary of tensors to send.
244
+
245
+ Returns:
246
+ A list of `dist.P2POp` objects configured for `dist.isend`.
247
+ """
248
+
249
+ ops = []
250
+
251
+ # sort ops by parameter names to ensure receive ops are ordered the same for send and recv
252
+ for output_name, output_info in sorted(self._output_handlers.items(), key=lambda x: x[0]):
253
+ output_tensor = send_contents[output_name]
254
+
255
+ match output_info:
256
+ case EndStageOutput():
257
+ pass
258
+ case SendStageOutput():
259
+ peer_rank = self._stage_idx_to_host_rank[output_info.to_stage]
260
+ peer_global_rank = dist.get_global_rank(self._group, peer_rank)
261
+ op = dist.P2POp(dist.isend, output_tensor, peer_global_rank, self._group)
262
+ ops.append(op)
263
+ case _:
264
+ raise ValueError()
265
+
266
+ return ops
267
+
268
+ def reset(self):
269
+ """Resets the internal state, specifically clearing gradients on input buffers."""
270
+
271
+ for inp_handlers in self._input_handlers.values():
272
+ for inp_handler in inp_handlers.values():
273
+ if isinstance(inp_handler, ReceiveStageInput):
274
+ inp_handler.buffer.grad = None