@fugood/llama.node 1.3.2 → 1.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +8 -3
- package/package.json +14 -14
- package/scripts/llama.cpp.patch +5 -5
- package/src/LlamaCompletionWorker.cpp +33 -33
- package/src/LlamaContext.cpp +17 -16
- package/src/llama.cpp/CMakeLists.txt +4 -0
- package/src/llama.cpp/common/CMakeLists.txt +6 -37
- package/src/llama.cpp/common/common.cpp +1 -5
- package/src/llama.cpp/common/download.cpp +47 -29
- package/src/llama.cpp/common/log.cpp +6 -0
- package/src/llama.cpp/common/log.h +2 -0
- package/src/llama.cpp/ggml/include/ggml.h +71 -0
- package/src/llama.cpp/ggml/src/CMakeLists.txt +16 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +15 -3
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +29 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +283 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +235 -34
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +289 -277
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +4 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +95 -42
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +16 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +2 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +10 -0
- package/src/llama.cpp/src/CMakeLists.txt +6 -0
- package/src/llama.cpp/src/llama-arch.cpp +32 -0
- package/src/llama.cpp/src/llama-arch.h +2 -0
- package/src/llama.cpp/src/llama-graph.cpp +2 -1
- package/src/llama.cpp/src/llama-model.cpp +102 -0
- package/src/llama.cpp/src/llama-model.h +2 -0
- package/src/llama.cpp/src/llama-sampling.cpp +10 -5
- package/src/llama.cpp/src/llama-vocab.cpp +16 -1
- package/src/llama.cpp/src/llama-vocab.h +1 -0
- package/src/llama.cpp/src/models/afmoe.cpp +187 -0
- package/src/llama.cpp/src/models/models.h +4 -0
- package/src/llama.cpp/src/unicode.cpp +77 -0
|
@@ -475,6 +475,7 @@ extern "C" {
|
|
|
475
475
|
GGML_OP_COS,
|
|
476
476
|
GGML_OP_SUM,
|
|
477
477
|
GGML_OP_SUM_ROWS,
|
|
478
|
+
GGML_OP_CUMSUM,
|
|
478
479
|
GGML_OP_MEAN,
|
|
479
480
|
GGML_OP_ARGMAX,
|
|
480
481
|
GGML_OP_COUNT_EQUAL,
|
|
@@ -530,6 +531,8 @@ extern "C" {
|
|
|
530
531
|
GGML_OP_TIMESTEP_EMBEDDING,
|
|
531
532
|
GGML_OP_ARGSORT,
|
|
532
533
|
GGML_OP_LEAKY_RELU,
|
|
534
|
+
GGML_OP_TRI,
|
|
535
|
+
GGML_OP_FILL,
|
|
533
536
|
|
|
534
537
|
GGML_OP_FLASH_ATTN_EXT,
|
|
535
538
|
GGML_OP_FLASH_ATTN_BACK,
|
|
@@ -542,6 +545,7 @@ extern "C" {
|
|
|
542
545
|
GGML_OP_RWKV_WKV6,
|
|
543
546
|
GGML_OP_GATED_LINEAR_ATTN,
|
|
544
547
|
GGML_OP_RWKV_WKV7,
|
|
548
|
+
GGML_OP_SOLVE_TRI,
|
|
545
549
|
|
|
546
550
|
GGML_OP_UNARY,
|
|
547
551
|
|
|
@@ -576,6 +580,8 @@ extern "C" {
|
|
|
576
580
|
GGML_UNARY_OP_HARDSWISH,
|
|
577
581
|
GGML_UNARY_OP_HARDSIGMOID,
|
|
578
582
|
GGML_UNARY_OP_EXP,
|
|
583
|
+
GGML_UNARY_OP_EXPM1,
|
|
584
|
+
GGML_UNARY_OP_SOFTPLUS,
|
|
579
585
|
GGML_UNARY_OP_GELU_ERF,
|
|
580
586
|
GGML_UNARY_OP_XIELU,
|
|
581
587
|
GGML_UNARY_OP_FLOOR,
|
|
@@ -620,6 +626,13 @@ extern "C" {
|
|
|
620
626
|
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
|
|
621
627
|
};
|
|
622
628
|
|
|
629
|
+
enum ggml_tri_type {
|
|
630
|
+
GGML_TRI_TYPE_UPPER_DIAG = 0,
|
|
631
|
+
GGML_TRI_TYPE_UPPER = 1,
|
|
632
|
+
GGML_TRI_TYPE_LOWER_DIAG = 2,
|
|
633
|
+
GGML_TRI_TYPE_LOWER = 3
|
|
634
|
+
};
|
|
635
|
+
|
|
623
636
|
struct ggml_init_params {
|
|
624
637
|
// memory pool
|
|
625
638
|
size_t mem_size; // bytes
|
|
@@ -957,6 +970,22 @@ extern "C" {
|
|
|
957
970
|
struct ggml_context * ctx,
|
|
958
971
|
struct ggml_tensor * a);
|
|
959
972
|
|
|
973
|
+
GGML_API struct ggml_tensor * ggml_expm1(
|
|
974
|
+
struct ggml_context * ctx,
|
|
975
|
+
struct ggml_tensor * a);
|
|
976
|
+
|
|
977
|
+
GGML_API struct ggml_tensor * ggml_expm1_inplace(
|
|
978
|
+
struct ggml_context * ctx,
|
|
979
|
+
struct ggml_tensor * a);
|
|
980
|
+
|
|
981
|
+
GGML_API struct ggml_tensor * ggml_softplus(
|
|
982
|
+
struct ggml_context * ctx,
|
|
983
|
+
struct ggml_tensor * a);
|
|
984
|
+
|
|
985
|
+
GGML_API struct ggml_tensor * ggml_softplus_inplace(
|
|
986
|
+
struct ggml_context * ctx,
|
|
987
|
+
struct ggml_tensor * a);
|
|
988
|
+
|
|
960
989
|
GGML_API struct ggml_tensor * ggml_sin(
|
|
961
990
|
struct ggml_context * ctx,
|
|
962
991
|
struct ggml_tensor * a);
|
|
@@ -983,6 +1012,10 @@ extern "C" {
|
|
|
983
1012
|
struct ggml_context * ctx,
|
|
984
1013
|
struct ggml_tensor * a);
|
|
985
1014
|
|
|
1015
|
+
GGML_API struct ggml_tensor * ggml_cumsum(
|
|
1016
|
+
struct ggml_context * ctx,
|
|
1017
|
+
struct ggml_tensor * a);
|
|
1018
|
+
|
|
986
1019
|
// mean along rows
|
|
987
1020
|
GGML_API struct ggml_tensor * ggml_mean(
|
|
988
1021
|
struct ggml_context * ctx,
|
|
@@ -2187,6 +2220,23 @@ extern "C" {
|
|
|
2187
2220
|
int shift2,
|
|
2188
2221
|
int shift3);
|
|
2189
2222
|
|
|
2223
|
+
// Convert matrix into a triangular one (upper, strict upper, lower or strict lower) by writing
|
|
2224
|
+
// zeroes everywhere outside the masked area
|
|
2225
|
+
GGML_API struct ggml_tensor * ggml_tri(
|
|
2226
|
+
struct ggml_context * ctx,
|
|
2227
|
+
struct ggml_tensor * a,
|
|
2228
|
+
enum ggml_tri_type type);
|
|
2229
|
+
|
|
2230
|
+
// Fill tensor a with constant c
|
|
2231
|
+
GGML_API struct ggml_tensor * ggml_fill(
|
|
2232
|
+
struct ggml_context * ctx,
|
|
2233
|
+
struct ggml_tensor * a,
|
|
2234
|
+
float c);
|
|
2235
|
+
|
|
2236
|
+
GGML_API struct ggml_tensor * ggml_fill_inplace(
|
|
2237
|
+
struct ggml_context * ctx,
|
|
2238
|
+
struct ggml_tensor * a,
|
|
2239
|
+
float c);
|
|
2190
2240
|
|
|
2191
2241
|
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
|
|
2192
2242
|
// timesteps: [N,]
|
|
@@ -2356,6 +2406,27 @@ extern "C" {
|
|
|
2356
2406
|
struct ggml_tensor * b,
|
|
2357
2407
|
struct ggml_tensor * state);
|
|
2358
2408
|
|
|
2409
|
+
/* Solves a specific equation of the form Ax=B, where A is a triangular matrix
|
|
2410
|
+
* without zeroes on the diagonal (i.e. invertible).
|
|
2411
|
+
* B can have any number of columns, but must have the same number of rows as A
|
|
2412
|
+
* If A is [n, n] and B is [n, m], then the result will be [n, m] as well
|
|
2413
|
+
* Has O(n^3) complexity (unlike most matrix ops out there), so use on cases
|
|
2414
|
+
* where n > 100 sparingly, pre-chunk if necessary.
|
|
2415
|
+
*
|
|
2416
|
+
* If left = false, solves xA=B instead
|
|
2417
|
+
* If lower = false, assumes upper triangular instead
|
|
2418
|
+
* If uni = true, assumes diagonal of A to be all ones (will override actual values)
|
|
2419
|
+
*
|
|
2420
|
+
* TODO: currently only lower, right, non-unitriangular variant is implemented
|
|
2421
|
+
*/
|
|
2422
|
+
GGML_API struct ggml_tensor * ggml_solve_tri(
|
|
2423
|
+
struct ggml_context * ctx,
|
|
2424
|
+
struct ggml_tensor * a,
|
|
2425
|
+
struct ggml_tensor * b,
|
|
2426
|
+
bool left,
|
|
2427
|
+
bool lower,
|
|
2428
|
+
bool uni);
|
|
2429
|
+
|
|
2359
2430
|
// custom operators
|
|
2360
2431
|
|
|
2361
2432
|
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
|
|
@@ -211,6 +211,11 @@ add_library(ggml-base
|
|
|
211
211
|
ggml-quants.h
|
|
212
212
|
gguf.cpp)
|
|
213
213
|
|
|
214
|
+
set_target_properties(ggml-base PROPERTIES
|
|
215
|
+
VERSION ${GGML_VERSION}
|
|
216
|
+
SOVERSION ${GGML_VERSION_MAJOR}
|
|
217
|
+
)
|
|
218
|
+
|
|
214
219
|
target_include_directories(ggml-base PRIVATE .)
|
|
215
220
|
if (GGML_BACKEND_DL)
|
|
216
221
|
target_compile_definitions(ggml-base PUBLIC GGML_BACKEND_DL)
|
|
@@ -220,6 +225,11 @@ add_library(ggml
|
|
|
220
225
|
ggml-backend-reg.cpp)
|
|
221
226
|
add_library(ggml::ggml ALIAS ggml)
|
|
222
227
|
|
|
228
|
+
set_target_properties(ggml PROPERTIES
|
|
229
|
+
VERSION ${GGML_VERSION}
|
|
230
|
+
SOVERSION ${GGML_VERSION_MAJOR}
|
|
231
|
+
)
|
|
232
|
+
|
|
223
233
|
if (GGML_BACKEND_DIR)
|
|
224
234
|
if (NOT GGML_BACKEND_DL)
|
|
225
235
|
message(FATAL_ERROR "GGML_BACKEND_DIR requires GGML_BACKEND_DL")
|
|
@@ -259,6 +269,12 @@ function(ggml_add_backend_library backend)
|
|
|
259
269
|
target_compile_definitions(${backend} PUBLIC GGML_BACKEND_SHARED)
|
|
260
270
|
endif()
|
|
261
271
|
|
|
272
|
+
# Set versioning properties for all backend libraries
|
|
273
|
+
set_target_properties(${backend} PROPERTIES
|
|
274
|
+
VERSION ${GGML_VERSION}
|
|
275
|
+
SOVERSION ${GGML_VERSION_MAJOR}
|
|
276
|
+
)
|
|
277
|
+
|
|
262
278
|
if(NOT GGML_AVAILABLE_BACKENDS)
|
|
263
279
|
set(GGML_AVAILABLE_BACKENDS "${backend}"
|
|
264
280
|
CACHE INTERNAL "List of backends for cmake package")
|
|
@@ -590,6 +590,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|
|
590
590
|
${KLEIDIAI_SRC}/kai/ukernels/
|
|
591
591
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/
|
|
592
592
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
|
|
593
|
+
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/
|
|
593
594
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/
|
|
594
595
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
|
|
595
596
|
|
|
@@ -608,23 +609,34 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|
|
608
609
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c
|
|
609
610
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
|
|
610
611
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c
|
|
611
|
-
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c
|
|
612
|
+
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c
|
|
613
|
+
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c
|
|
614
|
+
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.c)
|
|
612
615
|
|
|
613
616
|
if (NOT DOTPROD_ENABLED MATCHES -1)
|
|
614
617
|
list(APPEND GGML_KLEIDIAI_SOURCES
|
|
615
618
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c
|
|
616
619
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c
|
|
617
|
-
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c
|
|
620
|
+
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c
|
|
621
|
+
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c
|
|
622
|
+
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c
|
|
623
|
+
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c)
|
|
618
624
|
endif()
|
|
619
625
|
|
|
620
626
|
if (NOT I8MM_ENABLED MATCHES -1)
|
|
621
|
-
list(APPEND GGML_KLEIDIAI_SOURCES
|
|
627
|
+
list(APPEND GGML_KLEIDIAI_SOURCES
|
|
628
|
+
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c
|
|
629
|
+
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c)
|
|
622
630
|
endif()
|
|
623
631
|
|
|
624
632
|
if (NOT SME_ENABLED MATCHES -1)
|
|
625
633
|
list(APPEND GGML_KLEIDIAI_SOURCES
|
|
626
634
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
|
|
627
635
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
|
|
636
|
+
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c
|
|
637
|
+
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S
|
|
638
|
+
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.c
|
|
639
|
+
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S
|
|
628
640
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
|
|
629
641
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S
|
|
630
642
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
|
|
@@ -1731,6 +1731,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
|
1731
1731
|
{
|
|
1732
1732
|
ggml_compute_forward_sum_rows(params, tensor);
|
|
1733
1733
|
} break;
|
|
1734
|
+
case GGML_OP_CUMSUM:
|
|
1735
|
+
{
|
|
1736
|
+
ggml_compute_forward_cumsum(params, tensor);
|
|
1737
|
+
} break;
|
|
1734
1738
|
case GGML_OP_MEAN:
|
|
1735
1739
|
{
|
|
1736
1740
|
ggml_compute_forward_mean(params, tensor);
|
|
@@ -1927,6 +1931,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
|
1927
1931
|
{
|
|
1928
1932
|
ggml_compute_forward_leaky_relu(params, tensor);
|
|
1929
1933
|
} break;
|
|
1934
|
+
case GGML_OP_TRI:
|
|
1935
|
+
{
|
|
1936
|
+
ggml_compute_forward_tri(params, tensor);
|
|
1937
|
+
} break;
|
|
1938
|
+
case GGML_OP_FILL:
|
|
1939
|
+
{
|
|
1940
|
+
ggml_compute_forward_fill(params, tensor);
|
|
1941
|
+
} break;
|
|
1930
1942
|
case GGML_OP_FLASH_ATTN_EXT:
|
|
1931
1943
|
{
|
|
1932
1944
|
ggml_compute_forward_flash_attn_ext(params, tensor);
|
|
@@ -1982,6 +1994,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
|
1982
1994
|
{
|
|
1983
1995
|
ggml_compute_forward_rwkv_wkv7(params, tensor);
|
|
1984
1996
|
} break;
|
|
1997
|
+
case GGML_OP_SOLVE_TRI:
|
|
1998
|
+
{
|
|
1999
|
+
ggml_compute_forward_solve_tri(params, tensor);
|
|
2000
|
+
} break;
|
|
1985
2001
|
case GGML_OP_MAP_CUSTOM1:
|
|
1986
2002
|
{
|
|
1987
2003
|
ggml_compute_forward_map_custom1(params, tensor);
|
|
@@ -2140,6 +2156,9 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
|
2140
2156
|
case GGML_OP_ADD_ID:
|
|
2141
2157
|
case GGML_OP_ADD1:
|
|
2142
2158
|
case GGML_OP_ACC:
|
|
2159
|
+
case GGML_OP_CUMSUM:
|
|
2160
|
+
case GGML_OP_TRI:
|
|
2161
|
+
case GGML_OP_FILL:
|
|
2143
2162
|
{
|
|
2144
2163
|
n_tasks = n_threads;
|
|
2145
2164
|
} break;
|
|
@@ -2157,6 +2176,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
|
2157
2176
|
n_tasks = 1;
|
|
2158
2177
|
} break;
|
|
2159
2178
|
case GGML_OP_COUNT_EQUAL:
|
|
2179
|
+
case GGML_OP_SOLVE_TRI:
|
|
2160
2180
|
{
|
|
2161
2181
|
n_tasks = n_threads;
|
|
2162
2182
|
} break;
|
|
@@ -2179,6 +2199,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
|
2179
2199
|
case GGML_UNARY_OP_HARDSWISH:
|
|
2180
2200
|
case GGML_UNARY_OP_HARDSIGMOID:
|
|
2181
2201
|
case GGML_UNARY_OP_EXP:
|
|
2202
|
+
case GGML_UNARY_OP_SOFTPLUS:
|
|
2203
|
+
case GGML_UNARY_OP_EXPM1:
|
|
2182
2204
|
case GGML_UNARY_OP_FLOOR:
|
|
2183
2205
|
case GGML_UNARY_OP_CEIL:
|
|
2184
2206
|
case GGML_UNARY_OP_ROUND:
|
|
@@ -3274,6 +3296,13 @@ void ggml_cpu_fp16_to_fp32(const ggml_fp16_t * x, float * y, int64_t n) {
|
|
|
3274
3296
|
__m128 y_vec = _mm_cvtph_ps(x_vec);
|
|
3275
3297
|
_mm_storeu_ps(y + i, y_vec);
|
|
3276
3298
|
}
|
|
3299
|
+
#elif defined(__riscv_zvfh)
|
|
3300
|
+
for (int vl; i < n; i += vl) {
|
|
3301
|
+
vl = __riscv_vsetvl_e16m1(n - i);
|
|
3302
|
+
vfloat16m1_t vx = __riscv_vle16_v_f16m1((_Float16 *)&x[i], vl);
|
|
3303
|
+
vfloat32m2_t vy = __riscv_vfwcvt_f_f_v_f32m2(vx, vl);
|
|
3304
|
+
__riscv_vse32_v_f32m2(&y[i], vy, vl);
|
|
3305
|
+
}
|
|
3277
3306
|
#endif
|
|
3278
3307
|
|
|
3279
3308
|
for (; i < n; ++i) {
|
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
|
|
5
5
|
// KleidiAI micro-kernels
|
|
6
6
|
#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
|
|
7
|
+
#include "kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h"
|
|
7
8
|
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
|
|
8
9
|
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
|
|
9
10
|
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
|
|
@@ -11,20 +12,31 @@
|
|
|
11
12
|
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
|
|
12
13
|
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
|
|
13
14
|
#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
|
|
15
|
+
#include "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
|
|
16
|
+
#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h"
|
|
17
|
+
#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h"
|
|
18
|
+
#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h"
|
|
19
|
+
#include "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h"
|
|
20
|
+
#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
|
|
14
21
|
|
|
15
22
|
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
|
|
16
23
|
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
|
|
17
24
|
#include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
|
|
18
25
|
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
|
|
26
|
+
#include "kai_lhs_quant_pack_qai8dxp_f32.h"
|
|
19
27
|
|
|
20
28
|
#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
|
|
21
29
|
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
|
|
22
30
|
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
|
|
31
|
+
#include "kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h"
|
|
23
32
|
|
|
24
33
|
#include "kai_common.h"
|
|
25
34
|
|
|
26
35
|
#include "simd-mappings.h"
|
|
27
36
|
|
|
37
|
+
#define GGML_COMMON_DECL_CPP
|
|
38
|
+
#include "ggml-common.h"
|
|
39
|
+
|
|
28
40
|
#include "kernels.h"
|
|
29
41
|
|
|
30
42
|
#define NELEMS(x) sizeof(x) / sizeof(*x)
|
|
@@ -55,6 +67,14 @@ static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
|
|
|
55
67
|
Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max);
|
|
56
68
|
}
|
|
57
69
|
|
|
70
|
+
template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>
|
|
71
|
+
static inline void kernel_run_float_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
|
|
72
|
+
const void* lhs, const void* rhs, void* dst,
|
|
73
|
+
size_t dst_stride_row, size_t dst_stride_col,
|
|
74
|
+
float clamp_min, float clamp_max) {
|
|
75
|
+
Fn(m, n, k, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
|
|
76
|
+
}
|
|
77
|
+
|
|
58
78
|
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
|
|
59
79
|
static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
|
|
60
80
|
return Fn(m, k, bl, mr, kr, sr);
|
|
@@ -93,6 +113,12 @@ static inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t m
|
|
|
93
113
|
Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
|
|
94
114
|
}
|
|
95
115
|
|
|
116
|
+
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const float*,size_t,void*)>
|
|
117
|
+
static inline void lhs_pack_float_fn9_no_bl(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr,
|
|
118
|
+
size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed) {
|
|
119
|
+
Fn(m, k, mr, kr, sr, m_idx_start, static_cast<const float*>(lhs), lhs_stride, lhs_packed);
|
|
120
|
+
}
|
|
121
|
+
|
|
96
122
|
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
|
|
97
123
|
static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
|
|
98
124
|
return Fn(n, k, nr, kr, bl);
|
|
@@ -124,6 +150,18 @@ static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t n
|
|
|
124
150
|
static_cast<const kai_rhs_pack_qs4cxs1s0_param*>(params));
|
|
125
151
|
}
|
|
126
152
|
|
|
153
|
+
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const int8_t*,const float*,const float*,void*,size_t,const struct kai_rhs_pack_qsi8cx_params*)>
|
|
154
|
+
static inline void rhs_pack_scale_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
|
|
155
|
+
size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale,
|
|
156
|
+
void* rhs_packed, size_t extra_bytes, const void* params) {
|
|
157
|
+
Fn(num_groups, n, k, nr, kr, sr,
|
|
158
|
+
static_cast<const int8_t*>(rhs),
|
|
159
|
+
static_cast<const float*>(bias),
|
|
160
|
+
static_cast<const float*>(scale),
|
|
161
|
+
rhs_packed, extra_bytes,
|
|
162
|
+
static_cast<const kai_rhs_pack_qsi8cx_params*>(params));
|
|
163
|
+
}
|
|
164
|
+
|
|
127
165
|
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,const void*,const void*,void*,size_t,const void*)>
|
|
128
166
|
static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
|
|
129
167
|
size_t rhs_stride, const void* rhs, const void* bias, const void* scale,
|
|
@@ -213,6 +251,57 @@ static void dequantize_row_qsi4c32ps1s0scalef16(
|
|
|
213
251
|
GGML_UNUSED(kr);
|
|
214
252
|
}
|
|
215
253
|
|
|
254
|
+
static void dequantize_row_qsi8cxp(
|
|
255
|
+
const void *packed_data,
|
|
256
|
+
int32_t row_idx,
|
|
257
|
+
int64_t k,
|
|
258
|
+
float *out,
|
|
259
|
+
size_t nr,
|
|
260
|
+
size_t packed_row_stride,
|
|
261
|
+
size_t kr,
|
|
262
|
+
size_t bl,
|
|
263
|
+
size_t num_bytes_multiplier
|
|
264
|
+
) {
|
|
265
|
+
GGML_UNUSED(bl);
|
|
266
|
+
GGML_UNUSED(num_bytes_multiplier);
|
|
267
|
+
|
|
268
|
+
const size_t k_internal = ((size_t) k + QK8_0 - 1) / QK8_0 * QK8_0;
|
|
269
|
+
const size_t group_idx = row_idx / nr;
|
|
270
|
+
const size_t row_in_group = row_idx % nr;
|
|
271
|
+
|
|
272
|
+
const uint8_t * group_ptr = static_cast<const uint8_t *>(packed_data) + group_idx * packed_row_stride;
|
|
273
|
+
const int8_t * data_base = reinterpret_cast<const int8_t *>(group_ptr);
|
|
274
|
+
|
|
275
|
+
const size_t num_blocks = k_internal / kr;
|
|
276
|
+
|
|
277
|
+
for (size_t block = 0; block < num_blocks; ++block) {
|
|
278
|
+
const int8_t * block_ptr = data_base + (block * nr + row_in_group) * kr;
|
|
279
|
+
for (size_t i = 0; i < kr; ++i) {
|
|
280
|
+
const size_t k_idx = block * kr + i;
|
|
281
|
+
if (k_idx < (size_t) k) {
|
|
282
|
+
out[k_idx] = static_cast<float>(block_ptr[i]);
|
|
283
|
+
}
|
|
284
|
+
}
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
const uint8_t * sums_ptr = group_ptr + nr * k_internal;
|
|
288
|
+
GGML_UNUSED(sums_ptr);
|
|
289
|
+
|
|
290
|
+
const float * scale_ptr = reinterpret_cast<const float *>(sums_ptr + nr * sizeof(int32_t));
|
|
291
|
+
const float scale = scale_ptr[row_in_group];
|
|
292
|
+
|
|
293
|
+
if (scale == 0.0f) {
|
|
294
|
+
for (size_t i = 0; i < (size_t) k; ++i) {
|
|
295
|
+
out[i] = 0.0f;
|
|
296
|
+
}
|
|
297
|
+
return;
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
for (size_t i = 0; i < (size_t) k; ++i) {
|
|
301
|
+
out[i] *= scale;
|
|
302
|
+
}
|
|
303
|
+
}
|
|
304
|
+
|
|
216
305
|
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
217
306
|
#if defined(__ARM_FEATURE_SME)
|
|
218
307
|
{
|
|
@@ -548,6 +637,174 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
548
637
|
#endif
|
|
549
638
|
};
|
|
550
639
|
|
|
640
|
+
static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
|
|
641
|
+
#if defined(__ARM_FEATURE_SME)
|
|
642
|
+
{
|
|
643
|
+
/* SME GEMM */
|
|
644
|
+
{
|
|
645
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
646
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
647
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
648
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
649
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
650
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
651
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
652
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
|
653
|
+
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
|
|
654
|
+
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
|
|
655
|
+
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
|
|
656
|
+
},
|
|
657
|
+
/* .gemm_lhs_info = */ {
|
|
658
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
|
659
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
|
660
|
+
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
|
661
|
+
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
|
662
|
+
},
|
|
663
|
+
/* SME GEMV */
|
|
664
|
+
{
|
|
665
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
666
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
667
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
668
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
669
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
670
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
671
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
672
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
|
673
|
+
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
|
|
674
|
+
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
|
|
675
|
+
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
|
|
676
|
+
},
|
|
677
|
+
/* .gemv_lhs_info = */ {
|
|
678
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
|
679
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
|
680
|
+
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
|
681
|
+
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
|
682
|
+
},
|
|
683
|
+
/* .rhs_info = */ {
|
|
684
|
+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
|
|
685
|
+
/* .to_float = */ dequantize_row_qsi8cxp,
|
|
686
|
+
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
687
|
+
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
688
|
+
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
689
|
+
},
|
|
690
|
+
/* .required_cpu = */ CPU_FEATURE_SME,
|
|
691
|
+
/* .lhs_type = */ GGML_TYPE_F32,
|
|
692
|
+
/* .rhs_type = */ GGML_TYPE_Q8_0,
|
|
693
|
+
/* .op_type = */ GGML_TYPE_F32,
|
|
694
|
+
},
|
|
695
|
+
#endif
|
|
696
|
+
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
|
697
|
+
{
|
|
698
|
+
/* I8MM GEMM */
|
|
699
|
+
{
|
|
700
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
701
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
702
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
703
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
704
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
705
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
706
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
707
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
|
708
|
+
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
|
|
709
|
+
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
|
|
710
|
+
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
|
|
711
|
+
},
|
|
712
|
+
/* .gemm_lhs_info = */ {
|
|
713
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
|
714
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
|
715
|
+
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
|
716
|
+
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
|
717
|
+
},
|
|
718
|
+
/* I8MM GEMV (dotprod fallback) */
|
|
719
|
+
{
|
|
720
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
721
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
722
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
723
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
724
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
725
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
726
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
727
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
|
728
|
+
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
|
|
729
|
+
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
|
|
730
|
+
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
|
|
731
|
+
},
|
|
732
|
+
/* .gemv_lhs_info = */ {
|
|
733
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
|
734
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
|
735
|
+
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
|
736
|
+
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
|
737
|
+
},
|
|
738
|
+
/* .rhs_info = */ {
|
|
739
|
+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
|
|
740
|
+
/* .to_float = */ dequantize_row_qsi8cxp,
|
|
741
|
+
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
742
|
+
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
743
|
+
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
744
|
+
},
|
|
745
|
+
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
|
746
|
+
/* .lhs_type = */ GGML_TYPE_F32,
|
|
747
|
+
/* .rhs_type = */ GGML_TYPE_Q8_0,
|
|
748
|
+
/* .op_type = */ GGML_TYPE_F32,
|
|
749
|
+
},
|
|
750
|
+
#endif
|
|
751
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
|
752
|
+
{
|
|
753
|
+
/* DOTPROD GEMM */
|
|
754
|
+
{
|
|
755
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
756
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
757
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
758
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
759
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
760
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
761
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
762
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
|
763
|
+
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
|
|
764
|
+
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
|
|
765
|
+
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
|
|
766
|
+
},
|
|
767
|
+
/* .gemm_lhs_info = */ {
|
|
768
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
|
769
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
|
770
|
+
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
|
771
|
+
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
|
772
|
+
},
|
|
773
|
+
/* DOTPROD GEMV */
|
|
774
|
+
{
|
|
775
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
776
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
777
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
778
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
779
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
780
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
781
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
782
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
|
783
|
+
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
|
|
784
|
+
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
|
|
785
|
+
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
|
|
786
|
+
},
|
|
787
|
+
/* .gemv_lhs_info = */ {
|
|
788
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
|
789
|
+
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
|
790
|
+
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
|
791
|
+
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
|
792
|
+
},
|
|
793
|
+
/* .rhs_info = */ {
|
|
794
|
+
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
|
|
795
|
+
/* .to_float = */ dequantize_row_qsi8cxp,
|
|
796
|
+
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
797
|
+
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
798
|
+
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
|
799
|
+
},
|
|
800
|
+
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
|
801
|
+
/* .lhs_type = */ GGML_TYPE_F32,
|
|
802
|
+
/* .rhs_type = */ GGML_TYPE_Q8_0,
|
|
803
|
+
/* .op_type = */ GGML_TYPE_F32,
|
|
804
|
+
},
|
|
805
|
+
#endif
|
|
806
|
+
};
|
|
807
|
+
|
|
551
808
|
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
|
|
552
809
|
ggml_kleidiai_kernels * kernel = nullptr;
|
|
553
810
|
|
|
@@ -562,6 +819,17 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
|
|
562
819
|
break;
|
|
563
820
|
}
|
|
564
821
|
}
|
|
822
|
+
if (!kernel) {
|
|
823
|
+
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
|
|
824
|
+
if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu &&
|
|
825
|
+
gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type &&
|
|
826
|
+
gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type &&
|
|
827
|
+
gemm_gemv_kernels_q8[i].op_type == tensor->type) {
|
|
828
|
+
kernel = &gemm_gemv_kernels_q8[i];
|
|
829
|
+
break;
|
|
830
|
+
}
|
|
831
|
+
}
|
|
832
|
+
}
|
|
565
833
|
#endif
|
|
566
834
|
}
|
|
567
835
|
|
|
@@ -582,3 +850,18 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features)
|
|
|
582
850
|
|
|
583
851
|
return kernels;
|
|
584
852
|
}
|
|
853
|
+
|
|
854
|
+
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features) {
|
|
855
|
+
ggml_kleidiai_kernels * kernels = nullptr;
|
|
856
|
+
|
|
857
|
+
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
|
858
|
+
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
|
|
859
|
+
if ((features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu) {
|
|
860
|
+
kernels = &gemm_gemv_kernels_q8[i];
|
|
861
|
+
break;
|
|
862
|
+
}
|
|
863
|
+
}
|
|
864
|
+
#endif
|
|
865
|
+
|
|
866
|
+
return kernels;
|
|
867
|
+
}
|
|
@@ -87,3 +87,4 @@ struct ggml_kleidiai_kernels {
|
|
|
87
87
|
|
|
88
88
|
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor);
|
|
89
89
|
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features);
|
|
90
|
+
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features);
|