llama_cpp 0.10.0 → 0.10.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/ext/llama_cpp/llama_cpp.cpp +18 -1
- data/ext/llama_cpp/src/ggml-alloc.c +12 -4
- data/ext/llama_cpp/src/ggml-alloc.h +1 -1
- data/ext/llama_cpp/src/ggml-backend-impl.h +12 -8
- data/ext/llama_cpp/src/ggml-backend.c +75 -5
- data/ext/llama_cpp/src/ggml-backend.h +7 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +952 -232
- data/ext/llama_cpp/src/ggml-metal.h +3 -0
- data/ext/llama_cpp/src/ggml-metal.m +725 -98
- data/ext/llama_cpp/src/ggml-metal.metal +1508 -171
- data/ext/llama_cpp/src/ggml-quants.c +2 -2
- data/ext/llama_cpp/src/ggml.c +554 -215
- data/ext/llama_cpp/src/ggml.h +58 -23
- data/ext/llama_cpp/src/llama.cpp +1157 -851
- data/ext/llama_cpp/src/llama.h +9 -4
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +2 -0
- metadata +2 -2
@@ -66,9 +66,11 @@ struct ggml_metal_context {
|
|
66
66
|
GGML_METAL_DECL_KERNEL(div_row);
|
67
67
|
GGML_METAL_DECL_KERNEL(scale);
|
68
68
|
GGML_METAL_DECL_KERNEL(scale_4);
|
69
|
-
GGML_METAL_DECL_KERNEL(
|
69
|
+
GGML_METAL_DECL_KERNEL(tanh);
|
70
70
|
GGML_METAL_DECL_KERNEL(relu);
|
71
71
|
GGML_METAL_DECL_KERNEL(gelu);
|
72
|
+
GGML_METAL_DECL_KERNEL(gelu_quick);
|
73
|
+
GGML_METAL_DECL_KERNEL(silu);
|
72
74
|
GGML_METAL_DECL_KERNEL(soft_max);
|
73
75
|
GGML_METAL_DECL_KERNEL(soft_max_4);
|
74
76
|
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
@@ -86,6 +88,7 @@ struct ggml_metal_context {
|
|
86
88
|
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
87
89
|
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
88
90
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
91
|
+
GGML_METAL_DECL_KERNEL(group_norm);
|
89
92
|
GGML_METAL_DECL_KERNEL(norm);
|
90
93
|
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
91
94
|
GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
|
@@ -102,6 +105,21 @@ struct ggml_metal_context {
|
|
102
105
|
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
|
103
106
|
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
104
107
|
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
108
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
|
109
|
+
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
|
110
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
|
111
|
+
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
|
112
|
+
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
|
113
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
|
114
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
|
115
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
|
116
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
|
117
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
|
118
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
|
119
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
|
120
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
|
121
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
|
122
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
|
105
123
|
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
106
124
|
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
107
125
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
@@ -130,8 +148,11 @@ struct ggml_metal_context {
|
|
130
148
|
GGML_METAL_DECL_KERNEL(rope_f16);
|
131
149
|
GGML_METAL_DECL_KERNEL(alibi_f32);
|
132
150
|
GGML_METAL_DECL_KERNEL(im2col_f16);
|
151
|
+
GGML_METAL_DECL_KERNEL(upscale_f32);
|
152
|
+
GGML_METAL_DECL_KERNEL(pad_f32);
|
133
153
|
GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
|
134
154
|
GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
|
155
|
+
GGML_METAL_DECL_KERNEL(leaky_relu_f32);
|
135
156
|
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
136
157
|
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
137
158
|
GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
|
@@ -140,6 +161,7 @@ struct ggml_metal_context {
|
|
140
161
|
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
|
141
162
|
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
|
142
163
|
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
164
|
+
GGML_METAL_DECL_KERNEL(cpy_f16_f32);
|
143
165
|
GGML_METAL_DECL_KERNEL(concat);
|
144
166
|
GGML_METAL_DECL_KERNEL(sqr);
|
145
167
|
GGML_METAL_DECL_KERNEL(sum_rows);
|
@@ -158,7 +180,15 @@ struct ggml_metal_context {
|
|
158
180
|
@implementation GGMLMetalClass
|
159
181
|
@end
|
160
182
|
|
161
|
-
|
183
|
+
|
184
|
+
static void ggml_metal_default_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
|
185
|
+
fprintf(stderr, "%s", msg);
|
186
|
+
|
187
|
+
UNUSED(level);
|
188
|
+
UNUSED(user_data);
|
189
|
+
}
|
190
|
+
|
191
|
+
ggml_log_callback ggml_metal_log_callback = ggml_metal_default_log_callback;
|
162
192
|
void * ggml_metal_log_user_data = NULL;
|
163
193
|
|
164
194
|
void ggml_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {
|
@@ -177,6 +207,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
|
177
207
|
ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
|
178
208
|
} else {
|
179
209
|
char* buffer2 = malloc(len+1);
|
210
|
+
va_end(args);
|
211
|
+
va_start(args, format);
|
180
212
|
vsnprintf(buffer2, len+1, format, args);
|
181
213
|
buffer2[len] = 0;
|
182
214
|
ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
|
@@ -316,9 +348,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
316
348
|
GGML_METAL_ADD_KERNEL(div_row);
|
317
349
|
GGML_METAL_ADD_KERNEL(scale);
|
318
350
|
GGML_METAL_ADD_KERNEL(scale_4);
|
319
|
-
GGML_METAL_ADD_KERNEL(
|
351
|
+
GGML_METAL_ADD_KERNEL(tanh);
|
320
352
|
GGML_METAL_ADD_KERNEL(relu);
|
321
353
|
GGML_METAL_ADD_KERNEL(gelu);
|
354
|
+
GGML_METAL_ADD_KERNEL(gelu_quick);
|
355
|
+
GGML_METAL_ADD_KERNEL(silu);
|
322
356
|
GGML_METAL_ADD_KERNEL(soft_max);
|
323
357
|
GGML_METAL_ADD_KERNEL(soft_max_4);
|
324
358
|
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
@@ -336,6 +370,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
336
370
|
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
|
337
371
|
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
338
372
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
373
|
+
GGML_METAL_ADD_KERNEL(group_norm);
|
339
374
|
GGML_METAL_ADD_KERNEL(norm);
|
340
375
|
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
341
376
|
GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
|
@@ -352,6 +387,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
352
387
|
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
|
353
388
|
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
|
354
389
|
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
|
390
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
|
391
|
+
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
|
392
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
|
393
|
+
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
|
394
|
+
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
|
395
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
|
396
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
|
397
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
|
398
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
|
399
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
|
400
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
|
401
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
|
402
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
|
403
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
|
404
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
|
355
405
|
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
356
406
|
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
357
407
|
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
@@ -382,8 +432,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
382
432
|
GGML_METAL_ADD_KERNEL(rope_f16);
|
383
433
|
GGML_METAL_ADD_KERNEL(alibi_f32);
|
384
434
|
GGML_METAL_ADD_KERNEL(im2col_f16);
|
435
|
+
GGML_METAL_ADD_KERNEL(upscale_f32);
|
436
|
+
GGML_METAL_ADD_KERNEL(pad_f32);
|
385
437
|
GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
|
386
438
|
GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
|
439
|
+
GGML_METAL_ADD_KERNEL(leaky_relu_f32);
|
387
440
|
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
388
441
|
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
389
442
|
GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
|
@@ -392,6 +445,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
392
445
|
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
|
393
446
|
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
|
394
447
|
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
448
|
+
GGML_METAL_ADD_KERNEL(cpy_f16_f32);
|
395
449
|
GGML_METAL_ADD_KERNEL(concat);
|
396
450
|
GGML_METAL_ADD_KERNEL(sqr);
|
397
451
|
GGML_METAL_ADD_KERNEL(sum_rows);
|
@@ -416,9 +470,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
416
470
|
GGML_METAL_DEL_KERNEL(div_row);
|
417
471
|
GGML_METAL_DEL_KERNEL(scale);
|
418
472
|
GGML_METAL_DEL_KERNEL(scale_4);
|
419
|
-
GGML_METAL_DEL_KERNEL(
|
473
|
+
GGML_METAL_DEL_KERNEL(tanh);
|
420
474
|
GGML_METAL_DEL_KERNEL(relu);
|
421
475
|
GGML_METAL_DEL_KERNEL(gelu);
|
476
|
+
GGML_METAL_DEL_KERNEL(gelu_quick);
|
477
|
+
GGML_METAL_DEL_KERNEL(silu);
|
422
478
|
GGML_METAL_DEL_KERNEL(soft_max);
|
423
479
|
GGML_METAL_DEL_KERNEL(soft_max_4);
|
424
480
|
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
@@ -436,6 +492,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
436
492
|
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
437
493
|
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
438
494
|
GGML_METAL_DEL_KERNEL(rms_norm);
|
495
|
+
GGML_METAL_DEL_KERNEL(group_norm);
|
439
496
|
GGML_METAL_DEL_KERNEL(norm);
|
440
497
|
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
441
498
|
GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
|
@@ -452,6 +509,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
452
509
|
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
|
453
510
|
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
454
511
|
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
512
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
|
513
|
+
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
|
514
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
|
515
|
+
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
|
516
|
+
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
|
517
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
|
518
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
|
519
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
|
520
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
|
521
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
|
522
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
|
523
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
|
524
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
|
525
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
|
526
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
|
455
527
|
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
456
528
|
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
457
529
|
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
@@ -482,8 +554,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
482
554
|
GGML_METAL_DEL_KERNEL(rope_f16);
|
483
555
|
GGML_METAL_DEL_KERNEL(alibi_f32);
|
484
556
|
GGML_METAL_DEL_KERNEL(im2col_f16);
|
557
|
+
GGML_METAL_DEL_KERNEL(upscale_f32);
|
558
|
+
GGML_METAL_DEL_KERNEL(pad_f32);
|
485
559
|
GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
|
486
560
|
GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
|
561
|
+
GGML_METAL_DEL_KERNEL(leaky_relu_f32);
|
487
562
|
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
488
563
|
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
489
564
|
GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
|
@@ -492,6 +567,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
492
567
|
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
|
493
568
|
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
|
494
569
|
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
570
|
+
GGML_METAL_DEL_KERNEL(cpy_f16_f32);
|
495
571
|
GGML_METAL_DEL_KERNEL(concat);
|
496
572
|
GGML_METAL_DEL_KERNEL(sqr);
|
497
573
|
GGML_METAL_DEL_KERNEL(sum_rows);
|
@@ -539,12 +615,24 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
|
|
539
615
|
}
|
540
616
|
|
541
617
|
// temporarily defined here for compatibility between ggml-backend and the old API
|
542
|
-
|
543
|
-
|
618
|
+
|
619
|
+
struct ggml_backend_metal_buffer {
|
620
|
+
void * data;
|
621
|
+
size_t size;
|
544
622
|
|
545
623
|
id<MTLBuffer> metal;
|
546
624
|
};
|
547
625
|
|
626
|
+
struct ggml_backend_metal_buffer_context {
|
627
|
+
void * all_data;
|
628
|
+
size_t all_size;
|
629
|
+
bool owned;
|
630
|
+
|
631
|
+
// multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
|
632
|
+
int n_buffers;
|
633
|
+
struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
634
|
+
};
|
635
|
+
|
548
636
|
// finds the Metal buffer that contains the tensor data on the GPU device
|
549
637
|
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
550
638
|
// Metal buffer based on the host memory pointer
|
@@ -554,17 +642,29 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
|
554
642
|
|
555
643
|
const int64_t tsize = ggml_nbytes(t);
|
556
644
|
|
645
|
+
ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
|
646
|
+
|
557
647
|
// compatibility with ggml-backend
|
558
|
-
if (
|
559
|
-
struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *)
|
648
|
+
if (buffer && buffer->buft == ggml_backend_metal_buffer_type()) {
|
649
|
+
struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context;
|
560
650
|
|
561
|
-
|
651
|
+
// find the view that contains the tensor fully
|
652
|
+
for (int i = 0; i < buf_ctx->n_buffers; ++i) {
|
653
|
+
const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data;
|
562
654
|
|
563
|
-
|
655
|
+
//GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size);
|
656
|
+
if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) {
|
657
|
+
*offs = (size_t) ioffs;
|
564
658
|
|
565
|
-
|
659
|
+
//GGML_METAL_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs);
|
566
660
|
|
567
|
-
|
661
|
+
return buf_ctx->buffers[i].metal;
|
662
|
+
}
|
663
|
+
}
|
664
|
+
|
665
|
+
GGML_METAL_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name);
|
666
|
+
|
667
|
+
return nil;
|
568
668
|
}
|
569
669
|
|
570
670
|
// find the view that contains the tensor fully
|
@@ -793,9 +893,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
|
793
893
|
switch (op->op) {
|
794
894
|
case GGML_OP_UNARY:
|
795
895
|
switch (ggml_get_unary_op(op)) {
|
796
|
-
case
|
896
|
+
case GGML_UNARY_OP_TANH:
|
797
897
|
case GGML_UNARY_OP_RELU:
|
798
898
|
case GGML_UNARY_OP_GELU:
|
899
|
+
case GGML_UNARY_OP_GELU_QUICK:
|
900
|
+
case GGML_UNARY_OP_SILU:
|
799
901
|
return true;
|
800
902
|
default:
|
801
903
|
return false;
|
@@ -807,6 +909,7 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
|
807
909
|
case GGML_OP_PERMUTE:
|
808
910
|
case GGML_OP_CONCAT:
|
809
911
|
case GGML_OP_ADD:
|
912
|
+
case GGML_OP_ACC:
|
810
913
|
case GGML_OP_MUL:
|
811
914
|
case GGML_OP_DIV:
|
812
915
|
case GGML_OP_SCALE:
|
@@ -814,21 +917,50 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
|
814
917
|
case GGML_OP_SUM_ROWS:
|
815
918
|
case GGML_OP_SOFT_MAX:
|
816
919
|
case GGML_OP_RMS_NORM:
|
920
|
+
case GGML_OP_GROUP_NORM:
|
817
921
|
case GGML_OP_NORM:
|
818
922
|
case GGML_OP_ALIBI:
|
819
923
|
case GGML_OP_ROPE:
|
820
924
|
case GGML_OP_IM2COL:
|
925
|
+
case GGML_OP_UPSCALE:
|
926
|
+
case GGML_OP_PAD:
|
821
927
|
case GGML_OP_ARGSORT:
|
822
|
-
case
|
823
|
-
case GGML_OP_CPY:
|
824
|
-
case GGML_OP_CONT:
|
928
|
+
case GGML_OP_LEAKY_RELU:
|
825
929
|
case GGML_OP_MUL_MAT:
|
826
930
|
case GGML_OP_MUL_MAT_ID:
|
827
931
|
return true;
|
932
|
+
case GGML_OP_CPY:
|
933
|
+
case GGML_OP_DUP:
|
934
|
+
case GGML_OP_CONT:
|
935
|
+
{
|
936
|
+
switch (op->src[0]->type) {
|
937
|
+
case GGML_TYPE_F32:
|
938
|
+
switch (op->type) {
|
939
|
+
case GGML_TYPE_F16:
|
940
|
+
case GGML_TYPE_F32:
|
941
|
+
case GGML_TYPE_Q8_0:
|
942
|
+
case GGML_TYPE_Q4_0:
|
943
|
+
case GGML_TYPE_Q4_1:
|
944
|
+
return true;
|
945
|
+
default:
|
946
|
+
return false;
|
947
|
+
}
|
948
|
+
case GGML_TYPE_F16:
|
949
|
+
switch (op->type) {
|
950
|
+
case GGML_TYPE_F16:
|
951
|
+
case GGML_TYPE_F32:
|
952
|
+
return true;
|
953
|
+
default:
|
954
|
+
return false;
|
955
|
+
}
|
956
|
+
default:
|
957
|
+
return false;
|
958
|
+
};
|
959
|
+
}
|
828
960
|
case GGML_OP_DIAG_MASK_INF:
|
829
961
|
case GGML_OP_GET_ROWS:
|
830
962
|
{
|
831
|
-
return op->ne[
|
963
|
+
return op->ne[3] == 1;
|
832
964
|
}
|
833
965
|
default:
|
834
966
|
return false;
|
@@ -904,7 +1036,10 @@ void ggml_metal_graph_compute(
|
|
904
1036
|
} break;
|
905
1037
|
}
|
906
1038
|
|
907
|
-
|
1039
|
+
if (!ggml_metal_supports_op(dst)) {
|
1040
|
+
GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
|
1041
|
+
GGML_ASSERT(!"unsupported op");
|
1042
|
+
}
|
908
1043
|
|
909
1044
|
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
910
1045
|
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
@@ -1001,34 +1136,39 @@ void ggml_metal_graph_compute(
|
|
1001
1136
|
case GGML_OP_MUL:
|
1002
1137
|
case GGML_OP_DIV:
|
1003
1138
|
{
|
1004
|
-
|
1005
|
-
GGML_ASSERT(ggml_is_contiguous(src1));
|
1139
|
+
const size_t offs = 0;
|
1006
1140
|
|
1007
1141
|
bool bcast_row = false;
|
1008
1142
|
|
1009
1143
|
int64_t nb = ne00;
|
1010
1144
|
|
1011
|
-
|
1145
|
+
id<MTLComputePipelineState> pipeline = nil;
|
1146
|
+
|
1147
|
+
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
1148
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
1149
|
+
|
1012
1150
|
// src1 is a row
|
1013
1151
|
GGML_ASSERT(ne11 == 1);
|
1014
1152
|
|
1015
1153
|
nb = ne00 / 4;
|
1016
1154
|
switch (dst->op) {
|
1017
|
-
case GGML_OP_ADD:
|
1018
|
-
case GGML_OP_MUL:
|
1019
|
-
case GGML_OP_DIV:
|
1155
|
+
case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
|
1156
|
+
case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
|
1157
|
+
case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
|
1020
1158
|
default: GGML_ASSERT(false);
|
1021
1159
|
}
|
1022
1160
|
|
1023
1161
|
bcast_row = true;
|
1024
1162
|
} else {
|
1025
1163
|
switch (dst->op) {
|
1026
|
-
case GGML_OP_ADD:
|
1027
|
-
case GGML_OP_MUL:
|
1028
|
-
case GGML_OP_DIV:
|
1164
|
+
case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
|
1165
|
+
case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
|
1166
|
+
case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
|
1029
1167
|
default: GGML_ASSERT(false);
|
1030
1168
|
}
|
1031
1169
|
}
|
1170
|
+
|
1171
|
+
[encoder setComputePipelineState:pipeline];
|
1032
1172
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1033
1173
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1034
1174
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
@@ -1056,23 +1196,104 @@ void ggml_metal_graph_compute(
|
|
1056
1196
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
1057
1197
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
1058
1198
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
1059
|
-
[encoder setBytes:&
|
1199
|
+
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
1200
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
|
1060
1201
|
|
1061
1202
|
if (bcast_row) {
|
1062
1203
|
const int64_t n = ggml_nelements(dst)/4;
|
1063
1204
|
|
1064
1205
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1065
1206
|
} else {
|
1066
|
-
const int nth = MIN(
|
1207
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
1067
1208
|
|
1068
1209
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1069
1210
|
}
|
1070
1211
|
} break;
|
1212
|
+
case GGML_OP_ACC:
|
1213
|
+
{
|
1214
|
+
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
1215
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1216
|
+
GGML_ASSERT(dstt == GGML_TYPE_F32);
|
1217
|
+
|
1218
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
1219
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
1220
|
+
|
1221
|
+
const size_t pnb1 = ((int32_t *) dst->op_params)[0];
|
1222
|
+
const size_t pnb2 = ((int32_t *) dst->op_params)[1];
|
1223
|
+
const size_t pnb3 = ((int32_t *) dst->op_params)[2];
|
1224
|
+
const size_t offs = ((int32_t *) dst->op_params)[3];
|
1225
|
+
|
1226
|
+
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
1227
|
+
|
1228
|
+
if (!inplace) {
|
1229
|
+
// run a separete kernel to cpy src->dst
|
1230
|
+
// not sure how to avoid this
|
1231
|
+
// TODO: make a simpler cpy_bytes kernel
|
1232
|
+
|
1233
|
+
const int nth = MIN(1024, ne00);
|
1234
|
+
|
1235
|
+
[encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
|
1236
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1237
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1238
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
1239
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
1240
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
1241
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
1242
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
1243
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
1244
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
1245
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
1246
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
1247
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
1248
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
1249
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
1250
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
1251
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
1252
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
1253
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
1254
|
+
|
1255
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1256
|
+
}
|
1257
|
+
|
1258
|
+
[encoder setComputePipelineState:ctx->pipeline_add];
|
1259
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1260
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1261
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1262
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
1263
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
1264
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
1265
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
1266
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
1267
|
+
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
|
1268
|
+
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
|
1269
|
+
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
|
1270
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
1271
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
1272
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
1273
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
1274
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
1275
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
1276
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
1277
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
1278
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
1279
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
1280
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
1281
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
1282
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
1283
|
+
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
|
1284
|
+
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
|
1285
|
+
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
|
1286
|
+
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
1287
|
+
|
1288
|
+
const int nth = MIN(1024, ne0);
|
1289
|
+
|
1290
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1291
|
+
} break;
|
1071
1292
|
case GGML_OP_SCALE:
|
1072
1293
|
{
|
1073
1294
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
1074
1295
|
|
1075
|
-
const float scale = *(const float *)
|
1296
|
+
const float scale = *(const float *) dst->op_params;
|
1076
1297
|
|
1077
1298
|
int64_t n = ggml_nelements(dst);
|
1078
1299
|
|
@@ -1083,24 +1304,23 @@ void ggml_metal_graph_compute(
|
|
1083
1304
|
[encoder setComputePipelineState:ctx->pipeline_scale];
|
1084
1305
|
}
|
1085
1306
|
|
1086
|
-
[encoder setBuffer:id_src0
|
1087
|
-
[encoder setBuffer:id_dst
|
1307
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1308
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1088
1309
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
1089
1310
|
|
1090
1311
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1091
1312
|
} break;
|
1092
1313
|
case GGML_OP_UNARY:
|
1093
1314
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
1094
|
-
case
|
1315
|
+
case GGML_UNARY_OP_TANH:
|
1095
1316
|
{
|
1096
|
-
[encoder setComputePipelineState:ctx->
|
1317
|
+
[encoder setComputePipelineState:ctx->pipeline_tanh];
|
1097
1318
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1098
1319
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1099
1320
|
|
1100
1321
|
const int64_t n = ggml_nelements(dst);
|
1101
|
-
GGML_ASSERT(n % 4 == 0);
|
1102
1322
|
|
1103
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n
|
1323
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1104
1324
|
} break;
|
1105
1325
|
case GGML_UNARY_OP_RELU:
|
1106
1326
|
{
|
@@ -1121,6 +1341,28 @@ void ggml_metal_graph_compute(
|
|
1121
1341
|
const int64_t n = ggml_nelements(dst);
|
1122
1342
|
GGML_ASSERT(n % 4 == 0);
|
1123
1343
|
|
1344
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1345
|
+
} break;
|
1346
|
+
case GGML_UNARY_OP_GELU_QUICK:
|
1347
|
+
{
|
1348
|
+
[encoder setComputePipelineState:ctx->pipeline_gelu_quick];
|
1349
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1350
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1351
|
+
|
1352
|
+
const int64_t n = ggml_nelements(dst);
|
1353
|
+
GGML_ASSERT(n % 4 == 0);
|
1354
|
+
|
1355
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1356
|
+
} break;
|
1357
|
+
case GGML_UNARY_OP_SILU:
|
1358
|
+
{
|
1359
|
+
[encoder setComputePipelineState:ctx->pipeline_silu];
|
1360
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1361
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1362
|
+
|
1363
|
+
const int64_t n = ggml_nelements(dst);
|
1364
|
+
GGML_ASSERT(n % 4 == 0);
|
1365
|
+
|
1124
1366
|
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1125
1367
|
} break;
|
1126
1368
|
default:
|
@@ -1193,7 +1435,11 @@ void ggml_metal_graph_compute(
|
|
1193
1435
|
const float scale = ((float *) dst->op_params)[0];
|
1194
1436
|
|
1195
1437
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1196
|
-
|
1438
|
+
if (id_src1) {
|
1439
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1440
|
+
} else {
|
1441
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
1442
|
+
}
|
1197
1443
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1198
1444
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
1199
1445
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
@@ -1444,7 +1690,7 @@ void ggml_metal_graph_compute(
|
|
1444
1690
|
else if (src0t == GGML_TYPE_Q6_K) {
|
1445
1691
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1446
1692
|
} else {
|
1447
|
-
int64_t ny = (ne11 + nrows - 1)/nrows;
|
1693
|
+
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
1448
1694
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1449
1695
|
}
|
1450
1696
|
}
|
@@ -1456,7 +1702,7 @@ void ggml_metal_graph_compute(
|
|
1456
1702
|
|
1457
1703
|
GGML_ASSERT(src0t == GGML_TYPE_I32);
|
1458
1704
|
|
1459
|
-
const int n_as =
|
1705
|
+
const int n_as = ((int32_t *) dst->op_params)[1];
|
1460
1706
|
|
1461
1707
|
// TODO: make this more general
|
1462
1708
|
GGML_ASSERT(n_as <= 8);
|
@@ -1488,14 +1734,22 @@ void ggml_metal_graph_compute(
|
|
1488
1734
|
|
1489
1735
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
1490
1736
|
// to the matrix-vector kernel
|
1491
|
-
int ne11_mm_min =
|
1737
|
+
int ne11_mm_min = 1;
|
1492
1738
|
|
1493
1739
|
const int idx = ((int32_t *) dst->op_params)[0];
|
1494
1740
|
|
1741
|
+
// batch size
|
1742
|
+
GGML_ASSERT(ne01 == ne11);
|
1743
|
+
|
1744
|
+
const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
|
1745
|
+
|
1495
1746
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
1496
1747
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
1497
|
-
|
1498
|
-
|
1748
|
+
// !!!
|
1749
|
+
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
|
1750
|
+
// indirect matrix multiplication
|
1751
|
+
// !!!
|
1752
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
|
1499
1753
|
switch (src2->type) {
|
1500
1754
|
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
|
1501
1755
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
|
@@ -1514,19 +1768,22 @@ void ggml_metal_graph_compute(
|
|
1514
1768
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1515
1769
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1516
1770
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1517
|
-
[encoder setBytes:&
|
1518
|
-
[encoder setBytes:&
|
1519
|
-
[encoder setBytes:&
|
1520
|
-
[encoder setBytes:&
|
1521
|
-
[encoder setBytes:&
|
1522
|
-
[encoder setBytes:&
|
1523
|
-
[encoder setBytes:&
|
1524
|
-
[encoder setBytes:&
|
1525
|
-
[encoder setBytes:&
|
1526
|
-
[encoder setBytes:&
|
1527
|
-
[encoder setBytes:&
|
1528
|
-
[encoder setBytes:&
|
1529
|
-
[encoder setBytes:&
|
1771
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
1772
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
1773
|
+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
|
1774
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
1775
|
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
|
1776
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
1777
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
|
1778
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
|
1779
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
|
1780
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
|
1781
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
1782
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
|
1783
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
1784
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
|
1785
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
1786
|
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
|
1530
1787
|
// TODO: how to make this an array? read Metal docs
|
1531
1788
|
for (int j = 0; j < n_as; ++j) {
|
1532
1789
|
struct ggml_tensor * src_cur = dst->src[2 + j];
|
@@ -1534,11 +1791,157 @@ void ggml_metal_graph_compute(
|
|
1534
1791
|
size_t offs_src_cur = 0;
|
1535
1792
|
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
1536
1793
|
|
1537
|
-
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:
|
1794
|
+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
1538
1795
|
}
|
1539
1796
|
|
1540
1797
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
1541
|
-
|
1798
|
+
|
1799
|
+
// TODO: processing one row at a time (ne11 -> 1) is not efficient
|
1800
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
1801
|
+
} else {
|
1802
|
+
int nth0 = 32;
|
1803
|
+
int nth1 = 1;
|
1804
|
+
int nrows = 1;
|
1805
|
+
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
1806
|
+
|
1807
|
+
// use custom matrix x vector kernel
|
1808
|
+
switch (src2t) {
|
1809
|
+
case GGML_TYPE_F32:
|
1810
|
+
{
|
1811
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1812
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
|
1813
|
+
} break;
|
1814
|
+
case GGML_TYPE_F16:
|
1815
|
+
{
|
1816
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1817
|
+
nth0 = 32;
|
1818
|
+
nth1 = 1;
|
1819
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
|
1820
|
+
} break;
|
1821
|
+
case GGML_TYPE_Q4_0:
|
1822
|
+
{
|
1823
|
+
nth0 = 8;
|
1824
|
+
nth1 = 8;
|
1825
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
|
1826
|
+
} break;
|
1827
|
+
case GGML_TYPE_Q4_1:
|
1828
|
+
{
|
1829
|
+
nth0 = 8;
|
1830
|
+
nth1 = 8;
|
1831
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
|
1832
|
+
} break;
|
1833
|
+
case GGML_TYPE_Q5_0:
|
1834
|
+
{
|
1835
|
+
nth0 = 8;
|
1836
|
+
nth1 = 8;
|
1837
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
|
1838
|
+
} break;
|
1839
|
+
case GGML_TYPE_Q5_1:
|
1840
|
+
{
|
1841
|
+
nth0 = 8;
|
1842
|
+
nth1 = 8;
|
1843
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
|
1844
|
+
} break;
|
1845
|
+
case GGML_TYPE_Q8_0:
|
1846
|
+
{
|
1847
|
+
nth0 = 8;
|
1848
|
+
nth1 = 8;
|
1849
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
|
1850
|
+
} break;
|
1851
|
+
case GGML_TYPE_Q2_K:
|
1852
|
+
{
|
1853
|
+
nth0 = 2;
|
1854
|
+
nth1 = 32;
|
1855
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
|
1856
|
+
} break;
|
1857
|
+
case GGML_TYPE_Q3_K:
|
1858
|
+
{
|
1859
|
+
nth0 = 2;
|
1860
|
+
nth1 = 32;
|
1861
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
|
1862
|
+
} break;
|
1863
|
+
case GGML_TYPE_Q4_K:
|
1864
|
+
{
|
1865
|
+
nth0 = 4; //1;
|
1866
|
+
nth1 = 8; //32;
|
1867
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
|
1868
|
+
} break;
|
1869
|
+
case GGML_TYPE_Q5_K:
|
1870
|
+
{
|
1871
|
+
nth0 = 2;
|
1872
|
+
nth1 = 32;
|
1873
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
|
1874
|
+
} break;
|
1875
|
+
case GGML_TYPE_Q6_K:
|
1876
|
+
{
|
1877
|
+
nth0 = 2;
|
1878
|
+
nth1 = 32;
|
1879
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
|
1880
|
+
} break;
|
1881
|
+
default:
|
1882
|
+
{
|
1883
|
+
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
1884
|
+
GGML_ASSERT(false && "not implemented");
|
1885
|
+
}
|
1886
|
+
};
|
1887
|
+
|
1888
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1889
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1890
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1891
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
1892
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
1893
|
+
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
1894
|
+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
|
1895
|
+
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
|
1896
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
|
1897
|
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
|
1898
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
1899
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
|
1900
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
1901
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
1902
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
1903
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
1904
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
1905
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
1906
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
|
1907
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
|
1908
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
|
1909
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
|
1910
|
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:22];
|
1911
|
+
// TODO: how to make this an array? read Metal docs
|
1912
|
+
for (int j = 0; j < n_as; ++j) {
|
1913
|
+
struct ggml_tensor * src_cur = dst->src[2 + j];
|
1914
|
+
|
1915
|
+
size_t offs_src_cur = 0;
|
1916
|
+
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
1917
|
+
|
1918
|
+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
1919
|
+
}
|
1920
|
+
|
1921
|
+
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
1922
|
+
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
1923
|
+
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
|
1924
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1925
|
+
}
|
1926
|
+
else if (src2t == GGML_TYPE_Q4_K) {
|
1927
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1928
|
+
}
|
1929
|
+
else if (src2t == GGML_TYPE_Q3_K) {
|
1930
|
+
#ifdef GGML_QKK_64
|
1931
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1932
|
+
#else
|
1933
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1934
|
+
#endif
|
1935
|
+
}
|
1936
|
+
else if (src2t == GGML_TYPE_Q5_K) {
|
1937
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1938
|
+
}
|
1939
|
+
else if (src2t == GGML_TYPE_Q6_K) {
|
1940
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1941
|
+
} else {
|
1942
|
+
const int64_t ny = (_ne1 + nrows - 1)/nrows;
|
1943
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1944
|
+
}
|
1542
1945
|
}
|
1543
1946
|
} break;
|
1544
1947
|
case GGML_OP_GET_ROWS:
|
@@ -1559,16 +1962,19 @@ void ggml_metal_graph_compute(
|
|
1559
1962
|
default: GGML_ASSERT(false && "not implemented");
|
1560
1963
|
}
|
1561
1964
|
|
1562
|
-
[encoder setBuffer:id_src0
|
1563
|
-
[encoder setBuffer:id_src1
|
1564
|
-
[encoder setBuffer:id_dst
|
1965
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1966
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1967
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1565
1968
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
1566
1969
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
1567
|
-
[encoder setBytes:&
|
1568
|
-
|
1569
|
-
|
1570
|
-
|
1571
|
-
[encoder
|
1970
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
|
1971
|
+
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
|
1972
|
+
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
|
1973
|
+
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
|
1974
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
|
1975
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
|
1976
|
+
|
1977
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
1572
1978
|
} break;
|
1573
1979
|
case GGML_OP_RMS_NORM:
|
1574
1980
|
{
|
@@ -1595,6 +2001,38 @@ void ggml_metal_graph_compute(
|
|
1595
2001
|
|
1596
2002
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1597
2003
|
} break;
|
2004
|
+
case GGML_OP_GROUP_NORM:
|
2005
|
+
{
|
2006
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
2007
|
+
|
2008
|
+
//float eps;
|
2009
|
+
//memcpy(&eps, dst->op_params, sizeof(float));
|
2010
|
+
|
2011
|
+
const float eps = 1e-6f; // TODO: temporarily hardcoded
|
2012
|
+
|
2013
|
+
const int32_t n_groups = ((int32_t *) dst->op_params)[0];
|
2014
|
+
|
2015
|
+
int nth = 32; // SIMD width
|
2016
|
+
|
2017
|
+
//while (nth < ne00/4 && nth < 1024) {
|
2018
|
+
// nth *= 2;
|
2019
|
+
//}
|
2020
|
+
|
2021
|
+
[encoder setComputePipelineState:ctx->pipeline_group_norm];
|
2022
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2023
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2024
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
2025
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
2026
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
2027
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
|
2028
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
|
2029
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
|
2030
|
+
[encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
|
2031
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:9];
|
2032
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
2033
|
+
|
2034
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2035
|
+
} break;
|
1598
2036
|
case GGML_OP_NORM:
|
1599
2037
|
{
|
1600
2038
|
float eps;
|
@@ -1764,6 +2202,65 @@ void ggml_metal_graph_compute(
|
|
1764
2202
|
|
1765
2203
|
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
1766
2204
|
} break;
|
2205
|
+
case GGML_OP_UPSCALE:
|
2206
|
+
{
|
2207
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2208
|
+
|
2209
|
+
const int sf = dst->op_params[0];
|
2210
|
+
|
2211
|
+
[encoder setComputePipelineState:ctx->pipeline_upscale_f32];
|
2212
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2213
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2214
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
2215
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
2216
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
2217
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
2218
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
2219
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
2220
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
2221
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
2222
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
2223
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
2224
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
2225
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
2226
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
2227
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
2228
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
2229
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
2230
|
+
[encoder setBytes:&sf length:sizeof(sf) atIndex:18];
|
2231
|
+
|
2232
|
+
const int nth = MIN(1024, ne0);
|
2233
|
+
|
2234
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2235
|
+
} break;
|
2236
|
+
case GGML_OP_PAD:
|
2237
|
+
{
|
2238
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2239
|
+
|
2240
|
+
[encoder setComputePipelineState:ctx->pipeline_pad_f32];
|
2241
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2242
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2243
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
2244
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
2245
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
2246
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
2247
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
2248
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
2249
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
2250
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
2251
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
2252
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
2253
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
2254
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
2255
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
2256
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
2257
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
2258
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
2259
|
+
|
2260
|
+
const int nth = MIN(1024, ne0);
|
2261
|
+
|
2262
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2263
|
+
} break;
|
1767
2264
|
case GGML_OP_ARGSORT:
|
1768
2265
|
{
|
1769
2266
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
@@ -1785,6 +2282,22 @@ void ggml_metal_graph_compute(
|
|
1785
2282
|
|
1786
2283
|
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
|
1787
2284
|
} break;
|
2285
|
+
case GGML_OP_LEAKY_RELU:
|
2286
|
+
{
|
2287
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2288
|
+
|
2289
|
+
float slope;
|
2290
|
+
memcpy(&slope, dst->op_params, sizeof(float));
|
2291
|
+
|
2292
|
+
[encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
|
2293
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2294
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2295
|
+
[encoder setBytes:&slope length:sizeof(slope) atIndex:2];
|
2296
|
+
|
2297
|
+
const int64_t n = ggml_nelements(dst);
|
2298
|
+
|
2299
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
2300
|
+
} break;
|
1788
2301
|
case GGML_OP_DUP:
|
1789
2302
|
case GGML_OP_CPY:
|
1790
2303
|
case GGML_OP_CONT:
|
@@ -1813,7 +2326,7 @@ void ggml_metal_graph_compute(
|
|
1813
2326
|
{
|
1814
2327
|
switch (dstt) {
|
1815
2328
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
|
1816
|
-
case GGML_TYPE_F32:
|
2329
|
+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
|
1817
2330
|
default: GGML_ASSERT(false && "not implemented");
|
1818
2331
|
};
|
1819
2332
|
} break;
|
@@ -1880,6 +2393,7 @@ void ggml_metal_graph_compute(
|
|
1880
2393
|
|
1881
2394
|
// backend interface
|
1882
2395
|
|
2396
|
+
// default buffer
|
1883
2397
|
static id<MTLDevice> g_backend_device = nil;
|
1884
2398
|
static int g_backend_device_ref_count = 0;
|
1885
2399
|
|
@@ -1907,34 +2421,31 @@ static void ggml_backend_metal_free_device(void) {
|
|
1907
2421
|
static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
|
1908
2422
|
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
|
1909
2423
|
|
1910
|
-
return ctx->
|
2424
|
+
return ctx->all_data;
|
1911
2425
|
}
|
1912
2426
|
|
1913
2427
|
static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
1914
2428
|
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
|
1915
2429
|
|
1916
|
-
|
2430
|
+
for (int i = 0; i < ctx->n_buffers; i++) {
|
2431
|
+
[ctx->buffers[i].metal release];
|
2432
|
+
}
|
1917
2433
|
ggml_backend_metal_free_device();
|
1918
2434
|
|
1919
|
-
|
1920
|
-
|
2435
|
+
if (ctx->owned) {
|
2436
|
+
free(ctx->all_data);
|
2437
|
+
}
|
1921
2438
|
|
1922
|
-
|
2439
|
+
free(ctx);
|
1923
2440
|
}
|
1924
2441
|
|
1925
2442
|
static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
1926
|
-
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
|
1927
|
-
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
1928
|
-
|
1929
2443
|
memcpy((char *)tensor->data + offset, data, size);
|
1930
2444
|
|
1931
2445
|
UNUSED(buffer);
|
1932
2446
|
}
|
1933
2447
|
|
1934
2448
|
static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
1935
|
-
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
|
1936
|
-
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
1937
|
-
|
1938
2449
|
memcpy(data, (const char *)tensor->data + offset, size);
|
1939
2450
|
|
1940
2451
|
UNUSED(buffer);
|
@@ -1952,7 +2463,13 @@ static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer
|
|
1952
2463
|
UNUSED(buffer);
|
1953
2464
|
}
|
1954
2465
|
|
1955
|
-
static
|
2466
|
+
static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
2467
|
+
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
|
2468
|
+
|
2469
|
+
memset(ctx->all_data, value, ctx->all_size);
|
2470
|
+
}
|
2471
|
+
|
2472
|
+
static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
|
1956
2473
|
/* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
|
1957
2474
|
/* .get_base = */ ggml_backend_metal_buffer_get_base,
|
1958
2475
|
/* .init_tensor = */ NULL,
|
@@ -1960,8 +2477,11 @@ static struct ggml_backend_buffer_i metal_backend_buffer_i = {
|
|
1960
2477
|
/* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
|
1961
2478
|
/* .cpy_tensor_from = */ ggml_backend_metal_buffer_cpy_tensor_from,
|
1962
2479
|
/* .cpy_tensor_to = */ ggml_backend_metal_buffer_cpy_tensor_to,
|
2480
|
+
/* .clear = */ ggml_backend_metal_buffer_clear,
|
1963
2481
|
};
|
1964
2482
|
|
2483
|
+
// default buffer type
|
2484
|
+
|
1965
2485
|
static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
1966
2486
|
struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
|
1967
2487
|
|
@@ -1972,13 +2492,46 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
|
1972
2492
|
size_aligned += (size_page - (size_aligned % size_page));
|
1973
2493
|
}
|
1974
2494
|
|
1975
|
-
|
1976
|
-
|
2495
|
+
id<MTLDevice> device = ggml_backend_metal_get_device();
|
2496
|
+
|
2497
|
+
ctx->all_data = ggml_metal_host_malloc(size_aligned);
|
2498
|
+
ctx->all_size = size_aligned;
|
2499
|
+
ctx->owned = true;
|
2500
|
+
ctx->n_buffers = 1;
|
2501
|
+
|
2502
|
+
ctx->buffers[0].data = ctx->all_data;
|
2503
|
+
ctx->buffers[0].size = size;
|
2504
|
+
ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
|
1977
2505
|
length:size_aligned
|
1978
2506
|
options:MTLResourceStorageModeShared
|
1979
2507
|
deallocator:nil];
|
1980
2508
|
|
1981
|
-
|
2509
|
+
if (ctx->buffers[0].metal == nil) {
|
2510
|
+
GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
2511
|
+
free(ctx);
|
2512
|
+
ggml_backend_metal_free_device();
|
2513
|
+
return NULL;
|
2514
|
+
}
|
2515
|
+
|
2516
|
+
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
|
2517
|
+
|
2518
|
+
|
2519
|
+
#if TARGET_OS_OSX
|
2520
|
+
GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
2521
|
+
device.currentAllocatedSize / 1024.0 / 1024.0,
|
2522
|
+
device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
2523
|
+
|
2524
|
+
if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
|
2525
|
+
GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
2526
|
+
} else {
|
2527
|
+
GGML_METAL_LOG_INFO("\n");
|
2528
|
+
}
|
2529
|
+
#else
|
2530
|
+
GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
|
2531
|
+
#endif
|
2532
|
+
|
2533
|
+
|
2534
|
+
return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
|
1982
2535
|
}
|
1983
2536
|
|
1984
2537
|
static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
@@ -1989,7 +2542,13 @@ static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_t
|
|
1989
2542
|
static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
|
1990
2543
|
return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
|
1991
2544
|
|
1992
|
-
|
2545
|
+
UNUSED(buft);
|
2546
|
+
}
|
2547
|
+
|
2548
|
+
static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
2549
|
+
return true;
|
2550
|
+
|
2551
|
+
UNUSED(buft);
|
1993
2552
|
}
|
1994
2553
|
|
1995
2554
|
ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
@@ -1999,6 +2558,7 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
|
1999
2558
|
/* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
|
2000
2559
|
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
2001
2560
|
/* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
|
2561
|
+
/* .is_host = */ ggml_backend_metal_buffer_type_is_host,
|
2002
2562
|
},
|
2003
2563
|
/* .context = */ NULL,
|
2004
2564
|
};
|
@@ -2006,6 +2566,87 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
|
2006
2566
|
return &ggml_backend_buffer_type_metal;
|
2007
2567
|
}
|
2008
2568
|
|
2569
|
+
// buffer from ptr
|
2570
|
+
|
2571
|
+
ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
|
2572
|
+
struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
|
2573
|
+
|
2574
|
+
ctx->all_data = data;
|
2575
|
+
ctx->all_size = size;
|
2576
|
+
ctx->owned = false;
|
2577
|
+
ctx->n_buffers = 0;
|
2578
|
+
|
2579
|
+
const size_t size_page = sysconf(_SC_PAGESIZE);
|
2580
|
+
size_t size_aligned = size;
|
2581
|
+
if ((size_aligned % size_page) != 0) {
|
2582
|
+
size_aligned += (size_page - (size_aligned % size_page));
|
2583
|
+
}
|
2584
|
+
|
2585
|
+
id<MTLDevice> device = ggml_backend_metal_get_device();
|
2586
|
+
|
2587
|
+
// the buffer fits into the max buffer size allowed by the device
|
2588
|
+
if (size_aligned <= device.maxBufferLength) {
|
2589
|
+
ctx->buffers[ctx->n_buffers].data = data;
|
2590
|
+
ctx->buffers[ctx->n_buffers].size = size;
|
2591
|
+
|
2592
|
+
ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
2593
|
+
|
2594
|
+
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
2595
|
+
GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
2596
|
+
return false;
|
2597
|
+
}
|
2598
|
+
|
2599
|
+
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
|
2600
|
+
|
2601
|
+
++ctx->n_buffers;
|
2602
|
+
} else {
|
2603
|
+
// this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
|
2604
|
+
// one of the views
|
2605
|
+
const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
|
2606
|
+
const size_t size_step = device.maxBufferLength - size_ovlp;
|
2607
|
+
const size_t size_view = device.maxBufferLength;
|
2608
|
+
|
2609
|
+
for (size_t i = 0; i < size; i += size_step) {
|
2610
|
+
const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
|
2611
|
+
|
2612
|
+
ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
|
2613
|
+
ctx->buffers[ctx->n_buffers].size = size_step_aligned;
|
2614
|
+
|
2615
|
+
ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
2616
|
+
|
2617
|
+
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
2618
|
+
GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
|
2619
|
+
return false;
|
2620
|
+
}
|
2621
|
+
|
2622
|
+
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, offs = %12ld", __func__, size_step_aligned / 1024.0 / 1024.0, i);
|
2623
|
+
if (i + size_step < size) {
|
2624
|
+
GGML_METAL_LOG_INFO("\n");
|
2625
|
+
}
|
2626
|
+
|
2627
|
+
++ctx->n_buffers;
|
2628
|
+
}
|
2629
|
+
}
|
2630
|
+
|
2631
|
+
#if TARGET_OS_OSX
|
2632
|
+
GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
2633
|
+
device.currentAllocatedSize / 1024.0 / 1024.0,
|
2634
|
+
device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
2635
|
+
|
2636
|
+
if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
|
2637
|
+
GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
2638
|
+
} else {
|
2639
|
+
GGML_METAL_LOG_INFO("\n");
|
2640
|
+
}
|
2641
|
+
#else
|
2642
|
+
GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
|
2643
|
+
#endif
|
2644
|
+
|
2645
|
+
return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
|
2646
|
+
}
|
2647
|
+
|
2648
|
+
// backend
|
2649
|
+
|
2009
2650
|
static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
2010
2651
|
return "Metal";
|
2011
2652
|
|
@@ -2018,10 +2659,6 @@ static void ggml_backend_metal_free(ggml_backend_t backend) {
|
|
2018
2659
|
free(backend);
|
2019
2660
|
}
|
2020
2661
|
|
2021
|
-
static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
|
2022
|
-
UNUSED(backend);
|
2023
|
-
}
|
2024
|
-
|
2025
2662
|
static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
|
2026
2663
|
return ggml_backend_metal_buffer_type();
|
2027
2664
|
|
@@ -2048,25 +2685,15 @@ static struct ggml_backend_i metal_backend_i = {
|
|
2048
2685
|
/* .get_tensor_async = */ NULL,
|
2049
2686
|
/* .cpy_tensor_from_async = */ NULL,
|
2050
2687
|
/* .cpy_tensor_to_async = */ NULL,
|
2051
|
-
/* .synchronize = */
|
2052
|
-
/* .graph_plan_create = */ NULL,
|
2688
|
+
/* .synchronize = */ NULL,
|
2689
|
+
/* .graph_plan_create = */ NULL,
|
2053
2690
|
/* .graph_plan_free = */ NULL,
|
2054
2691
|
/* .graph_plan_compute = */ NULL,
|
2055
2692
|
/* .graph_compute = */ ggml_backend_metal_graph_compute,
|
2056
2693
|
/* .supports_op = */ ggml_backend_metal_supports_op,
|
2057
2694
|
};
|
2058
2695
|
|
2059
|
-
// TODO: make a common log callback for all backends in ggml-backend
|
2060
|
-
static void ggml_backend_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
|
2061
|
-
fprintf(stderr, "%s", msg);
|
2062
|
-
|
2063
|
-
UNUSED(level);
|
2064
|
-
UNUSED(user_data);
|
2065
|
-
}
|
2066
|
-
|
2067
2696
|
ggml_backend_t ggml_backend_metal_init(void) {
|
2068
|
-
ggml_metal_log_set_callback(ggml_backend_log_callback, NULL);
|
2069
|
-
|
2070
2697
|
struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
|
2071
2698
|
|
2072
2699
|
if (ctx == NULL) {
|