llama_cpp 0.14.5 → 0.14.6

Sign up to get free protection for your applications and to get access to all the features.
@@ -37,11 +37,15 @@ enum ggml_metal_kernel_type {
37
37
  GGML_METAL_KERNEL_TYPE_DIV_ROW,
38
38
  GGML_METAL_KERNEL_TYPE_SCALE,
39
39
  GGML_METAL_KERNEL_TYPE_SCALE_4,
40
+ GGML_METAL_KERNEL_TYPE_CLAMP,
40
41
  GGML_METAL_KERNEL_TYPE_TANH,
41
42
  GGML_METAL_KERNEL_TYPE_RELU,
42
43
  GGML_METAL_KERNEL_TYPE_GELU,
44
+ GGML_METAL_KERNEL_TYPE_GELU_4,
43
45
  GGML_METAL_KERNEL_TYPE_GELU_QUICK,
46
+ GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
44
47
  GGML_METAL_KERNEL_TYPE_SILU,
48
+ GGML_METAL_KERNEL_TYPE_SILU_4,
45
49
  GGML_METAL_KERNEL_TYPE_SOFT_MAX,
46
50
  GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
47
51
  GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
@@ -468,11 +472,15 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
468
472
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
469
473
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
470
474
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
475
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
471
476
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
472
477
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
473
478
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
479
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
474
480
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
481
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
475
482
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
483
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
476
484
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
477
485
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
478
486
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
@@ -713,6 +721,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
713
721
  case GGML_OP_MUL:
714
722
  case GGML_OP_DIV:
715
723
  case GGML_OP_SCALE:
724
+ case GGML_OP_CLAMP:
716
725
  case GGML_OP_SQR:
717
726
  case GGML_OP_SUM_ROWS:
718
727
  return true;
@@ -1154,8 +1163,30 @@ static enum ggml_status ggml_metal_graph_compute(
1154
1163
 
1155
1164
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1156
1165
  } break;
1166
+ case GGML_OP_CLAMP:
1167
+ {
1168
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
1169
+
1170
+ float min;
1171
+ float max;
1172
+ memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
1173
+ memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
1174
+
1175
+ [encoder setComputePipelineState:pipeline];
1176
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1177
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1178
+ [encoder setBytes:&min length:sizeof(min) atIndex:2];
1179
+ [encoder setBytes:&max length:sizeof(max) atIndex:3];
1180
+
1181
+ const int64_t n = ggml_nelements(dst);
1182
+
1183
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1184
+ } break;
1157
1185
  case GGML_OP_UNARY:
1158
1186
  switch (ggml_get_unary_op(gf->nodes[i])) {
1187
+ // we are not taking into account the strides, so for now require contiguous tensors
1188
+ GGML_ASSERT(ggml_is_contiguous(src0));
1189
+
1159
1190
  case GGML_UNARY_OP_TANH:
1160
1191
  {
1161
1192
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
@@ -1182,42 +1213,60 @@ static enum ggml_status ggml_metal_graph_compute(
1182
1213
  } break;
1183
1214
  case GGML_UNARY_OP_GELU:
1184
1215
  {
1185
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
1216
+ int64_t n = ggml_nelements(dst);
1217
+
1218
+ id<MTLComputePipelineState> pipeline = nil;
1219
+
1220
+ if (n % 4 == 0) {
1221
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
1222
+ n /= 4;
1223
+ } else {
1224
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
1225
+ }
1186
1226
 
1187
1227
  [encoder setComputePipelineState:pipeline];
1188
1228
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1189
1229
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1190
1230
 
1191
- const int64_t n = ggml_nelements(dst);
1192
- GGML_ASSERT(n % 4 == 0);
1193
-
1194
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1231
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1195
1232
  } break;
1196
1233
  case GGML_UNARY_OP_GELU_QUICK:
1197
1234
  {
1198
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
1235
+ int64_t n = ggml_nelements(dst);
1236
+
1237
+ id<MTLComputePipelineState> pipeline = nil;
1238
+
1239
+ if (n % 4 == 0) {
1240
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
1241
+ n /= 4;
1242
+ } else {
1243
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
1244
+ }
1199
1245
 
1200
1246
  [encoder setComputePipelineState:pipeline];
1201
1247
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1202
1248
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1203
1249
 
1204
- const int64_t n = ggml_nelements(dst);
1205
- GGML_ASSERT(n % 4 == 0);
1206
-
1207
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1250
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1208
1251
  } break;
1209
1252
  case GGML_UNARY_OP_SILU:
1210
1253
  {
1211
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
1254
+ int64_t n = ggml_nelements(dst);
1255
+
1256
+ id<MTLComputePipelineState> pipeline = nil;
1257
+
1258
+ if (n % 4 == 0) {
1259
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
1260
+ n /= 4;
1261
+ } else {
1262
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
1263
+ }
1212
1264
 
1213
1265
  [encoder setComputePipelineState:pipeline];
1214
1266
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1215
1267
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1216
1268
 
1217
- const int64_t n = ggml_nelements(dst);
1218
- GGML_ASSERT(n % 4 == 0);
1219
-
1220
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1269
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1221
1270
  } break;
1222
1271
  default:
1223
1272
  {
@@ -1683,15 +1732,10 @@ static enum ggml_status ggml_metal_graph_compute(
1683
1732
  } break;
1684
1733
  case GGML_OP_MUL_MAT_ID:
1685
1734
  {
1686
- //GGML_ASSERT(ne00 == ne10);
1687
- //GGML_ASSERT(ne03 == ne13);
1688
1735
  const int n_as = src0->ne[2];
1689
1736
 
1690
- // max size of the src1ids array in the kernel shared buffer
1691
- GGML_ASSERT(ne11 <= 4096);
1692
-
1693
1737
  // src2 = ids
1694
- const int64_t ne20 = src2->ne[0]; GGML_UNUSED(ne20);
1738
+ const int64_t ne20 = src2->ne[0];
1695
1739
  const int64_t ne21 = src2->ne[1];
1696
1740
  const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
1697
1741
  const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
@@ -1712,15 +1756,13 @@ static enum ggml_status ggml_metal_graph_compute(
1712
1756
 
1713
1757
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1714
1758
  // to the matrix-vector kernel
1715
- int ne11_mm_min = n_as;
1716
-
1717
- const int idx = ((int32_t *) dst->op_params)[0];
1759
+ // ne20 = n_used_experts
1760
+ // ne21 = n_rows
1761
+ const int dst_rows = ne20*ne21;
1762
+ const int dst_rows_min = n_as;
1718
1763
 
1719
- // batch size
1720
- GGML_ASSERT(ne21 == ne11); // ?
1721
- GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
1722
- const uint r2 = 1;
1723
- const uint r3 = 1;
1764
+ // max size of the rowids array in the kernel shared buffer
1765
+ GGML_ASSERT(dst_rows <= 2048);
1724
1766
 
1725
1767
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1726
1768
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
@@ -1730,7 +1772,7 @@ static enum ggml_status ggml_metal_graph_compute(
1730
1772
  // !!!
1731
1773
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1732
1774
  ne00 % 32 == 0 && ne00 >= 64 &&
1733
- ne11 > ne11_mm_min) {
1775
+ dst_rows > dst_rows_min) {
1734
1776
 
1735
1777
  // some Metal matrix data types require aligned pointers
1736
1778
  // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
@@ -1772,26 +1814,26 @@ static enum ggml_status ggml_metal_graph_compute(
1772
1814
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1773
1815
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1774
1816
  [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
1775
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
1776
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
1777
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
1778
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1779
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1780
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:9];
1781
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:10];
1782
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
1783
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
1784
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:13];
1785
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:14];
1786
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:15];
1787
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:16];
1788
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1789
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1790
- [encoder setBytes:&idx length:sizeof(idx) atIndex:19];
1791
-
1792
- [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
1793
-
1794
- [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1817
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1818
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1819
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1820
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
1821
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
1822
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
1823
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
1824
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1825
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1826
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1827
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1828
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1829
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1830
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1831
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
1832
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1833
+
1834
+ [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
1835
+
1836
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1795
1837
  } else {
1796
1838
  int nth0 = 32;
1797
1839
  int nth1 = 1;
@@ -1926,7 +1968,12 @@ static enum ggml_status ggml_metal_graph_compute(
1926
1968
  {
1927
1969
  nth0 = 4;
1928
1970
  nth1 = 16;
1971
+ #if QK_K == 64
1972
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
1973
+ #else
1929
1974
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
1975
+ #endif
1976
+
1930
1977
  } break;
1931
1978
  default:
1932
1979
  {
@@ -1939,72 +1986,72 @@ static enum ggml_status ggml_metal_graph_compute(
1939
1986
  GGML_ASSERT(ne00 >= nth0*nth1);
1940
1987
  }
1941
1988
 
1942
- const int64_t _ne1 = 1; // kernels needs a reference in constant memory
1943
-
1944
1989
  [encoder setComputePipelineState:pipeline];
1945
1990
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1946
1991
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1947
1992
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1948
1993
  [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
1949
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
1950
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
1951
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6];
1952
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7];
1953
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8];
1954
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
1955
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
1956
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1957
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
1958
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1959
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1960
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1961
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1962
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1963
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1964
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19];
1965
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:20];
1966
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:21];
1967
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:22];
1968
- [encoder setBytes:&idx length:sizeof(idx) atIndex:23];
1994
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1995
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1996
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1997
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
1998
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
1999
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
2000
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
2001
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
2002
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
2003
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
2004
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
2005
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
2006
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
2007
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
2008
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
2009
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
2010
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
2011
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
2012
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
2013
+
2014
+ const int64_t _ne1 = 1;
2015
+ const int tgz = dst_rows;
1969
2016
 
1970
2017
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
1971
2018
  src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
1972
2019
  src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
1973
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2020
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1974
2021
  }
1975
2022
  else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
1976
2023
  const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1977
2024
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1978
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2025
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1979
2026
  }
1980
2027
  else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
1981
2028
  const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1982
2029
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1983
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2030
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1984
2031
  }
1985
2032
  else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
1986
2033
  const int mem_size = 32*sizeof(float);
1987
2034
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1988
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2035
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1989
2036
  }
1990
2037
  else if (src0t == GGML_TYPE_Q4_K) {
1991
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2038
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1992
2039
  }
1993
2040
  else if (src0t == GGML_TYPE_Q3_K) {
1994
2041
  #ifdef GGML_QKK_64
1995
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2042
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1996
2043
  #else
1997
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2044
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1998
2045
  #endif
1999
2046
  }
2000
2047
  else if (src0t == GGML_TYPE_Q5_K) {
2001
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2048
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2002
2049
  }
2003
2050
  else if (src0t == GGML_TYPE_Q6_K) {
2004
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2051
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2005
2052
  } else {
2006
- const int64_t ny = (_ne1 + nrows - 1)/nrows;
2007
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2053
+ const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
2054
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2008
2055
  }
2009
2056
  }
2010
2057
  } break;