fbgemm-gpu-genai-nightly 2025.12.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fbgemm-gpu-genai-nightly might be problematic. Click here for more details.

Files changed (127) hide show
  1. fbgemm_gpu/__init__.py +186 -0
  2. fbgemm_gpu/asmjit.so +0 -0
  3. fbgemm_gpu/batched_unary_embeddings_ops.py +87 -0
  4. fbgemm_gpu/config/__init__.py +9 -0
  5. fbgemm_gpu/config/feature_list.py +88 -0
  6. fbgemm_gpu/docs/__init__.py +18 -0
  7. fbgemm_gpu/docs/common.py +9 -0
  8. fbgemm_gpu/docs/examples.py +73 -0
  9. fbgemm_gpu/docs/jagged_tensor_ops.py +259 -0
  10. fbgemm_gpu/docs/merge_pooled_embedding_ops.py +36 -0
  11. fbgemm_gpu/docs/permute_pooled_embedding_ops.py +108 -0
  12. fbgemm_gpu/docs/quantize_ops.py +41 -0
  13. fbgemm_gpu/docs/sparse_ops.py +616 -0
  14. fbgemm_gpu/docs/target.genai.json.py +6 -0
  15. fbgemm_gpu/enums.py +24 -0
  16. fbgemm_gpu/experimental/example/__init__.py +29 -0
  17. fbgemm_gpu/experimental/example/fbgemm_gpu_experimental_example_py.so +0 -0
  18. fbgemm_gpu/experimental/example/utils.py +20 -0
  19. fbgemm_gpu/experimental/gemm/triton_gemm/__init__.py +15 -0
  20. fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py +5654 -0
  21. fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +4422 -0
  22. fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py +1192 -0
  23. fbgemm_gpu/experimental/gemm/triton_gemm/matmul_perf_model.py +232 -0
  24. fbgemm_gpu/experimental/gemm/triton_gemm/utils.py +130 -0
  25. fbgemm_gpu/experimental/gen_ai/__init__.py +56 -0
  26. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/__init__.py +46 -0
  27. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +333 -0
  28. fbgemm_gpu/experimental/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +552 -0
  29. fbgemm_gpu/experimental/gen_ai/bench/__init__.py +13 -0
  30. fbgemm_gpu/experimental/gen_ai/bench/comm_bench.py +257 -0
  31. fbgemm_gpu/experimental/gen_ai/bench/gather_scatter_bench.py +348 -0
  32. fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +707 -0
  33. fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +3483 -0
  34. fbgemm_gpu/experimental/gen_ai/fbgemm_gpu_experimental_gen_ai.so +0 -0
  35. fbgemm_gpu/experimental/gen_ai/moe/README.md +15 -0
  36. fbgemm_gpu/experimental/gen_ai/moe/__init__.py +66 -0
  37. fbgemm_gpu/experimental/gen_ai/moe/activation.py +292 -0
  38. fbgemm_gpu/experimental/gen_ai/moe/gather_scatter.py +740 -0
  39. fbgemm_gpu/experimental/gen_ai/moe/layers.py +1272 -0
  40. fbgemm_gpu/experimental/gen_ai/moe/shuffling.py +421 -0
  41. fbgemm_gpu/experimental/gen_ai/quantize.py +307 -0
  42. fbgemm_gpu/fbgemm.so +0 -0
  43. fbgemm_gpu/metrics.py +160 -0
  44. fbgemm_gpu/permute_pooled_embedding_modules.py +142 -0
  45. fbgemm_gpu/permute_pooled_embedding_modules_split.py +85 -0
  46. fbgemm_gpu/quantize/__init__.py +43 -0
  47. fbgemm_gpu/quantize/quantize_ops.py +64 -0
  48. fbgemm_gpu/quantize_comm.py +315 -0
  49. fbgemm_gpu/quantize_utils.py +246 -0
  50. fbgemm_gpu/runtime_monitor.py +237 -0
  51. fbgemm_gpu/sll/__init__.py +189 -0
  52. fbgemm_gpu/sll/cpu/__init__.py +80 -0
  53. fbgemm_gpu/sll/cpu/cpu_sll.py +1001 -0
  54. fbgemm_gpu/sll/meta/__init__.py +35 -0
  55. fbgemm_gpu/sll/meta/meta_sll.py +337 -0
  56. fbgemm_gpu/sll/triton/__init__.py +127 -0
  57. fbgemm_gpu/sll/triton/common.py +38 -0
  58. fbgemm_gpu/sll/triton/triton_dense_jagged_cat_jagged_out.py +72 -0
  59. fbgemm_gpu/sll/triton/triton_jagged2_to_padded_dense.py +221 -0
  60. fbgemm_gpu/sll/triton/triton_jagged_bmm.py +418 -0
  61. fbgemm_gpu/sll/triton/triton_jagged_bmm_jagged_out.py +553 -0
  62. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_add.py +52 -0
  63. fbgemm_gpu/sll/triton/triton_jagged_dense_elementwise_mul_jagged_out.py +175 -0
  64. fbgemm_gpu/sll/triton/triton_jagged_dense_flash_attention.py +861 -0
  65. fbgemm_gpu/sll/triton/triton_jagged_flash_attention_basic.py +667 -0
  66. fbgemm_gpu/sll/triton/triton_jagged_self_substraction_jagged_out.py +73 -0
  67. fbgemm_gpu/sll/triton/triton_jagged_softmax.py +463 -0
  68. fbgemm_gpu/sll/triton/triton_multi_head_jagged_flash_attention.py +751 -0
  69. fbgemm_gpu/sparse_ops.py +1455 -0
  70. fbgemm_gpu/split_embedding_configs.py +452 -0
  71. fbgemm_gpu/split_embedding_inference_converter.py +175 -0
  72. fbgemm_gpu/split_embedding_optimizer_ops.py +21 -0
  73. fbgemm_gpu/split_embedding_utils.py +29 -0
  74. fbgemm_gpu/split_table_batched_embeddings_ops.py +73 -0
  75. fbgemm_gpu/split_table_batched_embeddings_ops_common.py +484 -0
  76. fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +2042 -0
  77. fbgemm_gpu/split_table_batched_embeddings_ops_training.py +4600 -0
  78. fbgemm_gpu/split_table_batched_embeddings_ops_training_common.py +146 -0
  79. fbgemm_gpu/ssd_split_table_batched_embeddings_ops.py +26 -0
  80. fbgemm_gpu/tbe/__init__.py +6 -0
  81. fbgemm_gpu/tbe/bench/__init__.py +55 -0
  82. fbgemm_gpu/tbe/bench/bench_config.py +156 -0
  83. fbgemm_gpu/tbe/bench/bench_runs.py +709 -0
  84. fbgemm_gpu/tbe/bench/benchmark_click_interface.py +187 -0
  85. fbgemm_gpu/tbe/bench/eeg_cli.py +137 -0
  86. fbgemm_gpu/tbe/bench/embedding_ops_common_config.py +149 -0
  87. fbgemm_gpu/tbe/bench/eval_compression.py +119 -0
  88. fbgemm_gpu/tbe/bench/reporter.py +35 -0
  89. fbgemm_gpu/tbe/bench/tbe_data_config.py +137 -0
  90. fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +323 -0
  91. fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +289 -0
  92. fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +170 -0
  93. fbgemm_gpu/tbe/bench/utils.py +48 -0
  94. fbgemm_gpu/tbe/cache/__init__.py +11 -0
  95. fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py +385 -0
  96. fbgemm_gpu/tbe/cache/split_embeddings_cache_ops.py +48 -0
  97. fbgemm_gpu/tbe/ssd/__init__.py +15 -0
  98. fbgemm_gpu/tbe/ssd/common.py +46 -0
  99. fbgemm_gpu/tbe/ssd/inference.py +586 -0
  100. fbgemm_gpu/tbe/ssd/training.py +4908 -0
  101. fbgemm_gpu/tbe/ssd/utils/__init__.py +7 -0
  102. fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py +273 -0
  103. fbgemm_gpu/tbe/stats/__init__.py +10 -0
  104. fbgemm_gpu/tbe/stats/bench_params_reporter.py +339 -0
  105. fbgemm_gpu/tbe/utils/__init__.py +13 -0
  106. fbgemm_gpu/tbe/utils/common.py +42 -0
  107. fbgemm_gpu/tbe/utils/offsets.py +65 -0
  108. fbgemm_gpu/tbe/utils/quantize.py +251 -0
  109. fbgemm_gpu/tbe/utils/requests.py +556 -0
  110. fbgemm_gpu/tbe_input_multiplexer.py +108 -0
  111. fbgemm_gpu/triton/__init__.py +22 -0
  112. fbgemm_gpu/triton/common.py +77 -0
  113. fbgemm_gpu/triton/jagged/__init__.py +8 -0
  114. fbgemm_gpu/triton/jagged/triton_jagged_tensor_ops.py +824 -0
  115. fbgemm_gpu/triton/quantize.py +647 -0
  116. fbgemm_gpu/triton/quantize_ref.py +286 -0
  117. fbgemm_gpu/utils/__init__.py +11 -0
  118. fbgemm_gpu/utils/filestore.py +211 -0
  119. fbgemm_gpu/utils/loader.py +36 -0
  120. fbgemm_gpu/utils/torch_library.py +132 -0
  121. fbgemm_gpu/uvm.py +40 -0
  122. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/METADATA +62 -0
  123. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/RECORD +127 -0
  124. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/WHEEL +5 -0
  125. fbgemm_gpu_genai_nightly-2025.12.19.dist-info/top_level.txt +2 -0
  126. list_versions/__init__.py +12 -0
  127. list_versions/cli_run.py +163 -0
