quack-kernels 0.1.11__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {quack_kernels-0.1.11/quack_kernels.egg-info → quack_kernels-0.2.0}/PKG-INFO +2 -2
  2. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/README.md +2 -3
  3. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/pyproject.toml +1 -1
  4. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/quack/__init__.py +7 -3
  5. quack_kernels-0.2.0/quack/activation.py +288 -0
  6. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/quack/autotuner.py +2 -1
  7. quack_kernels-0.2.0/quack/cross_entropy.py +734 -0
  8. quack_kernels-0.2.0/quack/cute_dsl_utils.py +119 -0
  9. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/quack/dense_gemm_sm100.py +1 -1
  10. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/quack/dense_gemm_sm90.py +911 -1140
  11. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/quack/fast_math.py +10 -27
  12. quack_kernels-0.2.0/quack/gemm_act_sm90.py +368 -0
  13. quack_kernels-0.2.0/quack/gemm_config.py +69 -0
  14. quack_kernels-0.2.0/quack/gemm_dact_sm90.py +150 -0
  15. quack_kernels-0.2.0/quack/gemm_interface.py +569 -0
  16. quack_kernels-0.2.0/quack/gemm_wrapper_utils.py +158 -0
  17. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/quack/layernorm.py +5 -3
  18. quack_kernels-0.2.0/quack/linear.py +240 -0
  19. quack_kernels-0.2.0/quack/linear_cross_entropy.py +275 -0
  20. quack_kernels-0.2.0/quack/mlp.py +74 -0
  21. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/quack/pipeline.py +2 -17
  22. quack_kernels-0.2.0/quack/reduce.py +241 -0
  23. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/quack/reduction_base.py +2 -11
  24. quack_kernels-0.2.0/quack/rmsnorm.py +1216 -0
  25. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/quack/softmax.py +27 -15
  26. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/quack/symmetric_dense_gemm_sm90.py +6 -3
  27. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/quack/tensormap_manager.py +1 -0
  28. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/quack/tile_scheduler.py +61 -59
  29. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/quack/topk.py +14 -8
  30. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/quack/utils.py +14 -259
  31. quack_kernels-0.2.0/quack/varlen_utils.py +22 -0
  32. {quack_kernels-0.1.11 → quack_kernels-0.2.0/quack_kernels.egg-info}/PKG-INFO +2 -2
  33. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/quack_kernels.egg-info/SOURCES.txt +9 -1
  34. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/quack_kernels.egg-info/requires.txt +1 -1
  35. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/quack_kernels.egg-info/top_level.txt +1 -0
  36. quack_kernels-0.2.0/tests/test_cross_entropy.py +333 -0
  37. quack_kernels-0.2.0/tests/test_linear.py +131 -0
  38. quack_kernels-0.2.0/tests/test_linear_cross_entropy.py +85 -0
  39. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/tests/test_rmsnorm.py +158 -16
  40. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/tests/test_softmax.py +19 -14
  41. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/tests/test_symmetric_dense_gemm_sm90.py +82 -94
  42. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/tests/test_topk.py +14 -8
  43. quack_kernels-0.1.11/quack/cross_entropy.py +0 -584
  44. quack_kernels-0.1.11/quack/cute_dsl_utils.py +0 -40
  45. quack_kernels-0.1.11/quack/gemm_config.py +0 -61
  46. quack_kernels-0.1.11/quack/gemm_interface.py +0 -321
  47. quack_kernels-0.1.11/quack/linear.py +0 -176
  48. quack_kernels-0.1.11/quack/lse.py +0 -62
  49. quack_kernels-0.1.11/quack/mlp.py +0 -204
  50. quack_kernels-0.1.11/quack/rmsnorm.py +0 -864
  51. quack_kernels-0.1.11/tests/test_cross_entropy.py +0 -109
  52. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/LICENSE +0 -0
  53. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/quack/sort/bitonic_sort.py +0 -0
  54. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/quack/sort/generate_sorting_networks.py +0 -0
  55. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/quack/sort/sorting_networks.py +0 -0
  56. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/quack/sort/utils.py +0 -0
  57. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/quack_kernels.egg-info/dependency_links.txt +0 -0
  58. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/setup.cfg +0 -0
  59. {quack_kernels-0.1.11 → quack_kernels-0.2.0}/tests/test_layernorm.py +0 -0
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.11
3
+ Version: 0.2.0
4
4
  Requires-Python: >=3.12
