mslk-cuda-nightly 2026.1.19__cp310-cp310-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mslk/__init__.py +56 -0
- mslk/attention/__init__.py +7 -0
- mslk/attention/cutlass_blackwell_fmha/__init__.py +30 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_custom_op.py +332 -0
- mslk/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +533 -0
- mslk/attention/flash_attn/__init__.py +22 -0
- mslk/attention/flash_attn/ampere_helpers.py +104 -0
- mslk/attention/flash_attn/barrier.py +72 -0
- mslk/attention/flash_attn/benchmark.py +269 -0
- mslk/attention/flash_attn/blackwell_helpers.py +754 -0
- mslk/attention/flash_attn/block_info.py +109 -0
- mslk/attention/flash_attn/block_sparse_utils.py +1452 -0
- mslk/attention/flash_attn/block_sparsity.py +219 -0
- mslk/attention/flash_attn/compute_block_sparsity.py +378 -0
- mslk/attention/flash_attn/copy_utils.py +341 -0
- mslk/attention/flash_attn/cute_dsl_utils.py +135 -0
- mslk/attention/flash_attn/fast_math.py +22 -0
- mslk/attention/flash_attn/flash_bwd.py +1262 -0
- mslk/attention/flash_attn/flash_bwd_postprocess.py +464 -0
- mslk/attention/flash_attn/flash_bwd_preprocess.py +366 -0
- mslk/attention/flash_attn/flash_bwd_sm100.py +2951 -0
- mslk/attention/flash_attn/flash_bwd_sm90.py +1703 -0
- mslk/attention/flash_attn/flash_fwd.py +2471 -0
- mslk/attention/flash_attn/flash_fwd_combine.py +705 -0
- mslk/attention/flash_attn/flash_fwd_sm100.py +2727 -0
- mslk/attention/flash_attn/hopper_helpers.py +102 -0
- mslk/attention/flash_attn/interface.py +1771 -0
- mslk/attention/flash_attn/mask.py +610 -0
- mslk/attention/flash_attn/mma_sm100_desc.py +292 -0
- mslk/attention/flash_attn/named_barrier.py +32 -0
- mslk/attention/flash_attn/pack_gqa.py +165 -0
- mslk/attention/flash_attn/paged_kv.py +176 -0
- mslk/attention/flash_attn/pipeline.py +273 -0
- mslk/attention/flash_attn/seqlen_info.py +139 -0
- mslk/attention/flash_attn/softmax.py +583 -0
- mslk/attention/flash_attn/testing.py +424 -0
- mslk/attention/flash_attn/tile_scheduler.py +720 -0
- mslk/attention/flash_attn/utils.py +860 -0
- mslk/attention/fmha/__init__.py +967 -0
- mslk/attention/fmha/_triton/__init__.py +6 -0
- mslk/attention/fmha/_triton/available.py +50 -0
- mslk/attention/fmha/_triton/splitk_kernels.py +1534 -0
- mslk/attention/fmha/_triton/vararg_kernel.py +262 -0
- mslk/attention/fmha/attn_bias.py +2186 -0
- mslk/attention/fmha/attn_bias_utils.py +536 -0
- mslk/attention/fmha/ck.py +508 -0
- mslk/attention/fmha/ck_decoder.py +141 -0
- mslk/attention/fmha/ck_splitk.py +204 -0
- mslk/attention/fmha/common.py +598 -0
- mslk/attention/fmha/cutlass.py +461 -0
- mslk/attention/fmha/cutlass_blackwell.py +560 -0
- mslk/attention/fmha/dispatch.py +224 -0
- mslk/attention/fmha/flash.py +862 -0
- mslk/attention/fmha/flash3.py +858 -0
- mslk/attention/fmha/flash_mtia.py +245 -0
- mslk/attention/fmha/merge_training.py +192 -0
- mslk/attention/fmha/split_blocks_fairinternal.py +329 -0
- mslk/attention/fmha/torch_attention_compat.py +154 -0
- mslk/attention/fmha/tree_attention.py +718 -0
- mslk/attention/fmha/triton_splitk.py +1378 -0
- mslk/attention/fmha/unbind.py +130 -0
- mslk/attention/fmha/utils/__init__.py +6 -0
- mslk/attention/fmha/utils/bench.py +74 -0
- mslk/attention/fmha/utils/cpp_lib.py +148 -0
- mslk/attention/fmha/utils/op_common.py +65 -0
- mslk/attention/gqa_attn_splitk/__init__.py +11 -0
- mslk/bench/comm/__init__.py +7 -0
- mslk/bench/comm/comm_bench.py +255 -0
- mslk/bench/common/__init__.py +5 -0
- mslk/bench/common/utils.py +148 -0
- mslk/bench/conv/__init__.py +7 -0
- mslk/bench/conv/conv_bench.py +551 -0
- mslk/bench/conv/conv_ops.py +213 -0
- mslk/bench/gemm/__init__.py +7 -0
- mslk/bench/gemm/gemm_bench.py +859 -0
- mslk/bench/gemm/gemm_ops.py +3342 -0
- mslk/bench/gemm/grouped_gemm_bias_scale_benchmark.py +177 -0
- mslk/bench/moe/__init__.py +7 -0
- mslk/bench/moe/gather_scatter_bench.py +356 -0
- mslk/bench/quantize/quantize_bench.py +345 -0
- mslk/bench/quantize/quantize_ops.py +266 -0
- mslk/comm/__init__.py +11 -0
- mslk/conv/__init__.py +11 -0
- mslk/gemm/__init__.py +18 -0
- mslk/gemm/triton/__init__.py +7 -0
- mslk/gemm/triton/fp8_gemm.py +2702 -0
- mslk/gemm/triton/grouped_gemm.py +1132 -0
- mslk/gemm/triton/matmul_perf_model.py +237 -0
- mslk/gemm/triton/utils.py +128 -0
- mslk/kv_cache/__init__.py +11 -0
- mslk/moe/__init__.py +26 -0
- mslk/moe/activation.py +291 -0
- mslk/moe/gather_scatter.py +739 -0
- mslk/moe/layers.py +1240 -0
- mslk/moe/shuffling.py +421 -0
- mslk/mslk.so +0 -0
- mslk/quantize/__init__.py +11 -0
- mslk/quantize/shuffle.py +306 -0
- mslk/quantize/triton/__init__.py +7 -0
- mslk/quantize/triton/fp4_quantize.py +5942 -0
- mslk/quantize/triton/fp8_quantize.py +1902 -0
- mslk/testing/__init__.py +7 -0
- mslk/testing/attributes.py +60 -0
- mslk/testing/rocm.py +91 -0
- mslk/utils/__init__.py +7 -0
- mslk/utils/torch/__init__.py +7 -0
- mslk/utils/torch/library.py +150 -0
- mslk/utils/triton/__init__.py +7 -0
- mslk/utils/triton/fp8_utils.py +72 -0
- mslk/utils/triton/utils.py +128 -0
- mslk/version.py +11 -0
- mslk_cuda_nightly-2026.1.19.dist-info/METADATA +102 -0
- mslk_cuda_nightly-2026.1.19.dist-info/RECORD +116 -0
- mslk_cuda_nightly-2026.1.19.dist-info/WHEEL +5 -0
- mslk_cuda_nightly-2026.1.19.dist-info/licenses/LICENSE +30 -0
- mslk_cuda_nightly-2026.1.19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,2951 @@
|
|
|
1
|
+
# @nolint # fbcode
|
|
2
|
+
# Copyright (c) 2025, Ted Zadouri, Markus Hoehnerbach, Jay Shah, Tri Dao.
|
|
3
|
+
import math
|
|
4
|
+
from typing import Callable, Optional
|
|
5
|
+
from functools import partial
|
|
6
|
+
|
|
7
|
+
import cuda.bindings.driver as cuda
|
|
8
|
+
|
|
9
|
+
import cutlass
|
|
10
|
+
import cutlass.cute as cute
|
|
11
|
+
from cutlass.cute import FastDivmodDivisor
|
|
12
|
+
from cutlass import Float32, Int32, const_expr
|
|
13
|
+
from cutlass.utils import LayoutEnum
|
|
14
|
+
from cutlass.cute.nvgpu import cpasync, tcgen05
|
|
15
|
+
import cutlass.utils.blackwell_helpers as sm100_utils_basic
|
|
16
|
+
from cutlass.pipeline import PipelineAsync, PipelineConsumer
|
|
17
|
+
|
|
18
|
+
from mslk.attention.flash_attn import utils
|
|
19
|
+
from mslk.attention.flash_attn import copy_utils
|
|
20
|
+
from mslk.attention.flash_attn import pipeline
|
|
21
|
+
from mslk.attention.flash_attn.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa
|
|
22
|
+
from mslk.attention.flash_attn.mask import AttentionMask
|
|
23
|
+
from mslk.attention.flash_attn.seqlen_info import SeqlenInfoQK
|
|
24
|
+
from mslk.attention.flash_attn.block_info import BlockInfo
|
|
25
|
+
from mslk.attention.flash_attn.tile_scheduler import (
|
|
26
|
+
TileSchedulerArguments,
|
|
27
|
+
SingleTileScheduler,
|
|
28
|
+
SingleTileLPTBwdScheduler, # noqa
|
|
29
|
+
SingleTileVarlenScheduler,
|
|
30
|
+
ParamsBase,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
from mslk.attention.flash_attn import barrier
|
|
34
|
+
from mslk.attention.flash_attn.named_barrier import NamedBarrierBwdSm100
|
|
35
|
+
from mslk.attention.flash_attn.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner
|
|
36
|
+
from mslk.attention.flash_attn.block_sparsity import BlockSparseTensors
|
|
37
|
+
from mslk.attention.flash_attn.block_sparse_utils import (
|
|
38
|
+
get_total_q_block_count_bwd,
|
|
39
|
+
get_block_sparse_iteration_info_bwd,
|
|
40
|
+
get_m_block_from_iter_bwd,
|
|
41
|
+
produce_block_sparse_q_loads_bwd_sm100,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class FlashAttentionBackwardSm100:
|
|
46
|
+
arch = 100
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
head_dim: int,
|
|
51
|
+
head_dim_v: Optional[int] = None,
|
|
52
|
+
is_causal: bool = False,
|
|
53
|
+
is_local: bool = False,
|
|
54
|
+
qhead_per_kvhead: cutlass.Constexpr[int] = 1,
|
|
55
|
+
tile_m: int = 128,
|
|
56
|
+
tile_n: int = 128,
|
|
57
|
+
is_persistent: bool = False,
|
|
58
|
+
deterministic: bool = False,
|
|
59
|
+
cluster_size: int = 1,
|
|
60
|
+
score_mod: cutlass.Constexpr | None = None,
|
|
61
|
+
score_mod_bwd: cutlass.Constexpr | None = None,
|
|
62
|
+
mask_mod: cutlass.Constexpr | None = None,
|
|
63
|
+
has_aux_tensors: cutlass.Constexpr = False,
|
|
64
|
+
subtile_factor: cutlass.Constexpr[int] = 1,
|
|
65
|
+
):
|
|
66
|
+
# padding head_dim to a multiple of 16 as k_block_size
|
|
67
|
+
hdim_multiple_of = 16
|
|
68
|
+
self.tile_hdim = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
|
|
69
|
+
head_dim_v = head_dim_v if head_dim_v is not None else head_dim
|
|
70
|
+
self.same_hdim_kv = head_dim == head_dim_v
|
|
71
|
+
assert head_dim == head_dim_v, "head_dim and head_dim_v must be the same for now"
|
|
72
|
+
self.tile_hdimv = int(math.ceil(head_dim_v / hdim_multiple_of) * hdim_multiple_of)
|
|
73
|
+
assert self.tile_hdim == self.tile_hdimv, (
|
|
74
|
+
"tile_hdim and tile_hdimv must be the same for now"
|
|
75
|
+
)
|
|
76
|
+
self.check_hdim_oob = head_dim != self.tile_hdim
|
|
77
|
+
self.check_hdim_v_oob = head_dim_v != self.tile_hdimv
|
|
78
|
+
|
|
79
|
+
self.tile_m = tile_m
|
|
80
|
+
self.tile_n = tile_n
|
|
81
|
+
|
|
82
|
+
# CTA tiler
|
|
83
|
+
self.cta_tiler = (tile_n, tile_m, self.tile_hdim)
|
|
84
|
+
# S = K @ Q.T
|
|
85
|
+
self.mma_tiler_kq = (tile_n, tile_m, self.tile_hdim)
|
|
86
|
+
# dP = V @ dO.T
|
|
87
|
+
self.mma_tiler_vdo = (tile_n, tile_m, self.tile_hdimv)
|
|
88
|
+
# dV = P.T @ dO
|
|
89
|
+
self.mma_tiler_pdo = (tile_n, self.tile_hdimv, tile_m)
|
|
90
|
+
# dK = dS.T @ Q (N, M) (M, D)
|
|
91
|
+
self.mma_tiler_dsq = (tile_n, self.tile_hdimv, tile_m)
|
|
92
|
+
# dQ = dS @ K
|
|
93
|
+
self.mma_tiler_dsk = (tile_m, self.tile_hdimv, tile_n)
|
|
94
|
+
|
|
95
|
+
self.acc_dtype = Float32
|
|
96
|
+
|
|
97
|
+
assert cluster_size in (1, 2), "Only cluster_size=1 or 2 is supported"
|
|
98
|
+
self.cluster_shape_mn = (cluster_size, 1)
|
|
99
|
+
self.is_persistent = is_persistent
|
|
100
|
+
self.is_causal = is_causal
|
|
101
|
+
self.is_local = is_local
|
|
102
|
+
self.qhead_per_kvhead = qhead_per_kvhead
|
|
103
|
+
self.pack_gqa = False
|
|
104
|
+
self.deterministic = deterministic
|
|
105
|
+
|
|
106
|
+
# Score mod and mask mod support
|
|
107
|
+
self.score_mod = score_mod
|
|
108
|
+
self.score_mod_bwd = score_mod_bwd
|
|
109
|
+
self.mask_mod = mask_mod
|
|
110
|
+
self.has_aux_tensors = has_aux_tensors
|
|
111
|
+
self.subtile_factor = subtile_factor
|
|
112
|
+
# For score_mod, use vec_size=1 (like forward) to handle per-element indices
|
|
113
|
+
if cutlass.const_expr(has_aux_tensors):
|
|
114
|
+
self.vec_size: cutlass.Constexpr = 1
|
|
115
|
+
else:
|
|
116
|
+
self.vec_size: cutlass.Constexpr = 4
|
|
117
|
+
self.qk_acc_dtype = Float32
|
|
118
|
+
|
|
119
|
+
# Speed optimizations, does not affect correctness
|
|
120
|
+
self.shuffle_LSE = False
|
|
121
|
+
self.shuffle_dPsum = False
|
|
122
|
+
self.use_smem_dS_for_mma_dK = self.deterministic and self.is_causal
|
|
123
|
+
|
|
124
|
+
self.reduce_warp_ids = (0, 1, 2, 3)
|
|
125
|
+
self.compute_warp_ids = (4, 5, 6, 7, 8, 9, 10, 11)
|
|
126
|
+
self.mma_warp_id = 12
|
|
127
|
+
self.load_warp_id = 13
|
|
128
|
+
self.epi_warp_id = 14
|
|
129
|
+
self.empty_warp_id = 15
|
|
130
|
+
|
|
131
|
+
# 16 warps -> 512 threads
|
|
132
|
+
self.threads_per_cta = cute.arch.WARP_SIZE * len(
|
|
133
|
+
(
|
|
134
|
+
*self.reduce_warp_ids,
|
|
135
|
+
*self.compute_warp_ids,
|
|
136
|
+
self.mma_warp_id,
|
|
137
|
+
self.load_warp_id,
|
|
138
|
+
self.epi_warp_id,
|
|
139
|
+
self.empty_warp_id,
|
|
140
|
+
)
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# NamedBarrier
|
|
144
|
+
self.compute_sync_barrier = cutlass.pipeline.NamedBarrier(
|
|
145
|
+
barrier_id=int(NamedBarrierBwdSm100.Compute),
|
|
146
|
+
num_threads=len(self.compute_warp_ids) * cute.arch.WARP_SIZE,
|
|
147
|
+
)
|
|
148
|
+
# self.epilogue_sync_barrier = pipeline.NamedBarrier(
|
|
149
|
+
# barrier_id=2,
|
|
150
|
+
# num_threads=self.num_compute_warps * self.threads_per_warp,
|
|
151
|
+
# )
|
|
152
|
+
self.reduce_sync_barrier = cutlass.pipeline.NamedBarrier(
|
|
153
|
+
barrier_id=int(NamedBarrierBwdSm100.dQaccReduce),
|
|
154
|
+
num_threads=len(self.reduce_warp_ids) * cute.arch.WARP_SIZE,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# TMEM setup
|
|
158
|
+
SM100_TMEM_CAPACITY_COLUMNS = 512
|
|
159
|
+
self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS
|
|
160
|
+
|
|
161
|
+
# self.tmem_dK_offset = 0
|
|
162
|
+
# self.tmem_dV_offset = self.tmem_dK_offset + self.tile_hdim
|
|
163
|
+
# self.tmem_dQ_offset = self.tmem_dV_offset + self.tile_hdimv
|
|
164
|
+
# self.tmem_dP_offset = self.tmem_dQ_offset # overlap with dQ
|
|
165
|
+
# self.tmem_S_offset = self.tmem_dQ_offset + max(self.tile_m, self.tile_hdim)
|
|
166
|
+
# self.tmem_P_offset = self.tmem_S_offset # overlap with S
|
|
167
|
+
# self.tmem_total = self.tmem_S_offset + self.tile_n
|
|
168
|
+
# assert self.tmem_total <= self.tmem_alloc_cols
|
|
169
|
+
|
|
170
|
+
self.tmem_S_offset = 0
|
|
171
|
+
self.tmem_P_offset = 0 # overlap with S
|
|
172
|
+
self.tmem_dV_offset = self.tmem_S_offset + self.tile_n
|
|
173
|
+
self.tmem_dP_offset = self.tmem_dV_offset + self.tile_hdimv
|
|
174
|
+
self.tmem_dQ_offset = self.tmem_dP_offset # overlap with dP
|
|
175
|
+
self.tmem_dK_offset = self.tmem_dP_offset + self.tile_m
|
|
176
|
+
self.tmem_dS_offset = self.tmem_dP_offset # overlap with dP
|
|
177
|
+
|
|
178
|
+
if (not is_causal and not is_local) or deterministic:
|
|
179
|
+
self.num_regs_reduce = 152
|
|
180
|
+
self.num_regs_compute = 136
|
|
181
|
+
else:
|
|
182
|
+
self.num_regs_reduce = 136
|
|
183
|
+
self.num_regs_compute = 144
|
|
184
|
+
self.num_regs_other = 96 - 8
|
|
185
|
+
self.num_regs_empty = 24
|
|
186
|
+
assert self.num_regs_reduce + self.num_regs_compute * 2 + self.num_regs_other <= 512
|
|
187
|
+
|
|
188
|
+
self.buffer_align_bytes = 1024
|
|
189
|
+
|
|
190
|
+
def _setup_attributes(self):
|
|
191
|
+
self.Q_stage = 2
|
|
192
|
+
self.dO_stage = 1
|
|
193
|
+
# LSE_stage = Q_stage and dPsum_stage = dO_stage
|
|
194
|
+
# self.sdKVaccum_stage = 2
|
|
195
|
+
# number of tma reduce adds per dQacc mma
|
|
196
|
+
self.dQ_reduce_ncol = 32
|
|
197
|
+
self.sdQaccum_stage = 64 // self.dQ_reduce_ncol
|
|
198
|
+
assert self.tile_hdim % self.dQ_reduce_ncol == 0
|
|
199
|
+
self.dQaccum_reduce_stage = self.tile_hdim // self.dQ_reduce_ncol
|
|
200
|
+
self.cluster_reduce_dQ = False and cute.size(self.cluster_shape_mn) > 1
|
|
201
|
+
# number of tma reduce adds for dKacc and dVacc epilogue
|
|
202
|
+
self.dK_reduce_ncol = 32
|
|
203
|
+
|
|
204
|
+
def _get_tiled_mma(self):
|
|
205
|
+
cta_group = tcgen05.CtaGroup.ONE
|
|
206
|
+
# S = K @ Q.T
|
|
207
|
+
tiled_mma_S = sm100_utils_basic.make_trivial_tiled_mma(
|
|
208
|
+
self.q_dtype,
|
|
209
|
+
tcgen05.OperandMajorMode.K,
|
|
210
|
+
tcgen05.OperandMajorMode.K,
|
|
211
|
+
self.acc_dtype,
|
|
212
|
+
cta_group,
|
|
213
|
+
self.mma_tiler_kq[:2],
|
|
214
|
+
)
|
|
215
|
+
# dP = V @ dO.T
|
|
216
|
+
tiled_mma_dP = sm100_utils_basic.make_trivial_tiled_mma(
|
|
217
|
+
self.do_dtype,
|
|
218
|
+
tcgen05.OperandMajorMode.K,
|
|
219
|
+
tcgen05.OperandMajorMode.K,
|
|
220
|
+
self.acc_dtype,
|
|
221
|
+
cta_group,
|
|
222
|
+
self.mma_tiler_vdo[:2],
|
|
223
|
+
)
|
|
224
|
+
# dV += P @ dO --> (K, MN) major
|
|
225
|
+
tiled_mma_dV = sm100_utils_basic.make_trivial_tiled_mma(
|
|
226
|
+
self.do_dtype,
|
|
227
|
+
tcgen05.OperandMajorMode.K, # P_major_mode
|
|
228
|
+
tcgen05.OperandMajorMode.MN, # dO_major_mode
|
|
229
|
+
self.acc_dtype,
|
|
230
|
+
cta_group,
|
|
231
|
+
self.mma_tiler_pdo[:2],
|
|
232
|
+
a_source=tcgen05.OperandSource.TMEM,
|
|
233
|
+
)
|
|
234
|
+
# dK += dS.T @ Q
|
|
235
|
+
if const_expr(self.use_smem_dS_for_mma_dK):
|
|
236
|
+
mma_dK_a_src = tcgen05.OperandSource.SMEM
|
|
237
|
+
else:
|
|
238
|
+
mma_dK_a_src = tcgen05.OperandSource.TMEM
|
|
239
|
+
tiled_mma_dK = sm100_utils_basic.make_trivial_tiled_mma(
|
|
240
|
+
self.do_dtype,
|
|
241
|
+
tcgen05.OperandMajorMode.K, # dS_major_mode
|
|
242
|
+
tcgen05.OperandMajorMode.MN, # Q_major_mode
|
|
243
|
+
self.acc_dtype,
|
|
244
|
+
cta_group,
|
|
245
|
+
self.mma_tiler_dsq[:2],
|
|
246
|
+
a_source=mma_dK_a_src,
|
|
247
|
+
)
|
|
248
|
+
# dQ = dS @ K
|
|
249
|
+
tiled_mma_dQ = sm100_utils_basic.make_trivial_tiled_mma(
|
|
250
|
+
self.k_dtype,
|
|
251
|
+
tcgen05.OperandMajorMode.MN, # dS_major_mode
|
|
252
|
+
tcgen05.OperandMajorMode.MN, # Kt_major_mode
|
|
253
|
+
self.acc_dtype,
|
|
254
|
+
cta_group,
|
|
255
|
+
self.mma_tiler_dsk[:2],
|
|
256
|
+
)
|
|
257
|
+
return tiled_mma_S, tiled_mma_dP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ
|
|
258
|
+
|
|
259
|
+
def _setup_smem_layout(self):
|
|
260
|
+
# S = K @ Q.T
|
|
261
|
+
sK_layout = sm100_utils_basic.make_smem_layout_a(
|
|
262
|
+
self.tiled_mma_S,
|
|
263
|
+
self.mma_tiler_kq,
|
|
264
|
+
self.k_dtype,
|
|
265
|
+
1,
|
|
266
|
+
)
|
|
267
|
+
self.sK_layout = cute.slice_(sK_layout, (None, None, None, 0))
|
|
268
|
+
self.sQ_layout = sm100_utils_basic.make_smem_layout_b(
|
|
269
|
+
self.tiled_mma_S,
|
|
270
|
+
self.mma_tiler_kq,
|
|
271
|
+
self.q_dtype,
|
|
272
|
+
self.Q_stage,
|
|
273
|
+
)
|
|
274
|
+
# dP = V @ dO.T
|
|
275
|
+
sV_layout = sm100_utils_basic.make_smem_layout_a(
|
|
276
|
+
self.tiled_mma_dP,
|
|
277
|
+
self.mma_tiler_vdo,
|
|
278
|
+
self.v_dtype,
|
|
279
|
+
1,
|
|
280
|
+
)
|
|
281
|
+
self.sV_layout = cute.slice_(sV_layout, (None, None, None, 0))
|
|
282
|
+
self.sdOt_layout = sm100_utils_basic.make_smem_layout_b(
|
|
283
|
+
self.tiled_mma_dP,
|
|
284
|
+
self.mma_tiler_vdo,
|
|
285
|
+
self.do_dtype,
|
|
286
|
+
self.dO_stage,
|
|
287
|
+
)
|
|
288
|
+
# dV += P @ dO
|
|
289
|
+
tP_layout = sm100_utils_basic.make_smem_layout_a(
|
|
290
|
+
self.tiled_mma_dV,
|
|
291
|
+
self.mma_tiler_pdo,
|
|
292
|
+
self.do_dtype,
|
|
293
|
+
1,
|
|
294
|
+
)
|
|
295
|
+
self.tP_layout = cute.slice_(tP_layout, (None, None, None, 0))
|
|
296
|
+
self.sdO_layout = sm100_utils_basic.make_smem_layout_b(
|
|
297
|
+
self.tiled_mma_dV,
|
|
298
|
+
self.mma_tiler_pdo,
|
|
299
|
+
self.do_dtype,
|
|
300
|
+
self.dO_stage,
|
|
301
|
+
)
|
|
302
|
+
# dK += dS.T @ Q
|
|
303
|
+
sdSt_layout = sm100_utils_basic.make_smem_layout_a(
|
|
304
|
+
self.tiled_mma_dK,
|
|
305
|
+
self.mma_tiler_dsq,
|
|
306
|
+
self.ds_dtype,
|
|
307
|
+
1,
|
|
308
|
+
)
|
|
309
|
+
self.sdSt_layout = cute.slice_(sdSt_layout, (None, None, None, 0))
|
|
310
|
+
tdS_layout = sm100_utils_basic.make_smem_layout_a(
|
|
311
|
+
self.tiled_mma_dK,
|
|
312
|
+
self.mma_tiler_dsq,
|
|
313
|
+
self.ds_dtype,
|
|
314
|
+
1,
|
|
315
|
+
)
|
|
316
|
+
self.tdS_layout = cute.slice_(tdS_layout, (None, None, None, 0))
|
|
317
|
+
self.sQt_layout = sm100_utils_basic.make_smem_layout_b(
|
|
318
|
+
self.tiled_mma_dK,
|
|
319
|
+
self.mma_tiler_dsq,
|
|
320
|
+
self.q_dtype,
|
|
321
|
+
self.Q_stage,
|
|
322
|
+
)
|
|
323
|
+
# dQ = dS @ K
|
|
324
|
+
sdS_layout = sm100_utils_basic.make_smem_layout_a(
|
|
325
|
+
self.tiled_mma_dQ,
|
|
326
|
+
self.mma_tiler_dsk,
|
|
327
|
+
self.ds_dtype,
|
|
328
|
+
1,
|
|
329
|
+
)
|
|
330
|
+
self.sdS_layout = cute.slice_(sdS_layout, (None, None, None, 0))
|
|
331
|
+
sKt_layout = sm100_utils_basic.make_smem_layout_b(
|
|
332
|
+
self.tiled_mma_dQ,
|
|
333
|
+
self.mma_tiler_dsk,
|
|
334
|
+
self.k_dtype,
|
|
335
|
+
1,
|
|
336
|
+
)
|
|
337
|
+
self.sKt_layout = cute.slice_(sKt_layout, (None, None, None, 0))
|
|
338
|
+
self.sdQaccum_layout = cute.make_layout(
|
|
339
|
+
(self.tile_m * self.dQ_reduce_ncol, self.sdQaccum_stage)
|
|
340
|
+
)
|
|
341
|
+
self.sLSE_layout = cute.make_layout(
|
|
342
|
+
shape=(self.tile_m, self.Q_stage),
|
|
343
|
+
stride=(1, cute.round_up(self.tile_m, 64)),
|
|
344
|
+
)
|
|
345
|
+
self.sdPsum_layout = cute.make_layout(
|
|
346
|
+
shape=(self.tile_m, self.dO_stage),
|
|
347
|
+
stride=(1, cute.round_up(self.tile_m, 64)),
|
|
348
|
+
)
|
|
349
|
+
self.sdKV_epi_tile = (
|
|
350
|
+
self.tile_n,
|
|
351
|
+
min(128 // (self.dk_dtype.width // 8), self.tile_hdim // 2), # 64 or 32
|
|
352
|
+
) # subtiles mma_tiler_dsq[:2] = mma_tiler_pdo[:2]
|
|
353
|
+
# headdim_64 gets 1 stage
|
|
354
|
+
self.num_epi_stages = max(1, (self.tile_hdim // 2) // self.sdKV_epi_tile[1])
|
|
355
|
+
self.sdKV_flat_epi_tile = self.tile_n * (self.tile_hdim // 2) // self.num_epi_stages
|
|
356
|
+
# TODO: dK and dV could have different shapes
|
|
357
|
+
if const_expr(not self.dKV_postprocess):
|
|
358
|
+
self.sdKV_layout = sm100_utils_basic.make_smem_layout_epi(
|
|
359
|
+
self.dk_dtype,
|
|
360
|
+
LayoutEnum.ROW_MAJOR,
|
|
361
|
+
self.sdKV_epi_tile,
|
|
362
|
+
2, # num compute wgs
|
|
363
|
+
)
|
|
364
|
+
else:
|
|
365
|
+
self.sdKV_layout = cute.make_layout((self.tile_n * self.dK_reduce_ncol, 2))
|
|
366
|
+
|
|
367
|
+
@cute.jit
|
|
368
|
+
def __call__(
|
|
369
|
+
self,
|
|
370
|
+
mQ: cute.Tensor,
|
|
371
|
+
mK: cute.Tensor,
|
|
372
|
+
mV: cute.Tensor,
|
|
373
|
+
mdO: cute.Tensor,
|
|
374
|
+
mLSE: cute.Tensor,
|
|
375
|
+
mdPsum: cute.Tensor,
|
|
376
|
+
mdQaccum: cute.Tensor,
|
|
377
|
+
mdK: cute.Tensor,
|
|
378
|
+
mdV: cute.Tensor,
|
|
379
|
+
softmax_scale: Float32,
|
|
380
|
+
stream: cuda.CUstream,
|
|
381
|
+
mCuSeqlensQ: Optional[cute.Tensor] = None,
|
|
382
|
+
mCuSeqlensK: Optional[cute.Tensor] = None,
|
|
383
|
+
mSeqUsedQ: Optional[cute.Tensor] = None,
|
|
384
|
+
mSeqUsedK: Optional[cute.Tensor] = None,
|
|
385
|
+
softcap: Float32 | float | None = None,
|
|
386
|
+
window_size_left: Int32 | int | None = None,
|
|
387
|
+
window_size_right: Int32 | int | None = None,
|
|
388
|
+
mdQ_semaphore: Optional[cute.Tensor] = None,
|
|
389
|
+
mdK_semaphore: Optional[cute.Tensor] = None,
|
|
390
|
+
mdV_semaphore: Optional[cute.Tensor] = None,
|
|
391
|
+
aux_tensors: Optional[list] = None,
|
|
392
|
+
# Block-sparse tensors (Q direction - for iterating m_blocks per n_block):
|
|
393
|
+
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
|
394
|
+
):
|
|
395
|
+
self.q_dtype = mQ.element_type
|
|
396
|
+
self.k_dtype = mK.element_type
|
|
397
|
+
self.v_dtype = mV.element_type
|
|
398
|
+
self.do_dtype = mdO.element_type
|
|
399
|
+
self.lse_dtype = mLSE.element_type
|
|
400
|
+
self.dpsum_dtype = mdPsum.element_type
|
|
401
|
+
self.dqaccum_dtype = mdQaccum.element_type
|
|
402
|
+
self.dk_dtype = mdK.element_type
|
|
403
|
+
self.dv_dtype = mdV.element_type
|
|
404
|
+
self.ds_dtype = self.q_dtype
|
|
405
|
+
|
|
406
|
+
self.is_varlen_k = mCuSeqlensK is not None or mSeqUsedK is not None
|
|
407
|
+
self.is_varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None
|
|
408
|
+
self.use_tma_store = not (self.qhead_per_kvhead == 1 and mCuSeqlensK is not None)
|
|
409
|
+
self.dKV_postprocess = self.qhead_per_kvhead > 1
|
|
410
|
+
|
|
411
|
+
if const_expr(self.dKV_postprocess):
|
|
412
|
+
assert self.dk_dtype.width == 32, "Must accumulate dK in float precision for GQA"
|
|
413
|
+
assert self.dv_dtype.width == 32, "Must accumulate dV in float precision for GQA"
|
|
414
|
+
|
|
415
|
+
# Assume all strides are divisible by 128 bits except the last stride
|
|
416
|
+
new_stride = lambda t: (
|
|
417
|
+
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
|
|
418
|
+
t.stride[-1],
|
|
419
|
+
)
|
|
420
|
+
(
|
|
421
|
+
mdQaccum,
|
|
422
|
+
mdK,
|
|
423
|
+
mdV,
|
|
424
|
+
) = [
|
|
425
|
+
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
|
426
|
+
if t is not None
|
|
427
|
+
else None
|
|
428
|
+
for t in (
|
|
429
|
+
mdQaccum,
|
|
430
|
+
mdK,
|
|
431
|
+
mdV,
|
|
432
|
+
)
|
|
433
|
+
]
|
|
434
|
+
|
|
435
|
+
# (b, s, n, h) --> (s, h, n, b) or (t, n, h) -> (t, h, n)
|
|
436
|
+
QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]
|
|
437
|
+
mQ, mdO = [utils.select(t, mode=QO_layout_transpose) for t in (mQ, mdO)]
|
|
438
|
+
|
|
439
|
+
KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1]
|
|
440
|
+
mK, mV = [utils.select(t, mode=KV_layout_transpose) for t in (mK, mV)]
|
|
441
|
+
|
|
442
|
+
# (b, n, s) --> (s, n, b) or (n, t) --> (t, n)
|
|
443
|
+
LSE_dPsum_dQaccum_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0]
|
|
444
|
+
mLSE, mdPsum, mdQaccum = [
|
|
445
|
+
utils.select(t, mode=LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum)
|
|
446
|
+
]
|
|
447
|
+
|
|
448
|
+
if const_expr(not self.dKV_postprocess):
|
|
449
|
+
layout_dKV_transpose = KV_layout_transpose
|
|
450
|
+
else:
|
|
451
|
+
layout_dKV_transpose = [2, 1, 0] if const_expr(mCuSeqlensK is None) else [1, 0]
|
|
452
|
+
mdK, mdV = [utils.select(t, mode=layout_dKV_transpose) for t in (mdK, mdV)]
|
|
453
|
+
# (s, h, n, b) --> (h, s, n, b) or (t, h, n) -> (h, t, b)
|
|
454
|
+
dO_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensQ is None) else [1, 0, 2]
|
|
455
|
+
mdO = utils.select(mdO, mode=dO_transpose)
|
|
456
|
+
|
|
457
|
+
# (b, n, block, stage) -> (block, stage, n, b)
|
|
458
|
+
semaphore_transpose = [2, 3, 1, 0]
|
|
459
|
+
if const_expr(self.deterministic):
|
|
460
|
+
assert mdQ_semaphore is not None
|
|
461
|
+
mdQ_semaphore = utils.select(mdQ_semaphore, mode=semaphore_transpose)
|
|
462
|
+
|
|
463
|
+
if const_expr(self.deterministic and self.qhead_per_kvhead > 1):
|
|
464
|
+
assert mdK_semaphore is not None
|
|
465
|
+
assert mdV_semaphore is not None
|
|
466
|
+
mdK_semaphore, mdV_semaphore = [
|
|
467
|
+
utils.select(t, mode=semaphore_transpose) for t in (mdK_semaphore, mdV_semaphore)
|
|
468
|
+
]
|
|
469
|
+
else:
|
|
470
|
+
mdK_semaphore = None
|
|
471
|
+
mdV_semaphore = None
|
|
472
|
+
|
|
473
|
+
self._setup_attributes()
|
|
474
|
+
(
|
|
475
|
+
self.tiled_mma_S,
|
|
476
|
+
self.tiled_mma_dP,
|
|
477
|
+
self.tiled_mma_dK,
|
|
478
|
+
self.tiled_mma_dV,
|
|
479
|
+
self.tiled_mma_dQ,
|
|
480
|
+
) = self._get_tiled_mma()
|
|
481
|
+
self._setup_smem_layout()
|
|
482
|
+
|
|
483
|
+
cta_group = tcgen05.CtaGroup.ONE
|
|
484
|
+
|
|
485
|
+
self.cluster_shape_mnk = (*self.cluster_shape_mn, 1)
|
|
486
|
+
self.cluster_layout_vmnk = cute.tiled_divide(
|
|
487
|
+
cute.make_layout(self.cluster_shape_mnk),
|
|
488
|
+
(self.tiled_mma_S.thr_id.shape,),
|
|
489
|
+
)
|
|
490
|
+
self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
|
|
491
|
+
self.is_q_do_mcast = self.num_mcast_ctas_b > 1
|
|
492
|
+
|
|
493
|
+
if const_expr(not self.dKV_postprocess):
|
|
494
|
+
self.mdK_layout_enum = LayoutEnum.from_tensor(mdK)
|
|
495
|
+
self.mdV_layout_enum = LayoutEnum.from_tensor(mdV)
|
|
496
|
+
dK_major_mode = self.mdK_layout_enum.mma_major_mode()
|
|
497
|
+
dV_major_mode = self.mdV_layout_enum.mma_major_mode()
|
|
498
|
+
if const_expr(dK_major_mode != tcgen05.OperandMajorMode.K):
|
|
499
|
+
raise RuntimeError("The layout of mdK is wrong")
|
|
500
|
+
if const_expr(dV_major_mode != tcgen05.OperandMajorMode.K):
|
|
501
|
+
raise RuntimeError("The layout of mdV is wrong")
|
|
502
|
+
|
|
503
|
+
if const_expr(self.use_tma_store and not self.dKV_postprocess):
|
|
504
|
+
tma_copy_op_dKV = cpasync.CopyBulkTensorTileS2GOp()
|
|
505
|
+
tma_atom_dK, mdK_tma_tensor = cpasync.make_tiled_tma_atom(
|
|
506
|
+
tma_copy_op_dKV,
|
|
507
|
+
mdK,
|
|
508
|
+
cute.select(self.sdKV_layout, mode=[0, 1]),
|
|
509
|
+
self.sdKV_epi_tile,
|
|
510
|
+
1, # no mcast
|
|
511
|
+
)
|
|
512
|
+
tma_atom_dV, mdV_tma_tensor = cpasync.make_tiled_tma_atom(
|
|
513
|
+
tma_copy_op_dKV,
|
|
514
|
+
mdV,
|
|
515
|
+
cute.select(self.sdKV_layout, mode=[0, 1]),
|
|
516
|
+
self.sdKV_epi_tile,
|
|
517
|
+
1, # no mcast
|
|
518
|
+
)
|
|
519
|
+
else:
|
|
520
|
+
mdV_tma_tensor = mdV
|
|
521
|
+
mdK_tma_tensor = mdK
|
|
522
|
+
tma_atom_dV = None
|
|
523
|
+
tma_atom_dK = None
|
|
524
|
+
|
|
525
|
+
if const_expr(not self.dKV_postprocess):
|
|
526
|
+
thr_layout_r2s_dKV = cute.make_ordered_layout((128, 1), order=(1, 0)) # 128 threads
|
|
527
|
+
val_layout_r2s_dKV = cute.make_ordered_layout(
|
|
528
|
+
(1, 128 // self.dk_dtype.width), order=(1, 0)
|
|
529
|
+
) # 4 or 8 vals for 16 byte store
|
|
530
|
+
copy_atom_r2s_dKV = cute.make_copy_atom(
|
|
531
|
+
cute.nvgpu.CopyUniversalOp(),
|
|
532
|
+
self.dk_dtype,
|
|
533
|
+
num_bits_per_copy=128,
|
|
534
|
+
)
|
|
535
|
+
tiled_copy_r2s_dKV = cute.make_tiled_copy_tv(
|
|
536
|
+
copy_atom_r2s_dKV, thr_layout_r2s_dKV, val_layout_r2s_dKV
|
|
537
|
+
)
|
|
538
|
+
else:
|
|
539
|
+
tiled_copy_r2s_dKV = copy_utils.tiled_copy_1d(
|
|
540
|
+
Float32, 128, num_copy_elems=128 // Float32.width
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group)
|
|
544
|
+
tma_load_op_multicast = cpasync.CopyBulkTensorTileG2SMulticastOp(cta_group)
|
|
545
|
+
|
|
546
|
+
# S.T = K @ Q.T
|
|
547
|
+
tma_atom_K, tma_tensor_K = cute.nvgpu.make_tiled_tma_atom_A(
|
|
548
|
+
tma_load_op,
|
|
549
|
+
mK,
|
|
550
|
+
cute.select(self.sK_layout, mode=[0, 1, 2]),
|
|
551
|
+
self.mma_tiler_kq,
|
|
552
|
+
self.tiled_mma_S,
|
|
553
|
+
self.cluster_layout_vmnk.shape,
|
|
554
|
+
)
|
|
555
|
+
Q_tma_op = sm100_utils_basic.cluster_shape_to_tma_atom_B(
|
|
556
|
+
self.cluster_shape_mnk, self.tiled_mma_S.thr_id
|
|
557
|
+
)
|
|
558
|
+
tma_atom_Q, tma_tensor_Q = cute.nvgpu.make_tiled_tma_atom_B(
|
|
559
|
+
# tma_load_op if const_expr(self.cluster_shape_mnk[0] == 1) else tma_load_op_multicast,
|
|
560
|
+
Q_tma_op,
|
|
561
|
+
mQ,
|
|
562
|
+
cute.select(self.sQ_layout, mode=[0, 1, 2]),
|
|
563
|
+
self.mma_tiler_kq,
|
|
564
|
+
self.tiled_mma_S,
|
|
565
|
+
self.cluster_layout_vmnk.shape,
|
|
566
|
+
)
|
|
567
|
+
# dP.T = V @ dO.T
|
|
568
|
+
tma_atom_V, tma_tensor_V = cute.nvgpu.make_tiled_tma_atom_A(
|
|
569
|
+
tma_load_op,
|
|
570
|
+
mV,
|
|
571
|
+
cute.select(self.sV_layout, mode=[0, 1, 2]),
|
|
572
|
+
self.mma_tiler_vdo,
|
|
573
|
+
self.tiled_mma_dP,
|
|
574
|
+
self.cluster_layout_vmnk.shape,
|
|
575
|
+
)
|
|
576
|
+
dO_tma_op = sm100_utils_basic.cluster_shape_to_tma_atom_B(
|
|
577
|
+
self.cluster_shape_mnk, self.tiled_mma_dV.thr_id
|
|
578
|
+
)
|
|
579
|
+
tma_atom_dO, tma_tensor_dO = cute.nvgpu.make_tiled_tma_atom_B(
|
|
580
|
+
# tma_load_op if const_expr(self.cluster_shape_mnk[0] == 1) else tma_load_op_multicast,
|
|
581
|
+
dO_tma_op,
|
|
582
|
+
mdO,
|
|
583
|
+
cute.select(self.sdO_layout, mode=[0, 1, 2]),
|
|
584
|
+
self.mma_tiler_pdo,
|
|
585
|
+
self.tiled_mma_dV,
|
|
586
|
+
self.cluster_layout_vmnk.shape,
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
self.tma_copy_bytes = {
|
|
590
|
+
name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2]))
|
|
591
|
+
for name, mX, layout in [
|
|
592
|
+
("Q", mQ, self.sQ_layout),
|
|
593
|
+
("K", mK, self.sK_layout),
|
|
594
|
+
("V", mV, self.sV_layout),
|
|
595
|
+
("dO", mdO, self.sdO_layout),
|
|
596
|
+
]
|
|
597
|
+
}
|
|
598
|
+
self.tma_copy_bytes["LSE"] = self.tile_m * Float32.width // 8
|
|
599
|
+
self.tma_copy_bytes["dPsum"] = self.tile_m * Float32.width // 8
|
|
600
|
+
self.tma_copy_bytes["dQ"] = self.tile_m * self.dQ_reduce_ncol * Float32.width // 8
|
|
601
|
+
self.tma_copy_bytes["dKacc"] = self.tile_n * self.dK_reduce_ncol * Float32.width // 8
|
|
602
|
+
|
|
603
|
+
# TileScheduler = SingleTileScheduler
|
|
604
|
+
if const_expr(self.is_varlen_k):
|
|
605
|
+
TileScheduler = SingleTileVarlenScheduler
|
|
606
|
+
elif const_expr(self.deterministic):
|
|
607
|
+
TileScheduler = SingleTileLPTBwdScheduler
|
|
608
|
+
else:
|
|
609
|
+
TileScheduler = SingleTileScheduler
|
|
610
|
+
# reads n_blocks right-to-left
|
|
611
|
+
self.spt = (self.is_causal or self.is_local) and self.deterministic
|
|
612
|
+
tile_sched_args = TileSchedulerArguments(
|
|
613
|
+
cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]), # num_blocks
|
|
614
|
+
cute.size(mQ.shape[2]), # num_heads = num_query_heads
|
|
615
|
+
cute.size(mK.shape[3])
|
|
616
|
+
if const_expr(mCuSeqlensK is None)
|
|
617
|
+
else cute.size(mCuSeqlensK.shape[0] - 1), # num_batches
|
|
618
|
+
1, # num_splits
|
|
619
|
+
cute.size(mQ.shape[0]), # pass seqlen_q or total_q for seqlen_k
|
|
620
|
+
mQ.shape[1], # headdim
|
|
621
|
+
mV.shape[1], # headdim_v
|
|
622
|
+
total_q=cute.size(mK.shape[0]) # pass total_k for total_q
|
|
623
|
+
if const_expr(mCuSeqlensK is not None)
|
|
624
|
+
else cute.size(mK.shape[0]) * cute.size(mK.shape[3]),
|
|
625
|
+
tile_shape_mn=self.cta_tiler[:2], # (tile_n, tile_m)
|
|
626
|
+
cluster_shape_mn=self.cluster_shape_mnk[:2],
|
|
627
|
+
mCuSeqlensQ=mCuSeqlensK,
|
|
628
|
+
mSeqUsedQ=mSeqUsedK,
|
|
629
|
+
qhead_per_kvhead_packgqa=1, # pack_gqa disabled for bwd
|
|
630
|
+
element_size=self.k_dtype.width // 8,
|
|
631
|
+
is_persistent=self.is_persistent, # persistent mode not tested
|
|
632
|
+
lpt=self.spt,
|
|
633
|
+
head_swizzle=self.deterministic,
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
|
|
637
|
+
self.tile_scheduler_cls = TileScheduler
|
|
638
|
+
grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
|
|
639
|
+
# cute.printf("grid_dim = {}", grid_dim)
|
|
640
|
+
|
|
641
|
+
# Compute allocation sizes for shared buffers that are reused
|
|
642
|
+
# sQ is reused for sdK, sdO is reused for sdV
|
|
643
|
+
sQ_alloc_bytes = max(
|
|
644
|
+
cute.size_in_bytes(self.q_dtype, self.sQ_layout),
|
|
645
|
+
cute.size_in_bytes(self.dk_dtype, self.sdKV_layout),
|
|
646
|
+
)
|
|
647
|
+
sdO_alloc_bytes = max(
|
|
648
|
+
cute.size_in_bytes(self.dv_dtype, self.sdKV_layout),
|
|
649
|
+
cute.size_in_bytes(self.do_dtype, self.sdO_layout),
|
|
650
|
+
)
|
|
651
|
+
# Sanity check that layouts fit in allocation
|
|
652
|
+
sdV_bytes = cute.size_in_bytes(self.dv_dtype, self.sdKV_layout)
|
|
653
|
+
sdK_bytes = cute.size_in_bytes(self.dk_dtype, self.sdKV_layout)
|
|
654
|
+
assert sdV_bytes <= sdO_alloc_bytes, "sdV doesn't fit in sdO storage allocation"
|
|
655
|
+
assert sdK_bytes <= sQ_alloc_bytes, "sdK doesn't fit in sQ storage allocation"
|
|
656
|
+
|
|
657
|
+
@cute.struct
|
|
658
|
+
class SharedStorage:
|
|
659
|
+
Q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage]
|
|
660
|
+
dO_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage]
|
|
661
|
+
LSE_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.Q_stage]
|
|
662
|
+
dPsum_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * self.dO_stage]
|
|
663
|
+
S_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1]
|
|
664
|
+
dP_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1]
|
|
665
|
+
dS_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 1]
|
|
666
|
+
dKV_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2 * 2]
|
|
667
|
+
dQ_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]
|
|
668
|
+
dQ_cluster_full_mbar_ptr: cute.struct.MemRange[
|
|
669
|
+
cutlass.Int64, self.dQaccum_reduce_stage // 2
|
|
670
|
+
]
|
|
671
|
+
dQ_cluster_empty_mbar_ptr: cute.struct.MemRange[
|
|
672
|
+
cutlass.Int64, self.dQaccum_reduce_stage // 2
|
|
673
|
+
]
|
|
674
|
+
tmem_holding_buf: Int32
|
|
675
|
+
tmem_dealloc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1]
|
|
676
|
+
|
|
677
|
+
# Smem tensors
|
|
678
|
+
|
|
679
|
+
# sQ is reused for sdK which in the non-MHA case needs float32
|
|
680
|
+
sQ: cute.struct.Align[
|
|
681
|
+
cute.struct.MemRange[cute.Uint8, sQ_alloc_bytes],
|
|
682
|
+
self.buffer_align_bytes,
|
|
683
|
+
]
|
|
684
|
+
sK: cute.struct.Align[
|
|
685
|
+
cute.struct.MemRange[self.k_dtype, cute.cosize(self.sK_layout)],
|
|
686
|
+
self.buffer_align_bytes,
|
|
687
|
+
]
|
|
688
|
+
sV: cute.struct.Align[
|
|
689
|
+
cute.struct.MemRange[self.v_dtype, cute.cosize(self.sV_layout)],
|
|
690
|
+
self.buffer_align_bytes,
|
|
691
|
+
]
|
|
692
|
+
# sdO is reused for sdV which in the non-MHA case needs float32
|
|
693
|
+
sdO: cute.struct.Align[
|
|
694
|
+
cute.struct.MemRange[cute.Uint8, sdO_alloc_bytes],
|
|
695
|
+
self.buffer_align_bytes,
|
|
696
|
+
]
|
|
697
|
+
sdS: cute.struct.Align[
|
|
698
|
+
cute.struct.MemRange[self.ds_dtype, cute.cosize(self.sdSt_layout)],
|
|
699
|
+
128,
|
|
700
|
+
]
|
|
701
|
+
sLSE: cute.struct.Align[
|
|
702
|
+
cute.struct.MemRange[self.lse_dtype, cute.cosize(self.sLSE_layout)],
|
|
703
|
+
128,
|
|
704
|
+
]
|
|
705
|
+
sdPsum: cute.struct.Align[
|
|
706
|
+
cute.struct.MemRange[self.dpsum_dtype, cute.cosize(self.sdPsum_layout)],
|
|
707
|
+
128,
|
|
708
|
+
]
|
|
709
|
+
sdQaccum: cute.struct.Align[
|
|
710
|
+
cute.struct.MemRange[self.dqaccum_dtype, cute.cosize(self.sdQaccum_layout)],
|
|
711
|
+
self.buffer_align_bytes,
|
|
712
|
+
]
|
|
713
|
+
|
|
714
|
+
self.shared_storage = SharedStorage
|
|
715
|
+
|
|
716
|
+
LOG2_E = math.log2(math.e)
|
|
717
|
+
if const_expr(self.score_mod is None):
|
|
718
|
+
# Without score_mod: bake scale into log2
|
|
719
|
+
softmax_scale_log2 = softmax_scale * LOG2_E
|
|
720
|
+
else:
|
|
721
|
+
# With score_mod: score_mod applied to S * softmax_scale, then use LOG2_E only
|
|
722
|
+
softmax_scale_log2 = LOG2_E
|
|
723
|
+
|
|
724
|
+
if const_expr(window_size_left is not None):
|
|
725
|
+
window_size_left = Int32(window_size_left)
|
|
726
|
+
if const_expr(window_size_right is not None):
|
|
727
|
+
window_size_right = Int32(window_size_right)
|
|
728
|
+
|
|
729
|
+
fastdiv_mods = None
|
|
730
|
+
if const_expr(aux_tensors is not None):
|
|
731
|
+
seqlen_q = cute.size(mQ.shape[0]) // (
|
|
732
|
+
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1
|
|
733
|
+
)
|
|
734
|
+
seqlen_k = cute.size(mK.shape[0])
|
|
735
|
+
seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
|
|
736
|
+
seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
|
|
737
|
+
fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)
|
|
738
|
+
self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)
|
|
739
|
+
|
|
740
|
+
if const_expr(self.use_block_sparsity or aux_tensors is not None):
|
|
741
|
+
assert all(x is None for x in (mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)), (
|
|
742
|
+
"Variable sequence length is not supported yet for blocksparse or aux tensors in bwd"
|
|
743
|
+
)
|
|
744
|
+
|
|
745
|
+
self.kernel(
|
|
746
|
+
tma_tensor_Q,
|
|
747
|
+
tma_tensor_K,
|
|
748
|
+
tma_tensor_V,
|
|
749
|
+
mLSE,
|
|
750
|
+
mdPsum,
|
|
751
|
+
tma_tensor_dO,
|
|
752
|
+
mdV,
|
|
753
|
+
mdK,
|
|
754
|
+
mdQaccum,
|
|
755
|
+
mdV_tma_tensor,
|
|
756
|
+
mdK_tma_tensor,
|
|
757
|
+
mdQ_semaphore,
|
|
758
|
+
mdK_semaphore,
|
|
759
|
+
mdV_semaphore,
|
|
760
|
+
mCuSeqlensQ,
|
|
761
|
+
mCuSeqlensK,
|
|
762
|
+
mSeqUsedQ,
|
|
763
|
+
mSeqUsedK,
|
|
764
|
+
tma_atom_Q,
|
|
765
|
+
tma_atom_K,
|
|
766
|
+
tma_atom_V,
|
|
767
|
+
tma_atom_dO,
|
|
768
|
+
tma_atom_dV,
|
|
769
|
+
tma_atom_dK,
|
|
770
|
+
self.sQ_layout,
|
|
771
|
+
self.sQt_layout,
|
|
772
|
+
self.sK_layout,
|
|
773
|
+
self.sV_layout,
|
|
774
|
+
self.sLSE_layout,
|
|
775
|
+
self.sdPsum_layout,
|
|
776
|
+
self.sdO_layout,
|
|
777
|
+
self.sdOt_layout,
|
|
778
|
+
self.sdSt_layout,
|
|
779
|
+
self.sdS_layout,
|
|
780
|
+
self.sKt_layout,
|
|
781
|
+
self.sdQaccum_layout,
|
|
782
|
+
self.sdKV_layout,
|
|
783
|
+
self.tP_layout,
|
|
784
|
+
self.tdS_layout,
|
|
785
|
+
self.tiled_mma_S,
|
|
786
|
+
self.tiled_mma_dP,
|
|
787
|
+
self.tiled_mma_dV,
|
|
788
|
+
self.tiled_mma_dK,
|
|
789
|
+
self.tiled_mma_dQ,
|
|
790
|
+
tiled_copy_r2s_dKV,
|
|
791
|
+
softmax_scale,
|
|
792
|
+
softmax_scale_log2,
|
|
793
|
+
window_size_left,
|
|
794
|
+
window_size_right,
|
|
795
|
+
tile_sched_params,
|
|
796
|
+
aux_tensors,
|
|
797
|
+
fastdiv_mods,
|
|
798
|
+
blocksparse_tensors,
|
|
799
|
+
).launch(
|
|
800
|
+
grid=grid_dim,
|
|
801
|
+
block=[self.threads_per_cta, 1, 1],
|
|
802
|
+
cluster=self.cluster_shape_mnk if cute.size(self.cluster_shape_mnk) > 1 else None,
|
|
803
|
+
smem=self.shared_storage.size_in_bytes(),
|
|
804
|
+
stream=stream,
|
|
805
|
+
min_blocks_per_mp=1,
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
@cute.kernel
|
|
809
|
+
def kernel(
|
|
810
|
+
self,
|
|
811
|
+
mQ: cute.Tensor,
|
|
812
|
+
mK: cute.Tensor,
|
|
813
|
+
mV: cute.Tensor,
|
|
814
|
+
mLSE: cute.Tensor,
|
|
815
|
+
mdPsum: cute.Tensor,
|
|
816
|
+
mdO: cute.Tensor,
|
|
817
|
+
mdV: cute.Tensor,
|
|
818
|
+
mdK: cute.Tensor,
|
|
819
|
+
mdQaccum: cute.Tensor,
|
|
820
|
+
mdV_tma_tensor: Optional[cute.Tensor],
|
|
821
|
+
mdK_tma_tensor: Optional[cute.Tensor],
|
|
822
|
+
mdQ_semaphore: Optional[cute.Tensor],
|
|
823
|
+
mdK_semaphore: Optional[cute.Tensor],
|
|
824
|
+
mdV_semaphore: Optional[cute.Tensor],
|
|
825
|
+
mCuSeqlensQ: Optional[cute.Tensor],
|
|
826
|
+
mCuSeqlensK: Optional[cute.Tensor],
|
|
827
|
+
mSeqUsedQ: Optional[cute.Tensor],
|
|
828
|
+
mSeqUsedK: Optional[cute.Tensor],
|
|
829
|
+
tma_atom_Q: cute.CopyAtom,
|
|
830
|
+
tma_atom_K: cute.CopyAtom,
|
|
831
|
+
tma_atom_V: cute.CopyAtom,
|
|
832
|
+
tma_atom_dO: cute.CopyAtom,
|
|
833
|
+
tma_atom_dV: Optional[cute.CopyAtom],
|
|
834
|
+
tma_atom_dK: Optional[cute.CopyAtom],
|
|
835
|
+
sQ_layout: cute.ComposedLayout,
|
|
836
|
+
sQt_layout: cute.ComposedLayout,
|
|
837
|
+
sK_layout: cute.ComposedLayout,
|
|
838
|
+
sV_layout: cute.ComposedLayout,
|
|
839
|
+
sLSE_layout: cute.Layout,
|
|
840
|
+
sdPsum_layout: cute.Layout,
|
|
841
|
+
sdO_layout: cute.ComposedLayout,
|
|
842
|
+
sdOt_layout: cute.ComposedLayout,
|
|
843
|
+
sdSt_layout: cute.ComposedLayout,
|
|
844
|
+
sdS_layout: cute.ComposedLayout,
|
|
845
|
+
sKt_layout: cute.ComposedLayout,
|
|
846
|
+
sdQaccum_layout: cute.Layout,
|
|
847
|
+
sdKV_layout: cute.ComposedLayout | cute.Layout,
|
|
848
|
+
tP_layout: cute.ComposedLayout,
|
|
849
|
+
tdS_layout: cute.ComposedLayout,
|
|
850
|
+
tiled_mma_S: cute.TiledMma,
|
|
851
|
+
tiled_mma_dP: cute.TiledMma,
|
|
852
|
+
tiled_mma_dV: cute.TiledMma,
|
|
853
|
+
tiled_mma_dK: cute.TiledMma,
|
|
854
|
+
tiled_mma_dQ: cute.TiledMma,
|
|
855
|
+
tiled_copy_r2s_dKV: cute.TiledCopy,
|
|
856
|
+
softmax_scale: cutlass.Float32,
|
|
857
|
+
softmax_scale_log2: cutlass.Float32,
|
|
858
|
+
window_size_left: Optional[Int32],
|
|
859
|
+
window_size_right: Optional[Int32],
|
|
860
|
+
tile_sched_params: ParamsBase,
|
|
861
|
+
aux_tensors: Optional[list] = None,
|
|
862
|
+
fastdiv_mods=(None, None),
|
|
863
|
+
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
|
864
|
+
):
|
|
865
|
+
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
866
|
+
|
|
867
|
+
# Prefetch tma descriptor
|
|
868
|
+
if warp_idx == self.load_warp_id:
|
|
869
|
+
with cute.arch.elect_one():
|
|
870
|
+
cpasync.prefetch_descriptor(tma_atom_Q)
|
|
871
|
+
cpasync.prefetch_descriptor(tma_atom_K)
|
|
872
|
+
cpasync.prefetch_descriptor(tma_atom_V)
|
|
873
|
+
cpasync.prefetch_descriptor(tma_atom_dO)
|
|
874
|
+
if const_expr(tma_atom_dV is not None):
|
|
875
|
+
cpasync.prefetch_descriptor(tma_atom_dV)
|
|
876
|
+
if const_expr(tma_atom_dK is not None):
|
|
877
|
+
cpasync.prefetch_descriptor(tma_atom_dK)
|
|
878
|
+
|
|
879
|
+
cluster_layout_vmnk = cute.tiled_divide(
|
|
880
|
+
cute.make_layout(self.cluster_shape_mnk),
|
|
881
|
+
(tiled_mma_S.thr_id.shape,),
|
|
882
|
+
)
|
|
883
|
+
|
|
884
|
+
# Alloc
|
|
885
|
+
smem = cutlass.utils.SmemAllocator()
|
|
886
|
+
storage = smem.allocate(self.shared_storage)
|
|
887
|
+
|
|
888
|
+
tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr.data_ptr()
|
|
889
|
+
dQ_cluster_full_mbar_ptr = storage.dQ_cluster_full_mbar_ptr.data_ptr()
|
|
890
|
+
dQ_cluster_empty_mbar_ptr = storage.dQ_cluster_empty_mbar_ptr.data_ptr()
|
|
891
|
+
|
|
892
|
+
if warp_idx == 1:
|
|
893
|
+
cute.arch.mbarrier_init(
|
|
894
|
+
tmem_dealloc_mbar_ptr, cute.arch.WARP_SIZE * len(self.compute_warp_ids)
|
|
895
|
+
)
|
|
896
|
+
if const_expr(self.cluster_reduce_dQ):
|
|
897
|
+
if warp_idx == 4:
|
|
898
|
+
for i in range(self.dQaccum_reduce_stage // 2):
|
|
899
|
+
cute.arch.mbarrier_init(dQ_cluster_full_mbar_ptr + i, 1)
|
|
900
|
+
cute.arch.mbarrier_init(dQ_cluster_empty_mbar_ptr + i, 1)
|
|
901
|
+
|
|
902
|
+
# UMMA producers and AsyncThread consumers
|
|
903
|
+
pipeline_producer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup(
|
|
904
|
+
cutlass.pipeline.Agent.Thread, len([self.mma_warp_id])
|
|
905
|
+
)
|
|
906
|
+
# Only 1 thread per warp will signal
|
|
907
|
+
pipeline_consumer_group_MMA_AsyncThread = cutlass.pipeline.CooperativeGroup(
|
|
908
|
+
cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids)
|
|
909
|
+
)
|
|
910
|
+
pipeline_S_P = cutlass.pipeline.PipelineUmmaAsync.create(
|
|
911
|
+
num_stages=1,
|
|
912
|
+
producer_group=pipeline_producer_group_MMA_AsyncThread,
|
|
913
|
+
consumer_group=pipeline_consumer_group_MMA_AsyncThread,
|
|
914
|
+
barrier_storage=storage.S_mbar_ptr.data_ptr(),
|
|
915
|
+
)
|
|
916
|
+
pipeline_dP = cutlass.pipeline.PipelineUmmaAsync.create(
|
|
917
|
+
num_stages=1,
|
|
918
|
+
producer_group=pipeline_producer_group_MMA_AsyncThread,
|
|
919
|
+
consumer_group=pipeline_consumer_group_MMA_AsyncThread,
|
|
920
|
+
barrier_storage=storage.dP_mbar_ptr.data_ptr(),
|
|
921
|
+
)
|
|
922
|
+
pipeline_dKV = cutlass.pipeline.PipelineUmmaAsync.create(
|
|
923
|
+
num_stages=2,
|
|
924
|
+
producer_group=pipeline_producer_group_MMA_AsyncThread,
|
|
925
|
+
consumer_group=pipeline_consumer_group_MMA_AsyncThread,
|
|
926
|
+
barrier_storage=storage.dKV_mbar_ptr.data_ptr(),
|
|
927
|
+
)
|
|
928
|
+
pipeline_consumer_group_MMA_AsyncThread_dQ = cutlass.pipeline.CooperativeGroup(
|
|
929
|
+
cutlass.pipeline.Agent.Thread,
|
|
930
|
+
len(self.reduce_warp_ids),
|
|
931
|
+
) # Compute
|
|
932
|
+
pipeline_dQ = cutlass.pipeline.PipelineUmmaAsync.create(
|
|
933
|
+
num_stages=1,
|
|
934
|
+
producer_group=pipeline_producer_group_MMA_AsyncThread,
|
|
935
|
+
consumer_group=pipeline_consumer_group_MMA_AsyncThread_dQ,
|
|
936
|
+
barrier_storage=storage.dQ_mbar_ptr.data_ptr(),
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
# AsyncThread producers and UMMA consumers
|
|
940
|
+
# Only 1 thread per warp will signal
|
|
941
|
+
pipeline_PdS_producer_group = cutlass.pipeline.CooperativeGroup(
|
|
942
|
+
cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids)
|
|
943
|
+
) # Compute
|
|
944
|
+
pipeline_PdS_consumer_group = cutlass.pipeline.CooperativeGroup(
|
|
945
|
+
cutlass.pipeline.Agent.Thread, len([self.mma_warp_id])
|
|
946
|
+
) # MMA
|
|
947
|
+
pipeline_dS = cutlass.pipeline.PipelineAsyncUmma.create(
|
|
948
|
+
num_stages=1,
|
|
949
|
+
producer_group=pipeline_PdS_producer_group,
|
|
950
|
+
consumer_group=pipeline_PdS_consumer_group,
|
|
951
|
+
barrier_storage=storage.dS_mbar_ptr.data_ptr(),
|
|
952
|
+
)
|
|
953
|
+
|
|
954
|
+
# TMA producer and UMMA consumers
|
|
955
|
+
pipeline_producer_group = cutlass.pipeline.CooperativeGroup(
|
|
956
|
+
cutlass.pipeline.Agent.Thread, len([self.load_warp_id])
|
|
957
|
+
)
|
|
958
|
+
# The arrive count is the number of mcast size
|
|
959
|
+
pipeline_consumer_group = cutlass.pipeline.CooperativeGroup(
|
|
960
|
+
cutlass.pipeline.Agent.Thread, len([self.mma_warp_id]) * self.num_mcast_ctas_b
|
|
961
|
+
)
|
|
962
|
+
pipeline_consumer_group_compute = cutlass.pipeline.CooperativeGroup(
|
|
963
|
+
# cutlass.pipeline.Agent.Thread, len(self.compute_warp_ids) * self.num_mcast_ctas_b
|
|
964
|
+
cutlass.pipeline.Agent.Thread,
|
|
965
|
+
len(self.compute_warp_ids) * 1,
|
|
966
|
+
)
|
|
967
|
+
pipeline_LSE = cutlass.pipeline.PipelineTmaAsync.create(
|
|
968
|
+
barrier_storage=storage.LSE_mbar_ptr.data_ptr(),
|
|
969
|
+
num_stages=self.Q_stage,
|
|
970
|
+
producer_group=pipeline_producer_group,
|
|
971
|
+
consumer_group=pipeline_consumer_group_compute,
|
|
972
|
+
tx_count=self.tma_copy_bytes["LSE"],
|
|
973
|
+
# cta_layout_vmnk=cluster_layout_vmnk,
|
|
974
|
+
# init_wait=False,
|
|
975
|
+
)
|
|
976
|
+
pipeline_dPsum = cutlass.pipeline.PipelineTmaAsync.create(
|
|
977
|
+
barrier_storage=storage.dPsum_mbar_ptr.data_ptr(),
|
|
978
|
+
num_stages=self.dO_stage,
|
|
979
|
+
producer_group=pipeline_producer_group,
|
|
980
|
+
consumer_group=pipeline_consumer_group_compute,
|
|
981
|
+
tx_count=self.tma_copy_bytes["dPsum"],
|
|
982
|
+
# cta_layout_vmnk=cluster_layout_vmnk,
|
|
983
|
+
# init_wait=False,
|
|
984
|
+
)
|
|
985
|
+
pipeline_Q = pipeline.PipelineTmaUmma.create(
|
|
986
|
+
barrier_storage=storage.Q_mbar_ptr.data_ptr(),
|
|
987
|
+
num_stages=self.Q_stage,
|
|
988
|
+
producer_group=pipeline_producer_group,
|
|
989
|
+
consumer_group=pipeline_consumer_group,
|
|
990
|
+
tx_count=self.tma_copy_bytes["Q"],
|
|
991
|
+
cta_layout_vmnk=cluster_layout_vmnk,
|
|
992
|
+
init_wait=False,
|
|
993
|
+
)
|
|
994
|
+
pipeline_dO = pipeline.PipelineTmaUmma.create(
|
|
995
|
+
barrier_storage=storage.dO_mbar_ptr.data_ptr(),
|
|
996
|
+
num_stages=self.dO_stage,
|
|
997
|
+
producer_group=pipeline_producer_group,
|
|
998
|
+
consumer_group=pipeline_consumer_group,
|
|
999
|
+
tx_count=self.tma_copy_bytes["dO"],
|
|
1000
|
+
cta_layout_vmnk=cluster_layout_vmnk,
|
|
1001
|
+
init_wait=True,
|
|
1002
|
+
)
|
|
1003
|
+
|
|
1004
|
+
sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner, dtype=self.q_dtype)
|
|
1005
|
+
sQt = cute.make_tensor(
|
|
1006
|
+
cute.recast_ptr(sQ.iterator, sQt_layout.inner, dtype=self.q_dtype), sQt_layout.outer
|
|
1007
|
+
)
|
|
1008
|
+
sK = storage.sK.get_tensor(sK_layout.outer, swizzle=sK_layout.inner)
|
|
1009
|
+
sKt = cute.make_tensor(cute.recast_ptr(sK.iterator, sKt_layout.inner), sKt_layout.outer)
|
|
1010
|
+
sV = storage.sV.get_tensor(sV_layout.outer, swizzle=sV_layout.inner)
|
|
1011
|
+
sdSt = storage.sdS.get_tensor(sdSt_layout.outer, swizzle=sdSt_layout.inner)
|
|
1012
|
+
sdS = cute.make_tensor(cute.recast_ptr(sdSt.iterator, sdS_layout.inner), sdS_layout.outer)
|
|
1013
|
+
sdO = storage.sdO.get_tensor(
|
|
1014
|
+
sdO_layout.outer, swizzle=sdO_layout.inner, dtype=self.do_dtype
|
|
1015
|
+
)
|
|
1016
|
+
sdOt = cute.make_tensor(
|
|
1017
|
+
cute.recast_ptr(sdO.iterator, sdOt_layout.inner, dtype=self.do_dtype), sdOt_layout.outer
|
|
1018
|
+
)
|
|
1019
|
+
sLSE = storage.sLSE.get_tensor(sLSE_layout)
|
|
1020
|
+
sdPsum = storage.sdPsum.get_tensor(sdPsum_layout)
|
|
1021
|
+
if const_expr(not self.dKV_postprocess):
|
|
1022
|
+
sdV = storage.sdO.get_tensor(
|
|
1023
|
+
sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dv_dtype
|
|
1024
|
+
)
|
|
1025
|
+
sdK = storage.sQ.get_tensor(
|
|
1026
|
+
sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dk_dtype
|
|
1027
|
+
)
|
|
1028
|
+
else:
|
|
1029
|
+
sdV = storage.sdO.get_tensor(sdKV_layout, dtype=self.dv_dtype)
|
|
1030
|
+
sdK = storage.sQ.get_tensor(sdKV_layout, dtype=self.dk_dtype)
|
|
1031
|
+
|
|
1032
|
+
# Buffer sizing is guaranteed by max(...) in SharedStorage declarations
|
|
1033
|
+
# for both sQ (reused as sdK) and sdO (reused as sdV)
|
|
1034
|
+
|
|
1035
|
+
sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout)
|
|
1036
|
+
|
|
1037
|
+
# TMEM
|
|
1038
|
+
# This is a fake tensor, by right need to retrieve tmem_ptr. But we know that we always
|
|
1039
|
+
# request 512 columns of tmem, so we know that it starts at 0.
|
|
1040
|
+
tmem_ptr = cute.make_ptr(Float32, 0, mem_space=cute.AddressSpace.tmem, assumed_align=16)
|
|
1041
|
+
# S
|
|
1042
|
+
thr_mma_S = tiled_mma_S.get_slice(0)
|
|
1043
|
+
Sacc_shape = thr_mma_S.partition_shape_C(self.mma_tiler_kq[:2]) # (M, N)
|
|
1044
|
+
tStS = thr_mma_S.make_fragment_C(Sacc_shape)
|
|
1045
|
+
# (MMA, MMA_M, MMA_N)
|
|
1046
|
+
tStS = cute.make_tensor(tmem_ptr + self.tmem_S_offset, tStS.layout)
|
|
1047
|
+
# dP
|
|
1048
|
+
thr_mma_dP = tiled_mma_dP.get_slice(0)
|
|
1049
|
+
dPacc_shape = thr_mma_dP.partition_shape_C(self.mma_tiler_vdo[:2])
|
|
1050
|
+
tdPtdP = thr_mma_dP.make_fragment_C(dPacc_shape)
|
|
1051
|
+
tdPtdP = cute.make_tensor(tmem_ptr + self.tmem_dP_offset, tdPtdP.layout)
|
|
1052
|
+
# dV
|
|
1053
|
+
thr_mma_dV = tiled_mma_dV.get_slice(0)
|
|
1054
|
+
dvacc_shape = thr_mma_dV.partition_shape_C(self.mma_tiler_pdo[:2])
|
|
1055
|
+
tdVtdV = thr_mma_dV.make_fragment_C(dvacc_shape)
|
|
1056
|
+
tdVtdV = cute.make_tensor(tmem_ptr + self.tmem_dV_offset, tdVtdV.layout)
|
|
1057
|
+
tP = cute.make_tensor(
|
|
1058
|
+
cute.recast_ptr(tmem_ptr + self.tmem_P_offset, dtype=self.do_dtype), tP_layout.outer
|
|
1059
|
+
)
|
|
1060
|
+
# dK
|
|
1061
|
+
thr_mma_dK = tiled_mma_dK.get_slice(0)
|
|
1062
|
+
dkacc_shape = thr_mma_dK.partition_shape_C(self.mma_tiler_dsq[:2])
|
|
1063
|
+
tdKtdK = thr_mma_dK.make_fragment_C(dkacc_shape)
|
|
1064
|
+
tdKtdK = cute.make_tensor(tmem_ptr + self.tmem_dK_offset, tdKtdK.layout)
|
|
1065
|
+
tdS = cute.make_tensor(
|
|
1066
|
+
cute.recast_ptr(tmem_ptr + self.tmem_dS_offset, dtype=self.ds_dtype), tdS_layout.outer
|
|
1067
|
+
)
|
|
1068
|
+
# dQ
|
|
1069
|
+
thr_mma_dQ = tiled_mma_dQ.get_slice(0)
|
|
1070
|
+
dQacc_shape = thr_mma_dQ.partition_shape_C(self.mma_tiler_dsk[:2])
|
|
1071
|
+
tdQtdQ = thr_mma_dQ.make_fragment_C(dQacc_shape)
|
|
1072
|
+
tdQtdQ = cute.make_tensor(tmem_ptr + self.tmem_dQ_offset, tdQtdQ.layout)
|
|
1073
|
+
|
|
1074
|
+
block_info = BlockInfo(
|
|
1075
|
+
self.tile_m,
|
|
1076
|
+
# self.tile_n,
|
|
1077
|
+
self.tile_n * self.cluster_shape_mnk[0], # careful, this case is not very well-tested
|
|
1078
|
+
self.is_causal,
|
|
1079
|
+
self.is_local,
|
|
1080
|
+
False, # is_split_kv
|
|
1081
|
+
window_size_left,
|
|
1082
|
+
window_size_right,
|
|
1083
|
+
qhead_per_kvhead_packgqa=1,
|
|
1084
|
+
)
|
|
1085
|
+
SeqlenInfoCls = partial(
|
|
1086
|
+
SeqlenInfoQK.create,
|
|
1087
|
+
seqlen_q_static=mQ.shape[0],
|
|
1088
|
+
seqlen_k_static=mK.shape[0],
|
|
1089
|
+
mCuSeqlensQ=mCuSeqlensQ,
|
|
1090
|
+
mCuSeqlensK=mCuSeqlensK,
|
|
1091
|
+
mSeqUsedQ=mSeqUsedQ,
|
|
1092
|
+
mSeqUsedK=mSeqUsedK,
|
|
1093
|
+
tile_m=self.tile_m,
|
|
1094
|
+
tile_n=self.tile_n,
|
|
1095
|
+
)
|
|
1096
|
+
TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params)
|
|
1097
|
+
|
|
1098
|
+
AttentionMaskCls = partial(
|
|
1099
|
+
AttentionMask,
|
|
1100
|
+
self.tile_m,
|
|
1101
|
+
self.tile_n,
|
|
1102
|
+
swap_AB=True,
|
|
1103
|
+
window_size_left=window_size_left,
|
|
1104
|
+
window_size_right=window_size_right,
|
|
1105
|
+
)
|
|
1106
|
+
|
|
1107
|
+
# EMPTY
|
|
1108
|
+
# (15)
|
|
1109
|
+
if warp_idx == self.empty_warp_id:
|
|
1110
|
+
cute.arch.warpgroup_reg_dealloc(self.num_regs_empty)
|
|
1111
|
+
|
|
1112
|
+
# EPI
|
|
1113
|
+
# (14)
|
|
1114
|
+
if warp_idx == self.epi_warp_id:
|
|
1115
|
+
# currently no-op, could use for tma store/reduce
|
|
1116
|
+
cute.arch.warpgroup_reg_dealloc(self.num_regs_empty)
|
|
1117
|
+
|
|
1118
|
+
# LOAD
|
|
1119
|
+
# (13)
|
|
1120
|
+
if warp_idx == self.load_warp_id:
|
|
1121
|
+
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
|
|
1122
|
+
self.load(
|
|
1123
|
+
thr_mma_S,
|
|
1124
|
+
thr_mma_dP,
|
|
1125
|
+
thr_mma_dV,
|
|
1126
|
+
mQ,
|
|
1127
|
+
mK,
|
|
1128
|
+
mV,
|
|
1129
|
+
mLSE,
|
|
1130
|
+
mdPsum,
|
|
1131
|
+
mdO,
|
|
1132
|
+
sQ,
|
|
1133
|
+
sK,
|
|
1134
|
+
sV,
|
|
1135
|
+
sLSE,
|
|
1136
|
+
sdPsum,
|
|
1137
|
+
sdO,
|
|
1138
|
+
tma_atom_Q,
|
|
1139
|
+
tma_atom_K,
|
|
1140
|
+
tma_atom_V,
|
|
1141
|
+
tma_atom_dO,
|
|
1142
|
+
pipeline_Q,
|
|
1143
|
+
pipeline_dO,
|
|
1144
|
+
pipeline_LSE,
|
|
1145
|
+
pipeline_dPsum,
|
|
1146
|
+
cluster_layout_vmnk,
|
|
1147
|
+
block_info,
|
|
1148
|
+
SeqlenInfoCls,
|
|
1149
|
+
TileSchedulerCls,
|
|
1150
|
+
blocksparse_tensors,
|
|
1151
|
+
should_load_Q=True,
|
|
1152
|
+
should_load_dO=True,
|
|
1153
|
+
)
|
|
1154
|
+
|
|
1155
|
+
# MMA
|
|
1156
|
+
# (12)
|
|
1157
|
+
if warp_idx == self.mma_warp_id:
|
|
1158
|
+
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
|
|
1159
|
+
|
|
1160
|
+
# Alloc tmem buffer
|
|
1161
|
+
tmem_alloc_cols = Int32(self.tmem_alloc_cols)
|
|
1162
|
+
cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf)
|
|
1163
|
+
cute.arch.sync_warp()
|
|
1164
|
+
|
|
1165
|
+
self.mma(
|
|
1166
|
+
tiled_mma_S,
|
|
1167
|
+
tiled_mma_dP,
|
|
1168
|
+
tiled_mma_dV,
|
|
1169
|
+
tiled_mma_dK,
|
|
1170
|
+
tiled_mma_dQ,
|
|
1171
|
+
sQ,
|
|
1172
|
+
sQt,
|
|
1173
|
+
sK,
|
|
1174
|
+
sV,
|
|
1175
|
+
sdO,
|
|
1176
|
+
sdOt,
|
|
1177
|
+
sdSt,
|
|
1178
|
+
sdS,
|
|
1179
|
+
sKt,
|
|
1180
|
+
tP,
|
|
1181
|
+
tdS,
|
|
1182
|
+
tStS,
|
|
1183
|
+
tdPtdP,
|
|
1184
|
+
tdVtdV,
|
|
1185
|
+
tdKtdK,
|
|
1186
|
+
tdQtdQ,
|
|
1187
|
+
pipeline_Q.make_consumer(),
|
|
1188
|
+
pipeline_dO,
|
|
1189
|
+
pipeline_S_P,
|
|
1190
|
+
pipeline_dS,
|
|
1191
|
+
pipeline_dKV,
|
|
1192
|
+
pipeline_dP,
|
|
1193
|
+
pipeline_dQ,
|
|
1194
|
+
block_info,
|
|
1195
|
+
SeqlenInfoCls,
|
|
1196
|
+
TileSchedulerCls,
|
|
1197
|
+
blocksparse_tensors,
|
|
1198
|
+
)
|
|
1199
|
+
cute.arch.relinquish_tmem_alloc_permit()
|
|
1200
|
+
tmem_ptr = cute.arch.retrieve_tmem_ptr(
|
|
1201
|
+
Float32, alignment=16, ptr_to_buffer_holding_addr=storage.tmem_holding_buf
|
|
1202
|
+
)
|
|
1203
|
+
|
|
1204
|
+
cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0)
|
|
1205
|
+
tmem_alloc_cols = Int32(self.tmem_alloc_cols)
|
|
1206
|
+
cute.arch.dealloc_tmem(tmem_ptr, tmem_alloc_cols, is_two_cta=False)
|
|
1207
|
+
|
|
1208
|
+
# Compute
|
|
1209
|
+
# (4, 5, 6, 7, 8, 9, 10, 11) --> 8 warps
|
|
1210
|
+
if warp_idx >= self.compute_warp_ids[0] and warp_idx <= self.compute_warp_ids[-1]:
|
|
1211
|
+
cute.arch.warpgroup_reg_alloc(self.num_regs_compute) # 8 warps
|
|
1212
|
+
self.compute_loop(
|
|
1213
|
+
thr_mma_S,
|
|
1214
|
+
thr_mma_dP,
|
|
1215
|
+
thr_mma_dV,
|
|
1216
|
+
thr_mma_dK,
|
|
1217
|
+
tStS,
|
|
1218
|
+
sLSE,
|
|
1219
|
+
sdPsum,
|
|
1220
|
+
tdVtdV,
|
|
1221
|
+
tdKtdK,
|
|
1222
|
+
mdV,
|
|
1223
|
+
mdK,
|
|
1224
|
+
sdS,
|
|
1225
|
+
tdPtdP,
|
|
1226
|
+
pipeline_LSE,
|
|
1227
|
+
pipeline_dPsum,
|
|
1228
|
+
pipeline_S_P,
|
|
1229
|
+
pipeline_dS,
|
|
1230
|
+
pipeline_dKV,
|
|
1231
|
+
pipeline_dP,
|
|
1232
|
+
softmax_scale,
|
|
1233
|
+
softmax_scale_log2,
|
|
1234
|
+
block_info,
|
|
1235
|
+
SeqlenInfoCls,
|
|
1236
|
+
AttentionMaskCls,
|
|
1237
|
+
TileSchedulerCls,
|
|
1238
|
+
sdV,
|
|
1239
|
+
sdK,
|
|
1240
|
+
mdV_tma_tensor,
|
|
1241
|
+
mdK_tma_tensor,
|
|
1242
|
+
tma_atom_dV,
|
|
1243
|
+
tma_atom_dK,
|
|
1244
|
+
tiled_copy_r2s_dKV,
|
|
1245
|
+
mdK_semaphore,
|
|
1246
|
+
mdV_semaphore,
|
|
1247
|
+
aux_tensors,
|
|
1248
|
+
fastdiv_mods,
|
|
1249
|
+
blocksparse_tensors,
|
|
1250
|
+
)
|
|
1251
|
+
cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr)
|
|
1252
|
+
|
|
1253
|
+
# Reduce
|
|
1254
|
+
# (0, 1, 2, 3) - dQ
|
|
1255
|
+
if warp_idx >= self.reduce_warp_ids[0] and warp_idx <= self.reduce_warp_ids[-1]:
|
|
1256
|
+
cute.arch.warpgroup_reg_alloc(self.num_regs_reduce)
|
|
1257
|
+
self.dQacc_reduce(
|
|
1258
|
+
mdQaccum,
|
|
1259
|
+
sdQaccum,
|
|
1260
|
+
thr_mma_dQ,
|
|
1261
|
+
tdQtdQ,
|
|
1262
|
+
pipeline_dQ,
|
|
1263
|
+
block_info,
|
|
1264
|
+
SeqlenInfoCls,
|
|
1265
|
+
TileSchedulerCls,
|
|
1266
|
+
mdQ_semaphore,
|
|
1267
|
+
blocksparse_tensors,
|
|
1268
|
+
)
|
|
1269
|
+
|
|
1270
|
+
return
|
|
1271
|
+
|
|
1272
|
+
@cute.jit
|
|
1273
|
+
def load(
|
|
1274
|
+
self,
|
|
1275
|
+
thr_mma_S: cute.core.ThrMma,
|
|
1276
|
+
thr_mma_dP: cute.core.ThrMma,
|
|
1277
|
+
thr_mma_dV: cute.core.ThrMma,
|
|
1278
|
+
mQ: cute.Tensor,
|
|
1279
|
+
mK: cute.Tensor,
|
|
1280
|
+
mV: cute.Tensor,
|
|
1281
|
+
mLSE: cute.Tensor,
|
|
1282
|
+
mdPsum: cute.Tensor,
|
|
1283
|
+
mdO: cute.Tensor,
|
|
1284
|
+
sQ: cute.Tensor,
|
|
1285
|
+
sK: cute.Tensor,
|
|
1286
|
+
sV: cute.Tensor,
|
|
1287
|
+
sLSE: cute.Tensor,
|
|
1288
|
+
sdPsum: cute.Tensor,
|
|
1289
|
+
sdO: cute.Tensor,
|
|
1290
|
+
tma_atom_Q: cute.CopyAtom,
|
|
1291
|
+
tma_atom_K: cute.CopyAtom,
|
|
1292
|
+
tma_atom_V: cute.CopyAtom,
|
|
1293
|
+
tma_atom_dO: cute.CopyAtom,
|
|
1294
|
+
pipeline_Q: PipelineAsync,
|
|
1295
|
+
pipeline_dO: PipelineAsync,
|
|
1296
|
+
pipeline_LSE: PipelineAsync,
|
|
1297
|
+
pipeline_dPsum: PipelineAsync,
|
|
1298
|
+
cluster_layout_vmnk: cute.Layout,
|
|
1299
|
+
block_info: BlockInfo,
|
|
1300
|
+
SeqlenInfoCls: Callable,
|
|
1301
|
+
TileSchedulerCls: Callable,
|
|
1302
|
+
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
|
1303
|
+
should_load_Q: bool = True,
|
|
1304
|
+
should_load_dO: bool = True,
|
|
1305
|
+
):
|
|
1306
|
+
producer_state_Q_LSE = cutlass.pipeline.make_pipeline_state(
|
|
1307
|
+
cutlass.pipeline.PipelineUserType.Producer, self.Q_stage
|
|
1308
|
+
)
|
|
1309
|
+
producer_state_dO_dPsum = cutlass.pipeline.make_pipeline_state(
|
|
1310
|
+
cutlass.pipeline.PipelineUserType.Producer, self.dO_stage
|
|
1311
|
+
)
|
|
1312
|
+
|
|
1313
|
+
# Compute multicast mask for Q & dO buffer full
|
|
1314
|
+
cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
|
|
1315
|
+
block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
|
|
1316
|
+
q_do_mcast_mask = None
|
|
1317
|
+
if const_expr(self.is_q_do_mcast):
|
|
1318
|
+
q_do_mcast_mask = cpasync.create_tma_multicast_mask(
|
|
1319
|
+
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1
|
|
1320
|
+
)
|
|
1321
|
+
|
|
1322
|
+
tile_scheduler = TileSchedulerCls()
|
|
1323
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1324
|
+
while work_tile.is_valid_tile:
|
|
1325
|
+
n_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
|
1326
|
+
seqlen = SeqlenInfoCls(batch_idx)
|
|
1327
|
+
m_block_min, m_block_max = block_info.get_m_block_min_max(
|
|
1328
|
+
seqlen, n_block // self.cluster_shape_mnk[0]
|
|
1329
|
+
)
|
|
1330
|
+
head_idx_kv = head_idx // self.qhead_per_kvhead
|
|
1331
|
+
mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]
|
|
1332
|
+
mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv]
|
|
1333
|
+
mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv]
|
|
1334
|
+
if const_expr(not seqlen.has_cu_seqlens_q):
|
|
1335
|
+
mdO_cur = mdO[None, None, head_idx, batch_idx]
|
|
1336
|
+
else:
|
|
1337
|
+
mdO_cur = cute.domain_offset((0, seqlen.offset_q), mdO[None, None, head_idx])
|
|
1338
|
+
mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2, padded=True)[None, head_idx]
|
|
1339
|
+
mdPsum_cur = seqlen.offset_batch_Q(mdPsum, batch_idx, dim=2, padded=True)[
|
|
1340
|
+
None, head_idx
|
|
1341
|
+
]
|
|
1342
|
+
|
|
1343
|
+
gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_kq, mode=[0, 2]), (n_block, 0))
|
|
1344
|
+
tSgK = thr_mma_S.partition_A(gK)
|
|
1345
|
+
gV = cute.local_tile(mV_cur, cute.select(self.mma_tiler_vdo, mode=[0, 2]), (n_block, 0))
|
|
1346
|
+
tdPgV = thr_mma_dP.partition_A(gV)
|
|
1347
|
+
gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_kq, mode=[1, 2]), (None, 0))
|
|
1348
|
+
tSgQ = thr_mma_S.partition_B(gQ)
|
|
1349
|
+
gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,))
|
|
1350
|
+
gdPsum = cute.local_tile(mdPsum_cur, (self.tile_m,), (None,))
|
|
1351
|
+
gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None))
|
|
1352
|
+
tdPgdO = thr_mma_dV.partition_B(gdO)
|
|
1353
|
+
|
|
1354
|
+
load_K, _, _ = copy_utils.tma_get_copy_fn(
|
|
1355
|
+
tma_atom_K, 0, cute.make_layout(1), tSgK, sK, single_stage=True
|
|
1356
|
+
)
|
|
1357
|
+
load_V, _, _ = copy_utils.tma_get_copy_fn(
|
|
1358
|
+
tma_atom_V,
|
|
1359
|
+
0,
|
|
1360
|
+
cute.make_layout(1),
|
|
1361
|
+
tdPgV,
|
|
1362
|
+
sV,
|
|
1363
|
+
single_stage=True,
|
|
1364
|
+
)
|
|
1365
|
+
b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape)
|
|
1366
|
+
load_Q, _, _ = copy_utils.tma_get_copy_fn(
|
|
1367
|
+
tma_atom_Q,
|
|
1368
|
+
cta_coord=block_in_cluster_coord_vmnk[1],
|
|
1369
|
+
cta_layout=b_cta_layout,
|
|
1370
|
+
src_tensor=tSgQ,
|
|
1371
|
+
dst_tensor=sQ,
|
|
1372
|
+
mcast_mask=q_do_mcast_mask,
|
|
1373
|
+
)
|
|
1374
|
+
load_Q = copy_utils.tma_producer_copy_fn(load_Q, pipeline_Q)
|
|
1375
|
+
load_dO, _, _ = copy_utils.tma_get_copy_fn(
|
|
1376
|
+
tma_atom_dO,
|
|
1377
|
+
cta_coord=block_in_cluster_coord_vmnk[1],
|
|
1378
|
+
cta_layout=b_cta_layout,
|
|
1379
|
+
src_tensor=tdPgdO,
|
|
1380
|
+
dst_tensor=sdO,
|
|
1381
|
+
mcast_mask=q_do_mcast_mask,
|
|
1382
|
+
)
|
|
1383
|
+
load_dO = copy_utils.tma_producer_copy_fn(load_dO, pipeline_dO)
|
|
1384
|
+
copy_atom_stats = cute.make_copy_atom(cpasync.CopyBulkG2SOp(), Float32)
|
|
1385
|
+
copy_stats = partial(cute.copy, copy_atom_stats)
|
|
1386
|
+
# copy_atom_stats = cute.make_copy_atom(cpasync.CopyBulkG2SMulticastOp(), Float32)
|
|
1387
|
+
# sLSE = cute.logical_divide(sLSE, (64,))[(None, block_in_cluster_coord_vmnk[1]), None]
|
|
1388
|
+
# gLSE = cute.logical_divide(gLSE, (64,))[(None, block_in_cluster_coord_vmnk[1]), None]
|
|
1389
|
+
# sdPsum = cute.logical_divide(sdPsum, (64,))[(None, block_in_cluster_coord_vmnk[1]), None]
|
|
1390
|
+
# gdPsum = cute.logical_divide(gdPsum, (64,))[(None, block_in_cluster_coord_vmnk[1]), None]
|
|
1391
|
+
# copy_stats = partial(cute.copy, copy_atom_stats, mcast_mask=q_do_mcast_mask)
|
|
1392
|
+
|
|
1393
|
+
# some tiles might be empty due to block sparsity
|
|
1394
|
+
if const_expr(self.use_block_sparsity):
|
|
1395
|
+
total_m_block_cnt = get_total_q_block_count_bwd(
|
|
1396
|
+
blocksparse_tensors,
|
|
1397
|
+
batch_idx,
|
|
1398
|
+
head_idx,
|
|
1399
|
+
n_block,
|
|
1400
|
+
subtile_factor=self.subtile_factor,
|
|
1401
|
+
m_block_max=m_block_max,
|
|
1402
|
+
)
|
|
1403
|
+
process_tile = total_m_block_cnt > Int32(0)
|
|
1404
|
+
else:
|
|
1405
|
+
process_tile = (
|
|
1406
|
+
const_expr(not self.is_local and not self.is_varlen_q)
|
|
1407
|
+
or m_block_min < m_block_max
|
|
1408
|
+
)
|
|
1409
|
+
|
|
1410
|
+
if process_tile:
|
|
1411
|
+
if const_expr(self.use_block_sparsity):
|
|
1412
|
+
producer_state_Q_LSE, producer_state_dO_dPsum = (
|
|
1413
|
+
produce_block_sparse_q_loads_bwd_sm100(
|
|
1414
|
+
blocksparse_tensors,
|
|
1415
|
+
batch_idx,
|
|
1416
|
+
head_idx,
|
|
1417
|
+
n_block,
|
|
1418
|
+
producer_state_Q_LSE,
|
|
1419
|
+
producer_state_dO_dPsum,
|
|
1420
|
+
pipeline_Q,
|
|
1421
|
+
pipeline_LSE,
|
|
1422
|
+
pipeline_dO,
|
|
1423
|
+
pipeline_dPsum,
|
|
1424
|
+
load_K,
|
|
1425
|
+
load_V,
|
|
1426
|
+
load_Q,
|
|
1427
|
+
load_dO,
|
|
1428
|
+
copy_stats,
|
|
1429
|
+
gLSE,
|
|
1430
|
+
sLSE,
|
|
1431
|
+
gdPsum,
|
|
1432
|
+
sdPsum,
|
|
1433
|
+
self.tma_copy_bytes["K"],
|
|
1434
|
+
self.tma_copy_bytes["V"],
|
|
1435
|
+
should_load_Q=should_load_Q,
|
|
1436
|
+
should_load_dO=should_load_dO,
|
|
1437
|
+
subtile_factor=self.subtile_factor,
|
|
1438
|
+
m_block_max=m_block_max,
|
|
1439
|
+
)
|
|
1440
|
+
)
|
|
1441
|
+
else:
|
|
1442
|
+
first_m_block = m_block_min
|
|
1443
|
+
|
|
1444
|
+
# First iteration: load K together w Q & LSE, then V together w dO & dPsum
|
|
1445
|
+
if const_expr(should_load_Q):
|
|
1446
|
+
pipeline_Q.producer_acquire(
|
|
1447
|
+
producer_state_Q_LSE, extra_tx_count=self.tma_copy_bytes["K"]
|
|
1448
|
+
)
|
|
1449
|
+
load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE))
|
|
1450
|
+
load_Q(first_m_block, producer_state=producer_state_Q_LSE)
|
|
1451
|
+
pipeline_Q.producer_commit(producer_state_Q_LSE)
|
|
1452
|
+
pipeline_LSE.producer_acquire(producer_state_Q_LSE)
|
|
1453
|
+
with cute.arch.elect_one():
|
|
1454
|
+
copy_stats(
|
|
1455
|
+
gLSE[None, first_m_block],
|
|
1456
|
+
sLSE[None, producer_state_Q_LSE.index],
|
|
1457
|
+
mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE),
|
|
1458
|
+
)
|
|
1459
|
+
producer_state_Q_LSE.advance()
|
|
1460
|
+
if const_expr(should_load_dO):
|
|
1461
|
+
pipeline_dO.producer_acquire(
|
|
1462
|
+
producer_state_dO_dPsum, extra_tx_count=self.tma_copy_bytes["V"]
|
|
1463
|
+
)
|
|
1464
|
+
load_V(
|
|
1465
|
+
tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum)
|
|
1466
|
+
)
|
|
1467
|
+
load_dO(first_m_block, producer_state=producer_state_dO_dPsum)
|
|
1468
|
+
pipeline_dO.producer_commit(producer_state_dO_dPsum)
|
|
1469
|
+
pipeline_dPsum.producer_acquire(producer_state_dO_dPsum)
|
|
1470
|
+
with cute.arch.elect_one():
|
|
1471
|
+
copy_stats(
|
|
1472
|
+
gdPsum[None, first_m_block],
|
|
1473
|
+
sdPsum[None, producer_state_dO_dPsum.index],
|
|
1474
|
+
mbar_ptr=pipeline_dPsum.producer_get_barrier(
|
|
1475
|
+
producer_state_dO_dPsum
|
|
1476
|
+
),
|
|
1477
|
+
)
|
|
1478
|
+
producer_state_dO_dPsum.advance()
|
|
1479
|
+
|
|
1480
|
+
# Dense path: iterate from m_block_min+1 to m_block_max
|
|
1481
|
+
for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1):
|
|
1482
|
+
if const_expr(should_load_Q):
|
|
1483
|
+
pipeline_Q.producer_acquire(producer_state_Q_LSE)
|
|
1484
|
+
load_Q(m_block, producer_state=producer_state_Q_LSE)
|
|
1485
|
+
pipeline_Q.producer_commit(producer_state_Q_LSE)
|
|
1486
|
+
pipeline_LSE.producer_acquire(producer_state_Q_LSE)
|
|
1487
|
+
with cute.arch.elect_one():
|
|
1488
|
+
copy_stats(
|
|
1489
|
+
gLSE[None, m_block],
|
|
1490
|
+
sLSE[None, producer_state_Q_LSE.index],
|
|
1491
|
+
mbar_ptr=pipeline_LSE.producer_get_barrier(
|
|
1492
|
+
producer_state_Q_LSE
|
|
1493
|
+
),
|
|
1494
|
+
)
|
|
1495
|
+
producer_state_Q_LSE.advance()
|
|
1496
|
+
if const_expr(should_load_dO):
|
|
1497
|
+
pipeline_dO.producer_acquire(producer_state_dO_dPsum)
|
|
1498
|
+
load_dO(m_block, producer_state=producer_state_dO_dPsum)
|
|
1499
|
+
pipeline_dO.producer_commit(producer_state_dO_dPsum)
|
|
1500
|
+
pipeline_dPsum.producer_acquire(producer_state_dO_dPsum)
|
|
1501
|
+
with cute.arch.elect_one():
|
|
1502
|
+
copy_stats(
|
|
1503
|
+
gdPsum[None, m_block],
|
|
1504
|
+
sdPsum[None, producer_state_dO_dPsum.index],
|
|
1505
|
+
mbar_ptr=pipeline_dPsum.producer_get_barrier(
|
|
1506
|
+
producer_state_dO_dPsum
|
|
1507
|
+
),
|
|
1508
|
+
)
|
|
1509
|
+
producer_state_dO_dPsum.advance()
|
|
1510
|
+
|
|
1511
|
+
if const_expr(should_load_Q):
|
|
1512
|
+
pipeline_Q.producer_tail(
|
|
1513
|
+
producer_state_Q_LSE.clone()
|
|
1514
|
+
) # will hang if we don't clone
|
|
1515
|
+
pipeline_LSE.producer_tail(producer_state_Q_LSE)
|
|
1516
|
+
if const_expr(should_load_dO):
|
|
1517
|
+
pipeline_dO.producer_tail(producer_state_dO_dPsum.clone())
|
|
1518
|
+
pipeline_dPsum.producer_tail(producer_state_dO_dPsum)
|
|
1519
|
+
|
|
1520
|
+
tile_scheduler.prefetch_next_work()
|
|
1521
|
+
tile_scheduler.advance_to_next_work()
|
|
1522
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1523
|
+
|
|
1524
|
+
@cute.jit
|
|
1525
|
+
def mma(
|
|
1526
|
+
self,
|
|
1527
|
+
tiled_mma_S: cute.TiledMma,
|
|
1528
|
+
tiled_mma_dP: cute.TiledMma,
|
|
1529
|
+
tiled_mma_dV: cute.TiledMma,
|
|
1530
|
+
tiled_mma_dK: cute.TiledMma,
|
|
1531
|
+
tiled_mma_dQ: cute.TiledMma,
|
|
1532
|
+
sQ: cute.Tensor,
|
|
1533
|
+
sQt: cute.Tensor,
|
|
1534
|
+
sK: cute.Tensor,
|
|
1535
|
+
sV: cute.Tensor,
|
|
1536
|
+
sdO: cute.Tensor,
|
|
1537
|
+
sdOt: cute.Tensor,
|
|
1538
|
+
sdSt: cute.Tensor,
|
|
1539
|
+
sdS: cute.Tensor,
|
|
1540
|
+
sKt: cute.Tensor,
|
|
1541
|
+
tP: cute.Tensor,
|
|
1542
|
+
tdS: cute.Tensor,
|
|
1543
|
+
tStS: cute.Tensor,
|
|
1544
|
+
tdPtdP: cute.Tensor,
|
|
1545
|
+
tdVtdV: cute.Tensor,
|
|
1546
|
+
tdKtdK: cute.Tensor,
|
|
1547
|
+
tdQtdQ: cute.Tensor,
|
|
1548
|
+
pipeline_Q_consumer: PipelineConsumer,
|
|
1549
|
+
pipeline_dO: PipelineAsync,
|
|
1550
|
+
pipeline_S_P: PipelineAsync,
|
|
1551
|
+
pipeline_dS: PipelineAsync,
|
|
1552
|
+
pipeline_dKV: PipelineAsync,
|
|
1553
|
+
pipeline_dP: PipelineAsync,
|
|
1554
|
+
pipeline_dQ: PipelineAsync,
|
|
1555
|
+
block_info: BlockInfo,
|
|
1556
|
+
SeqlenInfoCls: Callable,
|
|
1557
|
+
TileSchedulerCls: Callable,
|
|
1558
|
+
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
|
1559
|
+
):
|
|
1560
|
+
# [2025-10-21] For reasons I don't understand, putting these partitioning in the main
|
|
1561
|
+
# kernel (before warp specialization) is a lot slower tha putting them here.
|
|
1562
|
+
# Partition smem / tmem tensors
|
|
1563
|
+
# S = K @ Q.T
|
|
1564
|
+
tSrK = tiled_mma_S.make_fragment_A(sK)
|
|
1565
|
+
tSrQ = tiled_mma_S.make_fragment_B(sQ)
|
|
1566
|
+
# dP = V @ dO.T
|
|
1567
|
+
tdPrV = tiled_mma_dP.make_fragment_A(sV)
|
|
1568
|
+
tdPrdOt = tiled_mma_dP.make_fragment_B(sdOt)
|
|
1569
|
+
# dK = dS.T @ Q
|
|
1570
|
+
if const_expr(self.use_smem_dS_for_mma_dK):
|
|
1571
|
+
tdKrdS = tiled_mma_dK.make_fragment_A(sdSt)
|
|
1572
|
+
else:
|
|
1573
|
+
tdKrdS = tiled_mma_dK.make_fragment_A(tdS)
|
|
1574
|
+
tdKrQ = tiled_mma_dK.make_fragment_B(sQt)
|
|
1575
|
+
# dQ = dS @ K
|
|
1576
|
+
tdQrdS = tiled_mma_dQ.make_fragment_A(sdS)
|
|
1577
|
+
tdQrK = tiled_mma_dQ.make_fragment_B(sKt)
|
|
1578
|
+
# dV = P @ dO.T
|
|
1579
|
+
tdVrdO = tiled_mma_dV.make_fragment_B(sdO)
|
|
1580
|
+
tdVrP = tiled_mma_dV.make_fragment_A(tP)
|
|
1581
|
+
|
|
1582
|
+
# mma_qk_fn = partial(gemm_w_idx, tiled_mma_S, tStS, tSrK, tSrQ, zero_init=True)
|
|
1583
|
+
mma_qk_fn = partial(
|
|
1584
|
+
gemm_ptx_w_idx, tiled_mma_S, tStS, tSrK, tSrQ, sA=sK, sB=sQ, zero_init=True
|
|
1585
|
+
)
|
|
1586
|
+
# mma_dov_fn = partial(gemm_w_idx, tiled_mma_dP, tdPtdP, tdPrV, tdPrdOt, zero_init=True)
|
|
1587
|
+
mma_dov_fn = partial(
|
|
1588
|
+
gemm_ptx_w_idx,
|
|
1589
|
+
tiled_mma_dP,
|
|
1590
|
+
tdPtdP,
|
|
1591
|
+
tdPrV,
|
|
1592
|
+
tdPrdOt,
|
|
1593
|
+
sA=sV,
|
|
1594
|
+
sB=sdOt,
|
|
1595
|
+
zero_init=True,
|
|
1596
|
+
)
|
|
1597
|
+
# mma_pdo_fn = partial(gemm_w_idx, tiled_mma_dV, tdVtdV, tdVrP, tdVrdO)
|
|
1598
|
+
mma_pdo_fn = partial(
|
|
1599
|
+
gemm_ptx_w_idx,
|
|
1600
|
+
tiled_mma_dV,
|
|
1601
|
+
tdVtdV,
|
|
1602
|
+
tdVrP,
|
|
1603
|
+
tdVrdO,
|
|
1604
|
+
sA=None,
|
|
1605
|
+
sB=sdO,
|
|
1606
|
+
tA_addr=self.tmem_P_offset,
|
|
1607
|
+
)
|
|
1608
|
+
mma_dsk_fn = partial(gemm_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, zero_init=True)
|
|
1609
|
+
# mma_dsk_fn = partial(
|
|
1610
|
+
# gemm_ptx_w_idx, tiled_mma_dQ, tdQtdQ, tdQrdS, tdQrK, sA=sdS, sB=sKt, zero_init=True
|
|
1611
|
+
# )
|
|
1612
|
+
if const_expr(self.use_smem_dS_for_mma_dK):
|
|
1613
|
+
mma_dsq_fn = partial(gemm_w_idx, tiled_mma_dK, tdKtdK, tdKrdS, tdKrQ)
|
|
1614
|
+
else:
|
|
1615
|
+
# Need to explicitly pass in tA_addr for correctness
|
|
1616
|
+
mma_dsq_fn = partial(
|
|
1617
|
+
gemm_ptx_w_idx,
|
|
1618
|
+
tiled_mma_dK,
|
|
1619
|
+
tdKtdK,
|
|
1620
|
+
tdKrdS,
|
|
1621
|
+
tdKrQ,
|
|
1622
|
+
sA=None,
|
|
1623
|
+
sB=sQt,
|
|
1624
|
+
tA_addr=self.tmem_dS_offset,
|
|
1625
|
+
)
|
|
1626
|
+
|
|
1627
|
+
consumer_state_dO = cutlass.pipeline.make_pipeline_state(
|
|
1628
|
+
cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage
|
|
1629
|
+
)
|
|
1630
|
+
producer_phase_acc = Int32(1) # For S & P, dP, dQ
|
|
1631
|
+
consumer_state_dS = cutlass.pipeline.make_pipeline_state(
|
|
1632
|
+
cutlass.pipeline.PipelineUserType.Consumer, 1
|
|
1633
|
+
)
|
|
1634
|
+
# producer_state_dKV = cutlass.pipeline.make_pipeline_state(
|
|
1635
|
+
# cutlass.pipeline.PipelineUserType.Producer, 2
|
|
1636
|
+
# )
|
|
1637
|
+
producer_phase_dKV = Int32(1)
|
|
1638
|
+
cta_group = pipeline_S_P.cta_group
|
|
1639
|
+
|
|
1640
|
+
tile_scheduler = TileSchedulerCls()
|
|
1641
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
1642
|
+
while work_tile.is_valid_tile:
|
|
1643
|
+
n_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
|
1644
|
+
seqlen = SeqlenInfoCls(batch_idx) # must be seqlen_k
|
|
1645
|
+
m_block_min, m_block_max = block_info.get_m_block_min_max(
|
|
1646
|
+
seqlen, n_block // self.cluster_shape_mnk[0]
|
|
1647
|
+
)
|
|
1648
|
+
|
|
1649
|
+
if const_expr(self.use_block_sparsity):
|
|
1650
|
+
block_iter_count = get_total_q_block_count_bwd(
|
|
1651
|
+
blocksparse_tensors,
|
|
1652
|
+
batch_idx,
|
|
1653
|
+
head_idx,
|
|
1654
|
+
n_block,
|
|
1655
|
+
subtile_factor=self.subtile_factor,
|
|
1656
|
+
m_block_max=m_block_max,
|
|
1657
|
+
)
|
|
1658
|
+
process_tile = block_iter_count > Int32(0)
|
|
1659
|
+
else:
|
|
1660
|
+
block_iter_count = m_block_max - m_block_min
|
|
1661
|
+
process_tile = (
|
|
1662
|
+
const_expr(not self.is_local and not self.is_varlen_q)
|
|
1663
|
+
or m_block_min < m_block_max
|
|
1664
|
+
)
|
|
1665
|
+
|
|
1666
|
+
if process_tile:
|
|
1667
|
+
accumulate_dK = False
|
|
1668
|
+
# -----------------------------------------------------------
|
|
1669
|
+
###### Prologue
|
|
1670
|
+
# -----------------------------------------------------------
|
|
1671
|
+
# 1. S = Q0 @ K.T
|
|
1672
|
+
# 2. dP = V @ dO.T
|
|
1673
|
+
# 3. dV = P @ dO
|
|
1674
|
+
# 1) S = Q0 @ K.T
|
|
1675
|
+
handle_Q = pipeline_Q_consumer.wait_and_advance()
|
|
1676
|
+
pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc)
|
|
1677
|
+
mma_qk_fn(B_idx=handle_Q.index)
|
|
1678
|
+
# Don't release Q yet
|
|
1679
|
+
pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group)
|
|
1680
|
+
|
|
1681
|
+
# 2) dP = V @ dO.T
|
|
1682
|
+
pipeline_dO.consumer_wait(consumer_state_dO)
|
|
1683
|
+
pipeline_dP.sync_object_empty.wait(0, producer_phase_acc)
|
|
1684
|
+
# dQ uses the same tmem as dP
|
|
1685
|
+
pipeline_dQ.sync_object_empty.wait(0, producer_phase_acc)
|
|
1686
|
+
mma_dov_fn(B_idx=consumer_state_dO.index)
|
|
1687
|
+
# Don't release dO yet
|
|
1688
|
+
pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group)
|
|
1689
|
+
|
|
1690
|
+
producer_phase_acc ^= 1
|
|
1691
|
+
# 3) dV = P.T @ dO
|
|
1692
|
+
# wait for P to be ready, which uses the same tmem as S
|
|
1693
|
+
pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc)
|
|
1694
|
+
mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=True)
|
|
1695
|
+
pipeline_dO.consumer_release(consumer_state_dO)
|
|
1696
|
+
consumer_state_dO.advance()
|
|
1697
|
+
# -----------------------------------------------------------
|
|
1698
|
+
###### MAIN LOOP
|
|
1699
|
+
# -----------------------------------------------------------
|
|
1700
|
+
# 1. S = K @ Q.T
|
|
1701
|
+
# 2. dQ = dS @ K
|
|
1702
|
+
# 3. dK = dS.T @ Q
|
|
1703
|
+
# 4. dP = V @ dO.T
|
|
1704
|
+
# 5. dV = P.T @ dO
|
|
1705
|
+
|
|
1706
|
+
# For block sparsity, we use block_iter_count; for dense, use m_block range
|
|
1707
|
+
# MMA doesn't need actual m_block indices, just the iteration count
|
|
1708
|
+
main_loop_iters = (
|
|
1709
|
+
block_iter_count - 1
|
|
1710
|
+
if const_expr(self.use_block_sparsity)
|
|
1711
|
+
else m_block_max - m_block_min - 1
|
|
1712
|
+
)
|
|
1713
|
+
for _ in cutlass.range(main_loop_iters, unroll=1):
|
|
1714
|
+
# 1) S = K @ Q_i
|
|
1715
|
+
handle_Q_next = pipeline_Q_consumer.wait_and_advance()
|
|
1716
|
+
# Don't need to wait for S, as P must have been ready ealier, i.e., S is ready
|
|
1717
|
+
mma_qk_fn(B_idx=handle_Q_next.index)
|
|
1718
|
+
pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group)
|
|
1719
|
+
|
|
1720
|
+
# 2-3)
|
|
1721
|
+
# Do dK = dS.T @ Q, then dQ = dS @ K if dS in tmem for first mma
|
|
1722
|
+
# Otherwise, reverse order
|
|
1723
|
+
pipeline_dS.consumer_wait(consumer_state_dS)
|
|
1724
|
+
|
|
1725
|
+
if const_expr(self.use_smem_dS_for_mma_dK):
|
|
1726
|
+
mma_dsk_fn()
|
|
1727
|
+
pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group)
|
|
1728
|
+
mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK)
|
|
1729
|
+
accumulate_dK = True
|
|
1730
|
+
handle_Q.release()
|
|
1731
|
+
else:
|
|
1732
|
+
mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK)
|
|
1733
|
+
accumulate_dK = True
|
|
1734
|
+
handle_Q.release()
|
|
1735
|
+
mma_dsk_fn()
|
|
1736
|
+
pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group)
|
|
1737
|
+
|
|
1738
|
+
# dP uses the same tmem as dQ
|
|
1739
|
+
# However, if dS is ready, then dP must have been ready,
|
|
1740
|
+
# so we don't need this wait before mma_dsk_fn()
|
|
1741
|
+
# pipeline_dP.sync_object_empty.wait(0, producer_phase_acc)
|
|
1742
|
+
|
|
1743
|
+
pipeline_dS.consumer_release(consumer_state_dS)
|
|
1744
|
+
consumer_state_dS.advance()
|
|
1745
|
+
|
|
1746
|
+
# 4) dP = V @ dO.T
|
|
1747
|
+
pipeline_dO.consumer_wait(consumer_state_dO)
|
|
1748
|
+
# dQ uses the same tmem as dP
|
|
1749
|
+
pipeline_dQ.sync_object_empty.wait(0, producer_phase_acc)
|
|
1750
|
+
mma_dov_fn(B_idx=consumer_state_dO.index)
|
|
1751
|
+
pipeline_dP.sync_object_full.arrive(0, pipeline_dP.producer_mask, cta_group)
|
|
1752
|
+
|
|
1753
|
+
producer_phase_acc ^= 1
|
|
1754
|
+
# 5) dV += P @ dO
|
|
1755
|
+
# wait for P to be ready, which uses the same tmem as S
|
|
1756
|
+
pipeline_S_P.sync_object_empty.wait(0, producer_phase_acc)
|
|
1757
|
+
mma_pdo_fn(B_idx=consumer_state_dO.index, zero_init=False)
|
|
1758
|
+
pipeline_dO.consumer_release(consumer_state_dO)
|
|
1759
|
+
consumer_state_dO.advance()
|
|
1760
|
+
|
|
1761
|
+
handle_Q = handle_Q_next
|
|
1762
|
+
|
|
1763
|
+
pipeline_S_P.sync_object_full.arrive(0, pipeline_S_P.producer_mask, cta_group)
|
|
1764
|
+
|
|
1765
|
+
# signal to the epilogue that dV is ready
|
|
1766
|
+
# pipeline_dKV.producer_acquire(producer_state_dKV)
|
|
1767
|
+
pipeline_dKV.sync_object_empty.wait(0, producer_phase_dKV)
|
|
1768
|
+
# pipeline_dKV.producer_commit(producer_state_dKV)
|
|
1769
|
+
pipeline_dKV.sync_object_full.arrive(0, pipeline_dKV.producer_mask, cta_group)
|
|
1770
|
+
# producer_state_dKV.advance()
|
|
1771
|
+
# pipeline_dKV.producer_acquire(producer_state_dKV)
|
|
1772
|
+
pipeline_dKV.sync_object_empty.wait(1, producer_phase_dKV)
|
|
1773
|
+
|
|
1774
|
+
# -----------------------------------------------------------
|
|
1775
|
+
###### Remaining 2
|
|
1776
|
+
# -----------------------------------------------------------
|
|
1777
|
+
# 1) dK += dS.T @ Q
|
|
1778
|
+
pipeline_dS.consumer_wait(consumer_state_dS)
|
|
1779
|
+
mma_dsq_fn(B_idx=handle_Q.index, zero_init=not accumulate_dK)
|
|
1780
|
+
# signal to the epilogue that dK is ready
|
|
1781
|
+
# pipeline_dKV.producer_commit(producer_state_dKV)
|
|
1782
|
+
pipeline_dKV.sync_object_full.arrive(1, pipeline_dKV.producer_mask, cta_group)
|
|
1783
|
+
# producer_state_dKV.advance()
|
|
1784
|
+
producer_phase_dKV ^= 1
|
|
1785
|
+
|
|
1786
|
+
# 2) dQ = dS @ K
|
|
1787
|
+
# dS is done, so dP must have been ready, we don't need to wait
|
|
1788
|
+
mma_dsk_fn()
|
|
1789
|
+
pipeline_dQ.sync_object_full.arrive(0, pipeline_dQ.producer_mask, cta_group)
|
|
1790
|
+
# Wait until dQ is done before releasing Q, since K and Q0 uses the same mbarrier
|
|
1791
|
+
handle_Q.release()
|
|
1792
|
+
pipeline_dS.consumer_release(consumer_state_dS)
|
|
1793
|
+
consumer_state_dS.advance()
|
|
1794
|
+
|
|
1795
|
+
producer_phase_acc ^= 1
|
|
1796
|
+
|
|
1797
|
+
tile_scheduler.advance_to_next_work()
|
|
1798
|
+
work_tile = tile_scheduler.get_current_work()
|
|
1799
|
+
|
|
1800
|
+
# Currently it hangs if we have this S_P.producer_tail, will need to understand why
|
|
1801
|
+
# pipeline_S_P.producer_tail(producer_state_S_P)
|
|
1802
|
+
# pipeline_dP.producer_tail(producer_state_dP)
|
|
1803
|
+
# pipeline_dKV.producer_tail(producer_state_dKV)
|
|
1804
|
+
# pipeline_dQ.producer_tail(producer_state_dQ)
|
|
1805
|
+
|
|
1806
|
+
@cute.jit
|
|
1807
|
+
def split_wg(
|
|
1808
|
+
self,
|
|
1809
|
+
t: cute.Tensor,
|
|
1810
|
+
wg_idx: cutlass.Int32,
|
|
1811
|
+
num_wg: cutlass.Constexpr[int],
|
|
1812
|
+
):
|
|
1813
|
+
reduced_shape = cute.product_each(t.shape)
|
|
1814
|
+
rank = len(reduced_shape)
|
|
1815
|
+
if const_expr(reduced_shape[1] > 1):
|
|
1816
|
+
assert rank >= 2, "Need rank >= 2 for t in split_wg"
|
|
1817
|
+
t = cute.logical_divide(t, (reduced_shape[0], reduced_shape[1] // num_wg))
|
|
1818
|
+
coord = (None, (None, wg_idx)) + (None,) * (rank - 2)
|
|
1819
|
+
else:
|
|
1820
|
+
assert rank >= 3, "Need rank >= 3 for t in split_wg"
|
|
1821
|
+
if const_expr(rank == 3):
|
|
1822
|
+
t = cute.logical_divide(
|
|
1823
|
+
t, (reduced_shape[0], reduced_shape[1], reduced_shape[2] // num_wg)
|
|
1824
|
+
)
|
|
1825
|
+
coord = (
|
|
1826
|
+
None,
|
|
1827
|
+
None,
|
|
1828
|
+
(None, wg_idx),
|
|
1829
|
+
) + (None,) * (rank - 3)
|
|
1830
|
+
else:
|
|
1831
|
+
t = cute.logical_divide(
|
|
1832
|
+
t,
|
|
1833
|
+
(
|
|
1834
|
+
reduced_shape[0],
|
|
1835
|
+
reduced_shape[1],
|
|
1836
|
+
reduced_shape[2],
|
|
1837
|
+
reduced_shape[3] // num_wg,
|
|
1838
|
+
),
|
|
1839
|
+
)
|
|
1840
|
+
coord = (
|
|
1841
|
+
None,
|
|
1842
|
+
None,
|
|
1843
|
+
None,
|
|
1844
|
+
(None, wg_idx),
|
|
1845
|
+
) + (None,) * (rank - 4)
|
|
1846
|
+
return t[coord]
|
|
1847
|
+
|
|
1848
|
+
@cute.jit
|
|
1849
|
+
def apply_score_mod(
|
|
1850
|
+
self,
|
|
1851
|
+
tSrS_t2r,
|
|
1852
|
+
thr_copy_t2r,
|
|
1853
|
+
thr_mma_S,
|
|
1854
|
+
batch_idx,
|
|
1855
|
+
head_idx,
|
|
1856
|
+
m_block,
|
|
1857
|
+
n_block,
|
|
1858
|
+
softmax_scale,
|
|
1859
|
+
seqlen_info,
|
|
1860
|
+
aux_tensors=None,
|
|
1861
|
+
fastdiv_mods=(None, None),
|
|
1862
|
+
):
|
|
1863
|
+
"""Apply forward score modification for SM100 backward pass."""
|
|
1864
|
+
# In bwd, S is computed as K @ Q.T so dimensions are (tile_n, tile_m)
|
|
1865
|
+
cS = cute.make_identity_tensor((self.tile_n, self.tile_m))
|
|
1866
|
+
cS = cute.domain_offset((n_block * self.tile_n, m_block * self.tile_m), cS)
|
|
1867
|
+
tScS = thr_mma_S.partition_C(cS)
|
|
1868
|
+
tScS_idx = thr_copy_t2r.partition_D(tScS)
|
|
1869
|
+
|
|
1870
|
+
apply_score_mod_inner(
|
|
1871
|
+
tSrS_t2r,
|
|
1872
|
+
tScS_idx,
|
|
1873
|
+
self.score_mod,
|
|
1874
|
+
batch_idx,
|
|
1875
|
+
head_idx,
|
|
1876
|
+
softmax_scale,
|
|
1877
|
+
self.vec_size,
|
|
1878
|
+
self.qk_acc_dtype,
|
|
1879
|
+
aux_tensors,
|
|
1880
|
+
fastdiv_mods,
|
|
1881
|
+
seqlen_info,
|
|
1882
|
+
constant_q_idx=None,
|
|
1883
|
+
qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
|
1884
|
+
transpose_indices=True,
|
|
1885
|
+
)
|
|
1886
|
+
|
|
1887
|
+
@cute.jit
|
|
1888
|
+
def apply_score_mod_bwd(
|
|
1889
|
+
self,
|
|
1890
|
+
grad_tensor,
|
|
1891
|
+
score_tensor,
|
|
1892
|
+
index_tensor,
|
|
1893
|
+
batch_idx,
|
|
1894
|
+
head_idx,
|
|
1895
|
+
softmax_scale,
|
|
1896
|
+
seqlen_info,
|
|
1897
|
+
aux_tensors=None,
|
|
1898
|
+
fastdiv_mods=(None, None),
|
|
1899
|
+
):
|
|
1900
|
+
"""Apply backward score modification (joint graph) for SM100."""
|
|
1901
|
+
apply_score_mod_bwd_inner(
|
|
1902
|
+
grad_tensor,
|
|
1903
|
+
score_tensor,
|
|
1904
|
+
index_tensor,
|
|
1905
|
+
self.score_mod_bwd,
|
|
1906
|
+
batch_idx,
|
|
1907
|
+
head_idx,
|
|
1908
|
+
softmax_scale,
|
|
1909
|
+
self.vec_size,
|
|
1910
|
+
self.qk_acc_dtype,
|
|
1911
|
+
aux_tensors,
|
|
1912
|
+
fastdiv_mods,
|
|
1913
|
+
seqlen_info,
|
|
1914
|
+
constant_q_idx=None,
|
|
1915
|
+
qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
|
|
1916
|
+
transpose_indices=True,
|
|
1917
|
+
)
|
|
1918
|
+
|
|
1919
|
+
@cute.jit
|
|
1920
|
+
def compute_loop(
|
|
1921
|
+
self,
|
|
1922
|
+
thr_mma_S: cute.core.ThrMma,
|
|
1923
|
+
thr_mma_dP: cute.core.ThrMma,
|
|
1924
|
+
thr_mma_dV: cute.core.ThrMma,
|
|
1925
|
+
thr_mma_dK: cute.core.ThrMma,
|
|
1926
|
+
tStS: cute.Tensor,
|
|
1927
|
+
sLSE: cute.Tensor,
|
|
1928
|
+
sdPsum: cute.Tensor,
|
|
1929
|
+
tdVtdV: cute.Tensor,
|
|
1930
|
+
tdKtdK: cute.Tensor,
|
|
1931
|
+
mdV: cute.Tensor,
|
|
1932
|
+
mdK: cute.Tensor,
|
|
1933
|
+
sdS: cute.Tensor,
|
|
1934
|
+
tdPtdP: cute.Tensor,
|
|
1935
|
+
pipeline_LSE: PipelineAsync,
|
|
1936
|
+
pipeline_dPsum: PipelineAsync,
|
|
1937
|
+
pipeline_S_P: PipelineAsync,
|
|
1938
|
+
pipeline_dS: PipelineAsync,
|
|
1939
|
+
pipeline_dKV: PipelineAsync,
|
|
1940
|
+
pipeline_dP: PipelineAsync,
|
|
1941
|
+
softmax_scale: cutlass.Float32,
|
|
1942
|
+
softmax_scale_log2: cutlass.Float32,
|
|
1943
|
+
block_info: BlockInfo,
|
|
1944
|
+
SeqlenInfoCls: Callable,
|
|
1945
|
+
AttentionMaskCls: Callable,
|
|
1946
|
+
TileSchedulerCls: Callable,
|
|
1947
|
+
sdV: Optional[cute.Tensor],
|
|
1948
|
+
sdK: Optional[cute.Tensor],
|
|
1949
|
+
mdV_tma_tensor: Optional[cute.Tensor],
|
|
1950
|
+
mdK_tma_tensor: Optional[cute.Tensor],
|
|
1951
|
+
tma_atom_dV: Optional[cute.CopyAtom],
|
|
1952
|
+
tma_atom_dK: Optional[cute.CopyAtom],
|
|
1953
|
+
tiled_copy_r2s_dKV: Optional[cute.TiledCopy],
|
|
1954
|
+
mdK_semaphore: Optional[cute.Tensor],
|
|
1955
|
+
mdV_semaphore: Optional[cute.Tensor],
|
|
1956
|
+
aux_tensors: Optional[list] = None,
|
|
1957
|
+
fastdiv_mods=(None, None),
|
|
1958
|
+
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
|
1959
|
+
):
|
|
1960
|
+
sLSE_2D = cute.make_tensor(
|
|
1961
|
+
sLSE.iterator,
|
|
1962
|
+
cute.make_layout(
|
|
1963
|
+
(self.tile_m, self.tile_n, self.Q_stage),
|
|
1964
|
+
stride=(1, 0, cute.round_up(self.tile_m, 64)),
|
|
1965
|
+
),
|
|
1966
|
+
)
|
|
1967
|
+
sdPsum_2D = cute.make_tensor(
|
|
1968
|
+
sdPsum.iterator,
|
|
1969
|
+
cute.make_layout(
|
|
1970
|
+
(self.tile_m, self.tile_n, self.dO_stage),
|
|
1971
|
+
stride=(1, 0, cute.round_up(self.tile_m, 64)),
|
|
1972
|
+
),
|
|
1973
|
+
)
|
|
1974
|
+
# if const_expr(self.SdP_swapAB):
|
|
1975
|
+
if const_expr(True):
|
|
1976
|
+
sLSE_2D = utils.transpose_view(sLSE_2D)
|
|
1977
|
+
sdPsum_2D = utils.transpose_view(sdPsum_2D)
|
|
1978
|
+
|
|
1979
|
+
# tix: [128...384] 8 warps
|
|
1980
|
+
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) # 4-11
|
|
1981
|
+
tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.compute_warp_ids))
|
|
1982
|
+
# tidx = cute.arch.thread_idx()[0] - (cute.arch.WARP_SIZE * self.compute_warp_ids[0])
|
|
1983
|
+
dp_idx = tidx % 128
|
|
1984
|
+
num_wg = len(self.compute_warp_ids) // 4 # 2
|
|
1985
|
+
# wg_idx:
|
|
1986
|
+
# 0: [256...384]
|
|
1987
|
+
# 1: [128...256]
|
|
1988
|
+
|
|
1989
|
+
tileP_f32_like = self.mma_tiler_kq[0] // 32 * self.v_dtype.width # 64 for tile_n = 128
|
|
1990
|
+
# tStS has shape ((128, 128), 1, 1), tStP has shape ((128, 64), 1, 1)
|
|
1991
|
+
# tP overlap with tS
|
|
1992
|
+
tStP = cute.composition(tStS, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1))
|
|
1993
|
+
tStP = cute.make_tensor(tStS.iterator, tStP.layout) # Otherwise the tmem address is wrong
|
|
1994
|
+
tScS = thr_mma_S.partition_C(cute.make_identity_tensor(self.mma_tiler_kq[:2]))
|
|
1995
|
+
tScP = cute.composition(tScS, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1))
|
|
1996
|
+
# tdS overlap with tdP
|
|
1997
|
+
tdPtdS = cute.composition(tdPtdP, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1))
|
|
1998
|
+
tdPcdP = thr_mma_dP.partition_C(cute.make_identity_tensor(self.mma_tiler_vdo[:2]))
|
|
1999
|
+
tdPcdS = cute.composition(tdPcdP, (cute.make_layout((self.tile_n, tileP_f32_like)), 1, 1))
|
|
2000
|
+
|
|
2001
|
+
tmem_load_atom = cute.make_copy_atom(
|
|
2002
|
+
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32
|
|
2003
|
+
)
|
|
2004
|
+
tmem_store_atom = cute.make_copy_atom(
|
|
2005
|
+
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32
|
|
2006
|
+
)
|
|
2007
|
+
|
|
2008
|
+
# tmem -> rmem
|
|
2009
|
+
thr_copy_t2r = copy_utils.make_tmem_copy(tmem_load_atom, num_wg).get_slice(tidx)
|
|
2010
|
+
tStS_t2r = thr_copy_t2r.partition_S(tStS) # (((32, 32), 1), 2, 1, 1)
|
|
2011
|
+
tdPtdP_t2r = thr_copy_t2r.partition_S(tdPtdP)
|
|
2012
|
+
tScS_t2r = thr_copy_t2r.partition_D(tScS) # ((32, 1), 2, 1, 1)
|
|
2013
|
+
t0ScS_t2r = thr_copy_t2r.get_slice(0).partition_D(tScS) # ((32, 1), 2, 1, 1)
|
|
2014
|
+
# ((32, 1), 2, 1, 1, STAGE)
|
|
2015
|
+
tSsLSE = thr_copy_t2r.partition_D(thr_mma_S.partition_C(sLSE_2D))
|
|
2016
|
+
tSsdPsum = thr_copy_t2r.partition_D(thr_mma_dP.partition_C(sdPsum_2D))
|
|
2017
|
+
# rmem -> tmem
|
|
2018
|
+
thr_copy_r2t = copy_utils.make_tmem_copy(tmem_store_atom, num_wg).get_slice(tidx)
|
|
2019
|
+
tScP_r2t = thr_copy_r2t.partition_S(tScP)
|
|
2020
|
+
tStP_r2t = thr_copy_r2t.partition_D(tStP)
|
|
2021
|
+
tdPcdS_r2t = thr_copy_r2t.partition_S(tdPcdS)
|
|
2022
|
+
tdPtdS_r2t = thr_copy_r2t.partition_D(tdPtdS)
|
|
2023
|
+
# rmem -> smem
|
|
2024
|
+
# This part is a bit iffy, we might be making a lot of assumptions here
|
|
2025
|
+
copy_atom_r2s = sm100_utils_basic.get_smem_store_op(
|
|
2026
|
+
LayoutEnum.ROW_MAJOR, self.ds_dtype, Float32, thr_copy_t2r
|
|
2027
|
+
)
|
|
2028
|
+
thr_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, thr_copy_t2r).get_slice(tidx)
|
|
2029
|
+
# We assume the swizzle (i.e. layout.inner) stays the same
|
|
2030
|
+
sdS_layout = sm100_utils_basic.make_smem_layout_epi(
|
|
2031
|
+
self.ds_dtype, LayoutEnum.ROW_MAJOR, (self.tile_n, self.tile_m), 1
|
|
2032
|
+
).outer # ((8,16), (64,2), (1, 1))
|
|
2033
|
+
sdS_layout = cute.slice_(sdS_layout, (None, None, 0)) # ((8,16), (64,2))
|
|
2034
|
+
# Need to group into 1 mode to be compatible w thr_copy_r2s
|
|
2035
|
+
sdS_layout = cute.make_layout((sdS_layout.shape,), stride=(sdS_layout.stride,))
|
|
2036
|
+
sdS_epi = cute.make_tensor(sdS.iterator, sdS_layout)
|
|
2037
|
+
tRS_sdS = thr_copy_r2s.partition_D(sdS_epi)
|
|
2038
|
+
|
|
2039
|
+
consumer_state_S_P_dP = pipeline.make_pipeline_state( # Our impl has shortcut for stage==1
|
|
2040
|
+
cutlass.pipeline.PipelineUserType.Consumer, 1
|
|
2041
|
+
)
|
|
2042
|
+
# consumer_phase_S_P_dP = Int32(0)
|
|
2043
|
+
producer_state_dS = pipeline.make_pipeline_state( # Our impl has shortcut for stage==1
|
|
2044
|
+
cutlass.pipeline.PipelineUserType.Producer, 1
|
|
2045
|
+
)
|
|
2046
|
+
consumer_state_dKV = cutlass.pipeline.make_pipeline_state(
|
|
2047
|
+
cutlass.pipeline.PipelineUserType.Consumer, 2
|
|
2048
|
+
)
|
|
2049
|
+
consumer_state_LSE = cutlass.pipeline.make_pipeline_state(
|
|
2050
|
+
cutlass.pipeline.PipelineUserType.Consumer, self.Q_stage
|
|
2051
|
+
)
|
|
2052
|
+
# consumer_state_dPsum = cutlass.pipeline.make_pipeline_state(
|
|
2053
|
+
consumer_state_dPsum = pipeline.make_pipeline_state(
|
|
2054
|
+
cutlass.pipeline.PipelineUserType.Consumer, self.dO_stage
|
|
2055
|
+
)
|
|
2056
|
+
|
|
2057
|
+
tile_scheduler = TileSchedulerCls()
|
|
2058
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
2059
|
+
while work_tile.is_valid_tile:
|
|
2060
|
+
n_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
|
2061
|
+
seqlen = SeqlenInfoCls(batch_idx)
|
|
2062
|
+
m_block_min, m_block_max = block_info.get_m_block_min_max(
|
|
2063
|
+
seqlen, n_block // self.cluster_shape_mnk[0]
|
|
2064
|
+
)
|
|
2065
|
+
mask = AttentionMaskCls(seqlen)
|
|
2066
|
+
# TODO: condition mask_seqlen
|
|
2067
|
+
mask_fn = partial(
|
|
2068
|
+
mask.apply_mask_sm100_transposed,
|
|
2069
|
+
tScS_t2r=tScS_t2r,
|
|
2070
|
+
t0ScS_t2r=t0ScS_t2r,
|
|
2071
|
+
n_block=n_block,
|
|
2072
|
+
mask_seqlen=True,
|
|
2073
|
+
mask_causal=self.is_causal,
|
|
2074
|
+
mask_local=self.is_local,
|
|
2075
|
+
mask_mod=self.mask_mod,
|
|
2076
|
+
batch_idx=batch_idx,
|
|
2077
|
+
head_idx=head_idx,
|
|
2078
|
+
aux_tensors=aux_tensors,
|
|
2079
|
+
fastdiv_mods=fastdiv_mods,
|
|
2080
|
+
)
|
|
2081
|
+
|
|
2082
|
+
# prefetch_LSE = not self.is_causal
|
|
2083
|
+
prefetch_LSE = False
|
|
2084
|
+
|
|
2085
|
+
# some tiles might be empty due to block sparsity
|
|
2086
|
+
if const_expr(self.use_block_sparsity):
|
|
2087
|
+
(
|
|
2088
|
+
curr_q_cnt,
|
|
2089
|
+
curr_q_idx,
|
|
2090
|
+
curr_full_cnt,
|
|
2091
|
+
curr_full_idx,
|
|
2092
|
+
loop_count,
|
|
2093
|
+
) = get_block_sparse_iteration_info_bwd(
|
|
2094
|
+
blocksparse_tensors,
|
|
2095
|
+
batch_idx,
|
|
2096
|
+
head_idx,
|
|
2097
|
+
n_block,
|
|
2098
|
+
subtile_factor=self.subtile_factor,
|
|
2099
|
+
m_block_max=m_block_max,
|
|
2100
|
+
)
|
|
2101
|
+
process_tile = loop_count > Int32(0)
|
|
2102
|
+
else:
|
|
2103
|
+
process_tile = (
|
|
2104
|
+
const_expr(not self.is_local and not self.is_varlen_q)
|
|
2105
|
+
or m_block_min < m_block_max
|
|
2106
|
+
)
|
|
2107
|
+
loop_count = m_block_max - m_block_min
|
|
2108
|
+
|
|
2109
|
+
# Mainloop
|
|
2110
|
+
# Block sparsity: iterate over sparse m_block count and derive actual m_block
|
|
2111
|
+
# from Q_IDX/FULL_Q_IDX tensors. Dense: iterate m_block_min..m_block_max directly.
|
|
2112
|
+
for iter_idx in cutlass.range(loop_count, unroll=1):
|
|
2113
|
+
if const_expr(self.use_block_sparsity):
|
|
2114
|
+
m_block, is_full_block = get_m_block_from_iter_bwd(
|
|
2115
|
+
iter_idx,
|
|
2116
|
+
curr_q_cnt,
|
|
2117
|
+
curr_q_idx,
|
|
2118
|
+
curr_full_cnt,
|
|
2119
|
+
curr_full_idx,
|
|
2120
|
+
subtile_factor=self.subtile_factor,
|
|
2121
|
+
m_block_max=m_block_max,
|
|
2122
|
+
)
|
|
2123
|
+
m_block_oob = m_block >= m_block_max
|
|
2124
|
+
else:
|
|
2125
|
+
m_block = m_block_min + iter_idx
|
|
2126
|
+
m_block_oob = False
|
|
2127
|
+
is_full_block = False
|
|
2128
|
+
# Prefetch 1 stage of LSE
|
|
2129
|
+
pipeline_LSE.consumer_wait(consumer_state_LSE)
|
|
2130
|
+
tSrLSE_s2r = cute.make_fragment(tScS_t2r[None, 0, 0, 0].shape, Float32)
|
|
2131
|
+
if const_expr(prefetch_LSE and not self.shuffle_LSE):
|
|
2132
|
+
cute.autovec_copy(tSsLSE[None, 0, 0, 0, consumer_state_LSE.index], tSrLSE_s2r)
|
|
2133
|
+
|
|
2134
|
+
pipeline_S_P.consumer_wait(consumer_state_S_P_dP)
|
|
2135
|
+
# pipeline_S_P.sync_object_full.wait(0, consumer_phase_S_P_dP)
|
|
2136
|
+
#### TMEM->RMEM (Load S from TMEM)
|
|
2137
|
+
tSrS_t2r = cute.make_fragment(tScS_t2r.shape, Float32)
|
|
2138
|
+
cute.copy(thr_copy_t2r, tStS_t2r, tSrS_t2r)
|
|
2139
|
+
if const_expr(self.score_mod_bwd is not None):
|
|
2140
|
+
tSrS_pre = cute.make_fragment_like(tSrS_t2r)
|
|
2141
|
+
cute.autovec_copy(tSrS_t2r, tSrS_pre)
|
|
2142
|
+
|
|
2143
|
+
if const_expr(self.score_mod is not None):
|
|
2144
|
+
# Apply score_mod FIRST -> matches forward
|
|
2145
|
+
self.apply_score_mod(
|
|
2146
|
+
tSrS_t2r,
|
|
2147
|
+
thr_copy_t2r,
|
|
2148
|
+
thr_mma_S,
|
|
2149
|
+
batch_idx,
|
|
2150
|
+
head_idx,
|
|
2151
|
+
m_block,
|
|
2152
|
+
n_block,
|
|
2153
|
+
softmax_scale,
|
|
2154
|
+
seqlen,
|
|
2155
|
+
aux_tensors,
|
|
2156
|
+
fastdiv_mods,
|
|
2157
|
+
)
|
|
2158
|
+
|
|
2159
|
+
#### APPLY MASK (after score_mod, matching forward pass order)
|
|
2160
|
+
check_m_boundary = (m_block + 1) * self.tile_m > seqlen.seqlen_q
|
|
2161
|
+
mask_fn(
|
|
2162
|
+
tSrS_t2r,
|
|
2163
|
+
m_block=m_block,
|
|
2164
|
+
is_full_block=is_full_block,
|
|
2165
|
+
check_m_boundary=check_m_boundary,
|
|
2166
|
+
)
|
|
2167
|
+
|
|
2168
|
+
num_stages = cute.size(tScS_t2r, mode=[1])
|
|
2169
|
+
|
|
2170
|
+
# ---------------------------------------------
|
|
2171
|
+
#### P = exp(S - LSE)
|
|
2172
|
+
# ---------------------------------------------
|
|
2173
|
+
lane_idx = cute.arch.lane_idx()
|
|
2174
|
+
tSrP_r2t_f32 = cute.make_fragment(tScP_r2t.shape, Float32) # 64
|
|
2175
|
+
tSrP_r2t = cute.recast_tensor(tSrP_r2t_f32, self.q_dtype)
|
|
2176
|
+
for stage in cutlass.range_constexpr(num_stages):
|
|
2177
|
+
tSrS_cur = tSrS_t2r[None, stage, 0, 0]
|
|
2178
|
+
tSsLSE_cur = tSsLSE[None, stage, 0, 0, consumer_state_LSE.index]
|
|
2179
|
+
if const_expr(not self.shuffle_LSE):
|
|
2180
|
+
if const_expr(stage > 0 or not prefetch_LSE):
|
|
2181
|
+
cute.autovec_copy(tSsLSE_cur, tSrLSE_s2r)
|
|
2182
|
+
tSrLSE = tSrLSE_s2r
|
|
2183
|
+
else:
|
|
2184
|
+
tSrLSE = tSsLSE_cur[lane_idx]
|
|
2185
|
+
for v in cutlass.range_constexpr(cute.size(tSrS_t2r, mode=[0]) // 2):
|
|
2186
|
+
if const_expr(not self.shuffle_LSE):
|
|
2187
|
+
lse_pair = (tSrLSE[2 * v], tSrLSE[2 * v + 1])
|
|
2188
|
+
else:
|
|
2189
|
+
lse_pair = (
|
|
2190
|
+
utils.shuffle_sync(tSrLSE, offset=2 * v),
|
|
2191
|
+
utils.shuffle_sync(tSrLSE, offset=2 * v + 1),
|
|
2192
|
+
)
|
|
2193
|
+
tSrS_cur[2 * v], tSrS_cur[2 * v + 1] = utils.fma_packed_f32x2(
|
|
2194
|
+
((tSrS_cur[2 * v], tSrS_cur[2 * v + 1])),
|
|
2195
|
+
(softmax_scale_log2, softmax_scale_log2),
|
|
2196
|
+
(-lse_pair[0], -lse_pair[1]),
|
|
2197
|
+
)
|
|
2198
|
+
tSrS_cur[2 * v] = cute.math.exp2(tSrS_cur[2 * v], fastmath=True)
|
|
2199
|
+
tSrS_cur[2 * v + 1] = cute.math.exp2(tSrS_cur[2 * v + 1], fastmath=True)
|
|
2200
|
+
utils.cvt_f16(tSrS_cur, tSrP_r2t[None, stage, 0, 0])
|
|
2201
|
+
if const_expr(stage == 0):
|
|
2202
|
+
cute.arch.fence_view_async_tmem_load()
|
|
2203
|
+
# Without this barrier, we could have 1 warp writing to P in tmem while
|
|
2204
|
+
# another warp is still reading S from tmem.
|
|
2205
|
+
self.compute_sync_barrier.arrive_and_wait()
|
|
2206
|
+
cute.copy(
|
|
2207
|
+
thr_copy_r2t,
|
|
2208
|
+
tSrP_r2t_f32[None, stage, None, None],
|
|
2209
|
+
tStP_r2t[None, stage, None, None],
|
|
2210
|
+
)
|
|
2211
|
+
|
|
2212
|
+
cute.arch.fence_view_async_tmem_store()
|
|
2213
|
+
self.compute_sync_barrier.arrive_and_wait()
|
|
2214
|
+
|
|
2215
|
+
with cute.arch.elect_one():
|
|
2216
|
+
pipeline_S_P.consumer_release(consumer_state_S_P_dP)
|
|
2217
|
+
# pipeline_S_P.sync_object_empty.arrive(0, pipeline_S_P.consumer_mask)
|
|
2218
|
+
pipeline_LSE.consumer_release(consumer_state_LSE)
|
|
2219
|
+
# consumer_state_S_P_dP.advance()
|
|
2220
|
+
consumer_state_LSE.advance()
|
|
2221
|
+
|
|
2222
|
+
# ---------------------------------------------
|
|
2223
|
+
# dS.T = P.T * (dP.T - D)
|
|
2224
|
+
# ---------------------------------------------
|
|
2225
|
+
pipeline_dPsum.consumer_wait(consumer_state_dPsum)
|
|
2226
|
+
|
|
2227
|
+
pipeline_dP.consumer_wait(consumer_state_S_P_dP)
|
|
2228
|
+
# pipeline_dP.sync_object_full.wait(0, consumer_phase_S_P_dP)
|
|
2229
|
+
consumer_state_S_P_dP.advance()
|
|
2230
|
+
# consumer_phase_S_P_dP ^= 1
|
|
2231
|
+
|
|
2232
|
+
##### dS.T = P.T * (dP.T - Psum)
|
|
2233
|
+
for stage in cutlass.range_constexpr(num_stages):
|
|
2234
|
+
tdPrdP_t2r = cute.make_fragment(tScS_t2r[None, 0, None, None].shape, Float32)
|
|
2235
|
+
cute.copy(thr_copy_t2r, tdPtdP_t2r[None, stage, None, None], tdPrdP_t2r)
|
|
2236
|
+
cute.arch.fence_view_async_tmem_load()
|
|
2237
|
+
self.compute_sync_barrier.arrive_and_wait()
|
|
2238
|
+
tdPrdP_cur = tdPrdP_t2r[None, 0, 0]
|
|
2239
|
+
tSrS_cur = tSrS_t2r[None, stage, 0, 0]
|
|
2240
|
+
tSsdPsum_cur = tSsdPsum[None, stage, 0, 0, consumer_state_dPsum.index]
|
|
2241
|
+
if const_expr(not self.shuffle_dPsum):
|
|
2242
|
+
tSrdPsum = cute.make_fragment_like(tSsdPsum_cur, Float32)
|
|
2243
|
+
cute.autovec_copy(tSsdPsum_cur, tSrdPsum)
|
|
2244
|
+
else:
|
|
2245
|
+
tSrdPsum = tSsdPsum_cur[lane_idx]
|
|
2246
|
+
for v in cutlass.range_constexpr(cute.size(tdPrdP_t2r, mode=[0]) // 2):
|
|
2247
|
+
if const_expr(not self.shuffle_dPsum):
|
|
2248
|
+
dPsum_pair = (tSrdPsum[2 * v], tSrdPsum[2 * v + 1])
|
|
2249
|
+
else:
|
|
2250
|
+
dPsum_pair = (
|
|
2251
|
+
utils.shuffle_sync(tSrdPsum, offset=2 * v),
|
|
2252
|
+
utils.shuffle_sync(tSrdPsum, offset=2 * v + 1),
|
|
2253
|
+
)
|
|
2254
|
+
tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1] = utils.sub_packed_f32x2(
|
|
2255
|
+
(tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1]), dPsum_pair
|
|
2256
|
+
)
|
|
2257
|
+
tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1] = utils.mul_packed_f32x2(
|
|
2258
|
+
(tSrS_cur[2 * v], tSrS_cur[2 * v + 1]),
|
|
2259
|
+
(tdPrdP_cur[2 * v], tdPrdP_cur[2 * v + 1]),
|
|
2260
|
+
)
|
|
2261
|
+
|
|
2262
|
+
if const_expr(self.score_mod_bwd is not None):
|
|
2263
|
+
tSrS_pre_cur = tSrS_pre[None, stage, 0, 0]
|
|
2264
|
+
cS_bwd = cute.make_identity_tensor((self.tile_n, self.tile_m))
|
|
2265
|
+
cS_bwd = cute.domain_offset(
|
|
2266
|
+
(n_block * self.tile_n, m_block * self.tile_m), cS_bwd
|
|
2267
|
+
)
|
|
2268
|
+
tScS_bwd = thr_mma_S.partition_C(cS_bwd)
|
|
2269
|
+
tScS_idx_bwd = thr_copy_t2r.partition_D(tScS_bwd)
|
|
2270
|
+
tScS_idx_cur = tScS_idx_bwd[None, stage, 0, 0]
|
|
2271
|
+
self.apply_score_mod_bwd(
|
|
2272
|
+
tdPrdP_cur,
|
|
2273
|
+
tSrS_pre_cur,
|
|
2274
|
+
tScS_idx_cur,
|
|
2275
|
+
batch_idx,
|
|
2276
|
+
head_idx,
|
|
2277
|
+
softmax_scale,
|
|
2278
|
+
seqlen,
|
|
2279
|
+
aux_tensors,
|
|
2280
|
+
fastdiv_mods,
|
|
2281
|
+
)
|
|
2282
|
+
# Zero out OOB positions (kv_idx >= seqlen_k) after score_mod_bwd
|
|
2283
|
+
for i in cutlass.range(cute.size(tdPrdP_cur), unroll_full=True):
|
|
2284
|
+
kv_idx = tScS_idx_cur[i][0]
|
|
2285
|
+
tdPrdP_cur[i] = 0.0 if kv_idx >= seqlen.seqlen_k else tdPrdP_cur[i]
|
|
2286
|
+
|
|
2287
|
+
tdPrdS_cvt = cute.make_fragment_like(tdPrdP_cur, self.ds_dtype)
|
|
2288
|
+
utils.cvt_f16(tdPrdP_cur, tdPrdS_cvt)
|
|
2289
|
+
if const_expr(stage == 0):
|
|
2290
|
+
pipeline_dS.producer_acquire(producer_state_dS)
|
|
2291
|
+
cute.autovec_copy(tdPrdS_cvt, tRS_sdS[None, stage])
|
|
2292
|
+
if const_expr(not self.use_smem_dS_for_mma_dK):
|
|
2293
|
+
tdPrdS_r2t_f32 = cute.recast_tensor(tdPrdS_cvt, Float32)
|
|
2294
|
+
cute.copy(thr_copy_r2t, tdPrdS_r2t_f32, tdPtdS_r2t[None, stage, 0, 0])
|
|
2295
|
+
|
|
2296
|
+
if const_expr(not self.use_smem_dS_for_mma_dK):
|
|
2297
|
+
cute.arch.fence_view_async_tmem_store()
|
|
2298
|
+
cute.arch.fence_proxy(
|
|
2299
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
2300
|
+
)
|
|
2301
|
+
self.compute_sync_barrier.arrive_and_wait()
|
|
2302
|
+
|
|
2303
|
+
# with cute.arch.elect_one():
|
|
2304
|
+
# The mma warp no longer waits for dP (it waits for dS), so we don't have to arrive
|
|
2305
|
+
# pipeline_dP.sync_object_empty.arrive(0, pipeline_dP.consumer_mask)
|
|
2306
|
+
pipeline_dPsum.consumer_release(consumer_state_dPsum)
|
|
2307
|
+
consumer_state_dPsum.advance()
|
|
2308
|
+
with cute.arch.elect_one():
|
|
2309
|
+
pipeline_dS.producer_commit(producer_state_dS)
|
|
2310
|
+
producer_state_dS.advance()
|
|
2311
|
+
|
|
2312
|
+
# Epilogue
|
|
2313
|
+
# Run epilogue if we processed any m_blocks for this n_block
|
|
2314
|
+
if process_tile:
|
|
2315
|
+
if const_expr(not self.use_tma_store):
|
|
2316
|
+
consumer_state_dKV = self.epilogue_dKV(
|
|
2317
|
+
dp_idx,
|
|
2318
|
+
warp_idx,
|
|
2319
|
+
batch_idx,
|
|
2320
|
+
head_idx,
|
|
2321
|
+
n_block,
|
|
2322
|
+
seqlen,
|
|
2323
|
+
thr_mma_dV,
|
|
2324
|
+
thr_mma_dK,
|
|
2325
|
+
tdVtdV,
|
|
2326
|
+
tdKtdK,
|
|
2327
|
+
mdV,
|
|
2328
|
+
mdK,
|
|
2329
|
+
pipeline_dKV,
|
|
2330
|
+
consumer_state_dKV,
|
|
2331
|
+
softmax_scale,
|
|
2332
|
+
)
|
|
2333
|
+
else:
|
|
2334
|
+
thr_copy_r2s_dKV = tiled_copy_r2s_dKV.get_slice(dp_idx)
|
|
2335
|
+
#### STORE dV
|
|
2336
|
+
consumer_state_dKV = self.epilogue_dK_or_dV_tma(
|
|
2337
|
+
dp_idx,
|
|
2338
|
+
batch_idx,
|
|
2339
|
+
head_idx,
|
|
2340
|
+
n_block,
|
|
2341
|
+
seqlen,
|
|
2342
|
+
thr_mma_dV,
|
|
2343
|
+
tdVtdV,
|
|
2344
|
+
mdV_tma_tensor,
|
|
2345
|
+
sdV,
|
|
2346
|
+
tma_atom_dV,
|
|
2347
|
+
thr_copy_r2s_dKV,
|
|
2348
|
+
pipeline_dKV,
|
|
2349
|
+
consumer_state_dKV,
|
|
2350
|
+
None, # Don't scale
|
|
2351
|
+
int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id
|
|
2352
|
+
mdV_semaphore,
|
|
2353
|
+
)
|
|
2354
|
+
#### STORE dK
|
|
2355
|
+
consumer_state_dKV = self.epilogue_dK_or_dV_tma(
|
|
2356
|
+
dp_idx,
|
|
2357
|
+
batch_idx,
|
|
2358
|
+
head_idx,
|
|
2359
|
+
n_block,
|
|
2360
|
+
seqlen,
|
|
2361
|
+
thr_mma_dK,
|
|
2362
|
+
tdKtdK,
|
|
2363
|
+
mdK_tma_tensor,
|
|
2364
|
+
sdK,
|
|
2365
|
+
tma_atom_dK,
|
|
2366
|
+
thr_copy_r2s_dKV,
|
|
2367
|
+
pipeline_dKV,
|
|
2368
|
+
consumer_state_dKV,
|
|
2369
|
+
softmax_scale if const_expr(not self.dKV_postprocess) else None,
|
|
2370
|
+
int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id
|
|
2371
|
+
mdK_semaphore,
|
|
2372
|
+
)
|
|
2373
|
+
# Zero dK/dV for empty tiles (local attention or block sparsity)
|
|
2374
|
+
# When total_m_block_cnt == 0 for block sparsity, no Q tiles contribute to this KV tile
|
|
2375
|
+
if const_expr(not self.dKV_postprocess):
|
|
2376
|
+
should_zero_dKV = False
|
|
2377
|
+
if const_expr(self.is_local or self.is_varlen_q):
|
|
2378
|
+
should_zero_dKV = m_block_min >= m_block_max
|
|
2379
|
+
if const_expr(self.use_block_sparsity):
|
|
2380
|
+
# For block sparsity, zero when no m_blocks contribute to this n_block
|
|
2381
|
+
if not process_tile:
|
|
2382
|
+
should_zero_dKV = True
|
|
2383
|
+
|
|
2384
|
+
if should_zero_dKV:
|
|
2385
|
+
# like other epis, currently assumes hdim == hdimv
|
|
2386
|
+
gmem_tiled_copy_zero_dKV = copy_utils.tiled_copy_2d(
|
|
2387
|
+
self.dk_dtype,
|
|
2388
|
+
self.tile_hdim,
|
|
2389
|
+
128, # num_threads
|
|
2390
|
+
)
|
|
2391
|
+
gmem_thr_copy_zero_dKV = gmem_tiled_copy_zero_dKV.get_slice(dp_idx)
|
|
2392
|
+
mdV_cur = seqlen.offset_batch_K(mdV, batch_idx, dim=3)[None, None, head_idx]
|
|
2393
|
+
mdK_cur = seqlen.offset_batch_K(mdK, batch_idx, dim=3)[None, None, head_idx]
|
|
2394
|
+
gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0))
|
|
2395
|
+
gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0))
|
|
2396
|
+
tdKgdK = gmem_thr_copy_zero_dKV.partition_D(gdK)
|
|
2397
|
+
tdVgdV = gmem_thr_copy_zero_dKV.partition_D(gdV)
|
|
2398
|
+
assert tdKgdK.shape[2] == 1
|
|
2399
|
+
assert tdVgdV.shape[2] == 1
|
|
2400
|
+
cdKV = cute.make_identity_tensor((self.tile_n, self.tile_hdim))
|
|
2401
|
+
tdKVcdKV = gmem_thr_copy_zero_dKV.partition_D(cdKV)
|
|
2402
|
+
zero = cute.make_fragment_like(tdKgdK[None, 0, 0])
|
|
2403
|
+
zero.fill(0.0)
|
|
2404
|
+
if tidx < 128:
|
|
2405
|
+
for i in cutlass.range_constexpr(tdKgdK.shape[1]):
|
|
2406
|
+
row_idx = tdKVcdKV[0, i, 0][0]
|
|
2407
|
+
if row_idx < seqlen.seqlen_k - self.tile_n * n_block:
|
|
2408
|
+
cute.copy(gmem_tiled_copy_zero_dKV, zero, tdKgdK[None, i, 0])
|
|
2409
|
+
else:
|
|
2410
|
+
for i in cutlass.range_constexpr(tdVgdV.shape[1]):
|
|
2411
|
+
row_idx = tdKVcdKV[0, i, 0][0]
|
|
2412
|
+
if row_idx < seqlen.seqlen_k - self.tile_n * n_block:
|
|
2413
|
+
cute.copy(gmem_tiled_copy_zero_dKV, zero, tdVgdV[None, i, 0])
|
|
2414
|
+
|
|
2415
|
+
tile_scheduler.advance_to_next_work()
|
|
2416
|
+
work_tile = tile_scheduler.get_current_work()
|
|
2417
|
+
|
|
2418
|
+
@cute.jit
|
|
2419
|
+
def dQacc_reduce(
|
|
2420
|
+
self,
|
|
2421
|
+
mdQaccum: cute.Tensor,
|
|
2422
|
+
sdQaccum: cute.Tensor,
|
|
2423
|
+
thr_mma_dQ: cute.core.ThrMma,
|
|
2424
|
+
tdQtdQ: cute.Tensor,
|
|
2425
|
+
pipeline_dQ: PipelineAsync,
|
|
2426
|
+
block_info: BlockInfo,
|
|
2427
|
+
SeqlenInfoCls: Callable,
|
|
2428
|
+
TileSchedulerCls: Callable,
|
|
2429
|
+
mdQ_semaphore: Optional[cute.Tensor],
|
|
2430
|
+
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
|
2431
|
+
):
|
|
2432
|
+
num_reduce_threads = cute.arch.WARP_SIZE * len(self.reduce_warp_ids)
|
|
2433
|
+
tidx = cute.arch.thread_idx()[0] % num_reduce_threads
|
|
2434
|
+
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx() % len(self.reduce_warp_ids))
|
|
2435
|
+
is_tma_warp = warp_idx == 0
|
|
2436
|
+
# TMEM -> RMEM
|
|
2437
|
+
tmem_load_atom = cute.make_copy_atom(
|
|
2438
|
+
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(self.dQ_reduce_ncol)), Float32
|
|
2439
|
+
)
|
|
2440
|
+
thr_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdQtdQ).get_slice(tidx)
|
|
2441
|
+
tdQtdQ_t2r = thr_copy_t2r.partition_S(tdQtdQ)
|
|
2442
|
+
tdQcdQ = thr_mma_dQ.partition_C(cute.make_identity_tensor(self.mma_tiler_dsk[:2]))
|
|
2443
|
+
tdQrdQ_t2r_shape = thr_copy_t2r.partition_D(tdQcdQ).shape
|
|
2444
|
+
assert cute.size(tdQrdQ_t2r_shape, mode=[1]) == self.dQaccum_reduce_stage, (
|
|
2445
|
+
"dQaccum reduce stage mismatch"
|
|
2446
|
+
)
|
|
2447
|
+
|
|
2448
|
+
thr_copy_dQaccum_r2s = copy_utils.tiled_copy_1d(
|
|
2449
|
+
self.dqaccum_dtype, num_reduce_threads, num_copy_elems=128 // self.dqaccum_dtype.width
|
|
2450
|
+
).get_slice(tidx)
|
|
2451
|
+
tdQsdQ = thr_copy_dQaccum_r2s.partition_D(sdQaccum)
|
|
2452
|
+
|
|
2453
|
+
read_flag = const_expr(not self.deterministic)
|
|
2454
|
+
|
|
2455
|
+
tile_scheduler = TileSchedulerCls()
|
|
2456
|
+
work_tile = tile_scheduler.initial_work_tile_info()
|
|
2457
|
+
dQ_consumer_state = pipeline.make_pipeline_state(
|
|
2458
|
+
cutlass.pipeline.PipelineUserType.Consumer, 1
|
|
2459
|
+
)
|
|
2460
|
+
dQ_tma_store_producer_state = pipeline.make_pipeline_state(
|
|
2461
|
+
pipeline.PipelineUserType.Producer, self.sdQaccum_stage
|
|
2462
|
+
)
|
|
2463
|
+
while work_tile.is_valid_tile:
|
|
2464
|
+
n_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
|
2465
|
+
seqlen = SeqlenInfoCls(batch_idx)
|
|
2466
|
+
m_block_min, m_block_max = block_info.get_m_block_min_max(
|
|
2467
|
+
seqlen, n_block // self.cluster_shape_mnk[0]
|
|
2468
|
+
)
|
|
2469
|
+
if const_expr(not seqlen.has_cu_seqlens_q):
|
|
2470
|
+
mdQaccum_cur = mdQaccum[None, head_idx, batch_idx]
|
|
2471
|
+
else:
|
|
2472
|
+
mdQaccum_cur = cute.domain_offset(
|
|
2473
|
+
(seqlen.padded_offset_q * self.tile_hdim,), mdQaccum[None, head_idx]
|
|
2474
|
+
)
|
|
2475
|
+
gdQaccum_ = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,))
|
|
2476
|
+
# (M * K / STAGE, STAGE, _)
|
|
2477
|
+
gdQaccum = cute.flat_divide(
|
|
2478
|
+
gdQaccum_, (self.tile_m * self.tile_hdim // self.dQaccum_reduce_stage,)
|
|
2479
|
+
)
|
|
2480
|
+
|
|
2481
|
+
if const_expr(self.deterministic):
|
|
2482
|
+
mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx]
|
|
2483
|
+
|
|
2484
|
+
delay_semaphore_release = self.is_causal
|
|
2485
|
+
n_block_global_max = cute.ceil_div(seqlen.seqlen_k, self.tile_n)
|
|
2486
|
+
|
|
2487
|
+
# some tiles might be empty due to block sparsity
|
|
2488
|
+
if const_expr(self.use_block_sparsity):
|
|
2489
|
+
(
|
|
2490
|
+
curr_q_cnt,
|
|
2491
|
+
curr_q_idx,
|
|
2492
|
+
curr_full_cnt,
|
|
2493
|
+
curr_full_idx,
|
|
2494
|
+
loop_count,
|
|
2495
|
+
) = get_block_sparse_iteration_info_bwd(
|
|
2496
|
+
blocksparse_tensors,
|
|
2497
|
+
batch_idx,
|
|
2498
|
+
head_idx,
|
|
2499
|
+
n_block,
|
|
2500
|
+
subtile_factor=self.subtile_factor,
|
|
2501
|
+
m_block_max=m_block_max,
|
|
2502
|
+
)
|
|
2503
|
+
process_tile = loop_count > Int32(0)
|
|
2504
|
+
else:
|
|
2505
|
+
process_tile = (
|
|
2506
|
+
const_expr(not self.is_local and not self.is_varlen_q)
|
|
2507
|
+
or m_block_min < m_block_max
|
|
2508
|
+
)
|
|
2509
|
+
loop_count = m_block_max - m_block_min
|
|
2510
|
+
|
|
2511
|
+
# dQacc_reduce mainloop
|
|
2512
|
+
# Block sparsity: iterate over sparse m_block count and derive actual m_block
|
|
2513
|
+
# from Q_IDX/FULL_Q_IDX tensors. Dense: iterate m_block_min..m_block_max directly.
|
|
2514
|
+
for iter_idx in cutlass.range(loop_count, unroll=1):
|
|
2515
|
+
if const_expr(self.use_block_sparsity):
|
|
2516
|
+
m_block, _ = get_m_block_from_iter_bwd(
|
|
2517
|
+
iter_idx,
|
|
2518
|
+
curr_q_cnt,
|
|
2519
|
+
curr_q_idx,
|
|
2520
|
+
curr_full_cnt,
|
|
2521
|
+
curr_full_idx,
|
|
2522
|
+
subtile_factor=self.subtile_factor,
|
|
2523
|
+
m_block_max=m_block_max,
|
|
2524
|
+
)
|
|
2525
|
+
if m_block_max > 0:
|
|
2526
|
+
m_block = cutlass.min(m_block, m_block_max - 1)
|
|
2527
|
+
else:
|
|
2528
|
+
m_block = m_block_min + iter_idx
|
|
2529
|
+
pipeline_dQ.consumer_wait(dQ_consumer_state)
|
|
2530
|
+
# TMEM -> RMEM
|
|
2531
|
+
tdQrdQ_t2r = cute.make_fragment(tdQrdQ_t2r_shape, Float32)
|
|
2532
|
+
cute.copy(thr_copy_t2r, tdQtdQ_t2r, tdQrdQ_t2r)
|
|
2533
|
+
cute.arch.fence_view_async_tmem_load()
|
|
2534
|
+
cute.arch.sync_warp()
|
|
2535
|
+
with cute.arch.elect_one():
|
|
2536
|
+
pipeline_dQ.consumer_release(dQ_consumer_state)
|
|
2537
|
+
dQ_consumer_state.advance()
|
|
2538
|
+
|
|
2539
|
+
gdQaccum_cur = gdQaccum[None, None, m_block]
|
|
2540
|
+
|
|
2541
|
+
for stage in cutlass.range_constexpr(cute.size(tdQrdQ_t2r, mode=[1])): # 4
|
|
2542
|
+
smem_idx = dQ_tma_store_producer_state.index
|
|
2543
|
+
tdQsdQ_r2s = tdQsdQ[None, None, smem_idx]
|
|
2544
|
+
tdQrdQ_r2s = cute.make_tensor(
|
|
2545
|
+
tdQrdQ_t2r[None, stage, None, None].iterator, tdQsdQ_r2s.shape
|
|
2546
|
+
)
|
|
2547
|
+
cute.copy(thr_copy_dQaccum_r2s, tdQrdQ_r2s, tdQsdQ_r2s)
|
|
2548
|
+
# Fence and barrier to make sure shared memory store is visible to TMA store
|
|
2549
|
+
cute.arch.fence_proxy(
|
|
2550
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
2551
|
+
)
|
|
2552
|
+
# semaphore acquire
|
|
2553
|
+
if const_expr(self.deterministic and stage == 0):
|
|
2554
|
+
if const_expr(self.spt):
|
|
2555
|
+
if const_expr(
|
|
2556
|
+
self.is_causal or block_info.window_size_right is not None
|
|
2557
|
+
):
|
|
2558
|
+
n_idx_right = (
|
|
2559
|
+
(m_block + 1) * self.tile_m + seqlen.seqlen_k - seqlen.seqlen_q
|
|
2560
|
+
)
|
|
2561
|
+
if const_expr(block_info.window_size_right is not None):
|
|
2562
|
+
n_idx_right += block_info.window_size_right
|
|
2563
|
+
n_block_max_for_m_block = min(
|
|
2564
|
+
n_block_global_max,
|
|
2565
|
+
cute.ceil_div(n_idx_right, self.tile_n),
|
|
2566
|
+
)
|
|
2567
|
+
else:
|
|
2568
|
+
n_block_max_for_m_block = n_block_global_max
|
|
2569
|
+
lock_value = n_block_max_for_m_block - 1 - n_block
|
|
2570
|
+
else:
|
|
2571
|
+
lock_value = n_block
|
|
2572
|
+
barrier.wait_eq(
|
|
2573
|
+
mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, lock_value
|
|
2574
|
+
)
|
|
2575
|
+
self.reduce_sync_barrier.arrive_and_wait()
|
|
2576
|
+
# Copy from shared memory to global memory
|
|
2577
|
+
if is_tma_warp:
|
|
2578
|
+
with cute.arch.elect_one():
|
|
2579
|
+
copy_utils.cpasync_reduce_bulk_add_f32(
|
|
2580
|
+
sdQaccum[None, smem_idx].iterator,
|
|
2581
|
+
gdQaccum_cur[None, stage].iterator,
|
|
2582
|
+
self.tma_copy_bytes["dQ"] // 1,
|
|
2583
|
+
)
|
|
2584
|
+
cute.arch.cp_async_bulk_commit_group()
|
|
2585
|
+
cute.arch.cp_async_bulk_wait_group(self.sdQaccum_stage - 1, read=read_flag)
|
|
2586
|
+
self.reduce_sync_barrier.arrive_and_wait()
|
|
2587
|
+
dQ_tma_store_producer_state.advance()
|
|
2588
|
+
# Directly add to gmem, much slower
|
|
2589
|
+
# tdQgdQ = thr_copy_dQaccum_r2s.partition_D(gdQaccum[None, stage, m_block])
|
|
2590
|
+
# assert cute.size(tdQrdQ_r2s) == cute.size(tdQgdQ)
|
|
2591
|
+
# for i in cutlass.range(cute.size(tdQrdQ_r2s) // 4, unroll_full=True):
|
|
2592
|
+
# copy_utils.atomic_add_fp32x4(
|
|
2593
|
+
# tdQrdQ_r2s[4 * i],
|
|
2594
|
+
# tdQrdQ_r2s[4 * i + 1],
|
|
2595
|
+
# tdQrdQ_r2s[4 * i + 2],
|
|
2596
|
+
# tdQrdQ_r2s[4 * i + 3],
|
|
2597
|
+
# utils.elem_pointer(tdQgdQ, 4 * i),
|
|
2598
|
+
# )
|
|
2599
|
+
# semaphore release for prior m_block
|
|
2600
|
+
if const_expr(self.deterministic and stage == 0 and delay_semaphore_release):
|
|
2601
|
+
if m_block > m_block_min:
|
|
2602
|
+
barrier.arrive_inc(
|
|
2603
|
+
mdQ_semaphore_cur[(m_block - 1, None)].iterator, tidx, 0, 1
|
|
2604
|
+
)
|
|
2605
|
+
|
|
2606
|
+
# semaphore release
|
|
2607
|
+
# NOTE: arrive_inc calls red_release which issues membar
|
|
2608
|
+
if const_expr(self.deterministic and not delay_semaphore_release):
|
|
2609
|
+
if is_tma_warp:
|
|
2610
|
+
cute.arch.cp_async_bulk_wait_group(0, read=read_flag)
|
|
2611
|
+
self.reduce_sync_barrier.arrive_and_wait()
|
|
2612
|
+
barrier.arrive_inc(mdQ_semaphore_cur[m_block, None].iterator, tidx, 0, 1)
|
|
2613
|
+
|
|
2614
|
+
if const_expr(not self.is_local) or m_block_min < m_block_max:
|
|
2615
|
+
if is_tma_warp:
|
|
2616
|
+
cute.arch.cp_async_bulk_wait_group(0, read=read_flag)
|
|
2617
|
+
self.reduce_sync_barrier.arrive_and_wait()
|
|
2618
|
+
# final semaphore release
|
|
2619
|
+
if const_expr(self.deterministic and delay_semaphore_release):
|
|
2620
|
+
barrier.arrive_inc(
|
|
2621
|
+
mdQ_semaphore_cur[(m_block_max - 1, None)].iterator, tidx, 0, 1
|
|
2622
|
+
)
|
|
2623
|
+
|
|
2624
|
+
if const_expr(
|
|
2625
|
+
self.deterministic and not self.spt and block_info.window_size_left is not None
|
|
2626
|
+
):
|
|
2627
|
+
m_block_global_max = cute.ceil_div(seqlen.seqlen_q, self.tile_m)
|
|
2628
|
+
for m_block in cutlass.range(m_block_max, m_block_global_max, unroll=1):
|
|
2629
|
+
barrier.arrive_inc(mdQ_semaphore_cur[(m_block, None)].iterator, tidx, 0, 1)
|
|
2630
|
+
|
|
2631
|
+
tile_scheduler.advance_to_next_work()
|
|
2632
|
+
work_tile = tile_scheduler.get_current_work()
|
|
2633
|
+
|
|
2634
|
+
@cute.jit
|
|
2635
|
+
def epilogue_dKV(
|
|
2636
|
+
self,
|
|
2637
|
+
tidx: Int32,
|
|
2638
|
+
warp_idx: Int32,
|
|
2639
|
+
batch_idx: Int32,
|
|
2640
|
+
head_idx: Int32,
|
|
2641
|
+
n_block: Int32,
|
|
2642
|
+
seqlen,
|
|
2643
|
+
thr_mma_dV: cute.core.ThrMma,
|
|
2644
|
+
thr_mma_dK: cute.core.ThrMma,
|
|
2645
|
+
tdVtdV: cute.Tensor,
|
|
2646
|
+
tdKtdK: cute.Tensor,
|
|
2647
|
+
mdV: cute.Tensor,
|
|
2648
|
+
mdK: cute.Tensor,
|
|
2649
|
+
pipeline_dKV: PipelineAsync,
|
|
2650
|
+
consumer_state_dKV: cutlass.pipeline.PipelineState,
|
|
2651
|
+
softmax_scale: Float32,
|
|
2652
|
+
):
|
|
2653
|
+
wg_idx = (
|
|
2654
|
+
cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.compute_warp_ids))
|
|
2655
|
+
) // 128
|
|
2656
|
+
num_wg = cute.arch.WARP_SIZE * len(self.compute_warp_ids) // 128
|
|
2657
|
+
|
|
2658
|
+
assert self.qhead_per_kvhead == 1, "This epilogue path is only for MHA"
|
|
2659
|
+
mdV_cur = seqlen.offset_batch_K(mdV, batch_idx, dim=3)[None, None, head_idx]
|
|
2660
|
+
mdK_cur = seqlen.offset_batch_K(mdK, batch_idx, dim=3)[None, None, head_idx]
|
|
2661
|
+
|
|
2662
|
+
tmem_load_atom = cute.make_copy_atom(
|
|
2663
|
+
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(16)), Float32
|
|
2664
|
+
)
|
|
2665
|
+
|
|
2666
|
+
# dV
|
|
2667
|
+
pipeline_dKV.consumer_wait(consumer_state_dKV)
|
|
2668
|
+
|
|
2669
|
+
tiled_tmem_ld_dV = tcgen05.make_tmem_copy(tmem_load_atom, tdVtdV)
|
|
2670
|
+
thr_tmem_ld_dV = tiled_tmem_ld_dV.get_slice(tidx)
|
|
2671
|
+
|
|
2672
|
+
tdVtdV_t2r_p = thr_tmem_ld_dV.partition_S(tdVtdV)
|
|
2673
|
+
tdVtdV_t2r = self.split_wg(tdVtdV_t2r_p, wg_idx, num_wg)
|
|
2674
|
+
|
|
2675
|
+
cdV = cute.make_identity_tensor((self.mma_tiler_pdo[0], self.mma_tiler_pdo[1]))
|
|
2676
|
+
tdVcdV = thr_mma_dV.partition_C(cdV)
|
|
2677
|
+
tdVcdV_tensor = cute.make_tensor(tdVcdV.iterator, tdVcdV.layout)
|
|
2678
|
+
|
|
2679
|
+
tdVcdV_t2r_p = thr_tmem_ld_dV.partition_D(tdVcdV_tensor)
|
|
2680
|
+
tdVcdV_t2r = self.split_wg(tdVcdV_t2r_p, wg_idx, num_wg)
|
|
2681
|
+
tdVrdV_t2r = cute.make_fragment(tdVcdV_t2r.shape, Float32)
|
|
2682
|
+
|
|
2683
|
+
cute.copy(thr_tmem_ld_dV, tdVtdV_t2r, tdVrdV_t2r)
|
|
2684
|
+
cute.arch.fence_view_async_tmem_load()
|
|
2685
|
+
|
|
2686
|
+
universal_copy_bits = 128
|
|
2687
|
+
atom_universal_copy = cute.make_copy_atom(
|
|
2688
|
+
cute.nvgpu.CopyUniversalOp(),
|
|
2689
|
+
self.dv_dtype,
|
|
2690
|
+
num_bits_per_copy=universal_copy_bits,
|
|
2691
|
+
)
|
|
2692
|
+
tiled_gmem_store_dV = cute.make_tiled_copy(
|
|
2693
|
+
atom_universal_copy,
|
|
2694
|
+
layout_tv=tiled_tmem_ld_dV.layout_dst_tv_tiled,
|
|
2695
|
+
tiler_mn=tiled_tmem_ld_dV.tiler_mn,
|
|
2696
|
+
)
|
|
2697
|
+
|
|
2698
|
+
tdVrdV_r2s = cute.make_fragment(tdVrdV_t2r.shape, self.dv_dtype)
|
|
2699
|
+
for i in cutlass.range_constexpr(cute.size(tdVrdV_t2r, mode=[1])):
|
|
2700
|
+
dV_vec = tdVrdV_t2r[(None, i, 0, 0)].load()
|
|
2701
|
+
tdVrdV_r2s[(None, i, 0, 0)].store(dV_vec.to(self.dv_dtype))
|
|
2702
|
+
|
|
2703
|
+
gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (None, 0))
|
|
2704
|
+
gdV_tile = gdV[None, None, n_block]
|
|
2705
|
+
|
|
2706
|
+
tdVgdV = thr_mma_dV.partition_C(gdV_tile)
|
|
2707
|
+
tdVgdV_r2g_p = thr_tmem_ld_dV.partition_D(tdVgdV)
|
|
2708
|
+
tdVgdV_r2g = self.split_wg(tdVgdV_r2g_p, wg_idx, num_wg)
|
|
2709
|
+
|
|
2710
|
+
if tidx < seqlen.seqlen_k - self.tile_n * n_block:
|
|
2711
|
+
cute.copy(tiled_gmem_store_dV, tdVrdV_r2s, tdVgdV_r2g)
|
|
2712
|
+
|
|
2713
|
+
cute.arch.sync_warp()
|
|
2714
|
+
with cute.arch.elect_one():
|
|
2715
|
+
pipeline_dKV.consumer_release(consumer_state_dKV)
|
|
2716
|
+
consumer_state_dKV.advance()
|
|
2717
|
+
|
|
2718
|
+
# dK
|
|
2719
|
+
pipeline_dKV.consumer_wait(consumer_state_dKV)
|
|
2720
|
+
|
|
2721
|
+
tiled_tmem_ld_dK = tcgen05.make_tmem_copy(tmem_load_atom, tdKtdK)
|
|
2722
|
+
thr_tmem_ld_dK = tiled_tmem_ld_dK.get_slice(tidx)
|
|
2723
|
+
|
|
2724
|
+
tdKtdK_t2r_p = thr_tmem_ld_dK.partition_S(tdKtdK)
|
|
2725
|
+
tdKtdK_t2r = self.split_wg(tdKtdK_t2r_p, wg_idx, num_wg)
|
|
2726
|
+
|
|
2727
|
+
cdK = cute.make_identity_tensor((self.mma_tiler_dsq[0], self.mma_tiler_dsq[1]))
|
|
2728
|
+
tdKcdK = thr_mma_dK.partition_C(cdK)
|
|
2729
|
+
tdKcdK_tensor = cute.make_tensor(tdKcdK.iterator, tdKcdK.layout)
|
|
2730
|
+
|
|
2731
|
+
tdKcdK_t2r_p = thr_tmem_ld_dK.partition_D(tdKcdK_tensor)
|
|
2732
|
+
tdKcdK_t2r = self.split_wg(tdKcdK_t2r_p, wg_idx, num_wg)
|
|
2733
|
+
tdKrdK_t2r = cute.make_fragment(tdKcdK_t2r.shape, Float32)
|
|
2734
|
+
|
|
2735
|
+
cute.copy(tiled_tmem_ld_dK, tdKtdK_t2r, tdKrdK_t2r)
|
|
2736
|
+
cute.arch.fence_view_async_tmem_load()
|
|
2737
|
+
|
|
2738
|
+
universal_copy_bits = 128
|
|
2739
|
+
atom_universal_copy = cute.make_copy_atom(
|
|
2740
|
+
cute.nvgpu.CopyUniversalOp(),
|
|
2741
|
+
self.dk_dtype,
|
|
2742
|
+
num_bits_per_copy=universal_copy_bits,
|
|
2743
|
+
)
|
|
2744
|
+
|
|
2745
|
+
tiled_gmem_store_dK = cute.make_tiled_copy(
|
|
2746
|
+
atom_universal_copy,
|
|
2747
|
+
layout_tv=tiled_tmem_ld_dK.layout_dst_tv_tiled,
|
|
2748
|
+
tiler_mn=tiled_tmem_ld_dK.tiler_mn,
|
|
2749
|
+
)
|
|
2750
|
+
|
|
2751
|
+
tdKrdK_r2s = cute.make_fragment(tdKrdK_t2r.shape, self.dk_dtype)
|
|
2752
|
+
|
|
2753
|
+
for i in cutlass.range_constexpr(cute.size(tdKrdK_t2r, mode=[1])):
|
|
2754
|
+
dK_vec = tdKrdK_t2r[(None, i, 0, 0)].load() * softmax_scale
|
|
2755
|
+
tdKrdK_r2s[(None, i, 0, 0)].store(dK_vec.to(self.dk_dtype))
|
|
2756
|
+
|
|
2757
|
+
gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdimv), (None, 0))
|
|
2758
|
+
gdK_tile = gdK[None, None, n_block]
|
|
2759
|
+
|
|
2760
|
+
tdKgdK = thr_mma_dK.partition_C(gdK_tile)
|
|
2761
|
+
tdKgdK_r2g_p = thr_tmem_ld_dK.partition_D(tdKgdK)
|
|
2762
|
+
tdKgdK_r2g = self.split_wg(tdKgdK_r2g_p, wg_idx, num_wg)
|
|
2763
|
+
|
|
2764
|
+
if tidx < seqlen.seqlen_k - self.tile_n * n_block:
|
|
2765
|
+
cute.copy(tiled_gmem_store_dK, tdKrdK_r2s, tdKgdK_r2g)
|
|
2766
|
+
|
|
2767
|
+
cute.arch.sync_warp()
|
|
2768
|
+
with cute.arch.elect_one():
|
|
2769
|
+
pipeline_dKV.consumer_release(consumer_state_dKV)
|
|
2770
|
+
consumer_state_dKV.advance()
|
|
2771
|
+
return consumer_state_dKV
|
|
2772
|
+
|
|
2773
|
+
@cute.jit
|
|
2774
|
+
def epilogue_dK_or_dV_tma(
|
|
2775
|
+
self,
|
|
2776
|
+
tidx: Int32,
|
|
2777
|
+
batch_idx: Int32,
|
|
2778
|
+
head_idx: Int32,
|
|
2779
|
+
n_block: Int32,
|
|
2780
|
+
seqlen,
|
|
2781
|
+
thr_mma: cute.core.ThrMma,
|
|
2782
|
+
tdKVtdKV: cute.Tensor,
|
|
2783
|
+
mdKV: cute.Tensor,
|
|
2784
|
+
sdKV: cute.Tensor,
|
|
2785
|
+
tma_atom_dKV: cute.CopyAtom,
|
|
2786
|
+
thr_copy_r2s_dKV: cute.TiledCopy,
|
|
2787
|
+
pipeline_dKV: PipelineAsync,
|
|
2788
|
+
consumer_state_dKV: cutlass.pipeline.PipelineState,
|
|
2789
|
+
scale: Optional[Float32],
|
|
2790
|
+
barrier_id: Int32,
|
|
2791
|
+
mdKV_semaphore: Optional[cute.Tensor],
|
|
2792
|
+
) -> cutlass.pipeline.PipelineState:
|
|
2793
|
+
# assumes mma_tiler_pdo = mma_tiler_dsq = (tile_n, head_dim)
|
|
2794
|
+
# head_dim = head_dim_v, dk_dtype = dv_dtype
|
|
2795
|
+
num_compute_threads = cute.arch.WARP_SIZE * len(self.compute_warp_ids)
|
|
2796
|
+
wg_idx = (cute.arch.thread_idx()[0] % num_compute_threads) // 128
|
|
2797
|
+
num_wg = num_compute_threads // 128
|
|
2798
|
+
leader_warp = (cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4) == 0
|
|
2799
|
+
|
|
2800
|
+
if const_expr(not self.dKV_postprocess):
|
|
2801
|
+
sdKV = sdKV[None, None, wg_idx] # (tile_n, 64) for bf16
|
|
2802
|
+
else:
|
|
2803
|
+
sdKV = sdKV[None, wg_idx] # (tile_n * 32) for fp32
|
|
2804
|
+
|
|
2805
|
+
# (8, tile_n / 128, 64 / 8) = (8, 1, 8) or (4, tile_n * 32 / (128 * 4)) = (4, 8)
|
|
2806
|
+
tdKVsdKV_r2s = thr_copy_r2s_dKV.partition_D(sdKV)
|
|
2807
|
+
|
|
2808
|
+
head_idx_kv = head_idx // self.qhead_per_kvhead
|
|
2809
|
+
if const_expr(not self.dKV_postprocess):
|
|
2810
|
+
assert not seqlen.has_cu_seqlens_k, "varlen uses non tma store path"
|
|
2811
|
+
mdKV_cur = mdKV[None, None, head_idx_kv, batch_idx] # (seqlen, hdim)
|
|
2812
|
+
gdKV_p = cute.local_tile(
|
|
2813
|
+
mdKV_cur, (self.tile_n, self.tile_hdim), (n_block, 0)
|
|
2814
|
+
) # (tile_n, hdim)
|
|
2815
|
+
gdKV = self.split_wg(gdKV_p, wg_idx, num_wg) # (tile_n, hdim / 2)
|
|
2816
|
+
gdKV_epi = cute.local_tile(
|
|
2817
|
+
gdKV, self.sdKV_epi_tile, (0, None)
|
|
2818
|
+
) # (tile_n, 64, epi_stage = (hdim / 2) / 64)
|
|
2819
|
+
else:
|
|
2820
|
+
if const_expr(not seqlen.has_cu_seqlens_k):
|
|
2821
|
+
mdKV_cur = mdKV[None, head_idx_kv, batch_idx] # (seqlen * hdim)
|
|
2822
|
+
else:
|
|
2823
|
+
mdKV_cur = cute.domain_offset(
|
|
2824
|
+
(seqlen.padded_offset_k * self.tile_hdim,), mdKV[None, head_idx_kv]
|
|
2825
|
+
)
|
|
2826
|
+
gdKV_p = cute.local_tile(
|
|
2827
|
+
mdKV_cur, (self.tile_n * self.tile_hdim,), (n_block,)
|
|
2828
|
+
) # (tile_n * hdim)
|
|
2829
|
+
gdKV = cute.logical_divide(gdKV_p, (self.tile_n * self.tile_hdim // num_wg,))[
|
|
2830
|
+
((None, wg_idx),)
|
|
2831
|
+
] # (tile_n * hdim / 2)
|
|
2832
|
+
gdKV_epi = cute.flat_divide(
|
|
2833
|
+
gdKV, (self.sdKV_flat_epi_tile,)
|
|
2834
|
+
) # (tile_n * hdim / 2 / epi_stage, epi_stage)
|
|
2835
|
+
|
|
2836
|
+
deterministic_KV = self.deterministic and self.qhead_per_kvhead > 1
|
|
2837
|
+
if const_expr(deterministic_KV):
|
|
2838
|
+
mdKV_semaphore_cur = mdKV_semaphore[n_block, None, head_idx_kv, batch_idx]
|
|
2839
|
+
|
|
2840
|
+
if const_expr(not self.dKV_postprocess):
|
|
2841
|
+
tdKVsdKV, tdKVgdKV = cpasync.tma_partition(
|
|
2842
|
+
tma_atom_dKV,
|
|
2843
|
+
0, # no multicast
|
|
2844
|
+
cute.make_layout(1),
|
|
2845
|
+
cute.group_modes(sdKV, 0, 2),
|
|
2846
|
+
cute.group_modes(gdKV_epi, 0, 2),
|
|
2847
|
+
) # (TMA) and (TMA, EPI_STAGE)
|
|
2848
|
+
assert len(tdKVsdKV.shape) == 1, "Wrong rank for SMEM fragment tdKVsdKV"
|
|
2849
|
+
assert len(tdKVgdKV.shape) == 2, "Wrong rank for GMEM fragment tdKVgdKV"
|
|
2850
|
+
num_epi_stages = cute.size(tdKVgdKV.shape[1])
|
|
2851
|
+
assert num_epi_stages == self.num_epi_stages, "Epi stage calculation is wrong"
|
|
2852
|
+
else:
|
|
2853
|
+
num_epi_stages = self.num_epi_stages
|
|
2854
|
+
|
|
2855
|
+
tmem_load_atom = cute.make_copy_atom(
|
|
2856
|
+
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32
|
|
2857
|
+
)
|
|
2858
|
+
|
|
2859
|
+
read_flag = const_expr(not deterministic_KV)
|
|
2860
|
+
|
|
2861
|
+
pipeline_dKV.consumer_wait(consumer_state_dKV)
|
|
2862
|
+
|
|
2863
|
+
# semaphore acquire
|
|
2864
|
+
if const_expr(deterministic_KV):
|
|
2865
|
+
barrier.wait_eq(
|
|
2866
|
+
mdKV_semaphore_cur.iterator, tidx, wg_idx, head_idx % self.qhead_per_kvhead
|
|
2867
|
+
)
|
|
2868
|
+
cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128)
|
|
2869
|
+
|
|
2870
|
+
for epi_stage in cutlass.range_constexpr(num_epi_stages):
|
|
2871
|
+
# TMEM -> RMEM -- setup
|
|
2872
|
+
thr_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tdKVtdKV).get_slice(tidx)
|
|
2873
|
+
tdKVtdKV_t2r_p = thr_copy_t2r.partition_S(tdKVtdKV)
|
|
2874
|
+
tdKVtdKV_t2r = self.split_wg(tdKVtdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0]
|
|
2875
|
+
if const_expr(num_epi_stages > 1):
|
|
2876
|
+
tdKVtdKV_t2r = tdKVtdKV_t2r[None, epi_stage]
|
|
2877
|
+
|
|
2878
|
+
cdKV = cute.make_identity_tensor((self.tile_n, self.tile_hdim))
|
|
2879
|
+
tdKVcdKV = thr_mma.partition_C(cdKV)
|
|
2880
|
+
tdKVcdKV_t2r_p = thr_copy_t2r.partition_D(tdKVcdKV)
|
|
2881
|
+
tdKVcdKV_t2r = self.split_wg(tdKVcdKV_t2r_p, wg_idx, num_wg)[None, None, 0, 0]
|
|
2882
|
+
if const_expr(num_epi_stages > 1):
|
|
2883
|
+
tdKVcdKV_t2r = tdKVcdKV_t2r[None, epi_stage]
|
|
2884
|
+
|
|
2885
|
+
tdKVrdKV_t2r = cute.make_fragment(tdKVcdKV_t2r.shape, Float32)
|
|
2886
|
+
|
|
2887
|
+
assert cute.size(tdKVrdKV_t2r) == cute.size(tdKVtdKV_t2r) // cute.arch.WARP_SIZE, (
|
|
2888
|
+
"RMEM<->TMEM fragment size mismatch"
|
|
2889
|
+
)
|
|
2890
|
+
|
|
2891
|
+
# TMEM -> RMEM -- copy and fence
|
|
2892
|
+
cute.copy(thr_copy_t2r, tdKVtdKV_t2r, tdKVrdKV_t2r)
|
|
2893
|
+
cute.arch.fence_view_async_tmem_load()
|
|
2894
|
+
|
|
2895
|
+
# RMEM -- scale and convert
|
|
2896
|
+
if const_expr(scale is not None):
|
|
2897
|
+
for i in cutlass.range(cute.size(tdKVrdKV_t2r.shape) // 2, unroll_full=True):
|
|
2898
|
+
tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1] = utils.mul_packed_f32x2(
|
|
2899
|
+
(tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1]), (scale, scale)
|
|
2900
|
+
)
|
|
2901
|
+
tdKVrdKV = cute.make_fragment(tdKVrdKV_t2r.shape, self.dv_dtype) # (32 columns)
|
|
2902
|
+
tdKVrdKV.store(tdKVrdKV_t2r.load().to(self.dv_dtype))
|
|
2903
|
+
|
|
2904
|
+
# RMEM -> SMEM -- copy, fence and barrier
|
|
2905
|
+
tdKVrdKV_r2s = cute.make_tensor(tdKVrdKV.iterator, tdKVsdKV_r2s.shape)
|
|
2906
|
+
cute.copy(thr_copy_r2s_dKV, tdKVrdKV_r2s, tdKVsdKV_r2s)
|
|
2907
|
+
cute.arch.fence_proxy(
|
|
2908
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
2909
|
+
)
|
|
2910
|
+
cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128)
|
|
2911
|
+
|
|
2912
|
+
# SMEM -> GMEM
|
|
2913
|
+
if leader_warp:
|
|
2914
|
+
if const_expr(not self.dKV_postprocess):
|
|
2915
|
+
cute.copy(tma_atom_dKV, tdKVsdKV, tdKVgdKV[None, epi_stage])
|
|
2916
|
+
else:
|
|
2917
|
+
with cute.arch.elect_one():
|
|
2918
|
+
copy_utils.cpasync_reduce_bulk_add_f32(
|
|
2919
|
+
sdKV.iterator,
|
|
2920
|
+
gdKV_epi[None, epi_stage].iterator,
|
|
2921
|
+
self.tma_copy_bytes["dKacc"],
|
|
2922
|
+
)
|
|
2923
|
+
if const_expr(epi_stage < num_epi_stages - 1):
|
|
2924
|
+
cute.arch.cp_async_bulk_commit_group()
|
|
2925
|
+
cute.arch.cp_async_bulk_wait_group(0, read=read_flag)
|
|
2926
|
+
cute.arch.barrier_arrive(
|
|
2927
|
+
barrier_id=barrier_id + wg_idx, number_of_threads=128 + cute.arch.WARP_SIZE
|
|
2928
|
+
)
|
|
2929
|
+
|
|
2930
|
+
# Barrier since all warps need to wait for SMEM to be freed
|
|
2931
|
+
cute.arch.fence_proxy(
|
|
2932
|
+
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
|
2933
|
+
)
|
|
2934
|
+
cute.arch.barrier(
|
|
2935
|
+
barrier_id=barrier_id + wg_idx, number_of_threads=128 + cute.arch.WARP_SIZE
|
|
2936
|
+
)
|
|
2937
|
+
|
|
2938
|
+
# semaphore release
|
|
2939
|
+
# NOTE: arrive_inc calls red_release which issues membar
|
|
2940
|
+
if const_expr(deterministic_KV):
|
|
2941
|
+
if leader_warp:
|
|
2942
|
+
cute.arch.cp_async_bulk_commit_group()
|
|
2943
|
+
cute.arch.cp_async_bulk_wait_group(0, read=read_flag)
|
|
2944
|
+
cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128)
|
|
2945
|
+
barrier.arrive_inc(mdKV_semaphore_cur.iterator, tidx, wg_idx, 1)
|
|
2946
|
+
|
|
2947
|
+
cute.arch.sync_warp()
|
|
2948
|
+
with cute.arch.elect_one():
|
|
2949
|
+
pipeline_dKV.consumer_release(consumer_state_dKV)
|
|
2950
|
+
consumer_state_dKV.advance()
|
|
2951
|
+
return consumer_state_dKV
|