@@ -0,0 +1,4422 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import functools
9
+ import logging
10
+ import os
11
+ from typing import Optional, Union
12
+
13
+ import torch
14
+ import triton # @manual
15
+
16
+ import triton.language as tl # @manual
17
+
18
+ from fbgemm_gpu.experimental.gemm.triton_gemm.matmul_perf_model import (
19
+ early_config_prune,
20
+ estimate_matmul_time,
21
+ )
22
+ from fbgemm_gpu.experimental.gemm.triton_gemm.utils import (
23
+ map_dtype_to_triton,
24
+ TmaAutoTuneHelper,
25
+ )
26
+
27
+ from packaging import version
28
+ from torch._tensor import Tensor
29
+
30
+ from triton import Config # @manual
31
+ from triton.runtime.jit import reinterpret as tl_reinterpret, TensorWrapper # @manual
32
+
33
+ logger: logging.Logger = logging.getLogger(__name__)
34
+
35
+ running_on_github: bool = os.getenv("GITHUB_ENV") is not None
36
+
37
+ try:
38
+ # pyre-ignore[21]
39
+ from triton.fb.compat import disable_bufferops # @manual
40
+ except ModuleNotFoundError:
41
+ # Ensure we can call disable_bufferops if compat is not included (e.g. opensource)
42
+ # TODO(njriasan): Remove when we integrate triton.fb.compat into every Triton
43
+ # version.
44
+ from contextlib import contextmanager
45
+
46
+ @contextmanager
47
+ def disable_bufferops(_unused: bool):
48
+ yield None
49
+
50
+
51
+ @functools.lru_cache
52
+ def supports_float8_fnuz(throw_on_hip_incompatibility: bool = True) -> bool:
53
+ if torch.version.hip:
54
+ device_capability = torch.cuda.get_device_capability()
55
+
56
+ if device_capability < (9, 4):
57
+ gpu_arch = torch.cuda.get_device_properties("cuda").gcnArchName
58
+ msg = f"Unsupported GPU arch: {gpu_arch} for FP8"
59
+ if throw_on_hip_incompatibility:
60
+ raise RuntimeError(msg)
61
+ else:
62
+ logging.error(msg)
63
+ return False
64
+
65
+ elif device_capability == (9, 4):
66
+ return True
67
+
68
+ return False
69
+
70
+
71
+ def get_fp8_constants() -> tuple[torch.dtype, tl.dtype, float, float]:
72
+ """
73
+ Helper function to get constant values for the current platform.
74
+
75
+ Returns:
76
+ pt_dtype (torch.dtype): The correct torch fp8 datatype.
77
+ tl_dtype (tl.dtype): The correct triton fp8 datatype.
78
+ max_fp8 (float): The maximum reprsentable value for the fp8 datatype.
79
+ eps (float): Minimum clip value to prevent divide by zero.
80
+ """
81
+ if supports_float8_fnuz(throw_on_hip_incompatibility=(not running_on_github)):
82
+ pt_fp8_dtype = torch.float8_e4m3fnuz
83
+ tl_fp8_dtype = tl.float8e4b8
84
+ else:
85
+ pt_fp8_dtype = torch.float8_e4m3fn
86
+ tl_fp8_dtype = tl.float8e4nv
87
+
88
+ return pt_fp8_dtype, tl_fp8_dtype, torch.finfo(pt_fp8_dtype).max, 1e-12
89
+
90
+
91
+ def reinterpret_fp8_type(tensor: torch.Tensor, dtype: tl.dtype) -> TensorWrapper:
92
+ """
93
+ Converts tensor to triton fp8 type.
94
+
95
+ Args:
96
+ tensor (torch.Tensor): input tensor.
97
+ dtype (tl.dtype): target triton dtype.
98
+
99
+ Returns:
100
+ triton.TensorWrapper: fp8 tensor.
101
+ """
102
+ return tl_reinterpret(tensor, dtype=dtype)
103
+
104
+
105
+ def init_to_zero(name):
106
+ return lambda nargs: nargs[name].zero_()
107
+
108
+
109
+ def get_configs_io_bound() -> list[Config]:
110
+ """
111
+ Returns a list of configs for matmul that are IO bound.
112
+
113
+ Returns:
114
+ List[Config]: list of configs.
115
+ """
116
+ configs = []
117
+ for num_stages in [2, 3, 4, 5, 6]:
118
+ for block_m in [16, 32]:
119
+ for block_k in [32, 64]:
120
+ for block_n in [32, 64, 128, 256]:
121
+ num_warps = 2 if block_n <= 64 else 4
122
+ configs.append(
123
+ Config(
124
+ {
125
+ "BLOCK_M": block_m,
126
+ "BLOCK_N": block_n,
127
+ "BLOCK_K": block_k,
128
+ "SPLIT_K": 1,
129
+ },
130
+ num_stages=num_stages,
131
+ num_warps=num_warps,
132
+ )
133
+ )
134
+ # split_k
135
+ for split_k in []: # Disabled [2, 4, 8, 16]:
136
+ configs.append(
137
+ Config(
138
+ {
139
+ "BLOCK_M": block_m,
140
+ "BLOCK_N": block_n,
141
+ "BLOCK_K": block_k,
142
+ "SPLIT_K": split_k,
143
+ },
144
+ num_stages=num_stages,
145
+ num_warps=num_warps,
146
+ pre_hook=init_to_zero("C"),
147
+ )
148
+ )
149
+ return configs
150
+
151
+
152
+ def dummy_prune_configs(configs, named_args, **kwargs):
153
+
154
+ M = named_args["M"]
155
+ N = named_args["N"]
156
+ K = named_args["K"]
157
+
158
+ logger.info(f"{len(configs)=} {len(configs)=} for {M=} {N=} {K=}")
159
+ return configs
160
+
161
+
162
+ MATMUL_CONFIGS: list[Config] = [
163
+ # basic configs for compute-bound matmuls
164
+ Config(
165
+ {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
166
+ num_stages=3,
167
+ num_warps=8,
168
+ ),
169
+ Config(
170
+ {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
171
+ num_stages=3,
172
+ num_warps=8,
173
+ ),
174
+ Config(
175
+ {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
176
+ num_stages=4,
177
+ num_warps=4,
178
+ ),
179
+ Config(
180
+ {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
181
+ num_stages=4,
182
+ num_warps=4,
183
+ ),
184
+ Config(
185
+ {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
186
+ num_stages=4,
187
+ num_warps=4,
188
+ ),
189
+ Config(
190
+ {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 256, "SPLIT_K": 1},
191
+ num_stages=4,
192
+ num_warps=4,
193
+ ),
194
+ Config(
195
+ {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
196
+ num_stages=4,
197
+ num_warps=4,
198
+ ),
199
+ Config(
200
+ {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
201
+ num_stages=4,
202
+ num_warps=4,
203
+ ),
204
+ Config(
205
+ {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
206
+ num_stages=4,
207
+ num_warps=4,
208
+ ),
209
+ Config(
210
+ {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
211
+ num_stages=4,
212
+ num_warps=4,
213
+ ),
214
+ Config(
215
+ {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
216
+ num_stages=5,
217
+ num_warps=2,
218
+ ),
219
+ # good for int8
220
+ Config(
221
+ {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
222
+ num_stages=3,
223
+ num_warps=8,
224
+ ),
225
+ Config(
226
+ {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
227
+ num_stages=3,
228
+ num_warps=8,
229
+ ),
230
+ Config(
231
+ {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1},
232
+ num_stages=4,
233
+ num_warps=4,
234
+ ),
235
+ Config(
236
+ {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
237
+ num_stages=4,
238
+ num_warps=4,
239
+ ),
240
+ Config(
241
+ {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
242
+ num_stages=4,
243
+ num_warps=4,
244
+ ),
245
+ Config(
246
+ {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
247
+ num_stages=4,
248
+ num_warps=4,
249
+ ),
250
+ Config(
251
+ {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
252
+ num_stages=4,
253
+ num_warps=4,
254
+ ),
255
+ Config(
256
+ {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
257
+ num_stages=4,
258
+ num_warps=4,
259
+ ),
260
+ Config(
261
+ {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
262
+ num_stages=5,
263
+ num_warps=2,
264
+ ),
265
+ ] + get_configs_io_bound()
266
+
267
+
268
+ @triton.autotune(
269
+ configs=MATMUL_CONFIGS,
270
+ prune_configs_by={
271
+ "early_config_prune": dummy_prune_configs,
272
+ },
273
+ key=[
274
+ "m_key",
275
+ "n_key",
276
+ "k_key",
277
+ ],
278
+ )
279
+ @triton.jit
280
+ def _kernel_matmul_fp8_row(
281
+ A_ptr,
282
+ B_ptr,
283
+ C_ptr,
284
+ M,
285
+ N,
286
+ K,
287
+ m_key,
288
+ n_key,
289
+ k_key,
290
+ A_scale,
291
+ B_scale,
292
+ Bias,
293
+ stride_am,
294
+ stride_ak,
295
+ stride_bn,
296
+ stride_bk,
297
+ stride_cm,
298
+ stride_cn,
299
+ dot_out_dtype: tl.constexpr,
300
+ allow_tf32: tl.constexpr,
301
+ fp8_fast_accum: tl.constexpr,
302
+ skip_scaling_a: tl.constexpr,
303
+ BLOCK_M: tl.constexpr,
304
+ BLOCK_N: tl.constexpr,
305
+ BLOCK_K: tl.constexpr,
306
+ GROUP_M: tl.constexpr,
307
+ SPLIT_K: tl.constexpr,
308
+ USE_BIAS: tl.constexpr,
309
+ AB_DTYPE: tl.constexpr,
310
+ NUM_SMS: tl.constexpr,
311
+ ) -> None:
312
+ """Matmul kernel of [M, K] @ [N, K] with row-wise scales
313
+
314
+ performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
315
+
316
+ Args:
317
+ A (TensorWrapper): [M, K] input tensor.
318
+ B (TensorWrapper): [N, K] input tensor.
319
+ C (TensorWrapper): [M, N] output tensor.
320
+ M (int): M dimension of input tensor.
321
+ N (int): N dimension of input tensor.
322
+ K (int): K dimension of input tensor.
323
+ m_key (int): Autotuning key for M dimension of input tensor.
324
+ n_key (int): Autotuning key for N dimension of input tensor.
325
+ k_key (int): Autotuning key for K dimension of input tensor.
326
+ A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A.
327
+ B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B.
328
+ Bias (tensorWrapper): [N] Optional bias tensor.
329
+ stride_am (int): Stride of M dimension of A.
330
+ stride_ak (int): Stride of K dimension of A.
331
+ stride_bn (int): Stride of N dimension of B.
332
+ stride_bk (int): Stride of K dimension of B.
333
+ stride_cm (int): Stride of M dimension of C.
334
+ stride_cn (int): Stride of N dimension of C.
335
+ dot_out_dtype (torch.dtype): Output type of tensor core.
336
+ allow_tf32 (bool): Whether to use TF32 for tensor core.
337
+ fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
338
+ BLOCK_M (int): Block size for M dimension.
339
+ BLOCK_N (int): Block size for N dimension.
340
+ BLOCK_K (int): Block size for K dimension.
341
+ GROUP_M (int): Number of groups for M dimension swizzle.
342
+ SPLIT_K (int): Number of SM's to launch per row.
343
+ USE_BIAS (bool): Whether to use bias.
344
+ EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
345
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
346
+ """
347
+ # Matrix multiplication.
348
+ start_pid = tl.program_id(axis=0)
349
+ num_pid_m = tl.cdiv(M, BLOCK_M)
350
+ num_pid_n = tl.cdiv(N, BLOCK_N)
351
+ k_tiles = tl.cdiv(K, BLOCK_K)
352
+ num_tiles = num_pid_m * num_pid_n
353
+
354
+ tiles_per_SM = num_tiles // NUM_SMS
355
+ if start_pid < num_tiles % NUM_SMS:
356
+ tiles_per_SM += 1
357
+
358
+ tile_id = start_pid - NUM_SMS
359
+ ki = -1
360
+
361
+ offs_k_for_mask = tl.arange(0, BLOCK_K)
362
+
363
+ num_pid_in_group = GROUP_M * num_pid_n
364
+
365
+ pid_m = 0
366
+ pid_n = 0
367
+ offs_am = tl.arange(0, BLOCK_M)
368
+ offs_bn = tl.arange(0, BLOCK_N)
369
+ acc_dtype = tl.float32 if allow_tf32 else dot_out_dtype
370
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)
371
+
372
+ for _ in range(0, k_tiles * tiles_per_SM):
373
+ ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
374
+ if ki == 0:
375
+ tile_id += NUM_SMS
376
+ group_id = tile_id // num_pid_in_group
377
+ first_pid_m = group_id * GROUP_M
378
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
379
+ pid_m = first_pid_m + (tile_id % group_size_m)
380
+ pid_n = (tile_id % num_pid_in_group) // group_size_m
381
+
382
+ start_m = pid_m * BLOCK_M
383
+ start_n = pid_n * BLOCK_N
384
+ offs_am = start_m + tl.arange(0, BLOCK_M)
385
+ offs_bn = start_n + tl.arange(0, BLOCK_N)
386
+ offs_am = tl.where(offs_am < M, offs_am, 0)
387
+ offs_bn = tl.where(offs_bn < N, offs_bn, 0)
388
+ offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_M), BLOCK_M)
389
+ offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_N), BLOCK_N)
390
+ offs_k = ki * BLOCK_K + tl.arange(0, BLOCK_K)
391
+ A = A_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
392
+ B = B_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
393
+
394
+ a = tl.load(A, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_K, other=0.0)
395
+ b = tl.load(B, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_K, other=0.0)
396
+ acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32)
397
+
398
+ if ki == k_tiles - 1:
399
+ # rematerialize rm and rn to save registers
400
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
401
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
402
+
403
+ # Invert scaling.
404
+ b_scale = tl.load(B_scale + rn, mask=rn < N)
405
+ if skip_scaling_a:
406
+ acc *= b_scale[None, :]
407
+ else:
408
+ a_scale = tl.load(A_scale + rm, mask=rm < M)
409
+ # pyre-ignore[16]: Undefined attribute [16]: `float`
410
+ # has no attribute `__getitem__`.
411
+ scale = a_scale[:, None] * b_scale[None, :]
412
+ acc *= scale
413
+
414
+ # Load and add bias if specified.
415
+ if USE_BIAS:
416
+ bias = tl.load(Bias + rn, mask=rn < N)
417
+ acc += bias[None, :]
418
+
419
+ acc = acc.to(C_ptr.dtype.element_ty)
420
+ C = C_ptr + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
421
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
422
+ # Handles write-back with reduction-splitting
423
+ tl.store(C, acc, mask=mask)
424
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)
425
+
426
+
427
+ @triton.autotune(
428
+ configs=MATMUL_CONFIGS
429
+ + [
430
+ Config(
431
+ {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
432
+ num_stages=3,
433
+ num_warps=8,
434
+ ),
435
+ ],
436
+ key=[
437
+ "m_key",
438
+ "n_key",
439
+ "k_key",
440
+ ],
441
+ )
442
+ @triton.heuristics(
443
+ {
444
+ "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
445
+ }
446
+ )
447
+ @triton.jit
448
+ def _kernel_matmul_fp8_row_no_fast_acc(
449
+ A_ptr,
450
+ B_ptr,
451
+ C_ptr,
452
+ M,
453
+ N,
454
+ K,
455
+ m_key,
456
+ n_key,
457
+ k_key,
458
+ A_scale,
459
+ B_scale,
460
+ Bias,
461
+ stride_am,
462
+ stride_ak,
463
+ stride_bn,
464
+ stride_bk,
465
+ stride_cm,
466
+ stride_cn,
467
+ dot_out_dtype: tl.constexpr,
468
+ allow_tf32: tl.constexpr,
469
+ fp8_fast_accum: tl.constexpr,
470
+ BLOCK_M: tl.constexpr,
471
+ BLOCK_N: tl.constexpr,
472
+ BLOCK_K: tl.constexpr,
473
+ GROUP_M: tl.constexpr,
474
+ SPLIT_K: tl.constexpr,
475
+ EVEN_K: tl.constexpr,
476
+ USE_BIAS: tl.constexpr,
477
+ AB_DTYPE: tl.constexpr,
478
+ NUM_SMS: tl.constexpr,
479
+ ) -> None:
480
+ """Matmul kernel of [M, K] @ [N, K] with row-wise scales
481
+
482
+ performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
483
+
484
+ Args:
485
+ A (TensorWrapper): [M, K] input tensor.
486
+ B (TensorWrapper): [N, K] input tensor.
487
+ C (TensorWrapper): [M, N] output tensor.
488
+ M (int): M dimension of input tensor.
489
+ N (int): N dimension of input tensor.
490
+ K (int): K dimension of input tensor.
491
+ m_key (int): Autotuning key for M dimension of input tensor.
492
+ n_key (int): Autotuning key for N dimension of input tensor.
493
+ k_key (int): Autotuning key for K dimension of input tensor.
494
+ A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
495
+ B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
496
+ Bias (TensorWrapper): [N] Optional bias tensor.
497
+ stride_am (int): Stride of M dimension of A.
498
+ stride_ak (int): Stride of K dimension of A.
499
+ stride_bn (int): Stride of N dimension of B.
500
+ stride_bk (int): Stride of K dimension of B.
501
+ stride_cm (int): Stride of M dimension of C.
502
+ stride_cn (int): Stride of N dimension of C.
503
+ dot_out_dtype (torch.dtype): Output type of tensor core.
504
+ allow_tf32 (bool): Whether to use TF32 for tensor core.
505
+ fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
506
+ BLOCK_M (int): Block size for M dimension.
507
+ BLOCK_N (int): Block size for N dimension.
508
+ BLOCK_K (int): Block size for K dimension.
509
+ GROUP_M (int): Number of groups for M dimension swizzle.
510
+ SPLIT_K (int): Number of SM's to launch per row.
511
+ EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
512
+ USE_BIAS(bool): Whether to use bias.
513
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
514
+ """
515
+ # Matrix multiplication.
516
+
517
+ start_pid = tl.program_id(axis=0)
518
+ num_pid_m = tl.cdiv(M, BLOCK_M)
519
+ num_pid_n = tl.cdiv(N, BLOCK_N)
520
+ k_tiles = tl.cdiv(K, BLOCK_K)
521
+ num_tiles = num_pid_m * num_pid_n
522
+
523
+ tiles_per_SM = num_tiles // NUM_SMS
524
+ if start_pid < num_tiles % NUM_SMS:
525
+ tiles_per_SM += 1
526
+
527
+ tile_id = start_pid - NUM_SMS
528
+ ki = -1
529
+
530
+ offs_k_for_mask = tl.arange(0, BLOCK_K)
531
+
532
+ num_pid_in_group = GROUP_M * num_pid_n
533
+
534
+ pid_m = 0
535
+ pid_n = 0
536
+ offs_am = tl.arange(0, BLOCK_M)
537
+ offs_bn = tl.arange(0, BLOCK_N)
538
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
539
+
540
+ for _ in range(0, k_tiles * tiles_per_SM):
541
+ ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
542
+ if ki == 0:
543
+ tile_id += NUM_SMS
544
+ group_id = tile_id // num_pid_in_group
545
+ first_pid_m = group_id * GROUP_M
546
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
547
+ pid_m = first_pid_m + (tile_id % group_size_m)
548
+ pid_n = (tile_id % num_pid_in_group) // group_size_m
549
+
550
+ start_m = pid_m * BLOCK_M
551
+ start_n = pid_n * BLOCK_N
552
+ offs_am = start_m + tl.arange(0, BLOCK_M)
553
+ offs_bn = start_n + tl.arange(0, BLOCK_N)
554
+ offs_am = tl.where(offs_am < M, offs_am, 0)
555
+ offs_bn = tl.where(offs_bn < N, offs_bn, 0)
556
+ offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_M), BLOCK_M)
557
+ offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_N), BLOCK_N)
558
+ offs_k = ki * BLOCK_K + tl.arange(0, BLOCK_K)
559
+ A = A_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
560
+ B = B_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
561
+
562
+ a = tl.load(A, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_K, other=0.0)
563
+ b = tl.load(B, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_K, other=0.0)
564
+ acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
565
+
566
+ if ki == k_tiles - 1:
567
+ # rematerialize rm and rn to save registers
568
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
569
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
570
+
571
+ # Invert scaling.
572
+ a_scale = tl.load(A_scale + rm, mask=rm < M)
573
+ b_scale = tl.load(B_scale + rn, mask=rn < N)
574
+ # pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`.
575
+ scale = a_scale[:, None] * b_scale[None, :]
576
+ acc *= scale
577
+
578
+ # Load and add bias if specified.
579
+ if USE_BIAS:
580
+ bias = tl.load(Bias + rn, mask=rn < N)
581
+ acc += bias[None, :]
582
+
583
+ acc = acc.to(C_ptr.dtype.element_ty)
584
+ C = C_ptr + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
585
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
586
+ # Handles write-back with reduction-splitting
587
+ tl.store(C, acc, mask=mask)
588
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
589
+
590
+
591
+ @triton.autotune(
592
+ configs=MATMUL_CONFIGS,
593
+ key=[
594
+ "m_key",
595
+ "n_key",
596
+ "k_key",
597
+ ],
598
+ )
599
+ @triton.heuristics(
600
+ {
601
+ "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
602
+ }
603
+ )
604
+ @triton.jit
605
+ def _kernel_matmul_fp8_row_imprecise_acc(
606
+ A,
607
+ B,
608
+ C,
609
+ M,
610
+ N,
611
+ K,
612
+ m_key,
613
+ n_key,
614
+ k_key,
615
+ A_scale,
616
+ B_scale,
617
+ Bias,
618
+ stride_am,
619
+ stride_ak,
620
+ stride_bn,
621
+ stride_bk,
622
+ stride_cm,
623
+ stride_cn,
624
+ dot_out_dtype: tl.constexpr,
625
+ allow_tf32: tl.constexpr,
626
+ fp8_fast_accum: tl.constexpr,
627
+ BLOCK_M: tl.constexpr,
628
+ BLOCK_N: tl.constexpr,
629
+ BLOCK_K: tl.constexpr,
630
+ GROUP_M: tl.constexpr,
631
+ SPLIT_K: tl.constexpr,
632
+ EVEN_K: tl.constexpr,
633
+ USE_BIAS: tl.constexpr,
634
+ AB_DTYPE: tl.constexpr,
635
+ ) -> None:
636
+ """Matmul kernel of [M, K] @ [N, K] with row-wise scales
637
+
638
+ performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
639
+
640
+ Args:
641
+ A (TensorWrapper): [M, K] input tensor.
642
+ B (TensorWrapper): [N, K] input tensor.
643
+ C (TensorWrapper): [M, N] output tensor.
644
+ M (int): M dimension of input tensor.
645
+ N (int): N dimension of input tensor.
646
+ K (int): K dimension of input tensor.
647
+ m_key (int): Autotuning key for M dimension of input tensor.
648
+ n_key (int): Autotuning key for N dimension of input tensor.
649
+ k_key (int): Autotuning key for K dimension of input tensor.
650
+ A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
651
+ B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
652
+ Bias (TensorWrapper): [N] Optional bias tensor.
653
+ stride_am (int): Stride of M dimension of A.
654
+ stride_ak (int): Stride of K dimension of A.
655
+ stride_bn (int): Stride of N dimension of B.
656
+ stride_bk (int): Stride of K dimension of B.
657
+ stride_cm (int): Stride of M dimension of C.
658
+ stride_cn (int): Stride of N dimension of C.
659
+ dot_out_dtype (torch.dtype): Output type of tensor core.
660
+ allow_tf32 (bool): Whether to use TF32 for tensor core.
661
+ fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
662
+ BLOCK_M (int): Block size for M dimension.
663
+ BLOCK_N (int): Block size for N dimension.
664
+ BLOCK_K (int): Block size for K dimension.
665
+ GROUP_M (int): Number of groups for M dimension swizzle.
666
+ SPLIT_K (int): Number of SM's to launch per row.
667
+ EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
668
+ USE_BIAS (bool): Whether to use bias.
669
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
670
+ """
671
+ # Matrix multiplication.
672
+ pid = tl.program_id(0)
673
+ pid_z = tl.program_id(1)
674
+ grid_m = tl.cdiv(M, BLOCK_M)
675
+ grid_n = tl.cdiv(N, BLOCK_N)
676
+ # Re-order program ID for better L2 performance (swizzle).
677
+ width = GROUP_M * grid_n
678
+ group_id = pid // width
679
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
680
+ pid_m = group_id * GROUP_M + (pid % group_size)
681
+ pid_n = (pid % width) // (group_size)
682
+ # Do matrix multiplication.
683
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
684
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
685
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
686
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
687
+ rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
688
+ # Pointers.
689
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
690
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
691
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
692
+
693
+ for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
694
+ if EVEN_K:
695
+ a = tl.load(A)
696
+ b = tl.load(B)
697
+ else:
698
+ k_remaining = K - k * (BLOCK_K * SPLIT_K)
699
+ _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
700
+ a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
701
+ b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
702
+ if AB_DTYPE:
703
+ a = a.to(C.dtype.element_ty)
704
+ b = b.to(C.dtype.element_ty)
705
+ if fp8_fast_accum:
706
+ acc = tl.dot(
707
+ a,
708
+ b,
709
+ acc,
710
+ max_num_imprecise_acc=32,
711
+ out_dtype=dot_out_dtype,
712
+ allow_tf32=allow_tf32,
713
+ )
714
+ else:
715
+ acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
716
+
717
+ A += BLOCK_K * SPLIT_K * stride_ak
718
+ B += BLOCK_K * SPLIT_K * stride_bk
719
+
720
+ # rematerialize rm and rn to save registers
721
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
722
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
723
+
724
+ # Invert scaling.
725
+ a_scale = tl.load(A_scale + rm, mask=rm < M)
726
+ b_scale = tl.load(B_scale + rn, mask=rn < N)
727
+ # pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`.
728
+ scale = a_scale[:, None] * b_scale[None, :]
729
+ acc *= scale
730
+
731
+ # Apply bias.
732
+ if USE_BIAS:
733
+ bias = tl.load(Bias + rn, mask=rn < N)
734
+ acc += bias[None, :]
735
+
736
+ acc = acc.to(C.dtype.element_ty)
737
+ C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
738
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
739
+ # Handles write-back with reduction-splitting
740
+ if SPLIT_K == 1:
741
+ tl.store(C, acc, mask=mask)
742
+ else:
743
+ tl.atomic_add(C, acc, mask=mask)
744
+
745
+
746
+ @triton.autotune(
747
+ configs=[
748
+ Config(
749
+ {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
750
+ num_stages=3,
751
+ num_warps=8,
752
+ ),
753
+ Config(
754
+ {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
755
+ num_stages=3,
756
+ num_warps=8,
757
+ ),
758
+ Config(
759
+ {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1},
760
+ num_stages=4,
761
+ num_warps=4,
762
+ ),
763
+ Config(
764
+ {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
765
+ num_stages=4,
766
+ num_warps=4,
767
+ ),
768
+ Config(
769
+ {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
770
+ num_stages=4,
771
+ num_warps=4,
772
+ ),
773
+ Config(
774
+ {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
775
+ num_stages=4,
776
+ num_warps=4,
777
+ ),
778
+ Config(
779
+ {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
780
+ num_stages=4,
781
+ num_warps=4,
782
+ ),
783
+ Config(
784
+ {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 512, "SPLIT_K": 1},
785
+ num_stages=3,
786
+ num_warps=4,
787
+ ),
788
+ ],
789
+ key=[
790
+ "m_key",
791
+ "n_key",
792
+ "k_key",
793
+ ],
794
+ use_cuda_graph=True,
795
+ )
796
+ @triton.heuristics(
797
+ {
798
+ "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
799
+ }
800
+ )
801
+ @triton.jit
802
+ def _kernel_matmul_fp8_row_tma_persistent(
803
+ A_ptr,
804
+ B_ptr,
805
+ C_ptr,
806
+ M,
807
+ N,
808
+ K,
809
+ m_key,
810
+ n_key,
811
+ k_key,
812
+ A_scale,
813
+ B_scale,
814
+ Bias,
815
+ stride_am,
816
+ stride_ak,
817
+ stride_bn,
818
+ stride_bk,
819
+ stride_cm,
820
+ stride_cn,
821
+ dot_out_dtype: tl.constexpr,
822
+ c_dtype: tl.constexpr,
823
+ bias_dtype: tl.constexpr,
824
+ allow_tf32: tl.constexpr,
825
+ fp8_fast_accum: tl.constexpr,
826
+ BLOCK_M: tl.constexpr,
827
+ BLOCK_N: tl.constexpr,
828
+ BLOCK_K: tl.constexpr,
829
+ GROUP_M: tl.constexpr,
830
+ AB_DTYPE: tl.constexpr,
831
+ SPLIT_K: tl.constexpr,
832
+ EVEN_K: tl.constexpr,
833
+ NUM_SMS: tl.constexpr,
834
+ USE_BIAS: tl.constexpr,
835
+ ) -> None:
836
+ """Matmul kernel of [M, K] @ [N, K] with row-wise scales
837
+
838
+ performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
839
+
840
+ Args:
841
+ A (TensorWrapper): [M, K] input tensor.
842
+ B (TensorWrapper): [N, K] input tensor.
843
+ C (TensorWrapper): [M, N] output tensor.
844
+ M (int): M dimension of input tensor.
845
+ N (int): N dimension of input tensor.
846
+ K (int): K dimension of input tensor.
847
+ A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
848
+ B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
849
+ stride_am (int): Stride of M dimension of A.
850
+ stride_ak (int): Stride of K dimension of A.
851
+ stride_bn (int): Stride of N dimension of B.
852
+ stride_bk (int): Stride of K dimension of B.
853
+ stride_cm (int): Stride of M dimension of C.
854
+ stride_cn (int): Stride of N dimension of C.
855
+ dot_out_dtype (torch.dtype): Output type of tensor core.
856
+ allow_tf32 (bool): Whether to use TF32 for tensor core.
857
+ fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
858
+ BLOCK_M (int): Block size for M dimension.
859
+ BLOCK_N (int): Block size for N dimension.
860
+ BLOCK_K (int): Block size for K dimension.
861
+ GROUP_M (int): Number of groups for M dimension swizzle.
862
+ SPLIT_K (int): Number of SM's to launch per row.
863
+ EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
864
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
865
+ """
866
+ # Matrix multiplication.
867
+ start_pid = tl.program_id(axis=0)
868
+ num_pid_m = tl.cdiv(M, BLOCK_M)
869
+ num_pid_n = tl.cdiv(N, BLOCK_N)
870
+ k_tiles = tl.cdiv(K, BLOCK_K)
871
+ num_tiles = num_pid_m * num_pid_n
872
+
873
+ tiles_per_SM = num_tiles // NUM_SMS
874
+ if start_pid < num_tiles % NUM_SMS:
875
+ tiles_per_SM += 1
876
+
877
+ tile_id = start_pid - NUM_SMS
878
+ ki = -1
879
+
880
+ pid_m = 0
881
+ pid_n = 0
882
+ offs_am = 0
883
+ offs_bn = 0
884
+
885
+ num_pid_in_group = GROUP_M * num_pid_n
886
+
887
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
888
+
889
+ dtype_fp8 = tl.float8e4nv
890
+ scale_dtype = tl.float32
891
+
892
+ for _ in range(0, k_tiles * tiles_per_SM):
893
+ ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
894
+ if ki == 0:
895
+ tile_id += NUM_SMS
896
+ group_id = tile_id // num_pid_in_group
897
+ first_pid_m = group_id * GROUP_M
898
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
899
+ pid_m = first_pid_m + (tile_id % group_size_m)
900
+ pid_n = (tile_id % num_pid_in_group) // group_size_m
901
+
902
+ offs_am = pid_m * BLOCK_M
903
+ offs_bn = pid_n * BLOCK_N
904
+ offs_am = tl.multiple_of(offs_am, BLOCK_M)
905
+ offs_bn = tl.multiple_of(offs_bn, BLOCK_N)
906
+
907
+ offs_k = ki * BLOCK_K
908
+
909
+ a = tl._experimental_descriptor_load(
910
+ A_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], dtype_fp8
911
+ )
912
+ b = tl._experimental_descriptor_load(
913
+ B_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], dtype_fp8
914
+ )
915
+
916
+ if fp8_fast_accum:
917
+ acc = tl.dot(a, b.T, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
918
+ else:
919
+ acc += tl.dot(a, b.T, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
920
+
921
+ if ki == k_tiles - 1:
922
+ # rematerialize rm and rn to save registers
923
+
924
+ # # Invert scaling.
925
+ a_scale = tl._experimental_descriptor_load(
926
+ A_scale, [offs_am], [BLOCK_M], scale_dtype
927
+ )
928
+ b_scale = tl._experimental_descriptor_load(
929
+ B_scale, [offs_bn], [BLOCK_N], scale_dtype
930
+ )
931
+ # pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`.
932
+ scale = a_scale[:, None] * b_scale[None, :]
933
+ acc *= scale
934
+
935
+ # Load and add bias if specified.
936
+ if USE_BIAS:
937
+ bias = tl._experimental_descriptor_load(
938
+ Bias, [offs_bn], [BLOCK_N], bias_dtype
939
+ )
940
+ acc += bias[None, :]
941
+
942
+ acc = acc.to(c_dtype)
943
+ tl._experimental_descriptor_store(C_ptr, acc, [offs_am, offs_bn])
944
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
945
+
946
+
947
+ has_warp_specialization = hasattr(tl, "async_task")
948
+
949
+
950
+ def make_autotuner_config(dictargs, **kwargs):
951
+ # NOTE: Triton 3.4.x removed some keyword arguments from Config constructor;
952
+ # however, fbcode uses 3.3.1, and so this shim is provided to support both
953
+ # versions.
954
+ #
955
+ # https://github.com/triton-lang/triton/blob/v3.3.1/python/triton/runtime/autotuner.py#L275
956
+ # https://github.com/triton-lang/triton/blame/release/3.4.x/python/triton/runtime/autotuner.py#L319
957
+ if version.parse(triton.__version__) > version.parse("3.3.1"):
958
+ for key in ["num_buffers_warp_spec", "num_consumer_groups"]:
959
+ kwargs.pop(key, None)
960
+ return Config(dictargs, **kwargs)
961
+
962
+
963
+ def get_ws_configs() -> list[Config]:
964
+ if not has_warp_specialization:
965
+ return []
966
+ return [
967
+ make_autotuner_config(
968
+ {
969
+ "BLOCK_M": 128,
970
+ "BLOCK_N": 256,
971
+ "BLOCK_K": 128,
972
+ "SPLIT_K": 1,
973
+ "NUM_CONSUMER_GROUPS": 2,
974
+ },
975
+ num_stages=3,
976
+ num_warps=4,
977
+ num_consumer_groups=2,
978
+ num_buffers_warp_spec=3,
979
+ ),
980
+ make_autotuner_config(
981
+ {
982
+ "BLOCK_M": 128,
983
+ "BLOCK_N": 128,
984
+ "BLOCK_K": 128,
985
+ "SPLIT_K": 1,
986
+ "NUM_CONSUMER_GROUPS": 2,
987
+ },
988
+ num_stages=4,
989
+ num_warps=4,
990
+ num_consumer_groups=2,
991
+ num_buffers_warp_spec=4,
992
+ ),
993
+ make_autotuner_config(
994
+ {
995
+ "BLOCK_M": 128,
996
+ "BLOCK_N": 256,
997
+ "BLOCK_K": 128,
998
+ "SPLIT_K": 1,
999
+ "NUM_CONSUMER_GROUPS": 1,
1000
+ },
1001
+ num_stages=3,
1002
+ num_warps=8,
1003
+ num_consumer_groups=0,
1004
+ num_buffers_warp_spec=3,
1005
+ ),
1006
+ make_autotuner_config(
1007
+ {
1008
+ "BLOCK_M": 64,
1009
+ "BLOCK_N": 64,
1010
+ "BLOCK_K": 512,
1011
+ "SPLIT_K": 1,
1012
+ "NUM_CONSUMER_GROUPS": 1,
1013
+ },
1014
+ num_stages=3,
1015
+ num_warps=4,
1016
+ num_consumer_groups=0,
1017
+ num_buffers_warp_spec=3,
1018
+ ),
1019
+ ]
1020
+
1021
+
1022
+ @triton.autotune(
1023
+ configs=[
1024
+ Config(
1025
+ {
1026
+ "BLOCK_M": 128,
1027
+ "BLOCK_N": 256,
1028
+ "BLOCK_K": 128,
1029
+ "SPLIT_K": 1,
1030
+ "NUM_CONSUMER_GROUPS": 1,
1031
+ },
1032
+ num_stages=3,
1033
+ num_warps=8,
1034
+ ),
1035
+ ]
1036
+ + get_ws_configs(),
1037
+ key=[
1038
+ "m_key",
1039
+ "n_key",
1040
+ "k_key",
1041
+ ],
1042
+ use_cuda_graph=True,
1043
+ )
1044
+ @triton.heuristics(
1045
+ {
1046
+ "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
1047
+ }
1048
+ )
1049
+ @triton.jit
1050
+ def _kernel_matmul_fp8_row_tma_persistent_ws_cooperative(
1051
+ A_ptr,
1052
+ B_ptr,
1053
+ C_ptr,
1054
+ M,
1055
+ N,
1056
+ K,
1057
+ m_key,
1058
+ n_key,
1059
+ k_key,
1060
+ A_scale,
1061
+ B_scale,
1062
+ Bias,
1063
+ stride_am,
1064
+ stride_ak,
1065
+ stride_bn,
1066
+ stride_bk,
1067
+ stride_cm,
1068
+ stride_cn,
1069
+ dot_out_dtype: tl.constexpr,
1070
+ c_dtype: tl.constexpr,
1071
+ bias_dtype: tl.constexpr,
1072
+ allow_tf32: tl.constexpr,
1073
+ fp8_fast_accum: tl.constexpr,
1074
+ BLOCK_M: tl.constexpr,
1075
+ BLOCK_N: tl.constexpr,
1076
+ BLOCK_K: tl.constexpr,
1077
+ GROUP_M: tl.constexpr,
1078
+ AB_DTYPE: tl.constexpr,
1079
+ SPLIT_K: tl.constexpr,
1080
+ EVEN_K: tl.constexpr,
1081
+ NUM_SMS: tl.constexpr,
1082
+ USE_BIAS: tl.constexpr,
1083
+ NUM_CONSUMER_GROUPS: tl.constexpr,
1084
+ ) -> None:
1085
+ """Matmul kernel of [M, K] @ [N, K] with row-wise scales
1086
+
1087
+ performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
1088
+
1089
+ Args:
1090
+ A (TensorWrapper): [M , K] input tensor.
1091
+ B (TensorWrapper): [N, K] input tensor.
1092
+ C (TensorWrapper): [M, N] output tensor.
1093
+ M (int): M dimension of input tensor.
1094
+ N (int): N dimension of input tensor.
1095
+ K (int): K dimension of input tensor.
1096
+ A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
1097
+ B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
1098
+ stride_am (int): Stride of M dimension of A.
1099
+ stride_ak (int): Stride of K dimension of A.
1100
+ stride_bn (int): Stride of N dimension of B.
1101
+ stride_bk (int): Stride of K dimension of B.
1102
+ stride_cm (int): Stride of M dimension of C.
1103
+ stride_cn (int): Stride of N dimension of C.
1104
+ dot_out_dtype (torch.dtype): Output type of tensor core.
1105
+ allow_tf32 (bool): Whether to use TF32 for tensor core.
1106
+ fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
1107
+ BLOCK_M (int): Block size for M dimension.
1108
+ BLOCK_N (int): Block size for N dimension.
1109
+ BLOCK_K (int): Block size for K dimension.
1110
+ GROUP_M (int): Number of groups for M dimension swizzle.
1111
+ SPLIT_K (int): Number of SM's to launch per row.
1112
+ EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
1113
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
1114
+ """
1115
+ num_tiles = tl.cdiv(M, BLOCK_M) * tl.cdiv(N, BLOCK_N)
1116
+ num_pid_m = tl.cdiv(M, BLOCK_M)
1117
+ num_pid_n = tl.cdiv(N, BLOCK_N)
1118
+ dtype_fp8 = tl.float8e4nv
1119
+ for pid in range(tl.program_id(0), num_tiles, tl.num_programs(0)):
1120
+ num_pid_in_group = GROUP_M * num_pid_n
1121
+ group_id = pid // num_pid_in_group
1122
+ first_pid_m = group_id * GROUP_M
1123
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
1124
+ # pyre-ignore
1125
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
1126
+ pid_n = (pid % num_pid_in_group) // group_size_m
1127
+
1128
+ # ----------------------------------------------------------
1129
+ # Create pointers for the first blocks of A and B.
1130
+ # We will advance this pointer as we move in the K direction
1131
+ # and accumulate
1132
+ # `a_ptrs` is a block of [BLOCK_M, BLOCK_K] pointers
1133
+ # `b_ptrs` is a block of [BLOCK_K, BLOCK_N] pointers
1134
+ # See above `Pointer Arithmetic` section for details
1135
+ offs_am = pid_m * BLOCK_M
1136
+ offs_bn = pid_n * BLOCK_N
1137
+ offs_k = 0
1138
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
1139
+ # pyre-ignore
1140
+ tl.assume(tl.cdiv(K, BLOCK_K) > 0)
1141
+ for _ in range(0, tl.cdiv(K, BLOCK_K)):
1142
+ # pyre-ignore
1143
+ with tl.async_task([0]):
1144
+ a = tl._experimental_descriptor_load(
1145
+ A_ptr,
1146
+ [offs_am, offs_k],
1147
+ [BLOCK_M, BLOCK_K],
1148
+ dtype_fp8,
1149
+ )
1150
+ b = tl._experimental_descriptor_load(
1151
+ B_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], dtype_fp8
1152
+ )
1153
+
1154
+ if fp8_fast_accum:
1155
+ acc = tl.dot(
1156
+ a, b.T, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32
1157
+ )
1158
+ else:
1159
+ acc += tl.dot(a, b.T, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
1160
+
1161
+ offs_k += BLOCK_K
1162
+
1163
+ # pyre-ignore
1164
+ with tl.async_task([1, NUM_CONSUMER_GROUPS]):
1165
+ # Invert scaling.
1166
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
1167
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
1168
+ a_scale = tl.load(A_scale + rm, mask=rm < M)
1169
+ b_scale = tl.load(B_scale + rn, mask=rn < N)
1170
+ scale = a_scale[:, None] * b_scale[None, :]
1171
+ acc *= scale
1172
+ # Load and add bias if specified.
1173
+ if USE_BIAS:
1174
+ bias = tl._experimental_descriptor_load(
1175
+ Bias, [offs_bn], [BLOCK_N], bias_dtype
1176
+ )
1177
+ acc += bias[None, :]
1178
+ acc = acc.to(c_dtype)
1179
+ tl._experimental_descriptor_store(C_ptr, acc, [offs_am, offs_bn])
1180
+
1181
+
1182
+ def _is_eligible_for_skip_scaling(
1183
+ is_rowwise: bool,
1184
+ fp8_fast_accum: bool,
1185
+ imprecise_acc: bool,
1186
+ tma_persistent: bool,
1187
+ no_use_persistent: Optional[bool],
1188
+ use_warp_specialization: bool,
1189
+ ) -> bool:
1190
+ if not is_rowwise:
1191
+ return False
1192
+
1193
+ return (
1194
+ fp8_fast_accum
1195
+ and not imprecise_acc
1196
+ and not tma_persistent
1197
+ and not no_use_persistent
1198
+ and not use_warp_specialization
1199
+ )
1200
+
1201
+
1202
+ @torch._library.triton_op("triton::matmul_fp8_row", mutates_args=())
1203
+ def matmul_fp8_row(
1204
+ a: torch.Tensor,
1205
+ b: torch.Tensor,
1206
+ a_scale: Optional[torch.Tensor],
1207
+ b_scale: torch.Tensor,
1208
+ bias: Optional[torch.Tensor] = None,
1209
+ dot_out_dtype: Optional[torch.dtype] = None,
1210
+ allow_tf32: bool = True,
1211
+ fp8_fast_accum: bool = True,
1212
+ imprecise_acc: bool = False,
1213
+ tma_persistent: bool = True,
1214
+ no_use_persistent: Optional[bool] = None,
1215
+ # add an option to explicitly require the use of persistent process
1216
+ use_persistent: Optional[bool] = None,
1217
+ use_warp_specialization: bool = False,
1218
+ ) -> torch.Tensor:
1219
+ """
1220
+ Performs matmul on [M, K] and [N, K] fp8 matrices with row-wise scalings [M], [N].
1221
+
1222
+ Args:
1223
+ a (torch.Tensor): [M, K] input tensor.
1224
+ b (torch.Tensor): [N, K] input tensor.
1225
+ a_scale (Optiona;[torch.Tensor]): [M] reciprocal scale tensor per row.
1226
+ A * a_scale = original A. Scaling will be skiped if a_scale is None.
1227
+ b_scale (torch.Tensor): [N] reciprocal scale tensor per row. B * b_scale = original B
1228
+ bias (torch.Tensor): [N] optional bias tensor to add to output if provided.
1229
+ dot_out_dtype (torch.dtype): Output type of tensor core.
1230
+ allow_tf32 (bool): Whether to use TF32 for tensor core.
1231
+ fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
1232
+ tma_persistent (bool): Whether to use TMA persistent kernel impl.
1233
+
1234
+ Returns:
1235
+ torch.Tensor: [M, N] Output tensor a @ b / (a_scale[:, None] * b_scale[None, :])
1236
+ """
1237
+ if use_persistent:
1238
+ no_use_persistent = False
1239
+ elif no_use_persistent is None:
1240
+ # Default True for AMD and False for Nvidia.
1241
+ if torch.version.hip is not None:
1242
+ no_use_persistent = True
1243
+ else:
1244
+ no_use_persistent = False
1245
+ # if use_persistent is explicitly requested, set o_use_persistent to False
1246
+
1247
+ # Get datatypes and constants to use.
1248
+ pt_fp8_dtype, _, _, _ = get_fp8_constants()
1249
+ # Handle 3D+ a shape
1250
+ a_shape = a.shape
1251
+ a = a.view(-1, a.size(-1))
1252
+ # View inputs into proper torch fp8 dtype.
1253
+ if torch.version.cuda:
1254
+ assert a.dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
1255
+ elif torch.version.hip:
1256
+ if torch.cuda.get_device_capability() < (9, 5):
1257
+ assert a.dtype in (
1258
+ torch.float8_e4m3fnuz,
1259
+ torch.float8_e5m2fnuz,
1260
+ )
1261
+ else:
1262
+ assert a.dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
1263
+ else:
1264
+ assert a.dtype in (
1265
+ torch.float8_e4m3fnuz,
1266
+ torch.float8_e5m2fnuz,
1267
+ )
1268
+ assert b.dtype == pt_fp8_dtype
1269
+ M, N, K, m_key, n_key, k_key, c, c_dtype_triton, dot_out_dtype_triton, device = (
1270
+ prep_matmul(a, b, dot_out_dtype)
1271
+ )
1272
+
1273
+ # Skip scaling (a_scale is None) can only be applied in certain cases.
1274
+ assert a_scale is not None or _is_eligible_for_skip_scaling(
1275
+ is_rowwise=True,
1276
+ fp8_fast_accum=fp8_fast_accum,
1277
+ imprecise_acc=imprecise_acc,
1278
+ tma_persistent=tma_persistent,
1279
+ no_use_persistent=no_use_persistent,
1280
+ use_warp_specialization=use_warp_specialization,
1281
+ )
1282
+
1283
+ output_shape = a_shape[:-1] + (N,)
1284
+ # Handle tensor with empty inputs.
1285
+ if (M == 0) or (N == 0) or (K == 0):
1286
+ return torch.zeros(output_shape, device=device, dtype=torch.bfloat16)
1287
+ # launch kernel
1288
+ if a.device == torch.device("cpu"):
1289
+ logger.info(
1290
+ "FP8 Row-wise Triton kernel not supported on cpu, fallback to torch"
1291
+ )
1292
+ if a_scale is None:
1293
+ scale = b_scale[None, :]
1294
+ else:
1295
+ scale = a_scale[:, None] * b_scale[None, :]
1296
+ output = torch.matmul(a.to(torch.bfloat16), b.to(torch.bfloat16).T) * scale
1297
+ if bias is not None:
1298
+ output += bias[None, :]
1299
+ return output.to(c.dtype)
1300
+
1301
+ def grid(META: dict[str, int]) -> tuple[int, int]:
1302
+ return (
1303
+ triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
1304
+ META["SPLIT_K"],
1305
+ )
1306
+
1307
+ NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
1308
+
1309
+ def persistent_grid(META: dict[str, int]) -> tuple[int]:
1310
+ return (
1311
+ min(
1312
+ NUM_SMS,
1313
+ triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
1314
+ ),
1315
+ )
1316
+
1317
+ if no_use_persistent:
1318
+ logger.debug("Using non-persistent kernel")
1319
+ with torch.cuda.device(a.device.index):
1320
+ torch._library.capture_triton(_kernel_matmul_fp8_row_non_persistent)[grid](
1321
+ a,
1322
+ b,
1323
+ c,
1324
+ M,
1325
+ N,
1326
+ K,
1327
+ m_key,
1328
+ n_key,
1329
+ k_key,
1330
+ a_scale,
1331
+ b_scale,
1332
+ bias,
1333
+ a.stride(0),
1334
+ a.stride(1),
1335
+ b.stride(0),
1336
+ b.stride(1),
1337
+ c.stride(0),
1338
+ c.stride(1),
1339
+ dot_out_dtype=dot_out_dtype_triton,
1340
+ allow_tf32=allow_tf32,
1341
+ fp8_fast_accum=fp8_fast_accum,
1342
+ # GROUP_M=8,
1343
+ USE_BIAS=bias is not None,
1344
+ AB_DTYPE=False,
1345
+ )
1346
+ elif use_warp_specialization:
1347
+ assert has_warp_specialization
1348
+ # used by TMA warp specialization kernel
1349
+ desc_helper = TmaAutoTuneHelper()
1350
+ desc_helper.init_tma_descriptor("a")
1351
+ desc_helper.init_tma_descriptor("b")
1352
+ desc_helper.init_tma_descriptor("c")
1353
+ desc_helper.init_tma_descriptor("a_scale")
1354
+ desc_helper.init_tma_descriptor("b_scale")
1355
+ desc_helper.init_tma_descriptor("bias")
1356
+
1357
+ def persistent_grid_tma_ws(META: dict[str, int]) -> tuple[int]:
1358
+ nonlocal desc_helper # noqa: F824
1359
+ assert a_scale is not None # Type narrowing for Pyre
1360
+ desc_helper.fill_2d_tma_descriptor(
1361
+ "a",
1362
+ a.data_ptr(),
1363
+ M,
1364
+ K,
1365
+ META["BLOCK_M"] // META["NUM_CONSUMER_GROUPS"],
1366
+ META["BLOCK_K"],
1367
+ a.element_size(),
1368
+ )
1369
+
1370
+ desc_helper.fill_2d_tma_descriptor(
1371
+ "b",
1372
+ b.data_ptr(),
1373
+ N,
1374
+ K,
1375
+ META["BLOCK_N"],
1376
+ META["BLOCK_K"],
1377
+ b.element_size(),
1378
+ )
1379
+ desc_helper.fill_2d_tma_descriptor(
1380
+ "c",
1381
+ c.data_ptr(),
1382
+ M,
1383
+ N,
1384
+ META["BLOCK_M"] // META["NUM_CONSUMER_GROUPS"],
1385
+ META["BLOCK_N"],
1386
+ c.element_size(),
1387
+ )
1388
+ desc_helper.fill_1d_tma_descriptor(
1389
+ "a_scale",
1390
+ a_scale.data_ptr(),
1391
+ M,
1392
+ META["BLOCK_M"] // META["NUM_CONSUMER_GROUPS"],
1393
+ a_scale.element_size(),
1394
+ )
1395
+ desc_helper.fill_1d_tma_descriptor(
1396
+ "b_scale",
1397
+ b_scale.data_ptr(),
1398
+ N,
1399
+ META["BLOCK_N"],
1400
+ b_scale.element_size(),
1401
+ )
1402
+ if bias is not None:
1403
+ desc_helper.fill_1d_tma_descriptor(
1404
+ "bias",
1405
+ bias.data_ptr(),
1406
+ N,
1407
+ META["BLOCK_N"],
1408
+ bias.element_size(),
1409
+ )
1410
+ return (
1411
+ min(
1412
+ NUM_SMS,
1413
+ triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
1414
+ ),
1415
+ )
1416
+
1417
+ desc_a = desc_helper.get_tma_descriptor_kernel_param("a")
1418
+ desc_b = desc_helper.get_tma_descriptor_kernel_param("b")
1419
+ desc_c = desc_helper.get_tma_descriptor_kernel_param("c")
1420
+ desc_a_scale = desc_helper.get_tma_descriptor_kernel_param("a_scale")
1421
+ desc_b_scale = desc_helper.get_tma_descriptor_kernel_param("b_scale")
1422
+ desc_bias = desc_helper.get_tma_descriptor_kernel_param("bias")
1423
+
1424
+ bias_dtype_triton = None
1425
+ if bias is not None:
1426
+ bias_dtype_triton = map_dtype_to_triton(bias.dtype)
1427
+
1428
+ # pyre-ignore
1429
+ torch._library.capture_triton(
1430
+ _kernel_matmul_fp8_row_tma_persistent_ws_cooperative
1431
+ )[persistent_grid_tma_ws](
1432
+ desc_a,
1433
+ desc_b,
1434
+ desc_c,
1435
+ M,
1436
+ N,
1437
+ K,
1438
+ m_key,
1439
+ n_key,
1440
+ k_key,
1441
+ a_scale,
1442
+ b_scale,
1443
+ desc_bias,
1444
+ a.stride(0),
1445
+ a.stride(1),
1446
+ b.stride(0),
1447
+ b.stride(1),
1448
+ c.stride(0),
1449
+ c.stride(1),
1450
+ dot_out_dtype=dot_out_dtype_triton,
1451
+ c_dtype=c_dtype_triton,
1452
+ bias_dtype=bias_dtype_triton,
1453
+ allow_tf32=allow_tf32,
1454
+ fp8_fast_accum=fp8_fast_accum,
1455
+ GROUP_M=8,
1456
+ AB_DTYPE=False,
1457
+ NUM_SMS=NUM_SMS,
1458
+ USE_BIAS=bias is not None,
1459
+ )
1460
+ elif tma_persistent:
1461
+ # used by TMA persistent kernel
1462
+ desc_helper = TmaAutoTuneHelper()
1463
+ desc_helper.init_tma_descriptor("a")
1464
+ desc_helper.init_tma_descriptor("b")
1465
+ desc_helper.init_tma_descriptor("c")
1466
+ desc_helper.init_tma_descriptor("a_scale")
1467
+ desc_helper.init_tma_descriptor("b_scale")
1468
+ desc_helper.init_tma_descriptor("bias")
1469
+
1470
+ def persistent_grid_tma(META: dict[str, int]) -> tuple[int]:
1471
+ nonlocal desc_helper # noqa: F824
1472
+ assert a_scale is not None # Type narrowing for Pyre
1473
+ desc_helper.fill_2d_tma_descriptor(
1474
+ "a",
1475
+ a.data_ptr(),
1476
+ M,
1477
+ K,
1478
+ META["BLOCK_M"],
1479
+ META["BLOCK_K"],
1480
+ a.element_size(),
1481
+ )
1482
+
1483
+ desc_helper.fill_2d_tma_descriptor(
1484
+ "b",
1485
+ b.data_ptr(),
1486
+ N,
1487
+ K,
1488
+ META["BLOCK_N"],
1489
+ META["BLOCK_K"],
1490
+ b.element_size(),
1491
+ )
1492
+ desc_helper.fill_2d_tma_descriptor(
1493
+ "c",
1494
+ c.data_ptr(),
1495
+ M,
1496
+ N,
1497
+ META["BLOCK_M"],
1498
+ META["BLOCK_N"],
1499
+ c.element_size(),
1500
+ )
1501
+ desc_helper.fill_1d_tma_descriptor(
1502
+ "a_scale",
1503
+ a_scale.data_ptr(),
1504
+ M,
1505
+ META["BLOCK_M"],
1506
+ a_scale.element_size(),
1507
+ )
1508
+ desc_helper.fill_1d_tma_descriptor(
1509
+ "b_scale",
1510
+ b_scale.data_ptr(),
1511
+ N,
1512
+ META["BLOCK_N"],
1513
+ b_scale.element_size(),
1514
+ )
1515
+ if bias is not None:
1516
+ desc_helper.fill_1d_tma_descriptor(
1517
+ "bias",
1518
+ bias.data_ptr(),
1519
+ N,
1520
+ META["BLOCK_N"],
1521
+ bias.element_size(),
1522
+ )
1523
+ return (
1524
+ min(
1525
+ NUM_SMS,
1526
+ triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
1527
+ ),
1528
+ )
1529
+
1530
+ desc_a = desc_helper.get_tma_descriptor_kernel_param("a")
1531
+ desc_b = desc_helper.get_tma_descriptor_kernel_param("b")
1532
+ desc_c = desc_helper.get_tma_descriptor_kernel_param("c")
1533
+ desc_a_scale = desc_helper.get_tma_descriptor_kernel_param("a_scale")
1534
+ desc_b_scale = desc_helper.get_tma_descriptor_kernel_param("b_scale")
1535
+ desc_bias = desc_helper.get_tma_descriptor_kernel_param("bias")
1536
+
1537
+ bias_dtype_triton = None
1538
+ if bias is not None:
1539
+ bias_dtype_triton = map_dtype_to_triton(bias.dtype)
1540
+
1541
+ # pyre-ignore
1542
+ torch._library.capture_triton(_kernel_matmul_fp8_row_tma_persistent)[
1543
+ persistent_grid_tma
1544
+ ](
1545
+ desc_a,
1546
+ desc_b,
1547
+ desc_c,
1548
+ M,
1549
+ N,
1550
+ K,
1551
+ m_key,
1552
+ n_key,
1553
+ k_key,
1554
+ desc_a_scale,
1555
+ desc_b_scale,
1556
+ desc_bias,
1557
+ a.stride(0),
1558
+ a.stride(1),
1559
+ b.stride(0),
1560
+ b.stride(1),
1561
+ c.stride(0),
1562
+ c.stride(1),
1563
+ dot_out_dtype=dot_out_dtype_triton,
1564
+ c_dtype=c_dtype_triton,
1565
+ bias_dtype=bias_dtype_triton,
1566
+ allow_tf32=allow_tf32,
1567
+ fp8_fast_accum=fp8_fast_accum,
1568
+ GROUP_M=8,
1569
+ AB_DTYPE=False,
1570
+ NUM_SMS=NUM_SMS,
1571
+ USE_BIAS=bias is not None,
1572
+ )
1573
+ elif imprecise_acc:
1574
+ with torch.cuda.device(a.device.index):
1575
+ torch._library.capture_triton(_kernel_matmul_fp8_row_imprecise_acc)[grid](
1576
+ a,
1577
+ b,
1578
+ c,
1579
+ M,
1580
+ N,
1581
+ K,
1582
+ m_key,
1583
+ n_key,
1584
+ k_key,
1585
+ a_scale,
1586
+ b_scale,
1587
+ bias,
1588
+ a.stride(0),
1589
+ a.stride(1),
1590
+ b.stride(0),
1591
+ b.stride(1),
1592
+ c.stride(0),
1593
+ c.stride(1),
1594
+ dot_out_dtype=dot_out_dtype_triton,
1595
+ allow_tf32=allow_tf32,
1596
+ fp8_fast_accum=fp8_fast_accum,
1597
+ GROUP_M=8,
1598
+ USE_BIAS=bias is not None,
1599
+ AB_DTYPE=False,
1600
+ )
1601
+ elif fp8_fast_accum:
1602
+ skip_scaling_a = a_scale is None
1603
+ torch._library.capture_triton(_kernel_matmul_fp8_row)[persistent_grid](
1604
+ a,
1605
+ b,
1606
+ c,
1607
+ M,
1608
+ N,
1609
+ K,
1610
+ m_key,
1611
+ n_key,
1612
+ k_key,
1613
+ a_scale,
1614
+ b_scale,
1615
+ bias,
1616
+ a.stride(0),
1617
+ a.stride(1),
1618
+ b.stride(0),
1619
+ b.stride(1),
1620
+ c.stride(0),
1621
+ c.stride(1),
1622
+ dot_out_dtype=dot_out_dtype_triton,
1623
+ allow_tf32=allow_tf32,
1624
+ fp8_fast_accum=fp8_fast_accum,
1625
+ skip_scaling_a=skip_scaling_a,
1626
+ GROUP_M=8,
1627
+ USE_BIAS=bias is not None,
1628
+ AB_DTYPE=False,
1629
+ NUM_SMS=NUM_SMS,
1630
+ )
1631
+ else:
1632
+ torch._library.capture_triton(_kernel_matmul_fp8_row_no_fast_acc)[
1633
+ persistent_grid
1634
+ ](
1635
+ a,
1636
+ b,
1637
+ c,
1638
+ M,
1639
+ N,
1640
+ K,
1641
+ m_key,
1642
+ n_key,
1643
+ k_key,
1644
+ a_scale,
1645
+ b_scale,
1646
+ bias,
1647
+ a.stride(0),
1648
+ a.stride(1),
1649
+ b.stride(0),
1650
+ b.stride(1),
1651
+ c.stride(0),
1652
+ c.stride(1),
1653
+ dot_out_dtype=dot_out_dtype_triton,
1654
+ allow_tf32=allow_tf32,
1655
+ fp8_fast_accum=fp8_fast_accum,
1656
+ GROUP_M=8,
1657
+ USE_BIAS=bias is not None,
1658
+ AB_DTYPE=False,
1659
+ NUM_SMS=NUM_SMS,
1660
+ )
1661
+ return c.view(output_shape)
1662
+
1663
+
1664
+ @matmul_fp8_row.register_fake
1665
+ def matmul_fp8_row_meta(
1666
+ a: torch.Tensor,
1667
+ b: torch.Tensor,
1668
+ a_scale: Optional[torch.Tensor],
1669
+ b_scale: torch.Tensor,
1670
+ bias: Optional[torch.Tensor] = None,
1671
+ dot_out_dtype: Optional[torch.dtype] = None,
1672
+ allow_tf32: bool = True,
1673
+ fp8_fast_accum: bool = True,
1674
+ imprecise_acc: bool = False,
1675
+ tma_persistent: bool = True,
1676
+ no_use_persistent: Optional[bool] = None,
1677
+ use_warp_specialization: bool = False,
1678
+ ) -> torch.Tensor:
1679
+ """Shape function for torch compile."""
1680
+ M, K = a.shape
1681
+ N, K = b.shape
1682
+ return torch.empty(
1683
+ (M, N),
1684
+ device=a.device,
1685
+ dtype=torch.bfloat16 if dot_out_dtype is None else dot_out_dtype,
1686
+ )
1687
+
1688
+
1689
+ # pruned some unreasonable config
1690
+ def prune_configs_block(configs, named_args, **kwargs):
1691
+ configs = early_config_prune(configs, named_args, **kwargs)
1692
+ scale_block_k = named_args["scale_block_k"]
1693
+ pruned_configs = []
1694
+ # Further rule out configs with scale_block_k is not a multiple of BLOCK_K
1695
+ for config in configs:
1696
+ kw = config.kwargs
1697
+ BLOCK_K = kw["BLOCK_K"]
1698
+ if scale_block_k % BLOCK_K != 0:
1699
+ continue
1700
+ pruned_configs.append(config)
1701
+ return pruned_configs
1702
+
1703
+
1704
+ @triton.autotune(
1705
+ configs=MATMUL_CONFIGS,
1706
+ key=[
1707
+ "m_key",
1708
+ "n_key",
1709
+ "k_key",
1710
+ ], # TODO caller side bin keys so similar shapes can use same triton.autotune.
1711
+ prune_configs_by={
1712
+ "early_config_prune": prune_configs_block,
1713
+ "perf_model": estimate_matmul_time,
1714
+ "top_k": 10,
1715
+ },
1716
+ )
1717
+ @triton.heuristics(
1718
+ {
1719
+ "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
1720
+ }
1721
+ )
1722
+ @triton.jit
1723
+ def _kernel_matmul_fp8_block_fastacc(
1724
+ A,
1725
+ B,
1726
+ C,
1727
+ M,
1728
+ N,
1729
+ K,
1730
+ m_key,
1731
+ n_key,
1732
+ k_key,
1733
+ A_scale,
1734
+ B_scale,
1735
+ scale_block_m: tl.constexpr,
1736
+ scale_block_n: tl.constexpr,
1737
+ scale_block_k: tl.constexpr,
1738
+ stride_am,
1739
+ stride_ak,
1740
+ stride_bn,
1741
+ stride_bk,
1742
+ stride_cm,
1743
+ stride_cn,
1744
+ stride_scale_am,
1745
+ stride_scale_ak,
1746
+ stride_scale_bn,
1747
+ stride_scale_bk,
1748
+ dot_out_dtype: tl.constexpr,
1749
+ allow_tf32: tl.constexpr,
1750
+ BLOCK_M: tl.constexpr,
1751
+ BLOCK_N: tl.constexpr,
1752
+ BLOCK_K: tl.constexpr,
1753
+ GROUP_M: tl.constexpr,
1754
+ SPLIT_K: tl.constexpr,
1755
+ EVEN_K: tl.constexpr,
1756
+ AB_DTYPE: tl.constexpr,
1757
+ ) -> None:
1758
+ """Matmul kernel of [M, K] @ [N, K] with block-wise scales
1759
+
1760
+ Performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles and
1761
+ A and B scaled by a scaling factor per [scale_block_m, scale_block_k] and
1762
+ [scale_block_n, scale_block_k] tiles
1763
+ respectively.
1764
+
1765
+ Todo:
1766
+ * Support scale_block_{mnk} < BLOCK{MNK} for each dim.
1767
+ Args:
1768
+ A (TensorWrapper): [M, K] input tensor.
1769
+ B (TensorWrapper): [N, K] input tensor.
1770
+ C (TensorWrapper): [M, N] output tensor.
1771
+ M (int): M dimension of input tensor.
1772
+ N (int): N dimension of input tensor.
1773
+ K (int): K dimension of input tensor.
1774
+ m_key (int): Autotuning key for M dimension of input tensor.
1775
+ n_key (int): Autotuning key for N dimension of input tensor.
1776
+ k_key (int): Autotuning key for K dimension of input tensor.
1777
+ A_scale (TensorWrapper): [cdiv(M, scale_block_m), cdiv(K, scale_block_k)] reciprocal scale tensor per block. A * A_scale = original A
1778
+ B_scale (TensorWrapper): [cdiv(N, scale_block_n), cdiv(K, scale_block_k)] reciprocal scale tensor per block. B * B_scale = original B
1779
+ scale_block_m (int): Block size for M dimension of A_scale.
1780
+ scale_block_n (int): Block size for N dimension of B_scale.
1781
+ scale_block_k (int): Block size for K dimension of A_scale and B_scale.
1782
+ stride_am (int): Stride of M dimension of A.
1783
+ stride_ak (int): Stride of K dimension of A.
1784
+ stride_bn (int): Stride of N dimension of B.
1785
+ stride_bk (int): Stride of K dimension of B.
1786
+ stride_cm (int): Stride of M dimension of C.
1787
+ stride_cn (int): Stride of N dimension of C.
1788
+ stride_scale_am (int): Stride of M dimension of A_scale.
1789
+ stride_scale_ak (int): Stride of K dimension of A_scale.
1790
+ stride_scale_bn (int): Stride of N dimension of B_scale.
1791
+ stride_scale_bk (int): Stride of K dimension of B_scale.
1792
+ dot_out_dtype (torch.dtype): Output type of tensor core.
1793
+ allow_tf32 (bool): Whether to use TF32 for tensor core.
1794
+ fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
1795
+ BLOCK_M (int): Block size for M dimension.
1796
+ BLOCK_N (int): Block size for N dimension.
1797
+ BLOCK_K (int): Block size for K dimension.
1798
+ GROUP_M (int): Number of groups for M dimension swizzle.
1799
+ SPLIT_K (int): Number of SM's to launch per row.
1800
+ EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
1801
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
1802
+ """
1803
+ assert BLOCK_M < scale_block_m
1804
+ assert BLOCK_N < scale_block_n
1805
+ assert BLOCK_K < scale_block_k
1806
+ # matrix multiplication
1807
+ pid = tl.program_id(0)
1808
+ pid_z = tl.program_id(1)
1809
+
1810
+ grid_m = tl.cdiv(M, BLOCK_M)
1811
+ grid_n = tl.cdiv(N, BLOCK_N)
1812
+ # re-order program ID for better L2 performance
1813
+ width = GROUP_M * grid_n
1814
+ group_id = pid // width
1815
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
1816
+ pid_m = group_id * GROUP_M + (pid % group_size)
1817
+ pid_n = (pid % width) // (group_size)
1818
+ # do matrix multiplication
1819
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
1820
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
1821
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
1822
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
1823
+ rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
1824
+ # pointers
1825
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
1826
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
1827
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
1828
+ _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
1829
+ scale_m = pid_m * BLOCK_M // scale_block_m
1830
+ scale_n = pid_n * BLOCK_N // scale_block_n
1831
+ k_multiple = scale_block_k // BLOCK_K
1832
+
1833
+ for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
1834
+
1835
+ k_remaining = K - k * (BLOCK_K * SPLIT_K)
1836
+
1837
+ if EVEN_K:
1838
+ a = tl.load(A)
1839
+ b = tl.load(B)
1840
+ else:
1841
+ a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
1842
+ b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
1843
+ if AB_DTYPE:
1844
+ a = a.to(C.dtype.element_ty)
1845
+ b = b.to(C.dtype.element_ty)
1846
+
1847
+ acc = tl.dot(a, b, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
1848
+
1849
+ A += BLOCK_K * SPLIT_K * stride_ak
1850
+ B += BLOCK_K * SPLIT_K * stride_bk
1851
+
1852
+ # Some math to precompute on scalars, and apply once on matrix.
1853
+ # a + c/s = (as + c) / s
1854
+ # (((a_i-1 * s_i-1 + c_i-1) / s_i-1) * s_i + c_i) / s_i ... ) * s_k + c_k) * 1.0 / s_k
1855
+ # Simplifies to (a_i-1 + c) * (s_i+1/s_i)
1856
+ # And have s_k+1 be 1.
1857
+ # Scale_i = pid_i * BLOCK_I / scale_block_i
1858
+ pid_k = k * SPLIT_K + pid_z
1859
+ if ((pid_k + 1) % k_multiple == 0) or (k_remaining < BLOCK_K * SPLIT_K):
1860
+ # Note: Due to split_k access "pid_k" = k * SPLIT_K + pid_z
1861
+ # Access a_scale[pid_m, k * SPLIT_K + pid_z]
1862
+ # and b_scale[k * SPLIT_K + pid_z, pid_n]
1863
+
1864
+ scale_k = pid_k // k_multiple
1865
+ scale_k_next = scale_k + 1
1866
+ a_scale = tl.load(
1867
+ A_scale + scale_m * stride_scale_am + scale_k * stride_scale_ak
1868
+ )
1869
+ b_scale = tl.load(
1870
+ B_scale + scale_n * stride_scale_bn + scale_k * stride_scale_bk
1871
+ )
1872
+ scale = a_scale * b_scale
1873
+ if k + 1 == tl.cdiv(K, BLOCK_K * SPLIT_K):
1874
+ scale_next_inv_scale = scale
1875
+ else:
1876
+ a_scale_next = tl.load(
1877
+ A_scale + scale_m * stride_scale_am + scale_k_next * stride_scale_ak
1878
+ )
1879
+ b_scale_next = tl.load(
1880
+ B_scale + scale_n * stride_scale_bn + scale_k_next * stride_scale_bk
1881
+ )
1882
+ scale_next = a_scale_next * b_scale_next
1883
+ scale_next_inv_scale = scale / scale_next
1884
+ acc *= scale_next_inv_scale
1885
+
1886
+ # rematerialize rm and rn to save registers
1887
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
1888
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
1889
+
1890
+ acc = acc.to(C.dtype.element_ty)
1891
+ c = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
1892
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
1893
+ # handles write-back with reduction-splitting
1894
+ if SPLIT_K == 1:
1895
+ tl.store(c, acc, mask=mask)
1896
+ else:
1897
+ tl.atomic_add(c, acc, mask=mask)
1898
+
1899
+
1900
+ @triton.autotune(
1901
+ configs=MATMUL_CONFIGS,
1902
+ key=[
1903
+ "m_key",
1904
+ "n_key",
1905
+ "k_key",
1906
+ ], # TODO caller side bin keys so similar shapes can use same triton.autotune.
1907
+ prune_configs_by={
1908
+ "early_config_prune": early_config_prune,
1909
+ "perf_model": estimate_matmul_time,
1910
+ "top_k": 10,
1911
+ },
1912
+ )
1913
+ @triton.heuristics(
1914
+ {
1915
+ "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
1916
+ }
1917
+ )
1918
+ @triton.jit
1919
+ def _kernel_matmul_fp8_block_slowacc(
1920
+ A,
1921
+ B,
1922
+ C,
1923
+ M,
1924
+ N,
1925
+ K,
1926
+ m_key,
1927
+ n_key,
1928
+ k_key,
1929
+ A_scale,
1930
+ B_scale,
1931
+ scale_block_m: tl.constexpr,
1932
+ scale_block_n: tl.constexpr,
1933
+ scale_block_k: tl.constexpr,
1934
+ stride_am,
1935
+ stride_ak,
1936
+ stride_bn,
1937
+ stride_bk,
1938
+ stride_cm,
1939
+ stride_cn,
1940
+ stride_scale_am,
1941
+ stride_scale_ak,
1942
+ stride_scale_bn,
1943
+ stride_scale_bk,
1944
+ dot_out_dtype: tl.constexpr,
1945
+ allow_tf32: tl.constexpr,
1946
+ BLOCK_M: tl.constexpr,
1947
+ BLOCK_N: tl.constexpr,
1948
+ BLOCK_K: tl.constexpr,
1949
+ GROUP_M: tl.constexpr,
1950
+ SPLIT_K: tl.constexpr,
1951
+ EVEN_K: tl.constexpr,
1952
+ AB_DTYPE: tl.constexpr,
1953
+ ) -> None:
1954
+ """Matmul kernel of [M, K] @ [N, K] with block-wise scales
1955
+
1956
+ Performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles and
1957
+ A and B scaled by a scaling factor per [scale_block_m, scale_block_k] and
1958
+ [scale_block_n, scale_block_k] tiles
1959
+ respectively.
1960
+
1961
+ Todo:
1962
+ * Support scale_block_{mnk} < BLOCK{MNK} for each dim.
1963
+ Args:
1964
+ A (TensorWrapper): [M, K] input tensor.
1965
+ B (TensorWrapper): [N, K] input tensor.
1966
+ C (TensorWrapper): [M, N] output tensor.
1967
+ M (int): M dimension of input tensor.
1968
+ N (int): N dimension of input tensor.
1969
+ K (int): K dimension of input tensor.
1970
+ m_key (int): Autotuning key for M dimension of input tensor.
1971
+ n_key (int): Autotuning key for N dimension of input tensor.
1972
+ k_key (int): Autotuning key for K dimension of input tensor.
1973
+ A_scale (TensorWrapper): [cdiv(M, scale_block_m), cdiv(K, scale_block_k)] reciprocal scale tensor per block. A * A_scale = original A
1974
+ B_scale (TensorWrapper): [cdiv(N, scale_block_n), cdiv(K, scale_block_k)] reciprocal scale tensor per block. B * B_scale = original B
1975
+ scale_block_m (int): Block size for M dimension of A_scale.
1976
+ scale_block_n (int): Block size for N dimension of B_scale.
1977
+ scale_block_k (int): Block size for K dimension of A_scale and B_scale.
1978
+ stride_am (int): Stride of M dimension of A.
1979
+ stride_ak (int): Stride of K dimension of A.
1980
+ stride_bn (int): Stride of N dimension of B.
1981
+ stride_bk (int): Stride of K dimension of B.
1982
+ stride_cm (int): Stride of M dimension of C.
1983
+ stride_cn (int): Stride of N dimension of C.
1984
+ stride_scale_am (int): Stride of M dimension of A_scale.
1985
+ stride_scale_ak (int): Stride of K dimension of A_scale.
1986
+ stride_scale_bn (int): Stride of N dimension of B_scale.
1987
+ stride_scale_bk (int): Stride of K dimension of B_scale.
1988
+ dot_out_dtype (torch.dtype): Output type of tensor core.
1989
+ allow_tf32 (bool): Whether to use TF32 for tensor core.
1990
+ fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
1991
+ BLOCK_M (int): Block size for M dimension.
1992
+ BLOCK_N (int): Block size for N dimension.
1993
+ BLOCK_K (int): Block size for K dimension.
1994
+ GROUP_M (int): Number of groups for M dimension swizzle.
1995
+ SPLIT_K (int): Number of SM's to launch per row.
1996
+ EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
1997
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
1998
+ """
1999
+ assert BLOCK_M < scale_block_m
2000
+ assert BLOCK_N < scale_block_n
2001
+ assert BLOCK_K < scale_block_k
2002
+ # matrix multiplication
2003
+ pid = tl.program_id(0)
2004
+ pid_z = tl.program_id(1)
2005
+
2006
+ grid_m = tl.cdiv(M, BLOCK_M)
2007
+ grid_n = tl.cdiv(N, BLOCK_N)
2008
+ # re-order program ID for better L2 performance
2009
+ width = GROUP_M * grid_n
2010
+ group_id = pid // width
2011
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
2012
+ pid_m = group_id * GROUP_M + (pid % group_size)
2013
+ pid_n = (pid % width) // (group_size)
2014
+ # do matrix multiplication
2015
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
2016
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
2017
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
2018
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
2019
+ rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
2020
+ # pointers
2021
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
2022
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
2023
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
2024
+ scale_m = pid_m * BLOCK_M // scale_block_m
2025
+ scale_n = pid_n * BLOCK_N // scale_block_n
2026
+ _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
2027
+
2028
+ for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
2029
+ # Note: Due to split_k access "pid_k" = k * SPLIT_K + pid_z
2030
+ # Access a_scale[pid_m, k * SPLIT_K + pid_z]
2031
+ # and b_scale[k * SPLIT_K + pid_z, pid_n]
2032
+ pid_k = k * SPLIT_K + pid_z
2033
+ scale_k = pid_k * BLOCK_K // scale_block_k
2034
+ a_scale = tl.load(
2035
+ A_scale + scale_m * stride_scale_am + scale_k * stride_scale_ak
2036
+ )
2037
+ b_scale = tl.load(
2038
+ B_scale + scale_n * stride_scale_bn + scale_k * stride_scale_bk
2039
+ )
2040
+ scale = a_scale * b_scale
2041
+
2042
+ if EVEN_K:
2043
+ a = tl.load(A)
2044
+ b = tl.load(B)
2045
+ else:
2046
+ k_remaining = K - k * (BLOCK_K * SPLIT_K)
2047
+
2048
+ a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
2049
+ b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
2050
+ if AB_DTYPE:
2051
+ a = a.to(C.dtype.element_ty)
2052
+ b = b.to(C.dtype.element_ty)
2053
+
2054
+ acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) * scale
2055
+ A += BLOCK_K * SPLIT_K * stride_ak
2056
+ B += BLOCK_K * SPLIT_K * stride_bk
2057
+
2058
+ # rematerialize rm and rn to save registers
2059
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
2060
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
2061
+
2062
+ acc = acc.to(C.dtype.element_ty)
2063
+ c = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
2064
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
2065
+ # handles write-back with reduction-splitting
2066
+ if SPLIT_K == 1:
2067
+ tl.store(c, acc, mask=mask)
2068
+ else:
2069
+ tl.atomic_add(c, acc, mask=mask)
2070
+
2071
+
2072
+ @torch.library.custom_op("triton::matmul_fp8_block", mutates_args=())
2073
+ def matmul_fp8_block(
2074
+ a: torch.Tensor,
2075
+ b: torch.Tensor,
2076
+ a_scale: torch.Tensor,
2077
+ b_scale: torch.Tensor,
2078
+ scale_block_m: int = 256,
2079
+ scale_block_n: int = 256,
2080
+ scale_block_k: int = 256,
2081
+ dot_out_dtype: Optional[torch.dtype] = None,
2082
+ allow_tf32: bool = True,
2083
+ fp8_fast_accum: bool = True,
2084
+ ) -> Tensor:
2085
+ """Performs matmul on [M, K] and [N, K] fp8 matrices with block-wise scalings.
2086
+
2087
+ Args:
2088
+ a (torch.Tensor): [M, K] input tensor.
2089
+ b (torch.Tensor): [N, K] input tensor.
2090
+ a_scale (torch.Tensor): [cdiv(M, scale_block_m), cdiv(K, scale_block_k)] reciprocal scale tensor per scale block. A * A_scale = original A
2091
+ b_scale (torch.Tensor): [cdiv(N, scale_block_n), cdiv(K, scale_block_k)] reciprocal scale tensor per scale block. B * B_scale = original B
2092
+ scale_block_m (int): Block size for M dimension of A_scale.
2093
+ scale_block_n (int): Block size for N dimension of B_scale.
2094
+ scale_block_k (int): Block size for K dimension of A_scale and B_scale.
2095
+ dot_out_dtype (torch.dtype): Output type of tensor core.
2096
+ allow_tf32 (bool): Whether to use TF32 for tensor core.
2097
+ fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
2098
+
2099
+ Returns:
2100
+ Tensor: [M, N] output tensor, (a / a_scale) @ (b / b_scale)
2101
+ """
2102
+ # Get datatypes and constants to use.
2103
+ _, tl_fp8_dtype, _, _ = get_fp8_constants()
2104
+ # Handle 3D+ a shape
2105
+ a_shape = a.shape
2106
+ a = a.view(-1, a.size(-1))
2107
+ # View inputs into proper triton fp8 dtype.
2108
+ a_tl = reinterpret_fp8_type(a, tl_fp8_dtype)
2109
+ b_tl = reinterpret_fp8_type(b, tl_fp8_dtype)
2110
+
2111
+ M, N, K, m_key, n_key, k_key, c, _, dot_out_dtype_triton, device = prep_matmul(
2112
+ a_tl, b_tl, dot_out_dtype
2113
+ )
2114
+
2115
+ output_shape = a_shape[:-1] + (N,)
2116
+ # Handle case where inputs are empty.
2117
+ if (M == 0) or (N == 0) or (K == 0):
2118
+ return torch.zeros(output_shape, device=device, dtype=torch.bfloat16)
2119
+
2120
+ # launch kernel
2121
+ assert device != torch.device(
2122
+ "cpu"
2123
+ ), "Blockwise matmul not supported on cpu, please use row-wise instead."
2124
+
2125
+ if b.device != a.device:
2126
+ raise Exception("'b' must be on the same device as 'a'")
2127
+ if a_scale.device != a.device:
2128
+ raise Exception("'a_scale' must be on the same device as 'a'")
2129
+ if b_scale.device != a.device:
2130
+ raise Exception("'b_scale' must be on the same device as 'a'")
2131
+
2132
+ # noqa: E731:
2133
+ def grid(META: dict[str, int]) -> tuple[int, int]:
2134
+ return (
2135
+ triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
2136
+ META["SPLIT_K"],
2137
+ )
2138
+
2139
+ if fp8_fast_accum:
2140
+ with torch.cuda.device(a_tl.device.index):
2141
+ _kernel_matmul_fp8_block_fastacc[grid](
2142
+ a_tl,
2143
+ b_tl,
2144
+ c,
2145
+ M,
2146
+ N,
2147
+ K,
2148
+ m_key,
2149
+ n_key,
2150
+ k_key,
2151
+ a_scale,
2152
+ b_scale,
2153
+ scale_block_m,
2154
+ scale_block_n,
2155
+ scale_block_k,
2156
+ a.stride(0),
2157
+ a.stride(1),
2158
+ b.stride(0),
2159
+ b.stride(1),
2160
+ c.stride(0),
2161
+ c.stride(1),
2162
+ a_scale.stride(0),
2163
+ a_scale.stride(1),
2164
+ b_scale.stride(0),
2165
+ b_scale.stride(1),
2166
+ dot_out_dtype=dot_out_dtype_triton,
2167
+ allow_tf32=allow_tf32,
2168
+ GROUP_M=8,
2169
+ AB_DTYPE=False,
2170
+ )
2171
+ else:
2172
+ with torch.cuda.device(a_tl.device.index):
2173
+ _kernel_matmul_fp8_block_slowacc[grid](
2174
+ a_tl,
2175
+ b_tl,
2176
+ c,
2177
+ M,
2178
+ N,
2179
+ K,
2180
+ m_key,
2181
+ n_key,
2182
+ k_key,
2183
+ a_scale,
2184
+ b_scale,
2185
+ scale_block_m,
2186
+ scale_block_n,
2187
+ scale_block_k,
2188
+ a.stride(0),
2189
+ a.stride(1),
2190
+ b.stride(0),
2191
+ b.stride(1),
2192
+ c.stride(0),
2193
+ c.stride(1),
2194
+ a_scale.stride(0),
2195
+ a_scale.stride(1),
2196
+ b_scale.stride(0),
2197
+ b_scale.stride(1),
2198
+ dot_out_dtype=dot_out_dtype_triton,
2199
+ allow_tf32=allow_tf32,
2200
+ GROUP_M=8,
2201
+ AB_DTYPE=False,
2202
+ )
2203
+ return c.view(output_shape)
2204
+
2205
+
2206
+ @matmul_fp8_block.register_fake
2207
+ def matmul_fp8_block_meta(
2208
+ a: torch.Tensor,
2209
+ b: torch.Tensor,
2210
+ a_scale: torch.Tensor,
2211
+ b_scale: torch.Tensor,
2212
+ scale_block_m: int = 256,
2213
+ scale_block_n: int = 256,
2214
+ scale_block_k: int = 256,
2215
+ dot_out_dtype: Optional[torch.dtype] = None,
2216
+ allow_tf32: bool = True,
2217
+ fp8_fast_accum: bool = True,
2218
+ ) -> torch.Tensor:
2219
+ """Shape function for torch compile."""
2220
+ M, K = a.shape
2221
+ N, K = b.shape
2222
+ return torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
2223
+
2224
+
2225
+ def get_matmul_tune(M: int, N: int, K: int) -> tuple[int, int, int]:
2226
+ """
2227
+ Generate a simplified matmul tune key for A @ B.T
2228
+ with [M, K] A and [N, K] B to reduce excessive autotuning.
2229
+
2230
+ Args:
2231
+ M (int): Number of rows in A.
2232
+ N (int): Number of rows in B.
2233
+ K (int): Number of cols in A and cols in B.
2234
+
2235
+ Returns:
2236
+ m_key (int): Autotuning key for M dim.
2237
+ n_key (int): Autotuning key for N dim.
2238
+ k_key (int): Autotuning key for K dim.
2239
+
2240
+ TODO: Refine this. For now it's useful for LLM inference where N, K dims are fixed
2241
+ and M dim varies due to seq_len.
2242
+ """
2243
+ if M < 256:
2244
+ m_key = M
2245
+ else:
2246
+ m_key = 256 + M // 1024
2247
+ return m_key, N, K
2248
+
2249
+
2250
+ def prep_matmul(
2251
+ a: Union[TensorWrapper, torch.Tensor],
2252
+ b: Union[TensorWrapper, torch.Tensor],
2253
+ dot_out_dtype: Optional[torch.dtype],
2254
+ ) -> tuple[
2255
+ int, int, int, int, int, int, torch.Tensor, tl.dtype, tl.dtype, torch.device
2256
+ ]:
2257
+ """
2258
+ Shared bookkeeping for a @ b.T matmul.
2259
+
2260
+ Args:
2261
+ a (torch.Tensor): [M, K] input tensor.
2262
+ b (torch.Tensor): [N, K] input tensor.
2263
+ dot_out_dtype (tl.dtype): Output type of tensor core.
2264
+
2265
+ Returns:
2266
+ M (int): Number of rows in A.
2267
+ N (int): Number of rows in B.
2268
+ K (int): Number of cols in A and cols in B.
2269
+ m_key (int): Autotuning key for M dim.
2270
+ n_key (int): Autotuning key for N dim.
2271
+ k_key (int): Autotuning key for K dim.
2272
+ c (Tensor): [M, N] output tensor.
2273
+ c_dtype_triton (tl.dtype): Type of output tensor.
2274
+ dot_out_dtype (tl.dtype): Output type of tensor core.
2275
+ device (torch.device): Device of output tensor.
2276
+ """
2277
+ device = a.device
2278
+
2279
+ # checks constraints
2280
+ assert (
2281
+ a.shape[1] == b.shape[1]
2282
+ ), f"incompatible dimensions, a: {a.shape}, b: {b.shape}"
2283
+ M, K = a.shape
2284
+ N, _ = b.shape
2285
+ m_key, n_key, k_key = get_matmul_tune(M, N, K)
2286
+
2287
+ # allocates output
2288
+ assert a.dtype in [
2289
+ torch.float8_e4m3fn,
2290
+ torch.float8_e5m2,
2291
+ torch.float8_e4m3fnuz,
2292
+ torch.float8_e5m2fnuz,
2293
+ tl.float8e4nv,
2294
+ tl.float8e4b15,
2295
+ tl.float8e5,
2296
+ tl.float8e4b8,
2297
+ ]
2298
+ assert b.dtype in [
2299
+ torch.float8_e4m3fn,
2300
+ torch.float8_e5m2,
2301
+ torch.float8_e4m3fnuz,
2302
+ torch.float8_e5m2fnuz,
2303
+ tl.float8e4nv,
2304
+ tl.float8e4b15,
2305
+ tl.float8e5,
2306
+ tl.float8e4b8,
2307
+ ]
2308
+
2309
+ c_dtype, c_dtype_triton = (
2310
+ (torch.bfloat16, tl.bfloat16)
2311
+ if dot_out_dtype is None
2312
+ else (dot_out_dtype, map_dtype_to_triton(dot_out_dtype))
2313
+ )
2314
+
2315
+ c = torch.empty((M, N), device=device, dtype=c_dtype)
2316
+ if dot_out_dtype is None:
2317
+ dot_out_dtype_triton = tl.float32
2318
+ else:
2319
+ assert isinstance(
2320
+ dot_out_dtype, torch.dtype
2321
+ ), f"dot_out_dtype type {type(dot_out_dtype)} must be a torch.dtype"
2322
+ dot_out_dtype_triton = map_dtype_to_triton(dot_out_dtype)
2323
+
2324
+ return M, N, K, m_key, n_key, k_key, c, c_dtype_triton, dot_out_dtype_triton, device
2325
+
2326
+
2327
+ @triton.autotune(
2328
+ configs=[
2329
+ Config({"BLOCK_SIZE": 512}),
2330
+ Config({"BLOCK_SIZE": 1024}),
2331
+ Config({"BLOCK_SIZE": 2048}),
2332
+ Config({"BLOCK_SIZE": 4096}),
2333
+ Config({"BLOCK_SIZE": 8192}),
2334
+ ],
2335
+ key=["K"],
2336
+ )
2337
+ @triton.jit
2338
+ def _kernel_quantize_fp8_row(
2339
+ A,
2340
+ A_scale,
2341
+ A_fp8,
2342
+ scale_ub,
2343
+ zero_start_index_M,
2344
+ B,
2345
+ M,
2346
+ N,
2347
+ K,
2348
+ K_fp8, # used when padding
2349
+ stride_ab,
2350
+ stride_am,
2351
+ stride_an,
2352
+ stride_ak,
2353
+ stride_ob,
2354
+ stride_om,
2355
+ stride_on,
2356
+ stride_ok,
2357
+ stride_zb,
2358
+ stride_zm,
2359
+ TL_FP8_DTYPE: tl.constexpr,
2360
+ MAX_FP8: tl.constexpr,
2361
+ EPS: tl.constexpr,
2362
+ CLAMP_MAX: tl.constexpr,
2363
+ JAGGED: tl.constexpr,
2364
+ BLOCK_SIZE: tl.constexpr,
2365
+ USE_INT64: tl.constexpr,
2366
+ ) -> None:
2367
+ """Quantize and scale each row.
2368
+
2369
+ Scale per row i is computed as MAX_FP8 / max(abs(A[i, :]))
2370
+
2371
+ Kernel naively iterates through matrix with [1, BLOCK_SIZE] tiles
2372
+ in a max pass then scale/quantize pass.
2373
+
2374
+ Todo:
2375
+ * Better tiling schemes.
2376
+
2377
+ Args:
2378
+ A (Tensor): higher precision input tensor of 4 dimension.
2379
+ A_scale (Tensor): [B * M * N] reciprocal scale tensor per row.
2380
+ A_fp8 (Tensor): fp8 scaled tensor. A_fp8 = A / a_scale
2381
+ scale_ub (Tensor): [1] Maximum value allowed for scale.
2382
+ B (int): Size of dimenion 0
2383
+ M (int): Size of dimenion 1
2384
+ N (int): Size of dimenion 2
2385
+ K (int): Size of dimenion 3 (input row size)
2386
+ K_fp8 (int): Size of dimenion 3 for A_fp8 (output row size, can be >= K)
2387
+ stride_ab (int): Stride of b dimension of A.
2388
+ stride_am (int): Stride of m dimension of A.
2389
+ stride_an (int): Stride of n dimension of A.
2390
+ stride_ak (int): Stride of k dimension of A.
2391
+ stride_ob (int): Stride of b dimension of output.
2392
+ stride_om (int): Stride of m dimension of output.
2393
+ stride_on (int): Stride of n dimension of output.
2394
+ stride_ok (int): Stride of k dimension of output.
2395
+ stride_zb (int): Stride of b dimension of jagged index.
2396
+ stride_zm (int): Stride of m dimension of jagged index.
2397
+ TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
2398
+ MAX_FP8 (float): Maxmimum expressible value for FP8.
2399
+ EPS (float): Epsilon value for numerical stability.
2400
+ CLAMP_MAX (bool): Whethar to apply scale_ub.
2401
+ JAGGED (bool): Whether to use jagged indexing.
2402
+ BLOCK_SIZE (int): Block size for reduction.
2403
+ USE_INT64 (bool): Whether to use int64 indexing for large inputs.
2404
+ """
2405
+ pid = tl.program_id(0)
2406
+ # Use int64 indexing for large inputs. This is slower, but
2407
+ # needed to avoid index overflows.
2408
+ if USE_INT64:
2409
+ pid = pid.to(tl.int64)
2410
+ n_offset = tl.arange(0, BLOCK_SIZE)
2411
+ a_offset_base = (
2412
+ pid // (M * N) * stride_ab
2413
+ + (pid % (M * N)) // N * stride_am
2414
+ + (pid % (M * N)) % N * stride_an
2415
+ )
2416
+ a_fp8_offset_base = (
2417
+ pid // (M * N) * stride_ob
2418
+ + (pid % (M * N)) // N * stride_om
2419
+ + (pid % (M * N)) % N * stride_on
2420
+ )
2421
+
2422
+ K_in = K
2423
+
2424
+ if JAGGED:
2425
+ z_offset_base = pid // (M * N) * stride_zb + (pid % (M * N)) // N * stride_zm
2426
+ group_rows = tl.load(zero_start_index_M + z_offset_base)
2427
+ current_row = pid % N
2428
+ # If this row is empty, dont process any of it.
2429
+ if current_row >= group_rows:
2430
+ K_in = 0
2431
+
2432
+ # Calculate max.
2433
+ cur_max = 0.0
2434
+ for _k in range(0, tl.cdiv(K_in, BLOCK_SIZE)):
2435
+ a = tl.load(
2436
+ A + a_offset_base + n_offset * stride_ak,
2437
+ mask=n_offset < K_in,
2438
+ other=0.0,
2439
+ )
2440
+ tile_max = tl.max(tl.abs(a))
2441
+ cur_max = tl.maximum(tile_max, cur_max)
2442
+ n_offset += BLOCK_SIZE
2443
+
2444
+ # Clamp max value appropriately.
2445
+ if CLAMP_MAX:
2446
+ ub = tl.load(scale_ub)
2447
+ cur_max = tl.clamp(cur_max, EPS, ub)
2448
+ else:
2449
+ cur_max = tl.maximum(cur_max, EPS)
2450
+ # Scale and quantize.
2451
+ a_scale = MAX_FP8 / cur_max
2452
+ tl.store(A_scale + pid, 1.0 / a_scale)
2453
+ n_offset = tl.arange(0, BLOCK_SIZE)
2454
+
2455
+ # Write quantized values for the first K elements (from A), and pad the rest with zeros up to K_fp8
2456
+ for _k in range(0, tl.cdiv(K_fp8, BLOCK_SIZE)):
2457
+ # Load from A if in range, else 0 (we're going all the way to K_fp8)
2458
+ a = tl.load(
2459
+ A + a_offset_base + n_offset * stride_ak,
2460
+ mask=n_offset < K_in,
2461
+ other=0.0,
2462
+ )
2463
+ # For elements >= K, a will be 0
2464
+ a_fp8 = a * a_scale
2465
+ # Clamp A to fp8 range to make sure there's no overflow.
2466
+ # This is required for AMD. Nvidia's default saturation
2467
+ # handles it, but it's nice to have anyway.
2468
+ a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
2469
+
2470
+ # Store the full new row in its place (for elements >= K, a_fp8 is already 0)
2471
+ tl.store(
2472
+ A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
2473
+ a_fp8,
2474
+ mask=n_offset < K_fp8,
2475
+ )
2476
+ n_offset += BLOCK_SIZE
2477
+
2478
+
2479
+ def triton_quantize_fp8_row(
2480
+ a: Tensor,
2481
+ scale_ub: Optional[Tensor] = None,
2482
+ zero_start_index_M: Optional[Tensor] = None,
2483
+ align_rows_to: Optional[int] = None,
2484
+ ) -> tuple[Tensor, Tensor]:
2485
+ """
2486
+ Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
2487
+
2488
+ Args:
2489
+ a (Tensor): higher precision input tensor of 4 dimension.
2490
+ scale_ub (Tensor): Maximum allowed value for scale.
2491
+ zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
2492
+ align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
2493
+
2494
+ Returns:
2495
+ torch.Tensor: fp8 scaled tensor.
2496
+ torch.Tensor: reciprocal scale tensor per row.
2497
+ """
2498
+ if scale_ub is not None and scale_ub.device != a.device:
2499
+ raise Exception("'scale_ub' must be on the same device as 'a'")
2500
+ if zero_start_index_M is not None and zero_start_index_M.device != a.device:
2501
+ raise Exception("'zero_start_index_M' must be on the same device as 'a'")
2502
+
2503
+ assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor."
2504
+ a_shape = a.shape
2505
+ while a.dim() < 4:
2506
+ a = a.unsqueeze(0)
2507
+ if zero_start_index_M is not None:
2508
+ # There should be one value of zero_start_index_M per NxK matrix.
2509
+ zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
2510
+ # Get constant values.
2511
+ pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
2512
+ num_rows = a.numel() // a.shape[-1]
2513
+ a_scale = torch.empty((num_rows), dtype=torch.float32, device=a.device)
2514
+ # If align_rows_to is provided, pad the last dimension to be a multiple of it
2515
+ if align_rows_to is not None:
2516
+ last_dim = a.shape[-1]
2517
+ padded_last_dim = (
2518
+ (last_dim + align_rows_to - 1) // align_rows_to
2519
+ ) * align_rows_to
2520
+ a_fp8 = torch.empty(
2521
+ (*a.shape[:-1], padded_last_dim), device=a.device, dtype=pt_dtype
2522
+ )
2523
+ a_shape = torch.Size((*a_shape[:-1], padded_last_dim))
2524
+ else:
2525
+ a_fp8 = torch.empty(a.shape, device=a.device, dtype=pt_dtype)
2526
+
2527
+ # If input tensor is sufficiently large, we need to use int64 indexing.
2528
+ use_int64 = a.numel() > (2**31 - 1)
2529
+ grid = (num_rows,)
2530
+ # Pick a conservative value for inference shapes for disabling BufferOps.
2531
+ should_disable_bufferops = torch.version.hip is not None and a_shape[0] < 32
2532
+ with disable_bufferops(should_disable_bufferops):
2533
+ with torch.cuda.device(a.device.index):
2534
+ _kernel_quantize_fp8_row[grid](
2535
+ a,
2536
+ a_scale,
2537
+ a_fp8,
2538
+ scale_ub,
2539
+ zero_start_index_M,
2540
+ a.shape[0],
2541
+ a.shape[1],
2542
+ a.shape[2],
2543
+ a.shape[3],
2544
+ a_fp8.shape[3],
2545
+ a.stride(0),
2546
+ a.stride(1),
2547
+ a.stride(2),
2548
+ a.stride(3),
2549
+ a_fp8.stride(0),
2550
+ a_fp8.stride(1),
2551
+ a_fp8.stride(2),
2552
+ a_fp8.stride(3),
2553
+ (
2554
+ zero_start_index_M.stride(0)
2555
+ if zero_start_index_M is not None
2556
+ else None
2557
+ ),
2558
+ (
2559
+ zero_start_index_M.stride(1)
2560
+ if zero_start_index_M is not None
2561
+ else None
2562
+ ),
2563
+ TL_FP8_DTYPE=tl_dtype,
2564
+ MAX_FP8=max_fp8,
2565
+ EPS=eps,
2566
+ CLAMP_MAX=scale_ub is not None,
2567
+ JAGGED=zero_start_index_M is not None,
2568
+ USE_INT64=use_int64,
2569
+ )
2570
+
2571
+ return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])
2572
+
2573
+
2574
+ @triton.autotune(
2575
+ configs=[
2576
+ Config({"BLOCK_SIZE": 512}),
2577
+ Config({"BLOCK_SIZE": 1024}),
2578
+ Config({"BLOCK_SIZE": 2048}),
2579
+ Config({"BLOCK_SIZE": 4096}),
2580
+ Config({"BLOCK_SIZE": 8192}),
2581
+ ],
2582
+ key=["K"],
2583
+ )
2584
+ @triton.jit
2585
+ def _kernel_quantize_fp8_packed_row(
2586
+ A,
2587
+ A_fp8,
2588
+ packed_scale,
2589
+ scale_ub,
2590
+ zero_start_index_M,
2591
+ B,
2592
+ M,
2593
+ N,
2594
+ K,
2595
+ stride_ab,
2596
+ stride_am,
2597
+ stride_an,
2598
+ stride_ak,
2599
+ stride_ob,
2600
+ stride_om,
2601
+ stride_on,
2602
+ stride_ok,
2603
+ packed_scale_stride,
2604
+ stride_zb,
2605
+ stride_zm,
2606
+ TL_FP8_DTYPE: tl.constexpr,
2607
+ MAX_FP8: tl.constexpr,
2608
+ EPS: tl.constexpr,
2609
+ CLAMP_MAX: tl.constexpr,
2610
+ JAGGED: tl.constexpr,
2611
+ BLOCK_SIZE: tl.constexpr,
2612
+ USE_INT64: tl.constexpr,
2613
+ ) -> None:
2614
+ """Quantize and scale each row.
2615
+
2616
+ Scale per row i is computed as MAX_FP8 / max(abs(A[i, :]))
2617
+
2618
+ Kernel naively iterates through matrix with [1, BLOCK_SIZE] tiles
2619
+ in a max pass then scale/quantize pass.
2620
+
2621
+ Todo:
2622
+ * Better tiling schemes.
2623
+
2624
+ Args:
2625
+ A (Tensor): higher precision input tensor of 4 dimension.
2626
+ packed_scale (Tensor): [B * M * N] reciprocal scale tensor per row.
2627
+ A_fp8 (Tensor): fp8 scaled tensor. A_fp8 = A / a_scale
2628
+ scale_ub (Tensor): [1] Maximum value allowed for scale.
2629
+ B (int): Size of dimenion 0
2630
+ M (int): Size of dimenion 1
2631
+ N (int): Size of dimenion 2
2632
+ K (int): Size of dimenion 3
2633
+ stride_ab (int): Stride of b dimension of A.
2634
+ stride_am (int): Stride of m dimension of A.
2635
+ stride_an (int): Stride of n dimension of A.
2636
+ stride_ak (int): Stride of k dimension of A.
2637
+ stride_ob (int): Stride of b dimension of output.
2638
+ stride_om (int): Stride of m dimension of output.
2639
+ stride_on (int): Stride of n dimension of output.
2640
+ stride_ok (int): Stride of k dimension of output.
2641
+ packed_scale_stride (int): Stride of the packed scale, indexing into a_fp8.
2642
+ stride_zb (int): Stride of b dimension of jagged index.
2643
+ stride_zm (int): Stride of m dimension of jagged index.
2644
+ TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
2645
+ MAX_FP8 (float): Maxmimum expressible value for FP8.
2646
+ EPS (float): Epsilon value for numerical stability.
2647
+ CLAMP_MAX (bool): Whethar to apply scale_ub.
2648
+ JAGGED (bool): Whether to use jagged indexing.
2649
+ BLOCK_SIZE (int): Block size for reduction.
2650
+ USE_INT64 (bool): Whether to use int64 indexing for large inputs.
2651
+ """
2652
+ pid = tl.program_id(0)
2653
+ # Use int64 indexing for large inputs. This is slower, but
2654
+ # needed to avoid index overflows.
2655
+ if USE_INT64:
2656
+ pid = pid.to(tl.int64)
2657
+ n_offset = tl.arange(0, BLOCK_SIZE)
2658
+ a_offset_base = (
2659
+ pid // (M * N) * stride_ab
2660
+ + (pid % (M * N)) // N * stride_am
2661
+ + (pid % (M * N)) % N * stride_an
2662
+ )
2663
+ a_fp8_offset_base = (
2664
+ pid // (M * N) * stride_ob
2665
+ + (pid % (M * N)) // N * stride_om
2666
+ + (pid % (M * N)) % N * stride_on
2667
+ )
2668
+
2669
+ K_in = K
2670
+
2671
+ if JAGGED:
2672
+ z_offset_base = pid // (M * N) * stride_zb + (pid % (M * N)) // N * stride_zm
2673
+ group_rows = tl.load(zero_start_index_M + z_offset_base)
2674
+ current_row = pid % N
2675
+ # If this row is empty, dont process any of it.
2676
+ if current_row >= group_rows:
2677
+ K_in = 0
2678
+
2679
+ # Calculate max.
2680
+ cur_max = 0.0
2681
+ for _k in range(0, tl.cdiv(K_in, BLOCK_SIZE)):
2682
+ a = tl.load(
2683
+ A + a_offset_base + n_offset * stride_ak,
2684
+ mask=n_offset < K_in,
2685
+ other=0.0,
2686
+ )
2687
+ tile_max = tl.max(tl.abs(a))
2688
+ cur_max = tl.maximum(tile_max, cur_max)
2689
+ n_offset += BLOCK_SIZE
2690
+
2691
+ # Clamp max value appropriately.
2692
+ if CLAMP_MAX:
2693
+ ub = tl.load(scale_ub)
2694
+ cur_max = tl.clamp(cur_max, EPS, ub)
2695
+ else:
2696
+ cur_max = tl.maximum(cur_max, EPS)
2697
+ # Scale and quantize.
2698
+ a_scale = MAX_FP8 / cur_max
2699
+
2700
+ (fp8_0, fp8_1, fp8_2, fp8_3) = tl.inline_asm_elementwise(
2701
+ asm="""
2702
+ {
2703
+ // $4 is the input register
2704
+ .reg .b32 input;
2705
+ mov.b32 input, $4;
2706
+ mov.b32 $0, $4;
2707
+ shr.b32 $1, $4, 8;
2708
+ shr.b32 $2, $4, 16;
2709
+ shr.b32 $3, $4, 24;
2710
+ }
2711
+ """,
2712
+ constraints=("=r,=r,=r,=r," "r"),
2713
+ # Let's pass in 1 uint32 value per iteration, containing 8 packed int4 values
2714
+ args=[1.0 / a_scale],
2715
+ dtype=(
2716
+ tl.uint8,
2717
+ tl.uint8,
2718
+ tl.uint8,
2719
+ tl.uint8,
2720
+ ),
2721
+ is_pure=True,
2722
+ pack=1,
2723
+ )
2724
+
2725
+ # There are some compiler issues with FP8 pointers
2726
+ packed_scale_ptr = packed_scale.to(tl.pointer_type(tl.uint8))
2727
+ tl.store(packed_scale_ptr + pid * packed_scale_stride, fp8_0)
2728
+ tl.store(packed_scale_ptr + pid * packed_scale_stride + 1, fp8_1)
2729
+ tl.store(packed_scale_ptr + pid * packed_scale_stride + 2, fp8_2)
2730
+ tl.store(packed_scale_ptr + pid * packed_scale_stride + 3, fp8_3)
2731
+
2732
+ n_offset = tl.arange(0, BLOCK_SIZE)
2733
+
2734
+ for _k in range(0, tl.cdiv(K, BLOCK_SIZE)):
2735
+ a = tl.load(
2736
+ A + a_offset_base + n_offset * stride_ak,
2737
+ mask=n_offset < K_in,
2738
+ other=0.0,
2739
+ )
2740
+ a_fp8 = a * a_scale
2741
+ # Clamp A to fp8 range to make sure there's no overflow.
2742
+ # This is required for AMD. Nvidia's default saturation
2743
+ # handles it, but it's nice to have anyway.
2744
+ a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
2745
+ tl.store(
2746
+ A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
2747
+ a_fp8,
2748
+ mask=n_offset < K,
2749
+ )
2750
+
2751
+ n_offset += BLOCK_SIZE
2752
+
2753
+
2754
+ def triton_quantize_fp8_packed_row(
2755
+ a: Tensor,
2756
+ scale_ub: Optional[Tensor] = None,
2757
+ zero_start_index_M: Optional[Tensor] = None,
2758
+ return_only_packed: Optional[bool] = False,
2759
+ ) -> tuple[Optional[Tensor], Optional[Tensor], Tensor]:
2760
+ """
2761
+ Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
2762
+
2763
+ This packs the FP32 scale at the end of each row, so the fp8 scaled tensor and the reciprocal scale tensor per row are contiguous in memory.
2764
+
2765
+ Args:
2766
+ a (Tensor): higher precision input tensor of 4 dimension.
2767
+ scale_ub (Tensor): Maximum allowed value for scale.
2768
+ zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
2769
+ return_only_packed (bool): Only return the packed tensor, do not unpack results if True
2770
+ Returns:
2771
+ torch.Tensor: fp8 scaled tensor.
2772
+ torch.Tensor: reciprocal scale tensor per row.
2773
+ torch.Tensor: The packed FP8 scaled tensor, with the scale at the end of each row.
2774
+ """
2775
+ if scale_ub is not None and scale_ub.device != a.device:
2776
+ raise Exception("'scale_ub' must be on the same device as 'a'")
2777
+ if zero_start_index_M is not None and zero_start_index_M.device != a.device:
2778
+ raise Exception("'zero_start_index_M' must be on the same device as 'a'")
2779
+
2780
+ assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor."
2781
+ a_shape = a.shape
2782
+ while a.dim() < 4:
2783
+ a = a.unsqueeze(0)
2784
+ if zero_start_index_M is not None:
2785
+ # There should be one value of zero_start_index_M per NxK matrix.
2786
+ zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
2787
+ # Get constant values.
2788
+ pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
2789
+ num_rows = a.numel() // a.shape[-1]
2790
+
2791
+ # Allocate an extra 4-bytes at the end of each row for the scale.
2792
+ a_fp8 = torch.empty(
2793
+ (*a.shape[:-1], a.shape[-1] + 4), device=a.device, dtype=pt_dtype
2794
+ )
2795
+
2796
+ # create a view of the packed scale
2797
+ packed_scale = a_fp8[..., -4:]
2798
+
2799
+ # If input tensor is sufficiently large, we need to use int64 indexing.
2800
+ use_int64 = a.numel() > (2**31 - 1)
2801
+ grid = (num_rows,)
2802
+
2803
+ with torch.cuda.device(a.device.index):
2804
+ _kernel_quantize_fp8_packed_row[grid](
2805
+ a,
2806
+ a_fp8,
2807
+ packed_scale,
2808
+ scale_ub,
2809
+ zero_start_index_M,
2810
+ a.shape[0],
2811
+ a.shape[1],
2812
+ a.shape[2],
2813
+ a.shape[3],
2814
+ a.stride(0),
2815
+ a.stride(1),
2816
+ a.stride(2),
2817
+ a.stride(3),
2818
+ a_fp8.stride(0),
2819
+ a_fp8.stride(1),
2820
+ a_fp8.stride(2),
2821
+ a_fp8.stride(3),
2822
+ packed_scale.stride(2), # this is the stride that matters
2823
+ zero_start_index_M.stride(0) if zero_start_index_M is not None else None,
2824
+ zero_start_index_M.stride(1) if zero_start_index_M is not None else None,
2825
+ TL_FP8_DTYPE=tl_dtype,
2826
+ MAX_FP8=max_fp8,
2827
+ EPS=eps,
2828
+ CLAMP_MAX=scale_ub is not None,
2829
+ JAGGED=zero_start_index_M is not None,
2830
+ USE_INT64=use_int64,
2831
+ )
2832
+ if return_only_packed:
2833
+ return None, None, a_fp8.view((*a_shape[:-1], a_shape[-1] + 4))
2834
+
2835
+ # Extract the original shape data without the extra 4 bytes per row
2836
+ # The data is still contiguous in memory, so we have to unpack it.
2837
+ final_fp8_view = a_fp8[..., :-4].view(a_shape)
2838
+ scale_view = a_fp8[..., -4:].reshape((num_rows * 4)).view(torch.float32)
2839
+
2840
+ # the difference with the packed API is that it also
2841
+ # returns the full packed tensor as a third return value
2842
+ return final_fp8_view, scale_view.view(a_shape[:-1]), a_fp8
2843
+
2844
+
2845
+ @torch.library.custom_op("triton::quantize_fp8_packed_row", mutates_args=())
2846
+ def quantize_fp8_packed_row(
2847
+ a: Tensor,
2848
+ scale_ub: Optional[Tensor] = None,
2849
+ zero_start_index_M: Optional[Tensor] = None,
2850
+ use_triton: bool = True,
2851
+ output_device: Optional[torch.device] = None,
2852
+ ) -> tuple[torch.Tensor, torch.Tensor]:
2853
+ """
2854
+ Quantize a to fp8 with row-wise scalings and optionally move to output device.
2855
+
2856
+ Args:
2857
+ a (Tensor): Input high precision tensor. Required to have no more than 4 dimension
2858
+ scale_ub (Tensor): Maximum allowed value for scale.
2859
+ zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
2860
+ use_triton (bool): Whether to use triton kernel or pytorch.
2861
+ output_device (torch.device): Device to optionally move the scaled tensors to.
2862
+ Returns:
2863
+ torch.Tensor: fp8 scaled tensor.
2864
+ torch.Tensor: The reciprocal scale tensor per row.
2865
+ """
2866
+
2867
+ if a.device == torch.device("cpu"):
2868
+ logger.info("Triton does not support cpu, falling back to torch ops.")
2869
+ use_triton = False
2870
+ if use_triton:
2871
+ # ignore the packed tensor here, we aren't testing it
2872
+ a_fp8, scale, _ = triton_quantize_fp8_packed_row(
2873
+ a, scale_ub, zero_start_index_M, return_only_packed=False
2874
+ )
2875
+ assert a_fp8 is not None
2876
+ assert scale is not None
2877
+ return a_fp8, scale
2878
+ # else use pytorch implementation.
2879
+ if not output_device:
2880
+ output_device = a.device
2881
+
2882
+ a_shape = a.shape
2883
+ # Get constants.
2884
+ pt_dtype, _, max_fp8, eps = get_fp8_constants()
2885
+ row_max: torch.Tensor = torch.max(torch.abs(a), dim=-1)[0]
2886
+ # Apply clamping.
2887
+ if scale_ub is not None:
2888
+ row_max = torch.clamp(row_max, min=eps, max=scale_ub.item())
2889
+ else:
2890
+ # pyre-ignore[6]: Incompatible parameter type [6]
2891
+ row_max = torch.clamp(row_max, min=eps)
2892
+ a_scale = torch.empty((a.shape[:-1]), dtype=torch.float32, device=output_device)
2893
+ a_scale = max_fp8 / row_max.to(torch.float32) # pyre-ignore
2894
+ a_scale[a_scale == float("inf")] = 1.0 # pyre-ignore
2895
+ a_fp8 = a * a_scale[..., None] # pyre-ignore
2896
+ # Cast and move data to output device (for cpu weight loading).
2897
+ a_fp8 = a_fp8.to(device=output_device, dtype=pt_dtype)
2898
+ a_scale = a_scale.to(output_device) # pyre-ignore
2899
+ del a
2900
+ return a_fp8, (1 / a_scale).view(a_shape[:-1]) # pyre-ignore
2901
+
2902
+
2903
+ @torch.library.custom_op("triton::quantize_fp8_packed_row_raw", mutates_args=())
2904
+ def quantize_fp8_packed_row_raw(
2905
+ a: Tensor,
2906
+ scale_ub: Optional[Tensor] = None,
2907
+ zero_start_index_M: Optional[Tensor] = None,
2908
+ use_triton: bool = True,
2909
+ output_device: Optional[torch.device] = None,
2910
+ ) -> torch.Tensor:
2911
+ """
2912
+ Quantize a to fp8 with row-wise scalings and optionally move to output device.
2913
+
2914
+ Identical to quantize_fp8_packed_row, except it only returns the raw packed tensor.
2915
+
2916
+ Args:
2917
+ a (Tensor): Input high precision tensor. Required to have no more than 4 dimension
2918
+ scale_ub (Tensor): Maximum allowed value for scale.
2919
+ zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
2920
+ use_triton (bool): Whether to use triton kernel or pytorch.
2921
+ output_device (torch.device): Device to optionally move the scaled tensors to.
2922
+ Returns:
2923
+ torch.Tensor: fp8 scaled tensor.
2924
+ torch.Tensor: The reciprocal scale tensor per row.
2925
+ """
2926
+
2927
+ if a.device == torch.device("cpu"):
2928
+ logger.info("Triton does not support cpu, falling back to torch ops.")
2929
+ use_triton = False
2930
+ if use_triton:
2931
+ # ignore the packed tensor here, we aren't testing it
2932
+ _, _, packed_tensor = triton_quantize_fp8_packed_row(
2933
+ a, scale_ub, zero_start_index_M, return_only_packed=True
2934
+ )
2935
+ return packed_tensor
2936
+ else:
2937
+ raise Exception(
2938
+ "No PyTorch implementation provided for triton::quantize_fp8_packed_row_raw"
2939
+ )
2940
+
2941
+
2942
+ @torch.library.custom_op("triton::quantize_fp8_row", mutates_args=())
2943
+ def quantize_fp8_row(
2944
+ a: Tensor,
2945
+ scale_ub: Optional[Tensor] = None,
2946
+ zero_start_index_M: Optional[Tensor] = None,
2947
+ use_triton: bool = True,
2948
+ output_device: Optional[torch.device] = None,
2949
+ align_rows_to: Optional[int] = None,
2950
+ ) -> tuple[torch.Tensor, torch.Tensor]:
2951
+ """
2952
+ Quantize a to fp8 with row-wise scalings and optionally move to output device.
2953
+
2954
+ Args:
2955
+ a (Tensor): Input high precision tensor. Required to have no more than 4 dimension
2956
+ scale_ub (Tensor): Maximum allowed value for scale.
2957
+ zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
2958
+ use_triton (bool): Whether to use triton kernel or pytorch.
2959
+ output_device (torch.device): Device to optionally move the scaled tensors to.
2960
+ align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
2961
+
2962
+ Returns:
2963
+ torch.Tensor: fp8 scaled tensor.
2964
+ torch.Tensor: The reciprocal scale tensor per row.
2965
+ """
2966
+
2967
+ if a.device == torch.device("cpu"):
2968
+ logger.info("Triton does not support cpu, falling back to torch ops.")
2969
+ use_triton = False
2970
+ if use_triton:
2971
+ return triton_quantize_fp8_row(
2972
+ a,
2973
+ scale_ub,
2974
+ zero_start_index_M,
2975
+ align_rows_to=align_rows_to,
2976
+ )
2977
+ # else use pytorch implementation.
2978
+ if not output_device:
2979
+ output_device = a.device
2980
+
2981
+ a_shape = a.shape
2982
+ # Get constants.
2983
+ pt_dtype, _, max_fp8, eps = get_fp8_constants()
2984
+ row_max: torch.Tensor = torch.max(torch.abs(a), dim=-1)[0]
2985
+ # Apply clamping.
2986
+ if scale_ub is not None:
2987
+ row_max = torch.clamp(row_max, min=eps, max=scale_ub.item())
2988
+ else:
2989
+ # pyre-ignore[6]: Incompatible parameter type [6]
2990
+ row_max = torch.clamp(row_max, min=eps)
2991
+ a_scale = torch.empty((a.shape[:-1]), dtype=torch.float32, device=output_device)
2992
+ a_scale = max_fp8 / row_max.to(torch.float32) # pyre-ignore
2993
+ a_scale[a_scale == float("inf")] = 1.0 # pyre-ignore
2994
+ a_fp8 = a * a_scale[..., None] # pyre-ignore
2995
+ # Cast and move data to output device (for cpu weight loading).
2996
+ a_fp8 = a_fp8.to(device=output_device, dtype=pt_dtype)
2997
+ a_scale = a_scale.to(output_device) # pyre-ignore
2998
+ del a
2999
+ return a_fp8, (1 / a_scale).view(a_shape[:-1]) # pyre-ignore
3000
+
3001
+
3002
+ @quantize_fp8_row.register_fake
3003
+ def quantize_fp8_row_meta(
3004
+ a: Tensor,
3005
+ scale_ub: Optional[Tensor] = None,
3006
+ zero_start_index_M: Optional[Tensor] = None,
3007
+ use_triton: bool = True,
3008
+ output_device: Optional[torch.device] = None,
3009
+ align_rows_to: Optional[int] = None,
3010
+ ) -> tuple[torch.Tensor, torch.Tensor]:
3011
+ """Shape function for torch compile."""
3012
+ if output_device is None:
3013
+ output_device = a.device
3014
+ a_shape = a.shape
3015
+ dtype = get_fp8_constants()[0]
3016
+ fake_scale = torch.empty(a_shape[:-1], device=output_device, dtype=torch.float32)
3017
+ if align_rows_to is not None:
3018
+ last_dim = a.shape[-1]
3019
+ padded_last_dim = (
3020
+ (last_dim + align_rows_to - 1) // align_rows_to
3021
+ ) * align_rows_to
3022
+ fake_out = torch.empty(
3023
+ (*a.shape[:-1], padded_last_dim), device=output_device, dtype=dtype
3024
+ )
3025
+ return fake_out, fake_scale
3026
+ else:
3027
+ fake_out = torch.empty(a.shape, device=output_device, dtype=dtype)
3028
+ return fake_out, fake_scale
3029
+
3030
+
3031
+ @triton.autotune(
3032
+ configs=[
3033
+ Config({"BLOCK_SIZE": 512}),
3034
+ Config({"BLOCK_SIZE": 1024}),
3035
+ Config({"BLOCK_SIZE": 2048}),
3036
+ Config({"BLOCK_SIZE": 4096}),
3037
+ Config({"BLOCK_SIZE": 8192}),
3038
+ ],
3039
+ key=["N"],
3040
+ )
3041
+ @triton.jit
3042
+ def _kernel_scale_fp8_row(
3043
+ A,
3044
+ x_scale,
3045
+ w_scale,
3046
+ scaled_out,
3047
+ M,
3048
+ N,
3049
+ stride_am,
3050
+ stride_an,
3051
+ stride_om,
3052
+ stride_on,
3053
+ BLOCK_SIZE: tl.constexpr,
3054
+ ) -> None:
3055
+ """
3056
+ Scale each row of A by x_scale and each column of A by w_scale.
3057
+
3058
+ Args:
3059
+ A (Tensor): [m, n] Input tensor to scale.
3060
+ x_scale (Tensor): [m] Row-wise scale tensor.
3061
+ w_scale (Tensor): [n] Col-wise scale tensor.
3062
+ scaled_out (Tensor): [m, n] Output tensor.
3063
+ M (int): Number of rows.
3064
+ N (int): Number of columns.
3065
+ stride_am (int): Stride of m dimension of A.
3066
+ stride_an (int): Stride of n dimension of A.
3067
+ stride_om (int): Stride of m dimension of output.
3068
+ stride_on (int): Stride of n dimension of output.
3069
+ BLOCK_SIZE (int): Block size for data loads.
3070
+ """
3071
+ pid = tl.program_id(0)
3072
+ n_offset = tl.arange(0, BLOCK_SIZE)
3073
+ # Load activation scale for this row.
3074
+ row_scale = tl.load(x_scale + pid)
3075
+
3076
+ # Iterate over chunks of the row and apply scales.
3077
+ for _k in range(0, tl.cdiv(N, BLOCK_SIZE)):
3078
+ a = tl.load(
3079
+ A + pid * stride_am + n_offset * stride_an, mask=n_offset < N, other=0.0
3080
+ )
3081
+ col_scale = tl.load(w_scale + n_offset)
3082
+ scaled_a = a * row_scale * col_scale
3083
+ tl.store(
3084
+ scaled_out + pid * stride_om + n_offset * stride_on,
3085
+ scaled_a,
3086
+ mask=n_offset < N,
3087
+ )
3088
+ n_offset += BLOCK_SIZE
3089
+
3090
+
3091
+ def scale_fp8_row(
3092
+ a: Tensor,
3093
+ x_scale: Tensor,
3094
+ w_scale: Tensor,
3095
+ ) -> torch.Tensor:
3096
+ """
3097
+ Apply only rowwise scaling to a tensor. Useful when combining with kernels
3098
+ that do not support fused rowwise scaling.
3099
+
3100
+ Args:
3101
+ a (Tensor): Input floating point tensor to be scaled.
3102
+ x_scale (Tensor): Row-wise activation scale tensor.
3103
+ w_scale (Tensor): Col-wise weight scale tensor.
3104
+ """
3105
+ if a.device == torch.device("cpu"):
3106
+ # On CPU we'll just use native pytorch to scale.
3107
+ return a * x_scale[:, None] * w_scale[None, :]
3108
+
3109
+ if x_scale.device != a.device:
3110
+ raise Exception("'x_scale' must be on the same device as 'a'")
3111
+ if w_scale.device != a.device:
3112
+ raise Exception("'w_scale' must be on the same device as 'a'")
3113
+
3114
+ # Otherwise, use a fast triton kernel to implement.
3115
+ # We'll parallelize over rows.
3116
+ num_rows = a.shape[0]
3117
+ scaled_out = torch.empty(a.shape, device=a.device, dtype=a.dtype)
3118
+ grid = (num_rows,)
3119
+ with torch.cuda.device(a.device.index):
3120
+ _kernel_scale_fp8_row[grid](
3121
+ a,
3122
+ x_scale,
3123
+ w_scale,
3124
+ scaled_out,
3125
+ a.shape[0],
3126
+ a.shape[1],
3127
+ a.stride(0),
3128
+ a.stride(1),
3129
+ scaled_out.stride(0),
3130
+ scaled_out.stride(1),
3131
+ )
3132
+
3133
+ return scaled_out
3134
+
3135
+
3136
+ @triton.jit
3137
+ def _kernel_quantize_fp8_block(
3138
+ A,
3139
+ A_scale,
3140
+ A_fp8,
3141
+ scale_ub,
3142
+ M,
3143
+ K,
3144
+ stride_am,
3145
+ stride_ak,
3146
+ stride_om,
3147
+ stride_ok,
3148
+ stride_a_scale_m,
3149
+ stride_a_scale_k,
3150
+ TL_FP8_DTYPE: tl.constexpr,
3151
+ MAX_FP8: tl.constexpr,
3152
+ EPS: tl.constexpr,
3153
+ CLAMP_MAX: tl.constexpr,
3154
+ BLOCK_M: tl.constexpr,
3155
+ BLOCK_K: tl.constexpr,
3156
+ K_MAJOR: tl.constexpr,
3157
+ ) -> None:
3158
+ """Quantize and scale each [BLOCK_M, BLOCK_K] block.
3159
+
3160
+ Scale per block i, j is computed as 1 / (MAX_FP8 / max(abs(A[i:i+BLOCK_M, j:j+BLOCK_K])))
3161
+
3162
+ Kernel naively iterates through matrix with [BLOCK_M, BLOCK_K] tiles.
3163
+
3164
+ Todo:
3165
+ * Better tiling and ordering schemes.
3166
+
3167
+ Args:
3168
+ A (Tensor): [M, K] higher precision input tensor.
3169
+ A_scale (Tensor): [cdiv(M, BLOCK_M), cdiv(K, BLOCK_K)] reciprocal scale tensor per block.
3170
+ A_fp8 (Tensor): [M, K] fp8 scaled tensor. A_fp8 = A * a_scale
3171
+ scale_ub (Tensor): [1] Maximum allowed value for scale.
3172
+ M (int): Number of rows.
3173
+ K (int): Number of columns.
3174
+ stride_am (int): Stride of m dimension of A.
3175
+ stride_ak (int): Stride of k dimension of A.
3176
+ stride_om (int): Stride of m dimension of output.
3177
+ stride_ok (int): Stride of k dimension of output.
3178
+ stride_a_scale_m (int): Stride of m dimension of A_scale.
3179
+ stride_a_scale_k (int): Stride of k dimension of A_scale.
3180
+ TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
3181
+ MAX_FP8 (float): Maxmimum expressible value for FP8.
3182
+ EPS (float): Epsilon value for numerical stability.
3183
+ CLAMP_MAX (bool): Whether to apply scale_ub.
3184
+ BLOCK_M (int): Block size for M dimension of A_scale and kernel.
3185
+ BLOCK_K (int): Block size for K dimension of A_scale and kernel.
3186
+ K_MAJOR (bool): Whether output scales should be K major (True) or MN major (False).
3187
+ """
3188
+ pid = tl.program_id(0)
3189
+ grid_k = tl.cdiv(K, BLOCK_K)
3190
+ block_m = pid // grid_k
3191
+ block_k = pid % grid_k
3192
+ rm = block_m * BLOCK_M + tl.arange(0, BLOCK_M)
3193
+ rk = block_k * BLOCK_K + tl.arange(0, BLOCK_K)
3194
+ a_offset = rm[:, None] * stride_am + rk[None, :] * stride_ak
3195
+ out_offset = rm[:, None] * stride_om + rk[None, :] * stride_ok
3196
+ a_mask = (rm < M)[:, None] & (rk < K)[None, :]
3197
+ a_block = tl.load(A + a_offset, mask=a_mask, other=0.0)
3198
+
3199
+ block_max = tl.max(tl.abs(a_block))
3200
+ # Apply appropriate clamping.
3201
+ if CLAMP_MAX:
3202
+ ub = tl.load(scale_ub)
3203
+ block_max = tl.clamp(block_max, EPS, ub)
3204
+ else:
3205
+ block_max = tl.maximum(block_max, EPS)
3206
+ scale = MAX_FP8 / block_max
3207
+
3208
+ # Write in transposed order if specified.
3209
+ if K_MAJOR:
3210
+ scale_offset = block_m * stride_a_scale_m + block_k * stride_a_scale_k
3211
+ else:
3212
+ scale_offset = block_k * stride_a_scale_m + block_m * stride_a_scale_k
3213
+ tl.store(A_scale + scale_offset, 1.0 / scale)
3214
+ a_fp8 = a_block * scale
3215
+ # Clamp A to fp8 range to make sure there's no overflow.
3216
+ # This is required for AMD. Nvidia's default saturation
3217
+ # handles it, but it's nice to have anyway.
3218
+ a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8)
3219
+ a_fp8.to(TL_FP8_DTYPE)
3220
+ tl.store(A_fp8 + out_offset, a_fp8, mask=a_mask)
3221
+
3222
+
3223
+ def triton_quantize_fp8_block(
3224
+ x: torch.Tensor,
3225
+ block_m: int = 256,
3226
+ block_k: int = 256,
3227
+ scale_ub: Optional[torch.Tensor] = None,
3228
+ k_major: bool = True,
3229
+ ) -> tuple[torch.Tensor, torch.Tensor]:
3230
+ """
3231
+ Quantize a tensor to fp8 with block-wise scalings.
3232
+
3233
+ Scale per block i, j is computed as 1 / (MAX_FP8 / max(abs(x[i:i+block_m, j:j+block_k])))
3234
+
3235
+ Args:
3236
+ x (torch.Tensor): [M, K] higher precision input tensor.
3237
+ block_m (int): Block size for M dimension of scale.
3238
+ block_k (int): Block size for K dimension of scale.
3239
+ scale_ub: Maximum allowed value for scale.
3240
+ k_major (bool): Whether output scales should be K major (True) or MN major (False).
3241
+
3242
+ Returns:
3243
+ torch.Tensor : [M, K] fp8 scaled tensor.
3244
+ torch.Tensor: [cdiv(M, block_m), cdiv(K, block_k)] reciprocal scale tensor per block
3245
+ if k_major is True, otherwise [cdiv(K, block_k), cdiv(M, block_M)].
3246
+ """
3247
+ assert x.device != torch.device(
3248
+ "cpu"
3249
+ ), "Blockwise quantization not support on cpu, please use row-wise quantization instead."
3250
+
3251
+ if scale_ub is not None and scale_ub.device != x.device:
3252
+ raise Exception("'scale_ub' must be on the same device as 'a'")
3253
+
3254
+ x_shape = x.shape
3255
+ x = x.view(-1, x.size(-1))
3256
+ # Get constant values.
3257
+ pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
3258
+ M, K = x.shape
3259
+ grid_m = triton.cdiv(M, block_m)
3260
+ grid_k = triton.cdiv(K, block_k)
3261
+ if k_major:
3262
+ x_scale = torch.empty((grid_m, grid_k), device=x.device, dtype=torch.float32)
3263
+ else:
3264
+ x_scale = torch.empty((grid_k, grid_m), device=x.device, dtype=torch.float32)
3265
+ x_fp8 = torch.empty((M, K), device=x.device, dtype=pt_dtype)
3266
+
3267
+ _kernel_quantize_fp8_block[(grid_m * grid_k,)](
3268
+ x,
3269
+ x_scale,
3270
+ x_fp8,
3271
+ scale_ub,
3272
+ M,
3273
+ K,
3274
+ x.stride(0),
3275
+ x.stride(1),
3276
+ x_fp8.stride(0),
3277
+ x_fp8.stride(1),
3278
+ x_scale.stride(0),
3279
+ x_scale.stride(1),
3280
+ # pyre-ignore[6]: Incompatible parameter type [6]
3281
+ TL_FP8_DTYPE=tl_dtype,
3282
+ # pyre-ignore[6]: Incompatible parameter type [6]
3283
+ MAX_FP8=max_fp8,
3284
+ # pyre-ignore[6]: Incompatible parameter type [6]
3285
+ EPS=eps,
3286
+ # pyre-ignore[6]: Incompatible parameter type [6]
3287
+ CLAMP_MAX=scale_ub is not None,
3288
+ # pyre-ignore[6]: Incompatible parameter type [6]
3289
+ BLOCK_M=block_m,
3290
+ # pyre-ignore[6]: Incompatible parameter type [6]
3291
+ BLOCK_K=block_k,
3292
+ # pyre-ignore[6]: Incompatible parameter type [6]
3293
+ K_MAJOR=k_major,
3294
+ )
3295
+
3296
+ return x_fp8.view(x_shape), x_scale
3297
+
3298
+
3299
+ def quantize_fp8_block(
3300
+ x: torch.Tensor,
3301
+ block_m: int = 256,
3302
+ block_k: int = 256,
3303
+ scale_ub: Optional[torch.Tensor] = None,
3304
+ use_triton: bool = True,
3305
+ output_device: Optional[torch.device] = None,
3306
+ k_major: bool = True,
3307
+ ) -> tuple[torch.Tensor, torch.Tensor]:
3308
+ """
3309
+ Quantize a tensor to fp8 with block-wise scalings and optionally move to output device.
3310
+
3311
+ Scale per block i, j is computed as 1 / (MAX_FP8 / max(abs(x[i:i+block_m, j:j+block_k])))
3312
+
3313
+ Args:
3314
+ x (Tensor): [M, K] higher precision input tensor.
3315
+ block_m (int): Block size for M dimension of scale.
3316
+ block_k (int): Block size for K dimension of scale.
3317
+ scale_ub: Maximum allowed value for scale.
3318
+ use_triton (bool): Whether to use triton kernel or pytorch.
3319
+ output_device (torch.device): Device to optionally move the scaled tensors to.
3320
+ k_major (bool): Whether output scales should be K major (True) or MN major (False).
3321
+
3322
+ Returns:
3323
+ torch.Tensor: [M, K] fp8 scaled tensor.
3324
+ torch.Tensor: [cdiv(M, block_m), cdiv(K, block_k)] reciprocal scale tensor per block
3325
+ if k_major is True, otherwise [cdiv(K, block_k), cdiv(M, block_M)].
3326
+ """
3327
+ x_shape = x.shape
3328
+ x = x.view(-1, x.size(-1))
3329
+ if x.device == torch.device("cpu"):
3330
+ logger.info("Triton does not support cpu, falling back to torch ops.")
3331
+ use_triton = False
3332
+ if use_triton:
3333
+ xq, x_scale = triton_quantize_fp8_block(x, block_m, block_k, scale_ub, k_major)
3334
+ return xq.view(x_shape), x_scale
3335
+ # else use pytorch implementation.
3336
+ if not output_device:
3337
+ output_device = x.device
3338
+
3339
+ # Get constants.
3340
+ pt_dtype, _, max_fp8, eps = get_fp8_constants()
3341
+
3342
+ M, K = x.shape
3343
+ grid_m = triton.cdiv(M, block_m)
3344
+ grid_k = triton.cdiv(K, block_k)
3345
+
3346
+ # Pad x to multiple of block size.
3347
+ padded_m = grid_m * block_m
3348
+ padded_k = grid_k * block_k
3349
+ x_padded = torch.zeros(padded_m, padded_k, dtype=x.dtype, device=x.device)
3350
+ x_padded[:M, :K] = x
3351
+
3352
+ # Blockwise max.
3353
+ block_max = (
3354
+ x_padded.abs().reshape(grid_m, block_m, grid_k, block_k).amax(dim=(1, 3))
3355
+ )
3356
+
3357
+ # Apply clamping.
3358
+ if scale_ub is not None:
3359
+ block_max = torch.clamp(block_max, min=eps, max=scale_ub.item())
3360
+ else:
3361
+ block_max = torch.clamp(block_max, min=eps)
3362
+ x_scale = torch.empty((grid_m, grid_k), dtype=torch.float32, device=output_device)
3363
+ x_scale = max_fp8 / block_max.to(torch.float32) # pyre-ignore
3364
+ # pyre-ignore[16]: Undefined attribute [16]
3365
+ x_scale[x_scale == float("inf")] = 1.0
3366
+ x_fp8 = (
3367
+ x_padded
3368
+ # pyre-ignore[16]: Undefined attribute [16]
3369
+ * x_scale.repeat_interleave(block_m, dim=0).repeat_interleave(block_k, dim=1)
3370
+ )[:M, :K]
3371
+
3372
+ # Cast and move data to output device (for cpu weight loading).
3373
+ x_fp8 = x_fp8.to(device=output_device, dtype=pt_dtype)
3374
+ x_scale = x_scale.to(output_device) # pyre-ignore
3375
+ del x, x_padded
3376
+ if not k_major:
3377
+ x_scale = x_scale.t().contiguous()
3378
+ return x_fp8.view(x_shape), 1 / x_scale # pyre-ignore
3379
+
3380
+
3381
+ @triton.autotune(
3382
+ configs=[
3383
+ Config({"GROUP_LOAD": 2}),
3384
+ Config({"GROUP_LOAD": 4}),
3385
+ Config({"GROUP_LOAD": 8}),
3386
+ Config({"GROUP_LOAD": 16}),
3387
+ Config({"GROUP_LOAD": 32}),
3388
+ ],
3389
+ key=["K"],
3390
+ )
3391
+ @triton.jit
3392
+ def _kernel_quantize_fp8_group(
3393
+ A,
3394
+ A_scale,
3395
+ A_fp8,
3396
+ scale_ub,
3397
+ m_sizes,
3398
+ M,
3399
+ K,
3400
+ stride_am,
3401
+ stride_ak,
3402
+ stride_om,
3403
+ stride_ok,
3404
+ stride_a_scale_m,
3405
+ stride_a_scale_k,
3406
+ TL_FP8_DTYPE: tl.constexpr,
3407
+ MAX_FP8: tl.constexpr,
3408
+ EPS: tl.constexpr,
3409
+ CLAMP_MAX: tl.constexpr,
3410
+ USE_INT64: tl.constexpr,
3411
+ GROUP_SIZE: tl.constexpr,
3412
+ USE_M_MAJOR: tl.constexpr,
3413
+ G: tl.constexpr,
3414
+ GROUP_LOAD: tl.constexpr,
3415
+ ):
3416
+ """Quantize and scale each GROUP_SIZE chunk of each row.
3417
+
3418
+ Scale per group i is computed as 1 / (MAX_FP8 / max(abs(A[i:i+GROUP_SIZE])))
3419
+
3420
+ Each kernel thread is responsible for one row and loads and processes a tunable
3421
+ number of groups at once.
3422
+
3423
+ Args:
3424
+ A (Tensor): [M, K] higher precision input tensor.
3425
+ A_scale (Tensor): [M, cdiv(K, GROUP_SIZE)] reciprocal scale tensor per group.
3426
+ A_fp8 (Tensor): [M, K] fp8 scaled tensor. A_fp8 = A * a
3427
+ scale_ub (Tensor): [1] Maximum allowed value for scale.
3428
+ m_sizes (Optional[Tensor]): [G] Number of rows in each group.
3429
+ M (int): Number of rows.
3430
+ K (int): Number of columns.
3431
+ stride_am (int): Stride of m dimension of A.
3432
+ stride_ak (int): Stride of k dimension of A.
3433
+ stride_om (int): Stride of m dimension of output.
3434
+ stride_ok (int): Stride of k dimension of output.
3435
+ stride_a_scale_m (int): Stride of m dimension of A_scale.
3436
+ stride_a_scale_k (int): Stride of k dimension of A_scale.
3437
+ TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
3438
+ MAX_FP8 (float): Maxmimum expressible value for FP8.
3439
+ EPS (float): Epsilon value for numerical stability.
3440
+ CLAMP_MAX (bool): Whether to apply scale_ub.
3441
+ USE_INT64 (bool): Whether to index using int64, which may be needed for large tensors.
3442
+ GROUP_SIZE (int): Group size for K dimension of A_scale and kernel.
3443
+ USE_M_MAJOR (bool): Whether to use grouped M-major layout for A_scale.
3444
+ G (int): Number of groups in A_scale, only relevant when m_sizes is provided.
3445
+ GROUP_LOAD (int): Number of groups to load and process simultaneously.
3446
+ """
3447
+ pid = tl.program_id(0)
3448
+ if USE_INT64:
3449
+ pid = pid.to(tl.int64)
3450
+ # We load group_size * group_load chunks at a time.
3451
+ row_offset = pid * stride_am
3452
+ out_offset = pid * stride_om
3453
+ scale_row_offset = pid * stride_a_scale_m
3454
+ k_offset = tl.arange(0, GROUP_LOAD * GROUP_SIZE)
3455
+ scale_k_offset = tl.arange(0, GROUP_LOAD)
3456
+ NUM_GROUPS: tl.constexpr = K // GROUP_SIZE
3457
+
3458
+ # When dealing with an M-major grouped gemm, we need to figure out
3459
+ # which group this thread corresponds to and figure out the corresponding
3460
+ # scale offset.
3461
+ group_offset = 0
3462
+ group_cumsum = 0
3463
+ group_M = 0
3464
+ stop = False
3465
+ if USE_M_MAJOR and G > 0:
3466
+ # Iterate over groups to both compute the cumulative sum and find which group we are in.
3467
+ for i in range(G):
3468
+ if not stop:
3469
+ group_M = tl.cast(tl.load(m_sizes + i), pid.dtype)
3470
+ if (group_cumsum + group_M) <= pid:
3471
+ group_cumsum += group_M
3472
+ else:
3473
+ # Indicate we are finished computing cumsum.
3474
+ stop = True
3475
+
3476
+ group_offset = group_cumsum * NUM_GROUPS
3477
+
3478
+ for k in range(0, tl.cdiv(K, (GROUP_LOAD * GROUP_SIZE))):
3479
+ # Load groups of the input.
3480
+ chunk_offset = k_offset + k * GROUP_LOAD * GROUP_SIZE
3481
+ a = tl.load(
3482
+ A + row_offset + chunk_offset * stride_ak, mask=chunk_offset < K, other=0.0
3483
+ )
3484
+ # View loaded chunk as a set of groups.
3485
+ a_grouped = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE])
3486
+ # Reduce over groups.
3487
+ group_max = tl.max(tl.abs(a_grouped), axis=1)
3488
+ # Apply clamping if specified.
3489
+ if CLAMP_MAX:
3490
+ ub = tl.load(scale_ub)
3491
+ group_max = tl.clamp(group_max, EPS, ub)
3492
+ else:
3493
+ group_max = tl.maximum(group_max, EPS)
3494
+ # Scale and quantize.
3495
+ a_scale = MAX_FP8 / group_max
3496
+ scale_chunk_offset = scale_k_offset + k * GROUP_LOAD
3497
+
3498
+ if USE_M_MAJOR and G > 0:
3499
+ tl.store(
3500
+ A_scale
3501
+ + group_offset
3502
+ + (pid - group_cumsum) * stride_a_scale_k
3503
+ + (scale_chunk_offset * group_M),
3504
+ 1.0 / a_scale,
3505
+ mask=scale_chunk_offset < NUM_GROUPS,
3506
+ )
3507
+ else:
3508
+ if USE_M_MAJOR:
3509
+ tl.store(
3510
+ A_scale
3511
+ + pid * stride_a_scale_k
3512
+ + scale_chunk_offset * stride_a_scale_m,
3513
+ 1.0 / a_scale,
3514
+ mask=scale_chunk_offset < NUM_GROUPS,
3515
+ )
3516
+ else:
3517
+ tl.store(
3518
+ A_scale + scale_row_offset + scale_chunk_offset * stride_a_scale_k,
3519
+ 1.0 / a_scale,
3520
+ mask=scale_chunk_offset < NUM_GROUPS,
3521
+ )
3522
+ # Apply scale to input.
3523
+ a_fp8 = a_grouped * a_scale[:, None]
3524
+ # Clamp to FP8 range to avoid overflow
3525
+ a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
3526
+ # Write to output.
3527
+ tl.store(
3528
+ A_fp8 + out_offset + chunk_offset * stride_ok,
3529
+ tl.ravel(a_fp8),
3530
+ mask=chunk_offset < K,
3531
+ )
3532
+
3533
+
3534
+ def triton_quantize_fp8_group(
3535
+ x: torch.Tensor,
3536
+ group_size: int = 128,
3537
+ scale_ub: Optional[torch.Tensor] = None,
3538
+ m_sizes: Optional[torch.Tensor] = None,
3539
+ k_major: bool = True,
3540
+ ) -> tuple[torch.Tensor, torch.Tensor]:
3541
+ """
3542
+ Quantize a tensor to fp8 with group-wise scalings.
3543
+
3544
+ Scale per group i is computed as 1 / (MAX_FP8 / max(abs(x[i:i+group_size])))
3545
+
3546
+ Args:
3547
+ x (torch.Tensor): [M, K] higher precision input tensor.
3548
+ group_size (int): Group size for M dimension of scale.
3549
+ scale_ub: Maximum allowed value for scale.
3550
+ m_sizes: Optional input for grouped gemm to specify the number of rows in each group.
3551
+ k_major (bool): Whether output scales should be K major (True) or MN major (False).
3552
+
3553
+ Returns:
3554
+ torch.Tensor: [M, K] fp8 scaled tensor.
3555
+ torch.Tensor: [M, cdiv(K, group_size)] reciprocal scale tensor per group.
3556
+ """
3557
+ assert x.device != torch.device(
3558
+ "cpu"
3559
+ ), "Triton groupwise quantization not supported on cpu."
3560
+
3561
+ if scale_ub is not None and scale_ub.device != x.device:
3562
+ raise Exception("'scale_ub' must be on the same device as 'a'")
3563
+ if m_sizes is not None and m_sizes.device != x.device:
3564
+ raise Exception("'m_sizes' must be on the same device as 'a'")
3565
+
3566
+ x_shape = x.shape
3567
+ x = x.view(-1, x.size(-1))
3568
+ pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
3569
+ M, K = x.shape
3570
+ k_groups = triton.cdiv(K, group_size)
3571
+ if k_major:
3572
+ x_scale = torch.empty((M, k_groups), device=x.device, dtype=torch.float32)
3573
+ else:
3574
+ x_scale = torch.empty((k_groups, M), device=x.device, dtype=torch.float32)
3575
+ x_fp8 = torch.empty((M, K), device=x.device, dtype=pt_dtype)
3576
+ _kernel_quantize_fp8_group[(M,)](
3577
+ x,
3578
+ x_scale,
3579
+ x_fp8,
3580
+ scale_ub,
3581
+ m_sizes,
3582
+ M,
3583
+ K,
3584
+ x.stride(0),
3585
+ x.stride(1),
3586
+ x_fp8.stride(0),
3587
+ x_fp8.stride(1),
3588
+ x_scale.stride(0),
3589
+ x_scale.stride(1),
3590
+ TL_FP8_DTYPE=tl_dtype,
3591
+ MAX_FP8=max_fp8,
3592
+ EPS=eps,
3593
+ CLAMP_MAX=scale_ub is not None,
3594
+ USE_INT64=x.numel() > (2**32 - 1),
3595
+ GROUP_SIZE=group_size,
3596
+ USE_M_MAJOR=m_sizes is not None or k_major is False,
3597
+ G=m_sizes.numel() if m_sizes is not None else 0,
3598
+ )
3599
+ return x_fp8.view(x_shape), x_scale
3600
+
3601
+
3602
+ def quantize_fp8_group(
3603
+ x: torch.Tensor,
3604
+ group_size: int = 128,
3605
+ scale_ub: Optional[torch.Tensor] = None,
3606
+ m_sizes: Optional[torch.Tensor] = None,
3607
+ k_major: bool = True,
3608
+ use_triton: bool = True,
3609
+ output_device: Optional[torch.device] = None,
3610
+ ) -> tuple[torch.Tensor, torch.Tensor]:
3611
+ """
3612
+ Quantize a tensor to fp8 with group-wise scalings and optionally move to output device.
3613
+
3614
+ Scale per group i is computed as 1 / (MAX_FP8 / max(abs(x[i:i+group_size])))
3615
+
3616
+ Args:
3617
+ x (Tensor): [M, K] higher precision input tensor.
3618
+ group_size (int): Group size for M dimension of scale.
3619
+ scale_ub: Maximum allowed value for scale.
3620
+ m_sizes: Optional input for grouped gemm to specify the number of rows in each group.
3621
+ k_major (bool): Whether output scales should be K major (True) or MN major (False).
3622
+ This is needed because some kernels like cutlass require a special layout for scales.
3623
+ use_triton (bool): Whether to use triton kernel or pytorch.
3624
+ output_device (torch.device): Device to optionally move the scaled tensors to.
3625
+
3626
+ Returns:
3627
+ torch.Tensor: [M, K] fp8 scaled tensor.
3628
+ torch.Tensor: [M, cdiv(K, group_size)] reciprocal scale tensor per group.
3629
+ """
3630
+ x_shape = x.shape
3631
+ x = x.view(-1, x.size(-1))
3632
+ if x.device == torch.device("cpu"):
3633
+ logger.info("Triton does not support cpu, falling back to torch ops.")
3634
+ use_triton = False
3635
+ if use_triton:
3636
+ xq, x_scale = triton_quantize_fp8_group(
3637
+ x, group_size, scale_ub, m_sizes, k_major
3638
+ )
3639
+ return xq.view(x_shape), x_scale
3640
+ # else use pytorch implementation.
3641
+ if not output_device:
3642
+ output_device = x.device
3643
+
3644
+ # Get constants.
3645
+ pt_dtype, _, max_fp8, eps = get_fp8_constants()
3646
+
3647
+ M, K = x.shape
3648
+ assert (
3649
+ K % group_size == 0
3650
+ ), "K must be divisible by group_size for cpu implementation."
3651
+ assert m_sizes is None, "m_sizes is not supported for cpu implementation."
3652
+ k_groups = triton.cdiv(K, group_size)
3653
+ # View input as colleciton of groups for reduction.
3654
+ x_grouped = x.view(M, k_groups, group_size).to(torch.float32)
3655
+ # Reduce over groups.
3656
+ group_max = x_grouped.abs().amax(dim=2)
3657
+ # Apply clamping.
3658
+ group_max = (
3659
+ torch.clamp(group_max, min=eps, max=scale_ub.item())
3660
+ if scale_ub
3661
+ else torch.clamp(group_max, min=eps)
3662
+ )
3663
+ x_scale = torch.empty((M, k_groups), dtype=torch.float32, device=output_device)
3664
+ x_scale = max_fp8 / group_max # pyre-ignore
3665
+ # pyre-ignore[16]: Undefined attribute [16]
3666
+ x_scale[x_scale == float("inf")] = 1.0
3667
+ # pyre-ignore[16]: Undefined attribute [16]
3668
+ x_fp8 = x.view(-1, k_groups, group_size) * x_scale.unsqueeze(2)
3669
+ # Cast and move data to output device (for cpu weight loading).
3670
+ x_fp8 = x_fp8.to(device=output_device, dtype=pt_dtype)
3671
+ x_scale = x_scale.to(output_device) # pyre-ignore
3672
+ if not k_major:
3673
+ x_scale = x_scale.t().contiguous()
3674
+ return x_fp8.view(x_shape), 1 / x_scale # pyre-ignore
3675
+
3676
+
3677
+ def need_split_k(SIZE_M, SIZE_N, SIZE_K):
3678
+ return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
3679
+
3680
+
3681
+ # Force a failure instead of a warning when all configs are pruned.
3682
+ # TODO: Determine a better approach for model level testing. We need
3683
+ # to standardize our approach around prune_configs in general.
3684
+ FORCE_FAILURE_ON_EMPTY_CONFIGS = False
3685
+
3686
+
3687
+ def is_invalid_config(config, N, M, K, mfma, use_bias):
3688
+ """
3689
+ Contains all of the configuration checks for prune_configs
3690
+ that will result in an invalid result if select as the config.
3691
+
3692
+ This is done to ensure that if no config is "optimal" for a given
3693
+ shape we don't accidentally select
3694
+ """
3695
+ BLOCK_SIZE_M = config.kwargs.get("BLOCK_M")
3696
+ BLOCK_SIZE_N = config.kwargs.get("BLOCK_N")
3697
+ BLOCK_SIZE_K = config.kwargs.get("BLOCK_K")
3698
+ SPLIT_K = config.kwargs.get("SPLIT_K")
3699
+ matrix_instr_nonkdim = config.kwargs.get("matrix_instr_nonkdim")
3700
+ if matrix_instr_nonkdim > mfma:
3701
+ return True
3702
+ if mfma == 4 and BLOCK_SIZE_K < 64:
3703
+ return True
3704
+ # some layouts could not work properly in case
3705
+ # number elements per thread is less 1
3706
+ if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
3707
+ return True
3708
+ if BLOCK_SIZE_M < matrix_instr_nonkdim or BLOCK_SIZE_N < matrix_instr_nonkdim:
3709
+ return True
3710
+ if M <= matrix_instr_nonkdim and BLOCK_SIZE_M != matrix_instr_nonkdim:
3711
+ return True
3712
+ if N <= matrix_instr_nonkdim and BLOCK_SIZE_N != matrix_instr_nonkdim:
3713
+ return True
3714
+ # split_k cannot be used if there is a bias
3715
+ if use_bias and SPLIT_K != 1:
3716
+ return True
3717
+ return False
3718
+
3719
+
3720
+ # Configs adapted from https://github.com/ROCm/triton/blob/main_perf/python/perf-kernels/tools/tune_gemm/tune_gemm.py
3721
+ def prune_configs(configs, named_args, **kwargs):
3722
+
3723
+ pruned_configs = []
3724
+ M = named_args["M"]
3725
+ N = named_args["N"]
3726
+ K = named_args["K"]
3727
+ elemBytes_a = named_args["A"].element_size()
3728
+ elemBytes_b = named_args["B"].element_size()
3729
+ use_bias = kwargs["USE_BIAS"]
3730
+
3731
+ if M < 32 or N < 32:
3732
+ mfma = 16
3733
+ else:
3734
+ mfma = 32
3735
+
3736
+ for config in configs:
3737
+ BLOCK_SIZE_M = config.kwargs.get("BLOCK_M")
3738
+ BLOCK_SIZE_N = config.kwargs.get("BLOCK_N")
3739
+ BLOCK_SIZE_K = config.kwargs.get("BLOCK_K")
3740
+ SPLIT_K = config.kwargs.get("SPLIT_K")
3741
+ GROUP_M = config.kwargs.get("GROUP_M")
3742
+ if is_invalid_config(config, N, M, K, mfma, use_bias):
3743
+ continue
3744
+ # Skip BLOCK_SIZE that is too large compare to M/N
3745
+ # unless BLOCK_SIZE is already small enough
3746
+ if BLOCK_SIZE_M > M * 2 and BLOCK_SIZE_M != 16:
3747
+ continue
3748
+ if BLOCK_SIZE_N > N * 2 and BLOCK_SIZE_N != 16:
3749
+ continue
3750
+ # skip large split_k when not necessary
3751
+ if SPLIT_K != 1 and not need_split_k(M, N, K):
3752
+ continue
3753
+ # skip large GROUP_M
3754
+ if GROUP_M * BLOCK_SIZE_M >= M and GROUP_M != 1:
3755
+ continue
3756
+ # out of shared memory resource
3757
+ # TODO (zhanglx): This does not consider the LDS usage in the epilogue
3758
+ LDS = (
3759
+ BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
3760
+ + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
3761
+ )
3762
+ if LDS > 65536:
3763
+ continue
3764
+ pruned_configs.append(config)
3765
+
3766
+ print(f"{len(configs)=} {len(pruned_configs)=} for {M=} {N=} {K=}")
3767
+ if len(pruned_configs) == 0:
3768
+ if not FORCE_FAILURE_ON_EMPTY_CONFIGS:
3769
+ # Prune configs that can lead to incorrect results even if all configs are sub-optimal.
3770
+ candidate_configs = [
3771
+ c for c in configs if not is_invalid_config(c, N, M, K, mfma, use_bias)
3772
+ ]
3773
+ print(f"No configs left after pruning! {M=} {N=} {K=}")
3774
+ pruned_configs = candidate_configs[:10]
3775
+ if len(pruned_configs) == 0:
3776
+ raise RuntimeError(
3777
+ "No valid configs left after pruning! Consider autotuning further with TritonBench"
3778
+ )
3779
+ return pruned_configs
3780
+
3781
+
3782
+ def get_full_non_persistent_tuning_space():
3783
+ configs = []
3784
+
3785
+ block_mn_range = [16, 32, 64, 128, 256]
3786
+ block_k_range = [16, 32, 64, 128, 256]
3787
+ split_k_range = [1]
3788
+ num_warps_range = [1, 2, 4, 8]
3789
+ group_m_range = [1, 2, 4, 8, 16, 32]
3790
+ num_stage_range = [2]
3791
+ waves_per_eu_range = [0]
3792
+ matrix_instr_nonkdim_range = [16, 32]
3793
+ kpack_range = [1, 2]
3794
+
3795
+ for block_m in block_mn_range:
3796
+ for block_n in block_mn_range:
3797
+ for block_k in block_k_range:
3798
+ for num_warps in num_warps_range:
3799
+ for group_m in group_m_range:
3800
+ for split_k in split_k_range:
3801
+ for num_stages in num_stage_range:
3802
+ for waves_per_eu in waves_per_eu_range:
3803
+ for (
3804
+ matrix_instr_nonkdim
3805
+ ) in matrix_instr_nonkdim_range:
3806
+ for kpack in kpack_range:
3807
+ configs.append(
3808
+ triton.Config(
3809
+ {
3810
+ "BLOCK_M": block_m,
3811
+ "BLOCK_N": block_n,
3812
+ "BLOCK_K": block_k,
3813
+ "GROUP_M": group_m,
3814
+ "SPLIT_K": split_k,
3815
+ "waves_per_eu": waves_per_eu,
3816
+ "matrix_instr_nonkdim": matrix_instr_nonkdim,
3817
+ "kpack": kpack,
3818
+ },
3819
+ num_warps=num_warps,
3820
+ num_stages=num_stages,
3821
+ )
3822
+ )
3823
+ return configs
3824
+
3825
+
3826
+ MATMUL_CONFIGS_NON_PERSISTENT: list[Config] = get_full_non_persistent_tuning_space()
3827
+ # (BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M, SPLIT_K, waves_per_eu, matrix_instr_nonkdim, kpack, num_warps, num_stages)
3828
+ _MATMUL_CONFIG_TUPLES_PINGPONG_4K_8K_16K = [
3829
+ (16, 16, 256, 1, 1, 8, 16, 2, 2, 2),
3830
+ (16, 16, 256, 1, 1, 0, 16, 2, 2, 2),
3831
+ (32, 64, 512, 1, 1, 2, 16, 2, 8, 2),
3832
+ (64, 64, 256, 1, 1, 2, 16, 2, 4, 2),
3833
+ (256, 256, 128, 32, 1, 2, 16, 1, 8, 2),
3834
+ (256, 256, 128, 2, 1, 0, 32, 2, 8, 2),
3835
+ (256, 256, 128, 1, 1, 0, 32, 2, 8, 2),
3836
+ (256, 256, 128, 2, 1, 0, 16, 1, 8, 2),
3837
+ (256, 256, 64, 2, 1, 2, 16, 1, 8, 2),
3838
+ (128, 256, 64, 2, 1, 2, 16, 1, 4, 2),
3839
+ (256, 128, 128, 4, 1, 0, 16, 1, 8, 2),
3840
+ (128, 128, 128, 1, 1, 2, 16, 2, 4, 2),
3841
+ (128, 128, 256, 1, 1, 2, 16, 2, 8, 2),
3842
+ (128, 128, 64, 4, 1, 2, 16, 2, 4, 2),
3843
+ (128, 128, 64, 1, 1, 2, 16, 2, 4, 2),
3844
+ (128, 64, 64, 4, 1, 0, 16, 2, 4, 2),
3845
+ (128, 64, 64, 1, 1, 0, 16, 2, 4, 2),
3846
+ (256, 128, 128, 1, 1, 2, 16, 1, 8, 2),
3847
+ (128, 256, 128, 2, 1, 2, 16, 2, 4, 1),
3848
+ (256, 128, 64, 2, 1, 2, 16, 1, 4, 2),
3849
+ (128, 128, 256, 2, 1, 0, 16, 2, 8, 2),
3850
+ (128, 64, 128, 2, 1, 2, 16, 2, 4, 2),
3851
+ (128, 128, 64, 2, 1, 0, 16, 1, 4, 2),
3852
+ (128, 128, 128, 1, 1, 2, 16, 1, 4, 2),
3853
+ ]
3854
+
3855
+
3856
+ def _should_skip_config(block_k, matrix_instr_nonkdim):
3857
+ """Skip config if BLOCK_K=64 and matrix_instr_nonkdim=16 on GFX95+"""
3858
+ try:
3859
+ return (
3860
+ block_k == 64
3861
+ and matrix_instr_nonkdim == 16
3862
+ and torch.version.hip is not None
3863
+ and torch.cuda.get_device_capability() >= (9, 5)
3864
+ )
3865
+ except RuntimeError:
3866
+ # If no HIP GPUs are available, we can't check device capability
3867
+ # so we don't skip any configs
3868
+ return False
3869
+
3870
+
3871
+ MATMUL_CONFIGS_NON_PERSISTENT_PINGPONG_4K_8K_16K = [
3872
+ triton.Config(
3873
+ {
3874
+ "BLOCK_M": block_m,
3875
+ "BLOCK_N": block_n,
3876
+ "BLOCK_K": block_k,
3877
+ "GROUP_M": group_m,
3878
+ "SPLIT_K": split_k,
3879
+ "waves_per_eu": waves_per_eu,
3880
+ "matrix_instr_nonkdim": matrix_instr_nonkdim,
3881
+ "kpack": kpack,
3882
+ },
3883
+ num_warps=num_warps,
3884
+ num_stages=num_stages,
3885
+ )
3886
+ for block_m, block_n, block_k, group_m, split_k, waves_per_eu, matrix_instr_nonkdim, kpack, num_warps, num_stages in _MATMUL_CONFIG_TUPLES_PINGPONG_4K_8K_16K
3887
+ if not _should_skip_config(block_k, matrix_instr_nonkdim)
3888
+ ]
3889
+
3890
+ # Set this to enable full autotuning for proper benchmarking.
3891
+ # This should only be used when invoking the kernel through
3892
+ # Triton directly (e.g. TritonBench)
3893
+ #
3894
+ # NOTE: This will SIGNIFICANTLY increase autotuning time, often
3895
+ # taking hours. You should combine this with TRITON_PRINT_AUTOTUNING=1
3896
+ # to extract and add the optimal autotuning configs to
3897
+ # MATMUL_CONFIGS_NON_PERSISTENT_PINGPONG_4K_8K_16K.
3898
+
3899
+ FULL_NON_PERSISTENT_AUTOTUNING = False
3900
+ USED_MATMUL_NON_PERSISTENT_CONFIGS = (
3901
+ MATMUL_CONFIGS_NON_PERSISTENT
3902
+ if FULL_NON_PERSISTENT_AUTOTUNING
3903
+ else MATMUL_CONFIGS_NON_PERSISTENT_PINGPONG_4K_8K_16K
3904
+ )
3905
+
3906
+
3907
+ @triton.autotune(
3908
+ configs=USED_MATMUL_NON_PERSISTENT_CONFIGS,
3909
+ key=["M", "N", "K"],
3910
+ prune_configs_by={
3911
+ "early_config_prune": prune_configs,
3912
+ "perf_model": None,
3913
+ "top_k": None,
3914
+ },
3915
+ use_cuda_graph=FULL_NON_PERSISTENT_AUTOTUNING,
3916
+ )
3917
+ @triton.heuristics(
3918
+ {
3919
+ "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
3920
+ }
3921
+ )
3922
+ @triton.jit
3923
+ def _kernel_matmul_fp8_row_non_persistent(
3924
+ A,
3925
+ B,
3926
+ C,
3927
+ M,
3928
+ N,
3929
+ K,
3930
+ m_key,
3931
+ n_key,
3932
+ k_key,
3933
+ A_scale,
3934
+ B_scale,
3935
+ Bias,
3936
+ stride_am,
3937
+ stride_ak,
3938
+ stride_bn,
3939
+ stride_bk,
3940
+ stride_cm,
3941
+ stride_cn,
3942
+ dot_out_dtype: tl.constexpr,
3943
+ allow_tf32: tl.constexpr,
3944
+ fp8_fast_accum: tl.constexpr,
3945
+ BLOCK_M: tl.constexpr,
3946
+ BLOCK_N: tl.constexpr,
3947
+ BLOCK_K: tl.constexpr,
3948
+ GROUP_M: tl.constexpr,
3949
+ SPLIT_K: tl.constexpr,
3950
+ EVEN_K: tl.constexpr,
3951
+ USE_BIAS: tl.constexpr,
3952
+ AB_DTYPE: tl.constexpr,
3953
+ ) -> None:
3954
+ """Matmul kernel of [M, K] @ [N, K] with row-wise scales
3955
+
3956
+ performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
3957
+
3958
+ Args:
3959
+ A (TensorWrapper): [M, K] input tensor.
3960
+ B (TensorWrapper): [N, K] input tensor.
3961
+ C (TensorWrapper): [M, N] output tensor.
3962
+ M (int): M dimension of input tensor.
3963
+ N (int): N dimension of input tensor.
3964
+ K (int): K dimension of input tensor.
3965
+ m_key (int): Autotuning key for M dimension of input tensor.
3966
+ n_key (int): Autotuning key for N dimension of input tensor.
3967
+ k_key (int): Autotuning key for K dimension of input tensor.
3968
+ A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
3969
+ B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
3970
+ Bias (tensorWrapper): [N] Optional bias tensor.
3971
+ stride_am (int): Stride of M dimension of A.
3972
+ stride_ak (int): Stride of K dimension of A.
3973
+ stride_bn (int): Stride of N dimension of B.
3974
+ stride_bk (int): Stride of K dimension of B.
3975
+ stride_cm (int): Stride of M dimension of C.
3976
+ stride_cn (int): Stride of N dimension of C.
3977
+ dot_out_dtype (torch.dtype): Output type of tensor core.
3978
+ allow_tf32 (bool): Whether to use TF32 for tensor core.
3979
+ fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
3980
+ BLOCK_M (int): Block size for M dimension.
3981
+ BLOCK_N (int): Block size for N dimension.
3982
+ BLOCK_K (int): Block size for K dimension.
3983
+ GROUP_M (int): Number of groups for M dimension swizzle.
3984
+ SPLIT_K (int): Number of SM's to launch per row.
3985
+ EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
3986
+ USE_BIAS (bool): Whether to use bias.
3987
+ AB_DTYPE (bool): Whether to cast A and B to C.dtype before tensor core.
3988
+ """
3989
+ tl.assume(M >= 0)
3990
+ tl.assume(N >= 0)
3991
+ tl.assume(K >= 0)
3992
+ tl.assume(stride_am >= 0)
3993
+ tl.assume(stride_ak >= 0)
3994
+ tl.assume(stride_bn >= 0)
3995
+ tl.assume(stride_bk >= 0)
3996
+ tl.assume(stride_cm >= 0)
3997
+ tl.assume(stride_cn >= 0)
3998
+ # Matrix multiplication.
3999
+ pid = tl.program_id(0)
4000
+ pid_z = tl.program_id(1)
4001
+ grid_m = tl.cdiv(M, BLOCK_M)
4002
+ grid_n = tl.cdiv(N, BLOCK_N)
4003
+ # Re-order program ID for better L2 performance (swizzle).
4004
+ width = GROUP_M * grid_n
4005
+ group_id = pid // width
4006
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
4007
+ pid_m = group_id * GROUP_M + ((pid % width) % group_size)
4008
+ pid_n = (pid % width) // (group_size)
4009
+ tl.assume(pid_m >= 0)
4010
+ tl.assume(pid_n >= 0)
4011
+ # Do matrix multiplication.
4012
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
4013
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
4014
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
4015
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
4016
+ rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
4017
+ # Pointers.
4018
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
4019
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
4020
+ acc_dtype = tl.float32 if allow_tf32 else dot_out_dtype
4021
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)
4022
+
4023
+ for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
4024
+ if EVEN_K:
4025
+ a = tl.load(A)
4026
+ b = tl.load(B)
4027
+ else:
4028
+ k_remaining = K - k * (BLOCK_K * SPLIT_K)
4029
+ _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
4030
+ a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
4031
+ b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
4032
+ if AB_DTYPE:
4033
+ a = a.to(C.dtype.element_ty)
4034
+ b = b.to(C.dtype.element_ty)
4035
+ if fp8_fast_accum:
4036
+ acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32)
4037
+ else:
4038
+ acc += tl.dot(a, b, out_dtype=acc_dtype, allow_tf32=allow_tf32)
4039
+
4040
+ A += BLOCK_K * SPLIT_K * stride_ak
4041
+ B += BLOCK_K * SPLIT_K * stride_bk
4042
+
4043
+ # rematerialize rm and rn to save registers
4044
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
4045
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
4046
+
4047
+ # Invert scaling.
4048
+ a_scale = tl.load(A_scale + rm, mask=rm < M)
4049
+ b_scale = tl.load(B_scale + rn, mask=rn < N)
4050
+ # Invert vector, then multiply on matrix for speed.
4051
+ # pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`.
4052
+ scale = a_scale[:, None] * b_scale[None, :]
4053
+ acc *= scale
4054
+
4055
+ # Load and add bias if specified.
4056
+ if USE_BIAS:
4057
+ bias = tl.load(Bias + rn, mask=rn < N)
4058
+ acc += bias[None, :]
4059
+
4060
+ acc = acc.to(C.dtype.element_ty)
4061
+ C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
4062
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
4063
+ # Handles write-back with reduction-splitting
4064
+ if SPLIT_K == 1:
4065
+ tl.store(C, acc, mask=mask)
4066
+ else:
4067
+ tl.atomic_add(C, acc, mask=mask)
4068
+
4069
+
4070
+ @triton.autotune(
4071
+ configs=[Config({"BLOCK_M": 16, "BLOCK_K": 512, "NUM_STAGES": 2})],
4072
+ key=["M", "K"],
4073
+ )
4074
+ @triton.jit
4075
+ def _kernel_dequantize_fp8_row(
4076
+ xq_ptr,
4077
+ x_scale_ptr,
4078
+ x_dequant_ptr,
4079
+ M,
4080
+ K,
4081
+ stride_xm,
4082
+ stride_xk,
4083
+ stride_xdqm,
4084
+ stride_xdqk,
4085
+ BLOCK_M: tl.constexpr,
4086
+ BLOCK_K: tl.constexpr,
4087
+ NUM_STAGES: tl.constexpr,
4088
+ USE_INT64: tl.constexpr,
4089
+ ):
4090
+ """
4091
+ Kernel to dequantize FP8 tensor to BF16 tensor.
4092
+ Args:
4093
+ xq_ptr (tl.constexpr): Pointer to FP8 tensor.
4094
+ x_scale_ptr (tl.constexpr): Pointer to FP8 scale tensor.
4095
+ x_dequant_ptr (tl.constexpr): Pointer to BF16 tensor.
4096
+ M (tl.constexpr): M dimension of input tensor.
4097
+ K (tl.constexpr): K dimension of input tensor (along which scales are applied)
4098
+ BLOCK_SIZE (tl.constexpr): Block size for the K dimension.
4099
+ """
4100
+ pid = tl.program_id(axis=0)
4101
+ if USE_INT64:
4102
+ pid = pid.to(tl.int64)
4103
+ offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
4104
+ offs_k = tl.arange(0, BLOCK_K)
4105
+ scales = tl.load(x_scale_ptr + offs_m)
4106
+
4107
+ for _k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES):
4108
+ mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
4109
+ xq = tl.load(
4110
+ xq_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk,
4111
+ mask=mask,
4112
+ )
4113
+ x_dq = xq * scales[:, None]
4114
+ tl.store(
4115
+ x_dequant_ptr
4116
+ + offs_m[:, None] * stride_xdqm
4117
+ + offs_k[None, :] * stride_xdqk,
4118
+ x_dq,
4119
+ mask=mask,
4120
+ )
4121
+ offs_k += BLOCK_K
4122
+
4123
+
4124
+ def dequantize_fp8_row(
4125
+ xq: torch.Tensor,
4126
+ x_scale: torch.Tensor,
4127
+ ) -> torch.Tensor:
4128
+ """
4129
+ Rowwise Dequantize FP8 tensor to BF16 tensor along last axis.
4130
+
4131
+ Args:
4132
+ xq (torch.Tensor): FP8 tensor to be dequantized.
4133
+ x_scale (torch.Tensor): FP8 scale tensor.
4134
+
4135
+ Returns:
4136
+ torch.Tensor: Dequantized BF16 tensor.
4137
+ """
4138
+
4139
+ assert (
4140
+ xq.is_contiguous() and x_scale.is_contiguous()
4141
+ ), "Input tensors must be contiguous"
4142
+ x_dequant = torch.empty_like(xq, dtype=torch.bfloat16)
4143
+
4144
+ # Reshape to 2-d array keeping last dim only.
4145
+ K = xq.shape[-1]
4146
+ xq = xq.reshape(-1, K)
4147
+ M = xq.shape[0]
4148
+ use_int64 = xq.numel() > 2**31
4149
+
4150
+ def grid(meta: dict[str, int]) -> tuple[int]:
4151
+ return (triton.cdiv(M, meta["BLOCK_M"]),)
4152
+
4153
+ with torch.cuda.device(xq.device.index):
4154
+ _kernel_dequantize_fp8_row[grid](
4155
+ xq,
4156
+ x_scale,
4157
+ x_dequant,
4158
+ M,
4159
+ K,
4160
+ xq.stride(0),
4161
+ xq.stride(1),
4162
+ xq.stride(0), # Use squashed stride.
4163
+ xq.stride(1),
4164
+ USE_INT64=use_int64,
4165
+ )
4166
+ return x_dequant
4167
+
4168
+
4169
+ @triton.autotune(
4170
+ configs=[Config({"BLOCK_M": 16, "BLOCK_K": 512, "NUM_STAGES": 2})],
4171
+ key=["M", "K"],
4172
+ )
4173
+ @triton.jit
4174
+ def _kernel_dequantize_fp8_packed_row(
4175
+ xq_ptr,
4176
+ x_scale_ptr,
4177
+ x_dequant_ptr,
4178
+ M,
4179
+ K,
4180
+ stride_xm,
4181
+ stride_xk,
4182
+ stride_xdqm,
4183
+ stride_xdqk,
4184
+ BLOCK_M: tl.constexpr,
4185
+ BLOCK_K: tl.constexpr,
4186
+ NUM_STAGES: tl.constexpr,
4187
+ USE_INT64: tl.constexpr,
4188
+ ):
4189
+ """
4190
+ Kernel to dequantize FP8 tensor to BF16 tensor.
4191
+ Args:
4192
+ xq_ptr (tl.constexpr): Pointer to FP8 tensor.
4193
+ x_scale_ptr (tl.constexpr): Pointer to FP8 scale tensor.
4194
+ x_dequant_ptr (tl.constexpr): Pointer to BF16 tensor.
4195
+ M (tl.constexpr): M dimension of input tensor.
4196
+ K (tl.constexpr): K dimension of input tensor (along which scales are applied)
4197
+ BLOCK_SIZE (tl.constexpr): Block size for the K dimension.
4198
+ """
4199
+ pid = tl.program_id(axis=0)
4200
+ if USE_INT64:
4201
+ pid = pid.to(tl.int64)
4202
+ offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
4203
+ offs_k = tl.arange(0, BLOCK_K)
4204
+ scales = tl.load(x_scale_ptr + offs_m)
4205
+
4206
+ for _k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES):
4207
+ mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
4208
+
4209
+ xq = tl.load(
4210
+ xq_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk,
4211
+ mask=mask,
4212
+ other=0.0,
4213
+ )
4214
+ x_dq = xq * scales[:, None]
4215
+
4216
+ tl.store(
4217
+ x_dequant_ptr
4218
+ + offs_m[:, None] * stride_xdqm
4219
+ + offs_k[None, :] * stride_xdqk,
4220
+ x_dq,
4221
+ mask=mask,
4222
+ )
4223
+ offs_k += BLOCK_K
4224
+
4225
+
4226
+ def dequantize_fp8_packed_row(
4227
+ xq: torch.Tensor,
4228
+ ) -> torch.Tensor:
4229
+ """
4230
+ Rowwise Dequantize FP8 tensor to BF16 tensor along last axis.
4231
+
4232
+ Args:
4233
+ xq (torch.Tensor): Packed FP8 tensor to be dequantized. The last 4 bytes of each row is the FP32 scale for that row.
4234
+
4235
+ Returns:
4236
+ torch.Tensor: Dequantized BF16 tensor.
4237
+ """
4238
+
4239
+ # Create a view of the packed tensors, get the scale and actual xq tensor
4240
+ # This makes it much easier to write the kernel
4241
+ orig_shape = (*xq.shape[:-1], xq.shape[-1] - 4)
4242
+ actual_xq = xq[..., :-4].view(orig_shape)
4243
+
4244
+ assert xq.is_contiguous(), "Input tensors must be contiguous"
4245
+ x_dequant = torch.empty(orig_shape, dtype=torch.bfloat16, device=xq.device)
4246
+
4247
+ # Calculate number of rows when flattened
4248
+ num_rows = actual_xq.numel() // actual_xq.shape[-1]
4249
+
4250
+ # TODO: we take a perf hit from these reshapes, can we do better?
4251
+ # It's hard to skip this reshape, we can't create a int32/float32 view because of alignment issues
4252
+ scale_view = xq[..., -4:].reshape((num_rows * 4)).view(torch.float32)
4253
+ scale_view = scale_view.view(orig_shape[:-1])
4254
+
4255
+ # Reshape to 2-d array keeping last dim only.
4256
+ K = actual_xq.shape[-1]
4257
+ actual_xq = actual_xq.reshape(-1, K)
4258
+ M = actual_xq.shape[0]
4259
+ use_int64 = actual_xq.numel() > 2**31
4260
+
4261
+ def grid(meta: dict[str, int]) -> tuple[int]:
4262
+ return (triton.cdiv(M, meta["BLOCK_M"]),)
4263
+
4264
+ with torch.cuda.device(actual_xq.device.index):
4265
+ _kernel_dequantize_fp8_packed_row[grid](
4266
+ actual_xq,
4267
+ scale_view,
4268
+ x_dequant,
4269
+ M,
4270
+ K,
4271
+ actual_xq.stride(0),
4272
+ actual_xq.stride(1),
4273
+ x_dequant.stride(-2), # Use squashed stride.
4274
+ x_dequant.stride(-1),
4275
+ USE_INT64=use_int64,
4276
+ )
4277
+
4278
+ return x_dequant
4279
+
4280
+
4281
+ @triton.jit
4282
+ def _kernel_dequantize_fp8_block(
4283
+ xq_ptr,
4284
+ x_scale_ptr,
4285
+ x_dequant_ptr,
4286
+ M,
4287
+ K,
4288
+ BLOCK_M: tl.constexpr,
4289
+ BLOCK_K: tl.constexpr,
4290
+ ):
4291
+ """
4292
+ Kernel to dequantize FP8 tensor to BF16 tensor.
4293
+ Args:
4294
+ xq_ptr (tl.constexpr): Pointer to FP8 tensor.
4295
+ x_scale_ptr (tl.constexpr): Pointer to FP8 scale tensor.
4296
+ x_dequant_ptr (tl.constexpr): Pointer to BF16 tensor.
4297
+ M (tl.constexpr): M dimension of input tensor.
4298
+ K (tl.constexpr): K dimension of input tensor.
4299
+ BLOCK_M (tl.constexpr): Block size for the M dimension.
4300
+ BLOCK_K (tl.constexpr): Block size for the K dimension.
4301
+ """
4302
+ pid_m = tl.program_id(axis=0)
4303
+ pid_k = tl.program_id(axis=1)
4304
+ k = tl.cdiv(K, BLOCK_K)
4305
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
4306
+ offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
4307
+ offs = offs_m[:, None] * K + offs_k[None, :]
4308
+ mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
4309
+ xq = tl.load(xq_ptr + offs, mask=mask).to(tl.bfloat16)
4310
+ x_scale = tl.load(x_scale_ptr + pid_m * k + pid_k)
4311
+ x_dequant = xq * x_scale
4312
+ tl.store(x_dequant_ptr + offs, x_dequant, mask=mask)
4313
+
4314
+
4315
+ def dequantize_fp8_block(
4316
+ xq: torch.Tensor,
4317
+ x_scale: torch.Tensor,
4318
+ block_m: int = 256,
4319
+ block_k: int = 256,
4320
+ ) -> torch.Tensor:
4321
+ """
4322
+ Dequantize FP8 tensor to BF16 tensor.
4323
+
4324
+ Args:
4325
+ xq (torch.Tensor): FP8 tensor to be dequantized.
4326
+ x_scale (torch.Tensor): FP8 scale tensor.
4327
+ block_m (int): Block size for the M dimension.
4328
+ block_k (int): Block size for the K dimension.
4329
+
4330
+ Returns:
4331
+ torch.Tensor: Dequantized BF16 tensor.
4332
+ """
4333
+
4334
+ assert (
4335
+ xq.is_contiguous() and x_scale.is_contiguous()
4336
+ ), "Input tensors must be contiguous"
4337
+ assert xq.dim() == 2 and x_scale.dim() == 2, "Input tensors must have 2 dimensions"
4338
+ M, K = xq.size()
4339
+ x_dequant = torch.empty_like(xq, dtype=torch.bfloat16)
4340
+
4341
+ def grid(meta: dict[str, int]) -> tuple[int, int]:
4342
+ return (
4343
+ triton.cdiv(M, meta["BLOCK_M"]),
4344
+ triton.cdiv(K, meta["BLOCK_K"]),
4345
+ )
4346
+
4347
+ with torch.cuda.device(xq.device.index):
4348
+ _kernel_dequantize_fp8_block[grid](
4349
+ xq, x_scale, x_dequant, M, K, BLOCK_M=block_m, BLOCK_K=block_k # pyre-ignore[6]
4350
+ )
4351
+ return x_dequant
4352
+
4353
+
4354
+ # This function is extracted from https://github.com/pytorch/ao/blob/v0.12.0/torchao/prototype/mx_formats/mx_tensor.py#L142
4355
+ def to_mxfp8(
4356
+ data_hp: torch.Tensor,
4357
+ block_size: int = 32,
4358
+ ):
4359
+ assert data_hp.dtype in (
4360
+ torch.bfloat16,
4361
+ torch.float,
4362
+ ), f"{data_hp.dtype} is not supported yet"
4363
+ assert (
4364
+ data_hp.shape[-1] % block_size == 0
4365
+ ), f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}"
4366
+ assert data_hp.is_contiguous(), "unsupported"
4367
+
4368
+ orig_shape = data_hp.shape
4369
+ data_hp = data_hp.reshape(
4370
+ *orig_shape[:-1], orig_shape[-1] // block_size, block_size
4371
+ )
4372
+
4373
+ max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1)
4374
+
4375
+ data_hp = data_hp.to(torch.float32)
4376
+ max_abs = max_abs.to(torch.float32)
4377
+
4378
+ F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0
4379
+ max_pos = F8E4M3_MAX
4380
+
4381
+ # RCEIL
4382
+ def _to_mx_rceil(
4383
+ data_hp: torch.Tensor,
4384
+ max_abs: torch.Tensor,
4385
+ max_pos: float,
4386
+ ) -> tuple[torch.Tensor, torch.Tensor]:
4387
+ E8M0_EXPONENT_BIAS = 127
4388
+ descale = max_abs / max_pos
4389
+ exponent = torch.where(
4390
+ torch.isnan(descale),
4391
+ 0xFF, # Handle biased exponent for nan
4392
+ # NOTE: descale < (torch.finfo(torch.float32).smallest_normal / 2) is handled through clamping
4393
+ (
4394
+ torch.clamp(
4395
+ torch.ceil(torch.log2(descale)),
4396
+ min=-E8M0_EXPONENT_BIAS,
4397
+ max=E8M0_EXPONENT_BIAS,
4398
+ )
4399
+ + E8M0_EXPONENT_BIAS
4400
+ ).to(torch.uint8),
4401
+ )
4402
+
4403
+ descale_fp = torch.where(
4404
+ exponent == 0,
4405
+ 1.0,
4406
+ torch.exp2(E8M0_EXPONENT_BIAS - exponent.to(torch.float32)),
4407
+ )
4408
+
4409
+ # scale and saturated cast the data elements to max of target dtype
4410
+ data_lp = torch.clamp(data_hp * descale_fp, min=-1 * max_pos, max=max_pos)
4411
+ return exponent, data_lp
4412
+
4413
+ scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos)
4414
+
4415
+ # cast to target dtype
4416
+ data_lp = data_lp.to(torch.float8_e4m3fn)
4417
+ # need to reshape at the end to help inductor fuse things
4418
+ data_lp = data_lp.reshape(orig_shape)
4419
+
4420
+ scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu)
4421
+ scale_e8m0_biased = scale_e8m0_biased.squeeze(-1)
4422
+ return scale_e8m0_biased, data_lp