whisper.rn 0.5.1 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. package/android/src/main/jni.cpp +12 -3
  2. package/cpp/ggml-alloc.c +49 -18
  3. package/cpp/ggml-backend-impl.h +0 -3
  4. package/cpp/ggml-backend-reg.cpp +8 -0
  5. package/cpp/ggml-backend.cpp +0 -2
  6. package/cpp/ggml-backend.h +2 -0
  7. package/cpp/ggml-cpu/amx/amx.cpp +1 -0
  8. package/cpp/ggml-cpu/arch/arm/quants.c +428 -26
  9. package/cpp/ggml-cpu/ggml-cpu-impl.h +4 -2
  10. package/cpp/ggml-cpu/ggml-cpu.c +67 -24
  11. package/cpp/ggml-cpu/ops.cpp +489 -364
  12. package/cpp/ggml-cpu/ops.h +4 -4
  13. package/cpp/ggml-cpu/repack.cpp +143 -29
  14. package/cpp/ggml-cpu/simd-mappings.h +25 -25
  15. package/cpp/ggml-cpu/unary-ops.cpp +151 -0
  16. package/cpp/ggml-cpu/unary-ops.h +7 -0
  17. package/cpp/ggml-cpu/vec.cpp +83 -0
  18. package/cpp/ggml-cpu/vec.h +20 -8
  19. package/cpp/ggml-impl.h +67 -2
  20. package/cpp/ggml-metal/ggml-metal-common.cpp +2 -2
  21. package/cpp/ggml-metal/ggml-metal-context.m +5 -6
  22. package/cpp/ggml-metal/ggml-metal-device.cpp +300 -14
  23. package/cpp/ggml-metal/ggml-metal-device.h +26 -1
  24. package/cpp/ggml-metal/ggml-metal-device.m +243 -28
  25. package/cpp/ggml-metal/ggml-metal-impl.h +177 -9
  26. package/cpp/ggml-metal/ggml-metal-ops.cpp +843 -157
  27. package/cpp/ggml-metal/ggml-metal-ops.h +8 -0
  28. package/cpp/ggml-metal/ggml-metal.cpp +8 -3
  29. package/cpp/ggml-metal/ggml-metal.metal +12436 -0
  30. package/cpp/ggml.c +317 -4
  31. package/cpp/ggml.h +139 -0
  32. package/cpp/jsi/RNWhisperJSI.cpp +7 -2
  33. package/cpp/rn-whisper.h +1 -0
  34. package/cpp/whisper.cpp +8 -2
  35. package/ios/RNWhisperContext.mm +3 -1
  36. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  37. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  38. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  39. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
  40. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  41. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  42. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  43. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  44. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  45. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  46. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  47. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
  48. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  49. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  50. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  51. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  52. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  53. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  54. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  55. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  56. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +139 -0
  57. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  58. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  59. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  60. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  61. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  62. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  63. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +67 -2
  64. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +139 -0
  65. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  66. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  67. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  68. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  69. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  70. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  71. package/lib/commonjs/version.json +1 -1
  72. package/lib/module/NativeRNWhisper.js.map +1 -1
  73. package/lib/module/version.json +1 -1
  74. package/lib/typescript/NativeRNWhisper.d.ts +2 -0
  75. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  76. package/package.json +1 -1
  77. package/src/NativeRNWhisper.ts +2 -0
  78. package/src/version.json +1 -1
  79. package/whisper-rn.podspec +1 -1
  80. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  81. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  82. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  83. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  84. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  85. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
@@ -268,6 +268,25 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_glu(wsp_ggml_metal
268
268
  return res;
269
269
  }
270
270
 
