quack-kernels 0.1.11__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {quack_kernels-0.1.11/quack_kernels.egg-info → quack_kernels-0.2.1}/PKG-INFO +3 -3
  2. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/README.md +2 -3
  3. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/pyproject.toml +2 -2
  4. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/quack/__init__.py +7 -3
  5. quack_kernels-0.2.1/quack/activation.py +279 -0
  6. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/quack/autotuner.py +2 -1
  7. quack_kernels-0.2.1/quack/cross_entropy.py +730 -0
  8. quack_kernels-0.2.1/quack/cute_dsl_utils.py +119 -0
  9. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/quack/dense_gemm_sm100.py +1 -1
  10. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/quack/dense_gemm_sm90.py +911 -1140
  11. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/quack/fast_math.py +10 -27
  12. quack_kernels-0.2.1/quack/gemm_act_sm90.py +368 -0
  13. quack_kernels-0.2.1/quack/gemm_config.py +69 -0
  14. quack_kernels-0.2.1/quack/gemm_dact_sm90.py +150 -0
  15. quack_kernels-0.2.1/quack/gemm_interface.py +569 -0
  16. quack_kernels-0.2.1/quack/gemm_wrapper_utils.py +158 -0
  17. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/quack/layernorm.py +6 -4
  18. quack_kernels-0.2.1/quack/linear.py +240 -0
  19. quack_kernels-0.2.1/quack/linear_cross_entropy.py +275 -0
  20. quack_kernels-0.2.1/quack/mlp.py +74 -0
  21. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/quack/pipeline.py +2 -17
  22. quack_kernels-0.2.1/quack/reduce.py +240 -0
  23. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/quack/reduction_base.py +2 -11
  24. quack_kernels-0.2.1/quack/rmsnorm.py +1250 -0
  25. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/quack/softmax.py +28 -16
  26. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/quack/symmetric_dense_gemm_sm90.py +6 -3
  27. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/quack/tensormap_manager.py +1 -0
  28. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/quack/tile_scheduler.py +64 -61
  29. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/quack/topk.py +14 -8
  30. quack_kernels-0.2.1/quack/utils.py +358 -0
  31. quack_kernels-0.2.1/quack/varlen_utils.py +22 -0
  32. {quack_kernels-0.1.11 → quack_kernels-0.2.1/quack_kernels.egg-info}/PKG-INFO +3 -3
  33. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/quack_kernels.egg-info/SOURCES.txt +9 -1
  34. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/quack_kernels.egg-info/requires.txt +1 -1
  35. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/quack_kernels.egg-info/top_level.txt +1 -0
  36. quack_kernels-0.2.1/tests/test_cross_entropy.py +333 -0
  37. quack_kernels-0.2.1/tests/test_linear.py +131 -0
  38. quack_kernels-0.2.1/tests/test_linear_cross_entropy.py +85 -0
  39. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/tests/test_rmsnorm.py +158 -16
  40. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/tests/test_softmax.py +19 -14
  41. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/tests/test_symmetric_dense_gemm_sm90.py +82 -94
  42. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/tests/test_topk.py +14 -8
  43. quack_kernels-0.1.11/quack/cross_entropy.py +0 -584
  44. quack_kernels-0.1.11/quack/cute_dsl_utils.py +0 -40
  45. quack_kernels-0.1.11/quack/gemm_config.py +0 -61
  46. quack_kernels-0.1.11/quack/gemm_interface.py +0 -321
  47. quack_kernels-0.1.11/quack/linear.py +0 -176
  48. quack_kernels-0.1.11/quack/lse.py +0 -62
  49. quack_kernels-0.1.11/quack/mlp.py +0 -204
  50. quack_kernels-0.1.11/quack/rmsnorm.py +0 -864
  51. quack_kernels-0.1.11/quack/utils.py +0 -666
  52. quack_kernels-0.1.11/tests/test_cross_entropy.py +0 -109
  53. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/LICENSE +0 -0
  54. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/quack/sort/bitonic_sort.py +0 -0
  55. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/quack/sort/generate_sorting_networks.py +0 -0
  56. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/quack/sort/sorting_networks.py +0 -0
  57. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/quack/sort/utils.py +0 -0
  58. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/quack_kernels.egg-info/dependency_links.txt +0 -0
  59. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/setup.cfg +0 -0
  60. {quack_kernels-0.1.11 → quack_kernels-0.2.1}/tests/test_layernorm.py +0 -0
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.1.11
4
- Requires-Python: >=3.12
3
+ Version: 0.2.1
4
+ Requires-Python: >=3.10
5
5
  License-File: LICENSE
