quack-kernels 0.2.2__tar.gz → 0.2.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. {quack_kernels-0.2.2/quack_kernels.egg-info → quack_kernels-0.2.4}/PKG-INFO +4 -2
  2. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/pyproject.toml +3 -1
  3. quack_kernels-0.2.4/quack/__init__.py +11 -0
  4. quack_kernels-0.2.4/quack/activation.py +524 -0
  5. quack_kernels-0.2.4/quack/broadcast_utils.py +29 -0
  6. quack_kernels-0.2.4/quack/compile_utils.py +19 -0
  7. quack_kernels-0.2.4/quack/copy_utils.py +487 -0
  8. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/cross_entropy.py +157 -233
  9. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/cute_dsl_utils.py +20 -34
  10. quack_kernels-0.2.4/quack/gemm.py +194 -0
  11. quack_kernels-0.2.2/quack/gemm_act_sm90.py → quack_kernels-0.2.4/quack/gemm_act.py +218 -117
  12. quack_kernels-0.2.4/quack/gemm_config.py +95 -0
  13. quack_kernels-0.2.2/quack/gemm_dact_sm90.py → quack_kernels-0.2.4/quack/gemm_dact.py +53 -21
  14. quack_kernels-0.2.4/quack/gemm_default_epi.py +259 -0
  15. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/gemm_interface.py +177 -31
  16. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/gemm_sm100.py +729 -506
  17. quack_kernels-0.2.2/quack/dense_gemm_sm90.py → quack_kernels-0.2.4/quack/gemm_sm90.py +344 -814
  18. quack_kernels-0.2.4/quack/gemm_symmetric.py +330 -0
  19. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/gemm_wrapper_utils.py +3 -1
  20. quack_kernels-0.2.4/quack/layout_utils.py +287 -0
  21. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/linear.py +24 -16
  22. quack_kernels-0.2.4/quack/pipeline.py +306 -0
  23. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/reduce.py +88 -49
  24. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/reduction_base.py +25 -36
  25. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/rmsnorm.py +476 -526
  26. quack_kernels-0.2.4/quack/sm100_utils.py +62 -0
  27. quack_kernels-0.2.4/quack/sm90_utils.py +127 -0
  28. quack_kernels-0.2.4/quack/softmax.py +403 -0
  29. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/sort/bitonic_sort.py +13 -10
  30. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/sort/utils.py +6 -6
  31. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/tile_scheduler.py +23 -16
  32. quack_kernels-0.2.4/quack/topk.py +551 -0
  33. quack_kernels-0.2.4/quack/utils.py +223 -0
  34. quack_kernels-0.2.4/quack/varlen_utils.py +386 -0
  35. {quack_kernels-0.2.2 → quack_kernels-0.2.4/quack_kernels.egg-info}/PKG-INFO +4 -2
  36. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack_kernels.egg-info/SOURCES.txt +13 -6
  37. quack_kernels-0.2.4/quack_kernels.egg-info/requires.txt +8 -0
  38. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/tests/test_layernorm.py +17 -51
  39. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/tests/test_linear.py +37 -17
  40. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/tests/test_linear_varlen_k.py +49 -3
  41. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/tests/test_linear_varlen_m.py +43 -24
  42. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/tests/test_rmsnorm.py +26 -17
  43. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/tests/test_softmax.py +1 -2
  44. quack_kernels-0.2.2/tests/test_symmetric_dense_gemm_sm90.py → quack_kernels-0.2.4/tests/test_symmetric_gemm.py +15 -16
  45. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/tests/test_topk.py +42 -23
  46. quack_kernels-0.2.2/quack/__init__.py +0 -18
  47. quack_kernels-0.2.2/quack/activation.py +0 -279
  48. quack_kernels-0.2.2/quack/gemm_config.py +0 -69
  49. quack_kernels-0.2.2/quack/layernorm.py +0 -353
  50. quack_kernels-0.2.2/quack/pipeline.py +0 -151
  51. quack_kernels-0.2.2/quack/softmax.py +0 -471
  52. quack_kernels-0.2.2/quack/symmetric_dense_gemm_sm90.py +0 -2091
  53. quack_kernels-0.2.2/quack/topk.py +0 -227
  54. quack_kernels-0.2.2/quack/utils.py +0 -411
  55. quack_kernels-0.2.2/quack/varlen_utils.py +0 -17
  56. quack_kernels-0.2.2/quack_kernels.egg-info/requires.txt +0 -6
  57. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/LICENSE +0 -0
  58. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/README.md +0 -0
  59. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/autotuner.py +0 -0
  60. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/fast_math.py +0 -0
  61. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/linear_cross_entropy.py +0 -0
  62. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/mlp.py +0 -0
  63. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/sort/generate_sorting_networks.py +0 -0
  64. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/sort/sorting_networks.py +0 -0
  65. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack/tensormap_manager.py +0 -0
  66. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack_kernels.egg-info/dependency_links.txt +0 -0
  67. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/quack_kernels.egg-info/top_level.txt +0 -0
  68. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/setup.cfg +0 -0
  69. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/tests/test_cross_entropy.py +1 -1
  70. {quack_kernels-0.2.2 → quack_kernels-0.2.4}/tests/test_linear_cross_entropy.py +0 -0
