quack-kernels 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quack/__init__.py +1 -1
- quack/activation.py +72 -64
- quack/broadcast_utils.py +1 -1
- quack/copy_utils.py +14 -18
- quack/fast_math.py +29 -76
- quack/gemm_act.py +296 -8
- quack/gemm_dact.py +520 -4
- quack/gemm_default_epi.py +4 -4
- quack/gemm_interface.py +363 -0
- quack/gemm_sm100.py +62 -88
- quack/gemm_sm90.py +68 -114
- quack/gemm_symmetric.py +2 -6
- quack/layout_utils.py +2 -4
- quack/linear.py +37 -0
- quack/pipeline.py +59 -89
- quack/reduce.py +2 -2
- quack/rmsnorm.py +1 -3
- quack/sm90_utils.py +5 -3
- quack/sort/bitonic_sort.py +3 -3
- quack/tile_scheduler.py +310 -256
- quack/topk.py +4 -4
- quack/utils.py +76 -40
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/METADATA +2 -2
- quack_kernels-0.2.6.dist-info/RECORD +45 -0
- quack_kernels-0.2.5.dist-info/RECORD +0 -45
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/WHEEL +0 -0
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/licenses/LICENSE +0 -0
- {quack_kernels-0.2.5.dist-info → quack_kernels-0.2.6.dist-info}/top_level.txt +0 -0
quack/gemm_interface.py
CHANGED
|
@@ -12,7 +12,9 @@ from quack.autotuner import autotune, AutotuneConfig
|
|
|
12
12
|
from quack.cute_dsl_utils import get_device_capacity
|
|
13
13
|
from quack.gemm import gemm as gemm_sm90_sm100
|
|
14
14
|
from quack.gemm_act import gemm_act as gemm_act_sm90_sm100
|
|
15
|
+
from quack.gemm_act import gemm_gated as gemm_gated_sm90_sm100
|
|
15
16
|
from quack.gemm_dact import gemm_dact as gemm_dact_sm90_sm100
|
|
17
|
+
from quack.gemm_dact import gemm_dgated as gemm_dgated_sm90_sm100
|
|
16
18
|
from quack.gemm_symmetric import gemm_symmetric as gemm_symmetric_sm90_sm100
|
|
17
19
|
|
|
18
20
|
|
|
@@ -1027,6 +1029,367 @@ def gemm_symmetric(
|
|
|
1027
1029
|
return out
|
|
1028
1030
|
|
|
1029
1031
|
|
|
1032
|
+
@autotune(
|
|
1033
|
+
configs=[
|
|
1034
|
+
AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0], "gated")
|
|
1035
|
+
],
|
|
1036
|
+
key=["activation", "dynamic_scheduler"],
|
|
1037
|
+
prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
|
|
1038
|
+
)
|
|
1039
|
+
def gemm_gated_tuned(
|
|
1040
|
+
# (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
|
1041
|
+
A: Tensor,
|
|
1042
|
+
B: Tensor, # (K, N) or (L, K, N)
|
|
1043
|
+
# (M, N) or (L, M, N) or (total_M, N) if varlen_m - None if not storing preact
|
|
1044
|
+
preact_out: Optional[Tensor],
|
|
1045
|
+
postact_out: Tensor, # (M, N//2) or (L, M, N//2) or (total_M, N//2) if varlen_m
|
|
1046
|
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
1047
|
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
|
1048
|
+
activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu",
|
|
1049
|
+
cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
|
|
1050
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
1051
|
+
dynamic_scheduler: bool = False,
|
|
1052
|
+
config: Optional[GemmConfig] = None,
|
|
1053
|
+
) -> None:
|
|
1054
|
+
if config is None:
|
|
1055
|
+
config = default_config(A.device)
|
|
1056
|
+
varlen_m = cu_seqlens_m is not None
|
|
1057
|
+
if varlen_m:
|
|
1058
|
+
assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
|
|
1059
|
+
if A.ndim == 2 and not varlen_m:
|
|
1060
|
+
A = A.unsqueeze(0) # (1, M, K)
|
|
1061
|
+
B = B.mT # (N, K) or (L, N, K)
|
|
1062
|
+
if B.ndim == 2:
|
|
1063
|
+
B = B.unsqueeze(0) # (1, N, K)
|
|
1064
|
+
if C is not None and C.ndim == 2 and not varlen_m:
|
|
1065
|
+
C = C.unsqueeze(0) # (1, M, N)
|
|
1066
|
+
if preact_out is not None and preact_out.ndim == 2 and not varlen_m:
|
|
1067
|
+
D = preact_out.unsqueeze(0)
|
|
1068
|
+
else:
|
|
1069
|
+
D = preact_out
|
|
1070
|
+
if postact_out.ndim == 2 and not varlen_m:
|
|
1071
|
+
PostAct = postact_out.unsqueeze(0)
|
|
1072
|
+
else:
|
|
1073
|
+
PostAct = postact_out
|
|
1074
|
+
if bias is not None and bias.ndim == 1:
|
|
1075
|
+
bias = bias.unsqueeze(0) # (L, N)
|
|
1076
|
+
tile_count_semaphore = (
|
|
1077
|
+
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
|
|
1078
|
+
)
|
|
1079
|
+
gemm_gated_sm90_sm100(
|
|
1080
|
+
A if not config.swap_ab else B,
|
|
1081
|
+
B if not config.swap_ab else A,
|
|
1082
|
+
(D if not config.swap_ab else D.mT) if D is not None else None,
|
|
1083
|
+
(C if not config.swap_ab else C.mT) if C is not None else None,
|
|
1084
|
+
PostAct if not config.swap_ab else PostAct.mT,
|
|
1085
|
+
tile_count_semaphore,
|
|
1086
|
+
activation,
|
|
1087
|
+
config.tile_m,
|
|
1088
|
+
config.tile_n,
|
|
1089
|
+
config.cluster_m,
|
|
1090
|
+
config.cluster_n,
|
|
1091
|
+
config.pingpong,
|
|
1092
|
+
persistent=True,
|
|
1093
|
+
max_swizzle_size=config.max_swizzle_size,
|
|
1094
|
+
rowvec_bias=bias if not config.swap_ab else None,
|
|
1095
|
+
colvec_bias=bias if config.swap_ab else None,
|
|
1096
|
+
cu_seqlens_m=cu_seqlens_m,
|
|
1097
|
+
A_idx=A_idx,
|
|
1098
|
+
)
|
|
1099
|
+
|
|
1100
|
+
|
|
1101
|
+
def prune_invalid_gemm_dgated_configs(configs, named_args: dict, **kwargs):
|
|
1102
|
+
kwargs = named_args | kwargs
|
|
1103
|
+
# if there's colvec_scale or colvec_reduce, don't swap_AB
|
|
1104
|
+
if kwargs.get("colvec_scale", None) is not None or kwargs.get("colvec_reduce", False):
|
|
1105
|
+
configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab]
|
|
1106
|
+
return prune_invalid_gemm_configs(configs, named_args, **kwargs)
|
|
1107
|
+
|
|
1108
|
+
|
|
1109
|
+
@autotune(
|
|
1110
|
+
configs=[
|
|
1111
|
+
AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0], "dgated")
|
|
1112
|
+
],
|
|
1113
|
+
key=["activation", "colvec_reduce", "dynamic_scheduler"],
|
|
1114
|
+
prune_configs_by={"early_config_prune": prune_invalid_gemm_dgated_configs},
|
|
1115
|
+
)
|
|
1116
|
+
def gemm_dgated_tuned(
|
|
1117
|
+
# (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
|
1118
|
+
A: Tensor,
|
|
1119
|
+
B: Tensor, # (K, N) or (L, K, N)
|
|
1120
|
+
PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
|
|
1121
|
+
dx_out: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
|
|
1122
|
+
postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
1123
|
+
colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m
|
|
1124
|
+
activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu",
|
|
1125
|
+
# whether to do colvec reduction, returning (M,) or (L, M) or (total_M) if varlen_m
|
|
1126
|
+
colvec_reduce: bool = False,
|
|
1127
|
+
cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
|
|
1128
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
1129
|
+
dynamic_scheduler: bool = True,
|
|
1130
|
+
config: Optional[GemmConfig] = None,
|
|
1131
|
+
) -> Optional[Tensor]:
|
|
1132
|
+
if config is None:
|
|
1133
|
+
config = default_config(A.device)
|
|
1134
|
+
varlen_m = cu_seqlens_m is not None
|
|
1135
|
+
if varlen_m:
|
|
1136
|
+
assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
|
|
1137
|
+
og_ndim_2 = A.ndim == 2 and not varlen_m
|
|
1138
|
+
if A.ndim == 2 and not varlen_m:
|
|
1139
|
+
A = A.unsqueeze(0) # (1, M, K)
|
|
1140
|
+
B = B.mT # (N, K) or (L, N, K)
|
|
1141
|
+
if B.ndim == 2:
|
|
1142
|
+
B = B.unsqueeze(0) # (1, N, K)
|
|
1143
|
+
if PreAct.ndim == 2 and not varlen_m:
|
|
1144
|
+
PreAct = PreAct.unsqueeze(0) # (1, M, 2*N)
|
|
1145
|
+
if dx_out.ndim == 2 and not varlen_m:
|
|
1146
|
+
D = dx_out.unsqueeze(0)
|
|
1147
|
+
else:
|
|
1148
|
+
D = dx_out
|
|
1149
|
+
if postact_out.ndim == 2 and not varlen_m:
|
|
1150
|
+
PostAct = postact_out.unsqueeze(0)
|
|
1151
|
+
else:
|
|
1152
|
+
PostAct = postact_out
|
|
1153
|
+
if colvec_scale is not None and colvec_scale.ndim == 1 and not varlen_m:
|
|
1154
|
+
colvec_scale = colvec_scale.unsqueeze(0) # (L, N)
|
|
1155
|
+
if colvec_scale is not None:
|
|
1156
|
+
assert not config.swap_ab, "colvec_scale not supported with swap_ab"
|
|
1157
|
+
if colvec_reduce:
|
|
1158
|
+
tile_n = config.tile_n
|
|
1159
|
+
shape_n = (B.shape[-2] + tile_n - 1) // tile_n
|
|
1160
|
+
if varlen_m:
|
|
1161
|
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
|
1162
|
+
colvec_shape = (total_m, shape_n)
|
|
1163
|
+
else:
|
|
1164
|
+
colvec_shape = (A.shape[0], A.shape[-2], shape_n)
|
|
1165
|
+
colvec_reduce_partial = torch.empty(colvec_shape, dtype=torch.float32, device=A.device)
|
|
1166
|
+
else:
|
|
1167
|
+
colvec_reduce_partial = None
|
|
1168
|
+
tile_count_semaphore = (
|
|
1169
|
+
torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
|
|
1170
|
+
)
|
|
1171
|
+
gemm_dgated_sm90_sm100(
|
|
1172
|
+
A if not config.swap_ab else B,
|
|
1173
|
+
B if not config.swap_ab else A,
|
|
1174
|
+
D if not config.swap_ab else D.mT,
|
|
1175
|
+
PreAct if not config.swap_ab else PreAct.mT,
|
|
1176
|
+
PostAct if not config.swap_ab else PostAct.mT,
|
|
1177
|
+
tile_count_semaphore,
|
|
1178
|
+
activation,
|
|
1179
|
+
config.tile_m,
|
|
1180
|
+
config.tile_n,
|
|
1181
|
+
config.cluster_m,
|
|
1182
|
+
config.cluster_n,
|
|
1183
|
+
config.pingpong,
|
|
1184
|
+
persistent=True,
|
|
1185
|
+
max_swizzle_size=config.max_swizzle_size,
|
|
1186
|
+
colvec_scale=colvec_scale,
|
|
1187
|
+
colvec_reduce=colvec_reduce_partial,
|
|
1188
|
+
cu_seqlens_m=cu_seqlens_m,
|
|
1189
|
+
A_idx=A_idx,
|
|
1190
|
+
)
|
|
1191
|
+
if colvec_reduce:
|
|
1192
|
+
colvec_reduce_final = colvec_reduce_partial.sum(dim=-1)
|
|
1193
|
+
if og_ndim_2:
|
|
1194
|
+
colvec_reduce_final = colvec_reduce_final.squeeze(0)
|
|
1195
|
+
else:
|
|
1196
|
+
colvec_reduce_final = None
|
|
1197
|
+
return colvec_reduce_final
|
|
1198
|
+
|
|
1199
|
+
|
|
1200
|
+
def gemm_gated(
|
|
1201
|
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
|
1202
|
+
B: Tensor, # (K, N) or (L, K, N)
|
|
1203
|
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
1204
|
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
|
1205
|
+
activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu",
|
|
1206
|
+
preact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
1207
|
+
postact_out: Optional[
|
|
1208
|
+
Tensor
|
|
1209
|
+
] = None, # (M, N//2) or (L, M, N//2) or (total_M, N//2) if varlen_m
|
|
1210
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
1211
|
+
postact_dtype: Optional[torch.dtype] = None,
|
|
1212
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
1213
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
1214
|
+
store_preact: bool = True,
|
|
1215
|
+
dynamic_scheduler: bool = False,
|
|
1216
|
+
tuned: bool = True,
|
|
1217
|
+
) -> Tuple[Optional[Tensor], Tensor]:
|
|
1218
|
+
"""GEMM with gated activation and optional output tensors."""
|
|
1219
|
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
1220
|
+
postact_dtype = A.dtype if postact_dtype is None else postact_dtype
|
|
1221
|
+
varlen_m = cu_seqlens_m is not None
|
|
1222
|
+
# Determine output shape based on gather_A
|
|
1223
|
+
if varlen_m:
|
|
1224
|
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
|
1225
|
+
out_shape = (total_m, B.shape[-1])
|
|
1226
|
+
elif A.ndim == 2:
|
|
1227
|
+
out_shape = (A.shape[0], B.shape[-1])
|
|
1228
|
+
else:
|
|
1229
|
+
out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
|
|
1230
|
+
postact_shape = (*out_shape[:-1], out_shape[-1] // 2)
|
|
1231
|
+
if preact_out is None and store_preact:
|
|
1232
|
+
preact_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
|
1233
|
+
if postact_out is None:
|
|
1234
|
+
postact_out = torch.empty(postact_shape, dtype=postact_dtype, device=A.device)
|
|
1235
|
+
gemm_gated_out(
|
|
1236
|
+
A,
|
|
1237
|
+
B,
|
|
1238
|
+
preact_out,
|
|
1239
|
+
postact_out,
|
|
1240
|
+
C,
|
|
1241
|
+
bias,
|
|
1242
|
+
activation,
|
|
1243
|
+
cu_seqlens_m,
|
|
1244
|
+
A_idx,
|
|
1245
|
+
dynamic_scheduler,
|
|
1246
|
+
tuned,
|
|
1247
|
+
)
|
|
1248
|
+
return preact_out, postact_out
|
|
1249
|
+
|
|
1250
|
+
|
|
1251
|
+
@torch.library.custom_op(
|
|
1252
|
+
"quack::gemm_gated_out",
|
|
1253
|
+
mutates_args=("preact_out", "postact_out"),
|
|
1254
|
+
device_types="cuda",
|
|
1255
|
+
schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, Tensor? bias=None, str activation='swiglu', Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
|
|
1256
|
+
)
|
|
1257
|
+
def gemm_gated_out(
|
|
1258
|
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
|
1259
|
+
B: Tensor, # (K, N) or (L, K, N)
|
|
1260
|
+
preact_out: Optional[Tensor], # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
1261
|
+
postact_out: Tensor, # (M, N//2) or (L, M, N//2) or (total_M, N//2) if varlen_m
|
|
1262
|
+
C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
1263
|
+
bias: Optional[Tensor] = None, # (N,) or (L, N)
|
|
1264
|
+
activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu",
|
|
1265
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
1266
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
1267
|
+
dynamic_scheduler: bool = False,
|
|
1268
|
+
tuned: bool = True,
|
|
1269
|
+
) -> None:
|
|
1270
|
+
"""GEMM with gated activation and pre-allocated output tensors."""
|
|
1271
|
+
fn = gemm_gated_tuned if tuned else partial(gemm_gated_tuned.fn, config=None)
|
|
1272
|
+
fn(A, B, preact_out, postact_out, C, bias, activation, cu_seqlens_m, A_idx, dynamic_scheduler)
|
|
1273
|
+
|
|
1274
|
+
|
|
1275
|
+
def gemm_dgated(
|
|
1276
|
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
|
1277
|
+
B: Tensor, # (K, N) or (L, K, N)
|
|
1278
|
+
PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
|
|
1279
|
+
colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m
|
|
1280
|
+
activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu",
|
|
1281
|
+
dx_out: Optional[Tensor] = None, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
|
|
1282
|
+
postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
1283
|
+
out_dtype: Optional[torch.dtype] = None,
|
|
1284
|
+
postact_dtype: Optional[torch.dtype] = None,
|
|
1285
|
+
colvec_reduce: bool = False,
|
|
1286
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
1287
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
1288
|
+
dynamic_scheduler: bool = True,
|
|
1289
|
+
tuned: bool = True,
|
|
1290
|
+
) -> Tuple[Tensor, Tensor]:
|
|
1291
|
+
"""GEMM with gated activation gradient and optional output tensors."""
|
|
1292
|
+
out_dtype = A.dtype if out_dtype is None else out_dtype
|
|
1293
|
+
postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
|
|
1294
|
+
varlen_m = cu_seqlens_m is not None
|
|
1295
|
+
# Determine output shape based on gather_A
|
|
1296
|
+
if varlen_m:
|
|
1297
|
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
|
1298
|
+
out_shape = (total_m, B.shape[-1] * 2)
|
|
1299
|
+
elif A.ndim == 2:
|
|
1300
|
+
out_shape = (A.shape[0], B.shape[-1] * 2)
|
|
1301
|
+
else:
|
|
1302
|
+
out_shape = (A.shape[0], A.shape[-2], B.shape[-1] * 2)
|
|
1303
|
+
postact_shape = (*out_shape[:-1], out_shape[-1] // 2)
|
|
1304
|
+
if dx_out is None:
|
|
1305
|
+
dx_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
|
|
1306
|
+
if postact_out is None:
|
|
1307
|
+
postact_out = torch.empty(postact_shape, dtype=postact_dtype, device=A.device)
|
|
1308
|
+
colvec_reduce_final = gemm_dgated_out(
|
|
1309
|
+
A,
|
|
1310
|
+
B,
|
|
1311
|
+
PreAct,
|
|
1312
|
+
dx_out,
|
|
1313
|
+
postact_out,
|
|
1314
|
+
colvec_scale,
|
|
1315
|
+
activation,
|
|
1316
|
+
colvec_reduce,
|
|
1317
|
+
cu_seqlens_m,
|
|
1318
|
+
A_idx,
|
|
1319
|
+
dynamic_scheduler,
|
|
1320
|
+
tuned,
|
|
1321
|
+
)
|
|
1322
|
+
if not colvec_reduce:
|
|
1323
|
+
return dx_out, postact_out
|
|
1324
|
+
else:
|
|
1325
|
+
return dx_out, postact_out, colvec_reduce_final
|
|
1326
|
+
|
|
1327
|
+
|
|
1328
|
+
@torch.library.custom_op(
|
|
1329
|
+
"quack::gemm_dgated_out",
|
|
1330
|
+
mutates_args=("dx_out", "postact_out"),
|
|
1331
|
+
device_types="cuda",
|
|
1332
|
+
schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a3!) dx_out, Tensor(a4!) postact_out, Tensor? colvec_scale=None, str activation='swiglu', bool colvec_reduce=False, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=True, bool tuned=True) -> Tensor?",
|
|
1333
|
+
)
|
|
1334
|
+
def gemm_dgated_out(
|
|
1335
|
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
|
1336
|
+
B: Tensor, # (K, N) or (L, K, N)
|
|
1337
|
+
PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
|
|
1338
|
+
dx_out: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
|
|
1339
|
+
postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
1340
|
+
colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m
|
|
1341
|
+
activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu",
|
|
1342
|
+
colvec_reduce: bool = False,
|
|
1343
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
1344
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
1345
|
+
dynamic_scheduler: bool = True,
|
|
1346
|
+
tuned: bool = True,
|
|
1347
|
+
) -> Optional[Tensor]:
|
|
1348
|
+
"""GEMM with gated activation gradient and pre-allocated output tensors."""
|
|
1349
|
+
fn = gemm_dgated_tuned if tuned else partial(gemm_dgated_tuned.fn, config=None)
|
|
1350
|
+
return fn(
|
|
1351
|
+
A,
|
|
1352
|
+
B,
|
|
1353
|
+
PreAct,
|
|
1354
|
+
dx_out,
|
|
1355
|
+
postact_out,
|
|
1356
|
+
colvec_scale,
|
|
1357
|
+
activation,
|
|
1358
|
+
colvec_reduce,
|
|
1359
|
+
cu_seqlens_m,
|
|
1360
|
+
A_idx,
|
|
1361
|
+
dynamic_scheduler,
|
|
1362
|
+
)
|
|
1363
|
+
|
|
1364
|
+
|
|
1365
|
+
@torch.library.register_fake("quack::gemm_dgated_out")
|
|
1366
|
+
def gemm_dgated_out_fake(
|
|
1367
|
+
A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
|
|
1368
|
+
B: Tensor, # (K, N) or (L, K, N)
|
|
1369
|
+
PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
|
|
1370
|
+
dx_out: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
|
|
1371
|
+
postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
|
|
1372
|
+
colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m
|
|
1373
|
+
activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu",
|
|
1374
|
+
colvec_reduce: bool = False,
|
|
1375
|
+
cu_seqlens_m: Optional[Tensor] = None,
|
|
1376
|
+
A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
|
|
1377
|
+
dynamic_scheduler: bool = True,
|
|
1378
|
+
tuned: bool = True,
|
|
1379
|
+
) -> Optional[Tensor]:
|
|
1380
|
+
if not colvec_reduce:
|
|
1381
|
+
return None
|
|
1382
|
+
else:
|
|
1383
|
+
if cu_seqlens_m is not None:
|
|
1384
|
+
total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
|
|
1385
|
+
out_shape = (total_m,)
|
|
1386
|
+
elif A.ndim == 2:
|
|
1387
|
+
out_shape = (A.shape[0],)
|
|
1388
|
+
else:
|
|
1389
|
+
out_shape = (A.shape[0], A.shape[-2])
|
|
1390
|
+
return torch.empty(out_shape, dtype=torch.float32, device=A.device)
|
|
1391
|
+
|
|
1392
|
+
|
|
1030
1393
|
# TODO: this is not quite right, do we need to register gemm_add not gemm_add_out?
|
|
1031
1394
|
# try:
|
|
1032
1395
|
# from torch._inductor.fx_passes.reinplace import InplaceableOp
|