quack-kernels 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/broadcast_utils.py +29 -0
  4. quack/compile_utils.py +19 -0
  5. quack/copy_utils.py +487 -0
  6. quack/cross_entropy.py +157 -233
  7. quack/cute_dsl_utils.py +20 -34
  8. quack/gemm.py +194 -0
  9. quack/{gemm_act_sm90.py → gemm_act.py} +218 -117
  10. quack/gemm_config.py +72 -46
  11. quack/{gemm_dact_sm90.py → gemm_dact.py} +53 -21
  12. quack/gemm_default_epi.py +259 -0
  13. quack/gemm_interface.py +177 -31
  14. quack/gemm_sm100.py +729 -506
  15. quack/{dense_gemm_sm90.py → gemm_sm90.py} +344 -814
  16. quack/gemm_symmetric.py +330 -0
  17. quack/gemm_wrapper_utils.py +3 -1
  18. quack/layout_utils.py +287 -0
  19. quack/linear.py +24 -16
  20. quack/pipeline.py +158 -3
  21. quack/reduce.py +88 -49
  22. quack/reduction_base.py +25 -36
  23. quack/rmsnorm.py +476 -526
  24. quack/sm100_utils.py +62 -0
  25. quack/sm90_utils.py +127 -0
  26. quack/softmax.py +135 -203
  27. quack/sort/bitonic_sort.py +13 -10
  28. quack/sort/utils.py +6 -6
  29. quack/tile_scheduler.py +23 -16
  30. quack/topk.py +409 -85
  31. quack/utils.py +32 -220
  32. quack/varlen_utils.py +370 -1
  33. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/METADATA +4 -2
  34. quack_kernels-0.2.4.dist-info/RECORD +44 -0
  35. quack/layernorm.py +0 -353
  36. quack/symmetric_dense_gemm_sm90.py +0 -2091
  37. quack_kernels-0.2.2.dist-info/RECORD +0 -37
  38. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/WHEEL +0 -0
  39. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/licenses/LICENSE +0 -0
  40. {quack_kernels-0.2.2.dist-info → quack_kernels-0.2.4.dist-info}/top_level.txt +0 -0
quack/gemm_interface.py CHANGED
@@ -9,9 +9,11 @@ from torch import Tensor
9
9
  from quack.gemm_config import GemmConfig, get_all_configs
10
10
 
11
11
  from quack.autotuner import autotune, AutotuneConfig
12
- from quack.dense_gemm_sm90 import gemm_sm90
13
- from quack.gemm_act_sm90 import gemm_act_sm90
14
- from quack.gemm_dact_sm90 import gemm_dact_sm90
12
+ from quack.cute_dsl_utils import get_device_capacity
13
+ from quack.gemm import gemm as gemm_sm90_sm100
14
+ from quack.gemm_act import gemm_act as gemm_act_sm90_sm100
15
+ from quack.gemm_dact import gemm_dact as gemm_dact_sm90_sm100
16
+ from quack.gemm_symmetric import gemm_symmetric as gemm_symmetric_sm90_sm100
15
17
 
16
18
 
17
19
  # Dictionary mapping activation names to PyTorch functions
@@ -34,6 +36,16 @@ gated_to_pytorch_fn_map = {
34
36
  }
35
37
 
36
38
 
39
+ default_device_capacity = get_device_capacity(torch.device("cuda"))
40
+
41
+
42
+ def default_config(device):
43
+ if get_device_capacity(device)[0] != 10:
44
+ return GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
45
+ else:
46
+ return GemmConfig(tile_m=256, tile_n=256, cluster_m=2, cluster_n=1, pingpong=False)
47
+
48
+
37
49
  def prune_invalid_gemm_configs(configs, named_args: dict, **kwargs):
38
50
  kwargs = named_args | kwargs
39
51
  gather_A = kwargs.get("A_idx", None) is not None
@@ -41,17 +53,18 @@ def prune_invalid_gemm_configs(configs, named_args: dict, **kwargs):
41
53
  if varlen_m or gather_A: # Doesn't support swap_ab
42
54
  configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab]
43
55
  if gather_A:
