quack-kernels 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/activation.py +16 -25
- quack/autotuner.py +64 -5
- quack/cross_entropy.py +6 -10
- quack/cute_dsl_utils.py +6 -7
- quack/dense_gemm_sm90.py +582 -287
- quack/gemm_act_sm90.py +70 -29
- quack/gemm_dact_sm90.py +43 -10
- quack/gemm_interface.py +453 -130
- quack/{dense_gemm_sm100.py → gemm_sm100.py} +443 -419
- quack/gemm_wrapper_utils.py +179 -22
- quack/layernorm.py +1 -1
- quack/reduce.py +6 -7
- quack/rmsnorm.py +126 -158
- quack/softmax.py +1 -1
- quack/tile_scheduler.py +37 -49
- quack/utils.py +61 -71
- quack/varlen_utils.py +1 -6
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/METADATA +3 -3
- quack_kernels-0.2.2.dist-info/RECORD +37 -0
- quack_kernels-0.2.0.dist-info/RECORD +0 -37
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.0.dist-info → quack_kernels-0.2.2.dist-info}/top_level.txt +0 -0
quack/gemm_interface.py
CHANGED
|
@@ -34,30 +34,71 @@ gated_to_pytorch_fn_map = {
|
|
|
34
34
|
}
|
|
35
35
|
|
|
36
36
|
|
|
37
|
+
def prune_invalid_gemm_configs(configs, named_args: dict, **kwargs):
|
|
38
|
+
kwargs = named_args | kwargs
|
|
39
|
+
gather_A = kwargs.get("A_idx", None) is not None
|
|
40
|
+
varlen_m = kwargs.get("cu_seqlens_m", None) is not None
|
|
41
|
+
if varlen_m or gather_A: # Doesn't support swap_ab
|
|
42
|
+
configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab]
|
|
43
|
+
if gather_A:
|
|
44
|
+
# tile_n == 208 causes register spills, as gather_A requires more registers for the producer
|
|
45
|
+
configs = [
|
|
46
|
+
conf
|
|
47
|
+
for conf in configs
|
|
48
|
+
if conf.kwargs["config"].cluster_n == 1 and conf.kwargs["config"].tile_n != 208
|
|
49
|
+
]
|
|
50
|
+
return configs
|
|
51
|
+
|
|
52
|
+
|
|
37
53
|
@autotune(
|
|
38
54
|
configs=[AutotuneConfig(config=c) for c in get_all_configs()],
|
|
39
55
|
key=["dynamic_scheduler"],
|
|
56
|
+
prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
|
|
40
57
|
)
|
|
41
58
|
def gemm_tuned(
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
59
|
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
|
60
|
+
A: Tensor,
|
|
61
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
62
|
+
out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
63
|
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
46
64
|
alpha: float | Tensor = 1.0, # (1,)
|
|
47
65
|
beta: float | Tensor = 1.0, # (1,)
|
|
66
|
+
cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
|
|
67
|
+
cu_seqlens_k: Optional[Tensor] = None, # (L+1), int32
|
|
68
|
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
|
69
|
+
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
|
70
|
+
add_to_output: bool = False,
|
|
48
71
|
dynamic_scheduler: bool = False,
|
|
49
72
|
config: Optional[GemmConfig] = None,
|
|
50
73
|
) -> None:
|
|
51
74
|
if config is None:
|
|
52
75
|
config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
|
|
53
|
-
|
|
54
|
-
|
|
76
|
+
varlen_m = cu_seqlens_m is not None
|
|
77
|
+
varlen_k = cu_seqlens_k is not None
|
|
78
|
+
varlen = varlen_m or varlen_k
|
|
79
|
+
gather_A = A_idx is not None
|
|
80
|
+
if gather_A:
|
|
81
|
+
assert varlen, "gather_A requires either varlen_m or varlen_k"
|
|
82
|
+
assert config.cluster_n == 1, "gather_A requires cluster_n=1"
|
|
83
|
+
if varlen_m:
|
|
84
|
+
assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
|
|
85
|
+
if A.ndim == 2 and not varlen:
|
|
86
|
+
A = A.unsqueeze(0) # (1, M, K)
|
|
87
|
+
B = B.mT # (N, K) or (L, N, K) or (N, total_K)
|
|
88
|
+
if B.ndim == 2 and not varlen_k:
|
|
89
|
+
B = B.unsqueeze(0) # (1, N, K)
|
|
90
|
+
if C is not None and C.ndim == 2 and not varlen_m:
|
|
55
91
|
C = C.unsqueeze(0) # (1, M, N)
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
92
|
+
if out.ndim == 2 and not varlen_m:
|
|
93
|
+
out = out.unsqueeze(0)
|
|
94
|
+
batch_size = B.shape[0] if not varlen_k else cu_seqlens_k.shape[0] - 1
|
|
95
|
+
if varlen_m:
|
|
96
|
+
# If gather_A (A_idx provided), use its length; otherwise use A.shape[0]
|
|
97
|
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
|
98
|
+
out_shape = (total_m, B.shape[-2])
|
|
99
|
+
else:
|
|
100
|
+
out_shape = (batch_size, A.shape[-2], B.shape[-2])
|
|
101
|
+
assert out.shape == out_shape, f"out shape mismatch: {out.shape} vs {out_shape}"
|
|
61
102
|
tile_count_semaphore = (
|
|
62
103
|
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
|
|
63
104
|
)
|
|
@@ -74,71 +115,113 @@ def gemm_tuned(
|
|
|
74
115
|
config.pingpong,
|
|
75
116
|
alpha=alpha,
|
|
76
117
|
beta=beta,
|
|
118
|
+
cu_seqlens_m=cu_seqlens_m,
|
|
119
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
120
|
+
A_idx=A_idx,
|
|
121
|
+
batch_idx_permute=batch_idx_permute,
|
|
122
|
+
add_to_output=add_to_output,
|
|
77
123
|
)
|
|
78
124
|
|
|
79
125
|
|
|
80
126
|
@autotune(
|
|
81
127
|
configs=[AutotuneConfig(config=c) for c in get_all_configs()],
|
|
82
|
-
key=["activation"],
|
|
128
|
+
key=["activation", "dynamic_scheduler"],
|
|
129
|
+
prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
|
|
83
130
|
)
|
|
84
131
|
def gemm_act_tuned(
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
132
|
+
# (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
|
133
|
+
A: Tensor,
|
|
134
|
+
B: Tensor, # (K, N) or (L, K, N)
|
|
135
|
+
# (M, N) or (L, M, N) or (total_M, N) if varlen_m - None if not storing preact
|
|
136
|
+
preact_out: Optional[Tensor],
|
|
137
|
+
postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
138
|
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
90
139
|
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
140
|
+
cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
|
|
141
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
142
|
+
dynamic_scheduler: bool = False,
|
|
91
143
|
config: Optional[GemmConfig] = None,
|
|
92
144
|
) -> None:
|
|
93
145
|
if config is None:
|
|
94
146
|
config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
|
|
95
|
-
|
|
96
|
-
if
|
|
147
|
+
varlen_m = cu_seqlens_m is not None
|
|
148
|
+
if varlen_m:
|
|
149
|
+
assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
|
|
150
|
+
if A.ndim == 2 and not varlen_m:
|
|
151
|
+
A = A.unsqueeze(0) # (1, M, K)
|
|
152
|
+
B = B.mT # (N, K) or (L, N, K)
|
|
153
|
+
if B.ndim == 2:
|
|
154
|
+
B = B.unsqueeze(0) # (1, N, K)
|
|
155
|
+
if C is not None and C.ndim == 2 and not varlen_m:
|
|
97
156
|
C = C.unsqueeze(0) # (1, M, N)
|
|
98
|
-
if preact_out is not None:
|
|
99
|
-
assert preact_out.shape == (A.shape[1], B.shape[1])
|
|
157
|
+
if preact_out is not None and preact_out.ndim == 2 and not varlen_m:
|
|
100
158
|
D = preact_out.unsqueeze(0)
|
|
101
159
|
else:
|
|
102
|
-
D =
|
|
103
|
-
|
|
104
|
-
|
|
160
|
+
D = preact_out
|
|
161
|
+
if postact_out.ndim == 2 and not varlen_m:
|
|
162
|
+
PostAct = postact_out.unsqueeze(0)
|
|
163
|
+
else:
|
|
164
|
+
PostAct = postact_out
|
|
165
|
+
tile_count_semaphore = (
|
|
166
|
+
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
|
|
167
|
+
)
|
|
105
168
|
gemm_act_sm90(
|
|
106
169
|
A if not config.swap_ab else B,
|
|
107
170
|
B if not config.swap_ab else A,
|
|
108
171
|
(D if not config.swap_ab else D.mT) if D is not None else None,
|
|
109
172
|
(C if not config.swap_ab else C.mT) if C is not None else None,
|
|
110
173
|
PostAct if not config.swap_ab else PostAct.mT,
|
|
174
|
+
tile_count_semaphore,
|
|
111
175
|
activation,
|
|
112
176
|
config.tile_m,
|
|
113
177
|
config.tile_n,
|
|
114
178
|
config.cluster_m,
|
|
115
179
|
config.cluster_n,
|
|
116
180
|
config.pingpong,
|
|
181
|
+
persistent=True,
|
|
182
|
+
cu_seqlens_m=cu_seqlens_m,
|
|
183
|
+
A_idx=A_idx,
|
|
117
184
|
)
|
|
118
185
|
|
|
119
186
|
|
|
120
187
|
@autotune(
|
|
121
188
|
configs=[AutotuneConfig(config=c) for c in get_all_configs()],
|
|
122
189
|
key=["activation", "dynamic_scheduler"],
|
|
190
|
+
prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
|
|
123
191
|
)
|
|
124
192
|
def gemm_dact_tuned(
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
193
|
+
# (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
|
194
|
+
A: Tensor,
|
|
195
|
+
B: Tensor, # (K, N) or (L, K, N)
|
|
196
|
+
PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
197
|
+
dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
198
|
+
postact_out: Tensor, # (M, N) or (L, N, N) or (total_M, N) if varlen_m
|
|
130
199
|
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
200
|
+
cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
|
|
201
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
131
202
|
dynamic_scheduler: bool = True,
|
|
132
203
|
config: Optional[GemmConfig] = None,
|
|
133
204
|
) -> None:
|
|
134
205
|
if config is None:
|
|
135
206
|
config = GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
207
|
+
varlen_m = cu_seqlens_m is not None
|
|
208
|
+
if varlen_m:
|
|
209
|
+
assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
|
|
210
|
+
if A.ndim == 2 and not varlen_m:
|
|
211
|
+
A = A.unsqueeze(0) # (1, M, K)
|
|
212
|
+
B = B.mT # (N, K) or (L, N, K)
|
|
213
|
+
if B.ndim == 2:
|
|
214
|
+
B = B.unsqueeze(0) # (1, N, K)
|
|
215
|
+
if PreAct.ndim == 2 and not varlen_m:
|
|
216
|
+
PreAct = PreAct.unsqueeze(0) # (1, M, N)
|
|
217
|
+
if dx_out.ndim == 2 and not varlen_m:
|
|
218
|
+
D = dx_out.unsqueeze(0)
|
|
219
|
+
else:
|
|
220
|
+
D = dx_out
|
|
221
|
+
if postact_out.ndim == 2 and not varlen_m:
|
|
222
|
+
PostAct = postact_out.unsqueeze(0)
|
|
223
|
+
else:
|
|
224
|
+
PostAct = postact_out
|
|
142
225
|
tile_count_semaphore = (
|
|
143
226
|
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
|
|
144
227
|
)
|
|
@@ -155,22 +238,43 @@ def gemm_dact_tuned(
|
|
|
155
238
|
config.cluster_m,
|
|
156
239
|
config.cluster_n,
|
|
157
240
|
config.pingpong,
|
|
241
|
+
persistent=True,
|
|
242
|
+
cu_seqlens_m=cu_seqlens_m,
|
|
243
|
+
A_idx=A_idx,
|
|
158
244
|
)
|
|
159
245
|
|
|
160
246
|
|
|
161
247
|
def gemm(
|
|
248
|
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
|
162
249
|
A: Tensor,
|
|
163
|
-
B: Tensor,
|
|
164
|
-
out: Optional[Tensor] = None,
|
|
250
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
251
|
+
out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
165
252
|
alpha: float | Tensor = 1.0,
|
|
166
253
|
out_dtype: Optional[torch.dtype] = None,
|
|
254
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
255
|
+
cu_seqlens_k: Optional[Tensor] = None,
|
|
256
|
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
|
257
|
+
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
|
167
258
|
dynamic_scheduler: bool = False,
|
|
168
259
|
tuned: bool = True,
|
|
169
260
|
) -> Tensor:
|
|
170
261
|
"""GEMM with optional output tensor and tuning control."""
|
|
171
262
|
if out is None:
|
|
172
263
|
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
173
|
-
|
|
264
|
+
varlen_m = cu_seqlens_m is not None
|
|
265
|
+
varlen_k = cu_seqlens_k is not None
|
|
266
|
+
if varlen_m:
|
|
267
|
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
|
268
|
+
out_shape = (total_m, B.shape[-1])
|
|
269
|
+
elif varlen_k:
|
|
270
|
+
L = cu_seqlens_k.shape[0] - 1
|
|
271
|
+
# For varlen_k, the first dimension is always A.shape[0] (M dimension)
|
|
272
|
+
out_shape = (L, A.shape[0], B.shape[-1])
|
|
273
|
+
else:
|
|
274
|
+
out_shape = (
|
|
275
|
+
(A.shape[0], B.shape[-1]) if A.ndim == 2 else (A.shape[0], A.shape[-2], B.shape[-1])
|
|
276
|
+
)
|
|
277
|
+
out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
|
174
278
|
alpha_tensor = alpha if not isinstance(alpha, float) else None
|
|
175
279
|
alpha = alpha if isinstance(alpha, float) else 1.0
|
|
176
280
|
gemm_out(
|
|
@@ -179,6 +283,10 @@ def gemm(
|
|
|
179
283
|
out,
|
|
180
284
|
alpha=alpha,
|
|
181
285
|
alpha_tensor=alpha_tensor,
|
|
286
|
+
cu_seqlens_m=cu_seqlens_m,
|
|
287
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
288
|
+
A_idx=A_idx,
|
|
289
|
+
batch_idx_permute=batch_idx_permute,
|
|
182
290
|
dynamic_scheduler=dynamic_scheduler,
|
|
183
291
|
tuned=tuned,
|
|
184
292
|
)
|
|
@@ -194,50 +302,117 @@ def gemm(
|
|
|
194
302
|
# schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, Tensor? alpha_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
195
303
|
)
|
|
196
304
|
def gemm_out(
|
|
305
|
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
|
197
306
|
A: Tensor,
|
|
198
|
-
B: Tensor,
|
|
199
|
-
out: Tensor,
|
|
307
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
308
|
+
out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
200
309
|
alpha: float = 1.0,
|
|
201
310
|
alpha_tensor: Optional[Tensor] = None,
|
|
311
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
312
|
+
cu_seqlens_k: Optional[Tensor] = None,
|
|
313
|
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
|
314
|
+
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
|
202
315
|
dynamic_scheduler: bool = False,
|
|
203
316
|
tuned: bool = True,
|
|
204
317
|
) -> None:
|
|
205
318
|
"""GEMM with pre-allocated output tensor."""
|
|
206
319
|
fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
|
|
207
320
|
alpha = alpha_tensor if alpha_tensor is not None else alpha
|
|
208
|
-
fn(
|
|
321
|
+
fn(
|
|
322
|
+
A,
|
|
323
|
+
B,
|
|
324
|
+
out,
|
|
325
|
+
C=None,
|
|
326
|
+
alpha=alpha,
|
|
327
|
+
cu_seqlens_m=cu_seqlens_m,
|
|
328
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
329
|
+
A_idx=A_idx,
|
|
330
|
+
batch_idx_permute=batch_idx_permute,
|
|
331
|
+
dynamic_scheduler=dynamic_scheduler,
|
|
332
|
+
)
|
|
209
333
|
|
|
210
334
|
|
|
211
335
|
def gemm_ref(
|
|
336
|
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
|
212
337
|
A: Tensor,
|
|
213
|
-
B: Tensor,
|
|
214
|
-
out: Optional[Tensor] = None,
|
|
338
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
339
|
+
out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
215
340
|
alpha: float | Tensor = 1.0,
|
|
341
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
342
|
+
cu_seqlens_k: Optional[Tensor] = None,
|
|
343
|
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
|
216
344
|
out_dtype: Optional[torch.dtype] = None,
|
|
217
345
|
) -> Tensor:
|
|
218
346
|
"""Reference implementation for GEMM with pre-allocated output."""
|
|
219
347
|
# The out_dtype argument requires torch >= 2.8
|
|
220
|
-
|
|
348
|
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
349
|
+
if cu_seqlens_m is None and cu_seqlens_k is None:
|
|
350
|
+
fn = torch.bmm if A.ndim == 3 else torch.mm
|
|
351
|
+
out = fn(A, B, out_dtype=out_dtype, out=out)
|
|
352
|
+
elif cu_seqlens_m is not None:
|
|
353
|
+
# Handle varlen_m case
|
|
354
|
+
if out is None:
|
|
355
|
+
# When gather_A (A_idx provided), output size is determined by A_idx length
|
|
356
|
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
|
357
|
+
out = torch.empty((total_m, B.shape[-1]), dtype=out_dtype, device=A.device)
|
|
358
|
+
for i in range(cu_seqlens_m.shape[0] - 1):
|
|
359
|
+
A_slice = (
|
|
360
|
+
A[A_idx[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]]
|
|
361
|
+
if A_idx is not None
|
|
362
|
+
else A[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
|
|
363
|
+
)
|
|
364
|
+
torch.mm(A_slice, B[i], out=out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]])
|
|
365
|
+
else: # cu_seqlens_k is not None
|
|
366
|
+
L = cu_seqlens_k.shape[0] - 1
|
|
367
|
+
if out is None:
|
|
368
|
+
out = torch.empty((L, A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
|
|
369
|
+
for i in range(L):
|
|
370
|
+
A_slice = (
|
|
371
|
+
A[:, A_idx[cu_seqlens_k[i] : cu_seqlens_k[i + 1]]]
|
|
372
|
+
if A_idx is not None
|
|
373
|
+
else A[:, cu_seqlens_k[i] : cu_seqlens_k[i + 1]]
|
|
374
|
+
)
|
|
375
|
+
torch.mm(A_slice, B[cu_seqlens_k[i] : cu_seqlens_k[i + 1], :], out=out[i])
|
|
221
376
|
if not isinstance(alpha, float) or alpha != 1.0:
|
|
222
377
|
out = out * alpha
|
|
223
378
|
return out
|
|
224
379
|
|
|
225
380
|
|
|
226
381
|
def gemm_add(
|
|
382
|
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
|
227
383
|
A: Tensor,
|
|
228
|
-
B: Tensor,
|
|
229
|
-
C: Tensor,
|
|
230
|
-
out: Optional[Tensor] = None,
|
|
384
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
385
|
+
C: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
|
|
386
|
+
out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
231
387
|
alpha: float | Tensor = 1.0,
|
|
232
388
|
beta: float | Tensor = 1.0,
|
|
233
389
|
out_dtype: Optional[torch.dtype] = None,
|
|
390
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
391
|
+
cu_seqlens_k: Optional[Tensor] = None,
|
|
392
|
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
|
393
|
+
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
|
234
394
|
dynamic_scheduler: bool = False,
|
|
235
395
|
tuned: bool = True,
|
|
236
396
|
) -> Tensor:
|
|
237
397
|
"""GEMM with addition and optional output tensor."""
|
|
238
398
|
if out is None:
|
|
239
399
|
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
240
|
-
|
|
400
|
+
varlen_m = cu_seqlens_m is not None
|
|
401
|
+
varlen_k = cu_seqlens_k is not None
|
|
402
|
+
if varlen_m:
|
|
403
|
+
# If A_idx is provided (gather_A), use its length; otherwise use A.shape[0]
|
|
404
|
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
|
405
|
+
out_shape = (total_m, B.shape[-1])
|
|
406
|
+
elif varlen_k:
|
|
407
|
+
L = cu_seqlens_k.shape[0] - 1
|
|
408
|
+
# For varlen_k, the first dimension is always A.shape[0] (M dimension)
|
|
409
|
+
out_shape = (L, A.shape[0], B.shape[-1])
|
|
410
|
+
else:
|
|
411
|
+
out_shape = (
|
|
412
|
+
(A.shape[0], B.shape[-1]) if A.ndim == 2 else (A.shape[0], A.shape[-2], B.shape[-1])
|
|
413
|
+
)
|
|
414
|
+
out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
|
415
|
+
add_to_output = C is out and isinstance(beta, float) and beta == 1.0 and cu_seqlens_m is None
|
|
241
416
|
alpha_tensor = alpha if not isinstance(alpha, float) else None
|
|
242
417
|
alpha = alpha if isinstance(alpha, float) else 1.0
|
|
243
418
|
beta_tensor = beta if not isinstance(beta, float) else None
|
|
@@ -245,12 +420,17 @@ def gemm_add(
|
|
|
245
420
|
gemm_add_out(
|
|
246
421
|
A,
|
|
247
422
|
B,
|
|
248
|
-
C,
|
|
423
|
+
C if not add_to_output else None,
|
|
249
424
|
out,
|
|
250
425
|
alpha,
|
|
251
426
|
beta,
|
|
252
427
|
alpha_tensor,
|
|
253
428
|
beta_tensor,
|
|
429
|
+
cu_seqlens_m=cu_seqlens_m,
|
|
430
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
431
|
+
A_idx=A_idx,
|
|
432
|
+
batch_idx_permute=batch_idx_permute,
|
|
433
|
+
add_to_output=add_to_output,
|
|
254
434
|
dynamic_scheduler=dynamic_scheduler,
|
|
255
435
|
tuned=tuned,
|
|
256
436
|
)
|
|
@@ -263,17 +443,23 @@ def gemm_add(
|
|
|
263
443
|
device_types="cuda",
|
|
264
444
|
# We have to split out alpha and alpha_tensor since torch.library requires
|
|
265
445
|
# each argument to have a fixed type
|
|
266
|
-
# schema="(Tensor A, Tensor B, Tensor C, Tensor(a3!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
446
|
+
# schema="(Tensor A, Tensor B, Tensor C, Tensor(a3!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, Tensor? cu_seqlens_m=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
267
447
|
)
|
|
268
448
|
def gemm_add_out(
|
|
449
|
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
|
269
450
|
A: Tensor,
|
|
270
|
-
B: Tensor,
|
|
271
|
-
C: Tensor,
|
|
272
|
-
out: Tensor,
|
|
451
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
452
|
+
C: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
|
|
453
|
+
out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
273
454
|
alpha: float = 1.0,
|
|
274
455
|
beta: float = 1.0,
|
|
275
456
|
alpha_tensor: Optional[Tensor] = None,
|
|
276
457
|
beta_tensor: Optional[Tensor] = None,
|
|
458
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
459
|
+
cu_seqlens_k: Optional[Tensor] = None,
|
|
460
|
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
|
461
|
+
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
|
462
|
+
add_to_output: bool = False,
|
|
277
463
|
dynamic_scheduler: bool = False,
|
|
278
464
|
tuned: bool = True,
|
|
279
465
|
) -> None:
|
|
@@ -281,47 +467,105 @@ def gemm_add_out(
|
|
|
281
467
|
fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
|
|
282
468
|
alpha = alpha_tensor if alpha_tensor is not None else alpha
|
|
283
469
|
beta = beta_tensor if beta_tensor is not None else beta
|
|
284
|
-
fn(
|
|
470
|
+
fn(
|
|
471
|
+
A,
|
|
472
|
+
B,
|
|
473
|
+
out,
|
|
474
|
+
C,
|
|
475
|
+
alpha=alpha,
|
|
476
|
+
beta=beta,
|
|
477
|
+
cu_seqlens_m=cu_seqlens_m,
|
|
478
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
479
|
+
A_idx=A_idx,
|
|
480
|
+
batch_idx_permute=batch_idx_permute,
|
|
481
|
+
add_to_output=add_to_output,
|
|
482
|
+
dynamic_scheduler=dynamic_scheduler,
|
|
483
|
+
)
|
|
285
484
|
|
|
286
485
|
|
|
287
486
|
def gemm_add_ref(
|
|
487
|
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
|
288
488
|
A: Tensor,
|
|
289
|
-
B: Tensor,
|
|
290
|
-
C: Tensor,
|
|
291
|
-
out: Optional[Tensor] = None,
|
|
489
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
490
|
+
C: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
491
|
+
out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
292
492
|
alpha: float | Tensor = 1.0,
|
|
293
493
|
beta: float | Tensor = 1.0,
|
|
494
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
495
|
+
cu_seqlens_k: Optional[Tensor] = None,
|
|
496
|
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
|
294
497
|
out_dtype: Optional[torch.dtype] = None,
|
|
295
498
|
) -> Tensor:
|
|
296
499
|
"""Reference implementation for GEMM with addition and pre-allocated output."""
|
|
297
|
-
if
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
out
|
|
306
|
-
|
|
500
|
+
if cu_seqlens_m is None and cu_seqlens_k is None:
|
|
501
|
+
if isinstance(alpha, float) and isinstance(beta, float):
|
|
502
|
+
return torch.addmm(C, A, B, out_dtype=out_dtype, alpha=alpha, beta=beta, out=out)
|
|
503
|
+
else:
|
|
504
|
+
out_dtype = (
|
|
505
|
+
out.dtype if out is not None else (out_dtype if out_dtype is not None else A.dtype)
|
|
506
|
+
)
|
|
507
|
+
result = (alpha * (A @ B) + beta * C).to(out_dtype)
|
|
508
|
+
if out is not None:
|
|
509
|
+
out.copy_(result)
|
|
510
|
+
return result
|
|
511
|
+
elif cu_seqlens_m is not None:
|
|
512
|
+
# Handle varlen_m case
|
|
513
|
+
if out is None:
|
|
514
|
+
# When gather_A (A_idx provided), output size is determined by A_idx length
|
|
515
|
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
|
516
|
+
out_dtype = out_dtype if out_dtype is not None else A.dtype
|
|
517
|
+
out = torch.empty((total_m, B.shape[-1]), dtype=out_dtype, device=A.device)
|
|
518
|
+
for i in range(cu_seqlens_m.shape[0] - 1):
|
|
519
|
+
A_slice = (
|
|
520
|
+
A[A_idx[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]]
|
|
521
|
+
if A_idx is not None
|
|
522
|
+
else A[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
|
|
523
|
+
)
|
|
524
|
+
C_slice = C[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
|
|
525
|
+
out_slice = out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
|
|
526
|
+
result = alpha * torch.mm(A_slice, B[i]) + beta * C_slice
|
|
527
|
+
out_slice.copy_(result)
|
|
528
|
+
else: # cu_seqlens_k is not None
|
|
529
|
+
# Handle varlen_k case
|
|
530
|
+
L = cu_seqlens_k.shape[0] - 1
|
|
531
|
+
out_dtype = out_dtype if out_dtype is not None else A.dtype
|
|
532
|
+
if out is None:
|
|
533
|
+
out = torch.empty((L, A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
|
|
534
|
+
for i in range(L):
|
|
535
|
+
A_slice = (
|
|
536
|
+
A[:, A_idx[cu_seqlens_k[i] : cu_seqlens_k[i + 1]]]
|
|
537
|
+
if A_idx is not None
|
|
538
|
+
else A[:, cu_seqlens_k[i] : cu_seqlens_k[i + 1]]
|
|
539
|
+
)
|
|
540
|
+
B_slice = B[cu_seqlens_k[i] : cu_seqlens_k[i + 1], :]
|
|
541
|
+
result = alpha * torch.mm(A_slice, B_slice) + beta * C[i]
|
|
542
|
+
out[i].copy_(result)
|
|
543
|
+
return out
|
|
307
544
|
|
|
308
545
|
|
|
309
546
|
def gemm_add_inplace(
|
|
547
|
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
|
310
548
|
A: Tensor,
|
|
311
|
-
B: Tensor,
|
|
312
|
-
out: Tensor,
|
|
549
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
550
|
+
out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
|
|
313
551
|
alpha: float | Tensor = 1.0,
|
|
314
552
|
beta: float | Tensor = 1.0,
|
|
553
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
554
|
+
cu_seqlens_k: Optional[Tensor] = None,
|
|
555
|
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
|
556
|
+
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
|
315
557
|
dynamic_scheduler: bool = False,
|
|
316
558
|
tuned: bool = True,
|
|
317
559
|
) -> None:
|
|
318
560
|
"""In-place GEMM with addition: out = alpha * A @ B + beta * out.
|
|
319
561
|
Args:
|
|
320
|
-
A: (M, K) input tensor
|
|
321
|
-
B: (K, N) input tensor
|
|
322
|
-
out: (M, N) tensor to accumulate into (modified in-place)
|
|
562
|
+
A: (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k - input tensor
|
|
563
|
+
B: (K, N) or (L, K, N) or (total_K, N) if varlen_k - input tensor
|
|
564
|
+
out: (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k - tensor to accumulate into (modified in-place)
|
|
323
565
|
alpha: Scalar multiplier for A @ B
|
|
324
566
|
beta: Scalar multiplier for out
|
|
567
|
+
cu_seqlens_m: Optional cumulative sequence lengths for variable M
|
|
568
|
+
cu_seqlens_k: Optional cumulative sequence lengths for variable K
|
|
325
569
|
dynamic_scheduler: Whether to use dynamic scheduler
|
|
326
570
|
tuned: Whether to use autotuned configuration
|
|
327
571
|
"""
|
|
@@ -329,7 +573,21 @@ def gemm_add_inplace(
|
|
|
329
573
|
alpha = alpha if isinstance(alpha, float) else 1.0
|
|
330
574
|
beta_tensor = beta if not isinstance(beta, float) else None
|
|
331
575
|
beta = beta if isinstance(beta, float) else 1.0
|
|
332
|
-
gemm_add_inplace_op(
|
|
576
|
+
gemm_add_inplace_op(
|
|
577
|
+
A,
|
|
578
|
+
B,
|
|
579
|
+
out,
|
|
580
|
+
alpha,
|
|
581
|
+
beta,
|
|
582
|
+
alpha_tensor,
|
|
583
|
+
beta_tensor,
|
|
584
|
+
cu_seqlens_m,
|
|
585
|
+
cu_seqlens_k,
|
|
586
|
+
A_idx=A_idx,
|
|
587
|
+
batch_idx_permute=batch_idx_permute,
|
|
588
|
+
dynamic_scheduler=dynamic_scheduler,
|
|
589
|
+
tuned=tuned,
|
|
590
|
+
)
|
|
333
591
|
|
|
334
592
|
|
|
335
593
|
@torch.library.custom_op(
|
|
@@ -338,46 +596,79 @@ def gemm_add_inplace(
|
|
|
338
596
|
device_types="cuda",
|
|
339
597
|
# We have to split out alpha and alpha_tensor since torch.library requires
|
|
340
598
|
# each argument to have a fixed type
|
|
341
|
-
# schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
599
|
+
# schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, Tensor? cu_seqlens_m=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
342
600
|
)
|
|
343
601
|
def gemm_add_inplace_op(
|
|
602
|
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
|
344
603
|
A: Tensor,
|
|
345
|
-
B: Tensor,
|
|
346
|
-
out: Tensor,
|
|
604
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
605
|
+
out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
|
|
347
606
|
alpha: float = 1.0,
|
|
348
607
|
beta: float = 1.0,
|
|
349
608
|
alpha_tensor: Optional[Tensor] = None,
|
|
350
609
|
beta_tensor: Optional[Tensor] = None,
|
|
610
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
611
|
+
cu_seqlens_k: Optional[Tensor] = None,
|
|
612
|
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
|
613
|
+
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
|
351
614
|
dynamic_scheduler: bool = False,
|
|
352
615
|
tuned: bool = True,
|
|
353
616
|
) -> None:
|
|
354
617
|
fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
|
|
355
618
|
alpha = alpha_tensor if alpha_tensor is not None else alpha
|
|
356
619
|
beta = beta_tensor if beta_tensor is not None else beta
|
|
357
|
-
|
|
358
|
-
|
|
620
|
+
add_to_output = isinstance(beta, float) and beta == 1.0 and cu_seqlens_m is None
|
|
621
|
+
# Use out as both input bias and output
|
|
622
|
+
fn(
|
|
623
|
+
A,
|
|
624
|
+
B,
|
|
625
|
+
out,
|
|
626
|
+
out if not add_to_output else None,
|
|
627
|
+
alpha=alpha,
|
|
628
|
+
beta=beta,
|
|
629
|
+
cu_seqlens_m=cu_seqlens_m,
|
|
630
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
631
|
+
A_idx=A_idx,
|
|
632
|
+
batch_idx_permute=batch_idx_permute,
|
|
633
|
+
add_to_output=add_to_output,
|
|
634
|
+
dynamic_scheduler=dynamic_scheduler,
|
|
635
|
+
)
|
|
359
636
|
|
|
360
637
|
|
|
361
638
|
def gemm_act(
|
|
362
|
-
A: Tensor,
|
|
363
|
-
B: Tensor,
|
|
364
|
-
C: Optional[Tensor] = None,
|
|
639
|
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
|
640
|
+
B: Tensor, # (K, N) or (L, K, N)
|
|
641
|
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
365
642
|
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
366
|
-
preact_out: Optional[Tensor] = None,
|
|
367
|
-
postact_out: Optional[Tensor] = None,
|
|
643
|
+
preact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
644
|
+
postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
368
645
|
out_dtype: Optional[torch.dtype] = None,
|
|
369
646
|
postact_dtype: Optional[torch.dtype] = None,
|
|
647
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
648
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
370
649
|
store_preact: bool = True,
|
|
650
|
+
dynamic_scheduler: bool = False,
|
|
371
651
|
tuned: bool = True,
|
|
372
652
|
) -> Tuple[Optional[Tensor], Tensor]:
|
|
373
653
|
"""GEMM with activation and optional output tensors."""
|
|
374
654
|
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
375
655
|
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
|
656
|
+
varlen_m = cu_seqlens_m is not None
|
|
657
|
+
# Determine output shape based on gather_A
|
|
658
|
+
if varlen_m:
|
|
659
|
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
|
660
|
+
out_shape = (total_m, B.shape[-1])
|
|
661
|
+
elif A.ndim == 2:
|
|
662
|
+
out_shape = (A.shape[0], B.shape[-1])
|
|
663
|
+
else:
|
|
664
|
+
out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
|
|
376
665
|
if preact_out is None and store_preact:
|
|
377
|
-
preact_out = torch.empty(
|
|
666
|
+
preact_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
|
378
667
|
if postact_out is None:
|
|
379
|
-
postact_out = torch.empty(
|
|
380
|
-
gemm_act_out(
|
|
668
|
+
postact_out = torch.empty(out_shape, dtype=postact_dtype, device=A.device)
|
|
669
|
+
gemm_act_out(
|
|
670
|
+
A, B, preact_out, postact_out, C, activation, cu_seqlens_m, A_idx, dynamic_scheduler, tuned
|
|
671
|
+
)
|
|
381
672
|
return preact_out, postact_out
|
|
382
673
|
|
|
383
674
|
|
|
@@ -385,58 +676,79 @@ def gemm_act(
|
|
|
385
676
|
"quack::gemm_act_out",
|
|
386
677
|
mutates_args=("preact_out", "postact_out"),
|
|
387
678
|
device_types="cuda",
|
|
388
|
-
schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, str? activation=None, bool tuned=True) -> ()",
|
|
679
|
+
schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, str? activation=None, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
389
680
|
)
|
|
390
681
|
def gemm_act_out(
|
|
391
|
-
A: Tensor,
|
|
392
|
-
B: Tensor,
|
|
393
|
-
preact_out: Optional[Tensor],
|
|
394
|
-
postact_out: Tensor,
|
|
395
|
-
C: Optional[Tensor] = None,
|
|
682
|
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
|
683
|
+
B: Tensor, # (K, N) or (L, K, N)
|
|
684
|
+
preact_out: Optional[Tensor], # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
685
|
+
postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
686
|
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
396
687
|
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
688
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
689
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
690
|
+
dynamic_scheduler: bool = False,
|
|
397
691
|
tuned: bool = True,
|
|
398
692
|
) -> None:
|
|
399
693
|
"""GEMM with activation and pre-allocated output tensors."""
|
|
400
694
|
fn = gemm_act_tuned if tuned else partial(gemm_act_tuned.fn, config=None)
|
|
401
|
-
fn(A, B, preact_out, postact_out, C, activation)
|
|
695
|
+
fn(A, B, preact_out, postact_out, C, activation, cu_seqlens_m, A_idx, dynamic_scheduler)
|
|
402
696
|
|
|
403
697
|
|
|
404
698
|
def gemm_act_ref(
|
|
405
|
-
A: Tensor,
|
|
406
|
-
B: Tensor,
|
|
407
|
-
C: Optional[Tensor] = None,
|
|
699
|
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
|
|
700
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
701
|
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
408
702
|
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
703
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
704
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
409
705
|
out_dtype: Optional[torch.dtype] = None,
|
|
410
706
|
postact_dtype: Optional[torch.dtype] = None,
|
|
411
707
|
store_preact: bool = True,
|
|
412
708
|
) -> Tuple[Optional[Tensor], Tensor]:
|
|
413
709
|
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
414
710
|
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
|
415
|
-
|
|
711
|
+
if C is None:
|
|
712
|
+
out = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
|
|
713
|
+
else:
|
|
714
|
+
out = gemm_add_ref(A, B, C, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
|
|
416
715
|
postact = act_to_pytorch_fn_map[activation](out).to(postact_dtype)
|
|
417
716
|
return out.to(out_dtype) if store_preact else None, postact
|
|
418
717
|
|
|
419
718
|
|
|
420
719
|
def gemm_dact(
|
|
421
|
-
A: Tensor,
|
|
422
|
-
B: Tensor,
|
|
423
|
-
PreAct: Tensor,
|
|
720
|
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
|
721
|
+
B: Tensor, # (K, N) or (L, K, N)
|
|
722
|
+
PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
424
723
|
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
425
|
-
dx_out: Optional[Tensor] = None,
|
|
426
|
-
postact_out: Optional[Tensor] = None,
|
|
724
|
+
dx_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
725
|
+
postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
427
726
|
out_dtype: Optional[torch.dtype] = None,
|
|
428
727
|
postact_dtype: Optional[torch.dtype] = None,
|
|
728
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
729
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
429
730
|
dynamic_scheduler: bool = True,
|
|
430
731
|
tuned: bool = True,
|
|
431
732
|
) -> Tuple[Tensor, Tensor]:
|
|
432
733
|
"""GEMM with activation gradient and optional output tensors."""
|
|
433
734
|
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
434
735
|
postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
|
|
736
|
+
varlen_m = cu_seqlens_m is not None
|
|
737
|
+
# Determine output shape based on gather_A
|
|
738
|
+
if varlen_m:
|
|
739
|
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
|
740
|
+
out_shape = (total_m, B.shape[-1])
|
|
741
|
+
elif A.ndim == 2:
|
|
742
|
+
out_shape = (A.shape[0], B.shape[-1])
|
|
743
|
+
else:
|
|
744
|
+
out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
|
|
435
745
|
if dx_out is None:
|
|
436
|
-
dx_out = torch.empty(
|
|
746
|
+
dx_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
|
437
747
|
if postact_out is None:
|
|
438
|
-
postact_out = torch.empty(
|
|
439
|
-
gemm_dact_out(
|
|
748
|
+
postact_out = torch.empty(out_shape, dtype=postact_dtype, device=A.device)
|
|
749
|
+
gemm_dact_out(
|
|
750
|
+
A, B, PreAct, dx_out, postact_out, activation, cu_seqlens_m, A_idx, dynamic_scheduler, tuned
|
|
751
|
+
)
|
|
440
752
|
return dx_out, postact_out
|
|
441
753
|
|
|
442
754
|
|
|
@@ -444,35 +756,39 @@ def gemm_dact(
|
|
|
444
756
|
"quack::gemm_dact_out",
|
|
445
757
|
mutates_args=("dx_out", "postact_out"),
|
|
446
758
|
device_types="cuda",
|
|
447
|
-
schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a3!) dx_out, Tensor(a4!) postact_out, str? activation=None, bool dynamic_scheduler=True, bool tuned=True) -> ()",
|
|
759
|
+
schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a3!) dx_out, Tensor(a4!) postact_out, str? activation=None, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=True, bool tuned=True) -> ()",
|
|
448
760
|
)
|
|
449
761
|
def gemm_dact_out(
|
|
450
|
-
A: Tensor,
|
|
451
|
-
B: Tensor,
|
|
452
|
-
PreAct: Tensor,
|
|
453
|
-
dx_out: Tensor,
|
|
454
|
-
postact_out: Tensor,
|
|
762
|
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
|
763
|
+
B: Tensor, # (K, N) or (L, K, N)
|
|
764
|
+
PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
765
|
+
dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
766
|
+
postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
455
767
|
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
768
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
769
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
456
770
|
dynamic_scheduler: bool = True,
|
|
457
771
|
tuned: bool = True,
|
|
458
772
|
) -> None:
|
|
459
773
|
"""GEMM with activation gradient and pre-allocated output tensors."""
|
|
460
774
|
fn = gemm_dact_tuned if tuned else partial(gemm_dact_tuned.fn, config=None)
|
|
461
|
-
fn(A, B, PreAct, dx_out, postact_out, activation, dynamic_scheduler)
|
|
775
|
+
fn(A, B, PreAct, dx_out, postact_out, activation, cu_seqlens_m, A_idx, dynamic_scheduler)
|
|
462
776
|
|
|
463
777
|
|
|
464
778
|
def gemm_dact_ref(
|
|
465
|
-
A: Tensor,
|
|
466
|
-
B: Tensor,
|
|
467
|
-
PreAct: Tensor,
|
|
779
|
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
|
|
780
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
781
|
+
PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
468
782
|
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
783
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
784
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
469
785
|
out_dtype: Optional[torch.dtype] = None,
|
|
470
786
|
postact_dtype: Optional[torch.dtype] = None,
|
|
471
787
|
) -> Tuple[Tensor, Tensor]:
|
|
472
788
|
"""Reference implementation for GEMM with activation gradient."""
|
|
473
789
|
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
474
790
|
postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
|
|
475
|
-
dout =
|
|
791
|
+
dout = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx).to(out_dtype)
|
|
476
792
|
postact = act_to_pytorch_fn_map[activation](PreAct)
|
|
477
793
|
# Compute gradient using autograd
|
|
478
794
|
if activation is None:
|
|
@@ -487,10 +803,12 @@ def gemm_dact_ref(
|
|
|
487
803
|
|
|
488
804
|
|
|
489
805
|
def gemm_gated_ref(
|
|
490
|
-
A: Tensor,
|
|
491
|
-
B: Tensor,
|
|
492
|
-
C: Optional[Tensor] = None,
|
|
806
|
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
|
|
807
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
808
|
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
493
809
|
activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"] = "swiglu",
|
|
810
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
811
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
494
812
|
out_dtype: Optional[torch.dtype] = None,
|
|
495
813
|
postact_dtype: Optional[torch.dtype] = None,
|
|
496
814
|
store_preact: bool = True,
|
|
@@ -499,8 +817,8 @@ def gemm_gated_ref(
|
|
|
499
817
|
|
|
500
818
|
Args:
|
|
501
819
|
A: (M, K) - input tensor
|
|
502
|
-
B: (K,
|
|
503
|
-
C: (M,
|
|
820
|
+
B: (K, N) - weight tensor with gate and up projections
|
|
821
|
+
C: (M, N) - optional bias tensor
|
|
504
822
|
activation: Type of gated activation
|
|
505
823
|
out_dtype: Output dtype for preact
|
|
506
824
|
postact_dtype: Output dtype for postact
|
|
@@ -508,24 +826,29 @@ def gemm_gated_ref(
|
|
|
508
826
|
|
|
509
827
|
Returns:
|
|
510
828
|
(preact, postact) where:
|
|
511
|
-
- preact: (M,
|
|
512
|
-
- postact: (M, N) post-activation output
|
|
829
|
+
- preact: (M, N) pre-activation (if store_preact=True, else None)
|
|
830
|
+
- postact: (M, N // 2) post-activation output
|
|
513
831
|
"""
|
|
514
832
|
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
515
833
|
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
|
516
|
-
|
|
834
|
+
if C is None:
|
|
835
|
+
preact = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
|
|
836
|
+
else:
|
|
837
|
+
preact = gemm_add_ref(A, B, C, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
|
|
517
838
|
# Split preact into gate and up projections
|
|
518
|
-
gate = preact[..., ::2] # (M, N)
|
|
519
|
-
up = preact[..., 1::2] # (M, N)
|
|
839
|
+
gate = preact[..., ::2] # (M, N//2)
|
|
840
|
+
up = preact[..., 1::2] # (M, N//2)
|
|
520
841
|
postact = gated_to_pytorch_fn_map[activation](gate, up)
|
|
521
842
|
return preact.to(out_dtype) if store_preact else None, postact.to(postact_dtype)
|
|
522
843
|
|
|
523
844
|
|
|
524
845
|
def gemm_dgated_ref(
|
|
525
|
-
A: Tensor,
|
|
526
|
-
B: Tensor,
|
|
527
|
-
PreAct: Tensor,
|
|
846
|
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
|
|
847
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
848
|
+
PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
|
|
528
849
|
activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"],
|
|
850
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
851
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
529
852
|
out_dtype: Optional[torch.dtype] = None,
|
|
530
853
|
postact_dtype: Optional[torch.dtype] = None,
|
|
531
854
|
) -> Tuple[Tensor, Tensor]:
|
|
@@ -546,7 +869,7 @@ def gemm_dgated_ref(
|
|
|
546
869
|
"""
|
|
547
870
|
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
548
871
|
postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
|
|
549
|
-
dout =
|
|
872
|
+
dout = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx).to(out_dtype)
|
|
550
873
|
# Split PreAct into gate and up projections
|
|
551
874
|
gate = PreAct[..., ::2] # (M, N)
|
|
552
875
|
up = PreAct[..., 1::2] # (M, N)
|