d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,89 @@
1
+ from typing import Annotated, Literal
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ class PipelineScheduleInferenceConfig(BaseModel):
7
+ """
8
+ Configuration for inference-only pipeline execution.
9
+
10
+ This schedule runs all forward passes sequentially without any backward passes.
11
+ """
12
+
13
+ schedule: Literal["inference"] = "inference"
14
+
15
+
16
+ class PipelineScheduleGPipeConfig(BaseModel):
17
+ """
18
+ Configuration for GPipe execution.
19
+
20
+ This assumes a single stage per rank and processes all microbatches for the
21
+ forward pass before switching to the backward pass.
22
+ """
23
+
24
+ schedule: Literal["gpipe"] = "gpipe"
25
+
26
+
27
+ class PipelineScheduleLoopedBFSConfig(BaseModel):
28
+ """
29
+ Configuration for Looped Breadth-First Search execution.
30
+
31
+ Similar to GPipe, but supports multiple stages per rank (virtualization).
32
+ It executes all available work for a specific stage before moving to the next.
33
+ """
34
+
35
+ schedule: Literal["looped_bfs"] = "looped_bfs"
36
+
37
+ num_stages_per_rank: int
38
+
39
+
40
+ class PipelineSchedule1F1BConfig(BaseModel):
41
+ """
42
+ Configuration for Interleaved 1F1B and Interleaved Zero Bubble execution.
43
+
44
+ Supports assigning multiple stages per rank and sharding backward to dI and dW
45
+ to reduce pipeline bubbles.
46
+ """
47
+
48
+ schedule: Literal["1f1b"] = "1f1b"
49
+
50
+ num_stages_per_rank: int
51
+ zero_bubble: bool
52
+
53
+
54
+ class PipelineScheduleZeroBubbleVConfig(BaseModel):
55
+ """
56
+ Configuration for Zero Bubble V (ZBV) execution.
57
+
58
+ A specialized V-shape topology schedule that splits backward passes into
59
+ Input and Weight gradients to maximize overlap. Requires exactly 2 stages per rank.
60
+ """
61
+ schedule: Literal["zero_bubble_v"] = "zero_bubble_v"
62
+
63
+
64
+ class PipelineScheduleDualPipeVConfig(BaseModel):
65
+ """
66
+ Configuration for DualPipeV execution.
67
+
68
+ A bidirectional pipeline schedule for high-throughput training, utilizing
69
+ V-shape topology and reciprocal forward/backward scheduling.
70
+ """
71
+
72
+ schedule: Literal["dual_pipe_v"] = "dual_pipe_v"
73
+
74
+
75
+ AnyPipelineScheduleConfig = Annotated[
76
+ PipelineScheduleInferenceConfig |
77
+ PipelineScheduleGPipeConfig |
78
+ PipelineScheduleLoopedBFSConfig |
79
+ PipelineSchedule1F1BConfig |
80
+ PipelineScheduleZeroBubbleVConfig |
81
+ PipelineScheduleDualPipeVConfig,
82
+ Field(discriminator="schedule")
83
+ ]
84
+ """Union of all supported pipeline schedule configuration types.
85
+
86
+ This type alias uses a Pydantic discriminator on the ``schedule`` field to allow
87
+ polymorphic validation and serialization of specific schedule configs (e.g.
88
+ Inference, GPipe, 1F1B, ZeroBubble, etc.).
89
+ """
@@ -0,0 +1,114 @@
1
+ import dataclasses
2
+ from collections.abc import Callable
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from ...core.dist_context import REGULAR_DOMAIN, DistributedContext
8
+ from ..api import PipelineSchedule, PipelineStageInfo
9
+ from ..infra.schedule.component.program import (
10
+ build_stage_to_host_rank_topology,
11
+ invert_stage_to_host_rank_topology,
12
+ )
13
+ from ..infra.schedule.component.runtime import PipelineScheduleExecutor
14
+ from ..infra.stage import PipelineStage
15
+ from .config import (
16
+ AnyPipelineScheduleConfig,
17
+ )
18
+ from .registry import PIPELINE_PROGRAM_REGISTRY
19
+
20
+
21
+ @dataclasses.dataclass(kw_only=True)
22
+ class PipelineScheduleInfo:
23
+ """Contains the built pipeline schedule and rank-specific metadata."""
24
+
25
+ schedule: PipelineSchedule
26
+ has_first_stage: bool
27
+ has_last_stage: bool
28
+
29
+
30
+ def build_schedule(
31
+ dist_context: DistributedContext,
32
+ n_microbatches: int,
33
+ schedule_config: AnyPipelineScheduleConfig,
34
+ model_provider: Callable[[PipelineStageInfo], nn.Module],
35
+ loss_fn: Callable[[dict[str, torch.Tensor], int], torch.Tensor] | None,
36
+ ) -> tuple[PipelineScheduleInfo, list[nn.Module]]:
37
+ """
38
+ Constructs the pipeline schedule and instantiates model stages.
39
+
40
+ This function coordinates the creation of the distributed pipeline. It:
41
+ 1. Selects the appropriate `PipelineProgramBuilder` based on the config.
42
+ 2. Calculates the global stage topology mapping stages to ranks.
43
+ 3. Instantiates the local model stages for the current rank using `model_provider`.
44
+ 4. Wraps models in `PipelineStage` containers.
45
+ 5. Generates the execution program (action list).
46
+ 6. Builds the runtime executor.
47
+
48
+ Args:
49
+ dist_context: The distributed context.
50
+ n_microbatches: Number of microbatches per global step.
51
+ schedule_config: Configuration object determining the schedule strategy.
52
+ model_provider: A factory function that accepts stage info and returns an `nn.Module`
53
+ for that specific stage.
54
+ loss_fn: Optional loss function. Required if training (backward pass needed).
55
+
56
+ Returns:
57
+ A tuple containing:
58
+ 1. `PipelineScheduleInfo`: The executable schedule and metadata.
59
+ 2. `list[nn.Module]`: The local PyTorch modules created for this rank.
60
+ """
61
+
62
+ program_builder = PIPELINE_PROGRAM_REGISTRY.program_for(schedule_config)
63
+ mesh = dist_context.mesh_for(REGULAR_DOMAIN)["pp"]
64
+
65
+ num_stages = program_builder.num_stages_per_rank * mesh.size()
66
+
67
+ stage_to_host = build_stage_to_host_rank_topology(
68
+ num_stages=num_stages,
69
+ pp_size=mesh.size(),
70
+ style=program_builder.topology_style
71
+ )
72
+ host_to_stage = invert_stage_to_host_rank_topology(stage_to_host)
73
+ this_rank_stages = host_to_stage[mesh.get_local_rank()]
74
+
75
+ stages = []
76
+ modules = []
77
+ has_first_stage = False
78
+ has_last_stage = False
79
+
80
+ for stage_idx in this_rank_stages:
81
+ stage_info = PipelineStageInfo(
82
+ num_stages=num_stages,
83
+ current_stage=stage_idx
84
+ )
85
+
86
+ if stage_info.is_current_stage_first:
87
+ has_first_stage = True
88
+ if stage_info.is_current_stage_last:
89
+ has_last_stage = True
90
+
91
+ model = model_provider(stage_info)
92
+ modules.append(model)
93
+ stage = PipelineStage(
94
+ info=stage_info,
95
+ module=model,
96
+ group=mesh.get_group(),
97
+ stage_to_host_topology=stage_to_host
98
+ )
99
+ stages.append(stage)
100
+
101
+ program = program_builder.compose(num_microbatches=n_microbatches, pp_size=mesh.size())
102
+ schedule = PipelineScheduleExecutor(
103
+ dist_context=dist_context,
104
+ stages=stages,
105
+ num_microbatches=n_microbatches,
106
+ loss_fn=loss_fn,
107
+ program=program
108
+ )
109
+
110
+ return PipelineScheduleInfo(
111
+ schedule=schedule,
112
+ has_first_stage=has_first_stage,
113
+ has_last_stage=has_last_stage
114
+ ), modules
@@ -0,0 +1,82 @@
1
+ from collections.abc import Callable
2
+ from typing import TypeVar, cast
3
+
4
+ from d9d.pipelining.factory import (
5
+ AnyPipelineScheduleConfig,
6
+ PipelineSchedule1F1BConfig,
7
+ PipelineScheduleDualPipeVConfig,
8
+ PipelineScheduleGPipeConfig,
9
+ PipelineScheduleInferenceConfig,
10
+ PipelineScheduleLoopedBFSConfig,
11
+ PipelineScheduleZeroBubbleVConfig,
12
+ )
13
+ from d9d.pipelining.infra.schedule.component.program import PipelineProgramBuilder
14
+ from d9d.pipelining.infra.schedule.program import (
15
+ DualPipeVPipelineProgramBuilder,
16
+ Interleaved1F1BPipelineProgramBuilder,
17
+ LoopedBFSPipelineProgramBuilder,
18
+ ZeroBubbleVPipelineProgramBuilder,
19
+ )
20
+
21
+ TConfig = TypeVar("TConfig", bound=AnyPipelineScheduleConfig)
22
+
23
+ TRegistryDict = dict[
24
+ type[AnyPipelineScheduleConfig],
25
+ Callable[[AnyPipelineScheduleConfig], PipelineProgramBuilder]
26
+ ]
27
+
28
+ TBoundRegistryFn = Callable[[TConfig], PipelineProgramBuilder]
29
+
30
+
31
+ class PipelineProgramRegistry:
32
+ def __init__(self) -> None:
33
+ self._registry: TRegistryDict = {}
34
+
35
+ def register_program(
36
+ self, config_cls: type[TConfig]
37
+ ) -> Callable[[TBoundRegistryFn], TBoundRegistryFn]:
38
+ def decorator(func: TBoundRegistryFn) -> TBoundRegistryFn:
39
+ config_cls_any = cast(type[AnyPipelineScheduleConfig], config_cls)
40
+ self._registry[config_cls_any] = func
41
+ return func
42
+
43
+ return decorator
44
+
45
+ def program_for(self, config: AnyPipelineScheduleConfig) -> PipelineProgramBuilder:
46
+ program_fn = self._registry[type(config)]
47
+ program = program_fn(config)
48
+ return program
49
+
50
+
51
+ PIPELINE_PROGRAM_REGISTRY = PipelineProgramRegistry()
52
+
53
+
54
+ @PIPELINE_PROGRAM_REGISTRY.register_program(PipelineScheduleGPipeConfig)
55
+ def _build_gpipe(_: PipelineScheduleGPipeConfig) -> PipelineProgramBuilder:
56
+ return LoopedBFSPipelineProgramBuilder(num_stages_per_rank=1, inference_mode=False)
57
+
58
+
59
+ @PIPELINE_PROGRAM_REGISTRY.register_program(PipelineScheduleInferenceConfig)
60
+ def _build_inference(_: PipelineScheduleInferenceConfig) -> PipelineProgramBuilder:
61
+ return LoopedBFSPipelineProgramBuilder(num_stages_per_rank=1, inference_mode=True)
62
+
63
+
64
+ @PIPELINE_PROGRAM_REGISTRY.register_program(PipelineScheduleLoopedBFSConfig)
65
+ def _build_looped_bfs(cfg: PipelineScheduleLoopedBFSConfig) -> PipelineProgramBuilder:
66
+ return LoopedBFSPipelineProgramBuilder(num_stages_per_rank=cfg.num_stages_per_rank, inference_mode=False)
67
+
68
+
69
+ @PIPELINE_PROGRAM_REGISTRY.register_program(PipelineSchedule1F1BConfig)
70
+ def _build_1f1b(cfg: PipelineSchedule1F1BConfig) -> PipelineProgramBuilder:
71
+ return Interleaved1F1BPipelineProgramBuilder(num_stages_per_rank=cfg.num_stages_per_rank,
72
+ enable_zero_bubble=cfg.zero_bubble)
73
+
74
+
75
+ @PIPELINE_PROGRAM_REGISTRY.register_program(PipelineScheduleDualPipeVConfig)
76
+ def _build_dual_pipe_v(_: PipelineScheduleDualPipeVConfig) -> PipelineProgramBuilder:
77
+ return DualPipeVPipelineProgramBuilder()
78
+
79
+
80
+ @PIPELINE_PROGRAM_REGISTRY.register_program(PipelineScheduleZeroBubbleVConfig)
81
+ def _build_zero_bubble_v(_: PipelineScheduleZeroBubbleVConfig) -> PipelineProgramBuilder:
82
+ return ZeroBubbleVPipelineProgramBuilder()
File without changes
File without changes
File without changes
@@ -0,0 +1,22 @@
1
+ """
2
+ Pipeline Schedule Building Components.
3
+
4
+ This package provides the core building blocks and compiler passes used to generate
5
+ execution schedules for distributed pipelines.
6
+ """
7
+
8
+ from .base import PipelineProgramBuilder
9
+ from .communications import add_communication_ops
10
+ from .topology import (
11
+ ScheduleStyle,
12
+ build_stage_to_host_rank_topology,
13
+ invert_stage_to_host_rank_topology,
14
+ )
15
+
16
+ __all__ = [
17
+ "PipelineProgramBuilder",
18
+ "ScheduleStyle",
19
+ "add_communication_ops",
20
+ "build_stage_to_host_rank_topology",
21
+ "invert_stage_to_host_rank_topology"
22
+ ]
@@ -0,0 +1,35 @@
1
+ import abc
2
+
3
+ from ..program.topology import ScheduleStyle
4
+ from ..runtime import ActionBase
5
+
6
+
7
+ class PipelineProgramBuilder(abc.ABC):
8
+ """Abstract interface for building pipeline execution schedules."""
9
+
10
+ @abc.abstractmethod
11
+ def compose(self, num_microbatches: int, pp_size: int) -> dict[int, list[ActionBase]]:
12
+ """
13
+ Generates the execution program for all ranks in the pipeline.
14
+
15
+ Args:
16
+ num_microbatches: Number of microbatches per step.
17
+ pp_size: Number of pipeline parallel ranks.
18
+
19
+ Returns:
20
+ A dictionary mapping rank indices to their list of sequential actions.
21
+ """
22
+ ...
23
+
24
+ @property
25
+ @abc.abstractmethod
26
+ def num_stages_per_rank(self) -> int:
27
+ """Returns the number of model stages designated for each rank."""
28
+
29
+ ...
30
+
31
+ @property
32
+ @abc.abstractmethod
33
+ def topology_style(self) -> ScheduleStyle:
34
+ """Returns the topology style strategy used to assign stages to ranks."""
35
+ ...
@@ -0,0 +1,203 @@
1
+ import copy
2
+ import dataclasses
3
+
4
+ from ..runtime.action import (
5
+ ActionBase,
6
+ BackwardFullInputComputeAction,
7
+ BackwardReceiveAction,
8
+ BackwardSendAction,
9
+ ComposeAction,
10
+ ForwardComputeAction,
11
+ ForwardReceiveAction,
12
+ ForwardSendAction,
13
+ )
14
+
15
+
16
+ def _get_sub_actions(action: ActionBase) -> tuple[ActionBase, ...]:
17
+ if isinstance(action, ComposeAction):
18
+ return action.actions
19
+ return (action,)
20
+
21
+
22
+ def _check_action_communication_dependencies_fulfilled(
23
+ action: ActionBase,
24
+ rank_events: set[ActionBase],
25
+ num_stages: int
26
+ ) -> bool:
27
+ match action:
28
+ case ForwardComputeAction():
29
+ if action.stage_idx == 0:
30
+ return True
31
+ if ForwardReceiveAction(action.stage_idx, action.microbatch_idx) in rank_events:
32
+ return True
33
+ if ForwardComputeAction(action.stage_idx - 1, action.microbatch_idx) in rank_events:
34
+ return True
35
+ return False
36
+ case BackwardFullInputComputeAction():
37
+ if action.stage_idx == num_stages - 1:
38
+ return True
39
+ if BackwardReceiveAction(action.stage_idx, action.microbatch_idx) in rank_events:
40
+ return True
41
+
42
+ next_full = BackwardFullInputComputeAction(
43
+ action.stage_idx + 1,
44
+ action.microbatch_idx,
45
+ full_backward=True
46
+ )
47
+ next_inp = BackwardFullInputComputeAction(
48
+ action.stage_idx + 1,
49
+ action.microbatch_idx,
50
+ full_backward=False
51
+ )
52
+
53
+ if next_full in rank_events or next_inp in rank_events:
54
+ return True
55
+ return False
56
+ case _:
57
+ return True
58
+
59
+
60
+ def check_action_communication_dependencies_fulfilled(
61
+ action: ActionBase,
62
+ rank_events: set[ActionBase],
63
+ num_stages: int
64
+ ) -> bool:
65
+ """
66
+ Checks if data dependencies (Receive or Local Compute) are met for an action.
67
+
68
+ This function determines if a compute action is allowed to run based on
69
+ whether its inputs are available in `rank_events`. Inputs are available
70
+ if they were either computed locally by a previous stage or received
71
+ from a remote rank.
72
+
73
+ Args:
74
+ action: The action to check.
75
+ rank_events: A set of actions already completed on this rank.
76
+ num_stages: Total number of stages in the pipeline.
77
+
78
+ Returns:
79
+ True if all dependencies are satisfied, False otherwise.
80
+ """
81
+
82
+ return all(
83
+ _check_action_communication_dependencies_fulfilled(sub, rank_events, num_stages)
84
+ for sub in _get_sub_actions(action)
85
+ )
86
+
87
+
88
+ @dataclasses.dataclass(kw_only=True)
89
+ class _CommunicationPackage:
90
+ send: ActionBase
91
+ recv: ActionBase
92
+ sends_to_rank: int
93
+
94
+
95
+ def _create_communications_for_action(
96
+ action: ActionBase,
97
+ num_stages: int,
98
+ stage_to_rank: dict[int, int],
99
+ ) -> _CommunicationPackage | None:
100
+ match action:
101
+ case ForwardComputeAction():
102
+ if action.stage_idx == num_stages - 1:
103
+ return None
104
+
105
+ curr_rank, next_rank = stage_to_rank[action.stage_idx], stage_to_rank[action.stage_idx + 1]
106
+ if curr_rank == next_rank:
107
+ return None
108
+
109
+ return _CommunicationPackage(
110
+ send=ForwardSendAction(action.stage_idx, action.microbatch_idx),
111
+ recv=ForwardReceiveAction(action.stage_idx + 1, action.microbatch_idx),
112
+ sends_to_rank=next_rank
113
+ )
114
+ case BackwardFullInputComputeAction():
115
+ if action.stage_idx == 0:
116
+ return None
117
+
118
+ curr_rank, prev_rank = stage_to_rank[action.stage_idx], stage_to_rank[action.stage_idx - 1]
119
+ if curr_rank == prev_rank:
120
+ return None
121
+
122
+ return _CommunicationPackage(
123
+ send=BackwardSendAction(action.stage_idx, action.microbatch_idx),
124
+ recv=BackwardReceiveAction(action.stage_idx - 1, action.microbatch_idx),
125
+ sends_to_rank=prev_rank
126
+ )
127
+ case _:
128
+ return None
129
+
130
+
131
+ def add_communication_ops(
132
+ compute_actions: dict[int, list[ActionBase]],
133
+ stage_to_rank: dict[int, int],
134
+ num_stages: int,
135
+ ) -> dict[int, list[ActionBase]]:
136
+ """
137
+ Injects communication actions into a computation-only schedule.
138
+
139
+ This function iterates through the provided compute schedule and simulates execution.
140
+ When a compute action produces a result needed by a different rank, it injects
141
+ Send/Receive pairs. It also reorders actions to ensure that Receive
142
+ operations occur before the Computes that depend on them, preventing deadlocks.
143
+
144
+ Args:
145
+ compute_actions: Initial schedule containing only compute operations.
146
+ stage_to_rank: Mapping from stage index to rank index.
147
+ num_stages: Total number of pipeline stages.
148
+
149
+ Returns:
150
+ A new schedule dictionary including both compute and communication actions.
151
+
152
+ Raises:
153
+ RuntimeError: If the schedule simulation enters a deadlock state.
154
+ """
155
+
156
+ compute_actions = copy.deepcopy(compute_actions)
157
+
158
+ full_actions: dict[int, list[ActionBase]] = {rank: [] for rank in compute_actions}
159
+ completed_events: dict[int, set[ActionBase]] = {rank: set() for rank in compute_actions}
160
+
161
+ while compute_actions:
162
+ progress = False
163
+
164
+ for rank in sorted(compute_actions.keys()):
165
+ if not compute_actions[rank]:
166
+ del compute_actions[rank]
167
+ continue
168
+
169
+ current_action = compute_actions[rank][0]
170
+ sub_actions = _get_sub_actions(current_action)
171
+
172
+ # Check readiness
173
+ if not check_action_communication_dependencies_fulfilled(
174
+ current_action, completed_events[rank], num_stages
175
+ ):
176
+ continue
177
+
178
+ # Execute
179
+ full_actions[rank].append(current_action)
180
+ compute_actions[rank].pop(0)
181
+ progress = True
182
+
183
+ for sub_action in sub_actions:
184
+ completed_events[rank].add(sub_action)
185
+
186
+ comm_pkg = _create_communications_for_action(
187
+ sub_action,
188
+ num_stages=num_stages,
189
+ stage_to_rank=stage_to_rank
190
+ )
191
+ if comm_pkg:
192
+ # Add Send locally
193
+ full_actions[rank].append(comm_pkg.send)
194
+ completed_events[rank].add(comm_pkg.send)
195
+
196
+ # Add Recv remotely and unblock target
197
+ full_actions[comm_pkg.sends_to_rank].append(comm_pkg.recv)
198
+ completed_events[comm_pkg.sends_to_rank].add(comm_pkg.recv)
199
+
200
+ if not progress and compute_actions:
201
+ raise RuntimeError("Deadlock in schedule simulation")
202
+
203
+ return full_actions
@@ -0,0 +1,78 @@
1
+ from collections import defaultdict
2
+ from enum import StrEnum
3
+
4
+
5
+ class ScheduleStyle(StrEnum):
6
+ """
7
+ Defines the strategy for mapping logical stages to physical ranks.
8
+
9
+ Attributes:
10
+ loop: Assigns stages in a round-robin circular fashion (mod pp_size).
11
+ v: Assigns stages in a zig-zag V-shape pattern. Useful for interleaved 1F1B schedules.
12
+ """
13
+
14
+ loop = "loop"
15
+ v = "v"
16
+
17
+
18
+ def build_stage_to_host_rank_topology(
19
+ pp_size: int, num_stages: int, style: ScheduleStyle
20
+ ) -> dict[int, int]:
21
+ """
22
+ Constructs the mapping from stage index to rank index.
23
+
24
+ Args:
25
+ pp_size: Number of pipeline parallel ranks.
26
+ num_stages: Total number of model stages.
27
+ style: The topology style to use for assignment.
28
+
29
+ Returns:
30
+ A dictionary mapping stage IDs to Rank IDs.
31
+
32
+ Raises:
33
+ ValueError: If the style is unknown or if V-style parameters are invalid
34
+ (num_stages must be divisible by pp_size).
35
+ """
36
+
37
+ match style:
38
+ case ScheduleStyle.loop:
39
+ return {stage_index: stage_index % pp_size for stage_index in range(num_stages)}
40
+ case ScheduleStyle.v:
41
+ if num_stages % pp_size != 0:
42
+ raise ValueError(
43
+ f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size} for V schedules"
44
+ )
45
+
46
+ result = {}
47
+ rank_index = 0
48
+ for stage_index in range(num_stages):
49
+ result[stage_index] = rank_index
50
+ if (stage_index + 1) % pp_size == 0:
51
+ continue
52
+ if (stage_index // pp_size) % 2 == 0:
53
+ rank_index += 1
54
+ else:
55
+ rank_index -= 1
56
+ return result
57
+ case _:
58
+ raise ValueError()
59
+
60
+
61
+ def invert_stage_to_host_rank_topology(
62
+ stage_to_host: dict[int, int]
63
+ ) -> dict[int, list[int]]:
64
+ """
65
+ Inverts the topology mapping to list execution stages per rank.
66
+
67
+ Args:
68
+ stage_to_host: Mapping from stage index to rank index.
69
+
70
+ Returns:
71
+ A dictionary where keys are Rank IDs and values are lists of Stage IDs
72
+ managed by that rank.
73
+ """
74
+
75
+ host_to_stage = defaultdict(list)
76
+ for stage_idx, host in stage_to_host.items():
77
+ host_to_stage[host].append(stage_idx)
78
+ return dict(host_to_stage)
@@ -0,0 +1,29 @@
1
+ """
2
+ Pipelining Runtime Package.
3
+ """
4
+
5
+ from .action import (
6
+ ActionBase,
7
+ BackwardFullInputComputeAction,
8
+ BackwardReceiveAction,
9
+ BackwardSendAction,
10
+ BackwardWeightComputeAction,
11
+ ComposeAction,
12
+ ForwardComputeAction,
13
+ ForwardReceiveAction,
14
+ ForwardSendAction,
15
+ )
16
+ from .executor import PipelineScheduleExecutor
17
+
18
+ __all__ = [
19
+ "ActionBase",
20
+ "BackwardFullInputComputeAction",
21
+ "BackwardReceiveAction",
22
+ "BackwardSendAction",
23
+ "BackwardWeightComputeAction",
24
+ "ComposeAction",
25
+ "ForwardComputeAction",
26
+ "ForwardReceiveAction",
27
+ "ForwardSendAction",
28
+ "PipelineScheduleExecutor",
29
+ ]