quack-kernels 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -8
- quack/activation.py +366 -121
- quack/autotuner.py +64 -5
- quack/broadcast_utils.py +29 -0
- quack/compile_utils.py +19 -0
- quack/copy_utils.py +487 -0
- quack/cross_entropy.py +157 -233
- quack/cute_dsl_utils.py +20 -35
- quack/gemm.py +194 -0
- quack/gemm_act.py +510 -0
- quack/gemm_config.py +72 -46
- quack/gemm_dact.py +215 -0
- quack/gemm_default_epi.py +259 -0
- quack/gemm_interface.py +615 -146
- quack/{dense_gemm_sm100.py → gemm_sm100.py} +1034 -787
- quack/{dense_gemm_sm90.py → gemm_sm90.py} +552 -727
- quack/gemm_symmetric.py +330 -0
- quack/gemm_wrapper_utils.py +182 -23
- quack/layout_utils.py +287 -0
- quack/linear.py +24 -16
- quack/pipeline.py +158 -3
- quack/reduce.py +88 -49
- quack/reduction_base.py +25 -36
- quack/rmsnorm.py +508 -624
- quack/sm100_utils.py +62 -0
- quack/sm90_utils.py +127 -0
- quack/softmax.py +135 -203
- quack/sort/bitonic_sort.py +13 -10
- quack/sort/utils.py +6 -6
- quack/tile_scheduler.py +55 -61
- quack/topk.py +409 -85
- quack/utils.py +37 -172
- quack/varlen_utils.py +370 -6
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/METADATA +4 -2
- quack_kernels-0.2.3.dist-info/RECORD +44 -0
- quack/gemm_act_sm90.py +0 -368
- quack/gemm_dact_sm90.py +0 -150
- quack/layernorm.py +0 -353
- quack/symmetric_dense_gemm_sm90.py +0 -2091
- quack_kernels-0.2.1.dist-info/RECORD +0 -37
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.1.dist-info → quack_kernels-0.2.3.dist-info}/top_level.txt +0 -0
quack/gemm_interface.py
CHANGED
|
@@ -9,9 +9,11 @@ from torch import Tensor
|
|
|
9
9
|
from quack.gemm_config import GemmConfig, get_all_configs
|
|
10
10
|
|
|
11
11
|
from quack.autotuner import autotune, AutotuneConfig
|
|
12
|
-
from quack.
|
|
13
|
-
from quack.
|
|
14
|
-
from quack.
|
|
12
|
+
from quack.cute_dsl_utils import get_device_capacity
|
|
13
|
+
from quack.gemm import gemm as gemm_sm90_sm100
|
|
14
|
+
from quack.gemm_act import gemm_act as gemm_act_sm90_sm100
|
|
15
|
+
from quack.gemm_dact import gemm_dact as gemm_dact_sm90_sm100
|
|
16
|
+
from quack.gemm_symmetric import gemm_symmetric as gemm_symmetric_sm90_sm100
|
|
15
17
|
|
|
16
18
|
|
|
17
19
|
# Dictionary mapping activation names to PyTorch functions
|
|
@@ -34,34 +36,89 @@ gated_to_pytorch_fn_map = {
|
|
|
34
36
|
}
|
|
35
37
|
|
|
36
38
|
|
|
39
|
+
default_device_capacity = get_device_capacity(torch.device("cuda"))
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def default_config(device):
|
|
43
|
+
if get_device_capacity(device)[0] != 10:
|
|
44
|
+
return GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True)
|
|
45
|
+
else:
|
|
46
|
+
return GemmConfig(tile_m=256, tile_n=256, cluster_m=2, cluster_n=1, pingpong=False)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def prune_invalid_gemm_configs(configs, named_args: dict, **kwargs):
|
|
50
|
+
kwargs = named_args | kwargs
|
|
51
|
+
gather_A = kwargs.get("A_idx", None) is not None
|
|
52
|
+
varlen_m = kwargs.get("cu_seqlens_m", None) is not None
|
|
53
|
+
if varlen_m or gather_A: # Doesn't support swap_ab
|
|
54
|
+
configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab]
|
|
55
|
+
if gather_A:
|
|
56
|
+
if get_device_capacity(kwargs["A"].device)[0] == 9:
|
|
57
|
+
# tile_n == 208 causes register spills, as gather_A requires more registers for the producer
|
|
58
|
+
configs = [
|
|
59
|
+
conf
|
|
60
|
+
for conf in configs
|
|
61
|
+
if conf.kwargs["config"].cluster_n == 1 and conf.kwargs["config"].tile_n != 208
|
|
62
|
+
]
|
|
63
|
+
return configs
|
|
64
|
+
|
|
65
|
+
|
|
37
66
|
@autotune(
|
|
38
|
-
configs=[AutotuneConfig(config=c) for c in get_all_configs()],
|
|
67
|
+
configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])],
|
|
39
68
|
key=["dynamic_scheduler"],
|
|
69
|
+
prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
|
|
40
70
|
)
|
|
41
71
|
def gemm_tuned(
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
72
|
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
|
73
|
+
A: Tensor,
|
|
74
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
75
|
+
out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
76
|
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
77
|
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
|
46
78
|
alpha: float | Tensor = 1.0, # (1,)
|
|
47
79
|
beta: float | Tensor = 1.0, # (1,)
|
|
80
|
+
cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
|
|
81
|
+
cu_seqlens_k: Optional[Tensor] = None, # (L+1), int32
|
|
82
|
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
|
83
|
+
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
|
84
|
+
add_to_output: bool = False,
|
|
48
85
|
dynamic_scheduler: bool = False,
|
|
49
86
|
config: Optional[GemmConfig] = None,
|
|
50
87
|
) -> None:
|
|
51
88
|
if config is None:
|
|
52
|
-
config =
|
|
53
|
-
|
|
54
|
-
|
|
89
|
+
config = default_config(A.device)
|
|
90
|
+
varlen_m = cu_seqlens_m is not None
|
|
91
|
+
varlen_k = cu_seqlens_k is not None
|
|
92
|
+
varlen = varlen_m or varlen_k
|
|
93
|
+
gather_A = A_idx is not None
|
|
94
|
+
if gather_A:
|
|
95
|
+
assert varlen, "gather_A requires either varlen_m or varlen_k"
|
|
96
|
+
assert config.cluster_n == 1, "gather_A requires cluster_n=1"
|
|
97
|
+
if varlen_m:
|
|
98
|
+
assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
|
|
99
|
+
if A.ndim == 2 and not varlen:
|
|
100
|
+
A = A.unsqueeze(0) # (1, M, K)
|
|
101
|
+
B = B.mT # (N, K) or (L, N, K) or (N, total_K)
|
|
102
|
+
if B.ndim == 2 and not varlen_k:
|
|
103
|
+
B = B.unsqueeze(0) # (1, N, K)
|
|
104
|
+
if C is not None and C.ndim == 2 and not varlen_m:
|
|
55
105
|
C = C.unsqueeze(0) # (1, M, N)
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
106
|
+
if out.ndim == 2 and not varlen_m:
|
|
107
|
+
out = out.unsqueeze(0)
|
|
108
|
+
if bias is not None and bias.ndim == 1:
|
|
109
|
+
bias = bias.unsqueeze(0) # (L, N)
|
|
110
|
+
batch_size = B.shape[0] if not varlen_k else cu_seqlens_k.shape[0] - 1
|
|
111
|
+
if varlen_m:
|
|
112
|
+
# If gather_A (A_idx provided), use its length; otherwise use A.shape[0]
|
|
113
|
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
|
114
|
+
out_shape = (total_m, B.shape[-2])
|
|
115
|
+
else:
|
|
116
|
+
out_shape = (batch_size, A.shape[-2], B.shape[-2])
|
|
117
|
+
assert out.shape == out_shape, f"out shape mismatch: {out.shape} vs {out_shape}"
|
|
61
118
|
tile_count_semaphore = (
|
|
62
119
|
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
|
|
63
120
|
)
|
|
64
|
-
|
|
121
|
+
gemm_sm90_sm100(
|
|
65
122
|
A if not config.swap_ab else B,
|
|
66
123
|
B if not config.swap_ab else A,
|
|
67
124
|
out if not config.swap_ab else out.mT,
|
|
@@ -72,77 +129,129 @@ def gemm_tuned(
|
|
|
72
129
|
config.cluster_m,
|
|
73
130
|
config.cluster_n,
|
|
74
131
|
config.pingpong,
|
|
132
|
+
persistent=True,
|
|
133
|
+
max_swizzle_size=config.max_swizzle_size,
|
|
134
|
+
rowvec_bias=bias if not config.swap_ab else None,
|
|
135
|
+
colvec_bias=bias if config.swap_ab else None,
|
|
75
136
|
alpha=alpha,
|
|
76
137
|
beta=beta,
|
|
138
|
+
cu_seqlens_m=cu_seqlens_m,
|
|
139
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
140
|
+
A_idx=A_idx,
|
|
141
|
+
batch_idx_permute=batch_idx_permute,
|
|
142
|
+
add_to_output=add_to_output,
|
|
77
143
|
)
|
|
78
144
|
|
|
79
145
|
|
|
80
146
|
@autotune(
|
|
81
|
-
configs=[AutotuneConfig(config=c) for c in get_all_configs()],
|
|
82
|
-
key=["activation"],
|
|
147
|
+
configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])],
|
|
148
|
+
key=["activation", "dynamic_scheduler"],
|
|
149
|
+
prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
|
|
83
150
|
)
|
|
84
151
|
def gemm_act_tuned(
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
152
|
+
# (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
|
153
|
+
A: Tensor,
|
|
154
|
+
B: Tensor, # (K, N) or (L, K, N)
|
|
155
|
+
# (M, N) or (L, M, N) or (total_M, N) if varlen_m - None if not storing preact
|
|
156
|
+
preact_out: Optional[Tensor],
|
|
157
|
+
postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
158
|
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
159
|
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
|
90
160
|
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
161
|
+
cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
|
|
162
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
163
|
+
dynamic_scheduler: bool = False,
|
|
91
164
|
config: Optional[GemmConfig] = None,
|
|
92
165
|
) -> None:
|
|
93
166
|
if config is None:
|
|
94
|
-
config =
|
|
95
|
-
|
|
96
|
-
if
|
|
167
|
+
config = default_config(A.device)
|
|
168
|
+
varlen_m = cu_seqlens_m is not None
|
|
169
|
+
if varlen_m:
|
|
170
|
+
assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
|
|
171
|
+
if A.ndim == 2 and not varlen_m:
|
|
172
|
+
A = A.unsqueeze(0) # (1, M, K)
|
|
173
|
+
B = B.mT # (N, K) or (L, N, K)
|
|
174
|
+
if B.ndim == 2:
|
|
175
|
+
B = B.unsqueeze(0) # (1, N, K)
|
|
176
|
+
if C is not None and C.ndim == 2 and not varlen_m:
|
|
97
177
|
C = C.unsqueeze(0) # (1, M, N)
|
|
98
|
-
if preact_out is not None:
|
|
99
|
-
assert preact_out.shape == (A.shape[1], B.shape[1])
|
|
178
|
+
if preact_out is not None and preact_out.ndim == 2 and not varlen_m:
|
|
100
179
|
D = preact_out.unsqueeze(0)
|
|
101
180
|
else:
|
|
102
|
-
D =
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
181
|
+
D = preact_out
|
|
182
|
+
if postact_out.ndim == 2 and not varlen_m:
|
|
183
|
+
PostAct = postact_out.unsqueeze(0)
|
|
184
|
+
else:
|
|
185
|
+
PostAct = postact_out
|
|
186
|
+
if bias is not None and bias.ndim == 1:
|
|
187
|
+
bias = bias.unsqueeze(0) # (L, N)
|
|
188
|
+
tile_count_semaphore = (
|
|
189
|
+
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
|
|
190
|
+
)
|
|
191
|
+
gemm_act_sm90_sm100(
|
|
106
192
|
A if not config.swap_ab else B,
|
|
107
193
|
B if not config.swap_ab else A,
|
|
108
194
|
(D if not config.swap_ab else D.mT) if D is not None else None,
|
|
109
195
|
(C if not config.swap_ab else C.mT) if C is not None else None,
|
|
110
196
|
PostAct if not config.swap_ab else PostAct.mT,
|
|
197
|
+
tile_count_semaphore,
|
|
111
198
|
activation,
|
|
112
199
|
config.tile_m,
|
|
113
200
|
config.tile_n,
|
|
114
201
|
config.cluster_m,
|
|
115
202
|
config.cluster_n,
|
|
116
203
|
config.pingpong,
|
|
204
|
+
persistent=True,
|
|
205
|
+
max_swizzle_size=config.max_swizzle_size,
|
|
206
|
+
rowvec_bias=bias if not config.swap_ab else None,
|
|
207
|
+
colvec_bias=bias if config.swap_ab else None,
|
|
208
|
+
cu_seqlens_m=cu_seqlens_m,
|
|
209
|
+
A_idx=A_idx,
|
|
117
210
|
)
|
|
118
211
|
|
|
119
212
|
|
|
120
213
|
@autotune(
|
|
121
|
-
configs=[AutotuneConfig(config=c) for c in get_all_configs()],
|
|
214
|
+
configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])],
|
|
122
215
|
key=["activation", "dynamic_scheduler"],
|
|
216
|
+
prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
|
|
123
217
|
)
|
|
124
218
|
def gemm_dact_tuned(
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
219
|
+
# (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
|
220
|
+
A: Tensor,
|
|
221
|
+
B: Tensor, # (K, N) or (L, K, N)
|
|
222
|
+
PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
223
|
+
dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
224
|
+
postact_out: Tensor, # (M, N) or (L, N, N) or (total_M, N) if varlen_m
|
|
130
225
|
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
226
|
+
cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
|
|
227
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
131
228
|
dynamic_scheduler: bool = True,
|
|
132
229
|
config: Optional[GemmConfig] = None,
|
|
133
230
|
) -> None:
|
|
134
231
|
if config is None:
|
|
135
|
-
config =
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
232
|
+
config = default_config(A.device)
|
|
233
|
+
varlen_m = cu_seqlens_m is not None
|
|
234
|
+
if varlen_m:
|
|
235
|
+
assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
|
|
236
|
+
if A.ndim == 2 and not varlen_m:
|
|
237
|
+
A = A.unsqueeze(0) # (1, M, K)
|
|
238
|
+
B = B.mT # (N, K) or (L, N, K)
|
|
239
|
+
if B.ndim == 2:
|
|
240
|
+
B = B.unsqueeze(0) # (1, N, K)
|
|
241
|
+
if PreAct.ndim == 2 and not varlen_m:
|
|
242
|
+
PreAct = PreAct.unsqueeze(0) # (1, M, N)
|
|
243
|
+
if dx_out.ndim == 2 and not varlen_m:
|
|
244
|
+
D = dx_out.unsqueeze(0)
|
|
245
|
+
else:
|
|
246
|
+
D = dx_out
|
|
247
|
+
if postact_out.ndim == 2 and not varlen_m:
|
|
248
|
+
PostAct = postact_out.unsqueeze(0)
|
|
249
|
+
else:
|
|
250
|
+
PostAct = postact_out
|
|
142
251
|
tile_count_semaphore = (
|
|
143
252
|
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
|
|
144
253
|
)
|
|
145
|
-
|
|
254
|
+
gemm_dact_sm90_sm100(
|
|
146
255
|
A if not config.swap_ab else B,
|
|
147
256
|
B if not config.swap_ab else A,
|
|
148
257
|
D if not config.swap_ab else D.mT,
|
|
@@ -155,30 +264,58 @@ def gemm_dact_tuned(
|
|
|
155
264
|
config.cluster_m,
|
|
156
265
|
config.cluster_n,
|
|
157
266
|
config.pingpong,
|
|
267
|
+
persistent=True,
|
|
268
|
+
max_swizzle_size=config.max_swizzle_size,
|
|
269
|
+
cu_seqlens_m=cu_seqlens_m,
|
|
270
|
+
A_idx=A_idx,
|
|
158
271
|
)
|
|
159
272
|
|
|
160
273
|
|
|
161
274
|
def gemm(
|
|
275
|
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
|
162
276
|
A: Tensor,
|
|
163
|
-
B: Tensor,
|
|
164
|
-
out: Optional[Tensor] = None,
|
|
277
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
278
|
+
out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
279
|
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
|
165
280
|
alpha: float | Tensor = 1.0,
|
|
166
281
|
out_dtype: Optional[torch.dtype] = None,
|
|
282
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
283
|
+
cu_seqlens_k: Optional[Tensor] = None,
|
|
284
|
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
|
285
|
+
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
|
167
286
|
dynamic_scheduler: bool = False,
|
|
168
287
|
tuned: bool = True,
|
|
169
288
|
) -> Tensor:
|
|
170
289
|
"""GEMM with optional output tensor and tuning control."""
|
|
171
290
|
if out is None:
|
|
172
291
|
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
173
|
-
|
|
292
|
+
varlen_m = cu_seqlens_m is not None
|
|
293
|
+
varlen_k = cu_seqlens_k is not None
|
|
294
|
+
if varlen_m:
|
|
295
|
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
|
296
|
+
out_shape = (total_m, B.shape[-1])
|
|
297
|
+
elif varlen_k:
|
|
298
|
+
L = cu_seqlens_k.shape[0] - 1
|
|
299
|
+
# For varlen_k, the first dimension is always A.shape[0] (M dimension)
|
|
300
|
+
out_shape = (L, A.shape[0], B.shape[-1])
|
|
301
|
+
else:
|
|
302
|
+
out_shape = (
|
|
303
|
+
(A.shape[0], B.shape[-1]) if A.ndim == 2 else (A.shape[0], A.shape[-2], B.shape[-1])
|
|
304
|
+
)
|
|
305
|
+
out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
|
174
306
|
alpha_tensor = alpha if not isinstance(alpha, float) else None
|
|
175
307
|
alpha = alpha if isinstance(alpha, float) else 1.0
|
|
176
308
|
gemm_out(
|
|
177
309
|
A,
|
|
178
310
|
B,
|
|
179
311
|
out,
|
|
312
|
+
bias=bias,
|
|
180
313
|
alpha=alpha,
|
|
181
314
|
alpha_tensor=alpha_tensor,
|
|
315
|
+
cu_seqlens_m=cu_seqlens_m,
|
|
316
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
317
|
+
A_idx=A_idx,
|
|
318
|
+
batch_idx_permute=batch_idx_permute,
|
|
182
319
|
dynamic_scheduler=dynamic_scheduler,
|
|
183
320
|
tuned=tuned,
|
|
184
321
|
)
|
|
@@ -191,53 +328,134 @@ def gemm(
|
|
|
191
328
|
device_types="cuda",
|
|
192
329
|
# We have to split out alpha and alpha_tensor since torch.library requires
|
|
193
330
|
# each argument to have a fixed type
|
|
194
|
-
# schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, Tensor? alpha_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
331
|
+
# schema="(Tensor A, Tensor B, Tensor(a2!) out, Tensor? bias, float alpha=1.0, Tensor? alpha_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
195
332
|
)
|
|
196
333
|
def gemm_out(
|
|
334
|
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
|
197
335
|
A: Tensor,
|
|
198
|
-
B: Tensor,
|
|
199
|
-
out: Tensor,
|
|
336
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
337
|
+
out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
338
|
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
|
200
339
|
alpha: float = 1.0,
|
|
201
340
|
alpha_tensor: Optional[Tensor] = None,
|
|
341
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
342
|
+
cu_seqlens_k: Optional[Tensor] = None,
|
|
343
|
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
|
344
|
+
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
|
202
345
|
dynamic_scheduler: bool = False,
|
|
203
346
|
tuned: bool = True,
|
|
204
347
|
) -> None:
|
|
205
348
|
"""GEMM with pre-allocated output tensor."""
|
|
206
349
|
fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
|
|
207
350
|
alpha = alpha_tensor if alpha_tensor is not None else alpha
|
|
208
|
-
fn(
|
|
351
|
+
fn(
|
|
352
|
+
A,
|
|
353
|
+
B,
|
|
354
|
+
out,
|
|
355
|
+
C=None,
|
|
356
|
+
bias=bias,
|
|
357
|
+
alpha=alpha,
|
|
358
|
+
cu_seqlens_m=cu_seqlens_m,
|
|
359
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
360
|
+
A_idx=A_idx,
|
|
361
|
+
batch_idx_permute=batch_idx_permute,
|
|
362
|
+
dynamic_scheduler=dynamic_scheduler,
|
|
363
|
+
)
|
|
209
364
|
|
|
210
365
|
|
|
211
366
|
def gemm_ref(
|
|
367
|
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
|
212
368
|
A: Tensor,
|
|
213
|
-
B: Tensor,
|
|
214
|
-
out: Optional[Tensor] = None,
|
|
369
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
370
|
+
out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
371
|
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
|
215
372
|
alpha: float | Tensor = 1.0,
|
|
373
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
374
|
+
cu_seqlens_k: Optional[Tensor] = None,
|
|
375
|
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
|
216
376
|
out_dtype: Optional[torch.dtype] = None,
|
|
217
377
|
) -> Tensor:
|
|
218
378
|
"""Reference implementation for GEMM with pre-allocated output."""
|
|
219
379
|
# The out_dtype argument requires torch >= 2.8
|
|
220
|
-
|
|
221
|
-
if
|
|
222
|
-
|
|
380
|
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
381
|
+
if cu_seqlens_m is None and cu_seqlens_k is None:
|
|
382
|
+
fn = torch.bmm if A.ndim == 3 else torch.mm
|
|
383
|
+
out = fn(A, B, out_dtype=out_dtype, out=out)
|
|
384
|
+
if not isinstance(alpha, float) or alpha != 1.0:
|
|
385
|
+
out *= alpha
|
|
386
|
+
if bias is not None:
|
|
387
|
+
bias = bias if A.ndim == 2 else bias.unsqueeze(1)
|
|
388
|
+
out += bias
|
|
389
|
+
elif cu_seqlens_m is not None:
|
|
390
|
+
# Handle varlen_m case
|
|
391
|
+
if out is None:
|
|
392
|
+
# When gather_A (A_idx provided), output size is determined by A_idx length
|
|
393
|
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
|
394
|
+
out = torch.empty((total_m, B.shape[-1]), dtype=out_dtype, device=A.device)
|
|
395
|
+
for i in range(cu_seqlens_m.shape[0] - 1):
|
|
396
|
+
A_slice = (
|
|
397
|
+
A[A_idx[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]]
|
|
398
|
+
if A_idx is not None
|
|
399
|
+
else A[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
|
|
400
|
+
)
|
|
401
|
+
torch.mm(A_slice, B[i], out=out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]])
|
|
402
|
+
if not isinstance(alpha, float) or alpha != 1.0:
|
|
403
|
+
out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]] *= alpha
|
|
404
|
+
if bias is not None:
|
|
405
|
+
out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]] += bias[i]
|
|
406
|
+
else: # cu_seqlens_k is not None
|
|
407
|
+
L = cu_seqlens_k.shape[0] - 1
|
|
408
|
+
if out is None:
|
|
409
|
+
out = torch.empty((L, A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
|
|
410
|
+
for i in range(L):
|
|
411
|
+
A_slice = (
|
|
412
|
+
A[:, A_idx[cu_seqlens_k[i] : cu_seqlens_k[i + 1]]]
|
|
413
|
+
if A_idx is not None
|
|
414
|
+
else A[:, cu_seqlens_k[i] : cu_seqlens_k[i + 1]]
|
|
415
|
+
)
|
|
416
|
+
torch.mm(A_slice, B[cu_seqlens_k[i] : cu_seqlens_k[i + 1], :], out=out[i])
|
|
417
|
+
if not isinstance(alpha, float) or alpha != 1.0:
|
|
418
|
+
out *= alpha
|
|
419
|
+
if bias is not None:
|
|
420
|
+
out += bias
|
|
223
421
|
return out
|
|
224
422
|
|
|
225
423
|
|
|
226
424
|
def gemm_add(
|
|
425
|
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
|
227
426
|
A: Tensor,
|
|
228
|
-
B: Tensor,
|
|
229
|
-
C: Tensor,
|
|
230
|
-
out: Optional[Tensor] = None,
|
|
427
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
428
|
+
C: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
|
|
429
|
+
out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
231
430
|
alpha: float | Tensor = 1.0,
|
|
232
431
|
beta: float | Tensor = 1.0,
|
|
233
432
|
out_dtype: Optional[torch.dtype] = None,
|
|
433
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
434
|
+
cu_seqlens_k: Optional[Tensor] = None,
|
|
435
|
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
|
436
|
+
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
|
234
437
|
dynamic_scheduler: bool = False,
|
|
235
438
|
tuned: bool = True,
|
|
236
439
|
) -> Tensor:
|
|
237
440
|
"""GEMM with addition and optional output tensor."""
|
|
238
441
|
if out is None:
|
|
239
442
|
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
240
|
-
|
|
443
|
+
varlen_m = cu_seqlens_m is not None
|
|
444
|
+
varlen_k = cu_seqlens_k is not None
|
|
445
|
+
if varlen_m:
|
|
446
|
+
# If A_idx is provided (gather_A), use its length; otherwise use A.shape[0]
|
|
447
|
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
|
448
|
+
out_shape = (total_m, B.shape[-1])
|
|
449
|
+
elif varlen_k:
|
|
450
|
+
L = cu_seqlens_k.shape[0] - 1
|
|
451
|
+
# For varlen_k, the first dimension is always A.shape[0] (M dimension)
|
|
452
|
+
out_shape = (L, A.shape[0], B.shape[-1])
|
|
453
|
+
else:
|
|
454
|
+
out_shape = (
|
|
455
|
+
(A.shape[0], B.shape[-1]) if A.ndim == 2 else (A.shape[0], A.shape[-2], B.shape[-1])
|
|
456
|
+
)
|
|
457
|
+
out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
|
458
|
+
add_to_output = C is out and isinstance(beta, float) and beta == 1.0 and cu_seqlens_m is None
|
|
241
459
|
alpha_tensor = alpha if not isinstance(alpha, float) else None
|
|
242
460
|
alpha = alpha if isinstance(alpha, float) else 1.0
|
|
243
461
|
beta_tensor = beta if not isinstance(beta, float) else None
|
|
@@ -245,12 +463,17 @@ def gemm_add(
|
|
|
245
463
|
gemm_add_out(
|
|
246
464
|
A,
|
|
247
465
|
B,
|
|
248
|
-
C,
|
|
466
|
+
C if not add_to_output else None,
|
|
249
467
|
out,
|
|
250
468
|
alpha,
|
|
251
469
|
beta,
|
|
252
470
|
alpha_tensor,
|
|
253
471
|
beta_tensor,
|
|
472
|
+
cu_seqlens_m=cu_seqlens_m,
|
|
473
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
474
|
+
A_idx=A_idx,
|
|
475
|
+
batch_idx_permute=batch_idx_permute,
|
|
476
|
+
add_to_output=add_to_output,
|
|
254
477
|
dynamic_scheduler=dynamic_scheduler,
|
|
255
478
|
tuned=tuned,
|
|
256
479
|
)
|
|
@@ -263,17 +486,23 @@ def gemm_add(
|
|
|
263
486
|
device_types="cuda",
|
|
264
487
|
# We have to split out alpha and alpha_tensor since torch.library requires
|
|
265
488
|
# each argument to have a fixed type
|
|
266
|
-
# schema="(Tensor A, Tensor B, Tensor C, Tensor(a3!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
489
|
+
# schema="(Tensor A, Tensor B, Tensor C, Tensor(a3!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, Tensor? cu_seqlens_m=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
267
490
|
)
|
|
268
491
|
def gemm_add_out(
|
|
492
|
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
|
269
493
|
A: Tensor,
|
|
270
|
-
B: Tensor,
|
|
271
|
-
C: Tensor,
|
|
272
|
-
out: Tensor,
|
|
494
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
495
|
+
C: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
|
|
496
|
+
out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
273
497
|
alpha: float = 1.0,
|
|
274
498
|
beta: float = 1.0,
|
|
275
499
|
alpha_tensor: Optional[Tensor] = None,
|
|
276
500
|
beta_tensor: Optional[Tensor] = None,
|
|
501
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
502
|
+
cu_seqlens_k: Optional[Tensor] = None,
|
|
503
|
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
|
504
|
+
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
|
505
|
+
add_to_output: bool = False,
|
|
277
506
|
dynamic_scheduler: bool = False,
|
|
278
507
|
tuned: bool = True,
|
|
279
508
|
) -> None:
|
|
@@ -281,47 +510,112 @@ def gemm_add_out(
|
|
|
281
510
|
fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
|
|
282
511
|
alpha = alpha_tensor if alpha_tensor is not None else alpha
|
|
283
512
|
beta = beta_tensor if beta_tensor is not None else beta
|
|
284
|
-
fn(
|
|
513
|
+
fn(
|
|
514
|
+
A,
|
|
515
|
+
B,
|
|
516
|
+
out,
|
|
517
|
+
C,
|
|
518
|
+
alpha=alpha,
|
|
519
|
+
beta=beta,
|
|
520
|
+
cu_seqlens_m=cu_seqlens_m,
|
|
521
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
522
|
+
A_idx=A_idx,
|
|
523
|
+
batch_idx_permute=batch_idx_permute,
|
|
524
|
+
add_to_output=add_to_output,
|
|
525
|
+
dynamic_scheduler=dynamic_scheduler,
|
|
526
|
+
)
|
|
285
527
|
|
|
286
528
|
|
|
287
529
|
def gemm_add_ref(
|
|
530
|
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
|
288
531
|
A: Tensor,
|
|
289
|
-
B: Tensor,
|
|
290
|
-
C: Tensor,
|
|
291
|
-
|
|
532
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
533
|
+
C: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
534
|
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
|
535
|
+
out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
292
536
|
alpha: float | Tensor = 1.0,
|
|
293
537
|
beta: float | Tensor = 1.0,
|
|
538
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
539
|
+
cu_seqlens_k: Optional[Tensor] = None,
|
|
540
|
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
|
294
541
|
out_dtype: Optional[torch.dtype] = None,
|
|
295
542
|
) -> Tensor:
|
|
296
543
|
"""Reference implementation for GEMM with addition and pre-allocated output."""
|
|
297
|
-
if
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
out
|
|
306
|
-
|
|
544
|
+
if cu_seqlens_m is None and cu_seqlens_k is None:
|
|
545
|
+
if isinstance(alpha, float) and isinstance(beta, float):
|
|
546
|
+
out = torch.addmm(C, A, B, out_dtype=out_dtype, alpha=alpha, beta=beta, out=out)
|
|
547
|
+
else:
|
|
548
|
+
out_dtype = (
|
|
549
|
+
out.dtype if out is not None else (out_dtype if out_dtype is not None else A.dtype)
|
|
550
|
+
)
|
|
551
|
+
result = (alpha * (A @ B) + beta * C).to(out_dtype)
|
|
552
|
+
if out is not None:
|
|
553
|
+
out.copy_(result)
|
|
554
|
+
if bias is not None:
|
|
555
|
+
bias = bias if A.ndim == 2 else bias.unsqueeze(1)
|
|
556
|
+
out += bias
|
|
557
|
+
elif cu_seqlens_m is not None:
|
|
558
|
+
# Handle varlen_m case
|
|
559
|
+
if out is None:
|
|
560
|
+
# When gather_A (A_idx provided), output size is determined by A_idx length
|
|
561
|
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
|
562
|
+
out_dtype = out_dtype if out_dtype is not None else A.dtype
|
|
563
|
+
out = torch.empty((total_m, B.shape[-1]), dtype=out_dtype, device=A.device)
|
|
564
|
+
for i in range(cu_seqlens_m.shape[0] - 1):
|
|
565
|
+
A_slice = (
|
|
566
|
+
A[A_idx[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]]
|
|
567
|
+
if A_idx is not None
|
|
568
|
+
else A[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
|
|
569
|
+
)
|
|
570
|
+
C_slice = C[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
|
|
571
|
+
out_slice = out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]
|
|
572
|
+
result = alpha * torch.mm(A_slice, B[i]) + beta * C_slice
|
|
573
|
+
if bias is not None:
|
|
574
|
+
result += bias[i]
|
|
575
|
+
out_slice.copy_(result)
|
|
576
|
+
else: # cu_seqlens_k is not None
|
|
577
|
+
# Handle varlen_k case
|
|
578
|
+
L = cu_seqlens_k.shape[0] - 1
|
|
579
|
+
out_dtype = out_dtype if out_dtype is not None else A.dtype
|
|
580
|
+
if out is None:
|
|
581
|
+
out = torch.empty((L, A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device)
|
|
582
|
+
for i in range(L):
|
|
583
|
+
A_slice = (
|
|
584
|
+
A[:, A_idx[cu_seqlens_k[i] : cu_seqlens_k[i + 1]]]
|
|
585
|
+
if A_idx is not None
|
|
586
|
+
else A[:, cu_seqlens_k[i] : cu_seqlens_k[i + 1]]
|
|
587
|
+
)
|
|
588
|
+
B_slice = B[cu_seqlens_k[i] : cu_seqlens_k[i + 1], :]
|
|
589
|
+
result = alpha * torch.mm(A_slice, B_slice) + beta * C[i]
|
|
590
|
+
out[i].copy_(result)
|
|
591
|
+
if bias is not None:
|
|
592
|
+
out += bias
|
|
593
|
+
return out
|
|
307
594
|
|
|
308
595
|
|
|
309
596
|
def gemm_add_inplace(
|
|
597
|
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
|
310
598
|
A: Tensor,
|
|
311
|
-
B: Tensor,
|
|
312
|
-
out: Tensor,
|
|
599
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
600
|
+
out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
|
|
313
601
|
alpha: float | Tensor = 1.0,
|
|
314
602
|
beta: float | Tensor = 1.0,
|
|
603
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
604
|
+
cu_seqlens_k: Optional[Tensor] = None,
|
|
605
|
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
|
606
|
+
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
|
315
607
|
dynamic_scheduler: bool = False,
|
|
316
608
|
tuned: bool = True,
|
|
317
609
|
) -> None:
|
|
318
610
|
"""In-place GEMM with addition: out = alpha * A @ B + beta * out.
|
|
319
611
|
Args:
|
|
320
|
-
A: (M, K) input tensor
|
|
321
|
-
B: (K, N) input tensor
|
|
322
|
-
out: (M, N) tensor to accumulate into (modified in-place)
|
|
612
|
+
A: (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k - input tensor
|
|
613
|
+
B: (K, N) or (L, K, N) or (total_K, N) if varlen_k - input tensor
|
|
614
|
+
out: (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k - tensor to accumulate into (modified in-place)
|
|
323
615
|
alpha: Scalar multiplier for A @ B
|
|
324
616
|
beta: Scalar multiplier for out
|
|
617
|
+
cu_seqlens_m: Optional cumulative sequence lengths for variable M
|
|
618
|
+
cu_seqlens_k: Optional cumulative sequence lengths for variable K
|
|
325
619
|
dynamic_scheduler: Whether to use dynamic scheduler
|
|
326
620
|
tuned: Whether to use autotuned configuration
|
|
327
621
|
"""
|
|
@@ -329,7 +623,21 @@ def gemm_add_inplace(
|
|
|
329
623
|
alpha = alpha if isinstance(alpha, float) else 1.0
|
|
330
624
|
beta_tensor = beta if not isinstance(beta, float) else None
|
|
331
625
|
beta = beta if isinstance(beta, float) else 1.0
|
|
332
|
-
gemm_add_inplace_op(
|
|
626
|
+
gemm_add_inplace_op(
|
|
627
|
+
A,
|
|
628
|
+
B,
|
|
629
|
+
out,
|
|
630
|
+
alpha,
|
|
631
|
+
beta,
|
|
632
|
+
alpha_tensor,
|
|
633
|
+
beta_tensor,
|
|
634
|
+
cu_seqlens_m,
|
|
635
|
+
cu_seqlens_k,
|
|
636
|
+
A_idx=A_idx,
|
|
637
|
+
batch_idx_permute=batch_idx_permute,
|
|
638
|
+
dynamic_scheduler=dynamic_scheduler,
|
|
639
|
+
tuned=tuned,
|
|
640
|
+
)
|
|
333
641
|
|
|
334
642
|
|
|
335
643
|
@torch.library.custom_op(
|
|
@@ -338,46 +646,90 @@ def gemm_add_inplace(
|
|
|
338
646
|
device_types="cuda",
|
|
339
647
|
# We have to split out alpha and alpha_tensor since torch.library requires
|
|
340
648
|
# each argument to have a fixed type
|
|
341
|
-
# schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
649
|
+
# schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, Tensor? cu_seqlens_m=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
342
650
|
)
|
|
343
651
|
def gemm_add_inplace_op(
|
|
652
|
+
# (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k
|
|
344
653
|
A: Tensor,
|
|
345
|
-
B: Tensor,
|
|
346
|
-
out: Tensor,
|
|
654
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
655
|
+
out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k
|
|
347
656
|
alpha: float = 1.0,
|
|
348
657
|
beta: float = 1.0,
|
|
349
658
|
alpha_tensor: Optional[Tensor] = None,
|
|
350
659
|
beta_tensor: Optional[Tensor] = None,
|
|
660
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
661
|
+
cu_seqlens_k: Optional[Tensor] = None,
|
|
662
|
+
A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen
|
|
663
|
+
batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler
|
|
351
664
|
dynamic_scheduler: bool = False,
|
|
352
665
|
tuned: bool = True,
|
|
353
666
|
) -> None:
|
|
354
667
|
fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None)
|
|
355
668
|
alpha = alpha_tensor if alpha_tensor is not None else alpha
|
|
356
669
|
beta = beta_tensor if beta_tensor is not None else beta
|
|
357
|
-
|
|
358
|
-
|
|
670
|
+
add_to_output = isinstance(beta, float) and beta == 1.0 and cu_seqlens_m is None
|
|
671
|
+
# Use out as both input bias and output
|
|
672
|
+
fn(
|
|
673
|
+
A,
|
|
674
|
+
B,
|
|
675
|
+
out,
|
|
676
|
+
out if not add_to_output else None,
|
|
677
|
+
alpha=alpha,
|
|
678
|
+
beta=beta,
|
|
679
|
+
cu_seqlens_m=cu_seqlens_m,
|
|
680
|
+
cu_seqlens_k=cu_seqlens_k,
|
|
681
|
+
A_idx=A_idx,
|
|
682
|
+
batch_idx_permute=batch_idx_permute,
|
|
683
|
+
add_to_output=add_to_output,
|
|
684
|
+
dynamic_scheduler=dynamic_scheduler,
|
|
685
|
+
)
|
|
359
686
|
|
|
360
687
|
|
|
361
688
|
def gemm_act(
|
|
362
|
-
A: Tensor,
|
|
363
|
-
B: Tensor,
|
|
364
|
-
C: Optional[Tensor] = None,
|
|
689
|
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
|
690
|
+
B: Tensor, # (K, N) or (L, K, N)
|
|
691
|
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
692
|
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
|
365
693
|
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
366
|
-
preact_out: Optional[Tensor] = None,
|
|
367
|
-
postact_out: Optional[Tensor] = None,
|
|
694
|
+
preact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
695
|
+
postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
368
696
|
out_dtype: Optional[torch.dtype] = None,
|
|
369
697
|
postact_dtype: Optional[torch.dtype] = None,
|
|
698
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
699
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
370
700
|
store_preact: bool = True,
|
|
701
|
+
dynamic_scheduler: bool = False,
|
|
371
702
|
tuned: bool = True,
|
|
372
703
|
) -> Tuple[Optional[Tensor], Tensor]:
|
|
373
704
|
"""GEMM with activation and optional output tensors."""
|
|
374
705
|
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
375
706
|
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
|
707
|
+
varlen_m = cu_seqlens_m is not None
|
|
708
|
+
# Determine output shape based on gather_A
|
|
709
|
+
if varlen_m:
|
|
710
|
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
|
711
|
+
out_shape = (total_m, B.shape[-1])
|
|
712
|
+
elif A.ndim == 2:
|
|
713
|
+
out_shape = (A.shape[0], B.shape[-1])
|
|
714
|
+
else:
|
|
715
|
+
out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
|
|
376
716
|
if preact_out is None and store_preact:
|
|
377
|
-
preact_out = torch.empty(
|
|
717
|
+
preact_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
|
378
718
|
if postact_out is None:
|
|
379
|
-
postact_out = torch.empty(
|
|
380
|
-
gemm_act_out(
|
|
719
|
+
postact_out = torch.empty(out_shape, dtype=postact_dtype, device=A.device)
|
|
720
|
+
gemm_act_out(
|
|
721
|
+
A,
|
|
722
|
+
B,
|
|
723
|
+
preact_out,
|
|
724
|
+
postact_out,
|
|
725
|
+
C,
|
|
726
|
+
bias,
|
|
727
|
+
activation,
|
|
728
|
+
cu_seqlens_m,
|
|
729
|
+
A_idx,
|
|
730
|
+
dynamic_scheduler,
|
|
731
|
+
tuned,
|
|
732
|
+
)
|
|
381
733
|
return preact_out, postact_out
|
|
382
734
|
|
|
383
735
|
|
|
@@ -385,58 +737,81 @@ def gemm_act(
|
|
|
385
737
|
"quack::gemm_act_out",
|
|
386
738
|
mutates_args=("preact_out", "postact_out"),
|
|
387
739
|
device_types="cuda",
|
|
388
|
-
schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, str? activation=None, bool tuned=True) -> ()",
|
|
740
|
+
schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, Tensor? bias=None, str? activation=None, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
389
741
|
)
|
|
390
742
|
def gemm_act_out(
|
|
391
|
-
A: Tensor,
|
|
392
|
-
B: Tensor,
|
|
393
|
-
preact_out: Optional[Tensor],
|
|
394
|
-
postact_out: Tensor,
|
|
395
|
-
C: Optional[Tensor] = None,
|
|
743
|
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
|
744
|
+
B: Tensor, # (K, N) or (L, K, N)
|
|
745
|
+
preact_out: Optional[Tensor], # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
746
|
+
postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
747
|
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
748
|
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
|
396
749
|
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
750
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
751
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
752
|
+
dynamic_scheduler: bool = False,
|
|
397
753
|
tuned: bool = True,
|
|
398
754
|
) -> None:
|
|
399
755
|
"""GEMM with activation and pre-allocated output tensors."""
|
|
400
756
|
fn = gemm_act_tuned if tuned else partial(gemm_act_tuned.fn, config=None)
|
|
401
|
-
fn(A, B, preact_out, postact_out, C, activation)
|
|
757
|
+
fn(A, B, preact_out, postact_out, C, bias, activation, cu_seqlens_m, A_idx, dynamic_scheduler)
|
|
402
758
|
|
|
403
759
|
|
|
404
760
|
def gemm_act_ref(
|
|
405
|
-
A: Tensor,
|
|
406
|
-
B: Tensor,
|
|
407
|
-
C: Optional[Tensor] = None,
|
|
761
|
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
|
|
762
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
763
|
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
764
|
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
|
408
765
|
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
766
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
767
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
409
768
|
out_dtype: Optional[torch.dtype] = None,
|
|
410
769
|
postact_dtype: Optional[torch.dtype] = None,
|
|
411
770
|
store_preact: bool = True,
|
|
412
771
|
) -> Tuple[Optional[Tensor], Tensor]:
|
|
413
772
|
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
414
773
|
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
|
415
|
-
|
|
774
|
+
if C is None:
|
|
775
|
+
out = gemm_ref(A, B, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
|
|
776
|
+
else:
|
|
777
|
+
out = gemm_add_ref(A, B, C, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
|
|
416
778
|
postact = act_to_pytorch_fn_map[activation](out).to(postact_dtype)
|
|
417
779
|
return out.to(out_dtype) if store_preact else None, postact
|
|
418
780
|
|
|
419
781
|
|
|
420
782
|
def gemm_dact(
|
|
421
|
-
A: Tensor,
|
|
422
|
-
B: Tensor,
|
|
423
|
-
PreAct: Tensor,
|
|
783
|
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
|
784
|
+
B: Tensor, # (K, N) or (L, K, N)
|
|
785
|
+
PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
424
786
|
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
425
|
-
dx_out: Optional[Tensor] = None,
|
|
426
|
-
postact_out: Optional[Tensor] = None,
|
|
787
|
+
dx_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
788
|
+
postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
427
789
|
out_dtype: Optional[torch.dtype] = None,
|
|
428
790
|
postact_dtype: Optional[torch.dtype] = None,
|
|
791
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
792
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
429
793
|
dynamic_scheduler: bool = True,
|
|
430
794
|
tuned: bool = True,
|
|
431
795
|
) -> Tuple[Tensor, Tensor]:
|
|
432
796
|
"""GEMM with activation gradient and optional output tensors."""
|
|
433
797
|
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
434
798
|
postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
|
|
799
|
+
varlen_m = cu_seqlens_m is not None
|
|
800
|
+
# Determine output shape based on gather_A
|
|
801
|
+
if varlen_m:
|
|
802
|
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
|
803
|
+
out_shape = (total_m, B.shape[-1])
|
|
804
|
+
elif A.ndim == 2:
|
|
805
|
+
out_shape = (A.shape[0], B.shape[-1])
|
|
806
|
+
else:
|
|
807
|
+
out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
|
|
435
808
|
if dx_out is None:
|
|
436
|
-
dx_out = torch.empty(
|
|
809
|
+
dx_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
|
437
810
|
if postact_out is None:
|
|
438
|
-
postact_out = torch.empty(
|
|
439
|
-
gemm_dact_out(
|
|
811
|
+
postact_out = torch.empty(out_shape, dtype=postact_dtype, device=A.device)
|
|
812
|
+
gemm_dact_out(
|
|
813
|
+
A, B, PreAct, dx_out, postact_out, activation, cu_seqlens_m, A_idx, dynamic_scheduler, tuned
|
|
814
|
+
)
|
|
440
815
|
return dx_out, postact_out
|
|
441
816
|
|
|
442
817
|
|
|
@@ -444,35 +819,39 @@ def gemm_dact(
|
|
|
444
819
|
"quack::gemm_dact_out",
|
|
445
820
|
mutates_args=("dx_out", "postact_out"),
|
|
446
821
|
device_types="cuda",
|
|
447
|
-
schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a3!) dx_out, Tensor(a4!) postact_out, str? activation=None, bool dynamic_scheduler=True, bool tuned=True) -> ()",
|
|
822
|
+
schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a3!) dx_out, Tensor(a4!) postact_out, str? activation=None, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=True, bool tuned=True) -> ()",
|
|
448
823
|
)
|
|
449
824
|
def gemm_dact_out(
|
|
450
|
-
A: Tensor,
|
|
451
|
-
B: Tensor,
|
|
452
|
-
PreAct: Tensor,
|
|
453
|
-
dx_out: Tensor,
|
|
454
|
-
postact_out: Tensor,
|
|
825
|
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
|
826
|
+
B: Tensor, # (K, N) or (L, K, N)
|
|
827
|
+
PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
828
|
+
dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
829
|
+
postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
455
830
|
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
831
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
832
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
456
833
|
dynamic_scheduler: bool = True,
|
|
457
834
|
tuned: bool = True,
|
|
458
835
|
) -> None:
|
|
459
836
|
"""GEMM with activation gradient and pre-allocated output tensors."""
|
|
460
837
|
fn = gemm_dact_tuned if tuned else partial(gemm_dact_tuned.fn, config=None)
|
|
461
|
-
fn(A, B, PreAct, dx_out, postact_out, activation, dynamic_scheduler)
|
|
838
|
+
fn(A, B, PreAct, dx_out, postact_out, activation, cu_seqlens_m, A_idx, dynamic_scheduler)
|
|
462
839
|
|
|
463
840
|
|
|
464
841
|
def gemm_dact_ref(
|
|
465
|
-
A: Tensor,
|
|
466
|
-
B: Tensor,
|
|
467
|
-
PreAct: Tensor,
|
|
842
|
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
|
|
843
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
844
|
+
PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
468
845
|
activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None,
|
|
846
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
847
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
469
848
|
out_dtype: Optional[torch.dtype] = None,
|
|
470
849
|
postact_dtype: Optional[torch.dtype] = None,
|
|
471
850
|
) -> Tuple[Tensor, Tensor]:
|
|
472
851
|
"""Reference implementation for GEMM with activation gradient."""
|
|
473
852
|
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
474
853
|
postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
|
|
475
|
-
dout =
|
|
854
|
+
dout = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx).to(out_dtype)
|
|
476
855
|
postact = act_to_pytorch_fn_map[activation](PreAct)
|
|
477
856
|
# Compute gradient using autograd
|
|
478
857
|
if activation is None:
|
|
@@ -487,10 +866,13 @@ def gemm_dact_ref(
|
|
|
487
866
|
|
|
488
867
|
|
|
489
868
|
def gemm_gated_ref(
|
|
490
|
-
A: Tensor,
|
|
491
|
-
B: Tensor,
|
|
492
|
-
C: Optional[Tensor] = None,
|
|
869
|
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
|
|
870
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
871
|
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
872
|
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
|
493
873
|
activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"] = "swiglu",
|
|
874
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
875
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
494
876
|
out_dtype: Optional[torch.dtype] = None,
|
|
495
877
|
postact_dtype: Optional[torch.dtype] = None,
|
|
496
878
|
store_preact: bool = True,
|
|
@@ -499,8 +881,8 @@ def gemm_gated_ref(
|
|
|
499
881
|
|
|
500
882
|
Args:
|
|
501
883
|
A: (M, K) - input tensor
|
|
502
|
-
B: (K,
|
|
503
|
-
C: (M,
|
|
884
|
+
B: (K, N) - weight tensor with gate and up projections
|
|
885
|
+
C: (M, N) - optional bias tensor
|
|
504
886
|
activation: Type of gated activation
|
|
505
887
|
out_dtype: Output dtype for preact
|
|
506
888
|
postact_dtype: Output dtype for postact
|
|
@@ -508,24 +890,29 @@ def gemm_gated_ref(
|
|
|
508
890
|
|
|
509
891
|
Returns:
|
|
510
892
|
(preact, postact) where:
|
|
511
|
-
- preact: (M,
|
|
512
|
-
- postact: (M, N) post-activation output
|
|
893
|
+
- preact: (M, N) pre-activation (if store_preact=True, else None)
|
|
894
|
+
- postact: (M, N // 2) post-activation output
|
|
513
895
|
"""
|
|
514
896
|
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
515
897
|
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
|
516
|
-
|
|
898
|
+
if C is None:
|
|
899
|
+
preact = gemm_ref(A, B, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
|
|
900
|
+
else:
|
|
901
|
+
preact = gemm_add_ref(A, B, C, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx)
|
|
517
902
|
# Split preact into gate and up projections
|
|
518
|
-
gate = preact[..., ::2] # (M, N)
|
|
519
|
-
up = preact[..., 1::2] # (M, N)
|
|
903
|
+
gate = preact[..., ::2] # (M, N//2)
|
|
904
|
+
up = preact[..., 1::2] # (M, N//2)
|
|
520
905
|
postact = gated_to_pytorch_fn_map[activation](gate, up)
|
|
521
906
|
return preact.to(out_dtype) if store_preact else None, postact.to(postact_dtype)
|
|
522
907
|
|
|
523
908
|
|
|
524
909
|
def gemm_dgated_ref(
|
|
525
|
-
A: Tensor,
|
|
526
|
-
B: Tensor,
|
|
527
|
-
PreAct: Tensor,
|
|
910
|
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A
|
|
911
|
+
B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k
|
|
912
|
+
PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
|
|
528
913
|
activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"],
|
|
914
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
915
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
529
916
|
out_dtype: Optional[torch.dtype] = None,
|
|
530
917
|
postact_dtype: Optional[torch.dtype] = None,
|
|
531
918
|
) -> Tuple[Tensor, Tensor]:
|
|
@@ -546,18 +933,100 @@ def gemm_dgated_ref(
|
|
|
546
933
|
"""
|
|
547
934
|
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
548
935
|
postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
|
|
549
|
-
dout =
|
|
936
|
+
dout = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx).to(out_dtype)
|
|
550
937
|
# Split PreAct into gate and up projections
|
|
551
938
|
gate = PreAct[..., ::2] # (M, N)
|
|
552
939
|
up = PreAct[..., 1::2] # (M, N)
|
|
553
|
-
postact = gated_to_pytorch_fn_map[activation](gate, up)
|
|
554
940
|
# Use autograd to compute gradients w.r.t. gate and up
|
|
941
|
+
gate_requires_grad, up_requires_grad = gate.requires_grad, up.requires_grad
|
|
942
|
+
gate.requires_grad_(True)
|
|
943
|
+
up.requires_grad_(True)
|
|
944
|
+
postact = gated_to_pytorch_fn_map[activation](gate, up)
|
|
555
945
|
dgate, dup = torch.autograd.grad(postact, [gate, up], dout, create_graph=False)
|
|
946
|
+
gate.requires_grad_(gate_requires_grad)
|
|
947
|
+
up.requires_grad_(up_requires_grad)
|
|
556
948
|
# Interleave gradients back
|
|
557
949
|
dx = torch.stack([dgate, dup], dim=-1).reshape(PreAct.shape)
|
|
558
950
|
return dx.to(out_dtype), postact.to(postact_dtype)
|
|
559
951
|
|
|
560
952
|
|
|
953
|
+
@torch.library.custom_op(
|
|
954
|
+
"quack::gemm_symmetric_out",
|
|
955
|
+
mutates_args=("out",),
|
|
956
|
+
device_types="cuda",
|
|
957
|
+
schema="(Tensor A, Tensor B, Tensor(a2!) out, Tensor? C=None, bool dynamic_scheduler=False, float alpha=1.0, float beta=1.0) -> ()",
|
|
958
|
+
)
|
|
959
|
+
def gemm_symmetric_out(
|
|
960
|
+
A: Tensor, # (M, K) or (L, M, K)
|
|
961
|
+
B: Tensor, # (K, M) or (L, K, M)
|
|
962
|
+
out: Tensor, # (M, M) or (L, M, M)
|
|
963
|
+
C: Optional[Tensor] = None, # (M, M) or (L, M, M)
|
|
964
|
+
dynamic_scheduler: bool = False,
|
|
965
|
+
alpha: float = 1.0,
|
|
966
|
+
beta: float = 1.0,
|
|
967
|
+
) -> None:
|
|
968
|
+
"""GEMM with guaranteed symmetric output."""
|
|
969
|
+
if A.ndim == 2:
|
|
970
|
+
A = A.unsqueeze(0) # (1, M, K)
|
|
971
|
+
B = B.mT # (M, K) or (L, M, K)
|
|
972
|
+
if B.ndim == 2:
|
|
973
|
+
B = B.unsqueeze(0) # (1, M, K)
|
|
974
|
+
if C is not None and C.ndim == 2:
|
|
975
|
+
C = C.unsqueeze(0) # (1, M, M)
|
|
976
|
+
if out.ndim == 2:
|
|
977
|
+
out = out.unsqueeze(0)
|
|
978
|
+
else:
|
|
979
|
+
out = out
|
|
980
|
+
tile_count_semaphore = (
|
|
981
|
+
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
|
|
982
|
+
)
|
|
983
|
+
gemm_symmetric_sm90_sm100(
|
|
984
|
+
A,
|
|
985
|
+
B,
|
|
986
|
+
out if out is not None else None,
|
|
987
|
+
C if C is not None else None,
|
|
988
|
+
tile_count_semaphore,
|
|
989
|
+
tile_M=128,
|
|
990
|
+
tile_N=256,
|
|
991
|
+
cluster_M=2,
|
|
992
|
+
cluster_N=1,
|
|
993
|
+
pingpong=False,
|
|
994
|
+
persistent=True,
|
|
995
|
+
max_swizzle_size=8,
|
|
996
|
+
alpha=alpha,
|
|
997
|
+
beta=beta,
|
|
998
|
+
)
|
|
999
|
+
|
|
1000
|
+
|
|
1001
|
+
def gemm_symmetric(
|
|
1002
|
+
A: Tensor, # (M, K) or (L, M, K)
|
|
1003
|
+
B: Tensor, # (K, M) or (L, K, M)
|
|
1004
|
+
C: Optional[Tensor] = None, # (M, M) or (L, M, M)
|
|
1005
|
+
out: Optional[Tensor] = None, # (M, M) or (L, M, M)
|
|
1006
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
1007
|
+
dynamic_scheduler: bool = False,
|
|
1008
|
+
alpha: float | Tensor = 1.0,
|
|
1009
|
+
beta: float | Tensor = 1.0,
|
|
1010
|
+
) -> Tuple[Optional[Tensor], Tensor]:
|
|
1011
|
+
"""GEMM with symmetric output."""
|
|
1012
|
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
1013
|
+
# Determine output shape based on gather_A
|
|
1014
|
+
if A.ndim == 2:
|
|
1015
|
+
out_shape = (A.shape[0], B.shape[-1])
|
|
1016
|
+
else:
|
|
1017
|
+
out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
|
|
1018
|
+
if out is None:
|
|
1019
|
+
out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
|
1020
|
+
|
|
1021
|
+
alpha_val = alpha if isinstance(alpha, float) else 1.0
|
|
1022
|
+
beta_val = beta if isinstance(beta, float) else 1.0
|
|
1023
|
+
|
|
1024
|
+
gemm_symmetric_out(
|
|
1025
|
+
A, B, out, C, dynamic_scheduler=dynamic_scheduler, alpha=alpha_val, beta=beta_val
|
|
1026
|
+
)
|
|
1027
|
+
return out
|
|
1028
|
+
|
|
1029
|
+
|
|
561
1030
|
# TODO: this is not quite right, do we need to register gemm_add not gemm_add_out?
|
|
562
1031
|
# try:
|
|
563
1032
|
# from torch._inductor.fx_passes.reinplace import InplaceableOp
|