whisper.rn 0.4.0-rc.4 → 0.4.0-rc.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +6 -6
- package/android/build.gradle +4 -0
- package/android/src/main/CMakeLists.txt +5 -0
- package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +57 -134
- package/android/src/main/jni-utils.h +76 -0
- package/android/src/main/jni.cpp +188 -112
- package/cpp/README.md +1 -1
- package/cpp/coreml/whisper-encoder-impl.h +1 -1
- package/cpp/coreml/whisper-encoder.h +4 -0
- package/cpp/coreml/whisper-encoder.mm +4 -2
- package/cpp/ggml-alloc.c +55 -19
- package/cpp/ggml-alloc.h +8 -1
- package/cpp/ggml-backend-impl.h +46 -21
- package/cpp/ggml-backend.c +563 -156
- package/cpp/ggml-backend.h +62 -17
- package/cpp/ggml-impl.h +1 -1
- package/cpp/ggml-metal-whisper.metal +2444 -359
- package/cpp/ggml-metal.h +7 -1
- package/cpp/ggml-metal.m +1105 -197
- package/cpp/ggml-quants.c +66 -61
- package/cpp/ggml-quants.h +40 -40
- package/cpp/ggml.c +1040 -1590
- package/cpp/ggml.h +109 -30
- package/cpp/rn-audioutils.cpp +68 -0
- package/cpp/rn-audioutils.h +14 -0
- package/cpp/rn-whisper-log.h +11 -0
- package/cpp/rn-whisper.cpp +143 -59
- package/cpp/rn-whisper.h +48 -15
- package/cpp/whisper.cpp +1635 -928
- package/cpp/whisper.h +55 -10
- package/ios/RNWhisper.mm +7 -7
- package/ios/RNWhisperAudioUtils.h +0 -2
- package/ios/RNWhisperAudioUtils.m +0 -56
- package/ios/RNWhisperContext.h +3 -11
- package/ios/RNWhisperContext.mm +68 -137
- package/lib/commonjs/index.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/index.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/index.d.ts +5 -0
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +6 -5
- package/src/index.ts +5 -0
- package/src/version.json +1 -1
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +0 -4
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +0 -8
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
- package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +0 -19
package/cpp/ggml-metal.m
CHANGED
|
@@ -62,11 +62,15 @@ struct wsp_ggml_metal_context {
|
|
|
62
62
|
WSP_GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
|
|
63
63
|
WSP_GGML_METAL_DECL_KERNEL(mul);
|
|
64
64
|
WSP_GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
|
65
|
+
WSP_GGML_METAL_DECL_KERNEL(div);
|
|
66
|
+
WSP_GGML_METAL_DECL_KERNEL(div_row);
|
|
65
67
|
WSP_GGML_METAL_DECL_KERNEL(scale);
|
|
66
68
|
WSP_GGML_METAL_DECL_KERNEL(scale_4);
|
|
67
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
69
|
+
WSP_GGML_METAL_DECL_KERNEL(tanh);
|
|
68
70
|
WSP_GGML_METAL_DECL_KERNEL(relu);
|
|
69
71
|
WSP_GGML_METAL_DECL_KERNEL(gelu);
|
|
72
|
+
WSP_GGML_METAL_DECL_KERNEL(gelu_quick);
|
|
73
|
+
WSP_GGML_METAL_DECL_KERNEL(silu);
|
|
70
74
|
WSP_GGML_METAL_DECL_KERNEL(soft_max);
|
|
71
75
|
WSP_GGML_METAL_DECL_KERNEL(soft_max_4);
|
|
72
76
|
WSP_GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
|
@@ -84,8 +88,10 @@ struct wsp_ggml_metal_context {
|
|
|
84
88
|
WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
|
85
89
|
WSP_GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
|
86
90
|
WSP_GGML_METAL_DECL_KERNEL(rms_norm);
|
|
91
|
+
WSP_GGML_METAL_DECL_KERNEL(group_norm);
|
|
87
92
|
WSP_GGML_METAL_DECL_KERNEL(norm);
|
|
88
93
|
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
|
94
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
|
|
89
95
|
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
|
|
90
96
|
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
|
|
91
97
|
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
|
|
@@ -99,6 +105,21 @@ struct wsp_ggml_metal_context {
|
|
|
99
105
|
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
|
|
100
106
|
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
|
101
107
|
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
|
108
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
|
|
109
|
+
//WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
|
|
110
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
|
|
111
|
+
//WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
|
|
112
|
+
//WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
|
|
113
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
|
|
114
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
|
|
115
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
|
|
116
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
|
|
117
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
|
|
118
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
|
|
119
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
|
|
120
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
|
|
121
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
|
|
122
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
|
|
102
123
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
|
103
124
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
|
104
125
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
|
@@ -111,14 +132,39 @@ struct wsp_ggml_metal_context {
|
|
|
111
132
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
|
112
133
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
|
113
134
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
|
135
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
|
|
136
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
|
|
137
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
|
|
138
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
|
|
139
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
|
|
140
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
|
|
141
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
|
|
142
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
|
|
143
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
|
|
144
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
|
|
145
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
|
|
146
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
|
|
114
147
|
WSP_GGML_METAL_DECL_KERNEL(rope_f32);
|
|
115
148
|
WSP_GGML_METAL_DECL_KERNEL(rope_f16);
|
|
116
149
|
WSP_GGML_METAL_DECL_KERNEL(alibi_f32);
|
|
150
|
+
WSP_GGML_METAL_DECL_KERNEL(im2col_f16);
|
|
151
|
+
WSP_GGML_METAL_DECL_KERNEL(upscale_f32);
|
|
152
|
+
WSP_GGML_METAL_DECL_KERNEL(pad_f32);
|
|
153
|
+
WSP_GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
|
|
154
|
+
WSP_GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
|
|
155
|
+
WSP_GGML_METAL_DECL_KERNEL(leaky_relu_f32);
|
|
117
156
|
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
|
118
157
|
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
|
158
|
+
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
|
|
159
|
+
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
|
|
160
|
+
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
|
|
161
|
+
//WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
|
|
162
|
+
//WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
|
|
119
163
|
WSP_GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
|
164
|
+
WSP_GGML_METAL_DECL_KERNEL(cpy_f16_f32);
|
|
120
165
|
WSP_GGML_METAL_DECL_KERNEL(concat);
|
|
121
166
|
WSP_GGML_METAL_DECL_KERNEL(sqr);
|
|
167
|
+
WSP_GGML_METAL_DECL_KERNEL(sum_rows);
|
|
122
168
|
|
|
123
169
|
#undef WSP_GGML_METAL_DECL_KERNEL
|
|
124
170
|
};
|
|
@@ -126,7 +172,7 @@ struct wsp_ggml_metal_context {
|
|
|
126
172
|
// MSL code
|
|
127
173
|
// TODO: move the contents here when ready
|
|
128
174
|
// for now it is easier to work in a separate file
|
|
129
|
-
static NSString * const msl_library_source = @"see metal.metal";
|
|
175
|
+
//static NSString * const msl_library_source = @"see metal.metal";
|
|
130
176
|
|
|
131
177
|
// Here to assist with NSBundle Path Hack
|
|
132
178
|
@interface WSPGGMLMetalClass : NSObject
|
|
@@ -142,7 +188,8 @@ void wsp_ggml_metal_log_set_callback(wsp_ggml_log_callback log_callback, void *
|
|
|
142
188
|
wsp_ggml_metal_log_user_data = user_data;
|
|
143
189
|
}
|
|
144
190
|
|
|
145
|
-
|
|
191
|
+
WSP_GGML_ATTRIBUTE_FORMAT(2, 3)
|
|
192
|
+
static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char * format, ...){
|
|
146
193
|
if (wsp_ggml_metal_log_callback != NULL) {
|
|
147
194
|
va_list args;
|
|
148
195
|
va_start(args, format);
|
|
@@ -152,6 +199,8 @@ static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char* format
|
|
|
152
199
|
wsp_ggml_metal_log_callback(level, buffer, wsp_ggml_metal_log_user_data);
|
|
153
200
|
} else {
|
|
154
201
|
char* buffer2 = malloc(len+1);
|
|
202
|
+
va_end(args);
|
|
203
|
+
va_start(args, format);
|
|
155
204
|
vsnprintf(buffer2, len+1, format, args);
|
|
156
205
|
buffer2[len] = 0;
|
|
157
206
|
wsp_ggml_metal_log_callback(level, buffer2, wsp_ggml_metal_log_user_data);
|
|
@@ -161,12 +210,10 @@ static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char* format
|
|
|
161
210
|
}
|
|
162
211
|
}
|
|
163
212
|
|
|
164
|
-
|
|
165
|
-
|
|
166
213
|
struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
167
214
|
WSP_GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
|
|
168
215
|
|
|
169
|
-
id
|
|
216
|
+
id<MTLDevice> device;
|
|
170
217
|
NSString * s;
|
|
171
218
|
|
|
172
219
|
#if TARGET_OS_OSX
|
|
@@ -184,7 +231,7 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
184
231
|
WSP_GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
|
|
185
232
|
|
|
186
233
|
// Configure context
|
|
187
|
-
struct wsp_ggml_metal_context * ctx =
|
|
234
|
+
struct wsp_ggml_metal_context * ctx = calloc(1, sizeof(struct wsp_ggml_metal_context));
|
|
188
235
|
ctx->device = device;
|
|
189
236
|
ctx->n_cb = MIN(n_cb, WSP_GGML_METAL_MAX_BUFFERS);
|
|
190
237
|
ctx->queue = [ctx->device newCommandQueue];
|
|
@@ -212,6 +259,9 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
212
259
|
|
|
213
260
|
NSString * sourcePath;
|
|
214
261
|
NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"WSP_GGML_METAL_PATH_RESOURCES"];
|
|
262
|
+
|
|
263
|
+
WSP_GGML_METAL_LOG_INFO("%s: WSP_GGML_METAL_PATH_RESOURCES = %s\n", __func__, ggmlMetalPathResources ? [ggmlMetalPathResources UTF8String] : "nil");
|
|
264
|
+
|
|
215
265
|
if (ggmlMetalPathResources) {
|
|
216
266
|
sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
|
|
217
267
|
} else {
|
|
@@ -242,6 +292,29 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
242
292
|
}
|
|
243
293
|
}
|
|
244
294
|
|
|
295
|
+
#if TARGET_OS_OSX
|
|
296
|
+
// print MTL GPU family:
|
|
297
|
+
WSP_GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
|
|
298
|
+
|
|
299
|
+
// determine max supported GPU family
|
|
300
|
+
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
|
301
|
+
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
|
302
|
+
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
|
303
|
+
if ([ctx->device supportsFamily:i]) {
|
|
304
|
+
WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
|
|
305
|
+
break;
|
|
306
|
+
}
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
WSP_GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
|
310
|
+
WSP_GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
|
|
311
|
+
if (ctx->device.maxTransferRate != 0) {
|
|
312
|
+
WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
|
|
313
|
+
} else {
|
|
314
|
+
WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
|
|
315
|
+
}
|
|
316
|
+
#endif
|
|
317
|
+
|
|
245
318
|
// load kernels
|
|
246
319
|
{
|
|
247
320
|
NSError * error = nil;
|
|
@@ -263,11 +336,15 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
263
336
|
WSP_GGML_METAL_ADD_KERNEL(add_row);
|
|
264
337
|
WSP_GGML_METAL_ADD_KERNEL(mul);
|
|
265
338
|
WSP_GGML_METAL_ADD_KERNEL(mul_row);
|
|
339
|
+
WSP_GGML_METAL_ADD_KERNEL(div);
|
|
340
|
+
WSP_GGML_METAL_ADD_KERNEL(div_row);
|
|
266
341
|
WSP_GGML_METAL_ADD_KERNEL(scale);
|
|
267
342
|
WSP_GGML_METAL_ADD_KERNEL(scale_4);
|
|
268
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
343
|
+
WSP_GGML_METAL_ADD_KERNEL(tanh);
|
|
269
344
|
WSP_GGML_METAL_ADD_KERNEL(relu);
|
|
270
345
|
WSP_GGML_METAL_ADD_KERNEL(gelu);
|
|
346
|
+
WSP_GGML_METAL_ADD_KERNEL(gelu_quick);
|
|
347
|
+
WSP_GGML_METAL_ADD_KERNEL(silu);
|
|
271
348
|
WSP_GGML_METAL_ADD_KERNEL(soft_max);
|
|
272
349
|
WSP_GGML_METAL_ADD_KERNEL(soft_max_4);
|
|
273
350
|
WSP_GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
|
@@ -285,8 +362,10 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
285
362
|
WSP_GGML_METAL_ADD_KERNEL(get_rows_q5_K);
|
|
286
363
|
WSP_GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
|
287
364
|
WSP_GGML_METAL_ADD_KERNEL(rms_norm);
|
|
365
|
+
WSP_GGML_METAL_ADD_KERNEL(group_norm);
|
|
288
366
|
WSP_GGML_METAL_ADD_KERNEL(norm);
|
|
289
367
|
WSP_GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
|
368
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
|
|
290
369
|
WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
|
|
291
370
|
WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
|
|
292
371
|
WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
|
|
@@ -300,6 +379,21 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
300
379
|
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
|
|
301
380
|
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
|
|
302
381
|
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
|
|
382
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
|
|
383
|
+
//WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
|
|
384
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
|
|
385
|
+
//WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
|
|
386
|
+
//WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
|
|
387
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
|
|
388
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
|
|
389
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
|
|
390
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
|
|
391
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
|
|
392
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
|
|
393
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
|
|
394
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
|
|
395
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
|
|
396
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
|
|
303
397
|
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
|
304
398
|
WSP_GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
|
305
399
|
WSP_GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
|
@@ -313,42 +407,44 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
313
407
|
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
|
314
408
|
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
|
315
409
|
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
|
410
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
|
|
411
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
|
|
412
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
|
|
413
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
|
|
414
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
|
|
415
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
|
|
416
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
|
|
417
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
|
|
418
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
|
|
419
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
|
|
420
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
|
|
421
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
|
|
316
422
|
}
|
|
317
423
|
WSP_GGML_METAL_ADD_KERNEL(rope_f32);
|
|
318
424
|
WSP_GGML_METAL_ADD_KERNEL(rope_f16);
|
|
319
425
|
WSP_GGML_METAL_ADD_KERNEL(alibi_f32);
|
|
426
|
+
WSP_GGML_METAL_ADD_KERNEL(im2col_f16);
|
|
427
|
+
WSP_GGML_METAL_ADD_KERNEL(upscale_f32);
|
|
428
|
+
WSP_GGML_METAL_ADD_KERNEL(pad_f32);
|
|
429
|
+
WSP_GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
|
|
430
|
+
WSP_GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
|
|
431
|
+
WSP_GGML_METAL_ADD_KERNEL(leaky_relu_f32);
|
|
320
432
|
WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
|
321
433
|
WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
|
434
|
+
WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
|
|
435
|
+
WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
|
|
436
|
+
WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
|
|
437
|
+
//WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
|
|
438
|
+
//WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
|
|
322
439
|
WSP_GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
|
440
|
+
WSP_GGML_METAL_ADD_KERNEL(cpy_f16_f32);
|
|
323
441
|
WSP_GGML_METAL_ADD_KERNEL(concat);
|
|
324
442
|
WSP_GGML_METAL_ADD_KERNEL(sqr);
|
|
443
|
+
WSP_GGML_METAL_ADD_KERNEL(sum_rows);
|
|
325
444
|
|
|
326
445
|
#undef WSP_GGML_METAL_ADD_KERNEL
|
|
327
446
|
}
|
|
328
447
|
|
|
329
|
-
#if TARGET_OS_OSX
|
|
330
|
-
// print MTL GPU family:
|
|
331
|
-
WSP_GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
|
|
332
|
-
|
|
333
|
-
// determine max supported GPU family
|
|
334
|
-
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
|
335
|
-
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
|
336
|
-
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
|
337
|
-
if ([ctx->device supportsFamily:i]) {
|
|
338
|
-
WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
|
|
339
|
-
break;
|
|
340
|
-
}
|
|
341
|
-
}
|
|
342
|
-
|
|
343
|
-
WSP_GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
|
344
|
-
WSP_GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
|
345
|
-
if (ctx->device.maxTransferRate != 0) {
|
|
346
|
-
WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
|
347
|
-
} else {
|
|
348
|
-
WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
|
|
349
|
-
}
|
|
350
|
-
#endif
|
|
351
|
-
|
|
352
448
|
return ctx;
|
|
353
449
|
}
|
|
354
450
|
|
|
@@ -360,11 +456,15 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
|
|
|
360
456
|
WSP_GGML_METAL_DEL_KERNEL(add_row);
|
|
361
457
|
WSP_GGML_METAL_DEL_KERNEL(mul);
|
|
362
458
|
WSP_GGML_METAL_DEL_KERNEL(mul_row);
|
|
459
|
+
WSP_GGML_METAL_DEL_KERNEL(div);
|
|
460
|
+
WSP_GGML_METAL_DEL_KERNEL(div_row);
|
|
363
461
|
WSP_GGML_METAL_DEL_KERNEL(scale);
|
|
364
462
|
WSP_GGML_METAL_DEL_KERNEL(scale_4);
|
|
365
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
463
|
+
WSP_GGML_METAL_DEL_KERNEL(tanh);
|
|
366
464
|
WSP_GGML_METAL_DEL_KERNEL(relu);
|
|
367
465
|
WSP_GGML_METAL_DEL_KERNEL(gelu);
|
|
466
|
+
WSP_GGML_METAL_DEL_KERNEL(gelu_quick);
|
|
467
|
+
WSP_GGML_METAL_DEL_KERNEL(silu);
|
|
368
468
|
WSP_GGML_METAL_DEL_KERNEL(soft_max);
|
|
369
469
|
WSP_GGML_METAL_DEL_KERNEL(soft_max_4);
|
|
370
470
|
WSP_GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
|
@@ -382,8 +482,10 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
|
|
|
382
482
|
WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
|
383
483
|
WSP_GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
|
384
484
|
WSP_GGML_METAL_DEL_KERNEL(rms_norm);
|
|
485
|
+
WSP_GGML_METAL_DEL_KERNEL(group_norm);
|
|
385
486
|
WSP_GGML_METAL_DEL_KERNEL(norm);
|
|
386
487
|
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
|
488
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
|
|
387
489
|
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
|
|
388
490
|
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
|
|
389
491
|
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
|
|
@@ -397,6 +499,21 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
|
|
|
397
499
|
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
|
|
398
500
|
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
|
399
501
|
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
|
502
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
|
|
503
|
+
//WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
|
|
504
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
|
|
505
|
+
//WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
|
|
506
|
+
//WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
|
|
507
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
|
|
508
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
|
|
509
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
|
|
510
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
|
|
511
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
|
|
512
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
|
|
513
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
|
|
514
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
|
|
515
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
|
|
516
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
|
|
400
517
|
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
|
401
518
|
WSP_GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
|
402
519
|
WSP_GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
|
@@ -410,15 +527,40 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
|
|
|
410
527
|
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
|
411
528
|
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
|
412
529
|
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
|
530
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
|
|
531
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
|
|
532
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
|
|
533
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
|
|
534
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
|
|
535
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
|
|
536
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
|
|
537
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
|
|
538
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
|
|
539
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
|
|
540
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
|
|
541
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
|
|
413
542
|
}
|
|
414
543
|
WSP_GGML_METAL_DEL_KERNEL(rope_f32);
|
|
415
544
|
WSP_GGML_METAL_DEL_KERNEL(rope_f16);
|
|
416
545
|
WSP_GGML_METAL_DEL_KERNEL(alibi_f32);
|
|
546
|
+
WSP_GGML_METAL_DEL_KERNEL(im2col_f16);
|
|
547
|
+
WSP_GGML_METAL_DEL_KERNEL(upscale_f32);
|
|
548
|
+
WSP_GGML_METAL_DEL_KERNEL(pad_f32);
|
|
549
|
+
WSP_GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
|
|
550
|
+
WSP_GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
|
|
551
|
+
WSP_GGML_METAL_DEL_KERNEL(leaky_relu_f32);
|
|
417
552
|
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
|
418
553
|
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
|
554
|
+
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
|
|
555
|
+
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
|
|
556
|
+
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
|
|
557
|
+
//WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
|
|
558
|
+
//WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
|
|
419
559
|
WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
|
560
|
+
WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f32);
|
|
420
561
|
WSP_GGML_METAL_DEL_KERNEL(concat);
|
|
421
562
|
WSP_GGML_METAL_DEL_KERNEL(sqr);
|
|
563
|
+
WSP_GGML_METAL_DEL_KERNEL(sum_rows);
|
|
422
564
|
|
|
423
565
|
#undef WSP_GGML_METAL_DEL_KERNEL
|
|
424
566
|
|
|
@@ -452,6 +594,13 @@ int * wsp_ggml_metal_get_concur_list(struct wsp_ggml_metal_context * ctx) {
|
|
|
452
594
|
return ctx->concur_list;
|
|
453
595
|
}
|
|
454
596
|
|
|
597
|
+
// temporarily defined here for compatibility between ggml-backend and the old API
|
|
598
|
+
struct wsp_ggml_backend_metal_buffer_context {
|
|
599
|
+
void * data;
|
|
600
|
+
|
|
601
|
+
id<MTLBuffer> metal;
|
|
602
|
+
};
|
|
603
|
+
|
|
455
604
|
// finds the Metal buffer that contains the tensor data on the GPU device
|
|
456
605
|
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
|
457
606
|
// Metal buffer based on the host memory pointer
|
|
@@ -461,6 +610,19 @@ static id<MTLBuffer> wsp_ggml_metal_get_buffer(struct wsp_ggml_metal_context * c
|
|
|
461
610
|
|
|
462
611
|
const int64_t tsize = wsp_ggml_nbytes(t);
|
|
463
612
|
|
|
613
|
+
// compatibility with ggml-backend
|
|
614
|
+
if (t->buffer && t->buffer->buft == wsp_ggml_backend_metal_buffer_type()) {
|
|
615
|
+
struct wsp_ggml_backend_metal_buffer_context * buf_ctx = (struct wsp_ggml_backend_metal_buffer_context *) t->buffer->context;
|
|
616
|
+
|
|
617
|
+
const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->data;
|
|
618
|
+
|
|
619
|
+
WSP_GGML_ASSERT(ioffs >= 0 && ioffs + tsize <= (int64_t) t->buffer->size);
|
|
620
|
+
|
|
621
|
+
*offs = (size_t) ioffs;
|
|
622
|
+
|
|
623
|
+
return buf_ctx->metal;
|
|
624
|
+
}
|
|
625
|
+
|
|
464
626
|
// find the view that contains the tensor fully
|
|
465
627
|
for (int i = 0; i < ctx->n_buffers; ++i) {
|
|
466
628
|
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
|
@@ -518,11 +680,11 @@ bool wsp_ggml_metal_add_buffer(
|
|
|
518
680
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
|
519
681
|
|
|
520
682
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
|
521
|
-
WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f
|
|
683
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
|
522
684
|
return false;
|
|
523
685
|
}
|
|
524
686
|
|
|
525
|
-
WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f
|
|
687
|
+
WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
|
526
688
|
|
|
527
689
|
++ctx->n_buffers;
|
|
528
690
|
} else {
|
|
@@ -542,11 +704,11 @@ bool wsp_ggml_metal_add_buffer(
|
|
|
542
704
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
|
543
705
|
|
|
544
706
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
|
545
|
-
WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f
|
|
707
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
|
546
708
|
return false;
|
|
547
709
|
}
|
|
548
710
|
|
|
549
|
-
WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f
|
|
711
|
+
WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
|
550
712
|
if (i + size_step < size) {
|
|
551
713
|
WSP_GGML_METAL_LOG_INFO("\n");
|
|
552
714
|
}
|
|
@@ -561,7 +723,7 @@ bool wsp_ggml_metal_add_buffer(
|
|
|
561
723
|
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
|
562
724
|
|
|
563
725
|
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
|
564
|
-
WSP_GGML_METAL_LOG_WARN("
|
|
726
|
+
WSP_GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
|
565
727
|
} else {
|
|
566
728
|
WSP_GGML_METAL_LOG_INFO("\n");
|
|
567
729
|
}
|
|
@@ -683,6 +845,83 @@ void wsp_ggml_metal_graph_find_concurrency(
|
|
|
683
845
|
}
|
|
684
846
|
}
|
|
685
847
|
|
|
848
|
+
static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) {
|
|
849
|
+
switch (op->op) {
|
|
850
|
+
case WSP_GGML_OP_UNARY:
|
|
851
|
+
switch (wsp_ggml_get_unary_op(op)) {
|
|
852
|
+
case WSP_GGML_UNARY_OP_TANH:
|
|
853
|
+
case WSP_GGML_UNARY_OP_RELU:
|
|
854
|
+
case WSP_GGML_UNARY_OP_GELU:
|
|
855
|
+
case WSP_GGML_UNARY_OP_GELU_QUICK:
|
|
856
|
+
case WSP_GGML_UNARY_OP_SILU:
|
|
857
|
+
return true;
|
|
858
|
+
default:
|
|
859
|
+
return false;
|
|
860
|
+
}
|
|
861
|
+
case WSP_GGML_OP_NONE:
|
|
862
|
+
case WSP_GGML_OP_RESHAPE:
|
|
863
|
+
case WSP_GGML_OP_VIEW:
|
|
864
|
+
case WSP_GGML_OP_TRANSPOSE:
|
|
865
|
+
case WSP_GGML_OP_PERMUTE:
|
|
866
|
+
case WSP_GGML_OP_CONCAT:
|
|
867
|
+
case WSP_GGML_OP_ADD:
|
|
868
|
+
case WSP_GGML_OP_ACC:
|
|
869
|
+
case WSP_GGML_OP_MUL:
|
|
870
|
+
case WSP_GGML_OP_DIV:
|
|
871
|
+
case WSP_GGML_OP_SCALE:
|
|
872
|
+
case WSP_GGML_OP_SQR:
|
|
873
|
+
case WSP_GGML_OP_SUM_ROWS:
|
|
874
|
+
case WSP_GGML_OP_SOFT_MAX:
|
|
875
|
+
case WSP_GGML_OP_RMS_NORM:
|
|
876
|
+
case WSP_GGML_OP_GROUP_NORM:
|
|
877
|
+
case WSP_GGML_OP_NORM:
|
|
878
|
+
case WSP_GGML_OP_ALIBI:
|
|
879
|
+
case WSP_GGML_OP_ROPE:
|
|
880
|
+
case WSP_GGML_OP_IM2COL:
|
|
881
|
+
case WSP_GGML_OP_UPSCALE:
|
|
882
|
+
case WSP_GGML_OP_PAD:
|
|
883
|
+
case WSP_GGML_OP_ARGSORT:
|
|
884
|
+
case WSP_GGML_OP_LEAKY_RELU:
|
|
885
|
+
case WSP_GGML_OP_MUL_MAT:
|
|
886
|
+
case WSP_GGML_OP_MUL_MAT_ID:
|
|
887
|
+
return true;
|
|
888
|
+
case WSP_GGML_OP_CPY:
|
|
889
|
+
case WSP_GGML_OP_DUP:
|
|
890
|
+
case WSP_GGML_OP_CONT:
|
|
891
|
+
{
|
|
892
|
+
switch (op->src[0]->type) {
|
|
893
|
+
case WSP_GGML_TYPE_F32:
|
|
894
|
+
switch (op->type) {
|
|
895
|
+
case WSP_GGML_TYPE_F16:
|
|
896
|
+
case WSP_GGML_TYPE_F32:
|
|
897
|
+
case WSP_GGML_TYPE_Q8_0:
|
|
898
|
+
case WSP_GGML_TYPE_Q4_0:
|
|
899
|
+
case WSP_GGML_TYPE_Q4_1:
|
|
900
|
+
return true;
|
|
901
|
+
default:
|
|
902
|
+
return false;
|
|
903
|
+
}
|
|
904
|
+
case WSP_GGML_TYPE_F16:
|
|
905
|
+
switch (op->type) {
|
|
906
|
+
case WSP_GGML_TYPE_F16:
|
|
907
|
+
case WSP_GGML_TYPE_F32:
|
|
908
|
+
return true;
|
|
909
|
+
default:
|
|
910
|
+
return false;
|
|
911
|
+
}
|
|
912
|
+
default:
|
|
913
|
+
return false;
|
|
914
|
+
};
|
|
915
|
+
}
|
|
916
|
+
case WSP_GGML_OP_DIAG_MASK_INF:
|
|
917
|
+
case WSP_GGML_OP_GET_ROWS:
|
|
918
|
+
{
|
|
919
|
+
return op->ne[3] == 1;
|
|
920
|
+
}
|
|
921
|
+
default:
|
|
922
|
+
return false;
|
|
923
|
+
}
|
|
924
|
+
}
|
|
686
925
|
void wsp_ggml_metal_graph_compute(
|
|
687
926
|
struct wsp_ggml_metal_context * ctx,
|
|
688
927
|
struct wsp_ggml_cgraph * gf) {
|
|
@@ -753,6 +992,11 @@ void wsp_ggml_metal_graph_compute(
|
|
|
753
992
|
} break;
|
|
754
993
|
}
|
|
755
994
|
|
|
995
|
+
if (!wsp_ggml_metal_supports_op(dst)) {
|
|
996
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, wsp_ggml_op_desc(dst));
|
|
997
|
+
WSP_GGML_ASSERT(!"unsupported op");
|
|
998
|
+
}
|
|
999
|
+
|
|
756
1000
|
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
|
757
1001
|
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
|
758
1002
|
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
|
@@ -845,25 +1089,42 @@ void wsp_ggml_metal_graph_compute(
|
|
|
845
1089
|
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
846
1090
|
} break;
|
|
847
1091
|
case WSP_GGML_OP_ADD:
|
|
1092
|
+
case WSP_GGML_OP_MUL:
|
|
1093
|
+
case WSP_GGML_OP_DIV:
|
|
848
1094
|
{
|
|
849
|
-
|
|
850
|
-
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
|
|
1095
|
+
const size_t offs = 0;
|
|
851
1096
|
|
|
852
1097
|
bool bcast_row = false;
|
|
853
1098
|
|
|
854
1099
|
int64_t nb = ne00;
|
|
855
1100
|
|
|
856
|
-
|
|
1101
|
+
id<MTLComputePipelineState> pipeline = nil;
|
|
1102
|
+
|
|
1103
|
+
if (wsp_ggml_nelements(src1) == ne10 && wsp_ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
|
1104
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
1105
|
+
|
|
857
1106
|
// src1 is a row
|
|
858
1107
|
WSP_GGML_ASSERT(ne11 == 1);
|
|
859
1108
|
|
|
860
1109
|
nb = ne00 / 4;
|
|
861
|
-
|
|
1110
|
+
switch (dst->op) {
|
|
1111
|
+
case WSP_GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
|
|
1112
|
+
case WSP_GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
|
|
1113
|
+
case WSP_GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
|
|
1114
|
+
default: WSP_GGML_ASSERT(false);
|
|
1115
|
+
}
|
|
862
1116
|
|
|
863
1117
|
bcast_row = true;
|
|
864
1118
|
} else {
|
|
865
|
-
|
|
1119
|
+
switch (dst->op) {
|
|
1120
|
+
case WSP_GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
|
|
1121
|
+
case WSP_GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
|
|
1122
|
+
case WSP_GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
|
|
1123
|
+
default: WSP_GGML_ASSERT(false);
|
|
1124
|
+
}
|
|
866
1125
|
}
|
|
1126
|
+
|
|
1127
|
+
[encoder setComputePipelineState:pipeline];
|
|
867
1128
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
868
1129
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
869
1130
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
@@ -891,42 +1152,98 @@ void wsp_ggml_metal_graph_compute(
|
|
|
891
1152
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
|
892
1153
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
|
893
1154
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
|
894
|
-
[encoder setBytes:&
|
|
1155
|
+
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
|
1156
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
|
|
895
1157
|
|
|
896
1158
|
if (bcast_row) {
|
|
897
1159
|
const int64_t n = wsp_ggml_nelements(dst)/4;
|
|
898
1160
|
|
|
899
1161
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
900
1162
|
} else {
|
|
901
|
-
const int nth = MIN(
|
|
1163
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
|
902
1164
|
|
|
903
1165
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
904
1166
|
}
|
|
905
1167
|
} break;
|
|
906
|
-
case
|
|
1168
|
+
case WSP_GGML_OP_ACC:
|
|
907
1169
|
{
|
|
1170
|
+
WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_F32);
|
|
1171
|
+
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
1172
|
+
WSP_GGML_ASSERT(dstt == WSP_GGML_TYPE_F32);
|
|
1173
|
+
|
|
908
1174
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
909
1175
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
|
|
910
1176
|
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
const
|
|
1177
|
+
const size_t pnb1 = ((int32_t *) dst->op_params)[0];
|
|
1178
|
+
const size_t pnb2 = ((int32_t *) dst->op_params)[1];
|
|
1179
|
+
const size_t pnb3 = ((int32_t *) dst->op_params)[2];
|
|
1180
|
+
const size_t offs = ((int32_t *) dst->op_params)[3];
|
|
1181
|
+
|
|
1182
|
+
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
|
1183
|
+
|
|
1184
|
+
if (!inplace) {
|
|
1185
|
+
// run a separete kernel to cpy src->dst
|
|
1186
|
+
// not sure how to avoid this
|
|
1187
|
+
// TODO: make a simpler cpy_bytes kernel
|
|
1188
|
+
|
|
1189
|
+
const int nth = MIN(1024, ne00);
|
|
1190
|
+
|
|
1191
|
+
[encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
|
|
1192
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1193
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1194
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1195
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
1196
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
1197
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
|
1198
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
|
1199
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
|
1200
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
|
1201
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
|
1202
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
|
1203
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
|
1204
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
|
1205
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
|
1206
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
|
1207
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
1208
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
1209
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
914
1210
|
|
|
915
|
-
|
|
916
|
-
// src1 is a row
|
|
917
|
-
WSP_GGML_ASSERT(ne11 == 1);
|
|
918
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
|
919
|
-
} else {
|
|
920
|
-
[encoder setComputePipelineState:ctx->pipeline_mul];
|
|
1211
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
921
1212
|
}
|
|
1213
|
+
|
|
1214
|
+
[encoder setComputePipelineState:ctx->pipeline_add];
|
|
922
1215
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
923
1216
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
924
1217
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
925
|
-
[encoder setBytes:&
|
|
1218
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
1219
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
1220
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
1221
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
|
1222
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
|
1223
|
+
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
|
|
1224
|
+
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
|
|
1225
|
+
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
|
|
1226
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
|
1227
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
|
1228
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
|
1229
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
|
1230
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
|
1231
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
|
1232
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
|
1233
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
|
1234
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
|
1235
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
|
1236
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
|
1237
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
|
1238
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
|
1239
|
+
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
|
|
1240
|
+
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
|
|
1241
|
+
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
|
|
1242
|
+
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
|
926
1243
|
|
|
927
|
-
const
|
|
1244
|
+
const int nth = MIN(1024, ne0);
|
|
928
1245
|
|
|
929
|
-
[encoder dispatchThreadgroups:MTLSizeMake(
|
|
1246
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
930
1247
|
} break;
|
|
931
1248
|
case WSP_GGML_OP_SCALE:
|
|
932
1249
|
{
|
|
@@ -951,16 +1268,15 @@ void wsp_ggml_metal_graph_compute(
|
|
|
951
1268
|
} break;
|
|
952
1269
|
case WSP_GGML_OP_UNARY:
|
|
953
1270
|
switch (wsp_ggml_get_unary_op(gf->nodes[i])) {
|
|
954
|
-
case
|
|
1271
|
+
case WSP_GGML_UNARY_OP_TANH:
|
|
955
1272
|
{
|
|
956
|
-
[encoder setComputePipelineState:ctx->
|
|
1273
|
+
[encoder setComputePipelineState:ctx->pipeline_tanh];
|
|
957
1274
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
958
1275
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
959
1276
|
|
|
960
1277
|
const int64_t n = wsp_ggml_nelements(dst);
|
|
961
|
-
WSP_GGML_ASSERT(n % 4 == 0);
|
|
962
1278
|
|
|
963
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n
|
|
1279
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
964
1280
|
} break;
|
|
965
1281
|
case WSP_GGML_UNARY_OP_RELU:
|
|
966
1282
|
{
|
|
@@ -981,6 +1297,28 @@ void wsp_ggml_metal_graph_compute(
|
|
|
981
1297
|
const int64_t n = wsp_ggml_nelements(dst);
|
|
982
1298
|
WSP_GGML_ASSERT(n % 4 == 0);
|
|
983
1299
|
|
|
1300
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1301
|
+
} break;
|
|
1302
|
+
case WSP_GGML_UNARY_OP_GELU_QUICK:
|
|
1303
|
+
{
|
|
1304
|
+
[encoder setComputePipelineState:ctx->pipeline_gelu_quick];
|
|
1305
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1306
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1307
|
+
|
|
1308
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
1309
|
+
WSP_GGML_ASSERT(n % 4 == 0);
|
|
1310
|
+
|
|
1311
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1312
|
+
} break;
|
|
1313
|
+
case WSP_GGML_UNARY_OP_SILU:
|
|
1314
|
+
{
|
|
1315
|
+
[encoder setComputePipelineState:ctx->pipeline_silu];
|
|
1316
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1317
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1318
|
+
|
|
1319
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
1320
|
+
WSP_GGML_ASSERT(n % 4 == 0);
|
|
1321
|
+
|
|
984
1322
|
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
985
1323
|
} break;
|
|
986
1324
|
default:
|
|
@@ -1000,25 +1338,70 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1000
1338
|
const int64_t n = wsp_ggml_nelements(dst);
|
|
1001
1339
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1002
1340
|
} break;
|
|
1341
|
+
case WSP_GGML_OP_SUM_ROWS:
|
|
1342
|
+
{
|
|
1343
|
+
WSP_GGML_ASSERT(src0->nb[0] == wsp_ggml_type_size(src0->type));
|
|
1344
|
+
|
|
1345
|
+
[encoder setComputePipelineState:ctx->pipeline_sum_rows];
|
|
1346
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1347
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1348
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
1349
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
1350
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
1351
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
1352
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
1353
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
1354
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
1355
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
1356
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
|
1357
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
|
1358
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
|
1359
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
|
1360
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
|
1361
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
|
1362
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
|
1363
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
|
1364
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
|
|
1365
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
|
|
1366
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
|
|
1367
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
|
|
1368
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
|
|
1369
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
|
|
1370
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
|
|
1371
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
|
|
1372
|
+
|
|
1373
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1374
|
+
} break;
|
|
1003
1375
|
case WSP_GGML_OP_SOFT_MAX:
|
|
1004
1376
|
{
|
|
1005
1377
|
int nth = 32; // SIMD width
|
|
1006
1378
|
|
|
1007
1379
|
if (ne00%4 == 0) {
|
|
1380
|
+
while (nth < ne00/4 && nth < 256) {
|
|
1381
|
+
nth *= 2;
|
|
1382
|
+
}
|
|
1008
1383
|
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
|
1009
1384
|
} else {
|
|
1010
|
-
|
|
1385
|
+
while (nth < ne00 && nth < 1024) {
|
|
1011
1386
|
nth *= 2;
|
|
1012
|
-
}
|
|
1013
|
-
nth /= 2;
|
|
1387
|
+
}
|
|
1014
1388
|
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
|
1015
1389
|
}
|
|
1016
|
-
|
|
1017
|
-
[
|
|
1018
|
-
|
|
1019
|
-
[encoder
|
|
1020
|
-
|
|
1021
|
-
|
|
1390
|
+
|
|
1391
|
+
const float scale = ((float *) dst->op_params)[0];
|
|
1392
|
+
|
|
1393
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1394
|
+
if (id_src1) {
|
|
1395
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1396
|
+
} else {
|
|
1397
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
1398
|
+
}
|
|
1399
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1400
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
1401
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
1402
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
1403
|
+
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
|
1404
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
1022
1405
|
|
|
1023
1406
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1024
1407
|
} break;
|
|
@@ -1047,9 +1430,13 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1047
1430
|
case WSP_GGML_OP_MUL_MAT:
|
|
1048
1431
|
{
|
|
1049
1432
|
WSP_GGML_ASSERT(ne00 == ne10);
|
|
1050
|
-
WSP_GGML_ASSERT(ne03 == ne13);
|
|
1051
1433
|
|
|
1052
|
-
|
|
1434
|
+
// TODO: assert that dim2 and dim3 are contiguous
|
|
1435
|
+
WSP_GGML_ASSERT(ne12 % ne02 == 0);
|
|
1436
|
+
WSP_GGML_ASSERT(ne13 % ne03 == 0);
|
|
1437
|
+
|
|
1438
|
+
const uint r2 = ne12/ne02;
|
|
1439
|
+
const uint r3 = ne13/ne03;
|
|
1053
1440
|
|
|
1054
1441
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
|
1055
1442
|
// to the matrix-vector kernel
|
|
@@ -1084,7 +1471,7 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1084
1471
|
!wsp_ggml_is_transposed(src1) &&
|
|
1085
1472
|
src1t == WSP_GGML_TYPE_F32 &&
|
|
1086
1473
|
ne00 % 32 == 0 && ne00 >= 64 &&
|
|
1087
|
-
ne11 > ne11_mm_min) {
|
|
1474
|
+
(ne11 > ne11_mm_min || (wsp_ggml_is_quantized(src0t) && ne12 > 1))) {
|
|
1088
1475
|
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
1089
1476
|
switch (src0->type) {
|
|
1090
1477
|
case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
|
|
@@ -1114,9 +1501,10 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1114
1501
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
|
1115
1502
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
|
1116
1503
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
|
1117
|
-
[encoder setBytes:&
|
|
1504
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
|
|
1505
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
|
|
1118
1506
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
|
1119
|
-
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
1507
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
1120
1508
|
} else {
|
|
1121
1509
|
int nth0 = 32;
|
|
1122
1510
|
int nth1 = 1;
|
|
@@ -1127,6 +1515,7 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1127
1515
|
switch (src0t) {
|
|
1128
1516
|
case WSP_GGML_TYPE_F32:
|
|
1129
1517
|
{
|
|
1518
|
+
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
1130
1519
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
|
|
1131
1520
|
nrows = 4;
|
|
1132
1521
|
} break;
|
|
@@ -1134,102 +1523,77 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1134
1523
|
{
|
|
1135
1524
|
nth0 = 32;
|
|
1136
1525
|
nth1 = 1;
|
|
1137
|
-
if (
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1526
|
+
if (src1t == WSP_GGML_TYPE_F32) {
|
|
1527
|
+
if (ne11 * ne12 < 4) {
|
|
1528
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
|
|
1529
|
+
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
|
1530
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
|
|
1531
|
+
nrows = ne11;
|
|
1532
|
+
} else {
|
|
1533
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
|
|
1534
|
+
nrows = 4;
|
|
1535
|
+
}
|
|
1142
1536
|
} else {
|
|
1143
|
-
[encoder setComputePipelineState:ctx->
|
|
1537
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
|
|
1144
1538
|
nrows = 4;
|
|
1145
1539
|
}
|
|
1146
1540
|
} break;
|
|
1147
1541
|
case WSP_GGML_TYPE_Q4_0:
|
|
1148
1542
|
{
|
|
1149
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
1150
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
1151
|
-
|
|
1152
1543
|
nth0 = 8;
|
|
1153
1544
|
nth1 = 8;
|
|
1154
1545
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
|
|
1155
1546
|
} break;
|
|
1156
1547
|
case WSP_GGML_TYPE_Q4_1:
|
|
1157
1548
|
{
|
|
1158
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
1159
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
1160
|
-
|
|
1161
1549
|
nth0 = 8;
|
|
1162
1550
|
nth1 = 8;
|
|
1163
1551
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
|
|
1164
1552
|
} break;
|
|
1165
1553
|
case WSP_GGML_TYPE_Q5_0:
|
|
1166
1554
|
{
|
|
1167
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
1168
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
1169
|
-
|
|
1170
1555
|
nth0 = 8;
|
|
1171
1556
|
nth1 = 8;
|
|
1172
1557
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
|
|
1173
1558
|
} break;
|
|
1174
1559
|
case WSP_GGML_TYPE_Q5_1:
|
|
1175
1560
|
{
|
|
1176
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
1177
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
1178
|
-
|
|
1179
1561
|
nth0 = 8;
|
|
1180
1562
|
nth1 = 8;
|
|
1181
1563
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
|
|
1182
1564
|
} break;
|
|
1183
1565
|
case WSP_GGML_TYPE_Q8_0:
|
|
1184
1566
|
{
|
|
1185
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
1186
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
1187
|
-
|
|
1188
1567
|
nth0 = 8;
|
|
1189
1568
|
nth1 = 8;
|
|
1190
1569
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
|
|
1191
1570
|
} break;
|
|
1192
1571
|
case WSP_GGML_TYPE_Q2_K:
|
|
1193
1572
|
{
|
|
1194
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
1195
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
1196
|
-
|
|
1197
1573
|
nth0 = 2;
|
|
1198
1574
|
nth1 = 32;
|
|
1199
1575
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
|
|
1200
1576
|
} break;
|
|
1201
1577
|
case WSP_GGML_TYPE_Q3_K:
|
|
1202
1578
|
{
|
|
1203
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
1204
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
1205
|
-
|
|
1206
1579
|
nth0 = 2;
|
|
1207
1580
|
nth1 = 32;
|
|
1208
1581
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
|
|
1209
1582
|
} break;
|
|
1210
1583
|
case WSP_GGML_TYPE_Q4_K:
|
|
1211
1584
|
{
|
|
1212
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
1213
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
1214
|
-
|
|
1215
1585
|
nth0 = 4; //1;
|
|
1216
1586
|
nth1 = 8; //32;
|
|
1217
1587
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
|
|
1218
1588
|
} break;
|
|
1219
1589
|
case WSP_GGML_TYPE_Q5_K:
|
|
1220
1590
|
{
|
|
1221
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
1222
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
1223
|
-
|
|
1224
1591
|
nth0 = 2;
|
|
1225
1592
|
nth1 = 32;
|
|
1226
1593
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
|
|
1227
1594
|
} break;
|
|
1228
1595
|
case WSP_GGML_TYPE_Q6_K:
|
|
1229
1596
|
{
|
|
1230
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
1231
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
1232
|
-
|
|
1233
1597
|
nth0 = 2;
|
|
1234
1598
|
nth1 = 32;
|
|
1235
1599
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
|
|
@@ -1258,31 +1622,281 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1258
1622
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
|
1259
1623
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
|
1260
1624
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
|
1261
|
-
[encoder setBytes:&
|
|
1625
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
|
1626
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
|
1262
1627
|
|
|
1263
1628
|
if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 ||
|
|
1264
1629
|
src0t == WSP_GGML_TYPE_Q5_0 || src0t == WSP_GGML_TYPE_Q5_1 || src0t == WSP_GGML_TYPE_Q8_0 ||
|
|
1265
1630
|
src0t == WSP_GGML_TYPE_Q2_K) { // || src0t == WSP_GGML_TYPE_Q4_K) {
|
|
1266
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1631
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1267
1632
|
}
|
|
1268
1633
|
else if (src0t == WSP_GGML_TYPE_Q4_K) {
|
|
1269
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1634
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1270
1635
|
}
|
|
1271
1636
|
else if (src0t == WSP_GGML_TYPE_Q3_K) {
|
|
1272
1637
|
#ifdef WSP_GGML_QKK_64
|
|
1273
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1638
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1274
1639
|
#else
|
|
1275
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1640
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1276
1641
|
#endif
|
|
1277
1642
|
}
|
|
1278
1643
|
else if (src0t == WSP_GGML_TYPE_Q5_K) {
|
|
1279
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1644
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1280
1645
|
}
|
|
1281
1646
|
else if (src0t == WSP_GGML_TYPE_Q6_K) {
|
|
1282
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1647
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1648
|
+
} else {
|
|
1649
|
+
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
|
1650
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1651
|
+
}
|
|
1652
|
+
}
|
|
1653
|
+
} break;
|
|
1654
|
+
case WSP_GGML_OP_MUL_MAT_ID:
|
|
1655
|
+
{
|
|
1656
|
+
//WSP_GGML_ASSERT(ne00 == ne10);
|
|
1657
|
+
//WSP_GGML_ASSERT(ne03 == ne13);
|
|
1658
|
+
|
|
1659
|
+
WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_I32);
|
|
1660
|
+
|
|
1661
|
+
const int n_as = ((int32_t *) dst->op_params)[1];
|
|
1662
|
+
|
|
1663
|
+
// TODO: make this more general
|
|
1664
|
+
WSP_GGML_ASSERT(n_as <= 8);
|
|
1665
|
+
|
|
1666
|
+
struct wsp_ggml_tensor * src2 = gf->nodes[i]->src[2];
|
|
1667
|
+
|
|
1668
|
+
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
|
1669
|
+
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
|
1670
|
+
const int64_t ne22 = src2 ? src2->ne[2] : 0;
|
|
1671
|
+
const int64_t ne23 = src2 ? src2->ne[3] : 0; WSP_GGML_UNUSED(ne23);
|
|
1672
|
+
|
|
1673
|
+
const uint64_t nb20 = src2 ? src2->nb[0] : 0; WSP_GGML_UNUSED(nb20);
|
|
1674
|
+
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
|
1675
|
+
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
|
1676
|
+
const uint64_t nb23 = src2 ? src2->nb[3] : 0; WSP_GGML_UNUSED(nb23);
|
|
1677
|
+
|
|
1678
|
+
const enum wsp_ggml_type src2t = src2 ? src2->type : WSP_GGML_TYPE_COUNT; WSP_GGML_UNUSED(src2t);
|
|
1679
|
+
|
|
1680
|
+
WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src2));
|
|
1681
|
+
WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src1));
|
|
1682
|
+
|
|
1683
|
+
WSP_GGML_ASSERT(ne20 % 32 == 0);
|
|
1684
|
+
// !!!!!!!!! TODO: this assert is probably required but not sure!
|
|
1685
|
+
//WSP_GGML_ASSERT(ne20 >= 64);
|
|
1686
|
+
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
1687
|
+
|
|
1688
|
+
const uint r2 = ne12/ne22;
|
|
1689
|
+
const uint r3 = ne13/ne23;
|
|
1690
|
+
|
|
1691
|
+
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
|
1692
|
+
// to the matrix-vector kernel
|
|
1693
|
+
int ne11_mm_min = 1;
|
|
1694
|
+
|
|
1695
|
+
const int idx = ((int32_t *) dst->op_params)[0];
|
|
1696
|
+
|
|
1697
|
+
// batch size
|
|
1698
|
+
WSP_GGML_ASSERT(ne01 == ne11);
|
|
1699
|
+
|
|
1700
|
+
const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
|
|
1701
|
+
|
|
1702
|
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
|
1703
|
+
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
1704
|
+
// !!!
|
|
1705
|
+
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
|
|
1706
|
+
// indirect matrix multiplication
|
|
1707
|
+
// !!!
|
|
1708
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
|
|
1709
|
+
switch (src2->type) {
|
|
1710
|
+
case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
|
|
1711
|
+
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
|
|
1712
|
+
case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
|
|
1713
|
+
case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
|
|
1714
|
+
case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
|
|
1715
|
+
case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
|
|
1716
|
+
case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
|
|
1717
|
+
case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
|
|
1718
|
+
case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
|
|
1719
|
+
case WSP_GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
|
|
1720
|
+
case WSP_GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
|
|
1721
|
+
case WSP_GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
|
|
1722
|
+
default: WSP_GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
|
1723
|
+
}
|
|
1724
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1725
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1726
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1727
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
|
1728
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
|
1729
|
+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
|
|
1730
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
|
1731
|
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
|
|
1732
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
|
1733
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
|
|
1734
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
|
|
1735
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
|
|
1736
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
|
|
1737
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
|
1738
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
|
|
1739
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
|
1740
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
|
|
1741
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
|
1742
|
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
|
|
1743
|
+
// TODO: how to make this an array? read Metal docs
|
|
1744
|
+
for (int j = 0; j < n_as; ++j) {
|
|
1745
|
+
struct wsp_ggml_tensor * src_cur = dst->src[2 + j];
|
|
1746
|
+
|
|
1747
|
+
size_t offs_src_cur = 0;
|
|
1748
|
+
id<MTLBuffer> id_src_cur = wsp_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
|
1749
|
+
|
|
1750
|
+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
|
1751
|
+
}
|
|
1752
|
+
|
|
1753
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
|
1754
|
+
|
|
1755
|
+
// TODO: processing one row at a time (ne11 -> 1) is not efficient
|
|
1756
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
1757
|
+
} else {
|
|
1758
|
+
int nth0 = 32;
|
|
1759
|
+
int nth1 = 1;
|
|
1760
|
+
int nrows = 1;
|
|
1761
|
+
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
1762
|
+
|
|
1763
|
+
// use custom matrix x vector kernel
|
|
1764
|
+
switch (src2t) {
|
|
1765
|
+
case WSP_GGML_TYPE_F32:
|
|
1766
|
+
{
|
|
1767
|
+
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
1768
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
|
|
1769
|
+
} break;
|
|
1770
|
+
case WSP_GGML_TYPE_F16:
|
|
1771
|
+
{
|
|
1772
|
+
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
1773
|
+
nth0 = 32;
|
|
1774
|
+
nth1 = 1;
|
|
1775
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
|
|
1776
|
+
} break;
|
|
1777
|
+
case WSP_GGML_TYPE_Q4_0:
|
|
1778
|
+
{
|
|
1779
|
+
nth0 = 8;
|
|
1780
|
+
nth1 = 8;
|
|
1781
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
|
|
1782
|
+
} break;
|
|
1783
|
+
case WSP_GGML_TYPE_Q4_1:
|
|
1784
|
+
{
|
|
1785
|
+
nth0 = 8;
|
|
1786
|
+
nth1 = 8;
|
|
1787
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
|
|
1788
|
+
} break;
|
|
1789
|
+
case WSP_GGML_TYPE_Q5_0:
|
|
1790
|
+
{
|
|
1791
|
+
nth0 = 8;
|
|
1792
|
+
nth1 = 8;
|
|
1793
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
|
|
1794
|
+
} break;
|
|
1795
|
+
case WSP_GGML_TYPE_Q5_1:
|
|
1796
|
+
{
|
|
1797
|
+
nth0 = 8;
|
|
1798
|
+
nth1 = 8;
|
|
1799
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
|
|
1800
|
+
} break;
|
|
1801
|
+
case WSP_GGML_TYPE_Q8_0:
|
|
1802
|
+
{
|
|
1803
|
+
nth0 = 8;
|
|
1804
|
+
nth1 = 8;
|
|
1805
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
|
|
1806
|
+
} break;
|
|
1807
|
+
case WSP_GGML_TYPE_Q2_K:
|
|
1808
|
+
{
|
|
1809
|
+
nth0 = 2;
|
|
1810
|
+
nth1 = 32;
|
|
1811
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
|
|
1812
|
+
} break;
|
|
1813
|
+
case WSP_GGML_TYPE_Q3_K:
|
|
1814
|
+
{
|
|
1815
|
+
nth0 = 2;
|
|
1816
|
+
nth1 = 32;
|
|
1817
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
|
|
1818
|
+
} break;
|
|
1819
|
+
case WSP_GGML_TYPE_Q4_K:
|
|
1820
|
+
{
|
|
1821
|
+
nth0 = 4; //1;
|
|
1822
|
+
nth1 = 8; //32;
|
|
1823
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
|
|
1824
|
+
} break;
|
|
1825
|
+
case WSP_GGML_TYPE_Q5_K:
|
|
1826
|
+
{
|
|
1827
|
+
nth0 = 2;
|
|
1828
|
+
nth1 = 32;
|
|
1829
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
|
|
1830
|
+
} break;
|
|
1831
|
+
case WSP_GGML_TYPE_Q6_K:
|
|
1832
|
+
{
|
|
1833
|
+
nth0 = 2;
|
|
1834
|
+
nth1 = 32;
|
|
1835
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
|
|
1836
|
+
} break;
|
|
1837
|
+
default:
|
|
1838
|
+
{
|
|
1839
|
+
WSP_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
|
1840
|
+
WSP_GGML_ASSERT(false && "not implemented");
|
|
1841
|
+
}
|
|
1842
|
+
};
|
|
1843
|
+
|
|
1844
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1845
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1846
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1847
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
|
1848
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
|
1849
|
+
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
|
1850
|
+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
|
|
1851
|
+
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
|
|
1852
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
|
|
1853
|
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
|
|
1854
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
|
1855
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
|
|
1856
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
|
1857
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
|
1858
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
|
1859
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
|
1860
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
|
1861
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
|
1862
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
|
|
1863
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
|
|
1864
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
|
|
1865
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
|
|
1866
|
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:22];
|
|
1867
|
+
// TODO: how to make this an array? read Metal docs
|
|
1868
|
+
for (int j = 0; j < n_as; ++j) {
|
|
1869
|
+
struct wsp_ggml_tensor * src_cur = dst->src[2 + j];
|
|
1870
|
+
|
|
1871
|
+
size_t offs_src_cur = 0;
|
|
1872
|
+
id<MTLBuffer> id_src_cur = wsp_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
|
1873
|
+
|
|
1874
|
+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
|
1875
|
+
}
|
|
1876
|
+
|
|
1877
|
+
if (src2t == WSP_GGML_TYPE_Q4_0 || src2t == WSP_GGML_TYPE_Q4_1 ||
|
|
1878
|
+
src2t == WSP_GGML_TYPE_Q5_0 || src2t == WSP_GGML_TYPE_Q5_1 || src2t == WSP_GGML_TYPE_Q8_0 ||
|
|
1879
|
+
src2t == WSP_GGML_TYPE_Q2_K) { // || src2t == WSP_GGML_TYPE_Q4_K) {
|
|
1880
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1881
|
+
}
|
|
1882
|
+
else if (src2t == WSP_GGML_TYPE_Q4_K) {
|
|
1883
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1884
|
+
}
|
|
1885
|
+
else if (src2t == WSP_GGML_TYPE_Q3_K) {
|
|
1886
|
+
#ifdef WSP_GGML_QKK_64
|
|
1887
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1888
|
+
#else
|
|
1889
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1890
|
+
#endif
|
|
1891
|
+
}
|
|
1892
|
+
else if (src2t == WSP_GGML_TYPE_Q5_K) {
|
|
1893
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1894
|
+
}
|
|
1895
|
+
else if (src2t == WSP_GGML_TYPE_Q6_K) {
|
|
1896
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1283
1897
|
} else {
|
|
1284
|
-
int64_t ny = (
|
|
1285
|
-
[encoder dispatchThreadgroups:MTLSizeMake(
|
|
1898
|
+
const int64_t ny = (_ne1 + nrows - 1)/nrows;
|
|
1899
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1286
1900
|
}
|
|
1287
1901
|
}
|
|
1288
1902
|
} break;
|
|
@@ -1304,16 +1918,19 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1304
1918
|
default: WSP_GGML_ASSERT(false && "not implemented");
|
|
1305
1919
|
}
|
|
1306
1920
|
|
|
1307
|
-
[encoder setBuffer:id_src0
|
|
1308
|
-
[encoder setBuffer:id_src1
|
|
1309
|
-
[encoder setBuffer:id_dst
|
|
1921
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1922
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1923
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1310
1924
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
|
1311
1925
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
|
1312
|
-
[encoder setBytes:&
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
[encoder
|
|
1926
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
|
|
1927
|
+
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
|
|
1928
|
+
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
|
|
1929
|
+
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
|
|
1930
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
|
|
1931
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
|
|
1932
|
+
|
|
1933
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
|
1317
1934
|
} break;
|
|
1318
1935
|
case WSP_GGML_OP_RMS_NORM:
|
|
1319
1936
|
{
|
|
@@ -1322,20 +1939,56 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1322
1939
|
float eps;
|
|
1323
1940
|
memcpy(&eps, dst->op_params, sizeof(float));
|
|
1324
1941
|
|
|
1325
|
-
|
|
1942
|
+
int nth = 32; // SIMD width
|
|
1943
|
+
|
|
1944
|
+
while (nth < ne00/4 && nth < 1024) {
|
|
1945
|
+
nth *= 2;
|
|
1946
|
+
}
|
|
1326
1947
|
|
|
1327
1948
|
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
|
1328
|
-
[encoder setBuffer:id_src0 offset:offs_src0
|
|
1329
|
-
[encoder setBuffer:id_dst offset:offs_dst
|
|
1330
|
-
[encoder setBytes:&ne00
|
|
1331
|
-
[encoder setBytes:&nb01
|
|
1332
|
-
[encoder setBytes:&eps
|
|
1333
|
-
[encoder setThreadgroupMemoryLength:
|
|
1949
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1950
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1951
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1952
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
|
1953
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
|
1954
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
1334
1955
|
|
|
1335
1956
|
const int64_t nrows = wsp_ggml_nrows(src0);
|
|
1336
1957
|
|
|
1337
1958
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1338
1959
|
} break;
|
|
1960
|
+
case WSP_GGML_OP_GROUP_NORM:
|
|
1961
|
+
{
|
|
1962
|
+
WSP_GGML_ASSERT(ne00 % 4 == 0);
|
|
1963
|
+
|
|
1964
|
+
//float eps;
|
|
1965
|
+
//memcpy(&eps, dst->op_params, sizeof(float));
|
|
1966
|
+
|
|
1967
|
+
const float eps = 1e-6f; // TODO: temporarily hardcoded
|
|
1968
|
+
|
|
1969
|
+
const int32_t n_groups = ((int32_t *) dst->op_params)[0];
|
|
1970
|
+
|
|
1971
|
+
int nth = 32; // SIMD width
|
|
1972
|
+
|
|
1973
|
+
//while (nth < ne00/4 && nth < 1024) {
|
|
1974
|
+
// nth *= 2;
|
|
1975
|
+
//}
|
|
1976
|
+
|
|
1977
|
+
[encoder setComputePipelineState:ctx->pipeline_group_norm];
|
|
1978
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1979
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1980
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1981
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
1982
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
1983
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
|
|
1984
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
|
|
1985
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
|
|
1986
|
+
[encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
|
|
1987
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:9];
|
|
1988
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
1989
|
+
|
|
1990
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1991
|
+
} break;
|
|
1339
1992
|
case WSP_GGML_OP_NORM:
|
|
1340
1993
|
{
|
|
1341
1994
|
float eps;
|
|
@@ -1404,7 +2057,8 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1404
2057
|
const int n_past = ((int32_t *) dst->op_params)[0];
|
|
1405
2058
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
1406
2059
|
const int mode = ((int32_t *) dst->op_params)[2];
|
|
1407
|
-
|
|
2060
|
+
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
|
2061
|
+
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
|
1408
2062
|
|
|
1409
2063
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
1410
2064
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
@@ -1452,18 +2106,175 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1452
2106
|
|
|
1453
2107
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1454
2108
|
} break;
|
|
2109
|
+
case WSP_GGML_OP_IM2COL:
|
|
2110
|
+
{
|
|
2111
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
|
|
2112
|
+
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
2113
|
+
WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16);
|
|
2114
|
+
|
|
2115
|
+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
|
2116
|
+
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
|
2117
|
+
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
|
|
2118
|
+
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
|
2119
|
+
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
|
2120
|
+
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
|
2121
|
+
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
|
2122
|
+
|
|
2123
|
+
const int32_t N = src1->ne[is_2D ? 3 : 2];
|
|
2124
|
+
const int32_t IC = src1->ne[is_2D ? 2 : 1];
|
|
2125
|
+
const int32_t IH = is_2D ? src1->ne[1] : 1;
|
|
2126
|
+
const int32_t IW = src1->ne[0];
|
|
2127
|
+
|
|
2128
|
+
const int32_t KH = is_2D ? src0->ne[1] : 1;
|
|
2129
|
+
const int32_t KW = src0->ne[0];
|
|
2130
|
+
|
|
2131
|
+
const int32_t OH = is_2D ? dst->ne[2] : 1;
|
|
2132
|
+
const int32_t OW = dst->ne[1];
|
|
2133
|
+
|
|
2134
|
+
const int32_t CHW = IC * KH * KW;
|
|
2135
|
+
|
|
2136
|
+
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
|
2137
|
+
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
|
2138
|
+
|
|
2139
|
+
switch (src0->type) {
|
|
2140
|
+
case WSP_GGML_TYPE_F32: WSP_GGML_ASSERT(false && "not implemented"); break;
|
|
2141
|
+
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
|
|
2142
|
+
default: WSP_GGML_ASSERT(false);
|
|
2143
|
+
};
|
|
2144
|
+
|
|
2145
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
|
2146
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2147
|
+
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
|
|
2148
|
+
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
|
|
2149
|
+
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
|
|
2150
|
+
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
|
|
2151
|
+
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
|
|
2152
|
+
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
|
|
2153
|
+
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
|
|
2154
|
+
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
|
|
2155
|
+
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
|
|
2156
|
+
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
|
|
2157
|
+
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
|
|
2158
|
+
|
|
2159
|
+
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
|
2160
|
+
} break;
|
|
2161
|
+
case WSP_GGML_OP_UPSCALE:
|
|
2162
|
+
{
|
|
2163
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
2164
|
+
|
|
2165
|
+
const int sf = dst->op_params[0];
|
|
2166
|
+
|
|
2167
|
+
[encoder setComputePipelineState:ctx->pipeline_upscale_f32];
|
|
2168
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2169
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2170
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
2171
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
2172
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
2173
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
2174
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
2175
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
2176
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
2177
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
2178
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
|
2179
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
|
2180
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
|
2181
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
|
2182
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
|
2183
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
|
2184
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
|
2185
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
|
2186
|
+
[encoder setBytes:&sf length:sizeof(sf) atIndex:18];
|
|
2187
|
+
|
|
2188
|
+
const int nth = MIN(1024, ne0);
|
|
2189
|
+
|
|
2190
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2191
|
+
} break;
|
|
2192
|
+
case WSP_GGML_OP_PAD:
|
|
2193
|
+
{
|
|
2194
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
2195
|
+
|
|
2196
|
+
[encoder setComputePipelineState:ctx->pipeline_pad_f32];
|
|
2197
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2198
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2199
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
2200
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
2201
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
2202
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
2203
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
2204
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
2205
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
2206
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
2207
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
|
2208
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
|
2209
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
|
2210
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
|
2211
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
|
2212
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
|
2213
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
|
2214
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
|
2215
|
+
|
|
2216
|
+
const int nth = MIN(1024, ne0);
|
|
2217
|
+
|
|
2218
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2219
|
+
} break;
|
|
2220
|
+
case WSP_GGML_OP_ARGSORT:
|
|
2221
|
+
{
|
|
2222
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
2223
|
+
WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_I32);
|
|
2224
|
+
|
|
2225
|
+
const int nrows = wsp_ggml_nrows(src0);
|
|
2226
|
+
|
|
2227
|
+
enum wsp_ggml_sort_order order = (enum wsp_ggml_sort_order) dst->op_params[0];
|
|
2228
|
+
|
|
2229
|
+
switch (order) {
|
|
2230
|
+
case WSP_GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
|
|
2231
|
+
case WSP_GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
|
|
2232
|
+
default: WSP_GGML_ASSERT(false);
|
|
2233
|
+
};
|
|
2234
|
+
|
|
2235
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2236
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2237
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
2238
|
+
|
|
2239
|
+
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
|
|
2240
|
+
} break;
|
|
2241
|
+
case WSP_GGML_OP_LEAKY_RELU:
|
|
2242
|
+
{
|
|
2243
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
2244
|
+
|
|
2245
|
+
float slope;
|
|
2246
|
+
memcpy(&slope, dst->op_params, sizeof(float));
|
|
2247
|
+
|
|
2248
|
+
[encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
|
|
2249
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2250
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2251
|
+
[encoder setBytes:&slope length:sizeof(slope) atIndex:2];
|
|
2252
|
+
|
|
2253
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
2254
|
+
|
|
2255
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
2256
|
+
} break;
|
|
1455
2257
|
case WSP_GGML_OP_DUP:
|
|
1456
2258
|
case WSP_GGML_OP_CPY:
|
|
1457
2259
|
case WSP_GGML_OP_CONT:
|
|
1458
2260
|
{
|
|
1459
|
-
|
|
2261
|
+
WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(src0->type) == 0);
|
|
2262
|
+
|
|
2263
|
+
int nth = MIN(1024, ne00/wsp_ggml_blck_size(src0->type));
|
|
1460
2264
|
|
|
1461
2265
|
switch (src0t) {
|
|
1462
2266
|
case WSP_GGML_TYPE_F32:
|
|
1463
2267
|
{
|
|
2268
|
+
WSP_GGML_ASSERT(ne0 % wsp_ggml_blck_size(dst->type) == 0);
|
|
2269
|
+
|
|
1464
2270
|
switch (dstt) {
|
|
1465
|
-
case WSP_GGML_TYPE_F16:
|
|
1466
|
-
case WSP_GGML_TYPE_F32:
|
|
2271
|
+
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
|
|
2272
|
+
case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
|
|
2273
|
+
case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
|
|
2274
|
+
case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
|
|
2275
|
+
case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
|
|
2276
|
+
//case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
|
|
2277
|
+
//case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
|
|
1467
2278
|
default: WSP_GGML_ASSERT(false && "not implemented");
|
|
1468
2279
|
};
|
|
1469
2280
|
} break;
|
|
@@ -1471,7 +2282,7 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1471
2282
|
{
|
|
1472
2283
|
switch (dstt) {
|
|
1473
2284
|
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
|
|
1474
|
-
case WSP_GGML_TYPE_F32:
|
|
2285
|
+
case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
|
|
1475
2286
|
default: WSP_GGML_ASSERT(false && "not implemented");
|
|
1476
2287
|
};
|
|
1477
2288
|
} break;
|
|
@@ -1538,81 +2349,148 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1538
2349
|
|
|
1539
2350
|
// backend interface
|
|
1540
2351
|
|
|
1541
|
-
static
|
|
1542
|
-
|
|
2352
|
+
static id<MTLDevice> g_backend_device = nil;
|
|
2353
|
+
static int g_backend_device_ref_count = 0;
|
|
1543
2354
|
|
|
1544
|
-
|
|
2355
|
+
static id<MTLDevice> wsp_ggml_backend_metal_get_device(void) {
|
|
2356
|
+
if (g_backend_device == nil) {
|
|
2357
|
+
g_backend_device = MTLCreateSystemDefaultDevice();
|
|
2358
|
+
}
|
|
2359
|
+
|
|
2360
|
+
g_backend_device_ref_count++;
|
|
2361
|
+
|
|
2362
|
+
return g_backend_device;
|
|
1545
2363
|
}
|
|
1546
2364
|
|
|
1547
|
-
static void
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
|
|
2365
|
+
static void wsp_ggml_backend_metal_free_device(void) {
|
|
2366
|
+
assert(g_backend_device_ref_count > 0);
|
|
2367
|
+
|
|
2368
|
+
g_backend_device_ref_count--;
|
|
2369
|
+
|
|
2370
|
+
if (g_backend_device_ref_count == 0) {
|
|
2371
|
+
g_backend_device = nil;
|
|
2372
|
+
}
|
|
1551
2373
|
}
|
|
1552
2374
|
|
|
1553
2375
|
static void * wsp_ggml_backend_metal_buffer_get_base(wsp_ggml_backend_buffer_t buffer) {
|
|
1554
|
-
|
|
2376
|
+
struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context;
|
|
2377
|
+
|
|
2378
|
+
return ctx->data;
|
|
1555
2379
|
}
|
|
1556
2380
|
|
|
1557
2381
|
static void wsp_ggml_backend_metal_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) {
|
|
1558
|
-
|
|
2382
|
+
struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context;
|
|
2383
|
+
|
|
2384
|
+
wsp_ggml_backend_metal_free_device();
|
|
2385
|
+
|
|
2386
|
+
free(ctx->data);
|
|
2387
|
+
free(ctx);
|
|
2388
|
+
|
|
2389
|
+
UNUSED(buffer);
|
|
2390
|
+
}
|
|
2391
|
+
|
|
2392
|
+
static void wsp_ggml_backend_metal_buffer_set_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
|
2393
|
+
WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds");
|
|
2394
|
+
WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
|
2395
|
+
|
|
2396
|
+
memcpy((char *)tensor->data + offset, data, size);
|
|
2397
|
+
|
|
2398
|
+
UNUSED(buffer);
|
|
2399
|
+
}
|
|
2400
|
+
|
|
2401
|
+
static void wsp_ggml_backend_metal_buffer_get_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
|
2402
|
+
WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
|
|
2403
|
+
WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
|
2404
|
+
|
|
2405
|
+
memcpy(data, (const char *)tensor->data + offset, size);
|
|
2406
|
+
|
|
2407
|
+
UNUSED(buffer);
|
|
2408
|
+
}
|
|
2409
|
+
|
|
2410
|
+
static void wsp_ggml_backend_metal_buffer_cpy_tensor_from(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
|
|
2411
|
+
wsp_ggml_backend_tensor_get(src, dst->data, 0, wsp_ggml_nbytes(src));
|
|
2412
|
+
|
|
2413
|
+
UNUSED(buffer);
|
|
2414
|
+
}
|
|
2415
|
+
|
|
2416
|
+
static void wsp_ggml_backend_metal_buffer_cpy_tensor_to(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
|
|
2417
|
+
wsp_ggml_backend_tensor_set(dst, src->data, 0, wsp_ggml_nbytes(src));
|
|
2418
|
+
|
|
1559
2419
|
UNUSED(buffer);
|
|
1560
2420
|
}
|
|
1561
2421
|
|
|
1562
2422
|
static struct wsp_ggml_backend_buffer_i metal_backend_buffer_i = {
|
|
1563
|
-
/* .free_buffer
|
|
1564
|
-
/* .get_base
|
|
1565
|
-
/* .
|
|
1566
|
-
/* .
|
|
1567
|
-
/* .
|
|
2423
|
+
/* .free_buffer = */ wsp_ggml_backend_metal_buffer_free_buffer,
|
|
2424
|
+
/* .get_base = */ wsp_ggml_backend_metal_buffer_get_base,
|
|
2425
|
+
/* .init_tensor = */ NULL,
|
|
2426
|
+
/* .set_tensor = */ wsp_ggml_backend_metal_buffer_set_tensor,
|
|
2427
|
+
/* .get_tensor = */ wsp_ggml_backend_metal_buffer_get_tensor,
|
|
2428
|
+
/* .cpy_tensor_from = */ wsp_ggml_backend_metal_buffer_cpy_tensor_from,
|
|
2429
|
+
/* .cpy_tensor_to = */ wsp_ggml_backend_metal_buffer_cpy_tensor_to,
|
|
1568
2430
|
};
|
|
1569
2431
|
|
|
1570
|
-
static wsp_ggml_backend_buffer_t
|
|
1571
|
-
struct
|
|
2432
|
+
static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
|
|
2433
|
+
struct wsp_ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct wsp_ggml_backend_metal_buffer_context));
|
|
1572
2434
|
|
|
1573
|
-
|
|
2435
|
+
const size_t size_page = sysconf(_SC_PAGESIZE);
|
|
1574
2436
|
|
|
1575
|
-
|
|
1576
|
-
|
|
2437
|
+
size_t size_aligned = size;
|
|
2438
|
+
if ((size_aligned % size_page) != 0) {
|
|
2439
|
+
size_aligned += (size_page - (size_aligned % size_page));
|
|
2440
|
+
}
|
|
1577
2441
|
|
|
1578
|
-
|
|
2442
|
+
ctx->data = wsp_ggml_metal_host_malloc(size);
|
|
2443
|
+
ctx->metal = [wsp_ggml_backend_metal_get_device() newBufferWithBytesNoCopy:ctx->data
|
|
2444
|
+
length:size_aligned
|
|
2445
|
+
options:MTLResourceStorageModeShared
|
|
2446
|
+
deallocator:nil];
|
|
2447
|
+
|
|
2448
|
+
return wsp_ggml_backend_buffer_init(buft, metal_backend_buffer_i, ctx, size);
|
|
1579
2449
|
}
|
|
1580
2450
|
|
|
1581
|
-
static size_t
|
|
2451
|
+
static size_t wsp_ggml_backend_metal_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
|
|
1582
2452
|
return 32;
|
|
1583
|
-
UNUSED(
|
|
2453
|
+
UNUSED(buft);
|
|
1584
2454
|
}
|
|
1585
2455
|
|
|
1586
|
-
static
|
|
1587
|
-
|
|
1588
|
-
WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
|
2456
|
+
static bool wsp_ggml_backend_metal_buffer_type_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend) {
|
|
2457
|
+
return wsp_ggml_backend_is_metal(backend) || wsp_ggml_backend_is_cpu(backend);
|
|
1589
2458
|
|
|
1590
|
-
|
|
1591
|
-
|
|
1592
|
-
UNUSED(backend);
|
|
2459
|
+
WSP_GGML_UNUSED(buft);
|
|
1593
2460
|
}
|
|
1594
2461
|
|
|
1595
|
-
|
|
1596
|
-
|
|
1597
|
-
|
|
1598
|
-
|
|
1599
|
-
|
|
2462
|
+
wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void) {
|
|
2463
|
+
static struct wsp_ggml_backend_buffer_type wsp_ggml_backend_buffer_type_metal = {
|
|
2464
|
+
/* .iface = */ {
|
|
2465
|
+
/* .alloc_buffer = */ wsp_ggml_backend_metal_buffer_type_alloc_buffer,
|
|
2466
|
+
/* .get_alignment = */ wsp_ggml_backend_metal_buffer_type_get_alignment,
|
|
2467
|
+
/* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes
|
|
2468
|
+
/* .supports_backend = */ wsp_ggml_backend_metal_buffer_type_supports_backend,
|
|
2469
|
+
},
|
|
2470
|
+
/* .context = */ NULL,
|
|
2471
|
+
};
|
|
1600
2472
|
|
|
1601
|
-
|
|
2473
|
+
return &wsp_ggml_backend_buffer_type_metal;
|
|
1602
2474
|
}
|
|
1603
2475
|
|
|
1604
|
-
static
|
|
2476
|
+
static const char * wsp_ggml_backend_metal_name(wsp_ggml_backend_t backend) {
|
|
2477
|
+
return "Metal";
|
|
2478
|
+
|
|
1605
2479
|
UNUSED(backend);
|
|
1606
2480
|
}
|
|
1607
2481
|
|
|
1608
|
-
static void
|
|
1609
|
-
|
|
2482
|
+
static void wsp_ggml_backend_metal_free(wsp_ggml_backend_t backend) {
|
|
2483
|
+
struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
|
|
2484
|
+
wsp_ggml_metal_free(ctx);
|
|
2485
|
+
free(backend);
|
|
2486
|
+
}
|
|
1610
2487
|
|
|
2488
|
+
static void wsp_ggml_backend_metal_synchronize(wsp_ggml_backend_t backend) {
|
|
1611
2489
|
UNUSED(backend);
|
|
1612
2490
|
}
|
|
1613
2491
|
|
|
1614
|
-
static
|
|
1615
|
-
|
|
2492
|
+
static wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_get_default_buffer_type(wsp_ggml_backend_t backend) {
|
|
2493
|
+
return wsp_ggml_backend_metal_buffer_type();
|
|
1616
2494
|
|
|
1617
2495
|
UNUSED(backend);
|
|
1618
2496
|
}
|
|
@@ -1624,32 +2502,43 @@ static void wsp_ggml_backend_metal_graph_compute(wsp_ggml_backend_t backend, str
|
|
|
1624
2502
|
}
|
|
1625
2503
|
|
|
1626
2504
|
static bool wsp_ggml_backend_metal_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) {
|
|
1627
|
-
return
|
|
2505
|
+
return wsp_ggml_metal_supports_op(op);
|
|
2506
|
+
|
|
1628
2507
|
UNUSED(backend);
|
|
1629
|
-
UNUSED(op);
|
|
1630
2508
|
}
|
|
1631
2509
|
|
|
1632
2510
|
static struct wsp_ggml_backend_i metal_backend_i = {
|
|
1633
|
-
/* .get_name
|
|
1634
|
-
/* .free
|
|
1635
|
-
/* .
|
|
1636
|
-
/* .
|
|
1637
|
-
/* .
|
|
1638
|
-
/* .
|
|
1639
|
-
/* .
|
|
1640
|
-
/* .
|
|
1641
|
-
/* .
|
|
1642
|
-
/* .
|
|
1643
|
-
/* .
|
|
1644
|
-
/* .
|
|
1645
|
-
/* .
|
|
1646
|
-
/* .supports_op = */ wsp_ggml_backend_metal_supports_op,
|
|
2511
|
+
/* .get_name = */ wsp_ggml_backend_metal_name,
|
|
2512
|
+
/* .free = */ wsp_ggml_backend_metal_free,
|
|
2513
|
+
/* .get_default_buffer_type = */ wsp_ggml_backend_metal_get_default_buffer_type,
|
|
2514
|
+
/* .set_tensor_async = */ NULL,
|
|
2515
|
+
/* .get_tensor_async = */ NULL,
|
|
2516
|
+
/* .cpy_tensor_from_async = */ NULL,
|
|
2517
|
+
/* .cpy_tensor_to_async = */ NULL,
|
|
2518
|
+
/* .synchronize = */ wsp_ggml_backend_metal_synchronize,
|
|
2519
|
+
/* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
|
|
2520
|
+
/* .graph_plan_free = */ NULL,
|
|
2521
|
+
/* .graph_plan_compute = */ NULL,
|
|
2522
|
+
/* .graph_compute = */ wsp_ggml_backend_metal_graph_compute,
|
|
2523
|
+
/* .supports_op = */ wsp_ggml_backend_metal_supports_op,
|
|
1647
2524
|
};
|
|
1648
2525
|
|
|
2526
|
+
// TODO: make a common log callback for all backends in ggml-backend
|
|
2527
|
+
static void wsp_ggml_backend_log_callback(enum wsp_ggml_log_level level, const char * msg, void * user_data) {
|
|
2528
|
+
fprintf(stderr, "%s", msg);
|
|
2529
|
+
|
|
2530
|
+
UNUSED(level);
|
|
2531
|
+
UNUSED(user_data);
|
|
2532
|
+
}
|
|
2533
|
+
|
|
1649
2534
|
wsp_ggml_backend_t wsp_ggml_backend_metal_init(void) {
|
|
1650
|
-
|
|
2535
|
+
wsp_ggml_metal_log_set_callback(wsp_ggml_backend_log_callback, NULL);
|
|
2536
|
+
|
|
2537
|
+
struct wsp_ggml_metal_context * ctx = wsp_ggml_metal_init(WSP_GGML_DEFAULT_N_THREADS);
|
|
1651
2538
|
|
|
1652
|
-
ctx
|
|
2539
|
+
if (ctx == NULL) {
|
|
2540
|
+
return NULL;
|
|
2541
|
+
}
|
|
1653
2542
|
|
|
1654
2543
|
wsp_ggml_backend_t metal_backend = malloc(sizeof(struct wsp_ggml_backend));
|
|
1655
2544
|
|
|
@@ -1666,7 +2555,26 @@ bool wsp_ggml_backend_is_metal(wsp_ggml_backend_t backend) {
|
|
|
1666
2555
|
}
|
|
1667
2556
|
|
|
1668
2557
|
void wsp_ggml_backend_metal_set_n_cb(wsp_ggml_backend_t backend, int n_cb) {
|
|
2558
|
+
WSP_GGML_ASSERT(wsp_ggml_backend_is_metal(backend));
|
|
2559
|
+
|
|
1669
2560
|
struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
|
|
1670
2561
|
|
|
1671
2562
|
wsp_ggml_metal_set_n_cb(ctx, n_cb);
|
|
1672
2563
|
}
|
|
2564
|
+
|
|
2565
|
+
bool wsp_ggml_backend_metal_supports_family(wsp_ggml_backend_t backend, int family) {
|
|
2566
|
+
WSP_GGML_ASSERT(wsp_ggml_backend_is_metal(backend));
|
|
2567
|
+
|
|
2568
|
+
struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
|
|
2569
|
+
|
|
2570
|
+
return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
|
2571
|
+
}
|
|
2572
|
+
|
|
2573
|
+
wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
|
|
2574
|
+
|
|
2575
|
+
wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data) {
|
|
2576
|
+
return wsp_ggml_backend_metal_init();
|
|
2577
|
+
|
|
2578
|
+
WSP_GGML_UNUSED(params);
|
|
2579
|
+
WSP_GGML_UNUSED(user_data);
|
|
2580
|
+
}
|