@fugood/llama.node 0.3.13 → 0.3.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +1 -1
- package/package.json +1 -1
- package/src/LlamaContext.cpp +98 -76
- package/src/LlamaContext.h +1 -1
- package/src/common.hpp +1 -2
- package/src/llama.cpp/.github/workflows/build.yml +60 -10
- package/src/llama.cpp/.github/workflows/server.yml +2 -0
- package/src/llama.cpp/common/CMakeLists.txt +3 -3
- package/src/llama.cpp/common/arg.cpp +112 -11
- package/src/llama.cpp/common/chat.cpp +960 -266
- package/src/llama.cpp/common/chat.h +135 -0
- package/src/llama.cpp/common/common.cpp +27 -171
- package/src/llama.cpp/common/common.h +27 -67
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
- package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
- package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +37 -5
- package/src/llama.cpp/common/ngram-cache.cpp +1 -0
- package/src/llama.cpp/common/sampling.cpp +45 -7
- package/src/llama.cpp/common/speculative.cpp +6 -5
- package/src/llama.cpp/common/speculative.h +1 -1
- package/src/llama.cpp/docs/build.md +45 -7
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +3 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +1 -0
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +2 -3
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +1 -1
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +373 -107
- package/src/llama.cpp/examples/llava/clip.h +19 -3
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
- package/src/llama.cpp/examples/llava/llava.cpp +4 -2
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -0
- package/src/llama.cpp/examples/main/main.cpp +73 -28
- package/src/llama.cpp/examples/parallel/parallel.cpp +1 -0
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -0
- package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
- package/src/llama.cpp/examples/run/run.cpp +110 -67
- package/src/llama.cpp/examples/server/server.cpp +82 -87
- package/src/llama.cpp/examples/server/utils.hpp +94 -107
- package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
- package/src/llama.cpp/examples/tts/tts.cpp +251 -142
- package/src/llama.cpp/ggml/CMakeLists.txt +13 -1
- package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-cpu.h +3 -0
- package/src/llama.cpp/ggml/include/ggml.h +5 -1
- package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
- package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +132 -17
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +151 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1396 -386
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1432 -151
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +22 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +15 -2
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +235 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -2
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +220 -116
- package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +51 -10
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +136 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +308 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +168 -721
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -77
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +146 -42
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +13 -3
- package/src/llama.cpp/ggml/src/ggml.c +8 -3
- package/src/llama.cpp/include/llama.h +19 -5
- package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +1 -0
- package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
- package/src/llama.cpp/requirements.txt +1 -0
- package/src/llama.cpp/src/llama-arch.cpp +21 -0
- package/src/llama.cpp/src/llama-arch.h +1 -0
- package/src/llama.cpp/src/llama-chat.cpp +1 -0
- package/src/llama.cpp/src/llama-grammar.cpp +182 -182
- package/src/llama.cpp/src/llama-grammar.h +12 -3
- package/src/llama.cpp/src/llama-kv-cache.h +1 -0
- package/src/llama.cpp/src/llama-mmap.cpp +11 -1
- package/src/llama.cpp/src/llama-model.cpp +69 -5
- package/src/llama.cpp/src/llama-sampling.cpp +43 -10
- package/src/llama.cpp/src/llama-vocab.cpp +12 -0
- package/src/llama.cpp/src/llama.cpp +147 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +166 -110
- package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
- package/src/llama.cpp/tests/test-chat.cpp +593 -395
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
- package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
- package/src/llama.cpp/Sources/llama/llama.h +0 -4
- package/src/llama.cpp/common/chat.hpp +0 -55
- /package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +0 -0
|
@@ -14,6 +14,10 @@
|
|
|
14
14
|
#include "ggml-cpu-hbm.h"
|
|
15
15
|
#endif
|
|
16
16
|
|
|
17
|
+
#ifdef GGML_USE_CPU_KLEIDIAI
|
|
18
|
+
#include "kleidiai/kleidiai.h"
|
|
19
|
+
#endif
|
|
20
|
+
|
|
17
21
|
#if defined(__APPLE__)
|
|
18
22
|
#include <sys/types.h>
|
|
19
23
|
#include <sys/sysctl.h>
|
|
@@ -39,6 +43,12 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
|
|
|
39
43
|
}
|
|
40
44
|
#endif
|
|
41
45
|
|
|
46
|
+
#ifdef GGML_USE_CPU_KLEIDIAI
|
|
47
|
+
if (ggml_backend_cpu_kleidiai_buffer_type()) {
|
|
48
|
+
bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type());
|
|
49
|
+
}
|
|
50
|
+
#endif
|
|
51
|
+
|
|
42
52
|
#ifdef GGML_USE_CPU_AARCH64
|
|
43
53
|
if (ggml_backend_cpu_aarch64_buffer_type()) {
|
|
44
54
|
bufts.push_back(ggml_backend_cpu_aarch64_buffer_type());
|
|
@@ -501,6 +511,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
|
|
|
501
511
|
if (ggml_cpu_has_fma()) {
|
|
502
512
|
features.push_back({ "FMA", "1" });
|
|
503
513
|
}
|
|
514
|
+
if (ggml_cpu_has_bmi2()) {
|
|
515
|
+
features.push_back({ "BMI2", "1" });
|
|
516
|
+
}
|
|
504
517
|
if (ggml_cpu_has_avx512()) {
|
|
505
518
|
features.push_back({ "AVX512", "1" });
|
|
506
519
|
}
|
|
@@ -538,12 +551,18 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
|
|
|
538
551
|
static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt());
|
|
539
552
|
features.push_back({ "SVE_CNT", sve_cnt.c_str() });
|
|
540
553
|
}
|
|
554
|
+
if (ggml_cpu_has_sme()) {
|
|
555
|
+
features.push_back({ "SME", "1" });
|
|
556
|
+
}
|
|
541
557
|
if (ggml_cpu_has_riscv_v()) {
|
|
542
558
|
features.push_back({ "RISCV_V", "1" });
|
|
543
559
|
}
|
|
544
560
|
if (ggml_cpu_has_vsx()) {
|
|
545
561
|
features.push_back({ "VSX", "1" });
|
|
546
562
|
}
|
|
563
|
+
if (ggml_cpu_has_vxe()) {
|
|
564
|
+
features.push_back({ "VXE", "1" });
|
|
565
|
+
}
|
|
547
566
|
if (ggml_cpu_has_wasm_simd()) {
|
|
548
567
|
features.push_back({ "WASM_SIMD", "1" });
|
|
549
568
|
}
|
|
@@ -559,6 +578,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
|
|
|
559
578
|
#ifdef GGML_USE_OPENMP
|
|
560
579
|
features.push_back({ "OPENMP", "1" });
|
|
561
580
|
#endif
|
|
581
|
+
#ifdef GGML_USE_CPU_KLEIDIAI
|
|
582
|
+
features.push_back({ "KLEIDIAI", "1" });
|
|
583
|
+
#endif
|
|
562
584
|
#ifdef GGML_USE_CPU_AARCH64
|
|
563
585
|
features.push_back({ "AARCH64_REPACK", "1" });
|
|
564
586
|
#endif
|
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
|
|
2
|
+
// SPDX-License-Identifier: MIT
|
|
3
|
+
//
|
|
4
|
+
|
|
5
|
+
// KleidiAI micro-kernels
|
|
6
|
+
#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
|
|
7
|
+
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
|
|
8
|
+
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
|
|
9
|
+
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
|
|
10
|
+
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
|
|
11
|
+
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
|
|
12
|
+
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
|
|
13
|
+
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
|
|
14
|
+
#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
|
|
15
|
+
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
|
|
16
|
+
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
|
|
17
|
+
#include "kai_common.h"
|
|
18
|
+
|
|
19
|
+
#include "kernels.h"
|
|
20
|
+
|
|
21
|
+
#define NELEMS(x) sizeof(x) / sizeof(*x)
|
|
22
|
+
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
23
|
+
#if defined(__ARM_FEATURE_SME)
|
|
24
|
+
{
|
|
25
|
+
/* SME GEMM */
|
|
26
|
+
/* .kern_info = */ {
|
|
27
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
28
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
29
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
30
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
31
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
32
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
33
|
+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
34
|
+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
35
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
36
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
37
|
+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
38
|
+
},
|
|
39
|
+
/* SME GEMV */
|
|
40
|
+
/* .kern_info = */ {
|
|
41
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
42
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
43
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
44
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
45
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
46
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
47
|
+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
48
|
+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
49
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
50
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
51
|
+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
52
|
+
},
|
|
53
|
+
/* .lhs_info = */ {
|
|
54
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
55
|
+
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
56
|
+
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
|
|
57
|
+
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
|
|
58
|
+
/* .require_aligned_m_idx = */ true,
|
|
59
|
+
},
|
|
60
|
+
/* .rhs_info = */ {
|
|
61
|
+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
|
62
|
+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
|
63
|
+
},
|
|
64
|
+
/* .required_cpu = */ CPU_FEATURE_SME,
|
|
65
|
+
},
|
|
66
|
+
#endif
|
|
67
|
+
#if defined(__APPLE__)
|
|
68
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
|
69
|
+
{
|
|
70
|
+
/* DOTPROD GEMM */
|
|
71
|
+
/* .kern_info = */ {
|
|
72
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
73
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
74
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
75
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
76
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
77
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
78
|
+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
79
|
+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
80
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
81
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
82
|
+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
83
|
+
},
|
|
84
|
+
/* DOTPROD GEMV */
|
|
85
|
+
/* .kern_info = */ {
|
|
86
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
87
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
88
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
89
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
90
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
91
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
92
|
+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
93
|
+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
94
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
95
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
96
|
+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
97
|
+
},
|
|
98
|
+
/* .lhs_info = */ {
|
|
99
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
100
|
+
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
101
|
+
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
|
102
|
+
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
|
103
|
+
/* .require_aligned_m_idx = */ false,
|
|
104
|
+
},
|
|
105
|
+
/* .rhs_info = */ {
|
|
106
|
+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
107
|
+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
108
|
+
},
|
|
109
|
+
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
|
110
|
+
},
|
|
111
|
+
#endif
|
|
112
|
+
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
|
113
|
+
{
|
|
114
|
+
/* i8mm GEMM */
|
|
115
|
+
/* .kern_info = */ {
|
|
116
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
117
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
118
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
119
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
120
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
121
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
122
|
+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
123
|
+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
124
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
125
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
126
|
+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
127
|
+
},
|
|
128
|
+
/* i8mm GEMV */
|
|
129
|
+
/* .kern_info = */ {
|
|
130
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
131
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
132
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
133
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
134
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
135
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
136
|
+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
137
|
+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
138
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
139
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
140
|
+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
141
|
+
},
|
|
142
|
+
/* .lhs_info = */ {
|
|
143
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
144
|
+
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
145
|
+
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
|
146
|
+
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
|
147
|
+
/* .require_aligned_m_idx = */ false,
|
|
148
|
+
},
|
|
149
|
+
/* .rhs_info = */ {
|
|
150
|
+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
151
|
+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
152
|
+
},
|
|
153
|
+
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
|
154
|
+
},
|
|
155
|
+
#endif
|
|
156
|
+
#else
|
|
157
|
+
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
|
158
|
+
{
|
|
159
|
+
/* i8mm GEMM */
|
|
160
|
+
/* .kern_info = */ {
|
|
161
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
162
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
163
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
164
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
165
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
166
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
167
|
+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
168
|
+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
169
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
170
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
171
|
+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
172
|
+
},
|
|
173
|
+
/* i8mm GEMV */
|
|
174
|
+
/* .kern_info = */ {
|
|
175
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
176
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
177
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
178
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
179
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
180
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
181
|
+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
182
|
+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
183
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
184
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
185
|
+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
186
|
+
},
|
|
187
|
+
/* .lhs_info = */ {
|
|
188
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
189
|
+
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
190
|
+
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
|
191
|
+
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
|
192
|
+
/* .require_aligned_m_idx = */ false,
|
|
193
|
+
},
|
|
194
|
+
/* .rhs_info = */ {
|
|
195
|
+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
196
|
+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
197
|
+
},
|
|
198
|
+
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
|
199
|
+
},
|
|
200
|
+
#endif
|
|
201
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
|
202
|
+
{
|
|
203
|
+
/* DOTPROD GEMM */
|
|
204
|
+
/* .kern_info = */ {
|
|
205
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
206
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
207
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
208
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
209
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
210
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
211
|
+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
212
|
+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
213
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
214
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
215
|
+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
216
|
+
},
|
|
217
|
+
/* DOTPROD GEMV */
|
|
218
|
+
/* .kern_info = */ {
|
|
219
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
220
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
221
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
222
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
223
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
224
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
225
|
+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
226
|
+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
227
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
228
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
229
|
+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
230
|
+
},
|
|
231
|
+
/* .lhs_info = */ {
|
|
232
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
233
|
+
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
234
|
+
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
|
235
|
+
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
|
236
|
+
/* .require_aligned_m_idx = */ false,
|
|
237
|
+
},
|
|
238
|
+
/* .rhs_info = */ {
|
|
239
|
+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
240
|
+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
241
|
+
},
|
|
242
|
+
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
|
243
|
+
},
|
|
244
|
+
#endif
|
|
245
|
+
#endif
|
|
246
|
+
};
|
|
247
|
+
|
|
248
|
+
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) {
|
|
249
|
+
ggml_kleidiai_kernels * kernels = nullptr;
|
|
250
|
+
|
|
251
|
+
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
|
252
|
+
if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
|
|
253
|
+
kernels = &gemm_gemv_kernels[i];
|
|
254
|
+
break;
|
|
255
|
+
}
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
return kernels;
|
|
259
|
+
}
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
|
|
2
|
+
// SPDX-License-Identifier: MIT
|
|
3
|
+
//
|
|
4
|
+
|
|
5
|
+
#pragma once
|
|
6
|
+
|
|
7
|
+
enum cpu_feature {
|
|
8
|
+
CPU_FEATURE_NONE = 0,
|
|
9
|
+
CPU_FEATURE_DOTPROD = 1,
|
|
10
|
+
CPU_FEATURE_I8MM = 2,
|
|
11
|
+
CPU_FEATURE_SVE = 4,
|
|
12
|
+
CPU_FEATURE_SME = 8
|
|
13
|
+
};
|
|
14
|
+
inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) {
|
|
15
|
+
lhs = static_cast<cpu_feature>(lhs | rhs);
|
|
16
|
+
return lhs;
|
|
17
|
+
}
|
|
18
|
+
inline cpu_feature operator|(cpu_feature lhs, cpu_feature rhs) {
|
|
19
|
+
return static_cast<cpu_feature>(static_cast<int>(lhs) | static_cast<int>(rhs));
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
struct kernel_info {
|
|
23
|
+
size_t (*get_m_step)(void);
|
|
24
|
+
size_t (*get_n_step)(void);
|
|
25
|
+
size_t (*get_mr)(void);
|
|
26
|
+
size_t (*get_nr)(void);
|
|
27
|
+
size_t (*get_kr)(void);
|
|
28
|
+
size_t (*get_sr)(void);
|
|
29
|
+
size_t (*get_lhs_offset)(size_t m_idx, size_t k, size_t bl);
|
|
30
|
+
size_t (*get_rhs_packed_offset)(size_t n_idx, size_t k, size_t bl);
|
|
31
|
+
size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);
|
|
32
|
+
size_t (*get_dst_size)(size_t m, size_t n);
|
|
33
|
+
void (*run_kernel)(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
|
|
34
|
+
float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max);
|
|
35
|
+
};
|
|
36
|
+
|
|
37
|
+
struct lhs_packing_info {
|
|
38
|
+
size_t (*get_offset)(size_t m_idx, size_t lhs_stride);
|
|
39
|
+
size_t (*get_packed_offset)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
|
|
40
|
+
size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
|
|
41
|
+
void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
|
|
42
|
+
size_t lhs_stride, void* lhs_packed);
|
|
43
|
+
bool require_aligned_m_idx;
|
|
44
|
+
};
|
|
45
|
+
|
|
46
|
+
struct rhs_packing_info {
|
|
47
|
+
size_t (*packed_size)(size_t n, size_t k, size_t nr, size_t kr, size_t bl);
|
|
48
|
+
void (*pack_func)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
|
|
49
|
+
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params);
|
|
50
|
+
};
|
|
51
|
+
|
|
52
|
+
struct ggml_kleidiai_kernels {
|
|
53
|
+
kernel_info gemm;
|
|
54
|
+
kernel_info gemv;
|
|
55
|
+
lhs_packing_info lhs_info;
|
|
56
|
+
rhs_packing_info rhs_info;
|
|
57
|
+
|
|
58
|
+
cpu_feature required_cpu;
|
|
59
|
+
};
|
|
60
|
+
|
|
61
|
+
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features);
|
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
|
|
2
|
+
// SPDX-License-Identifier: MIT
|
|
3
|
+
//
|
|
4
|
+
#include <arm_neon.h>
|
|
5
|
+
#include <assert.h>
|
|
6
|
+
#include <cfloat>
|
|
7
|
+
#include <stdint.h>
|
|
8
|
+
#include <string.h>
|
|
9
|
+
#if defined(__linux__)
|
|
10
|
+
#include <asm/hwcap.h>
|
|
11
|
+
#include <sys/auxv.h>
|
|
12
|
+
#elif defined(__APPLE__)
|
|
13
|
+
#include <string_view>
|
|
14
|
+
#include <sys/sysctl.h>
|
|
15
|
+
#include <sys/types.h>
|
|
16
|
+
#elif defined(_WIN32)
|
|
17
|
+
#include <windows.h>
|
|
18
|
+
#include <excpt.h>
|
|
19
|
+
#endif
|
|
20
|
+
|
|
21
|
+
#include "kleidiai.h"
|
|
22
|
+
|
|
23
|
+
#include "ggml-cpu.h"
|
|
24
|
+
#include "ggml-impl.h"
|
|
25
|
+
#include "ggml-backend-impl.h"
|
|
26
|
+
#include "ggml-threading.h"
|
|
27
|
+
#include "ggml-cpu-traits.h"
|
|
28
|
+
|
|
29
|
+
#include "kernels.h"
|
|
30
|
+
|
|
31
|
+
#include "kai_common.h"
|
|
32
|
+
|
|
33
|
+
#define GGML_COMMON_DECL_CPP
|
|
34
|
+
#include "ggml-common.h"
|
|
35
|
+
|
|
36
|
+
struct ggml_kleidiai_context {
|
|
37
|
+
ggml_kleidiai_kernels * kernels;
|
|
38
|
+
} static ctx = { NULL };
|
|
39
|
+
|
|
40
|
+
static void init_kleidiai_context(void) {
|
|
41
|
+
|
|
42
|
+
ggml_critical_section_start();
|
|
43
|
+
static bool initialized = false;
|
|
44
|
+
|
|
45
|
+
if (!initialized) {
|
|
46
|
+
initialized = true;
|
|
47
|
+
const char *env_var = getenv("GGML_KLEIDIAI_SME");
|
|
48
|
+
int sme_enabled = 0;
|
|
49
|
+
|
|
50
|
+
cpu_feature features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
|
|
51
|
+
(ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
|
|
52
|
+
(ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
|
|
53
|
+
|
|
54
|
+
if (env_var) {
|
|
55
|
+
sme_enabled = atoi(env_var);
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
if (sme_enabled != 0) {
|
|
59
|
+
features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
|
|
60
|
+
}
|
|
61
|
+
ctx.kernels = ggml_kleidiai_select_kernels(features);
|
|
62
|
+
}
|
|
63
|
+
ggml_critical_section_end();
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
|
|
67
|
+
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
|
|
68
|
+
return tensor->ne[dim];
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
namespace ggml::cpu::kleidiai {
|
|
72
|
+
class tensor_traits : public ggml::cpu::tensor_traits {
|
|
73
|
+
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
|
|
74
|
+
GGML_ASSERT(ctx.kernels);
|
|
75
|
+
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
|
|
76
|
+
|
|
77
|
+
size_t k = op->src[0]->ne[0];
|
|
78
|
+
size_t m = op->src[1]->ne[1];
|
|
79
|
+
|
|
80
|
+
size_t mr = kernel->get_mr();
|
|
81
|
+
size_t kr = kernel->get_kr();
|
|
82
|
+
size_t sr = kernel->get_sr();
|
|
83
|
+
|
|
84
|
+
size = ctx.kernels->lhs_info.packed_size(m, k, QK4_0, mr, kr, sr);
|
|
85
|
+
|
|
86
|
+
return true;
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
|
|
90
|
+
if (dst->op == GGML_OP_MUL_MAT) {
|
|
91
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
92
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
93
|
+
|
|
94
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
95
|
+
|
|
96
|
+
GGML_ASSERT(ctx.kernels);
|
|
97
|
+
kernel_info * kernel = src1->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
|
|
98
|
+
lhs_packing_info * lhs_info = &ctx.kernels->lhs_info;
|
|
99
|
+
|
|
100
|
+
GGML_ASSERT(kernel);
|
|
101
|
+
|
|
102
|
+
const int ith = params->ith;
|
|
103
|
+
const int nth = params->nth;
|
|
104
|
+
|
|
105
|
+
const size_t k = ne00;
|
|
106
|
+
const size_t m = ne11;
|
|
107
|
+
const size_t n = ne01;
|
|
108
|
+
|
|
109
|
+
const size_t n_step = kernel->get_n_step();
|
|
110
|
+
const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
|
|
111
|
+
const size_t n_start = ith * num_n_per_thread;
|
|
112
|
+
|
|
113
|
+
size_t n_to_process = num_n_per_thread;
|
|
114
|
+
if ((n_start + n_to_process) > n) {
|
|
115
|
+
n_to_process = n - n_start;
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
|
|
119
|
+
uint8_t * lhs_packed = (uint8_t*)params->wdata;
|
|
120
|
+
const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
|
|
121
|
+
|
|
122
|
+
size_t mr = kernel->get_mr();
|
|
123
|
+
size_t kr = kernel->get_kr();
|
|
124
|
+
size_t sr = kernel->get_sr();
|
|
125
|
+
|
|
126
|
+
// Calculate number of columns to be processed per thread
|
|
127
|
+
const bool use_multithread = lhs_info->require_aligned_m_idx && m <= mr ? false : true;
|
|
128
|
+
const size_t num_m_per_thread = use_multithread ? kai_roundup(m, nth) / nth : m;
|
|
129
|
+
const size_t m_start = ith * num_m_per_thread;
|
|
130
|
+
size_t m_to_process = num_m_per_thread;
|
|
131
|
+
if ((m_start + m_to_process) > m) {
|
|
132
|
+
m_to_process = m - m_start;
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
if(m_start < m) {
|
|
136
|
+
// Transform LHS
|
|
137
|
+
const size_t src_stride = src1->nb[1];
|
|
138
|
+
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(0, dst->src[1]->nb[1]));
|
|
139
|
+
const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr);
|
|
140
|
+
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
|
|
141
|
+
|
|
142
|
+
lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, m_start, src_ptr, src_stride, lhs_packed_ptr);
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
ggml_barrier(params->threadpool);
|
|
146
|
+
|
|
147
|
+
// Perform the operation
|
|
148
|
+
const size_t dst_stride = dst->nb[1];
|
|
149
|
+
const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, QK4_0, mr, kr, sr);
|
|
150
|
+
const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, QK4_0);
|
|
151
|
+
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
|
|
152
|
+
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
|
|
153
|
+
const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
|
|
154
|
+
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
|
|
155
|
+
|
|
156
|
+
kernel->run_kernel(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr,
|
|
157
|
+
dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
|
|
158
|
+
return true;
|
|
159
|
+
}
|
|
160
|
+
return false;
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
public:
|
|
164
|
+
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
|
|
165
|
+
GGML_ASSERT(ctx.kernels);
|
|
166
|
+
const size_t n = tensor->ne[1];
|
|
167
|
+
const size_t k = tensor->ne[0];
|
|
168
|
+
size_t nr = ctx.kernels->gemm.get_nr();
|
|
169
|
+
size_t kr = ctx.kernels->gemm.get_kr();
|
|
170
|
+
size_t sr = ctx.kernels->gemm.get_sr();
|
|
171
|
+
|
|
172
|
+
#ifndef NDEBUG
|
|
173
|
+
const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, QK4_0);
|
|
174
|
+
GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
|
|
175
|
+
#endif
|
|
176
|
+
struct kai_rhs_pack_qs4cxs1s0_param params;
|
|
177
|
+
params.lhs_zero_point = 1;
|
|
178
|
+
params.rhs_zero_point = 8;
|
|
179
|
+
ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, NULL, tensor->data, 0, ¶ms);
|
|
180
|
+
|
|
181
|
+
return 0;
|
|
182
|
+
|
|
183
|
+
GGML_UNUSED(data_size);
|
|
184
|
+
}
|
|
185
|
+
};
|
|
186
|
+
|
|
187
|
+
static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) {
|
|
188
|
+
static tensor_traits traits;
|
|
189
|
+
return &traits;
|
|
190
|
+
}
|
|
191
|
+
} // namespace ggml::cpu::kleidiai
|
|
192
|
+
|
|
193
|
+
GGML_API enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
|
|
194
|
+
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
|
|
195
|
+
|
|
196
|
+
GGML_UNUSED(buffer);
|
|
197
|
+
return GGML_STATUS_SUCCESS;
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
|
|
201
|
+
const void * data, size_t offset, size_t size) {
|
|
202
|
+
GGML_ASSERT(offset == 0);
|
|
203
|
+
GGML_ASSERT(size == ggml_nbytes(tensor));
|
|
204
|
+
|
|
205
|
+
auto tensor_traits = (ggml::cpu::kleidiai::tensor_traits *) tensor->extra;
|
|
206
|
+
auto OK = tensor_traits->repack(tensor, data, size);
|
|
207
|
+
|
|
208
|
+
GGML_ASSERT(OK == 0);
|
|
209
|
+
GGML_UNUSED(buffer);
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
|
213
|
+
return "CPU_KLEIDIAI";
|
|
214
|
+
|
|
215
|
+
GGML_UNUSED(buft);
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
|
219
|
+
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
|
|
220
|
+
|
|
221
|
+
if (buffer == nullptr) {
|
|
222
|
+
return nullptr;
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
buffer->buft = buft;
|
|
226
|
+
buffer->iface.init_tensor = ggml_backend_cpu_kleidiai_buffer_init_tensor;
|
|
227
|
+
buffer->iface.set_tensor = ggml_backend_cpu_kleidiai_buffer_set_tensor;
|
|
228
|
+
buffer->iface.get_tensor = nullptr;
|
|
229
|
+
buffer->iface.cpy_tensor = nullptr;
|
|
230
|
+
return buffer;
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
|
234
|
+
return TENSOR_ALIGNMENT;
|
|
235
|
+
|
|
236
|
+
GGML_UNUSED(buft);
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
namespace ggml::cpu::kleidiai {
|
|
240
|
+
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|
241
|
+
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
|
242
|
+
if ( op->op == GGML_OP_MUL_MAT &&
|
|
243
|
+
op->src[0]->type == GGML_TYPE_Q4_0 &&
|
|
244
|
+
op->src[0]->buffer &&
|
|
245
|
+
(ggml_n_dims(op->src[0]) == 2) &&
|
|
246
|
+
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels
|
|
247
|
+
) {
|
|
248
|
+
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
|
249
|
+
return false;
|
|
250
|
+
}
|
|
251
|
+
if (op->src[1]->type == GGML_TYPE_F32 &&
|
|
252
|
+
ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
|
|
253
|
+
return true;
|
|
254
|
+
}
|
|
255
|
+
}
|
|
256
|
+
return false;
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
|
|
260
|
+
if (op->op == GGML_OP_MUL_MAT) {
|
|
261
|
+
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
|
262
|
+
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
|
263
|
+
}
|
|
264
|
+
}
|
|
265
|
+
return nullptr;
|
|
266
|
+
}
|
|
267
|
+
};
|
|
268
|
+
} // namespace ggml::cpu::kleidiai
|
|
269
|
+
|
|
270
|
+
ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
|
|
271
|
+
static ggml::cpu::kleidiai::extra_buffer_type ctx;
|
|
272
|
+
static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_kleidiai = {
|
|
273
|
+
/* .iface = */ {
|
|
274
|
+
/* .get_name = */ ggml_backend_cpu_kleidiai_buffer_type_get_name,
|
|
275
|
+
/* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
|
|
276
|
+
/* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
|
|
277
|
+
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
|
|
278
|
+
/* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
|
|
279
|
+
/* .is_host = */ nullptr,
|
|
280
|
+
},
|
|
281
|
+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
|
|
282
|
+
/* .context = */ &ctx,
|
|
283
|
+
};
|
|
284
|
+
|
|
285
|
+
init_kleidiai_context();
|
|
286
|
+
|
|
287
|
+
return &ggml_backend_cpu_buffer_type_kleidiai;
|
|
288
|
+
}
|