cui-llama.rn 1.2.3 → 1.2.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +0 -2
- package/android/src/main/CMakeLists.txt +1 -0
- package/android/src/main/java/com/rnllama/LlamaContext.java +0 -3
- package/android/src/main/jni.cpp +9 -11
- package/cpp/common.cpp +85 -75
- package/cpp/common.h +127 -91
- package/cpp/ggml-aarch64.c +269 -0
- package/cpp/ggml-alloc.c +17 -19
- package/cpp/ggml-backend-impl.h +4 -15
- package/cpp/ggml-backend.cpp +1697 -1626
- package/cpp/ggml-backend.h +13 -25
- package/cpp/ggml-cpp.h +38 -0
- package/cpp/ggml-cpu.c +13720 -0
- package/cpp/ggml-cpu.h +150 -0
- package/cpp/ggml-impl.h +95 -0
- package/cpp/ggml-metal.m +185 -71
- package/cpp/ggml-quants.c +38 -51
- package/cpp/ggml.c +4468 -19500
- package/cpp/ggml.h +26 -146
- package/cpp/json-schema-to-grammar.cpp +1 -1
- package/cpp/llama-sampling.cpp +742 -249
- package/cpp/llama-sampling.h +21 -2
- package/cpp/llama-vocab.cpp +49 -9
- package/cpp/llama-vocab.h +35 -11
- package/cpp/llama.cpp +2468 -2307
- package/cpp/llama.h +65 -32
- package/cpp/log.cpp +50 -50
- package/cpp/log.h +18 -18
- package/cpp/rn-llama.hpp +23 -22
- package/cpp/sampling.cpp +117 -118
- package/cpp/sampling.h +20 -20
- package/cpp/sgemm.cpp +57 -0
- package/lib/commonjs/NativeRNLlama.js.map +1 -1
- package/lib/module/NativeRNLlama.js.map +1 -1
- package/lib/typescript/NativeRNLlama.d.ts +0 -1
- package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +0 -1
package/cpp/ggml-metal.m
CHANGED
@@ -241,6 +241,8 @@ enum lm_ggml_metal_kernel_type {
|
|
241
241
|
LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
|
242
242
|
LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
243
243
|
LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
244
|
+
LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
|
245
|
+
LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
|
244
246
|
LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
245
247
|
LM_GGML_METAL_KERNEL_TYPE_PAD_F32,
|
246
248
|
LM_GGML_METAL_KERNEL_TYPE_ARANGE_F32,
|
@@ -272,6 +274,8 @@ enum lm_ggml_metal_kernel_type {
|
|
272
274
|
LM_GGML_METAL_KERNEL_TYPE_SIN,
|
273
275
|
LM_GGML_METAL_KERNEL_TYPE_COS,
|
274
276
|
LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
277
|
+
LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
278
|
+
LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
275
279
|
|
276
280
|
LM_GGML_METAL_KERNEL_TYPE_COUNT
|
277
281
|
};
|
@@ -446,7 +450,14 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
446
450
|
LM_GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
447
451
|
return NULL;
|
448
452
|
}
|
453
|
+
|
454
|
+
#if !__has_feature(objc_arc)
|
455
|
+
[options release];
|
456
|
+
#endif
|
449
457
|
}
|
458
|
+
#if LM_GGML_METAL_EMBED_LIBRARY
|
459
|
+
[src release];
|
460
|
+
#endif // LM_GGML_METAL_EMBED_LIBRARY
|
450
461
|
}
|
451
462
|
}
|
452
463
|
|
@@ -685,6 +696,8 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
685
696
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
686
697
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
687
698
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
699
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
|
700
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
|
688
701
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
689
702
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
690
703
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
@@ -716,6 +729,8 @@ static struct lm_ggml_backend_metal_context * lm_ggml_metal_init(lm_ggml_backend
|
|
716
729
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
717
730
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
718
731
|
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
732
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
733
|
+
LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
719
734
|
}
|
720
735
|
|
721
736
|
[metal_library release];
|
@@ -844,8 +859,8 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_
|
|
844
859
|
case LM_GGML_OP_IM2COL:
|
845
860
|
return op->src[0]->type == LM_GGML_TYPE_F16;
|
846
861
|
case LM_GGML_OP_POOL_1D:
|
847
|
-
case LM_GGML_OP_POOL_2D:
|
848
862
|
return false;
|
863
|
+
case LM_GGML_OP_POOL_2D:
|
849
864
|
case LM_GGML_OP_UPSCALE:
|
850
865
|
case LM_GGML_OP_PAD:
|
851
866
|
case LM_GGML_OP_ARANGE:
|
@@ -1007,19 +1022,21 @@ static void lm_ggml_metal_encode_node(
|
|
1007
1022
|
id<MTLBuffer> id_src2 = src2 ? lm_ggml_metal_get_buffer(src2, &offs_src2) : nil;
|
1008
1023
|
id<MTLBuffer> id_dst = dst ? lm_ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
1009
1024
|
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1013
|
-
|
1014
|
-
|
1015
|
-
|
1016
|
-
|
1017
|
-
|
1018
|
-
|
1019
|
-
|
1020
|
-
|
1021
|
-
|
1022
|
-
|
1025
|
+
#if 0
|
1026
|
+
LM_GGML_LOG_INFO("%s: op - %s\n", __func__, lm_ggml_op_name(dst->op));
|
1027
|
+
if (src0) {
|
1028
|
+
LM_GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, lm_ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
|
1029
|
+
lm_ggml_is_contiguous(src0), src0->name);
|
1030
|
+
}
|
1031
|
+
if (src1) {
|
1032
|
+
LM_GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, lm_ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
|
1033
|
+
lm_ggml_is_contiguous(src1), src1->name);
|
1034
|
+
}
|
1035
|
+
if (dst) {
|
1036
|
+
LM_GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, lm_ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
|
1037
|
+
dst->name);
|
1038
|
+
}
|
1039
|
+
#endif
|
1023
1040
|
|
1024
1041
|
id<MTLDevice> device = ctx_dev->mtl_device;
|
1025
1042
|
|
@@ -1802,14 +1819,16 @@ static void lm_ggml_metal_encode_node(
|
|
1802
1819
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
1803
1820
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
1804
1821
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
1805
|
-
[encoder setBytes:&
|
1806
|
-
[encoder setBytes:&
|
1807
|
-
[encoder setBytes:&
|
1808
|
-
[encoder setBytes:&
|
1809
|
-
[encoder setBytes:&
|
1810
|
-
[encoder setBytes:&
|
1811
|
-
[encoder setBytes:&
|
1812
|
-
[encoder setBytes:&
|
1822
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7];
|
1823
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
1824
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9];
|
1825
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10];
|
1826
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11];
|
1827
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12];
|
1828
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
1829
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
1830
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:15];
|
1831
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:16];
|
1813
1832
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
1814
1833
|
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
1815
1834
|
} else {
|
@@ -1978,20 +1997,22 @@ static void lm_ggml_metal_encode_node(
|
|
1978
1997
|
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
1979
1998
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
1980
1999
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
1981
|
-
[encoder setBytes:&
|
1982
|
-
[encoder setBytes:&
|
1983
|
-
[encoder setBytes:&
|
1984
|
-
[encoder setBytes:&
|
1985
|
-
[encoder setBytes:&
|
1986
|
-
[encoder setBytes:&
|
1987
|
-
[encoder setBytes:&
|
1988
|
-
[encoder setBytes:&
|
1989
|
-
[encoder setBytes:&
|
1990
|
-
[encoder setBytes:&
|
2000
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
2001
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
2002
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
2003
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
2004
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13];
|
2005
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14];
|
2006
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15];
|
2007
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16];
|
2008
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
2009
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
|
2010
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:19];
|
2011
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:20];
|
1991
2012
|
|
1992
2013
|
if (src0t == LM_GGML_TYPE_Q4_0 || src0t == LM_GGML_TYPE_Q4_1 || src0t == LM_GGML_TYPE_Q5_0 ||
|
1993
|
-
|
1994
|
-
|
2014
|
+
src0t == LM_GGML_TYPE_Q5_1 || src0t == LM_GGML_TYPE_Q8_0 || src0t == LM_GGML_TYPE_Q2_K ||
|
2015
|
+
src0t == LM_GGML_TYPE_IQ1_S || src0t == LM_GGML_TYPE_IQ1_M || src0t == LM_GGML_TYPE_IQ2_S) {
|
1995
2016
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1996
2017
|
}
|
1997
2018
|
else if (src0t == LM_GGML_TYPE_IQ2_XXS || src0t == LM_GGML_TYPE_IQ2_XS) {
|
@@ -2040,6 +2061,9 @@ static void lm_ggml_metal_encode_node(
|
|
2040
2061
|
|
2041
2062
|
LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32);
|
2042
2063
|
|
2064
|
+
LM_GGML_ASSERT(ne03 == 1);
|
2065
|
+
LM_GGML_ASSERT(ne13 == 1);
|
2066
|
+
|
2043
2067
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
2044
2068
|
// to the matrix-vector kernel
|
2045
2069
|
// ne20 = n_used_experts
|
@@ -2545,6 +2569,8 @@ static void lm_ggml_metal_encode_node(
|
|
2545
2569
|
} break;
|
2546
2570
|
case LM_GGML_OP_IM2COL:
|
2547
2571
|
{
|
2572
|
+
LM_GGML_ASSERT(lm_ggml_is_contiguous(src0));
|
2573
|
+
LM_GGML_ASSERT(lm_ggml_is_contiguous(src1));
|
2548
2574
|
LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16);
|
2549
2575
|
LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32);
|
2550
2576
|
LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F16 || dst->type == LM_GGML_TYPE_F32);
|
@@ -2574,30 +2600,54 @@ static void lm_ggml_metal_encode_node(
|
|
2574
2600
|
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
2575
2601
|
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
2576
2602
|
|
2577
|
-
id<MTLComputePipelineState> pipeline =
|
2603
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
|
2604
|
+
|
2605
|
+
const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup;
|
2578
2606
|
|
2579
2607
|
switch (dst->type) {
|
2580
|
-
case LM_GGML_TYPE_F32:
|
2581
|
-
|
2608
|
+
case LM_GGML_TYPE_F32: {
|
2609
|
+
pipeline = (is_gt_mttpt ?
|
2610
|
+
ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline
|
2611
|
+
:
|
2612
|
+
ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline);
|
2613
|
+
} break;
|
2614
|
+
case LM_GGML_TYPE_F16: {
|
2615
|
+
pipeline = (is_gt_mttpt ?
|
2616
|
+
ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline
|
2617
|
+
:
|
2618
|
+
ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline);
|
2619
|
+
} break;
|
2582
2620
|
default: LM_GGML_ABORT("fatal error");
|
2583
2621
|
};
|
2584
2622
|
|
2585
2623
|
[encoder setComputePipelineState:pipeline];
|
2586
|
-
[encoder setBuffer:id_src1 offset:offs_src1
|
2587
|
-
[encoder setBuffer:id_dst offset:offs_dst
|
2588
|
-
[encoder setBytes:&ofs0 length:sizeof(
|
2589
|
-
[encoder setBytes:&ofs1 length:sizeof(
|
2590
|
-
[encoder setBytes:&IW length:sizeof(
|
2591
|
-
[encoder setBytes:&IH length:sizeof(
|
2592
|
-
[encoder setBytes:&CHW length:sizeof(
|
2593
|
-
[encoder setBytes:&s0 length:sizeof(
|
2594
|
-
[encoder setBytes:&s1 length:sizeof(
|
2595
|
-
[encoder setBytes:&p0 length:sizeof(
|
2596
|
-
[encoder setBytes:&p1 length:sizeof(
|
2597
|
-
[encoder setBytes:&d0 length:sizeof(
|
2598
|
-
[encoder setBytes:&d1 length:sizeof(
|
2599
|
-
|
2600
|
-
|
2624
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
2625
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
2626
|
+
[encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2];
|
2627
|
+
[encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3];
|
2628
|
+
[encoder setBytes:&IW length:sizeof(int32_t) atIndex:4];
|
2629
|
+
[encoder setBytes:&IH length:sizeof(int32_t) atIndex:5];
|
2630
|
+
[encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6];
|
2631
|
+
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7];
|
2632
|
+
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8];
|
2633
|
+
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9];
|
2634
|
+
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10];
|
2635
|
+
[encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
|
2636
|
+
[encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];
|
2637
|
+
|
2638
|
+
if (is_gt_mttpt) {
|
2639
|
+
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
|
2640
|
+
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
|
2641
|
+
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
|
2642
|
+
|
2643
|
+
const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
|
2644
|
+
|
2645
|
+
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
|
2646
|
+
|
2647
|
+
[encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
2648
|
+
} else {
|
2649
|
+
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
2650
|
+
}
|
2601
2651
|
} break;
|
2602
2652
|
case LM_GGML_OP_UPSCALE:
|
2603
2653
|
{
|
@@ -3001,6 +3051,64 @@ static void lm_ggml_metal_encode_node(
|
|
3001
3051
|
|
3002
3052
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
3003
3053
|
} break;
|
3054
|
+
case LM_GGML_OP_POOL_2D:
|
3055
|
+
{
|
3056
|
+
LM_GGML_ASSERT(lm_ggml_is_contiguous(src0));
|
3057
|
+
LM_GGML_ASSERT(src0t == LM_GGML_TYPE_F32 && src0t == dstt);
|
3058
|
+
|
3059
|
+
const int32_t * opts = dst->op_params;
|
3060
|
+
enum lm_ggml_op_pool op = opts[0];
|
3061
|
+
|
3062
|
+
id<MTLComputePipelineState> pipeline = nil;
|
3063
|
+
switch (src0t) {
|
3064
|
+
case LM_GGML_TYPE_F32: {
|
3065
|
+
switch(op) {
|
3066
|
+
case LM_GGML_OP_POOL_AVG:
|
3067
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break;
|
3068
|
+
case LM_GGML_OP_POOL_MAX:
|
3069
|
+
pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break;
|
3070
|
+
default: LM_GGML_ASSERT(false && "not implemented");
|
3071
|
+
}
|
3072
|
+
} break;
|
3073
|
+
default: LM_GGML_ASSERT(false && "not implemented");
|
3074
|
+
}
|
3075
|
+
|
3076
|
+
const int32_t k0 = opts[1];
|
3077
|
+
const int32_t k1 = opts[2];
|
3078
|
+
const int32_t s0 = opts[3];
|
3079
|
+
const int32_t s1 = opts[4];
|
3080
|
+
const int32_t p0 = opts[5];
|
3081
|
+
const int32_t p1 = opts[6];
|
3082
|
+
|
3083
|
+
const int64_t IH = src0->ne[1];
|
3084
|
+
const int64_t IW = src0->ne[0];
|
3085
|
+
|
3086
|
+
const int64_t N = dst->ne[3];
|
3087
|
+
const int64_t OC = dst->ne[2];
|
3088
|
+
const int64_t OH = dst->ne[1];
|
3089
|
+
const int64_t OW = dst->ne[0];
|
3090
|
+
|
3091
|
+
const int64_t parallel_elements = N * OC * OH * OW;
|
3092
|
+
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
|
3093
|
+
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
|
3094
|
+
|
3095
|
+
[encoder setComputePipelineState:pipeline];
|
3096
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
3097
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
3098
|
+
[encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2];
|
3099
|
+
[encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
|
3100
|
+
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
|
3101
|
+
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
|
3102
|
+
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
|
3103
|
+
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
|
3104
|
+
[encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
|
3105
|
+
[encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
|
3106
|
+
[encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
|
3107
|
+
[encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
|
3108
|
+
[encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12];
|
3109
|
+
|
3110
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
3111
|
+
} break;
|
3004
3112
|
default:
|
3005
3113
|
{
|
3006
3114
|
LM_GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, lm_ggml_op_name(dst->op));
|
@@ -3146,12 +3254,6 @@ static enum lm_ggml_status lm_ggml_metal_graph_compute(
|
|
3146
3254
|
|
3147
3255
|
// backend interface
|
3148
3256
|
|
3149
|
-
static const char * lm_ggml_backend_metal_buffer_get_name(lm_ggml_backend_buffer_t buffer) {
|
3150
|
-
return "Metal";
|
3151
|
-
|
3152
|
-
UNUSED(buffer);
|
3153
|
-
}
|
3154
|
-
|
3155
3257
|
static void lm_ggml_backend_metal_buffer_free_buffer(lm_ggml_backend_buffer_t buffer) {
|
3156
3258
|
struct lm_ggml_backend_metal_buffer_context * ctx = (struct lm_ggml_backend_metal_buffer_context *)buffer->context;
|
3157
3259
|
|
@@ -3206,7 +3308,6 @@ static void lm_ggml_backend_metal_buffer_clear(lm_ggml_backend_buffer_t buffer,
|
|
3206
3308
|
}
|
3207
3309
|
|
3208
3310
|
static struct lm_ggml_backend_buffer_i lm_ggml_backend_metal_buffer_i = {
|
3209
|
-
/* .get_name = */ lm_ggml_backend_metal_buffer_get_name,
|
3210
3311
|
/* .free_buffer = */ lm_ggml_backend_metal_buffer_free_buffer,
|
3211
3312
|
/* .get_base = */ lm_ggml_backend_metal_buffer_get_base,
|
3212
3313
|
/* .init_tensor = */ NULL,
|
@@ -3331,6 +3432,29 @@ lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_buffer_type(void) {
|
|
3331
3432
|
return &lm_ggml_backend_buffer_type_metal;
|
3332
3433
|
}
|
3333
3434
|
|
3435
|
+
static const char * lm_ggml_backend_metal_buffer_from_ptr_type_get_name(lm_ggml_backend_buffer_type_t buft) {
|
3436
|
+
return "Metal_Mapped";
|
3437
|
+
|
3438
|
+
UNUSED(buft);
|
3439
|
+
}
|
3440
|
+
|
3441
|
+
static lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_buffer_from_ptr_type(void) {
|
3442
|
+
static struct lm_ggml_backend_buffer_type lm_ggml_backend_buffer_from_ptr_type_metal = {
|
3443
|
+
/* .iface = */ {
|
3444
|
+
/* .get_name = */ lm_ggml_backend_metal_buffer_from_ptr_type_get_name,
|
3445
|
+
/* .alloc_buffer = */ lm_ggml_backend_metal_buffer_type_alloc_buffer,
|
3446
|
+
/* .get_alignment = */ lm_ggml_backend_metal_buffer_type_get_alignment,
|
3447
|
+
/* .get_max_size = */ lm_ggml_backend_metal_buffer_type_get_max_size,
|
3448
|
+
/* .get_alloc_size = */ NULL, // defaults to lm_ggml_nbytes
|
3449
|
+
/* .is_host = */ lm_ggml_backend_metal_buffer_type_is_host,
|
3450
|
+
},
|
3451
|
+
/* .device = */ &g_lm_ggml_backend_metal_device,
|
3452
|
+
/* .context = */ NULL,
|
3453
|
+
};
|
3454
|
+
|
3455
|
+
return &lm_ggml_backend_buffer_from_ptr_type_metal;
|
3456
|
+
}
|
3457
|
+
|
3334
3458
|
// TODO: obsoleted by lm_ggml_backend_metal_device_buffer_from_ptr
|
3335
3459
|
lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
|
3336
3460
|
struct lm_ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct lm_ggml_backend_metal_buffer_context));
|
@@ -3407,7 +3531,7 @@ lm_ggml_backend_buffer_t lm_ggml_backend_metal_buffer_from_ptr(void * data, size
|
|
3407
3531
|
}
|
3408
3532
|
}
|
3409
3533
|
|
3410
|
-
return lm_ggml_backend_buffer_init(
|
3534
|
+
return lm_ggml_backend_buffer_init(lm_ggml_backend_metal_buffer_from_ptr_type(), lm_ggml_backend_metal_buffer_i, ctx, size);
|
3411
3535
|
}
|
3412
3536
|
|
3413
3537
|
// backend
|
@@ -3428,12 +3552,6 @@ static void lm_ggml_backend_metal_free(lm_ggml_backend_t backend) {
|
|
3428
3552
|
free(backend);
|
3429
3553
|
}
|
3430
3554
|
|
3431
|
-
static lm_ggml_backend_buffer_type_t lm_ggml_backend_metal_get_default_buffer_type(lm_ggml_backend_t backend) {
|
3432
|
-
return lm_ggml_backend_metal_buffer_type();
|
3433
|
-
|
3434
|
-
UNUSED(backend);
|
3435
|
-
}
|
3436
|
-
|
3437
3555
|
static enum lm_ggml_status lm_ggml_backend_metal_graph_compute(lm_ggml_backend_t backend, struct lm_ggml_cgraph * cgraph) {
|
3438
3556
|
return lm_ggml_metal_graph_compute(backend, cgraph);
|
3439
3557
|
}
|
@@ -3500,7 +3618,6 @@ static void lm_ggml_backend_metal_set_n_cb(lm_ggml_backend_t backend, int n_cb)
|
|
3500
3618
|
static struct lm_ggml_backend_i lm_ggml_backend_metal_i = {
|
3501
3619
|
/* .get_name = */ lm_ggml_backend_metal_name,
|
3502
3620
|
/* .free = */ lm_ggml_backend_metal_free,
|
3503
|
-
/* .get_default_buffer_type = */ lm_ggml_backend_metal_get_default_buffer_type,
|
3504
3621
|
/* .set_tensor_async = */ NULL,
|
3505
3622
|
/* .get_tensor_async = */ NULL,
|
3506
3623
|
/* .cpy_tensor_async = */ NULL,
|
@@ -3510,9 +3627,6 @@ static struct lm_ggml_backend_i lm_ggml_backend_metal_i = {
|
|
3510
3627
|
/* .graph_plan_update = */ NULL,
|
3511
3628
|
/* .graph_plan_compute = */ NULL,
|
3512
3629
|
/* .graph_compute = */ lm_ggml_backend_metal_graph_compute,
|
3513
|
-
/* .supports_op = */ NULL,
|
3514
|
-
/* .supports_buft = */ NULL,
|
3515
|
-
/* .offload_op = */ NULL,
|
3516
3630
|
/* .event_record = */ NULL,
|
3517
3631
|
/* .event_wait = */ NULL,
|
3518
3632
|
};
|
@@ -3607,7 +3721,7 @@ static void lm_ggml_backend_metal_device_get_memory(lm_ggml_backend_dev_t dev, s
|
|
3607
3721
|
}
|
3608
3722
|
|
3609
3723
|
static enum lm_ggml_backend_dev_type lm_ggml_backend_metal_device_get_type(lm_ggml_backend_dev_t dev) {
|
3610
|
-
return
|
3724
|
+
return LM_GGML_BACKEND_DEVICE_TYPE_GPU;
|
3611
3725
|
|
3612
3726
|
LM_GGML_UNUSED(dev);
|
3613
3727
|
}
|
package/cpp/ggml-quants.c
CHANGED
@@ -4,7 +4,7 @@
|
|
4
4
|
#include "ggml-quants.h"
|
5
5
|
#include "ggml-impl.h"
|
6
6
|
#include "ggml-cpu-impl.h"
|
7
|
-
|
7
|
+
#include "ggml-cpu.h"
|
8
8
|
|
9
9
|
#include <math.h>
|
10
10
|
#include <string.h>
|
@@ -9104,10 +9104,8 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
9104
9104
|
|
9105
9105
|
#elif defined __AVX__
|
9106
9106
|
|
9107
|
-
const __m128i m4 = _mm_set1_epi8(0xF);
|
9108
9107
|
const __m128i m3 = _mm_set1_epi8(3);
|
9109
|
-
const __m128i
|
9110
|
-
const __m128i m2 = _mm_set1_epi8(2);
|
9108
|
+
const __m128i m15 = _mm_set1_epi8(15);
|
9111
9109
|
|
9112
9110
|
__m256 acc = _mm256_setzero_ps();
|
9113
9111
|
|
@@ -9119,12 +9117,20 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
9119
9117
|
const uint8_t * restrict qh = x[i].qh;
|
9120
9118
|
const int8_t * restrict q8 = y[i].qs;
|
9121
9119
|
|
9120
|
+
// handle the q6_k -32 offset separately using bsums
|
9121
|
+
const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums);
|
9122
|
+
const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1);
|
9122
9123
|
const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
|
9124
|
+
const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales);
|
9125
|
+
const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8));
|
9126
|
+
const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5);
|
9127
|
+
const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5);
|
9123
9128
|
|
9124
9129
|
__m128i sumi_0 = _mm_setzero_si128();
|
9125
9130
|
__m128i sumi_1 = _mm_setzero_si128();
|
9126
9131
|
|
9127
|
-
|
9132
|
+
int is = 0;
|
9133
|
+
|
9128
9134
|
for (int j = 0; j < QK_K/128; ++j) {
|
9129
9135
|
|
9130
9136
|
const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
|
@@ -9132,26 +9138,26 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
9132
9138
|
|
9133
9139
|
const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
|
9134
9140
|
const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
|
9135
|
-
const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(
|
9136
|
-
const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(
|
9137
|
-
const __m128i q4h_4 =
|
9138
|
-
const __m128i q4h_5 =
|
9139
|
-
const __m128i q4h_6 =
|
9140
|
-
const __m128i q4h_7 =
|
9141
|
+
const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2);
|
9142
|
+
const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2);
|
9143
|
+
const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48));
|
9144
|
+
const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48));
|
9145
|
+
const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2);
|
9146
|
+
const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2);
|
9141
9147
|
|
9142
9148
|
const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
9143
9149
|
const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
9144
9150
|
const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
9145
9151
|
const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
|
9146
9152
|
|
9147
|
-
const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0,
|
9148
|
-
const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1,
|
9149
|
-
const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0,
|
9150
|
-
const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1,
|
9151
|
-
const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4),
|
9152
|
-
const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4),
|
9153
|
-
const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4),
|
9154
|
-
const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4),
|
9153
|
+
const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0);
|
9154
|
+
const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1);
|
9155
|
+
const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2);
|
9156
|
+
const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3);
|
9157
|
+
const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4);
|
9158
|
+
const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5);
|
9159
|
+
const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6);
|
9160
|
+
const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7);
|
9155
9161
|
|
9156
9162
|
const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
9157
9163
|
const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
@@ -9162,15 +9168,6 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
9162
9168
|
const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
9163
9169
|
const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
|
9164
9170
|
|
9165
|
-
__m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0);
|
9166
|
-
__m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1);
|
9167
|
-
__m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2);
|
9168
|
-
__m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3);
|
9169
|
-
__m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4);
|
9170
|
-
__m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5);
|
9171
|
-
__m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6);
|
9172
|
-
__m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7);
|
9173
|
-
|
9174
9171
|
__m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
|
9175
9172
|
__m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
|
9176
9173
|
__m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
|
@@ -9180,32 +9177,20 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
9180
9177
|
__m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
|
9181
9178
|
__m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
|
9182
9179
|
|
9183
|
-
|
9184
|
-
|
9185
|
-
|
9186
|
-
|
9187
|
-
|
9188
|
-
p16_5 = _mm_sub_epi16(p16_5, q8s_5);
|
9189
|
-
p16_6 = _mm_sub_epi16(p16_6, q8s_6);
|
9190
|
-
p16_7 = _mm_sub_epi16(p16_7, q8s_7);
|
9191
|
-
|
9192
|
-
const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
|
9193
|
-
shuffle = _mm_add_epi8(shuffle, m2);
|
9194
|
-
const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
|
9195
|
-
shuffle = _mm_add_epi8(shuffle, m2);
|
9196
|
-
const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle);
|
9197
|
-
shuffle = _mm_add_epi8(shuffle, m2);
|
9198
|
-
const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle);
|
9199
|
-
shuffle = _mm_add_epi8(shuffle, m2);
|
9180
|
+
const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
|
9181
|
+
const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
|
9182
|
+
const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
|
9183
|
+
const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
|
9184
|
+
is += 4;
|
9200
9185
|
|
9201
9186
|
p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
|
9202
|
-
p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(
|
9187
|
+
p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1);
|
9203
9188
|
p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
|
9204
|
-
p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(
|
9189
|
+
p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3);
|
9205
9190
|
p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
|
9206
|
-
p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(
|
9191
|
+
p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5);
|
9207
9192
|
p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
|
9208
|
-
p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(
|
9193
|
+
p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7);
|
9209
9194
|
|
9210
9195
|
sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
|
9211
9196
|
sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
|
@@ -9214,8 +9199,10 @@ void lm_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void
|
|
9214
9199
|
|
9215
9200
|
}
|
9216
9201
|
|
9217
|
-
|
9218
|
-
|
9202
|
+
sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0);
|
9203
|
+
sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1);
|
9204
|
+
const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
|
9205
|
+
acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc);
|
9219
9206
|
}
|
9220
9207
|
|
9221
9208
|
*s = hsum_float_8(acc);
|