quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. quack/__init__.py +1 -8
  2. quack/activation.py +366 -121
  3. quack/autotuner.py +64 -5
  4. quack/broadcast_utils.py +29 -0
  5. quack/compile_utils.py +19 -0
  6. quack/copy_utils.py +487 -0
  7. quack/cross_entropy.py +157 -233
  8. quack/cute_dsl_utils.py +20 -35
  9. quack/gemm.py +194 -0
  10. quack/gemm_act.py +510 -0
  11. quack/gemm_config.py +72 -46
  12. quack/gemm_dact.py +215 -0
  13. quack/gemm_default_epi.py +259 -0
  14. quack/gemm_interface.py +615 -146
  15. quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
  16. quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
  17. quack/gemm_symmetric.py +330 -0
  18. quack/gemm_wrapper_utils.py +182 -23
  19. quack/layout_utils.py +287 -0
  20. quack/linear.py +24 -16
  21. quack/pipeline.py +158 -3
  22. quack/reduce.py +88 -49
  23. quack/reduction_base.py +25 -36
  24. quack/rmsnorm.py +508 -624
  25. quack/sm100_utils.py +62 -0
  26. quack/sm90_utils.py +127 -0
  27. quack/softmax.py +135 -203
  28. quack/sort/bitonic_sort.py +13 -10
  29. quack/sort/utils.py +6 -6
  30. quack/tile_scheduler.py +55 -61
  31. quack/topk.py +409 -85
  32. quack/utils.py +37 -172
  33. quack/varlen_utils.py +370 -6
  34. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
  35. quack_kernels-0.2.3.dist-info/RECORD +44 -0
  36. quack/gemm_act_sm90.py +0 -368
  37. quack/gemm_dact_sm90.py +0 -150
  38. quack/layernorm.py +0 -353
  39. quack/symmetric_dense_gemm_sm90.py +0 -2091
  40. quack_kernels-0.2.1.dist-info/RECORD +0 -37
  41. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
  42. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
  43. {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/gemm_interface.py CHANGED
@@ -9,9 +9,11 @@ from torch import Tensor
9
9
  from quack.gemm_config import GemmConfig, get_all_configs
10
10
 
11
11
  from quack.autotuner import autotune, AutotuneConfig
12
- from quack.dense_gemm_sm90 import gemm_sm90
13
- from quack.gemm_act_sm90 import gemm_act_sm90
14
- from quack.gemm_dact_sm90 import gemm_dact_sm90
12
+ from quack.cute_dsl_utils import get_device_capacity
13
+ from quack.gemm import gemm as gemm_sm90_sm100
14
+ from quack.gemm_act import gemm_act as gemm_act_sm90_sm100
15
+ from quack.gemm_dact import gemm_dact as gemm_dact_sm90_sm100
16
+ from quack.gemm_symmetric import gemm_symmetric as gemm_symmetric_sm90_sm100
15
17
 
16
18
 
17
19
  # Dictionary mapping activation names to PyTorch functions
@@ -34,34 +36,89 @@ gated_to_pytorch_fn_map = {
34
36
  }
35
37
 
36
38
 
39
+ default_device_capacity = get_device_capacity(torch.device("cuda"))
40
+
41
+
42
+ def default_config(device):
43
+ if get_device_capacity(device)[0] != 10:
44
+ return GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
45
+ else:
46
+ return GemmConfig(tile_m=256, tile_n=256, cluster_m=2, cluster_n=1, pingpong=False)
47
+
48
+
49
+ def prune_invalid_gemm_configs(configs, named_args: dict, **kwargs):
50
+ kwargs = named_args | kwargs
51
+ gather_A = kwargs.get("A_idx", None) is not None
52
+ varlen_m = kwargs.get("cu_seqlens_m", None) is not None
53
+ if varlen_m or gather_A: # Doesn't support swap_ab
54
+ configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab]
55
+ if gather_A:
56
+ if get_device_capacity(kwargs["A"].device)[0] == 9:
57
+ # tile_n == 208 causes register spills, as gather_A requires more registers for the producer
58
+ configs = [
59
+ conf
60
+ for conf in configs
61
+ if conf.kwargs["config"].cluster_n == 1 and conf.kwargs["config"].tile_n != 208
62
+ ]
63
+ return configs
64
+
65
+
37
66
  @autotune(
38
- configs=[AutotuneConfig(config=c) for c in get_all_configs()],
67
+ configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])],
39
68
  key=["dynamic_scheduler"],
69
+ prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
40
70
  )
41
71
  def gemm_tuned(
42
- A: Tensor, # (M, K)
43
- B: Tensor, # (K, N)
44
- out: Tensor, # (M, N) - required output tensor
45
- C: Optional[Tensor] = None, # (M, N)
72
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
73
+ A: Tensor,
74
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
75
+ out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
76
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
77
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
46
78
  alpha: float | Tensor = 1.0, # (1,)
47
79
  beta: float | Tensor = 1.0, # (1,)
80
+ cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
81
+ cu_seqlens_k: Optional[Tensor] = None, # (L+1), int32
82
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
83
+ batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
84
+ add_to_output: bool = False,
48
85
  dynamic_scheduler: bool = False,
49
86
  config: Optional[GemmConfig] = None,
50
87
  ) -> None:
51
88
  if config is None:
52
- config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
53
- A, B = A.unsqueeze(0), B.mT.unsqueeze(0) # (1, M, K), (1, N, K)
54
- if C is not None:
89
+ config = default_config(A.device)
90
+ varlen_m = cu_seqlens_m is not None
91
+ varlen_k = cu_seqlens_k is not None
92
+ varlen = varlen_m or varlen_k
93
+ gather_A = A_idx is not None
94
+ if gather_A:
95
+ assert varlen, "gather_A requires either varlen_m or varlen_k"
96
+ assert config.cluster_n == 1, "gather_A requires cluster_n=1"
97
+ if varlen_m:
98
+ assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
99
+ if A.ndim == 2 and not varlen:
100
+ A = A.unsqueeze(0) # (1, M, K)
101
+ B = B.mT # (N, K) or (L, N, K) or (N, total_K)
102
+ if B.ndim == 2 and not varlen_k:
103
+ B = B.unsqueeze(0) # (1, N, K)
104
+ if C is not None and C.ndim == 2 and not varlen_m:
55
105
  C = C.unsqueeze(0) # (1, M, N)
