whisper.rn 0.4.0-rc.2 → 0.4.0-rc.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. package/android/src/main/CMakeLists.txt +2 -0
  2. package/android/src/main/java/com/rnwhisper/RNWhisper.java +6 -1
  3. package/android/src/main/java/com/rnwhisper/WhisperContext.java +29 -15
  4. package/android/src/main/jni.cpp +6 -2
  5. package/cpp/ggml-alloc.c +413 -280
  6. package/cpp/ggml-alloc.h +67 -8
  7. package/cpp/ggml-backend-impl.h +87 -0
  8. package/cpp/ggml-backend.c +950 -0
  9. package/cpp/ggml-backend.h +136 -0
  10. package/cpp/ggml-impl.h +243 -0
  11. package/cpp/{ggml-metal.metal → ggml-metal-whisper.metal} +591 -121
  12. package/cpp/ggml-metal.h +21 -0
  13. package/cpp/ggml-metal.m +623 -234
  14. package/cpp/ggml-quants.c +7377 -0
  15. package/cpp/ggml-quants.h +224 -0
  16. package/cpp/ggml.c +3773 -4455
  17. package/cpp/ggml.h +279 -146
  18. package/cpp/whisper.cpp +182 -103
  19. package/cpp/whisper.h +48 -11
  20. package/ios/RNWhisper.mm +8 -2
  21. package/ios/RNWhisperContext.h +6 -2
  22. package/ios/RNWhisperContext.mm +97 -26
  23. package/jest/mock.js +1 -1
  24. package/lib/commonjs/NativeRNWhisper.js.map +1 -1
  25. package/lib/commonjs/index.js +28 -9
  26. package/lib/commonjs/index.js.map +1 -1
  27. package/lib/commonjs/version.json +1 -1
  28. package/lib/module/NativeRNWhisper.js.map +1 -1
  29. package/lib/module/index.js +28 -9
  30. package/lib/module/index.js.map +1 -1
  31. package/lib/module/version.json +1 -1
  32. package/lib/typescript/NativeRNWhisper.d.ts +7 -1
  33. package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
  34. package/lib/typescript/index.d.ts +8 -3
  35. package/lib/typescript/index.d.ts.map +1 -1
  36. package/package.json +1 -1
  37. package/src/NativeRNWhisper.ts +8 -1
  38. package/src/index.ts +30 -18
  39. package/src/version.json +1 -1
  40. package/whisper-rn.podspec +1 -2
package/cpp/ggml-metal.m CHANGED
@@ -1,5 +1,6 @@
1
1
  #import "ggml-metal.h"
2
2
 
3
+ #import "ggml-backend-impl.h"
3
4
  #import "ggml.h"
4
5
 
5
6
  #import <Foundation/Foundation.h>
@@ -11,16 +12,19 @@
11
12
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
12
13
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
13
14
 
14
- // TODO: temporary - reuse llama.cpp logging
15
15
  #ifdef WSP_GGML_METAL_NDEBUG
16
- #define metal_printf(...)
16
+ #define WSP_GGML_METAL_LOG_INFO(...)
17
+ #define WSP_GGML_METAL_LOG_WARN(...)
18
+ #define WSP_GGML_METAL_LOG_ERROR(...)
17
19
  #else
18
- #define metal_printf(...) fprintf(stderr, __VA_ARGS__)
20
+ #define WSP_GGML_METAL_LOG_INFO(...) wsp_ggml_metal_log(WSP_GGML_LOG_LEVEL_INFO, __VA_ARGS__)
21
+ #define WSP_GGML_METAL_LOG_WARN(...) wsp_ggml_metal_log(WSP_GGML_LOG_LEVEL_WARN, __VA_ARGS__)
22
+ #define WSP_GGML_METAL_LOG_ERROR(...) wsp_ggml_metal_log(WSP_GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
19
23
  #endif
20
24
 
21
25
  #define UNUSED(x) (void)(x)
22
26
 
23
- #define WSP_GGML_MAX_CONCUR (2*WSP_GGML_MAX_NODES)
27
+ #define WSP_GGML_MAX_CONCUR (2*WSP_GGML_DEFAULT_GRAPH_SIZE)
24
28
 
25
29
  struct wsp_ggml_metal_buffer {
26
30
  const char * name;
@@ -59,6 +63,7 @@ struct wsp_ggml_metal_context {
59
63
  WSP_GGML_METAL_DECL_KERNEL(mul);
60
64
  WSP_GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
61
65
  WSP_GGML_METAL_DECL_KERNEL(scale);
66
+ WSP_GGML_METAL_DECL_KERNEL(scale_4);
62
67
  WSP_GGML_METAL_DECL_KERNEL(silu);
63
68
  WSP_GGML_METAL_DECL_KERNEL(relu);
64
69
  WSP_GGML_METAL_DECL_KERNEL(gelu);
@@ -70,6 +75,8 @@ struct wsp_ggml_metal_context {
70
75
  WSP_GGML_METAL_DECL_KERNEL(get_rows_f16);
71
76
  WSP_GGML_METAL_DECL_KERNEL(get_rows_q4_0);
72
77
  WSP_GGML_METAL_DECL_KERNEL(get_rows_q4_1);
78
+ WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_0);
79
+ WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_1);
73
80
  WSP_GGML_METAL_DECL_KERNEL(get_rows_q8_0);
74
81
  WSP_GGML_METAL_DECL_KERNEL(get_rows_q2_K);
75
82
  WSP_GGML_METAL_DECL_KERNEL(get_rows_q3_K);
@@ -78,33 +85,40 @@ struct wsp_ggml_metal_context {
78
85
  WSP_GGML_METAL_DECL_KERNEL(get_rows_q6_K);
79
86
  WSP_GGML_METAL_DECL_KERNEL(rms_norm);
80
87
  WSP_GGML_METAL_DECL_KERNEL(norm);
81
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
82
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
83
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
84
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
85
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
86
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
87
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
88
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
89
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
90
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
91
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
92
- WSP_GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
88
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
89
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
90
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
91
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
92
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
93
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
94
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
95
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
96
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
97
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
98
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
99
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
100
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
101
+ WSP_GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
93
102
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
94
103
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
95
104
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
96
105
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
106
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
107
+ WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
97
108
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
98
109
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
99
110
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
100
111
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
101
112
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
102
113
  WSP_GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
103
- WSP_GGML_METAL_DECL_KERNEL(rope);
114
+ WSP_GGML_METAL_DECL_KERNEL(rope_f32);
115
+ WSP_GGML_METAL_DECL_KERNEL(rope_f16);
104
116
  WSP_GGML_METAL_DECL_KERNEL(alibi_f32);
105
117
  WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f16);
106
118
  WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f32);
