d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,373 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
2
|
+
from typing import cast
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
from torch.utils.checkpoint import checkpoint
|
|
7
|
+
|
|
8
|
+
from d9d.module.base import ModuleLateInit
|
|
9
|
+
from d9d.module.block.embedding import SplitTokenEmbeddings
|
|
10
|
+
from d9d.module.block.head import SplitLanguageModellingHead
|
|
11
|
+
from d9d.module.block.hidden_states_aggregator import HiddenStatesAggregationMode, create_hidden_states_aggregator
|
|
12
|
+
from d9d.module.block.positional import RotaryEmbeddingProvider
|
|
13
|
+
from d9d.pipelining.api import (
|
|
14
|
+
ModuleSupportsPipelining,
|
|
15
|
+
PipelineStageInfo,
|
|
16
|
+
distribute_layers_for_pipeline_stage,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from .decoder_layer import Qwen3MoELayer
|
|
20
|
+
from .params import Qwen3MoEForCausalLMParameters, Qwen3MoEParameters
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Qwen3MoEModel(nn.Module, ModuleLateInit, ModuleSupportsPipelining):
|
|
24
|
+
"""
|
|
25
|
+
The Qwen3 Mixture-of-Experts (MoE) Transformer Decoder backbone.
|
|
26
|
+
|
|
27
|
+
It is designed to be split across multiple pipeline stages.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
params: Qwen3MoEParameters,
|
|
33
|
+
stage: PipelineStageInfo,
|
|
34
|
+
hidden_states_snapshot_mode: HiddenStatesAggregationMode,
|
|
35
|
+
enable_checkpointing: bool
|
|
36
|
+
):
|
|
37
|
+
"""
|
|
38
|
+
Constructs the Qwen3MoEModel object.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
params: Configuration parameters for the full model.
|
|
42
|
+
stage: Information about the pipeline stage this instance belongs to.
|
|
43
|
+
hidden_states_snapshot_mode: Configures intermediate hidden state aggregation & snapshotting mode
|
|
44
|
+
enable_checkpointing: If True, enables activation checkpointing for transformer layers to save memory.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
super().__init__()
|
|
48
|
+
|
|
49
|
+
if stage.is_current_stage_first:
|
|
50
|
+
self.embed_tokens = SplitTokenEmbeddings(
|
|
51
|
+
hidden_size=params.layer.hidden_size,
|
|
52
|
+
split_vocab_size=params.split_vocab_size,
|
|
53
|
+
split_order=params.split_vocab_order
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# we use ModuleDict here to properly handle pipelining and loading weights after the model
|
|
57
|
+
# was pipelined
|
|
58
|
+
layer_start, layer_end = distribute_layers_for_pipeline_stage(
|
|
59
|
+
num_layers=params.num_hidden_layers,
|
|
60
|
+
num_virtual_layers_pre=params.pipeline_num_virtual_layers_pre, # embeddings
|
|
61
|
+
num_virtual_layers_post=params.pipeline_num_virtual_layers_post, # LM head
|
|
62
|
+
stage=stage
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
self._num_layers_before = layer_start
|
|
66
|
+
self._layers_iter = list(map(str, range(layer_start, layer_end)))
|
|
67
|
+
layers = nn.ModuleDict({
|
|
68
|
+
str(layer_idx): Qwen3MoELayer(params=params.layer) for layer_idx in self._layers_iter
|
|
69
|
+
})
|
|
70
|
+
self.layers: Mapping[str, Qwen3MoELayer] = cast(Mapping[str, Qwen3MoELayer], layers)
|
|
71
|
+
|
|
72
|
+
self.rope_provider = RotaryEmbeddingProvider(
|
|
73
|
+
max_position_ids=params.max_position_ids,
|
|
74
|
+
rope_base=params.rope_base,
|
|
75
|
+
head_dim=params.layer.head_dim
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
if stage.is_current_stage_last:
|
|
79
|
+
self.norm = nn.RMSNorm(
|
|
80
|
+
normalized_shape=params.layer.hidden_size,
|
|
81
|
+
eps=params.layer.rms_norm_eps
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
self._stage = stage
|
|
85
|
+
self._hidden_states_snapshot_mode = hidden_states_snapshot_mode
|
|
86
|
+
self._hidden_size = params.layer.hidden_size
|
|
87
|
+
self._enable_checkpointing = enable_checkpointing
|
|
88
|
+
|
|
89
|
+
def output_dtype(self) -> torch.dtype:
|
|
90
|
+
"""
|
|
91
|
+
Returns the data type of the model output hidden states.
|
|
92
|
+
"""
|
|
93
|
+
return self.layers[self._layers_iter[0]].input_layernorm.weight.dtype
|
|
94
|
+
|
|
95
|
+
def forward(
|
|
96
|
+
self,
|
|
97
|
+
input_ids: torch.Tensor | None = None,
|
|
98
|
+
hidden_states: torch.Tensor | None = None,
|
|
99
|
+
position_ids: torch.Tensor | None = None,
|
|
100
|
+
hidden_states_snapshot: torch.Tensor | None = None,
|
|
101
|
+
hidden_states_agg_mask: torch.Tensor | None = None,
|
|
102
|
+
) -> dict[str, torch.Tensor]:
|
|
103
|
+
"""
|
|
104
|
+
Executes the forward pass for the current pipeline stage.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
input_ids: Indices of input sequence tokens. Required if this is the
|
|
108
|
+
first pipeline stage.
|
|
109
|
+
hidden_states: Hidden states from the previous pipeline stage. Required
|
|
110
|
+
if this is not the first pipeline stage.
|
|
111
|
+
position_ids: Indices of positions of each input sequence tokens in the
|
|
112
|
+
position embeddings.
|
|
113
|
+
hidden_states_snapshot: Accumulated tensor of aggregated hidden states
|
|
114
|
+
from previous stages. Used if snapshotting is enabled.
|
|
115
|
+
hidden_states_agg_mask: Mask used to aggregate hidden states for
|
|
116
|
+
snapshots.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
A dictionary containing:
|
|
120
|
+
* 'hidden_states': The output of the last layer in this stage.
|
|
121
|
+
* 'hidden_states_snapshot': (Optional) The updated snapshot tensor.
|
|
122
|
+
"""
|
|
123
|
+
state_aggregator = create_hidden_states_aggregator(self._hidden_states_snapshot_mode, hidden_states_agg_mask)
|
|
124
|
+
|
|
125
|
+
if input_ids is not None:
|
|
126
|
+
last_hidden_states = self.embed_tokens(input_ids)
|
|
127
|
+
state_aggregator.add_hidden_states(last_hidden_states)
|
|
128
|
+
else:
|
|
129
|
+
last_hidden_states = hidden_states
|
|
130
|
+
|
|
131
|
+
rope_params = self.rope_provider(position_ids)
|
|
132
|
+
|
|
133
|
+
for decoder_layer_name in self._layers_iter:
|
|
134
|
+
decoder_layer = self.layers[decoder_layer_name]
|
|
135
|
+
|
|
136
|
+
if self._enable_checkpointing:
|
|
137
|
+
last_hidden_states = checkpoint(
|
|
138
|
+
decoder_layer, last_hidden_states, rope_params,
|
|
139
|
+
use_reentrant=False
|
|
140
|
+
)
|
|
141
|
+
else:
|
|
142
|
+
last_hidden_states = decoder_layer(last_hidden_states, rope_params)
|
|
143
|
+
|
|
144
|
+
state_aggregator.add_hidden_states(last_hidden_states)
|
|
145
|
+
|
|
146
|
+
if self._stage.is_current_stage_last:
|
|
147
|
+
last_hidden_states = self.norm(last_hidden_states)
|
|
148
|
+
|
|
149
|
+
return {
|
|
150
|
+
"hidden_states": last_hidden_states,
|
|
151
|
+
"hidden_states_snapshot": state_aggregator.pack_with_snapshot(hidden_states_snapshot)
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
def reset_moe_stats(self):
|
|
155
|
+
"""
|
|
156
|
+
Resets routing statistics for all MoE layers in this stage.
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
for layer_name in self._layers_iter:
|
|
160
|
+
self.layers[layer_name].reset_moe_stats()
|
|
161
|
+
|
|
162
|
+
@property
|
|
163
|
+
def moe_tokens_per_expert(self) -> torch.Tensor:
|
|
164
|
+
"""
|
|
165
|
+
Retrieves the number of tokens routed to each expert across all layers.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
A tensor of shape (num_local_layers, num_experts) containing counts.
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
return torch.stack(
|
|
172
|
+
[self.layers[layer_name].moe_tokens_per_expert for layer_name in self._layers_iter],
|
|
173
|
+
dim=0
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
def reset_parameters(self):
|
|
177
|
+
"""Resets module parameters"""
|
|
178
|
+
|
|
179
|
+
if self._stage.is_current_stage_first:
|
|
180
|
+
self.embed_tokens.reset_parameters()
|
|
181
|
+
|
|
182
|
+
self.rope_provider.reset_parameters()
|
|
183
|
+
|
|
184
|
+
for decoder_layer_name in self._layers_iter:
|
|
185
|
+
decoder_layer = self.layers[decoder_layer_name]
|
|
186
|
+
decoder_layer.reset_parameters()
|
|
187
|
+
|
|
188
|
+
if self._stage.is_current_stage_last:
|
|
189
|
+
self.norm.reset_parameters()
|
|
190
|
+
|
|
191
|
+
def infer_stage_inputs_from_pipeline_inputs(
|
|
192
|
+
self, inputs: dict[str, torch.Tensor], n_microbatches: int
|
|
193
|
+
) -> dict[str, torch.Tensor]:
|
|
194
|
+
input_ids = inputs["input_ids"]
|
|
195
|
+
|
|
196
|
+
pp_inputs = {}
|
|
197
|
+
|
|
198
|
+
# for calculation - input ids or prev hidden state
|
|
199
|
+
if self._stage.is_current_stage_first:
|
|
200
|
+
pp_inputs["input_ids"] = torch.empty(
|
|
201
|
+
(input_ids.shape[0] // n_microbatches, input_ids.shape[1]),
|
|
202
|
+
dtype=torch.long,
|
|
203
|
+
device=input_ids.device
|
|
204
|
+
)
|
|
205
|
+
else:
|
|
206
|
+
pp_inputs["hidden_states"] = torch.empty(
|
|
207
|
+
(input_ids.shape[0] // n_microbatches, input_ids.shape[1], self._hidden_size),
|
|
208
|
+
dtype=self.output_dtype(),
|
|
209
|
+
device=input_ids.device
|
|
210
|
+
)
|
|
211
|
+
if self._hidden_states_snapshot_mode != HiddenStatesAggregationMode.no:
|
|
212
|
+
num_layers_before = self._num_layers_before + 1 # 1 for embedding
|
|
213
|
+
pp_inputs["hidden_states_snapshot"] = torch.empty(
|
|
214
|
+
(num_layers_before, input_ids.shape[0] // n_microbatches, self._hidden_size),
|
|
215
|
+
dtype=self.output_dtype(),
|
|
216
|
+
device=input_ids.device
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
return pp_inputs
|
|
220
|
+
|
|
221
|
+
def infer_stage_outputs_from_pipeline_inputs(
|
|
222
|
+
self, inputs: dict[str, torch.Tensor], n_microbatches: int
|
|
223
|
+
) -> dict[str, torch.Tensor]:
|
|
224
|
+
input_ids = inputs["input_ids"]
|
|
225
|
+
|
|
226
|
+
# for calculation - last hidden state
|
|
227
|
+
pp_outputs = {
|
|
228
|
+
"hidden_states": torch.empty(
|
|
229
|
+
(input_ids.shape[0] // n_microbatches, input_ids.shape[1], self._hidden_size),
|
|
230
|
+
dtype=self.output_dtype(),
|
|
231
|
+
device=input_ids.device
|
|
232
|
+
)
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
# for state caching
|
|
236
|
+
if self._hidden_states_snapshot_mode != HiddenStatesAggregationMode.no:
|
|
237
|
+
num_layers_before = self._num_layers_before + 1
|
|
238
|
+
num_layers_current = len(self.layers)
|
|
239
|
+
num_layers_after = num_layers_before + num_layers_current
|
|
240
|
+
pp_outputs["hidden_states_snapshot"] = torch.empty(
|
|
241
|
+
(num_layers_after, input_ids.shape[0] // n_microbatches, self._hidden_size),
|
|
242
|
+
dtype=self.output_dtype(),
|
|
243
|
+
device=input_ids.device
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
return pp_outputs
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class Qwen3MoEForCausalLM(nn.Module, ModuleLateInit, ModuleSupportsPipelining):
|
|
250
|
+
"""
|
|
251
|
+
A Qwen3 MoE model wrapped with a Causal Language Modeling head.
|
|
252
|
+
|
|
253
|
+
It is designed to be split across multiple pipeline stages.
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
def __init__(
|
|
257
|
+
self,
|
|
258
|
+
params: Qwen3MoEForCausalLMParameters,
|
|
259
|
+
stage: PipelineStageInfo,
|
|
260
|
+
hidden_states_snapshot_mode: HiddenStatesAggregationMode,
|
|
261
|
+
enable_checkpointing: bool
|
|
262
|
+
):
|
|
263
|
+
"""
|
|
264
|
+
Constructs the Qwen3MoEForCausalLM object.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
params: Full model configuration parameters.
|
|
268
|
+
stage: Pipeline stage information for this instance.
|
|
269
|
+
hidden_states_snapshot_mode: Configures intermediate hidden state aggregation & snapshotting mode.
|
|
270
|
+
enable_checkpointing: Whether to enable activation checkpointing.
|
|
271
|
+
"""
|
|
272
|
+
|
|
273
|
+
super().__init__()
|
|
274
|
+
|
|
275
|
+
self.model = Qwen3MoEModel(
|
|
276
|
+
params.model,
|
|
277
|
+
stage,
|
|
278
|
+
hidden_states_snapshot_mode=hidden_states_snapshot_mode,
|
|
279
|
+
enable_checkpointing=enable_checkpointing
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
if stage.is_current_stage_last:
|
|
283
|
+
self.lm_head = SplitLanguageModellingHead(
|
|
284
|
+
split_vocab_size=params.model.split_vocab_size,
|
|
285
|
+
split_order=params.model.split_vocab_order,
|
|
286
|
+
hidden_size=params.model.layer.hidden_size
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
self._stage = stage
|
|
290
|
+
self._hidden_size = params.model.layer.hidden_size
|
|
291
|
+
|
|
292
|
+
def forward(
|
|
293
|
+
self,
|
|
294
|
+
input_ids: torch.Tensor | None = None,
|
|
295
|
+
hidden_states: torch.Tensor | None = None,
|
|
296
|
+
position_ids: torch.Tensor | None = None,
|
|
297
|
+
hidden_states_snapshot: torch.Tensor | None = None,
|
|
298
|
+
hidden_states_agg_mask: torch.Tensor | None = None,
|
|
299
|
+
labels: torch.Tensor | None = None
|
|
300
|
+
) -> dict[str, torch.Tensor]:
|
|
301
|
+
"""
|
|
302
|
+
Executes the model forward pass.
|
|
303
|
+
|
|
304
|
+
If this is the last stage, it expects `labels` to be provided and computes
|
|
305
|
+
the cross-entropy loss (returned as 'logps' typically representing per-token loss).
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
input_ids: Input token IDS (for Stage 0).
|
|
309
|
+
hidden_states: Hidden states from previous stage (for Stage > 0).
|
|
310
|
+
position_ids: Positional indices for RoPE.
|
|
311
|
+
hidden_states_snapshot: Intermediate state collector.
|
|
312
|
+
hidden_states_agg_mask: Mask for state aggregation.
|
|
313
|
+
labels: Target tokens for loss computation (Last Stage).
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
Dictionary containing 'hidden_states', optionally 'hidden_states_snapshot',
|
|
317
|
+
and per-token 'logps' if on the last stage.
|
|
318
|
+
"""
|
|
319
|
+
|
|
320
|
+
model_outputs = self.model(
|
|
321
|
+
input_ids=input_ids,
|
|
322
|
+
hidden_states=hidden_states,
|
|
323
|
+
position_ids=position_ids,
|
|
324
|
+
hidden_states_snapshot=hidden_states_snapshot,
|
|
325
|
+
hidden_states_agg_mask=hidden_states_agg_mask
|
|
326
|
+
)
|
|
327
|
+
if self._stage.is_current_stage_last:
|
|
328
|
+
lm_out = self.lm_head(
|
|
329
|
+
hidden_states=model_outputs["hidden_states"],
|
|
330
|
+
labels=labels
|
|
331
|
+
)
|
|
332
|
+
model_outputs["logps"] = lm_out
|
|
333
|
+
return model_outputs
|
|
334
|
+
|
|
335
|
+
def reset_parameters(self):
|
|
336
|
+
"""
|
|
337
|
+
Resets module parameters.
|
|
338
|
+
"""
|
|
339
|
+
|
|
340
|
+
self.model.reset_parameters()
|
|
341
|
+
|
|
342
|
+
if self._stage.is_current_stage_last:
|
|
343
|
+
self.lm_head.reset_parameters()
|
|
344
|
+
|
|
345
|
+
def reset_moe_stats(self):
|
|
346
|
+
"""
|
|
347
|
+
Resets MoE routing statistics in the backbone.
|
|
348
|
+
"""
|
|
349
|
+
|
|
350
|
+
self.model.reset_moe_stats()
|
|
351
|
+
|
|
352
|
+
@property
|
|
353
|
+
def moe_tokens_per_expert(self) -> torch.Tensor:
|
|
354
|
+
"""
|
|
355
|
+
Accesses MoE routing statistics from the backbone.
|
|
356
|
+
"""
|
|
357
|
+
|
|
358
|
+
return self.model.moe_tokens_per_expert
|
|
359
|
+
|
|
360
|
+
def infer_stage_inputs_from_pipeline_inputs(
|
|
361
|
+
self, inputs: dict[str, torch.Tensor], n_microbatches: int
|
|
362
|
+
) -> dict[str, torch.Tensor]:
|
|
363
|
+
return self.model.infer_stage_inputs_from_pipeline_inputs(inputs, n_microbatches)
|
|
364
|
+
|
|
365
|
+
def infer_stage_outputs_from_pipeline_inputs(
|
|
366
|
+
self, inputs: dict[str, torch.Tensor], n_microbatches: int
|
|
367
|
+
) -> dict[str, torch.Tensor]:
|
|
368
|
+
pp_outputs = self.model.infer_stage_outputs_from_pipeline_inputs(inputs, n_microbatches)
|
|
369
|
+
|
|
370
|
+
if self._stage.is_current_stage_last:
|
|
371
|
+
pp_outputs["logps"] = torch.empty(inputs["input_ids"].shape, dtype=torch.float32)
|
|
372
|
+
|
|
373
|
+
return pp_outputs
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from pydantic import BaseModel
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Qwen3MoELayerParameters(BaseModel):
|
|
5
|
+
"""
|
|
6
|
+
Configuration parameters for a single Qwen3 MoE layer.
|
|
7
|
+
|
|
8
|
+
Attributes:
|
|
9
|
+
hidden_size: Dimension of the model's hidden states.
|
|
10
|
+
intermediate_size: Dimension of the feed-forward hidden state.
|
|
11
|
+
num_experts: Total number of experts in the MoE layer.
|
|
12
|
+
experts_top_k: Number of experts to route tokens to.
|
|
13
|
+
num_attention_heads: Number of attention heads for the query.
|
|
14
|
+
num_key_value_heads: Number of attention heads for key and value.
|
|
15
|
+
rms_norm_eps: Epsilon value found in the RMSNorm layers.
|
|
16
|
+
head_dim: Dimension of a single attention head.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
hidden_size: int
|
|
20
|
+
intermediate_size: int
|
|
21
|
+
num_experts: int
|
|
22
|
+
experts_top_k: int
|
|
23
|
+
num_attention_heads: int
|
|
24
|
+
num_key_value_heads: int
|
|
25
|
+
rms_norm_eps: float
|
|
26
|
+
head_dim: int
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Qwen3MoEParameters(BaseModel):
|
|
30
|
+
"""
|
|
31
|
+
Configuration parameters for the Qwen3 Mixture-of-Experts model backbone.
|
|
32
|
+
|
|
33
|
+
Attributes:
|
|
34
|
+
layer: Configuration shared across all transformer layers.
|
|
35
|
+
num_hidden_layers: The total number of transformer layers.
|
|
36
|
+
rope_base: Base value for RoPE frequency calculation.
|
|
37
|
+
max_position_ids: Maximum sequence length.
|
|
38
|
+
split_vocab_size: A dictionary mapping vocabulary segment names to their sizes.
|
|
39
|
+
split_vocab_order: The sequence in which vocabulary splits are correctly ordered.
|
|
40
|
+
pipeline_num_virtual_layers_pre: The number of 'virtual' layers representing the
|
|
41
|
+
computational cost of modules on the *first* stage, before the main
|
|
42
|
+
layers (e.g., token and positional embeddings).
|
|
43
|
+
pipeline_num_virtual_layers_post: The number of 'virtual' layers representing the
|
|
44
|
+
computational cost of modules on the *last* stage, after the main
|
|
45
|
+
layers (e.g., the final layer normalization and LM head).
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
layer: Qwen3MoELayerParameters
|
|
49
|
+
|
|
50
|
+
num_hidden_layers: int
|
|
51
|
+
rope_base: int
|
|
52
|
+
max_position_ids: int
|
|
53
|
+
|
|
54
|
+
split_vocab_size: dict[str, int]
|
|
55
|
+
split_vocab_order: list[str]
|
|
56
|
+
|
|
57
|
+
pipeline_num_virtual_layers_pre: int = 0
|
|
58
|
+
pipeline_num_virtual_layers_post: int = 0
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class Qwen3MoEForCausalLMParameters(BaseModel):
|
|
62
|
+
"""
|
|
63
|
+
Configuration parameters for Qwen3 Mixture-of-Experts model with a Causal Language Modeling head.
|
|
64
|
+
|
|
65
|
+
Attributes:
|
|
66
|
+
model: The configuration for the underlying Qwen3 MoE model.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
model: Qwen3MoEParameters
|
|
File without changes
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Horizontal parallelism strategies and utilities for d9d modules.
|
|
3
|
+
|
|
4
|
+
This package provides high-level helper functions to apply specific distributed
|
|
5
|
+
parallelism strategies to PyTorch modules compatible with the d9d ecosystem.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .expert_parallel import parallelize_expert_parallel
|
|
9
|
+
from .fully_sharded import parallelize_fsdp
|
|
10
|
+
from .hybrid_sharded import parallelize_hsdp
|
|
11
|
+
from .replicate_parallel import parallelize_replicate
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"parallelize_expert_parallel",
|
|
15
|
+
"parallelize_fsdp",
|
|
16
|
+
"parallelize_hsdp",
|
|
17
|
+
"parallelize_replicate"
|
|
18
|
+
]
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from torch.distributed import DeviceMesh
|
|
2
|
+
from torch.distributed.tensor import Replicate
|
|
3
|
+
from torch.distributed.tensor.parallel import parallelize_module
|
|
4
|
+
|
|
5
|
+
from d9d.module.block.moe import MoELayer
|
|
6
|
+
from d9d.module.parallelism.style import ShardMoESparseExpertsParallel, ToLocalParallel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def parallelize_expert_parallel(
|
|
10
|
+
module: MoELayer,
|
|
11
|
+
mesh_experts: DeviceMesh,
|
|
12
|
+
expert_shard_dim: str = "ep_shard"
|
|
13
|
+
):
|
|
14
|
+
"""
|
|
15
|
+
Applies Expert Parallelism to a MoE layer.
|
|
16
|
+
|
|
17
|
+
This function configures the provided Mixture of Experts layer for distributed
|
|
18
|
+
execution.
|
|
19
|
+
|
|
20
|
+
It partitions the sparse experts across the specified dimension
|
|
21
|
+
of the device mesh (Expert Parallelism) and replicates along other dims.
|
|
22
|
+
|
|
23
|
+
Simultaneously, it configures the router to be fully replicated across
|
|
24
|
+
the mesh.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
module: The MoE layer instance to parallelize.
|
|
28
|
+
mesh_experts: The device mesh containing the expert parallel resources.
|
|
29
|
+
expert_shard_dim: The name of the mesh dimension where experts should be sharded.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
parallelize_module(module, mesh_experts, ShardMoESparseExpertsParallel(shard_dim_name=expert_shard_dim))
|
|
33
|
+
parallelize_module(module.router, mesh_experts, ToLocalParallel(
|
|
34
|
+
param_placement=tuple(Replicate() for _ in range(mesh_experts.ndim)),
|
|
35
|
+
grad_placement=tuple(Replicate() for _ in range(mesh_experts.ndim))
|
|
36
|
+
))
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from torch import nn
|
|
4
|
+
from torch.distributed import DeviceMesh
|
|
5
|
+
from torch.distributed.fsdp import FSDPModule, fully_shard
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _force_fsdp_grad_reduction_policy(module: FSDPModule):
|
|
9
|
+
module.set_force_sum_reduction_for_comms(enable=True)
|
|
10
|
+
module.set_gradient_divide_factor(1.0)
|
|
11
|
+
module.set_requires_all_reduce(False)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def parallelize_fsdp(
|
|
15
|
+
module: nn.Module,
|
|
16
|
+
mesh: DeviceMesh,
|
|
17
|
+
*args: Any,
|
|
18
|
+
**kwargs: Any
|
|
19
|
+
):
|
|
20
|
+
"""
|
|
21
|
+
Applies Fully Sharded Data Parallel (FSDP) with forced gradient summation.
|
|
22
|
+
|
|
23
|
+
This function wraps the provided module with PyTorch's ``fully_shard`` API using
|
|
24
|
+
the specified device mesh. Unlike standard FSDP usage, this function explicitly
|
|
25
|
+
configures the module to sum gradients across the mesh
|
|
26
|
+
instead of averaging them and disables internal all-sum-reduce hooks.
|
|
27
|
+
This is intended for d9d to handle gradient normalization and reduction across replicas externally.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
module: The module to shard.
|
|
31
|
+
mesh: The device mesh over which to shard the module.
|
|
32
|
+
*args: Additional positional arguments passed to ``fully_shard``.
|
|
33
|
+
**kwargs: Additional keyword arguments passed to ``fully_shard``.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
if mesh.ndim != 1:
|
|
37
|
+
raise ValueError("FSDP mesh should contain exactly one dimension - for HSDP, please apply "
|
|
38
|
+
"parallelize_replicate(...) first!")
|
|
39
|
+
|
|
40
|
+
fully_shard(module, *args, mesh=mesh, **kwargs)
|
|
41
|
+
if not isinstance(module, FSDPModule):
|
|
42
|
+
raise RuntimeError("Torch FSDP did not convert the module into FSDPModule")
|
|
43
|
+
_force_fsdp_grad_reduction_policy(module)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from torch import nn
|
|
4
|
+
from torch.distributed import DeviceMesh
|
|
5
|
+
|
|
6
|
+
from .fully_sharded import parallelize_fsdp
|
|
7
|
+
from .replicate_parallel import parallelize_replicate
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def parallelize_hsdp(
|
|
11
|
+
module: nn.Module,
|
|
12
|
+
mesh: DeviceMesh,
|
|
13
|
+
shard_dim: str = "dp_cp_shard",
|
|
14
|
+
*fsdp_args: Any,
|
|
15
|
+
**fsdp_kwargs: Any
|
|
16
|
+
):
|
|
17
|
+
"""
|
|
18
|
+
Applies Hybrid Sharded Data Parallelism (HSDP) to a module.
|
|
19
|
+
|
|
20
|
+
This function decomposes the provided device mesh into sharding dimensions
|
|
21
|
+
and replication dimensions. It applies replication parallelism
|
|
22
|
+
across the replication dimensions and Fully Sharded Data Parallelism (FSDP)
|
|
23
|
+
across the specified shard dimension.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
module: The module to parallelize.
|
|
27
|
+
mesh: The device mesh over which to distribute the module.
|
|
28
|
+
shard_dim: The name of the mesh dimension used for FSDP sharding. Any
|
|
29
|
+
dimension in the mesh not matching this name will be treated as a
|
|
30
|
+
replication dimension.
|
|
31
|
+
*fsdp_args: Positional arguments passed to the underlying FSDP parallelizer.
|
|
32
|
+
**fsdp_kwargs: Keyword arguments passed to the underlying FSDP parallelizer.
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
ValueError: If the device mesh does not have named dimensions.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
replicate_dims = mesh.mesh_dim_names
|
|
39
|
+
|
|
40
|
+
if replicate_dims is None:
|
|
41
|
+
raise ValueError("Cannot use with unnamed device meshes")
|
|
42
|
+
|
|
43
|
+
replicate_dims = tuple(x for x in replicate_dims if x != shard_dim and mesh[x].size() > 1)
|
|
44
|
+
|
|
45
|
+
if len(replicate_dims) > 0:
|
|
46
|
+
parallelize_replicate(module, mesh[replicate_dims])
|
|
47
|
+
|
|
48
|
+
if mesh[shard_dim].size() != 1:
|
|
49
|
+
parallelize_fsdp(module, mesh[shard_dim], *fsdp_args, **fsdp_kwargs)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from torch import nn
|
|
2
|
+
from torch.distributed import DeviceMesh
|
|
3
|
+
from torch.distributed.tensor import Replicate
|
|
4
|
+
from torch.distributed.tensor.parallel import parallelize_module
|
|
5
|
+
|
|
6
|
+
from d9d.module.parallelism.style import ToLocalParallel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def parallelize_replicate(
|
|
10
|
+
module: nn.Module,
|
|
11
|
+
mesh: DeviceMesh,
|
|
12
|
+
):
|
|
13
|
+
"""
|
|
14
|
+
Applies replicated parallelism to the module.
|
|
15
|
+
|
|
16
|
+
This function configures the provided module to be fully replicated across the
|
|
17
|
+
given device mesh. It utilizes the ``ToLocalParallel`` style, which manages
|
|
18
|
+
``DTensor`` wrapping for parameters and gradients (via ``Replicate`` placements)
|
|
19
|
+
while ensuring that the underlying computation sees standard local tensors during the forward pass.
|
|
20
|
+
|
|
21
|
+
This approach is effectively Data Parallelism managed via the DTensor
|
|
22
|
+
APIs, allowing seamless integration of modules that require local tensor inputs
|
|
23
|
+
into a broader distributed mesh context.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
module: The module to parallelize.
|
|
27
|
+
mesh: The device mesh over which to replicate the module.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
parallelize_module(module, mesh, ToLocalParallel(
|
|
31
|
+
param_placement=tuple(Replicate() for _ in range(mesh.ndim)),
|
|
32
|
+
grad_placement=tuple(Replicate() for _ in range(mesh.ndim))
|
|
33
|
+
))
|
|
File without changes
|