llama_cpp 0.15.1 → 0.15.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -40,6 +40,7 @@ enum ggml_metal_kernel_type {
40
40
  GGML_METAL_KERNEL_TYPE_CLAMP,
41
41
  GGML_METAL_KERNEL_TYPE_TANH,
42
42
  GGML_METAL_KERNEL_TYPE_RELU,
43
+ GGML_METAL_KERNEL_TYPE_SIGMOID,
43
44
  GGML_METAL_KERNEL_TYPE_GELU,
44
45
  GGML_METAL_KERNEL_TYPE_GELU_4,
45
46
  GGML_METAL_KERNEL_TYPE_GELU_QUICK,
@@ -169,7 +170,6 @@ enum ggml_metal_kernel_type {
169
170
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
170
171
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
171
172
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
172
- GGML_METAL_KERNEL_TYPE_ALIBI_F32,
173
173
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
174
174
  GGML_METAL_KERNEL_TYPE_IM2COL_F32,
175
175
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
@@ -381,10 +381,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
381
381
  // dictionary of preprocessor macros
382
382
  NSMutableDictionary * prep = [NSMutableDictionary dictionary];
383
383
 
384
- #ifdef GGML_QKK_64
385
- prep[@"GGML_QKK_64"] = @(1);
386
- #endif
387
-
388
384
  MTLCompileOptions* options = [MTLCompileOptions new];
389
385
  options.preprocessorMacros = prep;
390
386
 
@@ -494,6 +490,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
494
490
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
495
491
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
496
492
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
493
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
497
494
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
498
495
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
499
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
@@ -623,7 +620,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
623
620
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
624
621
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
625
622
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
626
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
627
623
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
628
624
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
629
625
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
@@ -633,14 +629,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
633
629
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
634
630
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
635
631
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
636
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
637
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
638
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
639
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
640
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
641
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
642
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true);
643
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true);
632
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm);
633
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm);
634
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
635
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
636
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
637
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
638
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
639
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
644
640
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
645
641
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
646
642
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
@@ -732,6 +728,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
732
728
  switch (ggml_get_unary_op(op)) {
733
729
  case GGML_UNARY_OP_TANH:
734
730
  case GGML_UNARY_OP_RELU:
731
+ case GGML_UNARY_OP_SIGMOID:
735
732
  case GGML_UNARY_OP_GELU:
736
733
  case GGML_UNARY_OP_GELU_QUICK:
737
734
  case GGML_UNARY_OP_SILU:
@@ -759,7 +756,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
759
756
  case GGML_OP_GROUP_NORM:
760
757
  return ctx->support_simdgroup_reduction;
761
758
  case GGML_OP_NORM:
762
- case GGML_OP_ALIBI:
763
759
  case GGML_OP_ROPE:
764
760
  case GGML_OP_IM2COL:
765
761
  return true;
@@ -772,8 +768,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
772
768
  case GGML_OP_TIMESTEP_EMBEDDING:
773
769
  case GGML_OP_ARGSORT:
774
770
  case GGML_OP_LEAKY_RELU:
775
- case GGML_OP_FLASH_ATTN_EXT:
776
771
  return true;
772
+ case GGML_OP_FLASH_ATTN_EXT:
773
+ return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
777
774
  case GGML_OP_MUL_MAT:
778
775
  case GGML_OP_MUL_MAT_ID:
779
776
  return ctx->support_simdgroup_reduction &&
@@ -926,22 +923,32 @@ static enum ggml_status ggml_metal_graph_compute(
926
923
  const int64_t ne10 = src1 ? src1->ne[0] : 0;
927
924
  const int64_t ne11 = src1 ? src1->ne[1] : 0;
928
925
  const int64_t ne12 = src1 ? src1->ne[2] : 0;
929
- const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
926
+ const int64_t ne13 = src1 ? src1->ne[3] : 0;
930
927
 
931
928
  const uint64_t nb10 = src1 ? src1->nb[0] : 0;
932
929
  const uint64_t nb11 = src1 ? src1->nb[1] : 0;
933
930
  const uint64_t nb12 = src1 ? src1->nb[2] : 0;
934
- const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
931
+ const uint64_t nb13 = src1 ? src1->nb[3] : 0;
932
+
933
+ const int64_t ne20 = src2 ? src2->ne[0] : 0;
934
+ const int64_t ne21 = src2 ? src2->ne[1] : 0;
935
+ const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
936
+ const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
935
937
 
936
- const int64_t ne0 = dst ? dst->ne[0] : 0;
937
- const int64_t ne1 = dst ? dst->ne[1] : 0;
938
- const int64_t ne2 = dst ? dst->ne[2] : 0;
939
- const int64_t ne3 = dst ? dst->ne[3] : 0;
938
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
939
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
940
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
941
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0;
940
942
 
941
- const uint64_t nb0 = dst ? dst->nb[0] : 0;
942
- const uint64_t nb1 = dst ? dst->nb[1] : 0;
943
- const uint64_t nb2 = dst ? dst->nb[2] : 0;
944
- const uint64_t nb3 = dst ? dst->nb[3] : 0;
943
+ const int64_t ne0 = dst ? dst->ne[0] : 0;
944
+ const int64_t ne1 = dst ? dst->ne[1] : 0;
945
+ const int64_t ne2 = dst ? dst->ne[2] : 0;
946
+ const int64_t ne3 = dst ? dst->ne[3] : 0;
947
+
948
+ const uint64_t nb0 = dst ? dst->nb[0] : 0;
949
+ const uint64_t nb1 = dst ? dst->nb[1] : 0;
950
+ const uint64_t nb2 = dst ? dst->nb[2] : 0;
951
+ const uint64_t nb3 = dst ? dst->nb[3] : 0;
945
952
 
946
953
  const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
947
954
  const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
@@ -1194,24 +1201,24 @@ static enum ggml_status ggml_metal_graph_compute(
1194
1201
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1195
1202
  } break;
1196
1203
  case GGML_OP_CLAMP:
1197
- {
1198
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
1204
+ {
1205
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
1199
1206
 
1200
- float min;
1201
- float max;
1202
- memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
1203
- memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
1207
+ float min;
1208
+ float max;
1209
+ memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
1210
+ memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
1204
1211
 
1205
- [encoder setComputePipelineState:pipeline];
1206
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1207
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1208
- [encoder setBytes:&min length:sizeof(min) atIndex:2];
1209
- [encoder setBytes:&max length:sizeof(max) atIndex:3];
1212
+ [encoder setComputePipelineState:pipeline];
1213
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1214
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1215
+ [encoder setBytes:&min length:sizeof(min) atIndex:2];
1216
+ [encoder setBytes:&max length:sizeof(max) atIndex:3];
1210
1217
 
1211
- const int64_t n = ggml_nelements(dst);
1218
+ const int64_t n = ggml_nelements(dst);
1212
1219
 
1213
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1214
- } break;
1220
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1221
+ } break;
1215
1222
  case GGML_OP_UNARY:
1216
1223
  switch (ggml_get_unary_op(gf->nodes[i])) {
1217
1224
  // we are not taking into account the strides, so for now require contiguous tensors
@@ -1239,6 +1246,18 @@ static enum ggml_status ggml_metal_graph_compute(
1239
1246
 
1240
1247
  const int64_t n = ggml_nelements(dst);
1241
1248
 
1249
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1250
+ } break;
1251
+ case GGML_UNARY_OP_SIGMOID:
1252
+ {
1253
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
1254
+
1255
+ [encoder setComputePipelineState:pipeline];
1256
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1257
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1258
+
1259
+ const int64_t n = ggml_nelements(dst);
1260
+
1242
1261
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1243
1262
  } break;
1244
1263
  case GGML_UNARY_OP_GELU:
@@ -1357,16 +1376,15 @@ static enum ggml_status ggml_metal_graph_compute(
1357
1376
  case GGML_OP_SOFT_MAX:
1358
1377
  {
1359
1378
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
1360
- GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
1361
1379
 
1362
1380
  int nth = 32; // SIMD width
1363
1381
 
1364
1382
  id<MTLComputePipelineState> pipeline = nil;
1365
1383
 
1366
- const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
1384
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
1367
1385
 
1368
1386
  if (ne00%4 == 0) {
1369
- while (nth < ne00/4 && nth < 256) {
1387
+ while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
1370
1388
  nth *= 2;
1371
1389
  }
1372
1390
  if (use_f16) {
@@ -1375,7 +1393,7 @@ static enum ggml_status ggml_metal_graph_compute(
1375
1393
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
1376
1394
  }
1377
1395
  } else {
1378
- while (nth < ne00 && nth < 1024) {
1396
+ while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
1379
1397
  nth *= 2;
1380
1398
  }
1381
1399
  if (use_f16) {
@@ -1394,8 +1412,8 @@ static enum ggml_status ggml_metal_graph_compute(
1394
1412
  const int64_t nrows_x = ggml_nrows(src0);
1395
1413
  const int64_t nrows_y = src0->ne[1];
1396
1414
 
1397
- const uint32_t n_head_kv = nrows_x/nrows_y;
1398
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
1415
+ const uint32_t n_head = nrows_x/nrows_y;
1416
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
1399
1417
 
1400
1418
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1401
1419
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
@@ -1407,20 +1425,15 @@ static enum ggml_status ggml_metal_graph_compute(
1407
1425
  } else {
1408
1426
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1409
1427
  }
1410
- if (id_src2) {
1411
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
1412
- } else {
1413
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
1414
- }
1415
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
1416
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
1417
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5];
1418
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
1419
- [encoder setBytes:&scale length:sizeof(scale) atIndex:7];
1420
- [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
1421
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
1422
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
1423
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
1428
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1429
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1430
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1431
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1432
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1433
+ [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
1434
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
1435
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
1436
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
1424
1437
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1425
1438
 
1426
1439
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -1756,11 +1769,7 @@ static enum ggml_status ggml_metal_graph_compute(
1756
1769
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1757
1770
  }
1758
1771
  else if (src0t == GGML_TYPE_Q3_K) {
1759
- #ifdef GGML_QKK_64
1760
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1761
- #else
1762
1772
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1763
- #endif
1764
1773
  }
1765
1774
  else if (src0t == GGML_TYPE_Q5_K) {
1766
1775
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -1778,16 +1787,6 @@ static enum ggml_status ggml_metal_graph_compute(
1778
1787
  const int n_as = src0->ne[2];
1779
1788
 
1780
1789
  // src2 = ids
1781
- const int64_t ne20 = src2->ne[0];
1782
- const int64_t ne21 = src2->ne[1];
1783
- const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
1784
- const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
1785
-
1786
- const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
1787
- const uint64_t nb21 = src2->nb[1];
1788
- const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
1789
- const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
1790
-
1791
1790
  const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
1792
1791
 
1793
1792
  GGML_ASSERT(src2t == GGML_TYPE_I32);
@@ -2011,12 +2010,7 @@ static enum ggml_status ggml_metal_graph_compute(
2011
2010
  {
2012
2011
  nth0 = 4;
2013
2012
  nth1 = 16;
2014
- #if QK_K == 64
2015
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
2016
- #else
2017
2013
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
2018
- #endif
2019
-
2020
2014
  } break;
2021
2015
  default:
2022
2016
  {
@@ -2081,11 +2075,7 @@ static enum ggml_status ggml_metal_graph_compute(
2081
2075
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2082
2076
  }
2083
2077
  else if (src0t == GGML_TYPE_Q3_K) {
2084
- #ifdef GGML_QKK_64
2085
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2086
- #else
2087
2078
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2088
- #endif
2089
2079
  }
2090
2080
  else if (src0t == GGML_TYPE_Q5_K) {
2091
2081
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -2225,49 +2215,6 @@ static enum ggml_status ggml_metal_graph_compute(
2225
2215
 
2226
2216
  [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2227
2217
  } break;
2228
- case GGML_OP_ALIBI:
2229
- {
2230
- GGML_ASSERT((src0t == GGML_TYPE_F32));
2231
-
2232
- const int nth = MIN(1024, ne00);
2233
-
2234
- //const int n_past = ((int32_t *) dst->op_params)[0];
2235
- const int n_head = ((int32_t *) dst->op_params)[1];
2236
-
2237
- float max_bias;
2238
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
2239
-
2240
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
2241
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
2242
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
2243
-
2244
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
2245
-
2246
- [encoder setComputePipelineState:pipeline];
2247
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2248
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2249
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2250
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2251
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2252
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
2253
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
2254
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
2255
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
2256
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
2257
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
2258
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
2259
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
2260
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
2261
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
2262
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
2263
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
2264
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
2265
- [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
2266
- [encoder setBytes:&m1 length:sizeof( float) atIndex:19];
2267
- [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
2268
-
2269
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2270
- } break;
2271
2218
  case GGML_OP_ROPE:
2272
2219
  {
2273
2220
  GGML_ASSERT(ne10 == ne02);
@@ -2280,7 +2227,13 @@ static enum ggml_status ggml_metal_graph_compute(
2280
2227
  // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
2281
2228
  const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
2282
2229
 
2283
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
2230
+ float freq_base;
2231
+ float freq_scale;
2232
+ float ext_factor;
2233
+ float attn_factor;
2234
+ float beta_fast;
2235
+ float beta_slow;
2236
+
2284
2237
  memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
2285
2238
  memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
2286
2239
  memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
@@ -2288,6 +2241,15 @@ static enum ggml_status ggml_metal_graph_compute(
2288
2241
  memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
2289
2242
  memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
2290
2243
 
2244
+ const bool is_neox = mode & 2;
2245
+ const bool is_glm = mode & 4;
2246
+
2247
+ GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
2248
+
2249
+ if (!is_neox) {
2250
+ GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
2251
+ }
2252
+
2291
2253
  id<MTLComputePipelineState> pipeline = nil;
2292
2254
 
2293
2255
  switch (src0->type) {
@@ -2299,33 +2261,38 @@ static enum ggml_status ggml_metal_graph_compute(
2299
2261
  [encoder setComputePipelineState:pipeline];
2300
2262
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2301
2263
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2302
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2303
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
2304
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
2305
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
2306
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
2307
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
2308
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2309
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2310
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2311
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
2312
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
2313
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
2314
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
2315
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
2316
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
2317
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
2318
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
2319
- [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
2320
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
2321
- [encoder setBytes:&mode length:sizeof( int) atIndex:21];
2322
- [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
2323
- [encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2324
- [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2325
- [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2326
- [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2327
- [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2328
- [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
2264
+ if (id_src2 != nil) {
2265
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2266
+ } else {
2267
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
2268
+ }
2269
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2270
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
2271
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2272
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2273
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2274
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
2275
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
2276
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
2277
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
2278
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
2279
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
2280
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
2281
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
2282
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
2283
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
2284
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
2285
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
2286
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:20];
2287
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
2288
+ [encoder setBytes:&mode length:sizeof( int) atIndex:22];
2289
+ [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23];
2290
+ [encoder setBytes:&freq_base length:sizeof( float) atIndex:24];
2291
+ [encoder setBytes:&freq_scale length:sizeof( float) atIndex:25];
2292
+ [encoder setBytes:&ext_factor length:sizeof( float) atIndex:26];
2293
+ [encoder setBytes:&attn_factor length:sizeof( float) atIndex:27];
2294
+ [encoder setBytes:&beta_fast length:sizeof( float) atIndex:28];
2295
+ [encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
2329
2296
 
2330
2297
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2331
2298
  } break;
@@ -2389,7 +2356,10 @@ static enum ggml_status ggml_metal_graph_compute(
2389
2356
  {
2390
2357
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
2391
2358
 
2392
- const int sf = dst->op_params[0];
2359
+ const float sf0 = (float)ne0/src0->ne[0];
2360
+ const float sf1 = (float)ne1/src0->ne[1];
2361
+ const float sf2 = (float)ne2/src0->ne[2];
2362
+ const float sf3 = (float)ne3/src0->ne[3];
2393
2363
 
2394
2364
  const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
2395
2365
 
@@ -2412,7 +2382,10 @@ static enum ggml_status ggml_metal_graph_compute(
2412
2382
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2413
2383
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2414
2384
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2415
- [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2385
+ [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
2386
+ [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
2387
+ [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
2388
+ [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
2416
2389
 
2417
2390
  const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
2418
2391
 
@@ -2548,13 +2521,14 @@ static enum ggml_status ggml_metal_graph_compute(
2548
2521
  } break;
2549
2522
  case GGML_OP_FLASH_ATTN_EXT:
2550
2523
  {
2551
- GGML_ASSERT(ne00 % 4 == 0);
2524
+ GGML_ASSERT(ne00 % 4 == 0);
2525
+ GGML_ASSERT(ne11 % 32 == 0);
2526
+
2552
2527
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
2553
2528
 
2554
- struct ggml_tensor * src3 = gf->nodes[i]->src[3];
2529
+ GGML_ASSERT(ggml_are_same_shape (src1, src2));
2555
2530
 
2556
- GGML_ASSERT(ggml_are_same_shape(src1, src2));
2557
- GGML_ASSERT(src3);
2531
+ struct ggml_tensor * src3 = gf->nodes[i]->src[3];
2558
2532
 
2559
2533
  size_t offs_src3 = 0;
2560
2534
 
@@ -2565,7 +2539,7 @@ static enum ggml_status ggml_metal_graph_compute(
2565
2539
  "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
2566
2540
 
2567
2541
  const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
2568
- const int64_t ne31 = src3 ? src3->ne[1] : 0;
2542
+ //const int64_t ne31 = src3 ? src3->ne[1] : 0;
2569
2543
  const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
2570
2544
  const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
2571
2545
 
@@ -2577,7 +2551,16 @@ static enum ggml_status ggml_metal_graph_compute(
2577
2551
  const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
2578
2552
 
2579
2553
  float scale;
2580
- memcpy(&scale, dst->op_params, sizeof(float));
2554
+ float max_bias;
2555
+
2556
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
2557
+ memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
2558
+
2559
+ const uint32_t n_head = src0->ne[2];
2560
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
2561
+
2562
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
2563
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
2581
2564
 
2582
2565
  id<MTLComputePipelineState> pipeline = nil;
2583
2566
 
@@ -2614,34 +2597,38 @@ static enum ggml_status ggml_metal_graph_compute(
2614
2597
  }
2615
2598
 
2616
2599
  [encoder setComputePipelineState:pipeline];
2617
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2618
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2619
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2620
- [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2621
- [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2622
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
2623
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
2624
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
2625
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
2626
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
2627
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
2628
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
2629
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
2630
- [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
2631
- [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
2632
- [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
2633
- [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
2634
- [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
2635
- [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
2636
- [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
2637
- [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
2638
- [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21];
2639
- [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22];
2640
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23];
2641
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24];
2642
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25];
2643
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
2644
- [encoder setBytes:&scale length:sizeof( float) atIndex:27];
2600
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2601
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2602
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2603
+ if (id_src3) {
2604
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2605
+ } else {
2606
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
2607
+ }
2608
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2609
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2610
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2611
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2612
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2613
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2614
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2615
+ [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
2616
+ [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
2617
+ [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
2618
+ [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
2619
+ [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
2620
+ [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
2621
+ [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
2622
+ [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
2623
+ [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
2624
+ [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
2625
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
2626
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
2627
+ [encoder setBytes:&scale length:sizeof( float) atIndex:23];
2628
+ [encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
2629
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
2630
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
2631
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
2645
2632
 
2646
2633
  if (!use_vec_kernel) {
2647
2634
  // half8x8 kernel