whisper.rn 0.4.2 → 0.5.0-rc.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. package/README.md +1 -3
  2. package/android/build.gradle +70 -11
  3. package/android/src/main/CMakeLists.txt +28 -1
  4. package/android/src/main/java/com/rnwhisper/JSCallInvokerResolver.java +40 -0
  5. package/android/src/main/java/com/rnwhisper/RNWhisper.java +80 -27
  6. package/android/src/main/java/com/rnwhisper/WhisperContext.java +21 -9
  7. package/android/src/main/java/com/rnwhisper/WhisperVadContext.java +1 -1
  8. package/android/src/main/jni.cpp +79 -2
  9. package/android/src/main/jniLibs/arm64-v8a/librnwhisper.so +0 -0
  10. package/android/src/main/jniLibs/arm64-v8a/librnwhisper_v8fp16_va_2.so +0 -0
  11. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper.so +0 -0
  12. package/android/src/main/jniLibs/armeabi-v7a/librnwhisper_vfpv4.so +0 -0
  13. package/android/src/main/jniLibs/x86_64/librnwhisper.so +0 -0
  14. package/android/src/main/jniLibs/x86_64/librnwhisper_x86_64.so +0 -0
  15. package/android/src/newarch/java/com/rnwhisper/RNWhisperModule.java +5 -0
  16. package/android/src/oldarch/java/com/rnwhisper/RNWhisperModule.java +5 -0
  17. package/cpp/ggml-backend.cpp +36 -18
  18. package/cpp/ggml-backend.h +1 -1
  19. package/cpp/ggml-cpu/amx/mmq.cpp +10 -9
  20. package/cpp/ggml-cpu/arch/arm/quants.c +109 -108
  21. package/cpp/ggml-cpu/arch/arm/repack.cpp +13 -12
  22. package/cpp/ggml-cpu/arch/x86/quants.c +83 -82
  23. package/cpp/ggml-cpu/arch/x86/repack.cpp +20 -19
  24. package/cpp/ggml-cpu/common.h +3 -2
  25. package/cpp/ggml-cpu/ggml-cpu-impl.h +9 -3
  26. package/cpp/ggml-cpu/ggml-cpu.c +95 -17
  27. package/cpp/ggml-cpu/ggml-cpu.cpp +4 -0
  28. package/cpp/ggml-cpu/ops.cpp +775 -74
  29. package/cpp/ggml-cpu/ops.h +7 -0
  30. package/cpp/ggml-cpu/quants.c +25 -24
  31. package/cpp/ggml-cpu/repack.cpp +15 -14
  32. package/cpp/ggml-cpu/simd-mappings.h +211 -33
  33. package/cpp/ggml-cpu/vec.cpp +26 -2
  34. package/cpp/ggml-cpu/vec.h +99 -45
  35. package/cpp/ggml-cpu.h +2 -0
  36. package/cpp/ggml-impl.h +125 -183
  37. package/cpp/ggml-metal-impl.h +27 -0
  38. package/cpp/ggml-metal.m +298 -41
  39. package/cpp/ggml-quants.c +6 -6
  40. package/cpp/ggml-whisper-sim.metallib +0 -0
  41. package/cpp/ggml-whisper.metallib +0 -0
  42. package/cpp/ggml.c +269 -40
  43. package/cpp/ggml.h +122 -2
  44. package/cpp/gguf.cpp +5 -1
  45. package/cpp/jsi/RNWhisperJSI.cpp +681 -0
  46. package/cpp/jsi/RNWhisperJSI.h +44 -0
  47. package/cpp/jsi/ThreadPool.h +100 -0
  48. package/cpp/whisper.cpp +4 -0
  49. package/cpp/whisper.h +2 -0
  50. package/ios/RNWhisper.h +3 -0
  51. package/ios/RNWhisper.mm +66 -31
  52. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +1 -1
  53. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
  54. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +125 -183
  55. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
  56. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +122 -2
  57. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +2 -0
  58. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  59. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  60. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +1 -1
  61. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
  62. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +125 -183
  63. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
  64. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +122 -2
  65. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +2 -0
  66. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  67. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  68. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +1 -1
  69. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
  70. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +125 -183
  71. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
  72. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +122 -2
  73. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +2 -0
  74. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  75. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  76. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +1 -1
  77. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +2 -0
  78. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +125 -183
  79. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +27 -0
  80. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +122 -2
  81. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +2 -0
  82. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  83. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  84. package/jest/mock.js +1 -0
  85. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  86. package/lib/commonjs/index.js +83 -2
  87. package/lib/commonjs/index.js.map +1 -1
  88. package/lib/module/NativeRNWhisper.js.map +1 -1
  89. package/lib/module/index.js +83 -2
  90. package/lib/module/index.js.map +1 -1
  91. package/lib/typescript/NativeRNWhisper.d.ts +4 -0
  92. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  93. package/lib/typescript/index.d.ts +18 -6
  94. package/lib/typescript/index.d.ts.map +1 -1
  95. package/package.json +2 -3
  96. package/src/NativeRNWhisper.ts +2 -0
  97. package/src/index.ts +162 -33
  98. package/whisper-rn.podspec +6 -3
