d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,103 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ from d9d.module.base import ModuleLateInit
6
+
7
+
8
+ class TopKRouter(nn.Module, ModuleLateInit):
9
+ """
10
+ Selects the top-K experts based on a learned gating mechanism.
11
+
12
+ This router:
13
+
14
+ 1. Projects input tokens into expert space
15
+ 2. Applies softmax, optionally adds expert bias to influence selection
16
+ 3. Selects the experts with the highest probabilities
17
+ 4. Selected probabilities are then re-normalized to sum to 1 if needed.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ dim: int,
23
+ num_experts: int,
24
+ top_k: int,
25
+ renormalize_probabilities: bool,
26
+ enable_expert_bias: bool = False
27
+ ):
28
+ """
29
+ Constructs the TopKRouter.
30
+
31
+ Args:
32
+ dim: Input feature dimensionality.
33
+ num_experts: Total number of experts to choose from.
34
+ top_k: Number of experts to select for each token.
35
+ renormalize_probabilities: If True, probabilities of selected experts will be renormalized to sum up to 1
36
+ enable_expert_bias: If True, adds a bias term to the routing scores before top-k selection. This can be
37
+ used for loss-free load balancing.
38
+ """
39
+
40
+ super().__init__()
41
+ self.gate = nn.Linear(dim, num_experts, bias=False)
42
+
43
+ self.expert_bias: nn.Buffer | None
44
+ if enable_expert_bias:
45
+ self.expert_bias = nn.Buffer(
46
+ torch.empty(num_experts, dtype=torch.float32),
47
+ persistent=True,
48
+ )
49
+ else:
50
+ self.expert_bias = None
51
+
52
+ self._num_experts = num_experts
53
+ self._top_k = top_k
54
+ self._renormalize_probabilities = renormalize_probabilities
55
+
56
+ def forward(
57
+ self,
58
+ hidden_states: torch.Tensor
59
+ ) -> tuple[torch.Tensor, torch.Tensor]:
60
+ """
61
+ Calculates routing decisions for the input tokens.
62
+
63
+ Args:
64
+ hidden_states: Input tokens. Shape: `(num_tokens, dim)`.
65
+
66
+ Returns:
67
+ A tuple containing:
68
+
69
+ - Selected expert indices. Shape: `(num_tokens, top_k)`.
70
+ - Normalized routing weights for the selected experts. Shape: `(num_tokens, top_k)`.
71
+ """
72
+
73
+ # scores shape (bs*slen, num_experts)
74
+
75
+ # gate
76
+ scores = self.gate(hidden_states)
77
+
78
+ # and now do softmax (before top-k to be able to apply expert bias)
79
+ scores = F.softmax(scores, dim=-1, dtype=torch.float32)
80
+
81
+ # select top-k
82
+ if self.expert_bias is None:
83
+ scores, selected_experts_indices = torch.topk(
84
+ scores, k=self._top_k, dim=-1
85
+ )
86
+ else:
87
+ _, selected_experts_indices = torch.topk(
88
+ scores + self.expert_bias, k=self._top_k, dim=-1
89
+ )
90
+ scores = scores.gather(dim=-1, index=selected_experts_indices)
91
+
92
+ # re-normalize scores
93
+ denominator = scores.sum(dim=-1, keepdim=True) + 1e-20
94
+ scores = scores / denominator
95
+
96
+ return selected_experts_indices, scores
97
+
98
+ def reset_parameters(self):
99
+ """Resets module parameters."""
100
+ if self.expert_bias is not None:
101
+ nn.init.zeros_(self.expert_bias)
102
+
103
+ self.gate.reset_parameters()
@@ -0,0 +1,8 @@
1
+ """Provides modules for positional embeddings, such as Rotary Positional Embeddings."""
2
+
3
+ from .rope import RotaryEmbeddingApplicator, RotaryEmbeddingProvider
4
+
5
+ __all__ = [
6
+ "RotaryEmbeddingApplicator",
7
+ "RotaryEmbeddingProvider"
8
+ ]
@@ -0,0 +1,150 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+ from d9d.module.base import ModuleLateInit
5
+
6
+
7
+ def _prepare_rope_inverse_frequencies(
8
+ rope_base: int,
9
+ inside_dim: int
10
+ ) -> torch.Tensor:
11
+ """
12
+ Calculates inverse frequencies for RoPE calculation.
13
+
14
+ Args:
15
+ rope_base: Base for the geometric progression.
16
+ inside_dim: Dimension of the attention head (must be even).
17
+
18
+ Returns:
19
+ A tensor containing the inverse frequencies.
20
+ """
21
+
22
+ power = torch.arange(0, inside_dim, 2, dtype=torch.int64).to(dtype=torch.float) / inside_dim
23
+ freq = rope_base ** power
24
+ inv_freq = 1.0 / freq
25
+ return inv_freq
26
+
27
+
28
+ def prepare_rotary_cos_sin_emb(
29
+ rope_base: int,
30
+ head_dim: int,
31
+ max_position_ids: int,
32
+ device: torch.device,
33
+ dtype: torch.dtype
34
+ ) -> tuple[torch.Tensor, torch.Tensor]:
35
+ """
36
+ Precomputes rotary cosine and sine embeddings.
37
+
38
+ Args:
39
+ rope_base: Base frequency for calculation.
40
+ head_dim: Dimensionality of the attention head (E).
41
+ max_position_ids: Maximum sequence length supported (S).
42
+ device: Target device for the tensors.
43
+ dtype: Target data type for the tensors.
44
+
45
+ Returns:
46
+ A tuple containing cosine and sine tensors, both of shapes [S, E].
47
+ """
48
+
49
+ position_ids = torch.arange(0, max_position_ids, dtype=torch.long)
50
+ freqs = _prepare_rope_inverse_frequencies(rope_base, head_dim)
51
+
52
+ arguments = (freqs[:, None] @ position_ids[None, :].float()).T
53
+
54
+ emb = torch.cat((arguments, arguments), dim=-1)
55
+ cos = emb.cos()
56
+ sin = emb.sin()
57
+ return cos.to(device=device, dtype=dtype), sin.to(device=device, dtype=dtype)
58
+
59
+
60
+ class RotaryEmbeddingProvider(nn.Module, ModuleLateInit):
61
+ """Module that manages and provides Rotary Positional Embeddings."""
62
+
63
+ def __init__(self, rope_base: int, head_dim: int, max_position_ids: int):
64
+ """Constructs the RotaryEmbeddingProvider."""
65
+
66
+ super().__init__()
67
+ self._rope_base = rope_base
68
+ self._head_dim = head_dim
69
+ self._max_position_ids = max_position_ids
70
+ self.cos_emb = nn.Buffer(torch.empty(max_position_ids, head_dim), persistent=False)
71
+ self.sin_emb = nn.Buffer(torch.empty(max_position_ids, head_dim), persistent=False)
72
+
73
+ def forward(self, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
74
+ """
75
+ Retrieves cached cosine and sine embeddings for specific positions.
76
+
77
+ Args:
78
+ position_ids: Tensor of position indices.
79
+
80
+ Returns:
81
+ A tuple of (cos, sin) tensors aligned with the input positions.
82
+ """
83
+
84
+ return self.cos_emb[position_ids], self.sin_emb[position_ids]
85
+
86
+ def reset_parameters(self):
87
+ with torch.no_grad():
88
+ cos, sin = prepare_rotary_cos_sin_emb(
89
+ rope_base=self._rope_base,
90
+ head_dim=self._head_dim,
91
+ max_position_ids=self._max_position_ids,
92
+ device=self.cos_emb.device,
93
+ dtype=self.cos_emb.dtype
94
+ )
95
+ self.cos_emb.data = cos
96
+ self.sin_emb.data = sin
97
+
98
+
99
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
100
+ x1 = x[..., : x.shape[-1] // 2]
101
+ x2 = x[..., x.shape[-1] // 2:]
102
+ return torch.cat((-x2, x1), dim=-1)
103
+
104
+
105
+ def _apply_rotary_pos_emb(
106
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
107
+ ) -> tuple[torch.Tensor, torch.Tensor]:
108
+ cos = cos.unsqueeze(1)
109
+ sin = sin.unsqueeze(1)
110
+ q_embed = (q * cos) + (_rotate_half(q) * sin)
111
+ k_embed = (k * cos) + (_rotate_half(k) * sin)
112
+ return q_embed, k_embed
113
+
114
+
115
+ class RotaryEmbeddingApplicator(nn.Module):
116
+ """Applies Rotary Positional Embeddings (RoPE) to Q and K projections."""
117
+
118
+ def __init__(self):
119
+ """
120
+ Constructs RotaryEmbeddingApplicator object.
121
+ """
122
+
123
+ super().__init__()
124
+
125
+ def forward(
126
+ self,
127
+ query_states: torch.Tensor,
128
+ key_states: torch.Tensor,
129
+ position_embedding_cos: torch.Tensor,
130
+ position_embedding_sin: torch.Tensor
131
+ ) -> tuple[torch.Tensor, torch.Tensor]:
132
+ """
133
+ Rotates query and key states using provided cosine and sine embeddings.
134
+
135
+ Args:
136
+ query_states: Query tensor. Shape: `(batch, n_heads, seq_len, head_dim)`.
137
+ key_states: Key tensor. Shape: `(batch, n_kv_heads, seq_len, head_dim)`.
138
+ position_embedding_cos: Cosine values for positions.
139
+ Shape: `(batch, seq_len, head_dim)`.
140
+ position_embedding_sin: Sine values for positions.
141
+ Shape: `(batch, seq_len, head_dim)`.
142
+
143
+ Returns:
144
+ A tuple containing the rotated query and key tensors.
145
+ """
146
+
147
+ query_states, key_states = _apply_rotary_pos_emb(query_states, key_states,
148
+ position_embedding_cos, position_embedding_sin)
149
+
150
+ return query_states, key_states
File without changes
@@ -0,0 +1,16 @@
1
+ from .decoder_layer import Qwen3MoELayer
2
+ from .model import Qwen3MoEForCausalLM, Qwen3MoEModel
3
+ from .params import (
4
+ Qwen3MoEForCausalLMParameters,
5
+ Qwen3MoELayerParameters,
6
+ Qwen3MoEParameters,
7
+ )
8
+
9
+ __all__ = [
10
+ "Qwen3MoEForCausalLM",
11
+ "Qwen3MoEForCausalLMParameters",
12
+ "Qwen3MoELayer",
13
+ "Qwen3MoELayerParameters",
14
+ "Qwen3MoEModel",
15
+ "Qwen3MoEParameters"
16
+ ]
@@ -0,0 +1,110 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+ from d9d.module.base import ModuleLateInit
5
+ from d9d.module.block.attention import GroupedQueryAttention
6
+ from d9d.module.block.moe import MoELayer
7
+
8
+ from .params import Qwen3MoELayerParameters
9
+
10
+
11
+ class Qwen3MoELayer(nn.Module, ModuleLateInit):
12
+ """
13
+ Implements a single Qwen3 Mixture-of-Experts (MoE) transformer layer.
14
+
15
+ This layer consists of a Grouped Query Attention mechanism followed by an MoE
16
+ MLP block, with pre-RMSNorm applied before each sub-layer.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ params: Qwen3MoELayerParameters
22
+ ):
23
+ """
24
+ Constructs a Qwen3MoELayer object.
25
+
26
+ Args:
27
+ params: Configuration parameters for the layer.
28
+ """
29
+
30
+ super().__init__()
31
+
32
+ self.self_attn = GroupedQueryAttention(
33
+ hidden_size=params.hidden_size,
34
+ num_attention_heads=params.num_attention_heads,
35
+ num_key_value_heads=params.num_key_value_heads,
36
+ is_causal=True,
37
+ qk_norm_eps=params.rms_norm_eps,
38
+ head_dim=params.head_dim
39
+ )
40
+
41
+ self.mlp = MoELayer(
42
+ hidden_dim=params.hidden_size,
43
+ num_grouped_experts=params.num_experts,
44
+ intermediate_dim_grouped=params.intermediate_size,
45
+ top_k=params.experts_top_k,
46
+ router_renormalize_probabilities=True
47
+ )
48
+
49
+ self.input_layernorm = nn.RMSNorm(params.hidden_size, eps=params.rms_norm_eps)
50
+ self.post_attention_layernorm = nn.RMSNorm(params.hidden_size, eps=params.rms_norm_eps)
51
+
52
+ def forward(
53
+ self,
54
+ hidden_states: torch.Tensor,
55
+ position_embeddings: tuple[torch.Tensor, torch.Tensor]
56
+ ) -> torch.Tensor:
57
+ """
58
+ Performs the forward pass of the MoE layer.
59
+
60
+ Args:
61
+ hidden_states: Input tensor of shape `(batch, seq_len, hidden_dim)`.
62
+ position_embeddings: Tuple containing RoPE precomputed embeddings (cos, sin).
63
+
64
+ Returns:
65
+ Output tensor after attention and MoE blocks, shape `(batch, seq_len, hidden_dim)`.
66
+ """
67
+
68
+ residual = hidden_states
69
+
70
+ hidden_states = self.input_layernorm(hidden_states)
71
+
72
+ hidden_states = self.self_attn(
73
+ hidden_states=hidden_states,
74
+ position_embeddings=position_embeddings,
75
+ attention_mask=None # no mask for moe decoder
76
+ )
77
+ hidden_states = residual + hidden_states
78
+
79
+ residual = hidden_states
80
+ hidden_states = self.post_attention_layernorm(hidden_states)
81
+ hidden_states = self.mlp(hidden_states)
82
+
83
+ hidden_states = residual + hidden_states
84
+
85
+ return hidden_states
86
+
87
+ def reset_moe_stats(self):
88
+ """
89
+ Resets statistical counters inside the MoE router (e.g., token counts per expert).
90
+ """
91
+
92
+ self.mlp.reset_stats()
93
+
94
+ @property
95
+ def moe_tokens_per_expert(self) -> torch.Tensor:
96
+ """
97
+ Returns the number of tokens routed to each expert.
98
+ """
99
+
100
+ return self.mlp.tokens_per_expert
101
+
102
+ def reset_parameters(self):
103
+ """
104
+ Resets module parameters.
105
+ """
106
+
107
+ self.self_attn.reset_parameters()
108
+ self.mlp.reset_parameters()
109
+ self.input_layernorm.reset_parameters()
110
+ self.post_attention_layernorm.reset_parameters()