cui-llama.rn 1.1.0 → 1.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/cpp/common.cpp +19 -6
- package/cpp/ggml-aarch64.c +6 -21
- package/cpp/ggml-metal.m +154 -26
- package/cpp/ggml.c +115 -195
- package/cpp/ggml.h +5 -7
- package/cpp/llama-impl.h +10 -4
- package/cpp/llama-sampling.cpp +16 -14
- package/cpp/llama.cpp +1048 -500
- package/cpp/llama.h +3 -0
- package/package.json +1 -1
package/cpp/common.cpp
CHANGED
@@ -333,6 +333,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|
333
333
|
void gpt_params_parse_from_env(gpt_params & params) {
|
334
334
|
// we only care about server-related params for now
|
335
335
|
get_env("LLAMA_ARG_MODEL", params.model);
|
336
|
+
get_env("LLAMA_ARG_MODEL_URL", params.model_url);
|
337
|
+
get_env("LLAMA_ARG_MODEL_ALIAS", params.model_alias);
|
338
|
+
get_env("LLAMA_ARG_HF_REPO", params.hf_repo);
|
339
|
+
get_env("LLAMA_ARG_HF_FILE", params.hf_file);
|
336
340
|
get_env("LLAMA_ARG_THREADS", params.n_threads);
|
337
341
|
get_env("LLAMA_ARG_CTX_SIZE", params.n_ctx);
|
338
342
|
get_env("LLAMA_ARG_N_PARALLEL", params.n_parallel);
|
@@ -347,6 +351,9 @@ void gpt_params_parse_from_env(gpt_params & params) {
|
|
347
351
|
get_env("LLAMA_ARG_EMBEDDINGS", params.embedding);
|
348
352
|
get_env("LLAMA_ARG_FLASH_ATTN", params.flash_attn);
|
349
353
|
get_env("LLAMA_ARG_DEFRAG_THOLD", params.defrag_thold);
|
354
|
+
get_env("LLAMA_ARG_CONT_BATCHING", params.cont_batching);
|
355
|
+
get_env("LLAMA_ARG_HOST", params.hostname);
|
356
|
+
get_env("LLAMA_ARG_PORT", params.port);
|
350
357
|
}
|
351
358
|
|
352
359
|
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
@@ -907,7 +914,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|
907
914
|
}
|
908
915
|
return true;
|
909
916
|
}
|
910
|
-
if (arg == "-ngld" || arg == "--gpu-layers-draft" || arg == "--gpu-layers-draft") {
|
917
|
+
if (arg == "-ngld" || arg == "--gpu-layers-draft" || arg == "--n-gpu-layers-draft") {
|
911
918
|
CHECK_ARG
|
912
919
|
params.n_gpu_layers_draft = std::stoi(argv[i]);
|
913
920
|
if (!llama_supports_gpu_offload()) {
|
@@ -1867,13 +1874,19 @@ std::string string_get_sortable_timestamp() {
|
|
1867
1874
|
|
1868
1875
|
void string_replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
1869
1876
|
if (search.empty()) {
|
1870
|
-
return;
|
1877
|
+
return;
|
1871
1878
|
}
|
1879
|
+
std::string builder;
|
1880
|
+
builder.reserve(s.length());
|
1872
1881
|
size_t pos = 0;
|
1873
|
-
|
1874
|
-
|
1875
|
-
pos
|
1876
|
-
|
1882
|
+
size_t last_pos = 0;
|
1883
|
+
while ((pos = s.find(search, last_pos)) != std::string::npos) {
|
1884
|
+
builder.append(s, last_pos, pos - last_pos);
|
1885
|
+
builder.append(replace);
|
1886
|
+
last_pos = pos + search.length();
|
1887
|
+
}
|
1888
|
+
builder.append(s, last_pos, std::string::npos);
|
1889
|
+
s = std::move(builder);
|
1877
1890
|
}
|
1878
1891
|
|
1879
1892
|
void string_process_escapes(std::string & input) {
|
package/cpp/ggml-aarch64.c
CHANGED
@@ -337,33 +337,18 @@ static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict ds
|
|
337
337
|
}
|
338
338
|
|
339
339
|
size_t quantize_q4_0_4x4(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
340
|
-
|
341
|
-
|
342
|
-
}
|
343
|
-
else {
|
344
|
-
assert(false);
|
345
|
-
return 0;
|
346
|
-
}
|
340
|
+
UNUSED(quant_weights);
|
341
|
+
return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4);
|
347
342
|
}
|
348
343
|
|
349
344
|
size_t quantize_q4_0_4x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
350
|
-
|
351
|
-
|
352
|
-
}
|
353
|
-
else {
|
354
|
-
assert(false);
|
355
|
-
return 0;
|
356
|
-
}
|
345
|
+
UNUSED(quant_weights);
|
346
|
+
return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8);
|
357
347
|
}
|
358
348
|
|
359
349
|
size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
360
|
-
|
361
|
-
|
362
|
-
}
|
363
|
-
else {
|
364
|
-
assert(false);
|
365
|
-
return 0;
|
366
|
-
}
|
350
|
+
UNUSED(quant_weights);
|
351
|
+
return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8);
|
367
352
|
}
|
368
353
|
|
369
354
|
void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
package/cpp/ggml-metal.m
CHANGED
@@ -82,6 +82,8 @@ enum lm_ggml_metal_kernel_type {
|
|
82
82
|
LM_GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
83
83
|
LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
84
84
|
LM_GGML_METAL_KERNEL_TYPE_NORM,
|
85
|
+
LM_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
86
|
+
LM_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
85
87
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
86
88
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
87
89
|
LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
@@ -542,6 +544,8 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(int n_cb) {
|
|
542
544
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
543
545
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
544
546
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
547
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
548
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
545
549
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
546
550
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
547
551
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
@@ -803,6 +807,9 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_context
|
|
803
807
|
return false;
|
804
808
|
}
|
805
809
|
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
810
|
+
case LM_GGML_OP_SSM_CONV:
|
811
|
+
case LM_GGML_OP_SSM_SCAN:
|
812
|
+
return true;
|
806
813
|
case LM_GGML_OP_MUL_MAT:
|
807
814
|
case LM_GGML_OP_MUL_MAT_ID:
|
808
815
|
return ctx->support_simdgroup_reduction &&
|
@@ -1538,6 +1545,121 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
|
|
1538
1545
|
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1539
1546
|
}
|
1540
1547
|
} break;
|
1548
|
+
case LM_GGML_OP_SSM_CONV:
|
1549
|
+
{
|
1550
|
+
LM_GGML_ASSERT(src0t == LM_GGML_TYPE_F32);
|
1551
|
+
LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32);
|
1552
|
+
|
1553
|
+
LM_GGML_ASSERT(lm_ggml_is_contiguous(src0));
|
1554
|
+
LM_GGML_ASSERT(lm_ggml_is_contiguous(src1));
|
1555
|
+
|
1556
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
|
1557
|
+
|
1558
|
+
[encoder setComputePipelineState:pipeline];
|
1559
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1560
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1561
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
1562
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
1563
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
1564
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
1565
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
1566
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
1567
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
1568
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
|
1569
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
|
1570
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
|
1571
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
|
1572
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
1573
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
1574
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15];
|
1575
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16];
|
1576
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17];
|
1577
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18];
|
1578
|
+
|
1579
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1580
|
+
} break;
|
1581
|
+
case LM_GGML_OP_SSM_SCAN:
|
1582
|
+
{
|
1583
|
+
struct lm_ggml_tensor * src3 = gf->nodes[i]->src[3];
|
1584
|
+
struct lm_ggml_tensor * src4 = gf->nodes[i]->src[4];
|
1585
|
+
struct lm_ggml_tensor * src5 = gf->nodes[i]->src[5];
|
1586
|
+
|
1587
|
+
LM_GGML_ASSERT(src3);
|
1588
|
+
LM_GGML_ASSERT(src4);
|
1589
|
+
LM_GGML_ASSERT(src5);
|
1590
|
+
|
1591
|
+
size_t offs_src3 = 0;
|
1592
|
+
size_t offs_src4 = 0;
|
1593
|
+
size_t offs_src5 = 0;
|
1594
|
+
|
1595
|
+
id<MTLBuffer> id_src3 = src3 ? lm_ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
1596
|
+
id<MTLBuffer> id_src4 = src4 ? lm_ggml_metal_get_buffer(src4, &offs_src4) : nil;
|
1597
|
+
id<MTLBuffer> id_src5 = src5 ? lm_ggml_metal_get_buffer(src5, &offs_src5) : nil;
|
1598
|
+
|
1599
|
+
const int64_t ne30 = src3->ne[0]; LM_GGML_UNUSED(ne30);
|
1600
|
+
const int64_t ne31 = src3->ne[1]; LM_GGML_UNUSED(ne31);
|
1601
|
+
|
1602
|
+
const uint64_t nb30 = src3->nb[0];
|
1603
|
+
const uint64_t nb31 = src3->nb[1];
|
1604
|
+
|
1605
|
+
const int64_t ne40 = src4->ne[0]; LM_GGML_UNUSED(ne40);
|
1606
|
+
const int64_t ne41 = src4->ne[1]; LM_GGML_UNUSED(ne41);
|
1607
|
+
const int64_t ne42 = src4->ne[2]; LM_GGML_UNUSED(ne42);
|
1608
|
+
|
1609
|
+
const uint64_t nb40 = src4->nb[0];
|
1610
|
+
const uint64_t nb41 = src4->nb[1];
|
1611
|
+
const uint64_t nb42 = src4->nb[2];
|
1612
|
+
|
1613
|
+
const int64_t ne50 = src5->ne[0]; LM_GGML_UNUSED(ne50);
|
1614
|
+
const int64_t ne51 = src5->ne[1]; LM_GGML_UNUSED(ne51);
|
1615
|
+
const int64_t ne52 = src5->ne[2]; LM_GGML_UNUSED(ne52);
|
1616
|
+
|
1617
|
+
const uint64_t nb50 = src5->nb[0];
|
1618
|
+
const uint64_t nb51 = src5->nb[1];
|
1619
|
+
const uint64_t nb52 = src5->nb[2];
|
1620
|
+
|
1621
|
+
const int64_t d_state = ne00;
|
1622
|
+
const int64_t d_inner = ne01;
|
1623
|
+
const int64_t n_seq_tokens = ne11;
|
1624
|
+
const int64_t n_seqs = ne02;
|
1625
|
+
|
1626
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
|
1627
|
+
|
1628
|
+
[encoder setComputePipelineState:pipeline];
|
1629
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1630
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
1631
|
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
1632
|
+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
1633
|
+
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
1634
|
+
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
1635
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
1636
|
+
|
1637
|
+
[encoder setBytes:&d_state length:sizeof(d_state) atIndex:7];
|
1638
|
+
[encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8];
|
1639
|
+
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
|
1640
|
+
[encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10];
|
1641
|
+
|
1642
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
|
1643
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
|
1644
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
|
1645
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
1646
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
1647
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
1648
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
1649
|
+
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
|
1650
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
|
1651
|
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
|
1652
|
+
[encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
|
1653
|
+
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
|
1654
|
+
[encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
|
1655
|
+
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
|
1656
|
+
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
|
1657
|
+
[encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
|
1658
|
+
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
|
1659
|
+
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
|
1660
|
+
|
1661
|
+
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1662
|
+
} break;
|
1541
1663
|
case LM_GGML_OP_MUL_MAT:
|
1542
1664
|
{
|
1543
1665
|
LM_GGML_ASSERT(ne00 == ne10);
|
@@ -2624,9 +2746,14 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
|
|
2624
2746
|
|
2625
2747
|
float scale;
|
2626
2748
|
float max_bias;
|
2749
|
+
float logit_softcap;
|
2750
|
+
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
|
2751
|
+
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
2752
|
+
memcpy(&logit_softcap, ((int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
|
2627
2753
|
|
2628
|
-
|
2629
|
-
|
2754
|
+
if (logit_softcap != 0.0f) {
|
2755
|
+
scale /= logit_softcap;
|
2756
|
+
}
|
2630
2757
|
|
2631
2758
|
const uint32_t n_head = src0->ne[2];
|
2632
2759
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
@@ -2677,30 +2804,31 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
|
|
2677
2804
|
} else {
|
2678
2805
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
|
2679
2806
|
}
|
2680
|
-
[encoder setBuffer:id_dst
|
2681
|
-
[encoder setBytes:&ne01
|
2682
|
-
[encoder setBytes:&ne02
|
2683
|
-
[encoder setBytes:&ne03
|
2684
|
-
[encoder setBytes:&nb01
|
2685
|
-
[encoder setBytes:&nb02
|
2686
|
-
[encoder setBytes:&nb03
|
2687
|
-
[encoder setBytes:&ne11
|
2688
|
-
[encoder setBytes:&ne12
|
2689
|
-
[encoder setBytes:&ne13
|
2690
|
-
[encoder setBytes:&nb11
|
2691
|
-
[encoder setBytes:&nb12
|
2692
|
-
[encoder setBytes:&nb13
|
2693
|
-
[encoder setBytes:&nb21
|
2694
|
-
[encoder setBytes:&nb22
|
2695
|
-
[encoder setBytes:&nb23
|
2696
|
-
[encoder setBytes:&nb31
|
2697
|
-
[encoder setBytes:&ne1
|
2698
|
-
[encoder setBytes:&ne2
|
2699
|
-
[encoder setBytes:&scale
|
2700
|
-
[encoder setBytes:&max_bias
|
2701
|
-
[encoder setBytes:&m0
|
2702
|
-
[encoder setBytes:&m1
|
2703
|
-
[encoder setBytes:&n_head_log2
|
2807
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
2808
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
2809
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
2810
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
2811
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
2812
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
2813
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
2814
|
+
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
|
2815
|
+
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
|
2816
|
+
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
|
2817
|
+
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
2818
|
+
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
2819
|
+
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
2820
|
+
[encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
|
2821
|
+
[encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
|
2822
|
+
[encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
|
2823
|
+
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
|
2824
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
|
2825
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
|
2826
|
+
[encoder setBytes:&scale length:sizeof( float) atIndex:23];
|
2827
|
+
[encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
|
2828
|
+
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
|
2829
|
+
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
|
2830
|
+
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
|
2831
|
+
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
|
2704
2832
|
|
2705
2833
|
if (!use_vec_kernel) {
|
2706
2834
|
// half8x8 kernel
|