@fugood/llama.node 0.3.2 → 0.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +7 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +18 -1
- package/package.json +1 -1
- package/src/DetokenizeWorker.cpp +1 -1
- package/src/EmbeddingWorker.cpp +17 -7
- package/src/EmbeddingWorker.h +2 -1
- package/src/LlamaCompletionWorker.cpp +8 -8
- package/src/LlamaCompletionWorker.h +2 -2
- package/src/LlamaContext.cpp +89 -27
- package/src/LlamaContext.h +2 -0
- package/src/TokenizeWorker.cpp +1 -1
- package/src/common.hpp +4 -4
- package/src/llama.cpp/.github/workflows/build.yml +240 -168
- package/src/llama.cpp/.github/workflows/docker.yml +8 -8
- package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
- package/src/llama.cpp/.github/workflows/server.yml +21 -14
- package/src/llama.cpp/CMakeLists.txt +14 -6
- package/src/llama.cpp/Sources/llama/llama.h +4 -0
- package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
- package/src/llama.cpp/cmake/common.cmake +33 -0
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
- package/src/llama.cpp/common/CMakeLists.txt +6 -4
- package/src/llama.cpp/common/arg.cpp +986 -770
- package/src/llama.cpp/common/arg.h +22 -22
- package/src/llama.cpp/common/common.cpp +212 -351
- package/src/llama.cpp/common/common.h +204 -117
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/common/log.cpp +50 -50
- package/src/llama.cpp/common/log.h +18 -18
- package/src/llama.cpp/common/ngram-cache.cpp +36 -36
- package/src/llama.cpp/common/ngram-cache.h +19 -19
- package/src/llama.cpp/common/sampling.cpp +163 -121
- package/src/llama.cpp/common/sampling.h +41 -20
- package/src/llama.cpp/common/speculative.cpp +274 -0
- package/src/llama.cpp/common/speculative.h +28 -0
- package/src/llama.cpp/docs/build.md +134 -161
- package/src/llama.cpp/examples/CMakeLists.txt +33 -14
- package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched/batched.cpp +19 -18
- package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
- package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
- package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
- package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
- package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
- package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
- package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
- package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
- package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
- package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
- package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +31 -13
- package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/infill/infill.cpp +41 -87
- package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +439 -459
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +2 -0
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
- package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
- package/src/llama.cpp/examples/llava/clip.cpp +263 -66
- package/src/llama.cpp/examples/llava/clip.h +8 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
- package/src/llama.cpp/examples/llava/llava.cpp +83 -22
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
- package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
- package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +16 -15
- package/src/llama.cpp/examples/lookup/lookup.cpp +30 -30
- package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/main/main.cpp +73 -114
- package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
- package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
- package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
- package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +16 -16
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
- package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/run/run.cpp +911 -0
- package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +38 -21
- package/src/llama.cpp/examples/server/CMakeLists.txt +3 -16
- package/src/llama.cpp/examples/server/server.cpp +2073 -1339
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
- package/src/llama.cpp/examples/server/utils.hpp +354 -277
- package/src/llama.cpp/examples/simple/CMakeLists.txt +2 -2
- package/src/llama.cpp/examples/simple/simple.cpp +130 -94
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +200 -0
- package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +68 -64
- package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
- package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +3 -3
- package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/tts/tts.cpp +932 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +54 -36
- package/src/llama.cpp/ggml/include/ggml-backend.h +63 -34
- package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
- package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +135 -0
- package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
- package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
- package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
- package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
- package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
- package/src/llama.cpp/ggml/include/ggml.h +159 -417
- package/src/llama.cpp/ggml/src/CMakeLists.txt +121 -1155
- package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -28
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +57 -36
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +552 -0
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +306 -867
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +87 -0
- package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +216 -65
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +76 -0
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
- package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +343 -177
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +336 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- package/src/llama.cpp/ggml/src/{ggml-aarch64.c → ggml-cpu/ggml-cpu-aarch64.cpp} +1299 -246
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +14 -242
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +628 -0
- package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +666 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +152 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +104 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +393 -22
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
- package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +360 -127
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +105 -0
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +107 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +854 -0
- package/src/llama.cpp/ggml/src/ggml-quants.c +188 -10702
- package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
- package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +478 -300
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +84 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +36 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +259 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +34 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
- package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3638 -4151
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -87
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +7 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +141 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
- package/src/llama.cpp/ggml/src/ggml-threading.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +92 -0
- package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2138 -887
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +3 -1
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
- package/src/llama.cpp/ggml/src/ggml.c +4427 -20125
- package/src/llama.cpp/include/llama-cpp.h +25 -0
- package/src/llama.cpp/include/llama.h +93 -52
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
- package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
- package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
- package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
- package/src/llama.cpp/src/CMakeLists.txt +4 -8
- package/src/llama.cpp/src/llama-grammar.cpp +15 -15
- package/src/llama.cpp/src/llama-grammar.h +2 -5
- package/src/llama.cpp/src/llama-sampling.cpp +779 -194
- package/src/llama.cpp/src/llama-sampling.h +21 -2
- package/src/llama.cpp/src/llama-vocab.cpp +55 -10
- package/src/llama.cpp/src/llama-vocab.h +35 -11
- package/src/llama.cpp/src/llama.cpp +4317 -2979
- package/src/llama.cpp/src/unicode-data.cpp +2 -2
- package/src/llama.cpp/src/unicode.cpp +62 -51
- package/src/llama.cpp/src/unicode.h +9 -10
- package/src/llama.cpp/tests/CMakeLists.txt +48 -38
- package/src/llama.cpp/tests/test-arg-parser.cpp +15 -15
- package/src/llama.cpp/tests/test-backend-ops.cpp +324 -80
- package/src/llama.cpp/tests/test-barrier.cpp +1 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +59 -9
- package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
- package/src/llama.cpp/tests/test-log.cpp +2 -2
- package/src/llama.cpp/tests/test-opt.cpp +853 -142
- package/src/llama.cpp/tests/test-quantize-fns.cpp +24 -21
- package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
- package/src/llama.cpp/tests/test-rope.cpp +62 -20
- package/src/llama.cpp/tests/test-sampling.cpp +163 -138
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
- package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
- package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
- package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
- package/src/llama.cpp/common/train.cpp +0 -1515
- package/src/llama.cpp/common/train.h +0 -233
- package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -39
- package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +0 -600
- package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
- /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
- /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
|
@@ -1,22 +1,56 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
//
|
|
4
|
-
|
|
5
|
-
#define GGML_COMMON_IMPL_C
|
|
1
|
+
#define GGML_COMMON_IMPL_CPP
|
|
2
|
+
#define GGML_COMMON_DECL_CPP
|
|
6
3
|
#include "ggml-common.h"
|
|
4
|
+
#include "ggml-backend-impl.h"
|
|
7
5
|
|
|
8
6
|
#include "ggml-quants.h"
|
|
9
7
|
#include "ggml-impl.h"
|
|
8
|
+
#include "ggml-cpu.h"
|
|
10
9
|
#include "ggml-cpu-impl.h"
|
|
10
|
+
#include "ggml-cpu-traits.h"
|
|
11
|
+
|
|
12
|
+
#include <cmath>
|
|
13
|
+
#include <cstring>
|
|
14
|
+
#include <cassert>
|
|
15
|
+
#include <cfloat>
|
|
16
|
+
#include <cstdlib> // for qsort
|
|
17
|
+
#include <cstdio> // for GGML_ASSERT
|
|
18
|
+
|
|
19
|
+
#include "ggml-cpu-aarch64.h"
|
|
20
|
+
|
|
21
|
+
// TODO: move to include file?
|
|
22
|
+
template <int K> constexpr int QK_0() {
|
|
23
|
+
if constexpr (K == 4) {
|
|
24
|
+
return QK4_0;
|
|
25
|
+
}
|
|
26
|
+
if constexpr (K == 8) {
|
|
27
|
+
return QK8_0;
|
|
28
|
+
}
|
|
29
|
+
return -1;
|
|
30
|
+
}
|
|
11
31
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
#include <stdlib.h> // for qsort
|
|
17
|
-
#include <stdio.h> // for GGML_ASSERT
|
|
32
|
+
template <int K, int N> struct block {
|
|
33
|
+
ggml_half d[N]; // deltas for N qK_0 blocks
|
|
34
|
+
int8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_0 blocks
|
|
35
|
+
};
|
|
18
36
|
|
|
19
|
-
|
|
37
|
+
// control size
|
|
38
|
+
static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
|
|
39
|
+
static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
|
|
40
|
+
static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
|
|
41
|
+
static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
|
|
42
|
+
|
|
43
|
+
using block_q4_0x4 = block<4, 4>;
|
|
44
|
+
using block_q4_0x8 = block<4, 8>;
|
|
45
|
+
using block_q8_0x4 = block<8, 4>;
|
|
46
|
+
using block_q8_0x8 = block<8, 8>;
|
|
47
|
+
|
|
48
|
+
struct block_iq4_nlx4 {
|
|
49
|
+
ggml_half d[4]; // deltas for 4 iq4_nl blocks
|
|
50
|
+
uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks
|
|
51
|
+
};
|
|
52
|
+
|
|
53
|
+
static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
|
|
20
54
|
|
|
21
55
|
#if defined(__GNUC__)
|
|
22
56
|
#pragma GCC diagnostic ignored "-Woverlength-strings"
|
|
@@ -131,7 +165,7 @@ static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) {
|
|
|
131
165
|
}
|
|
132
166
|
|
|
133
167
|
static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i sy) {
|
|
134
|
-
#if defined(
|
|
168
|
+
#if defined(__AVX512VNNI__)
|
|
135
169
|
const __m512i zero = _mm512_setzero_si512();
|
|
136
170
|
return _mm512_dpbusd_epi32(zero, ax, sy);
|
|
137
171
|
#else
|
|
@@ -186,52 +220,14 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y)
|
|
|
186
220
|
}
|
|
187
221
|
#endif
|
|
188
222
|
|
|
189
|
-
static
|
|
190
|
-
block_q4_0x4 out;
|
|
191
|
-
|
|
192
|
-
for (int i = 0; i < 4; i++) {
|
|
193
|
-
out.d[i] = in[i].d;
|
|
194
|
-
}
|
|
195
|
-
|
|
196
|
-
for (int i = 0; i < QK4_0 * 2; i++) {
|
|
197
|
-
int src_offset = (i / (4 * blck_size_interleave)) * blck_size_interleave;
|
|
198
|
-
int src_id = (i % (4 * blck_size_interleave)) / blck_size_interleave;
|
|
199
|
-
src_offset += (i % blck_size_interleave);
|
|
200
|
-
|
|
201
|
-
out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
|
|
202
|
-
}
|
|
203
|
-
|
|
204
|
-
return out;
|
|
205
|
-
}
|
|
206
|
-
|
|
207
|
-
// interleave 8 block_q4_0s in blocks of blck_size_interleave
|
|
208
|
-
// returns an interleaved block_q4_0x8
|
|
209
|
-
// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
|
|
210
|
-
// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
|
|
211
|
-
static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
|
|
212
|
-
block_q4_0x8 out;
|
|
213
|
-
|
|
214
|
-
for (int i = 0; i < 8; i++) {
|
|
215
|
-
out.d[i] = in[i].d;
|
|
216
|
-
}
|
|
217
|
-
|
|
218
|
-
for (int i = 0; i < QK4_0 * 4; i++) {
|
|
219
|
-
int src_offset = (i / (8 * blck_size_interleave)) * blck_size_interleave;
|
|
220
|
-
int src_id = (i % (8 * blck_size_interleave)) / blck_size_interleave;
|
|
221
|
-
src_offset += (i % blck_size_interleave);
|
|
223
|
+
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
|
222
224
|
|
|
223
|
-
|
|
224
|
-
}
|
|
225
|
-
|
|
226
|
-
return out;
|
|
227
|
-
}
|
|
228
|
-
|
|
229
|
-
void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) {
|
|
225
|
+
static void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
|
230
226
|
assert(QK8_0 == 32);
|
|
231
227
|
assert(k % QK8_0 == 0);
|
|
232
228
|
const int nb = k / QK8_0;
|
|
233
229
|
|
|
234
|
-
block_q8_0x4 *
|
|
230
|
+
block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
|
|
235
231
|
|
|
236
232
|
#if defined(__ARM_NEON)
|
|
237
233
|
float32x4_t srcv[4][8];
|
|
@@ -320,12 +316,12 @@ void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k)
|
|
|
320
316
|
#endif
|
|
321
317
|
}
|
|
322
318
|
|
|
323
|
-
void quantize_q8_0_4x8(const float *
|
|
319
|
+
static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
|
324
320
|
assert(QK8_0 == 32);
|
|
325
321
|
assert(k % QK8_0 == 0);
|
|
326
322
|
const int nb = k / QK8_0;
|
|
327
323
|
|
|
328
|
-
block_q8_0x4 *
|
|
324
|
+
block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
|
|
329
325
|
|
|
330
326
|
#if defined(__ARM_NEON)
|
|
331
327
|
float32x4_t srcv[4][8];
|
|
@@ -535,7 +531,7 @@ void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k)
|
|
|
535
531
|
#endif
|
|
536
532
|
}
|
|
537
533
|
|
|
538
|
-
void quantize_mat_q8_0(const float *
|
|
534
|
+
static void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
|
|
539
535
|
assert(nrow == 4);
|
|
540
536
|
UNUSED(nrow);
|
|
541
537
|
if (blck_size_interleave == 4) {
|
|
@@ -547,58 +543,7 @@ void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nro
|
|
|
547
543
|
}
|
|
548
544
|
}
|
|
549
545
|
|
|
550
|
-
static
|
|
551
|
-
assert(n_per_row % QK4_0 == 0);
|
|
552
|
-
const int nb = n_per_row / QK4_0;
|
|
553
|
-
|
|
554
|
-
void * out_ptr = NULL;
|
|
555
|
-
if (nrows_interleaved == 8) {
|
|
556
|
-
out_ptr = (block_q4_0x8 *) dst;
|
|
557
|
-
}
|
|
558
|
-
else if (nrows_interleaved == 4) {
|
|
559
|
-
out_ptr = (block_q4_0x4 *) dst;
|
|
560
|
-
}
|
|
561
|
-
assert(nrows_interleaved <= 8);
|
|
562
|
-
block_q4_0 dst_tmp[8];
|
|
563
|
-
|
|
564
|
-
for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) {
|
|
565
|
-
|
|
566
|
-
for (int64_t x = 0; x < nb; x++) {
|
|
567
|
-
|
|
568
|
-
for (int i = 0; i < nrows_interleaved; i++ ) {
|
|
569
|
-
quantize_row_q4_0_ref(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0);
|
|
570
|
-
}
|
|
571
|
-
|
|
572
|
-
if (nrows_interleaved == 8) {
|
|
573
|
-
*(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave, 0x88);
|
|
574
|
-
out_ptr = (block_q4_0x8 *) out_ptr + 1;
|
|
575
|
-
}
|
|
576
|
-
else if (nrows_interleaved == 4) {
|
|
577
|
-
*(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave, 0x88);
|
|
578
|
-
out_ptr = (block_q4_0x4 *) out_ptr + 1;
|
|
579
|
-
}
|
|
580
|
-
}
|
|
581
|
-
}
|
|
582
|
-
|
|
583
|
-
return ((nrow * n_per_row) / QK4_0 * sizeof(block_q4_0));
|
|
584
|
-
}
|
|
585
|
-
|
|
586
|
-
size_t quantize_q4_0_4x4(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
|
587
|
-
UNUSED(quant_weights);
|
|
588
|
-
return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4);
|
|
589
|
-
}
|
|
590
|
-
|
|
591
|
-
size_t quantize_q4_0_4x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
|
592
|
-
UNUSED(quant_weights);
|
|
593
|
-
return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8);
|
|
594
|
-
}
|
|
595
|
-
|
|
596
|
-
size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
|
597
|
-
UNUSED(quant_weights);
|
|
598
|
-
return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8);
|
|
599
|
-
}
|
|
600
|
-
|
|
601
|
-
void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
|
|
546
|
+
static void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
602
547
|
const int qk = QK8_0;
|
|
603
548
|
const int nb = n / qk;
|
|
604
549
|
const int ncols_interleaved = 4;
|
|
@@ -617,67 +562,47 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
617
562
|
UNUSED(ncols_interleaved);
|
|
618
563
|
UNUSED(blocklen);
|
|
619
564
|
|
|
620
|
-
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
|
621
|
-
if (ggml_cpu_has_neon()) {
|
|
622
|
-
const
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
"fcvtl v16.4s, v20.4h\n"
|
|
659
|
-
".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n"
|
|
660
|
-
"fmul v16.4s, v16.4s, v21.4s\n"
|
|
661
|
-
".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n"
|
|
662
|
-
".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n"
|
|
663
|
-
".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n"
|
|
664
|
-
".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n"
|
|
665
|
-
".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n"
|
|
666
|
-
".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n"
|
|
667
|
-
"scvtf v26.4s, v26.4s, #0x4\n"
|
|
668
|
-
"fmla v29.4s, v26.4s, v16.4s\n"
|
|
669
|
-
"cbnz x21, 2b\n"
|
|
670
|
-
"sub %x[nc], %x[nc], #0x4\n"
|
|
671
|
-
"str q29, [%x[res_ptr], #0x0]\n"
|
|
672
|
-
"add %x[res_ptr], %x[res_ptr], #0x10\n"
|
|
673
|
-
"cbnz %x[nc], 1b\n"
|
|
674
|
-
: [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
|
|
675
|
-
: [a_ptr] "r" (a_ptr), [nb] "r" (nb)
|
|
676
|
-
: "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22"
|
|
677
|
-
);
|
|
565
|
+
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
566
|
+
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
|
567
|
+
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *)vx;
|
|
568
|
+
|
|
569
|
+
for (int c = 0; c < nc; c += ncols_interleaved) {
|
|
570
|
+
const block_q8_0 * a_ptr = (const block_q8_0 *)vy;
|
|
571
|
+
float32x4_t acc = vdupq_n_f32(0);
|
|
572
|
+
for (int b = 0; b < nb; b++) {
|
|
573
|
+
int8x16_t b0 = vld1q_s8((const int8_t *)b_ptr->qs);
|
|
574
|
+
int8x16_t b1 = vld1q_s8((const int8_t *)b_ptr->qs + 16);
|
|
575
|
+
int8x16_t b2 = vld1q_s8((const int8_t *)b_ptr->qs + 32);
|
|
576
|
+
int8x16_t b3 = vld1q_s8((const int8_t *)b_ptr->qs + 48);
|
|
577
|
+
float16x4_t bd = vld1_f16((const __fp16 *)b_ptr->d);
|
|
578
|
+
|
|
579
|
+
int8x16_t a0 = vld1q_s8(a_ptr->qs);
|
|
580
|
+
int8x16_t a1 = vld1q_s8(a_ptr->qs + qk/2);
|
|
581
|
+
float16x4_t ad = vld1_dup_f16((const __fp16 *)&a_ptr->d);
|
|
582
|
+
|
|
583
|
+
int32x4_t ret = vdupq_n_s32(0);
|
|
584
|
+
|
|
585
|
+
ret = vdotq_laneq_s32(ret, b0 << 4, a0, 0);
|
|
586
|
+
ret = vdotq_laneq_s32(ret, b1 << 4, a0, 1);
|
|
587
|
+
ret = vdotq_laneq_s32(ret, b2 << 4, a0, 2);
|
|
588
|
+
ret = vdotq_laneq_s32(ret, b3 << 4, a0, 3);
|
|
589
|
+
|
|
590
|
+
ret = vdotq_laneq_s32(ret, b0 & 0xf0U, a1, 0);
|
|
591
|
+
ret = vdotq_laneq_s32(ret, b1 & 0xf0U, a1, 1);
|
|
592
|
+
ret = vdotq_laneq_s32(ret, b2 & 0xf0U, a1, 2);
|
|
593
|
+
ret = vdotq_laneq_s32(ret, b3 & 0xf0U, a1, 3);
|
|
594
|
+
|
|
595
|
+
acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4),
|
|
596
|
+
vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
|
|
597
|
+
a_ptr++;
|
|
598
|
+
b_ptr++;
|
|
599
|
+
}
|
|
600
|
+
vst1q_f32(s, acc);
|
|
601
|
+
s += ncols_interleaved;
|
|
602
|
+
}
|
|
678
603
|
return;
|
|
679
604
|
}
|
|
680
|
-
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
|
605
|
+
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
681
606
|
float sumf[4];
|
|
682
607
|
int sumi;
|
|
683
608
|
|
|
@@ -703,7 +628,7 @@ void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
703
628
|
}
|
|
704
629
|
}
|
|
705
630
|
|
|
706
|
-
void ggml_gemv_q4_0_4x8_q8_0(int n, float *
|
|
631
|
+
static void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
707
632
|
const int qk = QK8_0;
|
|
708
633
|
const int nb = n / qk;
|
|
709
634
|
const int ncols_interleaved = 4;
|
|
@@ -813,7 +738,7 @@ void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
813
738
|
}
|
|
814
739
|
}
|
|
815
740
|
|
|
816
|
-
void ggml_gemv_q4_0_8x8_q8_0(int n, float *
|
|
741
|
+
static void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
817
742
|
const int qk = QK8_0;
|
|
818
743
|
const int nb = n / qk;
|
|
819
744
|
const int ncols_interleaved = 8;
|
|
@@ -991,6 +916,73 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
991
916
|
}
|
|
992
917
|
}
|
|
993
918
|
return;
|
|
919
|
+
#elif defined(__riscv_v_intrinsic)
|
|
920
|
+
if (__riscv_vlenb() >= QK4_0) {
|
|
921
|
+
const size_t vl = QK4_0;
|
|
922
|
+
|
|
923
|
+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
924
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
925
|
+
const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
|
|
926
|
+
|
|
927
|
+
vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
|
928
|
+
for (int l = 0; l < nb; l++) {
|
|
929
|
+
const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0];
|
|
930
|
+
const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8];
|
|
931
|
+
const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16];
|
|
932
|
+
const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24];
|
|
933
|
+
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
|
934
|
+
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, vl / 4));
|
|
935
|
+
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, vl / 4));
|
|
936
|
+
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, vl / 4));
|
|
937
|
+
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, vl / 4));
|
|
938
|
+
|
|
939
|
+
const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
|
|
940
|
+
const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
|
|
941
|
+
const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
|
|
942
|
+
const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0);
|
|
943
|
+
const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1);
|
|
944
|
+
const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0);
|
|
945
|
+
const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1);
|
|
946
|
+
|
|
947
|
+
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
|
948
|
+
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
|
949
|
+
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
|
950
|
+
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
|
951
|
+
|
|
952
|
+
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_hi_m));
|
|
953
|
+
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
|
954
|
+
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
|
955
|
+
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
|
956
|
+
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
|
957
|
+
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
|
958
|
+
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
|
959
|
+
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
|
960
|
+
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
|
961
|
+
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
|
962
|
+
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
|
963
|
+
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
|
964
|
+
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
|
965
|
+
|
|
966
|
+
// vector version needs Zvfhmin extension
|
|
967
|
+
const float a_scale = GGML_FP16_TO_FP32(a_ptr[l].d);
|
|
968
|
+
const float b_scales[8] = {
|
|
969
|
+
GGML_FP16_TO_FP32(b_ptr[l].d[0]),
|
|
970
|
+
GGML_FP16_TO_FP32(b_ptr[l].d[1]),
|
|
971
|
+
GGML_FP16_TO_FP32(b_ptr[l].d[2]),
|
|
972
|
+
GGML_FP16_TO_FP32(b_ptr[l].d[3]),
|
|
973
|
+
GGML_FP16_TO_FP32(b_ptr[l].d[4]),
|
|
974
|
+
GGML_FP16_TO_FP32(b_ptr[l].d[5]),
|
|
975
|
+
GGML_FP16_TO_FP32(b_ptr[l].d[6]),
|
|
976
|
+
GGML_FP16_TO_FP32(b_ptr[l].d[7])
|
|
977
|
+
};
|
|
978
|
+
const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);
|
|
979
|
+
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scale, vl / 4);
|
|
980
|
+
sumf = __riscv_vfmacc_vv_f32m1(sumf, tmp1, b_scales_vec, vl / 4);
|
|
981
|
+
}
|
|
982
|
+
__riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, vl / 4);
|
|
983
|
+
}
|
|
984
|
+
return;
|
|
985
|
+
}
|
|
994
986
|
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
|
|
995
987
|
{
|
|
996
988
|
float sumf[8];
|
|
@@ -1019,7 +1011,103 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
1019
1011
|
}
|
|
1020
1012
|
}
|
|
1021
1013
|
|
|
1022
|
-
void
|
|
1014
|
+
static void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
1015
|
+
const int qk = QK8_0;
|
|
1016
|
+
const int nb = n / qk;
|
|
1017
|
+
const int ncols_interleaved = 4;
|
|
1018
|
+
const int blocklen = 4;
|
|
1019
|
+
|
|
1020
|
+
assert (n % qk == 0);
|
|
1021
|
+
assert (nc % ncols_interleaved == 0);
|
|
1022
|
+
|
|
1023
|
+
UNUSED(s);
|
|
1024
|
+
UNUSED(bs);
|
|
1025
|
+
UNUSED(vx);
|
|
1026
|
+
UNUSED(vy);
|
|
1027
|
+
UNUSED(nr);
|
|
1028
|
+
UNUSED(nc);
|
|
1029
|
+
UNUSED(nb);
|
|
1030
|
+
UNUSED(ncols_interleaved);
|
|
1031
|
+
UNUSED(blocklen);
|
|
1032
|
+
|
|
1033
|
+
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
1034
|
+
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
|
1035
|
+
const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
|
|
1036
|
+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
1037
|
+
float * res_ptr = s;
|
|
1038
|
+
|
|
1039
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
1040
|
+
const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
|
|
1041
|
+
|
|
1042
|
+
float32x4_t sumf = vdupq_n_f32(0);
|
|
1043
|
+
for (int l = 0; l < nb; l++) {
|
|
1044
|
+
uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
|
|
1045
|
+
uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
|
|
1046
|
+
uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
|
|
1047
|
+
uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
|
|
1048
|
+
|
|
1049
|
+
int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4);
|
|
1050
|
+
int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F);
|
|
1051
|
+
int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4);
|
|
1052
|
+
int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F);
|
|
1053
|
+
int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4);
|
|
1054
|
+
int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F);
|
|
1055
|
+
int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4);
|
|
1056
|
+
int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F);
|
|
1057
|
+
|
|
1058
|
+
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
|
|
1059
|
+
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
|
|
1060
|
+
|
|
1061
|
+
int32x4_t sumi = vdupq_n_s32(0);
|
|
1062
|
+
sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
|
|
1063
|
+
sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
|
|
1064
|
+
sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
|
|
1065
|
+
sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1);
|
|
1066
|
+
sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2);
|
|
1067
|
+
sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
|
|
1068
|
+
sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
|
|
1069
|
+
sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
|
|
1070
|
+
|
|
1071
|
+
float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
|
|
1072
|
+
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
|
|
1073
|
+
float32x4_t d = a_d * b_d;
|
|
1074
|
+
|
|
1075
|
+
sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi));
|
|
1076
|
+
}
|
|
1077
|
+
|
|
1078
|
+
vst1q_f32(res_ptr + x * 4, sumf);
|
|
1079
|
+
}
|
|
1080
|
+
return;
|
|
1081
|
+
}
|
|
1082
|
+
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
|
1083
|
+
{
|
|
1084
|
+
float sumf[4];
|
|
1085
|
+
int sumi;
|
|
1086
|
+
|
|
1087
|
+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
1088
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
1089
|
+
const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
|
|
1090
|
+
|
|
1091
|
+
for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
|
|
1092
|
+
for (int l = 0; l < nb; l++) {
|
|
1093
|
+
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
|
1094
|
+
for (int j = 0; j < ncols_interleaved; j++) {
|
|
1095
|
+
sumi = 0;
|
|
1096
|
+
for (int i = 0; i < blocklen; ++i) {
|
|
1097
|
+
const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
|
|
1098
|
+
const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
|
|
1099
|
+
sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
|
|
1100
|
+
}
|
|
1101
|
+
sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
|
|
1102
|
+
}
|
|
1103
|
+
}
|
|
1104
|
+
}
|
|
1105
|
+
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
|
|
1106
|
+
}
|
|
1107
|
+
}
|
|
1108
|
+
}
|
|
1109
|
+
|
|
1110
|
+
static void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
1023
1111
|
const int qk = QK8_0;
|
|
1024
1112
|
const int nb = n / qk;
|
|
1025
1113
|
const int ncols_interleaved = 4;
|
|
@@ -1040,7 +1128,7 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
1040
1128
|
UNUSED(blocklen);
|
|
1041
1129
|
|
|
1042
1130
|
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
|
1043
|
-
if (ggml_cpu_has_neon()) {
|
|
1131
|
+
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
|
1044
1132
|
const void * b_ptr = vx;
|
|
1045
1133
|
const void * a_ptr = vy;
|
|
1046
1134
|
float * res_ptr = s;
|
|
@@ -1535,7 +1623,7 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
1535
1623
|
}
|
|
1536
1624
|
}
|
|
1537
1625
|
|
|
1538
|
-
void ggml_gemm_q4_0_4x8_q8_0(int n, float *
|
|
1626
|
+
static void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
1539
1627
|
const int qk = QK8_0;
|
|
1540
1628
|
const int nb = n / qk;
|
|
1541
1629
|
const int ncols_interleaved = 4;
|
|
@@ -1989,7 +2077,7 @@ void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
1989
2077
|
}
|
|
1990
2078
|
}
|
|
1991
2079
|
|
|
1992
|
-
void ggml_gemm_q4_0_8x8_q8_0(int n, float *
|
|
2080
|
+
static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
1993
2081
|
const int qk = QK8_0;
|
|
1994
2082
|
const int nb = n / qk;
|
|
1995
2083
|
const int ncols_interleaved = 8;
|
|
@@ -2509,31 +2597,31 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
2509
2597
|
const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
|
|
2510
2598
|
|
|
2511
2599
|
// Shuffle pattern one - right side input
|
|
2512
|
-
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
|
2513
|
-
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
|
2600
|
+
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
|
2601
|
+
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
|
2514
2602
|
|
|
2515
|
-
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
|
2516
|
-
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
|
2603
|
+
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
|
2604
|
+
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
|
2517
2605
|
|
|
2518
|
-
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
|
2519
|
-
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
|
2606
|
+
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
|
2607
|
+
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
|
2520
2608
|
|
|
2521
|
-
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
|
2522
|
-
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
|
2609
|
+
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
|
2610
|
+
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
|
2523
2611
|
|
|
2524
2612
|
// Shuffle pattern two - right side input
|
|
2525
2613
|
|
|
2526
|
-
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
|
2527
|
-
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
|
2614
|
+
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
|
2615
|
+
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
|
2528
2616
|
|
|
2529
|
-
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
|
2530
|
-
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
|
2617
|
+
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
|
2618
|
+
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
|
2531
2619
|
|
|
2532
|
-
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
|
2533
|
-
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
|
2620
|
+
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
|
2621
|
+
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
|
2534
2622
|
|
|
2535
|
-
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
|
2536
|
-
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
|
2623
|
+
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
|
2624
|
+
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
|
2537
2625
|
|
|
2538
2626
|
// Scale values - Load the weight scale values of two block_q4_0x8
|
|
2539
2627
|
const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
|
|
@@ -2567,31 +2655,31 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
2567
2655
|
|
|
2568
2656
|
// Shuffle pattern one - left side input
|
|
2569
2657
|
|
|
2570
|
-
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
|
2571
|
-
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
|
2658
|
+
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
|
2659
|
+
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
|
2572
2660
|
|
|
2573
|
-
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
|
2574
|
-
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
|
2661
|
+
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
|
2662
|
+
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
|
2575
2663
|
|
|
2576
|
-
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
|
2577
|
-
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
|
2664
|
+
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
|
2665
|
+
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
|
2578
2666
|
|
|
2579
|
-
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
|
2580
|
-
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
|
2667
|
+
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
|
2668
|
+
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
|
2581
2669
|
|
|
2582
2670
|
// Shuffle pattern two - left side input
|
|
2583
2671
|
|
|
2584
|
-
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
|
2585
|
-
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
|
2672
|
+
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
|
2673
|
+
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
|
2586
2674
|
|
|
2587
|
-
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
|
2588
|
-
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
|
2675
|
+
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
|
2676
|
+
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
|
2589
2677
|
|
|
2590
|
-
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
|
2591
|
-
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
|
2678
|
+
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
|
2679
|
+
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
|
2592
2680
|
|
|
2593
|
-
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
|
2594
|
-
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
|
2681
|
+
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
|
2682
|
+
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
|
2595
2683
|
|
|
2596
2684
|
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
|
2597
2685
|
// Resembles MMLAs into 2x2 matrices in ARM Version
|
|
@@ -2620,10 +2708,10 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
2620
2708
|
|
|
2621
2709
|
|
|
2622
2710
|
// Straighten out to make 4 row vectors
|
|
2623
|
-
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
|
|
2624
|
-
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
|
|
2625
|
-
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
|
|
2626
|
-
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
|
|
2711
|
+
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
|
|
2712
|
+
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
|
|
2713
|
+
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
|
|
2714
|
+
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
|
|
2627
2715
|
|
|
2628
2716
|
// Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
|
|
2629
2717
|
const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68);
|
|
@@ -2702,31 +2790,31 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
2702
2790
|
const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
|
|
2703
2791
|
|
|
2704
2792
|
// Shuffle pattern one - right side input
|
|
2705
|
-
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
|
2706
|
-
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
|
2793
|
+
const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
|
|
2794
|
+
const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
|
|
2707
2795
|
|
|
2708
|
-
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
|
2709
|
-
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
|
2796
|
+
const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
|
|
2797
|
+
const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
|
|
2710
2798
|
|
|
2711
|
-
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
|
2712
|
-
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
|
2799
|
+
const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
|
|
2800
|
+
const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
|
|
2713
2801
|
|
|
2714
|
-
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
|
2715
|
-
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
|
2802
|
+
const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
|
|
2803
|
+
const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
|
|
2716
2804
|
|
|
2717
2805
|
// Shuffle pattern two - right side input
|
|
2718
2806
|
|
|
2719
|
-
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
|
2720
|
-
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
|
2807
|
+
const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
|
|
2808
|
+
const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
|
|
2721
2809
|
|
|
2722
|
-
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
|
2723
|
-
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
|
2810
|
+
const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
|
|
2811
|
+
const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
|
|
2724
2812
|
|
|
2725
|
-
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
|
2726
|
-
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
|
2813
|
+
const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
|
|
2814
|
+
const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
|
|
2727
2815
|
|
|
2728
|
-
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
|
2729
|
-
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
|
2816
|
+
const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
|
|
2817
|
+
const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
|
|
2730
2818
|
|
|
2731
2819
|
|
|
2732
2820
|
// Scale values - Load the weight scale values of two block_q4_0x8
|
|
@@ -2758,31 +2846,31 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
2758
2846
|
|
|
2759
2847
|
// Shuffle pattern one - left side input
|
|
2760
2848
|
|
|
2761
|
-
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
|
2762
|
-
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
|
2849
|
+
const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
|
|
2850
|
+
const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
|
|
2763
2851
|
|
|
2764
|
-
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
|
2765
|
-
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
|
2852
|
+
const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
|
|
2853
|
+
const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
|
|
2766
2854
|
|
|
2767
|
-
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
|
2768
|
-
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
|
2855
|
+
const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
|
|
2856
|
+
const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
|
|
2769
2857
|
|
|
2770
|
-
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
|
2771
|
-
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
|
2858
|
+
const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
|
|
2859
|
+
const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
|
|
2772
2860
|
|
|
2773
2861
|
// Shuffle pattern two - left side input
|
|
2774
2862
|
|
|
2775
|
-
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
|
2776
|
-
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
|
2863
|
+
const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
|
|
2864
|
+
const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
|
|
2777
2865
|
|
|
2778
|
-
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
|
2779
|
-
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
|
2866
|
+
const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
|
|
2867
|
+
const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
|
|
2780
2868
|
|
|
2781
|
-
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
|
2782
|
-
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
|
2869
|
+
const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
|
|
2870
|
+
const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
|
|
2783
2871
|
|
|
2784
|
-
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
|
2785
|
-
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
|
2872
|
+
const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
|
|
2873
|
+
const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
|
|
2786
2874
|
|
|
2787
2875
|
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
|
|
2788
2876
|
// Resembles MMLAs into 2x2 matrices in ARM Version
|
|
@@ -2811,10 +2899,10 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
2811
2899
|
|
|
2812
2900
|
|
|
2813
2901
|
// Straighten out to make 4 row vectors
|
|
2814
|
-
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78));
|
|
2815
|
-
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01);
|
|
2816
|
-
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78));
|
|
2817
|
-
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11);
|
|
2902
|
+
__m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
|
|
2903
|
+
__m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
|
|
2904
|
+
__m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
|
|
2905
|
+
__m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
|
|
2818
2906
|
|
|
2819
2907
|
// Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
|
|
2820
2908
|
const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68);
|
|
@@ -3171,6 +3259,207 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
3171
3259
|
}
|
|
3172
3260
|
}
|
|
3173
3261
|
}
|
|
3262
|
+
return;
|
|
3263
|
+
}
|
|
3264
|
+
#elif defined(__riscv_v_intrinsic)
|
|
3265
|
+
if (__riscv_vlenb() >= QK4_0) {
|
|
3266
|
+
const size_t vl = QK4_0;
|
|
3267
|
+
|
|
3268
|
+
for (int y = 0; y < nr / 4; y++) {
|
|
3269
|
+
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
|
3270
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
3271
|
+
const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
|
|
3272
|
+
vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
|
3273
|
+
vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
|
3274
|
+
vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
|
3275
|
+
vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
|
3276
|
+
for (int l = 0; l < nb; l++) {
|
|
3277
|
+
const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
|
|
3278
|
+
const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
|
|
3279
|
+
const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
|
|
3280
|
+
const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0);
|
|
3281
|
+
const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1);
|
|
3282
|
+
const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0);
|
|
3283
|
+
const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1);
|
|
3284
|
+
|
|
3285
|
+
// vector version needs Zvfhmin extension
|
|
3286
|
+
const float a_scales[4] = {
|
|
3287
|
+
GGML_FP16_TO_FP32(a_ptr[l].d[0]),
|
|
3288
|
+
GGML_FP16_TO_FP32(a_ptr[l].d[1]),
|
|
3289
|
+
GGML_FP16_TO_FP32(a_ptr[l].d[2]),
|
|
3290
|
+
GGML_FP16_TO_FP32(a_ptr[l].d[3])
|
|
3291
|
+
};
|
|
3292
|
+
const float b_scales[8] = {
|
|
3293
|
+
GGML_FP16_TO_FP32(b_ptr[l].d[0]),
|
|
3294
|
+
GGML_FP16_TO_FP32(b_ptr[l].d[1]),
|
|
3295
|
+
GGML_FP16_TO_FP32(b_ptr[l].d[2]),
|
|
3296
|
+
GGML_FP16_TO_FP32(b_ptr[l].d[3]),
|
|
3297
|
+
GGML_FP16_TO_FP32(b_ptr[l].d[4]),
|
|
3298
|
+
GGML_FP16_TO_FP32(b_ptr[l].d[5]),
|
|
3299
|
+
GGML_FP16_TO_FP32(b_ptr[l].d[6]),
|
|
3300
|
+
GGML_FP16_TO_FP32(b_ptr[l].d[7])
|
|
3301
|
+
};
|
|
3302
|
+
const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);
|
|
3303
|
+
|
|
3304
|
+
const int64_t A0 = *(const int64_t *)&a_ptr[l].qs[0];
|
|
3305
|
+
const int64_t A4 = *(const int64_t *)&a_ptr[l].qs[32];
|
|
3306
|
+
const int64_t A8 = *(const int64_t *)&a_ptr[l].qs[64];
|
|
3307
|
+
const int64_t Ac = *(const int64_t *)&a_ptr[l].qs[96];
|
|
3308
|
+
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
|
3309
|
+
vint16m4_t sumi_l0;
|
|
3310
|
+
{
|
|
3311
|
+
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A0, vl / 4));
|
|
3312
|
+
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A4, vl / 4));
|
|
3313
|
+
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A8, vl / 4));
|
|
3314
|
+
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ac, vl / 4));
|
|
3315
|
+
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
|
3316
|
+
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
|
3317
|
+
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
|
3318
|
+
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
|
3319
|
+
|
|
3320
|
+
sumi_l0 = sumi_hi_m;
|
|
3321
|
+
}
|
|
3322
|
+
|
|
3323
|
+
{
|
|
3324
|
+
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l0));
|
|
3325
|
+
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
|
3326
|
+
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
|
3327
|
+
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
|
3328
|
+
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
|
3329
|
+
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
|
3330
|
+
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
|
3331
|
+
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
|
3332
|
+
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
|
3333
|
+
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
|
3334
|
+
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
|
3335
|
+
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
|
3336
|
+
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
|
3337
|
+
|
|
3338
|
+
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4);
|
|
3339
|
+
sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, tmp1, b_scales_vec, vl / 4);
|
|
3340
|
+
}
|
|
3341
|
+
|
|
3342
|
+
const int64_t A1 = *(const int64_t *)&a_ptr[l].qs[8];
|
|
3343
|
+
const int64_t A5 = *(const int64_t *)&a_ptr[l].qs[40];
|
|
3344
|
+
const int64_t A9 = *(const int64_t *)&a_ptr[l].qs[72];
|
|
3345
|
+
const int64_t Ad = *(const int64_t *)&a_ptr[l].qs[104];
|
|
3346
|
+
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
|
3347
|
+
vint16m4_t sumi_l1;
|
|
3348
|
+
{
|
|
3349
|
+
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A1, vl / 4));
|
|
3350
|
+
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A5, vl / 4));
|
|
3351
|
+
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A9, vl / 4));
|
|
3352
|
+
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ad, vl / 4));
|
|
3353
|
+
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
|
3354
|
+
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
|
3355
|
+
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
|
3356
|
+
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
|
3357
|
+
|
|
3358
|
+
sumi_l1 = sumi_hi_m;
|
|
3359
|
+
}
|
|
3360
|
+
|
|
3361
|
+
{
|
|
3362
|
+
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l1));
|
|
3363
|
+
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
|
3364
|
+
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
|
3365
|
+
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
|
3366
|
+
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
|
3367
|
+
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
|
3368
|
+
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
|
3369
|
+
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
|
3370
|
+
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
|
3371
|
+
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
|
3372
|
+
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
|
3373
|
+
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
|
3374
|
+
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
|
3375
|
+
|
|
3376
|
+
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[1], vl / 4);
|
|
3377
|
+
sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4);
|
|
3378
|
+
}
|
|
3379
|
+
|
|
3380
|
+
const int64_t A2 = *(const int64_t *)&a_ptr[l].qs[16];
|
|
3381
|
+
const int64_t A6 = *(const int64_t *)&a_ptr[l].qs[48];
|
|
3382
|
+
const int64_t Aa = *(const int64_t *)&a_ptr[l].qs[80];
|
|
3383
|
+
const int64_t Ae = *(const int64_t *)&a_ptr[l].qs[112];
|
|
3384
|
+
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
|
3385
|
+
vint16m4_t sumi_l2;
|
|
3386
|
+
{
|
|
3387
|
+
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A2, vl / 4));
|
|
3388
|
+
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A6, vl / 4));
|
|
3389
|
+
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Aa, vl / 4));
|
|
3390
|
+
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ae, vl / 4));
|
|
3391
|
+
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
|
3392
|
+
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
|
3393
|
+
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
|
3394
|
+
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
|
3395
|
+
|
|
3396
|
+
sumi_l2 = sumi_hi_m;
|
|
3397
|
+
}
|
|
3398
|
+
|
|
3399
|
+
{
|
|
3400
|
+
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l2));
|
|
3401
|
+
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
|
3402
|
+
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
|
3403
|
+
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
|
3404
|
+
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
|
3405
|
+
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
|
3406
|
+
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
|
3407
|
+
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
|
3408
|
+
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
|
3409
|
+
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
|
3410
|
+
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
|
3411
|
+
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
|
3412
|
+
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
|
3413
|
+
|
|
3414
|
+
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[2], vl / 4);
|
|
3415
|
+
sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, tmp1, b_scales_vec, vl / 4);
|
|
3416
|
+
}
|
|
3417
|
+
|
|
3418
|
+
const int64_t A3 = *(const int64_t *)&a_ptr[l].qs[24];
|
|
3419
|
+
const int64_t A7 = *(const int64_t *)&a_ptr[l].qs[56];
|
|
3420
|
+
const int64_t Ab = *(const int64_t *)&a_ptr[l].qs[88];
|
|
3421
|
+
const int64_t Af = *(const int64_t *)&a_ptr[l].qs[120];
|
|
3422
|
+
__asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
|
|
3423
|
+
vint16m4_t sumi_l3;
|
|
3424
|
+
{
|
|
3425
|
+
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A3, vl / 4));
|
|
3426
|
+
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A7, vl / 4));
|
|
3427
|
+
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ab, vl / 4));
|
|
3428
|
+
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Af, vl / 4));
|
|
3429
|
+
const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
|
|
3430
|
+
const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
|
|
3431
|
+
const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
|
|
3432
|
+
const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
|
|
3433
|
+
|
|
3434
|
+
sumi_l3 = sumi_hi_m;
|
|
3435
|
+
}
|
|
3436
|
+
|
|
3437
|
+
{
|
|
3438
|
+
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l3));
|
|
3439
|
+
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
|
3440
|
+
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
|
3441
|
+
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
|
3442
|
+
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
|
3443
|
+
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
|
3444
|
+
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
|
3445
|
+
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
|
3446
|
+
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
|
3447
|
+
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
|
3448
|
+
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
|
3449
|
+
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
|
3450
|
+
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
|
3451
|
+
|
|
3452
|
+
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[3], vl / 4);
|
|
3453
|
+
sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, tmp1, b_scales_vec, vl / 4);
|
|
3454
|
+
}
|
|
3455
|
+
}
|
|
3456
|
+
__riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4);
|
|
3457
|
+
__riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4);
|
|
3458
|
+
__riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4);
|
|
3459
|
+
__riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, vl / 4);
|
|
3460
|
+
}
|
|
3461
|
+
}
|
|
3462
|
+
|
|
3174
3463
|
return;
|
|
3175
3464
|
}
|
|
3176
3465
|
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
|
|
@@ -3207,3 +3496,767 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|
|
3207
3496
|
}
|
|
3208
3497
|
}
|
|
3209
3498
|
}
|
|
3499
|
+
|
|
3500
|
+
static void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
3501
|
+
const int qk = QK8_0;
|
|
3502
|
+
const int nb = n / qk;
|
|
3503
|
+
const int ncols_interleaved = 4;
|
|
3504
|
+
const int blocklen = 4;
|
|
3505
|
+
|
|
3506
|
+
assert (n % qk == 0);
|
|
3507
|
+
assert (nr % 4 == 0);
|
|
3508
|
+
assert (nc % ncols_interleaved == 0);
|
|
3509
|
+
|
|
3510
|
+
UNUSED(s);
|
|
3511
|
+
UNUSED(bs);
|
|
3512
|
+
UNUSED(vx);
|
|
3513
|
+
UNUSED(vy);
|
|
3514
|
+
UNUSED(nr);
|
|
3515
|
+
UNUSED(nc);
|
|
3516
|
+
UNUSED(nb);
|
|
3517
|
+
UNUSED(ncols_interleaved);
|
|
3518
|
+
UNUSED(blocklen);
|
|
3519
|
+
|
|
3520
|
+
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
3521
|
+
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
|
3522
|
+
const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
|
|
3523
|
+
|
|
3524
|
+
for (int y = 0; y < nr / 4; y++) {
|
|
3525
|
+
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
|
3526
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
3527
|
+
const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
|
|
3528
|
+
|
|
3529
|
+
float32x4_t sumf[4];
|
|
3530
|
+
for (int m = 0; m < 4; m++) {
|
|
3531
|
+
sumf[m] = vdupq_n_f32(0);
|
|
3532
|
+
}
|
|
3533
|
+
|
|
3534
|
+
for (int l = 0; l < nb; l++) {
|
|
3535
|
+
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
|
|
3536
|
+
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
|
|
3537
|
+
|
|
3538
|
+
int32x4_t sumi_0 = vdupq_n_s32(0);
|
|
3539
|
+
int32x4_t sumi_1 = vdupq_n_s32(0);
|
|
3540
|
+
int32x4_t sumi_2 = vdupq_n_s32(0);
|
|
3541
|
+
int32x4_t sumi_3 = vdupq_n_s32(0);
|
|
3542
|
+
|
|
3543
|
+
for (int k = 0; k < 4; k++) {
|
|
3544
|
+
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
|
|
3545
|
+
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
|
|
3546
|
+
|
|
3547
|
+
uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
|
|
3548
|
+
int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
|
|
3549
|
+
int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
|
|
3550
|
+
|
|
3551
|
+
sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
|
|
3552
|
+
sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
|
|
3553
|
+
sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
|
|
3554
|
+
sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
|
|
3555
|
+
sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
|
|
3556
|
+
sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
|
|
3557
|
+
sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
|
|
3558
|
+
sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
|
|
3559
|
+
}
|
|
3560
|
+
|
|
3561
|
+
sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
|
|
3562
|
+
sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
|
|
3563
|
+
sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
|
|
3564
|
+
sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
|
|
3565
|
+
}
|
|
3566
|
+
|
|
3567
|
+
for (int m = 0; m < 4; m++) {
|
|
3568
|
+
vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
|
|
3569
|
+
}
|
|
3570
|
+
}
|
|
3571
|
+
}
|
|
3572
|
+
return;
|
|
3573
|
+
}
|
|
3574
|
+
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
|
3575
|
+
{
|
|
3576
|
+
float sumf[4][4];
|
|
3577
|
+
int sumi;
|
|
3578
|
+
|
|
3579
|
+
for (int y = 0; y < nr / 4; y++) {
|
|
3580
|
+
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
|
3581
|
+
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
3582
|
+
const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
|
|
3583
|
+
for (int m = 0; m < 4; m++) {
|
|
3584
|
+
for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
|
|
3585
|
+
}
|
|
3586
|
+
for (int l = 0; l < nb; l++) {
|
|
3587
|
+
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
|
3588
|
+
for (int m = 0; m < 4; m++) {
|
|
3589
|
+
for (int j = 0; j < ncols_interleaved; j++) {
|
|
3590
|
+
sumi = 0;
|
|
3591
|
+
for (int i = 0; i < blocklen; ++i) {
|
|
3592
|
+
const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
|
|
3593
|
+
const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
|
|
3594
|
+
sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
|
|
3595
|
+
(v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
|
|
3596
|
+
}
|
|
3597
|
+
sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
|
|
3598
|
+
}
|
|
3599
|
+
}
|
|
3600
|
+
}
|
|
3601
|
+
}
|
|
3602
|
+
for (int m = 0; m < 4; m++) {
|
|
3603
|
+
for (int j = 0; j < ncols_interleaved; j++)
|
|
3604
|
+
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
|
|
3605
|
+
}
|
|
3606
|
+
}
|
|
3607
|
+
}
|
|
3608
|
+
}
|
|
3609
|
+
}
|
|
3610
|
+
|
|
3611
|
+
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
|
|
3612
|
+
block_q4_0x4 out;
|
|
3613
|
+
|
|
3614
|
+
for (int i = 0; i < 4; i++) {
|
|
3615
|
+
out.d[i] = in[i].d;
|
|
3616
|
+
}
|
|
3617
|
+
|
|
3618
|
+
const int end = QK4_0 * 2 / blck_size_interleave;
|
|
3619
|
+
|
|
3620
|
+
if (blck_size_interleave == 8) {
|
|
3621
|
+
const uint64_t xor_mask = 0x8888888888888888ULL;
|
|
3622
|
+
for (int i = 0; i < end; ++i) {
|
|
3623
|
+
int src_id = i % 4;
|
|
3624
|
+
int src_offset = (i / 4) * blck_size_interleave;
|
|
3625
|
+
int dst_offset = i * blck_size_interleave;
|
|
3626
|
+
|
|
3627
|
+
uint64_t elems;
|
|
3628
|
+
// Using memcpy to avoid unaligned memory accesses
|
|
3629
|
+
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
|
|
3630
|
+
elems ^= xor_mask;
|
|
3631
|
+
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
|
|
3632
|
+
}
|
|
3633
|
+
} else if (blck_size_interleave == 4) {
|
|
3634
|
+
const uint32_t xor_mask = 0x88888888;
|
|
3635
|
+
for (int i = 0; i < end; ++i) {
|
|
3636
|
+
int src_id = i % 4;
|
|
3637
|
+
int src_offset = (i / 4) * blck_size_interleave;
|
|
3638
|
+
int dst_offset = i * blck_size_interleave;
|
|
3639
|
+
|
|
3640
|
+
uint32_t elems;
|
|
3641
|
+
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint32_t));
|
|
3642
|
+
elems ^= xor_mask;
|
|
3643
|
+
memcpy(&out.qs[dst_offset], &elems, sizeof(uint32_t));
|
|
3644
|
+
}
|
|
3645
|
+
} else {
|
|
3646
|
+
GGML_ASSERT(false);
|
|
3647
|
+
}
|
|
3648
|
+
|
|
3649
|
+
return out;
|
|
3650
|
+
}
|
|
3651
|
+
|
|
3652
|
+
// interleave 8 block_q4_0s in blocks of blck_size_interleave
|
|
3653
|
+
// returns an interleaved block_q4_0x8
|
|
3654
|
+
// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
|
|
3655
|
+
// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
|
|
3656
|
+
static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave) {
|
|
3657
|
+
block_q4_0x8 out;
|
|
3658
|
+
|
|
3659
|
+
for (int i = 0; i < 8; i++) {
|
|
3660
|
+
out.d[i] = in[i].d;
|
|
3661
|
+
}
|
|
3662
|
+
|
|
3663
|
+
const int end = QK4_0 * 4 / blck_size_interleave;
|
|
3664
|
+
const uint64_t xor_mask = 0x8888888888888888ULL;
|
|
3665
|
+
|
|
3666
|
+
for (int i = 0; i < end; ++i) {
|
|
3667
|
+
int src_id = i % 8;
|
|
3668
|
+
int src_offset = (i / 8) * blck_size_interleave;
|
|
3669
|
+
int dst_offset = i * blck_size_interleave;
|
|
3670
|
+
|
|
3671
|
+
uint64_t elems;
|
|
3672
|
+
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
|
|
3673
|
+
elems ^= xor_mask;
|
|
3674
|
+
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
|
|
3675
|
+
}
|
|
3676
|
+
|
|
3677
|
+
return out;
|
|
3678
|
+
}
|
|
3679
|
+
|
|
3680
|
+
static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
|
3681
|
+
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
|
|
3682
|
+
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
|
3683
|
+
constexpr int nrows_interleaved = 4;
|
|
3684
|
+
|
|
3685
|
+
block_q4_0x4 * dst = (block_q4_0x4 *)t->data;
|
|
3686
|
+
const block_q4_0 * src = (const block_q4_0 *)data;
|
|
3687
|
+
block_q4_0 dst_tmp[4];
|
|
3688
|
+
int nrow = ggml_nrows(t);
|
|
3689
|
+
int nblocks = t->ne[0] / QK4_0;
|
|
3690
|
+
|
|
3691
|
+
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
|
|
3692
|
+
|
|
3693
|
+
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
|
3694
|
+
return -1;
|
|
3695
|
+
}
|
|
3696
|
+
|
|
3697
|
+
for (int b = 0; b < nrow; b += nrows_interleaved) {
|
|
3698
|
+
for (int64_t x = 0; x < nblocks; x++) {
|
|
3699
|
+
for (int i = 0; i < nrows_interleaved; i++) {
|
|
3700
|
+
dst_tmp[i] = src[x + i * nblocks];
|
|
3701
|
+
}
|
|
3702
|
+
*dst++ = make_block_q4_0x4(dst_tmp, interleave_block);
|
|
3703
|
+
}
|
|
3704
|
+
src += nrows_interleaved * nblocks;
|
|
3705
|
+
}
|
|
3706
|
+
return 0;
|
|
3707
|
+
|
|
3708
|
+
GGML_UNUSED(data_size);
|
|
3709
|
+
}
|
|
3710
|
+
|
|
3711
|
+
static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
|
3712
|
+
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
|
|
3713
|
+
GGML_ASSERT(interleave_block == 8);
|
|
3714
|
+
constexpr int nrows_interleaved = 8;
|
|
3715
|
+
|
|
3716
|
+
block_q4_0x8 * dst = (block_q4_0x8*)t->data;
|
|
3717
|
+
const block_q4_0 * src = (const block_q4_0*) data;
|
|
3718
|
+
block_q4_0 dst_tmp[8];
|
|
3719
|
+
int nrow = ggml_nrows(t);
|
|
3720
|
+
int nblocks = t->ne[0] / QK4_0;
|
|
3721
|
+
|
|
3722
|
+
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
|
|
3723
|
+
|
|
3724
|
+
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
|
3725
|
+
return -1;
|
|
3726
|
+
}
|
|
3727
|
+
|
|
3728
|
+
for (int b = 0; b < nrow; b += nrows_interleaved) {
|
|
3729
|
+
for (int64_t x = 0; x < nblocks; x++) {
|
|
3730
|
+
for (int i = 0; i < nrows_interleaved; i++ ) {
|
|
3731
|
+
dst_tmp[i] = src[x + i * nblocks];
|
|
3732
|
+
}
|
|
3733
|
+
*dst++ = make_block_q4_0x8(dst_tmp, interleave_block);
|
|
3734
|
+
}
|
|
3735
|
+
src += nrows_interleaved * nblocks;
|
|
3736
|
+
}
|
|
3737
|
+
return 0;
|
|
3738
|
+
|
|
3739
|
+
GGML_UNUSED(data_size);
|
|
3740
|
+
}
|
|
3741
|
+
|
|
3742
|
+
static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) {
|
|
3743
|
+
block_iq4_nlx4 out;
|
|
3744
|
+
|
|
3745
|
+
for (int i = 0; i < 4; i++) {
|
|
3746
|
+
out.d[i] = in[i].d;
|
|
3747
|
+
}
|
|
3748
|
+
|
|
3749
|
+
const int end = QK4_NL * 2 / blck_size_interleave;
|
|
3750
|
+
|
|
3751
|
+
// TODO: this branch seems wrong
|
|
3752
|
+
//if (blck_size_interleave == 8) {
|
|
3753
|
+
// for (int i = 0; i < end; ++i) {
|
|
3754
|
+
// int src_id = i % 4;
|
|
3755
|
+
// int src_offset = (i / 4) * blck_size_interleave;
|
|
3756
|
+
// int dst_offset = i * blck_size_interleave;
|
|
3757
|
+
|
|
3758
|
+
// // Using memcpy to avoid unaligned memory accesses
|
|
3759
|
+
// memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
|
|
3760
|
+
// }
|
|
3761
|
+
//} else
|
|
3762
|
+
if (blck_size_interleave == 4) {
|
|
3763
|
+
for (int i = 0; i < end; ++i) {
|
|
3764
|
+
int src_id = i % 4;
|
|
3765
|
+
int src_offset = (i / 4) * blck_size_interleave;
|
|
3766
|
+
int dst_offset = i * blck_size_interleave;
|
|
3767
|
+
|
|
3768
|
+
memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t));
|
|
3769
|
+
}
|
|
3770
|
+
} else {
|
|
3771
|
+
GGML_ASSERT(false);
|
|
3772
|
+
}
|
|
3773
|
+
|
|
3774
|
+
return out;
|
|
3775
|
+
}
|
|
3776
|
+
|
|
3777
|
+
static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
|
3778
|
+
GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
|
|
3779
|
+
//GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
|
3780
|
+
GGML_ASSERT(interleave_block == 4);
|
|
3781
|
+
|
|
3782
|
+
block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data;
|
|
3783
|
+
const block_iq4_nl * src = (const block_iq4_nl *)data;
|
|
3784
|
+
block_iq4_nl dst_tmp[4];
|
|
3785
|
+
int nrow = ggml_nrows(t);
|
|
3786
|
+
int nrows_interleaved = 4;
|
|
3787
|
+
int nblocks = t->ne[0] / QK4_0;
|
|
3788
|
+
|
|
3789
|
+
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
|
|
3790
|
+
|
|
3791
|
+
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
|
3792
|
+
return -1;
|
|
3793
|
+
}
|
|
3794
|
+
|
|
3795
|
+
for (int b = 0; b < nrow; b += nrows_interleaved) {
|
|
3796
|
+
for (int64_t x = 0; x < nblocks; x++) {
|
|
3797
|
+
for (int i = 0; i < nrows_interleaved; i++) {
|
|
3798
|
+
dst_tmp[i] = src[x + i * nblocks];
|
|
3799
|
+
}
|
|
3800
|
+
*dst++ = make_block_iq4_nlx4(dst_tmp, interleave_block);
|
|
3801
|
+
}
|
|
3802
|
+
src += nrows_interleaved * nblocks;
|
|
3803
|
+
}
|
|
3804
|
+
return 0;
|
|
3805
|
+
|
|
3806
|
+
GGML_UNUSED(data_size);
|
|
3807
|
+
}
|
|
3808
|
+
|
|
3809
|
+
namespace ggml::cpu::aarch64 {
|
|
3810
|
+
// repack
|
|
3811
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
|
|
3812
|
+
int repack(struct ggml_tensor *, const void *, size_t);
|
|
3813
|
+
|
|
3814
|
+
// TODO: generalise.
|
|
3815
|
+
template <> int repack<block_q4_0, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
|
3816
|
+
return repack_q4_0_to_q4_0_4_bl(t, 4, data, data_size);
|
|
3817
|
+
}
|
|
3818
|
+
|
|
3819
|
+
template <> int repack<block_q4_0, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
|
3820
|
+
return repack_q4_0_to_q4_0_4_bl(t, 8, data, data_size);
|
|
3821
|
+
}
|
|
3822
|
+
|
|
3823
|
+
template <> int repack<block_q4_0, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
|
3824
|
+
return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size);
|
|
3825
|
+
}
|
|
3826
|
+
|
|
3827
|
+
template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
|
3828
|
+
return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
|
|
3829
|
+
}
|
|
3830
|
+
|
|
3831
|
+
// TODO: needs to be revisited
|
|
3832
|
+
//template <> int repack<block_iq4_nl, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
|
3833
|
+
// return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size);
|
|
3834
|
+
//}
|
|
3835
|
+
|
|
3836
|
+
// gemv
|
|
3837
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
|
|
3838
|
+
void gemv(int, float *, size_t, const void *, const void *, int, int);
|
|
3839
|
+
|
|
3840
|
+
template <> void gemv<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3841
|
+
ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3842
|
+
}
|
|
3843
|
+
|
|
3844
|
+
template <> void gemv<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3845
|
+
ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3846
|
+
}
|
|
3847
|
+
|
|
3848
|
+
template <> void gemv<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3849
|
+
ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3850
|
+
}
|
|
3851
|
+
|
|
3852
|
+
template <>
|
|
3853
|
+
void gemv<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3854
|
+
ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3855
|
+
}
|
|
3856
|
+
|
|
3857
|
+
// gemm
|
|
3858
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
|
|
3859
|
+
void gemm(int, float *, size_t, const void *, const void *, int, int);
|
|
3860
|
+
|
|
3861
|
+
template <> void gemm<block_q4_0, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3862
|
+
ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3863
|
+
}
|
|
3864
|
+
|
|
3865
|
+
template <> void gemm<block_q4_0, 8, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3866
|
+
ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3867
|
+
}
|
|
3868
|
+
|
|
3869
|
+
template <> void gemm<block_q4_0, 8, 8>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3870
|
+
ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3871
|
+
}
|
|
3872
|
+
|
|
3873
|
+
template <>
|
|
3874
|
+
void gemm<block_iq4_nl, 4, 4>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
|
3875
|
+
ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
|
|
3876
|
+
}
|
|
3877
|
+
|
|
3878
|
+
class tensor_traits_base : public ggml::cpu::tensor_traits {
|
|
3879
|
+
public:
|
|
3880
|
+
virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
|
|
3881
|
+
};
|
|
3882
|
+
|
|
3883
|
+
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
|
|
3884
|
+
|
|
3885
|
+
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
|
|
3886
|
+
// not realy a GGML_TYPE_Q8_0 but same size.
|
|
3887
|
+
switch (op->op) {
|
|
3888
|
+
case GGML_OP_MUL_MAT:
|
|
3889
|
+
size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1]));
|
|
3890
|
+
return true;
|
|
3891
|
+
case GGML_OP_MUL_MAT_ID:
|
|
3892
|
+
size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1]));
|
|
3893
|
+
size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
|
|
3894
|
+
size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
|
|
3895
|
+
return true;
|
|
3896
|
+
default:
|
|
3897
|
+
// GGML_ABORT("fatal error");
|
|
3898
|
+
break;
|
|
3899
|
+
}
|
|
3900
|
+
return false;
|
|
3901
|
+
}
|
|
3902
|
+
|
|
3903
|
+
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
|
|
3904
|
+
switch (op->op) {
|
|
3905
|
+
case GGML_OP_MUL_MAT:
|
|
3906
|
+
forward_mul_mat(params, op);
|
|
3907
|
+
return true;
|
|
3908
|
+
case GGML_OP_MUL_MAT_ID:
|
|
3909
|
+
forward_mul_mat_id(params, op);
|
|
3910
|
+
return true;
|
|
3911
|
+
default:
|
|
3912
|
+
// GGML_ABORT("fatal error");
|
|
3913
|
+
break;
|
|
3914
|
+
}
|
|
3915
|
+
return false;
|
|
3916
|
+
}
|
|
3917
|
+
|
|
3918
|
+
void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
|
|
3919
|
+
const ggml_tensor * src0 = op->src[0];
|
|
3920
|
+
const ggml_tensor * src1 = op->src[1];
|
|
3921
|
+
ggml_tensor * dst = op;
|
|
3922
|
+
|
|
3923
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
3924
|
+
|
|
3925
|
+
const int ith = params->ith;
|
|
3926
|
+
const int nth = params->nth;
|
|
3927
|
+
|
|
3928
|
+
GGML_ASSERT(ne0 == ne01);
|
|
3929
|
+
GGML_ASSERT(ne1 == ne11);
|
|
3930
|
+
GGML_ASSERT(ne2 == ne12);
|
|
3931
|
+
GGML_ASSERT(ne3 == ne13);
|
|
3932
|
+
|
|
3933
|
+
// dst cannot be transposed or permuted
|
|
3934
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
|
3935
|
+
GGML_ASSERT(nb0 <= nb1);
|
|
3936
|
+
GGML_ASSERT(nb1 <= nb2);
|
|
3937
|
+
GGML_ASSERT(nb2 <= nb3);
|
|
3938
|
+
|
|
3939
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
3940
|
+
|
|
3941
|
+
GGML_ASSERT(ggml_n_dims(op->src[0]) == 2);
|
|
3942
|
+
// GGML_ASSERT(ggml_n_dims(op->src[1]) == 2);
|
|
3943
|
+
|
|
3944
|
+
char * wdata = static_cast<char *>(params->wdata);
|
|
3945
|
+
const size_t nbw1 = ggml_row_size(GGML_TYPE_Q8_0, ne10);
|
|
3946
|
+
|
|
3947
|
+
assert(params->wsize >= nbw1 * ne11);
|
|
3948
|
+
|
|
3949
|
+
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float;
|
|
3950
|
+
|
|
3951
|
+
int64_t i11_processed = 0;
|
|
3952
|
+
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
|
3953
|
+
quantize_mat_q8_0((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
|
|
3954
|
+
INTER_SIZE);
|
|
3955
|
+
}
|
|
3956
|
+
i11_processed = ne11 - ne11 % 4;
|
|
3957
|
+
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
|
|
3958
|
+
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
|
|
3959
|
+
}
|
|
3960
|
+
|
|
3961
|
+
ggml_barrier(params->threadpool);
|
|
3962
|
+
|
|
3963
|
+
const void * src1_wdata = params->wdata;
|
|
3964
|
+
const size_t src1_col_stride = ggml_row_size(GGML_TYPE_Q8_0, ne10);
|
|
3965
|
+
int64_t src0_start = (ith * ne01) / nth;
|
|
3966
|
+
int64_t src0_end = ((ith + 1) * ne01) / nth;
|
|
3967
|
+
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
|
|
3968
|
+
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
|
|
3969
|
+
if (src0_start >= src0_end) {
|
|
3970
|
+
return;
|
|
3971
|
+
}
|
|
3972
|
+
|
|
3973
|
+
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
|
3974
|
+
if (ne11 > 3) {
|
|
3975
|
+
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data) + src0_start, ne01,
|
|
3976
|
+
(const char *) src0->data + src0_start * nb01,
|
|
3977
|
+
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
|
3978
|
+
}
|
|
3979
|
+
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
|
|
3980
|
+
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(ne00, (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
|
3981
|
+
(const char *) src0->data + src0_start * nb01,
|
|
3982
|
+
(const char *) src1_wdata + (src1_col_stride * iter), 1,
|
|
3983
|
+
src0_end - src0_start);
|
|
3984
|
+
}
|
|
3985
|
+
}
|
|
3986
|
+
|
|
3987
|
+
void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) {
|
|
3988
|
+
const ggml_tensor * src0 = op->src[0];
|
|
3989
|
+
const ggml_tensor * src1 = op->src[1];
|
|
3990
|
+
const ggml_tensor * ids = op->src[2];
|
|
3991
|
+
ggml_tensor * dst = op;
|
|
3992
|
+
|
|
3993
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
3994
|
+
|
|
3995
|
+
const int ith = params->ith;
|
|
3996
|
+
const int nth = params->nth;
|
|
3997
|
+
|
|
3998
|
+
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float;
|
|
3999
|
+
|
|
4000
|
+
// we don't support permuted src0 or src1
|
|
4001
|
+
GGML_ASSERT(nb00 == ggml_type_size(src0->type));
|
|
4002
|
+
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
|
4003
|
+
|
|
4004
|
+
// dst cannot be transposed or permuted
|
|
4005
|
+
GGML_ASSERT(nb0 == sizeof(float));
|
|
4006
|
+
GGML_ASSERT(nb0 <= nb1);
|
|
4007
|
+
GGML_ASSERT(nb1 <= nb2);
|
|
4008
|
+
GGML_ASSERT(nb2 <= nb3);
|
|
4009
|
+
|
|
4010
|
+
GGML_ASSERT(ne03 == 1);
|
|
4011
|
+
GGML_ASSERT(ne13 == 1);
|
|
4012
|
+
GGML_ASSERT(ne3 == 1);
|
|
4013
|
+
|
|
4014
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
4015
|
+
|
|
4016
|
+
// row groups
|
|
4017
|
+
const int n_ids = ids->ne[0]; // n_expert_used
|
|
4018
|
+
const int n_as = ne02; // n_expert
|
|
4019
|
+
|
|
4020
|
+
const size_t nbw1 = ggml_row_size(GGML_TYPE_Q8_0, ne10);
|
|
4021
|
+
const size_t nbw2 = nbw1*ne11;
|
|
4022
|
+
const size_t nbw3 = nbw2*ne12;
|
|
4023
|
+
|
|
4024
|
+
struct mmid_row_mapping {
|
|
4025
|
+
int32_t i1;
|
|
4026
|
+
int32_t i2;
|
|
4027
|
+
};
|
|
4028
|
+
|
|
4029
|
+
GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
|
|
4030
|
+
n_as * ne12 * sizeof(mmid_row_mapping)));
|
|
4031
|
+
|
|
4032
|
+
auto wdata = (char *) params->wdata;
|
|
4033
|
+
auto wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t));
|
|
4034
|
+
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
|
|
4035
|
+
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
|
|
4036
|
+
|
|
4037
|
+
// src1: float32 => block_q8_0
|
|
4038
|
+
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
|
4039
|
+
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
|
|
4040
|
+
from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
|
|
4041
|
+
(void *) (wdata + i12 * nbw2 + i11 * nbw1),
|
|
4042
|
+
ne10);
|
|
4043
|
+
}
|
|
4044
|
+
}
|
|
4045
|
+
|
|
4046
|
+
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)]
|
|
4047
|
+
|
|
4048
|
+
if (ith == 0) {
|
|
4049
|
+
// initialize matrix_row_counts
|
|
4050
|
+
memset(matrix_row_counts, 0, n_as * sizeof(int64_t));
|
|
4051
|
+
|
|
4052
|
+
// group rows by src0 matrix
|
|
4053
|
+
for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
|
|
4054
|
+
for (int32_t id = 0; id < n_ids; ++id) {
|
|
4055
|
+
const int32_t i02 =
|
|
4056
|
+
*(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
|
|
4057
|
+
|
|
4058
|
+
GGML_ASSERT(i02 >= 0 && i02 < n_as);
|
|
4059
|
+
|
|
4060
|
+
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 };
|
|
4061
|
+
matrix_row_counts[i02] += 1;
|
|
4062
|
+
}
|
|
4063
|
+
}
|
|
4064
|
+
}
|
|
4065
|
+
|
|
4066
|
+
ggml_barrier(params->threadpool);
|
|
4067
|
+
|
|
4068
|
+
// compute each matrix multiplication in sequence
|
|
4069
|
+
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
|
|
4070
|
+
const int64_t cne1 = matrix_row_counts[cur_a];
|
|
4071
|
+
|
|
4072
|
+
if (cne1 == 0) {
|
|
4073
|
+
continue;
|
|
4074
|
+
}
|
|
4075
|
+
|
|
4076
|
+
auto src0_cur = (const char *) src0->data + cur_a*nb02;
|
|
4077
|
+
|
|
4078
|
+
//const int64_t nr0 = ne01; // src0 rows
|
|
4079
|
+
const int64_t nr1 = cne1; // src1 rows
|
|
4080
|
+
|
|
4081
|
+
int64_t src0_cur_start = (ith * ne01) / nth;
|
|
4082
|
+
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
|
|
4083
|
+
src0_cur_start =
|
|
4084
|
+
(src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
|
|
4085
|
+
src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
|
|
4086
|
+
|
|
4087
|
+
if (src0_cur_start >= src0_cur_end) return;
|
|
4088
|
+
|
|
4089
|
+
for (int ir1 = 0; ir1 < nr1; ir1++) {
|
|
4090
|
+
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
|
|
4091
|
+
const int id = row_mapping.i1; // selected expert index
|
|
4092
|
+
|
|
4093
|
+
const int64_t i11 = id % ne11;
|
|
4094
|
+
const int64_t i12 = row_mapping.i2; // row index in src1
|
|
4095
|
+
|
|
4096
|
+
const int64_t i1 = id; // selected expert index
|
|
4097
|
+
const int64_t i2 = i12; // row
|
|
4098
|
+
|
|
4099
|
+
auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
|
|
4100
|
+
|
|
4101
|
+
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS>(
|
|
4102
|
+
ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start,
|
|
4103
|
+
ne01, src0_cur + src0_cur_start * nb01,
|
|
4104
|
+
src1_col, 1, src0_cur_end - src0_cur_start);
|
|
4105
|
+
}
|
|
4106
|
+
}
|
|
4107
|
+
#undef MMID_MATRIX_ROW
|
|
4108
|
+
}
|
|
4109
|
+
|
|
4110
|
+
int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
|
|
4111
|
+
GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type),
|
|
4112
|
+
(int) NB_COLS, (int) INTER_SIZE);
|
|
4113
|
+
return ggml::cpu::aarch64::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size);
|
|
4114
|
+
}
|
|
4115
|
+
};
|
|
4116
|
+
|
|
4117
|
+
// instance for Q4
|
|
4118
|
+
static const tensor_traits<block_q4_0, 4, 4> q4_0_4x4_q8_0;
|
|
4119
|
+
static const tensor_traits<block_q4_0, 8, 4> q4_0_4x8_q8_0;
|
|
4120
|
+
static const tensor_traits<block_q4_0, 8, 8> q4_0_8x8_q8_0;
|
|
4121
|
+
|
|
4122
|
+
// instance for IQ4
|
|
4123
|
+
static const tensor_traits<block_iq4_nl, 4, 4> iq4_nl_4x4_q8_0;
|
|
4124
|
+
|
|
4125
|
+
} // namespace ggml::cpu::aarch64
|
|
4126
|
+
|
|
4127
|
+
static const ggml::cpu::tensor_traits * ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur) {
|
|
4128
|
+
if (cur->type == GGML_TYPE_Q4_0) {
|
|
4129
|
+
if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
|
|
4130
|
+
if (cur->ne[1] % 8 == 0) {
|
|
4131
|
+
return &ggml::cpu::aarch64::q4_0_8x8_q8_0;
|
|
4132
|
+
}
|
|
4133
|
+
}
|
|
4134
|
+
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
|
4135
|
+
if (cur->ne[1] % 4 == 0) {
|
|
4136
|
+
return &ggml::cpu::aarch64::q4_0_4x8_q8_0;
|
|
4137
|
+
}
|
|
4138
|
+
}
|
|
4139
|
+
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
|
4140
|
+
if (cur->ne[1] % 4 == 0) {
|
|
4141
|
+
return &ggml::cpu::aarch64::q4_0_4x4_q8_0;
|
|
4142
|
+
}
|
|
4143
|
+
}
|
|
4144
|
+
} else if (cur->type == GGML_TYPE_IQ4_NL) {
|
|
4145
|
+
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
|
4146
|
+
if (cur->ne[1] % 4 == 0) {
|
|
4147
|
+
return &ggml::cpu::aarch64::iq4_nl_4x4_q8_0;
|
|
4148
|
+
}
|
|
4149
|
+
}
|
|
4150
|
+
}
|
|
4151
|
+
|
|
4152
|
+
return nullptr;
|
|
4153
|
+
}
|
|
4154
|
+
|
|
4155
|
+
static void ggml_backend_cpu_aarch64_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
|
|
4156
|
+
tensor->extra = (void *) const_cast<ggml::cpu::tensor_traits *>(ggml_aarch64_get_optimal_repack_type(tensor));
|
|
4157
|
+
|
|
4158
|
+
GGML_UNUSED(buffer);
|
|
4159
|
+
}
|
|
4160
|
+
|
|
4161
|
+
static void ggml_backend_cpu_aarch64_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
|
|
4162
|
+
const void * data, size_t offset, size_t size) {
|
|
4163
|
+
GGML_ASSERT(offset == 0);
|
|
4164
|
+
GGML_ASSERT(size == ggml_nbytes(tensor));
|
|
4165
|
+
|
|
4166
|
+
auto tensor_traits = (ggml::cpu::aarch64::tensor_traits_base *) tensor->extra;
|
|
4167
|
+
auto OK = tensor_traits->repack(tensor, data, size);
|
|
4168
|
+
|
|
4169
|
+
GGML_ASSERT(OK == 0);
|
|
4170
|
+
GGML_UNUSED(buffer);
|
|
4171
|
+
}
|
|
4172
|
+
|
|
4173
|
+
static const char * ggml_backend_cpu_aarch64_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
|
4174
|
+
return "CPU_AARCH64";
|
|
4175
|
+
|
|
4176
|
+
GGML_UNUSED(buft);
|
|
4177
|
+
}
|
|
4178
|
+
|
|
4179
|
+
static ggml_backend_buffer_t ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
|
4180
|
+
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
|
|
4181
|
+
|
|
4182
|
+
if (buffer == nullptr) {
|
|
4183
|
+
return nullptr;
|
|
4184
|
+
}
|
|
4185
|
+
|
|
4186
|
+
buffer->buft = buft;
|
|
4187
|
+
buffer->iface.init_tensor = ggml_backend_cpu_aarch64_buffer_init_tensor;
|
|
4188
|
+
buffer->iface.set_tensor = ggml_backend_cpu_aarch64_buffer_set_tensor;
|
|
4189
|
+
return buffer;
|
|
4190
|
+
}
|
|
4191
|
+
|
|
4192
|
+
static size_t ggml_backend_cpu_aarch64_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
|
4193
|
+
return TENSOR_ALIGNMENT;
|
|
4194
|
+
|
|
4195
|
+
GGML_UNUSED(buft);
|
|
4196
|
+
}
|
|
4197
|
+
|
|
4198
|
+
namespace ggml::cpu::aarch64 {
|
|
4199
|
+
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|
4200
|
+
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
|
4201
|
+
if ( op->op == GGML_OP_MUL_MAT &&
|
|
4202
|
+
op->src[0]->buffer &&
|
|
4203
|
+
(ggml_n_dims(op->src[0]) == 2) &&
|
|
4204
|
+
op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type() &&
|
|
4205
|
+
ggml_aarch64_get_optimal_repack_type(op->src[0])
|
|
4206
|
+
) {
|
|
4207
|
+
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
|
4208
|
+
return false;
|
|
4209
|
+
}
|
|
4210
|
+
if (op->src[1]->type == GGML_TYPE_F32) {
|
|
4211
|
+
return true;
|
|
4212
|
+
}
|
|
4213
|
+
//if (op->src[1]->type == GGML_TYPE_Q8_0) {
|
|
4214
|
+
// return true;
|
|
4215
|
+
//}
|
|
4216
|
+
// may be possible if Q8_0 packed...
|
|
4217
|
+
} else if (op->op == GGML_OP_MUL_MAT_ID
|
|
4218
|
+
&& op->src[0]->buffer
|
|
4219
|
+
&& (ggml_n_dims(op->src[0]) == 3)
|
|
4220
|
+
&& op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type()
|
|
4221
|
+
&& ggml_aarch64_get_optimal_repack_type(op->src[0])
|
|
4222
|
+
) {
|
|
4223
|
+
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
|
4224
|
+
return false;
|
|
4225
|
+
}
|
|
4226
|
+
if (op->src[1]->type == GGML_TYPE_F32) {
|
|
4227
|
+
return true;
|
|
4228
|
+
}
|
|
4229
|
+
//if (op->src[1]->type == GGML_TYPE_Q8_0) {
|
|
4230
|
+
// return true;
|
|
4231
|
+
//}
|
|
4232
|
+
}
|
|
4233
|
+
return false;
|
|
4234
|
+
}
|
|
4235
|
+
|
|
4236
|
+
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
|
|
4237
|
+
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) {
|
|
4238
|
+
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type()) {
|
|
4239
|
+
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
|
4240
|
+
}
|
|
4241
|
+
}
|
|
4242
|
+
return nullptr;
|
|
4243
|
+
}
|
|
4244
|
+
};
|
|
4245
|
+
} // namespace ggml::cpu::aarch64
|
|
4246
|
+
|
|
4247
|
+
ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void) {
|
|
4248
|
+
static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_aarch64 = {
|
|
4249
|
+
/* .iface = */ {
|
|
4250
|
+
/* .get_name = */ ggml_backend_cpu_aarch64_buffer_type_get_name,
|
|
4251
|
+
/* .alloc_buffer = */ ggml_backend_cpu_aarch64_buffer_type_alloc_buffer,
|
|
4252
|
+
/* .get_alignment = */ ggml_backend_cpu_aarch64_buffer_type_get_alignment,
|
|
4253
|
+
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
|
|
4254
|
+
/* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
|
|
4255
|
+
/* .is_host = */ nullptr,
|
|
4256
|
+
},
|
|
4257
|
+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
|
|
4258
|
+
/* .context = */ new ggml::cpu::aarch64::extra_buffer_type(),
|
|
4259
|
+
};
|
|
4260
|
+
|
|
4261
|
+
return &ggml_backend_cpu_buffer_type_aarch64;
|
|
4262
|
+
}
|