quack-kernels 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/gemm_interface.py CHANGED
@@ -34,30 +34,71 @@ gated_to_pytorch_fn_map = {
34
34
  }
35
35
 
36
36
 
37
+ def prune_invalid_gemm_configs(configs, named_args: dict, **kwargs):
38
+ kwargs = named_args | kwargs
39
+ gather_A = kwargs.get("A_idx", None) is not None
40
+ varlen_m = kwargs.get("cu_seqlens_m", None) is not None
41
+ if varlen_m or gather_A: # Doesn't support swap_ab
42
+ configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab]
43
+ if gather_A:
44
+ # tile_n == 208 causes register spills, as gather_A requires more registers for the producer
45
+ configs = [
46
+ conf
47
+ for conf in configs
48
+ if conf.kwargs["config"].cluster_n == 1 and conf.kwargs["config"].tile_n != 208
49
+ ]
50
+ return configs
51
+
52
+
37
53
  @autotune(
38
54
  configs=[AutotuneConfig(config=c) for c in get_all_configs()],
39
55
  key=["dynamic_scheduler"],
56
+ prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
40
57
  )
41
58
  def gemm_tuned(
42
- A: Tensor, # (M, K)
43
- B: Tensor, # (K, N)
44
- out: Tensor, # (M, N) - required output tensor
45
- C: Optional[Tensor] = None, # (M, N)
59
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
60
+ A: Tensor,
61
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
62
+ out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
63
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
46
64
  alpha: float | Tensor = 1.0, # (1,)
47
65
  beta: float | Tensor = 1.0, # (1,)
66
+ cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
67
+ cu_seqlens_k: Optional[Tensor] = None, # (L+1), int32
68
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
69
+ batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
70
+ add_to_output: bool = False,
48
71
  dynamic_scheduler: bool = False,
49
72
  config: Optional[GemmConfig] = None,
50
73
  ) -> None:
51
74
  if config is None:
52
75
  config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
53
- A, B = A.unsqueeze(0), B.mT.unsqueeze(0) # (1, M, K), (1, N, K)
54
- if C is not None:
76
+ varlen_m = cu_seqlens_m is not None
77
+ varlen_k = cu_seqlens_k is not None
78
+ varlen = varlen_m or varlen_k
79
+ gather_A = A_idx is not None
80
+ if gather_A:
81
+ assert varlen, "gather_A requires either varlen_m or varlen_k"
82
+ assert config.cluster_n == 1, "gather_A requires cluster_n=1"
83
+ if varlen_m:
84
+ assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
85
+ if A.ndim == 2 and not varlen:
86
+ A = A.unsqueeze(0) # (1, M, K)
87
+ B = B.mT # (N, K) or (L, N, K) or (N, total_K)
88
+ if B.ndim == 2 and not varlen_k:
89
+ B = B.unsqueeze(0) # (1, N, K)
90
+ if C is not None and C.ndim == 2 and not varlen_m:
55
91
  C = C.unsqueeze(0) # (1, M, N)
56
- assert out.shape == (
57
- A.shape[1],
58
- B.shape[1],
59
- ), f"out shape mismatch: {out.shape} vs {(A.shape[1], B.shape[1])}"
60
- out = out.unsqueeze(0)
92
+ if out.ndim == 2 and not varlen_m:
93
+ out = out.unsqueeze(0)
94
+ batch_size = B.shape[0] if not varlen_k else cu_seqlens_k.shape[0] - 1
95
+ if varlen_m:
96
+ # If gather_A (A_idx provided), use its length; otherwise use A.shape[0]
97
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
98
+ out_shape = (total_m, B.shape[-2])
99
+ else:
100
+ out_shape = (batch_size, A.shape[-2], B.shape[-2])
101
+ assert out.shape == out_shape, f"out shape mismatch: {out.shape} vs {out_shape}"
61
102
  tile_count_semaphore = (
62
103
  torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
63
104
  )
@@ -74,71 +115,113 @@ def gemm_tuned(
74
115
  config.pingpong,
75
116
  alpha=alpha,
76
117
  beta=beta,
118
+ cu_seqlens_m=cu_seqlens_m,
119
+ cu_seqlens_k=cu_seqlens_k,
120
+ A_idx=A_idx,
121
+ batch_idx_permute=batch_idx_permute,
122
+ add_to_output=add_to_output,
77
123
  )
78
124
 
79
125
 
80
126
  @autotune(
81
127
  configs=[AutotuneConfig(config=c) for c in get_all_configs()],
82
- key=["activation"],
128
+ key=["activation", "dynamic_scheduler"],
129
+ prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
83
130
  )
84
131
  def gemm_act_tuned(
85
- A: Tensor, # (M, K)
86
- B: Tensor, # (K, N)
87
- preact_out: Optional[Tensor], # (M, N) - None if not storing preact
88
- postact_out: Tensor, # (M, N)
89
- C: Optional[Tensor] = None, # (M, N)
132
+ # (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
133
+ A: Tensor,
134
+ B: Tensor, # (K, N) or (L, K, N)
135
+ # (M, N) or (L, M, N) or (total_M, N) if varlen_m - None if not storing preact
136
+ preact_out: Optional[Tensor],
137
+ postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
138
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
90
139
  activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
140
+ cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
141
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
142
+ dynamic_scheduler: bool = False,
91
143
  config: Optional[GemmConfig] = None,
92
144
  ) -> None:
93
145
  if config is None:
94
146
  config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
95
- A, B = A.unsqueeze(0), B.mT.unsqueeze(0) # (1, M, K), (1, N, K)
96
- if C is not None:
147
+ varlen_m = cu_seqlens_m is not None
148
+ if varlen_m:
149
+ assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
150
+ if A.ndim == 2 and not varlen_m:
151
+ A = A.unsqueeze(0) # (1, M, K)
152
+ B = B.mT # (N, K) or (L, N, K)
153
+ if B.ndim == 2:
154
+ B = B.unsqueeze(0) # (1, N, K)
155
+ if C is not None and C.ndim == 2 and not varlen_m:
97
156
  C = C.unsqueeze(0) # (1, M, N)
98
- if preact_out is not None:
99
- assert preact_out.shape == (A.shape[1], B.shape[1])
157
+ if preact_out is not None and preact_out.ndim == 2 and not varlen_m:
100
158
  D = preact_out.unsqueeze(0)
101
159
  else:
102
- D = None
103
- assert postact_out.shape == (A.shape[1], B.shape[1])
104
- PostAct = postact_out.unsqueeze(0)
160
+ D = preact_out
161
+ if postact_out.ndim == 2 and not varlen_m:
162
+ PostAct = postact_out.unsqueeze(0)
163
+ else:
164
+ PostAct = postact_out
165
+ tile_count_semaphore = (
166
+ torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
167
+ )
105
168
  gemm_act_sm90(
106
169
  A if not config.swap_ab else B,
107
170
  B if not config.swap_ab else A,
108
171
  (D if not config.swap_ab else D.mT) if D is not None else None,
109
172
  (C if not config.swap_ab else C.mT) if C is not None else None,
110
173
  PostAct if not config.swap_ab else PostAct.mT,
174
+ tile_count_semaphore,
111
175
  activation,
112
176
  config.tile_m,
113
177
  config.tile_n,
114
178
  config.cluster_m,
115
179
  config.cluster_n,
116
180
  config.pingpong,
181
+ persistent=True,
182
+ cu_seqlens_m=cu_seqlens_m,
183
+ A_idx=A_idx,
117
184
  )
118
185
 
119
186
 
120
187
  @autotune(
121
188
  configs=[AutotuneConfig(config=c) for c in get_all_configs()],
122
189
  key=["activation", "dynamic_scheduler"],
190
+ prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
123
191
  )
124
192
  def gemm_dact_tuned(
125
- A: Tensor, # (M, K)
126
- B: Tensor, # (K, N)
127
- PreAct: Tensor, # (M, N)
128
- dx_out: Tensor, # (M, N)
129
- postact_out: Tensor, # (M, N)
193
+ # (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
194
+ A: Tensor,
195
+ B: Tensor, # (K, N) or (L, K, N)
196
+ PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
197
+ dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
198
+ postact_out: Tensor, # (M, N) or (L, N, N) or (total_M, N) if varlen_m
130
199
  activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
200
+ cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
201
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
131
202
  dynamic_scheduler: bool = True,
132
203
  config: Optional[GemmConfig] = None,
133
204
  ) -> None:
134
205
  if config is None:
135
206
  config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
136
- A, B = A.unsqueeze(0), B.mT.unsqueeze(0) # (1, M, K), (1, N, K)
137
- PreAct = PreAct.unsqueeze(0) # (1, M, N)
138
- assert dx_out.shape == (A.shape[1], B.shape[1])
139
- D = dx_out.unsqueeze(0)
140
- assert postact_out.shape == (A.shape[1], B.shape[1])
141
- PostAct = postact_out.unsqueeze(0)
207
+ varlen_m = cu_seqlens_m is not None
208
+ if varlen_m:
209
+ assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
210
+ if A.ndim == 2 and not varlen_m:
211
+ A = A.unsqueeze(0) # (1, M, K)
212
+ B = B.mT # (N, K) or (L, N, K)
213
+ if B.ndim == 2:
214
+ B = B.unsqueeze(0) # (1, N, K)
215
+ if PreAct.ndim == 2 and not varlen_m:
216
+ PreAct = PreAct.unsqueeze(0) # (1, M, N)
217
+ if dx_out.ndim == 2 and not varlen_m:
218
+ D = dx_out.unsqueeze(0)
219
+ else:
220
+ D = dx_out
221
+ if postact_out.ndim == 2 and not varlen_m:
222
+ PostAct = postact_out.unsqueeze(0)
223
+ else:
224
+ PostAct = postact_out
142
225
  tile_count_semaphore = (
143
226
  torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
144
227
  )
@@ -155,22 +238,43 @@ def gemm_dact_tuned(
155
238
  config.cluster_m,
156
239
  config.cluster_n,
157
240
  config.pingpong,
241
+ persistent=True,
242
+ cu_seqlens_m=cu_seqlens_m,
243
+ A_idx=A_idx,
158
244
  )
