mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. mslk/__init__.py +56 -0
  2. mslk/attention/__init__.py +7 -0
  3. mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
  4. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
  5. mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
  6. mslk/attention/flash_attn/__init__.py +22 -0
  7. mslk/attention/flash_attn/ampere_helpers.py +104 -0
  8. mslk/attention/flash_attn/barrier.py +72 -0
  9. mslk/attention/flash_attn/benchmark.py +269 -0
  10. mslk/attention/flash_attn/blackwell_helpers.py +754 -0
  11. mslk/attention/flash_attn/block_info.py +109 -0
  12. mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
  13. mslk/attention/flash_attn/block_sparsity.py +219 -0
  14. mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
  15. mslk/attention/flash_attn/copy_utils.py +341 -0
  16. mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
  17. mslk/attention/flash_attn/fast_math.py +22 -0
  18. mslk/attention/flash_attn/flash_bwd.py +1262 -0
  19. mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
  20. mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
  21. mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
  22. mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
  23. mslk/attention/flash_attn/flash_fwd.py +2471 -0
  24. mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
  25. mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
  26. mslk/attention/flash_attn/hopper_helpers.py +102 -0
  27. mslk/attention/flash_attn/interface.py +1771 -0
  28. mslk/attention/flash_attn/mask.py +610 -0
  29. mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
  30. mslk/attention/flash_attn/named_barrier.py +32 -0
  31. mslk/attention/flash_attn/pack_gqa.py +165 -0
  32. mslk/attention/flash_attn/paged_kv.py +176 -0
  33. mslk/attention/flash_attn/pipeline.py +273 -0
  34. mslk/attention/flash_attn/seqlen_info.py +139 -0
  35. mslk/attention/flash_attn/softmax.py +583 -0
  36. mslk/attention/flash_attn/testing.py +424 -0
  37. mslk/attention/flash_attn/tile_scheduler.py +720 -0
  38. mslk/attention/flash_attn/utils.py +860 -0
  39. mslk/attention/fmha/__init__.py +967 -0
  40. mslk/attention/fmha/_triton/__init__.py +6 -0
  41. mslk/attention/fmha/_triton/available.py +50 -0
  42. mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
  43. mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
  44. mslk/attention/fmha/attn_bias.py +2186 -0
  45. mslk/attention/fmha/attn_bias_utils.py +536 -0
  46. mslk/attention/fmha/ck.py +508 -0
  47. mslk/attention/fmha/ck_decoder.py +141 -0
  48. mslk/attention/fmha/ck_splitk.py +204 -0
  49. mslk/attention/fmha/common.py +598 -0
  50. mslk/attention/fmha/cutlass.py +461 -0
  51. mslk/attention/fmha/cutlass_blackwell.py +560 -0
  52. mslk/attention/fmha/dispatch.py +224 -0
  53. mslk/attention/fmha/flash.py +862 -0
  54. mslk/attention/fmha/flash3.py +858 -0
  55. mslk/attention/fmha/flash_mtia.py +245 -0
  56. mslk/attention/fmha/merge_training.py +192 -0
  57. mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
  58. mslk/attention/fmha/torch_attention_compat.py +154 -0
  59. mslk/attention/fmha/tree_attention.py +718 -0
  60. mslk/attention/fmha/triton_splitk.py +1378 -0
  61. mslk/attention/fmha/unbind.py +130 -0
  62. mslk/attention/fmha/utils/__init__.py +6 -0
  63. mslk/attention/fmha/utils/bench.py +74 -0
  64. mslk/attention/fmha/utils/cpp_lib.py +148 -0
  65. mslk/attention/fmha/utils/op_common.py +65 -0
  66. mslk/attention/gqa_attn_splitk/__init__.py +11 -0
  67. mslk/bench/comm/__init__.py +7 -0
  68. mslk/bench/comm/comm_bench.py +255 -0
  69. mslk/bench/common/__init__.py +5 -0
  70. mslk/bench/common/utils.py +148 -0
  71. mslk/bench/conv/__init__.py +7 -0
  72. mslk/bench/conv/conv_bench.py +551 -0
  73. mslk/bench/conv/conv_ops.py +213 -0
  74. mslk/bench/gemm/__init__.py +7 -0
  75. mslk/bench/gemm/gemm_bench.py +859 -0
  76. mslk/bench/gemm/gemm_ops.py +3342 -0
  77. mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
  78. mslk/bench/moe/__init__.py +7 -0
  79. mslk/bench/moe/gather_scatter_bench.py +356 -0
  80. mslk/bench/quantize/quantize_bench.py +345 -0
  81. mslk/bench/quantize/quantize_ops.py +266 -0
  82. mslk/comm/__init__.py +11 -0
  83. mslk/conv/__init__.py +11 -0
  84. mslk/gemm/__init__.py +18 -0
  85. mslk/gemm/triton/__init__.py +7 -0
  86. mslk/gemm/triton/fp8_gemm.py +2702 -0
  87. mslk/gemm/triton/grouped_gemm.py +1132 -0
  88. mslk/gemm/triton/matmul_perf_model.py +237 -0
  89. mslk/gemm/triton/utils.py +128 -0
  90. mslk/kv_cache/__init__.py +11 -0
  91. mslk/moe/__init__.py +26 -0
  92. mslk/moe/activation.py +291 -0
  93. mslk/moe/gather_scatter.py +739 -0
  94. mslk/moe/layers.py +1240 -0
  95. mslk/moe/shuffling.py +421 -0
  96. mslk/mslk.so +0 -0
  97. mslk/quantize/__init__.py +11 -0
  98. mslk/quantize/shuffle.py +306 -0
  99. mslk/quantize/triton/__init__.py +7 -0
  100. mslk/quantize/triton/fp4_quantize.py +5942 -0
  101. mslk/quantize/triton/fp8_quantize.py +1902 -0
  102. mslk/testing/__init__.py +7 -0
  103. mslk/testing/attributes.py +60 -0
  104. mslk/testing/rocm.py +91 -0
  105. mslk/utils/__init__.py +7 -0
  106. mslk/utils/torch/__init__.py +7 -0
  107. mslk/utils/torch/library.py +150 -0
  108. mslk/utils/triton/__init__.py +7 -0
  109. mslk/utils/triton/fp8_utils.py +72 -0
  110. mslk/utils/triton/utils.py +128 -0
  111. mslk/version.py +11 -0
  112. mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
  113. mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
  114. mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
  115. mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
  116. mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1534 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import functools
