quack-kernels 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -34
- quack/gemm.py +194 -0
- quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
- quack/gemm_config.py +72 -46
- quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +177 -31
- quack/gemm_sm100.py +729 -506
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +3 -1
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +476 -526
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +23 -16
- quack/topk.py +409 -85
- quack/utils.py +32 -220
- quack/varlen_utils.py +370 -1
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/METADATA +4 -2
- quack_kernels-0.2.4.dist-info/RECORD +44 -0
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.2.dist-info/RECORD +0 -37
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/top_level.txt +0 -0
quack/gemm_interface.py
CHANGED
|
@@ -9,9 +9,11 @@ from torch import Tensor
|
|
|
9
9
|
from quack.gemm_config import GemmConfig, get_all_configs
|
|
10
10
|
|
|
11
11
|
from quack.autotuner import autotune, AutotuneConfig
|
|
12
|
-
from quack.
|
|
13
|
-
from quack.
|
|
14
|
-
from quack.
|
|
12
|
+
from quack.cute_dsl_utils import get_device_capacity
|
|
13
|
+
from quack.gemm import gemm as gemm_sm90_sm100
|
|
14
|
+
from quack.gemm_act import gemm_act as gemm_act_sm90_sm100
|
|
15
|
+
from quack.gemm_dact import gemm_dact as gemm_dact_sm90_sm100
|
|
16
|
+
from quack.gemm_symmetric import gemm_symmetric as gemm_symmetric_sm90_sm100
|
|
15
17
|
|
|
16
18
|
|
|
17
19
|
# Dictionary mapping activation names to PyTorch functions
|
|
@@ -34,6 +36,16 @@ gated_to_pytorch_fn_map = {
|
|
|
34
36
|
}
|
|
35
37
|
|
|
36
38
|
|
|
39
|
+
default_device_capacity = get_device_capacity(torch.device("cuda"))
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def default_config(device):
|
|
43
|
+
if get_device_capacity(device)[0] != 10:
|
|
44
|
+
return GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
|
|
45
|
+
else:
|
|
46
|
+
return GemmConfig(tile_m=256, tile_n=256, cluster_m=2, cluster_n=1, pingpong=False)
|
|
47
|
+
|
|
48
|
+
|
|
37
49
|
def prune_invalid_gemm_configs(configs, named_args: dict, **kwargs):
|
|
38
50
|
kwargs = named_args | kwargs
|
|
39
51
|
gather_A = kwargs.get("A_idx", None) is not None
|
|
@@ -41,17 +53,18 @@ def prune_invalid_gemm_configs(configs, named_args: dict, **kwargs):
|
|
|
41
53
|
if varlen_m or gather_A: # Doesn't support swap_ab
|
|
42
54
|
configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab]
|
|
43
55
|
if gather_A:
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
56
|
+
if get_device_capacity(kwargs["A"].device)[0] == 9:
|
|
57
|
+
# tile_n == 208 causes register spills, as gather_A requires more registers for the producer
|
|
58
|
+
configs = [
|
|
59
|
+
conf
|
|
60
|
+
for conf in configs
|
|
61
|
+
if conf.kwargs["config"].cluster_n == 1 and conf.kwargs["config"].tile_n != 208
|
|
62
|
+
]
|
|
50
63
|
return configs
|
|
51
64
|
|
|
52
65
|
|
|
53
66
|
@autotune(
|
|
54
|
-
configs=[AutotuneConfig(config=c) for c in get_all_configs()],
|
|
67
|
+
configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])],
|
|
55
68
|
key=["dynamic_scheduler"],
|
|
56
69
|
prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
|
|
57
70
|
)
|
|
@@ -61,6 +74,7 @@ def gemm_tuned(
|
|
|
61
74
|
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
62
75
|
out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
63
76
|
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
77
|
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
|
64
78
|
alpha: float | Tensor = 1.0, # (1,)
|
|
65
79
|
beta: float | Tensor = 1.0, # (1,)
|
|
66
80
|
cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
|
|
@@ -72,7 +86,7 @@ def gemm_tuned(
|
|
|
72
86
|
config: Optional[GemmConfig] = None,
|
|
73
87
|
) -> None:
|
|
74
88
|
if config is None:
|
|
75
|
-
config =
|
|
89
|
+
config = default_config(A.device)
|
|
76
90
|
varlen_m = cu_seqlens_m is not None
|
|
77
91
|
varlen_k = cu_seqlens_k is not None
|
|
78
92
|
varlen = varlen_m or varlen_k
|
|
@@ -91,6 +105,8 @@ def gemm_tuned(
|
|
|
91
105
|
C = C.unsqueeze(0) # (1, M, N)
|
|
92
106
|
if out.ndim == 2 and not varlen_m:
|
|
93
107
|
out = out.unsqueeze(0)
|
|
108
|
+
if bias is not None and bias.ndim == 1:
|
|
109
|
+
bias = bias.unsqueeze(0) # (L, N)
|
|
94
110
|
batch_size = B.shape[0] if not varlen_k else cu_seqlens_k.shape[0] - 1
|
|
95
111
|
if varlen_m:
|
|
96
112
|
# If gather_A (A_idx provided), use its length; otherwise use A.shape[0]
|
|
@@ -102,7 +118,7 @@ def gemm_tuned(
|
|
|
102
118
|
tile_count_semaphore = (
|
|
103
119
|
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
|
|
104
120
|
)
|
|
105
|
-
|
|
121
|
+
gemm_sm90_sm100(
|
|
106
122
|
A if not config.swap_ab else B,
|
|
107
123
|
B if not config.swap_ab else A,
|
|
108
124
|
out if not config.swap_ab else out.mT,
|
|
@@ -113,6 +129,10 @@ def gemm_tuned(
|
|
|
113
129
|
config.cluster_m,
|
|
114
130
|
config.cluster_n,
|
|
115
131
|
config.pingpong,
|
|
132
|
+
persistent=True,
|
|
133
|
+
max_swizzle_size=config.max_swizzle_size,
|
|
134
|
+
rowvec_bias=bias if not config.swap_ab else None,
|
|
135
|
+
colvec_bias=bias if config.swap_ab else None,
|
|
116
136
|
alpha=alpha,
|
|
117
137
|
beta=beta,
|
|
118
138
|
cu_seqlens_m=cu_seqlens_m,
|
|
@@ -124,7 +144,7 @@ def gemm_tuned(
|
|
|
124
144
|
|
|
125
145
|
|
|
126
146
|
@autotune(
|
|
127
|
-
configs=[AutotuneConfig(config=c) for c in get_all_configs()],
|
|
147
|
+
configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])],
|
|
128
148
|
key=["activation", "dynamic_scheduler"],
|
|
129
149
|
prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
|
|
130
150
|
)
|
|
@@ -136,6 +156,7 @@ def gemm_act_tuned(
|
|
|
136
156
|
preact_out: Optional[Tensor],
|
|
137
157
|
postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
138
158
|
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
159
|
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
|
139
160
|
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
140
161
|
cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
|
|
141
162
|
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
@@ -143,7 +164,7 @@ def gemm_act_tuned(
|
|
|
143
164
|
config: Optional[GemmConfig] = None,
|
|
144
165
|
) -> None:
|
|
145
166
|
if config is None:
|
|
146
|
-
config =
|
|
167
|
+
config = default_config(A.device)
|
|
147
168
|
varlen_m = cu_seqlens_m is not None
|
|
148
169
|
if varlen_m:
|
|
149
170
|
assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
|
|
@@ -162,10 +183,12 @@ def gemm_act_tuned(
|
|
|
162
183
|
PostAct = postact_out.unsqueeze(0)
|
|
163
184
|
else:
|
|
164
185
|
PostAct = postact_out
|
|
186
|
+
if bias is not None and bias.ndim == 1:
|
|
187
|
+
bias = bias.unsqueeze(0) # (L, N)
|
|
165
188
|
tile_count_semaphore = (
|
|
166
189
|
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
|
|
167
190
|
)
|
|
168
|
-
|
|
191
|
+
gemm_act_sm90_sm100(
|
|
169
192
|
A if not config.swap_ab else B,
|
|
170
193
|
B if not config.swap_ab else A,
|
|
171
194
|
(D if not config.swap_ab else D.mT) if D is not None else None,
|
|
@@ -179,13 +202,16 @@ def gemm_act_tuned(
|
|
|
179
202
|
config.cluster_n,
|
|
180
203
|
config.pingpong,
|
|
181
204
|
persistent=True,
|
|
205
|
+
max_swizzle_size=config.max_swizzle_size,
|
|
206
|
+
rowvec_bias=bias if not config.swap_ab else None,
|
|
207
|
+
colvec_bias=bias if config.swap_ab else None,
|
|
182
208
|
cu_seqlens_m=cu_seqlens_m,
|
|
183
209
|
A_idx=A_idx,
|
|
184
210
|
)
|
|
185
211
|
|
|
186
212
|
|
|
187
213
|
@autotune(
|
|
188
|
-
configs=[AutotuneConfig(config=c) for c in get_all_configs()],
|
|
214
|
+
configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])],
|
|
189
215
|
key=["activation", "dynamic_scheduler"],
|
|
190
216
|
prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
|
|
191
217
|
)
|
|
@@ -203,7 +229,7 @@ def gemm_dact_tuned(
|
|
|
203
229
|
config: Optional[GemmConfig] = None,
|
|
204
230
|
) -> None:
|
|
205
231
|
if config is None:
|
|
206
|
-
config =
|
|
232
|
+
config = default_config(A.device)
|
|
207
233
|
varlen_m = cu_seqlens_m is not None
|
|
208
234
|
if varlen_m:
|
|
209
235
|
assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
|
|
@@ -225,7 +251,7 @@ def gemm_dact_tuned(
|
|
|
225
251
|
tile_count_semaphore = (
|
|
226
252
|
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
|
|
227
253
|
)
|
|
228
|
-
|
|
254
|
+
gemm_dact_sm90_sm100(
|
|
229
255
|
A if not config.swap_ab else B,
|
|
230
256
|
B if not config.swap_ab else A,
|
|
231
257
|
D if not config.swap_ab else D.mT,
|
|
@@ -239,6 +265,7 @@ def gemm_dact_tuned(
|
|
|
239
265
|
config.cluster_n,
|
|
240
266
|
config.pingpong,
|
|
241
267
|
persistent=True,
|
|
268
|
+
max_swizzle_size=config.max_swizzle_size,
|
|
242
269
|
cu_seqlens_m=cu_seqlens_m,
|
|
243
270
|
A_idx=A_idx,
|
|
244
271
|
)
|
|
@@ -249,6 +276,7 @@ def gemm(
|
|
|
249
276
|
A: Tensor,
|
|
250
277
|
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
251
278
|
out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
279
|
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
|
252
280
|
alpha: float | Tensor = 1.0,
|
|
253
281
|
out_dtype: Optional[torch.dtype] = None,
|
|
254
282
|
cu_seqlens_m: Optional[Tensor] = None,
|
|
@@ -281,6 +309,7 @@ def gemm(
|
|
|
281
309
|
A,
|
|
282
310
|
B,
|
|
283
311
|
out,
|
|
312
|
+
bias=bias,
|
|
284
313
|
alpha=alpha,
|
|
285
314
|
alpha_tensor=alpha_tensor,
|
|
286
315
|
cu_seqlens_m=cu_seqlens_m,
|
|
@@ -299,13 +328,14 @@ def gemm(
|
|
|
299
328
|
device_types="cuda",
|
|
300
329
|
# We have to split out alpha and alpha_tensor since torch.library requires
|
|
301
330
|
# each argument to have a fixed type
|
|
302
|
-
# schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, Tensor? alpha_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
331
|
+
# schema="(Tensor A, Tensor B, Tensor(a2!) out, Tensor? bias, float alpha=1.0, Tensor? alpha_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
303
332
|
)
|
|
304
333
|
def gemm_out(
|
|
305
334
|
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
|
306
335
|
A: Tensor,
|
|
307
336
|
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
308
337
|
out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
338
|
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
|
309
339
|
alpha: float = 1.0,
|
|
310
340
|
alpha_tensor: Optional[Tensor] = None,
|
|
311
341
|
cu_seqlens_m: Optional[Tensor] = None,
|
|
@@ -323,6 +353,7 @@ def gemm_out(
|
|
|
323
353
|
B,
|
|
324
354
|
out,
|
|
325
355
|
C=None,
|
|
356
|
+
bias=bias,
|
|
326
357
|
alpha=alpha,
|
|
327
358
|
cu_seqlens_m=cu_seqlens_m,
|
|
328
359
|
cu_seqlens_k=cu_seqlens_k,
|
|
@@ -337,6 +368,7 @@ def gemm_ref(
|
|
|
337
368
|
A: Tensor,
|
|
338
369
|
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
339
370
|
out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
371
|
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
|
340
372
|
alpha: float | Tensor = 1.0,
|
|
341
373
|
cu_seqlens_m: Optional[Tensor] = None,
|
|
342
374
|
cu_seqlens_k: Optional[Tensor] = None,
|
|
@@ -349,6 +381,11 @@ def gemm_ref(
|
|
|
349
381
|
if cu_seqlens_m is None and cu_seqlens_k is None:
|
|
350
382
|
fn = torch.bmm if A.ndim == 3 else torch.mm
|
|
351
383
|
out = fn(A, B, out_dtype=out_dtype, out=out)
|
|
384
|
+
if not isinstance(alpha, float) or alpha != 1.0:
|
|
385
|
+
out *= alpha
|
|
386
|
+
if bias is not None:
|
|
387
|
+
bias = bias if A.ndim == 2 else bias.unsqueeze(1)
|
|
388
|
+
out += bias
|
|
352
389
|
elif cu_seqlens_m is not None:
|
|
353
390
|
# Handle varlen_m case
|
|
354
391
|
if out is None:
|
|
@@ -362,6 +399,10 @@ def gemm_ref(
|
|
|
362
399
|
else A[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
|
|
363
400
|
)
|
|
364
401
|
torch.mm(A_slice, B[i], out=out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]])
|
|
402
|
+
if not isinstance(alpha, float) or alpha != 1.0:
|
|
403
|
+
out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]] *= alpha
|
|
404
|
+
if bias is not None:
|
|
405
|
+
out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]] += bias[i]
|
|
365
406
|
else: # cu_seqlens_k is not None
|
|
366
407
|
L = cu_seqlens_k.shape[0] - 1
|
|
367
408
|
if out is None:
|
|
@@ -373,8 +414,10 @@ def gemm_ref(
|
|
|
373
414
|
else A[:, cu_seqlens_k[i] : cu_seqlens_k[i + 1]]
|
|
374
415
|
)
|
|
375
416
|
torch.mm(A_slice, B[cu_seqlens_k[i] : cu_seqlens_k[i + 1], :], out=out[i])
|
|
376
|
-
|
|
377
|
-
|
|
417
|
+
if not isinstance(alpha, float) or alpha != 1.0:
|
|
418
|
+
out *= alpha
|
|
419
|
+
if bias is not None:
|
|
420
|
+
out += bias
|
|
378
421
|
return out
|
|
379
422
|
|
|
380
423
|
|
|
@@ -488,6 +531,7 @@ def gemm_add_ref(
|
|
|
488
531
|
A: Tensor,
|
|
489
532
|
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
490
533
|
C: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
534
|
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
|
491
535
|
out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
492
536
|
alpha: float | Tensor = 1.0,
|
|
493
537
|
beta: float | Tensor = 1.0,
|
|
@@ -499,7 +543,7 @@ def gemm_add_ref(
|
|
|
499
543
|
"""Reference implementation for GEMM with addition and pre-allocated output."""
|
|
500
544
|
if cu_seqlens_m is None and cu_seqlens_k is None:
|
|
501
545
|
if isinstance(alpha, float) and isinstance(beta, float):
|
|
502
|
-
|
|
546
|
+
out = torch.addmm(C, A, B, out_dtype=out_dtype, alpha=alpha, beta=beta, out=out)
|
|
503
547
|
else:
|
|
504
548
|
out_dtype = (
|
|
505
549
|
out.dtype if out is not None else (out_dtype if out_dtype is not None else A.dtype)
|
|
@@ -507,7 +551,9 @@ def gemm_add_ref(
|
|
|
507
551
|
result = (alpha * (A @ B) + beta * C).to(out_dtype)
|
|
508
552
|
if out is not None:
|
|
509
553
|
out.copy_(result)
|
|
510
|
-
|
|
554
|
+
if bias is not None:
|
|
555
|
+
bias = bias if A.ndim == 2 else bias.unsqueeze(1)
|
|
556
|
+
out += bias
|
|
511
557
|
elif cu_seqlens_m is not None:
|
|
512
558
|
# Handle varlen_m case
|
|
513
559
|
if out is None:
|
|
@@ -524,6 +570,8 @@ def gemm_add_ref(
|
|
|
524
570
|
C_slice = C[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
|
|
525
571
|
out_slice = out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
|
|
526
572
|
result = alpha * torch.mm(A_slice, B[i]) + beta * C_slice
|
|
573
|
+
if bias is not None:
|
|
574
|
+
result += bias[i]
|
|
527
575
|
out_slice.copy_(result)
|
|
528
576
|
else: # cu_seqlens_k is not None
|
|
529
577
|
# Handle varlen_k case
|
|
@@ -540,6 +588,8 @@ def gemm_add_ref(
|
|
|
540
588
|
B_slice = B[cu_seqlens_k[i] : cu_seqlens_k[i + 1], :]
|
|
541
589
|
result = alpha * torch.mm(A_slice, B_slice) + beta * C[i]
|
|
542
590
|
out[i].copy_(result)
|
|
591
|
+
if bias is not None:
|
|
592
|
+
out += bias
|
|
543
593
|
return out
|
|
544
594
|
|
|
545
595
|
|
|
@@ -639,6 +689,7 @@ def gemm_act(
|
|
|
639
689
|
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
|
640
690
|
B: Tensor, # (K, N) or (L, K, N)
|
|
641
691
|
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
692
|
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
|
642
693
|
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
643
694
|
preact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
644
695
|
postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
@@ -667,7 +718,17 @@ def gemm_act(
|
|
|
667
718
|
if postact_out is None:
|
|
668
719
|
postact_out = torch.empty(out_shape, dtype=postact_dtype, device=A.device)
|
|
669
720
|
gemm_act_out(
|
|
670
|
-
A,
|
|
721
|
+
A,
|
|
722
|
+
B,
|
|
723
|
+
preact_out,
|
|
724
|
+
postact_out,
|
|
725
|
+
C,
|
|
726
|
+
bias,
|
|
727
|
+
activation,
|
|
728
|
+
cu_seqlens_m,
|
|
729
|
+
A_idx,
|
|
730
|
+
dynamic_scheduler,
|
|
731
|
+
tuned,
|
|
671
732
|
)
|
|
672
733
|
return preact_out, postact_out
|
|
673
734
|
|
|
@@ -676,7 +737,7 @@ def gemm_act(
|
|
|
676
737
|
"quack::gemm_act_out",
|
|
677
738
|
mutates_args=("preact_out", "postact_out"),
|
|
678
739
|
device_types="cuda",
|
|
679
|
-
schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, str? activation=None, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
740
|
+
schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, Tensor? bias=None, str? activation=None, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
680
741
|
)
|
|
681
742
|
def gemm_act_out(
|
|
682
743
|
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
|
@@ -684,6 +745,7 @@ def gemm_act_out(
|
|
|
684
745
|
preact_out: Optional[Tensor], # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
685
746
|
postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
686
747
|
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
748
|
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
|
687
749
|
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
688
750
|
cu_seqlens_m: Optional[Tensor] = None,
|
|
689
751
|
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
@@ -692,13 +754,14 @@ def gemm_act_out(
|
|
|
692
754
|
) -> None:
|
|
693
755
|
"""GEMM with activation and pre-allocated output tensors."""
|
|
694
756
|
fn = gemm_act_tuned if tuned else partial(gemm_act_tuned.fn, config=None)
|
|
695
|
-
fn(A, B, preact_out, postact_out, C, activation, cu_seqlens_m, A_idx, dynamic_scheduler)
|
|
757
|
+
fn(A, B, preact_out, postact_out, C, bias, activation, cu_seqlens_m, A_idx, dynamic_scheduler)
|
|
696
758
|
|
|
697
759
|
|
|
698
760
|
def gemm_act_ref(
|
|
699
761
|
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
|
|
700
762
|
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
701
763
|
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
764
|
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
|
702
765
|
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
703
766
|
cu_seqlens_m: Optional[Tensor] = None,
|
|
704
767
|
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
@@ -709,9 +772,9 @@ def gemm_act_ref(
|
|
|
709
772
|
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
710
773
|
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
|
711
774
|
if C is None:
|
|
712
|
-
out = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
|
|
775
|
+
out = gemm_ref(A, B, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
|
|
713
776
|
else:
|
|
714
|
-
out = gemm_add_ref(A, B, C, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
|
|
777
|
+
out = gemm_add_ref(A, B, C, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
|
|
715
778
|
postact = act_to_pytorch_fn_map[activation](out).to(postact_dtype)
|
|
716
779
|
return out.to(out_dtype) if store_preact else None, postact
|
|
717
780
|
|
|
@@ -806,6 +869,7 @@ def gemm_gated_ref(
|
|
|
806
869
|
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
|
|
807
870
|
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
808
871
|
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
872
|
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
|
809
873
|
activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"] = "swiglu",
|
|
810
874
|
cu_seqlens_m: Optional[Tensor] = None,
|
|
811
875
|
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
@@ -832,9 +896,9 @@ def gemm_gated_ref(
|
|
|
832
896
|
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
833
897
|
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
|
834
898
|
if C is None:
|
|
835
|
-
preact = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
|
|
899
|
+
preact = gemm_ref(A, B, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
|
|
836
900
|
else:
|
|
837
|
-
preact = gemm_add_ref(A, B, C, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
|
|
901
|
+
preact = gemm_add_ref(A, B, C, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
|
|
838
902
|
# Split preact into gate and up projections
|
|
839
903
|
gate = preact[..., ::2] # (M, N//2)
|
|
840
904
|
up = preact[..., 1::2] # (M, N//2)
|
|
@@ -873,14 +937,96 @@ def gemm_dgated_ref(
|
|
|
873
937
|
# Split PreAct into gate and up projections
|
|
874
938
|
gate = PreAct[..., ::2] # (M, N)
|
|
875
939
|
up = PreAct[..., 1::2] # (M, N)
|
|
876
|
-
postact = gated_to_pytorch_fn_map[activation](gate, up)
|
|
877
940
|
# Use autograd to compute gradients w.r.t. gate and up
|
|
941
|
+
gate_requires_grad, up_requires_grad = gate.requires_grad, up.requires_grad
|
|
942
|
+
gate.requires_grad_(True)
|
|
943
|
+
up.requires_grad_(True)
|
|
944
|
+
postact = gated_to_pytorch_fn_map[activation](gate, up)
|
|
878
945
|
dgate, dup = torch.autograd.grad(postact, [gate, up], dout, create_graph=False)
|
|
946
|
+
gate.requires_grad_(gate_requires_grad)
|
|
947
|
+
up.requires_grad_(up_requires_grad)
|
|
879
948
|
# Interleave gradients back
|
|
880
949
|
dx = torch.stack([dgate, dup], dim=-1).reshape(PreAct.shape)
|
|
881
950
|
return dx.to(out_dtype), postact.to(postact_dtype)
|
|
882
951
|
|
|
883
952
|
|
|
953
|
+
@torch.library.custom_op(
|
|
954
|
+
"quack::gemm_symmetric_out",
|
|
955
|
+
mutates_args=("out",),
|
|
956
|
+
device_types="cuda",
|
|
957
|
+
schema="(Tensor A, Tensor B, Tensor(a2!) out, Tensor? C=None, bool dynamic_scheduler=False, float alpha=1.0, float beta=1.0) -> ()",
|
|
958
|
+
)
|
|
959
|
+
def gemm_symmetric_out(
|
|
960
|
+
A: Tensor, # (M, K) or (L, M, K)
|
|
961
|
+
B: Tensor, # (K, M) or (L, K, M)
|
|
962
|
+
out: Tensor, # (M, M) or (L, M, M)
|
|
963
|
+
C: Optional[Tensor] = None, # (M, M) or (L, M, M)
|
|
964
|
+
dynamic_scheduler: bool = False,
|
|
965
|
+
alpha: float = 1.0,
|
|
966
|
+
beta: float = 1.0,
|
|
967
|
+
) -> None:
|
|
968
|
+
"""GEMM with guaranteed symmetric output."""
|
|
969
|
+
if A.ndim == 2:
|
|
970
|
+
A = A.unsqueeze(0) # (1, M, K)
|
|
971
|
+
B = B.mT # (M, K) or (L, M, K)
|
|
972
|
+
if B.ndim == 2:
|
|
973
|
+
B = B.unsqueeze(0) # (1, M, K)
|
|
974
|
+
if C is not None and C.ndim == 2:
|
|
975
|
+
C = C.unsqueeze(0) # (1, M, M)
|
|
976
|
+
if out.ndim == 2:
|
|
977
|
+
out = out.unsqueeze(0)
|
|
978
|
+
else:
|
|
979
|
+
out = out
|
|
980
|
+
tile_count_semaphore = (
|
|
981
|
+
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
|
|
982
|
+
)
|
|
983
|
+
gemm_symmetric_sm90_sm100(
|
|
984
|
+
A,
|
|
985
|
+
B,
|
|
986
|
+
out if out is not None else None,
|
|
987
|
+
C if C is not None else None,
|
|
988
|
+
tile_count_semaphore,
|
|
989
|
+
tile_M=128,
|
|
990
|
+
tile_N=256,
|
|
991
|
+
cluster_M=2,
|
|
992
|
+
cluster_N=1,
|
|
993
|
+
pingpong=False,
|
|
994
|
+
persistent=True,
|
|
995
|
+
max_swizzle_size=8,
|
|
996
|
+
alpha=alpha,
|
|
997
|
+
beta=beta,
|
|
998
|
+
)
|
|
999
|
+
|
|
1000
|
+
|
|
1001
|
+
def gemm_symmetric(
|
|
1002
|
+
A: Tensor, # (M, K) or (L, M, K)
|
|
1003
|
+
B: Tensor, # (K, M) or (L, K, M)
|
|
1004
|
+
C: Optional[Tensor] = None, # (M, M) or (L, M, M)
|
|
1005
|
+
out: Optional[Tensor] = None, # (M, M) or (L, M, M)
|
|
1006
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
1007
|
+
dynamic_scheduler: bool = False,
|
|
1008
|
+
alpha: float | Tensor = 1.0,
|
|
1009
|
+
beta: float | Tensor = 1.0,
|
|
1010
|
+
) -> Tuple[Optional[Tensor], Tensor]:
|
|
1011
|
+
"""GEMM with symmetric output."""
|
|
1012
|
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
1013
|
+
# Determine output shape based on gather_A
|
|
1014
|
+
if A.ndim == 2:
|
|
1015
|
+
out_shape = (A.shape[0], B.shape[-1])
|
|
1016
|
+
else:
|
|
1017
|
+
out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
|
|
1018
|
+
if out is None:
|
|
1019
|
+
out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
|
1020
|
+
|
|
1021
|
+
alpha_val = alpha if isinstance(alpha, float) else 1.0
|
|
1022
|
+
beta_val = beta if isinstance(beta, float) else 1.0
|
|
1023
|
+
|
|
1024
|
+
gemm_symmetric_out(
|
|
1025
|
+
A, B, out, C, dynamic_scheduler=dynamic_scheduler, alpha=alpha_val, beta=beta_val
|
|
1026
|
+
)
|
|
1027
|
+
return out
|
|
1028
|
+
|
|
1029
|
+
|
|
884
1030
|
# TODO: this is not quite right, do we need to register gemm_add not gemm_add_out?
|
|
885
1031
|
# try:
|
|
886
1032
|
# from torch._inductor.fx_passes.reinplace import InplaceableOp
|