quack-kernels 0.2.4__tar.gz → 0.2.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {quack_kernels-0.2.4/quack_kernels.egg-info → quack_kernels-0.2.6}/PKG-INFO +2 -2
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/pyproject.toml +3 -2
- quack_kernels-0.2.6/quack/__init__.py +21 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/activation.py +72 -64
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/broadcast_utils.py +1 -1
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/copy_utils.py +143 -20
- quack_kernels-0.2.6/quack/cute_dsl_ptxas.py +151 -0
- quack_kernels-0.2.6/quack/fast_math.py +33 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/gemm_act.py +296 -8
- quack_kernels-0.2.6/quack/gemm_dact.py +731 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/gemm_default_epi.py +4 -4
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/gemm_interface.py +363 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/gemm_sm100.py +62 -88
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/gemm_sm90.py +68 -114
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/gemm_symmetric.py +2 -6
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/layout_utils.py +10 -4
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/linear.py +37 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/pipeline.py +87 -99
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/reduce.py +2 -2
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/rmsnorm.py +1 -3
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/sm90_utils.py +34 -2
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/sort/bitonic_sort.py +4 -4
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/tile_scheduler.py +310 -256
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/topk.py +4 -4
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/utils.py +76 -40
- {quack_kernels-0.2.4 → quack_kernels-0.2.6/quack_kernels.egg-info}/PKG-INFO +2 -2
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack_kernels.egg-info/SOURCES.txt +1 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack_kernels.egg-info/requires.txt +1 -1
- quack_kernels-0.2.6/quack_kernels.egg-info/top_level.txt +1 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/tests/test_linear.py +93 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/tests/test_linear_varlen_m.py +163 -0
- quack_kernels-0.2.4/quack/__init__.py +0 -11
- quack_kernels-0.2.4/quack/fast_math.py +0 -80
- quack_kernels-0.2.4/quack/gemm_dact.py +0 -215
- quack_kernels-0.2.4/quack_kernels.egg-info/top_level.txt +0 -5
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/LICENSE +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/README.md +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/autotuner.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/compile_utils.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/cross_entropy.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/cute_dsl_utils.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/gemm.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/gemm_config.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/gemm_wrapper_utils.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/linear_cross_entropy.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/mlp.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/reduction_base.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/sm100_utils.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/softmax.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/sort/generate_sorting_networks.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/sort/sorting_networks.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/sort/utils.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/tensormap_manager.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack/varlen_utils.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/quack_kernels.egg-info/dependency_links.txt +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/setup.cfg +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/tests/test_cross_entropy.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/tests/test_layernorm.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/tests/test_linear_cross_entropy.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/tests/test_linear_varlen_k.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/tests/test_rmsnorm.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/tests/test_softmax.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/tests/test_symmetric_gemm.py +0 -0
- {quack_kernels-0.2.4 → quack_kernels-0.2.6}/tests/test_topk.py +0 -0
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: quack-kernels
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.6
|
|
4
4
|
Requires-Python: >=3.10
|
|
5
5
|
License-File: LICENSE
|
|
6
|
-
Requires-Dist: nvidia-cutlass-dsl
|
|
6
|
+
Requires-Dist: nvidia-cutlass-dsl>=4.4.0.dev1
|
|
7
7
|
Requires-Dist: torch
|
|
8
8
|
Requires-Dist: apache-tvm-ffi<0.2,>=0.1.6
|
|
9
9
|
Requires-Dist: torch-c-dlpack-ext
|
|
@@ -7,7 +7,7 @@ name = "quack-kernels"
|
|
|
7
7
|
dynamic = ["version"]
|
|
8
8
|
requires-python = ">=3.10"
|
|
9
9
|
dependencies = [
|
|
10
|
-
"nvidia-cutlass-dsl>=4.
|
|
10
|
+
"nvidia-cutlass-dsl>=4.4.0.dev1",
|
|
11
11
|
"torch",
|
|
12
12
|
"apache-tvm-ffi>=0.1.6,<0.2",
|
|
13
13
|
"torch-c-dlpack-ext",
|
|
@@ -20,7 +20,8 @@ dev = [
|
|
|
20
20
|
]
|
|
21
21
|
|
|
22
22
|
[tool.setuptools.packages.find]
|
|
23
|
-
|
|
23
|
+
where = ["."]
|
|
24
|
+
include = ["quack*"]
|
|
24
25
|
|
|
25
26
|
[tool.setuptools.dynamic]
|
|
26
27
|
version = {attr = "quack.__version__"}
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
__version__ = "0.2.6"
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from quack.rmsnorm import rmsnorm
|
|
6
|
+
from quack.softmax import softmax
|
|
7
|
+
from quack.cross_entropy import cross_entropy
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None:
|
|
11
|
+
import quack.cute_dsl_ptxas # noqa: F401
|
|
12
|
+
|
|
13
|
+
# Patch to dump ptx and then use system ptxas to compile to cubin
|
|
14
|
+
quack.cute_dsl_ptxas.patch()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"rmsnorm",
|
|
19
|
+
"softmax",
|
|
20
|
+
"cross_entropy",
|
|
21
|
+
]
|
|
@@ -2,18 +2,24 @@
|
|
|
2
2
|
|
|
3
3
|
import math
|
|
4
4
|
from typing import Tuple
|
|
5
|
+
from functools import partial
|
|
5
6
|
|
|
6
7
|
import cutlass.cute as cute
|
|
7
8
|
from cutlass import Float32, Boolean, const_expr
|
|
8
9
|
from cutlass.cutlass_dsl import T, dsl_user_op
|
|
9
|
-
from cutlass._mlir.dialects import llvm
|
|
10
|
-
|
|
11
|
-
import quack.utils as utils
|
|
10
|
+
from cutlass._mlir.dialects import llvm, nvvm
|
|
12
11
|
|
|
13
12
|
|
|
14
13
|
F32_or_F32x2 = Float32 | Tuple[Float32, Float32]
|
|
15
14
|
|
|
16
15
|
|
|
16
|
+
sub_packed_f32x2 = partial(
|
|
17
|
+
cute.arch.calc_packed_f32x2_op,
|
|
18
|
+
src_c=None,
|
|
19
|
+
calc_func=nvvm.sub_packed_f32x2,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
17
23
|
@dsl_user_op
|
|
18
24
|
def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
19
25
|
return Float32(
|
|
@@ -35,9 +41,9 @@ def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
|
35
41
|
# return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
|
|
36
42
|
return 0.5 + 0.5 * tanh(0.5 * x)
|
|
37
43
|
else:
|
|
38
|
-
x_half =
|
|
44
|
+
x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x)
|
|
39
45
|
tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
|
|
40
|
-
return
|
|
46
|
+
return cute.arch.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5))
|
|
41
47
|
|
|
42
48
|
|
|
43
49
|
@dsl_user_op
|
|
@@ -75,7 +81,7 @@ def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
|
75
81
|
return cute.arch.fmax(x, Float32(0.0)) * x
|
|
76
82
|
else:
|
|
77
83
|
relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)))
|
|
78
|
-
return
|
|
84
|
+
return cute.arch.mul_packed_f32x2(relu_x, x)
|
|
79
85
|
|
|
80
86
|
|
|
81
87
|
@dsl_user_op
|
|
@@ -98,8 +104,8 @@ def drelu_sq(
|
|
|
98
104
|
return dx, relu_sq_out
|
|
99
105
|
else:
|
|
100
106
|
relu_x = relu(x)
|
|
101
|
-
relu_sq_out =
|
|
102
|
-
dx =
|
|
107
|
+
relu_sq_out = cute.arch.mul_packed_f32x2(relu_x, x)
|
|
108
|
+
dx = cute.arch.mul_packed_f32x2((2.0, 2.0), cute.arch.mul_packed_f32x2(dout, relu_x))
|
|
103
109
|
return dx, relu_sq_out
|
|
104
110
|
|
|
105
111
|
|
|
@@ -119,14 +125,14 @@ def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
|
119
125
|
* (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x))))
|
|
120
126
|
)
|
|
121
127
|
else:
|
|
122
|
-
x_sq =
|
|
123
|
-
x_sq_scaled =
|
|
128
|
+
x_sq = cute.arch.mul_packed_f32x2(x, x)
|
|
129
|
+
x_sq_scaled = cute.arch.fma_packed_f32x2(
|
|
124
130
|
x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
|
|
125
131
|
)
|
|
126
|
-
z =
|
|
132
|
+
z = cute.arch.mul_packed_f32x2(x, x_sq_scaled)
|
|
127
133
|
tanh_z = (tanh(z[0]), tanh(z[1]))
|
|
128
|
-
x_tanh_z =
|
|
129
|
-
return
|
|
134
|
+
x_tanh_z = cute.arch.fma_packed_f32x2(tanh_z, x, x)
|
|
135
|
+
return cute.arch.mul_packed_f32x2((0.5, 0.5), x_tanh_z)
|
|
130
136
|
|
|
131
137
|
|
|
132
138
|
@dsl_user_op
|
|
@@ -167,28 +173,28 @@ def dgelu_tanh_approx(
|
|
|
167
173
|
return dx, gelu_out
|
|
168
174
|
else:
|
|
169
175
|
# Compute z = x * (c1 + c2 * x^2)
|
|
170
|
-
x_sq =
|
|
171
|
-
x_sq_scaled =
|
|
176
|
+
x_sq = cute.arch.mul_packed_f32x2(x, x)
|
|
177
|
+
x_sq_scaled = cute.arch.fma_packed_f32x2(
|
|
172
178
|
x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
|
|
173
179
|
)
|
|
174
|
-
z =
|
|
180
|
+
z = cute.arch.mul_packed_f32x2(x, x_sq_scaled)
|
|
175
181
|
tanh_z = (tanh(z[0]), tanh(z[1]))
|
|
176
|
-
half_tanh_z_plus_one =
|
|
177
|
-
gelu_out =
|
|
182
|
+
half_tanh_z_plus_one = cute.arch.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5))
|
|
183
|
+
gelu_out = cute.arch.mul_packed_f32x2(x, half_tanh_z_plus_one)
|
|
178
184
|
|
|
179
185
|
# Compute gradient
|
|
180
186
|
# sech^2(z) = 1 - tanh^2(z)
|
|
181
|
-
sech2_z =
|
|
187
|
+
sech2_z = cute.arch.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0))
|
|
182
188
|
# dz/dx = c1 + 3 * c2 * x^2
|
|
183
|
-
dz_dx =
|
|
189
|
+
dz_dx = cute.arch.fma_packed_f32x2(
|
|
184
190
|
x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi)
|
|
185
191
|
)
|
|
186
192
|
# d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
|
187
|
-
sech2_dz_dx =
|
|
188
|
-
x_sech2_dz_dx =
|
|
189
|
-
dgelu =
|
|
193
|
+
sech2_dz_dx = cute.arch.mul_packed_f32x2(sech2_z, dz_dx)
|
|
194
|
+
x_sech2_dz_dx = cute.arch.mul_packed_f32x2(x, sech2_dz_dx)
|
|
195
|
+
dgelu = cute.arch.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one)
|
|
190
196
|
|
|
191
|
-
dx =
|
|
197
|
+
dx = cute.arch.mul_packed_f32x2(dout, dgelu)
|
|
192
198
|
return dx, gelu_out
|
|
193
199
|
|
|
194
200
|
|
|
@@ -204,15 +210,15 @@ def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
|
204
210
|
)
|
|
205
211
|
else:
|
|
206
212
|
log2_e = math.log2(math.e)
|
|
207
|
-
x_log2e =
|
|
213
|
+
x_log2e = cute.arch.mul_packed_f32x2(x, (log2_e, log2_e))
|
|
208
214
|
x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True))
|
|
209
|
-
x_exp_p1 =
|
|
215
|
+
x_exp_p1 = cute.arch.add_packed_f32x2(x_exp, (1.0, 1.0))
|
|
210
216
|
log_x_exp_p1 = (
|
|
211
217
|
cute.math.log2(x_exp_p1[0], fastmath=True),
|
|
212
218
|
cute.math.log2(x_exp_p1[1], fastmath=True),
|
|
213
219
|
)
|
|
214
220
|
ln2 = math.log(2.0)
|
|
215
|
-
softplus_x =
|
|
221
|
+
softplus_x = cute.arch.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2))
|
|
216
222
|
use_linear_0 = Boolean(x[0] > 20.0)
|
|
217
223
|
use_linear_1 = Boolean(x[1] > 20.0)
|
|
218
224
|
return (
|
|
@@ -241,9 +247,9 @@ def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) ->
|
|
|
241
247
|
# return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
|
|
242
248
|
return x_half * tanh(x_half) + x_half
|
|
243
249
|
else:
|
|
244
|
-
x_half =
|
|
250
|
+
x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x
|
|
245
251
|
tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
|
|
246
|
-
return
|
|
252
|
+
return cute.arch.fma_packed_f32x2(x_half, tanh_x_half, x_half)
|
|
247
253
|
|
|
248
254
|
|
|
249
255
|
@dsl_user_op
|
|
@@ -251,7 +257,7 @@ def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32
|
|
|
251
257
|
if const_expr(not isinstance(x, tuple)):
|
|
252
258
|
return silu(x) * y
|
|
253
259
|
else:
|
|
254
|
-
return
|
|
260
|
+
return cute.arch.mul_packed_f32x2(silu(x), y)
|
|
255
261
|
|
|
256
262
|
|
|
257
263
|
@dsl_user_op
|
|
@@ -301,20 +307,22 @@ def dswiglu(
|
|
|
301
307
|
# Compute sigmoid(x) and silu(x)
|
|
302
308
|
if const_expr(not already_halved):
|
|
303
309
|
sigmoid_x = sigmoid(x)
|
|
304
|
-
silu_x =
|
|
310
|
+
silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_x)
|
|
305
311
|
else:
|
|
306
312
|
tanh_x = (tanh(x[0]), tanh(x[1]))
|
|
307
|
-
sigmoid_x =
|
|
308
|
-
silu_x =
|
|
309
|
-
silu_x_dout =
|
|
313
|
+
sigmoid_x = cute.arch.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5))
|
|
314
|
+
silu_x = cute.arch.fma_packed_f32x2(x, tanh_x, x)
|
|
315
|
+
silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout)
|
|
310
316
|
# d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
|
|
311
|
-
sigmoid_x_minus_silu_x_sigmoid_x =
|
|
317
|
+
sigmoid_x_minus_silu_x_sigmoid_x = cute.arch.fma_packed_f32x2(
|
|
312
318
|
sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x
|
|
313
319
|
)
|
|
314
|
-
d_silu_x_dout =
|
|
315
|
-
|
|
320
|
+
d_silu_x_dout = cute.arch.fma_packed_f32x2(
|
|
321
|
+
sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout
|
|
322
|
+
)
|
|
323
|
+
dx = cute.arch.mul_packed_f32x2(d_silu_x_dout, y)
|
|
316
324
|
dy = silu_x_dout
|
|
317
|
-
swiglu_out =
|
|
325
|
+
swiglu_out = cute.arch.mul_packed_f32x2(silu_x, y)
|
|
318
326
|
return dx, dy, swiglu_out
|
|
319
327
|
|
|
320
328
|
|
|
@@ -334,11 +342,11 @@ def swiglu_oai(
|
|
|
334
342
|
silu_x = x_half * tanh(alpha * x_half) + x_half
|
|
335
343
|
return silu_x * y + silu_x
|
|
336
344
|
else:
|
|
337
|
-
x_half =
|
|
338
|
-
alpha_x_half =
|
|
345
|
+
x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x)
|
|
346
|
+
alpha_x_half = cute.arch.mul_packed_f32x2((alpha, alpha), x_half)
|
|
339
347
|
tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
|
|
340
|
-
silu_x =
|
|
341
|
-
return
|
|
348
|
+
silu_x = cute.arch.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half)
|
|
349
|
+
return cute.arch.fma_packed_f32x2(silu_x, y, silu_x)
|
|
342
350
|
|
|
343
351
|
|
|
344
352
|
@dsl_user_op
|
|
@@ -370,22 +378,22 @@ def dswiglu_oai(
|
|
|
370
378
|
return dx, dy, swiglu_out
|
|
371
379
|
else:
|
|
372
380
|
# Compute sigmoid(alpha * x)
|
|
373
|
-
alpha_x_half =
|
|
381
|
+
alpha_x_half = cute.arch.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x)
|
|
374
382
|
tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
|
|
375
|
-
sigmoid_alpha_x =
|
|
376
|
-
silu_x =
|
|
377
|
-
silu_x_dout =
|
|
383
|
+
sigmoid_alpha_x = cute.arch.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5))
|
|
384
|
+
silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_alpha_x)
|
|
385
|
+
silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout)
|
|
378
386
|
# d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
|
|
379
|
-
silu_x_minus_product =
|
|
387
|
+
silu_x_minus_product = cute.arch.fma_packed_f32x2(
|
|
380
388
|
silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x
|
|
381
389
|
)
|
|
382
|
-
sigmoid_plus_alpha_diff =
|
|
390
|
+
sigmoid_plus_alpha_diff = cute.arch.fma_packed_f32x2(
|
|
383
391
|
(alpha, alpha), silu_x_minus_product, sigmoid_alpha_x
|
|
384
392
|
)
|
|
385
|
-
d_silu_x_dout =
|
|
386
|
-
dx =
|
|
393
|
+
d_silu_x_dout = cute.arch.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout)
|
|
394
|
+
dx = cute.arch.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout)
|
|
387
395
|
dy = silu_x_dout
|
|
388
|
-
swiglu_out =
|
|
396
|
+
swiglu_out = cute.arch.fma_packed_f32x2(silu_x, y, silu_x)
|
|
389
397
|
return dx, dy, swiglu_out
|
|
390
398
|
|
|
391
399
|
|
|
@@ -400,7 +408,7 @@ def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
|
400
408
|
return sigmoid_x * y # FMUL
|
|
401
409
|
else:
|
|
402
410
|
sigmoid_x = sigmoid(x)
|
|
403
|
-
return
|
|
411
|
+
return cute.arch.mul_packed_f32x2(sigmoid_x, y)
|
|
404
412
|
|
|
405
413
|
|
|
406
414
|
@dsl_user_op
|
|
@@ -430,11 +438,11 @@ def dglu(
|
|
|
430
438
|
return dx, dy, glu_out
|
|
431
439
|
else:
|
|
432
440
|
sigmoid_x = sigmoid(x)
|
|
433
|
-
sigmoid_x_dout =
|
|
434
|
-
glu_out =
|
|
441
|
+
sigmoid_x_dout = cute.arch.mul_packed_f32x2(sigmoid_x, dout)
|
|
442
|
+
glu_out = cute.arch.mul_packed_f32x2(sigmoid_x, y)
|
|
435
443
|
# dx = (y - glu_out) * sigmoid_x_dout
|
|
436
|
-
y_minus_glu_out =
|
|
437
|
-
dx =
|
|
444
|
+
y_minus_glu_out = sub_packed_f32x2(y, glu_out)
|
|
445
|
+
dx = cute.arch.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout)
|
|
438
446
|
dy = sigmoid_x_dout
|
|
439
447
|
return dx, dy, glu_out
|
|
440
448
|
|
|
@@ -448,7 +456,7 @@ def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x
|
|
|
448
456
|
return cute.arch.fmax(x, Float32(0.0)) * y
|
|
449
457
|
else:
|
|
450
458
|
relu_x = relu(x)
|
|
451
|
-
return
|
|
459
|
+
return cute.arch.mul_packed_f32x2(relu_x, y)
|
|
452
460
|
|
|
453
461
|
|
|
454
462
|
@dsl_user_op
|
|
@@ -475,10 +483,10 @@ def dreglu(
|
|
|
475
483
|
x0_pos = Boolean(x[0] > 0)
|
|
476
484
|
x1_pos = Boolean(x[1] > 0)
|
|
477
485
|
relu_x = relu(x)
|
|
478
|
-
dout_y =
|
|
486
|
+
dout_y = cute.arch.mul_packed_f32x2(dout, y)
|
|
479
487
|
dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0)))
|
|
480
|
-
dy =
|
|
481
|
-
reglu_out =
|
|
488
|
+
dy = cute.arch.mul_packed_f32x2(dout, relu_x)
|
|
489
|
+
reglu_out = cute.arch.mul_packed_f32x2(relu_x, y)
|
|
482
490
|
return dx, dy, reglu_out
|
|
483
491
|
|
|
484
492
|
|
|
@@ -491,7 +499,7 @@ def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x
|
|
|
491
499
|
if const_expr(not isinstance(x, tuple)):
|
|
492
500
|
return gelu_tanh_approx(x) * y
|
|
493
501
|
else:
|
|
494
|
-
return
|
|
502
|
+
return cute.arch.mul_packed_f32x2(gelu_tanh_approx(x), y)
|
|
495
503
|
|
|
496
504
|
|
|
497
505
|
@dsl_user_op
|
|
@@ -518,7 +526,7 @@ def dgeglu(
|
|
|
518
526
|
# Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
|
|
519
527
|
dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
|
|
520
528
|
# Compute gradients for geglu
|
|
521
|
-
dx =
|
|
522
|
-
dy =
|
|
523
|
-
geglu_out =
|
|
529
|
+
dx = cute.arch.mul_packed_f32x2(dgelu_x_dout, y)
|
|
530
|
+
dy = cute.arch.mul_packed_f32x2(gelu_x, dout)
|
|
531
|
+
geglu_out = cute.arch.mul_packed_f32x2(gelu_x, y)
|
|
524
532
|
return dx, dy, geglu_out
|
|
@@ -11,7 +11,7 @@ from quack.layout_utils import make_acc_tensor_mn_view
|
|
|
11
11
|
@cute.jit
|
|
12
12
|
def vec_op(tCrC: cute.Tensor, tCrVec: cute.Tensor, op: Callable, is_colvec: bool) -> None:
|
|
13
13
|
if const_expr(tCrC.element_type != Float32): # Convert to f32
|
|
14
|
-
tCrC_f32 = cute.
|
|
14
|
+
tCrC_f32 = cute.make_rmem_tensor(tCrC.shape, Float32)
|
|
15
15
|
tCrC_f32.store(tCrC.load().to(Float32))
|
|
16
16
|
else:
|
|
17
17
|
tCrC_f32 = tCrC
|
|
@@ -7,18 +7,19 @@ import cutlass
|
|
|
7
7
|
import cutlass.cute as cute
|
|
8
8
|
|
|
9
9
|
from cutlass import Int32, Boolean, const_expr
|
|
10
|
-
from cutlass.cute.nvgpu import cpasync
|
|
10
|
+
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
|
|
11
11
|
from cutlass.cutlass_dsl import dsl_user_op
|
|
12
12
|
import cutlass.pipeline
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
@dsl_user_op
|
|
16
16
|
def cvt_copy(
|
|
17
|
-
|
|
17
|
+
tiled_copy: cute.TiledCopy,
|
|
18
18
|
src: cute.Tensor,
|
|
19
19
|
dst: cute.Tensor,
|
|
20
20
|
*,
|
|
21
21
|
pred: Optional[cute.Tensor] = None,
|
|
22
|
+
retile: bool = False,
|
|
22
23
|
loc=None,
|
|
23
24
|
ip=None,
|
|
24
25
|
**kwargs,
|
|
@@ -28,7 +29,9 @@ def cvt_copy(
|
|
|
28
29
|
src_cvt = cute.make_fragment_like(src, dst.element_type)
|
|
29
30
|
src_cvt.store(src.load().to(dst.element_type))
|
|
30
31
|
src = src_cvt
|
|
31
|
-
|
|
32
|
+
if const_expr(retile):
|
|
33
|
+
src = tiled_copy.retile(src)
|
|
34
|
+
cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
|
|
32
35
|
|
|
33
36
|
|
|
34
37
|
@dsl_user_op
|
|
@@ -49,7 +52,7 @@ def load_s2r_retile(
|
|
|
49
52
|
) -> cute.Tensor:
|
|
50
53
|
# Will also accept dst_shape being a tensor, in which case we write into that tensor
|
|
51
54
|
if const_expr(not isinstance(dst_shape, cute.Tensor)):
|
|
52
|
-
dst = cute.
|
|
55
|
+
dst = cute.make_rmem_tensor(dst_shape, src.element_type, loc=loc, ip=ip)
|
|
53
56
|
else:
|
|
54
57
|
dst = dst_shape
|
|
55
58
|
cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip)
|
|
@@ -114,7 +117,7 @@ def tiled_copy_2d(
|
|
|
114
117
|
@cute.jit
|
|
115
118
|
def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor:
|
|
116
119
|
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
|
|
117
|
-
tApA = cute.
|
|
120
|
+
tApA = cute.make_rmem_tensor(
|
|
118
121
|
cute.make_layout(
|
|
119
122
|
(cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
|
|
120
123
|
stride=(cute.size(tAcA, mode=[2]), 0, 1),
|
|
@@ -239,9 +242,7 @@ def sm90_get_smem_load_op(
|
|
|
239
242
|
raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
|
|
240
243
|
is_m_major = layout_c.is_m_major_c()
|
|
241
244
|
if elem_ty_c.width == 16:
|
|
242
|
-
return cute.make_copy_atom(
|
|
243
|
-
cute.nvgpu.warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip
|
|
244
|
-
)
|
|
245
|
+
return cute.make_copy_atom(warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip)
|
|
245
246
|
else:
|
|
246
247
|
return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
|
|
247
248
|
|
|
@@ -257,11 +258,127 @@ def get_smem_store_atom(
|
|
|
257
258
|
)
|
|
258
259
|
else:
|
|
259
260
|
return cute.make_copy_atom(
|
|
260
|
-
|
|
261
|
+
warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
|
|
261
262
|
element_type,
|
|
262
263
|
)
|
|
263
264
|
|
|
264
265
|
|
|
266
|
+
def get_smem_load_atom(
|
|
267
|
+
arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
|
|
268
|
+
) -> cute.CopyAtom:
|
|
269
|
+
if const_expr(arch < 90 or element_type.width != 16):
|
|
270
|
+
return cute.make_copy_atom(
|
|
271
|
+
cute.nvgpu.CopyUniversalOp(),
|
|
272
|
+
element_type,
|
|
273
|
+
num_bits_per_copy=(2 if not transpose else 1) * element_type.width,
|
|
274
|
+
)
|
|
275
|
+
else:
|
|
276
|
+
return cute.make_copy_atom(
|
|
277
|
+
warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
|
|
278
|
+
element_type,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def get_smem_store_C(
|
|
283
|
+
tiled_mma: cute.TiledMma,
|
|
284
|
+
sC: cute.Tensor,
|
|
285
|
+
tidx: Int32,
|
|
286
|
+
arch: int,
|
|
287
|
+
transpose: bool = False,
|
|
288
|
+
position_independent=False,
|
|
289
|
+
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
|
290
|
+
dtype = sC.element_type
|
|
291
|
+
copy_atom = get_smem_store_atom(arch, dtype, transpose)
|
|
292
|
+
tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
|
|
293
|
+
thr_copy = tiled_copy.get_slice(tidx)
|
|
294
|
+
if const_expr(not position_independent):
|
|
295
|
+
tRS_sC = thr_copy.partition_D(sC)
|
|
296
|
+
else:
|
|
297
|
+
tRS_sC = partition_D_position_independent(thr_copy, sC)
|
|
298
|
+
|
|
299
|
+
def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
|
|
300
|
+
cvt_copy(tiled_copy, src, tRS_sC[None, None, None, dst_idx], retile=True, **new_kwargs)
|
|
301
|
+
|
|
302
|
+
return copy_fn, thr_copy, tRS_sC
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def get_smem_load_C(
|
|
306
|
+
tiled_mma: cute.TiledMma,
|
|
307
|
+
sC: cute.Tensor,
|
|
308
|
+
tidx: Int32,
|
|
309
|
+
arch: int,
|
|
310
|
+
transpose: bool = False,
|
|
311
|
+
position_independent=False,
|
|
312
|
+
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
|
313
|
+
dtype = sC.element_type
|
|
314
|
+
copy_atom = get_smem_load_atom(arch, dtype, transpose)
|
|
315
|
+
tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma)
|
|
316
|
+
thr_copy = tiled_copy.get_slice(tidx)
|
|
317
|
+
if const_expr(not position_independent):
|
|
318
|
+
tSR_sC = thr_copy.partition_S(sC)
|
|
319
|
+
else:
|
|
320
|
+
tSR_sC = partition_S_position_independent(thr_copy, sC)
|
|
321
|
+
copy_atom_RS = get_smem_store_atom(arch, dtype, transpose)
|
|
322
|
+
thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx)
|
|
323
|
+
tRS_shape = thr_copy_RS.partition_S(cute.make_identity_tensor(sC.shape[:2])).shape
|
|
324
|
+
|
|
325
|
+
def copy_fn(src_idx: Int32, **new_kwargs):
|
|
326
|
+
return load_s2r_retile(
|
|
327
|
+
tiled_copy, tSR_sC[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
return copy_fn, thr_copy, tSR_sC
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def get_smem_store_A(
|
|
334
|
+
tiled_mma: cute.TiledMma, sA: cute.Tensor, tidx: Int32, arch: int, position_independent=False
|
|
335
|
+
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
|
336
|
+
dtype = sA.element_type
|
|
337
|
+
transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN
|
|
338
|
+
copy_atom = get_smem_store_atom(arch, dtype, transpose)
|
|
339
|
+
tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma)
|
|
340
|
+
thr_copy = tiled_copy.get_slice(tidx)
|
|
341
|
+
if const_expr(not position_independent):
|
|
342
|
+
tRS_sA = thr_copy.partition_D(sA)
|
|
343
|
+
else:
|
|
344
|
+
tRS_sA = partition_D_position_independent(thr_copy, sA)
|
|
345
|
+
|
|
346
|
+
def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs):
|
|
347
|
+
cvt_copy(tiled_copy, src, tRS_sA[None, None, None, dst_idx], retile=True, **new_kwargs)
|
|
348
|
+
|
|
349
|
+
return copy_fn, thr_copy, tRS_sA
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def get_smem_load_A(
|
|
353
|
+
tiled_mma: cute.TiledMma,
|
|
354
|
+
sA: cute.Tensor,
|
|
355
|
+
tidx: Int32,
|
|
356
|
+
arch: int,
|
|
357
|
+
with_dst_tensor: bool = False,
|
|
358
|
+
position_independent=False,
|
|
359
|
+
) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]:
|
|
360
|
+
dtype = sA.element_type
|
|
361
|
+
transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN
|
|
362
|
+
copy_atom = get_smem_load_atom(arch, dtype, transpose)
|
|
363
|
+
tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma)
|
|
364
|
+
thr_copy = tiled_copy.get_slice(tidx)
|
|
365
|
+
if const_expr(not position_independent):
|
|
366
|
+
tSR_sA = thr_copy.partition_S(sA)
|
|
367
|
+
else:
|
|
368
|
+
tSR_sA = partition_S_position_independent(thr_copy, sA)
|
|
369
|
+
tRS_shape = tiled_mma.partition_shape_A(sA.shape[:2])
|
|
370
|
+
|
|
371
|
+
def copy_fn(src_idx: Int32, **new_kwargs):
|
|
372
|
+
return load_s2r_retile(
|
|
373
|
+
tiled_copy, tSR_sA[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
def copy_fn_w_dst_tensor(src_idx: Int32, dst: cute.Tensor, **new_kwargs):
|
|
377
|
+
return load_s2r_retile(tiled_copy, tSR_sA[None, None, None, src_idx], dst, **new_kwargs)
|
|
378
|
+
|
|
379
|
+
return copy_fn if not with_dst_tensor else copy_fn_w_dst_tensor, thr_copy, tSR_sA
|
|
380
|
+
|
|
381
|
+
|
|
265
382
|
def tma_get_copy_fn(
|
|
266
383
|
atom: cute.CopyAtom,
|
|
267
384
|
cta_coord: cute.Coord,
|
|
@@ -269,6 +386,7 @@ def tma_get_copy_fn(
|
|
|
269
386
|
src_tensor: cute.Tensor,
|
|
270
387
|
dst_tensor: cute.Tensor,
|
|
271
388
|
filter_zeros: bool = False,
|
|
389
|
+
single_stage: bool = False,
|
|
272
390
|
**kwargs,
|
|
273
391
|
) -> Callable:
|
|
274
392
|
src_is_smem = const_expr(
|
|
@@ -276,13 +394,15 @@ def tma_get_copy_fn(
|
|
|
276
394
|
and src_tensor.memspace == cute.AddressSpace.smem
|
|
277
395
|
)
|
|
278
396
|
smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor)
|
|
397
|
+
group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0))
|
|
398
|
+
group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0))
|
|
279
399
|
# ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK)
|
|
280
400
|
s, g = cpasync.tma_partition(
|
|
281
401
|
atom,
|
|
282
402
|
cta_coord,
|
|
283
403
|
cta_layout,
|
|
284
|
-
cute.group_modes(smem_tensor, 0,
|
|
285
|
-
cute.group_modes(gmem_tensor, 0,
|
|
404
|
+
cute.group_modes(smem_tensor, 0, group_rank_smem),
|
|
405
|
+
cute.group_modes(gmem_tensor, 0, group_rank_gmem),
|
|
286
406
|
)
|
|
287
407
|
if const_expr(filter_zeros):
|
|
288
408
|
s = cute.filter_zeros(s)
|
|
@@ -292,7 +412,10 @@ def tma_get_copy_fn(
|
|
|
292
412
|
def copy_tma(src_idx, dst_idx, **new_kwargs):
|
|
293
413
|
cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs)
|
|
294
414
|
|
|
295
|
-
|
|
415
|
+
def copy_tma_single_stage(**new_kwargs):
|
|
416
|
+
cute.copy(atom, src, dst, **new_kwargs, **kwargs)
|
|
417
|
+
|
|
418
|
+
return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g
|
|
296
419
|
|
|
297
420
|
|
|
298
421
|
def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync):
|
|
@@ -337,10 +460,10 @@ def gather_m_get_copy_fn(
|
|
|
337
460
|
# Read and cache indices for A
|
|
338
461
|
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
|
339
462
|
cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
|
|
340
|
-
tApA_m = cute.
|
|
463
|
+
tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean)
|
|
341
464
|
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
|
342
465
|
tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
|
|
343
|
-
m_idx = cute.
|
|
466
|
+
m_idx = cute.make_rmem_tensor(rows_per_thread, Int32)
|
|
344
467
|
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
|
345
468
|
row_idx = tAcA[0, m, 0][0]
|
|
346
469
|
if tApA_m[m]:
|
|
@@ -353,7 +476,7 @@ def gather_m_get_copy_fn(
|
|
|
353
476
|
def copy_fn(src_idx, dst_idx, pred: bool = False):
|
|
354
477
|
tApA_k = None
|
|
355
478
|
if const_expr(pred):
|
|
356
|
-
tApA_k = cute.
|
|
479
|
+
tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
|
|
357
480
|
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
|
358
481
|
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
|
359
482
|
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
|
@@ -411,7 +534,7 @@ def gather_k_get_copy_fn(
|
|
|
411
534
|
# Read and cache indices for A
|
|
412
535
|
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
|
413
536
|
cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
|
|
414
|
-
tApA_m = cute.
|
|
537
|
+
tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean)
|
|
415
538
|
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
|
416
539
|
tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
|
|
417
540
|
threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load)
|
|
@@ -427,12 +550,12 @@ def gather_k_get_copy_fn(
|
|
|
427
550
|
# Prefetch mAIdx early, even before smem is free
|
|
428
551
|
tApA_k = None
|
|
429
552
|
if const_expr(pred):
|
|
430
|
-
tApA_k = cute.
|
|
553
|
+
tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
|
|
431
554
|
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
|
432
555
|
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
|
433
556
|
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
|
434
557
|
gAIdx_cur = gAIdx[None, src_idx]
|
|
435
|
-
k_idx = cute.
|
|
558
|
+
k_idx = cute.make_rmem_tensor(cols_per_thread, Int32)
|
|
436
559
|
for k in cutlass.range(cols_per_thread):
|
|
437
560
|
col_idx = tAcA[0, 0, k][1]
|
|
438
561
|
if const_expr(not pred):
|
|
@@ -449,13 +572,13 @@ def gather_k_get_copy_fn(
|
|
|
449
572
|
) -> Tuple[cute.Tensor, cute.Tensor]:
|
|
450
573
|
tApA_k = None
|
|
451
574
|
if const_expr(pred):
|
|
452
|
-
tApA_k = cute.
|
|
575
|
+
tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
|
|
453
576
|
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
|
454
577
|
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
|
455
578
|
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
|
456
579
|
a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
|
|
457
580
|
sAIdx_cur = sAIdx[None, dst_idx]
|
|
458
|
-
k_idx = cute.
|
|
581
|
+
k_idx = cute.make_rmem_tensor(cols_per_thread, Int32)
|
|
459
582
|
for k in cutlass.range(cols_per_thread):
|
|
460
583
|
col_idx = tAcA[0, 0, k][1]
|
|
461
584
|
k_idx[k] = sAIdx_cur[col_idx]
|