d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,361 @@
1
+ import abc
2
+ import dataclasses
3
+ from enum import StrEnum
4
+ from typing import Any
5
+
6
+ import torch
7
+
8
+ from d9d.pipelining.infra.stage import PipelineStage
9
+
10
+ from .communications import PipelineCommunicationHandler
11
+ from .loss import PipelineLossHandler
12
+
13
+
14
+ @dataclasses.dataclass(kw_only=True, slots=True)
15
+ class ActionContext:
16
+ """
17
+ Holds the runtime context required to execute a pipeline action.
18
+
19
+ Attributes:
20
+ pipeline_inputs_microbatches: The global inputs sharded by microbatch.
21
+ pipeline_kwargs_microbatches: The global keyword arguments sharded by microbatch.
22
+ stages: A mapping of stage indices to their active PipelineStage instances.
23
+ communications: The handler for P2P communications.
24
+ loss: The handler for loss computation, or None if not available.
25
+ """
26
+
27
+ pipeline_inputs_microbatches: tuple[dict[str, torch.Tensor], ...]
28
+ pipeline_kwargs_microbatches: tuple[dict[str, Any], ...]
29
+
30
+ stages: dict[int, PipelineStage]
31
+ communications: PipelineCommunicationHandler
32
+ loss: PipelineLossHandler | None
33
+
34
+
35
+ class ActionWorkType(StrEnum):
36
+ """
37
+ Classifies the type of work performed by an action.
38
+
39
+ Attributes:
40
+ compute: Indicates the action involves computation components (forward, backward).
41
+ communicate: Indicates the action involves network I/O components (send, receive).
42
+ """
43
+
44
+ compute = "compute"
45
+ communicate = "communicate"
46
+
47
+
48
+ class ActionBase(abc.ABC):
49
+ """
50
+ Abstract base class for all pipeline schedule actions.
51
+
52
+ An action represents an atomic unit of work in a pipeline schedule,
53
+ such as computing a microbatch or sending/receiving a tensor.
54
+ """
55
+
56
+ @abc.abstractmethod
57
+ def apply(self, ctx: ActionContext):
58
+ """
59
+ Executes the action logic using the provided context.
60
+
61
+ Args:
62
+ ctx: The runtime context containing stages, data, and communication handlers.
63
+ """
64
+
65
+ ...
66
+
67
+ @property
68
+ @abc.abstractmethod
69
+ def work_type(self) -> ActionWorkType:
70
+ """Returns the classification of work this action performs."""
71
+ ...
72
+
73
+ @property
74
+ @abc.abstractmethod
75
+ def has_backward_work(self) -> bool:
76
+ """Returns True if this action involves backward pass computations."""
77
+ ...
78
+
79
+ @abc.abstractmethod
80
+ def __str__(self) -> str:
81
+ """Returns a short string representation of the action for logging/visualization."""
82
+ ...
83
+
84
+
85
+ @dataclasses.dataclass(frozen=True, slots=True)
86
+ class ForwardSendAction(ActionBase):
87
+ """
88
+ Action to schedule a forward pass tensor send operation.
89
+
90
+ Attributes:
91
+ stage_idx: The integer index of the pipeline stage initiating the send operation.
92
+ microbatch_idx: The integer index of the microbatch being sent.
93
+ """
94
+
95
+ stage_idx: int
96
+ microbatch_idx: int
97
+
98
+ def apply(self, ctx: ActionContext):
99
+ ctx.communications.schedule_fwd_send(self.stage_idx, self.microbatch_idx)
100
+
101
+ @property
102
+ def work_type(self) -> ActionWorkType:
103
+ return ActionWorkType.communicate
104
+
105
+ @property
106
+ def has_backward_work(self) -> bool:
107
+ return False
108
+
109
+ def __str__(self) -> str:
110
+ return f"{self.stage_idx}SEND_F{self.microbatch_idx}"
111
+
112
+
113
+ @dataclasses.dataclass(frozen=True, slots=True)
114
+ class BackwardSendAction(ActionBase):
115
+ """
116
+ Action to schedule a backward pass gradient send operation.
117
+
118
+ Attributes:
119
+ stage_idx: The integer index of the pipeline stage initiating the send operation.
120
+ microbatch_idx: The integer index of the microbatch being sent.
121
+ """
122
+
123
+ stage_idx: int
124
+ microbatch_idx: int
125
+
126
+ def apply(self, ctx: ActionContext):
127
+ ctx.communications.schedule_bwd_send(self.stage_idx, self.microbatch_idx)
128
+
129
+ @property
130
+ def work_type(self) -> ActionWorkType:
131
+ return ActionWorkType.communicate
132
+
133
+ @property
134
+ def has_backward_work(self) -> bool:
135
+ return True
136
+
137
+ def __str__(self) -> str:
138
+ return f"{self.stage_idx}SEND_B{self.microbatch_idx}"
139
+
140
+
141
+ @dataclasses.dataclass(frozen=True, slots=True)
142
+ class ForwardReceiveAction(ActionBase):
143
+ """
144
+ Action to schedule a forward pass tensor receive operation.
145
+
146
+ Attributes:
147
+ stage_idx: The integer index of the pipeline stage expecting the receive operation.
148
+ microbatch_idx: The integer index of the microbatch being received.
149
+ """
150
+
151
+ stage_idx: int
152
+ microbatch_idx: int
153
+
154
+ def apply(self, ctx: ActionContext):
155
+ ctx.communications.schedule_fwd_recv(self.stage_idx, self.microbatch_idx)
156
+
157
+ @property
158
+ def work_type(self) -> ActionWorkType:
159
+ return ActionWorkType.communicate
160
+
161
+ @property
162
+ def has_backward_work(self) -> bool:
163
+ return True
164
+
165
+ def __str__(self) -> str:
166
+ return f"{self.stage_idx}RECV_F{self.microbatch_idx}"
167
+
168
+
169
+ @dataclasses.dataclass(frozen=True, slots=True)
170
+ class BackwardReceiveAction(ActionBase):
171
+ """
172
+ Action to schedule a backward pass gradient receive operation.
173
+
174
+ Attributes:
175
+ stage_idx: The integer index of the pipeline stage expecting the receive operation.
176
+ microbatch_idx: The integer index of the microbatch being received.
177
+ """
178
+
179
+ stage_idx: int
180
+ microbatch_idx: int
181
+
182
+ def apply(self, ctx: ActionContext):
183
+ ctx.communications.schedule_bwd_recv(self.stage_idx, self.microbatch_idx)
184
+
185
+ @property
186
+ def work_type(self) -> ActionWorkType:
187
+ return ActionWorkType.communicate
188
+
189
+ @property
190
+ def has_backward_work(self) -> bool:
191
+ return True
192
+
193
+ def __str__(self) -> str:
194
+ return f"{self.stage_idx}RECV_B{self.microbatch_idx}"
195
+
196
+
197
+ @dataclasses.dataclass(frozen=True, slots=True)
198
+ class ForwardComputeAction(ActionBase):
199
+ """
200
+ Action to perform forward computation for a specific microbatch.
201
+
202
+ Attributes:
203
+ stage_idx: The integer index of the pipeline stage.
204
+ microbatch_idx: The integer index of the microbatch to compute.
205
+ """
206
+
207
+ stage_idx: int
208
+ microbatch_idx: int
209
+
210
+ def apply(self, ctx: ActionContext):
211
+ # todo check unsharded
212
+ stage = ctx.stages[self.stage_idx]
213
+
214
+ if not stage.info.is_current_stage_first and self.stage_idx - 1 not in ctx.stages:
215
+ ctx.communications.wait_fwd_recv(self.stage_idx, self.microbatch_idx)
216
+
217
+ stage.forward_one_chunk(
218
+ microbatch_index=self.microbatch_idx,
219
+ pipeline_inputs=ctx.pipeline_inputs_microbatches[self.microbatch_idx],
220
+ pipeline_kwargs=ctx.pipeline_kwargs_microbatches[self.microbatch_idx]
221
+ )
222
+ result = stage.get_local_fwd_output(self.microbatch_idx)
223
+
224
+ if stage.info.is_current_stage_last and ctx.loss is not None:
225
+ ctx.loss.compute_loss(result, self.microbatch_idx)
226
+
227
+ if not stage.info.is_current_stage_last and self.stage_idx + 1 in ctx.stages:
228
+ ctx.stages[self.stage_idx + 1].set_local_fwd_input(
229
+ inputs=result,
230
+ microbatch_index=self.microbatch_idx
231
+ )
232
+
233
+ @property
234
+ def work_type(self) -> ActionWorkType:
235
+ return ActionWorkType.compute
236
+
237
+ @property
238
+ def has_backward_work(self) -> bool:
239
+ return False
240
+
241
+ def __str__(self) -> str:
242
+ return f"{self.stage_idx}F{self.microbatch_idx}"
243
+
244
+
245
+ @dataclasses.dataclass(frozen=True, slots=True)
246
+ class BackwardFullInputComputeAction(ActionBase):
247
+ """
248
+ Action to perform backward computation with respect to inputs.
249
+
250
+ Attributes:
251
+ stage_idx: The integer index of the pipeline stage.
252
+ microbatch_idx: The integer index of the microbatch to compute.
253
+ full_backward: If True, performs a full backward pass including inputs
254
+ and weights. If False, may only compute gradients w.r.t inputs
255
+ (depending on schedule implementation).
256
+ """
257
+
258
+ stage_idx: int
259
+ microbatch_idx: int
260
+ full_backward: bool
261
+
262
+ def apply(self, ctx: ActionContext):
263
+ # todo unshard
264
+ stage = ctx.stages[self.stage_idx]
265
+
266
+ if not stage.info.is_current_stage_last and self.stage_idx + 1 not in ctx.stages:
267
+ ctx.communications.wait_bwd_recv(self.stage_idx, self.microbatch_idx)
268
+
269
+ if stage.info.is_current_stage_last and ctx.loss is not None:
270
+ loss = ctx.loss.acquire_loss(self.microbatch_idx)
271
+ else:
272
+ loss = None
273
+
274
+ stage.backward_one_chunk(
275
+ microbatch_index=self.microbatch_idx,
276
+ full_backward=self.full_backward,
277
+ loss=loss
278
+ )
279
+
280
+ if not stage.info.is_current_stage_first and self.stage_idx - 1 in ctx.stages:
281
+ ctx.stages[self.stage_idx - 1].set_local_bwd_input(
282
+ microbatch_index=self.microbatch_idx,
283
+ inputs=stage.pop_local_bwd_output(self.microbatch_idx)
284
+ )
285
+
286
+ @property
287
+ def work_type(self) -> ActionWorkType:
288
+ return ActionWorkType.compute
289
+
290
+ @property
291
+ def has_backward_work(self) -> bool:
292
+ return True
293
+
294
+ def __str__(self) -> str:
295
+ letter = "B" if self.full_backward else "I"
296
+ return f"{self.stage_idx}{letter}{self.microbatch_idx}"
297
+
298
+
299
+ @dataclasses.dataclass(frozen=True, slots=True)
300
+ class BackwardWeightComputeAction(ActionBase):
301
+ """
302
+ Action to perform gradient accumulation on weights.
303
+
304
+ Attributes:
305
+ stage_idx: The integer index of the pipeline stage.
306
+ microbatch_idx: The integer index of the microbatch to compute.
307
+ """
308
+
309
+ stage_idx: int
310
+ microbatch_idx: int
311
+
312
+ def apply(self, ctx: ActionContext):
313
+ # todo unshard
314
+ stage = ctx.stages[self.stage_idx]
315
+
316
+ stage.backward_weight_one_chunk(
317
+ microbatch_index=self.microbatch_idx
318
+ )
319
+
320
+ @property
321
+ def work_type(self) -> ActionWorkType:
322
+ return ActionWorkType.compute
323
+
324
+ @property
325
+ def has_backward_work(self) -> bool:
326
+ return True
327
+
328
+ def __str__(self) -> str:
329
+ return f"{self.stage_idx}W{self.microbatch_idx}"
330
+
331
+
332
+ @dataclasses.dataclass(frozen=True, slots=True)
333
+ class ComposeAction(ActionBase):
334
+ """
335
+ Composite action scheduling multiple sub-actions sequentially.
336
+
337
+ Used for forward/backward overlapping.
338
+
339
+ Attributes:
340
+ actions: A tuple of sub-actions to be executed sequentially.
341
+ """
342
+
343
+ actions: tuple[ActionBase, ...]
344
+
345
+ def apply(self, ctx: ActionContext):
346
+ for act in self.actions:
347
+ act.apply(ctx)
348
+
349
+ @property
350
+ def work_type(self) -> ActionWorkType:
351
+ sub_work_types = {x.work_type for x in self.actions}
352
+ if len(sub_work_types) != 1:
353
+ raise ValueError("")
354
+ return next(iter(sub_work_types))
355
+
356
+ @property
357
+ def has_backward_work(self) -> bool:
358
+ return any(x.has_backward_work for x in self.actions)
359
+
360
+ def __str__(self) -> str:
361
+ return "|".join(map(str, self.actions))
@@ -0,0 +1,101 @@
1
+ import torch.distributed as dist
2
+
3
+ from d9d.pipelining.infra.stage import PipelineStage
4
+
5
+
6
+ def _schedule_batched_p2p(ops: list[dist.P2POp]) -> list[dist.Work]:
7
+ if not len(ops):
8
+ return []
9
+
10
+ return dist.batch_isend_irecv(ops)
11
+
12
+
13
+ def _wait_batched_p2p(work: list[dist.Work]):
14
+ for work_item in work:
15
+ work_item.wait()
16
+
17
+
18
+ class PipelineCommunicationHandler:
19
+ """Manages point-to-point communications between pipeline stages."""
20
+
21
+ def __init__(self, stages: dict[int, PipelineStage]):
22
+ """
23
+ Constructs the communication handler.
24
+
25
+ Args:
26
+ stages: Mapping of stage indices to PipelineStage instances.
27
+ """
28
+
29
+ self._stages = stages
30
+
31
+ self._forward_receive_ops: dict[tuple[int, int], list[dist.Work]] = {}
32
+ self._backward_receive_ops: dict[tuple[int, int], list[dist.Work]] = {}
33
+
34
+ self._send_ops: list[list[dist.Work]] = []
35
+
36
+ def schedule_fwd_send(self, stage_idx: int, microbatch_idx: int):
37
+ """Schedules non-blocking connection to send forward pass outputs."""
38
+
39
+ stage = self._stages[stage_idx]
40
+ work = _schedule_batched_p2p(stage.get_fwd_send_ops(microbatch_idx))
41
+ self._send_ops.append(work)
42
+
43
+ def schedule_bwd_send(self, stage_idx: int, microbatch_idx: int):
44
+ """Schedules non-blocking connection to send backward pass outputs."""
45
+
46
+ stage = self._stages[stage_idx]
47
+ work = _schedule_batched_p2p(stage.get_bwd_send_ops(microbatch_idx))
48
+ self._send_ops.append(work)
49
+
50
+ def schedule_fwd_recv(self, stage_idx: int, microbatch_idx: int):
51
+ """
52
+ Schedules non-blocking connection to receive forward pass inputs.
53
+
54
+ Raises:
55
+ ValueError: If a receive op is already pending for this stage/microbatch.
56
+ """
57
+ stage = self._stages[stage_idx]
58
+ key = (stage_idx, microbatch_idx)
59
+
60
+ if key in self._forward_receive_ops:
61
+ raise ValueError()
62
+
63
+ work = _schedule_batched_p2p(stage.get_fwd_recv_ops(microbatch_idx))
64
+ self._forward_receive_ops[key] = work
65
+
66
+ def wait_fwd_recv(self, stage_idx: int, microbatch_idx: int):
67
+ """Blocks until the forward pass receive operation completes."""
68
+ key = (stage_idx, microbatch_idx)
69
+ _wait_batched_p2p(self._forward_receive_ops.pop(key))
70
+
71
+ def schedule_bwd_recv(self, stage_idx: int, microbatch_idx: int):
72
+ """
73
+ Schedules non-blocking connection to receive backward pass inputs.
74
+
75
+ Raises:
76
+ ValueError: If a receive op is already pending for this stage/microbatch.
77
+ """
78
+
79
+ stage = self._stages[stage_idx]
80
+ key = (stage_idx, microbatch_idx)
81
+
82
+ if key in self._backward_receive_ops:
83
+ raise ValueError()
84
+
85
+ work = _schedule_batched_p2p(stage.get_bwd_recv_ops(microbatch_idx))
86
+
87
+ self._backward_receive_ops[key] = work
88
+
89
+ def wait_bwd_recv(self, stage_idx: int, microbatch_idx: int):
90
+ """Blocks until the backward pass receive operation completes."""
91
+
92
+ key = (stage_idx, microbatch_idx)
93
+ _wait_batched_p2p(self._backward_receive_ops.pop(key))
94
+
95
+ def wait_send_all(self):
96
+ """Blocks until all pending send operations are completed."""
97
+
98
+ while self._send_ops:
99
+ ops = self._send_ops.pop()
100
+ for op in ops:
101
+ op.wait()
@@ -0,0 +1,113 @@
1
+ from typing import Any
2
+
3
+ import torch
4
+ from torch.autograd.profiler import record_function
5
+
6
+ from d9d.core.dist_context import REGULAR_DOMAIN, DistributedContext
7
+ from d9d.core.sharding import ShardingSpec, shard_spec_on_dim, shard_tree
8
+ from d9d.pipelining.api import PipelineSchedule, PipelineShardingSpec
9
+ from d9d.pipelining.infra.stage import PipelineStage
10
+
11
+ from .action import ActionBase, ActionContext
12
+ from .communications import PipelineCommunicationHandler
13
+ from .loss import LossFn, PipelineLossHandler
14
+
15
+
16
+ class PipelineScheduleExecutor(PipelineSchedule):
17
+ """Executes a defined pipeline schedule by interpreting a sequence of actions."""
18
+
19
+ def __init__(
20
+ self,
21
+ dist_context: DistributedContext,
22
+ stages: list[PipelineStage],
23
+ num_microbatches: int,
24
+ loss_fn: LossFn | None,
25
+ program: dict[int, list[ActionBase]]
26
+ ):
27
+ """
28
+ Constructs the schedule executor.
29
+
30
+ Args:
31
+ dist_context: The distributed context.
32
+ stages: List of stages managed by this executor.
33
+ num_microbatches: Number of microbatches the global batch is split.
34
+ loss_fn: Function to compute loss.
35
+ program: The execution plan mapping rank ID to a list of actions.
36
+ """
37
+
38
+ self._dist_ctx = dist_context
39
+ self._stages = {stage.info.current_stage: stage for stage in stages}
40
+ self._num_microbatches = num_microbatches
41
+ self._program = program
42
+
43
+ self._has_backward = any(any(
44
+ action.has_backward_work for action in sub_program
45
+ ) for sub_program in program.values())
46
+
47
+ self._comm_handler = PipelineCommunicationHandler(self._stages)
48
+ if loss_fn is None:
49
+ self._loss_handler = None
50
+ else:
51
+ self._loss_handler = PipelineLossHandler(loss_fn)
52
+
53
+ self._input_data_sharding_spec: ShardingSpec | None = None
54
+ self._input_kwargs_sharding_spec: ShardingSpec | None = None
55
+
56
+ def configure_buffers(
57
+ self,
58
+ inputs: dict[str, torch.Tensor],
59
+ kwargs: dict[str, Any],
60
+ sharding_spec: PipelineShardingSpec | None
61
+ ):
62
+ if sharding_spec is None or sharding_spec.input_data is None:
63
+ self._input_data_sharding_spec = shard_spec_on_dim(inputs, dim=0)
64
+ if sharding_spec is None or sharding_spec.input_kwargs is None:
65
+ self._input_kwargs_sharding_spec = shard_spec_on_dim(kwargs, dim=0)
66
+
67
+ for stage in self._stages.values():
68
+ stage.configure_buffers(
69
+ num_microbatches=self._num_microbatches,
70
+ pipeline_inputs=inputs,
71
+ has_backward=self._has_backward
72
+ )
73
+
74
+ def step(self, inputs: dict[str, torch.Tensor], kwargs: dict[str, Any]):
75
+ if self._input_data_sharding_spec is None or self._input_kwargs_sharding_spec is None:
76
+ raise ValueError("Please configure sharding specs first")
77
+
78
+ self._dist_ctx.logger.debug("Begin pipeline step")
79
+ pp_group = self._dist_ctx.mesh_for(REGULAR_DOMAIN).get_group("pp")
80
+
81
+ for stage in self._stages.values():
82
+ stage.reset()
83
+
84
+ # Shard inputs and kwargs to microbatches
85
+ inputs_shard = shard_tree(
86
+ inputs,
87
+ num_shards=self._num_microbatches,
88
+ sharding_spec=self._input_data_sharding_spec,
89
+ enforce_even_split=True
90
+ )
91
+ kwargs_shard = shard_tree(
92
+ kwargs,
93
+ num_shards=self._num_microbatches,
94
+ sharding_spec=self._input_kwargs_sharding_spec,
95
+ enforce_even_split=True
96
+ )
97
+
98
+ my_program = self._program[pp_group.rank()]
99
+
100
+ for action in my_program:
101
+ with record_function(str(action)):
102
+ self._dist_ctx.logger.debug(f"Running pipeline action {action}")
103
+ action.apply(ActionContext(
104
+ loss=self._loss_handler,
105
+ stages=self._stages,
106
+ communications=self._comm_handler,
107
+ pipeline_inputs_microbatches=inputs_shard,
108
+ pipeline_kwargs_microbatches=kwargs_shard
109
+ ))
110
+
111
+ self._dist_ctx.logger.debug("Waiting for potentially hanging PP send comms")
112
+ self._comm_handler.wait_send_all() # finalize just in case
113
+ self._dist_ctx.logger.debug("End pipeline step")
@@ -0,0 +1,55 @@
1
+ from collections.abc import Callable
2
+
3
+ import torch
4
+
5
+ LossFn = Callable[[dict[str, torch.Tensor], int], torch.Tensor]
6
+
7
+
8
+ class PipelineLossHandler:
9
+ """Manages loss computation and state caching across forward and backward passes."""
10
+
11
+ def __init__(self, loss_fn: LossFn):
12
+ """
13
+ Constructs the loss handler.
14
+
15
+ Args:
16
+ loss_fn: The callable that computes loss from model outputs.
17
+ """
18
+
19
+ self._loss_fn = loss_fn
20
+ self._cached_values: dict[int, torch.Tensor] = {}
21
+
22
+ def compute_loss(self, forward_result: dict[str, torch.Tensor], microbatch_index: int) -> torch.Tensor:
23
+ """
24
+ Computes loss for a given microbatch result and caches it.
25
+
26
+ Args:
27
+ forward_result: The output from the last stage of the model.
28
+ microbatch_index: The index of the microbatch being processed.
29
+
30
+ Returns:
31
+ The computed loss tensor.
32
+ """
33
+
34
+ result = self._loss_fn(forward_result, microbatch_index)
35
+ self._cached_values[microbatch_index] = result
36
+ return result
37
+
38
+ def acquire_loss(self, microbatch_index: int) -> torch.Tensor:
39
+ """
40
+ Retrieves the cached loss tensor for the backward pass and removes it from the cache.
41
+
42
+ Args:
43
+ microbatch_index: The index of the microbatch.
44
+
45
+ Returns:
46
+ The previously computed loss tensor.
47
+
48
+ Raises:
49
+ ValueError: If the loss for this microbatch hasn't been computed yet.
50
+ """
51
+
52
+ if microbatch_index not in self._cached_values:
53
+ raise ValueError()
54
+
55
+ return self._cached_values[microbatch_index]
@@ -0,0 +1,15 @@
1
+ """
2
+ Pipeline Schedule Implementations
3
+ """
4
+
5
+ from .bfs import LoopedBFSPipelineProgramBuilder
6
+ from .dualpipev import DualPipeVPipelineProgramBuilder
7
+ from .interleaved import Interleaved1F1BPipelineProgramBuilder
8
+ from .zerobubblev import ZeroBubbleVPipelineProgramBuilder
9
+
10
+ __all__ = [
11
+ "DualPipeVPipelineProgramBuilder",
12
+ "Interleaved1F1BPipelineProgramBuilder",
13
+ "LoopedBFSPipelineProgramBuilder",
14
+ "ZeroBubbleVPipelineProgramBuilder"
15
+ ]