whisper.rn 0.5.1 → 0.5.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. package/android/src/main/jni.cpp +12 -3
  2. package/cpp/ggml-alloc.c +38 -14
  3. package/cpp/ggml-backend-impl.h +0 -3
  4. package/cpp/ggml-backend.h +2 -0
  5. package/cpp/ggml-cpu/amx/amx.cpp +1 -0
  6. package/cpp/ggml-cpu/ggml-cpu-impl.h +1 -1
  7. package/cpp/ggml-cpu/ggml-cpu.c +17 -3
  8. package/cpp/ggml-cpu/ops.cpp +33 -17
  9. package/cpp/ggml-cpu/unary-ops.cpp +135 -0
  10. package/cpp/ggml-cpu/unary-ops.h +5 -0
  11. package/cpp/ggml-cpu/vec.cpp +66 -0
  12. package/cpp/ggml-cpu/vec.h +10 -8
  13. package/cpp/ggml-impl.h +51 -2
  14. package/cpp/ggml-metal/ggml-metal-common.cpp +2 -2
  15. package/cpp/ggml-metal/ggml-metal-device.cpp +199 -10
  16. package/cpp/ggml-metal/ggml-metal-device.h +18 -0
  17. package/cpp/ggml-metal/ggml-metal-device.m +27 -14
  18. package/cpp/ggml-metal/ggml-metal-impl.h +87 -7
  19. package/cpp/ggml-metal/ggml-metal-ops.cpp +513 -88
  20. package/cpp/ggml-metal/ggml-metal-ops.h +6 -0
  21. package/cpp/ggml-metal/ggml-metal.cpp +3 -3
  22. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  23. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  24. package/cpp/ggml.c +166 -2
  25. package/cpp/ggml.h +66 -0
  26. package/cpp/jsi/RNWhisperJSI.cpp +7 -2
  27. package/cpp/rn-whisper.h +1 -0
  28. package/cpp/whisper.cpp +4 -2
  29. package/ios/RNWhisperContext.mm +3 -1
  30. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  31. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  32. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +51 -2
  33. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +66 -0
  34. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  35. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  36. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  37. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  38. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  39. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +51 -2
  40. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +66 -0
  41. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  42. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  43. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  44. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  45. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  46. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +51 -2
  47. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +66 -0
  48. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  49. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  50. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  51. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +0 -3
  52. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +2 -0
  53. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +51 -2
  54. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +66 -0
  55. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/rn-whisper.h +1 -0
  56. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  57. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  58. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  59. package/lib/commonjs/version.json +1 -1
  60. package/lib/module/NativeRNWhisper.js.map +1 -1
  61. package/lib/module/version.json +1 -1
  62. package/lib/typescript/NativeRNWhisper.d.ts +2 -0
  63. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  64. package/package.json +1 -1
  65. package/src/NativeRNWhisper.ts +2 -0
  66. package/src/version.json +1 -1
package/cpp/ggml-impl.h CHANGED
@@ -102,6 +102,9 @@ static bool wsp_ggml_op_is_empty(enum wsp_ggml_op op) {
102
102
  }
103
103
  }
104
104
 
105
+ static inline float wsp_ggml_softplus(float input) {
106
+ return (input > 20.0f) ? input : logf(1 + expf(input));
107
+ }
105
108
  //
106
109
  // logging
107
110
  //
@@ -562,14 +565,23 @@ static inline wsp_ggml_bf16_t wsp_ggml_compute_fp32_to_bf16(float s) {
562
565
  #define WSP_GGML_FP32_TO_BF16(x) wsp_ggml_compute_fp32_to_bf16(x)
563
566
  #define WSP_GGML_BF16_TO_FP32(x) wsp_ggml_compute_bf16_to_fp32(x)
564
567
 
568
+ static inline int32_t wsp_ggml_node_get_use_count(const struct wsp_ggml_cgraph * cgraph, int node_idx) {
569
+ const struct wsp_ggml_tensor * node = cgraph->nodes[node_idx];
570
+
571
+ size_t hash_pos = wsp_ggml_hash_find(&cgraph->visited_hash_set, node);
572
+ if (!wsp_ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos)) {
573
+ return 0;
574
+ }
575
+ return cgraph->use_counts[hash_pos];
576
+ }
577
+
565
578
  // return true if the node's results are only used by N other nodes
