quack-kernels 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -34
- quack/gemm.py +194 -0
- quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
- quack/gemm_config.py +72 -46
- quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +177 -31
- quack/gemm_sm100.py +729 -506
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +3 -1
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +476 -526
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +23 -16
- quack/topk.py +409 -85
- quack/utils.py +32 -220
- quack/varlen_utils.py +370 -1
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
- quack_kernels-0.2.3.dist-info/RECORD +44 -0
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.2.dist-info/RECORD +0 -37
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/__init__.py
CHANGED
|
@@ -1,16 +1,9 @@
|
|
|
1
|
-
__version__ = "0.2.
|
|
2
|
-
|
|
3
|
-
import cutlass.cute as cute
|
|
1
|
+
__version__ = "0.2.3"
|
|
4
2
|
|
|
5
3
|
from quack.rmsnorm import rmsnorm
|
|
6
4
|
from quack.softmax import softmax
|
|
7
5
|
from quack.cross_entropy import cross_entropy
|
|
8
6
|
|
|
9
|
-
import quack.cute_dsl_utils
|
|
10
|
-
|
|
11
|
-
# Patch cute.compile to optionally dump SASS
|
|
12
|
-
cute.compile = quack.cute_dsl_utils.cute_compile_patched
|
|
13
|
-
|
|
14
7
|
__all__ = [
|
|
15
8
|
"rmsnorm",
|
|
16
9
|
"softmax",
|
quack/activation.py
CHANGED
|
@@ -3,37 +3,86 @@
|
|
|
3
3
|
import math
|
|
4
4
|
from typing import Tuple
|
|
5
5
|
|
|
6
|
-
import cutlass
|
|
7
6
|
import cutlass.cute as cute
|
|
8
|
-
from cutlass import Float32
|
|
9
|
-
from cutlass.cutlass_dsl import dsl_user_op
|
|
7
|
+
from cutlass import Float32, Boolean, const_expr
|
|
8
|
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
|
9
|
+
from cutlass._mlir.dialects import llvm
|
|
10
|
+
|
|
11
|
+
import quack.utils as utils
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
F32_or_F32x2 = Float32 | Tuple[Float32, Float32]
|
|
10
15
|
|
|
11
16
|
|
|
12
17
|
@dsl_user_op
|
|
13
|
-
def
|
|
14
|
-
return
|
|
18
|
+
def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
19
|
+
return Float32(
|
|
20
|
+
llvm.inline_asm(
|
|
21
|
+
T.f32(),
|
|
22
|
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
23
|
+
"tanh.approx.f32 $0, $1;",
|
|
24
|
+
"=f,f",
|
|
25
|
+
has_side_effects=False,
|
|
26
|
+
is_align_stack=False,
|
|
27
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
28
|
+
)
|
|
29
|
+
)
|
|
15
30
|
|
|
16
31
|
|
|
17
32
|
@dsl_user_op
|
|
18
|
-
def
|
|
19
|
-
|
|
33
|
+
def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
34
|
+
if const_expr(not isinstance(x, tuple)):
|
|
35
|
+
# return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
|
|
36
|
+
return 0.5 + 0.5 * tanh(0.5 * x)
|
|
37
|
+
else:
|
|
38
|
+
x_half = utils.mul_packed_f32x2((0.5, 0.5), x)
|
|
39
|
+
tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
|
|
40
|
+
return utils.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5))
|
|
20
41
|
|
|
21
42
|
|
|
22
|
-
@cute.jit
|
|
23
43
|
@dsl_user_op
|
|
24
|
-
def
|
|
25
|
-
|
|
26
|
-
return dout
|
|
44
|
+
def dsigmoid_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
|
|
45
|
+
# return dout * out * (1.0 - out)
|
|
46
|
+
return dout * (out - out * out)
|
|
27
47
|
|
|
28
48
|
|
|
29
49
|
@dsl_user_op
|
|
30
|
-
def
|
|
31
|
-
|
|
50
|
+
def relu(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
51
|
+
if const_expr(not isinstance(x, tuple)):
|
|
52
|
+
return cute.arch.fmax(x, Float32(0.0))
|
|
53
|
+
else:
|
|
54
|
+
return cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0))
|
|
32
55
|
|
|
33
56
|
|
|
57
|
+
@dsl_user_op
|
|
34
58
|
@cute.jit
|
|
59
|
+
def drelu(
|
|
60
|
+
x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
|
61
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
|
|
62
|
+
if const_expr(not isinstance(x, tuple)):
|
|
63
|
+
x_pos = Boolean(x > 0)
|
|
64
|
+
return dout if x_pos else Float32(0.0), cute.arch.fmax(x, Float32(0.0))
|
|
65
|
+
else:
|
|
66
|
+
x0_pos = Boolean(x[0] > 0)
|
|
67
|
+
x1_pos = Boolean(x[1] > 0)
|
|
68
|
+
dx = (dout[0] if x0_pos else Float32(0.0), dout[1] if x1_pos else Float32(0.0))
|
|
69
|
+
return dx, relu(x)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dsl_user_op
|
|
73
|
+
def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
74
|
+
if const_expr(not isinstance(x, tuple)):
|
|
75
|
+
return cute.arch.fmax(x, Float32(0.0)) * x
|
|
76
|
+
else:
|
|
77
|
+
relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)))
|
|
78
|
+
return utils.mul_packed_f32x2(relu_x, x)
|
|
79
|
+
|
|
80
|
+
|
|
35
81
|
@dsl_user_op
|
|
36
|
-
|
|
82
|
+
@cute.jit
|
|
83
|
+
def drelu_sq(
|
|
84
|
+
x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
|
85
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
|
|
37
86
|
"""
|
|
38
87
|
ReLU squared backward pass: computes gradient w.r.t. x and recomputes forward
|
|
39
88
|
Given: relu_sq_out = max(x, 0) * x, and dout = grad w.r.t. relu_sq_out
|
|
@@ -41,29 +90,49 @@ def drelu_sq(x: Float32, dout: Float32, *, loc=None, ip=None) -> Tuple[Float32,
|
|
|
41
90
|
- dx = dout * 2 * x if x > 0, else 0
|
|
42
91
|
- relu_sq_out = max(x, 0) * x
|
|
43
92
|
"""
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
93
|
+
if const_expr(not isinstance(x, tuple)):
|
|
94
|
+
relu_x = relu(x)
|
|
95
|
+
relu_sq_out = relu_x * x
|
|
96
|
+
# Derivative: d/dx[max(x,0) * x] = 2*x if x > 0, else 0
|
|
97
|
+
dx = 2.0 * (dout * relu_x)
|
|
98
|
+
return dx, relu_sq_out
|
|
99
|
+
else:
|
|
100
|
+
relu_x = relu(x)
|
|
101
|
+
relu_sq_out = utils.mul_packed_f32x2(relu_x, x)
|
|
102
|
+
dx = utils.mul_packed_f32x2((2.0, 2.0), utils.mul_packed_f32x2(dout, relu_x))
|
|
103
|
+
return dx, relu_sq_out
|
|
49
104
|
|
|
50
105
|
|
|
51
106
|
@dsl_user_op
|
|
52
|
-
def gelu_tanh_approx(x:
|
|
107
|
+
def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
53
108
|
"""
|
|
54
109
|
gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
|
55
110
|
= 0.5 * x * (1 + tanh(x * (0.797885 + 0.0356774 * x * x)))
|
|
56
111
|
"""
|
|
57
112
|
sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885
|
|
58
113
|
sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
114
|
+
if const_expr(not isinstance(x, tuple)):
|
|
115
|
+
return 0.5 * (
|
|
116
|
+
x
|
|
117
|
+
# Currently cute.math.tanh(x, fastmath=True) generates very slow code
|
|
118
|
+
# * (1 + cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)), fastmath=True))
|
|
119
|
+
* (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x))))
|
|
120
|
+
)
|
|
121
|
+
else:
|
|
122
|
+
x_sq = utils.mul_packed_f32x2(x, x)
|
|
123
|
+
x_sq_scaled = utils.fma_packed_f32x2(
|
|
124
|
+
x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
|
|
125
|
+
)
|
|
126
|
+
z = utils.mul_packed_f32x2(x, x_sq_scaled)
|
|
127
|
+
tanh_z = (tanh(z[0]), tanh(z[1]))
|
|
128
|
+
x_tanh_z = utils.fma_packed_f32x2(tanh_z, x, x)
|
|
129
|
+
return utils.mul_packed_f32x2((0.5, 0.5), x_tanh_z)
|
|
63
130
|
|
|
64
131
|
|
|
65
132
|
@dsl_user_op
|
|
66
|
-
def dgelu_tanh_approx(
|
|
133
|
+
def dgelu_tanh_approx(
|
|
134
|
+
x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
|
135
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
|
|
67
136
|
"""
|
|
68
137
|
GELU tanh approximation backward pass: computes gradient w.r.t. x and recomputes forward
|
|
69
138
|
Given: gelu_out = 0.5 * x * (1 + tanh(x * (c1 + c2 * x^2))), and dout = grad w.r.t. gelu_out
|
|
@@ -78,43 +147,123 @@ def dgelu_tanh_approx(x: Float32, dout: Float32, *, loc=None, ip=None) -> Tuple[
|
|
|
78
147
|
sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # c2 ~0.0356774
|
|
79
148
|
sqrt_2_over_pi_coeff_3 = 3.0 * sqrt_2_over_pi_coeff # c3 ~0.01070322
|
|
80
149
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
150
|
+
if const_expr(not isinstance(x, tuple)):
|
|
151
|
+
# Compute z = x * (c1 + c2 * x^2)
|
|
152
|
+
x_sq = x * x
|
|
153
|
+
# tanh_z = cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq), fastmath=True)
|
|
154
|
+
tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq))
|
|
155
|
+
half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z
|
|
156
|
+
gelu_out = x * half_tanh_z_plus_one
|
|
157
|
+
|
|
158
|
+
# Compute gradient
|
|
159
|
+
# sech^2(z) = 1 - tanh^2(z)
|
|
160
|
+
sech2_z = 1 - tanh_z * tanh_z
|
|
161
|
+
# dz/dx = c1 + 3 * c2 * x^2
|
|
162
|
+
dz_dx = sqrt_2_over_pi + sqrt_2_over_pi_coeff_3 * x_sq
|
|
163
|
+
# d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
|
164
|
+
dgelu = half_tanh_z_plus_one + x * (0.5 * (sech2_z * dz_dx))
|
|
165
|
+
|
|
166
|
+
dx = dout * dgelu
|
|
167
|
+
return dx, gelu_out
|
|
168
|
+
else:
|
|
169
|
+
# Compute z = x * (c1 + c2 * x^2)
|
|
170
|
+
x_sq = utils.mul_packed_f32x2(x, x)
|
|
171
|
+
x_sq_scaled = utils.fma_packed_f32x2(
|
|
172
|
+
x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
|
|
173
|
+
)
|
|
174
|
+
z = utils.mul_packed_f32x2(x, x_sq_scaled)
|
|
175
|
+
tanh_z = (tanh(z[0]), tanh(z[1]))
|
|
176
|
+
half_tanh_z_plus_one = utils.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5))
|
|
177
|
+
gelu_out = utils.mul_packed_f32x2(x, half_tanh_z_plus_one)
|
|
178
|
+
|
|
179
|
+
# Compute gradient
|
|
180
|
+
# sech^2(z) = 1 - tanh^2(z)
|
|
181
|
+
sech2_z = utils.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0))
|
|
182
|
+
# dz/dx = c1 + 3 * c2 * x^2
|
|
183
|
+
dz_dx = utils.fma_packed_f32x2(
|
|
184
|
+
x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi)
|
|
185
|
+
)
|
|
186
|
+
# d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
|
187
|
+
sech2_dz_dx = utils.mul_packed_f32x2(sech2_z, dz_dx)
|
|
188
|
+
x_sech2_dz_dx = utils.mul_packed_f32x2(x, sech2_dz_dx)
|
|
189
|
+
dgelu = utils.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one)
|
|
190
|
+
|
|
191
|
+
dx = utils.mul_packed_f32x2(dout, dgelu)
|
|
192
|
+
return dx, gelu_out
|
|
86
193
|
|
|
87
|
-
# Compute gradient
|
|
88
|
-
# sech^2(z) = 1 - tanh^2(z)
|
|
89
|
-
sech2_z = 1 - tanh_z * tanh_z
|
|
90
|
-
# dz/dx = c1 + 3 * c2 * x^2
|
|
91
|
-
dz_dx = sqrt_2_over_pi + sqrt_2_over_pi_coeff_3 * x_sq
|
|
92
|
-
# d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
|
93
|
-
dgelu = half_tanh_z_plus_one + x * (0.5 * (sech2_z * dz_dx))
|
|
94
194
|
|
|
95
|
-
|
|
96
|
-
|
|
195
|
+
@dsl_user_op
|
|
196
|
+
@cute.jit
|
|
197
|
+
def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
198
|
+
if const_expr(not isinstance(x, tuple)):
|
|
199
|
+
use_linear = Boolean(x > 20.0)
|
|
200
|
+
return (
|
|
201
|
+
cute.math.log(Float32(cute.math.exp(x, fastmath=True)) + 1.0, fastmath=True)
|
|
202
|
+
if not use_linear
|
|
203
|
+
else x
|
|
204
|
+
)
|
|
205
|
+
else:
|
|
206
|
+
log2_e = math.log2(math.e)
|
|
207
|
+
x_log2e = utils.mul_packed_f32x2(x, (log2_e, log2_e))
|
|
208
|
+
x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True))
|
|
209
|
+
x_exp_p1 = utils.add_packed_f32x2(x_exp, (1.0, 1.0))
|
|
210
|
+
log_x_exp_p1 = (
|
|
211
|
+
cute.math.log2(x_exp_p1[0], fastmath=True),
|
|
212
|
+
cute.math.log2(x_exp_p1[1], fastmath=True),
|
|
213
|
+
)
|
|
214
|
+
ln2 = math.log(2.0)
|
|
215
|
+
softplus_x = utils.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2))
|
|
216
|
+
use_linear_0 = Boolean(x[0] > 20.0)
|
|
217
|
+
use_linear_1 = Boolean(x[1] > 20.0)
|
|
218
|
+
return (
|
|
219
|
+
softplus_x[0] if not use_linear_0 else x[0],
|
|
220
|
+
softplus_x[1] if not use_linear_1 else x[1],
|
|
221
|
+
)
|
|
97
222
|
|
|
98
223
|
|
|
99
224
|
@dsl_user_op
|
|
100
|
-
|
|
225
|
+
@cute.jit
|
|
226
|
+
def dsoftplus_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
|
|
227
|
+
use_linear = Boolean(out > 20.0)
|
|
228
|
+
# dx = dout * (1.0 - cute.math.exp(-out, fastmath=True)) if not use_linear else dout
|
|
229
|
+
dx = dout - dout * cute.math.exp(-out, fastmath=True)
|
|
230
|
+
return dx if not use_linear else dout
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
@dsl_user_op
|
|
234
|
+
def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) -> F32_or_F32x2:
|
|
101
235
|
"""
|
|
102
236
|
silu(x) = x * sigmoid(x) = x * (1 + tanh(x / 2)) / 2 = (0.5 * x) * tanh(0.5 * x) + (0.5 * x)
|
|
103
237
|
This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA.
|
|
104
238
|
"""
|
|
105
|
-
|
|
106
|
-
|
|
239
|
+
if const_expr(not isinstance(x, tuple)):
|
|
240
|
+
x_half = 0.5 * x if const_expr(not already_halved) else x
|
|
241
|
+
# return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
|
|
242
|
+
return x_half * tanh(x_half) + x_half
|
|
243
|
+
else:
|
|
244
|
+
x_half = utils.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x
|
|
245
|
+
tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
|
|
246
|
+
return utils.fma_packed_f32x2(x_half, tanh_x_half, x_half)
|
|
107
247
|
|
|
108
248
|
|
|
109
249
|
@dsl_user_op
|
|
110
|
-
def swiglu(x:
|
|
111
|
-
|
|
250
|
+
def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
251
|
+
if const_expr(not isinstance(x, tuple)):
|
|
252
|
+
return silu(x) * y
|
|
253
|
+
else:
|
|
254
|
+
return utils.mul_packed_f32x2(silu(x), y)
|
|
112
255
|
|
|
113
256
|
|
|
114
257
|
@dsl_user_op
|
|
115
258
|
def dswiglu(
|
|
116
|
-
x:
|
|
117
|
-
|
|
259
|
+
x: F32_or_F32x2,
|
|
260
|
+
y: F32_or_F32x2,
|
|
261
|
+
dout: F32_or_F32x2,
|
|
262
|
+
*,
|
|
263
|
+
already_halved: bool = False,
|
|
264
|
+
loc=None,
|
|
265
|
+
ip=None,
|
|
266
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
|
118
267
|
"""
|
|
119
268
|
SwiGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
|
120
269
|
Given: swiglu_out = silu(x) * y, and dout = grad w.r.t. swiglu_out
|
|
@@ -125,42 +274,77 @@ def dswiglu(
|
|
|
125
274
|
This has been optimized to use fewer instructions (i.e. we expand things out
|
|
126
275
|
to use FFMA instead of FADD and FMUL).
|
|
127
276
|
"""
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
277
|
+
if const_expr(not isinstance(x, tuple)):
|
|
278
|
+
# Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x))
|
|
279
|
+
# FMUL, MUFU.TANH, then FFMA
|
|
280
|
+
if const_expr(not already_halved):
|
|
281
|
+
sigmoid_x = sigmoid(x)
|
|
282
|
+
silu_x = x * sigmoid_x # FMUL
|
|
283
|
+
else:
|
|
284
|
+
tanh_x = tanh(x) # MUFU.TANH
|
|
285
|
+
sigmoid_x = 0.5 * tanh_x + 0.5 # FFMA
|
|
286
|
+
silu_x = x * tanh_x + x # FFMA
|
|
287
|
+
silu_x_dout = silu_x * dout # FMUL
|
|
288
|
+
# d_silu(x) * dout
|
|
289
|
+
# = sigmoid_x * (1 + x * (1 - sigmoid_x)) * dout
|
|
290
|
+
# = (sigmoid_x + sigmoid_x * x * (1 - sigmoid_x)) * dout
|
|
291
|
+
# = (sigmoid_x + silu_x * (1 - sigmoid_x)) * dout
|
|
292
|
+
# = (sigmoid_x + silu_x - silu_x * sigmoid_x) * dout
|
|
293
|
+
# = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
|
|
294
|
+
d_silu_x_dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x_dout # FFMA, FFMA
|
|
295
|
+
dx = d_silu_x_dout * y # FMUL
|
|
296
|
+
dy = silu_x_dout
|
|
297
|
+
swiglu_out = silu_x * y # FMUL
|
|
298
|
+
# Overall it's 1 MUFU.TANH, 5 FMUL, 3 FFMA
|
|
299
|
+
return dx, dy, swiglu_out
|
|
300
|
+
else:
|
|
301
|
+
# Compute sigmoid(x) and silu(x)
|
|
302
|
+
if const_expr(not already_halved):
|
|
303
|
+
sigmoid_x = sigmoid(x)
|
|
304
|
+
silu_x = utils.mul_packed_f32x2(x, sigmoid_x)
|
|
305
|
+
else:
|
|
306
|
+
tanh_x = (tanh(x[0]), tanh(x[1]))
|
|
307
|
+
sigmoid_x = utils.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5))
|
|
308
|
+
silu_x = utils.fma_packed_f32x2(x, tanh_x, x)
|
|
309
|
+
silu_x_dout = utils.mul_packed_f32x2(silu_x, dout)
|
|
310
|
+
# d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
|
|
311
|
+
sigmoid_x_minus_silu_x_sigmoid_x = utils.fma_packed_f32x2(
|
|
312
|
+
sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x
|
|
313
|
+
)
|
|
314
|
+
d_silu_x_dout = utils.fma_packed_f32x2(sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout)
|
|
315
|
+
dx = utils.mul_packed_f32x2(d_silu_x_dout, y)
|
|
316
|
+
dy = silu_x_dout
|
|
317
|
+
swiglu_out = utils.mul_packed_f32x2(silu_x, y)
|
|
318
|
+
return dx, dy, swiglu_out
|
|
145
319
|
|
|
146
320
|
|
|
147
321
|
@dsl_user_op
|
|
148
|
-
def swiglu_oai(
|
|
322
|
+
def swiglu_oai(
|
|
323
|
+
x: F32_or_F32x2, y: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
|
|
324
|
+
) -> F32_or_F32x2:
|
|
149
325
|
"""The swiglu variant used in gpt-oss, which has a scaling factor on x and bias of 1 to y.
|
|
150
326
|
https://github.com/openai/gpt-oss/blob/7be9334950053a888e24887a57dac797a17d6e00/gpt_oss/torch/model.py#L249
|
|
151
327
|
x * sigmoid(alpha * x) * (y + 1)
|
|
152
328
|
Compile down to FMUL, FMUL, TANH, FFMA, FFMA
|
|
153
329
|
"""
|
|
154
330
|
# Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
331
|
+
if const_expr(not isinstance(x, tuple)):
|
|
332
|
+
x_half = 0.5 * x
|
|
333
|
+
# silu_x = x_half * cute.math.tanh(alpha * x_half, fastmath=True) + x_half
|
|
334
|
+
silu_x = x_half * tanh(alpha * x_half) + x_half
|
|
335
|
+
return silu_x * y + silu_x
|
|
336
|
+
else:
|
|
337
|
+
x_half = utils.mul_packed_f32x2((0.5, 0.5), x)
|
|
338
|
+
alpha_x_half = utils.mul_packed_f32x2((alpha, alpha), x_half)
|
|
339
|
+
tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
|
|
340
|
+
silu_x = utils.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half)
|
|
341
|
+
return utils.fma_packed_f32x2(silu_x, y, silu_x)
|
|
158
342
|
|
|
159
343
|
|
|
160
344
|
@dsl_user_op
|
|
161
345
|
def dswiglu_oai(
|
|
162
|
-
x:
|
|
163
|
-
) -> Tuple[
|
|
346
|
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
|
|
347
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
|
164
348
|
"""
|
|
165
349
|
Swiglu OAI backward pass: computes gradients w.r.t. x and y
|
|
166
350
|
Given: swiglu_oai_out = x * sigmoid(alpha * x) * (y + 1), and dout = grad w.r.t. swiglu_oai_out
|
|
@@ -169,35 +353,60 @@ def dswiglu_oai(
|
|
|
169
353
|
Derivative of x * sigmoid(alpha * x) w.r.t. x:
|
|
170
354
|
d/dx[x * sigmoid(alpha * x)] = sigmoid(alpha * x) + alpha * x * sigmoid(alpha * x) * (1 - sigmoid(alpha * x))
|
|
171
355
|
"""
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
356
|
+
if const_expr(not isinstance(x, tuple)):
|
|
357
|
+
# Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
|
|
358
|
+
alpha_x_half = (0.5 * alpha) * x # FMUL
|
|
359
|
+
# MUFU.TANH, then FFMA
|
|
360
|
+
# sigmoid_alpha_x = 0.5 + 0.5 * cute.math.tanh(alpha_x_half, fastmath=True)
|
|
361
|
+
sigmoid_alpha_x = 0.5 + 0.5 * tanh(alpha_x_half)
|
|
362
|
+
silu_x = x * sigmoid_alpha_x # FMUL
|
|
363
|
+
silu_x_dout = silu_x * dout # FMUL
|
|
364
|
+
# FFMA, FFMA, FMUL
|
|
365
|
+
d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
|
|
366
|
+
dx = d_silu_x_dout * y + d_silu_x_dout # FFMA, instead of multiply by y + 1
|
|
367
|
+
dy = silu_x_dout
|
|
368
|
+
swiglu_out = silu_x * y + silu_x # FFMA, instead of multiply by y + 1
|
|
369
|
+
# Overall it's 1 MUFU.TANH, 4 FMUL, 5 FFMA
|
|
370
|
+
return dx, dy, swiglu_out
|
|
371
|
+
else:
|
|
372
|
+
# Compute sigmoid(alpha * x)
|
|
373
|
+
alpha_x_half = utils.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x)
|
|
374
|
+
tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
|
|
375
|
+
sigmoid_alpha_x = utils.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5))
|
|
376
|
+
silu_x = utils.mul_packed_f32x2(x, sigmoid_alpha_x)
|
|
377
|
+
silu_x_dout = utils.mul_packed_f32x2(silu_x, dout)
|
|
378
|
+
# d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
|
|
379
|
+
silu_x_minus_product = utils.fma_packed_f32x2(
|
|
380
|
+
silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x
|
|
381
|
+
)
|
|
382
|
+
sigmoid_plus_alpha_diff = utils.fma_packed_f32x2(
|
|
383
|
+
(alpha, alpha), silu_x_minus_product, sigmoid_alpha_x
|
|
384
|
+
)
|
|
385
|
+
d_silu_x_dout = utils.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout)
|
|
386
|
+
dx = utils.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout)
|
|
387
|
+
dy = silu_x_dout
|
|
388
|
+
swiglu_out = utils.fma_packed_f32x2(silu_x, y, silu_x)
|
|
389
|
+
return dx, dy, swiglu_out
|
|
185
390
|
|
|
186
391
|
|
|
187
392
|
@dsl_user_op
|
|
188
|
-
def glu(x:
|
|
393
|
+
def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
189
394
|
"""GLU: Gated Linear Unit
|
|
190
395
|
glu(x, y) = sigmoid(x) * y
|
|
191
396
|
Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2))
|
|
192
397
|
"""
|
|
193
|
-
|
|
194
|
-
|
|
398
|
+
if const_expr(not isinstance(x, tuple)):
|
|
399
|
+
sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
|
|
400
|
+
return sigmoid_x * y # FMUL
|
|
401
|
+
else:
|
|
402
|
+
sigmoid_x = sigmoid(x)
|
|
403
|
+
return utils.mul_packed_f32x2(sigmoid_x, y)
|
|
195
404
|
|
|
196
405
|
|
|
197
406
|
@dsl_user_op
|
|
198
407
|
def dglu(
|
|
199
|
-
x:
|
|
200
|
-
) -> Tuple[
|
|
408
|
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
|
409
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
|
201
410
|
"""
|
|
202
411
|
GLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
|
203
412
|
Given: glu_out = sigmoid(x) * y, and dout = grad w.r.t. glu_out
|
|
@@ -206,33 +415,47 @@ def dglu(
|
|
|
206
415
|
- dy = dout * sigmoid(x)
|
|
207
416
|
- glu_out = sigmoid(x) * y
|
|
208
417
|
"""
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
418
|
+
if const_expr(not isinstance(x, tuple)):
|
|
419
|
+
# Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2))
|
|
420
|
+
sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
|
|
421
|
+
sigmoid_x_dout = sigmoid_x * dout # FMUL
|
|
422
|
+
glu_out = sigmoid_x * y # FMUL
|
|
423
|
+
# dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout
|
|
424
|
+
# = y * (1 - sigmoid(x)) * sigmoid_x_dout
|
|
425
|
+
# = (y - y * sigmoid(x)) * sigmoid_x_dout
|
|
426
|
+
# = (y - glu_out) * sigmoid_x_dout
|
|
427
|
+
dx = (y - glu_out) * sigmoid_x_dout # FADD, FMUL
|
|
428
|
+
dy = sigmoid_x_dout
|
|
429
|
+
# Total: 1 MUFU.TANH, 4 FMUL, 1 FADD, 1 FFMA
|
|
430
|
+
return dx, dy, glu_out
|
|
431
|
+
else:
|
|
432
|
+
sigmoid_x = sigmoid(x)
|
|
433
|
+
sigmoid_x_dout = utils.mul_packed_f32x2(sigmoid_x, dout)
|
|
434
|
+
glu_out = utils.mul_packed_f32x2(sigmoid_x, y)
|
|
435
|
+
# dx = (y - glu_out) * sigmoid_x_dout
|
|
436
|
+
y_minus_glu_out = utils.sub_packed_f32x2(y, glu_out)
|
|
437
|
+
dx = utils.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout)
|
|
438
|
+
dy = sigmoid_x_dout
|
|
439
|
+
return dx, dy, glu_out
|
|
221
440
|
|
|
222
441
|
|
|
223
442
|
@dsl_user_op
|
|
224
|
-
def reglu(x:
|
|
443
|
+
def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
225
444
|
"""ReGLU: ReLU Gated Linear Unit
|
|
226
445
|
reglu(x, y) = relu(x) * y = max(x, 0) * y
|
|
227
446
|
"""
|
|
228
|
-
|
|
447
|
+
if const_expr(not isinstance(x, tuple)):
|
|
448
|
+
return cute.arch.fmax(x, Float32(0.0)) * y
|
|
449
|
+
else:
|
|
450
|
+
relu_x = relu(x)
|
|
451
|
+
return utils.mul_packed_f32x2(relu_x, y)
|
|
229
452
|
|
|
230
453
|
|
|
231
|
-
@cute.jit
|
|
232
454
|
@dsl_user_op
|
|
455
|
+
@cute.jit
|
|
233
456
|
def dreglu(
|
|
234
|
-
x:
|
|
235
|
-
) -> Tuple[
|
|
457
|
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
|
458
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
|
236
459
|
"""
|
|
237
460
|
ReGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
|
238
461
|
Given: reglu_out = relu(x) * y, and dout = grad w.r.t. reglu_out
|
|
@@ -241,27 +464,40 @@ def dreglu(
|
|
|
241
464
|
- dy = dout * relu(x)
|
|
242
465
|
- reglu_out = relu(x) * y
|
|
243
466
|
"""
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
467
|
+
if const_expr(not isinstance(x, tuple)):
|
|
468
|
+
x_pos = Boolean(x > 0)
|
|
469
|
+
relu_x = cute.arch.fmax(x, Float32(0.0))
|
|
470
|
+
dx = (dout * y) if x_pos else Float32(0.0)
|
|
471
|
+
dy = dout * relu_x
|
|
472
|
+
reglu_out = relu_x * y
|
|
473
|
+
return dx, dy, reglu_out
|
|
474
|
+
else:
|
|
475
|
+
x0_pos = Boolean(x[0] > 0)
|
|
476
|
+
x1_pos = Boolean(x[1] > 0)
|
|
477
|
+
relu_x = relu(x)
|
|
478
|
+
dout_y = utils.mul_packed_f32x2(dout, y)
|
|
479
|
+
dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0)))
|
|
480
|
+
dy = utils.mul_packed_f32x2(dout, relu_x)
|
|
481
|
+
reglu_out = utils.mul_packed_f32x2(relu_x, y)
|
|
482
|
+
return dx, dy, reglu_out
|
|
250
483
|
|
|
251
484
|
|
|
252
485
|
@dsl_user_op
|
|
253
|
-
def geglu(x:
|
|
486
|
+
def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
254
487
|
"""GeGLU: GELU Gated Linear Unit
|
|
255
488
|
geglu(x, y) = gelu(x) * y
|
|
256
489
|
Uses the tanh approximation of GELU
|
|
257
490
|
"""
|
|
258
|
-
|
|
491
|
+
if const_expr(not isinstance(x, tuple)):
|
|
492
|
+
return gelu_tanh_approx(x) * y
|
|
493
|
+
else:
|
|
494
|
+
return utils.mul_packed_f32x2(gelu_tanh_approx(x), y)
|
|
259
495
|
|
|
260
496
|
|
|
261
497
|
@dsl_user_op
|
|
262
498
|
def dgeglu(
|
|
263
|
-
x:
|
|
264
|
-
) -> Tuple[
|
|
499
|
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
|
500
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
|
265
501
|
"""
|
|
266
502
|
GeGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
|
267
503
|
Given: geglu_out = gelu(x) * y, and dout = grad w.r.t. geglu_out
|
|
@@ -270,10 +506,19 @@ def dgeglu(
|
|
|
270
506
|
- dy = dout * gelu(x)
|
|
271
507
|
- geglu_out = gelu(x) * y
|
|
272
508
|
"""
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
509
|
+
if const_expr(not isinstance(x, tuple)):
|
|
510
|
+
# Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
|
|
511
|
+
dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
|
|
512
|
+
# Compute gradients for geglu
|
|
513
|
+
dx = dgelu_x_dout * y
|
|
514
|
+
dy = gelu_x * dout
|
|
515
|
+
geglu_out = gelu_x * y
|
|
516
|
+
return dx, dy, geglu_out
|
|
517
|
+
else:
|
|
518
|
+
# Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
|
|
519
|
+
dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
|
|
520
|
+
# Compute gradients for geglu
|
|
521
|
+
dx = utils.mul_packed_f32x2(dgelu_x_dout, y)
|
|
522
|
+
dy = utils.mul_packed_f32x2(gelu_x, dout)
|
|
523
|
+
geglu_out = utils.mul_packed_f32x2(gelu_x, y)
|
|
524
|
+
return dx, dy, geglu_out
|