@fugood/llama.node 0.3.6 → 0.3.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +17 -2
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +3 -1
- package/lib/index.js +16 -1
- package/lib/index.ts +16 -0
- package/package.json +1 -1
- package/src/EmbeddingWorker.cpp +4 -3
- package/src/LlamaCompletionWorker.cpp +4 -2
- package/src/LlamaContext.cpp +61 -6
- package/src/LlamaContext.h +1 -0
- package/src/common.hpp +6 -11
- package/src/llama.cpp/.github/workflows/build.yml +19 -17
- package/src/llama.cpp/.github/workflows/docker.yml +77 -30
- package/src/llama.cpp/.github/workflows/editorconfig.yml +3 -1
- package/src/llama.cpp/.github/workflows/server.yml +22 -3
- package/src/llama.cpp/CMakeLists.txt +49 -24
- package/src/llama.cpp/common/arg.cpp +82 -26
- package/src/llama.cpp/common/arg.h +3 -0
- package/src/llama.cpp/common/common.cpp +192 -72
- package/src/llama.cpp/common/common.h +51 -18
- package/src/llama.cpp/common/ngram-cache.cpp +12 -12
- package/src/llama.cpp/common/ngram-cache.h +2 -2
- package/src/llama.cpp/common/sampling.cpp +11 -6
- package/src/llama.cpp/common/speculative.cpp +18 -15
- package/src/llama.cpp/docs/build.md +2 -0
- package/src/llama.cpp/examples/batched/batched.cpp +9 -7
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +3 -3
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +10 -8
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +11 -8
- package/src/llama.cpp/examples/cvector-generator/mean.hpp +1 -1
- package/src/llama.cpp/examples/cvector-generator/pca.hpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +8 -7
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +7 -6
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +8 -7
- package/src/llama.cpp/examples/gguf/gguf.cpp +10 -6
- package/src/llama.cpp/examples/gguf-hash/gguf-hash.cpp +1 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +8 -7
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +13 -10
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +13 -12
- package/src/llama.cpp/examples/infill/infill.cpp +23 -24
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +44 -13
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -6
- package/src/llama.cpp/examples/llava/clip.cpp +4 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +9 -6
- package/src/llama.cpp/examples/llava/llava.cpp +2 -2
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +8 -4
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +11 -8
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -7
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +4 -9
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +3 -7
- package/src/llama.cpp/examples/lookup/lookup.cpp +5 -6
- package/src/llama.cpp/examples/main/main.cpp +51 -29
- package/src/llama.cpp/examples/parallel/parallel.cpp +5 -6
- package/src/llama.cpp/examples/passkey/passkey.cpp +7 -5
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +37 -23
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -14
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +8 -8
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +12 -0
- package/src/llama.cpp/examples/run/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +1351 -0
- package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +114 -0
- package/src/llama.cpp/examples/run/run.cpp +175 -61
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -25
- package/src/llama.cpp/examples/server/CMakeLists.txt +1 -0
- package/src/llama.cpp/examples/server/httplib.h +1295 -409
- package/src/llama.cpp/examples/server/server.cpp +387 -181
- package/src/llama.cpp/examples/server/tests/requirements.txt +1 -0
- package/src/llama.cpp/examples/server/utils.hpp +170 -58
- package/src/llama.cpp/examples/simple/simple.cpp +9 -8
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +16 -12
- package/src/llama.cpp/examples/speculative/speculative.cpp +22 -23
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +8 -12
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +17 -5
- package/src/llama.cpp/examples/tts/tts.cpp +64 -23
- package/src/llama.cpp/ggml/CMakeLists.txt +5 -21
- package/src/llama.cpp/ggml/include/ggml-backend.h +2 -0
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -0
- package/src/llama.cpp/ggml/include/ggml.h +36 -145
- package/src/llama.cpp/ggml/include/gguf.h +202 -0
- package/src/llama.cpp/ggml/src/CMakeLists.txt +6 -3
- package/src/llama.cpp/ggml/src/ggml-alloc.c +5 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +0 -1
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +79 -49
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +5 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +33 -23
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +57 -72
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +87 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +335 -66
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +10 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1090 -378
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +2 -2
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +3 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +3 -1
- package/src/llama.cpp/ggml/src/ggml-impl.h +11 -16
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +16 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +6 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +154 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +9 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +18 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +40 -95
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +48 -48
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +24 -24
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -164
- package/src/llama.cpp/ggml/src/ggml-sycl/gla.cpp +105 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/gla.hpp +8 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +3 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +7 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +1 -2
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +74 -4
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +314 -116
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -2
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +9 -3
- package/src/llama.cpp/ggml/src/ggml.c +117 -1327
- package/src/llama.cpp/ggml/src/gguf.cpp +1329 -0
- package/src/llama.cpp/include/llama-cpp.h +6 -1
- package/src/llama.cpp/include/llama.h +138 -75
- package/src/llama.cpp/src/CMakeLists.txt +13 -1
- package/src/llama.cpp/src/llama-adapter.cpp +347 -0
- package/src/llama.cpp/src/llama-adapter.h +74 -0
- package/src/llama.cpp/src/llama-arch.cpp +1487 -0
- package/src/llama.cpp/src/llama-arch.h +400 -0
- package/src/llama.cpp/src/llama-batch.cpp +368 -0
- package/src/llama.cpp/src/llama-batch.h +88 -0
- package/src/llama.cpp/src/llama-chat.cpp +578 -0
- package/src/llama.cpp/src/llama-chat.h +52 -0
- package/src/llama.cpp/src/llama-context.cpp +1775 -0
- package/src/llama.cpp/src/llama-context.h +128 -0
- package/src/llama.cpp/src/llama-cparams.cpp +1 -0
- package/src/llama.cpp/src/llama-cparams.h +37 -0
- package/src/llama.cpp/src/llama-grammar.cpp +5 -4
- package/src/llama.cpp/src/llama-grammar.h +3 -1
- package/src/llama.cpp/src/llama-hparams.cpp +71 -0
- package/src/llama.cpp/src/llama-hparams.h +139 -0
- package/src/llama.cpp/src/llama-impl.cpp +167 -0
- package/src/llama.cpp/src/llama-impl.h +16 -136
- package/src/llama.cpp/src/llama-kv-cache.cpp +718 -0
- package/src/llama.cpp/src/llama-kv-cache.h +218 -0
- package/src/llama.cpp/src/llama-mmap.cpp +589 -0
- package/src/llama.cpp/src/llama-mmap.h +67 -0
- package/src/llama.cpp/src/llama-model-loader.cpp +1124 -0
- package/src/llama.cpp/src/llama-model-loader.h +167 -0
- package/src/llama.cpp/src/llama-model.cpp +3953 -0
- package/src/llama.cpp/src/llama-model.h +370 -0
- package/src/llama.cpp/src/llama-quant.cpp +934 -0
- package/src/llama.cpp/src/llama-quant.h +1 -0
- package/src/llama.cpp/src/llama-sampling.cpp +147 -32
- package/src/llama.cpp/src/llama-sampling.h +3 -19
- package/src/llama.cpp/src/llama-vocab.cpp +1832 -575
- package/src/llama.cpp/src/llama-vocab.h +97 -142
- package/src/llama.cpp/src/llama.cpp +7160 -20314
- package/src/llama.cpp/src/unicode.cpp +8 -3
- package/src/llama.cpp/tests/CMakeLists.txt +2 -0
- package/src/llama.cpp/tests/test-autorelease.cpp +3 -3
- package/src/llama.cpp/tests/test-backend-ops.cpp +370 -59
- package/src/llama.cpp/tests/test-chat-template.cpp +162 -125
- package/src/llama.cpp/tests/test-gguf.cpp +222 -187
- package/src/llama.cpp/tests/test-model-load-cancel.cpp +1 -1
- package/src/llama.cpp/tests/test-sampling.cpp +0 -1
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +4 -4
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +9 -7
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +8 -6
|
@@ -82,8 +82,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|
|
82
82
|
if (MSVC AND NOT CMAKE_C_COMPILER_ID STREQUAL "Clang")
|
|
83
83
|
message(FATAL_ERROR "MSVC is not supported for ARM, use clang")
|
|
84
84
|
else()
|
|
85
|
-
check_cxx_compiler_flag(-mfp16-format=ieee
|
|
86
|
-
if (NOT "${
|
|
85
|
+
check_cxx_compiler_flag(-mfp16-format=ieee GGML_COMPILER_SUPPORTS_FP16_FORMAT_I3E)
|
|
86
|
+
if (NOT "${GGML_COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "")
|
|
87
87
|
list(APPEND ARCH_FLAGS -mfp16-format=ieee)
|
|
88
88
|
endif()
|
|
89
89
|
|
|
@@ -106,28 +106,28 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|
|
106
106
|
message(STATUS "ARM -mcpu not found, -mcpu=native will be used")
|
|
107
107
|
endif()
|
|
108
108
|
|
|
109
|
-
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
|
110
109
|
include(CheckCXXSourceRuns)
|
|
111
110
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
111
|
+
function(check_arm_feature tag code)
|
|
112
|
+
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
|
113
|
+
set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+${tag}")
|
|
114
|
+
check_cxx_source_runs(
|
|
115
|
+
"${code}"
|
|
116
|
+
GGML_MACHINE_SUPPORTS_${tag}
|
|
117
|
+
)
|
|
118
|
+
if (GGML_MACHINE_SUPPORTS_${tag})
|
|
119
|
+
set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+${tag}" PARENT_SCOPE)
|
|
120
|
+
else()
|
|
121
|
+
set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE)
|
|
122
|
+
endif()
|
|
123
|
+
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
|
|
124
|
+
endfunction()
|
|
119
125
|
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
GGML_COMPILER_SUPPORT_I8MM)
|
|
124
|
-
if (GGML_COMPILER_SUPPORT_I8MM)
|
|
125
|
-
set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+i8mm")
|
|
126
|
-
endif()
|
|
126
|
+
check_arm_feature(dotprod "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
|
|
127
|
+
check_arm_feature(i8mm "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
|
|
128
|
+
check_arm_feature(sve "#include <arm_sve.h>\nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
|
|
127
129
|
|
|
128
|
-
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
|
|
129
130
|
list(APPEND ARCH_FLAGS "${ARM_MCPU_FLAG}${ARM_MCPU_FLAG_FIX}")
|
|
130
|
-
|
|
131
131
|
else()
|
|
132
132
|
if (GGML_CPU_ARM_ARCH)
|
|
133
133
|
list(APPEND ARCH_FLAGS -march=${GGML_CPU_ARM_ARCH})
|
|
@@ -135,14 +135,20 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|
|
135
135
|
endif()
|
|
136
136
|
|
|
137
137
|
# show enabled features
|
|
138
|
+
if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
|
|
139
|
+
set(FEAT_INPUT_FILE "NUL")
|
|
140
|
+
else()
|
|
141
|
+
set(FEAT_INPUT_FILE "/dev/null")
|
|
142
|
+
endif()
|
|
143
|
+
|
|
138
144
|
execute_process(
|
|
139
145
|
COMMAND ${CMAKE_C_COMPILER} ${ARCH_FLAGS} -dM -E -
|
|
140
|
-
INPUT_FILE
|
|
146
|
+
INPUT_FILE ${FEAT_INPUT_FILE}
|
|
141
147
|
OUTPUT_VARIABLE ARM_FEATURE
|
|
142
148
|
RESULT_VARIABLE ARM_FEATURE_RESULT
|
|
143
149
|
)
|
|
144
150
|
if (ARM_FEATURE_RESULT)
|
|
145
|
-
message(
|
|
151
|
+
message(WARNING "Failed to get ARM features")
|
|
146
152
|
else()
|
|
147
153
|
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC)
|
|
148
154
|
string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos)
|
|
@@ -209,8 +215,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|
|
209
215
|
list(APPEND ARCH_DEFINITIONS GGML_SSE42)
|
|
210
216
|
endif()
|
|
211
217
|
if (GGML_AVX_VNNI)
|
|
212
|
-
|
|
213
|
-
#list(APPEND ARCH_DEFINITIONS __AVXVNNI__ GGML_AVX_VNNI)
|
|
218
|
+
list(APPEND ARCH_DEFINITIONS __AVXVNNI__ GGML_AVX_VNNI)
|
|
214
219
|
endif()
|
|
215
220
|
else ()
|
|
216
221
|
if (GGML_NATIVE)
|
|
@@ -317,6 +322,11 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|
|
317
322
|
target_compile_definitions(${GGML_CPU_NAME} PRIVATE ${ARCH_DEFINITIONS})
|
|
318
323
|
|
|
319
324
|
if (GGML_BACKEND_DL)
|
|
325
|
+
if (GGML_NATIVE)
|
|
326
|
+
# the feature check relies on ARCH_DEFINITIONS, but it is not set with GGML_NATIVE
|
|
327
|
+
message(FATAL_ERROR "GGML_NATIVE is not compatible with GGML_BACKEND_DL, consider using GGML_CPU_ALL_VARIANTS")
|
|
328
|
+
endif()
|
|
329
|
+
|
|
320
330
|
# The feature detection code is compiled as a separate target so that
|
|
321
331
|
# it can be built without the architecture flags
|
|
322
332
|
# Since multiple variants of the CPU backend may be included in the same
|
|
@@ -194,9 +194,12 @@ static inline __m256i sum_i16_pairs_int32x8(const __m256i x) {
|
|
|
194
194
|
}
|
|
195
195
|
|
|
196
196
|
static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) {
|
|
197
|
-
#if defined(
|
|
197
|
+
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
|
|
198
198
|
const __m256i zero = _mm256_setzero_si256();
|
|
199
199
|
return _mm256_dpbusd_epi32(zero, ax, sy);
|
|
200
|
+
#elif defined(__AVXVNNI__)
|
|
201
|
+
const __m256i zero = _mm256_setzero_si256();
|
|
202
|
+
return _mm256_dpbusd_avx_epi32(zero, ax, sy);
|
|
200
203
|
#else
|
|
201
204
|
// Perform multiplication and create 16-bit values
|
|
202
205
|
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
|
@@ -564,21 +567,21 @@ static void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
|
|
|
564
567
|
|
|
565
568
|
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
566
569
|
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
|
567
|
-
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *)vx;
|
|
570
|
+
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
|
|
568
571
|
|
|
569
572
|
for (int c = 0; c < nc; c += ncols_interleaved) {
|
|
570
|
-
const block_q8_0 * a_ptr = (const block_q8_0 *)vy;
|
|
573
|
+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
571
574
|
float32x4_t acc = vdupq_n_f32(0);
|
|
572
575
|
for (int b = 0; b < nb; b++) {
|
|
573
|
-
int8x16_t b0 = vld1q_s8((const int8_t *)b_ptr->qs);
|
|
574
|
-
int8x16_t b1 = vld1q_s8((const int8_t *)b_ptr->qs + 16);
|
|
575
|
-
int8x16_t b2 = vld1q_s8((const int8_t *)b_ptr->qs + 32);
|
|
576
|
-
int8x16_t b3 = vld1q_s8((const int8_t *)b_ptr->qs + 48);
|
|
577
|
-
float16x4_t bd = vld1_f16((const __fp16 *)b_ptr->d);
|
|
576
|
+
int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs);
|
|
577
|
+
int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16);
|
|
578
|
+
int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32);
|
|
579
|
+
int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48);
|
|
580
|
+
float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
|
|
578
581
|
|
|
579
582
|
int8x16_t a0 = vld1q_s8(a_ptr->qs);
|
|
580
583
|
int8x16_t a1 = vld1q_s8(a_ptr->qs + qk/2);
|
|
581
|
-
float16x4_t ad = vld1_dup_f16((const __fp16 *)&a_ptr->d);
|
|
584
|
+
float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
|
|
582
585
|
|
|
583
586
|
int32x4_t ret = vdupq_n_s32(0);
|
|
584
587
|
|
|
@@ -647,72 +650,52 @@ static void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c
|
|
|
647
650
|
UNUSED(ncols_interleaved);
|
|
648
651
|
UNUSED(blocklen);
|
|
649
652
|
|
|
650
|
-
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(
|
|
651
|
-
if (ggml_cpu_has_neon() &&
|
|
652
|
-
const
|
|
653
|
-
const void * a_ptr = vy;
|
|
654
|
-
float * res_ptr = s;
|
|
653
|
+
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
654
|
+
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
|
655
|
+
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
|
|
655
656
|
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
".inst 0x4e95967a // sdot v26.4s, v19.16b, v21.16b\n"
|
|
696
|
-
"fmul v16.4s, v16.4s, v25.4s\n"
|
|
697
|
-
".inst 0x4e9297fd // sdot v29.4s, v31.16b, v18.16b\n"
|
|
698
|
-
".inst 0x4e9297da // sdot v26.4s, v30.16b, v18.16b\n"
|
|
699
|
-
".inst 0x4e91979d // sdot v29.4s, v28.16b, v17.16b\n"
|
|
700
|
-
".inst 0x4e91977a // sdot v26.4s, v27.16b, v17.16b\n"
|
|
701
|
-
"addp v29.4s, v29.4s, v26.4s\n"
|
|
702
|
-
"scvtf v29.4s, v29.4s, #0x4\n"
|
|
703
|
-
"fmla v0.4s, v29.4s, v16.4s\n"
|
|
704
|
-
"cbnz x22, 2b\n"
|
|
705
|
-
"sub %x[nc], %x[nc], #0x4\n"
|
|
706
|
-
"str q0, [%x[res_ptr], #0x0]\n"
|
|
707
|
-
"add %x[res_ptr], %x[res_ptr], #0x10\n"
|
|
708
|
-
"cbnz %x[nc], 1b\n"
|
|
709
|
-
: [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
|
|
710
|
-
: [a_ptr] "r" (a_ptr), [nb] "r" (nb)
|
|
711
|
-
: "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23"
|
|
712
|
-
);
|
|
657
|
+
for (int c = 0; c < nc; c += ncols_interleaved) {
|
|
658
|
+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
659
|
+
float32x4_t acc = vdupq_n_f32(0);
|
|
660
|
+
for (int b = 0; b < nb; b++) {
|
|
661
|
+
int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs);
|
|
662
|
+
int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16);
|
|
663
|
+
int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32);
|
|
664
|
+
int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48);
|
|
665
|
+
float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
|
|
666
|
+
|
|
667
|
+
int8x16_t a0 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs);
|
|
668
|
+
int8x16_t a1 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 1);
|
|
669
|
+
int8x16_t a2 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 2);
|
|
670
|
+
int8x16_t a3 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 3);
|
|
671
|
+
float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
|
|
672
|
+
|
|
673
|
+
int32x4_t ret0 = vdupq_n_s32(0);
|
|
674
|
+
int32x4_t ret1 = vdupq_n_s32(0);
|
|
675
|
+
|
|
676
|
+
ret0 = vdotq_s32(ret0, b0 << 4, a0);
|
|
677
|
+
ret1 = vdotq_s32(ret1, b1 << 4, a0);
|
|
678
|
+
ret0 = vdotq_s32(ret0, b2 << 4, a1);
|
|
679
|
+
ret1 = vdotq_s32(ret1, b3 << 4, a1);
|
|
680
|
+
|
|
681
|
+
ret0 = vdotq_s32(ret0, b0 & 0xf0U, a2);
|
|
682
|
+
ret1 = vdotq_s32(ret1, b1 & 0xf0U, a2);
|
|
683
|
+
ret0 = vdotq_s32(ret0, b2 & 0xf0U, a3);
|
|
684
|
+
ret1 = vdotq_s32(ret1, b3 & 0xf0U, a3);
|
|
685
|
+
|
|
686
|
+
int32x4_t ret = vpaddq_s32(ret0, ret1);
|
|
687
|
+
|
|
688
|
+
acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4),
|
|
689
|
+
vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
|
|
690
|
+
a_ptr++;
|
|
691
|
+
b_ptr++;
|
|
692
|
+
}
|
|
693
|
+
vst1q_f32(s, acc);
|
|
694
|
+
s += ncols_interleaved;
|
|
695
|
+
}
|
|
713
696
|
return;
|
|
714
697
|
}
|
|
715
|
-
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(
|
|
698
|
+
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
716
699
|
float sumf[4];
|
|
717
700
|
int sumi;
|
|
718
701
|
|
|
@@ -4186,6 +4169,8 @@ static ggml_backend_buffer_t ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(g
|
|
|
4186
4169
|
buffer->buft = buft;
|
|
4187
4170
|
buffer->iface.init_tensor = ggml_backend_cpu_aarch64_buffer_init_tensor;
|
|
4188
4171
|
buffer->iface.set_tensor = ggml_backend_cpu_aarch64_buffer_set_tensor;
|
|
4172
|
+
buffer->iface.get_tensor = nullptr;
|
|
4173
|
+
buffer->iface.cpy_tensor = nullptr;
|
|
4189
4174
|
return buffer;
|
|
4190
4175
|
}
|
|
4191
4176
|
|
|
@@ -103,10 +103,14 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
|
|
|
103
103
|
}
|
|
104
104
|
|
|
105
105
|
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
|
|
106
|
-
#if defined(
|
|
106
|
+
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
|
|
107
107
|
const __m256i zero = _mm256_setzero_si256();
|
|
108
108
|
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
|
|
109
109
|
return _mm256_cvtepi32_ps(summed_pairs);
|
|
110
|
+
#elif defined(__AVXVNNI__)
|
|
111
|
+
const __m256i zero = _mm256_setzero_si256();
|
|
112
|
+
const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy);
|
|
113
|
+
return _mm256_cvtepi32_ps(summed_pairs);
|
|
110
114
|
#else
|
|
111
115
|
// Perform multiplication and create 16-bit values
|
|
112
116
|
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
|
@@ -5569,7 +5573,88 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
|
|
|
5569
5573
|
|
|
5570
5574
|
uint32_t utmp[4];
|
|
5571
5575
|
|
|
5572
|
-
#ifdef
|
|
5576
|
+
#ifdef __ARM_FEATURE_SVE
|
|
5577
|
+
float sumf = 0;
|
|
5578
|
+
for (int i = 0; i < nb; ++i) {
|
|
5579
|
+
|
|
5580
|
+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
|
5581
|
+
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
|
5582
|
+
|
|
5583
|
+
const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
|
|
5584
|
+
|
|
5585
|
+
memcpy(utmp, x[i].scales, K_SCALE_SIZE);
|
|
5586
|
+
|
|
5587
|
+
uint32x2_t mins8 = { 0 };
|
|
5588
|
+
mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
|
|
5589
|
+
mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
|
|
5590
|
+
|
|
5591
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
|
5592
|
+
utmp[0] &= kmask1;
|
|
5593
|
+
|
|
5594
|
+
const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
|
|
5595
|
+
const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
|
|
5596
|
+
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
|
|
5597
|
+
sumf -= dmin * vaddvq_s32(prod);
|
|
5598
|
+
|
|
5599
|
+
const uint8_t * scales = (const uint8_t *)utmp;
|
|
5600
|
+
|
|
5601
|
+
const uint8_t * restrict q4 = x[i].qs;
|
|
5602
|
+
const int8_t * restrict q8 = y[i].qs;
|
|
5603
|
+
|
|
5604
|
+
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
|
5605
|
+
const svuint8_t m4b = svdup_n_u8(0xf);
|
|
5606
|
+
const svint32_t mzero = svdup_n_s32(0);
|
|
5607
|
+
svint32_t sumi1 = svdup_n_s32(0);
|
|
5608
|
+
svint32_t sumi1_1 = svdup_n_s32(0);
|
|
5609
|
+
svint32_t sumi1_2 = svdup_n_s32(0);
|
|
5610
|
+
svint32_t sumi2 = svdup_n_s32(0);
|
|
5611
|
+
svint32_t sumi2_1 = svdup_n_s32(0);
|
|
5612
|
+
svint32_t sumi2_2 = svdup_n_s32(0);
|
|
5613
|
+
switch (vector_length) {
|
|
5614
|
+
case 128:
|
|
5615
|
+
{
|
|
5616
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
|
5617
|
+
svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b));
|
|
5618
|
+
svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
|
|
5619
|
+
sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
|
|
5620
|
+
q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b));
|
|
5621
|
+
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
|
|
5622
|
+
sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
|
|
5623
|
+
|
|
5624
|
+
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4));
|
|
5625
|
+
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
|
|
5626
|
+
sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
|
|
5627
|
+
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4));
|
|
5628
|
+
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
|
|
5629
|
+
sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
|
|
5630
|
+
q4 += 32;
|
|
5631
|
+
}
|
|
5632
|
+
sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2);
|
|
5633
|
+
sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2);
|
|
5634
|
+
sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2)));
|
|
5635
|
+
} break;
|
|
5636
|
+
case 256:
|
|
5637
|
+
case 512:
|
|
5638
|
+
{
|
|
5639
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
|
5640
|
+
const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32;
|
|
5641
|
+
svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b));
|
|
5642
|
+
svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
|
|
5643
|
+
sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
|
|
5644
|
+
|
|
5645
|
+
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4));
|
|
5646
|
+
q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
|
|
5647
|
+
sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
|
|
5648
|
+
}
|
|
5649
|
+
sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2)));
|
|
5650
|
+
} break;
|
|
5651
|
+
default:
|
|
5652
|
+
assert(false && "Unsupported vector length");
|
|
5653
|
+
break;
|
|
5654
|
+
}
|
|
5655
|
+
}
|
|
5656
|
+
*s = sumf;
|
|
5657
|
+
#elif __ARM_NEON
|
|
5573
5658
|
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
|
5574
5659
|
const int32x4_t mzero = vdupq_n_s32(0);
|
|
5575
5660
|
|