566
579
  // and can be fused into their calculations.
567
580
  static inline bool wsp_ggml_node_has_n_uses(const struct wsp_ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
568
581
  const struct wsp_ggml_tensor * node = cgraph->nodes[node_idx];
569
582
 
570
583
  // check the use count against how many we're replacing
571
- size_t hash_pos = wsp_ggml_hash_find(&cgraph->visited_hash_set, node);
572
- if (!wsp_ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) {
584
+ if (wsp_ggml_node_get_use_count(cgraph, node_idx) != n_uses) {
573
585
  return false;
574
586
  }
575
587
 
@@ -635,6 +647,36 @@ static inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int
635
647
  return wsp_ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
636
648
  }
637
649
 
650
+ WSP_GGML_API bool wsp_ggml_can_fuse_subgraph_ext(const struct wsp_ggml_cgraph * cgraph,
651
+ const int * node_idxs,
652
+ int count,
653
+ const enum wsp_ggml_op * ops,
654
+ const int * outputs,
655
+ int num_outputs);
656
+
657
+ // Returns true if the subgraph formed by {node_idxs} can be fused
658
+ // checks whethers all nodes which are not part of outputs can be elided
659
+ // by checking if their num_uses are confined to the subgraph
660
+ static inline bool wsp_ggml_can_fuse_subgraph(const struct wsp_ggml_cgraph * cgraph,
661
+ int node_idx,
662
+ int count,
663
+ const enum wsp_ggml_op * ops,
664
+ const int * outputs,
665
+ int num_outputs) {
666
+ WSP_GGML_ASSERT(count < 32);
667
+ if (node_idx + count > cgraph->n_nodes) {
668
+ return false;
669
+ }
670
+
671
+ int idxs[32];
672
+
673
+ for (int i = 0; i < count; ++i) {
674
+ idxs[i] = node_idx + i;
675
+ }
676
+
677
+ return wsp_ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, outputs, num_outputs);
678
+ }
679
+
638
680
  #ifdef __cplusplus
639
681
  }
640
682
  #endif
@@ -648,6 +690,13 @@ inline bool wsp_ggml_can_fuse(const struct wsp_ggml_cgraph * cgraph, int node_id
648
690
  return wsp_ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
649
691
  }
650
692
 
693
+ inline bool wsp_ggml_can_fuse_subgraph(const struct wsp_ggml_cgraph * cgraph,
694
+ int start_idx,
695
+ std::initializer_list<enum wsp_ggml_op> ops,
696
+ std::initializer_list<int> outputs = {}) {
697
+ return wsp_ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
698
+ }
699
+
651
700
  // expose GGUF internals for test code
652
701
  WSP_GGML_API size_t wsp_gguf_type_size(enum wsp_gguf_type type);
653
702
  WSP_GGML_API struct wsp_gguf_context * wsp_gguf_init_from_file_impl(FILE * file, struct wsp_gguf_init_params params);
@@ -112,7 +112,7 @@ static bool wsp_ggml_mem_ranges_add_dst(wsp_ggml_mem_ranges_t mrs, const wsp_ggm
112
112
  }
113
113
 
