llama_cpp 0.15.1 → 0.15.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/ext/llama_cpp/llama_cpp.cpp +49 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -0
- data/vendor/tmp/llama.cpp/Makefile +9 -20
- data/vendor/tmp/llama.cpp/ggml-backend.c +2 -3
- data/vendor/tmp/llama.cpp/ggml-common.h +0 -54
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +87 -37
- data/vendor/tmp/llama.cpp/ggml-cuda.h +1 -0
- data/vendor/tmp/llama.cpp/ggml-impl.h +47 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +13 -3
- data/vendor/tmp/llama.cpp/ggml-metal.m +177 -190
- data/vendor/tmp/llama.cpp/ggml-metal.metal +97 -505
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +4 -1
- data/vendor/tmp/llama.cpp/ggml-quants.c +3660 -2057
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +1155 -0
- data/vendor/tmp/llama.cpp/ggml-rpc.h +24 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +60 -639
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +9351 -5627
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +203 -224
- data/vendor/tmp/llama.cpp/ggml.c +1168 -1470
- data/vendor/tmp/llama.cpp/ggml.h +67 -44
- data/vendor/tmp/llama.cpp/llama.cpp +1371 -944
- data/vendor/tmp/llama.cpp/llama.h +13 -3
- data/vendor/tmp/llama.cpp/unicode-data.cpp +6969 -2169
- data/vendor/tmp/llama.cpp/unicode-data.h +15 -12
- data/vendor/tmp/llama.cpp/unicode.cpp +89 -111
- data/vendor/tmp/llama.cpp/unicode.h +44 -12
- metadata +5 -3
@@ -40,6 +40,7 @@ enum ggml_metal_kernel_type {
|
|
40
40
|
GGML_METAL_KERNEL_TYPE_CLAMP,
|
41
41
|
GGML_METAL_KERNEL_TYPE_TANH,
|
42
42
|
GGML_METAL_KERNEL_TYPE_RELU,
|
43
|
+
GGML_METAL_KERNEL_TYPE_SIGMOID,
|
43
44
|
GGML_METAL_KERNEL_TYPE_GELU,
|
44
45
|
GGML_METAL_KERNEL_TYPE_GELU_4,
|
45
46
|
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
@@ -169,7 +170,6 @@ enum ggml_metal_kernel_type {
|
|
169
170
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
170
171
|
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
171
172
|
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
172
|
-
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
173
173
|
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
174
174
|
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
175
175
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
@@ -381,10 +381,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
381
381
|
// dictionary of preprocessor macros
|
382
382
|
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
383
383
|
|
384
|
-
#ifdef GGML_QKK_64
|
385
|
-
prep[@"GGML_QKK_64"] = @(1);
|
386
|
-
#endif
|
387
|
-
|
388
384
|
MTLCompileOptions* options = [MTLCompileOptions new];
|
389
385
|
options.preprocessorMacros = prep;
|
390
386
|
|
@@ -494,6 +490,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
494
490
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
495
491
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
496
492
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
493
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
|
497
494
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
498
495
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
|
499
496
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
@@ -623,7 +620,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
623
620
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
624
621
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
625
622
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
626
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
627
623
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
628
624
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
629
625
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
@@ -633,14 +629,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
633
629
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
634
630
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
635
631
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
636
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64,
|
637
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80,
|
638
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96,
|
639
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112,
|
640
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128,
|
641
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256,
|
642
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128,
|
643
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256,
|
632
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm);
|
633
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm);
|
634
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
|
635
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
|
636
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
|
637
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
|
638
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
|
639
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
|
644
640
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
645
641
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
646
642
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
@@ -732,6 +728,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
732
728
|
switch (ggml_get_unary_op(op)) {
|
733
729
|
case GGML_UNARY_OP_TANH:
|
734
730
|
case GGML_UNARY_OP_RELU:
|
731
|
+
case GGML_UNARY_OP_SIGMOID:
|
735
732
|
case GGML_UNARY_OP_GELU:
|
736
733
|
case GGML_UNARY_OP_GELU_QUICK:
|
737
734
|
case GGML_UNARY_OP_SILU:
|
@@ -759,7 +756,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
759
756
|
case GGML_OP_GROUP_NORM:
|
760
757
|
return ctx->support_simdgroup_reduction;
|
761
758
|
case GGML_OP_NORM:
|
762
|
-
case GGML_OP_ALIBI:
|
763
759
|
case GGML_OP_ROPE:
|
764
760
|
case GGML_OP_IM2COL:
|
765
761
|
return true;
|
@@ -772,8 +768,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
772
768
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
773
769
|
case GGML_OP_ARGSORT:
|
774
770
|
case GGML_OP_LEAKY_RELU:
|
775
|
-
case GGML_OP_FLASH_ATTN_EXT:
|
776
771
|
return true;
|
772
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
773
|
+
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
777
774
|
case GGML_OP_MUL_MAT:
|
778
775
|
case GGML_OP_MUL_MAT_ID:
|
779
776
|
return ctx->support_simdgroup_reduction &&
|
@@ -926,22 +923,32 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
926
923
|
const int64_t ne10 = src1 ? src1->ne[0] : 0;
|
927
924
|
const int64_t ne11 = src1 ? src1->ne[1] : 0;
|
928
925
|
const int64_t ne12 = src1 ? src1->ne[2] : 0;
|
929
|
-
const int64_t ne13 = src1 ? src1->ne[3] : 0;
|
926
|
+
const int64_t ne13 = src1 ? src1->ne[3] : 0;
|
930
927
|
|
931
928
|
const uint64_t nb10 = src1 ? src1->nb[0] : 0;
|
932
929
|
const uint64_t nb11 = src1 ? src1->nb[1] : 0;
|
933
930
|
const uint64_t nb12 = src1 ? src1->nb[2] : 0;
|
934
|
-
const uint64_t nb13 = src1 ? src1->nb[3] : 0;
|
931
|
+
const uint64_t nb13 = src1 ? src1->nb[3] : 0;
|
932
|
+
|
933
|
+
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
934
|
+
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
935
|
+
const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
|
936
|
+
const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
|
935
937
|
|
936
|
-
const
|
937
|
-
const
|
938
|
-
const
|
939
|
-
const
|
938
|
+
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
939
|
+
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
940
|
+
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
941
|
+
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
|
940
942
|
|
941
|
-
const
|
942
|
-
const
|
943
|
-
const
|
944
|
-
const
|
943
|
+
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
944
|
+
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
945
|
+
const int64_t ne2 = dst ? dst->ne[2] : 0;
|
946
|
+
const int64_t ne3 = dst ? dst->ne[3] : 0;
|
947
|
+
|
948
|
+
const uint64_t nb0 = dst ? dst->nb[0] : 0;
|
949
|
+
const uint64_t nb1 = dst ? dst->nb[1] : 0;
|
950
|
+
const uint64_t nb2 = dst ? dst->nb[2] : 0;
|
951
|
+
const uint64_t nb3 = dst ? dst->nb[3] : 0;
|
945
952
|
|
946
953
|
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
|
947
954
|
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
|
@@ -1194,24 +1201,24 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1194
1201
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1195
1202
|
} break;
|
1196
1203
|
case GGML_OP_CLAMP:
|
1197
|
-
|
1198
|
-
|
1204
|
+
{
|
1205
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
|
1199
1206
|
|
1200
|
-
|
1201
|
-
|
1202
|
-
|
1203
|
-
|
1207
|
+
float min;
|
1208
|
+
float max;
|
1209
|
+
memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
|
1210
|
+
memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
|
1204
1211
|
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1209
|
-
|
1212
|
+
[encoder setComputePipelineState:pipeline];
|
1213
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1214
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1215
|
+
[encoder setBytes:&min length:sizeof(min) atIndex:2];
|
1216
|
+
[encoder setBytes:&max length:sizeof(max) atIndex:3];
|
1210
1217
|
|
1211
|
-
|
1218
|
+
const int64_t n = ggml_nelements(dst);
|
1212
1219
|
|
1213
|
-
|
1214
|
-
|
1220
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1221
|
+
} break;
|
1215
1222
|
case GGML_OP_UNARY:
|
1216
1223
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
1217
1224
|
// we are not taking into account the strides, so for now require contiguous tensors
|
@@ -1239,6 +1246,18 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1239
1246
|
|
1240
1247
|
const int64_t n = ggml_nelements(dst);
|
1241
1248
|
|
1249
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1250
|
+
} break;
|
1251
|
+
case GGML_UNARY_OP_SIGMOID:
|
1252
|
+
{
|
1253
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
|
1254
|
+
|
1255
|
+
[encoder setComputePipelineState:pipeline];
|
1256
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1257
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1258
|
+
|
1259
|
+
const int64_t n = ggml_nelements(dst);
|
1260
|
+
|
1242
1261
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1243
1262
|
} break;
|
1244
1263
|
case GGML_UNARY_OP_GELU:
|
@@ -1357,16 +1376,15 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1357
1376
|
case GGML_OP_SOFT_MAX:
|
1358
1377
|
{
|
1359
1378
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
1360
|
-
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
|
1361
1379
|
|
1362
1380
|
int nth = 32; // SIMD width
|
1363
1381
|
|
1364
1382
|
id<MTLComputePipelineState> pipeline = nil;
|
1365
1383
|
|
1366
|
-
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16)
|
1384
|
+
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
1367
1385
|
|
1368
1386
|
if (ne00%4 == 0) {
|
1369
|
-
while (nth < ne00/4 && nth < 256) {
|
1387
|
+
while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
|
1370
1388
|
nth *= 2;
|
1371
1389
|
}
|
1372
1390
|
if (use_f16) {
|
@@ -1375,7 +1393,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1375
1393
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
|
1376
1394
|
}
|
1377
1395
|
} else {
|
1378
|
-
while (nth < ne00 && nth <
|
1396
|
+
while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
|
1379
1397
|
nth *= 2;
|
1380
1398
|
}
|
1381
1399
|
if (use_f16) {
|
@@ -1394,8 +1412,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1394
1412
|
const int64_t nrows_x = ggml_nrows(src0);
|
1395
1413
|
const int64_t nrows_y = src0->ne[1];
|
1396
1414
|
|
1397
|
-
const uint32_t
|
1398
|
-
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float)
|
1415
|
+
const uint32_t n_head = nrows_x/nrows_y;
|
1416
|
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
1399
1417
|
|
1400
1418
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
1401
1419
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
@@ -1407,20 +1425,15 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1407
1425
|
} else {
|
1408
1426
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
1409
1427
|
}
|
1410
|
-
|
1411
|
-
|
1412
|
-
|
1413
|
-
|
1414
|
-
|
1415
|
-
[encoder
|
1416
|
-
[encoder setBytes:&
|
1417
|
-
[encoder setBytes:&
|
1418
|
-
[encoder setBytes:&
|
1419
|
-
[encoder setBytes:&scale length:sizeof(scale) atIndex:7];
|
1420
|
-
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
|
1421
|
-
[encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
|
1422
|
-
[encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
|
1423
|
-
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
|
1428
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1429
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
1430
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
1431
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
1432
|
+
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
1433
|
+
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
|
1434
|
+
[encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
|
1435
|
+
[encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
|
1436
|
+
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
|
1424
1437
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
1425
1438
|
|
1426
1439
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
@@ -1756,11 +1769,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1756
1769
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1757
1770
|
}
|
1758
1771
|
else if (src0t == GGML_TYPE_Q3_K) {
|
1759
|
-
#ifdef GGML_QKK_64
|
1760
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1761
|
-
#else
|
1762
1772
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1763
|
-
#endif
|
1764
1773
|
}
|
1765
1774
|
else if (src0t == GGML_TYPE_Q5_K) {
|
1766
1775
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
@@ -1778,16 +1787,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1778
1787
|
const int n_as = src0->ne[2];
|
1779
1788
|
|
1780
1789
|
// src2 = ids
|
1781
|
-
const int64_t ne20 = src2->ne[0];
|
1782
|
-
const int64_t ne21 = src2->ne[1];
|
1783
|
-
const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
|
1784
|
-
const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
|
1785
|
-
|
1786
|
-
const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
|
1787
|
-
const uint64_t nb21 = src2->nb[1];
|
1788
|
-
const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
|
1789
|
-
const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
|
1790
|
-
|
1791
1790
|
const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
|
1792
1791
|
|
1793
1792
|
GGML_ASSERT(src2t == GGML_TYPE_I32);
|
@@ -2011,12 +2010,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2011
2010
|
{
|
2012
2011
|
nth0 = 4;
|
2013
2012
|
nth1 = 16;
|
2014
|
-
#if QK_K == 64
|
2015
|
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
|
2016
|
-
#else
|
2017
2013
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
2018
|
-
#endif
|
2019
|
-
|
2020
2014
|
} break;
|
2021
2015
|
default:
|
2022
2016
|
{
|
@@ -2081,11 +2075,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2081
2075
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2082
2076
|
}
|
2083
2077
|
else if (src0t == GGML_TYPE_Q3_K) {
|
2084
|
-
#ifdef GGML_QKK_64
|
2085
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2086
|
-
#else
|
2087
2078
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2088
|
-
#endif
|
2089
2079
|
}
|
2090
2080
|
else if (src0t == GGML_TYPE_Q5_K) {
|
2091
2081
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
@@ -2225,49 +2215,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2225
2215
|
|
2226
2216
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2227
2217
|
} break;
|
2228
|
-
case GGML_OP_ALIBI:
|
2229
|
-
{
|
2230
|
-
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
2231
|
-
|
2232
|
-
const int nth = MIN(1024, ne00);
|
2233
|
-
|
2234
|
-
//const int n_past = ((int32_t *) dst->op_params)[0];
|
2235
|
-
const int n_head = ((int32_t *) dst->op_params)[1];
|
2236
|
-
|
2237
|
-
float max_bias;
|
2238
|
-
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
2239
|
-
|
2240
|
-
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
2241
|
-
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
2242
|
-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
2243
|
-
|
2244
|
-
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
|
2245
|
-
|
2246
|
-
[encoder setComputePipelineState:pipeline];
|
2247
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2248
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2249
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
2250
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
2251
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
2252
|
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
2253
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
2254
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
2255
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
2256
|
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
2257
|
-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
2258
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
2259
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
2260
|
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
2261
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
2262
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
2263
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
2264
|
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
2265
|
-
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
2266
|
-
[encoder setBytes:&m1 length:sizeof( float) atIndex:19];
|
2267
|
-
[encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
|
2268
|
-
|
2269
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2270
|
-
} break;
|
2271
2218
|
case GGML_OP_ROPE:
|
2272
2219
|
{
|
2273
2220
|
GGML_ASSERT(ne10 == ne02);
|
@@ -2280,7 +2227,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2280
2227
|
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
2281
2228
|
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
2282
2229
|
|
2283
|
-
float freq_base
|
2230
|
+
float freq_base;
|
2231
|
+
float freq_scale;
|
2232
|
+
float ext_factor;
|
2233
|
+
float attn_factor;
|
2234
|
+
float beta_fast;
|
2235
|
+
float beta_slow;
|
2236
|
+
|
2284
2237
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
2285
2238
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
2286
2239
|
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
@@ -2288,6 +2241,15 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2288
2241
|
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
2289
2242
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
2290
2243
|
|
2244
|
+
const bool is_neox = mode & 2;
|
2245
|
+
const bool is_glm = mode & 4;
|
2246
|
+
|
2247
|
+
GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
|
2248
|
+
|
2249
|
+
if (!is_neox) {
|
2250
|
+
GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
|
2251
|
+
}
|
2252
|
+
|
2291
2253
|
id<MTLComputePipelineState> pipeline = nil;
|
2292
2254
|
|
2293
2255
|
switch (src0->type) {
|
@@ -2299,33 +2261,38 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2299
2261
|
[encoder setComputePipelineState:pipeline];
|
2300
2262
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2301
2263
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
2302
|
-
|
2303
|
-
|
2304
|
-
|
2305
|
-
|
2306
|
-
|
2307
|
-
[encoder
|
2308
|
-
[encoder setBytes:&
|
2309
|
-
[encoder setBytes:&
|
2310
|
-
[encoder setBytes:&
|
2311
|
-
[encoder setBytes:&
|
2312
|
-
[encoder setBytes:&
|
2313
|
-
[encoder setBytes:&
|
2314
|
-
[encoder setBytes:&
|
2315
|
-
[encoder setBytes:&
|
2316
|
-
[encoder setBytes:&
|
2317
|
-
[encoder setBytes:&
|
2318
|
-
[encoder setBytes:&
|
2319
|
-
[encoder setBytes:&
|
2320
|
-
[encoder setBytes:&
|
2321
|
-
[encoder setBytes:&
|
2322
|
-
[encoder setBytes:&
|
2323
|
-
[encoder setBytes:&
|
2324
|
-
[encoder setBytes:&
|
2325
|
-
[encoder setBytes:&
|
2326
|
-
[encoder setBytes:&
|
2327
|
-
[encoder setBytes:&
|
2328
|
-
[encoder setBytes:&
|
2264
|
+
if (id_src2 != nil) {
|
2265
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
2266
|
+
} else {
|
2267
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
|
2268
|
+
}
|
2269
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
2270
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
|
2271
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
2272
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
2273
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
2274
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
|
2275
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
|
2276
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
|
2277
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
|
2278
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
|
2279
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
|
2280
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
|
2281
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
|
2282
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
|
2283
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
|
2284
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
|
2285
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
|
2286
|
+
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
|
2287
|
+
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
|
2288
|
+
[encoder setBytes:&mode length:sizeof( int) atIndex:22];
|
2289
|
+
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23];
|
2290
|
+
[encoder setBytes:&freq_base length:sizeof( float) atIndex:24];
|
2291
|
+
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:25];
|
2292
|
+
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:26];
|
2293
|
+
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:27];
|
2294
|
+
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:28];
|
2295
|
+
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
|
2329
2296
|
|
2330
2297
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2331
2298
|
} break;
|
@@ -2389,7 +2356,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2389
2356
|
{
|
2390
2357
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2391
2358
|
|
2392
|
-
const
|
2359
|
+
const float sf0 = (float)ne0/src0->ne[0];
|
2360
|
+
const float sf1 = (float)ne1/src0->ne[1];
|
2361
|
+
const float sf2 = (float)ne2/src0->ne[2];
|
2362
|
+
const float sf3 = (float)ne3/src0->ne[3];
|
2393
2363
|
|
2394
2364
|
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
|
2395
2365
|
|
@@ -2412,7 +2382,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2412
2382
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
2413
2383
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
2414
2384
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
2415
|
-
[encoder setBytes:&
|
2385
|
+
[encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
|
2386
|
+
[encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
|
2387
|
+
[encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
|
2388
|
+
[encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
|
2416
2389
|
|
2417
2390
|
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
2418
2391
|
|
@@ -2548,13 +2521,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2548
2521
|
} break;
|
2549
2522
|
case GGML_OP_FLASH_ATTN_EXT:
|
2550
2523
|
{
|
2551
|
-
GGML_ASSERT(ne00 % 4
|
2524
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
2525
|
+
GGML_ASSERT(ne11 % 32 == 0);
|
2526
|
+
|
2552
2527
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2553
2528
|
|
2554
|
-
|
2529
|
+
GGML_ASSERT(ggml_are_same_shape (src1, src2));
|
2555
2530
|
|
2556
|
-
|
2557
|
-
GGML_ASSERT(src3);
|
2531
|
+
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
|
2558
2532
|
|
2559
2533
|
size_t offs_src3 = 0;
|
2560
2534
|
|
@@ -2565,7 +2539,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2565
2539
|
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
2566
2540
|
|
2567
2541
|
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
|
2568
|
-
|
2542
|
+
//const int64_t ne31 = src3 ? src3->ne[1] : 0;
|
2569
2543
|
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
|
2570
2544
|
const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
|
2571
2545
|
|
@@ -2577,7 +2551,16 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2577
2551
|
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
2578
2552
|
|
2579
2553
|
float scale;
|
2580
|
-
|
2554
|
+
float max_bias;
|
2555
|
+
|
2556
|
+
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
|
2557
|
+
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
2558
|
+
|
2559
|
+
const uint32_t n_head = src0->ne[2];
|
2560
|
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
2561
|
+
|
2562
|
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
2563
|
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
2581
2564
|
|
2582
2565
|
id<MTLComputePipelineState> pipeline = nil;
|
2583
2566
|
|
@@ -2614,34 +2597,38 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2614
2597
|
}
|
2615
2598
|
|
2616
2599
|
[encoder setComputePipelineState:pipeline];
|
2617
|
-
[encoder setBuffer:id_src0
|
2618
|
-
[encoder setBuffer:id_src1
|
2619
|
-
[encoder setBuffer:id_src2
|
2620
|
-
|
2621
|
-
|
2622
|
-
|
2623
|
-
|
2624
|
-
|
2625
|
-
[encoder
|
2626
|
-
[encoder setBytes:&
|
2627
|
-
[encoder setBytes:&
|
2628
|
-
[encoder setBytes:&
|
2629
|
-
[encoder setBytes:&
|
2630
|
-
[encoder setBytes:&
|
2631
|
-
[encoder setBytes:&
|
2632
|
-
[encoder setBytes:&
|
2633
|
-
[encoder setBytes:&
|
2634
|
-
[encoder setBytes:&
|
2635
|
-
[encoder setBytes:&nb11
|
2636
|
-
[encoder setBytes:&nb12
|
2637
|
-
[encoder setBytes:&nb13
|
2638
|
-
[encoder setBytes:&
|
2639
|
-
[encoder setBytes:&
|
2640
|
-
[encoder setBytes:&
|
2641
|
-
[encoder setBytes:&
|
2642
|
-
[encoder setBytes:&
|
2643
|
-
[encoder setBytes:&
|
2644
|
-
[encoder setBytes:&scale
|
2600
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2601
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
2602
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
2603
|
+
if (id_src3) {
|
2604
|
+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
2605
|
+
} else {
|
2606
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
|
2607
|
+
}
|
2608
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
2609
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
2610
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
2611
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
2612
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
2613
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
2614
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
2615
|
+
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
|
2616
|
+
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
|
2617
|
+
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
|
2618
|
+
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
2619
|
+
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
2620
|
+
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
2621
|
+
[encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
|
2622
|
+
[encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
|
2623
|
+
[encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
|
2624
|
+
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
|
2625
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
|
2626
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
|
2627
|
+
[encoder setBytes:&scale length:sizeof( float) atIndex:23];
|
2628
|
+
[encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
|
2629
|
+
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
|
2630
|
+
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
|
2631
|
+
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
|
2645
2632
|
|
2646
2633
|
if (!use_vec_kernel) {
|
2647
2634
|
// half8x8 kernel
|