quack-kernels 0.1.10__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,569 @@
1
+ # Copyright (c) 2025, Tri Dao
2
+ from typing import Optional, Tuple, Literal
3
+ from functools import partial
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+
9
+ from quack.gemm_config import GemmConfig, get_all_configs
10
+
11
+ from quack.autotuner import autotune, AutotuneConfig
12
+ from quack.dense_gemm_sm90 import gemm_sm90
13
+ from quack.gemm_act_sm90 import gemm_act_sm90
14
+ from quack.gemm_dact_sm90 import gemm_dact_sm90
15
+
16
+
17
+ # Dictionary mapping activation names to PyTorch functions
18
+ act_to_pytorch_fn_map = {
19
+ None: lambda x: x,
20
+ "relu": F.relu,
21
+ "relu_sq": lambda x: F.relu(x).square(),
22
+ "gelu_tanh_approx": partial(F.gelu, approximate="tanh"),
23
+ }
24
+
25
+
26
+ # Dictionary mapping gated activation names to their forward functions
27
+ # Each function takes (gate, up) and returns postact
28
+ gated_to_pytorch_fn_map = {
29
+ "swiglu": lambda gate, up: F.silu(gate) * up,
30
+ "swiglu_oai": lambda gate, up: gate * torch.sigmoid(1.702 * gate) * (up + 1),
31
+ "reglu": lambda gate, up: F.relu(gate) * up,
32
+ "geglu": lambda gate, up: F.gelu(gate, approximate="tanh") * up,
33
+ "glu": lambda gate, up: torch.sigmoid(gate) * up,
34
+ }
35
+
36
+
37
+ @autotune(
38
+ configs=[AutotuneConfig(config=c) for c in get_all_configs()],
39
+ key=["dynamic_scheduler"],
40
+ )
41
+ def gemm_tuned(
42
+ A: Tensor, # (M, K)
43
+ B: Tensor, # (K, N)
44
+ out: Tensor, # (M, N) - required output tensor
45
+ C: Optional[Tensor] = None, # (M, N)
46
+ alpha: float | Tensor = 1.0, # (1,)
47
+ beta: float | Tensor = 1.0, # (1,)
48
+ dynamic_scheduler: bool = False,
49
+ config: Optional[GemmConfig] = None,
50
+ ) -> None:
51
+ if config is None:
52
+ config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
53
+ A, B = A.unsqueeze(0), B.mT.unsqueeze(0) # (1, M, K), (1, N, K)
54
+ if C is not None:
55
+ C = C.unsqueeze(0) # (1, M, N)
56
+ assert out.shape == (
57
+ A.shape[1],
58
+ B.shape[1],
59
+ ), f"out shape mismatch: {out.shape} vs {(A.shape[1], B.shape[1])}"
60
+ out = out.unsqueeze(0)
61
+ tile_count_semaphore = (
62
+ torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
63
+ )
64
+ gemm_sm90(
65
+ A if not config.swap_ab else B,
66
+ B if not config.swap_ab else A,
67
+ out if not config.swap_ab else out.mT,
68
+ (C if not config.swap_ab else C.mT) if C is not None else None,
69
+ tile_count_semaphore,
70
+ config.tile_m,
71
+ config.tile_n,
72
+ config.cluster_m,
73
+ config.cluster_n,
74
+ config.pingpong,
75
+ alpha=alpha,
76
+ beta=beta,
77
+ )
78
+
79
+
80
+ @autotune(
81
+ configs=[AutotuneConfig(config=c) for c in get_all_configs()],
82
+ key=["activation"],
83
+ )
84
+ def gemm_act_tuned(
85
+ A: Tensor, # (M, K)
86
+ B: Tensor, # (K, N)
87
+ preact_out: Optional[Tensor], # (M, N) - None if not storing preact
88
+ postact_out: Tensor, # (M, N)
89
+ C: Optional[Tensor] = None, # (M, N)
90
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
91
+ config: Optional[GemmConfig] = None,
92
+ ) -> None:
93
+ if config is None:
94
+ config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
95
+ A, B = A.unsqueeze(0), B.mT.unsqueeze(0) # (1, M, K), (1, N, K)
96
+ if C is not None:
97
+ C = C.unsqueeze(0) # (1, M, N)
98
+ if preact_out is not None:
99
+ assert preact_out.shape == (A.shape[1], B.shape[1])
100
+ D = preact_out.unsqueeze(0)
101
+ else:
102
+ D = None
103
+ assert postact_out.shape == (A.shape[1], B.shape[1])
104
+ PostAct = postact_out.unsqueeze(0)
105
+ gemm_act_sm90(
106
+ A if not config.swap_ab else B,
107
+ B if not config.swap_ab else A,
108
+ (D if not config.swap_ab else D.mT) if D is not None else None,
109
+ (C if not config.swap_ab else C.mT) if C is not None else None,
110
+ PostAct if not config.swap_ab else PostAct.mT,
111
+ activation,
112
+ config.tile_m,
113
+ config.tile_n,
114
+ config.cluster_m,
115
+ config.cluster_n,
116
+ config.pingpong,
117
+ )
118
+
119
+
120
+ @autotune(
121
+ configs=[AutotuneConfig(config=c) for c in get_all_configs()],
122
+ key=["activation", "dynamic_scheduler"],
123
+ )
124
+ def gemm_dact_tuned(
125
+ A: Tensor, # (M, K)
126
+ B: Tensor, # (K, N)
127
+ PreAct: Tensor, # (M, N)
128
+ dx_out: Tensor, # (M, N)
129
+ postact_out: Tensor, # (M, N)
130
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
131
+ dynamic_scheduler: bool = True,
132
+ config: Optional[GemmConfig] = None,
133
+ ) -> None:
134
+ if config is None:
135
+ config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
136
+ A, B = A.unsqueeze(0), B.mT.unsqueeze(0) # (1, M, K), (1, N, K)
137
+ PreAct = PreAct.unsqueeze(0) # (1, M, N)
138
+ assert dx_out.shape == (A.shape[1], B.shape[1])
139
+ D = dx_out.unsqueeze(0)
140
+ assert postact_out.shape == (A.shape[1], B.shape[1])
141
+ PostAct = postact_out.unsqueeze(0)
142
+ tile_count_semaphore = (
143
+ torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
144
+ )
145
+ gemm_dact_sm90(
146
+ A if not config.swap_ab else B,
147
+ B if not config.swap_ab else A,
148
+ D if not config.swap_ab else D.mT,
149
+ PreAct if not config.swap_ab else PreAct.mT,
150
+ PostAct if not config.swap_ab else PostAct.mT,
151
+ tile_count_semaphore,
152
+ activation,
153
+ config.tile_m,
154
+ config.tile_n,
155
+ config.cluster_m,
156
+ config.cluster_n,
157
+ config.pingpong,
158
+ )
159
+
160
+
161
+ def gemm(
162
+ A: Tensor,
163
+ B: Tensor,
164
+ out: Optional[Tensor] = None,
165
+ alpha: float | Tensor = 1.0,
166
+ out_dtype: Optional[torch.dtype] = None,
167
+ dynamic_scheduler: bool = False,
168
+ tuned: bool = True,
169
+ ) -> Tensor:
170
+ """GEMM with optional output tensor and tuning control."""
171
+ if out is None:
172
+ out_dtype = A.dtype if out_dtype is None else out_dtype
173
+ out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
174
+ alpha_tensor = alpha if not isinstance(alpha, float) else None
175
+ alpha = alpha if isinstance(alpha, float) else 1.0
176
+ gemm_out(
177
+ A,
178
+ B,
179
+ out,
180
+ alpha=alpha,
181
+ alpha_tensor=alpha_tensor,
182
+ dynamic_scheduler=dynamic_scheduler,
183
+ tuned=tuned,
184
+ )
185
+ return out
186
+
187
+
188
+ @torch.library.custom_op(
189
+ "quack::gemm_out",
190
+ mutates_args=("out",),
191
+ device_types="cuda",
192
+ # We have to split out alpha and alpha_tensor since torch.library requires
193
+ # each argument to have a fixed type
194
+ # schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, Tensor? alpha_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
195
+ )
196
+ def gemm_out(
197
+ A: Tensor,
198
+ B: Tensor,
199
+ out: Tensor,
200
+ alpha: float = 1.0,
201
+ alpha_tensor: Optional[Tensor] = None,
202
+ dynamic_scheduler: bool = False,
203
+ tuned: bool = True,
204
+ ) -> None:
205
+ """GEMM with pre-allocated output tensor."""
206
+ fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
207
+ alpha = alpha_tensor if alpha_tensor is not None else alpha
208
+ fn(A, B, out, C=None, alpha=alpha, dynamic_scheduler=dynamic_scheduler)
209
+
210
+
211
+ def gemm_ref(
212
+ A: Tensor,
213
+ B: Tensor,
214
+ out: Optional[Tensor] = None,
215
+ alpha: float | Tensor = 1.0,
216
+ out_dtype: Optional[torch.dtype] = None,
217
+ ) -> Tensor:
218
+ """Reference implementation for GEMM with pre-allocated output."""
219
+ # The out_dtype argument requires torch >= 2.8
220
+ out = torch.mm(A, B, out_dtype=out_dtype, out=out)
221
+ if not isinstance(alpha, float) or alpha != 1.0:
222
+ out = out * alpha
223
+ return out
224
+
225
+
226
+ def gemm_add(
227
+ A: Tensor,
228
+ B: Tensor,
229
+ C: Tensor,
230
+ out: Optional[Tensor] = None,
231
+ alpha: float | Tensor = 1.0,
232
+ beta: float | Tensor = 1.0,
233
+ out_dtype: Optional[torch.dtype] = None,
234
+ dynamic_scheduler: bool = False,
235
+ tuned: bool = True,
236
+ ) -> Tensor:
237
+ """GEMM with addition and optional output tensor."""
238
+ if out is None:
239
+ out_dtype = A.dtype if out_dtype is None else out_dtype
240
+ out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
241
+ alpha_tensor = alpha if not isinstance(alpha, float) else None
242
+ alpha = alpha if isinstance(alpha, float) else 1.0
243
+ beta_tensor = beta if not isinstance(beta, float) else None
244
+ beta = beta if isinstance(beta, float) else 1.0
245
+ gemm_add_out(
246
+ A,
247
+ B,
248
+ C,
249
+ out,
250
+ alpha,
251
+ beta,
252
+ alpha_tensor,
253
+ beta_tensor,
254
+ dynamic_scheduler=dynamic_scheduler,
255
+ tuned=tuned,
256
+ )
257
+ return out
258
+
259
+
260
+ @torch.library.custom_op(
261
+ "quack::gemm_add_out",
262
+ mutates_args=("out",),
263
+ device_types="cuda",
264
+ # We have to split out alpha and alpha_tensor since torch.library requires
265
+ # each argument to have a fixed type
266
+ # schema="(Tensor A, Tensor B, Tensor C, Tensor(a3!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
267
+ )
268
+ def gemm_add_out(
269
+ A: Tensor,
270
+ B: Tensor,
271
+ C: Tensor,
272
+ out: Tensor,
273
+ alpha: float = 1.0,
274
+ beta: float = 1.0,
275
+ alpha_tensor: Optional[Tensor] = None,
276
+ beta_tensor: Optional[Tensor] = None,
277
+ dynamic_scheduler: bool = False,
278
+ tuned: bool = True,
279
+ ) -> None:
280
+ """GEMM with addition and pre-allocated output tensor."""
281
+ fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
282
+ alpha = alpha_tensor if alpha_tensor is not None else alpha
283
+ beta = beta_tensor if beta_tensor is not None else beta
284
+ fn(A, B, out, C, alpha=alpha, beta=beta, dynamic_scheduler=dynamic_scheduler)
285
+
286
+
287
+ def gemm_add_ref(
288
+ A: Tensor,
289
+ B: Tensor,
290
+ C: Tensor,
291
+ out: Optional[Tensor] = None,
292
+ alpha: float | Tensor = 1.0,
293
+ beta: float | Tensor = 1.0,
294
+ out_dtype: Optional[torch.dtype] = None,
295
+ ) -> Tensor:
296
+ """Reference implementation for GEMM with addition and pre-allocated output."""
297
+ if isinstance(alpha, float) and isinstance(beta, float):
298
+ return torch.addmm(C, A, B, out_dtype=out_dtype, alpha=alpha, beta=beta, out=out)
299
+ else:
300
+ out_dtype = (
301
+ out.dtype if out is not None else (out_dtype if out_dtype is not None else A.dtype)
302
+ )
303
+ result = (alpha * (A @ B) + beta * C).to(out_dtype)
304
+ if out is not None:
305
+ out.copy_(result)
306
+ return result
307
+
308
+
309
+ def gemm_add_inplace(
310
+ A: Tensor,
311
+ B: Tensor,
312
+ out: Tensor,
313
+ alpha: float | Tensor = 1.0,
314
+ beta: float | Tensor = 1.0,
315
+ dynamic_scheduler: bool = False,
316
+ tuned: bool = True,
317
+ ) -> None:
318
+ """In-place GEMM with addition: out = alpha * A @ B + beta * out.
319
+ Args:
320
+ A: (M, K) input tensor
321
+ B: (K, N) input tensor
322
+ out: (M, N) tensor to accumulate into (modified in-place)
323
+ alpha: Scalar multiplier for A @ B
324
+ beta: Scalar multiplier for out
325
+ dynamic_scheduler: Whether to use dynamic scheduler
326
+ tuned: Whether to use autotuned configuration
327
+ """
328
+ alpha_tensor = alpha if not isinstance(alpha, float) else None
329
+ alpha = alpha if isinstance(alpha, float) else 1.0
330
+ beta_tensor = beta if not isinstance(beta, float) else None
331
+ beta = beta if isinstance(beta, float) else 1.0
332
+ gemm_add_inplace_op(A, B, out, alpha, beta, alpha_tensor, beta_tensor, dynamic_scheduler, tuned)
333
+
334
+
335
+ @torch.library.custom_op(
336
+ "quack::gemm_add_inplace",
337
+ mutates_args=("out",),
338
+ device_types="cuda",
339
+ # We have to split out alpha and alpha_tensor since torch.library requires
340
+ # each argument to have a fixed type
341
+ # schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
342
+ )
343
+ def gemm_add_inplace_op(
344
+ A: Tensor,
345
+ B: Tensor,
346
+ out: Tensor,
347
+ alpha: float = 1.0,
348
+ beta: float = 1.0,
349
+ alpha_tensor: Optional[Tensor] = None,
350
+ beta_tensor: Optional[Tensor] = None,
351
+ dynamic_scheduler: bool = False,
352
+ tuned: bool = True,
353
+ ) -> None:
354
+ fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
355
+ alpha = alpha_tensor if alpha_tensor is not None else alpha
356
+ beta = beta_tensor if beta_tensor is not None else beta
357
+ # Use C as both input bias and output
358
+ fn(A, B, out, out, alpha=alpha, beta=beta, dynamic_scheduler=dynamic_scheduler)
359
+
360
+
361
+ def gemm_act(
362
+ A: Tensor,
363
+ B: Tensor,
364
+ C: Optional[Tensor] = None,
365
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
366
+ preact_out: Optional[Tensor] = None,
367
+ postact_out: Optional[Tensor] = None,
368
+ out_dtype: Optional[torch.dtype] = None,
369
+ postact_dtype: Optional[torch.dtype] = None,
370
+ store_preact: bool = True,
371
+ tuned: bool = True,
372
+ ) -> Tuple[Optional[Tensor], Tensor]:
373
+ """GEMM with activation and optional output tensors."""
374
+ out_dtype = A.dtype if out_dtype is None else out_dtype
375
+ postact_dtype = A.dtype if postact_dtype is None else postact_dtype
376
+ if preact_out is None and store_preact:
377
+ preact_out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
378
+ if postact_out is None:
379
+ postact_out = torch.empty((A.shape[0], B.shape[1]), dtype=postact_dtype, device=A.device)
380
+ gemm_act_out(A, B, preact_out, postact_out, C, activation, tuned)
381
+ return preact_out, postact_out
382
+
383
+
384
+ @torch.library.custom_op(
385
+ "quack::gemm_act_out",
386
+ mutates_args=("preact_out", "postact_out"),
387
+ device_types="cuda",
388
+ schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, str? activation=None, bool tuned=True) -> ()",
389
+ )
390
+ def gemm_act_out(
391
+ A: Tensor,
392
+ B: Tensor,
393
+ preact_out: Optional[Tensor],
394
+ postact_out: Tensor,
395
+ C: Optional[Tensor] = None,
396
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
397
+ tuned: bool = True,
398
+ ) -> None:
399
+ """GEMM with activation and pre-allocated output tensors."""
400
+ fn = gemm_act_tuned if tuned else partial(gemm_act_tuned.fn, config=None)
401
+ fn(A, B, preact_out, postact_out, C, activation)
402
+
403
+
404
+ def gemm_act_ref(
405
+ A: Tensor,
406
+ B: Tensor,
407
+ C: Optional[Tensor] = None,
408
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
409
+ out_dtype: Optional[torch.dtype] = None,
410
+ postact_dtype: Optional[torch.dtype] = None,
411
+ store_preact: bool = True,
412
+ ) -> Tuple[Optional[Tensor], Tensor]:
413
+ out_dtype = A.dtype if out_dtype is None else out_dtype
414
+ postact_dtype = A.dtype if postact_dtype is None else postact_dtype
415
+ out = torch.mm(A, B) if C is None else C + torch.mm(A, B)
416
+ postact = act_to_pytorch_fn_map[activation](out).to(postact_dtype)
417
+ return out.to(out_dtype) if store_preact else None, postact
418
+
419
+
420
+ def gemm_dact(
421
+ A: Tensor,
422
+ B: Tensor,
423
+ PreAct: Tensor,
424
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
425
+ dx_out: Optional[Tensor] = None,
426
+ postact_out: Optional[Tensor] = None,
427
+ out_dtype: Optional[torch.dtype] = None,
428
+ postact_dtype: Optional[torch.dtype] = None,
429
+ dynamic_scheduler: bool = True,
430
+ tuned: bool = True,
431
+ ) -> Tuple[Tensor, Tensor]:
432
+ """GEMM with activation gradient and optional output tensors."""
433
+ out_dtype = A.dtype if out_dtype is None else out_dtype
434
+ postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
435
+ if dx_out is None:
436
+ dx_out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
437
+ if postact_out is None:
438
+ postact_out = torch.empty((A.shape[0], B.shape[1]), dtype=postact_dtype, device=A.device)
439
+ gemm_dact_out(A, B, PreAct, dx_out, postact_out, activation, dynamic_scheduler, tuned)
440
+ return dx_out, postact_out
441
+
442
+
443
+ @torch.library.custom_op(
444
+ "quack::gemm_dact_out",
445
+ mutates_args=("dx_out", "postact_out"),
446
+ device_types="cuda",
447
+ schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a3!) dx_out, Tensor(a4!) postact_out, str? activation=None, bool dynamic_scheduler=True, bool tuned=True) -> ()",
448
+ )
449
+ def gemm_dact_out(
450
+ A: Tensor,
451
+ B: Tensor,
452
+ PreAct: Tensor,
453
+ dx_out: Tensor,
454
+ postact_out: Tensor,
455
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
456
+ dynamic_scheduler: bool = True,
457
+ tuned: bool = True,
458
+ ) -> None:
459
+ """GEMM with activation gradient and pre-allocated output tensors."""
460
+ fn = gemm_dact_tuned if tuned else partial(gemm_dact_tuned.fn, config=None)
461
+ fn(A, B, PreAct, dx_out, postact_out, activation, dynamic_scheduler)
462
+
463
+
464
+ def gemm_dact_ref(
465
+ A: Tensor,
466
+ B: Tensor,
467
+ PreAct: Tensor,
468
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
469
+ out_dtype: Optional[torch.dtype] = None,
470
+ postact_dtype: Optional[torch.dtype] = None,
471
+ ) -> Tuple[Tensor, Tensor]:
472
+ """Reference implementation for GEMM with activation gradient."""
473
+ out_dtype = A.dtype if out_dtype is None else out_dtype
474
+ postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
475
+ dout = torch.mm(A, B).to(out_dtype)
476
+ postact = act_to_pytorch_fn_map[activation](PreAct)
477
+ # Compute gradient using autograd
478
+ if activation is None:
479
+ dx = dout
480
+ else:
481
+ PreAct_requires_grad = PreAct.requires_grad
482
+ PreAct.requires_grad_(True)
483
+ postact_for_grad = act_to_pytorch_fn_map[activation](PreAct)
484
+ dx = torch.autograd.grad(postact_for_grad, PreAct, dout, create_graph=False)[0]
485
+ PreAct.requires_grad_(PreAct_requires_grad)
486
+ return dx.to(out_dtype), postact.to(postact_dtype)
487
+
488
+
489
+ def gemm_gated_ref(
490
+ A: Tensor,
491
+ B: Tensor,
492
+ C: Optional[Tensor] = None,
493
+ activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"] = "swiglu",
494
+ out_dtype: Optional[torch.dtype] = None,
495
+ postact_dtype: Optional[torch.dtype] = None,
496
+ store_preact: bool = True,
497
+ ) -> Tuple[Optional[Tensor], Tensor]:
498
+ """Reference implementation for GEMM with gated activation forward.
499
+
500
+ Args:
501
+ A: (M, K) - input tensor
502
+ B: (K, 2*N) - weight tensor with gate and up projections
503
+ C: (M, 2*N) - optional bias tensor
504
+ activation: Type of gated activation
505
+ out_dtype: Output dtype for preact
506
+ postact_dtype: Output dtype for postact
507
+ store_preact: Whether to return the pre-activation
508
+
509
+ Returns:
510
+ (preact, postact) where:
511
+ - preact: (M, 2*N) pre-activation (if store_preact=True, else None)
512
+ - postact: (M, N) post-activation output
513
+ """
514
+ out_dtype = A.dtype if out_dtype is None else out_dtype
515
+ postact_dtype = A.dtype if postact_dtype is None else postact_dtype
516
+ preact = torch.mm(A, B) if C is None else C + torch.mm(A, B)
517
+ # Split preact into gate and up projections
518
+ gate = preact[..., ::2] # (M, N)
519
+ up = preact[..., 1::2] # (M, N)
520
+ postact = gated_to_pytorch_fn_map[activation](gate, up)
521
+ return preact.to(out_dtype) if store_preact else None, postact.to(postact_dtype)
522
+
523
+
524
+ def gemm_dgated_ref(
525
+ A: Tensor,
526
+ B: Tensor,
527
+ PreAct: Tensor,
528
+ activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"],
529
+ out_dtype: Optional[torch.dtype] = None,
530
+ postact_dtype: Optional[torch.dtype] = None,
531
+ ) -> Tuple[Tensor, Tensor]:
532
+ """Reference implementation for GEMM with gated activation gradient.
533
+
534
+ Args:
535
+ A: (M, K) - dout input tensor
536
+ B: (K, N) - weight tensor
537
+ PreAct: (M, 2*N) - pre-activation tensor with gate and up projections interleaved
538
+ activation: Type of gated activation
539
+ out_dtype: Output dtype for dx
540
+ postact_dtype: Output dtype for postact
541
+
542
+ Returns:
543
+ (dx, postact) where:
544
+ - dx: (M, 2*N) gradient w.r.t. PreAct
545
+ - postact: (M, N) post-activation output
546
+ """
547
+ out_dtype = A.dtype if out_dtype is None else out_dtype
548
+ postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
549
+ dout = torch.mm(A, B).to(out_dtype)
550
+ # Split PreAct into gate and up projections
551
+ gate = PreAct[..., ::2] # (M, N)
552
+ up = PreAct[..., 1::2] # (M, N)
553
+ postact = gated_to_pytorch_fn_map[activation](gate, up)
554
+ # Use autograd to compute gradients w.r.t. gate and up
555
+ dgate, dup = torch.autograd.grad(postact, [gate, up], dout, create_graph=False)
556
+ # Interleave gradients back
557
+ dx = torch.stack([dgate, dup], dim=-1).reshape(PreAct.shape)
558
+ return dx.to(out_dtype), postact.to(postact_dtype)
559
+
560
+
561
+ # TODO: this is not quite right, do we need to register gemm_add not gemm_add_out?
562
+ # try:
563
+ # from torch._inductor.fx_passes.reinplace import InplaceableOp
564
+ # torch._inductor.fx_passes.reinplace.inplaceable_ops.update({
565
+ # torch.ops.quack.gemm_add_out.default:
566
+ # InplaceableOp(torch.ops.quack.gemm_add_inplace.default, mutated_arg=2)
567
+ # })
568
+ # except ImportError:
569
+ # pass