quack-kernels 0.2.1__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. {quack_kernels-0.2.1/quack_kernels.egg-info → quack_kernels-0.2.3}/PKG-INFO +4 -2
  2. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/pyproject.toml +3 -1
  3. quack_kernels-0.2.3/quack/__init__.py +11 -0
  4. quack_kernels-0.2.3/quack/activation.py +524 -0
  5. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/autotuner.py +64 -5
  6. quack_kernels-0.2.3/quack/broadcast_utils.py +29 -0
  7. quack_kernels-0.2.3/quack/compile_utils.py +19 -0
  8. quack_kernels-0.2.3/quack/copy_utils.py +487 -0
  9. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/cross_entropy.py +157 -233
  10. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/cute_dsl_utils.py +20 -35
  11. quack_kernels-0.2.3/quack/gemm.py +194 -0
  12. quack_kernels-0.2.3/quack/gemm_act.py +510 -0
  13. quack_kernels-0.2.3/quack/gemm_config.py +95 -0
  14. quack_kernels-0.2.3/quack/gemm_dact.py +215 -0
  15. quack_kernels-0.2.3/quack/gemm_default_epi.py +259 -0
  16. quack_kernels-0.2.3/quack/gemm_interface.py +1038 -0
  17. quack_kernels-0.2.1/quack/dense_gemm_sm100.py → quack_kernels-0.2.3/quack/gemm_sm100.py +1034 -787
  18. quack_kernels-0.2.1/quack/dense_gemm_sm90.py → quack_kernels-0.2.3/quack/gemm_sm90.py +552 -727
  19. quack_kernels-0.2.3/quack/gemm_symmetric.py +330 -0
  20. quack_kernels-0.2.3/quack/gemm_wrapper_utils.py +317 -0
  21. quack_kernels-0.2.3/quack/layout_utils.py +287 -0
  22. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/linear.py +24 -16
  23. quack_kernels-0.2.3/quack/pipeline.py +306 -0
  24. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/reduce.py +88 -49
  25. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/reduction_base.py +25 -36
  26. quack_kernels-0.2.3/quack/rmsnorm.py +1134 -0
  27. quack_kernels-0.2.3/quack/sm100_utils.py +62 -0
  28. quack_kernels-0.2.3/quack/sm90_utils.py +127 -0
  29. quack_kernels-0.2.3/quack/softmax.py +403 -0
  30. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/sort/bitonic_sort.py +13 -10
  31. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/sort/utils.py +6 -6
  32. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/tile_scheduler.py +55 -61
  33. quack_kernels-0.2.3/quack/topk.py +551 -0
  34. quack_kernels-0.2.3/quack/utils.py +223 -0
  35. quack_kernels-0.2.3/quack/varlen_utils.py +386 -0
  36. {quack_kernels-0.2.1 → quack_kernels-0.2.3/quack_kernels.egg-info}/PKG-INFO +4 -2
  37. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack_kernels.egg-info/SOURCES.txt +16 -7
  38. quack_kernels-0.2.3/quack_kernels.egg-info/requires.txt +8 -0
  39. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/tests/test_layernorm.py +17 -51
  40. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/tests/test_linear.py +43 -18
  41. quack_kernels-0.2.3/tests/test_linear_varlen_k.py +312 -0
  42. quack_kernels-0.2.3/tests/test_linear_varlen_m.py +395 -0
  43. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/tests/test_rmsnorm.py +26 -17
  44. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/tests/test_softmax.py +1 -2
  45. quack_kernels-0.2.1/tests/test_symmetric_dense_gemm_sm90.py → quack_kernels-0.2.3/tests/test_symmetric_gemm.py +15 -16
  46. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/tests/test_topk.py +42 -23
  47. quack_kernels-0.2.1/quack/__init__.py +0 -18
  48. quack_kernels-0.2.1/quack/activation.py +0 -279
  49. quack_kernels-0.2.1/quack/gemm_act_sm90.py +0 -368
  50. quack_kernels-0.2.1/quack/gemm_config.py +0 -69
  51. quack_kernels-0.2.1/quack/gemm_dact_sm90.py +0 -150
  52. quack_kernels-0.2.1/quack/gemm_interface.py +0 -569
  53. quack_kernels-0.2.1/quack/gemm_wrapper_utils.py +0 -158
  54. quack_kernels-0.2.1/quack/layernorm.py +0 -353
  55. quack_kernels-0.2.1/quack/pipeline.py +0 -151
  56. quack_kernels-0.2.1/quack/rmsnorm.py +0 -1250
  57. quack_kernels-0.2.1/quack/softmax.py +0 -471
  58. quack_kernels-0.2.1/quack/symmetric_dense_gemm_sm90.py +0 -2091
  59. quack_kernels-0.2.1/quack/topk.py +0 -227
  60. quack_kernels-0.2.1/quack/utils.py +0 -358
  61. quack_kernels-0.2.1/quack/varlen_utils.py +0 -22
  62. quack_kernels-0.2.1/quack_kernels.egg-info/requires.txt +0 -6
  63. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/LICENSE +0 -0
  64. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/README.md +0 -0
  65. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/fast_math.py +0 -0
  66. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/linear_cross_entropy.py +0 -0
  67. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/mlp.py +0 -0
  68. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/sort/generate_sorting_networks.py +0 -0
  69. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/sort/sorting_networks.py +0 -0
  70. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/tensormap_manager.py +0 -0
  71. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack_kernels.egg-info/dependency_links.txt +0 -0
  72. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack_kernels.egg-info/top_level.txt +0 -0
  73. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/setup.cfg +0 -0
  74. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/tests/test_cross_entropy.py +1 -1
  75. {quack_kernels-0.2.1 → quack_kernels-0.2.3}/tests/test_linear_cross_entropy.py +0 -0