107
119
  WSP_GGML_METAL_DECL_KERNEL(cpy_f16_f16);
120
+ WSP_GGML_METAL_DECL_KERNEL(concat);
121
+ WSP_GGML_METAL_DECL_KERNEL(sqr);
108
122
 
109
123
  #undef WSP_GGML_METAL_DECL_KERNEL
110
124
  };
@@ -115,13 +129,42 @@ struct wsp_ggml_metal_context {
115
129
  static NSString * const msl_library_source = @"see metal.metal";
116
130
 
117
131
  // Here to assist with NSBundle Path Hack
118
- @interface GGMLMetalClass : NSObject
132
+ @interface WSPGGMLMetalClass : NSObject
119
133
  @end
120
- @implementation GGMLMetalClass
134
+ @implementation WSPGGMLMetalClass
121
135
  @end
122
136
 
137
+ wsp_ggml_log_callback wsp_ggml_metal_log_callback = NULL;
138
+ void * wsp_ggml_metal_log_user_data = NULL;
139
+
140
+ void wsp_ggml_metal_log_set_callback(wsp_ggml_log_callback log_callback, void * user_data) {
141
+ wsp_ggml_metal_log_callback = log_callback;
142
+ wsp_ggml_metal_log_user_data = user_data;
143
+ }
144
+
145
+ static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char* format, ...){
146
+ if (wsp_ggml_metal_log_callback != NULL) {
147
+ va_list args;
148
+ va_start(args, format);
149
+ char buffer[128];
150
+ int len = vsnprintf(buffer, 128, format, args);
151
+ if (len < 128) {
152
+ wsp_ggml_metal_log_callback(level, buffer, wsp_ggml_metal_log_user_data);
153
+ } else {
154
+ char* buffer2 = malloc(len+1);
155
+ vsnprintf(buffer2, len+1, format, args);
156
+ buffer2[len] = 0;
157
+ wsp_ggml_metal_log_callback(level, buffer2, wsp_ggml_metal_log_user_data);
158
+ free(buffer2);
159
+ }
160
+ va_end(args);
161
+ }
162
+ }
163
+
164
+
165
+
123
166
  struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
124
- metal_printf("%s: allocating\n", __func__);
167
+ WSP_GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
125
168
 
126
169
  id <MTLDevice> device;
127
170
  NSString * s;
@@ -131,14 +174,14 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
131
174
  NSArray * devices = MTLCopyAllDevices();
132
175
  for (device in devices) {
133
176
  s = [device name];
134
- metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
177
+ WSP_GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]);
135
178
  }
136
179
  #endif
137
180
 
138
181
  // Pick and show default Metal device
139
182
  device = MTLCreateSystemDefaultDevice();
140
183
  s = [device name];
141
- metal_printf("%s: picking default device: %s\n", __func__, [s UTF8String]);
184
+ WSP_GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
142
185
 
143
186
  // Configure context
144
187
  struct wsp_ggml_metal_context * ctx = malloc(sizeof(struct wsp_ggml_metal_context));
@@ -150,68 +193,69 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
150
193
 
151
194
  ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
152
195
 
153
- #ifdef WSP_GGML_SWIFT
154
- // load the default.metallib file
196
+ // load library
155
197
  {
156
- NSError * error = nil;
157
-
158
- NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
159
- NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"];
160
- NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
161
- NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
162
- NSURL * libURL = [NSURL fileURLWithPath:libPath];
163
-
164
- // Load the metallib file into a Metal library
165
- ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
166
-
167
- if (error) {
168
- metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
169
- return NULL;
170
- }
171
- }
198
+ NSBundle * bundle = nil;
199
+ #ifdef SWIFT_PACKAGE
200
+ bundle = SWIFTPM_MODULE_BUNDLE;
172
201
  #else
173
- UNUSED(msl_library_source);
174
-
175
- // read the source from "ggml-metal.metal" into a string and use newLibraryWithSource
176
- {
202
+ bundle = [NSBundle bundleForClass:[WSPGGMLMetalClass class]];
203
+ #endif
177
204
  NSError * error = nil;
205
+ NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"];
206
+ if (libPath != nil) {
207
+ NSURL * libURL = [NSURL fileURLWithPath:libPath];
208
+ WSP_GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
209
+ ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
210
+ } else {
211
+ WSP_GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
212
+
213
+ NSString * sourcePath;
214
+ NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"WSP_GGML_METAL_PATH_RESOURCES"];
215
+ if (ggmlMetalPathResources) {
216
+ sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
217
+ } else {
218
+ sourcePath = [bundle pathForResource:@"ggml-metal-whisper" ofType:@"metal"];
219
+ }
220
+ if (sourcePath == nil) {
221
+ WSP_GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
222
+ sourcePath = @"ggml-metal.metal";
223
+ }
224
+ WSP_GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [sourcePath UTF8String]);
225
+ NSString * src = [NSString stringWithContentsOfFile:sourcePath encoding:NSUTF8StringEncoding error:&error];
226
+ if (error) {
227
+ WSP_GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
228
+ return NULL;
229
+ }
178
230
 
179
- //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
180
- NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
181
- NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
182
- metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
183
-
184
- NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
185
- if (error) {
186
- metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
187
- return NULL;
188
- }
189
-
231
+ MTLCompileOptions* options = nil;
190
232
  #ifdef WSP_GGML_QKK_64
191
- MTLCompileOptions* options = [MTLCompileOptions new];
192
- options.preprocessorMacros = @{ @"QK_K" : @(64) };
193
- ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
194
- #else
195
- ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
233
+ options = [MTLCompileOptions new];
234
+ options.preprocessorMacros = @{ @"QK_K" : @(64) };
196
235
  #endif
236
+ ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
237
+ }
238
+
197
239
  if (error) {
198
- metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
240
+ WSP_GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
199
241
  return NULL;
200
242
  }
201
243
  }
202
- #endif
203
244
 
204
245
  // load kernels