5
5
  License-File: LICENSE
6
- Requires-Dist: nvidia-cutlass-dsl==4.1.0
6
+ Requires-Dist: nvidia-cutlass-dsl==4.2.0
7
7
  Requires-Dist: torch
8
8
  Provides-Extra: dev
9
9
  Requires-Dist: pre-commit; extra == "dev"
@@ -20,9 +20,8 @@ pip install quack-kernels
20
20
  - 🦆 Softmax forward + backward
21
21
  - 🦆 Cross entropy forward + backward
22
22
  - 🦆 Layernorm forward
23
-
24
- Upcoming:
25
- - 🦆 Rotary forward + backward
23
+ - 🦆 Hopper gemm + epilogue
24
+ - 🦆 Blackwell gemm + epilogue
26
25
 
27
26
  ## Usage
28
27
 
@@ -7,7 +7,7 @@ name = "quack-kernels"
7
7
  dynamic = ["version"]
8
8
  requires-python = ">=3.12"
9
9
  dependencies = [
10
- "nvidia-cutlass-dsl==4.1.0",
10
+ "nvidia-cutlass-dsl==4.2.0",
11
11
  "torch",
12
12
  ]
13
13
 
@@ -1,11 +1,15 @@
1
- __version__ = "0.1.11"
1
+ __version__ = "0.2.0"
2
+
3
+ import cutlass.cute as cute
2
4
 
3
5
  from quack.rmsnorm import rmsnorm
4
6
  from quack.softmax import softmax
5
7
  from quack.cross_entropy import cross_entropy
6
8
 
7
- # ruff: noqa
8
- import quack.cute_dsl_utils # Patch cute.compile to optionally dump SASS
9
+ import quack.cute_dsl_utils
10
+
11
+ # Patch cute.compile to optionally dump SASS
12
+ cute.compile = quack.cute_dsl_utils.cute_compile_patched
9
13
 
