llama_cpp 0.15.1 → 0.15.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/ext/llama_cpp/llama_cpp.cpp +49 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -0
- data/vendor/tmp/llama.cpp/Makefile +9 -20
- data/vendor/tmp/llama.cpp/ggml-backend.c +2 -3
- data/vendor/tmp/llama.cpp/ggml-common.h +0 -54
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +87 -37
- data/vendor/tmp/llama.cpp/ggml-cuda.h +1 -0
- data/vendor/tmp/llama.cpp/ggml-impl.h +47 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +13 -3
- data/vendor/tmp/llama.cpp/ggml-metal.m +177 -190
- data/vendor/tmp/llama.cpp/ggml-metal.metal +97 -505
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +4 -1
- data/vendor/tmp/llama.cpp/ggml-quants.c +3660 -2057
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +1155 -0
- data/vendor/tmp/llama.cpp/ggml-rpc.h +24 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +60 -639
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +9351 -5627
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +203 -224
- data/vendor/tmp/llama.cpp/ggml.c +1168 -1470
- data/vendor/tmp/llama.cpp/ggml.h +67 -44
- data/vendor/tmp/llama.cpp/llama.cpp +1371 -944
- data/vendor/tmp/llama.cpp/llama.h +13 -3
- data/vendor/tmp/llama.cpp/unicode-data.cpp +6969 -2169
- data/vendor/tmp/llama.cpp/unicode-data.h +15 -12
- data/vendor/tmp/llama.cpp/unicode.cpp +89 -111
- data/vendor/tmp/llama.cpp/unicode.h +44 -12
- metadata +5 -3
@@ -40,6 +40,7 @@ enum ggml_metal_kernel_type {
|
|
40
40
|
GGML_METAL_KERNEL_TYPE_CLAMP,
|
41
41
|
GGML_METAL_KERNEL_TYPE_TANH,
|
42
42
|
GGML_METAL_KERNEL_TYPE_RELU,
|
43
|
+
GGML_METAL_KERNEL_TYPE_SIGMOID,
|
43
44
|
GGML_METAL_KERNEL_TYPE_GELU,
|
44
45
|
GGML_METAL_KERNEL_TYPE_GELU_4,
|
45
46
|
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
@@ -169,7 +170,6 @@ enum ggml_metal_kernel_type {
|
|
169
170
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
170
171
|
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
171
172
|
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
172
|
-
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
173
173
|
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
174
174
|
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
175
175
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
@@ -381,10 +381,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
381
381
|
// dictionary of preprocessor macros
|
382
382
|
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
383
383
|
|
384
|
-
#ifdef GGML_QKK_64
|
385
|
-
prep[@"GGML_QKK_64"] = @(1);
|
386
|
-
#endif
|
387
|
-
|
388
384
|
MTLCompileOptions* options = [MTLCompileOptions new];
|
389
385
|
options.preprocessorMacros = prep;
|
390
386
|
|
@@ -494,6 +490,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
494
490
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
495
491
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
496
492
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
493
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
|
497
494
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
498
495
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
|
499
496
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
@@ -623,7 +620,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
623
620
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
624
621
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
625
622
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
626
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
627
623
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
628
624
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
629
625
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
@@ -633,14 +629,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
633
629
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
634
630
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
635
631
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
636
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64,
|
637
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80,
|
638
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96,
|
639
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112,
|
640
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128,
|
641
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256,
|
642
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128,
|
643
|
-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256,
|
632
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm);
|
633
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm);
|
634
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
|
635
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
|
636
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
|
637
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
|
638
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
|
639
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
|
644
640
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
645
641
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
646
642
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
@@ -732,6 +728,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
732
728
|
switch (ggml_get_unary_op(op)) {
|
733
729
|
case GGML_UNARY_OP_TANH:
|
734
730
|
case GGML_UNARY_OP_RELU:
|
731
|
+
case GGML_UNARY_OP_SIGMOID:
|
735
732
|
case GGML_UNARY_OP_GELU:
|
736
733
|
case GGML_UNARY_OP_GELU_QUICK:
|
737
734
|
case GGML_UNARY_OP_SILU:
|
@@ -759,7 +756,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
759
756
|
case GGML_OP_GROUP_NORM:
|
760
757
|
return ctx->support_simdgroup_reduction;
|
761
758
|
case GGML_OP_NORM:
|
762
|
-
case GGML_OP_ALIBI:
|
763
759
|
case GGML_OP_ROPE:
|
764
760
|
case GGML_OP_IM2COL:
|
765
761
|
return true;
|
@@ -772,8 +768,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
772
768
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
773
769
|
case GGML_OP_ARGSORT:
|
774
770
|
case GGML_OP_LEAKY_RELU:
|
775
|
-
case GGML_OP_FLASH_ATTN_EXT:
|
776
771
|
return true;
|
772
|
+
case GGML_OP_FLASH_ATTN_EXT:
|
773
|
+
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
777
774
|
case GGML_OP_MUL_MAT:
|
778
775
|
case GGML_OP_MUL_MAT_ID:
|
779
776
|
return ctx->support_simdgroup_reduction &&
|
@@ -926,22 +923,32 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
926
923
|
const int64_t ne10 = src1 ? src1->ne[0] : 0;
|
927
924
|
const int64_t ne11 = src1 ? src1->ne[1] : 0;
|
928
925
|
const int64_t ne12 = src1 ? src1->ne[2] : 0;
|
929
|
-
const int64_t ne13 = src1 ? src1->ne[3] : 0;
|
926
|
+
const int64_t ne13 = src1 ? src1->ne[3] : 0;
|
930
927
|
|
931
928
|
const uint64_t nb10 = src1 ? src1->nb[0] : 0;
|
932
929
|
const uint64_t nb11 = src1 ? src1->nb[1] : 0;
|
933
930
|
const uint64_t nb12 = src1 ? src1->nb[2] : 0;
|
934
|
-
const uint64_t nb13 = src1 ? src1->nb[3] : 0;
|
931
|
+
const uint64_t nb13 = src1 ? src1->nb[3] : 0;
|
932
|
+
|
933
|
+
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
934
|
+
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
935
|
+
const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
|
936
|
+
const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
|
935
937
|
|
936
|
-
const
|
937
|
-
const
|
938
|
-
const
|
939
|
-
const
|
938
|
+
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
939
|
+
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
940
|
+
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
941
|
+
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
|
940
942
|
|
941
|
-
const
|
942
|
-
const
|
943
|
-
const
|
944
|
-
const
|
943
|
+
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
944
|
+
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
945
|
+
const int64_t ne2 = dst ? dst->ne[2] : 0;
|
946
|
+
const int64_t ne3 = dst ? dst->ne[3] : 0;
|
947
|
+
|
948
|
+
const uint64_t nb0 = dst ? dst->nb[0] : 0;
|
949
|
+
const uint64_t nb1 = dst ? dst->nb[1] : 0;
|
950
|
+
const uint64_t nb2 = dst ? dst->nb[2] : 0;
|
951
|
+
const uint64_t nb3 = dst ? dst->nb[3] : 0;
|
945
952
|
|
946
953
|
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
|
947
954
|
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
|
@@ -1194,24 +1201,24 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1194
1201
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1195
1202
|
} break;
|
1196
1203
|
case GGML_OP_CLAMP:
|
1197
|
-
|
1198
|
-
|
1204
|
+
{
|
1205
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
|
1199
1206
|
|
1200
|
-
|
1201
|
-
|
1202
|
-
|
1203
|
-
|
1207
|
+
float min;
|
1208
|
+
float max;
|
1209
|
+
memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
|
1210
|
+
memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
|
1204
1211
|
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1209
|
-
|
1212
|
+
[encoder setComputePipelineState:pipeline];
|
1213
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1214
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1215
|
+
[encoder setBytes:&min length:sizeof(min) atIndex:2];
|
1216
|
+
[encoder setBytes:&max length:sizeof(max) atIndex:3];
|
1210
1217
|
|
1211
|
-
|
1218
|
+
const int64_t n = ggml_nelements(dst);
|
1212
1219
|
|
1213
|
-
|
1214
|
-
|
1220
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1221
|
+
} break;
|
1215
1222
|
case GGML_OP_UNARY:
|
1216
1223
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
1217
1224
|
// we are not taking into account the strides, so for now require contiguous tensors
|
@@ -1239,6 +1246,18 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1239
1246
|
|
1240
1247
|
const int64_t n = ggml_nelements(dst);
|
1241
1248
|
|
1249
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1250
|
+
} break;
|
1251
|
+
case GGML_UNARY_OP_SIGMOID:
|
1252
|
+
{
|
1253
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
|
1254
|
+
|
1255
|
+
[encoder setComputePipelineState:pipeline];
|
1256
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1257
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1258
|
+
|
1259
|
+
const int64_t n = ggml_nelements(dst);
|
1260
|
+
|
1242
1261
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1243
1262
|
} break;
|
1244
1263
|
case GGML_UNARY_OP_GELU:
|
@@ -1357,16 +1376,15 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1357
1376
|
case GGML_OP_SOFT_MAX:
|
1358
1377
|
{
|
1359
1378
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
1360
|
-
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
|
1361
1379
|
|
1362
1380
|
int nth = 32; // SIMD width
|
1363
1381
|
|
1364
1382
|
id<MTLComputePipelineState> pipeline = nil;
|
1365
1383
|
|
1366
|
-
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16)
|
1384
|
+
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
1367
1385
|
|
1368
1386
|
if (ne00%4 == 0) {
|
1369
|
-
while (nth < ne00/4 && nth < 256) {
|
1387
|
+
while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
|
1370
1388
|
nth *= 2;
|
1371
1389
|
}
|
1372
1390
|
if (use_f16) {
|
@@ -1375,7 +1393,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1375
1393
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
|
1376
1394
|
}
|
1377
1395
|
} else {
|
1378
|
-
while (nth < ne00 && nth <
|
1396
|
+
while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
|
1379
1397
|
nth *= 2;
|
1380
1398
|
}
|
1381
1399
|
if (use_f16) {
|
@@ -1394,8 +1412,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1394
1412
|
const int64_t nrows_x = ggml_nrows(src0);
|
1395
1413
|
const int64_t nrows_y = src0->ne[1];
|
1396
1414
|
|
1397
|
-
const uint32_t
|
1398
|
-
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float)
|
1415
|
+
const uint32_t n_head = nrows_x/nrows_y;
|
1416
|
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
1399
1417
|
|
1400
1418
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
1401
1419
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
@@ -1407,20 +1425,15 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1407
1425
|
} else {
|
1408
1426
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
1409
1427
|
}
|
1410
|
-
|
1411
|
-
|
1412
|
-
|
1413
|
-
|
1414
|
-
|
1415
|
-
[encoder
|
1416
|
-
[encoder setBytes:&
|
1417
|
-
[encoder setBytes:&
|
1418
|
-
[encoder setBytes:&
|
1419
|
-
[encoder setBytes:&scale length:sizeof(scale) atIndex:7];
|
1420
|
-
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
|
1421
|
-
[encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
|
1422
|
-
[encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
|
1423
|
-
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
|
1428
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1429
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
1430
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
1431
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
1432
|
+
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
1433
|
+
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
|
1434
|
+
[encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
|
1435
|
+
[encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
|
1436
|
+
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
|
1424
1437
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
1425
1438
|
|
1426
1439
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
@@ -1756,11 +1769,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1756
1769
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1757
1770
|
}
|
1758
1771
|
else if (src0t == GGML_TYPE_Q3_K) {
|
1759
|
-
#ifdef GGML_QKK_64
|
1760
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1761
|
-
#else
|
1762
1772
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1763
|
-
#endif
|
1764
1773
|
}
|
1765
1774
|
else if (src0t == GGML_TYPE_Q5_K) {
|
1766
1775
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
@@ -1778,16 +1787,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1778
1787
|
const int n_as = src0->ne[2];
|
1779
1788
|
|
1780
1789
|
// src2 = ids
|
1781
|
-
const int64_t ne20 = src2->ne[0];
|
1782
|
-
const int64_t ne21 = src2->ne[1];
|
1783
|
-
const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
|
1784
|
-
const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
|
1785
|
-
|
1786
|
-
const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
|
1787
|
-
const uint64_t nb21 = src2->nb[1];
|
1788
|
-
const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
|
1789
|
-
const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
|
1790
|
-
|
1791
1790
|
const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
|
1792
1791
|
|
1793
1792
|
GGML_ASSERT(src2t == GGML_TYPE_I32);
|
@@ -2011,12 +2010,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2011
2010
|
{
|
2012
2011
|
nth0 = 4;
|
2013
2012
|
nth1 = 16;
|
2014
|
-
#if QK_K == 64
|
2015
|
-
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
|
2016
|
-
#else
|
2017
2013
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
2018
|
-
#endif
|
2019
|
-
|
2020
2014
|
} break;
|
2021
2015
|
default:
|
2022
2016
|
{
|
@@ -2081,11 +2075,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2081
2075
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2082
2076
|
}
|
2083
2077
|
else if (src0t == GGML_TYPE_Q3_K) {
|
2084
|
-
#ifdef GGML_QKK_64
|
2085
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2086
|
-
#else
|
2087
2078
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
2088
|
-
#endif
|
2089
2079
|
}
|
2090
2080
|
else if (src0t == GGML_TYPE_Q5_K) {
|
2091
2081
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
@@ -2225,49 +2215,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2225
2215
|
|
2226
2216
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2227
2217
|
} break;
|
2228
|
-
case GGML_OP_ALIBI:
|
2229
|
-
{
|
2230
|
-
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
2231
|
-
|
2232
|
-
const int nth = MIN(1024, ne00);
|
2233
|
-
|
2234
|
-
//const int n_past = ((int32_t *) dst->op_params)[0];
|
2235
|
-
const int n_head = ((int32_t *) dst->op_params)[1];
|
2236
|
-
|
2237
|
-
float max_bias;
|
2238
|
-
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
2239
|
-
|
2240
|
-
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
2241
|
-
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
2242
|
-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
2243
|
-
|
2244
|
-
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
|
2245
|
-
|
2246
|
-
[encoder setComputePipelineState:pipeline];
|
2247
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2248
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2249
|
-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
2250
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
2251
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
2252
|
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
2253
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
2254
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
2255
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
2256
|
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
2257
|
-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
2258
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
2259
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
2260
|
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
2261
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
2262
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
2263
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
2264
|
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
2265
|
-
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
2266
|
-
[encoder setBytes:&m1 length:sizeof( float) atIndex:19];
|
2267
|
-
[encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
|
2268
|
-
|
2269
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2270
|
-
} break;
|
2271
2218
|
case GGML_OP_ROPE:
|
2272
2219
|
{
|
2273
2220
|
GGML_ASSERT(ne10 == ne02);
|
@@ -2280,7 +2227,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2280
2227
|
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
2281
2228
|
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
2282
2229
|
|
2283
|
-
float freq_base
|
2230
|
+
float freq_base;
|
2231
|
+
float freq_scale;
|
2232
|
+
float ext_factor;
|
2233
|
+
float attn_factor;
|
2234
|
+
float beta_fast;
|
2235
|
+
float beta_slow;
|
2236
|
+
|
2284
2237
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
2285
2238
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
2286
2239
|
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
@@ -2288,6 +2241,15 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2288
2241
|
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
2289
2242
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
2290
2243
|
|
2244
|
+
const bool is_neox = mode & 2;
|
2245
|
+
const bool is_glm = mode & 4;
|
2246
|
+
|
2247
|
+
GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal");
|
2248
|
+
|
2249
|
+
if (!is_neox) {
|
2250
|
+
GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
|
2251
|
+
}
|
2252
|
+
|
2291
2253
|
id<MTLComputePipelineState> pipeline = nil;
|
2292
2254
|
|
2293
2255
|
switch (src0->type) {
|
@@ -2299,33 +2261,38 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2299
2261
|
[encoder setComputePipelineState:pipeline];
|
2300
2262
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2301
2263
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
2302
|
-
|
2303
|
-
|
2304
|
-
|
2305
|
-
|
2306
|
-
|
2307
|
-
[encoder
|
2308
|
-
[encoder setBytes:&
|
2309
|
-
[encoder setBytes:&
|
2310
|
-
[encoder setBytes:&
|
2311
|
-
[encoder setBytes:&
|
2312
|
-
[encoder setBytes:&
|
2313
|
-
[encoder setBytes:&
|
2314
|
-
[encoder setBytes:&
|
2315
|
-
[encoder setBytes:&
|
2316
|
-
[encoder setBytes:&
|
2317
|
-
[encoder setBytes:&
|
2318
|
-
[encoder setBytes:&
|
2319
|
-
[encoder setBytes:&
|
2320
|
-
[encoder setBytes:&
|
2321
|
-
[encoder setBytes:&
|
2322
|
-
[encoder setBytes:&
|
2323
|
-
[encoder setBytes:&
|
2324
|
-
[encoder setBytes:&
|
2325
|
-
[encoder setBytes:&
|
2326
|
-
[encoder setBytes:&
|
2327
|
-
[encoder setBytes:&
|
2328
|
-
[encoder setBytes:&
|
2264
|
+
if (id_src2 != nil) {
|
2265
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
2266
|
+
} else {
|
2267
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
|
2268
|
+
}
|
2269
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
2270
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4];
|
2271
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
2272
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
2273
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
2274
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8];
|
2275
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9];
|
2276
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10];
|
2277
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11];
|
2278
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12];
|
2279
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13];
|
2280
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14];
|
2281
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15];
|
2282
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16];
|
2283
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17];
|
2284
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18];
|
2285
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
|
2286
|
+
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
|
2287
|
+
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
|
2288
|
+
[encoder setBytes:&mode length:sizeof( int) atIndex:22];
|
2289
|
+
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23];
|
2290
|
+
[encoder setBytes:&freq_base length:sizeof( float) atIndex:24];
|
2291
|
+
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:25];
|
2292
|
+
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:26];
|
2293
|
+
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:27];
|
2294
|
+
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:28];
|
2295
|
+
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
|
2329
2296
|
|
2330
2297
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2331
2298
|
} break;
|
@@ -2389,7 +2356,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2389
2356
|
{
|
2390
2357
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2391
2358
|
|
2392
|
-
const
|
2359
|
+
const float sf0 = (float)ne0/src0->ne[0];
|
2360
|
+
const float sf1 = (float)ne1/src0->ne[1];
|
2361
|
+
const float sf2 = (float)ne2/src0->ne[2];
|
2362
|
+
const float sf3 = (float)ne3/src0->ne[3];
|
2393
2363
|
|
2394
2364
|
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
|
2395
2365
|
|
@@ -2412,7 +2382,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2412
2382
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
2413
2383
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
2414
2384
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
2415
|
-
[encoder setBytes:&
|
2385
|
+
[encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
|
2386
|
+
[encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
|
2387
|
+
[encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
|
2388
|
+
[encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
|
2416
2389
|
|
2417
2390
|
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
2418
2391
|
|
@@ -2548,13 +2521,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2548
2521
|
} break;
|
2549
2522
|
case GGML_OP_FLASH_ATTN_EXT:
|
2550
2523
|
{
|
2551
|
-
GGML_ASSERT(ne00 % 4
|
2524
|
+
GGML_ASSERT(ne00 % 4 == 0);
|
2525
|
+
GGML_ASSERT(ne11 % 32 == 0);
|
2526
|
+
|
2552
2527
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
2553
2528
|
|
2554
|
-
|
2529
|
+
GGML_ASSERT(ggml_are_same_shape (src1, src2));
|
2555
2530
|
|
2556
|
-
|
2557
|
-
GGML_ASSERT(src3);
|
2531
|
+
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
|
2558
2532
|
|
2559
2533
|
size_t offs_src3 = 0;
|
2560
2534
|
|
@@ -2565,7 +2539,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2565
2539
|
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
2566
2540
|
|
2567
2541
|
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
|
2568
|
-
|
2542
|
+
//const int64_t ne31 = src3 ? src3->ne[1] : 0;
|
2569
2543
|
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
|
2570
2544
|
const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
|
2571
2545
|
|
@@ -2577,7 +2551,16 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2577
2551
|
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
2578
2552
|
|
2579
2553
|
float scale;
|
2580
|
-
|
2554
|
+
float max_bias;
|
2555
|
+
|
2556
|
+
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
|
2557
|
+
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
2558
|
+
|
2559
|
+
const uint32_t n_head = src0->ne[2];
|
2560
|
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
2561
|
+
|
2562
|
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
2563
|
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
2581
2564
|
|
2582
2565
|
id<MTLComputePipelineState> pipeline = nil;
|
2583
2566
|
|
@@ -2614,34 +2597,38 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2614
2597
|
}
|
2615
2598
|
|
2616
2599
|
[encoder setComputePipelineState:pipeline];
|
2617
|
-
[encoder setBuffer:id_src0
|
2618
|
-
[encoder setBuffer:id_src1
|
2619
|
-
[encoder setBuffer:id_src2
|
2620
|
-
|
2621
|
-
|
2622
|
-
|
2623
|
-
|
2624
|
-
|
2625
|
-
[encoder
|
2626
|
-
[encoder setBytes:&
|
2627
|
-
[encoder setBytes:&
|
2628
|
-
[encoder setBytes:&
|
2629
|
-
[encoder setBytes:&
|
2630
|
-
[encoder setBytes:&
|
2631
|
-
[encoder setBytes:&
|
2632
|
-
[encoder setBytes:&
|
2633
|
-
[encoder setBytes:&
|
2634
|
-
[encoder setBytes:&
|
2635
|
-
[encoder setBytes:&nb11
|
2636
|
-
[encoder setBytes:&nb12
|
2637
|
-
[encoder setBytes:&nb13
|
2638
|
-
[encoder setBytes:&
|
2639
|
-
[encoder setBytes:&
|
2640
|
-
[encoder setBytes:&
|
2641
|
-
[encoder setBytes:&
|
2642
|
-
[encoder setBytes:&
|
2643
|
-
[encoder setBytes:&
|
2644
|
-
[encoder setBytes:&scale
|
2600
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
2601
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
2602
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
2603
|
+
if (id_src3) {
|
2604
|
+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
2605
|
+
} else {
|
2606
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
|
2607
|
+
}
|
2608
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
2609
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
2610
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
2611
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
2612
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
2613
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
2614
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
2615
|
+
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
|
2616
|
+
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
|
2617
|
+
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
|
2618
|
+
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
2619
|
+
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
2620
|
+
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
2621
|
+
[encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
|
2622
|
+
[encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
|
2623
|
+
[encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
|
2624
|
+
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
|
2625
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
|
2626
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
|
2627
|
+
[encoder setBytes:&scale length:sizeof( float) atIndex:23];
|
2628
|
+
[encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
|
2629
|
+
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
|
2630
|
+
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
|
2631
|
+
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
|
2645
2632
|
|
2646
2633
|
if (!use_vec_kernel) {
|
2647
2634
|
// half8x8 kernel
|