271
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
272
+ assert(op->op == WSP_GGML_OP_SUM);
273
+
274
+ char base[256];
275
+ char name[256];
276
+
277
+ snprintf(base, 256, "kernel_op_sum_%s", wsp_ggml_type_name(op->src[0]->type));
278
+ snprintf(name, 256, "%s", base);
279
+
280
+ wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
281
+ if (res) {
282
+ return res;
283
+ }
284
+
285
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
286
+
287
+ return res;
288
+ }
289
+
271
290
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum_rows(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
272
291
  WSP_GGML_ASSERT(op->src[0]->nb[0] == wsp_ggml_type_size(op->src[0]->type));
273
292
 
@@ -299,6 +318,44 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum_rows(wsp_ggml_
299
318
  return res;
300
319
  }
301
320
 
321
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cumsum_blk(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
322
+ WSP_GGML_ASSERT(op->op == WSP_GGML_OP_CUMSUM);
323
+
324
+ char base[256];
325
+ char name[256];
326
+
327
+ snprintf(base, 256, "kernel_cumsum_blk_%s", wsp_ggml_type_name(op->src[0]->type));
328
+ snprintf(name, 256, "%s", base);
329
+
330
+ wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
331
+ if (res) {
332
+ return res;
333
+ }
334
+
335
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
336
+
337
+ return res;
338
+ }
339
+
340
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cumsum_add(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
341
+ WSP_GGML_ASSERT(op->op == WSP_GGML_OP_CUMSUM);
342
+
343
+ char base[256];
344
+ char name[256];
345
+
346
+ snprintf(base, 256, "kernel_cumsum_add_%s", wsp_ggml_type_name(op->src[0]->type));
347
+ snprintf(name, 256, "%s", base);
348
+
349
+ wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
350
+ if (res) {
351
+ return res;
352
+ }
353
+
354
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
355
+
356
+ return res;
357
+ }
358
+
302
359
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_soft_max(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
303
360
  WSP_GGML_ASSERT(!op->src[1] || op->src[1]->type == WSP_GGML_TYPE_F16 || op->src[1]->type == WSP_GGML_TYPE_F32);
304
361
 
@@ -338,7 +395,13 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_conv(wsp_ggml_
338
395
  char base[256];
339
396
  char name[256];
340
397
 
341
- snprintf(base, 256, "kernel_ssm_conv_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type));
398
+ const char * suffix = "";
399
+
400
+ if (op->src[1]->ne[0] % 4 == 0) {
401
+ suffix = "_4";
402
+ }
403
+
404
+ snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type), suffix);
342
405
  snprintf(name, 256, "%s", base);
343
406
 
344
407
  wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
@@ -352,15 +415,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_conv(wsp_ggml_
352
415
  }
353
416
 
354
417
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_scan(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
418
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
419
+
355
420
  char base[256];
356
421
  char name[256];
357
422
 
358
- if (op->src[3]->ne[0] == 1) {
359
- snprintf(base, 256, "kernel_ssm_scan_group_%s", wsp_ggml_type_name(op->src[0]->type));
360
- } else {
361
- snprintf(base, 256, "kernel_ssm_scan_%s", wsp_ggml_type_name(op->src[0]->type));
362
- }
363
- snprintf(name, 256, "%s", base);
423
+ const int nsg = (ne00 + 31)/32;
424
+
425
+ snprintf(base, 256, "kernel_ssm_scan_%s", wsp_ggml_type_name(op->src[0]->type));
426
+ snprintf(name, 256, "%s_nsg=%d", base, nsg);
364
427
 
365
428
  wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
366
429
  if (res) {
@@ -369,7 +432,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_scan(wsp_ggml_
369
432
 
370
433
  res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
371
434
 
372
- wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
435
+ wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg);
373
436
 
374
437
  return res;
375
438
  }
@@ -652,7 +715,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mm_id_map0(wsp
652
715
  char name[256];
653
716
 
654
717
  snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
655
- snprintf(name, 256, "%s", base);
718
+ snprintf(name, 256, "%s_ne02=%d", base, ne02);
656
719
 
657
720
  wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
658
721
  if (res) {
@@ -918,6 +981,124 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort(wsp_ggml_m
918
981
  return res;
919
982
  }
920
983
 
984
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort_merge(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
985
+ assert(op->op == WSP_GGML_OP_ARGSORT);
986
+
987
+ char base[256];
988
+ char name[256];
989
+
990
+ wsp_ggml_sort_order order = (wsp_ggml_sort_order) op->op_params[0];
991
+
992
+ const char * order_str = "undefined";
993
+ switch (order) {
994
+ case WSP_GGML_SORT_ORDER_ASC: order_str = "asc"; break;
995
+ case WSP_GGML_SORT_ORDER_DESC: order_str = "desc"; break;
996
+ default: WSP_GGML_ABORT("fatal error");
997
+ };
998
+
999
+ snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->type), order_str);
1000
+ snprintf(name, 256, "%s", base);
1001
+
1002
+ wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1003
+ if (res) {
1004
+ return res;
1005
+ }
1006
+
1007
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1008
+
1009
+ return res;
1010
+ }
1011
+
1012
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(
1013
+ wsp_ggml_metal_library_t lib,
1014
+ const struct wsp_ggml_tensor * op,
1015
+ bool has_mask,
1016
+ int32_t ncpsg) {
1017
+ assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
1018
+ WSP_GGML_UNUSED(op);
1019
+
1020
+ char base[256];
1021
+ char name[256];
1022
+
1023
+ snprintf(base, 256, "kernel_%s",
1024
+ "flash_attn_ext_pad");
1025
+
1026
+ snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
1027
+ base,
1028
+ has_mask,
1029
+ ncpsg);
1030
+
1031
+ wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1032
+ if (res) {
1033
+ return res;
1034
+ }
1035
+
1036
+ wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
1037
+
1038
+ wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
1039
+ //wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
1040
+ //wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
1041
+ //wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
1042
+
1043
+ //wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
1044
+ //wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
1045
+ //wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
1046
+ //wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
1047
+ //wsp_ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
1048
+ wsp_ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
1049
+
1050
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
1051
+
1052
+ wsp_ggml_metal_cv_free(cv);
1053
+
1054
+ return res;
1055
+ }
1056
+
1057
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_blk(
1058
+ wsp_ggml_metal_library_t lib,
1059
+ const struct wsp_ggml_tensor * op,
1060
+ int32_t nqptg,
1061
+ int32_t ncpsg) {
1062
+ assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
1063
+ WSP_GGML_UNUSED(op);
1064
+
1065
+ char base[256];
1066
+ char name[256];
1067
+
1068
+ snprintf(base, 256, "kernel_%s",
1069
+ "flash_attn_ext_blk");
1070
+
1071
+ snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d",
1072
+ base,
1073
+ nqptg,
1074
+ ncpsg);
1075
+
1076
+ wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1077
+ if (res) {
1078
+ return res;
1079
+ }
1080
+
1081
+ wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
1082
+
1083
+ //wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0);
1084
+ //wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
1085
+ //wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2);
1086
+ //wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3);
1087
+
1088
+ //wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
1089
+ //wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
1090
+ //wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22);
1091
+ //wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23);
1092
+ wsp_ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
1093
+ wsp_ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
1094
+
1095
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
1096
+
1097
+ wsp_ggml_metal_cv_free(cv);
1098
+
1099
+ return res;
1100
+ }
1101
+
921
1102
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
922
1103
  wsp_ggml_metal_library_t lib,
923
1104
  const wsp_ggml_tensor * op,
@@ -925,6 +1106,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
925
1106
  bool has_sinks,
926
1107
  bool has_bias,
927
1108
  bool has_scap,
1109
+ bool has_kvpad,
928
1110
  int32_t nsg) {
929
1111
  assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
930
1112
 
@@ -937,18 +1119,23 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
937
1119
  const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
938
1120
  const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
939
1121
 
1122
+ // do bounds checks for the mask?
1123
+ const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
1124
+
940
1125
  snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
941
1126
  "flash_attn_ext",
942
1127
  wsp_ggml_type_name(op->src[1]->type),
943
1128
  dk,
944
1129
  dv);
