whisper.rn 0.4.0-rc.4 → 0.4.0-rc.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. package/README.md +6 -6
  2. package/android/build.gradle +4 -0
  3. package/android/src/main/CMakeLists.txt +5 -0
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
  5. package/android/src/main/java/com/rnwhisper/WhisperContext.java +51 -133
  6. package/android/src/main/jni-utils.h +76 -0
  7. package/android/src/main/jni.cpp +187 -112
  8. package/cpp/README.md +1 -1
  9. package/cpp/coreml/whisper-encoder-impl.h +1 -1
  10. package/cpp/coreml/whisper-encoder.h +4 -0
  11. package/cpp/coreml/whisper-encoder.mm +4 -2
  12. package/cpp/ggml-alloc.c +55 -19
  13. package/cpp/ggml-alloc.h +7 -0
  14. package/cpp/ggml-backend-impl.h +46 -21
  15. package/cpp/ggml-backend.c +563 -156
  16. package/cpp/ggml-backend.h +62 -17
  17. package/cpp/ggml-impl.h +1 -1
  18. package/cpp/ggml-metal-whisper.metal +1010 -253
  19. package/cpp/ggml-metal.h +7 -1
  20. package/cpp/ggml-metal.m +618 -187
  21. package/cpp/ggml-quants.c +64 -59
  22. package/cpp/ggml-quants.h +40 -40
  23. package/cpp/ggml.c +751 -1466
  24. package/cpp/ggml.h +90 -25
  25. package/cpp/rn-audioutils.cpp +68 -0
  26. package/cpp/rn-audioutils.h +14 -0
  27. package/cpp/rn-whisper-log.h +11 -0
  28. package/cpp/rn-whisper.cpp +141 -59
  29. package/cpp/rn-whisper.h +47 -15
  30. package/cpp/whisper.cpp +1635 -928
  31. package/cpp/whisper.h +55 -10
  32. package/ios/RNWhisper.mm +7 -7
  33. package/ios/RNWhisperAudioUtils.h +0 -2
  34. package/ios/RNWhisperAudioUtils.m +0 -56
  35. package/ios/RNWhisperContext.h +3 -11
  36. package/ios/RNWhisperContext.mm +62 -134
  37. package/lib/commonjs/version.json +1 -1
  38. package/lib/module/version.json +1 -1
  39. package/package.json +6 -5
  40. package/src/version.json +1 -1
package/cpp/ggml-metal.m CHANGED
@@ -62,6 +62,8 @@ struct wsp_ggml_metal_context {
62
62
  WSP_GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
63
63
  WSP_GGML_METAL_DECL_KERNEL(mul);
64
64
  WSP_GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
65
+ WSP_GGML_METAL_DECL_KERNEL(div);
66
+ WSP_GGML_METAL_DECL_KERNEL(div_row);
65
67
  WSP_GGML_METAL_DECL_KERNEL(scale);
66
68
  WSP_GGML_METAL_DECL_KERNEL(scale_4);
67
69
  WSP_GGML_METAL_DECL_KERNEL(silu);
@@ -86,6 +88,7 @@ struct wsp_ggml_metal_context {
86
88
  WSP_GGML_METAL_DECL_KERNEL(rms_norm);
87
89
  WSP_GGML_METAL_DECL_KERNEL(norm);
88
90
  WSP_GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
91
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
89
92
  WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
90
93
  WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
91
94
  WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
@@ -111,14 +114,35 @@ struct wsp_ggml_metal_context {
111
114
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
112
115
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
113
116
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
117
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
118
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
119
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
120
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
121
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
122
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
123
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
124
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
125
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
126
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
127
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
128
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
114
129
  WSP_GGML_METAL_DECL_KERNEL(rope_f32);
115
130
  WSP_GGML_METAL_DECL_KERNEL(rope_f16);
116
131
  WSP_GGML_METAL_DECL_KERNEL(alibi_f32);
132
+ WSP_GGML_METAL_DECL_KERNEL(im2col_f16);
133
+ WSP_GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
134
+ WSP_GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
117
135
  WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f16);
118
136
  WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f32);
137
+ WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
138
+ WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
139
+ WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
140
+ //WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
141
+ //WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
119
142
  WSP_GGML_METAL_DECL_KERNEL(cpy_f16_f16);
120
143
  WSP_GGML_METAL_DECL_KERNEL(concat);
121
144
  WSP_GGML_METAL_DECL_KERNEL(sqr);
145
+ WSP_GGML_METAL_DECL_KERNEL(sum_rows);
122
146
 
123
147
  #undef WSP_GGML_METAL_DECL_KERNEL
124
148
  };
@@ -126,7 +150,7 @@ struct wsp_ggml_metal_context {
126
150
  // MSL code
127
151
  // TODO: move the contents here when ready
128
152
  // for now it is easier to work in a separate file
129
- static NSString * const msl_library_source = @"see metal.metal";
153
+ //static NSString * const msl_library_source = @"see metal.metal";
130
154
 
131
155
  // Here to assist with NSBundle Path Hack
132
156
  @interface WSPGGMLMetalClass : NSObject
@@ -142,7 +166,8 @@ void wsp_ggml_metal_log_set_callback(wsp_ggml_log_callback log_callback, void *
142
166
  wsp_ggml_metal_log_user_data = user_data;
143
167
  }
144
168
 
145
- static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char* format, ...){
169
+ WSP_GGML_ATTRIBUTE_FORMAT(2, 3)
170
+ static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char * format, ...){
146
171
  if (wsp_ggml_metal_log_callback != NULL) {
147
172
  va_list args;
148
173
  va_start(args, format);
@@ -152,6 +177,8 @@ static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char* format
152
177
  wsp_ggml_metal_log_callback(level, buffer, wsp_ggml_metal_log_user_data);
153
178
  } else {
154
179
  char* buffer2 = malloc(len+1);
180
+ va_end(args);
181
+ va_start(args, format);
155
182
  vsnprintf(buffer2, len+1, format, args);
156
183
  buffer2[len] = 0;
157
184
  wsp_ggml_metal_log_callback(level, buffer2, wsp_ggml_metal_log_user_data);
@@ -161,12 +188,10 @@ static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char* format
161
188
  }
162
189
  }
163
190
 
164
-
165
-
166
191
  struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
167
192
  WSP_GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
168
193
 
169
- id <MTLDevice> device;
194
+ id<MTLDevice> device;
170
195
  NSString * s;
171
196
 
172
197
  #if TARGET_OS_OSX
@@ -184,7 +209,7 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
184
209
  WSP_GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
