llama_cpp 0.10.0 → 0.10.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/ext/llama_cpp/llama_cpp.cpp +2 -0
- data/ext/llama_cpp/src/ggml-alloc.h +1 -1
- data/ext/llama_cpp/src/ggml-cuda.cu +691 -93
- data/ext/llama_cpp/src/ggml-metal.m +535 -54
- data/ext/llama_cpp/src/ggml-metal.metal +1497 -169
- data/ext/llama_cpp/src/ggml-quants.c +2 -2
- data/ext/llama_cpp/src/ggml.c +325 -159
- data/ext/llama_cpp/src/ggml.h +34 -13
- data/ext/llama_cpp/src/llama.cpp +195 -35
- data/ext/llama_cpp/src/llama.h +1 -1
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +2 -0
- metadata +2 -2
@@ -66,9 +66,11 @@ struct ggml_metal_context {
|
|
66
66
|
GGML_METAL_DECL_KERNEL(div_row);
|
67
67
|
GGML_METAL_DECL_KERNEL(scale);
|
68
68
|
GGML_METAL_DECL_KERNEL(scale_4);
|
69
|
-
GGML_METAL_DECL_KERNEL(
|
69
|
+
GGML_METAL_DECL_KERNEL(tanh);
|
70
70
|
GGML_METAL_DECL_KERNEL(relu);
|
71
71
|
GGML_METAL_DECL_KERNEL(gelu);
|
72
|
+
GGML_METAL_DECL_KERNEL(gelu_quick);
|
73
|
+
GGML_METAL_DECL_KERNEL(silu);
|
72
74
|
GGML_METAL_DECL_KERNEL(soft_max);
|
73
75
|
GGML_METAL_DECL_KERNEL(soft_max_4);
|
74
76
|
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
@@ -86,6 +88,7 @@ struct ggml_metal_context {
|
|
86
88
|
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
87
89
|
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
88
90
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
91
|
+
GGML_METAL_DECL_KERNEL(group_norm);
|
89
92
|
GGML_METAL_DECL_KERNEL(norm);
|
90
93
|
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
91
94
|
GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
|
@@ -102,6 +105,21 @@ struct ggml_metal_context {
|
|
102
105
|
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
|
103
106
|
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
104
107
|
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
108
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
|
109
|
+
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
|
110
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
|
111
|
+
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
|
112
|
+
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
|
113
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
|
114
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
|
115
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
|
116
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
|
117
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
|
118
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
|
119
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
|
120
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
|
121
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
|
122
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
|
105
123
|
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
106
124
|
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
107
125
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
@@ -130,8 +148,11 @@ struct ggml_metal_context {
|
|
130
148
|
GGML_METAL_DECL_KERNEL(rope_f16);
|
131
149
|
GGML_METAL_DECL_KERNEL(alibi_f32);
|
132
150
|
GGML_METAL_DECL_KERNEL(im2col_f16);
|
151
|
+
GGML_METAL_DECL_KERNEL(upscale_f32);
|
152
|
+
GGML_METAL_DECL_KERNEL(pad_f32);
|
133
153
|
GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
|
134
154
|
GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
|
155
|
+
GGML_METAL_DECL_KERNEL(leaky_relu_f32);
|
135
156
|
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
136
157
|
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
137
158
|
GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
|
@@ -140,6 +161,7 @@ struct ggml_metal_context {
|
|
140
161
|
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
|
141
162
|
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
|
142
163
|
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
164
|
+
GGML_METAL_DECL_KERNEL(cpy_f16_f32);
|
143
165
|
GGML_METAL_DECL_KERNEL(concat);
|
144
166
|
GGML_METAL_DECL_KERNEL(sqr);
|
145
167
|
GGML_METAL_DECL_KERNEL(sum_rows);
|
@@ -177,6 +199,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
|
177
199
|
ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
|
178
200
|
} else {
|
179
201
|
char* buffer2 = malloc(len+1);
|
202
|
+
va_end(args);
|
203
|
+
va_start(args, format);
|
180
204
|
vsnprintf(buffer2, len+1, format, args);
|
181
205
|
buffer2[len] = 0;
|
182
206
|
ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
|
@@ -316,9 +340,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
316
340
|
GGML_METAL_ADD_KERNEL(div_row);
|
317
341
|
GGML_METAL_ADD_KERNEL(scale);
|
318
342
|
GGML_METAL_ADD_KERNEL(scale_4);
|
319
|
-
GGML_METAL_ADD_KERNEL(
|
343
|
+
GGML_METAL_ADD_KERNEL(tanh);
|
320
344
|
GGML_METAL_ADD_KERNEL(relu);
|
321
345
|
GGML_METAL_ADD_KERNEL(gelu);
|
346
|
+
GGML_METAL_ADD_KERNEL(gelu_quick);
|
347
|
+
GGML_METAL_ADD_KERNEL(silu);
|
322
348
|
GGML_METAL_ADD_KERNEL(soft_max);
|
323
349
|
GGML_METAL_ADD_KERNEL(soft_max_4);
|
324
350
|
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
@@ -336,6 +362,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
336
362
|
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
|
337
363
|
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
338
364
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
365
|
+
GGML_METAL_ADD_KERNEL(group_norm);
|
339
366
|
GGML_METAL_ADD_KERNEL(norm);
|
340
367
|
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
341
368
|
GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
|
@@ -352,6 +379,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
352
379
|
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
|
353
380
|
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
|
354
381
|
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
|
382
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
|
383
|
+
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
|
384
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
|
385
|
+
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
|
386
|
+
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
|
387
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
|
388
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
|
389
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
|
390
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
|
391
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
|
392
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
|
393
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
|
394
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
|
395
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
|
396
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
|
355
397
|
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
356
398
|
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
357
399
|
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
@@ -382,8 +424,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
382
424
|
GGML_METAL_ADD_KERNEL(rope_f16);
|
383
425
|
GGML_METAL_ADD_KERNEL(alibi_f32);
|
384
426
|
GGML_METAL_ADD_KERNEL(im2col_f16);
|
427
|
+
GGML_METAL_ADD_KERNEL(upscale_f32);
|
428
|
+
GGML_METAL_ADD_KERNEL(pad_f32);
|
385
429
|
GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
|
386
430
|
GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
|
431
|
+
GGML_METAL_ADD_KERNEL(leaky_relu_f32);
|
387
432
|
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
388
433
|
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
389
434
|
GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
|
@@ -392,6 +437,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
392
437
|
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
|
393
438
|
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
|
394
439
|
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
440
|
+
GGML_METAL_ADD_KERNEL(cpy_f16_f32);
|
395
441
|
GGML_METAL_ADD_KERNEL(concat);
|
396
442
|
GGML_METAL_ADD_KERNEL(sqr);
|
397
443
|
GGML_METAL_ADD_KERNEL(sum_rows);
|
@@ -416,9 +462,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
416
462
|
GGML_METAL_DEL_KERNEL(div_row);
|
417
463
|
GGML_METAL_DEL_KERNEL(scale);
|
418
464
|
GGML_METAL_DEL_KERNEL(scale_4);
|
419
|
-
GGML_METAL_DEL_KERNEL(
|
465
|
+
GGML_METAL_DEL_KERNEL(tanh);
|
420
466
|
GGML_METAL_DEL_KERNEL(relu);
|
421
467
|
GGML_METAL_DEL_KERNEL(gelu);
|
468
|
+
GGML_METAL_DEL_KERNEL(gelu_quick);
|
469
|
+
GGML_METAL_DEL_KERNEL(silu);
|
422
470
|
GGML_METAL_DEL_KERNEL(soft_max);
|
423
471
|
GGML_METAL_DEL_KERNEL(soft_max_4);
|
424
472
|
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
@@ -436,6 +484,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
436
484
|
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
437
485
|
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
438
486
|
GGML_METAL_DEL_KERNEL(rms_norm);
|
487
|
+
GGML_METAL_DEL_KERNEL(group_norm);
|
439
488
|
GGML_METAL_DEL_KERNEL(norm);
|
440
489
|
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
441
490
|
GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
|
@@ -452,6 +501,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
452
501
|
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
|
453
502
|
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
454
503
|
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
504
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
|
505
|
+
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
|
506
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
|
507
|
+
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
|
508
|
+
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
|
509
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
|
510
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
|
511
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
|
512
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
|
513
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
|
514
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
|
515
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
|
516
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
|
517
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
|
518
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
|
455
519
|
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
456
520
|
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
457
521
|
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
@@ -482,8 +546,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
482
546
|
GGML_METAL_DEL_KERNEL(rope_f16);
|
483
547
|
GGML_METAL_DEL_KERNEL(alibi_f32);
|
484
548
|
GGML_METAL_DEL_KERNEL(im2col_f16);
|
549
|
+
GGML_METAL_DEL_KERNEL(upscale_f32);
|
550
|
+
GGML_METAL_DEL_KERNEL(pad_f32);
|
485
551
|
GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
|
486
552
|
GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
|
553
|
+
GGML_METAL_DEL_KERNEL(leaky_relu_f32);
|
487
554
|
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
488
555
|
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
489
556
|
GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
|
@@ -492,6 +559,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
492
559
|
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
|
493
560
|
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
|
494
561
|
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
562
|
+
GGML_METAL_DEL_KERNEL(cpy_f16_f32);
|
495
563
|
GGML_METAL_DEL_KERNEL(concat);
|
496
564
|
GGML_METAL_DEL_KERNEL(sqr);
|
497
565
|
GGML_METAL_DEL_KERNEL(sum_rows);
|
@@ -793,9 +861,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
|
793
861
|
switch (op->op) {
|
794
862
|
case GGML_OP_UNARY:
|
795
863
|
switch (ggml_get_unary_op(op)) {
|
796
|
-
case
|
864
|
+
case GGML_UNARY_OP_TANH:
|
797
865
|
case GGML_UNARY_OP_RELU:
|
798
866
|
case GGML_UNARY_OP_GELU:
|
867
|
+
case GGML_UNARY_OP_GELU_QUICK:
|
868
|
+
case GGML_UNARY_OP_SILU:
|
799
869
|
return true;
|
800
870
|
default:
|
801
871
|
return false;
|
@@ -807,6 +877,7 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
|
807
877
|
case GGML_OP_PERMUTE:
|
808
878
|
case GGML_OP_CONCAT:
|
809
879
|
case GGML_OP_ADD:
|
880
|
+
case GGML_OP_ACC:
|
810
881
|
case GGML_OP_MUL:
|
811
882
|
case GGML_OP_DIV:
|
812
883
|
case GGML_OP_SCALE:
|
@@ -814,21 +885,50 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
|
814
885
|
case GGML_OP_SUM_ROWS:
|
815
886
|
case GGML_OP_SOFT_MAX:
|
816
887
|
case GGML_OP_RMS_NORM:
|
888
|
+
case GGML_OP_GROUP_NORM:
|
817
889
|
case GGML_OP_NORM:
|
818
890
|
case GGML_OP_ALIBI:
|
819
891
|
case GGML_OP_ROPE:
|
820
892
|
case GGML_OP_IM2COL:
|
893
|
+
case GGML_OP_UPSCALE:
|
894
|
+
case GGML_OP_PAD:
|
821
895
|
case GGML_OP_ARGSORT:
|
822
|
-
case
|
823
|
-
case GGML_OP_CPY:
|
824
|
-
case GGML_OP_CONT:
|
896
|
+
case GGML_OP_LEAKY_RELU:
|
825
897
|
case GGML_OP_MUL_MAT:
|
826
898
|
case GGML_OP_MUL_MAT_ID:
|
827
899
|
return true;
|
900
|
+
case GGML_OP_CPY:
|
901
|
+
case GGML_OP_DUP:
|
902
|
+
case GGML_OP_CONT:
|
903
|
+
{
|
904
|
+
switch (op->src[0]->type) {
|
905
|
+
case GGML_TYPE_F32:
|
906
|
+
switch (op->type) {
|
907
|
+
case GGML_TYPE_F16:
|
908
|
+
case GGML_TYPE_F32:
|
909
|
+
case GGML_TYPE_Q8_0:
|
910
|
+
case GGML_TYPE_Q4_0:
|
911
|
+
case GGML_TYPE_Q4_1:
|
912
|
+
return true;
|
913
|
+
default:
|
914
|
+
return false;
|
915
|
+
}
|
916
|
+
case GGML_TYPE_F16:
|
917
|
+
switch (op->type) {
|
918
|
+
case GGML_TYPE_F16:
|
919
|
+
case GGML_TYPE_F32:
|
920
|
+
return true;
|
921
|
+
default:
|
922
|
+
return false;
|
923
|
+
}
|
924
|
+
default:
|
925
|
+
return false;
|
926
|
+
};
|
927
|
+
}
|
828
928
|
case GGML_OP_DIAG_MASK_INF:
|
829
929
|
case GGML_OP_GET_ROWS:
|
830
930
|
{
|
831
|
-
return op->ne[
|
931
|
+
return op->ne[3] == 1;
|
832
932
|
}
|
833
933
|
default:
|
834
934
|
return false;
|
@@ -904,7 +1004,10 @@ void ggml_metal_graph_compute(
|
|
904
1004
|
} break;
|
905
1005
|
}
|
906
1006
|
|
907
|
-
|
1007
|
+
if (!ggml_metal_supports_op(dst)) {
|
1008
|
+
GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
|
1009
|
+
GGML_ASSERT(!"unsupported op");
|
1010
|
+
}
|
908
1011
|
|
909
1012
|
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
910
1013
|
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
@@ -1001,34 +1104,39 @@ void ggml_metal_graph_compute(
|
|
1001
1104
|
case GGML_OP_MUL:
|
1002
1105
|
case GGML_OP_DIV:
|
1003
1106
|
{
|
1004
|
-
|
1005
|
-
GGML_ASSERT(ggml_is_contiguous(src1));
|
1107
|
+
const size_t offs = 0;
|
1006
1108
|
|
1007
1109
|
bool bcast_row = false;
|
1008
1110
|
|
1009
1111
|
int64_t nb = ne00;
|
1010
1112
|
|
1011
|
-
|
1113
|
+
id<MTLComputePipelineState> pipeline = nil;
|
1114
|
+
|
1115
|
+
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
1116
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
1117
|
+
|
1012
1118
|
// src1 is a row
|
1013
1119
|
GGML_ASSERT(ne11 == 1);
|
1014
1120
|
|
1015
1121
|
nb = ne00 / 4;
|
1016
1122
|
switch (dst->op) {
|
1017
|
-
case GGML_OP_ADD:
|
1018
|
-
case GGML_OP_MUL:
|
1019
|
-
case GGML_OP_DIV:
|
1123
|
+
case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
|
1124
|
+
case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
|
1125
|
+
case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
|
1020
1126
|
default: GGML_ASSERT(false);
|
1021
1127
|
}
|
1022
1128
|
|
1023
1129
|
bcast_row = true;
|
1024
1130
|
} else {
|
1025
1131
|
switch (dst->op) {
|
1026
|
-
case GGML_OP_ADD:
|
1027
|
-
case GGML_OP_MUL:
|
1028
|
-
case GGML_OP_DIV:
|
1132
|
+
case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
|
1133
|
+
case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
|
1134
|
+
case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
|
1029
1135
|
default: GGML_ASSERT(false);
|
1030
1136
|
}
|
1031
1137
|
}
|
1138
|
+
|
1139
|
+
[encoder setComputePipelineState:pipeline];
|
1032
1140
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1033
1141
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1034
1142
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
@@ -1056,18 +1164,99 @@ void ggml_metal_graph_compute(
|
|
1056
1164
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
1057
1165
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
1058
1166
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
1059
|
-
[encoder setBytes:&
|
1167
|
+
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
1168
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
|
1060
1169
|
|
1061
1170
|
if (bcast_row) {
|
1062
1171
|
const int64_t n = ggml_nelements(dst)/4;
|
1063
1172
|
|
1064
1173
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1065
1174
|
} else {
|
1066
|
-
const int nth = MIN(
|
1175
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
1067
1176
|
|
1068
1177
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1069
1178
|
}
|
1070
1179
|
} break;
|
1180
|
+
case GGML_OP_ACC:
|
1181
|
+
{
|
1182
|
+
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
1183
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1184
|
+
GGML_ASSERT(dstt == GGML_TYPE_F32);
|
1185
|
+
|
1186
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
1187
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
1188
|
+
|
1189
|
+
const size_t pnb1 = ((int32_t *) dst->op_params)[0];
|
1190
|
+
const size_t pnb2 = ((int32_t *) dst->op_params)[1];
|
1191
|
+
const size_t pnb3 = ((int32_t *) dst->op_params)[2];
|
1192
|
+
const size_t offs = ((int32_t *) dst->op_params)[3];
|
1193
|
+
|
1194
|
+
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
1195
|
+
|
1196
|
+
if (!inplace) {
|
1197
|
+
// run a separete kernel to cpy src->dst
|
1198
|
+
// not sure how to avoid this
|
1199
|
+
// TODO: make a simpler cpy_bytes kernel
|
1200
|
+
|
1201
|
+
const int nth = MIN(1024, ne00);
|
1202
|
+
|
1203
|
+
[encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
|
1204
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1205
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1206
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
1207
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
1208
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
1209
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
1210
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
1211
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
1212
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
1213
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
1214
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
1215
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
1216
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
1217
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
1218
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
1219
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
1220
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
1221
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
1222
|
+
|
1223
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1224
|
+
}
|
1225
|
+
|
1226
|
+
[encoder setComputePipelineState:ctx->pipeline_add];
|
1227
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1228
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1229
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1230
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
1231
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
1232
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
1233
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
1234
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
1235
|
+
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
|
1236
|
+
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
|
1237
|
+
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
|
1238
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
1239
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
1240
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
1241
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
1242
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
1243
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
1244
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
1245
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
1246
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
1247
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
1248
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
1249
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
1250
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
1251
|
+
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
|
1252
|
+
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
|
1253
|
+
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
|
1254
|
+
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
1255
|
+
|
1256
|
+
const int nth = MIN(1024, ne0);
|
1257
|
+
|
1258
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1259
|
+
} break;
|
1071
1260
|
case GGML_OP_SCALE:
|
1072
1261
|
{
|
1073
1262
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
@@ -1091,16 +1280,15 @@ void ggml_metal_graph_compute(
|
|
1091
1280
|
} break;
|
1092
1281
|
case GGML_OP_UNARY:
|
1093
1282
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
1094
|
-
case
|
1283
|
+
case GGML_UNARY_OP_TANH:
|
1095
1284
|
{
|
1096
|
-
[encoder setComputePipelineState:ctx->
|
1285
|
+
[encoder setComputePipelineState:ctx->pipeline_tanh];
|
1097
1286
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1098
1287
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1099
1288
|
|
1100
1289
|
const int64_t n = ggml_nelements(dst);
|
1101
|
-
GGML_ASSERT(n % 4 == 0);
|
1102
1290
|
|
1103
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n
|
1291
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1104
1292
|
} break;
|
1105
1293
|
case GGML_UNARY_OP_RELU:
|
1106
1294
|
{
|
@@ -1121,6 +1309,28 @@ void ggml_metal_graph_compute(
|
|
1121
1309
|
const int64_t n = ggml_nelements(dst);
|
1122
1310
|
GGML_ASSERT(n % 4 == 0);
|
1123
1311
|
|
1312
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1313
|
+
} break;
|
1314
|
+
case GGML_UNARY_OP_GELU_QUICK:
|
1315
|
+
{
|
1316
|
+
[encoder setComputePipelineState:ctx->pipeline_gelu_quick];
|
1317
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1318
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1319
|
+
|
1320
|
+
const int64_t n = ggml_nelements(dst);
|
1321
|
+
GGML_ASSERT(n % 4 == 0);
|
1322
|
+
|
1323
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1324
|
+
} break;
|
1325
|
+
case GGML_UNARY_OP_SILU:
|
1326
|
+
{
|
1327
|
+
[encoder setComputePipelineState:ctx->pipeline_silu];
|
1328
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1329
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1330
|
+
|
1331
|
+
const int64_t n = ggml_nelements(dst);
|
1332
|
+
GGML_ASSERT(n % 4 == 0);
|
1333
|
+
|
1124
1334
|
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1125
1335
|
} break;
|
1126
1336
|
default:
|
@@ -1193,7 +1403,11 @@ void ggml_metal_graph_compute(
|
|
1193
1403
|
const float scale = ((float *) dst->op_params)[0];
|
1194
1404
|
|
1195
1405
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1196
|
-
|
1406
|
+
if (id_src1) {
|
1407
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1408
|
+
} else {
|
1409
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
1410
|
+
}
|
1197
1411
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1198
1412
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
1199
1413
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
@@ -1444,7 +1658,7 @@ void ggml_metal_graph_compute(
|
|
1444
1658
|
else if (src0t == GGML_TYPE_Q6_K) {
|
1445
1659
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1446
1660
|
} else {
|
1447
|
-
int64_t ny = (ne11 + nrows - 1)/nrows;
|
1661
|
+
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
1448
1662
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1449
1663
|
}
|
1450
1664
|
}
|
@@ -1456,7 +1670,7 @@ void ggml_metal_graph_compute(
|
|
1456
1670
|
|
1457
1671
|
GGML_ASSERT(src0t == GGML_TYPE_I32);
|
1458
1672
|
|
1459
|
-
const int n_as =
|
1673
|
+
const int n_as = ((int32_t *) dst->op_params)[1];
|
1460
1674
|
|
1461
1675
|
// TODO: make this more general
|
1462
1676
|
GGML_ASSERT(n_as <= 8);
|
@@ -1488,14 +1702,22 @@ void ggml_metal_graph_compute(
|
|
1488
1702
|
|
1489
1703
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
1490
1704
|
// to the matrix-vector kernel
|
1491
|
-
int ne11_mm_min =
|
1705
|
+
int ne11_mm_min = 1;
|
1492
1706
|
|
1493
1707
|
const int idx = ((int32_t *) dst->op_params)[0];
|
1494
1708
|
|
1709
|
+
// batch size
|
1710
|
+
GGML_ASSERT(ne01 == ne11);
|
1711
|
+
|
1712
|
+
const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
|
1713
|
+
|
1495
1714
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
1496
1715
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
1497
|
-
|
1498
|
-
|
1716
|
+
// !!!
|
1717
|
+
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
|
1718
|
+
// indirect matrix multiplication
|
1719
|
+
// !!!
|
1720
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
|
1499
1721
|
switch (src2->type) {
|
1500
1722
|
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
|
1501
1723
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
|
@@ -1514,19 +1736,22 @@ void ggml_metal_graph_compute(
|
|
1514
1736
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1515
1737
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1516
1738
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1517
|
-
[encoder setBytes:&
|
1518
|
-
[encoder setBytes:&
|
1519
|
-
[encoder setBytes:&
|
1520
|
-
[encoder setBytes:&
|
1521
|
-
[encoder setBytes:&
|
1522
|
-
[encoder setBytes:&
|
1523
|
-
[encoder setBytes:&
|
1524
|
-
[encoder setBytes:&
|
1525
|
-
[encoder setBytes:&
|
1526
|
-
[encoder setBytes:&
|
1527
|
-
[encoder setBytes:&
|
1528
|
-
[encoder setBytes:&
|
1529
|
-
[encoder setBytes:&
|
1739
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
1740
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
1741
|
+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
|
1742
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
1743
|
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
|
1744
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
1745
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
|
1746
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
|
1747
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
|
1748
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
|
1749
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
1750
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
|
1751
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
1752
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
|
1753
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
1754
|
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
|
1530
1755
|
// TODO: how to make this an array? read Metal docs
|
1531
1756
|
for (int j = 0; j < n_as; ++j) {
|
1532
1757
|
struct ggml_tensor * src_cur = dst->src[2 + j];
|
@@ -1534,11 +1759,157 @@ void ggml_metal_graph_compute(
|
|
1534
1759
|
size_t offs_src_cur = 0;
|
1535
1760
|
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
1536
1761
|
|
1537
|
-
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:
|
1762
|
+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
1538
1763
|
}
|
1539
1764
|
|
1540
1765
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
1541
|
-
|
1766
|
+
|
1767
|
+
// TODO: processing one row at a time (ne11 -> 1) is not efficient
|
1768
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
1769
|
+
} else {
|
1770
|
+
int nth0 = 32;
|
1771
|
+
int nth1 = 1;
|
1772
|
+
int nrows = 1;
|
1773
|
+
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
1774
|
+
|
1775
|
+
// use custom matrix x vector kernel
|
1776
|
+
switch (src2t) {
|
1777
|
+
case GGML_TYPE_F32:
|
1778
|
+
{
|
1779
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1780
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
|
1781
|
+
} break;
|
1782
|
+
case GGML_TYPE_F16:
|
1783
|
+
{
|
1784
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1785
|
+
nth0 = 32;
|
1786
|
+
nth1 = 1;
|
1787
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
|
1788
|
+
} break;
|
1789
|
+
case GGML_TYPE_Q4_0:
|
1790
|
+
{
|
1791
|
+
nth0 = 8;
|
1792
|
+
nth1 = 8;
|
1793
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
|
1794
|
+
} break;
|
1795
|
+
case GGML_TYPE_Q4_1:
|
1796
|
+
{
|
1797
|
+
nth0 = 8;
|
1798
|
+
nth1 = 8;
|
1799
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
|
1800
|
+
} break;
|
1801
|
+
case GGML_TYPE_Q5_0:
|
1802
|
+
{
|
1803
|
+
nth0 = 8;
|
1804
|
+
nth1 = 8;
|
1805
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
|
1806
|
+
} break;
|
1807
|
+
case GGML_TYPE_Q5_1:
|
1808
|
+
{
|
1809
|
+
nth0 = 8;
|
1810
|
+
nth1 = 8;
|
1811
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
|
1812
|
+
} break;
|
1813
|
+
case GGML_TYPE_Q8_0:
|
1814
|
+
{
|
1815
|
+
nth0 = 8;
|
1816
|
+
nth1 = 8;
|
1817
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
|
1818
|
+
} break;
|
1819
|
+
case GGML_TYPE_Q2_K:
|
1820
|
+
{
|
1821
|
+
nth0 = 2;
|
1822
|
+
nth1 = 32;
|
1823
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
|
1824
|
+
} break;
|
1825
|
+
case GGML_TYPE_Q3_K:
|
1826
|
+
{
|
1827
|
+
nth0 = 2;
|
1828
|
+
nth1 = 32;
|
1829
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
|
1830
|
+
} break;
|
1831
|
+
case GGML_TYPE_Q4_K:
|
1832
|
+
{
|
1833
|
+
nth0 = 4; //1;
|
1834
|
+
nth1 = 8; //32;
|
1835
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
|
1836
|
+
} break;
|
1837
|
+
case GGML_TYPE_Q5_K:
|
1838
|
+
{
|
1839
|
+
nth0 = 2;
|
1840
|
+
nth1 = 32;
|
1841
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
|
1842
|
+
} break;
|
1843
|
+
case GGML_TYPE_Q6_K:
|
1844
|
+
{
|
1845
|
+
nth0 = 2;
|
1846
|
+
nth1 = 32;
|
1847
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
|
1848
|
+
} break;
|
1849
|
+
default:
|
1850
|
+
{
|
1851
|
+
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
1852
|
+
GGML_ASSERT(false && "not implemented");
|
1853
|
+
}
|
1854
|
+
};
|
1855
|
+
|
1856
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1857
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1858
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1859
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
1860
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
1861
|
+
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
1862
|
+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
|
1863
|
+
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
|
1864
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
|
1865
|
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
|
1866
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
1867
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
|
1868
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
1869
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
1870
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
1871
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
1872
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
1873
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
1874
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
|
1875
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
|
1876
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
|
1877
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
|
1878
|
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:22];
|
1879
|
+
// TODO: how to make this an array? read Metal docs
|
1880
|
+
for (int j = 0; j < n_as; ++j) {
|
1881
|
+
struct ggml_tensor * src_cur = dst->src[2 + j];
|
1882
|
+
|
1883
|
+
size_t offs_src_cur = 0;
|
1884
|
+
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
1885
|
+
|
1886
|
+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
1887
|
+
}
|
1888
|
+
|
1889
|
+
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
1890
|
+
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
1891
|
+
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
|
1892
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1893
|
+
}
|
1894
|
+
else if (src2t == GGML_TYPE_Q4_K) {
|
1895
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1896
|
+
}
|
1897
|
+
else if (src2t == GGML_TYPE_Q3_K) {
|
1898
|
+
#ifdef GGML_QKK_64
|
1899
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1900
|
+
#else
|
1901
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1902
|
+
#endif
|
1903
|
+
}
|
1904
|
+
else if (src2t == GGML_TYPE_Q5_K) {
|
1905
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1906
|
+
}
|
1907
|
+
else if (src2t == GGML_TYPE_Q6_K) {
|
1908
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1909
|
+
} else {
|
1910
|
+
const int64_t ny = (_ne1 + nrows - 1)/nrows;
|
1911
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1912
|
+
}
|
1542
1913
|
}
|
1543
1914
|
} break;
|
1544
1915
|
case GGML_OP_GET_ROWS:
|
@@ -1559,16 +1930,19 @@ void ggml_metal_graph_compute(
|
|
1559
1930
|
default: GGML_ASSERT(false && "not implemented");
|
1560
1931
|
}
|
1561
1932
|
|
1562
|
-
[encoder setBuffer:id_src0
|
1563
|
-
[encoder setBuffer:id_src1
|
1564
|
-
[encoder setBuffer:id_dst
|
1933
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1934
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1935
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1565
1936
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
1566
1937
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
1567
|
-
[encoder setBytes:&
|
1568
|
-
|
1569
|
-
|
1570
|
-
|
1571
|
-
[encoder
|
1938
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
|
1939
|
+
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
|
1940
|
+
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
|
1941
|
+
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
|
1942
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
|
1943
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
|
1944
|
+
|
1945
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
1572
1946
|
} break;
|
1573
1947
|
case GGML_OP_RMS_NORM:
|
1574
1948
|
{
|
@@ -1595,6 +1969,38 @@ void ggml_metal_graph_compute(
|
|
1595
1969
|
|
1596
1970
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1597
1971
|
} break;
|
1972
|
+
case GGML_OP_GROUP_NORM:
|
1973
|
+
{
|
1974
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
1975
|
+
|
1976
|
+
//float eps;
|
1977
|
+
//memcpy(&eps, dst->op_params, sizeof(float));
|
1978
|
+
|
1979
|
+
const float eps = 1e-6f; // TODO: temporarily hardcoded
|
1980
|
+
|
1981
|
+
const int32_t n_groups = ((int32_t *) dst->op_params)[0];
|
1982
|
+
|
1983
|
+
int nth = 32; // SIMD width
|
1984
|
+
|
1985
|
+
//while (nth < ne00/4 && nth < 1024) {
|
1986
|
+
// nth *= 2;
|
1987
|
+
//}
|
1988
|
+
|
1989
|
+
[encoder setComputePipelineState:ctx->pipeline_group_norm];
|
1990
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1991
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1992
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
1993
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
1994
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
1995
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
|
1996
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
|
1997
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
|
1998
|
+
[encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
|
1999
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:9];
|
2000
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
2001
|
+
|
2002
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2003
|
+
} break;
|
1598
2004
|
case GGML_OP_NORM:
|
1599
2005
|
{
|
1600
2006
|
float eps;
|
@@ -1764,6 +2170,65 @@ void ggml_metal_graph_compute(
|
|
1764
2170
|
|
1765
2171
|
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
1766
2172
|
} break;
|
2173
|
+
case GGML_OP_UPSCALE:
|
2174
|
+
{
|
2175
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2176
|
+
|
2177
|
+
const int sf = dst->op_params[0];
|
2178
|
+
|
2179
|
+
[encoder setComputePipelineState:ctx->pipeline_upscale_f32];
|
2180
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2181
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2182
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
2183
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
2184
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
2185
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
2186
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
2187
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
2188
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
2189
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
2190
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
2191
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
2192
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
2193
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
2194
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
2195
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
2196
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
2197
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
2198
|
+
[encoder setBytes:&sf length:sizeof(sf) atIndex:18];
|
2199
|
+
|
2200
|
+
const int nth = MIN(1024, ne0);
|
2201
|
+
|
2202
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2203
|
+
} break;
|
2204
|
+
case GGML_OP_PAD:
|
2205
|
+
{
|
2206
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2207
|
+
|
2208
|
+
[encoder setComputePipelineState:ctx->pipeline_pad_f32];
|
2209
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2210
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2211
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
2212
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
2213
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
2214
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
2215
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
2216
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
2217
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
2218
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
2219
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
2220
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
2221
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
2222
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
2223
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
2224
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
2225
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
2226
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
2227
|
+
|
2228
|
+
const int nth = MIN(1024, ne0);
|
2229
|
+
|
2230
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2231
|
+
} break;
|
1767
2232
|
case GGML_OP_ARGSORT:
|
1768
2233
|
{
|
1769
2234
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
@@ -1785,6 +2250,22 @@ void ggml_metal_graph_compute(
|
|
1785
2250
|
|
1786
2251
|
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
|
1787
2252
|
} break;
|
2253
|
+
case GGML_OP_LEAKY_RELU:
|
2254
|
+
{
|
2255
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2256
|
+
|
2257
|
+
float slope;
|
2258
|
+
memcpy(&slope, dst->op_params, sizeof(float));
|
2259
|
+
|
2260
|
+
[encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
|
2261
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2262
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2263
|
+
[encoder setBytes:&slope length:sizeof(slope) atIndex:2];
|
2264
|
+
|
2265
|
+
const int64_t n = ggml_nelements(dst);
|
2266
|
+
|
2267
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
2268
|
+
} break;
|
1788
2269
|
case GGML_OP_DUP:
|
1789
2270
|
case GGML_OP_CPY:
|
1790
2271
|
case GGML_OP_CONT:
|
@@ -1813,7 +2294,7 @@ void ggml_metal_graph_compute(
|
|
1813
2294
|
{
|
1814
2295
|
switch (dstt) {
|
1815
2296
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
|
1816
|
-
case GGML_TYPE_F32:
|
2297
|
+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
|
1817
2298
|
default: GGML_ASSERT(false && "not implemented");
|
1818
2299
|
};
|
1819
2300
|
} break;
|