quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +8 -1
- quack/activation.py +288 -0
- quack/autotuner.py +310 -0
- quack/cross_entropy.py +325 -175
- quack/cute_dsl_utils.py +119 -0
- quack/dense_gemm_sm100.py +2562 -0
- quack/dense_gemm_sm90.py +1657 -842
- quack/fast_math.py +80 -0
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +69 -0
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +569 -0
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +5 -3
- quack/linear.py +240 -0
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +74 -0
- quack/pipeline.py +151 -0
- quack/reduce.py +241 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +583 -231
- quack/softmax.py +27 -15
- quack/sort/bitonic_sort.py +126 -0
- quack/sort/generate_sorting_networks.py +326 -0
- quack/sort/sorting_networks.py +120 -0
- quack/sort/utils.py +31 -0
- quack/symmetric_dense_gemm_sm90.py +2091 -0
- quack/tensormap_manager.py +115 -0
- quack/tile_scheduler.py +937 -0
- quack/topk.py +227 -0
- quack/utils.py +203 -230
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/METADATA +2 -2
- quack_kernels-0.2.0.dist-info/RECORD +37 -0
- quack_kernels-0.1.10.dist-info/RECORD +0 -13
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.10.dist-info → quack_kernels-0.2.0.dist-info}/top_level.txt +0 -0
quack/gemm_interface.py
ADDED
|
@@ -0,0 +1,569 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao
|
|
2
|
+
from typing import Optional, Tuple, Literal
|
|
3
|
+
from functools import partial
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
|
|
9
|
+
from quack.gemm_config import GemmConfig, get_all_configs
|
|
10
|
+
|
|
11
|
+
from quack.autotuner import autotune, AutotuneConfig
|
|
12
|
+
from quack.dense_gemm_sm90 import gemm_sm90
|
|
13
|
+
from quack.gemm_act_sm90 import gemm_act_sm90
|
|
14
|
+
from quack.gemm_dact_sm90 import gemm_dact_sm90
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# Dictionary mapping activation names to PyTorch functions
|
|
18
|
+
act_to_pytorch_fn_map = {
|
|
19
|
+
None: lambda x: x,
|
|
20
|
+
"relu": F.relu,
|
|
21
|
+
"relu_sq": lambda x: F.relu(x).square(),
|
|
22
|
+
"gelu_tanh_approx": partial(F.gelu, approximate="tanh"),
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# Dictionary mapping gated activation names to their forward functions
|
|
27
|
+
# Each function takes (gate, up) and returns postact
|
|
28
|
+
gated_to_pytorch_fn_map = {
|
|
29
|
+
"swiglu": lambda gate, up: F.silu(gate) * up,
|
|
30
|
+
"swiglu_oai": lambda gate, up: gate * torch.sigmoid(1.702 * gate) * (up + 1),
|
|
31
|
+
"reglu": lambda gate, up: F.relu(gate) * up,
|
|
32
|
+
"geglu": lambda gate, up: F.gelu(gate, approximate="tanh") * up,
|
|
33
|
+
"glu": lambda gate, up: torch.sigmoid(gate) * up,
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@autotune(
|
|
38
|
+
configs=[AutotuneConfig(config=c) for c in get_all_configs()],
|
|
39
|
+
key=["dynamic_scheduler"],
|
|
40
|
+
)
|
|
41
|
+
def gemm_tuned(
|
|
42
|
+
A: Tensor, # (M, K)
|
|
43
|
+
B: Tensor, # (K, N)
|
|
44
|
+
out: Tensor, # (M, N) - required output tensor
|
|
45
|
+
C: Optional[Tensor] = None, # (M, N)
|
|
46
|
+
alpha: float | Tensor = 1.0, # (1,)
|
|
47
|
+
beta: float | Tensor = 1.0, # (1,)
|
|
48
|
+
dynamic_scheduler: bool = False,
|
|
49
|
+
config: Optional[GemmConfig] = None,
|
|
50
|
+
) -> None:
|
|
51
|
+
if config is None:
|
|
52
|
+
config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
|
|
53
|
+
A, B = A.unsqueeze(0), B.mT.unsqueeze(0) # (1, M, K), (1, N, K)
|
|
54
|
+
if C is not None:
|
|
55
|
+
C = C.unsqueeze(0) # (1, M, N)
|
|
56
|
+
assert out.shape == (
|
|
57
|
+
A.shape[1],
|
|
58
|
+
B.shape[1],
|
|
59
|
+
), f"out shape mismatch: {out.shape} vs {(A.shape[1], B.shape[1])}"
|
|
60
|
+
out = out.unsqueeze(0)
|
|
61
|
+
tile_count_semaphore = (
|
|
62
|
+
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
|
|
63
|
+
)
|
|
64
|
+
gemm_sm90(
|
|
65
|
+
A if not config.swap_ab else B,
|
|
66
|
+
B if not config.swap_ab else A,
|
|
67
|
+
out if not config.swap_ab else out.mT,
|
|
68
|
+
(C if not config.swap_ab else C.mT) if C is not None else None,
|
|
69
|
+
tile_count_semaphore,
|
|
70
|
+
config.tile_m,
|
|
71
|
+
config.tile_n,
|
|
72
|
+
config.cluster_m,
|
|
73
|
+
config.cluster_n,
|
|
74
|
+
config.pingpong,
|
|
75
|
+
alpha=alpha,
|
|
76
|
+
beta=beta,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@autotune(
|
|
81
|
+
configs=[AutotuneConfig(config=c) for c in get_all_configs()],
|
|
82
|
+
key=["activation"],
|
|
83
|
+
)
|
|
84
|
+
def gemm_act_tuned(
|
|
85
|
+
A: Tensor, # (M, K)
|
|
86
|
+
B: Tensor, # (K, N)
|
|
87
|
+
preact_out: Optional[Tensor], # (M, N) - None if not storing preact
|
|
88
|
+
postact_out: Tensor, # (M, N)
|
|
89
|
+
C: Optional[Tensor] = None, # (M, N)
|
|
90
|
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
91
|
+
config: Optional[GemmConfig] = None,
|
|
92
|
+
) -> None:
|
|
93
|
+
if config is None:
|
|
94
|
+
config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
|
|
95
|
+
A, B = A.unsqueeze(0), B.mT.unsqueeze(0) # (1, M, K), (1, N, K)
|
|
96
|
+
if C is not None:
|
|
97
|
+
C = C.unsqueeze(0) # (1, M, N)
|
|
98
|
+
if preact_out is not None:
|
|
99
|
+
assert preact_out.shape == (A.shape[1], B.shape[1])
|
|
100
|
+
D = preact_out.unsqueeze(0)
|
|
101
|
+
else:
|
|
102
|
+
D = None
|
|
103
|
+
assert postact_out.shape == (A.shape[1], B.shape[1])
|
|
104
|
+
PostAct = postact_out.unsqueeze(0)
|
|
105
|
+
gemm_act_sm90(
|
|
106
|
+
A if not config.swap_ab else B,
|
|
107
|
+
B if not config.swap_ab else A,
|
|
108
|
+
(D if not config.swap_ab else D.mT) if D is not None else None,
|
|
109
|
+
(C if not config.swap_ab else C.mT) if C is not None else None,
|
|
110
|
+
PostAct if not config.swap_ab else PostAct.mT,
|
|
111
|
+
activation,
|
|
112
|
+
config.tile_m,
|
|
113
|
+
config.tile_n,
|
|
114
|
+
config.cluster_m,
|
|
115
|
+
config.cluster_n,
|
|
116
|
+
config.pingpong,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@autotune(
|
|
121
|
+
configs=[AutotuneConfig(config=c) for c in get_all_configs()],
|
|
122
|
+
key=["activation", "dynamic_scheduler"],
|
|
123
|
+
)
|
|
124
|
+
def gemm_dact_tuned(
|
|
125
|
+
A: Tensor, # (M, K)
|
|
126
|
+
B: Tensor, # (K, N)
|
|
127
|
+
PreAct: Tensor, # (M, N)
|
|
128
|
+
dx_out: Tensor, # (M, N)
|
|
129
|
+
postact_out: Tensor, # (M, N)
|
|
130
|
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
131
|
+
dynamic_scheduler: bool = True,
|
|
132
|
+
config: Optional[GemmConfig] = None,
|
|
133
|
+
) -> None:
|
|
134
|
+
if config is None:
|
|
135
|
+
config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
|
|
136
|
+
A, B = A.unsqueeze(0), B.mT.unsqueeze(0) # (1, M, K), (1, N, K)
|
|
137
|
+
PreAct = PreAct.unsqueeze(0) # (1, M, N)
|
|
138
|
+
assert dx_out.shape == (A.shape[1], B.shape[1])
|
|
139
|
+
D = dx_out.unsqueeze(0)
|
|
140
|
+
assert postact_out.shape == (A.shape[1], B.shape[1])
|
|
141
|
+
PostAct = postact_out.unsqueeze(0)
|
|
142
|
+
tile_count_semaphore = (
|
|
143
|
+
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
|
|
144
|
+
)
|
|
145
|
+
gemm_dact_sm90(
|
|
146
|
+
A if not config.swap_ab else B,
|
|
147
|
+
B if not config.swap_ab else A,
|
|
148
|
+
D if not config.swap_ab else D.mT,
|
|
149
|
+
PreAct if not config.swap_ab else PreAct.mT,
|
|
150
|
+
PostAct if not config.swap_ab else PostAct.mT,
|
|
151
|
+
tile_count_semaphore,
|
|
152
|
+
activation,
|
|
153
|
+
config.tile_m,
|
|
154
|
+
config.tile_n,
|
|
155
|
+
config.cluster_m,
|
|
156
|
+
config.cluster_n,
|
|
157
|
+
config.pingpong,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def gemm(
|
|
162
|
+
A: Tensor,
|
|
163
|
+
B: Tensor,
|
|
164
|
+
out: Optional[Tensor] = None,
|
|
165
|
+
alpha: float | Tensor = 1.0,
|
|
166
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
167
|
+
dynamic_scheduler: bool = False,
|
|
168
|
+
tuned: bool = True,
|
|
169
|
+
) -> Tensor:
|
|
170
|
+
"""GEMM with optional output tensor and tuning control."""
|
|
171
|
+
if out is None:
|
|
172
|
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
173
|
+
out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
|
|
174
|
+
alpha_tensor = alpha if not isinstance(alpha, float) else None
|
|
175
|
+
alpha = alpha if isinstance(alpha, float) else 1.0
|
|
176
|
+
gemm_out(
|
|
177
|
+
A,
|
|
178
|
+
B,
|
|
179
|
+
out,
|
|
180
|
+
alpha=alpha,
|
|
181
|
+
alpha_tensor=alpha_tensor,
|
|
182
|
+
dynamic_scheduler=dynamic_scheduler,
|
|
183
|
+
tuned=tuned,
|
|
184
|
+
)
|
|
185
|
+
return out
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
@torch.library.custom_op(
|
|
189
|
+
"quack::gemm_out",
|
|
190
|
+
mutates_args=("out",),
|
|
191
|
+
device_types="cuda",
|
|
192
|
+
# We have to split out alpha and alpha_tensor since torch.library requires
|
|
193
|
+
# each argument to have a fixed type
|
|
194
|
+
# schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, Tensor? alpha_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
195
|
+
)
|
|
196
|
+
def gemm_out(
|
|
197
|
+
A: Tensor,
|
|
198
|
+
B: Tensor,
|
|
199
|
+
out: Tensor,
|
|
200
|
+
alpha: float = 1.0,
|
|
201
|
+
alpha_tensor: Optional[Tensor] = None,
|
|
202
|
+
dynamic_scheduler: bool = False,
|
|
203
|
+
tuned: bool = True,
|
|
204
|
+
) -> None:
|
|
205
|
+
"""GEMM with pre-allocated output tensor."""
|
|
206
|
+
fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
|
|
207
|
+
alpha = alpha_tensor if alpha_tensor is not None else alpha
|
|
208
|
+
fn(A, B, out, C=None, alpha=alpha, dynamic_scheduler=dynamic_scheduler)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def gemm_ref(
|
|
212
|
+
A: Tensor,
|
|
213
|
+
B: Tensor,
|
|
214
|
+
out: Optional[Tensor] = None,
|
|
215
|
+
alpha: float | Tensor = 1.0,
|
|
216
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
217
|
+
) -> Tensor:
|
|
218
|
+
"""Reference implementation for GEMM with pre-allocated output."""
|
|
219
|
+
# The out_dtype argument requires torch >= 2.8
|
|
220
|
+
out = torch.mm(A, B, out_dtype=out_dtype, out=out)
|
|
221
|
+
if not isinstance(alpha, float) or alpha != 1.0:
|
|
222
|
+
out = out * alpha
|
|
223
|
+
return out
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def gemm_add(
|
|
227
|
+
A: Tensor,
|
|
228
|
+
B: Tensor,
|
|
229
|
+
C: Tensor,
|
|
230
|
+
out: Optional[Tensor] = None,
|
|
231
|
+
alpha: float | Tensor = 1.0,
|
|
232
|
+
beta: float | Tensor = 1.0,
|
|
233
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
234
|
+
dynamic_scheduler: bool = False,
|
|
235
|
+
tuned: bool = True,
|
|
236
|
+
) -> Tensor:
|
|
237
|
+
"""GEMM with addition and optional output tensor."""
|
|
238
|
+
if out is None:
|
|
239
|
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
240
|
+
out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
|
|
241
|
+
alpha_tensor = alpha if not isinstance(alpha, float) else None
|
|
242
|
+
alpha = alpha if isinstance(alpha, float) else 1.0
|
|
243
|
+
beta_tensor = beta if not isinstance(beta, float) else None
|
|
244
|
+
beta = beta if isinstance(beta, float) else 1.0
|
|
245
|
+
gemm_add_out(
|
|
246
|
+
A,
|
|
247
|
+
B,
|
|
248
|
+
C,
|
|
249
|
+
out,
|
|
250
|
+
alpha,
|
|
251
|
+
beta,
|
|
252
|
+
alpha_tensor,
|
|
253
|
+
beta_tensor,
|
|
254
|
+
dynamic_scheduler=dynamic_scheduler,
|
|
255
|
+
tuned=tuned,
|
|
256
|
+
)
|
|
257
|
+
return out
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
@torch.library.custom_op(
|
|
261
|
+
"quack::gemm_add_out",
|
|
262
|
+
mutates_args=("out",),
|
|
263
|
+
device_types="cuda",
|
|
264
|
+
# We have to split out alpha and alpha_tensor since torch.library requires
|
|
265
|
+
# each argument to have a fixed type
|
|
266
|
+
# schema="(Tensor A, Tensor B, Tensor C, Tensor(a3!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
267
|
+
)
|
|
268
|
+
def gemm_add_out(
|
|
269
|
+
A: Tensor,
|
|
270
|
+
B: Tensor,
|
|
271
|
+
C: Tensor,
|
|
272
|
+
out: Tensor,
|
|
273
|
+
alpha: float = 1.0,
|
|
274
|
+
beta: float = 1.0,
|
|
275
|
+
alpha_tensor: Optional[Tensor] = None,
|
|
276
|
+
beta_tensor: Optional[Tensor] = None,
|
|
277
|
+
dynamic_scheduler: bool = False,
|
|
278
|
+
tuned: bool = True,
|
|
279
|
+
) -> None:
|
|
280
|
+
"""GEMM with addition and pre-allocated output tensor."""
|
|
281
|
+
fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
|
|
282
|
+
alpha = alpha_tensor if alpha_tensor is not None else alpha
|
|
283
|
+
beta = beta_tensor if beta_tensor is not None else beta
|
|
284
|
+
fn(A, B, out, C, alpha=alpha, beta=beta, dynamic_scheduler=dynamic_scheduler)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def gemm_add_ref(
|
|
288
|
+
A: Tensor,
|
|
289
|
+
B: Tensor,
|
|
290
|
+
C: Tensor,
|
|
291
|
+
out: Optional[Tensor] = None,
|
|
292
|
+
alpha: float | Tensor = 1.0,
|
|
293
|
+
beta: float | Tensor = 1.0,
|
|
294
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
295
|
+
) -> Tensor:
|
|
296
|
+
"""Reference implementation for GEMM with addition and pre-allocated output."""
|
|
297
|
+
if isinstance(alpha, float) and isinstance(beta, float):
|
|
298
|
+
return torch.addmm(C, A, B, out_dtype=out_dtype, alpha=alpha, beta=beta, out=out)
|
|
299
|
+
else:
|
|
300
|
+
out_dtype = (
|
|
301
|
+
out.dtype if out is not None else (out_dtype if out_dtype is not None else A.dtype)
|
|
302
|
+
)
|
|
303
|
+
result = (alpha * (A @ B) + beta * C).to(out_dtype)
|
|
304
|
+
if out is not None:
|
|
305
|
+
out.copy_(result)
|
|
306
|
+
return result
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def gemm_add_inplace(
|
|
310
|
+
A: Tensor,
|
|
311
|
+
B: Tensor,
|
|
312
|
+
out: Tensor,
|
|
313
|
+
alpha: float | Tensor = 1.0,
|
|
314
|
+
beta: float | Tensor = 1.0,
|
|
315
|
+
dynamic_scheduler: bool = False,
|
|
316
|
+
tuned: bool = True,
|
|
317
|
+
) -> None:
|
|
318
|
+
"""In-place GEMM with addition: out = alpha * A @ B + beta * out.
|
|
319
|
+
Args:
|
|
320
|
+
A: (M, K) input tensor
|
|
321
|
+
B: (K, N) input tensor
|
|
322
|
+
out: (M, N) tensor to accumulate into (modified in-place)
|
|
323
|
+
alpha: Scalar multiplier for A @ B
|
|
324
|
+
beta: Scalar multiplier for out
|
|
325
|
+
dynamic_scheduler: Whether to use dynamic scheduler
|
|
326
|
+
tuned: Whether to use autotuned configuration
|
|
327
|
+
"""
|
|
328
|
+
alpha_tensor = alpha if not isinstance(alpha, float) else None
|
|
329
|
+
alpha = alpha if isinstance(alpha, float) else 1.0
|
|
330
|
+
beta_tensor = beta if not isinstance(beta, float) else None
|
|
331
|
+
beta = beta if isinstance(beta, float) else 1.0
|
|
332
|
+
gemm_add_inplace_op(A, B, out, alpha, beta, alpha_tensor, beta_tensor, dynamic_scheduler, tuned)
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
@torch.library.custom_op(
|
|
336
|
+
"quack::gemm_add_inplace",
|
|
337
|
+
mutates_args=("out",),
|
|
338
|
+
device_types="cuda",
|
|
339
|
+
# We have to split out alpha and alpha_tensor since torch.library requires
|
|
340
|
+
# each argument to have a fixed type
|
|
341
|
+
# schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
342
|
+
)
|
|
343
|
+
def gemm_add_inplace_op(
|
|
344
|
+
A: Tensor,
|
|
345
|
+
B: Tensor,
|
|
346
|
+
out: Tensor,
|
|
347
|
+
alpha: float = 1.0,
|
|
348
|
+
beta: float = 1.0,
|
|
349
|
+
alpha_tensor: Optional[Tensor] = None,
|
|
350
|
+
beta_tensor: Optional[Tensor] = None,
|
|
351
|
+
dynamic_scheduler: bool = False,
|
|
352
|
+
tuned: bool = True,
|
|
353
|
+
) -> None:
|
|
354
|
+
fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
|
|
355
|
+
alpha = alpha_tensor if alpha_tensor is not None else alpha
|
|
356
|
+
beta = beta_tensor if beta_tensor is not None else beta
|
|
357
|
+
# Use C as both input bias and output
|
|
358
|
+
fn(A, B, out, out, alpha=alpha, beta=beta, dynamic_scheduler=dynamic_scheduler)
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def gemm_act(
|
|
362
|
+
A: Tensor,
|
|
363
|
+
B: Tensor,
|
|
364
|
+
C: Optional[Tensor] = None,
|
|
365
|
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
366
|
+
preact_out: Optional[Tensor] = None,
|
|
367
|
+
postact_out: Optional[Tensor] = None,
|
|
368
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
369
|
+
postact_dtype: Optional[torch.dtype] = None,
|
|
370
|
+
store_preact: bool = True,
|
|
371
|
+
tuned: bool = True,
|
|
372
|
+
) -> Tuple[Optional[Tensor], Tensor]:
|
|
373
|
+
"""GEMM with activation and optional output tensors."""
|
|
374
|
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
375
|
+
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
|
376
|
+
if preact_out is None and store_preact:
|
|
377
|
+
preact_out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
|
|
378
|
+
if postact_out is None:
|
|
379
|
+
postact_out = torch.empty((A.shape[0], B.shape[1]), dtype=postact_dtype, device=A.device)
|
|
380
|
+
gemm_act_out(A, B, preact_out, postact_out, C, activation, tuned)
|
|
381
|
+
return preact_out, postact_out
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
@torch.library.custom_op(
|
|
385
|
+
"quack::gemm_act_out",
|
|
386
|
+
mutates_args=("preact_out", "postact_out"),
|
|
387
|
+
device_types="cuda",
|
|
388
|
+
schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, str? activation=None, bool tuned=True) -> ()",
|
|
389
|
+
)
|
|
390
|
+
def gemm_act_out(
|
|
391
|
+
A: Tensor,
|
|
392
|
+
B: Tensor,
|
|
393
|
+
preact_out: Optional[Tensor],
|
|
394
|
+
postact_out: Tensor,
|
|
395
|
+
C: Optional[Tensor] = None,
|
|
396
|
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
397
|
+
tuned: bool = True,
|
|
398
|
+
) -> None:
|
|
399
|
+
"""GEMM with activation and pre-allocated output tensors."""
|
|
400
|
+
fn = gemm_act_tuned if tuned else partial(gemm_act_tuned.fn, config=None)
|
|
401
|
+
fn(A, B, preact_out, postact_out, C, activation)
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def gemm_act_ref(
|
|
405
|
+
A: Tensor,
|
|
406
|
+
B: Tensor,
|
|
407
|
+
C: Optional[Tensor] = None,
|
|
408
|
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
409
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
410
|
+
postact_dtype: Optional[torch.dtype] = None,
|
|
411
|
+
store_preact: bool = True,
|
|
412
|
+
) -> Tuple[Optional[Tensor], Tensor]:
|
|
413
|
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
414
|
+
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
|
415
|
+
out = torch.mm(A, B) if C is None else C + torch.mm(A, B)
|
|
416
|
+
postact = act_to_pytorch_fn_map[activation](out).to(postact_dtype)
|
|
417
|
+
return out.to(out_dtype) if store_preact else None, postact
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def gemm_dact(
|
|
421
|
+
A: Tensor,
|
|
422
|
+
B: Tensor,
|
|
423
|
+
PreAct: Tensor,
|
|
424
|
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
425
|
+
dx_out: Optional[Tensor] = None,
|
|
426
|
+
postact_out: Optional[Tensor] = None,
|
|
427
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
428
|
+
postact_dtype: Optional[torch.dtype] = None,
|
|
429
|
+
dynamic_scheduler: bool = True,
|
|
430
|
+
tuned: bool = True,
|
|
431
|
+
) -> Tuple[Tensor, Tensor]:
|
|
432
|
+
"""GEMM with activation gradient and optional output tensors."""
|
|
433
|
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
434
|
+
postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
|
|
435
|
+
if dx_out is None:
|
|
436
|
+
dx_out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
|
|
437
|
+
if postact_out is None:
|
|
438
|
+
postact_out = torch.empty((A.shape[0], B.shape[1]), dtype=postact_dtype, device=A.device)
|
|
439
|
+
gemm_dact_out(A, B, PreAct, dx_out, postact_out, activation, dynamic_scheduler, tuned)
|
|
440
|
+
return dx_out, postact_out
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
@torch.library.custom_op(
|
|
444
|
+
"quack::gemm_dact_out",
|
|
445
|
+
mutates_args=("dx_out", "postact_out"),
|
|
446
|
+
device_types="cuda",
|
|
447
|
+
schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a3!) dx_out, Tensor(a4!) postact_out, str? activation=None, bool dynamic_scheduler=True, bool tuned=True) -> ()",
|
|
448
|
+
)
|
|
449
|
+
def gemm_dact_out(
|
|
450
|
+
A: Tensor,
|
|
451
|
+
B: Tensor,
|
|
452
|
+
PreAct: Tensor,
|
|
453
|
+
dx_out: Tensor,
|
|
454
|
+
postact_out: Tensor,
|
|
455
|
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
456
|
+
dynamic_scheduler: bool = True,
|
|
457
|
+
tuned: bool = True,
|
|
458
|
+
) -> None:
|
|
459
|
+
"""GEMM with activation gradient and pre-allocated output tensors."""
|
|
460
|
+
fn = gemm_dact_tuned if tuned else partial(gemm_dact_tuned.fn, config=None)
|
|
461
|
+
fn(A, B, PreAct, dx_out, postact_out, activation, dynamic_scheduler)
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def gemm_dact_ref(
|
|
465
|
+
A: Tensor,
|
|
466
|
+
B: Tensor,
|
|
467
|
+
PreAct: Tensor,
|
|
468
|
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
469
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
470
|
+
postact_dtype: Optional[torch.dtype] = None,
|
|
471
|
+
) -> Tuple[Tensor, Tensor]:
|
|
472
|
+
"""Reference implementation for GEMM with activation gradient."""
|
|
473
|
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
474
|
+
postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
|
|
475
|
+
dout = torch.mm(A, B).to(out_dtype)
|
|
476
|
+
postact = act_to_pytorch_fn_map[activation](PreAct)
|
|
477
|
+
# Compute gradient using autograd
|
|
478
|
+
if activation is None:
|
|
479
|
+
dx = dout
|
|
480
|
+
else:
|
|
481
|
+
PreAct_requires_grad = PreAct.requires_grad
|
|
482
|
+
PreAct.requires_grad_(True)
|
|
483
|
+
postact_for_grad = act_to_pytorch_fn_map[activation](PreAct)
|
|
484
|
+
dx = torch.autograd.grad(postact_for_grad, PreAct, dout, create_graph=False)[0]
|
|
485
|
+
PreAct.requires_grad_(PreAct_requires_grad)
|
|
486
|
+
return dx.to(out_dtype), postact.to(postact_dtype)
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
def gemm_gated_ref(
|
|
490
|
+
A: Tensor,
|
|
491
|
+
B: Tensor,
|
|
492
|
+
C: Optional[Tensor] = None,
|
|
493
|
+
activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"] = "swiglu",
|
|
494
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
495
|
+
postact_dtype: Optional[torch.dtype] = None,
|
|
496
|
+
store_preact: bool = True,
|
|
497
|
+
) -> Tuple[Optional[Tensor], Tensor]:
|
|
498
|
+
"""Reference implementation for GEMM with gated activation forward.
|
|
499
|
+
|
|
500
|
+
Args:
|
|
501
|
+
A: (M, K) - input tensor
|
|
502
|
+
B: (K, 2*N) - weight tensor with gate and up projections
|
|
503
|
+
C: (M, 2*N) - optional bias tensor
|
|
504
|
+
activation: Type of gated activation
|
|
505
|
+
out_dtype: Output dtype for preact
|
|
506
|
+
postact_dtype: Output dtype for postact
|
|
507
|
+
store_preact: Whether to return the pre-activation
|
|
508
|
+
|
|
509
|
+
Returns:
|
|
510
|
+
(preact, postact) where:
|
|
511
|
+
- preact: (M, 2*N) pre-activation (if store_preact=True, else None)
|
|
512
|
+
- postact: (M, N) post-activation output
|
|
513
|
+
"""
|
|
514
|
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
515
|
+
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
|
516
|
+
preact = torch.mm(A, B) if C is None else C + torch.mm(A, B)
|
|
517
|
+
# Split preact into gate and up projections
|
|
518
|
+
gate = preact[..., ::2] # (M, N)
|
|
519
|
+
up = preact[..., 1::2] # (M, N)
|
|
520
|
+
postact = gated_to_pytorch_fn_map[activation](gate, up)
|
|
521
|
+
return preact.to(out_dtype) if store_preact else None, postact.to(postact_dtype)
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
def gemm_dgated_ref(
|
|
525
|
+
A: Tensor,
|
|
526
|
+
B: Tensor,
|
|
527
|
+
PreAct: Tensor,
|
|
528
|
+
activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"],
|
|
529
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
530
|
+
postact_dtype: Optional[torch.dtype] = None,
|
|
531
|
+
) -> Tuple[Tensor, Tensor]:
|
|
532
|
+
"""Reference implementation for GEMM with gated activation gradient.
|
|
533
|
+
|
|
534
|
+
Args:
|
|
535
|
+
A: (M, K) - dout input tensor
|
|
536
|
+
B: (K, N) - weight tensor
|
|
537
|
+
PreAct: (M, 2*N) - pre-activation tensor with gate and up projections interleaved
|
|
538
|
+
activation: Type of gated activation
|
|
539
|
+
out_dtype: Output dtype for dx
|
|
540
|
+
postact_dtype: Output dtype for postact
|
|
541
|
+
|
|
542
|
+
Returns:
|
|
543
|
+
(dx, postact) where:
|
|
544
|
+
- dx: (M, 2*N) gradient w.r.t. PreAct
|
|
545
|
+
- postact: (M, N) post-activation output
|
|
546
|
+
"""
|
|
547
|
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
548
|
+
postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
|
|
549
|
+
dout = torch.mm(A, B).to(out_dtype)
|
|
550
|
+
# Split PreAct into gate and up projections
|
|
551
|
+
gate = PreAct[..., ::2] # (M, N)
|
|
552
|
+
up = PreAct[..., 1::2] # (M, N)
|
|
553
|
+
postact = gated_to_pytorch_fn_map[activation](gate, up)
|
|
554
|
+
# Use autograd to compute gradients w.r.t. gate and up
|
|
555
|
+
dgate, dup = torch.autograd.grad(postact, [gate, up], dout, create_graph=False)
|
|
556
|
+
# Interleave gradients back
|
|
557
|
+
dx = torch.stack([dgate, dup], dim=-1).reshape(PreAct.shape)
|
|
558
|
+
return dx.to(out_dtype), postact.to(postact_dtype)
|
|
559
|
+
|
|
560
|
+
|
|
561
|
+
# TODO: this is not quite right, do we need to register gemm_add not gemm_add_out?
|
|
562
|
+
# try:
|
|
563
|
+
# from torch._inductor.fx_passes.reinplace import InplaceableOp
|
|
564
|
+
# torch._inductor.fx_passes.reinplace.inplaceable_ops.update({
|
|
565
|
+
# torch.ops.quack.gemm_add_out.default:
|
|
566
|
+
# InplaceableOp(torch.ops.quack.gemm_add_inplace.default, mutated_arg=2)
|
|
567
|
+
# })
|
|
568
|
+
# except ImportError:
|
|
569
|
+
# pass
|