quack-kernels 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/activation.py +72 -64
- quack/broadcast_utils.py +1 -1
- quack/copy_utils.py +14 -18
- quack/fast_math.py +29 -76
- quack/gemm_act.py +296 -8
- quack/gemm_dact.py +520 -4
- quack/gemm_default_epi.py +4 -4
- quack/gemm_interface.py +363 -0
- quack/gemm_sm100.py +62 -88
- quack/gemm_sm90.py +68 -114
- quack/gemm_symmetric.py +2 -6
- quack/layout_utils.py +2 -4
- quack/linear.py +37 -0
- quack/pipeline.py +59 -89
- quack/reduce.py +2 -2
- quack/rmsnorm.py +1 -3
- quack/sm90_utils.py +5 -3
- quack/sort/bitonic_sort.py +3 -3
- quack/tile_scheduler.py +310 -256
- quack/topk.py +4 -4
- quack/utils.py +76 -40
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/METADATA +2 -2
- quack_kernels-0.2.6.dist-info/RECORD +45 -0
- quack_kernels-0.2.5.dist-info/RECORD +0 -45
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/top_level.txt +0 -0
quack/__init__.py
CHANGED
quack/activation.py
CHANGED
|
@@ -2,18 +2,24 @@
|
|
|
2
2
|
|
|
3
3
|
import math
|
|
4
4
|
from typing import Tuple
|
|
5
|
+
from functools import partial
|
|
5
6
|
|
|
6
7
|
import cutlass.cute as cute
|
|
7
8
|
from cutlass import Float32, Boolean, const_expr
|
|
8
9
|
from cutlass.cutlass_dsl import T, dsl_user_op
|
|
9
|
-
from cutlass._mlir.dialects import llvm
|
|
10
|
-
|
|
11
|
-
import quack.utils as utils
|
|
10
|
+
from cutlass._mlir.dialects import llvm, nvvm
|
|
12
11
|
|
|
13
12
|
|
|
14
13
|
F32_or_F32x2 = Float32 | Tuple[Float32, Float32]
|
|
15
14
|
|
|
16
15
|
|
|
16
|
+
sub_packed_f32x2 = partial(
|
|
17
|
+
cute.arch.calc_packed_f32x2_op,
|
|
18
|
+
src_c=None,
|
|
19
|
+
calc_func=nvvm.sub_packed_f32x2,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
17
23
|
@dsl_user_op
|
|
18
24
|
def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
19
25
|
return Float32(
|
|
@@ -35,9 +41,9 @@ def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
|
35
41
|
# return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
|
|
36
42
|
return 0.5 + 0.5 * tanh(0.5 * x)
|
|
37
43
|
else:
|
|
38
|
-
x_half =
|
|
44
|
+
x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x)
|
|
39
45
|
tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
|
|
40
|
-
return
|
|
46
|
+
return cute.arch.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5))
|
|
41
47
|
|
|
42
48
|
|
|
43
49
|
@dsl_user_op
|
|
@@ -75,7 +81,7 @@ def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
|
75
81
|
return cute.arch.fmax(x, Float32(0.0)) * x
|
|
76
82
|
else:
|
|
77
83
|
relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)))
|
|
78
|
-
return
|
|
84
|
+
return cute.arch.mul_packed_f32x2(relu_x, x)
|
|
79
85
|
|
|
80
86
|
|
|
81
87
|
@dsl_user_op
|
|
@@ -98,8 +104,8 @@ def drelu_sq(
|
|
|
98
104
|
return dx, relu_sq_out
|
|
99
105
|
else:
|
|
100
106
|
relu_x = relu(x)
|
|
101
|
-
relu_sq_out =
|
|
102
|
-
dx =
|
|
107
|
+
relu_sq_out = cute.arch.mul_packed_f32x2(relu_x, x)
|
|
108
|
+
dx = cute.arch.mul_packed_f32x2((2.0, 2.0), cute.arch.mul_packed_f32x2(dout, relu_x))
|
|
103
109
|
return dx, relu_sq_out
|
|
104
110
|
|
|
105
111
|
|
|
@@ -119,14 +125,14 @@ def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
|
119
125
|
* (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x))))
|
|
120
126
|
)
|
|
121
127
|
else:
|
|
122
|
-
x_sq =
|
|
123
|
-
x_sq_scaled =
|
|
128
|
+
x_sq = cute.arch.mul_packed_f32x2(x, x)
|
|
129
|
+
x_sq_scaled = cute.arch.fma_packed_f32x2(
|
|
124
130
|
x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
|
|
125
131
|
)
|
|
126
|
-
z =
|
|
132
|
+
z = cute.arch.mul_packed_f32x2(x, x_sq_scaled)
|
|
127
133
|
tanh_z = (tanh(z[0]), tanh(z[1]))
|
|
128
|
-
x_tanh_z =
|
|
129
|
-
return
|
|
134
|
+
x_tanh_z = cute.arch.fma_packed_f32x2(tanh_z, x, x)
|
|
135
|
+
return cute.arch.mul_packed_f32x2((0.5, 0.5), x_tanh_z)
|
|
130
136
|
|
|
131
137
|
|
|
132
138
|
@dsl_user_op
|
|
@@ -167,28 +173,28 @@ def dgelu_tanh_approx(
|
|
|
167
173
|
return dx, gelu_out
|
|
168
174
|
else:
|
|
169
175
|
# Compute z = x * (c1 + c2 * x^2)
|
|
170
|
-
x_sq =
|
|
171
|
-
x_sq_scaled =
|
|
176
|
+
x_sq = cute.arch.mul_packed_f32x2(x, x)
|
|
177
|
+
x_sq_scaled = cute.arch.fma_packed_f32x2(
|
|
172
178
|
x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
|
|
173
179
|
)
|
|
174
|
-
z =
|
|
180
|
+
z = cute.arch.mul_packed_f32x2(x, x_sq_scaled)
|
|
175
181
|
tanh_z = (tanh(z[0]), tanh(z[1]))
|
|
176
|
-
half_tanh_z_plus_one =
|
|
177
|
-
gelu_out =
|
|
182
|
+
half_tanh_z_plus_one = cute.arch.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5))
|
|
183
|
+
gelu_out = cute.arch.mul_packed_f32x2(x, half_tanh_z_plus_one)
|
|
178
184
|
|
|
179
185
|
# Compute gradient
|
|
180
186
|
# sech^2(z) = 1 - tanh^2(z)
|
|
181
|
-
sech2_z =
|
|
187
|
+
sech2_z = cute.arch.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0))
|
|
182
188
|
# dz/dx = c1 + 3 * c2 * x^2
|
|
183
|
-
dz_dx =
|
|
189
|
+
dz_dx = cute.arch.fma_packed_f32x2(
|
|
184
190
|
x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi)
|
|
185
191
|
)
|
|
186
192
|
# d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
|
187
|
-
sech2_dz_dx =
|
|
188
|
-
x_sech2_dz_dx =
|
|
189
|
-
dgelu =
|
|
193
|
+
sech2_dz_dx = cute.arch.mul_packed_f32x2(sech2_z, dz_dx)
|
|
194
|
+
x_sech2_dz_dx = cute.arch.mul_packed_f32x2(x, sech2_dz_dx)
|
|
195
|
+
dgelu = cute.arch.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one)
|
|
190
196
|
|
|
191
|
-
dx =
|
|
197
|
+
dx = cute.arch.mul_packed_f32x2(dout, dgelu)
|
|
192
198
|
return dx, gelu_out
|
|
193
199
|
|
|
194
200
|
|
|
@@ -204,15 +210,15 @@ def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
|
204
210
|
)
|
|
205
211
|
else:
|
|
206
212
|
log2_e = math.log2(math.e)
|
|
207
|
-
x_log2e =
|
|
213
|
+
x_log2e = cute.arch.mul_packed_f32x2(x, (log2_e, log2_e))
|
|
208
214
|
x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True))
|
|
209
|
-
x_exp_p1 =
|
|
215
|
+
x_exp_p1 = cute.arch.add_packed_f32x2(x_exp, (1.0, 1.0))
|
|
210
216
|
log_x_exp_p1 = (
|
|
211
217
|
cute.math.log2(x_exp_p1[0], fastmath=True),
|
|
212
218
|
cute.math.log2(x_exp_p1[1], fastmath=True),
|
|
213
219
|
)
|
|
214
220
|
ln2 = math.log(2.0)
|
|
215
|
-
softplus_x =
|
|
221
|
+
softplus_x = cute.arch.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2))
|
|
216
222
|
use_linear_0 = Boolean(x[0] > 20.0)
|
|
217
223
|
use_linear_1 = Boolean(x[1] > 20.0)
|
|
218
224
|
return (
|
|
@@ -241,9 +247,9 @@ def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) ->
|
|
|
241
247
|
# return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
|
|
242
248
|
return x_half * tanh(x_half) + x_half
|
|
243
249
|
else:
|
|
244
|
-
x_half =
|
|
250
|
+
x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x
|
|
245
251
|
tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
|
|
246
|
-
return
|
|
252
|
+
return cute.arch.fma_packed_f32x2(x_half, tanh_x_half, x_half)
|
|
247
253
|
|
|
248
254
|
|
|
249
255
|
@dsl_user_op
|
|
@@ -251,7 +257,7 @@ def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32
|
|
|
251
257
|
if const_expr(not isinstance(x, tuple)):
|
|
252
258
|
return silu(x) * y
|
|
253
259
|
else:
|
|
254
|
-
return
|
|
260
|
+
return cute.arch.mul_packed_f32x2(silu(x), y)
|
|
255
261
|
|
|
256
262
|
|
|
257
263
|
@dsl_user_op
|
|
@@ -301,20 +307,22 @@ def dswiglu(
|
|
|
301
307
|
# Compute sigmoid(x) and silu(x)
|
|
302
308
|
if const_expr(not already_halved):
|
|
303
309
|
sigmoid_x = sigmoid(x)
|
|
304
|
-
silu_x =
|
|
310
|
+
silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_x)
|
|
305
311
|
else:
|
|
306
312
|
tanh_x = (tanh(x[0]), tanh(x[1]))
|
|
307
|
-
sigmoid_x =
|
|
308
|
-
silu_x =
|
|
309
|
-
silu_x_dout =
|
|
313
|
+
sigmoid_x = cute.arch.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5))
|
|
314
|
+
silu_x = cute.arch.fma_packed_f32x2(x, tanh_x, x)
|
|
315
|
+
silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout)
|
|
310
316
|
# d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
|
|
311
|
-
sigmoid_x_minus_silu_x_sigmoid_x =
|
|
317
|
+
sigmoid_x_minus_silu_x_sigmoid_x = cute.arch.fma_packed_f32x2(
|
|
312
318
|
sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x
|
|
313
319
|
)
|
|
314
|
-
d_silu_x_dout =
|
|
315
|
-
|
|
320
|
+
d_silu_x_dout = cute.arch.fma_packed_f32x2(
|
|
321
|
+
sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout
|
|
322
|
+
)
|
|
323
|
+
dx = cute.arch.mul_packed_f32x2(d_silu_x_dout, y)
|
|
316
324
|
dy = silu_x_dout
|
|
317
|
-
swiglu_out =
|
|
325
|
+
swiglu_out = cute.arch.mul_packed_f32x2(silu_x, y)
|
|
318
326
|
return dx, dy, swiglu_out
|
|
319
327
|
|
|
320
328
|
|
|
@@ -334,11 +342,11 @@ def swiglu_oai(
|
|
|
334
342
|
silu_x = x_half * tanh(alpha * x_half) + x_half
|
|
335
343
|
return silu_x * y + silu_x
|
|
336
344
|
else:
|
|
337
|
-
x_half =
|
|
338
|
-
alpha_x_half =
|
|
345
|
+
x_half = cute.arch.mul_packed_f32x2((0.5, 0.5), x)
|
|
346
|
+
alpha_x_half = cute.arch.mul_packed_f32x2((alpha, alpha), x_half)
|
|
339
347
|
tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
|
|
340
|
-
silu_x =
|
|
341
|
-
return
|
|
348
|
+
silu_x = cute.arch.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half)
|
|
349
|
+
return cute.arch.fma_packed_f32x2(silu_x, y, silu_x)
|
|
342
350
|
|
|
343
351
|
|
|
344
352
|
@dsl_user_op
|
|
@@ -370,22 +378,22 @@ def dswiglu_oai(
|
|
|
370
378
|
return dx, dy, swiglu_out
|
|
371
379
|
else:
|
|
372
380
|
# Compute sigmoid(alpha * x)
|
|
373
|
-
alpha_x_half =
|
|
381
|
+
alpha_x_half = cute.arch.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x)
|
|
374
382
|
tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
|
|
375
|
-
sigmoid_alpha_x =
|
|
376
|
-
silu_x =
|
|
377
|
-
silu_x_dout =
|
|
383
|
+
sigmoid_alpha_x = cute.arch.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5))
|
|
384
|
+
silu_x = cute.arch.mul_packed_f32x2(x, sigmoid_alpha_x)
|
|
385
|
+
silu_x_dout = cute.arch.mul_packed_f32x2(silu_x, dout)
|
|
378
386
|
# d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
|
|
379
|
-
silu_x_minus_product =
|
|
387
|
+
silu_x_minus_product = cute.arch.fma_packed_f32x2(
|
|
380
388
|
silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x
|
|
381
389
|
)
|
|
382
|
-
sigmoid_plus_alpha_diff =
|
|
390
|
+
sigmoid_plus_alpha_diff = cute.arch.fma_packed_f32x2(
|
|
383
391
|
(alpha, alpha), silu_x_minus_product, sigmoid_alpha_x
|
|
384
392
|
)
|
|
385
|
-
d_silu_x_dout =
|
|
386
|
-
dx =
|
|
393
|
+
d_silu_x_dout = cute.arch.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout)
|
|
394
|
+
dx = cute.arch.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout)
|
|
387
395
|
dy = silu_x_dout
|
|
388
|
-
swiglu_out =
|
|
396
|
+
swiglu_out = cute.arch.fma_packed_f32x2(silu_x, y, silu_x)
|
|
389
397
|
return dx, dy, swiglu_out
|
|
390
398
|
|
|
391
399
|
|
|
@@ -400,7 +408,7 @@ def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
|
400
408
|
return sigmoid_x * y # FMUL
|
|
401
409
|
else:
|
|
402
410
|
sigmoid_x = sigmoid(x)
|
|
403
|
-
return
|
|
411
|
+
return cute.arch.mul_packed_f32x2(sigmoid_x, y)
|
|
404
412
|
|
|
405
413
|
|
|
406
414
|
@dsl_user_op
|
|
@@ -430,11 +438,11 @@ def dglu(
|
|
|
430
438
|
return dx, dy, glu_out
|
|
431
439
|
else:
|
|
432
440
|
sigmoid_x = sigmoid(x)
|
|
433
|
-
sigmoid_x_dout =
|
|
434
|
-
glu_out =
|
|
441
|
+
sigmoid_x_dout = cute.arch.mul_packed_f32x2(sigmoid_x, dout)
|
|
442
|
+
glu_out = cute.arch.mul_packed_f32x2(sigmoid_x, y)
|
|
435
443
|
# dx = (y - glu_out) * sigmoid_x_dout
|
|
436
|
-
y_minus_glu_out =
|
|
437
|
-
dx =
|
|
444
|
+
y_minus_glu_out = sub_packed_f32x2(y, glu_out)
|
|
445
|
+
dx = cute.arch.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout)
|
|
438
446
|
dy = sigmoid_x_dout
|
|
439
447
|
return dx, dy, glu_out
|
|
440
448
|
|
|
@@ -448,7 +456,7 @@ def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x
|
|
|
448
456
|
return cute.arch.fmax(x, Float32(0.0)) * y
|
|
449
457
|
else:
|
|
450
458
|
relu_x = relu(x)
|
|
451
|
-
return
|
|
459
|
+
return cute.arch.mul_packed_f32x2(relu_x, y)
|
|
452
460
|
|
|
453
461
|
|
|
454
462
|
@dsl_user_op
|
|
@@ -475,10 +483,10 @@ def dreglu(
|
|
|
475
483
|
x0_pos = Boolean(x[0] > 0)
|
|
476
484
|
x1_pos = Boolean(x[1] > 0)
|
|
477
485
|
relu_x = relu(x)
|
|
478
|
-
dout_y =
|
|
486
|
+
dout_y = cute.arch.mul_packed_f32x2(dout, y)
|
|
479
487
|
dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0)))
|
|
480
|
-
dy =
|
|
481
|
-
reglu_out =
|
|
488
|
+
dy = cute.arch.mul_packed_f32x2(dout, relu_x)
|
|
489
|
+
reglu_out = cute.arch.mul_packed_f32x2(relu_x, y)
|
|
482
490
|
return dx, dy, reglu_out
|
|
483
491
|
|
|
484
492
|
|
|
@@ -491,7 +499,7 @@ def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x
|
|
|
491
499
|
if const_expr(not isinstance(x, tuple)):
|
|
492
500
|
return gelu_tanh_approx(x) * y
|
|
493
501
|
else:
|
|
494
|
-
return
|
|
502
|
+
return cute.arch.mul_packed_f32x2(gelu_tanh_approx(x), y)
|
|
495
503
|
|
|
496
504
|
|
|
497
505
|
@dsl_user_op
|
|
@@ -518,7 +526,7 @@ def dgeglu(
|
|
|
518
526
|
# Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
|
|
519
527
|
dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
|
|
520
528
|
# Compute gradients for geglu
|
|
521
|
-
dx =
|
|
522
|
-
dy =
|
|
523
|
-
geglu_out =
|
|
529
|
+
dx = cute.arch.mul_packed_f32x2(dgelu_x_dout, y)
|
|
530
|
+
dy = cute.arch.mul_packed_f32x2(gelu_x, dout)
|
|
531
|
+
geglu_out = cute.arch.mul_packed_f32x2(gelu_x, y)
|
|
524
532
|
return dx, dy, geglu_out
|
quack/broadcast_utils.py
CHANGED
|
@@ -11,7 +11,7 @@ from quack.layout_utils import make_acc_tensor_mn_view
|
|
|
11
11
|
@cute.jit
|
|
12
12
|
def vec_op(tCrC: cute.Tensor, tCrVec: cute.Tensor, op: Callable, is_colvec: bool) -> None:
|
|
13
13
|
if const_expr(tCrC.element_type != Float32): # Convert to f32
|
|
14
|
-
tCrC_f32 = cute.
|
|
14
|
+
tCrC_f32 = cute.make_rmem_tensor(tCrC.shape, Float32)
|
|
15
15
|
tCrC_f32.store(tCrC.load().to(Float32))
|
|
16
16
|
else:
|
|
17
17
|
tCrC_f32 = tCrC
|
quack/copy_utils.py
CHANGED
|
@@ -7,7 +7,7 @@ import cutlass
|
|
|
7
7
|
import cutlass.cute as cute
|
|
8
8
|
|
|
9
9
|
from cutlass import Int32, Boolean, const_expr
|
|
10
|
-
from cutlass.cute.nvgpu import cpasync, warpgroup
|
|
10
|
+
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
|
|
11
11
|
from cutlass.cutlass_dsl import dsl_user_op
|
|
12
12
|
import cutlass.pipeline
|
|
13
13
|
|
|
@@ -52,7 +52,7 @@ def load_s2r_retile(
|
|
|
52
52
|
) -> cute.Tensor:
|
|
53
53
|
# Will also accept dst_shape being a tensor, in which case we write into that tensor
|
|
54
54
|
if const_expr(not isinstance(dst_shape, cute.Tensor)):
|
|
55
|
-
dst = cute.
|
|
55
|
+
dst = cute.make_rmem_tensor(dst_shape, src.element_type, loc=loc, ip=ip)
|
|
56
56
|
else:
|
|
57
57
|
dst = dst_shape
|
|
58
58
|
cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip)
|
|
@@ -117,7 +117,7 @@ def tiled_copy_2d(
|
|
|
117
117
|
@cute.jit
|
|
118
118
|
def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor:
|
|
119
119
|
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
|
|
120
|
-
tApA = cute.
|
|
120
|
+
tApA = cute.make_rmem_tensor(
|
|
121
121
|
cute.make_layout(
|
|
122
122
|
(cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
|
|
123
123
|
stride=(cute.size(tAcA, mode=[2]), 0, 1),
|
|
@@ -242,9 +242,7 @@ def sm90_get_smem_load_op(
|
|
|
242
242
|
raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}")
|
|
243
243
|
is_m_major = layout_c.is_m_major_c()
|
|
244
244
|
if elem_ty_c.width == 16:
|
|
245
|
-
return cute.make_copy_atom(
|
|
246
|
-
cute.nvgpu.warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip
|
|
247
|
-
)
|
|
245
|
+
return cute.make_copy_atom(warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip)
|
|
248
246
|
else:
|
|
249
247
|
return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip)
|
|
250
248
|
|
|
@@ -260,7 +258,7 @@ def get_smem_store_atom(
|
|
|
260
258
|
)
|
|
261
259
|
else:
|
|
262
260
|
return cute.make_copy_atom(
|
|
263
|
-
|
|
261
|
+
warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
|
|
264
262
|
element_type,
|
|
265
263
|
)
|
|
266
264
|
|
|
@@ -276,7 +274,7 @@ def get_smem_load_atom(
|
|
|
276
274
|
)
|
|
277
275
|
else:
|
|
278
276
|
return cute.make_copy_atom(
|
|
279
|
-
|
|
277
|
+
warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
|
|
280
278
|
element_type,
|
|
281
279
|
)
|
|
282
280
|
|
|
@@ -368,8 +366,6 @@ def get_smem_load_A(
|
|
|
368
366
|
tSR_sA = thr_copy.partition_S(sA)
|
|
369
367
|
else:
|
|
370
368
|
tSR_sA = partition_S_position_independent(thr_copy, sA)
|
|
371
|
-
copy_atom_RS = get_smem_store_atom(arch, dtype, transpose)
|
|
372
|
-
thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx)
|
|
373
369
|
tRS_shape = tiled_mma.partition_shape_A(sA.shape[:2])
|
|
374
370
|
|
|
375
371
|
def copy_fn(src_idx: Int32, **new_kwargs):
|
|
@@ -464,10 +460,10 @@ def gather_m_get_copy_fn(
|
|
|
464
460
|
# Read and cache indices for A
|
|
465
461
|
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
|
466
462
|
cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
|
|
467
|
-
tApA_m = cute.
|
|
463
|
+
tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean)
|
|
468
464
|
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
|
469
465
|
tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
|
|
470
|
-
m_idx = cute.
|
|
466
|
+
m_idx = cute.make_rmem_tensor(rows_per_thread, Int32)
|
|
471
467
|
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
|
472
468
|
row_idx = tAcA[0, m, 0][0]
|
|
473
469
|
if tApA_m[m]:
|
|
@@ -480,7 +476,7 @@ def gather_m_get_copy_fn(
|
|
|
480
476
|
def copy_fn(src_idx, dst_idx, pred: bool = False):
|
|
481
477
|
tApA_k = None
|
|
482
478
|
if const_expr(pred):
|
|
483
|
-
tApA_k = cute.
|
|
479
|
+
tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
|
|
484
480
|
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
|
485
481
|
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
|
486
482
|
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
|
@@ -538,7 +534,7 @@ def gather_k_get_copy_fn(
|
|
|
538
534
|
# Read and cache indices for A
|
|
539
535
|
rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1]))
|
|
540
536
|
cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2]))
|
|
541
|
-
tApA_m = cute.
|
|
537
|
+
tApA_m = cute.make_rmem_tensor(rows_per_thread, Boolean)
|
|
542
538
|
for m in cutlass.range(rows_per_thread, unroll_full=True):
|
|
543
539
|
tApA_m[m] = t0AcA[0, m, 0][0] < limit_m
|
|
544
540
|
threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load)
|
|
@@ -554,12 +550,12 @@ def gather_k_get_copy_fn(
|
|
|
554
550
|
# Prefetch mAIdx early, even before smem is free
|
|
555
551
|
tApA_k = None
|
|
556
552
|
if const_expr(pred):
|
|
557
|
-
tApA_k = cute.
|
|
553
|
+
tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
|
|
558
554
|
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
|
559
555
|
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
|
560
556
|
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
|
561
557
|
gAIdx_cur = gAIdx[None, src_idx]
|
|
562
|
-
k_idx = cute.
|
|
558
|
+
k_idx = cute.make_rmem_tensor(cols_per_thread, Int32)
|
|
563
559
|
for k in cutlass.range(cols_per_thread):
|
|
564
560
|
col_idx = tAcA[0, 0, k][1]
|
|
565
561
|
if const_expr(not pred):
|
|
@@ -576,13 +572,13 @@ def gather_k_get_copy_fn(
|
|
|
576
572
|
) -> Tuple[cute.Tensor, cute.Tensor]:
|
|
577
573
|
tApA_k = None
|
|
578
574
|
if const_expr(pred):
|
|
579
|
-
tApA_k = cute.
|
|
575
|
+
tApA_k = cute.make_rmem_tensor(cols_per_thread, Boolean)
|
|
580
576
|
limit_k_cur = limit_k - src_idx * tile_shape_mk[1]
|
|
581
577
|
for k in cutlass.range(cols_per_thread, unroll_full=True):
|
|
582
578
|
tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur
|
|
583
579
|
a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state)
|
|
584
580
|
sAIdx_cur = sAIdx[None, dst_idx]
|
|
585
|
-
k_idx = cute.
|
|
581
|
+
k_idx = cute.make_rmem_tensor(cols_per_thread, Int32)
|
|
586
582
|
for k in cutlass.range(cols_per_thread):
|
|
587
583
|
col_idx = tAcA[0, 0, k][1]
|
|
588
584
|
k_idx[k] = sAIdx_cur[col_idx]
|
quack/fast_math.py
CHANGED
|
@@ -1,80 +1,33 @@
|
|
|
1
1
|
# Copyright (c) 2025, Tri Dao.
|
|
2
2
|
|
|
3
|
-
from typing import Tuple
|
|
4
|
-
from dataclasses import dataclass
|
|
5
|
-
|
|
6
3
|
import cutlass
|
|
7
4
|
import cutlass.cute as cute
|
|
8
|
-
from cutlass import
|
|
9
|
-
from cutlass.cutlass_dsl import
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
@
|
|
16
|
-
def
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def umulhi(a: Int32, b: Int32, *, loc=None, ip=None) -> Uint32:
|
|
38
|
-
return Uint32(
|
|
39
|
-
llvm.inline_asm(
|
|
40
|
-
T.i32(),
|
|
41
|
-
[Int32(a).ir_value(loc=loc, ip=ip), Int32(b).ir_value(loc=loc, ip=ip)],
|
|
42
|
-
"mul.hi.u32 $0, $1, $2;",
|
|
43
|
-
"=r,r,r",
|
|
44
|
-
has_side_effects=False,
|
|
45
|
-
is_align_stack=False,
|
|
46
|
-
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
47
|
-
)
|
|
48
|
-
)
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
@dataclass
|
|
52
|
-
class FastDivmod(ParamsBase):
|
|
53
|
-
divisor: Int32
|
|
54
|
-
multiplier: Uint32
|
|
55
|
-
shift_right: Uint32
|
|
56
|
-
|
|
57
|
-
# called by host
|
|
58
|
-
@staticmethod
|
|
59
|
-
def create(divisor: Int32) -> "FastDivmod":
|
|
60
|
-
"""Construct the FastDivmod object, in host code.
|
|
61
|
-
This precomputes some values based on the divisor and is computationally expensive.
|
|
62
|
-
"""
|
|
63
|
-
p = Uint32(31 + find_log2(divisor))
|
|
64
|
-
divisor_u32 = Uint32(divisor)
|
|
65
|
-
multiplier = Uint32(((cutlass.Uint64(1) << p) + divisor_u32 - 1) // divisor_u32)
|
|
66
|
-
shift_right = Uint32(p - 32)
|
|
67
|
-
return FastDivmod(divisor, multiplier, shift_right)
|
|
68
|
-
|
|
69
|
-
@cute.jit
|
|
70
|
-
def div(self, dividend: Int32) -> Int32:
|
|
71
|
-
return (
|
|
72
|
-
Int32(umulhi(dividend, self.multiplier) >> self.shift_right)
|
|
73
|
-
if self.divisor != 1
|
|
74
|
-
else dividend
|
|
75
|
-
)
|
|
76
|
-
|
|
77
|
-
def divmod(self, dividend: Int32) -> Tuple[Int32, Int32]:
|
|
78
|
-
quotient = self.div(dividend)
|
|
79
|
-
remainder = dividend - quotient * self.divisor
|
|
80
|
-
return quotient, remainder
|
|
5
|
+
from cutlass.base_dsl.typing import Integer
|
|
6
|
+
from cutlass.cutlass_dsl import dsl_user_op
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FastDivmod(cute.FastDivmodDivisor):
|
|
10
|
+
"""We store the divisor along with the FastDivmodDivisor."""
|
|
11
|
+
|
|
12
|
+
@dsl_user_op
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
divisor: Integer,
|
|
16
|
+
is_power_of_2: bool = None,
|
|
17
|
+
*,
|
|
18
|
+
loc=None,
|
|
19
|
+
ip=None,
|
|
20
|
+
):
|
|
21
|
+
super().__init__(divisor, is_power_of_2=is_power_of_2, loc=loc, ip=ip)
|
|
22
|
+
self.divisor = divisor
|
|
23
|
+
|
|
24
|
+
def __extract_mlir_values__(self):
|
|
25
|
+
"""Extract MLIR values for Host->Device transfer."""
|
|
26
|
+
return [self._divisor] + cutlass.extract_mlir_values(self.divisor)
|
|
27
|
+
|
|
28
|
+
def __new_from_mlir_values__(self, values):
|
|
29
|
+
"""Reconstruct FastDivmodDivisor from MLIR values."""
|
|
30
|
+
new_obj = object.__new__(FastDivmod)
|
|
31
|
+
new_obj._divisor = values[0]
|
|
32
|
+
new_obj.divisor = cutlass.new_from_mlir_values(self.divisor, values[1:])
|
|
33
|
+
return new_obj
|