llama_cpp 0.7.0 → 0.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/ext/llama_cpp/extconf.rb +1 -1
- data/ext/llama_cpp/llama_cpp.cpp +41 -21
- data/ext/llama_cpp/src/ggml-alloc.c +62 -107
- data/ext/llama_cpp/src/ggml-alloc.h +11 -5
- data/ext/llama_cpp/src/ggml-backend.c +385 -0
- data/ext/llama_cpp/src/ggml-backend.h +143 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +500 -78
- data/ext/llama_cpp/src/ggml-cuda.h +4 -0
- data/ext/llama_cpp/src/ggml-metal.h +18 -1
- data/ext/llama_cpp/src/ggml-metal.m +396 -127
- data/ext/llama_cpp/src/ggml-metal.metal +290 -46
- data/ext/llama_cpp/src/ggml-opencl.cpp +47 -71
- data/ext/llama_cpp/src/ggml.c +71 -55
- data/ext/llama_cpp/src/ggml.h +15 -9
- data/ext/llama_cpp/src/k_quants.c +12 -20
- data/ext/llama_cpp/src/k_quants.h +5 -5
- data/ext/llama_cpp/src/llama.cpp +1851 -250
- data/ext/llama_cpp/src/llama.h +18 -12
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -4
- metadata +5 -3
@@ -73,6 +73,8 @@ struct ggml_metal_context {
|
|
73
73
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
74
74
|
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
75
75
|
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
76
|
+
GGML_METAL_DECL_KERNEL(get_rows_q5_0);
|
77
|
+
GGML_METAL_DECL_KERNEL(get_rows_q5_1);
|
76
78
|
GGML_METAL_DECL_KERNEL(get_rows_q8_0);
|
77
79
|
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
|
78
80
|
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
|
@@ -81,22 +83,26 @@ struct ggml_metal_context {
|
|
81
83
|
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
82
84
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
83
85
|
GGML_METAL_DECL_KERNEL(norm);
|
84
|
-
GGML_METAL_DECL_KERNEL(
|
85
|
-
GGML_METAL_DECL_KERNEL(
|
86
|
-
GGML_METAL_DECL_KERNEL(
|
87
|
-
GGML_METAL_DECL_KERNEL(
|
88
|
-
GGML_METAL_DECL_KERNEL(
|
89
|
-
GGML_METAL_DECL_KERNEL(
|
90
|
-
GGML_METAL_DECL_KERNEL(
|
91
|
-
GGML_METAL_DECL_KERNEL(
|
92
|
-
GGML_METAL_DECL_KERNEL(
|
93
|
-
GGML_METAL_DECL_KERNEL(
|
94
|
-
GGML_METAL_DECL_KERNEL(
|
95
|
-
GGML_METAL_DECL_KERNEL(
|
86
|
+
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
87
|
+
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
|
88
|
+
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
|
89
|
+
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
|
90
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
|
91
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
|
92
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
|
93
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
|
94
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
|
95
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
|
96
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
|
97
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
|
98
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
99
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
96
100
|
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
97
101
|
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
98
102
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
99
103
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
104
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
|
105
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
|
100
106
|
GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
|
101
107
|
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
|
102
108
|
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
|
@@ -109,6 +115,8 @@ struct ggml_metal_context {
|
|
109
115
|
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
110
116
|
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
111
117
|
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
118
|
+
GGML_METAL_DECL_KERNEL(concat);
|
119
|
+
GGML_METAL_DECL_KERNEL(sqr);
|
112
120
|
|
113
121
|
#undef GGML_METAL_DECL_KERNEL
|
114
122
|
};
|
@@ -183,56 +191,44 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
183
191
|
|
184
192
|
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
185
193
|
|
186
|
-
|
187
|
-
// load the default.metallib file
|
194
|
+
// load library
|
188
195
|
{
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"];
|
193
|
-
NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
|
194
|
-
NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
|
195
|
-
NSURL * libURL = [NSURL fileURLWithPath:libPath];
|
196
|
-
|
197
|
-
// Load the metallib file into a Metal library
|
198
|
-
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
199
|
-
|
200
|
-
if (error) {
|
201
|
-
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
202
|
-
return NULL;
|
203
|
-
}
|
204
|
-
}
|
196
|
+
NSBundle * bundle = nil;
|
197
|
+
#ifdef SWIFT_PACKAGE
|
198
|
+
bundle = SWIFTPM_MODULE_BUNDLE;
|
205
199
|
#else
|
206
|
-
|
207
|
-
|
208
|
-
// read the source from "ggml-metal.metal" into a string and use newLibraryWithSource
|
209
|
-
{
|
200
|
+
bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
201
|
+
#endif
|
210
202
|
NSError * error = nil;
|
203
|
+
NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"];
|
204
|
+
if (libPath != nil) {
|
205
|
+
NSURL * libURL = [NSURL fileURLWithPath:libPath];
|
206
|
+
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
|
207
|
+
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
208
|
+
} else {
|
209
|
+
GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
|
210
|
+
|
211
|
+
NSString * sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
212
|
+
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [sourcePath UTF8String]);
|
213
|
+
NSString * src = [NSString stringWithContentsOfFile:sourcePath encoding:NSUTF8StringEncoding error:&error];
|
214
|
+
if (error) {
|
215
|
+
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
216
|
+
return NULL;
|
217
|
+
}
|
211
218
|
|
212
|
-
|
213
|
-
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
214
|
-
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
215
|
-
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path UTF8String]);
|
216
|
-
|
217
|
-
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
218
|
-
if (error) {
|
219
|
-
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
220
|
-
return NULL;
|
221
|
-
}
|
222
|
-
|
219
|
+
MTLCompileOptions* options = nil;
|
223
220
|
#ifdef GGML_QKK_64
|
224
|
-
|
225
|
-
|
226
|
-
ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
|
227
|
-
#else
|
228
|
-
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
|
221
|
+
options = [MTLCompileOptions new];
|
222
|
+
options.preprocessorMacros = @{ @"QK_K" : @(64) };
|
229
223
|
#endif
|
224
|
+
ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
|
225
|
+
}
|
226
|
+
|
230
227
|
if (error) {
|
231
228
|
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
232
229
|
return NULL;
|
233
230
|
}
|
234
231
|
}
|
235
|
-
#endif
|
236
232
|
|
237
233
|
// load kernels
|
238
234
|
{
|
@@ -264,6 +260,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
264
260
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
265
261
|
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
266
262
|
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
263
|
+
GGML_METAL_ADD_KERNEL(get_rows_q5_0);
|
264
|
+
GGML_METAL_ADD_KERNEL(get_rows_q5_1);
|
267
265
|
GGML_METAL_ADD_KERNEL(get_rows_q8_0);
|
268
266
|
GGML_METAL_ADD_KERNEL(get_rows_q2_K);
|
269
267
|
GGML_METAL_ADD_KERNEL(get_rows_q3_K);
|
@@ -272,40 +270,61 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
272
270
|
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
273
271
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
274
272
|
GGML_METAL_ADD_KERNEL(norm);
|
275
|
-
GGML_METAL_ADD_KERNEL(
|
276
|
-
GGML_METAL_ADD_KERNEL(
|
277
|
-
GGML_METAL_ADD_KERNEL(
|
278
|
-
GGML_METAL_ADD_KERNEL(
|
279
|
-
GGML_METAL_ADD_KERNEL(
|
280
|
-
GGML_METAL_ADD_KERNEL(
|
281
|
-
GGML_METAL_ADD_KERNEL(
|
282
|
-
GGML_METAL_ADD_KERNEL(
|
283
|
-
GGML_METAL_ADD_KERNEL(
|
284
|
-
GGML_METAL_ADD_KERNEL(
|
285
|
-
GGML_METAL_ADD_KERNEL(
|
286
|
-
GGML_METAL_ADD_KERNEL(
|
287
|
-
GGML_METAL_ADD_KERNEL(
|
288
|
-
GGML_METAL_ADD_KERNEL(
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
273
|
+
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
274
|
+
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
|
275
|
+
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
|
276
|
+
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
|
277
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
|
278
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
|
279
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
|
280
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
|
281
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
|
282
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
|
283
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
|
284
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
|
285
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
|
286
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
|
287
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
288
|
+
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
289
|
+
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
290
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
291
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
|
292
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
|
293
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
|
294
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
295
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
|
296
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
|
297
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
298
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
299
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
300
|
+
}
|
297
301
|
GGML_METAL_ADD_KERNEL(rope_f32);
|
298
302
|
GGML_METAL_ADD_KERNEL(rope_f16);
|
299
303
|
GGML_METAL_ADD_KERNEL(alibi_f32);
|
300
304
|
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
301
305
|
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
302
306
|
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
307
|
+
GGML_METAL_ADD_KERNEL(concat);
|
308
|
+
GGML_METAL_ADD_KERNEL(sqr);
|
303
309
|
|
304
310
|
#undef GGML_METAL_ADD_KERNEL
|
305
311
|
}
|
306
312
|
|
307
|
-
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
308
313
|
#if TARGET_OS_OSX
|
314
|
+
// print MTL GPU family:
|
315
|
+
GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
|
316
|
+
|
317
|
+
// determine max supported GPU family
|
318
|
+
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
319
|
+
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
320
|
+
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
321
|
+
if ([ctx->device supportsFamily:i]) {
|
322
|
+
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
|
323
|
+
break;
|
324
|
+
}
|
325
|
+
}
|
326
|
+
|
327
|
+
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
309
328
|
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
310
329
|
if (ctx->device.maxTransferRate != 0) {
|
311
330
|
GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
@@ -339,6 +358,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
339
358
|
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
340
359
|
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
341
360
|
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
361
|
+
GGML_METAL_DEL_KERNEL(get_rows_q5_0);
|
362
|
+
GGML_METAL_DEL_KERNEL(get_rows_q5_1);
|
342
363
|
GGML_METAL_DEL_KERNEL(get_rows_q8_0);
|
343
364
|
GGML_METAL_DEL_KERNEL(get_rows_q2_K);
|
344
365
|
GGML_METAL_DEL_KERNEL(get_rows_q3_K);
|
@@ -347,34 +368,42 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
347
368
|
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
348
369
|
GGML_METAL_DEL_KERNEL(rms_norm);
|
349
370
|
GGML_METAL_DEL_KERNEL(norm);
|
350
|
-
GGML_METAL_DEL_KERNEL(
|
351
|
-
GGML_METAL_DEL_KERNEL(
|
352
|
-
GGML_METAL_DEL_KERNEL(
|
353
|
-
GGML_METAL_DEL_KERNEL(
|
354
|
-
GGML_METAL_DEL_KERNEL(
|
355
|
-
GGML_METAL_DEL_KERNEL(
|
356
|
-
GGML_METAL_DEL_KERNEL(
|
357
|
-
GGML_METAL_DEL_KERNEL(
|
358
|
-
GGML_METAL_DEL_KERNEL(
|
359
|
-
GGML_METAL_DEL_KERNEL(
|
360
|
-
GGML_METAL_DEL_KERNEL(
|
361
|
-
GGML_METAL_DEL_KERNEL(
|
362
|
-
GGML_METAL_DEL_KERNEL(
|
363
|
-
GGML_METAL_DEL_KERNEL(
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
371
|
+
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
372
|
+
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
|
373
|
+
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
|
374
|
+
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
|
375
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
|
376
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
|
377
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
|
378
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
|
379
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
|
380
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
|
381
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
|
382
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
|
383
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
384
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
385
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
386
|
+
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
387
|
+
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
388
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
389
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
|
390
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
|
391
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
|
392
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
393
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
|
394
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
|
395
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
396
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
397
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
398
|
+
}
|
372
399
|
GGML_METAL_DEL_KERNEL(rope_f32);
|
373
400
|
GGML_METAL_DEL_KERNEL(rope_f16);
|
374
401
|
GGML_METAL_DEL_KERNEL(alibi_f32);
|
375
402
|
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
376
403
|
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
377
404
|
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
405
|
+
GGML_METAL_DEL_KERNEL(concat);
|
406
|
+
GGML_METAL_DEL_KERNEL(sqr);
|
378
407
|
|
379
408
|
#undef GGML_METAL_DEL_KERNEL
|
380
409
|
|
@@ -431,7 +460,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
|
431
460
|
for (int i = 0; i < ctx->n_buffers; ++i) {
|
432
461
|
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
433
462
|
|
434
|
-
//
|
463
|
+
//GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
|
435
464
|
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
436
465
|
*offs = (size_t) ioffs;
|
437
466
|
|
@@ -766,6 +795,44 @@ void ggml_metal_graph_compute(
|
|
766
795
|
{
|
767
796
|
// noop
|
768
797
|
} break;
|
798
|
+
case GGML_OP_CONCAT:
|
799
|
+
{
|
800
|
+
const int64_t nb = ne00;
|
801
|
+
|
802
|
+
[encoder setComputePipelineState:ctx->pipeline_concat];
|
803
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
804
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
805
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
806
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
807
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
808
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
809
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
810
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
811
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
812
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
813
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
814
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
815
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
816
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
817
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
818
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
819
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
820
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
821
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
822
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
823
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
824
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
825
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
826
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
827
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
828
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
829
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
830
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
|
831
|
+
|
832
|
+
const int nth = MIN(1024, ne0);
|
833
|
+
|
834
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
835
|
+
} break;
|
769
836
|
case GGML_OP_ADD:
|
770
837
|
{
|
771
838
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
@@ -861,9 +928,10 @@ void ggml_metal_graph_compute(
|
|
861
928
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
862
929
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
863
930
|
|
864
|
-
const int64_t n = ggml_nelements(dst)
|
931
|
+
const int64_t n = ggml_nelements(dst);
|
932
|
+
GGML_ASSERT(n % 4 == 0);
|
865
933
|
|
866
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
934
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
867
935
|
} break;
|
868
936
|
case GGML_OP_UNARY:
|
869
937
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
@@ -873,9 +941,10 @@ void ggml_metal_graph_compute(
|
|
873
941
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
874
942
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
875
943
|
|
876
|
-
const int64_t n = ggml_nelements(dst)
|
944
|
+
const int64_t n = ggml_nelements(dst);
|
945
|
+
GGML_ASSERT(n % 4 == 0);
|
877
946
|
|
878
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
947
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
879
948
|
} break;
|
880
949
|
case GGML_UNARY_OP_RELU:
|
881
950
|
{
|
@@ -893,9 +962,10 @@ void ggml_metal_graph_compute(
|
|
893
962
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
894
963
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
895
964
|
|
896
|
-
const int64_t n = ggml_nelements(dst)
|
965
|
+
const int64_t n = ggml_nelements(dst);
|
966
|
+
GGML_ASSERT(n % 4 == 0);
|
897
967
|
|
898
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
968
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
899
969
|
} break;
|
900
970
|
default:
|
901
971
|
{
|
@@ -903,6 +973,17 @@ void ggml_metal_graph_compute(
|
|
903
973
|
GGML_ASSERT(false);
|
904
974
|
}
|
905
975
|
} break;
|
976
|
+
case GGML_OP_SQR:
|
977
|
+
{
|
978
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
979
|
+
|
980
|
+
[encoder setComputePipelineState:ctx->pipeline_sqr];
|
981
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
982
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
983
|
+
|
984
|
+
const int64_t n = ggml_nelements(dst);
|
985
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
986
|
+
} break;
|
906
987
|
case GGML_OP_SOFT_MAX:
|
907
988
|
{
|
908
989
|
const int nth = MIN(32, ne00);
|
@@ -944,26 +1025,53 @@ void ggml_metal_graph_compute(
|
|
944
1025
|
} break;
|
945
1026
|
case GGML_OP_MUL_MAT:
|
946
1027
|
{
|
947
|
-
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
|
948
|
-
|
949
1028
|
GGML_ASSERT(ne00 == ne10);
|
950
|
-
// GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
|
951
|
-
uint gqa = ne12/ne02;
|
952
1029
|
GGML_ASSERT(ne03 == ne13);
|
953
1030
|
|
1031
|
+
const uint gqa = ne12/ne02;
|
1032
|
+
|
1033
|
+
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
1034
|
+
// to the matrix-vector kernel
|
1035
|
+
int ne11_mm_min = 1;
|
1036
|
+
|
1037
|
+
#if 0
|
1038
|
+
// the numbers below are measured on M2 Ultra for 7B and 13B models
|
1039
|
+
// these numbers do not translate to other devices or model sizes
|
1040
|
+
// TODO: need to find a better approach
|
1041
|
+
if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
|
1042
|
+
switch (src0t) {
|
1043
|
+
case GGML_TYPE_F16: ne11_mm_min = 2; break;
|
1044
|
+
case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
|
1045
|
+
case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
|
1046
|
+
case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
|
1047
|
+
case GGML_TYPE_Q4_0:
|
1048
|
+
case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
|
1049
|
+
case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
|
1050
|
+
case GGML_TYPE_Q5_0: // not tested yet
|
1051
|
+
case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
|
1052
|
+
case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
|
1053
|
+
case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
|
1054
|
+
default: ne11_mm_min = 1; break;
|
1055
|
+
}
|
1056
|
+
}
|
1057
|
+
#endif
|
1058
|
+
|
954
1059
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
955
1060
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
956
|
-
if (
|
1061
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
1062
|
+
!ggml_is_transposed(src0) &&
|
957
1063
|
!ggml_is_transposed(src1) &&
|
958
1064
|
src1t == GGML_TYPE_F32 &&
|
959
|
-
|
960
|
-
|
961
|
-
ne11
|
1065
|
+
ne00 % 32 == 0 && ne00 >= 64 &&
|
1066
|
+
ne11 > ne11_mm_min) {
|
1067
|
+
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
962
1068
|
switch (src0->type) {
|
963
1069
|
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
|
964
1070
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
965
1071
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
966
1072
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
1073
|
+
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
|
1074
|
+
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
|
967
1075
|
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
|
968
1076
|
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
|
969
1077
|
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
|
@@ -987,17 +1095,18 @@ void ggml_metal_graph_compute(
|
|
987
1095
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
988
1096
|
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
|
989
1097
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
990
|
-
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63)
|
1098
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
991
1099
|
} else {
|
992
1100
|
int nth0 = 32;
|
993
1101
|
int nth1 = 1;
|
994
1102
|
int nrows = 1;
|
1103
|
+
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
995
1104
|
|
996
1105
|
// use custom matrix x vector kernel
|
997
1106
|
switch (src0t) {
|
998
1107
|
case GGML_TYPE_F32:
|
999
1108
|
{
|
1000
|
-
[encoder setComputePipelineState:ctx->
|
1109
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
|
1001
1110
|
nrows = 4;
|
1002
1111
|
} break;
|
1003
1112
|
case GGML_TYPE_F16:
|
@@ -1005,12 +1114,12 @@ void ggml_metal_graph_compute(
|
|
1005
1114
|
nth0 = 32;
|
1006
1115
|
nth1 = 1;
|
1007
1116
|
if (ne11 * ne12 < 4) {
|
1008
|
-
[encoder setComputePipelineState:ctx->
|
1117
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
|
1009
1118
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
1010
|
-
[encoder setComputePipelineState:ctx->
|
1119
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
|
1011
1120
|
nrows = ne11;
|
1012
1121
|
} else {
|
1013
|
-
[encoder setComputePipelineState:ctx->
|
1122
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
|
1014
1123
|
nrows = 4;
|
1015
1124
|
}
|
1016
1125
|
} break;
|
@@ -1021,7 +1130,7 @@ void ggml_metal_graph_compute(
|
|
1021
1130
|
|
1022
1131
|
nth0 = 8;
|
1023
1132
|
nth1 = 8;
|
1024
|
-
[encoder setComputePipelineState:ctx->
|
1133
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
|
1025
1134
|
} break;
|
1026
1135
|
case GGML_TYPE_Q4_1:
|
1027
1136
|
{
|
@@ -1030,7 +1139,25 @@ void ggml_metal_graph_compute(
|
|
1030
1139
|
|
1031
1140
|
nth0 = 8;
|
1032
1141
|
nth1 = 8;
|
1033
|
-
[encoder setComputePipelineState:ctx->
|
1142
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
|
1143
|
+
} break;
|
1144
|
+
case GGML_TYPE_Q5_0:
|
1145
|
+
{
|
1146
|
+
GGML_ASSERT(ne02 == 1);
|
1147
|
+
GGML_ASSERT(ne12 == 1);
|
1148
|
+
|
1149
|
+
nth0 = 8;
|
1150
|
+
nth1 = 8;
|
1151
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
|
1152
|
+
} break;
|
1153
|
+
case GGML_TYPE_Q5_1:
|
1154
|
+
{
|
1155
|
+
GGML_ASSERT(ne02 == 1);
|
1156
|
+
GGML_ASSERT(ne12 == 1);
|
1157
|
+
|
1158
|
+
nth0 = 8;
|
1159
|
+
nth1 = 8;
|
1160
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
|
1034
1161
|
} break;
|
1035
1162
|
case GGML_TYPE_Q8_0:
|
1036
1163
|
{
|
@@ -1039,7 +1166,7 @@ void ggml_metal_graph_compute(
|
|
1039
1166
|
|
1040
1167
|
nth0 = 8;
|
1041
1168
|
nth1 = 8;
|
1042
|
-
[encoder setComputePipelineState:ctx->
|
1169
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
|
1043
1170
|
} break;
|
1044
1171
|
case GGML_TYPE_Q2_K:
|
1045
1172
|
{
|
@@ -1048,7 +1175,7 @@ void ggml_metal_graph_compute(
|
|
1048
1175
|
|
1049
1176
|
nth0 = 2;
|
1050
1177
|
nth1 = 32;
|
1051
|
-
[encoder setComputePipelineState:ctx->
|
1178
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
|
1052
1179
|
} break;
|
1053
1180
|
case GGML_TYPE_Q3_K:
|
1054
1181
|
{
|
@@ -1057,7 +1184,7 @@ void ggml_metal_graph_compute(
|
|
1057
1184
|
|
1058
1185
|
nth0 = 2;
|
1059
1186
|
nth1 = 32;
|
1060
|
-
[encoder setComputePipelineState:ctx->
|
1187
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
|
1061
1188
|
} break;
|
1062
1189
|
case GGML_TYPE_Q4_K:
|
1063
1190
|
{
|
@@ -1066,7 +1193,7 @@ void ggml_metal_graph_compute(
|
|
1066
1193
|
|
1067
1194
|
nth0 = 4; //1;
|
1068
1195
|
nth1 = 8; //32;
|
1069
|
-
[encoder setComputePipelineState:ctx->
|
1196
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
|
1070
1197
|
} break;
|
1071
1198
|
case GGML_TYPE_Q5_K:
|
1072
1199
|
{
|
@@ -1075,7 +1202,7 @@ void ggml_metal_graph_compute(
|
|
1075
1202
|
|
1076
1203
|
nth0 = 2;
|
1077
1204
|
nth1 = 32;
|
1078
|
-
[encoder setComputePipelineState:ctx->
|
1205
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
|
1079
1206
|
} break;
|
1080
1207
|
case GGML_TYPE_Q6_K:
|
1081
1208
|
{
|
@@ -1084,7 +1211,7 @@ void ggml_metal_graph_compute(
|
|
1084
1211
|
|
1085
1212
|
nth0 = 2;
|
1086
1213
|
nth1 = 32;
|
1087
|
-
[encoder setComputePipelineState:ctx->
|
1214
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
|
1088
1215
|
} break;
|
1089
1216
|
default:
|
1090
1217
|
{
|
@@ -1112,8 +1239,9 @@ void ggml_metal_graph_compute(
|
|
1112
1239
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
1113
1240
|
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
1114
1241
|
|
1115
|
-
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
1116
|
-
src0t ==
|
1242
|
+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
1243
|
+
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
1244
|
+
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
1117
1245
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1118
1246
|
}
|
1119
1247
|
else if (src0t == GGML_TYPE_Q4_K) {
|
@@ -1144,6 +1272,8 @@ void ggml_metal_graph_compute(
|
|
1144
1272
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
1145
1273
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
1146
1274
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
1275
|
+
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
|
1276
|
+
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
|
1147
1277
|
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
|
1148
1278
|
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
|
1149
1279
|
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
|
@@ -1166,6 +1296,8 @@ void ggml_metal_graph_compute(
|
|
1166
1296
|
} break;
|
1167
1297
|
case GGML_OP_RMS_NORM:
|
1168
1298
|
{
|
1299
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
1300
|
+
|
1169
1301
|
float eps;
|
1170
1302
|
memcpy(&eps, dst->op_params, sizeof(float));
|
1171
1303
|
|
@@ -1208,7 +1340,7 @@ void ggml_metal_graph_compute(
|
|
1208
1340
|
|
1209
1341
|
const int nth = MIN(1024, ne00);
|
1210
1342
|
|
1211
|
-
const int n_past = ((int32_t *) dst->op_params)[0];
|
1343
|
+
//const int n_past = ((int32_t *) dst->op_params)[0];
|
1212
1344
|
const int n_head = ((int32_t *) dst->op_params)[1];
|
1213
1345
|
float max_bias;
|
1214
1346
|
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
@@ -1371,3 +1503,140 @@ void ggml_metal_graph_compute(
|
|
1371
1503
|
|
1372
1504
|
}
|
1373
1505
|
}
|
1506
|
+
|
1507
|
+
////////////////////////////////////////////////////////////////////////////////
|
1508
|
+
|
1509
|
+
// backend interface
|
1510
|
+
|
1511
|
+
static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
1512
|
+
return "Metal";
|
1513
|
+
|
1514
|
+
UNUSED(backend);
|
1515
|
+
}
|
1516
|
+
|
1517
|
+
static void ggml_backend_metal_free(ggml_backend_t backend) {
|
1518
|
+
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
1519
|
+
ggml_metal_free(ctx);
|
1520
|
+
free(backend);
|
1521
|
+
}
|
1522
|
+
|
1523
|
+
static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
|
1524
|
+
return (void *)buffer->context;
|
1525
|
+
}
|
1526
|
+
|
1527
|
+
static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
1528
|
+
free(buffer->context);
|
1529
|
+
UNUSED(buffer);
|
1530
|
+
}
|
1531
|
+
|
1532
|
+
static struct ggml_backend_buffer_i metal_backend_buffer_i = {
|
1533
|
+
/* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
|
1534
|
+
/* .get_base = */ ggml_backend_metal_buffer_get_base,
|
1535
|
+
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
1536
|
+
/* .init_tensor = */ NULL, // no initialization required
|
1537
|
+
/* .free_tensor = */ NULL, // no cleanup required
|
1538
|
+
};
|
1539
|
+
|
1540
|
+
static ggml_backend_buffer_t ggml_backend_metal_alloc_buffer(ggml_backend_t backend, size_t size) {
|
1541
|
+
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
1542
|
+
|
1543
|
+
void * data = ggml_metal_host_malloc(size);
|
1544
|
+
|
1545
|
+
// TODO: set proper name of the buffers
|
1546
|
+
ggml_metal_add_buffer(ctx, "backend", data, size, 0);
|
1547
|
+
|
1548
|
+
return ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size);
|
1549
|
+
}
|
1550
|
+
|
1551
|
+
static size_t ggml_backend_metal_get_alignment(ggml_backend_t backend) {
|
1552
|
+
return 32;
|
1553
|
+
UNUSED(backend);
|
1554
|
+
}
|
1555
|
+
|
1556
|
+
static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
1557
|
+
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
|
1558
|
+
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
1559
|
+
|
1560
|
+
memcpy((char *)tensor->data + offset, data, size);
|
1561
|
+
|
1562
|
+
UNUSED(backend);
|
1563
|
+
}
|
1564
|
+
|
1565
|
+
static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
1566
|
+
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
|
1567
|
+
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
1568
|
+
|
1569
|
+
memcpy(data, (const char *)tensor->data + offset, size);
|
1570
|
+
|
1571
|
+
UNUSED(backend);
|
1572
|
+
}
|
1573
|
+
|
1574
|
+
static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
|
1575
|
+
UNUSED(backend);
|
1576
|
+
}
|
1577
|
+
|
1578
|
+
static void ggml_backend_metal_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
|
1579
|
+
ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
|
1580
|
+
|
1581
|
+
UNUSED(backend);
|
1582
|
+
}
|
1583
|
+
|
1584
|
+
static void ggml_backend_metal_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
|
1585
|
+
ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src));
|
1586
|
+
|
1587
|
+
UNUSED(backend);
|
1588
|
+
}
|
1589
|
+
|
1590
|
+
static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
1591
|
+
struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
|
1592
|
+
|
1593
|
+
ggml_metal_graph_compute(metal_ctx, cgraph);
|
1594
|
+
}
|
1595
|
+
|
1596
|
+
static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
1597
|
+
return true;
|
1598
|
+
UNUSED(backend);
|
1599
|
+
UNUSED(op);
|
1600
|
+
}
|
1601
|
+
|
1602
|
+
static struct ggml_backend_i metal_backend_i = {
|
1603
|
+
/* .get_name = */ ggml_backend_metal_name,
|
1604
|
+
/* .free = */ ggml_backend_metal_free,
|
1605
|
+
/* .alloc_buffer = */ ggml_backend_metal_alloc_buffer,
|
1606
|
+
/* .get_alignment = */ ggml_backend_metal_get_alignment,
|
1607
|
+
/* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
|
1608
|
+
/* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
|
1609
|
+
/* .synchronize = */ ggml_backend_metal_synchronize,
|
1610
|
+
/* .cpy_tensor_from = */ ggml_backend_metal_cpy_tensor_from,
|
1611
|
+
/* .cpy_tensor_to = */ ggml_backend_metal_cpy_tensor_to,
|
1612
|
+
/* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
|
1613
|
+
/* .graph_plan_free = */ NULL,
|
1614
|
+
/* .graph_plan_compute = */ NULL,
|
1615
|
+
/* .graph_compute = */ ggml_backend_metal_graph_compute,
|
1616
|
+
/* .supports_op = */ ggml_backend_metal_supports_op,
|
1617
|
+
};
|
1618
|
+
|
1619
|
+
ggml_backend_t ggml_backend_metal_init(void) {
|
1620
|
+
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
|
1621
|
+
|
1622
|
+
ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
|
1623
|
+
|
1624
|
+
ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
|
1625
|
+
|
1626
|
+
*metal_backend = (struct ggml_backend) {
|
1627
|
+
/* .interface = */ metal_backend_i,
|
1628
|
+
/* .context = */ ctx,
|
1629
|
+
};
|
1630
|
+
|
1631
|
+
return metal_backend;
|
1632
|
+
}
|
1633
|
+
|
1634
|
+
bool ggml_backend_is_metal(ggml_backend_t backend) {
|
1635
|
+
return backend->iface.get_name == ggml_backend_metal_name;
|
1636
|
+
}
|
1637
|
+
|
1638
|
+
void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
1639
|
+
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
1640
|
+
|
1641
|
+
ggml_metal_set_n_cb(ctx, n_cb);
|
1642
|
+
}
|