quack-kernels 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/broadcast_utils.py +29 -0
  4. quack/compile_utils.py +19 -0
  5. quack/copy_utils.py +487 -0
  6. quack/cross_entropy.py +157 -233
  7. quack/cute_dsl_utils.py +20 -34
  8. quack/gemm.py +194 -0
  9. quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
  10. quack/gemm_config.py +72 -46
  11. quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
  12. quack/gemm_default_epi.py +259 -0
  13. quack/gemm_interface.py +177 -31
  14. quack/gemm_sm100.py +729 -506
  15. quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
  16. quack/gemm_symmetric.py +330 -0
  17. quack/gemm_wrapper_utils.py +3 -1
  18. quack/layout_utils.py +287 -0
  19. quack/linear.py +24 -16
  20. quack/pipeline.py +158 -3
  21. quack/reduce.py +88 -49
  22. quack/reduction_base.py +25 -36
  23. quack/rmsnorm.py +476 -526
  24. quack/sm100_utils.py +62 -0
  25. quack/sm90_utils.py +127 -0
  26. quack/softmax.py +135 -203
  27. quack/sort/bitonic_sort.py +13 -10
  28. quack/sort/utils.py +6 -6
  29. quack/tile_scheduler.py +23 -16
  30. quack/topk.py +409 -85
  31. quack/utils.py +32 -220
  32. quack/varlen_utils.py +370 -1
  33. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
  34. quack_kernels-0.2.3.dist-info/RECORD +44 -0
  35. quack/layernorm.py +0 -353
  36. quack/symmetric_dense_gemm_sm90.py +0 -2091
  37. quack_kernels-0.2.2.dist-info/RECORD +0 -37
  38. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
  39. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
  40. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/__init__.py CHANGED
@@ -1,16 +1,9 @@
1
- __version__ = "0.2.2"
2
-
3
- import cutlass.cute as cute
1
+ __version__ = "0.2.3"
4
2
 
5
3
  from quack.rmsnorm import rmsnorm
6
4
  from quack.softmax import softmax
7
5
  from quack.cross_entropy import cross_entropy
8
6
 
9
- import quack.cute_dsl_utils
10
-
11
- # Patch cute.compile to optionally dump SASS
12
- cute.compile = quack.cute_dsl_utils.cute_compile_patched
13
-
14
7
  __all__ = [
15
8
  "rmsnorm",
16
9
  "softmax",
quack/activation.py CHANGED
@@ -3,37 +3,86 @@
3
3
  import math
4
4
  from typing import Tuple
5
5
 
6
- import cutlass
7
6
  import cutlass.cute as cute
8
- from cutlass import Float32
9
- from cutlass.cutlass_dsl import dsl_user_op
7
+ from cutlass import Float32, Boolean, const_expr
8
+ from cutlass.cutlass_dsl import T, dsl_user_op
9
+ from cutlass._mlir.dialects import llvm
10
+
11
+ import quack.utils as utils
12
+
13
+
14
+ F32_or_F32x2 = Float32 | Tuple[Float32, Float32]
10
15
 
11
16
 
12
17
  @dsl_user_op
13
- def sigmoid(x: Float32, *, loc=None, ip=None) -> Float32:
14
- return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
18
+ def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
19
+ return Float32(
20
+ llvm.inline_asm(
21
+ T.f32(),
22
+ [Float32(a).ir_value(loc=loc, ip=ip)],
23
+ "tanh.approx.f32 $0, $1;",
24
+ "=f,f",
25
+ has_side_effects=False,
26
+ is_align_stack=False,
27
+ asm_dialect=llvm.AsmDialect.AD_ATT,
28
+ )
29
+ )
15
30
 
16
31
 
17
32
  @dsl_user_op
18
- def relu(x: Float32, *, loc=None, ip=None) -> Float32:
19
- return cute.arch.fmax(x, Float32(0.0))
33
+ def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
34
+ if const_expr(not isinstance(x, tuple)):
35
+ # return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
36
+ return 0.5 + 0.5 * tanh(0.5 * x)
37
+ else:
38
+ x_half = utils.mul_packed_f32x2((0.5, 0.5), x)
39
+ tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
40
+ return utils.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5))
20
41
 