package/cpp/ggml-metal.m CHANGED
@@ -48,22 +48,28 @@ static struct wsp_ggml_backend_metal_device_context {
48
48
  int mtl_device_ref_count;
49
49
  id<MTLLibrary> mtl_library;
50
50
 
51
+ NSLock * mtl_lock;
52
+
51
53
  bool has_simdgroup_reduction;
52
54
  bool has_simdgroup_mm;
53
55
  bool has_residency_sets;
54
56
  bool has_bfloat;
55
57
  bool use_bfloat;
56
58
 
59
+ size_t max_size;
60
+
57
61
  char name[128];
58
62
  } g_wsp_ggml_ctx_dev_main = {
59
63
  /*.mtl_device =*/ nil,
60
64
  /*.mtl_device_ref_count =*/ 0,
61
65
  /*.mtl_library =*/ nil,
66
+ /*.mtl_lock =*/ nil,
62
67
  /*.has_simdgroup_reduction =*/ false,
63
68
  /*.has_simdgroup_mm =*/ false,
64
69
  /*.has_residency_sets =*/ false,
65
70
  /*.has_bfloat =*/ false,
66
71
  /*.use_bfloat =*/ false,
72
+ /*.max_size =*/ 0,
67
73
  /*.name =*/ "",
68
74
  };
69
75
 
