llama_cpp 0.15.2 → 0.15.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -35,6 +35,10 @@ enum ggml_metal_kernel_type {
35
35
  GGML_METAL_KERNEL_TYPE_MUL_ROW,
36
36
  GGML_METAL_KERNEL_TYPE_DIV,
37
37
  GGML_METAL_KERNEL_TYPE_DIV_ROW,
38
+ GGML_METAL_KERNEL_TYPE_REPEAT_F32,
39
+ GGML_METAL_KERNEL_TYPE_REPEAT_F16,
40
+ GGML_METAL_KERNEL_TYPE_REPEAT_I32,
41
+ GGML_METAL_KERNEL_TYPE_REPEAT_I16,
38
42
  GGML_METAL_KERNEL_TYPE_SCALE,
39
43
  GGML_METAL_KERNEL_TYPE_SCALE_4,
40
44
  GGML_METAL_KERNEL_TYPE_CLAMP,
@@ -184,9 +188,9 @@ enum ggml_metal_kernel_type {
184
188
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
185
189
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
186
190
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
187
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
191
+ //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
188
192
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
189
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
193
+ //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
190
194
  GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
191
195
  GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
192
196
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@@ -381,10 +385,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
381
385
  // dictionary of preprocessor macros
382
386
  NSMutableDictionary * prep = [NSMutableDictionary dictionary];
383
387
 
384
- #ifdef GGML_QKK_64
385
- prep[@"GGML_QKK_64"] = @(1);
386
- #endif
387
-
388
388
  MTLCompileOptions* options = [MTLCompileOptions new];
389
389
  options.preprocessorMacros = prep;
390
390
 
@@ -489,6 +489,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
489
489
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
490
490
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
491
491
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
492
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
493
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
494
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
495
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
492
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
493
497
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
494
498
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
@@ -638,9 +642,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
638
642
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
639
643
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
640
644
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
641
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
645
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
642
646
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
643
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
647
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
644
648
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
645
649
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
646
650
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
@@ -750,6 +754,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
750
754
  case GGML_OP_ACC:
751
755
  case GGML_OP_MUL:
752
756
  case GGML_OP_DIV:
757
+ case GGML_OP_REPEAT:
753
758
  case GGML_OP_SCALE:
754
759
  case GGML_OP_CLAMP:
755
760
  case GGML_OP_SQR:
@@ -774,6 +779,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
774
779
  case GGML_OP_LEAKY_RELU:
775
780
  return true;
776
781
  case GGML_OP_FLASH_ATTN_EXT:
782
+ if (op->src[0]->ne[0] == 256) {
783
+ return false;
784
+ }
777
785
  return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
778
786
  case GGML_OP_MUL_MAT:
779
787
  case GGML_OP_MUL_MAT_ID:
@@ -927,22 +935,32 @@ static enum ggml_status ggml_metal_graph_compute(
927
935
  const int64_t ne10 = src1 ? src1->ne[0] : 0;
928
936
  const int64_t ne11 = src1 ? src1->ne[1] : 0;
929
937
  const int64_t ne12 = src1 ? src1->ne[2] : 0;
930
- const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
938
+ const int64_t ne13 = src1 ? src1->ne[3] : 0;
931
939
 
932
940
  const uint64_t nb10 = src1 ? src1->nb[0] : 0;
933
941
  const uint64_t nb11 = src1 ? src1->nb[1] : 0;
934
942
  const uint64_t nb12 = src1 ? src1->nb[2] : 0;
935
- const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
943
+ const uint64_t nb13 = src1 ? src1->nb[3] : 0;
936
944
 
937
- const int64_t ne0 = dst ? dst->ne[0] : 0;
938
- const int64_t ne1 = dst ? dst->ne[1] : 0;
939
- const int64_t ne2 = dst ? dst->ne[2] : 0;
940
- const int64_t ne3 = dst ? dst->ne[3] : 0;
945
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
946
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
947
+ const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
948
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
941
949
 
942
- const uint64_t nb0 = dst ? dst->nb[0] : 0;
943
- const uint64_t nb1 = dst ? dst->nb[1] : 0;
944
- const uint64_t nb2 = dst ? dst->nb[2] : 0;
945
- const uint64_t nb3 = dst ? dst->nb[3] : 0;
950
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
951
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
952
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
953
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0;
954
+
955
+ const int64_t ne0 = dst ? dst->ne[0] : 0;
956
+ const int64_t ne1 = dst ? dst->ne[1] : 0;
957
+ const int64_t ne2 = dst ? dst->ne[2] : 0;
958
+ const int64_t ne3 = dst ? dst->ne[3] : 0;
959
+
960
+ const uint64_t nb0 = dst ? dst->nb[0] : 0;
961
+ const uint64_t nb1 = dst ? dst->nb[1] : 0;
962
+ const uint64_t nb2 = dst ? dst->nb[2] : 0;
963
+ const uint64_t nb3 = dst ? dst->nb[3] : 0;
946
964
 
947
965
  const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
948
966
  const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
@@ -970,10 +988,10 @@ static enum ggml_status ggml_metal_graph_compute(
970
988
  switch (dst->op) {
971
989
  case GGML_OP_CONCAT:
972
990
  {
973
- const int64_t nb = ne00;
974
-
975
991
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
976
992
 
993
+ const int32_t dim = ((int32_t *) dst->op_params)[0];
994
+
977
995
  [encoder setComputePipelineState:pipeline];
978
996
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
979
997
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -1002,7 +1020,7 @@ static enum ggml_status ggml_metal_graph_compute(
1002
1020
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1003
1021
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1004
1022
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1005
- [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
1023
+ [encoder setBytes:&dim length:sizeof(dim) atIndex:27];
1006
1024
 
1007
1025
  const int nth = MIN(1024, ne0);
1008
1026
 
@@ -1012,11 +1030,14 @@ static enum ggml_status ggml_metal_graph_compute(
1012
1030
  case GGML_OP_MUL:
1013
1031
  case GGML_OP_DIV:
1014
1032
  {
1033
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
1034
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1035
+
1015
1036
  const size_t offs = 0;
1016
1037
 
1017
1038
  bool bcast_row = false;
1018
1039
 
1019
- int64_t nb = ne00;
1040
+ int64_t nb = ne00; // used by the "row" kernels
1020
1041
 
1021
1042
  id<MTLComputePipelineState> pipeline = nil;
1022
1043
 
@@ -1085,6 +1106,42 @@ static enum ggml_status ggml_metal_graph_compute(
1085
1106
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1086
1107
  }
1087
1108
  } break;
1109
+ case GGML_OP_REPEAT:
1110
+ {
1111
+ id<MTLComputePipelineState> pipeline;
1112
+
1113
+ switch (src0t) {
1114
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
1115
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
1116
+ case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
1117
+ case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
1118
+ default: GGML_ASSERT(false);
1119
+ }
1120
+
1121
+ [encoder setComputePipelineState:pipeline];
1122
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1123
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1124
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1125
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1126
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1127
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1128
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1129
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1130
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1131
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1132
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
1133
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
1134
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
1135
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
1136
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
1137
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1138
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
1139
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
1140
+
1141
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1142
+
1143
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1144
+ } break;
1088
1145
  case GGML_OP_ACC:
1089
1146
  {
1090
1147
  GGML_ASSERT(src0t == GGML_TYPE_F32);
@@ -1462,7 +1519,6 @@ static enum ggml_status ggml_metal_graph_compute(
1462
1519
  {
1463
1520
  GGML_ASSERT(ne00 == ne10);
1464
1521
 
1465
- // TODO: assert that dim2 and dim3 are contiguous
1466
1522
  GGML_ASSERT(ne12 % ne02 == 0);
1467
1523
  GGML_ASSERT(ne13 % ne03 == 0);
1468
1524
 
@@ -1763,11 +1819,7 @@ static enum ggml_status ggml_metal_graph_compute(
1763
1819
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1764
1820
  }
1765
1821
  else if (src0t == GGML_TYPE_Q3_K) {
1766
- #ifdef GGML_QKK_64
1767
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1768
- #else
1769
1822
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1770
- #endif
1771
1823
  }
1772
1824
  else if (src0t == GGML_TYPE_Q5_K) {
1773
1825
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -1785,16 +1837,6 @@ static enum ggml_status ggml_metal_graph_compute(
1785
1837
  const int n_as = src0->ne[2];
1786
1838
 
1787
1839
  // src2 = ids
1788
- const int64_t ne20 = src2->ne[0];
1789
- const int64_t ne21 = src2->ne[1];
1790
- const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
1791
- const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
1792
-
1793
- const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
1794
- const uint64_t nb21 = src2->nb[1];
1795
- const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
1796
- const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
1797
-
1798
1840
  const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
1799
1841
 
1800
1842
  GGML_ASSERT(src2t == GGML_TYPE_I32);
@@ -2018,12 +2060,7 @@ static enum ggml_status ggml_metal_graph_compute(
2018
2060
  {
2019
2061
  nth0 = 4;
2020
2062
  nth1 = 16;
2021
- #if QK_K == 64
2022
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
2023
- #else
2024
2063
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
2025
- #endif
2026
-
2027
2064
  } break;
2028
2065
  default:
2029
2066
  {
@@ -2088,11 +2125,7 @@ static enum ggml_status ggml_metal_graph_compute(
2088
2125
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2089
2126
  }
2090
2127
  else if (src0t == GGML_TYPE_Q3_K) {
2091
- #ifdef GGML_QKK_64
2092
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2093
- #else
2094
2128
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2095
- #endif
2096
2129
  }
2097
2130
  else if (src0t == GGML_TYPE_Q5_K) {
2098
2131
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -2153,6 +2186,7 @@ static enum ggml_status ggml_metal_graph_compute(
2153
2186
  case GGML_OP_RMS_NORM:
2154
2187
  {
2155
2188
  GGML_ASSERT(ne00 % 4 == 0);
2189
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2156
2190
 
2157
2191
  float eps;
2158
2192
  memcpy(&eps, dst->op_params, sizeof(float));
@@ -2180,6 +2214,7 @@ static enum ggml_status ggml_metal_graph_compute(
2180
2214
  case GGML_OP_GROUP_NORM:
2181
2215
  {
2182
2216
  GGML_ASSERT(ne00 % 4 == 0);
2217
+ GGML_ASSERT(ggml_is_contiguous(src0));
2183
2218
 
2184
2219
  //float eps;
2185
2220
  //memcpy(&eps, dst->op_params, sizeof(float));
@@ -2213,6 +2248,8 @@ static enum ggml_status ggml_metal_graph_compute(
2213
2248
  } break;
2214
2249
  case GGML_OP_NORM:
2215
2250
  {
2251
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2252
+
2216
2253
  float eps;
2217
2254
  memcpy(&eps, dst->op_params, sizeof(float));
2218
2255
 
@@ -2244,7 +2281,13 @@ static enum ggml_status ggml_metal_graph_compute(
2244
2281
  // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
2245
2282
  const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
2246
2283
 
2247
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
2284
+ float freq_base;
2285
+ float freq_scale;
2286
+ float ext_factor;
2287
+ float attn_factor;
2288
+ float beta_fast;
2289
+ float beta_slow;
2290
+
2248
2291
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
2249
2292
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
2250
2293
  memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
@@ -2252,6 +2295,15 @@ static enum ggml_status ggml_metal_graph_compute(
2252
2295
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
2253
2296
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
2254
2297
 
2298
+ const bool is_neox = mode & 2;
2299
+ const bool is_glm = mode & 4;
2300
+
2301
+ GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
2302
+
2303
+ if (!is_neox) {
2304
+ GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
2305
+ }
2306
+
2255
2307
  id<MTLComputePipelineState> pipeline = nil;
2256
2308
 
2257
2309
  switch (src0->type) {
@@ -2263,33 +2315,38 @@ static enum ggml_status ggml_metal_graph_compute(
2263
2315
  [encoder setComputePipelineState:pipeline];
2264
2316
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2265
2317
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2266
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2267
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
2268
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
2269
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
2270
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
2271
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
2272
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2273
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2274
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2275
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
2276
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
2277
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
2278
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
2279
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
2280
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
2281
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
2282
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
2283
- [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
2284
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
2285
- [encoder setBytes:&mode length:sizeof( int) atIndex:21];
2286
- [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
2287
- [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2288
- [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2289
- [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2290
- [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2291
- [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2292
- [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
2318
+ if (id_src2 != nil) {
2319
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2320
+ } else {
2321
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
2322
+ }
2323
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2324
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
2325
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2326
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2327
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2328
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
2329
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
2330
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
2331
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
2332
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
2333
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
2334
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
2335
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
2336
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
2337
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
2338
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
2339
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
2340
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:20];
2341
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
2342
+ [encoder setBytes:&mode length:sizeof( int) atIndex:22];
2343
+ [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23];
2344
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:24];
2345
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:25];
2346
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:26];
2347
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:27];
2348
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:28];
2349
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
2293
2350
 
2294
2351
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2295
2352
  } break;
@@ -2535,11 +2592,6 @@ static enum ggml_status ggml_metal_graph_compute(
2535
2592
  GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
2536
2593
  "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
2537
2594
 
2538
- const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
2539
- const uint64_t nb21 = src2 ? src2->nb[1] : 0;
2540
- const uint64_t nb22 = src2 ? src2->nb[2] : 0;
2541
- const uint64_t nb23 = src2 ? src2->nb[3] : 0;
2542
-
2543
2595
  const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
2544
2596
  //const int64_t ne31 = src3 ? src3->ne[1] : 0;
2545
2597
  const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
@@ -2575,7 +2627,7 @@ static enum ggml_status ggml_metal_graph_compute(
2575
2627
  case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
2576
2628
  case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
2577
2629
  case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
2578
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2630
+ //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2579
2631
  default:
2580
2632
  {
2581
2633
  GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
@@ -2588,7 +2640,7 @@ static enum ggml_status ggml_metal_graph_compute(
2588
2640
 
2589
2641
  switch (ne00) {
2590
2642
  case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
2591
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
2643
+ //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
2592
2644
  default:
2593
2645
  {
2594
2646
  GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);