21
42
 
22
- @cute.jit
23
43
  @dsl_user_op
24
- def drelu(x: Float32, dout: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
25
- x_pos = cutlass.Boolean(x > 0)
26
- return dout if x_pos else Float32(0.0), cute.arch.fmax(x, Float32(0.0))
44
+ def dsigmoid_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
45
+ # return dout * out * (1.0 - out)
46
+ return dout * (out - out * out)
27
47
 
28
48
 
29
49
  @dsl_user_op
30
- def relu_sq(x: Float32, *, loc=None, ip=None) -> Float32:
31
- return cute.arch.fmax(x, Float32(0.0)) * x
50
+ def relu(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
51
+ if const_expr(not isinstance(x, tuple)):
52
+ return cute.arch.fmax(x, Float32(0.0))
53
+ else:
54
+ return cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0))
32
55
 
33
56
 
57
+ @dsl_user_op
34
58
  @cute.jit
59
+ def drelu(
60
+ x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
61
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
62
+ if const_expr(not isinstance(x, tuple)):
63
+ x_pos = Boolean(x > 0)
64
+ return dout if x_pos else Float32(0.0), cute.arch.fmax(x, Float32(0.0))
65
+ else:
66
+ x0_pos = Boolean(x[0] > 0)
67
+ x1_pos = Boolean(x[1] > 0)
68
+ dx = (dout[0] if x0_pos else Float32(0.0), dout[1] if x1_pos else Float32(0.0))
69
+ return dx, relu(x)
70
+
71
+
72
+ @dsl_user_op
73
+ def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
74
+ if const_expr(not isinstance(x, tuple)):
75
+ return cute.arch.fmax(x, Float32(0.0)) * x
76
+ else:
77
+ relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)))
78
+ return utils.mul_packed_f32x2(relu_x, x)
79
+
80
+
35
81
  @dsl_user_op
36
- def drelu_sq(x: Float32, dout: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
82
+ @cute.jit
83
+ def drelu_sq(
84
+ x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
85
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
37
86
  """
38
87
  ReLU squared backward pass: computes gradient w.r.t. x and recomputes forward
39
88
  Given: relu_sq_out = max(x, 0) * x, and dout = grad w.r.t. relu_sq_out
@@ -41,29 +90,49 @@ def drelu_sq(x: Float32, dout: Float32, *, loc=None, ip=None) -> Tuple[Float32,
41
90
  - dx = dout * 2 * x if x > 0, else 0
42
91
  - relu_sq_out = max(x, 0) * x
43
92
  """
44
- x_pos = cutlass.Boolean(x > 0)
45
- relu_sq_out = cute.arch.fmax(x, Float32(0.0)) * x
46
- # Derivative: d/dx[max(x,0) * x] = 2*x if x > 0, else 0
47
- dx = (2.0 * dout * x) if x_pos else Float32(0.0)
48
- return dx, relu_sq_out
93
+ if const_expr(not isinstance(x, tuple)):
94
+ relu_x = relu(x)
95
+ relu_sq_out = relu_x * x
96
+ # Derivative: d/dx[max(x,0) * x] = 2*x if x > 0, else 0
97
+ dx = 2.0 * (dout * relu_x)
98
+ return dx, relu_sq_out
99
+ else:
100
+ relu_x = relu(x)
101
+ relu_sq_out = utils.mul_packed_f32x2(relu_x, x)
102
+ dx = utils.mul_packed_f32x2((2.0, 2.0), utils.mul_packed_f32x2(dout, relu_x))
103
+ return dx, relu_sq_out
49
104
 
50
105
 
51
106
  @dsl_user_op
52
- def gelu_tanh_approx(x: Float32, *, loc=None, ip=None) -> Float32:
107
+ def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
53
108
  """
54
109
  gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
55
110
  = 0.5 * x * (1 + tanh(x * (0.797885 + 0.0356774 * x * x)))
56
111
  """
57
112
  sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885
58
113
  sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774
59
- return 0.5 * (
60
- x
61
- * (1 + cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)), fastmath=True))
62
- )
114
+ if const_expr(not isinstance(x, tuple)):
115
+ return 0.5 * (
116
+ x
117
+ # Currently cute.math.tanh(x, fastmath=True) generates very slow code
118
+ # * (1 + cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)), fastmath=True))
119
+ * (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x))))
120
+ )
121
+ else:
122
+ x_sq = utils.mul_packed_f32x2(x, x)
123
+ x_sq_scaled = utils.fma_packed_f32x2(
124
+ x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
125
+ )
126
+ z = utils.mul_packed_f32x2(x, x_sq_scaled)
127
+ tanh_z = (tanh(z[0]), tanh(z[1]))
128
+ x_tanh_z = utils.fma_packed_f32x2(tanh_z, x, x)
129
+ return utils.mul_packed_f32x2((0.5, 0.5), x_tanh_z)
63
130
 
