d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from ..component.program import (
|
|
2
|
+
PipelineProgramBuilder,
|
|
3
|
+
ScheduleStyle,
|
|
4
|
+
add_communication_ops,
|
|
5
|
+
build_stage_to_host_rank_topology,
|
|
6
|
+
)
|
|
7
|
+
from ..component.runtime import (
|
|
8
|
+
ActionBase,
|
|
9
|
+
BackwardFullInputComputeAction,
|
|
10
|
+
ForwardComputeAction,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class LoopedBFSPipelineProgramBuilder(PipelineProgramBuilder):
|
|
15
|
+
"""
|
|
16
|
+
Builder for the Breadth-First Pipeline Parallelism schedule.
|
|
17
|
+
|
|
18
|
+
This schedule runs all available forward microbatches for local stages first.
|
|
19
|
+
If configured for training, it then runs backwards in reverse topological order.
|
|
20
|
+
|
|
21
|
+
References:
|
|
22
|
+
https://arxiv.org/pdf/2211.05953
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, num_stages_per_rank: int, inference_mode: bool = False):
|
|
26
|
+
"""
|
|
27
|
+
Constructs the LoopedBFS builder.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
num_stages_per_rank: Number of stages per rank.
|
|
31
|
+
inference_mode: If True, only forward passes are scheduled. If False,
|
|
32
|
+
both forward and backward passes are scheduled.
|
|
33
|
+
"""
|
|
34
|
+
self._num_stages_per_rank = num_stages_per_rank
|
|
35
|
+
self._inference_mode = inference_mode
|
|
36
|
+
|
|
37
|
+
def compose(self, num_microbatches: int, pp_size: int) -> dict[int, list[ActionBase]]:
|
|
38
|
+
num_stages = self._num_stages_per_rank * pp_size
|
|
39
|
+
stage_to_rank = build_stage_to_host_rank_topology(
|
|
40
|
+
pp_size=pp_size,
|
|
41
|
+
num_stages=num_stages,
|
|
42
|
+
style=ScheduleStyle.loop
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
compute_actions: dict[int, list[ActionBase]] = {r: [] for r in range(pp_size)}
|
|
46
|
+
|
|
47
|
+
for rank in range(pp_size):
|
|
48
|
+
my_stages = [s for s in range(num_stages) if stage_to_rank[s] == rank]
|
|
49
|
+
|
|
50
|
+
# Schedule all Forwards
|
|
51
|
+
# In Breadth-First loops, we finish all microbatches for the current stage
|
|
52
|
+
# before moving to the next stage assigned to this rank.
|
|
53
|
+
for stage_idx in my_stages:
|
|
54
|
+
for mb_idx in range(num_microbatches):
|
|
55
|
+
compute_actions[rank].append(
|
|
56
|
+
ForwardComputeAction(
|
|
57
|
+
stage_idx=stage_idx,
|
|
58
|
+
microbatch_idx=mb_idx
|
|
59
|
+
)
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Schedule all Backwards (Reverse order) - Only if training
|
|
63
|
+
if not self._inference_mode:
|
|
64
|
+
for stage_idx in reversed(my_stages):
|
|
65
|
+
for mb_idx in reversed(range(num_microbatches)):
|
|
66
|
+
compute_actions[rank].append(
|
|
67
|
+
BackwardFullInputComputeAction(
|
|
68
|
+
stage_idx=stage_idx,
|
|
69
|
+
microbatch_idx=mb_idx,
|
|
70
|
+
full_backward=True
|
|
71
|
+
)
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
return add_communication_ops(
|
|
75
|
+
compute_actions=compute_actions,
|
|
76
|
+
stage_to_rank=stage_to_rank,
|
|
77
|
+
num_stages=num_stages
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def num_stages_per_rank(self) -> int:
|
|
82
|
+
return self._num_stages_per_rank
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def topology_style(self) -> ScheduleStyle:
|
|
86
|
+
return ScheduleStyle.loop
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
|
|
3
|
+
from ..component.program import (
|
|
4
|
+
PipelineProgramBuilder,
|
|
5
|
+
ScheduleStyle,
|
|
6
|
+
add_communication_ops,
|
|
7
|
+
build_stage_to_host_rank_topology,
|
|
8
|
+
)
|
|
9
|
+
from ..component.runtime import (
|
|
10
|
+
ActionBase,
|
|
11
|
+
BackwardFullInputComputeAction,
|
|
12
|
+
BackwardWeightComputeAction,
|
|
13
|
+
ComposeAction,
|
|
14
|
+
ForwardComputeAction,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class DualPipeVPipelineProgramBuilder(PipelineProgramBuilder):
|
|
19
|
+
"""
|
|
20
|
+
Builder for the DualPipeV Pipeline Parallelism schedule.
|
|
21
|
+
|
|
22
|
+
DualPipeV is a specialized bi-directional pipeline schedule designed for high
|
|
23
|
+
throughput training. It requires exactly 2 stages per pipeline rank (V-shape)
|
|
24
|
+
and utilizes split backward passes (Input gradients vs Weight gradients)
|
|
25
|
+
to fill pipeline bubbles.
|
|
26
|
+
|
|
27
|
+
References:
|
|
28
|
+
https://github.com/deepseek-ai/DualPipe
|
|
29
|
+
https://hackmd.io/@ufotalent/r1lVXsa9Jg
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self):
|
|
33
|
+
"""
|
|
34
|
+
Constructs the DualPipeV builder.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def _build_for_rank( # noqa: C901
|
|
39
|
+
rank: int, stage_to_rank: dict[int, int], num_microbatches: int, pp_size: int
|
|
40
|
+
) -> list[ActionBase]:
|
|
41
|
+
compute_actions: list[ActionBase] = []
|
|
42
|
+
|
|
43
|
+
# Identify local stages: s0 is Phase 0, s1 is Phase 1
|
|
44
|
+
my_stages = sorted([s for s, r in stage_to_rank.items() if r == rank])
|
|
45
|
+
s0, s1 = my_stages[0], my_stages[1]
|
|
46
|
+
|
|
47
|
+
# Track microbatch indices for each stage and operation type
|
|
48
|
+
# f_idx: Next Forward microbatch
|
|
49
|
+
# b_idx: Next Backward microbatch (Input or Full)
|
|
50
|
+
f_idx = {s0: 0, s1: 0}
|
|
51
|
+
b_idx = {s0: 0, s1: 0}
|
|
52
|
+
|
|
53
|
+
# Queue for Zero Bubble optimization: stores (stage, mb_idx) for deferred weight grads
|
|
54
|
+
weight_queue: deque[tuple[int, int]] = deque()
|
|
55
|
+
|
|
56
|
+
# --- Helper Functions for Action Emission ---
|
|
57
|
+
|
|
58
|
+
def _add_f(stage: int):
|
|
59
|
+
compute_actions.append(
|
|
60
|
+
ForwardComputeAction(stage_idx=stage, microbatch_idx=f_idx[stage])
|
|
61
|
+
)
|
|
62
|
+
f_idx[stage] += 1
|
|
63
|
+
|
|
64
|
+
def _add_b_full(stage: int):
|
|
65
|
+
compute_actions.append(
|
|
66
|
+
BackwardFullInputComputeAction(
|
|
67
|
+
stage_idx=stage,
|
|
68
|
+
microbatch_idx=b_idx[stage],
|
|
69
|
+
full_backward=True,
|
|
70
|
+
)
|
|
71
|
+
)
|
|
72
|
+
b_idx[stage] += 1
|
|
73
|
+
|
|
74
|
+
def _add_b_input(stage: int):
|
|
75
|
+
mb = b_idx[stage]
|
|
76
|
+
compute_actions.append(
|
|
77
|
+
BackwardFullInputComputeAction(
|
|
78
|
+
stage_idx=stage,
|
|
79
|
+
microbatch_idx=mb,
|
|
80
|
+
full_backward=False,
|
|
81
|
+
)
|
|
82
|
+
)
|
|
83
|
+
weight_queue.append((stage, mb))
|
|
84
|
+
b_idx[stage] += 1
|
|
85
|
+
|
|
86
|
+
def _pop_w():
|
|
87
|
+
if not weight_queue:
|
|
88
|
+
return
|
|
89
|
+
s, mb = weight_queue.popleft()
|
|
90
|
+
compute_actions.append(
|
|
91
|
+
BackwardWeightComputeAction(stage_idx=s, microbatch_idx=mb)
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def _add_overlap_f_b(stage_f: int, stage_b: int, b_is_full: bool):
|
|
95
|
+
"""Emit overlapped Forward and Backward actions."""
|
|
96
|
+
mb_f = f_idx[stage_f]
|
|
97
|
+
mb_b = b_idx[stage_b]
|
|
98
|
+
|
|
99
|
+
act_f = ForwardComputeAction(stage_idx=stage_f, microbatch_idx=mb_f)
|
|
100
|
+
|
|
101
|
+
act_b = BackwardFullInputComputeAction(
|
|
102
|
+
stage_idx=stage_b, microbatch_idx=mb_b, full_backward=b_is_full
|
|
103
|
+
)
|
|
104
|
+
if not b_is_full:
|
|
105
|
+
weight_queue.append((stage_b, mb_b))
|
|
106
|
+
|
|
107
|
+
f_idx[stage_f] += 1
|
|
108
|
+
b_idx[stage_b] += 1
|
|
109
|
+
|
|
110
|
+
# Note: d9d infra treats ComposeAction sequentially in simulation,
|
|
111
|
+
# but runtime may overlap them.
|
|
112
|
+
compute_actions.append(ComposeAction(actions=(act_f, act_b)))
|
|
113
|
+
|
|
114
|
+
# Step 1: nF0 (Startup Phase 0)
|
|
115
|
+
step_1 = (pp_size - rank - 1) * 2
|
|
116
|
+
for _ in range(step_1):
|
|
117
|
+
_add_f(s0)
|
|
118
|
+
|
|
119
|
+
# Step 2: nF0F1 (Forward fill)
|
|
120
|
+
step_2 = rank + 1
|
|
121
|
+
for _ in range(step_2):
|
|
122
|
+
_add_f(s0)
|
|
123
|
+
_add_f(s1)
|
|
124
|
+
|
|
125
|
+
# Step 3: nI1W1F1 (Mixed Phase with Zero Bubble)
|
|
126
|
+
step_3 = pp_size - rank - 1
|
|
127
|
+
for _ in range(step_3):
|
|
128
|
+
_add_b_input(s1) # Backward Input Phase 1
|
|
129
|
+
_pop_w() # Weight Phase (accumulated from prev)
|
|
130
|
+
_add_f(s1) # Forward Phase 1
|
|
131
|
+
|
|
132
|
+
# Step 4: The Main Loop (Interleaved Forward/Backward)
|
|
133
|
+
step_4 = num_microbatches - 2 * pp_size + rank + 1
|
|
134
|
+
for i in range(step_4):
|
|
135
|
+
# Sub-step A: F0 & B1
|
|
136
|
+
if i == 0 and rank == pp_size - 1:
|
|
137
|
+
# Specific case for last rank on first iter: do not overlap
|
|
138
|
+
_add_f(s0)
|
|
139
|
+
_add_b_full(s1)
|
|
140
|
+
else:
|
|
141
|
+
# Overlap F0 and B1 (usually full backward unless we were in ZB mode,
|
|
142
|
+
# but DualPipeV main loop defaults to full for simplicity unless tuned)
|
|
143
|
+
# DeepSeek impl uses standard backward here (zb=False).
|
|
144
|
+
_add_overlap_f_b(stage_f=s0, stage_b=s1, b_is_full=True)
|
|
145
|
+
|
|
146
|
+
# Sub-step B: F1 & B0
|
|
147
|
+
# Overlap F1 and B0 (Full)
|
|
148
|
+
_add_overlap_f_b(stage_f=s1, stage_b=s0, b_is_full=True)
|
|
149
|
+
|
|
150
|
+
# Step 5: Cooldown F1/B0
|
|
151
|
+
step_5 = pp_size - rank - 1
|
|
152
|
+
for _ in range(step_5):
|
|
153
|
+
_add_b_full(s1)
|
|
154
|
+
_add_overlap_f_b(stage_f=s1, stage_b=s0, b_is_full=True)
|
|
155
|
+
|
|
156
|
+
# Step 6: Cooldown B1/B0 with Zero Bubble ramp-up
|
|
157
|
+
step_6 = rank + 1
|
|
158
|
+
enable_zb = False
|
|
159
|
+
for i in range(step_6):
|
|
160
|
+
# Phase 1 Backward
|
|
161
|
+
if i == step_6 // 2 and rank % 2 == 1:
|
|
162
|
+
enable_zb = True
|
|
163
|
+
|
|
164
|
+
if enable_zb:
|
|
165
|
+
_add_b_input(s1)
|
|
166
|
+
else:
|
|
167
|
+
_add_b_full(s1)
|
|
168
|
+
|
|
169
|
+
# Phase 0 Backward
|
|
170
|
+
if i == step_6 // 2 and rank % 2 == 0:
|
|
171
|
+
enable_zb = True
|
|
172
|
+
|
|
173
|
+
if enable_zb:
|
|
174
|
+
_add_b_input(s0)
|
|
175
|
+
else:
|
|
176
|
+
_add_b_full(s0)
|
|
177
|
+
|
|
178
|
+
# Step 7: Zero Bubble Weights + B0
|
|
179
|
+
step_7 = pp_size - rank - 1
|
|
180
|
+
for _ in range(step_7):
|
|
181
|
+
_pop_w()
|
|
182
|
+
# DeepSeek source explicitly uses enable_zb=True here for chunk 0
|
|
183
|
+
_add_b_input(s0)
|
|
184
|
+
|
|
185
|
+
# Step 8: Flush Weights
|
|
186
|
+
step_8 = rank + 1
|
|
187
|
+
for _ in range(step_8):
|
|
188
|
+
_pop_w()
|
|
189
|
+
|
|
190
|
+
return compute_actions
|
|
191
|
+
|
|
192
|
+
def compose(
|
|
193
|
+
self, num_microbatches: int, pp_size: int
|
|
194
|
+
) -> dict[int, list[ActionBase]]:
|
|
195
|
+
num_stages = self.num_stages_per_rank * pp_size
|
|
196
|
+
|
|
197
|
+
if num_microbatches < num_stages:
|
|
198
|
+
raise ValueError(
|
|
199
|
+
f"DualPipeV requires num_microbatches ({num_microbatches}) >= "
|
|
200
|
+
f"num_stages ({num_stages})."
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Ranks hold stages in a V pattern (e.g., Rank 0 holds Stage 0 and Stage N-1).
|
|
204
|
+
# We rely on the sorted order of local steps to determine Phase 0 (Forward-going)
|
|
205
|
+
# and Phase 1 (Backward-coming).
|
|
206
|
+
stage_to_rank = build_stage_to_host_rank_topology(
|
|
207
|
+
pp_size=pp_size, num_stages=num_stages, style=ScheduleStyle.v
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
compute_actions: dict[int, list[ActionBase]] = {r: [] for r in range(pp_size)}
|
|
211
|
+
|
|
212
|
+
for rank in range(pp_size):
|
|
213
|
+
compute_actions[rank] = self._build_for_rank(
|
|
214
|
+
rank=rank,
|
|
215
|
+
pp_size=pp_size,
|
|
216
|
+
num_microbatches=num_microbatches,
|
|
217
|
+
stage_to_rank=stage_to_rank
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# 4. Inject Communication Operations
|
|
221
|
+
# This wrapper handles dependency analysis and inserts Send/Recv/Wait ops.
|
|
222
|
+
return add_communication_ops(
|
|
223
|
+
compute_actions=compute_actions,
|
|
224
|
+
stage_to_rank=stage_to_rank,
|
|
225
|
+
num_stages=num_stages
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
@property
|
|
229
|
+
def num_stages_per_rank(self) -> int:
|
|
230
|
+
return 2
|
|
231
|
+
|
|
232
|
+
@property
|
|
233
|
+
def topology_style(self) -> ScheduleStyle:
|
|
234
|
+
return ScheduleStyle.v
|
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
from collections import defaultdict, deque
|
|
2
|
+
|
|
3
|
+
from ..component.program import (
|
|
4
|
+
PipelineProgramBuilder,
|
|
5
|
+
ScheduleStyle,
|
|
6
|
+
add_communication_ops,
|
|
7
|
+
build_stage_to_host_rank_topology,
|
|
8
|
+
)
|
|
9
|
+
from ..component.runtime import (
|
|
10
|
+
ActionBase,
|
|
11
|
+
BackwardFullInputComputeAction,
|
|
12
|
+
BackwardWeightComputeAction,
|
|
13
|
+
ForwardComputeAction,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Interleaved1F1BPipelineProgramBuilder(PipelineProgramBuilder):
|
|
18
|
+
"""
|
|
19
|
+
Builder for Interleaved Pipeline Parallelism schedules.
|
|
20
|
+
|
|
21
|
+
This builder supports:
|
|
22
|
+
|
|
23
|
+
1. **Standard Interleaved 1F1B**: Assigns multiple stages per rank and prioritizes
|
|
24
|
+
depth-first execution. (See https://arxiv.org/pdf/2104.04473)
|
|
25
|
+
2. **Interleaved Zero Bubble (ZB1P)**: Extends 1F1B by splitting backward passes
|
|
26
|
+
into Input Gradients and Weight Gradients. Weight gradients are delayed
|
|
27
|
+
to fill pipeline bubbles. (See https://arxiv.org/pdf/2401.10241)
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, num_stages_per_rank: int, enable_zero_bubble: bool = False):
|
|
31
|
+
"""
|
|
32
|
+
Constructs the Interleaved 1F1B builder.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
num_stages_per_rank: Number of stages per rank.
|
|
36
|
+
enable_zero_bubble: If True, uses the ZB1P schedule variant which
|
|
37
|
+
splits backward passes to reduce bubble size.
|
|
38
|
+
"""
|
|
39
|
+
self._num_stages_per_rank = num_stages_per_rank
|
|
40
|
+
self._enable_zero_bubble = enable_zero_bubble
|
|
41
|
+
|
|
42
|
+
def _get_warmup_ops(
|
|
43
|
+
self,
|
|
44
|
+
rank: int,
|
|
45
|
+
microbatches_per_round: int,
|
|
46
|
+
pp_size: int,
|
|
47
|
+
n_microbatches: int,
|
|
48
|
+
multiply_factor: int,
|
|
49
|
+
) -> int:
|
|
50
|
+
"""
|
|
51
|
+
Calculates the number of warmup steps required before entering steady state.
|
|
52
|
+
"""
|
|
53
|
+
warmups_ops_last_stage = (self._num_stages_per_rank - 1) * microbatches_per_round
|
|
54
|
+
warmup_ops = warmups_ops_last_stage + multiply_factor * ((pp_size - 1) - rank)
|
|
55
|
+
return min(warmup_ops, n_microbatches * self._num_stages_per_rank)
|
|
56
|
+
|
|
57
|
+
def compose(
|
|
58
|
+
self, num_microbatches: int, pp_size: int
|
|
59
|
+
) -> dict[int, list[ActionBase]]:
|
|
60
|
+
"""
|
|
61
|
+
Generates the execution program for all ranks.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
num_microbatches: Total microbatches. Must be divisible by the derived
|
|
65
|
+
number of rounds.
|
|
66
|
+
pp_size: Number of pipeline ranks.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
A dictionary mapping rank indices to their list of sequential actions.
|
|
70
|
+
"""
|
|
71
|
+
num_stages = self.num_stages_per_rank * pp_size
|
|
72
|
+
|
|
73
|
+
if num_stages % pp_size != 0:
|
|
74
|
+
raise ValueError(
|
|
75
|
+
f"num_stages ({num_stages}) must be divisible by pp_size ({pp_size}) "
|
|
76
|
+
"for interleaved schedules."
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# 1. Topology Setup
|
|
80
|
+
# Use Loop/Round-Robin assignment: Rank 0 gets Stage 0, PP, 2*PP...
|
|
81
|
+
stage_to_rank = build_stage_to_host_rank_topology(
|
|
82
|
+
pp_size=pp_size, num_stages=num_stages, style=ScheduleStyle.loop
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
num_rounds = max(1, num_microbatches // pp_size)
|
|
86
|
+
|
|
87
|
+
if num_microbatches % num_rounds != 0:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"microbatches ({num_microbatches}) must be divisible by rounds ({num_rounds})."
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
microbatches_per_round = num_microbatches // num_rounds
|
|
93
|
+
|
|
94
|
+
# 2. Schedule Generation
|
|
95
|
+
actions: dict[int, list[ActionBase]] = {}
|
|
96
|
+
|
|
97
|
+
# Zero Bubble 1f1b uses a shorter warmup heuristic (factor 1) than Standard (factor 2)
|
|
98
|
+
warmup_multiplier = 1 if self._enable_zero_bubble else 2
|
|
99
|
+
|
|
100
|
+
for rank in range(pp_size):
|
|
101
|
+
actions[rank] = self._generate_rank_schedule(
|
|
102
|
+
rank=rank,
|
|
103
|
+
pp_size=pp_size,
|
|
104
|
+
n_microbatches=num_microbatches,
|
|
105
|
+
microbatches_per_round=microbatches_per_round,
|
|
106
|
+
multiply_factor=warmup_multiplier,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# 3. Communication Injection
|
|
110
|
+
return add_communication_ops(
|
|
111
|
+
compute_actions=actions,
|
|
112
|
+
stage_to_rank=stage_to_rank,
|
|
113
|
+
num_stages=num_stages,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def _generate_rank_schedule( # noqa: C901
|
|
117
|
+
self,
|
|
118
|
+
rank: int,
|
|
119
|
+
pp_size: int,
|
|
120
|
+
n_microbatches: int,
|
|
121
|
+
microbatches_per_round: int,
|
|
122
|
+
multiply_factor: int,
|
|
123
|
+
) -> list[ActionBase]:
|
|
124
|
+
"""
|
|
125
|
+
Generates the sequential list of compute actions for a specific rank.
|
|
126
|
+
"""
|
|
127
|
+
rank_actions: list[ActionBase] = []
|
|
128
|
+
|
|
129
|
+
# -- State Tracking --
|
|
130
|
+
# Map: stage_idx -> next_microbatch_idx
|
|
131
|
+
fwd_counters: dict[int, int] = defaultdict(int)
|
|
132
|
+
bwd_counters: dict[int, int] = defaultdict(int)
|
|
133
|
+
|
|
134
|
+
# FIFO Queue for deferred weight gradients in Zero Bubble
|
|
135
|
+
# Stores: (stage_idx, microbatch_idx)
|
|
136
|
+
pending_weights: deque[tuple[int, int]] = deque()
|
|
137
|
+
|
|
138
|
+
# -- Helpers --
|
|
139
|
+
|
|
140
|
+
def get_global_stage(local_idx: int) -> int:
|
|
141
|
+
"""Converts a local virtual stage index (0..N) to global stage ID."""
|
|
142
|
+
return (local_idx * pp_size) + rank
|
|
143
|
+
|
|
144
|
+
def get_fwd_local_idx(op_idx: int) -> int:
|
|
145
|
+
return (op_idx // microbatches_per_round) % self._num_stages_per_rank
|
|
146
|
+
|
|
147
|
+
def get_bwd_local_idx(op_idx: int, warmup_offset: int) -> int:
|
|
148
|
+
return (self._num_stages_per_rank
|
|
149
|
+
- 1
|
|
150
|
+
- ((op_idx - warmup_offset) // microbatches_per_round) % self._num_stages_per_rank)
|
|
151
|
+
|
|
152
|
+
def emit_forward(op_idx: int):
|
|
153
|
+
local_idx = get_fwd_local_idx(op_idx)
|
|
154
|
+
stage = get_global_stage(local_idx)
|
|
155
|
+
mb = fwd_counters[stage]
|
|
156
|
+
|
|
157
|
+
rank_actions.append(ForwardComputeAction(stage_idx=stage, microbatch_idx=mb))
|
|
158
|
+
fwd_counters[stage] += 1
|
|
159
|
+
|
|
160
|
+
def emit_backward(op_idx: int, warmup_offset: int):
|
|
161
|
+
local_idx = get_bwd_local_idx(op_idx, warmup_offset)
|
|
162
|
+
stage = get_global_stage(local_idx)
|
|
163
|
+
mb = bwd_counters[stage]
|
|
164
|
+
|
|
165
|
+
# In Zero Bubble, we split: Backward Input (Now) + Backward Weight (Later)
|
|
166
|
+
# In Standard 1F1B, we do full backward now.
|
|
167
|
+
is_full = not self._enable_zero_bubble
|
|
168
|
+
|
|
169
|
+
rank_actions.append(
|
|
170
|
+
BackwardFullInputComputeAction(
|
|
171
|
+
stage_idx=stage,
|
|
172
|
+
microbatch_idx=mb,
|
|
173
|
+
full_backward=is_full
|
|
174
|
+
)
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
if self._enable_zero_bubble:
|
|
178
|
+
pending_weights.append((stage, mb))
|
|
179
|
+
|
|
180
|
+
bwd_counters[stage] += 1
|
|
181
|
+
|
|
182
|
+
def try_emit_weight_zb(op_idx: int, warmup_offset: int):
|
|
183
|
+
if not self._enable_zero_bubble or not pending_weights:
|
|
184
|
+
return
|
|
185
|
+
|
|
186
|
+
steps_into_1f1b = op_idx - warmup_offset
|
|
187
|
+
# The earliest reasonable time to start weaving in weights is proportional to rank depth
|
|
188
|
+
if steps_into_1f1b >= rank:
|
|
189
|
+
w_stage, w_mb = pending_weights.popleft()
|
|
190
|
+
rank_actions.append(
|
|
191
|
+
BackwardWeightComputeAction(stage_idx=w_stage, microbatch_idx=w_mb)
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# -- Execution Phase Math --
|
|
195
|
+
|
|
196
|
+
warmup_ops = self._get_warmup_ops(
|
|
197
|
+
rank, microbatches_per_round, pp_size, n_microbatches, multiply_factor
|
|
198
|
+
)
|
|
199
|
+
total_microbatch_ops = self._num_stages_per_rank * n_microbatches
|
|
200
|
+
fwd_bwd_ops = total_microbatch_ops - warmup_ops
|
|
201
|
+
cooldown_ops = total_microbatch_ops - fwd_bwd_ops
|
|
202
|
+
|
|
203
|
+
# Combine into one sequence for iteration, but handle logic per phase
|
|
204
|
+
total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
|
|
205
|
+
|
|
206
|
+
# -- Main Schedule Loop --
|
|
207
|
+
|
|
208
|
+
for op in range(total_ops):
|
|
209
|
+
|
|
210
|
+
# Phase 1: Warmup (Forward Only)
|
|
211
|
+
if op < warmup_ops:
|
|
212
|
+
emit_forward(op)
|
|
213
|
+
|
|
214
|
+
# Phase 2: Steady State (1F1B)
|
|
215
|
+
elif op < warmup_ops + fwd_bwd_ops:
|
|
216
|
+
emit_forward(op)
|
|
217
|
+
emit_backward(op, warmup_offset=warmup_ops)
|
|
218
|
+
try_emit_weight_zb(op, warmup_offset=warmup_ops)
|
|
219
|
+
|
|
220
|
+
# Phase 3: Cooldown (Backward Only)
|
|
221
|
+
else:
|
|
222
|
+
emit_backward(op, warmup_offset=warmup_ops)
|
|
223
|
+
try_emit_weight_zb(op, warmup_offset=warmup_ops)
|
|
224
|
+
|
|
225
|
+
# -- Post-Loop: Flush Remaining Weights (ZB only) --
|
|
226
|
+
while pending_weights:
|
|
227
|
+
w_stage, w_mb = pending_weights.popleft()
|
|
228
|
+
rank_actions.append(
|
|
229
|
+
BackwardWeightComputeAction(stage_idx=w_stage, microbatch_idx=w_mb)
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
return rank_actions
|
|
233
|
+
|
|
234
|
+
@property
|
|
235
|
+
def num_stages_per_rank(self) -> int:
|
|
236
|
+
return self._num_stages_per_rank
|
|
237
|
+
|
|
238
|
+
@property
|
|
239
|
+
def topology_style(self) -> ScheduleStyle:
|
|
240
|
+
return ScheduleStyle.loop
|