whisper.rn 0.4.0-rc.4 → 0.4.0-rc.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. package/README.md +6 -6
  2. package/android/build.gradle +4 -0
  3. package/android/src/main/CMakeLists.txt +5 -0
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
  5. package/android/src/main/java/com/rnwhisper/WhisperContext.java +57 -134
  6. package/android/src/main/jni-utils.h +76 -0
  7. package/android/src/main/jni.cpp +188 -112
  8. package/cpp/README.md +1 -1
  9. package/cpp/coreml/whisper-encoder-impl.h +1 -1
  10. package/cpp/coreml/whisper-encoder.h +4 -0
  11. package/cpp/coreml/whisper-encoder.mm +4 -2
  12. package/cpp/ggml-alloc.c +55 -19
  13. package/cpp/ggml-alloc.h +8 -1
  14. package/cpp/ggml-backend-impl.h +46 -21
  15. package/cpp/ggml-backend.c +563 -156
  16. package/cpp/ggml-backend.h +62 -17
  17. package/cpp/ggml-impl.h +1 -1
  18. package/cpp/ggml-metal-whisper.metal +2444 -359
  19. package/cpp/ggml-metal.h +7 -1
  20. package/cpp/ggml-metal.m +1105 -197
  21. package/cpp/ggml-quants.c +66 -61
  22. package/cpp/ggml-quants.h +40 -40
  23. package/cpp/ggml.c +1040 -1590
  24. package/cpp/ggml.h +109 -30
  25. package/cpp/rn-audioutils.cpp +68 -0
  26. package/cpp/rn-audioutils.h +14 -0
  27. package/cpp/rn-whisper-log.h +11 -0
  28. package/cpp/rn-whisper.cpp +143 -59
  29. package/cpp/rn-whisper.h +48 -15
  30. package/cpp/whisper.cpp +1635 -928
  31. package/cpp/whisper.h +55 -10
  32. package/ios/RNWhisper.mm +7 -7
  33. package/ios/RNWhisperAudioUtils.h +0 -2
  34. package/ios/RNWhisperAudioUtils.m +0 -56
  35. package/ios/RNWhisperContext.h +3 -11
  36. package/ios/RNWhisperContext.mm +68 -137
  37. package/lib/commonjs/index.js.map +1 -1
  38. package/lib/commonjs/version.json +1 -1
  39. package/lib/module/index.js.map +1 -1
  40. package/lib/module/version.json +1 -1
  41. package/lib/typescript/index.d.ts +5 -0
  42. package/lib/typescript/index.d.ts.map +1 -1
  43. package/package.json +6 -5
  44. package/src/index.ts +5 -0
  45. package/src/version.json +1 -1
  46. package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +0 -4
  47. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +0 -8
  48. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
  49. package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +0 -19
package/cpp/ggml-metal.m CHANGED
@@ -62,11 +62,15 @@ struct wsp_ggml_metal_context {
62
62
  WSP_GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
63
63
  WSP_GGML_METAL_DECL_KERNEL(mul);
64
64
  WSP_GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
65
+ WSP_GGML_METAL_DECL_KERNEL(div);
66
+ WSP_GGML_METAL_DECL_KERNEL(div_row);
65
67
  WSP_GGML_METAL_DECL_KERNEL(scale);
66
68
  WSP_GGML_METAL_DECL_KERNEL(scale_4);
67
- WSP_GGML_METAL_DECL_KERNEL(silu);
69
+ WSP_GGML_METAL_DECL_KERNEL(tanh);
68
70
  WSP_GGML_METAL_DECL_KERNEL(relu);
69
71
  WSP_GGML_METAL_DECL_KERNEL(gelu);
72
+ WSP_GGML_METAL_DECL_KERNEL(gelu_quick);
73
+ WSP_GGML_METAL_DECL_KERNEL(silu);
70
74
  WSP_GGML_METAL_DECL_KERNEL(soft_max);
71
75
  WSP_GGML_METAL_DECL_KERNEL(soft_max_4);
72
76
  WSP_GGML_METAL_DECL_KERNEL(diag_mask_inf);
@@ -84,8 +88,10 @@ struct wsp_ggml_metal_context {
84
88
  WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_K);
85
89
  WSP_GGML_METAL_DECL_KERNEL(get_rows_q6_K);
86
90
  WSP_GGML_METAL_DECL_KERNEL(rms_norm);
91
+ WSP_GGML_METAL_DECL_KERNEL(group_norm);
87
92
  WSP_GGML_METAL_DECL_KERNEL(norm);
88
93
  WSP_GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
94
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
89
95
  WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
90
96
  WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
91
97
  WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
@@ -99,6 +105,21 @@ struct wsp_ggml_metal_context {
99
105
  WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
100
106
  WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
101
107
  WSP_GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
108
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
109
+ //WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
110
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
111
+ //WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
112
+ //WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
113
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
114
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
115
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
116
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
117
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
118
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
119
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
120
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
121
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
122
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
102
123
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
103
124
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
104
125
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -111,14 +132,39 @@ struct wsp_ggml_metal_context {
111
132
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
112
133
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
113
134
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
135
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
136
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
137
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
138
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
139
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
140
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
141
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
142
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
143
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
144
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
145
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
146
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
114
147
  WSP_GGML_METAL_DECL_KERNEL(rope_f32);
115
148
  WSP_GGML_METAL_DECL_KERNEL(rope_f16);
116
149
  WSP_GGML_METAL_DECL_KERNEL(alibi_f32);
150
+ WSP_GGML_METAL_DECL_KERNEL(im2col_f16);
151
+ WSP_GGML_METAL_DECL_KERNEL(upscale_f32);
152
+ WSP_GGML_METAL_DECL_KERNEL(pad_f32);
153
+ WSP_GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
154
+ WSP_GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
155
+ WSP_GGML_METAL_DECL_KERNEL(leaky_relu_f32);
117
156
  WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f16);
118
157
  WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f32);
158
+ WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
159
+ WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
160
+ WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
161
+ //WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
162
+ //WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
119
163
  WSP_GGML_METAL_DECL_KERNEL(cpy_f16_f16);
164
+ WSP_GGML_METAL_DECL_KERNEL(cpy_f16_f32);
120
165
  WSP_GGML_METAL_DECL_KERNEL(concat);
121
166
  WSP_GGML_METAL_DECL_KERNEL(sqr);
167
+ WSP_GGML_METAL_DECL_KERNEL(sum_rows);
122
168
 
123
169
  #undef WSP_GGML_METAL_DECL_KERNEL
124
170
  };
@@ -126,7 +172,7 @@ struct wsp_ggml_metal_context {
126
172
  // MSL code
127
173
  // TODO: move the contents here when ready
128
174
  // for now it is easier to work in a separate file
129
- static NSString * const msl_library_source = @"see metal.metal";
175
+ //static NSString * const msl_library_source = @"see metal.metal";
130
176
 
131
177
  // Here to assist with NSBundle Path Hack
132
178
  @interface WSPGGMLMetalClass : NSObject
@@ -142,7 +188,8 @@ void wsp_ggml_metal_log_set_callback(wsp_ggml_log_callback log_callback, void *
142
188
  wsp_ggml_metal_log_user_data = user_data;
143
189
  }
144
190
 
145
- static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char* format, ...){
191
+ WSP_GGML_ATTRIBUTE_FORMAT(2, 3)
192
+ static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char * format, ...){
146
193
  if (wsp_ggml_metal_log_callback != NULL) {
147
194
  va_list args;
148
195
  va_start(args, format);
@@ -152,6 +199,8 @@ static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char* format
152
199
  wsp_ggml_metal_log_callback(level, buffer, wsp_ggml_metal_log_user_data);
153
200
  } else {
154
201
  char* buffer2 = malloc(len+1);
202
+ va_end(args);
203
+ va_start(args, format);
155
204
  vsnprintf(buffer2, len+1, format, args);
156
205
  buffer2[len] = 0;
157
206
  wsp_ggml_metal_log_callback(level, buffer2, wsp_ggml_metal_log_user_data);
@@ -161,12 +210,10 @@ static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char* format
161
210
  }
162
211
  }
163
212
 
164
-
165
-
166
213
  struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
167
214
  WSP_GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
168
215
 
169
- id <MTLDevice> device;
216
+ id<MTLDevice> device;
170
217
  NSString * s;
171
218
 
172
219
  #if TARGET_OS_OSX
@@ -184,7 +231,7 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
184
231
  WSP_GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
185
232
 
186
233
  // Configure context
187
- struct wsp_ggml_metal_context * ctx = malloc(sizeof(struct wsp_ggml_metal_context));
234
+ struct wsp_ggml_metal_context * ctx = calloc(1, sizeof(struct wsp_ggml_metal_context));
188
235
  ctx->device = device;
189
236
  ctx->n_cb = MIN(n_cb, WSP_GGML_METAL_MAX_BUFFERS);
190
237
  ctx->queue = [ctx->device newCommandQueue];