205
246
  {
206
247
  NSError * error = nil;
248
+
249
+ /*
250
+ WSP_GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
251
+ (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
252
+ (int) ctx->pipeline_##name.threadExecutionWidth); \
253
+ */
207
254
  #define WSP_GGML_METAL_ADD_KERNEL(name) \
208
255
  ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
209
256
  ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
210
- metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (__bridge void *) ctx->pipeline_##name, \
211
- (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
212
- (int) ctx->pipeline_##name.threadExecutionWidth); \
213
257
  if (error) { \
214
- metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
258
+ WSP_GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
215
259
  return NULL; \
216
260
  }
217
261
 
@@ -220,6 +264,7 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
220
264
  WSP_GGML_METAL_ADD_KERNEL(mul);
221
265
  WSP_GGML_METAL_ADD_KERNEL(mul_row);
222
266
  WSP_GGML_METAL_ADD_KERNEL(scale);
267
+ WSP_GGML_METAL_ADD_KERNEL(scale_4);
223
268
  WSP_GGML_METAL_ADD_KERNEL(silu);
224
269
  WSP_GGML_METAL_ADD_KERNEL(relu);
225
270
  WSP_GGML_METAL_ADD_KERNEL(gelu);
@@ -231,6 +276,8 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
231
276
  WSP_GGML_METAL_ADD_KERNEL(get_rows_f16);
232
277
  WSP_GGML_METAL_ADD_KERNEL(get_rows_q4_0);
233
278
  WSP_GGML_METAL_ADD_KERNEL(get_rows_q4_1);
279
+ WSP_GGML_METAL_ADD_KERNEL(get_rows_q5_0);
280
+ WSP_GGML_METAL_ADD_KERNEL(get_rows_q5_1);
234
281
  WSP_GGML_METAL_ADD_KERNEL(get_rows_q8_0);
235
282
  WSP_GGML_METAL_ADD_KERNEL(get_rows_q2_K);
236
283
  WSP_GGML_METAL_ADD_KERNEL(get_rows_q3_K);
@@ -239,44 +286,66 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
239
286
  WSP_GGML_METAL_ADD_KERNEL(get_rows_q6_K);
240
287
  WSP_GGML_METAL_ADD_KERNEL(rms_norm);
241
288
  WSP_GGML_METAL_ADD_KERNEL(norm);
242
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
243
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
244
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
245
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
246
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
247
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
248
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
249
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
250
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
251
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
252
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
253
- WSP_GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
254
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
255
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
256
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
257
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
258
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
259
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
260
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
261
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
262
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
263
- WSP_GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
264
- WSP_GGML_METAL_ADD_KERNEL(rope);
289
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
290
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
291
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
292
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
293
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
294
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
295
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
296
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
297
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
298
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
299
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
300
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
301
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
302
+ WSP_GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
303
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
304
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
305
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
306
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
307
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
308
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
309
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
310
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
311
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
312
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
313
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
314
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
315
+ WSP_GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
316
+ }
317
+ WSP_GGML_METAL_ADD_KERNEL(rope_f32);
318
+ WSP_GGML_METAL_ADD_KERNEL(rope_f16);
265
319
  WSP_GGML_METAL_ADD_KERNEL(alibi_f32);
266
320
  WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f16);
267
321
  WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f32);
268
322
  WSP_GGML_METAL_ADD_KERNEL(cpy_f16_f16);
323
+ WSP_GGML_METAL_ADD_KERNEL(concat);
324
+ WSP_GGML_METAL_ADD_KERNEL(sqr);
269
325
 
270
326
  #undef WSP_GGML_METAL_ADD_KERNEL
271
327
  }
272
328
 
273
- metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
274
329
  #if TARGET_OS_OSX
275
- metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
330
+ // print MTL GPU family:
331
+ WSP_GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
332
+
333
+ // determine max supported GPU family
334
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
335
+ // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
336
+ for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
337
+ if ([ctx->device supportsFamily:i]) {
338
+ WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
339
+ break;
340
+ }
341
+ }
342
+
343
+ WSP_GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
344
+ WSP_GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
276
345
  if (ctx->device.maxTransferRate != 0) {
277
- metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
346
+ WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
278
347
  } else {
279
- metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
348
+ WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
280
349
  }
281
350
  #endif
282
351
 
@@ -284,7 +353,7 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
284
353
  }
285
354
 
286
355
  void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
287
- metal_printf("%s: deallocating\n", __func__);
356
+ WSP_GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
288
357
  #define WSP_GGML_METAL_DEL_KERNEL(name) \
289
358
 
290
359
  WSP_GGML_METAL_DEL_KERNEL(add);
@@ -292,6 +361,7 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
292
361
  WSP_GGML_METAL_DEL_KERNEL(mul);
293
362
  WSP_GGML_METAL_DEL_KERNEL(mul_row);
294
363
  WSP_GGML_METAL_DEL_KERNEL(scale);
364
+ WSP_GGML_METAL_DEL_KERNEL(scale_4);
295
365
  WSP_GGML_METAL_DEL_KERNEL(silu);
296
366
  WSP_GGML_METAL_DEL_KERNEL(relu);
297
367
  WSP_GGML_METAL_DEL_KERNEL(gelu);
@@ -303,6 +373,8 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
303
373
  WSP_GGML_METAL_DEL_KERNEL(get_rows_f16);
304
374
  WSP_GGML_METAL_DEL_KERNEL(get_rows_q4_0);
305
375
  WSP_GGML_METAL_DEL_KERNEL(get_rows_q4_1);
376
+ WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_0);
377
+ WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_1);
306
378
  WSP_GGML_METAL_DEL_KERNEL(get_rows_q8_0);
307
379
  WSP_GGML_METAL_DEL_KERNEL(get_rows_q2_K);
308
380
  WSP_GGML_METAL_DEL_KERNEL(get_rows_q3_K);
@@ -311,33 +383,42 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
311
383
  WSP_GGML_METAL_DEL_KERNEL(get_rows_q6_K);
312
384
  WSP_GGML_METAL_DEL_KERNEL(rms_norm);
313
385
  WSP_GGML_METAL_DEL_KERNEL(norm);
314
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
315
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
316
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
317
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
318
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
319
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
320
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
321
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
322
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
323
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
324
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
325
- WSP_GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
326
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
327
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
328
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
329
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
330
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
331
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
332
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
333
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
334
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
335
- WSP_GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
336
- WSP_GGML_METAL_DEL_KERNEL(rope);
386
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
387
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
388
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
389
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
390
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
391
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
392
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
393
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
394
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
395
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
396
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
397
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
398
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
399
+ WSP_GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
400
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
401
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
402
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
403
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
404
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
405
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
406
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
407
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
408
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
409
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
410
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
411
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
412
+ WSP_GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
413
+ }
414
+ WSP_GGML_METAL_DEL_KERNEL(rope_f32);
415
+ WSP_GGML_METAL_DEL_KERNEL(rope_f16);
337
416
  WSP_GGML_METAL_DEL_KERNEL(alibi_f32);
338
417
  WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f16);
339
418
  WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f32);
340
419
  WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f16);
420
+ WSP_GGML_METAL_DEL_KERNEL(concat);
421
+ WSP_GGML_METAL_DEL_KERNEL(sqr);
341
422
 
342
423
  #undef WSP_GGML_METAL_DEL_KERNEL
343
424
 