64
131
 
65
132
  @dsl_user_op
66
- def dgelu_tanh_approx(x: Float32, dout: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
133
+ def dgelu_tanh_approx(
134
+ x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
135
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
67
136
  """
68
137
  GELU tanh approximation backward pass: computes gradient w.r.t. x and recomputes forward
69
138
  Given: gelu_out = 0.5 * x * (1 + tanh(x * (c1 + c2 * x^2))), and dout = grad w.r.t. gelu_out
@@ -78,43 +147,123 @@ def dgelu_tanh_approx(x: Float32, dout: Float32, *, loc=None, ip=None) -> Tuple[
78
147
  sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # c2 ~0.0356774
79
148
  sqrt_2_over_pi_coeff_3 = 3.0 * sqrt_2_over_pi_coeff # c3 ~0.01070322
80
149
 
81
- # Compute z = x * (c1 + c2 * x^2)
82
- x_sq = x * x
83
- tanh_z = cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq), fastmath=True)
84
- half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z
85
- gelu_out = x * half_tanh_z_plus_one
150
+ if const_expr(not isinstance(x, tuple)):
151
+ # Compute z = x * (c1 + c2 * x^2)
152
+ x_sq = x * x
153
+ # tanh_z = cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq), fastmath=True)
154
+ tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq))
155
+ half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z
156
+ gelu_out = x * half_tanh_z_plus_one
157
+
158
+ # Compute gradient
159
+ # sech^2(z) = 1 - tanh^2(z)
160
+ sech2_z = 1 - tanh_z * tanh_z
161
+ # dz/dx = c1 + 3 * c2 * x^2
162
+ dz_dx = sqrt_2_over_pi + sqrt_2_over_pi_coeff_3 * x_sq
163
+ # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
164
+ dgelu = half_tanh_z_plus_one + x * (0.5 * (sech2_z * dz_dx))
165
+
166
+ dx = dout * dgelu
167
+ return dx, gelu_out
168
+ else:
169
+ # Compute z = x * (c1 + c2 * x^2)
170
+ x_sq = utils.mul_packed_f32x2(x, x)
171
+ x_sq_scaled = utils.fma_packed_f32x2(
172
+ x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
173
+ )
174
+ z = utils.mul_packed_f32x2(x, x_sq_scaled)
175
+ tanh_z = (tanh(z[0]), tanh(z[1]))
176
+ half_tanh_z_plus_one = utils.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5))
177
+ gelu_out = utils.mul_packed_f32x2(x, half_tanh_z_plus_one)
178
+
179
+ # Compute gradient
180
+ # sech^2(z) = 1 - tanh^2(z)
181
+ sech2_z = utils.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0))
182
+ # dz/dx = c1 + 3 * c2 * x^2
183
+ dz_dx = utils.fma_packed_f32x2(
184
+ x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi)
185
+ )
186
+ # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
187
+ sech2_dz_dx = utils.mul_packed_f32x2(sech2_z, dz_dx)
188
+ x_sech2_dz_dx = utils.mul_packed_f32x2(x, sech2_dz_dx)
189
+ dgelu = utils.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one)
190
+
191
+ dx = utils.mul_packed_f32x2(dout, dgelu)
192
+ return dx, gelu_out
86
193
 
87
- # Compute gradient
88
- # sech^2(z) = 1 - tanh^2(z)
89
- sech2_z = 1 - tanh_z * tanh_z
90
- # dz/dx = c1 + 3 * c2 * x^2
91
- dz_dx = sqrt_2_over_pi + sqrt_2_over_pi_coeff_3 * x_sq
92
- # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
93
- dgelu = half_tanh_z_plus_one + x * (0.5 * (sech2_z * dz_dx))
94
194
 
95
- dx = dout * dgelu
96
- return dx, gelu_out
195
+ @dsl_user_op
196
+ @cute.jit
197
+ def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
198
+ if const_expr(not isinstance(x, tuple)):
199
+ use_linear = Boolean(x > 20.0)
200
+ return (
201
+ cute.math.log(Float32(cute.math.exp(x, fastmath=True)) + 1.0, fastmath=True)
202
+ if not use_linear
203
+ else x
204
+ )
205
+ else:
206
+ log2_e = math.log2(math.e)
207
+ x_log2e = utils.mul_packed_f32x2(x, (log2_e, log2_e))
208
+ x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True))
209
+ x_exp_p1 = utils.add_packed_f32x2(x_exp, (1.0, 1.0))
210
+ log_x_exp_p1 = (
211
+ cute.math.log2(x_exp_p1[0], fastmath=True),
212
+ cute.math.log2(x_exp_p1[1], fastmath=True),
213
+ )
214
+ ln2 = math.log(2.0)
215
+ softplus_x = utils.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2))
216
+ use_linear_0 = Boolean(x[0] > 20.0)
217
+ use_linear_1 = Boolean(x[1] > 20.0)
218
+ return (
219
+ softplus_x[0] if not use_linear_0 else x[0],
220
+ softplus_x[1] if not use_linear_1 else x[1],
221
+ )
97
222
 
98
223
 
99
224
  @dsl_user_op
100
- def silu(x: Float32, *, loc=None, ip=None) -> Float32:
225
+ @cute.jit
226
+ def dsoftplus_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
227
+ use_linear = Boolean(out > 20.0)
228
+ # dx = dout * (1.0 - cute.math.exp(-out, fastmath=True)) if not use_linear else dout
229
+ dx = dout - dout * cute.math.exp(-out, fastmath=True)
230
+ return dx if not use_linear else dout
231
+
232
+
233
+ @dsl_user_op
234
+ def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) -> F32_or_F32x2:
101
235
  """
102
236
  silu(x) = x * sigmoid(x) = x * (1 + tanh(x / 2)) / 2 = (0.5 * x) * tanh(0.5 * x) + (0.5 * x)
103
237
  This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA.
104
238
  """
105
- x_half = 0.5 * x
106
- return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
239
+ if const_expr(not isinstance(x, tuple)):
240
+ x_half = 0.5 * x if const_expr(not already_halved) else x
241
+ # return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
242
+ return x_half * tanh(x_half) + x_half
243
+ else:
244
+ x_half = utils.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x
245
+ tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
246
+ return utils.fma_packed_f32x2(x_half, tanh_x_half, x_half)
107
247
 
108
248
 
109
249
  @dsl_user_op
110
- def swiglu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
111
- return silu(x) * y
250
+ def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
251
+ if const_expr(not isinstance(x, tuple)):
252
+ return silu(x) * y
253
+ else:
254
+ return utils.mul_packed_f32x2(silu(x), y)
112
255
 
113
256
 
114
257
  @dsl_user_op
115
258
  def dswiglu(
116
- x: Float32, y: Float32, dout: Float32, *, loc=None, ip=None
117
- ) -> Tuple[Float32, Float32, Float32]:
259
+ x: F32_or_F32x2,
260
+ y: F32_or_F32x2,
261
+ dout: F32_or_F32x2,
262
+ *,
263
+ already_halved: bool = False,
264
+ loc=None,
265
+ ip=None,
266
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
118
267
  """
119
268
  SwiGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
120
269
  Given: swiglu_out = silu(x) * y, and dout = grad w.r.t. swiglu_out
@@ -125,42 +274,77 @@ def dswiglu(
125
274
  This has been optimized to use fewer instructions (i.e. we expand things out
126
275
  to use FFMA instead of FADD and FMUL).
127
276
  """
128
- # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x))
129
- # FMUL, MUFU.TANH, then FFMA
130
- sigmoid_x = sigmoid(x)
131
- silu_x = x * sigmoid_x # FMUL
132
- silu_x_dout = silu_x * dout # FMUL
133
- # d_silu(x) * dout
134
- # = sigmoid_x * (1 + x * (1 - sigmoid_x)) * dout
135
- # = (sigmoid_x + sigmoid_x * x * (1 - sigmoid_x)) * dout
136
- # = (sigmoid_x + silu_x * (1 - sigmoid_x)) * dout
137
- # = (sigmoid_x + silu_x - silu_x * sigmoid_x) * dout
138
- # = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
139
- d_silu_x_dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x_dout # FFMA, FFMA
140
- dx = d_silu_x_dout * y # FMUL
141
- dy = silu_x_dout
142
- swiglu_out = silu_x * y # FMUL
143
- # Overall it's 1 MUFU.TANH, 5 FMUL, 3 FFMA
144
- return dx, dy, swiglu_out
277
+ if const_expr(not isinstance(x, tuple)):
278
+ # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x))
279
+ # FMUL, MUFU.TANH, then FFMA
280
+ if const_expr(not already_halved):
281
+ sigmoid_x = sigmoid(x)
282
+ silu_x = x * sigmoid_x # FMUL
283
+ else:
284
+ tanh_x = tanh(x) # MUFU.TANH
285
+ sigmoid_x = 0.5 * tanh_x + 0.5 # FFMA
286
+ silu_x = x * tanh_x + x # FFMA
287
+ silu_x_dout = silu_x * dout # FMUL
288
+ # d_silu(x) * dout
289
+ # = sigmoid_x * (1 + x * (1 - sigmoid_x)) * dout
290
+ # = (sigmoid_x + sigmoid_x * x * (1 - sigmoid_x)) * dout
291
+ # = (sigmoid_x + silu_x * (1 - sigmoid_x)) * dout
292
+ # = (sigmoid_x + silu_x - silu_x * sigmoid_x) * dout
293
+ # = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
294
+ d_silu_x_dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x_dout # FFMA, FFMA
295
+ dx = d_silu_x_dout * y # FMUL
296
+ dy = silu_x_dout
297
+ swiglu_out = silu_x * y # FMUL
298
+ # Overall it's 1 MUFU.TANH, 5 FMUL, 3 FFMA
299
+ return dx, dy, swiglu_out
300
+ else:
301
+ # Compute sigmoid(x) and silu(x)
302
+ if const_expr(not already_halved):
303
+ sigmoid_x = sigmoid(x)
304
+ silu_x = utils.mul_packed_f32x2(x, sigmoid_x)
305
+ else:
306
+ tanh_x = (tanh(x[0]), tanh(x[1]))
307
+ sigmoid_x = utils.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5))
308
+ silu_x = utils.fma_packed_f32x2(x, tanh_x, x)
309
+ silu_x_dout = utils.mul_packed_f32x2(silu_x, dout)
310
+ # d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
311
+ sigmoid_x_minus_silu_x_sigmoid_x = utils.fma_packed_f32x2(
312
+ sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x
313
+ )
314
+ d_silu_x_dout = utils.fma_packed_f32x2(sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout)
315
+ dx = utils.mul_packed_f32x2(d_silu_x_dout, y)
316
+ dy = silu_x_dout
317
+ swiglu_out = utils.mul_packed_f32x2(silu_x, y)
318
+ return dx, dy, swiglu_out
145
319
 
146
320
 
147
321
  @dsl_user_op
148
- def swiglu_oai(x: Float32, y: Float32, alpha: float = 1.702, *, loc=None, ip=None) -> Float32:
322
+ def swiglu_oai(
323
+ x: F32_or_F32x2, y: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
324
+ ) -> F32_or_F32x2:
149
325
  """The swiglu variant used in gpt-oss, which has a scaling factor on x and bias of 1 to y.
150
326
  https://github.com/openai/gpt-oss/blob/7be9334950053a888e24887a57dac797a17d6e00/gpt_oss/torch/model.py#L249
151
327
  x * sigmoid(alpha * x) * (y + 1)
152
328
  Compile down to FMUL, FMUL, TANH, FFMA, FFMA
153
329
  """
154
330
  # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
155
- x_half = 0.5 * x
156
- silu_x = x_half * cute.math.tanh(alpha * x_half, fastmath=True) + x_half
157
- return silu_x * y + silu_x
331
+ if const_expr(not isinstance(x, tuple)):
332
+ x_half = 0.5 * x
333
+ # silu_x = x_half * cute.math.tanh(alpha * x_half, fastmath=True) + x_half
334
+ silu_x = x_half * tanh(alpha * x_half) + x_half
335
+ return silu_x * y + silu_x
336
+ else:
337
+ x_half = utils.mul_packed_f32x2((0.5, 0.5), x)
338
+ alpha_x_half = utils.mul_packed_f32x2((alpha, alpha), x_half)
339
+ tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
340
+ silu_x = utils.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half)
341
+ return utils.fma_packed_f32x2(silu_x, y, silu_x)
158
342
 
159
343
 
160
344
  @dsl_user_op
161
345
  def dswiglu_oai(
162
- x: Float32, y: Float32, dout: Float32, alpha: float = 1.702, *, loc=None, ip=None
163
- ) -> Tuple[Float32, Float32, Float32]:
346
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
347
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
164
348
  """
165
349
  Swiglu OAI backward pass: computes gradients w.r.t. x and y
166
350
  Given: swiglu_oai_out = x * sigmoid(alpha * x) * (y + 1), and dout = grad w.r.t. swiglu_oai_out
@@ -169,35 +353,60 @@ def dswiglu_oai(
169
353
  Derivative of x * sigmoid(alpha * x) w.r.t. x:
170
354
  d/dx[x * sigmoid(alpha * x)] = sigmoid(alpha * x) + alpha * x * sigmoid(alpha * x) * (1 - sigmoid(alpha * x))
171
355
  """
172
- # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
173
- alpha_x_half = (0.5 * alpha) * x # FMUL
174
- # MUFU.TANH, then FFMA
175
- sigmoid_alpha_x = 0.5 + 0.5 * cute.math.tanh(alpha_x_half, fastmath=True)
176
- silu_x = x * sigmoid_alpha_x # FMUL
177
- silu_x_dout = silu_x * dout # FMUL
178
- # FFMA, FFMA, FMUL
179
- d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
180
- dx = d_silu_x_dout * y + d_silu_x_dout # FFMA, instead of multiply by y + 1
181
- dy = silu_x_dout
182
- swiglu_out = silu_x * y + silu_x # FFMA, instead of multiply by y + 1
183
- # Overall it's 1 MUFU.TANH, 4 FMUL, 5 FFMA
184
- return dx, dy, swiglu_out
356
+ if const_expr(not isinstance(x, tuple)):
357
+ # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
358
+ alpha_x_half = (0.5 * alpha) * x # FMUL
359
+ # MUFU.TANH, then FFMA
360
+ # sigmoid_alpha_x = 0.5 + 0.5 * cute.math.tanh(alpha_x_half, fastmath=True)
361
+ sigmoid_alpha_x = 0.5 + 0.5 * tanh(alpha_x_half)
362
+ silu_x = x * sigmoid_alpha_x # FMUL
363
+ silu_x_dout = silu_x * dout # FMUL
364
+ # FFMA, FFMA, FMUL
365
+ d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
366
+ dx = d_silu_x_dout * y + d_silu_x_dout # FFMA, instead of multiply by y + 1
367
+ dy = silu_x_dout
368
+ swiglu_out = silu_x * y + silu_x # FFMA, instead of multiply by y + 1
369
+ # Overall it's 1 MUFU.TANH, 4 FMUL, 5 FFMA
370
+ return dx, dy, swiglu_out
371
+ else:
372
+ # Compute sigmoid(alpha * x)
373
+ alpha_x_half = utils.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x)
374
+ tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
375
+ sigmoid_alpha_x = utils.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5))
376
+ silu_x = utils.mul_packed_f32x2(x, sigmoid_alpha_x)
377
+ silu_x_dout = utils.mul_packed_f32x2(silu_x, dout)
378
+ # d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
379
+ silu_x_minus_product = utils.fma_packed_f32x2(
380
+ silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x
381
+ )
382
+ sigmoid_plus_alpha_diff = utils.fma_packed_f32x2(
383
+ (alpha, alpha), silu_x_minus_product, sigmoid_alpha_x
384
+ )
385
+ d_silu_x_dout = utils.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout)
386
+ dx = utils.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout)
387
+ dy = silu_x_dout
388
+ swiglu_out = utils.fma_packed_f32x2(silu_x, y, silu_x)
389
+ return dx, dy, swiglu_out
185
390
 
186
391
 
187
392
  @dsl_user_op
188
- def glu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
393
+ def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
189
394
  """GLU: Gated Linear Unit
190
395
  glu(x, y) = sigmoid(x) * y
191
396
  Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2))
192
397
  """
193
- sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
194
- return sigmoid_x * y # FMUL
398
+ if const_expr(not isinstance(x, tuple)):
399
+ sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
400
+ return sigmoid_x * y # FMUL
401
+ else:
402
+ sigmoid_x = sigmoid(x)
403
+ return utils.mul_packed_f32x2(sigmoid_x, y)
195
404
 
196
405
 
197
406
  @dsl_user_op
198
407
  def dglu(
199
- x: Float32, y: Float32, dout: Float32, *, loc=None, ip=None
200
- ) -> Tuple[Float32, Float32, Float32]:
408
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
409
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
201
410
  """
202
411
  GLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
203
412
  Given: glu_out = sigmoid(x) * y, and dout = grad w.r.t. glu_out
@@ -206,33 +415,47 @@ def dglu(
206
415
  - dy = dout * sigmoid(x)
207
416
  - glu_out = sigmoid(x) * y
208
417
  """
209
- # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2))
210
- sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
211
- sigmoid_x_dout = sigmoid_x * dout # FMUL
212
- glu_out = sigmoid_x * y # FMUL
213
- # dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout
214
- # = y * (1 - sigmoid(x)) * sigmoid_x_dout
215
- # = (y - y * sigmoid(x)) * sigmoid_x_dout
216
- # = (y - glu_out) * sigmoid_x_dout
217
- dx = (y - glu_out) * sigmoid_x_dout # FADD, FMUL
218
- dy = sigmoid_x_dout
219
- # Total: 1 MUFU.TANH, 4 FMUL, 1 FADD, 1 FFMA
220
- return dx, dy, glu_out
418
+ if const_expr(not isinstance(x, tuple)):
419
+ # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2))
420
+ sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
421
+ sigmoid_x_dout = sigmoid_x * dout # FMUL
422
+ glu_out = sigmoid_x * y # FMUL
423
+ # dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout
424
+ # = y * (1 - sigmoid(x)) * sigmoid_x_dout
425
+ # = (y - y * sigmoid(x)) * sigmoid_x_dout
426
+ # = (y - glu_out) * sigmoid_x_dout
427
+ dx = (y - glu_out) * sigmoid_x_dout # FADD, FMUL
428
+ dy = sigmoid_x_dout
429
+ # Total: 1 MUFU.TANH, 4 FMUL, 1 FADD, 1 FFMA
430
+ return dx, dy, glu_out
431
+ else:
432
+ sigmoid_x = sigmoid(x)
433
+ sigmoid_x_dout = utils.mul_packed_f32x2(sigmoid_x, dout)
434
+ glu_out = utils.mul_packed_f32x2(sigmoid_x, y)
435
+ # dx = (y - glu_out) * sigmoid_x_dout
436
+ y_minus_glu_out = utils.sub_packed_f32x2(y, glu_out)
437
+ dx = utils.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout)
438
+ dy = sigmoid_x_dout
439
+ return dx, dy, glu_out
221
440
 
222
441
 
223
442
  @dsl_user_op
224
- def reglu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
443
+ def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
225
444
  """ReGLU: ReLU Gated Linear Unit
226
445
  reglu(x, y) = relu(x) * y = max(x, 0) * y
227
446
  """
228
- return cute.arch.fmax(x, Float32(0.0)) * y
447
+ if const_expr(not isinstance(x, tuple)):
448
+ return cute.arch.fmax(x, Float32(0.0)) * y
449
+ else:
450
+ relu_x = relu(x)
451
+ return utils.mul_packed_f32x2(relu_x, y)
229
452
 
230
453
 
231
- @cute.jit
232
454
  @dsl_user_op
455
+ @cute.jit
233
456
  def dreglu(
234
- x: Float32, y: Float32, dout: Float32, *, loc=None, ip=None
235
- ) -> Tuple[Float32, Float32, Float32]:
457
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
458
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
236
459
  """
237
460
  ReGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
238
461
  Given: reglu_out = relu(x) * y, and dout = grad w.r.t. reglu_out
@@ -241,27 +464,40 @@ def dreglu(
241
464
  - dy = dout * relu(x)
242
465
  - reglu_out = relu(x) * y
243
466
  """
244
- x_pos = cutlass.Boolean(x > 0)
245
- relu_x = cute.arch.fmax(x, Float32(0.0))
246
- dx = (dout * y) if x_pos else Float32(0.0)
247
- dy = dout * relu_x
248
- reglu_out = relu_x * y
249
- return dx, dy, reglu_out
467
+ if const_expr(not isinstance(x, tuple)):
468
+ x_pos = Boolean(x > 0)
469
+ relu_x = cute.arch.fmax(x, Float32(0.0))
470
+ dx = (dout * y) if x_pos else Float32(0.0)
471
+ dy = dout * relu_x
472
+ reglu_out = relu_x * y
473
+ return dx, dy, reglu_out
474
+ else:
475
+ x0_pos = Boolean(x[0] > 0)
476
+ x1_pos = Boolean(x[1] > 0)
477
+ relu_x = relu(x)
478
+ dout_y = utils.mul_packed_f32x2(dout, y)
479
+ dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0)))
480
+ dy = utils.mul_packed_f32x2(dout, relu_x)
481
+ reglu_out = utils.mul_packed_f32x2(relu_x, y)
482
+ return dx, dy, reglu_out
250
483
 
251
484
 
252
485
  @dsl_user_op
253
- def geglu(x: Float32, y: Float32, *, loc=None, ip=None) -> Float32:
486
+ def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
254
487
  """GeGLU: GELU Gated Linear Unit
255
488
  geglu(x, y) = gelu(x) * y
256
489
  Uses the tanh approximation of GELU
257
490
  """
258
- return gelu_tanh_approx(x) * y
491
+ if const_expr(not isinstance(x, tuple)):
492
+ return gelu_tanh_approx(x) * y
493
+ else:
494
+ return utils.mul_packed_f32x2(gelu_tanh_approx(x), y)
259
495
 
260
496
 
261
497
  @dsl_user_op
262
498
  def dgeglu(
263
- x: Float32, y: Float32, dout: Float32, *, loc=None, ip=None
264
- ) -> Tuple[Float32, Float32, Float32]:
499
+ x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
500
+ ) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
265
501
  """
266
502
  GeGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
267
503
  Given: geglu_out = gelu(x) * y, and dout = grad w.r.t. geglu_out
@@ -270,10 +506,19 @@ def dgeglu(
270
506
  - dy = dout * gelu(x)
271
507
  - geglu_out = gelu(x) * y
272
508
  """
273
- # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
274
- dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
275
- # Compute gradients for geglu
276
- dx = dgelu_x_dout * y
277
- dy = gelu_x * dout
278
- geglu_out = gelu_x * y
279
- return dx, dy, geglu_out
509
+ if const_expr(not isinstance(x, tuple)):
510
+ # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
511
+ dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
512
+ # Compute gradients for geglu
513
+ dx = dgelu_x_dout * y
514
+ dy = gelu_x * dout
515
+ geglu_out = gelu_x * y
516
+ return dx, dy, geglu_out
517
+ else:
518
+ # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
519
+ dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
520
+ # Compute gradients for geglu
521
+ dx = utils.mul_packed_f32x2(dgelu_x_dout, y)
522
+ dy = utils.mul_packed_f32x2(gelu_x, dout)
523
+ geglu_out = utils.mul_packed_f32x2(gelu_x, y)
524
+ return dx, dy, geglu_out