@@ -212,6 +259,9 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
212
259
 
213
260
  NSString * sourcePath;
214
261
  NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"WSP_GGML_METAL_PATH_RESOURCES"];
262
+
263
+ WSP_GGML_METAL_LOG_INFO("%s: WSP_GGML_METAL_PATH_RESOURCES = %s\n", __func__, ggmlMetalPathResources ? [ggmlMetalPathResources UTF8String] : "nil");
264
+
215
265
  if (ggmlMetalPathResources) {
216
266
  sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
217
267
  } else {
@@ -242,6 +292,29 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
242
292
  }
243
293
  }
244
294
 
295
+ #if TARGET_OS_OSX
296
+ // print MTL GPU family:
297
+ WSP_GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
298
+
299
+ // determine max supported GPU family
300
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
301
+ // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
302
+ for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
303
+ if ([ctx->device supportsFamily:i]) {
304
+ WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
305
+ break;
306
+ }
307
+ }
308
+
309
+ WSP_GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
310
+ WSP_GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
311
+ if (ctx->device.maxTransferRate != 0) {
312
+ WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
313
+ } else {
314
+ WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
315
+ }
316
+ #endif
317
+
245
318
  // load kernels
246
319
  {
247
320
  NSError * error = nil;
@@ -263,11 +336,15 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
263
336
  WSP_GGML_METAL_ADD_KERNEL(add_row);
264
337
  WSP_GGML_METAL_ADD_KERNEL(mul);
265
338
  WSP_GGML_METAL_ADD_KERNEL(mul_row);
339
+ WSP_GGML_METAL_ADD_KERNEL(div);
340
+ WSP_GGML_METAL_ADD_KERNEL(div_row);
266
341
  WSP_GGML_METAL_ADD_KERNEL(scale);
267
342
  WSP_GGML_METAL_ADD_KERNEL(scale_4);
268
- WSP_GGML_METAL_ADD_KERNEL(silu);
343
+ WSP_GGML_METAL_ADD_KERNEL(tanh);
269
344
  WSP_GGML_METAL_ADD_KERNEL(relu);
270
345
  WSP_GGML_METAL_ADD_KERNEL(gelu);
346
+ WSP_GGML_METAL_ADD_KERNEL(gelu_quick);
347
+ WSP_GGML_METAL_ADD_KERNEL(silu);
271
348
  WSP_GGML_METAL_ADD_KERNEL(soft_max);
272
349
  WSP_GGML_METAL_ADD_KERNEL(soft_max_4);
273
350
  WSP_GGML_METAL_ADD_KERNEL(diag_mask_inf);
@@ -285,8 +362,10 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
285
362
  WSP_GGML_METAL_ADD_KERNEL(get_rows_q5_K);
286
363
  WSP_GGML_METAL_ADD_KERNEL(get_rows_q6_K);
287
364
  WSP_GGML_METAL_ADD_KERNEL(rms_norm);
365
+ WSP_GGML_METAL_ADD_KERNEL(group_norm);
288
366
  WSP_GGML_METAL_ADD_KERNEL(norm);
289
367
  WSP_GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
368
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
290
369
  WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
291
370
  WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
292
371
  WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
@@ -300,6 +379,21 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
300
379
  WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
301
380
  WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
302
381
  WSP_GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
382
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
383
+ //WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
384
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
385
+ //WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
386
+ //WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
387
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
388
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
389
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
390
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
391
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
392
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
393
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
394
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
395
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
396
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
303
397
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
304
398
  WSP_GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
305
399
  WSP_GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
@@ -313,42 +407,44 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
313
407
  WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
314
408
  WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
315
409
  WSP_GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
410
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
411
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
412
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
413
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
414
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
415
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
416
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
417
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
418
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
419
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
420
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
421
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
316
422
  }
317
423
  WSP_GGML_METAL_ADD_KERNEL(rope_f32);
318
424
  WSP_GGML_METAL_ADD_KERNEL(rope_f16);
319
425
  WSP_GGML_METAL_ADD_KERNEL(alibi_f32);
426
+ WSP_GGML_METAL_ADD_KERNEL(im2col_f16);
427
+ WSP_GGML_METAL_ADD_KERNEL(upscale_f32);
428
+ WSP_GGML_METAL_ADD_KERNEL(pad_f32);
429
+ WSP_GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
430
+ WSP_GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
431
+ WSP_GGML_METAL_ADD_KERNEL(leaky_relu_f32);
320
432
  WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f16);
321
433
  WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f32);
434
+ WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
435
+ WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
436
+ WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
437
+ //WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
438
+ //WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
322
439
  WSP_GGML_METAL_ADD_KERNEL(cpy_f16_f16);
440
+ WSP_GGML_METAL_ADD_KERNEL(cpy_f16_f32);
323
441
  WSP_GGML_METAL_ADD_KERNEL(concat);
324
442
  WSP_GGML_METAL_ADD_KERNEL(sqr);
443
+ WSP_GGML_METAL_ADD_KERNEL(sum_rows);
325
444
 
326
445
  #undef WSP_GGML_METAL_ADD_KERNEL
327
446
  }
328
447
 
329
- #if TARGET_OS_OSX
330
- // print MTL GPU family:
331
- WSP_GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
332
-
333
- // determine max supported GPU family
334
- // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
335
- // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
336
- for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
337
- if ([ctx->device supportsFamily:i]) {
338
- WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
339
- break;
340
- }
341
- }
342
-
343
- WSP_GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
344
- WSP_GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
345
- if (ctx->device.maxTransferRate != 0) {
346
- WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
347
- } else {
348
- WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
349
- }
350
- #endif
351
-
352
448
  return ctx;
353
449
  }
354
450
 
@@ -360,11 +456,15 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
360
456
  WSP_GGML_METAL_DEL_KERNEL(add_row);
361
457
  WSP_GGML_METAL_DEL_KERNEL(mul);
362
458
  WSP_GGML_METAL_DEL_KERNEL(mul_row);
459
+ WSP_GGML_METAL_DEL_KERNEL(div);
460
+ WSP_GGML_METAL_DEL_KERNEL(div_row);
363
461
  WSP_GGML_METAL_DEL_KERNEL(scale);
364
462
  WSP_GGML_METAL_DEL_KERNEL(scale_4);
365
- WSP_GGML_METAL_DEL_KERNEL(silu);
463
+ WSP_GGML_METAL_DEL_KERNEL(tanh);
366
464
  WSP_GGML_METAL_DEL_KERNEL(relu);
367
465
  WSP_GGML_METAL_DEL_KERNEL(gelu);
466
+ WSP_GGML_METAL_DEL_KERNEL(gelu_quick);
467
+ WSP_GGML_METAL_DEL_KERNEL(silu);
368
468
  WSP_GGML_METAL_DEL_KERNEL(soft_max);
369
469
  WSP_GGML_METAL_DEL_KERNEL(soft_max_4);
370
470
  WSP_GGML_METAL_DEL_KERNEL(diag_mask_inf);
@@ -382,8 +482,10 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
382
482
  WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_K);
383
483
  WSP_GGML_METAL_DEL_KERNEL(get_rows_q6_K);
384
484
  WSP_GGML_METAL_DEL_KERNEL(rms_norm);
485
+ WSP_GGML_METAL_DEL_KERNEL(group_norm);
385
486
  WSP_GGML_METAL_DEL_KERNEL(norm);
386
487
  WSP_GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
488
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
387
489
  WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
388
490
  WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
389
491
  WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
@@ -397,6 +499,21 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
397
499
  WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
398
500
  WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
399
501
  WSP_GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
502
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
503
+ //WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
504
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
505
+ //WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
506
+ //WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
507
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
508
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
509
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
510
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
511
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
512
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
513
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
514
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
515
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
516
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
400
517
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
401
518
  WSP_GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
402
519
  WSP_GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
@@ -410,15 +527,40 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
410
527
  WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
411
528
  WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
412
529
  WSP_GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
530
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
531
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
532
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
533
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
534
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
535
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
536
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
537
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
538
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
539
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
540
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
541
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
413
542
  }
414
543
  WSP_GGML_METAL_DEL_KERNEL(rope_f32);
415
544
  WSP_GGML_METAL_DEL_KERNEL(rope_f16);
416
545
  WSP_GGML_METAL_DEL_KERNEL(alibi_f32);
546
+ WSP_GGML_METAL_DEL_KERNEL(im2col_f16);
547
+ WSP_GGML_METAL_DEL_KERNEL(upscale_f32);
548
+ WSP_GGML_METAL_DEL_KERNEL(pad_f32);
549
+ WSP_GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
550
+ WSP_GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
551
+ WSP_GGML_METAL_DEL_KERNEL(leaky_relu_f32);
417
552
  WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f16);
418
553
  WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f32);
554
+ WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
555
+ WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
556
+ WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
557
+ //WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
558
+ //WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
419
559
  WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f16);
560
+ WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f32);
420
561
  WSP_GGML_METAL_DEL_KERNEL(concat);
