llama_cpp 0.15.1 → 0.15.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -40,6 +40,7 @@ enum ggml_metal_kernel_type {
40
40
  GGML_METAL_KERNEL_TYPE_CLAMP,
41
41
  GGML_METAL_KERNEL_TYPE_TANH,
42
42
  GGML_METAL_KERNEL_TYPE_RELU,
43
+ GGML_METAL_KERNEL_TYPE_SIGMOID,
43
44
  GGML_METAL_KERNEL_TYPE_GELU,
44
45
  GGML_METAL_KERNEL_TYPE_GELU_4,
45
46
  GGML_METAL_KERNEL_TYPE_GELU_QUICK,
@@ -169,7 +170,6 @@ enum ggml_metal_kernel_type {
169
170
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
170
171
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
171
172
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
172
- GGML_METAL_KERNEL_TYPE_ALIBI_F32,
173
173
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
174
174
  GGML_METAL_KERNEL_TYPE_IM2COL_F32,
175
175
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
@@ -381,10 +381,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
381
381
  // dictionary of preprocessor macros
382
382
  NSMutableDictionary * prep = [NSMutableDictionary dictionary];
383
383
 
384
- #ifdef GGML_QKK_64
385
- prep[@"GGML_QKK_64"] = @(1);
386
- #endif
387
-
388
384
  MTLCompileOptions* options = [MTLCompileOptions new];
389
385
  options.preprocessorMacros = prep;
390
386
 
@@ -494,6 +490,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
494
490
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
495
491
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
496
492
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
493
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
497
494
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
498
495
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
499
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
@@ -623,7 +620,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
623
620
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
624
621
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
625
622
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
626
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
627
623
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
628
624
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
629
625
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
@@ -633,14 +629,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
633
629
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
634
630
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
635
631
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
636
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
637
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
638
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
639
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
640
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
641
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
642
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true);
643
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true);
632
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm);
633
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm);
634
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
635
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
636
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
637
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
638
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
639
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
644
640
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
645
641
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
646
642
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
@@ -732,6 +728,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
732
728
  switch (ggml_get_unary_op(op)) {
733
729
  case GGML_UNARY_OP_TANH:
734
730
  case GGML_UNARY_OP_RELU:
731
+ case GGML_UNARY_OP_SIGMOID:
735
732
  case GGML_UNARY_OP_GELU:
736
733
  case GGML_UNARY_OP_GELU_QUICK:
737
734
  case GGML_UNARY_OP_SILU:
@@ -759,7 +756,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
759
756
  case GGML_OP_GROUP_NORM:
760
757
  return ctx->support_simdgroup_reduction;
761
758
  case GGML_OP_NORM:
762
- case GGML_OP_ALIBI:
763
759
  case GGML_OP_ROPE:
764
760
  case GGML_OP_IM2COL:
765
761
  return true;
@@ -772,8 +768,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
772
768
  case GGML_OP_TIMESTEP_EMBEDDING:
773
769
  case GGML_OP_ARGSORT:
774
770
  case GGML_OP_LEAKY_RELU:
775
- case GGML_OP_FLASH_ATTN_EXT:
776
771
  return true;
772
+ case GGML_OP_FLASH_ATTN_EXT:
773
+ return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
777
774
  case GGML_OP_MUL_MAT:
778
775
  case GGML_OP_MUL_MAT_ID:
779
776
  return ctx->support_simdgroup_reduction &&
@@ -926,22 +923,32 @@ static enum ggml_status ggml_metal_graph_compute(
926
923
  const int64_t ne10 = src1 ? src1->ne[0] : 0;
927
924
  const int64_t ne11 = src1 ? src1->ne[1] : 0;
928
925
  const int64_t ne12 = src1 ? src1->ne[2] : 0;
929
- const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
926
+ const int64_t ne13 = src1 ? src1->ne[3] : 0;
930
927
 
931
928
  const uint64_t nb10 = src1 ? src1->nb[0] : 0;
932
929
  const uint64_t nb11 = src1 ? src1->nb[1] : 0;
933
930
  const uint64_t nb12 = src1 ? src1->nb[2] : 0;
934
- const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
931
+ const uint64_t nb13 = src1 ? src1->nb[3] : 0;
932
+
933
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
934
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
935
+ const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
936
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
935
937
 
936
- const int64_t ne0 = dst ? dst->ne[0] : 0;
937
- const int64_t ne1 = dst ? dst->ne[1] : 0;
938
- const int64_t ne2 = dst ? dst->ne[2] : 0;
939
- const int64_t ne3 = dst ? dst->ne[3] : 0;
938
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
939
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
940
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
941
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0;
940
942
 
941
- const uint64_t nb0 = dst ? dst->nb[0] : 0;
942
- const uint64_t nb1 = dst ? dst->nb[1] : 0;
943
- const uint64_t nb2 = dst ? dst->nb[2] : 0;
944
- const uint64_t nb3 = dst ? dst->nb[3] : 0;
943
+ const int64_t ne0 = dst ? dst->ne[0] : 0;
944
+ const int64_t ne1 = dst ? dst->ne[1] : 0;
945
+ const int64_t ne2 = dst ? dst->ne[2] : 0;
946
+ const int64_t ne3 = dst ? dst->ne[3] : 0;
947
+
948
+ const uint64_t nb0 = dst ? dst->nb[0] : 0;
949
+ const uint64_t nb1 = dst ? dst->nb[1] : 0;
950
+ const uint64_t nb2 = dst ? dst->nb[2] : 0;
951
+ const uint64_t nb3 = dst ? dst->nb[3] : 0;
945
952
 
946
953
  const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
947
954
  const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
@@ -1194,24 +1201,24 @@ static enum ggml_status ggml_metal_graph_compute(
1194
1201
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1195
1202
  } break;
1196
1203
  case GGML_OP_CLAMP:
1197
- {
1198
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
1204
+ {
1205
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
1199
1206
 
1200
- float min;
1201
- float max;
1202
- memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
1203
- memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
1207
+ float min;
1208
+ float max;
1209
+ memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
1210
+ memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
1204
1211
 
1205
- [encoder setComputePipelineState:pipeline];
1206
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1207
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1208
- [encoder setBytes:&min length:sizeof(min) atIndex:2];
1209
- [encoder setBytes:&max length:sizeof(max) atIndex:3];
1212
+ [encoder setComputePipelineState:pipeline];
1213
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1214
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1215
+ [encoder setBytes:&min length:sizeof(min) atIndex:2];
1216
+ [encoder setBytes:&max length:sizeof(max) atIndex:3];
1210
1217
 
1211
- const int64_t n = ggml_nelements(dst);
1218
+ const int64_t n = ggml_nelements(dst);
1212
1219
 
1213
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1214
- } break;
1220
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1221
+ } break;
1215
1222
  case GGML_OP_UNARY:
1216
1223
  switch (ggml_get_unary_op(gf->nodes[i])) {
1217
1224
  // we are not taking into account the strides, so for now require contiguous tensors
@@ -1239,6 +1246,18 @@ static enum ggml_status ggml_metal_graph_compute(
1239
1246
 
1240
1247
  const int64_t n = ggml_nelements(dst);
1241
1248
 
1249
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1250
+ } break;
1251
+ case GGML_UNARY_OP_SIGMOID:
1252
+ {
1253
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
1254
+
1255
+ [encoder setComputePipelineState:pipeline];
1256
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1257
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1258
+
1259
+ const int64_t n = ggml_nelements(dst);
1260
+
1242
1261
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1243
1262
  } break;
1244
1263
  case GGML_UNARY_OP_GELU:
@@ -1357,16 +1376,15 @@ static enum ggml_status ggml_metal_graph_compute(
1357
1376
  case GGML_OP_SOFT_MAX:
1358
1377
  {
1359
1378
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
1360
- GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
1361
1379
 
1362
1380
  int nth = 32; // SIMD width
1363
1381
 
1364
1382
  id<MTLComputePipelineState> pipeline = nil;
1365
1383
 
1366
- const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
1384
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
1367
1385
 
1368
1386
  if (ne00%4 == 0) {
1369
- while (nth < ne00/4 && nth < 256) {
1387
+ while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
1370
1388
  nth *= 2;
1371
1389
  }
1372
1390
  if (use_f16) {
@@ -1375,7 +1393,7 @@ static enum ggml_status ggml_metal_graph_compute(
1375
1393
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
1376
1394
  }
1377
1395
  } else {
1378
- while (nth < ne00 && nth < 1024) {
1396
+ while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
1379
1397
  nth *= 2;
1380
1398
  }
1381
1399
  if (use_f16) {
@@ -1394,8 +1412,8 @@ static enum ggml_status ggml_metal_graph_compute(
1394
1412
  const int64_t nrows_x = ggml_nrows(src0);
1395
1413
  const int64_t nrows_y = src0->ne[1];
1396
1414
 
1397
- const uint32_t n_head_kv = nrows_x/nrows_y;
1398
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
1415
+ const uint32_t n_head = nrows_x/nrows_y;
1416
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
1399
1417
 
1400
1418
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1401
1419
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
@@ -1407,20 +1425,15 @@ static enum ggml_status ggml_metal_graph_compute(
1407
1425
  } else {
1408
1426
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1409
1427
  }
1410
- if (id_src2) {
1411
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
1412
- } else {
1413
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
1414
- }
1415
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
1416
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
1417
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5];
1418
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
1419
- [encoder setBytes:&scale length:sizeof(scale) atIndex:7];
1420
- [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
1421
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
1422
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
1423
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
1428
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1429
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1430
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1431
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1432
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1433
+ [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
1434
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
1435
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
1436
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
1424
1437
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1425
1438
 
1426
1439
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -1756,11 +1769,7 @@ static enum ggml_status ggml_metal_graph_compute(
1756
1769
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1757
1770
  }
1758
1771
  else if (src0t == GGML_TYPE_Q3_K) {
1759
- #ifdef GGML_QKK_64
1760
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1761
- #else
1762
1772
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1763
- #endif
1764
1773
  }
1765
1774
  else if (src0t == GGML_TYPE_Q5_K) {
1766
1775
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -1778,16 +1787,6 @@ static enum ggml_status ggml_metal_graph_compute(
1778
1787
  const int n_as = src0->ne[2];
1779
1788
 
1780
1789
  // src2 = ids
1781
- const int64_t ne20 = src2->ne[0];
1782
- const int64_t ne21 = src2->ne[1];
1783
- const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
1784
- const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
1785
-
1786
- const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
1787
- const uint64_t nb21 = src2->nb[1];
1788
- const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
1789
- const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
1790
-
1791
1790
  const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
1792
1791
 
1793
1792
  GGML_ASSERT(src2t == GGML_TYPE_I32);
@@ -2011,12 +2010,7 @@ static enum ggml_status ggml_metal_graph_compute(
2011
2010
  {
2012
2011
  nth0 = 4;
2013
2012
  nth1 = 16;
2014
- #if QK_K == 64
2015
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
2016
- #else
2017
2013
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
2018
- #endif
2019
-
2020
2014
  } break;
2021
2015
  default:
2022
2016
  {
@@ -2081,11 +2075,7 @@ static enum ggml_status ggml_metal_graph_compute(
2081
2075
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2082
2076
  }
2083
2077
  else if (src0t == GGML_TYPE_Q3_K) {
2084
- #ifdef GGML_QKK_64
2085
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2086
- #else
2087
2078
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2088
- #endif
2089
2079
  }
2090
2080
  else if (src0t == GGML_TYPE_Q5_K) {
2091
2081
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -2225,49 +2215,6 @@ static enum ggml_status ggml_metal_graph_compute(
2225
2215
 
2226
2216
  [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2227
2217
  } break;
2228
- case GGML_OP_ALIBI:
2229
- {
2230
- GGML_ASSERT((src0t == GGML_TYPE_F32));
2231
-
2232
- const int nth = MIN(1024, ne00);
2233
-
2234
- //const int n_past = ((int32_t *) dst->op_params)[0];
2235
- const int n_head = ((int32_t *) dst->op_params)[1];
2236
-
2237
- float max_bias;
2238
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
2239
-
2240
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
2241
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
2242
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
2243
-
2244
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
2245
-
2246
- [encoder setComputePipelineState:pipeline];
2247
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2248
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2249
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2250
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2251
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2252
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
2253
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
2254
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
2255
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
2256
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
2257
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
2258
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
2259
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
2260
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
2261
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
2262
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
2263
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
2264
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
2265
- [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
2266
- [encoder setBytes:&m1 length:sizeof( float) atIndex:19];
2267
- [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
2268
-
2269
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2270
- } break;
2271
2218
  case GGML_OP_ROPE:
2272
2219
  {
2273
2220
  GGML_ASSERT(ne10 == ne02);
@@ -2280,7 +2227,13 @@ static enum ggml_status ggml_metal_graph_compute(
2280
2227
  // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
2281
2228
  const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
2282
2229
 
2283
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
2230
+ float freq_base;
2231
+ float freq_scale;
2232
+ float ext_factor;
2233
+ float attn_factor;
2234
+ float beta_fast;
2235
+ float beta_slow;
2236
+
2284
2237
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
2285
2238
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
2286
2239
  memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
@@ -2288,6 +2241,15 @@ static enum ggml_status ggml_metal_graph_compute(
2288
2241
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
2289
2242
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
2290
2243
 
2244
+ const bool is_neox = mode & 2;
2245
+ const bool is_glm = mode & 4;
2246
+
2247
+ GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
2248
+
2249
+ if (!is_neox) {
2250
+ GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
2251
+ }
2252
+
2291
2253
  id<MTLComputePipelineState> pipeline = nil;
2292
2254
 
2293
2255
  switch (src0->type) {
@@ -2299,33 +2261,38 @@ static enum ggml_status ggml_metal_graph_compute(
2299
2261
  [encoder setComputePipelineState:pipeline];
2300
2262
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2301
2263
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2302
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2303
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
2304
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
2305
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
2306
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
2307
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
2308
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2309
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2310
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2311
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
2312
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
2313
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
2314
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
2315
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
2316
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
2317
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
2318
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
2319
- [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
2320
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
2321
- [encoder setBytes:&mode length:sizeof( int) atIndex:21];
2322
- [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
2323
- [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2324
- [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2325
- [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2326
- [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2327
- [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2328
- [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
2264
+ if (id_src2 != nil) {
2265
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2266
+ } else {
2267
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
2268
+ }
2269
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2270
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
2271
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2272
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2273
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2274
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
2275
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
2276
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
2277
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
2278
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
2279
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
2280
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
2281
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
2282
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
2283
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
2284
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
2285
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
2286
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:20];
2287
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
2288
+ [encoder setBytes:&mode length:sizeof( int) atIndex:22];
2289
+ [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23];
2290
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:24];
2291
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:25];
2292
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:26];
2293
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:27];
2294
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:28];
2295
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
2329
2296
 
2330
2297
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2331
2298
  } break;
@@ -2389,7 +2356,10 @@ static enum ggml_status ggml_metal_graph_compute(
2389
2356
  {
2390
2357
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
2391
2358
 
2392
- const int sf = dst->op_params[0];
2359
+ const float sf0 = (float)ne0/src0->ne[0];
2360
+ const float sf1 = (float)ne1/src0->ne[1];
2361
+ const float sf2 = (float)ne2/src0->ne[2];
2362
+ const float sf3 = (float)ne3/src0->ne[3];
2393
2363
 
2394
2364
  const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
2395
2365
 
@@ -2412,7 +2382,10 @@ static enum ggml_status ggml_metal_graph_compute(
2412
2382
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2413
2383
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2414
2384
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2415
- [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2385
+ [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
2386
+ [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
2387
+ [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
2388
+ [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
2416
2389
 
2417
2390
  const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
2418
2391
 
@@ -2548,13 +2521,14 @@ static enum ggml_status ggml_metal_graph_compute(
2548
2521
  } break;
2549
2522
  case GGML_OP_FLASH_ATTN_EXT:
2550
2523
  {
2551
- GGML_ASSERT(ne00 % 4 == 0);
2524
+ GGML_ASSERT(ne00 % 4 == 0);
2525
+ GGML_ASSERT(ne11 % 32 == 0);
2526
+
2552
2527
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
2553
2528
 
2554
- struct ggml_tensor * src3 = gf->nodes[i]->src[3];
2529
+ GGML_ASSERT(ggml_are_same_shape (src1, src2));
2555
2530
 
2556
- GGML_ASSERT(ggml_are_same_shape(src1, src2));
2557
- GGML_ASSERT(src3);
2531
+ struct ggml_tensor * src3 = gf->nodes[i]->src[3];
2558
2532
 
2559
2533
  size_t offs_src3 = 0;
2560
2534
 
@@ -2565,7 +2539,7 @@ static enum ggml_status ggml_metal_graph_compute(
2565
2539
  "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
2566
2540
 
2567
2541
  const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
2568
- const int64_t ne31 = src3 ? src3->ne[1] : 0;
2542
+ //const int64_t ne31 = src3 ? src3->ne[1] : 0;
2569
2543
  const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
2570
2544
  const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
2571
2545
 
@@ -2577,7 +2551,16 @@ static enum ggml_status ggml_metal_graph_compute(
2577
2551
  const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
2578
2552
 
2579
2553
  float scale;
2580
- memcpy(&scale, dst->op_params, sizeof(float));
2554
+ float max_bias;
2555
+
2556
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
2557
+ memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
2558
+
2559
+ const uint32_t n_head = src0->ne[2];
2560
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
2561
+
2562
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
2563
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
2581
2564
 
2582
2565
  id<MTLComputePipelineState> pipeline = nil;
2583
2566
 
@@ -2614,34 +2597,38 @@ static enum ggml_status ggml_metal_graph_compute(
2614
2597
  }
2615
2598
 
2616
2599
  [encoder setComputePipelineState:pipeline];
2617
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2618
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2619
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2620
- [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2621
- [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2622
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
2623
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
2624
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
2625
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
2626
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
2627
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
2628
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
2629
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
2630
- [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
2631
- [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
2632
- [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
2633
- [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
2634
- [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
2635
- [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
2636
- [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
2637
- [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
2638
- [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21];
2639
- [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22];
2640
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23];
2641
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24];
2642
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25];
2643
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
2644
- [encoder setBytes:&scale length:sizeof( float) atIndex:27];
2600
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2601
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2602
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2603
+ if (id_src3) {
2604
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2605
+ } else {
2606
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
2607
+ }
2608
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2609
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2610
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2611
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2612
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2613
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2614
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2615
+ [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
2616
+ [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
2617
+ [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
2618
+ [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
2619
+ [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
2620
+ [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
2621
+ [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
2622
+ [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
2623
+ [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
2624
+ [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
2625
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
2626
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
2627
+ [encoder setBytes:&scale length:sizeof( float) atIndex:23];
2628
+ [encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
2629
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
2630
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
2631
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
2645
2632
 
2646
2633
  if (!use_vec_kernel) {
2647
2634
  // half8x8 kernel