quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +8 -1
- quack/activation.py +288 -0
- quack/autotuner.py +310 -0
- quack/cross_entropy.py +325 -175
- quack/cute_dsl_utils.py +119 -0
- quack/dense_gemm_sm100.py +2562 -0
- quack/dense_gemm_sm90.py +1657 -842
- quack/fast_math.py +80 -0
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +69 -0
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +569 -0
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +5 -3
- quack/linear.py +240 -0
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +74 -0
- quack/pipeline.py +151 -0
- quack/reduce.py +241 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +583 -231
- quack/softmax.py +27 -15
- quack/sort/bitonic_sort.py +126 -0
- quack/sort/generate_sorting_networks.py +326 -0
- quack/sort/sorting_networks.py +120 -0
- quack/sort/utils.py +31 -0
- quack/symmetric_dense_gemm_sm90.py +2091 -0
- quack/tensormap_manager.py +115 -0
- quack/tile_scheduler.py +937 -0
- quack/topk.py +227 -0
- quack/utils.py +203 -230
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/METADATA +2 -2
- quack_kernels-0.2.0.dist-info/RECORD +37 -0
- quack_kernels-0.1.10.dist-info/RECORD +0 -13
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/top_level.txt +0 -0
quack/__init__.py
CHANGED
|
@@ -1,9 +1,16 @@
|
|
|
1
|
-
__version__ = "0.
|
|
1
|
+
__version__ = "0.2.0"
|
|
2
|
+
|
|
3
|
+
import cutlass.cute as cute
|
|
2
4
|
|
|
3
5
|
from quack.rmsnorm import rmsnorm
|
|
4
6
|
from quack.softmax import softmax
|
|
5
7
|
from quack.cross_entropy import cross_entropy
|
|
6
8
|
|
|
9
|
+
import quack.cute_dsl_utils
|
|
10
|
+
|
|
11
|
+
# Patch cute.compile to optionally dump SASS
|
|
12
|
+
cute.compile = quack.cute_dsl_utils.cute_compile_patched
|
|
13
|
+
|
|
7
14
|
__all__ = [
|
|
8
15
|
"rmsnorm",
|
|
9
16
|
"softmax",
|
quack/activation.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
import cutlass
|
|
7
|
+
import cutlass.cute as cute
|
|
8
|
+
from cutlass import Float32
|
|
9
|
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
|
10
|
+
from cutlass._mlir.dialects import llvm
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dsl_user_op
|
|
14
|
+
def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
15
|
+
return Float32(
|
|
16
|
+
llvm.inline_asm(
|
|
17
|
+
T.f32(),
|
|
18
|
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
19
|
+
"tanh.approx.f32 $0, $1;",
|
|
20
|
+
"=f,f",
|
|
21
|
+
has_side_effects=False,
|
|
22
|
+
is_align_stack=False,
|
|
23
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
24
|
+
)
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dsl_user_op
|
|
29
|
+
def relu(x: Float32, *, loc=None, ip=None) -> Float32:
|
|
30
|
+
return cute.arch.fmax(x, Float32(0.0))
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@cute.jit
|
|
34
|
+
@dsl_user_op
|
|
35
|
+
def drelu(x: Float32, dout: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
|
|
36
|
+
x_pos = cutlass.Boolean(x > 0)
|
|
37
|
+
return dout if x_pos else Float32(0.0), cute.arch.fmax(x, Float32(0.0))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dsl_user_op
|
|
41
|
+
def relu_sq(x: Float32, *, loc=None, ip=None) -> Float32:
|
|
42
|
+
return cute.arch.fmax(x, Float32(0.0)) * x
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@cute.jit
|
|
46
|
+
@dsl_user_op
|
|
47
|
+
def drelu_sq(x: Float32, dout: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
|
|
48
|
+
"""
|
|
49
|
+
ReLU squared backward pass: computes gradient w.r.t. x and recomputes forward
|
|
50
|
+
Given: relu_sq_out = max(x, 0) * x, and dout = grad w.r.t. relu_sq_out
|
|
51
|
+
Returns: (dx, relu_sq_out) where:
|
|
52
|
+
- dx = dout * 2 * x if x > 0, else 0
|
|
53
|
+
- relu_sq_out = max(x, 0) * x
|
|
54
|
+
"""
|
|
55
|
+
x_pos = cutlass.Boolean(x > 0)
|
|
56
|
+
relu_sq_out = cute.arch.fmax(x, Float32(0.0)) * x
|
|
57
|
+
# Derivative: d/dx[max(x,0) * x] = 2*x if x > 0, else 0
|
|
58
|
+
dx = (2.0 * dout * x) if x_pos else Float32(0.0)
|
|
59
|
+
return dx, relu_sq_out
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dsl_user_op
|
|
63
|
+
def gelu_tanh_approx(x: Float32, *, loc=None, ip=None) -> Float32:
|
|
64
|
+
"""
|
|
65
|
+
gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
|
66
|
+
= 0.5 * x * (1 + tanh(x * (0.797885 + 0.0356774 * x * x)))
|
|
67
|
+
"""
|
|
68
|
+
sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885
|
|
69
|
+
sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774
|
|
70
|
+
return 0.5 * (x * (1 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)))))
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@dsl_user_op
|
|
74
|
+
def dgelu_tanh_approx(x: Float32, dout: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
|
|
75
|
+
"""
|
|
76
|
+
GELU tanh approximation backward pass: computes gradient w.r.t. x and recomputes forward
|
|
77
|
+
Given: gelu_out = 0.5 * x * (1 + tanh(x * (c1 + c2 * x^2))), and dout = grad w.r.t. gelu_out
|
|
78
|
+
Returns: (dx, gelu_out)
|
|
79
|
+
|
|
80
|
+
Derivative uses the chain rule:
|
|
81
|
+
d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
|
82
|
+
where z = x * (c1 + c2 * x^2), dz/dx = c1 + 3 * c2 * x^2
|
|
83
|
+
and sech^2(z) = 1 - tanh^2(z)
|
|
84
|
+
"""
|
|
85
|
+
sqrt_2_over_pi = math.sqrt(2 / math.pi) # c1 ~0.797885
|
|
86
|
+
sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # c2 ~0.0356774
|
|
87
|
+
sqrt_2_over_pi_coeff_3 = 3.0 * sqrt_2_over_pi_coeff # c3 ~0.01070322
|
|
88
|
+
|
|
89
|
+
# Compute z = x * (c1 + c2 * x^2)
|
|
90
|
+
x_sq = x * x
|
|
91
|
+
tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq))
|
|
92
|
+
half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z
|
|
93
|
+
gelu_out = x * half_tanh_z_plus_one
|
|
94
|
+
|
|
95
|
+
# Compute gradient
|
|
96
|
+
# sech^2(z) = 1 - tanh^2(z)
|
|
97
|
+
sech2_z = 1 - tanh_z * tanh_z
|
|
98
|
+
# dz/dx = c1 + 3 * c2 * x^2
|
|
99
|
+
dz_dx = sqrt_2_over_pi + sqrt_2_over_pi_coeff_3 * x_sq
|
|
100
|
+
# d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
|
101
|
+
dgelu = half_tanh_z_plus_one + x * (0.5 * (sech2_z * dz_dx))
|
|
102
|
+
|
|
103
|
+
dx = dout * dgelu
|
|
104
|
+
return dx, gelu_out
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@dsl_user_op
|
|
108
|
+
def silu(x: Float32, *, loc=None, ip=None) -> Float32:
|
|
109
|
+
"""
|
|
110
|
+
silu(x) = x * sigmoid(x) = x * (1 + tanh(x / 2)) / 2 = (0.5 * x) * tanh(0.5 * x) + (0.5 * x)
|
|
111
|
+
This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA.
|
|
112
|
+
"""
|
|
113
|
+
x_half = 0.5 * x
|
|
114
|
+
return x_half * tanh(x_half) + x_half
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@dsl_user_op
|
|
118
|
+
def swiglu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
|
|
119
|
+
return silu(x) * y
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@dsl_user_op
|
|
123
|
+
def dswiglu(
|
|
124
|
+
x: Float32, y: Float32, dout: Float32, *, loc=None, ip=None
|
|
125
|
+
) -> Tuple[Float32, Float32, Float32]:
|
|
126
|
+
"""
|
|
127
|
+
SwiGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
|
128
|
+
Given: swiglu_out = silu(x) * y, and dout = grad w.r.t. swiglu_out
|
|
129
|
+
Returns: (dx, dy, swiglu_out) where dx = dout * y * d_silu(x), dy = dout * silu(x)
|
|
130
|
+
|
|
131
|
+
d_silu(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
|
132
|
+
|
|
133
|
+
This has been optimized to use fewer instructions (i.e. we expand things out
|
|
134
|
+
to use FFMA instead of FADD and FMUL).
|
|
135
|
+
"""
|
|
136
|
+
# Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x))
|
|
137
|
+
x_half = 0.5 * x # FMUL
|
|
138
|
+
sigmoid_x = 0.5 + 0.5 * tanh(x_half) # MUFU.TANH, then FFMA
|
|
139
|
+
silu_x = x * sigmoid_x # FMUL
|
|
140
|
+
silu_x_dout = silu_x * dout # FMUL
|
|
141
|
+
# d_silu(x) * dout
|
|
142
|
+
# = sigmoid_x * (1 + x * (1 - sigmoid_x)) * dout
|
|
143
|
+
# = (sigmoid_x + sigmoid_x * x * (1 - sigmoid_x)) * dout
|
|
144
|
+
# = (sigmoid_x + silu_x * (1 - sigmoid_x)) * dout
|
|
145
|
+
# = (sigmoid_x + silu_x - silu_x * sigmoid_x) * dout
|
|
146
|
+
# = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
|
|
147
|
+
d_silu_x_dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x_dout # FFMA, FFMA
|
|
148
|
+
dx = d_silu_x_dout * y # FMUL
|
|
149
|
+
dy = silu_x_dout
|
|
150
|
+
swiglu_out = silu_x * y # FMUL
|
|
151
|
+
# Overall it's 1 MUFU.TANH, 5 FMUL, 3 FFMA
|
|
152
|
+
return dx, dy, swiglu_out
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
@dsl_user_op
|
|
156
|
+
def swiglu_oai(x: Float32, y: Float32, alpha: float = 1.702, *, loc=None, ip=None) -> Float32:
|
|
157
|
+
"""The swiglu variant used in gpt-oss, which has a scaling factor on x and bias of 1 to y.
|
|
158
|
+
https://github.com/openai/gpt-oss/blob/7be9334950053a888e24887a57dac797a17d6e00/gpt_oss/torch/model.py#L249
|
|
159
|
+
x * sigmoid(alpha * x) * (y + 1)
|
|
160
|
+
Compile down to FMUL, FMUL, TANH, FFMA, FFMA
|
|
161
|
+
"""
|
|
162
|
+
# Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
|
|
163
|
+
x_half = 0.5 * x
|
|
164
|
+
silu_x = x_half * tanh(alpha * x_half) + x_half
|
|
165
|
+
return silu_x * y + silu_x
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
@dsl_user_op
|
|
169
|
+
def dswiglu_oai(
|
|
170
|
+
x: Float32, y: Float32, dout: Float32, alpha: float = 1.702, *, loc=None, ip=None
|
|
171
|
+
) -> Tuple[Float32, Float32, Float32]:
|
|
172
|
+
"""
|
|
173
|
+
Swiglu OAI backward pass: computes gradients w.r.t. x and y
|
|
174
|
+
Given: swiglu_oai_out = x * sigmoid(alpha * x) * (y + 1), and dout = grad w.r.t. swiglu_oai_out
|
|
175
|
+
Returns: (dx, dy, swiglu_oai_out)
|
|
176
|
+
|
|
177
|
+
Derivative of x * sigmoid(alpha * x) w.r.t. x:
|
|
178
|
+
d/dx[x * sigmoid(alpha * x)] = sigmoid(alpha * x) + alpha * x * sigmoid(alpha * x) * (1 - sigmoid(alpha * x))
|
|
179
|
+
"""
|
|
180
|
+
# Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
|
|
181
|
+
alpha_x_half = (0.5 * alpha) * x # FMUL
|
|
182
|
+
sigmoid_alpha_x = 0.5 + 0.5 * tanh(alpha_x_half) # MUFU.TANH, then FFMA
|
|
183
|
+
silu_x = x * sigmoid_alpha_x # FMUL
|
|
184
|
+
silu_x_dout = silu_x * dout # FMUL
|
|
185
|
+
# FFMA, FFMA, FMUL
|
|
186
|
+
d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
|
|
187
|
+
dx = d_silu_x_dout * y + d_silu_x_dout # FFMA, instead of multiply by y + 1
|
|
188
|
+
dy = silu_x_dout
|
|
189
|
+
swiglu_out = silu_x * y + silu_x # FFMA, instead of multiply by y + 1
|
|
190
|
+
# Overall it's 1 MUFU.TANH, 4 FMUL, 5 FFMA
|
|
191
|
+
return dx, dy, swiglu_out
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
@dsl_user_op
|
|
195
|
+
def glu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
|
|
196
|
+
"""GLU: Gated Linear Unit
|
|
197
|
+
glu(x, y) = sigmoid(x) * y
|
|
198
|
+
Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2))
|
|
199
|
+
"""
|
|
200
|
+
x_half = 0.5 * x # FMUL
|
|
201
|
+
sigmoid_x = 0.5 + 0.5 * tanh(x_half) # MUFU.TANH, then FFMA
|
|
202
|
+
return sigmoid_x * y # FMUL
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
@dsl_user_op
|
|
206
|
+
def dglu(
|
|
207
|
+
x: Float32, y: Float32, dout: Float32, *, loc=None, ip=None
|
|
208
|
+
) -> Tuple[Float32, Float32, Float32]:
|
|
209
|
+
"""
|
|
210
|
+
GLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
|
211
|
+
Given: glu_out = sigmoid(x) * y, and dout = grad w.r.t. glu_out
|
|
212
|
+
Returns: (dx, dy, glu_out) where:
|
|
213
|
+
- dx = dout * y * sigmoid(x) * (1 - sigmoid(x))
|
|
214
|
+
- dy = dout * sigmoid(x)
|
|
215
|
+
- glu_out = sigmoid(x) * y
|
|
216
|
+
"""
|
|
217
|
+
# Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2))
|
|
218
|
+
x_half = 0.5 * x # FMUL
|
|
219
|
+
sigmoid_x = 0.5 + 0.5 * tanh(x_half) # MUFU.TANH, then FFMA
|
|
220
|
+
sigmoid_x_dout = sigmoid_x * dout # FMUL
|
|
221
|
+
glu_out = sigmoid_x * y # FMUL
|
|
222
|
+
# dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout
|
|
223
|
+
# = y * (1 - sigmoid(x)) * sigmoid_x_dout
|
|
224
|
+
# = (y - y * sigmoid(x)) * sigmoid_x_dout
|
|
225
|
+
# = (y - glu_out) * sigmoid_x_dout
|
|
226
|
+
dx = (y - glu_out) * sigmoid_x_dout # FADD, FMUL
|
|
227
|
+
dy = sigmoid_x_dout
|
|
228
|
+
# Total: 1 MUFU.TANH, 4 FMUL, 1 FADD, 1 FFMA
|
|
229
|
+
return dx, dy, glu_out
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
@dsl_user_op
|
|
233
|
+
def reglu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
|
|
234
|
+
"""ReGLU: ReLU Gated Linear Unit
|
|
235
|
+
reglu(x, y) = relu(x) * y = max(x, 0) * y
|
|
236
|
+
"""
|
|
237
|
+
return cute.arch.fmax(x, Float32(0.0)) * y
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
@cute.jit
|
|
241
|
+
@dsl_user_op
|
|
242
|
+
def dreglu(
|
|
243
|
+
x: Float32, y: Float32, dout: Float32, *, loc=None, ip=None
|
|
244
|
+
) -> Tuple[Float32, Float32, Float32]:
|
|
245
|
+
"""
|
|
246
|
+
ReGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
|
247
|
+
Given: reglu_out = relu(x) * y, and dout = grad w.r.t. reglu_out
|
|
248
|
+
Returns: (dx, dy, reglu_out) where:
|
|
249
|
+
- dx = dout * y if x > 0, else 0
|
|
250
|
+
- dy = dout * relu(x)
|
|
251
|
+
- reglu_out = relu(x) * y
|
|
252
|
+
"""
|
|
253
|
+
x_pos = cutlass.Boolean(x > 0)
|
|
254
|
+
relu_x = cute.arch.fmax(x, Float32(0.0))
|
|
255
|
+
dx = (dout * y) if x_pos else Float32(0.0)
|
|
256
|
+
dy = dout * relu_x
|
|
257
|
+
reglu_out = relu_x * y
|
|
258
|
+
return dx, dy, reglu_out
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
@dsl_user_op
|
|
262
|
+
def geglu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
|
|
263
|
+
"""GeGLU: GELU Gated Linear Unit
|
|
264
|
+
geglu(x, y) = gelu(x) * y
|
|
265
|
+
Uses the tanh approximation of GELU
|
|
266
|
+
"""
|
|
267
|
+
return gelu_tanh_approx(x) * y
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
@dsl_user_op
|
|
271
|
+
def dgeglu(
|
|
272
|
+
x: Float32, y: Float32, dout: Float32, *, loc=None, ip=None
|
|
273
|
+
) -> Tuple[Float32, Float32, Float32]:
|
|
274
|
+
"""
|
|
275
|
+
GeGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
|
276
|
+
Given: geglu_out = gelu(x) * y, and dout = grad w.r.t. geglu_out
|
|
277
|
+
Returns: (dx, dy, geglu_out) where:
|
|
278
|
+
- dx = dout * y * d_gelu(x)
|
|
279
|
+
- dy = dout * gelu(x)
|
|
280
|
+
- geglu_out = gelu(x) * y
|
|
281
|
+
"""
|
|
282
|
+
# Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
|
|
283
|
+
dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
|
|
284
|
+
# Compute gradients for geglu
|
|
285
|
+
dx = dgelu_x_dout * y
|
|
286
|
+
dy = gelu_x * dout
|
|
287
|
+
geglu_out = gelu_x * y
|
|
288
|
+
return dx, dy, geglu_out
|
quack/autotuner.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
# Adapted from https://github.com/triton-lang/triton/blob/main/python/triton/runtime/autotuner.py
|
|
2
|
+
# Copyright (C) 2025, Tri Dao.
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import builtins
|
|
6
|
+
import os
|
|
7
|
+
import time
|
|
8
|
+
import inspect
|
|
9
|
+
import base64
|
|
10
|
+
import hashlib
|
|
11
|
+
import json
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from functools import cached_property, partial
|
|
14
|
+
from typing import Dict, Tuple
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from torch import Tensor
|
|
18
|
+
|
|
19
|
+
import triton
|
|
20
|
+
|
|
21
|
+
from . import __version__
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
PACKAGE_NAME = "quack"
|
|
25
|
+
VERSION = __version__
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_home_dir():
|
|
29
|
+
return os.getenv(f"{PACKAGE_NAME.upper()}_HOME", Path.home())
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def default_cache_dir():
|
|
33
|
+
return os.path.join(get_home_dir(), f".{PACKAGE_NAME}", "cache")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class FileCacheManager(triton.runtime.cache.FileCacheManager):
|
|
37
|
+
def __init__(self, key):
|
|
38
|
+
super().__init__(key)
|
|
39
|
+
self.cache_dir = (
|
|
40
|
+
os.getenv(f"{PACKAGE_NAME.upper()}_CACHE_DIR", "").strip() or default_cache_dir()
|
|
41
|
+
)
|
|
42
|
+
if self.cache_dir:
|
|
43
|
+
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
|
44
|
+
self.lock_path = os.path.join(self.cache_dir, "lock")
|
|
45
|
+
os.makedirs(self.cache_dir, exist_ok=True)
|
|
46
|
+
else:
|
|
47
|
+
raise RuntimeError("Could not create or locate cache dir")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _base32(key):
|
|
51
|
+
# Assume key is a hex string.
|
|
52
|
+
return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class Autotuner:
|
|
56
|
+
def __init__(self, fn, key, configs, restore_value=None, do_bench=None, cache_results=False):
|
|
57
|
+
if not configs:
|
|
58
|
+
self.configs = [AutotuneConfig()]
|
|
59
|
+
else:
|
|
60
|
+
self.configs = configs
|
|
61
|
+
signature = inspect.signature(fn)
|
|
62
|
+
self.keys = key
|
|
63
|
+
self.cache: Dict[Tuple, AutotuneConfig] = {}
|
|
64
|
+
self.arg_names = list(signature.parameters.keys())
|
|
65
|
+
self.cache_results = (
|
|
66
|
+
cache_results or os.getenv(f"{PACKAGE_NAME.upper()}_CACHE_AUTOTUNING", None) == "1"
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
self.restore_value = []
|
|
70
|
+
if restore_value is not None:
|
|
71
|
+
self.restore_value = list(restore_value)
|
|
72
|
+
|
|
73
|
+
if len(self.restore_value) > 0:
|
|
74
|
+
|
|
75
|
+
def _pre_hook(kwargs):
|
|
76
|
+
self.restore_copies = {name: kwargs[name].clone() for name in self.restore_value}
|
|
77
|
+
|
|
78
|
+
self.pre_hook = _pre_hook
|
|
79
|
+
else:
|
|
80
|
+
self.pre_hook = None
|
|
81
|
+
|
|
82
|
+
if len(self.restore_value) > 0:
|
|
83
|
+
|
|
84
|
+
def _post_hook(kwargs, exception):
|
|
85
|
+
for name in self.restore_value:
|
|
86
|
+
kwargs[name].copy_(self.restore_copies[name])
|
|
87
|
+
self.restore_copies = {}
|
|
88
|
+
|
|
89
|
+
self.post_hook = _post_hook
|
|
90
|
+
else:
|
|
91
|
+
self.post_hook = None
|
|
92
|
+
|
|
93
|
+
self.fn = fn
|
|
94
|
+
self._do_bench = do_bench
|
|
95
|
+
|
|
96
|
+
@cached_property
|
|
97
|
+
def do_bench(self):
|
|
98
|
+
if self._do_bench is None:
|
|
99
|
+
return partial(triton.testing.do_bench, warmup=5, rep=25)
|
|
100
|
+
return self._do_bench
|
|
101
|
+
|
|
102
|
+
def _bench(self, *args, config, **meta):
|
|
103
|
+
verbose = os.environ.get(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1"
|
|
104
|
+
if verbose:
|
|
105
|
+
print(f"Autotuning kernel {self.fn.__name__} with config {config}")
|
|
106
|
+
|
|
107
|
+
# check for conflicts, i.e. meta-parameters both provided
|
|
108
|
+
# as kwargs and by the autotuner
|
|
109
|
+
conflicts = meta.keys() & config.kwargs.keys()
|
|
110
|
+
if conflicts:
|
|
111
|
+
raise ValueError(
|
|
112
|
+
f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
|
113
|
+
" Make sure that you don't re-define auto-tuned symbols."
|
|
114
|
+
)
|
|
115
|
+
# augment meta-parameters with tunable ones
|
|
116
|
+
current = dict(meta, **config.all_kwargs())
|
|
117
|
+
full_nargs = {**self.nargs, **current}
|
|
118
|
+
|
|
119
|
+
def kernel_call():
|
|
120
|
+
if self.pre_hook is not None:
|
|
121
|
+
self.pre_hook(full_nargs)
|
|
122
|
+
try:
|
|
123
|
+
self.fn.__call__(
|
|
124
|
+
*args,
|
|
125
|
+
**current,
|
|
126
|
+
)
|
|
127
|
+
except Exception as e:
|
|
128
|
+
try:
|
|
129
|
+
if self.post_hook is not None:
|
|
130
|
+
self.post_hook(full_nargs, exception=e)
|
|
131
|
+
finally:
|
|
132
|
+
# Throw exception raised by `self.fn.run`
|
|
133
|
+
raise
|
|
134
|
+
|
|
135
|
+
if self.post_hook is not None:
|
|
136
|
+
self.post_hook(full_nargs, exception=None)
|
|
137
|
+
|
|
138
|
+
try:
|
|
139
|
+
return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
|
|
140
|
+
except Exception as e:
|
|
141
|
+
if verbose:
|
|
142
|
+
print(f"Autotuning failed with {e}")
|
|
143
|
+
return [float("inf"), float("inf"), float("inf")]
|
|
144
|
+
|
|
145
|
+
@torch.compiler.disable
|
|
146
|
+
def check_disk_cache(self, tuning_key, configs, bench_fn):
|
|
147
|
+
if not tuning_key:
|
|
148
|
+
bench_fn()
|
|
149
|
+
return
|
|
150
|
+
|
|
151
|
+
fn = self.fn
|
|
152
|
+
config_str_list = [str(c) for c in configs]
|
|
153
|
+
assert len(config_str_list) == len(set(config_str_list)), "Config strings must be unique"
|
|
154
|
+
cache_key = [VERSION, str(tuning_key)] + config_str_list
|
|
155
|
+
cache_key = hashlib.sha256("-".join(cache_key).encode("utf-8")).hexdigest()
|
|
156
|
+
cache = FileCacheManager(_base32(cache_key))
|
|
157
|
+
file_name = f"{fn.__name__[:150]}.autotune.json"
|
|
158
|
+
path = cache.get_file(file_name)
|
|
159
|
+
# There's an environment variable to force cache update
|
|
160
|
+
if path and not os.environ.get(f"{PACKAGE_NAME.upper()}_FORCE_CACHE_UPDATE", False):
|
|
161
|
+
str2config = {s: c for s, c in zip(config_str_list, configs)}
|
|
162
|
+
with open(path, "r") as cached_configs:
|
|
163
|
+
timings = json.load(cached_configs)["configs_timings"]
|
|
164
|
+
timings = {str2config[config]: timing for config, timing in timings}
|
|
165
|
+
self.cache[tuning_key] = builtins.min(timings, key=timings.get)
|
|
166
|
+
self.configs_timings = timings
|
|
167
|
+
self.bench_time = 0
|
|
168
|
+
return
|
|
169
|
+
|
|
170
|
+
bench_fn()
|
|
171
|
+
cache.put(
|
|
172
|
+
json.dumps(
|
|
173
|
+
{
|
|
174
|
+
"key": tuning_key,
|
|
175
|
+
"configs_timings": [
|
|
176
|
+
(str(config), timings) for config, timings in self.configs_timings.items()
|
|
177
|
+
],
|
|
178
|
+
}
|
|
179
|
+
),
|
|
180
|
+
file_name,
|
|
181
|
+
binary=False,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
def __call__(self, *args, **kwargs):
|
|
185
|
+
self.nargs = dict(zip(self.arg_names, args))
|
|
186
|
+
used_cached_result = True
|
|
187
|
+
if len(self.configs) > 1:
|
|
188
|
+
all_args = {**self.nargs, **kwargs}
|
|
189
|
+
_args = {k: v for (k, v) in all_args.items() if k in self.arg_names}
|
|
190
|
+
# Need "str" to make it json-serializable
|
|
191
|
+
key = [str(_args[key]) for key in self.keys if key in _args]
|
|
192
|
+
for _, arg in _args.items():
|
|
193
|
+
if isinstance(arg, Tensor):
|
|
194
|
+
key.append(str(arg.shape))
|
|
195
|
+
# If stride != 0, 1, we just cache it as 2
|
|
196
|
+
key.append(str([s if s in {0, 1} else 2 for s in arg.stride()]))
|
|
197
|
+
key.append(str(arg.dtype))
|
|
198
|
+
key = tuple(key)
|
|
199
|
+
if key not in self.cache:
|
|
200
|
+
used_cached_result = False
|
|
201
|
+
|
|
202
|
+
@torch.compiler.disable # Don't want any tracing here
|
|
203
|
+
def benchmark():
|
|
204
|
+
bench_start = time.time()
|
|
205
|
+
timings = {
|
|
206
|
+
config: self._bench(*args, config=config, **kwargs)
|
|
207
|
+
for config in self.configs
|
|
208
|
+
}
|
|
209
|
+
bench_end = time.time()
|
|
210
|
+
if os.getenv(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1":
|
|
211
|
+
for config, time_ in timings.items():
|
|
212
|
+
print(f"[{config}] -> {time_[0]:.3f}ms")
|
|
213
|
+
self.bench_time = bench_end - bench_start
|
|
214
|
+
self.cache[key] = builtins.min(timings, key=timings.get)
|
|
215
|
+
self.configs_timings = timings
|
|
216
|
+
|
|
217
|
+
if self.cache_results:
|
|
218
|
+
self.check_disk_cache(key, self.configs, benchmark)
|
|
219
|
+
else:
|
|
220
|
+
benchmark()
|
|
221
|
+
|
|
222
|
+
config = self.cache[key]
|
|
223
|
+
else:
|
|
224
|
+
config = self.configs[0]
|
|
225
|
+
self.best_config = config
|
|
226
|
+
if (
|
|
227
|
+
os.getenv(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1"
|
|
228
|
+
and not used_cached_result
|
|
229
|
+
):
|
|
230
|
+
print(
|
|
231
|
+
f"{PACKAGE_NAME} autotuning for function {self.fn.__name__} finished after "
|
|
232
|
+
f"{self.bench_time:.2f}s; best config selected: {self.best_config};"
|
|
233
|
+
)
|
|
234
|
+
ret = self.fn.__call__(
|
|
235
|
+
*args,
|
|
236
|
+
**kwargs,
|
|
237
|
+
**config.all_kwargs(),
|
|
238
|
+
)
|
|
239
|
+
self.nargs = None
|
|
240
|
+
return ret
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class AutotuneConfig:
|
|
244
|
+
"""
|
|
245
|
+
An object that represents a possible kernel configuration for the auto-tuner to try.
|
|
246
|
+
|
|
247
|
+
:ivar kwargs: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
|
|
248
|
+
:type kwargs: dict[Str, Any]
|
|
249
|
+
"""
|
|
250
|
+
|
|
251
|
+
def __init__(self, **kwargs):
|
|
252
|
+
self.kwargs = kwargs
|
|
253
|
+
|
|
254
|
+
def __setstate__(self, state):
|
|
255
|
+
self.kwargs = state.get("kwargs", {})
|
|
256
|
+
|
|
257
|
+
def all_kwargs(self):
|
|
258
|
+
return self.kwargs
|
|
259
|
+
|
|
260
|
+
def __str__(self):
|
|
261
|
+
res = []
|
|
262
|
+
for k, v in self.kwargs.items():
|
|
263
|
+
res.append(f"{k}: {v}")
|
|
264
|
+
return ", ".join(res)
|
|
265
|
+
|
|
266
|
+
def __hash__(self):
|
|
267
|
+
return hash(tuple(*self.all_kwargs().items()))
|
|
268
|
+
|
|
269
|
+
def __eq__(self, other):
|
|
270
|
+
self_tuple = tuple(*self.all_kwargs().items())
|
|
271
|
+
other_tuple = tuple(*other.all_kwargs().items())
|
|
272
|
+
return self_tuple == other_tuple
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def autotune(configs, key=None, restore_value=None, do_bench=None, cache_results=True):
|
|
276
|
+
f"""
|
|
277
|
+
Decorator for auto-tuning a function function.
|
|
278
|
+
|
|
279
|
+
.. highlight:: python
|
|
280
|
+
|
|
281
|
+
If the environment variable :code:`{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING` is set to
|
|
282
|
+
:code:`"1"`, we will print a message to stdout after autotuning each
|
|
283
|
+
kernel, including the time spent autotuning and the best configuration.
|
|
284
|
+
|
|
285
|
+
:param configs: a list of :code:`AutotuneConfig` objects
|
|
286
|
+
:type configs: list[AutotuneConfig]
|
|
287
|
+
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
|
288
|
+
:type key: list[str]
|
|
289
|
+
:param restore_value: a list of argument names whose value will be restored after evaluating any configs.
|
|
290
|
+
:type restore_value: list[str]
|
|
291
|
+
:param do_bench: a benchmark function to measure the time of each run.
|
|
292
|
+
:type do_bench: lambda fn, quantiles
|
|
293
|
+
:param cache_results: whether to cache autotune timings to disk. Defaults to False.
|
|
294
|
+
"type cache_results: bool
|
|
295
|
+
"""
|
|
296
|
+
|
|
297
|
+
if key is None:
|
|
298
|
+
key = []
|
|
299
|
+
|
|
300
|
+
def decorator(fn):
|
|
301
|
+
return Autotuner(
|
|
302
|
+
fn,
|
|
303
|
+
key,
|
|
304
|
+
configs,
|
|
305
|
+
restore_value=restore_value,
|
|
306
|
+
do_bench=do_bench,
|
|
307
|
+
cache_results=cache_results,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
return decorator
|