421
562
  WSP_GGML_METAL_DEL_KERNEL(sqr);
563
+ WSP_GGML_METAL_DEL_KERNEL(sum_rows);
422
564
 
423
565
  #undef WSP_GGML_METAL_DEL_KERNEL
424
566
 
@@ -452,6 +594,13 @@ int * wsp_ggml_metal_get_concur_list(struct wsp_ggml_metal_context * ctx) {
452
594
  return ctx->concur_list;
453
595
  }
454
596
 
597
+ // temporarily defined here for compatibility between ggml-backend and the old API
598
+ struct wsp_ggml_backend_metal_buffer_context {
599
+ void * data;
600
+
601
+ id<MTLBuffer> metal;
602
+ };
603
+
455
604
  // finds the Metal buffer that contains the tensor data on the GPU device
456
605
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
457
606
  // Metal buffer based on the host memory pointer
@@ -461,6 +610,19 @@ static id<MTLBuffer> wsp_ggml_metal_get_buffer(struct wsp_ggml_metal_context * c
461
610
 
462
611
  const int64_t tsize = wsp_ggml_nbytes(t);
463
612
 
613
+ // compatibility with ggml-backend
614
+ if (t->buffer && t->buffer->buft == wsp_ggml_backend_metal_buffer_type()) {
615
+ struct wsp_ggml_backend_metal_buffer_context * buf_ctx = (struct wsp_ggml_backend_metal_buffer_context *) t->buffer->context;
616
+
617
+ const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->data;
618
+
619
+ WSP_GGML_ASSERT(ioffs >= 0 && ioffs + tsize <= (int64_t) t->buffer->size);
620
+
621
+ *offs = (size_t) ioffs;
622
+
623
+ return buf_ctx->metal;
624
+ }
625
+
464
626
  // find the view that contains the tensor fully
465
627
  for (int i = 0; i < ctx->n_buffers; ++i) {
466
628
  const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
@@ -518,11 +680,11 @@ bool wsp_ggml_metal_add_buffer(
518
680
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
519
681
 
520
682
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
521
- WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
683
+ WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
522
684
  return false;
523
685
  }
524
686
 
525
- WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
687
+ WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB", __func__, name, size_aligned / 1024.0 / 1024.0);
526
688
 
527
689
  ++ctx->n_buffers;
528
690
  } else {
@@ -542,11 +704,11 @@ bool wsp_ggml_metal_add_buffer(
542
704
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
543
705
 
544
706
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
545
- WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
707
+ WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
546
708
  return false;
547
709
  }
548
710
 
549
- WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
711
+ WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
550
712
  if (i + size_step < size) {
551
713
  WSP_GGML_METAL_LOG_INFO("\n");
552
714
  }
@@ -561,7 +723,7 @@ bool wsp_ggml_metal_add_buffer(
561
723
  ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
562
724
 
563
725
  if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
564
- WSP_GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__);
726
+ WSP_GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
565
727
  } else {
566
728
  WSP_GGML_METAL_LOG_INFO("\n");
567
729
  }
@@ -683,6 +845,83 @@ void wsp_ggml_metal_graph_find_concurrency(
683
845
  }
684
846
  }
685
847
 
848
+ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) {
849
+ switch (op->op) {
850
+ case WSP_GGML_OP_UNARY:
851
+ switch (wsp_ggml_get_unary_op(op)) {
852
+ case WSP_GGML_UNARY_OP_TANH:
853
+ case WSP_GGML_UNARY_OP_RELU:
854
+ case WSP_GGML_UNARY_OP_GELU:
855
+ case WSP_GGML_UNARY_OP_GELU_QUICK:
856
+ case WSP_GGML_UNARY_OP_SILU:
857
+ return true;
858
+ default:
859
+ return false;
860
+ }
861
+ case WSP_GGML_OP_NONE:
862
+ case WSP_GGML_OP_RESHAPE:
863
+ case WSP_GGML_OP_VIEW:
864
+ case WSP_GGML_OP_TRANSPOSE:
865
+ case WSP_GGML_OP_PERMUTE:
866
+ case WSP_GGML_OP_CONCAT:
867
+ case WSP_GGML_OP_ADD:
868
+ case WSP_GGML_OP_ACC:
869
+ case WSP_GGML_OP_MUL:
870
+ case WSP_GGML_OP_DIV:
871
+ case WSP_GGML_OP_SCALE:
872
+ case WSP_GGML_OP_SQR:
873
+ case WSP_GGML_OP_SUM_ROWS:
874
+ case WSP_GGML_OP_SOFT_MAX:
875
+ case WSP_GGML_OP_RMS_NORM:
876
+ case WSP_GGML_OP_GROUP_NORM:
877
+ case WSP_GGML_OP_NORM:
878
+ case WSP_GGML_OP_ALIBI:
879
+ case WSP_GGML_OP_ROPE:
880
+ case WSP_GGML_OP_IM2COL:
881
+ case WSP_GGML_OP_UPSCALE:
882
+ case WSP_GGML_OP_PAD:
883
+ case WSP_GGML_OP_ARGSORT:
884
+ case WSP_GGML_OP_LEAKY_RELU:
885
+ case WSP_GGML_OP_MUL_MAT:
886
+ case WSP_GGML_OP_MUL_MAT_ID:
887
+ return true;
888
+ case WSP_GGML_OP_CPY:
889
+ case WSP_GGML_OP_DUP:
890
+ case WSP_GGML_OP_CONT:
891
+ {
892
+ switch (op->src[0]->type) {
893
+ case WSP_GGML_TYPE_F32:
894
+ switch (op->type) {
895
+ case WSP_GGML_TYPE_F16:
896
+ case WSP_GGML_TYPE_F32:
897
+ case WSP_GGML_TYPE_Q8_0:
898
+ case WSP_GGML_TYPE_Q4_0:
899
+ case WSP_GGML_TYPE_Q4_1:
900
+ return true;
901
+ default:
902
+ return false;
903
+ }
904
+ case WSP_GGML_TYPE_F16:
905
+ switch (op->type) {
906
+ case WSP_GGML_TYPE_F16:
907
+ case WSP_GGML_TYPE_F32:
908
+ return true;
909
+ default:
910
+ return false;
911
+ }
912
+ default:
913
+ return false;
914
+ };
915
+ }
916
+ case WSP_GGML_OP_DIAG_MASK_INF:
917
+ case WSP_GGML_OP_GET_ROWS:
918
+ {
919
+ return op->ne[3] == 1;
920
+ }
921
+ default:
922
+ return false;
923
+ }
924
+ }
686
925
  void wsp_ggml_metal_graph_compute(
687
926
  struct wsp_ggml_metal_context * ctx,
688
927
  struct wsp_ggml_cgraph * gf) {
@@ -753,6 +992,11 @@ void wsp_ggml_metal_graph_compute(
753
992
  } break;
754
993
  }
755
994
 
995
+ if (!wsp_ggml_metal_supports_op(dst)) {
996
+ WSP_GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, wsp_ggml_op_desc(dst));
997
+ WSP_GGML_ASSERT(!"unsupported op");
998
+ }
999
+
756
1000
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
757
1001
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
758
1002
  const int64_t ne02 = src0 ? src0->ne[2] : 0;