@@ -348,7 +429,7 @@ void * wsp_ggml_metal_host_malloc(size_t n) {
348
429
  void * data = NULL;
349
430
  const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
350
431
  if (result != 0) {
351
- metal_printf("%s: error: posix_memalign failed\n", __func__);
432
+ WSP_GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
352
433
  return NULL;
353
434
  }
354
435
 
@@ -376,7 +457,7 @@ int * wsp_ggml_metal_get_concur_list(struct wsp_ggml_metal_context * ctx) {
376
457
  // Metal buffer based on the host memory pointer
377
458
  //
378
459
  static id<MTLBuffer> wsp_ggml_metal_get_buffer(struct wsp_ggml_metal_context * ctx, struct wsp_ggml_tensor * t, size_t * offs) {
379
- //metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
460
+ //WSP_GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
380
461
 
381
462
  const int64_t tsize = wsp_ggml_nbytes(t);
382
463
 
@@ -384,17 +465,17 @@ static id<MTLBuffer> wsp_ggml_metal_get_buffer(struct wsp_ggml_metal_context * c
384
465
  for (int i = 0; i < ctx->n_buffers; ++i) {
385
466
  const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
386
467
 
387
- //metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
468
+ //WSP_GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
388
469
  if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
389
470
  *offs = (size_t) ioffs;
390
471
 
391
- //metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
472
+ //WSP_GGML_METAL_LOG_INFO("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
392
473
 
393
474
  return ctx->buffers[i].metal;
394
475
  }
395
476
  }
396
477
 
397
- metal_printf("%s: error: buffer is nil\n", __func__);
478
+ WSP_GGML_METAL_LOG_ERROR("%s: error: buffer is nil\n", __func__);
398
479
 
399
480
  return nil;
400
481
  }
@@ -406,7 +487,7 @@ bool wsp_ggml_metal_add_buffer(
406
487
  size_t size,
407
488
  size_t max_size) {
408
489
  if (ctx->n_buffers >= WSP_GGML_METAL_MAX_BUFFERS) {
409
- metal_printf("%s: too many buffers\n", __func__);
490
+ WSP_GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__);
410
491
  return false;
411
492
  }
412
493
 
@@ -416,7 +497,7 @@ bool wsp_ggml_metal_add_buffer(
416
497
  const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
417
498
 
418
499
  if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
419
- metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
500
+ WSP_GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
420
501
  return false;
421
502
  }
422
503
  }
@@ -437,11 +518,11 @@ bool wsp_ggml_metal_add_buffer(
437
518
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
438
519
 
439
520
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
440
- metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
521
+ WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
441
522
  return false;
442
523
  }
443
524
 
444
- metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
525
+ WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
445
526
 
446
527
  ++ctx->n_buffers;
447
528
  } else {
@@ -461,13 +542,13 @@ bool wsp_ggml_metal_add_buffer(
461
542
  ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
462
543
 
463
544
  if (ctx->buffers[ctx->n_buffers].metal == nil) {
464
- metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
545
+ WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
465
546
  return false;
466
547
  }
467
548
 
468
- metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
549
+ WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
469
550
  if (i + size_step < size) {
470
- metal_printf("\n");
551
+ WSP_GGML_METAL_LOG_INFO("\n");
471
552
  }
472
553
 
473
554
  ++ctx->n_buffers;
@@ -475,17 +556,17 @@ bool wsp_ggml_metal_add_buffer(
475
556
  }
476
557
 
477
558
  #if TARGET_OS_OSX
478
- metal_printf(", (%8.2f / %8.2f)",
559
+ WSP_GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
479
560
  ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
480
561
  ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
481
562
 
482
563
  if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
483
- metal_printf(", warning: current allocated size is greater than the recommended max working set size\n");
564
+ WSP_GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__);
484
565
  } else {
485
- metal_printf("\n");
566
+ WSP_GGML_METAL_LOG_INFO("\n");
486
567
  }
487
568
  #else
488
- metal_printf(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
569
+ WSP_GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
489
570
  #endif
490
571
  }
491
572
 
@@ -598,7 +679,7 @@ void wsp_ggml_metal_graph_find_concurrency(
598
679
  }
599
680
 
600
681
  if (ctx->concur_list_len > WSP_GGML_MAX_CONCUR) {
601
- metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__);
682
+ WSP_GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__);
602
683
  }
603
684
  }
604
685
 
@@ -652,12 +733,26 @@ void wsp_ggml_metal_graph_compute(
652
733
  continue;
653
734
  }
654
735
 
655
- //metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, wsp_ggml_op_name(gf->nodes[i]->op));
736
+ //WSP_GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, wsp_ggml_op_name(gf->nodes[i]->op));
656
737
 
657
738
  struct wsp_ggml_tensor * src0 = gf->nodes[i]->src[0];
658
739
  struct wsp_ggml_tensor * src1 = gf->nodes[i]->src[1];
659
740
  struct wsp_ggml_tensor * dst = gf->nodes[i];
660
741
 
742
+ switch (dst->op) {
743
+ case WSP_GGML_OP_NONE:
744
+ case WSP_GGML_OP_RESHAPE:
745
+ case WSP_GGML_OP_VIEW:
746
+ case WSP_GGML_OP_TRANSPOSE:
747
+ case WSP_GGML_OP_PERMUTE:
748
+ {
749
+ // noop -> next node
750
+ } continue;
751
+ default:
752
+ {
753
+ } break;
754
+ }
755
+
661
756
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
662
757
  const int64_t ne01 = src0 ? src0->ne[1] : 0;
663
758
  const int64_t ne02 = src0 ? src0->ne[2] : 0;