56
- assert out.shape == (
57
- A.shape[1],
58
- B.shape[1],
59
- ), f"out shape mismatch: {out.shape} vs {(A.shape[1], B.shape[1])}"
60
- out = out.unsqueeze(0)
106
+ if out.ndim == 2 and not varlen_m:
107
+ out = out.unsqueeze(0)
108
+ if bias is not None and bias.ndim == 1:
109
+ bias = bias.unsqueeze(0) # (L, N)
110
+ batch_size = B.shape[0] if not varlen_k else cu_seqlens_k.shape[0] - 1
111
+ if varlen_m:
112
+ # If gather_A (A_idx provided), use its length; otherwise use A.shape[0]
113
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
114
+ out_shape = (total_m, B.shape[-2])
115
+ else:
116
+ out_shape = (batch_size, A.shape[-2], B.shape[-2])
117
+ assert out.shape == out_shape, f"out shape mismatch: {out.shape} vs {out_shape}"
61
118
  tile_count_semaphore = (
62
119
  torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
63
120
  )
64
- gemm_sm90(
121
+ gemm_sm90_sm100(
65
122
  A if not config.swap_ab else B,
66
123
  B if not config.swap_ab else A,
67
124
  out if not config.swap_ab else out.mT,
@@ -72,77 +129,129 @@ def gemm_tuned(
72
129
  config.cluster_m,
73
130
  config.cluster_n,
74
131
  config.pingpong,
132
+ persistent=True,
133
+ max_swizzle_size=config.max_swizzle_size,
134
+ rowvec_bias=bias if not config.swap_ab else None,
135
+ colvec_bias=bias if config.swap_ab else None,
75
136
  alpha=alpha,
76
137
  beta=beta,
138
+ cu_seqlens_m=cu_seqlens_m,
139
+ cu_seqlens_k=cu_seqlens_k,
140
+ A_idx=A_idx,
141
+ batch_idx_permute=batch_idx_permute,
142
+ add_to_output=add_to_output,
77
143
  )
78
144
 
79
145
 
80
146
  @autotune(
81
- configs=[AutotuneConfig(config=c) for c in get_all_configs()],
82
- key=["activation"],
147
+ configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])],
148
+ key=["activation", "dynamic_scheduler"],
149
+ prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
83
150
  )
84
151
  def gemm_act_tuned(
85
- A: Tensor, # (M, K)
86
- B: Tensor, # (K, N)
87
- preact_out: Optional[Tensor], # (M, N) - None if not storing preact
88
- postact_out: Tensor, # (M, N)
89
- C: Optional[Tensor] = None, # (M, N)
152
+ # (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
153
+ A: Tensor,
154
+ B: Tensor, # (K, N) or (L, K, N)
155
+ # (M, N) or (L, M, N) or (total_M, N) if varlen_m - None if not storing preact
156
+ preact_out: Optional[Tensor],
157
+ postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
158
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
159
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
90
160
  activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
161
+ cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
162
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
163
+ dynamic_scheduler: bool = False,
91
164
  config: Optional[GemmConfig] = None,
92
165
  ) -> None:
93
166
  if config is None:
94
- config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
95
- A, B = A.unsqueeze(0), B.mT.unsqueeze(0) # (1, M, K), (1, N, K)
96
- if C is not None:
167
+ config = default_config(A.device)
168
+ varlen_m = cu_seqlens_m is not None
169
+ if varlen_m:
170
+ assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
171
+ if A.ndim == 2 and not varlen_m:
172
+ A = A.unsqueeze(0) # (1, M, K)
173
+ B = B.mT # (N, K) or (L, N, K)
174
+ if B.ndim == 2:
175
+ B = B.unsqueeze(0) # (1, N, K)
176
+ if C is not None and C.ndim == 2 and not varlen_m:
97
177
  C = C.unsqueeze(0) # (1, M, N)
98
- if preact_out is not None:
99
- assert preact_out.shape == (A.shape[1], B.shape[1])
178
+ if preact_out is not None and preact_out.ndim == 2 and not varlen_m:
100
179
  D = preact_out.unsqueeze(0)
101
180
  else:
102
- D = None
103
- assert postact_out.shape == (A.shape[1], B.shape[1])
104
- PostAct = postact_out.unsqueeze(0)
105
- gemm_act_sm90(
181
+ D = preact_out
182
+ if postact_out.ndim == 2 and not varlen_m:
183
+ PostAct = postact_out.unsqueeze(0)
184
+ else:
185
+ PostAct = postact_out
186
+ if bias is not None and bias.ndim == 1:
187
+ bias = bias.unsqueeze(0) # (L, N)
188
+ tile_count_semaphore = (
189
+ torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
190
+ )
191
+ gemm_act_sm90_sm100(
106
192
  A if not config.swap_ab else B,
107
193
  B if not config.swap_ab else A,
108
194
  (D if not config.swap_ab else D.mT) if D is not None else None,
109
195
  (C if not config.swap_ab else C.mT) if C is not None else None,
110
196
  PostAct if not config.swap_ab else PostAct.mT,
197
+ tile_count_semaphore,
111
198
  activation,
112
199
  config.tile_m,
113
200
  config.tile_n,
114
201
  config.cluster_m,
115
202
  config.cluster_n,
116
203
  config.pingpong,
204
+ persistent=True,
205
+ max_swizzle_size=config.max_swizzle_size,
206
+ rowvec_bias=bias if not config.swap_ab else None,
207
+ colvec_bias=bias if config.swap_ab else None,
208
+ cu_seqlens_m=cu_seqlens_m,
209
+ A_idx=A_idx,
117
210
  )
118
211
 
119
212
 
120
213
  @autotune(
121
- configs=[AutotuneConfig(config=c) for c in get_all_configs()],
214
+ configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])],
122
215
  key=["activation", "dynamic_scheduler"],
216
+ prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
123
217
  )
124
218
  def gemm_dact_tuned(
125
- A: Tensor, # (M, K)
126
- B: Tensor, # (K, N)
127
- PreAct: Tensor, # (M, N)
128
- dx_out: Tensor, # (M, N)
129
- postact_out: Tensor, # (M, N)
219
+ # (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
220
+ A: Tensor,
221
+ B: Tensor, # (K, N) or (L, K, N)
222
+ PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
223
+ dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
224
+ postact_out: Tensor, # (M, N) or (L, N, N) or (total_M, N) if varlen_m
130
225
  activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
226
+ cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
227
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
131
228
  dynamic_scheduler: bool = True,
132
229
  config: Optional[GemmConfig] = None,
133
230
  ) -> None:
134
231
  if config is None:
135
- config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
136
- A, B = A.unsqueeze(0), B.mT.unsqueeze(0) # (1, M, K), (1, N, K)
137
- PreAct = PreAct.unsqueeze(0) # (1, M, N)
138
- assert dx_out.shape == (A.shape[1], B.shape[1])
139
- D = dx_out.unsqueeze(0)
140
- assert postact_out.shape == (A.shape[1], B.shape[1])
141
- PostAct = postact_out.unsqueeze(0)
232
+ config = default_config(A.device)
233
+ varlen_m = cu_seqlens_m is not None
234
+ if varlen_m:
235
+ assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
236
+ if A.ndim == 2 and not varlen_m:
237
+ A = A.unsqueeze(0) # (1, M, K)
238
+ B = B.mT # (N, K) or (L, N, K)
239
+ if B.ndim == 2:
240
+ B = B.unsqueeze(0) # (1, N, K)
241
+ if PreAct.ndim == 2 and not varlen_m:
242
+ PreAct = PreAct.unsqueeze(0) # (1, M, N)
243
+ if dx_out.ndim == 2 and not varlen_m:
244
+ D = dx_out.unsqueeze(0)
245
+ else:
246
+ D = dx_out
247
+ if postact_out.ndim == 2 and not varlen_m:
248
+ PostAct = postact_out.unsqueeze(0)
249
+ else:
250
+ PostAct = postact_out
142
251
  tile_count_semaphore = (
143
252
  torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
144
253
  )
