whisper.rn 0.5.0 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. package/android/build.gradle +2 -1
  2. package/android/gradle.properties +1 -1
  3. package/cpp/ggml-alloc.c +264 -126
  4. package/cpp/ggml-backend-impl.h +4 -1
  5. package/cpp/ggml-backend-reg.cpp +13 -5
  6. package/cpp/ggml-backend.cpp +207 -17
  7. package/cpp/ggml-backend.h +17 -1
  8. package/cpp/ggml-cpu/amx/amx.cpp +4 -2
  9. package/cpp/ggml-cpu/arch/x86/repack.cpp +2 -2
  10. package/cpp/ggml-cpu/arch-fallback.h +0 -4
  11. package/cpp/ggml-cpu/common.h +14 -0
  12. package/cpp/ggml-cpu/ggml-cpu-impl.h +13 -6
  13. package/cpp/ggml-cpu/ggml-cpu.c +48 -41
  14. package/cpp/ggml-cpu/ggml-cpu.cpp +14 -4
  15. package/cpp/ggml-cpu/ops.cpp +518 -767
  16. package/cpp/ggml-cpu/ops.h +2 -0
  17. package/cpp/ggml-cpu/simd-mappings.h +88 -59
  18. package/cpp/ggml-cpu/vec.cpp +161 -20
  19. package/cpp/ggml-cpu/vec.h +400 -51
  20. package/cpp/ggml-cpu.h +1 -1
  21. package/cpp/ggml-impl.h +43 -10
  22. package/cpp/ggml-metal/ggml-metal-common.cpp +446 -0
  23. package/cpp/ggml-metal/ggml-metal-common.h +52 -0
  24. package/cpp/ggml-metal/ggml-metal-context.h +33 -0
  25. package/cpp/ggml-metal/ggml-metal-context.m +600 -0
  26. package/cpp/ggml-metal/ggml-metal-device.cpp +1376 -0
  27. package/cpp/ggml-metal/ggml-metal-device.h +226 -0
  28. package/cpp/ggml-metal/ggml-metal-device.m +1312 -0
  29. package/cpp/ggml-metal/ggml-metal-impl.h +722 -0
  30. package/cpp/ggml-metal/ggml-metal-ops.cpp +3158 -0
  31. package/cpp/ggml-metal/ggml-metal-ops.h +82 -0
  32. package/cpp/ggml-metal/ggml-metal.cpp +718 -0
  33. package/cpp/ggml-metal/ggml-whisper-sim.metallib +0 -0
  34. package/cpp/ggml-metal/ggml-whisper.metallib +0 -0
  35. package/cpp/ggml-metal-impl.h +40 -40
  36. package/cpp/ggml-metal.h +1 -6
  37. package/cpp/ggml-quants.c +1 -0
  38. package/cpp/ggml.c +175 -13
  39. package/cpp/ggml.h +84 -5
  40. package/cpp/jsi/RNWhisperJSI.cpp +2 -0
  41. package/cpp/jsi/ThreadPool.h +3 -3
  42. package/cpp/whisper.cpp +85 -70
  43. package/cpp/whisper.h +1 -0
  44. package/ios/CMakeLists.txt +6 -1
  45. package/ios/RNWhisperVadContext.mm +14 -13
  46. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  47. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  48. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  49. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +43 -10
  50. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  51. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  52. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +84 -5
  53. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  54. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Info.plist +0 -0
  55. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  56. package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
  57. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  58. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  59. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  60. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +43 -10
  61. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  62. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  63. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +84 -5
  64. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  65. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  66. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  67. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  68. package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  69. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  70. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  71. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  72. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +43 -10
  73. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  74. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  75. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +84 -5
  76. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/whisper.h +1 -0
  77. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Info.plist +0 -0
  78. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
  79. package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
  80. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend-impl.h +4 -1
  81. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-backend.h +17 -1
  82. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-cpu.h +1 -1
  83. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +43 -10
  84. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +40 -40
  85. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal.h +1 -6
  86. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +84 -5
  87. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/whisper.h +1 -0
  88. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Info.plist +0 -0
  89. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/_CodeSignature/CodeResources +1 -1
  90. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
  91. package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
  92. package/lib/commonjs/version.json +1 -1
  93. package/lib/module/version.json +1 -1
  94. package/package.json +1 -1
  95. package/src/version.json +1 -1
  96. package/whisper-rn.podspec +8 -9
  97. package/cpp/ggml-metal.m +0 -6779
  98. package/cpp/ggml-whisper-sim.metallib +0 -0
  99. package/cpp/ggml-whisper.metallib +0 -0
