d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,86 @@
1
+ from ..component.program import (
2
+ PipelineProgramBuilder,
3
+ ScheduleStyle,
4
+ add_communication_ops,
5
+ build_stage_to_host_rank_topology,
6
+ )
7
+ from ..component.runtime import (
8
+ ActionBase,
9
+ BackwardFullInputComputeAction,
10
+ ForwardComputeAction,
11
+ )
12
+
13
+
14
+ class LoopedBFSPipelineProgramBuilder(PipelineProgramBuilder):
15
+ """
16
+ Builder for the Breadth-First Pipeline Parallelism schedule.
17
+
18
+ This schedule runs all available forward microbatches for local stages first.
19
+ If configured for training, it then runs backwards in reverse topological order.
20
+
21
+ References:
22
+ https://arxiv.org/pdf/2211.05953
23
+ """
24
+
25
+ def __init__(self, num_stages_per_rank: int, inference_mode: bool = False):
26
+ """
27
+ Constructs the LoopedBFS builder.
28
+
29
+ Args:
30
+ num_stages_per_rank: Number of stages per rank.
31
+ inference_mode: If True, only forward passes are scheduled. If False,
32
+ both forward and backward passes are scheduled.
33
+ """
34
+ self._num_stages_per_rank = num_stages_per_rank
35
+ self._inference_mode = inference_mode
36
+
37
+ def compose(self, num_microbatches: int, pp_size: int) -> dict[int, list[ActionBase]]:
38
+ num_stages = self._num_stages_per_rank * pp_size
39
+ stage_to_rank = build_stage_to_host_rank_topology(
40
+ pp_size=pp_size,
41
+ num_stages=num_stages,
42
+ style=ScheduleStyle.loop
43
+ )
44
+
45
+ compute_actions: dict[int, list[ActionBase]] = {r: [] for r in range(pp_size)}
46
+
47
+ for rank in range(pp_size):
48
+ my_stages = [s for s in range(num_stages) if stage_to_rank[s] == rank]
49
+
50
+ # Schedule all Forwards
51
+ # In Breadth-First loops, we finish all microbatches for the current stage
52
+ # before moving to the next stage assigned to this rank.
53
+ for stage_idx in my_stages:
54
+ for mb_idx in range(num_microbatches):
55
+ compute_actions[rank].append(
56
+ ForwardComputeAction(
57
+ stage_idx=stage_idx,
58
+ microbatch_idx=mb_idx
59
+ )
60
+ )
61
+
62
+ # Schedule all Backwards (Reverse order) - Only if training
63
+ if not self._inference_mode:
64
+ for stage_idx in reversed(my_stages):
65
+ for mb_idx in reversed(range(num_microbatches)):
66
+ compute_actions[rank].append(
67
+ BackwardFullInputComputeAction(
68
+ stage_idx=stage_idx,
69
+ microbatch_idx=mb_idx,
70
+ full_backward=True
71
+ )
72
+ )
73
+
74
+ return add_communication_ops(
75
+ compute_actions=compute_actions,
76
+ stage_to_rank=stage_to_rank,
77
+ num_stages=num_stages
78
+ )
79
+
80
+ @property
81
+ def num_stages_per_rank(self) -> int:
82
+ return self._num_stages_per_rank
83
+
84
+ @property
85
+ def topology_style(self) -> ScheduleStyle:
86
+ return ScheduleStyle.loop
@@ -0,0 +1,234 @@
1
+ from collections import deque
2
+
3
+ from ..component.program import (
4
+ PipelineProgramBuilder,
5
+ ScheduleStyle,
6
+ add_communication_ops,
7
+ build_stage_to_host_rank_topology,
8
+ )
9
+ from ..component.runtime import (
10
+ ActionBase,
11
+ BackwardFullInputComputeAction,
12
+ BackwardWeightComputeAction,
13
+ ComposeAction,
14
+ ForwardComputeAction,
15
+ )
16
+
17
+
18
+ class DualPipeVPipelineProgramBuilder(PipelineProgramBuilder):
19
+ """
20
+ Builder for the DualPipeV Pipeline Parallelism schedule.
21
+
22
+ DualPipeV is a specialized bi-directional pipeline schedule designed for high
23
+ throughput training. It requires exactly 2 stages per pipeline rank (V-shape)
24
+ and utilizes split backward passes (Input gradients vs Weight gradients)
25
+ to fill pipeline bubbles.
26
+
27
+ References:
28
+ https://github.com/deepseek-ai/DualPipe
29
+ https://hackmd.io/@ufotalent/r1lVXsa9Jg
30
+ """
31
+
32
+ def __init__(self):
33
+ """
34
+ Constructs the DualPipeV builder.
35
+ """
36
+
37
+ @staticmethod
38
+ def _build_for_rank( # noqa: C901
39
+ rank: int, stage_to_rank: dict[int, int], num_microbatches: int, pp_size: int
40
+ ) -> list[ActionBase]:
41
+ compute_actions: list[ActionBase] = []
42
+
43
+ # Identify local stages: s0 is Phase 0, s1 is Phase 1
44
+ my_stages = sorted([s for s, r in stage_to_rank.items() if r == rank])
45
+ s0, s1 = my_stages[0], my_stages[1]
46
+
47
+ # Track microbatch indices for each stage and operation type
48
+ # f_idx: Next Forward microbatch
49
+ # b_idx: Next Backward microbatch (Input or Full)
50
+ f_idx = {s0: 0, s1: 0}
51
+ b_idx = {s0: 0, s1: 0}
52
+
53
+ # Queue for Zero Bubble optimization: stores (stage, mb_idx) for deferred weight grads
54
+ weight_queue: deque[tuple[int, int]] = deque()
55
+
56
+ # --- Helper Functions for Action Emission ---
57
+
58
+ def _add_f(stage: int):
59
+ compute_actions.append(
60
+ ForwardComputeAction(stage_idx=stage, microbatch_idx=f_idx[stage])
61
+ )
62
+ f_idx[stage] += 1
63
+
64
+ def _add_b_full(stage: int):
65
+ compute_actions.append(
66
+ BackwardFullInputComputeAction(
67
+ stage_idx=stage,
68
+ microbatch_idx=b_idx[stage],
69
+ full_backward=True,
70
+ )
71
+ )
72
+ b_idx[stage] += 1
73
+
74
+ def _add_b_input(stage: int):
75
+ mb = b_idx[stage]
76
+ compute_actions.append(
77
+ BackwardFullInputComputeAction(
78
+ stage_idx=stage,
79
+ microbatch_idx=mb,
80
+ full_backward=False,
81
+ )
82
+ )
83
+ weight_queue.append((stage, mb))
84
+ b_idx[stage] += 1
85
+
86
+ def _pop_w():
87
+ if not weight_queue:
88
+ return
89
+ s, mb = weight_queue.popleft()
90
+ compute_actions.append(
91
+ BackwardWeightComputeAction(stage_idx=s, microbatch_idx=mb)
92
+ )
93
+
94
+ def _add_overlap_f_b(stage_f: int, stage_b: int, b_is_full: bool):
95
+ """Emit overlapped Forward and Backward actions."""
96
+ mb_f = f_idx[stage_f]
97
+ mb_b = b_idx[stage_b]
98
+
99
+ act_f = ForwardComputeAction(stage_idx=stage_f, microbatch_idx=mb_f)
100
+
101
+ act_b = BackwardFullInputComputeAction(
102
+ stage_idx=stage_b, microbatch_idx=mb_b, full_backward=b_is_full
103
+ )
104
+ if not b_is_full:
105
+ weight_queue.append((stage_b, mb_b))
106
+
107
+ f_idx[stage_f] += 1
108
+ b_idx[stage_b] += 1
109
+
110
+ # Note: d9d infra treats ComposeAction sequentially in simulation,
111
+ # but runtime may overlap them.
112
+ compute_actions.append(ComposeAction(actions=(act_f, act_b)))
113
+
114
+ # Step 1: nF0 (Startup Phase 0)
115
+ step_1 = (pp_size - rank - 1) * 2
116
+ for _ in range(step_1):
117
+ _add_f(s0)
118
+
119
+ # Step 2: nF0F1 (Forward fill)
120
+ step_2 = rank + 1
121
+ for _ in range(step_2):
122
+ _add_f(s0)
123
+ _add_f(s1)
124
+
125
+ # Step 3: nI1W1F1 (Mixed Phase with Zero Bubble)
126
+ step_3 = pp_size - rank - 1
127
+ for _ in range(step_3):
128
+ _add_b_input(s1) # Backward Input Phase 1
129
+ _pop_w() # Weight Phase (accumulated from prev)
130
+ _add_f(s1) # Forward Phase 1
131
+
132
+ # Step 4: The Main Loop (Interleaved Forward/Backward)
133
+ step_4 = num_microbatches - 2 * pp_size + rank + 1
134
+ for i in range(step_4):
135
+ # Sub-step A: F0 & B1
136
+ if i == 0 and rank == pp_size - 1:
137
+ # Specific case for last rank on first iter: do not overlap
138
+ _add_f(s0)
139
+ _add_b_full(s1)
140
+ else:
141
+ # Overlap F0 and B1 (usually full backward unless we were in ZB mode,
142
+ # but DualPipeV main loop defaults to full for simplicity unless tuned)
143
+ # DeepSeek impl uses standard backward here (zb=False).
144
+ _add_overlap_f_b(stage_f=s0, stage_b=s1, b_is_full=True)
145
+
146
+ # Sub-step B: F1 & B0
147
+ # Overlap F1 and B0 (Full)
148
+ _add_overlap_f_b(stage_f=s1, stage_b=s0, b_is_full=True)
149
+
150
+ # Step 5: Cooldown F1/B0
151
+ step_5 = pp_size - rank - 1
152
+ for _ in range(step_5):
153
+ _add_b_full(s1)
154
+ _add_overlap_f_b(stage_f=s1, stage_b=s0, b_is_full=True)
155
+
156
+ # Step 6: Cooldown B1/B0 with Zero Bubble ramp-up
157
+ step_6 = rank + 1
158
+ enable_zb = False
159
+ for i in range(step_6):
160
+ # Phase 1 Backward
161
+ if i == step_6 // 2 and rank % 2 == 1:
162
+ enable_zb = True
163
+
164
+ if enable_zb:
165
+ _add_b_input(s1)
166
+ else:
167
+ _add_b_full(s1)
168
+
169
+ # Phase 0 Backward
170
+ if i == step_6 // 2 and rank % 2 == 0:
171
+ enable_zb = True
172
+
173
+ if enable_zb:
174
+ _add_b_input(s0)
175
+ else:
176
+ _add_b_full(s0)
177
+
178
+ # Step 7: Zero Bubble Weights + B0
179
+ step_7 = pp_size - rank - 1
180
+ for _ in range(step_7):
181
+ _pop_w()
182
+ # DeepSeek source explicitly uses enable_zb=True here for chunk 0
183
+ _add_b_input(s0)
184
+
185
+ # Step 8: Flush Weights
186
+ step_8 = rank + 1
187
+ for _ in range(step_8):
188
+ _pop_w()
189
+
190
+ return compute_actions
191
+
192
+ def compose(
193
+ self, num_microbatches: int, pp_size: int
194
+ ) -> dict[int, list[ActionBase]]:
195
+ num_stages = self.num_stages_per_rank * pp_size
196
+
197
+ if num_microbatches < num_stages:
198
+ raise ValueError(
199
+ f"DualPipeV requires num_microbatches ({num_microbatches}) >= "
200
+ f"num_stages ({num_stages})."
201
+ )
202
+
203
+ # Ranks hold stages in a V pattern (e.g., Rank 0 holds Stage 0 and Stage N-1).
204
+ # We rely on the sorted order of local steps to determine Phase 0 (Forward-going)
205
+ # and Phase 1 (Backward-coming).
206
+ stage_to_rank = build_stage_to_host_rank_topology(
207
+ pp_size=pp_size, num_stages=num_stages, style=ScheduleStyle.v
208
+ )
209
+
210
+ compute_actions: dict[int, list[ActionBase]] = {r: [] for r in range(pp_size)}
211
+
212
+ for rank in range(pp_size):
213
+ compute_actions[rank] = self._build_for_rank(
214
+ rank=rank,
215
+ pp_size=pp_size,
216
+ num_microbatches=num_microbatches,
217
+ stage_to_rank=stage_to_rank
218
+ )
219
+
220
+ # 4. Inject Communication Operations
221
+ # This wrapper handles dependency analysis and inserts Send/Recv/Wait ops.
222
+ return add_communication_ops(
223
+ compute_actions=compute_actions,
224
+ stage_to_rank=stage_to_rank,
225
+ num_stages=num_stages
226
+ )
227
+
228
+ @property
229
+ def num_stages_per_rank(self) -> int:
230
+ return 2
231
+
232
+ @property
233
+ def topology_style(self) -> ScheduleStyle:
234
+ return ScheduleStyle.v
@@ -0,0 +1,240 @@
1
+ from collections import defaultdict, deque
2
+
3
+ from ..component.program import (
4
+ PipelineProgramBuilder,
5
+ ScheduleStyle,
6
+ add_communication_ops,
7
+ build_stage_to_host_rank_topology,
8
+ )
9
+ from ..component.runtime import (
10
+ ActionBase,
11
+ BackwardFullInputComputeAction,
12
+ BackwardWeightComputeAction,
13
+ ForwardComputeAction,
14
+ )
15
+
16
+
17
+ class Interleaved1F1BPipelineProgramBuilder(PipelineProgramBuilder):
18
+ """
19
+ Builder for Interleaved Pipeline Parallelism schedules.
20
+
21
+ This builder supports:
22
+
23
+ 1. **Standard Interleaved 1F1B**: Assigns multiple stages per rank and prioritizes
24
+ depth-first execution. (See https://arxiv.org/pdf/2104.04473)
25
+ 2. **Interleaved Zero Bubble (ZB1P)**: Extends 1F1B by splitting backward passes
26
+ into Input Gradients and Weight Gradients. Weight gradients are delayed
27
+ to fill pipeline bubbles. (See https://arxiv.org/pdf/2401.10241)
28
+ """
29
+
30
+ def __init__(self, num_stages_per_rank: int, enable_zero_bubble: bool = False):
31
+ """
32
+ Constructs the Interleaved 1F1B builder.
33
+
34
+ Args:
35
+ num_stages_per_rank: Number of stages per rank.
36
+ enable_zero_bubble: If True, uses the ZB1P schedule variant which
37
+ splits backward passes to reduce bubble size.
38
+ """
39
+ self._num_stages_per_rank = num_stages_per_rank
40
+ self._enable_zero_bubble = enable_zero_bubble
41
+
42
+ def _get_warmup_ops(
43
+ self,
44
+ rank: int,
45
+ microbatches_per_round: int,
46
+ pp_size: int,
47
+ n_microbatches: int,
48
+ multiply_factor: int,
49
+ ) -> int:
50
+ """
51
+ Calculates the number of warmup steps required before entering steady state.
52
+ """
53
+ warmups_ops_last_stage = (self._num_stages_per_rank - 1) * microbatches_per_round
54
+ warmup_ops = warmups_ops_last_stage + multiply_factor * ((pp_size - 1) - rank)
55
+ return min(warmup_ops, n_microbatches * self._num_stages_per_rank)
56
+
57
+ def compose(
58
+ self, num_microbatches: int, pp_size: int
59
+ ) -> dict[int, list[ActionBase]]:
60
+ """
61
+ Generates the execution program for all ranks.
62
+
63
+ Args:
64
+ num_microbatches: Total microbatches. Must be divisible by the derived
65
+ number of rounds.
66
+ pp_size: Number of pipeline ranks.
67
+
68
+ Returns:
69
+ A dictionary mapping rank indices to their list of sequential actions.
70
+ """
71
+ num_stages = self.num_stages_per_rank * pp_size
72
+
73
+ if num_stages % pp_size != 0:
74
+ raise ValueError(
75
+ f"num_stages ({num_stages}) must be divisible by pp_size ({pp_size}) "
76
+ "for interleaved schedules."
77
+ )
78
+
79
+ # 1. Topology Setup
80
+ # Use Loop/Round-Robin assignment: Rank 0 gets Stage 0, PP, 2*PP...
81
+ stage_to_rank = build_stage_to_host_rank_topology(
82
+ pp_size=pp_size, num_stages=num_stages, style=ScheduleStyle.loop
83
+ )
84
+
85
+ num_rounds = max(1, num_microbatches // pp_size)
86
+
87
+ if num_microbatches % num_rounds != 0:
88
+ raise ValueError(
89
+ f"microbatches ({num_microbatches}) must be divisible by rounds ({num_rounds})."
90
+ )
91
+
92
+ microbatches_per_round = num_microbatches // num_rounds
93
+
94
+ # 2. Schedule Generation
95
+ actions: dict[int, list[ActionBase]] = {}
96
+
97
+ # Zero Bubble 1f1b uses a shorter warmup heuristic (factor 1) than Standard (factor 2)
98
+ warmup_multiplier = 1 if self._enable_zero_bubble else 2
99
+
100
+ for rank in range(pp_size):
101
+ actions[rank] = self._generate_rank_schedule(
102
+ rank=rank,
103
+ pp_size=pp_size,
104
+ n_microbatches=num_microbatches,
105
+ microbatches_per_round=microbatches_per_round,
106
+ multiply_factor=warmup_multiplier,
107
+ )
108
+
109
+ # 3. Communication Injection
110
+ return add_communication_ops(
111
+ compute_actions=actions,
112
+ stage_to_rank=stage_to_rank,
113
+ num_stages=num_stages,
114
+ )
115
+
116
+ def _generate_rank_schedule( # noqa: C901
117
+ self,
118
+ rank: int,
119
+ pp_size: int,
120
+ n_microbatches: int,
121
+ microbatches_per_round: int,
122
+ multiply_factor: int,
123
+ ) -> list[ActionBase]:
124
+ """
125
+ Generates the sequential list of compute actions for a specific rank.
126
+ """
127
+ rank_actions: list[ActionBase] = []
128
+
129
+ # -- State Tracking --
130
+ # Map: stage_idx -> next_microbatch_idx
131
+ fwd_counters: dict[int, int] = defaultdict(int)
132
+ bwd_counters: dict[int, int] = defaultdict(int)
133
+
134
+ # FIFO Queue for deferred weight gradients in Zero Bubble
135
+ # Stores: (stage_idx, microbatch_idx)
136
+ pending_weights: deque[tuple[int, int]] = deque()
137
+
138
+ # -- Helpers --
139
+
140
+ def get_global_stage(local_idx: int) -> int:
141
+ """Converts a local virtual stage index (0..N) to global stage ID."""
142
+ return (local_idx * pp_size) + rank
143
+
144
+ def get_fwd_local_idx(op_idx: int) -> int:
145
+ return (op_idx // microbatches_per_round) % self._num_stages_per_rank
146
+
147
+ def get_bwd_local_idx(op_idx: int, warmup_offset: int) -> int:
148
+ return (self._num_stages_per_rank
149
+ - 1
150
+ - ((op_idx - warmup_offset) // microbatches_per_round) % self._num_stages_per_rank)
151
+
152
+ def emit_forward(op_idx: int):
153
+ local_idx = get_fwd_local_idx(op_idx)
154
+ stage = get_global_stage(local_idx)
155
+ mb = fwd_counters[stage]
156
+
157
+ rank_actions.append(ForwardComputeAction(stage_idx=stage, microbatch_idx=mb))
158
+ fwd_counters[stage] += 1
159
+
160
+ def emit_backward(op_idx: int, warmup_offset: int):
161
+ local_idx = get_bwd_local_idx(op_idx, warmup_offset)
162
+ stage = get_global_stage(local_idx)
163
+ mb = bwd_counters[stage]
164
+
165
+ # In Zero Bubble, we split: Backward Input (Now) + Backward Weight (Later)
166
+ # In Standard 1F1B, we do full backward now.
167
+ is_full = not self._enable_zero_bubble
168
+
169
+ rank_actions.append(
170
+ BackwardFullInputComputeAction(
171
+ stage_idx=stage,
172
+ microbatch_idx=mb,
173
+ full_backward=is_full
174
+ )
175
+ )
176
+
177
+ if self._enable_zero_bubble:
178
+ pending_weights.append((stage, mb))
179
+
180
+ bwd_counters[stage] += 1
181
+
182
+ def try_emit_weight_zb(op_idx: int, warmup_offset: int):
183
+ if not self._enable_zero_bubble or not pending_weights:
184
+ return
185
+
186
+ steps_into_1f1b = op_idx - warmup_offset
187
+ # The earliest reasonable time to start weaving in weights is proportional to rank depth
188
+ if steps_into_1f1b >= rank:
189
+ w_stage, w_mb = pending_weights.popleft()
190
+ rank_actions.append(
191
+ BackwardWeightComputeAction(stage_idx=w_stage, microbatch_idx=w_mb)
192
+ )
193
+
194
+ # -- Execution Phase Math --
195
+
196
+ warmup_ops = self._get_warmup_ops(
197
+ rank, microbatches_per_round, pp_size, n_microbatches, multiply_factor
198
+ )
199
+ total_microbatch_ops = self._num_stages_per_rank * n_microbatches
200
+ fwd_bwd_ops = total_microbatch_ops - warmup_ops
201
+ cooldown_ops = total_microbatch_ops - fwd_bwd_ops
202
+
203
+ # Combine into one sequence for iteration, but handle logic per phase
204
+ total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
205
+
206
+ # -- Main Schedule Loop --
207
+
208
+ for op in range(total_ops):
209
+
210
+ # Phase 1: Warmup (Forward Only)
211
+ if op < warmup_ops:
212
+ emit_forward(op)
213
+
214
+ # Phase 2: Steady State (1F1B)
215
+ elif op < warmup_ops + fwd_bwd_ops:
216
+ emit_forward(op)
217
+ emit_backward(op, warmup_offset=warmup_ops)
218
+ try_emit_weight_zb(op, warmup_offset=warmup_ops)
219
+
220
+ # Phase 3: Cooldown (Backward Only)
221
+ else:
222
+ emit_backward(op, warmup_offset=warmup_ops)
223
+ try_emit_weight_zb(op, warmup_offset=warmup_ops)
224
+
225
+ # -- Post-Loop: Flush Remaining Weights (ZB only) --
226
+ while pending_weights:
227
+ w_stage, w_mb = pending_weights.popleft()
228
+ rank_actions.append(
229
+ BackwardWeightComputeAction(stage_idx=w_stage, microbatch_idx=w_mb)
230
+ )
231
+
232
+ return rank_actions
233
+
234
+ @property
235
+ def num_stages_per_rank(self) -> int:
236
+ return self._num_stages_per_rank
237
+
238
+ @property
239
+ def topology_style(self) -> ScheduleStyle:
240
+ return ScheduleStyle.loop