llama_cpp 0.9.5 → 0.10.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/ext/llama_cpp/llama_cpp.cpp +123 -15
- data/ext/llama_cpp/src/ggml-alloc.c +42 -7
- data/ext/llama_cpp/src/ggml-alloc.h +8 -1
- data/ext/llama_cpp/src/ggml-backend-impl.h +46 -21
- data/ext/llama_cpp/src/ggml-backend.c +563 -156
- data/ext/llama_cpp/src/ggml-backend.h +62 -17
- data/ext/llama_cpp/src/ggml-cuda.cu +1796 -413
- data/ext/llama_cpp/src/ggml-cuda.h +9 -1
- data/ext/llama_cpp/src/ggml-impl.h +1 -1
- data/ext/llama_cpp/src/ggml-metal.h +6 -0
- data/ext/llama_cpp/src/ggml-metal.m +998 -169
- data/ext/llama_cpp/src/ggml-metal.metal +2253 -274
- data/ext/llama_cpp/src/ggml-quants.c +2 -2
- data/ext/llama_cpp/src/ggml.c +634 -248
- data/ext/llama_cpp/src/ggml.h +81 -15
- data/ext/llama_cpp/src/llama.cpp +932 -352
- data/ext/llama_cpp/src/llama.h +28 -5
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +22 -2
- metadata +2 -2
@@ -62,11 +62,15 @@ struct ggml_metal_context {
|
|
62
62
|
GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
|
63
63
|
GGML_METAL_DECL_KERNEL(mul);
|
64
64
|
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
65
|
+
GGML_METAL_DECL_KERNEL(div);
|
66
|
+
GGML_METAL_DECL_KERNEL(div_row);
|
65
67
|
GGML_METAL_DECL_KERNEL(scale);
|
66
68
|
GGML_METAL_DECL_KERNEL(scale_4);
|
67
|
-
GGML_METAL_DECL_KERNEL(
|
69
|
+
GGML_METAL_DECL_KERNEL(tanh);
|
68
70
|
GGML_METAL_DECL_KERNEL(relu);
|
69
71
|
GGML_METAL_DECL_KERNEL(gelu);
|
72
|
+
GGML_METAL_DECL_KERNEL(gelu_quick);
|
73
|
+
GGML_METAL_DECL_KERNEL(silu);
|
70
74
|
GGML_METAL_DECL_KERNEL(soft_max);
|
71
75
|
GGML_METAL_DECL_KERNEL(soft_max_4);
|
72
76
|
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
@@ -84,6 +88,7 @@ struct ggml_metal_context {
|
|
84
88
|
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
85
89
|
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
86
90
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
91
|
+
GGML_METAL_DECL_KERNEL(group_norm);
|
87
92
|
GGML_METAL_DECL_KERNEL(norm);
|
88
93
|
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
89
94
|
GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
|
@@ -100,6 +105,21 @@ struct ggml_metal_context {
|
|
100
105
|
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
|
101
106
|
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
102
107
|
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
108
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
|
109
|
+
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
|
110
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
|
111
|
+
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
|
112
|
+
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
|
113
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
|
114
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
|
115
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
|
116
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
|
117
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
|
118
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
|
119
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
|
120
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
|
121
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
|
122
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
|
103
123
|
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
104
124
|
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
105
125
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
@@ -112,15 +132,39 @@ struct ggml_metal_context {
|
|
112
132
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
113
133
|
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
114
134
|
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
135
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
|
136
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
|
137
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
|
138
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
|
139
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
|
140
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
|
141
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
|
142
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
|
143
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
|
144
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
|
145
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
|
146
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
|
115
147
|
GGML_METAL_DECL_KERNEL(rope_f32);
|
116
148
|
GGML_METAL_DECL_KERNEL(rope_f16);
|
117
149
|
GGML_METAL_DECL_KERNEL(alibi_f32);
|
118
150
|
GGML_METAL_DECL_KERNEL(im2col_f16);
|
151
|
+
GGML_METAL_DECL_KERNEL(upscale_f32);
|
152
|
+
GGML_METAL_DECL_KERNEL(pad_f32);
|
153
|
+
GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
|
154
|
+
GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
|
155
|
+
GGML_METAL_DECL_KERNEL(leaky_relu_f32);
|
119
156
|
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
120
157
|
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
158
|
+
GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
|
159
|
+
GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
|
160
|
+
GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
|
161
|
+
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
|
162
|
+
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
|
121
163
|
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
164
|
+
GGML_METAL_DECL_KERNEL(cpy_f16_f32);
|
122
165
|
GGML_METAL_DECL_KERNEL(concat);
|
123
166
|
GGML_METAL_DECL_KERNEL(sqr);
|
167
|
+
GGML_METAL_DECL_KERNEL(sum_rows);
|
124
168
|
|
125
169
|
#undef GGML_METAL_DECL_KERNEL
|
126
170
|
};
|
@@ -155,6 +199,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
|
155
199
|
ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
|
156
200
|
} else {
|
157
201
|
char* buffer2 = malloc(len+1);
|
202
|
+
va_end(args);
|
203
|
+
va_start(args, format);
|
158
204
|
vsnprintf(buffer2, len+1, format, args);
|
159
205
|
buffer2[len] = 0;
|
160
206
|
ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
|
@@ -164,12 +210,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
|
164
210
|
}
|
165
211
|
}
|
166
212
|
|
167
|
-
|
168
|
-
|
169
213
|
struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
170
214
|
GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
|
171
215
|
|
172
|
-
id
|
216
|
+
id<MTLDevice> device;
|
173
217
|
NSString * s;
|
174
218
|
|
175
219
|
#if TARGET_OS_OSX
|
@@ -215,6 +259,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
215
259
|
|
216
260
|
NSString * sourcePath;
|
217
261
|
NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
|
262
|
+
|
263
|
+
GGML_METAL_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, ggmlMetalPathResources ? [ggmlMetalPathResources UTF8String] : "nil");
|
264
|
+
|
218
265
|
if (ggmlMetalPathResources) {
|
219
266
|
sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
|
220
267
|
} else {
|
@@ -245,6 +292,29 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
245
292
|
}
|
246
293
|
}
|
247
294
|
|
295
|
+
#if TARGET_OS_OSX
|
296
|
+
// print MTL GPU family:
|
297
|
+
GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
|
298
|
+
|
299
|
+
// determine max supported GPU family
|
300
|
+
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
301
|
+
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
302
|
+
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
303
|
+
if ([ctx->device supportsFamily:i]) {
|
304
|
+
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
|
305
|
+
break;
|
306
|
+
}
|
307
|
+
}
|
308
|
+
|
309
|
+
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
310
|
+
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
|
311
|
+
if (ctx->device.maxTransferRate != 0) {
|
312
|
+
GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
|
313
|
+
} else {
|
314
|
+
GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
|
315
|
+
}
|
316
|
+
#endif
|
317
|
+
|
248
318
|
// load kernels
|
249
319
|
{
|
250
320
|
NSError * error = nil;
|
@@ -266,11 +336,15 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
266
336
|
GGML_METAL_ADD_KERNEL(add_row);
|
267
337
|
GGML_METAL_ADD_KERNEL(mul);
|
268
338
|
GGML_METAL_ADD_KERNEL(mul_row);
|
339
|
+
GGML_METAL_ADD_KERNEL(div);
|
340
|
+
GGML_METAL_ADD_KERNEL(div_row);
|
269
341
|
GGML_METAL_ADD_KERNEL(scale);
|
270
342
|
GGML_METAL_ADD_KERNEL(scale_4);
|
271
|
-
GGML_METAL_ADD_KERNEL(
|
343
|
+
GGML_METAL_ADD_KERNEL(tanh);
|
272
344
|
GGML_METAL_ADD_KERNEL(relu);
|
273
345
|
GGML_METAL_ADD_KERNEL(gelu);
|
346
|
+
GGML_METAL_ADD_KERNEL(gelu_quick);
|
347
|
+
GGML_METAL_ADD_KERNEL(silu);
|
274
348
|
GGML_METAL_ADD_KERNEL(soft_max);
|
275
349
|
GGML_METAL_ADD_KERNEL(soft_max_4);
|
276
350
|
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
@@ -288,6 +362,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
288
362
|
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
|
289
363
|
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
290
364
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
365
|
+
GGML_METAL_ADD_KERNEL(group_norm);
|
291
366
|
GGML_METAL_ADD_KERNEL(norm);
|
292
367
|
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
293
368
|
GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
|
@@ -304,6 +379,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
304
379
|
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
|
305
380
|
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
|
306
381
|
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
|
382
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
|
383
|
+
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
|
384
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
|
385
|
+
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
|
386
|
+
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
|
387
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
|
388
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
|
389
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
|
390
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
|
391
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
|
392
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
|
393
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
|
394
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
|
395
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
|
396
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
|
307
397
|
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
308
398
|
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
309
399
|
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
@@ -317,43 +407,44 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
317
407
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
318
408
|
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
319
409
|
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
410
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
|
411
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
|
412
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
|
413
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
|
414
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
|
415
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
|
416
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
|
417
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
|
418
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
|
419
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
|
420
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
|
421
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
|
320
422
|
}
|
321
423
|
GGML_METAL_ADD_KERNEL(rope_f32);
|
322
424
|
GGML_METAL_ADD_KERNEL(rope_f16);
|
323
425
|
GGML_METAL_ADD_KERNEL(alibi_f32);
|
324
426
|
GGML_METAL_ADD_KERNEL(im2col_f16);
|
427
|
+
GGML_METAL_ADD_KERNEL(upscale_f32);
|
428
|
+
GGML_METAL_ADD_KERNEL(pad_f32);
|
429
|
+
GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
|
430
|
+
GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
|
431
|
+
GGML_METAL_ADD_KERNEL(leaky_relu_f32);
|
325
432
|
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
326
433
|
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
434
|
+
GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
|
435
|
+
GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
|
436
|
+
GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
|
437
|
+
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
|
438
|
+
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
|
327
439
|
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
440
|
+
GGML_METAL_ADD_KERNEL(cpy_f16_f32);
|
328
441
|
GGML_METAL_ADD_KERNEL(concat);
|
329
442
|
GGML_METAL_ADD_KERNEL(sqr);
|
443
|
+
GGML_METAL_ADD_KERNEL(sum_rows);
|
330
444
|
|
331
445
|
#undef GGML_METAL_ADD_KERNEL
|
332
446
|
}
|
333
447
|
|
334
|
-
#if TARGET_OS_OSX
|
335
|
-
// print MTL GPU family:
|
336
|
-
GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
|
337
|
-
|
338
|
-
// determine max supported GPU family
|
339
|
-
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
340
|
-
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
341
|
-
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
342
|
-
if ([ctx->device supportsFamily:i]) {
|
343
|
-
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
|
344
|
-
break;
|
345
|
-
}
|
346
|
-
}
|
347
|
-
|
348
|
-
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
349
|
-
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MiB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
350
|
-
if (ctx->device.maxTransferRate != 0) {
|
351
|
-
GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MiB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
352
|
-
} else {
|
353
|
-
GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
|
354
|
-
}
|
355
|
-
#endif
|
356
|
-
|
357
448
|
return ctx;
|
358
449
|
}
|
359
450
|
|
@@ -367,11 +458,15 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
367
458
|
GGML_METAL_DEL_KERNEL(add_row);
|
368
459
|
GGML_METAL_DEL_KERNEL(mul);
|
369
460
|
GGML_METAL_DEL_KERNEL(mul_row);
|
461
|
+
GGML_METAL_DEL_KERNEL(div);
|
462
|
+
GGML_METAL_DEL_KERNEL(div_row);
|
370
463
|
GGML_METAL_DEL_KERNEL(scale);
|
371
464
|
GGML_METAL_DEL_KERNEL(scale_4);
|
372
|
-
GGML_METAL_DEL_KERNEL(
|
465
|
+
GGML_METAL_DEL_KERNEL(tanh);
|
373
466
|
GGML_METAL_DEL_KERNEL(relu);
|
374
467
|
GGML_METAL_DEL_KERNEL(gelu);
|
468
|
+
GGML_METAL_DEL_KERNEL(gelu_quick);
|
469
|
+
GGML_METAL_DEL_KERNEL(silu);
|
375
470
|
GGML_METAL_DEL_KERNEL(soft_max);
|
376
471
|
GGML_METAL_DEL_KERNEL(soft_max_4);
|
377
472
|
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
@@ -389,6 +484,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
389
484
|
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
390
485
|
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
391
486
|
GGML_METAL_DEL_KERNEL(rms_norm);
|
487
|
+
GGML_METAL_DEL_KERNEL(group_norm);
|
392
488
|
GGML_METAL_DEL_KERNEL(norm);
|
393
489
|
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
394
490
|
GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
|
@@ -405,6 +501,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
405
501
|
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
|
406
502
|
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
407
503
|
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
504
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
|
505
|
+
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
|
506
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
|
507
|
+
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
|
508
|
+
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
|
509
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
|
510
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
|
511
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
|
512
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
|
513
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
|
514
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
|
515
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
|
516
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
|
517
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
|
518
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
|
408
519
|
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
409
520
|
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
410
521
|
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
@@ -418,16 +529,40 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
418
529
|
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
419
530
|
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
420
531
|
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
532
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
|
533
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
|
534
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
|
535
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
|
536
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
|
537
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
|
538
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
|
539
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
|
540
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
|
541
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
|
542
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
|
543
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
|
421
544
|
}
|
422
545
|
GGML_METAL_DEL_KERNEL(rope_f32);
|
423
546
|
GGML_METAL_DEL_KERNEL(rope_f16);
|
424
547
|
GGML_METAL_DEL_KERNEL(alibi_f32);
|
425
548
|
GGML_METAL_DEL_KERNEL(im2col_f16);
|
549
|
+
GGML_METAL_DEL_KERNEL(upscale_f32);
|
550
|
+
GGML_METAL_DEL_KERNEL(pad_f32);
|
551
|
+
GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
|
552
|
+
GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
|
553
|
+
GGML_METAL_DEL_KERNEL(leaky_relu_f32);
|
426
554
|
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
427
555
|
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
556
|
+
GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
|
557
|
+
GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
|
558
|
+
GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
|
559
|
+
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
|
560
|
+
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
|
428
561
|
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
562
|
+
GGML_METAL_DEL_KERNEL(cpy_f16_f32);
|
429
563
|
GGML_METAL_DEL_KERNEL(concat);
|
430
564
|
GGML_METAL_DEL_KERNEL(sqr);
|
565
|
+
GGML_METAL_DEL_KERNEL(sum_rows);
|
431
566
|
|
432
567
|
#undef GGML_METAL_DEL_KERNEL
|
433
568
|
|
@@ -471,6 +606,13 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
|
|
471
606
|
return ctx->concur_list;
|
472
607
|
}
|
473
608
|
|
609
|
+
// temporarily defined here for compatibility between ggml-backend and the old API
|
610
|
+
struct ggml_backend_metal_buffer_context {
|
611
|
+
void * data;
|
612
|
+
|
613
|
+
id<MTLBuffer> metal;
|
614
|
+
};
|
615
|
+
|
474
616
|
// finds the Metal buffer that contains the tensor data on the GPU device
|
475
617
|
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
476
618
|
// Metal buffer based on the host memory pointer
|
@@ -480,8 +622,17 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
|
480
622
|
|
481
623
|
const int64_t tsize = ggml_nbytes(t);
|
482
624
|
|
483
|
-
|
484
|
-
|
625
|
+
// compatibility with ggml-backend
|
626
|
+
if (t->buffer && t->buffer->buft == ggml_backend_metal_buffer_type()) {
|
627
|
+
struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) t->buffer->context;
|
628
|
+
|
629
|
+
const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->data;
|
630
|
+
|
631
|
+
GGML_ASSERT(ioffs >= 0 && ioffs + tsize <= (int64_t) t->buffer->size);
|
632
|
+
|
633
|
+
*offs = (size_t) ioffs;
|
634
|
+
|
635
|
+
return buf_ctx->metal;
|
485
636
|
}
|
486
637
|
|
487
638
|
// find the view that contains the tensor fully
|
@@ -706,6 +857,83 @@ void ggml_metal_graph_find_concurrency(
|
|
706
857
|
}
|
707
858
|
}
|
708
859
|
|
860
|
+
static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
861
|
+
switch (op->op) {
|
862
|
+
case GGML_OP_UNARY:
|
863
|
+
switch (ggml_get_unary_op(op)) {
|
864
|
+
case GGML_UNARY_OP_TANH:
|
865
|
+
case GGML_UNARY_OP_RELU:
|
866
|
+
case GGML_UNARY_OP_GELU:
|
867
|
+
case GGML_UNARY_OP_GELU_QUICK:
|
868
|
+
case GGML_UNARY_OP_SILU:
|
869
|
+
return true;
|
870
|
+
default:
|
871
|
+
return false;
|
872
|
+
}
|
873
|
+
case GGML_OP_NONE:
|
874
|
+
case GGML_OP_RESHAPE:
|
875
|
+
case GGML_OP_VIEW:
|
876
|
+
case GGML_OP_TRANSPOSE:
|
877
|
+
case GGML_OP_PERMUTE:
|
878
|
+
case GGML_OP_CONCAT:
|
879
|
+
case GGML_OP_ADD:
|
880
|
+
case GGML_OP_ACC:
|
881
|
+
case GGML_OP_MUL:
|
882
|
+
case GGML_OP_DIV:
|
883
|
+
case GGML_OP_SCALE:
|
884
|
+
case GGML_OP_SQR:
|
885
|
+
case GGML_OP_SUM_ROWS:
|
886
|
+
case GGML_OP_SOFT_MAX:
|
887
|
+
case GGML_OP_RMS_NORM:
|
888
|
+
case GGML_OP_GROUP_NORM:
|
889
|
+
case GGML_OP_NORM:
|
890
|
+
case GGML_OP_ALIBI:
|
891
|
+
case GGML_OP_ROPE:
|
892
|
+
case GGML_OP_IM2COL:
|
893
|
+
case GGML_OP_UPSCALE:
|
894
|
+
case GGML_OP_PAD:
|
895
|
+
case GGML_OP_ARGSORT:
|
896
|
+
case GGML_OP_LEAKY_RELU:
|
897
|
+
case GGML_OP_MUL_MAT:
|
898
|
+
case GGML_OP_MUL_MAT_ID:
|
899
|
+
return true;
|
900
|
+
case GGML_OP_CPY:
|
901
|
+
case GGML_OP_DUP:
|
902
|
+
case GGML_OP_CONT:
|
903
|
+
{
|
904
|
+
switch (op->src[0]->type) {
|
905
|
+
case GGML_TYPE_F32:
|
906
|
+
switch (op->type) {
|
907
|
+
case GGML_TYPE_F16:
|
908
|
+
case GGML_TYPE_F32:
|
909
|
+
case GGML_TYPE_Q8_0:
|
910
|
+
case GGML_TYPE_Q4_0:
|
911
|
+
case GGML_TYPE_Q4_1:
|
912
|
+
return true;
|
913
|
+
default:
|
914
|
+
return false;
|
915
|
+
}
|
916
|
+
case GGML_TYPE_F16:
|
917
|
+
switch (op->type) {
|
918
|
+
case GGML_TYPE_F16:
|
919
|
+
case GGML_TYPE_F32:
|
920
|
+
return true;
|
921
|
+
default:
|
922
|
+
return false;
|
923
|
+
}
|
924
|
+
default:
|
925
|
+
return false;
|
926
|
+
};
|
927
|
+
}
|
928
|
+
case GGML_OP_DIAG_MASK_INF:
|
929
|
+
case GGML_OP_GET_ROWS:
|
930
|
+
{
|
931
|
+
return op->ne[3] == 1;
|
932
|
+
}
|
933
|
+
default:
|
934
|
+
return false;
|
935
|
+
}
|
936
|
+
}
|
709
937
|
void ggml_metal_graph_compute(
|
710
938
|
struct ggml_metal_context * ctx,
|
711
939
|
struct ggml_cgraph * gf) {
|
@@ -776,6 +1004,11 @@ void ggml_metal_graph_compute(
|
|
776
1004
|
} break;
|
777
1005
|
}
|
778
1006
|
|
1007
|
+
if (!ggml_metal_supports_op(dst)) {
|
1008
|
+
GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
|
1009
|
+
GGML_ASSERT(!"unsupported op");
|
1010
|
+
}
|
1011
|
+
|
779
1012
|
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
780
1013
|
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
781
1014
|
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
@@ -868,25 +1101,42 @@ void ggml_metal_graph_compute(
|
|
868
1101
|
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
869
1102
|
} break;
|
870
1103
|
case GGML_OP_ADD:
|
1104
|
+
case GGML_OP_MUL:
|
1105
|
+
case GGML_OP_DIV:
|
871
1106
|
{
|
872
|
-
|
873
|
-
GGML_ASSERT(ggml_is_contiguous(src1));
|
1107
|
+
const size_t offs = 0;
|
874
1108
|
|
875
1109
|
bool bcast_row = false;
|
876
1110
|
|
877
1111
|
int64_t nb = ne00;
|
878
1112
|
|
879
|
-
|
1113
|
+
id<MTLComputePipelineState> pipeline = nil;
|
1114
|
+
|
1115
|
+
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
1116
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
1117
|
+
|
880
1118
|
// src1 is a row
|
881
1119
|
GGML_ASSERT(ne11 == 1);
|
882
1120
|
|
883
1121
|
nb = ne00 / 4;
|
884
|
-
|
1122
|
+
switch (dst->op) {
|
1123
|
+
case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
|
1124
|
+
case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
|
1125
|
+
case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
|
1126
|
+
default: GGML_ASSERT(false);
|
1127
|
+
}
|
885
1128
|
|
886
1129
|
bcast_row = true;
|
887
1130
|
} else {
|
888
|
-
|
1131
|
+
switch (dst->op) {
|
1132
|
+
case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
|
1133
|
+
case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
|
1134
|
+
case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
|
1135
|
+
default: GGML_ASSERT(false);
|
1136
|
+
}
|
889
1137
|
}
|
1138
|
+
|
1139
|
+
[encoder setComputePipelineState:pipeline];
|
890
1140
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
891
1141
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
892
1142
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
@@ -914,42 +1164,98 @@ void ggml_metal_graph_compute(
|
|
914
1164
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
915
1165
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
916
1166
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
917
|
-
[encoder setBytes:&
|
1167
|
+
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
1168
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
|
918
1169
|
|
919
1170
|
if (bcast_row) {
|
920
1171
|
const int64_t n = ggml_nelements(dst)/4;
|
921
1172
|
|
922
1173
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
923
1174
|
} else {
|
924
|
-
const int nth = MIN(
|
1175
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
925
1176
|
|
926
1177
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
927
1178
|
}
|
928
1179
|
} break;
|
929
|
-
case
|
1180
|
+
case GGML_OP_ACC:
|
930
1181
|
{
|
1182
|
+
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
1183
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1184
|
+
GGML_ASSERT(dstt == GGML_TYPE_F32);
|
1185
|
+
|
931
1186
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
932
1187
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
933
1188
|
|
934
|
-
|
935
|
-
|
936
|
-
const
|
1189
|
+
const size_t pnb1 = ((int32_t *) dst->op_params)[0];
|
1190
|
+
const size_t pnb2 = ((int32_t *) dst->op_params)[1];
|
1191
|
+
const size_t pnb3 = ((int32_t *) dst->op_params)[2];
|
1192
|
+
const size_t offs = ((int32_t *) dst->op_params)[3];
|
1193
|
+
|
1194
|
+
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
1195
|
+
|
1196
|
+
if (!inplace) {
|
1197
|
+
// run a separete kernel to cpy src->dst
|
1198
|
+
// not sure how to avoid this
|
1199
|
+
// TODO: make a simpler cpy_bytes kernel
|
1200
|
+
|
1201
|
+
const int nth = MIN(1024, ne00);
|
1202
|
+
|
1203
|
+
[encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
|
1204
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1205
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1206
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
1207
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
1208
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
1209
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
1210
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
1211
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
1212
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
1213
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
1214
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
1215
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
1216
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
1217
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
1218
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
1219
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
1220
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
1221
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
937
1222
|
|
938
|
-
|
939
|
-
// src1 is a row
|
940
|
-
GGML_ASSERT(ne11 == 1);
|
941
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
942
|
-
} else {
|
943
|
-
[encoder setComputePipelineState:ctx->pipeline_mul];
|
1223
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
944
1224
|
}
|
1225
|
+
|
1226
|
+
[encoder setComputePipelineState:ctx->pipeline_add];
|
945
1227
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
946
1228
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
947
1229
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
948
|
-
[encoder setBytes:&
|
1230
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
1231
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
1232
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
1233
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
1234
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
1235
|
+
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
|
1236
|
+
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
|
1237
|
+
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
|
1238
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
1239
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
1240
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
1241
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
1242
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
1243
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
1244
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
1245
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
1246
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
1247
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
1248
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
1249
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
1250
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
1251
|
+
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
|
1252
|
+
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
|
1253
|
+
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
|
1254
|
+
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
949
1255
|
|
950
|
-
const
|
1256
|
+
const int nth = MIN(1024, ne0);
|
951
1257
|
|
952
|
-
[encoder dispatchThreadgroups:MTLSizeMake(
|
1258
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
953
1259
|
} break;
|
954
1260
|
case GGML_OP_SCALE:
|
955
1261
|
{
|
@@ -974,16 +1280,15 @@ void ggml_metal_graph_compute(
|
|
974
1280
|
} break;
|
975
1281
|
case GGML_OP_UNARY:
|
976
1282
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
977
|
-
case
|
1283
|
+
case GGML_UNARY_OP_TANH:
|
978
1284
|
{
|
979
|
-
[encoder setComputePipelineState:ctx->
|
1285
|
+
[encoder setComputePipelineState:ctx->pipeline_tanh];
|
980
1286
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
981
1287
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
982
1288
|
|
983
1289
|
const int64_t n = ggml_nelements(dst);
|
984
|
-
GGML_ASSERT(n % 4 == 0);
|
985
1290
|
|
986
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n
|
1291
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
987
1292
|
} break;
|
988
1293
|
case GGML_UNARY_OP_RELU:
|
989
1294
|
{
|
@@ -1004,6 +1309,28 @@ void ggml_metal_graph_compute(
|
|
1004
1309
|
const int64_t n = ggml_nelements(dst);
|
1005
1310
|
GGML_ASSERT(n % 4 == 0);
|
1006
1311
|
|
1312
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1313
|
+
} break;
|
1314
|
+
case GGML_UNARY_OP_GELU_QUICK:
|
1315
|
+
{
|
1316
|
+
[encoder setComputePipelineState:ctx->pipeline_gelu_quick];
|
1317
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1318
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1319
|
+
|
1320
|
+
const int64_t n = ggml_nelements(dst);
|
1321
|
+
GGML_ASSERT(n % 4 == 0);
|
1322
|
+
|
1323
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1324
|
+
} break;
|
1325
|
+
case GGML_UNARY_OP_SILU:
|
1326
|
+
{
|
1327
|
+
[encoder setComputePipelineState:ctx->pipeline_silu];
|
1328
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1329
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1330
|
+
|
1331
|
+
const int64_t n = ggml_nelements(dst);
|
1332
|
+
GGML_ASSERT(n % 4 == 0);
|
1333
|
+
|
1007
1334
|
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1008
1335
|
} break;
|
1009
1336
|
default:
|
@@ -1023,6 +1350,40 @@ void ggml_metal_graph_compute(
|
|
1023
1350
|
const int64_t n = ggml_nelements(dst);
|
1024
1351
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1025
1352
|
} break;
|
1353
|
+
case GGML_OP_SUM_ROWS:
|
1354
|
+
{
|
1355
|
+
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
1356
|
+
|
1357
|
+
[encoder setComputePipelineState:ctx->pipeline_sum_rows];
|
1358
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1359
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1360
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
1361
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
1362
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
1363
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
1364
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
1365
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
1366
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
1367
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
1368
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
1369
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
1370
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
1371
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
1372
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
1373
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
1374
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
1375
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
1376
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
|
1377
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
|
1378
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
|
1379
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
|
1380
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
|
1381
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
|
1382
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
|
1383
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
|
1384
|
+
|
1385
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1386
|
+
} break;
|
1026
1387
|
case GGML_OP_SOFT_MAX:
|
1027
1388
|
{
|
1028
1389
|
int nth = 32; // SIMD width
|
@@ -1042,7 +1403,11 @@ void ggml_metal_graph_compute(
|
|
1042
1403
|
const float scale = ((float *) dst->op_params)[0];
|
1043
1404
|
|
1044
1405
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1045
|
-
|
1406
|
+
if (id_src1) {
|
1407
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1408
|
+
} else {
|
1409
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
1410
|
+
}
|
1046
1411
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1047
1412
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
1048
1413
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
@@ -1077,9 +1442,13 @@ void ggml_metal_graph_compute(
|
|
1077
1442
|
case GGML_OP_MUL_MAT:
|
1078
1443
|
{
|
1079
1444
|
GGML_ASSERT(ne00 == ne10);
|
1080
|
-
GGML_ASSERT(ne03 == ne13);
|
1081
1445
|
|
1082
|
-
|
1446
|
+
// TODO: assert that dim2 and dim3 are contiguous
|
1447
|
+
GGML_ASSERT(ne12 % ne02 == 0);
|
1448
|
+
GGML_ASSERT(ne13 % ne03 == 0);
|
1449
|
+
|
1450
|
+
const uint r2 = ne12/ne02;
|
1451
|
+
const uint r3 = ne13/ne03;
|
1083
1452
|
|
1084
1453
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
1085
1454
|
// to the matrix-vector kernel
|
@@ -1114,7 +1483,7 @@ void ggml_metal_graph_compute(
|
|
1114
1483
|
!ggml_is_transposed(src1) &&
|
1115
1484
|
src1t == GGML_TYPE_F32 &&
|
1116
1485
|
ne00 % 32 == 0 && ne00 >= 64 &&
|
1117
|
-
ne11 > ne11_mm_min) {
|
1486
|
+
(ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
|
1118
1487
|
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
1119
1488
|
switch (src0->type) {
|
1120
1489
|
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
|
@@ -1144,9 +1513,10 @@ void ggml_metal_graph_compute(
|
|
1144
1513
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
1145
1514
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
1146
1515
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
1147
|
-
[encoder setBytes:&
|
1516
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
|
1517
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
|
1148
1518
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
1149
|
-
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
1519
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
1150
1520
|
} else {
|
1151
1521
|
int nth0 = 32;
|
1152
1522
|
int nth1 = 1;
|
@@ -1182,90 +1552,60 @@ void ggml_metal_graph_compute(
|
|
1182
1552
|
} break;
|
1183
1553
|
case GGML_TYPE_Q4_0:
|
1184
1554
|
{
|
1185
|
-
GGML_ASSERT(ne02 == 1);
|
1186
|
-
GGML_ASSERT(ne12 == 1);
|
1187
|
-
|
1188
1555
|
nth0 = 8;
|
1189
1556
|
nth1 = 8;
|
1190
1557
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
|
1191
1558
|
} break;
|
1192
1559
|
case GGML_TYPE_Q4_1:
|
1193
1560
|
{
|
1194
|
-
GGML_ASSERT(ne02 == 1);
|
1195
|
-
GGML_ASSERT(ne12 == 1);
|
1196
|
-
|
1197
1561
|
nth0 = 8;
|
1198
1562
|
nth1 = 8;
|
1199
1563
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
|
1200
1564
|
} break;
|
1201
1565
|
case GGML_TYPE_Q5_0:
|
1202
1566
|
{
|
1203
|
-
GGML_ASSERT(ne02 == 1);
|
1204
|
-
GGML_ASSERT(ne12 == 1);
|
1205
|
-
|
1206
1567
|
nth0 = 8;
|
1207
1568
|
nth1 = 8;
|
1208
1569
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
|
1209
1570
|
} break;
|
1210
1571
|
case GGML_TYPE_Q5_1:
|
1211
1572
|
{
|
1212
|
-
GGML_ASSERT(ne02 == 1);
|
1213
|
-
GGML_ASSERT(ne12 == 1);
|
1214
|
-
|
1215
1573
|
nth0 = 8;
|
1216
1574
|
nth1 = 8;
|
1217
1575
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
|
1218
1576
|
} break;
|
1219
1577
|
case GGML_TYPE_Q8_0:
|
1220
1578
|
{
|
1221
|
-
GGML_ASSERT(ne02 == 1);
|
1222
|
-
GGML_ASSERT(ne12 == 1);
|
1223
|
-
|
1224
1579
|
nth0 = 8;
|
1225
1580
|
nth1 = 8;
|
1226
1581
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
|
1227
1582
|
} break;
|
1228
1583
|
case GGML_TYPE_Q2_K:
|
1229
1584
|
{
|
1230
|
-
GGML_ASSERT(ne02 == 1);
|
1231
|
-
GGML_ASSERT(ne12 == 1);
|
1232
|
-
|
1233
1585
|
nth0 = 2;
|
1234
1586
|
nth1 = 32;
|
1235
1587
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
|
1236
1588
|
} break;
|
1237
1589
|
case GGML_TYPE_Q3_K:
|
1238
1590
|
{
|
1239
|
-
GGML_ASSERT(ne02 == 1);
|
1240
|
-
GGML_ASSERT(ne12 == 1);
|
1241
|
-
|
1242
1591
|
nth0 = 2;
|
1243
1592
|
nth1 = 32;
|
1244
1593
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
|
1245
1594
|
} break;
|
1246
1595
|
case GGML_TYPE_Q4_K:
|
1247
1596
|
{
|
1248
|
-
GGML_ASSERT(ne02 == 1);
|
1249
|
-
GGML_ASSERT(ne12 == 1);
|
1250
|
-
|
1251
1597
|
nth0 = 4; //1;
|
1252
1598
|
nth1 = 8; //32;
|
1253
1599
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
|
1254
1600
|
} break;
|
1255
1601
|
case GGML_TYPE_Q5_K:
|
1256
1602
|
{
|
1257
|
-
GGML_ASSERT(ne02 == 1);
|
1258
|
-
GGML_ASSERT(ne12 == 1);
|
1259
|
-
|
1260
1603
|
nth0 = 2;
|
1261
1604
|
nth1 = 32;
|
1262
1605
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
|
1263
1606
|
} break;
|
1264
1607
|
case GGML_TYPE_Q6_K:
|
1265
1608
|
{
|
1266
|
-
GGML_ASSERT(ne02 == 1);
|
1267
|
-
GGML_ASSERT(ne12 == 1);
|
1268
|
-
|
1269
1609
|
nth0 = 2;
|
1270
1610
|
nth1 = 32;
|
1271
1611
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
|
@@ -1294,31 +1634,281 @@ void ggml_metal_graph_compute(
|
|
1294
1634
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
1295
1635
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
1296
1636
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
1297
|
-
[encoder setBytes:&
|
1637
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
1638
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
1298
1639
|
|
1299
1640
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
1300
1641
|
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
1301
1642
|
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
1302
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1643
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1303
1644
|
}
|
1304
1645
|
else if (src0t == GGML_TYPE_Q4_K) {
|
1305
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1646
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1306
1647
|
}
|
1307
1648
|
else if (src0t == GGML_TYPE_Q3_K) {
|
1308
1649
|
#ifdef GGML_QKK_64
|
1309
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1650
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1310
1651
|
#else
|
1311
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1652
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1312
1653
|
#endif
|
1313
1654
|
}
|
1314
1655
|
else if (src0t == GGML_TYPE_Q5_K) {
|
1315
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1656
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1316
1657
|
}
|
1317
1658
|
else if (src0t == GGML_TYPE_Q6_K) {
|
1318
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1659
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1660
|
+
} else {
|
1661
|
+
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
1662
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1663
|
+
}
|
1664
|
+
}
|
1665
|
+
} break;
|
1666
|
+
case GGML_OP_MUL_MAT_ID:
|
1667
|
+
{
|
1668
|
+
//GGML_ASSERT(ne00 == ne10);
|
1669
|
+
//GGML_ASSERT(ne03 == ne13);
|
1670
|
+
|
1671
|
+
GGML_ASSERT(src0t == GGML_TYPE_I32);
|
1672
|
+
|
1673
|
+
const int n_as = ((int32_t *) dst->op_params)[1];
|
1674
|
+
|
1675
|
+
// TODO: make this more general
|
1676
|
+
GGML_ASSERT(n_as <= 8);
|
1677
|
+
|
1678
|
+
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
1679
|
+
|
1680
|
+
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
1681
|
+
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
1682
|
+
const int64_t ne22 = src2 ? src2->ne[2] : 0;
|
1683
|
+
const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
|
1684
|
+
|
1685
|
+
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
1686
|
+
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
1687
|
+
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
1688
|
+
const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
|
1689
|
+
|
1690
|
+
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
1691
|
+
|
1692
|
+
GGML_ASSERT(!ggml_is_transposed(src2));
|
1693
|
+
GGML_ASSERT(!ggml_is_transposed(src1));
|
1694
|
+
|
1695
|
+
GGML_ASSERT(ne20 % 32 == 0);
|
1696
|
+
// !!!!!!!!! TODO: this assert is probably required but not sure!
|
1697
|
+
//GGML_ASSERT(ne20 >= 64);
|
1698
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1699
|
+
|
1700
|
+
const uint r2 = ne12/ne22;
|
1701
|
+
const uint r3 = ne13/ne23;
|
1702
|
+
|
1703
|
+
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
1704
|
+
// to the matrix-vector kernel
|
1705
|
+
int ne11_mm_min = 1;
|
1706
|
+
|
1707
|
+
const int idx = ((int32_t *) dst->op_params)[0];
|
1708
|
+
|
1709
|
+
// batch size
|
1710
|
+
GGML_ASSERT(ne01 == ne11);
|
1711
|
+
|
1712
|
+
const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
|
1713
|
+
|
1714
|
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
1715
|
+
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
1716
|
+
// !!!
|
1717
|
+
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
|
1718
|
+
// indirect matrix multiplication
|
1719
|
+
// !!!
|
1720
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
|
1721
|
+
switch (src2->type) {
|
1722
|
+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
|
1723
|
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
|
1724
|
+
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
|
1725
|
+
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
|
1726
|
+
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
|
1727
|
+
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
|
1728
|
+
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
|
1729
|
+
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
|
1730
|
+
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
|
1731
|
+
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
|
1732
|
+
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
|
1733
|
+
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
|
1734
|
+
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
1735
|
+
}
|
1736
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1737
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1738
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1739
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
1740
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
1741
|
+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
|
1742
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
1743
|
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
|
1744
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
1745
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
|
1746
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
|
1747
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
|
1748
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
|
1749
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
1750
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
|
1751
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
1752
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
|
1753
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
1754
|
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
|
1755
|
+
// TODO: how to make this an array? read Metal docs
|
1756
|
+
for (int j = 0; j < n_as; ++j) {
|
1757
|
+
struct ggml_tensor * src_cur = dst->src[2 + j];
|
1758
|
+
|
1759
|
+
size_t offs_src_cur = 0;
|
1760
|
+
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
1761
|
+
|
1762
|
+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
1763
|
+
}
|
1764
|
+
|
1765
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
1766
|
+
|
1767
|
+
// TODO: processing one row at a time (ne11 -> 1) is not efficient
|
1768
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
1769
|
+
} else {
|
1770
|
+
int nth0 = 32;
|
1771
|
+
int nth1 = 1;
|
1772
|
+
int nrows = 1;
|
1773
|
+
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
1774
|
+
|
1775
|
+
// use custom matrix x vector kernel
|
1776
|
+
switch (src2t) {
|
1777
|
+
case GGML_TYPE_F32:
|
1778
|
+
{
|
1779
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1780
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
|
1781
|
+
} break;
|
1782
|
+
case GGML_TYPE_F16:
|
1783
|
+
{
|
1784
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1785
|
+
nth0 = 32;
|
1786
|
+
nth1 = 1;
|
1787
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
|
1788
|
+
} break;
|
1789
|
+
case GGML_TYPE_Q4_0:
|
1790
|
+
{
|
1791
|
+
nth0 = 8;
|
1792
|
+
nth1 = 8;
|
1793
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
|
1794
|
+
} break;
|
1795
|
+
case GGML_TYPE_Q4_1:
|
1796
|
+
{
|
1797
|
+
nth0 = 8;
|
1798
|
+
nth1 = 8;
|
1799
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
|
1800
|
+
} break;
|
1801
|
+
case GGML_TYPE_Q5_0:
|
1802
|
+
{
|
1803
|
+
nth0 = 8;
|
1804
|
+
nth1 = 8;
|
1805
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
|
1806
|
+
} break;
|
1807
|
+
case GGML_TYPE_Q5_1:
|
1808
|
+
{
|
1809
|
+
nth0 = 8;
|
1810
|
+
nth1 = 8;
|
1811
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
|
1812
|
+
} break;
|
1813
|
+
case GGML_TYPE_Q8_0:
|
1814
|
+
{
|
1815
|
+
nth0 = 8;
|
1816
|
+
nth1 = 8;
|
1817
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
|
1818
|
+
} break;
|
1819
|
+
case GGML_TYPE_Q2_K:
|
1820
|
+
{
|
1821
|
+
nth0 = 2;
|
1822
|
+
nth1 = 32;
|
1823
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
|
1824
|
+
} break;
|
1825
|
+
case GGML_TYPE_Q3_K:
|
1826
|
+
{
|
1827
|
+
nth0 = 2;
|
1828
|
+
nth1 = 32;
|
1829
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
|
1830
|
+
} break;
|
1831
|
+
case GGML_TYPE_Q4_K:
|
1832
|
+
{
|
1833
|
+
nth0 = 4; //1;
|
1834
|
+
nth1 = 8; //32;
|
1835
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
|
1836
|
+
} break;
|
1837
|
+
case GGML_TYPE_Q5_K:
|
1838
|
+
{
|
1839
|
+
nth0 = 2;
|
1840
|
+
nth1 = 32;
|
1841
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
|
1842
|
+
} break;
|
1843
|
+
case GGML_TYPE_Q6_K:
|
1844
|
+
{
|
1845
|
+
nth0 = 2;
|
1846
|
+
nth1 = 32;
|
1847
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
|
1848
|
+
} break;
|
1849
|
+
default:
|
1850
|
+
{
|
1851
|
+
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
1852
|
+
GGML_ASSERT(false && "not implemented");
|
1853
|
+
}
|
1854
|
+
};
|
1855
|
+
|
1856
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1857
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1858
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1859
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
1860
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
1861
|
+
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
1862
|
+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
|
1863
|
+
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
|
1864
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
|
1865
|
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
|
1866
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
1867
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
|
1868
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
1869
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
1870
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
1871
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
1872
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
1873
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
1874
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
|
1875
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
|
1876
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
|
1877
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
|
1878
|
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:22];
|
1879
|
+
// TODO: how to make this an array? read Metal docs
|
1880
|
+
for (int j = 0; j < n_as; ++j) {
|
1881
|
+
struct ggml_tensor * src_cur = dst->src[2 + j];
|
1882
|
+
|
1883
|
+
size_t offs_src_cur = 0;
|
1884
|
+
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
1885
|
+
|
1886
|
+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
1887
|
+
}
|
1888
|
+
|
1889
|
+
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
1890
|
+
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
1891
|
+
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
|
1892
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1893
|
+
}
|
1894
|
+
else if (src2t == GGML_TYPE_Q4_K) {
|
1895
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1896
|
+
}
|
1897
|
+
else if (src2t == GGML_TYPE_Q3_K) {
|
1898
|
+
#ifdef GGML_QKK_64
|
1899
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1900
|
+
#else
|
1901
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1902
|
+
#endif
|
1903
|
+
}
|
1904
|
+
else if (src2t == GGML_TYPE_Q5_K) {
|
1905
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1906
|
+
}
|
1907
|
+
else if (src2t == GGML_TYPE_Q6_K) {
|
1908
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1319
1909
|
} else {
|
1320
|
-
int64_t ny = (
|
1321
|
-
[encoder dispatchThreadgroups:MTLSizeMake(
|
1910
|
+
const int64_t ny = (_ne1 + nrows - 1)/nrows;
|
1911
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1322
1912
|
}
|
1323
1913
|
}
|
1324
1914
|
} break;
|
@@ -1340,16 +1930,19 @@ void ggml_metal_graph_compute(
|
|
1340
1930
|
default: GGML_ASSERT(false && "not implemented");
|
1341
1931
|
}
|
1342
1932
|
|
1343
|
-
[encoder setBuffer:id_src0
|
1344
|
-
[encoder setBuffer:id_src1
|
1345
|
-
[encoder setBuffer:id_dst
|
1933
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1934
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1935
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1346
1936
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
1347
1937
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
1348
|
-
[encoder setBytes:&
|
1349
|
-
|
1350
|
-
|
1351
|
-
|
1352
|
-
[encoder
|
1938
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
|
1939
|
+
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
|
1940
|
+
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
|
1941
|
+
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
|
1942
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
|
1943
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
|
1944
|
+
|
1945
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
1353
1946
|
} break;
|
1354
1947
|
case GGML_OP_RMS_NORM:
|
1355
1948
|
{
|
@@ -1376,6 +1969,38 @@ void ggml_metal_graph_compute(
|
|
1376
1969
|
|
1377
1970
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1378
1971
|
} break;
|
1972
|
+
case GGML_OP_GROUP_NORM:
|
1973
|
+
{
|
1974
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
1975
|
+
|
1976
|
+
//float eps;
|
1977
|
+
//memcpy(&eps, dst->op_params, sizeof(float));
|
1978
|
+
|
1979
|
+
const float eps = 1e-6f; // TODO: temporarily hardcoded
|
1980
|
+
|
1981
|
+
const int32_t n_groups = ((int32_t *) dst->op_params)[0];
|
1982
|
+
|
1983
|
+
int nth = 32; // SIMD width
|
1984
|
+
|
1985
|
+
//while (nth < ne00/4 && nth < 1024) {
|
1986
|
+
// nth *= 2;
|
1987
|
+
//}
|
1988
|
+
|
1989
|
+
[encoder setComputePipelineState:ctx->pipeline_group_norm];
|
1990
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1991
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1992
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
1993
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
1994
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
1995
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
|
1996
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
|
1997
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
|
1998
|
+
[encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
|
1999
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:9];
|
2000
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
2001
|
+
|
2002
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2003
|
+
} break;
|
1379
2004
|
case GGML_OP_NORM:
|
1380
2005
|
{
|
1381
2006
|
float eps;
|
@@ -1545,18 +2170,123 @@ void ggml_metal_graph_compute(
|
|
1545
2170
|
|
1546
2171
|
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
1547
2172
|
} break;
|
2173
|
+
case GGML_OP_UPSCALE:
|
2174
|
+
{
|
2175
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2176
|
+
|
2177
|
+
const int sf = dst->op_params[0];
|
2178
|
+
|
2179
|
+
[encoder setComputePipelineState:ctx->pipeline_upscale_f32];
|
2180
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2181
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2182
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
2183
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
2184
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
2185
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
2186
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
2187
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
2188
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
2189
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
2190
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
2191
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
2192
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
2193
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
2194
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
2195
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
2196
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
2197
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
2198
|
+
[encoder setBytes:&sf length:sizeof(sf) atIndex:18];
|
2199
|
+
|
2200
|
+
const int nth = MIN(1024, ne0);
|
2201
|
+
|
2202
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2203
|
+
} break;
|
2204
|
+
case GGML_OP_PAD:
|
2205
|
+
{
|
2206
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2207
|
+
|
2208
|
+
[encoder setComputePipelineState:ctx->pipeline_pad_f32];
|
2209
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2210
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2211
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
2212
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
2213
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
2214
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
2215
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
2216
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
2217
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
2218
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
2219
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
2220
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
2221
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
2222
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
2223
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
2224
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
2225
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
2226
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
2227
|
+
|
2228
|
+
const int nth = MIN(1024, ne0);
|
2229
|
+
|
2230
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2231
|
+
} break;
|
2232
|
+
case GGML_OP_ARGSORT:
|
2233
|
+
{
|
2234
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2235
|
+
GGML_ASSERT( dst->type == GGML_TYPE_I32);
|
2236
|
+
|
2237
|
+
const int nrows = ggml_nrows(src0);
|
2238
|
+
|
2239
|
+
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
2240
|
+
|
2241
|
+
switch (order) {
|
2242
|
+
case GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
|
2243
|
+
case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
|
2244
|
+
default: GGML_ASSERT(false);
|
2245
|
+
};
|
2246
|
+
|
2247
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2248
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2249
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
2250
|
+
|
2251
|
+
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
|
2252
|
+
} break;
|
2253
|
+
case GGML_OP_LEAKY_RELU:
|
2254
|
+
{
|
2255
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2256
|
+
|
2257
|
+
float slope;
|
2258
|
+
memcpy(&slope, dst->op_params, sizeof(float));
|
2259
|
+
|
2260
|
+
[encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
|
2261
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2262
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2263
|
+
[encoder setBytes:&slope length:sizeof(slope) atIndex:2];
|
2264
|
+
|
2265
|
+
const int64_t n = ggml_nelements(dst);
|
2266
|
+
|
2267
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
2268
|
+
} break;
|
1548
2269
|
case GGML_OP_DUP:
|
1549
2270
|
case GGML_OP_CPY:
|
1550
2271
|
case GGML_OP_CONT:
|
1551
2272
|
{
|
1552
|
-
|
2273
|
+
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
2274
|
+
|
2275
|
+
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
|
1553
2276
|
|
1554
2277
|
switch (src0t) {
|
1555
2278
|
case GGML_TYPE_F32:
|
1556
2279
|
{
|
2280
|
+
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
|
2281
|
+
|
1557
2282
|
switch (dstt) {
|
1558
|
-
case GGML_TYPE_F16:
|
1559
|
-
case GGML_TYPE_F32:
|
2283
|
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
|
2284
|
+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
|
2285
|
+
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
|
2286
|
+
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
|
2287
|
+
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
|
2288
|
+
//case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
|
2289
|
+
//case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
|
1560
2290
|
default: GGML_ASSERT(false && "not implemented");
|
1561
2291
|
};
|
1562
2292
|
} break;
|
@@ -1564,7 +2294,7 @@ void ggml_metal_graph_compute(
|
|
1564
2294
|
{
|
1565
2295
|
switch (dstt) {
|
1566
2296
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
|
1567
|
-
case GGML_TYPE_F32:
|
2297
|
+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
|
1568
2298
|
default: GGML_ASSERT(false && "not implemented");
|
1569
2299
|
};
|
1570
2300
|
} break;
|
@@ -1631,81 +2361,150 @@ void ggml_metal_graph_compute(
|
|
1631
2361
|
|
1632
2362
|
// backend interface
|
1633
2363
|
|
1634
|
-
static
|
1635
|
-
|
2364
|
+
static id<MTLDevice> g_backend_device = nil;
|
2365
|
+
static int g_backend_device_ref_count = 0;
|
1636
2366
|
|
1637
|
-
|
2367
|
+
static id<MTLDevice> ggml_backend_metal_get_device(void) {
|
2368
|
+
if (g_backend_device == nil) {
|
2369
|
+
g_backend_device = MTLCreateSystemDefaultDevice();
|
2370
|
+
}
|
2371
|
+
|
2372
|
+
g_backend_device_ref_count++;
|
2373
|
+
|
2374
|
+
return g_backend_device;
|
1638
2375
|
}
|
1639
2376
|
|
1640
|
-
static void
|
1641
|
-
|
1642
|
-
|
1643
|
-
|
2377
|
+
static void ggml_backend_metal_free_device(void) {
|
2378
|
+
assert(g_backend_device_ref_count > 0);
|
2379
|
+
|
2380
|
+
g_backend_device_ref_count--;
|
2381
|
+
|
2382
|
+
if (g_backend_device_ref_count == 0) {
|
2383
|
+
[g_backend_device release];
|
2384
|
+
g_backend_device = nil;
|
2385
|
+
}
|
1644
2386
|
}
|
1645
2387
|
|
1646
2388
|
static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
|
1647
|
-
|
2389
|
+
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
|
2390
|
+
|
2391
|
+
return ctx->data;
|
1648
2392
|
}
|
1649
2393
|
|
1650
2394
|
static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
1651
|
-
|
2395
|
+
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
|
2396
|
+
|
2397
|
+
[ctx->metal release];
|
2398
|
+
ggml_backend_metal_free_device();
|
2399
|
+
|
2400
|
+
free(ctx->data);
|
2401
|
+
free(ctx);
|
2402
|
+
|
2403
|
+
UNUSED(buffer);
|
2404
|
+
}
|
2405
|
+
|
2406
|
+
static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
2407
|
+
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
|
2408
|
+
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
2409
|
+
|
2410
|
+
memcpy((char *)tensor->data + offset, data, size);
|
2411
|
+
|
2412
|
+
UNUSED(buffer);
|
2413
|
+
}
|
2414
|
+
|
2415
|
+
static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
2416
|
+
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
|
2417
|
+
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
2418
|
+
|
2419
|
+
memcpy(data, (const char *)tensor->data + offset, size);
|
2420
|
+
|
2421
|
+
UNUSED(buffer);
|
2422
|
+
}
|
2423
|
+
|
2424
|
+
static void ggml_backend_metal_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
|
2425
|
+
ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
|
2426
|
+
|
2427
|
+
UNUSED(buffer);
|
2428
|
+
}
|
2429
|
+
|
2430
|
+
static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
|
2431
|
+
ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
|
2432
|
+
|
1652
2433
|
UNUSED(buffer);
|
1653
2434
|
}
|
1654
2435
|
|
1655
2436
|
static struct ggml_backend_buffer_i metal_backend_buffer_i = {
|
1656
|
-
/* .free_buffer
|
1657
|
-
/* .get_base
|
1658
|
-
/* .
|
1659
|
-
/* .
|
1660
|
-
/* .
|
2437
|
+
/* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
|
2438
|
+
/* .get_base = */ ggml_backend_metal_buffer_get_base,
|
2439
|
+
/* .init_tensor = */ NULL,
|
2440
|
+
/* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
|
2441
|
+
/* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
|
2442
|
+
/* .cpy_tensor_from = */ ggml_backend_metal_buffer_cpy_tensor_from,
|
2443
|
+
/* .cpy_tensor_to = */ ggml_backend_metal_buffer_cpy_tensor_to,
|
1661
2444
|
};
|
1662
2445
|
|
1663
|
-
static ggml_backend_buffer_t
|
1664
|
-
struct
|
2446
|
+
static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
2447
|
+
struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
|
2448
|
+
|
2449
|
+
const size_t size_page = sysconf(_SC_PAGESIZE);
|
1665
2450
|
|
1666
|
-
|
2451
|
+
size_t size_aligned = size;
|
2452
|
+
if ((size_aligned % size_page) != 0) {
|
2453
|
+
size_aligned += (size_page - (size_aligned % size_page));
|
2454
|
+
}
|
1667
2455
|
|
1668
|
-
|
1669
|
-
|
2456
|
+
ctx->data = ggml_metal_host_malloc(size);
|
2457
|
+
ctx->metal = [ggml_backend_metal_get_device() newBufferWithBytesNoCopy:ctx->data
|
2458
|
+
length:size_aligned
|
2459
|
+
options:MTLResourceStorageModeShared
|
2460
|
+
deallocator:nil];
|
1670
2461
|
|
1671
|
-
return ggml_backend_buffer_init(
|
2462
|
+
return ggml_backend_buffer_init(buft, metal_backend_buffer_i, ctx, size);
|
1672
2463
|
}
|
1673
2464
|
|
1674
|
-
static size_t
|
2465
|
+
static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
1675
2466
|
return 32;
|
1676
|
-
UNUSED(
|
2467
|
+
UNUSED(buft);
|
1677
2468
|
}
|
1678
2469
|
|
1679
|
-
static
|
1680
|
-
|
1681
|
-
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
1682
|
-
|
1683
|
-
memcpy((char *)tensor->data + offset, data, size);
|
2470
|
+
static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
|
2471
|
+
return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
|
1684
2472
|
|
1685
|
-
|
2473
|
+
GGML_UNUSED(buft);
|
1686
2474
|
}
|
1687
2475
|
|
1688
|
-
|
1689
|
-
|
1690
|
-
|
1691
|
-
|
1692
|
-
|
2476
|
+
ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
2477
|
+
static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
|
2478
|
+
/* .iface = */ {
|
2479
|
+
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
|
2480
|
+
/* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
|
2481
|
+
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
2482
|
+
/* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
|
2483
|
+
},
|
2484
|
+
/* .context = */ NULL,
|
2485
|
+
};
|
1693
2486
|
|
1694
|
-
|
2487
|
+
return &ggml_backend_buffer_type_metal;
|
1695
2488
|
}
|
1696
2489
|
|
1697
|
-
static
|
2490
|
+
static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
2491
|
+
return "Metal";
|
2492
|
+
|
1698
2493
|
UNUSED(backend);
|
1699
2494
|
}
|
1700
2495
|
|
1701
|
-
static void
|
1702
|
-
|
2496
|
+
static void ggml_backend_metal_free(ggml_backend_t backend) {
|
2497
|
+
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
2498
|
+
ggml_metal_free(ctx);
|
2499
|
+
free(backend);
|
2500
|
+
}
|
1703
2501
|
|
2502
|
+
static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
|
1704
2503
|
UNUSED(backend);
|
1705
2504
|
}
|
1706
2505
|
|
1707
|
-
static
|
1708
|
-
|
2506
|
+
static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
|
2507
|
+
return ggml_backend_metal_buffer_type();
|
1709
2508
|
|
1710
2509
|
UNUSED(backend);
|
1711
2510
|
}
|
@@ -1717,32 +2516,43 @@ static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml
|
|
1717
2516
|
}
|
1718
2517
|
|
1719
2518
|
static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
1720
|
-
return
|
2519
|
+
return ggml_metal_supports_op(op);
|
2520
|
+
|
1721
2521
|
UNUSED(backend);
|
1722
|
-
UNUSED(op);
|
1723
2522
|
}
|
1724
2523
|
|
1725
2524
|
static struct ggml_backend_i metal_backend_i = {
|
1726
|
-
/* .get_name
|
1727
|
-
/* .free
|
1728
|
-
/* .
|
1729
|
-
/* .
|
1730
|
-
/* .
|
1731
|
-
/* .
|
1732
|
-
/* .
|
1733
|
-
/* .
|
1734
|
-
/* .
|
1735
|
-
/* .
|
1736
|
-
/* .
|
1737
|
-
/* .
|
1738
|
-
/* .
|
1739
|
-
/* .supports_op = */ ggml_backend_metal_supports_op,
|
2525
|
+
/* .get_name = */ ggml_backend_metal_name,
|
2526
|
+
/* .free = */ ggml_backend_metal_free,
|
2527
|
+
/* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
|
2528
|
+
/* .set_tensor_async = */ NULL,
|
2529
|
+
/* .get_tensor_async = */ NULL,
|
2530
|
+
/* .cpy_tensor_from_async = */ NULL,
|
2531
|
+
/* .cpy_tensor_to_async = */ NULL,
|
2532
|
+
/* .synchronize = */ ggml_backend_metal_synchronize,
|
2533
|
+
/* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
|
2534
|
+
/* .graph_plan_free = */ NULL,
|
2535
|
+
/* .graph_plan_compute = */ NULL,
|
2536
|
+
/* .graph_compute = */ ggml_backend_metal_graph_compute,
|
2537
|
+
/* .supports_op = */ ggml_backend_metal_supports_op,
|
1740
2538
|
};
|
1741
2539
|
|
2540
|
+
// TODO: make a common log callback for all backends in ggml-backend
|
2541
|
+
static void ggml_backend_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
|
2542
|
+
fprintf(stderr, "%s", msg);
|
2543
|
+
|
2544
|
+
UNUSED(level);
|
2545
|
+
UNUSED(user_data);
|
2546
|
+
}
|
2547
|
+
|
1742
2548
|
ggml_backend_t ggml_backend_metal_init(void) {
|
1743
|
-
|
2549
|
+
ggml_metal_log_set_callback(ggml_backend_log_callback, NULL);
|
1744
2550
|
|
1745
|
-
ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
|
2551
|
+
struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
|
2552
|
+
|
2553
|
+
if (ctx == NULL) {
|
2554
|
+
return NULL;
|
2555
|
+
}
|
1746
2556
|
|
1747
2557
|
ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
|
1748
2558
|
|
@@ -1759,7 +2569,26 @@ bool ggml_backend_is_metal(ggml_backend_t backend) {
|
|
1759
2569
|
}
|
1760
2570
|
|
1761
2571
|
void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
2572
|
+
GGML_ASSERT(ggml_backend_is_metal(backend));
|
2573
|
+
|
1762
2574
|
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
1763
2575
|
|
1764
2576
|
ggml_metal_set_n_cb(ctx, n_cb);
|
1765
2577
|
}
|
2578
|
+
|
2579
|
+
bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
2580
|
+
GGML_ASSERT(ggml_backend_is_metal(backend));
|
2581
|
+
|
2582
|
+
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
2583
|
+
|
2584
|
+
return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
2585
|
+
}
|
2586
|
+
|
2587
|
+
ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
|
2588
|
+
|
2589
|
+
ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
|
2590
|
+
return ggml_backend_metal_init();
|
2591
|
+
|
2592
|
+
GGML_UNUSED(params);
|
2593
|
+
GGML_UNUSED(user_data);
|
2594
|
+
}
|