@fugood/llama.node 0.3.17 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +3 -1
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +39 -2
- package/lib/index.js +132 -1
- package/lib/index.ts +203 -3
- package/package.json +2 -1
- package/src/EmbeddingWorker.cpp +1 -1
- package/src/LlamaCompletionWorker.cpp +366 -19
- package/src/LlamaCompletionWorker.h +30 -10
- package/src/LlamaContext.cpp +213 -5
- package/src/LlamaContext.h +12 -0
- package/src/common.hpp +15 -0
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +133 -24
- package/src/llama.cpp/.github/workflows/build.yml +41 -762
- package/src/llama.cpp/.github/workflows/docker.yml +5 -2
- package/src/llama.cpp/.github/workflows/release.yml +716 -0
- package/src/llama.cpp/.github/workflows/server.yml +12 -12
- package/src/llama.cpp/CMakeLists.txt +5 -17
- package/src/llama.cpp/cmake/build-info.cmake +8 -2
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
- package/src/llama.cpp/common/CMakeLists.txt +31 -3
- package/src/llama.cpp/common/arg.cpp +48 -29
- package/src/llama.cpp/common/chat.cpp +128 -106
- package/src/llama.cpp/common/chat.h +2 -0
- package/src/llama.cpp/common/common.cpp +37 -1
- package/src/llama.cpp/common/common.h +18 -9
- package/src/llama.cpp/common/llguidance.cpp +1 -0
- package/src/llama.cpp/common/minja/chat-template.hpp +9 -5
- package/src/llama.cpp/common/minja/minja.hpp +69 -36
- package/src/llama.cpp/common/regex-partial.cpp +204 -0
- package/src/llama.cpp/common/regex-partial.h +56 -0
- package/src/llama.cpp/common/sampling.cpp +57 -50
- package/src/llama.cpp/examples/CMakeLists.txt +2 -23
- package/src/llama.cpp/examples/embedding/embedding.cpp +2 -11
- package/src/llama.cpp/examples/parallel/parallel.cpp +86 -14
- package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/training/finetune.cpp +96 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +27 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
- package/src/llama.cpp/ggml/include/ggml.h +10 -7
- package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
- package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +20 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +306 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +4 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +29 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +501 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +0 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +0 -6
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +36 -11
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +0 -2
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +41 -27
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +9 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +121 -232
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +7 -15
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +0 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +338 -166
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
- package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -70
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +657 -193
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +20 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +123 -29
- package/src/llama.cpp/ggml/src/ggml.c +29 -20
- package/src/llama.cpp/ggml/src/gguf.cpp +33 -33
- package/src/llama.cpp/include/llama.h +52 -11
- package/src/llama.cpp/requirements/requirements-all.txt +3 -3
- package/src/llama.cpp/scripts/xxd.cmake +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +1 -0
- package/src/llama.cpp/src/llama-adapter.cpp +6 -0
- package/src/llama.cpp/src/llama-arch.cpp +3 -0
- package/src/llama.cpp/src/llama-batch.cpp +5 -1
- package/src/llama.cpp/src/llama-batch.h +2 -1
- package/src/llama.cpp/src/llama-chat.cpp +17 -7
- package/src/llama.cpp/src/llama-chat.h +1 -0
- package/src/llama.cpp/src/llama-context.cpp +389 -501
- package/src/llama.cpp/src/llama-context.h +44 -32
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +20 -38
- package/src/llama.cpp/src/llama-graph.h +12 -8
- package/src/llama.cpp/src/llama-kv-cache.cpp +1503 -389
- package/src/llama.cpp/src/llama-kv-cache.h +271 -85
- package/src/llama.cpp/src/llama-memory.h +11 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +24 -15
- package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/src/llama.cpp/src/llama-model-saver.h +37 -0
- package/src/llama.cpp/src/llama-model.cpp +316 -69
- package/src/llama.cpp/src/llama-model.h +8 -1
- package/src/llama.cpp/src/llama-quant.cpp +15 -13
- package/src/llama.cpp/src/llama-sampling.cpp +18 -6
- package/src/llama.cpp/src/llama-vocab.cpp +42 -4
- package/src/llama.cpp/src/llama-vocab.h +6 -0
- package/src/llama.cpp/src/llama.cpp +14 -0
- package/src/llama.cpp/tests/CMakeLists.txt +10 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +107 -47
- package/src/llama.cpp/tests/test-chat-template.cpp +10 -11
- package/src/llama.cpp/tests/test-chat.cpp +3 -1
- package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
- package/src/llama.cpp/tests/test-opt.cpp +33 -21
- package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
- package/src/llama.cpp/tests/test-sampling.cpp +1 -1
- package/src/llama.cpp/tools/CMakeLists.txt +39 -0
- package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +2 -2
- package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
- package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +495 -348
- package/src/llama.cpp/{examples → tools}/main/main.cpp +6 -9
- package/src/llama.cpp/{examples/llava → tools/mtmd}/CMakeLists.txt +1 -35
- package/src/llama.cpp/{examples/llava → tools/mtmd}/clip-impl.h +25 -5
- package/src/llama.cpp/{examples/llava → tools/mtmd}/clip.cpp +1440 -1349
- package/src/llama.cpp/tools/mtmd/clip.h +99 -0
- package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd-cli.cpp +70 -44
- package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
- package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd.cpp +251 -281
- package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
- package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +4 -2
- package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +13 -76
- package/src/llama.cpp/{examples → tools}/rpc/rpc-server.cpp +70 -74
- package/src/llama.cpp/{examples → tools}/run/run.cpp +18 -4
- package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
- package/src/llama.cpp/{examples → tools}/server/server.cpp +291 -76
- package/src/llama.cpp/{examples → tools}/server/utils.hpp +377 -5
- package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
- package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/infill.cpp +0 -590
- package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
- package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
- package/src/llama.cpp/examples/llava/clip.h +0 -135
- package/src/llama.cpp/examples/llava/llava.cpp +0 -586
- package/src/llama.cpp/examples/llava/llava.h +0 -49
- package/src/llama.cpp/examples/llava/mtmd.h +0 -168
- package/src/llama.cpp/examples/llava/qwen2vl-test.cpp +0 -636
- /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/deprecation-warning.cpp +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/rpc/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/server/httplib.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/tts.cpp +0 -0
|
@@ -3,7 +3,9 @@
|
|
|
3
3
|
//
|
|
4
4
|
#include <arm_neon.h>
|
|
5
5
|
#include <assert.h>
|
|
6
|
+
#include <atomic>
|
|
6
7
|
#include <cfloat>
|
|
8
|
+
#include <stdexcept>
|
|
7
9
|
#include <stdint.h>
|
|
8
10
|
#include <string.h>
|
|
9
11
|
#if defined(__linux__)
|
|
@@ -34,8 +36,9 @@
|
|
|
34
36
|
#include "ggml-common.h"
|
|
35
37
|
|
|
36
38
|
struct ggml_kleidiai_context {
|
|
39
|
+
cpu_feature features;
|
|
37
40
|
ggml_kleidiai_kernels * kernels;
|
|
38
|
-
} static ctx = { NULL };
|
|
41
|
+
} static ctx = { CPU_FEATURE_NONE, NULL };
|
|
39
42
|
|
|
40
43
|
static void init_kleidiai_context(void) {
|
|
41
44
|
|
|
@@ -47,18 +50,18 @@ static void init_kleidiai_context(void) {
|
|
|
47
50
|
const char *env_var = getenv("GGML_KLEIDIAI_SME");
|
|
48
51
|
int sme_enabled = 0;
|
|
49
52
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
+
ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
|
|
54
|
+
(ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
|
|
55
|
+
(ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
|
|
53
56
|
|
|
54
57
|
if (env_var) {
|
|
55
58
|
sme_enabled = atoi(env_var);
|
|
56
59
|
}
|
|
57
60
|
|
|
58
61
|
if (sme_enabled != 0) {
|
|
59
|
-
features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
|
|
62
|
+
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
|
|
60
63
|
}
|
|
61
|
-
ctx.kernels =
|
|
64
|
+
ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
|
|
62
65
|
}
|
|
63
66
|
ggml_critical_section_end();
|
|
64
67
|
}
|
|
@@ -68,95 +71,275 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
|
|
|
68
71
|
return tensor->ne[dim];
|
|
69
72
|
}
|
|
70
73
|
|
|
74
|
+
template<typename Ret, typename Variant, typename... Args>
|
|
75
|
+
static Ret variant_call(const Variant & var, Args&&... args) {
|
|
76
|
+
return std::visit([&](auto&& func) -> Ret {
|
|
77
|
+
if constexpr (std::is_invocable_r_v<Ret, decltype(func), Args...>) {
|
|
78
|
+
return func(std::forward<Args>(args)...);
|
|
79
|
+
} else {
|
|
80
|
+
throw std::runtime_error("Invalid function type in variant_call");
|
|
81
|
+
}
|
|
82
|
+
}, var);
|
|
83
|
+
}
|
|
84
|
+
|
|
71
85
|
namespace ggml::cpu::kleidiai {
|
|
86
|
+
|
|
87
|
+
static size_t round_down(size_t x, size_t y) {
|
|
88
|
+
return y == 0 ? x : x - (x % y);
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint16_t * src, size_t rhs_stride) {
|
|
92
|
+
size_t src_stride = rhs_stride / sizeof(uint16_t);
|
|
93
|
+
size_t dst_stride = n;
|
|
94
|
+
|
|
95
|
+
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
|
|
96
|
+
for (size_t n_idx = 0; n_idx < n; ++n_idx) {
|
|
97
|
+
uint16_t v = *(src + k_idx + n_idx * src_stride);
|
|
98
|
+
*(dst + n_idx + k_idx * dst_stride) = kai_cast_f32_f16(v);
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
|
|
72
103
|
class tensor_traits : public ggml::cpu::tensor_traits {
|
|
73
104
|
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
|
|
74
|
-
|
|
75
|
-
|
|
105
|
+
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
|
|
106
|
+
GGML_ASSERT(kernels);
|
|
107
|
+
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
|
|
76
108
|
|
|
77
109
|
size_t k = op->src[0]->ne[0];
|
|
110
|
+
size_t n = op->src[0]->ne[1];
|
|
78
111
|
size_t m = op->src[1]->ne[1];
|
|
79
112
|
|
|
80
113
|
size_t mr = kernel->get_mr();
|
|
81
114
|
size_t kr = kernel->get_kr();
|
|
82
115
|
size_t sr = kernel->get_sr();
|
|
83
116
|
|
|
84
|
-
|
|
117
|
+
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
|
|
118
|
+
size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, QK4_0, mr, kr, sr);
|
|
119
|
+
} else if (kernels->rhs_type == GGML_TYPE_F16) {
|
|
120
|
+
size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr) +
|
|
121
|
+
variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
|
|
122
|
+
k * n * sizeof(float) + n * sizeof(float);
|
|
123
|
+
} else {
|
|
124
|
+
GGML_ASSERT(false);
|
|
125
|
+
}
|
|
85
126
|
|
|
86
127
|
return true;
|
|
87
128
|
}
|
|
88
129
|
|
|
130
|
+
|
|
89
131
|
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
|
|
90
132
|
if (dst->op == GGML_OP_MUL_MAT) {
|
|
91
|
-
|
|
92
|
-
|
|
133
|
+
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
|
134
|
+
return compute_forward_q4_0(params, dst);
|
|
135
|
+
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
|
136
|
+
return compute_forward_kv_cache(params, dst);
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
return false;
|
|
140
|
+
}
|
|
93
141
|
|
|
94
|
-
|
|
142
|
+
bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) {
|
|
143
|
+
static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
|
|
95
144
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
lhs_packing_info * lhs_info = &ctx.kernels->lhs_info;
|
|
145
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
146
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
99
147
|
|
|
100
|
-
|
|
148
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
101
149
|
|
|
102
|
-
|
|
103
|
-
|
|
150
|
+
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
|
151
|
+
GGML_ASSERT(kernels);
|
|
104
152
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
const size_t n = ne01;
|
|
153
|
+
kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
|
|
154
|
+
GGML_ASSERT(kernel);
|
|
108
155
|
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
const size_t n_start = ith * num_n_per_thread;
|
|
156
|
+
const int nth = params->nth;
|
|
157
|
+
const int ith = params->ith;
|
|
112
158
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
159
|
+
const int64_t lhs_batch_size0 = ne12;
|
|
160
|
+
const int64_t rhs_batch_size0 = ne02;
|
|
161
|
+
const int64_t batch_size = rhs_batch_size0;
|
|
162
|
+
|
|
163
|
+
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
|
|
164
|
+
|
|
165
|
+
const int64_t m = ne11 * r;
|
|
166
|
+
const int64_t n = ne01;
|
|
167
|
+
const int64_t k = ne00;
|
|
168
|
+
|
|
169
|
+
const size_t lhs_stride = src1->nb[1];
|
|
170
|
+
const size_t rhs_stride = src0->nb[1];
|
|
171
|
+
const size_t dst_stride = dst->nb[1];
|
|
172
|
+
|
|
173
|
+
const int64_t mr = static_cast<int64_t>(kernel->get_mr());
|
|
174
|
+
const int64_t nr = static_cast<int64_t>(kernel->get_nr());
|
|
175
|
+
const int64_t kr = static_cast<int64_t>(kernel->get_kr());
|
|
176
|
+
const int64_t sr = static_cast<int64_t>(kernel->get_sr());
|
|
177
|
+
|
|
178
|
+
const size_t lhs_packed_size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr);
|
|
179
|
+
const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k);
|
|
180
|
+
const size_t kxn_size = k * n * sizeof(float);
|
|
181
|
+
const size_t bias_size = n * sizeof(float);
|
|
182
|
+
|
|
183
|
+
const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
|
|
184
|
+
GGML_ASSERT(wsize_required <= params->wsize);
|
|
185
|
+
|
|
186
|
+
uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
|
|
187
|
+
uint8_t * rhs_packed = lhs_packed + lhs_packed_size;
|
|
188
|
+
uint8_t * rhs_kxn = rhs_packed + rhs_packed_size;
|
|
189
|
+
uint8_t * bias = rhs_kxn + kxn_size;
|
|
190
|
+
|
|
191
|
+
for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
|
|
192
|
+
const uint8_t * lhs_batch = static_cast<const uint8_t *>(src1->data) + batch_idx * m * lhs_stride;
|
|
193
|
+
const uint8_t * rhs_batch = static_cast<const uint8_t *>(src0->data) + batch_idx * n * rhs_stride;
|
|
194
|
+
uint8_t * dst_batch = static_cast<uint8_t *>(dst->data) + batch_idx * m * dst_stride;
|
|
117
195
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
196
|
+
// LHS packing
|
|
197
|
+
{
|
|
198
|
+
const int64_t m_roundup_mr = kai_roundup(m, mr);
|
|
199
|
+
const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth);
|
|
121
200
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
201
|
+
if (ith < num_threads) {
|
|
202
|
+
const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr);
|
|
203
|
+
const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
|
|
125
204
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
205
|
+
const int64_t m_start = ith * num_m_per_thread0;
|
|
206
|
+
const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
|
|
207
|
+
|
|
208
|
+
const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
|
|
209
|
+
const size_t lhs_packed_offset = variant_call<size_t>(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr);
|
|
210
|
+
|
|
211
|
+
const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset;
|
|
212
|
+
void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;
|
|
213
|
+
|
|
214
|
+
variant_call<void>(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
|
|
215
|
+
}
|
|
132
216
|
}
|
|
133
217
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
218
|
+
// RHS packing
|
|
219
|
+
if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) {
|
|
220
|
+
// First thread to reach this point handles RHS packing
|
|
221
|
+
memset(bias, 0, n * sizeof(float));
|
|
222
|
+
transpose_f32kxn_f16nxk(n, k, reinterpret_cast<float *>(rhs_kxn),
|
|
223
|
+
reinterpret_cast<const uint16_t *>(rhs_batch), rhs_stride);
|
|
140
224
|
|
|
141
|
-
|
|
225
|
+
variant_call<void>(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float),
|
|
226
|
+
rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
|
|
142
227
|
}
|
|
143
228
|
|
|
144
229
|
ggml_barrier(params->threadpool);
|
|
145
230
|
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
231
|
+
first_to_arrive.clear(std::memory_order_release);
|
|
232
|
+
|
|
233
|
+
// Perform the matmul
|
|
234
|
+
{
|
|
235
|
+
const int64_t m_to_process = m;
|
|
236
|
+
const int64_t m_start = 0;
|
|
237
|
+
|
|
238
|
+
const int64_t n_step = static_cast<int64_t>(kernel->get_n_step());
|
|
239
|
+
const int64_t num_threads = KAI_MIN(n / n_step, nth);
|
|
240
|
+
|
|
241
|
+
if (ith < num_threads) {
|
|
242
|
+
const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step);
|
|
243
|
+
const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0;
|
|
244
|
+
|
|
245
|
+
const int64_t n_start = ith * num_n_per_thread0;
|
|
246
|
+
const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
|
|
247
|
+
|
|
248
|
+
const size_t lhs_packed_offset = variant_call<size_t>(kernel->get_lhs_offset, m_start, k);
|
|
249
|
+
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k);
|
|
250
|
+
const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride);
|
|
251
|
+
|
|
252
|
+
const void * lhs_ptr = lhs_packed + lhs_packed_offset;
|
|
253
|
+
const void * rhs_ptr = rhs_packed + rhs_packed_offset;
|
|
254
|
+
float * dst_ptr = reinterpret_cast<float *>(dst_batch + dst_offset);
|
|
255
|
+
|
|
256
|
+
variant_call<void>(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
|
|
257
|
+
}
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
if (batch_idx != batch_size - 1) {
|
|
261
|
+
// This barrier is necessary when the batch size is larger than 1. While processing a batch,
|
|
262
|
+
// the work data buffer (params->wdata) is used as temporary storage which means that only
|
|
263
|
+
// a single batch can be processed at any given time. No barrier is needed for the last
|
|
264
|
+
// batch since GGML inserts a barrier between the execution of every operator.
|
|
265
|
+
ggml_barrier(params->threadpool);
|
|
266
|
+
}
|
|
158
267
|
}
|
|
159
|
-
|
|
268
|
+
|
|
269
|
+
return true;
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
|
273
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
274
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
275
|
+
|
|
276
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
277
|
+
|
|
278
|
+
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
|
279
|
+
GGML_ASSERT(kernels);
|
|
280
|
+
|
|
281
|
+
kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
|
|
282
|
+
lhs_packing_info * lhs_info = &kernels->lhs_info;
|
|
283
|
+
|
|
284
|
+
GGML_ASSERT(kernel);
|
|
285
|
+
|
|
286
|
+
const int ith = params->ith;
|
|
287
|
+
const int nth = params->nth;
|
|
288
|
+
|
|
289
|
+
const size_t k = ne00;
|
|
290
|
+
const size_t m = ne11;
|
|
291
|
+
const size_t n = ne01;
|
|
292
|
+
|
|
293
|
+
size_t mr = kernel->get_mr();
|
|
294
|
+
size_t kr = kernel->get_kr();
|
|
295
|
+
size_t sr = kernel->get_sr();
|
|
296
|
+
|
|
297
|
+
const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
|
|
298
|
+
uint8_t * lhs_packed = (uint8_t*)params->wdata;
|
|
299
|
+
const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
|
|
300
|
+
|
|
301
|
+
const size_t n_step = kernel->get_n_step();
|
|
302
|
+
const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
|
|
303
|
+
const size_t n_start = ith * num_n_per_thread;
|
|
304
|
+
|
|
305
|
+
size_t n_to_process = num_n_per_thread;
|
|
306
|
+
if ((n_start + n_to_process) > n) {
|
|
307
|
+
n_to_process = n - n_start;
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
// Calculate number of columns to be processed per thread
|
|
311
|
+
const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
|
|
312
|
+
const size_t m_start = ith * num_m_per_thread;
|
|
313
|
+
size_t m_to_process = num_m_per_thread;
|
|
314
|
+
if ((m_start + m_to_process) > m) {
|
|
315
|
+
m_to_process = m - m_start;
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
if (m_start < m) {
|
|
319
|
+
// Transform LHS
|
|
320
|
+
const size_t src_stride = src1->nb[1];
|
|
321
|
+
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
|
|
322
|
+
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr);
|
|
323
|
+
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
|
|
324
|
+
|
|
325
|
+
variant_call<void>(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
ggml_barrier(params->threadpool);
|
|
329
|
+
|
|
330
|
+
// Perform the operation
|
|
331
|
+
const size_t dst_stride = dst->nb[1];
|
|
332
|
+
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr);
|
|
333
|
+
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k, QK4_0);
|
|
334
|
+
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
|
|
335
|
+
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
|
|
336
|
+
const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
|
|
337
|
+
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
|
|
338
|
+
|
|
339
|
+
variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
|
|
340
|
+
sizeof(float), -FLT_MAX, FLT_MAX);
|
|
341
|
+
|
|
342
|
+
return true;
|
|
160
343
|
}
|
|
161
344
|
|
|
162
345
|
public:
|
|
@@ -169,13 +352,13 @@ public:
|
|
|
169
352
|
size_t sr = ctx.kernels->gemm.get_sr();
|
|
170
353
|
|
|
171
354
|
#ifndef NDEBUG
|
|
172
|
-
const size_t repacked_size = ctx.kernels->rhs_info.packed_size
|
|
355
|
+
const size_t repacked_size = variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
|
|
173
356
|
GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
|
|
174
357
|
#endif
|
|
175
358
|
struct kai_rhs_pack_qs4cxs1s0_param params;
|
|
176
359
|
params.lhs_zero_point = 1;
|
|
177
360
|
params.rhs_zero_point = 8;
|
|
178
|
-
ctx.kernels->rhs_info.pack_func
|
|
361
|
+
variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms);
|
|
179
362
|
|
|
180
363
|
return 0;
|
|
181
364
|
|
|
@@ -189,7 +372,7 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
|
|
|
189
372
|
}
|
|
190
373
|
} // namespace ggml::cpu::kleidiai
|
|
191
374
|
|
|
192
|
-
|
|
375
|
+
static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
|
|
193
376
|
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
|
|
194
377
|
|
|
195
378
|
GGML_UNUSED(buffer);
|
|
@@ -238,12 +421,11 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
|
|
|
238
421
|
namespace ggml::cpu::kleidiai {
|
|
239
422
|
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|
240
423
|
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
|
241
|
-
if (
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
) {
|
|
424
|
+
if (op->op == GGML_OP_MUL_MAT &&
|
|
425
|
+
op->src[0]->type == GGML_TYPE_Q4_0 &&
|
|
426
|
+
op->src[0]->buffer &&
|
|
427
|
+
(ggml_n_dims(op->src[0]) == 2) &&
|
|
428
|
+
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
|
|
247
429
|
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
|
248
430
|
return false;
|
|
249
431
|
}
|
|
@@ -260,6 +442,19 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|
|
260
442
|
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
|
261
443
|
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
|
262
444
|
}
|
|
445
|
+
else if (ggml_kleidiai_select_kernels(ctx.features, op) &&
|
|
446
|
+
op->src[0]->op == GGML_OP_VIEW &&
|
|
447
|
+
(op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) &&
|
|
448
|
+
op->src[1]->ne[1] > 1) {
|
|
449
|
+
if ((op->src[0]->nb[0] != 2) ||
|
|
450
|
+
(op->src[1]->nb[0] != 4) ||
|
|
451
|
+
(op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
|
|
452
|
+
(op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
|
|
453
|
+
return nullptr;
|
|
454
|
+
}
|
|
455
|
+
|
|
456
|
+
return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
|
|
457
|
+
}
|
|
263
458
|
}
|
|
264
459
|
return nullptr;
|
|
265
460
|
}
|