@@ -696,53 +791,117 @@ void wsp_ggml_metal_graph_compute(
696
791
  id<MTLBuffer> id_src1 = src1 ? wsp_ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
697
792
  id<MTLBuffer> id_dst = dst ? wsp_ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
698
793
 
699
- //metal_printf("%s: op - %s\n", __func__, wsp_ggml_op_name(dst->op));
794
+ //WSP_GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, wsp_ggml_op_name(dst->op));
700
795
  //if (src0) {
701
- // metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src0t), ne00, ne01, ne02,
796
+ // WSP_GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src0t), ne00, ne01, ne02,
702
797
  // wsp_ggml_is_contiguous(src0), src0->name);
703
798
  //}
704
799
  //if (src1) {
705
- // metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src1t), ne10, ne11, ne12,
800
+ // WSP_GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src1t), ne10, ne11, ne12,
706
801
  // wsp_ggml_is_contiguous(src1), src1->name);
707
802
  //}
708
803
  //if (dst) {
709
- // metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, wsp_ggml_type_name(dstt), ne0, ne1, ne2,
804
+ // WSP_GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, wsp_ggml_type_name(dstt), ne0, ne1, ne2,
710
805
  // dst->name);
711
806
  //}
712
807
 
713
808
  switch (dst->op) {
714
- case WSP_GGML_OP_NONE:
715
- case WSP_GGML_OP_RESHAPE:
716
- case WSP_GGML_OP_VIEW:
717
- case WSP_GGML_OP_TRANSPOSE:
718
- case WSP_GGML_OP_PERMUTE:
809
+ case WSP_GGML_OP_CONCAT:
719
810
  {
720
- // noop
811
+ const int64_t nb = ne00;
812
+
813
+ [encoder setComputePipelineState:ctx->pipeline_concat];
814
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
815
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
816
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
817
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
818
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
819
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
820
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
821
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
822
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
823
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
824
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
825
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
826
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
827
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
828
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
829
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
830
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
831
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
832
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
833
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
834
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
835
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
836
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
837
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
838
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
839
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
840
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
841
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
842
+
843
+ const int nth = MIN(1024, ne0);
844
+
845
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
721
846
  } break;
722
847
  case WSP_GGML_OP_ADD:
723
848
  {
724
849
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
725
850
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
726
851
 
727
- // utilize float4
728
- WSP_GGML_ASSERT(ne00 % 4 == 0);
729
- const int64_t nb = ne00/4;
852
+ bool bcast_row = false;
730
853
 
731
- if (wsp_ggml_nelements(src1) == ne10) {
854
+ int64_t nb = ne00;
855
+
856
+ if (wsp_ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
732
857
  // src1 is a row
733
858
  WSP_GGML_ASSERT(ne11 == 1);
859
+
860
+ nb = ne00 / 4;
734
861
  [encoder setComputePipelineState:ctx->pipeline_add_row];
862
+
863
+ bcast_row = true;
735
864
  } else {
736
865
  [encoder setComputePipelineState:ctx->pipeline_add];
737
866
  }
738
867
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
739
868
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
740
869
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
741
- [encoder setBytes:&nb length:sizeof(nb) atIndex:3];
742
-
743
- const int64_t n = wsp_ggml_nelements(dst)/4;
870
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
871
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
872
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
873
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
874
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
875
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
876
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
877
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
878
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
879
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
880
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
881
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
882
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
883
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
884
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
885
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
886
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
887
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
888
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
889
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
890
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
891
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
892
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
893
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
894
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
895
+
896
+ if (bcast_row) {
897
+ const int64_t n = wsp_ggml_nelements(dst)/4;
898
+
899
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
900
+ } else {
901
+ const int nth = MIN(1024, ne0);
744
902
 
745
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
903
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
904
+ }
746
905
  } break;
747
906
  case WSP_GGML_OP_MUL:
748
907
  {
@@ -775,13 +934,19 @@ void wsp_ggml_metal_graph_compute(
775
934
 
776
935
  const float scale = *(const float *) src1->data;
777
936
 
778
- [encoder setComputePipelineState:ctx->pipeline_scale];
937
+ int64_t n = wsp_ggml_nelements(dst);
938
+
939
+ if (n % 4 == 0) {
940
+ n /= 4;
941
+ [encoder setComputePipelineState:ctx->pipeline_scale_4];
942
+ } else {
943
+ [encoder setComputePipelineState:ctx->pipeline_scale];
944
+ }
945
+
779
946
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
780
947
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
781
948
  [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
782
949
 
783
- const int64_t n = wsp_ggml_nelements(dst)/4;
784
-
785
950
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
786
951
  } break;
787
952
  case WSP_GGML_OP_UNARY:
@@ -792,9 +957,10 @@ void wsp_ggml_metal_graph_compute(
792
957
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
793
958
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
794
959
 
795
- const int64_t n = wsp_ggml_nelements(dst)/4;
960
+ const int64_t n = wsp_ggml_nelements(dst);
961
+ WSP_GGML_ASSERT(n % 4 == 0);
796
962
 
797
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
963
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
798
964
  } break;
799
965
  case WSP_GGML_UNARY_OP_RELU:
800
966
  {
@@ -812,23 +978,39 @@ void wsp_ggml_metal_graph_compute(
812
978
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
813
979
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
814
980
 
815
- const int64_t n = wsp_ggml_nelements(dst)/4;
981
+ const int64_t n = wsp_ggml_nelements(dst);
982
+ WSP_GGML_ASSERT(n % 4 == 0);
816
983
 
817
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
984
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
818
985
  } break;
819
986
  default:
820
987
  {
821
- metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op));
988
+ WSP_GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op));
822
989
  WSP_GGML_ASSERT(false);
823
990
  }
824
991
  } break;
992
+ case WSP_GGML_OP_SQR:
993
+ {
994
+ WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
995
+
996
+ [encoder setComputePipelineState:ctx->pipeline_sqr];
997
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
998
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
999
+
1000
+ const int64_t n = wsp_ggml_nelements(dst);
1001
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1002
+ } break;
825
1003
  case WSP_GGML_OP_SOFT_MAX:
826
1004
  {
827
- const int nth = 32;
1005
+ int nth = 32; // SIMD width
828
1006
 
829
1007
  if (ne00%4 == 0) {
830
1008
  [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
831
1009
  } else {
1010
+ do {
1011
+ nth *= 2;
1012
+ } while (nth <= ne00 && nth <= 1024);
1013
+ nth /= 2;
832
1014
  [encoder setComputePipelineState:ctx->pipeline_soft_max];
833
1015
  }
834
1016
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -836,8 +1018,9 @@ void wsp_ggml_metal_graph_compute(
836
1018
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
837
1019
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
838
1020
  [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1021
+ [encoder setThreadgroupMemoryLength:WSP_GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
839
1022
 
840
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1023
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
841
1024
  } break;
842
1025
  case WSP_GGML_OP_DIAG_MASK_INF:
843
1026
  {
@@ -863,26 +1046,53 @@ void wsp_ggml_metal_graph_compute(
863
1046
  } break;
864
1047
  case WSP_GGML_OP_MUL_MAT:
865
1048
  {
866
- // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
867
-
868
1049
  WSP_GGML_ASSERT(ne00 == ne10);
869
- // WSP_GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
870
- uint gqa = ne12/ne02;
871
1050
  WSP_GGML_ASSERT(ne03 == ne13);
872
1051
 
1052
+ const uint gqa = ne12/ne02;
1053
+
1054
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1055
+ // to the matrix-vector kernel
1056
+ int ne11_mm_min = 1;
1057
+
1058
+ #if 0
1059
+ // the numbers below are measured on M2 Ultra for 7B and 13B models
1060
+ // these numbers do not translate to other devices or model sizes
1061
+ // TODO: need to find a better approach
1062
+ if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
1063
+ switch (src0t) {
1064
+ case WSP_GGML_TYPE_F16: ne11_mm_min = 2; break;
1065
+ case WSP_GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1066
+ case WSP_GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1067
+ case WSP_GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1068
+ case WSP_GGML_TYPE_Q4_0:
1069
+ case WSP_GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1070
+ case WSP_GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1071
+ case WSP_GGML_TYPE_Q5_0: // not tested yet
1072
+ case WSP_GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1073
+ case WSP_GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1074
+ case WSP_GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1075
+ default: ne11_mm_min = 1; break;
1076
+ }
1077
+ }
1078
+ #endif
1079
+
873
1080
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
874
1081
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
875
- if (!wsp_ggml_is_transposed(src0) &&
1082
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1083
+ !wsp_ggml_is_transposed(src0) &&
876
1084
  !wsp_ggml_is_transposed(src1) &&
877
1085
  src1t == WSP_GGML_TYPE_F32 &&
878
- [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
879
- ne00%32 == 0 &&
880
- ne11 > 1) {
1086
+ ne00 % 32 == 0 && ne00 >= 64 &&
1087
+ ne11 > ne11_mm_min) {
1088
+ //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
881
1089
  switch (src0->type) {
882
1090
  case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
883
1091
  case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
884
1092
  case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
885
1093
  case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
1094
+ case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
1095
+ case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
886
1096
  case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
887
1097
  case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
888
1098
  case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
@@ -906,17 +1116,18 @@ void wsp_ggml_metal_graph_compute(
906
1116
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
907
1117
  [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
908
1118
  [encoder setThreadgroupMemoryLength:8192 atIndex:0];
909
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1119
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
910
1120
  } else {
911
1121
  int nth0 = 32;
912
1122
  int nth1 = 1;
913
1123
  int nrows = 1;
1124
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
914
1125
 
915
1126
  // use custom matrix x vector kernel
916
1127
  switch (src0t) {
917
1128
  case WSP_GGML_TYPE_F32:
918
1129
  {
919
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
1130
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
920
1131
  nrows = 4;
921
1132
  } break;
922
1133
  case WSP_GGML_TYPE_F16:
@@ -924,12 +1135,12 @@ void wsp_ggml_metal_graph_compute(
924
1135
  nth0 = 32;
925
1136
  nth1 = 1;
926
1137
  if (ne11 * ne12 < 4) {
927
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
1138
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
928
1139
  } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
929
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
1140
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
930
1141
  nrows = ne11;
931
1142
  } else {
932
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
1143
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
933
1144
  nrows = 4;
934
1145
  }
935
1146
  } break;
@@ -940,7 +1151,7 @@ void wsp_ggml_metal_graph_compute(
940
1151
 
941
1152
  nth0 = 8;
942
1153
  nth1 = 8;
943
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
1154
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
944
1155
  } break;
945
1156
  case WSP_GGML_TYPE_Q4_1:
946
1157
  {
@@ -949,7 +1160,25 @@ void wsp_ggml_metal_graph_compute(
949
1160
 
950
1161
  nth0 = 8;
951
1162
  nth1 = 8;
952
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
1163
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
1164
+ } break;
1165
+ case WSP_GGML_TYPE_Q5_0:
1166
+ {
1167
+ WSP_GGML_ASSERT(ne02 == 1);
1168
+ WSP_GGML_ASSERT(ne12 == 1);
1169
+
1170
+ nth0 = 8;
1171
+ nth1 = 8;
1172
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
1173
+ } break;
1174
+ case WSP_GGML_TYPE_Q5_1:
1175
+ {
1176
+ WSP_GGML_ASSERT(ne02 == 1);
1177
+ WSP_GGML_ASSERT(ne12 == 1);
1178
+
1179
+ nth0 = 8;
1180
+ nth1 = 8;
1181
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
953
1182
  } break;
954
1183
  case WSP_GGML_TYPE_Q8_0:
955
1184
  {
@@ -958,7 +1187,7 @@ void wsp_ggml_metal_graph_compute(
958
1187
 
959
1188
  nth0 = 8;
960
1189
  nth1 = 8;
961
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
1190
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
962
1191
  } break;
963
1192
  case WSP_GGML_TYPE_Q2_K:
964
1193
  {
@@ -967,7 +1196,7 @@ void wsp_ggml_metal_graph_compute(
967
1196
 
968
1197
  nth0 = 2;
969
1198
  nth1 = 32;
970
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
1199
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
971
1200
  } break;
972
1201
  case WSP_GGML_TYPE_Q3_K:
973
1202
  {
@@ -976,7 +1205,7 @@ void wsp_ggml_metal_graph_compute(
976
1205
 
977
1206
  nth0 = 2;
978
1207
  nth1 = 32;
979
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
1208
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
980
1209
  } break;
981
1210
  case WSP_GGML_TYPE_Q4_K:
982
1211
  {
@@ -985,7 +1214,7 @@ void wsp_ggml_metal_graph_compute(
985
1214
 
986
1215
  nth0 = 4; //1;
987
1216
  nth1 = 8; //32;
988
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
1217
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
989
1218
  } break;
990
1219
  case WSP_GGML_TYPE_Q5_K:
991
1220
  {
@@ -994,7 +1223,7 @@ void wsp_ggml_metal_graph_compute(
994
1223
 
995
1224
  nth0 = 2;
996
1225
  nth1 = 32;
997
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
1226
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
998
1227
  } break;
999
1228
  case WSP_GGML_TYPE_Q6_K:
1000
1229
  {
@@ -1003,11 +1232,11 @@ void wsp_ggml_metal_graph_compute(
1003
1232
 
1004
1233
  nth0 = 2;
1005
1234
  nth1 = 32;
1006
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
1235
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
1007
1236
  } break;
1008
1237
  default:
1009
1238
  {
1010
- metal_printf("Asserting on type %d\n",(int)src0t);
1239
+ WSP_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
1011
1240
  WSP_GGML_ASSERT(false && "not implemented");
1012
1241
  }
1013
1242
  };
@@ -1031,8 +1260,9 @@ void wsp_ggml_metal_graph_compute(
1031
1260
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1032
1261
  [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
1033
1262
 
1034
- if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 || src0t == WSP_GGML_TYPE_Q8_0 ||
1035
- src0t == WSP_GGML_TYPE_Q2_K) {// || src0t == WSP_GGML_TYPE_Q4_K) {
1263
+ if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 ||
1264
+ src0t == WSP_GGML_TYPE_Q5_0 || src0t == WSP_GGML_TYPE_Q5_1 || src0t == WSP_GGML_TYPE_Q8_0 ||
1265
+ src0t == WSP_GGML_TYPE_Q2_K) { // || src0t == WSP_GGML_TYPE_Q4_K) {
1036
1266
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1037
1267
  }
1038
1268
  else if (src0t == WSP_GGML_TYPE_Q4_K) {
@@ -1063,6 +1293,8 @@ void wsp_ggml_metal_graph_compute(
1063
1293
  case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
1064
1294
  case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
1065
1295
  case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
1296
+ case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
1297
+ case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
1066
1298
  case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
1067
1299
  case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
1068
1300
  case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
@@ -1085,10 +1317,12 @@ void wsp_ggml_metal_graph_compute(
1085
1317
  } break;
1086
1318
  case WSP_GGML_OP_RMS_NORM:
1087
1319
  {
1320
+ WSP_GGML_ASSERT(ne00 % 4 == 0);
1321
+
1088
1322
  float eps;
1089
1323
  memcpy(&eps, dst->op_params, sizeof(float));
1090
1324
 
1091
- const int nth = 512;
1325
+ const int nth = MIN(512, ne00);
1092
1326
 
1093
1327
  [encoder setComputePipelineState:ctx->pipeline_rms_norm];
1094
1328
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1096,7 +1330,7 @@ void wsp_ggml_metal_graph_compute(
1096
1330
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1097
1331
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1098
1332
  [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1099
- [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
1333
+ [encoder setThreadgroupMemoryLength:WSP_GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
1100
1334
 
1101
1335
  const int64_t nrows = wsp_ggml_nrows(src0);
1102
1336
 
@@ -1107,7 +1341,7 @@ void wsp_ggml_metal_graph_compute(
1107
1341
  float eps;
1108
1342
  memcpy(&eps, dst->op_params, sizeof(float));
1109
1343
 
1110
- const int nth = 256;
1344
+ const int nth = MIN(256, ne00);
1111
1345
 
1112
1346
  [encoder setComputePipelineState:ctx->pipeline_norm];
1113
1347
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1115,7 +1349,7 @@ void wsp_ggml_metal_graph_compute(
1115
1349
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1116
1350
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1117
1351
  [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1118
- [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
1352
+ [encoder setThreadgroupMemoryLength:WSP_GGML_PAD(nth*sizeof(float), 16) atIndex:0];
1119
1353
 
1120
1354
  const int64_t nrows = wsp_ggml_nrows(src0);
1121
1355
 
@@ -1125,17 +1359,16 @@ void wsp_ggml_metal_graph_compute(
1125
1359
  {
1126
1360
  WSP_GGML_ASSERT((src0t == WSP_GGML_TYPE_F32));
1127
1361
 
1128
- const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
1362
+ const int nth = MIN(1024, ne00);
1363
+
1364
+ //const int n_past = ((int32_t *) dst->op_params)[0];
1129
1365
  const int n_head = ((int32_t *) dst->op_params)[1];
1130
1366
  float max_bias;
1131
1367
  memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
1132
1368
 
1133
- if (__builtin_popcount(n_head) != 1) {
1134
- WSP_GGML_ASSERT(false && "only power-of-two n_head implemented");
1135
- }
1136
-
1137
1369
  const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
1138
1370
  const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
1371
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
1139
1372
 
1140
1373
  [encoder setComputePipelineState:ctx->pipeline_alibi_f32];
1141
1374
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1156,55 +1389,74 @@ void wsp_ggml_metal_graph_compute(
1156
1389
  [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1157
1390
  [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1158
1391
  [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1159
- [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
1160
-
1161
- const int nth = 32;
1392
+ [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
1393
+ [encoder setBytes:&m1 length:sizeof( float) atIndex:19];
1394
+ [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
1162
1395
 
1163
1396
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1164
1397
  } break;
1165
1398
  case WSP_GGML_OP_ROPE:
1166
1399
  {
1167
- const int n_past = ((int32_t *) dst->op_params)[0];
1168
- const int n_dims = ((int32_t *) dst->op_params)[1];
1169
- const int mode = ((int32_t *) dst->op_params)[2];
1400
+ WSP_GGML_ASSERT(ne10 == ne02);
1170
1401
 
1171
- float freq_base;
1172
- float freq_scale;
1173
- memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
1174
- memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
1402
+ const int nth = MIN(1024, ne00);
1175
1403
 
1176
- [encoder setComputePipelineState:ctx->pipeline_rope];
1177
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1178
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1179
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1180
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1181
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1182
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1183
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1184
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1185
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1186
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1187
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1188
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1189
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1190
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1191
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1192
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1193
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1194
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1195
- [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
1196
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
1197
- [encoder setBytes:&mode length:sizeof( int) atIndex:20];
1198
- [encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
1199
- [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
1404
+ const int n_past = ((int32_t *) dst->op_params)[0];
1405
+ const int n_dims = ((int32_t *) dst->op_params)[1];
1406
+ const int mode = ((int32_t *) dst->op_params)[2];
1407
+ const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
1408
+
1409
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
1410
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
1411
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
1412
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
1413
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
1414
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
1415
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
1200
1416
 
1201
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
1417
+ switch (src0->type) {
1418
+ case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
1419
+ case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
1420
+ default: WSP_GGML_ASSERT(false);
1421
+ };
1422
+
1423
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1424
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1425
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1426
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
1427
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
1428
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
1429
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
1430
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
1431
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
1432
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
1433
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
1434
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
1435
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
1436
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
1437
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
1438
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
1439
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
1440
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
1441
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
1442
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
1443
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
1444
+ [encoder setBytes:&mode length:sizeof( int) atIndex:21];
1445
+ [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
1446
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
1447
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
1448
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
1449
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
1450
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
1451
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
1452
+
1453
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1202
1454
  } break;
1203
1455
  case WSP_GGML_OP_DUP:
1204
1456
  case WSP_GGML_OP_CPY:
1205
1457
  case WSP_GGML_OP_CONT:
1206
1458
  {
1207
- const int nth = 32;
1459
+ const int nth = MIN(1024, ne00);
1208
1460
 
1209
1461
  switch (src0t) {
1210
1462
  case WSP_GGML_TYPE_F32:
@@ -1249,7 +1501,7 @@ void wsp_ggml_metal_graph_compute(
1249
1501
  } break;
1250
1502
  default:
1251
1503
  {
1252
- metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op));
1504
+ WSP_GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op));
1253
1505
  WSP_GGML_ASSERT(false);
1254
1506
  }
1255
1507
  }
@@ -1274,10 +1526,147 @@ void wsp_ggml_metal_graph_compute(
1274
1526
 
1275
1527
  MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
1276
1528
  if (status != MTLCommandBufferStatusCompleted) {
1277
- metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status);
1529
+ WSP_GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
1278
1530
  WSP_GGML_ASSERT(false);
1279
1531
  }
1280
1532
  }
1281
1533
 
1282
1534
  }
1283
1535
  }
1536
+
1537
+ ////////////////////////////////////////////////////////////////////////////////
1538
+
1539
+ // backend interface
1540
+
1541
+ static const char * wsp_ggml_backend_metal_name(wsp_ggml_backend_t backend) {
1542
+ return "Metal";
1543
+
1544
+ UNUSED(backend);
1545
+ }
1546
+
1547
+ static void wsp_ggml_backend_metal_free(wsp_ggml_backend_t backend) {
1548
+ struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
1549
+ wsp_ggml_metal_free(ctx);
1550
+ free(backend);
1551
+ }
1552
+
1553
+ static void * wsp_ggml_backend_metal_buffer_get_base(wsp_ggml_backend_buffer_t buffer) {
1554
+ return (void *)buffer->context;
1555
+ }
1556
+
1557
+ static void wsp_ggml_backend_metal_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) {
1558
+ free(buffer->context);
1559
+ UNUSED(buffer);
1560
+ }
1561
+
1562
+ static struct wsp_ggml_backend_buffer_i metal_backend_buffer_i = {
1563
+ /* .free_buffer = */ wsp_ggml_backend_metal_buffer_free_buffer,
1564
+ /* .get_base = */ wsp_ggml_backend_metal_buffer_get_base,
1565
+ /* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes
1566
+ /* .init_tensor = */ NULL, // no initialization required
1567
+ /* .free_tensor = */ NULL, // no cleanup required
1568
+ };
1569
+
1570
+ static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_alloc_buffer(wsp_ggml_backend_t backend, size_t size) {
1571
+ struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
1572
+
1573
+ void * data = wsp_ggml_metal_host_malloc(size);
1574
+
1575
+ // TODO: set proper name of the buffers
1576
+ wsp_ggml_metal_add_buffer(ctx, "backend", data, size, 0);
1577
+
1578
+ return wsp_ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size);
1579
+ }
1580
+
1581
+ static size_t wsp_ggml_backend_metal_get_alignment(wsp_ggml_backend_t backend) {
1582
+ return 32;
1583
+ UNUSED(backend);
1584
+ }
1585
+
1586
+ static void wsp_ggml_backend_metal_set_tensor_async(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1587
+ WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds");
1588
+ WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1589
+
1590
+ memcpy((char *)tensor->data + offset, data, size);
1591
+
1592
+ UNUSED(backend);
1593
+ }
1594
+
1595
+ static void wsp_ggml_backend_metal_get_tensor_async(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1596
+ WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
1597
+ WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
1598
+
1599
+ memcpy(data, (const char *)tensor->data + offset, size);
1600
+
1601
+ UNUSED(backend);
1602
+ }
1603
+
1604
+ static void wsp_ggml_backend_metal_synchronize(wsp_ggml_backend_t backend) {
1605
+ UNUSED(backend);
1606
+ }
1607
+
1608
+ static void wsp_ggml_backend_metal_cpy_tensor_from(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
1609
+ wsp_ggml_backend_tensor_get(src, dst->data, 0, wsp_ggml_nbytes(src));
1610
+
1611
+ UNUSED(backend);
1612
+ }
1613
+
1614
+ static void wsp_ggml_backend_metal_cpy_tensor_to(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
1615
+ wsp_ggml_backend_tensor_set_async(dst, src->data, 0, wsp_ggml_nbytes(src));
1616
+
1617
+ UNUSED(backend);
1618
+ }
1619
+
1620
+ static void wsp_ggml_backend_metal_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) {
1621
+ struct wsp_ggml_metal_context * metal_ctx = (struct wsp_ggml_metal_context *)backend->context;
1622
+
1623
+ wsp_ggml_metal_graph_compute(metal_ctx, cgraph);
1624
+ }
1625
+
1626
+ static bool wsp_ggml_backend_metal_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) {
1627
+ return true;
1628
+ UNUSED(backend);
1629
+ UNUSED(op);
1630
+ }
1631
+
1632
+ static struct wsp_ggml_backend_i metal_backend_i = {
1633
+ /* .get_name = */ wsp_ggml_backend_metal_name,
1634
+ /* .free = */ wsp_ggml_backend_metal_free,
1635
+ /* .alloc_buffer = */ wsp_ggml_backend_metal_alloc_buffer,
1636
+ /* .get_alignment = */ wsp_ggml_backend_metal_get_alignment,
1637
+ /* .set_tensor_async = */ wsp_ggml_backend_metal_set_tensor_async,
1638
+ /* .get_tensor_async = */ wsp_ggml_backend_metal_get_tensor_async,
1639
+ /* .synchronize = */ wsp_ggml_backend_metal_synchronize,
1640
+ /* .cpy_tensor_from = */ wsp_ggml_backend_metal_cpy_tensor_from,
1641
+ /* .cpy_tensor_to = */ wsp_ggml_backend_metal_cpy_tensor_to,
1642
+ /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
1643
+ /* .graph_plan_free = */ NULL,
1644
+ /* .graph_plan_compute = */ NULL,
1645
+ /* .graph_compute = */ wsp_ggml_backend_metal_graph_compute,
1646
+ /* .supports_op = */ wsp_ggml_backend_metal_supports_op,
1647
+ };
1648
+
1649
+ wsp_ggml_backend_t wsp_ggml_backend_metal_init(void) {
1650
+ struct wsp_ggml_metal_context * ctx = malloc(sizeof(struct wsp_ggml_metal_context));
1651
+
1652
+ ctx = wsp_ggml_metal_init(WSP_GGML_DEFAULT_N_THREADS);
1653
+
1654
+ wsp_ggml_backend_t metal_backend = malloc(sizeof(struct wsp_ggml_backend));
1655
+
1656
+ *metal_backend = (struct wsp_ggml_backend) {
1657
+ /* .interface = */ metal_backend_i,
1658
+ /* .context = */ ctx,
1659
+ };
1660
+
1661
+ return metal_backend;
1662
+ }
1663
+
1664
+ bool wsp_ggml_backend_is_metal(wsp_ggml_backend_t backend) {
1665
+ return backend->iface.get_name == wsp_ggml_backend_metal_name;
1666
+ }
1667
+
1668
+ void wsp_ggml_backend_metal_set_n_cb(wsp_ggml_backend_t backend, int n_cb) {
1669
+ struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
1670
+
1671
+ wsp_ggml_metal_set_n_cb(ctx, n_cb);
1672
+ }