6
- Requires-Dist: nvidia-cutlass-dsl==4.1.0
6
+ Requires-Dist: nvidia-cutlass-dsl==4.2.0
7
7
  Requires-Dist: torch
8
8
  Provides-Extra: dev
9
9
  Requires-Dist: pre-commit; extra == "dev"
@@ -20,9 +20,8 @@ pip install quack-kernels
20
20
  - 🦆 Softmax forward + backward
21
21
  - 🦆 Cross entropy forward + backward
22
22
  - 🦆 Layernorm forward
23
-
24
- Upcoming:
25
- - 🦆 Rotary forward + backward
23
+ - 🦆 Hopper gemm + epilogue
24
+ - 🦆 Blackwell gemm + epilogue
26
25
 
27
26
  ## Usage
28
27
 
@@ -5,9 +5,9 @@ build-backend = "setuptools.build_meta"
5
5
  [project]
6
6
  name = "quack-kernels"
7
7
  dynamic = ["version"]
8
- requires-python = ">=3.12"
8
+ requires-python = ">=3.10"
9
9
  dependencies = [
10
- "nvidia-cutlass-dsl==4.1.0",
10
+ "nvidia-cutlass-dsl==4.2.0",
11
11
  "torch",
12
12
  ]
13
13
 
@@ -1,11 +1,15 @@
1
- __version__ = "0.1.11"
1
+ __version__ = "0.2.1"
2
+
3
+ import cutlass.cute as cute
2
4
 
3
5
  from quack.rmsnorm import rmsnorm
4
6
  from quack.softmax import softmax
5
7
  from quack.cross_entropy import cross_entropy
6
8
 
7
- # ruff: noqa
8
- import quack.cute_dsl_utils # Patch cute.compile to optionally dump SASS
9
+ import quack.cute_dsl_utils
10
+
11
+ # Patch cute.compile to optionally dump SASS
12
+ cute.compile = quack.cute_dsl_utils.cute_compile_patched
9
13
 
