quack-kernels 0.2.1__tar.gz → 0.2.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {quack_kernels-0.2.1/quack_kernels.egg-info → quack_kernels-0.2.3}/PKG-INFO +4 -2
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/pyproject.toml +3 -1
- quack_kernels-0.2.3/quack/__init__.py +11 -0
- quack_kernels-0.2.3/quack/activation.py +524 -0
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/autotuner.py +64 -5
- quack_kernels-0.2.3/quack/broadcast_utils.py +29 -0
- quack_kernels-0.2.3/quack/compile_utils.py +19 -0
- quack_kernels-0.2.3/quack/copy_utils.py +487 -0
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/cross_entropy.py +157 -233
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/cute_dsl_utils.py +20 -35
- quack_kernels-0.2.3/quack/gemm.py +194 -0
- quack_kernels-0.2.3/quack/gemm_act.py +510 -0
- quack_kernels-0.2.3/quack/gemm_config.py +95 -0
- quack_kernels-0.2.3/quack/gemm_dact.py +215 -0
- quack_kernels-0.2.3/quack/gemm_default_epi.py +259 -0
- quack_kernels-0.2.3/quack/gemm_interface.py +1038 -0
- quack_kernels-0.2.1/quack/dense_gemm_sm100.py → quack_kernels-0.2.3/quack/gemm_sm100.py +1034 -787
- quack_kernels-0.2.1/quack/dense_gemm_sm90.py → quack_kernels-0.2.3/quack/gemm_sm90.py +552 -727
- quack_kernels-0.2.3/quack/gemm_symmetric.py +330 -0
- quack_kernels-0.2.3/quack/gemm_wrapper_utils.py +317 -0
- quack_kernels-0.2.3/quack/layout_utils.py +287 -0
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/linear.py +24 -16
- quack_kernels-0.2.3/quack/pipeline.py +306 -0
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/reduce.py +88 -49
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/reduction_base.py +25 -36
- quack_kernels-0.2.3/quack/rmsnorm.py +1134 -0
- quack_kernels-0.2.3/quack/sm100_utils.py +62 -0
- quack_kernels-0.2.3/quack/sm90_utils.py +127 -0
- quack_kernels-0.2.3/quack/softmax.py +403 -0
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/sort/bitonic_sort.py +13 -10
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/sort/utils.py +6 -6
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/tile_scheduler.py +55 -61
- quack_kernels-0.2.3/quack/topk.py +551 -0
- quack_kernels-0.2.3/quack/utils.py +223 -0
- quack_kernels-0.2.3/quack/varlen_utils.py +386 -0
- {quack_kernels-0.2.1 → quack_kernels-0.2.3/quack_kernels.egg-info}/PKG-INFO +4 -2
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack_kernels.egg-info/SOURCES.txt +16 -7
- quack_kernels-0.2.3/quack_kernels.egg-info/requires.txt +8 -0
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/tests/test_layernorm.py +17 -51
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/tests/test_linear.py +43 -18
- quack_kernels-0.2.3/tests/test_linear_varlen_k.py +312 -0
- quack_kernels-0.2.3/tests/test_linear_varlen_m.py +395 -0
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/tests/test_rmsnorm.py +26 -17
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/tests/test_softmax.py +1 -2
- quack_kernels-0.2.1/tests/test_symmetric_dense_gemm_sm90.py → quack_kernels-0.2.3/tests/test_symmetric_gemm.py +15 -16
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/tests/test_topk.py +42 -23
- quack_kernels-0.2.1/quack/__init__.py +0 -18
- quack_kernels-0.2.1/quack/activation.py +0 -279
- quack_kernels-0.2.1/quack/gemm_act_sm90.py +0 -368
- quack_kernels-0.2.1/quack/gemm_config.py +0 -69
- quack_kernels-0.2.1/quack/gemm_dact_sm90.py +0 -150
- quack_kernels-0.2.1/quack/gemm_interface.py +0 -569
- quack_kernels-0.2.1/quack/gemm_wrapper_utils.py +0 -158
- quack_kernels-0.2.1/quack/layernorm.py +0 -353
- quack_kernels-0.2.1/quack/pipeline.py +0 -151
- quack_kernels-0.2.1/quack/rmsnorm.py +0 -1250
- quack_kernels-0.2.1/quack/softmax.py +0 -471
- quack_kernels-0.2.1/quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.1/quack/topk.py +0 -227
- quack_kernels-0.2.1/quack/utils.py +0 -358
- quack_kernels-0.2.1/quack/varlen_utils.py +0 -22
- quack_kernels-0.2.1/quack_kernels.egg-info/requires.txt +0 -6
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/LICENSE +0 -0
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/README.md +0 -0
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/fast_math.py +0 -0
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/linear_cross_entropy.py +0 -0
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/mlp.py +0 -0
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/sort/generate_sorting_networks.py +0 -0
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/sort/sorting_networks.py +0 -0
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack/tensormap_manager.py +0 -0
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack_kernels.egg-info/dependency_links.txt +0 -0
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/quack_kernels.egg-info/top_level.txt +0 -0
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/setup.cfg +0 -0
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/tests/test_cross_entropy.py +1 -1
- {quack_kernels-0.2.1 → quack_kernels-0.2.3}/tests/test_linear_cross_entropy.py +0 -0
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: quack-kernels
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.3
|
|
4
4
|
Requires-Python: >=3.10
|
|
5
5
|
License-File: LICENSE
|
|
6
|
-
Requires-Dist: nvidia-cutlass-dsl==4.
|
|
6
|
+
Requires-Dist: nvidia-cutlass-dsl==4.3.3
|
|
7
7
|
Requires-Dist: torch
|
|
8
|
+
Requires-Dist: apache-tvm-ffi<0.2,>=0.1.5
|
|
9
|
+
Requires-Dist: torch-c-dlpack-ext
|
|
8
10
|
Provides-Extra: dev
|
|
9
11
|
Requires-Dist: pre-commit; extra == "dev"
|
|
10
12
|
Requires-Dist: ruff; extra == "dev"
|
|
@@ -7,8 +7,10 @@ name = "quack-kernels"
|
|
|
7
7
|
dynamic = ["version"]
|
|
8
8
|
requires-python = ">=3.10"
|
|
9
9
|
dependencies = [
|
|
10
|
-
"nvidia-cutlass-dsl==4.
|
|
10
|
+
"nvidia-cutlass-dsl==4.3.3",
|
|
11
11
|
"torch",
|
|
12
|
+
"apache-tvm-ffi>=0.1.5,<0.2",
|
|
13
|
+
"torch-c-dlpack-ext",
|
|
12
14
|
]
|
|
13
15
|
|
|
14
16
|
[project.optional-dependencies]
|
|
@@ -0,0 +1,524 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
import cutlass.cute as cute
|
|
7
|
+
from cutlass import Float32, Boolean, const_expr
|
|
8
|
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
|
9
|
+
from cutlass._mlir.dialects import llvm
|
|
10
|
+
|
|
11
|
+
import quack.utils as utils
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
F32_or_F32x2 = Float32 | Tuple[Float32, Float32]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dsl_user_op
|
|
18
|
+
def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
|
19
|
+
return Float32(
|
|
20
|
+
llvm.inline_asm(
|
|
21
|
+
T.f32(),
|
|
22
|
+
[Float32(a).ir_value(loc=loc, ip=ip)],
|
|
23
|
+
"tanh.approx.f32 $0, $1;",
|
|
24
|
+
"=f,f",
|
|
25
|
+
has_side_effects=False,
|
|
26
|
+
is_align_stack=False,
|
|
27
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
28
|
+
)
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dsl_user_op
|
|
33
|
+
def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
34
|
+
if const_expr(not isinstance(x, tuple)):
|
|
35
|
+
# return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True)
|
|
36
|
+
return 0.5 + 0.5 * tanh(0.5 * x)
|
|
37
|
+
else:
|
|
38
|
+
x_half = utils.mul_packed_f32x2((0.5, 0.5), x)
|
|
39
|
+
tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
|
|
40
|
+
return utils.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5))
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dsl_user_op
|
|
44
|
+
def dsigmoid_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
|
|
45
|
+
# return dout * out * (1.0 - out)
|
|
46
|
+
return dout * (out - out * out)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dsl_user_op
|
|
50
|
+
def relu(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
51
|
+
if const_expr(not isinstance(x, tuple)):
|
|
52
|
+
return cute.arch.fmax(x, Float32(0.0))
|
|
53
|
+
else:
|
|
54
|
+
return cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0))
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dsl_user_op
|
|
58
|
+
@cute.jit
|
|
59
|
+
def drelu(
|
|
60
|
+
x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
|
61
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
|
|
62
|
+
if const_expr(not isinstance(x, tuple)):
|
|
63
|
+
x_pos = Boolean(x > 0)
|
|
64
|
+
return dout if x_pos else Float32(0.0), cute.arch.fmax(x, Float32(0.0))
|
|
65
|
+
else:
|
|
66
|
+
x0_pos = Boolean(x[0] > 0)
|
|
67
|
+
x1_pos = Boolean(x[1] > 0)
|
|
68
|
+
dx = (dout[0] if x0_pos else Float32(0.0), dout[1] if x1_pos else Float32(0.0))
|
|
69
|
+
return dx, relu(x)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dsl_user_op
|
|
73
|
+
def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
74
|
+
if const_expr(not isinstance(x, tuple)):
|
|
75
|
+
return cute.arch.fmax(x, Float32(0.0)) * x
|
|
76
|
+
else:
|
|
77
|
+
relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)))
|
|
78
|
+
return utils.mul_packed_f32x2(relu_x, x)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dsl_user_op
|
|
82
|
+
@cute.jit
|
|
83
|
+
def drelu_sq(
|
|
84
|
+
x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
|
85
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
|
|
86
|
+
"""
|
|
87
|
+
ReLU squared backward pass: computes gradient w.r.t. x and recomputes forward
|
|
88
|
+
Given: relu_sq_out = max(x, 0) * x, and dout = grad w.r.t. relu_sq_out
|
|
89
|
+
Returns: (dx, relu_sq_out) where:
|
|
90
|
+
- dx = dout * 2 * x if x > 0, else 0
|
|
91
|
+
- relu_sq_out = max(x, 0) * x
|
|
92
|
+
"""
|
|
93
|
+
if const_expr(not isinstance(x, tuple)):
|
|
94
|
+
relu_x = relu(x)
|
|
95
|
+
relu_sq_out = relu_x * x
|
|
96
|
+
# Derivative: d/dx[max(x,0) * x] = 2*x if x > 0, else 0
|
|
97
|
+
dx = 2.0 * (dout * relu_x)
|
|
98
|
+
return dx, relu_sq_out
|
|
99
|
+
else:
|
|
100
|
+
relu_x = relu(x)
|
|
101
|
+
relu_sq_out = utils.mul_packed_f32x2(relu_x, x)
|
|
102
|
+
dx = utils.mul_packed_f32x2((2.0, 2.0), utils.mul_packed_f32x2(dout, relu_x))
|
|
103
|
+
return dx, relu_sq_out
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@dsl_user_op
|
|
107
|
+
def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
108
|
+
"""
|
|
109
|
+
gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
|
110
|
+
= 0.5 * x * (1 + tanh(x * (0.797885 + 0.0356774 * x * x)))
|
|
111
|
+
"""
|
|
112
|
+
sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885
|
|
113
|
+
sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774
|
|
114
|
+
if const_expr(not isinstance(x, tuple)):
|
|
115
|
+
return 0.5 * (
|
|
116
|
+
x
|
|
117
|
+
# Currently cute.math.tanh(x, fastmath=True) generates very slow code
|
|
118
|
+
# * (1 + cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)), fastmath=True))
|
|
119
|
+
* (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x))))
|
|
120
|
+
)
|
|
121
|
+
else:
|
|
122
|
+
x_sq = utils.mul_packed_f32x2(x, x)
|
|
123
|
+
x_sq_scaled = utils.fma_packed_f32x2(
|
|
124
|
+
x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
|
|
125
|
+
)
|
|
126
|
+
z = utils.mul_packed_f32x2(x, x_sq_scaled)
|
|
127
|
+
tanh_z = (tanh(z[0]), tanh(z[1]))
|
|
128
|
+
x_tanh_z = utils.fma_packed_f32x2(tanh_z, x, x)
|
|
129
|
+
return utils.mul_packed_f32x2((0.5, 0.5), x_tanh_z)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@dsl_user_op
|
|
133
|
+
def dgelu_tanh_approx(
|
|
134
|
+
x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
|
135
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2]:
|
|
136
|
+
"""
|
|
137
|
+
GELU tanh approximation backward pass: computes gradient w.r.t. x and recomputes forward
|
|
138
|
+
Given: gelu_out = 0.5 * x * (1 + tanh(x * (c1 + c2 * x^2))), and dout = grad w.r.t. gelu_out
|
|
139
|
+
Returns: (dx, gelu_out)
|
|
140
|
+
|
|
141
|
+
Derivative uses the chain rule:
|
|
142
|
+
d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
|
143
|
+
where z = x * (c1 + c2 * x^2), dz/dx = c1 + 3 * c2 * x^2
|
|
144
|
+
and sech^2(z) = 1 - tanh^2(z)
|
|
145
|
+
"""
|
|
146
|
+
sqrt_2_over_pi = math.sqrt(2 / math.pi) # c1 ~0.797885
|
|
147
|
+
sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # c2 ~0.0356774
|
|
148
|
+
sqrt_2_over_pi_coeff_3 = 3.0 * sqrt_2_over_pi_coeff # c3 ~0.01070322
|
|
149
|
+
|
|
150
|
+
if const_expr(not isinstance(x, tuple)):
|
|
151
|
+
# Compute z = x * (c1 + c2 * x^2)
|
|
152
|
+
x_sq = x * x
|
|
153
|
+
# tanh_z = cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq), fastmath=True)
|
|
154
|
+
tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq))
|
|
155
|
+
half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z
|
|
156
|
+
gelu_out = x * half_tanh_z_plus_one
|
|
157
|
+
|
|
158
|
+
# Compute gradient
|
|
159
|
+
# sech^2(z) = 1 - tanh^2(z)
|
|
160
|
+
sech2_z = 1 - tanh_z * tanh_z
|
|
161
|
+
# dz/dx = c1 + 3 * c2 * x^2
|
|
162
|
+
dz_dx = sqrt_2_over_pi + sqrt_2_over_pi_coeff_3 * x_sq
|
|
163
|
+
# d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
|
164
|
+
dgelu = half_tanh_z_plus_one + x * (0.5 * (sech2_z * dz_dx))
|
|
165
|
+
|
|
166
|
+
dx = dout * dgelu
|
|
167
|
+
return dx, gelu_out
|
|
168
|
+
else:
|
|
169
|
+
# Compute z = x * (c1 + c2 * x^2)
|
|
170
|
+
x_sq = utils.mul_packed_f32x2(x, x)
|
|
171
|
+
x_sq_scaled = utils.fma_packed_f32x2(
|
|
172
|
+
x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi)
|
|
173
|
+
)
|
|
174
|
+
z = utils.mul_packed_f32x2(x, x_sq_scaled)
|
|
175
|
+
tanh_z = (tanh(z[0]), tanh(z[1]))
|
|
176
|
+
half_tanh_z_plus_one = utils.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5))
|
|
177
|
+
gelu_out = utils.mul_packed_f32x2(x, half_tanh_z_plus_one)
|
|
178
|
+
|
|
179
|
+
# Compute gradient
|
|
180
|
+
# sech^2(z) = 1 - tanh^2(z)
|
|
181
|
+
sech2_z = utils.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0))
|
|
182
|
+
# dz/dx = c1 + 3 * c2 * x^2
|
|
183
|
+
dz_dx = utils.fma_packed_f32x2(
|
|
184
|
+
x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi)
|
|
185
|
+
)
|
|
186
|
+
# d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx
|
|
187
|
+
sech2_dz_dx = utils.mul_packed_f32x2(sech2_z, dz_dx)
|
|
188
|
+
x_sech2_dz_dx = utils.mul_packed_f32x2(x, sech2_dz_dx)
|
|
189
|
+
dgelu = utils.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one)
|
|
190
|
+
|
|
191
|
+
dx = utils.mul_packed_f32x2(dout, dgelu)
|
|
192
|
+
return dx, gelu_out
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@dsl_user_op
|
|
196
|
+
@cute.jit
|
|
197
|
+
def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
198
|
+
if const_expr(not isinstance(x, tuple)):
|
|
199
|
+
use_linear = Boolean(x > 20.0)
|
|
200
|
+
return (
|
|
201
|
+
cute.math.log(Float32(cute.math.exp(x, fastmath=True)) + 1.0, fastmath=True)
|
|
202
|
+
if not use_linear
|
|
203
|
+
else x
|
|
204
|
+
)
|
|
205
|
+
else:
|
|
206
|
+
log2_e = math.log2(math.e)
|
|
207
|
+
x_log2e = utils.mul_packed_f32x2(x, (log2_e, log2_e))
|
|
208
|
+
x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True))
|
|
209
|
+
x_exp_p1 = utils.add_packed_f32x2(x_exp, (1.0, 1.0))
|
|
210
|
+
log_x_exp_p1 = (
|
|
211
|
+
cute.math.log2(x_exp_p1[0], fastmath=True),
|
|
212
|
+
cute.math.log2(x_exp_p1[1], fastmath=True),
|
|
213
|
+
)
|
|
214
|
+
ln2 = math.log(2.0)
|
|
215
|
+
softplus_x = utils.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2))
|
|
216
|
+
use_linear_0 = Boolean(x[0] > 20.0)
|
|
217
|
+
use_linear_1 = Boolean(x[1] > 20.0)
|
|
218
|
+
return (
|
|
219
|
+
softplus_x[0] if not use_linear_0 else x[0],
|
|
220
|
+
softplus_x[1] if not use_linear_1 else x[1],
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
@dsl_user_op
|
|
225
|
+
@cute.jit
|
|
226
|
+
def dsoftplus_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32:
|
|
227
|
+
use_linear = Boolean(out > 20.0)
|
|
228
|
+
# dx = dout * (1.0 - cute.math.exp(-out, fastmath=True)) if not use_linear else dout
|
|
229
|
+
dx = dout - dout * cute.math.exp(-out, fastmath=True)
|
|
230
|
+
return dx if not use_linear else dout
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
@dsl_user_op
|
|
234
|
+
def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) -> F32_or_F32x2:
|
|
235
|
+
"""
|
|
236
|
+
silu(x) = x * sigmoid(x) = x * (1 + tanh(x / 2)) / 2 = (0.5 * x) * tanh(0.5 * x) + (0.5 * x)
|
|
237
|
+
This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA.
|
|
238
|
+
"""
|
|
239
|
+
if const_expr(not isinstance(x, tuple)):
|
|
240
|
+
x_half = 0.5 * x if const_expr(not already_halved) else x
|
|
241
|
+
# return x_half * cute.math.tanh(x_half, fastmath=True) + x_half
|
|
242
|
+
return x_half * tanh(x_half) + x_half
|
|
243
|
+
else:
|
|
244
|
+
x_half = utils.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x
|
|
245
|
+
tanh_x_half = (tanh(x_half[0]), tanh(x_half[1]))
|
|
246
|
+
return utils.fma_packed_f32x2(x_half, tanh_x_half, x_half)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
@dsl_user_op
|
|
250
|
+
def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
251
|
+
if const_expr(not isinstance(x, tuple)):
|
|
252
|
+
return silu(x) * y
|
|
253
|
+
else:
|
|
254
|
+
return utils.mul_packed_f32x2(silu(x), y)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
@dsl_user_op
|
|
258
|
+
def dswiglu(
|
|
259
|
+
x: F32_or_F32x2,
|
|
260
|
+
y: F32_or_F32x2,
|
|
261
|
+
dout: F32_or_F32x2,
|
|
262
|
+
*,
|
|
263
|
+
already_halved: bool = False,
|
|
264
|
+
loc=None,
|
|
265
|
+
ip=None,
|
|
266
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
|
267
|
+
"""
|
|
268
|
+
SwiGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
|
269
|
+
Given: swiglu_out = silu(x) * y, and dout = grad w.r.t. swiglu_out
|
|
270
|
+
Returns: (dx, dy, swiglu_out) where dx = dout * y * d_silu(x), dy = dout * silu(x)
|
|
271
|
+
|
|
272
|
+
d_silu(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
|
273
|
+
|
|
274
|
+
This has been optimized to use fewer instructions (i.e. we expand things out
|
|
275
|
+
to use FFMA instead of FADD and FMUL).
|
|
276
|
+
"""
|
|
277
|
+
if const_expr(not isinstance(x, tuple)):
|
|
278
|
+
# Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x))
|
|
279
|
+
# FMUL, MUFU.TANH, then FFMA
|
|
280
|
+
if const_expr(not already_halved):
|
|
281
|
+
sigmoid_x = sigmoid(x)
|
|
282
|
+
silu_x = x * sigmoid_x # FMUL
|
|
283
|
+
else:
|
|
284
|
+
tanh_x = tanh(x) # MUFU.TANH
|
|
285
|
+
sigmoid_x = 0.5 * tanh_x + 0.5 # FFMA
|
|
286
|
+
silu_x = x * tanh_x + x # FFMA
|
|
287
|
+
silu_x_dout = silu_x * dout # FMUL
|
|
288
|
+
# d_silu(x) * dout
|
|
289
|
+
# = sigmoid_x * (1 + x * (1 - sigmoid_x)) * dout
|
|
290
|
+
# = (sigmoid_x + sigmoid_x * x * (1 - sigmoid_x)) * dout
|
|
291
|
+
# = (sigmoid_x + silu_x * (1 - sigmoid_x)) * dout
|
|
292
|
+
# = (sigmoid_x + silu_x - silu_x * sigmoid_x) * dout
|
|
293
|
+
# = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
|
|
294
|
+
d_silu_x_dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x_dout # FFMA, FFMA
|
|
295
|
+
dx = d_silu_x_dout * y # FMUL
|
|
296
|
+
dy = silu_x_dout
|
|
297
|
+
swiglu_out = silu_x * y # FMUL
|
|
298
|
+
# Overall it's 1 MUFU.TANH, 5 FMUL, 3 FFMA
|
|
299
|
+
return dx, dy, swiglu_out
|
|
300
|
+
else:
|
|
301
|
+
# Compute sigmoid(x) and silu(x)
|
|
302
|
+
if const_expr(not already_halved):
|
|
303
|
+
sigmoid_x = sigmoid(x)
|
|
304
|
+
silu_x = utils.mul_packed_f32x2(x, sigmoid_x)
|
|
305
|
+
else:
|
|
306
|
+
tanh_x = (tanh(x[0]), tanh(x[1]))
|
|
307
|
+
sigmoid_x = utils.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5))
|
|
308
|
+
silu_x = utils.fma_packed_f32x2(x, tanh_x, x)
|
|
309
|
+
silu_x_dout = utils.mul_packed_f32x2(silu_x, dout)
|
|
310
|
+
# d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout
|
|
311
|
+
sigmoid_x_minus_silu_x_sigmoid_x = utils.fma_packed_f32x2(
|
|
312
|
+
sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x
|
|
313
|
+
)
|
|
314
|
+
d_silu_x_dout = utils.fma_packed_f32x2(sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout)
|
|
315
|
+
dx = utils.mul_packed_f32x2(d_silu_x_dout, y)
|
|
316
|
+
dy = silu_x_dout
|
|
317
|
+
swiglu_out = utils.mul_packed_f32x2(silu_x, y)
|
|
318
|
+
return dx, dy, swiglu_out
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
@dsl_user_op
|
|
322
|
+
def swiglu_oai(
|
|
323
|
+
x: F32_or_F32x2, y: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
|
|
324
|
+
) -> F32_or_F32x2:
|
|
325
|
+
"""The swiglu variant used in gpt-oss, which has a scaling factor on x and bias of 1 to y.
|
|
326
|
+
https://github.com/openai/gpt-oss/blob/7be9334950053a888e24887a57dac797a17d6e00/gpt_oss/torch/model.py#L249
|
|
327
|
+
x * sigmoid(alpha * x) * (y + 1)
|
|
328
|
+
Compile down to FMUL, FMUL, TANH, FFMA, FFMA
|
|
329
|
+
"""
|
|
330
|
+
# Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
|
|
331
|
+
if const_expr(not isinstance(x, tuple)):
|
|
332
|
+
x_half = 0.5 * x
|
|
333
|
+
# silu_x = x_half * cute.math.tanh(alpha * x_half, fastmath=True) + x_half
|
|
334
|
+
silu_x = x_half * tanh(alpha * x_half) + x_half
|
|
335
|
+
return silu_x * y + silu_x
|
|
336
|
+
else:
|
|
337
|
+
x_half = utils.mul_packed_f32x2((0.5, 0.5), x)
|
|
338
|
+
alpha_x_half = utils.mul_packed_f32x2((alpha, alpha), x_half)
|
|
339
|
+
tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
|
|
340
|
+
silu_x = utils.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half)
|
|
341
|
+
return utils.fma_packed_f32x2(silu_x, y, silu_x)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
@dsl_user_op
|
|
345
|
+
def dswiglu_oai(
|
|
346
|
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None
|
|
347
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
|
348
|
+
"""
|
|
349
|
+
Swiglu OAI backward pass: computes gradients w.r.t. x and y
|
|
350
|
+
Given: swiglu_oai_out = x * sigmoid(alpha * x) * (y + 1), and dout = grad w.r.t. swiglu_oai_out
|
|
351
|
+
Returns: (dx, dy, swiglu_oai_out)
|
|
352
|
+
|
|
353
|
+
Derivative of x * sigmoid(alpha * x) w.r.t. x:
|
|
354
|
+
d/dx[x * sigmoid(alpha * x)] = sigmoid(alpha * x) + alpha * x * sigmoid(alpha * x) * (1 - sigmoid(alpha * x))
|
|
355
|
+
"""
|
|
356
|
+
if const_expr(not isinstance(x, tuple)):
|
|
357
|
+
# Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2))
|
|
358
|
+
alpha_x_half = (0.5 * alpha) * x # FMUL
|
|
359
|
+
# MUFU.TANH, then FFMA
|
|
360
|
+
# sigmoid_alpha_x = 0.5 + 0.5 * cute.math.tanh(alpha_x_half, fastmath=True)
|
|
361
|
+
sigmoid_alpha_x = 0.5 + 0.5 * tanh(alpha_x_half)
|
|
362
|
+
silu_x = x * sigmoid_alpha_x # FMUL
|
|
363
|
+
silu_x_dout = silu_x * dout # FMUL
|
|
364
|
+
# FFMA, FFMA, FMUL
|
|
365
|
+
d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
|
|
366
|
+
dx = d_silu_x_dout * y + d_silu_x_dout # FFMA, instead of multiply by y + 1
|
|
367
|
+
dy = silu_x_dout
|
|
368
|
+
swiglu_out = silu_x * y + silu_x # FFMA, instead of multiply by y + 1
|
|
369
|
+
# Overall it's 1 MUFU.TANH, 4 FMUL, 5 FFMA
|
|
370
|
+
return dx, dy, swiglu_out
|
|
371
|
+
else:
|
|
372
|
+
# Compute sigmoid(alpha * x)
|
|
373
|
+
alpha_x_half = utils.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x)
|
|
374
|
+
tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1]))
|
|
375
|
+
sigmoid_alpha_x = utils.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5))
|
|
376
|
+
silu_x = utils.mul_packed_f32x2(x, sigmoid_alpha_x)
|
|
377
|
+
silu_x_dout = utils.mul_packed_f32x2(silu_x, dout)
|
|
378
|
+
# d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout
|
|
379
|
+
silu_x_minus_product = utils.fma_packed_f32x2(
|
|
380
|
+
silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x
|
|
381
|
+
)
|
|
382
|
+
sigmoid_plus_alpha_diff = utils.fma_packed_f32x2(
|
|
383
|
+
(alpha, alpha), silu_x_minus_product, sigmoid_alpha_x
|
|
384
|
+
)
|
|
385
|
+
d_silu_x_dout = utils.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout)
|
|
386
|
+
dx = utils.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout)
|
|
387
|
+
dy = silu_x_dout
|
|
388
|
+
swiglu_out = utils.fma_packed_f32x2(silu_x, y, silu_x)
|
|
389
|
+
return dx, dy, swiglu_out
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
@dsl_user_op
|
|
393
|
+
def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
394
|
+
"""GLU: Gated Linear Unit
|
|
395
|
+
glu(x, y) = sigmoid(x) * y
|
|
396
|
+
Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2))
|
|
397
|
+
"""
|
|
398
|
+
if const_expr(not isinstance(x, tuple)):
|
|
399
|
+
sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
|
|
400
|
+
return sigmoid_x * y # FMUL
|
|
401
|
+
else:
|
|
402
|
+
sigmoid_x = sigmoid(x)
|
|
403
|
+
return utils.mul_packed_f32x2(sigmoid_x, y)
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
@dsl_user_op
|
|
407
|
+
def dglu(
|
|
408
|
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
|
409
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
|
410
|
+
"""
|
|
411
|
+
GLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
|
412
|
+
Given: glu_out = sigmoid(x) * y, and dout = grad w.r.t. glu_out
|
|
413
|
+
Returns: (dx, dy, glu_out) where:
|
|
414
|
+
- dx = dout * y * sigmoid(x) * (1 - sigmoid(x))
|
|
415
|
+
- dy = dout * sigmoid(x)
|
|
416
|
+
- glu_out = sigmoid(x) * y
|
|
417
|
+
"""
|
|
418
|
+
if const_expr(not isinstance(x, tuple)):
|
|
419
|
+
# Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2))
|
|
420
|
+
sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA
|
|
421
|
+
sigmoid_x_dout = sigmoid_x * dout # FMUL
|
|
422
|
+
glu_out = sigmoid_x * y # FMUL
|
|
423
|
+
# dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout
|
|
424
|
+
# = y * (1 - sigmoid(x)) * sigmoid_x_dout
|
|
425
|
+
# = (y - y * sigmoid(x)) * sigmoid_x_dout
|
|
426
|
+
# = (y - glu_out) * sigmoid_x_dout
|
|
427
|
+
dx = (y - glu_out) * sigmoid_x_dout # FADD, FMUL
|
|
428
|
+
dy = sigmoid_x_dout
|
|
429
|
+
# Total: 1 MUFU.TANH, 4 FMUL, 1 FADD, 1 FFMA
|
|
430
|
+
return dx, dy, glu_out
|
|
431
|
+
else:
|
|
432
|
+
sigmoid_x = sigmoid(x)
|
|
433
|
+
sigmoid_x_dout = utils.mul_packed_f32x2(sigmoid_x, dout)
|
|
434
|
+
glu_out = utils.mul_packed_f32x2(sigmoid_x, y)
|
|
435
|
+
# dx = (y - glu_out) * sigmoid_x_dout
|
|
436
|
+
y_minus_glu_out = utils.sub_packed_f32x2(y, glu_out)
|
|
437
|
+
dx = utils.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout)
|
|
438
|
+
dy = sigmoid_x_dout
|
|
439
|
+
return dx, dy, glu_out
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
@dsl_user_op
|
|
443
|
+
def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
444
|
+
"""ReGLU: ReLU Gated Linear Unit
|
|
445
|
+
reglu(x, y) = relu(x) * y = max(x, 0) * y
|
|
446
|
+
"""
|
|
447
|
+
if const_expr(not isinstance(x, tuple)):
|
|
448
|
+
return cute.arch.fmax(x, Float32(0.0)) * y
|
|
449
|
+
else:
|
|
450
|
+
relu_x = relu(x)
|
|
451
|
+
return utils.mul_packed_f32x2(relu_x, y)
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
@dsl_user_op
|
|
455
|
+
@cute.jit
|
|
456
|
+
def dreglu(
|
|
457
|
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
|
458
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
|
459
|
+
"""
|
|
460
|
+
ReGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
|
461
|
+
Given: reglu_out = relu(x) * y, and dout = grad w.r.t. reglu_out
|
|
462
|
+
Returns: (dx, dy, reglu_out) where:
|
|
463
|
+
- dx = dout * y if x > 0, else 0
|
|
464
|
+
- dy = dout * relu(x)
|
|
465
|
+
- reglu_out = relu(x) * y
|
|
466
|
+
"""
|
|
467
|
+
if const_expr(not isinstance(x, tuple)):
|
|
468
|
+
x_pos = Boolean(x > 0)
|
|
469
|
+
relu_x = cute.arch.fmax(x, Float32(0.0))
|
|
470
|
+
dx = (dout * y) if x_pos else Float32(0.0)
|
|
471
|
+
dy = dout * relu_x
|
|
472
|
+
reglu_out = relu_x * y
|
|
473
|
+
return dx, dy, reglu_out
|
|
474
|
+
else:
|
|
475
|
+
x0_pos = Boolean(x[0] > 0)
|
|
476
|
+
x1_pos = Boolean(x[1] > 0)
|
|
477
|
+
relu_x = relu(x)
|
|
478
|
+
dout_y = utils.mul_packed_f32x2(dout, y)
|
|
479
|
+
dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0)))
|
|
480
|
+
dy = utils.mul_packed_f32x2(dout, relu_x)
|
|
481
|
+
reglu_out = utils.mul_packed_f32x2(relu_x, y)
|
|
482
|
+
return dx, dy, reglu_out
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
@dsl_user_op
|
|
486
|
+
def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2:
|
|
487
|
+
"""GeGLU: GELU Gated Linear Unit
|
|
488
|
+
geglu(x, y) = gelu(x) * y
|
|
489
|
+
Uses the tanh approximation of GELU
|
|
490
|
+
"""
|
|
491
|
+
if const_expr(not isinstance(x, tuple)):
|
|
492
|
+
return gelu_tanh_approx(x) * y
|
|
493
|
+
else:
|
|
494
|
+
return utils.mul_packed_f32x2(gelu_tanh_approx(x), y)
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
@dsl_user_op
|
|
498
|
+
def dgeglu(
|
|
499
|
+
x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None
|
|
500
|
+
) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]:
|
|
501
|
+
"""
|
|
502
|
+
GeGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection)
|
|
503
|
+
Given: geglu_out = gelu(x) * y, and dout = grad w.r.t. geglu_out
|
|
504
|
+
Returns: (dx, dy, geglu_out) where:
|
|
505
|
+
- dx = dout * y * d_gelu(x)
|
|
506
|
+
- dy = dout * gelu(x)
|
|
507
|
+
- geglu_out = gelu(x) * y
|
|
508
|
+
"""
|
|
509
|
+
if const_expr(not isinstance(x, tuple)):
|
|
510
|
+
# Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
|
|
511
|
+
dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
|
|
512
|
+
# Compute gradients for geglu
|
|
513
|
+
dx = dgelu_x_dout * y
|
|
514
|
+
dy = gelu_x * dout
|
|
515
|
+
geglu_out = gelu_x * y
|
|
516
|
+
return dx, dy, geglu_out
|
|
517
|
+
else:
|
|
518
|
+
# Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x)
|
|
519
|
+
dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout)
|
|
520
|
+
# Compute gradients for geglu
|
|
521
|
+
dx = utils.mul_packed_f32x2(dgelu_x_dout, y)
|
|
522
|
+
dy = utils.mul_packed_f32x2(gelu_x, dout)
|
|
523
|
+
geglu_out = utils.mul_packed_f32x2(gelu_x, y)
|
|
524
|
+
return dx, dy, geglu_out
|
|
@@ -11,7 +11,7 @@ import hashlib
|
|
|
11
11
|
import json
|
|
12
12
|
from pathlib import Path
|
|
13
13
|
from functools import cached_property, partial
|
|
14
|
-
from typing import Dict, Tuple
|
|
14
|
+
from typing import Dict, Tuple, List, Optional, Any
|
|
15
15
|
|
|
16
16
|
import torch
|
|
17
17
|
from torch import Tensor
|
|
@@ -53,7 +53,22 @@ def _base32(key):
|
|
|
53
53
|
|
|
54
54
|
|
|
55
55
|
class Autotuner:
|
|
56
|
-
def __init__(
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
fn,
|
|
59
|
+
key,
|
|
60
|
+
configs,
|
|
61
|
+
restore_value=None,
|
|
62
|
+
prune_configs_by: Optional[Dict] = None,
|
|
63
|
+
do_bench=None,
|
|
64
|
+
cache_results=False,
|
|
65
|
+
):
|
|
66
|
+
"""
|
|
67
|
+
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
|
68
|
+
'perf_model': performance model used to predicate running time with different configs, returns running time
|
|
69
|
+
'top_k': number of configs to bench
|
|
70
|
+
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
|
|
71
|
+
"""
|
|
57
72
|
if not configs:
|
|
58
73
|
self.configs = [AutotuneConfig()]
|
|
59
74
|
else:
|
|
@@ -90,6 +105,16 @@ class Autotuner:
|
|
|
90
105
|
else:
|
|
91
106
|
self.post_hook = None
|
|
92
107
|
|
|
108
|
+
self.perf_model = None
|
|
109
|
+
self.configs_top_k = 1.0
|
|
110
|
+
self.early_config_prune = None
|
|
111
|
+
if prune_configs_by:
|
|
112
|
+
self.perf_model = prune_configs_by.get("perf_model", self.perf_model)
|
|
113
|
+
self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k)
|
|
114
|
+
self.early_config_prune = prune_configs_by.get(
|
|
115
|
+
"early_config_prune", self.early_config_prune
|
|
116
|
+
)
|
|
117
|
+
|
|
93
118
|
self.fn = fn
|
|
94
119
|
self._do_bench = do_bench
|
|
95
120
|
|
|
@@ -198,13 +223,14 @@ class Autotuner:
|
|
|
198
223
|
key = tuple(key)
|
|
199
224
|
if key not in self.cache:
|
|
200
225
|
used_cached_result = False
|
|
226
|
+
pruned_configs = self.prune_configs(kwargs)
|
|
201
227
|
|
|
202
228
|
@torch.compiler.disable # Don't want any tracing here
|
|
203
229
|
def benchmark():
|
|
204
230
|
bench_start = time.time()
|
|
205
231
|
timings = {
|
|
206
232
|
config: self._bench(*args, config=config, **kwargs)
|
|
207
|
-
for config in
|
|
233
|
+
for config in pruned_configs
|
|
208
234
|
}
|
|
209
235
|
bench_end = time.time()
|
|
210
236
|
if os.getenv(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1":
|
|
@@ -215,7 +241,7 @@ class Autotuner:
|
|
|
215
241
|
self.configs_timings = timings
|
|
216
242
|
|
|
217
243
|
if self.cache_results:
|
|
218
|
-
self.check_disk_cache(key,
|
|
244
|
+
self.check_disk_cache(key, pruned_configs, benchmark)
|
|
219
245
|
else:
|
|
220
246
|
benchmark()
|
|
221
247
|
|
|
@@ -239,6 +265,32 @@ class Autotuner:
|
|
|
239
265
|
self.nargs = None
|
|
240
266
|
return ret
|
|
241
267
|
|
|
268
|
+
def prune_configs(self, kwargs: Dict) -> List[Any]:
|
|
269
|
+
pruned_configs = self.configs
|
|
270
|
+
if self.early_config_prune:
|
|
271
|
+
pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs)
|
|
272
|
+
if self.perf_model:
|
|
273
|
+
top_k = self.configs_top_k
|
|
274
|
+
if isinstance(top_k, float) and top_k <= 1.0:
|
|
275
|
+
top_k = int(len(self.configs) * top_k)
|
|
276
|
+
elif not isinstance(top_k, int):
|
|
277
|
+
# Slice index must be an integer
|
|
278
|
+
raise TypeError(
|
|
279
|
+
"Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int"
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
if len(pruned_configs) > top_k:
|
|
283
|
+
est_timing = {
|
|
284
|
+
config: self.perf_model(
|
|
285
|
+
**self.nargs,
|
|
286
|
+
**kwargs,
|
|
287
|
+
**config.all_kwargs(),
|
|
288
|
+
)
|
|
289
|
+
for config in pruned_configs
|
|
290
|
+
}
|
|
291
|
+
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
|
|
292
|
+
return pruned_configs
|
|
293
|
+
|
|
242
294
|
|
|
243
295
|
class AutotuneConfig:
|
|
244
296
|
"""
|
|
@@ -272,7 +324,9 @@ class AutotuneConfig:
|
|
|
272
324
|
return self_tuple == other_tuple
|
|
273
325
|
|
|
274
326
|
|
|
275
|
-
def autotune(
|
|
327
|
+
def autotune(
|
|
328
|
+
configs, key=None, prune_configs_by=None, restore_value=None, do_bench=None, cache_results=True
|
|
329
|
+
):
|
|
276
330
|
f"""
|
|
277
331
|
Decorator for auto-tuning a function function.
|
|
278
332
|
|
|
@@ -286,6 +340,10 @@ def autotune(configs, key=None, restore_value=None, do_bench=None, cache_results
|
|
|
286
340
|
:type configs: list[AutotuneConfig]
|
|
287
341
|
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
|
288
342
|
:type key: list[str]
|
|
343
|
+
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
|
344
|
+
'perf_model': performance model used to predicate running time with different configs, returns running time
|
|
345
|
+
'top_k': number of configs to bench
|
|
346
|
+
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
|
|
289
347
|
:param restore_value: a list of argument names whose value will be restored after evaluating any configs.
|
|
290
348
|
:type restore_value: list[str]
|
|
291
349
|
:param do_bench: a benchmark function to measure the time of each run.
|
|
@@ -303,6 +361,7 @@ def autotune(configs, key=None, restore_value=None, do_bench=None, cache_results
|
|
|
303
361
|
key,
|
|
304
362
|
configs,
|
|
305
363
|
restore_value=restore_value,
|
|
364
|
+
prune_configs_by=prune_configs_by,
|
|
306
365
|
do_bench=do_bench,
|
|
307
366
|
cache_results=cache_results,
|
|
308
367
|
)
|