mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,739 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
# pyre-unsafe
|
|
8
|
+
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import triton
|
|
13
|
+
import triton.language as tl
|
|
14
|
+
from mslk.utils.triton.fp8_utils import get_fp8_constants
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# Function APIs
|
|
18
|
+
def gather_scale_dense_tokens(
|
|
19
|
+
x: torch.Tensor,
|
|
20
|
+
token_indices: torch.Tensor,
|
|
21
|
+
expert_indices: torch.Tensor,
|
|
22
|
+
scores: torch.Tensor,
|
|
23
|
+
valid_token_count: Optional[torch.Tensor] = None,
|
|
24
|
+
) -> torch.Tensor:
|
|
25
|
+
"""
|
|
26
|
+
Gather and scale dense tokens along 1D indices.
|
|
27
|
+
|
|
28
|
+
For each input token, token_indices[i] is the index of the token in the input sequence.
|
|
29
|
+
expert_indices[i] is the index of the expert that the token is assigned to.
|
|
30
|
+
scores[i] is the score of the token.
|
|
31
|
+
|
|
32
|
+
For each expert, the tokens assigned to this expert are gathered from the input sequence,
|
|
33
|
+
and then their scores are multiplied element-wise.
|
|
34
|
+
|
|
35
|
+
valid_token_count is an optional tensor that can be used to filter out some tokens.
|
|
36
|
+
If it is provided, the function will only consider the first valid_token_count tokens in the input sequence.
|
|
37
|
+
|
|
38
|
+
The function returns a tensor of shape (a, D), where a is the number of tokens and D is the input dimension.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
x (torch.Tensor): input tensor of shape (T, D)
|
|
42
|
+
token_indices (torch.Tensor): token indices of shape (a,)
|
|
43
|
+
expert_indices (torch.Tensor): expert indices of shape (a,)
|
|
44
|
+
scores (torch.Tensor): scores of shape (T, E)
|
|
45
|
+
valid_token_count (torch.Tensor, optional): valid token count of shape (,)
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
torch.Tensor: output tensor of shape (a, D)
|
|
49
|
+
"""
|
|
50
|
+
T, D = x.shape
|
|
51
|
+
E = scores.shape[1]
|
|
52
|
+
# a = K * T
|
|
53
|
+
a = token_indices.shape[0]
|
|
54
|
+
|
|
55
|
+
out = torch.empty((a, D), device=x.device, dtype=x.dtype)
|
|
56
|
+
if a == 0 or D == 0:
|
|
57
|
+
return out
|
|
58
|
+
|
|
59
|
+
assert x.is_contiguous()
|
|
60
|
+
assert token_indices.is_contiguous()
|
|
61
|
+
assert expert_indices.is_contiguous()
|
|
62
|
+
|
|
63
|
+
assert tuple(token_indices.shape) == (a,)
|
|
64
|
+
assert tuple(expert_indices.shape) == (a,)
|
|
65
|
+
assert tuple(scores.shape) == (T, E)
|
|
66
|
+
|
|
67
|
+
stride_t = scores.stride(0)
|
|
68
|
+
stride_e = scores.stride(1)
|
|
69
|
+
|
|
70
|
+
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
|
71
|
+
if a >= NUM_SMS:
|
|
72
|
+
BLOCK_D_OUTER = D
|
|
73
|
+
BLOCK_D_INNER = 1024
|
|
74
|
+
assert D % BLOCK_D_INNER == 0
|
|
75
|
+
else:
|
|
76
|
+
BLOCK_D_OUTER = 512
|
|
77
|
+
BLOCK_D_INNER = 256
|
|
78
|
+
assert D % BLOCK_D_OUTER == 0
|
|
79
|
+
grid = (a, D // BLOCK_D_OUTER)
|
|
80
|
+
_mslk_gather_scale_dense_tokens[grid](
|
|
81
|
+
out,
|
|
82
|
+
x,
|
|
83
|
+
token_indices,
|
|
84
|
+
expert_indices,
|
|
85
|
+
scores,
|
|
86
|
+
stride_t,
|
|
87
|
+
stride_e,
|
|
88
|
+
valid_token_count,
|
|
89
|
+
D, # pyre-ignore
|
|
90
|
+
BLOCK_D_OUTER, # pyre-ignore
|
|
91
|
+
BLOCK_D_INNER, # pyre-ignore
|
|
92
|
+
)
|
|
93
|
+
return out
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def gather_scale_quant_dense_tokens(
|
|
97
|
+
x: torch.Tensor,
|
|
98
|
+
token_indices: torch.Tensor,
|
|
99
|
+
expert_indices: torch.Tensor,
|
|
100
|
+
scores: torch.Tensor,
|
|
101
|
+
scale_ub: Optional[torch.Tensor] = None,
|
|
102
|
+
valid_token_count: Optional[torch.Tensor] = None,
|
|
103
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
104
|
+
"""
|
|
105
|
+
Gather, scale, and quantize dense tokens along 1D indices.
|
|
106
|
+
|
|
107
|
+
For each input token, token_indices[i] is the index of the token in the input sequence.
|
|
108
|
+
expert_indices[i] is the index of the expert that the token is assigned to.
|
|
109
|
+
scores[i] is the score of the token.
|
|
110
|
+
|
|
111
|
+
For each expert, the tokens assigned to this expert are gathered from the input sequence,
|
|
112
|
+
and then their scores are multiplied element-wise, and then quantized to FP8.
|
|
113
|
+
|
|
114
|
+
valid_token_count is an optional tensor that can be used to filter out some tokens.
|
|
115
|
+
If it is provided, the function will only consider the first valid_token_count tokens in the input sequence.
|
|
116
|
+
|
|
117
|
+
The function returns a tensor of shape (a, D), where a is the number of tokens and D is the input dimension.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
x (torch.Tensor): input tensor of shape (T, D)
|
|
121
|
+
token_indices (torch.Tensor): token indices of shape (a,)
|
|
122
|
+
expert_indices (torch.Tensor): expert indices of shape (a,)
|
|
123
|
+
scores (torch.Tensor): scores of shape (T, E)
|
|
124
|
+
scale_ub (torch.Tensor, optional): scale upper bound of shape (1,)
|
|
125
|
+
valid_token_count (torch.Tensor, optional): valid token count of shape (1,)
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
torch.Tensor: output tensor of shape (a, D)
|
|
129
|
+
"""
|
|
130
|
+
T, D = x.shape
|
|
131
|
+
E = scores.shape[1]
|
|
132
|
+
# a = K * T
|
|
133
|
+
a = token_indices.shape[0]
|
|
134
|
+
|
|
135
|
+
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
|
|
136
|
+
|
|
137
|
+
assert x.is_contiguous()
|
|
138
|
+
assert token_indices.is_contiguous()
|
|
139
|
+
assert expert_indices.is_contiguous()
|
|
140
|
+
|
|
141
|
+
assert tuple(token_indices.shape) == (a,)
|
|
142
|
+
assert tuple(expert_indices.shape) == (a,)
|
|
143
|
+
assert tuple(scores.shape) == (T, E)
|
|
144
|
+
|
|
145
|
+
stride_t = scores.stride(0)
|
|
146
|
+
stride_e = scores.stride(1)
|
|
147
|
+
|
|
148
|
+
out = torch.empty((a, D), device="cuda", dtype=pt_dtype)
|
|
149
|
+
out_scale = torch.empty((a,), device="cuda", dtype=torch.float32)
|
|
150
|
+
|
|
151
|
+
grid = (a,)
|
|
152
|
+
_mslk_gather_scale_fp8_rowwise_quant_dense_tokens[grid](
|
|
153
|
+
out,
|
|
154
|
+
out_scale,
|
|
155
|
+
x,
|
|
156
|
+
token_indices,
|
|
157
|
+
expert_indices,
|
|
158
|
+
scores,
|
|
159
|
+
scale_ub,
|
|
160
|
+
stride_t,
|
|
161
|
+
stride_e,
|
|
162
|
+
valid_token_count,
|
|
163
|
+
D,
|
|
164
|
+
TL_FP8_DTYPE=tl_dtype,
|
|
165
|
+
MAX_FP8=max_fp8,
|
|
166
|
+
EPS=eps,
|
|
167
|
+
CLAMP_MAX=scale_ub is not None,
|
|
168
|
+
)
|
|
169
|
+
return out, out_scale
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def scatter_add_dense_tokens(
|
|
173
|
+
out_tokens: torch.Tensor, # [T, D]
|
|
174
|
+
in_tokens: torch.Tensor, # [a, D]
|
|
175
|
+
token_indices: torch.Tensor, # [a]
|
|
176
|
+
valid_token_count: Optional[torch.Tensor] = None,
|
|
177
|
+
) -> None:
|
|
178
|
+
"""
|
|
179
|
+
Scatter add dense tokens along 1D indices.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
out_tokens (torch.Tensor): output tensor of shape (T, D)
|
|
183
|
+
in_tokens (torch.Tensor): input tensor of shape (a, D)
|
|
184
|
+
token_indices (torch.Tensor): token indices of shape (a,)
|
|
185
|
+
valid_token_count (torch.Tensor, optional): valid token count of shape (1,)
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
None
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
assert torch.version.hip is not None or (
|
|
192
|
+
torch.version.cuda is not None and torch.version.cuda >= "12.4"
|
|
193
|
+
), "Requires CUDA version 12.4 or later on Nvidia GPUs!"
|
|
194
|
+
|
|
195
|
+
assert in_tokens.is_contiguous()
|
|
196
|
+
assert token_indices.is_contiguous()
|
|
197
|
+
assert out_tokens.is_contiguous()
|
|
198
|
+
|
|
199
|
+
a, D = in_tokens.shape
|
|
200
|
+
if a == 0:
|
|
201
|
+
return
|
|
202
|
+
assert token_indices.shape == (a,)
|
|
203
|
+
assert out_tokens.ndim == 2 and out_tokens.shape[1] == D
|
|
204
|
+
|
|
205
|
+
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
|
206
|
+
if a >= NUM_SMS:
|
|
207
|
+
BLOCK_D_OUTER = D
|
|
208
|
+
BLOCK_D_INNER = 1024
|
|
209
|
+
else:
|
|
210
|
+
BLOCK_D_OUTER = 512
|
|
211
|
+
BLOCK_D_INNER = 256
|
|
212
|
+
while D % BLOCK_D_OUTER != 0:
|
|
213
|
+
BLOCK_D_OUTER //= 2
|
|
214
|
+
while D % BLOCK_D_INNER != 0:
|
|
215
|
+
BLOCK_D_INNER //= 2
|
|
216
|
+
|
|
217
|
+
grid = (a, D // BLOCK_D_OUTER)
|
|
218
|
+
_mslk_scatter_add_dense_tokens[grid](
|
|
219
|
+
out_tokens,
|
|
220
|
+
in_tokens,
|
|
221
|
+
token_indices,
|
|
222
|
+
valid_token_count,
|
|
223
|
+
D, # pyre-ignore
|
|
224
|
+
BLOCK_D_OUTER, # pyre-ignore
|
|
225
|
+
BLOCK_D_INNER, # pyre-ignore
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def scatter_add_padded_tokens(
|
|
230
|
+
in_tokens: torch.Tensor, # [EP, T_K, D]
|
|
231
|
+
token_counts: torch.Tensor, # [E]
|
|
232
|
+
token_indices: torch.Tensor, # [T_K]
|
|
233
|
+
out_tokens: torch.Tensor, # [T, D]
|
|
234
|
+
) -> None:
|
|
235
|
+
"""
|
|
236
|
+
Scatter add valid tokens based on token counts metadata.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
in_tokens (torch.Tensor): input tensor of shape (EP, T_K, D)
|
|
240
|
+
token_counts (torch.Tensor): token counts of shape (E,)
|
|
241
|
+
token_indices (torch.Tensor): token indices of shape (T_K,)
|
|
242
|
+
out_tokens (torch.Tensor): output tensor of shape (T, D)
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
None
|
|
246
|
+
"""
|
|
247
|
+
assert torch.version.hip is not None or (
|
|
248
|
+
torch.version.cuda is not None and torch.version.cuda >= "12.4"
|
|
249
|
+
), "Requires CUDA version 12.4 or later on Nvidia GPUs!"
|
|
250
|
+
|
|
251
|
+
assert in_tokens.is_contiguous()
|
|
252
|
+
assert token_counts.is_contiguous()
|
|
253
|
+
assert token_indices.is_contiguous()
|
|
254
|
+
assert out_tokens.is_contiguous()
|
|
255
|
+
|
|
256
|
+
EP, T_K, D = in_tokens.shape
|
|
257
|
+
E = token_counts.shape[0]
|
|
258
|
+
assert tuple(token_indices.shape) == (T_K,)
|
|
259
|
+
assert T_K % out_tokens.shape[0] == 0 and out_tokens.shape[1] == D
|
|
260
|
+
|
|
261
|
+
def grid(META):
|
|
262
|
+
return (
|
|
263
|
+
E,
|
|
264
|
+
META["SPLIT_T"],
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
T_BUCKET_CAP = 16384
|
|
268
|
+
T_BUCKET = min(triton.next_power_of_2(T_K), T_BUCKET_CAP)
|
|
269
|
+
BLOCK_E = max(triton.next_power_of_2(E), 8)
|
|
270
|
+
_mslk_scatter_add_padded_tokens[grid](
|
|
271
|
+
in_tokens,
|
|
272
|
+
token_counts,
|
|
273
|
+
token_indices,
|
|
274
|
+
out_tokens,
|
|
275
|
+
EP,
|
|
276
|
+
E,
|
|
277
|
+
T_BUCKET,
|
|
278
|
+
T_K,
|
|
279
|
+
D,
|
|
280
|
+
BLOCK_E,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
# Torch Custom Op Registrations
|
|
285
|
+
_GATHER_SCALE_DENSE_TOKENS_OP_NAME = "mslk::gather_scale_dense_tokens"
|
|
286
|
+
|
|
287
|
+
torch.library.define(
|
|
288
|
+
_GATHER_SCALE_DENSE_TOKENS_OP_NAME,
|
|
289
|
+
"(Tensor x, Tensor token_indices, Tensor expert_indices, Tensor scores, Tensor? valid_token_count=None) -> Tensor",
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
@torch.library.impl(_GATHER_SCALE_DENSE_TOKENS_OP_NAME, "Meta")
|
|
294
|
+
def gather_scale_dense_tokens_meta(
|
|
295
|
+
x,
|
|
296
|
+
token_indices,
|
|
297
|
+
expert_indices,
|
|
298
|
+
scores,
|
|
299
|
+
valid_token_count=None,
|
|
300
|
+
):
|
|
301
|
+
D = x.shape[1]
|
|
302
|
+
a = token_indices.shape[0]
|
|
303
|
+
return x.new_empty((a, D))
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
@torch.library.impl(_GATHER_SCALE_DENSE_TOKENS_OP_NAME, "CUDA")
|
|
307
|
+
def gather_scale_dense_tokens_cuda(
|
|
308
|
+
x,
|
|
309
|
+
token_indices,
|
|
310
|
+
expert_indices,
|
|
311
|
+
scores,
|
|
312
|
+
valid_token_count=None,
|
|
313
|
+
):
|
|
314
|
+
return gather_scale_dense_tokens(
|
|
315
|
+
x,
|
|
316
|
+
token_indices,
|
|
317
|
+
expert_indices,
|
|
318
|
+
scores,
|
|
319
|
+
valid_token_count,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
_GATHER_SCALE_QUANT_DENSE_TOKENS_OP_NAME = "mslk::gather_scale_quant_dense_tokens"
|
|
324
|
+
|
|
325
|
+
torch.library.define(
|
|
326
|
+
_GATHER_SCALE_QUANT_DENSE_TOKENS_OP_NAME,
|
|
327
|
+
"(Tensor x, Tensor token_indices, Tensor expert_indices, Tensor scores, Tensor? scale_ub=None, Tensor? valid_token_count=None) -> Tensor",
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
@torch.library.impl(_GATHER_SCALE_QUANT_DENSE_TOKENS_OP_NAME, "Meta")
|
|
332
|
+
def gather_scale_quant_dense_tokens_meta(
|
|
333
|
+
x,
|
|
334
|
+
token_indices,
|
|
335
|
+
expert_indices,
|
|
336
|
+
scores,
|
|
337
|
+
scale_ub=None,
|
|
338
|
+
valid_token_count=None,
|
|
339
|
+
):
|
|
340
|
+
D = x.shape[1]
|
|
341
|
+
a = token_indices.shape[0]
|
|
342
|
+
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
|
|
343
|
+
return torch.empty((a, D), device=x.device, dtype=pt_dtype), torch.empty(
|
|
344
|
+
(a,), device=x.device, dtype=torch.float32
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
@torch.library.impl(_GATHER_SCALE_QUANT_DENSE_TOKENS_OP_NAME, "CUDA")
|
|
349
|
+
def gather_scale_quant_dense_tokens_cuda(
|
|
350
|
+
x,
|
|
351
|
+
token_indices,
|
|
352
|
+
expert_indices,
|
|
353
|
+
scores,
|
|
354
|
+
scale_ub=None,
|
|
355
|
+
valid_token_count=None,
|
|
356
|
+
):
|
|
357
|
+
return gather_scale_quant_dense_tokens(
|
|
358
|
+
x,
|
|
359
|
+
token_indices,
|
|
360
|
+
expert_indices,
|
|
361
|
+
scores,
|
|
362
|
+
scale_ub,
|
|
363
|
+
valid_token_count,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
_SCATTER_ADD_DENSE_TOKENS_OP_NAME = "mslk::scatter_add_dense_tokens"
|
|
368
|
+
|
|
369
|
+
torch.library.define(
|
|
370
|
+
_SCATTER_ADD_DENSE_TOKENS_OP_NAME,
|
|
371
|
+
"(Tensor out_tokens, Tensor in_tokens, Tensor token_indices, Tensor? valid_token_count=None) -> None",
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
@torch.library.impl(_SCATTER_ADD_DENSE_TOKENS_OP_NAME, "Meta")
|
|
376
|
+
def scatter_add_dense_tokens_meta(
|
|
377
|
+
out_tokens,
|
|
378
|
+
in_tokens,
|
|
379
|
+
token_indices,
|
|
380
|
+
valid_token_count=None,
|
|
381
|
+
):
|
|
382
|
+
return None
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
@torch.library.impl(_SCATTER_ADD_DENSE_TOKENS_OP_NAME, "CUDA")
|
|
386
|
+
def scatter_add_dense_tokens_cuda(
|
|
387
|
+
out_tokens,
|
|
388
|
+
in_tokens,
|
|
389
|
+
token_indices,
|
|
390
|
+
valid_token_count=None,
|
|
391
|
+
):
|
|
392
|
+
return scatter_add_dense_tokens(
|
|
393
|
+
out_tokens, in_tokens, token_indices, valid_token_count
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
_SCATTER_ADD_PADDED_TOKENS_OP_NAME = "mslk::scatter_add_padded_tokens"
|
|
398
|
+
|
|
399
|
+
torch.library.define(
|
|
400
|
+
_SCATTER_ADD_PADDED_TOKENS_OP_NAME,
|
|
401
|
+
"(Tensor in_tokens, Tensor token_counts, Tensor token_indices, Tensor out_tokens) -> None",
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
@torch.library.impl(_SCATTER_ADD_PADDED_TOKENS_OP_NAME, "Meta")
|
|
406
|
+
def scatter_add_padded_tokens_meta(
|
|
407
|
+
in_tokens,
|
|
408
|
+
token_counts,
|
|
409
|
+
token_indices,
|
|
410
|
+
out_tokens,
|
|
411
|
+
):
|
|
412
|
+
return None
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
@torch.library.impl(_SCATTER_ADD_PADDED_TOKENS_OP_NAME, "CUDA")
|
|
416
|
+
def scatter_add_padded_tokens_cuda(
|
|
417
|
+
in_tokens,
|
|
418
|
+
token_counts,
|
|
419
|
+
token_indices,
|
|
420
|
+
out_tokens,
|
|
421
|
+
):
|
|
422
|
+
return scatter_add_padded_tokens(
|
|
423
|
+
in_tokens,
|
|
424
|
+
token_counts,
|
|
425
|
+
token_indices,
|
|
426
|
+
out_tokens,
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
# Kernel Implementations
|
|
431
|
+
@triton.jit
|
|
432
|
+
def _mslk_gather_scale_dense_tokens(
|
|
433
|
+
out,
|
|
434
|
+
x,
|
|
435
|
+
token_indices,
|
|
436
|
+
expert_indices,
|
|
437
|
+
scores,
|
|
438
|
+
stride_t,
|
|
439
|
+
stride_e,
|
|
440
|
+
valid_token_count,
|
|
441
|
+
D: tl.constexpr,
|
|
442
|
+
BLOCK_D_OUTER: tl.constexpr,
|
|
443
|
+
BLOCK_D_INNER: tl.constexpr,
|
|
444
|
+
):
|
|
445
|
+
output_token_index = tl.program_id(0)
|
|
446
|
+
feature_offset = tl.program_id(1) * BLOCK_D_OUTER
|
|
447
|
+
|
|
448
|
+
if valid_token_count is not None:
|
|
449
|
+
valid_token_count = tl.load(
|
|
450
|
+
valid_token_count, None, eviction_policy="evict_last"
|
|
451
|
+
)
|
|
452
|
+
if output_token_index >= valid_token_count:
|
|
453
|
+
return
|
|
454
|
+
|
|
455
|
+
input_token_index = tl.load(
|
|
456
|
+
token_indices + output_token_index, None, eviction_policy="evict_last"
|
|
457
|
+
)
|
|
458
|
+
input_expert_index = tl.load(
|
|
459
|
+
expert_indices + output_token_index, None, eviction_policy="evict_last"
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
input_score = tl.load(
|
|
463
|
+
scores + input_token_index * stride_t + input_expert_index * stride_e,
|
|
464
|
+
None,
|
|
465
|
+
eviction_policy="evict_last",
|
|
466
|
+
).to(tl.float32)
|
|
467
|
+
|
|
468
|
+
for _ in range(0, BLOCK_D_OUTER // BLOCK_D_INNER):
|
|
469
|
+
input_token_value = tl.load(
|
|
470
|
+
x
|
|
471
|
+
+ input_token_index.to(tl.int64) * D
|
|
472
|
+
+ feature_offset
|
|
473
|
+
+ tl.arange(0, BLOCK_D_INNER)[:],
|
|
474
|
+
None,
|
|
475
|
+
).to(tl.float32)
|
|
476
|
+
output_token_value = input_token_value * input_score
|
|
477
|
+
|
|
478
|
+
tl.store(
|
|
479
|
+
out
|
|
480
|
+
+ output_token_index.to(tl.int64) * D
|
|
481
|
+
+ feature_offset
|
|
482
|
+
+ tl.arange(0, BLOCK_D_INNER)[:],
|
|
483
|
+
output_token_value,
|
|
484
|
+
None,
|
|
485
|
+
)
|
|
486
|
+
feature_offset += BLOCK_D_INNER
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
@triton.jit
|
|
490
|
+
def _mslk_scatter_add_dense_tokens(
|
|
491
|
+
out_tokens,
|
|
492
|
+
in_tokens,
|
|
493
|
+
token_indices,
|
|
494
|
+
valid_token_count,
|
|
495
|
+
D: tl.constexpr,
|
|
496
|
+
BLOCK_D_OUTER: tl.constexpr,
|
|
497
|
+
BLOCK_D_INNER: tl.constexpr,
|
|
498
|
+
):
|
|
499
|
+
input_token_index = tl.program_id(0).to(tl.int64)
|
|
500
|
+
feature_offset = tl.program_id(1) * BLOCK_D_OUTER + tl.arange(0, BLOCK_D_INNER)[:]
|
|
501
|
+
|
|
502
|
+
if valid_token_count is not None:
|
|
503
|
+
valid_token_count = tl.load(
|
|
504
|
+
valid_token_count, None, eviction_policy="evict_last"
|
|
505
|
+
)
|
|
506
|
+
if input_token_index >= valid_token_count:
|
|
507
|
+
return
|
|
508
|
+
|
|
509
|
+
output_token_index = tl.load(
|
|
510
|
+
token_indices + input_token_index, None, eviction_policy="evict_last"
|
|
511
|
+
).to(tl.int64)
|
|
512
|
+
|
|
513
|
+
for _ in range(0, BLOCK_D_OUTER // BLOCK_D_INNER):
|
|
514
|
+
input_token_value = tl.load(
|
|
515
|
+
in_tokens + input_token_index * D + feature_offset,
|
|
516
|
+
None,
|
|
517
|
+
eviction_policy="evict_first",
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
tl.atomic_add(
|
|
521
|
+
out_tokens + output_token_index * D + feature_offset,
|
|
522
|
+
input_token_value,
|
|
523
|
+
None,
|
|
524
|
+
sem="relaxed",
|
|
525
|
+
)
|
|
526
|
+
feature_offset += BLOCK_D_INNER
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
@triton.autotune(
|
|
530
|
+
configs=[
|
|
531
|
+
triton.Config({"BLOCK_D": 256}),
|
|
532
|
+
triton.Config({"BLOCK_D": 512}),
|
|
533
|
+
triton.Config({"BLOCK_D": 1024}),
|
|
534
|
+
],
|
|
535
|
+
key=["D"],
|
|
536
|
+
)
|
|
537
|
+
@triton.jit
|
|
538
|
+
def _mslk_gather_scale_fp8_rowwise_quant_dense_tokens(
|
|
539
|
+
output_ptr,
|
|
540
|
+
output_scale_ptr,
|
|
541
|
+
input_ptr,
|
|
542
|
+
token_indices_ptr,
|
|
543
|
+
expert_indices_ptr,
|
|
544
|
+
scores_ptr,
|
|
545
|
+
scale_ub_ptr,
|
|
546
|
+
stride_t,
|
|
547
|
+
stride_e,
|
|
548
|
+
valid_token_count,
|
|
549
|
+
D: tl.constexpr,
|
|
550
|
+
TL_FP8_DTYPE: tl.constexpr,
|
|
551
|
+
MAX_FP8: tl.constexpr,
|
|
552
|
+
EPS: tl.constexpr,
|
|
553
|
+
CLAMP_MAX: tl.constexpr,
|
|
554
|
+
BLOCK_D: tl.constexpr,
|
|
555
|
+
):
|
|
556
|
+
tl.static_assert(D % BLOCK_D == 0, "D must be a multiple of BLOCK_D")
|
|
557
|
+
|
|
558
|
+
output_token_index = tl.program_id(0)
|
|
559
|
+
|
|
560
|
+
if valid_token_count is not None:
|
|
561
|
+
valid_token_count = tl.load(
|
|
562
|
+
valid_token_count, None, eviction_policy="evict_last"
|
|
563
|
+
)
|
|
564
|
+
if output_token_index >= valid_token_count:
|
|
565
|
+
return
|
|
566
|
+
|
|
567
|
+
input_token_index = tl.load(
|
|
568
|
+
token_indices_ptr + output_token_index, None, eviction_policy="evict_first"
|
|
569
|
+
)
|
|
570
|
+
input_expert_index = tl.load(
|
|
571
|
+
expert_indices_ptr + output_token_index, None, eviction_policy="evict_first"
|
|
572
|
+
)
|
|
573
|
+
input_score = tl.load(
|
|
574
|
+
scores_ptr + input_token_index * stride_t + input_expert_index * stride_e,
|
|
575
|
+
None,
|
|
576
|
+
eviction_policy="evict_first",
|
|
577
|
+
).to(tl.float32)
|
|
578
|
+
|
|
579
|
+
row_max = 0.0
|
|
580
|
+
in_2d_ptr = (
|
|
581
|
+
input_ptr + input_token_index.to(tl.int64) * D + tl.arange(0, BLOCK_D)[:]
|
|
582
|
+
)
|
|
583
|
+
for _ in range(0, D, BLOCK_D):
|
|
584
|
+
input_token_value = tl.load(
|
|
585
|
+
in_2d_ptr,
|
|
586
|
+
None,
|
|
587
|
+
eviction_policy="evict_last",
|
|
588
|
+
).to(tl.float32)
|
|
589
|
+
output_token_value = input_token_value * input_score
|
|
590
|
+
|
|
591
|
+
tile_max = tl.max(tl.abs(output_token_value))
|
|
592
|
+
row_max = tl.maximum(tile_max, row_max)
|
|
593
|
+
in_2d_ptr += BLOCK_D
|
|
594
|
+
|
|
595
|
+
# Clamp max value appropriately.
|
|
596
|
+
if CLAMP_MAX:
|
|
597
|
+
ub = tl.load(scale_ub_ptr, eviction_policy="evict_last")
|
|
598
|
+
row_max = tl.clamp(row_max, EPS, ub)
|
|
599
|
+
else:
|
|
600
|
+
row_max = tl.maximum(row_max, EPS)
|
|
601
|
+
|
|
602
|
+
# Scale and quantize.
|
|
603
|
+
output_scale = MAX_FP8 / row_max
|
|
604
|
+
tl.store(output_scale_ptr + output_token_index, 1.0 / output_scale)
|
|
605
|
+
|
|
606
|
+
in_2d_ptr = (
|
|
607
|
+
input_ptr + input_token_index.to(tl.int64) * D + tl.arange(0, BLOCK_D)[:]
|
|
608
|
+
)
|
|
609
|
+
out_2d_ptr = (
|
|
610
|
+
output_ptr + output_token_index.to(tl.int64) * D + tl.arange(0, BLOCK_D)[:]
|
|
611
|
+
)
|
|
612
|
+
for _ in range(0, D, BLOCK_D):
|
|
613
|
+
# Load from L2
|
|
614
|
+
input_token_value = tl.load(
|
|
615
|
+
in_2d_ptr,
|
|
616
|
+
None,
|
|
617
|
+
eviction_policy="evict_first",
|
|
618
|
+
).to(tl.float32)
|
|
619
|
+
# Rematerilize
|
|
620
|
+
output_token_value_fp8 = (input_token_value * input_score) * output_scale
|
|
621
|
+
|
|
622
|
+
# Clamp A to fp8 range to make sure there's no overflow.
|
|
623
|
+
# This is required for AMD. Nvidia's default saturation
|
|
624
|
+
# handles it, but it's nice to have anyway.
|
|
625
|
+
output_token_value_fp8 = tl.clamp(output_token_value_fp8, -MAX_FP8, MAX_FP8).to(
|
|
626
|
+
TL_FP8_DTYPE
|
|
627
|
+
)
|
|
628
|
+
tl.store(
|
|
629
|
+
out_2d_ptr,
|
|
630
|
+
output_token_value_fp8,
|
|
631
|
+
None,
|
|
632
|
+
cache_modifier=".cg",
|
|
633
|
+
)
|
|
634
|
+
in_2d_ptr += BLOCK_D
|
|
635
|
+
out_2d_ptr += BLOCK_D
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
_NV_CONFIGS = [
|
|
639
|
+
triton.Config(
|
|
640
|
+
{
|
|
641
|
+
"SPLIT_T": split_t,
|
|
642
|
+
"BLOCK_D": block_d,
|
|
643
|
+
},
|
|
644
|
+
num_stages=num_stages,
|
|
645
|
+
num_warps=num_warps,
|
|
646
|
+
num_ctas=num_ctas,
|
|
647
|
+
)
|
|
648
|
+
for split_t in [1, 4, 8, 16]
|
|
649
|
+
for block_d in [512, 1024]
|
|
650
|
+
for num_stages in [1, 3]
|
|
651
|
+
for num_warps in [8, 16]
|
|
652
|
+
for num_ctas in [1]
|
|
653
|
+
]
|
|
654
|
+
|
|
655
|
+
_AMD_CONFIGS = [
|
|
656
|
+
triton.Config(
|
|
657
|
+
{
|
|
658
|
+
"SPLIT_T": split_t,
|
|
659
|
+
"BLOCK_D": block_d,
|
|
660
|
+
"waves_per_eu": waves_per_eu,
|
|
661
|
+
},
|
|
662
|
+
num_stages=num_stages,
|
|
663
|
+
num_warps=num_warps,
|
|
664
|
+
)
|
|
665
|
+
for split_t in [2, 8, 16, 32]
|
|
666
|
+
for block_d in [512, 1024]
|
|
667
|
+
for num_stages in [1, 3]
|
|
668
|
+
for num_warps, waves_per_eu in [(8, 2), (16, 4)]
|
|
669
|
+
]
|
|
670
|
+
|
|
671
|
+
|
|
672
|
+
@triton.autotune(
|
|
673
|
+
configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
|
|
674
|
+
restore_value=("out_tokens_ptr",),
|
|
675
|
+
key=["EP", "E", "T_BUCKET", "D"],
|
|
676
|
+
)
|
|
677
|
+
@triton.jit
|
|
678
|
+
def _mslk_scatter_add_padded_tokens(
|
|
679
|
+
in_tokens_ptr,
|
|
680
|
+
token_counts_ptr,
|
|
681
|
+
token_indices_ptr,
|
|
682
|
+
out_tokens_ptr,
|
|
683
|
+
EP: tl.constexpr,
|
|
684
|
+
E: tl.constexpr,
|
|
685
|
+
T_BUCKET,
|
|
686
|
+
T_K,
|
|
687
|
+
D: tl.constexpr,
|
|
688
|
+
BLOCK_E: tl.constexpr,
|
|
689
|
+
SPLIT_T: tl.constexpr,
|
|
690
|
+
BLOCK_D: tl.constexpr,
|
|
691
|
+
):
|
|
692
|
+
"""
|
|
693
|
+
in_tokens: [EP, T_K, D]
|
|
694
|
+
token_counts: [E]
|
|
695
|
+
out_tokens: [T, D]
|
|
696
|
+
"""
|
|
697
|
+
expert = tl.program_id(0)
|
|
698
|
+
t_tile = tl.program_id(1)
|
|
699
|
+
|
|
700
|
+
tl.static_assert(D % BLOCK_D == 0)
|
|
701
|
+
NUM_D_BLOCKS: tl.constexpr = D // BLOCK_D
|
|
702
|
+
|
|
703
|
+
num_tokens = tl.load(token_counts_ptr + expert)
|
|
704
|
+
if num_tokens == 0:
|
|
705
|
+
return
|
|
706
|
+
|
|
707
|
+
num_tokens_per_cta = tl.cdiv(num_tokens, SPLIT_T)
|
|
708
|
+
start_token = t_tile * num_tokens_per_cta
|
|
709
|
+
end_token = min(start_token + num_tokens_per_cta, num_tokens)
|
|
710
|
+
|
|
711
|
+
tl.static_assert(E % EP == 0)
|
|
712
|
+
EXPERT_PER_RANK: tl.constexpr = E // EP
|
|
713
|
+
rank = expert // EXPERT_PER_RANK
|
|
714
|
+
|
|
715
|
+
offs_e = tl.arange(0, BLOCK_E)
|
|
716
|
+
token_counts = tl.load(token_counts_ptr + offs_e, mask=(offs_e < E), other=0)
|
|
717
|
+
input_local_offset = (
|
|
718
|
+
tl.sum(tl.where(offs_e < expert, token_counts, 0)) + start_token
|
|
719
|
+
).to(tl.int64)
|
|
720
|
+
|
|
721
|
+
for _t in range(start_token, end_token):
|
|
722
|
+
output_local_offset = tl.load(token_indices_ptr + input_local_offset).to(
|
|
723
|
+
tl.int64
|
|
724
|
+
)
|
|
725
|
+
output_global_offset = output_local_offset * D
|
|
726
|
+
|
|
727
|
+
d_ptr = tl.arange(0, BLOCK_D)
|
|
728
|
+
input_global_ptr = (
|
|
729
|
+
in_tokens_ptr + rank * T_K * D + input_local_offset * D + d_ptr
|
|
730
|
+
)
|
|
731
|
+
output_global_ptr = out_tokens_ptr + output_global_offset + d_ptr
|
|
732
|
+
|
|
733
|
+
for _d in range(NUM_D_BLOCKS):
|
|
734
|
+
vec = tl.load(input_global_ptr)
|
|
735
|
+
tl.atomic_add(output_global_ptr, vec, sem="relaxed")
|
|
736
|
+
input_global_ptr += BLOCK_D
|
|
737
|
+
output_global_ptr += BLOCK_D
|
|
738
|
+
|
|
739
|
+
input_local_offset += 1
|