whisper.rn 0.4.0-rc.3 → 0.4.0-rc.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. package/README.md +6 -6
  2. package/android/build.gradle +4 -0
  3. package/android/src/main/CMakeLists.txt +7 -0
  4. package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
  5. package/android/src/main/java/com/rnwhisper/RNWhisper.java +6 -1
  6. package/android/src/main/java/com/rnwhisper/WhisperContext.java +53 -135
  7. package/android/src/main/jni-utils.h +76 -0
  8. package/android/src/main/jni.cpp +188 -109
  9. package/cpp/README.md +1 -1
  10. package/cpp/coreml/whisper-encoder-impl.h +1 -1
  11. package/cpp/coreml/whisper-encoder.h +4 -0
  12. package/cpp/coreml/whisper-encoder.mm +4 -2
  13. package/cpp/ggml-alloc.c +451 -282
  14. package/cpp/ggml-alloc.h +74 -8
  15. package/cpp/ggml-backend-impl.h +112 -0
  16. package/cpp/ggml-backend.c +1357 -0
  17. package/cpp/ggml-backend.h +181 -0
  18. package/cpp/ggml-impl.h +243 -0
  19. package/cpp/{ggml-metal.metal → ggml-metal-whisper.metal} +1556 -329
  20. package/cpp/ggml-metal.h +28 -1
  21. package/cpp/ggml-metal.m +1128 -308
  22. package/cpp/ggml-quants.c +7382 -0
  23. package/cpp/ggml-quants.h +224 -0
  24. package/cpp/ggml.c +3848 -5245
  25. package/cpp/ggml.h +353 -155
  26. package/cpp/rn-audioutils.cpp +68 -0
  27. package/cpp/rn-audioutils.h +14 -0
  28. package/cpp/rn-whisper-log.h +11 -0
  29. package/cpp/rn-whisper.cpp +141 -59
  30. package/cpp/rn-whisper.h +47 -15
  31. package/cpp/whisper.cpp +1750 -964
  32. package/cpp/whisper.h +97 -15
  33. package/ios/RNWhisper.mm +15 -9
  34. package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +4 -0
  35. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
  36. package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
  37. package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +19 -0
  38. package/ios/RNWhisperAudioUtils.h +0 -2
  39. package/ios/RNWhisperAudioUtils.m +0 -56
  40. package/ios/RNWhisperContext.h +8 -12
  41. package/ios/RNWhisperContext.mm +132 -138
  42. package/jest/mock.js +1 -1
  43. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  44. package/lib/commonjs/index.js +28 -9
  45. package/lib/commonjs/index.js.map +1 -1
  46. package/lib/commonjs/version.json +1 -1
  47. package/lib/module/NativeRNWhisper.js.map +1 -1
  48. package/lib/module/index.js +28 -9
  49. package/lib/module/index.js.map +1 -1
  50. package/lib/module/version.json +1 -1
  51. package/lib/typescript/NativeRNWhisper.d.ts +7 -1
  52. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  53. package/lib/typescript/index.d.ts +7 -2
  54. package/lib/typescript/index.d.ts.map +1 -1
  55. package/package.json +6 -5
  56. package/src/NativeRNWhisper.ts +8 -1
  57. package/src/index.ts +29 -17
  58. package/src/version.json +1 -1
  59. package/whisper-rn.podspec +1 -2
package/cpp/ggml-metal.m CHANGED
@@ -1,5 +1,6 @@
1
1
  #import "ggml-metal.h"
2
2
 
3
+ #import "ggml-backend-impl.h"
3
4
  #import "ggml.h"
4
5
 
5
6
  #import <Foundation/Foundation.h>
@@ -11,16 +12,19 @@
11
12
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
12
13
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
13
14
 
14
- // TODO: temporary - reuse llama.cpp logging
15
15
  #ifdef WSP_GGML_METAL_NDEBUG
16
- #define metal_printf(...)
16
+ #define WSP_GGML_METAL_LOG_INFO(...)
17
+ #define WSP_GGML_METAL_LOG_WARN(...)
18
+ #define WSP_GGML_METAL_LOG_ERROR(...)
17
19
  #else
18
- #define metal_printf(...) fprintf(stderr, __VA_ARGS__)
20
+ #define WSP_GGML_METAL_LOG_INFO(...) wsp_ggml_metal_log(WSP_GGML_LOG_LEVEL_INFO, __VA_ARGS__)
21
+ #define WSP_GGML_METAL_LOG_WARN(...) wsp_ggml_metal_log(WSP_GGML_LOG_LEVEL_WARN, __VA_ARGS__)
22
+ #define WSP_GGML_METAL_LOG_ERROR(...) wsp_ggml_metal_log(WSP_GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
19
23
  #endif
20
24
 
21
25
  #define UNUSED(x) (void)(x)
22
26
 
23
- #define WSP_GGML_MAX_CONCUR (2*WSP_GGML_MAX_NODES)
27
+ #define WSP_GGML_MAX_CONCUR (2*WSP_GGML_DEFAULT_GRAPH_SIZE)
24
28
 
25
29
  struct wsp_ggml_metal_buffer {
26
30
  const char * name;
@@ -58,7 +62,10 @@ struct wsp_ggml_metal_context {
58
62
  WSP_GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
59
63
  WSP_GGML_METAL_DECL_KERNEL(mul);
60
64
  WSP_GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
65
+ WSP_GGML_METAL_DECL_KERNEL(div);
66
+ WSP_GGML_METAL_DECL_KERNEL(div_row);
61
67
  WSP_GGML_METAL_DECL_KERNEL(scale);
68
+ WSP_GGML_METAL_DECL_KERNEL(scale_4);
62
69
  WSP_GGML_METAL_DECL_KERNEL(silu);
63
70
  WSP_GGML_METAL_DECL_KERNEL(relu);
64
71
  WSP_GGML_METAL_DECL_KERNEL(gelu);
@@ -70,6 +77,8 @@ struct wsp_ggml_metal_context {
70
77
  WSP_GGML_METAL_DECL_KERNEL(get_rows_f16);
71
78
  WSP_GGML_METAL_DECL_KERNEL(get_rows_q4_0);
72
79
  WSP_GGML_METAL_DECL_KERNEL(get_rows_q4_1);
80
+ WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_0);
81
+ WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_1);
73
82
  WSP_GGML_METAL_DECL_KERNEL(get_rows_q8_0);
74
83
  WSP_GGML_METAL_DECL_KERNEL(get_rows_q2_K);
75
84
  WSP_GGML_METAL_DECL_KERNEL(get_rows_q3_K);
@@ -78,33 +87,62 @@ struct wsp_ggml_metal_context {
78
87
  WSP_GGML_METAL_DECL_KERNEL(get_rows_q6_K);
79
88
  WSP_GGML_METAL_DECL_KERNEL(rms_norm);
80
89
  WSP_GGML_METAL_DECL_KERNEL(norm);
81
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
82
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
83
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
84
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
85
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
86
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
87
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
88
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
89
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
90
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
91
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
92
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
90
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
91
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
92
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
93
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
94
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
95
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
96
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
97
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
98
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
99
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
100
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
101
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
102
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
103
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
104
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
93
105
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
94
106
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
95
107
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
96
108
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
109
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
110
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
97
111
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
98
112
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
99
113
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
100
114
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
101
115
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
102
116
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
103
- WSP_GGML_METAL_DECL_KERNEL(rope);
117
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
118
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
119
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
120
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
121
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
122
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
123
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
124
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
125
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
126
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
127
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
128
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
129
+ WSP_GGML_METAL_DECL_KERNEL(rope_f32);
130
+ WSP_GGML_METAL_DECL_KERNEL(rope_f16);
104
131
  WSP_GGML_METAL_DECL_KERNEL(alibi_f32);
132
+ WSP_GGML_METAL_DECL_KERNEL(im2col_f16);
133
+ WSP_GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
134
+ WSP_GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
105
135
  WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f16);
106
136
  WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f32);
137
+ WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
138
+ WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
139
+ WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
140
+ //WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
141
+ //WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
107
142
  WSP_GGML_METAL_DECL_KERNEL(cpy_f16_f16);
143
+ WSP_GGML_METAL_DECL_KERNEL(concat);
144
+ WSP_GGML_METAL_DECL_KERNEL(sqr);
145
+ WSP_GGML_METAL_DECL_KERNEL(sum_rows);
108
146
 
109
147
  #undef WSP_GGML_METAL_DECL_KERNEL
110
148
  };
