mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,739 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ from typing import Optional
10
+
11
+ import torch
12
+ import triton
13
+ import triton.language as tl
14
+ from mslk.utils.triton.fp8_utils import get_fp8_constants
15
+
16
+
17
+ # Function APIs
18
+ def gather_scale_dense_tokens(
19
+ x: torch.Tensor,
20
+ token_indices: torch.Tensor,
21
+ expert_indices: torch.Tensor,
22
+ scores: torch.Tensor,
23
+ valid_token_count: Optional[torch.Tensor] = None,
24
+ ) -> torch.Tensor:
25
+ """
26
+ Gather and scale dense tokens along 1D indices.
27
+
28
+ For each input token, token_indices[i] is the index of the token in the input sequence.
29
+ expert_indices[i] is the index of the expert that the token is assigned to.
30
+ scores[i] is the score of the token.
31
+
32
+ For each expert, the tokens assigned to this expert are gathered from the input sequence,
33
+ and then their scores are multiplied element-wise.
34
+
35
+ valid_token_count is an optional tensor that can be used to filter out some tokens.
36
+ If it is provided, the function will only consider the first valid_token_count tokens in the input sequence.
37
+
38
+ The function returns a tensor of shape (a, D), where a is the number of tokens and D is the input dimension.
39
+
40
+ Args:
41
+ x (torch.Tensor): input tensor of shape (T, D)
42
+ token_indices (torch.Tensor): token indices of shape (a,)
43
+ expert_indices (torch.Tensor): expert indices of shape (a,)
44
+ scores (torch.Tensor): scores of shape (T, E)
45
+ valid_token_count (torch.Tensor, optional): valid token count of shape (,)
46
+
47
+ Returns:
48
+ torch.Tensor: output tensor of shape (a, D)
49
+ """
50
+ T, D = x.shape
51
+ E = scores.shape[1]
52
+ # a = K * T
53
+ a = token_indices.shape[0]
54
+
55
+ out = torch.empty((a, D), device=x.device, dtype=x.dtype)
56
+ if a == 0 or D == 0:
57
+ return out
58
+
59
+ assert x.is_contiguous()
60
+ assert token_indices.is_contiguous()
61
+ assert expert_indices.is_contiguous()
62
+
63
+ assert tuple(token_indices.shape) == (a,)
64
+ assert tuple(expert_indices.shape) == (a,)
65
+ assert tuple(scores.shape) == (T, E)
66
+
67
+ stride_t = scores.stride(0)
68
+ stride_e = scores.stride(1)
69
+
70
+ NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
71
+ if a >= NUM_SMS:
72
+ BLOCK_D_OUTER = D
73
+ BLOCK_D_INNER = 1024
74
+ assert D % BLOCK_D_INNER == 0
75
+ else:
76
+ BLOCK_D_OUTER = 512
77
+ BLOCK_D_INNER = 256
78
+ assert D % BLOCK_D_OUTER == 0
79
+ grid = (a, D // BLOCK_D_OUTER)
80
+ _mslk_gather_scale_dense_tokens[grid](
81
+ out,
82
+ x,
83
+ token_indices,
84
+ expert_indices,
85
+ scores,
86
+ stride_t,
87
+ stride_e,
88
+ valid_token_count,
89
+ D, # pyre-ignore
90
+ BLOCK_D_OUTER, # pyre-ignore
91
+ BLOCK_D_INNER, # pyre-ignore
92
+ )
93
+ return out
94
+
95
+
96
+ def gather_scale_quant_dense_tokens(
97
+ x: torch.Tensor,
98
+ token_indices: torch.Tensor,
99
+ expert_indices: torch.Tensor,
100
+ scores: torch.Tensor,
101
+ scale_ub: Optional[torch.Tensor] = None,
102
+ valid_token_count: Optional[torch.Tensor] = None,
103
+ ) -> tuple[torch.Tensor, torch.Tensor]:
104
+ """
105
+ Gather, scale, and quantize dense tokens along 1D indices.
106
+
107
+ For each input token, token_indices[i] is the index of the token in the input sequence.
108
+ expert_indices[i] is the index of the expert that the token is assigned to.
109
+ scores[i] is the score of the token.
110
+
111
+ For each expert, the tokens assigned to this expert are gathered from the input sequence,
112
+ and then their scores are multiplied element-wise, and then quantized to FP8.
113
+
114
+ valid_token_count is an optional tensor that can be used to filter out some tokens.
115
+ If it is provided, the function will only consider the first valid_token_count tokens in the input sequence.
116
+
117
+ The function returns a tensor of shape (a, D), where a is the number of tokens and D is the input dimension.
118
+
119
+ Args:
120
+ x (torch.Tensor): input tensor of shape (T, D)
121
+ token_indices (torch.Tensor): token indices of shape (a,)
122
+ expert_indices (torch.Tensor): expert indices of shape (a,)
123
+ scores (torch.Tensor): scores of shape (T, E)
124
+ scale_ub (torch.Tensor, optional): scale upper bound of shape (1,)
125
+ valid_token_count (torch.Tensor, optional): valid token count of shape (1,)
126
+
127
+ Returns:
128
+ torch.Tensor: output tensor of shape (a, D)
129
+ """
130
+ T, D = x.shape
131
+ E = scores.shape[1]
132
+ # a = K * T
133
+ a = token_indices.shape[0]
134
+
135
+ pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
136
+
137
+ assert x.is_contiguous()
138
+ assert token_indices.is_contiguous()
139
+ assert expert_indices.is_contiguous()
140
+
141
+ assert tuple(token_indices.shape) == (a,)
142
+ assert tuple(expert_indices.shape) == (a,)
143
+ assert tuple(scores.shape) == (T, E)
144
+
145
+ stride_t = scores.stride(0)
146
+ stride_e = scores.stride(1)
147
+
148
+ out = torch.empty((a, D), device="cuda", dtype=pt_dtype)
149
+ out_scale = torch.empty((a,), device="cuda", dtype=torch.float32)
150
+
151
+ grid = (a,)
152
+ _mslk_gather_scale_fp8_rowwise_quant_dense_tokens[grid](
153
+ out,
154
+ out_scale,
155
+ x,
156
+ token_indices,
157
+ expert_indices,
158
+ scores,
159
+ scale_ub,
160
+ stride_t,
161
+ stride_e,
162
+ valid_token_count,
163
+ D,
164
+ TL_FP8_DTYPE=tl_dtype,
165
+ MAX_FP8=max_fp8,
166
+ EPS=eps,
167
+ CLAMP_MAX=scale_ub is not None,
168
+ )
169
+ return out, out_scale
170
+
171
+
172
+ def scatter_add_dense_tokens(
173
+ out_tokens: torch.Tensor, # [T, D]
174
+ in_tokens: torch.Tensor, # [a, D]
175
+ token_indices: torch.Tensor, # [a]
176
+ valid_token_count: Optional[torch.Tensor] = None,
177
+ ) -> None:
178
+ """
179
+ Scatter add dense tokens along 1D indices.
180
+
181
+ Args:
182
+ out_tokens (torch.Tensor): output tensor of shape (T, D)
183
+ in_tokens (torch.Tensor): input tensor of shape (a, D)
184
+ token_indices (torch.Tensor): token indices of shape (a,)
185
+ valid_token_count (torch.Tensor, optional): valid token count of shape (1,)
186
+
187
+ Returns:
188
+ None
189
+ """
190
+
191
+ assert torch.version.hip is not None or (
192
+ torch.version.cuda is not None and torch.version.cuda >= "12.4"
193
+ ), "Requires CUDA version 12.4 or later on Nvidia GPUs!"
194
+
195
+ assert in_tokens.is_contiguous()
196
+ assert token_indices.is_contiguous()
197
+ assert out_tokens.is_contiguous()
198
+
199
+ a, D = in_tokens.shape
200
+ if a == 0:
201
+ return
202
+ assert token_indices.shape == (a,)
203
+ assert out_tokens.ndim == 2 and out_tokens.shape[1] == D
204
+
205
+ NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
206
+ if a >= NUM_SMS:
207
+ BLOCK_D_OUTER = D
208
+ BLOCK_D_INNER = 1024
209
+ else:
210
+ BLOCK_D_OUTER = 512
211
+ BLOCK_D_INNER = 256
212
+ while D % BLOCK_D_OUTER != 0:
213
+ BLOCK_D_OUTER //= 2
214
+ while D % BLOCK_D_INNER != 0:
215
+ BLOCK_D_INNER //= 2
216
+
217
+ grid = (a, D // BLOCK_D_OUTER)
218
+ _mslk_scatter_add_dense_tokens[grid](
219
+ out_tokens,
220
+ in_tokens,
221
+ token_indices,
222
+ valid_token_count,
223
+ D, # pyre-ignore
224
+ BLOCK_D_OUTER, # pyre-ignore
225
+ BLOCK_D_INNER, # pyre-ignore
226
+ )
227
+
228
+
229
+ def scatter_add_padded_tokens(
230
+ in_tokens: torch.Tensor, # [EP, T_K, D]
231
+ token_counts: torch.Tensor, # [E]
232
+ token_indices: torch.Tensor, # [T_K]
233
+ out_tokens: torch.Tensor, # [T, D]
234
+ ) -> None:
235
+ """
236
+ Scatter add valid tokens based on token counts metadata.
237
+
238
+ Args:
239
+ in_tokens (torch.Tensor): input tensor of shape (EP, T_K, D)
240
+ token_counts (torch.Tensor): token counts of shape (E,)
241
+ token_indices (torch.Tensor): token indices of shape (T_K,)
242
+ out_tokens (torch.Tensor): output tensor of shape (T, D)
243
+
244
+ Returns:
245
+ None
246
+ """
247
+ assert torch.version.hip is not None or (
248
+ torch.version.cuda is not None and torch.version.cuda >= "12.4"
249
+ ), "Requires CUDA version 12.4 or later on Nvidia GPUs!"
250
+
251
+ assert in_tokens.is_contiguous()
252
+ assert token_counts.is_contiguous()
253
+ assert token_indices.is_contiguous()
254
+ assert out_tokens.is_contiguous()
255
+
256
+ EP, T_K, D = in_tokens.shape
257
+ E = token_counts.shape[0]
258
+ assert tuple(token_indices.shape) == (T_K,)
259
+ assert T_K % out_tokens.shape[0] == 0 and out_tokens.shape[1] == D
260
+
261
+ def grid(META):
262
+ return (
263
+ E,
264
+ META["SPLIT_T"],
265
+ )
266
+
267
+ T_BUCKET_CAP = 16384
268
+ T_BUCKET = min(triton.next_power_of_2(T_K), T_BUCKET_CAP)
269
+ BLOCK_E = max(triton.next_power_of_2(E), 8)
270
+ _mslk_scatter_add_padded_tokens[grid](
271
+ in_tokens,
272
+ token_counts,
273
+ token_indices,
274
+ out_tokens,
275
+ EP,
276
+ E,
277
+ T_BUCKET,
278
+ T_K,
279
+ D,
280
+ BLOCK_E,
281
+ )
282
+
283
+
284
+ # Torch Custom Op Registrations
285
+ _GATHER_SCALE_DENSE_TOKENS_OP_NAME = "mslk::gather_scale_dense_tokens"
286
+
287
+ torch.library.define(
288
+ _GATHER_SCALE_DENSE_TOKENS_OP_NAME,
289
+ "(Tensor x, Tensor token_indices, Tensor expert_indices, Tensor scores, Tensor? valid_token_count=None) -> Tensor",
290
+ )
291
+
292
+
293
+ @torch.library.impl(_GATHER_SCALE_DENSE_TOKENS_OP_NAME, "Meta")
294
+ def gather_scale_dense_tokens_meta(
295
+ x,
296
+ token_indices,
297
+ expert_indices,
298
+ scores,
299
+ valid_token_count=None,
300
+ ):
301
+ D = x.shape[1]
302
+ a = token_indices.shape[0]
303
+ return x.new_empty((a, D))
304
+
305
+
306
+ @torch.library.impl(_GATHER_SCALE_DENSE_TOKENS_OP_NAME, "CUDA")
307
+ def gather_scale_dense_tokens_cuda(
308
+ x,
309
+ token_indices,
310
+ expert_indices,
311
+ scores,
312
+ valid_token_count=None,
313
+ ):
314
+ return gather_scale_dense_tokens(
315
+ x,
316
+ token_indices,
317
+ expert_indices,
318
+ scores,
319
+ valid_token_count,
320
+ )
321
+
322
+
323
+ _GATHER_SCALE_QUANT_DENSE_TOKENS_OP_NAME = "mslk::gather_scale_quant_dense_tokens"
324
+
325
+ torch.library.define(
326
+ _GATHER_SCALE_QUANT_DENSE_TOKENS_OP_NAME,
327
+ "(Tensor x, Tensor token_indices, Tensor expert_indices, Tensor scores, Tensor? scale_ub=None, Tensor? valid_token_count=None) -> Tensor",
328
+ )
329
+
330
+
331
+ @torch.library.impl(_GATHER_SCALE_QUANT_DENSE_TOKENS_OP_NAME, "Meta")
332
+ def gather_scale_quant_dense_tokens_meta(
333
+ x,
334
+ token_indices,
335
+ expert_indices,
336
+ scores,
337
+ scale_ub=None,
338
+ valid_token_count=None,
339
+ ):
340
+ D = x.shape[1]
341
+ a = token_indices.shape[0]
342
+ pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
343
+ return torch.empty((a, D), device=x.device, dtype=pt_dtype), torch.empty(
344
+ (a,), device=x.device, dtype=torch.float32
345
+ )
346
+
347
+
348
+ @torch.library.impl(_GATHER_SCALE_QUANT_DENSE_TOKENS_OP_NAME, "CUDA")
349
+ def gather_scale_quant_dense_tokens_cuda(
350
+ x,
351
+ token_indices,
352
+ expert_indices,
353
+ scores,
354
+ scale_ub=None,
355
+ valid_token_count=None,
356
+ ):
357
+ return gather_scale_quant_dense_tokens(
358
+ x,
359
+ token_indices,
360
+ expert_indices,
361
+ scores,
362
+ scale_ub,
363
+ valid_token_count,
364
+ )
365
+
366
+
367
+ _SCATTER_ADD_DENSE_TOKENS_OP_NAME = "mslk::scatter_add_dense_tokens"
368
+
369
+ torch.library.define(
370
+ _SCATTER_ADD_DENSE_TOKENS_OP_NAME,
371
+ "(Tensor out_tokens, Tensor in_tokens, Tensor token_indices, Tensor? valid_token_count=None) -> None",
372
+ )
373
+
374
+
375
+ @torch.library.impl(_SCATTER_ADD_DENSE_TOKENS_OP_NAME, "Meta")
376
+ def scatter_add_dense_tokens_meta(
377
+ out_tokens,
378
+ in_tokens,
379
+ token_indices,
380
+ valid_token_count=None,
381
+ ):
382
+ return None
383
+
384
+
385
+ @torch.library.impl(_SCATTER_ADD_DENSE_TOKENS_OP_NAME, "CUDA")
386
+ def scatter_add_dense_tokens_cuda(
387
+ out_tokens,
388
+ in_tokens,
389
+ token_indices,
390
+ valid_token_count=None,
391
+ ):
392
+ return scatter_add_dense_tokens(
393
+ out_tokens, in_tokens, token_indices, valid_token_count
394
+ )
395
+
396
+
397
+ _SCATTER_ADD_PADDED_TOKENS_OP_NAME = "mslk::scatter_add_padded_tokens"
398
+
399
+ torch.library.define(
400
+ _SCATTER_ADD_PADDED_TOKENS_OP_NAME,
401
+ "(Tensor in_tokens, Tensor token_counts, Tensor token_indices, Tensor out_tokens) -> None",
402
+ )
403
+
404
+
405
+ @torch.library.impl(_SCATTER_ADD_PADDED_TOKENS_OP_NAME, "Meta")
406
+ def scatter_add_padded_tokens_meta(
407
+ in_tokens,
408
+ token_counts,
409
+ token_indices,
410
+ out_tokens,
411
+ ):
412
+ return None
413
+
414
+
415
+ @torch.library.impl(_SCATTER_ADD_PADDED_TOKENS_OP_NAME, "CUDA")
416
+ def scatter_add_padded_tokens_cuda(
417
+ in_tokens,
418
+ token_counts,
419
+ token_indices,
420
+ out_tokens,
421
+ ):
422
+ return scatter_add_padded_tokens(
423
+ in_tokens,
424
+ token_counts,
425
+ token_indices,
426
+ out_tokens,
427
+ )
428
+
429
+
430
+ # Kernel Implementations
431
+ @triton.jit
432
+ def _mslk_gather_scale_dense_tokens(
433
+ out,
434
+ x,
435
+ token_indices,
436
+ expert_indices,
437
+ scores,
438
+ stride_t,
439
+ stride_e,
440
+ valid_token_count,
441
+ D: tl.constexpr,
442
+ BLOCK_D_OUTER: tl.constexpr,
443
+ BLOCK_D_INNER: tl.constexpr,
444
+ ):
445
+ output_token_index = tl.program_id(0)
446
+ feature_offset = tl.program_id(1) * BLOCK_D_OUTER
447
+
448
+ if valid_token_count is not None:
449
+ valid_token_count = tl.load(
450
+ valid_token_count, None, eviction_policy="evict_last"
451
+ )
452
+ if output_token_index >= valid_token_count:
453
+ return
454
+
455
+ input_token_index = tl.load(
456
+ token_indices + output_token_index, None, eviction_policy="evict_last"
457
+ )
458
+ input_expert_index = tl.load(
459
+ expert_indices + output_token_index, None, eviction_policy="evict_last"
460
+ )
461
+
462
+ input_score = tl.load(
463
+ scores + input_token_index * stride_t + input_expert_index * stride_e,
464
+ None,
465
+ eviction_policy="evict_last",
466
+ ).to(tl.float32)
467
+
468
+ for _ in range(0, BLOCK_D_OUTER // BLOCK_D_INNER):
469
+ input_token_value = tl.load(
470
+ x
471
+ + input_token_index.to(tl.int64) * D
472
+ + feature_offset
473
+ + tl.arange(0, BLOCK_D_INNER)[:],
474
+ None,
475
+ ).to(tl.float32)
476
+ output_token_value = input_token_value * input_score
477
+
478
+ tl.store(
479
+ out
480
+ + output_token_index.to(tl.int64) * D
481
+ + feature_offset
482
+ + tl.arange(0, BLOCK_D_INNER)[:],
483
+ output_token_value,
484
+ None,
485
+ )
486
+ feature_offset += BLOCK_D_INNER
487
+
488
+
489
+ @triton.jit
490
+ def _mslk_scatter_add_dense_tokens(
491
+ out_tokens,
492
+ in_tokens,
493
+ token_indices,
494
+ valid_token_count,
495
+ D: tl.constexpr,
496
+ BLOCK_D_OUTER: tl.constexpr,
497
+ BLOCK_D_INNER: tl.constexpr,
498
+ ):
499
+ input_token_index = tl.program_id(0).to(tl.int64)
500
+ feature_offset = tl.program_id(1) * BLOCK_D_OUTER + tl.arange(0, BLOCK_D_INNER)[:]
501
+
502
+ if valid_token_count is not None:
503
+ valid_token_count = tl.load(
504
+ valid_token_count, None, eviction_policy="evict_last"
505
+ )
506
+ if input_token_index >= valid_token_count:
507
+ return
508
+
509
+ output_token_index = tl.load(
510
+ token_indices + input_token_index, None, eviction_policy="evict_last"
511
+ ).to(tl.int64)
512
+
513
+ for _ in range(0, BLOCK_D_OUTER // BLOCK_D_INNER):
514
+ input_token_value = tl.load(
515
+ in_tokens + input_token_index * D + feature_offset,
516
+ None,
517
+ eviction_policy="evict_first",
518
+ )
519
+
520
+ tl.atomic_add(
521
+ out_tokens + output_token_index * D + feature_offset,
522
+ input_token_value,
523
+ None,
524
+ sem="relaxed",
525
+ )
526
+ feature_offset += BLOCK_D_INNER
527
+
528
+
529
+ @triton.autotune(
530
+ configs=[
531
+ triton.Config({"BLOCK_D": 256}),
532
+ triton.Config({"BLOCK_D": 512}),
533
+ triton.Config({"BLOCK_D": 1024}),
534
+ ],
535
+ key=["D"],
536
+ )
537
+ @triton.jit
538
+ def _mslk_gather_scale_fp8_rowwise_quant_dense_tokens(
539
+ output_ptr,
540
+ output_scale_ptr,
541
+ input_ptr,
542
+ token_indices_ptr,
543
+ expert_indices_ptr,
544
+ scores_ptr,
545
+ scale_ub_ptr,
546
+ stride_t,
547
+ stride_e,
548
+ valid_token_count,
549
+ D: tl.constexpr,
550
+ TL_FP8_DTYPE: tl.constexpr,
551
+ MAX_FP8: tl.constexpr,
552
+ EPS: tl.constexpr,
553
+ CLAMP_MAX: tl.constexpr,
554
+ BLOCK_D: tl.constexpr,
555
+ ):
556
+ tl.static_assert(D % BLOCK_D == 0, "D must be a multiple of BLOCK_D")
557
+
558
+ output_token_index = tl.program_id(0)
559
+
560
+ if valid_token_count is not None:
561
+ valid_token_count = tl.load(
562
+ valid_token_count, None, eviction_policy="evict_last"
563
+ )
564
+ if output_token_index >= valid_token_count:
565
+ return
566
+
567
+ input_token_index = tl.load(
568
+ token_indices_ptr + output_token_index, None, eviction_policy="evict_first"
569
+ )
570
+ input_expert_index = tl.load(
571
+ expert_indices_ptr + output_token_index, None, eviction_policy="evict_first"
572
+ )
573
+ input_score = tl.load(
574
+ scores_ptr + input_token_index * stride_t + input_expert_index * stride_e,
575
+ None,
576
+ eviction_policy="evict_first",
577
+ ).to(tl.float32)
578
+
579
+ row_max = 0.0
580
+ in_2d_ptr = (
581
+ input_ptr + input_token_index.to(tl.int64) * D + tl.arange(0, BLOCK_D)[:]
582
+ )
583
+ for _ in range(0, D, BLOCK_D):
584
+ input_token_value = tl.load(
585
+ in_2d_ptr,
586
+ None,
587
+ eviction_policy="evict_last",
588
+ ).to(tl.float32)
589
+ output_token_value = input_token_value * input_score
590
+
591
+ tile_max = tl.max(tl.abs(output_token_value))
592
+ row_max = tl.maximum(tile_max, row_max)
593
+ in_2d_ptr += BLOCK_D
594
+
595
+ # Clamp max value appropriately.
596
+ if CLAMP_MAX:
597
+ ub = tl.load(scale_ub_ptr, eviction_policy="evict_last")
598
+ row_max = tl.clamp(row_max, EPS, ub)
599
+ else:
600
+ row_max = tl.maximum(row_max, EPS)
601
+
602
+ # Scale and quantize.
603
+ output_scale = MAX_FP8 / row_max
604
+ tl.store(output_scale_ptr + output_token_index, 1.0 / output_scale)
605
+
606
+ in_2d_ptr = (
607
+ input_ptr + input_token_index.to(tl.int64) * D + tl.arange(0, BLOCK_D)[:]
608
+ )
609
+ out_2d_ptr = (
610
+ output_ptr + output_token_index.to(tl.int64) * D + tl.arange(0, BLOCK_D)[:]
611
+ )
612
+ for _ in range(0, D, BLOCK_D):
613
+ # Load from L2
614
+ input_token_value = tl.load(
615
+ in_2d_ptr,
616
+ None,
617
+ eviction_policy="evict_first",
618
+ ).to(tl.float32)
619
+ # Rematerilize
620
+ output_token_value_fp8 = (input_token_value * input_score) * output_scale
621
+
622
+ # Clamp A to fp8 range to make sure there's no overflow.
623
+ # This is required for AMD. Nvidia's default saturation
624
+ # handles it, but it's nice to have anyway.
625
+ output_token_value_fp8 = tl.clamp(output_token_value_fp8, -MAX_FP8, MAX_FP8).to(
626
+ TL_FP8_DTYPE
627
+ )
628
+ tl.store(
629
+ out_2d_ptr,
630
+ output_token_value_fp8,
631
+ None,
632
+ cache_modifier=".cg",
633
+ )
634
+ in_2d_ptr += BLOCK_D
635
+ out_2d_ptr += BLOCK_D
636
+
637
+
638
+ _NV_CONFIGS = [
639
+ triton.Config(
640
+ {
641
+ "SPLIT_T": split_t,
642
+ "BLOCK_D": block_d,
643
+ },
644
+ num_stages=num_stages,
645
+ num_warps=num_warps,
646
+ num_ctas=num_ctas,
647
+ )
648
+ for split_t in [1, 4, 8, 16]
649
+ for block_d in [512, 1024]
650
+ for num_stages in [1, 3]
651
+ for num_warps in [8, 16]
652
+ for num_ctas in [1]
653
+ ]
654
+
655
+ _AMD_CONFIGS = [
656
+ triton.Config(
657
+ {
658
+ "SPLIT_T": split_t,
659
+ "BLOCK_D": block_d,
660
+ "waves_per_eu": waves_per_eu,
661
+ },
662
+ num_stages=num_stages,
663
+ num_warps=num_warps,
664
+ )
665
+ for split_t in [2, 8, 16, 32]
666
+ for block_d in [512, 1024]
667
+ for num_stages in [1, 3]
668
+ for num_warps, waves_per_eu in [(8, 2), (16, 4)]
669
+ ]
670
+
671
+
672
+ @triton.autotune(
673
+ configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
674
+ restore_value=("out_tokens_ptr",),
675
+ key=["EP", "E", "T_BUCKET", "D"],
676
+ )
677
+ @triton.jit
678
+ def _mslk_scatter_add_padded_tokens(
679
+ in_tokens_ptr,
680
+ token_counts_ptr,
681
+ token_indices_ptr,
682
+ out_tokens_ptr,
683
+ EP: tl.constexpr,
684
+ E: tl.constexpr,
685
+ T_BUCKET,
686
+ T_K,
687
+ D: tl.constexpr,
688
+ BLOCK_E: tl.constexpr,
689
+ SPLIT_T: tl.constexpr,
690
+ BLOCK_D: tl.constexpr,
691
+ ):
692
+ """
693
+ in_tokens: [EP, T_K, D]
694
+ token_counts: [E]
695
+ out_tokens: [T, D]
696
+ """
697
+ expert = tl.program_id(0)
698
+ t_tile = tl.program_id(1)
699
+
700
+ tl.static_assert(D % BLOCK_D == 0)
701
+ NUM_D_BLOCKS: tl.constexpr = D // BLOCK_D
702
+
703
+ num_tokens = tl.load(token_counts_ptr + expert)
704
+ if num_tokens == 0:
705
+ return
706
+
707
+ num_tokens_per_cta = tl.cdiv(num_tokens, SPLIT_T)
708
+ start_token = t_tile * num_tokens_per_cta
709
+ end_token = min(start_token + num_tokens_per_cta, num_tokens)
710
+
711
+ tl.static_assert(E % EP == 0)
712
+ EXPERT_PER_RANK: tl.constexpr = E // EP
713
+ rank = expert // EXPERT_PER_RANK
714
+
715
+ offs_e = tl.arange(0, BLOCK_E)
716
+ token_counts = tl.load(token_counts_ptr + offs_e, mask=(offs_e < E), other=0)
717
+ input_local_offset = (
718
+ tl.sum(tl.where(offs_e < expert, token_counts, 0)) + start_token
719
+ ).to(tl.int64)
720
+
721
+ for _t in range(start_token, end_token):
722
+ output_local_offset = tl.load(token_indices_ptr + input_local_offset).to(
723
+ tl.int64
724
+ )
725
+ output_global_offset = output_local_offset * D
726
+
727
+ d_ptr = tl.arange(0, BLOCK_D)
728
+ input_global_ptr = (
729
+ in_tokens_ptr + rank * T_K * D + input_local_offset * D + d_ptr
730
+ )
731
+ output_global_ptr = out_tokens_ptr + output_global_offset + d_ptr
732
+
733
+ for _d in range(NUM_D_BLOCKS):
734
+ vec = tl.load(input_global_ptr)
735
+ tl.atomic_add(output_global_ptr, vec, sem="relaxed")
736
+ input_global_ptr += BLOCK_D
737
+ output_global_ptr += BLOCK_D
738
+
739
+ input_local_offset += 1