@fugood/llama.node 0.3.12 → 0.3.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +2 -1
- package/package.json +1 -1
- package/src/LlamaCompletionWorker.cpp +14 -0
- package/src/LlamaContext.cpp +110 -79
- package/src/LlamaContext.h +1 -1
- package/src/common.hpp +1 -2
- package/src/llama.cpp/.github/workflows/build.yml +95 -13
- package/src/llama.cpp/.github/workflows/docker.yml +2 -0
- package/src/llama.cpp/.github/workflows/labeler.yml +1 -1
- package/src/llama.cpp/.github/workflows/server.yml +2 -0
- package/src/llama.cpp/common/CMakeLists.txt +23 -6
- package/src/llama.cpp/common/arg.cpp +292 -14
- package/src/llama.cpp/common/chat.cpp +1128 -315
- package/src/llama.cpp/common/chat.h +135 -0
- package/src/llama.cpp/common/common.cpp +27 -171
- package/src/llama.cpp/common/common.h +41 -73
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
- package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
- package/src/llama.cpp/common/llguidance.cpp +3 -3
- package/src/llama.cpp/common/log.cpp +1 -0
- package/src/llama.cpp/common/log.h +2 -1
- package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +21 -7
- package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +61 -14
- package/src/llama.cpp/common/ngram-cache.cpp +1 -0
- package/src/llama.cpp/common/sampling.cpp +93 -49
- package/src/llama.cpp/common/speculative.cpp +6 -5
- package/src/llama.cpp/common/speculative.h +1 -1
- package/src/llama.cpp/docs/build.md +47 -9
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +3 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +1 -0
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +4 -4
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +6 -5
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +1 -1
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +373 -107
- package/src/llama.cpp/examples/llava/clip.h +19 -3
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
- package/src/llama.cpp/examples/llava/llava.cpp +4 -2
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -0
- package/src/llama.cpp/examples/main/main.cpp +73 -28
- package/src/llama.cpp/examples/parallel/parallel.cpp +1 -0
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -0
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +1 -0
- package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
- package/src/llama.cpp/examples/run/run.cpp +115 -79
- package/src/llama.cpp/examples/server/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/server/httplib.h +381 -292
- package/src/llama.cpp/examples/server/server.cpp +134 -128
- package/src/llama.cpp/examples/server/utils.hpp +95 -106
- package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
- package/src/llama.cpp/examples/tts/tts.cpp +251 -142
- package/src/llama.cpp/ggml/CMakeLists.txt +13 -1
- package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-cpu.h +4 -1
- package/src/llama.cpp/ggml/include/ggml-metal.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +0 -2
- package/src/llama.cpp/ggml/include/ggml.h +6 -2
- package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
- package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
- package/src/llama.cpp/ggml/src/ggml-common.h +0 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +132 -17
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +156 -11
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +2235 -641
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1572 -198
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +24 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +9 -8
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +16 -3
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +235 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -2
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +246 -120
- package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +51 -10
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +136 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +308 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +174 -728
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -77
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +949 -602
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +37 -3
- package/src/llama.cpp/ggml/src/ggml.c +9 -4
- package/src/llama.cpp/include/llama.h +32 -14
- package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +1 -0
- package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
- package/src/llama.cpp/requirements.txt +1 -0
- package/src/llama.cpp/src/llama-arch.cpp +21 -0
- package/src/llama.cpp/src/llama-arch.h +1 -0
- package/src/llama.cpp/src/llama-chat.cpp +1 -0
- package/src/llama.cpp/src/llama-grammar.cpp +183 -183
- package/src/llama.cpp/src/llama-grammar.h +13 -4
- package/src/llama.cpp/src/llama-impl.h +6 -6
- package/src/llama.cpp/src/llama-kv-cache.h +2 -1
- package/src/llama.cpp/src/llama-mmap.cpp +11 -1
- package/src/llama.cpp/src/llama-mmap.h +1 -0
- package/src/llama.cpp/src/llama-model.cpp +70 -6
- package/src/llama.cpp/src/llama-sampling.cpp +174 -67
- package/src/llama.cpp/src/llama-vocab.cpp +12 -0
- package/src/llama.cpp/src/llama.cpp +154 -5
- package/src/llama.cpp/src/unicode.cpp +9 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +171 -115
- package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
- package/src/llama.cpp/tests/test-chat.cpp +691 -325
- package/src/llama.cpp/tests/test-gguf.cpp +4 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
- package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
- package/src/llama.cpp/tests/test-sampling.cpp +15 -0
- package/src/llama.cpp/Sources/llama/llama.h +0 -4
- package/src/llama.cpp/common/chat.hpp +0 -52
|
@@ -14,6 +14,10 @@
|
|
|
14
14
|
#include "ggml-cpu-hbm.h"
|
|
15
15
|
#endif
|
|
16
16
|
|
|
17
|
+
#ifdef GGML_USE_CPU_KLEIDIAI
|
|
18
|
+
#include "kleidiai/kleidiai.h"
|
|
19
|
+
#endif
|
|
20
|
+
|
|
17
21
|
#if defined(__APPLE__)
|
|
18
22
|
#include <sys/types.h>
|
|
19
23
|
#include <sys/sysctl.h>
|
|
@@ -39,6 +43,12 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
|
|
|
39
43
|
}
|
|
40
44
|
#endif
|
|
41
45
|
|
|
46
|
+
#ifdef GGML_USE_CPU_KLEIDIAI
|
|
47
|
+
if (ggml_backend_cpu_kleidiai_buffer_type()) {
|
|
48
|
+
bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type());
|
|
49
|
+
}
|
|
50
|
+
#endif
|
|
51
|
+
|
|
42
52
|
#ifdef GGML_USE_CPU_AARCH64
|
|
43
53
|
if (ggml_backend_cpu_aarch64_buffer_type()) {
|
|
44
54
|
bufts.push_back(ggml_backend_cpu_aarch64_buffer_type());
|
|
@@ -284,14 +294,14 @@ struct ggml_backend_cpu_device_context {
|
|
|
284
294
|
&hKey) == ERROR_SUCCESS) {
|
|
285
295
|
DWORD cpu_brand_size = 0;
|
|
286
296
|
if (RegQueryValueExA(hKey,
|
|
287
|
-
|
|
297
|
+
"ProcessorNameString",
|
|
288
298
|
NULL,
|
|
289
299
|
NULL,
|
|
290
300
|
NULL,
|
|
291
301
|
&cpu_brand_size) == ERROR_SUCCESS) {
|
|
292
302
|
description.resize(cpu_brand_size);
|
|
293
303
|
if (RegQueryValueExA(hKey,
|
|
294
|
-
|
|
304
|
+
"ProcessorNameString",
|
|
295
305
|
NULL,
|
|
296
306
|
NULL,
|
|
297
307
|
(LPBYTE)&description[0], // NOLINT
|
|
@@ -501,6 +511,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
|
|
|
501
511
|
if (ggml_cpu_has_fma()) {
|
|
502
512
|
features.push_back({ "FMA", "1" });
|
|
503
513
|
}
|
|
514
|
+
if (ggml_cpu_has_bmi2()) {
|
|
515
|
+
features.push_back({ "BMI2", "1" });
|
|
516
|
+
}
|
|
504
517
|
if (ggml_cpu_has_avx512()) {
|
|
505
518
|
features.push_back({ "AVX512", "1" });
|
|
506
519
|
}
|
|
@@ -534,19 +547,22 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
|
|
|
534
547
|
if (ggml_cpu_has_dotprod()) {
|
|
535
548
|
features.push_back({ "DOTPROD", "1" });
|
|
536
549
|
}
|
|
537
|
-
if (ggml_cpu_has_matmul_int8()) {
|
|
538
|
-
features.push_back({ "MATMUL_INT8", "1" });
|
|
539
|
-
}
|
|
540
550
|
if (ggml_cpu_get_sve_cnt() > 0) {
|
|
541
551
|
static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt());
|
|
542
552
|
features.push_back({ "SVE_CNT", sve_cnt.c_str() });
|
|
543
553
|
}
|
|
554
|
+
if (ggml_cpu_has_sme()) {
|
|
555
|
+
features.push_back({ "SME", "1" });
|
|
556
|
+
}
|
|
544
557
|
if (ggml_cpu_has_riscv_v()) {
|
|
545
558
|
features.push_back({ "RISCV_V", "1" });
|
|
546
559
|
}
|
|
547
560
|
if (ggml_cpu_has_vsx()) {
|
|
548
561
|
features.push_back({ "VSX", "1" });
|
|
549
562
|
}
|
|
563
|
+
if (ggml_cpu_has_vxe()) {
|
|
564
|
+
features.push_back({ "VXE", "1" });
|
|
565
|
+
}
|
|
550
566
|
if (ggml_cpu_has_wasm_simd()) {
|
|
551
567
|
features.push_back({ "WASM_SIMD", "1" });
|
|
552
568
|
}
|
|
@@ -562,6 +578,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
|
|
|
562
578
|
#ifdef GGML_USE_OPENMP
|
|
563
579
|
features.push_back({ "OPENMP", "1" });
|
|
564
580
|
#endif
|
|
581
|
+
#ifdef GGML_USE_CPU_KLEIDIAI
|
|
582
|
+
features.push_back({ "KLEIDIAI", "1" });
|
|
583
|
+
#endif
|
|
565
584
|
#ifdef GGML_USE_CPU_AARCH64
|
|
566
585
|
features.push_back({ "AARCH64_REPACK", "1" });
|
|
567
586
|
#endif
|
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
|
|
2
|
+
// SPDX-License-Identifier: MIT
|
|
3
|
+
//
|
|
4
|
+
|
|
5
|
+
// KleidiAI micro-kernels
|
|
6
|
+
#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
|
|
7
|
+
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
|
|
8
|
+
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
|
|
9
|
+
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
|
|
10
|
+
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
|
|
11
|
+
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
|
|
12
|
+
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
|
|
13
|
+
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
|
|
14
|
+
#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
|
|
15
|
+
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
|
|
16
|
+
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
|
|
17
|
+
#include "kai_common.h"
|
|
18
|
+
|
|
19
|
+
#include "kernels.h"
|
|
20
|
+
|
|
21
|
+
#define NELEMS(x) sizeof(x) / sizeof(*x)
|
|
22
|
+
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|
23
|
+
#if defined(__ARM_FEATURE_SME)
|
|
24
|
+
{
|
|
25
|
+
/* SME GEMM */
|
|
26
|
+
/* .kern_info = */ {
|
|
27
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
28
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
29
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
30
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
31
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
32
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
33
|
+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
34
|
+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
35
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
36
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
37
|
+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
|
38
|
+
},
|
|
39
|
+
/* SME GEMV */
|
|
40
|
+
/* .kern_info = */ {
|
|
41
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
42
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
43
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
44
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
45
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
46
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
47
|
+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
48
|
+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
49
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
50
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
51
|
+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
|
52
|
+
},
|
|
53
|
+
/* .lhs_info = */ {
|
|
54
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
55
|
+
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
56
|
+
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
|
|
57
|
+
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
|
|
58
|
+
/* .require_aligned_m_idx = */ true,
|
|
59
|
+
},
|
|
60
|
+
/* .rhs_info = */ {
|
|
61
|
+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
|
62
|
+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
|
63
|
+
},
|
|
64
|
+
/* .required_cpu = */ CPU_FEATURE_SME,
|
|
65
|
+
},
|
|
66
|
+
#endif
|
|
67
|
+
#if defined(__APPLE__)
|
|
68
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
|
69
|
+
{
|
|
70
|
+
/* DOTPROD GEMM */
|
|
71
|
+
/* .kern_info = */ {
|
|
72
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
73
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
74
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
75
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
76
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
77
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
78
|
+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
79
|
+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
80
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
81
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
82
|
+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
83
|
+
},
|
|
84
|
+
/* DOTPROD GEMV */
|
|
85
|
+
/* .kern_info = */ {
|
|
86
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
87
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
88
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
89
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
90
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
91
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
92
|
+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
93
|
+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
94
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
95
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
96
|
+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
97
|
+
},
|
|
98
|
+
/* .lhs_info = */ {
|
|
99
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
100
|
+
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
101
|
+
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
|
102
|
+
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
|
103
|
+
/* .require_aligned_m_idx = */ false,
|
|
104
|
+
},
|
|
105
|
+
/* .rhs_info = */ {
|
|
106
|
+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
107
|
+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
108
|
+
},
|
|
109
|
+
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
|
110
|
+
},
|
|
111
|
+
#endif
|
|
112
|
+
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
|
113
|
+
{
|
|
114
|
+
/* i8mm GEMM */
|
|
115
|
+
/* .kern_info = */ {
|
|
116
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
117
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
118
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
119
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
120
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
121
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
122
|
+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
123
|
+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
124
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
125
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
126
|
+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
127
|
+
},
|
|
128
|
+
/* i8mm GEMV */
|
|
129
|
+
/* .kern_info = */ {
|
|
130
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
131
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
132
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
133
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
134
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
135
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
136
|
+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
137
|
+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
138
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
139
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
140
|
+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
141
|
+
},
|
|
142
|
+
/* .lhs_info = */ {
|
|
143
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
144
|
+
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
145
|
+
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
|
146
|
+
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
|
147
|
+
/* .require_aligned_m_idx = */ false,
|
|
148
|
+
},
|
|
149
|
+
/* .rhs_info = */ {
|
|
150
|
+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
151
|
+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
152
|
+
},
|
|
153
|
+
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
|
154
|
+
},
|
|
155
|
+
#endif
|
|
156
|
+
#else
|
|
157
|
+
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
|
158
|
+
{
|
|
159
|
+
/* i8mm GEMM */
|
|
160
|
+
/* .kern_info = */ {
|
|
161
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
162
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
163
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
164
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
165
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
166
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
167
|
+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
168
|
+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
169
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
170
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
171
|
+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
|
172
|
+
},
|
|
173
|
+
/* i8mm GEMV */
|
|
174
|
+
/* .kern_info = */ {
|
|
175
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
176
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
177
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
178
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
179
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
180
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
181
|
+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
182
|
+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
183
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
184
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
185
|
+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
|
186
|
+
},
|
|
187
|
+
/* .lhs_info = */ {
|
|
188
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
189
|
+
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
190
|
+
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
|
191
|
+
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
|
192
|
+
/* .require_aligned_m_idx = */ false,
|
|
193
|
+
},
|
|
194
|
+
/* .rhs_info = */ {
|
|
195
|
+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
196
|
+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
197
|
+
},
|
|
198
|
+
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
|
199
|
+
},
|
|
200
|
+
#endif
|
|
201
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
|
202
|
+
{
|
|
203
|
+
/* DOTPROD GEMM */
|
|
204
|
+
/* .kern_info = */ {
|
|
205
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
206
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
207
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
208
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
209
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
210
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
211
|
+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
212
|
+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
213
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
214
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
215
|
+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
|
216
|
+
},
|
|
217
|
+
/* DOTPROD GEMV */
|
|
218
|
+
/* .kern_info = */ {
|
|
219
|
+
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
220
|
+
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
221
|
+
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
222
|
+
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
223
|
+
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
224
|
+
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
225
|
+
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
226
|
+
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
227
|
+
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
228
|
+
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
229
|
+
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
|
230
|
+
},
|
|
231
|
+
/* .lhs_info = */ {
|
|
232
|
+
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
233
|
+
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
|
234
|
+
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
|
235
|
+
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
|
236
|
+
/* .require_aligned_m_idx = */ false,
|
|
237
|
+
},
|
|
238
|
+
/* .rhs_info = */ {
|
|
239
|
+
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
240
|
+
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
|
241
|
+
},
|
|
242
|
+
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
|
243
|
+
},
|
|
244
|
+
#endif
|
|
245
|
+
#endif
|
|
246
|
+
};
|
|
247
|
+
|
|
248
|
+
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) {
|
|
249
|
+
ggml_kleidiai_kernels * kernels = nullptr;
|
|
250
|
+
|
|
251
|
+
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
|
252
|
+
if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
|
|
253
|
+
kernels = &gemm_gemv_kernels[i];
|
|
254
|
+
break;
|
|
255
|
+
}
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
return kernels;
|
|
259
|
+
}
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
|
|
2
|
+
// SPDX-License-Identifier: MIT
|
|
3
|
+
//
|
|
4
|
+
|
|
5
|
+
#pragma once
|
|
6
|
+
|
|
7
|
+
enum cpu_feature {
|
|
8
|
+
CPU_FEATURE_NONE = 0,
|
|
9
|
+
CPU_FEATURE_DOTPROD = 1,
|
|
10
|
+
CPU_FEATURE_I8MM = 2,
|
|
11
|
+
CPU_FEATURE_SVE = 4,
|
|
12
|
+
CPU_FEATURE_SME = 8
|
|
13
|
+
};
|
|
14
|
+
inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) {
|
|
15
|
+
lhs = static_cast<cpu_feature>(lhs | rhs);
|
|
16
|
+
return lhs;
|
|
17
|
+
}
|
|
18
|
+
inline cpu_feature operator|(cpu_feature lhs, cpu_feature rhs) {
|
|
19
|
+
return static_cast<cpu_feature>(static_cast<int>(lhs) | static_cast<int>(rhs));
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
struct kernel_info {
|
|
23
|
+
size_t (*get_m_step)(void);
|
|
24
|
+
size_t (*get_n_step)(void);
|
|
25
|
+
size_t (*get_mr)(void);
|
|
26
|
+
size_t (*get_nr)(void);
|
|
27
|
+
size_t (*get_kr)(void);
|
|
28
|
+
size_t (*get_sr)(void);
|
|
29
|
+
size_t (*get_lhs_offset)(size_t m_idx, size_t k, size_t bl);
|
|
30
|
+
size_t (*get_rhs_packed_offset)(size_t n_idx, size_t k, size_t bl);
|
|
31
|
+
size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);
|
|
32
|
+
size_t (*get_dst_size)(size_t m, size_t n);
|
|
33
|
+
void (*run_kernel)(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
|
|
34
|
+
float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max);
|
|
35
|
+
};
|
|
36
|
+
|
|
37
|
+
struct lhs_packing_info {
|
|
38
|
+
size_t (*get_offset)(size_t m_idx, size_t lhs_stride);
|
|
39
|
+
size_t (*get_packed_offset)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
|
|
40
|
+
size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
|
|
41
|
+
void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
|
|
42
|
+
size_t lhs_stride, void* lhs_packed);
|
|
43
|
+
bool require_aligned_m_idx;
|
|
44
|
+
};
|
|
45
|
+
|
|
46
|
+
struct rhs_packing_info {
|
|
47
|
+
size_t (*packed_size)(size_t n, size_t k, size_t nr, size_t kr, size_t bl);
|
|
48
|
+
void (*pack_func)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
|
|
49
|
+
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params);
|
|
50
|
+
};
|
|
51
|
+
|
|
52
|
+
struct ggml_kleidiai_kernels {
|
|
53
|
+
kernel_info gemm;
|
|
54
|
+
kernel_info gemv;
|
|
55
|
+
lhs_packing_info lhs_info;
|
|
56
|
+
rhs_packing_info rhs_info;
|
|
57
|
+
|
|
58
|
+
cpu_feature required_cpu;
|
|
59
|
+
};
|
|
60
|
+
|
|
61
|
+
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features);
|