llama_cpp 0.14.3 → 0.14.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1430,6 +1430,10 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1430
1430
  struct ggml_tensor * dst = gf->nodes[i];
1431
1431
  GGML_ASSERT(dst->data != nullptr);
1432
1432
 
1433
+ if (ggml_is_empty(dst)) {
1434
+ continue;
1435
+ }
1436
+
1433
1437
  switch (dst->op) {
1434
1438
  case GGML_OP_NONE:
1435
1439
  case GGML_OP_RESHAPE:
@@ -64,6 +64,7 @@ enum ggml_metal_kernel_type {
64
64
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
65
65
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
66
66
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
67
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,
67
68
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
68
69
  GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
69
70
  GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
@@ -91,6 +92,7 @@ enum ggml_metal_kernel_type {
91
92
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
92
93
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
93
94
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
95
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,
94
96
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
95
97
  GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
96
98
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
@@ -114,6 +116,7 @@ enum ggml_metal_kernel_type {
114
116
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
115
117
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
116
118
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
119
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
117
120
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
118
121
  GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
119
122
  GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
@@ -134,6 +137,7 @@ enum ggml_metal_kernel_type {
134
137
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
135
138
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
136
139
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
140
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
137
141
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
138
142
  GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
139
143
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
@@ -154,6 +158,7 @@ enum ggml_metal_kernel_type {
154
158
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
155
159
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
156
160
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
161
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
157
162
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
158
163
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
159
164
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
@@ -490,6 +495,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
490
495
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true);
491
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
492
497
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
498
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true);
493
499
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
494
500
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
495
501
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
@@ -517,6 +523,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
517
523
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction);
518
524
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
519
525
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
526
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction);
520
527
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
521
528
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
522
529
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
@@ -540,6 +547,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
540
547
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction);
541
548
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
542
549
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
550
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction);
543
551
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
544
552
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
545
553
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
@@ -560,6 +568,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
560
568
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm);
561
569
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
562
570
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
571
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm);
563
572
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
564
573
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
565
574
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
@@ -580,6 +589,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
580
589
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm);
581
590
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
582
591
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
592
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
583
593
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
584
594
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
585
595
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
@@ -837,6 +847,10 @@ static enum ggml_status ggml_metal_graph_compute(
837
847
  struct ggml_tensor * src2 = gf->nodes[i]->src[2];
838
848
  struct ggml_tensor * dst = gf->nodes[i];
839
849
 
850
+ if (ggml_is_empty(dst)) {
851
+ continue;
852
+ }
853
+
840
854
  switch (dst->op) {
841
855
  case GGML_OP_NONE:
842
856
  case GGML_OP_RESHAPE:
@@ -1421,6 +1435,7 @@ static enum ggml_status ggml_metal_graph_compute(
1421
1435
  case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break;
1422
1436
  case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
1423
1437
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
1438
+ case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break;
1424
1439
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
1425
1440
  case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
1426
1441
  default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
@@ -1575,6 +1590,12 @@ static enum ggml_status ggml_metal_graph_compute(
1575
1590
  nth1 = 16;
1576
1591
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
1577
1592
  } break;
1593
+ case GGML_TYPE_IQ1_M:
1594
+ {
1595
+ nth0 = 4;
1596
+ nth1 = 16;
1597
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
1598
+ } break;
1578
1599
  case GGML_TYPE_IQ4_NL:
1579
1600
  {
1580
1601
  nth0 = 4;
@@ -1619,9 +1640,9 @@ static enum ggml_status ggml_metal_graph_compute(
1619
1640
  [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1620
1641
  [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1621
1642
 
1622
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1623
- src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1624
- src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ2_S) {
1643
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
1644
+ src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
1645
+ src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
1625
1646
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1626
1647
  }
1627
1648
  else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -1664,37 +1685,31 @@ static enum ggml_status ggml_metal_graph_compute(
1664
1685
  {
1665
1686
  //GGML_ASSERT(ne00 == ne10);
1666
1687
  //GGML_ASSERT(ne03 == ne13);
1667
-
1668
- GGML_ASSERT(src0t == GGML_TYPE_I32);
1669
-
1670
- const int n_as = ((int32_t *) dst->op_params)[1];
1671
-
1672
- // TODO: make this more general
1673
- GGML_ASSERT(n_as <= 8);
1688
+ const int n_as = src0->ne[2];
1674
1689
 
1675
1690
  // max size of the src1ids array in the kernel shared buffer
1676
1691
  GGML_ASSERT(ne11 <= 4096);
1677
1692
 
1678
- const int64_t ne20 = src2 ? src2->ne[0] : 0;
1679
- const int64_t ne21 = src2 ? src2->ne[1] : 0;
1680
- const int64_t ne22 = src2 ? src2->ne[2] : 0;
1681
- const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
1693
+ // src2 = ids
1694
+ const int64_t ne20 = src2->ne[0]; GGML_UNUSED(ne20);
1695
+ const int64_t ne21 = src2->ne[1];
1696
+ const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
1697
+ const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
1682
1698
 
1683
- const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
1684
- const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1685
- const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1686
- const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
1699
+ const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
1700
+ const uint64_t nb21 = src2->nb[1];
1701
+ const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
1702
+ const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
1687
1703
 
1688
- const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
1704
+ const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
1689
1705
 
1690
- GGML_ASSERT(!ggml_is_transposed(src2));
1706
+ GGML_ASSERT(src2t == GGML_TYPE_I32);
1707
+
1708
+ GGML_ASSERT(!ggml_is_transposed(src0));
1691
1709
  GGML_ASSERT(!ggml_is_transposed(src1));
1692
1710
 
1693
1711
  GGML_ASSERT(src1t == GGML_TYPE_F32);
1694
1712
 
1695
- const uint r2 = ne12/ne22;
1696
- const uint r3 = ne13/ne23;
1697
-
1698
1713
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1699
1714
  // to the matrix-vector kernel
1700
1715
  int ne11_mm_min = n_as;
@@ -1702,7 +1717,10 @@ static enum ggml_status ggml_metal_graph_compute(
1702
1717
  const int idx = ((int32_t *) dst->op_params)[0];
1703
1718
 
1704
1719
  // batch size
1705
- GGML_ASSERT(ne01 == ne11);
1720
+ GGML_ASSERT(ne21 == ne11); // ?
1721
+ GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
1722
+ const uint r2 = 1;
1723
+ const uint r3 = 1;
1706
1724
 
1707
1725
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1708
1726
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
@@ -1711,7 +1729,7 @@ static enum ggml_status ggml_metal_graph_compute(
1711
1729
  // indirect matrix multiplication
1712
1730
  // !!!
1713
1731
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1714
- ne20 % 32 == 0 && ne20 >= 64 &&
1732
+ ne00 % 32 == 0 && ne00 >= 64 &&
1715
1733
  ne11 > ne11_mm_min) {
1716
1734
 
1717
1735
  // some Metal matrix data types require aligned pointers
@@ -1724,7 +1742,7 @@ static enum ggml_status ggml_metal_graph_compute(
1724
1742
 
1725
1743
  id<MTLComputePipelineState> pipeline = nil;
1726
1744
 
1727
- switch (src2->type) {
1745
+ switch (src0->type) {
1728
1746
  case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
1729
1747
  case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
1730
1748
  case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
@@ -1743,6 +1761,7 @@ static enum ggml_status ggml_metal_graph_compute(
1743
1761
  case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
1744
1762
  case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
1745
1763
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
1764
+ case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
1746
1765
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
1747
1766
  case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
1748
1767
  default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
@@ -1752,36 +1771,27 @@ static enum ggml_status ggml_metal_graph_compute(
1752
1771
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1753
1772
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1754
1773
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1755
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1756
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1757
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
1758
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1759
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
1760
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1761
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
1762
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
1763
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
1764
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
1765
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1766
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
1767
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1768
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
1769
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
1770
- [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
1771
- // TODO: how to make this an array? read Metal docs
1772
- for (int j = 0; j < 8; ++j) {
1773
- // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1774
- struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1775
-
1776
- size_t offs_src_cur = 0;
1777
- id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
1778
-
1779
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
1780
- }
1774
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
1775
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
1776
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
1777
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
1778
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1779
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1780
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:9];
1781
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:10];
1782
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
1783
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
1784
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:13];
1785
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:14];
1786
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:15];
1787
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:16];
1788
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1789
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1790
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:19];
1781
1791
 
1782
1792
  [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
1783
1793
 
1784
- [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1794
+ [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1785
1795
  } else {
1786
1796
  int nth0 = 32;
1787
1797
  int nth1 = 1;
@@ -1791,7 +1801,7 @@ static enum ggml_status ggml_metal_graph_compute(
1791
1801
  id<MTLComputePipelineState> pipeline = nil;
1792
1802
 
1793
1803
  // use custom matrix x vector kernel
1794
- switch (src2t) {
1804
+ switch (src0t) {
1795
1805
  case GGML_TYPE_F32:
1796
1806
  {
1797
1807
  GGML_ASSERT(src1t == GGML_TYPE_F32);
@@ -1900,6 +1910,12 @@ static enum ggml_status ggml_metal_graph_compute(
1900
1910
  nth1 = 16;
1901
1911
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
1902
1912
  } break;
1913
+ case GGML_TYPE_IQ1_M:
1914
+ {
1915
+ nth0 = 4;
1916
+ nth1 = 16;
1917
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
1918
+ } break;
1903
1919
  case GGML_TYPE_IQ4_NL:
1904
1920
  {
1905
1921
  nth0 = 4;
@@ -1919,8 +1935,8 @@ static enum ggml_status ggml_metal_graph_compute(
1919
1935
  }
1920
1936
  };
1921
1937
 
1922
- if (ggml_is_quantized(src2t)) {
1923
- GGML_ASSERT(ne20 >= nth0*nth1);
1938
+ if (ggml_is_quantized(src0t)) {
1939
+ GGML_ASSERT(ne00 >= nth0*nth1);
1924
1940
  }
1925
1941
 
1926
1942
  const int64_t _ne1 = 1; // kernels needs a reference in constant memory
@@ -1929,75 +1945,66 @@ static enum ggml_status ggml_metal_graph_compute(
1929
1945
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1930
1946
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1931
1947
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1932
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
1933
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1934
- [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1935
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
1936
- [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
1937
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
1938
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
1939
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1940
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
1941
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1942
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1943
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1944
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1945
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1946
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1947
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
1948
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1949
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
1950
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
1951
- [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
1952
- // TODO: how to make this an array? read Metal docs
1953
- for (int j = 0; j < 8; ++j) {
1954
- // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1955
- struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
1956
-
1957
- size_t offs_src_cur = 0;
1958
- id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
1959
-
1960
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
1948
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
1949
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
1950
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
1951
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6];
1952
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7];
1953
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8];
1954
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
1955
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
1956
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1957
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
1958
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1959
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1960
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1961
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1962
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1963
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1964
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19];
1965
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:20];
1966
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:21];
1967
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:22];
1968
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:23];
1969
+
1970
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
1971
+ src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
1972
+ src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
1973
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1961
1974
  }
1962
-
1963
- if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1964
- src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1965
- src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ2_S) {
1966
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1967
- }
1968
- else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
1969
- const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1975
+ else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
1976
+ const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1970
1977
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1971
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1978
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1972
1979
  }
1973
- else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) {
1974
- const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1980
+ else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
1981
+ const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1975
1982
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1976
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1983
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1977
1984
  }
1978
- else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) {
1985
+ else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
1979
1986
  const int mem_size = 32*sizeof(float);
1980
1987
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1981
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1988
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1982
1989
  }
1983
- else if (src2t == GGML_TYPE_Q4_K) {
1984
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1990
+ else if (src0t == GGML_TYPE_Q4_K) {
1991
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1985
1992
  }
1986
- else if (src2t == GGML_TYPE_Q3_K) {
1993
+ else if (src0t == GGML_TYPE_Q3_K) {
1987
1994
  #ifdef GGML_QKK_64
1988
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1995
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1989
1996
  #else
1990
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1997
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1991
1998
  #endif
1992
1999
  }
1993
- else if (src2t == GGML_TYPE_Q5_K) {
1994
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2000
+ else if (src0t == GGML_TYPE_Q5_K) {
2001
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1995
2002
  }
1996
- else if (src2t == GGML_TYPE_Q6_K) {
1997
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2003
+ else if (src0t == GGML_TYPE_Q6_K) {
2004
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1998
2005
  } else {
1999
2006
  const int64_t ny = (_ne1 + nrows - 1)/nrows;
2000
- [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2007
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2001
2008
  }
2002
2009
  }
2003
2010
  } break;
@@ -2024,6 +2031,7 @@ static enum ggml_status ggml_metal_graph_compute(
2024
2031
  case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break;
2025
2032
  case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
2026
2033
  case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
2034
+ case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break;
2027
2035
  case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
2028
2036
  case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
2029
2037
  case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
@@ -2403,6 +2411,16 @@ static enum ggml_status ggml_metal_graph_compute(
2403
2411
 
2404
2412
  enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
2405
2413
 
2414
+ // bitonic sort requires the number of elements to be power of 2
2415
+ int64_t ne00_padded = 1;
2416
+ while (ne00_padded < ne00) {
2417
+ ne00_padded *= 2;
2418
+ }
2419
+
2420
+ // Metal kernels require the buffer size to be multiple of 16 bytes
2421
+ // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
2422
+ const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
2423
+
2406
2424
  id<MTLComputePipelineState> pipeline = nil;
2407
2425
 
2408
2426
  switch (order) {
@@ -2412,11 +2430,13 @@ static enum ggml_status ggml_metal_graph_compute(
2412
2430
  };
2413
2431
 
2414
2432
  [encoder setComputePipelineState:pipeline];
2415
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2416
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2417
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2433
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2434
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2435
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2436
+ [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
2437
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2418
2438
 
2419
- [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
2439
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
2420
2440
  } break;
2421
2441
  case GGML_OP_LEAKY_RELU:
2422
2442
  {