@@ -112,18 +150,48 @@ struct wsp_ggml_metal_context {
112
150
  // MSL code
113
151
  // TODO: move the contents here when ready
114
152
  // for now it is easier to work in a separate file
115
- static NSString * const msl_library_source = @"see metal.metal";
153
+ //static NSString * const msl_library_source = @"see metal.metal";
116
154
 
117
155
  // Here to assist with NSBundle Path Hack
118
- @interface GGMLMetalClass : NSObject
156
+ @interface WSPGGMLMetalClass : NSObject
119
157
  @end
120
- @implementation GGMLMetalClass
158
+ @implementation WSPGGMLMetalClass
121
159
  @end
122
160
 
161
+ wsp_ggml_log_callback wsp_ggml_metal_log_callback = NULL;
162
+ void * wsp_ggml_metal_log_user_data = NULL;
163
+
164
+ void wsp_ggml_metal_log_set_callback(wsp_ggml_log_callback log_callback, void * user_data) {
165
+ wsp_ggml_metal_log_callback = log_callback;
166
+ wsp_ggml_metal_log_user_data = user_data;
167
+ }
168
+
169
+ WSP_GGML_ATTRIBUTE_FORMAT(2, 3)
170
+ static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char * format, ...){
171
+ if (wsp_ggml_metal_log_callback != NULL) {
172
+ va_list args;
173
+ va_start(args, format);
174
+ char buffer[128];
175
+ int len = vsnprintf(buffer, 128, format, args);
176
+ if (len < 128) {
177
+ wsp_ggml_metal_log_callback(level, buffer, wsp_ggml_metal_log_user_data);
178
+ } else {
179
+ char* buffer2 = malloc(len+1);
180
+ va_end(args);
181
+ va_start(args, format);
182
+ vsnprintf(buffer2, len+1, format, args);
183
+ buffer2[len] = 0;
184
+ wsp_ggml_metal_log_callback(level, buffer2, wsp_ggml_metal_log_user_data);
185
+ free(buffer2);
186
+ }
187
+ va_end(args);
188
+ }
189
+ }
190
+
123
191
  struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
124
- metal_printf("%s: allocating\n", __func__);
192
+ WSP_GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
125
193
 
126
- id <MTLDevice> device;
194
+ id<MTLDevice> device;
127
195
  NSString * s;
128
196
 
129
197
  #if TARGET_OS_OSX
@@ -131,17 +199,17 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
131
199
  NSArray * devices = MTLCopyAllDevices();
132
200
  for (device in devices) {
133
201
  s = [device name];
134
- metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
202
+ WSP_GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]);
135
203
  }
136
204
  #endif
137
205
 
138
206
  // Pick and show default Metal device
139
207
  device = MTLCreateSystemDefaultDevice();
140
208
  s = [device name];
141
- metal_printf("%s: picking default device: %s\n", __func__, [s UTF8String]);
209
+ WSP_GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
142
210
 
143
211
  // Configure context
144
- struct wsp_ggml_metal_context * ctx = malloc(sizeof(struct wsp_ggml_metal_context));
212
+ struct wsp_ggml_metal_context * ctx = calloc(1, sizeof(struct wsp_ggml_metal_context));
145
213
  ctx->device = device;
146
214
  ctx->n_cb = MIN(n_cb, WSP_GGML_METAL_MAX_BUFFERS);
147
215
  ctx->queue = [ctx->device newCommandQueue];
@@ -150,68 +218,95 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
150
218
 
151
219
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
152
220
 
153
- #ifdef WSP_GGML_SWIFT
154
- // load the default.metallib file
221
+ // load library
155
222
  {
223
+ NSBundle * bundle = nil;
224
+ #ifdef SWIFT_PACKAGE
225
+ bundle = SWIFTPM_MODULE_BUNDLE;
226
+ #else
227
+ bundle = [NSBundle bundleForClass:[WSPGGMLMetalClass class]];
228
+ #endif
156
229
  NSError * error = nil;
230
+ NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"];
231
+ if (libPath != nil) {
232
+ NSURL * libURL = [NSURL fileURLWithPath:libPath];
233
+ WSP_GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
234
+ ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
235
+ } else {
236
+ WSP_GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
157
237
 
158
- NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
159
- NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"];
160
- NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
161
- NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
162
- NSURL * libURL = [NSURL fileURLWithPath:libPath];
163
-
164
- // Load the metallib file into a Metal library
165
- ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
238
+ NSString * sourcePath;
239
+ NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"WSP_GGML_METAL_PATH_RESOURCES"];
166
240
 
167
- if (error) {
168
- metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
169
- return NULL;
170
- }
171
- }
172
- #else
173
- UNUSED(msl_library_source);
241
+ WSP_GGML_METAL_LOG_INFO("%s: WSP_GGML_METAL_PATH_RESOURCES = %s\n", __func__, ggmlMetalPathResources ? [ggmlMetalPathResources UTF8String] : "nil");
174
242
 
175
- // read the source from "ggml-metal.metal" into a string and use newLibraryWithSource
176
- {
177
- NSError * error = nil;
243
+ if (ggmlMetalPathResources) {
244
+ sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
245
+ } else {
246
+ sourcePath = [bundle pathForResource:@"ggml-metal-whisper" ofType:@"metal"];
247
+ }
248
+ if (sourcePath == nil) {
249
+ WSP_GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
250
+ sourcePath = @"ggml-metal.metal";
251
+ }
252
+ WSP_GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [sourcePath UTF8String]);
253
+ NSString * src = [NSString stringWithContentsOfFile:sourcePath encoding:NSUTF8StringEncoding error:&error];
254
+ if (error) {
255
+ WSP_GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
256
+ return NULL;
257
+ }
178
258
 
179
- //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
180
- NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
181
- NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
182
- metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
259
+ MTLCompileOptions* options = nil;
260
+ #ifdef WSP_GGML_QKK_64
261
+ options = [MTLCompileOptions new];
262
+ options.preprocessorMacros = @{ @"QK_K" : @(64) };
263
+ #endif
264
+ ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
265
+ }
183
266
 
184
- NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
185
267
  if (error) {
186
- metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
268
+ WSP_GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
187
269
  return NULL;
188
270
  }
271
+ }
189
272
 
190
- #ifdef WSP_GGML_QKK_64
191
- MTLCompileOptions* options = [MTLCompileOptions new];
192
- options.preprocessorMacros = @{ @"QK_K" : @(64) };
193
- ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
194
- #else
195
- ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
196
- #endif
197
- if (error) {
198
- metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
199
- return NULL;
273
+ #if TARGET_OS_OSX
274
+ // print MTL GPU family:
275
+ WSP_GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
276
+
277
+ // determine max supported GPU family
278
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
279
+ // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
280
+ for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
281
+ if ([ctx->device supportsFamily:i]) {
282
+ WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
283
+ break;
200
284
  }
201
285
  }
286
+
287
+ WSP_GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
288
+ WSP_GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
289
+ if (ctx->device.maxTransferRate != 0) {
290
+ WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
291
+ } else {
292
+ WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
293
+ }
202
294
  #endif
203
295
 
204
296
  // load kernels
205
297
  {
206
298
  NSError * error = nil;
299
+
300
+ /*
301
+ WSP_GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
302
+ (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
303
+ (int) ctx->pipeline_##name.threadExecutionWidth); \
304
+ */
207
305
  #define WSP_GGML_METAL_ADD_KERNEL(name) \
208
306
  ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
209
307
  ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
210
- metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (__bridge void *) ctx->pipeline_##name, \
211
- (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
212
- (int) ctx->pipeline_##name.threadExecutionWidth); \
213
308
  if (error) { \
214
- metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
309
+ WSP_GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
215
310
  return NULL; \
216
311
  }
217
312
 
@@ -219,7 +314,10 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
219
314
  WSP_GGML_METAL_ADD_KERNEL(add_row);
220
315
  WSP_GGML_METAL_ADD_KERNEL(mul);
221
316
  WSP_GGML_METAL_ADD_KERNEL(mul_row);
317
+ WSP_GGML_METAL_ADD_KERNEL(div);
318
+ WSP_GGML_METAL_ADD_KERNEL(div_row);
222
319
  WSP_GGML_METAL_ADD_KERNEL(scale);
320
+ WSP_GGML_METAL_ADD_KERNEL(scale_4);
223
321
  WSP_GGML_METAL_ADD_KERNEL(silu);
224
322
  WSP_GGML_METAL_ADD_KERNEL(relu);
225
323
  WSP_GGML_METAL_ADD_KERNEL(gelu);
@@ -231,6 +329,8 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
231
329
  WSP_GGML_METAL_ADD_KERNEL(get_rows_f16);
232
330
  WSP_GGML_METAL_ADD_KERNEL(get_rows_q4_0);
233
331
  WSP_GGML_METAL_ADD_KERNEL(get_rows_q4_1);
332
+ WSP_GGML_METAL_ADD_KERNEL(get_rows_q5_0);
333
+ WSP_GGML_METAL_ADD_KERNEL(get_rows_q5_1);
234
334
  WSP_GGML_METAL_ADD_KERNEL(get_rows_q8_0);
235
335
  WSP_GGML_METAL_ADD_KERNEL(get_rows_q2_K);
236
336
  WSP_GGML_METAL_ADD_KERNEL(get_rows_q3_K);
@@ -239,59 +339,83 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
239
339
  WSP_GGML_METAL_ADD_KERNEL(get_rows_q6_K);
240
340
  WSP_GGML_METAL_ADD_KERNEL(rms_norm);
241
341
  WSP_GGML_METAL_ADD_KERNEL(norm);
242
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
243
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
244
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
245
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
246
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
247
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
248
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
249
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
250
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
251
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
252
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
253
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
254
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
255
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
256
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
257
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
258
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
259
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
260
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
261
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
262
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
263
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
264
- WSP_GGML_METAL_ADD_KERNEL(rope);
342
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
343
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
344
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
345
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
346
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
347
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
348
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
349
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
350
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
351
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
352
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
353
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
354
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
355
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
356
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
357
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
358
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
359
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
360
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
361
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
362
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
363
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
364
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
365
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
366
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
367
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
368
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
369
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
370
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
371
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
372
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
373
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
374
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
375
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
376
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
377
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
378
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
379
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
380
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
381
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
382
+ }
383
+ WSP_GGML_METAL_ADD_KERNEL(rope_f32);
384
+ WSP_GGML_METAL_ADD_KERNEL(rope_f16);
265
385
  WSP_GGML_METAL_ADD_KERNEL(alibi_f32);
386
+ WSP_GGML_METAL_ADD_KERNEL(im2col_f16);
387
+ WSP_GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
388
+ WSP_GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
266
389
  WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f16);