114
114
  bool wsp_ggml_mem_ranges_add(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
115
- for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) {
115
+ for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
116
116
  if (tensor->src[i]) {
117
117
  wsp_ggml_mem_ranges_add_src(mrs, tensor->src[i]);
118
118
  }
@@ -173,7 +173,7 @@ static bool wsp_ggml_mem_ranges_check_dst(wsp_ggml_mem_ranges_t mrs, const wsp_g
173
173
  }
174
174
 
175
175
  bool wsp_ggml_mem_ranges_check(wsp_ggml_mem_ranges_t mrs, const wsp_ggml_tensor * tensor) {
176
- for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) {
176
+ for (int i = 0; i < WSP_GGML_MAX_SRC; i++) {
177
177
  if (tensor->src[i]) {
178
178
  if (!wsp_ggml_mem_ranges_check_src(mrs, tensor->src[i])) {
179
179
  return false;
@@ -268,6 +268,25 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_glu(wsp_ggml_metal
268
268
  return res;
269
269
  }
270
270
 
271
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
272
+ assert(op->op == WSP_GGML_OP_SUM);
273
+
274
+ char base[256];
275
+ char name[256];
276
+
277
+ snprintf(base, 256, "kernel_op_sum_%s", wsp_ggml_type_name(op->src[0]->type));
278
+ snprintf(name, 256, "%s", base);
279
+
280
+ wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
281
+ if (res) {
282
+ return res;
283
+ }
284
+
285
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
286
+
287
+ return res;
288
+ }
289
+
271
290
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum_rows(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
272
291
  WSP_GGML_ASSERT(op->src[0]->nb[0] == wsp_ggml_type_size(op->src[0]->type));
273
292
 
@@ -338,7 +357,13 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_conv(wsp_ggml_
338
357
  char base[256];
339
358
  char name[256];
340
359
 
341
- snprintf(base, 256, "kernel_ssm_conv_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type));
360
+ const char * suffix = "";
361
+
362
+ if (op->src[1]->ne[0] % 4 == 0) {
363
+ suffix = "_4";
364
+ }
365
+
366
+ snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type), suffix);
342
367
  snprintf(name, 256, "%s", base);
343
368
 
344
369
  wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
@@ -352,15 +377,15 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_conv(wsp_ggml_
352
377
  }
353
378
 
354
379
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_scan(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
380
+ WSP_GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
381
+
355
382
  char base[256];
356
383
  char name[256];
357
384
 
358
- if (op->src[3]->ne[0] == 1) {
359
- snprintf(base, 256, "kernel_ssm_scan_group_%s", wsp_ggml_type_name(op->src[0]->type));
360
- } else {
361
- snprintf(base, 256, "kernel_ssm_scan_%s", wsp_ggml_type_name(op->src[0]->type));
362
- }
363
- snprintf(name, 256, "%s", base);
385
+ const int nsg = (ne00 + 31)/32;
386
+
387
+ snprintf(base, 256, "kernel_ssm_scan_%s", wsp_ggml_type_name(op->src[0]->type));
388
+ snprintf(name, 256, "%s_nsg=%d", base, nsg);
364
389
 
365
390
  wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
366
391
  if (res) {
@@ -369,7 +394,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_scan(wsp_ggml_
369
394
 
370
395
  res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
371
396
 
372
- wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
397
+ wsp_ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg);
373
398
 
374
399
  return res;
375
400
  }
@@ -918,6 +943,96 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort(wsp_ggml_m
918
943
  return res;
919
944
  }
920
945
 
946
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(
947
+ wsp_ggml_metal_library_t lib,
948
+ const struct wsp_ggml_tensor * op,
949
+ bool has_mask,
950
+ int32_t ncpsg) {
951
+ assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
952
+ WSP_GGML_UNUSED(op);
953
+
954
+ char base[256];
955
+ char name[256];
956
+
957
+ snprintf(base, 256, "kernel_%s",
958
+ "flash_attn_ext_pad");
959
+
960
+ snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
961
+ base,
962
+ has_mask,
963
+ ncpsg);
964
+
965
+ wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
966
+ if (res) {
967
+ return res;
968
+ }
969
+
970
+ wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
971
+
972
+ wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
973
+ //wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
974
+ //wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
975
+ //wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
976
+
977
+ //wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
978
+ //wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
979
+ //wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
980
+ //wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
981
+ //wsp_ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
982
+ wsp_ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
983
+
984
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
985
+
986
+ wsp_ggml_metal_cv_free(cv);
987
+
988
+ return res;
989
+ }
990
+
991
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_blk(
992
+ wsp_ggml_metal_library_t lib,
993
+ const struct wsp_ggml_tensor * op,
994
+ int32_t nqptg,
995
+ int32_t ncpsg) {
996
+ assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
997
+ WSP_GGML_UNUSED(op);
998
+
999
+ char base[256];
1000
+ char name[256];
1001
+
1002
+ snprintf(base, 256, "kernel_%s",
1003
+ "flash_attn_ext_blk");
1004
+
1005
+ snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d",
1006
+ base,
1007
+ nqptg,
1008
+ ncpsg);
1009
+
1010
+ wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1011
+ if (res) {
1012
+ return res;
1013
+ }
1014
+
1015
+ wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
1016
+
1017
+ //wsp_ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0);
1018
+ //wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
1019
+ //wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2);
1020
+ //wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3);
1021
+
1022
+ //wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
1023
+ //wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
1024
+ //wsp_ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22);
1025
+ //wsp_ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23);
1026
+ wsp_ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
1027
+ wsp_ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
1028
+
1029
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
1030
+
1031
+ wsp_ggml_metal_cv_free(cv);
1032
+
1033
+ return res;
1034
+ }
1035
+
921
1036
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
922
1037
  wsp_ggml_metal_library_t lib,
923
1038
  const wsp_ggml_tensor * op,
@@ -925,6 +1040,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
925
1040
  bool has_sinks,
926
1041
  bool has_bias,
927
1042
  bool has_scap,
1043
+ bool has_kvpad,
928
1044
  int32_t nsg) {
929
1045
  assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
930
1046
 
@@ -937,18 +1053,23 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
937
1053
  const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
938
1054
  const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
939
1055
 
1056
+ // do bounds checks for the mask?
1057
+ const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
1058
+
940
1059
  snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
941
1060
  "flash_attn_ext",
942
1061
  wsp_ggml_type_name(op->src[1]->type),
943
1062
  dk,
944
1063
  dv);
945
1064
 
946
- snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
1065
+ snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
947
1066
  base,
948
1067
  has_mask,
949
1068
  has_sinks,
950
1069
  has_bias,
951
1070
  has_scap,
1071
+ has_kvpad,
1072
+ bc_mask,
952
1073
  ns10,
953
1074
  ns20,
954
1075
  nsg);
@@ -964,6 +1085,9 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
964
1085
  wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
965
1086
  wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
966
1087
  wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
1088
+ wsp_ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
1089
+
1090
+ wsp_ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
967
1091
 
968
1092
  wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
969
1093
  wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
@@ -983,6 +1107,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec
983
1107
  bool has_sinks,
984
1108
  bool has_bias,
985
1109
  bool has_scap,
1110
+ bool has_kvpad,
986
1111
  int32_t nsg,
987
1112
  int32_t nwg) {
988
1113
  assert(op->op == WSP_GGML_OP_FLASH_ATTN_EXT);
@@ -1002,12 +1127,13 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec
1002
1127
  dk,
1003
1128
  dv);
1004
1129
 
1005
- snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
1130
+ snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
1006
1131
  base,
1007
1132
  has_mask,
1008
1133
  has_sinks,
1009
1134
  has_bias,
1010
1135
  has_scap,
1136
+ has_kvpad,
1011
1137
  ns10,
1012
1138
  ns20,
1013
1139
  nsg, nwg);
