mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1534 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import functools
|
|
8
|
+
import sys
|
|
9
|
+
from typing import Callable, Dict, Tuple, Union
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import triton
|
|
13
|
+
import triton.language as tl
|
|
14
|
+
|
|
15
|
+
from .vararg_kernel import unroll_varargs, VAR_ARGS_ARRAY
|
|
16
|
+
|
|
17
|
+
# pyre-ignore-all-errors
|
|
18
|
+
AUTOTUNER_KEY = [
|
|
19
|
+
"Z",
|
|
20
|
+
"H",
|
|
21
|
+
"G",
|
|
22
|
+
"N_CTX_Q",
|
|
23
|
+
"N_CTX_K",
|
|
24
|
+
"BLOCK_DMODEL",
|
|
25
|
+
"PACKED_PER_VAL",
|
|
26
|
+
"N_GROUPS",
|
|
27
|
+
"BLOCK_N_PER_SPLIT",
|
|
28
|
+
"PAGE_SIZE",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@triton.jit
|
|
33
|
+
def _fwd_kernel_splitK( # noqa: C901
|
|
34
|
+
Q,
|
|
35
|
+
K,
|
|
36
|
+
V,
|
|
37
|
+
sm_scale,
|
|
38
|
+
Out_splitK, # [B, H, split_k, Mq, K]
|
|
39
|
+
LSE_splitk, # [B, H, split_k, Mq]
|
|
40
|
+
block_tables,
|
|
41
|
+
Seq_len,
|
|
42
|
+
Seq_starts_k,
|
|
43
|
+
Seq_starts_q,
|
|
44
|
+
Seq_starts_q_multiplier,
|
|
45
|
+
additive_bias,
|
|
46
|
+
K_fp8_scale_shift,
|
|
47
|
+
V_fp8_scale_shift,
|
|
48
|
+
q_fp8_scale_shift,
|
|
49
|
+
stride_qz,
|
|
50
|
+
stride_qm,
|
|
51
|
+
stride_qg,
|
|
52
|
+
stride_qh,
|
|
53
|
+
stride_qk,
|
|
54
|
+
stride_kz,
|
|
55
|
+
stride_kn,
|
|
56
|
+
stride_kg,
|
|
57
|
+
stride_kh,
|
|
58
|
+
stride_kk,
|
|
59
|
+
stride_vz,
|
|
60
|
+
stride_vn,
|
|
61
|
+
stride_vg,
|
|
62
|
+
stride_vh,
|
|
63
|
+
stride_vk,
|
|
64
|
+
stride_osk_z,
|
|
65
|
+
stride_osk_g,
|
|
66
|
+
stride_osk_h,
|
|
67
|
+
stride_osk_s,
|
|
68
|
+
stride_osk_m,
|
|
69
|
+
stride_osk_k,
|
|
70
|
+
stride_lsek_z,
|
|
71
|
+
stride_lsek_g,
|
|
72
|
+
stride_lsek_h,
|
|
73
|
+
stride_lsek_s,
|
|
74
|
+
stride_lsek_m,
|
|
75
|
+
stride_blocktablesz,
|
|
76
|
+
stride_blocktablesl,
|
|
77
|
+
stride_bias_b,
|
|
78
|
+
stride_bias_g,
|
|
79
|
+
stride_bias_h,
|
|
80
|
+
stride_bias_qm,
|
|
81
|
+
stride_bias_km,
|
|
82
|
+
stride_k_fp8_scale_shift_z: tl.constexpr,
|
|
83
|
+
stride_k_fp8_scale_shift_n: tl.constexpr,
|
|
84
|
+
stride_k_fp8_scale_shift_g: tl.constexpr,
|
|
85
|
+
stride_k_fp8_scale_shift_h: tl.constexpr,
|
|
86
|
+
stride_v_fp8_scale_shift_z: tl.constexpr,
|
|
87
|
+
stride_v_fp8_scale_shift_n: tl.constexpr,
|
|
88
|
+
stride_v_fp8_scale_shift_g: tl.constexpr,
|
|
89
|
+
stride_v_fp8_scale_shift_h: tl.constexpr,
|
|
90
|
+
stride_q_fp8_scale_shift_z: tl.constexpr,
|
|
91
|
+
stride_q_fp8_scale_shift_m: tl.constexpr,
|
|
92
|
+
stride_q_fp8_scale_shift_g: tl.constexpr,
|
|
93
|
+
stride_q_fp8_scale_shift_h: tl.constexpr,
|
|
94
|
+
kv_cache_blocks_per_row: tl.constexpr,
|
|
95
|
+
Z: tl.constexpr,
|
|
96
|
+
N_CTX_Q: tl.constexpr, # The number of queries
|
|
97
|
+
N_CTX_K: tl.constexpr,
|
|
98
|
+
BLOCK_N_PER_SPLIT: tl.constexpr,
|
|
99
|
+
H: tl.constexpr,
|
|
100
|
+
G: tl.constexpr,
|
|
101
|
+
BLOCK_DMODEL: tl.constexpr,
|
|
102
|
+
USE_SEQ_LEN: tl.constexpr,
|
|
103
|
+
PACKED_PER_VAL: tl.constexpr,
|
|
104
|
+
N_GROUPS: tl.constexpr,
|
|
105
|
+
# It's important that BOUNDS_CHECKS_N, BLOCK_M, BLOCK_N come at the end of
|
|
106
|
+
# the argument list, since they are provided by the heuristics/autotune decorator.
|
|
107
|
+
# Otherwise Triton throws IndexError
|
|
108
|
+
BOUNDS_CHECKS_N: tl.constexpr,
|
|
109
|
+
BLOCK_M: tl.constexpr,
|
|
110
|
+
BLOCK_N: tl.constexpr,
|
|
111
|
+
IS_SPLITK: tl.constexpr,
|
|
112
|
+
SPLIT_K_EARLY_EXIT: tl.constexpr,
|
|
113
|
+
IS_CAUSAL: tl.constexpr,
|
|
114
|
+
IS_LOCAL: tl.constexpr,
|
|
115
|
+
NUM_QUERIES_CAUSAL: tl.constexpr, # The N_CTX_Q queries are from this many sequence positions
|
|
116
|
+
USE_PAGED_ATTENTION: tl.constexpr,
|
|
117
|
+
PAGE_SIZE: tl.constexpr,
|
|
118
|
+
WINDOW_LEFT: tl.constexpr,
|
|
119
|
+
WINDOW_RIGHT: tl.constexpr,
|
|
120
|
+
WRITE_LSE: tl.constexpr,
|
|
121
|
+
HAS_ADDITIVE_BIAS: tl.constexpr,
|
|
122
|
+
NUM_PROGRAMS_DIM2_CONST: tl.constexpr,
|
|
123
|
+
IS_HIP: tl.constexpr,
|
|
124
|
+
QUANTIZE_PV_TO_FP8: tl.constexpr,
|
|
125
|
+
QUANTIZE_QK_TO_FP8: tl.constexpr,
|
|
126
|
+
USE_FP32_SCALES: tl.constexpr,
|
|
127
|
+
):
|
|
128
|
+
"""This kernel can accept non-quantized or int4-quantized keys/values.
|
|
129
|
+
PACKED_PER_VAL determines the quantization type:
|
|
130
|
+
- PACKED_PER_VAL == 1 means no quantization
|
|
131
|
+
- PACKED_PER_VAL == 8 means 4-bit quantization (8 packed quantized values inside one int32)
|
|
132
|
+
For the quantized case K/V should be int32 tensors.
|
|
133
|
+
Quantization can be row-wise (when N_GROUPS = 1) or group-wise with N_GROUPS = 2, 4, or 8.
|
|
134
|
+
Quantization coefficients are stored at the beginning of the row along the last dimension of K/V
|
|
135
|
+
So K[B, H, M, :] has a form
|
|
136
|
+
[ quant_coef0, quant_coef1, ...|
|
|
137
|
+
group0_quant_value0, group0_quant_value1,... |
|
|
138
|
+
group1_quant_value0, group1_quant_value1,...]
|
|
139
|
+
where each quant_coef is an int32 which should be interpreted as 2 packed float16: scale and offset.
|
|
140
|
+
|
|
141
|
+
Note: this kernel needs to be processed by unroll_varargs
|
|
142
|
+
before compilation. That will unroll variables marked with "VAR_ARGS_ARRAY" into lists.
|
|
143
|
+
See how FwOp.apply does it below.
|
|
144
|
+
|
|
145
|
+
Set IS_SPLITK=False to indicate the MHA result should be written directly.
|
|
146
|
+
No metadata will be written.
|
|
147
|
+
"""
|
|
148
|
+
internal_dtype = (
|
|
149
|
+
tl.float64 if Out_splitK.dtype.element_ty is tl.float64 else tl.float32
|
|
150
|
+
)
|
|
151
|
+
tl.static_assert(
|
|
152
|
+
(PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32))
|
|
153
|
+
or (
|
|
154
|
+
(PACKED_PER_VAL == 4 or PACKED_PER_VAL == 8)
|
|
155
|
+
and tl.constexpr(K.dtype.element_ty == tl.int32)
|
|
156
|
+
),
|
|
157
|
+
f"Only int4 and fp8 quantization is supported, K/V should have dtype int32 in "
|
|
158
|
+
f"the quantized case: {PACKED_PER_VAL=} {tl.constexpr(K.dtype)=} {tl.constexpr(K.dtype.element_ty)=}",
|
|
159
|
+
)
|
|
160
|
+
tl.static_assert(
|
|
161
|
+
(((N_GROUPS == 1 or N_GROUPS == 2) or N_GROUPS == 4) or N_GROUPS == 8),
|
|
162
|
+
"Number of quantization groups can be 1 (row-wise quantization), 2, 4, or 8.",
|
|
163
|
+
)
|
|
164
|
+
tl.static_assert(
|
|
165
|
+
N_GROUPS == 1 or K_fp8_scale_shift is None,
|
|
166
|
+
f"Only row-wise fp8 quantization is supported, but got {N_GROUPS=} > 1.",
|
|
167
|
+
)
|
|
168
|
+
FP8_QUANTIZED: tl.constexpr = K_fp8_scale_shift is not None
|
|
169
|
+
INT4_QUANTIZED: tl.constexpr = PACKED_PER_VAL > 1 and not FP8_QUANTIZED
|
|
170
|
+
PACKED_D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // PACKED_PER_VAL // N_GROUPS
|
|
171
|
+
D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // N_GROUPS
|
|
172
|
+
|
|
173
|
+
start_m = tl.program_id(0)
|
|
174
|
+
off_zhg = tl.program_id(1)
|
|
175
|
+
off_z = (off_zhg // (H * G)).to(tl.int64)
|
|
176
|
+
off_hg = off_zhg % (H * G)
|
|
177
|
+
off_h = off_hg // G
|
|
178
|
+
off_g = off_hg % G
|
|
179
|
+
splitk_idx = tl.program_id(2)
|
|
180
|
+
|
|
181
|
+
if USE_SEQ_LEN:
|
|
182
|
+
kv_len = tl.load(Seq_len + off_z)
|
|
183
|
+
if SPLIT_K_EARLY_EXIT and kv_len == 0:
|
|
184
|
+
return
|
|
185
|
+
else:
|
|
186
|
+
kv_len = N_CTX_K
|
|
187
|
+
|
|
188
|
+
if Seq_starts_k is None:
|
|
189
|
+
start_kv_idx = 0
|
|
190
|
+
else:
|
|
191
|
+
start_kv_idx = tl.load(Seq_starts_k + off_z)
|
|
192
|
+
if USE_SEQ_LEN and PAGE_SIZE > 0:
|
|
193
|
+
# gappy with paged attention stores each "end" instead of each "length"
|
|
194
|
+
# because that's what FA3 needs.
|
|
195
|
+
kv_len -= start_kv_idx
|
|
196
|
+
|
|
197
|
+
if Seq_starts_q is None:
|
|
198
|
+
q_len = N_CTX_Q
|
|
199
|
+
queries_use_batch_dim = 1
|
|
200
|
+
off_m = 0
|
|
201
|
+
else:
|
|
202
|
+
queries_use_batch_dim = 0
|
|
203
|
+
off_m = tl.load(Seq_starts_q + off_z) * Seq_starts_q_multiplier
|
|
204
|
+
q_len = tl.load(Seq_starts_q + off_z + 1) * Seq_starts_q_multiplier - off_m
|
|
205
|
+
if q_len == 0:
|
|
206
|
+
return
|
|
207
|
+
|
|
208
|
+
k_base = K + off_h * stride_kh + off_g * stride_kg
|
|
209
|
+
v_base = V + off_h * stride_vh + off_g * stride_vg
|
|
210
|
+
|
|
211
|
+
if FP8_QUANTIZED:
|
|
212
|
+
k_fp8_scale_shift_base = (
|
|
213
|
+
K_fp8_scale_shift
|
|
214
|
+
+ off_h * stride_k_fp8_scale_shift_h
|
|
215
|
+
+ off_g * stride_k_fp8_scale_shift_g
|
|
216
|
+
)
|
|
217
|
+
v_fp8_scale_shift_base = (
|
|
218
|
+
V_fp8_scale_shift
|
|
219
|
+
+ off_h * stride_v_fp8_scale_shift_h
|
|
220
|
+
+ off_g * stride_v_fp8_scale_shift_g
|
|
221
|
+
)
|
|
222
|
+
else:
|
|
223
|
+
k_fp8_scale_shift_base = None
|
|
224
|
+
v_fp8_scale_shift_base = None
|
|
225
|
+
|
|
226
|
+
# Boundaries of split-k chunk
|
|
227
|
+
chunk_hi = (splitk_idx + 1) * BLOCK_N_PER_SPLIT
|
|
228
|
+
chunk_lo = splitk_idx * BLOCK_N_PER_SPLIT
|
|
229
|
+
ignore_in_first_block = 0
|
|
230
|
+
# For paged attention case K/V_block_ptr are defined inside the loop
|
|
231
|
+
# whereas for non-paged case they are defined before the loop.
|
|
232
|
+
if PAGE_SIZE > 0:
|
|
233
|
+
# Page contains several blocks
|
|
234
|
+
BLOCKS_IN_PAGE: tl.constexpr = PAGE_SIZE // BLOCK_N
|
|
235
|
+
# Align boundaries of split-k chunk to block boundaries
|
|
236
|
+
# In the last chunk, shift hi to the right, in the other chunks, shift it to the left
|
|
237
|
+
# TODO: Replace NUM_PROGRAMS_DIM2_CONST with tl.num_programs(2) after
|
|
238
|
+
# the next Triton upgrade.
|
|
239
|
+
is_last_chunk = splitk_idx == NUM_PROGRAMS_DIM2_CONST - 1
|
|
240
|
+
shift = BLOCK_N - 1 if is_last_chunk else 0
|
|
241
|
+
lo = (tl.maximum(chunk_lo, start_kv_idx) // BLOCK_N) * BLOCK_N
|
|
242
|
+
ignore_in_first_block = tl.maximum(0, (start_kv_idx - lo))
|
|
243
|
+
hi = ((chunk_hi + shift) // BLOCK_N) * BLOCK_N
|
|
244
|
+
hi = tl.minimum(hi, kv_len + start_kv_idx)
|
|
245
|
+
block_table = block_tables + stride_blocktablesz * off_z
|
|
246
|
+
# Offset in integer blocks
|
|
247
|
+
logical_block_idx = lo // BLOCK_N
|
|
248
|
+
else:
|
|
249
|
+
lo = chunk_lo
|
|
250
|
+
hi = tl.minimum(chunk_hi, kv_len)
|
|
251
|
+
if Seq_starts_k is not None:
|
|
252
|
+
k_base += start_kv_idx * stride_kn
|
|
253
|
+
v_base += start_kv_idx * stride_vn
|
|
254
|
+
else:
|
|
255
|
+
k_base += off_z * stride_kz
|
|
256
|
+
v_base += off_z * stride_vz
|
|
257
|
+
# Additional shift by 1 along the last dimension in the quantized case, since
|
|
258
|
+
# the first element along that dim contains packed quantization coefficients.
|
|
259
|
+
K_block_ptr = tl.make_block_ptr(
|
|
260
|
+
base=k_base + stride_kk * INT4_QUANTIZED * N_GROUPS,
|
|
261
|
+
shape=(PACKED_D_PER_GROUP, hi),
|
|
262
|
+
strides=(stride_kk, stride_kn),
|
|
263
|
+
offsets=(0, lo),
|
|
264
|
+
block_shape=(PACKED_D_PER_GROUP, BLOCK_N),
|
|
265
|
+
order=(0, 1),
|
|
266
|
+
)
|
|
267
|
+
V_block_ptr = tl.make_block_ptr(
|
|
268
|
+
base=v_base + stride_vk * INT4_QUANTIZED * N_GROUPS,
|
|
269
|
+
shape=(hi, PACKED_D_PER_GROUP if not QUANTIZE_PV_TO_FP8 else D_PER_GROUP),
|
|
270
|
+
strides=(stride_vn, stride_vk),
|
|
271
|
+
offsets=(lo, 0),
|
|
272
|
+
block_shape=(
|
|
273
|
+
BLOCK_N,
|
|
274
|
+
PACKED_D_PER_GROUP if not QUANTIZE_PV_TO_FP8 else D_PER_GROUP,
|
|
275
|
+
),
|
|
276
|
+
order=(1, 0),
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
if INT4_QUANTIZED:
|
|
280
|
+
# Pointers to quantization coefficients. Even those they are 1D,
|
|
281
|
+
# we use block pointers here so the pointer arithmetic is in int64,
|
|
282
|
+
# as otherwise the offsets for V_scale_shift_block_ptr may overflow.
|
|
283
|
+
K_scale_shift_block_ptr = tl.make_block_ptr(
|
|
284
|
+
base=k_base,
|
|
285
|
+
shape=(1, hi),
|
|
286
|
+
strides=(stride_kk, stride_kn),
|
|
287
|
+
offsets=(0, lo),
|
|
288
|
+
block_shape=(1, BLOCK_N),
|
|
289
|
+
order=(0, 1),
|
|
290
|
+
)
|
|
291
|
+
V_scale_shift_block_ptr = tl.make_block_ptr(
|
|
292
|
+
base=v_base,
|
|
293
|
+
shape=(hi, 1),
|
|
294
|
+
strides=(stride_vn, stride_vk),
|
|
295
|
+
offsets=(lo, 0),
|
|
296
|
+
block_shape=(BLOCK_N, 1),
|
|
297
|
+
order=(1, 0),
|
|
298
|
+
)
|
|
299
|
+
elif FP8_QUANTIZED:
|
|
300
|
+
if Seq_starts_k is not None:
|
|
301
|
+
k_fp8_scale_shift_base += start_kv_idx * stride_k_fp8_scale_shift_n
|
|
302
|
+
v_fp8_scale_shift_base += start_kv_idx * stride_v_fp8_scale_shift_n
|
|
303
|
+
else:
|
|
304
|
+
k_fp8_scale_shift_base += off_z * stride_k_fp8_scale_shift_z
|
|
305
|
+
v_fp8_scale_shift_base += off_z * stride_v_fp8_scale_shift_z
|
|
306
|
+
K_scale_shift_block_ptr = tl.make_block_ptr(
|
|
307
|
+
base=k_fp8_scale_shift_base,
|
|
308
|
+
shape=(1, hi),
|
|
309
|
+
strides=(1, stride_k_fp8_scale_shift_n),
|
|
310
|
+
offsets=(0, lo),
|
|
311
|
+
block_shape=(1, BLOCK_N),
|
|
312
|
+
order=(0, 1),
|
|
313
|
+
)
|
|
314
|
+
V_scale_shift_block_ptr = tl.make_block_ptr(
|
|
315
|
+
base=v_fp8_scale_shift_base,
|
|
316
|
+
shape=(hi, 1),
|
|
317
|
+
strides=(stride_v_fp8_scale_shift_n, 1),
|
|
318
|
+
offsets=(lo, 0),
|
|
319
|
+
block_shape=(BLOCK_N, 1),
|
|
320
|
+
order=(1, 0),
|
|
321
|
+
)
|
|
322
|
+
else:
|
|
323
|
+
K_scale_shift_block_ptr = None
|
|
324
|
+
V_scale_shift_block_ptr = None
|
|
325
|
+
|
|
326
|
+
if HAS_ADDITIVE_BIAS:
|
|
327
|
+
additive_bias_block_ptr = tl.make_block_ptr(
|
|
328
|
+
base=additive_bias
|
|
329
|
+
+ off_z * stride_bias_b
|
|
330
|
+
+ off_g * stride_bias_g
|
|
331
|
+
+ off_h * stride_bias_h,
|
|
332
|
+
shape=(N_CTX_Q, hi),
|
|
333
|
+
strides=(stride_bias_qm, stride_bias_km),
|
|
334
|
+
offsets=(start_m * BLOCK_M, lo),
|
|
335
|
+
block_shape=(BLOCK_M, BLOCK_N),
|
|
336
|
+
order=(0, 1),
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
if SPLIT_K_EARLY_EXIT and lo >= hi:
|
|
340
|
+
return
|
|
341
|
+
|
|
342
|
+
Q_block_ptr = tl.make_block_ptr(
|
|
343
|
+
base=Q
|
|
344
|
+
+ off_m * stride_qm
|
|
345
|
+
+ off_h * stride_qh
|
|
346
|
+
+ off_z * stride_qz * queries_use_batch_dim
|
|
347
|
+
+ off_g * stride_qg,
|
|
348
|
+
shape=(q_len, BLOCK_DMODEL),
|
|
349
|
+
strides=(stride_qm, stride_qk),
|
|
350
|
+
offsets=(start_m * BLOCK_M, 0),
|
|
351
|
+
block_shape=(BLOCK_M, D_PER_GROUP),
|
|
352
|
+
order=(1, 0),
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# initialize pointer to m and l
|
|
356
|
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
|
357
|
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
|
358
|
+
|
|
359
|
+
# Before compilation, this kernel will be processed by unroll_varargs.
|
|
360
|
+
# That turns tensors annotated as the one below into lists of tensors of length N_GROUPS.
|
|
361
|
+
# This is a solution for Triton native lack of support for lists of tensors.
|
|
362
|
+
acc: "VAR_ARGS_ARRAY" # noqa: F821
|
|
363
|
+
|
|
364
|
+
for i in range(len(acc)): # noqa: F821
|
|
365
|
+
acc[i] = tl.zeros([BLOCK_M, D_PER_GROUP], dtype=internal_dtype) # noqa: F821
|
|
366
|
+
# scale sm_scale by log_2(e) and use
|
|
367
|
+
# 2^x instead of exp in the loop because CSE and LICM
|
|
368
|
+
# don't work as expected with `exp` in the loop
|
|
369
|
+
#
|
|
370
|
+
# We declare log2e as a constant with a precisely-specified type to guarantee that
|
|
371
|
+
# triton will use the exact same value in all instances below, rather than sometimes
|
|
372
|
+
# using float32 and sometimes using float64. For more discussion see:
|
|
373
|
+
# https://github.com/triton-lang/triton/issues/5466
|
|
374
|
+
log2e = tl.full((), 1.44269504, tl.float32)
|
|
375
|
+
qk_scale = sm_scale * log2e
|
|
376
|
+
# load q: it will stay in SRAM throughout
|
|
377
|
+
q: "VAR_ARGS_ARRAY" # noqa: F821
|
|
378
|
+
|
|
379
|
+
if QUANTIZE_QK_TO_FP8:
|
|
380
|
+
# Create a block pointer for q_scale
|
|
381
|
+
q_scale_block_ptr = tl.make_block_ptr(
|
|
382
|
+
base=q_fp8_scale_shift
|
|
383
|
+
+ off_m * stride_q_fp8_scale_shift_m
|
|
384
|
+
+ off_h * stride_q_fp8_scale_shift_h
|
|
385
|
+
+ off_g * stride_q_fp8_scale_shift_g
|
|
386
|
+
+ off_z * stride_q_fp8_scale_shift_z * queries_use_batch_dim,
|
|
387
|
+
shape=(q_len, 1),
|
|
388
|
+
strides=(stride_q_fp8_scale_shift_m, 1),
|
|
389
|
+
offsets=(start_m * BLOCK_M, 0),
|
|
390
|
+
block_shape=(BLOCK_M, 1),
|
|
391
|
+
order=(1, 0),
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
# For FP8 quantized query, load and dequantize
|
|
395
|
+
for i in range(len(acc)): # noqa: F821
|
|
396
|
+
# Load quantized query
|
|
397
|
+
q_quantized = tl.load( # noqa: F821
|
|
398
|
+
tl.advance(Q_block_ptr, (0, i * D_PER_GROUP)), boundary_check=(0,)
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
# Load q_scale for dequantization - q_scale is per row
|
|
402
|
+
q_scale = tl.load(
|
|
403
|
+
tl.advance(q_scale_block_ptr, (0, i * D_PER_GROUP)), boundary_check=(0,)
|
|
404
|
+
)
|
|
405
|
+
q[i] = q_quantized.to(Q.dtype.element_ty) # noqa: F821
|
|
406
|
+
else:
|
|
407
|
+
# Regular query loading
|
|
408
|
+
for i in range(len(acc)): # noqa: F821
|
|
409
|
+
q[i] = tl.load( # noqa: F821
|
|
410
|
+
tl.advance(Q_block_ptr, (0, i * D_PER_GROUP)), boundary_check=(0,)
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
if IS_CAUSAL or IS_LOCAL:
|
|
414
|
+
# Why does the masking conditon below work as a causal mask?
|
|
415
|
+
# Assuming num_queries <= BLOCK_M:
|
|
416
|
+
# kv_pos = kv_start + range(0, BLOCK_N)
|
|
417
|
+
# q_offset = start_m * BLOCK_M + range(0, BLOCK_M)
|
|
418
|
+
# q_pos = kv_start + kv_len - num_queries + q_offset % num_queries
|
|
419
|
+
# mask = q_pos - kv_pos >= 0
|
|
420
|
+
# So the final masking condition is:
|
|
421
|
+
# range(0, BLOCK_M) % num_queries - range(0, BLOCK_N) >= num_queries - kv_len
|
|
422
|
+
|
|
423
|
+
q_offset = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
424
|
+
diag_idx = (q_offset[:, None] % NUM_QUERIES_CAUSAL) - tl.arange(0, BLOCK_N)[
|
|
425
|
+
None, :
|
|
426
|
+
]
|
|
427
|
+
diag_idx_shifted = tl.constexpr(diag_idx - NUM_QUERIES_CAUSAL + kv_len)
|
|
428
|
+
|
|
429
|
+
# loop over k, v and update accumulator
|
|
430
|
+
for start_n in range(lo, hi, BLOCK_N):
|
|
431
|
+
if PAGE_SIZE > 0:
|
|
432
|
+
# Offset in integer blocks from the beginning of the page
|
|
433
|
+
block_offset_in_page = logical_block_idx % BLOCKS_IN_PAGE
|
|
434
|
+
# Offset in integer pages
|
|
435
|
+
logical_page_idx = logical_block_idx // BLOCKS_IN_PAGE
|
|
436
|
+
physical_page_idx = tl.load(
|
|
437
|
+
block_table + stride_blocktablesl * logical_page_idx
|
|
438
|
+
).to(tl.int32)
|
|
439
|
+
offset = physical_page_idx * PAGE_SIZE + block_offset_in_page * BLOCK_N
|
|
440
|
+
|
|
441
|
+
current_block_size = min(hi - start_n, BLOCK_N)
|
|
442
|
+
K_block_ptr = tl.make_block_ptr(
|
|
443
|
+
base=k_base + stride_kk * INT4_QUANTIZED * N_GROUPS,
|
|
444
|
+
shape=(PACKED_D_PER_GROUP, offset + current_block_size),
|
|
445
|
+
strides=(stride_kk, stride_kn),
|
|
446
|
+
offsets=(0, offset),
|
|
447
|
+
block_shape=(PACKED_D_PER_GROUP, BLOCK_N),
|
|
448
|
+
order=(0, 1),
|
|
449
|
+
)
|
|
450
|
+
V_block_ptr = tl.make_block_ptr(
|
|
451
|
+
base=v_base + stride_vk * INT4_QUANTIZED * N_GROUPS,
|
|
452
|
+
shape=(
|
|
453
|
+
offset + current_block_size,
|
|
454
|
+
PACKED_D_PER_GROUP if not QUANTIZE_PV_TO_FP8 else D_PER_GROUP,
|
|
455
|
+
),
|
|
456
|
+
strides=(stride_vn, stride_vk),
|
|
457
|
+
offsets=(offset, 0),
|
|
458
|
+
block_shape=(
|
|
459
|
+
BLOCK_N,
|
|
460
|
+
PACKED_D_PER_GROUP if not QUANTIZE_PV_TO_FP8 else D_PER_GROUP,
|
|
461
|
+
),
|
|
462
|
+
order=(1, 0),
|
|
463
|
+
)
|
|
464
|
+
if INT4_QUANTIZED:
|
|
465
|
+
# Pointers to quantization coefficients. Even those they are 1D,
|
|
466
|
+
# we use block pointers here so the pointer arithmetic is in int64,
|
|
467
|
+
# as otherwise the offsets for V_scale_shift_block_ptr may overflow.
|
|
468
|
+
K_scale_shift_block_ptr = tl.make_block_ptr(
|
|
469
|
+
base=k_base,
|
|
470
|
+
shape=(1, offset + current_block_size),
|
|
471
|
+
strides=(stride_kk, stride_kn),
|
|
472
|
+
offsets=(0, offset),
|
|
473
|
+
block_shape=(1, BLOCK_N),
|
|
474
|
+
order=(0, 1),
|
|
475
|
+
)
|
|
476
|
+
V_scale_shift_block_ptr = tl.make_block_ptr(
|
|
477
|
+
base=v_base,
|
|
478
|
+
shape=(offset + current_block_size, 1),
|
|
479
|
+
strides=(stride_vn, stride_vk),
|
|
480
|
+
offsets=(offset, 0),
|
|
481
|
+
block_shape=(BLOCK_N, 1),
|
|
482
|
+
order=(1, 0),
|
|
483
|
+
)
|
|
484
|
+
elif FP8_QUANTIZED:
|
|
485
|
+
K_scale_shift_block_ptr = tl.make_block_ptr(
|
|
486
|
+
base=k_fp8_scale_shift_base,
|
|
487
|
+
shape=(1, offset + current_block_size),
|
|
488
|
+
strides=(1, stride_k_fp8_scale_shift_n),
|
|
489
|
+
offsets=(0, offset),
|
|
490
|
+
block_shape=(1, BLOCK_N),
|
|
491
|
+
order=(0, 1),
|
|
492
|
+
)
|
|
493
|
+
V_scale_shift_block_ptr = tl.make_block_ptr(
|
|
494
|
+
base=v_fp8_scale_shift_base,
|
|
495
|
+
shape=(offset + current_block_size, 1),
|
|
496
|
+
strides=(stride_v_fp8_scale_shift_n, 1),
|
|
497
|
+
offsets=(offset, 0),
|
|
498
|
+
block_shape=(BLOCK_N, 1),
|
|
499
|
+
order=(1, 0),
|
|
500
|
+
)
|
|
501
|
+
else:
|
|
502
|
+
K_scale_shift_block_ptr = None
|
|
503
|
+
V_scale_shift_block_ptr = None
|
|
504
|
+
logical_block_idx += 1
|
|
505
|
+
|
|
506
|
+
k: "VAR_ARGS_ARRAY" # noqa: F821
|
|
507
|
+
v: "VAR_ARGS_ARRAY" # noqa: F821
|
|
508
|
+
|
|
509
|
+
if QUANTIZE_PV_TO_FP8:
|
|
510
|
+
v_dtype = tl.float8e4nv
|
|
511
|
+
else:
|
|
512
|
+
v_dtype = Q.dtype.element_ty
|
|
513
|
+
for i in range(len(acc)): # noqa: F821
|
|
514
|
+
# Load and dequantize K/V with appropriate return values based on quantization flags
|
|
515
|
+
result = load_dequantize_k_v_group( # noqa: F821
|
|
516
|
+
K_block_ptr,
|
|
517
|
+
V_block_ptr,
|
|
518
|
+
K_scale_shift_block_ptr,
|
|
519
|
+
V_scale_shift_block_ptr,
|
|
520
|
+
BOUNDS_CHECKS_N,
|
|
521
|
+
PACKED_PER_VAL,
|
|
522
|
+
PACKED_D_PER_GROUP,
|
|
523
|
+
FP8_QUANTIZED,
|
|
524
|
+
Q.dtype.element_ty,
|
|
525
|
+
v_dtype,
|
|
526
|
+
i,
|
|
527
|
+
IS_HIP,
|
|
528
|
+
QUANTIZE_PV_TO_FP8,
|
|
529
|
+
QUANTIZE_QK_TO_FP8,
|
|
530
|
+
USE_FP32_SCALES,
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
# Unpack results based on quantization configuration
|
|
534
|
+
if QUANTIZE_PV_TO_FP8 and QUANTIZE_QK_TO_FP8:
|
|
535
|
+
k[i], v[i], v_scale, k_scale = result # noqa: F821
|
|
536
|
+
elif QUANTIZE_PV_TO_FP8:
|
|
537
|
+
k[i], v[i], v_scale = result # noqa: F821
|
|
538
|
+
elif QUANTIZE_QK_TO_FP8:
|
|
539
|
+
k[i], v[i], k_scale = result # noqa: F821
|
|
540
|
+
else:
|
|
541
|
+
k[i], v[i] = result # noqa: F821
|
|
542
|
+
# -- compute qk ---
|
|
543
|
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
|
544
|
+
for i in range(len(acc)): # noqa: F821
|
|
545
|
+
qk += tl.dot(q[i], k[i]) # noqa: F821
|
|
546
|
+
|
|
547
|
+
if QUANTIZE_QK_TO_FP8:
|
|
548
|
+
# Reshape k_scale for proper broadcasting with qk
|
|
549
|
+
# k_scale has shape (BLOCK_N,), we need to reshape it to (1, BLOCK_N)
|
|
550
|
+
# for proper broadcasting with qk of shape (BLOCK_M, BLOCK_N)
|
|
551
|
+
k_scale_reshaped = tl.reshape(k_scale, (1, BLOCK_N))
|
|
552
|
+
|
|
553
|
+
# Apply k_scale to qk
|
|
554
|
+
qk = qk * k_scale_reshaped
|
|
555
|
+
qk = qk * tl.reshape(q_scale, (BLOCK_M, 1)) # noqa: F821
|
|
556
|
+
|
|
557
|
+
# Apply qk_scale (scalar)
|
|
558
|
+
qk *= qk_scale
|
|
559
|
+
|
|
560
|
+
if start_n == lo and ignore_in_first_block > 0:
|
|
561
|
+
qk = tl.where(
|
|
562
|
+
tl.arange(0, BLOCK_N) < ignore_in_first_block, float("-inf"), qk
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
if HAS_ADDITIVE_BIAS:
|
|
566
|
+
loaded_bias = tl.load(
|
|
567
|
+
additive_bias_block_ptr,
|
|
568
|
+
boundary_check=(0, 1) if BOUNDS_CHECKS_N else (0,),
|
|
569
|
+
)
|
|
570
|
+
qk += loaded_bias.to(tl.float32) * log2e
|
|
571
|
+
additive_bias_block_ptr = tl.advance(additive_bias_block_ptr, (0, BLOCK_N))
|
|
572
|
+
|
|
573
|
+
# TODO: This is slow, and only needed at the last iteration.
|
|
574
|
+
# Maybe we can unroll the last iteration instead?
|
|
575
|
+
if BOUNDS_CHECKS_N:
|
|
576
|
+
qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf"))
|
|
577
|
+
if IS_CAUSAL:
|
|
578
|
+
# -- apply the causal mask --
|
|
579
|
+
qk = tl.where(diag_idx_shifted >= start_n - start_kv_idx, qk, float("-inf"))
|
|
580
|
+
if IS_LOCAL:
|
|
581
|
+
# -- apply the local window size mask --
|
|
582
|
+
qk = tl.where(
|
|
583
|
+
diag_idx_shifted < start_n - start_kv_idx + WINDOW_LEFT + 1,
|
|
584
|
+
qk,
|
|
585
|
+
float("-inf"),
|
|
586
|
+
)
|
|
587
|
+
if not IS_CAUSAL and WINDOW_RIGHT >= 0:
|
|
588
|
+
qk = tl.where(
|
|
589
|
+
diag_idx_shifted >= start_n - start_kv_idx - WINDOW_RIGHT,
|
|
590
|
+
qk,
|
|
591
|
+
float("-inf"),
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
# -- compute scaling constant ---
|
|
595
|
+
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
|
|
596
|
+
alpha = tl.math.exp2(m_i - m_i_new)
|
|
597
|
+
p = tl.math.exp2(qk - m_i_new[:, None])
|
|
598
|
+
if HAS_ADDITIVE_BIAS or (IS_CAUSAL or IS_LOCAL):
|
|
599
|
+
# NOTE: It's possible that an entire block is masked out.
|
|
600
|
+
# if this is the case, `m_i_new=nan` and everything becomes nan
|
|
601
|
+
alpha = tl.where(m_i_new == float("-inf"), 0, alpha)
|
|
602
|
+
p = tl.where(m_i_new[:, None] == float("-inf"), 0, p)
|
|
603
|
+
|
|
604
|
+
# -- update m_i and l_i --
|
|
605
|
+
l_i = l_i * alpha + tl.sum(p, 1)
|
|
606
|
+
m_i = m_i_new
|
|
607
|
+
if not QUANTIZE_PV_TO_FP8:
|
|
608
|
+
p = p.to(v_dtype)
|
|
609
|
+
else:
|
|
610
|
+
# Apply v-scale to P
|
|
611
|
+
p = p * tl.trans(v_scale)
|
|
612
|
+
|
|
613
|
+
# Quantize P to FP8
|
|
614
|
+
MAX_FP8 = 448
|
|
615
|
+
amax = tl.max(p, axis=1) # rowmax(P)
|
|
616
|
+
p_scale = tl.maximum(amax / MAX_FP8, 1e-9)
|
|
617
|
+
p_scaled = p / p_scale[:, None]
|
|
618
|
+
p_clamped = tl.clamp(p_scaled, 0, MAX_FP8)
|
|
619
|
+
|
|
620
|
+
# covert P to FP8
|
|
621
|
+
p = p_clamped.to(v_dtype, fp_downcast_rounding="rtne")
|
|
622
|
+
|
|
623
|
+
# -- scale and update acc --
|
|
624
|
+
for i in range(len(acc)): # noqa: F821
|
|
625
|
+
acc[i] *= alpha[:, None] # noqa: F821
|
|
626
|
+
if not QUANTIZE_PV_TO_FP8:
|
|
627
|
+
acc[i] += tl.dot(p, v[i]) # noqa: F821
|
|
628
|
+
else:
|
|
629
|
+
# Re-scale PV using p_scale
|
|
630
|
+
acc[i] += tl.dot(p, v[i]) * p_scale[:, None] # noqa: F821
|
|
631
|
+
|
|
632
|
+
if not PAGE_SIZE:
|
|
633
|
+
# update pointers
|
|
634
|
+
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
|
635
|
+
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
|
636
|
+
if PACKED_PER_VAL > 1:
|
|
637
|
+
K_scale_shift_block_ptr = tl.advance(
|
|
638
|
+
K_scale_shift_block_ptr, (0, BLOCK_N)
|
|
639
|
+
)
|
|
640
|
+
V_scale_shift_block_ptr = tl.advance(
|
|
641
|
+
V_scale_shift_block_ptr, (BLOCK_N, 0)
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
# write back O
|
|
645
|
+
O_block_ptr = tl.make_block_ptr(
|
|
646
|
+
base=Out_splitK
|
|
647
|
+
+ off_z.to(tl.int64) * stride_osk_z * queries_use_batch_dim
|
|
648
|
+
+ off_m * stride_osk_m
|
|
649
|
+
+ off_g * stride_osk_g
|
|
650
|
+
+ off_h * stride_osk_h
|
|
651
|
+
+ splitk_idx * stride_osk_s,
|
|
652
|
+
shape=(q_len, D_PER_GROUP),
|
|
653
|
+
strides=(stride_osk_m, 1),
|
|
654
|
+
offsets=(start_m * BLOCK_M, 0),
|
|
655
|
+
block_shape=(BLOCK_M, D_PER_GROUP),
|
|
656
|
+
order=(1, 0),
|
|
657
|
+
)
|
|
658
|
+
for i in range(len(acc)): # noqa: F821
|
|
659
|
+
# If for the current batch element there are no tokens in the current split-k chunk (because
|
|
660
|
+
# seqlen is too short), l_i will be 0, so we need to make sure attention is filled with zeros and not NaNs.
|
|
661
|
+
attn_out = tl.where(l_i[:, None] == 0, 0.0, acc[i] / l_i[:, None]) # noqa: F821
|
|
662
|
+
tl.store(
|
|
663
|
+
tl.advance(O_block_ptr, (0, i * D_PER_GROUP)),
|
|
664
|
+
attn_out.to(Out_splitK.dtype.element_ty), # noqa: F821
|
|
665
|
+
boundary_check=(0,),
|
|
666
|
+
)
|
|
667
|
+
if WRITE_LSE:
|
|
668
|
+
LSE_splitk_ptr = (
|
|
669
|
+
LSE_splitk
|
|
670
|
+
+ off_z * stride_lsek_z * queries_use_batch_dim
|
|
671
|
+
+ off_m * stride_lsek_m
|
|
672
|
+
+ off_g * stride_lsek_g
|
|
673
|
+
+ off_h * stride_lsek_h
|
|
674
|
+
+ splitk_idx * stride_lsek_s
|
|
675
|
+
+ (start_m * BLOCK_M + tl.arange(0, BLOCK_M)) * stride_lsek_m
|
|
676
|
+
)
|
|
677
|
+
mask = start_m * BLOCK_M + tl.arange(0, BLOCK_M) < q_len
|
|
678
|
+
# Can be float64 to improve numerics
|
|
679
|
+
lse_dtype = LSE_splitk.dtype.element_ty
|
|
680
|
+
tl.store(
|
|
681
|
+
LSE_splitk_ptr,
|
|
682
|
+
(tl.math.log2(l_i.to(lse_dtype)) + m_i.to(lse_dtype)) / log2e,
|
|
683
|
+
mask=mask,
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
|
|
687
|
+
def gen_config(
|
|
688
|
+
block_m: int,
|
|
689
|
+
block_n: int,
|
|
690
|
+
stages: int,
|
|
691
|
+
warps: int,
|
|
692
|
+
) -> triton.Config:
|
|
693
|
+
"""A more compact way to define a triton.Config, so it fits on one line"""
|
|
694
|
+
|
|
695
|
+
return triton.Config(
|
|
696
|
+
{
|
|
697
|
+
"BLOCK_M": block_m,
|
|
698
|
+
"BLOCK_N": block_n,
|
|
699
|
+
},
|
|
700
|
+
num_stages=stages,
|
|
701
|
+
num_warps=warps,
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
def _get_splitk_kernel(num_groups):
|
|
706
|
+
"""
|
|
707
|
+
Kernel _fwd_kernel_splitK needs to be post-processed by unroll_varargs
|
|
708
|
+
to specialize it for a given number of quantization groups N_GROUPS
|
|
709
|
+
before we can apply triton.heuristics and triton.autotune, so we
|
|
710
|
+
don't do them as decorators.
|
|
711
|
+
"""
|
|
712
|
+
|
|
713
|
+
_fwd_kernel_splitK_unrolled = unroll_varargs(_fwd_kernel_splitK, N=num_groups)
|
|
714
|
+
kernel = triton.heuristics(
|
|
715
|
+
{
|
|
716
|
+
"BOUNDS_CHECKS_N": lambda args: bool(
|
|
717
|
+
(args["BLOCK_N_PER_SPLIT"] % args["BLOCK_N"])
|
|
718
|
+
or (
|
|
719
|
+
args["BLOCK_N_PER_SPLIT"] > 0
|
|
720
|
+
and args["N_CTX_K"] % args["BLOCK_N_PER_SPLIT"]
|
|
721
|
+
)
|
|
722
|
+
or args["USE_SEQ_LEN"]
|
|
723
|
+
)
|
|
724
|
+
}
|
|
725
|
+
)(_fwd_kernel_splitK_unrolled)
|
|
726
|
+
return kernel
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
def early_config_prune(configs, named_args, **kwargs):
|
|
730
|
+
use_paged_attention = kwargs["USE_PAGED_ATTENTION"]
|
|
731
|
+
page_size = kwargs["PAGE_SIZE"]
|
|
732
|
+
if use_paged_attention:
|
|
733
|
+
return list(
|
|
734
|
+
filter(lambda config: page_size % config.kwargs["BLOCK_N"] == 0, configs)
|
|
735
|
+
)
|
|
736
|
+
else:
|
|
737
|
+
return configs
|
|
738
|
+
|
|
739
|
+
|
|
740
|
+
@functools.lru_cache(None)
|
|
741
|
+
def autotune_kernel(kernel: Callable):
|
|
742
|
+
BLOCK_M_VALUES = [16, 32, 64, 128]
|
|
743
|
+
BLOCK_N_VALUES = [16, 32, 64, 128]
|
|
744
|
+
STAGES_VALUES = [1, 2] if torch.version.hip else [1, 2, 3]
|
|
745
|
+
WARPS_VALUES = [1, 2, 4, 8]
|
|
746
|
+
|
|
747
|
+
TRITON_CONFIGS = [
|
|
748
|
+
gen_config(block_m, block_n, stages, warps)
|
|
749
|
+
for block_m in BLOCK_M_VALUES
|
|
750
|
+
for block_n in BLOCK_N_VALUES
|
|
751
|
+
for stages in STAGES_VALUES
|
|
752
|
+
for warps in WARPS_VALUES
|
|
753
|
+
if block_n >= block_m
|
|
754
|
+
]
|
|
755
|
+
|
|
756
|
+
kernel = triton.autotune(
|
|
757
|
+
configs=TRITON_CONFIGS,
|
|
758
|
+
key=AUTOTUNER_KEY,
|
|
759
|
+
use_cuda_graph=True,
|
|
760
|
+
prune_configs_by={
|
|
761
|
+
"early_config_prune": early_config_prune,
|
|
762
|
+
},
|
|
763
|
+
)(kernel)
|
|
764
|
+
return kernel
|
|
765
|
+
|
|
766
|
+
|
|
767
|
+
# This object contains forward kernels wrapped into autotuner for different number
|
|
768
|
+
# of quantization groups.
|
|
769
|
+
_fwd_kernel_splitK_autotune: Dict[int, triton.runtime.Autotuner] = {}
|
|
770
|
+
# The loop below:
|
|
771
|
+
# - transforms the jitted kernel with unroll_varargs producing a new kernel of each value of num_groups
|
|
772
|
+
# - wraps the kernel into triton.heuristics
|
|
773
|
+
# - wraps kernel into Triton autotuner. Autotuning itself happens the first time the kernel is called
|
|
774
|
+
if sys.version_info >= (3, 9):
|
|
775
|
+
# unroll_varargs requires Python 3.9+
|
|
776
|
+
for num_groups in [1, 2, 4, 8]:
|
|
777
|
+
_fwd_kernel_splitK_autotune[num_groups] = autotune_kernel(
|
|
778
|
+
_get_splitk_kernel(num_groups)
|
|
779
|
+
)
|
|
780
|
+
|
|
781
|
+
def get_autotuner_cache(
|
|
782
|
+
num_groups: int,
|
|
783
|
+
) -> Dict[Tuple[Union[int, str]], triton.Config]:
|
|
784
|
+
"""Returns a triton.runtime.autotuner.AutoTuner.cache object, which
|
|
785
|
+
represents mappings from kernel autotune keys (tuples describing kernel inputs)
|
|
786
|
+
to triton.Config
|
|
787
|
+
"""
|
|
788
|
+
return _fwd_kernel_splitK_autotune[num_groups].cache
|
|
789
|
+
|
|
790
|
+
def set_autotuner_cache(
|
|
791
|
+
cache: Dict[Tuple[Union[int, str]], triton.Config], num_groups: int
|
|
792
|
+
) -> None:
|
|
793
|
+
_fwd_kernel_splitK_autotune[num_groups].cache = cache
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
@triton.jit
|
|
797
|
+
def load_dequantize_k_v_group(
|
|
798
|
+
K_block_ptr,
|
|
799
|
+
V_block_ptr,
|
|
800
|
+
K_scale_shift_block_ptr,
|
|
801
|
+
V_scale_shift_block_ptr,
|
|
802
|
+
BOUNDS_CHECKS_N: tl.constexpr,
|
|
803
|
+
PACKED_PER_VAL: tl.constexpr,
|
|
804
|
+
PACKED_D_PER_GROUP: tl.constexpr,
|
|
805
|
+
FP8_QUANTIZED: tl.constexpr,
|
|
806
|
+
q_dtype: tl.constexpr,
|
|
807
|
+
v_dtype: tl.constexpr, # Q.dtype.element_ty
|
|
808
|
+
group_id: tl.constexpr,
|
|
809
|
+
IS_HIP: tl.constexpr,
|
|
810
|
+
QUANTIZE_PV_TO_FP8: tl.constexpr,
|
|
811
|
+
QUANTIZE_QK_TO_FP8: tl.constexpr,
|
|
812
|
+
USE_FP32_SCALES: tl.constexpr,
|
|
813
|
+
):
|
|
814
|
+
"""Load K/V for a given block. In case of int4/fp8-quantized K/V, dequantize them after loading.
|
|
815
|
+
If quantization is group-wise, use group_id to advance the pointers to the current group.
|
|
816
|
+
|
|
817
|
+
Returns:
|
|
818
|
+
- k, v: loaded and potentially dequantized tensors
|
|
819
|
+
- v_scale (optional): V scale factor if QUANTIZE_PV_TO_FP8 is True
|
|
820
|
+
- k_scale (optional): K scale factor if QUANTIZE_QK_TO_FP8 is True
|
|
821
|
+
"""
|
|
822
|
+
# Advance to the current quantization group
|
|
823
|
+
K_block_ptr = tl.advance(K_block_ptr, (PACKED_D_PER_GROUP * group_id, 0))
|
|
824
|
+
V_block_ptr = tl.advance(
|
|
825
|
+
V_block_ptr,
|
|
826
|
+
(0, PACKED_D_PER_GROUP * group_id)
|
|
827
|
+
if not QUANTIZE_PV_TO_FP8
|
|
828
|
+
else (0, PACKED_D_PER_GROUP * PACKED_PER_VAL * group_id),
|
|
829
|
+
)
|
|
830
|
+
|
|
831
|
+
# -- load k, v --
|
|
832
|
+
k = tl.load(K_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ())
|
|
833
|
+
v = tl.load(V_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ())
|
|
834
|
+
|
|
835
|
+
# Initialize return values
|
|
836
|
+
# v_scale = None
|
|
837
|
+
# k_scale = None
|
|
838
|
+
|
|
839
|
+
if FP8_QUANTIZED:
|
|
840
|
+
k, v, v_scale, k_scale = _process_fp8_quantization(
|
|
841
|
+
k,
|
|
842
|
+
v,
|
|
843
|
+
K_scale_shift_block_ptr,
|
|
844
|
+
V_scale_shift_block_ptr,
|
|
845
|
+
BOUNDS_CHECKS_N,
|
|
846
|
+
PACKED_PER_VAL,
|
|
847
|
+
q_dtype,
|
|
848
|
+
v_dtype,
|
|
849
|
+
IS_HIP,
|
|
850
|
+
QUANTIZE_PV_TO_FP8,
|
|
851
|
+
QUANTIZE_QK_TO_FP8,
|
|
852
|
+
USE_FP32_SCALES,
|
|
853
|
+
)
|
|
854
|
+
|
|
855
|
+
elif PACKED_PER_VAL > 1:
|
|
856
|
+
# Int4 quantization.
|
|
857
|
+
k, v = _process_int4_quantization(
|
|
858
|
+
k,
|
|
859
|
+
v,
|
|
860
|
+
K_scale_shift_block_ptr,
|
|
861
|
+
V_scale_shift_block_ptr,
|
|
862
|
+
group_id,
|
|
863
|
+
BOUNDS_CHECKS_N,
|
|
864
|
+
PACKED_PER_VAL,
|
|
865
|
+
q_dtype,
|
|
866
|
+
IS_HIP,
|
|
867
|
+
)
|
|
868
|
+
|
|
869
|
+
# Return appropriate values based on quantization flags
|
|
870
|
+
if QUANTIZE_PV_TO_FP8 and QUANTIZE_QK_TO_FP8:
|
|
871
|
+
# Return both v_scale and k_scale for applying to P and K
|
|
872
|
+
return k, v, v_scale, k_scale
|
|
873
|
+
elif QUANTIZE_PV_TO_FP8:
|
|
874
|
+
# Return v_scale for applying v_scale to P
|
|
875
|
+
return k, v, v_scale
|
|
876
|
+
elif QUANTIZE_QK_TO_FP8:
|
|
877
|
+
# Return k_scale for applying k_scale to K
|
|
878
|
+
return k, v, k_scale
|
|
879
|
+
else:
|
|
880
|
+
return k, v
|
|
881
|
+
|
|
882
|
+
|
|
883
|
+
@triton.jit
|
|
884
|
+
def _process_fp8_quantization(
|
|
885
|
+
k,
|
|
886
|
+
v,
|
|
887
|
+
K_scale_shift_block_ptr,
|
|
888
|
+
V_scale_shift_block_ptr,
|
|
889
|
+
BOUNDS_CHECKS_N: tl.constexpr,
|
|
890
|
+
PACKED_PER_VAL: tl.constexpr,
|
|
891
|
+
q_dtype: tl.constexpr,
|
|
892
|
+
v_dtype: tl.constexpr,
|
|
893
|
+
IS_HIP: tl.constexpr,
|
|
894
|
+
QUANTIZE_PV_TO_FP8: tl.constexpr,
|
|
895
|
+
QUANTIZE_QK_TO_FP8: tl.constexpr,
|
|
896
|
+
USE_FP32_SCALES: tl.constexpr,
|
|
897
|
+
):
|
|
898
|
+
"""Process FP8 quantization for K and V tensors."""
|
|
899
|
+
# Process V tensor
|
|
900
|
+
v_scale_shift = tl.load(
|
|
901
|
+
V_scale_shift_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ()
|
|
902
|
+
)
|
|
903
|
+
v_scale, v_shift = _extract_scale_shift(v_scale_shift, IS_HIP, USE_FP32_SCALES)
|
|
904
|
+
if not QUANTIZE_PV_TO_FP8:
|
|
905
|
+
v = dequantize(
|
|
906
|
+
v,
|
|
907
|
+
v_scale,
|
|
908
|
+
v_shift if not USE_FP32_SCALES else None,
|
|
909
|
+
PACKED_PER_VAL,
|
|
910
|
+
IS_HIP,
|
|
911
|
+
USE_FP32_SCALES,
|
|
912
|
+
).to(v_dtype)
|
|
913
|
+
else:
|
|
914
|
+
# Do not dequantize V; V needs to be FP8 for PV.
|
|
915
|
+
tl.static_assert(PACKED_PER_VAL == 4, "Assert: int32 packs four FP8 values")
|
|
916
|
+
|
|
917
|
+
# Process K tensor
|
|
918
|
+
k_scale_shift = tl.load(
|
|
919
|
+
K_scale_shift_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ()
|
|
920
|
+
)
|
|
921
|
+
k_scale, k_shift = _extract_scale_shift(k_scale_shift, IS_HIP, USE_FP32_SCALES)
|
|
922
|
+
if IS_HIP:
|
|
923
|
+
if not QUANTIZE_QK_TO_FP8:
|
|
924
|
+
k = dequantize_k_hip(k, k_scale, k_shift, PACKED_PER_VAL).to(q_dtype)
|
|
925
|
+
else:
|
|
926
|
+
# For QUANTIZE_QK_TO_FP8, unpack int32 to 8-bit entries and interpret as fp8
|
|
927
|
+
tl.static_assert(PACKED_PER_VAL == 4, "Assert: int32 packs four FP8 values")
|
|
928
|
+
else:
|
|
929
|
+
if not QUANTIZE_QK_TO_FP8:
|
|
930
|
+
k_t = dequantize(
|
|
931
|
+
tl.trans(k),
|
|
932
|
+
tl.trans(k_scale),
|
|
933
|
+
tl.trans(k_shift) if not USE_FP32_SCALES else None,
|
|
934
|
+
PACKED_PER_VAL,
|
|
935
|
+
IS_HIP,
|
|
936
|
+
USE_FP32_SCALES,
|
|
937
|
+
).to(q_dtype)
|
|
938
|
+
k = tl.trans(k_t)
|
|
939
|
+
else:
|
|
940
|
+
# For QUANTIZE_QK_TO_FP8, unpack int32 to 8-bit entries and interpret as fp8
|
|
941
|
+
tl.static_assert(PACKED_PER_VAL == 4, "Assert: int32 packs four FP8 values")
|
|
942
|
+
k_t = tl.trans(k)
|
|
943
|
+
k_t = _unpack_fp8_tensor(k_t, PACKED_PER_VAL, IS_HIP)
|
|
944
|
+
k = tl.trans(k_t)
|
|
945
|
+
|
|
946
|
+
return k, v, v_scale, k_scale
|
|
947
|
+
|
|
948
|
+
|
|
949
|
+
@triton.jit
|
|
950
|
+
def _extract_scale_shift(
|
|
951
|
+
scale_shift, IS_HIP: tl.constexpr, USE_FP32_SCALES: tl.constexpr
|
|
952
|
+
):
|
|
953
|
+
"""Extract scale and shift values from packed representation."""
|
|
954
|
+
if IS_HIP:
|
|
955
|
+
return cast_uint32_to_float(scale_shift)
|
|
956
|
+
elif USE_FP32_SCALES:
|
|
957
|
+
return scale_shift.to(tl.float32, bitcast=True), 0
|
|
958
|
+
else:
|
|
959
|
+
return cast_uint32_to_half2(scale_shift)
|
|
960
|
+
|
|
961
|
+
|
|
962
|
+
@triton.jit
|
|
963
|
+
def _unpack_fp8_tensor(x_, PACKED_PER_VAL: tl.constexpr, IS_HIP: tl.constexpr):
|
|
964
|
+
"""Unpack FP8 K/V tensor from int32 packed representation."""
|
|
965
|
+
tl.static_assert(PACKED_PER_VAL == 4, "Assert: int32 packs four FP8 values")
|
|
966
|
+
|
|
967
|
+
BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1]
|
|
968
|
+
BLOCK_N: tl.constexpr = x_.shape[0]
|
|
969
|
+
# Create bit offsets for unpacking (0, 8, 16, 24 bits)
|
|
970
|
+
offsets = tl.arange(0, PACKED_PER_VAL) * 8
|
|
971
|
+
|
|
972
|
+
# Extract 8-bit values by right-shifting and masking
|
|
973
|
+
unpacked_values = x_[:, :, None, :] >> offsets
|
|
974
|
+
unpacked_values = tl.reshape(
|
|
975
|
+
unpacked_values, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)
|
|
976
|
+
)
|
|
977
|
+
|
|
978
|
+
# Convert to FP8 through bitcast
|
|
979
|
+
fp8_type = tl.float8e4b8 if IS_HIP else tl.float8e4nv
|
|
980
|
+
x_ = unpacked_values.to(tl.uint8).to(fp8_type, bitcast=True)
|
|
981
|
+
|
|
982
|
+
return x_
|
|
983
|
+
|
|
984
|
+
|
|
985
|
+
@triton.jit
|
|
986
|
+
def _process_int4_quantization(
|
|
987
|
+
k,
|
|
988
|
+
v,
|
|
989
|
+
K_scale_shift_block_ptr,
|
|
990
|
+
V_scale_shift_block_ptr,
|
|
991
|
+
group_id: tl.constexpr,
|
|
992
|
+
BOUNDS_CHECKS_N: tl.constexpr,
|
|
993
|
+
PACKED_PER_VAL: tl.constexpr,
|
|
994
|
+
dtype: tl.constexpr,
|
|
995
|
+
IS_HIP: tl.constexpr,
|
|
996
|
+
):
|
|
997
|
+
"""Process INT4 quantization for K and V tensors."""
|
|
998
|
+
# Advance scale/shift pointers for INT4
|
|
999
|
+
K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (group_id, 0))
|
|
1000
|
+
V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (0, group_id))
|
|
1001
|
+
|
|
1002
|
+
k_scale_shift = tl.load(
|
|
1003
|
+
K_scale_shift_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ()
|
|
1004
|
+
)
|
|
1005
|
+
v_scale_shift = tl.load(
|
|
1006
|
+
V_scale_shift_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ()
|
|
1007
|
+
)
|
|
1008
|
+
if IS_HIP:
|
|
1009
|
+
k_scale, k_shift = cast_uint32_to_float(k_scale_shift)
|
|
1010
|
+
v_scale, v_shift = cast_uint32_to_float(v_scale_shift)
|
|
1011
|
+
v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL, IS_HIP).to(dtype)
|
|
1012
|
+
k = dequantize_k_hip(k, k_scale, k_shift, PACKED_PER_VAL).to(dtype)
|
|
1013
|
+
else:
|
|
1014
|
+
k_scale, k_shift = cast_uint32_to_half2(k_scale_shift)
|
|
1015
|
+
v_scale, v_shift = cast_uint32_to_half2(v_scale_shift)
|
|
1016
|
+
v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL, IS_HIP).to(dtype)
|
|
1017
|
+
k_t = dequantize(
|
|
1018
|
+
tl.trans(k),
|
|
1019
|
+
tl.trans(k_scale),
|
|
1020
|
+
tl.trans(k_shift),
|
|
1021
|
+
PACKED_PER_VAL,
|
|
1022
|
+
IS_HIP,
|
|
1023
|
+
).to(dtype)
|
|
1024
|
+
k = tl.trans(k_t)
|
|
1025
|
+
|
|
1026
|
+
return k, v
|
|
1027
|
+
|
|
1028
|
+
|
|
1029
|
+
@triton.jit
|
|
1030
|
+
def cast_uint64_to_float2(scale_shift):
|
|
1031
|
+
"""Using FP32 scales, so only extract one fp32 from the packed int64"""
|
|
1032
|
+
scale = scale_shift & 0xFFFFFFFF
|
|
1033
|
+
scale = scale.to(tl.uint32).to(tl.float32, bitcast=True)
|
|
1034
|
+
return scale, 0
|
|
1035
|
+
|
|
1036
|
+
|
|
1037
|
+
@triton.jit
|
|
1038
|
+
def cast_uint32_to_half2(scale_shift):
|
|
1039
|
+
"""Extract two float16 packed into one int32"""
|
|
1040
|
+
scale = scale_shift & 0xFFFF
|
|
1041
|
+
shift = scale_shift >> 16
|
|
1042
|
+
scale = scale.to(tl.uint16).to(tl.float16, bitcast=True)
|
|
1043
|
+
shift = shift.to(tl.uint16).to(tl.float16, bitcast=True)
|
|
1044
|
+
return scale, shift
|
|
1045
|
+
|
|
1046
|
+
|
|
1047
|
+
@triton.jit
|
|
1048
|
+
def cast_uint32_to_float(scale_shift):
|
|
1049
|
+
"""Extract two float16 packed into one int32 as float32"""
|
|
1050
|
+
scale = scale_shift & 0xFFFF
|
|
1051
|
+
shift = scale_shift >> 16
|
|
1052
|
+
scale = scale.to(tl.uint16).to(tl.float16, bitcast=True).to(tl.float32)
|
|
1053
|
+
shift = shift.to(tl.uint16).to(tl.float16, bitcast=True).to(tl.float32)
|
|
1054
|
+
return scale, shift
|
|
1055
|
+
|
|
1056
|
+
|
|
1057
|
+
@triton.jit
|
|
1058
|
+
def dequantize_k_hip(
|
|
1059
|
+
x_,
|
|
1060
|
+
scale,
|
|
1061
|
+
shift,
|
|
1062
|
+
PACKED_PER_VAL: tl.constexpr,
|
|
1063
|
+
):
|
|
1064
|
+
"""PACKED_PER_VAL is the number of values packed into each element x_.
|
|
1065
|
+
For example, for int4 quantization and x_ of type int32, PACKED_PER_VAL is 8.
|
|
1066
|
+
"""
|
|
1067
|
+
# x_ : (BLOCK_N, D // PACKED_PER_VAL)
|
|
1068
|
+
# scale: (BLOCK_N, 1)
|
|
1069
|
+
# offsets: (PACKED_PER_VAL,)
|
|
1070
|
+
BLOCK_N: tl.constexpr = x_.shape[1]
|
|
1071
|
+
BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[0]
|
|
1072
|
+
offsets = tl.arange(0, PACKED_PER_VAL) * (32 // PACKED_PER_VAL)
|
|
1073
|
+
quant_offset = (
|
|
1074
|
+
x_[:, None, :, :] >> offsets[:, None]
|
|
1075
|
+
) # (BLOCK_N, D // PACKED_PER_VAL, PACKED_PER_VAL)
|
|
1076
|
+
|
|
1077
|
+
quant_offset = tl.reshape(
|
|
1078
|
+
quant_offset, (BLOCK_DMODEL_PACKED * PACKED_PER_VAL, BLOCK_N)
|
|
1079
|
+
)
|
|
1080
|
+
|
|
1081
|
+
if PACKED_PER_VAL == 4:
|
|
1082
|
+
# FP8 quantization.
|
|
1083
|
+
fp8_type = tl.float8e4b8 if torch.version.hip is not None else tl.float8e4nv
|
|
1084
|
+
dequant = (
|
|
1085
|
+
quant_offset.to(tl.uint8).to(fp8_type, bitcast=True).to(scale.dtype) * scale
|
|
1086
|
+
+ shift
|
|
1087
|
+
)
|
|
1088
|
+
else:
|
|
1089
|
+
# Int4 quantization.
|
|
1090
|
+
# Trick - instead of converting int4 to float16 we view it as float16
|
|
1091
|
+
# and then multiply by 32768 * 512 == 2**24
|
|
1092
|
+
quant_offset = (
|
|
1093
|
+
(quant_offset & 0xF)
|
|
1094
|
+
.to(tl.uint16)
|
|
1095
|
+
.to(tl.float16, bitcast=True)
|
|
1096
|
+
.to(tl.float32)
|
|
1097
|
+
)
|
|
1098
|
+
quant_offset = quant_offset * 32768.0
|
|
1099
|
+
scale_512 = scale * 512
|
|
1100
|
+
|
|
1101
|
+
dequant = quant_offset * scale_512 + shift
|
|
1102
|
+
return dequant
|
|
1103
|
+
|
|
1104
|
+
|
|
1105
|
+
@triton.jit
|
|
1106
|
+
def dequantize(
|
|
1107
|
+
x_,
|
|
1108
|
+
scale,
|
|
1109
|
+
shift,
|
|
1110
|
+
PACKED_PER_VAL: tl.constexpr,
|
|
1111
|
+
IS_HIP: tl.constexpr,
|
|
1112
|
+
USE_FP32_SCALES: tl.constexpr = False,
|
|
1113
|
+
):
|
|
1114
|
+
"""PACKED_PER_VAL is the number of values packed into each element x_.
|
|
1115
|
+
For example, for int4 quantization and x_ of type int32, PACKED_PER_VAL is 8.
|
|
1116
|
+
"""
|
|
1117
|
+
# x_ : (BLOCK_N, D // PACKED_PER_VAL)
|
|
1118
|
+
# scale: (BLOCK_N, 1)
|
|
1119
|
+
# offsets: (PACKED_PER_VAL,)
|
|
1120
|
+
BLOCK_N: tl.constexpr = x_.shape[0]
|
|
1121
|
+
BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1]
|
|
1122
|
+
offsets = tl.arange(0, PACKED_PER_VAL) * (32 // PACKED_PER_VAL)
|
|
1123
|
+
quant_offset = (
|
|
1124
|
+
x_[:, :, None, :] >> offsets
|
|
1125
|
+
) # (BLOCK_N, D // PACKED_PER_VAL, PACKED_PER_VAL)
|
|
1126
|
+
|
|
1127
|
+
quant_offset = tl.reshape(
|
|
1128
|
+
quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)
|
|
1129
|
+
)
|
|
1130
|
+
if PACKED_PER_VAL == 4:
|
|
1131
|
+
# FP8 quantization.
|
|
1132
|
+
fp8_type = tl.float8e4b8 if torch.version.hip is not None else tl.float8e4nv
|
|
1133
|
+
dequant = (
|
|
1134
|
+
quant_offset.to(tl.uint8).to(fp8_type, bitcast=True).to(scale.dtype) * scale
|
|
1135
|
+
)
|
|
1136
|
+
if not USE_FP32_SCALES:
|
|
1137
|
+
# Use asymmetric quantization only for FP16 scales
|
|
1138
|
+
dequant += shift
|
|
1139
|
+
else:
|
|
1140
|
+
# Int4 quantization.
|
|
1141
|
+
# Trick - instead of converting int4 to float16 we view it as float16
|
|
1142
|
+
# and then multiply by 32768 * 512 == 2**24
|
|
1143
|
+
if IS_HIP:
|
|
1144
|
+
# Do final math in float32 to avoid casting to bf16 on MI300. There
|
|
1145
|
+
# no direct instructions for this so its less performant on this workload.
|
|
1146
|
+
quant_offset = (
|
|
1147
|
+
(quant_offset & 0xF)
|
|
1148
|
+
.to(tl.uint16)
|
|
1149
|
+
.to(tl.float16, bitcast=True)
|
|
1150
|
+
.to(tl.float32)
|
|
1151
|
+
)
|
|
1152
|
+
quant_offset = quant_offset * 32768.0
|
|
1153
|
+
else:
|
|
1154
|
+
quant_offset = (
|
|
1155
|
+
(quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True)
|
|
1156
|
+
)
|
|
1157
|
+
quant_offset = (quant_offset * 32768.0).to(tl.float16)
|
|
1158
|
+
scale_512 = scale * 512
|
|
1159
|
+
|
|
1160
|
+
dequant = quant_offset * scale_512
|
|
1161
|
+
if not USE_FP32_SCALES:
|
|
1162
|
+
# Use asymmetric quantization only for FP16 scales
|
|
1163
|
+
dequant += shift
|
|
1164
|
+
return dequant
|
|
1165
|
+
|
|
1166
|
+
|
|
1167
|
+
@triton.jit
|
|
1168
|
+
def _splitK_reduce(
|
|
1169
|
+
Out_splitK, # [B, G, H, split_k, Mq, K]
|
|
1170
|
+
LSE_splitK, # [B, G, H, split_k, Mq]
|
|
1171
|
+
Out, # [B, H, M, K]
|
|
1172
|
+
LSE, # [B, H, M]
|
|
1173
|
+
split_k: tl.constexpr,
|
|
1174
|
+
splitK_pow2: tl.constexpr,
|
|
1175
|
+
stride_osk_z: tl.constexpr,
|
|
1176
|
+
stride_osk_g: tl.constexpr,
|
|
1177
|
+
stride_osk_h: tl.constexpr,
|
|
1178
|
+
stride_osk_s: tl.constexpr,
|
|
1179
|
+
stride_osk_m: tl.constexpr,
|
|
1180
|
+
stride_osk_k: tl.constexpr,
|
|
1181
|
+
stride_lsek_z: tl.constexpr,
|
|
1182
|
+
stride_lsek_g: tl.constexpr,
|
|
1183
|
+
stride_lsek_h: tl.constexpr,
|
|
1184
|
+
stride_lsek_s: tl.constexpr,
|
|
1185
|
+
stride_lsek_m: tl.constexpr,
|
|
1186
|
+
stride_oz: tl.constexpr,
|
|
1187
|
+
stride_og: tl.constexpr,
|
|
1188
|
+
stride_oh: tl.constexpr,
|
|
1189
|
+
stride_om: tl.constexpr,
|
|
1190
|
+
stride_ok: tl.constexpr,
|
|
1191
|
+
stride_lse_z: tl.constexpr,
|
|
1192
|
+
stride_lse_g: tl.constexpr,
|
|
1193
|
+
stride_lse_h: tl.constexpr,
|
|
1194
|
+
stride_lse_m: tl.constexpr,
|
|
1195
|
+
head_dim: tl.constexpr,
|
|
1196
|
+
head_dim_pow_2: tl.constexpr,
|
|
1197
|
+
H: tl.constexpr,
|
|
1198
|
+
G: tl.constexpr,
|
|
1199
|
+
WRITE_LSE: tl.constexpr,
|
|
1200
|
+
):
|
|
1201
|
+
# grid = (M, B * G * H, 1)
|
|
1202
|
+
off_m = tl.program_id(0).to(tl.int64)
|
|
1203
|
+
off_zhg = tl.program_id(1).to(tl.int64)
|
|
1204
|
+
off_z = off_zhg // (H * G)
|
|
1205
|
+
off_h = (off_zhg // G) % H
|
|
1206
|
+
off_g = off_zhg % G
|
|
1207
|
+
|
|
1208
|
+
head_dim_mask = tl.arange(0, head_dim_pow_2) < head_dim
|
|
1209
|
+
|
|
1210
|
+
Out_splitK_ptr = (
|
|
1211
|
+
Out_splitK
|
|
1212
|
+
+ stride_osk_z * off_z
|
|
1213
|
+
+ stride_osk_g * off_g
|
|
1214
|
+
+ stride_osk_h * off_h
|
|
1215
|
+
+ stride_osk_m * off_m
|
|
1216
|
+
+ tl.arange(0, head_dim_pow_2)[None, :]
|
|
1217
|
+
+ stride_osk_s * tl.arange(0, splitK_pow2)[:, None]
|
|
1218
|
+
)
|
|
1219
|
+
|
|
1220
|
+
LSE_splitK_ptr0 = (
|
|
1221
|
+
LSE_splitK
|
|
1222
|
+
+ stride_lsek_z * off_z
|
|
1223
|
+
+ stride_lsek_g * off_g
|
|
1224
|
+
+ stride_lsek_h * off_h
|
|
1225
|
+
+ stride_lsek_m * off_m
|
|
1226
|
+
+ stride_lsek_s * tl.arange(0, splitK_pow2)
|
|
1227
|
+
)
|
|
1228
|
+
|
|
1229
|
+
if splitK_pow2 > split_k:
|
|
1230
|
+
mask_1d = tl.arange(0, splitK_pow2) < split_k
|
|
1231
|
+
mask_2d = mask_1d[:, None] & head_dim_mask[None, :]
|
|
1232
|
+
lse_splitk = tl.load(LSE_splitK_ptr0, mask=mask_1d, other=float("-inf"))
|
|
1233
|
+
lse_max = tl.max(lse_splitk)
|
|
1234
|
+
out_splitk = tl.load(
|
|
1235
|
+
Out_splitK_ptr, mask=mask_2d, other=0
|
|
1236
|
+
) # (split_k, head_dim_pow_2)
|
|
1237
|
+
lse_splitk = tl.load(
|
|
1238
|
+
LSE_splitK_ptr0, mask=mask_1d, other=float("-inf")
|
|
1239
|
+
) # (split_k,)
|
|
1240
|
+
else:
|
|
1241
|
+
lse_splitk = tl.load(LSE_splitK_ptr0)
|
|
1242
|
+
lse_max = tl.max(lse_splitk)
|
|
1243
|
+
out_splitk = tl.load(Out_splitK_ptr)
|
|
1244
|
+
lse_splitk = tl.load(LSE_splitK_ptr0)
|
|
1245
|
+
|
|
1246
|
+
sumexp_normalized_splitk = tl.math.exp2(
|
|
1247
|
+
(lse_splitk - lse_max).to(tl.float32) * 1.44269504
|
|
1248
|
+
) # (split_k,)
|
|
1249
|
+
sumexp_normalized = tl.sum(sumexp_normalized_splitk, axis=0) # scalar
|
|
1250
|
+
# Compute numerator
|
|
1251
|
+
numerator_normalized = tl.sum(
|
|
1252
|
+
out_splitk * sumexp_normalized_splitk[:, None], axis=0
|
|
1253
|
+
)
|
|
1254
|
+
acc = numerator_normalized / sumexp_normalized
|
|
1255
|
+
acc = tl.where(lse_max == float("-inf"), 0.0, acc)
|
|
1256
|
+
|
|
1257
|
+
Out_ptr = (
|
|
1258
|
+
Out
|
|
1259
|
+
+ stride_oz * off_z
|
|
1260
|
+
+ stride_oh * off_h
|
|
1261
|
+
+ stride_og * off_g
|
|
1262
|
+
+ stride_om * off_m
|
|
1263
|
+
+ tl.arange(0, head_dim_pow_2)
|
|
1264
|
+
)
|
|
1265
|
+
if acc.dtype is tl.float64 and Out.dtype.element_ty is not tl.float64:
|
|
1266
|
+
# must avoid direct cast f64->f16
|
|
1267
|
+
acc = acc.to(tl.float32)
|
|
1268
|
+
tl.store(Out_ptr, acc, mask=head_dim_mask)
|
|
1269
|
+
|
|
1270
|
+
if WRITE_LSE:
|
|
1271
|
+
l_ptrs = (
|
|
1272
|
+
LSE
|
|
1273
|
+
+ off_z * stride_lse_z
|
|
1274
|
+
+ off_g * stride_lse_g
|
|
1275
|
+
+ off_h * stride_lse_h
|
|
1276
|
+
+ off_m * stride_lse_m
|
|
1277
|
+
)
|
|
1278
|
+
to_store = lse_max + tl.math.log2(sumexp_normalized) / 1.44269504
|
|
1279
|
+
to_store = tl.where(lse_max == float("-inf"), lse_max, to_store)
|
|
1280
|
+
tl.store(l_ptrs, to_store)
|
|
1281
|
+
|
|
1282
|
+
|
|
1283
|
+
@triton.jit
|
|
1284
|
+
def _splitK_reduce_varargs(
|
|
1285
|
+
Out_splitK: "VAR_ARGS_ARRAY", # list of [B, G, H, Mq, K];
|
|
1286
|
+
LSE_splitK: "VAR_ARGS_ARRAY", # list of [B, G, H, Mq]
|
|
1287
|
+
Out, # [B, G, H, M, K]
|
|
1288
|
+
LSE, # [B, G, H, M]
|
|
1289
|
+
stride_osk_z: "VAR_ARGS_ARRAY",
|
|
1290
|
+
stride_osk_g: "VAR_ARGS_ARRAY",
|
|
1291
|
+
stride_osk_h: "VAR_ARGS_ARRAY",
|
|
1292
|
+
stride_osk_m: "VAR_ARGS_ARRAY",
|
|
1293
|
+
stride_osk_k: "VAR_ARGS_ARRAY",
|
|
1294
|
+
stride_lsek_z: "VAR_ARGS_ARRAY",
|
|
1295
|
+
stride_lsek_g: "VAR_ARGS_ARRAY",
|
|
1296
|
+
stride_lsek_h: "VAR_ARGS_ARRAY",
|
|
1297
|
+
stride_lsek_m: "VAR_ARGS_ARRAY",
|
|
1298
|
+
stride_oz,
|
|
1299
|
+
stride_og,
|
|
1300
|
+
stride_oh,
|
|
1301
|
+
stride_om,
|
|
1302
|
+
stride_ok,
|
|
1303
|
+
stride_lse_z,
|
|
1304
|
+
stride_lse_g,
|
|
1305
|
+
stride_lse_h,
|
|
1306
|
+
stride_lse_m,
|
|
1307
|
+
head_dim: tl.constexpr,
|
|
1308
|
+
head_dim_pow_2: tl.constexpr,
|
|
1309
|
+
H: tl.constexpr,
|
|
1310
|
+
G: tl.constexpr,
|
|
1311
|
+
WRITE_LSE: tl.constexpr,
|
|
1312
|
+
):
|
|
1313
|
+
"""
|
|
1314
|
+
This version of reduce kernel takes attention and LSE of chunks as lists of tensors,
|
|
1315
|
+
as opposed to _splitK_reduce, which takes each as a stacked tensor.
|
|
1316
|
+
"""
|
|
1317
|
+
# grid = (M, B * G * H, 1)
|
|
1318
|
+
off_m = tl.program_id(0).to(tl.int64)
|
|
1319
|
+
off_zhg = tl.program_id(1).to(tl.int64)
|
|
1320
|
+
off_z = off_zhg // (H * G)
|
|
1321
|
+
off_h = (off_zhg // G) % H
|
|
1322
|
+
off_g = off_zhg % G
|
|
1323
|
+
head_dim_mask = tl.arange(0, head_dim_pow_2) < head_dim
|
|
1324
|
+
|
|
1325
|
+
out_splitk_offset: "VAR_ARGS_ARRAY" # noqa: F821
|
|
1326
|
+
for i in range(len(Out_splitK)):
|
|
1327
|
+
out_splitk_offset[i] = ( # noqa: F821
|
|
1328
|
+
stride_osk_z[i] * off_z # type: ignore # noqa: F821
|
|
1329
|
+
+ stride_osk_g[i] * off_g
|
|
1330
|
+
+ stride_osk_h[i] * off_h
|
|
1331
|
+
+ stride_osk_m[i] * off_m
|
|
1332
|
+
+ tl.arange(0, head_dim_pow_2)
|
|
1333
|
+
)
|
|
1334
|
+
lse_splitk_offset: "VAR_ARGS_ARRAY" # noqa: F821
|
|
1335
|
+
for i in range(len(Out_splitK)):
|
|
1336
|
+
lse_splitk_offset[i] = ( # noqa: F821
|
|
1337
|
+
stride_lsek_z[i] * off_z # type: ignore # noqa: F821
|
|
1338
|
+
+ stride_lsek_g[i] * off_g
|
|
1339
|
+
+ stride_lsek_h[i] * off_h
|
|
1340
|
+
+ stride_lsek_m[i] * off_m
|
|
1341
|
+
)
|
|
1342
|
+
|
|
1343
|
+
lse_max = float("-inf")
|
|
1344
|
+
for split_k_idx in range(len(Out_splitK)): # type: ignore # noqa: F821
|
|
1345
|
+
LSE_splitK_ptr = LSE_splitK[split_k_idx] + lse_splitk_offset[split_k_idx] # type: ignore # noqa: F821
|
|
1346
|
+
lse_splitk = tl.load(LSE_splitK_ptr)
|
|
1347
|
+
lse_max = tl.maximum(lse_max, lse_splitk)
|
|
1348
|
+
|
|
1349
|
+
sumexp_normalized = 0.0
|
|
1350
|
+
numerator_normalized = tl.zeros([head_dim_pow_2], dtype=tl.float32)
|
|
1351
|
+
|
|
1352
|
+
for split_k_idx in range(len(Out_splitK)): # type: ignore # noqa: F821
|
|
1353
|
+
out_splitk = tl.load(
|
|
1354
|
+
Out_splitK[split_k_idx] + out_splitk_offset[split_k_idx], # type: ignore # noqa: F821
|
|
1355
|
+
mask=head_dim_mask,
|
|
1356
|
+
)
|
|
1357
|
+
lse_splitk = tl.load(LSE_splitK[split_k_idx] + lse_splitk_offset[split_k_idx]) # type: ignore # noqa: F821
|
|
1358
|
+
# Compute denominator
|
|
1359
|
+
sumexp_normalized_splitk = tl.math.exp2(
|
|
1360
|
+
(lse_splitk - lse_max).to(tl.float32) * 1.44269504
|
|
1361
|
+
)
|
|
1362
|
+
sumexp_normalized += sumexp_normalized_splitk
|
|
1363
|
+
|
|
1364
|
+
# Compute numerator
|
|
1365
|
+
numerator_normalized += out_splitk * sumexp_normalized_splitk
|
|
1366
|
+
|
|
1367
|
+
acc = numerator_normalized / sumexp_normalized
|
|
1368
|
+
acc = tl.where(lse_max == float("-inf"), 0.0, acc)
|
|
1369
|
+
|
|
1370
|
+
Out_ptr = (
|
|
1371
|
+
Out
|
|
1372
|
+
+ stride_oz * off_z
|
|
1373
|
+
+ stride_oh * off_h
|
|
1374
|
+
+ stride_og * off_g
|
|
1375
|
+
+ stride_om * off_m
|
|
1376
|
+
+ tl.arange(0, head_dim_pow_2)
|
|
1377
|
+
)
|
|
1378
|
+
if acc.dtype is tl.float64 and Out.dtype.element_ty is not tl.float64:
|
|
1379
|
+
# must avoid direct cast f64->f16
|
|
1380
|
+
acc = acc.to(tl.float32)
|
|
1381
|
+
tl.store(Out_ptr, acc, mask=head_dim_mask)
|
|
1382
|
+
|
|
1383
|
+
if WRITE_LSE:
|
|
1384
|
+
l_ptrs = (
|
|
1385
|
+
LSE
|
|
1386
|
+
+ off_z * stride_lse_z
|
|
1387
|
+
+ off_g * stride_lse_g
|
|
1388
|
+
+ off_h * stride_lse_h
|
|
1389
|
+
+ off_m * stride_lse_m
|
|
1390
|
+
)
|
|
1391
|
+
to_store = lse_max + tl.math.log2(sumexp_normalized) / 1.44269504
|
|
1392
|
+
to_store = tl.where(lse_max == float("-inf"), lse_max, to_store)
|
|
1393
|
+
tl.store(l_ptrs, to_store)
|
|
1394
|
+
|
|
1395
|
+
|
|
1396
|
+
@triton.jit
|
|
1397
|
+
def _splitK_reduce_varargs_backward(
|
|
1398
|
+
Out_splitK: "VAR_ARGS_ARRAY", # list of [B, G, H, Mq, K];
|
|
1399
|
+
LSE_splitK: "VAR_ARGS_ARRAY", # list of [B, G, H, Mq]
|
|
1400
|
+
Dout_splitK: "VAR_ARGS_ARRAY", # gradients - same shape as the inputs themselves
|
|
1401
|
+
DLSE_splitK: "VAR_ARGS_ARRAY",
|
|
1402
|
+
Out, # [B, G, H, M, K]
|
|
1403
|
+
LSE, # [B, G, H, M]
|
|
1404
|
+
DOut,
|
|
1405
|
+
DLSE,
|
|
1406
|
+
# strides of chunked inputs: attention and LSE
|
|
1407
|
+
stride_osk_z: "VAR_ARGS_ARRAY",
|
|
1408
|
+
stride_osk_g: "VAR_ARGS_ARRAY",
|
|
1409
|
+
stride_osk_h: "VAR_ARGS_ARRAY",
|
|
1410
|
+
stride_osk_m: "VAR_ARGS_ARRAY",
|
|
1411
|
+
stride_osk_k: "VAR_ARGS_ARRAY",
|
|
1412
|
+
stride_lsek_z: "VAR_ARGS_ARRAY",
|
|
1413
|
+
stride_lsek_g: "VAR_ARGS_ARRAY",
|
|
1414
|
+
stride_lsek_h: "VAR_ARGS_ARRAY",
|
|
1415
|
+
stride_lsek_m: "VAR_ARGS_ARRAY",
|
|
1416
|
+
# strides of merged outputs: attention and LSE
|
|
1417
|
+
stride_oz,
|
|
1418
|
+
stride_og,
|
|
1419
|
+
stride_oh,
|
|
1420
|
+
stride_om,
|
|
1421
|
+
stride_ok,
|
|
1422
|
+
stride_lse_z,
|
|
1423
|
+
stride_lse_g,
|
|
1424
|
+
stride_lse_h,
|
|
1425
|
+
stride_lse_m,
|
|
1426
|
+
# strides of gradients
|
|
1427
|
+
stride_doz,
|
|
1428
|
+
stride_dog,
|
|
1429
|
+
stride_doh,
|
|
1430
|
+
stride_dom,
|
|
1431
|
+
stride_dok,
|
|
1432
|
+
stride_dlse_z,
|
|
1433
|
+
stride_dlse_g,
|
|
1434
|
+
stride_dlse_h,
|
|
1435
|
+
stride_dlse_m,
|
|
1436
|
+
BLOCK_SIZE: tl.constexpr,
|
|
1437
|
+
H: tl.constexpr,
|
|
1438
|
+
G: tl.constexpr,
|
|
1439
|
+
):
|
|
1440
|
+
"""
|
|
1441
|
+
Backward for _splitK_reduce_varargs. Similar to forward, it takes
|
|
1442
|
+
attention and LSE of chunks as lists of tensors,
|
|
1443
|
+
and outputs the corresponding gradients in the same format.
|
|
1444
|
+
"""
|
|
1445
|
+
|
|
1446
|
+
# grid = (M, B * G * H, 1)
|
|
1447
|
+
off_m = tl.program_id(0).to(tl.int64)
|
|
1448
|
+
off_zhg = tl.program_id(1).to(tl.int64)
|
|
1449
|
+
off_z = off_zhg // (H * G)
|
|
1450
|
+
off_h = (off_zhg // G) % H
|
|
1451
|
+
off_g = off_zhg % G
|
|
1452
|
+
|
|
1453
|
+
# Compute offsets inside each attention/LSE chunk.
|
|
1454
|
+
# Note that each chunk can have different strides, so offsets can also be different.
|
|
1455
|
+
out_splitk_offset: "VAR_ARGS_ARRAY" # noqa: F821
|
|
1456
|
+
for i in range(len(Out_splitK)):
|
|
1457
|
+
out_splitk_offset[i] = ( # type: ignore # noqa: F821
|
|
1458
|
+
stride_osk_z[i] * off_z
|
|
1459
|
+
+ stride_osk_g[i] * off_g
|
|
1460
|
+
+ stride_osk_h[i] * off_h
|
|
1461
|
+
+ stride_osk_m[i] * off_m
|
|
1462
|
+
+ tl.arange(0, BLOCK_SIZE)
|
|
1463
|
+
)
|
|
1464
|
+
lse_splitk_offset: "VAR_ARGS_ARRAY" # noqa: F821
|
|
1465
|
+
for i in range(len(Out_splitK)):
|
|
1466
|
+
lse_splitk_offset[i] = ( # type: ignore # noqa: F821
|
|
1467
|
+
stride_lsek_z[i] * off_z
|
|
1468
|
+
+ stride_lsek_g[i] * off_g
|
|
1469
|
+
+ stride_lsek_h[i] * off_h
|
|
1470
|
+
+ stride_lsek_m[i] * off_m
|
|
1471
|
+
)
|
|
1472
|
+
|
|
1473
|
+
lse_max = float("-inf")
|
|
1474
|
+
for split_k_idx in range(len(Out_splitK)): # type: ignore # noqa: F821
|
|
1475
|
+
LSE_splitK_ptr = LSE_splitK[split_k_idx] + lse_splitk_offset[split_k_idx] # type: ignore # noqa: F821
|
|
1476
|
+
lse_splitk = tl.load(LSE_splitK_ptr)
|
|
1477
|
+
lse_max = tl.maximum(lse_max, lse_splitk)
|
|
1478
|
+
|
|
1479
|
+
# Load attention and the corresponding gradient
|
|
1480
|
+
offset_out = (
|
|
1481
|
+
stride_oz * off_z
|
|
1482
|
+
+ stride_oh * off_h
|
|
1483
|
+
+ stride_og * off_g
|
|
1484
|
+
+ stride_om * off_m
|
|
1485
|
+
+ tl.arange(0, BLOCK_SIZE)
|
|
1486
|
+
)
|
|
1487
|
+
offset_dout = (
|
|
1488
|
+
stride_doz * off_z
|
|
1489
|
+
+ stride_doh * off_h
|
|
1490
|
+
+ stride_dog * off_g
|
|
1491
|
+
+ stride_dom * off_m
|
|
1492
|
+
+ tl.arange(0, BLOCK_SIZE)
|
|
1493
|
+
)
|
|
1494
|
+
out = tl.load(Out + offset_out)
|
|
1495
|
+
dattn = tl.load(DOut + offset_dout)
|
|
1496
|
+
|
|
1497
|
+
# Load LSE and the corresponding gradient
|
|
1498
|
+
offset_lse = (
|
|
1499
|
+
stride_lse_z * off_z
|
|
1500
|
+
+ stride_lse_h * off_h
|
|
1501
|
+
+ stride_lse_g * off_g
|
|
1502
|
+
+ stride_lse_m * off_m
|
|
1503
|
+
)
|
|
1504
|
+
offset_dlse = (
|
|
1505
|
+
stride_dlse_z * off_z
|
|
1506
|
+
+ stride_dlse_h * off_h
|
|
1507
|
+
+ stride_dlse_g * off_g
|
|
1508
|
+
+ stride_dlse_m * off_m
|
|
1509
|
+
)
|
|
1510
|
+
lse = tl.load(LSE + offset_lse)
|
|
1511
|
+
dlse = tl.load(DLSE + offset_dlse)
|
|
1512
|
+
|
|
1513
|
+
for split_k_idx in range(len(Out_splitK)): # type: ignore # noqa: F821
|
|
1514
|
+
# Load attention and LSE of chunks
|
|
1515
|
+
out_splitk = tl.load(Out_splitK[split_k_idx] + out_splitk_offset[split_k_idx]) # type: ignore # noqa: F821
|
|
1516
|
+
lse_splitk = tl.load(LSE_splitK[split_k_idx] + lse_splitk_offset[split_k_idx]) # type: ignore # noqa: F821
|
|
1517
|
+
|
|
1518
|
+
# Pointers to save gradients of attention and LSE of chunks
|
|
1519
|
+
dout_splitk_ptr = Dout_splitK[split_k_idx] + out_splitk_offset[split_k_idx] # type: ignore # noqa: F821
|
|
1520
|
+
dlse_splitk_ptr = DLSE_splitK[split_k_idx] + lse_splitk_offset[split_k_idx] # type: ignore # noqa: F821
|
|
1521
|
+
|
|
1522
|
+
# dX/dattn_i = dX/dattn * dattn/dattn_i + dX/dlse * dlse/dattn_i, and dlse/dattn_i == 0
|
|
1523
|
+
dattn_dattn_i = tl.exp(lse_splitk - lse_max) / tl.exp(lse - lse_max)
|
|
1524
|
+
dX_dattn_i = dattn_dattn_i * dattn
|
|
1525
|
+
tl.store(dout_splitk_ptr, dX_dattn_i)
|
|
1526
|
+
|
|
1527
|
+
dattn_dlse_i = (out_splitk - out) * dattn_dattn_i
|
|
1528
|
+
|
|
1529
|
+
# dX/dlse_i = dX/dattn * dattn/dlse_i + dX/dlse * dlse/dlse_i
|
|
1530
|
+
dlse_dlse_i = dattn_dattn_i
|
|
1531
|
+
dX_dlse_i = dlse_dlse_i * dlse + tl.sum(
|
|
1532
|
+
dattn_dlse_i * dattn
|
|
1533
|
+
) # Sum is over the hidden dimension
|
|
1534
|
+
tl.store(dlse_splitk_ptr, dX_dlse_i)
|