d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,139 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+ from d9d.module.base import ModuleLateInit
5
+ from d9d.module.block.attention.sdpa import FlashSdpa
6
+ from d9d.module.block.positional import RotaryEmbeddingApplicator
7
+
8
+
9
+ class GroupedQueryAttention(nn.Module, ModuleLateInit):
10
+ """
11
+ Implements Grouped Query Attention (GQA) with RoPE and optional QK Normalization.
12
+
13
+ This module performs the full attention mechanism pipeline:
14
+ 1. Linear projection to Q, K, V.
15
+ 2. Optional RMS Normalization on Q and K.
16
+ 3. Rotary Positional Embedding (RoPE) application.
17
+ 4. Scaled Dot Product Attention (via FlashAttention).
18
+ 5. Output projection.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ hidden_size: int,
24
+ num_attention_heads: int,
25
+ num_key_value_heads: int,
26
+ head_dim: int,
27
+ qk_norm_eps: float | None,
28
+ is_causal: bool
29
+ ):
30
+ """
31
+ Constructs the GroupedQueryAttention layer.
32
+
33
+ Args:
34
+ hidden_size: Hidden size.
35
+ num_attention_heads: Number of Query heads.
36
+ num_key_value_heads: Number of Key/Value heads. If less than `num_attention_heads`, GQA/MQA is enabled.
37
+ head_dim: Dimensionality of a single attention head.
38
+ qk_norm_eps: Epsilon for LayerNorm/RMSNorm applied to Q and K. If None, normalization is disabled.
39
+ is_causal: Whether to apply a causal mask (auto-regressive constraint).
40
+ """
41
+
42
+ super().__init__()
43
+
44
+ self._head_dim = head_dim
45
+ self._num_key_value_groups = num_attention_heads // num_key_value_heads
46
+ self._scaling = head_dim ** -0.5
47
+
48
+ self.q_proj = nn.Linear(
49
+ hidden_size, num_attention_heads * head_dim, bias=False
50
+ )
51
+
52
+ self.k_proj = nn.Linear(
53
+ hidden_size, num_key_value_heads * head_dim, bias=False
54
+ )
55
+
56
+ self.v_proj = nn.Linear(
57
+ hidden_size, num_key_value_heads * head_dim, bias=False
58
+ )
59
+
60
+ self.o_proj = nn.Linear(
61
+ num_attention_heads * head_dim, hidden_size, bias=False
62
+ )
63
+
64
+ self.q_norm: nn.RMSNorm | None
65
+ self.k_norm: nn.RMSNorm | None
66
+
67
+ if qk_norm_eps is not None:
68
+ self.q_norm = nn.RMSNorm(normalized_shape=head_dim,
69
+ eps=qk_norm_eps)
70
+ self.k_norm = nn.RMSNorm(normalized_shape=head_dim,
71
+ eps=qk_norm_eps)
72
+ else:
73
+ self.q_norm = None
74
+ self.k_norm = None
75
+
76
+ self.rope = RotaryEmbeddingApplicator()
77
+ self.kernel = FlashSdpa()
78
+ self._is_causal = is_causal
79
+
80
+ def forward(
81
+ self,
82
+ hidden_states: torch.Tensor,
83
+ attention_mask: torch.Tensor | None,
84
+ position_embeddings: tuple[torch.Tensor, torch.Tensor]
85
+ ) -> torch.Tensor:
86
+ """
87
+ Computes the attention operation.
88
+
89
+ Args:
90
+ hidden_states: Input tensor. Shape: `(batch, seq_len, hidden_size)`.
91
+ attention_mask: Optional mask associated with the inputs.
92
+ position_embeddings: Tuple of `(cos, sin)` tensors for RoPE application.
93
+ Each tensor should be of shape `(batch, seq_len, head_dim)`
94
+
95
+ Returns:
96
+ The attention output tensor. Shape: `(batch, seq_len, hidden_size)`.
97
+ """
98
+
99
+ input_shape = hidden_states.shape[:-1]
100
+ hidden_shape = (*input_shape, -1, self._head_dim)
101
+
102
+ query_states = self.q_proj(hidden_states).view(hidden_shape)
103
+ if self.q_norm is not None:
104
+ query_states = self.q_norm(query_states)
105
+ query_states = query_states.transpose(1, 2)
106
+
107
+ key_states = self.k_proj(hidden_states).view(hidden_shape)
108
+ if self.k_norm is not None:
109
+ key_states = self.k_norm(key_states)
110
+ key_states = key_states.transpose(1, 2)
111
+
112
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
113
+
114
+ query_states, key_states = self.rope(query_states, key_states, position_embeddings[0], position_embeddings[1])
115
+
116
+ outputs = self.kernel(
117
+ query_states,
118
+ key_states,
119
+ value_states,
120
+ attention_mask=attention_mask,
121
+ is_causal=self._is_causal,
122
+ scale=self._scaling
123
+ )
124
+
125
+ outputs = outputs.reshape(*input_shape, -1).contiguous()
126
+ outputs = self.o_proj(outputs)
127
+ return outputs
128
+
129
+ def reset_parameters(self):
130
+ """Resets module parameters."""
131
+
132
+ self.q_proj.reset_parameters()
133
+ self.k_proj.reset_parameters()
134
+ self.v_proj.reset_parameters()
135
+ self.o_proj.reset_parameters()
136
+ if self.q_norm is not None:
137
+ self.q_norm.reset_parameters()
138
+ if self.k_norm is not None:
139
+ self.k_norm.reset_parameters()
@@ -0,0 +1,5 @@
1
+ from .flash import FlashSdpa
2
+
3
+ __all__ = [
4
+ "FlashSdpa"
5
+ ]
@@ -0,0 +1,52 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from torch.nn.attention import SDPBackend, sdpa_kernel
5
+
6
+
7
+ class FlashSdpa(nn.Module):
8
+ """Executes Scaled Dot Product Attention (SDPA) enforcing the FlashAttention backend."""
9
+
10
+ def __init__(self):
11
+ """
12
+ Constructs the FlashSdpa object.
13
+ """
14
+ super().__init__()
15
+
16
+ def forward(
17
+ self,
18
+ query_states: torch.Tensor,
19
+ key_states: torch.Tensor,
20
+ value_states: torch.Tensor,
21
+ attention_mask: torch.Tensor | None,
22
+ is_causal: bool,
23
+ scale: float
24
+ ) -> torch.Tensor:
25
+ """
26
+ Computes Scaled Dot-Product Attention using FlashAttention.
27
+
28
+ Args:
29
+ query_states: Query tensor. Shape: `(batch, n_q_heads, seq_len, head_dim)`.
30
+ key_states: Key tensor. Shape: `(batch, n_kv_heads, seq_len, head_dim)`.
31
+ value_states: Value tensor. Shape: `(batch, n_kv_heads, seq_len, head_dim)`.
32
+ attention_mask: Optional attention mask (usually not needed for FlashAttn with causal=True).
33
+ is_causal: If True, applies a causal mask (upper triangular masking).
34
+ scale: Scaling factor applied to the dot products (usually `1 / sqrt(head_dim)`).
35
+
36
+ Returns:
37
+ The attention output tensor, permuted to channel-last format.
38
+ Shape: `(batch, seq_len, n_q_heads, head_dim)`.
39
+ """
40
+
41
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
42
+ results = F.scaled_dot_product_attention(
43
+ query_states,
44
+ key_states,
45
+ value_states,
46
+ attn_mask=attention_mask,
47
+ dropout_p=0.0,
48
+ is_causal=is_causal,
49
+ scale=scale,
50
+ enable_gqa=query_states.shape[1] != key_states.shape[1]
51
+ )
52
+ return results.transpose(1, 2).contiguous()
@@ -0,0 +1,7 @@
1
+ """Package providing various embedding layer implementations"""
2
+
3
+ from .shard_token_embedding import SplitTokenEmbeddings
4
+
5
+ __all__ = [
6
+ "SplitTokenEmbeddings"
7
+ ]
@@ -0,0 +1,103 @@
1
+ from collections.abc import Mapping, Sequence
2
+ from typing import cast
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from d9d.module.base import ModuleLateInit
8
+
9
+
10
+ def _build_token_start_end_indices(
11
+ split_vocab_size: dict[str, int], split_order: Sequence[str]
12
+ ) -> tuple[dict[str, int], dict[str, int]]:
13
+ offset = 0
14
+ starts = {}
15
+ ends = {}
16
+ for split in split_order:
17
+ current_size = split_vocab_size[split]
18
+
19
+ starts[split] = offset
20
+ ends[split] = offset + current_size
21
+
22
+ offset += current_size
23
+ return starts, ends
24
+
25
+
26
+ class SplitTokenEmbeddings(nn.Module, ModuleLateInit):
27
+ """
28
+ A token embedding layer composed of multiple named, independent embedding tables.
29
+
30
+ This class maintains a dictionary of embedding layers, mapping contiguous
31
+ ranges of global vocabulary indices to specific named splits (e.g., 'orig',
32
+ 'special', 'prompt_prefix'). This is useful for model adaptation strategies where
33
+ different sets of tokens require different initialization training behaviors.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ split_vocab_size: dict[str, int],
39
+ split_order: Sequence[str],
40
+ hidden_size: int
41
+ ):
42
+ """
43
+ Constructs the SplitTokenEmbeddings object.
44
+
45
+ Args:
46
+ split_vocab_size: A dictionary mapping split names to their vocabulary sizes.
47
+ split_order: A sequence defining the order in which splits are concatenated
48
+ to form the global vocabulary. Keys provided here must exist in
49
+ split_vocab_size.
50
+ hidden_size: The dimensionality of the embedding vectors.
51
+ """
52
+
53
+ super().__init__()
54
+
55
+ token_embedding = nn.ModuleDict({
56
+ split_name: nn.Embedding(vocab_size, hidden_size)
57
+ for split_name, vocab_size in split_vocab_size.items()
58
+ })
59
+ self.token_embedding: Mapping[str, nn.Embedding] = cast(Mapping[str, nn.Embedding], token_embedding)
60
+
61
+ self._id_start, self._id_end = _build_token_start_end_indices(split_vocab_size, split_order)
62
+ self._hidden_size = hidden_size
63
+ self._split_order = split_order
64
+
65
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
66
+ """
67
+ Retrieves embeddings for the input indices by routing them to appropriate internal layers.
68
+
69
+ Args:
70
+ input_ids: Tensor of arbitrary shape containing global vocabulary indices.
71
+
72
+ Returns:
73
+ Tensor of same shape as input_ids plus a last dimension of hidden_size.
74
+ """
75
+
76
+ output_embeds: torch.Tensor | None = None
77
+
78
+ for split_name in self._split_order:
79
+ start_idx = self._id_start[split_name]
80
+ end_idx = self._id_end[split_name]
81
+ layer = self.token_embedding[split_name]
82
+ mask = (input_ids >= start_idx) & (input_ids < end_idx)
83
+
84
+ safe_ids = torch.where(mask, input_ids - start_idx, 0)
85
+ masked_embed = layer(safe_ids) * mask[..., None]
86
+
87
+ if output_embeds is None:
88
+ output_embeds = masked_embed
89
+ else:
90
+ output_embeds = output_embeds + masked_embed
91
+
92
+ if output_embeds is None:
93
+ raise ValueError("Embeddings are empty - perhaps no splits were configured")
94
+
95
+ return output_embeds
96
+
97
+ def reset_parameters(self):
98
+ """
99
+ Resets parameters for all registered embedding splits.
100
+ """
101
+
102
+ for layer in self.token_embedding.values():
103
+ layer.reset_parameters()
@@ -0,0 +1,5 @@
1
+ from .swiglu import SwiGLU
2
+
3
+ __all__ = [
4
+ "SwiGLU"
5
+ ]
@@ -0,0 +1,60 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+ from d9d.kernel.swiglu import silu_mul
5
+ from d9d.module.base import ModuleLateInit
6
+
7
+
8
+ class SwiGLU(nn.Module, ModuleLateInit):
9
+ """
10
+ Implements the SwiGLU Feed-Forward Network (FFN).
11
+
12
+ This module applies the gated activation function: `down(SiLU(gate(x)) * up(x))`.
13
+ It corresponds to the standard MLP block used in architectures like LLaMA.
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ hidden_size: int,
19
+ intermediate_size: int
20
+ ):
21
+ """
22
+ Constructs a SwiGLU object.
23
+
24
+ Args:
25
+ hidden_size: The hidden dim size.
26
+ intermediate_size: The intermediate dim size of the FFN.
27
+ """
28
+
29
+ super().__init__()
30
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size)
31
+ self.up_proj = nn.Linear(hidden_size, intermediate_size)
32
+ self.down_proj = nn.Linear(intermediate_size, hidden_size)
33
+
34
+ def forward(
35
+ self,
36
+ x: torch.Tensor
37
+ ) -> torch.Tensor:
38
+ """
39
+ Applies the SwiGLU FFN to the input.
40
+
41
+ Args:
42
+ x: Input tensor. Shape: `(batch_size, seq_len, hidden_dim)`.
43
+
44
+ Returns:
45
+ Output tensor. Shape: `(batch_size, seq_len, hidden_dim)`.
46
+ """
47
+
48
+ return self.down_proj(
49
+ silu_mul(
50
+ self.gate_proj(x),
51
+ self.up_proj(x)
52
+ )
53
+ )
54
+
55
+ def reset_parameters(self):
56
+ """Resets module parameters."""
57
+
58
+ self.gate_proj.reset_parameters()
59
+ self.up_proj.reset_parameters()
60
+ self.down_proj.reset_parameters()
@@ -0,0 +1,6 @@
1
+ from .language_modelling import LM_IGNORE_INDEX, SplitLanguageModellingHead
2
+
3
+ __all__ = [
4
+ "LM_IGNORE_INDEX",
5
+ "SplitLanguageModellingHead"
6
+ ]
@@ -0,0 +1,87 @@
1
+ from collections.abc import Mapping, Sequence
2
+ from typing import cast
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from d9d.kernel.cce import linear_cross_entropy
8
+ from d9d.module.base import ModuleLateInit
9
+
10
+ LM_IGNORE_INDEX = -100
11
+ """Index ignored by LM head while calculating logps"""
12
+
13
+
14
+ class SplitLanguageModellingHead(nn.Module, ModuleLateInit):
15
+ """
16
+ A segmented language modeling head that computes per-token cross-entropy loss values using a composed weight matrix.
17
+
18
+ This class maintains separate linear layers for different segments of the vocabulary
19
+ (e.g., regular vs. special tokens). During the forward pass, it concatenates the
20
+ weights to form a unified projection matrix and computes the cross-entropy loss
21
+ efficiently, typically using a fused kernel to avoid materializing full logits.
22
+
23
+ The concatenation order of the weights is determined by `split_order`, which ensures
24
+ consistency with the global vocabulary indices.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ split_vocab_size: dict[str, int],
30
+ split_order: Sequence[str],
31
+ hidden_size: int
32
+ ):
33
+ """
34
+ Constructs the SplitLanguageModellingHead object.
35
+
36
+ Args:
37
+ split_vocab_size: A dictionary mapping split names to their output vocabulary sizes.
38
+ split_order: A sequence defining the order in which vocabulary segments should be
39
+ concatenated. This determines the mapping of global indices to specific heads.
40
+ hidden_size: The input dimensionality (hidden state size).
41
+ """
42
+
43
+ super().__init__()
44
+
45
+ lm_head = nn.ModuleDict({
46
+ split_name: nn.Linear(hidden_size, vocab_size, bias=False)
47
+ for split_name, vocab_size in split_vocab_size.items()
48
+ })
49
+
50
+ self.lm_head: Mapping[str, nn.Linear] = cast(Mapping[str, nn.Linear], lm_head)
51
+ self._split_order = split_order
52
+ self._hidden_size = hidden_size
53
+
54
+ def forward(
55
+ self,
56
+ hidden_states: torch.Tensor,
57
+ labels: torch.Tensor
58
+ ) -> torch.Tensor:
59
+ """
60
+ Computes the cross-entropy loss for the given hidden states and labels.
61
+
62
+ Args:
63
+ hidden_states: Input tensor of shape `(B, S, H)`.
64
+ labels: Target label tensor of shape `(B, S)`. Indices must correspond
65
+ to the global vocabulary formed by concatenating splits in `split_order`.
66
+
67
+ Returns:
68
+ A tensor containing per-token loss values (reduction='none'), matching the
69
+ shape of the labels tensor.
70
+ """
71
+
72
+ lm_head_weight = torch.cat([self.lm_head[split_name].weight for split_name in self._split_order], dim=0)
73
+
74
+ losses = linear_cross_entropy(
75
+ hidden_states,
76
+ lm_head_weight,
77
+ labels,
78
+ ignore_index=LM_IGNORE_INDEX,
79
+ reduction="none"
80
+ )
81
+ return losses
82
+
83
+ def reset_parameters(self):
84
+ """Resets module parameters."""
85
+
86
+ for head in self.lm_head.values():
87
+ head.reset_parameters()
@@ -0,0 +1,12 @@
1
+ """
2
+ Aggregation utilities for model hidden states.
3
+ """
4
+
5
+ from .base import BaseHiddenStatesAggregator
6
+ from .factory import HiddenStatesAggregationMode, create_hidden_states_aggregator
7
+
8
+ __all__ = [
9
+ "BaseHiddenStatesAggregator",
10
+ "HiddenStatesAggregationMode",
11
+ "create_hidden_states_aggregator"
12
+ ]
@@ -0,0 +1,35 @@
1
+ import abc
2
+
3
+ import torch
4
+
5
+
6
+ class BaseHiddenStatesAggregator(abc.ABC):
7
+ """Abstract base class for hidden states aggregation strategies.
8
+
9
+ This interface defines how hidden states should be collected (added) and
10
+ how they should be finalized (packed) combined with optional historical snapshots.
11
+ """
12
+
13
+ @abc.abstractmethod
14
+ def add_hidden_states(self, hidden_states: torch.Tensor) -> None:
15
+ """Accumulates a batch of hidden states into the aggregator.
16
+
17
+ Args:
18
+ hidden_states: The tensor containing the hidden states to process.
19
+ """
20
+
21
+ @abc.abstractmethod
22
+ def pack_with_snapshot(self, snapshot: torch.Tensor | None) -> torch.Tensor | None:
23
+ """Finalizes the aggregation and combines it with an optional previous snapshot.
24
+
25
+ This method typically retrieves the accumulated states, processes them
26
+ (if not done during addition), and concatenates them with the snapshot.
27
+
28
+ Args:
29
+ snapshot: An optional tensor representing previously aggregated states
30
+ to be prepended to the current collection.
31
+
32
+ Returns:
33
+ The combined result of the snapshot and the newly aggregated states,
34
+ or None if no states were collected.
35
+ """
@@ -0,0 +1,48 @@
1
+ from enum import StrEnum
2
+
3
+ import torch
4
+
5
+ from .base import BaseHiddenStatesAggregator
6
+ from .mean import HiddenStatesAggregatorMean
7
+ from .noop import HiddenStatesAggregatorNoOp
8
+
9
+
10
+ class HiddenStatesAggregationMode(StrEnum):
11
+ """Enumeration of available hidden state aggregation strategies.
12
+
13
+ Attributes:
14
+ no: Performs no aggregation (No-Op).
15
+ mean: Computes the mean of hidden states, taking a mask into account.
16
+ """
17
+
18
+ no = "no"
19
+ mean = "mean"
20
+
21
+
22
+ def create_hidden_states_aggregator(
23
+ mode: HiddenStatesAggregationMode, agg_mask: torch.Tensor | None
24
+ ) -> BaseHiddenStatesAggregator:
25
+ """Factory function to create a hidden states aggregator.
26
+
27
+ Args:
28
+ mode: The specific aggregation mode to instantiate.
29
+ agg_mask: A tensor mask required for specific modes.
30
+ Can be None if the selected mode does not require masking.
31
+
32
+ Returns:
33
+ An instance of a concrete BaseHiddenStatesAggregator subclass.
34
+
35
+ Raises:
36
+ ValueError: If 'mean' mode is selected but 'agg_mask' is None, or if
37
+ an unknown mode is provided.
38
+ """
39
+
40
+ match mode:
41
+ case HiddenStatesAggregationMode.no:
42
+ return HiddenStatesAggregatorNoOp()
43
+ case HiddenStatesAggregationMode.mean:
44
+ if agg_mask is None:
45
+ raise ValueError("You have to specify aggregation mask")
46
+ return HiddenStatesAggregatorMean(agg_mask)
47
+ case _:
48
+ raise ValueError("Unknown hidden states aggregation mode")
@@ -0,0 +1,61 @@
1
+ import torch
2
+
3
+ from .base import BaseHiddenStatesAggregator
4
+
5
+
6
+ def _aggregate_hidden_states(
7
+ hidden_states: torch.Tensor, agg_mask: torch.Tensor
8
+ ) -> torch.Tensor:
9
+ orig_dtype = hidden_states.dtype
10
+ hidden_states = hidden_states.float()
11
+ num_tokens = agg_mask.sum(dim=1)[:, None]
12
+ masked_states = hidden_states * agg_mask[:, :, None]
13
+ averaged_states = masked_states.sum(dim=1) / num_tokens
14
+ return averaged_states.to(orig_dtype)
15
+
16
+
17
+ class HiddenStatesAggregatorMean(BaseHiddenStatesAggregator):
18
+ """Aggregator that computes the mean of hidden states using a validity mask."""
19
+
20
+ def __init__(self, agg_mask: torch.Tensor) -> None:
21
+ """Constructs the mean aggregator with the given mask.
22
+
23
+ Args:
24
+ agg_mask: A tensor used to mask out padding or invalid tokens
25
+ during average calculation.
26
+ """
27
+ self._agg_mask = agg_mask
28
+ self._collected_states: list[torch.Tensor] = []
29
+
30
+ def add_hidden_states(self, hidden_states: torch.Tensor) -> None:
31
+ """Calculates the masked mean immediately and stores the result.
32
+
33
+ Args:
34
+ hidden_states: The raw hidden states to be averaged and stored.
35
+ """
36
+ agg = _aggregate_hidden_states(
37
+ hidden_states=hidden_states,
38
+ agg_mask=self._agg_mask
39
+ )
40
+ self._collected_states.append(agg)
41
+
42
+ def pack_with_snapshot(self, snapshot: torch.Tensor | None) -> torch.Tensor | None:
43
+ """Stacks collected projected averages and appends to the snapshot.
44
+
45
+ This operation clears the internal buffer of collected states.
46
+
47
+ Args:
48
+ snapshot: Previous states to prepend.
49
+
50
+ Returns:
51
+ A tensor containing the snapshot followed by the stacked collected states,
52
+ or None if nothing was collected.
53
+ """
54
+ if len(self._collected_states) == 0:
55
+ return None
56
+
57
+ stacked = torch.stack(self._collected_states, dim=0)
58
+ self._collected_states.clear()
59
+ if snapshot is not None:
60
+ stacked = torch.cat([snapshot, stacked], dim=0)
61
+ return stacked
@@ -0,0 +1,27 @@
1
+ import torch
2
+
3
+ from .base import BaseHiddenStatesAggregator
4
+
5
+
6
+ class HiddenStatesAggregatorNoOp(BaseHiddenStatesAggregator):
7
+ """Aggregator implementation that performs no operations.
8
+
9
+ This acts as a null object for cases where aggregation is disabled in the configuration.
10
+ """
11
+
12
+ def add_hidden_states(self, hidden_states: torch.Tensor) -> None:
13
+ """Does nothing.
14
+
15
+ Args:
16
+ hidden_states: Ignored.
17
+ """
18
+
19
+ def pack_with_snapshot(self, snapshot: torch.Tensor | None) -> torch.Tensor | None:
20
+ """Does nothing.
21
+
22
+ Args:
23
+ snapshot: Ignored.
24
+
25
+ Returns:
26
+ None.
27
+ """
@@ -0,0 +1,13 @@
1
+ """Provides building blocks for Mixture-of-Experts (MoE) architectures."""
2
+
3
+ from .grouped_experts import GroupedSwiGLU
4
+ from .grouped_linear import GroupedLinear
5
+ from .layer import MoELayer
6
+ from .router import TopKRouter
7
+
8
+ __all__ = [
9
+ "GroupedLinear",
10
+ "GroupedSwiGLU",
11
+ "MoELayer",
12
+ "TopKRouter"
13
+ ]
@@ -0,0 +1,11 @@
1
+ """Provides communication strategies for Mixture-of-Experts routing operations."""
2
+
3
+ from .base import ExpertCommunicationHandler
4
+ from .deepep import DeepEpCommunicationHandler
5
+ from .naive import NoCommunicationHandler
6
+
7
+ __all__ = [
8
+ "DeepEpCommunicationHandler",
9
+ "ExpertCommunicationHandler",
10
+ "NoCommunicationHandler"
11
+ ]