@@ -1023,6 +1149,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec
1023
1149
  wsp_ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
1024
1150
  wsp_ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
1025
1151
  wsp_ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
1152
+ wsp_ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
1026
1153
 
1027
1154
  wsp_ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
1028
1155
  wsp_ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
@@ -1279,6 +1406,31 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_1d(
1279
1406
  return res;
1280
1407
  }
1281
1408
 
1409
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_2d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1410
+ assert(op->op == WSP_GGML_OP_CONV_TRANSPOSE_2D);
1411
+
1412
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
1413
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[1]));
1414
+ WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F16 || op->src[0]->type == WSP_GGML_TYPE_F32);
1415
+ WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
1416
+ WSP_GGML_ASSERT(op->type == WSP_GGML_TYPE_F32);
1417
+
1418
+ char base[256];
1419
+ char name[256];
1420
+
1421
+ snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type));
1422
+ snprintf(name, 256, "%s", base);
1423
+
1424
+ wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1425
+ if (res) {
1426
+ return res;
1427
+ }
1428
+
1429
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1430
+
1431
+ return res;
1432
+ }
1433
+
1282
1434
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_upscale(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1283
1435
  assert(op->op == WSP_GGML_OP_UPSCALE);
1284
1436
 
@@ -1374,3 +1526,40 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_timestep_embedding
1374
1526
  return res;
1375
1527
  }
