llama_cpp 0.14.5 → 0.14.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/vendor/tmp/llama.cpp/Makefile +18 -6
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +135 -46
- data/vendor/tmp/llama.cpp/ggml-impl.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +130 -83
- data/vendor/tmp/llama.cpp/ggml-metal.metal +505 -1467
- data/vendor/tmp/llama.cpp/ggml-quants.c +1 -1
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +65 -52
- data/vendor/tmp/llama.cpp/ggml.c +153 -87
- data/vendor/tmp/llama.cpp/ggml.h +5 -4
- data/vendor/tmp/llama.cpp/llama.cpp +885 -144
- data/vendor/tmp/llama.cpp/sgemm.cpp +1148 -0
- data/vendor/tmp/llama.cpp/sgemm.h +12 -0
- metadata +4 -2
@@ -37,11 +37,15 @@ enum ggml_metal_kernel_type {
|
|
37
37
|
GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
38
38
|
GGML_METAL_KERNEL_TYPE_SCALE,
|
39
39
|
GGML_METAL_KERNEL_TYPE_SCALE_4,
|
40
|
+
GGML_METAL_KERNEL_TYPE_CLAMP,
|
40
41
|
GGML_METAL_KERNEL_TYPE_TANH,
|
41
42
|
GGML_METAL_KERNEL_TYPE_RELU,
|
42
43
|
GGML_METAL_KERNEL_TYPE_GELU,
|
44
|
+
GGML_METAL_KERNEL_TYPE_GELU_4,
|
43
45
|
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
46
|
+
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
44
47
|
GGML_METAL_KERNEL_TYPE_SILU,
|
48
|
+
GGML_METAL_KERNEL_TYPE_SILU_4,
|
45
49
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX,
|
46
50
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
|
47
51
|
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
|
@@ -468,11 +472,15 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
468
472
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
469
473
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
470
474
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
475
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
471
476
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
472
477
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
473
478
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
479
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
|
474
480
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
481
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
475
482
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
483
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
476
484
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
|
477
485
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
|
478
486
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
@@ -713,6 +721,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
713
721
|
case GGML_OP_MUL:
|
714
722
|
case GGML_OP_DIV:
|
715
723
|
case GGML_OP_SCALE:
|
724
|
+
case GGML_OP_CLAMP:
|
716
725
|
case GGML_OP_SQR:
|
717
726
|
case GGML_OP_SUM_ROWS:
|
718
727
|
return true;
|
@@ -1154,8 +1163,30 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1154
1163
|
|
1155
1164
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1156
1165
|
} break;
|
1166
|
+
case GGML_OP_CLAMP:
|
1167
|
+
{
|
1168
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
|
1169
|
+
|
1170
|
+
float min;
|
1171
|
+
float max;
|
1172
|
+
memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
|
1173
|
+
memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
|
1174
|
+
|
1175
|
+
[encoder setComputePipelineState:pipeline];
|
1176
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1177
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1178
|
+
[encoder setBytes:&min length:sizeof(min) atIndex:2];
|
1179
|
+
[encoder setBytes:&max length:sizeof(max) atIndex:3];
|
1180
|
+
|
1181
|
+
const int64_t n = ggml_nelements(dst);
|
1182
|
+
|
1183
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1184
|
+
} break;
|
1157
1185
|
case GGML_OP_UNARY:
|
1158
1186
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
1187
|
+
// we are not taking into account the strides, so for now require contiguous tensors
|
1188
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
1189
|
+
|
1159
1190
|
case GGML_UNARY_OP_TANH:
|
1160
1191
|
{
|
1161
1192
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
|
@@ -1182,42 +1213,60 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1182
1213
|
} break;
|
1183
1214
|
case GGML_UNARY_OP_GELU:
|
1184
1215
|
{
|
1185
|
-
|
1216
|
+
int64_t n = ggml_nelements(dst);
|
1217
|
+
|
1218
|
+
id<MTLComputePipelineState> pipeline = nil;
|
1219
|
+
|
1220
|
+
if (n % 4 == 0) {
|
1221
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
|
1222
|
+
n /= 4;
|
1223
|
+
} else {
|
1224
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
|
1225
|
+
}
|
1186
1226
|
|
1187
1227
|
[encoder setComputePipelineState:pipeline];
|
1188
1228
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1189
1229
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1190
1230
|
|
1191
|
-
|
1192
|
-
GGML_ASSERT(n % 4 == 0);
|
1193
|
-
|
1194
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1231
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1195
1232
|
} break;
|
1196
1233
|
case GGML_UNARY_OP_GELU_QUICK:
|
1197
1234
|
{
|
1198
|
-
|
1235
|
+
int64_t n = ggml_nelements(dst);
|
1236
|
+
|
1237
|
+
id<MTLComputePipelineState> pipeline = nil;
|
1238
|
+
|
1239
|
+
if (n % 4 == 0) {
|
1240
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
|
1241
|
+
n /= 4;
|
1242
|
+
} else {
|
1243
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
|
1244
|
+
}
|
1199
1245
|
|
1200
1246
|
[encoder setComputePipelineState:pipeline];
|
1201
1247
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1202
1248
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1203
1249
|
|
1204
|
-
|
1205
|
-
GGML_ASSERT(n % 4 == 0);
|
1206
|
-
|
1207
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1250
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1208
1251
|
} break;
|
1209
1252
|
case GGML_UNARY_OP_SILU:
|
1210
1253
|
{
|
1211
|
-
|
1254
|
+
int64_t n = ggml_nelements(dst);
|
1255
|
+
|
1256
|
+
id<MTLComputePipelineState> pipeline = nil;
|
1257
|
+
|
1258
|
+
if (n % 4 == 0) {
|
1259
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
|
1260
|
+
n /= 4;
|
1261
|
+
} else {
|
1262
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
|
1263
|
+
}
|
1212
1264
|
|
1213
1265
|
[encoder setComputePipelineState:pipeline];
|
1214
1266
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1215
1267
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1216
1268
|
|
1217
|
-
|
1218
|
-
GGML_ASSERT(n % 4 == 0);
|
1219
|
-
|
1220
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1269
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1221
1270
|
} break;
|
1222
1271
|
default:
|
1223
1272
|
{
|
@@ -1683,15 +1732,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1683
1732
|
} break;
|
1684
1733
|
case GGML_OP_MUL_MAT_ID:
|
1685
1734
|
{
|
1686
|
-
//GGML_ASSERT(ne00 == ne10);
|
1687
|
-
//GGML_ASSERT(ne03 == ne13);
|
1688
1735
|
const int n_as = src0->ne[2];
|
1689
1736
|
|
1690
|
-
// max size of the src1ids array in the kernel shared buffer
|
1691
|
-
GGML_ASSERT(ne11 <= 4096);
|
1692
|
-
|
1693
1737
|
// src2 = ids
|
1694
|
-
const int64_t ne20 = src2->ne[0];
|
1738
|
+
const int64_t ne20 = src2->ne[0];
|
1695
1739
|
const int64_t ne21 = src2->ne[1];
|
1696
1740
|
const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
|
1697
1741
|
const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
|
@@ -1712,15 +1756,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1712
1756
|
|
1713
1757
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
1714
1758
|
// to the matrix-vector kernel
|
1715
|
-
|
1716
|
-
|
1717
|
-
const int
|
1759
|
+
// ne20 = n_used_experts
|
1760
|
+
// ne21 = n_rows
|
1761
|
+
const int dst_rows = ne20*ne21;
|
1762
|
+
const int dst_rows_min = n_as;
|
1718
1763
|
|
1719
|
-
//
|
1720
|
-
GGML_ASSERT(
|
1721
|
-
GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
|
1722
|
-
const uint r2 = 1;
|
1723
|
-
const uint r3 = 1;
|
1764
|
+
// max size of the rowids array in the kernel shared buffer
|
1765
|
+
GGML_ASSERT(dst_rows <= 2048);
|
1724
1766
|
|
1725
1767
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
1726
1768
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
@@ -1730,7 +1772,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1730
1772
|
// !!!
|
1731
1773
|
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
1732
1774
|
ne00 % 32 == 0 && ne00 >= 64 &&
|
1733
|
-
|
1775
|
+
dst_rows > dst_rows_min) {
|
1734
1776
|
|
1735
1777
|
// some Metal matrix data types require aligned pointers
|
1736
1778
|
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
@@ -1772,26 +1814,26 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1772
1814
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1773
1815
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1774
1816
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
1775
|
-
[encoder setBytes:&
|
1776
|
-
[encoder setBytes:&
|
1777
|
-
[encoder setBytes:&
|
1778
|
-
[encoder setBytes:&
|
1779
|
-
[encoder setBytes:&
|
1780
|
-
[encoder setBytes:&
|
1781
|
-
[encoder setBytes:&
|
1782
|
-
[encoder setBytes:&
|
1783
|
-
[encoder setBytes:&
|
1784
|
-
[encoder setBytes:&
|
1785
|
-
[encoder setBytes:&
|
1786
|
-
[encoder setBytes:&
|
1787
|
-
[encoder setBytes:&
|
1788
|
-
[encoder setBytes:&
|
1789
|
-
[encoder setBytes:&
|
1790
|
-
[encoder setBytes:&
|
1791
|
-
|
1792
|
-
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 +
|
1793
|
-
|
1794
|
-
[encoder dispatchThreadgroups:MTLSizeMake((
|
1817
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
1818
|
+
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
1819
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
1820
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
|
1821
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
|
1822
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
|
1823
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
|
1824
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
1825
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
1826
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
1827
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
1828
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
1829
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
1830
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
1831
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
|
1832
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
|
1833
|
+
|
1834
|
+
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
|
1835
|
+
|
1836
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
1795
1837
|
} else {
|
1796
1838
|
int nth0 = 32;
|
1797
1839
|
int nth1 = 1;
|
@@ -1926,7 +1968,12 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1926
1968
|
{
|
1927
1969
|
nth0 = 4;
|
1928
1970
|
nth1 = 16;
|
1971
|
+
#if QK_K == 64
|
1972
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
|
1973
|
+
#else
|
1929
1974
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
1975
|
+
#endif
|
1976
|
+
|
1930
1977
|
} break;
|
1931
1978
|
default:
|
1932
1979
|
{
|
@@ -1939,72 +1986,72 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1939
1986
|
GGML_ASSERT(ne00 >= nth0*nth1);
|
1940
1987
|
}
|
1941
1988
|
|
1942
|
-
const int64_t _ne1 = 1; // kernels needs a reference in constant memory
|
1943
|
-
|
1944
1989
|
[encoder setComputePipelineState:pipeline];
|
1945
1990
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1946
1991
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1947
1992
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1948
1993
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
1949
|
-
[encoder setBytes:&
|
1950
|
-
[encoder setBytes:&
|
1951
|
-
[encoder setBytes:&
|
1952
|
-
[encoder setBytes:&
|
1953
|
-
[encoder setBytes:&
|
1954
|
-
[encoder setBytes:&
|
1955
|
-
[encoder setBytes:&
|
1956
|
-
[encoder setBytes:&
|
1957
|
-
[encoder setBytes:&
|
1958
|
-
[encoder setBytes:&
|
1959
|
-
[encoder setBytes:&
|
1960
|
-
[encoder setBytes:&
|
1961
|
-
[encoder setBytes:&
|
1962
|
-
[encoder setBytes:&
|
1963
|
-
[encoder setBytes:&
|
1964
|
-
[encoder setBytes:&
|
1965
|
-
[encoder setBytes:&
|
1966
|
-
[encoder setBytes:&
|
1967
|
-
[encoder setBytes:&
|
1968
|
-
|
1994
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
1995
|
+
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
1996
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
1997
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
|
1998
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
|
1999
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
|
2000
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
|
2001
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
|
2002
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
|
2003
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
|
2004
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
|
2005
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
|
2006
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
|
2007
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
|
2008
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
|
2009
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
|
2010
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
|
2011
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
|
2012
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
|
2013
|
+
|
2014
|
+
const int64_t _ne1 = 1;
|
2015
|
+
const int tgz = dst_rows;
|
1969
2016
|
|
1970
2017
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
1971
2018
|
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
1972
2019
|
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
1973
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1,
|
2020
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1974
2021
|
}
|
1975
2022
|
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
1976
2023
|
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
1977
2024
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1978
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1,
|
2025
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1979
2026
|
}
|
1980
2027
|
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
|
1981
2028
|
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
1982
2029
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1983
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1,
|
2030
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1984
2031
|
}
|
1985
2032
|
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
|
1986
2033
|
const int mem_size = 32*sizeof(float);
|
1987
2034
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1988
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1,
|
2035
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1989
2036
|
}
|
1990
2037
|
else if (src0t == GGML_TYPE_Q4_K) {
|
1991
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1,
|
2038
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1992
2039
|
}
|
1993
2040
|
else if (src0t == GGML_TYPE_Q3_K) {
|
1994
2041
|
#ifdef GGML_QKK_64
|
1995
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1,
|
2042
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1996
2043
|
#else
|
1997
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1,
|
2044
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1998
2045
|
#endif
|
1999
2046
|
}
|
2000
2047
|
else if (src0t == GGML_TYPE_Q5_K) {
|
2001
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1,
|
2048
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2002
2049
|
}
|
2003
2050
|
else if (src0t == GGML_TYPE_Q6_K) {
|
2004
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1,
|
2051
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2005
2052
|
} else {
|
2006
|
-
const int64_t ny = (_ne1 + nrows - 1)/nrows;
|
2007
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny,
|
2053
|
+
const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
|
2054
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2008
2055
|
}
|
2009
2056
|
}
|
2010
2057
|
} break;
|