mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1132 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import functools
10
+ import warnings
11
+ from typing import Optional
12
+
13
+ import torch
14
+ import triton
15
+ import triton.language as tl
16
+ from triton.runtime import driver # @manual
17
+
18
+ try:
19
+ # @manual=//triton:triton
20
+ from triton.tools.tensor_descriptor import TensorDescriptor
21
+
22
+ TMA_AVAILABLE = True
23
+ except ImportError:
24
+ TMA_AVAILABLE = False
25
+ pass
26
+
27
+
28
+ def _grouped_gemm_set_block_size_hook(nargs):
29
+ BLOCK_M = nargs["BLOCK_SIZE_M"]
30
+ BLOCK_N = nargs["BLOCK_SIZE_N"]
31
+ BLOCK_K = nargs["BLOCK_SIZE_K"]
32
+ if nargs["USE_TMA_LOAD"]:
33
+ nargs["a_desc_ptr"].block_shape = [BLOCK_M, BLOCK_K]
34
+ nargs["b_desc_ptr"].block_shape = [BLOCK_N, BLOCK_K]
35
+
36
+
37
+ _NV_CONFIGS = [
38
+ triton.Config(
39
+ {
40
+ "BLOCK_SIZE_M": block_size_m,
41
+ "BLOCK_SIZE_N": block_size_n,
42
+ "BLOCK_SIZE_K": block_size_k,
43
+ "NUM_CONSUMER_GROUPS": 1,
44
+ },
45
+ num_stages=num_stages,
46
+ num_warps=num_warps,
47
+ num_ctas=num_ctas,
48
+ pre_hook=_grouped_gemm_set_block_size_hook,
49
+ )
50
+ for block_size_m in [64, 128]
51
+ for block_size_n in [64, 128, 256]
52
+ for block_size_k in [64, 128, 256]
53
+ for num_stages in [3, 4]
54
+ for num_warps in [4, 8]
55
+ for num_ctas in [1]
56
+ ]
57
+
58
+ if TMA_AVAILABLE:
59
+ _NV_WS_CONFIGS = [
60
+ triton.Config(
61
+ {
62
+ "BLOCK_SIZE_M": block_size_m,
63
+ "BLOCK_SIZE_N": block_size_n,
64
+ "BLOCK_SIZE_K": block_size_k,
65
+ "NUM_CONSUMER_GROUPS": 1,
66
+ "USE_TMA_STORE": use_tma_store,
67
+ },
68
+ num_stages=num_stages,
69
+ num_warps=num_warps,
70
+ num_ctas=num_ctas,
71
+ pre_hook=_grouped_gemm_set_block_size_hook,
72
+ )
73
+ for block_size_m in [64, 128, 256]
74
+ for block_size_n in [64, 128, 256]
75
+ for block_size_k in [64, 128, 256]
76
+ for num_stages in [2, 3, 4]
77
+ for num_warps in [4, 8, 16]
78
+ for num_ctas in [1]
79
+ for use_tma_store in [False]
80
+ ]
81
+ else:
82
+ _NV_WS_CONFIGS = _NV_CONFIGS
83
+
84
+
85
+ _AMD_CONFIGS = [
86
+ triton.Config(
87
+ {
88
+ "BLOCK_SIZE_M": block_size_m,
89
+ "BLOCK_SIZE_N": block_size_n,
90
+ "BLOCK_SIZE_K": block_size_k,
91
+ "waves_per_eu": waves_per_cu,
92
+ "matrix_instr_nonkdim": matrix_instr_nonkdim,
93
+ "NUM_CONSUMER_GROUPS": 1,
94
+ },
95
+ num_stages=num_stages,
96
+ num_warps=num_warps,
97
+ )
98
+ for block_size_m in [32, 64, 128]
99
+ for block_size_n in [32, 64, 128, 256]
100
+ for block_size_k in [128, 256]
101
+ for num_stages in [1, 2]
102
+ for num_warps, waves_per_cu in [(4, 1), (8, 2), (16, 4)]
103
+ for matrix_instr_nonkdim in [16]
104
+ ]
105
+
106
+
107
+ def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
108
+ device = torch.cuda.current_device()
109
+ # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
110
+ if dtsize is None:
111
+ dtsize = named_args["c_ptr"].element_size()
112
+ if dtype is None:
113
+ dtype = named_args["c_ptr"].dtype
114
+
115
+ pruned_configs = []
116
+ for config in configs:
117
+ kw = config.kwargs
118
+ (
119
+ BLOCK_M,
120
+ BLOCK_N,
121
+ BLOCK_K,
122
+ num_stages,
123
+ use_tma_load_on_scales,
124
+ ) = (
125
+ kw["BLOCK_SIZE_M"],
126
+ kw["BLOCK_SIZE_N"],
127
+ kw["BLOCK_SIZE_K"],
128
+ config.num_stages,
129
+ kw.get("USE_TMA_LOAD_ON_SCALES", False),
130
+ )
131
+ G, M, N = (
132
+ named_args["G"],
133
+ named_args["M_BUCKET"],
134
+ named_args["N"],
135
+ )
136
+
137
+ # 1. make sure we have enough smem
138
+ max_shared_memory = driver.active.utils.get_device_properties(device)[
139
+ "max_shared_mem"
140
+ ]
141
+ if torch.version.hip:
142
+ required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize
143
+ else:
144
+ required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
145
+ if required_shared_memory > max_shared_memory:
146
+ continue
147
+
148
+ M_PER_GROUP = M // G
149
+ MIN_M_TILES = 32 if torch.version.hip else 64
150
+ # 2. make sure we don't load M tiles that are too big
151
+ if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
152
+ continue
153
+ # 3. make sure we don't load N tiles that are too small
154
+ if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
155
+ continue
156
+
157
+ num_sm = driver.active.utils.get_device_properties(device)[
158
+ "multiprocessor_count"
159
+ ]
160
+ N_TILES = (N + BLOCK_N - 1) // BLOCK_N
161
+ MIN_N_TILES = 32 if torch.version.hip else 64
162
+ # 4. make sure we don't load N tiles that are too big
163
+ if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
164
+ continue
165
+ # 5. make sure we don't load N tiles that are too small
166
+ if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
167
+ continue
168
+ if dtsize >= 2:
169
+ if use_tma_load_on_scales:
170
+ continue
171
+ pruned_configs.append(config)
172
+
173
+ return pruned_configs
174
+
175
+
176
+ def early_config_prune_ws(configs, named_args, dtsize=None, dtype=None, **kwargs):
177
+ device = torch.cuda.current_device()
178
+ # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
179
+ if dtsize is None:
180
+ dtsize = named_args["c_ptr"].element_size()
181
+ if dtype is None:
182
+ dtype = named_args["c_ptr"].dtype
183
+
184
+ pruned_configs = []
185
+ for config in configs:
186
+ kw = config.kwargs
187
+ (
188
+ BLOCK_M,
189
+ BLOCK_N,
190
+ BLOCK_K,
191
+ num_stages,
192
+ use_tma_load_on_scales,
193
+ ) = (
194
+ kw["BLOCK_SIZE_M"],
195
+ kw["BLOCK_SIZE_N"],
196
+ kw["BLOCK_SIZE_K"],
197
+ config.num_stages,
198
+ kw.get("USE_TMA_LOAD_ON_SCALES", False),
199
+ )
200
+ G, M, N = (
201
+ named_args["G"],
202
+ named_args["M_BUCKET"],
203
+ named_args["N"],
204
+ )
205
+
206
+ # 1. make sure we have enough smem
207
+ max_shared_memory = driver.active.utils.get_device_properties(device)[
208
+ "max_shared_mem"
209
+ ]
210
+ if torch.version.hip:
211
+ required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize
212
+ else:
213
+ required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
214
+ if required_shared_memory > max_shared_memory:
215
+ continue
216
+
217
+ M_PER_GROUP = M // G
218
+ MIN_M_TILES = 32 if torch.version.hip else 64
219
+ # 2. make sure we don't load M tiles that are too big
220
+ if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
221
+ continue
222
+ # 3. make sure we don't load N tiles that are too small
223
+ if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
224
+ continue
225
+
226
+ num_sm = driver.active.utils.get_device_properties(device)[
227
+ "multiprocessor_count"
228
+ ]
229
+ N_TILES = (N + BLOCK_N - 1) // BLOCK_N
230
+ MIN_N_TILES = 32 if torch.version.hip else 64
231
+ # 4. make sure we don't load N tiles that are too big
232
+ if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
233
+ continue
234
+ # 5. make sure we don't load N tiles that are too small
235
+ if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
236
+ continue
237
+
238
+ if dtsize >= 2:
239
+ if use_tma_load_on_scales:
240
+ continue
241
+ pruned_configs.append(config)
242
+
243
+ return pruned_configs
244
+
245
+
246
+ @triton.autotune(
247
+ configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
248
+ key=["G", "M_BUCKET", "N", "K"],
249
+ prune_configs_by={"early_config_prune": early_config_prune},
250
+ restore_value=["c_ptr"], # restore for scatter_add fusion
251
+ )
252
+ @triton.jit
253
+ def _mslk_grouped_gemm(
254
+ a_desc_ptr,
255
+ b_desc_ptr,
256
+ c_ptr,
257
+ scatter_add_indices,
258
+ m_sizes,
259
+ bias_ptr,
260
+ token_weights_ptr,
261
+ # problem sizes
262
+ G: tl.constexpr,
263
+ M_BUCKET,
264
+ N: tl.constexpr,
265
+ K: tl.constexpr,
266
+ NUM_SMS: tl.constexpr,
267
+ FUSE_SCATTER_ADD: tl.constexpr,
268
+ USE_TMA_LOAD: tl.constexpr,
269
+ USE_TMA_STORE: tl.constexpr,
270
+ USE_FAST_ACCUM: tl.constexpr,
271
+ HAS_BIAS: tl.constexpr,
272
+ HAS_TOKEN_WEIGHTS: tl.constexpr,
273
+ # tile sizes
274
+ BLOCK_SIZE_M: tl.constexpr,
275
+ BLOCK_SIZE_N: tl.constexpr,
276
+ BLOCK_SIZE_K: tl.constexpr,
277
+ NUM_CONSUMER_GROUPS: tl.constexpr,
278
+ ) -> None:
279
+ tl.static_assert(
280
+ not (FUSE_SCATTER_ADD and USE_TMA_STORE),
281
+ "Cannot fuse scatter add with TMA store!",
282
+ )
283
+
284
+ tidx = tl.program_id(0)
285
+
286
+ M_end_offset = 0
287
+ M_end_offset = M_end_offset.to(tl.int64) # pyre-ignore
288
+ iterated_tiles = 0
289
+ for g in tl.range(G):
290
+ # Move across groups
291
+ m_size = tl.load(m_sizes + g)
292
+
293
+ if m_size > 0:
294
+ M_start_offset = M_end_offset
295
+ M_end_offset = M_start_offset + m_size
296
+ N_start_offset = g.to(tl.int64) * N
297
+ n_size = N
298
+
299
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
300
+ num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
301
+ num_tiles = num_m_tiles * num_n_tiles
302
+
303
+ if USE_TMA_STORE:
304
+ c_desc_ptr = tl.make_tensor_descriptor(
305
+ c_ptr + M_start_offset * N,
306
+ shape=[m_size, n_size],
307
+ # pyre-ignore
308
+ strides=[n_size, 1],
309
+ block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
310
+ )
311
+
312
+ # Move across tiles
313
+ while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
314
+ gidx = tidx - iterated_tiles
315
+ # Split M first and N second.
316
+ tile_m_idx = gidx % num_m_tiles
317
+ tile_n_idx = gidx // num_m_tiles
318
+
319
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
320
+
321
+ if USE_TMA_LOAD:
322
+ tl.static_assert(K % BLOCK_SIZE_K == 0)
323
+ m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
324
+ n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
325
+ for k_offset in range(0, K, BLOCK_SIZE_K):
326
+ a = a_desc_ptr.load([m_offset, k_offset])
327
+ b = b_desc_ptr.load([n_offset, k_offset])
328
+ if USE_FAST_ACCUM:
329
+ accumulator = tl.dot(a, b.T, accumulator)
330
+ else:
331
+ accumulator += tl.dot(a, b.T)
332
+ else:
333
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
334
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
335
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
336
+ a_ptrs = (
337
+ a_desc_ptr
338
+ + (M_start_offset + offs_am[:, None]) * K
339
+ + offs_k[None, :]
340
+ )
341
+ b_ptrs = (
342
+ b_desc_ptr
343
+ + (N_start_offset + offs_bn[:, None]) * K
344
+ + offs_k[None, :]
345
+ )
346
+ for k_offset in range(0, K, BLOCK_SIZE_K):
347
+ updated_k_offset = k_offset + offs_k
348
+ updated_k_offset_mask = updated_k_offset[None, :] < K # type: ignore[16]
349
+ a = tl.load(
350
+ a_ptrs,
351
+ mask=((offs_am[:, None] < m_size) & updated_k_offset_mask),
352
+ other=0.0,
353
+ )
354
+ b = tl.load(
355
+ b_ptrs,
356
+ mask=((offs_bn[:, None] < n_size) & updated_k_offset_mask),
357
+ other=0.0,
358
+ )
359
+ accumulator += tl.dot(a, b.T)
360
+ a_ptrs += BLOCK_SIZE_K
361
+ b_ptrs += BLOCK_SIZE_K
362
+
363
+ if HAS_BIAS:
364
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
365
+ bias_ptrs = bias_ptr + g.to(tl.int64) * N + offs_bn
366
+ bias = tl.load(bias_ptrs, mask=(offs_bn < n_size), other=0.0).to(
367
+ accumulator.dtype
368
+ )
369
+ accumulator = accumulator + bias[None, :]
370
+
371
+ if HAS_TOKEN_WEIGHTS:
372
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
373
+ tw_ptrs = token_weights_ptr + M_start_offset + offs_am
374
+ tw = tl.load(tw_ptrs, mask=(offs_am < m_size), other=1.0).to(
375
+ accumulator.dtype
376
+ )
377
+ accumulator = accumulator * tw[:, None]
378
+
379
+ if USE_TMA_STORE:
380
+ m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
381
+ n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
382
+ # pyre-ignore
383
+ c_desc_ptr.store(
384
+ [m_offset, n_offset], accumulator.to(c_ptr.dtype.element_ty)
385
+ )
386
+ elif FUSE_SCATTER_ADD:
387
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
388
+ mask = offs_am < m_size
389
+ m_offsets = tl.load(
390
+ scatter_add_indices + M_start_offset + offs_am,
391
+ mask=mask,
392
+ cache_modifier=".ca",
393
+ )
394
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
395
+ c = accumulator.to(c_ptr.dtype.element_ty)
396
+ tl.atomic_add(
397
+ c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
398
+ c,
399
+ mask=mask[:, None] and offs_bn[None, :] < n_size,
400
+ sem="relaxed",
401
+ )
402
+ else:
403
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
404
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
405
+ c = accumulator.to(c_ptr.dtype.element_ty)
406
+ tl.store(
407
+ c_ptr
408
+ + (M_start_offset + offs_am[:, None]) * N
409
+ + offs_bn[None, :],
410
+ c,
411
+ mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
412
+ )
413
+ tidx += NUM_SMS
414
+
415
+ iterated_tiles += num_tiles
416
+
417
+
418
+ # TODO(shikaili): Too much code duplication. Need to refactor.
419
+ @triton.autotune(
420
+ configs=_NV_WS_CONFIGS,
421
+ key=["G", "M_BUCKET", "N", "K"],
422
+ prune_configs_by={"early_config_prune": early_config_prune_ws},
423
+ restore_value=["c_ptr"], # restore for scatter_add fusion
424
+ )
425
+ @triton.jit
426
+ def _mslk_grouped_gemm_ws(
427
+ a_desc_ptr,
428
+ b_desc_ptr,
429
+ c_ptr,
430
+ scatter_add_indices,
431
+ m_sizes,
432
+ bias_ptr,
433
+ token_weights_ptr,
434
+ # problem sizes
435
+ G: tl.constexpr,
436
+ M_BUCKET: tl.constexpr,
437
+ N: tl.constexpr,
438
+ K: tl.constexpr,
439
+ NUM_SMS: tl.constexpr,
440
+ FUSE_SCATTER_ADD: tl.constexpr,
441
+ USE_TMA_LOAD: tl.constexpr,
442
+ USE_FAST_ACCUM: tl.constexpr,
443
+ HAS_BIAS: tl.constexpr,
444
+ HAS_TOKEN_WEIGHTS: tl.constexpr,
445
+ # tile sizes
446
+ BLOCK_SIZE_M: tl.constexpr,
447
+ BLOCK_SIZE_N: tl.constexpr,
448
+ BLOCK_SIZE_K: tl.constexpr,
449
+ NUM_CONSUMER_GROUPS: tl.constexpr,
450
+ USE_TMA_STORE: tl.constexpr,
451
+ ) -> None:
452
+ tl.static_assert(USE_TMA_LOAD, "Always use TMA load with warp specialziation!")
453
+ tl.static_assert(
454
+ not (FUSE_SCATTER_ADD and USE_TMA_STORE),
455
+ "Cannot fuse scatter add with TMA store!",
456
+ )
457
+
458
+ tidx = tl.program_id(0)
459
+
460
+ M_end_offset = 0
461
+ M_end_offset = M_end_offset.to(tl.int64) # pyre-ignore
462
+ iterated_tiles = 0
463
+ for g in tl.range(G):
464
+ # Move across groups
465
+ m_size = tl.load(m_sizes + g, cache_modifier=".ca")
466
+
467
+ if m_size > 0:
468
+ M_start_offset = M_end_offset
469
+ M_end_offset = M_start_offset + m_size
470
+ N_start_offset = g.to(tl.int64) * N
471
+
472
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
473
+ tl.static_assert(N % BLOCK_SIZE_N == 0, f"{N=} {BLOCK_SIZE_N=}")
474
+ NUM_N_TILES: tl.constexpr = N // BLOCK_SIZE_N
475
+ num_tiles = num_m_tiles * NUM_N_TILES
476
+
477
+ if USE_TMA_STORE:
478
+ c_desc_ptr = tl.make_tensor_descriptor(
479
+ c_ptr + M_start_offset * N,
480
+ shape=[m_size, N],
481
+ # pyre-ignore
482
+ strides=[N, 1],
483
+ block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
484
+ )
485
+
486
+ # Move across tiles
487
+ next_iterated_tiles = iterated_tiles + num_tiles
488
+ if (tidx >= iterated_tiles) and (tidx < next_iterated_tiles):
489
+ for i in range(tidx, next_iterated_tiles, NUM_SMS):
490
+ gidx = i - iterated_tiles
491
+ # Split M first and N second.
492
+ tile_m_idx = gidx % num_m_tiles
493
+ tile_n_idx = gidx // num_m_tiles
494
+
495
+ accumulator = tl.zeros(
496
+ (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32
497
+ )
498
+ tl.static_assert(K % BLOCK_SIZE_K == 0)
499
+ m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
500
+ n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
501
+ for k_offset in range(0, K, BLOCK_SIZE_K):
502
+ a = a_desc_ptr.load([m_offset, k_offset])
503
+ b = b_desc_ptr.load([n_offset, k_offset])
504
+ if USE_FAST_ACCUM:
505
+ accumulator = tl.dot(a, b.T, accumulator)
506
+ else:
507
+ accumulator += tl.dot(a, b.T)
508
+
509
+ if HAS_BIAS:
510
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
511
+ bias_ptrs = bias_ptr + g.to(tl.int64) * N + offs_bn
512
+ bias = tl.load(bias_ptrs).to(accumulator.dtype)
513
+ accumulator = accumulator + bias[None, :]
514
+
515
+ if HAS_TOKEN_WEIGHTS:
516
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
517
+ tw_ptrs = token_weights_ptr + M_start_offset + offs_am
518
+ tw = tl.load(tw_ptrs, mask=(offs_am < m_size), other=1.0).to(
519
+ accumulator.dtype
520
+ )
521
+ accumulator = accumulator * tw[:, None]
522
+
523
+ if USE_TMA_STORE:
524
+ m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
525
+ n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
526
+ # pyre-ignore
527
+ c_desc_ptr.store(
528
+ [m_offset, n_offset],
529
+ accumulator.to(c_ptr.dtype.element_ty),
530
+ )
531
+ elif FUSE_SCATTER_ADD:
532
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
533
+ mask = offs_am < m_size
534
+ m_offsets = tl.load(
535
+ scatter_add_indices + M_start_offset + offs_am,
536
+ mask=mask,
537
+ cache_modifier=".ca",
538
+ )
539
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
540
+ c = accumulator.to(c_ptr.dtype.element_ty)
541
+ tl.atomic_add(
542
+ c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
543
+ c,
544
+ mask=mask[:, None],
545
+ sem="relaxed",
546
+ )
547
+ else:
548
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
549
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
550
+ c = accumulator.to(c_ptr.dtype.element_ty)
551
+ tl.store(
552
+ c_ptr
553
+ + (M_start_offset + offs_am[:, None]) * N
554
+ + offs_bn[None, :],
555
+ c,
556
+ mask=offs_am[:, None] < m_size,
557
+ cache_modifier=".cs",
558
+ )
559
+ tidx += NUM_SMS
560
+
561
+ iterated_tiles += num_tiles
562
+
563
+
564
+ TT_FP8_DTYPE = tl.float8e4b8 if torch.version.hip else tl.float8e4nv
565
+
566
+
567
+ # TODO(shikaili): clean up redundant 'b_scale_desc_ptr' argument.
568
+ @triton.autotune(
569
+ configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
570
+ key=["G", "M_BUCKET", "N", "K"],
571
+ prune_configs_by={
572
+ "early_config_prune": functools.partial(
573
+ early_config_prune, dtype=TT_FP8_DTYPE, dtsize=1
574
+ )
575
+ },
576
+ restore_value=["c_ptr"], # restore for scatter_add fusion
577
+ )
578
+ @triton.jit
579
+ def _mslk_grouped_gemm_fp8_rowwise(
580
+ a_desc_ptr,
581
+ a_scale_ptr,
582
+ b_desc_ptr,
583
+ b_scale_ptr,
584
+ b_scale_desc_ptr,
585
+ c_ptr,
586
+ scatter_add_indices,
587
+ m_sizes,
588
+ # problem sizes
589
+ G: tl.constexpr,
590
+ M_BUCKET,
591
+ N: tl.constexpr,
592
+ K: tl.constexpr,
593
+ NUM_SMS: tl.constexpr,
594
+ FUSE_SCATTER_ADD: tl.constexpr,
595
+ USE_TMA_LOAD: tl.constexpr,
596
+ USE_TMA_STORE: tl.constexpr,
597
+ USE_FAST_ACCUM: tl.constexpr,
598
+ # tile sizes
599
+ BLOCK_SIZE_M: tl.constexpr,
600
+ BLOCK_SIZE_N: tl.constexpr,
601
+ BLOCK_SIZE_K: tl.constexpr,
602
+ NUM_CONSUMER_GROUPS: tl.constexpr,
603
+ ) -> None:
604
+ tl.static_assert(
605
+ not (FUSE_SCATTER_ADD and USE_TMA_STORE),
606
+ "Cannot fuse scatter add with TMA store!",
607
+ )
608
+
609
+ tidx = tl.program_id(0)
610
+
611
+ M_end_offset = 0
612
+ M_end_offset = M_end_offset.to(tl.int64) # pyre-ignore
613
+ iterated_tiles = 0
614
+ for g in tl.range(G):
615
+ # Move across groups
616
+ m_size = tl.load(m_sizes + g)
617
+
618
+ if m_size > 0:
619
+ M_start_offset = M_end_offset
620
+ M_end_offset = M_start_offset + m_size
621
+ N_start_offset = g.to(tl.int64) * N
622
+ n_size = N
623
+
624
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
625
+ num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
626
+ num_tiles = num_m_tiles * num_n_tiles
627
+
628
+ if USE_TMA_STORE:
629
+ c_desc_ptr = tl.make_tensor_descriptor(
630
+ c_ptr + M_start_offset * N,
631
+ shape=[m_size, n_size],
632
+ # pyre-ignore
633
+ strides=[n_size, 1],
634
+ block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
635
+ )
636
+
637
+ # Move across tiles
638
+ while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
639
+ gidx = tidx - iterated_tiles
640
+ # Split M first and N second.
641
+ tile_m_idx = gidx % num_m_tiles
642
+ tile_n_idx = gidx // num_m_tiles
643
+
644
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
645
+ tl.static_assert(K % BLOCK_SIZE_K == 0)
646
+ if USE_TMA_LOAD:
647
+ m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
648
+ n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
649
+ for k_offset in range(0, K, BLOCK_SIZE_K):
650
+ a = a_desc_ptr.load([m_offset, k_offset])
651
+ b = b_desc_ptr.load([n_offset, k_offset])
652
+ if USE_FAST_ACCUM:
653
+ accumulator = tl.dot(a, b.T, accumulator)
654
+ else:
655
+ accumulator += tl.dot(a, b.T)
656
+ else:
657
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
658
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
659
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
660
+ a_ptrs = (
661
+ a_desc_ptr
662
+ + (M_start_offset + offs_am[:, None]) * K
663
+ + offs_k[None, :]
664
+ )
665
+ b_ptrs = (
666
+ b_desc_ptr
667
+ + (N_start_offset + offs_bn[:, None]) * K
668
+ + offs_k[None, :]
669
+ )
670
+ for _ in range(0, K, BLOCK_SIZE_K):
671
+ a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
672
+ b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
673
+ accumulator += tl.dot(a, b.T)
674
+ a_ptrs += BLOCK_SIZE_K
675
+ b_ptrs += BLOCK_SIZE_K
676
+
677
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
678
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
679
+ a_scale = tl.load(
680
+ a_scale_ptr + M_start_offset + offs_am[:, None],
681
+ mask=offs_am[:, None] < m_size,
682
+ )
683
+ b_scale = tl.load(
684
+ b_scale_ptr + N_start_offset + offs_bn[None, :],
685
+ mask=offs_bn[None, :] < n_size,
686
+ )
687
+ c = accumulator.to(tl.float32) * a_scale * b_scale
688
+
689
+ if USE_TMA_STORE:
690
+ m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
691
+ n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
692
+ # pyre-ignore
693
+ c_desc_ptr.store([m_offset, n_offset], c.to(c_ptr.dtype.element_ty))
694
+ elif FUSE_SCATTER_ADD:
695
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
696
+ mask = offs_am < m_size
697
+ m_offsets = tl.load(
698
+ scatter_add_indices + M_start_offset + offs_am,
699
+ mask=mask,
700
+ cache_modifier=".ca",
701
+ )
702
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
703
+ tl.atomic_add(
704
+ c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
705
+ c.to(c_ptr.dtype.element_ty),
706
+ mask=mask[:, None] and offs_bn[None, :] < n_size,
707
+ sem="relaxed",
708
+ )
709
+ else:
710
+ tl.store(
711
+ c_ptr
712
+ + (M_start_offset + offs_am[:, None]) * N
713
+ + offs_bn[None, :],
714
+ c,
715
+ mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
716
+ )
717
+ tidx += NUM_SMS
718
+
719
+ iterated_tiles += num_tiles
720
+
721
+
722
+ # TODO(shikaili): Too much code duplication. Need to refactor.
723
+ @triton.autotune(
724
+ configs=_NV_WS_CONFIGS,
725
+ key=["G", "M_BUCKET", "N", "K"],
726
+ prune_configs_by={
727
+ "early_config_prune": functools.partial(
728
+ early_config_prune_ws, dtype=TT_FP8_DTYPE, dtsize=1
729
+ )
730
+ },
731
+ restore_value=["c_ptr"], # restore for scatter_add fusion
732
+ )
733
+ @triton.jit
734
+ def _mslk_grouped_gemm_fp8_rowwise_ws(
735
+ a_desc_ptr,
736
+ a_scale_ptr,
737
+ b_desc_ptr,
738
+ b_scale_ptr,
739
+ c_ptr,
740
+ scatter_add_indices,
741
+ m_sizes,
742
+ # problem sizes
743
+ G: tl.constexpr,
744
+ M_BUCKET: tl.constexpr,
745
+ N: tl.constexpr,
746
+ K: tl.constexpr,
747
+ NUM_SMS: tl.constexpr,
748
+ FUSE_SCATTER_ADD: tl.constexpr,
749
+ USE_TMA_LOAD: tl.constexpr,
750
+ USE_FAST_ACCUM: tl.constexpr,
751
+ # tile sizes
752
+ BLOCK_SIZE_M: tl.constexpr,
753
+ BLOCK_SIZE_N: tl.constexpr,
754
+ BLOCK_SIZE_K: tl.constexpr,
755
+ NUM_CONSUMER_GROUPS: tl.constexpr,
756
+ USE_TMA_STORE: tl.constexpr,
757
+ ) -> None:
758
+ tl.static_assert(USE_TMA_LOAD, "Always use TMA load with warp specialziation!")
759
+ tl.static_assert(
760
+ not (FUSE_SCATTER_ADD and USE_TMA_STORE),
761
+ "Cannot fuse scatter add with TMA store!",
762
+ )
763
+
764
+ tidx = tl.program_id(0)
765
+
766
+ M_end_offset = 0
767
+ M_end_offset = M_end_offset.to(tl.int64) # pyre-ignore
768
+ iterated_tiles = 0
769
+ for g in tl.range(G):
770
+ # Move across groups
771
+ m_size = tl.load(m_sizes + g, cache_modifier=".ca")
772
+
773
+ if m_size > 0:
774
+ M_start_offset = M_end_offset
775
+ M_end_offset = M_start_offset + m_size
776
+ N_start_offset = g.to(tl.int64) * N
777
+
778
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
779
+ tl.static_assert(N % BLOCK_SIZE_N == 0)
780
+ NUM_N_TILES: tl.constexpr = N // BLOCK_SIZE_N
781
+ num_tiles = num_m_tiles * NUM_N_TILES
782
+
783
+ if USE_TMA_STORE:
784
+ c_desc_ptr = tl.make_tensor_descriptor(
785
+ c_ptr + M_start_offset * N,
786
+ shape=[m_size, N],
787
+ # pyre-ignore
788
+ strides=[N, 1],
789
+ block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
790
+ )
791
+
792
+ # Move across tiles
793
+ next_iterated_tiles = iterated_tiles + num_tiles
794
+ if (tidx >= iterated_tiles) and (tidx < next_iterated_tiles):
795
+ for i in range(tidx, next_iterated_tiles, NUM_SMS):
796
+ gidx = i - iterated_tiles
797
+ # Split M first and N second.
798
+ tile_m_idx = gidx % num_m_tiles
799
+ tile_n_idx = gidx // num_m_tiles
800
+
801
+ accumulator = tl.zeros(
802
+ (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32
803
+ )
804
+ tl.static_assert(K % BLOCK_SIZE_K == 0)
805
+
806
+ m_offset = (M_start_offset + tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
807
+ n_offset = (N_start_offset + tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
808
+ for k_offset in range(0, K, BLOCK_SIZE_K):
809
+ a = a_desc_ptr.load([m_offset, k_offset])
810
+ b = b_desc_ptr.load([n_offset, k_offset])
811
+ if USE_FAST_ACCUM:
812
+ accumulator = tl.dot(a, b.T, accumulator)
813
+ else:
814
+ accumulator += tl.dot(a, b.T)
815
+
816
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
817
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
818
+ a_scale = tl.load(
819
+ a_scale_ptr + M_start_offset + offs_am[:, None],
820
+ mask=offs_am[:, None] < m_size,
821
+ cache_modifier=".ca",
822
+ )
823
+ b_scale = tl.load(
824
+ b_scale_ptr + N_start_offset + offs_bn[None, :],
825
+ cache_modifier=".ca",
826
+ )
827
+ c = accumulator.to(tl.float32) * a_scale * b_scale
828
+
829
+ if USE_TMA_STORE:
830
+ m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
831
+ n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
832
+ # pyre-ignore
833
+ c_desc_ptr.store(
834
+ [m_offset, n_offset], c.to(c_ptr.dtype.element_ty)
835
+ )
836
+ elif FUSE_SCATTER_ADD:
837
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
838
+ mask = offs_am < m_size
839
+ m_offsets = tl.load(
840
+ scatter_add_indices + M_start_offset + offs_am,
841
+ mask=mask,
842
+ cache_modifier=".ca",
843
+ )
844
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
845
+ tl.atomic_add(
846
+ c_ptr + m_offsets[:, None] * N + offs_bn[None, :],
847
+ c,
848
+ mask=mask[:, None],
849
+ sem="relaxed",
850
+ )
851
+ else:
852
+ offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
853
+ offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
854
+ tl.store(
855
+ c_ptr
856
+ + (M_start_offset + offs_am[:, None]) * N
857
+ + offs_bn[None, :],
858
+ c,
859
+ mask=offs_am[:, None] < m_size,
860
+ cache_modifier=".cs",
861
+ )
862
+ tidx += NUM_SMS
863
+
864
+ iterated_tiles += num_tiles
865
+
866
+
867
+ warnings.simplefilter("once")
868
+
869
+
870
+ def _grouped_gemm(
871
+ *,
872
+ x: torch.Tensor,
873
+ w: torch.Tensor,
874
+ m_sizes: torch.Tensor,
875
+ x_scale: Optional[torch.Tensor],
876
+ w_scale: Optional[torch.Tensor],
877
+ bias: Optional[torch.Tensor],
878
+ token_weights: Optional[torch.Tensor],
879
+ use_fast_accum: bool,
880
+ use_warp_specialization: bool,
881
+ output_tensor: Optional[torch.Tensor],
882
+ scatter_add_indices: Optional[torch.Tensor],
883
+ ) -> torch.Tensor:
884
+ USE_TMA_LOAD = not torch.version.hip and TMA_AVAILABLE
885
+ USE_TMA_STORE = False
886
+
887
+ # TODO(shikaili): Check the readniess of WS on ROCm side in Meta's Triton.
888
+ if use_warp_specialization and torch.version.hip:
889
+ warnings.warn(
890
+ "Warp specialization is disabled as it is not supported on ROCm.",
891
+ stacklevel=2,
892
+ )
893
+ use_warp_specialization = False
894
+
895
+ if use_warp_specialization:
896
+ assert TMA_AVAILABLE, "TMA is not available"
897
+ USE_TMA_STORE = True # Tuning decision
898
+
899
+ G = m_sizes.shape[0]
900
+
901
+ assert x.is_contiguous()
902
+ assert w.is_contiguous()
903
+ assert m_sizes.is_contiguous()
904
+
905
+ M, K = x.shape
906
+ N = w.shape[0] // G
907
+ assert K == w.shape[1]
908
+
909
+ if K % 8 != 0 or N % 8 != 0:
910
+ use_warp_specialization = False
911
+ USE_TMA_LOAD = False
912
+ USE_TMA_STORE = False
913
+ warnings.warn(
914
+ f"TMA load and warp specialization are disabled since K or N is not a multiple of 8: {K=}, {N=}.",
915
+ stacklevel=2,
916
+ )
917
+ assert x_scale is None, (
918
+ f"Quantisation is not supported yet when K or N is not a multiple of 8: {K=}, {N=}."
919
+ )
920
+
921
+ assert output_tensor is None, (
922
+ f"Fused scatter add has large rounding error when K or N is not a multiple of 8: {K=}, {N=}."
923
+ )
924
+
925
+ HAS_BIAS = bias is not None
926
+ if HAS_BIAS:
927
+ assert bias is not None # for type checker
928
+ assert bias.is_contiguous(), "Bias must be contiguous"
929
+ assert len(bias.shape) == 2, f"Bias must be 2D, got shape {bias.shape}"
930
+ assert bias.shape[0] == G, f"Bias dim 0 must match G={G}, got {bias.shape[0]}"
931
+ assert bias.shape[1] == N, f"Bias dim 1 must match N={N}, got {bias.shape[1]}"
932
+
933
+ HAS_TOKEN_WEIGHTS = token_weights is not None
934
+ if HAS_TOKEN_WEIGHTS:
935
+ assert token_weights is not None # for type checker
936
+ assert token_weights.is_contiguous(), "token_weights must be contiguous"
937
+ assert len(token_weights.shape) == 1, (
938
+ f"token_weights must be 1D, got shape {token_weights.shape}"
939
+ )
940
+ assert token_weights.shape[0] == M, (
941
+ f"token_weights dim 0 must match M={M}, got {token_weights.shape[0]}"
942
+ )
943
+
944
+ if output_tensor is None:
945
+ FUSE_SCATTER_ADD = False
946
+ assert scatter_add_indices is None
947
+ y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)
948
+ else:
949
+ FUSE_SCATTER_ADD = True
950
+ assert scatter_add_indices is not None
951
+ assert scatter_add_indices.is_contiguous()
952
+ assert scatter_add_indices.shape == (M,)
953
+ y = output_tensor
954
+ if M == 0 or N == 0:
955
+ return y
956
+
957
+ NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
958
+
959
+ # A dummy block value that will be overwritten in the pre_hook when we have the real block size
960
+ dummy_block = [1, 1]
961
+
962
+ if USE_TMA_LOAD:
963
+ # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional
964
+ # argument, expected `List[int]` but got `Size`
965
+ desc_x = TensorDescriptor(x, x.shape, x.stride(), dummy_block)
966
+ # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional
967
+ # argument, expected `List[int]` but got `Size`
968
+ desc_w = TensorDescriptor(w, w.shape, w.stride(), dummy_block)
969
+ else:
970
+ desc_x = x
971
+ desc_w = w
972
+
973
+ if USE_TMA_STORE:
974
+
975
+ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
976
+ return torch.empty(size, device="cuda", dtype=torch.int8)
977
+
978
+ triton.set_allocator(alloc_fn)
979
+
980
+ def grid(META):
981
+ return (NUM_SMS,)
982
+
983
+ M_BUCKET_CAP = 16384
984
+ M_BUCKET = min(triton.next_power_of_2(M), M_BUCKET_CAP)
985
+ if x_scale is not None and w_scale is not None:
986
+ assert x_scale.is_contiguous()
987
+ assert w_scale.is_contiguous()
988
+ fn = (
989
+ _mslk_grouped_gemm_fp8_rowwise_ws
990
+ if use_warp_specialization
991
+ else _mslk_grouped_gemm_fp8_rowwise
992
+ )
993
+ if use_warp_specialization:
994
+ args = (
995
+ desc_x,
996
+ x_scale,
997
+ desc_w,
998
+ w_scale,
999
+ y,
1000
+ scatter_add_indices,
1001
+ m_sizes,
1002
+ G,
1003
+ M_BUCKET,
1004
+ N,
1005
+ K,
1006
+ NUM_SMS,
1007
+ FUSE_SCATTER_ADD,
1008
+ USE_TMA_LOAD,
1009
+ use_fast_accum,
1010
+ )
1011
+ else:
1012
+ args = (
1013
+ desc_x,
1014
+ x_scale,
1015
+ desc_w,
1016
+ w_scale,
1017
+ w_scale, # b_scale_desc_ptr (unused, just passed for API compatibility)
1018
+ y,
1019
+ scatter_add_indices,
1020
+ m_sizes,
1021
+ G,
1022
+ M_BUCKET,
1023
+ N,
1024
+ K,
1025
+ NUM_SMS,
1026
+ FUSE_SCATTER_ADD,
1027
+ USE_TMA_LOAD,
1028
+ USE_TMA_STORE,
1029
+ use_fast_accum,
1030
+ )
1031
+ fn[grid](*args)
1032
+ else:
1033
+ assert x_scale is None
1034
+ assert w_scale is None
1035
+ fn = _mslk_grouped_gemm_ws if use_warp_specialization else _mslk_grouped_gemm
1036
+ args = (
1037
+ desc_x,
1038
+ desc_w,
1039
+ y,
1040
+ scatter_add_indices,
1041
+ m_sizes,
1042
+ bias if HAS_BIAS else None,
1043
+ token_weights if HAS_TOKEN_WEIGHTS else None,
1044
+ G,
1045
+ M_BUCKET,
1046
+ N,
1047
+ K,
1048
+ NUM_SMS,
1049
+ FUSE_SCATTER_ADD,
1050
+ USE_TMA_LOAD,
1051
+ )
1052
+ if use_warp_specialization:
1053
+ args += (use_fast_accum, HAS_BIAS, HAS_TOKEN_WEIGHTS)
1054
+ else:
1055
+ args += (USE_TMA_STORE, use_fast_accum, HAS_BIAS, HAS_TOKEN_WEIGHTS)
1056
+ fn[grid](*args)
1057
+
1058
+ return y
1059
+
1060
+
1061
+ def grouped_gemm(
1062
+ x: torch.Tensor,
1063
+ w: torch.Tensor,
1064
+ m_sizes: torch.Tensor,
1065
+ bias: Optional[torch.Tensor] = None,
1066
+ token_weights: Optional[torch.Tensor] = None,
1067
+ use_fast_accum: bool = True,
1068
+ *,
1069
+ _use_warp_specialization: bool = True,
1070
+ _output_tensor: Optional[torch.Tensor] = None,
1071
+ _scatter_add_indices: Optional[torch.Tensor] = None,
1072
+ ) -> torch.Tensor:
1073
+ """
1074
+ Grouped GEMM with optional bias addition and per-token weight scaling.
1075
+
1076
+ Performs: output = (x @ w.T + bias) * token_weights
1077
+ where operations are grouped by experts.
1078
+
1079
+ Args:
1080
+ x: Input tensor [M, K] where M is total tokens across all experts
1081
+ w: Weight tensor [G * N, K] where G is number of experts
1082
+ m_sizes: Tensor [G] indicating number of tokens per expert
1083
+ bias: Optional bias tensor [G, N], one bias vector per expert
1084
+ token_weights: Optional per-token scaling weights [M] (e.g., router weights)
1085
+ use_fast_accum: Enable fast accumulation for better performance
1086
+ _use_warp_specialization: Flag for warp specialization
1087
+ _output_tensor: Optional pre-allocated output tensor for scatter-add
1088
+ _scatter_add_indices: Optional indices for scatter-add operation
1089
+
1090
+ Returns:
1091
+ Output tensor [M, N]
1092
+ """
1093
+ return _grouped_gemm(
1094
+ x=x,
1095
+ w=w,
1096
+ m_sizes=m_sizes,
1097
+ x_scale=None,
1098
+ w_scale=None,
1099
+ bias=bias,
1100
+ token_weights=token_weights,
1101
+ use_fast_accum=use_fast_accum,
1102
+ use_warp_specialization=_use_warp_specialization,
1103
+ output_tensor=_output_tensor,
1104
+ scatter_add_indices=_scatter_add_indices,
1105
+ )
1106
+
1107
+
1108
+ def grouped_gemm_fp8_rowwise(
1109
+ x: torch.Tensor,
1110
+ w: torch.Tensor,
1111
+ m_sizes: torch.Tensor,
1112
+ x_scale: torch.Tensor,
1113
+ w_scale: torch.Tensor,
1114
+ use_fast_accum: bool = True,
1115
+ *,
1116
+ _use_warp_specialization: bool = True,
1117
+ _output_tensor: Optional[torch.Tensor] = None,
1118
+ _scatter_add_indices: Optional[torch.Tensor] = None,
1119
+ ) -> torch.Tensor:
1120
+ return _grouped_gemm(
1121
+ x=x,
1122
+ w=w,
1123
+ m_sizes=m_sizes,
1124
+ x_scale=x_scale,
1125
+ w_scale=w_scale,
1126
+ bias=None,
1127
+ token_weights=None,
1128
+ use_fast_accum=use_fast_accum,
1129
+ use_warp_specialization=_use_warp_specialization,
1130
+ output_tensor=_output_tensor,
1131
+ scatter_add_indices=_scatter_add_indices,
1132
+ )