945
1130
 
946
- snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
1131
+ snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
947
1132
  base,
948
1133
  has_mask,
949
1134
  has_sinks,
950
1135
  has_bias,
951
1136
  has_scap,
1137
+ has_kvpad,
1138
+ bc_mask,
952
1139
  ns10,
953
1140
  ns20,
954
1141
  nsg);
@@ -964,6 +1151,9 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
964
1151
  wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
965
1152
  wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
966
1153
  wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
1154
+ wsp_ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
1155
+
1156
+ wsp_ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
967
1157
 
968
1158
  wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
969
1159
  wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
@@ -983,6 +1173,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec
983
1173
  bool has_sinks,
984
1174
  bool has_bias,
985
1175
  bool has_scap,
1176
+ bool has_kvpad,
986
1177
  int32_t nsg,
987
1178
  int32_t nwg) {
988
1179
  assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
@@ -1002,12 +1193,13 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec
1002
1193
  dk,
1003
1194
  dv);
1004
1195
 
1005
- snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
1196
+ snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
1006
1197
  base,
1007
1198
  has_mask,
1008
1199
  has_sinks,
1009
1200
  has_bias,
1010
1201
  has_scap,
1202
+ has_kvpad,
1011
1203
  ns10,
1012
1204
  ns20,
