whisper.rn 0.5.1 → 0.5.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/jni.cpp +12 -3
- package/cpp/ggml-alloc.c +49 -18
- package/cpp/ggml-backend-impl.h +0 -3
- package/cpp/ggml-backend-reg.cpp +8 -0
- package/cpp/ggml-backend.cpp +0 -2
- package/cpp/ggml-backend.h +2 -0
- package/cpp/ggml-cpu/amx/amx.cpp +1 -0
- package/cpp/ggml-cpu/arch/arm/quants.c +428 -26
- package/cpp/ggml-cpu/ggml-cpu-impl.h +4 -2
- package/cpp/ggml-cpu/ggml-cpu.c +67 -24
- package/cpp/ggml-cpu/ops.cpp +489 -364
- package/cpp/ggml-cpu/ops.h +4 -4
- package/cpp/ggml-cpu/repack.cpp +143 -29
- package/cpp/ggml-cpu/simd-mappings.h +25 -25
- package/cpp/ggml-cpu/unary-ops.cpp +151 -0
- package/cpp/ggml-cpu/unary-ops.h +7 -0
- package/cpp/ggml-cpu/vec.cpp +83 -0
- package/cpp/ggml-cpu/vec.h +20 -8
- package/cpp/ggml-impl.h +67 -2
- package/cpp/ggml-metal/ggml-metal-common.cpp +2 -2
- package/cpp/ggml-metal/ggml-metal-context.m +5 -6
- package/cpp/ggml-metal/ggml-metal-device.cpp +300 -14
- package/cpp/ggml-metal/ggml-metal-device.h +26 -1
- package/cpp/ggml-metal/ggml-metal-device.m +243 -28
- package/cpp/ggml-metal/ggml-metal-impl.h +177 -9
- package/cpp/ggml-metal/ggml-metal-ops.cpp +843 -157
- package/cpp/ggml-metal/ggml-metal-ops.h +8 -0
- package/cpp/ggml-metal/ggml-metal.cpp +8 -3
- package/cpp/ggml-metal/ggml-metal.metal +12436 -0
- package/cpp/ggml.c +317 -4
- package/cpp/ggml.h +139 -0
- package/cpp/jsi/RNWhisperJSI.cpp +7 -2
- package/cpp/rn-whisper.h +1 -0
- package/cpp/whisper.cpp +8 -2
- package/ios/RNWhisperContext.mm +3 -1
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +2 -0
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNWhisper.ts +2 -0
- package/src/version.json +1 -1
- package/whisper-rn.podspec +1 -1
- package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
|
@@ -268,6 +268,25 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_glu(wsp_ggml_metal
|
|
|
268
268
|
return res;
|
|
269
269
|
}
|
|
270
270
|
|
|
271
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
272
|
+
assert(op->op == WSP_GGML_OP_SUM);
|
|
273
|
+
|
|
274
|
+
char base[256];
|
|
275
|
+
char name[256];
|
|
276
|
+
|
|
277
|
+
snprintf(base, 256, "kernel_op_sum_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
278
|
+
snprintf(name, 256, "%s", base);
|
|
279
|
+
|
|
280
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
281
|
+
if (res) {
|
|
282
|
+
return res;
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
286
|
+
|
|
287
|
+
return res;
|
|
288
|
+
}
|
|
289
|
+
|
|
271
290
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum_rows(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
272
291
|
WSP_GGML_ASSERT(op->src[0]->nb[0] == wsp_ggml_type_size(op->src[0]->type));
|
|
273
292
|
|
|
@@ -299,6 +318,44 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum_rows(wsp_ggml_
|
|
|
299
318
|
return res;
|
|
300
319
|
}
|
|
301
320
|
|
|
321
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cumsum_blk(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
322
|
+
WSP_GGML_ASSERT(op->op == WSP_GGML_OP_CUMSUM);
|
|
323
|
+
|
|
324
|
+
char base[256];
|
|
325
|
+
char name[256];
|
|
326
|
+
|
|
327
|
+
snprintf(base, 256, "kernel_cumsum_blk_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
328
|
+
snprintf(name, 256, "%s", base);
|
|
329
|
+
|
|
330
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
331
|
+
if (res) {
|
|
332
|
+
return res;
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
336
|
+
|
|
337
|
+
return res;
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cumsum_add(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
341
|
+
WSP_GGML_ASSERT(op->op == WSP_GGML_OP_CUMSUM);
|
|
342
|
+
|
|
343
|
+
char base[256];
|
|
344
|
+
char name[256];
|
|
345
|
+
|
|
346
|
+
snprintf(base, 256, "kernel_cumsum_add_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
347
|
+
snprintf(name, 256, "%s", base);
|
|
348
|
+
|
|
349
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
350
|
+
if (res) {
|
|
351
|
+
return res;
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
355
|
+
|
|
356
|
+
return res;
|
|
357
|
+
}
|
|
358
|
+
|
|
302
359
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_soft_max(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
303
360
|
WSP_GGML_ASSERT(!op->src[1] || op->src[1]->type == WSP_GGML_TYPE_F16 || op->src[1]->type == WSP_GGML_TYPE_F32);
|
|
304
361
|
|
|
@@ -338,7 +395,13 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_conv(wsp_ggml_
|
|
|
338
395
|
char base[256];
|
|
339
396
|
char name[256];
|
|
340
397
|
|
|
341
|
-
|
|
398
|
+
const char * suffix = "";
|
|
399
|
+
|
|
400
|
+
if (op->src[1]->ne[0] % 4 == 0) {
|
|
401
|
+
suffix = "_4";
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type), suffix);
|
|
342
405
|
snprintf(name, 256, "%s", base);
|
|
343
406
|
|
|
344
407
|
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
@@ -352,15 +415,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_conv(wsp_ggml_
|
|
|
352
415
|
}
|
|
353
416
|
|
|
354
417
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_scan(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
418
|
+
WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
419
|
+
|
|
355
420
|
char base[256];
|
|
356
421
|
char name[256];
|
|
357
422
|
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
}
|
|
363
|
-
snprintf(name, 256, "%s", base);
|
|
423
|
+
const int nsg = (ne00 + 31)/32;
|
|
424
|
+
|
|
425
|
+
snprintf(base, 256, "kernel_ssm_scan_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
426
|
+
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
|
364
427
|
|
|
365
428
|
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
366
429
|
if (res) {
|
|
@@ -369,7 +432,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_scan(wsp_ggml_
|
|
|
369
432
|
|
|
370
433
|
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
371
434
|
|
|
372
|
-
wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
|
435
|
+
wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg);
|
|
373
436
|
|
|
374
437
|
return res;
|
|
375
438
|
}
|
|
@@ -652,7 +715,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mm_id_map0(wsp
|
|
|
652
715
|
char name[256];
|
|
653
716
|
|
|
654
717
|
snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
|
|
655
|
-
snprintf(name, 256, "%
|
|
718
|
+
snprintf(name, 256, "%s_ne02=%d", base, ne02);
|
|
656
719
|
|
|
657
720
|
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
658
721
|
if (res) {
|
|
@@ -918,6 +981,124 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort(wsp_ggml_m
|
|
|
918
981
|
return res;
|
|
919
982
|
}
|
|
920
983
|
|
|
984
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort_merge(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
985
|
+
assert(op->op == WSP_GGML_OP_ARGSORT);
|
|
986
|
+
|
|
987
|
+
char base[256];
|
|
988
|
+
char name[256];
|
|
989
|
+
|
|
990
|
+
wsp_ggml_sort_order order = (wsp_ggml_sort_order) op->op_params[0];
|
|
991
|
+
|
|
992
|
+
const char * order_str = "undefined";
|
|
993
|
+
switch (order) {
|
|
994
|
+
case WSP_GGML_SORT_ORDER_ASC: order_str = "asc"; break;
|
|
995
|
+
case WSP_GGML_SORT_ORDER_DESC: order_str = "desc"; break;
|
|
996
|
+
default: WSP_GGML_ABORT("fatal error");
|
|
997
|
+
};
|
|
998
|
+
|
|
999
|
+
snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->type), order_str);
|
|
1000
|
+
snprintf(name, 256, "%s", base);
|
|
1001
|
+
|
|
1002
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1003
|
+
if (res) {
|
|
1004
|
+
return res;
|
|
1005
|
+
}
|
|
1006
|
+
|
|
1007
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1008
|
+
|
|
1009
|
+
return res;
|
|
1010
|
+
}
|
|
1011
|
+
|
|
1012
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
|
1013
|
+
wsp_ggml_metal_library_t lib,
|
|
1014
|
+
const struct wsp_ggml_tensor * op,
|
|
1015
|
+
bool has_mask,
|
|
1016
|
+
int32_t ncpsg) {
|
|
1017
|
+
assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
|
|
1018
|
+
WSP_GGML_UNUSED(op);
|
|
1019
|
+
|
|
1020
|
+
char base[256];
|
|
1021
|
+
char name[256];
|
|
1022
|
+
|
|
1023
|
+
snprintf(base, 256, "kernel_%s",
|
|
1024
|
+
"flash_attn_ext_pad");
|
|
1025
|
+
|
|
1026
|
+
snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
|
|
1027
|
+
base,
|
|
1028
|
+
has_mask,
|
|
1029
|
+
ncpsg);
|
|
1030
|
+
|
|
1031
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1032
|
+
if (res) {
|
|
1033
|
+
return res;
|
|
1034
|
+
}
|
|
1035
|
+
|
|
1036
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
1037
|
+
|
|
1038
|
+
wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
|
|
1039
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
|
|
1040
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
|
|
1041
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
|
|
1042
|
+
|
|
1043
|
+
//wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
|
|
1044
|
+
//wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
|
|
1045
|
+
//wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
|
|
1046
|
+
//wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
|
|
1047
|
+
//wsp_ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
|
|
1048
|
+
wsp_ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
|
|
1049
|
+
|
|
1050
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1051
|
+
|
|
1052
|
+
wsp_ggml_metal_cv_free(cv);
|
|
1053
|
+
|
|
1054
|
+
return res;
|
|
1055
|
+
}
|
|
1056
|
+
|
|
1057
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_blk(
|
|
1058
|
+
wsp_ggml_metal_library_t lib,
|
|
1059
|
+
const struct wsp_ggml_tensor * op,
|
|
1060
|
+
int32_t nqptg,
|
|
1061
|
+
int32_t ncpsg) {
|
|
1062
|
+
assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
|
|
1063
|
+
WSP_GGML_UNUSED(op);
|
|
1064
|
+
|
|
1065
|
+
char base[256];
|
|
1066
|
+
char name[256];
|
|
1067
|
+
|
|
1068
|
+
snprintf(base, 256, "kernel_%s",
|
|
1069
|
+
"flash_attn_ext_blk");
|
|
1070
|
+
|
|
1071
|
+
snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d",
|
|
1072
|
+
base,
|
|
1073
|
+
nqptg,
|
|
1074
|
+
ncpsg);
|
|
1075
|
+
|
|
1076
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1077
|
+
if (res) {
|
|
1078
|
+
return res;
|
|
1079
|
+
}
|
|
1080
|
+
|
|
1081
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
1082
|
+
|
|
1083
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0);
|
|
1084
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
|
|
1085
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2);
|
|
1086
|
+
//wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3);
|
|
1087
|
+
|
|
1088
|
+
//wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
|
|
1089
|
+
//wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
|
|
1090
|
+
//wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22);
|
|
1091
|
+
//wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23);
|
|
1092
|
+
wsp_ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
|
|
1093
|
+
wsp_ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
|
|
1094
|
+
|
|
1095
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1096
|
+
|
|
1097
|
+
wsp_ggml_metal_cv_free(cv);
|
|
1098
|
+
|
|
1099
|
+
return res;
|
|
1100
|
+
}
|
|
1101
|
+
|
|
921
1102
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
|
|
922
1103
|
wsp_ggml_metal_library_t lib,
|
|
923
1104
|
const wsp_ggml_tensor * op,
|
|
@@ -925,6 +1106,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
|
|
|
925
1106
|
bool has_sinks,
|
|
926
1107
|
bool has_bias,
|
|
927
1108
|
bool has_scap,
|
|
1109
|
+
bool has_kvpad,
|
|
928
1110
|
int32_t nsg) {
|
|
929
1111
|
assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
|
|
930
1112
|
|
|
@@ -937,18 +1119,23 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
|
|
|
937
1119
|
const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
|
|
938
1120
|
const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
|
|
939
1121
|
|
|
1122
|
+
// do bounds checks for the mask?
|
|
1123
|
+
const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
|
|
1124
|
+
|
|
940
1125
|
snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
|
|
941
1126
|
"flash_attn_ext",
|
|
942
1127
|
wsp_ggml_type_name(op->src[1]->type),
|
|
943
1128
|
dk,
|
|
944
1129
|
dv);
|
|
945
1130
|
|
|
946
|
-
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
|
|
1131
|
+
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
|
|
947
1132
|
base,
|
|
948
1133
|
has_mask,
|
|
949
1134
|
has_sinks,
|
|
950
1135
|
has_bias,
|
|
951
1136
|
has_scap,
|
|
1137
|
+
has_kvpad,
|
|
1138
|
+
bc_mask,
|
|
952
1139
|
ns10,
|
|
953
1140
|
ns20,
|
|
954
1141
|
nsg);
|
|
@@ -964,6 +1151,9 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
|
|
|
964
1151
|
wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
|
|
965
1152
|
wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
|
|
966
1153
|
wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
|
|
1154
|
+
wsp_ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
|
|
1155
|
+
|
|
1156
|
+
wsp_ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
|
|
967
1157
|
|
|
968
1158
|
wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
|
|
969
1159
|
wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
|
|
@@ -983,6 +1173,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec
|
|
|
983
1173
|
bool has_sinks,
|
|
984
1174
|
bool has_bias,
|
|
985
1175
|
bool has_scap,
|
|
1176
|
+
bool has_kvpad,
|
|
986
1177
|
int32_t nsg,
|
|
987
1178
|
int32_t nwg) {
|
|
988
1179
|
assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
|
|
@@ -1002,12 +1193,13 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec
|
|
|
1002
1193
|
dk,
|
|
1003
1194
|
dv);
|
|
1004
1195
|
|
|
1005
|
-
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%
|
|
1196
|
+
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
|
|
1006
1197
|
base,
|
|
1007
1198
|
has_mask,
|
|
1008
1199
|
has_sinks,
|
|
1009
1200
|
has_bias,
|
|
1010
1201
|
has_scap,
|
|
1202
|
+
has_kvpad,
|
|
1011
1203
|
ns10,
|
|
1012
1204
|
ns20,
|
|
1013
1205
|
nsg, nwg);
|
|
@@ -1023,6 +1215,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec
|
|
|
1023
1215
|
wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
|
|
1024
1216
|
wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
|
|
1025
1217
|
wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
|
|
1218
|
+
wsp_ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
|
|
1026
1219
|
|
|
1027
1220
|
wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
|
|
1028
1221
|
wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
|
|
@@ -1205,11 +1398,12 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_rope(wsp_ggml_meta
|
|
|
1205
1398
|
|
|
1206
1399
|
const bool is_neox = mode & WSP_GGML_ROPE_TYPE_NEOX;
|
|
1207
1400
|
const bool is_mrope = mode & WSP_GGML_ROPE_TYPE_MROPE;
|
|
1401
|
+
const bool is_imrope = mode == WSP_GGML_ROPE_TYPE_IMROPE;
|
|
1208
1402
|
const bool is_vision = mode == WSP_GGML_ROPE_TYPE_VISION;
|
|
1209
1403
|
|
|
1210
1404
|
if (is_neox) {
|
|
1211
1405
|
snprintf(base, 256, "kernel_rope_neox_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1212
|
-
} else if (is_mrope && !is_vision) {
|
|
1406
|
+
} else if ((is_mrope || is_imrope) && !is_vision) {
|
|
1213
1407
|
WSP_GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
|
|
1214
1408
|
snprintf(base, 256, "kernel_rope_multi_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1215
1409
|
} else if (is_vision) {
|
|
@@ -1219,14 +1413,20 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_rope(wsp_ggml_meta
|
|
|
1219
1413
|
snprintf(base, 256, "kernel_rope_norm_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1220
1414
|
}
|
|
1221
1415
|
|
|
1222
|
-
snprintf(name, 256, "%
|
|
1416
|
+
snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
|
|
1223
1417
|
|
|
1224
1418
|
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1225
1419
|
if (res) {
|
|
1226
1420
|
return res;
|
|
1227
1421
|
}
|
|
1228
1422
|
|
|
1229
|
-
|
|
1423
|
+
wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
|
|
1424
|
+
|
|
1425
|
+
wsp_ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
|
|
1426
|
+
|
|
1427
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
|
1428
|
+
|
|
1429
|
+
wsp_ggml_metal_cv_free(cv);
|
|
1230
1430
|
|
|
1231
1431
|
return res;
|
|
1232
1432
|
}
|
|
@@ -1279,6 +1479,55 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_1d(
|
|
|
1279
1479
|
return res;
|
|
1280
1480
|
}
|
|
1281
1481
|
|
|
1482
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_2d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1483
|
+
assert(op->op == WSP_GGML_OP_CONV_TRANSPOSE_2D);
|
|
1484
|
+
|
|
1485
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
|
|
1486
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[1]));
|
|
1487
|
+
WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F16 || op->src[0]->type == WSP_GGML_TYPE_F32);
|
|
1488
|
+
WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
|
|
1489
|
+
WSP_GGML_ASSERT(op->type == WSP_GGML_TYPE_F32);
|
|
1490
|
+
|
|
1491
|
+
char base[256];
|
|
1492
|
+
char name[256];
|
|
1493
|
+
|
|
1494
|
+
snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type));
|
|
1495
|
+
snprintf(name, 256, "%s", base);
|
|
1496
|
+
|
|
1497
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1498
|
+
if (res) {
|
|
1499
|
+
return res;
|
|
1500
|
+
}
|
|
1501
|
+
|
|
1502
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1503
|
+
|
|
1504
|
+
return res;
|
|
1505
|
+
}
|
|
1506
|
+
|
|
1507
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_2d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1508
|
+
assert(op->op == WSP_GGML_OP_CONV_2D);
|
|
1509
|
+
|
|
1510
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
|
|
1511
|
+
WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F16 || op->src[0]->type == WSP_GGML_TYPE_F32);
|
|
1512
|
+
WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
|
|
1513
|
+
WSP_GGML_ASSERT(op->type == WSP_GGML_TYPE_F32);
|
|
1514
|
+
|
|
1515
|
+
char base[256];
|
|
1516
|
+
char name[256];
|
|
1517
|
+
|
|
1518
|
+
snprintf(base, 256, "kernel_conv_2d_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type));
|
|
1519
|
+
snprintf(name, 256, "%s", base);
|
|
1520
|
+
|
|
1521
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1522
|
+
if (res) {
|
|
1523
|
+
return res;
|
|
1524
|
+
}
|
|
1525
|
+
|
|
1526
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1527
|
+
|
|
1528
|
+
return res;
|
|
1529
|
+
}
|
|
1530
|
+
|
|
1282
1531
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_upscale(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1283
1532
|
assert(op->op == WSP_GGML_OP_UPSCALE);
|
|
1284
1533
|
|
|
@@ -1374,3 +1623,40 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_timestep_embedding
|
|
|
1374
1623
|
return res;
|
|
1375
1624
|
}
|
|
1376
1625
|
|
|
1626
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_adamw(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1627
|
+
assert(op->op == WSP_GGML_OP_OPT_STEP_ADAMW);
|
|
1628
|
+
|
|
1629
|
+
char base[256];
|
|
1630
|
+
char name[256];
|
|
1631
|
+
|
|
1632
|
+
snprintf(base, 256, "kernel_opt_step_adamw_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1633
|
+
snprintf(name, 256, "%s", base);
|
|
1634
|
+
|
|
1635
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1636
|
+
if (res) {
|
|
1637
|
+
return res;
|
|
1638
|
+
}
|
|
1639
|
+
|
|
1640
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1641
|
+
|
|
1642
|
+
return res;
|
|
1643
|
+
}
|
|
1644
|
+
|
|
1645
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_sgd(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
|
|
1646
|
+
assert(op->op == WSP_GGML_OP_OPT_STEP_SGD);
|
|
1647
|
+
|
|
1648
|
+
char base[256];
|
|
1649
|
+
char name[256];
|
|
1650
|
+
|
|
1651
|
+
snprintf(base, 256, "kernel_opt_step_sgd_%s", wsp_ggml_type_name(op->src[0]->type));
|
|
1652
|
+
snprintf(name, 256, "%s", base);
|
|
1653
|
+
|
|
1654
|
+
wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
|
|
1655
|
+
if (res) {
|
|
1656
|
+
return res;
|
|
1657
|
+
}
|
|
1658
|
+
|
|
1659
|
+
res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
|
1660
|
+
|
|
1661
|
+
return res;
|
|
1662
|
+
}
|
|
@@ -95,7 +95,9 @@ void wsp_ggml_metal_encoder_end_encoding(wsp_ggml_metal_encoder_t encoder);
|
|
|
95
95
|
|
|
96
96
|
typedef struct wsp_ggml_metal_library * wsp_ggml_metal_library_t;
|
|
97
97
|
|
|
98
|
-
wsp_ggml_metal_library_t wsp_ggml_metal_library_init(wsp_ggml_metal_device_t dev);
|
|
98
|
+
wsp_ggml_metal_library_t wsp_ggml_metal_library_init (wsp_ggml_metal_device_t dev);
|
|
99
|
+
wsp_ggml_metal_library_t wsp_ggml_metal_library_init_from_source(wsp_ggml_metal_device_t dev, const char * source, bool verbose);
|
|
100
|
+
|
|
99
101
|
void wsp_ggml_metal_library_free(wsp_ggml_metal_library_t lib);
|
|
100
102
|
|
|
101
103
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline (wsp_ggml_metal_library_t lib, const char * name);
|
|
@@ -109,7 +111,10 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_set_rows
|
|
|
109
111
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_repeat (wsp_ggml_metal_library_t lib, enum wsp_ggml_type tsrc);
|
|
110
112
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_unary (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
111
113
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_glu (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
114
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
112
115
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum_rows (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
116
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cumsum_blk (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
117
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cumsum_add (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
113
118
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_soft_max (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
114
119
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_conv (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
115
120
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_scan (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
@@ -122,6 +127,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mm_id
|
|
|
122
127
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mv_id (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
123
128
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argmax (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
124
129
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
130
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort_merge (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
125
131
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_bin (wsp_ggml_metal_library_t lib, enum wsp_ggml_op op, int32_t n_fuse, bool row);
|
|
126
132
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_l2_norm (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
127
133
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_group_norm (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
@@ -129,11 +135,27 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_norm
|
|
|
129
135
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_rope (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
130
136
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_im2col (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
131
137
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_1d (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
138
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_2d (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
139
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_2d (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
132
140
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_upscale (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
133
141
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pad (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
134
142
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pad_reflect_1d (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
135
143
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_arange (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
136
144
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_timestep_embedding(wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
145
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_adamw (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
146
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_sgd (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
|
|
147
|
+
|
|
148
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
|
149
|
+
wsp_ggml_metal_library_t lib,
|
|
150
|
+
const struct wsp_ggml_tensor * op,
|
|
151
|
+
bool has_mask,
|
|
152
|
+
int32_t ncpsg);
|
|
153
|
+
|
|
154
|
+
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_blk(
|
|
155
|
+
wsp_ggml_metal_library_t lib,
|
|
156
|
+
const struct wsp_ggml_tensor * op,
|
|
157
|
+
int32_t nqptg,
|
|
158
|
+
int32_t ncpsg);
|
|
137
159
|
|
|
138
160
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
|
|
139
161
|
wsp_ggml_metal_library_t lib,
|
|
@@ -142,6 +164,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
|
|
|
142
164
|
bool has_sinks,
|
|
143
165
|
bool has_bias,
|
|
144
166
|
bool has_scap,
|
|
167
|
+
bool has_kvpad,
|
|
145
168
|
int32_t nsg);
|
|
146
169
|
|
|
147
170
|
wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
|
@@ -151,6 +174,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec
|
|
|
151
174
|
bool has_sinks,
|
|
152
175
|
bool has_bias,
|
|
153
176
|
bool has_scap,
|
|
177
|
+
bool has_kvpad,
|
|
154
178
|
int32_t nsg,
|
|
155
179
|
int32_t nwg);
|
|
156
180
|
|
|
@@ -175,6 +199,7 @@ struct wsp_ggml_metal_device_props {
|
|
|
175
199
|
bool has_simdgroup_mm;
|
|
176
200
|
bool has_unified_memory;
|
|
177
201
|
bool has_bfloat;
|
|
202
|
+
bool has_tensor;
|
|
178
203
|
bool use_residency_sets;
|
|
179
204
|
bool use_shared_buffers;
|
|
180
205
|
|