llama_cpp 0.14.5 → 0.14.6
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/vendor/tmp/llama.cpp/Makefile +18 -6
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +135 -46
- data/vendor/tmp/llama.cpp/ggml-impl.h +1 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +130 -83
- data/vendor/tmp/llama.cpp/ggml-metal.metal +505 -1467
- data/vendor/tmp/llama.cpp/ggml-quants.c +1 -1
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +65 -52
- data/vendor/tmp/llama.cpp/ggml.c +153 -87
- data/vendor/tmp/llama.cpp/ggml.h +5 -4
- data/vendor/tmp/llama.cpp/llama.cpp +885 -144
- data/vendor/tmp/llama.cpp/sgemm.cpp +1148 -0
- data/vendor/tmp/llama.cpp/sgemm.h +12 -0
- metadata +4 -2
@@ -37,11 +37,15 @@ enum ggml_metal_kernel_type {
|
|
37
37
|
GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
38
38
|
GGML_METAL_KERNEL_TYPE_SCALE,
|
39
39
|
GGML_METAL_KERNEL_TYPE_SCALE_4,
|
40
|
+
GGML_METAL_KERNEL_TYPE_CLAMP,
|
40
41
|
GGML_METAL_KERNEL_TYPE_TANH,
|
41
42
|
GGML_METAL_KERNEL_TYPE_RELU,
|
42
43
|
GGML_METAL_KERNEL_TYPE_GELU,
|
44
|
+
GGML_METAL_KERNEL_TYPE_GELU_4,
|
43
45
|
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
46
|
+
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
44
47
|
GGML_METAL_KERNEL_TYPE_SILU,
|
48
|
+
GGML_METAL_KERNEL_TYPE_SILU_4,
|
45
49
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX,
|
46
50
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
|
47
51
|
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
|
@@ -468,11 +472,15 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
468
472
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
469
473
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
470
474
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
475
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
471
476
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
472
477
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
473
478
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
479
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
|
474
480
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
481
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
475
482
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
483
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
476
484
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
|
477
485
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
|
478
486
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
@@ -713,6 +721,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
713
721
|
case GGML_OP_MUL:
|
714
722
|
case GGML_OP_DIV:
|
715
723
|
case GGML_OP_SCALE:
|
724
|
+
case GGML_OP_CLAMP:
|
716
725
|
case GGML_OP_SQR:
|
717
726
|
case GGML_OP_SUM_ROWS:
|
718
727
|
return true;
|
@@ -1154,8 +1163,30 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1154
1163
|
|
1155
1164
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1156
1165
|
} break;
|
1166
|
+
case GGML_OP_CLAMP:
|
1167
|
+
{
|
1168
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
|
1169
|
+
|
1170
|
+
float min;
|
1171
|
+
float max;
|
1172
|
+
memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
|
1173
|
+
memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
|
1174
|
+
|
1175
|
+
[encoder setComputePipelineState:pipeline];
|
1176
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1177
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1178
|
+
[encoder setBytes:&min length:sizeof(min) atIndex:2];
|
1179
|
+
[encoder setBytes:&max length:sizeof(max) atIndex:3];
|
1180
|
+
|
1181
|
+
const int64_t n = ggml_nelements(dst);
|
1182
|
+
|
1183
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1184
|
+
} break;
|
1157
1185
|
case GGML_OP_UNARY:
|
1158
1186
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
1187
|
+
// we are not taking into account the strides, so for now require contiguous tensors
|
1188
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
1189
|
+
|
1159
1190
|
case GGML_UNARY_OP_TANH:
|
1160
1191
|
{
|
1161
1192
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
|
@@ -1182,42 +1213,60 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1182
1213
|
} break;
|
1183
1214
|
case GGML_UNARY_OP_GELU:
|
1184
1215
|
{
|
1185
|
-
|
1216
|
+
int64_t n = ggml_nelements(dst);
|
1217
|
+
|
1218
|
+
id<MTLComputePipelineState> pipeline = nil;
|
1219
|
+
|
1220
|
+
if (n % 4 == 0) {
|
1221
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
|
1222
|
+
n /= 4;
|
1223
|
+
} else {
|
1224
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
|
1225
|
+
}
|
1186
1226
|
|
1187
1227
|
[encoder setComputePipelineState:pipeline];
|
1188
1228
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1189
1229
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1190
1230
|
|
1191
|
-
|
1192
|
-
GGML_ASSERT(n % 4 == 0);
|
1193
|
-
|
1194
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1231
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1195
1232
|
} break;
|
1196
1233
|
case GGML_UNARY_OP_GELU_QUICK:
|
1197
1234
|
{
|
1198
|
-
|
1235
|
+
int64_t n = ggml_nelements(dst);
|
1236
|
+
|
1237
|
+
id<MTLComputePipelineState> pipeline = nil;
|
1238
|
+
|
1239
|
+
if (n % 4 == 0) {
|
1240
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
|
1241
|
+
n /= 4;
|
1242
|
+
} else {
|
1243
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
|
1244
|
+
}
|
1199
1245
|
|
1200
1246
|
[encoder setComputePipelineState:pipeline];
|
1201
1247
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1202
1248
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1203
1249
|
|
1204
|
-
|
1205
|
-
GGML_ASSERT(n % 4 == 0);
|
1206
|
-
|
1207
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1250
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1208
1251
|
} break;
|
1209
1252
|
case GGML_UNARY_OP_SILU:
|
1210
1253
|
{
|
1211
|
-
|
1254
|
+
int64_t n = ggml_nelements(dst);
|
1255
|
+
|
1256
|
+
id<MTLComputePipelineState> pipeline = nil;
|
1257
|
+
|
1258
|
+
if (n % 4 == 0) {
|
1259
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
|
1260
|
+
n /= 4;
|
1261
|
+
} else {
|
1262
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
|
1263
|
+
}
|
1212
1264
|
|
1213
1265
|
[encoder setComputePipelineState:pipeline];
|
1214
1266
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1215
1267
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1216
1268
|
|
1217
|
-
|
1218
|
-
GGML_ASSERT(n % 4 == 0);
|
1219
|
-
|
1220
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1269
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1221
1270
|
} break;
|
1222
1271
|
default:
|
1223
1272
|
{
|
@@ -1683,15 +1732,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1683
1732
|
} break;
|
1684
1733
|
case GGML_OP_MUL_MAT_ID:
|
1685
1734
|
{
|
1686
|
-
//GGML_ASSERT(ne00 == ne10);
|
1687
|
-
//GGML_ASSERT(ne03 == ne13);
|
1688
1735
|
const int n_as = src0->ne[2];
|
1689
1736
|
|
1690
|
-
// max size of the src1ids array in the kernel shared buffer
|
1691
|
-
GGML_ASSERT(ne11 <= 4096);
|
1692
|
-
|
1693
1737
|
// src2 = ids
|
1694
|
-
const int64_t ne20 = src2->ne[0];
|
1738
|
+
const int64_t ne20 = src2->ne[0];
|
1695
1739
|
const int64_t ne21 = src2->ne[1];
|
1696
1740
|
const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
|
1697
1741
|
const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
|
@@ -1712,15 +1756,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1712
1756
|
|
1713
1757
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
1714
1758
|
// to the matrix-vector kernel
|
1715
|
-
|
1716
|
-
|
1717
|
-
const int
|
1759
|
+
// ne20 = n_used_experts
|
1760
|
+
// ne21 = n_rows
|
1761
|
+
const int dst_rows = ne20*ne21;
|
1762
|
+
const int dst_rows_min = n_as;
|
1718
1763
|
|
1719
|
-
//
|
1720
|
-
GGML_ASSERT(
|
1721
|
-
GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
|
1722
|
-
const uint r2 = 1;
|
1723
|
-
const uint r3 = 1;
|
1764
|
+
// max size of the rowids array in the kernel shared buffer
|
1765
|
+
GGML_ASSERT(dst_rows <= 2048);
|
1724
1766
|
|
1725
1767
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
1726
1768
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
@@ -1730,7 +1772,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1730
1772
|
// !!!
|
1731
1773
|
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
1732
1774
|
ne00 % 32 == 0 && ne00 >= 64 &&
|
1733
|
-
|
1775
|
+
dst_rows > dst_rows_min) {
|
1734
1776
|
|
1735
1777
|
// some Metal matrix data types require aligned pointers
|
1736
1778
|
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
@@ -1772,26 +1814,26 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1772
1814
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1773
1815
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1774
1816
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
1775
|
-
[encoder setBytes:&
|
1776
|
-
[encoder setBytes:&
|
1777
|
-
[encoder setBytes:&
|
1778
|
-
[encoder setBytes:&
|
1779
|
-
[encoder setBytes:&
|
1780
|
-
[encoder setBytes:&
|
1781
|
-
[encoder setBytes:&
|
1782
|
-
[encoder setBytes:&
|
1783
|
-
[encoder setBytes:&
|
1784
|
-
[encoder setBytes:&
|
1785
|
-
[encoder setBytes:&
|
1786
|
-
[encoder setBytes:&
|
1787
|
-
[encoder setBytes:&
|
1788
|
-
[encoder setBytes:&
|
1789
|
-
[encoder setBytes:&
|
1790
|
-
[encoder setBytes:&
|
1791
|
-
|
1792
|
-
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 +
|
1793
|
-
|
1794
|
-
[encoder dispatchThreadgroups:MTLSizeMake((
|
1817
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
1818
|
+
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
1819
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
1820
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
|
1821
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
|
1822
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
|
1823
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
|
1824
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
1825
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
1826
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
1827
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
1828
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
1829
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
1830
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
1831
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
|
1832
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
|
1833
|
+
|
1834
|
+
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
|
1835
|
+
|
1836
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
1795
1837
|
} else {
|
1796
1838
|
int nth0 = 32;
|
1797
1839
|
int nth1 = 1;
|
@@ -1926,7 +1968,12 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1926
1968
|
{
|
1927
1969
|
nth0 = 4;
|
1928
1970
|
nth1 = 16;
|
1971
|
+
#if QK_K == 64
|
1972
|
+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
|
1973
|
+
#else
|
1929
1974
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
1975
|
+
#endif
|
1976
|
+
|
1930
1977
|
} break;
|
1931
1978
|
default:
|
1932
1979
|
{
|
@@ -1939,72 +1986,72 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1939
1986
|
GGML_ASSERT(ne00 >= nth0*nth1);
|
1940
1987
|
}
|
1941
1988
|
|
1942
|
-
const int64_t _ne1 = 1; // kernels needs a reference in constant memory
|
1943
|
-
|
1944
1989
|
[encoder setComputePipelineState:pipeline];
|
1945
1990
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1946
1991
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1947
1992
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1948
1993
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
1949
|
-
[encoder setBytes:&
|
1950
|
-
[encoder setBytes:&
|
1951
|
-
[encoder setBytes:&
|
1952
|
-
[encoder setBytes:&
|
1953
|
-
[encoder setBytes:&
|
1954
|
-
[encoder setBytes:&
|
1955
|
-
[encoder setBytes:&
|
1956
|
-
[encoder setBytes:&
|
1957
|
-
[encoder setBytes:&
|
1958
|
-
[encoder setBytes:&
|
1959
|
-
[encoder setBytes:&
|
1960
|
-
[encoder setBytes:&
|
1961
|
-
[encoder setBytes:&
|
1962
|
-
[encoder setBytes:&
|
1963
|
-
[encoder setBytes:&
|
1964
|
-
[encoder setBytes:&
|
1965
|
-
[encoder setBytes:&
|
1966
|
-
[encoder setBytes:&
|
1967
|
-
[encoder setBytes:&
|
1968
|
-
|
1994
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
1995
|
+
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
1996
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
1997
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
|
1998
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
|
1999
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
|
2000
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
|
2001
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
|
2002
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
|
2003
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
|
2004
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
|
2005
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
|
2006
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
|
2007
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
|
2008
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
|
2009
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
|
2010
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
|
2011
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
|
2012
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
|
2013
|
+
|
2014
|
+
const int64_t _ne1 = 1;
|
2015
|
+
const int tgz = dst_rows;
|
1969
2016
|
|
1970
2017
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
1971
2018
|
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
1972
2019
|
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
1973
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1,
|
2020
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1974
2021
|
}
|
1975
2022
|
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
1976
2023
|
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
1977
2024
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1978
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1,
|
2025
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1979
2026
|
}
|
1980
2027
|
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
|
1981
2028
|
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
1982
2029
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1983
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1,
|
2030
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1984
2031
|
}
|
1985
2032
|
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
|
1986
2033
|
const int mem_size = 32*sizeof(float);
|
1987
2034
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
1988
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1,
|
2035
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1989
2036
|
}
|
1990
2037
|
else if (src0t == GGML_TYPE_Q4_K) {
|
1991
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1,
|
2038
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1992
2039
|
}
|
1993
2040
|
else if (src0t == GGML_TYPE_Q3_K) {
|
1994
2041
|
#ifdef GGML_QKK_64
|
1995
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1,
|
2042
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1996
2043
|
#else
|
1997
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1,
|
2044
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1998
2045
|
#endif
|
1999
2046
|
}
|
2000
2047
|
else if (src0t == GGML_TYPE_Q5_K) {
|
2001
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1,
|
2048
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2002
2049
|
}
|
2003
2050
|
else if (src0t == GGML_TYPE_Q6_K) {
|
2004
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1,
|
2051
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2005
2052
|
} else {
|
2006
|
-
const int64_t ny = (_ne1 + nrows - 1)/nrows;
|
2007
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny,
|
2053
|
+
const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
|
2054
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2008
2055
|
}
|
2009
2056
|
}
|
2010
2057
|
} break;
|