llama_cpp 0.15.2 → 0.15.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/ext/llama_cpp/llama_cpp.cpp +61 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +6 -0
- data/vendor/tmp/llama.cpp/Makefile +8 -16
- data/vendor/tmp/llama.cpp/ggml-common.h +0 -54
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +99 -40
- data/vendor/tmp/llama.cpp/ggml-cuda.h +1 -0
- data/vendor/tmp/llama.cpp/ggml-impl.h +44 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +4 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +133 -81
- data/vendor/tmp/llama.cpp/ggml-metal.metal +91 -434
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +4 -1
- data/vendor/tmp/llama.cpp/ggml-quants.c +1962 -2443
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +248 -108
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +375 -657
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +9351 -5627
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +204 -225
- data/vendor/tmp/llama.cpp/ggml.c +498 -836
- data/vendor/tmp/llama.cpp/ggml.h +57 -30
- data/vendor/tmp/llama.cpp/llama.cpp +1477 -859
- data/vendor/tmp/llama.cpp/llama.h +21 -8
- metadata +3 -3
@@ -35,6 +35,10 @@ enum ggml_metal_kernel_type {
|
|
35
35
|
GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
36
36
|
GGML_METAL_KERNEL_TYPE_DIV,
|
37
37
|
GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
38
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
39
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
40
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
|
41
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_I16,
|
38
42
|
GGML_METAL_KERNEL_TYPE_SCALE,
|
39
43
|
GGML_METAL_KERNEL_TYPE_SCALE_4,
|
40
44
|
GGML_METAL_KERNEL_TYPE_CLAMP,
|
@@ -184,9 +188,9 @@ enum ggml_metal_kernel_type {
|
|
184
188
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
185
189
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
186
190
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
187
|
-
|
191
|
+
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
188
192
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
189
|
-
|
193
|
+
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
190
194
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
191
195
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
192
196
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
@@ -381,10 +385,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
381
385
|
// dictionary of preprocessor macros
|
382
386
|
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
383
387
|
|
384
|
-
#ifdef GGML_QKK_64
|
385
|
-
prep[@"GGML_QKK_64"] = @(1);
|
386
|
-
#endif
|
387
|
-
|
388
388
|
MTLCompileOptions* options = [MTLCompileOptions new];
|
389
389
|
options.preprocessorMacros = prep;
|
390
390
|
|
@@ -489,6 +489,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
489
489
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
490
490
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
491
491
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
492
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
493
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
494
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
|
495
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
|
492
496
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
493
497
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
494
498
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
@@ -638,9 +642,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
638
642
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
|
639
643
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
|
640
644
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
|
641
|
-
|
645
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
|
642
646
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
|
643
|
-
|
647
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
|
644
648
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
645
649
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
646
650
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
@@ -750,6 +754,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
750
754
|
case GGML_OP_ACC:
|
751
755
|
case GGML_OP_MUL:
|
752
756
|
case GGML_OP_DIV:
|
757
|
+
case GGML_OP_REPEAT:
|
753
758
|
case GGML_OP_SCALE:
|
754
759
|
case GGML_OP_CLAMP:
|
755
760
|
case GGML_OP_SQR:
|
@@ -774,6 +779,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
774
779
|
case GGML_OP_LEAKY_RELU:
|
775
780
|
return true;
|
776
781
|
case GGML_OP_FLASH_ATTN_EXT:
|
782
|
+
if (op->src[0]->ne[0] == 256) {
|
783
|
+
return false;
|
784
|
+
}
|
777
785
|
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
778
786
|
case GGML_OP_MUL_MAT:
|
779
787
|
case GGML_OP_MUL_MAT_ID:
|
@@ -927,22 +935,32 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
927
935
|
const int64_t ne10 = src1 ? src1->ne[0] : 0;
|
928
936
|
const int64_t ne11 = src1 ? src1->ne[1] : 0;
|
929
937
|
const int64_t ne12 = src1 ? src1->ne[2] : 0;
|
930
|
-
const int64_t ne13 = src1 ? src1->ne[3] : 0;
|
938
|
+
const int64_t ne13 = src1 ? src1->ne[3] : 0;
|
931
939
|
|
932
940
|
const uint64_t nb10 = src1 ? src1->nb[0] : 0;
|
933
941
|
const uint64_t nb11 = src1 ? src1->nb[1] : 0;
|
934
942
|
const uint64_t nb12 = src1 ? src1->nb[2] : 0;
|
935
|
-
const uint64_t nb13 = src1 ? src1->nb[3] : 0;
|
943
|
+
const uint64_t nb13 = src1 ? src1->nb[3] : 0;
|
936
944
|
|
937
|
-
const int64_t
|
938
|
-
const int64_t
|
939
|
-
const int64_t
|
940
|
-
const int64_t
|
945
|
+
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
946
|
+
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
947
|
+
const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
|
948
|
+
const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
|
941
949
|
|
942
|
-
const uint64_t
|
943
|
-
const uint64_t
|
944
|
-
const uint64_t
|
945
|
-
const uint64_t
|
950
|
+
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
951
|
+
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
952
|
+
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
953
|
+
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
|
954
|
+
|
955
|
+
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
956
|
+
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
957
|
+
const int64_t ne2 = dst ? dst->ne[2] : 0;
|
958
|
+
const int64_t ne3 = dst ? dst->ne[3] : 0;
|
959
|
+
|
960
|
+
const uint64_t nb0 = dst ? dst->nb[0] : 0;
|
961
|
+
const uint64_t nb1 = dst ? dst->nb[1] : 0;
|
962
|
+
const uint64_t nb2 = dst ? dst->nb[2] : 0;
|
963
|
+
const uint64_t nb3 = dst ? dst->nb[3] : 0;
|
946
964
|
|
947
965
|
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
|
948
966
|
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
|
@@ -970,10 +988,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
970
988
|
switch (dst->op) {
|
971
989
|
case GGML_OP_CONCAT:
|
972
990
|
{
|
973
|
-
const int64_t nb = ne00;
|
974
|
-
|
975
991
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
|
976
992
|
|
993
|
+
const int32_t dim = ((int32_t *) dst->op_params)[0];
|
994
|
+
|
977
995
|
[encoder setComputePipelineState:pipeline];
|
978
996
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
979
997
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
@@ -1002,7 +1020,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1002
1020
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
1003
1021
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
1004
1022
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
1005
|
-
[encoder setBytes:&
|
1023
|
+
[encoder setBytes:&dim length:sizeof(dim) atIndex:27];
|
1006
1024
|
|
1007
1025
|
const int nth = MIN(1024, ne0);
|
1008
1026
|
|
@@ -1012,11 +1030,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1012
1030
|
case GGML_OP_MUL:
|
1013
1031
|
case GGML_OP_DIV:
|
1014
1032
|
{
|
1033
|
+
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
1034
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1035
|
+
|
1015
1036
|
const size_t offs = 0;
|
1016
1037
|
|
1017
1038
|
bool bcast_row = false;
|
1018
1039
|
|
1019
|
-
int64_t nb = ne00;
|
1040
|
+
int64_t nb = ne00; // used by the "row" kernels
|
1020
1041
|
|
1021
1042
|
id<MTLComputePipelineState> pipeline = nil;
|
1022
1043
|
|
@@ -1085,6 +1106,42 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1085
1106
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1086
1107
|
}
|
1087
1108
|
} break;
|
1109
|
+
case GGML_OP_REPEAT:
|
1110
|
+
{
|
1111
|
+
id<MTLComputePipelineState> pipeline;
|
1112
|
+
|
1113
|
+
switch (src0t) {
|
1114
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
|
1115
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
|
1116
|
+
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
|
1117
|
+
case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
|
1118
|
+
default: GGML_ASSERT(false);
|
1119
|
+
}
|
1120
|
+
|
1121
|
+
[encoder setComputePipelineState:pipeline];
|
1122
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1123
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1124
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
1125
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
1126
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
1127
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
1128
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
1129
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
1130
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
1131
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
1132
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
1133
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
1134
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
1135
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
1136
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
1137
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
1138
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
1139
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
1140
|
+
|
1141
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
1142
|
+
|
1143
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1144
|
+
} break;
|
1088
1145
|
case GGML_OP_ACC:
|
1089
1146
|
{
|
1090
1147
|
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
@@ -1462,7 +1519,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1462
1519
|
{
|
1463
1520
|
GGML_ASSERT(ne00 == ne10);
|
1464
1521
|
|
1465
|
-
// TODO: assert that dim2 and dim3 are contiguous
|
1466
1522
|
GGML_ASSERT(ne12 % ne02 == 0);
|
1467
1523
|
GGML_ASSERT(ne13 % ne03 == 0);
|
1468
1524
|
|
@@ -1763,11 +1819,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1763
1819
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1764
1820
|
}
|
1765
1821
|
else if (src0t == GGML_TYPE_Q3_K) {
|
1766
|
-
#ifdef GGML_QKK_64
|
1767
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1768
|
-
#else
|
1769
1822
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1770
|
-
#endif
|
1771
1823
|
}
|
1772
1824
|
else if (src0t == GGML_TYPE_Q5_K) {
|
1773
1825
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
@@ -1785,16 +1837,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1785
1837
|
const int n_as = src0->ne[2];
|
1786
1838
|
|
1787
1839
|
// src2 = ids
|
1788
|
-
const int64_t ne20 = src2->ne[0];
|
1789
|
-
const int64_t ne21 = src2->ne[1];
|
1790
|
-
const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
|
1791
|
-
const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
|
1792
|
-
|
1793
|
-
const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
|
1794
|
-
const uint64_t nb21 = src2->nb[1];
|
1795
|
-
const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
|
1796
|
-
const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
|
1797
|
-
|
1798
1840
|
const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
|
1799
1841
|
|
1800
1842
|
GGML_ASSERT(src2t == GGML_TYPE_I32);
|
@@ -2018,12 +2060,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2018
2060
|
{
|
2019
2061
|
nth0 = 4;
|
2020
2062
|
nth1 = 16;
|
2021
|
-
#if QK_K == 64
|
2022
|
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
|
2023
|
-
#else
|
2024
2063
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
2025
|
-
#endif
|
2026
|
-
|
2027
2064
|
} break;
|
2028
2065
|
default:
|
2029
2066
|
{
|
@@ -2088,11 +2125,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2088
2125
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2089
2126
|
}
|
2090
2127
|
else if (src0t == GGML_TYPE_Q3_K) {
|
2091
|
-
#ifdef GGML_QKK_64
|
2092
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2093
|
-
#else
|
2094
2128
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2095
|
-
#endif
|
2096
2129
|
}
|
2097
2130
|
else if (src0t == GGML_TYPE_Q5_K) {
|
2098
2131
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
@@ -2153,6 +2186,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2153
2186
|
case GGML_OP_RMS_NORM:
|
2154
2187
|
{
|
2155
2188
|
GGML_ASSERT(ne00 % 4 == 0);
|
2189
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
2156
2190
|
|
2157
2191
|
float eps;
|
2158
2192
|
memcpy(&eps, dst->op_params, sizeof(float));
|
@@ -2180,6 +2214,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2180
2214
|
case GGML_OP_GROUP_NORM:
|
2181
2215
|
{
|
2182
2216
|
GGML_ASSERT(ne00 % 4 == 0);
|
2217
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
2183
2218
|
|
2184
2219
|
//float eps;
|
2185
2220
|
//memcpy(&eps, dst->op_params, sizeof(float));
|
@@ -2213,6 +2248,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2213
2248
|
} break;
|
2214
2249
|
case GGML_OP_NORM:
|
2215
2250
|
{
|
2251
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
2252
|
+
|
2216
2253
|
float eps;
|
2217
2254
|
memcpy(&eps, dst->op_params, sizeof(float));
|
2218
2255
|
|
@@ -2244,7 +2281,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2244
2281
|
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
2245
2282
|
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
2246
2283
|
|
2247
|
-
float freq_base
|
2284
|
+
float freq_base;
|
2285
|
+
float freq_scale;
|
2286
|
+
float ext_factor;
|
2287
|
+
float attn_factor;
|
2288
|
+
float beta_fast;
|
2289
|
+
float beta_slow;
|
2290
|
+
|
2248
2291
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
2249
2292
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
2250
2293
|
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
@@ -2252,6 +2295,15 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2252
2295
|
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
2253
2296
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
2254
2297
|
|
2298
|
+
const bool is_neox = mode & 2;
|
2299
|
+
const bool is_glm = mode & 4;
|
2300
|
+
|
2301
|
+
GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
|
2302
|
+
|
2303
|
+
if (!is_neox) {
|
2304
|
+
GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
|
2305
|
+
}
|
2306
|
+
|
2255
2307
|
id<MTLComputePipelineState> pipeline = nil;
|
2256
2308
|
|
2257
2309
|
switch (src0->type) {
|
@@ -2263,33 +2315,38 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2263
2315
|
[encoder setComputePipelineState:pipeline];
|
2264
2316
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2265
2317
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
2266
|
-
|
2267
|
-
|
2268
|
-
|
2269
|
-
|
2270
|
-
|
2271
|
-
[encoder
|
2272
|
-
[encoder setBytes:&
|
2273
|
-
[encoder setBytes:&
|
2274
|
-
[encoder setBytes:&
|
2275
|
-
[encoder setBytes:&
|
2276
|
-
[encoder setBytes:&
|
2277
|
-
[encoder setBytes:&
|
2278
|
-
[encoder setBytes:&
|
2279
|
-
[encoder setBytes:&
|
2280
|
-
[encoder setBytes:&
|
2281
|
-
[encoder setBytes:&
|
2282
|
-
[encoder setBytes:&
|
2283
|
-
[encoder setBytes:&
|
2284
|
-
[encoder setBytes:&
|
2285
|
-
[encoder setBytes:&
|
2286
|
-
[encoder setBytes:&
|
2287
|
-
[encoder setBytes:&
|
2288
|
-
[encoder setBytes:&
|
2289
|
-
[encoder setBytes:&
|
2290
|
-
[encoder setBytes:&
|
2291
|
-
[encoder setBytes:&
|
2292
|
-
[encoder setBytes:&
|
2318
|
+
if (id_src2 != nil) {
|
2319
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
2320
|
+
} else {
|
2321
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
|
2322
|
+
}
|
2323
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
2324
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
|
2325
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
2326
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
2327
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
2328
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
|
2329
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
|
2330
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
|
2331
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
|
2332
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
|
2333
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
|
2334
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
|
2335
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
|
2336
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
|
2337
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
|
2338
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
|
2339
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
|
2340
|
+
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
|
2341
|
+
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
|
2342
|
+
[encoder setBytes:&mode length:sizeof( int) atIndex:22];
|
2343
|
+
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23];
|
2344
|
+
[encoder setBytes:&freq_base length:sizeof( float) atIndex:24];
|
2345
|
+
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:25];
|
2346
|
+
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:26];
|
2347
|
+
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:27];
|
2348
|
+
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:28];
|
2349
|
+
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
|
2293
2350
|
|
2294
2351
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2295
2352
|
} break;
|
@@ -2535,11 +2592,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2535
2592
|
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
|
2536
2593
|
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
2537
2594
|
|
2538
|
-
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
2539
|
-
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
2540
|
-
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
2541
|
-
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
|
2542
|
-
|
2543
2595
|
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
|
2544
2596
|
//const int64_t ne31 = src3 ? src3->ne[1] : 0;
|
2545
2597
|
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
|
@@ -2575,7 +2627,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2575
2627
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
2576
2628
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
2577
2629
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
2578
|
-
|
2630
|
+
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
2579
2631
|
default:
|
2580
2632
|
{
|
2581
2633
|
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
@@ -2588,7 +2640,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2588
2640
|
|
2589
2641
|
switch (ne00) {
|
2590
2642
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
2591
|
-
|
2643
|
+
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
2592
2644
|
default:
|
2593
2645
|
{
|
2594
2646
|
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|