267
390
  WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f32);
391
+ WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
392
+ WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
393
+ WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
394
+ //WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
395
+ //WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
268
396
  WSP_GGML_METAL_ADD_KERNEL(cpy_f16_f16);
397
+ WSP_GGML_METAL_ADD_KERNEL(concat);
398
+ WSP_GGML_METAL_ADD_KERNEL(sqr);
399
+ WSP_GGML_METAL_ADD_KERNEL(sum_rows);
269
400
 
270
401
  #undef WSP_GGML_METAL_ADD_KERNEL
271
402
  }
272
403
 
273
- metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
274
- #if TARGET_OS_OSX
275
- metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
276
- if (ctx->device.maxTransferRate != 0) {
277
- metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
278
- } else {
279
- metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
280
- }
281
- #endif
282
-
283
404
  return ctx;
284
405
  }
285
406
 
286
407
  void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
287
- metal_printf("%s: deallocating\n", __func__);
408
+ WSP_GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
288
409
  #define WSP_GGML_METAL_DEL_KERNEL(name) \
289
410
 
290
411
  WSP_GGML_METAL_DEL_KERNEL(add);
291
412
  WSP_GGML_METAL_DEL_KERNEL(add_row);
292
413
  WSP_GGML_METAL_DEL_KERNEL(mul);
293
414
  WSP_GGML_METAL_DEL_KERNEL(mul_row);
415
+ WSP_GGML_METAL_DEL_KERNEL(div);
416
+ WSP_GGML_METAL_DEL_KERNEL(div_row);
294
417
  WSP_GGML_METAL_DEL_KERNEL(scale);
418
+ WSP_GGML_METAL_DEL_KERNEL(scale_4);
295
419
  WSP_GGML_METAL_DEL_KERNEL(silu);
296
420
  WSP_GGML_METAL_DEL_KERNEL(relu);
297
421
  WSP_GGML_METAL_DEL_KERNEL(gelu);
@@ -303,6 +427,8 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
303
427
  WSP_GGML_METAL_DEL_KERNEL(get_rows_f16);
304
428
  WSP_GGML_METAL_DEL_KERNEL(get_rows_q4_0);
305
429
  WSP_GGML_METAL_DEL_KERNEL(get_rows_q4_1);
430
+ WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_0);
431
+ WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_1);
306
432
  WSP_GGML_METAL_DEL_KERNEL(get_rows_q8_0);
307
433
  WSP_GGML_METAL_DEL_KERNEL(get_rows_q2_K);
308
434
  WSP_GGML_METAL_DEL_KERNEL(get_rows_q3_K);
@@ -311,33 +437,64 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
311
437
  WSP_GGML_METAL_DEL_KERNEL(get_rows_q6_K);
312
438
  WSP_GGML_METAL_DEL_KERNEL(rms_norm);
313
439
  WSP_GGML_METAL_DEL_KERNEL(norm);
314
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
315
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
316
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
317
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
318
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
319
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
320
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
321
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
322
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
323
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
324
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
325
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
326
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
327
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
328
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
329
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
330
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
331
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
332
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
333
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
334
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
335
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
336
- WSP_GGML_METAL_DEL_KERNEL(rope);
440
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
441
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
442
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
443
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
444
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
445
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
446
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
447
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
448
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
449
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
450
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
451
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
452
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
453
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
454
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
455
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
456
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
457
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
458
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
459
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
460
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
461
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
462
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
463
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
464
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
465
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
466
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
467
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
468
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
469
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
470
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
471
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
472
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
473
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
474
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
475
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
476
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
477
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
478
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
479
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
480
+ }
481
+ WSP_GGML_METAL_DEL_KERNEL(rope_f32);
482
+ WSP_GGML_METAL_DEL_KERNEL(rope_f16);
337
483
  WSP_GGML_METAL_DEL_KERNEL(alibi_f32);
484
+ WSP_GGML_METAL_DEL_KERNEL(im2col_f16);
485
+ WSP_GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
486
+ WSP_GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
338
487
  WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f16);
339
488
  WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f32);
489
+ WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
490
+ WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
491
+ WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
492
+ //WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
493
+ //WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
340
494
  WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f16);
495
+ WSP_GGML_METAL_DEL_KERNEL(concat);
496
+ WSP_GGML_METAL_DEL_KERNEL(sqr);
497
+ WSP_GGML_METAL_DEL_KERNEL(sum_rows);
341
498
 
342
499
  #undef WSP_GGML_METAL_DEL_KERNEL
343
500
 
@@ -348,7 +505,7 @@ void * wsp_ggml_metal_host_malloc(size_t n) {
348
505
  void * data = NULL;
349
506
  const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
350
507
  if (result != 0) {
351
- metal_printf("%s: error: posix_memalign failed\n", __func__);
508
+ WSP_GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
352
509
  return NULL;
353
510
  }
354
511
 
@@ -371,30 +528,50 @@ int * wsp_ggml_metal_get_concur_list(struct wsp_ggml_metal_context * ctx) {
371
528
  return ctx->concur_list;
372
529
  }
373
530
 
531
+ // temporarily defined here for compatibility between ggml-backend and the old API
532
+ struct wsp_ggml_backend_metal_buffer_context {
533
+ void * data;
534
+
535
+ id<MTLBuffer> metal;
536
+ };
537
+
374
538
  // finds the Metal buffer that contains the tensor data on the GPU device
375
539
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
376
540
  // Metal buffer based on the host memory pointer
377
541
  //
378
542
  static id<MTLBuffer> wsp_ggml_metal_get_buffer(struct wsp_ggml_metal_context * ctx, struct wsp_ggml_tensor * t, size_t * offs) {
379
- //metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
543
+ //WSP_GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
380
544
 
381
545
  const int64_t tsize = wsp_ggml_nbytes(t);
382
546
 
547
+ // compatibility with ggml-backend
548
+ if (t->buffer && t->buffer->buft == wsp_ggml_backend_metal_buffer_type()) {
549
+ struct wsp_ggml_backend_metal_buffer_context * buf_ctx = (struct wsp_ggml_backend_metal_buffer_context *) t->buffer->context;
550
+
551
+ const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->data;
552
+
553
+ WSP_GGML_ASSERT(ioffs >= 0 && ioffs + tsize <= (int64_t) t->buffer->size);
554
+
555
+ *offs = (size_t) ioffs;
556
+
557
+ return buf_ctx->metal;
558
+ }
559
+
383
560
  // find the view that contains the tensor fully
384
561
  for (int i = 0; i < ctx->n_buffers; ++i) {
385
562
  const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
386
563
 
387
- //metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
564
+ //WSP_GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
388
565
  if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
389
566
  *offs = (size_t) ioffs;
390
567
 
391
- //metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
568
+ //WSP_GGML_METAL_LOG_INFO("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
392
569
 
393
570
  return ctx->buffers[i].metal;
394
571
  }
395
572
  }
396
573
 
397
- metal_printf("%s: error: buffer is nil\n", __func__);
574
+ WSP_GGML_METAL_LOG_ERROR("%s: error: buffer is nil\n", __func__);
398
575
 
399
576
  return nil;
400
577
  }
@@ -406,7 +583,7 @@ bool wsp_ggml_metal_add_buffer(
406
583
  size_t size,
407
584
  size_t max_size) {
408
585
  if (ctx->n_buffers >= WSP_GGML_METAL_MAX_BUFFERS) {
409
- metal_printf("%s: too many buffers\n", __func__);
586
+ WSP_GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__);
410
587
  return false;
411
588
  }
412
589
 
@@ -416,7 +593,7 @@ bool wsp_ggml_metal_add_buffer(
416
593
  const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
417
594
 
418
595
  if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
419
- metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
596
+ WSP_GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
420
597
  return false;
421
598
  }
422
599
  }
@@ -437,11 +614,11 @@ bool wsp_ggml_metal_add_buffer(
437
614
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
438
615
 
439
616
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
440
- metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
617
+ WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
441
618
  return false;
442
619
  }
443
620
 
444
- metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
621
+ WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB", __func__, name, size_aligned / 1024.0 / 1024.0);
445
622
 
446
623
  ++ctx->n_buffers;