1013
1205
  nsg, nwg);
@@ -1023,6 +1215,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec
1023
1215
  wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
1024
1216
  wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
1025
1217
  wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
1218
+ wsp_ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
1026
1219
 
1027
1220
  wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
1028
1221
  wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
@@ -1205,11 +1398,12 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_rope(wsp_ggml_meta
1205
1398
 
1206
1399
  const bool is_neox = mode & WSP_GGML_ROPE_TYPE_NEOX;
1207
1400
  const bool is_mrope = mode & WSP_GGML_ROPE_TYPE_MROPE;
1401
+ const bool is_imrope = mode == WSP_GGML_ROPE_TYPE_IMROPE;
1208
1402
  const bool is_vision = mode == WSP_GGML_ROPE_TYPE_VISION;
1209
1403
 
1210
1404
  if (is_neox) {
1211
1405
  snprintf(base, 256, "kernel_rope_neox_%s", wsp_ggml_type_name(op->src[0]->type));
1212
- } else if (is_mrope && !is_vision) {
1406
+ } else if ((is_mrope || is_imrope) && !is_vision) {
1213
1407
  WSP_GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
1214
1408
  snprintf(base, 256, "kernel_rope_multi_%s", wsp_ggml_type_name(op->src[0]->type));
1215
1409
  } else if (is_vision) {
@@ -1219,14 +1413,20 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_rope(wsp_ggml_meta
1219
1413
  snprintf(base, 256, "kernel_rope_norm_%s", wsp_ggml_type_name(op->src[0]->type));
1220
1414
  }
1221
1415
 
1222
- snprintf(name, 256, "%s", base);
1416
+ snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
1223
1417
 
1224
1418
  wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1225
1419
  if (res) {
1226
1420
  return res;
1227
1421
  }
1228
1422
 
1229
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1423
+ wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
1424
+
1425
+ wsp_ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
1426
+
1427
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
1428
+
1429
+ wsp_ggml_metal_cv_free(cv);
1230
1430
 
1231
1431
  return res;
1232
1432
  }
@@ -1279,6 +1479,55 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_1d(
1279
1479
  return res;
1280
1480
  }
1281
1481
 
1482
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_2d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1483
+ assert(op->op == WSP_GGML_OP_CONV_TRANSPOSE_2D);
1484
+
1485
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
1486
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[1]));
1487
+ WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F16 || op->src[0]->type == WSP_GGML_TYPE_F32);
1488
+ WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
1489
+ WSP_GGML_ASSERT(op->type == WSP_GGML_TYPE_F32);
1490
+
1491
+ char base[256];
1492
+ char name[256];
1493
+
1494
+ snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type));
1495
+ snprintf(name, 256, "%s", base);
1496
+
1497
+ wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1498
+ if (res) {
1499
+ return res;
1500
+ }
1501
+
1502
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1503
+
1504
+ return res;
1505
+ }
1506
+
1507
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_2d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1508
+ assert(op->op == WSP_GGML_OP_CONV_2D);
1509
+
1510
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
1511
+ WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F16 || op->src[0]->type == WSP_GGML_TYPE_F32);
1512
+ WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
1513
+ WSP_GGML_ASSERT(op->type == WSP_GGML_TYPE_F32);
1514
+
1515
+ char base[256];
1516
+ char name[256];
1517
+
1518
+ snprintf(base, 256, "kernel_conv_2d_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type));
1519
+ snprintf(name, 256, "%s", base);
1520
+
1521
+ wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1522
+ if (res) {
1523
+ return res;
1524
+ }
1525
+
1526
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1527
+
1528
+ return res;
1529
+ }
1530
+
1282
1531
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_upscale(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1283
1532
  assert(op->op == WSP_GGML_OP_UPSCALE);
1284
1533
 
@@ -1374,3 +1623,40 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_timestep_embedding
1374
1623
  return res;
1375
1624
  }
1376
1625
 
