quack-kernels 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quack/gemm_interface.py CHANGED
@@ -12,7 +12,9 @@ from quack.autotuner import autotune, AutotuneConfig
12
12
  from quack.cute_dsl_utils import get_device_capacity
13
13
  from quack.gemm import gemm as gemm_sm90_sm100
14
14
  from quack.gemm_act import gemm_act as gemm_act_sm90_sm100
15
+ from quack.gemm_act import gemm_gated as gemm_gated_sm90_sm100
15
16
  from quack.gemm_dact import gemm_dact as gemm_dact_sm90_sm100
17
+ from quack.gemm_dact import gemm_dgated as gemm_dgated_sm90_sm100
16
18
  from quack.gemm_symmetric import gemm_symmetric as gemm_symmetric_sm90_sm100
17
19
 
18
20
 
@@ -1027,6 +1029,367 @@ def gemm_symmetric(
1027
1029
  return out
1028
1030
 
1029
1031
 
1032
+ @autotune(
1033
+ configs=[
1034
+ AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0], "gated")
1035
+ ],
1036
+ key=["activation", "dynamic_scheduler"],
1037
+ prune_configs_by={"early_config_prune": prune_invalid_gemm_configs},
1038
+ )
1039
+ def gemm_gated_tuned(
1040
+ # (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
1041
+ A: Tensor,
1042
+ B: Tensor, # (K, N) or (L, K, N)
1043
+ # (M, N) or (L, M, N) or (total_M, N) if varlen_m - None if not storing preact
1044
+ preact_out: Optional[Tensor],
1045
+ postact_out: Tensor, # (M, N//2) or (L, M, N//2) or (total_M, N//2) if varlen_m
1046
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
1047
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
1048
+ activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu",
1049
+ cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
1050
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
1051
+ dynamic_scheduler: bool = False,
1052
+ config: Optional[GemmConfig] = None,
1053
+ ) -> None:
1054
+ if config is None:
1055
+ config = default_config(A.device)
1056
+ varlen_m = cu_seqlens_m is not None
1057
+ if varlen_m:
1058
+ assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
1059
+ if A.ndim == 2 and not varlen_m:
1060
+ A = A.unsqueeze(0) # (1, M, K)
1061
+ B = B.mT # (N, K) or (L, N, K)
1062
+ if B.ndim == 2:
1063
+ B = B.unsqueeze(0) # (1, N, K)
1064
+ if C is not None and C.ndim == 2 and not varlen_m:
1065
+ C = C.unsqueeze(0) # (1, M, N)
1066
+ if preact_out is not None and preact_out.ndim == 2 and not varlen_m:
1067
+ D = preact_out.unsqueeze(0)
1068
+ else:
1069
+ D = preact_out
1070
+ if postact_out.ndim == 2 and not varlen_m:
1071
+ PostAct = postact_out.unsqueeze(0)
1072
+ else:
1073
+ PostAct = postact_out
1074
+ if bias is not None and bias.ndim == 1:
1075
+ bias = bias.unsqueeze(0) # (L, N)
1076
+ tile_count_semaphore = (
1077
+ torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
1078
+ )
1079
+ gemm_gated_sm90_sm100(
1080
+ A if not config.swap_ab else B,
1081
+ B if not config.swap_ab else A,
1082
+ (D if not config.swap_ab else D.mT) if D is not None else None,
1083
+ (C if not config.swap_ab else C.mT) if C is not None else None,
1084
+ PostAct if not config.swap_ab else PostAct.mT,
1085
+ tile_count_semaphore,
1086
+ activation,
1087
+ config.tile_m,
1088
+ config.tile_n,
1089
+ config.cluster_m,
1090
+ config.cluster_n,
1091
+ config.pingpong,
1092
+ persistent=True,
1093
+ max_swizzle_size=config.max_swizzle_size,
1094
+ rowvec_bias=bias if not config.swap_ab else None,
1095
+ colvec_bias=bias if config.swap_ab else None,
1096
+ cu_seqlens_m=cu_seqlens_m,
1097
+ A_idx=A_idx,
1098
+ )
1099
+
1100
+
1101
+ def prune_invalid_gemm_dgated_configs(configs, named_args: dict, **kwargs):
1102
+ kwargs = named_args | kwargs
1103
+ # if there's colvec_scale or colvec_reduce, don't swap_AB
1104
+ if kwargs.get("colvec_scale", None) is not None or kwargs.get("colvec_reduce", False):
1105
+ configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab]
1106
+ return prune_invalid_gemm_configs(configs, named_args, **kwargs)
1107
+
1108
+
1109
+ @autotune(
1110
+ configs=[
1111
+ AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0], "dgated")
1112
+ ],
1113
+ key=["activation", "colvec_reduce", "dynamic_scheduler"],
1114
+ prune_configs_by={"early_config_prune": prune_invalid_gemm_dgated_configs},
1115
+ )
1116
+ def gemm_dgated_tuned(
1117
+ # (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
1118
+ A: Tensor,
1119
+ B: Tensor, # (K, N) or (L, K, N)
1120
+ PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
1121
+ dx_out: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
1122
+ postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
1123
+ colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m
1124
+ activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu",
1125
+ # whether to do colvec reduction, returning (M,) or (L, M) or (total_M) if varlen_m
1126
+ colvec_reduce: bool = False,
1127
+ cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32
1128
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
1129
+ dynamic_scheduler: bool = True,
1130
+ config: Optional[GemmConfig] = None,
1131
+ ) -> Optional[Tensor]:
1132
+ if config is None:
1133
+ config = default_config(A.device)
1134
+ varlen_m = cu_seqlens_m is not None
1135
+ if varlen_m:
1136
+ assert not config.swap_ab, "Variable-length sequences not supported with swap_ab"
1137
+ og_ndim_2 = A.ndim == 2 and not varlen_m
1138
+ if A.ndim == 2 and not varlen_m:
1139
+ A = A.unsqueeze(0) # (1, M, K)
1140
+ B = B.mT # (N, K) or (L, N, K)
1141
+ if B.ndim == 2:
1142
+ B = B.unsqueeze(0) # (1, N, K)
1143
+ if PreAct.ndim == 2 and not varlen_m:
1144
+ PreAct = PreAct.unsqueeze(0) # (1, M, 2*N)
1145
+ if dx_out.ndim == 2 and not varlen_m:
1146
+ D = dx_out.unsqueeze(0)
1147
+ else:
1148
+ D = dx_out
1149
+ if postact_out.ndim == 2 and not varlen_m:
1150
+ PostAct = postact_out.unsqueeze(0)
1151
+ else:
1152
+ PostAct = postact_out
1153
+ if colvec_scale is not None and colvec_scale.ndim == 1 and not varlen_m:
1154
+ colvec_scale = colvec_scale.unsqueeze(0) # (L, N)
1155
+ if colvec_scale is not None:
1156
+ assert not config.swap_ab, "colvec_scale not supported with swap_ab"
1157
+ if colvec_reduce:
1158
+ tile_n = config.tile_n
1159
+ shape_n = (B.shape[-2] + tile_n - 1) // tile_n
1160
+ if varlen_m:
1161
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
1162
+ colvec_shape = (total_m, shape_n)
1163
+ else:
1164
+ colvec_shape = (A.shape[0], A.shape[-2], shape_n)
1165
+ colvec_reduce_partial = torch.empty(colvec_shape, dtype=torch.float32, device=A.device)
1166
+ else:
1167
+ colvec_reduce_partial = None
1168
+ tile_count_semaphore = (
1169
+ torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None
1170
+ )
1171
+ gemm_dgated_sm90_sm100(
1172
+ A if not config.swap_ab else B,
1173
+ B if not config.swap_ab else A,
1174
+ D if not config.swap_ab else D.mT,
1175
+ PreAct if not config.swap_ab else PreAct.mT,
1176
+ PostAct if not config.swap_ab else PostAct.mT,
1177
+ tile_count_semaphore,
1178
+ activation,
1179
+ config.tile_m,
1180
+ config.tile_n,
1181
+ config.cluster_m,
1182
+ config.cluster_n,
1183
+ config.pingpong,
1184
+ persistent=True,
1185
+ max_swizzle_size=config.max_swizzle_size,
1186
+ colvec_scale=colvec_scale,
1187
+ colvec_reduce=colvec_reduce_partial,
1188
+ cu_seqlens_m=cu_seqlens_m,
1189
+ A_idx=A_idx,
1190
+ )
1191
+ if colvec_reduce:
1192
+ colvec_reduce_final = colvec_reduce_partial.sum(dim=-1)
1193
+ if og_ndim_2:
1194
+ colvec_reduce_final = colvec_reduce_final.squeeze(0)
1195
+ else:
1196
+ colvec_reduce_final = None
1197
+ return colvec_reduce_final
1198
+
1199
+
1200
+ def gemm_gated(
1201
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
1202
+ B: Tensor, # (K, N) or (L, K, N)
1203
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
1204
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
1205
+ activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu",
1206
+ preact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
1207
+ postact_out: Optional[
1208
+ Tensor
1209
+ ] = None, # (M, N//2) or (L, M, N//2) or (total_M, N//2) if varlen_m
1210
+ out_dtype: Optional[torch.dtype] = None,
1211
+ postact_dtype: Optional[torch.dtype] = None,
1212
+ cu_seqlens_m: Optional[Tensor] = None,
1213
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
1214
+ store_preact: bool = True,
1215
+ dynamic_scheduler: bool = False,
1216
+ tuned: bool = True,
1217
+ ) -> Tuple[Optional[Tensor], Tensor]:
1218
+ """GEMM with gated activation and optional output tensors."""
1219
+ out_dtype = A.dtype if out_dtype is None else out_dtype
1220
+ postact_dtype = A.dtype if postact_dtype is None else postact_dtype
1221
+ varlen_m = cu_seqlens_m is not None
1222
+ # Determine output shape based on gather_A
1223
+ if varlen_m:
1224
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
1225
+ out_shape = (total_m, B.shape[-1])
1226
+ elif A.ndim == 2:
1227
+ out_shape = (A.shape[0], B.shape[-1])
1228
+ else:
1229
+ out_shape = (A.shape[0], A.shape[-2], B.shape[-1])
1230
+ postact_shape = (*out_shape[:-1], out_shape[-1] // 2)
1231
+ if preact_out is None and store_preact:
1232
+ preact_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
1233
+ if postact_out is None:
1234
+ postact_out = torch.empty(postact_shape, dtype=postact_dtype, device=A.device)
1235
+ gemm_gated_out(
1236
+ A,
1237
+ B,
1238
+ preact_out,
1239
+ postact_out,
1240
+ C,
1241
+ bias,
1242
+ activation,
1243
+ cu_seqlens_m,
1244
+ A_idx,
1245
+ dynamic_scheduler,
1246
+ tuned,
1247
+ )
1248
+ return preact_out, postact_out
1249
+
1250
+
1251
+ @torch.library.custom_op(
1252
+ "quack::gemm_gated_out",
1253
+ mutates_args=("preact_out", "postact_out"),
1254
+ device_types="cuda",
1255
+ schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, Tensor? bias=None, str activation='swiglu', Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=False, bool tuned=True) -> ()",
1256
+ )
1257
+ def gemm_gated_out(
1258
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
1259
+ B: Tensor, # (K, N) or (L, K, N)
1260
+ preact_out: Optional[Tensor], # (M, N) or (L, M, N) or (total_M, N) if varlen_m
1261
+ postact_out: Tensor, # (M, N//2) or (L, M, N//2) or (total_M, N//2) if varlen_m
1262
+ C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
1263
+ bias: Optional[Tensor] = None, # (N,) or (L, N)
1264
+ activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu",
1265
+ cu_seqlens_m: Optional[Tensor] = None,
1266
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
1267
+ dynamic_scheduler: bool = False,
1268
+ tuned: bool = True,
1269
+ ) -> None:
1270
+ """GEMM with gated activation and pre-allocated output tensors."""
1271
+ fn = gemm_gated_tuned if tuned else partial(gemm_gated_tuned.fn, config=None)
1272
+ fn(A, B, preact_out, postact_out, C, bias, activation, cu_seqlens_m, A_idx, dynamic_scheduler)
1273
+
1274
+
1275
+ def gemm_dgated(
1276
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
1277
+ B: Tensor, # (K, N) or (L, K, N)
1278
+ PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
1279
+ colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m
1280
+ activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu",
1281
+ dx_out: Optional[Tensor] = None, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
1282
+ postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
1283
+ out_dtype: Optional[torch.dtype] = None,
1284
+ postact_dtype: Optional[torch.dtype] = None,
1285
+ colvec_reduce: bool = False,
1286
+ cu_seqlens_m: Optional[Tensor] = None,
1287
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
1288
+ dynamic_scheduler: bool = True,
1289
+ tuned: bool = True,
1290
+ ) -> Tuple[Tensor, Tensor]:
1291
+ """GEMM with gated activation gradient and optional output tensors."""
1292
+ out_dtype = A.dtype if out_dtype is None else out_dtype
1293
+ postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype
1294
+ varlen_m = cu_seqlens_m is not None
1295
+ # Determine output shape based on gather_A
1296
+ if varlen_m:
1297
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
1298
+ out_shape = (total_m, B.shape[-1] * 2)
1299
+ elif A.ndim == 2:
1300
+ out_shape = (A.shape[0], B.shape[-1] * 2)
1301
+ else:
1302
+ out_shape = (A.shape[0], A.shape[-2], B.shape[-1] * 2)
1303
+ postact_shape = (*out_shape[:-1], out_shape[-1] // 2)
1304
+ if dx_out is None:
1305
+ dx_out = torch.empty(out_shape, dtype=out_dtype, device=A.device)
1306
+ if postact_out is None:
1307
+ postact_out = torch.empty(postact_shape, dtype=postact_dtype, device=A.device)
1308
+ colvec_reduce_final = gemm_dgated_out(
1309
+ A,
1310
+ B,
1311
+ PreAct,
1312
+ dx_out,
1313
+ postact_out,
1314
+ colvec_scale,
1315
+ activation,
1316
+ colvec_reduce,
1317
+ cu_seqlens_m,
1318
+ A_idx,
1319
+ dynamic_scheduler,
1320
+ tuned,
1321
+ )
1322
+ if not colvec_reduce:
1323
+ return dx_out, postact_out
1324
+ else:
1325
+ return dx_out, postact_out, colvec_reduce_final
1326
+
1327
+
1328
+ @torch.library.custom_op(
1329
+ "quack::gemm_dgated_out",
1330
+ mutates_args=("dx_out", "postact_out"),
1331
+ device_types="cuda",
1332
+ schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a3!) dx_out, Tensor(a4!) postact_out, Tensor? colvec_scale=None, str activation='swiglu', bool colvec_reduce=False, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=True, bool tuned=True) -> Tensor?",
1333
+ )
1334
+ def gemm_dgated_out(
1335
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
1336
+ B: Tensor, # (K, N) or (L, K, N)
1337
+ PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
1338
+ dx_out: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
1339
+ postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
1340
+ colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m
1341
+ activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu",
1342
+ colvec_reduce: bool = False,
1343
+ cu_seqlens_m: Optional[Tensor] = None,
1344
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
1345
+ dynamic_scheduler: bool = True,
1346
+ tuned: bool = True,
1347
+ ) -> Optional[Tensor]:
1348
+ """GEMM with gated activation gradient and pre-allocated output tensors."""
1349
+ fn = gemm_dgated_tuned if tuned else partial(gemm_dgated_tuned.fn, config=None)
1350
+ return fn(
1351
+ A,
1352
+ B,
1353
+ PreAct,
1354
+ dx_out,
1355
+ postact_out,
1356
+ colvec_scale,
1357
+ activation,
1358
+ colvec_reduce,
1359
+ cu_seqlens_m,
1360
+ A_idx,
1361
+ dynamic_scheduler,
1362
+ )
1363
+
1364
+
1365
+ @torch.library.register_fake("quack::gemm_dgated_out")
1366
+ def gemm_dgated_out_fake(
1367
+ A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m
1368
+ B: Tensor, # (K, N) or (L, K, N)
1369
+ PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
1370
+ dx_out: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m
1371
+ postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m
1372
+ colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m
1373
+ activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu",
1374
+ colvec_reduce: bool = False,
1375
+ cu_seqlens_m: Optional[Tensor] = None,
1376
+ A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m
1377
+ dynamic_scheduler: bool = True,
1378
+ tuned: bool = True,
1379
+ ) -> Optional[Tensor]:
1380
+ if not colvec_reduce:
1381
+ return None
1382
+ else:
1383
+ if cu_seqlens_m is not None:
1384
+ total_m = A_idx.shape[0] if A_idx is not None else A.shape[0]
1385
+ out_shape = (total_m,)
1386
+ elif A.ndim == 2:
1387
+ out_shape = (A.shape[0],)
1388
+ else:
1389
+ out_shape = (A.shape[0], A.shape[-2])
1390
+ return torch.empty(out_shape, dtype=torch.float32, device=A.device)
1391
+
1392
+
1030
1393
  # TODO: this is not quite right, do we need to register gemm_add not gemm_add_out?
1031
1394
  # try:
1032
1395
  # from torch._inductor.fx_passes.reinplace import InplaceableOp