d9d 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- d9d/__init__.py +0 -0
- d9d/core/__init__.py +0 -0
- d9d/core/autograd/__init__.py +7 -0
- d9d/core/autograd/grad_context.py +85 -0
- d9d/core/dist_context/__init__.py +19 -0
- d9d/core/dist_context/configured.py +215 -0
- d9d/core/dist_context/device_mesh_domains.py +185 -0
- d9d/core/dist_context/log.py +30 -0
- d9d/core/dist_context/params.py +113 -0
- d9d/core/dist_ops/__init__.py +16 -0
- d9d/core/dist_ops/object.py +68 -0
- d9d/core/dist_ops/tensor.py +192 -0
- d9d/core/protocol/__init__.py +8 -0
- d9d/core/protocol/training.py +38 -0
- d9d/core/sharding/__init__.py +15 -0
- d9d/core/sharding/auto_spec.py +66 -0
- d9d/core/sharding/shard.py +154 -0
- d9d/core/sharding/spec.py +28 -0
- d9d/core/sharding/unshard.py +117 -0
- d9d/core/types/__init__.py +12 -0
- d9d/core/types/data.py +14 -0
- d9d/core/types/pytree.py +26 -0
- d9d/dataset/__init__.py +17 -0
- d9d/dataset/buffer_sorted.py +143 -0
- d9d/dataset/padding.py +79 -0
- d9d/dataset/sharded.py +195 -0
- d9d/internals/__init__.py +0 -0
- d9d/internals/determinism/__init__.py +10 -0
- d9d/internals/determinism/seed.py +63 -0
- d9d/internals/grad_norm/__init__.py +8 -0
- d9d/internals/grad_norm/group.py +87 -0
- d9d/internals/grad_norm/norm.py +169 -0
- d9d/internals/grad_sync/__init__.py +14 -0
- d9d/internals/grad_sync/bucket.py +317 -0
- d9d/internals/grad_sync/placement_helper.py +23 -0
- d9d/internals/grad_sync/synchronizer.py +257 -0
- d9d/internals/pipeline_state/__init__.py +14 -0
- d9d/internals/pipeline_state/api.py +45 -0
- d9d/internals/pipeline_state/handler.py +111 -0
- d9d/internals/pipeline_state/storage.py +236 -0
- d9d/internals/profiling/__init__.py +7 -0
- d9d/internals/profiling/profile.py +112 -0
- d9d/internals/state/__init__.py +6 -0
- d9d/internals/state/main_process.py +44 -0
- d9d/kernel/__init__.py +0 -0
- d9d/kernel/cce/__init__.py +5 -0
- d9d/kernel/cce/cce.py +298 -0
- d9d/kernel/cce/main.py +282 -0
- d9d/kernel/general/__init__.py +5 -0
- d9d/kernel/general/get_int_dtype.py +7 -0
- d9d/kernel/gmm/__init__.py +5 -0
- d9d/kernel/gmm/function.py +78 -0
- d9d/kernel/moe/__init__.py +8 -0
- d9d/kernel/moe/indices_to_multihot.py +268 -0
- d9d/kernel/moe/permute_with_probs.py +1035 -0
- d9d/kernel/stochastic/__init__.py +11 -0
- d9d/kernel/stochastic/adamw_step.py +204 -0
- d9d/kernel/stochastic/copy.py +104 -0
- d9d/kernel/stochastic/ops/__init__.py +5 -0
- d9d/kernel/stochastic/ops/round.py +22 -0
- d9d/kernel/swiglu/__init__.py +5 -0
- d9d/kernel/swiglu/function.py +36 -0
- d9d/kernel/swiglu/op.py +167 -0
- d9d/loop/__init__.py +0 -0
- d9d/loop/auto/__init__.py +9 -0
- d9d/loop/auto/auto_lr_scheduler.py +46 -0
- d9d/loop/auto/auto_optimizer.py +196 -0
- d9d/loop/component/__init__.py +35 -0
- d9d/loop/component/batch_maths.py +106 -0
- d9d/loop/component/checkpointer.py +172 -0
- d9d/loop/component/data_loader_factory.py +258 -0
- d9d/loop/component/garbage_collector.py +94 -0
- d9d/loop/component/gradient_clipper.py +89 -0
- d9d/loop/component/gradient_manager.py +149 -0
- d9d/loop/component/job_logger.py +146 -0
- d9d/loop/component/job_profiler.py +62 -0
- d9d/loop/component/loss_computer.py +86 -0
- d9d/loop/component/model_stage_exporter.py +37 -0
- d9d/loop/component/model_stage_factory.py +261 -0
- d9d/loop/component/optimizer_factory.py +88 -0
- d9d/loop/component/stepper.py +52 -0
- d9d/loop/component/timeout_manager.py +54 -0
- d9d/loop/component/train_task_operator.py +152 -0
- d9d/loop/config/__init__.py +36 -0
- d9d/loop/config/config.py +225 -0
- d9d/loop/config/types.py +24 -0
- d9d/loop/control/__init__.py +61 -0
- d9d/loop/control/dataset_provider.py +58 -0
- d9d/loop/control/lr_scheduler_provider.py +47 -0
- d9d/loop/control/model_provider.py +162 -0
- d9d/loop/control/optimizer_provider.py +45 -0
- d9d/loop/control/task.py +304 -0
- d9d/loop/run/__init__.py +6 -0
- d9d/loop/run/train.py +355 -0
- d9d/loop/state.py +143 -0
- d9d/lr_scheduler/__init__.py +9 -0
- d9d/lr_scheduler/piecewise/__init__.py +18 -0
- d9d/lr_scheduler/piecewise/builder.py +152 -0
- d9d/lr_scheduler/piecewise/config.py +176 -0
- d9d/lr_scheduler/piecewise/curves.py +75 -0
- d9d/lr_scheduler/piecewise/engine.py +76 -0
- d9d/lr_scheduler/visualizer.py +74 -0
- d9d/metric/__init__.py +10 -0
- d9d/metric/abc.py +79 -0
- d9d/metric/impl/__init__.py +7 -0
- d9d/metric/impl/compose.py +54 -0
- d9d/metric/impl/mean.py +94 -0
- d9d/model_state/__init__.py +0 -0
- d9d/model_state/io/__init__.py +21 -0
- d9d/model_state/io/dto.py +30 -0
- d9d/model_state/io/module_reader.py +75 -0
- d9d/model_state/io/module_writer.py +123 -0
- d9d/model_state/io/reader.py +125 -0
- d9d/model_state/io/writer.py +309 -0
- d9d/model_state/mapper/__init__.py +10 -0
- d9d/model_state/mapper/abc.py +70 -0
- d9d/model_state/mapper/adapters/__init__.py +12 -0
- d9d/model_state/mapper/adapters/mapper.py +27 -0
- d9d/model_state/mapper/adapters/module.py +22 -0
- d9d/model_state/mapper/compose/__init__.py +17 -0
- d9d/model_state/mapper/compose/helper.py +22 -0
- d9d/model_state/mapper/compose/parallel.py +58 -0
- d9d/model_state/mapper/compose/sequential.py +131 -0
- d9d/model_state/mapper/compose/shard.py +36 -0
- d9d/model_state/mapper/leaf/__init__.py +18 -0
- d9d/model_state/mapper/leaf/dtensor.py +56 -0
- d9d/model_state/mapper/leaf/identity.py +23 -0
- d9d/model_state/mapper/leaf/rename.py +26 -0
- d9d/model_state/mapper/leaf/select_child.py +37 -0
- d9d/model_state/mapper/leaf/stack.py +29 -0
- d9d/module/__init__.py +0 -0
- d9d/module/base/__init__.py +7 -0
- d9d/module/base/late_init.py +10 -0
- d9d/module/block/__init__.py +0 -0
- d9d/module/block/attention/__init__.py +7 -0
- d9d/module/block/attention/grouped_query.py +139 -0
- d9d/module/block/attention/sdpa/__init__.py +5 -0
- d9d/module/block/attention/sdpa/flash.py +52 -0
- d9d/module/block/embedding/__init__.py +7 -0
- d9d/module/block/embedding/shard_token_embedding.py +103 -0
- d9d/module/block/ffn/__init__.py +5 -0
- d9d/module/block/ffn/swiglu.py +60 -0
- d9d/module/block/head/__init__.py +6 -0
- d9d/module/block/head/language_modelling.py +87 -0
- d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
- d9d/module/block/hidden_states_aggregator/base.py +35 -0
- d9d/module/block/hidden_states_aggregator/factory.py +48 -0
- d9d/module/block/hidden_states_aggregator/mean.py +61 -0
- d9d/module/block/hidden_states_aggregator/noop.py +27 -0
- d9d/module/block/moe/__init__.py +13 -0
- d9d/module/block/moe/communications/__init__.py +11 -0
- d9d/module/block/moe/communications/base.py +58 -0
- d9d/module/block/moe/communications/deepep.py +300 -0
- d9d/module/block/moe/communications/naive.py +68 -0
- d9d/module/block/moe/grouped_experts.py +81 -0
- d9d/module/block/moe/grouped_linear.py +78 -0
- d9d/module/block/moe/layer.py +122 -0
- d9d/module/block/moe/router.py +103 -0
- d9d/module/block/positional/__init__.py +8 -0
- d9d/module/block/positional/rope.py +150 -0
- d9d/module/model/__init__.py +0 -0
- d9d/module/model/qwen3_moe/__init__.py +16 -0
- d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
- d9d/module/model/qwen3_moe/model.py +373 -0
- d9d/module/model/qwen3_moe/params.py +69 -0
- d9d/module/parallelism/__init__.py +0 -0
- d9d/module/parallelism/api/__init__.py +18 -0
- d9d/module/parallelism/api/expert_parallel.py +36 -0
- d9d/module/parallelism/api/fully_sharded.py +43 -0
- d9d/module/parallelism/api/hybrid_sharded.py +49 -0
- d9d/module/parallelism/api/replicate_parallel.py +33 -0
- d9d/module/parallelism/model/__init__.py +0 -0
- d9d/module/parallelism/model/qwen3_moe.py +99 -0
- d9d/module/parallelism/style/__init__.py +7 -0
- d9d/module/parallelism/style/shard_experts.py +60 -0
- d9d/module/parallelism/style/to_local.py +86 -0
- d9d/optim/__init__.py +0 -0
- d9d/optim/stochastic/__init__.py +5 -0
- d9d/optim/stochastic/adamw.py +158 -0
- d9d/peft/__init__.py +13 -0
- d9d/peft/all/__init__.py +12 -0
- d9d/peft/all/config.py +31 -0
- d9d/peft/all/method.py +76 -0
- d9d/peft/applicator.py +47 -0
- d9d/peft/base.py +70 -0
- d9d/peft/full_tune/__init__.py +11 -0
- d9d/peft/full_tune/config.py +20 -0
- d9d/peft/full_tune/method.py +46 -0
- d9d/peft/lora/__init__.py +15 -0
- d9d/peft/lora/config.py +35 -0
- d9d/peft/lora/layer.py +177 -0
- d9d/peft/lora/method.py +132 -0
- d9d/pipelining/__init__.py +0 -0
- d9d/pipelining/api/__init__.py +19 -0
- d9d/pipelining/api/module.py +149 -0
- d9d/pipelining/api/schedule.py +50 -0
- d9d/pipelining/api/sharding.py +9 -0
- d9d/pipelining/factory/__init__.py +21 -0
- d9d/pipelining/factory/config.py +89 -0
- d9d/pipelining/factory/factory.py +114 -0
- d9d/pipelining/factory/registry.py +82 -0
- d9d/pipelining/infra/__init__.py +0 -0
- d9d/pipelining/infra/schedule/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/__init__.py +0 -0
- d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
- d9d/pipelining/infra/schedule/component/program/base.py +35 -0
- d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
- d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
- d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
- d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
- d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
- d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
- d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
- d9d/pipelining/infra/schedule/program/__init__.py +15 -0
- d9d/pipelining/infra/schedule/program/bfs.py +86 -0
- d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
- d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
- d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
- d9d/pipelining/infra/stage/__init__.py +5 -0
- d9d/pipelining/infra/stage/communications.py +274 -0
- d9d/pipelining/infra/stage/computations.py +317 -0
- d9d/pipelining/infra/stage/splitgrad.py +377 -0
- d9d/pipelining/infra/stage/stage.py +321 -0
- d9d/pipelining/infra/stage/struct_helper.py +46 -0
- d9d/pipelining/training/__init__.py +7 -0
- d9d/pipelining/training/optimizer.py +41 -0
- d9d/pipelining/training/scheduler.py +34 -0
- d9d/tracker/__init__.py +14 -0
- d9d/tracker/base.py +124 -0
- d9d/tracker/factory.py +57 -0
- d9d/tracker/provider/__init__.py +0 -0
- d9d/tracker/provider/aim/__init__.py +0 -0
- d9d/tracker/provider/aim/config.py +23 -0
- d9d/tracker/provider/aim/tracker.py +114 -0
- d9d/tracker/provider/null.py +61 -0
- d9d-0.1.0.dist-info/METADATA +90 -0
- d9d-0.1.0.dist-info/RECORD +238 -0
- d9d-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
from torch import nn
|
|
4
|
+
|
|
5
|
+
from d9d.module.base import ModuleLateInit
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TopKRouter(nn.Module, ModuleLateInit):
|
|
9
|
+
"""
|
|
10
|
+
Selects the top-K experts based on a learned gating mechanism.
|
|
11
|
+
|
|
12
|
+
This router:
|
|
13
|
+
|
|
14
|
+
1. Projects input tokens into expert space
|
|
15
|
+
2. Applies softmax, optionally adds expert bias to influence selection
|
|
16
|
+
3. Selects the experts with the highest probabilities
|
|
17
|
+
4. Selected probabilities are then re-normalized to sum to 1 if needed.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
dim: int,
|
|
23
|
+
num_experts: int,
|
|
24
|
+
top_k: int,
|
|
25
|
+
renormalize_probabilities: bool,
|
|
26
|
+
enable_expert_bias: bool = False
|
|
27
|
+
):
|
|
28
|
+
"""
|
|
29
|
+
Constructs the TopKRouter.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
dim: Input feature dimensionality.
|
|
33
|
+
num_experts: Total number of experts to choose from.
|
|
34
|
+
top_k: Number of experts to select for each token.
|
|
35
|
+
renormalize_probabilities: If True, probabilities of selected experts will be renormalized to sum up to 1
|
|
36
|
+
enable_expert_bias: If True, adds a bias term to the routing scores before top-k selection. This can be
|
|
37
|
+
used for loss-free load balancing.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.gate = nn.Linear(dim, num_experts, bias=False)
|
|
42
|
+
|
|
43
|
+
self.expert_bias: nn.Buffer | None
|
|
44
|
+
if enable_expert_bias:
|
|
45
|
+
self.expert_bias = nn.Buffer(
|
|
46
|
+
torch.empty(num_experts, dtype=torch.float32),
|
|
47
|
+
persistent=True,
|
|
48
|
+
)
|
|
49
|
+
else:
|
|
50
|
+
self.expert_bias = None
|
|
51
|
+
|
|
52
|
+
self._num_experts = num_experts
|
|
53
|
+
self._top_k = top_k
|
|
54
|
+
self._renormalize_probabilities = renormalize_probabilities
|
|
55
|
+
|
|
56
|
+
def forward(
|
|
57
|
+
self,
|
|
58
|
+
hidden_states: torch.Tensor
|
|
59
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
60
|
+
"""
|
|
61
|
+
Calculates routing decisions for the input tokens.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
hidden_states: Input tokens. Shape: `(num_tokens, dim)`.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
A tuple containing:
|
|
68
|
+
|
|
69
|
+
- Selected expert indices. Shape: `(num_tokens, top_k)`.
|
|
70
|
+
- Normalized routing weights for the selected experts. Shape: `(num_tokens, top_k)`.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
# scores shape (bs*slen, num_experts)
|
|
74
|
+
|
|
75
|
+
# gate
|
|
76
|
+
scores = self.gate(hidden_states)
|
|
77
|
+
|
|
78
|
+
# and now do softmax (before top-k to be able to apply expert bias)
|
|
79
|
+
scores = F.softmax(scores, dim=-1, dtype=torch.float32)
|
|
80
|
+
|
|
81
|
+
# select top-k
|
|
82
|
+
if self.expert_bias is None:
|
|
83
|
+
scores, selected_experts_indices = torch.topk(
|
|
84
|
+
scores, k=self._top_k, dim=-1
|
|
85
|
+
)
|
|
86
|
+
else:
|
|
87
|
+
_, selected_experts_indices = torch.topk(
|
|
88
|
+
scores + self.expert_bias, k=self._top_k, dim=-1
|
|
89
|
+
)
|
|
90
|
+
scores = scores.gather(dim=-1, index=selected_experts_indices)
|
|
91
|
+
|
|
92
|
+
# re-normalize scores
|
|
93
|
+
denominator = scores.sum(dim=-1, keepdim=True) + 1e-20
|
|
94
|
+
scores = scores / denominator
|
|
95
|
+
|
|
96
|
+
return selected_experts_indices, scores
|
|
97
|
+
|
|
98
|
+
def reset_parameters(self):
|
|
99
|
+
"""Resets module parameters."""
|
|
100
|
+
if self.expert_bias is not None:
|
|
101
|
+
nn.init.zeros_(self.expert_bias)
|
|
102
|
+
|
|
103
|
+
self.gate.reset_parameters()
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
|
|
4
|
+
from d9d.module.base import ModuleLateInit
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _prepare_rope_inverse_frequencies(
|
|
8
|
+
rope_base: int,
|
|
9
|
+
inside_dim: int
|
|
10
|
+
) -> torch.Tensor:
|
|
11
|
+
"""
|
|
12
|
+
Calculates inverse frequencies for RoPE calculation.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
rope_base: Base for the geometric progression.
|
|
16
|
+
inside_dim: Dimension of the attention head (must be even).
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
A tensor containing the inverse frequencies.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
power = torch.arange(0, inside_dim, 2, dtype=torch.int64).to(dtype=torch.float) / inside_dim
|
|
23
|
+
freq = rope_base ** power
|
|
24
|
+
inv_freq = 1.0 / freq
|
|
25
|
+
return inv_freq
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def prepare_rotary_cos_sin_emb(
|
|
29
|
+
rope_base: int,
|
|
30
|
+
head_dim: int,
|
|
31
|
+
max_position_ids: int,
|
|
32
|
+
device: torch.device,
|
|
33
|
+
dtype: torch.dtype
|
|
34
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
35
|
+
"""
|
|
36
|
+
Precomputes rotary cosine and sine embeddings.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
rope_base: Base frequency for calculation.
|
|
40
|
+
head_dim: Dimensionality of the attention head (E).
|
|
41
|
+
max_position_ids: Maximum sequence length supported (S).
|
|
42
|
+
device: Target device for the tensors.
|
|
43
|
+
dtype: Target data type for the tensors.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
A tuple containing cosine and sine tensors, both of shapes [S, E].
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
position_ids = torch.arange(0, max_position_ids, dtype=torch.long)
|
|
50
|
+
freqs = _prepare_rope_inverse_frequencies(rope_base, head_dim)
|
|
51
|
+
|
|
52
|
+
arguments = (freqs[:, None] @ position_ids[None, :].float()).T
|
|
53
|
+
|
|
54
|
+
emb = torch.cat((arguments, arguments), dim=-1)
|
|
55
|
+
cos = emb.cos()
|
|
56
|
+
sin = emb.sin()
|
|
57
|
+
return cos.to(device=device, dtype=dtype), sin.to(device=device, dtype=dtype)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class RotaryEmbeddingProvider(nn.Module, ModuleLateInit):
|
|
61
|
+
"""Module that manages and provides Rotary Positional Embeddings."""
|
|
62
|
+
|
|
63
|
+
def __init__(self, rope_base: int, head_dim: int, max_position_ids: int):
|
|
64
|
+
"""Constructs the RotaryEmbeddingProvider."""
|
|
65
|
+
|
|
66
|
+
super().__init__()
|
|
67
|
+
self._rope_base = rope_base
|
|
68
|
+
self._head_dim = head_dim
|
|
69
|
+
self._max_position_ids = max_position_ids
|
|
70
|
+
self.cos_emb = nn.Buffer(torch.empty(max_position_ids, head_dim), persistent=False)
|
|
71
|
+
self.sin_emb = nn.Buffer(torch.empty(max_position_ids, head_dim), persistent=False)
|
|
72
|
+
|
|
73
|
+
def forward(self, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
74
|
+
"""
|
|
75
|
+
Retrieves cached cosine and sine embeddings for specific positions.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
position_ids: Tensor of position indices.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
A tuple of (cos, sin) tensors aligned with the input positions.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
return self.cos_emb[position_ids], self.sin_emb[position_ids]
|
|
85
|
+
|
|
86
|
+
def reset_parameters(self):
|
|
87
|
+
with torch.no_grad():
|
|
88
|
+
cos, sin = prepare_rotary_cos_sin_emb(
|
|
89
|
+
rope_base=self._rope_base,
|
|
90
|
+
head_dim=self._head_dim,
|
|
91
|
+
max_position_ids=self._max_position_ids,
|
|
92
|
+
device=self.cos_emb.device,
|
|
93
|
+
dtype=self.cos_emb.dtype
|
|
94
|
+
)
|
|
95
|
+
self.cos_emb.data = cos
|
|
96
|
+
self.sin_emb.data = sin
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
|
100
|
+
x1 = x[..., : x.shape[-1] // 2]
|
|
101
|
+
x2 = x[..., x.shape[-1] // 2:]
|
|
102
|
+
return torch.cat((-x2, x1), dim=-1)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _apply_rotary_pos_emb(
|
|
106
|
+
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
|
107
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
108
|
+
cos = cos.unsqueeze(1)
|
|
109
|
+
sin = sin.unsqueeze(1)
|
|
110
|
+
q_embed = (q * cos) + (_rotate_half(q) * sin)
|
|
111
|
+
k_embed = (k * cos) + (_rotate_half(k) * sin)
|
|
112
|
+
return q_embed, k_embed
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class RotaryEmbeddingApplicator(nn.Module):
|
|
116
|
+
"""Applies Rotary Positional Embeddings (RoPE) to Q and K projections."""
|
|
117
|
+
|
|
118
|
+
def __init__(self):
|
|
119
|
+
"""
|
|
120
|
+
Constructs RotaryEmbeddingApplicator object.
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
super().__init__()
|
|
124
|
+
|
|
125
|
+
def forward(
|
|
126
|
+
self,
|
|
127
|
+
query_states: torch.Tensor,
|
|
128
|
+
key_states: torch.Tensor,
|
|
129
|
+
position_embedding_cos: torch.Tensor,
|
|
130
|
+
position_embedding_sin: torch.Tensor
|
|
131
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
132
|
+
"""
|
|
133
|
+
Rotates query and key states using provided cosine and sine embeddings.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
query_states: Query tensor. Shape: `(batch, n_heads, seq_len, head_dim)`.
|
|
137
|
+
key_states: Key tensor. Shape: `(batch, n_kv_heads, seq_len, head_dim)`.
|
|
138
|
+
position_embedding_cos: Cosine values for positions.
|
|
139
|
+
Shape: `(batch, seq_len, head_dim)`.
|
|
140
|
+
position_embedding_sin: Sine values for positions.
|
|
141
|
+
Shape: `(batch, seq_len, head_dim)`.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
A tuple containing the rotated query and key tensors.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
query_states, key_states = _apply_rotary_pos_emb(query_states, key_states,
|
|
148
|
+
position_embedding_cos, position_embedding_sin)
|
|
149
|
+
|
|
150
|
+
return query_states, key_states
|
|
File without changes
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from .decoder_layer import Qwen3MoELayer
|
|
2
|
+
from .model import Qwen3MoEForCausalLM, Qwen3MoEModel
|
|
3
|
+
from .params import (
|
|
4
|
+
Qwen3MoEForCausalLMParameters,
|
|
5
|
+
Qwen3MoELayerParameters,
|
|
6
|
+
Qwen3MoEParameters,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"Qwen3MoEForCausalLM",
|
|
11
|
+
"Qwen3MoEForCausalLMParameters",
|
|
12
|
+
"Qwen3MoELayer",
|
|
13
|
+
"Qwen3MoELayerParameters",
|
|
14
|
+
"Qwen3MoEModel",
|
|
15
|
+
"Qwen3MoEParameters"
|
|
16
|
+
]
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
|
|
4
|
+
from d9d.module.base import ModuleLateInit
|
|
5
|
+
from d9d.module.block.attention import GroupedQueryAttention
|
|
6
|
+
from d9d.module.block.moe import MoELayer
|
|
7
|
+
|
|
8
|
+
from .params import Qwen3MoELayerParameters
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Qwen3MoELayer(nn.Module, ModuleLateInit):
|
|
12
|
+
"""
|
|
13
|
+
Implements a single Qwen3 Mixture-of-Experts (MoE) transformer layer.
|
|
14
|
+
|
|
15
|
+
This layer consists of a Grouped Query Attention mechanism followed by an MoE
|
|
16
|
+
MLP block, with pre-RMSNorm applied before each sub-layer.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
params: Qwen3MoELayerParameters
|
|
22
|
+
):
|
|
23
|
+
"""
|
|
24
|
+
Constructs a Qwen3MoELayer object.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
params: Configuration parameters for the layer.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
super().__init__()
|
|
31
|
+
|
|
32
|
+
self.self_attn = GroupedQueryAttention(
|
|
33
|
+
hidden_size=params.hidden_size,
|
|
34
|
+
num_attention_heads=params.num_attention_heads,
|
|
35
|
+
num_key_value_heads=params.num_key_value_heads,
|
|
36
|
+
is_causal=True,
|
|
37
|
+
qk_norm_eps=params.rms_norm_eps,
|
|
38
|
+
head_dim=params.head_dim
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
self.mlp = MoELayer(
|
|
42
|
+
hidden_dim=params.hidden_size,
|
|
43
|
+
num_grouped_experts=params.num_experts,
|
|
44
|
+
intermediate_dim_grouped=params.intermediate_size,
|
|
45
|
+
top_k=params.experts_top_k,
|
|
46
|
+
router_renormalize_probabilities=True
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
self.input_layernorm = nn.RMSNorm(params.hidden_size, eps=params.rms_norm_eps)
|
|
50
|
+
self.post_attention_layernorm = nn.RMSNorm(params.hidden_size, eps=params.rms_norm_eps)
|
|
51
|
+
|
|
52
|
+
def forward(
|
|
53
|
+
self,
|
|
54
|
+
hidden_states: torch.Tensor,
|
|
55
|
+
position_embeddings: tuple[torch.Tensor, torch.Tensor]
|
|
56
|
+
) -> torch.Tensor:
|
|
57
|
+
"""
|
|
58
|
+
Performs the forward pass of the MoE layer.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
hidden_states: Input tensor of shape `(batch, seq_len, hidden_dim)`.
|
|
62
|
+
position_embeddings: Tuple containing RoPE precomputed embeddings (cos, sin).
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Output tensor after attention and MoE blocks, shape `(batch, seq_len, hidden_dim)`.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
residual = hidden_states
|
|
69
|
+
|
|
70
|
+
hidden_states = self.input_layernorm(hidden_states)
|
|
71
|
+
|
|
72
|
+
hidden_states = self.self_attn(
|
|
73
|
+
hidden_states=hidden_states,
|
|
74
|
+
position_embeddings=position_embeddings,
|
|
75
|
+
attention_mask=None # no mask for moe decoder
|
|
76
|
+
)
|
|
77
|
+
hidden_states = residual + hidden_states
|
|
78
|
+
|
|
79
|
+
residual = hidden_states
|
|
80
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
81
|
+
hidden_states = self.mlp(hidden_states)
|
|
82
|
+
|
|
83
|
+
hidden_states = residual + hidden_states
|
|
84
|
+
|
|
85
|
+
return hidden_states
|
|
86
|
+
|
|
87
|
+
def reset_moe_stats(self):
|
|
88
|
+
"""
|
|
89
|
+
Resets statistical counters inside the MoE router (e.g., token counts per expert).
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
self.mlp.reset_stats()
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def moe_tokens_per_expert(self) -> torch.Tensor:
|
|
96
|
+
"""
|
|
97
|
+
Returns the number of tokens routed to each expert.
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
return self.mlp.tokens_per_expert
|
|
101
|
+
|
|
102
|
+
def reset_parameters(self):
|
|
103
|
+
"""
|
|
104
|
+
Resets module parameters.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
self.self_attn.reset_parameters()
|
|
108
|
+
self.mlp.reset_parameters()
|
|
109
|
+
self.input_layernorm.reset_parameters()
|
|
110
|
+
self.post_attention_layernorm.reset_parameters()
|