llama_cpp 0.10.0 → 0.10.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/ext/llama_cpp/llama_cpp.cpp +2 -0
- data/ext/llama_cpp/src/ggml-alloc.h +1 -1
- data/ext/llama_cpp/src/ggml-cuda.cu +691 -93
- data/ext/llama_cpp/src/ggml-metal.m +535 -54
- data/ext/llama_cpp/src/ggml-metal.metal +1497 -169
- data/ext/llama_cpp/src/ggml-quants.c +2 -2
- data/ext/llama_cpp/src/ggml.c +325 -159
- data/ext/llama_cpp/src/ggml.h +34 -13
- data/ext/llama_cpp/src/llama.cpp +195 -35
- data/ext/llama_cpp/src/llama.h +1 -1
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +2 -0
- metadata +2 -2
@@ -66,9 +66,11 @@ struct ggml_metal_context {
|
|
66
66
|
GGML_METAL_DECL_KERNEL(div_row);
|
67
67
|
GGML_METAL_DECL_KERNEL(scale);
|
68
68
|
GGML_METAL_DECL_KERNEL(scale_4);
|
69
|
-
GGML_METAL_DECL_KERNEL(
|
69
|
+
GGML_METAL_DECL_KERNEL(tanh);
|
70
70
|
GGML_METAL_DECL_KERNEL(relu);
|
71
71
|
GGML_METAL_DECL_KERNEL(gelu);
|
72
|
+
GGML_METAL_DECL_KERNEL(gelu_quick);
|
73
|
+
GGML_METAL_DECL_KERNEL(silu);
|
72
74
|
GGML_METAL_DECL_KERNEL(soft_max);
|
73
75
|
GGML_METAL_DECL_KERNEL(soft_max_4);
|
74
76
|
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
@@ -86,6 +88,7 @@ struct ggml_metal_context {
|
|
86
88
|
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
87
89
|
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
88
90
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
91
|
+
GGML_METAL_DECL_KERNEL(group_norm);
|
89
92
|
GGML_METAL_DECL_KERNEL(norm);
|
90
93
|
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
91
94
|
GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
|
@@ -102,6 +105,21 @@ struct ggml_metal_context {
|
|
102
105
|
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
|
103
106
|
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
104
107
|
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
108
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
|
109
|
+
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
|
110
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
|
111
|
+
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
|
112
|
+
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
|
113
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
|
114
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
|
115
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
|
116
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
|
117
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
|
118
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
|
119
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
|
120
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
|
121
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
|
122
|
+
GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
|
105
123
|
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
106
124
|
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
107
125
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
@@ -130,8 +148,11 @@ struct ggml_metal_context {
|
|
130
148
|
GGML_METAL_DECL_KERNEL(rope_f16);
|
131
149
|
GGML_METAL_DECL_KERNEL(alibi_f32);
|
132
150
|
GGML_METAL_DECL_KERNEL(im2col_f16);
|
151
|
+
GGML_METAL_DECL_KERNEL(upscale_f32);
|
152
|
+
GGML_METAL_DECL_KERNEL(pad_f32);
|
133
153
|
GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
|
134
154
|
GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
|
155
|
+
GGML_METAL_DECL_KERNEL(leaky_relu_f32);
|
135
156
|
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
136
157
|
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
137
158
|
GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
|
@@ -140,6 +161,7 @@ struct ggml_metal_context {
|
|
140
161
|
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
|
141
162
|
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
|
142
163
|
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
164
|
+
GGML_METAL_DECL_KERNEL(cpy_f16_f32);
|
143
165
|
GGML_METAL_DECL_KERNEL(concat);
|
144
166
|
GGML_METAL_DECL_KERNEL(sqr);
|
145
167
|
GGML_METAL_DECL_KERNEL(sum_rows);
|
@@ -177,6 +199,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
|
177
199
|
ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
|
178
200
|
} else {
|
179
201
|
char* buffer2 = malloc(len+1);
|
202
|
+
va_end(args);
|
203
|
+
va_start(args, format);
|
180
204
|
vsnprintf(buffer2, len+1, format, args);
|
181
205
|
buffer2[len] = 0;
|
182
206
|
ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
|
@@ -316,9 +340,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
316
340
|
GGML_METAL_ADD_KERNEL(div_row);
|
317
341
|
GGML_METAL_ADD_KERNEL(scale);
|
318
342
|
GGML_METAL_ADD_KERNEL(scale_4);
|
319
|
-
GGML_METAL_ADD_KERNEL(
|
343
|
+
GGML_METAL_ADD_KERNEL(tanh);
|
320
344
|
GGML_METAL_ADD_KERNEL(relu);
|
321
345
|
GGML_METAL_ADD_KERNEL(gelu);
|
346
|
+
GGML_METAL_ADD_KERNEL(gelu_quick);
|
347
|
+
GGML_METAL_ADD_KERNEL(silu);
|
322
348
|
GGML_METAL_ADD_KERNEL(soft_max);
|
323
349
|
GGML_METAL_ADD_KERNEL(soft_max_4);
|
324
350
|
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
@@ -336,6 +362,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
336
362
|
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
|
337
363
|
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
338
364
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
365
|
+
GGML_METAL_ADD_KERNEL(group_norm);
|
339
366
|
GGML_METAL_ADD_KERNEL(norm);
|
340
367
|
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
341
368
|
GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
|
@@ -352,6 +379,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
352
379
|
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
|
353
380
|
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
|
354
381
|
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
|
382
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
|
383
|
+
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
|
384
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
|
385
|
+
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
|
386
|
+
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
|
387
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
|
388
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
|
389
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
|
390
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
|
391
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
|
392
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
|
393
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
|
394
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
|
395
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
|
396
|
+
GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
|
355
397
|
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
356
398
|
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
357
399
|
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
@@ -382,8 +424,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
382
424
|
GGML_METAL_ADD_KERNEL(rope_f16);
|
383
425
|
GGML_METAL_ADD_KERNEL(alibi_f32);
|
384
426
|
GGML_METAL_ADD_KERNEL(im2col_f16);
|
427
|
+
GGML_METAL_ADD_KERNEL(upscale_f32);
|
428
|
+
GGML_METAL_ADD_KERNEL(pad_f32);
|
385
429
|
GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
|
386
430
|
GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
|
431
|
+
GGML_METAL_ADD_KERNEL(leaky_relu_f32);
|
387
432
|
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
388
433
|
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
389
434
|
GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
|
@@ -392,6 +437,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
392
437
|
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
|
393
438
|
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
|
394
439
|
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
440
|
+
GGML_METAL_ADD_KERNEL(cpy_f16_f32);
|
395
441
|
GGML_METAL_ADD_KERNEL(concat);
|
396
442
|
GGML_METAL_ADD_KERNEL(sqr);
|
397
443
|
GGML_METAL_ADD_KERNEL(sum_rows);
|
@@ -416,9 +462,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
416
462
|
GGML_METAL_DEL_KERNEL(div_row);
|
417
463
|
GGML_METAL_DEL_KERNEL(scale);
|
418
464
|
GGML_METAL_DEL_KERNEL(scale_4);
|
419
|
-
GGML_METAL_DEL_KERNEL(
|
465
|
+
GGML_METAL_DEL_KERNEL(tanh);
|
420
466
|
GGML_METAL_DEL_KERNEL(relu);
|
421
467
|
GGML_METAL_DEL_KERNEL(gelu);
|
468
|
+
GGML_METAL_DEL_KERNEL(gelu_quick);
|
469
|
+
GGML_METAL_DEL_KERNEL(silu);
|
422
470
|
GGML_METAL_DEL_KERNEL(soft_max);
|
423
471
|
GGML_METAL_DEL_KERNEL(soft_max_4);
|
424
472
|
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
@@ -436,6 +484,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
436
484
|
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
437
485
|
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
438
486
|
GGML_METAL_DEL_KERNEL(rms_norm);
|
487
|
+
GGML_METAL_DEL_KERNEL(group_norm);
|
439
488
|
GGML_METAL_DEL_KERNEL(norm);
|
440
489
|
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
441
490
|
GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
|
@@ -452,6 +501,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
452
501
|
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
|
453
502
|
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
454
503
|
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
504
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
|
505
|
+
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
|
506
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
|
507
|
+
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
|
508
|
+
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
|
509
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
|
510
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
|
511
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
|
512
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
|
513
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
|
514
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
|
515
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
|
516
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
|
517
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
|
518
|
+
GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
|
455
519
|
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
456
520
|
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
457
521
|
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
@@ -482,8 +546,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
482
546
|
GGML_METAL_DEL_KERNEL(rope_f16);
|
483
547
|
GGML_METAL_DEL_KERNEL(alibi_f32);
|
484
548
|
GGML_METAL_DEL_KERNEL(im2col_f16);
|
549
|
+
GGML_METAL_DEL_KERNEL(upscale_f32);
|
550
|
+
GGML_METAL_DEL_KERNEL(pad_f32);
|
485
551
|
GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
|
486
552
|
GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
|
553
|
+
GGML_METAL_DEL_KERNEL(leaky_relu_f32);
|
487
554
|
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
488
555
|
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
489
556
|
GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
|
@@ -492,6 +559,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
492
559
|
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
|
493
560
|
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
|
494
561
|
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
562
|
+
GGML_METAL_DEL_KERNEL(cpy_f16_f32);
|
495
563
|
GGML_METAL_DEL_KERNEL(concat);
|
496
564
|
GGML_METAL_DEL_KERNEL(sqr);
|
497
565
|
GGML_METAL_DEL_KERNEL(sum_rows);
|
@@ -793,9 +861,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
|
793
861
|
switch (op->op) {
|
794
862
|
case GGML_OP_UNARY:
|
795
863
|
switch (ggml_get_unary_op(op)) {
|
796
|
-
case
|
864
|
+
case GGML_UNARY_OP_TANH:
|
797
865
|
case GGML_UNARY_OP_RELU:
|
798
866
|
case GGML_UNARY_OP_GELU:
|
867
|
+
case GGML_UNARY_OP_GELU_QUICK:
|
868
|
+
case GGML_UNARY_OP_SILU:
|
799
869
|
return true;
|
800
870
|
default:
|
801
871
|
return false;
|
@@ -807,6 +877,7 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
|
807
877
|
case GGML_OP_PERMUTE:
|
808
878
|
case GGML_OP_CONCAT:
|
809
879
|
case GGML_OP_ADD:
|
880
|
+
case GGML_OP_ACC:
|
810
881
|
case GGML_OP_MUL:
|
811
882
|
case GGML_OP_DIV:
|
812
883
|
case GGML_OP_SCALE:
|
@@ -814,21 +885,50 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
|
814
885
|
case GGML_OP_SUM_ROWS:
|
815
886
|
case GGML_OP_SOFT_MAX:
|
816
887
|
case GGML_OP_RMS_NORM:
|
888
|
+
case GGML_OP_GROUP_NORM:
|
817
889
|
case GGML_OP_NORM:
|
818
890
|
case GGML_OP_ALIBI:
|
819
891
|
case GGML_OP_ROPE:
|
820
892
|
case GGML_OP_IM2COL:
|
893
|
+
case GGML_OP_UPSCALE:
|
894
|
+
case GGML_OP_PAD:
|
821
895
|
case GGML_OP_ARGSORT:
|
822
|
-
case
|
823
|
-
case GGML_OP_CPY:
|
824
|
-
case GGML_OP_CONT:
|
896
|
+
case GGML_OP_LEAKY_RELU:
|
825
897
|
case GGML_OP_MUL_MAT:
|
826
898
|
case GGML_OP_MUL_MAT_ID:
|
827
899
|
return true;
|
900
|
+
case GGML_OP_CPY:
|
901
|
+
case GGML_OP_DUP:
|
902
|
+
case GGML_OP_CONT:
|
903
|
+
{
|
904
|
+
switch (op->src[0]->type) {
|
905
|
+
case GGML_TYPE_F32:
|
906
|
+
switch (op->type) {
|
907
|
+
case GGML_TYPE_F16:
|
908
|
+
case GGML_TYPE_F32:
|
909
|
+
case GGML_TYPE_Q8_0:
|
910
|
+
case GGML_TYPE_Q4_0:
|
911
|
+
case GGML_TYPE_Q4_1:
|
912
|
+
return true;
|
913
|
+
default:
|
914
|
+
return false;
|
915
|
+
}
|
916
|
+
case GGML_TYPE_F16:
|
917
|
+
switch (op->type) {
|
918
|
+
case GGML_TYPE_F16:
|
919
|
+
case GGML_TYPE_F32:
|
920
|
+
return true;
|
921
|
+
default:
|
922
|
+
return false;
|
923
|
+
}
|
924
|
+
default:
|
925
|
+
return false;
|
926
|
+
};
|
927
|
+
}
|
828
928
|
case GGML_OP_DIAG_MASK_INF:
|
829
929
|
case GGML_OP_GET_ROWS:
|
830
930
|
{
|
831
|
-
return op->ne[
|
931
|
+
return op->ne[3] == 1;
|
832
932
|
}
|
833
933
|
default:
|
834
934
|
return false;
|
@@ -904,7 +1004,10 @@ void ggml_metal_graph_compute(
|
|
904
1004
|
} break;
|
905
1005
|
}
|
906
1006
|
|
907
|
-
|
1007
|
+
if (!ggml_metal_supports_op(dst)) {
|
1008
|
+
GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
|
1009
|
+
GGML_ASSERT(!"unsupported op");
|
1010
|
+
}
|
908
1011
|
|
909
1012
|
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
910
1013
|
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
@@ -1001,34 +1104,39 @@ void ggml_metal_graph_compute(
|
|
1001
1104
|
case GGML_OP_MUL:
|
1002
1105
|
case GGML_OP_DIV:
|
1003
1106
|
{
|
1004
|
-
|
1005
|
-
GGML_ASSERT(ggml_is_contiguous(src1));
|
1107
|
+
const size_t offs = 0;
|
1006
1108
|
|
1007
1109
|
bool bcast_row = false;
|
1008
1110
|
|
1009
1111
|
int64_t nb = ne00;
|
1010
1112
|
|
1011
|
-
|
1113
|
+
id<MTLComputePipelineState> pipeline = nil;
|
1114
|
+
|
1115
|
+
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
1116
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
1117
|
+
|
1012
1118
|
// src1 is a row
|
1013
1119
|
GGML_ASSERT(ne11 == 1);
|
1014
1120
|
|
1015
1121
|
nb = ne00 / 4;
|
1016
1122
|
switch (dst->op) {
|
1017
|
-
case GGML_OP_ADD:
|
1018
|
-
case GGML_OP_MUL:
|
1019
|
-
case GGML_OP_DIV:
|
1123
|
+
case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
|
1124
|
+
case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
|
1125
|
+
case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
|
1020
1126
|
default: GGML_ASSERT(false);
|
1021
1127
|
}
|
1022
1128
|
|
1023
1129
|
bcast_row = true;
|
1024
1130
|
} else {
|
1025
1131
|
switch (dst->op) {
|
1026
|
-
case GGML_OP_ADD:
|
1027
|
-
case GGML_OP_MUL:
|
1028
|
-
case GGML_OP_DIV:
|
1132
|
+
case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
|
1133
|
+
case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
|
1134
|
+
case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
|
1029
1135
|
default: GGML_ASSERT(false);
|
1030
1136
|
}
|
1031
1137
|
}
|
1138
|
+
|
1139
|
+
[encoder setComputePipelineState:pipeline];
|
1032
1140
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1033
1141
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1034
1142
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
@@ -1056,18 +1164,99 @@ void ggml_metal_graph_compute(
|
|
1056
1164
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
1057
1165
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
1058
1166
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
1059
|
-
[encoder setBytes:&
|
1167
|
+
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
1168
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
|
1060
1169
|
|
1061
1170
|
if (bcast_row) {
|
1062
1171
|
const int64_t n = ggml_nelements(dst)/4;
|
1063
1172
|
|
1064
1173
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1065
1174
|
} else {
|
1066
|
-
const int nth = MIN(
|
1175
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
1067
1176
|
|
1068
1177
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1069
1178
|
}
|
1070
1179
|
} break;
|
1180
|
+
case GGML_OP_ACC:
|
1181
|
+
{
|
1182
|
+
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
1183
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1184
|
+
GGML_ASSERT(dstt == GGML_TYPE_F32);
|
1185
|
+
|
1186
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
1187
|
+
GGML_ASSERT(ggml_is_contiguous(src1));
|
1188
|
+
|
1189
|
+
const size_t pnb1 = ((int32_t *) dst->op_params)[0];
|
1190
|
+
const size_t pnb2 = ((int32_t *) dst->op_params)[1];
|
1191
|
+
const size_t pnb3 = ((int32_t *) dst->op_params)[2];
|
1192
|
+
const size_t offs = ((int32_t *) dst->op_params)[3];
|
1193
|
+
|
1194
|
+
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
1195
|
+
|
1196
|
+
if (!inplace) {
|
1197
|
+
// run a separete kernel to cpy src->dst
|
1198
|
+
// not sure how to avoid this
|
1199
|
+
// TODO: make a simpler cpy_bytes kernel
|
1200
|
+
|
1201
|
+
const int nth = MIN(1024, ne00);
|
1202
|
+
|
1203
|
+
[encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
|
1204
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1205
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1206
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
1207
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
1208
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
1209
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
1210
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
1211
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
1212
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
1213
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
1214
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
1215
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
1216
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
1217
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
1218
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
1219
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
1220
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
1221
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
1222
|
+
|
1223
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1224
|
+
}
|
1225
|
+
|
1226
|
+
[encoder setComputePipelineState:ctx->pipeline_add];
|
1227
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1228
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1229
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1230
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
1231
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
1232
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
1233
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
1234
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
1235
|
+
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
|
1236
|
+
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
|
1237
|
+
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
|
1238
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
1239
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
1240
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
1241
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
1242
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
1243
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
1244
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
1245
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
1246
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
1247
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
1248
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
1249
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
1250
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
1251
|
+
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
|
1252
|
+
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
|
1253
|
+
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
|
1254
|
+
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
1255
|
+
|
1256
|
+
const int nth = MIN(1024, ne0);
|
1257
|
+
|
1258
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1259
|
+
} break;
|
1071
1260
|
case GGML_OP_SCALE:
|
1072
1261
|
{
|
1073
1262
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
@@ -1091,16 +1280,15 @@ void ggml_metal_graph_compute(
|
|
1091
1280
|
} break;
|
1092
1281
|
case GGML_OP_UNARY:
|
1093
1282
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
1094
|
-
case
|
1283
|
+
case GGML_UNARY_OP_TANH:
|
1095
1284
|
{
|
1096
|
-
[encoder setComputePipelineState:ctx->
|
1285
|
+
[encoder setComputePipelineState:ctx->pipeline_tanh];
|
1097
1286
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1098
1287
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1099
1288
|
|
1100
1289
|
const int64_t n = ggml_nelements(dst);
|
1101
|
-
GGML_ASSERT(n % 4 == 0);
|
1102
1290
|
|
1103
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n
|
1291
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1104
1292
|
} break;
|
1105
1293
|
case GGML_UNARY_OP_RELU:
|
1106
1294
|
{
|
@@ -1121,6 +1309,28 @@ void ggml_metal_graph_compute(
|
|
1121
1309
|
const int64_t n = ggml_nelements(dst);
|
1122
1310
|
GGML_ASSERT(n % 4 == 0);
|
1123
1311
|
|
1312
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1313
|
+
} break;
|
1314
|
+
case GGML_UNARY_OP_GELU_QUICK:
|
1315
|
+
{
|
1316
|
+
[encoder setComputePipelineState:ctx->pipeline_gelu_quick];
|
1317
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1318
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1319
|
+
|
1320
|
+
const int64_t n = ggml_nelements(dst);
|
1321
|
+
GGML_ASSERT(n % 4 == 0);
|
1322
|
+
|
1323
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1324
|
+
} break;
|
1325
|
+
case GGML_UNARY_OP_SILU:
|
1326
|
+
{
|
1327
|
+
[encoder setComputePipelineState:ctx->pipeline_silu];
|
1328
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1329
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1330
|
+
|
1331
|
+
const int64_t n = ggml_nelements(dst);
|
1332
|
+
GGML_ASSERT(n % 4 == 0);
|
1333
|
+
|
1124
1334
|
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1125
1335
|
} break;
|
1126
1336
|
default:
|
@@ -1193,7 +1403,11 @@ void ggml_metal_graph_compute(
|
|
1193
1403
|
const float scale = ((float *) dst->op_params)[0];
|
1194
1404
|
|
1195
1405
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1196
|
-
|
1406
|
+
if (id_src1) {
|
1407
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1408
|
+
} else {
|
1409
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
1410
|
+
}
|
1197
1411
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1198
1412
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
1199
1413
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
@@ -1444,7 +1658,7 @@ void ggml_metal_graph_compute(
|
|
1444
1658
|
else if (src0t == GGML_TYPE_Q6_K) {
|
1445
1659
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1446
1660
|
} else {
|
1447
|
-
int64_t ny = (ne11 + nrows - 1)/nrows;
|
1661
|
+
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
1448
1662
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1449
1663
|
}
|
1450
1664
|
}
|
@@ -1456,7 +1670,7 @@ void ggml_metal_graph_compute(
|
|
1456
1670
|
|
1457
1671
|
GGML_ASSERT(src0t == GGML_TYPE_I32);
|
1458
1672
|
|
1459
|
-
const int n_as =
|
1673
|
+
const int n_as = ((int32_t *) dst->op_params)[1];
|
1460
1674
|
|
1461
1675
|
// TODO: make this more general
|
1462
1676
|
GGML_ASSERT(n_as <= 8);
|
@@ -1488,14 +1702,22 @@ void ggml_metal_graph_compute(
|
|
1488
1702
|
|
1489
1703
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
1490
1704
|
// to the matrix-vector kernel
|
1491
|
-
int ne11_mm_min =
|
1705
|
+
int ne11_mm_min = 1;
|
1492
1706
|
|
1493
1707
|
const int idx = ((int32_t *) dst->op_params)[0];
|
1494
1708
|
|
1709
|
+
// batch size
|
1710
|
+
GGML_ASSERT(ne01 == ne11);
|
1711
|
+
|
1712
|
+
const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
|
1713
|
+
|
1495
1714
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
1496
1715
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
1497
|
-
|
1498
|
-
|
1716
|
+
// !!!
|
1717
|
+
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
|
1718
|
+
// indirect matrix multiplication
|
1719
|
+
// !!!
|
1720
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
|
1499
1721
|
switch (src2->type) {
|
1500
1722
|
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
|
1501
1723
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
|
@@ -1514,19 +1736,22 @@ void ggml_metal_graph_compute(
|
|
1514
1736
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1515
1737
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1516
1738
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1517
|
-
[encoder setBytes:&
|
1518
|
-
[encoder setBytes:&
|
1519
|
-
[encoder setBytes:&
|
1520
|
-
[encoder setBytes:&
|
1521
|
-
[encoder setBytes:&
|
1522
|
-
[encoder setBytes:&
|
1523
|
-
[encoder setBytes:&
|
1524
|
-
[encoder setBytes:&
|
1525
|
-
[encoder setBytes:&
|
1526
|
-
[encoder setBytes:&
|
1527
|
-
[encoder setBytes:&
|
1528
|
-
[encoder setBytes:&
|
1529
|
-
[encoder setBytes:&
|
1739
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
1740
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
1741
|
+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
|
1742
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
1743
|
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
|
1744
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
1745
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
|
1746
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
|
1747
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
|
1748
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
|
1749
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
1750
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
|
1751
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
1752
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
|
1753
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
1754
|
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
|
1530
1755
|
// TODO: how to make this an array? read Metal docs
|
1531
1756
|
for (int j = 0; j < n_as; ++j) {
|
1532
1757
|
struct ggml_tensor * src_cur = dst->src[2 + j];
|
@@ -1534,11 +1759,157 @@ void ggml_metal_graph_compute(
|
|
1534
1759
|
size_t offs_src_cur = 0;
|
1535
1760
|
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
1536
1761
|
|
1537
|
-
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:
|
1762
|
+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
1538
1763
|
}
|
1539
1764
|
|
1540
1765
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
1541
|
-
|
1766
|
+
|
1767
|
+
// TODO: processing one row at a time (ne11 -> 1) is not efficient
|
1768
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
1769
|
+
} else {
|
1770
|
+
int nth0 = 32;
|
1771
|
+
int nth1 = 1;
|
1772
|
+
int nrows = 1;
|
1773
|
+
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
1774
|
+
|
1775
|
+
// use custom matrix x vector kernel
|
1776
|
+
switch (src2t) {
|
1777
|
+
case GGML_TYPE_F32:
|
1778
|
+
{
|
1779
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1780
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
|
1781
|
+
} break;
|
1782
|
+
case GGML_TYPE_F16:
|
1783
|
+
{
|
1784
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1785
|
+
nth0 = 32;
|
1786
|
+
nth1 = 1;
|
1787
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
|
1788
|
+
} break;
|
1789
|
+
case GGML_TYPE_Q4_0:
|
1790
|
+
{
|
1791
|
+
nth0 = 8;
|
1792
|
+
nth1 = 8;
|
1793
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
|
1794
|
+
} break;
|
1795
|
+
case GGML_TYPE_Q4_1:
|
1796
|
+
{
|
1797
|
+
nth0 = 8;
|
1798
|
+
nth1 = 8;
|
1799
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
|
1800
|
+
} break;
|
1801
|
+
case GGML_TYPE_Q5_0:
|
1802
|
+
{
|
1803
|
+
nth0 = 8;
|
1804
|
+
nth1 = 8;
|
1805
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
|
1806
|
+
} break;
|
1807
|
+
case GGML_TYPE_Q5_1:
|
1808
|
+
{
|
1809
|
+
nth0 = 8;
|
1810
|
+
nth1 = 8;
|
1811
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
|
1812
|
+
} break;
|
1813
|
+
case GGML_TYPE_Q8_0:
|
1814
|
+
{
|
1815
|
+
nth0 = 8;
|
1816
|
+
nth1 = 8;
|
1817
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
|
1818
|
+
} break;
|
1819
|
+
case GGML_TYPE_Q2_K:
|
1820
|
+
{
|
1821
|
+
nth0 = 2;
|
1822
|
+
nth1 = 32;
|
1823
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
|
1824
|
+
} break;
|
1825
|
+
case GGML_TYPE_Q3_K:
|
1826
|
+
{
|
1827
|
+
nth0 = 2;
|
1828
|
+
nth1 = 32;
|
1829
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
|
1830
|
+
} break;
|
1831
|
+
case GGML_TYPE_Q4_K:
|
1832
|
+
{
|
1833
|
+
nth0 = 4; //1;
|
1834
|
+
nth1 = 8; //32;
|
1835
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
|
1836
|
+
} break;
|
1837
|
+
case GGML_TYPE_Q5_K:
|
1838
|
+
{
|
1839
|
+
nth0 = 2;
|
1840
|
+
nth1 = 32;
|
1841
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
|
1842
|
+
} break;
|
1843
|
+
case GGML_TYPE_Q6_K:
|
1844
|
+
{
|
1845
|
+
nth0 = 2;
|
1846
|
+
nth1 = 32;
|
1847
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
|
1848
|
+
} break;
|
1849
|
+
default:
|
1850
|
+
{
|
1851
|
+
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
1852
|
+
GGML_ASSERT(false && "not implemented");
|
1853
|
+
}
|
1854
|
+
};
|
1855
|
+
|
1856
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1857
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1858
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1859
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
1860
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
1861
|
+
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
1862
|
+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
|
1863
|
+
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
|
1864
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
|
1865
|
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
|
1866
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
1867
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
|
1868
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
1869
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
1870
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
1871
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
1872
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
1873
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
1874
|
+
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
|
1875
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
|
1876
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
|
1877
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
|
1878
|
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:22];
|
1879
|
+
// TODO: how to make this an array? read Metal docs
|
1880
|
+
for (int j = 0; j < n_as; ++j) {
|
1881
|
+
struct ggml_tensor * src_cur = dst->src[2 + j];
|
1882
|
+
|
1883
|
+
size_t offs_src_cur = 0;
|
1884
|
+
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
1885
|
+
|
1886
|
+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
1887
|
+
}
|
1888
|
+
|
1889
|
+
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
1890
|
+
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
1891
|
+
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
|
1892
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1893
|
+
}
|
1894
|
+
else if (src2t == GGML_TYPE_Q4_K) {
|
1895
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1896
|
+
}
|
1897
|
+
else if (src2t == GGML_TYPE_Q3_K) {
|
1898
|
+
#ifdef GGML_QKK_64
|
1899
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1900
|
+
#else
|
1901
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1902
|
+
#endif
|
1903
|
+
}
|
1904
|
+
else if (src2t == GGML_TYPE_Q5_K) {
|
1905
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1906
|
+
}
|
1907
|
+
else if (src2t == GGML_TYPE_Q6_K) {
|
1908
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1909
|
+
} else {
|
1910
|
+
const int64_t ny = (_ne1 + nrows - 1)/nrows;
|
1911
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1912
|
+
}
|
1542
1913
|
}
|
1543
1914
|
} break;
|
1544
1915
|
case GGML_OP_GET_ROWS:
|
@@ -1559,16 +1930,19 @@ void ggml_metal_graph_compute(
|
|
1559
1930
|
default: GGML_ASSERT(false && "not implemented");
|
1560
1931
|
}
|
1561
1932
|
|
1562
|
-
[encoder setBuffer:id_src0
|
1563
|
-
[encoder setBuffer:id_src1
|
1564
|
-
[encoder setBuffer:id_dst
|
1933
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1934
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1935
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1565
1936
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
1566
1937
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
1567
|
-
[encoder setBytes:&
|
1568
|
-
|
1569
|
-
|
1570
|
-
|
1571
|
-
[encoder
|
1938
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
|
1939
|
+
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
|
1940
|
+
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
|
1941
|
+
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
|
1942
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
|
1943
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
|
1944
|
+
|
1945
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
1572
1946
|
} break;
|
1573
1947
|
case GGML_OP_RMS_NORM:
|
1574
1948
|
{
|
@@ -1595,6 +1969,38 @@ void ggml_metal_graph_compute(
|
|
1595
1969
|
|
1596
1970
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1597
1971
|
} break;
|
1972
|
+
case GGML_OP_GROUP_NORM:
|
1973
|
+
{
|
1974
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
1975
|
+
|
1976
|
+
//float eps;
|
1977
|
+
//memcpy(&eps, dst->op_params, sizeof(float));
|
1978
|
+
|
1979
|
+
const float eps = 1e-6f; // TODO: temporarily hardcoded
|
1980
|
+
|
1981
|
+
const int32_t n_groups = ((int32_t *) dst->op_params)[0];
|
1982
|
+
|
1983
|
+
int nth = 32; // SIMD width
|
1984
|
+
|
1985
|
+
//while (nth < ne00/4 && nth < 1024) {
|
1986
|
+
// nth *= 2;
|
1987
|
+
//}
|
1988
|
+
|
1989
|
+
[encoder setComputePipelineState:ctx->pipeline_group_norm];
|
1990
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1991
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1992
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
1993
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
1994
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
1995
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
|
1996
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
|
1997
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
|
1998
|
+
[encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
|
1999
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:9];
|
2000
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
2001
|
+
|
2002
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2003
|
+
} break;
|
1598
2004
|
case GGML_OP_NORM:
|
1599
2005
|
{
|
1600
2006
|
float eps;
|
@@ -1764,6 +2170,65 @@ void ggml_metal_graph_compute(
|
|
1764
2170
|
|
1765
2171
|
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
1766
2172
|
} break;
|
2173
|
+
case GGML_OP_UPSCALE:
|
2174
|
+
{
|
2175
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2176
|
+
|
2177
|
+
const int sf = dst->op_params[0];
|
2178
|
+
|
2179
|
+
[encoder setComputePipelineState:ctx->pipeline_upscale_f32];
|
2180
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2181
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2182
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
2183
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
2184
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
2185
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
2186
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
2187
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
2188
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
2189
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
2190
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
2191
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
2192
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
2193
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
2194
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
2195
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
2196
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
2197
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
2198
|
+
[encoder setBytes:&sf length:sizeof(sf) atIndex:18];
|
2199
|
+
|
2200
|
+
const int nth = MIN(1024, ne0);
|
2201
|
+
|
2202
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2203
|
+
} break;
|
2204
|
+
case GGML_OP_PAD:
|
2205
|
+
{
|
2206
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2207
|
+
|
2208
|
+
[encoder setComputePipelineState:ctx->pipeline_pad_f32];
|
2209
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2210
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2211
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
2212
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
2213
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
2214
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
2215
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
2216
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
2217
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
2218
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
2219
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
2220
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
2221
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
2222
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
2223
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
2224
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
2225
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
2226
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
2227
|
+
|
2228
|
+
const int nth = MIN(1024, ne0);
|
2229
|
+
|
2230
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2231
|
+
} break;
|
1767
2232
|
case GGML_OP_ARGSORT:
|
1768
2233
|
{
|
1769
2234
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
@@ -1785,6 +2250,22 @@ void ggml_metal_graph_compute(
|
|
1785
2250
|
|
1786
2251
|
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
|
1787
2252
|
} break;
|
2253
|
+
case GGML_OP_LEAKY_RELU:
|
2254
|
+
{
|
2255
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2256
|
+
|
2257
|
+
float slope;
|
2258
|
+
memcpy(&slope, dst->op_params, sizeof(float));
|
2259
|
+
|
2260
|
+
[encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
|
2261
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2262
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2263
|
+
[encoder setBytes:&slope length:sizeof(slope) atIndex:2];
|
2264
|
+
|
2265
|
+
const int64_t n = ggml_nelements(dst);
|
2266
|
+
|
2267
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
2268
|
+
} break;
|
1788
2269
|
case GGML_OP_DUP:
|
1789
2270
|
case GGML_OP_CPY:
|
1790
2271
|
case GGML_OP_CONT:
|
@@ -1813,7 +2294,7 @@ void ggml_metal_graph_compute(
|
|
1813
2294
|
{
|
1814
2295
|
switch (dstt) {
|
1815
2296
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
|
1816
|
-
case GGML_TYPE_F32:
|
2297
|
+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
|
1817
2298
|
default: GGML_ASSERT(false && "not implemented");
|
1818
2299
|
};
|
1819
2300
|
} break;
|