10
14
  __all__ = [
11
15
  "rmsnorm",
@@ -0,0 +1,279 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ import math
4
+ from typing import Tuple
5
+
6
+ import cutlass
7
+ import cutlass.cute as cute
8
+ from cutlass import Float32
9
+ from cutlass.cutlass_dsl import dsl_user_op
10
+
11
+
12
+ @dsl_user_op
13
+ def sigmoid(x: Float32, *, loc=None, ip=None) -> Float32:
14
+ return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
15
+
16
+
17
+ @dsl_user_op
18
+ def relu(x: Float32, *, loc=None, ip=None) -> Float32:
19
+ return cute.arch.fmax(x, Float32(0.0))
20
+
21
+
22
+ @cute.jit
23
+ @dsl_user_op
24
+ def drelu(x: Float32, dout: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
25
+ x_pos = cutlass.Boolean(x > 0)
26
+ return dout if x_pos else Float32(0.0), cute.arch.fmax(x, Float32(0.0))
27
+
28
+
29
+ @dsl_user_op
30
+ def relu_sq(x: Float32, *, loc=None, ip=None) -> Float32:
31
+ return cute.arch.fmax(x, Float32(0.0)) * x
32
+
33
+
34
+ @cute.jit
35
+ @dsl_user_op
36
+ def drelu_sq(x: Float32, dout: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
37
+ """
38
+ ReLU squared backward pass: computes gradient w.r.t. x and recomputes forward
39
+ Given: relu_sq_out = max(x, 0) * x, and dout = grad w.r.t. relu_sq_out
40
+ Returns: (dx, relu_sq_out) where:
41
+ - dx = dout * 2 * x if x > 0, else 0
42
+ - relu_sq_out = max(x, 0) * x
43
+ """
44
+ x_pos = cutlass.Boolean(x > 0)
45
+ relu_sq_out = cute.arch.fmax(x, Float32(0.0)) * x
46
+ # Derivative: d/dx[max(x,0) * x] = 2*x if x > 0, else 0
47
+ dx = (2.0 * dout * x) if x_pos else Float32(0.0)
48
+ return dx, relu_sq_out
49
+
50
+
51
+ @dsl_user_op
52
+ def gelu_tanh_approx(x: Float32, *, loc=None, ip=None) -> Float32:
53
+ """
54
+ gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
55
+ = 0.5 * x * (1 + tanh(x * (0.797885 + 0.0356774 * x * x)))
56
+ """
57
+ sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885
58
+ sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774
59
+ return 0.5 * (
60
+ x
61
+ * (1 + cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)), fastmath=True))
62
+ )
63
+
64
+
65
+ @dsl_user_op
66
+ def dgelu_tanh_approx(x: Float32, dout: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
67
+ """
68
+ GELU tanh approximation backward pass: computes gradient w.r.t. x and recomputes forward
69
+ Given: gelu_out = 0.5 * x * (1 + tanh(x * (c1 + c2 * x^2))), and dout = grad w.r.t. gelu_out
70
+ Returns: (dx, gelu_out)
71
+
72
+ Derivative uses the chain rule:
73
+ d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
74
+ where z = x * (c1 + c2 * x^2), dz/dx = c1 + 3 * c2 * x^2
75
+ and sech^2(z) = 1 - tanh^2(z)
76
+ """
77
+ sqrt_2_over_pi = math.sqrt(2 / math.pi) # c1 ~0.797885
78
+ sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # c2 ~0.0356774
79
+ sqrt_2_over_pi_coeff_3 = 3.0 * sqrt_2_over_pi_coeff # c3 ~0.01070322
80
+
81
+ # Compute z = x * (c1 + c2 * x^2)
82
+ x_sq = x * x
83
+ tanh_z = cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq), fastmath=True)
84
+ half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z
85
+ gelu_out = x * half_tanh_z_plus_one
86
+
87
+ # Compute gradient
88
+ # sech^2(z) = 1 - tanh^2(z)
89
+ sech2_z = 1 - tanh_z * tanh_z
90
+ # dz/dx = c1 + 3 * c2 * x^2
91
+ dz_dx = sqrt_2_over_pi + sqrt_2_over_pi_coeff_3 * x_sq
92
+ # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
93
+ dgelu = half_tanh_z_plus_one + x * (0.5 * (sech2_z * dz_dx))
94
+
95
+ dx = dout * dgelu
96
+ return dx, gelu_out
97
+
98
+
99
+ @dsl_user_op
100
+ def silu(x: Float32, *, loc=None, ip=None) -> Float32:
101
+ """
102
+ silu(x) = x * sigmoid(x) = x * (1 + tanh(x / 2)) / 2 = (0.5 * x) * tanh(0.5 * x) + (0.5 * x)
103
+ This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA.
104
+ """
105
+ x_half = 0.5 * x
106
+ return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
107
+
108
+
109
+ @dsl_user_op
110
+ def swiglu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
111
+ return silu(x) * y
112
+
113
+
114
+ @dsl_user_op
115
+ def dswiglu(
116
+ x: Float32, y: Float32, dout: Float32, *, loc=None, ip=None
117
+ ) -> Tuple[Float32, Float32, Float32]:
118
+ """
119
+ SwiGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
120
+ Given: swiglu_out = silu(x) * y, and dout = grad w.r.t. swiglu_out
121
+ Returns: (dx, dy, swiglu_out) where dx = dout * y * d_silu(x), dy = dout * silu(x)
122
+
123
+ d_silu(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
124
+
125
+ This has been optimized to use fewer instructions (i.e. we expand things out
126
+ to use FFMA instead of FADD and FMUL).
127
+ """
128
+ # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x))
129
+ # FMUL, MUFU.TANH, then FFMA
130
+ sigmoid_x = sigmoid(x)
131
+ silu_x = x * sigmoid_x # FMUL
132
+ silu_x_dout = silu_x * dout # FMUL
133
+ # d_silu(x) * dout
134
+ # = sigmoid_x * (1 + x * (1 - sigmoid_x)) * dout
135
+ # = (sigmoid_x + sigmoid_x * x * (1 - sigmoid_x)) * dout
136
+ # = (sigmoid_x + silu_x * (1 - sigmoid_x)) * dout
137
+ # = (sigmoid_x + silu_x - silu_x * sigmoid_x) * dout
138
+ # = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
139
+ d_silu_x_dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x_dout # FFMA, FFMA
140
+ dx = d_silu_x_dout * y # FMUL
141
+ dy = silu_x_dout
142
+ swiglu_out = silu_x * y # FMUL
143
+ # Overall it's 1 MUFU.TANH, 5 FMUL, 3 FFMA
144
+ return dx, dy, swiglu_out
145
+
146
+
147
+ @dsl_user_op
148
+ def swiglu_oai(x: Float32, y: Float32, alpha: float = 1.702, *, loc=None, ip=None) -> Float32:
149
+ """The swiglu variant used in gpt-oss, which has a scaling factor on x and bias of 1 to y.
150
+ https://github.com/openai/gpt-oss/blob/7be9334950053a888e24887a57dac797a17d6e00/gpt_oss/torch/model.py#L249
151
+ x * sigmoid(alpha * x) * (y + 1)
152
+ Compile down to FMUL, FMUL, TANH, FFMA, FFMA
153
+ """
154
+ # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
155
+ x_half = 0.5 * x
156
+ silu_x = x_half * cute.math.tanh(alpha * x_half, fastmath=True) + x_half
157
+ return silu_x * y + silu_x
158
+
159
+
160
+ @dsl_user_op
161
+ def dswiglu_oai(
162
+ x: Float32, y: Float32, dout: Float32, alpha: float = 1.702, *, loc=None, ip=None
163
+ ) -> Tuple[Float32, Float32, Float32]:
164
+ """
165
+ Swiglu OAI backward pass: computes gradients w.r.t. x and y
166
+ Given: swiglu_oai_out = x * sigmoid(alpha * x) * (y + 1), and dout = grad w.r.t. swiglu_oai_out
167
+ Returns: (dx, dy, swiglu_oai_out)
168
+
169
+ Derivative of x * sigmoid(alpha * x) w.r.t. x:
170
+ d/dx[x * sigmoid(alpha * x)] = sigmoid(alpha * x) + alpha * x * sigmoid(alpha * x) * (1 - sigmoid(alpha * x))
171
+ """
172
+ # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
173
+ alpha_x_half = (0.5 * alpha) * x # FMUL
174
+ # MUFU.TANH, then FFMA
175
+ sigmoid_alpha_x = 0.5 + 0.5 * cute.math.tanh(alpha_x_half, fastmath=True)
176
+ silu_x = x * sigmoid_alpha_x # FMUL
177
+ silu_x_dout = silu_x * dout # FMUL
178
+ # FFMA, FFMA, FMUL
179
+ d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
180
+ dx = d_silu_x_dout * y + d_silu_x_dout # FFMA, instead of multiply by y + 1
181
+ dy = silu_x_dout
182
+ swiglu_out = silu_x * y + silu_x # FFMA, instead of multiply by y + 1
183
+ # Overall it's 1 MUFU.TANH, 4 FMUL, 5 FFMA
184
+ return dx, dy, swiglu_out
185
+
186
+
187
+ @dsl_user_op
188
+ def glu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
189
+ """GLU: Gated Linear Unit
190
+ glu(x, y) = sigmoid(x) * y
191
+ Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2))
192
+ """
193
+ sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
194
+ return sigmoid_x * y # FMUL
195
+
196
+
197
+ @dsl_user_op
198
+ def dglu(
199
+ x: Float32, y: Float32, dout: Float32, *, loc=None, ip=None
200
+ ) -> Tuple[Float32, Float32, Float32]:
201
+ """
202
+ GLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
203
+ Given: glu_out = sigmoid(x) * y, and dout = grad w.r.t. glu_out
204
+ Returns: (dx, dy, glu_out) where:
205
+ - dx = dout * y * sigmoid(x) * (1 - sigmoid(x))
206
+ - dy = dout * sigmoid(x)
207
+ - glu_out = sigmoid(x) * y
208
+ """
209
+ # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2))
210
+ sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
211
+ sigmoid_x_dout = sigmoid_x * dout # FMUL
212
+ glu_out = sigmoid_x * y # FMUL
213
+ # dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout
214
+ # = y * (1 - sigmoid(x)) * sigmoid_x_dout
215
+ # = (y - y * sigmoid(x)) * sigmoid_x_dout
216
+ # = (y - glu_out) * sigmoid_x_dout
217
+ dx = (y - glu_out) * sigmoid_x_dout # FADD, FMUL
218
+ dy = sigmoid_x_dout
219
+ # Total: 1 MUFU.TANH, 4 FMUL, 1 FADD, 1 FFMA
220
+ return dx, dy, glu_out
221
+
222
+
223
+ @dsl_user_op
224
+ def reglu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
225
+ """ReGLU: ReLU Gated Linear Unit
226
+ reglu(x, y) = relu(x) * y = max(x, 0) * y
227
+ """
228
+ return cute.arch.fmax(x, Float32(0.0)) * y
229
+
230
+
231
+ @cute.jit
232
+ @dsl_user_op
233
+ def dreglu(
234
+ x: Float32, y: Float32, dout: Float32, *, loc=None, ip=None
235
+ ) -> Tuple[Float32, Float32, Float32]:
236
+ """
237
+ ReGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
238
+ Given: reglu_out = relu(x) * y, and dout = grad w.r.t. reglu_out
239
+ Returns: (dx, dy, reglu_out) where:
240
+ - dx = dout * y if x > 0, else 0
241
+ - dy = dout * relu(x)
242
+ - reglu_out = relu(x) * y
243
+ """
244
+ x_pos = cutlass.Boolean(x > 0)
245
+ relu_x = cute.arch.fmax(x, Float32(0.0))
246
+ dx = (dout * y) if x_pos else Float32(0.0)
247
+ dy = dout * relu_x
248
+ reglu_out = relu_x * y
249
+ return dx, dy, reglu_out
250
+
251
+
252
+ @dsl_user_op
253
+ def geglu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
254
+ """GeGLU: GELU Gated Linear Unit
255
+ geglu(x, y) = gelu(x) * y
256
+ Uses the tanh approximation of GELU
257
+ """
258
+ return gelu_tanh_approx(x) * y
259
+
260
+
261
+ @dsl_user_op
262
+ def dgeglu(
263
+ x: Float32, y: Float32, dout: Float32, *, loc=None, ip=None
264
+ ) -> Tuple[Float32, Float32, Float32]:
265
+ """
266
+ GeGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
267
+ Given: geglu_out = gelu(x) * y, and dout = grad w.r.t. geglu_out
268
+ Returns: (dx, dy, geglu_out) where:
269
+ - dx = dout * y * d_gelu(x)
270
+ - dy = dout * gelu(x)
271
+ - geglu_out = gelu(x) * y
272
+ """
273
+ # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
274
+ dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
275
+ # Compute gradients for geglu
276
+ dx = dgelu_x_dout * y
277
+ dy = gelu_x * dout
278
+ geglu_out = gelu_x * y
279
+ return dx, dy, geglu_out
@@ -187,7 +187,8 @@ class Autotuner:
187
187
  if len(self.configs) > 1:
188
188
  all_args = {**self.nargs, **kwargs}
189
189
  _args = {k: v for (k, v) in all_args.items() if k in self.arg_names}
190
- key = [_args[key] for key in self.keys if key in _args]
190
+ # Need "str" to make it json-serializable
191
+ key = [str(_args[key]) for key in self.keys if key in _args]
191
192
  for _, arg in _args.items():
192
193
  if isinstance(arg, Tensor):
193
194
  key.append(str(arg.shape))