whisper.rn 0.5.2 → 0.5.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. package/README.md +1 -1
  2. package/cpp/ggml-alloc.c +11 -4
  3. package/cpp/ggml-backend-reg.cpp +8 -0
  4. package/cpp/ggml-backend.cpp +0 -2
  5. package/cpp/ggml-cpu/arch/arm/quants.c +428 -26
  6. package/cpp/ggml-cpu/ggml-cpu-impl.h +3 -1
  7. package/cpp/ggml-cpu/ggml-cpu.c +50 -21
  8. package/cpp/ggml-cpu/ops.cpp +458 -349
  9. package/cpp/ggml-cpu/ops.h +4 -4
  10. package/cpp/ggml-cpu/repack.cpp +143 -29
  11. package/cpp/ggml-cpu/simd-mappings.h +25 -25
  12. package/cpp/ggml-cpu/unary-ops.cpp +16 -0
  13. package/cpp/ggml-cpu/unary-ops.h +2 -0
  14. package/cpp/ggml-cpu/vec.cpp +17 -0
  15. package/cpp/ggml-cpu/vec.h +10 -0
  16. package/cpp/ggml-impl.h +17 -1
  17. package/cpp/ggml-metal/ggml-metal-context.m +5 -6
  18. package/cpp/ggml-metal/ggml-metal-device.cpp +101 -4
  19. package/cpp/ggml-metal/ggml-metal-device.h +8 -1
  20. package/cpp/ggml-metal/ggml-metal-device.m +216 -14
  21. package/cpp/ggml-metal/ggml-metal-impl.h +90 -2
  22. package/cpp/ggml-metal/ggml-metal-ops.cpp +346 -85
  23. package/cpp/ggml-metal/ggml-metal-ops.h +2 -0
  24. package/cpp/ggml-metal/ggml-metal.cpp +5 -0
  25. package/cpp/ggml-metal/ggml-metal.metal +12436 -0
  26. package/cpp/ggml.c +154 -5
  27. package/cpp/ggml.h +73 -0
  28. package/cpp/whisper.cpp +6 -2
  29. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +17 -1
  30. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +73 -0
  31. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  32. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  33. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  34. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +17 -1
  35. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +73 -0
  36. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  37. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  38. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  39. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  40. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +17 -1
  41. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +73 -0
  42. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  43. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-metal.metal +12436 -0
  44. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  45. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +17 -1
  46. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +73 -0
  47. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  48. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  49. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-metal.metal +12436 -0
  50. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  51. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +156 -12
  52. package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  53. package/lib/module/realtime-transcription/RealtimeTranscriber.js +155 -12
  54. package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
  55. package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts +29 -0
  56. package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
  57. package/lib/typescript/realtime-transcription/types.d.ts +7 -0
  58. package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
  59. package/package.json +1 -1
  60. package/src/realtime-transcription/RealtimeTranscriber.ts +179 -9
  61. package/src/realtime-transcription/types.ts +9 -0
  62. package/whisper-rn.podspec +1 -1
  63. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  64. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  65. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  66. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  67. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  68. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
@@ -318,6 +318,44 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum_rows(wsp_ggml_
318
318
  return res;
319
319
  }
320
320
 