145
- gemm_dact_sm90(
254
+ gemm_dact_sm90_sm100(
146
255
  A if not config.swap_ab else B,
147
256
  B if not config.swap_ab else A,
148
257
  D if not config.swap_ab else D.mT,
@@ -155,30 +264,58 @@ def gemm_dact_tuned(
155
264
  config.cluster_m,
156
265
  config.cluster_n,
157
266
  config.pingpong,
267
+ persistent=True,
268
+ max_swizzle_size=config.max_swizzle_size,
269
+ cu_seqlens_m=cu_seqlens_m,
270
+ A_idx=A_idx,
158
271
  )
159
272
 
160
273
 
161
274
  def gemm(
275
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
162
276
  A: Tensor,
163
- B: Tensor,
164
- out: Optional[Tensor] = None,
277
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
278
+ out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
279
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
165
280
  alpha: float | Tensor = 1.0,
166
281
  out_dtype: Optional[torch.dtype] = None,
282
+ cu_seqlens_m: Optional[Tensor] = None,
283
+ cu_seqlens_k: Optional[Tensor] = None,
284
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
285
+ batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
167
286
  dynamic_scheduler: bool = False,
168
287
  tuned: bool = True,
169
288
  ) -> Tensor:
170
289
  """GEMM with optional output tensor and tuning control."""
171
290
  if out is None:
172
291
  out_dtype = A.dtype if out_dtype is None else out_dtype
173
- out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
292
+ varlen_m = cu_seqlens_m is not None
293
+ varlen_k = cu_seqlens_k is not None
294
+ if varlen_m:
295
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
296
+ out_shape = (total_m, B.shape[-1])
297
+ elif varlen_k:
298
+ L = cu_seqlens_k.shape[0] - 1
299
+ # For varlen_k, the first dimension is always A.shape[0] (M dimension)
300
+ out_shape = (L, A.shape[0], B.shape[-1])
301
+ else:
302
+ out_shape = (
303
+ (A.shape[0], B.shape[-1]) if A.ndim == 2 else (A.shape[0], A.shape[-2], B.shape[-1])
304
+ )
305
+ out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
174
306
  alpha_tensor = alpha if not isinstance(alpha, float) else None
175
307
  alpha = alpha if isinstance(alpha, float) else 1.0
176
308
  gemm_out(
177
309
  A,
178
310
  B,
179
311
  out,
312
+ bias=bias,
180
313
  alpha=alpha,
181
314
  alpha_tensor=alpha_tensor,
315
+ cu_seqlens_m=cu_seqlens_m,
316
+ cu_seqlens_k=cu_seqlens_k,
317
+ A_idx=A_idx,
318
+ batch_idx_permute=batch_idx_permute,
182
319
  dynamic_scheduler=dynamic_scheduler,
183
320
  tuned=tuned,
184
321
  )
@@ -191,53 +328,134 @@ def gemm(
191
328
  device_types="cuda",
192
329
  # We have to split out alpha and alpha_tensor since torch.library requires
193
330
  # each argument to have a fixed type
194
- # schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, Tensor? alpha_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
331
+ # schema="(Tensor A, Tensor B, Tensor(a2!) out, Tensor? bias, float alpha=1.0, Tensor? alpha_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
195
332
  )
196
333
  def gemm_out(
334
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
197
335
  A: Tensor,
198
- B: Tensor,
199
- out: Tensor,
336
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
337
+ out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
338
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
200
339
  alpha: float = 1.0,
201
340
  alpha_tensor: Optional[Tensor] = None,
341
+ cu_seqlens_m: Optional[Tensor] = None,
342
+ cu_seqlens_k: Optional[Tensor] = None,
343
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
344
+ batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
202
345
  dynamic_scheduler: bool = False,
203
346
  tuned: bool = True,
204
347
  ) -> None:
205
348
  """GEMM with pre-allocated output tensor."""
206
349
  fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
207
350
  alpha = alpha_tensor if alpha_tensor is not None else alpha
208
- fn(A, B, out, C=None, alpha=alpha, dynamic_scheduler=dynamic_scheduler)
351
+ fn(
352
+ A,
353
+ B,
354
+ out,
355
+ C=None,
356
+ bias=bias,
357
+ alpha=alpha,
358
+ cu_seqlens_m=cu_seqlens_m,
359
+ cu_seqlens_k=cu_seqlens_k,
360
+ A_idx=A_idx,
361
+ batch_idx_permute=batch_idx_permute,
362
+ dynamic_scheduler=dynamic_scheduler,
363
+ )
209
364
 
210
365
 
211
366
  def gemm_ref(
367
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
212
368
  A: Tensor,
213
- B: Tensor,
214
- out: Optional[Tensor] = None,
369
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
370
+ out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
371
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
215
372
  alpha: float | Tensor = 1.0,
373
+ cu_seqlens_m: Optional[Tensor] = None,
374
+ cu_seqlens_k: Optional[Tensor] = None,
375
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
216
376
  out_dtype: Optional[torch.dtype] = None,
217
377
  ) -> Tensor:
218
378
  """Reference implementation for GEMM with pre-allocated output."""
219
379
  # The out_dtype argument requires torch >= 2.8
220
- out = torch.mm(A, B, out_dtype=out_dtype, out=out)
221
- if not isinstance(alpha, float) or alpha != 1.0:
222
- out = out * alpha
380
+ out_dtype = A.dtype if out_dtype is None else out_dtype
381
+ if cu_seqlens_m is None and cu_seqlens_k is None:
382
+ fn = torch.bmm if A.ndim == 3 else torch.mm
383
+ out = fn(A, B, out_dtype=out_dtype, out=out)
384
+ if not isinstance(alpha, float) or alpha != 1.0:
385
+ out *= alpha
386
+ if bias is not None:
387
+ bias = bias if A.ndim == 2 else bias.unsqueeze(1)
388
+ out += bias
389
+ elif cu_seqlens_m is not None:
390
+ # Handle varlen_m case
391
+ if out is None:
392
+ # When gather_A (A_idx provided), output size is determined by A_idx length
393
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
394
+ out = torch.empty((total_m, B.shape[-1]), dtype=out_dtype, device=A.device)
395
+ for i in range(cu_seqlens_m.shape[0] - 1):
396
+ A_slice = (
397
+ A[A_idx[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]]
398
+ if A_idx is not None
399
+ else A[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
400
+ )
401
+ torch.mm(A_slice, B[i], out=out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]])
402
+ if not isinstance(alpha, float) or alpha != 1.0:
403
+ out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]] *= alpha
404
+ if bias is not None:
405
+ out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]] += bias[i]
406
+ else: # cu_seqlens_k is not None
407
+ L = cu_seqlens_k.shape[0] - 1
408
+ if out is None:
409
+ out = torch.empty((L, A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
410
+ for i in range(L):
411
+ A_slice = (
412
+ A[:, A_idx[cu_seqlens_k[i] : cu_seqlens_k[i + 1]]]
413
+ if A_idx is not None
414
+ else A[:, cu_seqlens_k[i] : cu_seqlens_k[i + 1]]
415
+ )
416
+ torch.mm(A_slice, B[cu_seqlens_k[i] : cu_seqlens_k[i + 1], :], out=out[i])
417
+ if not isinstance(alpha, float) or alpha != 1.0:
418
+ out *= alpha
419
+ if bias is not None:
420
+ out += bias
223
421
  return out
224
422
 
225
423
 
226
424
  def gemm_add(
425
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
227
426
  A: Tensor,
228
- B: Tensor,
229
- C: Tensor,
230
- out: Optional[Tensor] = None,
427
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
428
+ C: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
429
+ out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
231
430
  alpha: float | Tensor = 1.0,
232
431
  beta: float | Tensor = 1.0,
233
432
  out_dtype: Optional[torch.dtype] = None,
433
+ cu_seqlens_m: Optional[Tensor] = None,
434
+ cu_seqlens_k: Optional[Tensor] = None,
435
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
436
+ batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
234
437
  dynamic_scheduler: bool = False,
235
438
  tuned: bool = True,
236
439
  ) -> Tensor:
237
440
  """GEMM with addition and optional output tensor."""
238
441
  if out is None:
239
442
  out_dtype = A.dtype if out_dtype is None else out_dtype
240
- out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
443
+ varlen_m = cu_seqlens_m is not None
444
+ varlen_k = cu_seqlens_k is not None
445
+ if varlen_m:
446
+ # If A_idx is provided (gather_A), use its length; otherwise use A.shape[0]
447
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
448
+ out_shape = (total_m, B.shape[-1])
449
+ elif varlen_k:
450
+ L = cu_seqlens_k.shape[0] - 1
451
+ # For varlen_k, the first dimension is always A.shape[0] (M dimension)
452
+ out_shape = (L, A.shape[0], B.shape[-1])
453
+ else:
454
+ out_shape = (
455
+ (A.shape[0], B.shape[-1]) if A.ndim == 2 else (A.shape[0], A.shape[-2], B.shape[-1])
456
+ )
457
+ out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
458
+ add_to_output = C is out and isinstance(beta, float) and beta == 1.0 and cu_seqlens_m is None
241
459
  alpha_tensor = alpha if not isinstance(alpha, float) else None
242
460
  alpha = alpha if isinstance(alpha, float) else 1.0
243
461
  beta_tensor = beta if not isinstance(beta, float) else None
@@ -245,12 +463,17 @@ def gemm_add(
245
463
  gemm_add_out(
246
464
  A,
247
465
  B,
248
- C,
466
+ C if not add_to_output else None,
249
467
  out,
250
468
  alpha,
251
469
  beta,
252
470
  alpha_tensor,
253
471
  beta_tensor,
472
+ cu_seqlens_m=cu_seqlens_m,
473
+ cu_seqlens_k=cu_seqlens_k,
474
+ A_idx=A_idx,
475
+ batch_idx_permute=batch_idx_permute,
476
+ add_to_output=add_to_output,
254
477
  dynamic_scheduler=dynamic_scheduler,
255
478
  tuned=tuned,
256
479
  )
@@ -263,17 +486,23 @@ def gemm_add(
263
486
  device_types="cuda",
264
487
  # We have to split out alpha and alpha_tensor since torch.library requires
265
488
  # each argument to have a fixed type
266
- # schema="(Tensor A, Tensor B, Tensor C, Tensor(a3!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
489
+ # schema="(Tensor A, Tensor B, Tensor C, Tensor(a3!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, Tensor? cu_seqlens_m=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
267
490
  )
268
491
  def gemm_add_out(
492
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
269
493
  A: Tensor,
270
- B: Tensor,
271
- C: Tensor,
272
- out: Tensor,
494
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
495
+ C: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
496
+ out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
273
497
  alpha: float = 1.0,
274
498
  beta: float = 1.0,
275
499
  alpha_tensor: Optional[Tensor] = None,
276
500
  beta_tensor: Optional[Tensor] = None,
501
+ cu_seqlens_m: Optional[Tensor] = None,
502
+ cu_seqlens_k: Optional[Tensor] = None,
503
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
504
+ batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
505
+ add_to_output: bool = False,
277
506
  dynamic_scheduler: bool = False,
278
507
  tuned: bool = True,
279
508
  ) -> None:
@@ -281,47 +510,112 @@ def gemm_add_out(
281
510
  fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
282
511
  alpha = alpha_tensor if alpha_tensor is not None else alpha
283
512
  beta = beta_tensor if beta_tensor is not None else beta
284
- fn(A, B, out, C, alpha=alpha, beta=beta, dynamic_scheduler=dynamic_scheduler)
513
+ fn(
514
+ A,
515
+ B,
516
+ out,
517
+ C,
518
+ alpha=alpha,
519
+ beta=beta,
520
+ cu_seqlens_m=cu_seqlens_m,
521
+ cu_seqlens_k=cu_seqlens_k,
522
+ A_idx=A_idx,
523
+ batch_idx_permute=batch_idx_permute,
524
+ add_to_output=add_to_output,
525
+ dynamic_scheduler=dynamic_scheduler,
526
+ )
285
527
 
286
528
 
287
529
  def gemm_add_ref(
530
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
288
531
  A: Tensor,
289
- B: Tensor,
290
- C: Tensor,
291
- out: Optional[Tensor] = None,
532
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
533
+ C: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
534
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
535
+ out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
292
536
  alpha: float | Tensor = 1.0,
293
537
  beta: float | Tensor = 1.0,
538
+ cu_seqlens_m: Optional[Tensor] = None,
539
+ cu_seqlens_k: Optional[Tensor] = None,
540
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
294
541
  out_dtype: Optional[torch.dtype] = None,
295
542
  ) -> Tensor:
296
543
  """Reference implementation for GEMM with addition and pre-allocated output."""
297
- if isinstance(alpha, float) and isinstance(beta, float):
298
- return torch.addmm(C, A, B, out_dtype=out_dtype, alpha=alpha, beta=beta, out=out)
299
- else:
300
- out_dtype = (
301
- out.dtype if out is not None else (out_dtype if out_dtype is not None else A.dtype)
302
- )
303
- result = (alpha * (A @ B) + beta * C).to(out_dtype)
304
- if out is not None:
305
- out.copy_(result)
306
- return result
544
+ if cu_seqlens_m is None and cu_seqlens_k is None:
545
+ if isinstance(alpha, float) and isinstance(beta, float):
546
+ out = torch.addmm(C, A, B, out_dtype=out_dtype, alpha=alpha, beta=beta, out=out)
547
+ else:
548
+ out_dtype = (
549
+ out.dtype if out is not None else (out_dtype if out_dtype is not None else A.dtype)
550
+ )
551
+ result = (alpha * (A @ B) + beta * C).to(out_dtype)
552
+ if out is not None:
553
+ out.copy_(result)
554
+ if bias is not None:
555
+ bias = bias if A.ndim == 2 else bias.unsqueeze(1)
556
+ out += bias
557
+ elif cu_seqlens_m is not None:
558
+ # Handle varlen_m case
559
+ if out is None:
560
+ # When gather_A (A_idx provided), output size is determined by A_idx length
561
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
562
+ out_dtype = out_dtype if out_dtype is not None else A.dtype
563
+ out = torch.empty((total_m, B.shape[-1]), dtype=out_dtype, device=A.device)
564
+ for i in range(cu_seqlens_m.shape[0] - 1):
565
+ A_slice = (
566
+ A[A_idx[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]]
567
+ if A_idx is not None
568
+ else A[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
569
+ )
570
+ C_slice = C[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
571
+ out_slice = out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
572
+ result = alpha * torch.mm(A_slice, B[i]) + beta * C_slice
573
+ if bias is not None:
574
+ result += bias[i]
575
+ out_slice.copy_(result)
576
+ else: # cu_seqlens_k is not None
577
+ # Handle varlen_k case
578
+ L = cu_seqlens_k.shape[0] - 1
579
+ out_dtype = out_dtype if out_dtype is not None else A.dtype
580
+ if out is None:
581
+ out = torch.empty((L, A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
582
+ for i in range(L):
583
+ A_slice = (
584
+ A[:, A_idx[cu_seqlens_k[i] : cu_seqlens_k[i + 1]]]
585
+ if A_idx is not None
586
+ else A[:, cu_seqlens_k[i] : cu_seqlens_k[i + 1]]
587
+ )
588
+ B_slice = B[cu_seqlens_k[i] : cu_seqlens_k[i + 1], :]
589
+ result = alpha * torch.mm(A_slice, B_slice) + beta * C[i]
590
+ out[i].copy_(result)
591
+ if bias is not None:
592
+ out += bias
593
+ return out
307
594
 
308
595
 
309
596
  def gemm_add_inplace(
597
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
310
598
  A: Tensor,
311
- B: Tensor,
312
- out: Tensor,
599
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
600
+ out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
313
601
  alpha: float | Tensor = 1.0,
314
602
  beta: float | Tensor = 1.0,
603
+ cu_seqlens_m: Optional[Tensor] = None,
604
+ cu_seqlens_k: Optional[Tensor] = None,
605
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
606
+ batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
315
607
  dynamic_scheduler: bool = False,
316
608
  tuned: bool = True,
317
609
  ) -> None:
318
610
  """In-place GEMM with addition: out = alpha * A @ B + beta * out.
319
611
  Args:
320
- A: (M, K) input tensor
321
- B: (K, N) input tensor
322
- out: (M, N) tensor to accumulate into (modified in-place)
612
+ A: (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k - input tensor
613
+ B: (K, N) or (L, K, N) or (total_K, N) if varlen_k - input tensor
614
+ out: (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k - tensor to accumulate into (modified in-place)
323
615
  alpha: Scalar multiplier for A @ B
324
616
  beta: Scalar multiplier for out
617
+ cu_seqlens_m: Optional cumulative sequence lengths for variable M
618
+ cu_seqlens_k: Optional cumulative sequence lengths for variable K
325
619
  dynamic_scheduler: Whether to use dynamic scheduler
326
620
  tuned: Whether to use autotuned configuration
327
621
  """
@@ -329,7 +623,21 @@ def gemm_add_inplace(
329
623
  alpha = alpha if isinstance(alpha, float) else 1.0
330
624
  beta_tensor = beta if not isinstance(beta, float) else None
331
625
  beta = beta if isinstance(beta, float) else 1.0
332
- gemm_add_inplace_op(A, B, out, alpha, beta, alpha_tensor, beta_tensor, dynamic_scheduler, tuned)
626
+ gemm_add_inplace_op(
627
+ A,
628
+ B,
629
+ out,
630
+ alpha,
631
+ beta,
632
+ alpha_tensor,
633
+ beta_tensor,
634
+ cu_seqlens_m,
635
+ cu_seqlens_k,
636
+ A_idx=A_idx,
637
+ batch_idx_permute=batch_idx_permute,
638
+ dynamic_scheduler=dynamic_scheduler,
639
+ tuned=tuned,
640
+ )
333
641
 
334
642
 
335
643
  @torch.library.custom_op(
@@ -338,46 +646,90 @@ def gemm_add_inplace(
338
646
  device_types="cuda",
339
647
  # We have to split out alpha and alpha_tensor since torch.library requires
340
648
  # each argument to have a fixed type
341
- # schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
649
+ # schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, Tensor? cu_seqlens_m=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
342
650
  )
343
651
  def gemm_add_inplace_op(
652
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
344
653
  A: Tensor,
345
- B: Tensor,
346
- out: Tensor,
654
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
655
+ out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
347
656
  alpha: float = 1.0,
348
657
  beta: float = 1.0,
349
658
  alpha_tensor: Optional[Tensor] = None,
350
659
  beta_tensor: Optional[Tensor] = None,
660
+ cu_seqlens_m: Optional[Tensor] = None,
661
+ cu_seqlens_k: Optional[Tensor] = None,
662
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
663
+ batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
351
664
  dynamic_scheduler: bool = False,
352
665
  tuned: bool = True,
353
666
  ) -> None:
354
667
  fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
355
668
  alpha = alpha_tensor if alpha_tensor is not None else alpha
356
669
  beta = beta_tensor if beta_tensor is not None else beta
357
- # Use C as both input bias and output
358
- fn(A, B, out, out, alpha=alpha, beta=beta, dynamic_scheduler=dynamic_scheduler)
670
+ add_to_output = isinstance(beta, float) and beta == 1.0 and cu_seqlens_m is None
671
+ # Use out as both input bias and output
672
+ fn(
673
+ A,
674
+ B,
675
+ out,
676
+ out if not add_to_output else None,
677
+ alpha=alpha,
678
+ beta=beta,
679
+ cu_seqlens_m=cu_seqlens_m,
680
+ cu_seqlens_k=cu_seqlens_k,
681
+ A_idx=A_idx,
682
+ batch_idx_permute=batch_idx_permute,
683
+ add_to_output=add_to_output,
684
+ dynamic_scheduler=dynamic_scheduler,
685
+ )
359
686
 
360
687
 
361
688
  def gemm_act(
362
- A: Tensor,
363
- B: Tensor,
364
- C: Optional[Tensor] = None,
689
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
690
+ B: Tensor, # (K, N) or (L, K, N)
691
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
692
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
365
693
  activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
366
- preact_out: Optional[Tensor] = None,
367
- postact_out: Optional[Tensor] = None,
694
+ preact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
695
+ postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
368
696
  out_dtype: Optional[torch.dtype] = None,
369
697
  postact_dtype: Optional[torch.dtype] = None,
698
+ cu_seqlens_m: Optional[Tensor] = None,
699
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
370
700
  store_preact: bool = True,
701
+ dynamic_scheduler: bool = False,
371
702
  tuned: bool = True,
372
703
  ) -> Tuple[Optional[Tensor], Tensor]:
373
704
  """GEMM with activation and optional output tensors."""
374
705
  out_dtype = A.dtype if out_dtype is None else out_dtype
375
706
  postact_dtype = A.dtype if postact_dtype is None else postact_dtype
707
+ varlen_m = cu_seqlens_m is not None
708
+ # Determine output shape based on gather_A
709
+ if varlen_m:
710
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
711
+ out_shape = (total_m, B.shape[-1])
712
+ elif A.ndim == 2:
713
+ out_shape = (A.shape[0], B.shape[-1])
714
+ else:
715
+ out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
376
716
  if preact_out is None and store_preact:
377
- preact_out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
717
+ preact_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
378
718
  if postact_out is None:
379
- postact_out = torch.empty((A.shape[0], B.shape[1]), dtype=postact_dtype, device=A.device)
380
- gemm_act_out(A, B, preact_out, postact_out, C, activation, tuned)
719
+ postact_out = torch.empty(out_shape, dtype=postact_dtype, device=A.device)
720
+ gemm_act_out(
721
+ A,
722
+ B,
723
+ preact_out,
724
+ postact_out,
725
+ C,
726
+ bias,
727
+ activation,
728
+ cu_seqlens_m,
729
+ A_idx,
730
+ dynamic_scheduler,
731
+ tuned,
732
+ )
381
733
  return preact_out, postact_out
382
734
 
383
735
 
@@ -385,58 +737,81 @@ def gemm_act(
385
737
  "quack::gemm_act_out",
386
738
  mutates_args=("preact_out", "postact_out"),
387
739
  device_types="cuda",
388
- schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, str? activation=None, bool tuned=True) -> ()",
740
+ schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, Tensor? bias=None, str? activation=None, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
389
741
  )
390
742
  def gemm_act_out(
391
- A: Tensor,
392
- B: Tensor,
393
- preact_out: Optional[Tensor],
394
- postact_out: Tensor,
395
- C: Optional[Tensor] = None,
743
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
744
+ B: Tensor, # (K, N) or (L, K, N)
745
+ preact_out: Optional[Tensor], # (M, N) or (L, M, N) or (total_M, N) if varlen_m
746
+ postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
747
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
748
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
396
749
  activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
750
+ cu_seqlens_m: Optional[Tensor] = None,
751
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
752
+ dynamic_scheduler: bool = False,
397
753
  tuned: bool = True,
398
754
  ) -> None:
399
755
  """GEMM with activation and pre-allocated output tensors."""
400
756
  fn = gemm_act_tuned if tuned else partial(gemm_act_tuned.fn, config=None)
401
- fn(A, B, preact_out, postact_out, C, activation)
757
+ fn(A, B, preact_out, postact_out, C, bias, activation, cu_seqlens_m, A_idx, dynamic_scheduler)
402
758
 
403
759
 
404
760
  def gemm_act_ref(
405
- A: Tensor,
406
- B: Tensor,
407
- C: Optional[Tensor] = None,
761
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
762
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
763
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
764
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
408
765
  activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
766
+ cu_seqlens_m: Optional[Tensor] = None,
767
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
409
768
  out_dtype: Optional[torch.dtype] = None,
410
769
  postact_dtype: Optional[torch.dtype] = None,
411
770
  store_preact: bool = True,
412
771
  ) -> Tuple[Optional[Tensor], Tensor]:
413
772
  out_dtype = A.dtype if out_dtype is None else out_dtype
414
773
  postact_dtype = A.dtype if postact_dtype is None else postact_dtype
415
- out = torch.mm(A, B) if C is None else C + torch.mm(A, B)
774
+ if C is None:
775
+ out = gemm_ref(A, B, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
776
+ else:
777
+ out = gemm_add_ref(A, B, C, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
416
778
  postact = act_to_pytorch_fn_map[activation](out).to(postact_dtype)
417
779
  return out.to(out_dtype) if store_preact else None, postact
418
780
 
419
781
 
420
782
  def gemm_dact(
421
- A: Tensor,
422
- B: Tensor,
423
- PreAct: Tensor,
783
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
784
+ B: Tensor, # (K, N) or (L, K, N)
785
+ PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
424
786
  activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
425
- dx_out: Optional[Tensor] = None,
426
- postact_out: Optional[Tensor] = None,
787
+ dx_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
788
+ postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
427
789
  out_dtype: Optional[torch.dtype] = None,
428
790
  postact_dtype: Optional[torch.dtype] = None,
791
+ cu_seqlens_m: Optional[Tensor] = None,
792
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
429
793
  dynamic_scheduler: bool = True,
430
794
  tuned: bool = True,
431
795
  ) -> Tuple[Tensor, Tensor]:
432
796
  """GEMM with activation gradient and optional output tensors."""
433
797
  out_dtype = A.dtype if out_dtype is None else out_dtype
434
798
  postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
799
+ varlen_m = cu_seqlens_m is not None
800
+ # Determine output shape based on gather_A
801
+ if varlen_m:
802
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
803
+ out_shape = (total_m, B.shape[-1])
804
+ elif A.ndim == 2:
805
+ out_shape = (A.shape[0], B.shape[-1])
806
+ else:
807
+ out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
435
808
  if dx_out is None:
436
- dx_out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
809
+ dx_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
437
810
  if postact_out is None:
438
- postact_out = torch.empty((A.shape[0], B.shape[1]), dtype=postact_dtype, device=A.device)
439
- gemm_dact_out(A, B, PreAct, dx_out, postact_out, activation, dynamic_scheduler, tuned)
811
+ postact_out = torch.empty(out_shape, dtype=postact_dtype, device=A.device)
812
+ gemm_dact_out(
813
+ A, B, PreAct, dx_out, postact_out, activation, cu_seqlens_m, A_idx, dynamic_scheduler, tuned
814
+ )
440
815
  return dx_out, postact_out
441
816
 
442
817
 
@@ -444,35 +819,39 @@ def gemm_dact(
444
819
  "quack::gemm_dact_out",
445
820
  mutates_args=("dx_out", "postact_out"),
446
821
  device_types="cuda",
447
- schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a3!) dx_out, Tensor(a4!) postact_out, str? activation=None, bool dynamic_scheduler=True, bool tuned=True) -> ()",
822
+ schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a3!) dx_out, Tensor(a4!) postact_out, str? activation=None, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=True, bool tuned=True) -> ()",
448
823
  )
449
824
  def gemm_dact_out(
450
- A: Tensor,
451
- B: Tensor,
452
- PreAct: Tensor,
453
- dx_out: Tensor,
454
- postact_out: Tensor,
825
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
826
+ B: Tensor, # (K, N) or (L, K, N)
827
+ PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
828
+ dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
829
+ postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
455
830
  activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
831
+ cu_seqlens_m: Optional[Tensor] = None,
832
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
456
833
  dynamic_scheduler: bool = True,
457
834
  tuned: bool = True,
458
835
  ) -> None:
459
836
  """GEMM with activation gradient and pre-allocated output tensors."""
460
837
  fn = gemm_dact_tuned if tuned else partial(gemm_dact_tuned.fn, config=None)
461
- fn(A, B, PreAct, dx_out, postact_out, activation, dynamic_scheduler)
838
+ fn(A, B, PreAct, dx_out, postact_out, activation, cu_seqlens_m, A_idx, dynamic_scheduler)
462
839
 
463
840
 
464
841
  def gemm_dact_ref(
465
- A: Tensor,
466
- B: Tensor,
467
- PreAct: Tensor,
842
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
843
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
844
+ PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
468
845
  activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
846
+ cu_seqlens_m: Optional[Tensor] = None,
847
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
469
848
  out_dtype: Optional[torch.dtype] = None,
470
849
  postact_dtype: Optional[torch.dtype] = None,
471
850
  ) -> Tuple[Tensor, Tensor]:
472
851
  """Reference implementation for GEMM with activation gradient."""
473
852
  out_dtype = A.dtype if out_dtype is None else out_dtype
474
853
  postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
475
- dout = torch.mm(A, B).to(out_dtype)
854
+ dout = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx).to(out_dtype)
476
855
  postact = act_to_pytorch_fn_map[activation](PreAct)
477
856
  # Compute gradient using autograd
478
857
  if activation is None:
@@ -487,10 +866,13 @@ def gemm_dact_ref(
487
866
 
488
867
 
489
868
  def gemm_gated_ref(
490
- A: Tensor,
491
- B: Tensor,
492
- C: Optional[Tensor] = None,
869
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
870
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
871
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
872
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
493
873
  activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"] = "swiglu",
874
+ cu_seqlens_m: Optional[Tensor] = None,
875
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
494
876
  out_dtype: Optional[torch.dtype] = None,
495
877
  postact_dtype: Optional[torch.dtype] = None,
496
878
  store_preact: bool = True,
@@ -499,8 +881,8 @@ def gemm_gated_ref(
499
881
 
500
882
  Args:
501
883
  A: (M, K) - input tensor
502
- B: (K, 2*N) - weight tensor with gate and up projections
503
- C: (M, 2*N) - optional bias tensor
884
+ B: (K, N) - weight tensor with gate and up projections
885
+ C: (M, N) - optional bias tensor
504
886
  activation: Type of gated activation
505
887
  out_dtype: Output dtype for preact
506
888
  postact_dtype: Output dtype for postact
@@ -508,24 +890,29 @@ def gemm_gated_ref(
508
890
 
509
891
  Returns:
510
892
  (preact, postact) where:
511
- - preact: (M, 2*N) pre-activation (if store_preact=True, else None)
512
- - postact: (M, N) post-activation output
893
+ - preact: (M, N) pre-activation (if store_preact=True, else None)
894
+ - postact: (M, N // 2) post-activation output
513
895
  """
514
896
  out_dtype = A.dtype if out_dtype is None else out_dtype
515
897
  postact_dtype = A.dtype if postact_dtype is None else postact_dtype
516
- preact = torch.mm(A, B) if C is None else C + torch.mm(A, B)
898
+ if C is None:
899
+ preact = gemm_ref(A, B, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
900
+ else:
901
+ preact = gemm_add_ref(A, B, C, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
517
902
  # Split preact into gate and up projections
518
- gate = preact[..., ::2] # (M, N)
519
- up = preact[..., 1::2] # (M, N)
903
+ gate = preact[..., ::2] # (M, N//2)
904
+ up = preact[..., 1::2] # (M, N//2)
520
905
  postact = gated_to_pytorch_fn_map[activation](gate, up)
521
906
  return preact.to(out_dtype) if store_preact else None, postact.to(postact_dtype)
522
907
 
523
908
 
524
909
  def gemm_dgated_ref(
525
- A: Tensor,
526
- B: Tensor,
527
- PreAct: Tensor,
910
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
911
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
912
+ PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
528
913
  activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"],
914
+ cu_seqlens_m: Optional[Tensor] = None,
915
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
529
916
  out_dtype: Optional[torch.dtype] = None,
530
917
  postact_dtype: Optional[torch.dtype] = None,
531
918
  ) -> Tuple[Tensor, Tensor]:
@@ -546,18 +933,100 @@ def gemm_dgated_ref(
546
933
  """
547
934
  out_dtype = A.dtype if out_dtype is None else out_dtype
548
935
  postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
549
- dout = torch.mm(A, B).to(out_dtype)
936
+ dout = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx).to(out_dtype)
550
937
  # Split PreAct into gate and up projections
551
938
  gate = PreAct[..., ::2] # (M, N)
552
939
  up = PreAct[..., 1::2] # (M, N)
553
- postact = gated_to_pytorch_fn_map[activation](gate, up)
554
940
  # Use autograd to compute gradients w.r.t. gate and up
941
+ gate_requires_grad, up_requires_grad = gate.requires_grad, up.requires_grad
942
+ gate.requires_grad_(True)
943
+ up.requires_grad_(True)
944
+ postact = gated_to_pytorch_fn_map[activation](gate, up)
555
945
  dgate, dup = torch.autograd.grad(postact, [gate, up], dout, create_graph=False)
946
+ gate.requires_grad_(gate_requires_grad)
947
+ up.requires_grad_(up_requires_grad)
556
948
  # Interleave gradients back
557
949
  dx = torch.stack([dgate, dup], dim=-1).reshape(PreAct.shape)
558
950
  return dx.to(out_dtype), postact.to(postact_dtype)
559
951
 
560
952
 
953
+ @torch.library.custom_op(
954
+ "quack::gemm_symmetric_out",
955
+ mutates_args=("out",),
956
+ device_types="cuda",
957
+ schema="(Tensor A, Tensor B, Tensor(a2!) out, Tensor? C=None, bool dynamic_scheduler=False, float alpha=1.0, float beta=1.0) -> ()",
958
+ )
959
+ def gemm_symmetric_out(
960
+ A: Tensor, # (M, K) or (L, M, K)
961
+ B: Tensor, # (K, M) or (L, K, M)
962
+ out: Tensor, # (M, M) or (L, M, M)
963
+ C: Optional[Tensor] = None, # (M, M) or (L, M, M)
964
+ dynamic_scheduler: bool = False,
965
+ alpha: float = 1.0,
966
+ beta: float = 1.0,
967
+ ) -> None:
968
+ """GEMM with guaranteed symmetric output."""
969
+ if A.ndim == 2:
970
+ A = A.unsqueeze(0) # (1, M, K)
971
+ B = B.mT # (M, K) or (L, M, K)
972
+ if B.ndim == 2:
973
+ B = B.unsqueeze(0) # (1, M, K)
974
+ if C is not None and C.ndim == 2:
975
+ C = C.unsqueeze(0) # (1, M, M)
976
+ if out.ndim == 2:
977
+ out = out.unsqueeze(0)
978
+ else:
979
+ out = out
980
+ tile_count_semaphore = (
981
+ torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
982
+ )
983
+ gemm_symmetric_sm90_sm100(
984
+ A,
985
+ B,
986
+ out if out is not None else None,
987
+ C if C is not None else None,
988
+ tile_count_semaphore,
989
+ tile_M=128,
990
+ tile_N=256,
991
+ cluster_M=2,
992
+ cluster_N=1,
993
+ pingpong=False,
994
+ persistent=True,
995
+ max_swizzle_size=8,
996
+ alpha=alpha,
997
+ beta=beta,
998
+ )
999
+
1000
+
1001
+ def gemm_symmetric(
1002
+ A: Tensor, # (M, K) or (L, M, K)
1003
+ B: Tensor, # (K, M) or (L, K, M)
1004
+ C: Optional[Tensor] = None, # (M, M) or (L, M, M)
1005
+ out: Optional[Tensor] = None, # (M, M) or (L, M, M)
1006
+ out_dtype: Optional[torch.dtype] = None,
1007
+ dynamic_scheduler: bool = False,
1008
+ alpha: float | Tensor = 1.0,
1009
+ beta: float | Tensor = 1.0,
1010
+ ) -> Tuple[Optional[Tensor], Tensor]:
1011
+ """GEMM with symmetric output."""
1012
+ out_dtype = A.dtype if out_dtype is None else out_dtype
1013
+ # Determine output shape based on gather_A
1014
+ if A.ndim == 2:
1015
+ out_shape = (A.shape[0], B.shape[-1])
1016
+ else:
1017
+ out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
1018
+ if out is None:
1019
+ out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
1020
+
1021
+ alpha_val = alpha if isinstance(alpha, float) else 1.0
1022
+ beta_val = beta if isinstance(beta, float) else 1.0
1023
+
1024
+ gemm_symmetric_out(
1025
+ A, B, out, C, dynamic_scheduler=dynamic_scheduler, alpha=alpha_val, beta=beta_val
1026
+ )
1027
+ return out
1028
+
1029
+
561
1030
  # TODO: this is not quite right, do we need to register gemm_add not gemm_add_out?
562
1031
  # try:
563
1032
  # from torch._inductor.fx_passes.reinplace import InplaceableOp