@@ -1,10 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Requires-Python: >=3.10
5
5
  License-File: LICENSE
6
- Requires-Dist: nvidia-cutlass-dsl==4.2.0
6
+ Requires-Dist: nvidia-cutlass-dsl==4.3.3
7
7
  Requires-Dist: torch
8
+ Requires-Dist: apache-tvm-ffi<0.2,>=0.1.5
9
+ Requires-Dist: torch-c-dlpack-ext
8
10
  Provides-Extra: dev
9
11
  Requires-Dist: pre-commit; extra == "dev"
10
12
  Requires-Dist: ruff; extra == "dev"
@@ -7,8 +7,10 @@ name = "quack-kernels"
7
7
  dynamic = ["version"]
8
8
  requires-python = ">=3.10"
9
9
  dependencies = [
10
- "nvidia-cutlass-dsl==4.2.0",
10
+ "nvidia-cutlass-dsl==4.3.3",
11
11
  "torch",
12
+ "apache-tvm-ffi>=0.1.5,<0.2",
13
+ "torch-c-dlpack-ext",
12
14
  ]
13
15
 
14
16
  [project.optional-dependencies]
@@ -0,0 +1,11 @@
1
+ __version__ = "0.2.3"
2
+
3
+ from quack.rmsnorm import rmsnorm
4
+ from quack.softmax import softmax
5
+ from quack.cross_entropy import cross_entropy
6
+
7
+ __all__ = [
8
+ "rmsnorm",
9
+ "softmax",
10
+ "cross_entropy",
11
+ ]
@@ -0,0 +1,524 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ import math
4
+ from typing import Tuple
5
+
6
+ import cutlass.cute as cute
7
+ from cutlass import Float32, Boolean, const_expr
8
+ from cutlass.cutlass_dsl import T, dsl_user_op
9
+ from cutlass._mlir.dialects import llvm
10
+
11
+ import quack.utils as utils
12
+
13
+
14
+ F32_or_F32x2 = Float32 | Tuple[Float32, Float32]
15
+
16
+
17
+ @dsl_user_op
18
+ def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
19
+ return Float32(
20
+ llvm.inline_asm(
21
+ T.f32(),
22
+ [Float32(a).ir_value(loc=loc, ip=ip)],
23
+ "tanh.approx.f32 $0, $1;",
24
+ "=f,f",
25
+ has_side_effects=False,
26
+ is_align_stack=False,
27
+ asm_dialect=llvm.AsmDialect.AD_ATT,
28
+ )
29
+ )
30
+
31
+
32
+ @dsl_user_op
33
+ def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
34
+ if const_expr(not isinstance(x, tuple)):
35
+ # return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
36
+ return 0.5 + 0.5 * tanh(0.5 * x)
37
+ else:
38
+ x_half = utils.mul_packed_f32x2((0.5, 0.5), x)
39
+ tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
40
+ return utils.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5))
41
+
42
+
43
+ @dsl_user_op
44
+ def dsigmoid_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
45
+ # return dout * out * (1.0 - out)
46
+ return dout * (out - out * out)
47
+
48
+
49
+ @dsl_user_op
50
+ def relu(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
51
+ if const_expr(not isinstance(x, tuple)):
52
+ return cute.arch.fmax(x, Float32(0.0))
53
+ else:
54
+ return cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0))
55
+
56
+
57
+ @dsl_user_op
58
+ @cute.jit
59
+ def drelu(
60
+ x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
61
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
62
+ if const_expr(not isinstance(x, tuple)):
63
+ x_pos = Boolean(x > 0)
64
+ return dout if x_pos else Float32(0.0), cute.arch.fmax(x, Float32(0.0))
65
+ else:
66
+ x0_pos = Boolean(x[0] > 0)
67
+ x1_pos = Boolean(x[1] > 0)
68
+ dx = (dout[0] if x0_pos else Float32(0.0), dout[1] if x1_pos else Float32(0.0))
69
+ return dx, relu(x)
70
+
71
+
72
+ @dsl_user_op
73
+ def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
74
+ if const_expr(not isinstance(x, tuple)):
75
+ return cute.arch.fmax(x, Float32(0.0)) * x
76
+ else:
77
+ relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)))
78
+ return utils.mul_packed_f32x2(relu_x, x)
79
+
80
+
81
+ @dsl_user_op
82
+ @cute.jit
83
+ def drelu_sq(
84
+ x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
85
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
86
+ """
87
+ ReLU squared backward pass: computes gradient w.r.t. x and recomputes forward
88
+ Given: relu_sq_out = max(x, 0) * x, and dout = grad w.r.t. relu_sq_out
89
+ Returns: (dx, relu_sq_out) where:
90
+ - dx = dout * 2 * x if x > 0, else 0
91
+ - relu_sq_out = max(x, 0) * x
92
+ """
93
+ if const_expr(not isinstance(x, tuple)):
94
+ relu_x = relu(x)
95
+ relu_sq_out = relu_x * x
96
+ # Derivative: d/dx[max(x,0) * x] = 2*x if x > 0, else 0
97
+ dx = 2.0 * (dout * relu_x)
98
+ return dx, relu_sq_out
99
+ else:
100
+ relu_x = relu(x)
101
+ relu_sq_out = utils.mul_packed_f32x2(relu_x, x)
102
+ dx = utils.mul_packed_f32x2((2.0, 2.0), utils.mul_packed_f32x2(dout, relu_x))
103
+ return dx, relu_sq_out
104
+
105
+
106
+ @dsl_user_op
107
+ def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
108
+ """
109
+ gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
110
+ = 0.5 * x * (1 + tanh(x * (0.797885 + 0.0356774 * x * x)))
111
+ """
112
+ sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885
113
+ sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774
114
+ if const_expr(not isinstance(x, tuple)):
115
+ return 0.5 * (
116
+ x
117
+ # Currently cute.math.tanh(x, fastmath=True) generates very slow code
118
+ # * (1 + cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)), fastmath=True))
119
+ * (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x))))
120
+ )
121
+ else:
122
+ x_sq = utils.mul_packed_f32x2(x, x)
123
+ x_sq_scaled = utils.fma_packed_f32x2(
124
+ x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
125
+ )
126
+ z = utils.mul_packed_f32x2(x, x_sq_scaled)
127
+ tanh_z = (tanh(z[0]), tanh(z[1]))
128
+ x_tanh_z = utils.fma_packed_f32x2(tanh_z, x, x)
129
+ return utils.mul_packed_f32x2((0.5, 0.5), x_tanh_z)
130
+
131
+
132
+ @dsl_user_op
133
+ def dgelu_tanh_approx(
134
+ x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
135
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
136
+ """
137
+ GELU tanh approximation backward pass: computes gradient w.r.t. x and recomputes forward
138
+ Given: gelu_out = 0.5 * x * (1 + tanh(x * (c1 + c2 * x^2))), and dout = grad w.r.t. gelu_out
139
+ Returns: (dx, gelu_out)
140
+
141
+ Derivative uses the chain rule:
142
+ d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
143
+ where z = x * (c1 + c2 * x^2), dz/dx = c1 + 3 * c2 * x^2
144
+ and sech^2(z) = 1 - tanh^2(z)
145
+ """
146
+ sqrt_2_over_pi = math.sqrt(2 / math.pi) # c1 ~0.797885
147
+ sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # c2 ~0.0356774
148
+ sqrt_2_over_pi_coeff_3 = 3.0 * sqrt_2_over_pi_coeff # c3 ~0.01070322
149
+
150
+ if const_expr(not isinstance(x, tuple)):
151
+ # Compute z = x * (c1 + c2 * x^2)
152
+ x_sq = x * x
153
+ # tanh_z = cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq), fastmath=True)
154
+ tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq))
155
+ half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z
156
+ gelu_out = x * half_tanh_z_plus_one
157
+
158
+ # Compute gradient
159
+ # sech^2(z) = 1 - tanh^2(z)
160
+ sech2_z = 1 - tanh_z * tanh_z
161
+ # dz/dx = c1 + 3 * c2 * x^2
162
+ dz_dx = sqrt_2_over_pi + sqrt_2_over_pi_coeff_3 * x_sq
163
+ # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
164
+ dgelu = half_tanh_z_plus_one + x * (0.5 * (sech2_z * dz_dx))
165
+
166
+ dx = dout * dgelu
167
+ return dx, gelu_out
168
+ else:
169
+ # Compute z = x * (c1 + c2 * x^2)
170
+ x_sq = utils.mul_packed_f32x2(x, x)
171
+ x_sq_scaled = utils.fma_packed_f32x2(
172
+ x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
173
+ )
174
+ z = utils.mul_packed_f32x2(x, x_sq_scaled)
175
+ tanh_z = (tanh(z[0]), tanh(z[1]))
176
+ half_tanh_z_plus_one = utils.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5))
177
+ gelu_out = utils.mul_packed_f32x2(x, half_tanh_z_plus_one)
178
+
179
+ # Compute gradient
180
+ # sech^2(z) = 1 - tanh^2(z)
181
+ sech2_z = utils.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0))
182
+ # dz/dx = c1 + 3 * c2 * x^2
183
+ dz_dx = utils.fma_packed_f32x2(
184
+ x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi)
185
+ )
186
+ # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
187
+ sech2_dz_dx = utils.mul_packed_f32x2(sech2_z, dz_dx)
188
+ x_sech2_dz_dx = utils.mul_packed_f32x2(x, sech2_dz_dx)
189
+ dgelu = utils.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one)
190
+
191
+ dx = utils.mul_packed_f32x2(dout, dgelu)
192
+ return dx, gelu_out
193
+
194
+
195
+ @dsl_user_op
196
+ @cute.jit
197
+ def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
198
+ if const_expr(not isinstance(x, tuple)):
199
+ use_linear = Boolean(x > 20.0)
200
+ return (
201
+ cute.math.log(Float32(cute.math.exp(x, fastmath=True)) + 1.0, fastmath=True)
202
+ if not use_linear
203
+ else x
204
+ )
205
+ else:
206
+ log2_e = math.log2(math.e)
207
+ x_log2e = utils.mul_packed_f32x2(x, (log2_e, log2_e))
208
+ x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True))
209
+ x_exp_p1 = utils.add_packed_f32x2(x_exp, (1.0, 1.0))
210
+ log_x_exp_p1 = (
211
+ cute.math.log2(x_exp_p1[0], fastmath=True),
212
+ cute.math.log2(x_exp_p1[1], fastmath=True),
213
+ )
214
+ ln2 = math.log(2.0)
215
+ softplus_x = utils.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2))
216
+ use_linear_0 = Boolean(x[0] > 20.0)
217
+ use_linear_1 = Boolean(x[1] > 20.0)
218
+ return (
219
+ softplus_x[0] if not use_linear_0 else x[0],
220
+ softplus_x[1] if not use_linear_1 else x[1],
221
+ )
222
+
223
+
224
+ @dsl_user_op
225
+ @cute.jit
226
+ def dsoftplus_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
227
+ use_linear = Boolean(out > 20.0)
228
+ # dx = dout * (1.0 - cute.math.exp(-out, fastmath=True)) if not use_linear else dout
229
+ dx = dout - dout * cute.math.exp(-out, fastmath=True)
230
+ return dx if not use_linear else dout
231
+
232
+
233
+ @dsl_user_op
234
+ def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) -> F32_or_F32x2:
235
+ """
236
+ silu(x) = x * sigmoid(x) = x * (1 + tanh(x / 2)) / 2 = (0.5 * x) * tanh(0.5 * x) + (0.5 * x)
237
+ This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA.
238
+ """
239
+ if const_expr(not isinstance(x, tuple)):
240
+ x_half = 0.5 * x if const_expr(not already_halved) else x
241
+ # return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
242
+ return x_half * tanh(x_half) + x_half
243
+ else:
244
+ x_half = utils.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x
245
+ tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
246
+ return utils.fma_packed_f32x2(x_half, tanh_x_half, x_half)
247
+
248
+
249
+ @dsl_user_op
250
+ def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
251
+ if const_expr(not isinstance(x, tuple)):
252
+ return silu(x) * y
253
+ else:
254
+ return utils.mul_packed_f32x2(silu(x), y)
255
+
256
+
257
+ @dsl_user_op
258
+ def dswiglu(
259
+ x: F32_or_F32x2,
260
+ y: F32_or_F32x2,
261
+ dout: F32_or_F32x2,
262
+ *,
263
+ already_halved: bool = False,
264
+ loc=None,
265
+ ip=None,
266
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
267
+ """
268
+ SwiGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
269
+ Given: swiglu_out = silu(x) * y, and dout = grad w.r.t. swiglu_out
270
+ Returns: (dx, dy, swiglu_out) where dx = dout * y * d_silu(x), dy = dout * silu(x)
271
+
272
+ d_silu(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
273
+
274
+ This has been optimized to use fewer instructions (i.e. we expand things out
275
+ to use FFMA instead of FADD and FMUL).
276
+ """
277
+ if const_expr(not isinstance(x, tuple)):
278
+ # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x))
279
+ # FMUL, MUFU.TANH, then FFMA
280
+ if const_expr(not already_halved):
281
+ sigmoid_x = sigmoid(x)
282
+ silu_x = x * sigmoid_x # FMUL
283
+ else:
284
+ tanh_x = tanh(x) # MUFU.TANH
285
+ sigmoid_x = 0.5 * tanh_x + 0.5 # FFMA
286
+ silu_x = x * tanh_x + x # FFMA
287
+ silu_x_dout = silu_x * dout # FMUL
288
+ # d_silu(x) * dout
289
+ # = sigmoid_x * (1 + x * (1 - sigmoid_x)) * dout
290
+ # = (sigmoid_x + sigmoid_x * x * (1 - sigmoid_x)) * dout
291
+ # = (sigmoid_x + silu_x * (1 - sigmoid_x)) * dout
292
+ # = (sigmoid_x + silu_x - silu_x * sigmoid_x) * dout
293
+ # = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
294
+ d_silu_x_dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x_dout # FFMA, FFMA
295
+ dx = d_silu_x_dout * y # FMUL
296
+ dy = silu_x_dout
297
+ swiglu_out = silu_x * y # FMUL
298
+ # Overall it's 1 MUFU.TANH, 5 FMUL, 3 FFMA
299
+ return dx, dy, swiglu_out
300
+ else:
301
+ # Compute sigmoid(x) and silu(x)
302
+ if const_expr(not already_halved):
303
+ sigmoid_x = sigmoid(x)
304
+ silu_x = utils.mul_packed_f32x2(x, sigmoid_x)
305
+ else:
306
+ tanh_x = (tanh(x[0]), tanh(x[1]))
307
+ sigmoid_x = utils.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5))
308
+ silu_x = utils.fma_packed_f32x2(x, tanh_x, x)
309
+ silu_x_dout = utils.mul_packed_f32x2(silu_x, dout)
310
+ # d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
311
+ sigmoid_x_minus_silu_x_sigmoid_x = utils.fma_packed_f32x2(
312
+ sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x
313
+ )
314
+ d_silu_x_dout = utils.fma_packed_f32x2(sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout)
315
+ dx = utils.mul_packed_f32x2(d_silu_x_dout, y)
316
+ dy = silu_x_dout
317
+ swiglu_out = utils.mul_packed_f32x2(silu_x, y)
318
+ return dx, dy, swiglu_out
319
+
320
+
321
+ @dsl_user_op
322
+ def swiglu_oai(
323
+ x: F32_or_F32x2, y: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
324
+ ) -> F32_or_F32x2:
325
+ """The swiglu variant used in gpt-oss, which has a scaling factor on x and bias of 1 to y.
326
+ https://github.com/openai/gpt-oss/blob/7be9334950053a888e24887a57dac797a17d6e00/gpt_oss/torch/model.py#L249
327
+ x * sigmoid(alpha * x) * (y + 1)
328
+ Compile down to FMUL, FMUL, TANH, FFMA, FFMA
329
+ """
330
+ # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
331
+ if const_expr(not isinstance(x, tuple)):
332
+ x_half = 0.5 * x
333
+ # silu_x = x_half * cute.math.tanh(alpha * x_half, fastmath=True) + x_half
334
+ silu_x = x_half * tanh(alpha * x_half) + x_half
335
+ return silu_x * y + silu_x
336
+ else:
337
+ x_half = utils.mul_packed_f32x2((0.5, 0.5), x)
338
+ alpha_x_half = utils.mul_packed_f32x2((alpha, alpha), x_half)
339
+ tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
340
+ silu_x = utils.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half)
341
+ return utils.fma_packed_f32x2(silu_x, y, silu_x)
342
+
343
+
344
+ @dsl_user_op
345
+ def dswiglu_oai(
346
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
347
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
348
+ """
349
+ Swiglu OAI backward pass: computes gradients w.r.t. x and y
350
+ Given: swiglu_oai_out = x * sigmoid(alpha * x) * (y + 1), and dout = grad w.r.t. swiglu_oai_out
351
+ Returns: (dx, dy, swiglu_oai_out)
352
+
353
+ Derivative of x * sigmoid(alpha * x) w.r.t. x:
354
+ d/dx[x * sigmoid(alpha * x)] = sigmoid(alpha * x) + alpha * x * sigmoid(alpha * x) * (1 - sigmoid(alpha * x))
355
+ """
356
+ if const_expr(not isinstance(x, tuple)):
357
+ # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
358
+ alpha_x_half = (0.5 * alpha) * x # FMUL
359
+ # MUFU.TANH, then FFMA
360
+ # sigmoid_alpha_x = 0.5 + 0.5 * cute.math.tanh(alpha_x_half, fastmath=True)
361
+ sigmoid_alpha_x = 0.5 + 0.5 * tanh(alpha_x_half)
362
+ silu_x = x * sigmoid_alpha_x # FMUL
363
+ silu_x_dout = silu_x * dout # FMUL
364
+ # FFMA, FFMA, FMUL
365
+ d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
366
+ dx = d_silu_x_dout * y + d_silu_x_dout # FFMA, instead of multiply by y + 1
367
+ dy = silu_x_dout
368
+ swiglu_out = silu_x * y + silu_x # FFMA, instead of multiply by y + 1
369
+ # Overall it's 1 MUFU.TANH, 4 FMUL, 5 FFMA
370
+ return dx, dy, swiglu_out
371
+ else:
372
+ # Compute sigmoid(alpha * x)
373
+ alpha_x_half = utils.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x)
374
+ tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
375
+ sigmoid_alpha_x = utils.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5))
376
+ silu_x = utils.mul_packed_f32x2(x, sigmoid_alpha_x)
377
+ silu_x_dout = utils.mul_packed_f32x2(silu_x, dout)
378
+ # d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
379
+ silu_x_minus_product = utils.fma_packed_f32x2(
380
+ silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x
381
+ )
382
+ sigmoid_plus_alpha_diff = utils.fma_packed_f32x2(
383
+ (alpha, alpha), silu_x_minus_product, sigmoid_alpha_x
384
+ )
385
+ d_silu_x_dout = utils.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout)
386
+ dx = utils.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout)
387
+ dy = silu_x_dout
388
+ swiglu_out = utils.fma_packed_f32x2(silu_x, y, silu_x)
389
+ return dx, dy, swiglu_out
390
+
391
+
392
+ @dsl_user_op
393
+ def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
394
+ """GLU: Gated Linear Unit
395
+ glu(x, y) = sigmoid(x) * y
396
+ Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2))
397
+ """
398
+ if const_expr(not isinstance(x, tuple)):
399
+ sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
400
+ return sigmoid_x * y # FMUL
401
+ else:
402
+ sigmoid_x = sigmoid(x)
403
+ return utils.mul_packed_f32x2(sigmoid_x, y)
404
+
405
+
406
+ @dsl_user_op
407
+ def dglu(
408
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
409
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
410
+ """
411
+ GLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
412
+ Given: glu_out = sigmoid(x) * y, and dout = grad w.r.t. glu_out
413
+ Returns: (dx, dy, glu_out) where:
414
+ - dx = dout * y * sigmoid(x) * (1 - sigmoid(x))
415
+ - dy = dout * sigmoid(x)
416
+ - glu_out = sigmoid(x) * y
417
+ """
418
+ if const_expr(not isinstance(x, tuple)):
419
+ # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2))
420
+ sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
421
+ sigmoid_x_dout = sigmoid_x * dout # FMUL
422
+ glu_out = sigmoid_x * y # FMUL
423
+ # dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout
424
+ # = y * (1 - sigmoid(x)) * sigmoid_x_dout
425
+ # = (y - y * sigmoid(x)) * sigmoid_x_dout
426
+ # = (y - glu_out) * sigmoid_x_dout
427
+ dx = (y - glu_out) * sigmoid_x_dout # FADD, FMUL
428
+ dy = sigmoid_x_dout
429
+ # Total: 1 MUFU.TANH, 4 FMUL, 1 FADD, 1 FFMA
430
+ return dx, dy, glu_out
431
+ else:
432
+ sigmoid_x = sigmoid(x)
433
+ sigmoid_x_dout = utils.mul_packed_f32x2(sigmoid_x, dout)
434
+ glu_out = utils.mul_packed_f32x2(sigmoid_x, y)
435
+ # dx = (y - glu_out) * sigmoid_x_dout
436
+ y_minus_glu_out = utils.sub_packed_f32x2(y, glu_out)
437
+ dx = utils.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout)
438
+ dy = sigmoid_x_dout
439
+ return dx, dy, glu_out
440
+
441
+
442
+ @dsl_user_op
443
+ def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
444
+ """ReGLU: ReLU Gated Linear Unit
445
+ reglu(x, y) = relu(x) * y = max(x, 0) * y
446
+ """
447
+ if const_expr(not isinstance(x, tuple)):
448
+ return cute.arch.fmax(x, Float32(0.0)) * y
449
+ else:
450
+ relu_x = relu(x)
451
+ return utils.mul_packed_f32x2(relu_x, y)
452
+
453
+
454
+ @dsl_user_op
455
+ @cute.jit
456
+ def dreglu(
457
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
458
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
459
+ """
460
+ ReGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
461
+ Given: reglu_out = relu(x) * y, and dout = grad w.r.t. reglu_out
462
+ Returns: (dx, dy, reglu_out) where:
463
+ - dx = dout * y if x > 0, else 0
464
+ - dy = dout * relu(x)
465
+ - reglu_out = relu(x) * y
466
+ """
467
+ if const_expr(not isinstance(x, tuple)):
468
+ x_pos = Boolean(x > 0)
469
+ relu_x = cute.arch.fmax(x, Float32(0.0))
470
+ dx = (dout * y) if x_pos else Float32(0.0)
471
+ dy = dout * relu_x
472
+ reglu_out = relu_x * y
473
+ return dx, dy, reglu_out
474
+ else:
475
+ x0_pos = Boolean(x[0] > 0)
476
+ x1_pos = Boolean(x[1] > 0)
477
+ relu_x = relu(x)
478
+ dout_y = utils.mul_packed_f32x2(dout, y)
479
+ dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0)))
480
+ dy = utils.mul_packed_f32x2(dout, relu_x)
481
+ reglu_out = utils.mul_packed_f32x2(relu_x, y)
482
+ return dx, dy, reglu_out
483
+
484
+
485
+ @dsl_user_op
486
+ def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
487
+ """GeGLU: GELU Gated Linear Unit
488
+ geglu(x, y) = gelu(x) * y
489
+ Uses the tanh approximation of GELU
490
+ """
491
+ if const_expr(not isinstance(x, tuple)):
492
+ return gelu_tanh_approx(x) * y
493
+ else:
494
+ return utils.mul_packed_f32x2(gelu_tanh_approx(x), y)
495
+
496
+
497
+ @dsl_user_op
498
+ def dgeglu(
499
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
500
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
501
+ """
502
+ GeGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
503
+ Given: geglu_out = gelu(x) * y, and dout = grad w.r.t. geglu_out
504
+ Returns: (dx, dy, geglu_out) where:
505
+ - dx = dout * y * d_gelu(x)
506
+ - dy = dout * gelu(x)
507
+ - geglu_out = gelu(x) * y
508
+ """
509
+ if const_expr(not isinstance(x, tuple)):
510
+ # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
511
+ dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
512
+ # Compute gradients for geglu
513
+ dx = dgelu_x_dout * y
514
+ dy = gelu_x * dout
515
+ geglu_out = gelu_x * y
516
+ return dx, dy, geglu_out
517
+ else:
518
+ # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
519
+ dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
520
+ # Compute gradients for geglu
521
+ dx = utils.mul_packed_f32x2(dgelu_x_dout, y)
522
+ dy = utils.mul_packed_f32x2(gelu_x, dout)
523
+ geglu_out = utils.mul_packed_f32x2(gelu_x, y)
524
+ return dx, dy, geglu_out
@@ -11,7 +11,7 @@ import hashlib
11
11
  import json
12
12
  from pathlib import Path
13
13
  from functools import cached_property, partial
14
- from typing import Dict, Tuple
14
+ from typing import Dict, Tuple, List, Optional, Any
15
15
 
16
16
  import torch
17
17
  from torch import Tensor
@@ -53,7 +53,22 @@ def _base32(key):
53
53
 
54
54
 
55
55
  class Autotuner:
56
- def __init__(self, fn, key, configs, restore_value=None, do_bench=None, cache_results=False):
56
+ def __init__(
57
+ self,
58
+ fn,
59
+ key,
60
+ configs,
61
+ restore_value=None,
62
+ prune_configs_by: Optional[Dict] = None,
63
+ do_bench=None,
64
+ cache_results=False,
65
+ ):
66
+ """
67
+ :param prune_configs_by: a dict of functions that are used to prune configs, fields:
68
+ 'perf_model': performance model used to predicate running time with different configs, returns running time
69
+ 'top_k': number of configs to bench
70
+ 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
71
+ """
57
72
  if not configs:
58
73
  self.configs = [AutotuneConfig()]
59
74
  else:
@@ -90,6 +105,16 @@ class Autotuner:
90
105
  else:
91
106
  self.post_hook = None
92
107
 
108
+ self.perf_model = None
109
+ self.configs_top_k = 1.0
110
+ self.early_config_prune = None
111
+ if prune_configs_by:
112
+ self.perf_model = prune_configs_by.get("perf_model", self.perf_model)
113
+ self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k)
114
+ self.early_config_prune = prune_configs_by.get(
115
+ "early_config_prune", self.early_config_prune
116
+ )
117
+
93
118
  self.fn = fn
