llama_cpp 0.15.2 → 0.15.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -35,6 +35,10 @@ enum ggml_metal_kernel_type {
35
35
  GGML_METAL_KERNEL_TYPE_MUL_ROW,
36
36
  GGML_METAL_KERNEL_TYPE_DIV,
37
37
  GGML_METAL_KERNEL_TYPE_DIV_ROW,
38
+ GGML_METAL_KERNEL_TYPE_REPEAT_F32,
39
+ GGML_METAL_KERNEL_TYPE_REPEAT_F16,
40
+ GGML_METAL_KERNEL_TYPE_REPEAT_I32,
41
+ GGML_METAL_KERNEL_TYPE_REPEAT_I16,
38
42
  GGML_METAL_KERNEL_TYPE_SCALE,
39
43
  GGML_METAL_KERNEL_TYPE_SCALE_4,
40
44
  GGML_METAL_KERNEL_TYPE_CLAMP,
@@ -184,9 +188,9 @@ enum ggml_metal_kernel_type {
184
188
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
185
189
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
186
190
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
187
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
191
+ //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
188
192
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
189
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
193
+ //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
190
194
  GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
191
195
  GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
192
196
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@@ -381,10 +385,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
381
385
  // dictionary of preprocessor macros
382
386
  NSMutableDictionary * prep = [NSMutableDictionary dictionary];
383
387
 
384
- #ifdef GGML_QKK_64
385
- prep[@"GGML_QKK_64"] = @(1);
386
- #endif
387
-
388
388
  MTLCompileOptions* options = [MTLCompileOptions new];
389
389
  options.preprocessorMacros = prep;
390
390
 
@@ -489,6 +489,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
489
489
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
490
490
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
491
491
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
492
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
493
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
494
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
495
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
492
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
493
497
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
494
498
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
@@ -638,9 +642,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
638
642
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
639
643
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
640
644
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
641
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
645
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
642
646
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
643
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
647
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
644
648
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
645
649
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
646
650
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
@@ -750,6 +754,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
750
754
  case GGML_OP_ACC:
751
755
  case GGML_OP_MUL:
752
756
  case GGML_OP_DIV:
757
+ case GGML_OP_REPEAT:
753
758
  case GGML_OP_SCALE:
754
759
  case GGML_OP_CLAMP:
755
760
  case GGML_OP_SQR:
@@ -774,6 +779,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
774
779
  case GGML_OP_LEAKY_RELU:
775
780
  return true;
776
781
  case GGML_OP_FLASH_ATTN_EXT:
782
+ if (op->src[0]->ne[0] == 256) {
783
+ return false;
784
+ }
777
785
  return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
778
786
  case GGML_OP_MUL_MAT:
779
787
  case GGML_OP_MUL_MAT_ID:
@@ -927,22 +935,32 @@ static enum ggml_status ggml_metal_graph_compute(
927
935
  const int64_t ne10 = src1 ? src1->ne[0] : 0;
928
936
  const int64_t ne11 = src1 ? src1->ne[1] : 0;
929
937
  const int64_t ne12 = src1 ? src1->ne[2] : 0;
930
- const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
938
+ const int64_t ne13 = src1 ? src1->ne[3] : 0;
931
939
 
932
940
  const uint64_t nb10 = src1 ? src1->nb[0] : 0;
933
941
  const uint64_t nb11 = src1 ? src1->nb[1] : 0;
934
942
  const uint64_t nb12 = src1 ? src1->nb[2] : 0;
935
- const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
943
+ const uint64_t nb13 = src1 ? src1->nb[3] : 0;
936
944
 
937
- const int64_t ne0 = dst ? dst->ne[0] : 0;
938
- const int64_t ne1 = dst ? dst->ne[1] : 0;
939
- const int64_t ne2 = dst ? dst->ne[2] : 0;
940
- const int64_t ne3 = dst ? dst->ne[3] : 0;
945
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
946
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
947
+ const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
948
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
941
949
 
942
- const uint64_t nb0 = dst ? dst->nb[0] : 0;
943
- const uint64_t nb1 = dst ? dst->nb[1] : 0;
944
- const uint64_t nb2 = dst ? dst->nb[2] : 0;
945
- const uint64_t nb3 = dst ? dst->nb[3] : 0;
950
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
951
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
952
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
953
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0;
954
+
955
+ const int64_t ne0 = dst ? dst->ne[0] : 0;
956
+ const int64_t ne1 = dst ? dst->ne[1] : 0;
957
+ const int64_t ne2 = dst ? dst->ne[2] : 0;
958
+ const int64_t ne3 = dst ? dst->ne[3] : 0;
959
+
960
+ const uint64_t nb0 = dst ? dst->nb[0] : 0;
961
+ const uint64_t nb1 = dst ? dst->nb[1] : 0;
962
+ const uint64_t nb2 = dst ? dst->nb[2] : 0;
963
+ const uint64_t nb3 = dst ? dst->nb[3] : 0;
946
964
 
947
965
  const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
948
966
  const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
@@ -970,10 +988,10 @@ static enum ggml_status ggml_metal_graph_compute(
970
988
  switch (dst->op) {
971
989
  case GGML_OP_CONCAT:
972
990
  {
973
- const int64_t nb = ne00;
974
-
975
991
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
976
992
 
993
+ const int32_t dim = ((int32_t *) dst->op_params)[0];
994
+
977
995
  [encoder setComputePipelineState:pipeline];
978
996
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
979
997
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -1002,7 +1020,7 @@ static enum ggml_status ggml_metal_graph_compute(
1002
1020
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1003
1021
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1004
1022
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1005
- [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
1023
+ [encoder setBytes:&dim length:sizeof(dim) atIndex:27];
1006
1024
 
1007
1025
  const int nth = MIN(1024, ne0);
1008
1026
 
@@ -1012,11 +1030,14 @@ static enum ggml_status ggml_metal_graph_compute(
1012
1030
  case GGML_OP_MUL:
1013
1031
  case GGML_OP_DIV:
1014
1032
  {
1033
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
1034
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1035
+
1015
1036
  const size_t offs = 0;
1016
1037
 
1017
1038
  bool bcast_row = false;
1018
1039
 
1019
- int64_t nb = ne00;
1040
+ int64_t nb = ne00; // used by the "row" kernels
1020
1041
 
1021
1042
  id<MTLComputePipelineState> pipeline = nil;
1022
1043
 
@@ -1085,6 +1106,42 @@ static enum ggml_status ggml_metal_graph_compute(
1085
1106
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1086
1107
  }
1087
1108
  } break;
1109
+ case GGML_OP_REPEAT:
1110
+ {
1111
+ id<MTLComputePipelineState> pipeline;
1112
+
1113
+ switch (src0t) {
1114
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
1115
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
1116
+ case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
1117
+ case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
1118
+ default: GGML_ASSERT(false);
1119
+ }
1120
+
1121
+ [encoder setComputePipelineState:pipeline];
1122
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1123
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1124
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1125
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1126
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1127
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1128
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1129
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1130
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1131
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1132
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
1133
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
1134
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
1135
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
1136
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
1137
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1138
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
1139
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
1140
+
1141
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1142
+
1143
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1144
+ } break;
1088
1145
  case GGML_OP_ACC:
1089
1146
  {
1090
1147
  GGML_ASSERT(src0t == GGML_TYPE_F32);
@@ -1462,7 +1519,6 @@ static enum ggml_status ggml_metal_graph_compute(
1462
1519
  {
1463
1520
  GGML_ASSERT(ne00 == ne10);
1464
1521
 
1465
- // TODO: assert that dim2 and dim3 are contiguous
1466
1522
  GGML_ASSERT(ne12 % ne02 == 0);
1467
1523
  GGML_ASSERT(ne13 % ne03 == 0);
1468
1524
 
@@ -1763,11 +1819,7 @@ static enum ggml_status ggml_metal_graph_compute(
1763
1819
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1764
1820
  }
1765
1821
  else if (src0t == GGML_TYPE_Q3_K) {
1766
- #ifdef GGML_QKK_64
1767
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1768
- #else
1769
1822
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1770
- #endif
1771
1823
  }
1772
1824
  else if (src0t == GGML_TYPE_Q5_K) {
1773
1825
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -1785,16 +1837,6 @@ static enum ggml_status ggml_metal_graph_compute(
1785
1837
  const int n_as = src0->ne[2];
1786
1838
 
1787
1839
  // src2 = ids
1788
- const int64_t ne20 = src2->ne[0];
1789
- const int64_t ne21 = src2->ne[1];
1790
- const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
1791
- const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
1792
-
1793
- const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
1794
- const uint64_t nb21 = src2->nb[1];
1795
- const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
1796
- const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
1797
-
1798
1840
  const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
1799
1841
 
1800
1842
  GGML_ASSERT(src2t == GGML_TYPE_I32);
@@ -2018,12 +2060,7 @@ static enum ggml_status ggml_metal_graph_compute(
2018
2060
  {
2019
2061
  nth0 = 4;
2020
2062
  nth1 = 16;
2021
- #if QK_K == 64
2022
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
2023
- #else
2024
2063
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
2025
- #endif
2026
-
2027
2064
  } break;
2028
2065
  default:
2029
2066
  {
@@ -2088,11 +2125,7 @@ static enum ggml_status ggml_metal_graph_compute(
2088
2125
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2089
2126
  }
2090
2127
  else if (src0t == GGML_TYPE_Q3_K) {
2091
- #ifdef GGML_QKK_64
2092
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2093
- #else
2094
2128
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2095
- #endif
2096
2129
  }
2097
2130
  else if (src0t == GGML_TYPE_Q5_K) {
2098
2131
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -2153,6 +2186,7 @@ static enum ggml_status ggml_metal_graph_compute(
2153
2186
  case GGML_OP_RMS_NORM:
2154
2187
  {
2155
2188
  GGML_ASSERT(ne00 % 4 == 0);
2189
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2156
2190
 
2157
2191
  float eps;
2158
2192
  memcpy(&eps, dst->op_params, sizeof(float));
@@ -2180,6 +2214,7 @@ static enum ggml_status ggml_metal_graph_compute(
2180
2214
  case GGML_OP_GROUP_NORM:
2181
2215
  {
2182
2216
  GGML_ASSERT(ne00 % 4 == 0);
2217
+ GGML_ASSERT(ggml_is_contiguous(src0));
2183
2218
 
2184
2219
  //float eps;
2185
2220
  //memcpy(&eps, dst->op_params, sizeof(float));
@@ -2213,6 +2248,8 @@ static enum ggml_status ggml_metal_graph_compute(
2213
2248
  } break;
2214
2249
  case GGML_OP_NORM:
2215
2250
  {
2251
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2252
+
2216
2253
  float eps;
2217
2254
  memcpy(&eps, dst->op_params, sizeof(float));
2218
2255
 
@@ -2244,7 +2281,13 @@ static enum ggml_status ggml_metal_graph_compute(
2244
2281
  // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
2245
2282
  const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
2246
2283
 
2247
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
2284
+ float freq_base;
2285
+ float freq_scale;
2286
+ float ext_factor;
2287
+ float attn_factor;
2288
+ float beta_fast;
2289
+ float beta_slow;
2290
+
2248
2291
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
2249
2292
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
2250
2293
  memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
@@ -2252,6 +2295,15 @@ static enum ggml_status ggml_metal_graph_compute(
2252
2295
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
2253
2296
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
2254
2297
 
2298
+ const bool is_neox = mode & 2;
2299
+ const bool is_glm = mode & 4;
2300
+
2301
+ GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
2302
+
2303
+ if (!is_neox) {
2304
+ GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
2305
+ }
2306
+
2255
2307
  id<MTLComputePipelineState> pipeline = nil;
2256
2308
 
2257
2309
  switch (src0->type) {
@@ -2263,33 +2315,38 @@ static enum ggml_status ggml_metal_graph_compute(
2263
2315
  [encoder setComputePipelineState:pipeline];
2264
2316
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2265
2317
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2266
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2267
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
2268
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
2269
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
2270
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
2271
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
2272
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2273
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2274
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2275
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
2276
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
2277
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
2278
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
2279
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
2280
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
2281
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
2282
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
2283
- [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
2284
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
2285
- [encoder setBytes:&mode length:sizeof( int) atIndex:21];
2286
- [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
2287
- [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2288
- [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2289
- [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2290
- [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2291
- [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2292
- [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
2318
+ if (id_src2 != nil) {
2319
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2320
+ } else {
2321
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
2322
+ }
2323
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2324
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
2325
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2326
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2327
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2328
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
2329
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
2330
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
2331
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
2332
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
2333
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
2334
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
2335
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
2336
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
2337
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
2338
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
2339
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
2340
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:20];
2341
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
2342
+ [encoder setBytes:&mode length:sizeof( int) atIndex:22];
2343
+ [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23];
2344
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:24];
2345
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:25];
2346
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:26];
2347
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:27];
2348
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:28];
2349
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
2293
2350
 
2294
2351
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2295
2352
  } break;
@@ -2535,11 +2592,6 @@ static enum ggml_status ggml_metal_graph_compute(
2535
2592
  GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
2536
2593
  "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
2537
2594
 
2538
- const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
2539
- const uint64_t nb21 = src2 ? src2->nb[1] : 0;
2540
- const uint64_t nb22 = src2 ? src2->nb[2] : 0;
2541
- const uint64_t nb23 = src2 ? src2->nb[3] : 0;
2542
-
2543
2595
  const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
2544
2596
  //const int64_t ne31 = src3 ? src3->ne[1] : 0;
2545
2597
  const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
@@ -2575,7 +2627,7 @@ static enum ggml_status ggml_metal_graph_compute(
2575
2627
  case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
2576
2628
  case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
2577
2629
  case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
2578
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2630
+ //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2579
2631
  default:
2580
2632
  {
2581
2633
  GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
@@ -2588,7 +2640,7 @@ static enum ggml_status ggml_metal_graph_compute(
2588
2640
 
2589
2641
  switch (ne00) {
2590
2642
  case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
2591
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
2643
+ //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
2592
2644
  default:
2593
2645
  {
2594
2646
  GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);