@fugood/llama.node 0.3.17 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +3 -1
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +39 -2
- package/lib/index.js +132 -1
- package/lib/index.ts +203 -3
- package/package.json +2 -1
- package/src/EmbeddingWorker.cpp +1 -1
- package/src/LlamaCompletionWorker.cpp +366 -19
- package/src/LlamaCompletionWorker.h +30 -10
- package/src/LlamaContext.cpp +213 -5
- package/src/LlamaContext.h +12 -0
- package/src/common.hpp +15 -0
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +133 -24
- package/src/llama.cpp/.github/workflows/build.yml +41 -762
- package/src/llama.cpp/.github/workflows/docker.yml +5 -2
- package/src/llama.cpp/.github/workflows/release.yml +716 -0
- package/src/llama.cpp/.github/workflows/server.yml +12 -12
- package/src/llama.cpp/CMakeLists.txt +5 -17
- package/src/llama.cpp/cmake/build-info.cmake +8 -2
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
- package/src/llama.cpp/common/CMakeLists.txt +31 -3
- package/src/llama.cpp/common/arg.cpp +48 -29
- package/src/llama.cpp/common/chat.cpp +128 -106
- package/src/llama.cpp/common/chat.h +2 -0
- package/src/llama.cpp/common/common.cpp +37 -1
- package/src/llama.cpp/common/common.h +18 -9
- package/src/llama.cpp/common/llguidance.cpp +1 -0
- package/src/llama.cpp/common/minja/chat-template.hpp +9 -5
- package/src/llama.cpp/common/minja/minja.hpp +69 -36
- package/src/llama.cpp/common/regex-partial.cpp +204 -0
- package/src/llama.cpp/common/regex-partial.h +56 -0
- package/src/llama.cpp/common/sampling.cpp +57 -50
- package/src/llama.cpp/examples/CMakeLists.txt +2 -23
- package/src/llama.cpp/examples/embedding/embedding.cpp +2 -11
- package/src/llama.cpp/examples/parallel/parallel.cpp +86 -14
- package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/training/finetune.cpp +96 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +27 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
- package/src/llama.cpp/ggml/include/ggml.h +10 -7
- package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
- package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +20 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +306 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +4 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +29 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +501 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +0 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +0 -6
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +36 -11
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +0 -2
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +41 -27
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +9 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +121 -232
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +7 -15
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +0 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +338 -166
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
- package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -70
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +657 -193
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +20 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +123 -29
- package/src/llama.cpp/ggml/src/ggml.c +29 -20
- package/src/llama.cpp/ggml/src/gguf.cpp +33 -33
- package/src/llama.cpp/include/llama.h +52 -11
- package/src/llama.cpp/requirements/requirements-all.txt +3 -3
- package/src/llama.cpp/scripts/xxd.cmake +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +1 -0
- package/src/llama.cpp/src/llama-adapter.cpp +6 -0
- package/src/llama.cpp/src/llama-arch.cpp +3 -0
- package/src/llama.cpp/src/llama-batch.cpp +5 -1
- package/src/llama.cpp/src/llama-batch.h +2 -1
- package/src/llama.cpp/src/llama-chat.cpp +17 -7
- package/src/llama.cpp/src/llama-chat.h +1 -0
- package/src/llama.cpp/src/llama-context.cpp +389 -501
- package/src/llama.cpp/src/llama-context.h +44 -32
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +20 -38
- package/src/llama.cpp/src/llama-graph.h +12 -8
- package/src/llama.cpp/src/llama-kv-cache.cpp +1503 -389
- package/src/llama.cpp/src/llama-kv-cache.h +271 -85
- package/src/llama.cpp/src/llama-memory.h +11 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +24 -15
- package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/src/llama.cpp/src/llama-model-saver.h +37 -0
- package/src/llama.cpp/src/llama-model.cpp +316 -69
- package/src/llama.cpp/src/llama-model.h +8 -1
- package/src/llama.cpp/src/llama-quant.cpp +15 -13
- package/src/llama.cpp/src/llama-sampling.cpp +18 -6
- package/src/llama.cpp/src/llama-vocab.cpp +42 -4
- package/src/llama.cpp/src/llama-vocab.h +6 -0
- package/src/llama.cpp/src/llama.cpp +14 -0
- package/src/llama.cpp/tests/CMakeLists.txt +10 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +107 -47
- package/src/llama.cpp/tests/test-chat-template.cpp +10 -11
- package/src/llama.cpp/tests/test-chat.cpp +3 -1
- package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
- package/src/llama.cpp/tests/test-opt.cpp +33 -21
- package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
- package/src/llama.cpp/tests/test-sampling.cpp +1 -1
- package/src/llama.cpp/tools/CMakeLists.txt +39 -0
- package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +2 -2
- package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
- package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +495 -348
- package/src/llama.cpp/{examples → tools}/main/main.cpp +6 -9
- package/src/llama.cpp/{examples/llava → tools/mtmd}/CMakeLists.txt +1 -35
- package/src/llama.cpp/{examples/llava → tools/mtmd}/clip-impl.h +25 -5
- package/src/llama.cpp/{examples/llava → tools/mtmd}/clip.cpp +1440 -1349
- package/src/llama.cpp/tools/mtmd/clip.h +99 -0
- package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd-cli.cpp +70 -44
- package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
- package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd.cpp +251 -281
- package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
- package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +4 -2
- package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +13 -76
- package/src/llama.cpp/{examples → tools}/rpc/rpc-server.cpp +70 -74
- package/src/llama.cpp/{examples → tools}/run/run.cpp +18 -4
- package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
- package/src/llama.cpp/{examples → tools}/server/server.cpp +291 -76
- package/src/llama.cpp/{examples → tools}/server/utils.hpp +377 -5
- package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
- package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/infill.cpp +0 -590
- package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
- package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
- package/src/llama.cpp/examples/llava/clip.h +0 -135
- package/src/llama.cpp/examples/llava/llava.cpp +0 -586
- package/src/llama.cpp/examples/llava/llava.h +0 -49
- package/src/llama.cpp/examples/llava/mtmd.h +0 -168
- package/src/llama.cpp/examples/llava/qwen2vl-test.cpp +0 -636
- /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/deprecation-warning.cpp +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/rpc/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/server/httplib.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/tts.cpp +0 -0
|
@@ -4,16 +4,22 @@
|
|
|
4
4
|
|
|
5
5
|
// KleidiAI micro-kernels
|
|
6
6
|
#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
|
|
7
|
-
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
|
|
8
|
-
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
|
|
9
|
-
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
|
|
10
|
-
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
|
|
11
7
|
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
|
|
12
8
|
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
|
|
13
9
|
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
|
|
14
10
|
#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
|
|
15
11
|
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
|
|
16
12
|
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
|
|
13
|
+
#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
|
|
14
|
+
|
|
15
|
+
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
|
|
16
|
+
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
|
|
17
|
+
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
|
|
18
|
+
|
|
19
|
+
#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
|
|
20
|
+
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
|
|
21
|
+
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
|
|
22
|
+
|
|
17
23
|
#include "kai_common.h"
|
|
18
24
|
|
|
19
25
|
#include "kernels.h"
|
|
@@ -61,6 +67,53 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
61
67
|
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
|
62
68
|
},
|
|
63
69
|
/* .required_cpu = */ CPU_FEATURE_SME,
|
|
70
|
+
/* .lhs_type = */ GGML_TYPE_F32,
|
|
71
|
+
/* .rhs_type = */ GGML_TYPE_Q4_0,
|
|
72
|
+
/* .op_type = */ GGML_TYPE_F32,
|
|
73
|
+
},
|
|
74
|
+
{
|
|
75
|
+
/* SME GEMM */
|
|
76
|
+
/* .kern_info = */ {
|
|
77
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
78
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
79
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
80
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
81
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
82
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
83
|
+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
84
|
+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
85
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
86
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
87
|
+
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
88
|
+
},
|
|
89
|
+
/* SME GEMV */
|
|
90
|
+
/* .kern_info = */ {
|
|
91
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
92
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
93
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
94
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
95
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
96
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
97
|
+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
98
|
+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
99
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
100
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
101
|
+
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
|
102
|
+
},
|
|
103
|
+
/* .lhs_info = */ {
|
|
104
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
|
105
|
+
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
|
106
|
+
/* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
|
|
107
|
+
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
|
|
108
|
+
},
|
|
109
|
+
/* .rhs_info = */ {
|
|
110
|
+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
|
|
111
|
+
/* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
|
|
112
|
+
},
|
|
113
|
+
/* .required_cpu = */ CPU_FEATURE_SME,
|
|
114
|
+
/* .lhs_type = */ GGML_TYPE_F32,
|
|
115
|
+
/* .rhs_type = */ GGML_TYPE_F16,
|
|
116
|
+
/* .op_type = */ GGML_TYPE_F32,
|
|
64
117
|
},
|
|
65
118
|
#endif
|
|
66
119
|
#if defined(__APPLE__)
|
|
@@ -105,6 +158,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
105
158
|
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
106
159
|
},
|
|
107
160
|
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
|
161
|
+
/* .lhs_type = */ GGML_TYPE_F32,
|
|
162
|
+
/* .rhs_type = */ GGML_TYPE_Q4_0,
|
|
163
|
+
/* .op_type = */ GGML_TYPE_F32,
|
|
108
164
|
},
|
|
109
165
|
#endif
|
|
110
166
|
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
|
@@ -148,6 +204,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
148
204
|
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
149
205
|
},
|
|
150
206
|
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
|
207
|
+
/* .lhs_type = */ GGML_TYPE_F32,
|
|
208
|
+
/* .rhs_type = */ GGML_TYPE_Q4_0,
|
|
209
|
+
/* .op_type = */ GGML_TYPE_F32,
|
|
151
210
|
},
|
|
152
211
|
#endif
|
|
153
212
|
#else
|
|
@@ -192,6 +251,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
192
251
|
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
193
252
|
},
|
|
194
253
|
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
|
254
|
+
/* .lhs_type = */ GGML_TYPE_F32,
|
|
255
|
+
/* .rhs_type = */ GGML_TYPE_Q4_0,
|
|
256
|
+
/* .op_type = */ GGML_TYPE_F32,
|
|
195
257
|
},
|
|
196
258
|
#endif
|
|
197
259
|
#if defined(__ARM_FEATURE_DOTPROD)
|
|
@@ -235,12 +297,33 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
|
235
297
|
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
236
298
|
},
|
|
237
299
|
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
|
300
|
+
/* .lhs_type = */ GGML_TYPE_F32,
|
|
301
|
+
/* .rhs_type = */ GGML_TYPE_Q4_0,
|
|
302
|
+
/* .op_type = */ GGML_TYPE_F32,
|
|
238
303
|
},
|
|
239
304
|
#endif
|
|
240
305
|
#endif
|
|
241
306
|
};
|
|
242
307
|
|
|
243
|
-
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature
|
|
308
|
+
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
|
|
309
|
+
ggml_kleidiai_kernels * kernel = nullptr;
|
|
310
|
+
|
|
311
|
+
if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
|
|
312
|
+
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
|
313
|
+
if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
|
|
314
|
+
gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
|
|
315
|
+
gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type &&
|
|
316
|
+
gemm_gemv_kernels[i].op_type == tensor->type) {
|
|
317
|
+
kernel = &gemm_gemv_kernels[i];
|
|
318
|
+
break;
|
|
319
|
+
}
|
|
320
|
+
}
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
return kernel;
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) {
|
|
244
327
|
ggml_kleidiai_kernels * kernels = nullptr;
|
|
245
328
|
|
|
246
329
|
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
|
@@ -4,6 +4,10 @@
|
|
|
4
4
|
|
|
5
5
|
#pragma once
|
|
6
6
|
|
|
7
|
+
#include <functional>
|
|
8
|
+
#include <variant>
|
|
9
|
+
#include "ggml.h"
|
|
10
|
+
|
|
7
11
|
enum cpu_feature {
|
|
8
12
|
CPU_FEATURE_NONE = 0,
|
|
9
13
|
CPU_FEATURE_DOTPROD = 1,
|
|
@@ -26,26 +30,53 @@ struct kernel_info {
|
|
|
26
30
|
size_t (*get_nr)(void);
|
|
27
31
|
size_t (*get_kr)(void);
|
|
28
32
|
size_t (*get_sr)(void);
|
|
29
|
-
|
|
30
|
-
|
|
33
|
+
std::variant<
|
|
34
|
+
std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
|
|
35
|
+
std::function<size_t(size_t m_idx, size_t k)>
|
|
36
|
+
> get_lhs_offset;
|
|
37
|
+
std::variant<
|
|
38
|
+
std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
|
|
39
|
+
std::function<size_t(size_t n_idx, size_t k)>
|
|
40
|
+
> get_rhs_packed_offset;
|
|
31
41
|
size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);
|
|
32
42
|
size_t (*get_dst_size)(size_t m, size_t n);
|
|
33
|
-
|
|
34
|
-
|
|
43
|
+
std::variant<
|
|
44
|
+
std::function<void(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
|
|
45
|
+
float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max)>,
|
|
46
|
+
std::function<void(size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row,
|
|
47
|
+
size_t dst_stride_col, float clamp_min, float clamp_max)>
|
|
48
|
+
> run_kernel;
|
|
35
49
|
};
|
|
36
50
|
|
|
37
51
|
struct lhs_packing_info {
|
|
38
52
|
size_t (*get_offset)(size_t m_idx, size_t lhs_stride);
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
53
|
+
std::variant<
|
|
54
|
+
std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
|
|
55
|
+
std::function<size_t(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr)>
|
|
56
|
+
> get_packed_offset;
|
|
57
|
+
std::variant<
|
|
58
|
+
std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
|
|
59
|
+
std::function<size_t(size_t m, size_t k, size_t mr, size_t kr, size_t sr)>
|
|
60
|
+
> packed_size;
|
|
61
|
+
std::variant<
|
|
62
|
+
std::function<void(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
|
|
63
|
+
size_t lhs_stride, void* lhs_packed)>,
|
|
64
|
+
std::function<void(size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
|
|
65
|
+
void* lhs_packed)>
|
|
66
|
+
> pack_func;
|
|
43
67
|
};
|
|
44
68
|
|
|
45
69
|
struct rhs_packing_info {
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
70
|
+
std::variant<
|
|
71
|
+
std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
|
|
72
|
+
std::function<size_t(size_t n, size_t k)>
|
|
73
|
+
> packed_size;
|
|
74
|
+
std::variant<
|
|
75
|
+
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
|
|
76
|
+
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
|
|
77
|
+
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
|
|
78
|
+
const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
|
|
79
|
+
> pack_func;
|
|
49
80
|
};
|
|
50
81
|
|
|
51
82
|
struct ggml_kleidiai_kernels {
|
|
@@ -55,6 +86,10 @@ struct ggml_kleidiai_kernels {
|
|
|
55
86
|
rhs_packing_info rhs_info;
|
|
56
87
|
|
|
57
88
|
cpu_feature required_cpu;
|
|
89
|
+
ggml_type lhs_type;
|
|
90
|
+
ggml_type rhs_type;
|
|
91
|
+
ggml_type op_type;
|
|
58
92
|
};
|
|
59
93
|
|
|
60
|
-
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features);
|
|
94
|
+
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor);
|
|
95
|
+
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features);
|