@@ -71,6 +77,10 @@ static struct wsp_ggml_backend_metal_device_context {
71
77
  static id<MTLDevice> wsp_ggml_backend_metal_device_acq(struct wsp_ggml_backend_metal_device_context * ctx) {
72
78
  assert(ctx != NULL);
73
79
 
80
+ if (ctx->mtl_lock == nil) {
81
+ ctx->mtl_lock = [[NSLock alloc] init];
82
+ }
83
+
74
84
  if (ctx->mtl_device == nil) {
75
85
  ctx->mtl_device = MTLCreateSystemDefaultDevice();
76
86
  }
@@ -94,6 +104,8 @@ static id<MTLDevice> wsp_ggml_backend_metal_device_acq(struct wsp_ggml_backend_m
94
104
  ctx->use_bfloat = false;
95
105
  #endif
96
106
 
107
+ ctx->max_size = ctx->mtl_device.maxBufferLength;
108
+
97
109
  strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
98
110
  }
99
111
 
@@ -110,6 +122,11 @@ static void wsp_ggml_backend_metal_device_rel(struct wsp_ggml_backend_metal_devi
110
122
  ctx->mtl_device_ref_count--;
111
123
 
112
124
  if (ctx->mtl_device_ref_count == 0) {
125
+ if (ctx->mtl_lock) {
126
+ [ctx->mtl_lock release];
127
+ ctx->mtl_lock = nil;
128
+ }
129
+
113
130
  if (ctx->mtl_library) {
114
131
  [ctx->mtl_library release];
115
132
  ctx->mtl_library = nil;
@@ -185,6 +202,15 @@ enum wsp_ggml_metal_kernel_type {
185
202
  WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
186
203
  WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
187
204
  WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
205
+ WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
206
+ WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
207
+ WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
208
+ WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
209
+ WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
210
+ WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
211
+ WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
212
+ WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
213
+ WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
188
214
  WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM,
189
215
  WSP_GGML_METAL_KERNEL_TYPE_L2_NORM,
190
216
  WSP_GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -194,11 +220,14 @@ enum wsp_ggml_metal_kernel_type {
194
220
  WSP_GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
195
221
  WSP_GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
196
222
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
223
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4,
197
224
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
225
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4,
198
226
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
199
227
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
200
228
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
201
229
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
230
+ WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4,
202
231
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
203
232
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
204
233
  WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
@@ -497,6 +526,9 @@ enum wsp_ggml_metal_kernel_type {
497
526
  WSP_GGML_METAL_KERNEL_TYPE_SIN,
498
527
  WSP_GGML_METAL_KERNEL_TYPE_COS,
499
528
  WSP_GGML_METAL_KERNEL_TYPE_NEG,
529
+ WSP_GGML_METAL_KERNEL_TYPE_REGLU,
530
+ WSP_GGML_METAL_KERNEL_TYPE_GEGLU,
531
+ WSP_GGML_METAL_KERNEL_TYPE_SWIGLU,
500
532
  WSP_GGML_METAL_KERNEL_TYPE_SUM_ROWS,
501
533
  WSP_GGML_METAL_KERNEL_TYPE_MEAN,
502
534
  WSP_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
@@ -981,7 +1013,7 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
981
1013
  struct wsp_ggml_backend_metal_context * ctx = calloc(1, sizeof(struct wsp_ggml_backend_metal_context));
982
1014
  struct wsp_ggml_backend_metal_device_context * ctx_dev = dev->context;
983
1015
 
984
- id<MTLDevice> device = wsp_ggml_backend_metal_device_acq(ctx_dev);
1016
+ id<MTLDevice> device = ctx_dev->mtl_device;
985
1017
 
986
1018
  WSP_GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
987
1019
 
@@ -995,9 +1027,16 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
995
1027
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
996
1028
 
997
1029
  // load library
998
- if (ctx_dev->mtl_library == nil) {
999
- ctx_dev->mtl_library = wsp_ggml_metal_load_library(device, ctx_dev->use_bfloat);
1030
+ {
1031
+ [ctx_dev->mtl_lock lock];
1032
+
1033
+ if (ctx_dev->mtl_library == nil) {
1034
+ ctx_dev->mtl_library = wsp_ggml_metal_load_library(device, ctx_dev->use_bfloat);
1035
+ }
1036
+
1037
+ [ctx_dev->mtl_lock unlock];
1000
1038
  }
1039
+
1001
1040
  id<MTLLibrary> metal_library = ctx_dev->mtl_library;
1002
1041
  if (metal_library == nil) {
1003
1042
  WSP_GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__);
@@ -1146,6 +1185,15 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
1146
1185
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
1147
1186
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
1148
1187
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
1188
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true);
1189
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true);
1190
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
1191
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true);
1192
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true);
1193
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true);
1194
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true);
1195
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
1196
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
1149
1197
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
1150
1198
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
1151
1199
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
@@ -1155,11 +1203,14 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
1155
1203
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
1156
1204
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
1157
1205
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
1206
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, mul_mv_f32_f32_c4, true);
1158
1207
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
1208
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, mul_mv_bf16_f32_c4, use_bfloat);
1159
1209
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
1160
1210
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
1161
1211
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
1162
1212
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
1213
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, mul_mv_f16_f32_c4, true);
1163
1214
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
1164
1215
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
1165
1216
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
@@ -1458,6 +1509,9 @@ static struct wsp_ggml_backend_metal_context * wsp_ggml_metal_init(wsp_ggml_back
1458
1509
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SIN, sin, true);
1459
1510
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_COS, cos, true);
1460
1511
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_NEG, neg, true);
1512
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
1513
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
1514
+ WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
1461
1515
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
1462
1516
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
1463
1517
  WSP_GGML_METAL_ADD_KERNEL(WSP_GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
@@ -1609,6 +1663,10 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
1609
1663
  const bool use_bfloat = ctx_dev->use_bfloat;
1610
1664
 
1611
1665
  if (!use_bfloat) {
1666
+ if (op->type == WSP_GGML_TYPE_BF16) {
1667
+ return false;
1668
+ }
1669
+
1612
1670
  for (size_t i = 0, n = 3; i < n; ++i) {
1613
1671
  if (op->src[i] != NULL && op->src[i]->type == WSP_GGML_TYPE_BF16) {
1614
1672
  return false;
@@ -1632,6 +1690,15 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
1632
1690
  default:
1633
1691
  return false;
1634
1692
  }
1693
+ case WSP_GGML_OP_GLU:
1694
+ switch (wsp_ggml_get_glu_op(op)) {
1695
+ case WSP_GGML_GLU_OP_REGLU:
1696
+ case WSP_GGML_GLU_OP_GEGLU:
1697
+ case WSP_GGML_GLU_OP_SWIGLU:
1698
+ return wsp_ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == WSP_GGML_TYPE_F32;
1699
+ default:
1700
+ return false;
1701
+ }
1635
1702
  case WSP_GGML_OP_NONE:
1636
1703
  case WSP_GGML_OP_RESHAPE:
1637
1704
  case WSP_GGML_OP_VIEW:
@@ -1778,6 +1845,27 @@ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_backend_metal_devic
1778
1845
  {
1779
1846
  return op->ne[3] == 1;
1780
1847
  }
1848
+ case WSP_GGML_OP_SET_ROWS:
1849
+ {
1850
+ if (op->src[0]->type != WSP_GGML_TYPE_F32) {
1851
+ return false;
1852
+ }
1853
+
1854
+ switch (op->type) {
1855
+ case WSP_GGML_TYPE_F32:
1856
+ case WSP_GGML_TYPE_F16:
1857
+ case WSP_GGML_TYPE_BF16:
1858
+ case WSP_GGML_TYPE_Q8_0:
1859
+ case WSP_GGML_TYPE_Q4_0:
1860
+ case WSP_GGML_TYPE_Q4_1:
1861
+ case WSP_GGML_TYPE_Q5_0:
1862
+ case WSP_GGML_TYPE_Q5_1:
1863
+ case WSP_GGML_TYPE_IQ4_NL:
1864
+ return true;
1865
+ default:
1866
+ return false;
1867
+ };
1868
+ }
1781
1869
  default:
1782
1870
  return false;
1783
1871
  }
@@ -2350,6 +2438,62 @@ static bool wsp_ggml_metal_encode_node(
2350
2438
  WSP_GGML_ABORT("fatal error");
2351
2439
  }
2352
2440
  } break;
2441
+ case WSP_GGML_OP_GLU:
2442
+ {
2443
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous_1(src0));
2444
+
2445
+ if (src1) {
2446
+ WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, src1));
2447
+ }
2448
+
2449
+ id<MTLComputePipelineState> pipeline = nil;
2450
+
2451
+ switch (wsp_ggml_get_glu_op(node)) {
2452
+ case WSP_GGML_GLU_OP_REGLU:
2453
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_REGLU].pipeline;
2454
+ break;
2455
+ case WSP_GGML_GLU_OP_GEGLU:
2456
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_GEGLU].pipeline;
2457
+ break;
2458
+ case WSP_GGML_GLU_OP_SWIGLU:
2459
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
2460
+ break;
2461
+ default:
2462
+ WSP_GGML_ABORT("fatal error");
2463
+ }
2464
+
2465
+ const int32_t swp = ((const int32_t *) dst->op_params)[1];
2466
+
2467
+ const int32_t i00 = swp ? ne0 : 0;
2468
+ const int32_t i10 = swp ? 0 : ne0;
2469
+
2470
+ wsp_ggml_metal_kargs_glu args = {
2471
+ /*.ne00 =*/ ne00,
2472
+ /*.nb01 =*/ nb01,
2473
+ /*.ne10 =*/ src1 ? ne10 : ne00,
2474
+ /*.nb11 =*/ src1 ? nb11 : nb01,
2475
+ /*.ne0 =*/ ne0,
2476
+ /*.nb1 =*/ nb1,
2477
+ /*.i00 =*/ src1 ? 0 : i00,
2478
+ /*.i10 =*/ src1 ? 0 : i10,
2479
+ };
2480
+
2481
+ [encoder setComputePipelineState:pipeline];
2482
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2483
+ if (src1) {
2484
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2485
+ } else {
2486
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2487
+ }
2488
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2489
+ [encoder setBytes:&args length:sizeof(args) atIndex:3];
2490
+
2491
+ const int64_t nrows = wsp_ggml_nrows(src0);
2492
+
2493
+ const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2);
2494
+
2495
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2496
+ } break;
2353
2497
  case WSP_GGML_OP_SQR:
2354
2498
  {
2355
2499
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
@@ -2430,6 +2574,7 @@ static bool wsp_ggml_metal_encode_node(
2430
2574
  nth *= 2;
2431
2575
  }
2432
2576
 
2577
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
2433
2578
  nth = MIN(nth, ne00);
2434
2579
 
2435
2580
  wsp_ggml_metal_kargs_sum_rows args = {
@@ -3090,14 +3235,23 @@ static bool wsp_ggml_metal_encode_node(
3090
3235
  nsg = 1;
3091
3236
  nr0 = 1;
3092
3237
  nr1 = 4;
3093
- pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
3238
+ if (ne00 == 4) {
3239
+ nr0 = 32;
3240
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4].pipeline;
3241
+ } else {
3242
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
3243
+ }
3094
3244
  } break;
3095
3245
  case WSP_GGML_TYPE_F16:
3096
3246
  {
3097
3247
  nsg = 1;
3098
3248
  nr0 = 1;
3099
3249
  if (src1t == WSP_GGML_TYPE_F32) {
3100
- if (ne11 * ne12 < 4) {
3250
+ if (ne00 == 4) {
3251
+ nr0 = 32;
3252
+ nr1 = 4;
3253
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4].pipeline;
3254
+ } else if (ne11 * ne12 < 4) {
3101
3255
  pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
3102
3256
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
3103
3257
  pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
@@ -3116,7 +3270,11 @@ static bool wsp_ggml_metal_encode_node(
3116
3270
  nsg = 1;
3117
3271
  nr0 = 1;
3118
3272
  if (src1t == WSP_GGML_TYPE_F32) {
3119
- if (ne11 * ne12 < 4) {
3273
+ if (ne00 == 4) {
3274
+ nr0 = 32;
3275
+ nr1 = 4;
3276
+ pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4].pipeline;
3277
+ } else if (ne11 * ne12 < 4) {
3120
3278
  pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
3121
3279
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
3122
3280
  pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
@@ -3737,13 +3895,74 @@ static bool wsp_ggml_metal_encode_node(
3737
3895
  };
3738
3896
 
3739
3897
  [encoder setComputePipelineState:pipeline];
3740
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3741
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3742
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3743
- [encoder setBytes:&args length:sizeof(args) atIndex:3];
3898
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3899
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3900
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3901
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3744
3902
 
3745
3903
  [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
3746
3904
  } break;
3905
+ case WSP_GGML_OP_SET_ROWS:
3906
+ {
3907
+ id<MTLComputePipelineState> pipeline = nil;
3908
+
3909
+ switch (dst->type) {
3910
+ case WSP_GGML_TYPE_F32: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline; break;
3911
+ case WSP_GGML_TYPE_F16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline; break;
3912
+ case WSP_GGML_TYPE_BF16: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline; break;
3913
+ case WSP_GGML_TYPE_Q8_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline; break;
3914
+ case WSP_GGML_TYPE_Q4_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline; break;
3915
+ case WSP_GGML_TYPE_Q4_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline; break;
3916
+ case WSP_GGML_TYPE_Q5_0: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline; break;
3917
+ case WSP_GGML_TYPE_Q5_1: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline; break;
3918
+ case WSP_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[WSP_GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline; break;
3919
+ default: WSP_GGML_ABORT("not implemented");
3920
+ }
3921
+
3922
+ const int32_t nk0 = ne0/wsp_ggml_blck_size(dst->type);
3923
+
3924
+ int nth = 32; // SIMD width
3925
+
3926
+ while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
3927
+ nth *= 2;
3928
+ }
3929
+
3930
+ int nrptg = 1;
3931
+ if (nth > nk0) {
3932
+ nrptg = (nth + nk0 - 1)/nk0;
3933
+ nth = nk0;
3934
+
3935
+ if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
3936
+ nrptg--;
3937
+ }
3938
+ }
3939
+
3940
+ nth = MIN(nth, nk0);
3941
+
3942
+ wsp_ggml_metal_kargs_set_rows args = {
3943
+ /*.nk0 =*/ nk0,
3944
+ /*.ne01 =*/ ne01,
3945
+ /*.nb01 =*/ nb01,
3946
+ /*.nb02 =*/ nb02,
3947
+ /*.nb03 =*/ nb03,
3948
+ /*.ne11 =*/ ne11,
3949
+ /*.ne12 =*/ ne12,
3950
+ /*.nb10 =*/ nb10,
3951
+ /*.nb11 =*/ nb11,
3952
+ /*.nb12 =*/ nb12,
3953
+ /*.nb1 =*/ nb1,
3954
+ /*.nb2 =*/ nb2,
3955
+ /*.nb3 =*/ nb3,
3956
+ };
3957
+
3958
+ [encoder setComputePipelineState:pipeline];
3959
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
3960
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3961
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3962
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3963
+
3964
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
3965
+ } break;
3747
3966
  case WSP_GGML_OP_RMS_NORM:
3748
3967
  {
3749
3968
  WSP_GGML_ASSERT(ne00 % 4 == 0);
@@ -3760,6 +3979,7 @@ static bool wsp_ggml_metal_encode_node(
3760
3979
  nth *= 2;
3761
3980
  }
3762
3981
 
3982
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
3763
3983
  nth = MIN(nth, ne00/4);
3764
3984
 
3765
3985
  wsp_ggml_metal_kargs_rms_norm args = {
@@ -3796,6 +4016,7 @@ static bool wsp_ggml_metal_encode_node(
3796
4016
  nth *= 2;
3797
4017
  }
3798
4018
 
4019
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
3799
4020
  nth = MIN(nth, ne00/4);
3800
4021
 
3801
4022
  wsp_ggml_metal_kargs_l2_norm args = {
@@ -3868,6 +4089,7 @@ static bool wsp_ggml_metal_encode_node(
3868
4089
  nth *= 2;
3869
4090
  }
3870
4091
 
4092
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
3871
4093
  nth = MIN(nth, ne00/4);
3872
4094
 
3873
4095
  wsp_ggml_metal_kargs_norm args = {
@@ -4954,8 +5176,39 @@ static bool wsp_ggml_metal_encode_node(
4954
5176
  default: WSP_GGML_ABORT("not implemented");
4955
5177
  }
4956
5178
 
5179
+ WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(src0->type) == 0);
5180
+
5181
+ // TODO: support
5182
+ //const int32_t nk00 = ne00/wsp_ggml_blck_size(dst->type);
5183
+ const int32_t nk00 = ne00;
5184
+
5185
+ int nth = 32; // SIMD width
5186
+
5187
+ while (nth < nk00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
5188
+ nth *= 2;
5189
+ }
5190
+
5191
+ nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
5192
+
5193
+ // when rows are small, we can batch them together in a single threadgroup
5194
+ int nrptg = 1;
5195
+
5196
+ // TODO: relax this constraint in the future
5197
+ if (wsp_ggml_blck_size(src0->type) == 1 && wsp_ggml_blck_size(dst->type) == 1) {
5198
+ if (nth > nk00) {
5199
+ nrptg = (nth + nk00 - 1)/nk00;
5200
+ nth = nk00;
5201
+
5202
+ if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
5203
+ nrptg--;
5204
+ }
5205
+ }
5206
+ }
5207
+
5208
+ nth = MIN(nth, nk00);
5209
+
4957
5210
  wsp_ggml_metal_kargs_cpy args = {
4958
- /*.ne00 =*/ ne00,
5211
+ /*.ne00 =*/ nk00,
4959
5212
  /*.ne01 =*/ ne01,
4960
5213
  /*.ne02 =*/ ne02,
4961
5214
  /*.ne03 =*/ ne03,
@@ -4978,11 +5231,7 @@ static bool wsp_ggml_metal_encode_node(
4978
5231
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
4979
5232
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
4980
5233
 
4981
- WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(src0->type) == 0);
4982
- int nth = MIN(1024, ne00/wsp_ggml_blck_size(src0->type));
4983
-
4984
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
4985
-
5234
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
4986
5235
  } break;
4987
5236
  case WSP_GGML_OP_SET:
4988
5237
  {
@@ -5288,7 +5537,6 @@ static void wsp_ggml_backend_metal_buffer_free_buffer(wsp_ggml_backend_buffer_t
5288
5537
  }
5289
5538
 
5290
5539
  wsp_ggml_backend_metal_buffer_rset_free(ctx);
5291
- wsp_ggml_backend_metal_device_rel(buffer->buft->device->context);
5292
5540
 
5293
5541
  if (ctx->owned) {
5294
5542
  #if TARGET_OS_OSX
@@ -5397,7 +5645,10 @@ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer
5397
5645
  }
5398
5646
 
5399
5647
  struct wsp_ggml_backend_metal_device_context * ctx_dev = (struct wsp_ggml_backend_metal_device_context *)buft->device->context;
5400
- id<MTLDevice> device = wsp_ggml_backend_metal_device_acq(ctx_dev);
5648
+
5649
+ WSP_GGML_ASSERT(ctx_dev->mtl_device != nil);
5650
+
5651
+ id<MTLDevice> device = ctx_dev->mtl_device;
5401
5652
 
5402
5653
  ctx->all_data = wsp_ggml_metal_host_malloc(size_aligned);
5403
5654
  ctx->all_size = size_aligned;
@@ -5420,14 +5671,12 @@ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer
5420
5671
  if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
5421
5672
  WSP_GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
5422
5673
  free(ctx);
5423
- wsp_ggml_backend_metal_device_rel(ctx_dev);
5424
5674
  return NULL;
5425
5675
  }
5426
5676
 
5427
5677
  if (!wsp_ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
5428
5678
  WSP_GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
5429
5679
  free(ctx);
5430
- wsp_ggml_backend_metal_device_rel(ctx_dev);
5431
5680
  return NULL;
5432
5681
  }
5433
5682
 
@@ -5438,17 +5687,14 @@ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer
5438
5687
 
5439
5688
  static size_t wsp_ggml_backend_metal_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
5440
5689
  return 32;
5690
+
5441
5691
  WSP_GGML_UNUSED(buft);
5442
5692
  }
5443
5693
 
5444
5694
  static size_t wsp_ggml_backend_metal_buffer_type_get_max_size(wsp_ggml_backend_buffer_type_t buft) {
5445
- id<MTLDevice> device = wsp_ggml_backend_metal_device_acq(buft->device->context);
5446
- const size_t max_size = device.maxBufferLength;
5447
- wsp_ggml_backend_metal_device_rel(buft->device->context);
5695
+ const size_t max_size = ((struct wsp_ggml_backend_metal_device_context *)buft->device->context)->max_size;
5448
5696
 
5449
5697
  return max_size;
5450
-
5451
- WSP_GGML_UNUSED(buft);
5452
5698
  }
5453
5699
 
5454
5700
  static bool wsp_ggml_backend_metal_buffer_type_is_host(wsp_ggml_backend_buffer_type_t buft) {
@@ -5521,7 +5767,10 @@ wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_from_ptr(void * data, si
5521
5767
  }
5522
5768
 
5523
5769
  struct wsp_ggml_backend_metal_device_context * ctx_dev = &g_wsp_ggml_ctx_dev_main;
5524
- id<MTLDevice> device = wsp_ggml_backend_metal_device_acq(ctx_dev);
5770
+
5771
+ WSP_GGML_ASSERT(ctx_dev->mtl_device != nil);
5772
+
5773
+ id<MTLDevice> device = ctx_dev->mtl_device;
5525
5774
 
5526
5775
  // the buffer fits into the max buffer size allowed by the device
5527
5776
  if (size_aligned <= device.maxBufferLength) {
@@ -5577,7 +5826,6 @@ wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_from_ptr(void * data, si
5577
5826
  if (!wsp_ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
5578
5827
  WSP_GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
5579
5828
  free(ctx);
5580
- wsp_ggml_backend_metal_device_rel(ctx_dev);
5581
5829
  return NULL;
5582
5830
  }
5583
5831
 
@@ -5593,10 +5841,8 @@ static const char * wsp_ggml_backend_metal_name(wsp_ggml_backend_t backend) {
5593
5841
  }
5594
5842
 
5595
5843
  static void wsp_ggml_backend_metal_free(wsp_ggml_backend_t backend) {
5596
- struct wsp_ggml_backend_metal_context * ctx = backend->context;
5597
- struct wsp_ggml_backend_metal_device_context * ctx_dev = backend->device->context;
5844
+ struct wsp_ggml_backend_metal_context * ctx = backend->context;
5598
5845
 
5599
- wsp_ggml_backend_metal_device_rel(ctx_dev);
5600
5846
  wsp_ggml_metal_free(ctx);
5601
5847
 
5602
5848
  free(backend);
@@ -5736,6 +5982,8 @@ bool wsp_ggml_backend_metal_supports_family(wsp_ggml_backend_t backend, int fami
5736
5982
 
5737
5983
  struct wsp_ggml_backend_metal_device_context * ctx_dev = backend->device->context;
5738
5984
 
5985
+ WSP_GGML_ASSERT(ctx_dev->mtl_device != nil);
5986
+
5739
5987
  return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
5740
5988
  }
5741
5989
 
@@ -5755,10 +6003,7 @@ static const char * wsp_ggml_backend_metal_device_get_name(wsp_ggml_backend_dev_
5755
6003
  }
5756
6004
 
5757
6005
  static const char * wsp_ggml_backend_metal_device_get_description(wsp_ggml_backend_dev_t dev) {
5758
- // acq/rel just to populate ctx->name in case it hasn't been done yet
5759
6006
  struct wsp_ggml_backend_metal_device_context * ctx_dev = (struct wsp_ggml_backend_metal_device_context *)dev->context;
5760
- wsp_ggml_backend_metal_device_acq(ctx_dev);
5761
- wsp_ggml_backend_metal_device_rel(ctx_dev);
5762
6007
 
5763
6008
  return ctx_dev->name;
5764
6009
  }
@@ -5766,12 +6011,10 @@ static const char * wsp_ggml_backend_metal_device_get_description(wsp_ggml_backe
5766
6011
  static void wsp_ggml_backend_metal_device_get_memory(wsp_ggml_backend_dev_t dev, size_t * free, size_t * total) {
5767
6012
  if (@available(macOS 10.12, iOS 16.0, *)) {
5768
6013
  struct wsp_ggml_backend_metal_device_context * ctx_dev = (struct wsp_ggml_backend_metal_device_context *)dev->context;
5769
- id<MTLDevice> device = wsp_ggml_backend_metal_device_acq(ctx_dev);
6014
+ id<MTLDevice> device = ctx_dev->mtl_device;
5770
6015
 
5771
6016
  *total = device.recommendedMaxWorkingSetSize;
5772
6017
  *free = *total - device.currentAllocatedSize;
5773
-
5774
- wsp_ggml_backend_metal_device_rel(ctx_dev);
5775
6018
  } else {
5776
6019
  *free = 1;
5777
6020
  *total = 1;
@@ -5849,7 +6092,10 @@ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_device_buffer_from_ptr(w
5849
6092
  }
5850
6093
 
5851
6094
  struct wsp_ggml_backend_metal_device_context * ctx_dev = (struct wsp_ggml_backend_metal_device_context *)dev->context;
5852
- id<MTLDevice> device = wsp_ggml_backend_metal_device_acq(ctx_dev);
6095
+
6096
+ WSP_GGML_ASSERT(ctx_dev->mtl_device != nil);
6097
+
6098
+ id<MTLDevice> device = ctx_dev->mtl_device;
5853
6099
 
5854
6100
  // the buffer fits into the max buffer size allowed by the device
5855
6101
  if (size_aligned <= device.maxBufferLength) {
@@ -5905,7 +6151,6 @@ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_device_buffer_from_ptr(w
5905
6151
  if (!wsp_ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
5906
6152
  WSP_GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
5907
6153
  free(ctx);
5908
- wsp_ggml_backend_metal_device_rel(ctx_dev);
5909
6154
  return NULL;
5910
6155
  }
5911
6156
 
@@ -5919,8 +6164,9 @@ static bool wsp_ggml_backend_metal_device_supports_op(wsp_ggml_backend_dev_t dev
5919
6164
  }
5920
6165
 
5921
6166
  static bool wsp_ggml_backend_metal_device_supports_buft(wsp_ggml_backend_dev_t dev, wsp_ggml_backend_buffer_type_t buft) {
5922
- return buft->iface.get_name == wsp_ggml_backend_metal_buffer_type_get_name ||
5923
- buft->iface.get_name == wsp_ggml_backend_metal_buffer_from_ptr_type_get_name;
6167
+ return
6168
+ buft->iface.get_name == wsp_ggml_backend_metal_buffer_type_get_name ||
6169
+ buft->iface.get_name == wsp_ggml_backend_metal_buffer_from_ptr_type_get_name;
5924
6170
 
5925
6171
  WSP_GGML_UNUSED(dev);
5926
6172
  }
@@ -6005,8 +6251,19 @@ static struct wsp_ggml_backend_reg_i wsp_ggml_backend_metal_reg_i = {
6005
6251
  /* .get_proc_address = */ wsp_ggml_backend_metal_get_proc_address,
6006
6252
  };
6007
6253
 
6254
+ // called upon program exit
6255
+ static void wsp_ggml_metal_cleanup(void) {
6256
+ wsp_ggml_backend_metal_device_rel(&g_wsp_ggml_ctx_dev_main);
6257
+ }
6258
+
6259
+ // TODO: make thread-safe
6008
6260
  wsp_ggml_backend_reg_t wsp_ggml_backend_metal_reg(void) {
6009
- // TODO: make this thread-safe somehow?
6261
+ wsp_ggml_backend_metal_device_acq(&g_wsp_ggml_ctx_dev_main);
6262
+
6263
+ // register cleanup callback
6264
+ // TODO: not ideal, but not sure if there is a better way to do this in Objective-C
6265
+ atexit(wsp_ggml_metal_cleanup);
6266
+
6010
6267
  {
6011
6268
  g_wsp_ggml_backend_metal_reg = (struct wsp_ggml_backend_reg) {
6012
6269
  /* .api_version = */ WSP_GGML_BACKEND_API_VERSION,