159
245
 
160
246
 
161
247
  def gemm(
248
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
162
249
  A: Tensor,
163
- B: Tensor,
164
- out: Optional[Tensor] = None,
250
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
251
+ out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
165
252
  alpha: float | Tensor = 1.0,
166
253
  out_dtype: Optional[torch.dtype] = None,
254
+ cu_seqlens_m: Optional[Tensor] = None,
255
+ cu_seqlens_k: Optional[Tensor] = None,
256
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
257
+ batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
167
258
  dynamic_scheduler: bool = False,
168
259
  tuned: bool = True,
169
260
  ) -> Tensor:
170
261
  """GEMM with optional output tensor and tuning control."""
171
262
  if out is None:
172
263
  out_dtype = A.dtype if out_dtype is None else out_dtype
173
- out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
264
+ varlen_m = cu_seqlens_m is not None
265
+ varlen_k = cu_seqlens_k is not None
266
+ if varlen_m:
267
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
268
+ out_shape = (total_m, B.shape[-1])
269
+ elif varlen_k:
270
+ L = cu_seqlens_k.shape[0] - 1
271
+ # For varlen_k, the first dimension is always A.shape[0] (M dimension)
272
+ out_shape = (L, A.shape[0], B.shape[-1])
273
+ else:
274
+ out_shape = (
275
+ (A.shape[0], B.shape[-1]) if A.ndim == 2 else (A.shape[0], A.shape[-2], B.shape[-1])
276
+ )
277
+ out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
174
278
  alpha_tensor = alpha if not isinstance(alpha, float) else None
175
279
  alpha = alpha if isinstance(alpha, float) else 1.0