94
119
  self._do_bench = do_bench
95
120
 
@@ -198,13 +223,14 @@ class Autotuner:
198
223
  key = tuple(key)
199
224
  if key not in self.cache:
200
225
  used_cached_result = False
226
+ pruned_configs = self.prune_configs(kwargs)
201
227
 
202
228
  @torch.compiler.disable # Don't want any tracing here
203
229
  def benchmark():
204
230
  bench_start = time.time()
205
231
  timings = {
206
232
  config: self._bench(*args, config=config, **kwargs)
207
- for config in self.configs
233
+ for config in pruned_configs
208
234
  }
209
235
  bench_end = time.time()
210
236
  if os.getenv(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1":
@@ -215,7 +241,7 @@ class Autotuner:
215
241
  self.configs_timings = timings
216
242
 
217
243
  if self.cache_results:
218
- self.check_disk_cache(key, self.configs, benchmark)
244
+ self.check_disk_cache(key, pruned_configs, benchmark)
219
245
  else:
220
246
  benchmark()
221
247
 
@@ -239,6 +265,32 @@ class Autotuner:
239
265
  self.nargs = None
240
266
  return ret
241
267
 
268
+ def prune_configs(self, kwargs: Dict) -> List[Any]:
269
+ pruned_configs = self.configs
270
+ if self.early_config_prune:
271
+ pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
272
+ if self.perf_model:
273
+ top_k = self.configs_top_k
274
+ if isinstance(top_k, float) and top_k <= 1.0:
275
+ top_k = int(len(self.configs) * top_k)
276
+ elif not isinstance(top_k, int):
277
+ # Slice index must be an integer
278
+ raise TypeError(
279
+ "Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int"
280
+ )
281
+
282
+ if len(pruned_configs) > top_k:
283
+ est_timing = {
284
+ config: self.perf_model(
285
+ **self.nargs,
286
+ **kwargs,
287
+ **config.all_kwargs(),
288
+ )
289
+ for config in pruned_configs
290
+ }
291
+ pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
292
+ return pruned_configs
293
+
242
294
 
243
295
  class AutotuneConfig:
244
296
  """
@@ -272,7 +324,9 @@ class AutotuneConfig:
272
324
  return self_tuple == other_tuple
273
325
 
274
326
 
275
- def autotune(configs, key=None, restore_value=None, do_bench=None, cache_results=True):
327
+ def autotune(
328
+ configs, key=None, prune_configs_by=None, restore_value=None, do_bench=None, cache_results=True
329
+ ):
276
330
  f"""
277
331
  Decorator for auto-tuning a function function.
278
332
 
@@ -286,6 +340,10 @@ def autotune(configs, key=None, restore_value=None, do_bench=None, cache_results
286
340
  :type configs: list[AutotuneConfig]
287
341
  :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
288
342
  :type key: list[str]
343
+ :param prune_configs_by: a dict of functions that are used to prune configs, fields:
344
+ 'perf_model': performance model used to predicate running time with different configs, returns running time
345
+ 'top_k': number of configs to bench
346
+ 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
289
347
  :param restore_value: a list of argument names whose value will be restored after evaluating any configs.
290
348
  :type restore_value: list[str]
291
349
  :param do_bench: a benchmark function to measure the time of each run.
@@ -303,6 +361,7 @@ def autotune(configs, key=None, restore_value=None, do_bench=None, cache_results
303
361
  key,
304
362
  configs,
305
363
  restore_value=restore_value,
364
+ prune_configs_by=prune_configs_by,
306
365
  do_bench=do_bench,
307
366
  cache_results=cache_results,
308
367
  )