447
624
  } else {
@@ -461,13 +638,13 @@ bool wsp_ggml_metal_add_buffer(
461
638
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
462
639
 
463
640
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
464
- metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
641
+ WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
465
642
  return false;
466
643
  }
467
644
 
468
- metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
645
+ WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
469
646
  if (i + size_step < size) {
470
- metal_printf("\n");
647
+ WSP_GGML_METAL_LOG_INFO("\n");
471
648
  }
472
649
 
473
650
  ++ctx->n_buffers;
@@ -475,17 +652,17 @@ bool wsp_ggml_metal_add_buffer(
475
652
  }
476
653
 
477
654
  #if TARGET_OS_OSX
478
- metal_printf(", (%8.2f / %8.2f)",
655
+ WSP_GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
479
656
  ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
480
657
  ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
481
658
 
482
659
  if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
483
- metal_printf(", warning: current allocated size is greater than the recommended max working set size\n");
660
+ WSP_GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
484
661
  } else {
485
- metal_printf("\n");
662
+ WSP_GGML_METAL_LOG_INFO("\n");
486
663
  }
487
664
  #else
488
- metal_printf(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
665
+ WSP_GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
489
666
  #endif
490
667
  }
491
668
 
@@ -598,10 +775,55 @@ void wsp_ggml_metal_graph_find_concurrency(
598
775
  }
599
776
 
600
777
  if (ctx->concur_list_len > WSP_GGML_MAX_CONCUR) {
601
- metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__);
778
+ WSP_GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__);
602
779
  }
603
780
  }
604
781
 
