whisper.rn 0.4.0-rc.4 → 0.4.0-rc.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +6 -6
- package/android/build.gradle +4 -0
- package/android/src/main/CMakeLists.txt +5 -0
- package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +51 -133
- package/android/src/main/jni-utils.h +76 -0
- package/android/src/main/jni.cpp +187 -112
- package/cpp/README.md +1 -1
- package/cpp/coreml/whisper-encoder-impl.h +1 -1
- package/cpp/coreml/whisper-encoder.h +4 -0
- package/cpp/coreml/whisper-encoder.mm +4 -2
- package/cpp/ggml-alloc.c +55 -19
- package/cpp/ggml-alloc.h +7 -0
- package/cpp/ggml-backend-impl.h +46 -21
- package/cpp/ggml-backend.c +563 -156
- package/cpp/ggml-backend.h +62 -17
- package/cpp/ggml-impl.h +1 -1
- package/cpp/ggml-metal-whisper.metal +1010 -253
- package/cpp/ggml-metal.h +7 -1
- package/cpp/ggml-metal.m +618 -187
- package/cpp/ggml-quants.c +64 -59
- package/cpp/ggml-quants.h +40 -40
- package/cpp/ggml.c +751 -1466
- package/cpp/ggml.h +90 -25
- package/cpp/rn-audioutils.cpp +68 -0
- package/cpp/rn-audioutils.h +14 -0
- package/cpp/rn-whisper-log.h +11 -0
- package/cpp/rn-whisper.cpp +141 -59
- package/cpp/rn-whisper.h +47 -15
- package/cpp/whisper.cpp +1635 -928
- package/cpp/whisper.h +55 -10
- package/ios/RNWhisper.mm +7 -7
- package/ios/RNWhisperAudioUtils.h +0 -2
- package/ios/RNWhisperAudioUtils.m +0 -56
- package/ios/RNWhisperContext.h +3 -11
- package/ios/RNWhisperContext.mm +62 -134
- package/lib/commonjs/version.json +1 -1
- package/lib/module/version.json +1 -1
- package/package.json +6 -5
- package/src/version.json +1 -1
package/cpp/ggml-metal.m
CHANGED
|
@@ -62,6 +62,8 @@ struct wsp_ggml_metal_context {
|
|
|
62
62
|
WSP_GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
|
|
63
63
|
WSP_GGML_METAL_DECL_KERNEL(mul);
|
|
64
64
|
WSP_GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
|
65
|
+
WSP_GGML_METAL_DECL_KERNEL(div);
|
|
66
|
+
WSP_GGML_METAL_DECL_KERNEL(div_row);
|
|
65
67
|
WSP_GGML_METAL_DECL_KERNEL(scale);
|
|
66
68
|
WSP_GGML_METAL_DECL_KERNEL(scale_4);
|
|
67
69
|
WSP_GGML_METAL_DECL_KERNEL(silu);
|
|
@@ -86,6 +88,7 @@ struct wsp_ggml_metal_context {
|
|
|
86
88
|
WSP_GGML_METAL_DECL_KERNEL(rms_norm);
|
|
87
89
|
WSP_GGML_METAL_DECL_KERNEL(norm);
|
|
88
90
|
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
|
91
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
|
|
89
92
|
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
|
|
90
93
|
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
|
|
91
94
|
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
|
|
@@ -111,14 +114,35 @@ struct wsp_ggml_metal_context {
|
|
|
111
114
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
|
112
115
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
|
113
116
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
|
117
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
|
|
118
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
|
|
119
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
|
|
120
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
|
|
121
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
|
|
122
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
|
|
123
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
|
|
124
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
|
|
125
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
|
|
126
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
|
|
127
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
|
|
128
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
|
|
114
129
|
WSP_GGML_METAL_DECL_KERNEL(rope_f32);
|
|
115
130
|
WSP_GGML_METAL_DECL_KERNEL(rope_f16);
|
|
116
131
|
WSP_GGML_METAL_DECL_KERNEL(alibi_f32);
|
|
132
|
+
WSP_GGML_METAL_DECL_KERNEL(im2col_f16);
|
|
133
|
+
WSP_GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
|
|
134
|
+
WSP_GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
|
|
117
135
|
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
|
118
136
|
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
|
137
|
+
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
|
|
138
|
+
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
|
|
139
|
+
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
|
|
140
|
+
//WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
|
|
141
|
+
//WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
|
|
119
142
|
WSP_GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
|
120
143
|
WSP_GGML_METAL_DECL_KERNEL(concat);
|
|
121
144
|
WSP_GGML_METAL_DECL_KERNEL(sqr);
|
|
145
|
+
WSP_GGML_METAL_DECL_KERNEL(sum_rows);
|
|
122
146
|
|
|
123
147
|
#undef WSP_GGML_METAL_DECL_KERNEL
|
|
124
148
|
};
|
|
@@ -126,7 +150,7 @@ struct wsp_ggml_metal_context {
|
|
|
126
150
|
// MSL code
|
|
127
151
|
// TODO: move the contents here when ready
|
|
128
152
|
// for now it is easier to work in a separate file
|
|
129
|
-
static NSString * const msl_library_source = @"see metal.metal";
|
|
153
|
+
//static NSString * const msl_library_source = @"see metal.metal";
|
|
130
154
|
|
|
131
155
|
// Here to assist with NSBundle Path Hack
|
|
132
156
|
@interface WSPGGMLMetalClass : NSObject
|
|
@@ -142,7 +166,8 @@ void wsp_ggml_metal_log_set_callback(wsp_ggml_log_callback log_callback, void *
|
|
|
142
166
|
wsp_ggml_metal_log_user_data = user_data;
|
|
143
167
|
}
|
|
144
168
|
|
|
145
|
-
|
|
169
|
+
WSP_GGML_ATTRIBUTE_FORMAT(2, 3)
|
|
170
|
+
static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char * format, ...){
|
|
146
171
|
if (wsp_ggml_metal_log_callback != NULL) {
|
|
147
172
|
va_list args;
|
|
148
173
|
va_start(args, format);
|
|
@@ -152,6 +177,8 @@ static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char* format
|
|
|
152
177
|
wsp_ggml_metal_log_callback(level, buffer, wsp_ggml_metal_log_user_data);
|
|
153
178
|
} else {
|
|
154
179
|
char* buffer2 = malloc(len+1);
|
|
180
|
+
va_end(args);
|
|
181
|
+
va_start(args, format);
|
|
155
182
|
vsnprintf(buffer2, len+1, format, args);
|
|
156
183
|
buffer2[len] = 0;
|
|
157
184
|
wsp_ggml_metal_log_callback(level, buffer2, wsp_ggml_metal_log_user_data);
|
|
@@ -161,12 +188,10 @@ static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char* format
|
|
|
161
188
|
}
|
|
162
189
|
}
|
|
163
190
|
|
|
164
|
-
|
|
165
|
-
|
|
166
191
|
struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
167
192
|
WSP_GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
|
|
168
193
|
|
|
169
|
-
id
|
|
194
|
+
id<MTLDevice> device;
|
|
170
195
|
NSString * s;
|
|
171
196
|
|
|
172
197
|
#if TARGET_OS_OSX
|
|
@@ -184,7 +209,7 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
184
209
|
WSP_GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
|
|
185
210
|
|
|
186
211
|
// Configure context
|
|
187
|
-
struct wsp_ggml_metal_context * ctx =
|
|
212
|
+
struct wsp_ggml_metal_context * ctx = calloc(1, sizeof(struct wsp_ggml_metal_context));
|
|
188
213
|
ctx->device = device;
|
|
189
214
|
ctx->n_cb = MIN(n_cb, WSP_GGML_METAL_MAX_BUFFERS);
|
|
190
215
|
ctx->queue = [ctx->device newCommandQueue];
|
|
@@ -212,6 +237,9 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
212
237
|
|
|
213
238
|
NSString * sourcePath;
|
|
214
239
|
NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"WSP_GGML_METAL_PATH_RESOURCES"];
|
|
240
|
+
|
|
241
|
+
WSP_GGML_METAL_LOG_INFO("%s: WSP_GGML_METAL_PATH_RESOURCES = %s\n", __func__, ggmlMetalPathResources ? [ggmlMetalPathResources UTF8String] : "nil");
|
|
242
|
+
|
|
215
243
|
if (ggmlMetalPathResources) {
|
|
216
244
|
sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
|
|
217
245
|
} else {
|
|
@@ -242,6 +270,29 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
242
270
|
}
|
|
243
271
|
}
|
|
244
272
|
|
|
273
|
+
#if TARGET_OS_OSX
|
|
274
|
+
// print MTL GPU family:
|
|
275
|
+
WSP_GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
|
|
276
|
+
|
|
277
|
+
// determine max supported GPU family
|
|
278
|
+
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
|
279
|
+
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
|
280
|
+
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
|
281
|
+
if ([ctx->device supportsFamily:i]) {
|
|
282
|
+
WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
|
|
283
|
+
break;
|
|
284
|
+
}
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
WSP_GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
|
288
|
+
WSP_GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
|
|
289
|
+
if (ctx->device.maxTransferRate != 0) {
|
|
290
|
+
WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
|
|
291
|
+
} else {
|
|
292
|
+
WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
|
|
293
|
+
}
|
|
294
|
+
#endif
|
|
295
|
+
|
|
245
296
|
// load kernels
|
|
246
297
|
{
|
|
247
298
|
NSError * error = nil;
|
|
@@ -263,6 +314,8 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
263
314
|
WSP_GGML_METAL_ADD_KERNEL(add_row);
|
|
264
315
|
WSP_GGML_METAL_ADD_KERNEL(mul);
|
|
265
316
|
WSP_GGML_METAL_ADD_KERNEL(mul_row);
|
|
317
|
+
WSP_GGML_METAL_ADD_KERNEL(div);
|
|
318
|
+
WSP_GGML_METAL_ADD_KERNEL(div_row);
|
|
266
319
|
WSP_GGML_METAL_ADD_KERNEL(scale);
|
|
267
320
|
WSP_GGML_METAL_ADD_KERNEL(scale_4);
|
|
268
321
|
WSP_GGML_METAL_ADD_KERNEL(silu);
|
|
@@ -287,6 +340,7 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
287
340
|
WSP_GGML_METAL_ADD_KERNEL(rms_norm);
|
|
288
341
|
WSP_GGML_METAL_ADD_KERNEL(norm);
|
|
289
342
|
WSP_GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
|
343
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
|
|
290
344
|
WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
|
|
291
345
|
WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
|
|
292
346
|
WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
|
|
@@ -313,42 +367,40 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
313
367
|
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
|
314
368
|
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
|
315
369
|
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
|
370
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
|
|
371
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
|
|
372
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
|
|
373
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
|
|
374
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
|
|
375
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
|
|
376
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
|
|
377
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
|
|
378
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
|
|
379
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
|
|
380
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
|
|
381
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
|
|
316
382
|
}
|
|
317
383
|
WSP_GGML_METAL_ADD_KERNEL(rope_f32);
|
|
318
384
|
WSP_GGML_METAL_ADD_KERNEL(rope_f16);
|
|
319
385
|
WSP_GGML_METAL_ADD_KERNEL(alibi_f32);
|
|
386
|
+
WSP_GGML_METAL_ADD_KERNEL(im2col_f16);
|
|
387
|
+
WSP_GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
|
|
388
|
+
WSP_GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
|
|
320
389
|
WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
|
321
390
|
WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
|
391
|
+
WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
|
|
392
|
+
WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
|
|
393
|
+
WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
|
|
394
|
+
//WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
|
|
395
|
+
//WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
|
|
322
396
|
WSP_GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
|
323
397
|
WSP_GGML_METAL_ADD_KERNEL(concat);
|
|
324
398
|
WSP_GGML_METAL_ADD_KERNEL(sqr);
|
|
399
|
+
WSP_GGML_METAL_ADD_KERNEL(sum_rows);
|
|
325
400
|
|
|
326
401
|
#undef WSP_GGML_METAL_ADD_KERNEL
|
|
327
402
|
}
|
|
328
403
|
|
|
329
|
-
#if TARGET_OS_OSX
|
|
330
|
-
// print MTL GPU family:
|
|
331
|
-
WSP_GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
|
|
332
|
-
|
|
333
|
-
// determine max supported GPU family
|
|
334
|
-
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
|
335
|
-
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
|
336
|
-
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
|
337
|
-
if ([ctx->device supportsFamily:i]) {
|
|
338
|
-
WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
|
|
339
|
-
break;
|
|
340
|
-
}
|
|
341
|
-
}
|
|
342
|
-
|
|
343
|
-
WSP_GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
|
344
|
-
WSP_GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
|
345
|
-
if (ctx->device.maxTransferRate != 0) {
|
|
346
|
-
WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
|
347
|
-
} else {
|
|
348
|
-
WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
|
|
349
|
-
}
|
|
350
|
-
#endif
|
|
351
|
-
|
|
352
404
|
return ctx;
|
|
353
405
|
}
|
|
354
406
|
|
|
@@ -360,6 +412,8 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
|
|
|
360
412
|
WSP_GGML_METAL_DEL_KERNEL(add_row);
|
|
361
413
|
WSP_GGML_METAL_DEL_KERNEL(mul);
|
|
362
414
|
WSP_GGML_METAL_DEL_KERNEL(mul_row);
|
|
415
|
+
WSP_GGML_METAL_DEL_KERNEL(div);
|
|
416
|
+
WSP_GGML_METAL_DEL_KERNEL(div_row);
|
|
363
417
|
WSP_GGML_METAL_DEL_KERNEL(scale);
|
|
364
418
|
WSP_GGML_METAL_DEL_KERNEL(scale_4);
|
|
365
419
|
WSP_GGML_METAL_DEL_KERNEL(silu);
|
|
@@ -384,6 +438,7 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
|
|
|
384
438
|
WSP_GGML_METAL_DEL_KERNEL(rms_norm);
|
|
385
439
|
WSP_GGML_METAL_DEL_KERNEL(norm);
|
|
386
440
|
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
|
441
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
|
|
387
442
|
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
|
|
388
443
|
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
|
|
389
444
|
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
|
|
@@ -410,15 +465,36 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
|
|
|
410
465
|
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
|
411
466
|
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
|
412
467
|
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
|
468
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
|
|
469
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
|
|
470
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
|
|
471
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
|
|
472
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
|
|
473
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
|
|
474
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
|
|
475
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
|
|
476
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
|
|
477
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
|
|
478
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
|
|
479
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
|
|
413
480
|
}
|
|
414
481
|
WSP_GGML_METAL_DEL_KERNEL(rope_f32);
|
|
415
482
|
WSP_GGML_METAL_DEL_KERNEL(rope_f16);
|
|
416
483
|
WSP_GGML_METAL_DEL_KERNEL(alibi_f32);
|
|
484
|
+
WSP_GGML_METAL_DEL_KERNEL(im2col_f16);
|
|
485
|
+
WSP_GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
|
|
486
|
+
WSP_GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
|
|
417
487
|
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
|
418
488
|
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
|
489
|
+
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
|
|
490
|
+
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
|
|
491
|
+
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
|
|
492
|
+
//WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
|
|
493
|
+
//WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
|
|
419
494
|
WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
|
420
495
|
WSP_GGML_METAL_DEL_KERNEL(concat);
|
|
421
496
|
WSP_GGML_METAL_DEL_KERNEL(sqr);
|
|
497
|
+
WSP_GGML_METAL_DEL_KERNEL(sum_rows);
|
|
422
498
|
|
|
423
499
|
#undef WSP_GGML_METAL_DEL_KERNEL
|
|
424
500
|
|
|
@@ -452,6 +528,13 @@ int * wsp_ggml_metal_get_concur_list(struct wsp_ggml_metal_context * ctx) {
|
|
|
452
528
|
return ctx->concur_list;
|
|
453
529
|
}
|
|
454
530
|
|
|
531
|
+
// temporarily defined here for compatibility between ggml-backend and the old API
|
|
532
|
+
struct wsp_ggml_backend_metal_buffer_context {
|
|
533
|
+
void * data;
|
|
534
|
+
|
|
535
|
+
id<MTLBuffer> metal;
|
|
536
|
+
};
|
|
537
|
+
|
|
455
538
|
// finds the Metal buffer that contains the tensor data on the GPU device
|
|
456
539
|
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
|
457
540
|
// Metal buffer based on the host memory pointer
|
|
@@ -461,6 +544,19 @@ static id<MTLBuffer> wsp_ggml_metal_get_buffer(struct wsp_ggml_metal_context * c
|
|
|
461
544
|
|
|
462
545
|
const int64_t tsize = wsp_ggml_nbytes(t);
|
|
463
546
|
|
|
547
|
+
// compatibility with ggml-backend
|
|
548
|
+
if (t->buffer && t->buffer->buft == wsp_ggml_backend_metal_buffer_type()) {
|
|
549
|
+
struct wsp_ggml_backend_metal_buffer_context * buf_ctx = (struct wsp_ggml_backend_metal_buffer_context *) t->buffer->context;
|
|
550
|
+
|
|
551
|
+
const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->data;
|
|
552
|
+
|
|
553
|
+
WSP_GGML_ASSERT(ioffs >= 0 && ioffs + tsize <= (int64_t) t->buffer->size);
|
|
554
|
+
|
|
555
|
+
*offs = (size_t) ioffs;
|
|
556
|
+
|
|
557
|
+
return buf_ctx->metal;
|
|
558
|
+
}
|
|
559
|
+
|
|
464
560
|
// find the view that contains the tensor fully
|
|
465
561
|
for (int i = 0; i < ctx->n_buffers; ++i) {
|
|
466
562
|
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
|
@@ -518,11 +614,11 @@ bool wsp_ggml_metal_add_buffer(
|
|
|
518
614
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
|
519
615
|
|
|
520
616
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
|
521
|
-
WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f
|
|
617
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
|
522
618
|
return false;
|
|
523
619
|
}
|
|
524
620
|
|
|
525
|
-
WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f
|
|
621
|
+
WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
|
526
622
|
|
|
527
623
|
++ctx->n_buffers;
|
|
528
624
|
} else {
|
|
@@ -542,11 +638,11 @@ bool wsp_ggml_metal_add_buffer(
|
|
|
542
638
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
|
543
639
|
|
|
544
640
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
|
545
|
-
WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f
|
|
641
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
|
546
642
|
return false;
|
|
547
643
|
}
|
|
548
644
|
|
|
549
|
-
WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f
|
|
645
|
+
WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
|
550
646
|
if (i + size_step < size) {
|
|
551
647
|
WSP_GGML_METAL_LOG_INFO("\n");
|
|
552
648
|
}
|
|
@@ -561,7 +657,7 @@ bool wsp_ggml_metal_add_buffer(
|
|
|
561
657
|
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
|
562
658
|
|
|
563
659
|
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
|
564
|
-
WSP_GGML_METAL_LOG_WARN("
|
|
660
|
+
WSP_GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
|
565
661
|
} else {
|
|
566
662
|
WSP_GGML_METAL_LOG_INFO("\n");
|
|
567
663
|
}
|
|
@@ -683,6 +779,51 @@ void wsp_ggml_metal_graph_find_concurrency(
|
|
|
683
779
|
}
|
|
684
780
|
}
|
|
685
781
|
|
|
782
|
+
static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) {
|
|
783
|
+
switch (op->op) {
|
|
784
|
+
case WSP_GGML_OP_UNARY:
|
|
785
|
+
switch (wsp_ggml_get_unary_op(op)) {
|
|
786
|
+
case WSP_GGML_UNARY_OP_SILU:
|
|
787
|
+
case WSP_GGML_UNARY_OP_RELU:
|
|
788
|
+
case WSP_GGML_UNARY_OP_GELU:
|
|
789
|
+
return true;
|
|
790
|
+
default:
|
|
791
|
+
return false;
|
|
792
|
+
}
|
|
793
|
+
case WSP_GGML_OP_NONE:
|
|
794
|
+
case WSP_GGML_OP_RESHAPE:
|
|
795
|
+
case WSP_GGML_OP_VIEW:
|
|
796
|
+
case WSP_GGML_OP_TRANSPOSE:
|
|
797
|
+
case WSP_GGML_OP_PERMUTE:
|
|
798
|
+
case WSP_GGML_OP_CONCAT:
|
|
799
|
+
case WSP_GGML_OP_ADD:
|
|
800
|
+
case WSP_GGML_OP_MUL:
|
|
801
|
+
case WSP_GGML_OP_DIV:
|
|
802
|
+
case WSP_GGML_OP_SCALE:
|
|
803
|
+
case WSP_GGML_OP_SQR:
|
|
804
|
+
case WSP_GGML_OP_SUM_ROWS:
|
|
805
|
+
case WSP_GGML_OP_SOFT_MAX:
|
|
806
|
+
case WSP_GGML_OP_RMS_NORM:
|
|
807
|
+
case WSP_GGML_OP_NORM:
|
|
808
|
+
case WSP_GGML_OP_ALIBI:
|
|
809
|
+
case WSP_GGML_OP_ROPE:
|
|
810
|
+
case WSP_GGML_OP_IM2COL:
|
|
811
|
+
case WSP_GGML_OP_ARGSORT:
|
|
812
|
+
case WSP_GGML_OP_DUP:
|
|
813
|
+
case WSP_GGML_OP_CPY:
|
|
814
|
+
case WSP_GGML_OP_CONT:
|
|
815
|
+
case WSP_GGML_OP_MUL_MAT:
|
|
816
|
+
case WSP_GGML_OP_MUL_MAT_ID:
|
|
817
|
+
return true;
|
|
818
|
+
case WSP_GGML_OP_DIAG_MASK_INF:
|
|
819
|
+
case WSP_GGML_OP_GET_ROWS:
|
|
820
|
+
{
|
|
821
|
+
return op->ne[0] % 4 == 0;
|
|
822
|
+
}
|
|
823
|
+
default:
|
|
824
|
+
return false;
|
|
825
|
+
}
|
|
826
|
+
}
|
|
686
827
|
void wsp_ggml_metal_graph_compute(
|
|
687
828
|
struct wsp_ggml_metal_context * ctx,
|
|
688
829
|
struct wsp_ggml_cgraph * gf) {
|
|
@@ -753,6 +894,8 @@ void wsp_ggml_metal_graph_compute(
|
|
|
753
894
|
} break;
|
|
754
895
|
}
|
|
755
896
|
|
|
897
|
+
WSP_GGML_ASSERT(wsp_ggml_metal_supports_op(dst));
|
|
898
|
+
|
|
756
899
|
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
|
757
900
|
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
|
758
901
|
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
|
@@ -845,6 +988,8 @@ void wsp_ggml_metal_graph_compute(
|
|
|
845
988
|
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
846
989
|
} break;
|
|
847
990
|
case WSP_GGML_OP_ADD:
|
|
991
|
+
case WSP_GGML_OP_MUL:
|
|
992
|
+
case WSP_GGML_OP_DIV:
|
|
848
993
|
{
|
|
849
994
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
850
995
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
|
|
@@ -858,11 +1003,21 @@ void wsp_ggml_metal_graph_compute(
|
|
|
858
1003
|
WSP_GGML_ASSERT(ne11 == 1);
|
|
859
1004
|
|
|
860
1005
|
nb = ne00 / 4;
|
|
861
|
-
|
|
1006
|
+
switch (dst->op) {
|
|
1007
|
+
case WSP_GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
|
|
1008
|
+
case WSP_GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
|
|
1009
|
+
case WSP_GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
|
|
1010
|
+
default: WSP_GGML_ASSERT(false);
|
|
1011
|
+
}
|
|
862
1012
|
|
|
863
1013
|
bcast_row = true;
|
|
864
1014
|
} else {
|
|
865
|
-
|
|
1015
|
+
switch (dst->op) {
|
|
1016
|
+
case WSP_GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
|
|
1017
|
+
case WSP_GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
|
|
1018
|
+
case WSP_GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
|
|
1019
|
+
default: WSP_GGML_ASSERT(false);
|
|
1020
|
+
}
|
|
866
1021
|
}
|
|
867
1022
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
868
1023
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
@@ -903,31 +1058,6 @@ void wsp_ggml_metal_graph_compute(
|
|
|
903
1058
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
904
1059
|
}
|
|
905
1060
|
} break;
|
|
906
|
-
case WSP_GGML_OP_MUL:
|
|
907
|
-
{
|
|
908
|
-
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
909
|
-
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
|
|
910
|
-
|
|
911
|
-
// utilize float4
|
|
912
|
-
WSP_GGML_ASSERT(ne00 % 4 == 0);
|
|
913
|
-
const int64_t nb = ne00/4;
|
|
914
|
-
|
|
915
|
-
if (wsp_ggml_nelements(src1) == ne10) {
|
|
916
|
-
// src1 is a row
|
|
917
|
-
WSP_GGML_ASSERT(ne11 == 1);
|
|
918
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
|
919
|
-
} else {
|
|
920
|
-
[encoder setComputePipelineState:ctx->pipeline_mul];
|
|
921
|
-
}
|
|
922
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
923
|
-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
924
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
925
|
-
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
|
|
926
|
-
|
|
927
|
-
const int64_t n = wsp_ggml_nelements(dst)/4;
|
|
928
|
-
|
|
929
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
930
|
-
} break;
|
|
931
1061
|
case WSP_GGML_OP_SCALE:
|
|
932
1062
|
{
|
|
933
1063
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
@@ -1000,25 +1130,68 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1000
1130
|
const int64_t n = wsp_ggml_nelements(dst);
|
|
1001
1131
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1002
1132
|
} break;
|
|
1133
|
+
case WSP_GGML_OP_SUM_ROWS:
|
|
1134
|
+
{
|
|
1135
|
+
WSP_GGML_ASSERT(src0->nb[0] == wsp_ggml_type_size(src0->type));
|
|
1136
|
+
|
|
1137
|
+
[encoder setComputePipelineState:ctx->pipeline_sum_rows];
|
|
1138
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1139
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1140
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
1141
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
1142
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
1143
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
1144
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
1145
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
1146
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
1147
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
1148
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
|
1149
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
|
1150
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
|
1151
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
|
1152
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
|
1153
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
|
1154
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
|
1155
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
|
1156
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
|
|
1157
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
|
|
1158
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
|
|
1159
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
|
|
1160
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
|
|
1161
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
|
|
1162
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
|
|
1163
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
|
|
1164
|
+
|
|
1165
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1166
|
+
} break;
|
|
1003
1167
|
case WSP_GGML_OP_SOFT_MAX:
|
|
1004
1168
|
{
|
|
1005
1169
|
int nth = 32; // SIMD width
|
|
1006
1170
|
|
|
1007
1171
|
if (ne00%4 == 0) {
|
|
1172
|
+
while (nth < ne00/4 && nth < 256) {
|
|
1173
|
+
nth *= 2;
|
|
1174
|
+
}
|
|
1008
1175
|
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
|
1009
1176
|
} else {
|
|
1010
|
-
|
|
1177
|
+
while (nth < ne00 && nth < 1024) {
|
|
1011
1178
|
nth *= 2;
|
|
1012
|
-
}
|
|
1013
|
-
nth /= 2;
|
|
1179
|
+
}
|
|
1014
1180
|
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
|
1015
1181
|
}
|
|
1016
|
-
|
|
1017
|
-
[
|
|
1018
|
-
|
|
1019
|
-
[encoder
|
|
1020
|
-
|
|
1021
|
-
|
|
1182
|
+
|
|
1183
|
+
const float scale = ((float *) dst->op_params)[0];
|
|
1184
|
+
|
|
1185
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1186
|
+
if (id_src1) {
|
|
1187
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1188
|
+
}
|
|
1189
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1190
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
1191
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
1192
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
1193
|
+
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
|
1194
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
1022
1195
|
|
|
1023
1196
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1024
1197
|
} break;
|
|
@@ -1047,9 +1220,13 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1047
1220
|
case WSP_GGML_OP_MUL_MAT:
|
|
1048
1221
|
{
|
|
1049
1222
|
WSP_GGML_ASSERT(ne00 == ne10);
|
|
1050
|
-
WSP_GGML_ASSERT(ne03 == ne13);
|
|
1051
1223
|
|
|
1052
|
-
|
|
1224
|
+
// TODO: assert that dim2 and dim3 are contiguous
|
|
1225
|
+
WSP_GGML_ASSERT(ne12 % ne02 == 0);
|
|
1226
|
+
WSP_GGML_ASSERT(ne13 % ne03 == 0);
|
|
1227
|
+
|
|
1228
|
+
const uint r2 = ne12/ne02;
|
|
1229
|
+
const uint r3 = ne13/ne03;
|
|
1053
1230
|
|
|
1054
1231
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
|
1055
1232
|
// to the matrix-vector kernel
|
|
@@ -1084,7 +1261,7 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1084
1261
|
!wsp_ggml_is_transposed(src1) &&
|
|
1085
1262
|
src1t == WSP_GGML_TYPE_F32 &&
|
|
1086
1263
|
ne00 % 32 == 0 && ne00 >= 64 &&
|
|
1087
|
-
ne11 > ne11_mm_min) {
|
|
1264
|
+
(ne11 > ne11_mm_min || (wsp_ggml_is_quantized(src0t) && ne12 > 1))) {
|
|
1088
1265
|
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
1089
1266
|
switch (src0->type) {
|
|
1090
1267
|
case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
|
|
@@ -1114,9 +1291,10 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1114
1291
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
|
1115
1292
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
|
1116
1293
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
|
1117
|
-
[encoder setBytes:&
|
|
1294
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
|
|
1295
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
|
|
1118
1296
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
|
1119
|
-
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
1297
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
1120
1298
|
} else {
|
|
1121
1299
|
int nth0 = 32;
|
|
1122
1300
|
int nth1 = 1;
|
|
@@ -1127,6 +1305,7 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1127
1305
|
switch (src0t) {
|
|
1128
1306
|
case WSP_GGML_TYPE_F32:
|
|
1129
1307
|
{
|
|
1308
|
+
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
1130
1309
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
|
|
1131
1310
|
nrows = 4;
|
|
1132
1311
|
} break;
|
|
@@ -1134,102 +1313,77 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1134
1313
|
{
|
|
1135
1314
|
nth0 = 32;
|
|
1136
1315
|
nth1 = 1;
|
|
1137
|
-
if (
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1316
|
+
if (src1t == WSP_GGML_TYPE_F32) {
|
|
1317
|
+
if (ne11 * ne12 < 4) {
|
|
1318
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
|
|
1319
|
+
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
|
1320
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
|
|
1321
|
+
nrows = ne11;
|
|
1322
|
+
} else {
|
|
1323
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
|
|
1324
|
+
nrows = 4;
|
|
1325
|
+
}
|
|
1142
1326
|
} else {
|
|
1143
|
-
[encoder setComputePipelineState:ctx->
|
|
1327
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
|
|
1144
1328
|
nrows = 4;
|
|
1145
1329
|
}
|
|
1146
1330
|
} break;
|
|
1147
1331
|
case WSP_GGML_TYPE_Q4_0:
|
|
1148
1332
|
{
|
|
1149
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
1150
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
1151
|
-
|
|
1152
1333
|
nth0 = 8;
|
|
1153
1334
|
nth1 = 8;
|
|
1154
1335
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
|
|
1155
1336
|
} break;
|
|
1156
1337
|
case WSP_GGML_TYPE_Q4_1:
|
|
1157
1338
|
{
|
|
1158
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
1159
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
1160
|
-
|
|
1161
1339
|
nth0 = 8;
|
|
1162
1340
|
nth1 = 8;
|
|
1163
1341
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
|
|
1164
1342
|
} break;
|
|
1165
1343
|
case WSP_GGML_TYPE_Q5_0:
|
|
1166
1344
|
{
|
|
1167
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
1168
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
1169
|
-
|
|
1170
1345
|
nth0 = 8;
|
|
1171
1346
|
nth1 = 8;
|
|
1172
1347
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
|
|
1173
1348
|
} break;
|
|
1174
1349
|
case WSP_GGML_TYPE_Q5_1:
|
|
1175
1350
|
{
|
|
1176
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
1177
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
1178
|
-
|
|
1179
1351
|
nth0 = 8;
|
|
1180
1352
|
nth1 = 8;
|
|
1181
1353
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
|
|
1182
1354
|
} break;
|
|
1183
1355
|
case WSP_GGML_TYPE_Q8_0:
|
|
1184
1356
|
{
|
|
1185
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
1186
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
1187
|
-
|
|
1188
1357
|
nth0 = 8;
|
|
1189
1358
|
nth1 = 8;
|
|
1190
1359
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
|
|
1191
1360
|
} break;
|
|
1192
1361
|
case WSP_GGML_TYPE_Q2_K:
|
|
1193
1362
|
{
|
|
1194
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
1195
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
1196
|
-
|
|
1197
1363
|
nth0 = 2;
|
|
1198
1364
|
nth1 = 32;
|
|
1199
1365
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
|
|
1200
1366
|
} break;
|
|
1201
1367
|
case WSP_GGML_TYPE_Q3_K:
|
|
1202
1368
|
{
|
|
1203
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
1204
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
1205
|
-
|
|
1206
1369
|
nth0 = 2;
|
|
1207
1370
|
nth1 = 32;
|
|
1208
1371
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
|
|
1209
1372
|
} break;
|
|
1210
1373
|
case WSP_GGML_TYPE_Q4_K:
|
|
1211
1374
|
{
|
|
1212
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
1213
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
1214
|
-
|
|
1215
1375
|
nth0 = 4; //1;
|
|
1216
1376
|
nth1 = 8; //32;
|
|
1217
1377
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
|
|
1218
1378
|
} break;
|
|
1219
1379
|
case WSP_GGML_TYPE_Q5_K:
|
|
1220
1380
|
{
|
|
1221
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
1222
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
1223
|
-
|
|
1224
1381
|
nth0 = 2;
|
|
1225
1382
|
nth1 = 32;
|
|
1226
1383
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
|
|
1227
1384
|
} break;
|
|
1228
1385
|
case WSP_GGML_TYPE_Q6_K:
|
|
1229
1386
|
{
|
|
1230
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
1231
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
1232
|
-
|
|
1233
1387
|
nth0 = 2;
|
|
1234
1388
|
nth1 = 32;
|
|
1235
1389
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
|
|
@@ -1258,32 +1412,125 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1258
1412
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
|
1259
1413
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
|
1260
1414
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
|
1261
|
-
[encoder setBytes:&
|
|
1415
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
|
1416
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
|
1262
1417
|
|
|
1263
1418
|
if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 ||
|
|
1264
1419
|
src0t == WSP_GGML_TYPE_Q5_0 || src0t == WSP_GGML_TYPE_Q5_1 || src0t == WSP_GGML_TYPE_Q8_0 ||
|
|
1265
1420
|
src0t == WSP_GGML_TYPE_Q2_K) { // || src0t == WSP_GGML_TYPE_Q4_K) {
|
|
1266
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1421
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1267
1422
|
}
|
|
1268
1423
|
else if (src0t == WSP_GGML_TYPE_Q4_K) {
|
|
1269
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1424
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1270
1425
|
}
|
|
1271
1426
|
else if (src0t == WSP_GGML_TYPE_Q3_K) {
|
|
1272
1427
|
#ifdef WSP_GGML_QKK_64
|
|
1273
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1428
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1274
1429
|
#else
|
|
1275
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1430
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1276
1431
|
#endif
|
|
1277
1432
|
}
|
|
1278
1433
|
else if (src0t == WSP_GGML_TYPE_Q5_K) {
|
|
1279
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1434
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1280
1435
|
}
|
|
1281
1436
|
else if (src0t == WSP_GGML_TYPE_Q6_K) {
|
|
1282
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1437
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1283
1438
|
} else {
|
|
1284
1439
|
int64_t ny = (ne11 + nrows - 1)/nrows;
|
|
1285
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1440
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1441
|
+
}
|
|
1442
|
+
}
|
|
1443
|
+
} break;
|
|
1444
|
+
case WSP_GGML_OP_MUL_MAT_ID:
|
|
1445
|
+
{
|
|
1446
|
+
//WSP_GGML_ASSERT(ne00 == ne10);
|
|
1447
|
+
//WSP_GGML_ASSERT(ne03 == ne13);
|
|
1448
|
+
|
|
1449
|
+
WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_I32);
|
|
1450
|
+
|
|
1451
|
+
const int n_as = ne00;
|
|
1452
|
+
|
|
1453
|
+
// TODO: make this more general
|
|
1454
|
+
WSP_GGML_ASSERT(n_as <= 8);
|
|
1455
|
+
|
|
1456
|
+
struct wsp_ggml_tensor * src2 = gf->nodes[i]->src[2];
|
|
1457
|
+
|
|
1458
|
+
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
|
1459
|
+
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
|
1460
|
+
const int64_t ne22 = src2 ? src2->ne[2] : 0;
|
|
1461
|
+
const int64_t ne23 = src2 ? src2->ne[3] : 0; WSP_GGML_UNUSED(ne23);
|
|
1462
|
+
|
|
1463
|
+
const uint64_t nb20 = src2 ? src2->nb[0] : 0; WSP_GGML_UNUSED(nb20);
|
|
1464
|
+
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
|
1465
|
+
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
|
1466
|
+
const uint64_t nb23 = src2 ? src2->nb[3] : 0; WSP_GGML_UNUSED(nb23);
|
|
1467
|
+
|
|
1468
|
+
const enum wsp_ggml_type src2t = src2 ? src2->type : WSP_GGML_TYPE_COUNT; WSP_GGML_UNUSED(src2t);
|
|
1469
|
+
|
|
1470
|
+
WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src2));
|
|
1471
|
+
WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src1));
|
|
1472
|
+
|
|
1473
|
+
WSP_GGML_ASSERT(ne20 % 32 == 0);
|
|
1474
|
+
// !!!!!!!!! TODO: this assert is probably required but not sure!
|
|
1475
|
+
//WSP_GGML_ASSERT(ne20 >= 64);
|
|
1476
|
+
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
1477
|
+
|
|
1478
|
+
const uint r2 = ne12/ne22;
|
|
1479
|
+
const uint r3 = ne13/ne23;
|
|
1480
|
+
|
|
1481
|
+
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
|
1482
|
+
// to the matrix-vector kernel
|
|
1483
|
+
int ne11_mm_min = 0;
|
|
1484
|
+
|
|
1485
|
+
const int idx = ((int32_t *) dst->op_params)[0];
|
|
1486
|
+
|
|
1487
|
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
|
1488
|
+
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
1489
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
|
1490
|
+
ne11 > ne11_mm_min) {
|
|
1491
|
+
switch (src2->type) {
|
|
1492
|
+
case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
|
|
1493
|
+
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
|
|
1494
|
+
case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
|
|
1495
|
+
case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
|
|
1496
|
+
case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
|
|
1497
|
+
case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
|
|
1498
|
+
case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
|
|
1499
|
+
case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
|
|
1500
|
+
case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
|
|
1501
|
+
case WSP_GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
|
|
1502
|
+
case WSP_GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
|
|
1503
|
+
case WSP_GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
|
|
1504
|
+
default: WSP_GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
|
1505
|
+
}
|
|
1506
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1507
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1508
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1509
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
|
|
1510
|
+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
|
|
1511
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
|
|
1512
|
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6];
|
|
1513
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
|
1514
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
|
|
1515
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
|
|
1516
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
|
1517
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
|
1518
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
|
1519
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
|
|
1520
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
|
|
1521
|
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:15];
|
|
1522
|
+
// TODO: how to make this an array? read Metal docs
|
|
1523
|
+
for (int j = 0; j < n_as; ++j) {
|
|
1524
|
+
struct wsp_ggml_tensor * src_cur = dst->src[2 + j];
|
|
1525
|
+
|
|
1526
|
+
size_t offs_src_cur = 0;
|
|
1527
|
+
id<MTLBuffer> id_src_cur = wsp_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
|
1528
|
+
|
|
1529
|
+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
|
|
1286
1530
|
}
|
|
1531
|
+
|
|
1532
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
|
1533
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
1287
1534
|
}
|
|
1288
1535
|
} break;
|
|
1289
1536
|
case WSP_GGML_OP_GET_ROWS:
|
|
@@ -1322,15 +1569,19 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1322
1569
|
float eps;
|
|
1323
1570
|
memcpy(&eps, dst->op_params, sizeof(float));
|
|
1324
1571
|
|
|
1325
|
-
|
|
1572
|
+
int nth = 32; // SIMD width
|
|
1573
|
+
|
|
1574
|
+
while (nth < ne00/4 && nth < 1024) {
|
|
1575
|
+
nth *= 2;
|
|
1576
|
+
}
|
|
1326
1577
|
|
|
1327
1578
|
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
|
1328
|
-
[encoder setBuffer:id_src0 offset:offs_src0
|
|
1329
|
-
[encoder setBuffer:id_dst offset:offs_dst
|
|
1330
|
-
[encoder setBytes:&ne00
|
|
1331
|
-
[encoder setBytes:&nb01
|
|
1332
|
-
[encoder setBytes:&eps
|
|
1333
|
-
[encoder setThreadgroupMemoryLength:
|
|
1579
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1580
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1581
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1582
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
|
1583
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
|
1584
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
1334
1585
|
|
|
1335
1586
|
const int64_t nrows = wsp_ggml_nrows(src0);
|
|
1336
1587
|
|
|
@@ -1404,7 +1655,8 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1404
1655
|
const int n_past = ((int32_t *) dst->op_params)[0];
|
|
1405
1656
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
1406
1657
|
const int mode = ((int32_t *) dst->op_params)[2];
|
|
1407
|
-
|
|
1658
|
+
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
|
1659
|
+
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
|
1408
1660
|
|
|
1409
1661
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
1410
1662
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
@@ -1452,18 +1704,100 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1452
1704
|
|
|
1453
1705
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1454
1706
|
} break;
|
|
1707
|
+
case WSP_GGML_OP_IM2COL:
|
|
1708
|
+
{
|
|
1709
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
|
|
1710
|
+
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
1711
|
+
WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16);
|
|
1712
|
+
|
|
1713
|
+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
|
1714
|
+
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
|
1715
|
+
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
|
|
1716
|
+
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
|
1717
|
+
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
|
1718
|
+
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
|
1719
|
+
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
|
1720
|
+
|
|
1721
|
+
const int32_t N = src1->ne[is_2D ? 3 : 2];
|
|
1722
|
+
const int32_t IC = src1->ne[is_2D ? 2 : 1];
|
|
1723
|
+
const int32_t IH = is_2D ? src1->ne[1] : 1;
|
|
1724
|
+
const int32_t IW = src1->ne[0];
|
|
1725
|
+
|
|
1726
|
+
const int32_t KH = is_2D ? src0->ne[1] : 1;
|
|
1727
|
+
const int32_t KW = src0->ne[0];
|
|
1728
|
+
|
|
1729
|
+
const int32_t OH = is_2D ? dst->ne[2] : 1;
|
|
1730
|
+
const int32_t OW = dst->ne[1];
|
|
1731
|
+
|
|
1732
|
+
const int32_t CHW = IC * KH * KW;
|
|
1733
|
+
|
|
1734
|
+
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
|
1735
|
+
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
|
1736
|
+
|
|
1737
|
+
switch (src0->type) {
|
|
1738
|
+
case WSP_GGML_TYPE_F32: WSP_GGML_ASSERT(false && "not implemented"); break;
|
|
1739
|
+
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
|
|
1740
|
+
default: WSP_GGML_ASSERT(false);
|
|
1741
|
+
};
|
|
1742
|
+
|
|
1743
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
|
1744
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1745
|
+
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
|
|
1746
|
+
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
|
|
1747
|
+
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
|
|
1748
|
+
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
|
|
1749
|
+
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
|
|
1750
|
+
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
|
|
1751
|
+
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
|
|
1752
|
+
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
|
|
1753
|
+
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
|
|
1754
|
+
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
|
|
1755
|
+
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
|
|
1756
|
+
|
|
1757
|
+
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
|
1758
|
+
} break;
|
|
1759
|
+
case WSP_GGML_OP_ARGSORT:
|
|
1760
|
+
{
|
|
1761
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
1762
|
+
WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_I32);
|
|
1763
|
+
|
|
1764
|
+
const int nrows = wsp_ggml_nrows(src0);
|
|
1765
|
+
|
|
1766
|
+
enum wsp_ggml_sort_order order = (enum wsp_ggml_sort_order) dst->op_params[0];
|
|
1767
|
+
|
|
1768
|
+
switch (order) {
|
|
1769
|
+
case WSP_GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
|
|
1770
|
+
case WSP_GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
|
|
1771
|
+
default: WSP_GGML_ASSERT(false);
|
|
1772
|
+
};
|
|
1773
|
+
|
|
1774
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1775
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1776
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1777
|
+
|
|
1778
|
+
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
|
|
1779
|
+
} break;
|
|
1455
1780
|
case WSP_GGML_OP_DUP:
|
|
1456
1781
|
case WSP_GGML_OP_CPY:
|
|
1457
1782
|
case WSP_GGML_OP_CONT:
|
|
1458
1783
|
{
|
|
1459
|
-
|
|
1784
|
+
WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(src0->type) == 0);
|
|
1785
|
+
|
|
1786
|
+
int nth = MIN(1024, ne00/wsp_ggml_blck_size(src0->type));
|
|
1460
1787
|
|
|
1461
1788
|
switch (src0t) {
|
|
1462
1789
|
case WSP_GGML_TYPE_F32:
|
|
1463
1790
|
{
|
|
1791
|
+
WSP_GGML_ASSERT(ne0 % wsp_ggml_blck_size(dst->type) == 0);
|
|
1792
|
+
|
|
1464
1793
|
switch (dstt) {
|
|
1465
|
-
case WSP_GGML_TYPE_F16:
|
|
1466
|
-
case WSP_GGML_TYPE_F32:
|
|
1794
|
+
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
|
|
1795
|
+
case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
|
|
1796
|
+
case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
|
|
1797
|
+
case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
|
|
1798
|
+
case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
|
|
1799
|
+
//case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
|
|
1800
|
+
//case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
|
|
1467
1801
|
default: WSP_GGML_ASSERT(false && "not implemented");
|
|
1468
1802
|
};
|
|
1469
1803
|
} break;
|
|
@@ -1538,81 +1872,148 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1538
1872
|
|
|
1539
1873
|
// backend interface
|
|
1540
1874
|
|
|
1541
|
-
static
|
|
1542
|
-
|
|
1875
|
+
static id<MTLDevice> g_backend_device = nil;
|
|
1876
|
+
static int g_backend_device_ref_count = 0;
|
|
1543
1877
|
|
|
1544
|
-
|
|
1878
|
+
static id<MTLDevice> wsp_ggml_backend_metal_get_device(void) {
|
|
1879
|
+
if (g_backend_device == nil) {
|
|
1880
|
+
g_backend_device = MTLCreateSystemDefaultDevice();
|
|
1881
|
+
}
|
|
1882
|
+
|
|
1883
|
+
g_backend_device_ref_count++;
|
|
1884
|
+
|
|
1885
|
+
return g_backend_device;
|
|
1545
1886
|
}
|
|
1546
1887
|
|
|
1547
|
-
static void
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
|
|
1888
|
+
static void wsp_ggml_backend_metal_free_device(void) {
|
|
1889
|
+
assert(g_backend_device_ref_count > 0);
|
|
1890
|
+
|
|
1891
|
+
g_backend_device_ref_count--;
|
|
1892
|
+
|
|
1893
|
+
if (g_backend_device_ref_count == 0) {
|
|
1894
|
+
g_backend_device = nil;
|
|
1895
|
+
}
|
|
1551
1896
|
}
|
|
1552
1897
|
|
|
1553
1898
|
static void * wsp_ggml_backend_metal_buffer_get_base(wsp_ggml_backend_buffer_t buffer) {
|
|
1554
|
-
|
|
1899
|
+
struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context;
|
|
1900
|
+
|
|
1901
|
+
return ctx->data;
|
|
1555
1902
|
}
|
|
1556
1903
|
|
|
1557
1904
|
static void wsp_ggml_backend_metal_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) {
|
|
1558
|
-
|
|
1905
|
+
struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context;
|
|
1906
|
+
|
|
1907
|
+
wsp_ggml_backend_metal_free_device();
|
|
1908
|
+
|
|
1909
|
+
free(ctx->data);
|
|
1910
|
+
free(ctx);
|
|
1911
|
+
|
|
1912
|
+
UNUSED(buffer);
|
|
1913
|
+
}
|
|
1914
|
+
|
|
1915
|
+
static void wsp_ggml_backend_metal_buffer_set_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
|
1916
|
+
WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds");
|
|
1917
|
+
WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
|
1918
|
+
|
|
1919
|
+
memcpy((char *)tensor->data + offset, data, size);
|
|
1920
|
+
|
|
1921
|
+
UNUSED(buffer);
|
|
1922
|
+
}
|
|
1923
|
+
|
|
1924
|
+
static void wsp_ggml_backend_metal_buffer_get_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
|
1925
|
+
WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
|
|
1926
|
+
WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
|
1927
|
+
|
|
1928
|
+
memcpy(data, (const char *)tensor->data + offset, size);
|
|
1929
|
+
|
|
1930
|
+
UNUSED(buffer);
|
|
1931
|
+
}
|
|
1932
|
+
|
|
1933
|
+
static void wsp_ggml_backend_metal_buffer_cpy_tensor_from(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
|
|
1934
|
+
wsp_ggml_backend_tensor_get(src, dst->data, 0, wsp_ggml_nbytes(src));
|
|
1935
|
+
|
|
1936
|
+
UNUSED(buffer);
|
|
1937
|
+
}
|
|
1938
|
+
|
|
1939
|
+
static void wsp_ggml_backend_metal_buffer_cpy_tensor_to(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
|
|
1940
|
+
wsp_ggml_backend_tensor_set(dst, src->data, 0, wsp_ggml_nbytes(src));
|
|
1941
|
+
|
|
1559
1942
|
UNUSED(buffer);
|
|
1560
1943
|
}
|
|
1561
1944
|
|
|
1562
1945
|
static struct wsp_ggml_backend_buffer_i metal_backend_buffer_i = {
|
|
1563
|
-
/* .free_buffer
|
|
1564
|
-
/* .get_base
|
|
1565
|
-
/* .
|
|
1566
|
-
/* .
|
|
1567
|
-
/* .
|
|
1946
|
+
/* .free_buffer = */ wsp_ggml_backend_metal_buffer_free_buffer,
|
|
1947
|
+
/* .get_base = */ wsp_ggml_backend_metal_buffer_get_base,
|
|
1948
|
+
/* .init_tensor = */ NULL,
|
|
1949
|
+
/* .set_tensor = */ wsp_ggml_backend_metal_buffer_set_tensor,
|
|
1950
|
+
/* .get_tensor = */ wsp_ggml_backend_metal_buffer_get_tensor,
|
|
1951
|
+
/* .cpy_tensor_from = */ wsp_ggml_backend_metal_buffer_cpy_tensor_from,
|
|
1952
|
+
/* .cpy_tensor_to = */ wsp_ggml_backend_metal_buffer_cpy_tensor_to,
|
|
1568
1953
|
};
|
|
1569
1954
|
|
|
1570
|
-
static wsp_ggml_backend_buffer_t
|
|
1571
|
-
struct
|
|
1955
|
+
static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
|
|
1956
|
+
struct wsp_ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct wsp_ggml_backend_metal_buffer_context));
|
|
1572
1957
|
|
|
1573
|
-
|
|
1958
|
+
const size_t size_page = sysconf(_SC_PAGESIZE);
|
|
1574
1959
|
|
|
1575
|
-
|
|
1576
|
-
|
|
1960
|
+
size_t size_aligned = size;
|
|
1961
|
+
if ((size_aligned % size_page) != 0) {
|
|
1962
|
+
size_aligned += (size_page - (size_aligned % size_page));
|
|
1963
|
+
}
|
|
1577
1964
|
|
|
1578
|
-
|
|
1965
|
+
ctx->data = wsp_ggml_metal_host_malloc(size);
|
|
1966
|
+
ctx->metal = [wsp_ggml_backend_metal_get_device() newBufferWithBytesNoCopy:ctx->data
|
|
1967
|
+
length:size_aligned
|
|
1968
|
+
options:MTLResourceStorageModeShared
|
|
1969
|
+
deallocator:nil];
|
|
1970
|
+
|
|
1971
|
+
return wsp_ggml_backend_buffer_init(buft, metal_backend_buffer_i, ctx, size);
|
|
1579
1972
|
}
|
|
1580
1973
|
|
|
1581
|
-
static size_t
|
|
1974
|
+
static size_t wsp_ggml_backend_metal_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
|
|
1582
1975
|
return 32;
|
|
1583
|
-
UNUSED(
|
|
1976
|
+
UNUSED(buft);
|
|
1584
1977
|
}
|
|
1585
1978
|
|
|
1586
|
-
static
|
|
1587
|
-
|
|
1588
|
-
WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
|
1589
|
-
|
|
1590
|
-
memcpy((char *)tensor->data + offset, data, size);
|
|
1979
|
+
static bool wsp_ggml_backend_metal_buffer_type_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend) {
|
|
1980
|
+
return wsp_ggml_backend_is_metal(backend) || wsp_ggml_backend_is_cpu(backend);
|
|
1591
1981
|
|
|
1592
|
-
|
|
1982
|
+
WSP_GGML_UNUSED(buft);
|
|
1593
1983
|
}
|
|
1594
1984
|
|
|
1595
|
-
|
|
1596
|
-
|
|
1597
|
-
|
|
1598
|
-
|
|
1599
|
-
|
|
1985
|
+
wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void) {
|
|
1986
|
+
static struct wsp_ggml_backend_buffer_type wsp_ggml_backend_buffer_type_metal = {
|
|
1987
|
+
/* .iface = */ {
|
|
1988
|
+
/* .alloc_buffer = */ wsp_ggml_backend_metal_buffer_type_alloc_buffer,
|
|
1989
|
+
/* .get_alignment = */ wsp_ggml_backend_metal_buffer_type_get_alignment,
|
|
1990
|
+
/* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes
|
|
1991
|
+
/* .supports_backend = */ wsp_ggml_backend_metal_buffer_type_supports_backend,
|
|
1992
|
+
},
|
|
1993
|
+
/* .context = */ NULL,
|
|
1994
|
+
};
|
|
1600
1995
|
|
|
1601
|
-
|
|
1996
|
+
return &wsp_ggml_backend_buffer_type_metal;
|
|
1602
1997
|
}
|
|
1603
1998
|
|
|
1604
|
-
static
|
|
1999
|
+
static const char * wsp_ggml_backend_metal_name(wsp_ggml_backend_t backend) {
|
|
2000
|
+
return "Metal";
|
|
2001
|
+
|
|
1605
2002
|
UNUSED(backend);
|
|
1606
2003
|
}
|
|
1607
2004
|
|
|
1608
|
-
static void
|
|
1609
|
-
|
|
2005
|
+
static void wsp_ggml_backend_metal_free(wsp_ggml_backend_t backend) {
|
|
2006
|
+
struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
|
|
2007
|
+
wsp_ggml_metal_free(ctx);
|
|
2008
|
+
free(backend);
|
|
2009
|
+
}
|
|
1610
2010
|
|
|
2011
|
+
static void wsp_ggml_backend_metal_synchronize(wsp_ggml_backend_t backend) {
|
|
1611
2012
|
UNUSED(backend);
|
|
1612
2013
|
}
|
|
1613
2014
|
|
|
1614
|
-
static
|
|
1615
|
-
|
|
2015
|
+
static wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_get_default_buffer_type(wsp_ggml_backend_t backend) {
|
|
2016
|
+
return wsp_ggml_backend_metal_buffer_type();
|
|
1616
2017
|
|
|
1617
2018
|
UNUSED(backend);
|
|
1618
2019
|
}
|
|
@@ -1624,32 +2025,43 @@ static void wsp_ggml_backend_metal_graph_compute(wsp_ggml_backend_t backend, str
|
|
|
1624
2025
|
}
|
|
1625
2026
|
|
|
1626
2027
|
static bool wsp_ggml_backend_metal_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) {
|
|
1627
|
-
return
|
|
2028
|
+
return wsp_ggml_metal_supports_op(op);
|
|
2029
|
+
|
|
1628
2030
|
UNUSED(backend);
|
|
1629
|
-
UNUSED(op);
|
|
1630
2031
|
}
|
|
1631
2032
|
|
|
1632
2033
|
static struct wsp_ggml_backend_i metal_backend_i = {
|
|
1633
|
-
/* .get_name
|
|
1634
|
-
/* .free
|
|
1635
|
-
/* .
|
|
1636
|
-
/* .
|
|
1637
|
-
/* .
|
|
1638
|
-
/* .
|
|
1639
|
-
/* .
|
|
1640
|
-
/* .
|
|
1641
|
-
/* .
|
|
1642
|
-
/* .
|
|
1643
|
-
/* .
|
|
1644
|
-
/* .
|
|
1645
|
-
/* .
|
|
1646
|
-
/* .supports_op = */ wsp_ggml_backend_metal_supports_op,
|
|
2034
|
+
/* .get_name = */ wsp_ggml_backend_metal_name,
|
|
2035
|
+
/* .free = */ wsp_ggml_backend_metal_free,
|
|
2036
|
+
/* .get_default_buffer_type = */ wsp_ggml_backend_metal_get_default_buffer_type,
|
|
2037
|
+
/* .set_tensor_async = */ NULL,
|
|
2038
|
+
/* .get_tensor_async = */ NULL,
|
|
2039
|
+
/* .cpy_tensor_from_async = */ NULL,
|
|
2040
|
+
/* .cpy_tensor_to_async = */ NULL,
|
|
2041
|
+
/* .synchronize = */ wsp_ggml_backend_metal_synchronize,
|
|
2042
|
+
/* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
|
|
2043
|
+
/* .graph_plan_free = */ NULL,
|
|
2044
|
+
/* .graph_plan_compute = */ NULL,
|
|
2045
|
+
/* .graph_compute = */ wsp_ggml_backend_metal_graph_compute,
|
|
2046
|
+
/* .supports_op = */ wsp_ggml_backend_metal_supports_op,
|
|
1647
2047
|
};
|
|
1648
2048
|
|
|
2049
|
+
// TODO: make a common log callback for all backends in ggml-backend
|
|
2050
|
+
static void wsp_ggml_backend_log_callback(enum wsp_ggml_log_level level, const char * msg, void * user_data) {
|
|
2051
|
+
fprintf(stderr, "%s", msg);
|
|
2052
|
+
|
|
2053
|
+
UNUSED(level);
|
|
2054
|
+
UNUSED(user_data);
|
|
2055
|
+
}
|
|
2056
|
+
|
|
1649
2057
|
wsp_ggml_backend_t wsp_ggml_backend_metal_init(void) {
|
|
1650
|
-
|
|
2058
|
+
wsp_ggml_metal_log_set_callback(wsp_ggml_backend_log_callback, NULL);
|
|
1651
2059
|
|
|
1652
|
-
ctx = wsp_ggml_metal_init(WSP_GGML_DEFAULT_N_THREADS);
|
|
2060
|
+
struct wsp_ggml_metal_context * ctx = wsp_ggml_metal_init(WSP_GGML_DEFAULT_N_THREADS);
|
|
2061
|
+
|
|
2062
|
+
if (ctx == NULL) {
|
|
2063
|
+
return NULL;
|
|
2064
|
+
}
|
|
1653
2065
|
|
|
1654
2066
|
wsp_ggml_backend_t metal_backend = malloc(sizeof(struct wsp_ggml_backend));
|
|
1655
2067
|
|
|
@@ -1666,7 +2078,26 @@ bool wsp_ggml_backend_is_metal(wsp_ggml_backend_t backend) {
|
|
|
1666
2078
|
}
|
|
1667
2079
|
|
|
1668
2080
|
void wsp_ggml_backend_metal_set_n_cb(wsp_ggml_backend_t backend, int n_cb) {
|
|
2081
|
+
WSP_GGML_ASSERT(wsp_ggml_backend_is_metal(backend));
|
|
2082
|
+
|
|
1669
2083
|
struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
|
|
1670
2084
|
|
|
1671
2085
|
wsp_ggml_metal_set_n_cb(ctx, n_cb);
|
|
1672
2086
|
}
|
|
2087
|
+
|
|
2088
|
+
bool wsp_ggml_backend_metal_supports_family(wsp_ggml_backend_t backend, int family) {
|
|
2089
|
+
WSP_GGML_ASSERT(wsp_ggml_backend_is_metal(backend));
|
|
2090
|
+
|
|
2091
|
+
struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
|
|
2092
|
+
|
|
2093
|
+
return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
|
2094
|
+
}
|
|
2095
|
+
|
|
2096
|
+
wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
|
|
2097
|
+
|
|
2098
|
+
wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data) {
|
|
2099
|
+
return wsp_ggml_backend_metal_init();
|
|
2100
|
+
|
|
2101
|
+
WSP_GGML_UNUSED(params);
|
|
2102
|
+
WSP_GGML_UNUSED(user_data);
|
|
2103
|
+
}
|