llama_cpp 0.7.0 → 0.8.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/ext/llama_cpp/extconf.rb +1 -1
- data/ext/llama_cpp/llama_cpp.cpp +41 -21
- data/ext/llama_cpp/src/ggml-alloc.c +62 -107
- data/ext/llama_cpp/src/ggml-alloc.h +11 -5
- data/ext/llama_cpp/src/ggml-backend.c +385 -0
- data/ext/llama_cpp/src/ggml-backend.h +143 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +500 -78
- data/ext/llama_cpp/src/ggml-cuda.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.h +18 -1
- data/ext/llama_cpp/src/ggml-metal.m +396 -127
- data/ext/llama_cpp/src/ggml-metal.metal +290 -46
- data/ext/llama_cpp/src/ggml-opencl.cpp +47 -71
- data/ext/llama_cpp/src/ggml.c +71 -55
- data/ext/llama_cpp/src/ggml.h +15 -9
- data/ext/llama_cpp/src/k_quants.c +12 -20
- data/ext/llama_cpp/src/k_quants.h +5 -5
- data/ext/llama_cpp/src/llama.cpp +1851 -250
- data/ext/llama_cpp/src/llama.h +18 -12
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -4
- metadata +5 -3
@@ -73,6 +73,8 @@ struct ggml_metal_context {
|
|
73
73
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
74
74
|
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
75
75
|
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
76
|
+
GGML_METAL_DECL_KERNEL(get_rows_q5_0);
|
77
|
+
GGML_METAL_DECL_KERNEL(get_rows_q5_1);
|
76
78
|
GGML_METAL_DECL_KERNEL(get_rows_q8_0);
|
77
79
|
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
|
78
80
|
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
|
@@ -81,22 +83,26 @@ struct ggml_metal_context {
|
|
81
83
|
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
82
84
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
83
85
|
GGML_METAL_DECL_KERNEL(norm);
|
84
|
-
GGML_METAL_DECL_KERNEL(
|
85
|
-
GGML_METAL_DECL_KERNEL(
|
86
|
-
GGML_METAL_DECL_KERNEL(
|
87
|
-
GGML_METAL_DECL_KERNEL(
|
88
|
-
GGML_METAL_DECL_KERNEL(
|
89
|
-
GGML_METAL_DECL_KERNEL(
|
90
|
-
GGML_METAL_DECL_KERNEL(
|
91
|
-
GGML_METAL_DECL_KERNEL(
|
92
|
-
GGML_METAL_DECL_KERNEL(
|
93
|
-
GGML_METAL_DECL_KERNEL(
|
94
|
-
GGML_METAL_DECL_KERNEL(
|
95
|
-
GGML_METAL_DECL_KERNEL(
|
86
|
+
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
87
|
+
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
|
88
|
+
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
|
89
|
+
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
|
90
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
|
91
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
|
92
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
|
93
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
|
94
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
|
95
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
|
96
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
|
97
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
|
98
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
99
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
96
100
|
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
97
101
|
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
98
102
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
99
103
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
104
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
|
105
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
|
100
106
|
GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
|
101
107
|
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
|
102
108
|
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
|
@@ -109,6 +115,8 @@ struct ggml_metal_context {
|
|
109
115
|
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
110
116
|
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
111
117
|
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
118
|
+
GGML_METAL_DECL_KERNEL(concat);
|
119
|
+
GGML_METAL_DECL_KERNEL(sqr);
|
112
120
|
|
113
121
|
#undef GGML_METAL_DECL_KERNEL
|
114
122
|
};
|
@@ -183,56 +191,44 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
183
191
|
|
184
192
|
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
185
193
|
|
186
|
-
|
187
|
-
// load the default.metallib file
|
194
|
+
// load library
|
188
195
|
{
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"];
|
193
|
-
NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
|
194
|
-
NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
|
195
|
-
NSURL * libURL = [NSURL fileURLWithPath:libPath];
|
196
|
-
|
197
|
-
// Load the metallib file into a Metal library
|
198
|
-
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
199
|
-
|
200
|
-
if (error) {
|
201
|
-
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
202
|
-
return NULL;
|
203
|
-
}
|
204
|
-
}
|
196
|
+
NSBundle * bundle = nil;
|
197
|
+
#ifdef SWIFT_PACKAGE
|
198
|
+
bundle = SWIFTPM_MODULE_BUNDLE;
|
205
199
|
#else
|
206
|
-
|
207
|
-
|
208
|
-
// read the source from "ggml-metal.metal" into a string and use newLibraryWithSource
|
209
|
-
{
|
200
|
+
bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
201
|
+
#endif
|
210
202
|
NSError * error = nil;
|
203
|
+
NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"];
|
204
|
+
if (libPath != nil) {
|
205
|
+
NSURL * libURL = [NSURL fileURLWithPath:libPath];
|
206
|
+
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
|
207
|
+
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
208
|
+
} else {
|
209
|
+
GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
|
210
|
+
|
211
|
+
NSString * sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
212
|
+
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [sourcePath UTF8String]);
|
213
|
+
NSString * src = [NSString stringWithContentsOfFile:sourcePath encoding:NSUTF8StringEncoding error:&error];
|
214
|
+
if (error) {
|
215
|
+
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
216
|
+
return NULL;
|
217
|
+
}
|
211
218
|
|
212
|
-
|
213
|
-
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
214
|
-
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
215
|
-
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path UTF8String]);
|
216
|
-
|
217
|
-
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
218
|
-
if (error) {
|
219
|
-
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
220
|
-
return NULL;
|
221
|
-
}
|
222
|
-
|
219
|
+
MTLCompileOptions* options = nil;
|
223
220
|
#ifdef GGML_QKK_64
|
224
|
-
|
225
|
-
|
226
|
-
ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
|
227
|
-
#else
|
228
|
-
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
|
221
|
+
options = [MTLCompileOptions new];
|
222
|
+
options.preprocessorMacros = @{ @"QK_K" : @(64) };
|
229
223
|
#endif
|
224
|
+
ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
|
225
|
+
}
|
226
|
+
|
230
227
|
if (error) {
|
231
228
|
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
232
229
|
return NULL;
|
233
230
|
}
|
234
231
|
}
|
235
|
-
#endif
|
236
232
|
|
237
233
|
// load kernels
|
238
234
|
{
|
@@ -264,6 +260,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
264
260
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
265
261
|
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
266
262
|
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
263
|
+
GGML_METAL_ADD_KERNEL(get_rows_q5_0);
|
264
|
+
GGML_METAL_ADD_KERNEL(get_rows_q5_1);
|
267
265
|
GGML_METAL_ADD_KERNEL(get_rows_q8_0);
|
268
266
|
GGML_METAL_ADD_KERNEL(get_rows_q2_K);
|
269
267
|
GGML_METAL_ADD_KERNEL(get_rows_q3_K);
|
@@ -272,40 +270,61 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
272
270
|
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
273
271
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
274
272
|
GGML_METAL_ADD_KERNEL(norm);
|
275
|
-
GGML_METAL_ADD_KERNEL(
|
276
|
-
GGML_METAL_ADD_KERNEL(
|
277
|
-
GGML_METAL_ADD_KERNEL(
|
278
|
-
GGML_METAL_ADD_KERNEL(
|
279
|
-
GGML_METAL_ADD_KERNEL(
|
280
|
-
GGML_METAL_ADD_KERNEL(
|
281
|
-
GGML_METAL_ADD_KERNEL(
|
282
|
-
GGML_METAL_ADD_KERNEL(
|
283
|
-
GGML_METAL_ADD_KERNEL(
|
284
|
-
GGML_METAL_ADD_KERNEL(
|
285
|
-
GGML_METAL_ADD_KERNEL(
|
286
|
-
GGML_METAL_ADD_KERNEL(
|
287
|
-
GGML_METAL_ADD_KERNEL(
|
288
|
-
GGML_METAL_ADD_KERNEL(
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
273
|
+
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
274
|
+
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
|
275
|
+
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
|
276
|
+
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
|
277
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
|
278
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
|
279
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
|
280
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
|
281
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
|
282
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
|
283
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
|
284
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
|
285
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
|
286
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
|
287
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
288
|
+
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
289
|
+
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
290
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
291
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
|
292
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
|
293
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
|
294
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
295
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
|
296
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
|
297
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
298
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
299
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
300
|
+
}
|
297
301
|
GGML_METAL_ADD_KERNEL(rope_f32);
|
298
302
|
GGML_METAL_ADD_KERNEL(rope_f16);
|
299
303
|
GGML_METAL_ADD_KERNEL(alibi_f32);
|
300
304
|
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
301
305
|
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
302
306
|
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
307
|
+
GGML_METAL_ADD_KERNEL(concat);
|
308
|
+
GGML_METAL_ADD_KERNEL(sqr);
|
303
309
|
|
304
310
|
#undef GGML_METAL_ADD_KERNEL
|
305
311
|
}
|
306
312
|
|
307
|
-
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
308
313
|
#if TARGET_OS_OSX
|
314
|
+
// print MTL GPU family:
|
315
|
+
GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
|
316
|
+
|
317
|
+
// determine max supported GPU family
|
318
|
+
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
319
|
+
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
320
|
+
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
321
|
+
if ([ctx->device supportsFamily:i]) {
|
322
|
+
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
|
323
|
+
break;
|
324
|
+
}
|
325
|
+
}
|
326
|
+
|
327
|
+
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
309
328
|
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
310
329
|
if (ctx->device.maxTransferRate != 0) {
|
311
330
|
GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
@@ -339,6 +358,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
339
358
|
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
340
359
|
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
341
360
|
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
361
|
+
GGML_METAL_DEL_KERNEL(get_rows_q5_0);
|
362
|
+
GGML_METAL_DEL_KERNEL(get_rows_q5_1);
|
342
363
|
GGML_METAL_DEL_KERNEL(get_rows_q8_0);
|
343
364
|
GGML_METAL_DEL_KERNEL(get_rows_q2_K);
|
344
365
|
GGML_METAL_DEL_KERNEL(get_rows_q3_K);
|
@@ -347,34 +368,42 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
347
368
|
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
348
369
|
GGML_METAL_DEL_KERNEL(rms_norm);
|
349
370
|
GGML_METAL_DEL_KERNEL(norm);
|
350
|
-
GGML_METAL_DEL_KERNEL(
|
351
|
-
GGML_METAL_DEL_KERNEL(
|
352
|
-
GGML_METAL_DEL_KERNEL(
|
353
|
-
GGML_METAL_DEL_KERNEL(
|
354
|
-
GGML_METAL_DEL_KERNEL(
|
355
|
-
GGML_METAL_DEL_KERNEL(
|
356
|
-
GGML_METAL_DEL_KERNEL(
|
357
|
-
GGML_METAL_DEL_KERNEL(
|
358
|
-
GGML_METAL_DEL_KERNEL(
|
359
|
-
GGML_METAL_DEL_KERNEL(
|
360
|
-
GGML_METAL_DEL_KERNEL(
|
361
|
-
GGML_METAL_DEL_KERNEL(
|
362
|
-
GGML_METAL_DEL_KERNEL(
|
363
|
-
GGML_METAL_DEL_KERNEL(
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
371
|
+
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
372
|
+
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
|
373
|
+
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
|
374
|
+
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
|
375
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
|
376
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
|
377
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
|
378
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
|
379
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
|
380
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
|
381
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
|
382
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
|
383
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
384
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
385
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
386
|
+
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
387
|
+
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
388
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
389
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
|
390
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
|
391
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
|
392
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
393
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
|
394
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
|
395
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
396
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
397
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
398
|
+
}
|
372
399
|
GGML_METAL_DEL_KERNEL(rope_f32);
|
373
400
|
GGML_METAL_DEL_KERNEL(rope_f16);
|
374
401
|
GGML_METAL_DEL_KERNEL(alibi_f32);
|
375
402
|
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
376
403
|
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
377
404
|
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
405
|
+
GGML_METAL_DEL_KERNEL(concat);
|
406
|
+
GGML_METAL_DEL_KERNEL(sqr);
|
378
407
|
|
379
408
|
#undef GGML_METAL_DEL_KERNEL
|
380
409
|
|
@@ -431,7 +460,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
|
431
460
|
for (int i = 0; i < ctx->n_buffers; ++i) {
|
432
461
|
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
433
462
|
|
434
|
-
//
|
463
|
+
//GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
|
435
464
|
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
436
465
|
*offs = (size_t) ioffs;
|
437
466
|
|
@@ -766,6 +795,44 @@ void ggml_metal_graph_compute(
|
|
766
795
|
{
|
767
796
|
// noop
|
768
797
|
} break;
|
798
|
+
case GGML_OP_CONCAT:
|
799
|
+
{
|
800
|
+
const int64_t nb = ne00;
|
801
|
+
|
802
|
+
[encoder setComputePipelineState:ctx->pipeline_concat];
|
803
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
804
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
805
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
806
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
807
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
808
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
809
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
810
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
811
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
812
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
813
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
814
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
815
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
816
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
817
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
818
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
819
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
820
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
821
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
822
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
823
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
824
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
825
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
826
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
827
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
828
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
829
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
830
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
|
831
|
+
|
832
|
+
const int nth = MIN(1024, ne0);
|
833
|
+
|
834
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
835
|
+
} break;
|
769
836
|
case GGML_OP_ADD:
|
770
837
|
{
|
771
838
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
@@ -861,9 +928,10 @@ void ggml_metal_graph_compute(
|
|
861
928
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
862
929
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
863
930
|
|
864
|
-
const int64_t n = ggml_nelements(dst)
|
931
|
+
const int64_t n = ggml_nelements(dst);
|
932
|
+
GGML_ASSERT(n % 4 == 0);
|
865
933
|
|
866
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
934
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
867
935
|
} break;
|
868
936
|
case GGML_OP_UNARY:
|
869
937
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
@@ -873,9 +941,10 @@ void ggml_metal_graph_compute(
|
|
873
941
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
874
942
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
875
943
|
|
876
|
-
const int64_t n = ggml_nelements(dst)
|
944
|
+
const int64_t n = ggml_nelements(dst);
|
945
|
+
GGML_ASSERT(n % 4 == 0);
|
877
946
|
|
878
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
947
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
879
948
|
} break;
|
880
949
|
case GGML_UNARY_OP_RELU:
|
881
950
|
{
|
@@ -893,9 +962,10 @@ void ggml_metal_graph_compute(
|
|
893
962
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
894
963
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
895
964
|
|
896
|
-
const int64_t n = ggml_nelements(dst)
|
965
|
+
const int64_t n = ggml_nelements(dst);
|
966
|
+
GGML_ASSERT(n % 4 == 0);
|
897
967
|
|
898
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
968
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
899
969
|
} break;
|
900
970
|
default:
|
901
971
|
{
|
@@ -903,6 +973,17 @@ void ggml_metal_graph_compute(
|
|
903
973
|
GGML_ASSERT(false);
|
904
974
|
}
|
905
975
|
} break;
|
976
|
+
case GGML_OP_SQR:
|
977
|
+
{
|
978
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
979
|
+
|
980
|
+
[encoder setComputePipelineState:ctx->pipeline_sqr];
|
981
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
982
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
983
|
+
|
984
|
+
const int64_t n = ggml_nelements(dst);
|
985
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
986
|
+
} break;
|
906
987
|
case GGML_OP_SOFT_MAX:
|
907
988
|
{
|
908
989
|
const int nth = MIN(32, ne00);
|
@@ -944,26 +1025,53 @@ void ggml_metal_graph_compute(
|
|
944
1025
|
} break;
|
945
1026
|
case GGML_OP_MUL_MAT:
|
946
1027
|
{
|
947
|
-
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
|
948
|
-
|
949
1028
|
GGML_ASSERT(ne00 == ne10);
|
950
|
-
// GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
|
951
|
-
uint gqa = ne12/ne02;
|
952
1029
|
GGML_ASSERT(ne03 == ne13);
|
953
1030
|
|
1031
|
+
const uint gqa = ne12/ne02;
|
1032
|
+
|
1033
|
+
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
1034
|
+
// to the matrix-vector kernel
|
1035
|
+
int ne11_mm_min = 1;
|
1036
|
+
|
1037
|
+
#if 0
|
1038
|
+
// the numbers below are measured on M2 Ultra for 7B and 13B models
|
1039
|
+
// these numbers do not translate to other devices or model sizes
|
1040
|
+
// TODO: need to find a better approach
|
1041
|
+
if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
|
1042
|
+
switch (src0t) {
|
1043
|
+
case GGML_TYPE_F16: ne11_mm_min = 2; break;
|
1044
|
+
case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
|
1045
|
+
case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
|
1046
|
+
case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
|
1047
|
+
case GGML_TYPE_Q4_0:
|
1048
|
+
case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
|
1049
|
+
case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
|
1050
|
+
case GGML_TYPE_Q5_0: // not tested yet
|
1051
|
+
case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
|
1052
|
+
case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
|
1053
|
+
case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
|
1054
|
+
default: ne11_mm_min = 1; break;
|
1055
|
+
}
|
1056
|
+
}
|
1057
|
+
#endif
|
1058
|
+
|
954
1059
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
955
1060
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
956
|
-
if (
|
1061
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
1062
|
+
!ggml_is_transposed(src0) &&
|
957
1063
|
!ggml_is_transposed(src1) &&
|
958
1064
|
src1t == GGML_TYPE_F32 &&
|
959
|
-
|
960
|
-
|
961
|
-
ne11
|
1065
|
+
ne00 % 32 == 0 && ne00 >= 64 &&
|
1066
|
+
ne11 > ne11_mm_min) {
|
1067
|
+
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
962
1068
|
switch (src0->type) {
|
963
1069
|
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
|
964
1070
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
965
1071
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
966
1072
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
1073
|
+
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
|
1074
|
+
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
|
967
1075
|
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
|
968
1076
|
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
|
969
1077
|
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
|
@@ -987,17 +1095,18 @@ void ggml_metal_graph_compute(
|
|
987
1095
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
988
1096
|
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
|
989
1097
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
990
|
-
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63)
|
1098
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
991
1099
|
} else {
|
992
1100
|
int nth0 = 32;
|
993
1101
|
int nth1 = 1;
|
994
1102
|
int nrows = 1;
|
1103
|
+
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
995
1104
|
|
996
1105
|
// use custom matrix x vector kernel
|
997
1106
|
switch (src0t) {
|
998
1107
|
case GGML_TYPE_F32:
|
999
1108
|
{
|
1000
|
-
[encoder setComputePipelineState:ctx->
|
1109
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
|
1001
1110
|
nrows = 4;
|
1002
1111
|
} break;
|
1003
1112
|
case GGML_TYPE_F16:
|
@@ -1005,12 +1114,12 @@ void ggml_metal_graph_compute(
|
|
1005
1114
|
nth0 = 32;
|
1006
1115
|
nth1 = 1;
|
1007
1116
|
if (ne11 * ne12 < 4) {
|
1008
|
-
[encoder setComputePipelineState:ctx->
|
1117
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
|
1009
1118
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
1010
|
-
[encoder setComputePipelineState:ctx->
|
1119
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
|
1011
1120
|
nrows = ne11;
|
1012
1121
|
} else {
|
1013
|
-
[encoder setComputePipelineState:ctx->
|
1122
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
|
1014
1123
|
nrows = 4;
|
1015
1124
|
}
|
1016
1125
|
} break;
|
@@ -1021,7 +1130,7 @@ void ggml_metal_graph_compute(
|
|
1021
1130
|
|
1022
1131
|
nth0 = 8;
|
1023
1132
|
nth1 = 8;
|
1024
|
-
[encoder setComputePipelineState:ctx->
|
1133
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
|
1025
1134
|
} break;
|
1026
1135
|
case GGML_TYPE_Q4_1:
|
1027
1136
|
{
|
@@ -1030,7 +1139,25 @@ void ggml_metal_graph_compute(
|
|
1030
1139
|
|
1031
1140
|
nth0 = 8;
|
1032
1141
|
nth1 = 8;
|
1033
|
-
[encoder setComputePipelineState:ctx->
|
1142
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
|
1143
|
+
} break;
|
1144
|
+
case GGML_TYPE_Q5_0:
|
1145
|
+
{
|
1146
|
+
GGML_ASSERT(ne02 == 1);
|
1147
|
+
GGML_ASSERT(ne12 == 1);
|
1148
|
+
|
1149
|
+
nth0 = 8;
|
1150
|
+
nth1 = 8;
|
1151
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
|
1152
|
+
} break;
|
1153
|
+
case GGML_TYPE_Q5_1:
|
1154
|
+
{
|
1155
|
+
GGML_ASSERT(ne02 == 1);
|
1156
|
+
GGML_ASSERT(ne12 == 1);
|
1157
|
+
|
1158
|
+
nth0 = 8;
|
1159
|
+
nth1 = 8;
|
1160
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
|
1034
1161
|
} break;
|
1035
1162
|
case GGML_TYPE_Q8_0:
|
1036
1163
|
{
|
@@ -1039,7 +1166,7 @@ void ggml_metal_graph_compute(
|
|
1039
1166
|
|
1040
1167
|
nth0 = 8;
|
1041
1168
|
nth1 = 8;
|
1042
|
-
[encoder setComputePipelineState:ctx->
|
1169
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
|
1043
1170
|
} break;
|
1044
1171
|
case GGML_TYPE_Q2_K:
|
1045
1172
|
{
|
@@ -1048,7 +1175,7 @@ void ggml_metal_graph_compute(
|
|
1048
1175
|
|
1049
1176
|
nth0 = 2;
|
1050
1177
|
nth1 = 32;
|
1051
|
-
[encoder setComputePipelineState:ctx->
|
1178
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
|
1052
1179
|
} break;
|
1053
1180
|
case GGML_TYPE_Q3_K:
|
1054
1181
|
{
|
@@ -1057,7 +1184,7 @@ void ggml_metal_graph_compute(
|
|
1057
1184
|
|
1058
1185
|
nth0 = 2;
|
1059
1186
|
nth1 = 32;
|
1060
|
-
[encoder setComputePipelineState:ctx->
|
1187
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
|
1061
1188
|
} break;
|
1062
1189
|
case GGML_TYPE_Q4_K:
|
1063
1190
|
{
|
@@ -1066,7 +1193,7 @@ void ggml_metal_graph_compute(
|
|
1066
1193
|
|
1067
1194
|
nth0 = 4; //1;
|
1068
1195
|
nth1 = 8; //32;
|
1069
|
-
[encoder setComputePipelineState:ctx->
|
1196
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
|
1070
1197
|
} break;
|
1071
1198
|
case GGML_TYPE_Q5_K:
|
1072
1199
|
{
|
@@ -1075,7 +1202,7 @@ void ggml_metal_graph_compute(
|
|
1075
1202
|
|
1076
1203
|
nth0 = 2;
|
1077
1204
|
nth1 = 32;
|
1078
|
-
[encoder setComputePipelineState:ctx->
|
1205
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
|
1079
1206
|
} break;
|
1080
1207
|
case GGML_TYPE_Q6_K:
|
1081
1208
|
{
|
@@ -1084,7 +1211,7 @@ void ggml_metal_graph_compute(
|
|
1084
1211
|
|
1085
1212
|
nth0 = 2;
|
1086
1213
|
nth1 = 32;
|
1087
|
-
[encoder setComputePipelineState:ctx->
|
1214
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
|
1088
1215
|
} break;
|
1089
1216
|
default:
|
1090
1217
|
{
|
@@ -1112,8 +1239,9 @@ void ggml_metal_graph_compute(
|
|
1112
1239
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
1113
1240
|
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
1114
1241
|
|
1115
|
-
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
1116
|
-
src0t ==
|
1242
|
+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
1243
|
+
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
1244
|
+
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
1117
1245
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1118
1246
|
}
|
1119
1247
|
else if (src0t == GGML_TYPE_Q4_K) {
|
@@ -1144,6 +1272,8 @@ void ggml_metal_graph_compute(
|
|
1144
1272
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
1145
1273
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
1146
1274
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
1275
|
+
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
|
1276
|
+
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
|
1147
1277
|
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
|
1148
1278
|
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
|
1149
1279
|
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
|
@@ -1166,6 +1296,8 @@ void ggml_metal_graph_compute(
|
|
1166
1296
|
} break;
|
1167
1297
|
case GGML_OP_RMS_NORM:
|
1168
1298
|
{
|
1299
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
1300
|
+
|
1169
1301
|
float eps;
|
1170
1302
|
memcpy(&eps, dst->op_params, sizeof(float));
|
1171
1303
|
|
@@ -1208,7 +1340,7 @@ void ggml_metal_graph_compute(
|
|
1208
1340
|
|
1209
1341
|
const int nth = MIN(1024, ne00);
|
1210
1342
|
|
1211
|
-
const int n_past = ((int32_t *) dst->op_params)[0];
|
1343
|
+
//const int n_past = ((int32_t *) dst->op_params)[0];
|
1212
1344
|
const int n_head = ((int32_t *) dst->op_params)[1];
|
1213
1345
|
float max_bias;
|
1214
1346
|
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
@@ -1371,3 +1503,140 @@ void ggml_metal_graph_compute(
|
|
1371
1503
|
|
1372
1504
|
}
|
1373
1505
|
}
|
1506
|
+
|
1507
|
+
////////////////////////////////////////////////////////////////////////////////
|
1508
|
+
|
1509
|
+
// backend interface
|
1510
|
+
|
1511
|
+
static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
1512
|
+
return "Metal";
|
1513
|
+
|
1514
|
+
UNUSED(backend);
|
1515
|
+
}
|
1516
|
+
|
1517
|
+
static void ggml_backend_metal_free(ggml_backend_t backend) {
|
1518
|
+
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
1519
|
+
ggml_metal_free(ctx);
|
1520
|
+
free(backend);
|
1521
|
+
}
|
1522
|
+
|
1523
|
+
static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
|
1524
|
+
return (void *)buffer->context;
|
1525
|
+
}
|
1526
|
+
|
1527
|
+
static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
1528
|
+
free(buffer->context);
|
1529
|
+
UNUSED(buffer);
|
1530
|
+
}
|
1531
|
+
|
1532
|
+
static struct ggml_backend_buffer_i metal_backend_buffer_i = {
|
1533
|
+
/* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
|
1534
|
+
/* .get_base = */ ggml_backend_metal_buffer_get_base,
|
1535
|
+
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
1536
|
+
/* .init_tensor = */ NULL, // no initialization required
|
1537
|
+
/* .free_tensor = */ NULL, // no cleanup required
|
1538
|
+
};
|
1539
|
+
|
1540
|
+
static ggml_backend_buffer_t ggml_backend_metal_alloc_buffer(ggml_backend_t backend, size_t size) {
|
1541
|
+
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
1542
|
+
|
1543
|
+
void * data = ggml_metal_host_malloc(size);
|
1544
|
+
|
1545
|
+
// TODO: set proper name of the buffers
|
1546
|
+
ggml_metal_add_buffer(ctx, "backend", data, size, 0);
|
1547
|
+
|
1548
|
+
return ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size);
|
1549
|
+
}
|
1550
|
+
|
1551
|
+
static size_t ggml_backend_metal_get_alignment(ggml_backend_t backend) {
|
1552
|
+
return 32;
|
1553
|
+
UNUSED(backend);
|
1554
|
+
}
|
1555
|
+
|
1556
|
+
static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
1557
|
+
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
|
1558
|
+
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
1559
|
+
|
1560
|
+
memcpy((char *)tensor->data + offset, data, size);
|
1561
|
+
|
1562
|
+
UNUSED(backend);
|
1563
|
+
}
|
1564
|
+
|
1565
|
+
static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
1566
|
+
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
|
1567
|
+
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
1568
|
+
|
1569
|
+
memcpy(data, (const char *)tensor->data + offset, size);
|
1570
|
+
|
1571
|
+
UNUSED(backend);
|
1572
|
+
}
|
1573
|
+
|
1574
|
+
static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
|
1575
|
+
UNUSED(backend);
|
1576
|
+
}
|
1577
|
+
|
1578
|
+
static void ggml_backend_metal_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
|
1579
|
+
ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
|
1580
|
+
|
1581
|
+
UNUSED(backend);
|
1582
|
+
}
|
1583
|
+
|
1584
|
+
static void ggml_backend_metal_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
|
1585
|
+
ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src));
|
1586
|
+
|
1587
|
+
UNUSED(backend);
|
1588
|
+
}
|
1589
|
+
|
1590
|
+
static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
1591
|
+
struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
|
1592
|
+
|
1593
|
+
ggml_metal_graph_compute(metal_ctx, cgraph);
|
1594
|
+
}
|
1595
|
+
|
1596
|
+
static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
1597
|
+
return true;
|
1598
|
+
UNUSED(backend);
|
1599
|
+
UNUSED(op);
|
1600
|
+
}
|
1601
|
+
|
1602
|
+
static struct ggml_backend_i metal_backend_i = {
|
1603
|
+
/* .get_name = */ ggml_backend_metal_name,
|
1604
|
+
/* .free = */ ggml_backend_metal_free,
|
1605
|
+
/* .alloc_buffer = */ ggml_backend_metal_alloc_buffer,
|
1606
|
+
/* .get_alignment = */ ggml_backend_metal_get_alignment,
|
1607
|
+
/* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
|
1608
|
+
/* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
|
1609
|
+
/* .synchronize = */ ggml_backend_metal_synchronize,
|
1610
|
+
/* .cpy_tensor_from = */ ggml_backend_metal_cpy_tensor_from,
|
1611
|
+
/* .cpy_tensor_to = */ ggml_backend_metal_cpy_tensor_to,
|
1612
|
+
/* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
|
1613
|
+
/* .graph_plan_free = */ NULL,
|
1614
|
+
/* .graph_plan_compute = */ NULL,
|
1615
|
+
/* .graph_compute = */ ggml_backend_metal_graph_compute,
|
1616
|
+
/* .supports_op = */ ggml_backend_metal_supports_op,
|
1617
|
+
};
|
1618
|
+
|
1619
|
+
ggml_backend_t ggml_backend_metal_init(void) {
|
1620
|
+
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
|
1621
|
+
|
1622
|
+
ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
|
1623
|
+
|
1624
|
+
ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
|
1625
|
+
|
1626
|
+
*metal_backend = (struct ggml_backend) {
|
1627
|
+
/* .interface = */ metal_backend_i,
|
1628
|
+
/* .context = */ ctx,
|
1629
|
+
};
|
1630
|
+
|
1631
|
+
return metal_backend;
|
1632
|
+
}
|
1633
|
+
|
1634
|
+
bool ggml_backend_is_metal(ggml_backend_t backend) {
|
1635
|
+
return backend->iface.get_name == ggml_backend_metal_name;
|
1636
|
+
}
|
1637
|
+
|
1638
|
+
void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
1639
|
+
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
1640
|
+
|
1641
|
+
ggml_metal_set_n_cb(ctx, n_cb);
|
1642
|
+
}
|