mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
mslk/moe/layers.py ADDED
@@ -0,0 +1,1240 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ from abc import ABCMeta, abstractmethod
10
+ from collections.abc import Mapping
11
+ from dataclasses import dataclass
12
+ from functools import cached_property
13
+ from typing import Callable, Optional, Union
14
+
15
+ import torch
16
+ from fairscale.nn.model_parallel.initialize import get_model_parallel_world_size
17
+ from mslk.gemm.triton.grouped_gemm import grouped_gemm, grouped_gemm_fp8_rowwise
18
+ from mslk.moe.activation import silu_mul, silu_mul_quant
19
+ from mslk.moe.gather_scatter import (
20
+ gather_scale_dense_tokens,
21
+ gather_scale_quant_dense_tokens,
22
+ scatter_add_dense_tokens,
23
+ scatter_add_padded_tokens,
24
+ )
25
+ from mslk.moe.shuffling import combine_shuffling, split_shuffling
26
+ from mslk.quantize.triton.fp8_quantize import triton_quantize_fp8_row
27
+ from pyre_extensions import none_throws
28
+ from torch.distributed import get_rank, ProcessGroup
29
+
30
+ if torch.cuda.is_available():
31
+ index_shuffling = torch.ops.mslk.index_shuffling # noqa F401
32
+ else:
33
+ index_shuffling = None
34
+
35
+
36
+ __all__ = ["MoEArgs", "BaselineMoE", "MetaShufflingMoE"]
37
+
38
+
39
+ @dataclass(frozen=True)
40
+ class MoEArgs:
41
+ precision: str
42
+ dim: int
43
+ hidden_dim: int
44
+ num_experts: int
45
+ top_k: int
46
+ mp_size: int
47
+ ep_size: int
48
+ mp_size_for_routed_experts: Optional[int]
49
+ use_fast_accum: bool
50
+ dedup_comm: bool
51
+
52
+ @cached_property
53
+ def num_local_experts(self) -> int:
54
+ return self.num_experts // self.ep_size
55
+
56
+
57
+ INIT_METHODS_TYPE = Mapping[
58
+ str,
59
+ Callable[[torch.Tensor], Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]],
60
+ ]
61
+
62
+
63
+ class ScaledParameter(torch.nn.Parameter):
64
+ def __new__(
65
+ cls,
66
+ data: torch.Tensor,
67
+ scale: Optional[torch.Tensor] = None,
68
+ ) -> "ScaledParameter":
69
+ return super().__new__(cls, data, False)
70
+
71
+ def __init__(
72
+ self,
73
+ data: torch.Tensor,
74
+ scale: Optional[torch.Tensor] = None,
75
+ ):
76
+ self._scale: Optional[torch.Tensor] = scale
77
+
78
+ @property
79
+ def weights(self) -> torch.Tensor:
80
+ return self.data
81
+
82
+ @property
83
+ def scales(self) -> torch.Tensor:
84
+ assert self._scale is not None
85
+ return self._scale
86
+
87
+ @scales.setter
88
+ def scales(self, s: torch.Tensor) -> None:
89
+ self._scale = s
90
+
91
+ @property
92
+ def is_scaled(self) -> bool:
93
+ return self._scale is not None
94
+
95
+
96
+ # Helper functions/modules to perform weights sharding and initialization.
97
+ def init_params(
98
+ key: str,
99
+ param: ScaledParameter,
100
+ init_methods: INIT_METHODS_TYPE,
101
+ ):
102
+ if key in init_methods:
103
+ ret = init_methods[key](param.data)
104
+ if isinstance(ret, torch.Tensor):
105
+ param.data = ret
106
+ else:
107
+ param.data, param.scales = ret
108
+ else:
109
+ torch.nn.init.kaiming_uniform_(param)
110
+
111
+
112
+ class Experts(torch.nn.Module, metaclass=ABCMeta):
113
+ def __init__(
114
+ self,
115
+ dim: int,
116
+ hidden_dim: int,
117
+ ):
118
+ super().__init__()
119
+
120
+ self.dim: int = dim
121
+ self.hidden_dim: int = hidden_dim
122
+
123
+ self.dtype: torch.dtype = torch.get_default_dtype()
124
+ self.divide_factor: int = get_model_parallel_world_size()
125
+
126
+ assert self.dim % self.divide_factor == 0
127
+ assert self.hidden_dim % self.divide_factor == 0
128
+
129
+ self._w13: Optional[ScaledParameter] = None
130
+ self._w2: Optional[ScaledParameter] = None
131
+
132
+ @abstractmethod
133
+ def build(self, init_methods: Optional[INIT_METHODS_TYPE] = None) -> "Experts":
134
+ pass
135
+
136
+ @property
137
+ def w13(self) -> ScaledParameter:
138
+ assert self._w13 is not None, "Parameters are not initialized!"
139
+ return self._w13
140
+
141
+ @property
142
+ def w2(self) -> ScaledParameter:
143
+ assert self._w2 is not None, "Parameters are not initialized!"
144
+ return self._w2
145
+
146
+ @property
147
+ def is_fp8_rowwise(self) -> bool:
148
+ return self.w13.dtype == torch.float8_e4m3fn
149
+
150
+
151
+ class RoutedExperts(Experts):
152
+ def __init__(
153
+ self,
154
+ num_local_experts: int,
155
+ dim: int,
156
+ hidden_dim: int,
157
+ ) -> None:
158
+ super().__init__(dim, hidden_dim)
159
+
160
+ self.num_local_experts: int = num_local_experts
161
+
162
+ def build(
163
+ self, init_methods: Optional[INIT_METHODS_TYPE] = None
164
+ ) -> "RoutedExperts":
165
+ init_methods = {} if init_methods is None else init_methods
166
+
167
+ moe_w_in_eDF: ScaledParameter = ScaledParameter(
168
+ torch.empty(
169
+ self.num_local_experts,
170
+ self.dim,
171
+ self.hidden_dim // self.divide_factor,
172
+ dtype=self.dtype,
173
+ )
174
+ )
175
+ init_params("moe_w_in_eDF", moe_w_in_eDF, init_methods)
176
+
177
+ moe_w_out_eFD: ScaledParameter = ScaledParameter(
178
+ torch.empty(
179
+ self.num_local_experts,
180
+ self.hidden_dim // self.divide_factor,
181
+ self.dim,
182
+ dtype=self.dtype,
183
+ )
184
+ )
185
+ init_params("moe_w_out_eFD", moe_w_out_eFD, init_methods)
186
+
187
+ moe_w_swiglu_eDF: ScaledParameter = ScaledParameter(
188
+ torch.empty(
189
+ self.num_local_experts,
190
+ self.dim,
191
+ self.hidden_dim // self.divide_factor,
192
+ dtype=self.dtype,
193
+ )
194
+ )
195
+ init_params("moe_w_swiglu_eDF", moe_w_swiglu_eDF, init_methods)
196
+
197
+ assert (
198
+ moe_w_in_eDF.dtype == moe_w_out_eFD.dtype
199
+ and moe_w_in_eDF.dtype == moe_w_swiglu_eDF.dtype
200
+ )
201
+ assert (
202
+ moe_w_in_eDF.is_scaled == moe_w_out_eFD.is_scaled
203
+ and moe_w_in_eDF.is_scaled == moe_w_swiglu_eDF.is_scaled
204
+ )
205
+
206
+ self._w13 = ScaledParameter(
207
+ data=torch.cat(
208
+ [
209
+ moe_w_in_eDF,
210
+ moe_w_swiglu_eDF,
211
+ ],
212
+ dim=-1,
213
+ )
214
+ .transpose(1, 2)
215
+ .contiguous(),
216
+ scale=(
217
+ torch.cat(
218
+ [
219
+ moe_w_in_eDF.scales,
220
+ moe_w_swiglu_eDF.scales,
221
+ ],
222
+ dim=-1,
223
+ ).contiguous()
224
+ if moe_w_in_eDF.is_scaled
225
+ else None
226
+ ),
227
+ )
228
+
229
+ del moe_w_in_eDF
230
+ del moe_w_swiglu_eDF
231
+
232
+ self._w2 = ScaledParameter(
233
+ data=moe_w_out_eFD.transpose(1, 2).contiguous(),
234
+ scale=(
235
+ moe_w_out_eFD.scales.contiguous() if moe_w_out_eFD.is_scaled else None
236
+ ),
237
+ )
238
+
239
+ del moe_w_out_eFD
240
+
241
+ return self
242
+
243
+
244
+ class SharedExperts(Experts):
245
+ def __init__(
246
+ self,
247
+ dim: int,
248
+ hidden_dim: int,
249
+ ):
250
+ super().__init__(dim, hidden_dim)
251
+
252
+ def build(
253
+ self, init_methods: Optional[INIT_METHODS_TYPE] = None
254
+ ) -> "SharedExperts":
255
+ init_methods = {} if init_methods is None else init_methods
256
+
257
+ w_in_shared_FD = ScaledParameter(
258
+ torch.empty(
259
+ (self.hidden_dim // self.divide_factor, self.dim), dtype=self.dtype
260
+ )
261
+ )
262
+ init_params("w_in_shared_FD", w_in_shared_FD, init_methods)
263
+
264
+ w_out_shared_DF = ScaledParameter(
265
+ torch.empty(
266
+ (self.dim, self.hidden_dim // self.divide_factor), dtype=self.dtype
267
+ )
268
+ )
269
+ init_params("w_out_shared_DF", w_out_shared_DF, init_methods)
270
+
271
+ w_swiglu_FD = ScaledParameter(
272
+ torch.empty(
273
+ (self.hidden_dim // self.divide_factor, self.dim), dtype=self.dtype
274
+ )
275
+ )
276
+ init_params("w_swiglu_FD", w_swiglu_FD, init_methods)
277
+
278
+ assert (w_in_shared_FD.dtype == w_out_shared_DF.dtype) and (
279
+ w_in_shared_FD.dtype == w_swiglu_FD.dtype
280
+ )
281
+ assert (w_in_shared_FD.is_scaled == w_out_shared_DF.is_scaled) and (
282
+ w_in_shared_FD.is_scaled == w_swiglu_FD.is_scaled
283
+ )
284
+
285
+ self._w13 = ScaledParameter(
286
+ data=torch.cat(
287
+ [
288
+ w_in_shared_FD,
289
+ w_swiglu_FD,
290
+ ]
291
+ ).contiguous(),
292
+ scale=(
293
+ torch.cat(
294
+ [
295
+ w_in_shared_FD.scales,
296
+ w_swiglu_FD.scales,
297
+ ]
298
+ ).contiguous()
299
+ if w_in_shared_FD.is_scaled
300
+ else None
301
+ ),
302
+ )
303
+ del w_in_shared_FD
304
+ del w_swiglu_FD
305
+
306
+ self._w2 = ScaledParameter(
307
+ data=w_out_shared_DF.data.contiguous(),
308
+ scale=(
309
+ w_out_shared_DF.scales.contiguous()
310
+ if w_out_shared_DF.is_scaled
311
+ else None
312
+ ),
313
+ )
314
+ del w_out_shared_DF
315
+
316
+ return self
317
+
318
+
319
+ class BaselineMoE(torch.nn.Module):
320
+ def __init__(
321
+ self,
322
+ ep_group: ProcessGroup,
323
+ ep_mp_group: ProcessGroup,
324
+ moe_args: MoEArgs,
325
+ ) -> None:
326
+ super().__init__()
327
+
328
+ self.moe_args = moe_args
329
+ self.mp_size: int = moe_args.mp_size
330
+ self.ep_size: int = moe_args.ep_size
331
+ self.ep_mp_size: int = (
332
+ moe_args.mp_size
333
+ if moe_args.mp_size_for_routed_experts is None
334
+ else moe_args.mp_size_for_routed_experts
335
+ )
336
+
337
+ self.ep_rank: int = get_rank(ep_group)
338
+ self.ep_mp_rank: int = get_rank(ep_mp_group)
339
+
340
+ self.ep_mp_group: ProcessGroup = ep_mp_group
341
+ self.ep_group: ProcessGroup = ep_group
342
+
343
+ self.num_experts: int = moe_args.num_experts
344
+ self.num_local_experts: int = none_throws(moe_args.num_local_experts)
345
+ assert self.num_experts == self.num_local_experts * self.ep_size
346
+
347
+ self.top_k: int = moe_args.top_k
348
+
349
+ self.dtype: torch.dtype = torch.get_default_dtype()
350
+
351
+ self._router_DE: Optional[ScaledParameter] = None
352
+ self.routed_experts = RoutedExperts(
353
+ moe_args.num_local_experts,
354
+ moe_args.dim,
355
+ moe_args.hidden_dim,
356
+ )
357
+ self.shared_experts = SharedExperts(
358
+ moe_args.dim,
359
+ moe_args.hidden_dim,
360
+ )
361
+
362
+ def build(self, init_methods: Optional[INIT_METHODS_TYPE] = None) -> "BaselineMoE":
363
+ init_methods = {} if init_methods is None else init_methods
364
+
365
+ router_DE = ScaledParameter(
366
+ torch.empty(self.moe_args.dim, self.moe_args.num_experts, dtype=self.dtype)
367
+ )
368
+ init_params("router_DE", router_DE, init_methods)
369
+ self._router_DE = router_DE
370
+
371
+ self.routed_experts.build(init_methods)
372
+ self.shared_experts.build(init_methods)
373
+ return self
374
+
375
+ @property
376
+ def router_DE(self) -> ScaledParameter:
377
+ assert self._router_DE is not None, "Parameters are not initialized!"
378
+ return self._router_DE
379
+
380
+ # User should overwrite this property
381
+ @property
382
+ def is_shared_fp8_rowwise(self) -> bool:
383
+ return self.shared_experts.is_fp8_rowwise
384
+
385
+ @property
386
+ def is_routed_fp8_rowwise(self) -> bool:
387
+ return self.routed_experts.is_fp8_rowwise
388
+
389
+ @property
390
+ def E(self) -> int:
391
+ return self.num_experts
392
+
393
+ @property
394
+ def EG(self) -> int:
395
+ return self.num_local_experts
396
+
397
+ @property
398
+ def K(self) -> int:
399
+ return self.top_k
400
+
401
+ def forward(self, x: torch.Tensor, use_static_shape: bool) -> torch.Tensor:
402
+ with torch.no_grad():
403
+ return self._forward(x, use_static_shape)
404
+
405
+ def _forward(self, x: torch.Tensor, use_static_shape: bool) -> torch.Tensor:
406
+ (B, T, D) = x.shape
407
+ T *= B
408
+ tokens = x.view(T, D)
409
+
410
+ # Shared Experts
411
+ shared_y = self._fake_quant(torch.mm, tokens, self.shared_experts.w13)
412
+ shared_y0, shared_y1 = torch.chunk(shared_y, chunks=2, dim=-1)
413
+ shared_z = shared_y0 * torch.sigmoid(shared_y0) * shared_y1
414
+ shared_z = self._fake_quant(torch.mm, shared_z, self.shared_experts.w2)
415
+
416
+ # Routing Scores
417
+ E: int = self.E
418
+ scores = torch.nn.functional.linear(tokens, self.router_DE.T)
419
+ scores = torch.sigmoid(scores)
420
+ assert scores.shape == (T, E)
421
+
422
+ # Routing
423
+ K: int = self.K
424
+ topk_values, topk_indices = torch.topk(scores, K, dim=-1)
425
+ assert topk_values.shape == (T, K)
426
+ assert topk_indices.shape == (T, K)
427
+
428
+ masked_scores = torch.zeros_like(scores)
429
+ masked_scores = (
430
+ masked_scores.scatter_(dim=1, index=topk_indices, src=topk_values)
431
+ .transpose(0, 1) # (E, T)
432
+ .reshape(E, T, 1)
433
+ .expand(E, T, D)
434
+ )
435
+
436
+ tokens = tokens.view(1, T, D).expand(E, T, D)
437
+ masked_tokens = tokens * masked_scores
438
+
439
+ # Routed Experts
440
+ EG: int = self.EG
441
+ if self.ep_size > 1:
442
+ send_tokens = masked_tokens.contiguous()
443
+ send_list = list(torch.chunk(send_tokens, chunks=self.ep_size, dim=0))
444
+ recv_tokens = torch.empty_like(send_tokens)
445
+ recv_list = list(torch.chunk(recv_tokens, chunks=self.ep_size, dim=0))
446
+
447
+ torch.distributed.all_to_all(
448
+ output_tensor_list=recv_list,
449
+ input_tensor_list=send_list,
450
+ group=self.ep_group,
451
+ )
452
+
453
+ masked_tokens = recv_tokens.reshape(EG, -1, D)
454
+
455
+ routed_y = self._fake_quant(torch.bmm, masked_tokens, self.routed_experts.w13)
456
+ routed_y0, routed_y1 = torch.chunk(routed_y, chunks=2, dim=-1)
457
+ routed_z = routed_y0 * torch.sigmoid(routed_y0) * routed_y1
458
+ routed_z = self._fake_quant(torch.bmm, routed_z, self.routed_experts.w2)
459
+
460
+ if self.ep_size > 1:
461
+ send_tokens = routed_z.reshape(E * T, D).contiguous()
462
+ send_list = list(torch.chunk(send_tokens, chunks=self.ep_size, dim=0))
463
+ recv_tokens = torch.empty_like(send_tokens)
464
+ recv_list = list(torch.chunk(recv_tokens, chunks=self.ep_size, dim=0))
465
+
466
+ torch.distributed.all_to_all(
467
+ output_tensor_list=recv_list,
468
+ input_tensor_list=send_list,
469
+ group=self.ep_group,
470
+ )
471
+
472
+ routed_z = recv_tokens.reshape(E, T, D)
473
+
474
+ return (shared_z + routed_z.sum(dim=0)).reshape(B, -1, D)
475
+
476
+ def _fake_quant(self, op, x: torch.Tensor, w: ScaledParameter) -> torch.Tensor:
477
+ if not w.is_scaled:
478
+ return op(x, w.transpose(-1, -2))
479
+
480
+ xq, xs = triton_quantize_fp8_row(x)
481
+ wq, ws = w.weights, w.scales
482
+
483
+ y = (
484
+ op(xq.to(x.dtype), wq.transpose(-1, -2).to(x.dtype))
485
+ * xs.unsqueeze(-1)
486
+ * ws.unsqueeze(-2)
487
+ )
488
+ return y.to(x.dtype)
489
+
490
+
491
+ class MetaShufflingMoE(BaselineMoE):
492
+ def __init__(
493
+ self,
494
+ ep_group: ProcessGroup,
495
+ ep_mp_group: ProcessGroup,
496
+ moe_args: MoEArgs,
497
+ ) -> None:
498
+ super().__init__(ep_group=ep_group, ep_mp_group=ep_mp_group, moe_args=moe_args)
499
+
500
+ assert self.mp_size == self.ep_mp_size, (
501
+ "MetaShuffling only supports mp_size = mp_size_for_routed_experts now"
502
+ )
503
+
504
+ assert self.top_k == 1, (
505
+ "MetaShuffling only supports top 1 routing at the moment"
506
+ )
507
+
508
+ self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
509
+ self.comp_end_event: torch.cuda.Event = torch.cuda.Event()
510
+ self.comm_end_event: torch.cuda.Event = torch.cuda.Event()
511
+
512
+ self.use_fast_accum: bool = moe_args.use_fast_accum
513
+ self.dedup_comm: bool = moe_args.dedup_comm
514
+ if self.dedup_comm:
515
+ assert self.ep_mp_size == self.mp_size, (
516
+ "TP2EP is not supported for dedup at the moment."
517
+ )
518
+
519
+ self.activation_scale_ub = None
520
+
521
+ def forward(self, x: torch.Tensor, use_static_shape: bool) -> torch.Tensor:
522
+ with torch.no_grad():
523
+ if self.ep_size == 1:
524
+ return self._no_comm_forward(x, use_static_shape)
525
+ if use_static_shape:
526
+ return self._static_comm_forward(x)
527
+ else:
528
+ return self._dynamic_comm_forward(x)
529
+
530
+ def _dynamic_comm_forward(self, tokens: torch.Tensor) -> torch.Tensor:
531
+ comp_stream = torch.cuda.current_stream()
532
+
533
+ (B, T, D) = tokens.shape
534
+ T *= B
535
+
536
+ # 1. Dispatch router kernels.
537
+ routed_tokens, routed_tokens_scales, token_counts, token_indices = self._route(
538
+ tokens
539
+ )
540
+ assert routed_tokens_scales is None
541
+
542
+ # 2. Dispatch 1st all2all on shapes.
543
+ self.comp_end_event.record()
544
+ with torch.cuda.stream(self.comm_stream):
545
+ self.comp_end_event.wait()
546
+
547
+ send_token_counts = token_counts
548
+ recv_token_counts = self._exchange_shapes(send_token_counts)
549
+ send_token_counts.record_stream(self.comm_stream)
550
+
551
+ recv_token_counts.record_stream(comp_stream)
552
+
553
+ # 3. Dispatch shared expert part 1.
554
+ shared_y = self._shared_expert_part1(tokens)
555
+
556
+ with torch.cuda.stream(self.comm_stream):
557
+ # 4. CPU/GPU sync.
558
+ concat_counts = torch.concat(
559
+ [send_token_counts.flatten(), recv_token_counts.flatten()]
560
+ ).cpu()
561
+ send_tokens_list = concat_counts[: self.E].tolist()
562
+ recv_tokens_list = concat_counts[self.E :].tolist()
563
+
564
+ # 5. Dispatch 2nd all2all on tokens.
565
+ send_tokens = routed_tokens
566
+ recv_tokens = self._exchange_tokens(
567
+ send_tokens,
568
+ send_tokens_list,
569
+ recv_tokens_list,
570
+ is_input=True,
571
+ )
572
+ send_tokens.record_stream(self.comm_stream)
573
+
574
+ self.comm_end_event.record()
575
+ recv_tokens.record_stream(comp_stream)
576
+
577
+ # 6. Dispatch routed expert kernels.
578
+ self.comm_end_event.wait()
579
+ recv_T = recv_tokens.shape[0]
580
+ assert recv_tokens.shape == (recv_T, D)
581
+ assert recv_token_counts.shape == (self.ep_size, self.num_local_experts)
582
+ shuffled_recv_tokens, shuffled_recv_token_counts = combine_shuffling(
583
+ recv_tokens, recv_token_counts
584
+ )
585
+ assert shuffled_recv_tokens.shape == (recv_T, D)
586
+ assert shuffled_recv_token_counts.shape == (self.num_local_experts + 1,)
587
+ routed_z = self._routed_expert(
588
+ shuffled_recv_tokens,
589
+ shuffled_recv_token_counts[:-1],
590
+ )
591
+ assert routed_z.shape == (recv_T, D)
592
+ shuffled_send_tokens = split_shuffling(routed_z, recv_token_counts)
593
+ assert shuffled_send_tokens.shape == (recv_T, D)
594
+
595
+ # 7. Dispatch 3rd all2all on tokens.
596
+ self.comp_end_event.record()
597
+ with torch.cuda.stream(self.comm_stream):
598
+ self.comp_end_event.wait()
599
+
600
+ send_tokens = shuffled_send_tokens
601
+ recv_tokens = self._exchange_tokens(
602
+ send_tokens,
603
+ recv_tokens_list,
604
+ send_tokens_list,
605
+ is_input=False,
606
+ )
607
+ send_tokens.record_stream(self.comm_stream)
608
+
609
+ self.comm_end_event.record()
610
+ recv_tokens.record_stream(comp_stream)
611
+
612
+ # 8. Dispatch shared expert part 2.
613
+ shared_z = self._shared_expert_part2(shared_y)
614
+
615
+ # 9. Dispatch combine outputs.
616
+ self.comm_end_event.wait()
617
+ final_output = self._combine_outputs(
618
+ shared_z, recv_tokens, token_indices, token_counts, padded=False
619
+ )
620
+
621
+ T //= B
622
+ return final_output.view(B, T, D)
623
+
624
+ def _static_comm_forward(self, tokens: torch.Tensor) -> torch.Tensor:
625
+ comp_stream = torch.cuda.current_stream()
626
+
627
+ (B, T, D) = tokens.shape
628
+ T *= B
629
+
630
+ # 1. Dispatch router kernels.
631
+ routed_tokens, routed_tokens_scales, token_counts, token_indices = self._route(
632
+ tokens
633
+ )
634
+ assert routed_tokens_scales is None
635
+
636
+ # 2. Dispatch allgather on shapes and tokens.
637
+ self.comp_end_event.record()
638
+ with torch.cuda.stream(self.comm_stream):
639
+ self.comp_end_event.wait()
640
+
641
+ send_token_counts = token_counts
642
+ send_tokens = routed_tokens
643
+ # TODO(shikaili): Check if using 1 allgather is faster even with copies.
644
+ recv_token_counts = self._gather_shapes(send_token_counts)
645
+ recv_tokens = self._gather_tokens(send_tokens)
646
+ send_token_counts.record_stream(self.comm_stream)
647
+ send_tokens.record_stream(self.comm_stream)
648
+
649
+ self.comm_end_event.record()
650
+ recv_token_counts.record_stream(comp_stream)
651
+ recv_tokens.record_stream(comp_stream)
652
+
653
+ # 3. Dispatch shared expert part 1.
654
+ shared_y = self._shared_expert_part1(tokens)
655
+
656
+ # 4. Dispatch routed expert kernels.
657
+ self.comm_end_event.wait()
658
+ assert recv_tokens.shape == (
659
+ self.ep_size,
660
+ T,
661
+ D,
662
+ ), f"{recv_tokens.shape=}, {(self.ep_size, T, D)=}"
663
+ assert recv_token_counts.shape == (self.ep_size, self.E)
664
+ shuffled_recv_tokens, shuffled_recv_token_counts = combine_shuffling(
665
+ recv_tokens.view(-1, D),
666
+ recv_token_counts,
667
+ expert_start=self.ep_rank * self.num_local_experts,
668
+ expert_end=(self.ep_rank + 1) * self.num_local_experts,
669
+ )
670
+ assert shuffled_recv_tokens.shape == (self.ep_size * T, D)
671
+ assert shuffled_recv_token_counts.shape == (self.num_local_experts + 1,), (
672
+ f"{shuffled_recv_token_counts.shape=}"
673
+ )
674
+ routed_z = self._routed_expert(
675
+ shuffled_recv_tokens,
676
+ shuffled_recv_token_counts[:-1],
677
+ )
678
+ assert routed_z.shape == (self.ep_size * T, D)
679
+ shuffled_send_tokens = split_shuffling(
680
+ routed_z,
681
+ recv_token_counts,
682
+ expert_start=self.ep_rank * self.num_local_experts,
683
+ expert_end=(self.ep_rank + 1) * self.num_local_experts,
684
+ )
685
+ assert shuffled_send_tokens.shape == (self.ep_size * T, D)
686
+
687
+ # 5. Dispatch all2all on tokens.
688
+ self.comp_end_event.record()
689
+ with torch.cuda.stream(self.comm_stream):
690
+ self.comp_end_event.wait()
691
+
692
+ send_tokens = shuffled_send_tokens
693
+ recv_tokens = self._exchange_tokens(send_tokens, None, None, is_input=False)
694
+ send_tokens.record_stream(self.comm_stream)
695
+
696
+ self.comm_end_event.record()
697
+ recv_tokens.record_stream(comp_stream)
698
+
699
+ # 6. Dispatch shared expert part 2.
700
+ shared_z = self._shared_expert_part2(shared_y)
701
+
702
+ # 7. Dispatch combine outputs.
703
+ self.comm_end_event.wait()
704
+ final_output = self._combine_outputs(
705
+ shared_z,
706
+ recv_tokens.view(self.ep_size, T, D),
707
+ token_indices,
708
+ token_counts,
709
+ padded=True,
710
+ )
711
+
712
+ T //= B
713
+ return final_output.view(B, T, D)
714
+
715
+ def _no_comm_forward(
716
+ self, tokens: torch.Tensor, overlap_router_and_shared_expert: bool
717
+ ) -> torch.Tensor:
718
+ # Default stream for compute
719
+ comp_stream = torch.cuda.current_stream()
720
+ if overlap_router_and_shared_expert:
721
+ self.comp_end_event.record()
722
+ (B, T, D) = tokens.shape
723
+
724
+ # 1. Dispatch router kernels and shared experts GEMMs.
725
+ routed_tokens, routed_tokens_scales, token_counts, token_indices = self._route(
726
+ tokens
727
+ )
728
+
729
+ if overlap_router_and_shared_expert:
730
+ with torch.cuda.stream(self.comm_stream):
731
+ self.comp_end_event.wait()
732
+
733
+ shared_y = self._shared_expert_part1(tokens)
734
+ shared_z = self._shared_expert_part2(shared_y)
735
+ tokens.record_stream(self.comm_stream)
736
+
737
+ self.comm_end_event.record()
738
+ shared_z.record_stream(comp_stream)
739
+ self.comm_end_event.wait()
740
+ else:
741
+ shared_y = self._shared_expert_part1(tokens)
742
+ shared_z = self._shared_expert_part2(shared_y)
743
+
744
+ # 2. Dispatch routed expert GEMMs.
745
+ if not torch.version.hip:
746
+ final_output = self._routed_expert(
747
+ routed_tokens,
748
+ token_counts,
749
+ token_scales=routed_tokens_scales,
750
+ shared_output=shared_z,
751
+ token_indices=token_indices,
752
+ )
753
+ else:
754
+ routed_z = self._routed_expert(
755
+ routed_tokens,
756
+ token_counts,
757
+ token_scales=routed_tokens_scales,
758
+ )
759
+ # 3. Dispatch combine outputs.
760
+ final_output = self._combine_outputs(
761
+ shared_z, routed_z, token_indices, token_counts, padded=False
762
+ )
763
+
764
+ return final_output.view(B, T, D)
765
+
766
+ def _exchange_shapes(self, send_sizes: torch.Tensor) -> torch.Tensor:
767
+ "No CPU/GPU sync in this function."
768
+ if self.ep_size == 1:
769
+ return send_sizes
770
+
771
+ assert tuple(send_sizes.shape) == (self.E,)
772
+ recv_sizes = torch.empty_like(send_sizes)
773
+
774
+ recv_sizes_list = list(recv_sizes.chunk(self.ep_size))
775
+ send_sizes_list = list(send_sizes.chunk(self.ep_size))
776
+
777
+ assert all(r.is_contiguous() for r in recv_sizes_list)
778
+ assert all(s.is_contiguous() for s in send_sizes_list)
779
+ torch.distributed.all_to_all(
780
+ output_tensor_list=recv_sizes_list,
781
+ input_tensor_list=send_sizes_list,
782
+ group=self.ep_group,
783
+ )
784
+
785
+ # send_sizes: [E] viewed as [EP, EG]
786
+ # recv_sizes: [E] viewed as [EP, EG]
787
+ return recv_sizes.view(self.ep_size, self.num_local_experts)
788
+
789
+ def _gather_shapes(self, send_sizes: torch.Tensor) -> torch.Tensor:
790
+ "No CPU/GPU sync in this function."
791
+ if self.ep_size == 1:
792
+ return send_sizes
793
+
794
+ assert tuple(send_sizes.shape) == (self.E,)
795
+ recv_sizes = torch.empty(
796
+ (self.ep_size, self.E), dtype=send_sizes.dtype, device=send_sizes.device
797
+ )
798
+
799
+ assert send_sizes.is_contiguous()
800
+ assert recv_sizes.is_contiguous()
801
+ torch.distributed.all_gather_into_tensor(
802
+ output_tensor=recv_sizes,
803
+ input_tensor=send_sizes,
804
+ group=self.ep_group,
805
+ )
806
+
807
+ # send_sizes: [E]
808
+ # recv_sizes: [EP, E]
809
+ return recv_sizes
810
+
811
+ def _exchange_tokens(
812
+ self,
813
+ send_tokens: torch.Tensor,
814
+ send_sizes: Optional[list[int]],
815
+ recv_sizes: Optional[list[int]],
816
+ is_input: bool,
817
+ ) -> torch.Tensor:
818
+ """
819
+ When `send_sizes`/`recv_size` are `None`, we assume the tokens are evenly distributed
820
+ across different EP ranks, so the total number of tokens `T` are split by `E`.
821
+ No CPU/GPU sync in this function.
822
+ """
823
+ if self.ep_size == 1:
824
+ return send_tokens
825
+
826
+ D = send_tokens.shape[-1]
827
+ send_tokens = send_tokens.view(-1, D)
828
+ T = send_tokens.shape[0]
829
+
830
+ if send_sizes is None:
831
+ send_sizes = [T // self.ep_size for _ in range(self.ep_size)]
832
+ else:
833
+ send_sizes = [
834
+ sum(
835
+ send_sizes[
836
+ r * self.num_local_experts : (r + 1) * self.num_local_experts
837
+ ]
838
+ )
839
+ for r in range(self.ep_size)
840
+ ]
841
+
842
+ if recv_sizes is None:
843
+ recv_sizes = [T // self.ep_size for _ in range(self.ep_size)]
844
+ else:
845
+ recv_sizes = [
846
+ sum(
847
+ recv_sizes[
848
+ r * self.num_local_experts : (r + 1) * self.num_local_experts
849
+ ]
850
+ )
851
+ for r in range(self.ep_size)
852
+ ]
853
+
854
+ # TODO: Add FP8 A2A to example.
855
+ if self.dedup_comm:
856
+ if is_input:
857
+ sliced_recv_tokens = torch.empty(
858
+ (sum(none_throws(recv_sizes)), D // self.ep_mp_size),
859
+ dtype=send_tokens.dtype,
860
+ device=send_tokens.device,
861
+ )
862
+ # TODO(shikaili): Extremely high copy overhead in prefill.
863
+ sliced_send_tokens = send_tokens.chunk(self.ep_mp_size, dim=-1)[
864
+ self.ep_mp_rank
865
+ ].contiguous()
866
+
867
+ recv_tokens_list = list(
868
+ sliced_recv_tokens.split(none_throws(recv_sizes))
869
+ )
870
+ send_tokens_list = list(
871
+ sliced_send_tokens.split(none_throws(send_sizes))
872
+ )
873
+
874
+ assert all(r.is_contiguous() for r in recv_tokens_list)
875
+ assert all(s.is_contiguous() for s in send_tokens_list)
876
+ torch.distributed.all_to_all(
877
+ output_tensor_list=recv_tokens_list,
878
+ input_tensor_list=send_tokens_list,
879
+ group=self.ep_group,
880
+ )
881
+
882
+ recv_tokens_permutated = torch.empty(
883
+ (
884
+ self.ep_mp_size,
885
+ sum(none_throws(recv_sizes)),
886
+ D // self.ep_mp_size,
887
+ ),
888
+ dtype=send_tokens.dtype,
889
+ device=send_tokens.device,
890
+ )
891
+
892
+ assert sliced_recv_tokens.is_contiguous()
893
+ assert recv_tokens_permutated.is_contiguous()
894
+ torch.distributed.all_gather_into_tensor(
895
+ output_tensor=recv_tokens_permutated,
896
+ input_tensor=sliced_recv_tokens,
897
+ group=self.ep_mp_group,
898
+ )
899
+
900
+ return (
901
+ recv_tokens_permutated.permute(1, 0, 2).reshape(-1, D).contiguous()
902
+ )
903
+ else:
904
+ # ReduceScatter
905
+ reduced_sliced_send_tokens = torch.empty(
906
+ (D // self.ep_mp_size, sum(none_throws(send_sizes))),
907
+ dtype=send_tokens.dtype,
908
+ device=send_tokens.device,
909
+ )
910
+ torch.distributed.reduce_scatter_tensor(
911
+ output=reduced_sliced_send_tokens,
912
+ input=send_tokens.transpose(0, 1).contiguous(),
913
+ group=self.ep_mp_group,
914
+ )
915
+ reduced_sliced_send_tokens = reduced_sliced_send_tokens.transpose(
916
+ 0, 1
917
+ ).contiguous()
918
+
919
+ # AlltoAll
920
+ reduced_sliced_recv_tokens = torch.empty(
921
+ (sum(none_throws(recv_sizes)), D // self.ep_mp_size),
922
+ dtype=send_tokens.dtype,
923
+ device=send_tokens.device,
924
+ )
925
+ recv_tokens_list = list(
926
+ reduced_sliced_recv_tokens.split(none_throws(recv_sizes))
927
+ )
928
+ send_tokens_list = list(
929
+ reduced_sliced_send_tokens.split(none_throws(send_sizes))
930
+ )
931
+
932
+ assert all(r.is_contiguous() for r in recv_tokens_list)
933
+ assert all(s.is_contiguous() for s in send_tokens_list)
934
+ torch.distributed.all_to_all(
935
+ output_tensor_list=recv_tokens_list,
936
+ input_tensor_list=send_tokens_list,
937
+ group=self.ep_group,
938
+ )
939
+
940
+ # Padding
941
+ slice_d = D // self.ep_mp_size
942
+ pad_l = slice_d * self.ep_mp_rank
943
+ pad_r = D - pad_l - slice_d
944
+ return torch.nn.functional.pad(
945
+ reduced_sliced_recv_tokens, (pad_l, pad_r)
946
+ )
947
+ else:
948
+ recv_tokens = torch.empty(
949
+ (sum(none_throws(recv_sizes)), D),
950
+ dtype=send_tokens.dtype,
951
+ device=send_tokens.device,
952
+ )
953
+
954
+ recv_tokens_list = list(recv_tokens.split(none_throws(recv_sizes)))
955
+ send_tokens_list = list(send_tokens.split(none_throws(send_sizes)))
956
+
957
+ assert all(r.is_contiguous() for r in recv_tokens_list)
958
+ assert all(s.is_contiguous() for s in send_tokens_list)
959
+ torch.distributed.all_to_all(
960
+ output_tensor_list=recv_tokens_list,
961
+ input_tensor_list=send_tokens_list,
962
+ group=self.ep_group,
963
+ )
964
+
965
+ return recv_tokens
966
+
967
+ def _gather_tokens(
968
+ self,
969
+ send_tokens: torch.Tensor,
970
+ ) -> torch.Tensor:
971
+ "No CPU/GPU sync in this function."
972
+ if self.ep_size == 1:
973
+ return send_tokens
974
+
975
+ # TODO: Add FP8 AG to example.
976
+ T, D = send_tokens.shape
977
+ if self.dedup_comm:
978
+ inter_node_recv_tokens = torch.empty(
979
+ (self.ep_size, T, D // self.ep_mp_size),
980
+ dtype=send_tokens.dtype,
981
+ device=send_tokens.device,
982
+ )
983
+ # Copy overhead.
984
+ inter_node_send_tokens = send_tokens.chunk(self.ep_mp_size, dim=-1)[
985
+ self.ep_mp_rank
986
+ ].contiguous()
987
+
988
+ assert inter_node_send_tokens.is_contiguous()
989
+ assert inter_node_recv_tokens.is_contiguous()
990
+ torch.distributed.all_gather_into_tensor(
991
+ output_tensor=inter_node_recv_tokens,
992
+ input_tensor=inter_node_send_tokens,
993
+ group=self.ep_group,
994
+ )
995
+
996
+ intra_node_recv_tokens_transposed = torch.empty(
997
+ (self.ep_mp_size, self.ep_size, T, D // self.ep_mp_size),
998
+ dtype=send_tokens.dtype,
999
+ device=send_tokens.device,
1000
+ )
1001
+
1002
+ assert inter_node_recv_tokens.is_contiguous()
1003
+ assert intra_node_recv_tokens_transposed.is_contiguous()
1004
+ torch.distributed.all_gather_into_tensor(
1005
+ output_tensor=intra_node_recv_tokens_transposed,
1006
+ input_tensor=inter_node_recv_tokens,
1007
+ group=self.ep_mp_group,
1008
+ )
1009
+
1010
+ # Copy overhead.
1011
+ return (
1012
+ intra_node_recv_tokens_transposed.permute(1, 2, 0, 3)
1013
+ .reshape(self.ep_size, T, D)
1014
+ .contiguous()
1015
+ )
1016
+ else:
1017
+ recv_tokens = torch.empty(
1018
+ (self.ep_size, T, D),
1019
+ dtype=send_tokens.dtype,
1020
+ device=send_tokens.device,
1021
+ )
1022
+
1023
+ assert send_tokens.is_contiguous()
1024
+ assert recv_tokens.is_contiguous()
1025
+ torch.distributed.all_gather_into_tensor(
1026
+ output_tensor=recv_tokens,
1027
+ input_tensor=send_tokens,
1028
+ group=self.ep_group,
1029
+ )
1030
+ return recv_tokens
1031
+
1032
+ def _route(
1033
+ self, tokens: torch.Tensor
1034
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
1035
+ B, T, D = tokens.shape
1036
+ tokens = tokens.view(-1, D)
1037
+
1038
+ assert not self.router_DE.is_scaled
1039
+ scores = torch.nn.functional.linear(tokens, self.router_DE.T)
1040
+ scores = torch.sigmoid(scores)
1041
+ assert scores.shape == (B * T, self.E)
1042
+
1043
+ token_counts, expert_indices, token_indices = index_shuffling(
1044
+ scores, # num_tokens
1045
+ )
1046
+ token_counts = token_counts[: self.E]
1047
+
1048
+ if self.dedup_comm:
1049
+ split_sizes = [
1050
+ token_counts.shape[0],
1051
+ expert_indices.shape[0],
1052
+ token_indices.shape[0],
1053
+ ]
1054
+ output = torch.concat([token_counts, expert_indices, token_indices], dim=0)
1055
+ # Require broadcast as index_shuffling is not deterministic.
1056
+ torch.distributed.broadcast(
1057
+ output,
1058
+ src=(torch.distributed.get_rank() // self.ep_mp_size) * self.ep_mp_size,
1059
+ group=self.ep_mp_group,
1060
+ )
1061
+ token_counts, expert_indices, token_indices = torch.split(
1062
+ output, split_sizes, dim=0
1063
+ )
1064
+
1065
+ if self.is_routed_fp8_rowwise and self.ep_size == 1:
1066
+ routed_tokens, routed_tokens_scales = gather_scale_quant_dense_tokens(
1067
+ tokens,
1068
+ token_indices=token_indices.flatten(),
1069
+ expert_indices=expert_indices.flatten(),
1070
+ scores=scores,
1071
+ scale_ub=self.activation_scale_ub,
1072
+ )
1073
+ else:
1074
+ routed_tokens = gather_scale_dense_tokens(
1075
+ tokens,
1076
+ token_indices=token_indices.flatten(),
1077
+ expert_indices=expert_indices.flatten(),
1078
+ scores=scores,
1079
+ )
1080
+ routed_tokens_scales = None
1081
+ return routed_tokens, routed_tokens_scales, token_counts, token_indices
1082
+
1083
+ def _shared_expert_part1(self, x: torch.Tensor) -> torch.Tensor:
1084
+ # tokens: [B, T, D]
1085
+ D = x.shape[-1]
1086
+ x = x.view(-1, D)
1087
+ w13 = self.shared_experts.w13
1088
+
1089
+ if not self.is_shared_fp8_rowwise:
1090
+ # TODO(shikaili): Skip padded tokens.
1091
+ return x @ w13.T
1092
+ else:
1093
+ x, x_scale = triton_quantize_fp8_row(x, self.activation_scale_ub)
1094
+ # TODO(shikaili): Skip padded tokens.
1095
+ return torch.ops.mslk.f8f8bf16_rowwise(
1096
+ x,
1097
+ w13.weights,
1098
+ x_scale,
1099
+ w13.scales,
1100
+ use_fast_accum=self.use_fast_accum,
1101
+ )
1102
+
1103
+ def _shared_expert_part2(self, y: torch.Tensor) -> torch.Tensor:
1104
+ # tokens: [B, T, D]
1105
+ HD_L_2 = y.shape[-1]
1106
+ HD_L = HD_L_2 // 2
1107
+ w2 = self.shared_experts.w2
1108
+
1109
+ z, z_scale = self._fused_silu_mul(
1110
+ y[:, :HD_L],
1111
+ y[:, HD_L:],
1112
+ self.is_shared_fp8_rowwise,
1113
+ self.activation_scale_ub,
1114
+ )
1115
+ if not self.is_shared_fp8_rowwise:
1116
+ assert z_scale is None
1117
+ # TODO(shikaili): Skip padded tokens.
1118
+ return z @ w2.T
1119
+ else:
1120
+ assert z_scale is not None
1121
+ # TODO(shikaili): Skip padded tokens.
1122
+ return torch.ops.mslk.f8f8bf16_rowwise(
1123
+ z,
1124
+ w2.weights,
1125
+ z_scale,
1126
+ w2.scales,
1127
+ use_fast_accum=self.use_fast_accum,
1128
+ )
1129
+
1130
+ def _routed_expert(
1131
+ self,
1132
+ tokens: torch.Tensor,
1133
+ token_counts: torch.Tensor,
1134
+ token_scales: Optional[torch.Tensor] = None,
1135
+ shared_output: Optional[torch.Tensor] = None,
1136
+ token_indices: Optional[torch.Tensor] = None,
1137
+ ) -> torch.Tensor:
1138
+ # tokens: [B, T, D]
1139
+ D = tokens.shape[-1]
1140
+ x = tokens.view(-1, D)
1141
+
1142
+ if x.shape[0] == 0:
1143
+ return x
1144
+
1145
+ w13 = self.routed_experts.w13
1146
+ w2 = self.routed_experts.w2
1147
+
1148
+ assert D == w13.shape[-1]
1149
+ HD_L = w2.shape[-1]
1150
+
1151
+ assert token_counts.shape == (self.num_local_experts,)
1152
+ if not self.is_routed_fp8_rowwise:
1153
+ y = grouped_gemm(
1154
+ x,
1155
+ w13.view(-1, D),
1156
+ token_counts,
1157
+ use_fast_accum=self.use_fast_accum,
1158
+ _use_warp_specialization=not torch.version.hip,
1159
+ )
1160
+ z, _ = self._fused_silu_mul(y[:, :HD_L], y[:, HD_L:], False)
1161
+ return grouped_gemm(
1162
+ z,
1163
+ w2.view(-1, HD_L),
1164
+ token_counts,
1165
+ use_fast_accum=self.use_fast_accum,
1166
+ _use_warp_specialization=not torch.version.hip,
1167
+ _output_tensor=shared_output,
1168
+ _scatter_add_indices=token_indices,
1169
+ )
1170
+ else:
1171
+ if token_scales is None:
1172
+ x, x_scale = triton_quantize_fp8_row(x, self.activation_scale_ub)
1173
+ else:
1174
+ x_scale = token_scales
1175
+ y = grouped_gemm_fp8_rowwise(
1176
+ x,
1177
+ w13.weights.view(-1, D),
1178
+ token_counts,
1179
+ x_scale.view(-1),
1180
+ w13.scales.view(-1),
1181
+ use_fast_accum=self.use_fast_accum,
1182
+ _use_warp_specialization=not torch.version.hip,
1183
+ )
1184
+ # TODO(shikaili): Skip padded tokens.
1185
+ z, z_scale = self._fused_silu_mul(
1186
+ y[:, :HD_L], y[:, HD_L:], True, self.activation_scale_ub
1187
+ )
1188
+ assert z_scale is not None
1189
+ return grouped_gemm_fp8_rowwise(
1190
+ z,
1191
+ w2.weights.view(-1, HD_L),
1192
+ token_counts,
1193
+ z_scale.view(-1),
1194
+ w2.scales.view(-1),
1195
+ use_fast_accum=self.use_fast_accum,
1196
+ _use_warp_specialization=not torch.version.hip,
1197
+ _output_tensor=shared_output,
1198
+ _scatter_add_indices=token_indices,
1199
+ )
1200
+
1201
+ def _combine_outputs(
1202
+ self,
1203
+ shared_output_tokens: torch.Tensor,
1204
+ routed_output_tokens: torch.Tensor,
1205
+ token_indices: torch.Tensor,
1206
+ token_counts: torch.Tensor,
1207
+ padded: bool = False,
1208
+ ) -> torch.Tensor:
1209
+ D = shared_output_tokens.shape[-1]
1210
+ assert routed_output_tokens.shape[-1] == D
1211
+
1212
+ if padded:
1213
+ scatter_add_padded_tokens(
1214
+ in_tokens=routed_output_tokens,
1215
+ token_counts=token_counts,
1216
+ token_indices=token_indices,
1217
+ out_tokens=shared_output_tokens,
1218
+ )
1219
+ return shared_output_tokens
1220
+
1221
+ scatter_add_dense_tokens(
1222
+ shared_output_tokens,
1223
+ routed_output_tokens.view(-1, D),
1224
+ token_indices,
1225
+ )
1226
+ return shared_output_tokens
1227
+
1228
+ def _fused_silu_mul(
1229
+ self,
1230
+ x0: torch.Tensor,
1231
+ x1: torch.Tensor,
1232
+ is_fp8: bool,
1233
+ scale_ub: Optional[torch.Tensor] = None,
1234
+ ):
1235
+ z_scale = None
1236
+ if is_fp8:
1237
+ z, z_scale = silu_mul_quant(x0, x1, scale_ub)
1238
+ else:
1239
+ z = silu_mul(x0, x1)
1240
+ return z, z_scale