llama_cpp 0.14.7 → 0.15.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/README.md +2 -2
- data/ext/llama_cpp/extconf.rb +2 -1
- data/ext/llama_cpp/llama_cpp.cpp +53 -9
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +18 -3
- data/vendor/tmp/llama.cpp/Makefile +41 -16
- data/vendor/tmp/llama.cpp/ggml-backend.c +7 -5
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +6 -0
- data/vendor/tmp/llama.cpp/ggml-impl.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +7 -0
- data/vendor/tmp/llama.cpp/ggml-metal.m +376 -176
- data/vendor/tmp/llama.cpp/ggml-metal.metal +654 -18
- data/vendor/tmp/llama.cpp/ggml-quants.c +284 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +17 -7
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +5 -0
- data/vendor/tmp/llama.cpp/ggml.c +391 -27
- data/vendor/tmp/llama.cpp/ggml.h +22 -0
- data/vendor/tmp/llama.cpp/llama.cpp +623 -395
- data/vendor/tmp/llama.cpp/llama.h +27 -9
- data/vendor/tmp/llama.cpp/sgemm.cpp +83 -87
- data/vendor/tmp/llama.cpp/sgemm.h +4 -2
- data/vendor/tmp/llama.cpp/unicode-data.cpp +1 -1
- data/vendor/tmp/llama.cpp/unicode-data.h +2 -2
- data/vendor/tmp/llama.cpp/unicode.cpp +448 -39
- data/vendor/tmp/llama.cpp/unicode.h +2 -1
- metadata +3 -3
@@ -46,8 +46,10 @@ enum ggml_metal_kernel_type {
|
|
46
46
|
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
47
47
|
GGML_METAL_KERNEL_TYPE_SILU,
|
48
48
|
GGML_METAL_KERNEL_TYPE_SILU_4,
|
49
|
-
|
50
|
-
|
49
|
+
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
50
|
+
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
51
|
+
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
52
|
+
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
|
51
53
|
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
|
52
54
|
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
|
53
55
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
|
@@ -177,6 +179,14 @@ enum ggml_metal_kernel_type {
|
|
177
179
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
178
180
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
|
179
181
|
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
182
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
|
183
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
|
184
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
185
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
186
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
187
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
188
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
189
|
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
180
190
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
181
191
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
182
192
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
@@ -443,7 +453,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
443
453
|
}
|
444
454
|
|
445
455
|
/*
|
446
|
-
GGML_METAL_LOG_INFO("%s: loaded %-
|
456
|
+
GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
|
447
457
|
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
|
448
458
|
(int) kernel->pipeline.threadExecutionWidth); \
|
449
459
|
*/
|
@@ -459,172 +469,182 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
459
469
|
return NULL; \
|
460
470
|
} \
|
461
471
|
} else { \
|
462
|
-
GGML_METAL_LOG_WARN("%s: skipping %-
|
472
|
+
GGML_METAL_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
|
463
473
|
}
|
464
474
|
|
465
475
|
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
466
476
|
|
467
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD,
|
468
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW,
|
469
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL,
|
470
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
471
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,
|
472
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
473
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE,
|
474
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4,
|
475
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP,
|
476
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH,
|
477
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU,
|
478
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU,
|
479
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4,
|
480
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
481
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
482
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU,
|
483
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4,
|
484
|
-
GGML_METAL_ADD_KERNEL(
|
485
|
-
GGML_METAL_ADD_KERNEL(
|
486
|
-
GGML_METAL_ADD_KERNEL(
|
487
|
-
GGML_METAL_ADD_KERNEL(
|
488
|
-
GGML_METAL_ADD_KERNEL(
|
489
|
-
GGML_METAL_ADD_KERNEL(
|
490
|
-
GGML_METAL_ADD_KERNEL(
|
491
|
-
GGML_METAL_ADD_KERNEL(
|
492
|
-
GGML_METAL_ADD_KERNEL(
|
493
|
-
GGML_METAL_ADD_KERNEL(
|
494
|
-
GGML_METAL_ADD_KERNEL(
|
495
|
-
GGML_METAL_ADD_KERNEL(
|
496
|
-
GGML_METAL_ADD_KERNEL(
|
497
|
-
GGML_METAL_ADD_KERNEL(
|
498
|
-
GGML_METAL_ADD_KERNEL(
|
499
|
-
GGML_METAL_ADD_KERNEL(
|
500
|
-
GGML_METAL_ADD_KERNEL(
|
501
|
-
GGML_METAL_ADD_KERNEL(
|
502
|
-
GGML_METAL_ADD_KERNEL(
|
503
|
-
GGML_METAL_ADD_KERNEL(
|
504
|
-
GGML_METAL_ADD_KERNEL(
|
505
|
-
GGML_METAL_ADD_KERNEL(
|
506
|
-
GGML_METAL_ADD_KERNEL(
|
507
|
-
GGML_METAL_ADD_KERNEL(
|
508
|
-
GGML_METAL_ADD_KERNEL(
|
509
|
-
GGML_METAL_ADD_KERNEL(
|
510
|
-
GGML_METAL_ADD_KERNEL(
|
511
|
-
GGML_METAL_ADD_KERNEL(
|
512
|
-
GGML_METAL_ADD_KERNEL(
|
513
|
-
GGML_METAL_ADD_KERNEL(
|
514
|
-
GGML_METAL_ADD_KERNEL(
|
515
|
-
GGML_METAL_ADD_KERNEL(
|
516
|
-
GGML_METAL_ADD_KERNEL(
|
517
|
-
GGML_METAL_ADD_KERNEL(
|
518
|
-
GGML_METAL_ADD_KERNEL(
|
519
|
-
GGML_METAL_ADD_KERNEL(
|
520
|
-
GGML_METAL_ADD_KERNEL(
|
521
|
-
GGML_METAL_ADD_KERNEL(
|
522
|
-
GGML_METAL_ADD_KERNEL(
|
523
|
-
GGML_METAL_ADD_KERNEL(
|
524
|
-
GGML_METAL_ADD_KERNEL(
|
525
|
-
GGML_METAL_ADD_KERNEL(
|
526
|
-
GGML_METAL_ADD_KERNEL(
|
527
|
-
GGML_METAL_ADD_KERNEL(
|
528
|
-
GGML_METAL_ADD_KERNEL(
|
529
|
-
GGML_METAL_ADD_KERNEL(
|
530
|
-
GGML_METAL_ADD_KERNEL(
|
531
|
-
GGML_METAL_ADD_KERNEL(
|
532
|
-
GGML_METAL_ADD_KERNEL(
|
533
|
-
GGML_METAL_ADD_KERNEL(
|
534
|
-
GGML_METAL_ADD_KERNEL(
|
535
|
-
GGML_METAL_ADD_KERNEL(
|
536
|
-
GGML_METAL_ADD_KERNEL(
|
537
|
-
GGML_METAL_ADD_KERNEL(
|
538
|
-
|
539
|
-
GGML_METAL_ADD_KERNEL(
|
540
|
-
//GGML_METAL_ADD_KERNEL(
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
GGML_METAL_ADD_KERNEL(
|
545
|
-
GGML_METAL_ADD_KERNEL(
|
546
|
-
GGML_METAL_ADD_KERNEL(
|
547
|
-
GGML_METAL_ADD_KERNEL(
|
548
|
-
GGML_METAL_ADD_KERNEL(
|
549
|
-
GGML_METAL_ADD_KERNEL(
|
550
|
-
GGML_METAL_ADD_KERNEL(
|
551
|
-
GGML_METAL_ADD_KERNEL(
|
552
|
-
GGML_METAL_ADD_KERNEL(
|
553
|
-
GGML_METAL_ADD_KERNEL(
|
554
|
-
GGML_METAL_ADD_KERNEL(
|
555
|
-
GGML_METAL_ADD_KERNEL(
|
556
|
-
GGML_METAL_ADD_KERNEL(
|
557
|
-
GGML_METAL_ADD_KERNEL(
|
558
|
-
GGML_METAL_ADD_KERNEL(
|
559
|
-
GGML_METAL_ADD_KERNEL(
|
560
|
-
GGML_METAL_ADD_KERNEL(
|
561
|
-
GGML_METAL_ADD_KERNEL(
|
562
|
-
GGML_METAL_ADD_KERNEL(
|
563
|
-
GGML_METAL_ADD_KERNEL(
|
564
|
-
GGML_METAL_ADD_KERNEL(
|
565
|
-
GGML_METAL_ADD_KERNEL(
|
566
|
-
GGML_METAL_ADD_KERNEL(
|
567
|
-
GGML_METAL_ADD_KERNEL(
|
568
|
-
GGML_METAL_ADD_KERNEL(
|
569
|
-
GGML_METAL_ADD_KERNEL(
|
570
|
-
GGML_METAL_ADD_KERNEL(
|
571
|
-
GGML_METAL_ADD_KERNEL(
|
572
|
-
GGML_METAL_ADD_KERNEL(
|
573
|
-
GGML_METAL_ADD_KERNEL(
|
574
|
-
GGML_METAL_ADD_KERNEL(
|
575
|
-
GGML_METAL_ADD_KERNEL(
|
576
|
-
GGML_METAL_ADD_KERNEL(
|
577
|
-
GGML_METAL_ADD_KERNEL(
|
578
|
-
GGML_METAL_ADD_KERNEL(
|
579
|
-
GGML_METAL_ADD_KERNEL(
|
580
|
-
GGML_METAL_ADD_KERNEL(
|
581
|
-
GGML_METAL_ADD_KERNEL(
|
582
|
-
GGML_METAL_ADD_KERNEL(
|
583
|
-
GGML_METAL_ADD_KERNEL(
|
584
|
-
GGML_METAL_ADD_KERNEL(
|
585
|
-
GGML_METAL_ADD_KERNEL(
|
586
|
-
GGML_METAL_ADD_KERNEL(
|
587
|
-
GGML_METAL_ADD_KERNEL(
|
588
|
-
GGML_METAL_ADD_KERNEL(
|
589
|
-
GGML_METAL_ADD_KERNEL(
|
590
|
-
GGML_METAL_ADD_KERNEL(
|
591
|
-
GGML_METAL_ADD_KERNEL(
|
592
|
-
GGML_METAL_ADD_KERNEL(
|
593
|
-
GGML_METAL_ADD_KERNEL(
|
594
|
-
GGML_METAL_ADD_KERNEL(
|
595
|
-
GGML_METAL_ADD_KERNEL(
|
596
|
-
GGML_METAL_ADD_KERNEL(
|
597
|
-
GGML_METAL_ADD_KERNEL(
|
598
|
-
GGML_METAL_ADD_KERNEL(
|
599
|
-
GGML_METAL_ADD_KERNEL(
|
600
|
-
GGML_METAL_ADD_KERNEL(
|
601
|
-
GGML_METAL_ADD_KERNEL(
|
602
|
-
GGML_METAL_ADD_KERNEL(
|
603
|
-
GGML_METAL_ADD_KERNEL(
|
604
|
-
GGML_METAL_ADD_KERNEL(
|
605
|
-
GGML_METAL_ADD_KERNEL(
|
606
|
-
GGML_METAL_ADD_KERNEL(
|
607
|
-
GGML_METAL_ADD_KERNEL(
|
608
|
-
GGML_METAL_ADD_KERNEL(
|
609
|
-
GGML_METAL_ADD_KERNEL(
|
610
|
-
GGML_METAL_ADD_KERNEL(
|
611
|
-
GGML_METAL_ADD_KERNEL(
|
612
|
-
GGML_METAL_ADD_KERNEL(
|
613
|
-
GGML_METAL_ADD_KERNEL(
|
614
|
-
GGML_METAL_ADD_KERNEL(
|
615
|
-
GGML_METAL_ADD_KERNEL(
|
616
|
-
GGML_METAL_ADD_KERNEL(
|
617
|
-
GGML_METAL_ADD_KERNEL(
|
618
|
-
GGML_METAL_ADD_KERNEL(
|
619
|
-
GGML_METAL_ADD_KERNEL(
|
620
|
-
GGML_METAL_ADD_KERNEL(
|
621
|
-
GGML_METAL_ADD_KERNEL(
|
622
|
-
GGML_METAL_ADD_KERNEL(
|
623
|
-
GGML_METAL_ADD_KERNEL(
|
624
|
-
GGML_METAL_ADD_KERNEL(
|
625
|
-
GGML_METAL_ADD_KERNEL(
|
626
|
-
GGML_METAL_ADD_KERNEL(
|
627
|
-
GGML_METAL_ADD_KERNEL(
|
477
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
478
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
|
479
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
480
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
481
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
482
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
483
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
484
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
485
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
486
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
487
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
488
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
489
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
|
490
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
491
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
492
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
493
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
494
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
|
495
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
|
496
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
|
497
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction);
|
498
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
499
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
500
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
501
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
|
502
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
|
503
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
|
504
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
505
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
|
506
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
|
507
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
|
508
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
|
509
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
|
510
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
|
511
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
|
512
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
513
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
514
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
|
515
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
|
516
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
|
517
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
|
518
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
|
519
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
520
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
521
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
522
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
523
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
524
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
525
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
526
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
527
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
528
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
|
529
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
|
530
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
|
531
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
|
532
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
|
533
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
|
534
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
|
535
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
|
536
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
|
537
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
|
538
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
|
539
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
|
540
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
541
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
542
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
543
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
|
544
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
|
545
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
|
546
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
|
547
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
548
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
549
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
550
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
|
551
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
552
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
|
553
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
|
554
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
|
555
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
|
556
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
|
557
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
|
558
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
|
559
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
|
560
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
|
561
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
|
562
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
|
563
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
|
564
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
565
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
566
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
567
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
|
568
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
|
569
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
|
570
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
|
571
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
572
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
573
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
574
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
575
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
576
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
|
577
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
|
578
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
|
579
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
|
580
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
|
581
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
|
582
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
|
583
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
|
584
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
|
585
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
586
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
587
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
588
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
|
589
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
|
590
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
|
591
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
|
592
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
|
593
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
|
594
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
595
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
596
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
597
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
|
598
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
|
599
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
|
600
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
|
601
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
|
602
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
|
603
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
|
604
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
|
605
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
|
606
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
607
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
|
608
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
609
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
|
610
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
|
611
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
|
612
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
|
613
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
614
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
615
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
616
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
617
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
618
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
619
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
620
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
621
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
622
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
623
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
|
624
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
625
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
626
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
627
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
|
628
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
|
629
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
|
630
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
|
631
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
|
632
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
|
633
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true);
|
634
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true);
|
635
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
636
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
637
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
638
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
639
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
640
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
641
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
642
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
|
643
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
644
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
645
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
646
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
647
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
628
648
|
}
|
629
649
|
|
630
650
|
[metal_library release];
|
@@ -743,6 +763,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
743
763
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
744
764
|
case GGML_OP_ARGSORT:
|
745
765
|
case GGML_OP_LEAKY_RELU:
|
766
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
746
767
|
return true;
|
747
768
|
case GGML_OP_MUL_MAT:
|
748
769
|
case GGML_OP_MUL_MAT_ID:
|
@@ -1326,20 +1347,33 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1326
1347
|
} break;
|
1327
1348
|
case GGML_OP_SOFT_MAX:
|
1328
1349
|
{
|
1350
|
+
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
1351
|
+
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
|
1352
|
+
|
1329
1353
|
int nth = 32; // SIMD width
|
1330
1354
|
|
1331
1355
|
id<MTLComputePipelineState> pipeline = nil;
|
1332
1356
|
|
1357
|
+
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
|
1358
|
+
|
1333
1359
|
if (ne00%4 == 0) {
|
1334
1360
|
while (nth < ne00/4 && nth < 256) {
|
1335
1361
|
nth *= 2;
|
1336
1362
|
}
|
1337
|
-
|
1363
|
+
if (use_f16) {
|
1364
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
|
1365
|
+
} else {
|
1366
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
|
1367
|
+
}
|
1338
1368
|
} else {
|
1339
1369
|
while (nth < ne00 && nth < 1024) {
|
1340
1370
|
nth *= 2;
|
1341
1371
|
}
|
1342
|
-
|
1372
|
+
if (use_f16) {
|
1373
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
|
1374
|
+
} else {
|
1375
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
|
1376
|
+
}
|
1343
1377
|
}
|
1344
1378
|
|
1345
1379
|
float scale;
|
@@ -2503,6 +2537,161 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2503
2537
|
|
2504
2538
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
2505
2539
|
} break;
|
2540
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
2541
|
+
{
|
2542
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
2543
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2544
|
+
|
2545
|
+
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
|
2546
|
+
|
2547
|
+
GGML_ASSERT(ggml_are_same_shape(src1, src2));
|
2548
|
+
GGML_ASSERT(src3);
|
2549
|
+
|
2550
|
+
size_t offs_src3 = 0;
|
2551
|
+
|
2552
|
+
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
2553
|
+
|
2554
|
+
GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
|
2555
|
+
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
|
2556
|
+
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
2557
|
+
|
2558
|
+
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
|
2559
|
+
const int64_t ne31 = src3 ? src3->ne[1] : 0;
|
2560
|
+
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
|
2561
|
+
const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
|
2562
|
+
|
2563
|
+
const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30);
|
2564
|
+
const uint64_t nb31 = src3 ? src3->nb[1] : 0;
|
2565
|
+
const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
|
2566
|
+
const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
|
2567
|
+
|
2568
|
+
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
2569
|
+
|
2570
|
+
float scale;
|
2571
|
+
memcpy(&scale, dst->op_params, sizeof(float));
|
2572
|
+
|
2573
|
+
id<MTLComputePipelineState> pipeline = nil;
|
2574
|
+
|
2575
|
+
bool use_vec_kernel = false;
|
2576
|
+
|
2577
|
+
if (ne01 >= 4 || (ne00%128 != 0)) {
|
2578
|
+
switch (ne00) {
|
2579
|
+
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
2580
|
+
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
2581
|
+
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
2582
|
+
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
2583
|
+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
2584
|
+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
2585
|
+
default:
|
2586
|
+
{
|
2587
|
+
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
2588
|
+
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
2589
|
+
GGML_ASSERT(false && "add template specialization for this size");
|
2590
|
+
}
|
2591
|
+
}
|
2592
|
+
} else {
|
2593
|
+
use_vec_kernel = true;
|
2594
|
+
|
2595
|
+
switch (ne00) {
|
2596
|
+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
2597
|
+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
2598
|
+
default:
|
2599
|
+
{
|
2600
|
+
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
2601
|
+
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
2602
|
+
GGML_ASSERT(false && "add template specialization for this size");
|
2603
|
+
}
|
2604
|
+
}
|
2605
|
+
}
|
2606
|
+
|
2607
|
+
[encoder setComputePipelineState:pipeline];
|
2608
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2609
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
2610
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
2611
|
+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
2612
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
2613
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
|
2614
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
|
2615
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
|
2616
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
|
2617
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
|
2618
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
|
2619
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
|
2620
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
|
2621
|
+
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
|
2622
|
+
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
|
2623
|
+
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
|
2624
|
+
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
|
2625
|
+
[encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
|
2626
|
+
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
|
2627
|
+
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
|
2628
|
+
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
|
2629
|
+
[encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21];
|
2630
|
+
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22];
|
2631
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23];
|
2632
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24];
|
2633
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25];
|
2634
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
2635
|
+
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
2636
|
+
|
2637
|
+
if (!use_vec_kernel) {
|
2638
|
+
// half8x8 kernel
|
2639
|
+
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
2640
|
+
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
2641
|
+
|
2642
|
+
GGML_ASSERT(nqptg <= 32);
|
2643
|
+
GGML_ASSERT(nqptg % 8 == 0);
|
2644
|
+
GGML_ASSERT(ncpsg % 32 == 0);
|
2645
|
+
|
2646
|
+
int64_t nsgmax = 2;
|
2647
|
+
|
2648
|
+
while (true) {
|
2649
|
+
const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
|
2650
|
+
if (smem > ctx->device.maxThreadgroupMemoryLength) {
|
2651
|
+
break;
|
2652
|
+
}
|
2653
|
+
nsgmax *= 2;
|
2654
|
+
}
|
2655
|
+
nsgmax /= 2;
|
2656
|
+
|
2657
|
+
// simdgroups per threadgroup (a.k.a. warps)
|
2658
|
+
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
2659
|
+
|
2660
|
+
const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
2661
|
+
|
2662
|
+
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
2663
|
+
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
2664
|
+
|
2665
|
+
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
2666
|
+
|
2667
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
2668
|
+
} else {
|
2669
|
+
// half1x4 kernel
|
2670
|
+
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
|
2671
|
+
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
2672
|
+
|
2673
|
+
GGML_ASSERT(nqptg <= 32);
|
2674
|
+
GGML_ASSERT(nqptg % 1 == 0);
|
2675
|
+
GGML_ASSERT(ncpsg % 32 == 0);
|
2676
|
+
|
2677
|
+
// simdgroups per threadgroup (a.k.a. warps)
|
2678
|
+
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
|
2679
|
+
|
2680
|
+
int64_t nsg = 1;
|
2681
|
+
while (nsg <= nsgt) {
|
2682
|
+
nsg *= 2;
|
2683
|
+
}
|
2684
|
+
nsg /= 2;
|
2685
|
+
|
2686
|
+
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
|
2687
|
+
|
2688
|
+
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
2689
|
+
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
2690
|
+
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
2691
|
+
|
2692
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
2693
|
+
}
|
2694
|
+
} break;
|
2506
2695
|
case GGML_OP_DUP:
|
2507
2696
|
case GGML_OP_CPY:
|
2508
2697
|
case GGML_OP_CONT:
|
@@ -2590,6 +2779,11 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2590
2779
|
MTLCommandBufferStatus status = [command_buffer status];
|
2591
2780
|
if (status != MTLCommandBufferStatusCompleted) {
|
2592
2781
|
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
2782
|
+
if (status == MTLCommandBufferStatusError) {
|
2783
|
+
NSString * error_code = [command_buffer error].localizedDescription;
|
2784
|
+
GGML_METAL_LOG_INFO("error: %s\n", [error_code UTF8String]);
|
2785
|
+
}
|
2786
|
+
|
2593
2787
|
return GGML_STATUS_FAILED;
|
2594
2788
|
}
|
2595
2789
|
}
|
@@ -2706,10 +2900,13 @@ GGML_CALL static const char * ggml_backend_metal_buffer_type_get_name(ggml_backe
|
|
2706
2900
|
UNUSED(buft);
|
2707
2901
|
}
|
2708
2902
|
|
2709
|
-
static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
|
2903
|
+
static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
|
2904
|
+
#ifndef GGML_METAL_NDEBUG
|
2710
2905
|
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
2711
2906
|
if (@available(macOS 10.12, iOS 16.0, *)) {
|
2712
|
-
GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
2907
|
+
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)",
|
2908
|
+
__func__,
|
2909
|
+
size_aligned / 1024.0 / 1024.0,
|
2713
2910
|
device.currentAllocatedSize / 1024.0 / 1024.0,
|
2714
2911
|
device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
2715
2912
|
|
@@ -2719,10 +2916,15 @@ static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
|
|
2719
2916
|
GGML_METAL_LOG_INFO("\n");
|
2720
2917
|
}
|
2721
2918
|
} else {
|
2722
|
-
GGML_METAL_LOG_INFO(", (%8.2f)\n",
|
2919
|
+
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n",
|
2920
|
+
__func__,
|
2921
|
+
size_aligned / 1024.0 / 1024.0,
|
2922
|
+
device.currentAllocatedSize / 1024.0 / 1024.0);
|
2723
2923
|
}
|
2924
|
+
#endif
|
2724
2925
|
#endif
|
2725
2926
|
UNUSED(device);
|
2927
|
+
UNUSED(size_aligned);
|
2726
2928
|
}
|
2727
2929
|
|
2728
2930
|
GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
@@ -2756,8 +2958,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff
|
|
2756
2958
|
return NULL;
|
2757
2959
|
}
|
2758
2960
|
|
2759
|
-
|
2760
|
-
ggml_backend_metal_log_allocated_size(device);
|
2961
|
+
//ggml_backend_metal_log_allocated_size(device, size_aligned);
|
2761
2962
|
|
2762
2963
|
return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
|
2763
2964
|
}
|
@@ -2844,7 +3045,7 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
|
|
2844
3045
|
return false;
|
2845
3046
|
}
|
2846
3047
|
|
2847
|
-
|
3048
|
+
ggml_backend_metal_log_allocated_size(device, size_aligned);
|
2848
3049
|
|
2849
3050
|
++ctx->n_buffers;
|
2850
3051
|
} else {
|
@@ -2867,7 +3068,8 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
|
|
2867
3068
|
return false;
|
2868
3069
|
}
|
2869
3070
|
|
2870
|
-
|
3071
|
+
ggml_backend_metal_log_allocated_size(device, size_step_aligned);
|
3072
|
+
|
2871
3073
|
if (i + size_step < size) {
|
2872
3074
|
GGML_METAL_LOG_INFO("\n");
|
2873
3075
|
}
|
@@ -2876,8 +3078,6 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
|
|
2876
3078
|
}
|
2877
3079
|
}
|
2878
3080
|
|
2879
|
-
ggml_backend_metal_log_allocated_size(device);
|
2880
|
-
|
2881
3081
|
return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
|
2882
3082
|
}
|
2883
3083
|
|