quack-kernels 0.1.11__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/gemm_interface.py CHANGED
@@ -1,5 +1,6 @@
1
1
  # Copyright (c) 2025, Tri Dao
2
- from typing import Optional
2
+ from typing import Optional, Tuple, Literal
3
+ from functools import partial
3
4
 
4
5
  import torch
5
6
  import torch.nn.functional as F
@@ -8,314 +9,561 @@ from torch import Tensor
8
9
  from quack.gemm_config import GemmConfig, get_all_configs
9
10
 
10
11
  from quack.autotuner import autotune, AutotuneConfig
11
- from quack.lse import logsumexp
12
+ from quack.dense_gemm_sm90 import gemm_sm90
13
+ from quack.gemm_act_sm90 import gemm_act_sm90
14
+ from quack.gemm_dact_sm90 import gemm_dact_sm90
12
15
 
13
16
 
14
- def gemm_swiglu_out_ref(
15
- A: Tensor, B: Tensor, out: Optional[Tensor], store_preact: bool
16
- ) -> (Tensor, Tensor):
17
- preact = torch.mm(A, B)
18
- out_ = F.silu(preact[..., ::2]) * preact[..., 1::2]
19
- if out is not None:
20
- out.copy_(out_)
21
- else:
22
- out = out_
23
- if not store_preact:
24
- preact = None
25
- return out, preact
17
+ # Dictionary mapping activation names to PyTorch functions
18
+ act_to_pytorch_fn_map = {
19
+ None: lambda x: x,
20
+ "relu": F.relu,
21
+ "relu_sq": lambda x: F.relu(x).square(),
22
+ "gelu_tanh_approx": partial(F.gelu, approximate="tanh"),
23
+ }
24
+
25
+
26
+ # Dictionary mapping gated activation names to their forward functions
27
+ # Each function takes (gate, up) and returns postact
28
+ gated_to_pytorch_fn_map = {
29
+ "swiglu": lambda gate, up: F.silu(gate) * up,
30
+ "swiglu_oai": lambda gate, up: gate * torch.sigmoid(1.702 * gate) * (up + 1),
31
+ "reglu": lambda gate, up: F.relu(gate) * up,
32
+ "geglu": lambda gate, up: F.gelu(gate, approximate="tanh") * up,
33
+ "glu": lambda gate, up: torch.sigmoid(gate) * up,
34
+ }
26
35
 
27
36
 
28
37
  @autotune(
29
- configs=[AutotuneConfig(config=c) for c in get_all_configs(epilogue=None)], key=["sm_carveout"]
38
+ configs=[AutotuneConfig(config=c) for c in get_all_configs()],
39
+ key=["dynamic_scheduler"],
30
40
  )
31
41
  def gemm_tuned(
32
- A: Tensor,
33
- B: Tensor,
34
- sm_carveout: int = 0,
42
+ A: Tensor, # (M, K)
43
+ B: Tensor, # (K, N)
44
+ out: Tensor, # (M, N) - required output tensor
45
+ C: Optional[Tensor] = None, # (M, N)
46
+ alpha: float | Tensor = 1.0, # (1,)
47
+ beta: float | Tensor = 1.0, # (1,)
48
+ dynamic_scheduler: bool = False,
35
49
  config: Optional[GemmConfig] = None,
36
- ) -> (Tensor, Optional[Tensor]):
50
+ ) -> None:
37
51
  if config is None:
38
- config = GemmConfig(
39
- tile_m=256,
40
- tile_n=192,
41
- cluster_m=2,
42
- cluster_n=1,
43
- pingpong=False,
44
- raster_order=2,
45
- max_swizzle_size=1,
46
- )
47
- out = torch.ops.quack.gemm_impl.default(
48
- A if not config.swap_ab else B.T,
49
- B if not config.swap_ab else A.T,
50
- sm_carveout,
52
+ config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
53
+ A, B = A.unsqueeze(0), B.mT.unsqueeze(0) # (1, M, K), (1, N, K)
54
+ if C is not None:
55
+ C = C.unsqueeze(0) # (1, M, N)
56
+ assert out.shape == (
57
+ A.shape[1],
58
+ B.shape[1],
59
+ ), f"out shape mismatch: {out.shape} vs {(A.shape[1], B.shape[1])}"
60
+ out = out.unsqueeze(0)
61
+ tile_count_semaphore = (
62
+ torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
63
+ )
64
+ gemm_sm90(
65
+ A if not config.swap_ab else B,
66
+ B if not config.swap_ab else A,
67
+ out if not config.swap_ab else out.mT,
68
+ (C if not config.swap_ab else C.mT) if C is not None else None,
69
+ tile_count_semaphore,
51
70
  config.tile_m,
52
71
  config.tile_n,
53
72
  config.cluster_m,
54
73
  config.cluster_n,
55
- not config.swap_ab, # C_rowmajor
56
74
  config.pingpong,
57
- config.raster_order,
58
- config.max_swizzle_size,
75
+ alpha=alpha,
76
+ beta=beta,
59
77
  )
60
- return out if not config.swap_ab else out.T
61
78
 
62
79
 
63
- @torch.library.custom_op("quack::gemm", mutates_args=(), device_types="cuda")
64
- def gemm(A: Tensor, B: Tensor, sm_carveout: int = 0) -> Tensor:
65
- return gemm_tuned(A, B, sm_carveout)
66
-
67
-
68
- @torch.library.register_fake("quack::gemm")
69
- def gemm_ref(A: Tensor, B: Tensor, sm_carveout: int = 0) -> Tensor:
70
- return torch.mm(A, B)
71
-
72
-
73
- @autotune(configs=[AutotuneConfig(config=c) for c in get_all_configs("add")])
74
- def gemm_add_tuned(
75
- A: Tensor,
76
- B: Tensor,
77
- C: Tensor,
80
+ @autotune(
81
+ configs=[AutotuneConfig(config=c) for c in get_all_configs()],
82
+ key=["activation"],
83
+ )
84
+ def gemm_act_tuned(
85
+ A: Tensor, # (M, K)
86
+ B: Tensor, # (K, N)
87
+ preact_out: Optional[Tensor], # (M, N) - None if not storing preact
88
+ postact_out: Tensor, # (M, N)
89
+ C: Optional[Tensor] = None, # (M, N)
90
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
78
91
  config: Optional[GemmConfig] = None,
79
- ) -> (Tensor, Optional[Tensor]):
92
+ ) -> None:
80
93
  if config is None:
81
- config = GemmConfig(
82
- tile_m=256,
83
- tile_n=192,
84
- cluster_m=2,
85
- cluster_n=1,
86
- pingpong=False,
87
- raster_order=2,
88
- max_swizzle_size=1,
89
- )
90
- out = torch.ops.quack.gemm_add_impl.default(
91
- A if not config.swap_ab else B.T,
92
- B if not config.swap_ab else A.T,
93
- C if not config.swap_ab else C.T,
94
+ config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
95
+ A, B = A.unsqueeze(0), B.mT.unsqueeze(0) # (1, M, K), (1, N, K)
96
+ if C is not None:
97
+ C = C.unsqueeze(0) # (1, M, N)
98
+ if preact_out is not None:
99
+ assert preact_out.shape == (A.shape[1], B.shape[1])
100
+ D = preact_out.unsqueeze(0)
101
+ else:
102
+ D = None
103
+ assert postact_out.shape == (A.shape[1], B.shape[1])
104
+ PostAct = postact_out.unsqueeze(0)
105
+ gemm_act_sm90(
106
+ A if not config.swap_ab else B,
107
+ B if not config.swap_ab else A,
108
+ (D if not config.swap_ab else D.mT) if D is not None else None,
109
+ (C if not config.swap_ab else C.mT) if C is not None else None,
110
+ PostAct if not config.swap_ab else PostAct.mT,
111
+ activation,
94
112
  config.tile_m,
95
113
  config.tile_n,
96
114
  config.cluster_m,
97
115
  config.cluster_n,
98
116
  config.pingpong,
99
- config.raster_order,
100
- config.max_swizzle_size,
101
117
  )
102
- return out if not config.swap_ab else out.T
103
-
104
-
105
- @torch.library.custom_op("quack::gemm_add", mutates_args=(), device_types="cuda")
106
- def gemm_add(A: Tensor, B: Tensor, C: Tensor) -> Tensor:
107
- return gemm_add_tuned(A, B, C)
108
-
109
-
110
- @torch.library.register_fake("quack::gemm_add")
111
- def gemm_add_ref(A: Tensor, B: Tensor, C: Tensor) -> Tensor:
112
- return C + torch.mm(A, B)
113
-
114
-
115
- @torch.library.custom_op("quack::gemm_add_t", mutates_args=(), device_types="cuda")
116
- def gemm_t_add(A: Tensor, B: Tensor, C: Tensor) -> Tensor:
117
- return gemm_add_tuned(A, B.T, C)
118
-
119
-
120
- @torch.library.register_fake("quack::gemm_add_t")
121
- def gemm_t_add_ref(A: Tensor, B: Tensor, C: Tensor) -> Tensor:
122
- return gemm_add_ref(A, B.T, C)
123
118
 
124
119
 
125
120
  @autotune(
126
- configs=[AutotuneConfig(config=c) for c in get_all_configs("swiglu")], key=["store_preact"]
121
+ configs=[AutotuneConfig(config=c) for c in get_all_configs()],
122
+ key=["activation", "dynamic_scheduler"],
127
123
  )
128
- def gemm_swiglu_tuned(
129
- A: Tensor,
130
- B: Tensor,
131
- store_preact: bool = True,
124
+ def gemm_dact_tuned(
125
+ A: Tensor, # (M, K)
126
+ B: Tensor, # (K, N)
127
+ PreAct: Tensor, # (M, N)
128
+ dx_out: Tensor, # (M, N)
129
+ postact_out: Tensor, # (M, N)
130
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
131
+ dynamic_scheduler: bool = True,
132
132
  config: Optional[GemmConfig] = None,
133
- ) -> (Tensor, Optional[Tensor]):
133
+ ) -> None:
134
134
  if config is None:
135
- config = GemmConfig(
136
- tile_m=256,
137
- tile_n=192,
138
- cluster_m=2,
139
- cluster_n=1,
140
- pingpong=False,
141
- raster_order=2,
142
- max_swizzle_size=1,
143
- )
144
- # out, preact
145
- return torch.ops.quack.gemm_swiglu_impl.default(
146
- A,
147
- B,
148
- store_preact,
135
+ config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
136
+ A, B = A.unsqueeze(0), B.mT.unsqueeze(0) # (1, M, K), (1, N, K)
137
+ PreAct = PreAct.unsqueeze(0) # (1, M, N)
138
+ assert dx_out.shape == (A.shape[1], B.shape[1])
139
+ D = dx_out.unsqueeze(0)
140
+ assert postact_out.shape == (A.shape[1], B.shape[1])
141
+ PostAct = postact_out.unsqueeze(0)
142
+ tile_count_semaphore = (
143
+ torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
144
+ )
145
+ gemm_dact_sm90(
146
+ A if not config.swap_ab else B,
147
+ B if not config.swap_ab else A,
148
+ D if not config.swap_ab else D.mT,
149
+ PreAct if not config.swap_ab else PreAct.mT,
150
+ PostAct if not config.swap_ab else PostAct.mT,
151
+ tile_count_semaphore,
152
+ activation,
149
153
  config.tile_m,
150
154
  config.tile_n,
151
155
  config.cluster_m,
152
156
  config.cluster_n,
153
157
  config.pingpong,
154
- config.raster_order,
155
- config.max_swizzle_size,
156
158
  )
157
159
 
158
160
 
159
- # Specifying the schema manually here since torch.library._infer_schema doesn't work when return
160
- # type is a tuple of Tensor
161
+ def gemm(
162
+ A: Tensor,
163
+ B: Tensor,
164
+ out: Optional[Tensor] = None,
165
+ alpha: float | Tensor = 1.0,
166
+ out_dtype: Optional[torch.dtype] = None,
167
+ dynamic_scheduler: bool = False,
168
+ tuned: bool = True,
169
+ ) -> Tensor:
170
+ """GEMM with optional output tensor and tuning control."""
171
+ if out is None:
172
+ out_dtype = A.dtype if out_dtype is None else out_dtype
173
+ out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
174
+ alpha_tensor = alpha if not isinstance(alpha, float) else None
175
+ alpha = alpha if isinstance(alpha, float) else 1.0
176
+ gemm_out(
177
+ A,
178
+ B,
179
+ out,
180
+ alpha=alpha,
181
+ alpha_tensor=alpha_tensor,
182
+ dynamic_scheduler=dynamic_scheduler,
183
+ tuned=tuned,
184
+ )
185
+ return out
186
+
187
+
161
188
  @torch.library.custom_op(
162
- "quack::gemm_swiglu",
163
- mutates_args=(),
189
+ "quack::gemm_out",
190
+ mutates_args=("out",),
164
191
  device_types="cuda",
165
- schema="(Tensor A, Tensor B, bool store_preact) -> (Tensor, Tensor)",
192
+ # We have to split out alpha and alpha_tensor since torch.library requires
193
+ # each argument to have a fixed type
194
+ # schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, Tensor? alpha_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
166
195
  )
167
- def gemm_swiglu(A: Tensor, B: Tensor, store_preact: bool = True) -> (Tensor, Tensor):
168
- return gemm_swiglu_tuned(A, B, store_preact=store_preact)
169
-
170
-
171
- @torch.library.register_fake("quack::gemm_swiglu")
172
- def gemm_swiglu_ref(A: Tensor, B: Tensor, store_preact: bool) -> (Tensor, Tensor):
173
- return gemm_swiglu_out_ref(A, B, None, store_preact)
174
-
175
-
176
- # @torch.library.custom_op("quack::gemm_swiglu_t", mutates_args=(), device_types="cuda",
177
- # schema="(Tensor A, Tensor B, bool store_preact) -> (Tensor, Tensor)")
178
- # def gemm_swiglu_t(A: Tensor, B: Tensor, store_preact: bool = True) -> (Tensor, Tensor):
179
- # return gemm_swiglu_tuned(A, B.T, store_preact=store_preact)
180
-
181
-
182
- # @torch.library.register_fake("quack::gemm_swiglu_t")
183
- # def gemm_swiglu_t_ref(A: Tensor, B: Tensor, store_preact: bool) -> (Tensor, Tensor):
184
- # return gemm_swiglu_ref(A, B.T, store_preact)
196
+ def gemm_out(
197
+ A: Tensor,
198
+ B: Tensor,
199
+ out: Tensor,
200
+ alpha: float = 1.0,
201
+ alpha_tensor: Optional[Tensor] = None,
202
+ dynamic_scheduler: bool = False,
203
+ tuned: bool = True,
204
+ ) -> None:
205
+ """GEMM with pre-allocated output tensor."""
206
+ fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
207
+ alpha = alpha_tensor if alpha_tensor is not None else alpha
208
+ fn(A, B, out, C=None, alpha=alpha, dynamic_scheduler=dynamic_scheduler)
209
+
210
+
211
+ def gemm_ref(
212
+ A: Tensor,
213
+ B: Tensor,
214
+ out: Optional[Tensor] = None,
215
+ alpha: float | Tensor = 1.0,
216
+ out_dtype: Optional[torch.dtype] = None,
217
+ ) -> Tensor:
218
+ """Reference implementation for GEMM with pre-allocated output."""
219
+ # The out_dtype argument requires torch >= 2.8
220
+ out = torch.mm(A, B, out_dtype=out_dtype, out=out)
221
+ if not isinstance(alpha, float) or alpha != 1.0:
222
+ out = out * alpha
223
+ return out
224
+
225
+
226
+ def gemm_add(
227
+ A: Tensor,
228
+ B: Tensor,
229
+ C: Tensor,
230
+ out: Optional[Tensor] = None,
231
+ alpha: float | Tensor = 1.0,
232
+ beta: float | Tensor = 1.0,
233
+ out_dtype: Optional[torch.dtype] = None,
234
+ dynamic_scheduler: bool = False,
235
+ tuned: bool = True,
236
+ ) -> Tensor:
237
+ """GEMM with addition and optional output tensor."""
238
+ if out is None:
239
+ out_dtype = A.dtype if out_dtype is None else out_dtype
240
+ out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
241
+ alpha_tensor = alpha if not isinstance(alpha, float) else None
242
+ alpha = alpha if isinstance(alpha, float) else 1.0
243
+ beta_tensor = beta if not isinstance(beta, float) else None
244
+ beta = beta if isinstance(beta, float) else 1.0
245
+ gemm_add_out(
246
+ A,
247
+ B,
248
+ C,
249
+ out,
250
+ alpha,
251
+ beta,
252
+ alpha_tensor,
253
+ beta_tensor,
254
+ dynamic_scheduler=dynamic_scheduler,
255
+ tuned=tuned,
256
+ )
257
+ return out
185
258
 
186
259
 
187
- @autotune(
188
- configs=[AutotuneConfig(config=c) for c in get_all_configs("dswiglu")], key=["sm_carveout"]
260
+ @torch.library.custom_op(
261
+ "quack::gemm_add_out",
262
+ mutates_args=("out",),
263
+ device_types="cuda",
264
+ # We have to split out alpha and alpha_tensor since torch.library requires
265
+ # each argument to have a fixed type
266
+ # schema="(Tensor A, Tensor B, Tensor C, Tensor(a3!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
189
267
  )
190
- def gemm_dswiglu_tuned(
268
+ def gemm_add_out(
191
269
  A: Tensor,
192
270
  B: Tensor,
193
- preact: Tensor,
194
- sm_carveout: int = 0,
195
- config: Optional[GemmConfig] = None,
196
- ) -> (Tensor, Tensor):
197
- if config is None:
198
- config = GemmConfig(
199
- tile_m=128,
200
- tile_n=192,
201
- cluster_m=2,
202
- cluster_n=1,
203
- pingpong=True,
204
- raster_order=2,
205
- max_swizzle_size=1,
271
+ C: Tensor,
272
+ out: Tensor,
273
+ alpha: float = 1.0,
274
+ beta: float = 1.0,
275
+ alpha_tensor: Optional[Tensor] = None,
276
+ beta_tensor: Optional[Tensor] = None,
277
+ dynamic_scheduler: bool = False,
278
+ tuned: bool = True,
279
+ ) -> None:
280
+ """GEMM with addition and pre-allocated output tensor."""
281
+ fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
282
+ alpha = alpha_tensor if alpha_tensor is not None else alpha
283
+ beta = beta_tensor if beta_tensor is not None else beta
284
+ fn(A, B, out, C, alpha=alpha, beta=beta, dynamic_scheduler=dynamic_scheduler)
285
+
286
+
287
+ def gemm_add_ref(
288
+ A: Tensor,
289
+ B: Tensor,
290
+ C: Tensor,
291
+ out: Optional[Tensor] = None,
292
+ alpha: float | Tensor = 1.0,
293
+ beta: float | Tensor = 1.0,
294
+ out_dtype: Optional[torch.dtype] = None,
295
+ ) -> Tensor:
296
+ """Reference implementation for GEMM with addition and pre-allocated output."""
297
+ if isinstance(alpha, float) and isinstance(beta, float):
298
+ return torch.addmm(C, A, B, out_dtype=out_dtype, alpha=alpha, beta=beta, out=out)
299
+ else:
300
+ out_dtype = (
301
+ out.dtype if out is not None else (out_dtype if out_dtype is not None else A.dtype)
206
302
  )
207
- out, postact = torch.ops.quack.gemm_dswiglu_impl.default(
208
- A if not config.swap_ab else B.T,
209
- B if not config.swap_ab else A.T,
210
- preact if not config.swap_ab else preact.T,
211
- sm_carveout,
212
- config.tile_m,
213
- config.tile_n,
214
- config.cluster_m,
215
- config.cluster_n,
216
- not config.swap_ab, # C_rowmajor
217
- config.pingpong,
218
- config.raster_order,
219
- config.max_swizzle_size,
220
- )
221
- return (out, postact) if not config.swap_ab else (out.T, postact.T)
303
+ result = (alpha * (A @ B) + beta * C).to(out_dtype)
304
+ if out is not None:
305
+ out.copy_(result)
306
+ return result
307
+
308
+
309
+ def gemm_add_inplace(
310
+ A: Tensor,
311
+ B: Tensor,
312
+ out: Tensor,
313
+ alpha: float | Tensor = 1.0,
314
+ beta: float | Tensor = 1.0,
315
+ dynamic_scheduler: bool = False,
316
+ tuned: bool = True,
317
+ ) -> None:
318
+ """In-place GEMM with addition: out = alpha * A @ B + beta * out.
319
+ Args:
320
+ A: (M, K) input tensor
321
+ B: (K, N) input tensor
322
+ out: (M, N) tensor to accumulate into (modified in-place)
323
+ alpha: Scalar multiplier for A @ B
324
+ beta: Scalar multiplier for out
325
+ dynamic_scheduler: Whether to use dynamic scheduler
326
+ tuned: Whether to use autotuned configuration
327
+ """
328
+ alpha_tensor = alpha if not isinstance(alpha, float) else None
329
+ alpha = alpha if isinstance(alpha, float) else 1.0
330
+ beta_tensor = beta if not isinstance(beta, float) else None
331
+ beta = beta if isinstance(beta, float) else 1.0
332
+ gemm_add_inplace_op(A, B, out, alpha, beta, alpha_tensor, beta_tensor, dynamic_scheduler, tuned)
222
333
 
223
334
 
224
- # Specifying the schema manually here since torch.library._infer_schema doesn't work when return
225
- # type is a tuple of Tensor
226
335
  @torch.library.custom_op(
227
- "quack::gemm_dswiglu",
228
- mutates_args=(),
336
+ "quack::gemm_add_inplace",
337
+ mutates_args=("out",),
229
338
  device_types="cuda",
230
- schema="(Tensor A, Tensor B, Tensor preact, int sm_carveout=0) -> (Tensor, Tensor)",
339
+ # We have to split out alpha and alpha_tensor since torch.library requires
340
+ # each argument to have a fixed type
341
+ # schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
231
342
  )
232
- def gemm_dswiglu(A: Tensor, B: Tensor, preact: Tensor, sm_carveout: int = 0) -> (Tensor, Tensor):
233
- return gemm_dswiglu_tuned(A, B, preact, sm_carveout)
234
-
235
-
236
- @torch.library.register_fake("quack::gemm_dswiglu")
237
- def gemm_dswiglu_ref(
238
- A: Tensor, B: Tensor, preact: Tensor, sm_carveout: int = 0
239
- ) -> (Tensor, Tensor):
240
- # A: (M, K), B: (K, N), preact: (M, 2 * N)
241
- dout = torch.mm(A, B)
242
- p0, p1 = preact[..., ::2], preact[..., 1::2]
243
- sigmoid = torch.sigmoid(p0)
244
- silu = F.silu(p0)
245
- postact = silu * p1
246
- d0 = sigmoid * (1 + p0 * (1 - sigmoid)) * p1 * dout
247
- d1 = F.silu(p0) * dout
248
- out = torch.stack([d0, d1], dim=-1).reshape(d0.shape[:-1] + (2 * d0.shape[-1],))
249
- return out, postact
250
-
251
-
252
- @autotune(configs=[AutotuneConfig(config=c) for c in get_all_configs("lse")])
253
- def gemm_lse_tuned(
343
+ def gemm_add_inplace_op(
254
344
  A: Tensor,
255
345
  B: Tensor,
256
- softcap: float = 0.0,
257
- config: Optional[GemmConfig] = None,
258
- ) -> (Tensor, Tensor):
259
- if config is None:
260
- config = GemmConfig(
261
- tile_m=256,
262
- tile_n=192,
263
- cluster_m=2,
264
- cluster_n=1,
265
- pingpong=False,
266
- raster_order=2,
267
- max_swizzle_size=1,
268
- )
269
- out, lse_partial = torch.ops.quack.gemm_lse_impl.default(
270
- A,
271
- B,
272
- None, # bias
273
- softcap,
274
- config.tile_m,
275
- config.tile_n,
276
- config.cluster_m,
277
- config.cluster_n,
278
- config.pingpong,
279
- config.raster_order,
280
- config.max_swizzle_size,
281
- )
282
- lse = logsumexp(lse_partial)
283
- return out, lse
346
+ out: Tensor,
347
+ alpha: float = 1.0,
348
+ beta: float = 1.0,
349
+ alpha_tensor: Optional[Tensor] = None,
350
+ beta_tensor: Optional[Tensor] = None,
351
+ dynamic_scheduler: bool = False,
352
+ tuned: bool = True,
353
+ ) -> None:
354
+ fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
355
+ alpha = alpha_tensor if alpha_tensor is not None else alpha
356
+ beta = beta_tensor if beta_tensor is not None else beta
357
+ # Use C as both input bias and output
358
+ fn(A, B, out, out, alpha=alpha, beta=beta, dynamic_scheduler=dynamic_scheduler)
359
+
360
+
361
+ def gemm_act(
362
+ A: Tensor,
363
+ B: Tensor,
364
+ C: Optional[Tensor] = None,
365
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
366
+ preact_out: Optional[Tensor] = None,
367
+ postact_out: Optional[Tensor] = None,
368
+ out_dtype: Optional[torch.dtype] = None,
369
+ postact_dtype: Optional[torch.dtype] = None,
370
+ store_preact: bool = True,
371
+ tuned: bool = True,
372
+ ) -> Tuple[Optional[Tensor], Tensor]:
373
+ """GEMM with activation and optional output tensors."""
374
+ out_dtype = A.dtype if out_dtype is None else out_dtype
375
+ postact_dtype = A.dtype if postact_dtype is None else postact_dtype
376
+ if preact_out is None and store_preact:
377
+ preact_out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
378
+ if postact_out is None:
379
+ postact_out = torch.empty((A.shape[0], B.shape[1]), dtype=postact_dtype, device=A.device)
380
+ gemm_act_out(A, B, preact_out, postact_out, C, activation, tuned)
381
+ return preact_out, postact_out
284
382
 
285
383
 
286
384
  @torch.library.custom_op(
287
- "quack::gemm_lse",
288
- mutates_args=(),
385
+ "quack::gemm_act_out",
386
+ mutates_args=("preact_out", "postact_out"),
289
387
  device_types="cuda",
290
- schema="(Tensor A, Tensor B, float softcap=0.0) -> (Tensor, Tensor)",
388
+ schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, str? activation=None, bool tuned=True) -> ()",
291
389
  )
292
- def gemm_lse(A: Tensor, B: Tensor, softcap: float = 0.0) -> (Tensor, Tensor):
293
- return gemm_lse_tuned(A, B, softcap)
390
+ def gemm_act_out(
391
+ A: Tensor,
392
+ B: Tensor,
393
+ preact_out: Optional[Tensor],
394
+ postact_out: Tensor,
395
+ C: Optional[Tensor] = None,
396
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
397
+ tuned: bool = True,
398
+ ) -> None:
399
+ """GEMM with activation and pre-allocated output tensors."""
400
+ fn = gemm_act_tuned if tuned else partial(gemm_act_tuned.fn, config=None)
401
+ fn(A, B, preact_out, postact_out, C, activation)
402
+
403
+
404
+ def gemm_act_ref(
405
+ A: Tensor,
406
+ B: Tensor,
407
+ C: Optional[Tensor] = None,
408
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
409
+ out_dtype: Optional[torch.dtype] = None,
410
+ postact_dtype: Optional[torch.dtype] = None,
411
+ store_preact: bool = True,
412
+ ) -> Tuple[Optional[Tensor], Tensor]:
413
+ out_dtype = A.dtype if out_dtype is None else out_dtype
414
+ postact_dtype = A.dtype if postact_dtype is None else postact_dtype
415
+ out = torch.mm(A, B) if C is None else C + torch.mm(A, B)
416
+ postact = act_to_pytorch_fn_map[activation](out).to(postact_dtype)
417
+ return out.to(out_dtype) if store_preact else None, postact
294
418
 
295
419
 
296
- @torch.library.register_fake("quack::gemm_lse")
297
- def gemm_lse_ref(A: Tensor, B: Tensor, softcap: float = 0.0) -> (Tensor, Tensor):
298
- # A: (M, K), B: (K, N)
299
- out = torch.mm(A, B)
300
- if softcap > 0:
301
- out_fp32 = torch.tanh(out.to(torch.float32) / softcap) * softcap
302
- out = out_fp32.to(out.dtype)
303
- else:
304
- out_fp32 = out.to(torch.float32)
305
- lse = torch.logsumexp(out_fp32, dim=-1)
306
- return out, lse
420
+ def gemm_dact(
421
+ A: Tensor,
422
+ B: Tensor,
423
+ PreAct: Tensor,
424
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
425
+ dx_out: Optional[Tensor] = None,
426
+ postact_out: Optional[Tensor] = None,
427
+ out_dtype: Optional[torch.dtype] = None,
428
+ postact_dtype: Optional[torch.dtype] = None,
429
+ dynamic_scheduler: bool = True,
430
+ tuned: bool = True,
431
+ ) -> Tuple[Tensor, Tensor]:
432
+ """GEMM with activation gradient and optional output tensors."""
433
+ out_dtype = A.dtype if out_dtype is None else out_dtype
434
+ postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
435
+ if dx_out is None:
436
+ dx_out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
437
+ if postact_out is None:
438
+ postact_out = torch.empty((A.shape[0], B.shape[1]), dtype=postact_dtype, device=A.device)
439
+ gemm_dact_out(A, B, PreAct, dx_out, postact_out, activation, dynamic_scheduler, tuned)
440
+ return dx_out, postact_out
307
441
 
308
442
 
309
443
  @torch.library.custom_op(
310
- "quack::gemm_lse_t",
311
- mutates_args=(),
444
+ "quack::gemm_dact_out",
445
+ mutates_args=("dx_out", "postact_out"),
312
446
  device_types="cuda",
313
- schema="(Tensor A, Tensor B, float softcap=0.0) -> (Tensor, Tensor)",
447
+ schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a3!) dx_out, Tensor(a4!) postact_out, str? activation=None, bool dynamic_scheduler=True, bool tuned=True) -> ()",
314
448
  )
315
- def gemm_lse_t(A: Tensor, B: Tensor, softcap: float = 0.0) -> (Tensor, Tensor):
316
- return gemm_lse_tuned(A, B.T, softcap)
449
+ def gemm_dact_out(
450
+ A: Tensor,
451
+ B: Tensor,
452
+ PreAct: Tensor,
453
+ dx_out: Tensor,
454
+ postact_out: Tensor,
455
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
456
+ dynamic_scheduler: bool = True,
457
+ tuned: bool = True,
458
+ ) -> None:
459
+ """GEMM with activation gradient and pre-allocated output tensors."""
460
+ fn = gemm_dact_tuned if tuned else partial(gemm_dact_tuned.fn, config=None)
461
+ fn(A, B, PreAct, dx_out, postact_out, activation, dynamic_scheduler)
462
+
463
+
464
+ def gemm_dact_ref(
465
+ A: Tensor,
466
+ B: Tensor,
467
+ PreAct: Tensor,
468
+ activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
469
+ out_dtype: Optional[torch.dtype] = None,
470
+ postact_dtype: Optional[torch.dtype] = None,
471
+ ) -> Tuple[Tensor, Tensor]:
472
+ """Reference implementation for GEMM with activation gradient."""
473
+ out_dtype = A.dtype if out_dtype is None else out_dtype
474
+ postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
475
+ dout = torch.mm(A, B).to(out_dtype)
476
+ postact = act_to_pytorch_fn_map[activation](PreAct)
477
+ # Compute gradient using autograd
478
+ if activation is None:
479
+ dx = dout
480
+ else:
481
+ PreAct_requires_grad = PreAct.requires_grad
482
+ PreAct.requires_grad_(True)
483
+ postact_for_grad = act_to_pytorch_fn_map[activation](PreAct)
484
+ dx = torch.autograd.grad(postact_for_grad, PreAct, dout, create_graph=False)[0]
485
+ PreAct.requires_grad_(PreAct_requires_grad)
486
+ return dx.to(out_dtype), postact.to(postact_dtype)
317
487
 
318
488
 
319
- @torch.library.register_fake("quack::gemm_lse_t")
320
- def gemm_lse_t_ref(A: Tensor, B: Tensor, softcap: float = 0.0) -> (Tensor, Tensor):
321
- return gemm_lse_ref(A, B.T, softcap)
489
+ def gemm_gated_ref(
490
+ A: Tensor,
491
+ B: Tensor,
492
+ C: Optional[Tensor] = None,
493
+ activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"] = "swiglu",
494
+ out_dtype: Optional[torch.dtype] = None,
495
+ postact_dtype: Optional[torch.dtype] = None,
496
+ store_preact: bool = True,
497
+ ) -> Tuple[Optional[Tensor], Tensor]:
498
+ """Reference implementation for GEMM with gated activation forward.
499
+
500
+ Args:
501
+ A: (M, K) - input tensor
502
+ B: (K, 2*N) - weight tensor with gate and up projections
503
+ C: (M, 2*N) - optional bias tensor
504
+ activation: Type of gated activation
505
+ out_dtype: Output dtype for preact
506
+ postact_dtype: Output dtype for postact
507
+ store_preact: Whether to return the pre-activation
508
+
509
+ Returns:
510
+ (preact, postact) where:
511
+ - preact: (M, 2*N) pre-activation (if store_preact=True, else None)
512
+ - postact: (M, N) post-activation output
513
+ """
514
+ out_dtype = A.dtype if out_dtype is None else out_dtype
515
+ postact_dtype = A.dtype if postact_dtype is None else postact_dtype
516
+ preact = torch.mm(A, B) if C is None else C + torch.mm(A, B)
517
+ # Split preact into gate and up projections
518
+ gate = preact[..., ::2] # (M, N)
519
+ up = preact[..., 1::2] # (M, N)
520
+ postact = gated_to_pytorch_fn_map[activation](gate, up)
521
+ return preact.to(out_dtype) if store_preact else None, postact.to(postact_dtype)
522
+
523
+
524
+ def gemm_dgated_ref(
525
+ A: Tensor,
526
+ B: Tensor,
527
+ PreAct: Tensor,
528
+ activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"],
529
+ out_dtype: Optional[torch.dtype] = None,
530
+ postact_dtype: Optional[torch.dtype] = None,
531
+ ) -> Tuple[Tensor, Tensor]:
532
+ """Reference implementation for GEMM with gated activation gradient.
533
+
534
+ Args:
535
+ A: (M, K) - dout input tensor
536
+ B: (K, N) - weight tensor
537
+ PreAct: (M, 2*N) - pre-activation tensor with gate and up projections interleaved
538
+ activation: Type of gated activation
539
+ out_dtype: Output dtype for dx
540
+ postact_dtype: Output dtype for postact
541
+
542
+ Returns:
543
+ (dx, postact) where:
544
+ - dx: (M, 2*N) gradient w.r.t. PreAct
545
+ - postact: (M, N) post-activation output
546
+ """
547
+ out_dtype = A.dtype if out_dtype is None else out_dtype
548
+ postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
549
+ dout = torch.mm(A, B).to(out_dtype)
550
+ # Split PreAct into gate and up projections
551
+ gate = PreAct[..., ::2] # (M, N)
552
+ up = PreAct[..., 1::2] # (M, N)
553
+ postact = gated_to_pytorch_fn_map[activation](gate, up)
554
+ # Use autograd to compute gradients w.r.t. gate and up
555
+ dgate, dup = torch.autograd.grad(postact, [gate, up], dout, create_graph=False)
556
+ # Interleave gradients back
557
+ dx = torch.stack([dgate, dup], dim=-1).reshape(PreAct.shape)
558
+ return dx.to(out_dtype), postact.to(postact_dtype)
559
+
560
+
561
+ # TODO: this is not quite right, do we need to register gemm_add not gemm_add_out?
562
+ # try:
563
+ # from torch._inductor.fx_passes.reinplace import InplaceableOp
564
+ # torch._inductor.fx_passes.reinplace.inplaceable_ops.update({
565
+ # torch.ops.quack.gemm_add_out.default:
566
+ # InplaceableOp(torch.ops.quack.gemm_add_inplace.default, mutated_arg=2)
567
+ # })
568
+ # except ImportError:
569
+ # pass