185
210
 
186
211
  // Configure context
187
- struct wsp_ggml_metal_context * ctx = malloc(sizeof(struct wsp_ggml_metal_context));
212
+ struct wsp_ggml_metal_context * ctx = calloc(1, sizeof(struct wsp_ggml_metal_context));
188
213
  ctx->device = device;
189
214
  ctx->n_cb = MIN(n_cb, WSP_GGML_METAL_MAX_BUFFERS);
190
215
  ctx->queue = [ctx->device newCommandQueue];
@@ -212,6 +237,9 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
212
237
 
213
238
  NSString * sourcePath;
214
239
  NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"WSP_GGML_METAL_PATH_RESOURCES"];
240
+
241
+ WSP_GGML_METAL_LOG_INFO("%s: WSP_GGML_METAL_PATH_RESOURCES = %s\n", __func__, ggmlMetalPathResources ? [ggmlMetalPathResources UTF8String] : "nil");
242
+
215
243
  if (ggmlMetalPathResources) {
216
244
  sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
217
245
  } else {
@@ -242,6 +270,29 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
242
270
  }
243
271
  }
244
272
 
273
+ #if TARGET_OS_OSX
274
+ // print MTL GPU family:
275
+ WSP_GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
276
+
277
+ // determine max supported GPU family
278
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
279
+ // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
280
+ for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
281
+ if ([ctx->device supportsFamily:i]) {
282
+ WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
283
+ break;
284
+ }
285
+ }
286
+
287
+ WSP_GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
288
+ WSP_GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
289
+ if (ctx->device.maxTransferRate != 0) {
290
+ WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
291
+ } else {
292
+ WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
293
+ }
294
+ #endif
295
+
245
296
  // load kernels
246
297
  {
247
298
  NSError * error = nil;
@@ -263,6 +314,8 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
263
314
  WSP_GGML_METAL_ADD_KERNEL(add_row);
264
315
  WSP_GGML_METAL_ADD_KERNEL(mul);
265
316
  WSP_GGML_METAL_ADD_KERNEL(mul_row);
317
+ WSP_GGML_METAL_ADD_KERNEL(div);
318
+ WSP_GGML_METAL_ADD_KERNEL(div_row);
266
319
  WSP_GGML_METAL_ADD_KERNEL(scale);
267
320
  WSP_GGML_METAL_ADD_KERNEL(scale_4);
268
321
  WSP_GGML_METAL_ADD_KERNEL(silu);
@@ -287,6 +340,7 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
287
340
  WSP_GGML_METAL_ADD_KERNEL(rms_norm);
288
341
  WSP_GGML_METAL_ADD_KERNEL(norm);
289
342
  WSP_GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
343
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
290
344
  WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
291
345
  WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
292
346
  WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
@@ -313,42 +367,40 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
313
367
  WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
314
368
  WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
315
369
  WSP_GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
370
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
371
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
372
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
373
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
374
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
375
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
376
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
377
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
378
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
379
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
380
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
381
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
316
382
  }
317
383
  WSP_GGML_METAL_ADD_KERNEL(rope_f32);
318
384
  WSP_GGML_METAL_ADD_KERNEL(rope_f16);
319
385
  WSP_GGML_METAL_ADD_KERNEL(alibi_f32);
386
+ WSP_GGML_METAL_ADD_KERNEL(im2col_f16);
387
+ WSP_GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
388
+ WSP_GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
320
389
  WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f16);
321
390
  WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f32);
391
+ WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
392
+ WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
393
+ WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
394
+ //WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
395
+ //WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
322
396
  WSP_GGML_METAL_ADD_KERNEL(cpy_f16_f16);
323
397
  WSP_GGML_METAL_ADD_KERNEL(concat);
324
398
  WSP_GGML_METAL_ADD_KERNEL(sqr);
399
+ WSP_GGML_METAL_ADD_KERNEL(sum_rows);
325
400
 
326
401
  #undef WSP_GGML_METAL_ADD_KERNEL
327
402
  }
328
403
 
329
- #if TARGET_OS_OSX
330
- // print MTL GPU family:
331
- WSP_GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
332
-
333
- // determine max supported GPU family
334
- // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
335
- // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
336
- for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
337
- if ([ctx->device supportsFamily:i]) {
338
- WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
339
- break;
340
- }
341
- }
342
-
343
- WSP_GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
344
- WSP_GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
345
- if (ctx->device.maxTransferRate != 0) {
346
- WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
347
- } else {
348
- WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
349
- }
350
- #endif
351
-
352
404
  return ctx;
353
405
  }
354
406
 
@@ -360,6 +412,8 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
360
412
  WSP_GGML_METAL_DEL_KERNEL(add_row);
361
413
  WSP_GGML_METAL_DEL_KERNEL(mul);
362
414
  WSP_GGML_METAL_DEL_KERNEL(mul_row);
415
+ WSP_GGML_METAL_DEL_KERNEL(div);
416
+ WSP_GGML_METAL_DEL_KERNEL(div_row);
363
417
  WSP_GGML_METAL_DEL_KERNEL(scale);
364
418
  WSP_GGML_METAL_DEL_KERNEL(scale_4);
365
419
  WSP_GGML_METAL_DEL_KERNEL(silu);
@@ -384,6 +438,7 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
384
438
  WSP_GGML_METAL_DEL_KERNEL(rms_norm);
385
439
  WSP_GGML_METAL_DEL_KERNEL(norm);
386
440
  WSP_GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
441
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
387
442
  WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
388
443
  WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
389
444
  WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
@@ -410,15 +465,36 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
410
465
  WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
411
466
  WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
412
467
  WSP_GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
468
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
469
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
470
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
471
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
472
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
473
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
474
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
475
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
476
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
477
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
478
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
479
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
413
480
  }
414
481
  WSP_GGML_METAL_DEL_KERNEL(rope_f32);
415
482
  WSP_GGML_METAL_DEL_KERNEL(rope_f16);
416
483
  WSP_GGML_METAL_DEL_KERNEL(alibi_f32);
484
+ WSP_GGML_METAL_DEL_KERNEL(im2col_f16);
485
+ WSP_GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
486
+ WSP_GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
417
487
  WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f16);
418
488
  WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f32);
489
+ WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
490
+ WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
491
+ WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
492
+ //WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
493
+ //WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
419
494
  WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f16);
420
495
  WSP_GGML_METAL_DEL_KERNEL(concat);
421
496
  WSP_GGML_METAL_DEL_KERNEL(sqr);
497
+ WSP_GGML_METAL_DEL_KERNEL(sum_rows);
422
498
 
