mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1902 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import logging
9
+ import os
10
+ from typing import Dict, Optional, Tuple
11
+
12
+ import torch
13
+ import triton # @manual
14
+ import triton.language as tl # @manual
15
+ from mslk.utils.triton.fp8_utils import get_fp8_constants
16
+ from triton import Config # @manual
17
+
18
+ logger: logging.Logger = logging.getLogger(__name__)
19
+
20
+ running_on_github: bool = os.getenv("GITHUB_ENV") is not None
21
+
22
+ try:
23
+ # pyre-ignore[21]
24
+ from triton.fb.compat import disable_bufferops # @manual
25
+ except ModuleNotFoundError:
26
+ # Ensure we can call disable_bufferops if compat is not included (e.g. opensource)
27
+ # TODO(njriasan): Remove when we integrate triton.fb.compat into every Triton
28
+ # version.
29
+ from contextlib import contextmanager
30
+
31
+ @contextmanager
32
+ def disable_bufferops(_unused: bool):
33
+ yield None
34
+
35
+
36
+ @triton.autotune(
37
+ configs=[
38
+ Config({"BLOCK_SIZE": 512}),
39
+ Config({"BLOCK_SIZE": 1024}),
40
+ Config({"BLOCK_SIZE": 2048}),
41
+ Config({"BLOCK_SIZE": 4096}),
42
+ Config({"BLOCK_SIZE": 8192}),
43
+ ],
44
+ key=["K"],
45
+ )
46
+ @triton.jit
47
+ def _kernel_quantize_fp8_row(
48
+ A,
49
+ A_scale,
50
+ A_fp8,
51
+ scale_ub,
52
+ zero_start_index_M,
53
+ B,
54
+ M,
55
+ N,
56
+ K,
57
+ K_fp8, # used when padding
58
+ stride_ab,
59
+ stride_am,
60
+ stride_an,
61
+ stride_ak,
62
+ stride_ob,
63
+ stride_om,
64
+ stride_on,
65
+ stride_ok,
66
+ stride_zb,
67
+ stride_zm,
68
+ TL_FP8_DTYPE: tl.constexpr,
69
+ MAX_FP8: tl.constexpr,
70
+ EPS: tl.constexpr,
71
+ CLAMP_MAX: tl.constexpr,
72
+ JAGGED: tl.constexpr,
73
+ BLOCK_SIZE: tl.constexpr,
74
+ USE_INT64: tl.constexpr,
75
+ ) -> None:
76
+ """Quantize and scale each row.
77
+
78
+ Scale per row i is computed as MAX_FP8 / max(abs(A[i, :]))
79
+
80
+ Kernel naively iterates through matrix with [1, BLOCK_SIZE] tiles
81
+ in a max pass then scale/quantize pass.
82
+
83
+ Todo:
84
+ * Better tiling schemes.
85
+
86
+ Args:
87
+ A (Tensor): higher precision input tensor of 4 dimension.
88
+ A_scale (Tensor): [B * M * N] reciprocal scale tensor per row.
89
+ A_fp8 (Tensor): fp8 scaled tensor. A_fp8 = A / a_scale
90
+ scale_ub (Tensor): [1] Maximum value allowed for scale.
91
+ B (int): Size of dimenion 0
92
+ M (int): Size of dimenion 1
93
+ N (int): Size of dimenion 2
94
+ K (int): Size of dimenion 3 (input row size)
95
+ K_fp8 (int): Size of dimenion 3 for A_fp8 (output row size, can be >= K)
96
+ stride_ab (int): Stride of b dimension of A.
97
+ stride_am (int): Stride of m dimension of A.
98
+ stride_an (int): Stride of n dimension of A.
99
+ stride_ak (int): Stride of k dimension of A.
100
+ stride_ob (int): Stride of b dimension of output.
101
+ stride_om (int): Stride of m dimension of output.
102
+ stride_on (int): Stride of n dimension of output.
103
+ stride_ok (int): Stride of k dimension of output.
104
+ stride_zb (int): Stride of b dimension of jagged index.
105
+ stride_zm (int): Stride of m dimension of jagged index.
106
+ TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
107
+ MAX_FP8 (float): Maxmimum expressible value for FP8.
108
+ EPS (float): Epsilon value for numerical stability.
109
+ CLAMP_MAX (bool): Whethar to apply scale_ub.
110
+ JAGGED (bool): Whether to use jagged indexing.
111
+ BLOCK_SIZE (int): Block size for reduction.
112
+ USE_INT64 (bool): Whether to use int64 indexing for large inputs.
113
+ """
114
+ pid = tl.program_id(0)
115
+ # Use int64 indexing for large inputs. This is slower, but
116
+ # needed to avoid index overflows.
117
+ if USE_INT64:
118
+ pid = pid.to(tl.int64)
119
+ n_offset = tl.arange(0, BLOCK_SIZE)
120
+ a_offset_base = (
121
+ pid // (M * N) * stride_ab
122
+ + (pid % (M * N)) // N * stride_am
123
+ + (pid % (M * N)) % N * stride_an
124
+ )
125
+ a_fp8_offset_base = (
126
+ pid // (M * N) * stride_ob
127
+ + (pid % (M * N)) // N * stride_om
128
+ + (pid % (M * N)) % N * stride_on
129
+ )
130
+
131
+ K_in = K
132
+
133
+ if JAGGED:
134
+ z_offset_base = pid // (M * N) * stride_zb + (pid % (M * N)) // N * stride_zm
135
+ group_rows = tl.load(zero_start_index_M + z_offset_base)
136
+ current_row = pid % N
137
+ # If this row is empty, dont process any of it.
138
+ if current_row >= group_rows:
139
+ K_in = 0
140
+
141
+ # Calculate max.
142
+ cur_max = 0.0
143
+ for _k in range(0, tl.cdiv(K_in, BLOCK_SIZE)):
144
+ a = tl.load(
145
+ A + a_offset_base + n_offset * stride_ak,
146
+ mask=n_offset < K_in,
147
+ other=0.0,
148
+ )
149
+ tile_max = tl.max(tl.abs(a))
150
+ cur_max = tl.maximum(tile_max, cur_max)
151
+ n_offset += BLOCK_SIZE
152
+
153
+ # Clamp max value appropriately.
154
+ if CLAMP_MAX:
155
+ ub = tl.load(scale_ub)
156
+ cur_max = tl.clamp(cur_max, EPS, ub)
157
+ else:
158
+ cur_max = tl.maximum(cur_max, EPS)
159
+ # Scale and quantize.
160
+ a_scale = MAX_FP8 / cur_max
161
+ tl.store(A_scale + pid, 1.0 / a_scale)
162
+ n_offset = tl.arange(0, BLOCK_SIZE)
163
+
164
+ # Write quantized values for the first K elements (from A), and pad the rest with zeros up to K_fp8
165
+ for _k in range(0, tl.cdiv(K_fp8, BLOCK_SIZE)):
166
+ # Load from A if in range, else 0 (we're going all the way to K_fp8)
167
+ a = tl.load(
168
+ A + a_offset_base + n_offset * stride_ak,
169
+ mask=n_offset < K_in,
170
+ other=0.0,
171
+ )
172
+ # For elements >= K, a will be 0
173
+ a_fp8 = a * a_scale
174
+ # Clamp A to fp8 range to make sure there's no overflow.
175
+ # This is required for AMD. Nvidia's default saturation
176
+ # handles it, but it's nice to have anyway.
177
+ a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
178
+
179
+ # Store the full new row in its place (for elements >= K, a_fp8 is already 0)
180
+ tl.store(
181
+ A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
182
+ a_fp8,
183
+ mask=n_offset < K_fp8,
184
+ )
185
+ n_offset += BLOCK_SIZE
186
+
187
+
188
+ def triton_quantize_fp8_row(
189
+ a: torch.Tensor,
190
+ scale_ub: Optional[torch.Tensor] = None,
191
+ zero_start_index_M: Optional[torch.Tensor] = None,
192
+ align_rows_to: Optional[int] = None,
193
+ ) -> tuple[torch.Tensor, torch.Tensor]:
194
+ """
195
+ Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
196
+
197
+ Args:
198
+ a (Tensor): higher precision input tensor of 4 dimension.
199
+ scale_ub (Tensor): Maximum allowed value for scale.
200
+ zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
201
+ align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
202
+
203
+ Returns:
204
+ torch.Tensor: fp8 scaled tensor.
205
+ torch.Tensor: reciprocal scale tensor per row.
206
+ """
207
+ if scale_ub is not None and scale_ub.device != a.device:
208
+ raise Exception("'scale_ub' must be on the same device as 'a'")
209
+ if zero_start_index_M is not None and zero_start_index_M.device != a.device:
210
+ raise Exception("'zero_start_index_M' must be on the same device as 'a'")
211
+
212
+ assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor."
213
+ a_shape = a.shape
214
+ while a.dim() < 4:
215
+ a = a.unsqueeze(0)
216
+ if zero_start_index_M is not None:
217
+ # There should be one value of zero_start_index_M per NxK matrix.
218
+ zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
219
+ # Get constant values.
220
+ pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
221
+ num_rows = a.numel() // a.shape[-1]
222
+ a_scale = torch.empty((num_rows), dtype=torch.float32, device=a.device)
223
+ # If align_rows_to is provided, pad the last dimension to be a multiple of it
224
+ if align_rows_to is not None:
225
+ last_dim = a.shape[-1]
226
+ padded_last_dim = (
227
+ (last_dim + align_rows_to - 1) // align_rows_to
228
+ ) * align_rows_to
229
+ a_fp8 = torch.empty(
230
+ (*a.shape[:-1], padded_last_dim), device=a.device, dtype=pt_dtype
231
+ )
232
+ a_shape = torch.Size((*a_shape[:-1], padded_last_dim))
233
+ else:
234
+ a_fp8 = torch.empty(a.shape, device=a.device, dtype=pt_dtype)
235
+
236
+ # If input tensor is sufficiently large, we need to use int64 indexing.
237
+ use_int64 = a.numel() > (2**31 - 1)
238
+ grid = (num_rows,)
239
+ # Pick a conservative value for inference shapes for disabling BufferOps.
240
+ should_disable_bufferops = torch.version.hip is not None and a_shape[0] < 32
241
+ with disable_bufferops(should_disable_bufferops):
242
+ with torch.cuda.device(a.device.index):
243
+ _kernel_quantize_fp8_row[grid](
244
+ a,
245
+ a_scale,
246
+ a_fp8,
247
+ scale_ub,
248
+ zero_start_index_M,
249
+ a.shape[0],
250
+ a.shape[1],
251
+ a.shape[2],
252
+ a.shape[3],
253
+ a_fp8.shape[3],
254
+ a.stride(0),
255
+ a.stride(1),
256
+ a.stride(2),
257
+ a.stride(3),
258
+ a_fp8.stride(0),
259
+ a_fp8.stride(1),
260
+ a_fp8.stride(2),
261
+ a_fp8.stride(3),
262
+ (
263
+ zero_start_index_M.stride(0)
264
+ if zero_start_index_M is not None
265
+ else None
266
+ ),
267
+ (
268
+ zero_start_index_M.stride(1)
269
+ if zero_start_index_M is not None
270
+ else None
271
+ ),
272
+ TL_FP8_DTYPE=tl_dtype,
273
+ MAX_FP8=max_fp8,
274
+ EPS=eps,
275
+ CLAMP_MAX=scale_ub is not None,
276
+ JAGGED=zero_start_index_M is not None,
277
+ USE_INT64=use_int64,
278
+ )
279
+
280
+ return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])
281
+
282
+
283
+ @triton.autotune(
284
+ configs=[
285
+ Config({"BLOCK_SIZE": 512}),
286
+ Config({"BLOCK_SIZE": 1024}),
287
+ Config({"BLOCK_SIZE": 2048}),
288
+ Config({"BLOCK_SIZE": 4096}),
289
+ Config({"BLOCK_SIZE": 8192}),
290
+ ],
291
+ key=["K"],
292
+ )
293
+ @triton.jit
294
+ def _kernel_quantize_fp8_packed_row(
295
+ A,
296
+ A_fp8,
297
+ packed_scale,
298
+ scale_ub,
299
+ zero_start_index_M,
300
+ B,
301
+ M,
302
+ N,
303
+ K,
304
+ stride_ab,
305
+ stride_am,
306
+ stride_an,
307
+ stride_ak,
308
+ stride_ob,
309
+ stride_om,
310
+ stride_on,
311
+ stride_ok,
312
+ packed_scale_stride,
313
+ stride_zb,
314
+ stride_zm,
315
+ TL_FP8_DTYPE: tl.constexpr,
316
+ MAX_FP8: tl.constexpr,
317
+ EPS: tl.constexpr,
318
+ CLAMP_MAX: tl.constexpr,
319
+ JAGGED: tl.constexpr,
320
+ BLOCK_SIZE: tl.constexpr,
321
+ USE_INT64: tl.constexpr,
322
+ ) -> None:
323
+ """Quantize and scale each row.
324
+
325
+ Scale per row i is computed as MAX_FP8 / max(abs(A[i, :]))
326
+
327
+ Kernel naively iterates through matrix with [1, BLOCK_SIZE] tiles
328
+ in a max pass then scale/quantize pass.
329
+
330
+ Todo:
331
+ * Better tiling schemes.
332
+
333
+ Args:
334
+ A (Tensor): higher precision input tensor of 4 dimension.
335
+ packed_scale (Tensor): [B * M * N] reciprocal scale tensor per row.
336
+ A_fp8 (Tensor): fp8 scaled tensor. A_fp8 = A / a_scale
337
+ scale_ub (Tensor): [1] Maximum value allowed for scale.
338
+ B (int): Size of dimenion 0
339
+ M (int): Size of dimenion 1
340
+ N (int): Size of dimenion 2
341
+ K (int): Size of dimenion 3
342
+ stride_ab (int): Stride of b dimension of A.
343
+ stride_am (int): Stride of m dimension of A.
344
+ stride_an (int): Stride of n dimension of A.
345
+ stride_ak (int): Stride of k dimension of A.
346
+ stride_ob (int): Stride of b dimension of output.
347
+ stride_om (int): Stride of m dimension of output.
348
+ stride_on (int): Stride of n dimension of output.
349
+ stride_ok (int): Stride of k dimension of output.
350
+ packed_scale_stride (int): Stride of the packed scale, indexing into a_fp8.
351
+ stride_zb (int): Stride of b dimension of jagged index.
352
+ stride_zm (int): Stride of m dimension of jagged index.
353
+ TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
354
+ MAX_FP8 (float): Maxmimum expressible value for FP8.
355
+ EPS (float): Epsilon value for numerical stability.
356
+ CLAMP_MAX (bool): Whethar to apply scale_ub.
357
+ JAGGED (bool): Whether to use jagged indexing.
358
+ BLOCK_SIZE (int): Block size for reduction.
359
+ USE_INT64 (bool): Whether to use int64 indexing for large inputs.
360
+ """
361
+ pid = tl.program_id(0)
362
+ # Use int64 indexing for large inputs. This is slower, but
363
+ # needed to avoid index overflows.
364
+ if USE_INT64:
365
+ pid = pid.to(tl.int64)
366
+ n_offset = tl.arange(0, BLOCK_SIZE)
367
+ a_offset_base = (
368
+ pid // (M * N) * stride_ab
369
+ + (pid % (M * N)) // N * stride_am
370
+ + (pid % (M * N)) % N * stride_an
371
+ )
372
+ a_fp8_offset_base = (
373
+ pid // (M * N) * stride_ob
374
+ + (pid % (M * N)) // N * stride_om
375
+ + (pid % (M * N)) % N * stride_on
376
+ )
377
+
378
+ K_in = K
379
+
380
+ if JAGGED:
381
+ z_offset_base = pid // (M * N) * stride_zb + (pid % (M * N)) // N * stride_zm
382
+ group_rows = tl.load(zero_start_index_M + z_offset_base)
383
+ current_row = pid % N
384
+ # If this row is empty, dont process any of it.
385
+ if current_row >= group_rows:
386
+ K_in = 0
387
+
388
+ # Calculate max.
389
+ cur_max = 0.0
390
+ for _k in range(0, tl.cdiv(K_in, BLOCK_SIZE)):
391
+ a = tl.load(
392
+ A + a_offset_base + n_offset * stride_ak,
393
+ mask=n_offset < K_in,
394
+ other=0.0,
395
+ )
396
+ tile_max = tl.max(tl.abs(a))
397
+ cur_max = tl.maximum(tile_max, cur_max)
398
+ n_offset += BLOCK_SIZE
399
+
400
+ # Clamp max value appropriately.
401
+ if CLAMP_MAX:
402
+ ub = tl.load(scale_ub)
403
+ cur_max = tl.clamp(cur_max, EPS, ub)
404
+ else:
405
+ cur_max = tl.maximum(cur_max, EPS)
406
+ # Scale and quantize.
407
+ a_scale = MAX_FP8 / cur_max
408
+
409
+ (fp8_0, fp8_1, fp8_2, fp8_3) = tl.inline_asm_elementwise(
410
+ asm="""
411
+ {
412
+ // $4 is the input register
413
+ .reg .b32 input;
414
+ mov.b32 input, $4;
415
+ mov.b32 $0, $4;
416
+ shr.b32 $1, $4, 8;
417
+ shr.b32 $2, $4, 16;
418
+ shr.b32 $3, $4, 24;
419
+ }
420
+ """,
421
+ constraints=("=r,=r,=r,=r,r"),
422
+ # Let's pass in 1 uint32 value per iteration, containing 8 packed int4 values
423
+ args=[1.0 / a_scale],
424
+ dtype=(
425
+ tl.uint8,
426
+ tl.uint8,
427
+ tl.uint8,
428
+ tl.uint8,
429
+ ),
430
+ is_pure=True,
431
+ pack=1,
432
+ )
433
+
434
+ # There are some compiler issues with FP8 pointers
435
+ packed_scale_ptr = packed_scale.to(tl.pointer_type(tl.uint8))
436
+ tl.store(packed_scale_ptr + pid * packed_scale_stride, fp8_0)
437
+ tl.store(packed_scale_ptr + pid * packed_scale_stride + 1, fp8_1)
438
+ tl.store(packed_scale_ptr + pid * packed_scale_stride + 2, fp8_2)
439
+ tl.store(packed_scale_ptr + pid * packed_scale_stride + 3, fp8_3)
440
+
441
+ n_offset = tl.arange(0, BLOCK_SIZE)
442
+
443
+ for _k in range(0, tl.cdiv(K, BLOCK_SIZE)):
444
+ a = tl.load(
445
+ A + a_offset_base + n_offset * stride_ak,
446
+ mask=n_offset < K_in,
447
+ other=0.0,
448
+ )
449
+ a_fp8 = a * a_scale
450
+ # Clamp A to fp8 range to make sure there's no overflow.
451
+ # This is required for AMD. Nvidia's default saturation
452
+ # handles it, but it's nice to have anyway.
453
+ a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
454
+ tl.store(
455
+ A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
456
+ a_fp8,
457
+ mask=n_offset < K,
458
+ )
459
+
460
+ n_offset += BLOCK_SIZE
461
+
462
+
463
+ def triton_quantize_fp8_packed_row(
464
+ a: torch.Tensor,
465
+ scale_ub: Optional[torch.Tensor] = None,
466
+ zero_start_index_M: Optional[torch.Tensor] = None,
467
+ return_only_packed: Optional[bool] = False,
468
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], torch.Tensor]:
469
+ """
470
+ Call the triton quantize fp8 row kernel to quantize a tensor to fp8 with row-wise scalings.
471
+
472
+ This packs the FP32 scale at the end of each row, so the fp8 scaled tensor and the reciprocal scale tensor per row are contiguous in memory.
473
+
474
+ Args:
475
+ a (Tensor): higher precision input tensor of 4 dimension.
476
+ scale_ub (Tensor): Maximum allowed value for scale.
477
+ zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
478
+ return_only_packed (bool): Only return the packed tensor, do not unpack results if True
479
+ Returns:
480
+ torch.Tensor: fp8 scaled tensor.
481
+ torch.Tensor: reciprocal scale tensor per row.
482
+ torch.Tensor: The packed FP8 scaled tensor, with the scale at the end of each row.
483
+ """
484
+ if scale_ub is not None and scale_ub.device != a.device:
485
+ raise Exception("'scale_ub' must be on the same device as 'a'")
486
+ if zero_start_index_M is not None and zero_start_index_M.device != a.device:
487
+ raise Exception("'zero_start_index_M' must be on the same device as 'a'")
488
+
489
+ assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor."
490
+ a_shape = a.shape
491
+ while a.dim() < 4:
492
+ a = a.unsqueeze(0)
493
+ if zero_start_index_M is not None:
494
+ # There should be one value of zero_start_index_M per NxK matrix.
495
+ zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
496
+ # Get constant values.
497
+ pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
498
+ num_rows = a.numel() // a.shape[-1]
499
+
500
+ # Allocate an extra 4-bytes at the end of each row for the scale.
501
+ a_fp8 = torch.empty(
502
+ (*a.shape[:-1], a.shape[-1] + 4), device=a.device, dtype=pt_dtype
503
+ )
504
+
505
+ # create a view of the packed scale
506
+ packed_scale = a_fp8[..., -4:]
507
+
508
+ # If input tensor is sufficiently large, we need to use int64 indexing.
509
+ use_int64 = a.numel() > (2**31 - 1)
510
+ grid = (num_rows,)
511
+
512
+ with torch.cuda.device(a.device.index):
513
+ _kernel_quantize_fp8_packed_row[grid](
514
+ a,
515
+ a_fp8,
516
+ packed_scale,
517
+ scale_ub,
518
+ zero_start_index_M,
519
+ a.shape[0],
520
+ a.shape[1],
521
+ a.shape[2],
522
+ a.shape[3],
523
+ a.stride(0),
524
+ a.stride(1),
525
+ a.stride(2),
526
+ a.stride(3),
527
+ a_fp8.stride(0),
528
+ a_fp8.stride(1),
529
+ a_fp8.stride(2),
530
+ a_fp8.stride(3),
531
+ packed_scale.stride(2), # this is the stride that matters
532
+ zero_start_index_M.stride(0) if zero_start_index_M is not None else None,
533
+ zero_start_index_M.stride(1) if zero_start_index_M is not None else None,
534
+ TL_FP8_DTYPE=tl_dtype,
535
+ MAX_FP8=max_fp8,
536
+ EPS=eps,
537
+ CLAMP_MAX=scale_ub is not None,
538
+ JAGGED=zero_start_index_M is not None,
539
+ USE_INT64=use_int64,
540
+ )
541
+ if return_only_packed:
542
+ return None, None, a_fp8.view((*a_shape[:-1], a_shape[-1] + 4))
543
+
544
+ # Extract the original shape data without the extra 4 bytes per row
545
+ # The data is still contiguous in memory, so we have to unpack it.
546
+ final_fp8_view = a_fp8[..., :-4].view(a_shape)
547
+ scale_view = a_fp8[..., -4:].reshape((num_rows * 4)).view(torch.float32)
548
+
549
+ # the difference with the packed API is that it also
550
+ # returns the full packed tensor as a third return value
551
+ return final_fp8_view, scale_view.view(a_shape[:-1]), a_fp8
552
+
553
+
554
+ @torch.library.custom_op("triton::quantize_fp8_packed_row", mutates_args=())
555
+ def quantize_fp8_packed_row(
556
+ a: torch.Tensor,
557
+ scale_ub: Optional[torch.Tensor] = None,
558
+ zero_start_index_M: Optional[torch.Tensor] = None,
559
+ use_triton: bool = True,
560
+ output_device: Optional[torch.device] = None,
561
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
562
+ """
563
+ Quantize a to fp8 with row-wise scalings and optionally move to output device.
564
+
565
+ Args:
566
+ a (Tensor): Input high precision tensor. Required to have no more than 4 dimension
567
+ scale_ub (Tensor): Maximum allowed value for scale.
568
+ zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
569
+ use_triton (bool): Whether to use triton kernel or pytorch.
570
+ output_device (torch.device): Device to optionally move the scaled tensors to.
571
+ Returns:
572
+ torch.Tensor: fp8 scaled tensor.
573
+ torch.Tensor: The reciprocal scale tensor per row.
574
+ """
575
+
576
+ if a.device == torch.device("cpu"):
577
+ logger.info("Triton does not support cpu, falling back to torch ops.")
578
+ use_triton = False
579
+ if use_triton:
580
+ # ignore the packed tensor here, we aren't testing it
581
+ a_fp8, scale, _ = triton_quantize_fp8_packed_row(
582
+ a, scale_ub, zero_start_index_M, return_only_packed=False
583
+ )
584
+ assert a_fp8 is not None
585
+ assert scale is not None
586
+ return a_fp8, scale
587
+ # else use pytorch implementation.
588
+ if not output_device:
589
+ output_device = a.device
590
+
591
+ a_shape = a.shape
592
+ # Get constants.
593
+ pt_dtype, _, max_fp8, eps = get_fp8_constants()
594
+ row_max: torch.Tensor = torch.max(torch.abs(a), dim=-1)[0]
595
+ # Apply clamping.
596
+ if scale_ub is not None:
597
+ row_max = torch.clamp(row_max, min=eps, max=scale_ub.item())
598
+ else:
599
+ # pyre-ignore[6]: Incompatible parameter type [6]
600
+ row_max = torch.clamp(row_max, min=eps)
601
+ a_scale = torch.empty((a.shape[:-1]), dtype=torch.float32, device=output_device)
602
+ a_scale = max_fp8 / row_max.to(torch.float32) # pyre-ignore
603
+ a_scale[a_scale == float("inf")] = 1.0 # pyre-ignore
604
+ a_fp8 = a * a_scale[..., None] # pyre-ignore
605
+ # Cast and move data to output device (for cpu weight loading).
606
+ a_fp8 = a_fp8.to(device=output_device, dtype=pt_dtype)
607
+ a_scale = a_scale.to(output_device) # pyre-ignore
608
+ del a
609
+ return a_fp8, (1 / a_scale).view(a_shape[:-1]) # pyre-ignore
610
+
611
+
612
+ @torch.library.custom_op("triton::quantize_fp8_packed_row_raw", mutates_args=())
613
+ def quantize_fp8_packed_row_raw(
614
+ a: torch.Tensor,
615
+ scale_ub: Optional[torch.Tensor] = None,
616
+ zero_start_index_M: Optional[torch.Tensor] = None,
617
+ use_triton: bool = True,
618
+ output_device: Optional[torch.device] = None,
619
+ ) -> torch.Tensor:
620
+ """
621
+ Quantize a to fp8 with row-wise scalings and optionally move to output device.
622
+
623
+ Identical to quantize_fp8_packed_row, except it only returns the raw packed tensor.
624
+
625
+ Args:
626
+ a (Tensor): Input high precision tensor. Required to have no more than 4 dimension
627
+ scale_ub (Tensor): Maximum allowed value for scale.
628
+ zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
629
+ use_triton (bool): Whether to use triton kernel or pytorch.
630
+ output_device (torch.device): Device to optionally move the scaled tensors to.
631
+ Returns:
632
+ torch.Tensor: fp8 scaled tensor.
633
+ torch.Tensor: The reciprocal scale tensor per row.
634
+ """
635
+
636
+ if a.device == torch.device("cpu"):
637
+ logger.info("Triton does not support cpu, falling back to torch ops.")
638
+ use_triton = False
639
+ if use_triton:
640
+ # ignore the packed tensor here, we aren't testing it
641
+ _, _, packed_tensor = triton_quantize_fp8_packed_row(
642
+ a, scale_ub, zero_start_index_M, return_only_packed=True
643
+ )
644
+ return packed_tensor
645
+ else:
646
+ raise Exception(
647
+ "No PyTorch implementation provided for triton::quantize_fp8_packed_row_raw"
648
+ )
649
+
650
+
651
+ @torch.library.custom_op("triton::quantize_fp8_row", mutates_args=())
652
+ def quantize_fp8_row(
653
+ a: torch.Tensor,
654
+ scale_ub: Optional[torch.Tensor] = None,
655
+ zero_start_index_M: Optional[torch.Tensor] = None,
656
+ use_triton: bool = True,
657
+ output_device: Optional[torch.device] = None,
658
+ align_rows_to: Optional[int] = None,
659
+ ) -> tuple[torch.Tensor, torch.Tensor]:
660
+ """
661
+ Quantize a to fp8 with row-wise scalings and optionally move to output device.
662
+
663
+ Args:
664
+ a (Tensor): Input high precision tensor. Required to have no more than 4 dimension
665
+ scale_ub (Tensor): Maximum allowed value for scale.
666
+ zero_start_index_M (Tensor): Indicates number of nonzero elements in each row.
667
+ use_triton (bool): Whether to use triton kernel or pytorch.
668
+ output_device (torch.device): Device to optionally move the scaled tensors to.
669
+ align_rows_to: Pad rows to align to this value. Useful for downstream kernels accepting specific sizes (e.g., multiple of 16)
670
+
671
+ Returns:
672
+ torch.Tensor: fp8 scaled tensor.
673
+ torch.Tensor: The reciprocal scale tensor per row.
674
+ """
675
+
676
+ if a.device == torch.device("cpu"):
677
+ logger.info("Triton does not support cpu, falling back to torch ops.")
678
+ use_triton = False
679
+ if use_triton:
680
+ return triton_quantize_fp8_row(
681
+ a,
682
+ scale_ub,
683
+ zero_start_index_M,
684
+ align_rows_to=align_rows_to,
685
+ )
686
+ # else use pytorch implementation.
687
+ if not output_device:
688
+ output_device = a.device
689
+
690
+ a_shape = a.shape
691
+ # Get constants.
692
+ pt_dtype, _, max_fp8, eps = get_fp8_constants()
693
+ row_max: torch.Tensor = torch.max(torch.abs(a), dim=-1)[0]
694
+ # Apply clamping.
695
+ if scale_ub is not None:
696
+ row_max = torch.clamp(row_max, min=eps, max=scale_ub.item())
697
+ else:
698
+ # pyre-ignore[6]: Incompatible parameter type [6]
699
+ row_max = torch.clamp(row_max, min=eps)
700
+ a_scale = torch.empty((a.shape[:-1]), dtype=torch.float32, device=output_device)
701
+ a_scale = max_fp8 / row_max.to(torch.float32) # pyre-ignore
702
+ a_scale[a_scale == float("inf")] = 1.0 # pyre-ignore
703
+ a_fp8 = a * a_scale[..., None] # pyre-ignore
704
+ # Cast and move data to output device (for cpu weight loading).
705
+ a_fp8 = a_fp8.to(device=output_device, dtype=pt_dtype)
706
+ a_scale = a_scale.to(output_device) # pyre-ignore
707
+ del a
708
+ return a_fp8, (1 / a_scale).view(a_shape[:-1]) # pyre-ignore
709
+
710
+
711
+ @quantize_fp8_row.register_fake
712
+ def quantize_fp8_row_meta(
713
+ a: torch.Tensor,
714
+ scale_ub: Optional[torch.Tensor] = None,
715
+ zero_start_index_M: Optional[torch.Tensor] = None,
716
+ use_triton: bool = True,
717
+ output_device: Optional[torch.device] = None,
718
+ align_rows_to: Optional[int] = None,
719
+ ) -> tuple[torch.Tensor, torch.Tensor]:
720
+ """Shape function for torch compile."""
721
+ if output_device is None:
722
+ output_device = a.device
723
+ a_shape = a.shape
724
+ dtype = get_fp8_constants()[0]
725
+ fake_scale = torch.empty(a_shape[:-1], device=output_device, dtype=torch.float32)
726
+ if align_rows_to is not None:
727
+ last_dim = a.shape[-1]
728
+ padded_last_dim = (
729
+ (last_dim + align_rows_to - 1) // align_rows_to
730
+ ) * align_rows_to
731
+ fake_out = torch.empty(
732
+ (*a.shape[:-1], padded_last_dim), device=output_device, dtype=dtype
733
+ )
734
+ return fake_out, fake_scale
735
+ else:
736
+ fake_out = torch.empty(a.shape, device=output_device, dtype=dtype)
737
+ return fake_out, fake_scale
738
+
739
+
740
+ @triton.autotune(
741
+ configs=[
742
+ Config({"BLOCK_SIZE": 512}),
743
+ Config({"BLOCK_SIZE": 1024}),
744
+ Config({"BLOCK_SIZE": 2048}),
745
+ Config({"BLOCK_SIZE": 4096}),
746
+ Config({"BLOCK_SIZE": 8192}),
747
+ ],
748
+ key=["N"],
749
+ )
750
+ @triton.jit
751
+ def _kernel_scale_fp8_row(
752
+ A,
753
+ x_scale,
754
+ w_scale,
755
+ scaled_out,
756
+ M,
757
+ N,
758
+ stride_am,
759
+ stride_an,
760
+ stride_om,
761
+ stride_on,
762
+ BLOCK_SIZE: tl.constexpr,
763
+ ) -> None:
764
+ """
765
+ Scale each row of A by x_scale and each column of A by w_scale.
766
+
767
+ Args:
768
+ A (Tensor): [m, n] Input tensor to scale.
769
+ x_scale (Tensor): [m] Row-wise scale tensor.
770
+ w_scale (Tensor): [n] Col-wise scale tensor.
771
+ scaled_out (Tensor): [m, n] Output tensor.
772
+ M (int): Number of rows.
773
+ N (int): Number of columns.
774
+ stride_am (int): Stride of m dimension of A.
775
+ stride_an (int): Stride of n dimension of A.
776
+ stride_om (int): Stride of m dimension of output.
777
+ stride_on (int): Stride of n dimension of output.
778
+ BLOCK_SIZE (int): Block size for data loads.
779
+ """
780
+ pid = tl.program_id(0)
781
+ n_offset = tl.arange(0, BLOCK_SIZE)
782
+ # Load activation scale for this row.
783
+ row_scale = tl.load(x_scale + pid)
784
+
785
+ # Iterate over chunks of the row and apply scales.
786
+ for _k in range(0, tl.cdiv(N, BLOCK_SIZE)):
787
+ a = tl.load(
788
+ A + pid * stride_am + n_offset * stride_an, mask=n_offset < N, other=0.0
789
+ )
790
+ col_scale = tl.load(w_scale + n_offset)
791
+ scaled_a = a * row_scale * col_scale
792
+ tl.store(
793
+ scaled_out + pid * stride_om + n_offset * stride_on,
794
+ scaled_a,
795
+ mask=n_offset < N,
796
+ )
797
+ n_offset += BLOCK_SIZE
798
+
799
+
800
+ def scale_fp8_row(
801
+ a: torch.Tensor,
802
+ x_scale: torch.Tensor,
803
+ w_scale: torch.Tensor,
804
+ ) -> torch.Tensor:
805
+ """
806
+ Apply only rowwise scaling to a tensor. Useful when combining with kernels
807
+ that do not support fused rowwise scaling.
808
+
809
+ Args:
810
+ a (Tensor): Input floating point tensor to be scaled.
811
+ x_scale (Tensor): Row-wise activation scale tensor.
812
+ w_scale (Tensor): Col-wise weight scale tensor.
813
+ """
814
+ if a.device == torch.device("cpu"):
815
+ # On CPU we'll just use native pytorch to scale.
816
+ return a * x_scale[:, None] * w_scale[None, :]
817
+
818
+ if x_scale.device != a.device:
819
+ raise Exception("'x_scale' must be on the same device as 'a'")
820
+ if w_scale.device != a.device:
821
+ raise Exception("'w_scale' must be on the same device as 'a'")
822
+
823
+ # Otherwise, use a fast triton kernel to implement.
824
+ # We'll parallelize over rows.
825
+ num_rows = a.shape[0]
826
+ scaled_out = torch.empty(a.shape, device=a.device, dtype=a.dtype)
827
+ grid = (num_rows,)
828
+ with torch.cuda.device(a.device.index):
829
+ _kernel_scale_fp8_row[grid](
830
+ a,
831
+ x_scale,
832
+ w_scale,
833
+ scaled_out,
834
+ a.shape[0],
835
+ a.shape[1],
836
+ a.stride(0),
837
+ a.stride(1),
838
+ scaled_out.stride(0),
839
+ scaled_out.stride(1),
840
+ )
841
+
842
+ return scaled_out
843
+
844
+
845
+ @triton.jit
846
+ def _kernel_quantize_fp8_block(
847
+ A,
848
+ A_scale,
849
+ A_fp8,
850
+ scale_ub,
851
+ M,
852
+ K,
853
+ stride_am,
854
+ stride_ak,
855
+ stride_om,
856
+ stride_ok,
857
+ stride_a_scale_m,
858
+ stride_a_scale_k,
859
+ TL_FP8_DTYPE: tl.constexpr,
860
+ MAX_FP8: tl.constexpr,
861
+ EPS: tl.constexpr,
862
+ CLAMP_MAX: tl.constexpr,
863
+ BLOCK_M: tl.constexpr,
864
+ BLOCK_K: tl.constexpr,
865
+ K_MAJOR: tl.constexpr,
866
+ ) -> None:
867
+ """Quantize and scale each [BLOCK_M, BLOCK_K] block.
868
+
869
+ Scale per block i, j is computed as 1 / (MAX_FP8 / max(abs(A[i:i+BLOCK_M, j:j+BLOCK_K])))
870
+
871
+ Kernel naively iterates through matrix with [BLOCK_M, BLOCK_K] tiles.
872
+
873
+ Todo:
874
+ * Better tiling and ordering schemes.
875
+
876
+ Args:
877
+ A (Tensor): [M, K] higher precision input tensor.
878
+ A_scale (Tensor): [cdiv(M, BLOCK_M), cdiv(K, BLOCK_K)] reciprocal scale tensor per block.
879
+ A_fp8 (Tensor): [M, K] fp8 scaled tensor. A_fp8 = A * a_scale
880
+ scale_ub (Tensor): [1] Maximum allowed value for scale.
881
+ M (int): Number of rows.
882
+ K (int): Number of columns.
883
+ stride_am (int): Stride of m dimension of A.
884
+ stride_ak (int): Stride of k dimension of A.
885
+ stride_om (int): Stride of m dimension of output.
886
+ stride_ok (int): Stride of k dimension of output.
887
+ stride_a_scale_m (int): Stride of m dimension of A_scale.
888
+ stride_a_scale_k (int): Stride of k dimension of A_scale.
889
+ TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
890
+ MAX_FP8 (float): Maxmimum expressible value for FP8.
891
+ EPS (float): Epsilon value for numerical stability.
892
+ CLAMP_MAX (bool): Whether to apply scale_ub.
893
+ BLOCK_M (int): Block size for M dimension of A_scale and kernel.
894
+ BLOCK_K (int): Block size for K dimension of A_scale and kernel.
895
+ K_MAJOR (bool): Whether output scales should be K major (True) or MN major (False).
896
+ """
897
+ pid = tl.program_id(0)
898
+ grid_k = tl.cdiv(K, BLOCK_K)
899
+ block_m = pid // grid_k
900
+ block_k = pid % grid_k
901
+ rm = block_m * BLOCK_M + tl.arange(0, BLOCK_M)
902
+ rk = block_k * BLOCK_K + tl.arange(0, BLOCK_K)
903
+ a_offset = rm[:, None] * stride_am + rk[None, :] * stride_ak
904
+ out_offset = rm[:, None] * stride_om + rk[None, :] * stride_ok
905
+ a_mask = (rm < M)[:, None] & (rk < K)[None, :]
906
+ a_block = tl.load(A + a_offset, mask=a_mask, other=0.0)
907
+
908
+ block_max = tl.max(tl.abs(a_block))
909
+ # Apply appropriate clamping.
910
+ if CLAMP_MAX:
911
+ ub = tl.load(scale_ub)
912
+ block_max = tl.clamp(block_max, EPS, ub)
913
+ else:
914
+ block_max = tl.maximum(block_max, EPS)
915
+ scale = MAX_FP8 / block_max
916
+
917
+ # Write in transposed order if specified.
918
+ if K_MAJOR:
919
+ scale_offset = block_m * stride_a_scale_m + block_k * stride_a_scale_k
920
+ else:
921
+ scale_offset = block_k * stride_a_scale_m + block_m * stride_a_scale_k
922
+ tl.store(A_scale + scale_offset, 1.0 / scale)
923
+ a_fp8 = a_block * scale
924
+ # Clamp A to fp8 range to make sure there's no overflow.
925
+ # This is required for AMD. Nvidia's default saturation
926
+ # handles it, but it's nice to have anyway.
927
+ a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8)
928
+ a_fp8.to(TL_FP8_DTYPE)
929
+ tl.store(A_fp8 + out_offset, a_fp8, mask=a_mask)
930
+
931
+
932
+ def triton_quantize_fp8_block(
933
+ x: torch.Tensor,
934
+ block_m: int = 256,
935
+ block_k: int = 256,
936
+ scale_ub: Optional[torch.Tensor] = None,
937
+ k_major: bool = True,
938
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
939
+ """
940
+ Quantize a tensor to fp8 with block-wise scalings.
941
+
942
+ Scale per block i, j is computed as 1 / (MAX_FP8 / max(abs(x[i:i+block_m, j:j+block_k])))
943
+
944
+ Args:
945
+ x (torch.Tensor): [M, K] higher precision input tensor.
946
+ block_m (int): Block size for M dimension of scale.
947
+ block_k (int): Block size for K dimension of scale.
948
+ scale_ub: Maximum allowed value for scale.
949
+ k_major (bool): Whether output scales should be K major (True) or MN major (False).
950
+
951
+ Returns:
952
+ torch.Tensor : [M, K] fp8 scaled tensor.
953
+ torch.Tensor: [cdiv(M, block_m), cdiv(K, block_k)] reciprocal scale tensor per block
954
+ if k_major is True, otherwise [cdiv(K, block_k), cdiv(M, block_M)].
955
+ """
956
+ assert x.device != torch.device("cpu"), (
957
+ "Blockwise quantization not support on cpu, please use row-wise quantization instead."
958
+ )
959
+
960
+ if scale_ub is not None and scale_ub.device != x.device:
961
+ raise Exception("'scale_ub' must be on the same device as 'a'")
962
+
963
+ x_shape = x.shape
964
+ x = x.view(-1, x.size(-1))
965
+ # Get constant values.
966
+ pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
967
+ M, K = x.shape
968
+ grid_m = triton.cdiv(M, block_m)
969
+ grid_k = triton.cdiv(K, block_k)
970
+ if k_major:
971
+ x_scale = torch.empty((grid_m, grid_k), device=x.device, dtype=torch.float32)
972
+ else:
973
+ x_scale = torch.empty((grid_k, grid_m), device=x.device, dtype=torch.float32)
974
+ x_fp8 = torch.empty((M, K), device=x.device, dtype=pt_dtype)
975
+
976
+ _kernel_quantize_fp8_block[(grid_m * grid_k,)](
977
+ x,
978
+ x_scale,
979
+ x_fp8,
980
+ scale_ub,
981
+ M,
982
+ K,
983
+ x.stride(0),
984
+ x.stride(1),
985
+ x_fp8.stride(0),
986
+ x_fp8.stride(1),
987
+ x_scale.stride(0),
988
+ x_scale.stride(1),
989
+ # pyre-ignore[6]: Incompatible parameter type [6]
990
+ TL_FP8_DTYPE=tl_dtype,
991
+ # pyre-ignore[6]: Incompatible parameter type [6]
992
+ MAX_FP8=max_fp8,
993
+ # pyre-ignore[6]: Incompatible parameter type [6]
994
+ EPS=eps,
995
+ # pyre-ignore[6]: Incompatible parameter type [6]
996
+ CLAMP_MAX=scale_ub is not None,
997
+ # pyre-ignore[6]: Incompatible parameter type [6]
998
+ BLOCK_M=block_m,
999
+ # pyre-ignore[6]: Incompatible parameter type [6]
1000
+ BLOCK_K=block_k,
1001
+ # pyre-ignore[6]: Incompatible parameter type [6]
1002
+ K_MAJOR=k_major,
1003
+ )
1004
+
1005
+ return x_fp8.view(x_shape), x_scale
1006
+
1007
+
1008
+ @torch.library.custom_op("triton::quantize_fp8_block", mutates_args=())
1009
+ def quantize_fp8_block(
1010
+ x: torch.Tensor,
1011
+ block_m: int = 256,
1012
+ block_k: int = 256,
1013
+ scale_ub: Optional[torch.Tensor] = None,
1014
+ use_triton: bool = True,
1015
+ output_device: Optional[torch.device] = None,
1016
+ k_major: bool = True,
1017
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1018
+ """
1019
+ Quantize a tensor to fp8 with block-wise scalings and optionally move to output device.
1020
+
1021
+ Scale per block i, j is computed as 1 / (MAX_FP8 / max(abs(x[i:i+block_m, j:j+block_k])))
1022
+
1023
+ Args:
1024
+ x (Tensor): [M, K] higher precision input tensor.
1025
+ block_m (int): Block size for M dimension of scale.
1026
+ block_k (int): Block size for K dimension of scale.
1027
+ scale_ub: Maximum allowed value for scale.
1028
+ use_triton (bool): Whether to use triton kernel or pytorch.
1029
+ output_device (torch.device): Device to optionally move the scaled tensors to.
1030
+ k_major (bool): Whether output scales should be K major (True) or MN major (False).
1031
+
1032
+ Returns:
1033
+ torch.Tensor: [M, K] fp8 scaled tensor.
1034
+ torch.Tensor: [cdiv(M, block_m), cdiv(K, block_k)] reciprocal scale tensor per block
1035
+ if k_major is True, otherwise [cdiv(K, block_k), cdiv(M, block_M)].
1036
+ """
1037
+ x_shape = x.shape
1038
+ x = x.view(-1, x.size(-1))
1039
+ if x.device == torch.device("cpu"):
1040
+ logger.info("Triton does not support cpu, falling back to torch ops.")
1041
+ use_triton = False
1042
+ if use_triton:
1043
+ xq, x_scale = triton_quantize_fp8_block(x, block_m, block_k, scale_ub, k_major)
1044
+ return xq.view(x_shape), x_scale
1045
+ # else use pytorch implementation.
1046
+ if not output_device:
1047
+ output_device = x.device
1048
+
1049
+ # Get constants.
1050
+ pt_dtype, _, max_fp8, eps = get_fp8_constants()
1051
+
1052
+ M, K = x.shape
1053
+ grid_m = triton.cdiv(M, block_m)
1054
+ grid_k = triton.cdiv(K, block_k)
1055
+
1056
+ # Pad x to multiple of block size.
1057
+ padded_m = grid_m * block_m
1058
+ padded_k = grid_k * block_k
1059
+ x_padded = torch.zeros(padded_m, padded_k, dtype=x.dtype, device=x.device)
1060
+ x_padded[:M, :K] = x
1061
+
1062
+ # Blockwise max.
1063
+ block_max = (
1064
+ x_padded.abs().reshape(grid_m, block_m, grid_k, block_k).amax(dim=(1, 3))
1065
+ )
1066
+
1067
+ # Apply clamping.
1068
+ if scale_ub is not None:
1069
+ block_max = torch.clamp(block_max, min=eps, max=scale_ub.item())
1070
+ else:
1071
+ block_max = torch.clamp(block_max, min=eps)
1072
+ x_scale = torch.empty((grid_m, grid_k), dtype=torch.float32, device=output_device)
1073
+ x_scale = max_fp8 / block_max.to(torch.float32) # pyre-ignore
1074
+ # pyre-ignore[16]: Undefined attribute [16]
1075
+ x_scale[x_scale == float("inf")] = 1.0
1076
+ x_fp8 = (
1077
+ x_padded
1078
+ # pyre-ignore[16]: Undefined attribute [16]
1079
+ * x_scale.repeat_interleave(block_m, dim=0).repeat_interleave(block_k, dim=1)
1080
+ )[:M, :K]
1081
+
1082
+ # Cast and move data to output device (for cpu weight loading).
1083
+ x_fp8 = x_fp8.to(device=output_device, dtype=pt_dtype)
1084
+ x_scale = x_scale.to(output_device) # pyre-ignore
1085
+ del x, x_padded
1086
+ if not k_major:
1087
+ x_scale = x_scale.t().contiguous()
1088
+ return x_fp8.view(x_shape), 1 / x_scale # pyre-ignore
1089
+
1090
+
1091
+ @quantize_fp8_block.register_fake
1092
+ def quantize_fp8_block_meta(
1093
+ a: torch.Tensor,
1094
+ block_m: int = 256,
1095
+ block_k: int = 256,
1096
+ scale_ub: Optional[torch.Tensor] = None,
1097
+ use_triton: bool = True,
1098
+ output_device: Optional[torch.device] = None,
1099
+ k_major: bool = True,
1100
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1101
+ """Shape function for torch compile."""
1102
+ if output_device is None:
1103
+ output_device = a.device
1104
+ a_shape = a.shape
1105
+ dtype = get_fp8_constants()[0]
1106
+ fake_out = torch.empty(a.shape, device=output_device, dtype=dtype)
1107
+ scale_m = triton.cdiv(a_shape[0], block_m)
1108
+ scale_k = triton.cdiv(a_shape[1], block_k)
1109
+ scale_out_shape = (
1110
+ a_shape[:-2] + (scale_m, scale_k) if k_major else (scale_k, scale_m)
1111
+ )
1112
+ fake_scale = torch.empty(
1113
+ scale_out_shape,
1114
+ device=output_device,
1115
+ dtype=torch.float32,
1116
+ )
1117
+ return fake_out, fake_scale
1118
+
1119
+
1120
+ @triton.autotune(
1121
+ configs=[
1122
+ Config({"GROUP_LOAD": 2}),
1123
+ Config({"GROUP_LOAD": 4}),
1124
+ Config({"GROUP_LOAD": 8}),
1125
+ Config({"GROUP_LOAD": 16}),
1126
+ Config({"GROUP_LOAD": 32}),
1127
+ ],
1128
+ key=["K"],
1129
+ )
1130
+ @triton.jit
1131
+ def _kernel_quantize_fp8_group(
1132
+ A,
1133
+ A_scale,
1134
+ A_fp8,
1135
+ scale_ub,
1136
+ m_sizes,
1137
+ M,
1138
+ K,
1139
+ stride_am,
1140
+ stride_ak,
1141
+ stride_om,
1142
+ stride_ok,
1143
+ stride_a_scale_m,
1144
+ stride_a_scale_k,
1145
+ TL_FP8_DTYPE: tl.constexpr,
1146
+ MAX_FP8: tl.constexpr,
1147
+ EPS: tl.constexpr,
1148
+ CLAMP_MAX: tl.constexpr,
1149
+ USE_INT64: tl.constexpr,
1150
+ GROUP_SIZE: tl.constexpr,
1151
+ USE_M_MAJOR: tl.constexpr,
1152
+ G: tl.constexpr,
1153
+ GROUP_LOAD: tl.constexpr,
1154
+ ):
1155
+ """Quantize and scale each GROUP_SIZE chunk of each row.
1156
+
1157
+ Scale per group i is computed as 1 / (MAX_FP8 / max(abs(A[i:i+GROUP_SIZE])))
1158
+
1159
+ Each kernel thread is responsible for one row and loads and processes a tunable
1160
+ number of groups at once.
1161
+
1162
+ Args:
1163
+ A (Tensor): [M, K] higher precision input tensor.
1164
+ A_scale (Tensor): [M, cdiv(K, GROUP_SIZE)] reciprocal scale tensor per group.
1165
+ A_fp8 (Tensor): [M, K] fp8 scaled tensor. A_fp8 = A * a
1166
+ scale_ub (Tensor): [1] Maximum allowed value for scale.
1167
+ m_sizes (Optional[Tensor]): [G] Number of rows in each group.
1168
+ M (int): Number of rows.
1169
+ K (int): Number of columns.
1170
+ stride_am (int): Stride of m dimension of A.
1171
+ stride_ak (int): Stride of k dimension of A.
1172
+ stride_om (int): Stride of m dimension of output.
1173
+ stride_ok (int): Stride of k dimension of output.
1174
+ stride_a_scale_m (int): Stride of m dimension of A_scale.
1175
+ stride_a_scale_k (int): Stride of k dimension of A_scale.
1176
+ TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
1177
+ MAX_FP8 (float): Maxmimum expressible value for FP8.
1178
+ EPS (float): Epsilon value for numerical stability.
1179
+ CLAMP_MAX (bool): Whether to apply scale_ub.
1180
+ USE_INT64 (bool): Whether to index using int64, which may be needed for large tensors.
1181
+ GROUP_SIZE (int): Group size for K dimension of A_scale and kernel.
1182
+ USE_M_MAJOR (bool): Whether to use grouped M-major layout for A_scale.
1183
+ G (int): Number of groups in A_scale, only relevant when m_sizes is provided.
1184
+ GROUP_LOAD (int): Number of groups to load and process simultaneously.
1185
+ """
1186
+ pid = tl.program_id(0)
1187
+ if USE_INT64:
1188
+ pid = pid.to(tl.int64)
1189
+ # We load group_size * group_load chunks at a time.
1190
+ row_offset = pid * stride_am
1191
+ out_offset = pid * stride_om
1192
+ scale_row_offset = pid * stride_a_scale_m
1193
+ k_offset = tl.arange(0, GROUP_LOAD * GROUP_SIZE)
1194
+ scale_k_offset = tl.arange(0, GROUP_LOAD)
1195
+ NUM_GROUPS: tl.constexpr = K // GROUP_SIZE
1196
+
1197
+ # When dealing with an M-major grouped gemm, we need to figure out
1198
+ # which group this thread corresponds to and figure out the corresponding
1199
+ # scale offset.
1200
+ group_offset = 0
1201
+ group_cumsum = 0
1202
+ group_M = 0
1203
+ stop = False
1204
+ if USE_M_MAJOR and G > 0:
1205
+ # Iterate over groups to both compute the cumulative sum and find which group we are in.
1206
+ for i in range(G):
1207
+ if not stop:
1208
+ group_M = tl.cast(tl.load(m_sizes + i), pid.dtype)
1209
+ if (group_cumsum + group_M) <= pid:
1210
+ group_cumsum += group_M
1211
+ else:
1212
+ # Indicate we are finished computing cumsum.
1213
+ stop = True
1214
+
1215
+ group_offset = group_cumsum * NUM_GROUPS
1216
+
1217
+ for k in range(0, tl.cdiv(K, (GROUP_LOAD * GROUP_SIZE))):
1218
+ # Load groups of the input.
1219
+ chunk_offset = k_offset + k * GROUP_LOAD * GROUP_SIZE
1220
+ a = tl.load(
1221
+ A + row_offset + chunk_offset * stride_ak, mask=chunk_offset < K, other=0.0
1222
+ )
1223
+ # View loaded chunk as a set of groups.
1224
+ a_grouped = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE])
1225
+ # Reduce over groups.
1226
+ group_max = tl.max(tl.abs(a_grouped), axis=1)
1227
+ # Apply clamping if specified.
1228
+ if CLAMP_MAX:
1229
+ ub = tl.load(scale_ub)
1230
+ group_max = tl.clamp(group_max, EPS, ub)
1231
+ else:
1232
+ group_max = tl.maximum(group_max, EPS)
1233
+ # Scale and quantize.
1234
+ a_scale = MAX_FP8 / group_max
1235
+ scale_chunk_offset = scale_k_offset + k * GROUP_LOAD
1236
+
1237
+ if USE_M_MAJOR and G > 0:
1238
+ tl.store(
1239
+ A_scale
1240
+ + group_offset
1241
+ + (pid - group_cumsum) * stride_a_scale_k
1242
+ + (scale_chunk_offset * group_M),
1243
+ 1.0 / a_scale,
1244
+ mask=scale_chunk_offset < NUM_GROUPS,
1245
+ )
1246
+ else:
1247
+ if USE_M_MAJOR:
1248
+ tl.store(
1249
+ A_scale
1250
+ + pid * stride_a_scale_k
1251
+ + scale_chunk_offset * stride_a_scale_m,
1252
+ 1.0 / a_scale,
1253
+ mask=scale_chunk_offset < NUM_GROUPS,
1254
+ )
1255
+ else:
1256
+ tl.store(
1257
+ A_scale + scale_row_offset + scale_chunk_offset * stride_a_scale_k,
1258
+ 1.0 / a_scale,
1259
+ mask=scale_chunk_offset < NUM_GROUPS,
1260
+ )
1261
+ # Apply scale to input.
1262
+ a_fp8 = a_grouped * a_scale[:, None]
1263
+ # Clamp to FP8 range to avoid overflow
1264
+ a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
1265
+ # Write to output.
1266
+ tl.store(
1267
+ A_fp8 + out_offset + chunk_offset * stride_ok,
1268
+ tl.ravel(a_fp8),
1269
+ mask=chunk_offset < K,
1270
+ )
1271
+
1272
+
1273
+ def triton_quantize_fp8_group(
1274
+ x: torch.Tensor,
1275
+ group_size: int = 128,
1276
+ scale_ub: Optional[torch.Tensor] = None,
1277
+ m_sizes: Optional[torch.Tensor] = None,
1278
+ k_major: bool = True,
1279
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1280
+ """
1281
+ Quantize a tensor to fp8 with group-wise scalings.
1282
+
1283
+ Scale per group i is computed as 1 / (MAX_FP8 / max(abs(x[i:i+group_size])))
1284
+
1285
+ Args:
1286
+ x (torch.Tensor): [M, K] higher precision input tensor.
1287
+ group_size (int): Group size for M dimension of scale.
1288
+ scale_ub: Maximum allowed value for scale.
1289
+ m_sizes: Optional input for grouped gemm to specify the number of rows in each group.
1290
+ k_major (bool): Whether output scales should be K major (True) or MN major (False).
1291
+
1292
+ Returns:
1293
+ torch.Tensor: [M, K] fp8 scaled tensor.
1294
+ torch.Tensor: [M, cdiv(K, group_size)] reciprocal scale tensor per group.
1295
+ """
1296
+ assert x.device != torch.device("cpu"), (
1297
+ "Triton groupwise quantization not supported on cpu."
1298
+ )
1299
+
1300
+ if scale_ub is not None and scale_ub.device != x.device:
1301
+ raise Exception("'scale_ub' must be on the same device as 'a'")
1302
+ if m_sizes is not None and m_sizes.device != x.device:
1303
+ raise Exception("'m_sizes' must be on the same device as 'a'")
1304
+
1305
+ x_shape = x.shape
1306
+ x = x.view(-1, x.size(-1))
1307
+ pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
1308
+ M, K = x.shape
1309
+ k_groups = triton.cdiv(K, group_size)
1310
+ if k_major:
1311
+ x_scale = torch.empty((M, k_groups), device=x.device, dtype=torch.float32)
1312
+ else:
1313
+ x_scale = torch.empty((k_groups, M), device=x.device, dtype=torch.float32)
1314
+ x_fp8 = torch.empty((M, K), device=x.device, dtype=pt_dtype)
1315
+ _kernel_quantize_fp8_group[(M,)](
1316
+ x,
1317
+ x_scale,
1318
+ x_fp8,
1319
+ scale_ub,
1320
+ m_sizes,
1321
+ M,
1322
+ K,
1323
+ x.stride(0),
1324
+ x.stride(1),
1325
+ x_fp8.stride(0),
1326
+ x_fp8.stride(1),
1327
+ x_scale.stride(0),
1328
+ x_scale.stride(1),
1329
+ TL_FP8_DTYPE=tl_dtype,
1330
+ MAX_FP8=max_fp8,
1331
+ EPS=eps,
1332
+ CLAMP_MAX=scale_ub is not None,
1333
+ USE_INT64=x.numel() > (2**32 - 1),
1334
+ GROUP_SIZE=group_size,
1335
+ USE_M_MAJOR=m_sizes is not None or k_major is False,
1336
+ G=m_sizes.numel() if m_sizes is not None else 0,
1337
+ )
1338
+ return x_fp8.view(x_shape), x_scale
1339
+
1340
+
1341
+ def quantize_fp8_group(
1342
+ x: torch.Tensor,
1343
+ group_size: int = 128,
1344
+ scale_ub: Optional[torch.Tensor] = None,
1345
+ m_sizes: Optional[torch.Tensor] = None,
1346
+ k_major: bool = True,
1347
+ use_triton: bool = True,
1348
+ output_device: Optional[torch.device] = None,
1349
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1350
+ """
1351
+ Quantize a tensor to fp8 with group-wise scalings and optionally move to output device.
1352
+
1353
+ Scale per group i is computed as 1 / (MAX_FP8 / max(abs(x[i:i+group_size])))
1354
+
1355
+ Args:
1356
+ x (Tensor): [M, K] higher precision input tensor.
1357
+ group_size (int): Group size for M dimension of scale.
1358
+ scale_ub: Maximum allowed value for scale.
1359
+ m_sizes: Optional input for grouped gemm to specify the number of rows in each group.
1360
+ k_major (bool): Whether output scales should be K major (True) or MN major (False).
1361
+ This is needed because some kernels like cutlass require a special layout for scales.
1362
+ use_triton (bool): Whether to use triton kernel or pytorch.
1363
+ output_device (torch.device): Device to optionally move the scaled tensors to.
1364
+
1365
+ Returns:
1366
+ torch.Tensor: [M, K] fp8 scaled tensor.
1367
+ torch.Tensor: [M, cdiv(K, group_size)] reciprocal scale tensor per group.
1368
+ """
1369
+ x_shape = x.shape
1370
+ x = x.view(-1, x.size(-1))
1371
+ if x.device == torch.device("cpu"):
1372
+ logger.info("Triton does not support cpu, falling back to torch ops.")
1373
+ use_triton = False
1374
+ if use_triton:
1375
+ xq, x_scale = triton_quantize_fp8_group(
1376
+ x, group_size, scale_ub, m_sizes, k_major
1377
+ )
1378
+ return xq.view(x_shape), x_scale
1379
+ # else use pytorch implementation.
1380
+ if not output_device:
1381
+ output_device = x.device
1382
+
1383
+ # Get constants.
1384
+ pt_dtype, _, max_fp8, eps = get_fp8_constants()
1385
+
1386
+ M, K = x.shape
1387
+ assert K % group_size == 0, (
1388
+ "K must be divisible by group_size for cpu implementation."
1389
+ )
1390
+ assert m_sizes is None, "m_sizes is not supported for cpu implementation."
1391
+ k_groups = triton.cdiv(K, group_size)
1392
+ # View input as colleciton of groups for reduction.
1393
+ x_grouped = x.view(M, k_groups, group_size).to(torch.float32)
1394
+ # Reduce over groups.
1395
+ group_max = x_grouped.abs().amax(dim=2)
1396
+ # Apply clamping.
1397
+ group_max = (
1398
+ torch.clamp(group_max, min=eps, max=scale_ub.item())
1399
+ if scale_ub
1400
+ else torch.clamp(group_max, min=eps)
1401
+ )
1402
+ x_scale = torch.empty((M, k_groups), dtype=torch.float32, device=output_device)
1403
+ x_scale = max_fp8 / group_max # pyre-ignore
1404
+ # pyre-ignore[16]: Undefined attribute [16]
1405
+ x_scale[x_scale == float("inf")] = 1.0
1406
+ # pyre-ignore[16]: Undefined attribute [16]
1407
+ x_fp8 = x.view(-1, k_groups, group_size) * x_scale.unsqueeze(2)
1408
+ # Cast and move data to output device (for cpu weight loading).
1409
+ x_fp8 = x_fp8.to(device=output_device, dtype=pt_dtype)
1410
+ x_scale = x_scale.to(output_device) # pyre-ignore
1411
+ if not k_major:
1412
+ x_scale = x_scale.t().contiguous()
1413
+ return x_fp8.view(x_shape), 1 / x_scale # pyre-ignore
1414
+
1415
+
1416
+ @triton.autotune(
1417
+ configs=[Config({"BLOCK_M": 16, "BLOCK_K": 512, "NUM_STAGES": 2})],
1418
+ key=["M", "K"],
1419
+ )
1420
+ @triton.jit
1421
+ def _kernel_dequantize_fp8_row(
1422
+ xq_ptr,
1423
+ x_scale_ptr,
1424
+ x_dequant_ptr,
1425
+ M,
1426
+ K,
1427
+ stride_xm,
1428
+ stride_xk,
1429
+ stride_xdqm,
1430
+ stride_xdqk,
1431
+ BLOCK_M: tl.constexpr,
1432
+ BLOCK_K: tl.constexpr,
1433
+ NUM_STAGES: tl.constexpr,
1434
+ USE_INT64: tl.constexpr,
1435
+ ):
1436
+ """
1437
+ Kernel to dequantize FP8 tensor to BF16 tensor.
1438
+ Args:
1439
+ xq_ptr (tl.constexpr): Pointer to FP8 tensor.
1440
+ x_scale_ptr (tl.constexpr): Pointer to FP8 scale tensor.
1441
+ x_dequant_ptr (tl.constexpr): Pointer to BF16 tensor.
1442
+ M (tl.constexpr): M dimension of input tensor.
1443
+ K (tl.constexpr): K dimension of input tensor (along which scales are applied)
1444
+ BLOCK_SIZE (tl.constexpr): Block size for the K dimension.
1445
+ """
1446
+ pid = tl.program_id(axis=0)
1447
+ if USE_INT64:
1448
+ pid = pid.to(tl.int64)
1449
+ offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
1450
+ offs_k = tl.arange(0, BLOCK_K)
1451
+ scales = tl.load(x_scale_ptr + offs_m)
1452
+
1453
+ for _k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES):
1454
+ mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
1455
+ xq = tl.load(
1456
+ xq_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk,
1457
+ mask=mask,
1458
+ )
1459
+ x_dq = xq * scales[:, None]
1460
+ tl.store(
1461
+ x_dequant_ptr
1462
+ + offs_m[:, None] * stride_xdqm
1463
+ + offs_k[None, :] * stride_xdqk,
1464
+ x_dq,
1465
+ mask=mask,
1466
+ )
1467
+ offs_k += BLOCK_K
1468
+
1469
+
1470
+ def dequantize_fp8_row(
1471
+ xq: torch.Tensor,
1472
+ x_scale: torch.Tensor,
1473
+ ) -> torch.Tensor:
1474
+ """
1475
+ Rowwise Dequantize FP8 tensor to BF16 tensor along last axis.
1476
+
1477
+ Args:
1478
+ xq (torch.Tensor): FP8 tensor to be dequantized.
1479
+ x_scale (torch.Tensor): FP8 scale tensor.
1480
+
1481
+ Returns:
1482
+ torch.Tensor: Dequantized BF16 tensor.
1483
+ """
1484
+
1485
+ assert xq.is_contiguous() and x_scale.is_contiguous(), (
1486
+ "Input tensors must be contiguous"
1487
+ )
1488
+ x_dequant = torch.empty_like(xq, dtype=torch.bfloat16)
1489
+
1490
+ # Reshape to 2-d array keeping last dim only.
1491
+ K = xq.shape[-1]
1492
+ xq = xq.reshape(-1, K)
1493
+ M = xq.shape[0]
1494
+ use_int64 = xq.numel() > 2**31
1495
+
1496
+ def grid(meta: Dict[str, int]) -> Tuple[int]:
1497
+ return (triton.cdiv(M, meta["BLOCK_M"]),)
1498
+
1499
+ with torch.cuda.device(xq.device.index):
1500
+ _kernel_dequantize_fp8_row[grid](
1501
+ xq,
1502
+ x_scale,
1503
+ x_dequant,
1504
+ M,
1505
+ K,
1506
+ xq.stride(0),
1507
+ xq.stride(1),
1508
+ xq.stride(0), # Use squashed stride.
1509
+ xq.stride(1),
1510
+ USE_INT64=use_int64,
1511
+ )
1512
+ return x_dequant
1513
+
1514
+
1515
+ @triton.autotune(
1516
+ configs=[Config({"BLOCK_M": 16, "BLOCK_K": 512, "NUM_STAGES": 2})],
1517
+ key=["M", "K"],
1518
+ )
1519
+ @triton.jit
1520
+ def _kernel_dequantize_fp8_packed_row(
1521
+ xq_ptr,
1522
+ x_scale_ptr,
1523
+ x_dequant_ptr,
1524
+ M,
1525
+ K,
1526
+ stride_xm,
1527
+ stride_xk,
1528
+ stride_xdqm,
1529
+ stride_xdqk,
1530
+ BLOCK_M: tl.constexpr,
1531
+ BLOCK_K: tl.constexpr,
1532
+ NUM_STAGES: tl.constexpr,
1533
+ USE_INT64: tl.constexpr,
1534
+ ):
1535
+ """
1536
+ Kernel to dequantize FP8 tensor to BF16 tensor.
1537
+ Args:
1538
+ xq_ptr (tl.constexpr): Pointer to FP8 tensor.
1539
+ x_scale_ptr (tl.constexpr): Pointer to FP8 scale tensor.
1540
+ x_dequant_ptr (tl.constexpr): Pointer to BF16 tensor.
1541
+ M (tl.constexpr): M dimension of input tensor.
1542
+ K (tl.constexpr): K dimension of input tensor (along which scales are applied)
1543
+ BLOCK_SIZE (tl.constexpr): Block size for the K dimension.
1544
+ """
1545
+ pid = tl.program_id(axis=0)
1546
+ if USE_INT64:
1547
+ pid = pid.to(tl.int64)
1548
+ offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
1549
+ offs_k = tl.arange(0, BLOCK_K)
1550
+ scales = tl.load(x_scale_ptr + offs_m)
1551
+
1552
+ for _k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES):
1553
+ mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
1554
+
1555
+ xq = tl.load(
1556
+ xq_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk,
1557
+ mask=mask,
1558
+ other=0.0,
1559
+ )
1560
+ x_dq = xq * scales[:, None]
1561
+
1562
+ tl.store(
1563
+ x_dequant_ptr
1564
+ + offs_m[:, None] * stride_xdqm
1565
+ + offs_k[None, :] * stride_xdqk,
1566
+ x_dq,
1567
+ mask=mask,
1568
+ )
1569
+ offs_k += BLOCK_K
1570
+
1571
+
1572
+ def dequantize_fp8_packed_row(
1573
+ xq: torch.Tensor,
1574
+ ) -> torch.Tensor:
1575
+ """
1576
+ Rowwise Dequantize FP8 tensor to BF16 tensor along last axis.
1577
+
1578
+ Args:
1579
+ xq (torch.Tensor): Packed FP8 tensor to be dequantized. The last 4 bytes of each row is the FP32 scale for that row.
1580
+
1581
+ Returns:
1582
+ torch.Tensor: Dequantized BF16 tensor.
1583
+ """
1584
+
1585
+ # Create a view of the packed tensors, get the scale and actual xq tensor
1586
+ # This makes it much easier to write the kernel
1587
+ orig_shape = (*xq.shape[:-1], xq.shape[-1] - 4)
1588
+ actual_xq = xq[..., :-4].view(orig_shape)
1589
+
1590
+ assert xq.is_contiguous(), "Input tensors must be contiguous"
1591
+ x_dequant = torch.empty(orig_shape, dtype=torch.bfloat16, device=xq.device)
1592
+
1593
+ # Calculate number of rows when flattened
1594
+ num_rows = actual_xq.numel() // actual_xq.shape[-1]
1595
+
1596
+ # TODO: we take a perf hit from these reshapes, can we do better?
1597
+ # It's hard to skip this reshape, we can't create a int32/float32 view because of alignment issues
1598
+ scale_view = xq[..., -4:].reshape((num_rows * 4)).view(torch.float32)
1599
+ scale_view = scale_view.view(orig_shape[:-1])
1600
+
1601
+ # Reshape to 2-d array keeping last dim only.
1602
+ K = actual_xq.shape[-1]
1603
+ actual_xq = actual_xq.reshape(-1, K)
1604
+ M = actual_xq.shape[0]
1605
+ use_int64 = actual_xq.numel() > 2**31
1606
+
1607
+ def grid(meta: Dict[str, int]) -> Tuple[int]:
1608
+ return (triton.cdiv(M, meta["BLOCK_M"]),)
1609
+
1610
+ with torch.cuda.device(actual_xq.device.index):
1611
+ _kernel_dequantize_fp8_packed_row[grid](
1612
+ actual_xq,
1613
+ scale_view,
1614
+ x_dequant,
1615
+ M,
1616
+ K,
1617
+ actual_xq.stride(0),
1618
+ actual_xq.stride(1),
1619
+ x_dequant.stride(-2), # Use squashed stride.
1620
+ x_dequant.stride(-1),
1621
+ USE_INT64=use_int64,
1622
+ )
1623
+
1624
+ return x_dequant
1625
+
1626
+
1627
+ @triton.jit
1628
+ def _kernel_quantize_fp8_tensor(
1629
+ A,
1630
+ A_fp8,
1631
+ global_max_ptr,
1632
+ blocks_done_ptr,
1633
+ scale_ready_ptr,
1634
+ scale_out_ptr,
1635
+ N,
1636
+ num_sms,
1637
+ TL_FP8_DTYPE: tl.constexpr,
1638
+ MAX_FP8: tl.constexpr,
1639
+ EPS: tl.constexpr,
1640
+ BLOCK_SIZE: tl.constexpr,
1641
+ ) -> None:
1642
+ """Fused persistent kernel that finds global max and quantizes.
1643
+
1644
+ Uses a persistent kernel approach where we launch exactly num_sms blocks,
1645
+ guaranteeing all blocks run concurrently and avoiding deadlocks.
1646
+ Each block processes multiple chunks of the input in a loop.
1647
+
1648
+ Args:
1649
+ A (Tensor): Flattened input tensor.
1650
+ A_fp8 (Tensor): Output fp8 tensor.
1651
+ global_max_ptr (Tensor): Pointer to global max value (initialized to 0).
1652
+ blocks_done_ptr (Tensor): Pointer to atomic counter (initialized to 0).
1653
+ scale_ready_ptr (Tensor): Pointer to ready flag (initialized to 0).
1654
+ scale_out_ptr (Tensor): Pointer to output scale value.
1655
+ N (int): Total number of elements.
1656
+ num_sms (int): Number of SMs (equals number of blocks launched).
1657
+ TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
1658
+ MAX_FP8 (float): Maximum expressible value for FP8.
1659
+ EPS (float): Epsilon for numerical stability.
1660
+ BLOCK_SIZE (int): Block size for processing.
1661
+ """
1662
+ pid = tl.program_id(0)
1663
+
1664
+ # Phase 1: Each block finds max across all its assigned chunks
1665
+ local_max = 0.0
1666
+ chunk_id = pid
1667
+ num_chunks = tl.cdiv(N, BLOCK_SIZE)
1668
+
1669
+ while chunk_id < num_chunks:
1670
+ offset = chunk_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
1671
+ a = tl.load(A + offset, mask=offset < N, other=0.0)
1672
+ chunk_max = tl.max(tl.abs(a))
1673
+ local_max = tl.maximum(local_max, chunk_max)
1674
+ chunk_id += num_sms
1675
+
1676
+ # Atomically update global max using integer atomics on float bits
1677
+ local_max_int = local_max.to(tl.float32, bitcast=False).to(tl.int32, bitcast=True)
1678
+ tl.atomic_max(global_max_ptr, local_max_int)
1679
+
1680
+ # Increment completed block counter
1681
+ old_count = tl.atomic_add(blocks_done_ptr, 1)
1682
+
1683
+ # Last block to finish computes the scale
1684
+ if old_count == num_sms - 1:
1685
+ global_max_int = tl.load(global_max_ptr)
1686
+ global_max_float = global_max_int.to(tl.float32, bitcast=True)
1687
+ global_max_float = tl.maximum(global_max_float, EPS)
1688
+ scale = tl.div_rn(global_max_float, MAX_FP8)
1689
+ tl.store(scale_out_ptr, scale)
1690
+ tl.atomic_xchg(scale_ready_ptr, 1)
1691
+
1692
+ # Phase 2: Spin-wait for scale to be ready
1693
+ # Safe because all num_sms blocks are guaranteed to be running
1694
+ while tl.atomic_add(scale_ready_ptr, 0) == 0:
1695
+ pass
1696
+
1697
+ # Load scale and quantize all assigned chunks
1698
+ scale = tl.load(scale_out_ptr)
1699
+ chunk_id = pid
1700
+
1701
+ while chunk_id < num_chunks:
1702
+ offset = chunk_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
1703
+ a = tl.load(A + offset, mask=offset < N, other=0.0)
1704
+ a_fp8 = a * tl.div_rn(1.0, scale)
1705
+ a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
1706
+ tl.store(A_fp8 + offset, a_fp8, mask=offset < N)
1707
+ chunk_id += num_sms
1708
+
1709
+
1710
+ def _get_num_sms(device: torch.device) -> int:
1711
+ """Get the number of SMs on the current GPU device."""
1712
+ return torch.cuda.get_device_properties(device).multi_processor_count
1713
+
1714
+
1715
+ def triton_quantize_fp8_tensor(
1716
+ a: torch.Tensor,
1717
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1718
+ """
1719
+ Triton implementation to quantize a tensor to fp8 with a single scale.
1720
+
1721
+ Uses a fused persistent kernel with atomic operations for inter-block
1722
+ coordination. By launching exactly num_sms blocks, we guarantee all
1723
+ blocks run concurrently, avoiding deadlocks from spin-waiting.
1724
+
1725
+ Args:
1726
+ a (Tensor): Input tensor to be quantized.
1727
+
1728
+ Returns:
1729
+ torch.Tensor: fp8 quantized tensor.
1730
+ torch.Tensor: scalar reciprocal scale tensor (fp32).
1731
+ """
1732
+ pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
1733
+ N = a.numel()
1734
+
1735
+ BLOCK_SIZE = 4096
1736
+ # Launch exactly num_sms blocks to guarantee concurrent execution
1737
+ num_sms = _get_num_sms(a.device)
1738
+
1739
+ # Allocate synchronization buffers (initialized to 0)
1740
+ global_max = torch.zeros(1, device=a.device, dtype=torch.int32)
1741
+ blocks_done = torch.zeros(1, device=a.device, dtype=torch.int32)
1742
+ scale_ready = torch.zeros(1, device=a.device, dtype=torch.int32)
1743
+ scale_out = torch.empty((), device=a.device, dtype=torch.float32)
1744
+
1745
+ # Output tensor matches shape of a but is contiguous.
1746
+ a_fp8 = torch.empty_like(a, dtype=pt_dtype)
1747
+
1748
+ with torch.cuda.device(a.device.index):
1749
+ _kernel_quantize_fp8_tensor[(num_sms,)](
1750
+ a,
1751
+ a_fp8,
1752
+ global_max,
1753
+ blocks_done,
1754
+ scale_ready,
1755
+ scale_out,
1756
+ N,
1757
+ num_sms,
1758
+ # pyre-ignore[6]: Incompatible parameter type
1759
+ TL_FP8_DTYPE=tl_dtype,
1760
+ # pyre-ignore[6]: Incompatible parameter type
1761
+ MAX_FP8=max_fp8,
1762
+ # pyre-ignore[6]: Incompatible parameter type
1763
+ EPS=eps,
1764
+ # pyre-ignore[6]: Incompatible parameter type
1765
+ BLOCK_SIZE=BLOCK_SIZE,
1766
+ )
1767
+
1768
+ return a_fp8, scale_out
1769
+
1770
+
1771
+ @torch.library.custom_op("triton::quantize_fp8_tensor", mutates_args=())
1772
+ def quantize_fp8_tensor(
1773
+ a: torch.Tensor,
1774
+ use_triton: bool = True,
1775
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1776
+ """
1777
+ Quantize a tensor to fp8 with a single scale factor across the entire tensor.
1778
+
1779
+ The scale is computed as MAX_FP8 / max(abs(a)) and applied uniformly.
1780
+ Handles non-contiguous input tensors and returns a contiguous output.
1781
+
1782
+ Args:
1783
+ a (Tensor): Input tensor of any shape. May be non-contiguous.
1784
+ use_triton (bool): Whether to use optimized triton kernel.
1785
+
1786
+ Returns:
1787
+ torch.Tensor: fp8 quantized tensor (contiguous, same shape as input).
1788
+ torch.Tensor: scalar reciprocal scale tensor (fp32).
1789
+ """
1790
+ if a.device == torch.device("cpu"):
1791
+ use_triton = False
1792
+
1793
+ if use_triton:
1794
+ a_fp8, reciprocal_scale = triton_quantize_fp8_tensor(a)
1795
+ return a_fp8, reciprocal_scale
1796
+
1797
+ # Fallback to PyTorch implementation
1798
+ pt_dtype, _, max_fp8, eps = get_fp8_constants()
1799
+
1800
+ tensor_max = torch.max(torch.abs(a)).to(torch.float32)
1801
+ tensor_max = torch.clamp(tensor_max, min=eps)
1802
+
1803
+ scale = max_fp8 / tensor_max # pyre-ignore[58]
1804
+ a_scaled = a.to(torch.float32) * scale
1805
+ a_scaled = torch.clamp(a_scaled, -max_fp8, max_fp8)
1806
+ a_fp8 = a_scaled.to(pt_dtype)
1807
+
1808
+ reciprocal_scale = (1.0 / scale).to(torch.float32) # pyre-ignore[16]
1809
+
1810
+ return a_fp8, reciprocal_scale
1811
+
1812
+
1813
+ @quantize_fp8_tensor.register_fake
1814
+ def quantize_fp8_tensor_meta(
1815
+ a: torch.Tensor,
1816
+ use_triton: bool = True,
1817
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1818
+ """Shape function for torch compile."""
1819
+ dtype = get_fp8_constants()[0]
1820
+ # Preserve memory format (e.g., channels_last_3d) from input tensor
1821
+ fake_out = torch.empty_like(a, dtype=dtype)
1822
+ fake_scale = torch.empty((), device=a.device, dtype=torch.float32)
1823
+ return fake_out, fake_scale
1824
+
1825
+
1826
+ @triton.jit
1827
+ def _kernel_dequantize_fp8_block(
1828
+ xq_ptr,
1829
+ x_scale_ptr,
1830
+ x_dequant_ptr,
1831
+ M,
1832
+ K,
1833
+ BLOCK_M: tl.constexpr,
1834
+ BLOCK_K: tl.constexpr,
1835
+ ):
1836
+ """
1837
+ Kernel to dequantize FP8 tensor to BF16 tensor.
1838
+ Args:
1839
+ xq_ptr (tl.constexpr): Pointer to FP8 tensor.
1840
+ x_scale_ptr (tl.constexpr): Pointer to FP8 scale tensor.
1841
+ x_dequant_ptr (tl.constexpr): Pointer to BF16 tensor.
1842
+ M (tl.constexpr): M dimension of input tensor.
1843
+ K (tl.constexpr): K dimension of input tensor.
1844
+ BLOCK_M (tl.constexpr): Block size for the M dimension.
1845
+ BLOCK_K (tl.constexpr): Block size for the K dimension.
1846
+ """
1847
+ pid_m = tl.program_id(axis=0)
1848
+ pid_k = tl.program_id(axis=1)
1849
+ k = tl.cdiv(K, BLOCK_K)
1850
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
1851
+ offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
1852
+ offs = offs_m[:, None] * K + offs_k[None, :]
1853
+ mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
1854
+ xq = tl.load(xq_ptr + offs, mask=mask).to(tl.bfloat16)
1855
+ x_scale = tl.load(x_scale_ptr + pid_m * k + pid_k)
1856
+ x_dequant = xq * x_scale
1857
+ tl.store(x_dequant_ptr + offs, x_dequant, mask=mask)
1858
+
1859
+
1860
+ def dequantize_fp8_block(
1861
+ xq: torch.Tensor,
1862
+ x_scale: torch.Tensor,
1863
+ block_m: int = 256,
1864
+ block_k: int = 256,
1865
+ ) -> torch.Tensor:
1866
+ """
1867
+ Dequantize FP8 tensor to BF16 tensor.
1868
+
1869
+ Args:
1870
+ xq (torch.Tensor): FP8 tensor to be dequantized.
1871
+ x_scale (torch.Tensor): FP8 scale tensor.
1872
+ block_m (int): Block size for the M dimension.
1873
+ block_k (int): Block size for the K dimension.
1874
+
1875
+ Returns:
1876
+ torch.Tensor: Dequantized BF16 tensor.
1877
+ """
1878
+
1879
+ assert xq.is_contiguous() and x_scale.is_contiguous(), (
1880
+ "Input tensors must be contiguous"
1881
+ )
1882
+ assert xq.dim() == 2 and x_scale.dim() == 2, "Input tensors must have 2 dimensions"
1883
+ M, K = xq.size()
1884
+ x_dequant = torch.empty_like(xq, dtype=torch.bfloat16)
1885
+
1886
+ def grid(meta: Dict[str, int]) -> Tuple[int, int]:
1887
+ return (
1888
+ triton.cdiv(M, meta["BLOCK_M"]),
1889
+ triton.cdiv(K, meta["BLOCK_K"]),
1890
+ )
1891
+
1892
+ with torch.cuda.device(xq.device.index):
1893
+ _kernel_dequantize_fp8_block[grid](
1894
+ xq,
1895
+ x_scale,
1896
+ x_dequant,
1897
+ M,
1898
+ K,
1899
+ BLOCK_M=block_m, # pyre-ignore[6]
1900
+ BLOCK_K=block_k, # pyre-ignore[6]
1901
+ )
1902
+ return x_dequant