1626
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_adamw(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1627
+ assert(op->op == WSP_GGML_OP_OPT_STEP_ADAMW);
1628
+
1629
+ char base[256];
1630
+ char name[256];
1631
+
1632
+ snprintf(base, 256, "kernel_opt_step_adamw_%s", wsp_ggml_type_name(op->src[0]->type));
1633
+ snprintf(name, 256, "%s", base);
1634
+
1635
+ wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1636
+ if (res) {
1637
+ return res;
1638
+ }
1639
+
1640
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1641
+
1642
+ return res;
1643
+ }
1644
+
1645
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_sgd(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1646
+ assert(op->op == WSP_GGML_OP_OPT_STEP_SGD);
1647
+
1648
+ char base[256];
1649
+ char name[256];
1650
+
1651
+ snprintf(base, 256, "kernel_opt_step_sgd_%s", wsp_ggml_type_name(op->src[0]->type));
1652
+ snprintf(name, 256, "%s", base);
1653
+
1654
+ wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1655
+ if (res) {
1656
+ return res;
1657
+ }
1658
+
1659
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1660
+
1661
+ return res;
1662
+ }
@@ -95,7 +95,9 @@ void wsp_ggml_metal_encoder_end_encoding(wsp_ggml_metal_encoder_t encoder);
95
95
 
96
96
  typedef struct wsp_ggml_metal_library * wsp_ggml_metal_library_t;
97
97
 
98
- wsp_ggml_metal_library_t wsp_ggml_metal_library_init(wsp_ggml_metal_device_t dev);
98
+ wsp_ggml_metal_library_t wsp_ggml_metal_library_init (wsp_ggml_metal_device_t dev);
99
+ wsp_ggml_metal_library_t wsp_ggml_metal_library_init_from_source(wsp_ggml_metal_device_t dev, const char * source, bool verbose);
100
+
99
101
  void wsp_ggml_metal_library_free(wsp_ggml_metal_library_t lib);
100
102
 
101
103
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline (wsp_ggml_metal_library_t lib, const char * name);
@@ -109,7 +111,10 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_set_rows
109
111
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_repeat (wsp_ggml_metal_library_t lib, enum wsp_ggml_type tsrc);
110
112
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_unary (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
111
113
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_glu (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
114
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
112
115
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum_rows (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
116
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cumsum_blk (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
117
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cumsum_add (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
113
118
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_soft_max (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
114
119
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_conv (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
115
120
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_scan (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
@@ -122,6 +127,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mm_id
122
127
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mv_id (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
123
128
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argmax (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
124
129
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
130
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort_merge (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
125
131
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_bin (wsp_ggml_metal_library_t lib, enum wsp_ggml_op op, int32_t n_fuse, bool row);
126
132
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_l2_norm (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
127
133
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_group_norm (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
@@ -129,11 +135,27 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_norm
129
135
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_rope (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
130
136
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_im2col (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
131
137
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_1d (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
138
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_2d (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
139
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_2d (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
132
140
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_upscale (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
133
141
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pad (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
134
142
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pad_reflect_1d (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
135
143
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_arange (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
136
144
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_timestep_embedding(wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
145
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_adamw (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
146
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_sgd (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
147
+
148
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(
149
+ wsp_ggml_metal_library_t lib,
150
+ const struct wsp_ggml_tensor * op,
151
+ bool has_mask,
152
+ int32_t ncpsg);
153
+
154
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_blk(
155
+ wsp_ggml_metal_library_t lib,
156
+ const struct wsp_ggml_tensor * op,
157
+ int32_t nqptg,
158
+ int32_t ncpsg);
137
159
 
138
160
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
139
161
  wsp_ggml_metal_library_t lib,
@@ -142,6 +164,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
142
164
  bool has_sinks,
143
165
  bool has_bias,
144
166
  bool has_scap,
167
+ bool has_kvpad,
145
168
  int32_t nsg);
146
169
 
147
170
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec(
@@ -151,6 +174,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec
151
174
  bool has_sinks,
152
175
  bool has_bias,
153
176
  bool has_scap,
177
+ bool has_kvpad,
154
178
  int32_t nsg,
155
179
  int32_t nwg);
156
180
 
@@ -175,6 +199,7 @@ struct wsp_ggml_metal_device_props {
175
199
  bool has_simdgroup_mm;
176
200
  bool has_unified_memory;
177
201
  bool has_bfloat;
202
+ bool has_tensor;
178
203
  bool use_residency_sets;
179
204
  bool use_shared_buffers;
180
205