llama_cpp 0.14.5 → 0.14.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -37,11 +37,15 @@ enum ggml_metal_kernel_type {
37
37
  GGML_METAL_KERNEL_TYPE_DIV_ROW,
38
38
  GGML_METAL_KERNEL_TYPE_SCALE,
39
39
  GGML_METAL_KERNEL_TYPE_SCALE_4,
40
+ GGML_METAL_KERNEL_TYPE_CLAMP,
40
41
  GGML_METAL_KERNEL_TYPE_TANH,
41
42
  GGML_METAL_KERNEL_TYPE_RELU,
42
43
  GGML_METAL_KERNEL_TYPE_GELU,
44
+ GGML_METAL_KERNEL_TYPE_GELU_4,
43
45
  GGML_METAL_KERNEL_TYPE_GELU_QUICK,
46
+ GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
44
47
  GGML_METAL_KERNEL_TYPE_SILU,
48
+ GGML_METAL_KERNEL_TYPE_SILU_4,
45
49
  GGML_METAL_KERNEL_TYPE_SOFT_MAX,
46
50
  GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
47
51
  GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
@@ -468,11 +472,15 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
468
472
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
469
473
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
470
474
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
475
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
471
476
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
472
477
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
473
478
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
479
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
474
480
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
481
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
475
482
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
483
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
476
484
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
477
485
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
478
486
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
@@ -713,6 +721,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
713
721
  case GGML_OP_MUL:
714
722
  case GGML_OP_DIV:
715
723
  case GGML_OP_SCALE:
724
+ case GGML_OP_CLAMP:
716
725
  case GGML_OP_SQR:
717
726
  case GGML_OP_SUM_ROWS:
718
727
  return true;
@@ -1154,8 +1163,30 @@ static enum ggml_status ggml_metal_graph_compute(
1154
1163
 
1155
1164
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1156
1165
  } break;
1166
+ case GGML_OP_CLAMP:
1167
+ {
1168
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
1169
+
1170
+ float min;
1171
+ float max;
1172
+ memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
1173
+ memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
1174
+
1175
+ [encoder setComputePipelineState:pipeline];
1176
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1177
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1178
+ [encoder setBytes:&min length:sizeof(min) atIndex:2];
1179
+ [encoder setBytes:&max length:sizeof(max) atIndex:3];
1180
+
1181
+ const int64_t n = ggml_nelements(dst);
1182
+
1183
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1184
+ } break;
1157
1185
  case GGML_OP_UNARY:
1158
1186
  switch (ggml_get_unary_op(gf->nodes[i])) {
1187
+ // we are not taking into account the strides, so for now require contiguous tensors
1188
+ GGML_ASSERT(ggml_is_contiguous(src0));
1189
+
1159
1190
  case GGML_UNARY_OP_TANH:
1160
1191
  {
1161
1192
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
@@ -1182,42 +1213,60 @@ static enum ggml_status ggml_metal_graph_compute(
1182
1213
  } break;
1183
1214
  case GGML_UNARY_OP_GELU:
1184
1215
  {
1185
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
1216
+ int64_t n = ggml_nelements(dst);
1217
+
1218
+ id<MTLComputePipelineState> pipeline = nil;
1219
+
1220
+ if (n % 4 == 0) {
1221
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
1222
+ n /= 4;
1223
+ } else {
1224
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
1225
+ }
1186
1226
 
1187
1227
  [encoder setComputePipelineState:pipeline];
1188
1228
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1189
1229
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1190
1230
 
1191
- const int64_t n = ggml_nelements(dst);
1192
- GGML_ASSERT(n % 4 == 0);
1193
-
1194
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1231
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1195
1232
  } break;
1196
1233
  case GGML_UNARY_OP_GELU_QUICK:
1197
1234
  {
1198
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
1235
+ int64_t n = ggml_nelements(dst);
1236
+
1237
+ id<MTLComputePipelineState> pipeline = nil;
1238
+
1239
+ if (n % 4 == 0) {
1240
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
1241
+ n /= 4;
1242
+ } else {
1243
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
1244
+ }
1199
1245
 
1200
1246
  [encoder setComputePipelineState:pipeline];
1201
1247
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1202
1248
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1203
1249
 
1204
- const int64_t n = ggml_nelements(dst);
1205
- GGML_ASSERT(n % 4 == 0);
1206
-
1207
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1250
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1208
1251
  } break;
1209
1252
  case GGML_UNARY_OP_SILU:
1210
1253
  {
1211
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
1254
+ int64_t n = ggml_nelements(dst);
1255
+
1256
+ id<MTLComputePipelineState> pipeline = nil;
1257
+
1258
+ if (n % 4 == 0) {
1259
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
1260
+ n /= 4;
1261
+ } else {
1262
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
1263
+ }
1212
1264
 
1213
1265
  [encoder setComputePipelineState:pipeline];
1214
1266
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1215
1267
  [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1216
1268
 
1217
- const int64_t n = ggml_nelements(dst);
1218
- GGML_ASSERT(n % 4 == 0);
1219
-
1220
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1269
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1221
1270
  } break;
1222
1271
  default:
1223
1272
  {
@@ -1683,15 +1732,10 @@ static enum ggml_status ggml_metal_graph_compute(
1683
1732
  } break;
1684
1733
  case GGML_OP_MUL_MAT_ID:
1685
1734
  {
1686
- //GGML_ASSERT(ne00 == ne10);
1687
- //GGML_ASSERT(ne03 == ne13);
1688
1735
  const int n_as = src0->ne[2];
1689
1736
 
1690
- // max size of the src1ids array in the kernel shared buffer
1691
- GGML_ASSERT(ne11 <= 4096);
1692
-
1693
1737
  // src2 = ids
1694
- const int64_t ne20 = src2->ne[0]; GGML_UNUSED(ne20);
1738
+ const int64_t ne20 = src2->ne[0];
1695
1739
  const int64_t ne21 = src2->ne[1];
1696
1740
  const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
1697
1741
  const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
@@ -1712,15 +1756,13 @@ static enum ggml_status ggml_metal_graph_compute(
1712
1756
 
1713
1757
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1714
1758
  // to the matrix-vector kernel
1715
- int ne11_mm_min = n_as;
1716
-
1717
- const int idx = ((int32_t *) dst->op_params)[0];
1759
+ // ne20 = n_used_experts
1760
+ // ne21 = n_rows
1761
+ const int dst_rows = ne20*ne21;
1762
+ const int dst_rows_min = n_as;
1718
1763
 
1719
- // batch size
1720
- GGML_ASSERT(ne21 == ne11); // ?
1721
- GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
1722
- const uint r2 = 1;
1723
- const uint r3 = 1;
1764
+ // max size of the rowids array in the kernel shared buffer
1765
+ GGML_ASSERT(dst_rows <= 2048);
1724
1766
 
1725
1767
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1726
1768
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
@@ -1730,7 +1772,7 @@ static enum ggml_status ggml_metal_graph_compute(
1730
1772
  // !!!
1731
1773
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1732
1774
  ne00 % 32 == 0 && ne00 >= 64 &&
1733
- ne11 > ne11_mm_min) {
1775
+ dst_rows > dst_rows_min) {
1734
1776
 
1735
1777
  // some Metal matrix data types require aligned pointers
1736
1778
  // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
@@ -1772,26 +1814,26 @@ static enum ggml_status ggml_metal_graph_compute(
1772
1814
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1773
1815
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1774
1816
  [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
1775
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
1776
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
1777
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
1778
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1779
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1780
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:9];
1781
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:10];
1782
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
1783
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
1784
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:13];
1785
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:14];
1786
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:15];
1787
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:16];
1788
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1789
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1790
- [encoder setBytes:&idx length:sizeof(idx) atIndex:19];
1791
-
1792
- [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
1793
-
1794
- [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1817
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1818
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1819
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1820
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
1821
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
1822
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
1823
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
1824
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1825
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1826
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1827
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1828
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1829
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1830
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1831
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
1832
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1833
+
1834
+ [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
1835
+
1836
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1795
1837
  } else {
1796
1838
  int nth0 = 32;
1797
1839
  int nth1 = 1;
@@ -1926,7 +1968,12 @@ static enum ggml_status ggml_metal_graph_compute(
1926
1968
  {
1927
1969
  nth0 = 4;
1928
1970
  nth1 = 16;
1971
+ #if QK_K == 64
1972
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
1973
+ #else
1929
1974
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
1975
+ #endif
1976
+
1930
1977
  } break;
1931
1978
  default:
1932
1979
  {
@@ -1939,72 +1986,72 @@ static enum ggml_status ggml_metal_graph_compute(
1939
1986
  GGML_ASSERT(ne00 >= nth0*nth1);
1940
1987
  }
1941
1988
 
1942
- const int64_t _ne1 = 1; // kernels needs a reference in constant memory
1943
-
1944
1989
  [encoder setComputePipelineState:pipeline];
1945
1990
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1946
1991
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1947
1992
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1948
1993
  [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
1949
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
1950
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
1951
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6];
1952
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7];
1953
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8];
1954
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
1955
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
1956
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
1957
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
1958
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
1959
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
1960
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
1961
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
1962
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
1963
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
1964
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19];
1965
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:20];
1966
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:21];
1967
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:22];
1968
- [encoder setBytes:&idx length:sizeof(idx) atIndex:23];
1994
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1995
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1996
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1997
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
1998
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
1999
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
2000
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
2001
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
2002
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
2003
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
2004
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
2005
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
2006
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
2007
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
2008
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
2009
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
2010
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
2011
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
2012
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
2013
+
2014
+ const int64_t _ne1 = 1;
2015
+ const int tgz = dst_rows;
1969
2016
 
1970
2017
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
1971
2018
  src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
1972
2019
  src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
1973
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2020
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1974
2021
  }
1975
2022
  else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
1976
2023
  const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
1977
2024
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1978
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2025
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1979
2026
  }
1980
2027
  else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
1981
2028
  const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
1982
2029
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1983
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2030
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1984
2031
  }
1985
2032
  else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
1986
2033
  const int mem_size = 32*sizeof(float);
1987
2034
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
1988
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2035
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1989
2036
  }
1990
2037
  else if (src0t == GGML_TYPE_Q4_K) {
1991
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2038
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1992
2039
  }
1993
2040
  else if (src0t == GGML_TYPE_Q3_K) {
1994
2041
  #ifdef GGML_QKK_64
1995
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2042
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1996
2043
  #else
1997
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2044
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1998
2045
  #endif
1999
2046
  }
2000
2047
  else if (src0t == GGML_TYPE_Q5_K) {
2001
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2048
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2002
2049
  }
2003
2050
  else if (src0t == GGML_TYPE_Q6_K) {
2004
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2051
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2005
2052
  } else {
2006
- const int64_t ny = (_ne1 + nrows - 1)/nrows;
2007
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2053
+ const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
2054
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2008
2055
  }
2009
2056
  }
2010
2057
  } break;