llama_cpp 0.14.3 → 0.14.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -1430,6 +1430,10 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1430
1430
  struct ggml_tensor * dst = gf->nodes[i];
1431
1431
  GGML_ASSERT(dst->data != nullptr);
1432
1432
 
1433
+ if (ggml_is_empty(dst)) {
1434
+ continue;
1435
+ }
1436
+
1433
1437
  switch (dst->op) {
1434
1438
  case GGML_OP_NONE:
1435
1439
  case GGML_OP_RESHAPE:
@@ -64,6 +64,7 @@ enum ggml_metal_kernel_type {
64
64
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
65
65
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
66
66
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
67
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,
67
68
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
68
69
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
69
70
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
@@ -91,6 +92,7 @@ enum ggml_metal_kernel_type {
91
92
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
92
93
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
93
94
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
95
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,
94
96
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
95
97
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
96
98
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
@@ -114,6 +116,7 @@ enum ggml_metal_kernel_type {
114
116
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
115
117
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
116
118
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
119
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
117
120
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
118
121
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
119
122
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
@@ -134,6 +137,7 @@ enum ggml_metal_kernel_type {
134
137
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
135
138
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
136
139
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
140
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
137
141
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
138
142
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
139
143
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
@@ -154,6 +158,7 @@ enum ggml_metal_kernel_type {
154
158
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
155
159
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
156
160
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
161
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
157
162
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
158
163
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
159
164
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
@@ -490,6 +495,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
490
495
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
491
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
492
497
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
498
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
493
499
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
494
500
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
495
501
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
@@ -517,6 +523,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
517
523
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
518
524
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
519
525
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
526
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
520
527
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
521
528
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
522
529
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
@@ -540,6 +547,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
540
547
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
541
548
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
542
549
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
550
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
543
551
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
544
552
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
545
553
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
@@ -560,6 +568,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
560
568
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
561
569
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
562
570
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
571
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
563
572
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
564
573
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
565
574
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
@@ -580,6 +589,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
580
589
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
581
590
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
582
591
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
592
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
583
593
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
584
594
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
585
595
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
@@ -837,6 +847,10 @@ static enum ggml_status ggml_metal_graph_compute(
837
847
  struct ggml_tensor * src2 = gf->nodes[i]->src[2];
838
848
  struct ggml_tensor * dst = gf->nodes[i];
839
849
 
850
+ if (ggml_is_empty(dst)) {
851
+ continue;
852
+ }
853
+
840
854
  switch (dst->op) {
841
855
  case GGML_OP_NONE:
842
856
  case GGML_OP_RESHAPE:
@@ -1421,6 +1435,7 @@ static enum ggml_status ggml_metal_graph_compute(
1421
1435
  case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
1422
1436
  case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
1423
1437
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1438
+ case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
1424
1439
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
1425
1440
  case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
1426
1441
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
@@ -1575,6 +1590,12 @@ static enum ggml_status ggml_metal_graph_compute(
1575
1590
  nth1 = 16;
1576
1591
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
1577
1592
  } break;
1593
+ case GGML_TYPE_IQ1_M:
1594
+ {
1595
+ nth0 = 4;
1596
+ nth1 = 16;
1597
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
1598
+ } break;
1578
1599
  case GGML_TYPE_IQ4_NL:
1579
1600
  {
1580
1601
  nth0 = 4;
@@ -1619,9 +1640,9 @@ static enum ggml_status ggml_metal_graph_compute(
1619
1640
  [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1620
1641
  [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1621
1642
 
1622
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1623
- src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1624
- src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ2_S) {
1643
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
1644
+ src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
1645
+ src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
1625
1646
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1626
1647
  }
1627
1648
  else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -1664,37 +1685,31 @@ static enum ggml_status ggml_metal_graph_compute(
1664
1685
  {
1665
1686
  //GGML_ASSERT(ne00 == ne10);
1666
1687
  //GGML_ASSERT(ne03 == ne13);
1667
-
1668
- GGML_ASSERT(src0t == GGML_TYPE_I32);
1669
-
1670
- const int n_as = ((int32_t *) dst->op_params)[1];
1671
-
1672
- // TODO: make this more general
1673
- GGML_ASSERT(n_as <= 8);
1688
+ const int n_as = src0->ne[2];
1674
1689
 
1675
1690
  // max size of the src1ids array in the kernel shared buffer
1676
1691
  GGML_ASSERT(ne11 <= 4096);
1677
1692
 
1678
- const int64_t ne20 = src2 ? src2->ne[0] : 0;
1679
- const int64_t ne21 = src2 ? src2->ne[1] : 0;
1680
- const int64_t ne22 = src2 ? src2->ne[2] : 0;
1681
- const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
1693
+ // src2 = ids
1694
+ const int64_t ne20 = src2->ne[0]; GGML_UNUSED(ne20);
1695
+ const int64_t ne21 = src2->ne[1];
1696
+ const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
1697
+ const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
1682
1698
 
1683
- const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
1684
- const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1685
- const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1686
- const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
1699
+ const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
1700
+ const uint64_t nb21 = src2->nb[1];
1701
+ const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
1702
+ const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
1687
1703
 
1688
- const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
1704
+ const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
1689
1705
 
1690
- GGML_ASSERT(!ggml_is_transposed(src2));
1706
+ GGML_ASSERT(src2t == GGML_TYPE_I32);
1707
+
1708
+ GGML_ASSERT(!ggml_is_transposed(src0));
1691
1709
  GGML_ASSERT(!ggml_is_transposed(src1));
1692
1710
 
1693
1711
  GGML_ASSERT(src1t == GGML_TYPE_F32);
1694
1712
 
1695
- const uint r2 = ne12/ne22;
1696
- const uint r3 = ne13/ne23;
1697
-
1698
1713
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1699
1714
  // to the matrix-vector kernel
1700
1715
  int ne11_mm_min = n_as;
@@ -1702,7 +1717,10 @@ static enum ggml_status ggml_metal_graph_compute(
1702
1717
  const int idx = ((int32_t *) dst->op_params)[0];
1703
1718
 
1704
1719
  // batch size
1705
- GGML_ASSERT(ne01 == ne11);
1720
+ GGML_ASSERT(ne21 == ne11); // ?
1721
+ GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
1722
+ const uint r2 = 1;
1723
+ const uint r3 = 1;
1706
1724
 
1707
1725
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1708
1726
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
@@ -1711,7 +1729,7 @@ static enum ggml_status ggml_metal_graph_compute(
1711
1729
  // indirect matrix multiplication
1712
1730
  // !!!
1713
1731
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1714
- ne20 % 32 == 0 && ne20 >= 64 &&
1732
+ ne00 % 32 == 0 && ne00 >= 64 &&
1715
1733
  ne11 > ne11_mm_min) {
1716
1734
 
1717
1735
  // some Metal matrix data types require aligned pointers
@@ -1724,7 +1742,7 @@ static enum ggml_status ggml_metal_graph_compute(
1724
1742
 
1725
1743
  id<MTLComputePipelineState> pipeline = nil;
1726
1744
 
1727
- switch (src2->type) {
1745
+ switch (src0->type) {
1728
1746
  case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
1729
1747
  case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
1730
1748
  case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
@@ -1743,6 +1761,7 @@ static enum ggml_status ggml_metal_graph_compute(
1743
1761
  case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
1744
1762
  case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
1745
1763
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
1764
+ case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
1746
1765
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
1747
1766
  case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
1748
1767
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
@@ -1752,36 +1771,27 @@ static enum ggml_status ggml_metal_graph_compute(
1752
1771
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1753
1772
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1754
1773
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1755
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1756
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1757
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
1758
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1759
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
1760
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1761
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
1762
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
1763
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
1764
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
1765
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1766
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
1767
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1768
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
1769
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
1770
- [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
1771
- // TODO: how to make this an array? read Metal docs
1772
- for (int j = 0; j < 8; ++j) {
1773
- // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1774
- struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1775
-
1776
- size_t offs_src_cur = 0;
1777
- id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
1778
-
1779
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1780
- }
1774
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
1775
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
1776
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
1777
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
1778
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1779
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1780
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:9];
1781
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:10];
1782
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
1783
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
1784
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:13];
1785
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:14];
1786
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:15];
1787
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:16];
1788
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1789
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1790
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:19];
1781
1791
 
1782
1792
  [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
1783
1793
 
1784
- [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1794
+ [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1785
1795
  } else {
1786
1796
  int nth0 = 32;
1787
1797
  int nth1 = 1;
@@ -1791,7 +1801,7 @@ static enum ggml_status ggml_metal_graph_compute(
1791
1801
  id<MTLComputePipelineState> pipeline = nil;
1792
1802
 
1793
1803
  // use custom matrix x vector kernel
1794
- switch (src2t) {
1804
+ switch (src0t) {
1795
1805
  case GGML_TYPE_F32:
1796
1806
  {
1797
1807
  GGML_ASSERT(src1t == GGML_TYPE_F32);
@@ -1900,6 +1910,12 @@ static enum ggml_status ggml_metal_graph_compute(
1900
1910
  nth1 = 16;
1901
1911
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
1902
1912
  } break;
1913
+ case GGML_TYPE_IQ1_M:
1914
+ {
1915
+ nth0 = 4;
1916
+ nth1 = 16;
1917
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
1918
+ } break;
1903
1919
  case GGML_TYPE_IQ4_NL:
1904
1920
  {
1905
1921
  nth0 = 4;
@@ -1919,8 +1935,8 @@ static enum ggml_status ggml_metal_graph_compute(
1919
1935
  }
1920
1936
  };
1921
1937
 
1922
- if (ggml_is_quantized(src2t)) {
1923
- GGML_ASSERT(ne20 >= nth0*nth1);
1938
+ if (ggml_is_quantized(src0t)) {
1939
+ GGML_ASSERT(ne00 >= nth0*nth1);
1924
1940
  }
1925
1941
 
1926
1942
  const int64_t _ne1 = 1; // kernels needs a reference in constant memory
@@ -1929,75 +1945,66 @@ static enum ggml_status ggml_metal_graph_compute(
1929
1945
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1930
1946
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1931
1947
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1932
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1933
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1934
- [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1935
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
1936
- [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
1937
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
1938
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
1939
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1940
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
1941
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1942
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1943
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1944
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1945
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1946
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1947
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
1948
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1949
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
1950
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1951
- [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1952
- // TODO: how to make this an array? read Metal docs
1953
- for (int j = 0; j < 8; ++j) {
1954
- // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1955
- struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1956
-
1957
- size_t offs_src_cur = 0;
1958
- id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
1959
-
1960
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1948
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
1949
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
1950
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
1951
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6];
1952
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7];
1953
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8];
1954
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
1955
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
1956
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1957
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
1958
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1959
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1960
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1961
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1962
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1963
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1964
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19];
1965
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:20];
1966
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:21];
1967
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:22];
1968
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:23];
1969
+
1970
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
1971
+ src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
1972
+ src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
1973
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1961
1974
  }
1962
-
1963
- if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1964
- src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1965
- src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ2_S) {
1966
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1967
- }
1968
- else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
1969
- const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1975
+ else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
1976
+ const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1970
1977
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1971
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1978
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1972
1979
  }
1973
- else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) {
1974
- const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1980
+ else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
1981
+ const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1975
1982
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1976
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1983
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1977
1984
  }
1978
- else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) {
1985
+ else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
1979
1986
  const int mem_size = 32*sizeof(float);
1980
1987
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1981
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1988
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1982
1989
  }
1983
- else if (src2t == GGML_TYPE_Q4_K) {
1984
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1990
+ else if (src0t == GGML_TYPE_Q4_K) {
1991
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1985
1992
  }
1986
- else if (src2t == GGML_TYPE_Q3_K) {
1993
+ else if (src0t == GGML_TYPE_Q3_K) {
1987
1994
  #ifdef GGML_QKK_64
1988
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1995
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1989
1996
  #else
1990
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1997
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1991
1998
  #endif
1992
1999
  }
1993
- else if (src2t == GGML_TYPE_Q5_K) {
1994
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2000
+ else if (src0t == GGML_TYPE_Q5_K) {
2001
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1995
2002
  }
1996
- else if (src2t == GGML_TYPE_Q6_K) {
1997
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2003
+ else if (src0t == GGML_TYPE_Q6_K) {
2004
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1998
2005
  } else {
1999
2006
  const int64_t ny = (_ne1 + nrows - 1)/nrows;
2000
- [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2007
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2001
2008
  }
2002
2009
  }
2003
2010
  } break;
@@ -2024,6 +2031,7 @@ static enum ggml_status ggml_metal_graph_compute(
2024
2031
  case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
2025
2032
  case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
2026
2033
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
2034
+ case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break;
2027
2035
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
2028
2036
  case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
2029
2037
  case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
@@ -2403,6 +2411,16 @@ static enum ggml_status ggml_metal_graph_compute(
2403
2411
 
2404
2412
  enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2405
2413
 
2414
+ // bitonic sort requires the number of elements to be power of 2
2415
+ int64_t ne00_padded = 1;
2416
+ while (ne00_padded < ne00) {
2417
+ ne00_padded *= 2;
2418
+ }
2419
+
2420
+ // Metal kernels require the buffer size to be multiple of 16 bytes
2421
+ // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
2422
+ const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
2423
+
2406
2424
  id<MTLComputePipelineState> pipeline = nil;
2407
2425
 
2408
2426
  switch (order) {
@@ -2412,11 +2430,13 @@ static enum ggml_status ggml_metal_graph_compute(
2412
2430
  };
2413
2431
 
2414
2432
  [encoder setComputePipelineState:pipeline];
2415
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2416
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2417
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2433
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2434
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2435
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2436
+ [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
2437
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2418
2438
 
2419
- [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
2439
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
2420
2440
  } break;
2421
2441
  case GGML_OP_LEAKY_RELU:
2422
2442
  {