176
280
  gemm_out(
@@ -179,6 +283,10 @@ def gemm(
179
283
  out,
180
284
  alpha=alpha,
181
285
  alpha_tensor=alpha_tensor,
286
+ cu_seqlens_m=cu_seqlens_m,
287
+ cu_seqlens_k=cu_seqlens_k,
288
+ A_idx=A_idx,
289
+ batch_idx_permute=batch_idx_permute,
182
290
  dynamic_scheduler=dynamic_scheduler,
183
291
  tuned=tuned,
184
292
  )
@@ -194,50 +302,117 @@ def gemm(
194
302
  # schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, Tensor? alpha_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
195
303
  )
196
304
  def gemm_out(
305
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
197
306
  A: Tensor,
198
- B: Tensor,
199
- out: Tensor,
307
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
308
+ out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
200
309
  alpha: float = 1.0,
201
310
  alpha_tensor: Optional[Tensor] = None,
311
+ cu_seqlens_m: Optional[Tensor] = None,
312
+ cu_seqlens_k: Optional[Tensor] = None,
313
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
314
+ batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
202
315
  dynamic_scheduler: bool = False,
203
316
  tuned: bool = True,
204
317
  ) -> None:
205
318
  """GEMM with pre-allocated output tensor."""
206
319
  fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
207
320
  alpha = alpha_tensor if alpha_tensor is not None else alpha
208
- fn(A, B, out, C=None, alpha=alpha, dynamic_scheduler=dynamic_scheduler)
321
+ fn(
322
+ A,
323
+ B,
324
+ out,
325
+ C=None,
326
+ alpha=alpha,
327
+ cu_seqlens_m=cu_seqlens_m,
328
+ cu_seqlens_k=cu_seqlens_k,
329
+ A_idx=A_idx,
330
+ batch_idx_permute=batch_idx_permute,
331
+ dynamic_scheduler=dynamic_scheduler,
332
+ )
209
333
 
210
334
 
211
335
  def gemm_ref(
336
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
212
337
  A: Tensor,
213
- B: Tensor,
214
- out: Optional[Tensor] = None,
338
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
339
+ out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
215
340
  alpha: float | Tensor = 1.0,
341
+ cu_seqlens_m: Optional[Tensor] = None,
342
+ cu_seqlens_k: Optional[Tensor] = None,
343
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
216
344
  out_dtype: Optional[torch.dtype] = None,
217
345
  ) -> Tensor:
218
346
  """Reference implementation for GEMM with pre-allocated output."""
219
347
  # The out_dtype argument requires torch >= 2.8
220
- out = torch.mm(A, B, out_dtype=out_dtype, out=out)
348
+ out_dtype = A.dtype if out_dtype is None else out_dtype
349
+ if cu_seqlens_m is None and cu_seqlens_k is None:
350
+ fn = torch.bmm if A.ndim == 3 else torch.mm
351
+ out = fn(A, B, out_dtype=out_dtype, out=out)
352
+ elif cu_seqlens_m is not None:
353
+ # Handle varlen_m case
354
+ if out is None:
355
+ # When gather_A (A_idx provided), output size is determined by A_idx length
356
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
357
+ out = torch.empty((total_m, B.shape[-1]), dtype=out_dtype, device=A.device)
358
+ for i in range(cu_seqlens_m.shape[0] - 1):
359
+ A_slice = (
360
+ A[A_idx[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]]
361
+ if A_idx is not None
362
+ else A[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
363
+ )
364
+ torch.mm(A_slice, B[i], out=out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]])
365
+ else: # cu_seqlens_k is not None
366
+ L = cu_seqlens_k.shape[0] - 1
367
+ if out is None:
368
+ out = torch.empty((L, A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
369
+ for i in range(L):
370
+ A_slice = (
371
+ A[:, A_idx[cu_seqlens_k[i] : cu_seqlens_k[i + 1]]]
372
+ if A_idx is not None
373
+ else A[:, cu_seqlens_k[i] : cu_seqlens_k[i + 1]]
374
+ )
375
+ torch.mm(A_slice, B[cu_seqlens_k[i] : cu_seqlens_k[i + 1], :], out=out[i])
221
376
  if not isinstance(alpha, float) or alpha != 1.0:
222
377
  out = out * alpha
223
378
  return out
224
379
 
225
380
 
226
381
  def gemm_add(
382
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
227
383
  A: Tensor,
228
- B: Tensor,
229
- C: Tensor,
230
- out: Optional[Tensor] = None,
384
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
385
+ C: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
386
+ out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
231
387
  alpha: float | Tensor = 1.0,
232
388
  beta: float | Tensor = 1.0,
233
389
  out_dtype: Optional[torch.dtype] = None,
390
+ cu_seqlens_m: Optional[Tensor] = None,
391
+ cu_seqlens_k: Optional[Tensor] = None,
392
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
393
+ batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
234
394
  dynamic_scheduler: bool = False,
235
395
  tuned: bool = True,
236
396
  ) -> Tensor:
237
397
  """GEMM with addition and optional output tensor."""
238
398
  if out is None:
239
399
  out_dtype = A.dtype if out_dtype is None else out_dtype
240
- out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
400
+ varlen_m = cu_seqlens_m is not None
401
+ varlen_k = cu_seqlens_k is not None
402
+ if varlen_m:
403
+ # If A_idx is provided (gather_A), use its length; otherwise use A.shape[0]
404
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
405
+ out_shape = (total_m, B.shape[-1])
406
+ elif varlen_k:
407
+ L = cu_seqlens_k.shape[0] - 1
408
+ # For varlen_k, the first dimension is always A.shape[0] (M dimension)
409
+ out_shape = (L, A.shape[0], B.shape[-1])
410
+ else:
411
+ out_shape = (
412
+ (A.shape[0], B.shape[-1]) if A.ndim == 2 else (A.shape[0], A.shape[-2], B.shape[-1])
413
+ )
414
+ out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
415
+ add_to_output = C is out and isinstance(beta, float) and beta == 1.0 and cu_seqlens_m is None
241
416
  alpha_tensor = alpha if not isinstance(alpha, float) else None
242
417
  alpha = alpha if isinstance(alpha, float) else 1.0
243
418
  beta_tensor = beta if not isinstance(beta, float) else None
@@ -245,12 +420,17 @@ def gemm_add(
245
420
  gemm_add_out(
246
421
  A,
247
422
  B,
248
- C,
423
+ C if not add_to_output else None,
249
424
  out,
250
425
  alpha,
251
426
  beta,
252
427
  alpha_tensor,
253
428
  beta_tensor,
429
+ cu_seqlens_m=cu_seqlens_m,
430
+ cu_seqlens_k=cu_seqlens_k,
431
+ A_idx=A_idx,
432
+ batch_idx_permute=batch_idx_permute,
433
+ add_to_output=add_to_output,
254
434
  dynamic_scheduler=dynamic_scheduler,
255
435
  tuned=tuned,
256
436
  )
@@ -263,17 +443,23 @@ def gemm_add(
263
443
  device_types="cuda",
264
444
  # We have to split out alpha and alpha_tensor since torch.library requires
265
445
  # each argument to have a fixed type
266
- # schema="(Tensor A, Tensor B, Tensor C, Tensor(a3!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
446
+ # schema="(Tensor A, Tensor B, Tensor C, Tensor(a3!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, Tensor? cu_seqlens_m=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
267
447
  )
268
448
  def gemm_add_out(
449
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
269
450
  A: Tensor,
270
- B: Tensor,
271
- C: Tensor,
272
- out: Tensor,
451
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
452
+ C: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
453
+ out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
273
454
  alpha: float = 1.0,
274
455
  beta: float = 1.0,
275
456
  alpha_tensor: Optional[Tensor] = None,
276
457
  beta_tensor: Optional[Tensor] = None,
458
+ cu_seqlens_m: Optional[Tensor] = None,
459
+ cu_seqlens_k: Optional[Tensor] = None,
460
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
461
+ batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
462
+ add_to_output: bool = False,
277
463
  dynamic_scheduler: bool = False,
278
464
  tuned: bool = True,
279
465
  ) -> None:
@@ -281,47 +467,105 @@ def gemm_add_out(
281
467
  fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
282
468
  alpha = alpha_tensor if alpha_tensor is not None else alpha
283
469
  beta = beta_tensor if beta_tensor is not None else beta
284
- fn(A, B, out, C, alpha=alpha, beta=beta, dynamic_scheduler=dynamic_scheduler)
470
+ fn(
471
+ A,
472
+ B,
473
+ out,
474
+ C,
475
+ alpha=alpha,
476
+ beta=beta,
477
+ cu_seqlens_m=cu_seqlens_m,
478
+ cu_seqlens_k=cu_seqlens_k,
479
+ A_idx=A_idx,
480
+ batch_idx_permute=batch_idx_permute,
481
+ add_to_output=add_to_output,
482
+ dynamic_scheduler=dynamic_scheduler,
483
+ )
285
484
 
286
485
 
287
486
  def gemm_add_ref(
487
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
288
488
  A: Tensor,
289
- B: Tensor,
290
- C: Tensor,
291
- out: Optional[Tensor] = None,
489
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
490
+ C: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
491
+ out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
292
492
  alpha: float | Tensor = 1.0,
293
493
  beta: float | Tensor = 1.0,
494
+ cu_seqlens_m: Optional[Tensor] = None,
495
+ cu_seqlens_k: Optional[Tensor] = None,
496
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
294
497
  out_dtype: Optional[torch.dtype] = None,
295
498
  ) -> Tensor:
296
499
  """Reference implementation for GEMM with addition and pre-allocated output."""
297
- if isinstance(alpha, float) and isinstance(beta, float):
298
- return torch.addmm(C, A, B, out_dtype=out_dtype, alpha=alpha, beta=beta, out=out)
299
- else:
300
- out_dtype = (
301
- out.dtype if out is not None else (out_dtype if out_dtype is not None else A.dtype)
302
- )
303
- result = (alpha * (A @ B) + beta * C).to(out_dtype)
304
- if out is not None:
305
- out.copy_(result)
306
- return result
500
+ if cu_seqlens_m is None and cu_seqlens_k is None:
501
+ if isinstance(alpha, float) and isinstance(beta, float):
502
+ return torch.addmm(C, A, B, out_dtype=out_dtype, alpha=alpha, beta=beta, out=out)
503
+ else:
504
+ out_dtype = (
505
+ out.dtype if out is not None else (out_dtype if out_dtype is not None else A.dtype)
506
+ )
507
+ result = (alpha * (A @ B) + beta * C).to(out_dtype)
508
+ if out is not None:
509
+ out.copy_(result)
510
+ return result
511
+ elif cu_seqlens_m is not None:
512
+ # Handle varlen_m case
513
+ if out is None:
514
+ # When gather_A (A_idx provided), output size is determined by A_idx length
515
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
516
+ out_dtype = out_dtype if out_dtype is not None else A.dtype
517
+ out = torch.empty((total_m, B.shape[-1]), dtype=out_dtype, device=A.device)
518
+ for i in range(cu_seqlens_m.shape[0] - 1):
519
+ A_slice = (
520
+ A[A_idx[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]]
521
+ if A_idx is not None
522
+ else A[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
523
+ )
524
+ C_slice = C[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
525
+ out_slice = out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
526
+ result = alpha * torch.mm(A_slice, B[i]) + beta * C_slice
527
+ out_slice.copy_(result)
528
+ else: # cu_seqlens_k is not None
529
+ # Handle varlen_k case
530
+ L = cu_seqlens_k.shape[0] - 1
531
+ out_dtype = out_dtype if out_dtype is not None else A.dtype
532
+ if out is None:
533
+ out = torch.empty((L, A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
534
+ for i in range(L):
535
+ A_slice = (
536
+ A[:, A_idx[cu_seqlens_k[i] : cu_seqlens_k[i + 1]]]
537
+ if A_idx is not None
538
+ else A[:, cu_seqlens_k[i] : cu_seqlens_k[i + 1]]
539
+ )
540
+ B_slice = B[cu_seqlens_k[i] : cu_seqlens_k[i + 1], :]
541
+ result = alpha * torch.mm(A_slice, B_slice) + beta * C[i]
542
+ out[i].copy_(result)
543
+ return out
307
544
 
308
545
 
309
546
  def gemm_add_inplace(
547
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
310
548
  A: Tensor,
311
- B: Tensor,
312
- out: Tensor,
549
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
550
+ out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
313
551
  alpha: float | Tensor = 1.0,
314
552
  beta: float | Tensor = 1.0,
553
+ cu_seqlens_m: Optional[Tensor] = None,
554
+ cu_seqlens_k: Optional[Tensor] = None,
555
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
556
+ batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
315
557
  dynamic_scheduler: bool = False,
316
558
  tuned: bool = True,
317
559
  ) -> None:
318
560
  """In-place GEMM with addition: out = alpha * A @ B + beta * out.
319
561
  Args:
320
- A: (M, K) input tensor
321
- B: (K, N) input tensor
322
- out: (M, N) tensor to accumulate into (modified in-place)
562
+ A: (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k - input tensor
563
+ B: (K, N) or (L, K, N) or (total_K, N) if varlen_k - input tensor
564
+ out: (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k - tensor to accumulate into (modified in-place)
323
565
  alpha: Scalar multiplier for A @ B
324
566
  beta: Scalar multiplier for out
567
+ cu_seqlens_m: Optional cumulative sequence lengths for variable M
568
+ cu_seqlens_k: Optional cumulative sequence lengths for variable K
325
569
  dynamic_scheduler: Whether to use dynamic scheduler
326
570
  tuned: Whether to use autotuned configuration
327
571
  """
@@ -329,7 +573,21 @@ def gemm_add_inplace(
329
573
  alpha = alpha if isinstance(alpha, float) else 1.0
330
574
  beta_tensor = beta if not isinstance(beta, float) else None
331
575
  beta = beta if isinstance(beta, float) else 1.0
332
- gemm_add_inplace_op(A, B, out, alpha, beta, alpha_tensor, beta_tensor, dynamic_scheduler, tuned)
576
+ gemm_add_inplace_op(
577
+ A,
578
+ B,
579
+ out,
580
+ alpha,
581
+ beta,
582
+ alpha_tensor,
583
+ beta_tensor,
584
+ cu_seqlens_m,
585
+ cu_seqlens_k,
586
+ A_idx=A_idx,
587
+ batch_idx_permute=batch_idx_permute,
588
+ dynamic_scheduler=dynamic_scheduler,
589
+ tuned=tuned,
590
+ )
333
591
 
334
592
 
335
593
  @torch.library.custom_op(
@@ -338,46 +596,79 @@ def gemm_add_inplace(
338
596
  device_types="cuda",
339
597
  # We have to split out alpha and alpha_tensor since torch.library requires
340
598
  # each argument to have a fixed type
341
- # schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
599
+ # schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, Tensor? cu_seqlens_m=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
342
600
  )
343
601
  def gemm_add_inplace_op(
602
+ # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
344
603
  A: Tensor,
345
- B: Tensor,
346
- out: Tensor,
604
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
605
+ out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
347
606
  alpha: float = 1.0,
348
607
  beta: float = 1.0,
349
608
  alpha_tensor: Optional[Tensor] = None,
350
609
  beta_tensor: Optional[Tensor] = None,
610
+ cu_seqlens_m: Optional[Tensor] = None,
611
+ cu_seqlens_k: Optional[Tensor] = None,
612
+ A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
613
+ batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
351
614
  dynamic_scheduler: bool = False,
352
615
  tuned: bool = True,
353
616
  ) -> None:
354
617
  fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
355
618
  alpha = alpha_tensor if alpha_tensor is not None else alpha
356
619
  beta = beta_tensor if beta_tensor is not None else beta
357
- # Use C as both input bias and output
358
- fn(A, B, out, out, alpha=alpha, beta=beta, dynamic_scheduler=dynamic_scheduler)
620
+ add_to_output = isinstance(beta, float) and beta == 1.0 and cu_seqlens_m is None
621
+ # Use out as both input bias and output
622
+ fn(
623
+ A,
624
+ B,
625
+ out,
626
+ out if not add_to_output else None,
627
+ alpha=alpha,
628
+ beta=beta,
629
+ cu_seqlens_m=cu_seqlens_m,
630
+ cu_seqlens_k=cu_seqlens_k,
631
+ A_idx=A_idx,
632
+ batch_idx_permute=batch_idx_permute,
633
+ add_to_output=add_to_output,
634
+ dynamic_scheduler=dynamic_scheduler,
635
+ )
359
636
 
360
637
 
361
638
  def gemm_act(
362
- A: Tensor,
363
- B: Tensor,
364
- C: Optional[Tensor] = None,
639
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
640
+ B: Tensor, # (K, N) or (L, K, N)
641
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
365
642
  activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
366
- preact_out: Optional[Tensor] = None,
367
- postact_out: Optional[Tensor] = None,
643
+ preact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
644
+ postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
368
645
  out_dtype: Optional[torch.dtype] = None,
369
646
  postact_dtype: Optional[torch.dtype] = None,
647
+ cu_seqlens_m: Optional[Tensor] = None,
648
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
370
649
  store_preact: bool = True,
650
+ dynamic_scheduler: bool = False,
371
651
  tuned: bool = True,
372
652
  ) -> Tuple[Optional[Tensor], Tensor]:
373
653
  """GEMM with activation and optional output tensors."""
374
654
  out_dtype = A.dtype if out_dtype is None else out_dtype
375
655
  postact_dtype = A.dtype if postact_dtype is None else postact_dtype
656
+ varlen_m = cu_seqlens_m is not None
657
+ # Determine output shape based on gather_A
658
+ if varlen_m:
659
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
660
+ out_shape = (total_m, B.shape[-1])
661
+ elif A.ndim == 2:
662
+ out_shape = (A.shape[0], B.shape[-1])
663
+ else:
664
+ out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
376
665
  if preact_out is None and store_preact:
377
- preact_out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
666
+ preact_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
378
667
  if postact_out is None:
379
- postact_out = torch.empty((A.shape[0], B.shape[1]), dtype=postact_dtype, device=A.device)
380
- gemm_act_out(A, B, preact_out, postact_out, C, activation, tuned)
668
+ postact_out = torch.empty(out_shape, dtype=postact_dtype, device=A.device)
669
+ gemm_act_out(
670
+ A, B, preact_out, postact_out, C, activation, cu_seqlens_m, A_idx, dynamic_scheduler, tuned
671
+ )
381
672
  return preact_out, postact_out
382
673
 
383
674
 
@@ -385,58 +676,79 @@ def gemm_act(
385
676
  "quack::gemm_act_out",
386
677
  mutates_args=("preact_out", "postact_out"),
387
678
  device_types="cuda",
388
- schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, str? activation=None, bool tuned=True) -> ()",
679
+ schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, str? activation=None, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
389
680
  )
390
681
  def gemm_act_out(
391
- A: Tensor,
392
- B: Tensor,
393
- preact_out: Optional[Tensor],
394
- postact_out: Tensor,
395
- C: Optional[Tensor] = None,
682
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
683
+ B: Tensor, # (K, N) or (L, K, N)
684
+ preact_out: Optional[Tensor], # (M, N) or (L, M, N) or (total_M, N) if varlen_m
685
+ postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
686
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
396
687
  activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
688
+ cu_seqlens_m: Optional[Tensor] = None,
689
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
690
+ dynamic_scheduler: bool = False,
397
691
  tuned: bool = True,
398
692
  ) -> None:
399
693
  """GEMM with activation and pre-allocated output tensors."""
400
694
  fn = gemm_act_tuned if tuned else partial(gemm_act_tuned.fn, config=None)
401
- fn(A, B, preact_out, postact_out, C, activation)
695
+ fn(A, B, preact_out, postact_out, C, activation, cu_seqlens_m, A_idx, dynamic_scheduler)
402
696
 
403
697
 
404
698
  def gemm_act_ref(
405
- A: Tensor,
406
- B: Tensor,
407
- C: Optional[Tensor] = None,
699
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
700
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
701
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
408
702
  activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
703
+ cu_seqlens_m: Optional[Tensor] = None,
704
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
409
705
  out_dtype: Optional[torch.dtype] = None,
410
706
  postact_dtype: Optional[torch.dtype] = None,
411
707
  store_preact: bool = True,
412
708
  ) -> Tuple[Optional[Tensor], Tensor]:
413
709
  out_dtype = A.dtype if out_dtype is None else out_dtype
414
710
  postact_dtype = A.dtype if postact_dtype is None else postact_dtype
415
- out = torch.mm(A, B) if C is None else C + torch.mm(A, B)
711
+ if C is None:
712
+ out = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
713
+ else:
714
+ out = gemm_add_ref(A, B, C, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
416
715
  postact = act_to_pytorch_fn_map[activation](out).to(postact_dtype)
417
716
  return out.to(out_dtype) if store_preact else None, postact
418
717
 
419
718
 
420
719
  def gemm_dact(
421
- A: Tensor,
422
- B: Tensor,
423
- PreAct: Tensor,
720
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
721
+ B: Tensor, # (K, N) or (L, K, N)
722
+ PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
424
723
  activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
425
- dx_out: Optional[Tensor] = None,
426
- postact_out: Optional[Tensor] = None,
724
+ dx_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
725
+ postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
427
726
  out_dtype: Optional[torch.dtype] = None,
428
727
  postact_dtype: Optional[torch.dtype] = None,
728
+ cu_seqlens_m: Optional[Tensor] = None,
729
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
429
730
  dynamic_scheduler: bool = True,
430
731
  tuned: bool = True,
431
732
  ) -> Tuple[Tensor, Tensor]:
432
733
  """GEMM with activation gradient and optional output tensors."""
433
734
  out_dtype = A.dtype if out_dtype is None else out_dtype
434
735
  postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
736
+ varlen_m = cu_seqlens_m is not None
737
+ # Determine output shape based on gather_A
738
+ if varlen_m:
739
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
740
+ out_shape = (total_m, B.shape[-1])
741
+ elif A.ndim == 2:
742
+ out_shape = (A.shape[0], B.shape[-1])
743
+ else:
744
+ out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
435
745
  if dx_out is None:
436
- dx_out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
746
+ dx_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
437
747
  if postact_out is None:
438
- postact_out = torch.empty((A.shape[0], B.shape[1]), dtype=postact_dtype, device=A.device)
439
- gemm_dact_out(A, B, PreAct, dx_out, postact_out, activation, dynamic_scheduler, tuned)
748
+ postact_out = torch.empty(out_shape, dtype=postact_dtype, device=A.device)
749
+ gemm_dact_out(
750
+ A, B, PreAct, dx_out, postact_out, activation, cu_seqlens_m, A_idx, dynamic_scheduler, tuned
751
+ )
440
752
  return dx_out, postact_out
441
753
 
442
754
 
@@ -444,35 +756,39 @@ def gemm_dact(
444
756
  "quack::gemm_dact_out",
445
757
  mutates_args=("dx_out", "postact_out"),
446
758
  device_types="cuda",
447
- schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a3!) dx_out, Tensor(a4!) postact_out, str? activation=None, bool dynamic_scheduler=True, bool tuned=True) -> ()",
759
+ schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a3!) dx_out, Tensor(a4!) postact_out, str? activation=None, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=True, bool tuned=True) -> ()",
448
760
  )
449
761
  def gemm_dact_out(
450
- A: Tensor,
451
- B: Tensor,
452
- PreAct: Tensor,
453
- dx_out: Tensor,
454
- postact_out: Tensor,
762
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
763
+ B: Tensor, # (K, N) or (L, K, N)
764
+ PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
765
+ dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
766
+ postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
455
767
  activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
768
+ cu_seqlens_m: Optional[Tensor] = None,
769
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
456
770
  dynamic_scheduler: bool = True,
457
771
  tuned: bool = True,
458
772
  ) -> None:
459
773
  """GEMM with activation gradient and pre-allocated output tensors."""
460
774
  fn = gemm_dact_tuned if tuned else partial(gemm_dact_tuned.fn, config=None)
461
- fn(A, B, PreAct, dx_out, postact_out, activation, dynamic_scheduler)
775
+ fn(A, B, PreAct, dx_out, postact_out, activation, cu_seqlens_m, A_idx, dynamic_scheduler)
462
776
 
463
777
 
464
778
  def gemm_dact_ref(
465
- A: Tensor,
466
- B: Tensor,
467
- PreAct: Tensor,
779
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
780
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
781
+ PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
468
782
  activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
783
+ cu_seqlens_m: Optional[Tensor] = None,
784
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
469
785
  out_dtype: Optional[torch.dtype] = None,
470
786
  postact_dtype: Optional[torch.dtype] = None,
471
787
  ) -> Tuple[Tensor, Tensor]:
472
788
  """Reference implementation for GEMM with activation gradient."""
473
789
  out_dtype = A.dtype if out_dtype is None else out_dtype
474
790
  postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
475
- dout = torch.mm(A, B).to(out_dtype)
791
+ dout = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx).to(out_dtype)
476
792
  postact = act_to_pytorch_fn_map[activation](PreAct)
477
793
  # Compute gradient using autograd
478
794
  if activation is None:
@@ -487,10 +803,12 @@ def gemm_dact_ref(
487
803
 
488
804
 
489
805
  def gemm_gated_ref(
490
- A: Tensor,
491
- B: Tensor,
492
- C: Optional[Tensor] = None,
806
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
807
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
808
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
493
809
  activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"] = "swiglu",
810
+ cu_seqlens_m: Optional[Tensor] = None,
811
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
494
812
  out_dtype: Optional[torch.dtype] = None,
495
813
  postact_dtype: Optional[torch.dtype] = None,
496
814
  store_preact: bool = True,
@@ -499,8 +817,8 @@ def gemm_gated_ref(
499
817
 
500
818
  Args:
501
819
  A: (M, K) - input tensor
502
- B: (K, 2*N) - weight tensor with gate and up projections
503
- C: (M, 2*N) - optional bias tensor
820
+ B: (K, N) - weight tensor with gate and up projections
821
+ C: (M, N) - optional bias tensor
504
822
  activation: Type of gated activation
505
823
  out_dtype: Output dtype for preact
506
824
  postact_dtype: Output dtype for postact
@@ -508,24 +826,29 @@ def gemm_gated_ref(
508
826
 
509
827
  Returns:
510
828
  (preact, postact) where:
511
- - preact: (M, 2*N) pre-activation (if store_preact=True, else None)
512
- - postact: (M, N) post-activation output
829
+ - preact: (M, N) pre-activation (if store_preact=True, else None)
830
+ - postact: (M, N // 2) post-activation output
513
831
  """
514
832
  out_dtype = A.dtype if out_dtype is None else out_dtype
515
833
  postact_dtype = A.dtype if postact_dtype is None else postact_dtype
516
- preact = torch.mm(A, B) if C is None else C + torch.mm(A, B)
834
+ if C is None:
835
+ preact = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
836
+ else:
837
+ preact = gemm_add_ref(A, B, C, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
517
838
  # Split preact into gate and up projections
518
- gate = preact[..., ::2] # (M, N)
519
- up = preact[..., 1::2] # (M, N)
839
+ gate = preact[..., ::2] # (M, N//2)
840
+ up = preact[..., 1::2] # (M, N//2)
520
841
  postact = gated_to_pytorch_fn_map[activation](gate, up)
521
842
  return preact.to(out_dtype) if store_preact else None, postact.to(postact_dtype)
522
843
 
523
844
 
524
845
  def gemm_dgated_ref(
525
- A: Tensor,
526
- B: Tensor,
527
- PreAct: Tensor,
846
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
847
+ B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
848
+ PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
528
849
  activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"],
850
+ cu_seqlens_m: Optional[Tensor] = None,
851
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
529
852
  out_dtype: Optional[torch.dtype] = None,
530
853
  postact_dtype: Optional[torch.dtype] = None,
531
854
  ) -> Tuple[Tensor, Tensor]:
@@ -546,7 +869,7 @@ def gemm_dgated_ref(
546
869
  """
547
870
  out_dtype = A.dtype if out_dtype is None else out_dtype
548
871
  postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
549
- dout = torch.mm(A, B).to(out_dtype)
872
+ dout = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx).to(out_dtype)
550
873
  # Split PreAct into gate and up projections
551
874
  gate = PreAct[..., ::2] # (M, N)
552
875
  up = PreAct[..., 1::2] # (M, N)