d9d 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. d9d/__init__.py +0 -0
  2. d9d/core/__init__.py +0 -0
  3. d9d/core/autograd/__init__.py +7 -0
  4. d9d/core/autograd/grad_context.py +85 -0
  5. d9d/core/dist_context/__init__.py +19 -0
  6. d9d/core/dist_context/configured.py +215 -0
  7. d9d/core/dist_context/device_mesh_domains.py +185 -0
  8. d9d/core/dist_context/log.py +30 -0
  9. d9d/core/dist_context/params.py +113 -0
  10. d9d/core/dist_ops/__init__.py +16 -0
  11. d9d/core/dist_ops/object.py +68 -0
  12. d9d/core/dist_ops/tensor.py +192 -0
  13. d9d/core/protocol/__init__.py +8 -0
  14. d9d/core/protocol/training.py +38 -0
  15. d9d/core/sharding/__init__.py +15 -0
  16. d9d/core/sharding/auto_spec.py +66 -0
  17. d9d/core/sharding/shard.py +154 -0
  18. d9d/core/sharding/spec.py +28 -0
  19. d9d/core/sharding/unshard.py +117 -0
  20. d9d/core/types/__init__.py +12 -0
  21. d9d/core/types/data.py +14 -0
  22. d9d/core/types/pytree.py +26 -0
  23. d9d/dataset/__init__.py +17 -0
  24. d9d/dataset/buffer_sorted.py +143 -0
  25. d9d/dataset/padding.py +79 -0
  26. d9d/dataset/sharded.py +195 -0
  27. d9d/internals/__init__.py +0 -0
  28. d9d/internals/determinism/__init__.py +10 -0
  29. d9d/internals/determinism/seed.py +63 -0
  30. d9d/internals/grad_norm/__init__.py +8 -0
  31. d9d/internals/grad_norm/group.py +87 -0
  32. d9d/internals/grad_norm/norm.py +169 -0
  33. d9d/internals/grad_sync/__init__.py +14 -0
  34. d9d/internals/grad_sync/bucket.py +317 -0
  35. d9d/internals/grad_sync/placement_helper.py +23 -0
  36. d9d/internals/grad_sync/synchronizer.py +257 -0
  37. d9d/internals/pipeline_state/__init__.py +14 -0
  38. d9d/internals/pipeline_state/api.py +45 -0
  39. d9d/internals/pipeline_state/handler.py +111 -0
  40. d9d/internals/pipeline_state/storage.py +236 -0
  41. d9d/internals/profiling/__init__.py +7 -0
  42. d9d/internals/profiling/profile.py +112 -0
  43. d9d/internals/state/__init__.py +6 -0
  44. d9d/internals/state/main_process.py +44 -0
  45. d9d/kernel/__init__.py +0 -0
  46. d9d/kernel/cce/__init__.py +5 -0
  47. d9d/kernel/cce/cce.py +298 -0
  48. d9d/kernel/cce/main.py +282 -0
  49. d9d/kernel/general/__init__.py +5 -0
  50. d9d/kernel/general/get_int_dtype.py +7 -0
  51. d9d/kernel/gmm/__init__.py +5 -0
  52. d9d/kernel/gmm/function.py +78 -0
  53. d9d/kernel/moe/__init__.py +8 -0
  54. d9d/kernel/moe/indices_to_multihot.py +268 -0
  55. d9d/kernel/moe/permute_with_probs.py +1035 -0
  56. d9d/kernel/stochastic/__init__.py +11 -0
  57. d9d/kernel/stochastic/adamw_step.py +204 -0
  58. d9d/kernel/stochastic/copy.py +104 -0
  59. d9d/kernel/stochastic/ops/__init__.py +5 -0
  60. d9d/kernel/stochastic/ops/round.py +22 -0
  61. d9d/kernel/swiglu/__init__.py +5 -0
  62. d9d/kernel/swiglu/function.py +36 -0
  63. d9d/kernel/swiglu/op.py +167 -0
  64. d9d/loop/__init__.py +0 -0
  65. d9d/loop/auto/__init__.py +9 -0
  66. d9d/loop/auto/auto_lr_scheduler.py +46 -0
  67. d9d/loop/auto/auto_optimizer.py +196 -0
  68. d9d/loop/component/__init__.py +35 -0
  69. d9d/loop/component/batch_maths.py +106 -0
  70. d9d/loop/component/checkpointer.py +172 -0
  71. d9d/loop/component/data_loader_factory.py +258 -0
  72. d9d/loop/component/garbage_collector.py +94 -0
  73. d9d/loop/component/gradient_clipper.py +89 -0
  74. d9d/loop/component/gradient_manager.py +149 -0
  75. d9d/loop/component/job_logger.py +146 -0
  76. d9d/loop/component/job_profiler.py +62 -0
  77. d9d/loop/component/loss_computer.py +86 -0
  78. d9d/loop/component/model_stage_exporter.py +37 -0
  79. d9d/loop/component/model_stage_factory.py +261 -0
  80. d9d/loop/component/optimizer_factory.py +88 -0
  81. d9d/loop/component/stepper.py +52 -0
  82. d9d/loop/component/timeout_manager.py +54 -0
  83. d9d/loop/component/train_task_operator.py +152 -0
  84. d9d/loop/config/__init__.py +36 -0
  85. d9d/loop/config/config.py +225 -0
  86. d9d/loop/config/types.py +24 -0
  87. d9d/loop/control/__init__.py +61 -0
  88. d9d/loop/control/dataset_provider.py +58 -0
  89. d9d/loop/control/lr_scheduler_provider.py +47 -0
  90. d9d/loop/control/model_provider.py +162 -0
  91. d9d/loop/control/optimizer_provider.py +45 -0
  92. d9d/loop/control/task.py +304 -0
  93. d9d/loop/run/__init__.py +6 -0
  94. d9d/loop/run/train.py +355 -0
  95. d9d/loop/state.py +143 -0
  96. d9d/lr_scheduler/__init__.py +9 -0
  97. d9d/lr_scheduler/piecewise/__init__.py +18 -0
  98. d9d/lr_scheduler/piecewise/builder.py +152 -0
  99. d9d/lr_scheduler/piecewise/config.py +176 -0
  100. d9d/lr_scheduler/piecewise/curves.py +75 -0
  101. d9d/lr_scheduler/piecewise/engine.py +76 -0
  102. d9d/lr_scheduler/visualizer.py +74 -0
  103. d9d/metric/__init__.py +10 -0
  104. d9d/metric/abc.py +79 -0
  105. d9d/metric/impl/__init__.py +7 -0
  106. d9d/metric/impl/compose.py +54 -0
  107. d9d/metric/impl/mean.py +94 -0
  108. d9d/model_state/__init__.py +0 -0
  109. d9d/model_state/io/__init__.py +21 -0
  110. d9d/model_state/io/dto.py +30 -0
  111. d9d/model_state/io/module_reader.py +75 -0
  112. d9d/model_state/io/module_writer.py +123 -0
  113. d9d/model_state/io/reader.py +125 -0
  114. d9d/model_state/io/writer.py +309 -0
  115. d9d/model_state/mapper/__init__.py +10 -0
  116. d9d/model_state/mapper/abc.py +70 -0
  117. d9d/model_state/mapper/adapters/__init__.py +12 -0
  118. d9d/model_state/mapper/adapters/mapper.py +27 -0
  119. d9d/model_state/mapper/adapters/module.py +22 -0
  120. d9d/model_state/mapper/compose/__init__.py +17 -0
  121. d9d/model_state/mapper/compose/helper.py +22 -0
  122. d9d/model_state/mapper/compose/parallel.py +58 -0
  123. d9d/model_state/mapper/compose/sequential.py +131 -0
  124. d9d/model_state/mapper/compose/shard.py +36 -0
  125. d9d/model_state/mapper/leaf/__init__.py +18 -0
  126. d9d/model_state/mapper/leaf/dtensor.py +56 -0
  127. d9d/model_state/mapper/leaf/identity.py +23 -0
  128. d9d/model_state/mapper/leaf/rename.py +26 -0
  129. d9d/model_state/mapper/leaf/select_child.py +37 -0
  130. d9d/model_state/mapper/leaf/stack.py +29 -0
  131. d9d/module/__init__.py +0 -0
  132. d9d/module/base/__init__.py +7 -0
  133. d9d/module/base/late_init.py +10 -0
  134. d9d/module/block/__init__.py +0 -0
  135. d9d/module/block/attention/__init__.py +7 -0
  136. d9d/module/block/attention/grouped_query.py +139 -0
  137. d9d/module/block/attention/sdpa/__init__.py +5 -0
  138. d9d/module/block/attention/sdpa/flash.py +52 -0
  139. d9d/module/block/embedding/__init__.py +7 -0
  140. d9d/module/block/embedding/shard_token_embedding.py +103 -0
  141. d9d/module/block/ffn/__init__.py +5 -0
  142. d9d/module/block/ffn/swiglu.py +60 -0
  143. d9d/module/block/head/__init__.py +6 -0
  144. d9d/module/block/head/language_modelling.py +87 -0
  145. d9d/module/block/hidden_states_aggregator/__init__.py +12 -0
  146. d9d/module/block/hidden_states_aggregator/base.py +35 -0
  147. d9d/module/block/hidden_states_aggregator/factory.py +48 -0
  148. d9d/module/block/hidden_states_aggregator/mean.py +61 -0
  149. d9d/module/block/hidden_states_aggregator/noop.py +27 -0
  150. d9d/module/block/moe/__init__.py +13 -0
  151. d9d/module/block/moe/communications/__init__.py +11 -0
  152. d9d/module/block/moe/communications/base.py +58 -0
  153. d9d/module/block/moe/communications/deepep.py +300 -0
  154. d9d/module/block/moe/communications/naive.py +68 -0
  155. d9d/module/block/moe/grouped_experts.py +81 -0
  156. d9d/module/block/moe/grouped_linear.py +78 -0
  157. d9d/module/block/moe/layer.py +122 -0
  158. d9d/module/block/moe/router.py +103 -0
  159. d9d/module/block/positional/__init__.py +8 -0
  160. d9d/module/block/positional/rope.py +150 -0
  161. d9d/module/model/__init__.py +0 -0
  162. d9d/module/model/qwen3_moe/__init__.py +16 -0
  163. d9d/module/model/qwen3_moe/decoder_layer.py +110 -0
  164. d9d/module/model/qwen3_moe/model.py +373 -0
  165. d9d/module/model/qwen3_moe/params.py +69 -0
  166. d9d/module/parallelism/__init__.py +0 -0
  167. d9d/module/parallelism/api/__init__.py +18 -0
  168. d9d/module/parallelism/api/expert_parallel.py +36 -0
  169. d9d/module/parallelism/api/fully_sharded.py +43 -0
  170. d9d/module/parallelism/api/hybrid_sharded.py +49 -0
  171. d9d/module/parallelism/api/replicate_parallel.py +33 -0
  172. d9d/module/parallelism/model/__init__.py +0 -0
  173. d9d/module/parallelism/model/qwen3_moe.py +99 -0
  174. d9d/module/parallelism/style/__init__.py +7 -0
  175. d9d/module/parallelism/style/shard_experts.py +60 -0
  176. d9d/module/parallelism/style/to_local.py +86 -0
  177. d9d/optim/__init__.py +0 -0
  178. d9d/optim/stochastic/__init__.py +5 -0
  179. d9d/optim/stochastic/adamw.py +158 -0
  180. d9d/peft/__init__.py +13 -0
  181. d9d/peft/all/__init__.py +12 -0
  182. d9d/peft/all/config.py +31 -0
  183. d9d/peft/all/method.py +76 -0
  184. d9d/peft/applicator.py +47 -0
  185. d9d/peft/base.py +70 -0
  186. d9d/peft/full_tune/__init__.py +11 -0
  187. d9d/peft/full_tune/config.py +20 -0
  188. d9d/peft/full_tune/method.py +46 -0
  189. d9d/peft/lora/__init__.py +15 -0
  190. d9d/peft/lora/config.py +35 -0
  191. d9d/peft/lora/layer.py +177 -0
  192. d9d/peft/lora/method.py +132 -0
  193. d9d/pipelining/__init__.py +0 -0
  194. d9d/pipelining/api/__init__.py +19 -0
  195. d9d/pipelining/api/module.py +149 -0
  196. d9d/pipelining/api/schedule.py +50 -0
  197. d9d/pipelining/api/sharding.py +9 -0
  198. d9d/pipelining/factory/__init__.py +21 -0
  199. d9d/pipelining/factory/config.py +89 -0
  200. d9d/pipelining/factory/factory.py +114 -0
  201. d9d/pipelining/factory/registry.py +82 -0
  202. d9d/pipelining/infra/__init__.py +0 -0
  203. d9d/pipelining/infra/schedule/__init__.py +0 -0
  204. d9d/pipelining/infra/schedule/component/__init__.py +0 -0
  205. d9d/pipelining/infra/schedule/component/program/__init__.py +22 -0
  206. d9d/pipelining/infra/schedule/component/program/base.py +35 -0
  207. d9d/pipelining/infra/schedule/component/program/communications.py +203 -0
  208. d9d/pipelining/infra/schedule/component/program/topology.py +78 -0
  209. d9d/pipelining/infra/schedule/component/runtime/__init__.py +29 -0
  210. d9d/pipelining/infra/schedule/component/runtime/action.py +361 -0
  211. d9d/pipelining/infra/schedule/component/runtime/communications.py +101 -0
  212. d9d/pipelining/infra/schedule/component/runtime/executor.py +113 -0
  213. d9d/pipelining/infra/schedule/component/runtime/loss.py +55 -0
  214. d9d/pipelining/infra/schedule/program/__init__.py +15 -0
  215. d9d/pipelining/infra/schedule/program/bfs.py +86 -0
  216. d9d/pipelining/infra/schedule/program/dualpipev.py +234 -0
  217. d9d/pipelining/infra/schedule/program/interleaved.py +240 -0
  218. d9d/pipelining/infra/schedule/program/zerobubblev.py +227 -0
  219. d9d/pipelining/infra/stage/__init__.py +5 -0
  220. d9d/pipelining/infra/stage/communications.py +274 -0
  221. d9d/pipelining/infra/stage/computations.py +317 -0
  222. d9d/pipelining/infra/stage/splitgrad.py +377 -0
  223. d9d/pipelining/infra/stage/stage.py +321 -0
  224. d9d/pipelining/infra/stage/struct_helper.py +46 -0
  225. d9d/pipelining/training/__init__.py +7 -0
  226. d9d/pipelining/training/optimizer.py +41 -0
  227. d9d/pipelining/training/scheduler.py +34 -0
  228. d9d/tracker/__init__.py +14 -0
  229. d9d/tracker/base.py +124 -0
  230. d9d/tracker/factory.py +57 -0
  231. d9d/tracker/provider/__init__.py +0 -0
  232. d9d/tracker/provider/aim/__init__.py +0 -0
  233. d9d/tracker/provider/aim/config.py +23 -0
  234. d9d/tracker/provider/aim/tracker.py +114 -0
  235. d9d/tracker/provider/null.py +61 -0
  236. d9d-0.1.0.dist-info/METADATA +90 -0
  237. d9d-0.1.0.dist-info/RECORD +238 -0
  238. d9d-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,58 @@