@@ -0,0 +1,600 @@
1
+ #import "ggml-metal-context.h"
2
+
3
+ #import "ggml-impl.h"
4
+ #import "ggml-backend-impl.h"
5
+
6
+ #import "ggml-metal-impl.h"
7
+ #import "ggml-metal-common.h"
8
+ #import "ggml-metal-ops.h"
9
+
10
+ #import <Foundation/Foundation.h>
11
+
12
+ #import <Metal/Metal.h>
13
+
14
+ #undef MIN
15
+ #undef MAX
16
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
17
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
18
+
19
+ // max number of MTLCommandBuffer used to submit a graph for processing
20
+ #define WSP_GGML_METAL_MAX_COMMAND_BUFFERS 8
21
+
22
+ struct wsp_ggml_metal_command_buffer {
23
+ id<MTLCommandBuffer> obj;
24
+ };
25
+
26
+ struct wsp_ggml_metal {
27
+ id<MTLDevice> device;
28
+ id<MTLCommandQueue> queue; // currently a pointer to the device queue, but might become separate queue [TAG_QUEUE_PER_BACKEND]
29
+
30
+ wsp_ggml_metal_device_t dev;
31
+ wsp_ggml_metal_library_t lib;
32
+
33
+ dispatch_queue_t d_queue;
34
+
35
+ // additional, inference-time compiled pipelines
36
+ wsp_ggml_metal_pipelines_t pipelines_ext;
37
+
38
+ bool use_bfloat;
39
+ bool use_fusion;
40
+ bool use_concurrency;
41
+ bool use_graph_optimize;
42
+
43
+ int debug_graph;
44
+ int debug_fusion;
45
+
46
+ // how many times a given op was fused
47
+ uint64_t fuse_cnt[WSP_GGML_OP_COUNT];
48
+
49
+ // capture state
50
+ bool capture_next_compute;
51
+ bool capture_started;
52
+
53
+ id<MTLCaptureScope> capture_scope;
54
+
55
+ // command buffer state
56
+ int n_cb; // number of extra threads used to submit the command buffers
57
+ int n_nodes_0; // number of nodes submitted by the main thread
58
+ int n_nodes_1; // remaining number of nodes submitted by the n_cb threads
59
+ int n_nodes_per_cb;
60
+
61
+ struct wsp_ggml_cgraph * gf;
62
+
63
+ // the callback given to the thread pool
64
+ void (^encode_async)(size_t ith);
65
+
66
+ // n_cb command buffers + 1 used by the main thread
67
+ struct wsp_ggml_metal_command_buffer cmd_bufs[WSP_GGML_METAL_MAX_COMMAND_BUFFERS + 1];
68
+
69
+ // extra command buffers for things like getting, setting and copying tensors
70
+ NSMutableArray * cmd_bufs_ext;
71
+
72
+ // the last command buffer queued into the Metal queue with operations relevant to the current Metal backend
73
+ id<MTLCommandBuffer> cmd_buf_last;
74
+
75
+ // abort wsp_ggml_metal_graph_compute if callback returns true
76
+ wsp_ggml_abort_callback abort_callback;
77
+ void * abort_callback_data;
78
+ };
79
+
80
+ wsp_ggml_metal_t wsp_ggml_metal_init(wsp_ggml_metal_device_t dev) {
81
+ WSP_GGML_LOG_INFO("%s: allocating\n", __func__);
82
+
83
+ #if TARGET_OS_OSX && !WSP_GGML_METAL_NDEBUG
84
+ // Show all the Metal device instances in the system
85
+ NSArray * devices = MTLCopyAllDevices();
86
+ for (id<MTLDevice> device in devices) {
87
+ WSP_GGML_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]);
88
+ }
89
+ [devices release]; // since it was created by a *Copy* C method
90
+ #endif
91
+
92
+ // init context
93
+ wsp_ggml_metal_t res = calloc(1, sizeof(struct wsp_ggml_metal));
94
+
95
+ res->device = wsp_ggml_metal_device_get_obj(dev);
96
+
97
+ WSP_GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[res->device name] UTF8String]);
98
+
99
+ // TODO: would it be better to have one queue for the backend and one queue for the device?
100
+ // the graph encoders and async ops would use the backend queue while the sync ops would use the device queue?
101
+ //res->queue = [device newCommandQueue]; [TAG_QUEUE_PER_BACKEND]
102
+ res->queue = wsp_ggml_metal_device_get_queue(dev);
103
+ if (res->queue == nil) {
104
+ WSP_GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
105
+ return NULL;
106
+ }
107
+
108
+ res->dev = dev;
109
+ res->lib = wsp_ggml_metal_device_get_library(dev);
110
+ if (res->lib == NULL) {
111
+ WSP_GGML_LOG_WARN("%s: the device does not have a precompiled Metal library - this is unexpected\n", __func__);
112
+ WSP_GGML_LOG_WARN("%s: will try to compile it on the fly\n", __func__);
113
+
114
+ res->lib = wsp_ggml_metal_library_init(dev);
115
+ if (res->lib == NULL) {
116
+ WSP_GGML_LOG_ERROR("%s: error: failed to initialize the Metal library\n", __func__);
117
+
118
+ free(res);
119
+
120
+ return NULL;
121
+ }
122
+ }
123
+
124
+ const struct wsp_ggml_metal_device_props * props_dev = wsp_ggml_metal_device_get_props(dev);
125
+
126
+ res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
127
+
128
+ res->use_bfloat = props_dev->has_bfloat;
129
+ res->use_fusion = getenv("WSP_GGML_METAL_FUSION_DISABLE") == nil;
130
+ res->use_concurrency = getenv("WSP_GGML_METAL_CONCURRENCY_DISABLE") == nil;
131
+
132
+ {
133
+ const char * val = getenv("WSP_GGML_METAL_GRAPH_DEBUG");
134
+ res->debug_graph = val ? atoi(val) : 0;
135
+ }
136
+
137
+ {
138
+ const char * val = getenv("WSP_GGML_METAL_FUSION_DEBUG");
139
+ res->debug_fusion = val ? atoi(val) : 0;
140
+ }
141
+
142
+ res->use_graph_optimize = true;
143
+
144
+ if (getenv("WSP_GGML_METAL_GRAPH_OPTIMIZE_DISABLE") != NULL) {
145
+ res->use_graph_optimize = false;
146
+ }
147
+
148
+ memset(res->fuse_cnt, 0, sizeof(res->fuse_cnt));
149
+
150
+ WSP_GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, res->use_bfloat ? "true" : "false");
151
+ WSP_GGML_LOG_INFO("%s: use fusion = %s\n", __func__, res->use_fusion ? "true" : "false");
152
+ WSP_GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, res->use_concurrency ? "true" : "false");
153
+ WSP_GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, res->use_graph_optimize ? "true" : "false");
154
+
155
+ res->capture_next_compute = false;
156
+ res->capture_started = false;
157
+ res->capture_scope = nil;
158
+
159
+ res->gf = nil;
160
+ res->encode_async = nil;
161
+ for (int i = 0; i < WSP_GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
162
+ res->cmd_bufs[i].obj = nil;
163
+ }
164
+
165
+ res->cmd_bufs_ext = [[NSMutableArray alloc] init];
166
+
167
+ res->cmd_buf_last = nil;
168
+
169
+ res->pipelines_ext = wsp_ggml_metal_pipelines_init();
170
+
171
+ return res;
172
+ }
173
+
174
+ void wsp_ggml_metal_free(wsp_ggml_metal_t ctx) {
175
+ WSP_GGML_LOG_INFO("%s: deallocating\n", __func__);
176
+
177
+ for (int i = 0; i < WSP_GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
178
+ if (ctx->cmd_bufs[i].obj) {
179
+ [ctx->cmd_bufs[i].obj release];
180
+ }
181
+ }
182
+
183
+ for (int i = 0; i < (int) ctx->cmd_bufs_ext.count; ++i) {
184
+ if (ctx->cmd_bufs_ext[i]) {
185
+ [ctx->cmd_bufs_ext[i] release];
186
+ }
187
+ }
188
+
189
+ [ctx->cmd_bufs_ext removeAllObjects];
190
+ [ctx->cmd_bufs_ext release];
191
+
192
+ if (ctx->pipelines_ext) {
193
+ wsp_ggml_metal_pipelines_free(ctx->pipelines_ext);
194
+ ctx->pipelines_ext = nil;
195
+ }
196
+
197
+ if (ctx->debug_fusion > 0) {
198
+ WSP_GGML_LOG_DEBUG("%s: fusion stats:\n", __func__);
199
+ for (int i = 0; i < WSP_GGML_OP_COUNT; i++) {
200
+ if (ctx->fuse_cnt[i] == 0) {
201
+ continue;
202
+ }
203
+
204
+ // note: cannot use wsp_ggml_log here
205
+ WSP_GGML_LOG_DEBUG("%s: - %s: %" PRIu64 "\n", __func__, wsp_ggml_op_name((enum wsp_ggml_op) i), ctx->fuse_cnt[i]);
206
+ }
207
+ }
208
+
209
+ Block_release(ctx->encode_async);
210
+
211
+ //[ctx->queue release]; // [TAG_QUEUE_PER_BACKEND]
212
+
213
+ dispatch_release(ctx->d_queue);
214
+
215
+ free(ctx);
216
+ }
217
+
218
+ void wsp_ggml_metal_synchronize(wsp_ggml_metal_t ctx) {
219
+ // wait for any backend operations to finish
220
+ if (ctx->cmd_buf_last) {
221
+ [ctx->cmd_buf_last waitUntilCompleted];
222
+ ctx->cmd_buf_last = nil;
223
+ }
224
+
225
+ // check status of all command buffers
226
+ {
227
+ const int n_cb = ctx->n_cb;
228
+
229
+ for (int cb_idx = 0; cb_idx <= n_cb; ++cb_idx) {
230
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
231
+ if (!cmd_buf) {
232
+ continue;
233
+ }
234
+
235
+ MTLCommandBufferStatus status = [cmd_buf status];
236
+ if (status != MTLCommandBufferStatusCompleted) {
237
+ WSP_GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, cb_idx, (int) status);
238
+ if (status == MTLCommandBufferStatusError) {
239
+ WSP_GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
240
+ }
241
+ WSP_GGML_ABORT("fatal error");
242
+ }
243
+ }
244
+ }
245
+
246
+ // release any completed extra command buffers
247
+ if (ctx->cmd_bufs_ext.count > 0) {
248
+ for (size_t i = 0; i < ctx->cmd_bufs_ext.count; ++i) {
249
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs_ext[i];
250
+
251
+ MTLCommandBufferStatus status = [cmd_buf status];
252
+ if (status != MTLCommandBufferStatusCompleted) {
253
+ WSP_GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, (int) i, (int) status);
254
+ if (status == MTLCommandBufferStatusError) {
255
+ WSP_GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
256
+ }
257
+ WSP_GGML_ABORT("fatal error");
258
+ }
259
+
260
+ [cmd_buf release];
261
+ }
262
+
263
+ [ctx->cmd_bufs_ext removeAllObjects];
264
+ }
265
+ }
266
+
267
+ static struct wsp_ggml_metal_buffer_id wsp_ggml_metal_get_buffer_id(const struct wsp_ggml_tensor * t) {
268
+ if (!t) {
269
+ return (struct wsp_ggml_metal_buffer_id) { nil, 0 };
270
+ }
271
+
272
+ wsp_ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
273
+
274
+ return wsp_ggml_metal_buffer_get_id(buffer->context, t);
275
+ }
276
+
277
+ void wsp_ggml_metal_set_tensor_async(wsp_ggml_metal_t ctx, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
278
+ @autoreleasepool {
279
+ // wrap the source data into a Metal buffer
280
+ id<MTLBuffer> buf_src = [ctx->device newBufferWithBytes:data
281
+ length:size
282
+ options:MTLResourceStorageModeShared];
283
+
284
+ WSP_GGML_ASSERT(buf_src);
285
+
286
+ struct wsp_ggml_metal_buffer_id bid_dst = wsp_ggml_metal_get_buffer_id(tensor);
287
+ if (bid_dst.metal == nil) {
288
+ WSP_GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name);
289
+ }
290
+
291
+ bid_dst.offs += offset;
292
+
293
+ // queue the copy operation into the queue of the Metal context
294
+ // this will be queued at the end, after any currently ongoing GPU operations
295
+ id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
296
+ id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
297
+
298
+ [encoder copyFromBuffer:buf_src
299
+ sourceOffset:0
300
+ toBuffer:bid_dst.metal
301
+ destinationOffset:bid_dst.offs
302
+ size:size];
303
+
304
+ [encoder endEncoding];
305
+ [cmd_buf commit];
306
+
307
+ // do not wait here for completion
308
+ //[cmd_buf waitUntilCompleted];
309
+
310
+ // instead, remember a reference to the command buffer and wait for it later if needed
311
+ [ctx->cmd_bufs_ext addObject:cmd_buf];
312
+ ctx->cmd_buf_last = cmd_buf;
313
+
314
+ [cmd_buf retain];
315
+ }
316
+ }
317
+
318
+ void wsp_ggml_metal_get_tensor_async(wsp_ggml_metal_t ctx, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
319
+ @autoreleasepool {
320
+ id<MTLBuffer> buf_dst = [ctx->device newBufferWithBytesNoCopy:data
321
+ length:size
322
+ options:MTLResourceStorageModeShared
323
+ deallocator:nil];
324
+
325
+ WSP_GGML_ASSERT(buf_dst);
326
+
327
+ struct wsp_ggml_metal_buffer_id bid_src = wsp_ggml_metal_get_buffer_id(tensor);
328
+ if (bid_src.metal == nil) {
329
+ WSP_GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name);
330
+ }
331
+
332
+ bid_src.offs += offset;
333
+
334
+ // queue the copy operation into the queue of the Metal context
335
+ // this will be queued at the end, after any currently ongoing GPU operations
336
+ id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
337
+ id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
338
+
339
+ [encoder copyFromBuffer:bid_src.metal
340
+ sourceOffset:bid_src.offs
341
+ toBuffer:buf_dst
342
+ destinationOffset:0
343
+ size:size];
344
+
345
+ [encoder endEncoding];
346
+ [cmd_buf commit];
347
+
348
+ // do not wait here for completion
349
+ //[cmd_buf waitUntilCompleted];
350
+
351
+ // instead, remember a reference to the command buffer and wait for it later if needed
352
+ [ctx->cmd_bufs_ext addObject:cmd_buf];
353
+ ctx->cmd_buf_last = cmd_buf;
354
+
355
+ [cmd_buf retain];
356
+ }
357
+ }
358
+
359
+ enum wsp_ggml_status wsp_ggml_metal_graph_compute(wsp_ggml_metal_t ctx, struct wsp_ggml_cgraph * gf) {
360
+ // number of nodes encoded by the main thread (empirically determined)
361
+ const int n_main = 64;
362
+
363
+ // number of threads in addition to the main thread
364
+ const int n_cb = ctx->n_cb;
365
+
366
+ // submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them
367
+ // the first n_nodes_0 are encoded and submitted for processing directly by the calling thread
368
+ // while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes
369
+ // each thread creates it's own command buffer and enqueues the ops in parallel
370
+ //
371
+ // tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2
372
+
373
+ @autoreleasepool {
374
+ ctx->gf = gf;
375
+
376
+ ctx->n_nodes_0 = MIN(n_main, gf->n_nodes);
377
+ ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0;
378
+
379
+ ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb;
380
+
381
+ const bool use_capture = ctx->capture_next_compute;
382
+ if (use_capture) {
383
+ ctx->capture_next_compute = false;
384
+
385
+ // make sure all previous computations have finished before starting the capture
386
+ if (ctx->cmd_buf_last) {
387
+ [ctx->cmd_buf_last waitUntilCompleted];
388
+ ctx->cmd_buf_last = nil;
389
+ }
390
+
391
+ if (!ctx->capture_started) {
392
+ // create capture scope
393
+ ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device];
394
+
395
+ MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
396
+ descriptor.captureObject = ctx->capture_scope;
397
+ descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
398
+ descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]];
399
+
400
+ NSError * error = nil;
401
+ if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
402
+ WSP_GGML_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
403
+ } else {
404
+ [ctx->capture_scope beginScope];
405
+ ctx->capture_started = true;
406
+ }
407
+ }
408
+ }
409
+
410
+ // the main thread commits the first few commands immediately
411
+ // cmd_buf[n_cb]
412
+ {
413
+ id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
414
+ [cmd_buf retain];
415
+
416
+ if (ctx->cmd_bufs[n_cb].obj) {
417
+ [ctx->cmd_bufs[n_cb].obj release];
418
+ }
419
+ ctx->cmd_bufs[n_cb].obj = cmd_buf;
420
+
421
+ [cmd_buf enqueue];
422
+
423
+ ctx->encode_async(n_cb);
424
+ }
425
+
426
+ // remember the command buffer for the next iteration
427
+ ctx->cmd_buf_last = ctx->cmd_bufs[n_cb].obj;
428
+
429
+ // prepare the rest of the command buffers asynchronously (optional)
430
+ // cmd_buf[0.. n_cb)
431
+ for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
432
+ id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
433
+ [cmd_buf retain];
434
+
435
+ if (ctx->cmd_bufs[cb_idx].obj) {
436
+ [ctx->cmd_bufs[cb_idx].obj release];
437
+ }
438
+ ctx->cmd_bufs[cb_idx].obj = cmd_buf;
439
+
440
+ // always enqueue the first two command buffers
441
+ // enqueue all of the command buffers if we don't need to abort
442
+ if (cb_idx < 2 || ctx->abort_callback == NULL) {
443
+ [cmd_buf enqueue];
444
+
445
+ // update the pointer to the last queued command buffer
446
+ // this is needed to implement synchronize()
447
+ ctx->cmd_buf_last = cmd_buf;
448
+ }
449
+ }
450
+
451
+ dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
452
+
453
+ // for debugging: block until graph is computed
454
+ //[ctx->cmd_buf_last waitUntilCompleted];
455
+
456
+ // enter here only when capturing in order to wait for all computation to finish
457
+ // otherwise, we leave the graph to compute asynchronously
458
+ if (!use_capture && ctx->capture_started) {
459
+ // wait for completion and check status of each command buffer
460
+ // needed to detect if the device ran out-of-memory for example (#1881)
461
+ {
462
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
463
+ [cmd_buf waitUntilCompleted];
464
+
465
+ MTLCommandBufferStatus status = [cmd_buf status];
466
+ if (status != MTLCommandBufferStatusCompleted) {
467
+ WSP_GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
468
+ if (status == MTLCommandBufferStatusError) {
469
+ WSP_GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
470
+ }
471
+
472
+ return WSP_GGML_STATUS_FAILED;
473
+ }
474
+ }
475
+
476
+ for (int i = 0; i < n_cb; ++i) {
477
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
478
+ [cmd_buf waitUntilCompleted];
479
+
480
+ MTLCommandBufferStatus status = [cmd_buf status];
481
+ if (status != MTLCommandBufferStatusCompleted) {
482
+ WSP_GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
483
+ if (status == MTLCommandBufferStatusError) {
484
+ WSP_GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
485
+ }
486
+
487
+ return WSP_GGML_STATUS_FAILED;
488
+ }
489
+
490
+ id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
491
+ if (!next_buffer) {
492
+ continue;
493
+ }
494
+
495
+ const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
496
+ if (next_queued) {
497
+ continue;
498
+ }
499
+
500
+ if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
501
+ WSP_GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
502
+ return WSP_GGML_STATUS_ABORTED;
503
+ }
504
+
505
+ [next_buffer commit];
506
+ }
507
+
508
+ [ctx->capture_scope endScope];
509
+ [[MTLCaptureManager sharedCaptureManager] stopCapture];
510
+ }
511
+ }
512
+
513
+ return WSP_GGML_STATUS_SUCCESS;
514
+ }
515
+
516
+ void wsp_ggml_metal_graph_optimize(wsp_ggml_metal_t ctx, struct wsp_ggml_cgraph * gf) {
517
+ //const int64_t t_start = wsp_ggml_time_us();
518
+
519
+ if (ctx->use_graph_optimize) {
520
+ wsp_ggml_graph_optimize(gf);
521
+ }
522
+
523
+ //printf("%s: graph optimize took %.3f ms\n", __func__, (wsp_ggml_time_us() - t_start) / 1000.0);
524
+ }
525
+
526
+ void wsp_ggml_metal_set_n_cb(wsp_ggml_metal_t ctx, int n_cb) {
527
+ if (ctx->n_cb != n_cb) {
528
+ ctx->n_cb = MIN(n_cb, WSP_GGML_METAL_MAX_COMMAND_BUFFERS);
529
+
530
+ if (ctx->n_cb > 2) {
531
+ WSP_GGML_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb);
532
+ }
533
+ }
534
+
535
+ if (ctx->encode_async) {
536
+ Block_release(ctx->encode_async);
537
+ }
538
+
539
+ ctx->encode_async = Block_copy(^(size_t iter) {
540
+ const int cb_idx = iter;
541
+ const int n_cb_l = ctx->n_cb;
542
+
543
+ const int n_nodes_0 = ctx->n_nodes_0;
544
+ const int n_nodes_1 = ctx->n_nodes_1;
545
+
546
+ const int n_nodes_per_cb = ctx->n_nodes_per_cb;
547
+
548
+ int idx_start = 0;
549
+ int idx_end = n_nodes_0;
550
+
551
+ if (cb_idx < n_cb_l) {
552
+ idx_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
553
+ idx_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
554
+ }
555
+
556
+ id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
557
+
558
+ wsp_ggml_metal_op_t ctx_op = wsp_ggml_metal_op_init(
559
+ ctx->dev,
560
+ cmd_buf,
561
+ ctx->gf,
562
+ idx_start,
563
+ idx_end,
564
+ ctx->use_fusion,
565
+ ctx->use_concurrency,
566
+ ctx->capture_next_compute,
567
+ ctx->debug_graph,
568
+ ctx->debug_fusion);
569
+
570
+ for (int idx = 0; idx < wsp_ggml_metal_op_n_nodes(ctx_op); ++idx) {
571
+ const int res = wsp_ggml_metal_op_encode(ctx_op, idx);
572
+ if (res == 0) {
573
+ break;
574
+ }
575
+
576
+ idx += res - 1;
577
+ }
578
+
579
+ wsp_ggml_metal_op_free(ctx_op);
580
+
581
+ if (cb_idx < 2 || ctx->abort_callback == NULL) {
582
+ [cmd_buf commit];
583
+ }
584
+ });
585
+ }
586
+
587
+ void wsp_ggml_metal_set_abort_callback(wsp_ggml_metal_t ctx, wsp_ggml_abort_callback abort_callback, void * user_data) {
588
+ ctx->abort_callback = abort_callback;
589
+ ctx->abort_callback_data = user_data;
590
+ }
591
+
592
+ bool wsp_ggml_metal_supports_family(wsp_ggml_metal_t ctx, int family) {
593
+ WSP_GGML_ASSERT(ctx->device != nil);
594
+
595
+ return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
596
+ }
597
+
598
+ void wsp_ggml_metal_capture_next_compute(wsp_ggml_metal_t ctx) {
599
+ ctx->capture_next_compute = true;
600
+ }