44
- # tile_n == 208 causes register spills, as gather_A requires more registers for the producer
45
- configs = [
46
- conf
47
- for conf in configs
48
- if conf.kwargs["config"].cluster_n == 1 and conf.kwargs["config"].tile_n != 208
49
- ]
56
+ if get_device_capacity(kwargs["A"].device)[0] == 9:
57
+ # tile_n == 208 causes register spills, as gather_A requires more registers for the producer
58
+ configs = [
59
+ conf
60
+ for conf in configs
61
+ if conf.kwargs["config"].cluster_n == 1 and conf.kwargs["config"].tile_n != 208
62
+ ]
50
63
  return configs
51
64
 
52
65
 
53
66
  @autotune(
54
- configs=[AutotuneConfig(config=c) for c in get_all_configs()],
67
+ configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])],
55
68
  key=["dynamic_scheduler"],
56
69
  prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
57
70
  )
@@ -61,6 +74,7 @@ def gemm_tuned(
61
74
  B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
62
75
  out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
63
76
  C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
77
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
64
78
  alpha: float | Tensor = 1.0, # (1,)
65
79
  beta: float | Tensor = 1.0, # (1,)
66
80
  cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
@@ -72,7 +86,7 @@ def gemm_tuned(
72
86
  config: Optional[GemmConfig] = None,
73
87
  ) -> None:
74
88
  if config is None:
75
- config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
89
+ config = default_config(A.device)
76
90
  varlen_m = cu_seqlens_m is not None
77
91
  varlen_k = cu_seqlens_k is not None
78
92
  varlen = varlen_m or varlen_k
@@ -91,6 +105,8 @@ def gemm_tuned(
91
105
  C = C.unsqueeze(0) # (1, M, N)
92
106
  if out.ndim == 2 and not varlen_m:
93
107
  out = out.unsqueeze(0)
108
+ if bias is not None and bias.ndim == 1:
109
+ bias = bias.unsqueeze(0) # (L, N)
94
110
  batch_size = B.shape[0] if not varlen_k else cu_seqlens_k.shape[0] - 1
95
111
  if varlen_m:
96
112
  # If gather_A (A_idx provided), use its length; otherwise use A.shape[0]
@@ -102,7 +118,7 @@ def gemm_tuned(
102
118
  tile_count_semaphore = (
103
119
  torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
104
120
  )
105
- gemm_sm90(
121
+ gemm_sm90_sm100(
106
122
  A if not config.swap_ab else B,
107
123
  B if not config.swap_ab else A,
108
124
  out if not config.swap_ab else out.mT,
@@ -113,6 +129,10 @@ def gemm_tuned(
113
129
  config.cluster_m,
114
130
  config.cluster_n,
115
131
  config.pingpong,
132
+ persistent=True,
133
+ max_swizzle_size=config.max_swizzle_size,
134
+ rowvec_bias=bias if not config.swap_ab else None,
135
+ colvec_bias=bias if config.swap_ab else None,
116
136
  alpha=alpha,
117
137
  beta=beta,
118
138
  cu_seqlens_m=cu_seqlens_m,
@@ -124,7 +144,7 @@ def gemm_tuned(
124
144
 
125
145
 
126
146
  @autotune(
127
- configs=[AutotuneConfig(config=c) for c in get_all_configs()],
147
+ configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])],
128
148
  key=["activation", "dynamic_scheduler"],
129
149
  prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
130
150
  )
@@ -136,6 +156,7 @@ def gemm_act_tuned(
136
156
  preact_out: Optional[Tensor],
137
157
  postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
138
158
  C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
159
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
139
160
  activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
140
161
  cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
141
162
  A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
@@ -143,7 +164,7 @@ def gemm_act_tuned(
143
164
  config: Optional[GemmConfig] = None,
144
165
  ) -> None:
145
166
  if config is None:
146
- config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
167
+ config = default_config(A.device)
147
168
  varlen_m = cu_seqlens_m is not None
148
169
  if varlen_m:
149
170
  assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
@@ -162,10 +183,12 @@ def gemm_act_tuned(
162
183
  PostAct = postact_out.unsqueeze(0)
163
184
  else:
164
185
  PostAct = postact_out
186
+ if bias is not None and bias.ndim == 1:
187
+ bias = bias.unsqueeze(0) # (L, N)
165
188
  tile_count_semaphore = (
166
189
  torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
167
190
  )
168
- gemm_act_sm90(
191
+ gemm_act_sm90_sm100(
169
192
  A if not config.swap_ab else B,
170
193
  B if not config.swap_ab else A,
171
194
  (D if not config.swap_ab else D.mT) if D is not None else None,
@@ -179,13 +202,16 @@ def gemm_act_tuned(
179
202
  config.cluster_n,
180
203
  config.pingpong,
181
204
  persistent=True,
205
+ max_swizzle_size=config.max_swizzle_size,
206
+ rowvec_bias=bias if not config.swap_ab else None,
207
+ colvec_bias=bias if config.swap_ab else None,
182
208
  cu_seqlens_m=cu_seqlens_m,
183
209
  A_idx=A_idx,
184
210
  )
185
211
 
186
212
 
187
213
  @autotune(
188
- configs=[AutotuneConfig(config=c) for c in get_all_configs()],
214
+ configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])],
189
215
  key=["activation", "dynamic_scheduler"],
190
216
  prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
191
217
  )
@@ -203,7 +229,7 @@ def gemm_dact_tuned(
203
229
  config: Optional[GemmConfig] = None,
204
230
  ) -> None:
205
231
  if config is None:
206
- config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
232
+ config = default_config(A.device)
207
233
  varlen_m = cu_seqlens_m is not None
208
234
  if varlen_m:
209
235
  assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
@@ -225,7 +251,7 @@ def gemm_dact_tuned(
225
251
  tile_count_semaphore = (
226
252
  torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
227
253
  )
228
- gemm_dact_sm90(
254
+ gemm_dact_sm90_sm100(
229
255
  A if not config.swap_ab else B,
230
256
  B if not config.swap_ab else A,
231
257
  D if not config.swap_ab else D.mT,
@@ -239,6 +265,7 @@ def gemm_dact_tuned(
239
265
  config.cluster_n,
240
266
  config.pingpong,
241
267
  persistent=True,
268
+ max_swizzle_size=config.max_swizzle_size,
242
269
  cu_seqlens_m=cu_seqlens_m,
243
270
  A_idx=A_idx,
244
271
  )
@@ -249,6 +276,7 @@ def gemm(
249
276
  A: Tensor,
250
277
  B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
251
278
  out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
279
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
252
280
  alpha: float | Tensor = 1.0,
253
281
  out_dtype: Optional[torch.dtype] = None,
254
282
  cu_seqlens_m: Optional[Tensor] = None,
@@ -281,6 +309,7 @@ def gemm(
281
309
  A,
282
310
  B,
283
311
  out,
312
+ bias=bias,
284
313
  alpha=alpha,
285
314
  alpha_tensor=alpha_tensor,
286
315
  cu_seqlens_m=cu_seqlens_m,
@@ -299,13 +328,14 @@ def gemm(
299
328
  device_types="cuda",
300
329
  # We have to split out alpha and alpha_tensor since torch.library requires
301
330
  # each argument to have a fixed type
302
- # schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, Tensor? alpha_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
331
+ # schema="(Tensor A, Tensor B, Tensor(a2!) out, Tensor? bias, float alpha=1.0, Tensor? alpha_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
303
332
  )
304
333
  def gemm_out(
305
334
  # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
306
335
  A: Tensor,
307
336
  B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
308
337
  out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
338
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
309
339
  alpha: float = 1.0,
310
340
  alpha_tensor: Optional[Tensor] = None,
311
341
  cu_seqlens_m: Optional[Tensor] = None,
@@ -323,6 +353,7 @@ def gemm_out(
323
353
  B,
324
354
  out,
325
355
  C=None,
356
+ bias=bias,
326
357
  alpha=alpha,
327
358
  cu_seqlens_m=cu_seqlens_m,
328
359
  cu_seqlens_k=cu_seqlens_k,
@@ -337,6 +368,7 @@ def gemm_ref(
337
368
  A: Tensor,
338
369
  B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
339
370
  out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
371
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
340
372
  alpha: float | Tensor = 1.0,
341
373
  cu_seqlens_m: Optional[Tensor] = None,
342
374
  cu_seqlens_k: Optional[Tensor] = None,
@@ -349,6 +381,11 @@ def gemm_ref(
349
381
  if cu_seqlens_m is None and cu_seqlens_k is None:
350
382
  fn = torch.bmm if A.ndim == 3 else torch.mm
351
383
  out = fn(A, B, out_dtype=out_dtype, out=out)
384
+ if not isinstance(alpha, float) or alpha != 1.0:
385
+ out *= alpha
386
+ if bias is not None:
387
+ bias = bias if A.ndim == 2 else bias.unsqueeze(1)
388
+ out += bias
352
389
  elif cu_seqlens_m is not None:
353
390
  # Handle varlen_m case
354
391
  if out is None:
@@ -362,6 +399,10 @@ def gemm_ref(
362
399
  else A[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
363
400
  )
364
401
  torch.mm(A_slice, B[i], out=out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]])
402
+ if not isinstance(alpha, float) or alpha != 1.0:
403
+ out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]] *= alpha
404
+ if bias is not None:
405
+ out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]] += bias[i]
365
406
  else: # cu_seqlens_k is not None
366
407
  L = cu_seqlens_k.shape[0] - 1
367
408
  if out is None:
@@ -373,8 +414,10 @@ def gemm_ref(
373
414
  else A[:, cu_seqlens_k[i] : cu_seqlens_k[i + 1]]
374
415
  )
375
416
  torch.mm(A_slice, B[cu_seqlens_k[i] : cu_seqlens_k[i + 1], :], out=out[i])
376
- if not isinstance(alpha, float) or alpha != 1.0:
377
- out = out * alpha
417
+ if not isinstance(alpha, float) or alpha != 1.0:
418
+ out *= alpha
419
+ if bias is not None:
420
+ out += bias
378
421
  return out
379
422
 
380
423
 
@@ -488,6 +531,7 @@ def gemm_add_ref(
488
531
  A: Tensor,
489
532
  B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
490
533
  C: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
534
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
491
535
  out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
492
536
  alpha: float | Tensor = 1.0,
493
537
  beta: float | Tensor = 1.0,
@@ -499,7 +543,7 @@ def gemm_add_ref(
499
543
  """Reference implementation for GEMM with addition and pre-allocated output."""
500
544
  if cu_seqlens_m is None and cu_seqlens_k is None:
501
545
  if isinstance(alpha, float) and isinstance(beta, float):
502
- return torch.addmm(C, A, B, out_dtype=out_dtype, alpha=alpha, beta=beta, out=out)
546
+ out = torch.addmm(C, A, B, out_dtype=out_dtype, alpha=alpha, beta=beta, out=out)
503
547
  else:
504
548
  out_dtype = (
505
549
  out.dtype if out is not None else (out_dtype if out_dtype is not None else A.dtype)
@@ -507,7 +551,9 @@ def gemm_add_ref(
507
551
  result = (alpha * (A @ B) + beta * C).to(out_dtype)
508
552
  if out is not None:
509
553
  out.copy_(result)
510
- return result
554
+ if bias is not None:
555
+ bias = bias if A.ndim == 2 else bias.unsqueeze(1)
556
+ out += bias
511
557
  elif cu_seqlens_m is not None:
512
558
  # Handle varlen_m case
513
559
  if out is None:
@@ -524,6 +570,8 @@ def gemm_add_ref(
524
570
  C_slice = C[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
525
571
  out_slice = out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
526
572
  result = alpha * torch.mm(A_slice, B[i]) + beta * C_slice
573
+ if bias is not None:
574
+ result += bias[i]
527
575
  out_slice.copy_(result)
528
576
  else: # cu_seqlens_k is not None
529
577
  # Handle varlen_k case
@@ -540,6 +588,8 @@ def gemm_add_ref(
540
588
  B_slice = B[cu_seqlens_k[i] : cu_seqlens_k[i + 1], :]
541
589
  result = alpha * torch.mm(A_slice, B_slice) + beta * C[i]
542
590
  out[i].copy_(result)
591
+ if bias is not None:
592
+ out += bias
543
593
  return out
544
594
 
545
595
 
@@ -639,6 +689,7 @@ def gemm_act(
639
689
  A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
640
690
  B: Tensor, # (K, N) or (L, K, N)
641
691
  C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
692
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
642
693
  activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
643
694
  preact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
644
695
  postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
@@ -667,7 +718,17 @@ def gemm_act(
667
718
  if postact_out is None:
668
719
  postact_out = torch.empty(out_shape, dtype=postact_dtype, device=A.device)
669
720
  gemm_act_out(
670
- A, B, preact_out, postact_out, C, activation, cu_seqlens_m, A_idx, dynamic_scheduler, tuned
721
+ A,
722
+ B,
723
+ preact_out,
724
+ postact_out,
725
+ C,
726
+ bias,
727
+ activation,
728
+ cu_seqlens_m,
729
+ A_idx,
730
+ dynamic_scheduler,
731
+ tuned,
671
732
  )
672
733
  return preact_out, postact_out
673
734
 
@@ -676,7 +737,7 @@ def gemm_act(
676
737
  "quack::gemm_act_out",
677
738
  mutates_args=("preact_out", "postact_out"),
678
739
  device_types="cuda",
679
- schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, str? activation=None, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
740
+ schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, Tensor? bias=None, str? activation=None, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
680
741
  )
681
742
  def gemm_act_out(
682
743
  A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
@@ -684,6 +745,7 @@ def gemm_act_out(
684
745
  preact_out: Optional[Tensor], # (M, N) or (L, M, N) or (total_M, N) if varlen_m
685
746
  postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
686
747
  C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
748
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
687
749
  activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
688
750
  cu_seqlens_m: Optional[Tensor] = None,
689
751
  A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
@@ -692,13 +754,14 @@ def gemm_act_out(
692
754
  ) -> None:
693
755
  """GEMM with activation and pre-allocated output tensors."""
694
756
  fn = gemm_act_tuned if tuned else partial(gemm_act_tuned.fn, config=None)
695
- fn(A, B, preact_out, postact_out, C, activation, cu_seqlens_m, A_idx, dynamic_scheduler)
757
+ fn(A, B, preact_out, postact_out, C, bias, activation, cu_seqlens_m, A_idx, dynamic_scheduler)
696
758
 
697
759
 
698
760
  def gemm_act_ref(
699
761
  A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
700
762
  B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
701
763
  C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
764
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
702
765
  activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
703
766
  cu_seqlens_m: Optional[Tensor] = None,
704
767
  A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
@@ -709,9 +772,9 @@ def gemm_act_ref(
709
772
  out_dtype = A.dtype if out_dtype is None else out_dtype
710
773
  postact_dtype = A.dtype if postact_dtype is None else postact_dtype
711
774
  if C is None:
712
- out = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
775
+ out = gemm_ref(A, B, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
713
776
  else:
714
- out = gemm_add_ref(A, B, C, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
777
+ out = gemm_add_ref(A, B, C, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
715
778
  postact = act_to_pytorch_fn_map[activation](out).to(postact_dtype)
716
779
  return out.to(out_dtype) if store_preact else None, postact
717
780
 
@@ -806,6 +869,7 @@ def gemm_gated_ref(
806
869
  A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
807
870
  B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
808
871
  C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
872
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
809
873
  activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"] = "swiglu",
810
874
  cu_seqlens_m: Optional[Tensor] = None,
811
875
  A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
@@ -832,9 +896,9 @@ def gemm_gated_ref(
832
896
  out_dtype = A.dtype if out_dtype is None else out_dtype
833
897
  postact_dtype = A.dtype if postact_dtype is None else postact_dtype
834
898
  if C is None:
835
- preact = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
899
+ preact = gemm_ref(A, B, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
836
900
  else:
837
- preact = gemm_add_ref(A, B, C, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
901
+ preact = gemm_add_ref(A, B, C, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
838
902
  # Split preact into gate and up projections
839
903
  gate = preact[..., ::2] # (M, N//2)
840
904
  up = preact[..., 1::2] # (M, N//2)
@@ -873,14 +937,96 @@ def gemm_dgated_ref(
873
937
  # Split PreAct into gate and up projections
874
938
  gate = PreAct[..., ::2] # (M, N)
875
939
  up = PreAct[..., 1::2] # (M, N)
876
- postact = gated_to_pytorch_fn_map[activation](gate, up)
877
940
  # Use autograd to compute gradients w.r.t. gate and up
941
+ gate_requires_grad, up_requires_grad = gate.requires_grad, up.requires_grad
942
+ gate.requires_grad_(True)
943
+ up.requires_grad_(True)
944
+ postact = gated_to_pytorch_fn_map[activation](gate, up)
878
945
  dgate, dup = torch.autograd.grad(postact, [gate, up], dout, create_graph=False)
946
+ gate.requires_grad_(gate_requires_grad)
947
+ up.requires_grad_(up_requires_grad)
879
948
  # Interleave gradients back
880
949
  dx = torch.stack([dgate, dup], dim=-1).reshape(PreAct.shape)
881
950
  return dx.to(out_dtype), postact.to(postact_dtype)
882
951
 
883
952
 
953
+ @torch.library.custom_op(
954
+ "quack::gemm_symmetric_out",
955
+ mutates_args=("out",),
956
+ device_types="cuda",
957
+ schema="(Tensor A, Tensor B, Tensor(a2!) out, Tensor? C=None, bool dynamic_scheduler=False, float alpha=1.0, float beta=1.0) -> ()",
958
+ )
959
+ def gemm_symmetric_out(
960
+ A: Tensor, # (M, K) or (L, M, K)
961
+ B: Tensor, # (K, M) or (L, K, M)
962
+ out: Tensor, # (M, M) or (L, M, M)
963
+ C: Optional[Tensor] = None, # (M, M) or (L, M, M)
964
+ dynamic_scheduler: bool = False,
965
+ alpha: float = 1.0,
966
+ beta: float = 1.0,
967
+ ) -> None:
968
+ """GEMM with guaranteed symmetric output."""
969
+ if A.ndim == 2:
970
+ A = A.unsqueeze(0) # (1, M, K)
971
+ B = B.mT # (M, K) or (L, M, K)
972
+ if B.ndim == 2:
973
+ B = B.unsqueeze(0) # (1, M, K)
974
+ if C is not None and C.ndim == 2:
975
+ C = C.unsqueeze(0) # (1, M, M)
976
+ if out.ndim == 2:
977
+ out = out.unsqueeze(0)
978
+ else:
979
+ out = out
980
+ tile_count_semaphore = (
981
+ torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
982
+ )
983
+ gemm_symmetric_sm90_sm100(
984
+ A,
985
+ B,
986
+ out if out is not None else None,
987
+ C if C is not None else None,
988
+ tile_count_semaphore,
989
+ tile_M=128,
990
+ tile_N=256,
991
+ cluster_M=2,
992
+ cluster_N=1,
993
+ pingpong=False,
994
+ persistent=True,
995
+ max_swizzle_size=8,
996
+ alpha=alpha,
997
+ beta=beta,
998
+ )
999
+
1000
+
1001
+ def gemm_symmetric(
1002
+ A: Tensor, # (M, K) or (L, M, K)
1003
+ B: Tensor, # (K, M) or (L, K, M)
1004
+ C: Optional[Tensor] = None, # (M, M) or (L, M, M)
1005
+ out: Optional[Tensor] = None, # (M, M) or (L, M, M)
1006
+ out_dtype: Optional[torch.dtype] = None,
1007
+ dynamic_scheduler: bool = False,
1008
+ alpha: float | Tensor = 1.0,
1009
+ beta: float | Tensor = 1.0,
1010
+ ) -> Tuple[Optional[Tensor], Tensor]:
1011
+ """GEMM with symmetric output."""
1012
+ out_dtype = A.dtype if out_dtype is None else out_dtype
1013
+ # Determine output shape based on gather_A
1014
+ if A.ndim == 2:
1015
+ out_shape = (A.shape[0], B.shape[-1])
1016
+ else:
1017
+ out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
1018
+ if out is None:
1019
+ out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
1020
+
1021
+ alpha_val = alpha if isinstance(alpha, float) else 1.0
1022
+ beta_val = beta if isinstance(beta, float) else 1.0
1023
+
1024
+ gemm_symmetric_out(
1025
+ A, B, out, C, dynamic_scheduler=dynamic_scheduler, alpha=alpha_val, beta=beta_val
1026
+ )
1027
+ return out
1028
+
1029
+
884
1030
  # TODO: this is not quite right, do we need to register gemm_add not gemm_add_out?
885
1031
  # try:
886
1032
  # from torch._inductor.fx_passes.reinplace import InplaceableOp