llama_cpp 0.9.4 → 0.10.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/ext/llama_cpp/llama_cpp.cpp +121 -15
- data/ext/llama_cpp/src/ggml-alloc.c +43 -8
- data/ext/llama_cpp/src/ggml-alloc.h +7 -0
- data/ext/llama_cpp/src/ggml-backend-impl.h +46 -21
- data/ext/llama_cpp/src/ggml-backend.c +563 -156
- data/ext/llama_cpp/src/ggml-backend.h +62 -17
- data/ext/llama_cpp/src/ggml-cuda.cu +1270 -434
- data/ext/llama_cpp/src/ggml-cuda.h +9 -1
- data/ext/llama_cpp/src/ggml-impl.h +1 -1
- data/ext/llama_cpp/src/ggml-metal.h +6 -0
- data/ext/llama_cpp/src/ggml-metal.m +535 -175
- data/ext/llama_cpp/src/ggml-metal.metal +888 -237
- data/ext/llama_cpp/src/ggml-opencl.cpp +5 -7
- data/ext/llama_cpp/src/ggml.c +393 -127
- data/ext/llama_cpp/src/ggml.h +59 -7
- data/ext/llama_cpp/src/llama.cpp +791 -357
- data/ext/llama_cpp/src/llama.h +29 -6
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +20 -2
- metadata +3 -3
@@ -62,6 +62,8 @@ struct ggml_metal_context {
|
|
62
62
|
GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
|
63
63
|
GGML_METAL_DECL_KERNEL(mul);
|
64
64
|
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
65
|
+
GGML_METAL_DECL_KERNEL(div);
|
66
|
+
GGML_METAL_DECL_KERNEL(div_row);
|
65
67
|
GGML_METAL_DECL_KERNEL(scale);
|
66
68
|
GGML_METAL_DECL_KERNEL(scale_4);
|
67
69
|
GGML_METAL_DECL_KERNEL(silu);
|
@@ -112,15 +114,35 @@ struct ggml_metal_context {
|
|
112
114
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
113
115
|
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
114
116
|
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
117
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
|
118
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
|
119
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
|
120
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
|
121
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
|
122
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
|
123
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
|
124
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
|
125
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
|
126
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
|
127
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
|
128
|
+
GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
|
115
129
|
GGML_METAL_DECL_KERNEL(rope_f32);
|
116
130
|
GGML_METAL_DECL_KERNEL(rope_f16);
|
117
131
|
GGML_METAL_DECL_KERNEL(alibi_f32);
|
118
132
|
GGML_METAL_DECL_KERNEL(im2col_f16);
|
133
|
+
GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
|
134
|
+
GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
|
119
135
|
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
120
136
|
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
137
|
+
GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
|
138
|
+
GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
|
139
|
+
GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
|
140
|
+
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
|
141
|
+
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
|
121
142
|
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
122
143
|
GGML_METAL_DECL_KERNEL(concat);
|
123
144
|
GGML_METAL_DECL_KERNEL(sqr);
|
145
|
+
GGML_METAL_DECL_KERNEL(sum_rows);
|
124
146
|
|
125
147
|
#undef GGML_METAL_DECL_KERNEL
|
126
148
|
};
|
@@ -164,12 +186,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
|
164
186
|
}
|
165
187
|
}
|
166
188
|
|
167
|
-
|
168
|
-
|
169
189
|
struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
170
190
|
GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
|
171
191
|
|
172
|
-
id
|
192
|
+
id<MTLDevice> device;
|
173
193
|
NSString * s;
|
174
194
|
|
175
195
|
#if TARGET_OS_OSX
|
@@ -215,6 +235,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
215
235
|
|
216
236
|
NSString * sourcePath;
|
217
237
|
NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
|
238
|
+
|
239
|
+
GGML_METAL_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, ggmlMetalPathResources ? [ggmlMetalPathResources UTF8String] : "nil");
|
240
|
+
|
218
241
|
if (ggmlMetalPathResources) {
|
219
242
|
sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
|
220
243
|
} else {
|
@@ -245,6 +268,29 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
245
268
|
}
|
246
269
|
}
|
247
270
|
|
271
|
+
#if TARGET_OS_OSX
|
272
|
+
// print MTL GPU family:
|
273
|
+
GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
|
274
|
+
|
275
|
+
// determine max supported GPU family
|
276
|
+
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
277
|
+
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
278
|
+
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
279
|
+
if ([ctx->device supportsFamily:i]) {
|
280
|
+
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
|
281
|
+
break;
|
282
|
+
}
|
283
|
+
}
|
284
|
+
|
285
|
+
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
286
|
+
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
|
287
|
+
if (ctx->device.maxTransferRate != 0) {
|
288
|
+
GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
|
289
|
+
} else {
|
290
|
+
GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
|
291
|
+
}
|
292
|
+
#endif
|
293
|
+
|
248
294
|
// load kernels
|
249
295
|
{
|
250
296
|
NSError * error = nil;
|
@@ -266,6 +312,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
266
312
|
GGML_METAL_ADD_KERNEL(add_row);
|
267
313
|
GGML_METAL_ADD_KERNEL(mul);
|
268
314
|
GGML_METAL_ADD_KERNEL(mul_row);
|
315
|
+
GGML_METAL_ADD_KERNEL(div);
|
316
|
+
GGML_METAL_ADD_KERNEL(div_row);
|
269
317
|
GGML_METAL_ADD_KERNEL(scale);
|
270
318
|
GGML_METAL_ADD_KERNEL(scale_4);
|
271
319
|
GGML_METAL_ADD_KERNEL(silu);
|
@@ -317,43 +365,40 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
317
365
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
318
366
|
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
319
367
|
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
368
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
|
369
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
|
370
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
|
371
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
|
372
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
|
373
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
|
374
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
|
375
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
|
376
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
|
377
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
|
378
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
|
379
|
+
GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
|
320
380
|
}
|
321
381
|
GGML_METAL_ADD_KERNEL(rope_f32);
|
322
382
|
GGML_METAL_ADD_KERNEL(rope_f16);
|
323
383
|
GGML_METAL_ADD_KERNEL(alibi_f32);
|
324
384
|
GGML_METAL_ADD_KERNEL(im2col_f16);
|
385
|
+
GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
|
386
|
+
GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
|
325
387
|
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
326
388
|
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
389
|
+
GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
|
390
|
+
GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
|
391
|
+
GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
|
392
|
+
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
|
393
|
+
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
|
327
394
|
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
328
395
|
GGML_METAL_ADD_KERNEL(concat);
|
329
396
|
GGML_METAL_ADD_KERNEL(sqr);
|
397
|
+
GGML_METAL_ADD_KERNEL(sum_rows);
|
330
398
|
|
331
399
|
#undef GGML_METAL_ADD_KERNEL
|
332
400
|
}
|
333
401
|
|
334
|
-
#if TARGET_OS_OSX
|
335
|
-
// print MTL GPU family:
|
336
|
-
GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
|
337
|
-
|
338
|
-
// determine max supported GPU family
|
339
|
-
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
340
|
-
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
341
|
-
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
342
|
-
if ([ctx->device supportsFamily:i]) {
|
343
|
-
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
|
344
|
-
break;
|
345
|
-
}
|
346
|
-
}
|
347
|
-
|
348
|
-
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
349
|
-
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MiB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
350
|
-
if (ctx->device.maxTransferRate != 0) {
|
351
|
-
GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MiB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
352
|
-
} else {
|
353
|
-
GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
|
354
|
-
}
|
355
|
-
#endif
|
356
|
-
|
357
402
|
return ctx;
|
358
403
|
}
|
359
404
|
|
@@ -367,6 +412,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
367
412
|
GGML_METAL_DEL_KERNEL(add_row);
|
368
413
|
GGML_METAL_DEL_KERNEL(mul);
|
369
414
|
GGML_METAL_DEL_KERNEL(mul_row);
|
415
|
+
GGML_METAL_DEL_KERNEL(div);
|
416
|
+
GGML_METAL_DEL_KERNEL(div_row);
|
370
417
|
GGML_METAL_DEL_KERNEL(scale);
|
371
418
|
GGML_METAL_DEL_KERNEL(scale_4);
|
372
419
|
GGML_METAL_DEL_KERNEL(silu);
|
@@ -418,16 +465,36 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
418
465
|
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
419
466
|
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
420
467
|
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
468
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
|
469
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
|
470
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
|
471
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
|
472
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
|
473
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
|
474
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
|
475
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
|
476
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
|
477
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
|
478
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
|
479
|
+
GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
|
421
480
|
}
|
422
481
|
GGML_METAL_DEL_KERNEL(rope_f32);
|
423
482
|
GGML_METAL_DEL_KERNEL(rope_f16);
|
424
483
|
GGML_METAL_DEL_KERNEL(alibi_f32);
|
425
484
|
GGML_METAL_DEL_KERNEL(im2col_f16);
|
485
|
+
GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
|
486
|
+
GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
|
426
487
|
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
427
488
|
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
489
|
+
GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
|
490
|
+
GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
|
491
|
+
GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
|
492
|
+
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
|
493
|
+
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
|
428
494
|
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
429
495
|
GGML_METAL_DEL_KERNEL(concat);
|
430
496
|
GGML_METAL_DEL_KERNEL(sqr);
|
497
|
+
GGML_METAL_DEL_KERNEL(sum_rows);
|
431
498
|
|
432
499
|
#undef GGML_METAL_DEL_KERNEL
|
433
500
|
|
@@ -471,6 +538,13 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
|
|
471
538
|
return ctx->concur_list;
|
472
539
|
}
|
473
540
|
|
541
|
+
// temporarily defined here for compatibility between ggml-backend and the old API
|
542
|
+
struct ggml_backend_metal_buffer_context {
|
543
|
+
void * data;
|
544
|
+
|
545
|
+
id<MTLBuffer> metal;
|
546
|
+
};
|
547
|
+
|
474
548
|
// finds the Metal buffer that contains the tensor data on the GPU device
|
475
549
|
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
476
550
|
// Metal buffer based on the host memory pointer
|
@@ -480,8 +554,17 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
|
480
554
|
|
481
555
|
const int64_t tsize = ggml_nbytes(t);
|
482
556
|
|
483
|
-
|
484
|
-
|
557
|
+
// compatibility with ggml-backend
|
558
|
+
if (t->buffer && t->buffer->buft == ggml_backend_metal_buffer_type()) {
|
559
|
+
struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) t->buffer->context;
|
560
|
+
|
561
|
+
const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->data;
|
562
|
+
|
563
|
+
GGML_ASSERT(ioffs >= 0 && ioffs + tsize <= (int64_t) t->buffer->size);
|
564
|
+
|
565
|
+
*offs = (size_t) ioffs;
|
566
|
+
|
567
|
+
return buf_ctx->metal;
|
485
568
|
}
|
486
569
|
|
487
570
|
// find the view that contains the tensor fully
|
@@ -706,6 +789,51 @@ void ggml_metal_graph_find_concurrency(
|
|
706
789
|
}
|
707
790
|
}
|
708
791
|
|
792
|
+
static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
793
|
+
switch (op->op) {
|
794
|
+
case GGML_OP_UNARY:
|
795
|
+
switch (ggml_get_unary_op(op)) {
|
796
|
+
case GGML_UNARY_OP_SILU:
|
797
|
+
case GGML_UNARY_OP_RELU:
|
798
|
+
case GGML_UNARY_OP_GELU:
|
799
|
+
return true;
|
800
|
+
default:
|
801
|
+
return false;
|
802
|
+
}
|
803
|
+
case GGML_OP_NONE:
|
804
|
+
case GGML_OP_RESHAPE:
|
805
|
+
case GGML_OP_VIEW:
|
806
|
+
case GGML_OP_TRANSPOSE:
|
807
|
+
case GGML_OP_PERMUTE:
|
808
|
+
case GGML_OP_CONCAT:
|
809
|
+
case GGML_OP_ADD:
|
810
|
+
case GGML_OP_MUL:
|
811
|
+
case GGML_OP_DIV:
|
812
|
+
case GGML_OP_SCALE:
|
813
|
+
case GGML_OP_SQR:
|
814
|
+
case GGML_OP_SUM_ROWS:
|
815
|
+
case GGML_OP_SOFT_MAX:
|
816
|
+
case GGML_OP_RMS_NORM:
|
817
|
+
case GGML_OP_NORM:
|
818
|
+
case GGML_OP_ALIBI:
|
819
|
+
case GGML_OP_ROPE:
|
820
|
+
case GGML_OP_IM2COL:
|
821
|
+
case GGML_OP_ARGSORT:
|
822
|
+
case GGML_OP_DUP:
|
823
|
+
case GGML_OP_CPY:
|
824
|
+
case GGML_OP_CONT:
|
825
|
+
case GGML_OP_MUL_MAT:
|
826
|
+
case GGML_OP_MUL_MAT_ID:
|
827
|
+
return true;
|
828
|
+
case GGML_OP_DIAG_MASK_INF:
|
829
|
+
case GGML_OP_GET_ROWS:
|
830
|
+
{
|
831
|
+
return op->ne[0] % 4 == 0;
|
832
|
+
}
|
833
|
+
default:
|
834
|
+
return false;
|
835
|
+
}
|
836
|
+
}
|
709
837
|
void ggml_metal_graph_compute(
|
710
838
|
struct ggml_metal_context * ctx,
|
711
839
|
struct ggml_cgraph * gf) {
|
@@ -776,6 +904,8 @@ void ggml_metal_graph_compute(
|
|
776
904
|
} break;
|
777
905
|
}
|
778
906
|
|
907
|
+
GGML_ASSERT(ggml_metal_supports_op(dst));
|
908
|
+
|
779
909
|
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
780
910
|
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
781
911
|
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
@@ -868,6 +998,8 @@ void ggml_metal_graph_compute(
|
|
868
998
|
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
869
999
|
} break;
|
870
1000
|
case GGML_OP_ADD:
|
1001
|
+
case GGML_OP_MUL:
|
1002
|
+
case GGML_OP_DIV:
|
871
1003
|
{
|
872
1004
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
873
1005
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
@@ -881,11 +1013,21 @@ void ggml_metal_graph_compute(
|
|
881
1013
|
GGML_ASSERT(ne11 == 1);
|
882
1014
|
|
883
1015
|
nb = ne00 / 4;
|
884
|
-
|
1016
|
+
switch (dst->op) {
|
1017
|
+
case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
|
1018
|
+
case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
|
1019
|
+
case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
|
1020
|
+
default: GGML_ASSERT(false);
|
1021
|
+
}
|
885
1022
|
|
886
1023
|
bcast_row = true;
|
887
1024
|
} else {
|
888
|
-
|
1025
|
+
switch (dst->op) {
|
1026
|
+
case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
|
1027
|
+
case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
|
1028
|
+
case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
|
1029
|
+
default: GGML_ASSERT(false);
|
1030
|
+
}
|
889
1031
|
}
|
890
1032
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
891
1033
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
@@ -926,31 +1068,6 @@ void ggml_metal_graph_compute(
|
|
926
1068
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
927
1069
|
}
|
928
1070
|
} break;
|
929
|
-
case GGML_OP_MUL:
|
930
|
-
{
|
931
|
-
GGML_ASSERT(ggml_is_contiguous(src0));
|
932
|
-
GGML_ASSERT(ggml_is_contiguous(src1));
|
933
|
-
|
934
|
-
// utilize float4
|
935
|
-
GGML_ASSERT(ne00 % 4 == 0);
|
936
|
-
const int64_t nb = ne00/4;
|
937
|
-
|
938
|
-
if (ggml_nelements(src1) == ne10) {
|
939
|
-
// src1 is a row
|
940
|
-
GGML_ASSERT(ne11 == 1);
|
941
|
-
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
942
|
-
} else {
|
943
|
-
[encoder setComputePipelineState:ctx->pipeline_mul];
|
944
|
-
}
|
945
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
946
|
-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
947
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
948
|
-
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
|
949
|
-
|
950
|
-
const int64_t n = ggml_nelements(dst)/4;
|
951
|
-
|
952
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
953
|
-
} break;
|
954
1071
|
case GGML_OP_SCALE:
|
955
1072
|
{
|
956
1073
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
@@ -1023,25 +1140,66 @@ void ggml_metal_graph_compute(
|
|
1023
1140
|
const int64_t n = ggml_nelements(dst);
|
1024
1141
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1025
1142
|
} break;
|
1143
|
+
case GGML_OP_SUM_ROWS:
|
1144
|
+
{
|
1145
|
+
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
1146
|
+
|
1147
|
+
[encoder setComputePipelineState:ctx->pipeline_sum_rows];
|
1148
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1149
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1150
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
1151
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
1152
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
1153
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
1154
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
1155
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
1156
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
1157
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
1158
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
1159
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
1160
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
1161
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
1162
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
1163
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
1164
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
1165
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
1166
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
|
1167
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
|
1168
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
|
1169
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
|
1170
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
|
1171
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
|
1172
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
|
1173
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
|
1174
|
+
|
1175
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1176
|
+
} break;
|
1026
1177
|
case GGML_OP_SOFT_MAX:
|
1027
1178
|
{
|
1028
1179
|
int nth = 32; // SIMD width
|
1029
1180
|
|
1030
1181
|
if (ne00%4 == 0) {
|
1182
|
+
while (nth < ne00/4 && nth < 256) {
|
1183
|
+
nth *= 2;
|
1184
|
+
}
|
1031
1185
|
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
1032
1186
|
} else {
|
1033
|
-
|
1187
|
+
while (nth < ne00 && nth < 1024) {
|
1034
1188
|
nth *= 2;
|
1035
|
-
}
|
1036
|
-
nth /= 2;
|
1189
|
+
}
|
1037
1190
|
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
1038
1191
|
}
|
1039
|
-
|
1040
|
-
[
|
1041
|
-
|
1042
|
-
[encoder
|
1043
|
-
[encoder
|
1044
|
-
[encoder
|
1192
|
+
|
1193
|
+
const float scale = ((float *) dst->op_params)[0];
|
1194
|
+
|
1195
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1196
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1197
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1198
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
1199
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
1200
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
1201
|
+
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
1202
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
1045
1203
|
|
1046
1204
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1047
1205
|
} break;
|
@@ -1070,9 +1228,13 @@ void ggml_metal_graph_compute(
|
|
1070
1228
|
case GGML_OP_MUL_MAT:
|
1071
1229
|
{
|
1072
1230
|
GGML_ASSERT(ne00 == ne10);
|
1073
|
-
GGML_ASSERT(ne03 == ne13);
|
1074
1231
|
|
1075
|
-
|
1232
|
+
// TODO: assert that dim2 and dim3 are contiguous
|
1233
|
+
GGML_ASSERT(ne12 % ne02 == 0);
|
1234
|
+
GGML_ASSERT(ne13 % ne03 == 0);
|
1235
|
+
|
1236
|
+
const uint r2 = ne12/ne02;
|
1237
|
+
const uint r3 = ne13/ne03;
|
1076
1238
|
|
1077
1239
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
1078
1240
|
// to the matrix-vector kernel
|
@@ -1107,7 +1269,7 @@ void ggml_metal_graph_compute(
|
|
1107
1269
|
!ggml_is_transposed(src1) &&
|
1108
1270
|
src1t == GGML_TYPE_F32 &&
|
1109
1271
|
ne00 % 32 == 0 && ne00 >= 64 &&
|
1110
|
-
ne11 > ne11_mm_min) {
|
1272
|
+
(ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
|
1111
1273
|
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
1112
1274
|
switch (src0->type) {
|
1113
1275
|
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
|
@@ -1137,9 +1299,10 @@ void ggml_metal_graph_compute(
|
|
1137
1299
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
1138
1300
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
1139
1301
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
1140
|
-
[encoder setBytes:&
|
1302
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
|
1303
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
|
1141
1304
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
1142
|
-
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
1305
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
1143
1306
|
} else {
|
1144
1307
|
int nth0 = 32;
|
1145
1308
|
int nth1 = 1;
|
@@ -1175,90 +1338,60 @@ void ggml_metal_graph_compute(
|
|
1175
1338
|
} break;
|
1176
1339
|
case GGML_TYPE_Q4_0:
|
1177
1340
|
{
|
1178
|
-
GGML_ASSERT(ne02 == 1);
|
1179
|
-
GGML_ASSERT(ne12 == 1);
|
1180
|
-
|
1181
1341
|
nth0 = 8;
|
1182
1342
|
nth1 = 8;
|
1183
1343
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
|
1184
1344
|
} break;
|
1185
1345
|
case GGML_TYPE_Q4_1:
|
1186
1346
|
{
|
1187
|
-
GGML_ASSERT(ne02 == 1);
|
1188
|
-
GGML_ASSERT(ne12 == 1);
|
1189
|
-
|
1190
1347
|
nth0 = 8;
|
1191
1348
|
nth1 = 8;
|
1192
1349
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
|
1193
1350
|
} break;
|
1194
1351
|
case GGML_TYPE_Q5_0:
|
1195
1352
|
{
|
1196
|
-
GGML_ASSERT(ne02 == 1);
|
1197
|
-
GGML_ASSERT(ne12 == 1);
|
1198
|
-
|
1199
1353
|
nth0 = 8;
|
1200
1354
|
nth1 = 8;
|
1201
1355
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
|
1202
1356
|
} break;
|
1203
1357
|
case GGML_TYPE_Q5_1:
|
1204
1358
|
{
|
1205
|
-
GGML_ASSERT(ne02 == 1);
|
1206
|
-
GGML_ASSERT(ne12 == 1);
|
1207
|
-
|
1208
1359
|
nth0 = 8;
|
1209
1360
|
nth1 = 8;
|
1210
1361
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
|
1211
1362
|
} break;
|
1212
1363
|
case GGML_TYPE_Q8_0:
|
1213
1364
|
{
|
1214
|
-
GGML_ASSERT(ne02 == 1);
|
1215
|
-
GGML_ASSERT(ne12 == 1);
|
1216
|
-
|
1217
1365
|
nth0 = 8;
|
1218
1366
|
nth1 = 8;
|
1219
1367
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
|
1220
1368
|
} break;
|
1221
1369
|
case GGML_TYPE_Q2_K:
|
1222
1370
|
{
|
1223
|
-
GGML_ASSERT(ne02 == 1);
|
1224
|
-
GGML_ASSERT(ne12 == 1);
|
1225
|
-
|
1226
1371
|
nth0 = 2;
|
1227
1372
|
nth1 = 32;
|
1228
1373
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
|
1229
1374
|
} break;
|
1230
1375
|
case GGML_TYPE_Q3_K:
|
1231
1376
|
{
|
1232
|
-
GGML_ASSERT(ne02 == 1);
|
1233
|
-
GGML_ASSERT(ne12 == 1);
|
1234
|
-
|
1235
1377
|
nth0 = 2;
|
1236
1378
|
nth1 = 32;
|
1237
1379
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
|
1238
1380
|
} break;
|
1239
1381
|
case GGML_TYPE_Q4_K:
|
1240
1382
|
{
|
1241
|
-
GGML_ASSERT(ne02 == 1);
|
1242
|
-
GGML_ASSERT(ne12 == 1);
|
1243
|
-
|
1244
1383
|
nth0 = 4; //1;
|
1245
1384
|
nth1 = 8; //32;
|
1246
1385
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
|
1247
1386
|
} break;
|
1248
1387
|
case GGML_TYPE_Q5_K:
|
1249
1388
|
{
|
1250
|
-
GGML_ASSERT(ne02 == 1);
|
1251
|
-
GGML_ASSERT(ne12 == 1);
|
1252
|
-
|
1253
1389
|
nth0 = 2;
|
1254
1390
|
nth1 = 32;
|
1255
1391
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
|
1256
1392
|
} break;
|
1257
1393
|
case GGML_TYPE_Q6_K:
|
1258
1394
|
{
|
1259
|
-
GGML_ASSERT(ne02 == 1);
|
1260
|
-
GGML_ASSERT(ne12 == 1);
|
1261
|
-
|
1262
1395
|
nth0 = 2;
|
1263
1396
|
nth1 = 32;
|
1264
1397
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
|
@@ -1287,32 +1420,125 @@ void ggml_metal_graph_compute(
|
|
1287
1420
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
1288
1421
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
1289
1422
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
1290
|
-
[encoder setBytes:&
|
1423
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
1424
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
1291
1425
|
|
1292
1426
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
1293
1427
|
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
1294
1428
|
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
1295
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1429
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1296
1430
|
}
|
1297
1431
|
else if (src0t == GGML_TYPE_Q4_K) {
|
1298
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1432
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1299
1433
|
}
|
1300
1434
|
else if (src0t == GGML_TYPE_Q3_K) {
|
1301
1435
|
#ifdef GGML_QKK_64
|
1302
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1436
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1303
1437
|
#else
|
1304
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1438
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1305
1439
|
#endif
|
1306
1440
|
}
|
1307
1441
|
else if (src0t == GGML_TYPE_Q5_K) {
|
1308
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1442
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1309
1443
|
}
|
1310
1444
|
else if (src0t == GGML_TYPE_Q6_K) {
|
1311
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1445
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1312
1446
|
} else {
|
1313
1447
|
int64_t ny = (ne11 + nrows - 1)/nrows;
|
1314
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1448
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1449
|
+
}
|
1450
|
+
}
|
1451
|
+
} break;
|
1452
|
+
case GGML_OP_MUL_MAT_ID:
|
1453
|
+
{
|
1454
|
+
//GGML_ASSERT(ne00 == ne10);
|
1455
|
+
//GGML_ASSERT(ne03 == ne13);
|
1456
|
+
|
1457
|
+
GGML_ASSERT(src0t == GGML_TYPE_I32);
|
1458
|
+
|
1459
|
+
const int n_as = ne00;
|
1460
|
+
|
1461
|
+
// TODO: make this more general
|
1462
|
+
GGML_ASSERT(n_as <= 8);
|
1463
|
+
|
1464
|
+
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
1465
|
+
|
1466
|
+
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
1467
|
+
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
1468
|
+
const int64_t ne22 = src2 ? src2->ne[2] : 0;
|
1469
|
+
const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
|
1470
|
+
|
1471
|
+
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
1472
|
+
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
1473
|
+
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
1474
|
+
const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
|
1475
|
+
|
1476
|
+
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
1477
|
+
|
1478
|
+
GGML_ASSERT(!ggml_is_transposed(src2));
|
1479
|
+
GGML_ASSERT(!ggml_is_transposed(src1));
|
1480
|
+
|
1481
|
+
GGML_ASSERT(ne20 % 32 == 0);
|
1482
|
+
// !!!!!!!!! TODO: this assert is probably required but not sure!
|
1483
|
+
//GGML_ASSERT(ne20 >= 64);
|
1484
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1485
|
+
|
1486
|
+
const uint r2 = ne12/ne22;
|
1487
|
+
const uint r3 = ne13/ne23;
|
1488
|
+
|
1489
|
+
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
1490
|
+
// to the matrix-vector kernel
|
1491
|
+
int ne11_mm_min = 0;
|
1492
|
+
|
1493
|
+
const int idx = ((int32_t *) dst->op_params)[0];
|
1494
|
+
|
1495
|
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
1496
|
+
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
1497
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
1498
|
+
ne11 > ne11_mm_min) {
|
1499
|
+
switch (src2->type) {
|
1500
|
+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
|
1501
|
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
|
1502
|
+
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
|
1503
|
+
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
|
1504
|
+
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
|
1505
|
+
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
|
1506
|
+
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
|
1507
|
+
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
|
1508
|
+
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
|
1509
|
+
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
|
1510
|
+
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
|
1511
|
+
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
|
1512
|
+
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
1315
1513
|
}
|
1514
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1515
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1516
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1517
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
|
1518
|
+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
|
1519
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
|
1520
|
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6];
|
1521
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
1522
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
|
1523
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
|
1524
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
1525
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
1526
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
1527
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
|
1528
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
|
1529
|
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:15];
|
1530
|
+
// TODO: how to make this an array? read Metal docs
|
1531
|
+
for (int j = 0; j < n_as; ++j) {
|
1532
|
+
struct ggml_tensor * src_cur = dst->src[2 + j];
|
1533
|
+
|
1534
|
+
size_t offs_src_cur = 0;
|
1535
|
+
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
1536
|
+
|
1537
|
+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
|
1538
|
+
}
|
1539
|
+
|
1540
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
1541
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
1316
1542
|
}
|
1317
1543
|
} break;
|
1318
1544
|
case GGML_OP_GET_ROWS:
|
@@ -1351,15 +1577,19 @@ void ggml_metal_graph_compute(
|
|
1351
1577
|
float eps;
|
1352
1578
|
memcpy(&eps, dst->op_params, sizeof(float));
|
1353
1579
|
|
1354
|
-
|
1580
|
+
int nth = 32; // SIMD width
|
1581
|
+
|
1582
|
+
while (nth < ne00/4 && nth < 1024) {
|
1583
|
+
nth *= 2;
|
1584
|
+
}
|
1355
1585
|
|
1356
1586
|
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
1357
|
-
[encoder setBuffer:id_src0 offset:offs_src0
|
1358
|
-
[encoder setBuffer:id_dst offset:offs_dst
|
1359
|
-
[encoder setBytes:&ne00
|
1360
|
-
[encoder setBytes:&nb01
|
1361
|
-
[encoder setBytes:&eps
|
1362
|
-
[encoder setThreadgroupMemoryLength:
|
1587
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1588
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1589
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
1590
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
1591
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
1592
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
1363
1593
|
|
1364
1594
|
const int64_t nrows = ggml_nrows(src0);
|
1365
1595
|
|
@@ -1433,7 +1663,8 @@ void ggml_metal_graph_compute(
|
|
1433
1663
|
const int n_past = ((int32_t *) dst->op_params)[0];
|
1434
1664
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
1435
1665
|
const int mode = ((int32_t *) dst->op_params)[2];
|
1436
|
-
|
1666
|
+
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
1667
|
+
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
1437
1668
|
|
1438
1669
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
1439
1670
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
@@ -1533,18 +1764,48 @@ void ggml_metal_graph_compute(
|
|
1533
1764
|
|
1534
1765
|
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
1535
1766
|
} break;
|
1767
|
+
case GGML_OP_ARGSORT:
|
1768
|
+
{
|
1769
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
1770
|
+
GGML_ASSERT( dst->type == GGML_TYPE_I32);
|
1771
|
+
|
1772
|
+
const int nrows = ggml_nrows(src0);
|
1773
|
+
|
1774
|
+
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
1775
|
+
|
1776
|
+
switch (order) {
|
1777
|
+
case GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
|
1778
|
+
case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
|
1779
|
+
default: GGML_ASSERT(false);
|
1780
|
+
};
|
1781
|
+
|
1782
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1783
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1784
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
1785
|
+
|
1786
|
+
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
|
1787
|
+
} break;
|
1536
1788
|
case GGML_OP_DUP:
|
1537
1789
|
case GGML_OP_CPY:
|
1538
1790
|
case GGML_OP_CONT:
|
1539
1791
|
{
|
1540
|
-
|
1792
|
+
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
1793
|
+
|
1794
|
+
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
|
1541
1795
|
|
1542
1796
|
switch (src0t) {
|
1543
1797
|
case GGML_TYPE_F32:
|
1544
1798
|
{
|
1799
|
+
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
|
1800
|
+
|
1545
1801
|
switch (dstt) {
|
1546
|
-
case GGML_TYPE_F16:
|
1547
|
-
case GGML_TYPE_F32:
|
1802
|
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
|
1803
|
+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
|
1804
|
+
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
|
1805
|
+
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
|
1806
|
+
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
|
1807
|
+
//case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
|
1808
|
+
//case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
|
1548
1809
|
default: GGML_ASSERT(false && "not implemented");
|
1549
1810
|
};
|
1550
1811
|
} break;
|
@@ -1619,81 +1880,150 @@ void ggml_metal_graph_compute(
|
|
1619
1880
|
|
1620
1881
|
// backend interface
|
1621
1882
|
|
1622
|
-
static
|
1623
|
-
|
1883
|
+
static id<MTLDevice> g_backend_device = nil;
|
1884
|
+
static int g_backend_device_ref_count = 0;
|
1624
1885
|
|
1625
|
-
|
1886
|
+
static id<MTLDevice> ggml_backend_metal_get_device(void) {
|
1887
|
+
if (g_backend_device == nil) {
|
1888
|
+
g_backend_device = MTLCreateSystemDefaultDevice();
|
1889
|
+
}
|
1890
|
+
|
1891
|
+
g_backend_device_ref_count++;
|
1892
|
+
|
1893
|
+
return g_backend_device;
|
1626
1894
|
}
|
1627
1895
|
|
1628
|
-
static void
|
1629
|
-
|
1630
|
-
|
1631
|
-
|
1896
|
+
static void ggml_backend_metal_free_device(void) {
|
1897
|
+
assert(g_backend_device_ref_count > 0);
|
1898
|
+
|
1899
|
+
g_backend_device_ref_count--;
|
1900
|
+
|
1901
|
+
if (g_backend_device_ref_count == 0) {
|
1902
|
+
[g_backend_device release];
|
1903
|
+
g_backend_device = nil;
|
1904
|
+
}
|
1632
1905
|
}
|
1633
1906
|
|
1634
1907
|
static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
|
1635
|
-
|
1908
|
+
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
|
1909
|
+
|
1910
|
+
return ctx->data;
|
1636
1911
|
}
|
1637
1912
|
|
1638
1913
|
static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
1639
|
-
|
1914
|
+
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
|
1915
|
+
|
1916
|
+
[ctx->metal release];
|
1917
|
+
ggml_backend_metal_free_device();
|
1918
|
+
|
1919
|
+
free(ctx->data);
|
1920
|
+
free(ctx);
|
1921
|
+
|
1922
|
+
UNUSED(buffer);
|
1923
|
+
}
|
1924
|
+
|
1925
|
+
static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
1926
|
+
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
|
1927
|
+
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
1928
|
+
|
1929
|
+
memcpy((char *)tensor->data + offset, data, size);
|
1930
|
+
|
1931
|
+
UNUSED(buffer);
|
1932
|
+
}
|
1933
|
+
|
1934
|
+
static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
1935
|
+
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
|
1936
|
+
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
1937
|
+
|
1938
|
+
memcpy(data, (const char *)tensor->data + offset, size);
|
1939
|
+
|
1940
|
+
UNUSED(buffer);
|
1941
|
+
}
|
1942
|
+
|
1943
|
+
static void ggml_backend_metal_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
|
1944
|
+
ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
|
1945
|
+
|
1946
|
+
UNUSED(buffer);
|
1947
|
+
}
|
1948
|
+
|
1949
|
+
static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
|
1950
|
+
ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
|
1951
|
+
|
1640
1952
|
UNUSED(buffer);
|
1641
1953
|
}
|
1642
1954
|
|
1643
1955
|
static struct ggml_backend_buffer_i metal_backend_buffer_i = {
|
1644
|
-
/* .free_buffer
|
1645
|
-
/* .get_base
|
1646
|
-
/* .
|
1647
|
-
/* .
|
1648
|
-
/* .
|
1956
|
+
/* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
|
1957
|
+
/* .get_base = */ ggml_backend_metal_buffer_get_base,
|
1958
|
+
/* .init_tensor = */ NULL,
|
1959
|
+
/* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
|
1960
|
+
/* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
|
1961
|
+
/* .cpy_tensor_from = */ ggml_backend_metal_buffer_cpy_tensor_from,
|
1962
|
+
/* .cpy_tensor_to = */ ggml_backend_metal_buffer_cpy_tensor_to,
|
1649
1963
|
};
|
1650
1964
|
|
1651
|
-
static ggml_backend_buffer_t
|
1652
|
-
struct
|
1965
|
+
static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
1966
|
+
struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
|
1653
1967
|
|
1654
|
-
|
1968
|
+
const size_t size_page = sysconf(_SC_PAGESIZE);
|
1655
1969
|
|
1656
|
-
|
1657
|
-
|
1970
|
+
size_t size_aligned = size;
|
1971
|
+
if ((size_aligned % size_page) != 0) {
|
1972
|
+
size_aligned += (size_page - (size_aligned % size_page));
|
1973
|
+
}
|
1974
|
+
|
1975
|
+
ctx->data = ggml_metal_host_malloc(size);
|
1976
|
+
ctx->metal = [ggml_backend_metal_get_device() newBufferWithBytesNoCopy:ctx->data
|
1977
|
+
length:size_aligned
|
1978
|
+
options:MTLResourceStorageModeShared
|
1979
|
+
deallocator:nil];
|
1658
1980
|
|
1659
|
-
return ggml_backend_buffer_init(
|
1981
|
+
return ggml_backend_buffer_init(buft, metal_backend_buffer_i, ctx, size);
|
1660
1982
|
}
|
1661
1983
|
|
1662
|
-
static size_t
|
1984
|
+
static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
1663
1985
|
return 32;
|
1664
|
-
UNUSED(
|
1986
|
+
UNUSED(buft);
|
1665
1987
|
}
|
1666
1988
|
|
1667
|
-
static
|
1668
|
-
|
1669
|
-
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
1670
|
-
|
1671
|
-
memcpy((char *)tensor->data + offset, data, size);
|
1989
|
+
static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
|
1990
|
+
return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
|
1672
1991
|
|
1673
|
-
|
1992
|
+
GGML_UNUSED(buft);
|
1674
1993
|
}
|
1675
1994
|
|
1676
|
-
|
1677
|
-
|
1678
|
-
|
1679
|
-
|
1680
|
-
|
1995
|
+
ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
1996
|
+
static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
|
1997
|
+
/* .iface = */ {
|
1998
|
+
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
|
1999
|
+
/* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
|
2000
|
+
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
2001
|
+
/* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
|
2002
|
+
},
|
2003
|
+
/* .context = */ NULL,
|
2004
|
+
};
|
1681
2005
|
|
1682
|
-
|
2006
|
+
return &ggml_backend_buffer_type_metal;
|
1683
2007
|
}
|
1684
2008
|
|
1685
|
-
static
|
2009
|
+
static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
2010
|
+
return "Metal";
|
2011
|
+
|
1686
2012
|
UNUSED(backend);
|
1687
2013
|
}
|
1688
2014
|
|
1689
|
-
static void
|
1690
|
-
|
2015
|
+
static void ggml_backend_metal_free(ggml_backend_t backend) {
|
2016
|
+
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
2017
|
+
ggml_metal_free(ctx);
|
2018
|
+
free(backend);
|
2019
|
+
}
|
1691
2020
|
|
2021
|
+
static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
|
1692
2022
|
UNUSED(backend);
|
1693
2023
|
}
|
1694
2024
|
|
1695
|
-
static
|
1696
|
-
|
2025
|
+
static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
|
2026
|
+
return ggml_backend_metal_buffer_type();
|
1697
2027
|
|
1698
2028
|
UNUSED(backend);
|
1699
2029
|
}
|
@@ -1705,32 +2035,43 @@ static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml
|
|
1705
2035
|
}
|
1706
2036
|
|
1707
2037
|
static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
1708
|
-
return
|
2038
|
+
return ggml_metal_supports_op(op);
|
2039
|
+
|
1709
2040
|
UNUSED(backend);
|
1710
|
-
UNUSED(op);
|
1711
2041
|
}
|
1712
2042
|
|
1713
2043
|
static struct ggml_backend_i metal_backend_i = {
|
1714
|
-
/* .get_name
|
1715
|
-
/* .free
|
1716
|
-
/* .
|
1717
|
-
/* .
|
1718
|
-
/* .
|
1719
|
-
/* .
|
1720
|
-
/* .
|
1721
|
-
/* .
|
1722
|
-
/* .
|
1723
|
-
/* .
|
1724
|
-
/* .
|
1725
|
-
/* .
|
1726
|
-
/* .
|
1727
|
-
/* .supports_op = */ ggml_backend_metal_supports_op,
|
2044
|
+
/* .get_name = */ ggml_backend_metal_name,
|
2045
|
+
/* .free = */ ggml_backend_metal_free,
|
2046
|
+
/* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
|
2047
|
+
/* .set_tensor_async = */ NULL,
|
2048
|
+
/* .get_tensor_async = */ NULL,
|
2049
|
+
/* .cpy_tensor_from_async = */ NULL,
|
2050
|
+
/* .cpy_tensor_to_async = */ NULL,
|
2051
|
+
/* .synchronize = */ ggml_backend_metal_synchronize,
|
2052
|
+
/* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
|
2053
|
+
/* .graph_plan_free = */ NULL,
|
2054
|
+
/* .graph_plan_compute = */ NULL,
|
2055
|
+
/* .graph_compute = */ ggml_backend_metal_graph_compute,
|
2056
|
+
/* .supports_op = */ ggml_backend_metal_supports_op,
|
1728
2057
|
};
|
1729
2058
|
|
2059
|
+
// TODO: make a common log callback for all backends in ggml-backend
|
2060
|
+
static void ggml_backend_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
|
2061
|
+
fprintf(stderr, "%s", msg);
|
2062
|
+
|
2063
|
+
UNUSED(level);
|
2064
|
+
UNUSED(user_data);
|
2065
|
+
}
|
2066
|
+
|
1730
2067
|
ggml_backend_t ggml_backend_metal_init(void) {
|
1731
|
-
|
2068
|
+
ggml_metal_log_set_callback(ggml_backend_log_callback, NULL);
|
2069
|
+
|
2070
|
+
struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
|
1732
2071
|
|
1733
|
-
ctx
|
2072
|
+
if (ctx == NULL) {
|
2073
|
+
return NULL;
|
2074
|
+
}
|
1734
2075
|
|
1735
2076
|
ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
|
1736
2077
|
|
@@ -1747,7 +2088,26 @@ bool ggml_backend_is_metal(ggml_backend_t backend) {
|
|
1747
2088
|
}
|
1748
2089
|
|
1749
2090
|
void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
2091
|
+
GGML_ASSERT(ggml_backend_is_metal(backend));
|
2092
|
+
|
1750
2093
|
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
1751
2094
|
|
1752
2095
|
ggml_metal_set_n_cb(ctx, n_cb);
|
1753
2096
|
}
|
2097
|
+
|
2098
|
+
bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
2099
|
+
GGML_ASSERT(ggml_backend_is_metal(backend));
|
2100
|
+
|
2101
|
+
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
2102
|
+
|
2103
|
+
return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
2104
|
+
}
|
2105
|
+
|
2106
|
+
ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
|
2107
|
+
|
2108
|
+
ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
|
2109
|
+
return ggml_backend_metal_init();
|
2110
|
+
|
2111
|
+
GGML_UNUSED(params);
|
2112
|
+
GGML_UNUSED(user_data);
|
2113
|
+
}
|