quack-kernels 0.1.11__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +7 -3
- quack/activation.py +279 -0
- quack/autotuner.py +2 -1
- quack/cross_entropy.py +330 -184
- quack/cute_dsl_utils.py +83 -4
- quack/dense_gemm_sm100.py +1 -1
- quack/dense_gemm_sm90.py +911 -1140
- quack/fast_math.py +10 -27
- quack/gemm_act_sm90.py +368 -0
- quack/gemm_config.py +43 -35
- quack/gemm_dact_sm90.py +150 -0
- quack/gemm_interface.py +491 -243
- quack/gemm_wrapper_utils.py +158 -0
- quack/layernorm.py +6 -4
- quack/linear.py +128 -64
- quack/linear_cross_entropy.py +275 -0
- quack/mlp.py +30 -160
- quack/pipeline.py +2 -17
- quack/reduce.py +240 -0
- quack/reduction_base.py +2 -11
- quack/rmsnorm.py +614 -228
- quack/softmax.py +28 -16
- quack/symmetric_dense_gemm_sm90.py +6 -3
- quack/tensormap_manager.py +1 -0
- quack/tile_scheduler.py +64 -61
- quack/topk.py +14 -8
- quack/utils.py +14 -322
- quack/varlen_utils.py +22 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/METADATA +3 -3
- quack_kernels-0.2.1.dist-info/RECORD +37 -0
- quack/lse.py +0 -62
- quack_kernels-0.1.11.dist-info/RECORD +0 -31
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/WHEEL +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.1.11.dist-info → quack_kernels-0.2.1.dist-info}/top_level.txt +0 -0
quack/gemm_interface.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# Copyright (c) 2025, Tri Dao
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import Optional, Tuple, Literal
|
|
3
|
+
from functools import partial
|
|
3
4
|
|
|
4
5
|
import torch
|
|
5
6
|
import torch.nn.functional as F
|
|
@@ -8,314 +9,561 @@ from torch import Tensor
|
|
|
8
9
|
from quack.gemm_config import GemmConfig, get_all_configs
|
|
9
10
|
|
|
10
11
|
from quack.autotuner import autotune, AutotuneConfig
|
|
11
|
-
from quack.
|
|
12
|
+
from quack.dense_gemm_sm90 import gemm_sm90
|
|
13
|
+
from quack.gemm_act_sm90 import gemm_act_sm90
|
|
14
|
+
from quack.gemm_dact_sm90 import gemm_dact_sm90
|
|
12
15
|
|
|
13
16
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
17
|
+
# Dictionary mapping activation names to PyTorch functions
|
|
18
|
+
act_to_pytorch_fn_map = {
|
|
19
|
+
None: lambda x: x,
|
|
20
|
+
"relu": F.relu,
|
|
21
|
+
"relu_sq": lambda x: F.relu(x).square(),
|
|
22
|
+
"gelu_tanh_approx": partial(F.gelu, approximate="tanh"),
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# Dictionary mapping gated activation names to their forward functions
|
|
27
|
+
# Each function takes (gate, up) and returns postact
|
|
28
|
+
gated_to_pytorch_fn_map = {
|
|
29
|
+
"swiglu": lambda gate, up: F.silu(gate) * up,
|
|
30
|
+
"swiglu_oai": lambda gate, up: gate * torch.sigmoid(1.702 * gate) * (up + 1),
|
|
31
|
+
"reglu": lambda gate, up: F.relu(gate) * up,
|
|
32
|
+
"geglu": lambda gate, up: F.gelu(gate, approximate="tanh") * up,
|
|
33
|
+
"glu": lambda gate, up: torch.sigmoid(gate) * up,
|
|
34
|
+
}
|
|
26
35
|
|
|
27
36
|
|
|
28
37
|
@autotune(
|
|
29
|
-
configs=[AutotuneConfig(config=c) for c in get_all_configs(
|
|
38
|
+
configs=[AutotuneConfig(config=c) for c in get_all_configs()],
|
|
39
|
+
key=["dynamic_scheduler"],
|
|
30
40
|
)
|
|
31
41
|
def gemm_tuned(
|
|
32
|
-
A: Tensor,
|
|
33
|
-
B: Tensor,
|
|
34
|
-
|
|
42
|
+
A: Tensor, # (M, K)
|
|
43
|
+
B: Tensor, # (K, N)
|
|
44
|
+
out: Tensor, # (M, N) - required output tensor
|
|
45
|
+
C: Optional[Tensor] = None, # (M, N)
|
|
46
|
+
alpha: float | Tensor = 1.0, # (1,)
|
|
47
|
+
beta: float | Tensor = 1.0, # (1,)
|
|
48
|
+
dynamic_scheduler: bool = False,
|
|
35
49
|
config: Optional[GemmConfig] = None,
|
|
36
|
-
) ->
|
|
50
|
+
) -> None:
|
|
37
51
|
if config is None:
|
|
38
|
-
config = GemmConfig(
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
A if
|
|
49
|
-
|
|
50
|
-
|
|
52
|
+
config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
|
|
53
|
+
A, B = A.unsqueeze(0), B.mT.unsqueeze(0) # (1, M, K), (1, N, K)
|
|
54
|
+
if C is not None:
|
|
55
|
+
C = C.unsqueeze(0) # (1, M, N)
|
|
56
|
+
assert out.shape == (
|
|
57
|
+
A.shape[1],
|
|
58
|
+
B.shape[1],
|
|
59
|
+
), f"out shape mismatch: {out.shape} vs {(A.shape[1], B.shape[1])}"
|
|
60
|
+
out = out.unsqueeze(0)
|
|
61
|
+
tile_count_semaphore = (
|
|
62
|
+
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
|
|
63
|
+
)
|
|
64
|
+
gemm_sm90(
|
|
65
|
+
A if not config.swap_ab else B,
|
|
66
|
+
B if not config.swap_ab else A,
|
|
67
|
+
out if not config.swap_ab else out.mT,
|
|
68
|
+
(C if not config.swap_ab else C.mT) if C is not None else None,
|
|
69
|
+
tile_count_semaphore,
|
|
51
70
|
config.tile_m,
|
|
52
71
|
config.tile_n,
|
|
53
72
|
config.cluster_m,
|
|
54
73
|
config.cluster_n,
|
|
55
|
-
not config.swap_ab, # C_rowmajor
|
|
56
74
|
config.pingpong,
|
|
57
|
-
|
|
58
|
-
|
|
75
|
+
alpha=alpha,
|
|
76
|
+
beta=beta,
|
|
59
77
|
)
|
|
60
|
-
return out if not config.swap_ab else out.T
|
|
61
78
|
|
|
62
79
|
|
|
63
|
-
@
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
def gemm_add_tuned(
|
|
75
|
-
A: Tensor,
|
|
76
|
-
B: Tensor,
|
|
77
|
-
C: Tensor,
|
|
80
|
+
@autotune(
|
|
81
|
+
configs=[AutotuneConfig(config=c) for c in get_all_configs()],
|
|
82
|
+
key=["activation"],
|
|
83
|
+
)
|
|
84
|
+
def gemm_act_tuned(
|
|
85
|
+
A: Tensor, # (M, K)
|
|
86
|
+
B: Tensor, # (K, N)
|
|
87
|
+
preact_out: Optional[Tensor], # (M, N) - None if not storing preact
|
|
88
|
+
postact_out: Tensor, # (M, N)
|
|
89
|
+
C: Optional[Tensor] = None, # (M, N)
|
|
90
|
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
78
91
|
config: Optional[GemmConfig] = None,
|
|
79
|
-
) ->
|
|
92
|
+
) -> None:
|
|
80
93
|
if config is None:
|
|
81
|
-
config = GemmConfig(
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
+
config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
|
|
95
|
+
A, B = A.unsqueeze(0), B.mT.unsqueeze(0) # (1, M, K), (1, N, K)
|
|
96
|
+
if C is not None:
|
|
97
|
+
C = C.unsqueeze(0) # (1, M, N)
|
|
98
|
+
if preact_out is not None:
|
|
99
|
+
assert preact_out.shape == (A.shape[1], B.shape[1])
|
|
100
|
+
D = preact_out.unsqueeze(0)
|
|
101
|
+
else:
|
|
102
|
+
D = None
|
|
103
|
+
assert postact_out.shape == (A.shape[1], B.shape[1])
|
|
104
|
+
PostAct = postact_out.unsqueeze(0)
|
|
105
|
+
gemm_act_sm90(
|
|
106
|
+
A if not config.swap_ab else B,
|
|
107
|
+
B if not config.swap_ab else A,
|
|
108
|
+
(D if not config.swap_ab else D.mT) if D is not None else None,
|
|
109
|
+
(C if not config.swap_ab else C.mT) if C is not None else None,
|
|
110
|
+
PostAct if not config.swap_ab else PostAct.mT,
|
|
111
|
+
activation,
|
|
94
112
|
config.tile_m,
|
|
95
113
|
config.tile_n,
|
|
96
114
|
config.cluster_m,
|
|
97
115
|
config.cluster_n,
|
|
98
116
|
config.pingpong,
|
|
99
|
-
config.raster_order,
|
|
100
|
-
config.max_swizzle_size,
|
|
101
117
|
)
|
|
102
|
-
return out if not config.swap_ab else out.T
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
@torch.library.custom_op("quack::gemm_add", mutates_args=(), device_types="cuda")
|
|
106
|
-
def gemm_add(A: Tensor, B: Tensor, C: Tensor) -> Tensor:
|
|
107
|
-
return gemm_add_tuned(A, B, C)
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
@torch.library.register_fake("quack::gemm_add")
|
|
111
|
-
def gemm_add_ref(A: Tensor, B: Tensor, C: Tensor) -> Tensor:
|
|
112
|
-
return C + torch.mm(A, B)
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
@torch.library.custom_op("quack::gemm_add_t", mutates_args=(), device_types="cuda")
|
|
116
|
-
def gemm_t_add(A: Tensor, B: Tensor, C: Tensor) -> Tensor:
|
|
117
|
-
return gemm_add_tuned(A, B.T, C)
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
@torch.library.register_fake("quack::gemm_add_t")
|
|
121
|
-
def gemm_t_add_ref(A: Tensor, B: Tensor, C: Tensor) -> Tensor:
|
|
122
|
-
return gemm_add_ref(A, B.T, C)
|
|
123
118
|
|
|
124
119
|
|
|
125
120
|
@autotune(
|
|
126
|
-
configs=[AutotuneConfig(config=c) for c in get_all_configs(
|
|
121
|
+
configs=[AutotuneConfig(config=c) for c in get_all_configs()],
|
|
122
|
+
key=["activation", "dynamic_scheduler"],
|
|
127
123
|
)
|
|
128
|
-
def
|
|
129
|
-
A: Tensor,
|
|
130
|
-
B: Tensor,
|
|
131
|
-
|
|
124
|
+
def gemm_dact_tuned(
|
|
125
|
+
A: Tensor, # (M, K)
|
|
126
|
+
B: Tensor, # (K, N)
|
|
127
|
+
PreAct: Tensor, # (M, N)
|
|
128
|
+
dx_out: Tensor, # (M, N)
|
|
129
|
+
postact_out: Tensor, # (M, N)
|
|
130
|
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
131
|
+
dynamic_scheduler: bool = True,
|
|
132
132
|
config: Optional[GemmConfig] = None,
|
|
133
|
-
) ->
|
|
133
|
+
) -> None:
|
|
134
134
|
if config is None:
|
|
135
|
-
config = GemmConfig(
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
)
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
A,
|
|
147
|
-
B,
|
|
148
|
-
|
|
135
|
+
config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
|
|
136
|
+
A, B = A.unsqueeze(0), B.mT.unsqueeze(0) # (1, M, K), (1, N, K)
|
|
137
|
+
PreAct = PreAct.unsqueeze(0) # (1, M, N)
|
|
138
|
+
assert dx_out.shape == (A.shape[1], B.shape[1])
|
|
139
|
+
D = dx_out.unsqueeze(0)
|
|
140
|
+
assert postact_out.shape == (A.shape[1], B.shape[1])
|
|
141
|
+
PostAct = postact_out.unsqueeze(0)
|
|
142
|
+
tile_count_semaphore = (
|
|
143
|
+
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
|
|
144
|
+
)
|
|
145
|
+
gemm_dact_sm90(
|
|
146
|
+
A if not config.swap_ab else B,
|
|
147
|
+
B if not config.swap_ab else A,
|
|
148
|
+
D if not config.swap_ab else D.mT,
|
|
149
|
+
PreAct if not config.swap_ab else PreAct.mT,
|
|
150
|
+
PostAct if not config.swap_ab else PostAct.mT,
|
|
151
|
+
tile_count_semaphore,
|
|
152
|
+
activation,
|
|
149
153
|
config.tile_m,
|
|
150
154
|
config.tile_n,
|
|
151
155
|
config.cluster_m,
|
|
152
156
|
config.cluster_n,
|
|
153
157
|
config.pingpong,
|
|
154
|
-
config.raster_order,
|
|
155
|
-
config.max_swizzle_size,
|
|
156
158
|
)
|
|
157
159
|
|
|
158
160
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
+
def gemm(
|
|
162
|
+
A: Tensor,
|
|
163
|
+
B: Tensor,
|
|
164
|
+
out: Optional[Tensor] = None,
|
|
165
|
+
alpha: float | Tensor = 1.0,
|
|
166
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
167
|
+
dynamic_scheduler: bool = False,
|
|
168
|
+
tuned: bool = True,
|
|
169
|
+
) -> Tensor:
|
|
170
|
+
"""GEMM with optional output tensor and tuning control."""
|
|
171
|
+
if out is None:
|
|
172
|
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
173
|
+
out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
|
|
174
|
+
alpha_tensor = alpha if not isinstance(alpha, float) else None
|
|
175
|
+
alpha = alpha if isinstance(alpha, float) else 1.0
|
|
176
|
+
gemm_out(
|
|
177
|
+
A,
|
|
178
|
+
B,
|
|
179
|
+
out,
|
|
180
|
+
alpha=alpha,
|
|
181
|
+
alpha_tensor=alpha_tensor,
|
|
182
|
+
dynamic_scheduler=dynamic_scheduler,
|
|
183
|
+
tuned=tuned,
|
|
184
|
+
)
|
|
185
|
+
return out
|
|
186
|
+
|
|
187
|
+
|
|
161
188
|
@torch.library.custom_op(
|
|
162
|
-
"quack::
|
|
163
|
-
mutates_args=(),
|
|
189
|
+
"quack::gemm_out",
|
|
190
|
+
mutates_args=("out",),
|
|
164
191
|
device_types="cuda",
|
|
165
|
-
|
|
192
|
+
# We have to split out alpha and alpha_tensor since torch.library requires
|
|
193
|
+
# each argument to have a fixed type
|
|
194
|
+
# schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, Tensor? alpha_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
166
195
|
)
|
|
167
|
-
def
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
196
|
+
def gemm_out(
|
|
197
|
+
A: Tensor,
|
|
198
|
+
B: Tensor,
|
|
199
|
+
out: Tensor,
|
|
200
|
+
alpha: float = 1.0,
|
|
201
|
+
alpha_tensor: Optional[Tensor] = None,
|
|
202
|
+
dynamic_scheduler: bool = False,
|
|
203
|
+
tuned: bool = True,
|
|
204
|
+
) -> None:
|
|
205
|
+
"""GEMM with pre-allocated output tensor."""
|
|
206
|
+
fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
|
|
207
|
+
alpha = alpha_tensor if alpha_tensor is not None else alpha
|
|
208
|
+
fn(A, B, out, C=None, alpha=alpha, dynamic_scheduler=dynamic_scheduler)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def gemm_ref(
|
|
212
|
+
A: Tensor,
|
|
213
|
+
B: Tensor,
|
|
214
|
+
out: Optional[Tensor] = None,
|
|
215
|
+
alpha: float | Tensor = 1.0,
|
|
216
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
217
|
+
) -> Tensor:
|
|
218
|
+
"""Reference implementation for GEMM with pre-allocated output."""
|
|
219
|
+
# The out_dtype argument requires torch >= 2.8
|
|
220
|
+
out = torch.mm(A, B, out_dtype=out_dtype, out=out)
|
|
221
|
+
if not isinstance(alpha, float) or alpha != 1.0:
|
|
222
|
+
out = out * alpha
|
|
223
|
+
return out
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def gemm_add(
|
|
227
|
+
A: Tensor,
|
|
228
|
+
B: Tensor,
|
|
229
|
+
C: Tensor,
|
|
230
|
+
out: Optional[Tensor] = None,
|
|
231
|
+
alpha: float | Tensor = 1.0,
|
|
232
|
+
beta: float | Tensor = 1.0,
|
|
233
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
234
|
+
dynamic_scheduler: bool = False,
|
|
235
|
+
tuned: bool = True,
|
|
236
|
+
) -> Tensor:
|
|
237
|
+
"""GEMM with addition and optional output tensor."""
|
|
238
|
+
if out is None:
|
|
239
|
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
240
|
+
out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
|
|
241
|
+
alpha_tensor = alpha if not isinstance(alpha, float) else None
|
|
242
|
+
alpha = alpha if isinstance(alpha, float) else 1.0
|
|
243
|
+
beta_tensor = beta if not isinstance(beta, float) else None
|
|
244
|
+
beta = beta if isinstance(beta, float) else 1.0
|
|
245
|
+
gemm_add_out(
|
|
246
|
+
A,
|
|
247
|
+
B,
|
|
248
|
+
C,
|
|
249
|
+
out,
|
|
250
|
+
alpha,
|
|
251
|
+
beta,
|
|
252
|
+
alpha_tensor,
|
|
253
|
+
beta_tensor,
|
|
254
|
+
dynamic_scheduler=dynamic_scheduler,
|
|
255
|
+
tuned=tuned,
|
|
256
|
+
)
|
|
257
|
+
return out
|
|
185
258
|
|
|
186
259
|
|
|
187
|
-
@
|
|
188
|
-
|
|
260
|
+
@torch.library.custom_op(
|
|
261
|
+
"quack::gemm_add_out",
|
|
262
|
+
mutates_args=("out",),
|
|
263
|
+
device_types="cuda",
|
|
264
|
+
# We have to split out alpha and alpha_tensor since torch.library requires
|
|
265
|
+
# each argument to have a fixed type
|
|
266
|
+
# schema="(Tensor A, Tensor B, Tensor C, Tensor(a3!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
189
267
|
)
|
|
190
|
-
def
|
|
268
|
+
def gemm_add_out(
|
|
191
269
|
A: Tensor,
|
|
192
270
|
B: Tensor,
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
271
|
+
C: Tensor,
|
|
272
|
+
out: Tensor,
|
|
273
|
+
alpha: float = 1.0,
|
|
274
|
+
beta: float = 1.0,
|
|
275
|
+
alpha_tensor: Optional[Tensor] = None,
|
|
276
|
+
beta_tensor: Optional[Tensor] = None,
|
|
277
|
+
dynamic_scheduler: bool = False,
|
|
278
|
+
tuned: bool = True,
|
|
279
|
+
) -> None:
|
|
280
|
+
"""GEMM with addition and pre-allocated output tensor."""
|
|
281
|
+
fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
|
|
282
|
+
alpha = alpha_tensor if alpha_tensor is not None else alpha
|
|
283
|
+
beta = beta_tensor if beta_tensor is not None else beta
|
|
284
|
+
fn(A, B, out, C, alpha=alpha, beta=beta, dynamic_scheduler=dynamic_scheduler)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def gemm_add_ref(
|
|
288
|
+
A: Tensor,
|
|
289
|
+
B: Tensor,
|
|
290
|
+
C: Tensor,
|
|
291
|
+
out: Optional[Tensor] = None,
|
|
292
|
+
alpha: float | Tensor = 1.0,
|
|
293
|
+
beta: float | Tensor = 1.0,
|
|
294
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
295
|
+
) -> Tensor:
|
|
296
|
+
"""Reference implementation for GEMM with addition and pre-allocated output."""
|
|
297
|
+
if isinstance(alpha, float) and isinstance(beta, float):
|
|
298
|
+
return torch.addmm(C, A, B, out_dtype=out_dtype, alpha=alpha, beta=beta, out=out)
|
|
299
|
+
else:
|
|
300
|
+
out_dtype = (
|
|
301
|
+
out.dtype if out is not None else (out_dtype if out_dtype is not None else A.dtype)
|
|
206
302
|
)
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
303
|
+
result = (alpha * (A @ B) + beta * C).to(out_dtype)
|
|
304
|
+
if out is not None:
|
|
305
|
+
out.copy_(result)
|
|
306
|
+
return result
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def gemm_add_inplace(
|
|
310
|
+
A: Tensor,
|
|
311
|
+
B: Tensor,
|
|
312
|
+
out: Tensor,
|
|
313
|
+
alpha: float | Tensor = 1.0,
|
|
314
|
+
beta: float | Tensor = 1.0,
|
|
315
|
+
dynamic_scheduler: bool = False,
|
|
316
|
+
tuned: bool = True,
|
|
317
|
+
) -> None:
|
|
318
|
+
"""In-place GEMM with addition: out = alpha * A @ B + beta * out.
|
|
319
|
+
Args:
|
|
320
|
+
A: (M, K) input tensor
|
|
321
|
+
B: (K, N) input tensor
|
|
322
|
+
out: (M, N) tensor to accumulate into (modified in-place)
|
|
323
|
+
alpha: Scalar multiplier for A @ B
|
|
324
|
+
beta: Scalar multiplier for out
|
|
325
|
+
dynamic_scheduler: Whether to use dynamic scheduler
|
|
326
|
+
tuned: Whether to use autotuned configuration
|
|
327
|
+
"""
|
|
328
|
+
alpha_tensor = alpha if not isinstance(alpha, float) else None
|
|
329
|
+
alpha = alpha if isinstance(alpha, float) else 1.0
|
|
330
|
+
beta_tensor = beta if not isinstance(beta, float) else None
|
|
331
|
+
beta = beta if isinstance(beta, float) else 1.0
|
|
332
|
+
gemm_add_inplace_op(A, B, out, alpha, beta, alpha_tensor, beta_tensor, dynamic_scheduler, tuned)
|
|
222
333
|
|
|
223
334
|
|
|
224
|
-
# Specifying the schema manually here since torch.library._infer_schema doesn't work when return
|
|
225
|
-
# type is a tuple of Tensor
|
|
226
335
|
@torch.library.custom_op(
|
|
227
|
-
"quack::
|
|
228
|
-
mutates_args=(),
|
|
336
|
+
"quack::gemm_add_inplace",
|
|
337
|
+
mutates_args=("out",),
|
|
229
338
|
device_types="cuda",
|
|
230
|
-
|
|
339
|
+
# We have to split out alpha and alpha_tensor since torch.library requires
|
|
340
|
+
# each argument to have a fixed type
|
|
341
|
+
# schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
231
342
|
)
|
|
232
|
-
def
|
|
233
|
-
return gemm_dswiglu_tuned(A, B, preact, sm_carveout)
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
@torch.library.register_fake("quack::gemm_dswiglu")
|
|
237
|
-
def gemm_dswiglu_ref(
|
|
238
|
-
A: Tensor, B: Tensor, preact: Tensor, sm_carveout: int = 0
|
|
239
|
-
) -> (Tensor, Tensor):
|
|
240
|
-
# A: (M, K), B: (K, N), preact: (M, 2 * N)
|
|
241
|
-
dout = torch.mm(A, B)
|
|
242
|
-
p0, p1 = preact[..., ::2], preact[..., 1::2]
|
|
243
|
-
sigmoid = torch.sigmoid(p0)
|
|
244
|
-
silu = F.silu(p0)
|
|
245
|
-
postact = silu * p1
|
|
246
|
-
d0 = sigmoid * (1 + p0 * (1 - sigmoid)) * p1 * dout
|
|
247
|
-
d1 = F.silu(p0) * dout
|
|
248
|
-
out = torch.stack([d0, d1], dim=-1).reshape(d0.shape[:-1] + (2 * d0.shape[-1],))
|
|
249
|
-
return out, postact
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
@autotune(configs=[AutotuneConfig(config=c) for c in get_all_configs("lse")])
|
|
253
|
-
def gemm_lse_tuned(
|
|
343
|
+
def gemm_add_inplace_op(
|
|
254
344
|
A: Tensor,
|
|
255
345
|
B: Tensor,
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
346
|
+
out: Tensor,
|
|
347
|
+
alpha: float = 1.0,
|
|
348
|
+
beta: float = 1.0,
|
|
349
|
+
alpha_tensor: Optional[Tensor] = None,
|
|
350
|
+
beta_tensor: Optional[Tensor] = None,
|
|
351
|
+
dynamic_scheduler: bool = False,
|
|
352
|
+
tuned: bool = True,
|
|
353
|
+
) -> None:
|
|
354
|
+
fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
|
|
355
|
+
alpha = alpha_tensor if alpha_tensor is not None else alpha
|
|
356
|
+
beta = beta_tensor if beta_tensor is not None else beta
|
|
357
|
+
# Use C as both input bias and output
|
|
358
|
+
fn(A, B, out, out, alpha=alpha, beta=beta, dynamic_scheduler=dynamic_scheduler)
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def gemm_act(
|
|
362
|
+
A: Tensor,
|
|
363
|
+
B: Tensor,
|
|
364
|
+
C: Optional[Tensor] = None,
|
|
365
|
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
366
|
+
preact_out: Optional[Tensor] = None,
|
|
367
|
+
postact_out: Optional[Tensor] = None,
|
|
368
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
369
|
+
postact_dtype: Optional[torch.dtype] = None,
|
|
370
|
+
store_preact: bool = True,
|
|
371
|
+
tuned: bool = True,
|
|
372
|
+
) -> Tuple[Optional[Tensor], Tensor]:
|
|
373
|
+
"""GEMM with activation and optional output tensors."""
|
|
374
|
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
375
|
+
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
|
376
|
+
if preact_out is None and store_preact:
|
|
377
|
+
preact_out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
|
|
378
|
+
if postact_out is None:
|
|
379
|
+
postact_out = torch.empty((A.shape[0], B.shape[1]), dtype=postact_dtype, device=A.device)
|
|
380
|
+
gemm_act_out(A, B, preact_out, postact_out, C, activation, tuned)
|
|
381
|
+
return preact_out, postact_out
|
|
284
382
|
|
|
285
383
|
|
|
286
384
|
@torch.library.custom_op(
|
|
287
|
-
"quack::
|
|
288
|
-
mutates_args=(),
|
|
385
|
+
"quack::gemm_act_out",
|
|
386
|
+
mutates_args=("preact_out", "postact_out"),
|
|
289
387
|
device_types="cuda",
|
|
290
|
-
schema="(Tensor A, Tensor B,
|
|
388
|
+
schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, str? activation=None, bool tuned=True) -> ()",
|
|
291
389
|
)
|
|
292
|
-
def
|
|
293
|
-
|
|
390
|
+
def gemm_act_out(
|
|
391
|
+
A: Tensor,
|
|
392
|
+
B: Tensor,
|
|
393
|
+
preact_out: Optional[Tensor],
|
|
394
|
+
postact_out: Tensor,
|
|
395
|
+
C: Optional[Tensor] = None,
|
|
396
|
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
397
|
+
tuned: bool = True,
|
|
398
|
+
) -> None:
|
|
399
|
+
"""GEMM with activation and pre-allocated output tensors."""
|
|
400
|
+
fn = gemm_act_tuned if tuned else partial(gemm_act_tuned.fn, config=None)
|
|
401
|
+
fn(A, B, preact_out, postact_out, C, activation)
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def gemm_act_ref(
|
|
405
|
+
A: Tensor,
|
|
406
|
+
B: Tensor,
|
|
407
|
+
C: Optional[Tensor] = None,
|
|
408
|
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
409
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
410
|
+
postact_dtype: Optional[torch.dtype] = None,
|
|
411
|
+
store_preact: bool = True,
|
|
412
|
+
) -> Tuple[Optional[Tensor], Tensor]:
|
|
413
|
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
414
|
+
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
|
415
|
+
out = torch.mm(A, B) if C is None else C + torch.mm(A, B)
|
|
416
|
+
postact = act_to_pytorch_fn_map[activation](out).to(postact_dtype)
|
|
417
|
+
return out.to(out_dtype) if store_preact else None, postact
|
|
294
418
|
|
|
295
419
|
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
420
|
+
def gemm_dact(
|
|
421
|
+
A: Tensor,
|
|
422
|
+
B: Tensor,
|
|
423
|
+
PreAct: Tensor,
|
|
424
|
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
425
|
+
dx_out: Optional[Tensor] = None,
|
|
426
|
+
postact_out: Optional[Tensor] = None,
|
|
427
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
428
|
+
postact_dtype: Optional[torch.dtype] = None,
|
|
429
|
+
dynamic_scheduler: bool = True,
|
|
430
|
+
tuned: bool = True,
|
|
431
|
+
) -> Tuple[Tensor, Tensor]:
|
|
432
|
+
"""GEMM with activation gradient and optional output tensors."""
|
|
433
|
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
434
|
+
postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
|
|
435
|
+
if dx_out is None:
|
|
436
|
+
dx_out = torch.empty((A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
|
|
437
|
+
if postact_out is None:
|
|
438
|
+
postact_out = torch.empty((A.shape[0], B.shape[1]), dtype=postact_dtype, device=A.device)
|
|
439
|
+
gemm_dact_out(A, B, PreAct, dx_out, postact_out, activation, dynamic_scheduler, tuned)
|
|
440
|
+
return dx_out, postact_out
|
|
307
441
|
|
|
308
442
|
|
|
309
443
|
@torch.library.custom_op(
|
|
310
|
-
"quack::
|
|
311
|
-
mutates_args=(),
|
|
444
|
+
"quack::gemm_dact_out",
|
|
445
|
+
mutates_args=("dx_out", "postact_out"),
|
|
312
446
|
device_types="cuda",
|
|
313
|
-
schema="(Tensor A, Tensor B,
|
|
447
|
+
schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a3!) dx_out, Tensor(a4!) postact_out, str? activation=None, bool dynamic_scheduler=True, bool tuned=True) -> ()",
|
|
314
448
|
)
|
|
315
|
-
def
|
|
316
|
-
|
|
449
|
+
def gemm_dact_out(
|
|
450
|
+
A: Tensor,
|
|
451
|
+
B: Tensor,
|
|
452
|
+
PreAct: Tensor,
|
|
453
|
+
dx_out: Tensor,
|
|
454
|
+
postact_out: Tensor,
|
|
455
|
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
456
|
+
dynamic_scheduler: bool = True,
|
|
457
|
+
tuned: bool = True,
|
|
458
|
+
) -> None:
|
|
459
|
+
"""GEMM with activation gradient and pre-allocated output tensors."""
|
|
460
|
+
fn = gemm_dact_tuned if tuned else partial(gemm_dact_tuned.fn, config=None)
|
|
461
|
+
fn(A, B, PreAct, dx_out, postact_out, activation, dynamic_scheduler)
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def gemm_dact_ref(
|
|
465
|
+
A: Tensor,
|
|
466
|
+
B: Tensor,
|
|
467
|
+
PreAct: Tensor,
|
|
468
|
+
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
469
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
470
|
+
postact_dtype: Optional[torch.dtype] = None,
|
|
471
|
+
) -> Tuple[Tensor, Tensor]:
|
|
472
|
+
"""Reference implementation for GEMM with activation gradient."""
|
|
473
|
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
474
|
+
postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
|
|
475
|
+
dout = torch.mm(A, B).to(out_dtype)
|
|
476
|
+
postact = act_to_pytorch_fn_map[activation](PreAct)
|
|
477
|
+
# Compute gradient using autograd
|
|
478
|
+
if activation is None:
|
|
479
|
+
dx = dout
|
|
480
|
+
else:
|
|
481
|
+
PreAct_requires_grad = PreAct.requires_grad
|
|
482
|
+
PreAct.requires_grad_(True)
|
|
483
|
+
postact_for_grad = act_to_pytorch_fn_map[activation](PreAct)
|
|
484
|
+
dx = torch.autograd.grad(postact_for_grad, PreAct, dout, create_graph=False)[0]
|
|
485
|
+
PreAct.requires_grad_(PreAct_requires_grad)
|
|
486
|
+
return dx.to(out_dtype), postact.to(postact_dtype)
|
|
317
487
|
|
|
318
488
|
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
489
|
+
def gemm_gated_ref(
|
|
490
|
+
A: Tensor,
|
|
491
|
+
B: Tensor,
|
|
492
|
+
C: Optional[Tensor] = None,
|
|
493
|
+
activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"] = "swiglu",
|
|
494
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
495
|
+
postact_dtype: Optional[torch.dtype] = None,
|
|
496
|
+
store_preact: bool = True,
|
|
497
|
+
) -> Tuple[Optional[Tensor], Tensor]:
|
|
498
|
+
"""Reference implementation for GEMM with gated activation forward.
|
|
499
|
+
|
|
500
|
+
Args:
|
|
501
|
+
A: (M, K) - input tensor
|
|
502
|
+
B: (K, 2*N) - weight tensor with gate and up projections
|
|
503
|
+
C: (M, 2*N) - optional bias tensor
|
|
504
|
+
activation: Type of gated activation
|
|
505
|
+
out_dtype: Output dtype for preact
|
|
506
|
+
postact_dtype: Output dtype for postact
|
|
507
|
+
store_preact: Whether to return the pre-activation
|
|
508
|
+
|
|
509
|
+
Returns:
|
|
510
|
+
(preact, postact) where:
|
|
511
|
+
- preact: (M, 2*N) pre-activation (if store_preact=True, else None)
|
|
512
|
+
- postact: (M, N) post-activation output
|
|
513
|
+
"""
|
|
514
|
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
515
|
+
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
|
516
|
+
preact = torch.mm(A, B) if C is None else C + torch.mm(A, B)
|
|
517
|
+
# Split preact into gate and up projections
|
|
518
|
+
gate = preact[..., ::2] # (M, N)
|
|
519
|
+
up = preact[..., 1::2] # (M, N)
|
|
520
|
+
postact = gated_to_pytorch_fn_map[activation](gate, up)
|
|
521
|
+
return preact.to(out_dtype) if store_preact else None, postact.to(postact_dtype)
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
def gemm_dgated_ref(
|
|
525
|
+
A: Tensor,
|
|
526
|
+
B: Tensor,
|
|
527
|
+
PreAct: Tensor,
|
|
528
|
+
activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"],
|
|
529
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
530
|
+
postact_dtype: Optional[torch.dtype] = None,
|
|
531
|
+
) -> Tuple[Tensor, Tensor]:
|
|
532
|
+
"""Reference implementation for GEMM with gated activation gradient.
|
|
533
|
+
|
|
534
|
+
Args:
|
|
535
|
+
A: (M, K) - dout input tensor
|
|
536
|
+
B: (K, N) - weight tensor
|
|
537
|
+
PreAct: (M, 2*N) - pre-activation tensor with gate and up projections interleaved
|
|
538
|
+
activation: Type of gated activation
|
|
539
|
+
out_dtype: Output dtype for dx
|
|
540
|
+
postact_dtype: Output dtype for postact
|
|
541
|
+
|
|
542
|
+
Returns:
|
|
543
|
+
(dx, postact) where:
|
|
544
|
+
- dx: (M, 2*N) gradient w.r.t. PreAct
|
|
545
|
+
- postact: (M, N) post-activation output
|
|
546
|
+
"""
|
|
547
|
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
548
|
+
postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
|
|
549
|
+
dout = torch.mm(A, B).to(out_dtype)
|
|
550
|
+
# Split PreAct into gate and up projections
|
|
551
|
+
gate = PreAct[..., ::2] # (M, N)
|
|
552
|
+
up = PreAct[..., 1::2] # (M, N)
|
|
553
|
+
postact = gated_to_pytorch_fn_map[activation](gate, up)
|
|
554
|
+
# Use autograd to compute gradients w.r.t. gate and up
|
|
555
|
+
dgate, dup = torch.autograd.grad(postact, [gate, up], dout, create_graph=False)
|
|
556
|
+
# Interleave gradients back
|
|
557
|
+
dx = torch.stack([dgate, dup], dim=-1).reshape(PreAct.shape)
|
|
558
|
+
return dx.to(out_dtype), postact.to(postact_dtype)
|
|
559
|
+
|
|
560
|
+
|
|
561
|
+
# TODO: this is not quite right, do we need to register gemm_add not gemm_add_out?
|
|
562
|
+
# try:
|
|
563
|
+
# from torch._inductor.fx_passes.reinplace import InplaceableOp
|
|
564
|
+
# torch._inductor.fx_passes.reinplace.inplaceable_ops.update({
|
|
565
|
+
# torch.ops.quack.gemm_add_out.default:
|
|
566
|
+
# InplaceableOp(torch.ops.quack.gemm_add_inplace.default, mutated_arg=2)
|
|
567
|
+
# })
|
|
568
|
+
# except ImportError:
|
|
569
|
+
# pass
|