1376
1528
 
1529
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_adamw(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1530
+ assert(op->op == WSP_GGML_OP_OPT_STEP_ADAMW);
1531
+
1532
+ char base[256];
1533
+ char name[256];
1534
+
1535
+ snprintf(base, 256, "kernel_opt_step_adamw_%s", wsp_ggml_type_name(op->src[0]->type));
1536
+ snprintf(name, 256, "%s", base);
1537
+
1538
+ wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1539
+ if (res) {
1540
+ return res;
1541
+ }
1542
+
1543
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1544
+
1545
+ return res;
1546
+ }
1547
+
1548
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_sgd(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1549
+ assert(op->op == WSP_GGML_OP_OPT_STEP_SGD);
1550
+
1551
+ char base[256];
1552
+ char name[256];
1553
+
1554
+ snprintf(base, 256, "kernel_opt_step_sgd_%s", wsp_ggml_type_name(op->src[0]->type));
1555
+ snprintf(name, 256, "%s", base);
1556
+
1557
+ wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1558
+ if (res) {
1559
+ return res;
1560
+ }
1561
+
1562
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1563
+
1564
+ return res;
1565
+ }
@@ -109,6 +109,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_set_rows
109
109
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_repeat (wsp_ggml_metal_library_t lib, enum wsp_ggml_type tsrc);
110
110
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_unary (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
111
111
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_glu (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
112
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
112
113
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum_rows (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
113
114
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_soft_max (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
114
115
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_conv (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
@@ -129,11 +130,26 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_norm
129
130
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_rope (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
130
131
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_im2col (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
131
132
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_1d (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
133
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_2d (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
132
134
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_upscale (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
133
135
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pad (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
134
136
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pad_reflect_1d (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
135
137
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_arange (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
136
138
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_timestep_embedding(wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
139
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_adamw (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
140
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_opt_step_sgd (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
141
+
142
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(
143
+ wsp_ggml_metal_library_t lib,
144
+ const struct wsp_ggml_tensor * op,
145
+ bool has_mask,
146
+ int32_t ncpsg);
147
+
148
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_blk(
149
+ wsp_ggml_metal_library_t lib,
150
+ const struct wsp_ggml_tensor * op,
151
+ int32_t nqptg,
152
+ int32_t ncpsg);
137
153
 
138
154
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
139
155
  wsp_ggml_metal_library_t lib,
@@ -142,6 +158,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext(
142
158
  bool has_sinks,
143
159
  bool has_bias,
144
160
  bool has_scap,
161
+ bool has_kvpad,
145
162
  int32_t nsg);
146
163
 
147
164
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec(
@@ -151,6 +168,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_vec
151
168
  bool has_sinks,
152
169
  bool has_bias,
153
170
  bool has_scap,
171
+ bool has_kvpad,
154
172
  int32_t nsg,
155
173
  int32_t nwg);
156
174
 
@@ -7,6 +7,8 @@
7
7
 
8
8
  #include <Metal/Metal.h>
9
9
 
10
+ #include <stdatomic.h>
11
+
10
12
  #ifndef TARGET_OS_VISION
11
13
  #define TARGET_OS_VISION 0
12
14
  #endif
@@ -22,6 +24,9 @@
22
24
  // overload of MTLGPUFamilyMetal3 (not available in some environments)
23
25
  static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
24
26
 
27
+ // virtual address for GPU memory allocations
28
+ static atomic_uintptr_t g_addr_device = 0x000000400ULL;
29
+
25
30
  #if !WSP_GGML_METAL_EMBED_LIBRARY
26
31
  // Here to assist with NSBundle Path Hack
27
32
  @interface WSPGGMLMetalClass : NSObject
@@ -652,6 +657,11 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
652
657
  case WSP_GGML_OP_SCALE:
653
658
  case WSP_GGML_OP_CONV_TRANSPOSE_1D:
654
659
  return true;
660
+ case WSP_GGML_OP_CONV_TRANSPOSE_2D:
661
+ return wsp_ggml_is_contiguous(op->src[0]) && wsp_ggml_is_contiguous(op->src[1]) &&
662
+ (op->src[0]->type == WSP_GGML_TYPE_F16 || op->src[0]->type == WSP_GGML_TYPE_F32) &&
663
+ op->src[1]->type == WSP_GGML_TYPE_F32 &&
664
+ op->type == WSP_GGML_TYPE_F32;
655
665
  case WSP_GGML_OP_CLAMP:
656
666
  return op->src[0]->type == WSP_GGML_TYPE_F32;
657
667
  case WSP_GGML_OP_SQR:
@@ -660,6 +670,8 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
660
670
  case WSP_GGML_OP_COS:
661
671
  case WSP_GGML_OP_LOG:
662
672
  return wsp_ggml_is_contiguous(op->src[0]) && op->src[0]->type == WSP_GGML_TYPE_F32;
673
+ case WSP_GGML_OP_SUM:
674
+ return has_simdgroup_reduction && wsp_ggml_is_contiguous(op->src[0]);
663
675
  case WSP_GGML_OP_SUM_ROWS:
664
676
  case WSP_GGML_OP_MEAN:
665
677
  case WSP_GGML_OP_SOFT_MAX:
@@ -696,7 +708,8 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
696
708
  return true;
697
709
  case WSP_GGML_OP_FLASH_ATTN_EXT:
698
710
  // for new head sizes, add checks here
699
- if (op->src[0]->ne[0] != 40 &&
711
+ if (op->src[0]->ne[0] != 32 &&
712
+ op->src[0]->ne[0] != 40 &&
700
713
  op->src[0]->ne[0] != 64 &&
701
714
  op->src[0]->ne[0] != 80 &&
702
715
  op->src[0]->ne[0] != 96 &&
@@ -780,9 +793,7 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
780
793
  };
781
794
  }
782
795
  case WSP_GGML_OP_GET_ROWS:
783
- {
784
- return op->ne[3] == 1;
785
- }
796
+ return true;
786
797
  case WSP_GGML_OP_SET_ROWS:
787
798
  {
788
799
  if (op->src[0]->type != WSP_GGML_TYPE_F32) {
@@ -804,6 +815,9 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
804
815
  return false;
805
816
  };
806
817
  }
818
+ case WSP_GGML_OP_OPT_STEP_ADAMW:
819
+ case WSP_GGML_OP_OPT_STEP_SGD:
820
+ return has_simdgroup_reduction;
807
821
  default:
808
822
  return false;
809
823
  }
@@ -828,7 +842,7 @@ struct wsp_ggml_metal_buffer_wrapper {
828
842
  };
829
843
 
830
844
  struct wsp_ggml_metal_buffer {
831
- void * all_data; // TODO: https://github.com/ggml-org/llama.cpp/pull/15985
845
+ void * all_data;
832
846
  size_t all_size;
833
847
 
834
848
  // if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host
@@ -966,14 +980,15 @@ wsp_ggml_metal_buffer_t wsp_ggml_metal_buffer_init(wsp_ggml_metal_device_t dev,
966
980
  if (shared) {
967
981
  res->all_data = wsp_ggml_metal_host_malloc(size_aligned);
968
982
  res->is_shared = true;
969
- res->owned = true;
970
983
  } else {
971
- // dummy, non-NULL value - we'll populate this after creating the Metal buffer below
972
- res->all_data = (void *) 0x000000400ULL;
984
+ // use virtual address from g_addr_device counter
985
+ res->all_data = (void *) atomic_fetch_add_explicit(&g_addr_device, size_aligned, memory_order_relaxed);
973
986
  res->is_shared = false;
974
987
  }
975
988
  res->all_size = size_aligned;
976
989
 
990
+ res->owned = true;
991
+
977
992
  res->device = wsp_ggml_metal_device_get_obj(dev);
978
993
  res->queue = wsp_ggml_metal_device_get_queue(dev);
979
994
 
@@ -984,15 +999,13 @@ wsp_ggml_metal_buffer_t wsp_ggml_metal_buffer_init(wsp_ggml_metal_device_t dev,
984
999
  res->buffers[0].metal = nil;
985
1000
 
986
1001
  if (size_aligned > 0) {
987
- if (props_dev->use_shared_buffers &&shared) {
1002
+ if (props_dev->use_shared_buffers && shared) {
988
1003
  res->buffers[0].metal = [res->device newBufferWithBytesNoCopy:res->all_data
989
1004
  length:size_aligned
990
1005
  options:MTLResourceStorageModeShared
991
1006
  deallocator:nil];
992
1007
  } else {
993
1008
  res->buffers[0].metal = [res->device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate];
994
-
995
- res->all_data = (void *) (res->buffers[0].metal.gpuAddress);
996
1009
  }
997
1010
  }
998
1011
 
@@ -1140,7 +1153,7 @@ bool wsp_ggml_metal_buffer_is_shared(wsp_ggml_metal_buffer_t buf) {
1140
1153
 
1141
1154
  void wsp_ggml_metal_buffer_memset_tensor(wsp_ggml_metal_buffer_t buf, struct wsp_ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
1142
1155
  if (buf->is_shared) {
1143
- memset((char *)tensor->data + offset, value, size);
1156
+ memset((char *) tensor->data + offset, value, size);
1144
1157
  return;
1145
1158
  }
1146
1159
 
@@ -1169,7 +1182,7 @@ void wsp_ggml_metal_buffer_memset_tensor(wsp_ggml_metal_buffer_t buf, struct wsp
1169
1182
 
1170
1183
  void wsp_ggml_metal_buffer_set_tensor(wsp_ggml_metal_buffer_t buf, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1171
1184
  if (buf->is_shared) {
1172
- memcpy((char *)tensor->data + offset, data, size);
1185
+ memcpy((char *) tensor->data + offset, data, size);
1173
1186
  return;
1174
1187
  }
1175
1188
 
@@ -1224,7 +1237,7 @@ void wsp_ggml_metal_buffer_set_tensor(wsp_ggml_metal_buffer_t buf, struct wsp_gg
1224
1237
 
1225
1238
  void wsp_ggml_metal_buffer_get_tensor(wsp_ggml_metal_buffer_t buf, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1226
1239
  if (buf->is_shared) {
1227
- memcpy(data, (const char *)tensor->data + offset, size);
1240
+ memcpy(data, (const char *) tensor->data + offset, size);
1228
1241
  return;
1229
1242
  }
1230
1243