8
+ import sys
9
+ from typing import Callable, Dict, Tuple, Union
10
+
11
+ import torch
12
+ import triton
13
+ import triton.language as tl
14
+
15
+ from .vararg_kernel import unroll_varargs, VAR_ARGS_ARRAY
16
+
17
+ # pyre-ignore-all-errors
18
+ AUTOTUNER_KEY = [
19
+ "Z",
20
+ "H",
21
+ "G",
22
+ "N_CTX_Q",
23
+ "N_CTX_K",
24
+ "BLOCK_DMODEL",
25
+ "PACKED_PER_VAL",
26
+ "N_GROUPS",
27
+ "BLOCK_N_PER_SPLIT",
28
+ "PAGE_SIZE",
29
+ ]
30
+
31
+
32
+ @triton.jit
33
+ def _fwd_kernel_splitK( # noqa: C901
34
+ Q,
35
+ K,
36
+ V,
37
+ sm_scale,
38
+ Out_splitK, # [B, H, split_k, Mq, K]
39
+ LSE_splitk, # [B, H, split_k, Mq]
40
+ block_tables,
41
+ Seq_len,
42
+ Seq_starts_k,
43
+ Seq_starts_q,
44
+ Seq_starts_q_multiplier,
45
+ additive_bias,
46
+ K_fp8_scale_shift,
47
+ V_fp8_scale_shift,
48
+ q_fp8_scale_shift,
49
+ stride_qz,
50
+ stride_qm,
51
+ stride_qg,
52
+ stride_qh,
53
+ stride_qk,
54
+ stride_kz,
55
+ stride_kn,
56
+ stride_kg,
57
+ stride_kh,
58
+ stride_kk,
59
+ stride_vz,
60
+ stride_vn,
61
+ stride_vg,
62
+ stride_vh,
63
+ stride_vk,
64
+ stride_osk_z,
65
+ stride_osk_g,
66
+ stride_osk_h,
67
+ stride_osk_s,
68
+ stride_osk_m,
69
+ stride_osk_k,
70
+ stride_lsek_z,
71
+ stride_lsek_g,
72
+ stride_lsek_h,
73
+ stride_lsek_s,
74
+ stride_lsek_m,
75
+ stride_blocktablesz,
76
+ stride_blocktablesl,
77
+ stride_bias_b,
78
+ stride_bias_g,
79
+ stride_bias_h,
80
+ stride_bias_qm,
81
+ stride_bias_km,
82
+ stride_k_fp8_scale_shift_z: tl.constexpr,
83
+ stride_k_fp8_scale_shift_n: tl.constexpr,
84
+ stride_k_fp8_scale_shift_g: tl.constexpr,
85
+ stride_k_fp8_scale_shift_h: tl.constexpr,
86
+ stride_v_fp8_scale_shift_z: tl.constexpr,
87
+ stride_v_fp8_scale_shift_n: tl.constexpr,
88
+ stride_v_fp8_scale_shift_g: tl.constexpr,
89
+ stride_v_fp8_scale_shift_h: tl.constexpr,
90
+ stride_q_fp8_scale_shift_z: tl.constexpr,
91
+ stride_q_fp8_scale_shift_m: tl.constexpr,
92
+ stride_q_fp8_scale_shift_g: tl.constexpr,
93
+ stride_q_fp8_scale_shift_h: tl.constexpr,
94
+ kv_cache_blocks_per_row: tl.constexpr,
95
+ Z: tl.constexpr,
96
+ N_CTX_Q: tl.constexpr, # The number of queries
97
+ N_CTX_K: tl.constexpr,
98
+ BLOCK_N_PER_SPLIT: tl.constexpr,
99
+ H: tl.constexpr,
100
+ G: tl.constexpr,
101
+ BLOCK_DMODEL: tl.constexpr,
102
+ USE_SEQ_LEN: tl.constexpr,
103
+ PACKED_PER_VAL: tl.constexpr,
104
+ N_GROUPS: tl.constexpr,
105
+ # It's important that BOUNDS_CHECKS_N, BLOCK_M, BLOCK_N come at the end of
106
+ # the argument list, since they are provided by the heuristics/autotune decorator.
107
+ # Otherwise Triton throws IndexError
108
+ BOUNDS_CHECKS_N: tl.constexpr,
109
+ BLOCK_M: tl.constexpr,
110
+ BLOCK_N: tl.constexpr,
111
+ IS_SPLITK: tl.constexpr,
112
+ SPLIT_K_EARLY_EXIT: tl.constexpr,
113
+ IS_CAUSAL: tl.constexpr,
114
+ IS_LOCAL: tl.constexpr,
115
+ NUM_QUERIES_CAUSAL: tl.constexpr, # The N_CTX_Q queries are from this many sequence positions
116
+ USE_PAGED_ATTENTION: tl.constexpr,
117
+ PAGE_SIZE: tl.constexpr,
118
+ WINDOW_LEFT: tl.constexpr,
119
+ WINDOW_RIGHT: tl.constexpr,
120
+ WRITE_LSE: tl.constexpr,
121
+ HAS_ADDITIVE_BIAS: tl.constexpr,
122
+ NUM_PROGRAMS_DIM2_CONST: tl.constexpr,
123
+ IS_HIP: tl.constexpr,
124
+ QUANTIZE_PV_TO_FP8: tl.constexpr,
125
+ QUANTIZE_QK_TO_FP8: tl.constexpr,
126
+ USE_FP32_SCALES: tl.constexpr,
127
+ ):
128
+ """This kernel can accept non-quantized or int4-quantized keys/values.
129
+ PACKED_PER_VAL determines the quantization type:
130
+ - PACKED_PER_VAL == 1 means no quantization
131
+ - PACKED_PER_VAL == 8 means 4-bit quantization (8 packed quantized values inside one int32)
132
+ For the quantized case K/V should be int32 tensors.
133
+ Quantization can be row-wise (when N_GROUPS = 1) or group-wise with N_GROUPS = 2, 4, or 8.
134
+ Quantization coefficients are stored at the beginning of the row along the last dimension of K/V
135
+ So K[B, H, M, :] has a form
136
+ [ quant_coef0, quant_coef1, ...|
137
+ group0_quant_value0, group0_quant_value1,... |
138
+ group1_quant_value0, group1_quant_value1,...]
139
+ where each quant_coef is an int32 which should be interpreted as 2 packed float16: scale and offset.
140
+
141
+ Note: this kernel needs to be processed by unroll_varargs
142
+ before compilation. That will unroll variables marked with "VAR_ARGS_ARRAY" into lists.
143
+ See how FwOp.apply does it below.
144
+
145
+ Set IS_SPLITK=False to indicate the MHA result should be written directly.
146
+ No metadata will be written.
147
+ """
148
+ internal_dtype = (
149
+ tl.float64 if Out_splitK.dtype.element_ty is tl.float64 else tl.float32
150
+ )
151
+ tl.static_assert(
152
+ (PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32))
153
+ or (
154
+ (PACKED_PER_VAL == 4 or PACKED_PER_VAL == 8)
155
+ and tl.constexpr(K.dtype.element_ty == tl.int32)
156
+ ),
157
+ f"Only int4 and fp8 quantization is supported, K/V should have dtype int32 in "
158
+ f"the quantized case: {PACKED_PER_VAL=} {tl.constexpr(K.dtype)=} {tl.constexpr(K.dtype.element_ty)=}",
159
+ )
160
+ tl.static_assert(
161
+ (((N_GROUPS == 1 or N_GROUPS == 2) or N_GROUPS == 4) or N_GROUPS == 8),
162
+ "Number of quantization groups can be 1 (row-wise quantization), 2, 4, or 8.",
163
+ )
164
+ tl.static_assert(
165
+ N_GROUPS == 1 or K_fp8_scale_shift is None,
166
+ f"Only row-wise fp8 quantization is supported, but got {N_GROUPS=} > 1.",
167
+ )
168
+ FP8_QUANTIZED: tl.constexpr = K_fp8_scale_shift is not None
169
+ INT4_QUANTIZED: tl.constexpr = PACKED_PER_VAL > 1 and not FP8_QUANTIZED
170
+ PACKED_D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // PACKED_PER_VAL // N_GROUPS
171
+ D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // N_GROUPS
172
+
173
+ start_m = tl.program_id(0)
174
+ off_zhg = tl.program_id(1)
175
+ off_z = (off_zhg // (H * G)).to(tl.int64)
176
+ off_hg = off_zhg % (H * G)
177
+ off_h = off_hg // G
178
+ off_g = off_hg % G
179
+ splitk_idx = tl.program_id(2)
180
+
181
+ if USE_SEQ_LEN:
182
+ kv_len = tl.load(Seq_len + off_z)
183
+ if SPLIT_K_EARLY_EXIT and kv_len == 0:
184
+ return
185
+ else:
186
+ kv_len = N_CTX_K
187
+
188
+ if Seq_starts_k is None:
189
+ start_kv_idx = 0
190
+ else:
191
+ start_kv_idx = tl.load(Seq_starts_k + off_z)
192
+ if USE_SEQ_LEN and PAGE_SIZE > 0:
193
+ # gappy with paged attention stores each "end" instead of each "length"
194
+ # because that's what FA3 needs.
195
+ kv_len -= start_kv_idx
196
+
197
+ if Seq_starts_q is None:
198
+ q_len = N_CTX_Q
199
+ queries_use_batch_dim = 1
200
+ off_m = 0
201
+ else:
202
+ queries_use_batch_dim = 0
203
+ off_m = tl.load(Seq_starts_q + off_z) * Seq_starts_q_multiplier
204
+ q_len = tl.load(Seq_starts_q + off_z + 1) * Seq_starts_q_multiplier - off_m
205
+ if q_len == 0:
206
+ return
207
+
208
+ k_base = K + off_h * stride_kh + off_g * stride_kg
209
+ v_base = V + off_h * stride_vh + off_g * stride_vg
210
+
211
+ if FP8_QUANTIZED:
212
+ k_fp8_scale_shift_base = (
213
+ K_fp8_scale_shift
214
+ + off_h * stride_k_fp8_scale_shift_h
215
+ + off_g * stride_k_fp8_scale_shift_g
216
+ )
217
+ v_fp8_scale_shift_base = (
218
+ V_fp8_scale_shift
219
+ + off_h * stride_v_fp8_scale_shift_h
220
+ + off_g * stride_v_fp8_scale_shift_g
221
+ )
222
+ else:
223
+ k_fp8_scale_shift_base = None
224
+ v_fp8_scale_shift_base = None
225
+
226
+ # Boundaries of split-k chunk
227
+ chunk_hi = (splitk_idx + 1) * BLOCK_N_PER_SPLIT
228
+ chunk_lo = splitk_idx * BLOCK_N_PER_SPLIT
229
+ ignore_in_first_block = 0
230
+ # For paged attention case K/V_block_ptr are defined inside the loop
231
+ # whereas for non-paged case they are defined before the loop.
232
+ if PAGE_SIZE > 0:
233
+ # Page contains several blocks
234
+ BLOCKS_IN_PAGE: tl.constexpr = PAGE_SIZE // BLOCK_N
235
+ # Align boundaries of split-k chunk to block boundaries
236
+ # In the last chunk, shift hi to the right, in the other chunks, shift it to the left
237
+ # TODO: Replace NUM_PROGRAMS_DIM2_CONST with tl.num_programs(2) after
238
+ # the next Triton upgrade.
239
+ is_last_chunk = splitk_idx == NUM_PROGRAMS_DIM2_CONST - 1
240
+ shift = BLOCK_N - 1 if is_last_chunk else 0
241
+ lo = (tl.maximum(chunk_lo, start_kv_idx) // BLOCK_N) * BLOCK_N
242
+ ignore_in_first_block = tl.maximum(0, (start_kv_idx - lo))
243
+ hi = ((chunk_hi + shift) // BLOCK_N) * BLOCK_N
244
+ hi = tl.minimum(hi, kv_len + start_kv_idx)
245
+ block_table = block_tables + stride_blocktablesz * off_z
246
+ # Offset in integer blocks
247
+ logical_block_idx = lo // BLOCK_N
248
+ else:
249
+ lo = chunk_lo
250
+ hi = tl.minimum(chunk_hi, kv_len)
251
+ if Seq_starts_k is not None:
252
+ k_base += start_kv_idx * stride_kn
253
+ v_base += start_kv_idx * stride_vn
254
+ else:
255
+ k_base += off_z * stride_kz
256
+ v_base += off_z * stride_vz
257
+ # Additional shift by 1 along the last dimension in the quantized case, since
258
+ # the first element along that dim contains packed quantization coefficients.
259
+ K_block_ptr = tl.make_block_ptr(
260
+ base=k_base + stride_kk * INT4_QUANTIZED * N_GROUPS,
261
+ shape=(PACKED_D_PER_GROUP, hi),
262
+ strides=(stride_kk, stride_kn),
263
+ offsets=(0, lo),
264
+ block_shape=(PACKED_D_PER_GROUP, BLOCK_N),
265
+ order=(0, 1),
266
+ )
267
+ V_block_ptr = tl.make_block_ptr(
268
+ base=v_base + stride_vk * INT4_QUANTIZED * N_GROUPS,
269
+ shape=(hi, PACKED_D_PER_GROUP if not QUANTIZE_PV_TO_FP8 else D_PER_GROUP),
270
+ strides=(stride_vn, stride_vk),
271
+ offsets=(lo, 0),
272
+ block_shape=(
273
+ BLOCK_N,
274
+ PACKED_D_PER_GROUP if not QUANTIZE_PV_TO_FP8 else D_PER_GROUP,
275
+ ),
276
+ order=(1, 0),
277
+ )
278
+
279
+ if INT4_QUANTIZED:
280
+ # Pointers to quantization coefficients. Even those they are 1D,
281
+ # we use block pointers here so the pointer arithmetic is in int64,
282
+ # as otherwise the offsets for V_scale_shift_block_ptr may overflow.
283
+ K_scale_shift_block_ptr = tl.make_block_ptr(
284
+ base=k_base,
285
+ shape=(1, hi),
286
+ strides=(stride_kk, stride_kn),
287
+ offsets=(0, lo),
288
+ block_shape=(1, BLOCK_N),
289
+ order=(0, 1),
290
+ )
291
+ V_scale_shift_block_ptr = tl.make_block_ptr(
292
+ base=v_base,
293
+ shape=(hi, 1),
294
+ strides=(stride_vn, stride_vk),
295
+ offsets=(lo, 0),
296
+ block_shape=(BLOCK_N, 1),
297
+ order=(1, 0),
298
+ )
299
+ elif FP8_QUANTIZED:
300
+ if Seq_starts_k is not None:
301
+ k_fp8_scale_shift_base += start_kv_idx * stride_k_fp8_scale_shift_n
302
+ v_fp8_scale_shift_base += start_kv_idx * stride_v_fp8_scale_shift_n
303
+ else:
304
+ k_fp8_scale_shift_base += off_z * stride_k_fp8_scale_shift_z
305
+ v_fp8_scale_shift_base += off_z * stride_v_fp8_scale_shift_z
306
+ K_scale_shift_block_ptr = tl.make_block_ptr(
307
+ base=k_fp8_scale_shift_base,
308
+ shape=(1, hi),
309
+ strides=(1, stride_k_fp8_scale_shift_n),
310
+ offsets=(0, lo),
311
+ block_shape=(1, BLOCK_N),
312
+ order=(0, 1),
313
+ )
314
+ V_scale_shift_block_ptr = tl.make_block_ptr(
315
+ base=v_fp8_scale_shift_base,
316
+ shape=(hi, 1),
317
+ strides=(stride_v_fp8_scale_shift_n, 1),
318
+ offsets=(lo, 0),
319
+ block_shape=(BLOCK_N, 1),
320
+ order=(1, 0),
321
+ )
322
+ else:
323
+ K_scale_shift_block_ptr = None
324
+ V_scale_shift_block_ptr = None
325
+
326
+ if HAS_ADDITIVE_BIAS:
327
+ additive_bias_block_ptr = tl.make_block_ptr(
328
+ base=additive_bias
329
+ + off_z * stride_bias_b
330
+ + off_g * stride_bias_g
331
+ + off_h * stride_bias_h,
332
+ shape=(N_CTX_Q, hi),
333
+ strides=(stride_bias_qm, stride_bias_km),
334
+ offsets=(start_m * BLOCK_M, lo),
335
+ block_shape=(BLOCK_M, BLOCK_N),
336
+ order=(0, 1),
337
+ )
338
+
339
+ if SPLIT_K_EARLY_EXIT and lo >= hi:
340
+ return
341
+
342
+ Q_block_ptr = tl.make_block_ptr(
343
+ base=Q
344
+ + off_m * stride_qm
345
+ + off_h * stride_qh
346
+ + off_z * stride_qz * queries_use_batch_dim
347
+ + off_g * stride_qg,
348
+ shape=(q_len, BLOCK_DMODEL),
349
+ strides=(stride_qm, stride_qk),
350
+ offsets=(start_m * BLOCK_M, 0),
351
+ block_shape=(BLOCK_M, D_PER_GROUP),
352
+ order=(1, 0),
353
+ )
354
+
355
+ # initialize pointer to m and l
356
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
357
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
358
+
359
+ # Before compilation, this kernel will be processed by unroll_varargs.
360
+ # That turns tensors annotated as the one below into lists of tensors of length N_GROUPS.
361
+ # This is a solution for Triton native lack of support for lists of tensors.
362
+ acc: "VAR_ARGS_ARRAY" # noqa: F821
363
+
364
+ for i in range(len(acc)): # noqa: F821
365
+ acc[i] = tl.zeros([BLOCK_M, D_PER_GROUP], dtype=internal_dtype) # noqa: F821
366
+ # scale sm_scale by log_2(e) and use
367
+ # 2^x instead of exp in the loop because CSE and LICM
368
+ # don't work as expected with `exp` in the loop
369
+ #
370
+ # We declare log2e as a constant with a precisely-specified type to guarantee that
371
+ # triton will use the exact same value in all instances below, rather than sometimes
372
+ # using float32 and sometimes using float64. For more discussion see:
373
+ # https://github.com/triton-lang/triton/issues/5466
374
+ log2e = tl.full((), 1.44269504, tl.float32)
375
+ qk_scale = sm_scale * log2e
376
+ # load q: it will stay in SRAM throughout
377
+ q: "VAR_ARGS_ARRAY" # noqa: F821
378
+
379
+ if QUANTIZE_QK_TO_FP8:
380
+ # Create a block pointer for q_scale
381
+ q_scale_block_ptr = tl.make_block_ptr(
382
+ base=q_fp8_scale_shift
383
+ + off_m * stride_q_fp8_scale_shift_m
384
+ + off_h * stride_q_fp8_scale_shift_h
385
+ + off_g * stride_q_fp8_scale_shift_g
386
+ + off_z * stride_q_fp8_scale_shift_z * queries_use_batch_dim,
387
+ shape=(q_len, 1),
388
+ strides=(stride_q_fp8_scale_shift_m, 1),
389
+ offsets=(start_m * BLOCK_M, 0),
390
+ block_shape=(BLOCK_M, 1),
391
+ order=(1, 0),
392
+ )
393
+
394
+ # For FP8 quantized query, load and dequantize
395
+ for i in range(len(acc)): # noqa: F821
396
+ # Load quantized query
397
+ q_quantized = tl.load( # noqa: F821
398
+ tl.advance(Q_block_ptr, (0, i * D_PER_GROUP)), boundary_check=(0,)
399
+ )
400
+
401
+ # Load q_scale for dequantization - q_scale is per row
402
+ q_scale = tl.load(
403
+ tl.advance(q_scale_block_ptr, (0, i * D_PER_GROUP)), boundary_check=(0,)
404
+ )
405
+ q[i] = q_quantized.to(Q.dtype.element_ty) # noqa: F821
406
+ else:
407
+ # Regular query loading
408
+ for i in range(len(acc)): # noqa: F821
409
+ q[i] = tl.load( # noqa: F821
410
+ tl.advance(Q_block_ptr, (0, i * D_PER_GROUP)), boundary_check=(0,)
411
+ )
412
+
413
+ if IS_CAUSAL or IS_LOCAL:
414
+ # Why does the masking conditon below work as a causal mask?
415
+ # Assuming num_queries <= BLOCK_M:
416
+ # kv_pos = kv_start + range(0, BLOCK_N)
417
+ # q_offset = start_m * BLOCK_M + range(0, BLOCK_M)
418
+ # q_pos = kv_start + kv_len - num_queries + q_offset % num_queries
419
+ # mask = q_pos - kv_pos >= 0
420
+ # So the final masking condition is:
421
+ # range(0, BLOCK_M) % num_queries - range(0, BLOCK_N) >= num_queries - kv_len
422
+
423
+ q_offset = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
424
+ diag_idx = (q_offset[:, None] % NUM_QUERIES_CAUSAL) - tl.arange(0, BLOCK_N)[
425
+ None, :
426
+ ]
427
+ diag_idx_shifted = tl.constexpr(diag_idx - NUM_QUERIES_CAUSAL + kv_len)
428
+
429
+ # loop over k, v and update accumulator
430
+ for start_n in range(lo, hi, BLOCK_N):
431
+ if PAGE_SIZE > 0:
432
+ # Offset in integer blocks from the beginning of the page
433
+ block_offset_in_page = logical_block_idx % BLOCKS_IN_PAGE
434
+ # Offset in integer pages
435
+ logical_page_idx = logical_block_idx // BLOCKS_IN_PAGE
436
+ physical_page_idx = tl.load(
437
+ block_table + stride_blocktablesl * logical_page_idx
438
+ ).to(tl.int32)
439
+ offset = physical_page_idx * PAGE_SIZE + block_offset_in_page * BLOCK_N
440
+
441
+ current_block_size = min(hi - start_n, BLOCK_N)
442
+ K_block_ptr = tl.make_block_ptr(
443
+ base=k_base + stride_kk * INT4_QUANTIZED * N_GROUPS,
444
+ shape=(PACKED_D_PER_GROUP, offset + current_block_size),
445
+ strides=(stride_kk, stride_kn),
446
+ offsets=(0, offset),
447
+ block_shape=(PACKED_D_PER_GROUP, BLOCK_N),
448
+ order=(0, 1),
449
+ )
450
+ V_block_ptr = tl.make_block_ptr(
451
+ base=v_base + stride_vk * INT4_QUANTIZED * N_GROUPS,
452
+ shape=(
453
+ offset + current_block_size,
454
+ PACKED_D_PER_GROUP if not QUANTIZE_PV_TO_FP8 else D_PER_GROUP,
455
+ ),
456
+ strides=(stride_vn, stride_vk),
457
+ offsets=(offset, 0),
458
+ block_shape=(
459
+ BLOCK_N,
460
+ PACKED_D_PER_GROUP if not QUANTIZE_PV_TO_FP8 else D_PER_GROUP,
461
+ ),
462
+ order=(1, 0),
463
+ )
464
+ if INT4_QUANTIZED:
465
+ # Pointers to quantization coefficients. Even those they are 1D,
466
+ # we use block pointers here so the pointer arithmetic is in int64,
467
+ # as otherwise the offsets for V_scale_shift_block_ptr may overflow.
468
+ K_scale_shift_block_ptr = tl.make_block_ptr(
469
+ base=k_base,
470
+ shape=(1, offset + current_block_size),
471
+ strides=(stride_kk, stride_kn),
472
+ offsets=(0, offset),
473
+ block_shape=(1, BLOCK_N),
474
+ order=(0, 1),
475
+ )
476
+ V_scale_shift_block_ptr = tl.make_block_ptr(
477
+ base=v_base,
478
+ shape=(offset + current_block_size, 1),
479
+ strides=(stride_vn, stride_vk),
480
+ offsets=(offset, 0),
481
+ block_shape=(BLOCK_N, 1),
482
+ order=(1, 0),
483
+ )
484
+ elif FP8_QUANTIZED:
485
+ K_scale_shift_block_ptr = tl.make_block_ptr(
486
+ base=k_fp8_scale_shift_base,
487
+ shape=(1, offset + current_block_size),
488
+ strides=(1, stride_k_fp8_scale_shift_n),
489
+ offsets=(0, offset),
490
+ block_shape=(1, BLOCK_N),
491
+ order=(0, 1),
492
+ )
493
+ V_scale_shift_block_ptr = tl.make_block_ptr(
494
+ base=v_fp8_scale_shift_base,
495
+ shape=(offset + current_block_size, 1),
496
+ strides=(stride_v_fp8_scale_shift_n, 1),
497
+ offsets=(offset, 0),
498
+ block_shape=(BLOCK_N, 1),
499
+ order=(1, 0),
500
+ )
501
+ else:
502
+ K_scale_shift_block_ptr = None
503
+ V_scale_shift_block_ptr = None
504
+ logical_block_idx += 1
505
+
506
+ k: "VAR_ARGS_ARRAY" # noqa: F821
507
+ v: "VAR_ARGS_ARRAY" # noqa: F821
508
+
509
+ if QUANTIZE_PV_TO_FP8:
510
+ v_dtype = tl.float8e4nv
511
+ else:
512
+ v_dtype = Q.dtype.element_ty
513
+ for i in range(len(acc)): # noqa: F821
514
+ # Load and dequantize K/V with appropriate return values based on quantization flags
515
+ result = load_dequantize_k_v_group( # noqa: F821
516
+ K_block_ptr,
517
+ V_block_ptr,
518
+ K_scale_shift_block_ptr,
519
+ V_scale_shift_block_ptr,
520
+ BOUNDS_CHECKS_N,
521
+ PACKED_PER_VAL,
522
+ PACKED_D_PER_GROUP,
523
+ FP8_QUANTIZED,
524
+ Q.dtype.element_ty,
525
+ v_dtype,
526
+ i,
527
+ IS_HIP,
528
+ QUANTIZE_PV_TO_FP8,
529
+ QUANTIZE_QK_TO_FP8,
530
+ USE_FP32_SCALES,
531
+ )
532
+
533
+ # Unpack results based on quantization configuration
534
+ if QUANTIZE_PV_TO_FP8 and QUANTIZE_QK_TO_FP8:
535
+ k[i], v[i], v_scale, k_scale = result # noqa: F821
536
+ elif QUANTIZE_PV_TO_FP8:
537
+ k[i], v[i], v_scale = result # noqa: F821
538
+ elif QUANTIZE_QK_TO_FP8:
539
+ k[i], v[i], k_scale = result # noqa: F821
540
+ else:
541
+ k[i], v[i] = result # noqa: F821
542
+ # -- compute qk ---
543
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
544
+ for i in range(len(acc)): # noqa: F821
545
+ qk += tl.dot(q[i], k[i]) # noqa: F821
546
+
547
+ if QUANTIZE_QK_TO_FP8:
548
+ # Reshape k_scale for proper broadcasting with qk
549
+ # k_scale has shape (BLOCK_N,), we need to reshape it to (1, BLOCK_N)
550
+ # for proper broadcasting with qk of shape (BLOCK_M, BLOCK_N)
551
+ k_scale_reshaped = tl.reshape(k_scale, (1, BLOCK_N))
552
+
553
+ # Apply k_scale to qk
554
+ qk = qk * k_scale_reshaped
555
+ qk = qk * tl.reshape(q_scale, (BLOCK_M, 1)) # noqa: F821
556
+
557
+ # Apply qk_scale (scalar)
558
+ qk *= qk_scale
559
+
560
+ if start_n == lo and ignore_in_first_block > 0:
561
+ qk = tl.where(
562
+ tl.arange(0, BLOCK_N) < ignore_in_first_block, float("-inf"), qk
563
+ )
564
+
565
+ if HAS_ADDITIVE_BIAS:
566
+ loaded_bias = tl.load(
567
+ additive_bias_block_ptr,
568
+ boundary_check=(0, 1) if BOUNDS_CHECKS_N else (0,),
569
+ )
570
+ qk += loaded_bias.to(tl.float32) * log2e
571
+ additive_bias_block_ptr = tl.advance(additive_bias_block_ptr, (0, BLOCK_N))
572
+
573
+ # TODO: This is slow, and only needed at the last iteration.
574
+ # Maybe we can unroll the last iteration instead?
575
+ if BOUNDS_CHECKS_N:
576
+ qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf"))
577
+ if IS_CAUSAL:
578
+ # -- apply the causal mask --
579
+ qk = tl.where(diag_idx_shifted >= start_n - start_kv_idx, qk, float("-inf"))
580
+ if IS_LOCAL:
581
+ # -- apply the local window size mask --
582
+ qk = tl.where(
583
+ diag_idx_shifted < start_n - start_kv_idx + WINDOW_LEFT + 1,
584
+ qk,
585
+ float("-inf"),
586
+ )
587
+ if not IS_CAUSAL and WINDOW_RIGHT >= 0:
588
+ qk = tl.where(
589
+ diag_idx_shifted >= start_n - start_kv_idx - WINDOW_RIGHT,
590
+ qk,
591
+ float("-inf"),
592
+ )
593
+
594
+ # -- compute scaling constant ---
595
+ m_i_new = tl.maximum(m_i, tl.max(qk, 1))
596
+ alpha = tl.math.exp2(m_i - m_i_new)
597
+ p = tl.math.exp2(qk - m_i_new[:, None])
598
+ if HAS_ADDITIVE_BIAS or (IS_CAUSAL or IS_LOCAL):
599
+ # NOTE: It's possible that an entire block is masked out.
600
+ # if this is the case, `m_i_new=nan` and everything becomes nan
601
+ alpha = tl.where(m_i_new == float("-inf"), 0, alpha)
602
+ p = tl.where(m_i_new[:, None] == float("-inf"), 0, p)
603
+
604
+ # -- update m_i and l_i --
605
+ l_i = l_i * alpha + tl.sum(p, 1)
606
+ m_i = m_i_new
607
+ if not QUANTIZE_PV_TO_FP8:
608
+ p = p.to(v_dtype)
609
+ else:
610
+ # Apply v-scale to P
611
+ p = p * tl.trans(v_scale)
612
+
613
+ # Quantize P to FP8
614
+ MAX_FP8 = 448
615
+ amax = tl.max(p, axis=1) # rowmax(P)
616
+ p_scale = tl.maximum(amax / MAX_FP8, 1e-9)
617
+ p_scaled = p / p_scale[:, None]
618
+ p_clamped = tl.clamp(p_scaled, 0, MAX_FP8)
619
+
620
+ # covert P to FP8
621
+ p = p_clamped.to(v_dtype, fp_downcast_rounding="rtne")
622
+
623
+ # -- scale and update acc --
624
+ for i in range(len(acc)): # noqa: F821
625
+ acc[i] *= alpha[:, None] # noqa: F821
626
+ if not QUANTIZE_PV_TO_FP8:
627
+ acc[i] += tl.dot(p, v[i]) # noqa: F821
628
+ else:
629
+ # Re-scale PV using p_scale
630
+ acc[i] += tl.dot(p, v[i]) * p_scale[:, None] # noqa: F821
631
+
632
+ if not PAGE_SIZE:
633
+ # update pointers
634
+ K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
635
+ V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
636
+ if PACKED_PER_VAL > 1:
637
+ K_scale_shift_block_ptr = tl.advance(
638
+ K_scale_shift_block_ptr, (0, BLOCK_N)
639
+ )
640
+ V_scale_shift_block_ptr = tl.advance(
641
+ V_scale_shift_block_ptr, (BLOCK_N, 0)
642
+ )
643
+
644
+ # write back O
645
+ O_block_ptr = tl.make_block_ptr(
646
+ base=Out_splitK
647
+ + off_z.to(tl.int64) * stride_osk_z * queries_use_batch_dim
648
+ + off_m * stride_osk_m
649
+ + off_g * stride_osk_g
650
+ + off_h * stride_osk_h
651
+ + splitk_idx * stride_osk_s,
652
+ shape=(q_len, D_PER_GROUP),
653
+ strides=(stride_osk_m, 1),
654
+ offsets=(start_m * BLOCK_M, 0),
655
+ block_shape=(BLOCK_M, D_PER_GROUP),
656
+ order=(1, 0),
657
+ )
658
+ for i in range(len(acc)): # noqa: F821
659
+ # If for the current batch element there are no tokens in the current split-k chunk (because
660
+ # seqlen is too short), l_i will be 0, so we need to make sure attention is filled with zeros and not NaNs.
661
+ attn_out = tl.where(l_i[:, None] == 0, 0.0, acc[i] / l_i[:, None]) # noqa: F821
662
+ tl.store(
663
+ tl.advance(O_block_ptr, (0, i * D_PER_GROUP)),
664
+ attn_out.to(Out_splitK.dtype.element_ty), # noqa: F821
665
+ boundary_check=(0,),
666
+ )
667
+ if WRITE_LSE:
668
+ LSE_splitk_ptr = (
669
+ LSE_splitk
670
+ + off_z * stride_lsek_z * queries_use_batch_dim
671
+ + off_m * stride_lsek_m
672
+ + off_g * stride_lsek_g
673
+ + off_h * stride_lsek_h
674
+ + splitk_idx * stride_lsek_s
675
+ + (start_m * BLOCK_M + tl.arange(0, BLOCK_M)) * stride_lsek_m
676
+ )
677
+ mask = start_m * BLOCK_M + tl.arange(0, BLOCK_M) < q_len
678
+ # Can be float64 to improve numerics
679
+ lse_dtype = LSE_splitk.dtype.element_ty
680
+ tl.store(
681
+ LSE_splitk_ptr,
682
+ (tl.math.log2(l_i.to(lse_dtype)) + m_i.to(lse_dtype)) / log2e,
683
+ mask=mask,
684
+ )
685
+
686
+
687
+ def gen_config(
688
+ block_m: int,
689
+ block_n: int,
690
+ stages: int,
691
+ warps: int,
692
+ ) -> triton.Config:
693
+ """A more compact way to define a triton.Config, so it fits on one line"""
694
+
695
+ return triton.Config(
696
+ {
697
+ "BLOCK_M": block_m,
698
+ "BLOCK_N": block_n,
699
+ },
700
+ num_stages=stages,
701
+ num_warps=warps,
702
+ )
703
+
704
+
705
+ def _get_splitk_kernel(num_groups):
706
+ """
707
+ Kernel _fwd_kernel_splitK needs to be post-processed by unroll_varargs
708
+ to specialize it for a given number of quantization groups N_GROUPS
709
+ before we can apply triton.heuristics and triton.autotune, so we
710
+ don't do them as decorators.
711
+ """
712
+
713
+ _fwd_kernel_splitK_unrolled = unroll_varargs(_fwd_kernel_splitK, N=num_groups)
714
+ kernel = triton.heuristics(
715
+ {
716
+ "BOUNDS_CHECKS_N": lambda args: bool(
717
+ (args["BLOCK_N_PER_SPLIT"] % args["BLOCK_N"])
718
+ or (
719
+ args["BLOCK_N_PER_SPLIT"] > 0
720
+ and args["N_CTX_K"] % args["BLOCK_N_PER_SPLIT"]
721
+ )
722
+ or args["USE_SEQ_LEN"]
723
+ )
724
+ }
725
+ )(_fwd_kernel_splitK_unrolled)
726
+ return kernel
727
+
728
+
729
+ def early_config_prune(configs, named_args, **kwargs):
730
+ use_paged_attention = kwargs["USE_PAGED_ATTENTION"]
731
+ page_size = kwargs["PAGE_SIZE"]
732
+ if use_paged_attention:
733
+ return list(
734
+ filter(lambda config: page_size % config.kwargs["BLOCK_N"] == 0, configs)
735
+ )
736
+ else:
737
+ return configs
738
+
739
+
740
+ @functools.lru_cache(None)
741
+ def autotune_kernel(kernel: Callable):
742
+ BLOCK_M_VALUES = [16, 32, 64, 128]
743
+ BLOCK_N_VALUES = [16, 32, 64, 128]
744
+ STAGES_VALUES = [1, 2] if torch.version.hip else [1, 2, 3]
745
+ WARPS_VALUES = [1, 2, 4, 8]
746
+
747
+ TRITON_CONFIGS = [
748
+ gen_config(block_m, block_n, stages, warps)
749
+ for block_m in BLOCK_M_VALUES
750
+ for block_n in BLOCK_N_VALUES
751
+ for stages in STAGES_VALUES
752
+ for warps in WARPS_VALUES
753
+ if block_n >= block_m
754
+ ]
755
+
756
+ kernel = triton.autotune(
757
+ configs=TRITON_CONFIGS,
758
+ key=AUTOTUNER_KEY,
759
+ use_cuda_graph=True,
760
+ prune_configs_by={
761
+ "early_config_prune": early_config_prune,
762
+ },
763
+ )(kernel)
764
+ return kernel
765
+
766
+
767
+ # This object contains forward kernels wrapped into autotuner for different number
768
+ # of quantization groups.
769
+ _fwd_kernel_splitK_autotune: Dict[int, triton.runtime.Autotuner] = {}
770
+ # The loop below:
771
+ # - transforms the jitted kernel with unroll_varargs producing a new kernel of each value of num_groups
772
+ # - wraps the kernel into triton.heuristics
773
+ # - wraps kernel into Triton autotuner. Autotuning itself happens the first time the kernel is called
774
+ if sys.version_info >= (3, 9):
775
+ # unroll_varargs requires Python 3.9+
776
+ for num_groups in [1, 2, 4, 8]:
777
+ _fwd_kernel_splitK_autotune[num_groups] = autotune_kernel(
778
+ _get_splitk_kernel(num_groups)
779
+ )
780
+
781
+ def get_autotuner_cache(
782
+ num_groups: int,
783
+ ) -> Dict[Tuple[Union[int, str]], triton.Config]:
784
+ """Returns a triton.runtime.autotuner.AutoTuner.cache object, which
785
+ represents mappings from kernel autotune keys (tuples describing kernel inputs)
786
+ to triton.Config
787
+ """
788
+ return _fwd_kernel_splitK_autotune[num_groups].cache
789
+
790
+ def set_autotuner_cache(
791
+ cache: Dict[Tuple[Union[int, str]], triton.Config], num_groups: int
792
+ ) -> None:
793
+ _fwd_kernel_splitK_autotune[num_groups].cache = cache
794
+
795
+
796
+ @triton.jit
797
+ def load_dequantize_k_v_group(
798
+ K_block_ptr,
799
+ V_block_ptr,
800
+ K_scale_shift_block_ptr,
801
+ V_scale_shift_block_ptr,
802
+ BOUNDS_CHECKS_N: tl.constexpr,
803
+ PACKED_PER_VAL: tl.constexpr,
804
+ PACKED_D_PER_GROUP: tl.constexpr,
805
+ FP8_QUANTIZED: tl.constexpr,
806
+ q_dtype: tl.constexpr,
807
+ v_dtype: tl.constexpr, # Q.dtype.element_ty
808
+ group_id: tl.constexpr,
809
+ IS_HIP: tl.constexpr,
810
+ QUANTIZE_PV_TO_FP8: tl.constexpr,
811
+ QUANTIZE_QK_TO_FP8: tl.constexpr,
812
+ USE_FP32_SCALES: tl.constexpr,
813
+ ):
814
+ """Load K/V for a given block. In case of int4/fp8-quantized K/V, dequantize them after loading.
815
+ If quantization is group-wise, use group_id to advance the pointers to the current group.
816
+
817
+ Returns:
818
+ - k, v: loaded and potentially dequantized tensors
819
+ - v_scale (optional): V scale factor if QUANTIZE_PV_TO_FP8 is True
820
+ - k_scale (optional): K scale factor if QUANTIZE_QK_TO_FP8 is True
821
+ """
822
+ # Advance to the current quantization group
823
+ K_block_ptr = tl.advance(K_block_ptr, (PACKED_D_PER_GROUP * group_id, 0))
824
+ V_block_ptr = tl.advance(
825
+ V_block_ptr,
826
+ (0, PACKED_D_PER_GROUP * group_id)
827
+ if not QUANTIZE_PV_TO_FP8
828
+ else (0, PACKED_D_PER_GROUP * PACKED_PER_VAL * group_id),
829
+ )
830
+
831
+ # -- load k, v --
832
+ k = tl.load(K_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ())
833
+ v = tl.load(V_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ())
834
+
835
+ # Initialize return values
836
+ # v_scale = None
837
+ # k_scale = None
838
+
839
+ if FP8_QUANTIZED:
840
+ k, v, v_scale, k_scale = _process_fp8_quantization(
841
+ k,
842
+ v,
843
+ K_scale_shift_block_ptr,
844
+ V_scale_shift_block_ptr,
845
+ BOUNDS_CHECKS_N,
846
+ PACKED_PER_VAL,
847
+ q_dtype,
848
+ v_dtype,
849
+ IS_HIP,
850
+ QUANTIZE_PV_TO_FP8,
851
+ QUANTIZE_QK_TO_FP8,
852
+ USE_FP32_SCALES,
853
+ )
854
+
855
+ elif PACKED_PER_VAL > 1:
856
+ # Int4 quantization.
857
+ k, v = _process_int4_quantization(
858
+ k,
859
+ v,
860
+ K_scale_shift_block_ptr,
861
+ V_scale_shift_block_ptr,
862
+ group_id,
863
+ BOUNDS_CHECKS_N,
864
+ PACKED_PER_VAL,
865
+ q_dtype,
866
+ IS_HIP,
867
+ )
868
+
869
+ # Return appropriate values based on quantization flags
870
+ if QUANTIZE_PV_TO_FP8 and QUANTIZE_QK_TO_FP8:
871
+ # Return both v_scale and k_scale for applying to P and K
872
+ return k, v, v_scale, k_scale
873
+ elif QUANTIZE_PV_TO_FP8:
874
+ # Return v_scale for applying v_scale to P
875
+ return k, v, v_scale
876
+ elif QUANTIZE_QK_TO_FP8:
877
+ # Return k_scale for applying k_scale to K
878
+ return k, v, k_scale
879
+ else:
880
+ return k, v
881
+
882
+
883
+ @triton.jit
884
+ def _process_fp8_quantization(
885
+ k,
886
+ v,
887
+ K_scale_shift_block_ptr,
888
+ V_scale_shift_block_ptr,
889
+ BOUNDS_CHECKS_N: tl.constexpr,
890
+ PACKED_PER_VAL: tl.constexpr,
891
+ q_dtype: tl.constexpr,
892
+ v_dtype: tl.constexpr,
893
+ IS_HIP: tl.constexpr,
894
+ QUANTIZE_PV_TO_FP8: tl.constexpr,
895
+ QUANTIZE_QK_TO_FP8: tl.constexpr,
896
+ USE_FP32_SCALES: tl.constexpr,
897
+ ):
898
+ """Process FP8 quantization for K and V tensors."""
899
+ # Process V tensor
900
+ v_scale_shift = tl.load(
901
+ V_scale_shift_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ()
902
+ )
903
+ v_scale, v_shift = _extract_scale_shift(v_scale_shift, IS_HIP, USE_FP32_SCALES)
904
+ if not QUANTIZE_PV_TO_FP8:
905
+ v = dequantize(
906
+ v,
907
+ v_scale,
908
+ v_shift if not USE_FP32_SCALES else None,
909
+ PACKED_PER_VAL,
910
+ IS_HIP,
911
+ USE_FP32_SCALES,
912
+ ).to(v_dtype)
913
+ else:
914
+ # Do not dequantize V; V needs to be FP8 for PV.
915
+ tl.static_assert(PACKED_PER_VAL == 4, "Assert: int32 packs four FP8 values")
916
+
917
+ # Process K tensor
918
+ k_scale_shift = tl.load(
919
+ K_scale_shift_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ()
920
+ )
921
+ k_scale, k_shift = _extract_scale_shift(k_scale_shift, IS_HIP, USE_FP32_SCALES)
922
+ if IS_HIP:
923
+ if not QUANTIZE_QK_TO_FP8:
924
+ k = dequantize_k_hip(k, k_scale, k_shift, PACKED_PER_VAL).to(q_dtype)
925
+ else:
926
+ # For QUANTIZE_QK_TO_FP8, unpack int32 to 8-bit entries and interpret as fp8
927
+ tl.static_assert(PACKED_PER_VAL == 4, "Assert: int32 packs four FP8 values")
928
+ else:
929
+ if not QUANTIZE_QK_TO_FP8:
930
+ k_t = dequantize(
931
+ tl.trans(k),
932
+ tl.trans(k_scale),
933
+ tl.trans(k_shift) if not USE_FP32_SCALES else None,
934
+ PACKED_PER_VAL,
935
+ IS_HIP,
936
+ USE_FP32_SCALES,
937
+ ).to(q_dtype)
938
+ k = tl.trans(k_t)
939
+ else:
940
+ # For QUANTIZE_QK_TO_FP8, unpack int32 to 8-bit entries and interpret as fp8
941
+ tl.static_assert(PACKED_PER_VAL == 4, "Assert: int32 packs four FP8 values")
942
+ k_t = tl.trans(k)
943
+ k_t = _unpack_fp8_tensor(k_t, PACKED_PER_VAL, IS_HIP)
944
+ k = tl.trans(k_t)
945
+
946
+ return k, v, v_scale, k_scale
947
+
948
+
949
+ @triton.jit
950
+ def _extract_scale_shift(
951
+ scale_shift, IS_HIP: tl.constexpr, USE_FP32_SCALES: tl.constexpr
952
+ ):
953
+ """Extract scale and shift values from packed representation."""
954
+ if IS_HIP:
955
+ return cast_uint32_to_float(scale_shift)
956
+ elif USE_FP32_SCALES:
957
+ return scale_shift.to(tl.float32, bitcast=True), 0
958
+ else:
959
+ return cast_uint32_to_half2(scale_shift)
960
+
961
+
962
+ @triton.jit
963
+ def _unpack_fp8_tensor(x_, PACKED_PER_VAL: tl.constexpr, IS_HIP: tl.constexpr):
964
+ """Unpack FP8 K/V tensor from int32 packed representation."""
965
+ tl.static_assert(PACKED_PER_VAL == 4, "Assert: int32 packs four FP8 values")
966
+
967
+ BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1]
968
+ BLOCK_N: tl.constexpr = x_.shape[0]
969
+ # Create bit offsets for unpacking (0, 8, 16, 24 bits)
970
+ offsets = tl.arange(0, PACKED_PER_VAL) * 8
971
+
972
+ # Extract 8-bit values by right-shifting and masking
973
+ unpacked_values = x_[:, :, None, :] >> offsets
974
+ unpacked_values = tl.reshape(
975
+ unpacked_values, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)
976
+ )
977
+
978
+ # Convert to FP8 through bitcast
979
+ fp8_type = tl.float8e4b8 if IS_HIP else tl.float8e4nv
980
+ x_ = unpacked_values.to(tl.uint8).to(fp8_type, bitcast=True)
981
+
982
+ return x_
983
+
984
+
985
+ @triton.jit
986
+ def _process_int4_quantization(
987
+ k,
988
+ v,
989
+ K_scale_shift_block_ptr,
990
+ V_scale_shift_block_ptr,
991
+ group_id: tl.constexpr,
992
+ BOUNDS_CHECKS_N: tl.constexpr,
993
+ PACKED_PER_VAL: tl.constexpr,
994
+ dtype: tl.constexpr,
995
+ IS_HIP: tl.constexpr,
996
+ ):
997
+ """Process INT4 quantization for K and V tensors."""
998
+ # Advance scale/shift pointers for INT4
999
+ K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (group_id, 0))
1000
+ V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (0, group_id))
1001
+
1002
+ k_scale_shift = tl.load(
1003
+ K_scale_shift_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ()
1004
+ )
1005
+ v_scale_shift = tl.load(
1006
+ V_scale_shift_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ()
1007
+ )
1008
+ if IS_HIP:
1009
+ k_scale, k_shift = cast_uint32_to_float(k_scale_shift)
1010
+ v_scale, v_shift = cast_uint32_to_float(v_scale_shift)
1011
+ v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL, IS_HIP).to(dtype)
1012
+ k = dequantize_k_hip(k, k_scale, k_shift, PACKED_PER_VAL).to(dtype)
1013
+ else:
1014
+ k_scale, k_shift = cast_uint32_to_half2(k_scale_shift)
1015
+ v_scale, v_shift = cast_uint32_to_half2(v_scale_shift)
1016
+ v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL, IS_HIP).to(dtype)
1017
+ k_t = dequantize(
1018
+ tl.trans(k),
1019
+ tl.trans(k_scale),
1020
+ tl.trans(k_shift),
1021
+ PACKED_PER_VAL,
1022
+ IS_HIP,
1023
+ ).to(dtype)
1024
+ k = tl.trans(k_t)
1025
+
1026
+ return k, v
1027
+
1028
+
1029
+ @triton.jit
1030
+ def cast_uint64_to_float2(scale_shift):
1031
+ """Using FP32 scales, so only extract one fp32 from the packed int64"""
1032
+ scale = scale_shift & 0xFFFFFFFF
1033
+ scale = scale.to(tl.uint32).to(tl.float32, bitcast=True)
1034
+ return scale, 0
1035
+
1036
+
1037
+ @triton.jit
1038
+ def cast_uint32_to_half2(scale_shift):
1039
+ """Extract two float16 packed into one int32"""
1040
+ scale = scale_shift & 0xFFFF
1041
+ shift = scale_shift >> 16
1042
+ scale = scale.to(tl.uint16).to(tl.float16, bitcast=True)
1043
+ shift = shift.to(tl.uint16).to(tl.float16, bitcast=True)
1044
+ return scale, shift
1045
+
1046
+
1047
+ @triton.jit
1048
+ def cast_uint32_to_float(scale_shift):
1049
+ """Extract two float16 packed into one int32 as float32"""
1050
+ scale = scale_shift & 0xFFFF
1051
+ shift = scale_shift >> 16
1052
+ scale = scale.to(tl.uint16).to(tl.float16, bitcast=True).to(tl.float32)
1053
+ shift = shift.to(tl.uint16).to(tl.float16, bitcast=True).to(tl.float32)
1054
+ return scale, shift
1055
+
1056
+
1057
+ @triton.jit
1058
+ def dequantize_k_hip(
1059
+ x_,
1060
+ scale,
1061
+ shift,
1062
+ PACKED_PER_VAL: tl.constexpr,
1063
+ ):
1064
+ """PACKED_PER_VAL is the number of values packed into each element x_.
1065
+ For example, for int4 quantization and x_ of type int32, PACKED_PER_VAL is 8.
1066
+ """
1067
+ # x_ : (BLOCK_N, D // PACKED_PER_VAL)
1068
+ # scale: (BLOCK_N, 1)
1069
+ # offsets: (PACKED_PER_VAL,)
1070
+ BLOCK_N: tl.constexpr = x_.shape[1]
1071
+ BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[0]
1072
+ offsets = tl.arange(0, PACKED_PER_VAL) * (32 // PACKED_PER_VAL)
1073
+ quant_offset = (
1074
+ x_[:, None, :, :] >> offsets[:, None]
1075
+ ) # (BLOCK_N, D // PACKED_PER_VAL, PACKED_PER_VAL)
1076
+
1077
+ quant_offset = tl.reshape(
1078
+ quant_offset, (BLOCK_DMODEL_PACKED * PACKED_PER_VAL, BLOCK_N)
1079
+ )
1080
+
1081
+ if PACKED_PER_VAL == 4:
1082
+ # FP8 quantization.
1083
+ fp8_type = tl.float8e4b8 if torch.version.hip is not None else tl.float8e4nv
1084
+ dequant = (
1085
+ quant_offset.to(tl.uint8).to(fp8_type, bitcast=True).to(scale.dtype) * scale
1086
+ + shift
1087
+ )
1088
+ else:
1089
+ # Int4 quantization.
1090
+ # Trick - instead of converting int4 to float16 we view it as float16
1091
+ # and then multiply by 32768 * 512 == 2**24
1092
+ quant_offset = (
1093
+ (quant_offset & 0xF)
1094
+ .to(tl.uint16)
1095
+ .to(tl.float16, bitcast=True)
1096
+ .to(tl.float32)
1097
+ )
1098
+ quant_offset = quant_offset * 32768.0
1099
+ scale_512 = scale * 512
1100
+
1101
+ dequant = quant_offset * scale_512 + shift
1102
+ return dequant
1103
+
1104
+
1105
+ @triton.jit
1106
+ def dequantize(
1107
+ x_,
1108
+ scale,
1109
+ shift,
1110
+ PACKED_PER_VAL: tl.constexpr,
1111
+ IS_HIP: tl.constexpr,
1112
+ USE_FP32_SCALES: tl.constexpr = False,
1113
+ ):
1114
+ """PACKED_PER_VAL is the number of values packed into each element x_.
1115
+ For example, for int4 quantization and x_ of type int32, PACKED_PER_VAL is 8.
1116
+ """
1117
+ # x_ : (BLOCK_N, D // PACKED_PER_VAL)
1118
+ # scale: (BLOCK_N, 1)
1119
+ # offsets: (PACKED_PER_VAL,)
1120
+ BLOCK_N: tl.constexpr = x_.shape[0]
1121
+ BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1]
1122
+ offsets = tl.arange(0, PACKED_PER_VAL) * (32 // PACKED_PER_VAL)
1123
+ quant_offset = (
1124
+ x_[:, :, None, :] >> offsets
1125
+ ) # (BLOCK_N, D // PACKED_PER_VAL, PACKED_PER_VAL)
1126
+
1127
+ quant_offset = tl.reshape(
1128
+ quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)
1129
+ )
1130
+ if PACKED_PER_VAL == 4:
1131
+ # FP8 quantization.
1132
+ fp8_type = tl.float8e4b8 if torch.version.hip is not None else tl.float8e4nv
1133
+ dequant = (
1134
+ quant_offset.to(tl.uint8).to(fp8_type, bitcast=True).to(scale.dtype) * scale
1135
+ )
1136
+ if not USE_FP32_SCALES:
1137
+ # Use asymmetric quantization only for FP16 scales
1138
+ dequant += shift
1139
+ else:
1140
+ # Int4 quantization.
1141
+ # Trick - instead of converting int4 to float16 we view it as float16
1142
+ # and then multiply by 32768 * 512 == 2**24
1143
+ if IS_HIP:
1144
+ # Do final math in float32 to avoid casting to bf16 on MI300. There
1145
+ # no direct instructions for this so its less performant on this workload.
1146
+ quant_offset = (
1147
+ (quant_offset & 0xF)
1148
+ .to(tl.uint16)
1149
+ .to(tl.float16, bitcast=True)
1150
+ .to(tl.float32)
1151
+ )
1152
+ quant_offset = quant_offset * 32768.0
1153
+ else:
1154
+ quant_offset = (
1155
+ (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True)
1156
+ )
1157
+ quant_offset = (quant_offset * 32768.0).to(tl.float16)
1158
+ scale_512 = scale * 512
1159
+
1160
+ dequant = quant_offset * scale_512
1161
+ if not USE_FP32_SCALES:
1162
+ # Use asymmetric quantization only for FP16 scales
1163
+ dequant += shift
1164
+ return dequant
1165
+
1166
+
1167
+ @triton.jit
1168
+ def _splitK_reduce(
1169
+ Out_splitK, # [B, G, H, split_k, Mq, K]
1170
+ LSE_splitK, # [B, G, H, split_k, Mq]
1171
+ Out, # [B, H, M, K]
1172
+ LSE, # [B, H, M]
1173
+ split_k: tl.constexpr,
1174
+ splitK_pow2: tl.constexpr,
1175
+ stride_osk_z: tl.constexpr,
1176
+ stride_osk_g: tl.constexpr,
1177
+ stride_osk_h: tl.constexpr,
1178
+ stride_osk_s: tl.constexpr,
1179
+ stride_osk_m: tl.constexpr,
1180
+ stride_osk_k: tl.constexpr,
1181
+ stride_lsek_z: tl.constexpr,
1182
+ stride_lsek_g: tl.constexpr,
1183
+ stride_lsek_h: tl.constexpr,
1184
+ stride_lsek_s: tl.constexpr,
1185
+ stride_lsek_m: tl.constexpr,
1186
+ stride_oz: tl.constexpr,
1187
+ stride_og: tl.constexpr,
1188
+ stride_oh: tl.constexpr,
1189
+ stride_om: tl.constexpr,
1190
+ stride_ok: tl.constexpr,
1191
+ stride_lse_z: tl.constexpr,
1192
+ stride_lse_g: tl.constexpr,
1193
+ stride_lse_h: tl.constexpr,
1194
+ stride_lse_m: tl.constexpr,
1195
+ head_dim: tl.constexpr,
1196
+ head_dim_pow_2: tl.constexpr,
1197
+ H: tl.constexpr,
1198
+ G: tl.constexpr,
1199
+ WRITE_LSE: tl.constexpr,
1200
+ ):
1201
+ # grid = (M, B * G * H, 1)
1202
+ off_m = tl.program_id(0).to(tl.int64)
1203
+ off_zhg = tl.program_id(1).to(tl.int64)
1204
+ off_z = off_zhg // (H * G)
1205
+ off_h = (off_zhg // G) % H
1206
+ off_g = off_zhg % G
1207
+
1208
+ head_dim_mask = tl.arange(0, head_dim_pow_2) < head_dim
1209
+
1210
+ Out_splitK_ptr = (
1211
+ Out_splitK
1212
+ + stride_osk_z * off_z
1213
+ + stride_osk_g * off_g
1214
+ + stride_osk_h * off_h
1215
+ + stride_osk_m * off_m
1216
+ + tl.arange(0, head_dim_pow_2)[None, :]
1217
+ + stride_osk_s * tl.arange(0, splitK_pow2)[:, None]
1218
+ )
1219
+
1220
+ LSE_splitK_ptr0 = (
1221
+ LSE_splitK
1222
+ + stride_lsek_z * off_z
1223
+ + stride_lsek_g * off_g
1224
+ + stride_lsek_h * off_h
1225
+ + stride_lsek_m * off_m
1226
+ + stride_lsek_s * tl.arange(0, splitK_pow2)
1227
+ )
1228
+
1229
+ if splitK_pow2 > split_k:
1230
+ mask_1d = tl.arange(0, splitK_pow2) < split_k
1231
+ mask_2d = mask_1d[:, None] & head_dim_mask[None, :]
1232
+ lse_splitk = tl.load(LSE_splitK_ptr0, mask=mask_1d, other=float("-inf"))
1233
+ lse_max = tl.max(lse_splitk)
1234
+ out_splitk = tl.load(
1235
+ Out_splitK_ptr, mask=mask_2d, other=0
1236
+ ) # (split_k, head_dim_pow_2)
1237
+ lse_splitk = tl.load(
1238
+ LSE_splitK_ptr0, mask=mask_1d, other=float("-inf")
1239
+ ) # (split_k,)
1240
+ else:
1241
+ lse_splitk = tl.load(LSE_splitK_ptr0)
1242
+ lse_max = tl.max(lse_splitk)
1243
+ out_splitk = tl.load(Out_splitK_ptr)
1244
+ lse_splitk = tl.load(LSE_splitK_ptr0)
1245
+
1246
+ sumexp_normalized_splitk = tl.math.exp2(
1247
+ (lse_splitk - lse_max).to(tl.float32) * 1.44269504
1248
+ ) # (split_k,)
1249
+ sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0) # scalar
1250
+ # Compute numerator
1251
+ numerator_normalized = tl.sum(
1252
+ out_splitk * sumexp_normalized_splitk[:, None], axis=0
1253
+ )
1254
+ acc = numerator_normalized / sumexp_normalized
1255
+ acc = tl.where(lse_max == float("-inf"), 0.0, acc)
1256
+
1257
+ Out_ptr = (
1258
+ Out
1259
+ + stride_oz * off_z
1260
+ + stride_oh * off_h
1261
+ + stride_og * off_g
1262
+ + stride_om * off_m
1263
+ + tl.arange(0, head_dim_pow_2)
1264
+ )
1265
+ if acc.dtype is tl.float64 and Out.dtype.element_ty is not tl.float64:
1266
+ # must avoid direct cast f64->f16
1267
+ acc = acc.to(tl.float32)
1268
+ tl.store(Out_ptr, acc, mask=head_dim_mask)
1269
+
1270
+ if WRITE_LSE:
1271
+ l_ptrs = (
1272
+ LSE
1273
+ + off_z * stride_lse_z
1274
+ + off_g * stride_lse_g
1275
+ + off_h * stride_lse_h
1276
+ + off_m * stride_lse_m
1277
+ )
1278
+ to_store = lse_max + tl.math.log2(sumexp_normalized) / 1.44269504
1279
+ to_store = tl.where(lse_max == float("-inf"), lse_max, to_store)
1280
+ tl.store(l_ptrs, to_store)
1281
+
1282
+
1283
+ @triton.jit
1284
+ def _splitK_reduce_varargs(
1285
+ Out_splitK: "VAR_ARGS_ARRAY", # list of [B, G, H, Mq, K];
1286
+ LSE_splitK: "VAR_ARGS_ARRAY", # list of [B, G, H, Mq]
1287
+ Out, # [B, G, H, M, K]
1288
+ LSE, # [B, G, H, M]
1289
+ stride_osk_z: "VAR_ARGS_ARRAY",
1290
+ stride_osk_g: "VAR_ARGS_ARRAY",
1291
+ stride_osk_h: "VAR_ARGS_ARRAY",
1292
+ stride_osk_m: "VAR_ARGS_ARRAY",
1293
+ stride_osk_k: "VAR_ARGS_ARRAY",
1294
+ stride_lsek_z: "VAR_ARGS_ARRAY",
1295
+ stride_lsek_g: "VAR_ARGS_ARRAY",
1296
+ stride_lsek_h: "VAR_ARGS_ARRAY",
1297
+ stride_lsek_m: "VAR_ARGS_ARRAY",
1298
+ stride_oz,
1299
+ stride_og,
1300
+ stride_oh,
1301
+ stride_om,
1302
+ stride_ok,
1303
+ stride_lse_z,
1304
+ stride_lse_g,
1305
+ stride_lse_h,
1306
+ stride_lse_m,
1307
+ head_dim: tl.constexpr,
1308
+ head_dim_pow_2: tl.constexpr,
1309
+ H: tl.constexpr,
1310
+ G: tl.constexpr,
1311
+ WRITE_LSE: tl.constexpr,
1312
+ ):
1313
+ """
1314
+ This version of reduce kernel takes attention and LSE of chunks as lists of tensors,
1315
+ as opposed to _splitK_reduce, which takes each as a stacked tensor.
1316
+ """
1317
+ # grid = (M, B * G * H, 1)
1318
+ off_m = tl.program_id(0).to(tl.int64)
1319
+ off_zhg = tl.program_id(1).to(tl.int64)
1320
+ off_z = off_zhg // (H * G)
1321
+ off_h = (off_zhg // G) % H
1322
+ off_g = off_zhg % G
1323
+ head_dim_mask = tl.arange(0, head_dim_pow_2) < head_dim
1324
+
1325
+ out_splitk_offset: "VAR_ARGS_ARRAY" # noqa: F821
1326
+ for i in range(len(Out_splitK)):
1327
+ out_splitk_offset[i] = ( # noqa: F821
1328
+ stride_osk_z[i] * off_z # type: ignore # noqa: F821
1329
+ + stride_osk_g[i] * off_g
1330
+ + stride_osk_h[i] * off_h
1331
+ + stride_osk_m[i] * off_m
1332
+ + tl.arange(0, head_dim_pow_2)
1333
+ )
1334
+ lse_splitk_offset: "VAR_ARGS_ARRAY" # noqa: F821
1335
+ for i in range(len(Out_splitK)):
1336
+ lse_splitk_offset[i] = ( # noqa: F821
1337
+ stride_lsek_z[i] * off_z # type: ignore # noqa: F821
1338
+ + stride_lsek_g[i] * off_g
1339
+ + stride_lsek_h[i] * off_h
1340
+ + stride_lsek_m[i] * off_m
1341
+ )
1342
+
1343
+ lse_max = float("-inf")
1344
+ for split_k_idx in range(len(Out_splitK)): # type: ignore # noqa: F821
1345
+ LSE_splitK_ptr = LSE_splitK[split_k_idx] + lse_splitk_offset[split_k_idx] # type: ignore # noqa: F821
1346
+ lse_splitk = tl.load(LSE_splitK_ptr)
1347
+ lse_max = tl.maximum(lse_max, lse_splitk)
1348
+
1349
+ sumexp_normalized = 0.0
1350
+ numerator_normalized = tl.zeros([head_dim_pow_2], dtype=tl.float32)
1351
+
1352
+ for split_k_idx in range(len(Out_splitK)): # type: ignore # noqa: F821
1353
+ out_splitk = tl.load(
1354
+ Out_splitK[split_k_idx] + out_splitk_offset[split_k_idx], # type: ignore # noqa: F821
1355
+ mask=head_dim_mask,
1356
+ )
1357
+ lse_splitk = tl.load(LSE_splitK[split_k_idx] + lse_splitk_offset[split_k_idx]) # type: ignore # noqa: F821
1358
+ # Compute denominator
1359
+ sumexp_normalized_splitk = tl.math.exp2(
1360
+ (lse_splitk - lse_max).to(tl.float32) * 1.44269504
1361
+ )
1362
+ sumexp_normalized += sumexp_normalized_splitk
1363
+
1364
+ # Compute numerator
1365
+ numerator_normalized += out_splitk * sumexp_normalized_splitk
1366
+
1367
+ acc = numerator_normalized / sumexp_normalized
1368
+ acc = tl.where(lse_max == float("-inf"), 0.0, acc)
1369
+
1370
+ Out_ptr = (
1371
+ Out
1372
+ + stride_oz * off_z
1373
+ + stride_oh * off_h
1374
+ + stride_og * off_g
1375
+ + stride_om * off_m
1376
+ + tl.arange(0, head_dim_pow_2)
1377
+ )
1378
+ if acc.dtype is tl.float64 and Out.dtype.element_ty is not tl.float64:
1379
+ # must avoid direct cast f64->f16
1380
+ acc = acc.to(tl.float32)
1381
+ tl.store(Out_ptr, acc, mask=head_dim_mask)
1382
+
1383
+ if WRITE_LSE:
1384
+ l_ptrs = (
1385
+ LSE
1386
+ + off_z * stride_lse_z
1387
+ + off_g * stride_lse_g
1388
+ + off_h * stride_lse_h
1389
+ + off_m * stride_lse_m
1390
+ )
1391
+ to_store = lse_max + tl.math.log2(sumexp_normalized) / 1.44269504
1392
+ to_store = tl.where(lse_max == float("-inf"), lse_max, to_store)
1393
+ tl.store(l_ptrs, to_store)
1394
+
1395
+
1396
+ @triton.jit
1397
+ def _splitK_reduce_varargs_backward(
1398
+ Out_splitK: "VAR_ARGS_ARRAY", # list of [B, G, H, Mq, K];
1399
+ LSE_splitK: "VAR_ARGS_ARRAY", # list of [B, G, H, Mq]
1400
+ Dout_splitK: "VAR_ARGS_ARRAY", # gradients - same shape as the inputs themselves
1401
+ DLSE_splitK: "VAR_ARGS_ARRAY",
1402
+ Out, # [B, G, H, M, K]
1403
+ LSE, # [B, G, H, M]
1404
+ DOut,
1405
+ DLSE,
1406
+ # strides of chunked inputs: attention and LSE
1407
+ stride_osk_z: "VAR_ARGS_ARRAY",
1408
+ stride_osk_g: "VAR_ARGS_ARRAY",
1409
+ stride_osk_h: "VAR_ARGS_ARRAY",
1410
+ stride_osk_m: "VAR_ARGS_ARRAY",
1411
+ stride_osk_k: "VAR_ARGS_ARRAY",
1412
+ stride_lsek_z: "VAR_ARGS_ARRAY",
1413
+ stride_lsek_g: "VAR_ARGS_ARRAY",
1414
+ stride_lsek_h: "VAR_ARGS_ARRAY",
1415
+ stride_lsek_m: "VAR_ARGS_ARRAY",
1416
+ # strides of merged outputs: attention and LSE
1417
+ stride_oz,
1418
+ stride_og,
1419
+ stride_oh,
1420
+ stride_om,
1421
+ stride_ok,
1422
+ stride_lse_z,
1423
+ stride_lse_g,
1424
+ stride_lse_h,
1425
+ stride_lse_m,
1426
+ # strides of gradients
1427
+ stride_doz,
1428
+ stride_dog,
1429
+ stride_doh,
1430
+ stride_dom,
1431
+ stride_dok,
1432
+ stride_dlse_z,
1433
+ stride_dlse_g,
1434
+ stride_dlse_h,
1435
+ stride_dlse_m,
1436
+ BLOCK_SIZE: tl.constexpr,
1437
+ H: tl.constexpr,
1438
+ G: tl.constexpr,
1439
+ ):
1440
+ """
1441
+ Backward for _splitK_reduce_varargs. Similar to forward, it takes
1442
+ attention and LSE of chunks as lists of tensors,
1443
+ and outputs the corresponding gradients in the same format.
1444
+ """
1445
+
1446
+ # grid = (M, B * G * H, 1)
1447
+ off_m = tl.program_id(0).to(tl.int64)
1448
+ off_zhg = tl.program_id(1).to(tl.int64)
1449
+ off_z = off_zhg // (H * G)
1450
+ off_h = (off_zhg // G) % H
1451
+ off_g = off_zhg % G
1452
+
1453
+ # Compute offsets inside each attention/LSE chunk.
1454
+ # Note that each chunk can have different strides, so offsets can also be different.
1455
+ out_splitk_offset: "VAR_ARGS_ARRAY" # noqa: F821
1456
+ for i in range(len(Out_splitK)):
1457
+ out_splitk_offset[i] = ( # type: ignore # noqa: F821
1458
+ stride_osk_z[i] * off_z
1459
+ + stride_osk_g[i] * off_g
1460
+ + stride_osk_h[i] * off_h
1461
+ + stride_osk_m[i] * off_m
1462
+ + tl.arange(0, BLOCK_SIZE)
1463
+ )
1464
+ lse_splitk_offset: "VAR_ARGS_ARRAY" # noqa: F821
1465
+ for i in range(len(Out_splitK)):
1466
+ lse_splitk_offset[i] = ( # type: ignore # noqa: F821
1467
+ stride_lsek_z[i] * off_z
1468
+ + stride_lsek_g[i] * off_g
1469
+ + stride_lsek_h[i] * off_h
1470
+ + stride_lsek_m[i] * off_m
1471
+ )
1472
+
1473
+ lse_max = float("-inf")
1474
+ for split_k_idx in range(len(Out_splitK)): # type: ignore # noqa: F821
1475
+ LSE_splitK_ptr = LSE_splitK[split_k_idx] + lse_splitk_offset[split_k_idx] # type: ignore # noqa: F821
1476
+ lse_splitk = tl.load(LSE_splitK_ptr)
1477
+ lse_max = tl.maximum(lse_max, lse_splitk)
1478
+
1479
+ # Load attention and the corresponding gradient
1480
+ offset_out = (
1481
+ stride_oz * off_z
1482
+ + stride_oh * off_h
1483
+ + stride_og * off_g
1484
+ + stride_om * off_m
1485
+ + tl.arange(0, BLOCK_SIZE)
1486
+ )
1487
+ offset_dout = (
1488
+ stride_doz * off_z
1489
+ + stride_doh * off_h
1490
+ + stride_dog * off_g
1491
+ + stride_dom * off_m
1492
+ + tl.arange(0, BLOCK_SIZE)
1493
+ )
1494
+ out = tl.load(Out + offset_out)
1495
+ dattn = tl.load(DOut + offset_dout)
1496
+
1497
+ # Load LSE and the corresponding gradient
1498
+ offset_lse = (
1499
+ stride_lse_z * off_z
1500
+ + stride_lse_h * off_h
1501
+ + stride_lse_g * off_g
1502
+ + stride_lse_m * off_m
1503
+ )
1504
+ offset_dlse = (
1505
+ stride_dlse_z * off_z
1506
+ + stride_dlse_h * off_h
1507
+ + stride_dlse_g * off_g
1508
+ + stride_dlse_m * off_m
1509
+ )
1510
+ lse = tl.load(LSE + offset_lse)
1511
+ dlse = tl.load(DLSE + offset_dlse)
1512
+
1513
+ for split_k_idx in range(len(Out_splitK)): # type: ignore # noqa: F821
1514
+ # Load attention and LSE of chunks
1515
+ out_splitk = tl.load(Out_splitK[split_k_idx] + out_splitk_offset[split_k_idx]) # type: ignore # noqa: F821
1516
+ lse_splitk = tl.load(LSE_splitK[split_k_idx] + lse_splitk_offset[split_k_idx]) # type: ignore # noqa: F821
1517
+
1518
+ # Pointers to save gradients of attention and LSE of chunks
1519
+ dout_splitk_ptr = Dout_splitK[split_k_idx] + out_splitk_offset[split_k_idx] # type: ignore # noqa: F821
1520
+ dlse_splitk_ptr = DLSE_splitK[split_k_idx] + lse_splitk_offset[split_k_idx] # type: ignore # noqa: F821
1521
+
1522
+ # dX/dattn_i = dX/dattn * dattn/dattn_i + dX/dlse * dlse/dattn_i, and dlse/dattn_i == 0
1523
+ dattn_dattn_i = tl.exp(lse_splitk - lse_max) / tl.exp(lse - lse_max)
1524
+ dX_dattn_i = dattn_dattn_i * dattn
1525
+ tl.store(dout_splitk_ptr, dX_dattn_i)
1526
+
1527
+ dattn_dlse_i = (out_splitk - out) * dattn_dattn_i
1528
+
1529
+ # dX/dlse_i = dX/dattn * dattn/dlse_i + dX/dlse * dlse/dlse_i
1530
+ dlse_dlse_i = dattn_dattn_i
1531
+ dX_dlse_i = dlse_dlse_i * dlse + tl.sum(
1532
+ dattn_dlse_i * dattn
1533
+ ) # Sum is over the hidden dimension
1534
+ tl.store(dlse_splitk_ptr, dX_dlse_i)