423
499
  #undef WSP_GGML_METAL_DEL_KERNEL
424
500
 
@@ -452,6 +528,13 @@ int * wsp_ggml_metal_get_concur_list(struct wsp_ggml_metal_context * ctx) {
452
528
  return ctx->concur_list;
453
529
  }
454
530
 
531
+ // temporarily defined here for compatibility between ggml-backend and the old API
532
+ struct wsp_ggml_backend_metal_buffer_context {
533
+ void * data;
534
+
535
+ id<MTLBuffer> metal;
536
+ };
537
+
455
538
  // finds the Metal buffer that contains the tensor data on the GPU device
456
539
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
457
540
  // Metal buffer based on the host memory pointer
@@ -461,6 +544,19 @@ static id<MTLBuffer> wsp_ggml_metal_get_buffer(struct wsp_ggml_metal_context * c
461
544
 
462
545
  const int64_t tsize = wsp_ggml_nbytes(t);
463
546
 
547
+ // compatibility with ggml-backend
548
+ if (t->buffer && t->buffer->buft == wsp_ggml_backend_metal_buffer_type()) {
549
+ struct wsp_ggml_backend_metal_buffer_context * buf_ctx = (struct wsp_ggml_backend_metal_buffer_context *) t->buffer->context;
550
+
551
+ const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->data;
552
+
553
+ WSP_GGML_ASSERT(ioffs >= 0 && ioffs + tsize <= (int64_t) t->buffer->size);
554
+
555
+ *offs = (size_t) ioffs;
556
+
557
+ return buf_ctx->metal;
558
+ }
559
+
464
560
  // find the view that contains the tensor fully
465
561
  for (int i = 0; i < ctx->n_buffers; ++i) {
466
562
  const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
@@ -518,11 +614,11 @@ bool wsp_ggml_metal_add_buffer(
518
614
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
519
615
 
520
616
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
521
- WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
617
+ WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
522
618
  return false;
523
619
  }
524
620
 
525
- WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
621
+ WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB", __func__, name, size_aligned / 1024.0 / 1024.0);
526
622
 
527
623
  ++ctx->n_buffers;
528
624
  } else {
@@ -542,11 +638,11 @@ bool wsp_ggml_metal_add_buffer(
542
638
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
543
639
 
544
640
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
545
- WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
641
+ WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
546
642
  return false;
547
643
  }
548
644
 
549
- WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
645
+ WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
550
646
  if (i + size_step < size) {
551
647
  WSP_GGML_METAL_LOG_INFO("\n");
552
648
  }
@@ -561,7 +657,7 @@ bool wsp_ggml_metal_add_buffer(
561
657
  ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
562
658
 
563
659
  if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
564
- WSP_GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__);
660
+ WSP_GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
565
661
  } else {
566
662
  WSP_GGML_METAL_LOG_INFO("\n");
567
663
  }
@@ -683,6 +779,51 @@ void wsp_ggml_metal_graph_find_concurrency(
683
779
  }
684
780
  }
685
781
 
782
+ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) {
783
+ switch (op->op) {
784
+ case WSP_GGML_OP_UNARY:
785
+ switch (wsp_ggml_get_unary_op(op)) {
786
+ case WSP_GGML_UNARY_OP_SILU:
787
+ case WSP_GGML_UNARY_OP_RELU:
788
+ case WSP_GGML_UNARY_OP_GELU:
789
+ return true;
790
+ default:
791
+ return false;
792
+ }
793
+ case WSP_GGML_OP_NONE:
794
+ case WSP_GGML_OP_RESHAPE:
795
+ case WSP_GGML_OP_VIEW:
796
+ case WSP_GGML_OP_TRANSPOSE:
797
+ case WSP_GGML_OP_PERMUTE:
798
+ case WSP_GGML_OP_CONCAT:
799
+ case WSP_GGML_OP_ADD:
800
+ case WSP_GGML_OP_MUL:
801
+ case WSP_GGML_OP_DIV:
802
+ case WSP_GGML_OP_SCALE:
803
+ case WSP_GGML_OP_SQR:
804
+ case WSP_GGML_OP_SUM_ROWS:
805
+ case WSP_GGML_OP_SOFT_MAX:
806
+ case WSP_GGML_OP_RMS_NORM:
807
+ case WSP_GGML_OP_NORM:
808
+ case WSP_GGML_OP_ALIBI:
809
+ case WSP_GGML_OP_ROPE:
810
+ case WSP_GGML_OP_IM2COL:
811
+ case WSP_GGML_OP_ARGSORT:
812
+ case WSP_GGML_OP_DUP:
813
+ case WSP_GGML_OP_CPY:
814
+ case WSP_GGML_OP_CONT:
815
+ case WSP_GGML_OP_MUL_MAT:
816
+ case WSP_GGML_OP_MUL_MAT_ID:
817
+ return true;
818
+ case WSP_GGML_OP_DIAG_MASK_INF:
819
+ case WSP_GGML_OP_GET_ROWS:
820
+ {
821
+ return op->ne[0] % 4 == 0;
822
+ }
823
+ default:
824
+ return false;
825
+ }
826
+ }
686
827
  void wsp_ggml_metal_graph_compute(
687
828
  struct wsp_ggml_metal_context * ctx,
688
829
  struct wsp_ggml_cgraph * gf) {
@@ -753,6 +894,8 @@ void wsp_ggml_metal_graph_compute(
753
894
  } break;
754
895
  }
755
896
 
897
+ WSP_GGML_ASSERT(wsp_ggml_metal_supports_op(dst));
898
+
756
899
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
757
900
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
758
901
  const int64_t ne02 = src0 ? src0->ne[2] : 0;
