fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,1272 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ from abc import ABCMeta, abstractmethod
9
+ from collections.abc import Mapping
10
+ from dataclasses import dataclass
11
+ from functools import cached_property
12
+ from typing import Callable, Optional, Union
13
+
14
+ import torch
15
+
16
+ from fairscale.nn.model_parallel.initialize import get_model_parallel_world_size
17
+
18
+ from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import triton_quantize_fp8_row
19
+ from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
20
+ grouped_gemm,
21
+ grouped_gemm_fp8_rowwise,
22
+ )
23
+ from fbgemm_gpu.experimental.gen_ai.moe.activation import silu_mul, silu_mul_quant
24
+ from fbgemm_gpu.experimental.gen_ai.moe.gather_scatter import (
25
+ gather_scale_dense_tokens,
26
+ gather_scale_quant_dense_tokens,
27
+ scatter_add_dense_tokens,
28
+ scatter_add_padded_tokens,
29
+ )
30
+ from fbgemm_gpu.experimental.gen_ai.moe.shuffling import (
31
+ combine_shuffling,
32
+ split_shuffling,
33
+ )
34
+ from pyre_extensions import none_throws
35
+ from torch.distributed import get_rank, ProcessGroup
36
+
37
+ try:
38
+ # pyre-ignore[21]
39
+ # @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils
40
+ from fbgemm_gpu import open_source
41
+ except Exception:
42
+ open_source: bool = False
43
+
44
+ if open_source:
45
+ torch.ops.load_library(
46
+ os.path.join(
47
+ os.path.dirname(os.path.dirname(__file__)),
48
+ "fbgemm_gpu_experimental_gen_ai.so",
49
+ )
50
+ )
51
+ torch.classes.load_library(
52
+ os.path.join(
53
+ os.path.dirname(os.path.dirname(__file__)),
54
+ "fbgemm_gpu_experimental_gen_ai.so",
55
+ )
56
+ )
57
+ else:
58
+ torch.ops.load_library(
59
+ "//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:index_shuffling_ops"
60
+ )
61
+
62
+ if torch.cuda.is_available():
63
+ index_shuffling = torch.ops.fbgemm.index_shuffling # noqa F401
64
+ else:
65
+ index_shuffling = None
66
+
67
+
68
+ __all__ = ["MoEArgs", "BaselineMoE", "MetaShufflingMoE"]
69
+
70
+
71
+ @dataclass(frozen=True)
72
+ class MoEArgs:
73
+ precision: str
74
+ dim: int
75
+ hidden_dim: int
76
+ num_experts: int
77
+ top_k: int
78
+ mp_size: int
79
+ ep_size: int
80
+ mp_size_for_routed_experts: Optional[int]
81
+ use_fast_accum: bool
82
+ dedup_comm: bool
83
+
84
+ @cached_property
85
+ def num_local_experts(self) -> int:
86
+ return self.num_experts // self.ep_size
87
+
88
+
89
+ INIT_METHODS_TYPE = Mapping[
90
+ str,
91
+ Callable[[torch.Tensor], Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]],
92
+ ]
93
+
94
+
95
+ class ScaledParameter(torch.nn.Parameter):
96
+ def __new__(
97
+ cls,
98
+ data: torch.Tensor,
99
+ scale: Optional[torch.Tensor] = None,
100
+ ) -> "ScaledParameter":
101
+ return super().__new__(cls, data, False)
102
+
103
+ def __init__(
104
+ self,
105
+ data: torch.Tensor,
106
+ scale: Optional[torch.Tensor] = None,
107
+ ):
108
+ self._scale: Optional[torch.Tensor] = scale
109
+
110
+ @property
111
+ def weights(self) -> torch.Tensor:
112
+ return self.data
113
+
114
+ @property
115
+ def scales(self) -> torch.Tensor:
116
+ assert self._scale is not None
117
+ return self._scale
118
+
119
+ @scales.setter
120
+ def scales(self, s: torch.Tensor) -> None:
121
+ self._scale = s
122
+
123
+ @property
124
+ def is_scaled(self) -> bool:
125
+ return self._scale is not None
126
+
127
+
128
+ # Helper functions/modules to perform weights sharding and initialization.
129
+ def init_params(
130
+ key: str,
131
+ param: ScaledParameter,
132
+ init_methods: INIT_METHODS_TYPE,
133
+ ):
134
+ if key in init_methods:
135
+ ret = init_methods[key](param.data)
136
+ if isinstance(ret, torch.Tensor):
137
+ param.data = ret
138
+ else:
139
+ param.data, param.scales = ret
140
+ else:
141
+ torch.nn.init.kaiming_uniform_(param)
142
+
143
+
144
+ class Experts(torch.nn.Module, metaclass=ABCMeta):
145
+ def __init__(
146
+ self,
147
+ dim: int,
148
+ hidden_dim: int,
149
+ ):
150
+ super().__init__()
151
+
152
+ self.dim: int = dim
153
+ self.hidden_dim: int = hidden_dim
154
+
155
+ self.dtype: torch.dtype = torch.get_default_dtype()
156
+ self.divide_factor: int = get_model_parallel_world_size()
157
+
158
+ assert self.dim % self.divide_factor == 0
159
+ assert self.hidden_dim % self.divide_factor == 0
160
+
161
+ self._w13: Optional[ScaledParameter] = None
162
+ self._w2: Optional[ScaledParameter] = None
163
+
164
+ @abstractmethod
165
+ def build(self, init_methods: Optional[INIT_METHODS_TYPE] = None) -> "Experts":
166
+ pass
167
+
168
+ @property
169
+ def w13(self) -> ScaledParameter:
170
+ assert self._w13 is not None, "Parameters are not initialized!"
171
+ return self._w13
172
+
173
+ @property
174
+ def w2(self) -> ScaledParameter:
175
+ assert self._w2 is not None, "Parameters are not initialized!"
176
+ return self._w2
177
+
178
+ @property
179
+ def is_fp8_rowwise(self) -> bool:
180
+ return self.w13.dtype == torch.float8_e4m3fn
181
+
182
+
183
+ class RoutedExperts(Experts):
184
+ def __init__(
185
+ self,
186
+ num_local_experts: int,
187
+ dim: int,
188
+ hidden_dim: int,
189
+ ) -> None:
190
+ super().__init__(dim, hidden_dim)
191
+
192
+ self.num_local_experts: int = num_local_experts
193
+
194
+ def build(
195
+ self, init_methods: Optional[INIT_METHODS_TYPE] = None
196
+ ) -> "RoutedExperts":
197
+ init_methods = {} if init_methods is None else init_methods
198
+
199
+ moe_w_in_eDF: ScaledParameter = ScaledParameter(
200
+ torch.empty(
201
+ self.num_local_experts,
202
+ self.dim,
203
+ self.hidden_dim // self.divide_factor,
204
+ dtype=self.dtype,
205
+ )
206
+ )
207
+ init_params("moe_w_in_eDF", moe_w_in_eDF, init_methods)
208
+
209
+ moe_w_out_eFD: ScaledParameter = ScaledParameter(
210
+ torch.empty(
211
+ self.num_local_experts,
212
+ self.hidden_dim // self.divide_factor,
213
+ self.dim,
214
+ dtype=self.dtype,
215
+ )
216
+ )
217
+ init_params("moe_w_out_eFD", moe_w_out_eFD, init_methods)
218
+
219
+ moe_w_swiglu_eDF: ScaledParameter = ScaledParameter(
220
+ torch.empty(
221
+ self.num_local_experts,
222
+ self.dim,
223
+ self.hidden_dim // self.divide_factor,
224
+ dtype=self.dtype,
225
+ )
226
+ )
227
+ init_params("moe_w_swiglu_eDF", moe_w_swiglu_eDF, init_methods)
228
+
229
+ assert (
230
+ moe_w_in_eDF.dtype == moe_w_out_eFD.dtype
231
+ and moe_w_in_eDF.dtype == moe_w_swiglu_eDF.dtype
232
+ )
233
+ assert (
234
+ moe_w_in_eDF.is_scaled == moe_w_out_eFD.is_scaled
235
+ and moe_w_in_eDF.is_scaled == moe_w_swiglu_eDF.is_scaled
236
+ )
237
+
238
+ self._w13 = ScaledParameter(
239
+ data=torch.cat(
240
+ [
241
+ moe_w_in_eDF,
242
+ moe_w_swiglu_eDF,
243
+ ],
244
+ dim=-1,
245
+ )
246
+ .transpose(1, 2)
247
+ .contiguous(),
248
+ scale=(
249
+ torch.cat(
250
+ [
251
+ moe_w_in_eDF.scales,
252
+ moe_w_swiglu_eDF.scales,
253
+ ],
254
+ dim=-1,
255
+ ).contiguous()
256
+ if moe_w_in_eDF.is_scaled
257
+ else None
258
+ ),
259
+ )
260
+
261
+ del moe_w_in_eDF
262
+ del moe_w_swiglu_eDF
263
+
264
+ self._w2 = ScaledParameter(
265
+ data=moe_w_out_eFD.transpose(1, 2).contiguous(),
266
+ scale=(
267
+ moe_w_out_eFD.scales.contiguous() if moe_w_out_eFD.is_scaled else None
268
+ ),
269
+ )
270
+
271
+ del moe_w_out_eFD
272
+
273
+ return self
274
+
275
+
276
+ class SharedExperts(Experts):
277
+ def __init__(
278
+ self,
279
+ dim: int,
280
+ hidden_dim: int,
281
+ ):
282
+ super().__init__(dim, hidden_dim)
283
+
284
+ def build(
285
+ self, init_methods: Optional[INIT_METHODS_TYPE] = None
286
+ ) -> "SharedExperts":
287
+ init_methods = {} if init_methods is None else init_methods
288
+
289
+ w_in_shared_FD = ScaledParameter(
290
+ torch.empty(
291
+ (self.hidden_dim // self.divide_factor, self.dim), dtype=self.dtype
292
+ )
293
+ )
294
+ init_params("w_in_shared_FD", w_in_shared_FD, init_methods)
295
+
296
+ w_out_shared_DF = ScaledParameter(
297
+ torch.empty(
298
+ (self.dim, self.hidden_dim // self.divide_factor), dtype=self.dtype
299
+ )
300
+ )
301
+ init_params("w_out_shared_DF", w_out_shared_DF, init_methods)
302
+
303
+ w_swiglu_FD = ScaledParameter(
304
+ torch.empty(
305
+ (self.hidden_dim // self.divide_factor, self.dim), dtype=self.dtype
306
+ )
307
+ )
308
+ init_params("w_swiglu_FD", w_swiglu_FD, init_methods)
309
+
310
+ assert (w_in_shared_FD.dtype == w_out_shared_DF.dtype) and (
311
+ w_in_shared_FD.dtype == w_swiglu_FD.dtype
312
+ )
313
+ assert (w_in_shared_FD.is_scaled == w_out_shared_DF.is_scaled) and (
314
+ w_in_shared_FD.is_scaled == w_swiglu_FD.is_scaled
315
+ )
316
+
317
+ self._w13 = ScaledParameter(
318
+ data=torch.cat(
319
+ [
320
+ w_in_shared_FD,
321
+ w_swiglu_FD,
322
+ ]
323
+ ).contiguous(),
324
+ scale=(
325
+ torch.cat(
326
+ [
327
+ w_in_shared_FD.scales,
328
+ w_swiglu_FD.scales,
329
+ ]
330
+ ).contiguous()
331
+ if w_in_shared_FD.is_scaled
332
+ else None
333
+ ),
334
+ )
335
+ del w_in_shared_FD
336
+ del w_swiglu_FD
337
+
338
+ self._w2 = ScaledParameter(
339
+ data=w_out_shared_DF.data.contiguous(),
340
+ scale=(
341
+ w_out_shared_DF.scales.contiguous()
342
+ if w_out_shared_DF.is_scaled
343
+ else None
344
+ ),
345
+ )
346
+ del w_out_shared_DF
347
+
348
+ return self
349
+
350
+
351
+ class BaselineMoE(torch.nn.Module):
352
+ def __init__(
353
+ self,
354
+ ep_group: ProcessGroup,
355
+ ep_mp_group: ProcessGroup,
356
+ moe_args: MoEArgs,
357
+ ) -> None:
358
+ super().__init__()
359
+
360
+ self.moe_args = moe_args
361
+ self.mp_size: int = moe_args.mp_size
362
+ self.ep_size: int = moe_args.ep_size
363
+ self.ep_mp_size: int = (
364
+ moe_args.mp_size
365
+ if moe_args.mp_size_for_routed_experts is None
366
+ else moe_args.mp_size_for_routed_experts
367
+ )
368
+
369
+ self.ep_rank: int = get_rank(ep_group)
370
+ self.ep_mp_rank: int = get_rank(ep_mp_group)
371
+
372
+ self.ep_mp_group: ProcessGroup = ep_mp_group
373
+ self.ep_group: ProcessGroup = ep_group
374
+
375
+ self.num_experts: int = moe_args.num_experts
376
+ self.num_local_experts: int = none_throws(moe_args.num_local_experts)
377
+ assert self.num_experts == self.num_local_experts * self.ep_size
378
+
379
+ self.top_k: int = moe_args.top_k
380
+
381
+ self.dtype: torch.dtype = torch.get_default_dtype()
382
+
383
+ self._router_DE: Optional[ScaledParameter] = None
384
+ self.routed_experts = RoutedExperts(
385
+ moe_args.num_local_experts,
386
+ moe_args.dim,
387
+ moe_args.hidden_dim,
388
+ )
389
+ self.shared_experts = SharedExperts(
390
+ moe_args.dim,
391
+ moe_args.hidden_dim,
392
+ )
393
+
394
+ def build(self, init_methods: Optional[INIT_METHODS_TYPE] = None) -> "BaselineMoE":
395
+ init_methods = {} if init_methods is None else init_methods
396
+
397
+ router_DE = ScaledParameter(
398
+ torch.empty(self.moe_args.dim, self.moe_args.num_experts, dtype=self.dtype)
399
+ )
400
+ init_params("router_DE", router_DE, init_methods)
401
+ self._router_DE = router_DE
402
+
403
+ self.routed_experts.build(init_methods)
404
+ self.shared_experts.build(init_methods)
405
+ return self
406
+
407
+ @property
408
+ def router_DE(self) -> ScaledParameter:
409
+ assert self._router_DE is not None, "Parameters are not initialized!"
410
+ return self._router_DE
411
+
412
+ # User should overwrite this property
413
+ @property
414
+ def is_shared_fp8_rowwise(self) -> bool:
415
+ return self.shared_experts.is_fp8_rowwise
416
+
417
+ @property
418
+ def is_routed_fp8_rowwise(self) -> bool:
419
+ return self.routed_experts.is_fp8_rowwise
420
+
421
+ @property
422
+ def E(self) -> int:
423
+ return self.num_experts
424
+
425
+ @property
426
+ def EG(self) -> int:
427
+ return self.num_local_experts
428
+
429
+ @property
430
+ def K(self) -> int:
431
+ return self.top_k
432
+
433
+ def forward(self, x: torch.Tensor, use_static_shape: bool) -> torch.Tensor:
434
+ with torch.no_grad():
435
+ return self._forward(x, use_static_shape)
436
+
437
+ def _forward(self, x: torch.Tensor, use_static_shape: bool) -> torch.Tensor:
438
+ (B, T, D) = x.shape
439
+ T *= B
440
+ tokens = x.view(T, D)
441
+
442
+ # Shared Experts
443
+ shared_y = self._fake_quant(torch.mm, tokens, self.shared_experts.w13)
444
+ shared_y0, shared_y1 = torch.chunk(shared_y, chunks=2, dim=-1)
445
+ shared_z = shared_y0 * torch.sigmoid(shared_y0) * shared_y1
446
+ shared_z = self._fake_quant(torch.mm, shared_z, self.shared_experts.w2)
447
+
448
+ # Routing Scores
449
+ E: int = self.E
450
+ scores = torch.nn.functional.linear(tokens, self.router_DE.T)
451
+ scores = torch.sigmoid(scores)
452
+ assert scores.shape == (T, E)
453
+
454
+ # Routing
455
+ K: int = self.K
456
+ topk_values, topk_indices = torch.topk(scores, K, dim=-1)
457
+ assert topk_values.shape == (T, K)
458
+ assert topk_indices.shape == (T, K)
459
+
460
+ masked_scores = torch.zeros_like(scores)
461
+ masked_scores = (
462
+ masked_scores.scatter_(dim=1, index=topk_indices, src=topk_values)
463
+ .transpose(0, 1) # (E, T)
464
+ .reshape(E, T, 1)
465
+ .expand(E, T, D)
466
+ )
467
+
468
+ tokens = tokens.view(1, T, D).expand(E, T, D)
469
+ masked_tokens = tokens * masked_scores
470
+
471
+ # Routed Experts
472
+ EG: int = self.EG
473
+ if self.ep_size > 1:
474
+ send_tokens = masked_tokens.contiguous()
475
+ send_list = list(torch.chunk(send_tokens, chunks=self.ep_size, dim=0))
476
+ recv_tokens = torch.empty_like(send_tokens)
477
+ recv_list = list(torch.chunk(recv_tokens, chunks=self.ep_size, dim=0))
478
+
479
+ torch.distributed.all_to_all(
480
+ output_tensor_list=recv_list,
481
+ input_tensor_list=send_list,
482
+ group=self.ep_group,
483
+ )
484
+
485
+ masked_tokens = recv_tokens.reshape(EG, -1, D)
486
+
487
+ routed_y = self._fake_quant(torch.bmm, masked_tokens, self.routed_experts.w13)
488
+ routed_y0, routed_y1 = torch.chunk(routed_y, chunks=2, dim=-1)
489
+ routed_z = routed_y0 * torch.sigmoid(routed_y0) * routed_y1
490
+ routed_z = self._fake_quant(torch.bmm, routed_z, self.routed_experts.w2)
491
+
492
+ if self.ep_size > 1:
493
+ send_tokens = routed_z.reshape(E * T, D).contiguous()
494
+ send_list = list(torch.chunk(send_tokens, chunks=self.ep_size, dim=0))
495
+ recv_tokens = torch.empty_like(send_tokens)
496
+ recv_list = list(torch.chunk(recv_tokens, chunks=self.ep_size, dim=0))
497
+
498
+ torch.distributed.all_to_all(
499
+ output_tensor_list=recv_list,
500
+ input_tensor_list=send_list,
501
+ group=self.ep_group,
502
+ )
503
+
504
+ routed_z = recv_tokens.reshape(E, T, D)
505
+
506
+ return (shared_z + routed_z.sum(dim=0)).reshape(B, -1, D)
507
+
508
+ def _fake_quant(self, op, x: torch.Tensor, w: ScaledParameter) -> torch.Tensor:
509
+ if not w.is_scaled:
510
+ return op(x, w.transpose(-1, -2))
511
+
512
+ xq, xs = triton_quantize_fp8_row(x)
513
+ wq, ws = w.weights, w.scales
514
+
515
+ y = (
516
+ op(xq.to(x.dtype), wq.transpose(-1, -2).to(x.dtype))
517
+ * xs.unsqueeze(-1)
518
+ * ws.unsqueeze(-2)
519
+ )
520
+ return y.to(x.dtype)
521
+
522
+
523
+ class MetaShufflingMoE(BaselineMoE):
524
+ def __init__(
525
+ self,
526
+ ep_group: ProcessGroup,
527
+ ep_mp_group: ProcessGroup,
528
+ moe_args: MoEArgs,
529
+ ) -> None:
530
+ super().__init__(ep_group=ep_group, ep_mp_group=ep_mp_group, moe_args=moe_args)
531
+
532
+ assert (
533
+ self.mp_size == self.ep_mp_size
534
+ ), "MetaShuffling only supports mp_size = mp_size_for_routed_experts now"
535
+
536
+ assert (
537
+ self.top_k == 1
538
+ ), "MetaShuffling only supports top 1 routing at the moment"
539
+
540
+ self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
541
+ self.comp_end_event: torch.cuda.Event = torch.cuda.Event()
542
+ self.comm_end_event: torch.cuda.Event = torch.cuda.Event()
543
+
544
+ self.use_fast_accum: bool = moe_args.use_fast_accum
545
+ self.dedup_comm: bool = moe_args.dedup_comm
546
+ if self.dedup_comm:
547
+ assert (
548
+ self.ep_mp_size == self.mp_size
549
+ ), "TP2EP is not supported for dedup at the moment."
550
+
551
+ self.activation_scale_ub = None
552
+
553
+ def forward(self, x: torch.Tensor, use_static_shape: bool) -> torch.Tensor:
554
+ with torch.no_grad():
555
+ if self.ep_size == 1:
556
+ return self._no_comm_forward(x, use_static_shape)
557
+ if use_static_shape:
558
+ return self._static_comm_forward(x)
559
+ else:
560
+ return self._dynamic_comm_forward(x)
561
+
562
+ def _dynamic_comm_forward(self, tokens: torch.Tensor) -> torch.Tensor:
563
+ comp_stream = torch.cuda.current_stream()
564
+
565
+ (B, T, D) = tokens.shape
566
+ T *= B
567
+
568
+ # 1. Dispatch router kernels.
569
+ routed_tokens, routed_tokens_scales, token_counts, token_indices = self._route(
570
+ tokens
571
+ )
572
+ assert routed_tokens_scales is None
573
+
574
+ # 2. Dispatch 1st all2all on shapes.
575
+ self.comp_end_event.record()
576
+ with torch.cuda.stream(self.comm_stream):
577
+ self.comp_end_event.wait()
578
+
579
+ send_token_counts = token_counts
580
+ recv_token_counts = self._exchange_shapes(send_token_counts)
581
+ send_token_counts.record_stream(self.comm_stream)
582
+
583
+ recv_token_counts.record_stream(comp_stream)
584
+
585
+ # 3. Dispatch shared expert part 1.
586
+ shared_y = self._shared_expert_part1(tokens)
587
+
588
+ with torch.cuda.stream(self.comm_stream):
589
+ # 4. CPU/GPU sync.
590
+ concat_counts = torch.concat(
591
+ [send_token_counts.flatten(), recv_token_counts.flatten()]
592
+ ).cpu()
593
+ send_tokens_list = concat_counts[: self.E].tolist()
594
+ recv_tokens_list = concat_counts[self.E :].tolist()
595
+
596
+ # 5. Dispatch 2nd all2all on tokens.
597
+ send_tokens = routed_tokens
598
+ recv_tokens = self._exchange_tokens(
599
+ send_tokens,
600
+ send_tokens_list,
601
+ recv_tokens_list,
602
+ is_input=True,
603
+ )
604
+ send_tokens.record_stream(self.comm_stream)
605
+
606
+ self.comm_end_event.record()
607
+ recv_tokens.record_stream(comp_stream)
608
+
609
+ # 6. Dispatch routed expert kernels.
610
+ self.comm_end_event.wait()
611
+ recv_T = recv_tokens.shape[0]
612
+ assert recv_tokens.shape == (recv_T, D)
613
+ assert recv_token_counts.shape == (self.ep_size, self.num_local_experts)
614
+ shuffled_recv_tokens, shuffled_recv_token_counts = combine_shuffling(
615
+ recv_tokens, recv_token_counts
616
+ )
617
+ assert shuffled_recv_tokens.shape == (recv_T, D)
618
+ assert shuffled_recv_token_counts.shape == (self.num_local_experts + 1,)
619
+ routed_z = self._routed_expert(
620
+ shuffled_recv_tokens,
621
+ shuffled_recv_token_counts[:-1],
622
+ )
623
+ assert routed_z.shape == (recv_T, D)
624
+ shuffled_send_tokens = split_shuffling(routed_z, recv_token_counts)
625
+ assert shuffled_send_tokens.shape == (recv_T, D)
626
+
627
+ # 7. Dispatch 3rd all2all on tokens.
628
+ self.comp_end_event.record()
629
+ with torch.cuda.stream(self.comm_stream):
630
+ self.comp_end_event.wait()
631
+
632
+ send_tokens = shuffled_send_tokens
633
+ recv_tokens = self._exchange_tokens(
634
+ send_tokens,
635
+ recv_tokens_list,
636
+ send_tokens_list,
637
+ is_input=False,
638
+ )
639
+ send_tokens.record_stream(self.comm_stream)
640
+
641
+ self.comm_end_event.record()
642
+ recv_tokens.record_stream(comp_stream)
643
+
644
+ # 8. Dispatch shared expert part 2.
645
+ shared_z = self._shared_expert_part2(shared_y)
646
+
647
+ # 9. Dispatch combine outputs.
648
+ self.comm_end_event.wait()
649
+ final_output = self._combine_outputs(
650
+ shared_z, recv_tokens, token_indices, token_counts, padded=False
651
+ )
652
+
653
+ T //= B
654
+ return final_output.view(B, T, D)
655
+
656
+ def _static_comm_forward(self, tokens: torch.Tensor) -> torch.Tensor:
657
+ comp_stream = torch.cuda.current_stream()
658
+
659
+ (B, T, D) = tokens.shape
660
+ T *= B
661
+
662
+ # 1. Dispatch router kernels.
663
+ routed_tokens, routed_tokens_scales, token_counts, token_indices = self._route(
664
+ tokens
665
+ )
666
+ assert routed_tokens_scales is None
667
+
668
+ # 2. Dispatch allgather on shapes and tokens.
669
+ self.comp_end_event.record()
670
+ with torch.cuda.stream(self.comm_stream):
671
+ self.comp_end_event.wait()
672
+
673
+ send_token_counts = token_counts
674
+ send_tokens = routed_tokens
675
+ # TODO(shikaili): Check if using 1 allgather is faster even with copies.
676
+ recv_token_counts = self._gather_shapes(send_token_counts)
677
+ recv_tokens = self._gather_tokens(send_tokens)
678
+ send_token_counts.record_stream(self.comm_stream)
679
+ send_tokens.record_stream(self.comm_stream)
680
+
681
+ self.comm_end_event.record()
682
+ recv_token_counts.record_stream(comp_stream)
683
+ recv_tokens.record_stream(comp_stream)
684
+
685
+ # 3. Dispatch shared expert part 1.
686
+ shared_y = self._shared_expert_part1(tokens)
687
+
688
+ # 4. Dispatch routed expert kernels.
689
+ self.comm_end_event.wait()
690
+ assert recv_tokens.shape == (
691
+ self.ep_size,
692
+ T,
693
+ D,
694
+ ), f"{recv_tokens.shape=}, {(self.ep_size, T, D)=}"
695
+ assert recv_token_counts.shape == (self.ep_size, self.E)
696
+ shuffled_recv_tokens, shuffled_recv_token_counts = combine_shuffling(
697
+ recv_tokens.view(-1, D),
698
+ recv_token_counts,
699
+ expert_start=self.ep_rank * self.num_local_experts,
700
+ expert_end=(self.ep_rank + 1) * self.num_local_experts,
701
+ )
702
+ assert shuffled_recv_tokens.shape == (self.ep_size * T, D)
703
+ assert shuffled_recv_token_counts.shape == (
704
+ self.num_local_experts + 1,
705
+ ), f"{shuffled_recv_token_counts.shape=}"
706
+ routed_z = self._routed_expert(
707
+ shuffled_recv_tokens,
708
+ shuffled_recv_token_counts[:-1],
709
+ )
710
+ assert routed_z.shape == (self.ep_size * T, D)
711
+ shuffled_send_tokens = split_shuffling(
712
+ routed_z,
713
+ recv_token_counts,
714
+ expert_start=self.ep_rank * self.num_local_experts,
715
+ expert_end=(self.ep_rank + 1) * self.num_local_experts,
716
+ )
717
+ assert shuffled_send_tokens.shape == (self.ep_size * T, D)
718
+
719
+ # 5. Dispatch all2all on tokens.
720
+ self.comp_end_event.record()
721
+ with torch.cuda.stream(self.comm_stream):
722
+ self.comp_end_event.wait()
723
+
724
+ send_tokens = shuffled_send_tokens
725
+ recv_tokens = self._exchange_tokens(send_tokens, None, None, is_input=False)
726
+ send_tokens.record_stream(self.comm_stream)
727
+
728
+ self.comm_end_event.record()
729
+ recv_tokens.record_stream(comp_stream)
730
+
731
+ # 6. Dispatch shared expert part 2.
732
+ shared_z = self._shared_expert_part2(shared_y)
733
+
734
+ # 7. Dispatch combine outputs.
735
+ self.comm_end_event.wait()
736
+ final_output = self._combine_outputs(
737
+ shared_z,
738
+ recv_tokens.view(self.ep_size, T, D),
739
+ token_indices,
740
+ token_counts,
741
+ padded=True,
742
+ )
743
+
744
+ T //= B
745
+ return final_output.view(B, T, D)
746
+
747
+ def _no_comm_forward(
748
+ self, tokens: torch.Tensor, overlap_router_and_shared_expert: bool
749
+ ) -> torch.Tensor:
750
+ # Default stream for compute
751
+ comp_stream = torch.cuda.current_stream()
752
+ if overlap_router_and_shared_expert:
753
+ self.comp_end_event.record()
754
+ (B, T, D) = tokens.shape
755
+
756
+ # 1. Dispatch router kernels and shared experts GEMMs.
757
+ routed_tokens, routed_tokens_scales, token_counts, token_indices = self._route(
758
+ tokens
759
+ )
760
+
761
+ if overlap_router_and_shared_expert:
762
+ with torch.cuda.stream(self.comm_stream):
763
+ self.comp_end_event.wait()
764
+
765
+ shared_y = self._shared_expert_part1(tokens)
766
+ shared_z = self._shared_expert_part2(shared_y)
767
+ tokens.record_stream(self.comm_stream)
768
+
769
+ self.comm_end_event.record()
770
+ shared_z.record_stream(comp_stream)
771
+ self.comm_end_event.wait()
772
+ else:
773
+ shared_y = self._shared_expert_part1(tokens)
774
+ shared_z = self._shared_expert_part2(shared_y)
775
+
776
+ # 2. Dispatch routed expert GEMMs.
777
+ if not torch.version.hip:
778
+ final_output = self._routed_expert(
779
+ routed_tokens,
780
+ token_counts,
781
+ token_scales=routed_tokens_scales,
782
+ shared_output=shared_z,
783
+ token_indices=token_indices,
784
+ )
785
+ else:
786
+ routed_z = self._routed_expert(
787
+ routed_tokens,
788
+ token_counts,
789
+ token_scales=routed_tokens_scales,
790
+ )
791
+ # 3. Dispatch combine outputs.
792
+ final_output = self._combine_outputs(
793
+ shared_z, routed_z, token_indices, token_counts, padded=False
794
+ )
795
+
796
+ return final_output.view(B, T, D)
797
+
798
+ def _exchange_shapes(self, send_sizes: torch.Tensor) -> torch.Tensor:
799
+ "No CPU/GPU sync in this function."
800
+ if self.ep_size == 1:
801
+ return send_sizes
802
+
803
+ assert tuple(send_sizes.shape) == (self.E,)
804
+ recv_sizes = torch.empty_like(send_sizes)
805
+
806
+ recv_sizes_list = list(recv_sizes.chunk(self.ep_size))
807
+ send_sizes_list = list(send_sizes.chunk(self.ep_size))
808
+
809
+ assert all(r.is_contiguous() for r in recv_sizes_list)
810
+ assert all(s.is_contiguous() for s in send_sizes_list)
811
+ torch.distributed.all_to_all(
812
+ output_tensor_list=recv_sizes_list,
813
+ input_tensor_list=send_sizes_list,
814
+ group=self.ep_group,
815
+ )
816
+
817
+ # send_sizes: [E] viewed as [EP, EG]
818
+ # recv_sizes: [E] viewed as [EP, EG]
819
+ return recv_sizes.view(self.ep_size, self.num_local_experts)
820
+
821
+ def _gather_shapes(self, send_sizes: torch.Tensor) -> torch.Tensor:
822
+ "No CPU/GPU sync in this function."
823
+ if self.ep_size == 1:
824
+ return send_sizes
825
+
826
+ assert tuple(send_sizes.shape) == (self.E,)
827
+ recv_sizes = torch.empty(
828
+ (self.ep_size, self.E), dtype=send_sizes.dtype, device=send_sizes.device
829
+ )
830
+
831
+ assert send_sizes.is_contiguous()
832
+ assert recv_sizes.is_contiguous()
833
+ torch.distributed.all_gather_into_tensor(
834
+ output_tensor=recv_sizes,
835
+ input_tensor=send_sizes,
836
+ group=self.ep_group,
837
+ )
838
+
839
+ # send_sizes: [E]
840
+ # recv_sizes: [EP, E]
841
+ return recv_sizes
842
+
843
+ def _exchange_tokens(
844
+ self,
845
+ send_tokens: torch.Tensor,
846
+ send_sizes: Optional[list[int]],
847
+ recv_sizes: Optional[list[int]],
848
+ is_input: bool,
849
+ ) -> torch.Tensor:
850
+ """
851
+ When `send_sizes`/`recv_size` are `None`, we assume the tokens are evenly distributed
852
+ across different EP ranks, so the total number of tokens `T` are split by `E`.
853
+ No CPU/GPU sync in this function.
854
+ """
855
+ if self.ep_size == 1:
856
+ return send_tokens
857
+
858
+ D = send_tokens.shape[-1]
859
+ send_tokens = send_tokens.view(-1, D)
860
+ T = send_tokens.shape[0]
861
+
862
+ if send_sizes is None:
863
+ send_sizes = [T // self.ep_size for _ in range(self.ep_size)]
864
+ else:
865
+ send_sizes = [
866
+ sum(
867
+ send_sizes[
868
+ r * self.num_local_experts : (r + 1) * self.num_local_experts
869
+ ]
870
+ )
871
+ for r in range(self.ep_size)
872
+ ]
873
+
874
+ if recv_sizes is None:
875
+ recv_sizes = [T // self.ep_size for _ in range(self.ep_size)]
876
+ else:
877
+ recv_sizes = [
878
+ sum(
879
+ recv_sizes[
880
+ r * self.num_local_experts : (r + 1) * self.num_local_experts
881
+ ]
882
+ )
883
+ for r in range(self.ep_size)
884
+ ]
885
+
886
+ # TODO: Add FP8 A2A to example.
887
+ if self.dedup_comm:
888
+ if is_input:
889
+ sliced_recv_tokens = torch.empty(
890
+ (sum(none_throws(recv_sizes)), D // self.ep_mp_size),
891
+ dtype=send_tokens.dtype,
892
+ device=send_tokens.device,
893
+ )
894
+ # TODO(shikaili): Extremely high copy overhead in prefill.
895
+ sliced_send_tokens = send_tokens.chunk(self.ep_mp_size, dim=-1)[
896
+ self.ep_mp_rank
897
+ ].contiguous()
898
+
899
+ recv_tokens_list = list(
900
+ sliced_recv_tokens.split(none_throws(recv_sizes))
901
+ )
902
+ send_tokens_list = list(
903
+ sliced_send_tokens.split(none_throws(send_sizes))
904
+ )
905
+
906
+ assert all(r.is_contiguous() for r in recv_tokens_list)
907
+ assert all(s.is_contiguous() for s in send_tokens_list)
908
+ torch.distributed.all_to_all(
909
+ output_tensor_list=recv_tokens_list,
910
+ input_tensor_list=send_tokens_list,
911
+ group=self.ep_group,
912
+ )
913
+
914
+ recv_tokens_permutated = torch.empty(
915
+ (
916
+ self.ep_mp_size,
917
+ sum(none_throws(recv_sizes)),
918
+ D // self.ep_mp_size,
919
+ ),
920
+ dtype=send_tokens.dtype,
921
+ device=send_tokens.device,
922
+ )
923
+
924
+ assert sliced_recv_tokens.is_contiguous()
925
+ assert recv_tokens_permutated.is_contiguous()
926
+ torch.distributed.all_gather_into_tensor(
927
+ output_tensor=recv_tokens_permutated,
928
+ input_tensor=sliced_recv_tokens,
929
+ group=self.ep_mp_group,
930
+ )
931
+
932
+ return (
933
+ recv_tokens_permutated.permute(1, 0, 2).reshape(-1, D).contiguous()
934
+ )
935
+ else:
936
+ # ReduceScatter
937
+ reduced_sliced_send_tokens = torch.empty(
938
+ (D // self.ep_mp_size, sum(none_throws(send_sizes))),
939
+ dtype=send_tokens.dtype,
940
+ device=send_tokens.device,
941
+ )
942
+ torch.distributed.reduce_scatter_tensor(
943
+ output=reduced_sliced_send_tokens,
944
+ input=send_tokens.transpose(0, 1).contiguous(),
945
+ group=self.ep_mp_group,
946
+ )
947
+ reduced_sliced_send_tokens = reduced_sliced_send_tokens.transpose(
948
+ 0, 1
949
+ ).contiguous()
950
+
951
+ # AlltoAll
952
+ reduced_sliced_recv_tokens = torch.empty(
953
+ (sum(none_throws(recv_sizes)), D // self.ep_mp_size),
954
+ dtype=send_tokens.dtype,
955
+ device=send_tokens.device,
956
+ )
957
+ recv_tokens_list = list(
958
+ reduced_sliced_recv_tokens.split(none_throws(recv_sizes))
959
+ )
960
+ send_tokens_list = list(
961
+ reduced_sliced_send_tokens.split(none_throws(send_sizes))
962
+ )
963
+
964
+ assert all(r.is_contiguous() for r in recv_tokens_list)
965
+ assert all(s.is_contiguous() for s in send_tokens_list)
966
+ torch.distributed.all_to_all(
967
+ output_tensor_list=recv_tokens_list,
968
+ input_tensor_list=send_tokens_list,
969
+ group=self.ep_group,
970
+ )
971
+
972
+ # Padding
973
+ slice_d = D // self.ep_mp_size
974
+ pad_l = slice_d * self.ep_mp_rank
975
+ pad_r = D - pad_l - slice_d
976
+ return torch.nn.functional.pad(
977
+ reduced_sliced_recv_tokens, (pad_l, pad_r)
978
+ )
979
+ else:
980
+ recv_tokens = torch.empty(
981
+ (sum(none_throws(recv_sizes)), D),
982
+ dtype=send_tokens.dtype,
983
+ device=send_tokens.device,
984
+ )
985
+
986
+ recv_tokens_list = list(recv_tokens.split(none_throws(recv_sizes)))
987
+ send_tokens_list = list(send_tokens.split(none_throws(send_sizes)))
988
+
989
+ assert all(r.is_contiguous() for r in recv_tokens_list)
990
+ assert all(s.is_contiguous() for s in send_tokens_list)
991
+ torch.distributed.all_to_all(
992
+ output_tensor_list=recv_tokens_list,
993
+ input_tensor_list=send_tokens_list,
994
+ group=self.ep_group,
995
+ )
996
+
997
+ return recv_tokens
998
+
999
+ def _gather_tokens(
1000
+ self,
1001
+ send_tokens: torch.Tensor,
1002
+ ) -> torch.Tensor:
1003
+ "No CPU/GPU sync in this function."
1004
+ if self.ep_size == 1:
1005
+ return send_tokens
1006
+
1007
+ # TODO: Add FP8 AG to example.
1008
+ T, D = send_tokens.shape
1009
+ if self.dedup_comm:
1010
+ inter_node_recv_tokens = torch.empty(
1011
+ (self.ep_size, T, D // self.ep_mp_size),
1012
+ dtype=send_tokens.dtype,
1013
+ device=send_tokens.device,
1014
+ )
1015
+ # Copy overhead.
1016
+ inter_node_send_tokens = send_tokens.chunk(self.ep_mp_size, dim=-1)[
1017
+ self.ep_mp_rank
1018
+ ].contiguous()
1019
+
1020
+ assert inter_node_send_tokens.is_contiguous()
1021
+ assert inter_node_recv_tokens.is_contiguous()
1022
+ torch.distributed.all_gather_into_tensor(
1023
+ output_tensor=inter_node_recv_tokens,
1024
+ input_tensor=inter_node_send_tokens,
1025
+ group=self.ep_group,
1026
+ )
1027
+
1028
+ intra_node_recv_tokens_transposed = torch.empty(
1029
+ (self.ep_mp_size, self.ep_size, T, D // self.ep_mp_size),
1030
+ dtype=send_tokens.dtype,
1031
+ device=send_tokens.device,
1032
+ )
1033
+
1034
+ assert inter_node_recv_tokens.is_contiguous()
1035
+ assert intra_node_recv_tokens_transposed.is_contiguous()
1036
+ torch.distributed.all_gather_into_tensor(
1037
+ output_tensor=intra_node_recv_tokens_transposed,
1038
+ input_tensor=inter_node_recv_tokens,
1039
+ group=self.ep_mp_group,
1040
+ )
1041
+
1042
+ # Copy overhead.
1043
+ return (
1044
+ intra_node_recv_tokens_transposed.permute(1, 2, 0, 3)
1045
+ .reshape(self.ep_size, T, D)
1046
+ .contiguous()
1047
+ )
1048
+ else:
1049
+ recv_tokens = torch.empty(
1050
+ (self.ep_size, T, D),
1051
+ dtype=send_tokens.dtype,
1052
+ device=send_tokens.device,
1053
+ )
1054
+
1055
+ assert send_tokens.is_contiguous()
1056
+ assert recv_tokens.is_contiguous()
1057
+ torch.distributed.all_gather_into_tensor(
1058
+ output_tensor=recv_tokens,
1059
+ input_tensor=send_tokens,
1060
+ group=self.ep_group,
1061
+ )
1062
+ return recv_tokens
1063
+
1064
+ def _route(
1065
+ self, tokens: torch.Tensor
1066
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
1067
+ B, T, D = tokens.shape
1068
+ tokens = tokens.view(-1, D)
1069
+
1070
+ assert not self.router_DE.is_scaled
1071
+ scores = torch.nn.functional.linear(tokens, self.router_DE.T)
1072
+ scores = torch.sigmoid(scores)
1073
+ assert scores.shape == (B * T, self.E)
1074
+
1075
+ token_counts, expert_indices, token_indices = index_shuffling(
1076
+ scores, # num_tokens
1077
+ )
1078
+ token_counts = token_counts[: self.E]
1079
+
1080
+ if self.dedup_comm:
1081
+ split_sizes = [
1082
+ token_counts.shape[0],
1083
+ expert_indices.shape[0],
1084
+ token_indices.shape[0],
1085
+ ]
1086
+ output = torch.concat([token_counts, expert_indices, token_indices], dim=0)
1087
+ # Require broadcast as index_shuffling is not deterministic.
1088
+ torch.distributed.broadcast(
1089
+ output,
1090
+ src=(torch.distributed.get_rank() // self.ep_mp_size) * self.ep_mp_size,
1091
+ group=self.ep_mp_group,
1092
+ )
1093
+ token_counts, expert_indices, token_indices = torch.split(
1094
+ output, split_sizes, dim=0
1095
+ )
1096
+
1097
+ if self.is_routed_fp8_rowwise and self.ep_size == 1:
1098
+ routed_tokens, routed_tokens_scales = gather_scale_quant_dense_tokens(
1099
+ tokens,
1100
+ token_indices=token_indices.flatten(),
1101
+ expert_indices=expert_indices.flatten(),
1102
+ scores=scores,
1103
+ scale_ub=self.activation_scale_ub,
1104
+ )
1105
+ else:
1106
+ routed_tokens = gather_scale_dense_tokens(
1107
+ tokens,
1108
+ token_indices=token_indices.flatten(),
1109
+ expert_indices=expert_indices.flatten(),
1110
+ scores=scores,
1111
+ )
1112
+ routed_tokens_scales = None
1113
+ return routed_tokens, routed_tokens_scales, token_counts, token_indices
1114
+
1115
+ def _shared_expert_part1(self, x: torch.Tensor) -> torch.Tensor:
1116
+ # tokens: [B, T, D]
1117
+ D = x.shape[-1]
1118
+ x = x.view(-1, D)
1119
+ w13 = self.shared_experts.w13
1120
+
1121
+ if not self.is_shared_fp8_rowwise:
1122
+ # TODO(shikaili): Skip padded tokens.
1123
+ return x @ w13.T
1124
+ else:
1125
+ x, x_scale = triton_quantize_fp8_row(x, self.activation_scale_ub)
1126
+ # TODO(shikaili): Skip padded tokens.
1127
+ return torch.ops.fbgemm.f8f8bf16_rowwise(
1128
+ x,
1129
+ w13.weights,
1130
+ x_scale,
1131
+ w13.scales,
1132
+ use_fast_accum=self.use_fast_accum,
1133
+ )
1134
+
1135
+ def _shared_expert_part2(self, y: torch.Tensor) -> torch.Tensor:
1136
+ # tokens: [B, T, D]
1137
+ HD_L_2 = y.shape[-1]
1138
+ HD_L = HD_L_2 // 2
1139
+ w2 = self.shared_experts.w2
1140
+
1141
+ z, z_scale = self._fused_silu_mul(
1142
+ y[:, :HD_L],
1143
+ y[:, HD_L:],
1144
+ self.is_shared_fp8_rowwise,
1145
+ self.activation_scale_ub,
1146
+ )
1147
+ if not self.is_shared_fp8_rowwise:
1148
+ assert z_scale is None
1149
+ # TODO(shikaili): Skip padded tokens.
1150
+ return z @ w2.T
1151
+ else:
1152
+ assert z_scale is not None
1153
+ # TODO(shikaili): Skip padded tokens.
1154
+ return torch.ops.fbgemm.f8f8bf16_rowwise(
1155
+ z,
1156
+ w2.weights,
1157
+ z_scale,
1158
+ w2.scales,
1159
+ use_fast_accum=self.use_fast_accum,
1160
+ )
1161
+
1162
+ def _routed_expert(
1163
+ self,
1164
+ tokens: torch.Tensor,
1165
+ token_counts: torch.Tensor,
1166
+ token_scales: Optional[torch.Tensor] = None,
1167
+ shared_output: Optional[torch.Tensor] = None,
1168
+ token_indices: Optional[torch.Tensor] = None,
1169
+ ) -> torch.Tensor:
1170
+ # tokens: [B, T, D]
1171
+ D = tokens.shape[-1]
1172
+ x = tokens.view(-1, D)
1173
+
1174
+ if x.shape[0] == 0:
1175
+ return x
1176
+
1177
+ w13 = self.routed_experts.w13
1178
+ w2 = self.routed_experts.w2
1179
+
1180
+ assert D == w13.shape[-1]
1181
+ HD_L = w2.shape[-1]
1182
+
1183
+ assert token_counts.shape == (self.num_local_experts,)
1184
+ if not self.is_routed_fp8_rowwise:
1185
+ y = grouped_gemm(
1186
+ x,
1187
+ w13.view(-1, D),
1188
+ token_counts,
1189
+ use_fast_accum=self.use_fast_accum,
1190
+ _use_warp_specialization=not torch.version.hip,
1191
+ )
1192
+ z, _ = self._fused_silu_mul(y[:, :HD_L], y[:, HD_L:], False)
1193
+ return grouped_gemm(
1194
+ z,
1195
+ w2.view(-1, HD_L),
1196
+ token_counts,
1197
+ use_fast_accum=self.use_fast_accum,
1198
+ _use_warp_specialization=not torch.version.hip,
1199
+ _output_tensor=shared_output,
1200
+ _scatter_add_indices=token_indices,
1201
+ )
1202
+ else:
1203
+ if token_scales is None:
1204
+ x, x_scale = triton_quantize_fp8_row(x, self.activation_scale_ub)
1205
+ else:
1206
+ x_scale = token_scales
1207
+ y = grouped_gemm_fp8_rowwise(
1208
+ x,
1209
+ w13.weights.view(-1, D),
1210
+ token_counts,
1211
+ x_scale.view(-1),
1212
+ w13.scales.view(-1),
1213
+ use_fast_accum=self.use_fast_accum,
1214
+ _use_warp_specialization=not torch.version.hip,
1215
+ )
1216
+ # TODO(shikaili): Skip padded tokens.
1217
+ z, z_scale = self._fused_silu_mul(
1218
+ y[:, :HD_L], y[:, HD_L:], True, self.activation_scale_ub
1219
+ )
1220
+ assert z_scale is not None
1221
+ return grouped_gemm_fp8_rowwise(
1222
+ z,
1223
+ w2.weights.view(-1, HD_L),
1224
+ token_counts,
1225
+ z_scale.view(-1),
1226
+ w2.scales.view(-1),
1227
+ use_fast_accum=self.use_fast_accum,
1228
+ _use_warp_specialization=not torch.version.hip,
1229
+ _output_tensor=shared_output,
1230
+ _scatter_add_indices=token_indices,
1231
+ )
1232
+
1233
+ def _combine_outputs(
1234
+ self,
1235
+ shared_output_tokens: torch.Tensor,
1236
+ routed_output_tokens: torch.Tensor,
1237
+ token_indices: torch.Tensor,
1238
+ token_counts: torch.Tensor,
1239
+ padded: bool = False,
1240
+ ) -> torch.Tensor:
1241
+ D = shared_output_tokens.shape[-1]
1242
+ assert routed_output_tokens.shape[-1] == D
1243
+
1244
+ if padded:
1245
+ scatter_add_padded_tokens(
1246
+ in_tokens=routed_output_tokens,
1247
+ token_counts=token_counts,
1248
+ token_indices=token_indices,
1249
+ out_tokens=shared_output_tokens,
1250
+ )
1251
+ return shared_output_tokens
1252
+
1253
+ scatter_add_dense_tokens(
1254
+ shared_output_tokens,
1255
+ routed_output_tokens.view(-1, D),
1256
+ token_indices,
1257
+ )
1258
+ return shared_output_tokens
1259
+
1260
+ def _fused_silu_mul(
1261
+ self,
1262
+ x0: torch.Tensor,
1263
+ x1: torch.Tensor,
1264
+ is_fp8: bool,
1265
+ scale_ub: Optional[torch.Tensor] = None,
1266
+ ):
1267
+ z_scale = None
1268
+ if is_fp8:
1269
+ z, z_scale = silu_mul_quant(x0, x1, scale_ub)
1270
+ else:
1271
+ z = silu_mul(x0, x1)
1272
+ return z, z_scale