llama_cpp 0.14.7 → 0.15.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/README.md +2 -2
- data/ext/llama_cpp/extconf.rb +2 -1
- data/ext/llama_cpp/llama_cpp.cpp +53 -9
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +18 -3
- data/vendor/tmp/llama.cpp/Makefile +41 -16
- data/vendor/tmp/llama.cpp/ggml-backend.c +7 -5
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +6 -0
- data/vendor/tmp/llama.cpp/ggml-impl.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +7 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +376 -176
- data/vendor/tmp/llama.cpp/ggml-metal.metal +654 -18
- data/vendor/tmp/llama.cpp/ggml-quants.c +284 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +17 -7
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +5 -0
- data/vendor/tmp/llama.cpp/ggml.c +391 -27
- data/vendor/tmp/llama.cpp/ggml.h +22 -0
- data/vendor/tmp/llama.cpp/llama.cpp +623 -395
- data/vendor/tmp/llama.cpp/llama.h +27 -9
- data/vendor/tmp/llama.cpp/sgemm.cpp +83 -87
- data/vendor/tmp/llama.cpp/sgemm.h +4 -2
- data/vendor/tmp/llama.cpp/unicode-data.cpp +1 -1
- data/vendor/tmp/llama.cpp/unicode-data.h +2 -2
- data/vendor/tmp/llama.cpp/unicode.cpp +448 -39
- data/vendor/tmp/llama.cpp/unicode.h +2 -1
- metadata +3 -3
@@ -46,8 +46,10 @@ enum ggml_metal_kernel_type {
|
|
46
46
|
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
47
47
|
GGML_METAL_KERNEL_TYPE_SILU,
|
48
48
|
GGML_METAL_KERNEL_TYPE_SILU_4,
|
49
|
-
|
50
|
-
|
49
|
+
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
50
|
+
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
51
|
+
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
52
|
+
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
|
51
53
|
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
|
52
54
|
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
|
53
55
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
|
@@ -177,6 +179,14 @@ enum ggml_metal_kernel_type {
|
|
177
179
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
178
180
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
|
179
181
|
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
182
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
|
183
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
|
184
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
185
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
186
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
187
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
188
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
189
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
180
190
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
181
191
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
182
192
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
@@ -443,7 +453,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
443
453
|
}
|
444
454
|
|
445
455
|
/*
|
446
|
-
GGML_METAL_LOG_INFO("%s: loaded %-
|
456
|
+
GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
|
447
457
|
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
|
448
458
|
(int) kernel->pipeline.threadExecutionWidth); \
|
449
459
|
*/
|
@@ -459,172 +469,182 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
459
469
|
return NULL; \
|
460
470
|
} \
|
461
471
|
} else { \
|
462
|
-
GGML_METAL_LOG_WARN("%s: skipping %-
|
472
|
+
GGML_METAL_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
|
463
473
|
}
|
464
474
|
|
465
475
|
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
466
476
|
|
467
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD,
|
468
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW,
|
469
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL,
|
470
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
471
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,
|
472
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
473
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE,
|
474
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4,
|
475
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP,
|
476
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH,
|
477
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU,
|
478
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU,
|
479
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4,
|
480
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
481
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
482
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU,
|
483
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4,
|
484
|
-
GGML_METAL_ADD_KERNEL(
|
485
|
-
GGML_METAL_ADD_KERNEL(
|
486
|
-
GGML_METAL_ADD_KERNEL(
|
487
|
-
GGML_METAL_ADD_KERNEL(
|
488
|
-
GGML_METAL_ADD_KERNEL(
|
489
|
-
GGML_METAL_ADD_KERNEL(
|
490
|
-
GGML_METAL_ADD_KERNEL(
|
491
|
-
GGML_METAL_ADD_KERNEL(
|
492
|
-
GGML_METAL_ADD_KERNEL(
|
493
|
-
GGML_METAL_ADD_KERNEL(
|
494
|
-
GGML_METAL_ADD_KERNEL(
|
495
|
-
GGML_METAL_ADD_KERNEL(
|
496
|
-
GGML_METAL_ADD_KERNEL(
|
497
|
-
GGML_METAL_ADD_KERNEL(
|
498
|
-
GGML_METAL_ADD_KERNEL(
|
499
|
-
GGML_METAL_ADD_KERNEL(
|
500
|
-
GGML_METAL_ADD_KERNEL(
|
501
|
-
GGML_METAL_ADD_KERNEL(
|
502
|
-
GGML_METAL_ADD_KERNEL(
|
503
|
-
GGML_METAL_ADD_KERNEL(
|
504
|
-
GGML_METAL_ADD_KERNEL(
|
505
|
-
GGML_METAL_ADD_KERNEL(
|
506
|
-
GGML_METAL_ADD_KERNEL(
|
507
|
-
GGML_METAL_ADD_KERNEL(
|
508
|
-
GGML_METAL_ADD_KERNEL(
|
509
|
-
GGML_METAL_ADD_KERNEL(
|
510
|
-
GGML_METAL_ADD_KERNEL(
|
511
|
-
GGML_METAL_ADD_KERNEL(
|
512
|
-
GGML_METAL_ADD_KERNEL(
|
513
|
-
GGML_METAL_ADD_KERNEL(
|
514
|
-
GGML_METAL_ADD_KERNEL(
|
515
|
-
GGML_METAL_ADD_KERNEL(
|
516
|
-
GGML_METAL_ADD_KERNEL(
|
517
|
-
GGML_METAL_ADD_KERNEL(
|
518
|
-
GGML_METAL_ADD_KERNEL(
|
519
|
-
GGML_METAL_ADD_KERNEL(
|
520
|
-
GGML_METAL_ADD_KERNEL(
|
521
|
-
GGML_METAL_ADD_KERNEL(
|
522
|
-
GGML_METAL_ADD_KERNEL(
|
523
|
-
GGML_METAL_ADD_KERNEL(
|
524
|
-
GGML_METAL_ADD_KERNEL(
|
525
|
-
GGML_METAL_ADD_KERNEL(
|
526
|
-
GGML_METAL_ADD_KERNEL(
|
527
|
-
GGML_METAL_ADD_KERNEL(
|
528
|
-
GGML_METAL_ADD_KERNEL(
|
529
|
-
GGML_METAL_ADD_KERNEL(
|
530
|
-
GGML_METAL_ADD_KERNEL(
|
531
|
-
GGML_METAL_ADD_KERNEL(
|
532
|
-
GGML_METAL_ADD_KERNEL(
|
533
|
-
GGML_METAL_ADD_KERNEL(
|
534
|
-
GGML_METAL_ADD_KERNEL(
|
535
|
-
GGML_METAL_ADD_KERNEL(
|
536
|
-
GGML_METAL_ADD_KERNEL(
|
537
|
-
GGML_METAL_ADD_KERNEL(
|
538
|
-
|
539
|
-
GGML_METAL_ADD_KERNEL(
|
540
|
-
//GGML_METAL_ADD_KERNEL(
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
GGML_METAL_ADD_KERNEL(
|
545
|
-
GGML_METAL_ADD_KERNEL(
|
546
|
-
GGML_METAL_ADD_KERNEL(
|
547
|
-
GGML_METAL_ADD_KERNEL(
|
548
|
-
GGML_METAL_ADD_KERNEL(
|
549
|
-
GGML_METAL_ADD_KERNEL(
|
550
|
-
GGML_METAL_ADD_KERNEL(
|
551
|
-
GGML_METAL_ADD_KERNEL(
|
552
|
-
GGML_METAL_ADD_KERNEL(
|
553
|
-
GGML_METAL_ADD_KERNEL(
|
554
|
-
GGML_METAL_ADD_KERNEL(
|
555
|
-
GGML_METAL_ADD_KERNEL(
|
556
|
-
GGML_METAL_ADD_KERNEL(
|
557
|
-
GGML_METAL_ADD_KERNEL(
|
558
|
-
GGML_METAL_ADD_KERNEL(
|
559
|
-
GGML_METAL_ADD_KERNEL(
|
560
|
-
GGML_METAL_ADD_KERNEL(
|
561
|
-
GGML_METAL_ADD_KERNEL(
|
562
|
-
GGML_METAL_ADD_KERNEL(
|
563
|
-
GGML_METAL_ADD_KERNEL(
|
564
|
-
GGML_METAL_ADD_KERNEL(
|
565
|
-
GGML_METAL_ADD_KERNEL(
|
566
|
-
GGML_METAL_ADD_KERNEL(
|
567
|
-
GGML_METAL_ADD_KERNEL(
|
568
|
-
GGML_METAL_ADD_KERNEL(
|
569
|
-
GGML_METAL_ADD_KERNEL(
|
570
|
-
GGML_METAL_ADD_KERNEL(
|
571
|
-
GGML_METAL_ADD_KERNEL(
|
572
|
-
GGML_METAL_ADD_KERNEL(
|
573
|
-
GGML_METAL_ADD_KERNEL(
|
574
|
-
GGML_METAL_ADD_KERNEL(
|
575
|
-
GGML_METAL_ADD_KERNEL(
|
576
|
-
GGML_METAL_ADD_KERNEL(
|
577
|
-
GGML_METAL_ADD_KERNEL(
|
578
|
-
GGML_METAL_ADD_KERNEL(
|
579
|
-
GGML_METAL_ADD_KERNEL(
|
580
|
-
GGML_METAL_ADD_KERNEL(
|
581
|
-
GGML_METAL_ADD_KERNEL(
|
582
|
-
GGML_METAL_ADD_KERNEL(
|
583
|
-
GGML_METAL_ADD_KERNEL(
|
584
|
-
GGML_METAL_ADD_KERNEL(
|
585
|
-
GGML_METAL_ADD_KERNEL(
|
586
|
-
GGML_METAL_ADD_KERNEL(
|
587
|
-
GGML_METAL_ADD_KERNEL(
|
588
|
-
GGML_METAL_ADD_KERNEL(
|
589
|
-
GGML_METAL_ADD_KERNEL(
|
590
|
-
GGML_METAL_ADD_KERNEL(
|
591
|
-
GGML_METAL_ADD_KERNEL(
|
592
|
-
GGML_METAL_ADD_KERNEL(
|
593
|
-
GGML_METAL_ADD_KERNEL(
|
594
|
-
GGML_METAL_ADD_KERNEL(
|
595
|
-
GGML_METAL_ADD_KERNEL(
|
596
|
-
GGML_METAL_ADD_KERNEL(
|
597
|
-
GGML_METAL_ADD_KERNEL(
|
598
|
-
GGML_METAL_ADD_KERNEL(
|
599
|
-
GGML_METAL_ADD_KERNEL(
|
600
|
-
GGML_METAL_ADD_KERNEL(
|
601
|
-
GGML_METAL_ADD_KERNEL(
|
602
|
-
GGML_METAL_ADD_KERNEL(
|
603
|
-
GGML_METAL_ADD_KERNEL(
|
604
|
-
GGML_METAL_ADD_KERNEL(
|
605
|
-
GGML_METAL_ADD_KERNEL(
|
606
|
-
GGML_METAL_ADD_KERNEL(
|
607
|
-
GGML_METAL_ADD_KERNEL(
|
608
|
-
GGML_METAL_ADD_KERNEL(
|
609
|
-
GGML_METAL_ADD_KERNEL(
|
610
|
-
GGML_METAL_ADD_KERNEL(
|
611
|
-
GGML_METAL_ADD_KERNEL(
|
612
|
-
GGML_METAL_ADD_KERNEL(
|
613
|
-
GGML_METAL_ADD_KERNEL(
|
614
|
-
GGML_METAL_ADD_KERNEL(
|
615
|
-
GGML_METAL_ADD_KERNEL(
|
616
|
-
GGML_METAL_ADD_KERNEL(
|
617
|
-
GGML_METAL_ADD_KERNEL(
|
618
|
-
GGML_METAL_ADD_KERNEL(
|
619
|
-
GGML_METAL_ADD_KERNEL(
|
620
|
-
GGML_METAL_ADD_KERNEL(
|
621
|
-
GGML_METAL_ADD_KERNEL(
|
622
|
-
GGML_METAL_ADD_KERNEL(
|
623
|
-
GGML_METAL_ADD_KERNEL(
|
624
|
-
GGML_METAL_ADD_KERNEL(
|
625
|
-
GGML_METAL_ADD_KERNEL(
|
626
|
-
GGML_METAL_ADD_KERNEL(
|
627
|
-
GGML_METAL_ADD_KERNEL(
|
477
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
478
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
|
479
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
480
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
481
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
482
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
483
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
484
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
485
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
486
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
487
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
488
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
489
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
|
490
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
491
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
492
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
493
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
494
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
|
495
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
|
496
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
|
497
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction);
|
498
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
499
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
500
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
501
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
|
502
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
|
503
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
|
504
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
505
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
|
506
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
|
507
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
|
508
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
|
509
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
|
510
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
|
511
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
|
512
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
513
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
514
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
|
515
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
|
516
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
|
517
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
|
518
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
|
519
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
520
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
521
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
522
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
523
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
524
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
525
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
526
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
527
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
528
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
|
529
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
|
530
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
|
531
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
|
532
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
|
533
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
|
534
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
|
535
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
|
536
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
|
537
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
|
538
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
|
539
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
|
540
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
541
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
542
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
543
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
|
544
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
|
545
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
|
546
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
|
547
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
548
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
549
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
550
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
|
551
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
552
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
|
553
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
|
554
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
|
555
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
|
556
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
|
557
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
|
558
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
|
559
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
|
560
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
|
561
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
|
562
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
|
563
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
|
564
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
565
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
566
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
567
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
|
568
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
|
569
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
|
570
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
|
571
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
572
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
573
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
574
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
575
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
576
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
|
577
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
|
578
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
|
579
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
|
580
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
|
581
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
|
582
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
|
583
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
|
584
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
|
585
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
586
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
587
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
588
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
|
589
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
|
590
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
|
591
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
|
592
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
|
593
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
|
594
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
595
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
596
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
597
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
|
598
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
|
599
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
|
600
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
|
601
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
|
602
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
|
603
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
|
604
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
|
605
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
|
606
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
607
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
|
608
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
609
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
|
610
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
|
611
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
|
612
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
|
613
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
614
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
615
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
616
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
617
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
618
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
619
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
620
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
621
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
622
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
623
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
|
624
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
625
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
626
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
627
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
|
628
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
|
629
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
|
630
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
|
631
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
|
632
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
|
633
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true);
|
634
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true);
|
635
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
636
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
637
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
638
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
639
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
640
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
641
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
642
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
|
643
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
644
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
645
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
646
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
647
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
628
648
|
}
|
629
649
|
|
630
650
|
[metal_library release];
|
@@ -743,6 +763,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
743
763
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
744
764
|
case GGML_OP_ARGSORT:
|
745
765
|
case GGML_OP_LEAKY_RELU:
|
766
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
746
767
|
return true;
|
747
768
|
case GGML_OP_MUL_MAT:
|
748
769
|
case GGML_OP_MUL_MAT_ID:
|
@@ -1326,20 +1347,33 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1326
1347
|
} break;
|
1327
1348
|
case GGML_OP_SOFT_MAX:
|
1328
1349
|
{
|
1350
|
+
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
1351
|
+
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
|
1352
|
+
|
1329
1353
|
int nth = 32; // SIMD width
|
1330
1354
|
|
1331
1355
|
id<MTLComputePipelineState> pipeline = nil;
|
1332
1356
|
|
1357
|
+
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
|
1358
|
+
|
1333
1359
|
if (ne00%4 == 0) {
|
1334
1360
|
while (nth < ne00/4 && nth < 256) {
|
1335
1361
|
nth *= 2;
|
1336
1362
|
}
|
1337
|
-
|
1363
|
+
if (use_f16) {
|
1364
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
|
1365
|
+
} else {
|
1366
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
|
1367
|
+
}
|
1338
1368
|
} else {
|
1339
1369
|
while (nth < ne00 && nth < 1024) {
|
1340
1370
|
nth *= 2;
|
1341
1371
|
}
|
1342
|
-
|
1372
|
+
if (use_f16) {
|
1373
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
|
1374
|
+
} else {
|
1375
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
|
1376
|
+
}
|
1343
1377
|
}
|
1344
1378
|
|
1345
1379
|
float scale;
|
@@ -2503,6 +2537,161 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2503
2537
|
|
2504
2538
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
2505
2539
|
} break;
|
2540
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
2541
|
+
{
|
2542
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
2543
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2544
|
+
|
2545
|
+
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
|
2546
|
+
|
2547
|
+
GGML_ASSERT(ggml_are_same_shape(src1, src2));
|
2548
|
+
GGML_ASSERT(src3);
|
2549
|
+
|
2550
|
+
size_t offs_src3 = 0;
|
2551
|
+
|
2552
|
+
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
2553
|
+
|
2554
|
+
GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
|
2555
|
+
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
|
2556
|
+
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
2557
|
+
|
2558
|
+
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
|
2559
|
+
const int64_t ne31 = src3 ? src3->ne[1] : 0;
|
2560
|
+
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
|
2561
|
+
const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
|
2562
|
+
|
2563
|
+
const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30);
|
2564
|
+
const uint64_t nb31 = src3 ? src3->nb[1] : 0;
|
2565
|
+
const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
|
2566
|
+
const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
|
2567
|
+
|
2568
|
+
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
2569
|
+
|
2570
|
+
float scale;
|
2571
|
+
memcpy(&scale, dst->op_params, sizeof(float));
|
2572
|
+
|
2573
|
+
id<MTLComputePipelineState> pipeline = nil;
|
2574
|
+
|
2575
|
+
bool use_vec_kernel = false;
|
2576
|
+
|
2577
|
+
if (ne01 >= 4 || (ne00%128 != 0)) {
|
2578
|
+
switch (ne00) {
|
2579
|
+
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
2580
|
+
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
2581
|
+
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
2582
|
+
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
2583
|
+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
2584
|
+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
2585
|
+
default:
|
2586
|
+
{
|
2587
|
+
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
2588
|
+
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
2589
|
+
GGML_ASSERT(false && "add template specialization for this size");
|
2590
|
+
}
|
2591
|
+
}
|
2592
|
+
} else {
|
2593
|
+
use_vec_kernel = true;
|
2594
|
+
|
2595
|
+
switch (ne00) {
|
2596
|
+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
2597
|
+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
2598
|
+
default:
|
2599
|
+
{
|
2600
|
+
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
2601
|
+
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
2602
|
+
GGML_ASSERT(false && "add template specialization for this size");
|
2603
|
+
}
|
2604
|
+
}
|
2605
|
+
}
|
2606
|
+
|
2607
|
+
[encoder setComputePipelineState:pipeline];
|
2608
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2609
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
2610
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
2611
|
+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
2612
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
2613
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
|
2614
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
|
2615
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
|
2616
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
|
2617
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
|
2618
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
|
2619
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
|
2620
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
|
2621
|
+
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
|
2622
|
+
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
|
2623
|
+
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
|
2624
|
+
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
|
2625
|
+
[encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
|
2626
|
+
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
|
2627
|
+
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
|
2628
|
+
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
|
2629
|
+
[encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21];
|
2630
|
+
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22];
|
2631
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23];
|
2632
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24];
|
2633
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25];
|
2634
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
2635
|
+
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
2636
|
+
|
2637
|
+
if (!use_vec_kernel) {
|
2638
|
+
// half8x8 kernel
|
2639
|
+
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
2640
|
+
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
2641
|
+
|
2642
|
+
GGML_ASSERT(nqptg <= 32);
|
2643
|
+
GGML_ASSERT(nqptg % 8 == 0);
|
2644
|
+
GGML_ASSERT(ncpsg % 32 == 0);
|
2645
|
+
|
2646
|
+
int64_t nsgmax = 2;
|
2647
|
+
|
2648
|
+
while (true) {
|
2649
|
+
const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
|
2650
|
+
if (smem > ctx->device.maxThreadgroupMemoryLength) {
|
2651
|
+
break;
|
2652
|
+
}
|
2653
|
+
nsgmax *= 2;
|
2654
|
+
}
|
2655
|
+
nsgmax /= 2;
|
2656
|
+
|
2657
|
+
// simdgroups per threadgroup (a.k.a. warps)
|
2658
|
+
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
2659
|
+
|
2660
|
+
const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
2661
|
+
|
2662
|
+
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
2663
|
+
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
2664
|
+
|
2665
|
+
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
2666
|
+
|
2667
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
2668
|
+
} else {
|
2669
|
+
// half1x4 kernel
|
2670
|
+
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
|
2671
|
+
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
2672
|
+
|
2673
|
+
GGML_ASSERT(nqptg <= 32);
|
2674
|
+
GGML_ASSERT(nqptg % 1 == 0);
|
2675
|
+
GGML_ASSERT(ncpsg % 32 == 0);
|
2676
|
+
|
2677
|
+
// simdgroups per threadgroup (a.k.a. warps)
|
2678
|
+
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
|
2679
|
+
|
2680
|
+
int64_t nsg = 1;
|
2681
|
+
while (nsg <= nsgt) {
|
2682
|
+
nsg *= 2;
|
2683
|
+
}
|
2684
|
+
nsg /= 2;
|
2685
|
+
|
2686
|
+
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
|
2687
|
+
|
2688
|
+
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
2689
|
+
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
2690
|
+
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
2691
|
+
|
2692
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
2693
|
+
}
|
2694
|
+
} break;
|
2506
2695
|
case GGML_OP_DUP:
|
2507
2696
|
case GGML_OP_CPY:
|
2508
2697
|
case GGML_OP_CONT:
|
@@ -2590,6 +2779,11 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2590
2779
|
MTLCommandBufferStatus status = [command_buffer status];
|
2591
2780
|
if (status != MTLCommandBufferStatusCompleted) {
|
2592
2781
|
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
2782
|
+
if (status == MTLCommandBufferStatusError) {
|
2783
|
+
NSString * error_code = [command_buffer error].localizedDescription;
|
2784
|
+
GGML_METAL_LOG_INFO("error: %s\n", [error_code UTF8String]);
|
2785
|
+
}
|
2786
|
+
|
2593
2787
|
return GGML_STATUS_FAILED;
|
2594
2788
|
}
|
2595
2789
|
}
|
@@ -2706,10 +2900,13 @@ GGML_CALL static const char * ggml_backend_metal_buffer_type_get_name(ggml_backe
|
|
2706
2900
|
UNUSED(buft);
|
2707
2901
|
}
|
2708
2902
|
|
2709
|
-
static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
|
2903
|
+
static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
|
2904
|
+
#ifndef GGML_METAL_NDEBUG
|
2710
2905
|
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
2711
2906
|
if (@available(macOS 10.12, iOS 16.0, *)) {
|
2712
|
-
GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
2907
|
+
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)",
|
2908
|
+
__func__,
|
2909
|
+
size_aligned / 1024.0 / 1024.0,
|
2713
2910
|
device.currentAllocatedSize / 1024.0 / 1024.0,
|
2714
2911
|
device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
2715
2912
|
|
@@ -2719,10 +2916,15 @@ static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
|
|
2719
2916
|
GGML_METAL_LOG_INFO("\n");
|
2720
2917
|
}
|
2721
2918
|
} else {
|
2722
|
-
GGML_METAL_LOG_INFO(", (%8.2f)\n",
|
2919
|
+
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n",
|
2920
|
+
__func__,
|
2921
|
+
size_aligned / 1024.0 / 1024.0,
|
2922
|
+
device.currentAllocatedSize / 1024.0 / 1024.0);
|
2723
2923
|
}
|
2924
|
+
#endif
|
2724
2925
|
#endif
|
2725
2926
|
UNUSED(device);
|
2927
|
+
UNUSED(size_aligned);
|
2726
2928
|
}
|
2727
2929
|
|
2728
2930
|
GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
@@ -2756,8 +2958,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff
|
|
2756
2958
|
return NULL;
|
2757
2959
|
}
|
2758
2960
|
|
2759
|
-
|
2760
|
-
ggml_backend_metal_log_allocated_size(device);
|
2961
|
+
//ggml_backend_metal_log_allocated_size(device, size_aligned);
|
2761
2962
|
|
2762
2963
|
return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
|
2763
2964
|
}
|
@@ -2844,7 +3045,7 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
|
|
2844
3045
|
return false;
|
2845
3046
|
}
|
2846
3047
|
|
2847
|
-
|
3048
|
+
ggml_backend_metal_log_allocated_size(device, size_aligned);
|
2848
3049
|
|
2849
3050
|
++ctx->n_buffers;
|
2850
3051
|
} else {
|
@@ -2867,7 +3068,8 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
|
|
2867
3068
|
return false;
|
2868
3069
|
}
|
2869
3070
|
|
2870
|
-
|
3071
|
+
ggml_backend_metal_log_allocated_size(device, size_step_aligned);
|
3072
|
+
|
2871
3073
|
if (i + size_step < size) {
|
2872
3074
|
GGML_METAL_LOG_INFO("\n");
|
2873
3075
|
}
|
@@ -2876,8 +3078,6 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
|
|
2876
3078
|
}
|
2877
3079
|
}
|
2878
3080
|
|
2879
|
-
ggml_backend_metal_log_allocated_size(device);
|
2880
|
-
|
2881
3081
|
return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
|
2882
3082
|
}
|
2883
3083
|
|