llama_cpp 0.15.2 → 0.15.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/ext/llama_cpp/llama_cpp.cpp +61 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +6 -0
- data/vendor/tmp/llama.cpp/Makefile +8 -16
- data/vendor/tmp/llama.cpp/ggml-common.h +0 -54
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +99 -40
- data/vendor/tmp/llama.cpp/ggml-cuda.h +1 -0
- data/vendor/tmp/llama.cpp/ggml-impl.h +44 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +4 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +133 -81
- data/vendor/tmp/llama.cpp/ggml-metal.metal +91 -434
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +4 -1
- data/vendor/tmp/llama.cpp/ggml-quants.c +1962 -2443
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +248 -108
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +375 -657
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +9351 -5627
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +204 -225
- data/vendor/tmp/llama.cpp/ggml.c +498 -836
- data/vendor/tmp/llama.cpp/ggml.h +57 -30
- data/vendor/tmp/llama.cpp/llama.cpp +1477 -859
- data/vendor/tmp/llama.cpp/llama.h +21 -8
- metadata +3 -3
@@ -35,6 +35,10 @@ enum ggml_metal_kernel_type {
|
|
35
35
|
GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
36
36
|
GGML_METAL_KERNEL_TYPE_DIV,
|
37
37
|
GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
38
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
39
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
40
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
|
41
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_I16,
|
38
42
|
GGML_METAL_KERNEL_TYPE_SCALE,
|
39
43
|
GGML_METAL_KERNEL_TYPE_SCALE_4,
|
40
44
|
GGML_METAL_KERNEL_TYPE_CLAMP,
|
@@ -184,9 +188,9 @@ enum ggml_metal_kernel_type {
|
|
184
188
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
185
189
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
186
190
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
187
|
-
|
191
|
+
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
188
192
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
189
|
-
|
193
|
+
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
190
194
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
191
195
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
192
196
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
@@ -381,10 +385,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
381
385
|
// dictionary of preprocessor macros
|
382
386
|
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
383
387
|
|
384
|
-
#ifdef GGML_QKK_64
|
385
|
-
prep[@"GGML_QKK_64"] = @(1);
|
386
|
-
#endif
|
387
|
-
|
388
388
|
MTLCompileOptions* options = [MTLCompileOptions new];
|
389
389
|
options.preprocessorMacros = prep;
|
390
390
|
|
@@ -489,6 +489,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
489
489
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
490
490
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
491
491
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
492
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
493
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
494
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
|
495
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
|
492
496
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
493
497
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
494
498
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
@@ -638,9 +642,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
638
642
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
|
639
643
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
|
640
644
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
|
641
|
-
|
645
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
|
642
646
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
|
643
|
-
|
647
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
|
644
648
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
645
649
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
646
650
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
@@ -750,6 +754,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
750
754
|
case GGML_OP_ACC:
|
751
755
|
case GGML_OP_MUL:
|
752
756
|
case GGML_OP_DIV:
|
757
|
+
case GGML_OP_REPEAT:
|
753
758
|
case GGML_OP_SCALE:
|
754
759
|
case GGML_OP_CLAMP:
|
755
760
|
case GGML_OP_SQR:
|
@@ -774,6 +779,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
774
779
|
case GGML_OP_LEAKY_RELU:
|
775
780
|
return true;
|
776
781
|
case GGML_OP_FLASH_ATTN_EXT:
|
782
|
+
if (op->src[0]->ne[0] == 256) {
|
783
|
+
return false;
|
784
|
+
}
|
777
785
|
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
778
786
|
case GGML_OP_MUL_MAT:
|
779
787
|
case GGML_OP_MUL_MAT_ID:
|
@@ -927,22 +935,32 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
927
935
|
const int64_t ne10 = src1 ? src1->ne[0] : 0;
|
928
936
|
const int64_t ne11 = src1 ? src1->ne[1] : 0;
|
929
937
|
const int64_t ne12 = src1 ? src1->ne[2] : 0;
|
930
|
-
const int64_t ne13 = src1 ? src1->ne[3] : 0;
|
938
|
+
const int64_t ne13 = src1 ? src1->ne[3] : 0;
|
931
939
|
|
932
940
|
const uint64_t nb10 = src1 ? src1->nb[0] : 0;
|
933
941
|
const uint64_t nb11 = src1 ? src1->nb[1] : 0;
|
934
942
|
const uint64_t nb12 = src1 ? src1->nb[2] : 0;
|
935
|
-
const uint64_t nb13 = src1 ? src1->nb[3] : 0;
|
943
|
+
const uint64_t nb13 = src1 ? src1->nb[3] : 0;
|
936
944
|
|
937
|
-
const int64_t
|
938
|
-
const int64_t
|
939
|
-
const int64_t
|
940
|
-
const int64_t
|
945
|
+
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
946
|
+
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
947
|
+
const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
|
948
|
+
const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
|
941
949
|
|
942
|
-
const uint64_t
|
943
|
-
const uint64_t
|
944
|
-
const uint64_t
|
945
|
-
const uint64_t
|
950
|
+
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
951
|
+
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
952
|
+
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
953
|
+
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
|
954
|
+
|
955
|
+
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
956
|
+
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
957
|
+
const int64_t ne2 = dst ? dst->ne[2] : 0;
|
958
|
+
const int64_t ne3 = dst ? dst->ne[3] : 0;
|
959
|
+
|
960
|
+
const uint64_t nb0 = dst ? dst->nb[0] : 0;
|
961
|
+
const uint64_t nb1 = dst ? dst->nb[1] : 0;
|
962
|
+
const uint64_t nb2 = dst ? dst->nb[2] : 0;
|
963
|
+
const uint64_t nb3 = dst ? dst->nb[3] : 0;
|
946
964
|
|
947
965
|
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
|
948
966
|
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
|
@@ -970,10 +988,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
970
988
|
switch (dst->op) {
|
971
989
|
case GGML_OP_CONCAT:
|
972
990
|
{
|
973
|
-
const int64_t nb = ne00;
|
974
|
-
|
975
991
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
|
976
992
|
|
993
|
+
const int32_t dim = ((int32_t *) dst->op_params)[0];
|
994
|
+
|
977
995
|
[encoder setComputePipelineState:pipeline];
|
978
996
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
979
997
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
@@ -1002,7 +1020,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1002
1020
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
1003
1021
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
1004
1022
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
1005
|
-
[encoder setBytes:&
|
1023
|
+
[encoder setBytes:&dim length:sizeof(dim) atIndex:27];
|
1006
1024
|
|
1007
1025
|
const int nth = MIN(1024, ne0);
|
1008
1026
|
|
@@ -1012,11 +1030,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1012
1030
|
case GGML_OP_MUL:
|
1013
1031
|
case GGML_OP_DIV:
|
1014
1032
|
{
|
1033
|
+
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
1034
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1035
|
+
|
1015
1036
|
const size_t offs = 0;
|
1016
1037
|
|
1017
1038
|
bool bcast_row = false;
|
1018
1039
|
|
1019
|
-
int64_t nb = ne00;
|
1040
|
+
int64_t nb = ne00; // used by the "row" kernels
|
1020
1041
|
|
1021
1042
|
id<MTLComputePipelineState> pipeline = nil;
|
1022
1043
|
|
@@ -1085,6 +1106,42 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1085
1106
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1086
1107
|
}
|
1087
1108
|
} break;
|
1109
|
+
case GGML_OP_REPEAT:
|
1110
|
+
{
|
1111
|
+
id<MTLComputePipelineState> pipeline;
|
1112
|
+
|
1113
|
+
switch (src0t) {
|
1114
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
|
1115
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
|
1116
|
+
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
|
1117
|
+
case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
|
1118
|
+
default: GGML_ASSERT(false);
|
1119
|
+
}
|
1120
|
+
|
1121
|
+
[encoder setComputePipelineState:pipeline];
|
1122
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1123
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1124
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
1125
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
1126
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
1127
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
1128
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
1129
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
1130
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
1131
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
1132
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
1133
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
1134
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
1135
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
1136
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
1137
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
1138
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
1139
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
1140
|
+
|
1141
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
1142
|
+
|
1143
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1144
|
+
} break;
|
1088
1145
|
case GGML_OP_ACC:
|
1089
1146
|
{
|
1090
1147
|
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
@@ -1462,7 +1519,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1462
1519
|
{
|
1463
1520
|
GGML_ASSERT(ne00 == ne10);
|
1464
1521
|
|
1465
|
-
// TODO: assert that dim2 and dim3 are contiguous
|
1466
1522
|
GGML_ASSERT(ne12 % ne02 == 0);
|
1467
1523
|
GGML_ASSERT(ne13 % ne03 == 0);
|
1468
1524
|
|
@@ -1763,11 +1819,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1763
1819
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1764
1820
|
}
|
1765
1821
|
else if (src0t == GGML_TYPE_Q3_K) {
|
1766
|
-
#ifdef GGML_QKK_64
|
1767
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1768
|
-
#else
|
1769
1822
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1770
|
-
#endif
|
1771
1823
|
}
|
1772
1824
|
else if (src0t == GGML_TYPE_Q5_K) {
|
1773
1825
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
@@ -1785,16 +1837,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1785
1837
|
const int n_as = src0->ne[2];
|
1786
1838
|
|
1787
1839
|
// src2 = ids
|
1788
|
-
const int64_t ne20 = src2->ne[0];
|
1789
|
-
const int64_t ne21 = src2->ne[1];
|
1790
|
-
const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
|
1791
|
-
const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
|
1792
|
-
|
1793
|
-
const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
|
1794
|
-
const uint64_t nb21 = src2->nb[1];
|
1795
|
-
const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
|
1796
|
-
const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
|
1797
|
-
|
1798
1840
|
const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
|
1799
1841
|
|
1800
1842
|
GGML_ASSERT(src2t == GGML_TYPE_I32);
|
@@ -2018,12 +2060,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2018
2060
|
{
|
2019
2061
|
nth0 = 4;
|
2020
2062
|
nth1 = 16;
|
2021
|
-
#if QK_K == 64
|
2022
|
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
|
2023
|
-
#else
|
2024
2063
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
2025
|
-
#endif
|
2026
|
-
|
2027
2064
|
} break;
|
2028
2065
|
default:
|
2029
2066
|
{
|
@@ -2088,11 +2125,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2088
2125
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2089
2126
|
}
|
2090
2127
|
else if (src0t == GGML_TYPE_Q3_K) {
|
2091
|
-
#ifdef GGML_QKK_64
|
2092
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2093
|
-
#else
|
2094
2128
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2095
|
-
#endif
|
2096
2129
|
}
|
2097
2130
|
else if (src0t == GGML_TYPE_Q5_K) {
|
2098
2131
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
@@ -2153,6 +2186,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2153
2186
|
case GGML_OP_RMS_NORM:
|
2154
2187
|
{
|
2155
2188
|
GGML_ASSERT(ne00 % 4 == 0);
|
2189
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
2156
2190
|
|
2157
2191
|
float eps;
|
2158
2192
|
memcpy(&eps, dst->op_params, sizeof(float));
|
@@ -2180,6 +2214,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2180
2214
|
case GGML_OP_GROUP_NORM:
|
2181
2215
|
{
|
2182
2216
|
GGML_ASSERT(ne00 % 4 == 0);
|
2217
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
2183
2218
|
|
2184
2219
|
//float eps;
|
2185
2220
|
//memcpy(&eps, dst->op_params, sizeof(float));
|
@@ -2213,6 +2248,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2213
2248
|
} break;
|
2214
2249
|
case GGML_OP_NORM:
|
2215
2250
|
{
|
2251
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
2252
|
+
|
2216
2253
|
float eps;
|
2217
2254
|
memcpy(&eps, dst->op_params, sizeof(float));
|
2218
2255
|
|
@@ -2244,7 +2281,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2244
2281
|
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
2245
2282
|
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
2246
2283
|
|
2247
|
-
float freq_base
|
2284
|
+
float freq_base;
|
2285
|
+
float freq_scale;
|
2286
|
+
float ext_factor;
|
2287
|
+
float attn_factor;
|
2288
|
+
float beta_fast;
|
2289
|
+
float beta_slow;
|
2290
|
+
|
2248
2291
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
2249
2292
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
2250
2293
|
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
@@ -2252,6 +2295,15 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2252
2295
|
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
2253
2296
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
2254
2297
|
|
2298
|
+
const bool is_neox = mode & 2;
|
2299
|
+
const bool is_glm = mode & 4;
|
2300
|
+
|
2301
|
+
GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
|
2302
|
+
|
2303
|
+
if (!is_neox) {
|
2304
|
+
GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
|
2305
|
+
}
|
2306
|
+
|
2255
2307
|
id<MTLComputePipelineState> pipeline = nil;
|
2256
2308
|
|
2257
2309
|
switch (src0->type) {
|
@@ -2263,33 +2315,38 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2263
2315
|
[encoder setComputePipelineState:pipeline];
|
2264
2316
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2265
2317
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
2266
|
-
|
2267
|
-
|
2268
|
-
|
2269
|
-
|
2270
|
-
|
2271
|
-
[encoder
|
2272
|
-
[encoder setBytes:&
|
2273
|
-
[encoder setBytes:&
|
2274
|
-
[encoder setBytes:&
|
2275
|
-
[encoder setBytes:&
|
2276
|
-
[encoder setBytes:&
|
2277
|
-
[encoder setBytes:&
|
2278
|
-
[encoder setBytes:&
|
2279
|
-
[encoder setBytes:&
|
2280
|
-
[encoder setBytes:&
|
2281
|
-
[encoder setBytes:&
|
2282
|
-
[encoder setBytes:&
|
2283
|
-
[encoder setBytes:&
|
2284
|
-
[encoder setBytes:&
|
2285
|
-
[encoder setBytes:&
|
2286
|
-
[encoder setBytes:&
|
2287
|
-
[encoder setBytes:&
|
2288
|
-
[encoder setBytes:&
|
2289
|
-
[encoder setBytes:&
|
2290
|
-
[encoder setBytes:&
|
2291
|
-
[encoder setBytes:&
|
2292
|
-
[encoder setBytes:&
|
2318
|
+
if (id_src2 != nil) {
|
2319
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
2320
|
+
} else {
|
2321
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
|
2322
|
+
}
|
2323
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
2324
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
|
2325
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
2326
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
2327
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
2328
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
|
2329
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
|
2330
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
|
2331
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
|
2332
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
|
2333
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
|
2334
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
|
2335
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
|
2336
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
|
2337
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
|
2338
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
|
2339
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
|
2340
|
+
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
|
2341
|
+
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
|
2342
|
+
[encoder setBytes:&mode length:sizeof( int) atIndex:22];
|
2343
|
+
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23];
|
2344
|
+
[encoder setBytes:&freq_base length:sizeof( float) atIndex:24];
|
2345
|
+
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:25];
|
2346
|
+
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:26];
|
2347
|
+
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:27];
|
2348
|
+
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:28];
|
2349
|
+
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
|
2293
2350
|
|
2294
2351
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2295
2352
|
} break;
|
@@ -2535,11 +2592,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2535
2592
|
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
|
2536
2593
|
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
2537
2594
|
|
2538
|
-
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
2539
|
-
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
2540
|
-
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
2541
|
-
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
|
2542
|
-
|
2543
2595
|
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
|
2544
2596
|
//const int64_t ne31 = src3 ? src3->ne[1] : 0;
|
2545
2597
|
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
|
@@ -2575,7 +2627,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2575
2627
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
2576
2628
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
2577
2629
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
2578
|
-
|
2630
|
+
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
2579
2631
|
default:
|
2580
2632
|
{
|
2581
2633
|
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
@@ -2588,7 +2640,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2588
2640
|
|
2589
2641
|
switch (ne00) {
|
2590
2642
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
2591
|
-
|
2643
|
+
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
2592
2644
|
default:
|
2593
2645
|
{
|
2594
2646
|
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|