d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,361 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import dataclasses
|
|
3
|
+
from enum import StrEnum
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from d9d.pipelining.infra.stage import PipelineStage
|
|
9
|
+
|
|
10
|
+
from .communications import PipelineCommunicationHandler
|
|
11
|
+
from .loss import PipelineLossHandler
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclasses.dataclass(kw_only=True, slots=True)
|
|
15
|
+
class ActionContext:
|
|
16
|
+
"""
|
|
17
|
+
Holds the runtime context required to execute a pipeline action.
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
pipeline_inputs_microbatches: The global inputs sharded by microbatch.
|
|
21
|
+
pipeline_kwargs_microbatches: The global keyword arguments sharded by microbatch.
|
|
22
|
+
stages: A mapping of stage indices to their active PipelineStage instances.
|
|
23
|
+
communications: The handler for P2P communications.
|
|
24
|
+
loss: The handler for loss computation, or None if not available.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
pipeline_inputs_microbatches: tuple[dict[str, torch.Tensor], ...]
|
|
28
|
+
pipeline_kwargs_microbatches: tuple[dict[str, Any], ...]
|
|
29
|
+
|
|
30
|
+
stages: dict[int, PipelineStage]
|
|
31
|
+
communications: PipelineCommunicationHandler
|
|
32
|
+
loss: PipelineLossHandler | None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ActionWorkType(StrEnum):
|
|
36
|
+
"""
|
|
37
|
+
Classifies the type of work performed by an action.
|
|
38
|
+
|
|
39
|
+
Attributes:
|
|
40
|
+
compute: Indicates the action involves computation components (forward, backward).
|
|
41
|
+
communicate: Indicates the action involves network I/O components (send, receive).
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
compute = "compute"
|
|
45
|
+
communicate = "communicate"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ActionBase(abc.ABC):
|
|
49
|
+
"""
|
|
50
|
+
Abstract base class for all pipeline schedule actions.
|
|
51
|
+
|
|
52
|
+
An action represents an atomic unit of work in a pipeline schedule,
|
|
53
|
+
such as computing a microbatch or sending/receiving a tensor.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
@abc.abstractmethod
|
|
57
|
+
def apply(self, ctx: ActionContext):
|
|
58
|
+
"""
|
|
59
|
+
Executes the action logic using the provided context.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
ctx: The runtime context containing stages, data, and communication handlers.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
...
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
@abc.abstractmethod
|
|
69
|
+
def work_type(self) -> ActionWorkType:
|
|
70
|
+
"""Returns the classification of work this action performs."""
|
|
71
|
+
...
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
@abc.abstractmethod
|
|
75
|
+
def has_backward_work(self) -> bool:
|
|
76
|
+
"""Returns True if this action involves backward pass computations."""
|
|
77
|
+
...
|
|
78
|
+
|
|
79
|
+
@abc.abstractmethod
|
|
80
|
+
def __str__(self) -> str:
|
|
81
|
+
"""Returns a short string representation of the action for logging/visualization."""
|
|
82
|
+
...
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclasses.dataclass(frozen=True, slots=True)
|
|
86
|
+
class ForwardSendAction(ActionBase):
|
|
87
|
+
"""
|
|
88
|
+
Action to schedule a forward pass tensor send operation.
|
|
89
|
+
|
|
90
|
+
Attributes:
|
|
91
|
+
stage_idx: The integer index of the pipeline stage initiating the send operation.
|
|
92
|
+
microbatch_idx: The integer index of the microbatch being sent.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
stage_idx: int
|
|
96
|
+
microbatch_idx: int
|
|
97
|
+
|
|
98
|
+
def apply(self, ctx: ActionContext):
|
|
99
|
+
ctx.communications.schedule_fwd_send(self.stage_idx, self.microbatch_idx)
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def work_type(self) -> ActionWorkType:
|
|
103
|
+
return ActionWorkType.communicate
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def has_backward_work(self) -> bool:
|
|
107
|
+
return False
|
|
108
|
+
|
|
109
|
+
def __str__(self) -> str:
|
|
110
|
+
return f"{self.stage_idx}SEND_F{self.microbatch_idx}"
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@dataclasses.dataclass(frozen=True, slots=True)
|
|
114
|
+
class BackwardSendAction(ActionBase):
|
|
115
|
+
"""
|
|
116
|
+
Action to schedule a backward pass gradient send operation.
|
|
117
|
+
|
|
118
|
+
Attributes:
|
|
119
|
+
stage_idx: The integer index of the pipeline stage initiating the send operation.
|
|
120
|
+
microbatch_idx: The integer index of the microbatch being sent.
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
stage_idx: int
|
|
124
|
+
microbatch_idx: int
|
|
125
|
+
|
|
126
|
+
def apply(self, ctx: ActionContext):
|
|
127
|
+
ctx.communications.schedule_bwd_send(self.stage_idx, self.microbatch_idx)
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def work_type(self) -> ActionWorkType:
|
|
131
|
+
return ActionWorkType.communicate
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def has_backward_work(self) -> bool:
|
|
135
|
+
return True
|
|
136
|
+
|
|
137
|
+
def __str__(self) -> str:
|
|
138
|
+
return f"{self.stage_idx}SEND_B{self.microbatch_idx}"
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@dataclasses.dataclass(frozen=True, slots=True)
|
|
142
|
+
class ForwardReceiveAction(ActionBase):
|
|
143
|
+
"""
|
|
144
|
+
Action to schedule a forward pass tensor receive operation.
|
|
145
|
+
|
|
146
|
+
Attributes:
|
|
147
|
+
stage_idx: The integer index of the pipeline stage expecting the receive operation.
|
|
148
|
+
microbatch_idx: The integer index of the microbatch being received.
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
stage_idx: int
|
|
152
|
+
microbatch_idx: int
|
|
153
|
+
|
|
154
|
+
def apply(self, ctx: ActionContext):
|
|
155
|
+
ctx.communications.schedule_fwd_recv(self.stage_idx, self.microbatch_idx)
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def work_type(self) -> ActionWorkType:
|
|
159
|
+
return ActionWorkType.communicate
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def has_backward_work(self) -> bool:
|
|
163
|
+
return True
|
|
164
|
+
|
|
165
|
+
def __str__(self) -> str:
|
|
166
|
+
return f"{self.stage_idx}RECV_F{self.microbatch_idx}"
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
@dataclasses.dataclass(frozen=True, slots=True)
|
|
170
|
+
class BackwardReceiveAction(ActionBase):
|
|
171
|
+
"""
|
|
172
|
+
Action to schedule a backward pass gradient receive operation.
|
|
173
|
+
|
|
174
|
+
Attributes:
|
|
175
|
+
stage_idx: The integer index of the pipeline stage expecting the receive operation.
|
|
176
|
+
microbatch_idx: The integer index of the microbatch being received.
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
stage_idx: int
|
|
180
|
+
microbatch_idx: int
|
|
181
|
+
|
|
182
|
+
def apply(self, ctx: ActionContext):
|
|
183
|
+
ctx.communications.schedule_bwd_recv(self.stage_idx, self.microbatch_idx)
|
|
184
|
+
|
|
185
|
+
@property
|
|
186
|
+
def work_type(self) -> ActionWorkType:
|
|
187
|
+
return ActionWorkType.communicate
|
|
188
|
+
|
|
189
|
+
@property
|
|
190
|
+
def has_backward_work(self) -> bool:
|
|
191
|
+
return True
|
|
192
|
+
|
|
193
|
+
def __str__(self) -> str:
|
|
194
|
+
return f"{self.stage_idx}RECV_B{self.microbatch_idx}"
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
@dataclasses.dataclass(frozen=True, slots=True)
|
|
198
|
+
class ForwardComputeAction(ActionBase):
|
|
199
|
+
"""
|
|
200
|
+
Action to perform forward computation for a specific microbatch.
|
|
201
|
+
|
|
202
|
+
Attributes:
|
|
203
|
+
stage_idx: The integer index of the pipeline stage.
|
|
204
|
+
microbatch_idx: The integer index of the microbatch to compute.
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
stage_idx: int
|
|
208
|
+
microbatch_idx: int
|
|
209
|
+
|
|
210
|
+
def apply(self, ctx: ActionContext):
|
|
211
|
+
# todo check unsharded
|
|
212
|
+
stage = ctx.stages[self.stage_idx]
|
|
213
|
+
|
|
214
|
+
if not stage.info.is_current_stage_first and self.stage_idx - 1 not in ctx.stages:
|
|
215
|
+
ctx.communications.wait_fwd_recv(self.stage_idx, self.microbatch_idx)
|
|
216
|
+
|
|
217
|
+
stage.forward_one_chunk(
|
|
218
|
+
microbatch_index=self.microbatch_idx,
|
|
219
|
+
pipeline_inputs=ctx.pipeline_inputs_microbatches[self.microbatch_idx],
|
|
220
|
+
pipeline_kwargs=ctx.pipeline_kwargs_microbatches[self.microbatch_idx]
|
|
221
|
+
)
|
|
222
|
+
result = stage.get_local_fwd_output(self.microbatch_idx)
|
|
223
|
+
|
|
224
|
+
if stage.info.is_current_stage_last and ctx.loss is not None:
|
|
225
|
+
ctx.loss.compute_loss(result, self.microbatch_idx)
|
|
226
|
+
|
|
227
|
+
if not stage.info.is_current_stage_last and self.stage_idx + 1 in ctx.stages:
|
|
228
|
+
ctx.stages[self.stage_idx + 1].set_local_fwd_input(
|
|
229
|
+
inputs=result,
|
|
230
|
+
microbatch_index=self.microbatch_idx
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
@property
|
|
234
|
+
def work_type(self) -> ActionWorkType:
|
|
235
|
+
return ActionWorkType.compute
|
|
236
|
+
|
|
237
|
+
@property
|
|
238
|
+
def has_backward_work(self) -> bool:
|
|
239
|
+
return False
|
|
240
|
+
|
|
241
|
+
def __str__(self) -> str:
|
|
242
|
+
return f"{self.stage_idx}F{self.microbatch_idx}"
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
@dataclasses.dataclass(frozen=True, slots=True)
|
|
246
|
+
class BackwardFullInputComputeAction(ActionBase):
|
|
247
|
+
"""
|
|
248
|
+
Action to perform backward computation with respect to inputs.
|
|
249
|
+
|
|
250
|
+
Attributes:
|
|
251
|
+
stage_idx: The integer index of the pipeline stage.
|
|
252
|
+
microbatch_idx: The integer index of the microbatch to compute.
|
|
253
|
+
full_backward: If True, performs a full backward pass including inputs
|
|
254
|
+
and weights. If False, may only compute gradients w.r.t inputs
|
|
255
|
+
(depending on schedule implementation).
|
|
256
|
+
"""
|
|
257
|
+
|
|
258
|
+
stage_idx: int
|
|
259
|
+
microbatch_idx: int
|
|
260
|
+
full_backward: bool
|
|
261
|
+
|
|
262
|
+
def apply(self, ctx: ActionContext):
|
|
263
|
+
# todo unshard
|
|
264
|
+
stage = ctx.stages[self.stage_idx]
|
|
265
|
+
|
|
266
|
+
if not stage.info.is_current_stage_last and self.stage_idx + 1 not in ctx.stages:
|
|
267
|
+
ctx.communications.wait_bwd_recv(self.stage_idx, self.microbatch_idx)
|
|
268
|
+
|
|
269
|
+
if stage.info.is_current_stage_last and ctx.loss is not None:
|
|
270
|
+
loss = ctx.loss.acquire_loss(self.microbatch_idx)
|
|
271
|
+
else:
|
|
272
|
+
loss = None
|
|
273
|
+
|
|
274
|
+
stage.backward_one_chunk(
|
|
275
|
+
microbatch_index=self.microbatch_idx,
|
|
276
|
+
full_backward=self.full_backward,
|
|
277
|
+
loss=loss
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
if not stage.info.is_current_stage_first and self.stage_idx - 1 in ctx.stages:
|
|
281
|
+
ctx.stages[self.stage_idx - 1].set_local_bwd_input(
|
|
282
|
+
microbatch_index=self.microbatch_idx,
|
|
283
|
+
inputs=stage.pop_local_bwd_output(self.microbatch_idx)
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
@property
|
|
287
|
+
def work_type(self) -> ActionWorkType:
|
|
288
|
+
return ActionWorkType.compute
|
|
289
|
+
|
|
290
|
+
@property
|
|
291
|
+
def has_backward_work(self) -> bool:
|
|
292
|
+
return True
|
|
293
|
+
|
|
294
|
+
def __str__(self) -> str:
|
|
295
|
+
letter = "B" if self.full_backward else "I"
|
|
296
|
+
return f"{self.stage_idx}{letter}{self.microbatch_idx}"
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
@dataclasses.dataclass(frozen=True, slots=True)
|
|
300
|
+
class BackwardWeightComputeAction(ActionBase):
|
|
301
|
+
"""
|
|
302
|
+
Action to perform gradient accumulation on weights.
|
|
303
|
+
|
|
304
|
+
Attributes:
|
|
305
|
+
stage_idx: The integer index of the pipeline stage.
|
|
306
|
+
microbatch_idx: The integer index of the microbatch to compute.
|
|
307
|
+
"""
|
|
308
|
+
|
|
309
|
+
stage_idx: int
|
|
310
|
+
microbatch_idx: int
|
|
311
|
+
|
|
312
|
+
def apply(self, ctx: ActionContext):
|
|
313
|
+
# todo unshard
|
|
314
|
+
stage = ctx.stages[self.stage_idx]
|
|
315
|
+
|
|
316
|
+
stage.backward_weight_one_chunk(
|
|
317
|
+
microbatch_index=self.microbatch_idx
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
@property
|
|
321
|
+
def work_type(self) -> ActionWorkType:
|
|
322
|
+
return ActionWorkType.compute
|
|
323
|
+
|
|
324
|
+
@property
|
|
325
|
+
def has_backward_work(self) -> bool:
|
|
326
|
+
return True
|
|
327
|
+
|
|
328
|
+
def __str__(self) -> str:
|
|
329
|
+
return f"{self.stage_idx}W{self.microbatch_idx}"
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
@dataclasses.dataclass(frozen=True, slots=True)
|
|
333
|
+
class ComposeAction(ActionBase):
|
|
334
|
+
"""
|
|
335
|
+
Composite action scheduling multiple sub-actions sequentially.
|
|
336
|
+
|
|
337
|
+
Used for forward/backward overlapping.
|
|
338
|
+
|
|
339
|
+
Attributes:
|
|
340
|
+
actions: A tuple of sub-actions to be executed sequentially.
|
|
341
|
+
"""
|
|
342
|
+
|
|
343
|
+
actions: tuple[ActionBase, ...]
|
|
344
|
+
|
|
345
|
+
def apply(self, ctx: ActionContext):
|
|
346
|
+
for act in self.actions:
|
|
347
|
+
act.apply(ctx)
|
|
348
|
+
|
|
349
|
+
@property
|
|
350
|
+
def work_type(self) -> ActionWorkType:
|
|
351
|
+
sub_work_types = {x.work_type for x in self.actions}
|
|
352
|
+
if len(sub_work_types) != 1:
|
|
353
|
+
raise ValueError("")
|
|
354
|
+
return next(iter(sub_work_types))
|
|
355
|
+
|
|
356
|
+
@property
|
|
357
|
+
def has_backward_work(self) -> bool:
|
|
358
|
+
return any(x.has_backward_work for x in self.actions)
|
|
359
|
+
|
|
360
|
+
def __str__(self) -> str:
|
|
361
|
+
return "|".join(map(str, self.actions))
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
import torch.distributed as dist
|
|
2
|
+
|
|
3
|
+
from d9d.pipelining.infra.stage import PipelineStage
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _schedule_batched_p2p(ops: list[dist.P2POp]) -> list[dist.Work]:
|
|
7
|
+
if not len(ops):
|
|
8
|
+
return []
|
|
9
|
+
|
|
10
|
+
return dist.batch_isend_irecv(ops)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _wait_batched_p2p(work: list[dist.Work]):
|
|
14
|
+
for work_item in work:
|
|
15
|
+
work_item.wait()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class PipelineCommunicationHandler:
|
|
19
|
+
"""Manages point-to-point communications between pipeline stages."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, stages: dict[int, PipelineStage]):
|
|
22
|
+
"""
|
|
23
|
+
Constructs the communication handler.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
stages: Mapping of stage indices to PipelineStage instances.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
self._stages = stages
|
|
30
|
+
|
|
31
|
+
self._forward_receive_ops: dict[tuple[int, int], list[dist.Work]] = {}
|
|
32
|
+
self._backward_receive_ops: dict[tuple[int, int], list[dist.Work]] = {}
|
|
33
|
+
|
|
34
|
+
self._send_ops: list[list[dist.Work]] = []
|
|
35
|
+
|
|
36
|
+
def schedule_fwd_send(self, stage_idx: int, microbatch_idx: int):
|
|
37
|
+
"""Schedules non-blocking connection to send forward pass outputs."""
|
|
38
|
+
|
|
39
|
+
stage = self._stages[stage_idx]
|
|
40
|
+
work = _schedule_batched_p2p(stage.get_fwd_send_ops(microbatch_idx))
|
|
41
|
+
self._send_ops.append(work)
|
|
42
|
+
|
|
43
|
+
def schedule_bwd_send(self, stage_idx: int, microbatch_idx: int):
|
|
44
|
+
"""Schedules non-blocking connection to send backward pass outputs."""
|
|
45
|
+
|
|
46
|
+
stage = self._stages[stage_idx]
|
|
47
|
+
work = _schedule_batched_p2p(stage.get_bwd_send_ops(microbatch_idx))
|
|
48
|
+
self._send_ops.append(work)
|
|
49
|
+
|
|
50
|
+
def schedule_fwd_recv(self, stage_idx: int, microbatch_idx: int):
|
|
51
|
+
"""
|
|
52
|
+
Schedules non-blocking connection to receive forward pass inputs.
|
|
53
|
+
|
|
54
|
+
Raises:
|
|
55
|
+
ValueError: If a receive op is already pending for this stage/microbatch.
|
|
56
|
+
"""
|
|
57
|
+
stage = self._stages[stage_idx]
|
|
58
|
+
key = (stage_idx, microbatch_idx)
|
|
59
|
+
|
|
60
|
+
if key in self._forward_receive_ops:
|
|
61
|
+
raise ValueError()
|
|
62
|
+
|
|
63
|
+
work = _schedule_batched_p2p(stage.get_fwd_recv_ops(microbatch_idx))
|
|
64
|
+
self._forward_receive_ops[key] = work
|
|
65
|
+
|
|
66
|
+
def wait_fwd_recv(self, stage_idx: int, microbatch_idx: int):
|
|
67
|
+
"""Blocks until the forward pass receive operation completes."""
|
|
68
|
+
key = (stage_idx, microbatch_idx)
|
|
69
|
+
_wait_batched_p2p(self._forward_receive_ops.pop(key))
|
|
70
|
+
|
|
71
|
+
def schedule_bwd_recv(self, stage_idx: int, microbatch_idx: int):
|
|
72
|
+
"""
|
|
73
|
+
Schedules non-blocking connection to receive backward pass inputs.
|
|
74
|
+
|
|
75
|
+
Raises:
|
|
76
|
+
ValueError: If a receive op is already pending for this stage/microbatch.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
stage = self._stages[stage_idx]
|
|
80
|
+
key = (stage_idx, microbatch_idx)
|
|
81
|
+
|
|
82
|
+
if key in self._backward_receive_ops:
|
|
83
|
+
raise ValueError()
|
|
84
|
+
|
|
85
|
+
work = _schedule_batched_p2p(stage.get_bwd_recv_ops(microbatch_idx))
|
|
86
|
+
|
|
87
|
+
self._backward_receive_ops[key] = work
|
|
88
|
+
|
|
89
|
+
def wait_bwd_recv(self, stage_idx: int, microbatch_idx: int):
|
|
90
|
+
"""Blocks until the backward pass receive operation completes."""
|
|
91
|
+
|
|
92
|
+
key = (stage_idx, microbatch_idx)
|
|
93
|
+
_wait_batched_p2p(self._backward_receive_ops.pop(key))
|
|
94
|
+
|
|
95
|
+
def wait_send_all(self):
|
|
96
|
+
"""Blocks until all pending send operations are completed."""
|
|
97
|
+
|
|
98
|
+
while self._send_ops:
|
|
99
|
+
ops = self._send_ops.pop()
|
|
100
|
+
for op in ops:
|
|
101
|
+
op.wait()
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch.autograd.profiler import record_function
|
|
5
|
+
|
|
6
|
+
from d9d.core.dist_context import REGULAR_DOMAIN, DistributedContext
|
|
7
|
+
from d9d.core.sharding import ShardingSpec, shard_spec_on_dim, shard_tree
|
|
8
|
+
from d9d.pipelining.api import PipelineSchedule, PipelineShardingSpec
|
|
9
|
+
from d9d.pipelining.infra.stage import PipelineStage
|
|
10
|
+
|
|
11
|
+
from .action import ActionBase, ActionContext
|
|
12
|
+
from .communications import PipelineCommunicationHandler
|
|
13
|
+
from .loss import LossFn, PipelineLossHandler
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PipelineScheduleExecutor(PipelineSchedule):
|
|
17
|
+
"""Executes a defined pipeline schedule by interpreting a sequence of actions."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
dist_context: DistributedContext,
|
|
22
|
+
stages: list[PipelineStage],
|
|
23
|
+
num_microbatches: int,
|
|
24
|
+
loss_fn: LossFn | None,
|
|
25
|
+
program: dict[int, list[ActionBase]]
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Constructs the schedule executor.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
dist_context: The distributed context.
|
|
32
|
+
stages: List of stages managed by this executor.
|
|
33
|
+
num_microbatches: Number of microbatches the global batch is split.
|
|
34
|
+
loss_fn: Function to compute loss.
|
|
35
|
+
program: The execution plan mapping rank ID to a list of actions.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
self._dist_ctx = dist_context
|
|
39
|
+
self._stages = {stage.info.current_stage: stage for stage in stages}
|
|
40
|
+
self._num_microbatches = num_microbatches
|
|
41
|
+
self._program = program
|
|
42
|
+
|
|
43
|
+
self._has_backward = any(any(
|
|
44
|
+
action.has_backward_work for action in sub_program
|
|
45
|
+
) for sub_program in program.values())
|
|
46
|
+
|
|
47
|
+
self._comm_handler = PipelineCommunicationHandler(self._stages)
|
|
48
|
+
if loss_fn is None:
|
|
49
|
+
self._loss_handler = None
|
|
50
|
+
else:
|
|
51
|
+
self._loss_handler = PipelineLossHandler(loss_fn)
|
|
52
|
+
|
|
53
|
+
self._input_data_sharding_spec: ShardingSpec | None = None
|
|
54
|
+
self._input_kwargs_sharding_spec: ShardingSpec | None = None
|
|
55
|
+
|
|
56
|
+
def configure_buffers(
|
|
57
|
+
self,
|
|
58
|
+
inputs: dict[str, torch.Tensor],
|
|
59
|
+
kwargs: dict[str, Any],
|
|
60
|
+
sharding_spec: PipelineShardingSpec | None
|
|
61
|
+
):
|
|
62
|
+
if sharding_spec is None or sharding_spec.input_data is None:
|
|
63
|
+
self._input_data_sharding_spec = shard_spec_on_dim(inputs, dim=0)
|
|
64
|
+
if sharding_spec is None or sharding_spec.input_kwargs is None:
|
|
65
|
+
self._input_kwargs_sharding_spec = shard_spec_on_dim(kwargs, dim=0)
|
|
66
|
+
|
|
67
|
+
for stage in self._stages.values():
|
|
68
|
+
stage.configure_buffers(
|
|
69
|
+
num_microbatches=self._num_microbatches,
|
|
70
|
+
pipeline_inputs=inputs,
|
|
71
|
+
has_backward=self._has_backward
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def step(self, inputs: dict[str, torch.Tensor], kwargs: dict[str, Any]):
|
|
75
|
+
if self._input_data_sharding_spec is None or self._input_kwargs_sharding_spec is None:
|
|
76
|
+
raise ValueError("Please configure sharding specs first")
|
|
77
|
+
|
|
78
|
+
self._dist_ctx.logger.debug("Begin pipeline step")
|
|
79
|
+
pp_group = self._dist_ctx.mesh_for(REGULAR_DOMAIN).get_group("pp")
|
|
80
|
+
|
|
81
|
+
for stage in self._stages.values():
|
|
82
|
+
stage.reset()
|
|
83
|
+
|
|
84
|
+
# Shard inputs and kwargs to microbatches
|
|
85
|
+
inputs_shard = shard_tree(
|
|
86
|
+
inputs,
|
|
87
|
+
num_shards=self._num_microbatches,
|
|
88
|
+
sharding_spec=self._input_data_sharding_spec,
|
|
89
|
+
enforce_even_split=True
|
|
90
|
+
)
|
|
91
|
+
kwargs_shard = shard_tree(
|
|
92
|
+
kwargs,
|
|
93
|
+
num_shards=self._num_microbatches,
|
|
94
|
+
sharding_spec=self._input_kwargs_sharding_spec,
|
|
95
|
+
enforce_even_split=True
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
my_program = self._program[pp_group.rank()]
|
|
99
|
+
|
|
100
|
+
for action in my_program:
|
|
101
|
+
with record_function(str(action)):
|
|
102
|
+
self._dist_ctx.logger.debug(f"Running pipeline action {action}")
|
|
103
|
+
action.apply(ActionContext(
|
|
104
|
+
loss=self._loss_handler,
|
|
105
|
+
stages=self._stages,
|
|
106
|
+
communications=self._comm_handler,
|
|
107
|
+
pipeline_inputs_microbatches=inputs_shard,
|
|
108
|
+
pipeline_kwargs_microbatches=kwargs_shard
|
|
109
|
+
))
|
|
110
|
+
|
|
111
|
+
self._dist_ctx.logger.debug("Waiting for potentially hanging PP send comms")
|
|
112
|
+
self._comm_handler.wait_send_all() # finalize just in case
|
|
113
|
+
self._dist_ctx.logger.debug("End pipeline step")
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
LossFn = Callable[[dict[str, torch.Tensor], int], torch.Tensor]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PipelineLossHandler:
|
|
9
|
+
"""Manages loss computation and state caching across forward and backward passes."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, loss_fn: LossFn):
|
|
12
|
+
"""
|
|
13
|
+
Constructs the loss handler.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
loss_fn: The callable that computes loss from model outputs.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
self._loss_fn = loss_fn
|
|
20
|
+
self._cached_values: dict[int, torch.Tensor] = {}
|
|
21
|
+
|
|
22
|
+
def compute_loss(self, forward_result: dict[str, torch.Tensor], microbatch_index: int) -> torch.Tensor:
|
|
23
|
+
"""
|
|
24
|
+
Computes loss for a given microbatch result and caches it.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
forward_result: The output from the last stage of the model.
|
|
28
|
+
microbatch_index: The index of the microbatch being processed.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
The computed loss tensor.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
result = self._loss_fn(forward_result, microbatch_index)
|
|
35
|
+
self._cached_values[microbatch_index] = result
|
|
36
|
+
return result
|
|
37
|
+
|
|
38
|
+
def acquire_loss(self, microbatch_index: int) -> torch.Tensor:
|
|
39
|
+
"""
|
|
40
|
+
Retrieves the cached loss tensor for the backward pass and removes it from the cache.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
microbatch_index: The index of the microbatch.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
The previously computed loss tensor.
|
|
47
|
+
|
|
48
|
+
Raises:
|
|
49
|
+
ValueError: If the loss for this microbatch hasn't been computed yet.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
if microbatch_index not in self._cached_values:
|
|
53
|
+
raise ValueError()
|
|
54
|
+
|
|
55
|
+
return self._cached_values[microbatch_index]
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pipeline Schedule Implementations
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from .bfs import LoopedBFSPipelineProgramBuilder
|
|
6
|
+
from .dualpipev import DualPipeVPipelineProgramBuilder
|
|
7
|
+
from .interleaved import Interleaved1F1BPipelineProgramBuilder
|
|
8
|
+
from .zerobubblev import ZeroBubbleVPipelineProgramBuilder
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"DualPipeVPipelineProgramBuilder",
|
|
12
|
+
"Interleaved1F1BPipelineProgramBuilder",
|
|
13
|
+
"LoopedBFSPipelineProgramBuilder",
|
|
14
|
+
"ZeroBubbleVPipelineProgramBuilder"
|
|
15
|
+
]
|