mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
mslk/moe/layers.py
ADDED
|
@@ -0,0 +1,1240 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
|
|
9
|
+
from abc import ABCMeta, abstractmethod
|
|
10
|
+
from collections.abc import Mapping
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from functools import cached_property
|
|
13
|
+
from typing import Callable, Optional, Union
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
from fairscale.nn.model_parallel.initialize import get_model_parallel_world_size
|
|
17
|
+
from mslk.gemm.triton.grouped_gemm import grouped_gemm, grouped_gemm_fp8_rowwise
|
|
18
|
+
from mslk.moe.activation import silu_mul, silu_mul_quant
|
|
19
|
+
from mslk.moe.gather_scatter import (
|
|
20
|
+
gather_scale_dense_tokens,
|
|
21
|
+
gather_scale_quant_dense_tokens,
|
|
22
|
+
scatter_add_dense_tokens,
|
|
23
|
+
scatter_add_padded_tokens,
|
|
24
|
+
)
|
|
25
|
+
from mslk.moe.shuffling import combine_shuffling, split_shuffling
|
|
26
|
+
from mslk.quantize.triton.fp8_quantize import triton_quantize_fp8_row
|
|
27
|
+
from pyre_extensions import none_throws
|
|
28
|
+
from torch.distributed import get_rank, ProcessGroup
|
|
29
|
+
|
|
30
|
+
if torch.cuda.is_available():
|
|
31
|
+
index_shuffling = torch.ops.mslk.index_shuffling # noqa F401
|
|
32
|
+
else:
|
|
33
|
+
index_shuffling = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
__all__ = ["MoEArgs", "BaselineMoE", "MetaShufflingMoE"]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass(frozen=True)
|
|
40
|
+
class MoEArgs:
|
|
41
|
+
precision: str
|
|
42
|
+
dim: int
|
|
43
|
+
hidden_dim: int
|
|
44
|
+
num_experts: int
|
|
45
|
+
top_k: int
|
|
46
|
+
mp_size: int
|
|
47
|
+
ep_size: int
|
|
48
|
+
mp_size_for_routed_experts: Optional[int]
|
|
49
|
+
use_fast_accum: bool
|
|
50
|
+
dedup_comm: bool
|
|
51
|
+
|
|
52
|
+
@cached_property
|
|
53
|
+
def num_local_experts(self) -> int:
|
|
54
|
+
return self.num_experts // self.ep_size
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
INIT_METHODS_TYPE = Mapping[
|
|
58
|
+
str,
|
|
59
|
+
Callable[[torch.Tensor], Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]],
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class ScaledParameter(torch.nn.Parameter):
|
|
64
|
+
def __new__(
|
|
65
|
+
cls,
|
|
66
|
+
data: torch.Tensor,
|
|
67
|
+
scale: Optional[torch.Tensor] = None,
|
|
68
|
+
) -> "ScaledParameter":
|
|
69
|
+
return super().__new__(cls, data, False)
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
data: torch.Tensor,
|
|
74
|
+
scale: Optional[torch.Tensor] = None,
|
|
75
|
+
):
|
|
76
|
+
self._scale: Optional[torch.Tensor] = scale
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def weights(self) -> torch.Tensor:
|
|
80
|
+
return self.data
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def scales(self) -> torch.Tensor:
|
|
84
|
+
assert self._scale is not None
|
|
85
|
+
return self._scale
|
|
86
|
+
|
|
87
|
+
@scales.setter
|
|
88
|
+
def scales(self, s: torch.Tensor) -> None:
|
|
89
|
+
self._scale = s
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def is_scaled(self) -> bool:
|
|
93
|
+
return self._scale is not None
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
# Helper functions/modules to perform weights sharding and initialization.
|
|
97
|
+
def init_params(
|
|
98
|
+
key: str,
|
|
99
|
+
param: ScaledParameter,
|
|
100
|
+
init_methods: INIT_METHODS_TYPE,
|
|
101
|
+
):
|
|
102
|
+
if key in init_methods:
|
|
103
|
+
ret = init_methods[key](param.data)
|
|
104
|
+
if isinstance(ret, torch.Tensor):
|
|
105
|
+
param.data = ret
|
|
106
|
+
else:
|
|
107
|
+
param.data, param.scales = ret
|
|
108
|
+
else:
|
|
109
|
+
torch.nn.init.kaiming_uniform_(param)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class Experts(torch.nn.Module, metaclass=ABCMeta):
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
dim: int,
|
|
116
|
+
hidden_dim: int,
|
|
117
|
+
):
|
|
118
|
+
super().__init__()
|
|
119
|
+
|
|
120
|
+
self.dim: int = dim
|
|
121
|
+
self.hidden_dim: int = hidden_dim
|
|
122
|
+
|
|
123
|
+
self.dtype: torch.dtype = torch.get_default_dtype()
|
|
124
|
+
self.divide_factor: int = get_model_parallel_world_size()
|
|
125
|
+
|
|
126
|
+
assert self.dim % self.divide_factor == 0
|
|
127
|
+
assert self.hidden_dim % self.divide_factor == 0
|
|
128
|
+
|
|
129
|
+
self._w13: Optional[ScaledParameter] = None
|
|
130
|
+
self._w2: Optional[ScaledParameter] = None
|
|
131
|
+
|
|
132
|
+
@abstractmethod
|
|
133
|
+
def build(self, init_methods: Optional[INIT_METHODS_TYPE] = None) -> "Experts":
|
|
134
|
+
pass
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def w13(self) -> ScaledParameter:
|
|
138
|
+
assert self._w13 is not None, "Parameters are not initialized!"
|
|
139
|
+
return self._w13
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def w2(self) -> ScaledParameter:
|
|
143
|
+
assert self._w2 is not None, "Parameters are not initialized!"
|
|
144
|
+
return self._w2
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def is_fp8_rowwise(self) -> bool:
|
|
148
|
+
return self.w13.dtype == torch.float8_e4m3fn
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class RoutedExperts(Experts):
|
|
152
|
+
def __init__(
|
|
153
|
+
self,
|
|
154
|
+
num_local_experts: int,
|
|
155
|
+
dim: int,
|
|
156
|
+
hidden_dim: int,
|
|
157
|
+
) -> None:
|
|
158
|
+
super().__init__(dim, hidden_dim)
|
|
159
|
+
|
|
160
|
+
self.num_local_experts: int = num_local_experts
|
|
161
|
+
|
|
162
|
+
def build(
|
|
163
|
+
self, init_methods: Optional[INIT_METHODS_TYPE] = None
|
|
164
|
+
) -> "RoutedExperts":
|
|
165
|
+
init_methods = {} if init_methods is None else init_methods
|
|
166
|
+
|
|
167
|
+
moe_w_in_eDF: ScaledParameter = ScaledParameter(
|
|
168
|
+
torch.empty(
|
|
169
|
+
self.num_local_experts,
|
|
170
|
+
self.dim,
|
|
171
|
+
self.hidden_dim // self.divide_factor,
|
|
172
|
+
dtype=self.dtype,
|
|
173
|
+
)
|
|
174
|
+
)
|
|
175
|
+
init_params("moe_w_in_eDF", moe_w_in_eDF, init_methods)
|
|
176
|
+
|
|
177
|
+
moe_w_out_eFD: ScaledParameter = ScaledParameter(
|
|
178
|
+
torch.empty(
|
|
179
|
+
self.num_local_experts,
|
|
180
|
+
self.hidden_dim // self.divide_factor,
|
|
181
|
+
self.dim,
|
|
182
|
+
dtype=self.dtype,
|
|
183
|
+
)
|
|
184
|
+
)
|
|
185
|
+
init_params("moe_w_out_eFD", moe_w_out_eFD, init_methods)
|
|
186
|
+
|
|
187
|
+
moe_w_swiglu_eDF: ScaledParameter = ScaledParameter(
|
|
188
|
+
torch.empty(
|
|
189
|
+
self.num_local_experts,
|
|
190
|
+
self.dim,
|
|
191
|
+
self.hidden_dim // self.divide_factor,
|
|
192
|
+
dtype=self.dtype,
|
|
193
|
+
)
|
|
194
|
+
)
|
|
195
|
+
init_params("moe_w_swiglu_eDF", moe_w_swiglu_eDF, init_methods)
|
|
196
|
+
|
|
197
|
+
assert (
|
|
198
|
+
moe_w_in_eDF.dtype == moe_w_out_eFD.dtype
|
|
199
|
+
and moe_w_in_eDF.dtype == moe_w_swiglu_eDF.dtype
|
|
200
|
+
)
|
|
201
|
+
assert (
|
|
202
|
+
moe_w_in_eDF.is_scaled == moe_w_out_eFD.is_scaled
|
|
203
|
+
and moe_w_in_eDF.is_scaled == moe_w_swiglu_eDF.is_scaled
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
self._w13 = ScaledParameter(
|
|
207
|
+
data=torch.cat(
|
|
208
|
+
[
|
|
209
|
+
moe_w_in_eDF,
|
|
210
|
+
moe_w_swiglu_eDF,
|
|
211
|
+
],
|
|
212
|
+
dim=-1,
|
|
213
|
+
)
|
|
214
|
+
.transpose(1, 2)
|
|
215
|
+
.contiguous(),
|
|
216
|
+
scale=(
|
|
217
|
+
torch.cat(
|
|
218
|
+
[
|
|
219
|
+
moe_w_in_eDF.scales,
|
|
220
|
+
moe_w_swiglu_eDF.scales,
|
|
221
|
+
],
|
|
222
|
+
dim=-1,
|
|
223
|
+
).contiguous()
|
|
224
|
+
if moe_w_in_eDF.is_scaled
|
|
225
|
+
else None
|
|
226
|
+
),
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
del moe_w_in_eDF
|
|
230
|
+
del moe_w_swiglu_eDF
|
|
231
|
+
|
|
232
|
+
self._w2 = ScaledParameter(
|
|
233
|
+
data=moe_w_out_eFD.transpose(1, 2).contiguous(),
|
|
234
|
+
scale=(
|
|
235
|
+
moe_w_out_eFD.scales.contiguous() if moe_w_out_eFD.is_scaled else None
|
|
236
|
+
),
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
del moe_w_out_eFD
|
|
240
|
+
|
|
241
|
+
return self
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class SharedExperts(Experts):
|
|
245
|
+
def __init__(
|
|
246
|
+
self,
|
|
247
|
+
dim: int,
|
|
248
|
+
hidden_dim: int,
|
|
249
|
+
):
|
|
250
|
+
super().__init__(dim, hidden_dim)
|
|
251
|
+
|
|
252
|
+
def build(
|
|
253
|
+
self, init_methods: Optional[INIT_METHODS_TYPE] = None
|
|
254
|
+
) -> "SharedExperts":
|
|
255
|
+
init_methods = {} if init_methods is None else init_methods
|
|
256
|
+
|
|
257
|
+
w_in_shared_FD = ScaledParameter(
|
|
258
|
+
torch.empty(
|
|
259
|
+
(self.hidden_dim // self.divide_factor, self.dim), dtype=self.dtype
|
|
260
|
+
)
|
|
261
|
+
)
|
|
262
|
+
init_params("w_in_shared_FD", w_in_shared_FD, init_methods)
|
|
263
|
+
|
|
264
|
+
w_out_shared_DF = ScaledParameter(
|
|
265
|
+
torch.empty(
|
|
266
|
+
(self.dim, self.hidden_dim // self.divide_factor), dtype=self.dtype
|
|
267
|
+
)
|
|
268
|
+
)
|
|
269
|
+
init_params("w_out_shared_DF", w_out_shared_DF, init_methods)
|
|
270
|
+
|
|
271
|
+
w_swiglu_FD = ScaledParameter(
|
|
272
|
+
torch.empty(
|
|
273
|
+
(self.hidden_dim // self.divide_factor, self.dim), dtype=self.dtype
|
|
274
|
+
)
|
|
275
|
+
)
|
|
276
|
+
init_params("w_swiglu_FD", w_swiglu_FD, init_methods)
|
|
277
|
+
|
|
278
|
+
assert (w_in_shared_FD.dtype == w_out_shared_DF.dtype) and (
|
|
279
|
+
w_in_shared_FD.dtype == w_swiglu_FD.dtype
|
|
280
|
+
)
|
|
281
|
+
assert (w_in_shared_FD.is_scaled == w_out_shared_DF.is_scaled) and (
|
|
282
|
+
w_in_shared_FD.is_scaled == w_swiglu_FD.is_scaled
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
self._w13 = ScaledParameter(
|
|
286
|
+
data=torch.cat(
|
|
287
|
+
[
|
|
288
|
+
w_in_shared_FD,
|
|
289
|
+
w_swiglu_FD,
|
|
290
|
+
]
|
|
291
|
+
).contiguous(),
|
|
292
|
+
scale=(
|
|
293
|
+
torch.cat(
|
|
294
|
+
[
|
|
295
|
+
w_in_shared_FD.scales,
|
|
296
|
+
w_swiglu_FD.scales,
|
|
297
|
+
]
|
|
298
|
+
).contiguous()
|
|
299
|
+
if w_in_shared_FD.is_scaled
|
|
300
|
+
else None
|
|
301
|
+
),
|
|
302
|
+
)
|
|
303
|
+
del w_in_shared_FD
|
|
304
|
+
del w_swiglu_FD
|
|
305
|
+
|
|
306
|
+
self._w2 = ScaledParameter(
|
|
307
|
+
data=w_out_shared_DF.data.contiguous(),
|
|
308
|
+
scale=(
|
|
309
|
+
w_out_shared_DF.scales.contiguous()
|
|
310
|
+
if w_out_shared_DF.is_scaled
|
|
311
|
+
else None
|
|
312
|
+
),
|
|
313
|
+
)
|
|
314
|
+
del w_out_shared_DF
|
|
315
|
+
|
|
316
|
+
return self
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
class BaselineMoE(torch.nn.Module):
|
|
320
|
+
def __init__(
|
|
321
|
+
self,
|
|
322
|
+
ep_group: ProcessGroup,
|
|
323
|
+
ep_mp_group: ProcessGroup,
|
|
324
|
+
moe_args: MoEArgs,
|
|
325
|
+
) -> None:
|
|
326
|
+
super().__init__()
|
|
327
|
+
|
|
328
|
+
self.moe_args = moe_args
|
|
329
|
+
self.mp_size: int = moe_args.mp_size
|
|
330
|
+
self.ep_size: int = moe_args.ep_size
|
|
331
|
+
self.ep_mp_size: int = (
|
|
332
|
+
moe_args.mp_size
|
|
333
|
+
if moe_args.mp_size_for_routed_experts is None
|
|
334
|
+
else moe_args.mp_size_for_routed_experts
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
self.ep_rank: int = get_rank(ep_group)
|
|
338
|
+
self.ep_mp_rank: int = get_rank(ep_mp_group)
|
|
339
|
+
|
|
340
|
+
self.ep_mp_group: ProcessGroup = ep_mp_group
|
|
341
|
+
self.ep_group: ProcessGroup = ep_group
|
|
342
|
+
|
|
343
|
+
self.num_experts: int = moe_args.num_experts
|
|
344
|
+
self.num_local_experts: int = none_throws(moe_args.num_local_experts)
|
|
345
|
+
assert self.num_experts == self.num_local_experts * self.ep_size
|
|
346
|
+
|
|
347
|
+
self.top_k: int = moe_args.top_k
|
|
348
|
+
|
|
349
|
+
self.dtype: torch.dtype = torch.get_default_dtype()
|
|
350
|
+
|
|
351
|
+
self._router_DE: Optional[ScaledParameter] = None
|
|
352
|
+
self.routed_experts = RoutedExperts(
|
|
353
|
+
moe_args.num_local_experts,
|
|
354
|
+
moe_args.dim,
|
|
355
|
+
moe_args.hidden_dim,
|
|
356
|
+
)
|
|
357
|
+
self.shared_experts = SharedExperts(
|
|
358
|
+
moe_args.dim,
|
|
359
|
+
moe_args.hidden_dim,
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
def build(self, init_methods: Optional[INIT_METHODS_TYPE] = None) -> "BaselineMoE":
|
|
363
|
+
init_methods = {} if init_methods is None else init_methods
|
|
364
|
+
|
|
365
|
+
router_DE = ScaledParameter(
|
|
366
|
+
torch.empty(self.moe_args.dim, self.moe_args.num_experts, dtype=self.dtype)
|
|
367
|
+
)
|
|
368
|
+
init_params("router_DE", router_DE, init_methods)
|
|
369
|
+
self._router_DE = router_DE
|
|
370
|
+
|
|
371
|
+
self.routed_experts.build(init_methods)
|
|
372
|
+
self.shared_experts.build(init_methods)
|
|
373
|
+
return self
|
|
374
|
+
|
|
375
|
+
@property
|
|
376
|
+
def router_DE(self) -> ScaledParameter:
|
|
377
|
+
assert self._router_DE is not None, "Parameters are not initialized!"
|
|
378
|
+
return self._router_DE
|
|
379
|
+
|
|
380
|
+
# User should overwrite this property
|
|
381
|
+
@property
|
|
382
|
+
def is_shared_fp8_rowwise(self) -> bool:
|
|
383
|
+
return self.shared_experts.is_fp8_rowwise
|
|
384
|
+
|
|
385
|
+
@property
|
|
386
|
+
def is_routed_fp8_rowwise(self) -> bool:
|
|
387
|
+
return self.routed_experts.is_fp8_rowwise
|
|
388
|
+
|
|
389
|
+
@property
|
|
390
|
+
def E(self) -> int:
|
|
391
|
+
return self.num_experts
|
|
392
|
+
|
|
393
|
+
@property
|
|
394
|
+
def EG(self) -> int:
|
|
395
|
+
return self.num_local_experts
|
|
396
|
+
|
|
397
|
+
@property
|
|
398
|
+
def K(self) -> int:
|
|
399
|
+
return self.top_k
|
|
400
|
+
|
|
401
|
+
def forward(self, x: torch.Tensor, use_static_shape: bool) -> torch.Tensor:
|
|
402
|
+
with torch.no_grad():
|
|
403
|
+
return self._forward(x, use_static_shape)
|
|
404
|
+
|
|
405
|
+
def _forward(self, x: torch.Tensor, use_static_shape: bool) -> torch.Tensor:
|
|
406
|
+
(B, T, D) = x.shape
|
|
407
|
+
T *= B
|
|
408
|
+
tokens = x.view(T, D)
|
|
409
|
+
|
|
410
|
+
# Shared Experts
|
|
411
|
+
shared_y = self._fake_quant(torch.mm, tokens, self.shared_experts.w13)
|
|
412
|
+
shared_y0, shared_y1 = torch.chunk(shared_y, chunks=2, dim=-1)
|
|
413
|
+
shared_z = shared_y0 * torch.sigmoid(shared_y0) * shared_y1
|
|
414
|
+
shared_z = self._fake_quant(torch.mm, shared_z, self.shared_experts.w2)
|
|
415
|
+
|
|
416
|
+
# Routing Scores
|
|
417
|
+
E: int = self.E
|
|
418
|
+
scores = torch.nn.functional.linear(tokens, self.router_DE.T)
|
|
419
|
+
scores = torch.sigmoid(scores)
|
|
420
|
+
assert scores.shape == (T, E)
|
|
421
|
+
|
|
422
|
+
# Routing
|
|
423
|
+
K: int = self.K
|
|
424
|
+
topk_values, topk_indices = torch.topk(scores, K, dim=-1)
|
|
425
|
+
assert topk_values.shape == (T, K)
|
|
426
|
+
assert topk_indices.shape == (T, K)
|
|
427
|
+
|
|
428
|
+
masked_scores = torch.zeros_like(scores)
|
|
429
|
+
masked_scores = (
|
|
430
|
+
masked_scores.scatter_(dim=1, index=topk_indices, src=topk_values)
|
|
431
|
+
.transpose(0, 1) # (E, T)
|
|
432
|
+
.reshape(E, T, 1)
|
|
433
|
+
.expand(E, T, D)
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
tokens = tokens.view(1, T, D).expand(E, T, D)
|
|
437
|
+
masked_tokens = tokens * masked_scores
|
|
438
|
+
|
|
439
|
+
# Routed Experts
|
|
440
|
+
EG: int = self.EG
|
|
441
|
+
if self.ep_size > 1:
|
|
442
|
+
send_tokens = masked_tokens.contiguous()
|
|
443
|
+
send_list = list(torch.chunk(send_tokens, chunks=self.ep_size, dim=0))
|
|
444
|
+
recv_tokens = torch.empty_like(send_tokens)
|
|
445
|
+
recv_list = list(torch.chunk(recv_tokens, chunks=self.ep_size, dim=0))
|
|
446
|
+
|
|
447
|
+
torch.distributed.all_to_all(
|
|
448
|
+
output_tensor_list=recv_list,
|
|
449
|
+
input_tensor_list=send_list,
|
|
450
|
+
group=self.ep_group,
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
masked_tokens = recv_tokens.reshape(EG, -1, D)
|
|
454
|
+
|
|
455
|
+
routed_y = self._fake_quant(torch.bmm, masked_tokens, self.routed_experts.w13)
|
|
456
|
+
routed_y0, routed_y1 = torch.chunk(routed_y, chunks=2, dim=-1)
|
|
457
|
+
routed_z = routed_y0 * torch.sigmoid(routed_y0) * routed_y1
|
|
458
|
+
routed_z = self._fake_quant(torch.bmm, routed_z, self.routed_experts.w2)
|
|
459
|
+
|
|
460
|
+
if self.ep_size > 1:
|
|
461
|
+
send_tokens = routed_z.reshape(E * T, D).contiguous()
|
|
462
|
+
send_list = list(torch.chunk(send_tokens, chunks=self.ep_size, dim=0))
|
|
463
|
+
recv_tokens = torch.empty_like(send_tokens)
|
|
464
|
+
recv_list = list(torch.chunk(recv_tokens, chunks=self.ep_size, dim=0))
|
|
465
|
+
|
|
466
|
+
torch.distributed.all_to_all(
|
|
467
|
+
output_tensor_list=recv_list,
|
|
468
|
+
input_tensor_list=send_list,
|
|
469
|
+
group=self.ep_group,
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
routed_z = recv_tokens.reshape(E, T, D)
|
|
473
|
+
|
|
474
|
+
return (shared_z + routed_z.sum(dim=0)).reshape(B, -1, D)
|
|
475
|
+
|
|
476
|
+
def _fake_quant(self, op, x: torch.Tensor, w: ScaledParameter) -> torch.Tensor:
|
|
477
|
+
if not w.is_scaled:
|
|
478
|
+
return op(x, w.transpose(-1, -2))
|
|
479
|
+
|
|
480
|
+
xq, xs = triton_quantize_fp8_row(x)
|
|
481
|
+
wq, ws = w.weights, w.scales
|
|
482
|
+
|
|
483
|
+
y = (
|
|
484
|
+
op(xq.to(x.dtype), wq.transpose(-1, -2).to(x.dtype))
|
|
485
|
+
* xs.unsqueeze(-1)
|
|
486
|
+
* ws.unsqueeze(-2)
|
|
487
|
+
)
|
|
488
|
+
return y.to(x.dtype)
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
class MetaShufflingMoE(BaselineMoE):
|
|
492
|
+
def __init__(
|
|
493
|
+
self,
|
|
494
|
+
ep_group: ProcessGroup,
|
|
495
|
+
ep_mp_group: ProcessGroup,
|
|
496
|
+
moe_args: MoEArgs,
|
|
497
|
+
) -> None:
|
|
498
|
+
super().__init__(ep_group=ep_group, ep_mp_group=ep_mp_group, moe_args=moe_args)
|
|
499
|
+
|
|
500
|
+
assert self.mp_size == self.ep_mp_size, (
|
|
501
|
+
"MetaShuffling only supports mp_size = mp_size_for_routed_experts now"
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
assert self.top_k == 1, (
|
|
505
|
+
"MetaShuffling only supports top 1 routing at the moment"
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
|
|
509
|
+
self.comp_end_event: torch.cuda.Event = torch.cuda.Event()
|
|
510
|
+
self.comm_end_event: torch.cuda.Event = torch.cuda.Event()
|
|
511
|
+
|
|
512
|
+
self.use_fast_accum: bool = moe_args.use_fast_accum
|
|
513
|
+
self.dedup_comm: bool = moe_args.dedup_comm
|
|
514
|
+
if self.dedup_comm:
|
|
515
|
+
assert self.ep_mp_size == self.mp_size, (
|
|
516
|
+
"TP2EP is not supported for dedup at the moment."
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
self.activation_scale_ub = None
|
|
520
|
+
|
|
521
|
+
def forward(self, x: torch.Tensor, use_static_shape: bool) -> torch.Tensor:
|
|
522
|
+
with torch.no_grad():
|
|
523
|
+
if self.ep_size == 1:
|
|
524
|
+
return self._no_comm_forward(x, use_static_shape)
|
|
525
|
+
if use_static_shape:
|
|
526
|
+
return self._static_comm_forward(x)
|
|
527
|
+
else:
|
|
528
|
+
return self._dynamic_comm_forward(x)
|
|
529
|
+
|
|
530
|
+
def _dynamic_comm_forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
|
531
|
+
comp_stream = torch.cuda.current_stream()
|
|
532
|
+
|
|
533
|
+
(B, T, D) = tokens.shape
|
|
534
|
+
T *= B
|
|
535
|
+
|
|
536
|
+
# 1. Dispatch router kernels.
|
|
537
|
+
routed_tokens, routed_tokens_scales, token_counts, token_indices = self._route(
|
|
538
|
+
tokens
|
|
539
|
+
)
|
|
540
|
+
assert routed_tokens_scales is None
|
|
541
|
+
|
|
542
|
+
# 2. Dispatch 1st all2all on shapes.
|
|
543
|
+
self.comp_end_event.record()
|
|
544
|
+
with torch.cuda.stream(self.comm_stream):
|
|
545
|
+
self.comp_end_event.wait()
|
|
546
|
+
|
|
547
|
+
send_token_counts = token_counts
|
|
548
|
+
recv_token_counts = self._exchange_shapes(send_token_counts)
|
|
549
|
+
send_token_counts.record_stream(self.comm_stream)
|
|
550
|
+
|
|
551
|
+
recv_token_counts.record_stream(comp_stream)
|
|
552
|
+
|
|
553
|
+
# 3. Dispatch shared expert part 1.
|
|
554
|
+
shared_y = self._shared_expert_part1(tokens)
|
|
555
|
+
|
|
556
|
+
with torch.cuda.stream(self.comm_stream):
|
|
557
|
+
# 4. CPU/GPU sync.
|
|
558
|
+
concat_counts = torch.concat(
|
|
559
|
+
[send_token_counts.flatten(), recv_token_counts.flatten()]
|
|
560
|
+
).cpu()
|
|
561
|
+
send_tokens_list = concat_counts[: self.E].tolist()
|
|
562
|
+
recv_tokens_list = concat_counts[self.E :].tolist()
|
|
563
|
+
|
|
564
|
+
# 5. Dispatch 2nd all2all on tokens.
|
|
565
|
+
send_tokens = routed_tokens
|
|
566
|
+
recv_tokens = self._exchange_tokens(
|
|
567
|
+
send_tokens,
|
|
568
|
+
send_tokens_list,
|
|
569
|
+
recv_tokens_list,
|
|
570
|
+
is_input=True,
|
|
571
|
+
)
|
|
572
|
+
send_tokens.record_stream(self.comm_stream)
|
|
573
|
+
|
|
574
|
+
self.comm_end_event.record()
|
|
575
|
+
recv_tokens.record_stream(comp_stream)
|
|
576
|
+
|
|
577
|
+
# 6. Dispatch routed expert kernels.
|
|
578
|
+
self.comm_end_event.wait()
|
|
579
|
+
recv_T = recv_tokens.shape[0]
|
|
580
|
+
assert recv_tokens.shape == (recv_T, D)
|
|
581
|
+
assert recv_token_counts.shape == (self.ep_size, self.num_local_experts)
|
|
582
|
+
shuffled_recv_tokens, shuffled_recv_token_counts = combine_shuffling(
|
|
583
|
+
recv_tokens, recv_token_counts
|
|
584
|
+
)
|
|
585
|
+
assert shuffled_recv_tokens.shape == (recv_T, D)
|
|
586
|
+
assert shuffled_recv_token_counts.shape == (self.num_local_experts + 1,)
|
|
587
|
+
routed_z = self._routed_expert(
|
|
588
|
+
shuffled_recv_tokens,
|
|
589
|
+
shuffled_recv_token_counts[:-1],
|
|
590
|
+
)
|
|
591
|
+
assert routed_z.shape == (recv_T, D)
|
|
592
|
+
shuffled_send_tokens = split_shuffling(routed_z, recv_token_counts)
|
|
593
|
+
assert shuffled_send_tokens.shape == (recv_T, D)
|
|
594
|
+
|
|
595
|
+
# 7. Dispatch 3rd all2all on tokens.
|
|
596
|
+
self.comp_end_event.record()
|
|
597
|
+
with torch.cuda.stream(self.comm_stream):
|
|
598
|
+
self.comp_end_event.wait()
|
|
599
|
+
|
|
600
|
+
send_tokens = shuffled_send_tokens
|
|
601
|
+
recv_tokens = self._exchange_tokens(
|
|
602
|
+
send_tokens,
|
|
603
|
+
recv_tokens_list,
|
|
604
|
+
send_tokens_list,
|
|
605
|
+
is_input=False,
|
|
606
|
+
)
|
|
607
|
+
send_tokens.record_stream(self.comm_stream)
|
|
608
|
+
|
|
609
|
+
self.comm_end_event.record()
|
|
610
|
+
recv_tokens.record_stream(comp_stream)
|
|
611
|
+
|
|
612
|
+
# 8. Dispatch shared expert part 2.
|
|
613
|
+
shared_z = self._shared_expert_part2(shared_y)
|
|
614
|
+
|
|
615
|
+
# 9. Dispatch combine outputs.
|
|
616
|
+
self.comm_end_event.wait()
|
|
617
|
+
final_output = self._combine_outputs(
|
|
618
|
+
shared_z, recv_tokens, token_indices, token_counts, padded=False
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
T //= B
|
|
622
|
+
return final_output.view(B, T, D)
|
|
623
|
+
|
|
624
|
+
def _static_comm_forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
|
625
|
+
comp_stream = torch.cuda.current_stream()
|
|
626
|
+
|
|
627
|
+
(B, T, D) = tokens.shape
|
|
628
|
+
T *= B
|
|
629
|
+
|
|
630
|
+
# 1. Dispatch router kernels.
|
|
631
|
+
routed_tokens, routed_tokens_scales, token_counts, token_indices = self._route(
|
|
632
|
+
tokens
|
|
633
|
+
)
|
|
634
|
+
assert routed_tokens_scales is None
|
|
635
|
+
|
|
636
|
+
# 2. Dispatch allgather on shapes and tokens.
|
|
637
|
+
self.comp_end_event.record()
|
|
638
|
+
with torch.cuda.stream(self.comm_stream):
|
|
639
|
+
self.comp_end_event.wait()
|
|
640
|
+
|
|
641
|
+
send_token_counts = token_counts
|
|
642
|
+
send_tokens = routed_tokens
|
|
643
|
+
# TODO(shikaili): Check if using 1 allgather is faster even with copies.
|
|
644
|
+
recv_token_counts = self._gather_shapes(send_token_counts)
|
|
645
|
+
recv_tokens = self._gather_tokens(send_tokens)
|
|
646
|
+
send_token_counts.record_stream(self.comm_stream)
|
|
647
|
+
send_tokens.record_stream(self.comm_stream)
|
|
648
|
+
|
|
649
|
+
self.comm_end_event.record()
|
|
650
|
+
recv_token_counts.record_stream(comp_stream)
|
|
651
|
+
recv_tokens.record_stream(comp_stream)
|
|
652
|
+
|
|
653
|
+
# 3. Dispatch shared expert part 1.
|
|
654
|
+
shared_y = self._shared_expert_part1(tokens)
|
|
655
|
+
|
|
656
|
+
# 4. Dispatch routed expert kernels.
|
|
657
|
+
self.comm_end_event.wait()
|
|
658
|
+
assert recv_tokens.shape == (
|
|
659
|
+
self.ep_size,
|
|
660
|
+
T,
|
|
661
|
+
D,
|
|
662
|
+
), f"{recv_tokens.shape=}, {(self.ep_size, T, D)=}"
|
|
663
|
+
assert recv_token_counts.shape == (self.ep_size, self.E)
|
|
664
|
+
shuffled_recv_tokens, shuffled_recv_token_counts = combine_shuffling(
|
|
665
|
+
recv_tokens.view(-1, D),
|
|
666
|
+
recv_token_counts,
|
|
667
|
+
expert_start=self.ep_rank * self.num_local_experts,
|
|
668
|
+
expert_end=(self.ep_rank + 1) * self.num_local_experts,
|
|
669
|
+
)
|
|
670
|
+
assert shuffled_recv_tokens.shape == (self.ep_size * T, D)
|
|
671
|
+
assert shuffled_recv_token_counts.shape == (self.num_local_experts + 1,), (
|
|
672
|
+
f"{shuffled_recv_token_counts.shape=}"
|
|
673
|
+
)
|
|
674
|
+
routed_z = self._routed_expert(
|
|
675
|
+
shuffled_recv_tokens,
|
|
676
|
+
shuffled_recv_token_counts[:-1],
|
|
677
|
+
)
|
|
678
|
+
assert routed_z.shape == (self.ep_size * T, D)
|
|
679
|
+
shuffled_send_tokens = split_shuffling(
|
|
680
|
+
routed_z,
|
|
681
|
+
recv_token_counts,
|
|
682
|
+
expert_start=self.ep_rank * self.num_local_experts,
|
|
683
|
+
expert_end=(self.ep_rank + 1) * self.num_local_experts,
|
|
684
|
+
)
|
|
685
|
+
assert shuffled_send_tokens.shape == (self.ep_size * T, D)
|
|
686
|
+
|
|
687
|
+
# 5. Dispatch all2all on tokens.
|
|
688
|
+
self.comp_end_event.record()
|
|
689
|
+
with torch.cuda.stream(self.comm_stream):
|
|
690
|
+
self.comp_end_event.wait()
|
|
691
|
+
|
|
692
|
+
send_tokens = shuffled_send_tokens
|
|
693
|
+
recv_tokens = self._exchange_tokens(send_tokens, None, None, is_input=False)
|
|
694
|
+
send_tokens.record_stream(self.comm_stream)
|
|
695
|
+
|
|
696
|
+
self.comm_end_event.record()
|
|
697
|
+
recv_tokens.record_stream(comp_stream)
|
|
698
|
+
|
|
699
|
+
# 6. Dispatch shared expert part 2.
|
|
700
|
+
shared_z = self._shared_expert_part2(shared_y)
|
|
701
|
+
|
|
702
|
+
# 7. Dispatch combine outputs.
|
|
703
|
+
self.comm_end_event.wait()
|
|
704
|
+
final_output = self._combine_outputs(
|
|
705
|
+
shared_z,
|
|
706
|
+
recv_tokens.view(self.ep_size, T, D),
|
|
707
|
+
token_indices,
|
|
708
|
+
token_counts,
|
|
709
|
+
padded=True,
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
T //= B
|
|
713
|
+
return final_output.view(B, T, D)
|
|
714
|
+
|
|
715
|
+
def _no_comm_forward(
|
|
716
|
+
self, tokens: torch.Tensor, overlap_router_and_shared_expert: bool
|
|
717
|
+
) -> torch.Tensor:
|
|
718
|
+
# Default stream for compute
|
|
719
|
+
comp_stream = torch.cuda.current_stream()
|
|
720
|
+
if overlap_router_and_shared_expert:
|
|
721
|
+
self.comp_end_event.record()
|
|
722
|
+
(B, T, D) = tokens.shape
|
|
723
|
+
|
|
724
|
+
# 1. Dispatch router kernels and shared experts GEMMs.
|
|
725
|
+
routed_tokens, routed_tokens_scales, token_counts, token_indices = self._route(
|
|
726
|
+
tokens
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
if overlap_router_and_shared_expert:
|
|
730
|
+
with torch.cuda.stream(self.comm_stream):
|
|
731
|
+
self.comp_end_event.wait()
|
|
732
|
+
|
|
733
|
+
shared_y = self._shared_expert_part1(tokens)
|
|
734
|
+
shared_z = self._shared_expert_part2(shared_y)
|
|
735
|
+
tokens.record_stream(self.comm_stream)
|
|
736
|
+
|
|
737
|
+
self.comm_end_event.record()
|
|
738
|
+
shared_z.record_stream(comp_stream)
|
|
739
|
+
self.comm_end_event.wait()
|
|
740
|
+
else:
|
|
741
|
+
shared_y = self._shared_expert_part1(tokens)
|
|
742
|
+
shared_z = self._shared_expert_part2(shared_y)
|
|
743
|
+
|
|
744
|
+
# 2. Dispatch routed expert GEMMs.
|
|
745
|
+
if not torch.version.hip:
|
|
746
|
+
final_output = self._routed_expert(
|
|
747
|
+
routed_tokens,
|
|
748
|
+
token_counts,
|
|
749
|
+
token_scales=routed_tokens_scales,
|
|
750
|
+
shared_output=shared_z,
|
|
751
|
+
token_indices=token_indices,
|
|
752
|
+
)
|
|
753
|
+
else:
|
|
754
|
+
routed_z = self._routed_expert(
|
|
755
|
+
routed_tokens,
|
|
756
|
+
token_counts,
|
|
757
|
+
token_scales=routed_tokens_scales,
|
|
758
|
+
)
|
|
759
|
+
# 3. Dispatch combine outputs.
|
|
760
|
+
final_output = self._combine_outputs(
|
|
761
|
+
shared_z, routed_z, token_indices, token_counts, padded=False
|
|
762
|
+
)
|
|
763
|
+
|
|
764
|
+
return final_output.view(B, T, D)
|
|
765
|
+
|
|
766
|
+
def _exchange_shapes(self, send_sizes: torch.Tensor) -> torch.Tensor:
|
|
767
|
+
"No CPU/GPU sync in this function."
|
|
768
|
+
if self.ep_size == 1:
|
|
769
|
+
return send_sizes
|
|
770
|
+
|
|
771
|
+
assert tuple(send_sizes.shape) == (self.E,)
|
|
772
|
+
recv_sizes = torch.empty_like(send_sizes)
|
|
773
|
+
|
|
774
|
+
recv_sizes_list = list(recv_sizes.chunk(self.ep_size))
|
|
775
|
+
send_sizes_list = list(send_sizes.chunk(self.ep_size))
|
|
776
|
+
|
|
777
|
+
assert all(r.is_contiguous() for r in recv_sizes_list)
|
|
778
|
+
assert all(s.is_contiguous() for s in send_sizes_list)
|
|
779
|
+
torch.distributed.all_to_all(
|
|
780
|
+
output_tensor_list=recv_sizes_list,
|
|
781
|
+
input_tensor_list=send_sizes_list,
|
|
782
|
+
group=self.ep_group,
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
# send_sizes: [E] viewed as [EP, EG]
|
|
786
|
+
# recv_sizes: [E] viewed as [EP, EG]
|
|
787
|
+
return recv_sizes.view(self.ep_size, self.num_local_experts)
|
|
788
|
+
|
|
789
|
+
def _gather_shapes(self, send_sizes: torch.Tensor) -> torch.Tensor:
|
|
790
|
+
"No CPU/GPU sync in this function."
|
|
791
|
+
if self.ep_size == 1:
|
|
792
|
+
return send_sizes
|
|
793
|
+
|
|
794
|
+
assert tuple(send_sizes.shape) == (self.E,)
|
|
795
|
+
recv_sizes = torch.empty(
|
|
796
|
+
(self.ep_size, self.E), dtype=send_sizes.dtype, device=send_sizes.device
|
|
797
|
+
)
|
|
798
|
+
|
|
799
|
+
assert send_sizes.is_contiguous()
|
|
800
|
+
assert recv_sizes.is_contiguous()
|
|
801
|
+
torch.distributed.all_gather_into_tensor(
|
|
802
|
+
output_tensor=recv_sizes,
|
|
803
|
+
input_tensor=send_sizes,
|
|
804
|
+
group=self.ep_group,
|
|
805
|
+
)
|
|
806
|
+
|
|
807
|
+
# send_sizes: [E]
|
|
808
|
+
# recv_sizes: [EP, E]
|
|
809
|
+
return recv_sizes
|
|
810
|
+
|
|
811
|
+
def _exchange_tokens(
|
|
812
|
+
self,
|
|
813
|
+
send_tokens: torch.Tensor,
|
|
814
|
+
send_sizes: Optional[list[int]],
|
|
815
|
+
recv_sizes: Optional[list[int]],
|
|
816
|
+
is_input: bool,
|
|
817
|
+
) -> torch.Tensor:
|
|
818
|
+
"""
|
|
819
|
+
When `send_sizes`/`recv_size` are `None`, we assume the tokens are evenly distributed
|
|
820
|
+
across different EP ranks, so the total number of tokens `T` are split by `E`.
|
|
821
|
+
No CPU/GPU sync in this function.
|
|
822
|
+
"""
|
|
823
|
+
if self.ep_size == 1:
|
|
824
|
+
return send_tokens
|
|
825
|
+
|
|
826
|
+
D = send_tokens.shape[-1]
|
|
827
|
+
send_tokens = send_tokens.view(-1, D)
|
|
828
|
+
T = send_tokens.shape[0]
|
|
829
|
+
|
|
830
|
+
if send_sizes is None:
|
|
831
|
+
send_sizes = [T // self.ep_size for _ in range(self.ep_size)]
|
|
832
|
+
else:
|
|
833
|
+
send_sizes = [
|
|
834
|
+
sum(
|
|
835
|
+
send_sizes[
|
|
836
|
+
r * self.num_local_experts : (r + 1) * self.num_local_experts
|
|
837
|
+
]
|
|
838
|
+
)
|
|
839
|
+
for r in range(self.ep_size)
|
|
840
|
+
]
|
|
841
|
+
|
|
842
|
+
if recv_sizes is None:
|
|
843
|
+
recv_sizes = [T // self.ep_size for _ in range(self.ep_size)]
|
|
844
|
+
else:
|
|
845
|
+
recv_sizes = [
|
|
846
|
+
sum(
|
|
847
|
+
recv_sizes[
|
|
848
|
+
r * self.num_local_experts : (r + 1) * self.num_local_experts
|
|
849
|
+
]
|
|
850
|
+
)
|
|
851
|
+
for r in range(self.ep_size)
|
|
852
|
+
]
|
|
853
|
+
|
|
854
|
+
# TODO: Add FP8 A2A to example.
|
|
855
|
+
if self.dedup_comm:
|
|
856
|
+
if is_input:
|
|
857
|
+
sliced_recv_tokens = torch.empty(
|
|
858
|
+
(sum(none_throws(recv_sizes)), D // self.ep_mp_size),
|
|
859
|
+
dtype=send_tokens.dtype,
|
|
860
|
+
device=send_tokens.device,
|
|
861
|
+
)
|
|
862
|
+
# TODO(shikaili): Extremely high copy overhead in prefill.
|
|
863
|
+
sliced_send_tokens = send_tokens.chunk(self.ep_mp_size, dim=-1)[
|
|
864
|
+
self.ep_mp_rank
|
|
865
|
+
].contiguous()
|
|
866
|
+
|
|
867
|
+
recv_tokens_list = list(
|
|
868
|
+
sliced_recv_tokens.split(none_throws(recv_sizes))
|
|
869
|
+
)
|
|
870
|
+
send_tokens_list = list(
|
|
871
|
+
sliced_send_tokens.split(none_throws(send_sizes))
|
|
872
|
+
)
|
|
873
|
+
|
|
874
|
+
assert all(r.is_contiguous() for r in recv_tokens_list)
|
|
875
|
+
assert all(s.is_contiguous() for s in send_tokens_list)
|
|
876
|
+
torch.distributed.all_to_all(
|
|
877
|
+
output_tensor_list=recv_tokens_list,
|
|
878
|
+
input_tensor_list=send_tokens_list,
|
|
879
|
+
group=self.ep_group,
|
|
880
|
+
)
|
|
881
|
+
|
|
882
|
+
recv_tokens_permutated = torch.empty(
|
|
883
|
+
(
|
|
884
|
+
self.ep_mp_size,
|
|
885
|
+
sum(none_throws(recv_sizes)),
|
|
886
|
+
D // self.ep_mp_size,
|
|
887
|
+
),
|
|
888
|
+
dtype=send_tokens.dtype,
|
|
889
|
+
device=send_tokens.device,
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
assert sliced_recv_tokens.is_contiguous()
|
|
893
|
+
assert recv_tokens_permutated.is_contiguous()
|
|
894
|
+
torch.distributed.all_gather_into_tensor(
|
|
895
|
+
output_tensor=recv_tokens_permutated,
|
|
896
|
+
input_tensor=sliced_recv_tokens,
|
|
897
|
+
group=self.ep_mp_group,
|
|
898
|
+
)
|
|
899
|
+
|
|
900
|
+
return (
|
|
901
|
+
recv_tokens_permutated.permute(1, 0, 2).reshape(-1, D).contiguous()
|
|
902
|
+
)
|
|
903
|
+
else:
|
|
904
|
+
# ReduceScatter
|
|
905
|
+
reduced_sliced_send_tokens = torch.empty(
|
|
906
|
+
(D // self.ep_mp_size, sum(none_throws(send_sizes))),
|
|
907
|
+
dtype=send_tokens.dtype,
|
|
908
|
+
device=send_tokens.device,
|
|
909
|
+
)
|
|
910
|
+
torch.distributed.reduce_scatter_tensor(
|
|
911
|
+
output=reduced_sliced_send_tokens,
|
|
912
|
+
input=send_tokens.transpose(0, 1).contiguous(),
|
|
913
|
+
group=self.ep_mp_group,
|
|
914
|
+
)
|
|
915
|
+
reduced_sliced_send_tokens = reduced_sliced_send_tokens.transpose(
|
|
916
|
+
0, 1
|
|
917
|
+
).contiguous()
|
|
918
|
+
|
|
919
|
+
# AlltoAll
|
|
920
|
+
reduced_sliced_recv_tokens = torch.empty(
|
|
921
|
+
(sum(none_throws(recv_sizes)), D // self.ep_mp_size),
|
|
922
|
+
dtype=send_tokens.dtype,
|
|
923
|
+
device=send_tokens.device,
|
|
924
|
+
)
|
|
925
|
+
recv_tokens_list = list(
|
|
926
|
+
reduced_sliced_recv_tokens.split(none_throws(recv_sizes))
|
|
927
|
+
)
|
|
928
|
+
send_tokens_list = list(
|
|
929
|
+
reduced_sliced_send_tokens.split(none_throws(send_sizes))
|
|
930
|
+
)
|
|
931
|
+
|
|
932
|
+
assert all(r.is_contiguous() for r in recv_tokens_list)
|
|
933
|
+
assert all(s.is_contiguous() for s in send_tokens_list)
|
|
934
|
+
torch.distributed.all_to_all(
|
|
935
|
+
output_tensor_list=recv_tokens_list,
|
|
936
|
+
input_tensor_list=send_tokens_list,
|
|
937
|
+
group=self.ep_group,
|
|
938
|
+
)
|
|
939
|
+
|
|
940
|
+
# Padding
|
|
941
|
+
slice_d = D // self.ep_mp_size
|
|
942
|
+
pad_l = slice_d * self.ep_mp_rank
|
|
943
|
+
pad_r = D - pad_l - slice_d
|
|
944
|
+
return torch.nn.functional.pad(
|
|
945
|
+
reduced_sliced_recv_tokens, (pad_l, pad_r)
|
|
946
|
+
)
|
|
947
|
+
else:
|
|
948
|
+
recv_tokens = torch.empty(
|
|
949
|
+
(sum(none_throws(recv_sizes)), D),
|
|
950
|
+
dtype=send_tokens.dtype,
|
|
951
|
+
device=send_tokens.device,
|
|
952
|
+
)
|
|
953
|
+
|
|
954
|
+
recv_tokens_list = list(recv_tokens.split(none_throws(recv_sizes)))
|
|
955
|
+
send_tokens_list = list(send_tokens.split(none_throws(send_sizes)))
|
|
956
|
+
|
|
957
|
+
assert all(r.is_contiguous() for r in recv_tokens_list)
|
|
958
|
+
assert all(s.is_contiguous() for s in send_tokens_list)
|
|
959
|
+
torch.distributed.all_to_all(
|
|
960
|
+
output_tensor_list=recv_tokens_list,
|
|
961
|
+
input_tensor_list=send_tokens_list,
|
|
962
|
+
group=self.ep_group,
|
|
963
|
+
)
|
|
964
|
+
|
|
965
|
+
return recv_tokens
|
|
966
|
+
|
|
967
|
+
def _gather_tokens(
|
|
968
|
+
self,
|
|
969
|
+
send_tokens: torch.Tensor,
|
|
970
|
+
) -> torch.Tensor:
|
|
971
|
+
"No CPU/GPU sync in this function."
|
|
972
|
+
if self.ep_size == 1:
|
|
973
|
+
return send_tokens
|
|
974
|
+
|
|
975
|
+
# TODO: Add FP8 AG to example.
|
|
976
|
+
T, D = send_tokens.shape
|
|
977
|
+
if self.dedup_comm:
|
|
978
|
+
inter_node_recv_tokens = torch.empty(
|
|
979
|
+
(self.ep_size, T, D // self.ep_mp_size),
|
|
980
|
+
dtype=send_tokens.dtype,
|
|
981
|
+
device=send_tokens.device,
|
|
982
|
+
)
|
|
983
|
+
# Copy overhead.
|
|
984
|
+
inter_node_send_tokens = send_tokens.chunk(self.ep_mp_size, dim=-1)[
|
|
985
|
+
self.ep_mp_rank
|
|
986
|
+
].contiguous()
|
|
987
|
+
|
|
988
|
+
assert inter_node_send_tokens.is_contiguous()
|
|
989
|
+
assert inter_node_recv_tokens.is_contiguous()
|
|
990
|
+
torch.distributed.all_gather_into_tensor(
|
|
991
|
+
output_tensor=inter_node_recv_tokens,
|
|
992
|
+
input_tensor=inter_node_send_tokens,
|
|
993
|
+
group=self.ep_group,
|
|
994
|
+
)
|
|
995
|
+
|
|
996
|
+
intra_node_recv_tokens_transposed = torch.empty(
|
|
997
|
+
(self.ep_mp_size, self.ep_size, T, D // self.ep_mp_size),
|
|
998
|
+
dtype=send_tokens.dtype,
|
|
999
|
+
device=send_tokens.device,
|
|
1000
|
+
)
|
|
1001
|
+
|
|
1002
|
+
assert inter_node_recv_tokens.is_contiguous()
|
|
1003
|
+
assert intra_node_recv_tokens_transposed.is_contiguous()
|
|
1004
|
+
torch.distributed.all_gather_into_tensor(
|
|
1005
|
+
output_tensor=intra_node_recv_tokens_transposed,
|
|
1006
|
+
input_tensor=inter_node_recv_tokens,
|
|
1007
|
+
group=self.ep_mp_group,
|
|
1008
|
+
)
|
|
1009
|
+
|
|
1010
|
+
# Copy overhead.
|
|
1011
|
+
return (
|
|
1012
|
+
intra_node_recv_tokens_transposed.permute(1, 2, 0, 3)
|
|
1013
|
+
.reshape(self.ep_size, T, D)
|
|
1014
|
+
.contiguous()
|
|
1015
|
+
)
|
|
1016
|
+
else:
|
|
1017
|
+
recv_tokens = torch.empty(
|
|
1018
|
+
(self.ep_size, T, D),
|
|
1019
|
+
dtype=send_tokens.dtype,
|
|
1020
|
+
device=send_tokens.device,
|
|
1021
|
+
)
|
|
1022
|
+
|
|
1023
|
+
assert send_tokens.is_contiguous()
|
|
1024
|
+
assert recv_tokens.is_contiguous()
|
|
1025
|
+
torch.distributed.all_gather_into_tensor(
|
|
1026
|
+
output_tensor=recv_tokens,
|
|
1027
|
+
input_tensor=send_tokens,
|
|
1028
|
+
group=self.ep_group,
|
|
1029
|
+
)
|
|
1030
|
+
return recv_tokens
|
|
1031
|
+
|
|
1032
|
+
def _route(
|
|
1033
|
+
self, tokens: torch.Tensor
|
|
1034
|
+
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
|
|
1035
|
+
B, T, D = tokens.shape
|
|
1036
|
+
tokens = tokens.view(-1, D)
|
|
1037
|
+
|
|
1038
|
+
assert not self.router_DE.is_scaled
|
|
1039
|
+
scores = torch.nn.functional.linear(tokens, self.router_DE.T)
|
|
1040
|
+
scores = torch.sigmoid(scores)
|
|
1041
|
+
assert scores.shape == (B * T, self.E)
|
|
1042
|
+
|
|
1043
|
+
token_counts, expert_indices, token_indices = index_shuffling(
|
|
1044
|
+
scores, # num_tokens
|
|
1045
|
+
)
|
|
1046
|
+
token_counts = token_counts[: self.E]
|
|
1047
|
+
|
|
1048
|
+
if self.dedup_comm:
|
|
1049
|
+
split_sizes = [
|
|
1050
|
+
token_counts.shape[0],
|
|
1051
|
+
expert_indices.shape[0],
|
|
1052
|
+
token_indices.shape[0],
|
|
1053
|
+
]
|
|
1054
|
+
output = torch.concat([token_counts, expert_indices, token_indices], dim=0)
|
|
1055
|
+
# Require broadcast as index_shuffling is not deterministic.
|
|
1056
|
+
torch.distributed.broadcast(
|
|
1057
|
+
output,
|
|
1058
|
+
src=(torch.distributed.get_rank() // self.ep_mp_size) * self.ep_mp_size,
|
|
1059
|
+
group=self.ep_mp_group,
|
|
1060
|
+
)
|
|
1061
|
+
token_counts, expert_indices, token_indices = torch.split(
|
|
1062
|
+
output, split_sizes, dim=0
|
|
1063
|
+
)
|
|
1064
|
+
|
|
1065
|
+
if self.is_routed_fp8_rowwise and self.ep_size == 1:
|
|
1066
|
+
routed_tokens, routed_tokens_scales = gather_scale_quant_dense_tokens(
|
|
1067
|
+
tokens,
|
|
1068
|
+
token_indices=token_indices.flatten(),
|
|
1069
|
+
expert_indices=expert_indices.flatten(),
|
|
1070
|
+
scores=scores,
|
|
1071
|
+
scale_ub=self.activation_scale_ub,
|
|
1072
|
+
)
|
|
1073
|
+
else:
|
|
1074
|
+
routed_tokens = gather_scale_dense_tokens(
|
|
1075
|
+
tokens,
|
|
1076
|
+
token_indices=token_indices.flatten(),
|
|
1077
|
+
expert_indices=expert_indices.flatten(),
|
|
1078
|
+
scores=scores,
|
|
1079
|
+
)
|
|
1080
|
+
routed_tokens_scales = None
|
|
1081
|
+
return routed_tokens, routed_tokens_scales, token_counts, token_indices
|
|
1082
|
+
|
|
1083
|
+
def _shared_expert_part1(self, x: torch.Tensor) -> torch.Tensor:
|
|
1084
|
+
# tokens: [B, T, D]
|
|
1085
|
+
D = x.shape[-1]
|
|
1086
|
+
x = x.view(-1, D)
|
|
1087
|
+
w13 = self.shared_experts.w13
|
|
1088
|
+
|
|
1089
|
+
if not self.is_shared_fp8_rowwise:
|
|
1090
|
+
# TODO(shikaili): Skip padded tokens.
|
|
1091
|
+
return x @ w13.T
|
|
1092
|
+
else:
|
|
1093
|
+
x, x_scale = triton_quantize_fp8_row(x, self.activation_scale_ub)
|
|
1094
|
+
# TODO(shikaili): Skip padded tokens.
|
|
1095
|
+
return torch.ops.mslk.f8f8bf16_rowwise(
|
|
1096
|
+
x,
|
|
1097
|
+
w13.weights,
|
|
1098
|
+
x_scale,
|
|
1099
|
+
w13.scales,
|
|
1100
|
+
use_fast_accum=self.use_fast_accum,
|
|
1101
|
+
)
|
|
1102
|
+
|
|
1103
|
+
def _shared_expert_part2(self, y: torch.Tensor) -> torch.Tensor:
|
|
1104
|
+
# tokens: [B, T, D]
|
|
1105
|
+
HD_L_2 = y.shape[-1]
|
|
1106
|
+
HD_L = HD_L_2 // 2
|
|
1107
|
+
w2 = self.shared_experts.w2
|
|
1108
|
+
|
|
1109
|
+
z, z_scale = self._fused_silu_mul(
|
|
1110
|
+
y[:, :HD_L],
|
|
1111
|
+
y[:, HD_L:],
|
|
1112
|
+
self.is_shared_fp8_rowwise,
|
|
1113
|
+
self.activation_scale_ub,
|
|
1114
|
+
)
|
|
1115
|
+
if not self.is_shared_fp8_rowwise:
|
|
1116
|
+
assert z_scale is None
|
|
1117
|
+
# TODO(shikaili): Skip padded tokens.
|
|
1118
|
+
return z @ w2.T
|
|
1119
|
+
else:
|
|
1120
|
+
assert z_scale is not None
|
|
1121
|
+
# TODO(shikaili): Skip padded tokens.
|
|
1122
|
+
return torch.ops.mslk.f8f8bf16_rowwise(
|
|
1123
|
+
z,
|
|
1124
|
+
w2.weights,
|
|
1125
|
+
z_scale,
|
|
1126
|
+
w2.scales,
|
|
1127
|
+
use_fast_accum=self.use_fast_accum,
|
|
1128
|
+
)
|
|
1129
|
+
|
|
1130
|
+
def _routed_expert(
|
|
1131
|
+
self,
|
|
1132
|
+
tokens: torch.Tensor,
|
|
1133
|
+
token_counts: torch.Tensor,
|
|
1134
|
+
token_scales: Optional[torch.Tensor] = None,
|
|
1135
|
+
shared_output: Optional[torch.Tensor] = None,
|
|
1136
|
+
token_indices: Optional[torch.Tensor] = None,
|
|
1137
|
+
) -> torch.Tensor:
|
|
1138
|
+
# tokens: [B, T, D]
|
|
1139
|
+
D = tokens.shape[-1]
|
|
1140
|
+
x = tokens.view(-1, D)
|
|
1141
|
+
|
|
1142
|
+
if x.shape[0] == 0:
|
|
1143
|
+
return x
|
|
1144
|
+
|
|
1145
|
+
w13 = self.routed_experts.w13
|
|
1146
|
+
w2 = self.routed_experts.w2
|
|
1147
|
+
|
|
1148
|
+
assert D == w13.shape[-1]
|
|
1149
|
+
HD_L = w2.shape[-1]
|
|
1150
|
+
|
|
1151
|
+
assert token_counts.shape == (self.num_local_experts,)
|
|
1152
|
+
if not self.is_routed_fp8_rowwise:
|
|
1153
|
+
y = grouped_gemm(
|
|
1154
|
+
x,
|
|
1155
|
+
w13.view(-1, D),
|
|
1156
|
+
token_counts,
|
|
1157
|
+
use_fast_accum=self.use_fast_accum,
|
|
1158
|
+
_use_warp_specialization=not torch.version.hip,
|
|
1159
|
+
)
|
|
1160
|
+
z, _ = self._fused_silu_mul(y[:, :HD_L], y[:, HD_L:], False)
|
|
1161
|
+
return grouped_gemm(
|
|
1162
|
+
z,
|
|
1163
|
+
w2.view(-1, HD_L),
|
|
1164
|
+
token_counts,
|
|
1165
|
+
use_fast_accum=self.use_fast_accum,
|
|
1166
|
+
_use_warp_specialization=not torch.version.hip,
|
|
1167
|
+
_output_tensor=shared_output,
|
|
1168
|
+
_scatter_add_indices=token_indices,
|
|
1169
|
+
)
|
|
1170
|
+
else:
|
|
1171
|
+
if token_scales is None:
|
|
1172
|
+
x, x_scale = triton_quantize_fp8_row(x, self.activation_scale_ub)
|
|
1173
|
+
else:
|
|
1174
|
+
x_scale = token_scales
|
|
1175
|
+
y = grouped_gemm_fp8_rowwise(
|
|
1176
|
+
x,
|
|
1177
|
+
w13.weights.view(-1, D),
|
|
1178
|
+
token_counts,
|
|
1179
|
+
x_scale.view(-1),
|
|
1180
|
+
w13.scales.view(-1),
|
|
1181
|
+
use_fast_accum=self.use_fast_accum,
|
|
1182
|
+
_use_warp_specialization=not torch.version.hip,
|
|
1183
|
+
)
|
|
1184
|
+
# TODO(shikaili): Skip padded tokens.
|
|
1185
|
+
z, z_scale = self._fused_silu_mul(
|
|
1186
|
+
y[:, :HD_L], y[:, HD_L:], True, self.activation_scale_ub
|
|
1187
|
+
)
|
|
1188
|
+
assert z_scale is not None
|
|
1189
|
+
return grouped_gemm_fp8_rowwise(
|
|
1190
|
+
z,
|
|
1191
|
+
w2.weights.view(-1, HD_L),
|
|
1192
|
+
token_counts,
|
|
1193
|
+
z_scale.view(-1),
|
|
1194
|
+
w2.scales.view(-1),
|
|
1195
|
+
use_fast_accum=self.use_fast_accum,
|
|
1196
|
+
_use_warp_specialization=not torch.version.hip,
|
|
1197
|
+
_output_tensor=shared_output,
|
|
1198
|
+
_scatter_add_indices=token_indices,
|
|
1199
|
+
)
|
|
1200
|
+
|
|
1201
|
+
def _combine_outputs(
|
|
1202
|
+
self,
|
|
1203
|
+
shared_output_tokens: torch.Tensor,
|
|
1204
|
+
routed_output_tokens: torch.Tensor,
|
|
1205
|
+
token_indices: torch.Tensor,
|
|
1206
|
+
token_counts: torch.Tensor,
|
|
1207
|
+
padded: bool = False,
|
|
1208
|
+
) -> torch.Tensor:
|
|
1209
|
+
D = shared_output_tokens.shape[-1]
|
|
1210
|
+
assert routed_output_tokens.shape[-1] == D
|
|
1211
|
+
|
|
1212
|
+
if padded:
|
|
1213
|
+
scatter_add_padded_tokens(
|
|
1214
|
+
in_tokens=routed_output_tokens,
|
|
1215
|
+
token_counts=token_counts,
|
|
1216
|
+
token_indices=token_indices,
|
|
1217
|
+
out_tokens=shared_output_tokens,
|
|
1218
|
+
)
|
|
1219
|
+
return shared_output_tokens
|
|
1220
|
+
|
|
1221
|
+
scatter_add_dense_tokens(
|
|
1222
|
+
shared_output_tokens,
|
|
1223
|
+
routed_output_tokens.view(-1, D),
|
|
1224
|
+
token_indices,
|
|
1225
|
+
)
|
|
1226
|
+
return shared_output_tokens
|
|
1227
|
+
|
|
1228
|
+
def _fused_silu_mul(
|
|
1229
|
+
self,
|
|
1230
|
+
x0: torch.Tensor,
|
|
1231
|
+
x1: torch.Tensor,
|
|
1232
|
+
is_fp8: bool,
|
|
1233
|
+
scale_ub: Optional[torch.Tensor] = None,
|
|
1234
|
+
):
|
|
1235
|
+
z_scale = None
|
|
1236
|
+
if is_fp8:
|
|
1237
|
+
z, z_scale = silu_mul_quant(x0, x1, scale_ub)
|
|
1238
|
+
else:
|
|
1239
|
+
z = silu_mul(x0, x1)
|
|
1240
|
+
return z, z_scale
|