321
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cumsum_blk(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
322
+ WSP_GGML_ASSERT(op->op == WSP_GGML_OP_CUMSUM);
323
+
324
+ char base[256];
325
+ char name[256];
326
+
327
+ snprintf(base, 256, "kernel_cumsum_blk_%s", wsp_ggml_type_name(op->src[0]->type));
328
+ snprintf(name, 256, "%s", base);
329
+
330
+ wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
331
+ if (res) {
332
+ return res;
333
+ }
334
+
335
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
336
+
337
+ return res;
338
+ }
339
+
340
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cumsum_add(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
341
+ WSP_GGML_ASSERT(op->op == WSP_GGML_OP_CUMSUM);
342
+
343
+ char base[256];
344
+ char name[256];
345
+
346
+ snprintf(base, 256, "kernel_cumsum_add_%s", wsp_ggml_type_name(op->src[0]->type));
347
+ snprintf(name, 256, "%s", base);
348
+
349
+ wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
350
+ if (res) {
351
+ return res;
352
+ }
353
+
354
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
355
+
356
+ return res;
357
+ }
358
+
321
359
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_soft_max(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
322
360
  WSP_GGML_ASSERT(!op->src[1] || op->src[1]->type == WSP_GGML_TYPE_F16 || op->src[1]->type == WSP_GGML_TYPE_F32);
323
361
 
@@ -677,7 +715,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mm_id_map0(wsp
677
715
  char name[256];
678
716
 
679
717
  snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
680
- snprintf(name, 256, "%s", base);
718
+ snprintf(name, 256, "%s_ne02=%d", base, ne02);
681
719
 
682
720
  wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
683
721
  if (res) {
@@ -943,6 +981,34 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort(wsp_ggml_m
943
981
  return res;
944
982
  }
945
983
 
984
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort_merge(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
985
+ assert(op->op == WSP_GGML_OP_ARGSORT);
986
+
987
+ char base[256];
988
+ char name[256];
989
+
990
+ wsp_ggml_sort_order order = (wsp_ggml_sort_order) op->op_params[0];
991
+
992
+ const char * order_str = "undefined";
993
+ switch (order) {
994
+ case WSP_GGML_SORT_ORDER_ASC: order_str = "asc"; break;
995
+ case WSP_GGML_SORT_ORDER_DESC: order_str = "desc"; break;
996
+ default: WSP_GGML_ABORT("fatal error");
997
+ };
998
+
999
+ snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->type), order_str);
1000
+ snprintf(name, 256, "%s", base);
1001
+
1002
+ wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1003
+ if (res) {
1004
+ return res;
1005
+ }
1006
+
1007
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1008
+
1009
+ return res;
1010
+ }
1011
+
946
1012
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_flash_attn_ext_pad(
947
1013
  wsp_ggml_metal_library_t lib,
948
1014
  const struct wsp_ggml_tensor * op,
@@ -1332,11 +1398,12 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_rope(wsp_ggml_meta
1332
1398
 
1333
1399
  const bool is_neox = mode & WSP_GGML_ROPE_TYPE_NEOX;
1334
1400
  const bool is_mrope = mode & WSP_GGML_ROPE_TYPE_MROPE;
1401
+ const bool is_imrope = mode == WSP_GGML_ROPE_TYPE_IMROPE;
1335
1402
  const bool is_vision = mode == WSP_GGML_ROPE_TYPE_VISION;
1336
1403
 
1337
1404
  if (is_neox) {
1338
1405
  snprintf(base, 256, "kernel_rope_neox_%s", wsp_ggml_type_name(op->src[0]->type));
1339
- } else if (is_mrope && !is_vision) {
1406
+ } else if ((is_mrope || is_imrope) && !is_vision) {
1340
1407
  WSP_GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
1341
1408
  snprintf(base, 256, "kernel_rope_multi_%s", wsp_ggml_type_name(op->src[0]->type));
1342
1409
  } else if (is_vision) {
@@ -1346,14 +1413,20 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_rope(wsp_ggml_meta
1346
1413
  snprintf(base, 256, "kernel_rope_norm_%s", wsp_ggml_type_name(op->src[0]->type));
1347
1414
  }
1348
1415
 
1349
- snprintf(name, 256, "%s", base);
1416
+ snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
1350
1417
 
1351
1418
  wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1352
1419
  if (res) {
1353
1420
  return res;
1354
1421
  }
1355
1422
 
1356
- res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1423
+ wsp_ggml_metal_cv_t cv = wsp_ggml_metal_cv_init();
1424
+
1425
+ wsp_ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
1426
+
1427
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, cv);
1428
+
1429
+ wsp_ggml_metal_cv_free(cv);
1357
1430
 
1358
1431
  return res;
1359
1432
  }
@@ -1431,6 +1504,30 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_2d(
1431
1504
  return res;
1432
1505
  }
1433
1506
 
1507
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_2d(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1508
+ assert(op->op == WSP_GGML_OP_CONV_2D);
1509
+
1510
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(op->src[0]));
1511
+ WSP_GGML_ASSERT(op->src[0]->type == WSP_GGML_TYPE_F16 || op->src[0]->type == WSP_GGML_TYPE_F32);
1512
+ WSP_GGML_ASSERT(op->src[1]->type == WSP_GGML_TYPE_F32);
1513
+ WSP_GGML_ASSERT(op->type == WSP_GGML_TYPE_F32);
1514
+
1515
+ char base[256];
1516
+ char name[256];
1517
+
1518
+ snprintf(base, 256, "kernel_conv_2d_%s_%s", wsp_ggml_type_name(op->src[0]->type), wsp_ggml_type_name(op->src[1]->type));
1519
+ snprintf(name, 256, "%s", base);
1520
+
1521
+ wsp_ggml_metal_pipeline_t res = wsp_ggml_metal_library_get_pipeline(lib, name);
1522
+ if (res) {
1523
+ return res;
1524
+ }
1525
+
1526
+ res = wsp_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1527
+
1528
+ return res;
1529
+ }
1530
+
1434
1531
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_upscale(wsp_ggml_metal_library_t lib, const wsp_ggml_tensor * op) {
1435
1532
  assert(op->op == WSP_GGML_OP_UPSCALE);
1436
1533
 
@@ -95,7 +95,9 @@ void wsp_ggml_metal_encoder_end_encoding(wsp_ggml_metal_encoder_t encoder);
95
95
 
96
96
  typedef struct wsp_ggml_metal_library * wsp_ggml_metal_library_t;
97
97
 
98
- wsp_ggml_metal_library_t wsp_ggml_metal_library_init(wsp_ggml_metal_device_t dev);
98
+ wsp_ggml_metal_library_t wsp_ggml_metal_library_init (wsp_ggml_metal_device_t dev);
99
+ wsp_ggml_metal_library_t wsp_ggml_metal_library_init_from_source(wsp_ggml_metal_device_t dev, const char * source, bool verbose);
100
+
99
101
  void wsp_ggml_metal_library_free(wsp_ggml_metal_library_t lib);
100
102
 
101
103
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline (wsp_ggml_metal_library_t lib, const char * name);
@@ -111,6 +113,8 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_unary
111
113
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_glu (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
112
114
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
113
115
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_sum_rows (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
116
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cumsum_blk (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
117
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_cumsum_add (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
114
118
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_soft_max (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
115
119
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_conv (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
116
120
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_ssm_scan (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
@@ -123,6 +127,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mm_id
123
127
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_mul_mv_id (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
124
128
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argmax (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
125
129
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
130
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_argsort_merge (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
126
131
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_bin (wsp_ggml_metal_library_t lib, enum wsp_ggml_op op, int32_t n_fuse, bool row);
127
132
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_l2_norm (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
128
133
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_group_norm (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
@@ -131,6 +136,7 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_rope
131
136
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_im2col (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
132
137
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_1d (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
133
138
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_transpose_2d (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
139
+ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_conv_2d (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
134
140
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_upscale (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
135
141
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pad (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
136
142
  wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_get_pipeline_pad_reflect_1d (wsp_ggml_metal_library_t lib, const struct wsp_ggml_tensor * op);
@@ -193,6 +199,7 @@ struct wsp_ggml_metal_device_props {
193
199
  bool has_simdgroup_mm;
194
200
  bool has_unified_memory;
195
201
  bool has_bfloat;
202
+ bool has_tensor;
196
203
  bool use_residency_sets;
197
204
  bool use_shared_buffers;
198
205
 
@@ -21,8 +21,9 @@
21
21
  #define WSP_GGML_METAL_HAS_RESIDENCY_SETS 1
22
22
  #endif
23
23
 
24
- // overload of MTLGPUFamilyMetal3 (not available in some environments)
24
+ // overload of MTLGPUFamilyMetalX (not available in some environments)
25
25
  static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
26
+ static const NSInteger MTLGPUFamilyMetal4_GGML = 5002;
26
27
 
27
28
  // virtual address for GPU memory allocations
28
29
  static atomic_uintptr_t g_addr_device = 0x000000400ULL;
@@ -180,11 +181,7 @@ wsp_ggml_metal_library_t wsp_ggml_metal_library_init(wsp_ggml_metal_device_t dev
180
181
  NSBundle * bundle = [NSBundle bundleForClass:[WSPGGMLMetalClass class]];
181
182
  #endif
182
183
 
183
- #if TARGET_OS_SIMULATOR
184
- NSString * path_lib = [bundle pathForResource:@"ggml-whisper-sim" ofType:@"metallib"];
185
- #else
186
- NSString * path_lib = [bundle pathForResource:@"ggml-whisper" ofType:@"metallib"];
187
- #endif
184
+ NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
188
185
  if (path_lib == nil) {
189
186
  // Try to find the resource in the directory where the current binary located.
190
187
  NSString * bin_cur = [[NSProcessInfo processInfo] arguments][0];
@@ -265,6 +262,10 @@ wsp_ggml_metal_library_t wsp_ggml_metal_library_init(wsp_ggml_metal_device_t dev
265
262
  [prep setObject:@"1" forKey:@"WSP_GGML_METAL_HAS_BF16"];
266
263
  }
267
264
 
265
+ if (wsp_ggml_metal_device_get_props(dev)->has_tensor) {
266
+ [prep setObject:@"1" forKey:@"WSP_GGML_METAL_HAS_TENSOR"];
267
+ }
268
+
268
269
  #if WSP_GGML_METAL_EMBED_LIBRARY
269
270
  [prep setObject:@"1" forKey:@"WSP_GGML_METAL_EMBED_LIBRARY"];
270
271
  #endif
@@ -302,6 +303,72 @@ wsp_ggml_metal_library_t wsp_ggml_metal_library_init(wsp_ggml_metal_device_t dev
302
303
  return res;
303
304
  }
304
305
 
306
+ wsp_ggml_metal_library_t wsp_ggml_metal_library_init_from_source(wsp_ggml_metal_device_t dev, const char * source, bool verbose) {
307
+ if (source == NULL) {
308
+ WSP_GGML_LOG_ERROR("%s: source is NULL\n", __func__);
309
+ return NULL;
310
+ }
311
+
312
+ id<MTLDevice> device = wsp_ggml_metal_device_get_obj(dev);
313
+ id<MTLLibrary> library = nil;
314
+ NSError * error = nil;
315
+
316
+ const int64_t t_start = wsp_ggml_time_us();
317
+
318
+ NSString * src = [[NSString alloc] initWithBytes:source
319
+ length:strlen(source)
320
+ encoding:NSUTF8StringEncoding];
321
+ if (!src) {
322
+ WSP_GGML_LOG_ERROR("%s: failed to create NSString from source\n", __func__);
323
+ return NULL;
324
+ }
325
+
326
+ @autoreleasepool {
327
+ NSMutableDictionary * prep = [NSMutableDictionary dictionary];
328
+
329
+ MTLCompileOptions * options = [MTLCompileOptions new];
330
+ options.preprocessorMacros = prep;
331
+
332
+ library = [device newLibraryWithSource:src options:options error:&error];
333
+ if (error) {
334
+ if (verbose) {
335
+ WSP_GGML_LOG_ERROR("%s: error compiling source: %s\n", __func__, [[error description] UTF8String]);
336
+ } else {
337
+ WSP_GGML_LOG_ERROR("%s: error compiling source\n", __func__);
338
+ }
339
+ library = nil;
340
+ }
341
+
342
+ [options release];
343
+ }
344
+
345
+ [src release];
346
+
347
+ if (!library) {
348
+ if (verbose) {
349
+ WSP_GGML_LOG_ERROR("%s: failed to create Metal library from source\n", __func__);
350
+ }
351
+
352
+ return NULL;
353
+ }
354
+
355
+ if (verbose) {
356
+ WSP_GGML_LOG_INFO("%s: compiled in %.3f sec\n", __func__, (wsp_ggml_time_us() - t_start) / 1e6);
357
+ }
358
+
359
+ wsp_ggml_metal_library_t res = calloc(1, sizeof(struct wsp_ggml_metal_library));
360
+ if (!res) {
361
+ WSP_GGML_LOG_ERROR("%s: calloc failed\n", __func__);
362
+ return NULL;
363
+ }
364
+
365
+ res->obj = library;
366
+ res->device = device;
367
+ res->pipelines = wsp_ggml_metal_pipelines_init();
368
+
369
+ return res;
370
+ }
371
+
305
372
  void wsp_ggml_metal_library_free(wsp_ggml_metal_library_t lib) {
306
373
  if (!lib) {
307
374
  return;
@@ -349,9 +416,9 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_compile_pipeline(wsp_ggml_metal
349
416
  if (!mtl_function) {
350
417
  wsp_ggml_critical_section_end();
351
418
 
352
- WSP_GGML_LOG_ERROR("%s: error: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
419
+ WSP_GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
353
420
  if (error) {
354
- WSP_GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
421
+ WSP_GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]);
355
422
  }
356
423
 
357
424
  return nil;
@@ -359,13 +426,21 @@ wsp_ggml_metal_pipeline_t wsp_ggml_metal_library_compile_pipeline(wsp_ggml_metal
359
426
 
360
427
  res->obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error];
361
428
 
362
- wsp_ggml_metal_pipelines_add(lib->pipelines, name, res);
363
-
364
429
  [mtl_function release];
365
430
 
366
431
  WSP_GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) res->obj,
367
432
  (int) res->obj.maxTotalThreadsPerThreadgroup,
368
433
  (int) res->obj.threadExecutionWidth);
434
+
435
+ if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) {
436
+ wsp_ggml_critical_section_end();
437
+
438
+ WSP_GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name);
439
+
440
+ return nil;
441
+ }
442
+
443
+ wsp_ggml_metal_pipelines_add(lib->pipelines, name, res);
369
444
  }
370
445
 
371
446
  wsp_ggml_critical_section_end();
@@ -473,6 +548,128 @@ wsp_ggml_metal_device_t wsp_ggml_metal_device_init(void) {
473
548
 
474
549
  dev->props.has_bfloat = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
475
550
  dev->props.has_bfloat |= [dev->mtl_device supportsFamily:MTLGPUFamilyApple6];
551
+ if (getenv("WSP_GGML_METAL_BF16_DISABLE") != NULL) {
552
+ dev->props.has_bfloat = false;
553
+ }
554
+
555
+ dev->props.has_tensor = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal4_GGML];
556
+ if (getenv("WSP_GGML_METAL_TENSOR_DISABLE") != NULL) {
557
+ dev->props.has_tensor = false;
558
+ }
559
+
560
+ // note: disable the tensor API by default for old chips because with the current implementation it is not useful
561
+ // - M2 Ultra: ~5% slower
562
+ // - M4, M4 Max: no significant difference
563
+ //
564
+ // TODO: try to update the tensor API kernels to at least match the simdgroup performance
565
+ if (getenv("WSP_GGML_METAL_TENSOR_ENABLE") == NULL &&
566
+ ![[dev->mtl_device name] containsString:@"M5"] &&
567
+ ![[dev->mtl_device name] containsString:@"M6"] &&
568
+ ![[dev->mtl_device name] containsString:@"A19"] &&
569
+ ![[dev->mtl_device name] containsString:@"A20"]) {
570
+ WSP_GGML_LOG_WARN("%s: tensor API disabled for pre-M5 and pre-A19 devices\n", __func__);
571
+ dev->props.has_tensor = false;
572
+ }
573
+
574
+ // double-check that the tensor API compiles
575
+ if (dev->props.has_tensor) {
576
+ const char * src_tensor_f16 = "\n"
577
+ "#include <metal_stdlib> \n"
578
+ "#include <metal_tensor> \n"
579
+ "#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n"
580
+ " \n"
581
+ "using namespace metal; \n"
582
+ "using namespace mpp::tensor_ops; \n"
583
+ " \n"
584
+ "kernel void dummy_kernel( \n"
585
+ " tensor<device half, dextents<int32_t, 2>> A [[buffer(0)]], \n"
586
+ " tensor<device half, dextents<int32_t, 2>> B [[buffer(1)]], \n"
587
+ " device float * C [[buffer(2)]], \n"
588
+ " uint2 tgid [[threadgroup_position_in_grid]]) \n"
589
+ "{ \n"
590
+ " auto tA = A.slice(0, (int)tgid.y); \n"
591
+ " auto tB = B.slice((int)tgid.x, 0); \n"
592
+ " \n"
593
+ " matmul2d< \n"
594
+ " matmul2d_descriptor(8, 8, dynamic_extent), \n"
595
+ " execution_simdgroups<4>> mm; \n"
596
+ " \n"
597
+ " auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n"
598
+ " \n"
599
+ " auto sA = tA.slice(0, 0); \n"
600
+ " auto sB = tB.slice(0, 0); \n"
601
+ " mm.run(sB, sA, cT); \n"
602
+ " \n"
603
+ " auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n"
604
+ " \n"
605
+ " cT.store(tC); \n"
606
+ "}";
607
+
608
+ WSP_GGML_LOG_INFO("%s: testing tensor API for f16 support\n", __func__);
609
+ wsp_ggml_metal_library_t lib = wsp_ggml_metal_library_init_from_source(dev, src_tensor_f16, false);
610
+ if (lib == NULL) {
611
+ WSP_GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
612
+ dev->props.has_tensor = false;
613
+ } else {
614
+ wsp_ggml_metal_pipeline_t ppl = wsp_ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
615
+ if (!ppl) {
616
+ WSP_GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
617
+ dev->props.has_tensor = false;
618
+ }
619
+
620
+ wsp_ggml_metal_library_free(lib);
621
+ }
622
+ }
623
+
624
+ // try to compile a dummy kernel to determine if the tensor API is supported for bfloat
625
+ if (dev->props.has_tensor && dev->props.has_bfloat) {
626
+ const char * src_tensor_bf16 = "\n"
627
+ "#include <metal_stdlib> \n"
628
+ "#include <metal_tensor> \n"
629
+ "#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n"
630
+ " \n"
631
+ "using namespace metal; \n"
632
+ "using namespace mpp::tensor_ops; \n"
633
+ " \n"
634
+ "kernel void dummy_kernel( \n"
635
+ " tensor<device bfloat, dextents<int32_t, 2>> A [[buffer(0)]], \n"
636
+ " tensor<device bfloat, dextents<int32_t, 2>> B [[buffer(1)]], \n"
637
+ " device float * C [[buffer(2)]], \n"
638
+ " uint2 tgid [[threadgroup_position_in_grid]]) \n"
639
+ "{ \n"
640
+ " auto tA = A.slice(0, (int)tgid.y); \n"
641
+ " auto tB = B.slice((int)tgid.x, 0); \n"
642
+ " \n"
643
+ " matmul2d< \n"
644
+ " matmul2d_descriptor(8, 8, dynamic_extent), \n"
645
+ " execution_simdgroups<4>> mm; \n"
646
+ " \n"
647
+ " auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n"
648
+ " \n"
649
+ " auto sA = tA.slice(0, 0); \n"
650
+ " auto sB = tB.slice(0, 0); \n"
651
+ " mm.run(sB, sA, cT); \n"
652
+ " \n"
653
+ " auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n"
654
+ " \n"
655
+ " cT.store(tC); \n"
656
+ "}";
657
+
658
+ WSP_GGML_LOG_INFO("%s: testing tensor API for bfloat support\n", __func__);
659
+ wsp_ggml_metal_library_t lib = wsp_ggml_metal_library_init_from_source(dev, src_tensor_bf16, false);
660
+ if (lib == NULL) {
661
+ WSP_GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
662
+ dev->props.has_bfloat = false;
663
+ } else {
664
+ wsp_ggml_metal_pipeline_t ppl = wsp_ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
665
+ if (!ppl) {
666
+ WSP_GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
667
+ dev->props.has_bfloat = false;
668
+ }
669
+
670
+ wsp_ggml_metal_library_free(lib);
671
+ }
672
+ }
476
673
 
477
674
  dev->props.use_residency_sets = true;
478
675
  #if defined(WSP_GGML_METAL_HAS_RESIDENCY_SETS)
@@ -480,7 +677,6 @@ wsp_ggml_metal_device_t wsp_ggml_metal_device_init(void) {
480
677
  #endif
481
678
 
482
679
  dev->props.use_shared_buffers = dev->props.has_unified_memory;
483
-
484
680
  if (getenv("WSP_GGML_METAL_SHARED_BUFFERS_DISABLE") != NULL) {
485
681
  dev->props.use_shared_buffers = false;
486
682
  }
@@ -533,6 +729,7 @@ wsp_ggml_metal_device_t wsp_ggml_metal_device_init(void) {
533
729
  WSP_GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, dev->props.has_simdgroup_mm ? "true" : "false");
534
730
  WSP_GGML_LOG_INFO("%s: has unified memory = %s\n", __func__, dev->props.has_unified_memory ? "true" : "false");
535
731
  WSP_GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, dev->props.has_bfloat ? "true" : "false");
732
+ WSP_GGML_LOG_INFO("%s: has tensor = %s\n", __func__, dev->props.has_tensor ? "true" : "false");
536
733
  WSP_GGML_LOG_INFO("%s: use residency sets = %s\n", __func__, dev->props.use_residency_sets ? "true" : "false");
537
734
  WSP_GGML_LOG_INFO("%s: use shared buffers = %s\n", __func__, dev->props.use_shared_buffers ? "true" : "false");
538
735
 
@@ -673,6 +870,7 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
673
870
  case WSP_GGML_OP_SUM:
674
871
  return has_simdgroup_reduction && wsp_ggml_is_contiguous(op->src[0]);
675
872
  case WSP_GGML_OP_SUM_ROWS:
873
+ case WSP_GGML_OP_CUMSUM:
676
874
  case WSP_GGML_OP_MEAN:
677
875
  case WSP_GGML_OP_SOFT_MAX:
678
876
  case WSP_GGML_OP_GROUP_NORM:
@@ -688,6 +886,11 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
688
886
  return true;
689
887
  case WSP_GGML_OP_IM2COL:
690
888
  return wsp_ggml_is_contiguous(op->src[1]) && op->src[1]->type == WSP_GGML_TYPE_F32 && (op->type == WSP_GGML_TYPE_F16 || op->type == WSP_GGML_TYPE_F32);
889
+ case WSP_GGML_OP_CONV_2D:
890
+ return wsp_ggml_is_contiguous(op->src[0]) &&
891
+ op->src[1]->type == WSP_GGML_TYPE_F32 &&
892
+ op->type == WSP_GGML_TYPE_F32 &&
893
+ (op->src[0]->type == WSP_GGML_TYPE_F16 || op->src[0]->type == WSP_GGML_TYPE_F32);
691
894
  case WSP_GGML_OP_POOL_1D:
692
895
  return false;
693
896
  case WSP_GGML_OP_UPSCALE:
@@ -702,8 +905,6 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
702
905
  case WSP_GGML_OP_LEAKY_RELU:
703
906
  return op->src[0]->type == WSP_GGML_TYPE_F32;
704
907
  case WSP_GGML_OP_ARGSORT:
705
- // TODO: Support arbitrary column width
706
- return op->src[0]->ne[0] <= 1024;
707
908
  case WSP_GGML_OP_ARANGE:
708
909
  return true;
709
910
  case WSP_GGML_OP_FLASH_ATTN_EXT:
@@ -711,6 +912,7 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
711
912
  if (op->src[0]->ne[0] != 32 &&
712
913
  op->src[0]->ne[0] != 40 &&
713
914
  op->src[0]->ne[0] != 64 &&
915
+ op->src[0]->ne[0] != 72 &&
714
916
  op->src[0]->ne[0] != 80 &&
715
917
  op->src[0]->ne[0] != 96 &&
716
918
  op->src[0]->ne[0] != 112 &&
@@ -787,7 +989,7 @@ bool wsp_ggml_metal_device_supports_op(wsp_ggml_metal_device_t dev, const struct
787
989
  return false;
788
990
  }
789
991
  case WSP_GGML_TYPE_I32:
790
- return op->type == WSP_GGML_TYPE_F32;
992
+ return op->type == WSP_GGML_TYPE_F32 || op->type == WSP_GGML_TYPE_I32;
791
993
  default:
792
994
  return false;
793
995
  };
@@ -76,6 +76,7 @@
76
76
  #define FC_FLASH_ATTN_EXT_VEC_REDUCE 500
77
77
  #define FC_MUL_MV 600
78
78
  #define FC_MUL_MM 700
79
+ #define FC_ROPE 800
79
80
 
80
81
  // op-specific constants
81
82
  #define OP_FLASH_ATTN_EXT_NQPTG 8
@@ -527,6 +528,36 @@ typedef struct {
527
528
  uint64_t nb2;
528
529
  } wsp_ggml_metal_kargs_conv_transpose_2d;
529
530
 
531
+ typedef struct {
532
+ uint64_t nb00;
533
+ uint64_t nb01;
534
+ uint64_t nb02;
535
+ uint64_t nb03;
536
+ uint64_t nb10;
537
+ uint64_t nb11;
538
+ uint64_t nb12;
539
+ uint64_t nb13;
540
+ uint64_t nb0;
541
+ uint64_t nb1;
542
+ uint64_t nb2;
543
+ uint64_t nb3;
544
+ int32_t IW;
545
+ int32_t IH;
546
+ int32_t KW;
547
+ int32_t KH;
548
+ int32_t IC;
549
+ int32_t OC;
550
+ int32_t OW;
551
+ int32_t OH;
552
+ int32_t N;
553
+ int32_t s0;
554
+ int32_t s1;
555
+ int32_t p0;
556
+ int32_t p1;
557
+ int32_t d0;
558
+ int32_t d1;
559
+ } wsp_ggml_metal_kargs_conv_2d;
560
+
530
561
  typedef struct {
531
562
  uint64_t ofs0;
532
563
  uint64_t ofs1;
@@ -581,6 +612,45 @@ typedef struct {
581
612
  uint64_t nb3;
582
613
  } wsp_ggml_metal_kargs_sum_rows;
583
614
 
615
+ typedef struct {
616
+ int64_t ne00;
617
+ int64_t ne01;
618
+ int64_t ne02;
619
+ int64_t ne03;
620
+ uint64_t nb00;
621
+ uint64_t nb01;
622
+ uint64_t nb02;
623
+ uint64_t nb03;
624
+ int64_t net0;
625
+ int64_t net1;
626
+ int64_t net2;
627
+ int64_t net3;
628
+ uint64_t nbt0;
629
+ uint64_t nbt1;
630
+ uint64_t nbt2;
631
+ uint64_t nbt3;
632
+ bool outb;
633
+ } wsp_ggml_metal_kargs_cumsum_blk;
634
+
635
+ typedef struct {
636
+ int64_t ne00;
637
+ int64_t ne01;
638
+ int64_t ne02;
639
+ int64_t ne03;
640
+ uint64_t nb00;
641
+ uint64_t nb01;
642
+ uint64_t nb02;
643
+ uint64_t nb03;
644
+ int64_t net0;
645
+ int64_t net1;
646
+ int64_t net2;
647
+ int64_t net3;
648
+ uint64_t nbt0;
649
+ uint64_t nbt1;
650
+ uint64_t nbt2;
651
+ uint64_t nbt3;
652
+ } wsp_ggml_metal_kargs_cumsum_add;
653
+
584
654
  typedef struct {
585
655
  int32_t ne00;
586
656
  int32_t ne01;
@@ -762,10 +832,28 @@ typedef struct {
762
832
  } wsp_ggml_metal_kargs_leaky_relu;
763
833
 
764
834
  typedef struct {
765
- int64_t ncols;
766
- int64_t ncols_pad;
835
+ int64_t ne00;
836
+ int64_t ne01;
837
+ int64_t ne02;
838
+ int64_t ne03;
839
+ uint64_t nb00;
840
+ uint64_t nb01;
841
+ uint64_t nb02;
842
+ uint64_t nb03;
767
843
  } wsp_ggml_metal_kargs_argsort;
768
844
 
845
+ typedef struct {
846
+ int64_t ne00;
847
+ int64_t ne01;
848
+ int64_t ne02;
849
+ int64_t ne03;
850
+ uint64_t nb00;
851
+ uint64_t nb01;
852
+ uint64_t nb02;
853
+ uint64_t nb03;
854
+ int32_t len;
855
+ } wsp_ggml_metal_kargs_argsort_merge;
856
+
769
857
  typedef struct {
770
858
  int64_t ne0;
771
859
  float start;