1
+ import abc
2
+
3
+ import torch
4
+
5
+
6
+ class ExpertCommunicationHandler(abc.ABC):
7
+ """Abstract base class for Mixture-of-Experts communication strategies."""
8
+
9
+ @abc.abstractmethod
10
+ def dispatch(
11
+ self,
12
+ hidden_states: torch.Tensor,
13
+ topk_ids: torch.Tensor,
14
+ topk_weights: torch.Tensor
15
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
16
+ """
17
+ Prepares and routes local hidden states to their target experts (possibly on other workers).
18
+
19
+ This process involves:
20
+
21
+ 1. All-to-All Communication: Transfers hidden states to workers containing the assigned experts. States
22
+ assigned to multiple experts are replicated.
23
+
24
+ 2. Permutation: Sorts tokens by expert ID to prepare for Grouped GEMM.
25
+
26
+ Args:
27
+ hidden_states: Input tokens. Shape: `(num_tokens, hidden_size)`.
28
+ topk_ids: Indices of the top-k experts selected for each token. Shape: `(num_tokens, k)`.
29
+ topk_weights: Routing weights associated with the selected experts. Shape: `(num_tokens, k)`.
30
+
31
+ Returns:
32
+ A tuple containing:
33
+
34
+ - Permuted hidden states received by this rank. Shape: `(num_received_tokens, hidden_size)`.
35
+ - Permuted weights matching the hidden states order. Shape: `(num_received_tokens)`.
36
+ - Expert count tensor indicating how many tokens each local expert received. Shape: `(num_local_experts)`.
37
+ """
38
+
39
+ ...
40
+
41
+ @abc.abstractmethod
42
+ def combine(
43
+ self,
44
+ hidden_states: torch.Tensor
45
+ ) -> torch.Tensor:
46
+ """
47
+ Restores hidden states to their original order and location.
48
+
49
+ Undoes the permutation and performs the reverse All-to-All communication
50
+ to return processed results to the workers that originated the requests.
51
+
52
+ Args:
53
+ hidden_states: The processed hidden states. Shape: `(num_received_tokens, hidden_size)`.
54
+
55
+ Returns:
56
+ The combined hidden states with the original shape and order. Shape: `(num_tokens, hidden_size)`.
57
+ """
58
+ ...
@@ -0,0 +1,300 @@
1
+ from typing import Any
2
+
3
+ import torch
4
+ from deep_ep import Buffer, EventOverlap
5
+ from torch.autograd.function import FunctionCtx
6
+
7
+ from d9d.kernel.moe.indices_to_multihot import fused_indices_to_multihot
8
+ from d9d.kernel.moe.permute_with_probs import moe_permute_with_probs, moe_unpermute_mask
9
+ from d9d.module.block.moe.communications import ExpertCommunicationHandler
10
+
11
+ # see https://github.com/deepseek-ai/DeepEP/blob/main/README.md for examples
12
+ # TODO: implement computation/communication overlap for PP case
13
+
14
+ _buffer: Buffer | None = None
15
+
16
+
17
+ def get_hidden_state_bytes(x: torch.Tensor) -> int:
18
+ """
19
+ Calculates the byte size of a hidden state tensor row.
20
+
21
+ Args:
22
+ x: Input tensor. Shape: `(?, hidden_size)`.
23
+ """
24
+
25
+ return x.size(1) * max(x.element_size(), 2)
26
+
27
+
28
+ def init_deepep_buffer(group: torch.distributed.ProcessGroup, hidden_bytes: int):
29
+ """
30
+ Initializes or expands the global DeepEP communication buffer.
31
+
32
+ Checks if the existing buffer is sufficient for the required hidden dimension
33
+ and process group size. If not, it allocates a new buffer.
34
+
35
+ Args:
36
+ group: The process group intended for communication.
37
+ hidden_bytes: Size of a single hidden state vector in bytes.
38
+ """
39
+
40
+ global _buffer # noqa: PLW0603
41
+ num_nvl_bytes, num_rdma_bytes = 0, 0
42
+ for config in (
43
+ Buffer.get_dispatch_config(group.size()),
44
+ Buffer.get_combine_config(group.size()),
45
+ ):
46
+ num_nvl_bytes = max(
47
+ config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes
48
+ )
49
+ num_rdma_bytes = max(
50
+ config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes
51
+ )
52
+
53
+ # Allocate buffer if not existed or not enough buffer
54
+ if (
55
+ _buffer is None
56
+ or _buffer.group != group
57
+ or _buffer.num_nvl_bytes < num_nvl_bytes
58
+ or _buffer.num_rdma_bytes < num_rdma_bytes
59
+ ):
60
+ _buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes)
61
+
62
+
63
+ class DeepEpDispatch(torch.autograd.Function):
64
+ """Autograd function for the DeepEP Dispatch operation."""
65
+
66
+ @staticmethod
67
+ def forward(
68
+ ctx: FunctionCtx,
69
+ x: torch.Tensor,
70
+ topk_idx: torch.Tensor,
71
+ topk_weights: torch.Tensor,
72
+ num_experts: int
73
+ ) -> tuple[
74
+ torch.Tensor,
75
+ torch.Tensor,
76
+ torch.Tensor,
77
+ list,
78
+ tuple,
79
+ EventOverlap
80
+ ]:
81
+ previous_event = Buffer.capture()
82
+ (
83
+ num_tokens_per_rank,
84
+ num_tokens_per_rdma_rank,
85
+ num_tokens_per_expert,
86
+ is_token_in_rank,
87
+ previous_event
88
+ ) = _buffer.get_dispatch_layout(
89
+ topk_idx, num_experts,
90
+ previous_event=previous_event,
91
+ async_finish=True,
92
+ allocate_on_comm_stream=True
93
+ )
94
+
95
+ (
96
+ recv_x,
97
+ recv_topk_idx,
98
+ recv_topk_weights,
99
+ num_recv_tokens_per_expert_list,
100
+ handle,
101
+ event
102
+ ) = _buffer.dispatch(
103
+ x,
104
+ topk_idx=topk_idx,
105
+ topk_weights=topk_weights,
106
+ num_tokens_per_rank=num_tokens_per_rank,
107
+ num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
108
+ is_token_in_rank=is_token_in_rank,
109
+ num_tokens_per_expert=num_tokens_per_expert,
110
+ previous_event=previous_event,
111
+ async_finish=True,
112
+ allocate_on_comm_stream=True
113
+ )
114
+
115
+ event.current_stream_wait()
116
+
117
+ num_recv_tokens_per_expert_list = torch.tensor(num_recv_tokens_per_expert_list)
118
+
119
+ ctx.handle = handle
120
+
121
+ return (
122
+ recv_x,
123
+ recv_topk_idx,
124
+ recv_topk_weights,
125
+ num_recv_tokens_per_expert_list,
126
+ handle
127
+ )
128
+
129
+ @staticmethod
130
+ def backward(
131
+ ctx: FunctionCtx,
132
+ grad_recv_x: torch.Tensor,
133
+ grad_recv_topk_idx: torch.Tensor,
134
+ grad_recv_topk_weights: torch.Tensor,
135
+ grad_num_recv_tokens_per_expert_list: list,
136
+ grad_handle: Any
137
+ ) -> tuple[
138
+ torch.Tensor,
139
+ None,
140
+ torch.Tensor,
141
+ None
142
+ ]:
143
+ handle = ctx.handle
144
+
145
+ prev_event = Buffer.capture()
146
+
147
+ (
148
+ combined_grad_x,
149
+ combined_grad_recv_topk_weights,
150
+ event
151
+ ) = _buffer.combine(
152
+ grad_recv_x.contiguous(),
153
+ handle,
154
+ topk_weights=grad_recv_topk_weights,
155
+ async_finish=True,
156
+ previous_event=prev_event,
157
+ allocate_on_comm_stream=True
158
+ )
159
+
160
+ event.current_stream_wait()
161
+
162
+ return combined_grad_x, None, combined_grad_recv_topk_weights, None
163
+
164
+
165
+ class DeepEpCombine(torch.autograd.Function):
166
+ """Autograd function for the DeepEP Combine operation."""
167
+
168
+ @staticmethod
169
+ def forward(
170
+ ctx: FunctionCtx,
171
+ x: torch.Tensor,
172
+ handle: Any
173
+ ) -> torch.Tensor:
174
+ previous_event = Buffer.capture()
175
+
176
+ combined_x, _, event = _buffer.combine(
177
+ x,
178
+ handle,
179
+ async_finish=True,
180
+ previous_event=previous_event,
181
+ allocate_on_comm_stream=True
182
+ )
183
+
184
+ event.current_stream_wait()
185
+
186
+ ctx.handle = handle
187
+
188
+ return combined_x
189
+
190
+ @staticmethod
191
+ def backward(ctx: FunctionCtx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]:
192
+ handle = ctx.handle
193
+
194
+ previous_event = Buffer.capture()
195
+
196
+ grad_x, _, _, _, _, event = _buffer.dispatch(
197
+ grad_output.contiguous(),
198
+ handle=handle,
199
+ async_finish=True,
200
+ previous_event=previous_event,
201
+ allocate_on_comm_stream=True
202
+ )
203
+
204
+ event.current_stream_wait()
205
+
206
+ return grad_x, None
207
+
208
+
209
+ class DeepEpCommunicationHandler(ExpertCommunicationHandler):
210
+ """Handles MoE communication using the high-performance DeepEP library."""
211
+
212
+ def __init__(self, num_experts: int):
213
+ """Constructs the DeepEpCommunicationHandler."""
214
+
215
+ self._num_experts = num_experts
216
+ self._num_experts_per_shard = None # late-initialization
217
+
218
+ # == fields saved for post-dispatch ==
219
+
220
+ self._handle = None
221
+ self._hidden_shape_before_permute = None
222
+ self._unpermute_mapping = None
223
+
224
+ def setup(self, group: torch.distributed.ProcessGroup, hidden_size: int, hidden_dtype: torch.dtype):
225
+ """
226
+ Initializes the backend buffer and calculates expert sharding.
227
+
228
+ Args:
229
+ group: The process group containing all experts.
230
+ hidden_size: Dimensionality of the hidden states.
231
+ hidden_dtype: Data type of the hidden states.
232
+ """
233
+
234
+ init_deepep_buffer(group, hidden_size * hidden_dtype.itemsize)
235
+
236
+ if self._num_experts % group.size() != 0:
237
+ raise ValueError("num_experts must be divisible by distributed group size")
238
+
239
+ self._num_experts_per_shard = self._num_experts // group.size()
240
+
241
+ def dispatch(
242
+ self,
243
+ hidden_states: torch.Tensor,
244
+ topk_ids: torch.Tensor,
245
+ topk_weights: torch.Tensor
246
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
247
+ (
248
+ hidden_states,
249
+ topk_ids,
250
+ topk_weights,
251
+ tokens_per_expert,
252
+ handle
253
+ ) = DeepEpDispatch.apply(
254
+ hidden_states,
255
+ topk_ids,
256
+ topk_weights,
257
+ self._num_experts
258
+ )
259
+
260
+ routing_map, routing_probs = fused_indices_to_multihot(
261
+ topk_ids, topk_weights, self._num_experts_per_shard
262
+ )
263
+
264
+ self._hidden_shape_before_permute = hidden_states.shape
265
+
266
+ hidden_states, routing_probs, reverse_permute_map = moe_permute_with_probs(
267
+ hidden_states,
268
+ routing_probs,
269
+ routing_map,
270
+ num_out_tokens=tokens_per_expert.sum().item()
271
+ )
272
+
273
+ self._handle = handle
274
+ self._unpermute_mapping = reverse_permute_map
275
+
276
+ return hidden_states, routing_probs, tokens_per_expert
277
+
278
+ def combine(
279
+ self,
280
+ hidden_states: torch.Tensor
281
+ ) -> torch.Tensor:
282
+ if self._handle is None:
283
+ raise ValueError("you fucked up moe communication order: you should dispatch first and after that combine")
284
+
285
+ hidden_states = moe_unpermute_mask(
286
+ hidden_states,
287
+ self._unpermute_mapping,
288
+ restore_shape=self._hidden_shape_before_permute,
289
+ )
290
+
291
+ hidden_states = DeepEpCombine.apply(
292
+ hidden_states,
293
+ self._handle
294
+ )
295
+
296
+ self._handle = None
297
+ self._unpermute_mapping = None
298
+ self._hidden_shape_before_permute = None
299
+
300
+ return hidden_states
@@ -0,0 +1,68 @@
1
+ from typing import cast
2
+
3
+ import torch
4
+ from torch import Size
5
+
6
+ from d9d.kernel.moe import (
7
+ fused_indices_to_multihot,
8
+ moe_permute_with_probs,
9
+ moe_unpermute_mask,
10
+ )
11
+ from d9d.module.block.moe.communications import ExpertCommunicationHandler
12
+
13
+
14
+ class NoCommunicationHandler(ExpertCommunicationHandler):
15
+ """
16
+ Handles MoE routing within a single device or when no cross-device routing is needed.
17
+
18
+ This handler does not perform network operations. It only permutes elements
19
+ mostly for local logical grouping or debugging.
20
+ """
21
+
22
+ def __init__(self, num_experts: int):
23
+ """Constructs the NoCommunicationHandler."""
24
+ self._num_experts = num_experts
25
+
26
+ self._hidden_shape_before_permute: Size | None = None
27
+ self._unpermute_mapping: torch.Tensor | None = None
28
+
29
+ def dispatch(
30
+ self,
31
+ hidden_states: torch.Tensor,
32
+ topk_ids: torch.Tensor,
33
+ topk_weights: torch.Tensor
34
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
35
+ with torch.no_grad():
36
+ tokens_per_expert = torch.bincount(topk_ids.flatten(), minlength=self._num_experts).cpu()
37
+
38
+ routing_map, routing_probs = fused_indices_to_multihot(
39
+ topk_ids, topk_weights, self._num_experts
40
+ )
41
+
42
+ self._hidden_shape_before_permute = hidden_states.shape
43
+
44
+ hidden_states, routing_probs, reverse_permute_map = moe_permute_with_probs(
45
+ hidden_states,
46
+ routing_probs,
47
+ routing_map,
48
+ num_out_tokens=cast(int, tokens_per_expert.sum().item())
49
+ )
50
+
51
+ self._unpermute_mapping = reverse_permute_map
52
+
53
+ return hidden_states, routing_probs, tokens_per_expert
54
+
55
+ def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
56
+ if self._unpermute_mapping is None:
57
+ raise ValueError("Cannot run combine before running dispatch!")
58
+
59
+ hidden_states = moe_unpermute_mask(
60
+ hidden_states,
61
+ self._unpermute_mapping,
62
+ restore_shape=self._hidden_shape_before_permute,
63
+ )
64
+
65
+ self._unpermute_mapping = None
66
+ self._hidden_shape_before_permute = None
67
+
68
+ return hidden_states
@@ -0,0 +1,81 @@
1
+ import torch
2
+ from torch import nn
3
+
4
+ from d9d.kernel.swiglu import silu_mul
5
+ from d9d.module.base import ModuleLateInit
6
+
7
+ from .grouped_linear import GroupedLinear
8
+
9
+
10
+ class GroupedSwiGLU(nn.Module, ModuleLateInit):
11
+ """
12
+ Executes a collection of SwiGLU experts efficiently using Grouped GEMM.
13
+
14
+ This module implements the architectural pattern: `down_proj(SiLU(gate_proj(x)) * up_proj(x))`.
15
+ It applies this operation across multiple discrete experts in parallel without padding or masking.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ hidden_dim: int,
21
+ intermediate_dim: int,
22
+ num_experts: int
23
+ ):
24
+ """
25
+ Constructs the GroupedSwiGLU module.
26
+
27
+ Args:
28
+ hidden_dim: Dimensionality of the input and output hidden states.
29
+ intermediate_dim: Dimensionality of the intermediate projection.
30
+ num_experts: Total number of experts managed by this local instance.
31
+ """
32
+
33
+ super().__init__()
34
+ self._num_experts = num_experts
35
+
36
+ self.gate_proj = GroupedLinear(num_experts, hidden_dim, intermediate_dim)
37
+ self.up_proj = GroupedLinear(num_experts, hidden_dim, intermediate_dim)
38
+ self.down_proj = GroupedLinear(num_experts, intermediate_dim, hidden_dim)
39
+
40
+ def forward(
41
+ self,
42
+ permuted_x: torch.Tensor,
43
+ permuted_probs: torch.Tensor,
44
+ tokens_per_expert: torch.Tensor,
45
+ ) -> torch.Tensor:
46
+ """
47
+ Computes expert outputs for sorted input tokens.
48
+
49
+ Args:
50
+ permuted_x: Input tokens sorted by their assigned expert.
51
+ Shape: `(total_tokens, hidden_dim)`.
52
+ permuted_probs: Routing weights/probabilities corresponding to the sorted tokens.
53
+ Shape: `(total_tokens)`.
54
+ tokens_per_expert: Number of tokens assigned to each consecutive expert. It is a CPU tensor.
55
+ Shape: `(num_experts)`.
56
+
57
+ Returns:
58
+ The computed and weighted output tokens (still permuted).
59
+ Shape: `(total_tokens, hidden_dim)`.
60
+ """
61
+
62
+ if permuted_x.numel() == 0: # handle cases when there are no routed experts to this instance
63
+ return permuted_x
64
+
65
+ probs = permuted_probs[:, None].to(permuted_x.dtype)
66
+ values = self.down_proj(
67
+ silu_mul(
68
+ self.gate_proj(permuted_x, tokens_per_expert),
69
+ self.up_proj(permuted_x, tokens_per_expert)
70
+ ),
71
+ tokens_per_expert
72
+ )
73
+
74
+ return probs * values
75
+
76
+ def reset_parameters(self):
77
+ """Resets parameters for all internal linear projections."""
78
+
79
+ self.gate_proj.reset_parameters()
80
+ self.up_proj.reset_parameters()
81
+ self.down_proj.reset_parameters()
@@ -0,0 +1,78 @@
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.distributed.tensor import DTensor
6
+
7
+ from d9d.core.autograd import GradDirection
8
+ from d9d.kernel.gmm import gmm
9
+ from d9d.module.base import ModuleLateInit
10
+
11
+
12
+ class GroupedLinear(nn.Module, ModuleLateInit):
13
+ """
14
+ Applies a linear transformation using Grouped GEMM (Generalized Matrix Multiplication).
15
+
16
+ This module allows efficient execution of multiple linear layers (experts) in parallel, where each expert
17
+ processes a variable number of tokens.
18
+ It is the computational core of the Mixture-of-Experts layer.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ n_groups: int,
24
+ in_features: int,
25
+ out_features: int,
26
+ device: torch.device | str | None = None,
27
+ dtype: torch.dtype | None = None
28
+ ):
29
+ """
30
+ Constructs the GroupedLinear layer.
31
+
32
+ Args:
33
+ n_groups: Number of groups (experts).
34
+ in_features: Input hidden size.
35
+ out_features: Output hidden size.
36
+ device: Target device.
37
+ dtype: Target data type.
38
+ """
39
+ super().__init__()
40
+ self.weight = nn.Parameter(torch.empty(n_groups, in_features, out_features,
41
+ device=device, dtype=dtype))
42
+
43
+ self.n_groups = n_groups
44
+ self.in_features = in_features
45
+ self.out_features = out_features
46
+
47
+ self.reset_parameters()
48
+
49
+ def forward(self, x: torch.Tensor, x_groups: torch.Tensor) -> torch.Tensor:
50
+ """
51
+ Performs the grouped matrix multiplication.
52
+
53
+ Args:
54
+ x: Flattened input tensor containing tokens for all groups.
55
+ Shape: `(total_tokens, in_features)`.
56
+ x_groups: CPU Tensor indicating the number of tokens assigned to each group.
57
+ Must sum to `total_tokens`. Shape: `(n_groups,)`.
58
+
59
+ Returns:
60
+ The output tensor. Shape: `(total_tokens, out_features)`.
61
+ """
62
+
63
+ weight: torch.Tensor = self.weight
64
+
65
+ if isinstance(weight, DTensor):
66
+ weight = weight.to_local()
67
+
68
+ return gmm(
69
+ x,
70
+ weight,
71
+ x_groups,
72
+ a_grad_direction=GradDirection.inputs,
73
+ b_grad_direction=GradDirection.weight
74
+ )
75
+
76
+ def reset_parameters(self):
77
+ """Initializes weights using a uniform distribution based on input features."""
78
+ nn.init.uniform_(self.weight, -1 / math.sqrt(self.in_features), 1 / math.sqrt(self.in_features))
@@ -0,0 +1,122 @@
1
+ import torch
2
+ from torch import nn
3
+ from torch.distributed import ProcessGroup
4
+
5
+ from d9d.module.base import ModuleLateInit
6
+
7
+ from .communications import (
8
+ DeepEpCommunicationHandler,
9
+ ExpertCommunicationHandler,
10
+ NoCommunicationHandler,
11
+ )
12
+ from .grouped_experts import GroupedSwiGLU
13
+ from .router import TopKRouter
14
+
15
+ # TODO: implement expert bias
16
+ # TODO: shared experts
17
+
18
+
19
+ class MoELayer(nn.Module, ModuleLateInit):
20
+ """
21
+ A complete Mixture-of-Experts (MoE) block comprising routing, communication, and computation.
22
+
23
+ This layer integrates:
24
+
25
+ 1. **Router**: Selects experts for each token.
26
+ 2. **Communicator**: Handles token dispatch to local or remote experts (EP).
27
+ 3. **Experts**: Performs parallelized computation (Grouped SwiGLU).
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ hidden_dim: int,
33
+ intermediate_dim_grouped: int,
34
+ num_grouped_experts: int,
35
+ top_k: int,
36
+ router_renormalize_probabilities: bool
37
+ ):
38
+ """
39
+ Constructs the MoELayer.
40
+
41
+ Args:
42
+ hidden_dim: Hidden size.
43
+ intermediate_dim_grouped: Intermediate dimension for the Expert FFNs.
44
+ num_grouped_experts: Total number of experts.
45
+ top_k: Number of experts to route each token to.
46
+ router_renormalize_probabilities: Configures router probability normalization behavior.
47
+ """
48
+
49
+ super().__init__()
50
+ self.router = TopKRouter(
51
+ dim=hidden_dim, num_experts=num_grouped_experts, top_k=top_k,
52
+ renormalize_probabilities=router_renormalize_probabilities
53
+ )
54
+ self.grouped_experts = GroupedSwiGLU(
55
+ hidden_dim=hidden_dim,
56
+ intermediate_dim=intermediate_dim_grouped,
57
+ num_experts=num_grouped_experts
58
+ )
59
+ self._communicator: ExpertCommunicationHandler = NoCommunicationHandler(num_grouped_experts)
60
+
61
+ self._num_grouped_experts = num_grouped_experts
62
+ self._hidden_dim = hidden_dim
63
+
64
+ self.tokens_per_expert = nn.Buffer(torch.empty((num_grouped_experts,), dtype=torch.int64), persistent=False)
65
+
66
+ def enable_distributed_communicator(self, group: ProcessGroup):
67
+ """
68
+ Switches from local no-op communication to distributed DeepEP communication.
69
+
70
+ This should be called during model initialization if the model is running in a
71
+ distributed Expert Parallel environment.
72
+
73
+ Args:
74
+ group: The PyTorch process group spanning the expert parallel ranks.
75
+ """
76
+
77
+ communicator = DeepEpCommunicationHandler(num_experts=self._num_grouped_experts)
78
+ communicator.setup(group, self._hidden_dim, self.router.gate.weight.dtype)
79
+ self._communicator = communicator
80
+
81
+ @torch.no_grad()
82
+ def _update_tokens_per_expert(self, expert_indices: torch.Tensor):
83
+ self.tokens_per_expert.add_(expert_indices.view(-1).bincount(minlength=self._num_grouped_experts))
84
+
85
+ @torch.no_grad()
86
+ def reset_stats(self):
87
+ """Resets the expert load balancing counters."""
88
+ self.tokens_per_expert.zero_()
89
+
90
+ def forward(
91
+ self,
92
+ hidden_states: torch.Tensor
93
+ ) -> torch.Tensor:
94
+ """
95
+ Routes tokens to experts, computes, and combines results.
96
+
97
+ Args:
98
+ hidden_states: Input tensor. Shape: `(batch_size, seq_len, hidden_dim)`.
99
+
100
+ Returns:
101
+ Output tensor combined from experts. Shape: `(batch_size, seq_len, hidden_dim)`.
102
+ """
103
+
104
+ old_shape = hidden_states.shape
105
+ hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
106
+ expert_indices, expert_scores = self.router(hidden_states)
107
+ self._update_tokens_per_expert(expert_indices)
108
+ hidden_states, expert_scores, expert_count = self._communicator.dispatch(
109
+ hidden_states, expert_indices, expert_scores
110
+ )
111
+ hidden_states = self.grouped_experts(hidden_states, expert_scores, expert_count)
112
+ hidden_states = self._communicator.combine(hidden_states)
113
+ hidden_states = hidden_states.reshape(*old_shape)
114
+
115
+ return hidden_states
116
+
117
+ def reset_parameters(self):
118
+ """Resets module parameters."""
119
+ self.router.reset_parameters()
120
+ self.grouped_experts.reset_parameters()
121
+
122
+ nn.init.zeros_(self.tokens_per_expert)