@@ -845,25 +1089,42 @@ void wsp_ggml_metal_graph_compute(
845
1089
  [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
846
1090
  } break;
847
1091
  case WSP_GGML_OP_ADD:
1092
+ case WSP_GGML_OP_MUL:
1093
+ case WSP_GGML_OP_DIV:
848
1094
  {
849
- WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
850
- WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
1095
+ const size_t offs = 0;
851
1096
 
852
1097
  bool bcast_row = false;
853
1098
 
854
1099
  int64_t nb = ne00;
855
1100
 
856
- if (wsp_ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
1101
+ id<MTLComputePipelineState> pipeline = nil;
1102
+
1103
+ if (wsp_ggml_nelements(src1) == ne10 && wsp_ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
1104
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
1105
+
857
1106
  // src1 is a row
858
1107
  WSP_GGML_ASSERT(ne11 == 1);
859
1108
 
860
1109
  nb = ne00 / 4;
861
- [encoder setComputePipelineState:ctx->pipeline_add_row];
1110
+ switch (dst->op) {
1111
+ case WSP_GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
1112
+ case WSP_GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
1113
+ case WSP_GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
1114
+ default: WSP_GGML_ASSERT(false);
1115
+ }
862
1116
 
863
1117
  bcast_row = true;
864
1118
  } else {
865
- [encoder setComputePipelineState:ctx->pipeline_add];
1119
+ switch (dst->op) {
1120
+ case WSP_GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
1121
+ case WSP_GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
1122
+ case WSP_GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
1123
+ default: WSP_GGML_ASSERT(false);
1124
+ }
866
1125
  }
1126
+
1127
+ [encoder setComputePipelineState:pipeline];
867
1128
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
868
1129
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
869
1130
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@@ -891,42 +1152,98 @@ void wsp_ggml_metal_graph_compute(
891
1152
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
892
1153
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
893
1154
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
894
- [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
1155
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
1156
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
895
1157
 
896
1158
  if (bcast_row) {
897
1159
  const int64_t n = wsp_ggml_nelements(dst)/4;
898
1160
 
899
1161
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
900
1162
  } else {
901
- const int nth = MIN(1024, ne0);
1163
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
902
1164
 
903
1165
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
904
1166
  }
905
1167
  } break;
906
- case WSP_GGML_OP_MUL:
1168
+ case WSP_GGML_OP_ACC:
907
1169
  {
1170
+ WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_F32);
1171
+ WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
1172
+ WSP_GGML_ASSERT(dstt == WSP_GGML_TYPE_F32);
1173
+
908
1174
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
909
1175
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
910
1176
 
911
- // utilize float4
912
- WSP_GGML_ASSERT(ne00 % 4 == 0);
913
- const int64_t nb = ne00/4;
1177
+ const size_t pnb1 = ((int32_t *) dst->op_params)[0];
1178
+ const size_t pnb2 = ((int32_t *) dst->op_params)[1];
1179
+ const size_t pnb3 = ((int32_t *) dst->op_params)[2];
1180
+ const size_t offs = ((int32_t *) dst->op_params)[3];
1181
+
1182
+ const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
1183
+
1184
+ if (!inplace) {
1185
+ // run a separete kernel to cpy src->dst
1186
+ // not sure how to avoid this
1187
+ // TODO: make a simpler cpy_bytes kernel
1188
+
1189
+ const int nth = MIN(1024, ne00);
1190
+
1191
+ [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
1192
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1193
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1194
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1195
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1196
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1197
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1198
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1199
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1200
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1201
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1202
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1203
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1204
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1205
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1206
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1207
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1208
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1209
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
914
1210
 
915
- if (wsp_ggml_nelements(src1) == ne10) {
916
- // src1 is a row
917
- WSP_GGML_ASSERT(ne11 == 1);
918
- [encoder setComputePipelineState:ctx->pipeline_mul_row];
919
- } else {
920
- [encoder setComputePipelineState:ctx->pipeline_mul];
1211
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
921
1212
  }
1213
+
1214
+ [encoder setComputePipelineState:ctx->pipeline_add];
922
1215
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
923
1216
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
924
1217
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
925
- [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
1218
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1219
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1220
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1221
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1222
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1223
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
1224
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
1225
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
1226
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1227
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1228
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1229
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1230
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1231
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1232
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1233
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1234
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1235
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1236
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1237
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1238
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1239
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
1240
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
1241
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
1242
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
926
1243
 
927
- const int64_t n = wsp_ggml_nelements(dst)/4;
1244
+ const int nth = MIN(1024, ne0);
928
1245
 
929
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1246
+ [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
930
1247
  } break;
931
1248
  case WSP_GGML_OP_SCALE:
932
1249
  {
@@ -951,16 +1268,15 @@ void wsp_ggml_metal_graph_compute(
951
1268
  } break;
952
1269
  case WSP_GGML_OP_UNARY:
953
1270
  switch (wsp_ggml_get_unary_op(gf->nodes[i])) {
954
- case WSP_GGML_UNARY_OP_SILU:
1271
+ case WSP_GGML_UNARY_OP_TANH:
955
1272
  {
956
- [encoder setComputePipelineState:ctx->pipeline_silu];
1273
+ [encoder setComputePipelineState:ctx->pipeline_tanh];
957
1274
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
958
1275
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
959
1276
 
960
1277
  const int64_t n = wsp_ggml_nelements(dst);
961
- WSP_GGML_ASSERT(n % 4 == 0);
962
1278
 
963
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1279
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
964
1280
  } break;
965
1281
  case WSP_GGML_UNARY_OP_RELU:
966
1282
  {
@@ -981,6 +1297,28 @@ void wsp_ggml_metal_graph_compute(
981
1297
  const int64_t n = wsp_ggml_nelements(dst);
982
1298
  WSP_GGML_ASSERT(n % 4 == 0);
983
1299
 
1300
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1301
+ } break;
1302
+ case WSP_GGML_UNARY_OP_GELU_QUICK:
1303
+ {
1304
+ [encoder setComputePipelineState:ctx->pipeline_gelu_quick];
1305
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1306
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1307
+
1308
+ const int64_t n = wsp_ggml_nelements(dst);
1309
+ WSP_GGML_ASSERT(n % 4 == 0);
1310
+
1311
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1312
+ } break;
1313
+ case WSP_GGML_UNARY_OP_SILU:
1314
+ {
1315
+ [encoder setComputePipelineState:ctx->pipeline_silu];
1316
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1317
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1318
+
1319
+ const int64_t n = wsp_ggml_nelements(dst);
1320
+ WSP_GGML_ASSERT(n % 4 == 0);
1321
+
984
1322
  [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
985
1323
  } break;
986
1324
  default:
@@ -1000,25 +1338,70 @@ void wsp_ggml_metal_graph_compute(
1000
1338
  const int64_t n = wsp_ggml_nelements(dst);
1001
1339
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1002
1340
  } break;
1341
+ case WSP_GGML_OP_SUM_ROWS:
1342
+ {
1343
+ WSP_GGML_ASSERT(src0->nb[0] == wsp_ggml_type_size(src0->type));
1344
+
1345
+ [encoder setComputePipelineState:ctx->pipeline_sum_rows];
1346
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1347
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1348
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1349
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1350
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1351
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1352
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1353
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1354
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1355
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1356
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1357
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1358
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1359
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1360
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1361
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1362
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1363
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1364
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1365
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1366
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1367
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1368
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1369
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1370
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1371
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
1372
+
1373
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1374
+ } break;
1003
1375
  case WSP_GGML_OP_SOFT_MAX:
1004
1376
  {
1005
1377
  int nth = 32; // SIMD width
1006
1378
 
1007
1379
  if (ne00%4 == 0) {
1380
+ while (nth < ne00/4 && nth < 256) {
1381
+ nth *= 2;
1382
+ }
1008
1383
  [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
1009
1384
  } else {
1010
- do {
1385
+ while (nth < ne00 && nth < 1024) {
1011
1386
  nth *= 2;
1012
- } while (nth <= ne00 && nth <= 1024);
1013
- nth /= 2;
1387
+ }
1014
1388
  [encoder setComputePipelineState:ctx->pipeline_soft_max];
1015
1389
  }
1016
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1017
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1018
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1019
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1020
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1021
- [encoder setThreadgroupMemoryLength:WSP_GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1390
+
1391
+ const float scale = ((float *) dst->op_params)[0];
1392
+
1393
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1394
+ if (id_src1) {
1395
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1396
+ } else {
1397
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1398
+ }
1399
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1400
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1401
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1402
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1403
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1404
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1022
1405
 
1023
1406
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1024
1407
  } break;
@@ -1047,9 +1430,13 @@ void wsp_ggml_metal_graph_compute(
1047
1430
  case WSP_GGML_OP_MUL_MAT:
1048
1431
  {
1049
1432
  WSP_GGML_ASSERT(ne00 == ne10);
1050
- WSP_GGML_ASSERT(ne03 == ne13);
1051
1433
 
1052
- const uint gqa = ne12/ne02;
1434
+ // TODO: assert that dim2 and dim3 are contiguous
1435
+ WSP_GGML_ASSERT(ne12 % ne02 == 0);
1436
+ WSP_GGML_ASSERT(ne13 % ne03 == 0);
1437
+
1438
+ const uint r2 = ne12/ne02;
1439
+ const uint r3 = ne13/ne03;
1053
1440
 
1054
1441
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1055
1442
  // to the matrix-vector kernel
@@ -1084,7 +1471,7 @@ void wsp_ggml_metal_graph_compute(
1084
1471
  !wsp_ggml_is_transposed(src1) &&
1085
1472
  src1t == WSP_GGML_TYPE_F32 &&
1086
1473
  ne00 % 32 == 0 && ne00 >= 64 &&
1087
- ne11 > ne11_mm_min) {
1474
+ (ne11 > ne11_mm_min || (wsp_ggml_is_quantized(src0t) && ne12 > 1))) {
1088
1475
  //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1089
1476
  switch (src0->type) {
1090
1477
  case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
@@ -1114,9 +1501,10 @@ void wsp_ggml_metal_graph_compute(
1114
1501
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1115
1502
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1116
1503
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1117
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
1504
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1505
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1118
1506
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1119
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1507
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1120
1508
  } else {
1121
1509
  int nth0 = 32;
1122
1510
  int nth1 = 1;
@@ -1127,6 +1515,7 @@ void wsp_ggml_metal_graph_compute(
1127
1515
  switch (src0t) {
1128
1516
  case WSP_GGML_TYPE_F32:
1129
1517
  {
1518
+ WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
1130
1519
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
1131
1520
  nrows = 4;
1132
1521
  } break;
@@ -1134,102 +1523,77 @@ void wsp_ggml_metal_graph_compute(
1134
1523
  {
1135
1524
  nth0 = 32;
1136
1525
  nth1 = 1;
1137
- if (ne11 * ne12 < 4) {
1138
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1139
- } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1140
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1141
- nrows = ne11;
1526
+ if (src1t == WSP_GGML_TYPE_F32) {
1527
+ if (ne11 * ne12 < 4) {
1528
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1529
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1530
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1531
+ nrows = ne11;
1532
+ } else {
1533
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1534
+ nrows = 4;
1535
+ }
1142
1536
  } else {
1143
- [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1537
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
1144
1538
  nrows = 4;
1145
1539
  }
1146
1540
  } break;
1147
1541
  case WSP_GGML_TYPE_Q4_0:
1148
1542
  {
1149
- WSP_GGML_ASSERT(ne02 == 1);
1150
- WSP_GGML_ASSERT(ne12 == 1);
1151
-
1152
1543
  nth0 = 8;
1153
1544
  nth1 = 8;
1154
1545
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
1155
1546
  } break;
1156
1547
  case WSP_GGML_TYPE_Q4_1:
1157
1548
  {
1158
- WSP_GGML_ASSERT(ne02 == 1);
1159
- WSP_GGML_ASSERT(ne12 == 1);
1160
-
1161
1549
  nth0 = 8;
1162
1550
  nth1 = 8;
1163
1551
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
1164
1552
  } break;
1165
1553
  case WSP_GGML_TYPE_Q5_0:
1166
1554
  {
1167
- WSP_GGML_ASSERT(ne02 == 1);
1168
- WSP_GGML_ASSERT(ne12 == 1);
1169
-
1170
1555
  nth0 = 8;
1171
1556
  nth1 = 8;
1172
1557
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
1173
1558
  } break;
1174
1559
  case WSP_GGML_TYPE_Q5_1:
1175
1560
  {
1176
- WSP_GGML_ASSERT(ne02 == 1);
1177
- WSP_GGML_ASSERT(ne12 == 1);
1178
-
1179
1561
  nth0 = 8;
1180
1562
  nth1 = 8;
1181
1563
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
1182
1564
  } break;
1183
1565
  case WSP_GGML_TYPE_Q8_0:
1184
1566
  {
1185
- WSP_GGML_ASSERT(ne02 == 1);
1186
- WSP_GGML_ASSERT(ne12 == 1);
1187
-
1188
1567
  nth0 = 8;
1189
1568
  nth1 = 8;
1190
1569
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
1191
1570
  } break;
1192
1571
  case WSP_GGML_TYPE_Q2_K:
1193
1572
  {
1194
- WSP_GGML_ASSERT(ne02 == 1);
1195
- WSP_GGML_ASSERT(ne12 == 1);
1196
-
1197
1573
  nth0 = 2;
1198
1574
  nth1 = 32;
1199
1575
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
1200
1576
  } break;
1201
1577
  case WSP_GGML_TYPE_Q3_K:
1202
1578
  {
1203
- WSP_GGML_ASSERT(ne02 == 1);
1204
- WSP_GGML_ASSERT(ne12 == 1);
1205
-
1206
1579
  nth0 = 2;
1207
1580
  nth1 = 32;
1208
1581
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
1209
1582
  } break;
1210
1583
  case WSP_GGML_TYPE_Q4_K:
1211
1584
  {
1212
- WSP_GGML_ASSERT(ne02 == 1);
1213
- WSP_GGML_ASSERT(ne12 == 1);
1214
-
1215
1585
  nth0 = 4; //1;
1216
1586
  nth1 = 8; //32;
1217
1587
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
1218
1588
  } break;
1219
1589
  case WSP_GGML_TYPE_Q5_K:
1220
1590
  {
1221
- WSP_GGML_ASSERT(ne02 == 1);
1222
- WSP_GGML_ASSERT(ne12 == 1);
1223
-
1224
1591
  nth0 = 2;
1225
1592
  nth1 = 32;
1226
1593
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
1227
1594
  } break;
1228
1595
  case WSP_GGML_TYPE_Q6_K:
1229
1596
  {
1230
- WSP_GGML_ASSERT(ne02 == 1);
1231
- WSP_GGML_ASSERT(ne12 == 1);
1232
-
1233
1597
  nth0 = 2;
1234
1598
  nth1 = 32;
1235
1599
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
@@ -1258,31 +1622,281 @@ void wsp_ggml_metal_graph_compute(
1258
1622
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1259
1623
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1260
1624
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1261
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
1625
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1626
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1262
1627
 
1263
1628
  if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 ||
1264
1629
  src0t == WSP_GGML_TYPE_Q5_0 || src0t == WSP_GGML_TYPE_Q5_1 || src0t == WSP_GGML_TYPE_Q8_0 ||
1265
1630
  src0t == WSP_GGML_TYPE_Q2_K) { // || src0t == WSP_GGML_TYPE_Q4_K) {
1266
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1631
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1267
1632
  }
1268
1633
  else if (src0t == WSP_GGML_TYPE_Q4_K) {
1269
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1634
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1270
1635
  }
1271
1636
  else if (src0t == WSP_GGML_TYPE_Q3_K) {
1272
1637
  #ifdef WSP_GGML_QKK_64
1273
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1638
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1274
1639
  #else
1275
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1640
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1276
1641
  #endif
1277
1642
  }
1278
1643
  else if (src0t == WSP_GGML_TYPE_Q5_K) {
1279
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1644
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1280
1645
  }
1281
1646
  else if (src0t == WSP_GGML_TYPE_Q6_K) {
1282
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1647
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1648
+ } else {
1649
+ const int64_t ny = (ne11 + nrows - 1)/nrows;
1650
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1651
+ }
1652
+ }
1653
+ } break;
1654
+ case WSP_GGML_OP_MUL_MAT_ID:
1655
+ {
1656
+ //WSP_GGML_ASSERT(ne00 == ne10);
1657
+ //WSP_GGML_ASSERT(ne03 == ne13);
1658
+
1659
+ WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_I32);
1660
+
1661
+ const int n_as = ((int32_t *) dst->op_params)[1];
1662
+
1663
+ // TODO: make this more general
1664
+ WSP_GGML_ASSERT(n_as <= 8);
1665
+
1666
+ struct wsp_ggml_tensor * src2 = gf->nodes[i]->src[2];
1667
+
1668
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
1669
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
1670
+ const int64_t ne22 = src2 ? src2->ne[2] : 0;
1671
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; WSP_GGML_UNUSED(ne23);
1672
+
1673
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; WSP_GGML_UNUSED(nb20);
1674
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1675
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1676
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0; WSP_GGML_UNUSED(nb23);
1677
+
1678
+ const enum wsp_ggml_type src2t = src2 ? src2->type : WSP_GGML_TYPE_COUNT; WSP_GGML_UNUSED(src2t);
1679
+
1680
+ WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src2));
1681
+ WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src1));
1682
+
1683
+ WSP_GGML_ASSERT(ne20 % 32 == 0);
1684
+ // !!!!!!!!! TODO: this assert is probably required but not sure!
1685
+ //WSP_GGML_ASSERT(ne20 >= 64);
1686
+ WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
1687
+
1688
+ const uint r2 = ne12/ne22;
1689
+ const uint r3 = ne13/ne23;
1690
+
1691
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1692
+ // to the matrix-vector kernel
1693
+ int ne11_mm_min = 1;
1694
+
1695
+ const int idx = ((int32_t *) dst->op_params)[0];
1696
+
1697
+ // batch size
1698
+ WSP_GGML_ASSERT(ne01 == ne11);
1699
+
1700
+ const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
1701
+
1702
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1703
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1704
+ // !!!
1705
+ // TODO: for now, always use mat-vec kernels until we figure out how to improve the
1706
+ // indirect matrix multiplication
1707
+ // !!!
1708
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
1709
+ switch (src2->type) {
1710
+ case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
1711
+ case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
1712
+ case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
1713
+ case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
1714
+ case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
1715
+ case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
1716
+ case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
1717
+ case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
1718
+ case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
1719
+ case WSP_GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
1720
+ case WSP_GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
1721
+ case WSP_GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
1722
+ default: WSP_GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1723
+ }
1724
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1725
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1726
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1727
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1728
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1729
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
1730
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1731
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
1732
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1733
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
1734
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
1735
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
1736
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
1737
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1738
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
1739
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1740
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
1741
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
1742
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
1743
+ // TODO: how to make this an array? read Metal docs
1744
+ for (int j = 0; j < n_as; ++j) {
1745
+ struct wsp_ggml_tensor * src_cur = dst->src[2 + j];
1746
+
1747
+ size_t offs_src_cur = 0;
1748
+ id<MTLBuffer> id_src_cur = wsp_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1749
+
1750
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1751
+ }
1752
+
1753
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1754
+
1755
+ // TODO: processing one row at a time (ne11 -> 1) is not efficient
1756
+ [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1757
+ } else {
1758
+ int nth0 = 32;
1759
+ int nth1 = 1;
1760
+ int nrows = 1;
1761
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1762
+
1763
+ // use custom matrix x vector kernel
1764
+ switch (src2t) {
1765
+ case WSP_GGML_TYPE_F32:
1766
+ {
1767
+ WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
1768
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
1769
+ } break;
1770
+ case WSP_GGML_TYPE_F16:
1771
+ {
1772
+ WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
1773
+ nth0 = 32;
1774
+ nth1 = 1;
1775
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
1776
+ } break;
1777
+ case WSP_GGML_TYPE_Q4_0:
1778
+ {
1779
+ nth0 = 8;
1780
+ nth1 = 8;
1781
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
1782
+ } break;
1783
+ case WSP_GGML_TYPE_Q4_1:
1784
+ {
1785
+ nth0 = 8;
1786
+ nth1 = 8;
1787
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
1788
+ } break;
1789
+ case WSP_GGML_TYPE_Q5_0:
1790
+ {
1791
+ nth0 = 8;
1792
+ nth1 = 8;
1793
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
1794
+ } break;
1795
+ case WSP_GGML_TYPE_Q5_1:
1796
+ {
1797
+ nth0 = 8;
1798
+ nth1 = 8;
1799
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
1800
+ } break;
1801
+ case WSP_GGML_TYPE_Q8_0:
1802
+ {
1803
+ nth0 = 8;
1804
+ nth1 = 8;
1805
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
1806
+ } break;
1807
+ case WSP_GGML_TYPE_Q2_K:
1808
+ {
1809
+ nth0 = 2;
1810
+ nth1 = 32;
1811
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
1812
+ } break;
1813
+ case WSP_GGML_TYPE_Q3_K:
1814
+ {
1815
+ nth0 = 2;
1816
+ nth1 = 32;
1817
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
1818
+ } break;
1819
+ case WSP_GGML_TYPE_Q4_K:
1820
+ {
1821
+ nth0 = 4; //1;
1822
+ nth1 = 8; //32;
1823
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
1824
+ } break;
1825
+ case WSP_GGML_TYPE_Q5_K:
1826
+ {
1827
+ nth0 = 2;
1828
+ nth1 = 32;
1829
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
1830
+ } break;
1831
+ case WSP_GGML_TYPE_Q6_K:
1832
+ {
1833
+ nth0 = 2;
1834
+ nth1 = 32;
1835
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
1836
+ } break;
1837
+ default:
1838
+ {
1839
+ WSP_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1840
+ WSP_GGML_ASSERT(false && "not implemented");
1841
+ }
1842
+ };
1843
+
1844
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1845
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1846
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1847
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1848
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1849
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1850
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
1851
+ [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
1852
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
1853
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
1854
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1855
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
1856
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1857
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1858
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1859
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1860
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1861
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1862
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
1863
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1864
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
1865
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1866
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1867
+ // TODO: how to make this an array? read Metal docs
1868
+ for (int j = 0; j < n_as; ++j) {
1869
+ struct wsp_ggml_tensor * src_cur = dst->src[2 + j];
1870
+
1871
+ size_t offs_src_cur = 0;
1872
+ id<MTLBuffer> id_src_cur = wsp_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1873
+
1874
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1875
+ }
1876
+
1877
+ if (src2t == WSP_GGML_TYPE_Q4_0 || src2t == WSP_GGML_TYPE_Q4_1 ||
1878
+ src2t == WSP_GGML_TYPE_Q5_0 || src2t == WSP_GGML_TYPE_Q5_1 || src2t == WSP_GGML_TYPE_Q8_0 ||
1879
+ src2t == WSP_GGML_TYPE_Q2_K) { // || src2t == WSP_GGML_TYPE_Q4_K) {
1880
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1881
+ }
1882
+ else if (src2t == WSP_GGML_TYPE_Q4_K) {
1883
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1884
+ }
1885
+ else if (src2t == WSP_GGML_TYPE_Q3_K) {
1886
+ #ifdef WSP_GGML_QKK_64
1887
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1888
+ #else
1889
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1890
+ #endif
1891
+ }
1892
+ else if (src2t == WSP_GGML_TYPE_Q5_K) {
1893
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1894
+ }
1895
+ else if (src2t == WSP_GGML_TYPE_Q6_K) {
1896
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1283
1897
  } else {
1284
- int64_t ny = (ne11 + nrows - 1)/nrows;
1285
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1898
+ const int64_t ny = (_ne1 + nrows - 1)/nrows;
1899
+ [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1286
1900
  }
1287
1901
  }
1288
1902
  } break;
@@ -1304,16 +1918,19 @@ void wsp_ggml_metal_graph_compute(
1304
1918
  default: WSP_GGML_ASSERT(false && "not implemented");
1305
1919
  }
1306
1920
 
1307
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1308
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1309
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1921
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1922
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1923
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1310
1924
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1311
1925
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
1312
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
1313
-
1314
- const int64_t n = wsp_ggml_nelements(src1);
1315
-
1316
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1926
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
1927
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
1928
+ [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
1929
+ [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
1930
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
1931
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
1932
+
1933
+ [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
1317
1934
  } break;
1318
1935
  case WSP_GGML_OP_RMS_NORM:
1319
1936
  {
@@ -1322,20 +1939,56 @@ void wsp_ggml_metal_graph_compute(
1322
1939
  float eps;
1323
1940
  memcpy(&eps, dst->op_params, sizeof(float));
1324
1941
 
1325
- const int nth = MIN(512, ne00);
1942
+ int nth = 32; // SIMD width
1943
+
1944
+ while (nth < ne00/4 && nth < 1024) {
1945
+ nth *= 2;
1946
+ }
1326
1947
 
1327
1948
  [encoder setComputePipelineState:ctx->pipeline_rms_norm];
1328
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1329
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1330
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1331
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1332
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1333
- [encoder setThreadgroupMemoryLength:WSP_GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1949
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1950
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1951
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1952
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1953
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1954
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1334
1955
 
1335
1956
  const int64_t nrows = wsp_ggml_nrows(src0);
1336
1957
 
1337
1958
  [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1338
1959
  } break;
1960
+ case WSP_GGML_OP_GROUP_NORM:
1961
+ {
1962
+ WSP_GGML_ASSERT(ne00 % 4 == 0);
1963
+
1964
+ //float eps;
1965
+ //memcpy(&eps, dst->op_params, sizeof(float));
1966
+
1967
+ const float eps = 1e-6f; // TODO: temporarily hardcoded
1968
+
1969
+ const int32_t n_groups = ((int32_t *) dst->op_params)[0];
1970
+
1971
+ int nth = 32; // SIMD width
1972
+
1973
+ //while (nth < ne00/4 && nth < 1024) {
1974
+ // nth *= 2;
1975
+ //}
1976
+
1977
+ [encoder setComputePipelineState:ctx->pipeline_group_norm];
1978
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1979
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1980
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1981
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1982
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1983
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
1984
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
1985
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
1986
+ [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
1987
+ [encoder setBytes:&eps length:sizeof( float) atIndex:9];
1988
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1989
+
1990
+ [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1991
+ } break;
1339
1992
  case WSP_GGML_OP_NORM:
1340
1993
  {
1341
1994
  float eps;
@@ -1404,7 +2057,8 @@ void wsp_ggml_metal_graph_compute(
1404
2057
  const int n_past = ((int32_t *) dst->op_params)[0];
1405
2058
  const int n_dims = ((int32_t *) dst->op_params)[1];
1406
2059
  const int mode = ((int32_t *) dst->op_params)[2];
1407
- const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
2060
+ // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
2061
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
1408
2062
 
1409
2063
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1410
2064
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
@@ -1452,18 +2106,175 @@ void wsp_ggml_metal_graph_compute(
1452
2106
 
1453
2107
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1454
2108
  } break;
2109
+ case WSP_GGML_OP_IM2COL:
2110
+ {
2111
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
2112
+ WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
2113
+ WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16);
2114
+
2115
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
2116
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
2117
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
2118
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
2119
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
2120
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
2121
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
2122
+
2123
+ const int32_t N = src1->ne[is_2D ? 3 : 2];
2124
+ const int32_t IC = src1->ne[is_2D ? 2 : 1];
2125
+ const int32_t IH = is_2D ? src1->ne[1] : 1;
2126
+ const int32_t IW = src1->ne[0];
2127
+
2128
+ const int32_t KH = is_2D ? src0->ne[1] : 1;
2129
+ const int32_t KW = src0->ne[0];
2130
+
2131
+ const int32_t OH = is_2D ? dst->ne[2] : 1;
2132
+ const int32_t OW = dst->ne[1];
2133
+
2134
+ const int32_t CHW = IC * KH * KW;
2135
+
2136
+ const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
2137
+ const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
2138
+
2139
+ switch (src0->type) {
2140
+ case WSP_GGML_TYPE_F32: WSP_GGML_ASSERT(false && "not implemented"); break;
2141
+ case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
2142
+ default: WSP_GGML_ASSERT(false);
2143
+ };
2144
+
2145
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
2146
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2147
+ [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
2148
+ [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
2149
+ [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
2150
+ [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
2151
+ [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
2152
+ [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
2153
+ [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
2154
+ [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
2155
+ [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
2156
+ [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
2157
+ [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
2158
+
2159
+ [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
2160
+ } break;
2161
+ case WSP_GGML_OP_UPSCALE:
2162
+ {
2163
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
2164
+
2165
+ const int sf = dst->op_params[0];
2166
+
2167
+ [encoder setComputePipelineState:ctx->pipeline_upscale_f32];
2168
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2169
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2170
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2171
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2172
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2173
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2174
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2175
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2176
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2177
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2178
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2179
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2180
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2181
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2182
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2183
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2184
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2185
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2186
+ [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2187
+
2188
+ const int nth = MIN(1024, ne0);
2189
+
2190
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2191
+ } break;
2192
+ case WSP_GGML_OP_PAD:
2193
+ {
2194
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
2195
+
2196
+ [encoder setComputePipelineState:ctx->pipeline_pad_f32];
2197
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2198
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2199
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
2200
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
2201
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
2202
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
2203
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
2204
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
2205
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
2206
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
2207
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
2208
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
2209
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
2210
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
2211
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
2212
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2213
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2214
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2215
+
2216
+ const int nth = MIN(1024, ne0);
2217
+
2218
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2219
+ } break;
2220
+ case WSP_GGML_OP_ARGSORT:
2221
+ {
2222
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
2223
+ WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_I32);
2224
+
2225
+ const int nrows = wsp_ggml_nrows(src0);
2226
+
2227
+ enum wsp_ggml_sort_order order = (enum wsp_ggml_sort_order) dst->op_params[0];
2228
+
2229
+ switch (order) {
2230
+ case WSP_GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
2231
+ case WSP_GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
2232
+ default: WSP_GGML_ASSERT(false);
2233
+ };
2234
+
2235
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2236
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2237
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2238
+
2239
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
2240
+ } break;
2241
+ case WSP_GGML_OP_LEAKY_RELU:
2242
+ {
2243
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
2244
+
2245
+ float slope;
2246
+ memcpy(&slope, dst->op_params, sizeof(float));
2247
+
2248
+ [encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
2249
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2250
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2251
+ [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
2252
+
2253
+ const int64_t n = wsp_ggml_nelements(dst);
2254
+
2255
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2256
+ } break;
1455
2257
  case WSP_GGML_OP_DUP:
1456
2258
  case WSP_GGML_OP_CPY:
1457
2259
  case WSP_GGML_OP_CONT:
1458
2260
  {
1459
- const int nth = MIN(1024, ne00);
2261
+ WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(src0->type) == 0);
2262
+
2263
+ int nth = MIN(1024, ne00/wsp_ggml_blck_size(src0->type));
1460
2264
 
1461
2265
  switch (src0t) {
1462
2266
  case WSP_GGML_TYPE_F32:
1463
2267
  {
2268
+ WSP_GGML_ASSERT(ne0 % wsp_ggml_blck_size(dst->type) == 0);
2269
+
1464
2270
  switch (dstt) {
1465
- case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
1466
- case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
2271
+ case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
2272
+ case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
2273
+ case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
2274
+ case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
2275
+ case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
2276
+ //case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
2277
+ //case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
1467
2278
  default: WSP_GGML_ASSERT(false && "not implemented");
1468
2279
  };
1469
2280
  } break;
@@ -1471,7 +2282,7 @@ void wsp_ggml_metal_graph_compute(
1471
2282
  {
1472
2283
  switch (dstt) {
1473
2284
  case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
1474
- case WSP_GGML_TYPE_F32: WSP_GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
2285
+ case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
1475
2286
  default: WSP_GGML_ASSERT(false && "not implemented");
1476
2287
  };
1477
2288
  } break;
@@ -1538,81 +2349,148 @@ void wsp_ggml_metal_graph_compute(
1538
2349
 
1539
2350
  // backend interface
1540
2351
 
1541
- static const char * wsp_ggml_backend_metal_name(wsp_ggml_backend_t backend) {
1542
- return "Metal";
2352
+ static id<MTLDevice> g_backend_device = nil;
2353
+ static int g_backend_device_ref_count = 0;
1543
2354
 
1544
- UNUSED(backend);
2355
+ static id<MTLDevice> wsp_ggml_backend_metal_get_device(void) {
2356
+ if (g_backend_device == nil) {
2357
+ g_backend_device = MTLCreateSystemDefaultDevice();
2358
+ }
2359
+
2360
+ g_backend_device_ref_count++;
2361
+
2362
+ return g_backend_device;
1545
2363
  }
1546
2364
 
1547
- static void wsp_ggml_backend_metal_free(wsp_ggml_backend_t backend) {
1548
- struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
1549
- wsp_ggml_metal_free(ctx);
1550
- free(backend);
2365
+ static void wsp_ggml_backend_metal_free_device(void) {
2366
+ assert(g_backend_device_ref_count > 0);
2367
+
2368
+ g_backend_device_ref_count--;
2369
+
2370
+ if (g_backend_device_ref_count == 0) {
2371
+ g_backend_device = nil;
2372
+ }
1551
2373
  }
1552
2374
 
1553
2375
  static void * wsp_ggml_backend_metal_buffer_get_base(wsp_ggml_backend_buffer_t buffer) {
1554
- return (void *)buffer->context;
2376
+ struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context;
2377
+
2378
+ return ctx->data;
1555
2379
  }
1556
2380
 
1557
2381
  static void wsp_ggml_backend_metal_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) {
1558
- free(buffer->context);
2382
+ struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context;
2383
+
2384
+ wsp_ggml_backend_metal_free_device();
2385
+
2386
+ free(ctx->data);
2387
+ free(ctx);
2388
+
2389
+ UNUSED(buffer);
2390
+ }
2391
+
2392
+ static void wsp_ggml_backend_metal_buffer_set_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
2393
+ WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds");
2394
+ WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
2395
+
2396
+ memcpy((char *)tensor->data + offset, data, size);
2397
+
2398
+ UNUSED(buffer);
2399
+ }
2400
+
2401
+ static void wsp_ggml_backend_metal_buffer_get_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
2402
+ WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
2403
+ WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
2404
+
2405
+ memcpy(data, (const char *)tensor->data + offset, size);
2406
+
2407
+ UNUSED(buffer);
2408
+ }
2409
+
2410
+ static void wsp_ggml_backend_metal_buffer_cpy_tensor_from(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
2411
+ wsp_ggml_backend_tensor_get(src, dst->data, 0, wsp_ggml_nbytes(src));
2412
+
2413
+ UNUSED(buffer);
2414
+ }
2415
+
2416
+ static void wsp_ggml_backend_metal_buffer_cpy_tensor_to(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
2417
+ wsp_ggml_backend_tensor_set(dst, src->data, 0, wsp_ggml_nbytes(src));
2418
+
1559
2419
  UNUSED(buffer);
1560
2420
  }
1561
2421
 
1562
2422
  static struct wsp_ggml_backend_buffer_i metal_backend_buffer_i = {
1563
- /* .free_buffer = */ wsp_ggml_backend_metal_buffer_free_buffer,
1564
- /* .get_base = */ wsp_ggml_backend_metal_buffer_get_base,
1565
- /* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes
1566
- /* .init_tensor = */ NULL, // no initialization required
1567
- /* .free_tensor = */ NULL, // no cleanup required
2423
+ /* .free_buffer = */ wsp_ggml_backend_metal_buffer_free_buffer,
2424
+ /* .get_base = */ wsp_ggml_backend_metal_buffer_get_base,
2425
+ /* .init_tensor = */ NULL,
2426
+ /* .set_tensor = */ wsp_ggml_backend_metal_buffer_set_tensor,
2427
+ /* .get_tensor = */ wsp_ggml_backend_metal_buffer_get_tensor,
2428
+ /* .cpy_tensor_from = */ wsp_ggml_backend_metal_buffer_cpy_tensor_from,
2429
+ /* .cpy_tensor_to = */ wsp_ggml_backend_metal_buffer_cpy_tensor_to,
1568
2430
  };
1569
2431
 
1570
- static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_alloc_buffer(wsp_ggml_backend_t backend, size_t size) {
1571
- struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
2432
+ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
2433
+ struct wsp_ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct wsp_ggml_backend_metal_buffer_context));
1572
2434
 
1573
- void * data = wsp_ggml_metal_host_malloc(size);
2435
+ const size_t size_page = sysconf(_SC_PAGESIZE);
1574
2436
 
1575
- // TODO: set proper name of the buffers
1576
- wsp_ggml_metal_add_buffer(ctx, "backend", data, size, 0);
2437
+ size_t size_aligned = size;
2438
+ if ((size_aligned % size_page) != 0) {
2439
+ size_aligned += (size_page - (size_aligned % size_page));
2440
+ }
1577
2441
 
1578
- return wsp_ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size);
2442
+ ctx->data = wsp_ggml_metal_host_malloc(size);
2443
+ ctx->metal = [wsp_ggml_backend_metal_get_device() newBufferWithBytesNoCopy:ctx->data
2444
+ length:size_aligned
2445
+ options:MTLResourceStorageModeShared
2446
+ deallocator:nil];
2447
+
2448
+ return wsp_ggml_backend_buffer_init(buft, metal_backend_buffer_i, ctx, size);
1579
2449
  }
1580
2450
 
1581
- static size_t wsp_ggml_backend_metal_get_alignment(wsp_ggml_backend_t backend) {
2451
+ static size_t wsp_ggml_backend_metal_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
1582
2452
  return 32;
1583
- UNUSED(backend);
2453
+ UNUSED(buft);
1584
2454
  }
1585
2455
 
1586
- static void wsp_ggml_backend_metal_set_tensor_async(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1587
- WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds");
1588
- WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
2456
+ static bool wsp_ggml_backend_metal_buffer_type_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend) {
2457
+ return wsp_ggml_backend_is_metal(backend) || wsp_ggml_backend_is_cpu(backend);
1589
2458
 
1590
- memcpy((char *)tensor->data + offset, data, size);
1591
-
1592
- UNUSED(backend);
2459
+ WSP_GGML_UNUSED(buft);
1593
2460
  }
1594
2461
 
1595
- static void wsp_ggml_backend_metal_get_tensor_async(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1596
- WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
1597
- WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1598
-
1599
- memcpy(data, (const char *)tensor->data + offset, size);
2462
+ wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void) {
2463
+ static struct wsp_ggml_backend_buffer_type wsp_ggml_backend_buffer_type_metal = {
2464
+ /* .iface = */ {
2465
+ /* .alloc_buffer = */ wsp_ggml_backend_metal_buffer_type_alloc_buffer,
2466
+ /* .get_alignment = */ wsp_ggml_backend_metal_buffer_type_get_alignment,
2467
+ /* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes
2468
+ /* .supports_backend = */ wsp_ggml_backend_metal_buffer_type_supports_backend,
2469
+ },
2470
+ /* .context = */ NULL,
2471
+ };
1600
2472
 
1601
- UNUSED(backend);
2473
+ return &wsp_ggml_backend_buffer_type_metal;
1602
2474
  }
1603
2475
 
1604
- static void wsp_ggml_backend_metal_synchronize(wsp_ggml_backend_t backend) {
2476
+ static const char * wsp_ggml_backend_metal_name(wsp_ggml_backend_t backend) {
2477
+ return "Metal";
2478
+
1605
2479
  UNUSED(backend);
1606
2480
  }
1607
2481
 
1608
- static void wsp_ggml_backend_metal_cpy_tensor_from(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
1609
- wsp_ggml_backend_tensor_get(src, dst->data, 0, wsp_ggml_nbytes(src));
2482
+ static void wsp_ggml_backend_metal_free(wsp_ggml_backend_t backend) {
2483
+ struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
2484
+ wsp_ggml_metal_free(ctx);
2485
+ free(backend);
2486
+ }
1610
2487
 
2488
+ static void wsp_ggml_backend_metal_synchronize(wsp_ggml_backend_t backend) {
1611
2489
  UNUSED(backend);
1612
2490
  }
1613
2491
 
1614
- static void wsp_ggml_backend_metal_cpy_tensor_to(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
1615
- wsp_ggml_backend_tensor_set_async(dst, src->data, 0, wsp_ggml_nbytes(src));
2492
+ static wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_get_default_buffer_type(wsp_ggml_backend_t backend) {
2493
+ return wsp_ggml_backend_metal_buffer_type();
1616
2494
 
1617
2495
  UNUSED(backend);
1618
2496
  }
@@ -1624,32 +2502,43 @@ static void wsp_ggml_backend_metal_graph_compute(wsp_ggml_backend_t backend, str
1624
2502
  }
1625
2503
 
1626
2504
  static bool wsp_ggml_backend_metal_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) {
1627
- return true;
2505
+ return wsp_ggml_metal_supports_op(op);
2506
+
1628
2507
  UNUSED(backend);
1629
- UNUSED(op);
1630
2508
  }
1631
2509
 
1632
2510
  static struct wsp_ggml_backend_i metal_backend_i = {
1633
- /* .get_name = */ wsp_ggml_backend_metal_name,
1634
- /* .free = */ wsp_ggml_backend_metal_free,
1635
- /* .alloc_buffer = */ wsp_ggml_backend_metal_alloc_buffer,
1636
- /* .get_alignment = */ wsp_ggml_backend_metal_get_alignment,
1637
- /* .set_tensor_async = */ wsp_ggml_backend_metal_set_tensor_async,
1638
- /* .get_tensor_async = */ wsp_ggml_backend_metal_get_tensor_async,
1639
- /* .synchronize = */ wsp_ggml_backend_metal_synchronize,
1640
- /* .cpy_tensor_from = */ wsp_ggml_backend_metal_cpy_tensor_from,
1641
- /* .cpy_tensor_to = */ wsp_ggml_backend_metal_cpy_tensor_to,
1642
- /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
1643
- /* .graph_plan_free = */ NULL,
1644
- /* .graph_plan_compute = */ NULL,
1645
- /* .graph_compute = */ wsp_ggml_backend_metal_graph_compute,
1646
- /* .supports_op = */ wsp_ggml_backend_metal_supports_op,
2511
+ /* .get_name = */ wsp_ggml_backend_metal_name,
2512
+ /* .free = */ wsp_ggml_backend_metal_free,
2513
+ /* .get_default_buffer_type = */ wsp_ggml_backend_metal_get_default_buffer_type,
2514
+ /* .set_tensor_async = */ NULL,
2515
+ /* .get_tensor_async = */ NULL,
2516
+ /* .cpy_tensor_from_async = */ NULL,
2517
+ /* .cpy_tensor_to_async = */ NULL,
2518
+ /* .synchronize = */ wsp_ggml_backend_metal_synchronize,
2519
+ /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
2520
+ /* .graph_plan_free = */ NULL,
2521
+ /* .graph_plan_compute = */ NULL,
2522
+ /* .graph_compute = */ wsp_ggml_backend_metal_graph_compute,
2523
+ /* .supports_op = */ wsp_ggml_backend_metal_supports_op,
1647
2524
  };
1648
2525
 
2526
+ // TODO: make a common log callback for all backends in ggml-backend
2527
+ static void wsp_ggml_backend_log_callback(enum wsp_ggml_log_level level, const char * msg, void * user_data) {
2528
+ fprintf(stderr, "%s", msg);
2529
+
2530
+ UNUSED(level);
2531
+ UNUSED(user_data);
2532
+ }
2533
+
1649
2534
  wsp_ggml_backend_t wsp_ggml_backend_metal_init(void) {
1650
- struct wsp_ggml_metal_context * ctx = malloc(sizeof(struct wsp_ggml_metal_context));
2535
+ wsp_ggml_metal_log_set_callback(wsp_ggml_backend_log_callback, NULL);
2536
+
2537
+ struct wsp_ggml_metal_context * ctx = wsp_ggml_metal_init(WSP_GGML_DEFAULT_N_THREADS);
1651
2538
 
1652
- ctx = wsp_ggml_metal_init(WSP_GGML_DEFAULT_N_THREADS);
2539
+ if (ctx == NULL) {
2540
+ return NULL;
2541
+ }
1653
2542
 
1654
2543
  wsp_ggml_backend_t metal_backend = malloc(sizeof(struct wsp_ggml_backend));
1655
2544
 
@@ -1666,7 +2555,26 @@ bool wsp_ggml_backend_is_metal(wsp_ggml_backend_t backend) {
1666
2555
  }
1667
2556
 
1668
2557
  void wsp_ggml_backend_metal_set_n_cb(wsp_ggml_backend_t backend, int n_cb) {
2558
+ WSP_GGML_ASSERT(wsp_ggml_backend_is_metal(backend));
2559
+
1669
2560
  struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
1670
2561
 
1671
2562
  wsp_ggml_metal_set_n_cb(ctx, n_cb);
1672
2563
  }
2564
+
2565
+ bool wsp_ggml_backend_metal_supports_family(wsp_ggml_backend_t backend, int family) {
2566
+ WSP_GGML_ASSERT(wsp_ggml_backend_is_metal(backend));
2567
+
2568
+ struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
2569
+
2570
+ return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
2571
+ }
2572
+
2573
+ wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
2574
+
2575
+ wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data) {
2576
+ return wsp_ggml_backend_metal_init();
2577
+
2578
+ WSP_GGML_UNUSED(params);
2579
+ WSP_GGML_UNUSED(user_data);
2580
+ }