quack-kernels 0.1.11__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -3
- quack/activation.py +288 -0
- quack/autotuner.py +2 -1
- quack/cross_entropy.py +325 -175
- quack/cute_dsl_utils.py +83 -4
- quack/dense_gemm_sm100.py +1 -1
- quack/dense_gemm_sm90.py +911 -1140
- quack/fast_math.py +10 -27
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +43 -35
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +491 -243
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +5 -3
- quack/linear.py +128 -64
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +30 -160
- quack/pipeline.py +2 -17
- quack/reduce.py +241 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +583 -231
- quack/softmax.py +27 -15
- quack/symmetric_dense_gemm_sm90.py +6 -3
- quack/tensormap_manager.py +1 -0
- quack/tile_scheduler.py +61 -59
- quack/topk.py +14 -8
- quack/utils.py +14 -259
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.0.dist-info}/METADATA +2 -2
- quack_kernels-0.2.0.dist-info/RECORD +37 -0
- quack/lse.py +0 -62
- quack_kernels-0.1.11.dist-info/RECORD +0 -31
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.0.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.0.dist-info}/top_level.txt +0 -0
quack/__init__.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
|
1
|
-
__version__ = "0.
|
|
1
|
+
__version__ = "0.2.0"
|
|
2
|
+
|
|
3
|
+
import cutlass.cute as cute
|
|
2
4
|
|
|
3
5
|
from quack.rmsnorm import rmsnorm
|
|
4
6
|
from quack.softmax import softmax
|
|
5
7
|
from quack.cross_entropy import cross_entropy
|
|
6
8
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
+
import quack.cute_dsl_utils
|
|
10
|
+
|
|
11
|
+
# Patch cute.compile to optionally dump SASS
|
|
12
|
+
cute.compile = quack.cute_dsl_utils.cute_compile_patched
|
|
9
13
|
|
|
10
14
|
__all__ = [
|
|
11
15
|
"rmsnorm",
|
quack/activation.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
import cutlass
|
|
7
|
+
import cutlass.cute as cute
|
|
8
|
+
from cutlass import Float32
|
|
9
|
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
|
10
|
+
from cutlass._mlir.dialects import llvm
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dsl_user_op
|
|
14
|
+
def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
15
|
+
return Float32(
|
|
16
|
+
llvm.inline_asm(
|
|
17
|
+
T.f32(),
|
|
18
|
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
19
|
+
"tanh.approx.f32 $0, $1;",
|
|
20
|
+
"=f,f",
|
|
21
|
+
has_side_effects=False,
|
|
22
|
+
is_align_stack=False,
|
|
23
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
24
|
+
)
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dsl_user_op
|
|
29
|
+
def relu(x: Float32, *, loc=None, ip=None) -> Float32:
|
|
30
|
+
return cute.arch.fmax(x, Float32(0.0))
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@cute.jit
|
|
34
|
+
@dsl_user_op
|
|
35
|
+
def drelu(x: Float32, dout: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
|
|
36
|
+
x_pos = cutlass.Boolean(x > 0)
|
|
37
|
+
return dout if x_pos else Float32(0.0), cute.arch.fmax(x, Float32(0.0))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dsl_user_op
|
|
41
|
+
def relu_sq(x: Float32, *, loc=None, ip=None) -> Float32:
|
|
42
|
+
return cute.arch.fmax(x, Float32(0.0)) * x
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@cute.jit
|
|
46
|
+
@dsl_user_op
|
|
47
|
+
def drelu_sq(x: Float32, dout: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
|
|
48
|
+
"""
|
|
49
|
+
ReLU squared backward pass: computes gradient w.r.t. x and recomputes forward
|
|
50
|
+
Given: relu_sq_out = max(x, 0) * x, and dout = grad w.r.t. relu_sq_out
|
|
51
|
+
Returns: (dx, relu_sq_out) where:
|
|
52
|
+
- dx = dout * 2 * x if x > 0, else 0
|
|
53
|
+
- relu_sq_out = max(x, 0) * x
|
|
54
|
+
"""
|
|
55
|
+
x_pos = cutlass.Boolean(x > 0)
|
|
56
|
+
relu_sq_out = cute.arch.fmax(x, Float32(0.0)) * x
|
|
57
|
+
# Derivative: d/dx[max(x,0) * x] = 2*x if x > 0, else 0
|
|
58
|
+
dx = (2.0 * dout * x) if x_pos else Float32(0.0)
|
|
59
|
+
return dx, relu_sq_out
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dsl_user_op
|
|
63
|
+
def gelu_tanh_approx(x: Float32, *, loc=None, ip=None) -> Float32:
|
|
64
|
+
"""
|
|
65
|
+
gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
|
66
|
+
= 0.5 * x * (1 + tanh(x * (0.797885 + 0.0356774 * x * x)))
|
|
67
|
+
"""
|
|
68
|
+
sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885
|
|
69
|
+
sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774
|
|
70
|
+
return 0.5 * (x * (1 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)))))
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@dsl_user_op
|
|
74
|
+
def dgelu_tanh_approx(x: Float32, dout: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
|
|
75
|
+
"""
|
|
76
|
+
GELU tanh approximation backward pass: computes gradient w.r.t. x and recomputes forward
|
|
77
|
+
Given: gelu_out = 0.5 * x * (1 + tanh(x * (c1 + c2 * x^2))), and dout = grad w.r.t. gelu_out
|
|
78
|
+
Returns: (dx, gelu_out)
|
|
79
|
+
|
|
80
|
+
Derivative uses the chain rule:
|
|
81
|
+
d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
|
82
|
+
where z = x * (c1 + c2 * x^2), dz/dx = c1 + 3 * c2 * x^2
|
|
83
|
+
and sech^2(z) = 1 - tanh^2(z)
|
|
84
|
+
"""
|
|
85
|
+
sqrt_2_over_pi = math.sqrt(2 / math.pi) # c1 ~0.797885
|
|
86
|
+
sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # c2 ~0.0356774
|
|
87
|
+
sqrt_2_over_pi_coeff_3 = 3.0 * sqrt_2_over_pi_coeff # c3 ~0.01070322
|
|
88
|
+
|
|
89
|
+
# Compute z = x * (c1 + c2 * x^2)
|
|
90
|
+
x_sq = x * x
|
|
91
|
+
tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq))
|
|
92
|
+
half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z
|
|
93
|
+
gelu_out = x * half_tanh_z_plus_one
|
|
94
|
+
|
|
95
|
+
# Compute gradient
|
|
96
|
+
# sech^2(z) = 1 - tanh^2(z)
|
|
97
|
+
sech2_z = 1 - tanh_z * tanh_z
|
|
98
|
+
# dz/dx = c1 + 3 * c2 * x^2
|
|
99
|
+
dz_dx = sqrt_2_over_pi + sqrt_2_over_pi_coeff_3 * x_sq
|
|
100
|
+
# d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
|
101
|
+
dgelu = half_tanh_z_plus_one + x * (0.5 * (sech2_z * dz_dx))
|
|
102
|
+
|
|
103
|
+
dx = dout * dgelu
|
|
104
|
+
return dx, gelu_out
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@dsl_user_op
|
|
108
|
+
def silu(x: Float32, *, loc=None, ip=None) -> Float32:
|
|
109
|
+
"""
|
|
110
|
+
silu(x) = x * sigmoid(x) = x * (1 + tanh(x / 2)) / 2 = (0.5 * x) * tanh(0.5 * x) + (0.5 * x)
|
|
111
|
+
This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA.
|
|
112
|
+
"""
|
|
113
|
+
x_half = 0.5 * x
|
|
114
|
+
return x_half * tanh(x_half) + x_half
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@dsl_user_op
|
|
118
|
+
def swiglu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
|
|
119
|
+
return silu(x) * y
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@dsl_user_op
|
|
123
|
+
def dswiglu(
|
|
124
|
+
x: Float32, y: Float32, dout: Float32, *, loc=None, ip=None
|
|
125
|
+
) -> Tuple[Float32, Float32, Float32]:
|
|
126
|
+
"""
|
|
127
|
+
SwiGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
|
128
|
+
Given: swiglu_out = silu(x) * y, and dout = grad w.r.t. swiglu_out
|
|
129
|
+
Returns: (dx, dy, swiglu_out) where dx = dout * y * d_silu(x), dy = dout * silu(x)
|
|
130
|
+
|
|
131
|
+
d_silu(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
|
132
|
+
|
|
133
|
+
This has been optimized to use fewer instructions (i.e. we expand things out
|
|
134
|
+
to use FFMA instead of FADD and FMUL).
|
|
135
|
+
"""
|
|
136
|
+
# Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x))
|
|
137
|
+
x_half = 0.5 * x # FMUL
|
|
138
|
+
sigmoid_x = 0.5 + 0.5 * tanh(x_half) # MUFU.TANH, then FFMA
|
|
139
|
+
silu_x = x * sigmoid_x # FMUL
|
|
140
|
+
silu_x_dout = silu_x * dout # FMUL
|
|
141
|
+
# d_silu(x) * dout
|
|
142
|
+
# = sigmoid_x * (1 + x * (1 - sigmoid_x)) * dout
|
|
143
|
+
# = (sigmoid_x + sigmoid_x * x * (1 - sigmoid_x)) * dout
|
|
144
|
+
# = (sigmoid_x + silu_x * (1 - sigmoid_x)) * dout
|
|
145
|
+
# = (sigmoid_x + silu_x - silu_x * sigmoid_x) * dout
|
|
146
|
+
# = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
|
|
147
|
+
d_silu_x_dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x_dout # FFMA, FFMA
|
|
148
|
+
dx = d_silu_x_dout * y # FMUL
|
|
149
|
+
dy = silu_x_dout
|
|
150
|
+
swiglu_out = silu_x * y # FMUL
|
|
151
|
+
# Overall it's 1 MUFU.TANH, 5 FMUL, 3 FFMA
|
|
152
|
+
return dx, dy, swiglu_out
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
@dsl_user_op
|
|
156
|
+
def swiglu_oai(x: Float32, y: Float32, alpha: float = 1.702, *, loc=None, ip=None) -> Float32:
|
|
157
|
+
"""The swiglu variant used in gpt-oss, which has a scaling factor on x and bias of 1 to y.
|
|
158
|
+
https://github.com/openai/gpt-oss/blob/7be9334950053a888e24887a57dac797a17d6e00/gpt_oss/torch/model.py#L249
|
|
159
|
+
x * sigmoid(alpha * x) * (y + 1)
|
|
160
|
+
Compile down to FMUL, FMUL, TANH, FFMA, FFMA
|
|
161
|
+
"""
|
|
162
|
+
# Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
|
|
163
|
+
x_half = 0.5 * x
|
|
164
|
+
silu_x = x_half * tanh(alpha * x_half) + x_half
|
|
165
|
+
return silu_x * y + silu_x
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
@dsl_user_op
|
|
169
|
+
def dswiglu_oai(
|
|
170
|
+
x: Float32, y: Float32, dout: Float32, alpha: float = 1.702, *, loc=None, ip=None
|
|
171
|
+
) -> Tuple[Float32, Float32, Float32]:
|
|
172
|
+
"""
|
|
173
|
+
Swiglu OAI backward pass: computes gradients w.r.t. x and y
|
|
174
|
+
Given: swiglu_oai_out = x * sigmoid(alpha * x) * (y + 1), and dout = grad w.r.t. swiglu_oai_out
|
|
175
|
+
Returns: (dx, dy, swiglu_oai_out)
|
|
176
|
+
|
|
177
|
+
Derivative of x * sigmoid(alpha * x) w.r.t. x:
|
|
178
|
+
d/dx[x * sigmoid(alpha * x)] = sigmoid(alpha * x) + alpha * x * sigmoid(alpha * x) * (1 - sigmoid(alpha * x))
|
|
179
|
+
"""
|
|
180
|
+
# Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
|
|
181
|
+
alpha_x_half = (0.5 * alpha) * x # FMUL
|
|
182
|
+
sigmoid_alpha_x = 0.5 + 0.5 * tanh(alpha_x_half) # MUFU.TANH, then FFMA
|
|
183
|
+
silu_x = x * sigmoid_alpha_x # FMUL
|
|
184
|
+
silu_x_dout = silu_x * dout # FMUL
|
|
185
|
+
# FFMA, FFMA, FMUL
|
|
186
|
+
d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
|
|
187
|
+
dx = d_silu_x_dout * y + d_silu_x_dout # FFMA, instead of multiply by y + 1
|
|
188
|
+
dy = silu_x_dout
|
|
189
|
+
swiglu_out = silu_x * y + silu_x # FFMA, instead of multiply by y + 1
|
|
190
|
+
# Overall it's 1 MUFU.TANH, 4 FMUL, 5 FFMA
|
|
191
|
+
return dx, dy, swiglu_out
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
@dsl_user_op
|
|
195
|
+
def glu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
|
|
196
|
+
"""GLU: Gated Linear Unit
|
|
197
|
+
glu(x, y) = sigmoid(x) * y
|
|
198
|
+
Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2))
|
|
199
|
+
"""
|
|
200
|
+
x_half = 0.5 * x # FMUL
|
|
201
|
+
sigmoid_x = 0.5 + 0.5 * tanh(x_half) # MUFU.TANH, then FFMA
|
|
202
|
+
return sigmoid_x * y # FMUL
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
@dsl_user_op
|
|
206
|
+
def dglu(
|
|
207
|
+
x: Float32, y: Float32, dout: Float32, *, loc=None, ip=None
|
|
208
|
+
) -> Tuple[Float32, Float32, Float32]:
|
|
209
|
+
"""
|
|
210
|
+
GLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
|
211
|
+
Given: glu_out = sigmoid(x) * y, and dout = grad w.r.t. glu_out
|
|
212
|
+
Returns: (dx, dy, glu_out) where:
|
|
213
|
+
- dx = dout * y * sigmoid(x) * (1 - sigmoid(x))
|
|
214
|
+
- dy = dout * sigmoid(x)
|
|
215
|
+
- glu_out = sigmoid(x) * y
|
|
216
|
+
"""
|
|
217
|
+
# Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2))
|
|
218
|
+
x_half = 0.5 * x # FMUL
|
|
219
|
+
sigmoid_x = 0.5 + 0.5 * tanh(x_half) # MUFU.TANH, then FFMA
|
|
220
|
+
sigmoid_x_dout = sigmoid_x * dout # FMUL
|
|
221
|
+
glu_out = sigmoid_x * y # FMUL
|
|
222
|
+
# dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout
|
|
223
|
+
# = y * (1 - sigmoid(x)) * sigmoid_x_dout
|
|
224
|
+
# = (y - y * sigmoid(x)) * sigmoid_x_dout
|
|
225
|
+
# = (y - glu_out) * sigmoid_x_dout
|
|
226
|
+
dx = (y - glu_out) * sigmoid_x_dout # FADD, FMUL
|
|
227
|
+
dy = sigmoid_x_dout
|
|
228
|
+
# Total: 1 MUFU.TANH, 4 FMUL, 1 FADD, 1 FFMA
|
|
229
|
+
return dx, dy, glu_out
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
@dsl_user_op
|
|
233
|
+
def reglu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
|
|
234
|
+
"""ReGLU: ReLU Gated Linear Unit
|
|
235
|
+
reglu(x, y) = relu(x) * y = max(x, 0) * y
|
|
236
|
+
"""
|
|
237
|
+
return cute.arch.fmax(x, Float32(0.0)) * y
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
@cute.jit
|
|
241
|
+
@dsl_user_op
|
|
242
|
+
def dreglu(
|
|
243
|
+
x: Float32, y: Float32, dout: Float32, *, loc=None, ip=None
|
|
244
|
+
) -> Tuple[Float32, Float32, Float32]:
|
|
245
|
+
"""
|
|
246
|
+
ReGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
|
247
|
+
Given: reglu_out = relu(x) * y, and dout = grad w.r.t. reglu_out
|
|
248
|
+
Returns: (dx, dy, reglu_out) where:
|
|
249
|
+
- dx = dout * y if x > 0, else 0
|
|
250
|
+
- dy = dout * relu(x)
|
|
251
|
+
- reglu_out = relu(x) * y
|
|
252
|
+
"""
|
|
253
|
+
x_pos = cutlass.Boolean(x > 0)
|
|
254
|
+
relu_x = cute.arch.fmax(x, Float32(0.0))
|
|
255
|
+
dx = (dout * y) if x_pos else Float32(0.0)
|
|
256
|
+
dy = dout * relu_x
|
|
257
|
+
reglu_out = relu_x * y
|
|
258
|
+
return dx, dy, reglu_out
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
@dsl_user_op
|
|
262
|
+
def geglu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
|
|
263
|
+
"""GeGLU: GELU Gated Linear Unit
|
|
264
|
+
geglu(x, y) = gelu(x) * y
|
|
265
|
+
Uses the tanh approximation of GELU
|
|
266
|
+
"""
|
|
267
|
+
return gelu_tanh_approx(x) * y
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
@dsl_user_op
|
|
271
|
+
def dgeglu(
|
|
272
|
+
x: Float32, y: Float32, dout: Float32, *, loc=None, ip=None
|
|
273
|
+
) -> Tuple[Float32, Float32, Float32]:
|
|
274
|
+
"""
|
|
275
|
+
GeGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
|
276
|
+
Given: geglu_out = gelu(x) * y, and dout = grad w.r.t. geglu_out
|
|
277
|
+
Returns: (dx, dy, geglu_out) where:
|
|
278
|
+
- dx = dout * y * d_gelu(x)
|
|
279
|
+
- dy = dout * gelu(x)
|
|
280
|
+
- geglu_out = gelu(x) * y
|
|
281
|
+
"""
|
|
282
|
+
# Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
|
|
283
|
+
dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
|
|
284
|
+
# Compute gradients for geglu
|
|
285
|
+
dx = dgelu_x_dout * y
|
|
286
|
+
dy = gelu_x * dout
|
|
287
|
+
geglu_out = gelu_x * y
|
|
288
|
+
return dx, dy, geglu_out
|
quack/autotuner.py
CHANGED
|
@@ -187,7 +187,8 @@ class Autotuner:
|
|
|
187
187
|
if len(self.configs) > 1:
|
|
188
188
|
all_args = {**self.nargs, **kwargs}
|
|
189
189
|
_args = {k: v for (k, v) in all_args.items() if k in self.arg_names}
|
|
190
|
-
|
|
190
|
+
# Need "str" to make it json-serializable
|
|
191
|
+
key = [str(_args[key]) for key in self.keys if key in _args]
|
|
191
192
|
for _, arg in _args.items():
|
|
192
193
|
if isinstance(arg, Tensor):
|
|
193
194
|
key.append(str(arg.shape))
|