10
14
  __all__ = [
11
15
  "rmsnorm",
@@ -0,0 +1,288 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ import math
4
+ from typing import Tuple
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+ from cutlass import Float32
9
+ from cutlass.cutlass_dsl import T, dsl_user_op
10
+ from cutlass._mlir.dialects import llvm
11
+
12
+
13
+ @dsl_user_op
14
+ def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
15
+ return Float32(
16
+ llvm.inline_asm(
17
+ T.f32(),
18
+ [Float32(a).ir_value(loc=loc, ip=ip)],
19
+ "tanh.approx.f32 $0, $1;",
20
+ "=f,f",
21
+ has_side_effects=False,
22
+ is_align_stack=False,
23
+ asm_dialect=llvm.AsmDialect.AD_ATT,
24
+ )
25
+ )
26
+
27
+
28
+ @dsl_user_op
29
+ def relu(x: Float32, *, loc=None, ip=None) -> Float32:
30
+ return cute.arch.fmax(x, Float32(0.0))
31
+
32
+
33
+ @cute.jit
34
+ @dsl_user_op
35
+ def drelu(x: Float32, dout: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
36
+ x_pos = cutlass.Boolean(x > 0)
37
+ return dout if x_pos else Float32(0.0), cute.arch.fmax(x, Float32(0.0))
38
+
39
+
40
+ @dsl_user_op
41
+ def relu_sq(x: Float32, *, loc=None, ip=None) -> Float32:
42
+ return cute.arch.fmax(x, Float32(0.0)) * x
43
+
44
+
45
+ @cute.jit
46
+ @dsl_user_op
47
+ def drelu_sq(x: Float32, dout: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
48
+ """
49
+ ReLU squared backward pass: computes gradient w.r.t. x and recomputes forward
50
+ Given: relu_sq_out = max(x, 0) * x, and dout = grad w.r.t. relu_sq_out
51
+ Returns: (dx, relu_sq_out) where:
52
+ - dx = dout * 2 * x if x > 0, else 0
53
+ - relu_sq_out = max(x, 0) * x
54
+ """
55
+ x_pos = cutlass.Boolean(x > 0)
56
+ relu_sq_out = cute.arch.fmax(x, Float32(0.0)) * x
57
+ # Derivative: d/dx[max(x,0) * x] = 2*x if x > 0, else 0
58
+ dx = (2.0 * dout * x) if x_pos else Float32(0.0)
59
+ return dx, relu_sq_out
60
+
61
+
62
+ @dsl_user_op
63
+ def gelu_tanh_approx(x: Float32, *, loc=None, ip=None) -> Float32:
64
+ """
65
+ gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
66
+ = 0.5 * x * (1 + tanh(x * (0.797885 + 0.0356774 * x * x)))
67
+ """
68
+ sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885
69
+ sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774
70
+ return 0.5 * (x * (1 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)))))
71
+
72
+
73
+ @dsl_user_op
74
+ def dgelu_tanh_approx(x: Float32, dout: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
75
+ """
76
+ GELU tanh approximation backward pass: computes gradient w.r.t. x and recomputes forward
77
+ Given: gelu_out = 0.5 * x * (1 + tanh(x * (c1 + c2 * x^2))), and dout = grad w.r.t. gelu_out
78
+ Returns: (dx, gelu_out)
79
+
80
+ Derivative uses the chain rule:
81
+ d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
82
+ where z = x * (c1 + c2 * x^2), dz/dx = c1 + 3 * c2 * x^2
83
+ and sech^2(z) = 1 - tanh^2(z)
84
+ """
85
+ sqrt_2_over_pi = math.sqrt(2 / math.pi) # c1 ~0.797885
86
+ sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # c2 ~0.0356774
87
+ sqrt_2_over_pi_coeff_3 = 3.0 * sqrt_2_over_pi_coeff # c3 ~0.01070322
88
+
89
+ # Compute z = x * (c1 + c2 * x^2)
90
+ x_sq = x * x
91
+ tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq))
92
+ half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z
93
+ gelu_out = x * half_tanh_z_plus_one
94
+
95
+ # Compute gradient
96
+ # sech^2(z) = 1 - tanh^2(z)
97
+ sech2_z = 1 - tanh_z * tanh_z
98
+ # dz/dx = c1 + 3 * c2 * x^2
99
+ dz_dx = sqrt_2_over_pi + sqrt_2_over_pi_coeff_3 * x_sq
100
+ # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
101
+ dgelu = half_tanh_z_plus_one + x * (0.5 * (sech2_z * dz_dx))
102
+
103
+ dx = dout * dgelu
104
+ return dx, gelu_out
105
+
106
+
107
+ @dsl_user_op
108
+ def silu(x: Float32, *, loc=None, ip=None) -> Float32:
109
+ """
110
+ silu(x) = x * sigmoid(x) = x * (1 + tanh(x / 2)) / 2 = (0.5 * x) * tanh(0.5 * x) + (0.5 * x)
111
+ This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA.
112
+ """
113
+ x_half = 0.5 * x
114
+ return x_half * tanh(x_half) + x_half
115
+
116
+
117
+ @dsl_user_op
118
+ def swiglu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
119
+ return silu(x) * y
120
+
121
+
122
+ @dsl_user_op
123
+ def dswiglu(
124
+ x: Float32, y: Float32, dout: Float32, *, loc=None, ip=None
125
+ ) -> Tuple[Float32, Float32, Float32]:
126
+ """
127
+ SwiGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
128
+ Given: swiglu_out = silu(x) * y, and dout = grad w.r.t. swiglu_out
129
+ Returns: (dx, dy, swiglu_out) where dx = dout * y * d_silu(x), dy = dout * silu(x)
130
+
131
+ d_silu(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
132
+
133
+ This has been optimized to use fewer instructions (i.e. we expand things out
134
+ to use FFMA instead of FADD and FMUL).
135
+ """
136
+ # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x))
137
+ x_half = 0.5 * x # FMUL
138
+ sigmoid_x = 0.5 + 0.5 * tanh(x_half) # MUFU.TANH, then FFMA
139
+ silu_x = x * sigmoid_x # FMUL
140
+ silu_x_dout = silu_x * dout # FMUL
141
+ # d_silu(x) * dout
142
+ # = sigmoid_x * (1 + x * (1 - sigmoid_x)) * dout
143
+ # = (sigmoid_x + sigmoid_x * x * (1 - sigmoid_x)) * dout
144
+ # = (sigmoid_x + silu_x * (1 - sigmoid_x)) * dout
145
+ # = (sigmoid_x + silu_x - silu_x * sigmoid_x) * dout
146
+ # = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
147
+ d_silu_x_dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x_dout # FFMA, FFMA
148
+ dx = d_silu_x_dout * y # FMUL
149
+ dy = silu_x_dout
150
+ swiglu_out = silu_x * y # FMUL
151
+ # Overall it's 1 MUFU.TANH, 5 FMUL, 3 FFMA
152
+ return dx, dy, swiglu_out
153
+
154
+
155
+ @dsl_user_op
156
+ def swiglu_oai(x: Float32, y: Float32, alpha: float = 1.702, *, loc=None, ip=None) -> Float32:
157
+ """The swiglu variant used in gpt-oss, which has a scaling factor on x and bias of 1 to y.
158
+ https://github.com/openai/gpt-oss/blob/7be9334950053a888e24887a57dac797a17d6e00/gpt_oss/torch/model.py#L249
159
+ x * sigmoid(alpha * x) * (y + 1)
160
+ Compile down to FMUL, FMUL, TANH, FFMA, FFMA
161
+ """
162
+ # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
163
+ x_half = 0.5 * x
164
+ silu_x = x_half * tanh(alpha * x_half) + x_half
165
+ return silu_x * y + silu_x
166
+
167
+
168
+ @dsl_user_op
169
+ def dswiglu_oai(
170
+ x: Float32, y: Float32, dout: Float32, alpha: float = 1.702, *, loc=None, ip=None
171
+ ) -> Tuple[Float32, Float32, Float32]:
172
+ """
173
+ Swiglu OAI backward pass: computes gradients w.r.t. x and y
174
+ Given: swiglu_oai_out = x * sigmoid(alpha * x) * (y + 1), and dout = grad w.r.t. swiglu_oai_out
175
+ Returns: (dx, dy, swiglu_oai_out)
176
+
177
+ Derivative of x * sigmoid(alpha * x) w.r.t. x:
178
+ d/dx[x * sigmoid(alpha * x)] = sigmoid(alpha * x) + alpha * x * sigmoid(alpha * x) * (1 - sigmoid(alpha * x))
179
+ """
180
+ # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
181
+ alpha_x_half = (0.5 * alpha) * x # FMUL
182
+ sigmoid_alpha_x = 0.5 + 0.5 * tanh(alpha_x_half) # MUFU.TANH, then FFMA
183
+ silu_x = x * sigmoid_alpha_x # FMUL
184
+ silu_x_dout = silu_x * dout # FMUL
185
+ # FFMA, FFMA, FMUL
186
+ d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
187
+ dx = d_silu_x_dout * y + d_silu_x_dout # FFMA, instead of multiply by y + 1
188
+ dy = silu_x_dout
189
+ swiglu_out = silu_x * y + silu_x # FFMA, instead of multiply by y + 1
190
+ # Overall it's 1 MUFU.TANH, 4 FMUL, 5 FFMA
191
+ return dx, dy, swiglu_out
192
+
193
+
194
+ @dsl_user_op
195
+ def glu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
196
+ """GLU: Gated Linear Unit
197
+ glu(x, y) = sigmoid(x) * y
198
+ Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2))
199
+ """
200
+ x_half = 0.5 * x # FMUL
201
+ sigmoid_x = 0.5 + 0.5 * tanh(x_half) # MUFU.TANH, then FFMA
202
+ return sigmoid_x * y # FMUL
203
+
204
+
205
+ @dsl_user_op
206
+ def dglu(
207
+ x: Float32, y: Float32, dout: Float32, *, loc=None, ip=None
208
+ ) -> Tuple[Float32, Float32, Float32]:
209
+ """
210
+ GLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
211
+ Given: glu_out = sigmoid(x) * y, and dout = grad w.r.t. glu_out
212
+ Returns: (dx, dy, glu_out) where:
213
+ - dx = dout * y * sigmoid(x) * (1 - sigmoid(x))
214
+ - dy = dout * sigmoid(x)
215
+ - glu_out = sigmoid(x) * y
216
+ """
217
+ # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2))
218
+ x_half = 0.5 * x # FMUL
219
+ sigmoid_x = 0.5 + 0.5 * tanh(x_half) # MUFU.TANH, then FFMA
220
+ sigmoid_x_dout = sigmoid_x * dout # FMUL
221
+ glu_out = sigmoid_x * y # FMUL
222
+ # dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout
223
+ # = y * (1 - sigmoid(x)) * sigmoid_x_dout
224
+ # = (y - y * sigmoid(x)) * sigmoid_x_dout
225
+ # = (y - glu_out) * sigmoid_x_dout
226
+ dx = (y - glu_out) * sigmoid_x_dout # FADD, FMUL
227
+ dy = sigmoid_x_dout
228
+ # Total: 1 MUFU.TANH, 4 FMUL, 1 FADD, 1 FFMA
229
+ return dx, dy, glu_out
230
+
231
+
232
+ @dsl_user_op
233
+ def reglu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
234
+ """ReGLU: ReLU Gated Linear Unit
235
+ reglu(x, y) = relu(x) * y = max(x, 0) * y
236
+ """
237
+ return cute.arch.fmax(x, Float32(0.0)) * y
238
+
239
+
240
+ @cute.jit
241
+ @dsl_user_op
242
+ def dreglu(
243
+ x: Float32, y: Float32, dout: Float32, *, loc=None, ip=None
244
+ ) -> Tuple[Float32, Float32, Float32]:
245
+ """
246
+ ReGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
247
+ Given: reglu_out = relu(x) * y, and dout = grad w.r.t. reglu_out
248
+ Returns: (dx, dy, reglu_out) where:
249
+ - dx = dout * y if x > 0, else 0
250
+ - dy = dout * relu(x)
251
+ - reglu_out = relu(x) * y
252
+ """
253
+ x_pos = cutlass.Boolean(x > 0)
254
+ relu_x = cute.arch.fmax(x, Float32(0.0))
255
+ dx = (dout * y) if x_pos else Float32(0.0)
256
+ dy = dout * relu_x
257
+ reglu_out = relu_x * y
258
+ return dx, dy, reglu_out
259
+
260
+
261
+ @dsl_user_op
262
+ def geglu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
263
+ """GeGLU: GELU Gated Linear Unit
264
+ geglu(x, y) = gelu(x) * y
265
+ Uses the tanh approximation of GELU
266
+ """
267
+ return gelu_tanh_approx(x) * y
268
+
269
+
270
+ @dsl_user_op
271
+ def dgeglu(
272
+ x: Float32, y: Float32, dout: Float32, *, loc=None, ip=None
273
+ ) -> Tuple[Float32, Float32, Float32]:
274
+ """
275
+ GeGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
276
+ Given: geglu_out = gelu(x) * y, and dout = grad w.r.t. geglu_out
277
+ Returns: (dx, dy, geglu_out) where:
278
+ - dx = dout * y * d_gelu(x)
279
+ - dy = dout * gelu(x)
280
+ - geglu_out = gelu(x) * y
281
+ """
282
+ # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
283
+ dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
284
+ # Compute gradients for geglu
285
+ dx = dgelu_x_dout * y
286
+ dy = gelu_x * dout
287
+ geglu_out = gelu_x * y
288
+ return dx, dy, geglu_out
@@ -187,7 +187,8 @@ class Autotuner:
187
187
  if len(self.configs) > 1:
188
188
  all_args = {**self.nargs, **kwargs}
189
189
  _args = {k: v for (k, v) in all_args.items() if k in self.arg_names}
190
- key = [_args[key] for key in self.keys if key in _args]
190
+ # Need "str" to make it json-serializable
191
+ key = [str(_args[key]) for key in self.keys if key in _args]
191
192
  for _, arg in _args.items():
192
193
  if isinstance(arg, Tensor):
193
194
  key.append(str(arg.shape))