mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2702 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+ import os
10
+ from typing import Dict, List, Optional, Tuple, Union
11
+
12
+ import torch
13
+ import triton # @manual
14
+ import triton.language as tl # @manual
15
+ from mslk.gemm.triton.matmul_perf_model import early_config_prune, estimate_matmul_time
16
+ from mslk.gemm.triton.utils import map_dtype_to_triton, TmaAutoTuneHelper
17
+ from mslk.utils.triton.fp8_utils import get_fp8_constants, reinterpret_fp8_type
18
+ from packaging import version
19
+ from torch._tensor import Tensor
20
+ from triton import Config # @manual
21
+ from triton.runtime.jit import TensorWrapper # @manual
22
+
23
+ logger: logging.Logger = logging.getLogger(__name__)
24
+
25
+ running_on_github: bool = os.getenv("GITHUB_ENV") is not None
26
+
27
+ try:
28
+ # pyre-ignore[21]
29
+ from triton.fb.compat import disable_bufferops # @manual
30
+ except ModuleNotFoundError:
31
+ # Ensure we can call disable_bufferops if compat is not included (e.g. opensource)
32
+ # TODO(njriasan): Remove when we integrate triton.fb.compat into every Triton
33
+ # version.
34
+ from contextlib import contextmanager
35
+
36
+ @contextmanager
37
+ def disable_bufferops(_unused: bool):
38
+ yield None
39
+
40
+
41
+ def init_to_zero(name):
42
+ return lambda nargs: nargs[name].zero_()
43
+
44
+
45
+ def get_configs_io_bound() -> List[Config]:
46
+ """
47
+ Returns a list of configs for matmul that are IO bound.
48
+
49
+ Returns:
50
+ List[Config]: list of configs.
51
+ """
52
+ configs = []
53
+ for num_stages in [2, 3, 4, 5, 6]:
54
+ for block_m in [16, 32]:
55
+ for block_k in [32, 64]:
56
+ for block_n in [32, 64, 128, 256]:
57
+ num_warps = 2 if block_n <= 64 else 4
58
+ configs.append(
59
+ Config(
60
+ {
61
+ "BLOCK_M": block_m,
62
+ "BLOCK_N": block_n,
63
+ "BLOCK_K": block_k,
64
+ "SPLIT_K": 1,
65
+ },
66
+ num_stages=num_stages,
67
+ num_warps=num_warps,
68
+ )
69
+ )
70
+ # split_k
71
+ for split_k in []: # Disabled [2, 4, 8, 16]:
72
+ configs.append(
73
+ Config(
74
+ {
75
+ "BLOCK_M": block_m,
76
+ "BLOCK_N": block_n,
77
+ "BLOCK_K": block_k,
78
+ "SPLIT_K": split_k,
79
+ },
80
+ num_stages=num_stages,
81
+ num_warps=num_warps,
82
+ pre_hook=init_to_zero("C"),
83
+ )
84
+ )
85
+ return configs
86
+
87
+
88
+ def dummy_prune_configs(configs, named_args, **kwargs):
89
+ M = named_args["M"]
90
+ N = named_args["N"]
91
+ K = named_args["K"]
92
+
93
+ logger.info(f"{len(configs)=} {len(configs)=} for {M=} {N=} {K=}")
94
+ return configs
95
+
96
+
97
+ MATMUL_CONFIGS: List[Config] = [
98
+ # basic configs for compute-bound matmuls
99
+ Config(
100
+ {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
101
+ num_stages=3,
102
+ num_warps=8,
103
+ ),
104
+ Config(
105
+ {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
106
+ num_stages=3,
107
+ num_warps=8,
108
+ ),
109
+ Config(
110
+ {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
111
+ num_stages=4,
112
+ num_warps=4,
113
+ ),
114
+ Config(
115
+ {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
116
+ num_stages=4,
117
+ num_warps=4,
118
+ ),
119
+ Config(
120
+ {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
121
+ num_stages=4,
122
+ num_warps=4,
123
+ ),
124
+ Config(
125
+ {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 256, "SPLIT_K": 1},
126
+ num_stages=4,
127
+ num_warps=4,
128
+ ),
129
+ Config(
130
+ {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
131
+ num_stages=4,
132
+ num_warps=4,
133
+ ),
134
+ Config(
135
+ {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
136
+ num_stages=4,
137
+ num_warps=4,
138
+ ),
139
+ Config(
140
+ {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
141
+ num_stages=4,
142
+ num_warps=4,
143
+ ),
144
+ Config(
145
+ {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
146
+ num_stages=4,
147
+ num_warps=4,
148
+ ),
149
+ Config(
150
+ {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
151
+ num_stages=5,
152
+ num_warps=2,
153
+ ),
154
+ # good for int8
155
+ Config(
156
+ {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
157
+ num_stages=3,
158
+ num_warps=8,
159
+ ),
160
+ Config(
161
+ {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
162
+ num_stages=3,
163
+ num_warps=8,
164
+ ),
165
+ Config(
166
+ {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1},
167
+ num_stages=4,
168
+ num_warps=4,
169
+ ),
170
+ Config(
171
+ {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
172
+ num_stages=4,
173
+ num_warps=4,
174
+ ),
175
+ Config(
176
+ {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
177
+ num_stages=4,
178
+ num_warps=4,
179
+ ),
180
+ Config(
181
+ {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
182
+ num_stages=4,
183
+ num_warps=4,
184
+ ),
185
+ Config(
186
+ {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
187
+ num_stages=4,
188
+ num_warps=4,
189
+ ),
190
+ Config(
191
+ {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
192
+ num_stages=4,
193
+ num_warps=4,
194
+ ),
195
+ Config(
196
+ {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
197
+ num_stages=5,
198
+ num_warps=2,
199
+ ),
200
+ ] + get_configs_io_bound()
201
+
202
+
203
+ @triton.autotune(
204
+ configs=MATMUL_CONFIGS,
205
+ prune_configs_by={
206
+ "early_config_prune": dummy_prune_configs,
207
+ },
208
+ key=[
209
+ "m_key",
210
+ "n_key",
211
+ "k_key",
212
+ ],
213
+ )
214
+ @triton.jit
215
+ def _kernel_matmul_fp8_row(
216
+ A_ptr,
217
+ B_ptr,
218
+ C_ptr,
219
+ M,
220
+ N,
221
+ K,
222
+ m_key,
223
+ n_key,
224
+ k_key,
225
+ A_scale,
226
+ B_scale,
227
+ Bias,
228
+ stride_am,
229
+ stride_ak,
230
+ stride_bn,
231
+ stride_bk,
232
+ stride_cm,
233
+ stride_cn,
234
+ dot_out_dtype: tl.constexpr,
235
+ allow_tf32: tl.constexpr,
236
+ fp8_fast_accum: tl.constexpr,
237
+ skip_scaling_a: tl.constexpr,
238
+ BLOCK_M: tl.constexpr,
239
+ BLOCK_N: tl.constexpr,
240
+ BLOCK_K: tl.constexpr,
241
+ GROUP_M: tl.constexpr,
242
+ SPLIT_K: tl.constexpr,
243
+ USE_BIAS: tl.constexpr,
244
+ AB_DTYPE: tl.constexpr,
245
+ NUM_SMS: tl.constexpr,
246
+ ) -> None:
247
+ """Matmul kernel of [M, K] @ [N, K] with row-wise scales
248
+
249
+ performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
250
+
251
+ Args:
252
+ A (TensorWrapper): [M, K] input tensor.
253
+ B (TensorWrapper): [N, K] input tensor.
254
+ C (TensorWrapper): [M, N] output tensor.
255
+ M (int): M dimension of input tensor.
256
+ N (int): N dimension of input tensor.
257
+ K (int): K dimension of input tensor.
258
+ m_key (int): Autotuning key for M dimension of input tensor.
259
+ n_key (int): Autotuning key for N dimension of input tensor.
260
+ k_key (int): Autotuning key for K dimension of input tensor.
261
+ A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A.
262
+ B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B.
263
+ Bias (tensorWrapper): [N] Optional bias tensor.
264
+ stride_am (int): Stride of M dimension of A.
265
+ stride_ak (int): Stride of K dimension of A.
266
+ stride_bn (int): Stride of N dimension of B.
267
+ stride_bk (int): Stride of K dimension of B.
268
+ stride_cm (int): Stride of M dimension of C.
269
+ stride_cn (int): Stride of N dimension of C.
270
+ dot_out_dtype (torch.dtype): Output type of tensor core.
271
+ allow_tf32 (bool): Whether to use TF32 for tensor core.
272
+ fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
273
+ BLOCK_M (int): Block size for M dimension.
274
+ BLOCK_N (int): Block size for N dimension.
275
+ BLOCK_K (int): Block size for K dimension.
276
+ GROUP_M (int): Number of groups for M dimension swizzle.
277
+ SPLIT_K (int): Number of SM's to launch per row.
278
+ USE_BIAS (bool): Whether to use bias.
279
+ EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
280
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
281
+ """
282
+ # Matrix multiplication.
283
+ start_pid = tl.program_id(axis=0)
284
+ num_pid_m = tl.cdiv(M, BLOCK_M)
285
+ num_pid_n = tl.cdiv(N, BLOCK_N)
286
+ k_tiles = tl.cdiv(K, BLOCK_K)
287
+ num_tiles = num_pid_m * num_pid_n
288
+
289
+ offs_k_for_mask = tl.arange(0, BLOCK_K)
290
+
291
+ num_pid_in_group = GROUP_M * num_pid_n
292
+
293
+ acc_dtype = tl.float32 if allow_tf32 else dot_out_dtype
294
+
295
+ # Outer loop over tiles assigned to this SM
296
+ for tile_id in range(start_pid, num_tiles, NUM_SMS):
297
+ group_id = tile_id // num_pid_in_group
298
+ first_pid_m = group_id * GROUP_M
299
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
300
+ # pyre-ignore[58]: `%` is not supported for operand types `int` and `tl.core.constexpr`.
301
+ pid_m = first_pid_m + (tile_id % group_size_m)
302
+ pid_n = (tile_id % num_pid_in_group) // group_size_m
303
+
304
+ start_m = pid_m * BLOCK_M
305
+ start_n = pid_n * BLOCK_N
306
+ offs_am = start_m + tl.arange(0, BLOCK_M)
307
+ offs_bn = start_n + tl.arange(0, BLOCK_N)
308
+ offs_am = tl.where(offs_am < M, offs_am, 0)
309
+ offs_bn = tl.where(offs_bn < N, offs_bn, 0)
310
+ offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_M), BLOCK_M)
311
+ offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_N), BLOCK_N)
312
+
313
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)
314
+
315
+ # Inner loop over K dimension
316
+ for ki in range(0, k_tiles):
317
+ offs_k = ki * BLOCK_K + tl.arange(0, BLOCK_K)
318
+ A = A_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
319
+ B = B_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
320
+
321
+ a = tl.load(A, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_K, other=0.0)
322
+ b = tl.load(B, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_K, other=0.0)
323
+ acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32)
324
+
325
+ # rematerialize rm and rn to save registers
326
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
327
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
328
+
329
+ # Invert scaling.
330
+ b_scale = tl.load(B_scale + rn, mask=rn < N)
331
+ if skip_scaling_a:
332
+ acc *= b_scale[None, :]
333
+ else:
334
+ a_scale = tl.load(A_scale + rm, mask=rm < M)
335
+ # pyre-ignore[16]: Undefined attribute [16]: `float`
336
+ # has no attribute `__getitem__`.
337
+ scale = a_scale[:, None] * b_scale[None, :]
338
+ acc *= scale
339
+
340
+ # Load and add bias if specified.
341
+ if USE_BIAS:
342
+ bias = tl.load(Bias + rn, mask=rn < N)
343
+ acc += bias[None, :]
344
+
345
+ acc = acc.to(C_ptr.dtype.element_ty)
346
+ C = C_ptr + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
347
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
348
+ # Handles write-back with reduction-splitting
349
+ tl.store(C, acc, mask=mask)
350
+
351
+
352
+ @triton.autotune(
353
+ configs=MATMUL_CONFIGS
354
+ + [
355
+ Config(
356
+ {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
357
+ num_stages=3,
358
+ num_warps=8,
359
+ ),
360
+ ],
361
+ key=[
362
+ "m_key",
363
+ "n_key",
364
+ "k_key",
365
+ ],
366
+ )
367
+ @triton.heuristics(
368
+ {
369
+ "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
370
+ }
371
+ )
372
+ @triton.jit
373
+ def _kernel_matmul_fp8_row_no_fast_acc(
374
+ A_ptr,
375
+ B_ptr,
376
+ C_ptr,
377
+ M,
378
+ N,
379
+ K,
380
+ m_key,
381
+ n_key,
382
+ k_key,
383
+ A_scale,
384
+ B_scale,
385
+ Bias,
386
+ stride_am,
387
+ stride_ak,
388
+ stride_bn,
389
+ stride_bk,
390
+ stride_cm,
391
+ stride_cn,
392
+ dot_out_dtype: tl.constexpr,
393
+ allow_tf32: tl.constexpr,
394
+ fp8_fast_accum: tl.constexpr,
395
+ BLOCK_M: tl.constexpr,
396
+ BLOCK_N: tl.constexpr,
397
+ BLOCK_K: tl.constexpr,
398
+ GROUP_M: tl.constexpr,
399
+ SPLIT_K: tl.constexpr,
400
+ EVEN_K: tl.constexpr,
401
+ USE_BIAS: tl.constexpr,
402
+ AB_DTYPE: tl.constexpr,
403
+ NUM_SMS: tl.constexpr,
404
+ ) -> None:
405
+ """Matmul kernel of [M, K] @ [N, K] with row-wise scales
406
+
407
+ performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
408
+
409
+ Args:
410
+ A (TensorWrapper): [M, K] input tensor.
411
+ B (TensorWrapper): [N, K] input tensor.
412
+ C (TensorWrapper): [M, N] output tensor.
413
+ M (int): M dimension of input tensor.
414
+ N (int): N dimension of input tensor.
415
+ K (int): K dimension of input tensor.
416
+ m_key (int): Autotuning key for M dimension of input tensor.
417
+ n_key (int): Autotuning key for N dimension of input tensor.
418
+ k_key (int): Autotuning key for K dimension of input tensor.
419
+ A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
420
+ B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
421
+ Bias (TensorWrapper): [N] Optional bias tensor.
422
+ stride_am (int): Stride of M dimension of A.
423
+ stride_ak (int): Stride of K dimension of A.
424
+ stride_bn (int): Stride of N dimension of B.
425
+ stride_bk (int): Stride of K dimension of B.
426
+ stride_cm (int): Stride of M dimension of C.
427
+ stride_cn (int): Stride of N dimension of C.
428
+ dot_out_dtype (torch.dtype): Output type of tensor core.
429
+ allow_tf32 (bool): Whether to use TF32 for tensor core.
430
+ fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
431
+ BLOCK_M (int): Block size for M dimension.
432
+ BLOCK_N (int): Block size for N dimension.
433
+ BLOCK_K (int): Block size for K dimension.
434
+ GROUP_M (int): Number of groups for M dimension swizzle.
435
+ SPLIT_K (int): Number of SM's to launch per row.
436
+ EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
437
+ USE_BIAS(bool): Whether to use bias.
438
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
439
+ """
440
+ # Matrix multiplication.
441
+
442
+ start_pid = tl.program_id(axis=0)
443
+ num_pid_m = tl.cdiv(M, BLOCK_M)
444
+ num_pid_n = tl.cdiv(N, BLOCK_N)
445
+ k_tiles = tl.cdiv(K, BLOCK_K)
446
+ num_tiles = num_pid_m * num_pid_n
447
+
448
+ tiles_per_SM = num_tiles // NUM_SMS
449
+ if start_pid < num_tiles % NUM_SMS:
450
+ tiles_per_SM += 1
451
+
452
+ tile_id = start_pid - NUM_SMS
453
+ ki = -1
454
+
455
+ offs_k_for_mask = tl.arange(0, BLOCK_K)
456
+
457
+ num_pid_in_group = GROUP_M * num_pid_n
458
+
459
+ pid_m = 0
460
+ pid_n = 0
461
+ offs_am = tl.arange(0, BLOCK_M)
462
+ offs_bn = tl.arange(0, BLOCK_N)
463
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
464
+
465
+ for _ in range(0, k_tiles * tiles_per_SM):
466
+ ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
467
+ if ki == 0:
468
+ tile_id += NUM_SMS
469
+ group_id = tile_id // num_pid_in_group
470
+ first_pid_m = group_id * GROUP_M
471
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
472
+ pid_m = first_pid_m + (tile_id % group_size_m)
473
+ pid_n = (tile_id % num_pid_in_group) // group_size_m
474
+
475
+ start_m = pid_m * BLOCK_M
476
+ start_n = pid_n * BLOCK_N
477
+ offs_am = start_m + tl.arange(0, BLOCK_M)
478
+ offs_bn = start_n + tl.arange(0, BLOCK_N)
479
+ offs_am = tl.where(offs_am < M, offs_am, 0)
480
+ offs_bn = tl.where(offs_bn < N, offs_bn, 0)
481
+ offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_M), BLOCK_M)
482
+ offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_N), BLOCK_N)
483
+ offs_k = ki * BLOCK_K + tl.arange(0, BLOCK_K)
484
+ A = A_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
485
+ B = B_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
486
+
487
+ a = tl.load(A, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_K, other=0.0)
488
+ b = tl.load(B, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_K, other=0.0)
489
+ acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
490
+
491
+ if ki == k_tiles - 1:
492
+ # rematerialize rm and rn to save registers
493
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
494
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
495
+
496
+ # Invert scaling.
497
+ a_scale = tl.load(A_scale + rm, mask=rm < M)
498
+ b_scale = tl.load(B_scale + rn, mask=rn < N)
499
+ # pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`.
500
+ scale = a_scale[:, None] * b_scale[None, :]
501
+ acc *= scale
502
+
503
+ # Load and add bias if specified.
504
+ if USE_BIAS:
505
+ bias = tl.load(Bias + rn, mask=rn < N)
506
+ acc += bias[None, :]
507
+
508
+ acc = acc.to(C_ptr.dtype.element_ty)
509
+ C = C_ptr + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
510
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
511
+ # Handles write-back with reduction-splitting
512
+ tl.store(C, acc, mask=mask)
513
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
514
+
515
+
516
+ @triton.autotune(
517
+ configs=MATMUL_CONFIGS,
518
+ key=[
519
+ "m_key",
520
+ "n_key",
521
+ "k_key",
522
+ ],
523
+ )
524
+ @triton.heuristics(
525
+ {
526
+ "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
527
+ }
528
+ )
529
+ @triton.jit
530
+ def _kernel_matmul_fp8_row_imprecise_acc(
531
+ A,
532
+ B,
533
+ C,
534
+ M,
535
+ N,
536
+ K,
537
+ m_key,
538
+ n_key,
539
+ k_key,
540
+ A_scale,
541
+ B_scale,
542
+ Bias,
543
+ stride_am,
544
+ stride_ak,
545
+ stride_bn,
546
+ stride_bk,
547
+ stride_cm,
548
+ stride_cn,
549
+ dot_out_dtype: tl.constexpr,
550
+ allow_tf32: tl.constexpr,
551
+ fp8_fast_accum: tl.constexpr,
552
+ BLOCK_M: tl.constexpr,
553
+ BLOCK_N: tl.constexpr,
554
+ BLOCK_K: tl.constexpr,
555
+ GROUP_M: tl.constexpr,
556
+ SPLIT_K: tl.constexpr,
557
+ EVEN_K: tl.constexpr,
558
+ USE_BIAS: tl.constexpr,
559
+ AB_DTYPE: tl.constexpr,
560
+ ) -> None:
561
+ """Matmul kernel of [M, K] @ [N, K] with row-wise scales
562
+
563
+ performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
564
+
565
+ Args:
566
+ A (TensorWrapper): [M, K] input tensor.
567
+ B (TensorWrapper): [N, K] input tensor.
568
+ C (TensorWrapper): [M, N] output tensor.
569
+ M (int): M dimension of input tensor.
570
+ N (int): N dimension of input tensor.
571
+ K (int): K dimension of input tensor.
572
+ m_key (int): Autotuning key for M dimension of input tensor.
573
+ n_key (int): Autotuning key for N dimension of input tensor.
574
+ k_key (int): Autotuning key for K dimension of input tensor.
575
+ A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
576
+ B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
577
+ Bias (TensorWrapper): [N] Optional bias tensor.
578
+ stride_am (int): Stride of M dimension of A.
579
+ stride_ak (int): Stride of K dimension of A.
580
+ stride_bn (int): Stride of N dimension of B.
581
+ stride_bk (int): Stride of K dimension of B.
582
+ stride_cm (int): Stride of M dimension of C.
583
+ stride_cn (int): Stride of N dimension of C.
584
+ dot_out_dtype (torch.dtype): Output type of tensor core.
585
+ allow_tf32 (bool): Whether to use TF32 for tensor core.
586
+ fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
587
+ BLOCK_M (int): Block size for M dimension.
588
+ BLOCK_N (int): Block size for N dimension.
589
+ BLOCK_K (int): Block size for K dimension.
590
+ GROUP_M (int): Number of groups for M dimension swizzle.
591
+ SPLIT_K (int): Number of SM's to launch per row.
592
+ EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
593
+ USE_BIAS (bool): Whether to use bias.
594
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
595
+ """
596
+ # Matrix multiplication.
597
+ pid = tl.program_id(0)
598
+ pid_z = tl.program_id(1)
599
+ grid_m = tl.cdiv(M, BLOCK_M)
600
+ grid_n = tl.cdiv(N, BLOCK_N)
601
+ # Re-order program ID for better L2 performance (swizzle).
602
+ width = GROUP_M * grid_n
603
+ group_id = pid // width
604
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
605
+ pid_m = group_id * GROUP_M + (pid % group_size)
606
+ pid_n = (pid % width) // (group_size)
607
+ # Do matrix multiplication.
608
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
609
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
610
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
611
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
612
+ rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
613
+ # Pointers.
614
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
615
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
616
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
617
+
618
+ for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
619
+ if EVEN_K:
620
+ a = tl.load(A)
621
+ b = tl.load(B)
622
+ else:
623
+ k_remaining = K - k * (BLOCK_K * SPLIT_K)
624
+ _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
625
+ a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
626
+ b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
627
+ if AB_DTYPE:
628
+ a = a.to(C.dtype.element_ty)
629
+ b = b.to(C.dtype.element_ty)
630
+ if fp8_fast_accum:
631
+ acc = tl.dot(
632
+ a,
633
+ b,
634
+ acc,
635
+ max_num_imprecise_acc=32,
636
+ out_dtype=dot_out_dtype,
637
+ allow_tf32=allow_tf32,
638
+ )
639
+ else:
640
+ acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
641
+
642
+ A += BLOCK_K * SPLIT_K * stride_ak
643
+ B += BLOCK_K * SPLIT_K * stride_bk
644
+
645
+ # rematerialize rm and rn to save registers
646
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
647
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
648
+
649
+ # Invert scaling.
650
+ a_scale = tl.load(A_scale + rm, mask=rm < M)
651
+ b_scale = tl.load(B_scale + rn, mask=rn < N)
652
+ # pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`.
653
+ scale = a_scale[:, None] * b_scale[None, :]
654
+ acc *= scale
655
+
656
+ # Apply bias.
657
+ if USE_BIAS:
658
+ bias = tl.load(Bias + rn, mask=rn < N)
659
+ acc += bias[None, :]
660
+
661
+ acc = acc.to(C.dtype.element_ty)
662
+ C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
663
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
664
+ # Handles write-back with reduction-splitting
665
+ if SPLIT_K == 1:
666
+ tl.store(C, acc, mask=mask)
667
+ else:
668
+ tl.atomic_add(C, acc, mask=mask)
669
+
670
+
671
+ @triton.autotune(
672
+ configs=[
673
+ Config(
674
+ {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
675
+ num_stages=3,
676
+ num_warps=8,
677
+ ),
678
+ Config(
679
+ {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
680
+ num_stages=3,
681
+ num_warps=8,
682
+ ),
683
+ Config(
684
+ {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1},
685
+ num_stages=4,
686
+ num_warps=4,
687
+ ),
688
+ Config(
689
+ {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
690
+ num_stages=4,
691
+ num_warps=4,
692
+ ),
693
+ Config(
694
+ {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
695
+ num_stages=4,
696
+ num_warps=4,
697
+ ),
698
+ Config(
699
+ {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
700
+ num_stages=4,
701
+ num_warps=4,
702
+ ),
703
+ Config(
704
+ {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
705
+ num_stages=4,
706
+ num_warps=4,
707
+ ),
708
+ Config(
709
+ {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 512, "SPLIT_K": 1},
710
+ num_stages=3,
711
+ num_warps=4,
712
+ ),
713
+ ],
714
+ key=[
715
+ "m_key",
716
+ "n_key",
717
+ "k_key",
718
+ ],
719
+ use_cuda_graph=True,
720
+ )
721
+ @triton.heuristics(
722
+ {
723
+ "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
724
+ }
725
+ )
726
+ @triton.jit
727
+ def _kernel_matmul_fp8_row_tma_persistent(
728
+ A_ptr,
729
+ B_ptr,
730
+ C_ptr,
731
+ M,
732
+ N,
733
+ K,
734
+ m_key,
735
+ n_key,
736
+ k_key,
737
+ A_scale,
738
+ B_scale,
739
+ Bias,
740
+ stride_am,
741
+ stride_ak,
742
+ stride_bn,
743
+ stride_bk,
744
+ stride_cm,
745
+ stride_cn,
746
+ dot_out_dtype: tl.constexpr,
747
+ c_dtype: tl.constexpr,
748
+ bias_dtype: tl.constexpr,
749
+ allow_tf32: tl.constexpr,
750
+ fp8_fast_accum: tl.constexpr,
751
+ BLOCK_M: tl.constexpr,
752
+ BLOCK_N: tl.constexpr,
753
+ BLOCK_K: tl.constexpr,
754
+ GROUP_M: tl.constexpr,
755
+ AB_DTYPE: tl.constexpr,
756
+ SPLIT_K: tl.constexpr,
757
+ EVEN_K: tl.constexpr,
758
+ NUM_SMS: tl.constexpr,
759
+ USE_BIAS: tl.constexpr,
760
+ ) -> None:
761
+ """Matmul kernel of [M, K] @ [N, K] with row-wise scales
762
+
763
+ performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
764
+
765
+ Args:
766
+ A (TensorWrapper): [M, K] input tensor.
767
+ B (TensorWrapper): [N, K] input tensor.
768
+ C (TensorWrapper): [M, N] output tensor.
769
+ M (int): M dimension of input tensor.
770
+ N (int): N dimension of input tensor.
771
+ K (int): K dimension of input tensor.
772
+ A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
773
+ B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
774
+ stride_am (int): Stride of M dimension of A.
775
+ stride_ak (int): Stride of K dimension of A.
776
+ stride_bn (int): Stride of N dimension of B.
777
+ stride_bk (int): Stride of K dimension of B.
778
+ stride_cm (int): Stride of M dimension of C.
779
+ stride_cn (int): Stride of N dimension of C.
780
+ dot_out_dtype (torch.dtype): Output type of tensor core.
781
+ allow_tf32 (bool): Whether to use TF32 for tensor core.
782
+ fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
783
+ BLOCK_M (int): Block size for M dimension.
784
+ BLOCK_N (int): Block size for N dimension.
785
+ BLOCK_K (int): Block size for K dimension.
786
+ GROUP_M (int): Number of groups for M dimension swizzle.
787
+ SPLIT_K (int): Number of SM's to launch per row.
788
+ EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
789
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
790
+ """
791
+ # Matrix multiplication.
792
+ start_pid = tl.program_id(axis=0)
793
+ num_pid_m = tl.cdiv(M, BLOCK_M)
794
+ num_pid_n = tl.cdiv(N, BLOCK_N)
795
+ k_tiles = tl.cdiv(K, BLOCK_K)
796
+ num_tiles = num_pid_m * num_pid_n
797
+
798
+ tiles_per_SM = num_tiles // NUM_SMS
799
+ if start_pid < num_tiles % NUM_SMS:
800
+ tiles_per_SM += 1
801
+
802
+ tile_id = start_pid - NUM_SMS
803
+ ki = -1
804
+
805
+ pid_m = 0
806
+ pid_n = 0
807
+ offs_am = 0
808
+ offs_bn = 0
809
+
810
+ num_pid_in_group = GROUP_M * num_pid_n
811
+
812
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
813
+
814
+ dtype_fp8 = tl.float8e4nv
815
+ scale_dtype = tl.float32
816
+
817
+ for _ in range(0, k_tiles * tiles_per_SM):
818
+ ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
819
+ if ki == 0:
820
+ tile_id += NUM_SMS
821
+ group_id = tile_id // num_pid_in_group
822
+ first_pid_m = group_id * GROUP_M
823
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
824
+ pid_m = first_pid_m + (tile_id % group_size_m)
825
+ pid_n = (tile_id % num_pid_in_group) // group_size_m
826
+
827
+ offs_am = pid_m * BLOCK_M
828
+ offs_bn = pid_n * BLOCK_N
829
+ offs_am = tl.multiple_of(offs_am, BLOCK_M)
830
+ offs_bn = tl.multiple_of(offs_bn, BLOCK_N)
831
+
832
+ offs_k = ki * BLOCK_K
833
+
834
+ a = tl._experimental_descriptor_load(
835
+ A_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], dtype_fp8
836
+ )
837
+ b = tl._experimental_descriptor_load(
838
+ B_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], dtype_fp8
839
+ )
840
+
841
+ if fp8_fast_accum:
842
+ acc = tl.dot(a, b.T, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
843
+ else:
844
+ acc += tl.dot(a, b.T, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
845
+
846
+ if ki == k_tiles - 1:
847
+ # rematerialize rm and rn to save registers
848
+
849
+ # # Invert scaling.
850
+ a_scale = tl._experimental_descriptor_load(
851
+ A_scale, [offs_am], [BLOCK_M], scale_dtype
852
+ )
853
+ b_scale = tl._experimental_descriptor_load(
854
+ B_scale, [offs_bn], [BLOCK_N], scale_dtype
855
+ )
856
+ # pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`.
857
+ scale = a_scale[:, None] * b_scale[None, :]
858
+ acc *= scale
859
+
860
+ # Load and add bias if specified.
861
+ if USE_BIAS:
862
+ bias = tl._experimental_descriptor_load(
863
+ Bias, [offs_bn], [BLOCK_N], bias_dtype
864
+ )
865
+ acc += bias[None, :]
866
+
867
+ acc = acc.to(c_dtype)
868
+ tl._experimental_descriptor_store(C_ptr, acc, [offs_am, offs_bn])
869
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
870
+
871
+
872
+ has_warp_specialization = hasattr(tl, "async_task")
873
+
874
+
875
+ def make_autotuner_config(dictargs, **kwargs):
876
+ # NOTE: Triton 3.4.x removed some keyword arguments from Config constructor;
877
+ # however, fbcode uses 3.3.1, and so this shim is provided to support both
878
+ # versions.
879
+ #
880
+ # https://github.com/triton-lang/triton/blob/v3.3.1/python/triton/runtime/autotuner.py#L275
881
+ # https://github.com/triton-lang/triton/blame/release/3.4.x/python/triton/runtime/autotuner.py#L319
882
+ if version.parse(triton.__version__) > version.parse("3.3.1"):
883
+ for key in ["num_buffers_warp_spec", "num_consumer_groups"]:
884
+ kwargs.pop(key, None)
885
+ return Config(dictargs, **kwargs)
886
+
887
+
888
+ def get_ws_configs() -> List[Config]:
889
+ if not has_warp_specialization:
890
+ return []
891
+ return [
892
+ make_autotuner_config(
893
+ {
894
+ "BLOCK_M": 128,
895
+ "BLOCK_N": 256,
896
+ "BLOCK_K": 128,
897
+ "SPLIT_K": 1,
898
+ "NUM_CONSUMER_GROUPS": 2,
899
+ },
900
+ num_stages=3,
901
+ num_warps=4,
902
+ num_consumer_groups=2,
903
+ num_buffers_warp_spec=3,
904
+ ),
905
+ make_autotuner_config(
906
+ {
907
+ "BLOCK_M": 128,
908
+ "BLOCK_N": 128,
909
+ "BLOCK_K": 128,
910
+ "SPLIT_K": 1,
911
+ "NUM_CONSUMER_GROUPS": 2,
912
+ },
913
+ num_stages=4,
914
+ num_warps=4,
915
+ num_consumer_groups=2,
916
+ num_buffers_warp_spec=4,
917
+ ),
918
+ make_autotuner_config(
919
+ {
920
+ "BLOCK_M": 128,
921
+ "BLOCK_N": 256,
922
+ "BLOCK_K": 128,
923
+ "SPLIT_K": 1,
924
+ "NUM_CONSUMER_GROUPS": 1,
925
+ },
926
+ num_stages=3,
927
+ num_warps=8,
928
+ num_consumer_groups=0,
929
+ num_buffers_warp_spec=3,
930
+ ),
931
+ make_autotuner_config(
932
+ {
933
+ "BLOCK_M": 64,
934
+ "BLOCK_N": 64,
935
+ "BLOCK_K": 512,
936
+ "SPLIT_K": 1,
937
+ "NUM_CONSUMER_GROUPS": 1,
938
+ },
939
+ num_stages=3,
940
+ num_warps=4,
941
+ num_consumer_groups=0,
942
+ num_buffers_warp_spec=3,
943
+ ),
944
+ ]
945
+
946
+
947
+ @triton.autotune(
948
+ configs=[
949
+ Config(
950
+ {
951
+ "BLOCK_M": 128,
952
+ "BLOCK_N": 256,
953
+ "BLOCK_K": 128,
954
+ "SPLIT_K": 1,
955
+ "NUM_CONSUMER_GROUPS": 1,
956
+ },
957
+ num_stages=3,
958
+ num_warps=8,
959
+ ),
960
+ ]
961
+ + get_ws_configs(),
962
+ key=[
963
+ "m_key",
964
+ "n_key",
965
+ "k_key",
966
+ ],
967
+ use_cuda_graph=True,
968
+ )
969
+ @triton.heuristics(
970
+ {
971
+ "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
972
+ }
973
+ )
974
+ @triton.jit
975
+ def _kernel_matmul_fp8_row_tma_persistent_ws_cooperative(
976
+ A_ptr,
977
+ B_ptr,
978
+ C_ptr,
979
+ M,
980
+ N,
981
+ K,
982
+ m_key,
983
+ n_key,
984
+ k_key,
985
+ A_scale,
986
+ B_scale,
987
+ Bias,
988
+ stride_am,
989
+ stride_ak,
990
+ stride_bn,
991
+ stride_bk,
992
+ stride_cm,
993
+ stride_cn,
994
+ dot_out_dtype: tl.constexpr,
995
+ c_dtype: tl.constexpr,
996
+ bias_dtype: tl.constexpr,
997
+ allow_tf32: tl.constexpr,
998
+ fp8_fast_accum: tl.constexpr,
999
+ BLOCK_M: tl.constexpr,
1000
+ BLOCK_N: tl.constexpr,
1001
+ BLOCK_K: tl.constexpr,
1002
+ GROUP_M: tl.constexpr,
1003
+ AB_DTYPE: tl.constexpr,
1004
+ SPLIT_K: tl.constexpr,
1005
+ EVEN_K: tl.constexpr,
1006
+ NUM_SMS: tl.constexpr,
1007
+ USE_BIAS: tl.constexpr,
1008
+ NUM_CONSUMER_GROUPS: tl.constexpr,
1009
+ ) -> None:
1010
+ """Matmul kernel of [M, K] @ [N, K] with row-wise scales
1011
+
1012
+ performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
1013
+
1014
+ Args:
1015
+ A (TensorWrapper): [M , K] input tensor.
1016
+ B (TensorWrapper): [N, K] input tensor.
1017
+ C (TensorWrapper): [M, N] output tensor.
1018
+ M (int): M dimension of input tensor.
1019
+ N (int): N dimension of input tensor.
1020
+ K (int): K dimension of input tensor.
1021
+ A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
1022
+ B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
1023
+ stride_am (int): Stride of M dimension of A.
1024
+ stride_ak (int): Stride of K dimension of A.
1025
+ stride_bn (int): Stride of N dimension of B.
1026
+ stride_bk (int): Stride of K dimension of B.
1027
+ stride_cm (int): Stride of M dimension of C.
1028
+ stride_cn (int): Stride of N dimension of C.
1029
+ dot_out_dtype (torch.dtype): Output type of tensor core.
1030
+ allow_tf32 (bool): Whether to use TF32 for tensor core.
1031
+ fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
1032
+ BLOCK_M (int): Block size for M dimension.
1033
+ BLOCK_N (int): Block size for N dimension.
1034
+ BLOCK_K (int): Block size for K dimension.
1035
+ GROUP_M (int): Number of groups for M dimension swizzle.
1036
+ SPLIT_K (int): Number of SM's to launch per row.
1037
+ EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
1038
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
1039
+ """
1040
+ num_tiles = tl.cdiv(M, BLOCK_M) * tl.cdiv(N, BLOCK_N)
1041
+ num_pid_m = tl.cdiv(M, BLOCK_M)
1042
+ num_pid_n = tl.cdiv(N, BLOCK_N)
1043
+ dtype_fp8 = tl.float8e4nv
1044
+ for pid in range(tl.program_id(0), num_tiles, tl.num_programs(0)):
1045
+ num_pid_in_group = GROUP_M * num_pid_n
1046
+ group_id = pid // num_pid_in_group
1047
+ first_pid_m = group_id * GROUP_M
1048
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
1049
+ # pyre-ignore
1050
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
1051
+ pid_n = (pid % num_pid_in_group) // group_size_m
1052
+
1053
+ # ----------------------------------------------------------
1054
+ # Create pointers for the first blocks of A and B.
1055
+ # We will advance this pointer as we move in the K direction
1056
+ # and accumulate
1057
+ # `a_ptrs` is a block of [BLOCK_M, BLOCK_K] pointers
1058
+ # `b_ptrs` is a block of [BLOCK_K, BLOCK_N] pointers
1059
+ # See above `Pointer Arithmetic` section for details
1060
+ offs_am = pid_m * BLOCK_M
1061
+ offs_bn = pid_n * BLOCK_N
1062
+ offs_k = 0
1063
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
1064
+ # pyre-ignore
1065
+ tl.assume(tl.cdiv(K, BLOCK_K) > 0)
1066
+ for _ in range(0, tl.cdiv(K, BLOCK_K)):
1067
+ # pyre-ignore
1068
+ with tl.async_task([0]):
1069
+ a = tl._experimental_descriptor_load(
1070
+ A_ptr,
1071
+ [offs_am, offs_k],
1072
+ [BLOCK_M, BLOCK_K],
1073
+ dtype_fp8,
1074
+ )
1075
+ b = tl._experimental_descriptor_load(
1076
+ B_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], dtype_fp8
1077
+ )
1078
+
1079
+ if fp8_fast_accum:
1080
+ acc = tl.dot(
1081
+ a, b.T, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32
1082
+ )
1083
+ else:
1084
+ acc += tl.dot(a, b.T, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
1085
+
1086
+ offs_k += BLOCK_K
1087
+
1088
+ # pyre-ignore
1089
+ with tl.async_task([1, NUM_CONSUMER_GROUPS]):
1090
+ # Invert scaling.
1091
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
1092
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
1093
+ a_scale = tl.load(A_scale + rm, mask=rm < M)
1094
+ b_scale = tl.load(B_scale + rn, mask=rn < N)
1095
+ scale = a_scale[:, None] * b_scale[None, :]
1096
+ acc *= scale
1097
+ # Load and add bias if specified.
1098
+ if USE_BIAS:
1099
+ bias = tl._experimental_descriptor_load(
1100
+ Bias, [offs_bn], [BLOCK_N], bias_dtype
1101
+ )
1102
+ acc += bias[None, :]
1103
+ acc = acc.to(c_dtype)
1104
+ tl._experimental_descriptor_store(C_ptr, acc, [offs_am, offs_bn])
1105
+
1106
+
1107
+ def _is_eligible_for_skip_scaling(
1108
+ is_rowwise: bool,
1109
+ fp8_fast_accum: bool,
1110
+ imprecise_acc: bool,
1111
+ tma_persistent: bool,
1112
+ no_use_persistent: Optional[bool],
1113
+ use_warp_specialization: bool,
1114
+ ) -> bool:
1115
+ if not is_rowwise:
1116
+ return False
1117
+
1118
+ return (
1119
+ fp8_fast_accum
1120
+ and not imprecise_acc
1121
+ and not tma_persistent
1122
+ and not no_use_persistent
1123
+ and not use_warp_specialization
1124
+ )
1125
+
1126
+
1127
+ @torch._library.triton_op("triton::matmul_fp8_row", mutates_args=())
1128
+ def matmul_fp8_row(
1129
+ a: torch.Tensor,
1130
+ b: torch.Tensor,
1131
+ a_scale: Optional[torch.Tensor],
1132
+ b_scale: torch.Tensor,
1133
+ bias: Optional[torch.Tensor] = None,
1134
+ dot_out_dtype: Optional[torch.dtype] = None,
1135
+ allow_tf32: bool = True,
1136
+ fp8_fast_accum: bool = True,
1137
+ imprecise_acc: bool = False,
1138
+ tma_persistent: bool = True,
1139
+ no_use_persistent: Optional[bool] = None,
1140
+ use_warp_specialization: bool = False,
1141
+ ) -> torch.Tensor:
1142
+ """
1143
+ Performs matmul on [M, K] and [N, K] fp8 matrices with row-wise scalings [M], [N].
1144
+
1145
+ Args:
1146
+ a (torch.Tensor): [M, K] input tensor.
1147
+ b (torch.Tensor): [N, K] input tensor.
1148
+ a_scale (Optiona;[torch.Tensor]): [M] reciprocal scale tensor per row.
1149
+ A * a_scale = original A. Scaling will be skiped if a_scale is None.
1150
+ b_scale (torch.Tensor): [N] reciprocal scale tensor per row. B * b_scale = original B
1151
+ bias (torch.Tensor): [N] optional bias tensor to add to output if provided.
1152
+ dot_out_dtype (torch.dtype): Output type of tensor core.
1153
+ allow_tf32 (bool): Whether to use TF32 for tensor core.
1154
+ fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
1155
+ tma_persistent (bool): Whether to use TMA persistent kernel impl.
1156
+
1157
+ Returns:
1158
+ torch.Tensor: [M, N] Output tensor a @ b / (a_scale[:, None] * b_scale[None, :])
1159
+ """
1160
+ if no_use_persistent is None:
1161
+ # Default True for AMD and False for Nvidia.
1162
+ if torch.version.hip is not None:
1163
+ no_use_persistent = True
1164
+ else:
1165
+ no_use_persistent = False
1166
+ # Get datatypes and constants to use.
1167
+ pt_fp8_dtype, _, _, _ = get_fp8_constants()
1168
+ # Handle 3D+ a shape
1169
+ a_shape = a.shape
1170
+ a = a.view(-1, a.size(-1))
1171
+ # View inputs into proper torch fp8 dtype.
1172
+ if torch.version.cuda:
1173
+ assert a.dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
1174
+ elif torch.version.hip:
1175
+ if torch.cuda.get_device_capability() < (9, 5):
1176
+ assert a.dtype in (
1177
+ torch.float8_e4m3fnuz,
1178
+ torch.float8_e5m2fnuz,
1179
+ )
1180
+ else:
1181
+ assert a.dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
1182
+ else:
1183
+ assert a.dtype in (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz)
1184
+ assert b.dtype == pt_fp8_dtype
1185
+ M, N, K, m_key, n_key, k_key, c, c_dtype_triton, dot_out_dtype_triton, device = (
1186
+ prep_matmul(a, b, dot_out_dtype)
1187
+ )
1188
+
1189
+ # Skip scaling (a_scale is None) can only be applied in certain cases.
1190
+ assert a_scale is not None or _is_eligible_for_skip_scaling(
1191
+ is_rowwise=True,
1192
+ fp8_fast_accum=fp8_fast_accum,
1193
+ imprecise_acc=imprecise_acc,
1194
+ tma_persistent=tma_persistent,
1195
+ no_use_persistent=no_use_persistent,
1196
+ use_warp_specialization=use_warp_specialization,
1197
+ )
1198
+
1199
+ output_shape = a_shape[:-1] + (N,)
1200
+ # Handle tensor with empty inputs.
1201
+ if (M == 0) or (N == 0) or (K == 0):
1202
+ return torch.zeros(output_shape, device=device, dtype=torch.bfloat16)
1203
+ # launch kernel
1204
+ if a.device == torch.device("cpu"):
1205
+ logger.info(
1206
+ "FP8 Row-wise Triton kernel not supported on cpu, fallback to torch"
1207
+ )
1208
+ if a_scale is None:
1209
+ scale = b_scale[None, :]
1210
+ else:
1211
+ scale = a_scale[:, None] * b_scale[None, :]
1212
+ output = torch.matmul(a.to(torch.bfloat16), b.to(torch.bfloat16).T) * scale
1213
+ if bias is not None:
1214
+ output += bias[None, :]
1215
+ return output.to(c.dtype)
1216
+
1217
+ def grid(META: Dict[str, int]) -> Tuple[int, int]:
1218
+ return (
1219
+ triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
1220
+ META["SPLIT_K"],
1221
+ )
1222
+
1223
+ NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
1224
+
1225
+ def persistent_grid(META: Dict[str, int]) -> Tuple[int]:
1226
+ return (
1227
+ min(
1228
+ NUM_SMS,
1229
+ triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
1230
+ ),
1231
+ )
1232
+
1233
+ if no_use_persistent:
1234
+ logger.debug("Using non-persistent kernel")
1235
+ with torch.cuda.device(a.device.index):
1236
+ torch._library.capture_triton(_kernel_matmul_fp8_row_non_persistent)[grid](
1237
+ a,
1238
+ b,
1239
+ c,
1240
+ M,
1241
+ N,
1242
+ K,
1243
+ m_key,
1244
+ n_key,
1245
+ k_key,
1246
+ a_scale,
1247
+ b_scale,
1248
+ bias,
1249
+ a.stride(0),
1250
+ a.stride(1),
1251
+ b.stride(0),
1252
+ b.stride(1),
1253
+ c.stride(0),
1254
+ c.stride(1),
1255
+ dot_out_dtype=dot_out_dtype_triton,
1256
+ allow_tf32=allow_tf32,
1257
+ fp8_fast_accum=fp8_fast_accum,
1258
+ # GROUP_M=8,
1259
+ USE_BIAS=bias is not None,
1260
+ AB_DTYPE=False,
1261
+ )
1262
+ elif use_warp_specialization:
1263
+ assert has_warp_specialization
1264
+ # used by TMA warp specialization kernel
1265
+ desc_helper = TmaAutoTuneHelper()
1266
+ desc_helper.init_tma_descriptor("a")
1267
+ desc_helper.init_tma_descriptor("b")
1268
+ desc_helper.init_tma_descriptor("c")
1269
+ desc_helper.init_tma_descriptor("a_scale")
1270
+ desc_helper.init_tma_descriptor("b_scale")
1271
+ desc_helper.init_tma_descriptor("bias")
1272
+
1273
+ def persistent_grid_tma_ws(META: Dict[str, int]) -> Tuple[int]:
1274
+ nonlocal desc_helper # noqa: F824
1275
+ assert a_scale is not None # Type narrowing for Pyre
1276
+ desc_helper.fill_2d_tma_descriptor(
1277
+ "a",
1278
+ a.data_ptr(),
1279
+ M,
1280
+ K,
1281
+ META["BLOCK_M"] // META["NUM_CONSUMER_GROUPS"],
1282
+ META["BLOCK_K"],
1283
+ a.element_size(),
1284
+ )
1285
+
1286
+ desc_helper.fill_2d_tma_descriptor(
1287
+ "b",
1288
+ b.data_ptr(),
1289
+ N,
1290
+ K,
1291
+ META["BLOCK_N"],
1292
+ META["BLOCK_K"],
1293
+ b.element_size(),
1294
+ )
1295
+ desc_helper.fill_2d_tma_descriptor(
1296
+ "c",
1297
+ c.data_ptr(),
1298
+ M,
1299
+ N,
1300
+ META["BLOCK_M"] // META["NUM_CONSUMER_GROUPS"],
1301
+ META["BLOCK_N"],
1302
+ c.element_size(),
1303
+ )
1304
+ desc_helper.fill_1d_tma_descriptor(
1305
+ "a_scale",
1306
+ a_scale.data_ptr(),
1307
+ M,
1308
+ META["BLOCK_M"] // META["NUM_CONSUMER_GROUPS"],
1309
+ a_scale.element_size(),
1310
+ )
1311
+ desc_helper.fill_1d_tma_descriptor(
1312
+ "b_scale",
1313
+ b_scale.data_ptr(),
1314
+ N,
1315
+ META["BLOCK_N"],
1316
+ b_scale.element_size(),
1317
+ )
1318
+ if bias is not None:
1319
+ desc_helper.fill_1d_tma_descriptor(
1320
+ "bias",
1321
+ bias.data_ptr(),
1322
+ N,
1323
+ META["BLOCK_N"],
1324
+ bias.element_size(),
1325
+ )
1326
+ return (
1327
+ min(
1328
+ NUM_SMS,
1329
+ triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
1330
+ ),
1331
+ )
1332
+
1333
+ desc_a = desc_helper.get_tma_descriptor_kernel_param("a")
1334
+ desc_b = desc_helper.get_tma_descriptor_kernel_param("b")
1335
+ desc_c = desc_helper.get_tma_descriptor_kernel_param("c")
1336
+ desc_a_scale = desc_helper.get_tma_descriptor_kernel_param("a_scale")
1337
+ desc_b_scale = desc_helper.get_tma_descriptor_kernel_param("b_scale")
1338
+ desc_bias = desc_helper.get_tma_descriptor_kernel_param("bias")
1339
+
1340
+ bias_dtype_triton = None
1341
+ if bias is not None:
1342
+ bias_dtype_triton = map_dtype_to_triton(bias.dtype)
1343
+
1344
+ # pyre-ignore
1345
+ torch._library.capture_triton(
1346
+ _kernel_matmul_fp8_row_tma_persistent_ws_cooperative
1347
+ )[persistent_grid_tma_ws](
1348
+ desc_a,
1349
+ desc_b,
1350
+ desc_c,
1351
+ M,
1352
+ N,
1353
+ K,
1354
+ m_key,
1355
+ n_key,
1356
+ k_key,
1357
+ a_scale,
1358
+ b_scale,
1359
+ desc_bias,
1360
+ a.stride(0),
1361
+ a.stride(1),
1362
+ b.stride(0),
1363
+ b.stride(1),
1364
+ c.stride(0),
1365
+ c.stride(1),
1366
+ dot_out_dtype=dot_out_dtype_triton,
1367
+ c_dtype=c_dtype_triton,
1368
+ bias_dtype=bias_dtype_triton,
1369
+ allow_tf32=allow_tf32,
1370
+ fp8_fast_accum=fp8_fast_accum,
1371
+ GROUP_M=8,
1372
+ AB_DTYPE=False,
1373
+ NUM_SMS=NUM_SMS,
1374
+ USE_BIAS=bias is not None,
1375
+ )
1376
+ elif tma_persistent:
1377
+ # used by TMA persistent kernel
1378
+ desc_helper = TmaAutoTuneHelper()
1379
+ desc_helper.init_tma_descriptor("a")
1380
+ desc_helper.init_tma_descriptor("b")
1381
+ desc_helper.init_tma_descriptor("c")
1382
+ desc_helper.init_tma_descriptor("a_scale")
1383
+ desc_helper.init_tma_descriptor("b_scale")
1384
+ desc_helper.init_tma_descriptor("bias")
1385
+
1386
+ def persistent_grid_tma(META: Dict[str, int]) -> Tuple[int]:
1387
+ nonlocal desc_helper # noqa: F824
1388
+ assert a_scale is not None # Type narrowing for Pyre
1389
+ desc_helper.fill_2d_tma_descriptor(
1390
+ "a",
1391
+ a.data_ptr(),
1392
+ M,
1393
+ K,
1394
+ META["BLOCK_M"],
1395
+ META["BLOCK_K"],
1396
+ a.element_size(),
1397
+ )
1398
+
1399
+ desc_helper.fill_2d_tma_descriptor(
1400
+ "b",
1401
+ b.data_ptr(),
1402
+ N,
1403
+ K,
1404
+ META["BLOCK_N"],
1405
+ META["BLOCK_K"],
1406
+ b.element_size(),
1407
+ )
1408
+ desc_helper.fill_2d_tma_descriptor(
1409
+ "c",
1410
+ c.data_ptr(),
1411
+ M,
1412
+ N,
1413
+ META["BLOCK_M"],
1414
+ META["BLOCK_N"],
1415
+ c.element_size(),
1416
+ )
1417
+ desc_helper.fill_1d_tma_descriptor(
1418
+ "a_scale",
1419
+ a_scale.data_ptr(),
1420
+ M,
1421
+ META["BLOCK_M"],
1422
+ a_scale.element_size(),
1423
+ )
1424
+ desc_helper.fill_1d_tma_descriptor(
1425
+ "b_scale",
1426
+ b_scale.data_ptr(),
1427
+ N,
1428
+ META["BLOCK_N"],
1429
+ b_scale.element_size(),
1430
+ )
1431
+ if bias is not None:
1432
+ desc_helper.fill_1d_tma_descriptor(
1433
+ "bias",
1434
+ bias.data_ptr(),
1435
+ N,
1436
+ META["BLOCK_N"],
1437
+ bias.element_size(),
1438
+ )
1439
+ return (
1440
+ min(
1441
+ NUM_SMS,
1442
+ triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
1443
+ ),
1444
+ )
1445
+
1446
+ desc_a = desc_helper.get_tma_descriptor_kernel_param("a")
1447
+ desc_b = desc_helper.get_tma_descriptor_kernel_param("b")
1448
+ desc_c = desc_helper.get_tma_descriptor_kernel_param("c")
1449
+ desc_a_scale = desc_helper.get_tma_descriptor_kernel_param("a_scale")
1450
+ desc_b_scale = desc_helper.get_tma_descriptor_kernel_param("b_scale")
1451
+ desc_bias = desc_helper.get_tma_descriptor_kernel_param("bias")
1452
+
1453
+ bias_dtype_triton = None
1454
+ if bias is not None:
1455
+ bias_dtype_triton = map_dtype_to_triton(bias.dtype)
1456
+
1457
+ # pyre-ignore
1458
+ torch._library.capture_triton(_kernel_matmul_fp8_row_tma_persistent)[
1459
+ persistent_grid_tma
1460
+ ](
1461
+ desc_a,
1462
+ desc_b,
1463
+ desc_c,
1464
+ M,
1465
+ N,
1466
+ K,
1467
+ m_key,
1468
+ n_key,
1469
+ k_key,
1470
+ desc_a_scale,
1471
+ desc_b_scale,
1472
+ desc_bias,
1473
+ a.stride(0),
1474
+ a.stride(1),
1475
+ b.stride(0),
1476
+ b.stride(1),
1477
+ c.stride(0),
1478
+ c.stride(1),
1479
+ dot_out_dtype=dot_out_dtype_triton,
1480
+ c_dtype=c_dtype_triton,
1481
+ bias_dtype=bias_dtype_triton,
1482
+ allow_tf32=allow_tf32,
1483
+ fp8_fast_accum=fp8_fast_accum,
1484
+ GROUP_M=8,
1485
+ AB_DTYPE=False,
1486
+ NUM_SMS=NUM_SMS,
1487
+ USE_BIAS=bias is not None,
1488
+ )
1489
+ elif imprecise_acc:
1490
+ with torch.cuda.device(a.device.index):
1491
+ torch._library.capture_triton(_kernel_matmul_fp8_row_imprecise_acc)[grid](
1492
+ a,
1493
+ b,
1494
+ c,
1495
+ M,
1496
+ N,
1497
+ K,
1498
+ m_key,
1499
+ n_key,
1500
+ k_key,
1501
+ a_scale,
1502
+ b_scale,
1503
+ bias,
1504
+ a.stride(0),
1505
+ a.stride(1),
1506
+ b.stride(0),
1507
+ b.stride(1),
1508
+ c.stride(0),
1509
+ c.stride(1),
1510
+ dot_out_dtype=dot_out_dtype_triton,
1511
+ allow_tf32=allow_tf32,
1512
+ fp8_fast_accum=fp8_fast_accum,
1513
+ GROUP_M=8,
1514
+ USE_BIAS=bias is not None,
1515
+ AB_DTYPE=False,
1516
+ )
1517
+ elif fp8_fast_accum:
1518
+ skip_scaling_a = a_scale is None
1519
+ torch._library.capture_triton(_kernel_matmul_fp8_row)[persistent_grid](
1520
+ a,
1521
+ b,
1522
+ c,
1523
+ M,
1524
+ N,
1525
+ K,
1526
+ m_key,
1527
+ n_key,
1528
+ k_key,
1529
+ a_scale,
1530
+ b_scale,
1531
+ bias,
1532
+ a.stride(0),
1533
+ a.stride(1),
1534
+ b.stride(0),
1535
+ b.stride(1),
1536
+ c.stride(0),
1537
+ c.stride(1),
1538
+ dot_out_dtype=dot_out_dtype_triton,
1539
+ allow_tf32=allow_tf32,
1540
+ fp8_fast_accum=fp8_fast_accum,
1541
+ skip_scaling_a=skip_scaling_a,
1542
+ GROUP_M=8,
1543
+ USE_BIAS=bias is not None,
1544
+ AB_DTYPE=False,
1545
+ NUM_SMS=NUM_SMS,
1546
+ )
1547
+ else:
1548
+ torch._library.capture_triton(_kernel_matmul_fp8_row_no_fast_acc)[
1549
+ persistent_grid
1550
+ ](
1551
+ a,
1552
+ b,
1553
+ c,
1554
+ M,
1555
+ N,
1556
+ K,
1557
+ m_key,
1558
+ n_key,
1559
+ k_key,
1560
+ a_scale,
1561
+ b_scale,
1562
+ bias,
1563
+ a.stride(0),
1564
+ a.stride(1),
1565
+ b.stride(0),
1566
+ b.stride(1),
1567
+ c.stride(0),
1568
+ c.stride(1),
1569
+ dot_out_dtype=dot_out_dtype_triton,
1570
+ allow_tf32=allow_tf32,
1571
+ fp8_fast_accum=fp8_fast_accum,
1572
+ GROUP_M=8,
1573
+ USE_BIAS=bias is not None,
1574
+ AB_DTYPE=False,
1575
+ NUM_SMS=NUM_SMS,
1576
+ )
1577
+ return c.view(output_shape)
1578
+
1579
+
1580
+ @matmul_fp8_row.register_fake
1581
+ def matmul_fp8_row_meta(
1582
+ a: torch.Tensor,
1583
+ b: torch.Tensor,
1584
+ a_scale: Optional[torch.Tensor],
1585
+ b_scale: torch.Tensor,
1586
+ bias: Optional[torch.Tensor] = None,
1587
+ dot_out_dtype: Optional[torch.dtype] = None,
1588
+ allow_tf32: bool = True,
1589
+ fp8_fast_accum: bool = True,
1590
+ imprecise_acc: bool = False,
1591
+ tma_persistent: bool = True,
1592
+ no_use_persistent: Optional[bool] = None,
1593
+ use_warp_specialization: bool = False,
1594
+ ) -> torch.Tensor:
1595
+ """Shape function for torch compile."""
1596
+ M, K = a.shape
1597
+ N, K = b.shape
1598
+ return torch.empty(
1599
+ (M, N),
1600
+ device=a.device,
1601
+ dtype=torch.bfloat16 if dot_out_dtype is None else dot_out_dtype,
1602
+ )
1603
+
1604
+
1605
+ # pruned some unreasonable config
1606
+ def prune_configs_block(configs, named_args, **kwargs):
1607
+ configs = early_config_prune(configs, named_args, **kwargs)
1608
+ scale_block_k = named_args["scale_block_k"]
1609
+ pruned_configs = []
1610
+ # Further rule out configs with scale_block_k is not a multiple of BLOCK_K
1611
+ for config in configs:
1612
+ kw = config.kwargs
1613
+ BLOCK_K = kw["BLOCK_K"]
1614
+ if scale_block_k % BLOCK_K != 0:
1615
+ continue
1616
+ pruned_configs.append(config)
1617
+ return pruned_configs
1618
+
1619
+
1620
+ @triton.autotune(
1621
+ configs=MATMUL_CONFIGS,
1622
+ key=[
1623
+ "m_key",
1624
+ "n_key",
1625
+ "k_key",
1626
+ ], # TODO caller side bin keys so similar shapes can use same triton.autotune.
1627
+ prune_configs_by={
1628
+ "early_config_prune": prune_configs_block,
1629
+ "perf_model": estimate_matmul_time,
1630
+ "top_k": 10,
1631
+ },
1632
+ )
1633
+ @triton.heuristics(
1634
+ {
1635
+ "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
1636
+ }
1637
+ )
1638
+ @triton.jit
1639
+ def _kernel_matmul_fp8_block_fastacc(
1640
+ A,
1641
+ B,
1642
+ C,
1643
+ M,
1644
+ N,
1645
+ K,
1646
+ m_key,
1647
+ n_key,
1648
+ k_key,
1649
+ A_scale,
1650
+ B_scale,
1651
+ scale_block_m: tl.constexpr,
1652
+ scale_block_n: tl.constexpr,
1653
+ scale_block_k: tl.constexpr,
1654
+ stride_am,
1655
+ stride_ak,
1656
+ stride_bn,
1657
+ stride_bk,
1658
+ stride_cm,
1659
+ stride_cn,
1660
+ stride_scale_am,
1661
+ stride_scale_ak,
1662
+ stride_scale_bn,
1663
+ stride_scale_bk,
1664
+ dot_out_dtype: tl.constexpr,
1665
+ allow_tf32: tl.constexpr,
1666
+ BLOCK_M: tl.constexpr,
1667
+ BLOCK_N: tl.constexpr,
1668
+ BLOCK_K: tl.constexpr,
1669
+ GROUP_M: tl.constexpr,
1670
+ SPLIT_K: tl.constexpr,
1671
+ EVEN_K: tl.constexpr,
1672
+ AB_DTYPE: tl.constexpr,
1673
+ ) -> None:
1674
+ """Matmul kernel of [M, K] @ [N, K] with block-wise scales
1675
+
1676
+ Performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles and
1677
+ A and B scaled by a scaling factor per [scale_block_m, scale_block_k] and
1678
+ [scale_block_n, scale_block_k] tiles
1679
+ respectively.
1680
+
1681
+ Todo:
1682
+ * Support scale_block_{mnk} < BLOCK{MNK} for each dim.
1683
+ Args:
1684
+ A (TensorWrapper): [M, K] input tensor.
1685
+ B (TensorWrapper): [N, K] input tensor.
1686
+ C (TensorWrapper): [M, N] output tensor.
1687
+ M (int): M dimension of input tensor.
1688
+ N (int): N dimension of input tensor.
1689
+ K (int): K dimension of input tensor.
1690
+ m_key (int): Autotuning key for M dimension of input tensor.
1691
+ n_key (int): Autotuning key for N dimension of input tensor.
1692
+ k_key (int): Autotuning key for K dimension of input tensor.
1693
+ A_scale (TensorWrapper): [cdiv(M, scale_block_m), cdiv(K, scale_block_k)] reciprocal scale tensor per block. A * A_scale = original A
1694
+ B_scale (TensorWrapper): [cdiv(N, scale_block_n), cdiv(K, scale_block_k)] reciprocal scale tensor per block. B * B_scale = original B
1695
+ scale_block_m (int): Block size for M dimension of A_scale.
1696
+ scale_block_n (int): Block size for N dimension of B_scale.
1697
+ scale_block_k (int): Block size for K dimension of A_scale and B_scale.
1698
+ stride_am (int): Stride of M dimension of A.
1699
+ stride_ak (int): Stride of K dimension of A.
1700
+ stride_bn (int): Stride of N dimension of B.
1701
+ stride_bk (int): Stride of K dimension of B.
1702
+ stride_cm (int): Stride of M dimension of C.
1703
+ stride_cn (int): Stride of N dimension of C.
1704
+ stride_scale_am (int): Stride of M dimension of A_scale.
1705
+ stride_scale_ak (int): Stride of K dimension of A_scale.
1706
+ stride_scale_bn (int): Stride of N dimension of B_scale.
1707
+ stride_scale_bk (int): Stride of K dimension of B_scale.
1708
+ dot_out_dtype (torch.dtype): Output type of tensor core.
1709
+ allow_tf32 (bool): Whether to use TF32 for tensor core.
1710
+ fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
1711
+ BLOCK_M (int): Block size for M dimension.
1712
+ BLOCK_N (int): Block size for N dimension.
1713
+ BLOCK_K (int): Block size for K dimension.
1714
+ GROUP_M (int): Number of groups for M dimension swizzle.
1715
+ SPLIT_K (int): Number of SM's to launch per row.
1716
+ EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
1717
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
1718
+ """
1719
+ assert BLOCK_M < scale_block_m
1720
+ assert BLOCK_N < scale_block_n
1721
+ assert BLOCK_K < scale_block_k
1722
+ # matrix multiplication
1723
+ pid = tl.program_id(0)
1724
+ pid_z = tl.program_id(1)
1725
+
1726
+ grid_m = tl.cdiv(M, BLOCK_M)
1727
+ grid_n = tl.cdiv(N, BLOCK_N)
1728
+ # re-order program ID for better L2 performance
1729
+ width = GROUP_M * grid_n
1730
+ group_id = pid // width
1731
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
1732
+ pid_m = group_id * GROUP_M + (pid % group_size)
1733
+ pid_n = (pid % width) // (group_size)
1734
+ # do matrix multiplication
1735
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
1736
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
1737
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
1738
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
1739
+ rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
1740
+ # pointers
1741
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
1742
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
1743
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
1744
+ _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
1745
+ scale_m = pid_m * BLOCK_M // scale_block_m
1746
+ scale_n = pid_n * BLOCK_N // scale_block_n
1747
+ k_multiple = scale_block_k // BLOCK_K
1748
+
1749
+ for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
1750
+ k_remaining = K - k * (BLOCK_K * SPLIT_K)
1751
+
1752
+ if EVEN_K:
1753
+ a = tl.load(A)
1754
+ b = tl.load(B)
1755
+ else:
1756
+ a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
1757
+ b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
1758
+ if AB_DTYPE:
1759
+ a = a.to(C.dtype.element_ty)
1760
+ b = b.to(C.dtype.element_ty)
1761
+
1762
+ acc = tl.dot(a, b, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
1763
+
1764
+ A += BLOCK_K * SPLIT_K * stride_ak
1765
+ B += BLOCK_K * SPLIT_K * stride_bk
1766
+
1767
+ # Some math to precompute on scalars, and apply once on matrix.
1768
+ # a + c/s = (as + c) / s
1769
+ # (((a_i-1 * s_i-1 + c_i-1) / s_i-1) * s_i + c_i) / s_i ... ) * s_k + c_k) * 1.0 / s_k
1770
+ # Simplifies to (a_i-1 + c) * (s_i+1/s_i)
1771
+ # And have s_k+1 be 1.
1772
+ # Scale_i = pid_i * BLOCK_I / scale_block_i
1773
+ pid_k = k * SPLIT_K + pid_z
1774
+ if ((pid_k + 1) % k_multiple == 0) or (k_remaining < BLOCK_K * SPLIT_K):
1775
+ # Note: Due to split_k access "pid_k" = k * SPLIT_K + pid_z
1776
+ # Access a_scale[pid_m, k * SPLIT_K + pid_z]
1777
+ # and b_scale[k * SPLIT_K + pid_z, pid_n]
1778
+
1779
+ scale_k = pid_k // k_multiple
1780
+ scale_k_next = scale_k + 1
1781
+ a_scale = tl.load(
1782
+ A_scale + scale_m * stride_scale_am + scale_k * stride_scale_ak
1783
+ )
1784
+ b_scale = tl.load(
1785
+ B_scale + scale_n * stride_scale_bn + scale_k * stride_scale_bk
1786
+ )
1787
+ scale = a_scale * b_scale
1788
+ if k + 1 == tl.cdiv(K, BLOCK_K * SPLIT_K):
1789
+ scale_next_inv_scale = scale
1790
+ else:
1791
+ a_scale_next = tl.load(
1792
+ A_scale + scale_m * stride_scale_am + scale_k_next * stride_scale_ak
1793
+ )
1794
+ b_scale_next = tl.load(
1795
+ B_scale + scale_n * stride_scale_bn + scale_k_next * stride_scale_bk
1796
+ )
1797
+ scale_next = a_scale_next * b_scale_next
1798
+ scale_next_inv_scale = scale / scale_next
1799
+ acc *= scale_next_inv_scale
1800
+
1801
+ # rematerialize rm and rn to save registers
1802
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
1803
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
1804
+
1805
+ acc = acc.to(C.dtype.element_ty)
1806
+ c = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
1807
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
1808
+ # handles write-back with reduction-splitting
1809
+ if SPLIT_K == 1:
1810
+ tl.store(c, acc, mask=mask)
1811
+ else:
1812
+ tl.atomic_add(c, acc, mask=mask)
1813
+
1814
+
1815
+ @triton.autotune(
1816
+ configs=MATMUL_CONFIGS,
1817
+ key=[
1818
+ "m_key",
1819
+ "n_key",
1820
+ "k_key",
1821
+ ], # TODO caller side bin keys so similar shapes can use same triton.autotune.
1822
+ prune_configs_by={
1823
+ "early_config_prune": early_config_prune,
1824
+ "perf_model": estimate_matmul_time,
1825
+ "top_k": 10,
1826
+ },
1827
+ )
1828
+ @triton.heuristics(
1829
+ {
1830
+ "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
1831
+ }
1832
+ )
1833
+ @triton.jit
1834
+ def _kernel_matmul_fp8_block_slowacc(
1835
+ A,
1836
+ B,
1837
+ C,
1838
+ M,
1839
+ N,
1840
+ K,
1841
+ m_key,
1842
+ n_key,
1843
+ k_key,
1844
+ A_scale,
1845
+ B_scale,
1846
+ scale_block_m: tl.constexpr,
1847
+ scale_block_n: tl.constexpr,
1848
+ scale_block_k: tl.constexpr,
1849
+ stride_am,
1850
+ stride_ak,
1851
+ stride_bn,
1852
+ stride_bk,
1853
+ stride_cm,
1854
+ stride_cn,
1855
+ stride_scale_am,
1856
+ stride_scale_ak,
1857
+ stride_scale_bn,
1858
+ stride_scale_bk,
1859
+ dot_out_dtype: tl.constexpr,
1860
+ allow_tf32: tl.constexpr,
1861
+ BLOCK_M: tl.constexpr,
1862
+ BLOCK_N: tl.constexpr,
1863
+ BLOCK_K: tl.constexpr,
1864
+ GROUP_M: tl.constexpr,
1865
+ SPLIT_K: tl.constexpr,
1866
+ EVEN_K: tl.constexpr,
1867
+ AB_DTYPE: tl.constexpr,
1868
+ ) -> None:
1869
+ """Matmul kernel of [M, K] @ [N, K] with block-wise scales
1870
+
1871
+ Performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles and
1872
+ A and B scaled by a scaling factor per [scale_block_m, scale_block_k] and
1873
+ [scale_block_n, scale_block_k] tiles
1874
+ respectively.
1875
+
1876
+ Todo:
1877
+ * Support scale_block_{mnk} < BLOCK{MNK} for each dim.
1878
+ Args:
1879
+ A (TensorWrapper): [M, K] input tensor.
1880
+ B (TensorWrapper): [N, K] input tensor.
1881
+ C (TensorWrapper): [M, N] output tensor.
1882
+ M (int): M dimension of input tensor.
1883
+ N (int): N dimension of input tensor.
1884
+ K (int): K dimension of input tensor.
1885
+ m_key (int): Autotuning key for M dimension of input tensor.
1886
+ n_key (int): Autotuning key for N dimension of input tensor.
1887
+ k_key (int): Autotuning key for K dimension of input tensor.
1888
+ A_scale (TensorWrapper): [cdiv(M, scale_block_m), cdiv(K, scale_block_k)] reciprocal scale tensor per block. A * A_scale = original A
1889
+ B_scale (TensorWrapper): [cdiv(N, scale_block_n), cdiv(K, scale_block_k)] reciprocal scale tensor per block. B * B_scale = original B
1890
+ scale_block_m (int): Block size for M dimension of A_scale.
1891
+ scale_block_n (int): Block size for N dimension of B_scale.
1892
+ scale_block_k (int): Block size for K dimension of A_scale and B_scale.
1893
+ stride_am (int): Stride of M dimension of A.
1894
+ stride_ak (int): Stride of K dimension of A.
1895
+ stride_bn (int): Stride of N dimension of B.
1896
+ stride_bk (int): Stride of K dimension of B.
1897
+ stride_cm (int): Stride of M dimension of C.
1898
+ stride_cn (int): Stride of N dimension of C.
1899
+ stride_scale_am (int): Stride of M dimension of A_scale.
1900
+ stride_scale_ak (int): Stride of K dimension of A_scale.
1901
+ stride_scale_bn (int): Stride of N dimension of B_scale.
1902
+ stride_scale_bk (int): Stride of K dimension of B_scale.
1903
+ dot_out_dtype (torch.dtype): Output type of tensor core.
1904
+ allow_tf32 (bool): Whether to use TF32 for tensor core.
1905
+ fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
1906
+ BLOCK_M (int): Block size for M dimension.
1907
+ BLOCK_N (int): Block size for N dimension.
1908
+ BLOCK_K (int): Block size for K dimension.
1909
+ GROUP_M (int): Number of groups for M dimension swizzle.
1910
+ SPLIT_K (int): Number of SM's to launch per row.
1911
+ EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
1912
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
1913
+ """
1914
+ assert BLOCK_M < scale_block_m
1915
+ assert BLOCK_N < scale_block_n
1916
+ assert BLOCK_K < scale_block_k
1917
+ # matrix multiplication
1918
+ pid = tl.program_id(0)
1919
+ pid_z = tl.program_id(1)
1920
+
1921
+ grid_m = tl.cdiv(M, BLOCK_M)
1922
+ grid_n = tl.cdiv(N, BLOCK_N)
1923
+ # re-order program ID for better L2 performance
1924
+ width = GROUP_M * grid_n
1925
+ group_id = pid // width
1926
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
1927
+ pid_m = group_id * GROUP_M + (pid % group_size)
1928
+ pid_n = (pid % width) // (group_size)
1929
+ # do matrix multiplication
1930
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
1931
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
1932
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
1933
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
1934
+ rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
1935
+ # pointers
1936
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
1937
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
1938
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
1939
+ scale_m = pid_m * BLOCK_M // scale_block_m
1940
+ scale_n = pid_n * BLOCK_N // scale_block_n
1941
+ _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
1942
+
1943
+ for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
1944
+ # Note: Due to split_k access "pid_k" = k * SPLIT_K + pid_z
1945
+ # Access a_scale[pid_m, k * SPLIT_K + pid_z]
1946
+ # and b_scale[k * SPLIT_K + pid_z, pid_n]
1947
+ pid_k = k * SPLIT_K + pid_z
1948
+ scale_k = pid_k * BLOCK_K // scale_block_k
1949
+ a_scale = tl.load(
1950
+ A_scale + scale_m * stride_scale_am + scale_k * stride_scale_ak
1951
+ )
1952
+ b_scale = tl.load(
1953
+ B_scale + scale_n * stride_scale_bn + scale_k * stride_scale_bk
1954
+ )
1955
+ scale = a_scale * b_scale
1956
+
1957
+ if EVEN_K:
1958
+ a = tl.load(A)
1959
+ b = tl.load(B)
1960
+ else:
1961
+ k_remaining = K - k * (BLOCK_K * SPLIT_K)
1962
+
1963
+ a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
1964
+ b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
1965
+ if AB_DTYPE:
1966
+ a = a.to(C.dtype.element_ty)
1967
+ b = b.to(C.dtype.element_ty)
1968
+
1969
+ acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) * scale
1970
+ A += BLOCK_K * SPLIT_K * stride_ak
1971
+ B += BLOCK_K * SPLIT_K * stride_bk
1972
+
1973
+ # rematerialize rm and rn to save registers
1974
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
1975
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
1976
+
1977
+ acc = acc.to(C.dtype.element_ty)
1978
+ c = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
1979
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
1980
+ # handles write-back with reduction-splitting
1981
+ if SPLIT_K == 1:
1982
+ tl.store(c, acc, mask=mask)
1983
+ else:
1984
+ tl.atomic_add(c, acc, mask=mask)
1985
+
1986
+
1987
+ @torch.library.custom_op("triton::matmul_fp8_block", mutates_args=())
1988
+ def matmul_fp8_block(
1989
+ a: torch.Tensor,
1990
+ b: torch.Tensor,
1991
+ a_scale: torch.Tensor,
1992
+ b_scale: torch.Tensor,
1993
+ scale_block_m: int = 256,
1994
+ scale_block_n: int = 256,
1995
+ scale_block_k: int = 256,
1996
+ dot_out_dtype: Optional[torch.dtype] = None,
1997
+ allow_tf32: bool = True,
1998
+ fp8_fast_accum: bool = True,
1999
+ ) -> Tensor:
2000
+ """Performs matmul on [M, K] and [N, K] fp8 matrices with block-wise scalings.
2001
+
2002
+ Args:
2003
+ a (torch.Tensor): [M, K] input tensor.
2004
+ b (torch.Tensor): [N, K] input tensor.
2005
+ a_scale (torch.Tensor): [cdiv(M, scale_block_m), cdiv(K, scale_block_k)] reciprocal scale tensor per scale block. A * A_scale = original A
2006
+ b_scale (torch.Tensor): [cdiv(N, scale_block_n), cdiv(K, scale_block_k)] reciprocal scale tensor per scale block. B * B_scale = original B
2007
+ scale_block_m (int): Block size for M dimension of A_scale.
2008
+ scale_block_n (int): Block size for N dimension of B_scale.
2009
+ scale_block_k (int): Block size for K dimension of A_scale and B_scale.
2010
+ dot_out_dtype (torch.dtype): Output type of tensor core.
2011
+ allow_tf32 (bool): Whether to use TF32 for tensor core.
2012
+ fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
2013
+
2014
+ Returns:
2015
+ Tensor: [M, N] output tensor, (a / a_scale) @ (b / b_scale)
2016
+ """
2017
+ # Get datatypes and constants to use.
2018
+ _, tl_fp8_dtype, _, _ = get_fp8_constants()
2019
+ # Handle 3D+ a shape
2020
+ a_shape = a.shape
2021
+ a = a.view(-1, a.size(-1))
2022
+ # View inputs into proper triton fp8 dtype.
2023
+ a_tl = reinterpret_fp8_type(a, tl_fp8_dtype)
2024
+ b_tl = reinterpret_fp8_type(b, tl_fp8_dtype)
2025
+
2026
+ M, N, K, m_key, n_key, k_key, c, _, dot_out_dtype_triton, device = prep_matmul(
2027
+ a_tl, b_tl, dot_out_dtype
2028
+ )
2029
+
2030
+ output_shape = a_shape[:-1] + (N,)
2031
+ # Handle case where inputs are empty.
2032
+ if (M == 0) or (N == 0) or (K == 0):
2033
+ return torch.zeros(output_shape, device=device, dtype=torch.bfloat16)
2034
+
2035
+ # launch kernel
2036
+ assert device != torch.device("cpu"), (
2037
+ "Blockwise matmul not supported on cpu, please use row-wise instead."
2038
+ )
2039
+
2040
+ if b.device != a.device:
2041
+ raise Exception("'b' must be on the same device as 'a'")
2042
+ if a_scale.device != a.device:
2043
+ raise Exception("'a_scale' must be on the same device as 'a'")
2044
+ if b_scale.device != a.device:
2045
+ raise Exception("'b_scale' must be on the same device as 'a'")
2046
+
2047
+ # noqa: E731:
2048
+ def grid(META: Dict[str, int]) -> Tuple[int, int]:
2049
+ return (
2050
+ triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
2051
+ META["SPLIT_K"],
2052
+ )
2053
+
2054
+ if fp8_fast_accum:
2055
+ with torch.cuda.device(a_tl.device.index):
2056
+ _kernel_matmul_fp8_block_fastacc[grid](
2057
+ a_tl,
2058
+ b_tl,
2059
+ c,
2060
+ M,
2061
+ N,
2062
+ K,
2063
+ m_key,
2064
+ n_key,
2065
+ k_key,
2066
+ a_scale,
2067
+ b_scale,
2068
+ scale_block_m,
2069
+ scale_block_n,
2070
+ scale_block_k,
2071
+ a.stride(0),
2072
+ a.stride(1),
2073
+ b.stride(0),
2074
+ b.stride(1),
2075
+ c.stride(0),
2076
+ c.stride(1),
2077
+ a_scale.stride(0),
2078
+ a_scale.stride(1),
2079
+ b_scale.stride(0),
2080
+ b_scale.stride(1),
2081
+ dot_out_dtype=dot_out_dtype_triton,
2082
+ allow_tf32=allow_tf32,
2083
+ GROUP_M=8,
2084
+ AB_DTYPE=False,
2085
+ )
2086
+ else:
2087
+ with torch.cuda.device(a_tl.device.index):
2088
+ _kernel_matmul_fp8_block_slowacc[grid](
2089
+ a_tl,
2090
+ b_tl,
2091
+ c,
2092
+ M,
2093
+ N,
2094
+ K,
2095
+ m_key,
2096
+ n_key,
2097
+ k_key,
2098
+ a_scale,
2099
+ b_scale,
2100
+ scale_block_m,
2101
+ scale_block_n,
2102
+ scale_block_k,
2103
+ a.stride(0),
2104
+ a.stride(1),
2105
+ b.stride(0),
2106
+ b.stride(1),
2107
+ c.stride(0),
2108
+ c.stride(1),
2109
+ a_scale.stride(0),
2110
+ a_scale.stride(1),
2111
+ b_scale.stride(0),
2112
+ b_scale.stride(1),
2113
+ dot_out_dtype=dot_out_dtype_triton,
2114
+ allow_tf32=allow_tf32,
2115
+ GROUP_M=8,
2116
+ AB_DTYPE=False,
2117
+ )
2118
+ return c.view(output_shape)
2119
+
2120
+
2121
+ @matmul_fp8_block.register_fake
2122
+ def matmul_fp8_block_meta(
2123
+ a: torch.Tensor,
2124
+ b: torch.Tensor,
2125
+ a_scale: torch.Tensor,
2126
+ b_scale: torch.Tensor,
2127
+ scale_block_m: int = 256,
2128
+ scale_block_n: int = 256,
2129
+ scale_block_k: int = 256,
2130
+ dot_out_dtype: Optional[torch.dtype] = None,
2131
+ allow_tf32: bool = True,
2132
+ fp8_fast_accum: bool = True,
2133
+ ) -> torch.Tensor:
2134
+ """Shape function for torch compile."""
2135
+ M, K = a.shape
2136
+ N, K = b.shape
2137
+ return torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
2138
+
2139
+
2140
+ def get_matmul_tune(M: int, N: int, K: int) -> Tuple[int, int, int]:
2141
+ """
2142
+ Generate a simplified matmul tune key for A @ B.T
2143
+ with [M, K] A and [N, K] B to reduce excessive autotuning.
2144
+
2145
+ Args:
2146
+ M (int): Number of rows in A.
2147
+ N (int): Number of rows in B.
2148
+ K (int): Number of cols in A and cols in B.
2149
+
2150
+ Returns:
2151
+ m_key (int): Autotuning key for M dim.
2152
+ n_key (int): Autotuning key for N dim.
2153
+ k_key (int): Autotuning key for K dim.
2154
+
2155
+ TODO: Refine this. For now it's useful for LLM inference where N, K dims are fixed
2156
+ and M dim varies due to seq_len.
2157
+ """
2158
+ if M < 256:
2159
+ m_key = M
2160
+ else:
2161
+ m_key = 256 + M // 1024
2162
+ return m_key, N, K
2163
+
2164
+
2165
+ def prep_matmul(
2166
+ a: Union[TensorWrapper, torch.Tensor],
2167
+ b: Union[TensorWrapper, torch.Tensor],
2168
+ dot_out_dtype: Optional[torch.dtype],
2169
+ ) -> Tuple[
2170
+ int, int, int, int, int, int, torch.Tensor, tl.dtype, tl.dtype, torch.device
2171
+ ]:
2172
+ """
2173
+ Shared bookkeeping for a @ b.T matmul.
2174
+
2175
+ Args:
2176
+ a (torch.Tensor): [M, K] input tensor.
2177
+ b (torch.Tensor): [N, K] input tensor.
2178
+ dot_out_dtype (tl.dtype): Output type of tensor core.
2179
+
2180
+ Returns:
2181
+ M (int): Number of rows in A.
2182
+ N (int): Number of rows in B.
2183
+ K (int): Number of cols in A and cols in B.
2184
+ m_key (int): Autotuning key for M dim.
2185
+ n_key (int): Autotuning key for N dim.
2186
+ k_key (int): Autotuning key for K dim.
2187
+ c (Tensor): [M, N] output tensor.
2188
+ c_dtype_triton (tl.dtype): Type of output tensor.
2189
+ dot_out_dtype (tl.dtype): Output type of tensor core.
2190
+ device (torch.device): Device of output tensor.
2191
+ """
2192
+ device = a.device
2193
+
2194
+ # checks constraints
2195
+ assert a.shape[1] == b.shape[1], (
2196
+ f"incompatible dimensions, a: {a.shape}, b: {b.shape}"
2197
+ )
2198
+ M, K = a.shape
2199
+ N, _ = b.shape
2200
+ m_key, n_key, k_key = get_matmul_tune(M, N, K)
2201
+
2202
+ # allocates output
2203
+ assert a.dtype in [
2204
+ torch.float8_e4m3fn,
2205
+ torch.float8_e5m2,
2206
+ torch.float8_e4m3fnuz,
2207
+ torch.float8_e5m2fnuz,
2208
+ tl.float8e4nv,
2209
+ tl.float8e4b15,
2210
+ tl.float8e5,
2211
+ tl.float8e4b8,
2212
+ ]
2213
+ assert b.dtype in [
2214
+ torch.float8_e4m3fn,
2215
+ torch.float8_e5m2,
2216
+ torch.float8_e4m3fnuz,
2217
+ torch.float8_e5m2fnuz,
2218
+ tl.float8e4nv,
2219
+ tl.float8e4b15,
2220
+ tl.float8e5,
2221
+ tl.float8e4b8,
2222
+ ]
2223
+
2224
+ c_dtype, c_dtype_triton = (
2225
+ (torch.bfloat16, tl.bfloat16)
2226
+ if dot_out_dtype is None
2227
+ else (dot_out_dtype, map_dtype_to_triton(dot_out_dtype))
2228
+ )
2229
+
2230
+ c = torch.empty((M, N), device=device, dtype=c_dtype)
2231
+ if dot_out_dtype is None:
2232
+ dot_out_dtype_triton = tl.float32
2233
+ else:
2234
+ assert isinstance(dot_out_dtype, torch.dtype), (
2235
+ f"dot_out_dtype type {type(dot_out_dtype)} must be a torch.dtype"
2236
+ )
2237
+ dot_out_dtype_triton = map_dtype_to_triton(dot_out_dtype)
2238
+
2239
+ return M, N, K, m_key, n_key, k_key, c, c_dtype_triton, dot_out_dtype_triton, device
2240
+
2241
+
2242
+ def need_split_k(SIZE_M, SIZE_N, SIZE_K):
2243
+ return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
2244
+
2245
+
2246
+ # Force a failure instead of a warning when all configs are pruned.
2247
+ # TODO: Determine a better approach for model level testing. We need
2248
+ # to standardize our approach around prune_configs in general.
2249
+ FORCE_FAILURE_ON_EMPTY_CONFIGS = False
2250
+
2251
+
2252
+ def is_invalid_config(config, N, M, K, mfma, use_bias):
2253
+ """
2254
+ Contains all of the configuration checks for prune_configs
2255
+ that will result in an invalid result if select as the config.
2256
+
2257
+ This is done to ensure that if no config is "optimal" for a given
2258
+ shape we don't accidentally select
2259
+ """
2260
+ BLOCK_SIZE_M = config.kwargs.get("BLOCK_M")
2261
+ BLOCK_SIZE_N = config.kwargs.get("BLOCK_N")
2262
+ BLOCK_SIZE_K = config.kwargs.get("BLOCK_K")
2263
+ SPLIT_K = config.kwargs.get("SPLIT_K")
2264
+ matrix_instr_nonkdim = config.kwargs.get("matrix_instr_nonkdim")
2265
+ if matrix_instr_nonkdim > mfma:
2266
+ return True
2267
+ if mfma == 4 and BLOCK_SIZE_K < 64:
2268
+ return True
2269
+ # some layouts could not work properly in case
2270
+ # number elements per thread is less 1
2271
+ if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
2272
+ return True
2273
+ if BLOCK_SIZE_M < matrix_instr_nonkdim or BLOCK_SIZE_N < matrix_instr_nonkdim:
2274
+ return True
2275
+ if M <= matrix_instr_nonkdim and BLOCK_SIZE_M != matrix_instr_nonkdim:
2276
+ return True
2277
+ if N <= matrix_instr_nonkdim and BLOCK_SIZE_N != matrix_instr_nonkdim:
2278
+ return True
2279
+ # split_k cannot be used if there is a bias
2280
+ if use_bias and SPLIT_K != 1:
2281
+ return True
2282
+ return False
2283
+
2284
+
2285
+ # Configs adapted from https://github.com/ROCm/triton/blob/main_perf/python/perf-kernels/tools/tune_gemm/tune_gemm.py
2286
+ def prune_configs(configs, named_args, **kwargs):
2287
+ pruned_configs = []
2288
+ M = named_args["M"]
2289
+ N = named_args["N"]
2290
+ K = named_args["K"]
2291
+ elemBytes_a = named_args["A"].element_size()
2292
+ elemBytes_b = named_args["B"].element_size()
2293
+ use_bias = kwargs["USE_BIAS"]
2294
+
2295
+ if M < 32 or N < 32:
2296
+ mfma = 16
2297
+ else:
2298
+ mfma = 32
2299
+
2300
+ for config in configs:
2301
+ BLOCK_SIZE_M = config.kwargs.get("BLOCK_M")
2302
+ BLOCK_SIZE_N = config.kwargs.get("BLOCK_N")
2303
+ BLOCK_SIZE_K = config.kwargs.get("BLOCK_K")
2304
+ SPLIT_K = config.kwargs.get("SPLIT_K")
2305
+ GROUP_M = config.kwargs.get("GROUP_M")
2306
+ if is_invalid_config(config, N, M, K, mfma, use_bias):
2307
+ continue
2308
+ # Skip BLOCK_SIZE that is too large compare to M/N
2309
+ # unless BLOCK_SIZE is already small enough
2310
+ if BLOCK_SIZE_M > M * 2 and BLOCK_SIZE_M != 16:
2311
+ continue
2312
+ if BLOCK_SIZE_N > N * 2 and BLOCK_SIZE_N != 16:
2313
+ continue
2314
+ # skip large split_k when not necessary
2315
+ if SPLIT_K != 1 and not need_split_k(M, N, K):
2316
+ continue
2317
+ # skip large GROUP_M
2318
+ if GROUP_M * BLOCK_SIZE_M >= M and GROUP_M != 1:
2319
+ continue
2320
+ # out of shared memory resource
2321
+ # TODO (zhanglx): This does not consider the LDS usage in the epilogue
2322
+ LDS = (
2323
+ BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
2324
+ + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
2325
+ )
2326
+ if LDS > 65536:
2327
+ continue
2328
+ pruned_configs.append(config)
2329
+
2330
+ print(f"{len(configs)=} {len(pruned_configs)=} for {M=} {N=} {K=}")
2331
+ if len(pruned_configs) == 0:
2332
+ if not FORCE_FAILURE_ON_EMPTY_CONFIGS:
2333
+ # Prune configs that can lead to incorrect results even if all configs are sub-optimal.
2334
+ candidate_configs = [
2335
+ c for c in configs if not is_invalid_config(c, N, M, K, mfma, use_bias)
2336
+ ]
2337
+ print(f"No configs left after pruning! {M=} {N=} {K=}")
2338
+ pruned_configs = candidate_configs[:10]
2339
+ if len(pruned_configs) == 0:
2340
+ raise RuntimeError(
2341
+ "No valid configs left after pruning! Consider autotuning further with TritonBench"
2342
+ )
2343
+ return pruned_configs
2344
+
2345
+
2346
+ def get_full_non_persistent_tuning_space():
2347
+ configs = []
2348
+
2349
+ block_mn_range = [16, 32, 64, 128, 256]
2350
+ block_k_range = [16, 32, 64, 128, 256]
2351
+ split_k_range = [1]
2352
+ num_warps_range = [1, 2, 4, 8]
2353
+ group_m_range = [1, 2, 4, 8, 16, 32]
2354
+ num_stage_range = [2]
2355
+ waves_per_eu_range = [0]
2356
+ matrix_instr_nonkdim_range = [16, 32]
2357
+ kpack_range = [1, 2]
2358
+
2359
+ for block_m in block_mn_range:
2360
+ for block_n in block_mn_range:
2361
+ for block_k in block_k_range:
2362
+ for num_warps in num_warps_range:
2363
+ for group_m in group_m_range:
2364
+ for split_k in split_k_range:
2365
+ for num_stages in num_stage_range:
2366
+ for waves_per_eu in waves_per_eu_range:
2367
+ for (
2368
+ matrix_instr_nonkdim
2369
+ ) in matrix_instr_nonkdim_range:
2370
+ for kpack in kpack_range:
2371
+ configs.append(
2372
+ triton.Config(
2373
+ {
2374
+ "BLOCK_M": block_m,
2375
+ "BLOCK_N": block_n,
2376
+ "BLOCK_K": block_k,
2377
+ "GROUP_M": group_m,
2378
+ "SPLIT_K": split_k,
2379
+ "waves_per_eu": waves_per_eu,
2380
+ "matrix_instr_nonkdim": matrix_instr_nonkdim,
2381
+ "kpack": kpack,
2382
+ },
2383
+ num_warps=num_warps,
2384
+ num_stages=num_stages,
2385
+ )
2386
+ )
2387
+ return configs
2388
+
2389
+
2390
+ MATMUL_CONFIGS_NON_PERSISTENT: List[Config] = get_full_non_persistent_tuning_space()
2391
+ # (BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M, SPLIT_K, waves_per_eu, matrix_instr_nonkdim, kpack, num_warps, num_stages)
2392
+ _MATMUL_CONFIG_TUPLES_PINGPONG_4K_8K_16K = [
2393
+ (16, 16, 256, 1, 1, 8, 16, 2, 2, 2),
2394
+ (16, 16, 256, 1, 1, 0, 16, 2, 2, 2),
2395
+ (32, 64, 512, 1, 1, 2, 16, 2, 8, 2),
2396
+ (64, 64, 256, 1, 1, 2, 16, 2, 4, 2),
2397
+ (256, 256, 128, 32, 1, 2, 16, 1, 8, 2),
2398
+ (256, 256, 128, 2, 1, 0, 32, 2, 8, 2),
2399
+ (256, 256, 128, 1, 1, 0, 32, 2, 8, 2),
2400
+ (256, 256, 128, 2, 1, 0, 16, 1, 8, 2),
2401
+ (256, 256, 64, 2, 1, 2, 16, 1, 8, 2),
2402
+ (128, 256, 64, 2, 1, 2, 16, 1, 4, 2),
2403
+ (256, 128, 128, 4, 1, 0, 16, 1, 8, 2),
2404
+ (128, 128, 128, 1, 1, 2, 16, 2, 4, 2),
2405
+ (128, 128, 256, 1, 1, 2, 16, 2, 8, 2),
2406
+ (128, 128, 64, 4, 1, 2, 16, 2, 4, 2),
2407
+ (128, 128, 64, 1, 1, 2, 16, 2, 4, 2),
2408
+ (128, 64, 64, 4, 1, 0, 16, 2, 4, 2),
2409
+ (128, 64, 64, 1, 1, 0, 16, 2, 4, 2),
2410
+ (256, 128, 128, 1, 1, 2, 16, 1, 8, 2),
2411
+ (128, 256, 128, 2, 1, 2, 16, 2, 4, 1),
2412
+ (256, 128, 64, 2, 1, 2, 16, 1, 4, 2),
2413
+ (128, 128, 256, 2, 1, 0, 16, 2, 8, 2),
2414
+ (128, 64, 128, 2, 1, 2, 16, 2, 4, 2),
2415
+ (128, 128, 64, 2, 1, 0, 16, 1, 4, 2),
2416
+ (128, 128, 128, 1, 1, 2, 16, 1, 4, 2),
2417
+ ]
2418
+
2419
+
2420
+ def _should_skip_config(block_k, matrix_instr_nonkdim):
2421
+ """Skip config if BLOCK_K=64 and matrix_instr_nonkdim=16 on GFX95+"""
2422
+ try:
2423
+ return (
2424
+ block_k == 64
2425
+ and matrix_instr_nonkdim == 16
2426
+ and torch.version.hip is not None
2427
+ and torch.cuda.get_device_capability() >= (9, 5)
2428
+ )
2429
+ except RuntimeError:
2430
+ # If no HIP GPUs are available, we can't check device capability
2431
+ # so we don't skip any configs
2432
+ return False
2433
+
2434
+
2435
+ MATMUL_CONFIGS_NON_PERSISTENT_PINGPONG_4K_8K_16K = [
2436
+ triton.Config(
2437
+ {
2438
+ "BLOCK_M": block_m,
2439
+ "BLOCK_N": block_n,
2440
+ "BLOCK_K": block_k,
2441
+ "GROUP_M": group_m,
2442
+ "SPLIT_K": split_k,
2443
+ "waves_per_eu": waves_per_eu,
2444
+ "matrix_instr_nonkdim": matrix_instr_nonkdim,
2445
+ "kpack": kpack,
2446
+ },
2447
+ num_warps=num_warps,
2448
+ num_stages=num_stages,
2449
+ )
2450
+ for block_m, block_n, block_k, group_m, split_k, waves_per_eu, matrix_instr_nonkdim, kpack, num_warps, num_stages in _MATMUL_CONFIG_TUPLES_PINGPONG_4K_8K_16K
2451
+ if not _should_skip_config(block_k, matrix_instr_nonkdim)
2452
+ ]
2453
+
2454
+ # Set this to enable full autotuning for proper benchmarking.
2455
+ # This should only be used when invoking the kernel through
2456
+ # Triton directly (e.g. TritonBench)
2457
+ #
2458
+ # NOTE: This will SIGNIFICANTLY increase autotuning time, often
2459
+ # taking hours. You should combine this with TRITON_PRINT_AUTOTUNING=1
2460
+ # to extract and add the optimal autotuning configs to
2461
+ # MATMUL_CONFIGS_NON_PERSISTENT_PINGPONG_4K_8K_16K.
2462
+
2463
+ FULL_NON_PERSISTENT_AUTOTUNING = False
2464
+ USED_MATMUL_NON_PERSISTENT_CONFIGS = (
2465
+ MATMUL_CONFIGS_NON_PERSISTENT
2466
+ if FULL_NON_PERSISTENT_AUTOTUNING
2467
+ else MATMUL_CONFIGS_NON_PERSISTENT_PINGPONG_4K_8K_16K
2468
+ )
2469
+
2470
+
2471
+ @triton.autotune(
2472
+ configs=USED_MATMUL_NON_PERSISTENT_CONFIGS,
2473
+ key=["M", "N", "K"],
2474
+ prune_configs_by={
2475
+ "early_config_prune": prune_configs,
2476
+ "perf_model": None,
2477
+ "top_k": None,
2478
+ },
2479
+ use_cuda_graph=FULL_NON_PERSISTENT_AUTOTUNING,
2480
+ )
2481
+ @triton.heuristics(
2482
+ {
2483
+ "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
2484
+ }
2485
+ )
2486
+ @triton.jit
2487
+ def _kernel_matmul_fp8_row_non_persistent(
2488
+ A,
2489
+ B,
2490
+ C,
2491
+ M,
2492
+ N,
2493
+ K,
2494
+ m_key,
2495
+ n_key,
2496
+ k_key,
2497
+ A_scale,
2498
+ B_scale,
2499
+ Bias,
2500
+ stride_am,
2501
+ stride_ak,
2502
+ stride_bn,
2503
+ stride_bk,
2504
+ stride_cm,
2505
+ stride_cn,
2506
+ dot_out_dtype: tl.constexpr,
2507
+ allow_tf32: tl.constexpr,
2508
+ fp8_fast_accum: tl.constexpr,
2509
+ BLOCK_M: tl.constexpr,
2510
+ BLOCK_N: tl.constexpr,
2511
+ BLOCK_K: tl.constexpr,
2512
+ GROUP_M: tl.constexpr,
2513
+ SPLIT_K: tl.constexpr,
2514
+ EVEN_K: tl.constexpr,
2515
+ USE_BIAS: tl.constexpr,
2516
+ AB_DTYPE: tl.constexpr,
2517
+ ) -> None:
2518
+ """Matmul kernel of [M, K] @ [N, K] with row-wise scales
2519
+
2520
+ performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
2521
+
2522
+ Args:
2523
+ A (TensorWrapper): [M, K] input tensor.
2524
+ B (TensorWrapper): [N, K] input tensor.
2525
+ C (TensorWrapper): [M, N] output tensor.
2526
+ M (int): M dimension of input tensor.
2527
+ N (int): N dimension of input tensor.
2528
+ K (int): K dimension of input tensor.
2529
+ m_key (int): Autotuning key for M dimension of input tensor.
2530
+ n_key (int): Autotuning key for N dimension of input tensor.
2531
+ k_key (int): Autotuning key for K dimension of input tensor.
2532
+ A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
2533
+ B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
2534
+ Bias (tensorWrapper): [N] Optional bias tensor.
2535
+ stride_am (int): Stride of M dimension of A.
2536
+ stride_ak (int): Stride of K dimension of A.
2537
+ stride_bn (int): Stride of N dimension of B.
2538
+ stride_bk (int): Stride of K dimension of B.
2539
+ stride_cm (int): Stride of M dimension of C.
2540
+ stride_cn (int): Stride of N dimension of C.
2541
+ dot_out_dtype (torch.dtype): Output type of tensor core.
2542
+ allow_tf32 (bool): Whether to use TF32 for tensor core.
2543
+ fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
2544
+ BLOCK_M (int): Block size for M dimension.
2545
+ BLOCK_N (int): Block size for N dimension.
2546
+ BLOCK_K (int): Block size for K dimension.
2547
+ GROUP_M (int): Number of groups for M dimension swizzle.
2548
+ SPLIT_K (int): Number of SM's to launch per row.
2549
+ EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
2550
+ USE_BIAS (bool): Whether to use bias.
2551
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
2552
+ """
2553
+ tl.assume(M >= 0)
2554
+ tl.assume(N >= 0)
2555
+ tl.assume(K >= 0)
2556
+ tl.assume(stride_am >= 0)
2557
+ tl.assume(stride_ak >= 0)
2558
+ tl.assume(stride_bn >= 0)
2559
+ tl.assume(stride_bk >= 0)
2560
+ tl.assume(stride_cm >= 0)
2561
+ tl.assume(stride_cn >= 0)
2562
+ # Matrix multiplication.
2563
+ pid = tl.program_id(0)
2564
+ pid_z = tl.program_id(1)
2565
+ grid_m = tl.cdiv(M, BLOCK_M)
2566
+ grid_n = tl.cdiv(N, BLOCK_N)
2567
+ # Re-order program ID for better L2 performance (swizzle).
2568
+ width = GROUP_M * grid_n
2569
+ group_id = pid // width
2570
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
2571
+ pid_m = group_id * GROUP_M + ((pid % width) % group_size)
2572
+ pid_n = (pid % width) // (group_size)
2573
+ tl.assume(pid_m >= 0)
2574
+ tl.assume(pid_n >= 0)
2575
+ # Do matrix multiplication.
2576
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
2577
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
2578
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
2579
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
2580
+ rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
2581
+ # Pointers.
2582
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
2583
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
2584
+ acc_dtype = tl.float32 if allow_tf32 else dot_out_dtype
2585
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)
2586
+
2587
+ for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
2588
+ if EVEN_K:
2589
+ a = tl.load(A)
2590
+ b = tl.load(B)
2591
+ else:
2592
+ k_remaining = K - k * (BLOCK_K * SPLIT_K)
2593
+ _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
2594
+ a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
2595
+ b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
2596
+ if AB_DTYPE:
2597
+ a = a.to(C.dtype.element_ty)
2598
+ b = b.to(C.dtype.element_ty)
2599
+ if fp8_fast_accum:
2600
+ acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32)
2601
+ else:
2602
+ acc += tl.dot(a, b, out_dtype=acc_dtype, allow_tf32=allow_tf32)
2603
+
2604
+ A += BLOCK_K * SPLIT_K * stride_ak
2605
+ B += BLOCK_K * SPLIT_K * stride_bk
2606
+
2607
+ # rematerialize rm and rn to save registers
2608
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
2609
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
2610
+
2611
+ # Invert scaling.
2612
+ a_scale = tl.load(A_scale + rm, mask=rm < M)
2613
+ b_scale = tl.load(B_scale + rn, mask=rn < N)
2614
+ # Invert vector, then multiply on matrix for speed.
2615
+ # pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`.
2616
+ scale = a_scale[:, None] * b_scale[None, :]
2617
+ acc *= scale
2618
+
2619
+ # Load and add bias if specified.
2620
+ if USE_BIAS:
2621
+ bias = tl.load(Bias + rn, mask=rn < N)
2622
+ acc += bias[None, :]
2623
+
2624
+ acc = acc.to(C.dtype.element_ty)
2625
+ C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
2626
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
2627
+ # Handles write-back with reduction-splitting
2628
+ if SPLIT_K == 1:
2629
+ tl.store(C, acc, mask=mask)
2630
+ else:
2631
+ tl.atomic_add(C, acc, mask=mask)
2632
+
2633
+
2634
+ # This function is extracted from https://github.com/pytorch/ao/blob/v0.12.0/torchao/prototype/mx_formats/mx_tensor.py#L142
2635
+ def to_mxfp8(
2636
+ data_hp: torch.Tensor,
2637
+ block_size: int = 32,
2638
+ ):
2639
+ assert data_hp.dtype in (
2640
+ torch.bfloat16,
2641
+ torch.float,
2642
+ ), f"{data_hp.dtype} is not supported yet"
2643
+ assert data_hp.shape[-1] % block_size == 0, (
2644
+ f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}"
2645
+ )
2646
+ assert data_hp.is_contiguous(), "unsupported"
2647
+
2648
+ orig_shape = data_hp.shape
2649
+ data_hp = data_hp.reshape(
2650
+ *orig_shape[:-1], orig_shape[-1] // block_size, block_size
2651
+ )
2652
+
2653
+ max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1)
2654
+
2655
+ data_hp = data_hp.to(torch.float32)
2656
+ max_abs = max_abs.to(torch.float32)
2657
+
2658
+ F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0
2659
+ max_pos = F8E4M3_MAX
2660
+
2661
+ # RCEIL
2662
+ def _to_mx_rceil(
2663
+ data_hp: torch.Tensor,
2664
+ max_abs: torch.Tensor,
2665
+ max_pos: float,
2666
+ ) -> tuple[torch.Tensor, torch.Tensor]:
2667
+ E8M0_EXPONENT_BIAS = 127
2668
+ descale = max_abs / max_pos
2669
+ exponent = torch.where(
2670
+ torch.isnan(descale),
2671
+ 0xFF, # Handle biased exponent for nan
2672
+ # NOTE: descale < (torch.finfo(torch.float32).smallest_normal / 2) is handled through clamping
2673
+ (
2674
+ torch.clamp(
2675
+ torch.ceil(torch.log2(descale)),
2676
+ min=-E8M0_EXPONENT_BIAS,
2677
+ max=E8M0_EXPONENT_BIAS,
2678
+ )
2679
+ + E8M0_EXPONENT_BIAS
2680
+ ).to(torch.uint8),
2681
+ )
2682
+
2683
+ descale_fp = torch.where(
2684
+ exponent == 0,
2685
+ 1.0,
2686
+ torch.exp2(E8M0_EXPONENT_BIAS - exponent.to(torch.float32)),
2687
+ )
2688
+
2689
+ # scale and saturated cast the data elements to max of target dtype
2690
+ data_lp = torch.clamp(data_hp * descale_fp, min=-1 * max_pos, max=max_pos)
2691
+ return exponent, data_lp
2692
+
2693
+ scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos)
2694
+
2695
+ # cast to target dtype
2696
+ data_lp = data_lp.to(torch.float8_e4m3fn)
2697
+ # need to reshape at the end to help inductor fuse things
2698
+ data_lp = data_lp.reshape(orig_shape)
2699
+
2700
+ scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu)
2701
+ scale_e8m0_biased = scale_e8m0_biased.squeeze(-1)
2702
+ return scale_e8m0_biased, data_lp