@@ -1,10 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: quack-kernels
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Requires-Python: >=3.10
5
5
  License-File: LICENSE
6
- Requires-Dist: nvidia-cutlass-dsl==4.2.1
6
+ Requires-Dist: nvidia-cutlass-dsl<4.4.0,>=4.3.4
7
7
  Requires-Dist: torch
8
+ Requires-Dist: apache-tvm-ffi<0.2,>=0.1.6
9
+ Requires-Dist: torch-c-dlpack-ext
8
10
  Provides-Extra: dev
9
11
  Requires-Dist: pre-commit; extra == "dev"
10
12
  Requires-Dist: ruff; extra == "dev"
@@ -7,8 +7,10 @@ name = "quack-kernels"
7
7
  dynamic = ["version"]
8
8
  requires-python = ">=3.10"
9
9
  dependencies = [
10
- "nvidia-cutlass-dsl==4.2.1",
10
+ "nvidia-cutlass-dsl>=4.3.4,<4.4.0",
11
11
  "torch",
12
+ "apache-tvm-ffi>=0.1.6,<0.2",
13
+ "torch-c-dlpack-ext",
12
14
  ]
13
15
 
14
16
  [project.optional-dependencies]
@@ -0,0 +1,11 @@
1
+ __version__ = "0.2.4"
2
+
3
+ from quack.rmsnorm import rmsnorm
4
+ from quack.softmax import softmax
5
+ from quack.cross_entropy import cross_entropy
6
+
7
+ __all__ = [
8
+ "rmsnorm",
9
+ "softmax",
10
+ "cross_entropy",
11
+ ]
@@ -0,0 +1,524 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+
3
+ import math
4
+ from typing import Tuple
5
+
6
+ import cutlass.cute as cute
7
+ from cutlass import Float32, Boolean, const_expr
8
+ from cutlass.cutlass_dsl import T, dsl_user_op
9
+ from cutlass._mlir.dialects import llvm
10
+
11
+ import quack.utils as utils
12
+
13
+
14
+ F32_or_F32x2 = Float32 | Tuple[Float32, Float32]
15
+
16
+
17
+ @dsl_user_op
18
+ def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
19
+ return Float32(
20
+ llvm.inline_asm(
21
+ T.f32(),
22
+ [Float32(a).ir_value(loc=loc, ip=ip)],
23
+ "tanh.approx.f32 $0, $1;",
24
+ "=f,f",
25
+ has_side_effects=False,
26
+ is_align_stack=False,
27
+ asm_dialect=llvm.AsmDialect.AD_ATT,
28
+ )
29
+ )
30
+
31
+
32
+ @dsl_user_op
33
+ def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
34
+ if const_expr(not isinstance(x, tuple)):
35
+ # return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
36
+ return 0.5 + 0.5 * tanh(0.5 * x)
37
+ else:
38
+ x_half = utils.mul_packed_f32x2((0.5, 0.5), x)
39
+ tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
40
+ return utils.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5))
41
+
42
+
43
+ @dsl_user_op
44
+ def dsigmoid_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
45
+ # return dout * out * (1.0 - out)
46
+ return dout * (out - out * out)
47
+
48
+
49
+ @dsl_user_op
50
+ def relu(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
51
+ if const_expr(not isinstance(x, tuple)):
52
+ return cute.arch.fmax(x, Float32(0.0))
53
+ else:
54
+ return cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0))
55
+
56
+
57
+ @dsl_user_op
58
+ @cute.jit
59
+ def drelu(
60
+ x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
61
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
62
+ if const_expr(not isinstance(x, tuple)):
63
+ x_pos = Boolean(x > 0)
64
+ return dout if x_pos else Float32(0.0), cute.arch.fmax(x, Float32(0.0))
65
+ else:
66
+ x0_pos = Boolean(x[0] > 0)
67
+ x1_pos = Boolean(x[1] > 0)
68
+ dx = (dout[0] if x0_pos else Float32(0.0), dout[1] if x1_pos else Float32(0.0))
69
+ return dx, relu(x)
70
+
71
+
72
+ @dsl_user_op
73
+ def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
74
+ if const_expr(not isinstance(x, tuple)):
75
+ return cute.arch.fmax(x, Float32(0.0)) * x
76
+ else:
77
+ relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)))
78
+ return utils.mul_packed_f32x2(relu_x, x)
79
+
80
+
81
+ @dsl_user_op
82
+ @cute.jit
83
+ def drelu_sq(
84
+ x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
85
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
86
+ """
87
+ ReLU squared backward pass: computes gradient w.r.t. x and recomputes forward
88
+ Given: relu_sq_out = max(x, 0) * x, and dout = grad w.r.t. relu_sq_out
89
+ Returns: (dx, relu_sq_out) where:
90
+ - dx = dout * 2 * x if x > 0, else 0
91
+ - relu_sq_out = max(x, 0) * x
92
+ """
93
+ if const_expr(not isinstance(x, tuple)):
94
+ relu_x = relu(x)
95
+ relu_sq_out = relu_x * x
96
+ # Derivative: d/dx[max(x,0) * x] = 2*x if x > 0, else 0
97
+ dx = 2.0 * (dout * relu_x)
98
+ return dx, relu_sq_out
99
+ else:
100
+ relu_x = relu(x)
101
+ relu_sq_out = utils.mul_packed_f32x2(relu_x, x)
102
+ dx = utils.mul_packed_f32x2((2.0, 2.0), utils.mul_packed_f32x2(dout, relu_x))
103
+ return dx, relu_sq_out
104
+
105
+
106
+ @dsl_user_op
107
+ def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
108
+ """
109
+ gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
110
+ = 0.5 * x * (1 + tanh(x * (0.797885 + 0.0356774 * x * x)))
111
+ """
112
+ sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885
113
+ sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774
114
+ if const_expr(not isinstance(x, tuple)):
115
+ return 0.5 * (
116
+ x
117
+ # Currently cute.math.tanh(x, fastmath=True) generates very slow code
118
+ # * (1 + cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)), fastmath=True))
119
+ * (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x))))
120
+ )
121
+ else:
122
+ x_sq = utils.mul_packed_f32x2(x, x)
123
+ x_sq_scaled = utils.fma_packed_f32x2(
124
+ x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
125
+ )
126
+ z = utils.mul_packed_f32x2(x, x_sq_scaled)
127
+ tanh_z = (tanh(z[0]), tanh(z[1]))
128
+ x_tanh_z = utils.fma_packed_f32x2(tanh_z, x, x)
129
+ return utils.mul_packed_f32x2((0.5, 0.5), x_tanh_z)
130
+
131
+
132
+ @dsl_user_op
133
+ def dgelu_tanh_approx(
134
+ x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
135
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
136
+ """
137
+ GELU tanh approximation backward pass: computes gradient w.r.t. x and recomputes forward
138
+ Given: gelu_out = 0.5 * x * (1 + tanh(x * (c1 + c2 * x^2))), and dout = grad w.r.t. gelu_out
139
+ Returns: (dx, gelu_out)
140
+
141
+ Derivative uses the chain rule:
142
+ d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
143
+ where z = x * (c1 + c2 * x^2), dz/dx = c1 + 3 * c2 * x^2
144
+ and sech^2(z) = 1 - tanh^2(z)
145
+ """
146
+ sqrt_2_over_pi = math.sqrt(2 / math.pi) # c1 ~0.797885
147
+ sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # c2 ~0.0356774
148
+ sqrt_2_over_pi_coeff_3 = 3.0 * sqrt_2_over_pi_coeff # c3 ~0.01070322
149
+
150
+ if const_expr(not isinstance(x, tuple)):
151
+ # Compute z = x * (c1 + c2 * x^2)
152
+ x_sq = x * x
153
+ # tanh_z = cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq), fastmath=True)
154
+ tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq))
155
+ half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z
156
+ gelu_out = x * half_tanh_z_plus_one
157
+
158
+ # Compute gradient
159
+ # sech^2(z) = 1 - tanh^2(z)
160
+ sech2_z = 1 - tanh_z * tanh_z
161
+ # dz/dx = c1 + 3 * c2 * x^2
162
+ dz_dx = sqrt_2_over_pi + sqrt_2_over_pi_coeff_3 * x_sq
163
+ # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
164
+ dgelu = half_tanh_z_plus_one + x * (0.5 * (sech2_z * dz_dx))
165
+
166
+ dx = dout * dgelu
167
+ return dx, gelu_out
168
+ else:
169
+ # Compute z = x * (c1 + c2 * x^2)
170
+ x_sq = utils.mul_packed_f32x2(x, x)
171
+ x_sq_scaled = utils.fma_packed_f32x2(
172
+ x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
173
+ )
174
+ z = utils.mul_packed_f32x2(x, x_sq_scaled)
175
+ tanh_z = (tanh(z[0]), tanh(z[1]))
176
+ half_tanh_z_plus_one = utils.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5))
177
+ gelu_out = utils.mul_packed_f32x2(x, half_tanh_z_plus_one)
178
+
179
+ # Compute gradient
180
+ # sech^2(z) = 1 - tanh^2(z)
181
+ sech2_z = utils.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0))
182
+ # dz/dx = c1 + 3 * c2 * x^2
183
+ dz_dx = utils.fma_packed_f32x2(
184
+ x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi)
185
+ )
186
+ # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
187
+ sech2_dz_dx = utils.mul_packed_f32x2(sech2_z, dz_dx)
188
+ x_sech2_dz_dx = utils.mul_packed_f32x2(x, sech2_dz_dx)
189
+ dgelu = utils.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one)
190
+
191
+ dx = utils.mul_packed_f32x2(dout, dgelu)
192
+ return dx, gelu_out
193
+
194
+
195
+ @dsl_user_op
196
+ @cute.jit
197
+ def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
198
+ if const_expr(not isinstance(x, tuple)):
199
+ use_linear = Boolean(x > 20.0)
200
+ return (
201
+ cute.math.log(Float32(cute.math.exp(x, fastmath=True)) + 1.0, fastmath=True)
202
+ if not use_linear
203
+ else x
204
+ )
205
+ else:
206
+ log2_e = math.log2(math.e)
207
+ x_log2e = utils.mul_packed_f32x2(x, (log2_e, log2_e))
208
+ x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True))
209
+ x_exp_p1 = utils.add_packed_f32x2(x_exp, (1.0, 1.0))
210
+ log_x_exp_p1 = (
211
+ cute.math.log2(x_exp_p1[0], fastmath=True),
212
+ cute.math.log2(x_exp_p1[1], fastmath=True),
213
+ )
214
+ ln2 = math.log(2.0)
215
+ softplus_x = utils.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2))
216
+ use_linear_0 = Boolean(x[0] > 20.0)
217
+ use_linear_1 = Boolean(x[1] > 20.0)
218
+ return (
219
+ softplus_x[0] if not use_linear_0 else x[0],
220
+ softplus_x[1] if not use_linear_1 else x[1],
221
+ )
222
+
223
+
224
+ @dsl_user_op
225
+ @cute.jit
226
+ def dsoftplus_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
227
+ use_linear = Boolean(out > 20.0)
228
+ # dx = dout * (1.0 - cute.math.exp(-out, fastmath=True)) if not use_linear else dout
229
+ dx = dout - dout * cute.math.exp(-out, fastmath=True)
230
+ return dx if not use_linear else dout
231
+
232
+
233
+ @dsl_user_op
234
+ def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) -> F32_or_F32x2:
235
+ """
236
+ silu(x) = x * sigmoid(x) = x * (1 + tanh(x / 2)) / 2 = (0.5 * x) * tanh(0.5 * x) + (0.5 * x)
237
+ This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA.
238
+ """
239
+ if const_expr(not isinstance(x, tuple)):
240
+ x_half = 0.5 * x if const_expr(not already_halved) else x
241
+ # return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
242
+ return x_half * tanh(x_half) + x_half
243
+ else:
244
+ x_half = utils.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x
245
+ tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
246
+ return utils.fma_packed_f32x2(x_half, tanh_x_half, x_half)
247
+
248
+
249
+ @dsl_user_op
250
+ def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
251
+ if const_expr(not isinstance(x, tuple)):
252
+ return silu(x) * y
253
+ else:
254
+ return utils.mul_packed_f32x2(silu(x), y)
255
+
256
+
257
+ @dsl_user_op
258
+ def dswiglu(
259
+ x: F32_or_F32x2,
260
+ y: F32_or_F32x2,
261
+ dout: F32_or_F32x2,
262
+ *,
263
+ already_halved: bool = False,
264
+ loc=None,
265
+ ip=None,
266
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
267
+ """
268
+ SwiGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
269
+ Given: swiglu_out = silu(x) * y, and dout = grad w.r.t. swiglu_out
270
+ Returns: (dx, dy, swiglu_out) where dx = dout * y * d_silu(x), dy = dout * silu(x)
271
+
272
+ d_silu(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
273
+
274
+ This has been optimized to use fewer instructions (i.e. we expand things out
275
+ to use FFMA instead of FADD and FMUL).
276
+ """
277
+ if const_expr(not isinstance(x, tuple)):
278
+ # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x))
279
+ # FMUL, MUFU.TANH, then FFMA
280
+ if const_expr(not already_halved):
281
+ sigmoid_x = sigmoid(x)
282
+ silu_x = x * sigmoid_x # FMUL
283
+ else:
284
+ tanh_x = tanh(x) # MUFU.TANH
285
+ sigmoid_x = 0.5 * tanh_x + 0.5 # FFMA
286
+ silu_x = x * tanh_x + x # FFMA
287
+ silu_x_dout = silu_x * dout # FMUL
288
+ # d_silu(x) * dout
289
+ # = sigmoid_x * (1 + x * (1 - sigmoid_x)) * dout
290
+ # = (sigmoid_x + sigmoid_x * x * (1 - sigmoid_x)) * dout
291
+ # = (sigmoid_x + silu_x * (1 - sigmoid_x)) * dout
292
+ # = (sigmoid_x + silu_x - silu_x * sigmoid_x) * dout
293
+ # = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
294
+ d_silu_x_dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x_dout # FFMA, FFMA
295
+ dx = d_silu_x_dout * y # FMUL
296
+ dy = silu_x_dout
297
+ swiglu_out = silu_x * y # FMUL
298
+ # Overall it's 1 MUFU.TANH, 5 FMUL, 3 FFMA
299
+ return dx, dy, swiglu_out
300
+ else:
301
+ # Compute sigmoid(x) and silu(x)
302
+ if const_expr(not already_halved):
303
+ sigmoid_x = sigmoid(x)
304
+ silu_x = utils.mul_packed_f32x2(x, sigmoid_x)
305
+ else:
306
+ tanh_x = (tanh(x[0]), tanh(x[1]))
307
+ sigmoid_x = utils.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5))
308
+ silu_x = utils.fma_packed_f32x2(x, tanh_x, x)
309
+ silu_x_dout = utils.mul_packed_f32x2(silu_x, dout)
310
+ # d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
311
+ sigmoid_x_minus_silu_x_sigmoid_x = utils.fma_packed_f32x2(
312
+ sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x
313
+ )
314
+ d_silu_x_dout = utils.fma_packed_f32x2(sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout)
315
+ dx = utils.mul_packed_f32x2(d_silu_x_dout, y)
316
+ dy = silu_x_dout
317
+ swiglu_out = utils.mul_packed_f32x2(silu_x, y)
318
+ return dx, dy, swiglu_out
319
+
320
+
321
+ @dsl_user_op
322
+ def swiglu_oai(
323
+ x: F32_or_F32x2, y: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
324
+ ) -> F32_or_F32x2:
325
+ """The swiglu variant used in gpt-oss, which has a scaling factor on x and bias of 1 to y.
326
+ https://github.com/openai/gpt-oss/blob/7be9334950053a888e24887a57dac797a17d6e00/gpt_oss/torch/model.py#L249
327
+ x * sigmoid(alpha * x) * (y + 1)
328
+ Compile down to FMUL, FMUL, TANH, FFMA, FFMA
329
+ """
330
+ # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
331
+ if const_expr(not isinstance(x, tuple)):
332
+ x_half = 0.5 * x
333
+ # silu_x = x_half * cute.math.tanh(alpha * x_half, fastmath=True) + x_half
334
+ silu_x = x_half * tanh(alpha * x_half) + x_half
335
+ return silu_x * y + silu_x
336
+ else:
337
+ x_half = utils.mul_packed_f32x2((0.5, 0.5), x)
338
+ alpha_x_half = utils.mul_packed_f32x2((alpha, alpha), x_half)
339
+ tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
340
+ silu_x = utils.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half)
341
+ return utils.fma_packed_f32x2(silu_x, y, silu_x)
342
+
343
+
344
+ @dsl_user_op
345
+ def dswiglu_oai(
346
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
347
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
348
+ """
349
+ Swiglu OAI backward pass: computes gradients w.r.t. x and y
350
+ Given: swiglu_oai_out = x * sigmoid(alpha * x) * (y + 1), and dout = grad w.r.t. swiglu_oai_out
351
+ Returns: (dx, dy, swiglu_oai_out)
352
+
353
+ Derivative of x * sigmoid(alpha * x) w.r.t. x:
354
+ d/dx[x * sigmoid(alpha * x)] = sigmoid(alpha * x) + alpha * x * sigmoid(alpha * x) * (1 - sigmoid(alpha * x))
355
+ """
356
+ if const_expr(not isinstance(x, tuple)):
357
+ # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
358
+ alpha_x_half = (0.5 * alpha) * x # FMUL
359
+ # MUFU.TANH, then FFMA
360
+ # sigmoid_alpha_x = 0.5 + 0.5 * cute.math.tanh(alpha_x_half, fastmath=True)
361
+ sigmoid_alpha_x = 0.5 + 0.5 * tanh(alpha_x_half)
362
+ silu_x = x * sigmoid_alpha_x # FMUL
363
+ silu_x_dout = silu_x * dout # FMUL
364
+ # FFMA, FFMA, FMUL
365
+ d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
366
+ dx = d_silu_x_dout * y + d_silu_x_dout # FFMA, instead of multiply by y + 1
367
+ dy = silu_x_dout
368
+ swiglu_out = silu_x * y + silu_x # FFMA, instead of multiply by y + 1
369
+ # Overall it's 1 MUFU.TANH, 4 FMUL, 5 FFMA
370
+ return dx, dy, swiglu_out
371
+ else:
372
+ # Compute sigmoid(alpha * x)
373
+ alpha_x_half = utils.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x)
374
+ tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
375
+ sigmoid_alpha_x = utils.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5))
376
+ silu_x = utils.mul_packed_f32x2(x, sigmoid_alpha_x)
377
+ silu_x_dout = utils.mul_packed_f32x2(silu_x, dout)
378
+ # d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
379
+ silu_x_minus_product = utils.fma_packed_f32x2(
380
+ silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x
381
+ )
382
+ sigmoid_plus_alpha_diff = utils.fma_packed_f32x2(
383
+ (alpha, alpha), silu_x_minus_product, sigmoid_alpha_x
384
+ )
385
+ d_silu_x_dout = utils.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout)
386
+ dx = utils.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout)
387
+ dy = silu_x_dout
388
+ swiglu_out = utils.fma_packed_f32x2(silu_x, y, silu_x)
389
+ return dx, dy, swiglu_out
390
+
391
+
392
+ @dsl_user_op
393
+ def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
394
+ """GLU: Gated Linear Unit
395
+ glu(x, y) = sigmoid(x) * y
396
+ Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2))
397
+ """
398
+ if const_expr(not isinstance(x, tuple)):
399
+ sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
400
+ return sigmoid_x * y # FMUL
401
+ else:
402
+ sigmoid_x = sigmoid(x)
403
+ return utils.mul_packed_f32x2(sigmoid_x, y)
404
+
405
+
406
+ @dsl_user_op
407
+ def dglu(
408
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
409
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
410
+ """
411
+ GLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
412
+ Given: glu_out = sigmoid(x) * y, and dout = grad w.r.t. glu_out
413
+ Returns: (dx, dy, glu_out) where:
414
+ - dx = dout * y * sigmoid(x) * (1 - sigmoid(x))
415
+ - dy = dout * sigmoid(x)
416
+ - glu_out = sigmoid(x) * y
417
+ """
418
+ if const_expr(not isinstance(x, tuple)):
419
+ # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2))
420
+ sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
421
+ sigmoid_x_dout = sigmoid_x * dout # FMUL
422
+ glu_out = sigmoid_x * y # FMUL
423
+ # dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout
424
+ # = y * (1 - sigmoid(x)) * sigmoid_x_dout
425
+ # = (y - y * sigmoid(x)) * sigmoid_x_dout
426
+ # = (y - glu_out) * sigmoid_x_dout
427
+ dx = (y - glu_out) * sigmoid_x_dout # FADD, FMUL
428
+ dy = sigmoid_x_dout
429
+ # Total: 1 MUFU.TANH, 4 FMUL, 1 FADD, 1 FFMA
430
+ return dx, dy, glu_out
431
+ else:
432
+ sigmoid_x = sigmoid(x)
433
+ sigmoid_x_dout = utils.mul_packed_f32x2(sigmoid_x, dout)
434
+ glu_out = utils.mul_packed_f32x2(sigmoid_x, y)
435
+ # dx = (y - glu_out) * sigmoid_x_dout
436
+ y_minus_glu_out = utils.sub_packed_f32x2(y, glu_out)
437
+ dx = utils.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout)
438
+ dy = sigmoid_x_dout
439
+ return dx, dy, glu_out
440
+
441
+
442
+ @dsl_user_op
443
+ def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
444
+ """ReGLU: ReLU Gated Linear Unit
445
+ reglu(x, y) = relu(x) * y = max(x, 0) * y
446
+ """
447
+ if const_expr(not isinstance(x, tuple)):
448
+ return cute.arch.fmax(x, Float32(0.0)) * y
449
+ else:
450
+ relu_x = relu(x)
451
+ return utils.mul_packed_f32x2(relu_x, y)
452
+
453
+
454
+ @dsl_user_op
455
+ @cute.jit
456
+ def dreglu(
457
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
458
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
459
+ """
460
+ ReGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
461
+ Given: reglu_out = relu(x) * y, and dout = grad w.r.t. reglu_out
462
+ Returns: (dx, dy, reglu_out) where:
463
+ - dx = dout * y if x > 0, else 0
464
+ - dy = dout * relu(x)
465
+ - reglu_out = relu(x) * y
466
+ """
467
+ if const_expr(not isinstance(x, tuple)):
468
+ x_pos = Boolean(x > 0)
469
+ relu_x = cute.arch.fmax(x, Float32(0.0))
470
+ dx = (dout * y) if x_pos else Float32(0.0)
471
+ dy = dout * relu_x
472
+ reglu_out = relu_x * y
473
+ return dx, dy, reglu_out
474
+ else:
475
+ x0_pos = Boolean(x[0] > 0)
476
+ x1_pos = Boolean(x[1] > 0)
477
+ relu_x = relu(x)
478
+ dout_y = utils.mul_packed_f32x2(dout, y)
479
+ dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0)))
480
+ dy = utils.mul_packed_f32x2(dout, relu_x)
481
+ reglu_out = utils.mul_packed_f32x2(relu_x, y)
482
+ return dx, dy, reglu_out
483
+
484
+
485
+ @dsl_user_op
486
+ def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
487
+ """GeGLU: GELU Gated Linear Unit
488
+ geglu(x, y) = gelu(x) * y
489
+ Uses the tanh approximation of GELU
490
+ """
491
+ if const_expr(not isinstance(x, tuple)):
492
+ return gelu_tanh_approx(x) * y
493
+ else:
494
+ return utils.mul_packed_f32x2(gelu_tanh_approx(x), y)
495
+
496
+
497
+ @dsl_user_op
498
+ def dgeglu(
499
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
500
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
501
+ """
502
+ GeGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
503
+ Given: geglu_out = gelu(x) * y, and dout = grad w.r.t. geglu_out
504
+ Returns: (dx, dy, geglu_out) where:
505
+ - dx = dout * y * d_gelu(x)
506
+ - dy = dout * gelu(x)
507
+ - geglu_out = gelu(x) * y
508
+ """
509
+ if const_expr(not isinstance(x, tuple)):
510
+ # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
511
+ dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
512
+ # Compute gradients for geglu
513
+ dx = dgelu_x_dout * y
514
+ dy = gelu_x * dout
515
+ geglu_out = gelu_x * y
516
+ return dx, dy, geglu_out
517
+ else:
518
+ # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
519
+ dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
520
+ # Compute gradients for geglu
521
+ dx = utils.mul_packed_f32x2(dgelu_x_dout, y)
522
+ dy = utils.mul_packed_f32x2(gelu_x, dout)
523
+ geglu_out = utils.mul_packed_f32x2(gelu_x, y)
524
+ return dx, dy, geglu_out
@@ -0,0 +1,29 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+ from typing import Callable
3
+
4
+ import cutlass
5
+ import cutlass.cute as cute
6
+ from cutlass import Float32, const_expr
7
+
8
+ from quack.layout_utils import make_acc_tensor_mn_view
9
+
10
+
11
+ @cute.jit
12
+ def vec_op(tCrC: cute.Tensor, tCrVec: cute.Tensor, op: Callable, is_colvec: bool) -> None:
13
+ if const_expr(tCrC.element_type != Float32): # Convert to f32
14
+ tCrC_f32 = cute.make_fragment(tCrC.shape, Float32)
15
+ tCrC_f32.store(tCrC.load().to(Float32))
16
+ else:
17
+ tCrC_f32 = tCrC
18
+ # this happens to work for frgA layout too, not just acc layout
19
+ tCrC_f32_mn = make_acc_tensor_mn_view(tCrC_f32)
20
+ if const_expr(is_colvec):
21
+ assert cute.size(tCrC_f32_mn, mode=[0]) == cute.size(tCrVec)
22
+ for r in cutlass.range(cute.size(tCrC_f32_mn, mode=[0]), unroll_full=True):
23
+ tCrC_f32_mn[r, None].store(op(tCrC_f32_mn[r, None].load(), tCrVec[r]))
24
+ else:
25
+ assert cute.size(tCrC_f32_mn, mode=[1]) == cute.size(tCrVec)
26
+ for c in cutlass.range(cute.size(tCrC_f32_mn, mode=[1]), unroll_full=True):
27
+ tCrC_f32_mn[None, c].store(op(tCrC_f32_mn[None, c].load(), tCrVec[c]))
28
+ if const_expr(tCrC.element_type != Float32): # Convert back to original dtype
29
+ tCrC.store(tCrC_f32.load().to(tCrC.element_type))
@@ -0,0 +1,19 @@
1
+ # Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
2
+
3
+ from typing import Optional
4
+
5
+ import cutlass.cute as cute
6
+
7
+
8
+ def make_fake_tensor(dtype, shape, divisibility=1, leading_dim=-1) -> Optional[cute.Tensor]:
9
+ if leading_dim < 0:
10
+ leading_dim = len(shape) + leading_dim
11
+ if dtype is None:
12
+ return None
13
+ stride = tuple(
14
+ cute.sym_int64(divisibility=divisibility) if i != leading_dim else 1
15
+ for i in range(len(shape))
16
+ )
17
+ return cute.runtime.make_fake_tensor(
18
+ dtype, shape, stride=stride, assumed_align=divisibility * dtype.width // 8
19
+ )