d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
from ..component.program import (
|
|
2
|
+
PipelineProgramBuilder,
|
|
3
|
+
ScheduleStyle,
|
|
4
|
+
add_communication_ops,
|
|
5
|
+
build_stage_to_host_rank_topology,
|
|
6
|
+
)
|
|
7
|
+
from ..component.runtime import (
|
|
8
|
+
ActionBase,
|
|
9
|
+
BackwardFullInputComputeAction,
|
|
10
|
+
BackwardWeightComputeAction,
|
|
11
|
+
ForwardComputeAction,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ZeroBubbleVPipelineProgramBuilder(PipelineProgramBuilder):
|
|
16
|
+
"""
|
|
17
|
+
Builder for the Zero Bubble V (ZBV) Pipeline Schedule.
|
|
18
|
+
|
|
19
|
+
This schedule is designed for V-shape topologies (2 stages per rank) and
|
|
20
|
+
utilizes the Zero Bubble optimizations by splitting backward passes.
|
|
21
|
+
|
|
22
|
+
It requires exactly two stages
|
|
23
|
+
per rank organized in a V-shape topology and splits backward passes into
|
|
24
|
+
Input and Weight gradients to optimize pipeline throughput.
|
|
25
|
+
|
|
26
|
+
References:
|
|
27
|
+
https://arxiv.org/pdf/2401.10241, Section 6
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self):
|
|
31
|
+
"""Constructs the ZBV builder."""
|
|
32
|
+
|
|
33
|
+
def compose(
|
|
34
|
+
self, num_microbatches: int, pp_size: int
|
|
35
|
+
) -> dict[int, list[ActionBase]]:
|
|
36
|
+
num_stages = self.num_stages_per_rank * pp_size
|
|
37
|
+
|
|
38
|
+
# 1. Topology
|
|
39
|
+
# V-style: Rank 0 gets Stage 0 & Stage N-1. Rank 1 gets Stage 1 & Stage N-2...
|
|
40
|
+
stage_to_rank = build_stage_to_host_rank_topology(
|
|
41
|
+
pp_size=pp_size, num_stages=num_stages, style=ScheduleStyle.v
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
actions: dict[int, list[ActionBase]] = {}
|
|
45
|
+
|
|
46
|
+
for rank in range(pp_size):
|
|
47
|
+
actions[rank] = self._generate_rank_schedule(
|
|
48
|
+
rank=rank,
|
|
49
|
+
pp_size=pp_size,
|
|
50
|
+
num_stages=num_stages,
|
|
51
|
+
target_microbatches=num_microbatches,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# 2. Inject Communications
|
|
55
|
+
return add_communication_ops(
|
|
56
|
+
compute_actions=actions,
|
|
57
|
+
stage_to_rank=stage_to_rank,
|
|
58
|
+
num_stages=num_stages
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def _generate_rank_schedule( # noqa: C901
|
|
62
|
+
self,
|
|
63
|
+
rank: int,
|
|
64
|
+
pp_size: int,
|
|
65
|
+
num_stages: int,
|
|
66
|
+
target_microbatches: int,
|
|
67
|
+
) -> list[ActionBase]:
|
|
68
|
+
# ZBV logic assumes the pipeline is fully saturated to define the loop bounds.
|
|
69
|
+
# We simulate enough steps to cover the topology startup, then filter
|
|
70
|
+
# down to the user's requested microbatches at the end.
|
|
71
|
+
simulated_n_micro = max(2 * pp_size - 1, target_microbatches)
|
|
72
|
+
|
|
73
|
+
rank_ops: list[ActionBase] = []
|
|
74
|
+
|
|
75
|
+
# -- Stage Identification (V-Shape) --
|
|
76
|
+
# s0: The "Forward-going" chunk (e.g., Stage 0 for Rank 0)
|
|
77
|
+
# s1: The "Backward-coming" chunk (e.g., Stage N-1 for Rank 0)
|
|
78
|
+
s0 = rank
|
|
79
|
+
s1 = num_stages - 1 - rank
|
|
80
|
+
|
|
81
|
+
# -- Counters --
|
|
82
|
+
# Track next microbatch index for each operation type on each chunk.
|
|
83
|
+
# F: Forward, I: Backward Input, W: Backward Weight
|
|
84
|
+
f0_cnt = 0
|
|
85
|
+
b0_cnt = 0 # Input Grad Counter (Chunk 0)
|
|
86
|
+
w0_cnt = 0 # Weight Grad Counter (Chunk 0)
|
|
87
|
+
|
|
88
|
+
f1_cnt = 0
|
|
89
|
+
b1_cnt = 0 # Input Grad Counter (Chunk 1)
|
|
90
|
+
w1_cnt = 0 # Weight Grad Counter (Chunk 1)
|
|
91
|
+
|
|
92
|
+
# -- Helpers --
|
|
93
|
+
|
|
94
|
+
def emit_f(stage: int, idx: int):
|
|
95
|
+
rank_ops.append(ForwardComputeAction(stage_idx=stage, microbatch_idx=idx))
|
|
96
|
+
|
|
97
|
+
def emit_i_and_w(stage: int, idx: int):
|
|
98
|
+
rank_ops.append(
|
|
99
|
+
BackwardFullInputComputeAction(
|
|
100
|
+
stage_idx=stage, microbatch_idx=idx, full_backward=False
|
|
101
|
+
)
|
|
102
|
+
)
|
|
103
|
+
rank_ops.append(
|
|
104
|
+
BackwardWeightComputeAction(stage_idx=stage, microbatch_idx=idx)
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
def emit_i(stage: int, idx: int):
|
|
108
|
+
rank_ops.append(
|
|
109
|
+
BackwardFullInputComputeAction(
|
|
110
|
+
stage_idx=stage, microbatch_idx=idx, full_backward=False
|
|
111
|
+
)
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
def emit_w(stage: int, idx: int):
|
|
115
|
+
rank_ops.append(
|
|
116
|
+
BackwardWeightComputeAction(stage_idx=stage, microbatch_idx=idx)
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# -- Phase 1: Warmup 1 (Chunk 0 Forwards) --
|
|
120
|
+
warmup_n1 = 2 * (pp_size - rank) - 1
|
|
121
|
+
for _ in range(warmup_n1):
|
|
122
|
+
emit_f(s0, f0_cnt)
|
|
123
|
+
f0_cnt += 1
|
|
124
|
+
|
|
125
|
+
# -- Phase 2: Warmup 2 (Interleave F1, F0) --
|
|
126
|
+
warmup_n2 = rank
|
|
127
|
+
for _ in range(warmup_n2):
|
|
128
|
+
emit_f(s1, f1_cnt)
|
|
129
|
+
f1_cnt += 1
|
|
130
|
+
emit_f(s0, f0_cnt)
|
|
131
|
+
f0_cnt += 1
|
|
132
|
+
|
|
133
|
+
# -- Phase 3: Warmup 3 (F1, then B1 I+W) --
|
|
134
|
+
warmup_n3 = pp_size - rank
|
|
135
|
+
for _ in range(warmup_n3):
|
|
136
|
+
emit_f(s1, f1_cnt)
|
|
137
|
+
f1_cnt += 1
|
|
138
|
+
|
|
139
|
+
emit_i_and_w(s1, b1_cnt)
|
|
140
|
+
b1_cnt += 1
|
|
141
|
+
w1_cnt += 1
|
|
142
|
+
|
|
143
|
+
# -- Phase 4: Stable State --
|
|
144
|
+
while f1_cnt < f0_cnt or f0_cnt < simulated_n_micro:
|
|
145
|
+
# Emit F0 if within bounds
|
|
146
|
+
if f0_cnt < simulated_n_micro:
|
|
147
|
+
emit_f(s0, f0_cnt)
|
|
148
|
+
f0_cnt += 1
|
|
149
|
+
|
|
150
|
+
# Emit B0 (I+W)
|
|
151
|
+
emit_i_and_w(s0, b0_cnt)
|
|
152
|
+
b0_cnt += 1
|
|
153
|
+
w0_cnt += 1
|
|
154
|
+
|
|
155
|
+
# Emit F1
|
|
156
|
+
emit_f(s1, f1_cnt)
|
|
157
|
+
f1_cnt += 1
|
|
158
|
+
|
|
159
|
+
# Emit B1 (I+W)
|
|
160
|
+
emit_i_and_w(s1, b1_cnt)
|
|
161
|
+
b1_cnt += 1
|
|
162
|
+
w1_cnt += 1
|
|
163
|
+
|
|
164
|
+
# -- Phase 5: Cooldown 1 (Splitting I and W) --
|
|
165
|
+
# In cooldown, the I and W streams diverge to fill bubbles.
|
|
166
|
+
cooldown_n1 = rank
|
|
167
|
+
for _ in range(cooldown_n1):
|
|
168
|
+
emit_i(s0, b0_cnt)
|
|
169
|
+
b0_cnt += 1
|
|
170
|
+
|
|
171
|
+
emit_i(s1, b1_cnt)
|
|
172
|
+
b1_cnt += 1
|
|
173
|
+
|
|
174
|
+
# -- Phase 6: Cooldown 2 (I0, then W0) --
|
|
175
|
+
cooldown_n2 = pp_size - rank
|
|
176
|
+
for _ in range(cooldown_n2):
|
|
177
|
+
# Input Grad Chunk 0
|
|
178
|
+
emit_i(s0, b0_cnt)
|
|
179
|
+
b0_cnt += 1
|
|
180
|
+
|
|
181
|
+
# Weight Grad Chunk 0 (delayed from previous steps)
|
|
182
|
+
emit_w(s0, w0_cnt)
|
|
183
|
+
w0_cnt += 1
|
|
184
|
+
|
|
185
|
+
# -- Phase 7: Flush Remaining Weights --
|
|
186
|
+
|
|
187
|
+
# Flush W1
|
|
188
|
+
while w1_cnt < b1_cnt:
|
|
189
|
+
emit_w(s1, w1_cnt)
|
|
190
|
+
w1_cnt += 1
|
|
191
|
+
|
|
192
|
+
# Flush W0
|
|
193
|
+
while w0_cnt < b0_cnt:
|
|
194
|
+
emit_w(s0, w0_cnt)
|
|
195
|
+
w0_cnt += 1
|
|
196
|
+
|
|
197
|
+
# -- Integrity Check --
|
|
198
|
+
if not (w0_cnt == b0_cnt == f0_cnt):
|
|
199
|
+
raise RuntimeError(
|
|
200
|
+
f"ZBV Schedule Failed (Chunk 0): F={f0_cnt}, I={b0_cnt}, W={w0_cnt}"
|
|
201
|
+
)
|
|
202
|
+
if not (w1_cnt == b1_cnt == f1_cnt):
|
|
203
|
+
raise RuntimeError(
|
|
204
|
+
f"ZBV Schedule Failed (Chunk 1): F={f1_cnt}, I={b1_cnt}, W={w1_cnt}"
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# -- Post-Process: Filter to Target Microbatches --
|
|
208
|
+
# Remove any actions involving simulated microbatches beyond the user's request.
|
|
209
|
+
final_ops: list[ActionBase] = []
|
|
210
|
+
for action in rank_ops:
|
|
211
|
+
if isinstance(action, (ForwardComputeAction,
|
|
212
|
+
BackwardFullInputComputeAction,
|
|
213
|
+
BackwardWeightComputeAction)):
|
|
214
|
+
if action.microbatch_idx < target_microbatches:
|
|
215
|
+
final_ops.append(action)
|
|
216
|
+
else:
|
|
217
|
+
final_ops.append(action)
|
|
218
|
+
|
|
219
|
+
return final_ops
|
|
220
|
+
|
|
221
|
+
@property
|
|
222
|
+
def num_stages_per_rank(self) -> int:
|
|
223
|
+
return 2
|
|
224
|
+
|
|
225
|
+
@property
|
|
226
|
+
def topology_style(self) -> ScheduleStyle:
|
|
227
|
+
return ScheduleStyle.v
|
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.distributed as dist
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclasses.dataclass(kw_only=True, slots=True)
|
|
8
|
+
class ReceiveStageInput:
|
|
9
|
+
"""
|
|
10
|
+
Instruction to receive a specific tensor from a previous stage (or next stage during backward).
|
|
11
|
+
|
|
12
|
+
Attributes:
|
|
13
|
+
name: A unique identifier for the communication operation.
|
|
14
|
+
from_stage: The stage index sending the data.
|
|
15
|
+
buffer: The pre-allocated tensor buffer where data will be received.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
name: str
|
|
19
|
+
from_stage: int
|
|
20
|
+
buffer: torch.Tensor
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclasses.dataclass
|
|
24
|
+
class StartStageInput:
|
|
25
|
+
"""
|
|
26
|
+
Instruction indicating that the input for this stage does not come from communication
|
|
27
|
+
(e.g., this is the first stage receiving data loader inputs).
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
StageInput = ReceiveStageInput | StartStageInput
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclasses.dataclass(kw_only=True, slots=True)
|
|
35
|
+
class SendStageOutput:
|
|
36
|
+
"""
|
|
37
|
+
Instruction to send a specific tensor to a next stage (or previous if backward).
|
|
38
|
+
|
|
39
|
+
Attributes:
|
|
40
|
+
to_stage: The stage index receiving the data.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
to_stage: int
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclasses.dataclass
|
|
47
|
+
class EndStageOutput:
|
|
48
|
+
"""
|
|
49
|
+
Instruction indicating that the output of this stage is not sent anywhere
|
|
50
|
+
(e.g., this is the last stage computing loss).
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
StageOutput = SendStageOutput | EndStageOutput
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class StageCommunicationHandler:
|
|
58
|
+
"""
|
|
59
|
+
Manages Point-to-Point (P2P) communication descriptors for a specific data flow direction within a pipeline stage.
|
|
60
|
+
|
|
61
|
+
This class handles the creation of P2P operations (send/recv) across multiple microbatches,
|
|
62
|
+
managing buffers and mapping logical stage indices to physical ranks.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
|
|
68
|
+
name: str,
|
|
69
|
+
stage_index: int,
|
|
70
|
+
num_microbatches: int,
|
|
71
|
+
|
|
72
|
+
input_stage_index: int | None,
|
|
73
|
+
input_args: dict[str, torch.Tensor],
|
|
74
|
+
|
|
75
|
+
output_stage_index: int | None,
|
|
76
|
+
output_args: dict[str, torch.Tensor],
|
|
77
|
+
|
|
78
|
+
stage_idx_to_host_rank: dict[int, int],
|
|
79
|
+
group: dist.ProcessGroup
|
|
80
|
+
):
|
|
81
|
+
"""
|
|
82
|
+
Constructs a StageCommunicationHandler object.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
name: Name prefix for this handler (e.g., 'fwd', 'bwd').
|
|
86
|
+
stage_index: The logical index of the current stage.
|
|
87
|
+
num_microbatches: Total number of microbatches ("chunks") to schedule.
|
|
88
|
+
input_stage_index: The logical index of the stage providing inputs, or None if inputs are local.
|
|
89
|
+
input_args: Metadata (shapes/dtypes) for input tensors.
|
|
90
|
+
output_stage_index: The logical index of the stage consuming outputs, or None if outputs are terminal.
|
|
91
|
+
output_args: Metadata (shapes/dtypes) for output tensors.
|
|
92
|
+
stage_idx_to_host_rank: Mapping from logical stage indices to physical world ranks.
|
|
93
|
+
group: The process group strictly for pipeline communication.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
self._input_handlers = self._build_inputs(
|
|
97
|
+
name=name,
|
|
98
|
+
stage_index=stage_index,
|
|
99
|
+
num_microbatches=num_microbatches,
|
|
100
|
+
input_stage_index=input_stage_index,
|
|
101
|
+
input_args=input_args
|
|
102
|
+
)
|
|
103
|
+
self._output_handlers = self._build_outputs(
|
|
104
|
+
output_stage_index=output_stage_index,
|
|
105
|
+
output_args=output_args
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
self._stage_idx_to_host_rank = stage_idx_to_host_rank
|
|
109
|
+
self._group = group
|
|
110
|
+
|
|
111
|
+
@staticmethod
|
|
112
|
+
def _build_inputs(
|
|
113
|
+
name: str,
|
|
114
|
+
stage_index: int,
|
|
115
|
+
num_microbatches: int,
|
|
116
|
+
input_stage_index: int | None,
|
|
117
|
+
input_args: dict[str, torch.Tensor]
|
|
118
|
+
) -> dict[int, dict[str, StageInput]]:
|
|
119
|
+
handlers: dict[int, dict[str, StageInput]] = {}
|
|
120
|
+
|
|
121
|
+
for chunk_id in range(num_microbatches):
|
|
122
|
+
handlers[chunk_id] = {}
|
|
123
|
+
for input_name, input_tensor_meta in input_args.items():
|
|
124
|
+
if input_stage_index is None:
|
|
125
|
+
handlers[chunk_id][input_name] = StartStageInput()
|
|
126
|
+
else:
|
|
127
|
+
handlers[chunk_id][input_name] = ReceiveStageInput(
|
|
128
|
+
name=f"{name}_recv_from_{input_stage_index}_to_{stage_index}[{chunk_id}][{input_name}]",
|
|
129
|
+
from_stage=input_stage_index,
|
|
130
|
+
buffer=torch.empty(
|
|
131
|
+
input_tensor_meta.size(),
|
|
132
|
+
dtype=input_tensor_meta.dtype,
|
|
133
|
+
layout=input_tensor_meta.layout,
|
|
134
|
+
device="cuda" # force device
|
|
135
|
+
)
|
|
136
|
+
)
|
|
137
|
+
return handlers
|
|
138
|
+
|
|
139
|
+
@staticmethod
|
|
140
|
+
def _build_outputs(
|
|
141
|
+
output_stage_index: int | None,
|
|
142
|
+
output_args: dict[str, torch.Tensor]
|
|
143
|
+
) -> dict[str, StageOutput]:
|
|
144
|
+
handlers: dict[str, StageOutput] = {}
|
|
145
|
+
|
|
146
|
+
for output_name in output_args:
|
|
147
|
+
if output_stage_index is None:
|
|
148
|
+
handlers[output_name] = EndStageOutput()
|
|
149
|
+
else:
|
|
150
|
+
handlers[output_name] = SendStageOutput(
|
|
151
|
+
to_stage=output_stage_index
|
|
152
|
+
)
|
|
153
|
+
return handlers
|
|
154
|
+
|
|
155
|
+
def set_input_requires_grad_(self, requires_grad: bool):
|
|
156
|
+
"""
|
|
157
|
+
Sets the `requires_grad` flag for all internal input buffers.
|
|
158
|
+
|
|
159
|
+
Typically used to enable gradient flow from backward stages to forward stages.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
requires_grad: Whether the buffers should require gradients.
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
for inputs in self._input_handlers.values():
|
|
166
|
+
for info in inputs.values():
|
|
167
|
+
if isinstance(info, ReceiveStageInput):
|
|
168
|
+
info.buffer.requires_grad_(requires_grad)
|
|
169
|
+
|
|
170
|
+
def set_inputs_local(self, inputs: dict[str, torch.Tensor], microbatch_index: int):
|
|
171
|
+
"""
|
|
172
|
+
Manually fills the input buffer for a specific microbatch with local data.
|
|
173
|
+
|
|
174
|
+
This is used when the stage is the first in the pipeline or receives data
|
|
175
|
+
from a dataloader rather than via network communication.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
inputs: Dictionary of input tensors.
|
|
179
|
+
microbatch_index: The microbatch identifier.
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
for input_name, input_value in inputs.items():
|
|
183
|
+
handler = self._input_handlers[microbatch_index][input_name]
|
|
184
|
+
if not isinstance(handler, ReceiveStageInput):
|
|
185
|
+
raise RuntimeError("Tried to set a buffer of no-receive stage input")
|
|
186
|
+
prev_requires_grad = handler.buffer.requires_grad
|
|
187
|
+
handler.buffer = input_value.detach().requires_grad_(
|
|
188
|
+
prev_requires_grad)
|
|
189
|
+
|
|
190
|
+
def get_inputs(self, microbatch_index: int) -> dict[str, torch.Tensor]:
|
|
191
|
+
"""
|
|
192
|
+
Retrieves the input tensors for a specific microbatch from the internal buffers.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
microbatch_index: The microbatch identifier.
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
Dictionary mapping input names to tensors.
|
|
199
|
+
"""
|
|
200
|
+
outputs: dict[str, torch.Tensor] = {}
|
|
201
|
+
|
|
202
|
+
for input_name, input_info in self._input_handlers[microbatch_index].items():
|
|
203
|
+
if not isinstance(input_info, ReceiveStageInput):
|
|
204
|
+
raise RuntimeError("Tried to get a buffer of no receive stage input")
|
|
205
|
+
outputs[input_name] = input_info.buffer
|
|
206
|
+
|
|
207
|
+
return outputs
|
|
208
|
+
|
|
209
|
+
def create_receive_ops(self, microbatch_index: int) -> list[dist.P2POp]:
|
|
210
|
+
"""
|
|
211
|
+
Generates the PyTorch P2P receive operations for a specific microbatch.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
microbatch_index: The microbatch identifier.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
A list of `dist.P2POp` objects configured for `dist.irecv`.
|
|
218
|
+
"""
|
|
219
|
+
|
|
220
|
+
ops = []
|
|
221
|
+
|
|
222
|
+
inputs = self._input_handlers[microbatch_index]
|
|
223
|
+
# sort ops by parameter names to ensure receive ops are ordered the same for send and recv
|
|
224
|
+
for _input_name, input_info in sorted(inputs.items(), key=lambda x: x[0]):
|
|
225
|
+
match input_info:
|
|
226
|
+
case StartStageInput():
|
|
227
|
+
pass
|
|
228
|
+
case ReceiveStageInput():
|
|
229
|
+
peer_rank = self._stage_idx_to_host_rank[input_info.from_stage]
|
|
230
|
+
peer_global_rank = dist.get_global_rank(self._group, peer_rank)
|
|
231
|
+
op = dist.P2POp(dist.irecv, input_info.buffer, peer_global_rank, self._group)
|
|
232
|
+
ops.append(op)
|
|
233
|
+
case _:
|
|
234
|
+
raise ValueError()
|
|
235
|
+
|
|
236
|
+
return ops
|
|
237
|
+
|
|
238
|
+
def create_send_ops(self, send_contents: dict[str, torch.Tensor]) -> list[dist.P2POp]:
|
|
239
|
+
"""
|
|
240
|
+
Generates the PyTorch P2P send operations for the provided tensors.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
send_contents: Dictionary of tensors to send.
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
A list of `dist.P2POp` objects configured for `dist.isend`.
|
|
247
|
+
"""
|
|
248
|
+
|
|
249
|
+
ops = []
|
|
250
|
+
|
|
251
|
+
# sort ops by parameter names to ensure receive ops are ordered the same for send and recv
|
|
252
|
+
for output_name, output_info in sorted(self._output_handlers.items(), key=lambda x: x[0]):
|
|
253
|
+
output_tensor = send_contents[output_name]
|
|
254
|
+
|
|
255
|
+
match output_info:
|
|
256
|
+
case EndStageOutput():
|
|
257
|
+
pass
|
|
258
|
+
case SendStageOutput():
|
|
259
|
+
peer_rank = self._stage_idx_to_host_rank[output_info.to_stage]
|
|
260
|
+
peer_global_rank = dist.get_global_rank(self._group, peer_rank)
|
|
261
|
+
op = dist.P2POp(dist.isend, output_tensor, peer_global_rank, self._group)
|
|
262
|
+
ops.append(op)
|
|
263
|
+
case _:
|
|
264
|
+
raise ValueError()
|
|
265
|
+
|
|
266
|
+
return ops
|
|
267
|
+
|
|
268
|
+
def reset(self):
|
|
269
|
+
"""Resets the internal state, specifically clearing gradients on input buffers."""
|
|
270
|
+
|
|
271
|
+
for inp_handlers in self._input_handlers.values():
|
|
272
|
+
for inp_handler in inp_handlers.values():
|
|
273
|
+
if isinstance(inp_handler, ReceiveStageInput):
|
|
274
|
+
inp_handler.buffer.grad = None
|