@fugood/llama.node 0.3.16 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +6 -1
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +44 -2
- package/lib/index.js +132 -1
- package/lib/index.ts +203 -3
- package/package.json +2 -1
- package/src/EmbeddingWorker.cpp +1 -1
- package/src/LlamaCompletionWorker.cpp +374 -19
- package/src/LlamaCompletionWorker.h +31 -10
- package/src/LlamaContext.cpp +216 -7
- package/src/LlamaContext.h +12 -0
- package/src/common.hpp +15 -0
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +233 -0
- package/src/llama.cpp/.github/workflows/build.yml +89 -767
- package/src/llama.cpp/.github/workflows/docker.yml +9 -6
- package/src/llama.cpp/.github/workflows/release.yml +716 -0
- package/src/llama.cpp/.github/workflows/server.yml +19 -23
- package/src/llama.cpp/CMakeLists.txt +11 -1
- package/src/llama.cpp/cmake/build-info.cmake +8 -2
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
- package/src/llama.cpp/common/CMakeLists.txt +35 -4
- package/src/llama.cpp/common/arg.cpp +844 -121
- package/src/llama.cpp/common/arg.h +9 -0
- package/src/llama.cpp/common/chat.cpp +129 -107
- package/src/llama.cpp/common/chat.h +2 -0
- package/src/llama.cpp/common/common.cpp +64 -518
- package/src/llama.cpp/common/common.h +35 -45
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
- package/src/llama.cpp/common/llguidance.cpp +31 -47
- package/src/llama.cpp/common/minja/chat-template.hpp +23 -11
- package/src/llama.cpp/common/minja/minja.hpp +186 -127
- package/src/llama.cpp/common/regex-partial.cpp +204 -0
- package/src/llama.cpp/common/regex-partial.h +56 -0
- package/src/llama.cpp/common/sampling.cpp +60 -50
- package/src/llama.cpp/docs/build.md +122 -7
- package/src/llama.cpp/examples/CMakeLists.txt +2 -32
- package/src/llama.cpp/examples/batched/batched.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +9 -12
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/parallel/parallel.cpp +89 -15
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/sycl/build.sh +2 -2
- package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
- package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/training/finetune.cpp +96 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +35 -2
- package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
- package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
- package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
- package/src/llama.cpp/ggml/include/ggml.h +76 -106
- package/src/llama.cpp/ggml/src/CMakeLists.txt +11 -8
- package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
- package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +66 -33
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -194
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1060 -410
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1008 -13533
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +31 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +90 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +266 -72
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1034 -88
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8796 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +252 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +106 -14
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -262
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +307 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +125 -45
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +10 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +239 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +9 -307
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +944 -438
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +507 -411
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
- package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +83 -49
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1278 -282
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +32 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +133 -30
- package/src/llama.cpp/ggml/src/ggml.c +170 -265
- package/src/llama.cpp/ggml/src/gguf.cpp +34 -33
- package/src/llama.cpp/include/llama.h +82 -22
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +5 -3
- package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
- package/src/llama.cpp/scripts/xxd.cmake +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +4 -2
- package/src/llama.cpp/src/llama-adapter.cpp +43 -1
- package/src/llama.cpp/src/llama-arch.cpp +163 -17
- package/src/llama.cpp/src/llama-arch.h +16 -0
- package/src/llama.cpp/src/llama-batch.cpp +5 -1
- package/src/llama.cpp/src/llama-batch.h +2 -1
- package/src/llama.cpp/src/llama-chat.cpp +91 -16
- package/src/llama.cpp/src/llama-chat.h +7 -2
- package/src/llama.cpp/src/llama-context.cpp +479 -575
- package/src/llama.cpp/src/llama-context.h +44 -33
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +209 -157
- package/src/llama.cpp/src/llama-graph.h +38 -14
- package/src/llama.cpp/src/llama-hparams.h +13 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +1604 -543
- package/src/llama.cpp/src/llama-kv-cache.h +283 -171
- package/src/llama.cpp/src/llama-memory.h +12 -2
- package/src/llama.cpp/src/llama-mmap.cpp +1 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +34 -20
- package/src/llama.cpp/src/llama-model-loader.h +5 -3
- package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/src/llama.cpp/src/llama-model-saver.h +37 -0
- package/src/llama.cpp/src/llama-model.cpp +1803 -330
- package/src/llama.cpp/src/llama-model.h +21 -2
- package/src/llama.cpp/src/llama-quant.cpp +33 -10
- package/src/llama.cpp/src/llama-sampling.cpp +25 -7
- package/src/llama.cpp/src/llama-vocab.cpp +86 -10
- package/src/llama.cpp/src/llama-vocab.h +6 -0
- package/src/llama.cpp/src/llama.cpp +15 -1
- package/src/llama.cpp/tests/CMakeLists.txt +52 -31
- package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
- package/src/llama.cpp/tests/test-backend-ops.cpp +189 -90
- package/src/llama.cpp/tests/test-chat-template.cpp +26 -6
- package/src/llama.cpp/tests/test-chat.cpp +15 -3
- package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
- package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
- package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
- package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
- package/src/llama.cpp/tests/test-opt.cpp +33 -21
- package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
- package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
- package/src/llama.cpp/tests/test-sampling.cpp +1 -1
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
- package/src/llama.cpp/tools/CMakeLists.txt +39 -0
- package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +3 -3
- package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +1 -1
- package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +15 -16
- package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
- package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +623 -274
- package/src/llama.cpp/{examples → tools}/main/main.cpp +22 -14
- package/src/llama.cpp/tools/mtmd/CMakeLists.txt +47 -0
- package/src/llama.cpp/tools/mtmd/clip-impl.h +365 -0
- package/src/llama.cpp/tools/mtmd/clip.cpp +3646 -0
- package/src/llama.cpp/tools/mtmd/clip.h +99 -0
- package/src/llama.cpp/tools/mtmd/deprecation-warning.cpp +22 -0
- package/src/llama.cpp/tools/mtmd/mtmd-cli.cpp +370 -0
- package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
- package/src/llama.cpp/tools/mtmd/mtmd.cpp +678 -0
- package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
- package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +21 -5
- package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +53 -3
- package/src/llama.cpp/tools/rpc/CMakeLists.txt +4 -0
- package/src/llama.cpp/tools/rpc/rpc-server.cpp +322 -0
- package/src/llama.cpp/tools/run/CMakeLists.txt +16 -0
- package/src/llama.cpp/{examples → tools}/run/run.cpp +30 -30
- package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
- package/src/llama.cpp/{examples → tools}/server/httplib.h +313 -247
- package/src/llama.cpp/{examples → tools}/server/server.cpp +529 -215
- package/src/llama.cpp/{examples → tools}/server/utils.hpp +427 -6
- package/src/llama.cpp/{examples → tools}/tts/tts.cpp +6 -9
- package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/infill.cpp +0 -590
- package/src/llama.cpp/examples/llava/CMakeLists.txt +0 -66
- package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
- package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
- package/src/llama.cpp/examples/llava/clip.cpp +0 -3206
- package/src/llama.cpp/examples/llava/clip.h +0 -118
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
- package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
- package/src/llama.cpp/examples/llava/llava.cpp +0 -574
- package/src/llama.cpp/examples/llava/llava.h +0 -49
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +0 -584
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/rpc/CMakeLists.txt +0 -2
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +0 -171
- package/src/llama.cpp/examples/run/CMakeLists.txt +0 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
|
@@ -20,12 +20,6 @@
|
|
|
20
20
|
#define GROUP_MAX_EPS_IQ1_M 1e-7f
|
|
21
21
|
#define GROUP_MAX_EPS_IQ1_S 1e-12f
|
|
22
22
|
|
|
23
|
-
#if defined(_MSC_VER)
|
|
24
|
-
// disable "possible loss of data" to avoid warnings for hundreds of casts
|
|
25
|
-
// we should just be careful :)
|
|
26
|
-
#pragma warning(disable: 4244 4267)
|
|
27
|
-
#endif
|
|
28
|
-
|
|
29
23
|
#define UNUSED GGML_UNUSED
|
|
30
24
|
|
|
31
25
|
// some compilers don't provide _mm256_set_m128i, e.g. gcc 7
|
|
@@ -891,15 +885,15 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
|
|
|
891
885
|
}
|
|
892
886
|
#elif defined(__riscv_v_intrinsic)
|
|
893
887
|
|
|
894
|
-
size_t vl =
|
|
888
|
+
size_t vl = QK8_0;
|
|
895
889
|
|
|
896
890
|
for (int i = 0; i < nb; i++) {
|
|
897
891
|
// load elements
|
|
898
|
-
|
|
892
|
+
vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_0, vl);
|
|
899
893
|
|
|
900
|
-
|
|
894
|
+
vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl);
|
|
901
895
|
vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl);
|
|
902
|
-
vfloat32m1_t vmax =
|
|
896
|
+
vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl);
|
|
903
897
|
float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
|
|
904
898
|
|
|
905
899
|
const float d = amax / ((1 << 7) - 1);
|
|
@@ -907,14 +901,14 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
|
|
|
907
901
|
|
|
908
902
|
y[i].d = GGML_FP32_TO_FP16(d);
|
|
909
903
|
|
|
910
|
-
|
|
904
|
+
vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl);
|
|
911
905
|
|
|
912
906
|
// convert to integer
|
|
913
|
-
|
|
914
|
-
|
|
907
|
+
vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl);
|
|
908
|
+
vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl);
|
|
915
909
|
|
|
916
910
|
// store result
|
|
917
|
-
|
|
911
|
+
__riscv_vse8_v_i8m2(y[i].qs , vs, vl);
|
|
918
912
|
}
|
|
919
913
|
|
|
920
914
|
#elif defined(__POWER9_VECTOR__)
|
|
@@ -1229,15 +1223,15 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
|
|
|
1229
1223
|
}
|
|
1230
1224
|
#elif defined(__riscv_v_intrinsic)
|
|
1231
1225
|
|
|
1232
|
-
size_t vl =
|
|
1226
|
+
size_t vl = QK8_1;
|
|
1233
1227
|
|
|
1234
1228
|
for (int i = 0; i < nb; i++) {
|
|
1235
1229
|
// load elements
|
|
1236
|
-
|
|
1230
|
+
vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_1, vl);
|
|
1237
1231
|
|
|
1238
|
-
|
|
1232
|
+
vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl);
|
|
1239
1233
|
vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl);
|
|
1240
|
-
vfloat32m1_t vmax =
|
|
1234
|
+
vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl);
|
|
1241
1235
|
float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
|
|
1242
1236
|
|
|
1243
1237
|
const float d = amax / ((1 << 7) - 1);
|
|
@@ -1245,18 +1239,18 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
|
|
|
1245
1239
|
|
|
1246
1240
|
y[i].d = GGML_FP32_TO_FP16(d);
|
|
1247
1241
|
|
|
1248
|
-
|
|
1242
|
+
vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl);
|
|
1249
1243
|
|
|
1250
1244
|
// convert to integer
|
|
1251
|
-
|
|
1252
|
-
|
|
1245
|
+
vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl);
|
|
1246
|
+
vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl);
|
|
1253
1247
|
|
|
1254
1248
|
// store result
|
|
1255
|
-
|
|
1249
|
+
__riscv_vse8_v_i8m2(y[i].qs , vs, vl);
|
|
1256
1250
|
|
|
1257
1251
|
// compute sum for y[i].s
|
|
1258
1252
|
vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
|
|
1259
|
-
vint16m1_t vwrs =
|
|
1253
|
+
vint16m1_t vwrs = __riscv_vwredsum_vs_i8m2_i16m1(vs, tmp2, vl);
|
|
1260
1254
|
|
|
1261
1255
|
// set y[i].s
|
|
1262
1256
|
int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
|
|
@@ -2391,33 +2385,31 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
2391
2385
|
|
|
2392
2386
|
sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
|
|
2393
2387
|
#elif defined(__riscv_v_intrinsic)
|
|
2394
|
-
size_t vl =
|
|
2388
|
+
size_t vl = qk / 2;
|
|
2395
2389
|
|
|
2396
2390
|
for (; ib < nb; ++ib) {
|
|
2397
2391
|
// load elements
|
|
2398
|
-
|
|
2392
|
+
vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl);
|
|
2399
2393
|
|
|
2400
|
-
|
|
2401
|
-
|
|
2394
|
+
vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
|
|
2395
|
+
vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl);
|
|
2402
2396
|
|
|
2403
2397
|
// mask and store lower part of x, and then upper part
|
|
2404
|
-
|
|
2405
|
-
|
|
2398
|
+
vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
|
|
2399
|
+
vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
|
|
2406
2400
|
|
|
2407
|
-
|
|
2408
|
-
|
|
2401
|
+
vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
|
|
2402
|
+
vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
|
|
2409
2403
|
|
|
2410
2404
|
// subtract offset
|
|
2411
|
-
|
|
2412
|
-
|
|
2405
|
+
vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl);
|
|
2406
|
+
vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl);
|
|
2413
2407
|
|
|
2414
|
-
|
|
2415
|
-
|
|
2408
|
+
vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
|
|
2409
|
+
vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl);
|
|
2416
2410
|
|
|
2417
2411
|
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
|
|
2418
|
-
|
|
2419
|
-
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
|
|
2420
|
-
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
|
|
2412
|
+
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
|
|
2421
2413
|
|
|
2422
2414
|
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
|
|
2423
2415
|
|
|
@@ -2783,29 +2775,27 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
2783
2775
|
|
|
2784
2776
|
sumf = hsum_float_8(acc) + summs;
|
|
2785
2777
|
#elif defined(__riscv_v_intrinsic)
|
|
2786
|
-
size_t vl =
|
|
2778
|
+
size_t vl = qk / 2;
|
|
2787
2779
|
|
|
2788
2780
|
for (; ib < nb; ++ib) {
|
|
2789
2781
|
// load elements
|
|
2790
|
-
|
|
2782
|
+
vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl);
|
|
2791
2783
|
|
|
2792
|
-
|
|
2793
|
-
|
|
2784
|
+
vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
|
|
2785
|
+
vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl);
|
|
2794
2786
|
|
|
2795
2787
|
// mask and store lower part of x, and then upper part
|
|
2796
|
-
|
|
2797
|
-
|
|
2788
|
+
vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
|
|
2789
|
+
vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
|
|
2798
2790
|
|
|
2799
|
-
|
|
2800
|
-
|
|
2791
|
+
vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
|
|
2792
|
+
vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
|
|
2801
2793
|
|
|
2802
|
-
|
|
2803
|
-
|
|
2794
|
+
vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
|
|
2795
|
+
vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl);
|
|
2804
2796
|
|
|
2805
2797
|
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
|
|
2806
|
-
|
|
2807
|
-
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
|
|
2808
|
-
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
|
|
2798
|
+
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
|
|
2809
2799
|
|
|
2810
2800
|
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
|
|
2811
2801
|
|
|
@@ -3132,65 +3122,33 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
3132
3122
|
|
|
3133
3123
|
sumf = hsum_float_8(acc);
|
|
3134
3124
|
#elif defined(__riscv_v_intrinsic)
|
|
3135
|
-
|
|
3136
|
-
|
|
3137
|
-
size_t vl = __riscv_vsetvl_e8m1(qk/2);
|
|
3138
|
-
|
|
3139
|
-
// These temporary registers are for masking and shift operations
|
|
3140
|
-
vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
|
|
3141
|
-
vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
|
|
3142
|
-
|
|
3143
|
-
vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
|
|
3144
|
-
vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
|
|
3125
|
+
size_t vl;
|
|
3126
|
+
size_t vlenb = __riscv_vlenb();
|
|
3145
3127
|
|
|
3146
3128
|
for (; ib < nb; ++ib) {
|
|
3147
|
-
|
|
3148
|
-
|
|
3149
|
-
|
|
3150
|
-
|
|
3151
|
-
|
|
3152
|
-
|
|
3153
|
-
|
|
3154
|
-
|
|
3155
|
-
|
|
3156
|
-
|
|
3157
|
-
|
|
3158
|
-
|
|
3159
|
-
|
|
3160
|
-
|
|
3161
|
-
|
|
3162
|
-
|
|
3163
|
-
|
|
3164
|
-
|
|
3165
|
-
|
|
3166
|
-
|
|
3167
|
-
|
|
3168
|
-
|
|
3169
|
-
|
|
3170
|
-
|
|
3171
|
-
vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
|
|
3172
|
-
vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
|
|
3173
|
-
|
|
3174
|
-
vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
|
|
3175
|
-
vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
|
|
3176
|
-
|
|
3177
|
-
vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
|
|
3178
|
-
vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
|
|
3179
|
-
|
|
3180
|
-
vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
|
|
3181
|
-
vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
|
|
3182
|
-
|
|
3183
|
-
vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
|
|
3184
|
-
vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
|
|
3185
|
-
|
|
3186
|
-
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
|
|
3187
|
-
|
|
3188
|
-
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
|
|
3189
|
-
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
|
|
3190
|
-
|
|
3191
|
-
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
|
|
3192
|
-
|
|
3193
|
-
sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi;
|
|
3129
|
+
vl = qk / 2;
|
|
3130
|
+
vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl);
|
|
3131
|
+
vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl));
|
|
3132
|
+
vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl));
|
|
3133
|
+
vint8m2_t v0c;
|
|
3134
|
+
if (vlenb == 16) {
|
|
3135
|
+
v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h);
|
|
3136
|
+
} else {
|
|
3137
|
+
v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32);
|
|
3138
|
+
v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l);
|
|
3139
|
+
}
|
|
3140
|
+
|
|
3141
|
+
vl = qk;
|
|
3142
|
+
vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl);
|
|
3143
|
+
qh = __riscv_vmnand_mm_b4(qh, qh, vl);
|
|
3144
|
+
vint8m2_t v0f = __riscv_vsub_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl);
|
|
3145
|
+
vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
|
|
3146
|
+
vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl);
|
|
3147
|
+
vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl);
|
|
3148
|
+
vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl);
|
|
3149
|
+
int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum);
|
|
3150
|
+
|
|
3151
|
+
sumf += (GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)) * sumi;
|
|
3194
3152
|
}
|
|
3195
3153
|
|
|
3196
3154
|
#elif defined(__POWER9_VECTOR__)
|
|
@@ -3503,60 +3461,30 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
3503
3461
|
|
|
3504
3462
|
sumf = hsum_float_8(acc) + summs;
|
|
3505
3463
|
#elif defined(__riscv_v_intrinsic)
|
|
3506
|
-
|
|
3507
|
-
|
|
3508
|
-
size_t vl = __riscv_vsetvl_e8m1(qk/2);
|
|
3509
|
-
|
|
3510
|
-
// temporary registers for shift operations
|
|
3511
|
-
vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
|
|
3512
|
-
vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
|
|
3464
|
+
size_t vl;
|
|
3465
|
+
size_t vlenb = __riscv_vlenb();
|
|
3513
3466
|
|
|
3514
3467
|
for (; ib < nb; ++ib) {
|
|
3515
|
-
|
|
3516
|
-
|
|
3517
|
-
|
|
3518
|
-
|
|
3519
|
-
|
|
3520
|
-
|
|
3521
|
-
|
|
3522
|
-
|
|
3523
|
-
|
|
3524
|
-
|
|
3525
|
-
|
|
3526
|
-
|
|
3527
|
-
|
|
3528
|
-
|
|
3529
|
-
|
|
3530
|
-
|
|
3531
|
-
|
|
3532
|
-
|
|
3533
|
-
|
|
3534
|
-
|
|
3535
|
-
|
|
3536
|
-
// load
|
|
3537
|
-
vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
|
|
3538
|
-
|
|
3539
|
-
vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
|
|
3540
|
-
vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
|
|
3541
|
-
|
|
3542
|
-
vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
|
|
3543
|
-
vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
|
|
3544
|
-
|
|
3545
|
-
vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
|
|
3546
|
-
vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
|
|
3547
|
-
|
|
3548
|
-
vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
|
|
3549
|
-
vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
|
|
3550
|
-
|
|
3551
|
-
vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
|
|
3552
|
-
vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
|
|
3553
|
-
|
|
3554
|
-
vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
|
|
3555
|
-
|
|
3556
|
-
vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
|
|
3557
|
-
vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
|
|
3558
|
-
|
|
3559
|
-
int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
|
|
3468
|
+
vl = qk / 2;
|
|
3469
|
+
vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl);
|
|
3470
|
+
vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl));
|
|
3471
|
+
vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl));
|
|
3472
|
+
vint8m2_t v0c;
|
|
3473
|
+
if (vlenb == 16) {
|
|
3474
|
+
v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h);
|
|
3475
|
+
} else {
|
|
3476
|
+
v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32);
|
|
3477
|
+
v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l);
|
|
3478
|
+
}
|
|
3479
|
+
|
|
3480
|
+
vl = qk;
|
|
3481
|
+
vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl);
|
|
3482
|
+
vint8m2_t v0f = __riscv_vor_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl);
|
|
3483
|
+
vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
|
|
3484
|
+
vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl);
|
|
3485
|
+
vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl);
|
|
3486
|
+
vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl);
|
|
3487
|
+
int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum);
|
|
3560
3488
|
|
|
3561
3489
|
sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
|
|
3562
3490
|
}
|
|
@@ -3970,17 +3898,17 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
3970
3898
|
|
|
3971
3899
|
sumf = hsum_float_8(accum);
|
|
3972
3900
|
#elif defined(__riscv_v_intrinsic)
|
|
3973
|
-
size_t vl =
|
|
3901
|
+
size_t vl = qk;
|
|
3974
3902
|
|
|
3975
3903
|
for (; ib < nb; ++ib) {
|
|
3976
3904
|
// load elements
|
|
3977
|
-
|
|
3978
|
-
|
|
3905
|
+
vint8m2_t bx_0 = __riscv_vle8_v_i8m2(x[ib].qs, vl);
|
|
3906
|
+
vint8m2_t by_0 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
|
|
3979
3907
|
|
|
3980
|
-
|
|
3908
|
+
vint16m4_t vw_mul = __riscv_vwmul_vv_i16m4(bx_0, by_0, vl);
|
|
3981
3909
|
|
|
3982
3910
|
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
|
|
3983
|
-
vint32m1_t v_sum =
|
|
3911
|
+
vint32m1_t v_sum = __riscv_vwredsum_vs_i16m4_i32m1(vw_mul, v_zero, vl);
|
|
3984
3912
|
|
|
3985
3913
|
int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
|
|
3986
3914
|
|
|
@@ -5174,84 +5102,182 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
5174
5102
|
|
|
5175
5103
|
#elif defined __riscv_v_intrinsic
|
|
5176
5104
|
|
|
5105
|
+
const int vector_length = __riscv_vlenb() * 8;
|
|
5177
5106
|
float sumf = 0;
|
|
5178
|
-
uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
5179
|
-
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
|
5180
5107
|
|
|
5181
|
-
|
|
5108
|
+
uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
5109
|
+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };
|
|
5110
|
+
uint8_t atmp[16];
|
|
5182
5111
|
|
|
5183
|
-
|
|
5184
|
-
|
|
5185
|
-
|
|
5186
|
-
|
|
5187
|
-
|
|
5188
|
-
|
|
5112
|
+
switch (vector_length) {
|
|
5113
|
+
case 256:
|
|
5114
|
+
for (int i = 0; i < nb; ++i) {
|
|
5115
|
+
const uint8_t * q2 = x[i].qs;
|
|
5116
|
+
const int8_t * q8 = y[i].qs;
|
|
5117
|
+
const uint8_t * sc = x[i].scales;
|
|
5189
5118
|
|
|
5190
|
-
|
|
5119
|
+
const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
|
5120
|
+
const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
|
5191
5121
|
|
|
5192
|
-
|
|
5193
|
-
vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
|
|
5122
|
+
size_t vl = 16;
|
|
5194
5123
|
|
|
5195
|
-
|
|
5124
|
+
vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl);
|
|
5125
|
+
vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
|
|
5196
5126
|
|
|
5197
|
-
|
|
5198
|
-
vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
|
|
5199
|
-
vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
|
|
5200
|
-
vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
|
|
5201
|
-
vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
|
|
5127
|
+
vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl);
|
|
5202
5128
|
|
|
5203
|
-
|
|
5129
|
+
vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl);
|
|
5130
|
+
vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
|
|
5131
|
+
vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
|
|
5132
|
+
vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
|
|
5133
|
+
vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
|
|
5204
5134
|
|
|
5205
|
-
|
|
5135
|
+
sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums);
|
|
5206
5136
|
|
|
5207
|
-
|
|
5208
|
-
vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl);
|
|
5137
|
+
vl = 32;
|
|
5209
5138
|
|
|
5210
|
-
|
|
5211
|
-
|
|
5139
|
+
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
|
|
5140
|
+
vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl);
|
|
5212
5141
|
|
|
5213
|
-
|
|
5214
|
-
|
|
5215
|
-
vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
|
|
5142
|
+
uint8_t is = 0;
|
|
5143
|
+
int isum = 0;
|
|
5216
5144
|
|
|
5217
|
-
|
|
5218
|
-
|
|
5219
|
-
|
|
5220
|
-
vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl);
|
|
5145
|
+
for (int j = 0; j < QK_K / 128; ++j) {
|
|
5146
|
+
// load Q2
|
|
5147
|
+
vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
|
|
5221
5148
|
|
|
5222
|
-
|
|
5223
|
-
|
|
5224
|
-
|
|
5225
|
-
|
|
5226
|
-
vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl);
|
|
5149
|
+
vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl);
|
|
5150
|
+
vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl);
|
|
5151
|
+
vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl);
|
|
5152
|
+
vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl);
|
|
5227
5153
|
|
|
5228
|
-
|
|
5229
|
-
|
|
5230
|
-
|
|
5231
|
-
|
|
5154
|
+
// duplicate scale elements for product
|
|
5155
|
+
vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl);
|
|
5156
|
+
vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl);
|
|
5157
|
+
vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl);
|
|
5158
|
+
vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl);
|
|
5232
5159
|
|
|
5233
|
-
|
|
5234
|
-
|
|
5235
|
-
|
|
5236
|
-
|
|
5237
|
-
vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl);
|
|
5160
|
+
vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl));
|
|
5161
|
+
vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl));
|
|
5162
|
+
vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl));
|
|
5163
|
+
vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl));
|
|
5238
5164
|
|
|
5239
|
-
|
|
5240
|
-
|
|
5241
|
-
|
|
5242
|
-
|
|
5165
|
+
// load Q8
|
|
5166
|
+
vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
|
|
5167
|
+
vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl);
|
|
5168
|
+
vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl);
|
|
5169
|
+
vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl);
|
|
5243
5170
|
|
|
5244
|
-
|
|
5245
|
-
|
|
5171
|
+
vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl);
|
|
5172
|
+
vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl);
|
|
5173
|
+
vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl);
|
|
5174
|
+
vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl);
|
|
5246
5175
|
|
|
5247
|
-
|
|
5176
|
+
vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl);
|
|
5177
|
+
vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl);
|
|
5248
5178
|
|
|
5249
|
-
|
|
5179
|
+
isum += __riscv_vmv_x_s_i32m1_i32(isum1);
|
|
5250
5180
|
|
|
5251
|
-
|
|
5181
|
+
q2 += 32;
|
|
5182
|
+
q8 += 128;
|
|
5183
|
+
is = 8;
|
|
5184
|
+
}
|
|
5252
5185
|
|
|
5253
|
-
|
|
5186
|
+
sumf += dall * isum;
|
|
5187
|
+
}
|
|
5188
|
+
break;
|
|
5189
|
+
case 128:
|
|
5190
|
+
for (int i = 0; i < nb; ++i) {
|
|
5191
|
+
const uint8_t * q2 = x[i].qs;
|
|
5192
|
+
const int8_t * q8 = y[i].qs;
|
|
5193
|
+
const uint8_t * sc = x[i].scales;
|
|
5194
|
+
const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
|
5195
|
+
const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
|
5196
|
+
uint8_t *patmp = atmp;
|
|
5197
|
+
int vsums;
|
|
5198
|
+
int tmp;
|
|
5199
|
+
__asm__ __volatile__(
|
|
5200
|
+
"vsetivli zero, 16, e8, m1\n\t"
|
|
5201
|
+
"vmv.v.x v8, zero\n\t"
|
|
5202
|
+
"vle8.v v1, (%[sc])\n\t"
|
|
5203
|
+
"vand.vi v0, v1, 0xF\n\t"
|
|
5204
|
+
"vsrl.vi v1, v1, 4\n\t"
|
|
5205
|
+
"vse8.v v0, (%[scale])\n\t"
|
|
5206
|
+
"vsetivli zero, 16, e16, m2\n\t"
|
|
5207
|
+
"vle16.v v2, (%[bsums])\n\t"
|
|
5208
|
+
"vzext.vf2 v0, v1\n\t"
|
|
5209
|
+
"vwmul.vv v4, v0, v2\n\t"
|
|
5210
|
+
"vsetivli zero, 16, e32, m4\n\t"
|
|
5211
|
+
"vredsum.vs v8, v4, v8\n\t"
|
|
5212
|
+
"vmv.x.s %[vsums], v8"
|
|
5213
|
+
: [tmp] "=&r" (tmp), [vsums] "=&r" (vsums)
|
|
5214
|
+
: [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums)
|
|
5215
|
+
: "memory"
|
|
5216
|
+
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
|
|
5217
|
+
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
|
|
5218
|
+
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
|
|
5219
|
+
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
|
|
5220
|
+
);
|
|
5221
|
+
sumf += dmin * vsums;
|
|
5222
|
+
int isum = 0;
|
|
5223
|
+
|
|
5224
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
|
5225
|
+
__asm__ __volatile__(
|
|
5226
|
+
"vsetvli zero, %[vl32], e8, m2\n\t"
|
|
5227
|
+
"vle8.v v0, (%[q2])\n\t"
|
|
5228
|
+
"vsrl.vi v2, v0, 2\n\t"
|
|
5229
|
+
"vsrl.vi v4, v0, 4\n\t"
|
|
5230
|
+
"vsrl.vi v6, v0, 6\n\t"
|
|
5231
|
+
"vand.vi v0, v0, 0x3\n\t"
|
|
5232
|
+
"vand.vi v2, v2, 0x3\n\t"
|
|
5233
|
+
"vand.vi v4, v4, 0x3\n\t"
|
|
5234
|
+
"vsetvli zero, %[vl128], e8, m8\n\t"
|
|
5235
|
+
"vle8.v v8, (%[q8])\n\t"
|
|
5236
|
+
"vsetvli zero, %[vl64], e8, m4\n\t"
|
|
5237
|
+
"vwmul.vv v16, v0, v8\n\t"
|
|
5238
|
+
"vwmul.vv v24, v4, v12\n\t"
|
|
5239
|
+
"vsetivli zero, 16, e16, m2\n\t"
|
|
5240
|
+
"vmv.v.x v0, zero\n\t"
|
|
5241
|
+
"vwredsum.vs v10, v16, v0\n\t"
|
|
5242
|
+
"vwredsum.vs v9, v18, v0\n\t"
|
|
5243
|
+
"vwredsum.vs v8, v20, v0\n\t"
|
|
5244
|
+
"vwredsum.vs v7, v22, v0\n\t"
|
|
5245
|
+
"vwredsum.vs v11, v24, v0\n\t"
|
|
5246
|
+
"vwredsum.vs v12, v26, v0\n\t"
|
|
5247
|
+
"vwredsum.vs v13, v28, v0\n\t"
|
|
5248
|
+
"vwredsum.vs v14, v30, v0\n\t"
|
|
5249
|
+
"vsetivli zero, 4, e32, m1\n\t"
|
|
5250
|
+
"vslideup.vi v10, v9, 1\n\t"
|
|
5251
|
+
"vslideup.vi v8, v7, 1\n\t"
|
|
5252
|
+
"vslideup.vi v11, v12, 1\n\t"
|
|
5253
|
+
"vslideup.vi v13, v14, 1\n\t"
|
|
5254
|
+
"vslideup.vi v10, v8, 2\n\t"
|
|
5255
|
+
"vslideup.vi v11, v13, 2\n\t"
|
|
5256
|
+
"vsetivli zero, 8, e32, m2\n\t"
|
|
5257
|
+
"vle8.v v15, (%[scale])\n\t"
|
|
5258
|
+
"vzext.vf4 v12, v15\n\t"
|
|
5259
|
+
"vmul.vv v10, v10, v12\n\t"
|
|
5260
|
+
"vredsum.vs v0, v10, v0\n\t"
|
|
5261
|
+
"vmv.x.s %[tmp], v0\n\t"
|
|
5262
|
+
"add %[isum], %[isum], %[tmp]"
|
|
5263
|
+
: [tmp] "=&r" (tmp), [isum] "+&r" (isum)
|
|
5264
|
+
: [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
|
|
5265
|
+
, [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
|
|
5266
|
+
: "memory"
|
|
5267
|
+
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
|
|
5268
|
+
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
|
|
5269
|
+
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
|
|
5270
|
+
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
|
|
5271
|
+
);
|
|
5272
|
+
q2 += 32; q8 += 128; patmp += 8;
|
|
5273
|
+
}
|
|
5254
5274
|
|
|
5275
|
+
sumf += dall * isum;
|
|
5276
|
+
}
|
|
5277
|
+
break;
|
|
5278
|
+
default:
|
|
5279
|
+
assert(false && "Unsupported vector length");
|
|
5280
|
+
break;
|
|
5255
5281
|
}
|
|
5256
5282
|
|
|
5257
5283
|
*s = sumf;
|
|
@@ -6116,97 +6142,221 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
6116
6142
|
uint32_t aux[3];
|
|
6117
6143
|
uint32_t utmp[4];
|
|
6118
6144
|
|
|
6145
|
+
const int vector_length = __riscv_vlenb() * 8;
|
|
6119
6146
|
float sumf = 0;
|
|
6120
|
-
for (int i = 0; i < nb; ++i) {
|
|
6121
6147
|
|
|
6122
|
-
|
|
6123
|
-
|
|
6124
|
-
|
|
6148
|
+
switch (vector_length) {
|
|
6149
|
+
case 256:
|
|
6150
|
+
for (int i = 0; i < nb; ++i) {
|
|
6125
6151
|
|
|
6126
|
-
|
|
6127
|
-
|
|
6128
|
-
|
|
6129
|
-
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
|
|
6130
|
-
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
|
|
6152
|
+
const uint8_t * GGML_RESTRICT q3 = x[i].qs;
|
|
6153
|
+
const uint8_t * GGML_RESTRICT qh = x[i].hmask;
|
|
6154
|
+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
|
6131
6155
|
|
|
6132
|
-
|
|
6133
|
-
|
|
6156
|
+
memcpy(aux, x[i].scales, 12);
|
|
6157
|
+
utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
|
|
6158
|
+
utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
|
|
6159
|
+
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
|
|
6160
|
+
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
|
|
6134
6161
|
|
|
6162
|
+
int8_t * scale = (int8_t *)utmp;
|
|
6163
|
+
for (int j = 0; j < 16; ++j) scale[j] -= 32;
|
|
6135
6164
|
|
|
6136
|
-
size_t vl = 32;
|
|
6137
|
-
uint8_t m = 1;
|
|
6138
6165
|
|
|
6139
|
-
|
|
6140
|
-
|
|
6166
|
+
size_t vl = 32;
|
|
6167
|
+
uint8_t m = 1;
|
|
6141
6168
|
|
|
6142
|
-
|
|
6169
|
+
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
|
|
6170
|
+
vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl);
|
|
6143
6171
|
|
|
6144
|
-
|
|
6172
|
+
int sum_t = 0;
|
|
6145
6173
|
|
|
6146
|
-
|
|
6174
|
+
for (int j = 0; j < QK_K; j += 128) {
|
|
6147
6175
|
|
|
6148
|
-
|
|
6149
|
-
vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
|
|
6176
|
+
vl = 32;
|
|
6150
6177
|
|
|
6151
|
-
|
|
6152
|
-
|
|
6153
|
-
vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
|
|
6154
|
-
vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
|
|
6178
|
+
// load Q3
|
|
6179
|
+
vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
|
|
6155
6180
|
|
|
6156
|
-
|
|
6157
|
-
|
|
6158
|
-
|
|
6159
|
-
|
|
6160
|
-
m <<= 1;
|
|
6181
|
+
vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl));
|
|
6182
|
+
vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl));
|
|
6183
|
+
vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
|
|
6184
|
+
vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
|
|
6161
6185
|
|
|
6162
|
-
|
|
6163
|
-
|
|
6164
|
-
|
|
6165
|
-
|
|
6186
|
+
// compute mask for subtraction
|
|
6187
|
+
vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
|
6188
|
+
vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
|
|
6189
|
+
vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl);
|
|
6190
|
+
m <<= 1;
|
|
6166
6191
|
|
|
6167
|
-
|
|
6168
|
-
|
|
6169
|
-
|
|
6170
|
-
|
|
6192
|
+
vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
|
6193
|
+
vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
|
|
6194
|
+
vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl);
|
|
6195
|
+
m <<= 1;
|
|
6171
6196
|
|
|
6172
|
-
|
|
6173
|
-
|
|
6174
|
-
|
|
6175
|
-
|
|
6197
|
+
vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
|
6198
|
+
vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
|
|
6199
|
+
vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl);
|
|
6200
|
+
m <<= 1;
|
|
6176
6201
|
|
|
6177
|
-
|
|
6178
|
-
|
|
6179
|
-
|
|
6180
|
-
|
|
6181
|
-
vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
|
|
6202
|
+
vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
|
|
6203
|
+
vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
|
|
6204
|
+
vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl);
|
|
6205
|
+
m <<= 1;
|
|
6182
6206
|
|
|
6183
|
-
|
|
6207
|
+
// load Q8 and take product with Q3
|
|
6208
|
+
vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl);
|
|
6209
|
+
vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
|
|
6210
|
+
vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
|
|
6211
|
+
vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
|
|
6184
6212
|
|
|
6185
|
-
|
|
6186
|
-
vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
|
|
6187
|
-
vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
|
|
6188
|
-
vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
|
|
6189
|
-
vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
|
|
6190
|
-
vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
|
|
6191
|
-
vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
|
|
6192
|
-
vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
|
|
6193
|
-
vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
|
|
6213
|
+
vl = 16;
|
|
6194
6214
|
|
|
6195
|
-
|
|
6196
|
-
|
|
6197
|
-
|
|
6198
|
-
|
|
6215
|
+
// retrieve lane to multiply with scale
|
|
6216
|
+
vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
|
|
6217
|
+
vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
|
|
6218
|
+
vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
|
|
6219
|
+
vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
|
|
6220
|
+
vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
|
|
6221
|
+
vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
|
|
6222
|
+
vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
|
|
6223
|
+
vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
|
|
6199
6224
|
|
|
6200
|
-
|
|
6225
|
+
vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl);
|
|
6226
|
+
vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl);
|
|
6227
|
+
vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl);
|
|
6228
|
+
vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl);
|
|
6201
6229
|
|
|
6202
|
-
|
|
6230
|
+
sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
|
|
6203
6231
|
|
|
6204
|
-
|
|
6232
|
+
q3 += 32; q8 += 128; scale += 8;
|
|
6205
6233
|
|
|
6206
|
-
|
|
6234
|
+
}
|
|
6235
|
+
|
|
6236
|
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
6237
|
+
|
|
6238
|
+
sumf += d*sum_t;
|
|
6239
|
+
|
|
6240
|
+
}
|
|
6241
|
+
break;
|
|
6242
|
+
case 128:
|
|
6243
|
+
for (int i = 0; i < nb; ++i) {
|
|
6244
|
+
const uint8_t * restrict q3 = x[i].qs;
|
|
6245
|
+
const uint8_t * restrict qh = x[i].hmask;
|
|
6246
|
+
const int8_t * restrict q8 = y[i].qs;
|
|
6247
|
+
|
|
6248
|
+
int8_t * scale = (int8_t *)utmp;
|
|
6249
|
+
int tmp;
|
|
6250
|
+
__asm__ __volatile__(
|
|
6251
|
+
"vsetivli zero, 12, e8, m1\n\t"
|
|
6252
|
+
"vle8.v v0, (%[s6b])\n\t"
|
|
6253
|
+
"vmv1r.v v2, v0\n\t"
|
|
6254
|
+
"vsetivli zero, 2, e64, m1\n\t"
|
|
6255
|
+
"vmv.v.x v9, %[sh]\n\t"\
|
|
6256
|
+
"vslidedown.vi v1, v0, 1\n\t"
|
|
6257
|
+
"vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4}
|
|
6258
|
+
"vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]}
|
|
6259
|
+
"vsetivli zero, 4, e32, m1\n\t"
|
|
6260
|
+
"vid.v v9\n\t"
|
|
6261
|
+
"vmv.x.s %[tmp], v1\n\t"
|
|
6262
|
+
"vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6}
|
|
6263
|
+
"vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]}
|
|
6264
|
+
"vsrl.vv v4, v1, v9\n\t"
|
|
6265
|
+
"vsrl.vv v2, v0, v8\n\t"
|
|
6266
|
+
"vand.vx v5, v4, %[kmask1]\n\t"
|
|
6267
|
+
"vand.vx v3, v2, %[kmask2]\n\t"
|
|
6268
|
+
"vsll.vi v6, v5, 4\n\t"
|
|
6269
|
+
"vor.vv v7, v6, v3\n\t"
|
|
6270
|
+
"vsetivli zero, 16, e8, m1\n\t"
|
|
6271
|
+
"vsub.vx v0, v7, %[c]\n\t"
|
|
6272
|
+
"vse8.v v0, (%[scale])"
|
|
6273
|
+
: [tmp] "=&r" (tmp)
|
|
6274
|
+
: [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32)
|
|
6275
|
+
, [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2)
|
|
6276
|
+
: "memory"
|
|
6277
|
+
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
|
|
6278
|
+
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
|
|
6279
|
+
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
|
|
6280
|
+
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
|
|
6281
|
+
);
|
|
6207
6282
|
|
|
6208
|
-
|
|
6283
|
+
uint8_t m = 1;
|
|
6284
|
+
int isum = 0;
|
|
6285
|
+
for (int j = 0; j < QK_K; j += 128) {
|
|
6286
|
+
__asm__ __volatile__(
|
|
6287
|
+
"vsetvli zero, %[vl32], e8, m2, ta, mu\n\t"
|
|
6288
|
+
"vle8.v v8, (%[q3])\n\t"
|
|
6289
|
+
"vsrl.vi v10, v8, 2\n\t"
|
|
6290
|
+
"vsrl.vi v12, v8, 4\n\t"
|
|
6291
|
+
"vsrl.vi v14, v8, 6\n\t"
|
|
6292
|
+
"vand.vi v8, v8, 3\n\t"
|
|
6293
|
+
"vand.vi v10, v10, 3\n\t"
|
|
6294
|
+
"vand.vi v12, v12, 3\n\t"
|
|
6295
|
+
"vle8.v v2, (%[qh])\n\t"
|
|
6296
|
+
"vand.vx v4, v2, %[m]\n\t"
|
|
6297
|
+
"slli %[m], %[m], 1\n\t"
|
|
6298
|
+
"vmseq.vx v0, v4, zero\n\t"
|
|
6299
|
+
"vadd.vi v8, v8, -4, v0.t\n\t"
|
|
6300
|
+
"vand.vx v4, v2, %[m]\n\t"
|
|
6301
|
+
"slli %[m], %[m], 1\n\t"
|
|
6302
|
+
"vmseq.vx v0, v4, zero\n\t"
|
|
6303
|
+
"vadd.vi v10, v10, -4, v0.t\n\t"
|
|
6304
|
+
"vand.vx v4, v2, %[m]\n\t"
|
|
6305
|
+
"slli %[m], %[m], 1\n\t"
|
|
6306
|
+
"vmseq.vx v0, v4, zero\n\t"
|
|
6307
|
+
"vadd.vi v12, v12, -4, v0.t\n\t"
|
|
6308
|
+
"vand.vx v4, v2, %[m]\n\t"
|
|
6309
|
+
"slli %[m], %[m], 1\n\t"
|
|
6310
|
+
"vmseq.vx v0, v4, zero\n\t"
|
|
6311
|
+
"vadd.vi v14, v14, -4, v0.t\n\t"
|
|
6312
|
+
"vsetvli zero, %[vl128], e8, m8\n\t"
|
|
6313
|
+
"vle8.v v0, (%[q8])\n\t"
|
|
6314
|
+
"vsetvli zero, %[vl64], e8, m4\n\t"
|
|
6315
|
+
"vwmul.vv v16, v0, v8\n\t"
|
|
6316
|
+
"vwmul.vv v24, v4, v12\n\t"
|
|
6317
|
+
"vsetivli zero, 16, e16, m2\n\t"
|
|
6318
|
+
"vmv.v.x v0, zero\n\t"
|
|
6319
|
+
"vwredsum.vs v10, v16, v0\n\t"
|
|
6320
|
+
"vwredsum.vs v9, v18, v0\n\t"
|
|
6321
|
+
"vwredsum.vs v8, v20, v0\n\t"
|
|
6322
|
+
"vwredsum.vs v7, v22, v0\n\t"
|
|
6323
|
+
"vwredsum.vs v11, v24, v0\n\t"
|
|
6324
|
+
"vwredsum.vs v12, v26, v0\n\t"
|
|
6325
|
+
"vwredsum.vs v13, v28, v0\n\t"
|
|
6326
|
+
"vwredsum.vs v14, v30, v0\n\t"
|
|
6327
|
+
"vsetivli zero, 4, e32, m1\n\t"
|
|
6328
|
+
"vslideup.vi v10, v9, 1\n\t"
|
|
6329
|
+
"vslideup.vi v8, v7, 1\n\t"
|
|
6330
|
+
"vslideup.vi v11, v12, 1\n\t"
|
|
6331
|
+
"vslideup.vi v13, v14, 1\n\t"
|
|
6332
|
+
"vslideup.vi v10, v8, 2\n\t"
|
|
6333
|
+
"vslideup.vi v11, v13, 2\n\t"
|
|
6334
|
+
"vsetivli zero, 8, e32, m2\n\t"\
|
|
6335
|
+
"vle8.v v15, (%[scale])\n\t"
|
|
6336
|
+
"vsext.vf4 v12, v15\n\t"
|
|
6337
|
+
"vmul.vv v10, v10, v12\n\t"
|
|
6338
|
+
"vredsum.vs v0, v10, v0\n\t"
|
|
6339
|
+
"vmv.x.s %[tmp], v0\n\t"
|
|
6340
|
+
"add %[isum], %[isum], %[tmp]"
|
|
6341
|
+
: [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum)
|
|
6342
|
+
: [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32)
|
|
6343
|
+
, [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8)
|
|
6344
|
+
: "memory"
|
|
6345
|
+
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
|
|
6346
|
+
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
|
|
6347
|
+
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
|
|
6348
|
+
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
|
|
6349
|
+
);
|
|
6350
|
+
q3 += 32; q8 += 128; scale += 8;
|
|
6351
|
+
}
|
|
6209
6352
|
|
|
6353
|
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
6354
|
+
sumf += d * isum;
|
|
6355
|
+
}
|
|
6356
|
+
break;
|
|
6357
|
+
default:
|
|
6358
|
+
assert(false && "Unsupported vector length");
|
|
6359
|
+
break;
|
|
6210
6360
|
}
|
|
6211
6361
|
|
|
6212
6362
|
*s = sumf;
|
|
@@ -6440,7 +6590,118 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
6440
6590
|
}
|
|
6441
6591
|
|
|
6442
6592
|
*s = hsum_float_8(acc);
|
|
6593
|
+
#elif defined(__VXE__) || defined(__VXE2__)
|
|
6594
|
+
uint32_t aux[3];
|
|
6595
|
+
uint32_t utmp[4];
|
|
6596
|
+
|
|
6597
|
+
const int32x4_t v_z = vec_splat_s32(0);
|
|
6598
|
+
const uint8x16_t v_3m = vec_splat_u8(0x03);
|
|
6599
|
+
|
|
6600
|
+
const uint8x16_t v_0c = vec_splat_u8(1);
|
|
6601
|
+
const uint8x16_t v_1c = vec_sl(v_0c, 1);
|
|
6602
|
+
const uint8x16_t v_2c = vec_sl(v_0c, 2);
|
|
6603
|
+
const uint8x16_t v_3c = vec_sl(v_0c, 3);
|
|
6604
|
+
|
|
6605
|
+
uint8x16_t q3h[4];
|
|
6606
|
+
uint8x16_t q3b[2];
|
|
6607
|
+
int8x16_t q3bytes[4];
|
|
6608
|
+
int8x16_t q8bytes[4];
|
|
6609
|
+
uint8x16_t qhbits[2];
|
|
6610
|
+
|
|
6611
|
+
float sum = 0;
|
|
6612
|
+
|
|
6613
|
+
for (int i = 0; i < nb; ++i) {
|
|
6614
|
+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
|
6615
|
+
|
|
6616
|
+
const uint8_t * restrict x0l = x[i].qs;
|
|
6617
|
+
const uint8_t * restrict x0h = x[i].hmask;
|
|
6618
|
+
const int8_t * restrict y0 = y[i].qs;
|
|
6619
|
+
|
|
6620
|
+
qhbits[0] = vec_xl(0 , x0h);
|
|
6621
|
+
qhbits[1] = vec_xl(16, x0h);
|
|
6622
|
+
|
|
6623
|
+
int32_t isum = 0;
|
|
6443
6624
|
|
|
6625
|
+
memcpy(aux, x[i].scales, 12);
|
|
6626
|
+
utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
|
|
6627
|
+
utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
|
|
6628
|
+
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
|
|
6629
|
+
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
|
|
6630
|
+
|
|
6631
|
+
int8_t * scale = (int8_t *)utmp;
|
|
6632
|
+
for (int j = 0; j < 16; ++j) scale[j] -= 32;
|
|
6633
|
+
|
|
6634
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
|
6635
|
+
int32x4_t isum0, isum1, isum2, isum3;
|
|
6636
|
+
|
|
6637
|
+
q3b[0] = vec_xl(0 , x0l);
|
|
6638
|
+
q3b[1] = vec_xl(16, x0l);
|
|
6639
|
+
x0l += 32;
|
|
6640
|
+
|
|
6641
|
+
q8bytes[0] = vec_xl(0 , y0);
|
|
6642
|
+
q8bytes[1] = vec_xl(16 , y0);
|
|
6643
|
+
q8bytes[2] = vec_xl(32 , y0);
|
|
6644
|
+
q8bytes[3] = vec_xl(48 , y0);
|
|
6645
|
+
q8bytes[4] = vec_xl(64 , y0);
|
|
6646
|
+
q8bytes[5] = vec_xl(80 , y0);
|
|
6647
|
+
q8bytes[6] = vec_xl(96 , y0);
|
|
6648
|
+
q8bytes[7] = vec_xl(112, y0);
|
|
6649
|
+
y0 += 128;
|
|
6650
|
+
|
|
6651
|
+
q3h[0] = vec_sl(vec_andc(v_0c, qhbits[0]), 2);
|
|
6652
|
+
q3h[1] = vec_sl(vec_andc(v_0c, qhbits[1]), 2);
|
|
6653
|
+
q3h[2] = vec_sl(vec_andc(v_1c, qhbits[0]), 1);
|
|
6654
|
+
q3h[3] = vec_sl(vec_andc(v_1c, qhbits[1]), 1);
|
|
6655
|
+
|
|
6656
|
+
q3bytes[0] = vec_sub((int8x16_t)vec_and(q3b[0], v_3m), (int8x16_t)q3h[0]);
|
|
6657
|
+
q3bytes[1] = vec_sub((int8x16_t)vec_and(q3b[1], v_3m), (int8x16_t)q3h[1]);
|
|
6658
|
+
q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 2), v_3m), (int8x16_t)q3h[2]);
|
|
6659
|
+
q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 2), v_3m), (int8x16_t)q3h[3]);
|
|
6660
|
+
|
|
6661
|
+
isum0 = ggml_vec_dot(v_z, q3bytes[0], q8bytes[0]);
|
|
6662
|
+
isum1 = ggml_vec_dot(v_z, q3bytes[1], q8bytes[1]);
|
|
6663
|
+
isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[2]);
|
|
6664
|
+
isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[3]);
|
|
6665
|
+
|
|
6666
|
+
isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0];
|
|
6667
|
+
isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1];
|
|
6668
|
+
isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2];
|
|
6669
|
+
isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3];
|
|
6670
|
+
|
|
6671
|
+
scale += 4;
|
|
6672
|
+
|
|
6673
|
+
q3h[0] = vec_andc(v_2c, qhbits[0]);
|
|
6674
|
+
q3h[1] = vec_andc(v_2c, qhbits[1]);
|
|
6675
|
+
q3h[2] = vec_sr(vec_andc(v_3c, qhbits[0]), 1);
|
|
6676
|
+
q3h[3] = vec_sr(vec_andc(v_3c, qhbits[1]), 1);
|
|
6677
|
+
|
|
6678
|
+
q3bytes[0] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 4), v_3m), (int8x16_t)q3h[0]);
|
|
6679
|
+
q3bytes[1] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 4), v_3m), (int8x16_t)q3h[1]);
|
|
6680
|
+
q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 6), v_3m), (int8x16_t)q3h[2]);
|
|
6681
|
+
q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 6), v_3m), (int8x16_t)q3h[3]);
|
|
6682
|
+
|
|
6683
|
+
isum0 = ggml_vec_dot(v_z, q3bytes[0], q8bytes[4]);
|
|
6684
|
+
isum1 = ggml_vec_dot(v_z, q3bytes[1], q8bytes[5]);
|
|
6685
|
+
isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[6]);
|
|
6686
|
+
isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[7]);
|
|
6687
|
+
|
|
6688
|
+
isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0];
|
|
6689
|
+
isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1];
|
|
6690
|
+
isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2];
|
|
6691
|
+
isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3];
|
|
6692
|
+
|
|
6693
|
+
scale += 4;
|
|
6694
|
+
|
|
6695
|
+
if (j == 0) {
|
|
6696
|
+
qhbits[0] = vec_sr(qhbits[0], 4);
|
|
6697
|
+
qhbits[1] = vec_sr(qhbits[1], 4);
|
|
6698
|
+
}
|
|
6699
|
+
}
|
|
6700
|
+
|
|
6701
|
+
sum += d * isum;
|
|
6702
|
+
}
|
|
6703
|
+
|
|
6704
|
+
*s = sum;
|
|
6444
6705
|
#else
|
|
6445
6706
|
// scalar version
|
|
6446
6707
|
// This function is written like this so the compiler can manage to vectorize most of it
|
|
@@ -6924,69 +7185,181 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
6924
7185
|
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
|
6925
7186
|
const uint8_t * mins = (const uint8_t*)&utmp[2];
|
|
6926
7187
|
|
|
7188
|
+
const int vector_length = __riscv_vlenb() * 8;
|
|
6927
7189
|
float sumf = 0;
|
|
6928
7190
|
|
|
6929
|
-
|
|
7191
|
+
switch (vector_length) {
|
|
7192
|
+
case 256:
|
|
7193
|
+
for (int i = 0; i < nb; ++i) {
|
|
6930
7194
|
|
|
6931
|
-
|
|
7195
|
+
size_t vl = 8;
|
|
6932
7196
|
|
|
6933
|
-
|
|
6934
|
-
|
|
7197
|
+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
|
7198
|
+
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
|
6935
7199
|
|
|
6936
|
-
|
|
6937
|
-
|
|
6938
|
-
|
|
7200
|
+
vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
|
|
7201
|
+
vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
|
|
7202
|
+
vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
|
|
6939
7203
|
|
|
6940
|
-
|
|
6941
|
-
|
|
6942
|
-
|
|
6943
|
-
|
|
6944
|
-
|
|
6945
|
-
|
|
7204
|
+
memcpy(utmp, x[i].scales, 12);
|
|
7205
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
|
7206
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
|
7207
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
|
7208
|
+
utmp[2] = uaux;
|
|
7209
|
+
utmp[0] &= kmask1;
|
|
6946
7210
|
|
|
6947
|
-
|
|
6948
|
-
|
|
6949
|
-
|
|
7211
|
+
vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
|
|
7212
|
+
vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
|
|
7213
|
+
vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
|
|
6950
7214
|
|
|
6951
|
-
|
|
6952
|
-
|
|
7215
|
+
vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
|
|
7216
|
+
sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
|
|
6953
7217
|
|
|
6954
|
-
|
|
6955
|
-
|
|
7218
|
+
const uint8_t * GGML_RESTRICT q4 = x[i].qs;
|
|
7219
|
+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
|
6956
7220
|
|
|
6957
|
-
|
|
7221
|
+
vl = 32;
|
|
6958
7222
|
|
|
6959
|
-
|
|
6960
|
-
|
|
7223
|
+
int32_t sum_1 = 0;
|
|
7224
|
+
int32_t sum_2 = 0;
|
|
6961
7225
|
|
|
6962
|
-
|
|
7226
|
+
vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
|
|
6963
7227
|
|
|
6964
|
-
|
|
6965
|
-
|
|
6966
|
-
|
|
7228
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
|
7229
|
+
// load Q4
|
|
7230
|
+
vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
|
|
6967
7231
|
|
|
6968
|
-
|
|
6969
|
-
|
|
6970
|
-
|
|
6971
|
-
|
|
6972
|
-
|
|
7232
|
+
// load Q8 and multiply it with lower Q4 nibble
|
|
7233
|
+
vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
|
|
7234
|
+
vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
|
|
7235
|
+
vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl);
|
|
7236
|
+
vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl);
|
|
6973
7237
|
|
|
6974
|
-
|
|
7238
|
+
sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0];
|
|
6975
7239
|
|
|
6976
|
-
|
|
6977
|
-
|
|
6978
|
-
|
|
6979
|
-
|
|
6980
|
-
|
|
7240
|
+
// load Q8 and multiply it with upper Q4 nibble
|
|
7241
|
+
vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
|
|
7242
|
+
vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
|
|
7243
|
+
vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl);
|
|
7244
|
+
vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl);
|
|
6981
7245
|
|
|
6982
|
-
|
|
7246
|
+
sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1];
|
|
6983
7247
|
|
|
6984
|
-
|
|
7248
|
+
q4 += 32; q8 += 64;
|
|
6985
7249
|
|
|
6986
|
-
|
|
7250
|
+
}
|
|
6987
7251
|
|
|
6988
|
-
|
|
7252
|
+
sumf += d*(sum_1 + sum_2);
|
|
7253
|
+
|
|
7254
|
+
}
|
|
7255
|
+
break;
|
|
7256
|
+
case 128:
|
|
7257
|
+
for (int i = 0; i < nb; ++i) {
|
|
7258
|
+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
|
7259
|
+
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
|
|
7260
|
+
|
|
7261
|
+
int tmp, tmp2, sumi;
|
|
7262
|
+
__asm__ __volatile__(
|
|
7263
|
+
"vsetivli zero, 12, e8, m1\n\t"
|
|
7264
|
+
"vle8.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]}
|
|
7265
|
+
"vsetivli zero, 4, e32, m1\n\t"
|
|
7266
|
+
"vslidedown.vi v2, v1, 2\n\t"
|
|
7267
|
+
"vmv1r.v v3, v2\n\t"
|
|
7268
|
+
"vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
|
|
7269
|
+
"vsetivli zero, 2, e32, m1\n\t"
|
|
7270
|
+
"vmv.v.i v4, 4\n\t"
|
|
7271
|
+
"vand.vx v8, v1, %[kmask1]\n\t"
|
|
7272
|
+
"vslide1up.vx v5, v4, zero\n\t" // {0, 4}
|
|
7273
|
+
"vsrl.vi v6, v1, 6\n\t"
|
|
7274
|
+
"vsrl.vv v7, v2, v5\n\t"
|
|
7275
|
+
"vand.vx v0, v6, %[kmask3]\n\t"
|
|
7276
|
+
"vand.vx v2, v7, %[kmask2]\n\t"
|
|
7277
|
+
"vsll.vi v6, v0, 4\n\t"
|
|
7278
|
+
"li %[t2], 8\n\t"
|
|
7279
|
+
"addi %[t1], %[utmp], 4\n\t"
|
|
7280
|
+
"vor.vv v1, v6, v2\n\t"
|
|
7281
|
+
"vsse32.v v8, (%[utmp]), %[t2]\n\t"
|
|
7282
|
+
"vsse32.v v1, (%[t1]), %[t2]\n\t"
|
|
7283
|
+
"vsetivli zero, 8, e16, m1\n\t"
|
|
7284
|
+
"vle32.v v2, (%[bsums])\n\t"
|
|
7285
|
+
"vnsrl.wi v0, v2, 0\n\t"
|
|
7286
|
+
"vnsrl.wi v1, v2, 16\n\t"
|
|
7287
|
+
"vadd.vv v2, v0, v1\n\t"
|
|
7288
|
+
"vle8.v v3, (%[mins])\n\t"
|
|
7289
|
+
"vzext.vf2 v4, v3\n\t"
|
|
7290
|
+
"vwmul.vv v6, v4, v2\n\t"
|
|
7291
|
+
"vmv.v.x v0, zero\n\t"
|
|
7292
|
+
"vsetivli zero, 8, e32, m2\n\t"
|
|
7293
|
+
"vredsum.vs v0, v6, v0\n\t"
|
|
7294
|
+
"vmv.x.s %[sumi], v0"
|
|
7295
|
+
: [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi)
|
|
7296
|
+
: [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
|
|
7297
|
+
, [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1)
|
|
7298
|
+
, [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
|
|
7299
|
+
: "memory"
|
|
7300
|
+
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
|
|
7301
|
+
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
|
|
7302
|
+
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
|
|
7303
|
+
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
|
|
7304
|
+
);
|
|
7305
|
+
sumf -= dmin * sumi;
|
|
7306
|
+
|
|
7307
|
+
const uint8_t * restrict q4 = x[i].qs;
|
|
7308
|
+
const int8_t * restrict q8 = y[i].qs;
|
|
7309
|
+
|
|
7310
|
+
sumi = 0;
|
|
7311
|
+
const uint8_t * scale = scales;
|
|
7312
|
+
|
|
7313
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
|
7314
|
+
int vl128 = 128, vl64 = 64, vl32 = 32;
|
|
7315
|
+
__asm__ __volatile__(
|
|
7316
|
+
"vsetvli zero, %[vl128], e8, m8\n\t"
|
|
7317
|
+
"vle8.v v8, (%[q8])\n\t"
|
|
7318
|
+
"vsetvli zero, %[vl64], e8, m4\n\t"
|
|
7319
|
+
"vle8.v v0, (%[q4])\n\t"
|
|
7320
|
+
"vsrl.vi v4, v0, 4\n\t"
|
|
7321
|
+
"vand.vi v0, v0, 0xF\n\t"
|
|
7322
|
+
"vsetvli zero, %[vl32], e8, m2\n\t"
|
|
7323
|
+
"vwmul.vv v28, v6, v14\n\t"
|
|
7324
|
+
"vwmul.vv v20, v4, v10\n\t"
|
|
7325
|
+
"vwmul.vv v24, v2, v12\n\t"
|
|
7326
|
+
"vwmul.vv v16, v0, v8\n\t"
|
|
7327
|
+
"vsetivli zero, 4, e32, m1\n\t"
|
|
7328
|
+
"vle8.v v2, (%[scale])\n\t"
|
|
7329
|
+
"vmv.v.x v0, zero\n\t"
|
|
7330
|
+
"vzext.vf4 v1, v2\n\t"
|
|
7331
|
+
"vsetvli zero, %[vl32], e16, m4\n\t"
|
|
7332
|
+
"vwredsum.vs v6, v24, v0\n\t"
|
|
7333
|
+
"vwredsum.vs v7, v28, v0\n\t"
|
|
7334
|
+
"vwredsum.vs v4, v16, v0\n\t"
|
|
7335
|
+
"vwredsum.vs v5, v20, v0\n\t"
|
|
7336
|
+
"vsetivli zero, 4, e32, m1\n\t"
|
|
7337
|
+
"vslideup.vi v6, v7, 1\n\t"
|
|
7338
|
+
"vslideup.vi v4, v5, 1\n\t"
|
|
7339
|
+
"vslideup.vi v4, v6, 2\n\t"
|
|
7340
|
+
"vmul.vv v8, v4, v1\n\t"
|
|
7341
|
+
"vredsum.vs v0, v8, v0\n\t"
|
|
7342
|
+
"vmv.x.s %[tmp], v0\n\t"
|
|
7343
|
+
"add %[sumi], %[sumi], %[tmp]"
|
|
7344
|
+
: [tmp] "=&r" (tmp), [sumi] "+&r" (sumi)
|
|
7345
|
+
: [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32)
|
|
7346
|
+
, [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale)
|
|
7347
|
+
: "memory"
|
|
7348
|
+
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
|
|
7349
|
+
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
|
|
7350
|
+
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
|
|
7351
|
+
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
|
|
7352
|
+
);
|
|
7353
|
+
|
|
7354
|
+
q4 += 64; q8 += 128; scale += 4;
|
|
7355
|
+
}
|
|
6989
7356
|
|
|
7357
|
+
sumf += d * sumi;
|
|
7358
|
+
}
|
|
7359
|
+
break;
|
|
7360
|
+
default:
|
|
7361
|
+
assert(false && "Unsupported vector length");
|
|
7362
|
+
break;
|
|
6990
7363
|
}
|
|
6991
7364
|
|
|
6992
7365
|
*s = sumf;
|
|
@@ -7722,9 +8095,9 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
7722
8095
|
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
7723
8096
|
const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
|
|
7724
8097
|
|
|
7725
|
-
|
|
7726
|
-
|
|
7727
|
-
|
|
8098
|
+
vint16m1_t q8sums_0 = __riscv_vlse16_v_i16m1(y[i].bsums, 4, vl);
|
|
8099
|
+
vint16m1_t q8sums_1 = __riscv_vlse16_v_i16m1(y[i].bsums+1, 4, vl);
|
|
8100
|
+
vint16m1_t q8sums = __riscv_vadd_vv_i16m1(q8sums_0, q8sums_1, vl);
|
|
7728
8101
|
|
|
7729
8102
|
memcpy(utmp, x[i].scales, 12);
|
|
7730
8103
|
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
|
@@ -7733,11 +8106,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
7733
8106
|
utmp[2] = uaux;
|
|
7734
8107
|
utmp[0] &= kmask1;
|
|
7735
8108
|
|
|
7736
|
-
|
|
7737
|
-
|
|
7738
|
-
|
|
8109
|
+
vuint8mf2_t mins8 = __riscv_vle8_v_u8mf2(mins, vl);
|
|
8110
|
+
vint16m1_t v_mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
|
|
8111
|
+
vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, v_mins, vl);
|
|
7739
8112
|
|
|
7740
|
-
vint32m1_t sumi =
|
|
8113
|
+
vint32m1_t sumi = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
|
|
7741
8114
|
sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
|
|
7742
8115
|
|
|
7743
8116
|
vl = 32;
|
|
@@ -7746,43 +8119,42 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
7746
8119
|
|
|
7747
8120
|
uint8_t m = 1;
|
|
7748
8121
|
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
|
|
7749
|
-
|
|
8122
|
+
vuint8m2_t vqh = __riscv_vle8_v_u8m2(hm, vl);
|
|
7750
8123
|
|
|
7751
8124
|
for (int j = 0; j < QK_K/64; ++j) {
|
|
7752
8125
|
// load Q5 and Q8
|
|
7753
|
-
|
|
7754
|
-
|
|
7755
|
-
|
|
8126
|
+
vuint8m2_t q5_x = __riscv_vle8_v_u8m2(q5, vl);
|
|
8127
|
+
vint8m2_t q8_y1 = __riscv_vle8_v_i8m2(q8, vl);
|
|
8128
|
+
vint8m2_t q8_y2 = __riscv_vle8_v_i8m2(q8+32, vl);
|
|
7756
8129
|
|
|
7757
8130
|
// compute mask for addition
|
|
7758
|
-
|
|
7759
|
-
|
|
7760
|
-
|
|
7761
|
-
|
|
8131
|
+
vint8m2_t q5_a = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(q5_x, 0x0F, vl));
|
|
8132
|
+
vuint8m2_t qh_m1 = __riscv_vand_vx_u8m2(vqh, m, vl);
|
|
8133
|
+
vbool4_t vmask_1 = __riscv_vmsne_vx_u8m2_b4(qh_m1, 0, vl);
|
|
8134
|
+
vint8m2_t q5_m1 = __riscv_vadd_vx_i8m2_mu(vmask_1, q5_a, q5_a, 16, vl);
|
|
7762
8135
|
m <<= 1;
|
|
7763
8136
|
|
|
7764
|
-
|
|
7765
|
-
|
|
7766
|
-
|
|
7767
|
-
|
|
8137
|
+
vint8m2_t q5_l = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vsrl_vx_u8m2(q5_x, 0x04, vl));
|
|
8138
|
+
vuint8m2_t qh_m2 = __riscv_vand_vx_u8m2(vqh, m, vl);
|
|
8139
|
+
vbool4_t vmask_2 = __riscv_vmsne_vx_u8m2_b4(qh_m2, 0, vl);
|
|
8140
|
+
vint8m2_t q5_m2 = __riscv_vadd_vx_i8m2_mu(vmask_2, q5_l, q5_l, 16, vl);
|
|
7768
8141
|
m <<= 1;
|
|
7769
8142
|
|
|
7770
|
-
|
|
7771
|
-
|
|
8143
|
+
vint16m4_t v0 = __riscv_vwmul_vv_i16m4(q5_m1, q8_y1, vl);
|
|
8144
|
+
vint16m4_t v1 = __riscv_vwmul_vv_i16m4(q5_m2, q8_y2, vl);
|
|
7772
8145
|
|
|
7773
|
-
|
|
7774
|
-
|
|
8146
|
+
vint32m8_t vs1 = __riscv_vwmul_vx_i32m8(v0, scales[is++], vl);
|
|
8147
|
+
vint32m8_t vs2 = __riscv_vwmul_vx_i32m8(v1, scales[is++], vl);
|
|
7775
8148
|
|
|
7776
|
-
vint32m1_t vacc1 =
|
|
7777
|
-
vint32m1_t vacc2 =
|
|
8149
|
+
vint32m1_t vacc1 = __riscv_vredsum_vs_i32m8_i32m1(vs1, vzero, vl);
|
|
8150
|
+
vint32m1_t vacc2 = __riscv_vredsum_vs_i32m8_i32m1(vs2, vacc1, vl);
|
|
7778
8151
|
|
|
7779
|
-
aux32 += __riscv_vmv_x_s_i32m1_i32(
|
|
8152
|
+
aux32 += __riscv_vmv_x_s_i32m1_i32(vacc2);
|
|
7780
8153
|
q5 += 32; q8 += 64;
|
|
7781
8154
|
|
|
7782
8155
|
}
|
|
7783
8156
|
|
|
7784
|
-
|
|
7785
|
-
sums += __riscv_vfmv_f_s_f32m1_f32(vaux);
|
|
8157
|
+
sums += aux32 * d;
|
|
7786
8158
|
|
|
7787
8159
|
}
|
|
7788
8160
|
|
|
@@ -8147,7 +8519,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
8147
8519
|
|
|
8148
8520
|
void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
8149
8521
|
assert(n % QK_K == 0);
|
|
8522
|
+
#ifdef __ARM_FEATURE_MATMUL_INT8
|
|
8523
|
+
assert((nrc == 2) || (nrc == 1));
|
|
8524
|
+
#else
|
|
8150
8525
|
assert(nrc == 1);
|
|
8526
|
+
#endif
|
|
8151
8527
|
UNUSED(nrc);
|
|
8152
8528
|
UNUSED(bx);
|
|
8153
8529
|
UNUSED(by);
|
|
@@ -8158,6 +8534,197 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
8158
8534
|
|
|
8159
8535
|
const int nb = n / QK_K;
|
|
8160
8536
|
|
|
8537
|
+
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
|
8538
|
+
if (nrc == 2) {
|
|
8539
|
+
const block_q6_K * GGML_RESTRICT x0 = x;
|
|
8540
|
+
const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
|
|
8541
|
+
const block_q8_K * GGML_RESTRICT y0 = y;
|
|
8542
|
+
const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
|
|
8543
|
+
|
|
8544
|
+
float32x4_t vfsum = vdupq_n_f32(0.0f);
|
|
8545
|
+
|
|
8546
|
+
for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
|
|
8547
|
+
const uint8_t * GGML_RESTRICT ql0 = x0->ql;
|
|
8548
|
+
const uint8_t * GGML_RESTRICT ql1 = x1->ql;
|
|
8549
|
+
const uint8_t * GGML_RESTRICT qh0 = x0->qh;
|
|
8550
|
+
const uint8_t * GGML_RESTRICT qh1 = x1->qh;
|
|
8551
|
+
const int8_t * GGML_RESTRICT qy0 = y0->qs;
|
|
8552
|
+
const int8_t * GGML_RESTRICT qy1 = y1->qs;
|
|
8553
|
+
|
|
8554
|
+
const uint8x16_t mone = vdupq_n_u8(0x30);
|
|
8555
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
|
8556
|
+
|
|
8557
|
+
int32x4_t visum = vdupq_n_s32(0);
|
|
8558
|
+
|
|
8559
|
+
// process 8 blocks per iteration, totally 16 blocks
|
|
8560
|
+
for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) {
|
|
8561
|
+
int8x16_t vx0[8], vx1[8];
|
|
8562
|
+
|
|
8563
|
+
// de-quantize vx0[8]
|
|
8564
|
+
{
|
|
8565
|
+
const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0);
|
|
8566
|
+
const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0);
|
|
8567
|
+
|
|
8568
|
+
uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
|
|
8569
|
+
uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
|
|
8570
|
+
uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
|
|
8571
|
+
uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
|
|
8572
|
+
|
|
8573
|
+
vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
|
|
8574
|
+
vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
|
|
8575
|
+
vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
|
|
8576
|
+
vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
|
|
8577
|
+
|
|
8578
|
+
q6h_0 = vandq_u8(mone, qh_bits.val[0]);
|
|
8579
|
+
q6h_1 = vandq_u8(mone, qh_bits.val[1]);
|
|
8580
|
+
q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
|
|
8581
|
+
q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
|
|
8582
|
+
|
|
8583
|
+
vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
|
|
8584
|
+
vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
|
|
8585
|
+
vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
|
|
8586
|
+
vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
|
|
8587
|
+
}
|
|
8588
|
+
|
|
8589
|
+
// de-quantize vx1[8]
|
|
8590
|
+
{
|
|
8591
|
+
const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1);
|
|
8592
|
+
const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1);
|
|
8593
|
+
|
|
8594
|
+
uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
|
|
8595
|
+
uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
|
|
8596
|
+
uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
|
|
8597
|
+
uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
|
|
8598
|
+
|
|
8599
|
+
vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
|
|
8600
|
+
vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
|
|
8601
|
+
vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
|
|
8602
|
+
vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
|
|
8603
|
+
|
|
8604
|
+
q6h_0 = vandq_u8(mone, qh_bits.val[0]);
|
|
8605
|
+
q6h_1 = vandq_u8(mone, qh_bits.val[1]);
|
|
8606
|
+
q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
|
|
8607
|
+
q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
|
|
8608
|
+
|
|
8609
|
+
vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
|
|
8610
|
+
vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
|
|
8611
|
+
vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
|
|
8612
|
+
vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
|
|
8613
|
+
}
|
|
8614
|
+
|
|
8615
|
+
// process 16 elements (one block with same scale) per iteration
|
|
8616
|
+
// - vx = concat(ql, qh) - 32
|
|
8617
|
+
// - r1,r2,r3,r4 = smmla(vx, vy)
|
|
8618
|
+
for (int k = 0; k < 8; ++k) {
|
|
8619
|
+
const int blk = j * 8 + k;
|
|
8620
|
+
|
|
8621
|
+
const int8x16_t vy0 = vld1q_s8(qy0);
|
|
8622
|
+
const int8x16_t vy1 = vld1q_s8(qy1);
|
|
8623
|
+
qy0 += 16;
|
|
8624
|
+
qy1 += 16;
|
|
8625
|
+
|
|
8626
|
+
const int32x4_t block_scale = {
|
|
8627
|
+
x0->scales[blk],
|
|
8628
|
+
x0->scales[blk],
|
|
8629
|
+
x1->scales[blk],
|
|
8630
|
+
x1->scales[blk],
|
|
8631
|
+
};
|
|
8632
|
+
|
|
8633
|
+
// calculate four results at once with outer product
|
|
8634
|
+
const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
|
|
8635
|
+
const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
|
|
8636
|
+
const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
|
|
8637
|
+
const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
|
|
8638
|
+
int32x4_t vr = vdupq_n_s32(0);
|
|
8639
|
+
vr = vmmlaq_s32(vr, vx_l, vy_l);
|
|
8640
|
+
vr = vmmlaq_s32(vr, vx_h, vy_h);
|
|
8641
|
+
|
|
8642
|
+
// apply block scale, will NOT overflow
|
|
8643
|
+
// block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits
|
|
8644
|
+
visum = vmlaq_s32(visum, vr, block_scale);
|
|
8645
|
+
}
|
|
8646
|
+
}
|
|
8647
|
+
|
|
8648
|
+
// adjust bias, apply superblock scale
|
|
8649
|
+
{
|
|
8650
|
+
int32_t bias[4];
|
|
8651
|
+
#ifdef __ARM_FEATURE_SVE
|
|
8652
|
+
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
|
|
8653
|
+
const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
|
|
8654
|
+
const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
|
|
8655
|
+
const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
|
|
8656
|
+
const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
|
|
8657
|
+
const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
|
|
8658
|
+
const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
|
|
8659
|
+
const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
|
|
8660
|
+
const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
|
|
8661
|
+
const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
|
|
8662
|
+
const svint64_t zero = svdup_n_s64(0);
|
|
8663
|
+
bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
|
|
8664
|
+
svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
|
|
8665
|
+
bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
|
|
8666
|
+
svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
|
|
8667
|
+
bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
|
|
8668
|
+
svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
|
|
8669
|
+
bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
|
|
8670
|
+
svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
|
|
8671
|
+
#else
|
|
8672
|
+
// NEON doesn't support int16 dot product, fallback to separated mul and add
|
|
8673
|
+
const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
|
|
8674
|
+
const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
|
|
8675
|
+
|
|
8676
|
+
int8x16_t scales_s8 = vld1q_s8(x0->scales);
|
|
8677
|
+
const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
|
|
8678
|
+
scales_s8 = vld1q_s8(x1->scales);
|
|
8679
|
+
const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
|
|
8680
|
+
|
|
8681
|
+
int32x4_t prod;
|
|
8682
|
+
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])),
|
|
8683
|
+
vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))),
|
|
8684
|
+
vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])),
|
|
8685
|
+
vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1]))));
|
|
8686
|
+
bias[0] = vaddvq_s32(prod);
|
|
8687
|
+
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])),
|
|
8688
|
+
vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))),
|
|
8689
|
+
vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])),
|
|
8690
|
+
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1]))));
|
|
8691
|
+
bias[1] = vaddvq_s32(prod);
|
|
8692
|
+
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])),
|
|
8693
|
+
vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))),
|
|
8694
|
+
vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])),
|
|
8695
|
+
vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1]))));
|
|
8696
|
+
bias[2] = vaddvq_s32(prod);
|
|
8697
|
+
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])),
|
|
8698
|
+
vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))),
|
|
8699
|
+
vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])),
|
|
8700
|
+
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
|
|
8701
|
+
bias[3] = vaddvq_s32(prod);
|
|
8702
|
+
|
|
8703
|
+
#endif
|
|
8704
|
+
const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
|
|
8705
|
+
|
|
8706
|
+
const float32x4_t superblock_scale = {
|
|
8707
|
+
GGML_FP16_TO_FP32(x0->d) * y0->d,
|
|
8708
|
+
GGML_FP16_TO_FP32(x0->d) * y1->d,
|
|
8709
|
+
GGML_FP16_TO_FP32(x1->d) * y0->d,
|
|
8710
|
+
GGML_FP16_TO_FP32(x1->d) * y1->d,
|
|
8711
|
+
};
|
|
8712
|
+
|
|
8713
|
+
visum = vsubq_s32(visum, vibias);
|
|
8714
|
+
vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
|
|
8715
|
+
}
|
|
8716
|
+
}
|
|
8717
|
+
|
|
8718
|
+
// vfsum = ABCD -> ACBD
|
|
8719
|
+
// AC -> s, BD -> (s+bs)
|
|
8720
|
+
vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
|
|
8721
|
+
vst1_f32(s, vget_low_f32 (vfsum));
|
|
8722
|
+
vst1_f32(s + bs, vget_high_f32(vfsum));
|
|
8723
|
+
|
|
8724
|
+
return;
|
|
8725
|
+
}
|
|
8726
|
+
#endif
|
|
8727
|
+
|
|
8161
8728
|
#ifdef __ARM_FEATURE_SVE
|
|
8162
8729
|
const int vector_length = ggml_cpu_get_sve_cnt()*8;
|
|
8163
8730
|
float sum = 0;
|
|
@@ -8667,85 +9234,168 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
8667
9234
|
|
|
8668
9235
|
#elif defined __riscv_v_intrinsic
|
|
8669
9236
|
|
|
9237
|
+
const int vector_length = __riscv_vlenb() * 8;
|
|
8670
9238
|
float sumf = 0;
|
|
8671
|
-
for (int i = 0; i < nb; ++i) {
|
|
8672
9239
|
|
|
8673
|
-
|
|
9240
|
+
switch (vector_length) {
|
|
9241
|
+
case 256:
|
|
9242
|
+
for (int i = 0; i < nb; ++i) {
|
|
8674
9243
|
|
|
8675
|
-
|
|
8676
|
-
const uint8_t * GGML_RESTRICT qh = x[i].qh;
|
|
8677
|
-
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
|
9244
|
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
8678
9245
|
|
|
8679
|
-
|
|
9246
|
+
const uint8_t * GGML_RESTRICT q6 = x[i].ql;
|
|
9247
|
+
const uint8_t * GGML_RESTRICT qh = x[i].qh;
|
|
9248
|
+
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
|
8680
9249
|
|
|
8681
|
-
|
|
9250
|
+
const int8_t * GGML_RESTRICT scale = x[i].scales;
|
|
8682
9251
|
|
|
8683
|
-
|
|
9252
|
+
size_t vl;
|
|
8684
9253
|
|
|
8685
|
-
|
|
8686
|
-
int is = 0;
|
|
9254
|
+
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
|
|
8687
9255
|
|
|
8688
|
-
|
|
9256
|
+
int sum_t = 0;
|
|
9257
|
+
int is = 0;
|
|
8689
9258
|
|
|
8690
|
-
|
|
9259
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
|
8691
9260
|
|
|
8692
|
-
|
|
8693
|
-
vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
|
|
9261
|
+
vl = 32;
|
|
8694
9262
|
|
|
8695
|
-
|
|
8696
|
-
|
|
8697
|
-
vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
|
|
9263
|
+
// load qh
|
|
9264
|
+
vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
|
|
8698
9265
|
|
|
8699
|
-
|
|
8700
|
-
|
|
8701
|
-
|
|
8702
|
-
vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
|
|
9266
|
+
// load Q6
|
|
9267
|
+
vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
|
|
9268
|
+
vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
|
|
8703
9269
|
|
|
8704
|
-
|
|
8705
|
-
|
|
8706
|
-
|
|
8707
|
-
|
|
9270
|
+
vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
|
|
9271
|
+
vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
|
|
9272
|
+
vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
|
|
9273
|
+
vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
|
|
8708
9274
|
|
|
8709
|
-
|
|
8710
|
-
|
|
8711
|
-
|
|
8712
|
-
|
|
9275
|
+
vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
|
|
9276
|
+
vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
|
|
9277
|
+
vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
|
|
9278
|
+
vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
|
|
8713
9279
|
|
|
8714
|
-
|
|
8715
|
-
|
|
8716
|
-
|
|
8717
|
-
|
|
9280
|
+
vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
|
|
9281
|
+
vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
|
|
9282
|
+
vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
|
|
9283
|
+
vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
|
|
8718
9284
|
|
|
8719
|
-
|
|
8720
|
-
|
|
8721
|
-
|
|
8722
|
-
|
|
8723
|
-
vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
|
|
9285
|
+
vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
|
|
9286
|
+
vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
|
|
9287
|
+
vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
|
|
9288
|
+
vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
|
|
8724
9289
|
|
|
8725
|
-
|
|
9290
|
+
// load Q8 and take product
|
|
9291
|
+
vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
|
|
9292
|
+
vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
|
|
9293
|
+
vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
|
|
9294
|
+
vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
|
|
8726
9295
|
|
|
8727
|
-
|
|
8728
|
-
vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
|
|
8729
|
-
vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
|
|
8730
|
-
vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
|
|
8731
|
-
vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
|
|
8732
|
-
vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
|
|
8733
|
-
vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
|
|
8734
|
-
vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
|
|
9296
|
+
vl = 16;
|
|
8735
9297
|
|
|
8736
|
-
|
|
8737
|
-
|
|
8738
|
-
|
|
8739
|
-
|
|
9298
|
+
vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
|
|
9299
|
+
vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
|
|
9300
|
+
vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
|
|
9301
|
+
vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
|
|
9302
|
+
vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
|
|
9303
|
+
vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
|
|
9304
|
+
vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
|
|
9305
|
+
vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
|
|
8740
9306
|
|
|
8741
|
-
|
|
9307
|
+
vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
|
|
9308
|
+
vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
|
|
9309
|
+
vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
|
|
9310
|
+
vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
|
|
8742
9311
|
|
|
8743
|
-
|
|
9312
|
+
sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
|
|
8744
9313
|
|
|
8745
|
-
|
|
9314
|
+
q6 += 64; qh += 32; q8 += 128; is=8;
|
|
8746
9315
|
|
|
8747
|
-
|
|
9316
|
+
}
|
|
9317
|
+
|
|
9318
|
+
sumf += d * sum_t;
|
|
9319
|
+
|
|
9320
|
+
}
|
|
9321
|
+
break;
|
|
9322
|
+
case 128:
|
|
9323
|
+
for (int i = 0; i < nb; ++i) {
|
|
9324
|
+
|
|
9325
|
+
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
9326
|
+
|
|
9327
|
+
const uint8_t * restrict q6 = x[i].ql;
|
|
9328
|
+
const uint8_t * restrict qh = x[i].qh;
|
|
9329
|
+
const int8_t * restrict q8 = y[i].qs;
|
|
9330
|
+
|
|
9331
|
+
const int8_t * restrict scale = x[i].scales;
|
|
9332
|
+
|
|
9333
|
+
int sum_t = 0;
|
|
9334
|
+
int t0;
|
|
9335
|
+
|
|
9336
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
|
9337
|
+
__asm__ __volatile__(
|
|
9338
|
+
"vsetvli zero, %[vl32], e8, m2\n\t"
|
|
9339
|
+
"vle8.v v4, (%[qh])\n\t"
|
|
9340
|
+
"vsll.vi v0, v4, 4\n\t"
|
|
9341
|
+
"vsll.vi v2, v4, 2\n\t"
|
|
9342
|
+
"vsrl.vi v6, v4, 2\n\t"
|
|
9343
|
+
"vsetvli zero, %[vl64], e8, m4\n\t"
|
|
9344
|
+
"vle8.v v8, (%[q6])\n\t"
|
|
9345
|
+
"vsrl.vi v12, v8, 4\n\t"
|
|
9346
|
+
"vand.vi v8, v8, 0xF\n\t"
|
|
9347
|
+
"vsetvli zero, %[vl128], e8, m8\n\t"
|
|
9348
|
+
"vand.vx v0, v0, %[mask]\n\t"
|
|
9349
|
+
"vor.vv v8, v8, v0\n\t"
|
|
9350
|
+
"vle8.v v0, (%[q8])\n\t"
|
|
9351
|
+
"vsub.vx v8, v8, %[vl32]\n\t"
|
|
9352
|
+
"vsetvli zero, %[vl64], e8, m4\n\t"
|
|
9353
|
+
"vwmul.vv v16, v0, v8\n\t"
|
|
9354
|
+
"vwmul.vv v24, v4, v12\n\t"
|
|
9355
|
+
"vsetivli zero, 16, e16, m2\n\t"
|
|
9356
|
+
"vmv.v.x v0, zero\n\t"
|
|
9357
|
+
"vwredsum.vs v10, v16, v0\n\t"
|
|
9358
|
+
"vwredsum.vs v9, v18, v0\n\t"
|
|
9359
|
+
"vwredsum.vs v8, v20, v0\n\t"
|
|
9360
|
+
"vwredsum.vs v7, v22, v0\n\t"
|
|
9361
|
+
"vwredsum.vs v11, v24, v0\n\t"
|
|
9362
|
+
"vwredsum.vs v12, v26, v0\n\t"
|
|
9363
|
+
"vwredsum.vs v13, v28, v0\n\t"
|
|
9364
|
+
"vwredsum.vs v14, v30, v0\n\t"
|
|
9365
|
+
"vsetivli zero, 4, e32, m1\n\t"
|
|
9366
|
+
"vslideup.vi v10, v9, 1\n\t"
|
|
9367
|
+
"vslideup.vi v8, v7, 1\n\t"
|
|
9368
|
+
"vslideup.vi v11, v12, 1\n\t"
|
|
9369
|
+
"vslideup.vi v13, v14, 1\n\t"
|
|
9370
|
+
"vslideup.vi v10, v8, 2\n\t"
|
|
9371
|
+
"vslideup.vi v11, v13, 2\n\t"
|
|
9372
|
+
"vsetivli zero, 8, e32, m2\n\t"
|
|
9373
|
+
"vle8.v v2, (%[scale])\n\t"
|
|
9374
|
+
"vsext.vf4 v4, v2\n\t"
|
|
9375
|
+
"vmul.vv v2, v4, v10\n\t"
|
|
9376
|
+
"vredsum.vs v0, v2, v0\n\t"
|
|
9377
|
+
"vmv.x.s %[t0], v0\n\t"
|
|
9378
|
+
"add %[sumi], %[sumi], %[t0]"
|
|
9379
|
+
: [sumi] "+&r" (sum_t), [t0] "=&r" (t0)
|
|
9380
|
+
: [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale)
|
|
9381
|
+
, [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
|
|
9382
|
+
, [mask] "r" (0x30)
|
|
9383
|
+
: "memory"
|
|
9384
|
+
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
|
|
9385
|
+
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
|
|
9386
|
+
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
|
|
9387
|
+
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
|
|
9388
|
+
);
|
|
9389
|
+
q6 += 64; qh += 32; q8 += 128; scale += 8;
|
|
9390
|
+
}
|
|
9391
|
+
|
|
9392
|
+
sumf += d * sum_t;
|
|
8748
9393
|
|
|
9394
|
+
}
|
|
9395
|
+
break;
|
|
9396
|
+
default:
|
|
9397
|
+
assert(false && "Unsupported vector length");
|
|
9398
|
+
break;
|
|
8749
9399
|
}
|
|
8750
9400
|
|
|
8751
9401
|
*s = sumf;
|