782
+ static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) {
783
+ switch (op->op) {
784
+ case WSP_GGML_OP_UNARY:
785
+ switch (wsp_ggml_get_unary_op(op)) {
786
+ case WSP_GGML_UNARY_OP_SILU:
787
+ case WSP_GGML_UNARY_OP_RELU:
788
+ case WSP_GGML_UNARY_OP_GELU:
789
+ return true;
790
+ default:
791
+ return false;
792
+ }
793
+ case WSP_GGML_OP_NONE:
794
+ case WSP_GGML_OP_RESHAPE:
795
+ case WSP_GGML_OP_VIEW:
796
+ case WSP_GGML_OP_TRANSPOSE:
797
+ case WSP_GGML_OP_PERMUTE:
798
+ case WSP_GGML_OP_CONCAT:
799
+ case WSP_GGML_OP_ADD:
800
+ case WSP_GGML_OP_MUL:
801
+ case WSP_GGML_OP_DIV:
802
+ case WSP_GGML_OP_SCALE:
803
+ case WSP_GGML_OP_SQR:
804
+ case WSP_GGML_OP_SUM_ROWS:
805
+ case WSP_GGML_OP_SOFT_MAX:
806
+ case WSP_GGML_OP_RMS_NORM:
807
+ case WSP_GGML_OP_NORM:
808
+ case WSP_GGML_OP_ALIBI:
809
+ case WSP_GGML_OP_ROPE:
810
+ case WSP_GGML_OP_IM2COL:
811
+ case WSP_GGML_OP_ARGSORT:
812
+ case WSP_GGML_OP_DUP:
813
+ case WSP_GGML_OP_CPY:
814
+ case WSP_GGML_OP_CONT:
815
+ case WSP_GGML_OP_MUL_MAT:
816
+ case WSP_GGML_OP_MUL_MAT_ID:
817
+ return true;
818
+ case WSP_GGML_OP_DIAG_MASK_INF:
819
+ case WSP_GGML_OP_GET_ROWS:
820
+ {
821
+ return op->ne[0] % 4 == 0;
822
+ }
823
+ default:
824
+ return false;
825
+ }
826
+ }
605
827
  void wsp_ggml_metal_graph_compute(
606
828
  struct wsp_ggml_metal_context * ctx,
607
829
  struct wsp_ggml_cgraph * gf) {
@@ -652,12 +874,28 @@ void wsp_ggml_metal_graph_compute(
652
874
  continue;
653
875
  }
654
876
 
655
- //metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, wsp_ggml_op_name(gf->nodes[i]->op));
877
+ //WSP_GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, wsp_ggml_op_name(gf->nodes[i]->op));
656
878
 
657
879
  struct wsp_ggml_tensor * src0 = gf->nodes[i]->src[0];
658
880
  struct wsp_ggml_tensor * src1 = gf->nodes[i]->src[1];
659
881
  struct wsp_ggml_tensor * dst = gf->nodes[i];
660
882
 
883
+ switch (dst->op) {
884
+ case WSP_GGML_OP_NONE:
885
+ case WSP_GGML_OP_RESHAPE:
886
+ case WSP_GGML_OP_VIEW:
887
+ case WSP_GGML_OP_TRANSPOSE:
888
+ case WSP_GGML_OP_PERMUTE:
889
+ {
890
+ // noop -> next node
891
+ } continue;
892
+ default:
893
+ {
894
+ } break;
895
+ }
896
+
897
+ WSP_GGML_ASSERT(wsp_ggml_metal_supports_op(dst));
898
+
661
899
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
662
900
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
663
901
  const int64_t ne02 = src0 ? src0->ne[2] : 0;
@@ -696,78 +934,129 @@ void wsp_ggml_metal_graph_compute(
696
934
  id<MTLBuffer> id_src1 = src1 ? wsp_ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
697
935
  id<MTLBuffer> id_dst = dst ? wsp_ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
698
936
 
699
- //metal_printf("%s: op - %s\n", __func__, wsp_ggml_op_name(dst->op));
937
+ //WSP_GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, wsp_ggml_op_name(dst->op));
700
938
  //if (src0) {
701
- // metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src0t), ne00, ne01, ne02,
939
+ // WSP_GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src0t), ne00, ne01, ne02,
702
940
  // wsp_ggml_is_contiguous(src0), src0->name);
703
941
  //}
704
942
  //if (src1) {
705
- // metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src1t), ne10, ne11, ne12,
943
+ // WSP_GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src1t), ne10, ne11, ne12,
706
944
  // wsp_ggml_is_contiguous(src1), src1->name);
707
945
  //}
708
946
  //if (dst) {
709
- // metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, wsp_ggml_type_name(dstt), ne0, ne1, ne2,
947
+ // WSP_GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, wsp_ggml_type_name(dstt), ne0, ne1, ne2,
710
948
  // dst->name);
711
949
  //}
712
950
 
713
951
  switch (dst->op) {
714
- case WSP_GGML_OP_NONE:
715
- case WSP_GGML_OP_RESHAPE:
716
- case WSP_GGML_OP_VIEW:
717
- case WSP_GGML_OP_TRANSPOSE:
718
- case WSP_GGML_OP_PERMUTE:
952
+ case WSP_GGML_OP_CONCAT:
719
953
  {
720
- // noop
721
- } break;
722
- case WSP_GGML_OP_ADD:
723
- {
724
- WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
725
- WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
726
-
727
- // utilize float4
728
- WSP_GGML_ASSERT(ne00 % 4 == 0);
729
- const int64_t nb = ne00/4;
954
+ const int64_t nb = ne00;
730
955
 
731
- if (wsp_ggml_nelements(src1) == ne10) {
732
- // src1 is a row
733
- WSP_GGML_ASSERT(ne11 == 1);
734
- [encoder setComputePipelineState:ctx->pipeline_add_row];
735
- } else {
736
- [encoder setComputePipelineState:ctx->pipeline_add];
737
- }
956
+ [encoder setComputePipelineState:ctx->pipeline_concat];
738
957
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
739
958
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
740
959
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
741
- [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
742
-
743
- const int64_t n = wsp_ggml_nelements(dst)/4;
744
-
745
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
960
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
961
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
962
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
963
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
964
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
965
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
966
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
967
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
968
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
969
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
970
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
971
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
972
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
973
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
974
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
975
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
976
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
977
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
978
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
979
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
980
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
981
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
982
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
983
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
984
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
985
+
986
+ const int nth = MIN(1024, ne0);
987
+
988
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
746
989
  } break;
990
+ case WSP_GGML_OP_ADD:
747
991
  case WSP_GGML_OP_MUL:
992
+ case WSP_GGML_OP_DIV:
748
993
  {
749
994
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
750
995
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
751
996
 
752
- // utilize float4
753
- WSP_GGML_ASSERT(ne00 % 4 == 0);
754
- const int64_t nb = ne00/4;
997
+ bool bcast_row = false;
998
+
999
+ int64_t nb = ne00;
755
1000
 
756
- if (wsp_ggml_nelements(src1) == ne10) {
1001
+ if (wsp_ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
757
1002
  // src1 is a row
758
1003
  WSP_GGML_ASSERT(ne11 == 1);
759
- [encoder setComputePipelineState:ctx->pipeline_mul_row];
1004
+
1005
+ nb = ne00 / 4;
1006
+ switch (dst->op) {
1007
+ case WSP_GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
1008
+ case WSP_GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
1009
+ case WSP_GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
1010
+ default: WSP_GGML_ASSERT(false);
1011
+ }
1012
+
1013
+ bcast_row = true;
760
1014
  } else {
761
- [encoder setComputePipelineState:ctx->pipeline_mul];
1015
+ switch (dst->op) {
1016
+ case WSP_GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
1017
+ case WSP_GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
1018
+ case WSP_GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
1019
+ default: WSP_GGML_ASSERT(false);
1020
+ }
762
1021
  }
763
1022
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
764
1023
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
765
1024
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
766
- [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
767
-
768
- const int64_t n = wsp_ggml_nelements(dst)/4;
1025
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1026
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1027
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1028
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
1029
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
1030
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
1031
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
1032
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
1033
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1034
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
1035
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1036
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1037
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1038
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1039
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1040
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
1041
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
1042
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
1043
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
1044
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
1045
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
1046
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1047
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1048
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1049
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
1050
+
1051
+ if (bcast_row) {
1052
+ const int64_t n = wsp_ggml_nelements(dst)/4;
1053
+
1054
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1055
+ } else {
1056
+ const int nth = MIN(1024, ne0);
769
1057
 
770
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1058
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1059
+ }
771
1060
  } break;
772
1061
  case WSP_GGML_OP_SCALE:
773
1062
  {
@@ -775,13 +1064,19 @@ void wsp_ggml_metal_graph_compute(
775
1064
 
776
1065
  const float scale = *(const float *) src1->data;
777
1066
 
778
- [encoder setComputePipelineState:ctx->pipeline_scale];
1067
+ int64_t n = wsp_ggml_nelements(dst);
1068
+
1069
+ if (n % 4 == 0) {
1070
+ n /= 4;
1071
+ [encoder setComputePipelineState:ctx->pipeline_scale_4];
1072
+ } else {
1073
+ [encoder setComputePipelineState:ctx->pipeline_scale];
1074
+ }
1075
+
779
1076
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
780
1077
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
781
1078
  [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
782
1079
 
783
- const int64_t n = wsp_ggml_nelements(dst)/4;
784
-
785
1080
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
786
1081
  } break;
787
1082
  case WSP_GGML_OP_UNARY:
@@ -792,9 +1087,10 @@ void wsp_ggml_metal_graph_compute(
792
1087
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
793
1088
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
794
1089
 
795
- const int64_t n = wsp_ggml_nelements(dst)/4;
1090
+ const int64_t n = wsp_ggml_nelements(dst);
1091
+ WSP_GGML_ASSERT(n % 4 == 0);
796
1092
 
797
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1093
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
798
1094
  } break;
799
1095
  case WSP_GGML_UNARY_OP_RELU:
800
1096
  {
@@ -812,32 +1108,92 @@ void wsp_ggml_metal_graph_compute(
812
1108
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
813
1109
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
814
1110
 
815
- const int64_t n = wsp_ggml_nelements(dst)/4;
1111
+ const int64_t n = wsp_ggml_nelements(dst);
1112
+ WSP_GGML_ASSERT(n % 4 == 0);
816
1113
 
817
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1114
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
818
1115
  } break;
819
1116
  default:
820
1117
  {
821
- metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op));
1118
+ WSP_GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op));
822
1119
  WSP_GGML_ASSERT(false);
823
1120
  }
824
1121
  } break;
1122
+ case WSP_GGML_OP_SQR:
1123
+ {
1124
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
1125
+
1126
+ [encoder setComputePipelineState:ctx->pipeline_sqr];
1127
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1128
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1129
+
1130
+ const int64_t n = wsp_ggml_nelements(dst);
1131
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1132
+ } break;
1133
+ case WSP_GGML_OP_SUM_ROWS:
1134
+ {
1135
+ WSP_GGML_ASSERT(src0->nb[0] == wsp_ggml_type_size(src0->type));
1136
+
1137
+ [encoder setComputePipelineState:ctx->pipeline_sum_rows];
1138
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1139
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1140
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1141
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1142
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1143
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1144
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1145
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1146
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1147
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1148
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1149
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1150
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1151
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1152
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1153
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1154
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1155
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1156
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1157
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
1158
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
1159
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
1160
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
1161
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
1162
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
1163
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
1164
+
1165
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1166
+ } break;
825
1167
  case WSP_GGML_OP_SOFT_MAX:
826
1168
  {
827
- const int nth = 32;
1169
+ int nth = 32; // SIMD width
828
1170
 
829
1171
  if (ne00%4 == 0) {
1172
+ while (nth < ne00/4 && nth < 256) {
1173
+ nth *= 2;
1174
+ }
830
1175
  [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
831
1176
  } else {
1177
+ while (nth < ne00 && nth < 1024) {
1178
+ nth *= 2;
1179
+ }
832
1180
  [encoder setComputePipelineState:ctx->pipeline_soft_max];
833
1181
  }
834
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
835
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
836
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
837
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
838
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
839
1182
 
840
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1183
+ const float scale = ((float *) dst->op_params)[0];
1184
+
1185
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1186
+ if (id_src1) {
1187
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1188
+ }
1189
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1190
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1191
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1192
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1193
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1194
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1195
+
1196
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
841
1197
  } break;
842
1198
  case WSP_GGML_OP_DIAG_MASK_INF:
843
1199
  {
@@ -863,26 +1219,57 @@ void wsp_ggml_metal_graph_compute(
863
1219
  } break;
864
1220
  case WSP_GGML_OP_MUL_MAT:
865
1221
  {
866
- // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
867
-
868
1222
  WSP_GGML_ASSERT(ne00 == ne10);
869
- // WSP_GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
870
- uint gqa = ne12/ne02;
871
- WSP_GGML_ASSERT(ne03 == ne13);
1223
+
1224
+ // TODO: assert that dim2 and dim3 are contiguous
1225
+ WSP_GGML_ASSERT(ne12 % ne02 == 0);
1226
+ WSP_GGML_ASSERT(ne13 % ne03 == 0);
1227
+
1228
+ const uint r2 = ne12/ne02;
1229
+ const uint r3 = ne13/ne03;
1230
+
1231
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1232
+ // to the matrix-vector kernel
1233
+ int ne11_mm_min = 1;
1234
+
1235
+ #if 0
1236
+ // the numbers below are measured on M2 Ultra for 7B and 13B models
1237
+ // these numbers do not translate to other devices or model sizes
1238
+ // TODO: need to find a better approach
1239
+ if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
1240
+ switch (src0t) {
1241
+ case WSP_GGML_TYPE_F16: ne11_mm_min = 2; break;
1242
+ case WSP_GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1243
+ case WSP_GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1244
+ case WSP_GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1245
+ case WSP_GGML_TYPE_Q4_0:
1246
+ case WSP_GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1247
+ case WSP_GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1248
+ case WSP_GGML_TYPE_Q5_0: // not tested yet
1249
+ case WSP_GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1250
+ case WSP_GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1251
+ case WSP_GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1252
+ default: ne11_mm_min = 1; break;
1253
+ }
1254
+ }
1255
+ #endif
872
1256
 
873
1257
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
874
1258
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
875
- if (!wsp_ggml_is_transposed(src0) &&
1259
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1260
+ !wsp_ggml_is_transposed(src0) &&
876
1261
  !wsp_ggml_is_transposed(src1) &&
877
1262
  src1t == WSP_GGML_TYPE_F32 &&
878
- [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
879
- ne00%32 == 0 &&
880
- ne11 > 1) {
1263
+ ne00 % 32 == 0 && ne00 >= 64 &&
1264
+ (ne11 > ne11_mm_min || (wsp_ggml_is_quantized(src0t) && ne12 > 1))) {
1265
+ //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
881
1266
  switch (src0->type) {
882
1267
  case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
883
1268
  case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
884
1269
  case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
885
1270
  case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
1271
+ case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
1272
+ case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
886
1273
  case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
887
1274
  case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
888
1275
  case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
@@ -904,110 +1291,106 @@ void wsp_ggml_metal_graph_compute(
904
1291
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
905
1292
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
906
1293
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
907
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
1294
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1295
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
908
1296
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
909
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1297
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
910
1298
  } else {
911
1299
  int nth0 = 32;
912
1300
  int nth1 = 1;
913
1301
  int nrows = 1;
1302
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
914
1303
 
915
1304
  // use custom matrix x vector kernel
916
1305
  switch (src0t) {
917
1306
  case WSP_GGML_TYPE_F32:
918
1307
  {
919
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
1308
+ WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
1309
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
920
1310
  nrows = 4;
921
1311
  } break;
922
1312
  case WSP_GGML_TYPE_F16:
923
1313
  {
924
1314
  nth0 = 32;
925
1315
  nth1 = 1;
926
- if (ne11 * ne12 < 4) {
927
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
928
- } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
929
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
930
- nrows = ne11;
1316
+ if (src1t == WSP_GGML_TYPE_F32) {
1317
+ if (ne11 * ne12 < 4) {
1318
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
1319
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1320
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
1321
+ nrows = ne11;
1322
+ } else {
1323
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
1324
+ nrows = 4;
1325
+ }
931
1326
  } else {
932
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
1327
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
933
1328
  nrows = 4;
934
1329
  }
935
1330
  } break;
936
1331
  case WSP_GGML_TYPE_Q4_0:
937
1332
  {
938
- WSP_GGML_ASSERT(ne02 == 1);
939
- WSP_GGML_ASSERT(ne12 == 1);
940
-
941
1333
  nth0 = 8;
942
1334
  nth1 = 8;
943
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
1335
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
944
1336
  } break;
945
1337
  case WSP_GGML_TYPE_Q4_1:
946
1338
  {
947
- WSP_GGML_ASSERT(ne02 == 1);
948
- WSP_GGML_ASSERT(ne12 == 1);
949
-
950
1339
  nth0 = 8;
951
1340
  nth1 = 8;
952
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
1341
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
1342
+ } break;
1343
+ case WSP_GGML_TYPE_Q5_0:
1344
+ {
1345
+ nth0 = 8;
1346
+ nth1 = 8;
1347
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
1348
+ } break;
1349
+ case WSP_GGML_TYPE_Q5_1:
1350
+ {
1351
+ nth0 = 8;
1352
+ nth1 = 8;
1353
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
953
1354
  } break;
954
1355
  case WSP_GGML_TYPE_Q8_0:
955
1356
  {
956
- WSP_GGML_ASSERT(ne02 == 1);
957
- WSP_GGML_ASSERT(ne12 == 1);
958
-
959
1357
  nth0 = 8;
960
1358
  nth1 = 8;
961
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
1359
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
962
1360
  } break;
963
1361
  case WSP_GGML_TYPE_Q2_K:
964
1362
  {
965
- WSP_GGML_ASSERT(ne02 == 1);
966
- WSP_GGML_ASSERT(ne12 == 1);
967
-
968
1363
  nth0 = 2;
969
1364
  nth1 = 32;
970
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
1365
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
971
1366
  } break;
972
1367
  case WSP_GGML_TYPE_Q3_K:
973
1368
  {
974
- WSP_GGML_ASSERT(ne02 == 1);
975
- WSP_GGML_ASSERT(ne12 == 1);
976
-
977
1369
  nth0 = 2;
978
1370
  nth1 = 32;
979
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
1371
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
980
1372
  } break;
981
1373
  case WSP_GGML_TYPE_Q4_K:
982
1374
  {
983
- WSP_GGML_ASSERT(ne02 == 1);
984
- WSP_GGML_ASSERT(ne12 == 1);
985
-
986
1375
  nth0 = 4; //1;
987
1376
  nth1 = 8; //32;
988
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
1377
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
989
1378
  } break;
990
1379
  case WSP_GGML_TYPE_Q5_K:
991
1380
  {
992
- WSP_GGML_ASSERT(ne02 == 1);
993
- WSP_GGML_ASSERT(ne12 == 1);
994
-
995
1381
  nth0 = 2;
996
1382
  nth1 = 32;
997
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
1383
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
998
1384
  } break;
999
1385
  case WSP_GGML_TYPE_Q6_K:
1000
1386
  {
1001
- WSP_GGML_ASSERT(ne02 == 1);
1002
- WSP_GGML_ASSERT(ne12 == 1);
1003
-
1004
1387
  nth0 = 2;
1005
1388
  nth1 = 32;
1006
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
1389
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
1007
1390
  } break;
1008
1391
  default:
1009
1392
  {
1010
- metal_printf("Asserting on type %d\n",(int)src0t);
1393
+ WSP_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1011
1394
  WSP_GGML_ASSERT(false && "not implemented");
1012
1395
  }
1013
1396
  };
@@ -1029,31 +1412,125 @@ void wsp_ggml_metal_graph_compute(
1029
1412
  [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1030
1413
  [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1031
1414
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1032
- [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
1415
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1416
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1033
1417
 
1034
- if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 || src0t == WSP_GGML_TYPE_Q8_0 ||
1035
- src0t == WSP_GGML_TYPE_Q2_K) {// || src0t == WSP_GGML_TYPE_Q4_K) {
1036
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1418
+ if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 ||
1419
+ src0t == WSP_GGML_TYPE_Q5_0 || src0t == WSP_GGML_TYPE_Q5_1 || src0t == WSP_GGML_TYPE_Q8_0 ||
1420
+ src0t == WSP_GGML_TYPE_Q2_K) { // || src0t == WSP_GGML_TYPE_Q4_K) {
1421
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1037
1422
  }
1038
1423
  else if (src0t == WSP_GGML_TYPE_Q4_K) {
1039
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1424
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1040
1425
  }
1041
1426
  else if (src0t == WSP_GGML_TYPE_Q3_K) {
1042
1427
  #ifdef WSP_GGML_QKK_64
1043
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1428
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1044
1429
  #else
1045
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1430
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1046
1431
  #endif
1047
1432
  }
1048
1433
  else if (src0t == WSP_GGML_TYPE_Q5_K) {
1049
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1434
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1050
1435
  }
1051
1436
  else if (src0t == WSP_GGML_TYPE_Q6_K) {
1052
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1437
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1053
1438
  } else {
1054
1439
  int64_t ny = (ne11 + nrows - 1)/nrows;
1055
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1440
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1441
+ }
1442
+ }
1443
+ } break;
1444
+ case WSP_GGML_OP_MUL_MAT_ID:
1445
+ {
1446
+ //WSP_GGML_ASSERT(ne00 == ne10);
1447
+ //WSP_GGML_ASSERT(ne03 == ne13);
1448
+
1449
+ WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_I32);
1450
+
1451
+ const int n_as = ne00;
1452
+
1453
+ // TODO: make this more general
1454
+ WSP_GGML_ASSERT(n_as <= 8);
1455
+
1456
+ struct wsp_ggml_tensor * src2 = gf->nodes[i]->src[2];
1457
+
1458
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
1459
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
1460
+ const int64_t ne22 = src2 ? src2->ne[2] : 0;
1461
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; WSP_GGML_UNUSED(ne23);
1462
+
1463
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; WSP_GGML_UNUSED(nb20);
1464
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1465
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1466
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0; WSP_GGML_UNUSED(nb23);
1467
+
1468
+ const enum wsp_ggml_type src2t = src2 ? src2->type : WSP_GGML_TYPE_COUNT; WSP_GGML_UNUSED(src2t);
1469
+
1470
+ WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src2));
1471
+ WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src1));
1472
+
1473
+ WSP_GGML_ASSERT(ne20 % 32 == 0);
1474
+ // !!!!!!!!! TODO: this assert is probably required but not sure!
1475
+ //WSP_GGML_ASSERT(ne20 >= 64);
1476
+ WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
1477
+
1478
+ const uint r2 = ne12/ne22;
1479
+ const uint r3 = ne13/ne23;
1480
+
1481
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1482
+ // to the matrix-vector kernel
1483
+ int ne11_mm_min = 0;
1484
+
1485
+ const int idx = ((int32_t *) dst->op_params)[0];
1486
+
1487
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1488
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1489
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1490
+ ne11 > ne11_mm_min) {
1491
+ switch (src2->type) {
1492
+ case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
1493
+ case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
1494
+ case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
1495
+ case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
1496
+ case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
1497
+ case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
1498
+ case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
1499
+ case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
1500
+ case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
1501
+ case WSP_GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
1502
+ case WSP_GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
1503
+ case WSP_GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
1504
+ default: WSP_GGML_ASSERT(false && "MUL_MAT_ID not implemented");
1505
+ }
1506
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1507
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1508
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1509
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
1510
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
1511
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
1512
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6];
1513
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1514
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1515
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1516
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1517
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1518
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1519
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1520
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1521
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:15];
1522
+ // TODO: how to make this an array? read Metal docs
1523
+ for (int j = 0; j < n_as; ++j) {
1524
+ struct wsp_ggml_tensor * src_cur = dst->src[2 + j];
1525
+
1526
+ size_t offs_src_cur = 0;
1527
+ id<MTLBuffer> id_src_cur = wsp_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
1528
+
1529
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
1056
1530
  }
1531
+
1532
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
1533
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1057
1534
  }
1058
1535
  } break;
1059
1536
  case WSP_GGML_OP_GET_ROWS:
@@ -1063,6 +1540,8 @@ void wsp_ggml_metal_graph_compute(
1063
1540
  case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
1064
1541
  case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
1065
1542
  case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
1543
+ case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
1544
+ case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
1066
1545
  case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
1067
1546
  case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
1068
1547
  case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
@@ -1085,18 +1564,24 @@ void wsp_ggml_metal_graph_compute(
1085
1564
  } break;
1086
1565
  case WSP_GGML_OP_RMS_NORM:
1087
1566
  {
1567
+ WSP_GGML_ASSERT(ne00 % 4 == 0);
1568
+
1088
1569
  float eps;
1089
1570
  memcpy(&eps, dst->op_params, sizeof(float));
1090
1571
 
1091
- const int nth = 512;
1572
+ int nth = 32; // SIMD width
1573
+
1574
+ while (nth < ne00/4 && nth < 1024) {
1575
+ nth *= 2;
1576
+ }
1092
1577
 
1093
1578
  [encoder setComputePipelineState:ctx->pipeline_rms_norm];
1094
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1095
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1096
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1097
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1098
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1099
- [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
1579
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1580
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1581
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1582
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1583
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1584
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1100
1585
 
1101
1586
  const int64_t nrows = wsp_ggml_nrows(src0);
1102
1587
 
@@ -1107,7 +1592,7 @@ void wsp_ggml_metal_graph_compute(
1107
1592
  float eps;
1108
1593
  memcpy(&eps, dst->op_params, sizeof(float));
1109
1594
 
1110
- const int nth = 256;
1595
+ const int nth = MIN(256, ne00);
1111
1596
 
1112
1597
  [encoder setComputePipelineState:ctx->pipeline_norm];
1113
1598
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1115,7 +1600,7 @@ void wsp_ggml_metal_graph_compute(
1115
1600
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1116
1601
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1117
1602
  [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1118
- [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
1603
+ [encoder setThreadgroupMemoryLength:WSP_GGML_PAD(nth*sizeof(float), 16) atIndex:0];
1119
1604
 
1120
1605
  const int64_t nrows = wsp_ggml_nrows(src0);
1121
1606
 
@@ -1125,17 +1610,16 @@ void wsp_ggml_metal_graph_compute(
1125
1610
  {
1126
1611
  WSP_GGML_ASSERT((src0t == WSP_GGML_TYPE_F32));
1127
1612
 
1128
- const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
1613
+ const int nth = MIN(1024, ne00);
1614
+
1615
+ //const int n_past = ((int32_t *) dst->op_params)[0];
1129
1616
  const int n_head = ((int32_t *) dst->op_params)[1];
1130
1617
  float max_bias;
1131
1618
  memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
1132
1619
 
1133
- if (__builtin_popcount(n_head) != 1) {
1134
- WSP_GGML_ASSERT(false && "only power-of-two n_head implemented");
1135
- }
1136
-
1137
1620
  const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
1138
1621
  const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
1622
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
1139
1623
 
1140
1624
  [encoder setComputePipelineState:ctx->pipeline_alibi_f32];
1141
1625
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1156,62 +1640,164 @@ void wsp_ggml_metal_graph_compute(
1156
1640
  [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1157
1641
  [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1158
1642
  [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1159
- [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
1160
-
1161
- const int nth = 32;
1643
+ [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
1644
+ [encoder setBytes:&m1 length:sizeof( float) atIndex:19];
1645
+ [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
1162
1646
 
1163
1647
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1164
1648
  } break;
1165
1649
  case WSP_GGML_OP_ROPE:
1166
1650
  {
1167
- const int n_past = ((int32_t *) dst->op_params)[0];
1168
- const int n_dims = ((int32_t *) dst->op_params)[1];
1169
- const int mode = ((int32_t *) dst->op_params)[2];
1651
+ WSP_GGML_ASSERT(ne10 == ne02);
1652
+
1653
+ const int nth = MIN(1024, ne00);
1654
+
1655
+ const int n_past = ((int32_t *) dst->op_params)[0];
1656
+ const int n_dims = ((int32_t *) dst->op_params)[1];
1657
+ const int mode = ((int32_t *) dst->op_params)[2];
1658
+ // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
1659
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
1660
+
1661
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1662
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
1663
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
1664
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
1665
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
1666
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1667
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1668
+
1669
+ switch (src0->type) {
1670
+ case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
1671
+ case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
1672
+ default: WSP_GGML_ASSERT(false);
1673
+ };
1674
+
1675
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1676
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1677
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1678
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1679
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
1680
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
1681
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
1682
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
1683
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
1684
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
1685
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
1686
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
1687
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
1688
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
1689
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
1690
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
1691
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
1692
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
1693
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
1694
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
1695
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
1696
+ [encoder setBytes:&mode length:sizeof( int) atIndex:21];
1697
+ [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
1698
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
1699
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
1700
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
1701
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
1702
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
1703
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
1704
+
1705
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1706
+ } break;
1707
+ case WSP_GGML_OP_IM2COL:
1708
+ {
1709
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
1710
+ WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
1711
+ WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16);
1712
+
1713
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
1714
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
1715
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
1716
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
1717
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
1718
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
1719
+ const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
1720
+
1721
+ const int32_t N = src1->ne[is_2D ? 3 : 2];
1722
+ const int32_t IC = src1->ne[is_2D ? 2 : 1];
1723
+ const int32_t IH = is_2D ? src1->ne[1] : 1;
1724
+ const int32_t IW = src1->ne[0];
1725
+
1726
+ const int32_t KH = is_2D ? src0->ne[1] : 1;
1727
+ const int32_t KW = src0->ne[0];
1728
+
1729
+ const int32_t OH = is_2D ? dst->ne[2] : 1;
1730
+ const int32_t OW = dst->ne[1];
1731
+
1732
+ const int32_t CHW = IC * KH * KW;
1733
+
1734
+ const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
1735
+ const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
1736
+
1737
+ switch (src0->type) {
1738
+ case WSP_GGML_TYPE_F32: WSP_GGML_ASSERT(false && "not implemented"); break;
1739
+ case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
1740
+ default: WSP_GGML_ASSERT(false);
1741
+ };
1742
+
1743
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
1744
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1745
+ [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
1746
+ [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
1747
+ [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
1748
+ [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
1749
+ [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
1750
+ [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
1751
+ [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
1752
+ [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
1753
+ [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
1754
+ [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
1755
+ [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
1756
+
1757
+ [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
1758
+ } break;
1759
+ case WSP_GGML_OP_ARGSORT:
1760
+ {
1761
+ WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
1762
+ WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_I32);
1170
1763
 
1171
- float freq_base;
1172
- float freq_scale;
1173
- memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
1174
- memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
1764
+ const int nrows = wsp_ggml_nrows(src0);
1765
+
1766
+ enum wsp_ggml_sort_order order = (enum wsp_ggml_sort_order) dst->op_params[0];
1767
+
1768
+ switch (order) {
1769
+ case WSP_GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
1770
+ case WSP_GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
1771
+ default: WSP_GGML_ASSERT(false);
1772
+ };
1175
1773
 
1176
- [encoder setComputePipelineState:ctx->pipeline_rope];
1177
1774
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1178
1775
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1179
1776
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1180
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1181
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1182
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1183
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1184
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1185
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1186
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1187
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1188
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1189
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1190
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1191
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1192
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1193
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1194
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1195
- [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
1196
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
1197
- [encoder setBytes:&mode length:sizeof( int) atIndex:20];
1198
- [encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
1199
- [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
1200
1777
 
1201
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
1778
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
1202
1779
  } break;
1203
1780
  case WSP_GGML_OP_DUP:
1204
1781
  case WSP_GGML_OP_CPY:
1205
1782
  case WSP_GGML_OP_CONT:
1206
1783
  {
1207
- const int nth = 32;
1784
+ WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(src0->type) == 0);
1785
+
1786
+ int nth = MIN(1024, ne00/wsp_ggml_blck_size(src0->type));
1208
1787
 
1209
1788
  switch (src0t) {
1210
1789
  case WSP_GGML_TYPE_F32:
1211
1790
  {
1791
+ WSP_GGML_ASSERT(ne0 % wsp_ggml_blck_size(dst->type) == 0);
1792
+
1212
1793
  switch (dstt) {
1213
- case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
1214
- case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
1794
+ case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
1795
+ case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
1796
+ case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
1797
+ case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
1798
+ case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
1799
+ //case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
1800
+ //case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
1215
1801
  default: WSP_GGML_ASSERT(false && "not implemented");
1216
1802
  };
1217
1803
  } break;
@@ -1249,7 +1835,7 @@ void wsp_ggml_metal_graph_compute(
1249
1835
  } break;
1250
1836
  default:
1251
1837
  {
1252
- metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op));
1838
+ WSP_GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op));
1253
1839
  WSP_GGML_ASSERT(false);
1254
1840
  }
1255
1841
  }
@@ -1274,10 +1860,244 @@ void wsp_ggml_metal_graph_compute(
1274
1860
 
1275
1861
  MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
1276
1862
  if (status != MTLCommandBufferStatusCompleted) {
1277
- metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status);
1863
+ WSP_GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
1278
1864
  WSP_GGML_ASSERT(false);
1279
1865
  }
1280
1866
  }
1281
1867
 
1282
1868
  }
1283
1869
  }
1870
+
1871
+ ////////////////////////////////////////////////////////////////////////////////
1872
+
1873
+ // backend interface
1874
+
1875
+ static id<MTLDevice> g_backend_device = nil;
1876
+ static int g_backend_device_ref_count = 0;
1877
+
1878
+ static id<MTLDevice> wsp_ggml_backend_metal_get_device(void) {
1879
+ if (g_backend_device == nil) {
1880
+ g_backend_device = MTLCreateSystemDefaultDevice();
1881
+ }
1882
+
1883
+ g_backend_device_ref_count++;
1884
+
1885
+ return g_backend_device;
1886
+ }
1887
+
1888
+ static void wsp_ggml_backend_metal_free_device(void) {
1889
+ assert(g_backend_device_ref_count > 0);
1890
+
1891
+ g_backend_device_ref_count--;
1892
+
1893
+ if (g_backend_device_ref_count == 0) {
1894
+ g_backend_device = nil;
1895
+ }
1896
+ }
1897
+
1898
+ static void * wsp_ggml_backend_metal_buffer_get_base(wsp_ggml_backend_buffer_t buffer) {
1899
+ struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context;
1900
+
1901
+ return ctx->data;
1902
+ }
1903
+
1904
+ static void wsp_ggml_backend_metal_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) {
1905
+ struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context;
1906
+
1907
+ wsp_ggml_backend_metal_free_device();
1908
+
1909
+ free(ctx->data);
1910
+ free(ctx);
1911
+
1912
+ UNUSED(buffer);
1913
+ }
1914
+
1915
+ static void wsp_ggml_backend_metal_buffer_set_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1916
+ WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds");
1917
+ WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1918
+
1919
+ memcpy((char *)tensor->data + offset, data, size);
1920
+
1921
+ UNUSED(buffer);
1922
+ }
1923
+
1924
+ static void wsp_ggml_backend_metal_buffer_get_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1925
+ WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
1926
+ WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1927
+
1928
+ memcpy(data, (const char *)tensor->data + offset, size);
1929
+
1930
+ UNUSED(buffer);
1931
+ }
1932
+
1933
+ static void wsp_ggml_backend_metal_buffer_cpy_tensor_from(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
1934
+ wsp_ggml_backend_tensor_get(src, dst->data, 0, wsp_ggml_nbytes(src));
1935
+
1936
+ UNUSED(buffer);
1937
+ }
1938
+
1939
+ static void wsp_ggml_backend_metal_buffer_cpy_tensor_to(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
1940
+ wsp_ggml_backend_tensor_set(dst, src->data, 0, wsp_ggml_nbytes(src));
1941
+
1942
+ UNUSED(buffer);
1943
+ }
1944
+
1945
+ static struct wsp_ggml_backend_buffer_i metal_backend_buffer_i = {
1946
+ /* .free_buffer = */ wsp_ggml_backend_metal_buffer_free_buffer,
1947
+ /* .get_base = */ wsp_ggml_backend_metal_buffer_get_base,
1948
+ /* .init_tensor = */ NULL,
1949
+ /* .set_tensor = */ wsp_ggml_backend_metal_buffer_set_tensor,
1950
+ /* .get_tensor = */ wsp_ggml_backend_metal_buffer_get_tensor,
1951
+ /* .cpy_tensor_from = */ wsp_ggml_backend_metal_buffer_cpy_tensor_from,
1952
+ /* .cpy_tensor_to = */ wsp_ggml_backend_metal_buffer_cpy_tensor_to,
1953
+ };
1954
+
1955
+ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
1956
+ struct wsp_ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct wsp_ggml_backend_metal_buffer_context));
1957
+
1958
+ const size_t size_page = sysconf(_SC_PAGESIZE);
1959
+
1960
+ size_t size_aligned = size;
1961
+ if ((size_aligned % size_page) != 0) {
1962
+ size_aligned += (size_page - (size_aligned % size_page));
1963
+ }
1964
+
1965
+ ctx->data = wsp_ggml_metal_host_malloc(size);
1966
+ ctx->metal = [wsp_ggml_backend_metal_get_device() newBufferWithBytesNoCopy:ctx->data
1967
+ length:size_aligned
1968
+ options:MTLResourceStorageModeShared
1969
+ deallocator:nil];
1970
+
1971
+ return wsp_ggml_backend_buffer_init(buft, metal_backend_buffer_i, ctx, size);
1972
+ }
1973
+
1974
+ static size_t wsp_ggml_backend_metal_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
1975
+ return 32;
1976
+ UNUSED(buft);
1977
+ }
1978
+
1979
+ static bool wsp_ggml_backend_metal_buffer_type_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend) {
1980
+ return wsp_ggml_backend_is_metal(backend) || wsp_ggml_backend_is_cpu(backend);
1981
+
1982
+ WSP_GGML_UNUSED(buft);
1983
+ }
1984
+
1985
+ wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void) {
1986
+ static struct wsp_ggml_backend_buffer_type wsp_ggml_backend_buffer_type_metal = {
1987
+ /* .iface = */ {
1988
+ /* .alloc_buffer = */ wsp_ggml_backend_metal_buffer_type_alloc_buffer,
1989
+ /* .get_alignment = */ wsp_ggml_backend_metal_buffer_type_get_alignment,
1990
+ /* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes
1991
+ /* .supports_backend = */ wsp_ggml_backend_metal_buffer_type_supports_backend,
1992
+ },
1993
+ /* .context = */ NULL,
1994
+ };
1995
+
1996
+ return &wsp_ggml_backend_buffer_type_metal;
1997
+ }
1998
+
1999
+ static const char * wsp_ggml_backend_metal_name(wsp_ggml_backend_t backend) {
2000
+ return "Metal";
2001
+
2002
+ UNUSED(backend);
2003
+ }
2004
+
2005
+ static void wsp_ggml_backend_metal_free(wsp_ggml_backend_t backend) {
2006
+ struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
2007
+ wsp_ggml_metal_free(ctx);
2008
+ free(backend);
2009
+ }
2010
+
2011
+ static void wsp_ggml_backend_metal_synchronize(wsp_ggml_backend_t backend) {
2012
+ UNUSED(backend);
2013
+ }
2014
+
2015
+ static wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_get_default_buffer_type(wsp_ggml_backend_t backend) {
2016
+ return wsp_ggml_backend_metal_buffer_type();
2017
+
2018
+ UNUSED(backend);
2019
+ }
2020
+
2021
+ static void wsp_ggml_backend_metal_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) {
2022
+ struct wsp_ggml_metal_context * metal_ctx = (struct wsp_ggml_metal_context *)backend->context;
2023
+
2024
+ wsp_ggml_metal_graph_compute(metal_ctx, cgraph);
2025
+ }
2026
+
2027
+ static bool wsp_ggml_backend_metal_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) {
2028
+ return wsp_ggml_metal_supports_op(op);
2029
+
2030
+ UNUSED(backend);
2031
+ }
2032
+
2033
+ static struct wsp_ggml_backend_i metal_backend_i = {
2034
+ /* .get_name = */ wsp_ggml_backend_metal_name,
2035
+ /* .free = */ wsp_ggml_backend_metal_free,
2036
+ /* .get_default_buffer_type = */ wsp_ggml_backend_metal_get_default_buffer_type,
2037
+ /* .set_tensor_async = */ NULL,
2038
+ /* .get_tensor_async = */ NULL,
2039
+ /* .cpy_tensor_from_async = */ NULL,
2040
+ /* .cpy_tensor_to_async = */ NULL,
2041
+ /* .synchronize = */ wsp_ggml_backend_metal_synchronize,
2042
+ /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
2043
+ /* .graph_plan_free = */ NULL,
2044
+ /* .graph_plan_compute = */ NULL,
2045
+ /* .graph_compute = */ wsp_ggml_backend_metal_graph_compute,
2046
+ /* .supports_op = */ wsp_ggml_backend_metal_supports_op,
2047
+ };
2048
+
2049
+ // TODO: make a common log callback for all backends in ggml-backend
2050
+ static void wsp_ggml_backend_log_callback(enum wsp_ggml_log_level level, const char * msg, void * user_data) {
2051
+ fprintf(stderr, "%s", msg);
2052
+
2053
+ UNUSED(level);
2054
+ UNUSED(user_data);
2055
+ }
2056
+
2057
+ wsp_ggml_backend_t wsp_ggml_backend_metal_init(void) {
2058
+ wsp_ggml_metal_log_set_callback(wsp_ggml_backend_log_callback, NULL);
2059
+
2060
+ struct wsp_ggml_metal_context * ctx = wsp_ggml_metal_init(WSP_GGML_DEFAULT_N_THREADS);
2061
+
2062
+ if (ctx == NULL) {
2063
+ return NULL;
2064
+ }
2065
+
2066
+ wsp_ggml_backend_t metal_backend = malloc(sizeof(struct wsp_ggml_backend));
2067
+
2068
+ *metal_backend = (struct wsp_ggml_backend) {
2069
+ /* .interface = */ metal_backend_i,
2070
+ /* .context = */ ctx,
2071
+ };
2072
+
2073
+ return metal_backend;
2074
+ }
2075
+
2076
+ bool wsp_ggml_backend_is_metal(wsp_ggml_backend_t backend) {
2077
+ return backend->iface.get_name == wsp_ggml_backend_metal_name;
2078
+ }
2079
+
2080
+ void wsp_ggml_backend_metal_set_n_cb(wsp_ggml_backend_t backend, int n_cb) {
2081
+ WSP_GGML_ASSERT(wsp_ggml_backend_is_metal(backend));
2082
+
2083
+ struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
2084
+
2085
+ wsp_ggml_metal_set_n_cb(ctx, n_cb);
2086
+ }
2087
+
2088
+ bool wsp_ggml_backend_metal_supports_family(wsp_ggml_backend_t backend, int family) {
2089
+ WSP_GGML_ASSERT(wsp_ggml_backend_is_metal(backend));
2090
+
2091
+ struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
2092
+
2093
+ return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
2094
+ }
2095
+
2096
+ wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
2097
+
2098
+ wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data) {
2099
+ return wsp_ggml_backend_metal_init();
2100
+
2101
+ WSP_GGML_UNUSED(params);
2102
+ WSP_GGML_UNUSED(user_data);
2103
+ }