@@ -845,6 +988,8 @@ void wsp_ggml_metal_graph_compute(
845
988
  [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
846
989
  } break;
847
990
  case WSP_GGML_OP_ADD:
991
+ case WSP_GGML_OP_MUL:
992
+ case WSP_GGML_OP_DIV:
848
993
  {
849
994
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
850
995
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
@@ -858,11 +1003,21 @@ void wsp_ggml_metal_graph_compute(
858
1003
  WSP_GGML_ASSERT(ne11 == 1);
859
1004
 
860
1005
  nb = ne00 / 4;
861
- [encoder setComputePipelineState:ctx->pipeline_add_row];
1006
+ switch (dst->op) {
1007
+ case WSP_GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
1008
+ case WSP_GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
1009
+ case WSP_GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
1010
+ default: WSP_GGML_ASSERT(false);
1011
+ }
862
1012
 
863
1013
  bcast_row = true;
864
1014
  } else {
865
- [encoder setComputePipelineState:ctx->pipeline_add];
1015
+ switch (dst->op) {
1016
+ case WSP_GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
1017
+ case WSP_GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
1018
+ case WSP_GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
1019
+ default: WSP_GGML_ASSERT(false);
1020
+ }
866
1021
  }
867
1022
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
868
1023
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -903,31 +1058,6 @@ void wsp_ggml_metal_graph_compute(
903
1058
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
904
1059
  }
905
1060
  } break;
906
- case WSP_GGML_OP_MUL:
907
- {
908
- WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
909
- WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
910
-
911
- // utilize float4
912
- WSP_GGML_ASSERT(ne00 % 4 == 0);
913
- const int64_t nb = ne00/4;
914
-
915
- if (wsp_ggml_nelements(src1) == ne10) {
916
- // src1 is a row
917
- WSP_GGML_ASSERT(ne11 == 1);
918
- [encoder setComputePipelineState:ctx->pipeline_mul_row];
919
- } else {
920
- [encoder setComputePipelineState:ctx->pipeline_mul];
921
- }
922
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
923
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
924
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
925
- [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
926
-
927
- const int64_t n = wsp_ggml_nelements(dst)/4;
928
-
929
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
930
- } break;
931
1061
  case WSP_GGML_OP_SCALE:
932
1062
  {
933
1063
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
@@ -1000,25 +1130,68 @@ void wsp_ggml_metal_graph_compute(
1000
1130
  const int64_t n = wsp_ggml_nelements(dst);
1001
1131
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1002
1132
  } break;
1133
+ case WSP_GGML_OP_SUM_ROWS:
1134
+ {
1135
+ WSP_GGML_ASSERT(src0->nb[0] == wsp_ggml_type_size(src0->type));
1136
+
1137
+ [encoder setComputePipelineState:ctx->pipeline_sum_rows];
1138
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1139
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1140
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1141
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1142
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1143
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1144
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1145
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1146
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1147
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1148
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1149
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1150
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1151
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1152
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1153
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1154
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1155
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1156
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1157
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1158
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1159
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1160
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1161
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1162
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1163
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
1164
+
1165
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1166
+ } break;
1003
1167
  case WSP_GGML_OP_SOFT_MAX:
1004
1168
  {
1005
1169
  int nth = 32; // SIMD width
1006
1170
 
1007
1171
  if (ne00%4 == 0) {
1172
+ while (nth < ne00/4 && nth < 256) {
1173
+ nth *= 2;
1174
+ }
1008
1175
  [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
1009
1176
  } else {
1010
- do {
1177
+ while (nth < ne00 && nth < 1024) {
1011
1178
  nth *= 2;
1012
- } while (nth <= ne00 && nth <= 1024);
1013
- nth /= 2;
1179
+ }
1014
1180
  [encoder setComputePipelineState:ctx->pipeline_soft_max];
1015
1181
  }
1016
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1017
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1018
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1019
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1020
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1021
- [encoder setThreadgroupMemoryLength:WSP_GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1182
+
1183
+ const float scale = ((float *) dst->op_params)[0];
1184
+
1185
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1186
+ if (id_src1) {
1187
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1188
+ }
1189
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1190
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1191
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1192
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1193
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1194
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1022
1195
 
1023
1196
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1024
1197
  } break;
@@ -1047,9 +1220,13 @@ void wsp_ggml_metal_graph_compute(
1047
1220
  case WSP_GGML_OP_MUL_MAT:
1048
1221
  {
1049
1222
  WSP_GGML_ASSERT(ne00 == ne10);
1050
- WSP_GGML_ASSERT(ne03 == ne13);
1051
1223
 
1052
- const uint gqa = ne12/ne02;
1224
+ // TODO: assert that dim2 and dim3 are contiguous
1225
+ WSP_GGML_ASSERT(ne12 % ne02 == 0);
1226
+ WSP_GGML_ASSERT(ne13 % ne03 == 0);
1227
+
1228
+ const uint r2 = ne12/ne02;
1229
+ const uint r3 = ne13/ne03;
1053
1230
 
1054
1231
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1055
1232
  // to the matrix-vector kernel
@@ -1084,7 +1261,7 @@ void wsp_ggml_metal_graph_compute(
1084
1261
  !wsp_ggml_is_transposed(src1) &&
1085
1262
  src1t == WSP_GGML_TYPE_F32 &&
1086
1263
  ne00 % 32 == 0 && ne00 >= 64 &&
1087
- ne11 > ne11_mm_min) {
1264
+ (ne11 > ne11_mm_min || (wsp_ggml_is_quantized(src0t) && ne12 > 1))) {
1088
1265
  //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1089
1266
  switch (src0->type) {
1090
1267
  case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
@@ -1114,9 +1291,10 @@ void wsp_ggml_metal_graph_compute(
1114
1291
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1115
1292
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1116
1293
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1117
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
1294
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1295
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1118
1296
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1119
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1297
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1120
1298
  } else {
1121
1299
  int nth0 = 32;
1122
1300
  int nth1 = 1;
@@ -1127,6 +1305,7 @@ void wsp_ggml_metal_graph_compute(
1127
1305
  switch (src0t) {
1128
1306
  case WSP_GGML_TYPE_F32:
1129
1307
  {
1308
+ WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
1130
1309
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
1131
1310
  nrows = 4;
1132
1311
  } break;
@@ -1134,102 +1313,77 @@ void wsp_ggml_metal_graph_compute(
1134
1313
  {
1135
1314
  nth0 = 32;
1136
1315
  nth1 = 1;
1137
- if (ne11 * ne12 < 4) {
1138
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1139
- } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1140
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1141
- nrows = ne11;
1316
+ if (src1t == WSP_GGML_TYPE_F32) {
1317
+ if (ne11 * ne12 < 4) {
1318
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1319
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1320
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1321
+ nrows = ne11;
1322
+ } else {
1323
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1324
+ nrows = 4;
1325
+ }
1142
1326
  } else {
1143
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1327
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
1144
1328
  nrows = 4;
1145
1329
  }
1146
1330
  } break;
1147
1331
  case WSP_GGML_TYPE_Q4_0:
1148
1332
  {
1149
- WSP_GGML_ASSERT(ne02 == 1);
1150
- WSP_GGML_ASSERT(ne12 == 1);
1151
-
1152
1333
  nth0 = 8;
1153
1334
  nth1 = 8;
1154
1335
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
1155
1336
  } break;
1156
1337
  case WSP_GGML_TYPE_Q4_1:
1157
1338
  {
1158
- WSP_GGML_ASSERT(ne02 == 1);
1159
- WSP_GGML_ASSERT(ne12 == 1);
1160
-
1161
1339
  nth0 = 8;
1162
1340
  nth1 = 8;
1163
1341
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
1164
1342
  } break;
1165
1343
  case WSP_GGML_TYPE_Q5_0:
1166
1344
  {
1167
- WSP_GGML_ASSERT(ne02 == 1);
1168
- WSP_GGML_ASSERT(ne12 == 1);
1169
-
1170
1345
  nth0 = 8;
1171
1346
  nth1 = 8;
1172
1347
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
1173
1348
  } break;
1174
1349
  case WSP_GGML_TYPE_Q5_1:
1175
1350
  {
1176
- WSP_GGML_ASSERT(ne02 == 1);
1177
- WSP_GGML_ASSERT(ne12 == 1);
1178
-
1179
1351
  nth0 = 8;
1180
1352
  nth1 = 8;
1181
1353
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
1182
1354
  } break;
1183
1355
  case WSP_GGML_TYPE_Q8_0:
1184
1356
  {
1185
- WSP_GGML_ASSERT(ne02 == 1);
1186
- WSP_GGML_ASSERT(ne12 == 1);
1187
-
1188
1357
  nth0 = 8;
1189
1358
  nth1 = 8;
1190
1359
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
1191
1360
  } break;
1192
1361
  case WSP_GGML_TYPE_Q2_K:
1193
1362
  {
1194
- WSP_GGML_ASSERT(ne02 == 1);
1195
- WSP_GGML_ASSERT(ne12 == 1);
1196
-
1197
1363
  nth0 = 2;
1198
1364
  nth1 = 32;
1199
1365
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
1200
1366
  } break;
1201
1367
  case WSP_GGML_TYPE_Q3_K:
1202
1368
  {
1203
- WSP_GGML_ASSERT(ne02 == 1);
1204
- WSP_GGML_ASSERT(ne12 == 1);
1205
-
1206
1369
  nth0 = 2;
1207
1370
  nth1 = 32;
1208
1371
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
1209
1372
  } break;
1210
1373
  case WSP_GGML_TYPE_Q4_K:
1211
1374
  {
1212
- WSP_GGML_ASSERT(ne02 == 1);
1213
- WSP_GGML_ASSERT(ne12 == 1);
1214
-
1215
1375
  nth0 = 4; //1;
1216
1376
  nth1 = 8; //32;
1217
1377
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
1218
1378
  } break;
1219
1379
  case WSP_GGML_TYPE_Q5_K:
1220
1380
  {
1221
- WSP_GGML_ASSERT(ne02 == 1);
1222
- WSP_GGML_ASSERT(ne12 == 1);
1223
-
1224
1381
  nth0 = 2;
1225
1382
  nth1 = 32;
1226
1383
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
1227
1384
  } break;
1228
1385
  case WSP_GGML_TYPE_Q6_K:
1229
1386
  {
1230
- WSP_GGML_ASSERT(ne02 == 1);
1231
- WSP_GGML_ASSERT(ne12 == 1);
1232
-
1233
1387
  nth0 = 2;
1234
1388
  nth1 = 32;
1235
1389
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
@@ -1258,32 +1412,125 @@ void wsp_ggml_metal_graph_compute(
1258
1412
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1259
1413
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1260
1414
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1261
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
1415
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1416
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1262
1417
 
1263
1418
  if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 ||
1264
1419
  src0t == WSP_GGML_TYPE_Q5_0 || src0t == WSP_GGML_TYPE_Q5_1 || src0t == WSP_GGML_TYPE_Q8_0 ||
1265
1420
  src0t == WSP_GGML_TYPE_Q2_K) { // || src0t == WSP_GGML_TYPE_Q4_K) {
1266
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1421
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1267
1422
  }
1268
1423
  else if (src0t == WSP_GGML_TYPE_Q4_K) {
1269
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1424
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1270
1425
  }
1271
1426
  else if (src0t == WSP_GGML_TYPE_Q3_K) {
1272
1427
  #ifdef WSP_GGML_QKK_64
1273
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1428
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1274
1429
  #else
1275
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1430
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1276
1431
  #endif
1277
1432
  }
1278
1433
  else if (src0t == WSP_GGML_TYPE_Q5_K) {
1279
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1434
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1280
1435
  }
1281
1436
  else if (src0t == WSP_GGML_TYPE_Q6_K) {
1282
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1437
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1283
1438
  } else {
1284
1439
  int64_t ny = (ne11 + nrows - 1)/nrows;
1285
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1440
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1441
+ }
1442
+ }
1443
+ } break;
1444
+ case WSP_GGML_OP_MUL_MAT_ID:
1445
+ {
1446
+ //WSP_GGML_ASSERT(ne00 == ne10);
1447
+ //WSP_GGML_ASSERT(ne03 == ne13);
1448
+
1449
+ WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_I32);
1450
+
1451
+ const int n_as = ne00;
1452
+
1453
+ // TODO: make this more general
1454
+ WSP_GGML_ASSERT(n_as <= 8);
1455
+
1456
+ struct wsp_ggml_tensor * src2 = gf->nodes[i]->src[2];
1457
+
1458
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
1459
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
1460
+ const int64_t ne22 = src2 ? src2->ne[2] : 0;
1461
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; WSP_GGML_UNUSED(ne23);
1462
+
1463
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; WSP_GGML_UNUSED(nb20);
1464
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1465
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1466
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0; WSP_GGML_UNUSED(nb23);
1467
+
1468
+ const enum wsp_ggml_type src2t = src2 ? src2->type : WSP_GGML_TYPE_COUNT; WSP_GGML_UNUSED(src2t);
1469
+
1470
+ WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src2));
1471
+ WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src1));
1472
+
1473
+ WSP_GGML_ASSERT(ne20 % 32 == 0);
1474
+ // !!!!!!!!! TODO: this assert is probably required but not sure!
1475
+ //WSP_GGML_ASSERT(ne20 >= 64);
1476
+ WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
1477
+
1478
+ const uint r2 = ne12/ne22;
1479
+ const uint r3 = ne13/ne23;
1480
+
1481
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1482
+ // to the matrix-vector kernel
1483
+ int ne11_mm_min = 0;
1484
+
1485
+ const int idx = ((int32_t *) dst->op_params)[0];
1486
+
1487
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1488
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1489
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1490
+ ne11 > ne11_mm_min) {
1491
+ switch (src2->type) {
1492
+ case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
1493
+ case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
1494
+ case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
1495
+ case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
1496
+ case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
1497
+ case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
1498
+ case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
1499
+ case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
1500
+ case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
1501
+ case WSP_GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
1502
+ case WSP_GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
1503
+ case WSP_GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
1504
+ default: WSP_GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1505
+ }
1506
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1507
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1508
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1509
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
1510
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
1511
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
1512
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6];
1513
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1514
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1515
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1516
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1517
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1518
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1519
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1520
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1521
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:15];
1522
+ // TODO: how to make this an array? read Metal docs
1523
+ for (int j = 0; j < n_as; ++j) {
1524
+ struct wsp_ggml_tensor * src_cur = dst->src[2 + j];
1525
+
1526
+ size_t offs_src_cur = 0;
1527
+ id<MTLBuffer> id_src_cur = wsp_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1528
+
1529
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
1286
1530
  }
1531
+
1532
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1533
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1287
1534
  }
1288
1535
  } break;
1289
1536
  case WSP_GGML_OP_GET_ROWS:
@@ -1322,15 +1569,19 @@ void wsp_ggml_metal_graph_compute(
1322
1569
  float eps;
1323
1570
  memcpy(&eps, dst->op_params, sizeof(float));
1324
1571
 
1325
- const int nth = MIN(512, ne00);
1572
+ int nth = 32; // SIMD width
1573
+
1574
+ while (nth < ne00/4 && nth < 1024) {
1575
+ nth *= 2;
1576
+ }
1326
1577
 
1327
1578
  [encoder setComputePipelineState:ctx->pipeline_rms_norm];
1328
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1329
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1330
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1331
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1332
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1333
- [encoder setThreadgroupMemoryLength:WSP_GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1579
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1580
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1581
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1582
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1583
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1584
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1334
1585
 
1335
1586
  const int64_t nrows = wsp_ggml_nrows(src0);
1336
1587
 
@@ -1404,7 +1655,8 @@ void wsp_ggml_metal_graph_compute(
1404
1655
  const int n_past = ((int32_t *) dst->op_params)[0];
1405
1656
  const int n_dims = ((int32_t *) dst->op_params)[1];
1406
1657
  const int mode = ((int32_t *) dst->op_params)[2];
1407
- const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
1658
+ // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
1659
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
1408
1660
 
1409
1661
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1410
1662
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
@@ -1452,18 +1704,100 @@ void wsp_ggml_metal_graph_compute(
1452
1704
 
1453
1705
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1454
1706
  } break;
1707
+ case WSP_GGML_OP_IM2COL:
1708
+ {
1709
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
1710
+ WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
1711
+ WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16);
1712
+
1713
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
1714
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
1715
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
1716
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
1717
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
1718
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
1719
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
1720
+
1721
+ const int32_t N = src1->ne[is_2D ? 3 : 2];
1722
+ const int32_t IC = src1->ne[is_2D ? 2 : 1];
1723
+ const int32_t IH = is_2D ? src1->ne[1] : 1;
1724
+ const int32_t IW = src1->ne[0];
1725
+
1726
+ const int32_t KH = is_2D ? src0->ne[1] : 1;
1727
+ const int32_t KW = src0->ne[0];
1728
+
1729
+ const int32_t OH = is_2D ? dst->ne[2] : 1;
1730
+ const int32_t OW = dst->ne[1];
1731
+
1732
+ const int32_t CHW = IC * KH * KW;
1733
+
1734
+ const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
1735
+ const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
1736
+
1737
+ switch (src0->type) {
1738
+ case WSP_GGML_TYPE_F32: WSP_GGML_ASSERT(false && "not implemented"); break;
1739
+ case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
1740
+ default: WSP_GGML_ASSERT(false);
1741
+ };
1742
+
1743
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
1744
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1745
+ [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
1746
+ [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
1747
+ [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
1748
+ [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
1749
+ [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
1750
+ [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
1751
+ [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
1752
+ [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
1753
+ [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
1754
+ [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
1755
+ [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
1756
+
1757
+ [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
1758
+ } break;
1759
+ case WSP_GGML_OP_ARGSORT:
1760
+ {
1761
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
1762
+ WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_I32);
1763
+
1764
+ const int nrows = wsp_ggml_nrows(src0);
1765
+
1766
+ enum wsp_ggml_sort_order order = (enum wsp_ggml_sort_order) dst->op_params[0];
1767
+
1768
+ switch (order) {
1769
+ case WSP_GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
1770
+ case WSP_GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
1771
+ default: WSP_GGML_ASSERT(false);
1772
+ };
1773
+
1774
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1775
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1776
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1777
+
1778
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
1779
+ } break;
1455
1780
  case WSP_GGML_OP_DUP:
1456
1781
  case WSP_GGML_OP_CPY:
1457
1782
  case WSP_GGML_OP_CONT:
1458
1783
  {
1459
- const int nth = MIN(1024, ne00);
1784
+ WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(src0->type) == 0);
1785
+
1786
+ int nth = MIN(1024, ne00/wsp_ggml_blck_size(src0->type));
1460
1787
 
1461
1788
  switch (src0t) {
1462
1789
  case WSP_GGML_TYPE_F32:
1463
1790
  {
1791
+ WSP_GGML_ASSERT(ne0 % wsp_ggml_blck_size(dst->type) == 0);
1792
+
1464
1793
  switch (dstt) {
1465
- case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
1466
- case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
1794
+ case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
1795
+ case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
1796
+ case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
1797
+ case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
1798
+ case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
1799
+ //case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
1800
+ //case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
1467
1801
  default: WSP_GGML_ASSERT(false && "not implemented");
1468
1802
  };
1469
1803
  } break;
@@ -1538,81 +1872,148 @@ void wsp_ggml_metal_graph_compute(
1538
1872
 
1539
1873
  // backend interface
1540
1874
 
1541
- static const char * wsp_ggml_backend_metal_name(wsp_ggml_backend_t backend) {
1542
- return "Metal";
1875
+ static id<MTLDevice> g_backend_device = nil;
1876
+ static int g_backend_device_ref_count = 0;
1543
1877
 
1544
- UNUSED(backend);
1878
+ static id<MTLDevice> wsp_ggml_backend_metal_get_device(void) {
1879
+ if (g_backend_device == nil) {
1880
+ g_backend_device = MTLCreateSystemDefaultDevice();
1881
+ }
1882
+
1883
+ g_backend_device_ref_count++;
1884
+
1885
+ return g_backend_device;
1545
1886
  }
1546
1887
 
1547
- static void wsp_ggml_backend_metal_free(wsp_ggml_backend_t backend) {
1548
- struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
1549
- wsp_ggml_metal_free(ctx);
1550
- free(backend);
1888
+ static void wsp_ggml_backend_metal_free_device(void) {
1889
+ assert(g_backend_device_ref_count > 0);
1890
+
1891
+ g_backend_device_ref_count--;
1892
+
1893
+ if (g_backend_device_ref_count == 0) {
1894
+ g_backend_device = nil;
1895
+ }
1551
1896
  }
1552
1897
 
1553
1898
  static void * wsp_ggml_backend_metal_buffer_get_base(wsp_ggml_backend_buffer_t buffer) {
1554
- return (void *)buffer->context;
1899
+ struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context;
1900
+
1901
+ return ctx->data;
1555
1902
  }
1556
1903
 
1557
1904
  static void wsp_ggml_backend_metal_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) {
1558
- free(buffer->context);
1905
+ struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context;
1906
+
1907
+ wsp_ggml_backend_metal_free_device();
1908
+
1909
+ free(ctx->data);
1910
+ free(ctx);
1911
+
1912
+ UNUSED(buffer);
1913
+ }
1914
+
1915
+ static void wsp_ggml_backend_metal_buffer_set_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1916
+ WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds");
1917
+ WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1918
+
1919
+ memcpy((char *)tensor->data + offset, data, size);
1920
+
1921
+ UNUSED(buffer);
1922
+ }
1923
+
1924
+ static void wsp_ggml_backend_metal_buffer_get_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1925
+ WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
1926
+ WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1927
+
1928
+ memcpy(data, (const char *)tensor->data + offset, size);
1929
+
1930
+ UNUSED(buffer);
1931
+ }
1932
+
1933
+ static void wsp_ggml_backend_metal_buffer_cpy_tensor_from(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
1934
+ wsp_ggml_backend_tensor_get(src, dst->data, 0, wsp_ggml_nbytes(src));
1935
+
1936
+ UNUSED(buffer);
1937
+ }
1938
+
1939
+ static void wsp_ggml_backend_metal_buffer_cpy_tensor_to(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
1940
+ wsp_ggml_backend_tensor_set(dst, src->data, 0, wsp_ggml_nbytes(src));
1941
+
1559
1942
  UNUSED(buffer);
1560
1943
  }
1561
1944
 
1562
1945
  static struct wsp_ggml_backend_buffer_i metal_backend_buffer_i = {
1563
- /* .free_buffer = */ wsp_ggml_backend_metal_buffer_free_buffer,
1564
- /* .get_base = */ wsp_ggml_backend_metal_buffer_get_base,
1565
- /* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes
1566
- /* .init_tensor = */ NULL, // no initialization required
1567
- /* .free_tensor = */ NULL, // no cleanup required
1946
+ /* .free_buffer = */ wsp_ggml_backend_metal_buffer_free_buffer,
1947
+ /* .get_base = */ wsp_ggml_backend_metal_buffer_get_base,
1948
+ /* .init_tensor = */ NULL,
1949
+ /* .set_tensor = */ wsp_ggml_backend_metal_buffer_set_tensor,
1950
+ /* .get_tensor = */ wsp_ggml_backend_metal_buffer_get_tensor,
1951
+ /* .cpy_tensor_from = */ wsp_ggml_backend_metal_buffer_cpy_tensor_from,
1952
+ /* .cpy_tensor_to = */ wsp_ggml_backend_metal_buffer_cpy_tensor_to,
1568
1953
  };
1569
1954
 
1570
- static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_alloc_buffer(wsp_ggml_backend_t backend, size_t size) {
1571
- struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
1955
+ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
1956
+ struct wsp_ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct wsp_ggml_backend_metal_buffer_context));
1572
1957
 
1573
- void * data = wsp_ggml_metal_host_malloc(size);
1958
+ const size_t size_page = sysconf(_SC_PAGESIZE);
1574
1959
 
1575
- // TODO: set proper name of the buffers
1576
- wsp_ggml_metal_add_buffer(ctx, "backend", data, size, 0);
1960
+ size_t size_aligned = size;
1961
+ if ((size_aligned % size_page) != 0) {
1962
+ size_aligned += (size_page - (size_aligned % size_page));
1963
+ }
1577
1964
 
1578
- return wsp_ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size);
1965
+ ctx->data = wsp_ggml_metal_host_malloc(size);
1966
+ ctx->metal = [wsp_ggml_backend_metal_get_device() newBufferWithBytesNoCopy:ctx->data
1967
+ length:size_aligned
1968
+ options:MTLResourceStorageModeShared
1969
+ deallocator:nil];
1970
+
1971
+ return wsp_ggml_backend_buffer_init(buft, metal_backend_buffer_i, ctx, size);
1579
1972
  }
1580
1973
 
1581
- static size_t wsp_ggml_backend_metal_get_alignment(wsp_ggml_backend_t backend) {
1974
+ static size_t wsp_ggml_backend_metal_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
1582
1975
  return 32;
1583
- UNUSED(backend);
1976
+ UNUSED(buft);
1584
1977
  }
1585
1978
 
1586
- static void wsp_ggml_backend_metal_set_tensor_async(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1587
- WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds");
1588
- WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1589
-
1590
- memcpy((char *)tensor->data + offset, data, size);
1979
+ static bool wsp_ggml_backend_metal_buffer_type_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend) {
1980
+ return wsp_ggml_backend_is_metal(backend) || wsp_ggml_backend_is_cpu(backend);
1591
1981
 
1592
- UNUSED(backend);
1982
+ WSP_GGML_UNUSED(buft);
1593
1983
  }
1594
1984
 
1595
- static void wsp_ggml_backend_metal_get_tensor_async(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1596
- WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
1597
- WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1598
-
1599
- memcpy(data, (const char *)tensor->data + offset, size);
1985
+ wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void) {
1986
+ static struct wsp_ggml_backend_buffer_type wsp_ggml_backend_buffer_type_metal = {
1987
+ /* .iface = */ {
1988
+ /* .alloc_buffer = */ wsp_ggml_backend_metal_buffer_type_alloc_buffer,
1989
+ /* .get_alignment = */ wsp_ggml_backend_metal_buffer_type_get_alignment,
1990
+ /* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes
1991
+ /* .supports_backend = */ wsp_ggml_backend_metal_buffer_type_supports_backend,
1992
+ },
1993
+ /* .context = */ NULL,
1994
+ };
1600
1995
 
1601
- UNUSED(backend);
1996
+ return &wsp_ggml_backend_buffer_type_metal;
1602
1997
  }
1603
1998
 
1604
- static void wsp_ggml_backend_metal_synchronize(wsp_ggml_backend_t backend) {
1999
+ static const char * wsp_ggml_backend_metal_name(wsp_ggml_backend_t backend) {
2000
+ return "Metal";
2001
+
1605
2002
  UNUSED(backend);
1606
2003
  }
1607
2004
 
1608
- static void wsp_ggml_backend_metal_cpy_tensor_from(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
1609
- wsp_ggml_backend_tensor_get(src, dst->data, 0, wsp_ggml_nbytes(src));
2005
+ static void wsp_ggml_backend_metal_free(wsp_ggml_backend_t backend) {
2006
+ struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
2007
+ wsp_ggml_metal_free(ctx);
2008
+ free(backend);
2009
+ }
1610
2010
 
2011
+ static void wsp_ggml_backend_metal_synchronize(wsp_ggml_backend_t backend) {
1611
2012
  UNUSED(backend);
1612
2013
  }
1613
2014
 
1614
- static void wsp_ggml_backend_metal_cpy_tensor_to(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
1615
- wsp_ggml_backend_tensor_set_async(dst, src->data, 0, wsp_ggml_nbytes(src));
2015
+ static wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_get_default_buffer_type(wsp_ggml_backend_t backend) {
2016
+ return wsp_ggml_backend_metal_buffer_type();
1616
2017
 
1617
2018
  UNUSED(backend);
1618
2019
  }
@@ -1624,32 +2025,43 @@ static void wsp_ggml_backend_metal_graph_compute(wsp_ggml_backend_t backend, str
1624
2025
  }
1625
2026
 
1626
2027
  static bool wsp_ggml_backend_metal_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) {
1627
- return true;
2028
+ return wsp_ggml_metal_supports_op(op);
2029
+
1628
2030
  UNUSED(backend);
1629
- UNUSED(op);
1630
2031
  }
1631
2032
 
1632
2033
  static struct wsp_ggml_backend_i metal_backend_i = {
1633
- /* .get_name = */ wsp_ggml_backend_metal_name,
1634
- /* .free = */ wsp_ggml_backend_metal_free,
1635
- /* .alloc_buffer = */ wsp_ggml_backend_metal_alloc_buffer,
1636
- /* .get_alignment = */ wsp_ggml_backend_metal_get_alignment,
1637
- /* .set_tensor_async = */ wsp_ggml_backend_metal_set_tensor_async,
1638
- /* .get_tensor_async = */ wsp_ggml_backend_metal_get_tensor_async,
1639
- /* .synchronize = */ wsp_ggml_backend_metal_synchronize,
1640
- /* .cpy_tensor_from = */ wsp_ggml_backend_metal_cpy_tensor_from,
1641
- /* .cpy_tensor_to = */ wsp_ggml_backend_metal_cpy_tensor_to,
1642
- /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
1643
- /* .graph_plan_free = */ NULL,
1644
- /* .graph_plan_compute = */ NULL,
1645
- /* .graph_compute = */ wsp_ggml_backend_metal_graph_compute,
1646
- /* .supports_op = */ wsp_ggml_backend_metal_supports_op,
2034
+ /* .get_name = */ wsp_ggml_backend_metal_name,
2035
+ /* .free = */ wsp_ggml_backend_metal_free,
2036
+ /* .get_default_buffer_type = */ wsp_ggml_backend_metal_get_default_buffer_type,
2037
+ /* .set_tensor_async = */ NULL,
2038
+ /* .get_tensor_async = */ NULL,
2039
+ /* .cpy_tensor_from_async = */ NULL,
2040
+ /* .cpy_tensor_to_async = */ NULL,
2041
+ /* .synchronize = */ wsp_ggml_backend_metal_synchronize,
2042
+ /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
2043
+ /* .graph_plan_free = */ NULL,
2044
+ /* .graph_plan_compute = */ NULL,
2045
+ /* .graph_compute = */ wsp_ggml_backend_metal_graph_compute,
2046
+ /* .supports_op = */ wsp_ggml_backend_metal_supports_op,
1647
2047
  };
1648
2048
 
2049
+ // TODO: make a common log callback for all backends in ggml-backend
2050
+ static void wsp_ggml_backend_log_callback(enum wsp_ggml_log_level level, const char * msg, void * user_data) {
2051
+ fprintf(stderr, "%s", msg);
2052
+
2053
+ UNUSED(level);
2054
+ UNUSED(user_data);
2055
+ }
2056
+
1649
2057
  wsp_ggml_backend_t wsp_ggml_backend_metal_init(void) {
1650
- struct wsp_ggml_metal_context * ctx = malloc(sizeof(struct wsp_ggml_metal_context));
2058
+ wsp_ggml_metal_log_set_callback(wsp_ggml_backend_log_callback, NULL);
1651
2059
 
1652
- ctx = wsp_ggml_metal_init(WSP_GGML_DEFAULT_N_THREADS);
2060
+ struct wsp_ggml_metal_context * ctx = wsp_ggml_metal_init(WSP_GGML_DEFAULT_N_THREADS);
2061
+
2062
+ if (ctx == NULL) {
2063
+ return NULL;
2064
+ }
1653
2065
 
1654
2066
  wsp_ggml_backend_t metal_backend = malloc(sizeof(struct wsp_ggml_backend));
1655
2067
 
@@ -1666,7 +2078,26 @@ bool wsp_ggml_backend_is_metal(wsp_ggml_backend_t backend) {
1666
2078
  }
1667
2079
 
1668
2080
  void wsp_ggml_backend_metal_set_n_cb(wsp_ggml_backend_t backend, int n_cb) {
2081
+ WSP_GGML_ASSERT(wsp_ggml_backend_is_metal(backend));
2082
+
1669
2083
  struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
1670
2084
 
1671
2085
  wsp_ggml_metal_set_n_cb(ctx, n_cb);
1672
2086
  }
2087
+
2088
+ bool wsp_ggml_backend_metal_supports_family(wsp_ggml_backend_t backend, int family) {
2089
+ WSP_GGML_ASSERT(wsp_ggml_backend_is_metal(backend));
2090
+
2091
+ struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
2092
+
2093
+ return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
2094
+ }
2095
+
2096
+ wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
2097
+
2098
+ wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data) {
2099
+ return wsp_ggml_backend_metal_init();
2100
+
2101
+ WSP_GGML_UNUSED(params);
2102
+ WSP_GGML_UNUSED(user_data);
2103
+ }