cui-llama.rn 1.7.4 → 1.7.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +217 -17
- package/android/src/main/CMakeLists.txt +34 -15
- package/android/src/main/java/com/rnllama/LlamaContext.java +79 -5
- package/android/src/main/java/com/rnllama/RNLlama.java +237 -0
- package/android/src/main/jni.cpp +213 -14
- package/android/src/main/jniLibs/arm64-v8a/librnllama.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_dotprod_i8mm.so +0 -0
- package/android/src/main/jniLibs/arm64-v8a/librnllama_v8_2_i8mm.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama.so +0 -0
- package/android/src/main/jniLibs/x86_64/librnllama_x86_64.so +0 -0
- package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +35 -0
- package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +34 -0
- package/cpp/README.md +1 -1
- package/cpp/chat-parser.cpp +385 -0
- package/cpp/chat-parser.h +120 -0
- package/cpp/chat.cpp +726 -596
- package/cpp/chat.h +71 -6
- package/cpp/common.cpp +56 -38
- package/cpp/common.h +9 -3
- package/cpp/ggml-backend-reg.cpp +5 -0
- package/cpp/ggml-backend.cpp +10 -2
- package/cpp/ggml-common.h +4 -0
- package/cpp/ggml-cpu/amx/amx.cpp +1 -1
- package/cpp/ggml-cpu/amx/mmq.cpp +11 -10
- package/cpp/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
- package/cpp/ggml-cpu/arch/arm/quants.c +4114 -0
- package/cpp/ggml-cpu/arch/arm/repack.cpp +2163 -0
- package/cpp/ggml-cpu/arch/x86/cpu-feats.cpp +327 -0
- package/cpp/ggml-cpu/arch/x86/quants.c +4311 -0
- package/cpp/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
- package/cpp/ggml-cpu/arch-fallback.h +184 -0
- package/cpp/ggml-cpu/common.h +4 -3
- package/cpp/ggml-cpu/ggml-cpu-impl.h +21 -16
- package/cpp/ggml-cpu/ggml-cpu.c +123 -104
- package/cpp/ggml-cpu/ggml-cpu.cpp +11 -8
- package/cpp/ggml-cpu/ops.cpp +330 -148
- package/cpp/ggml-cpu/ops.h +1 -0
- package/cpp/ggml-cpu/quants.c +1158 -0
- package/cpp/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
- package/cpp/ggml-cpu/repack.cpp +1571 -0
- package/cpp/ggml-cpu/repack.h +98 -0
- package/cpp/ggml-cpu/simd-mappings.h +330 -38
- package/cpp/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
- package/cpp/ggml-cpu/vec.cpp +87 -18
- package/cpp/ggml-cpu/vec.h +249 -94
- package/cpp/ggml-cpu.h +1 -0
- package/cpp/ggml-impl.h +63 -183
- package/cpp/ggml-llama-sim.metallib +0 -0
- package/cpp/ggml-llama.metallib +0 -0
- package/cpp/ggml-metal.m +152 -45
- package/cpp/ggml-quants.c +0 -2
- package/cpp/ggml.c +61 -21
- package/cpp/ggml.h +22 -3
- package/cpp/gguf.cpp +24 -3
- package/cpp/json-partial.cpp +256 -0
- package/cpp/json-partial.h +38 -0
- package/cpp/json-schema-to-grammar.cpp +5 -47
- package/cpp/json-schema-to-grammar.h +4 -4
- package/cpp/llama-arch.cpp +153 -3
- package/cpp/llama-arch.h +27 -1
- package/cpp/llama-batch.cpp +741 -272
- package/cpp/llama-batch.h +112 -54
- package/cpp/llama-chat.cpp +30 -8
- package/cpp/llama-chat.h +1 -0
- package/cpp/llama-context.cpp +524 -339
- package/cpp/llama-context.h +38 -17
- package/cpp/llama-cparams.cpp +4 -0
- package/cpp/llama-cparams.h +2 -0
- package/cpp/llama-grammar.cpp +12 -2
- package/cpp/llama-graph.cpp +431 -356
- package/cpp/llama-graph.h +126 -58
- package/cpp/llama-hparams.cpp +10 -2
- package/cpp/llama-hparams.h +19 -2
- package/cpp/llama-kv-cache-unified-iswa.cpp +279 -0
- package/cpp/llama-kv-cache-unified-iswa.h +128 -0
- package/cpp/llama-kv-cache-unified.cpp +1841 -0
- package/cpp/llama-kv-cache-unified.h +303 -0
- package/cpp/llama-kv-cells.h +439 -0
- package/cpp/llama-memory-hybrid.cpp +246 -0
- package/cpp/llama-memory-hybrid.h +138 -0
- package/cpp/llama-memory-recurrent.cpp +1112 -0
- package/cpp/llama-memory-recurrent.h +183 -0
- package/cpp/llama-memory.cpp +41 -0
- package/cpp/llama-memory.h +86 -5
- package/cpp/llama-mmap.cpp +1 -1
- package/cpp/llama-model-loader.cpp +42 -17
- package/cpp/llama-model-saver.cpp +1 -0
- package/cpp/llama-model.cpp +1639 -513
- package/cpp/llama-model.h +26 -0
- package/cpp/llama-sampling.cpp +2 -2
- package/cpp/llama-vocab.cpp +65 -28
- package/cpp/llama-vocab.h +1 -0
- package/cpp/llama.cpp +11 -7
- package/cpp/llama.h +150 -42
- package/cpp/minja/chat-template.hpp +1 -1
- package/cpp/minja/minja.hpp +1 -1
- package/cpp/{json.hpp → nlohmann/json.hpp} +3027 -2267
- package/cpp/nlohmann/json_fwd.hpp +187 -0
- package/cpp/regex-partial.cpp +204 -0
- package/cpp/regex-partial.h +56 -0
- package/cpp/rn-llama.cpp +646 -35
- package/cpp/rn-llama.h +32 -1
- package/cpp/rn-tts.h +39 -0
- package/cpp/sampling.cpp +7 -8
- package/cpp/tools/mtmd/clip-impl.h +5 -0
- package/cpp/tools/mtmd/clip.cpp +572 -436
- package/cpp/tools/mtmd/clip.h +14 -4
- package/cpp/tools/mtmd/mtmd-audio.cpp +0 -86
- package/cpp/tools/mtmd/mtmd-audio.h +2 -17
- package/cpp/tools/mtmd/mtmd-helper.cpp +175 -12
- package/cpp/tools/mtmd/mtmd-helper.h +91 -0
- package/cpp/tools/mtmd/mtmd.cpp +368 -248
- package/cpp/tools/mtmd/mtmd.h +6 -70
- package/cpp/unicode.cpp +5 -0
- package/ios/CMakeLists.txt +26 -6
- package/ios/RNLlama.h +1 -1
- package/ios/RNLlama.mm +153 -3
- package/ios/RNLlamaContext.h +9 -1
- package/ios/RNLlamaContext.mm +112 -9
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat-parser.h +120 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/chat.h +71 -6
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/common.h +9 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-common.h +4 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml-impl.h +63 -183
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/ggml.h +22 -3
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-partial.h +38 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-arch.h +27 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-batch.h +112 -54
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-context.h +38 -17
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-cparams.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-graph.h +126 -58
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-hparams.h +19 -2
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cells.h +439 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-memory.h +86 -5
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-model.h +26 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-vocab.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama.h +150 -42
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/minja/minja.hpp +1 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/{json.hpp → nlohmann/json.hpp} +3027 -2267
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/regex-partial.h +56 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-llama.h +32 -1
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/rn-tts.h +39 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat-parser.h +120 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +71 -6
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +9 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +4 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +63 -183
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +22 -3
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-partial.h +38 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +27 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +112 -54
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +38 -17
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +2 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +126 -58
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +19 -2
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cells.h +439 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +86 -5
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +26 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +1 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +150 -42
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +1 -1
- package/ios/rnllama.xcframework/{tvos-arm64/rnllama.framework/Headers → ios-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann}/json.hpp +3027 -2267
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/regex-partial.h +56 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +32 -1
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/rn-tts.h +39 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat-parser.h +120 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/chat.h +71 -6
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/common.h +9 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-common.h +4 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml-impl.h +63 -183
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/ggml.h +22 -3
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-partial.h +38 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-arch.h +27 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-batch.h +112 -54
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-context.h +38 -17
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-cparams.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-graph.h +126 -58
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-hparams.h +19 -2
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cells.h +439 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-memory.h +86 -5
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-model.h +26 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-vocab.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama.h +150 -42
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/minja/minja.hpp +1 -1
- package/ios/rnllama.xcframework/{ios-arm64_x86_64-simulator/rnllama.framework/Headers → tvos-arm64/rnllama.framework/Headers/nlohmann}/json.hpp +3027 -2267
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/regex-partial.h +56 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-llama.h +32 -1
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/rn-tts.h +39 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/ggml-llama.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/rnllama +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat-parser.h +120 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/chat.h +71 -6
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/common.h +9 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-common.h +4 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-cpu.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml-impl.h +63 -183
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/ggml.h +22 -3
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-partial.h +38 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json-schema-to-grammar.h +4 -4
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-arch.h +27 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-batch.h +112 -54
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-chat.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-context.h +38 -17
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-cparams.h +2 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-graph.h +126 -58
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-hparams.h +19 -2
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified-iswa.h +128 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache-unified.h +303 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cells.h +439 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-hybrid.h +138 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory-recurrent.h +183 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-memory.h +86 -5
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-model.h +26 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-vocab.h +1 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama.h +150 -42
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/chat-template.hpp +1 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/minja/minja.hpp +1 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json.hpp +25526 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/nlohmann/json_fwd.hpp +187 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/regex-partial.h +56 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-llama.h +32 -1
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/rn-tts.h +39 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/ggml-llama-sim.metallib +0 -0
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/rnllama +0 -0
- package/jest/mock.js +24 -0
- package/package.json +1 -1
- package/src/NativeRNLlama.ts +46 -2
- package/src/index.ts +105 -1
- package/cpp/ggml-cpu/ggml-cpu-aarch64.h +0 -8
- package/cpp/ggml-cpu/ggml-cpu-quants.c +0 -13326
- package/cpp/ggml-cpu/sgemm.cpp +0 -3544
- package/cpp/ggml-cpu/sgemm.h +0 -14
- package/cpp/llama-kv-cache.cpp +0 -2827
- package/cpp/llama-kv-cache.h +0 -515
- package/ios/rnllama.xcframework/ios-arm64/rnllama.framework/Headers/llama-kv-cache.h +0 -515
- package/ios/rnllama.xcframework/ios-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +0 -515
- package/ios/rnllama.xcframework/tvos-arm64/rnllama.framework/Headers/llama-kv-cache.h +0 -515
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/json.hpp +0 -24766
- package/ios/rnllama.xcframework/tvos-arm64_x86_64-simulator/rnllama.framework/Headers/llama-kv-cache.h +0 -515
- /package/cpp/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
- /package/cpp/tools/mtmd/{miniaudio.h → miniaudio/miniaudio.h} +0 -0
- /package/cpp/tools/mtmd/{stb_image.h → stb/stb_image.h} +0 -0
@@ -0,0 +1,4114 @@
|
|
1
|
+
#define LM_GGML_COMMON_IMPL_C
|
2
|
+
#include "ggml-common.h"
|
3
|
+
#include "ggml-quants.h"
|
4
|
+
#include "ggml-impl.h"
|
5
|
+
#include "ggml-cpu.h"
|
6
|
+
#include "simd-mappings.h"
|
7
|
+
|
8
|
+
#include "../../quants.h"
|
9
|
+
#include "../../ggml-cpu-impl.h"
|
10
|
+
|
11
|
+
#include <math.h>
|
12
|
+
#include <string.h>
|
13
|
+
#include <assert.h>
|
14
|
+
#include <float.h>
|
15
|
+
#include <stdlib.h> // for qsort
|
16
|
+
#include <stdio.h> // for LM_GGML_ASSERT
|
17
|
+
|
18
|
+
#define GROUP_MAX_EPS 1e-15f
|
19
|
+
#define GROUP_MAX_EPS_IQ3_XXS 1e-8f
|
20
|
+
#define GROUP_MAX_EPS_IQ2_S 1e-8f
|
21
|
+
#define GROUP_MAX_EPS_IQ1_M 1e-7f
|
22
|
+
#define GROUP_MAX_EPS_IQ1_S 1e-12f
|
23
|
+
|
24
|
+
#define UNUSED LM_GGML_UNUSED
|
25
|
+
|
26
|
+
#if defined(__ARM_NEON)
|
27
|
+
#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s
|
28
|
+
#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
|
29
|
+
#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
|
30
|
+
#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
|
31
|
+
#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
|
32
|
+
#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
|
33
|
+
#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
|
34
|
+
#define B8(c,s ) B7(c,s, c), B7(c,s, s)
|
35
|
+
|
36
|
+
// precomputed tables for expanding 8bits to 8 bytes:
|
37
|
+
static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
|
38
|
+
static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
|
39
|
+
#endif
|
40
|
+
|
41
|
+
void quantize_row_q8_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) {
|
42
|
+
assert(QK8_0 == 32);
|
43
|
+
assert(k % QK8_0 == 0);
|
44
|
+
const int nb = k / QK8_0;
|
45
|
+
|
46
|
+
block_q8_0 * LM_GGML_RESTRICT y = vy;
|
47
|
+
|
48
|
+
#if defined(__ARM_NEON)
|
49
|
+
for (int i = 0; i < nb; i++) {
|
50
|
+
float32x4_t srcv [8];
|
51
|
+
float32x4_t asrcv[8];
|
52
|
+
float32x4_t amaxv[8];
|
53
|
+
|
54
|
+
for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
|
55
|
+
for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
|
56
|
+
|
57
|
+
for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
|
58
|
+
for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
|
59
|
+
for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
|
60
|
+
|
61
|
+
const float amax = vmaxvq_f32(amaxv[0]);
|
62
|
+
|
63
|
+
const float d = amax / ((1 << 7) - 1);
|
64
|
+
const float id = d ? 1.0f/d : 0.0f;
|
65
|
+
|
66
|
+
y[i].d = LM_GGML_CPU_FP32_TO_FP16(d);
|
67
|
+
|
68
|
+
for (int j = 0; j < 8; j++) {
|
69
|
+
const float32x4_t v = vmulq_n_f32(srcv[j], id);
|
70
|
+
const int32x4_t vi = vcvtnq_s32_f32(v);
|
71
|
+
|
72
|
+
y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
|
73
|
+
y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
|
74
|
+
y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
|
75
|
+
y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
|
76
|
+
}
|
77
|
+
}
|
78
|
+
#else
|
79
|
+
LM_GGML_UNUSED(nb);
|
80
|
+
// scalar
|
81
|
+
quantize_row_q8_0_ref(x, y, k);
|
82
|
+
#endif
|
83
|
+
}
|
84
|
+
|
85
|
+
void quantize_row_q8_1(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) {
|
86
|
+
assert(k % QK8_1 == 0);
|
87
|
+
const int nb = k / QK8_1;
|
88
|
+
|
89
|
+
block_q8_1 * LM_GGML_RESTRICT y = vy;
|
90
|
+
#if defined(__ARM_NEON)
|
91
|
+
for (int i = 0; i < nb; i++) {
|
92
|
+
float32x4_t srcv [8];
|
93
|
+
float32x4_t asrcv[8];
|
94
|
+
float32x4_t amaxv[8];
|
95
|
+
|
96
|
+
for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j);
|
97
|
+
for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
|
98
|
+
|
99
|
+
for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
|
100
|
+
for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
|
101
|
+
for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
|
102
|
+
|
103
|
+
const float amax = vmaxvq_f32(amaxv[0]);
|
104
|
+
|
105
|
+
const float d = amax / ((1 << 7) - 1);
|
106
|
+
const float id = d ? 1.0f/d : 0.0f;
|
107
|
+
|
108
|
+
y[i].d = LM_GGML_CPU_FP32_TO_FP16(d);
|
109
|
+
|
110
|
+
int32x4_t accv = vdupq_n_s32(0);
|
111
|
+
|
112
|
+
for (int j = 0; j < 8; j++) {
|
113
|
+
const float32x4_t v = vmulq_n_f32(srcv[j], id);
|
114
|
+
const int32x4_t vi = vcvtnq_s32_f32(v);
|
115
|
+
|
116
|
+
y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
|
117
|
+
y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
|
118
|
+
y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
|
119
|
+
y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
|
120
|
+
|
121
|
+
accv = vaddq_s32(accv, vi);
|
122
|
+
}
|
123
|
+
|
124
|
+
y[i].s = LM_GGML_CPU_FP32_TO_FP16(d * vaddvq_s32(accv));
|
125
|
+
}
|
126
|
+
#else
|
127
|
+
LM_GGML_UNUSED(nb);
|
128
|
+
// scalar
|
129
|
+
quantize_row_q8_1_ref(x, y, k);
|
130
|
+
#endif
|
131
|
+
}
|
132
|
+
|
133
|
+
// placeholder implementation for Apple targets
|
134
|
+
void quantize_row_q8_K(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k) {
|
135
|
+
quantize_row_q8_K_ref(x, y, k);
|
136
|
+
}
|
137
|
+
|
138
|
+
//===================================== Dot products =================================
|
139
|
+
|
140
|
+
void lm_ggml_vec_dot_q4_0_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
|
141
|
+
const int qk = QK8_0;
|
142
|
+
const int nb = n / qk;
|
143
|
+
|
144
|
+
assert(n % qk == 0);
|
145
|
+
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
146
|
+
assert((nrc == 2) || (nrc == 1));
|
147
|
+
#else
|
148
|
+
assert(nrc == 1);
|
149
|
+
#endif
|
150
|
+
UNUSED(nrc);
|
151
|
+
UNUSED(bx);
|
152
|
+
UNUSED(by);
|
153
|
+
UNUSED(bs);
|
154
|
+
|
155
|
+
const block_q4_0 * LM_GGML_RESTRICT x = vx;
|
156
|
+
const block_q8_0 * LM_GGML_RESTRICT y = vy;
|
157
|
+
|
158
|
+
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
159
|
+
if (nrc == 2) {
|
160
|
+
const block_q4_0 * LM_GGML_RESTRICT vx0 = vx;
|
161
|
+
const block_q4_0 * LM_GGML_RESTRICT vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx);
|
162
|
+
const block_q8_0 * LM_GGML_RESTRICT vy0 = vy;
|
163
|
+
const block_q8_0 * LM_GGML_RESTRICT vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
|
164
|
+
|
165
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
166
|
+
|
167
|
+
for (int i = 0; i < nb; i++) {
|
168
|
+
const block_q4_0 * LM_GGML_RESTRICT b_x0 = &vx0[i];
|
169
|
+
const block_q4_0 * LM_GGML_RESTRICT b_x1 = &vx1[i];
|
170
|
+
const block_q8_0 * LM_GGML_RESTRICT b_y0 = &vy0[i];
|
171
|
+
const block_q8_0 * LM_GGML_RESTRICT b_y1 = &vy1[i];
|
172
|
+
|
173
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
174
|
+
const int8x16_t s8b = vdupq_n_s8(0x8);
|
175
|
+
|
176
|
+
const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
|
177
|
+
const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
|
178
|
+
|
179
|
+
// 4-bit -> 8-bit
|
180
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
181
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
182
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
183
|
+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
184
|
+
|
185
|
+
// sub 8
|
186
|
+
const int8x16_t x0_l = vsubq_s8(v0_0l, s8b);
|
187
|
+
const int8x16_t x0_h = vsubq_s8(v0_0h, s8b);
|
188
|
+
const int8x16_t x1_l = vsubq_s8(v0_1l, s8b);
|
189
|
+
const int8x16_t x1_h = vsubq_s8(v0_1h, s8b);
|
190
|
+
|
191
|
+
// load y
|
192
|
+
const int8x16_t y0_l = vld1q_s8(b_y0->qs);
|
193
|
+
const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
|
194
|
+
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
|
195
|
+
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
|
196
|
+
|
197
|
+
float32_t _scale[4] = {
|
198
|
+
LM_GGML_CPU_FP16_TO_FP32(b_x0->d)*LM_GGML_CPU_FP16_TO_FP32(b_y0->d),
|
199
|
+
LM_GGML_CPU_FP16_TO_FP32(b_x0->d)*LM_GGML_CPU_FP16_TO_FP32(b_y1->d),
|
200
|
+
LM_GGML_CPU_FP16_TO_FP32(b_x1->d)*LM_GGML_CPU_FP16_TO_FP32(b_y0->d),
|
201
|
+
LM_GGML_CPU_FP16_TO_FP32(b_x1->d)*LM_GGML_CPU_FP16_TO_FP32(b_y1->d)
|
202
|
+
};
|
203
|
+
float32x4_t scale = vld1q_f32(_scale);
|
204
|
+
|
205
|
+
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
206
|
+
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
207
|
+
|
208
|
+
int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
|
209
|
+
int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
|
210
|
+
|
211
|
+
int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
|
212
|
+
int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
|
213
|
+
|
214
|
+
int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
215
|
+
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
216
|
+
|
217
|
+
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
|
218
|
+
l1, r1)), l2, r2)), l3, r3))), scale);
|
219
|
+
}
|
220
|
+
|
221
|
+
float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
|
222
|
+
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
|
223
|
+
|
224
|
+
vst1_f32(s, vget_low_f32 (sumv2));
|
225
|
+
vst1_f32(s + bs, vget_high_f32(sumv2));
|
226
|
+
|
227
|
+
return;
|
228
|
+
}
|
229
|
+
#endif
|
230
|
+
|
231
|
+
int ib = 0;
|
232
|
+
float sumf = 0;
|
233
|
+
|
234
|
+
#if defined(__ARM_FEATURE_SVE)
|
235
|
+
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
236
|
+
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
237
|
+
|
238
|
+
const int vector_length = lm_ggml_cpu_get_sve_cnt()*8;
|
239
|
+
|
240
|
+
// VLA Implementation using switch case
|
241
|
+
switch (vector_length) {
|
242
|
+
case 128:
|
243
|
+
{
|
244
|
+
// predicate for activating higher lanes for 4 float32 elements
|
245
|
+
const svbool_t ph4 = svptrue_pat_b32(SV_VL4);
|
246
|
+
|
247
|
+
for (; ib + 1 < nb; ib += 2) {
|
248
|
+
const block_q4_0 * LM_GGML_RESTRICT x0 = &x[ib + 0];
|
249
|
+
const block_q4_0 * LM_GGML_RESTRICT x1 = &x[ib + 1];
|
250
|
+
const block_q8_0 * LM_GGML_RESTRICT y0 = &y[ib + 0];
|
251
|
+
const block_q8_0 * LM_GGML_RESTRICT y1 = &y[ib + 1];
|
252
|
+
|
253
|
+
// load x
|
254
|
+
const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
|
255
|
+
const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
|
256
|
+
|
257
|
+
// 4-bit -> 8-bit
|
258
|
+
const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx0r, 0x0F));
|
259
|
+
const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx0r, 0x04));
|
260
|
+
const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx1r, 0x0F));
|
261
|
+
const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04));
|
262
|
+
|
263
|
+
// sub 8
|
264
|
+
const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8);
|
265
|
+
const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8);
|
266
|
+
const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8);
|
267
|
+
const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8);
|
268
|
+
|
269
|
+
// load y
|
270
|
+
const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs);
|
271
|
+
const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs + 16);
|
272
|
+
const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs);
|
273
|
+
const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs + 16);
|
274
|
+
|
275
|
+
// dot product
|
276
|
+
sumv0 = svmla_n_f32_x(ph4, sumv0, svcvt_f32_s32_x(ph4, svadd_x(ph4,
|
277
|
+
svdot_s32(svdup_n_s32(0), qx0ls, qy0l),
|
278
|
+
svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), LM_GGML_CPU_FP16_TO_FP32(x0->d)*LM_GGML_CPU_FP16_TO_FP32(y0->d));
|
279
|
+
sumv1 = svmla_n_f32_x(ph4, sumv1, svcvt_f32_s32_x(ph4, svadd_x(ph4,
|
280
|
+
svdot_s32(svdup_n_s32(0), qx1ls, qy1l),
|
281
|
+
svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), LM_GGML_CPU_FP16_TO_FP32(x1->d)*LM_GGML_CPU_FP16_TO_FP32(y1->d));
|
282
|
+
}
|
283
|
+
|
284
|
+
sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
285
|
+
} break;
|
286
|
+
case 256:
|
287
|
+
{
|
288
|
+
// predicate for activating higher lanes for 16 int8 elements
|
289
|
+
const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
|
290
|
+
// predicate for activating lower lanes for 16 int8 elements
|
291
|
+
const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
|
292
|
+
|
293
|
+
for (; ib + 1 < nb; ib += 2) {
|
294
|
+
const block_q4_0 * LM_GGML_RESTRICT x0 = &x[ib + 0];
|
295
|
+
const block_q4_0 * LM_GGML_RESTRICT x1 = &x[ib + 1];
|
296
|
+
const block_q8_0 * LM_GGML_RESTRICT y0 = &y[ib + 0];
|
297
|
+
const block_q8_0 * LM_GGML_RESTRICT y1 = &y[ib + 1];
|
298
|
+
|
299
|
+
// load x
|
300
|
+
const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
|
301
|
+
const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
|
302
|
+
|
303
|
+
// 4-bit -> 8-bit
|
304
|
+
const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
|
305
|
+
const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
|
306
|
+
|
307
|
+
// sub 8
|
308
|
+
const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
|
309
|
+
const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
|
310
|
+
|
311
|
+
// load y
|
312
|
+
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
|
313
|
+
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
|
314
|
+
|
315
|
+
// dot product
|
316
|
+
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
|
317
|
+
svdot_s32(svdup_n_s32(0), qx0s, qy0)), LM_GGML_CPU_FP16_TO_FP32(x0->d)*LM_GGML_CPU_FP16_TO_FP32(y0->d));
|
318
|
+
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
|
319
|
+
svdot_s32(svdup_n_s32(0), qx1s, qy1)), LM_GGML_CPU_FP16_TO_FP32(x1->d)*LM_GGML_CPU_FP16_TO_FP32(y1->d));
|
320
|
+
}
|
321
|
+
|
322
|
+
sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
323
|
+
} break;
|
324
|
+
case 512:
|
325
|
+
{
|
326
|
+
// predicate for activating higher lanes for 32 int8 elements
|
327
|
+
const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
|
328
|
+
|
329
|
+
// predicate for activating higher lanes for 16 int8 elements
|
330
|
+
const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
|
331
|
+
// predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes
|
332
|
+
const svbool_t pl16 = svnot_b_z(ph32, ph16);
|
333
|
+
|
334
|
+
for (; ib + 1 < nb; ib += 2) {
|
335
|
+
const block_q4_0 * LM_GGML_RESTRICT x0 = &x[ib + 0];
|
336
|
+
const block_q4_0 * LM_GGML_RESTRICT x1 = &x[ib + 1];
|
337
|
+
const block_q8_0 * LM_GGML_RESTRICT y0 = &y[ib + 0];
|
338
|
+
const block_q8_0 * LM_GGML_RESTRICT y1 = &y[ib + 1];
|
339
|
+
|
340
|
+
// load x
|
341
|
+
const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs);
|
342
|
+
const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs);
|
343
|
+
|
344
|
+
// 4-bit -> 8-bit
|
345
|
+
const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
|
346
|
+
const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
|
347
|
+
|
348
|
+
// sub 8
|
349
|
+
const svint8_t qx0s = svsub_n_s8_x(ph32, qx0, 8);
|
350
|
+
const svint8_t qx1s = svsub_n_s8_x(ph32, qx1, 8);
|
351
|
+
|
352
|
+
// load y
|
353
|
+
const svint8_t qy0 = svld1_s8(ph32, y0->qs);
|
354
|
+
const svint8_t qy1 = svld1_s8(ph32, y1->qs);
|
355
|
+
|
356
|
+
// dot product
|
357
|
+
sumv0 = svmla_n_f32_x(ph32, sumv0, svcvt_f32_s32_x(ph32,
|
358
|
+
svdot_s32(svdup_n_s32(0), qx0s, qy0)), LM_GGML_CPU_FP16_TO_FP32(x0->d)*LM_GGML_CPU_FP16_TO_FP32(y0->d));
|
359
|
+
sumv1 = svmla_n_f32_x(ph32, sumv1, svcvt_f32_s32_x(ph32,
|
360
|
+
svdot_s32(svdup_n_s32(0), qx1s, qy1)), LM_GGML_CPU_FP16_TO_FP32(x1->d)*LM_GGML_CPU_FP16_TO_FP32(y1->d));
|
361
|
+
}
|
362
|
+
|
363
|
+
sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1));
|
364
|
+
} break;
|
365
|
+
default:
|
366
|
+
assert(false && "Unsupported vector length");
|
367
|
+
break;
|
368
|
+
}
|
369
|
+
|
370
|
+
#elif defined(__ARM_NEON)
|
371
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
372
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
373
|
+
|
374
|
+
for (; ib + 1 < nb; ib += 2) {
|
375
|
+
const block_q4_0 * LM_GGML_RESTRICT x0 = &x[ib + 0];
|
376
|
+
const block_q4_0 * LM_GGML_RESTRICT x1 = &x[ib + 1];
|
377
|
+
const block_q8_0 * LM_GGML_RESTRICT y0 = &y[ib + 0];
|
378
|
+
const block_q8_0 * LM_GGML_RESTRICT y1 = &y[ib + 1];
|
379
|
+
|
380
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
381
|
+
const int8x16_t s8b = vdupq_n_s8(0x8);
|
382
|
+
|
383
|
+
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
384
|
+
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
385
|
+
|
386
|
+
// 4-bit -> 8-bit
|
387
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
388
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
389
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
390
|
+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
391
|
+
|
392
|
+
// sub 8
|
393
|
+
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
394
|
+
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
395
|
+
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
396
|
+
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
397
|
+
|
398
|
+
// load y
|
399
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
400
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
401
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
402
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
403
|
+
|
404
|
+
// dot product into int32x4_t
|
405
|
+
const int32x4_t p_0 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
|
406
|
+
const int32x4_t p_1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
|
407
|
+
|
408
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), LM_GGML_CPU_FP16_TO_FP32(x0->d)*LM_GGML_CPU_FP16_TO_FP32(y0->d));
|
409
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), LM_GGML_CPU_FP16_TO_FP32(x1->d)*LM_GGML_CPU_FP16_TO_FP32(y1->d));
|
410
|
+
}
|
411
|
+
|
412
|
+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
413
|
+
#endif
|
414
|
+
for (; ib < nb; ++ib) {
|
415
|
+
int sumi0 = 0;
|
416
|
+
int sumi1 = 0;
|
417
|
+
|
418
|
+
for (int j = 0; j < qk/2; ++j) {
|
419
|
+
const int v0 = (x[ib].qs[j] & 0x0F) - 8;
|
420
|
+
const int v1 = (x[ib].qs[j] >> 4) - 8;
|
421
|
+
|
422
|
+
sumi0 += (v0 * y[ib].qs[j]);
|
423
|
+
sumi1 += (v1 * y[ib].qs[j + qk/2]);
|
424
|
+
}
|
425
|
+
|
426
|
+
int sumi = sumi0 + sumi1;
|
427
|
+
sumf += sumi*LM_GGML_CPU_FP16_TO_FP32(x[ib].d)*LM_GGML_CPU_FP16_TO_FP32(y[ib].d);
|
428
|
+
}
|
429
|
+
|
430
|
+
*s = sumf;
|
431
|
+
}
|
432
|
+
|
433
|
+
void lm_ggml_vec_dot_q4_1_q8_1(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
|
434
|
+
const int qk = QK8_1;
|
435
|
+
const int nb = n / qk;
|
436
|
+
|
437
|
+
assert(n % qk == 0);
|
438
|
+
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
439
|
+
assert((nrc == 2) || (nrc == 1));
|
440
|
+
#else
|
441
|
+
assert(nrc == 1);
|
442
|
+
#endif
|
443
|
+
UNUSED(nrc);
|
444
|
+
UNUSED(bx);
|
445
|
+
UNUSED(by);
|
446
|
+
UNUSED(bs);
|
447
|
+
|
448
|
+
const block_q4_1 * LM_GGML_RESTRICT x = vx;
|
449
|
+
const block_q8_1 * LM_GGML_RESTRICT y = vy;
|
450
|
+
|
451
|
+
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
452
|
+
if (nrc == 2) {
|
453
|
+
const block_q4_1 * LM_GGML_RESTRICT vx0 = vx;
|
454
|
+
const block_q4_1 * LM_GGML_RESTRICT vx1 = (const block_q4_1 *) ((const uint8_t*)vx + bx);
|
455
|
+
const block_q8_1 * LM_GGML_RESTRICT vy0 = vy;
|
456
|
+
const block_q8_1 * LM_GGML_RESTRICT vy1 = (const block_q8_1 *) ((const uint8_t*)vy + by);
|
457
|
+
|
458
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
459
|
+
float32x4_t summs0 = vdupq_n_f32(0.0f);
|
460
|
+
|
461
|
+
for (int i = 0; i < nb; i++) {
|
462
|
+
const block_q4_1 * LM_GGML_RESTRICT b_x0 = &vx0[i];
|
463
|
+
const block_q4_1 * LM_GGML_RESTRICT b_x1 = &vx1[i];
|
464
|
+
const block_q8_1 * LM_GGML_RESTRICT b_y0 = &vy0[i];
|
465
|
+
const block_q8_1 * LM_GGML_RESTRICT b_y1 = &vy1[i];
|
466
|
+
|
467
|
+
float32_t summs_t[4] = {
|
468
|
+
LM_GGML_CPU_FP16_TO_FP32(b_x0->m) * LM_GGML_CPU_FP16_TO_FP32(b_y0->s),
|
469
|
+
LM_GGML_CPU_FP16_TO_FP32(b_x1->m) * LM_GGML_CPU_FP16_TO_FP32(b_y0->s),
|
470
|
+
LM_GGML_CPU_FP16_TO_FP32(b_x0->m) * LM_GGML_CPU_FP16_TO_FP32(b_y1->s),
|
471
|
+
LM_GGML_CPU_FP16_TO_FP32(b_x1->m) * LM_GGML_CPU_FP16_TO_FP32(b_y1->s)
|
472
|
+
};
|
473
|
+
summs0 = vaddq_f32(summs0, vld1q_f32(summs_t));
|
474
|
+
|
475
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
476
|
+
|
477
|
+
const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
|
478
|
+
const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
|
479
|
+
|
480
|
+
// 4-bit -> 8-bit
|
481
|
+
const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
482
|
+
const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
483
|
+
const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
484
|
+
const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
485
|
+
|
486
|
+
// load y
|
487
|
+
const int8x16_t y0_l = vld1q_s8(b_y0->qs);
|
488
|
+
const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
|
489
|
+
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
|
490
|
+
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
|
491
|
+
|
492
|
+
// mmla into int32x4_t
|
493
|
+
float32_t _scale[4] = {
|
494
|
+
LM_GGML_CPU_FP16_TO_FP32(b_x0->d)*LM_GGML_CPU_FP16_TO_FP32(b_y0->d),
|
495
|
+
LM_GGML_CPU_FP16_TO_FP32(b_x0->d)*LM_GGML_CPU_FP16_TO_FP32(b_y1->d),
|
496
|
+
LM_GGML_CPU_FP16_TO_FP32(b_x1->d)*LM_GGML_CPU_FP16_TO_FP32(b_y0->d),
|
497
|
+
LM_GGML_CPU_FP16_TO_FP32(b_x1->d)*LM_GGML_CPU_FP16_TO_FP32(b_y1->d)
|
498
|
+
};
|
499
|
+
float32x4_t scale = vld1q_f32(_scale);
|
500
|
+
|
501
|
+
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
502
|
+
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
503
|
+
|
504
|
+
int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
|
505
|
+
int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
|
506
|
+
|
507
|
+
int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
|
508
|
+
int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
|
509
|
+
|
510
|
+
int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
511
|
+
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
512
|
+
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
|
513
|
+
l1, r1)), l2, r2)), l3, r3))), scale);
|
514
|
+
}
|
515
|
+
|
516
|
+
float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
|
517
|
+
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
|
518
|
+
|
519
|
+
sumv2 = vaddq_f32(sumv2, summs0);
|
520
|
+
|
521
|
+
vst1_f32(s, vget_low_f32 (sumv2));
|
522
|
+
vst1_f32(s + bs, vget_high_f32(sumv2));
|
523
|
+
|
524
|
+
return;
|
525
|
+
}
|
526
|
+
#endif
|
527
|
+
|
528
|
+
int ib = 0;
|
529
|
+
float sumf = 0;
|
530
|
+
|
531
|
+
#if defined(__ARM_NEON)
|
532
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
533
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
534
|
+
|
535
|
+
float summs = 0;
|
536
|
+
|
537
|
+
for (; ib + 1 < nb; ib += 2) {
|
538
|
+
const block_q4_1 * LM_GGML_RESTRICT x0 = &x[ib + 0];
|
539
|
+
const block_q4_1 * LM_GGML_RESTRICT x1 = &x[ib + 1];
|
540
|
+
const block_q8_1 * LM_GGML_RESTRICT y0 = &y[ib + 0];
|
541
|
+
const block_q8_1 * LM_GGML_RESTRICT y1 = &y[ib + 1];
|
542
|
+
|
543
|
+
summs += LM_GGML_CPU_FP16_TO_FP32(x0->m) * LM_GGML_CPU_FP16_TO_FP32(y0->s) + LM_GGML_CPU_FP16_TO_FP32(x1->m) * LM_GGML_CPU_FP16_TO_FP32(y1->s);
|
544
|
+
|
545
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
546
|
+
|
547
|
+
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
548
|
+
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
549
|
+
|
550
|
+
// 4-bit -> 8-bit
|
551
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
552
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
553
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
554
|
+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
555
|
+
|
556
|
+
// load y
|
557
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
558
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
559
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
560
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
561
|
+
|
562
|
+
// dot product into int32x4_t
|
563
|
+
const int32x4_t p_0 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
|
564
|
+
const int32x4_t p_1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
|
565
|
+
|
566
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), LM_GGML_CPU_FP16_TO_FP32(x0->d)*LM_GGML_CPU_FP16_TO_FP32(y0->d));
|
567
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), LM_GGML_CPU_FP16_TO_FP32(x1->d)*LM_GGML_CPU_FP16_TO_FP32(y1->d));
|
568
|
+
}
|
569
|
+
|
570
|
+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
|
571
|
+
|
572
|
+
#endif
|
573
|
+
for (; ib < nb; ++ib) {
|
574
|
+
int sumi0 = 0;
|
575
|
+
int sumi1 = 0;
|
576
|
+
|
577
|
+
for (int j = 0; j < qk/2; ++j) {
|
578
|
+
const int v0 = (x[ib].qs[j] & 0x0F);
|
579
|
+
const int v1 = (x[ib].qs[j] >> 4);
|
580
|
+
|
581
|
+
sumi0 += (v0 * y[ib].qs[j]);
|
582
|
+
sumi1 += (v1 * y[ib].qs[j + qk/2]);
|
583
|
+
}
|
584
|
+
|
585
|
+
int sumi = sumi0 + sumi1;
|
586
|
+
sumf += (LM_GGML_CPU_FP16_TO_FP32(x[ib].d)*LM_GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + LM_GGML_CPU_FP16_TO_FP32(x[ib].m)*LM_GGML_CPU_FP16_TO_FP32(y[ib].s);
|
587
|
+
}
|
588
|
+
|
589
|
+
*s = sumf;
|
590
|
+
}
|
591
|
+
|
592
|
+
void lm_ggml_vec_dot_q5_0_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
|
593
|
+
const int qk = QK8_0;
|
594
|
+
const int nb = n / qk;
|
595
|
+
|
596
|
+
int ib = 0;
|
597
|
+
float sumf = 0;
|
598
|
+
|
599
|
+
assert(n % qk == 0);
|
600
|
+
assert(qk == QK5_0);
|
601
|
+
assert(nrc == 1);
|
602
|
+
UNUSED(nrc);
|
603
|
+
UNUSED(bx);
|
604
|
+
UNUSED(by);
|
605
|
+
UNUSED(bs);
|
606
|
+
|
607
|
+
const block_q5_0 * LM_GGML_RESTRICT x = vx;
|
608
|
+
const block_q8_0 * LM_GGML_RESTRICT y = vy;
|
609
|
+
|
610
|
+
#if defined(__ARM_NEON)
|
611
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
612
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
613
|
+
|
614
|
+
uint32_t qh0;
|
615
|
+
uint32_t qh1;
|
616
|
+
|
617
|
+
uint64_t tmp0[4];
|
618
|
+
uint64_t tmp1[4];
|
619
|
+
|
620
|
+
for (; ib + 1 < nb; ib += 2) {
|
621
|
+
const block_q5_0 * LM_GGML_RESTRICT x0 = &x[ib];
|
622
|
+
const block_q5_0 * LM_GGML_RESTRICT x1 = &x[ib + 1];
|
623
|
+
const block_q8_0 * LM_GGML_RESTRICT y0 = &y[ib];
|
624
|
+
const block_q8_0 * LM_GGML_RESTRICT y1 = &y[ib + 1];
|
625
|
+
|
626
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
627
|
+
|
628
|
+
// extract the 5th bit via lookup table ((!b) << 4)
|
629
|
+
memcpy(&qh0, x0->qh, sizeof(qh0));
|
630
|
+
memcpy(&qh1, x1->qh, sizeof(qh1));
|
631
|
+
|
632
|
+
tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF];
|
633
|
+
tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF];
|
634
|
+
tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
|
635
|
+
tmp0[3] = table_b2b_1[(qh0 >> 24) ];
|
636
|
+
|
637
|
+
tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF];
|
638
|
+
tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF];
|
639
|
+
tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
|
640
|
+
tmp1[3] = table_b2b_1[(qh1 >> 24) ];
|
641
|
+
|
642
|
+
const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
|
643
|
+
const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
|
644
|
+
const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
|
645
|
+
const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
|
646
|
+
|
647
|
+
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
648
|
+
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
649
|
+
|
650
|
+
// 4-bit -> 8-bit
|
651
|
+
int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
652
|
+
int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
653
|
+
int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
654
|
+
int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
655
|
+
|
656
|
+
// add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
|
657
|
+
const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0);
|
658
|
+
const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0);
|
659
|
+
const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1);
|
660
|
+
const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1);
|
661
|
+
|
662
|
+
// load y
|
663
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
664
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
665
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
666
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
667
|
+
|
668
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
669
|
+
lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
|
670
|
+
lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), LM_GGML_CPU_FP16_TO_FP32(x0->d)*LM_GGML_CPU_FP16_TO_FP32(y0->d));
|
671
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
672
|
+
lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
|
673
|
+
lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), LM_GGML_CPU_FP16_TO_FP32(x1->d)*LM_GGML_CPU_FP16_TO_FP32(y1->d));
|
674
|
+
}
|
675
|
+
|
676
|
+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
677
|
+
|
678
|
+
#endif
|
679
|
+
for (; ib < nb; ++ib) {
|
680
|
+
uint32_t qh;
|
681
|
+
memcpy(&qh, x[ib].qh, sizeof(qh));
|
682
|
+
|
683
|
+
int sumi0 = 0;
|
684
|
+
int sumi1 = 0;
|
685
|
+
|
686
|
+
for (int j = 0; j < qk/2; ++j) {
|
687
|
+
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
|
688
|
+
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
|
689
|
+
|
690
|
+
const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);
|
691
|
+
const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16);
|
692
|
+
|
693
|
+
sumi0 += (x0 * y[ib].qs[j]);
|
694
|
+
sumi1 += (x1 * y[ib].qs[j + qk/2]);
|
695
|
+
}
|
696
|
+
|
697
|
+
int sumi = sumi0 + sumi1;
|
698
|
+
sumf += (LM_GGML_CPU_FP16_TO_FP32(x[ib].d)*LM_GGML_CPU_FP16_TO_FP32(y[ib].d)) * sumi;
|
699
|
+
}
|
700
|
+
|
701
|
+
*s = sumf;
|
702
|
+
}
|
703
|
+
|
704
|
+
void lm_ggml_vec_dot_q5_1_q8_1(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
|
705
|
+
const int qk = QK8_1;
|
706
|
+
const int nb = n / qk;
|
707
|
+
|
708
|
+
int ib = 0;
|
709
|
+
float sumf = 0;
|
710
|
+
|
711
|
+
assert(n % qk == 0);
|
712
|
+
assert(qk == QK5_1);
|
713
|
+
assert(nrc == 1);
|
714
|
+
UNUSED(nrc);
|
715
|
+
UNUSED(bx);
|
716
|
+
UNUSED(by);
|
717
|
+
UNUSED(bs);
|
718
|
+
|
719
|
+
const block_q5_1 * LM_GGML_RESTRICT x = vx;
|
720
|
+
const block_q8_1 * LM_GGML_RESTRICT y = vy;
|
721
|
+
|
722
|
+
#if defined(__ARM_NEON)
|
723
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
724
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
725
|
+
|
726
|
+
float summs0 = 0.0f;
|
727
|
+
float summs1 = 0.0f;
|
728
|
+
|
729
|
+
uint32_t qh0;
|
730
|
+
uint32_t qh1;
|
731
|
+
|
732
|
+
uint64_t tmp0[4];
|
733
|
+
uint64_t tmp1[4];
|
734
|
+
|
735
|
+
for (; ib + 1 < nb; ib += 2) {
|
736
|
+
const block_q5_1 * LM_GGML_RESTRICT x0 = &x[ib];
|
737
|
+
const block_q5_1 * LM_GGML_RESTRICT x1 = &x[ib + 1];
|
738
|
+
const block_q8_1 * LM_GGML_RESTRICT y0 = &y[ib];
|
739
|
+
const block_q8_1 * LM_GGML_RESTRICT y1 = &y[ib + 1];
|
740
|
+
|
741
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0F);
|
742
|
+
|
743
|
+
summs0 += LM_GGML_CPU_FP16_TO_FP32(x0->m) * LM_GGML_CPU_FP16_TO_FP32(y0->s);
|
744
|
+
summs1 += LM_GGML_CPU_FP16_TO_FP32(x1->m) * LM_GGML_CPU_FP16_TO_FP32(y1->s);
|
745
|
+
|
746
|
+
// extract the 5th bit via lookup table ((b) << 4)
|
747
|
+
memcpy(&qh0, x0->qh, sizeof(qh0));
|
748
|
+
memcpy(&qh1, x1->qh, sizeof(qh1));
|
749
|
+
|
750
|
+
tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF];
|
751
|
+
tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF];
|
752
|
+
tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
|
753
|
+
tmp0[3] = table_b2b_0[(qh0 >> 24) ];
|
754
|
+
|
755
|
+
tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF];
|
756
|
+
tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF];
|
757
|
+
tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
|
758
|
+
tmp1[3] = table_b2b_0[(qh1 >> 24) ];
|
759
|
+
|
760
|
+
const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
|
761
|
+
const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
|
762
|
+
const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
|
763
|
+
const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
|
764
|
+
|
765
|
+
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
766
|
+
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
767
|
+
|
768
|
+
// 4-bit -> 8-bit
|
769
|
+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
|
770
|
+
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
771
|
+
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
|
772
|
+
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
773
|
+
|
774
|
+
// add high bit
|
775
|
+
const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0);
|
776
|
+
const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0);
|
777
|
+
const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1);
|
778
|
+
const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1);
|
779
|
+
|
780
|
+
// load y
|
781
|
+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
782
|
+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
783
|
+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
|
784
|
+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
|
785
|
+
|
786
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
787
|
+
lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
|
788
|
+
lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), LM_GGML_CPU_FP16_TO_FP32(x0->d)*LM_GGML_CPU_FP16_TO_FP32(y0->d));
|
789
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
790
|
+
lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
|
791
|
+
lm_ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), LM_GGML_CPU_FP16_TO_FP32(x1->d)*LM_GGML_CPU_FP16_TO_FP32(y1->d));
|
792
|
+
}
|
793
|
+
|
794
|
+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
|
795
|
+
|
796
|
+
#endif
|
797
|
+
for (; ib < nb; ++ib) {
|
798
|
+
uint32_t qh;
|
799
|
+
memcpy(&qh, x[ib].qh, sizeof(qh));
|
800
|
+
|
801
|
+
int sumi0 = 0;
|
802
|
+
int sumi1 = 0;
|
803
|
+
|
804
|
+
for (int j = 0; j < qk/2; ++j) {
|
805
|
+
const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
|
806
|
+
const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
|
807
|
+
|
808
|
+
const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;
|
809
|
+
const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1;
|
810
|
+
|
811
|
+
sumi0 += (x0 * y[ib].qs[j]);
|
812
|
+
sumi1 += (x1 * y[ib].qs[j + qk/2]);
|
813
|
+
}
|
814
|
+
|
815
|
+
int sumi = sumi0 + sumi1;
|
816
|
+
sumf += (LM_GGML_CPU_FP16_TO_FP32(x[ib].d)*LM_GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + LM_GGML_CPU_FP16_TO_FP32(x[ib].m)*LM_GGML_CPU_FP16_TO_FP32(y[ib].s);
|
817
|
+
}
|
818
|
+
|
819
|
+
*s = sumf;
|
820
|
+
}
|
821
|
+
|
822
|
+
void lm_ggml_vec_dot_q8_0_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
|
823
|
+
const int qk = QK8_0;
|
824
|
+
const int nb = n / qk;
|
825
|
+
|
826
|
+
assert(n % qk == 0);
|
827
|
+
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
828
|
+
assert((nrc == 2) || (nrc == 1));
|
829
|
+
#else
|
830
|
+
assert(nrc == 1);
|
831
|
+
#endif
|
832
|
+
UNUSED(nrc);
|
833
|
+
UNUSED(bx);
|
834
|
+
UNUSED(by);
|
835
|
+
UNUSED(bs);
|
836
|
+
|
837
|
+
const block_q8_0 * LM_GGML_RESTRICT x = vx;
|
838
|
+
const block_q8_0 * LM_GGML_RESTRICT y = vy;
|
839
|
+
|
840
|
+
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
841
|
+
if (nrc == 2) {
|
842
|
+
const block_q8_0 * LM_GGML_RESTRICT vx0 = vx;
|
843
|
+
const block_q8_0 * LM_GGML_RESTRICT vx1 = (const block_q8_0 *) ((const uint8_t*)vx + bx);
|
844
|
+
const block_q8_0 * LM_GGML_RESTRICT vy0 = vy;
|
845
|
+
const block_q8_0 * LM_GGML_RESTRICT vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
|
846
|
+
|
847
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
848
|
+
|
849
|
+
for (int i = 0; i < nb; i++) {
|
850
|
+
const block_q8_0 * LM_GGML_RESTRICT b_x0 = &vx0[i];
|
851
|
+
const block_q8_0 * LM_GGML_RESTRICT b_y0 = &vy0[i];
|
852
|
+
|
853
|
+
const block_q8_0 * LM_GGML_RESTRICT b_x1 = &vx1[i];
|
854
|
+
const block_q8_0 * LM_GGML_RESTRICT b_y1 = &vy1[i];
|
855
|
+
|
856
|
+
const int8x16_t x0_l = vld1q_s8(b_x0->qs);
|
857
|
+
const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
|
858
|
+
const int8x16_t x1_l = vld1q_s8(b_x1->qs);
|
859
|
+
const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);
|
860
|
+
|
861
|
+
// load y
|
862
|
+
const int8x16_t y0_l = vld1q_s8(b_y0->qs);
|
863
|
+
const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
|
864
|
+
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
|
865
|
+
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
|
866
|
+
|
867
|
+
float32_t _scale[4] = {
|
868
|
+
LM_GGML_CPU_FP16_TO_FP32(b_x0->d)*LM_GGML_CPU_FP16_TO_FP32(b_y0->d),
|
869
|
+
LM_GGML_CPU_FP16_TO_FP32(b_x0->d)*LM_GGML_CPU_FP16_TO_FP32(b_y1->d),
|
870
|
+
LM_GGML_CPU_FP16_TO_FP32(b_x1->d)*LM_GGML_CPU_FP16_TO_FP32(b_y0->d),
|
871
|
+
LM_GGML_CPU_FP16_TO_FP32(b_x1->d)*LM_GGML_CPU_FP16_TO_FP32(b_y1->d)
|
872
|
+
};
|
873
|
+
float32x4_t scale = vld1q_f32(_scale);
|
874
|
+
|
875
|
+
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
876
|
+
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
|
877
|
+
|
878
|
+
int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
|
879
|
+
int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
|
880
|
+
|
881
|
+
int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
|
882
|
+
int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
|
883
|
+
|
884
|
+
int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
885
|
+
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
|
886
|
+
|
887
|
+
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
|
888
|
+
l1, r1)), l2, r2)), l3, r3))), scale);
|
889
|
+
}
|
890
|
+
|
891
|
+
float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
|
892
|
+
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
|
893
|
+
|
894
|
+
vst1_f32(s, vget_low_f32 (sumv2));
|
895
|
+
vst1_f32(s + bs, vget_high_f32(sumv2));
|
896
|
+
|
897
|
+
return;
|
898
|
+
}
|
899
|
+
#endif
|
900
|
+
|
901
|
+
int ib = 0;
|
902
|
+
float sumf = 0;
|
903
|
+
|
904
|
+
#if defined(__ARM_FEATURE_SVE)
|
905
|
+
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
906
|
+
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
907
|
+
|
908
|
+
const int vector_length = lm_ggml_cpu_get_sve_cnt()*8;
|
909
|
+
|
910
|
+
//VLA Implemenation for SVE
|
911
|
+
switch (vector_length) {
|
912
|
+
case 128:
|
913
|
+
{
|
914
|
+
// predicate for activating lanes for 16 Int8 elements
|
915
|
+
const svbool_t ph16 = svptrue_pat_b8 (SV_VL16);
|
916
|
+
const svbool_t pl16 = svptrue_pat_b32(SV_VL4);
|
917
|
+
|
918
|
+
for (; ib + 1 < nb; ib += 2) {
|
919
|
+
const block_q8_0 * LM_GGML_RESTRICT x0 = &x[ib + 0];
|
920
|
+
const block_q8_0 * LM_GGML_RESTRICT x1 = &x[ib + 1];
|
921
|
+
const block_q8_0 * LM_GGML_RESTRICT y0 = &y[ib + 0];
|
922
|
+
const block_q8_0 * LM_GGML_RESTRICT y1 = &y[ib + 1];
|
923
|
+
|
924
|
+
// load x
|
925
|
+
const svint8_t qx0_0 = svld1_s8(ph16, x0->qs);
|
926
|
+
const svint8_t qx0_1 = svld1_s8(ph16, x0->qs+16);
|
927
|
+
const svint8_t qx1_0 = svld1_s8(ph16, x1->qs);
|
928
|
+
const svint8_t qx1_1 = svld1_s8(ph16, x1->qs+16);
|
929
|
+
|
930
|
+
// load y
|
931
|
+
const svint8_t qy0_0 = svld1_s8(ph16, y0->qs);
|
932
|
+
const svint8_t qy0_1 = svld1_s8(ph16, y0->qs+16);
|
933
|
+
const svint8_t qy1_0 = svld1_s8(ph16, y1->qs);
|
934
|
+
const svint8_t qy1_1 = svld1_s8(ph16, y1->qs+16);
|
935
|
+
|
936
|
+
sumv0 = svmla_n_f32_x(pl16, sumv0, svcvt_f32_s32_x(pl16, svadd_x(pl16,
|
937
|
+
svdot_s32(svdup_n_s32(0), qx0_0, qy0_0),
|
938
|
+
svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), LM_GGML_CPU_FP16_TO_FP32(x0->d)*LM_GGML_CPU_FP16_TO_FP32(y0->d));
|
939
|
+
sumv1 = svmla_n_f32_x(pl16, sumv1, svcvt_f32_s32_x(pl16, svadd_x(pl16,
|
940
|
+
svdot_s32(svdup_n_s32(0), qx1_0, qy1_0),
|
941
|
+
svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), LM_GGML_CPU_FP16_TO_FP32(x1->d)*LM_GGML_CPU_FP16_TO_FP32(y1->d));
|
942
|
+
}
|
943
|
+
|
944
|
+
sumf = svaddv_f32(pl16, svadd_f32_x(pl16, sumv0, sumv1));
|
945
|
+
} break;
|
946
|
+
case 256:
|
947
|
+
{
|
948
|
+
//printf("sve256");
|
949
|
+
for (; ib + 1 < nb; ib += 2) {
|
950
|
+
const block_q8_0 * LM_GGML_RESTRICT x0 = &x[ib + 0];
|
951
|
+
const block_q8_0 * LM_GGML_RESTRICT x1 = &x[ib + 1];
|
952
|
+
const block_q8_0 * LM_GGML_RESTRICT y0 = &y[ib + 0];
|
953
|
+
const block_q8_0 * LM_GGML_RESTRICT y1 = &y[ib + 1];
|
954
|
+
|
955
|
+
// load x
|
956
|
+
const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
|
957
|
+
const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
|
958
|
+
|
959
|
+
// load y
|
960
|
+
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
|
961
|
+
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
|
962
|
+
|
963
|
+
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
|
964
|
+
svdot_s32(svdup_n_s32(0), qx0, qy0)), LM_GGML_CPU_FP16_TO_FP32(x0->d)*LM_GGML_CPU_FP16_TO_FP32(y0->d));
|
965
|
+
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
|
966
|
+
svdot_s32(svdup_n_s32(0), qx1, qy1)), LM_GGML_CPU_FP16_TO_FP32(x1->d)*LM_GGML_CPU_FP16_TO_FP32(y1->d));
|
967
|
+
}
|
968
|
+
|
969
|
+
sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
970
|
+
} break;
|
971
|
+
case 512:
|
972
|
+
{
|
973
|
+
// predicate for activating high 256 bit
|
974
|
+
const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
|
975
|
+
// predicate for activating low 256 bit
|
976
|
+
const svbool_t pl32 = svnot_b_z(svptrue_b8(), ph32);
|
977
|
+
|
978
|
+
// predicate for activating high lanes for 8 float32 elements
|
979
|
+
const svbool_t ph8 = svptrue_pat_b32(SV_VL8);
|
980
|
+
// predicate for activating low lanes for 8 float32 elements
|
981
|
+
const svbool_t pl8 = svnot_b_z(svptrue_b32(), ph8);
|
982
|
+
|
983
|
+
svfloat32_t sumv00 = svdup_n_f32(0.0f);
|
984
|
+
|
985
|
+
for (; ib + 1 < nb; ib += 2) {
|
986
|
+
const block_q8_0 * LM_GGML_RESTRICT x0 = &x[ib + 0];
|
987
|
+
const block_q8_0 * LM_GGML_RESTRICT x1 = &x[ib + 1];
|
988
|
+
const block_q8_0 * LM_GGML_RESTRICT y0 = &y[ib + 0];
|
989
|
+
const block_q8_0 * LM_GGML_RESTRICT y1 = &y[ib + 1];
|
990
|
+
|
991
|
+
//load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits
|
992
|
+
// and add them to make one 64 element vector
|
993
|
+
// load x
|
994
|
+
const svint8_t qx_32 = svld1_s8(ph32, x0->qs);
|
995
|
+
svint8_t qx_64 = svld1_s8(pl32, x0->qs + 2);
|
996
|
+
|
997
|
+
qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64);
|
998
|
+
|
999
|
+
// load y
|
1000
|
+
const svint8_t qy_32 = svld1_s8(ph32, y0->qs);
|
1001
|
+
svint8_t qy_64 = svld1_s8(pl32, y0->qs + 2);
|
1002
|
+
|
1003
|
+
qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64);
|
1004
|
+
|
1005
|
+
// scale creation
|
1006
|
+
const float32_t deq1 = LM_GGML_CPU_FP16_TO_FP32(x0->d)*LM_GGML_CPU_FP16_TO_FP32(y0->d);
|
1007
|
+
const float32_t deq2 = LM_GGML_CPU_FP16_TO_FP32(x1->d)*LM_GGML_CPU_FP16_TO_FP32(y1->d);
|
1008
|
+
|
1009
|
+
// duplicate deq1 in first half of vector and deq2 in second half of vector
|
1010
|
+
const svfloat32_t temp = svdup_f32_m(svdup_f32_z(ph8, deq1), pl8, deq2);
|
1011
|
+
|
1012
|
+
const svfloat32_t sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64));
|
1013
|
+
|
1014
|
+
sumv00 = svmla_f32_m(svptrue_b32(), sumv00, sumvt, temp);
|
1015
|
+
}
|
1016
|
+
|
1017
|
+
sumf = svaddv_f32(svptrue_b32(), sumv00);
|
1018
|
+
break;
|
1019
|
+
}
|
1020
|
+
default:
|
1021
|
+
assert(false && "Unsupported vector length");
|
1022
|
+
break;
|
1023
|
+
}
|
1024
|
+
#elif defined(__ARM_NEON)
|
1025
|
+
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
1026
|
+
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
1027
|
+
|
1028
|
+
for (; ib + 1 < nb; ib += 2) {
|
1029
|
+
const block_q8_0 * LM_GGML_RESTRICT x0 = &x[ib + 0];
|
1030
|
+
const block_q8_0 * LM_GGML_RESTRICT x1 = &x[ib + 1];
|
1031
|
+
const block_q8_0 * LM_GGML_RESTRICT y0 = &y[ib + 0];
|
1032
|
+
const block_q8_0 * LM_GGML_RESTRICT y1 = &y[ib + 1];
|
1033
|
+
|
1034
|
+
const int8x16_t x0_0 = vld1q_s8(x0->qs);
|
1035
|
+
const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
|
1036
|
+
const int8x16_t x1_0 = vld1q_s8(x1->qs);
|
1037
|
+
const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
|
1038
|
+
|
1039
|
+
// load y
|
1040
|
+
const int8x16_t y0_0 = vld1q_s8(y0->qs);
|
1041
|
+
const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
|
1042
|
+
const int8x16_t y1_0 = vld1q_s8(y1->qs);
|
1043
|
+
const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
|
1044
|
+
|
1045
|
+
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
|
1046
|
+
lm_ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
|
1047
|
+
lm_ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), LM_GGML_CPU_FP16_TO_FP32(x0->d)*LM_GGML_CPU_FP16_TO_FP32(y0->d));
|
1048
|
+
|
1049
|
+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
|
1050
|
+
lm_ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
|
1051
|
+
lm_ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), LM_GGML_CPU_FP16_TO_FP32(x1->d)*LM_GGML_CPU_FP16_TO_FP32(y1->d));
|
1052
|
+
}
|
1053
|
+
|
1054
|
+
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
1055
|
+
#endif
|
1056
|
+
for (; ib < nb; ++ib) {
|
1057
|
+
int sumi = 0;
|
1058
|
+
|
1059
|
+
for (int j = 0; j < qk; j++) {
|
1060
|
+
sumi += x[ib].qs[j]*y[ib].qs[j];
|
1061
|
+
}
|
1062
|
+
|
1063
|
+
sumf += sumi*(LM_GGML_CPU_FP16_TO_FP32(x[ib].d)*LM_GGML_CPU_FP16_TO_FP32(y[ib].d));
|
1064
|
+
}
|
1065
|
+
|
1066
|
+
*s = sumf;
|
1067
|
+
}
|
1068
|
+
|
1069
|
+
void lm_ggml_vec_dot_tq1_0_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
|
1070
|
+
assert(nrc == 1);
|
1071
|
+
UNUSED(nrc);
|
1072
|
+
UNUSED(bx);
|
1073
|
+
UNUSED(by);
|
1074
|
+
UNUSED(bs);
|
1075
|
+
|
1076
|
+
const block_tq1_0 * LM_GGML_RESTRICT x = vx;
|
1077
|
+
const block_q8_K * LM_GGML_RESTRICT y = vy;
|
1078
|
+
|
1079
|
+
const int nb = n / QK_K;
|
1080
|
+
|
1081
|
+
#if defined(__ARM_NEON)
|
1082
|
+
float sumf = 0.0f;
|
1083
|
+
|
1084
|
+
uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
|
1085
|
+
|
1086
|
+
const uint8x16_t shift = vld1q_u8(k_shift);
|
1087
|
+
|
1088
|
+
for (int i = 0; i < nb; ++i) {
|
1089
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
1090
|
+
int32x4_t sumi0 = vdupq_n_s32(0);
|
1091
|
+
int32x4_t sumi1 = vdupq_n_s32(0);
|
1092
|
+
#else
|
1093
|
+
int16x8_t sumi0 = vdupq_n_s16(0);
|
1094
|
+
int16x8_t sumi1 = vdupq_n_s16(0);
|
1095
|
+
#endif
|
1096
|
+
|
1097
|
+
// first 32 bytes of 5 elements
|
1098
|
+
{
|
1099
|
+
uint8x16_t qx0 = vld1q_u8(x[i].qs + 0);
|
1100
|
+
uint8x16_t qx1 = vld1q_u8(x[i].qs + 16);
|
1101
|
+
uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3));
|
1102
|
+
uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3));
|
1103
|
+
uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9));
|
1104
|
+
uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9));
|
1105
|
+
uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27));
|
1106
|
+
uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27));
|
1107
|
+
uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81));
|
1108
|
+
uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81));
|
1109
|
+
|
1110
|
+
// multiply by 3 and keep the 2 bits above 8 bits
|
1111
|
+
int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
|
1112
|
+
int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
|
1113
|
+
int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
|
1114
|
+
int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
|
1115
|
+
int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
|
1116
|
+
int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
|
1117
|
+
int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6));
|
1118
|
+
int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6));
|
1119
|
+
int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6));
|
1120
|
+
int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6));
|
1121
|
+
|
1122
|
+
const int8x16_t qy0 = vld1q_s8(y[i].qs + 0);
|
1123
|
+
const int8x16_t qy1 = vld1q_s8(y[i].qs + 16);
|
1124
|
+
const int8x16_t qy2 = vld1q_s8(y[i].qs + 32);
|
1125
|
+
const int8x16_t qy3 = vld1q_s8(y[i].qs + 48);
|
1126
|
+
const int8x16_t qy4 = vld1q_s8(y[i].qs + 64);
|
1127
|
+
const int8x16_t qy5 = vld1q_s8(y[i].qs + 80);
|
1128
|
+
const int8x16_t qy6 = vld1q_s8(y[i].qs + 96);
|
1129
|
+
const int8x16_t qy7 = vld1q_s8(y[i].qs + 112);
|
1130
|
+
const int8x16_t qy8 = vld1q_s8(y[i].qs + 128);
|
1131
|
+
const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);
|
1132
|
+
|
1133
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
1134
|
+
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
|
1135
|
+
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
|
1136
|
+
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
|
1137
|
+
sumi1 = vdotq_s32(sumi1, sqx3, qy3);
|
1138
|
+
sumi0 = vdotq_s32(sumi0, sqx4, qy4);
|
1139
|
+
sumi1 = vdotq_s32(sumi1, sqx5, qy5);
|
1140
|
+
sumi0 = vdotq_s32(sumi0, sqx6, qy6);
|
1141
|
+
sumi1 = vdotq_s32(sumi1, sqx7, qy7);
|
1142
|
+
sumi0 = vdotq_s32(sumi0, sqx8, qy8);
|
1143
|
+
sumi1 = vdotq_s32(sumi1, sqx9, qy9);
|
1144
|
+
#else
|
1145
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
|
1146
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
|
1147
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
|
1148
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
|
1149
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
|
1150
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
|
1151
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
|
1152
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
|
1153
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
|
1154
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
|
1155
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
|
1156
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
|
1157
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
|
1158
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
|
1159
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
|
1160
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
|
1161
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx8), vget_low_s8(qy8));
|
1162
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8));
|
1163
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9));
|
1164
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9));
|
1165
|
+
#endif
|
1166
|
+
}
|
1167
|
+
|
1168
|
+
// last 16 bytes of 5-element, along with the 4 bytes of 4 elements
|
1169
|
+
{
|
1170
|
+
uint8x16_t qx0 = vld1q_u8(x[i].qs + 32);
|
1171
|
+
uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3));
|
1172
|
+
uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9));
|
1173
|
+
uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27));
|
1174
|
+
uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81));
|
1175
|
+
uint32_t qh;
|
1176
|
+
memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
|
1177
|
+
uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh));
|
1178
|
+
qx5 = vmulq_u8(qx5, shift);
|
1179
|
+
|
1180
|
+
// multiply by 3 and keep the 2 bits above 8 bits
|
1181
|
+
int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
|
1182
|
+
int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
|
1183
|
+
int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
|
1184
|
+
int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
|
1185
|
+
int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
|
1186
|
+
int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
|
1187
|
+
|
1188
|
+
const int8x16_t qy0 = vld1q_s8(y[i].qs + 160);
|
1189
|
+
const int8x16_t qy1 = vld1q_s8(y[i].qs + 176);
|
1190
|
+
const int8x16_t qy2 = vld1q_s8(y[i].qs + 192);
|
1191
|
+
const int8x16_t qy3 = vld1q_s8(y[i].qs + 208);
|
1192
|
+
const int8x16_t qy4 = vld1q_s8(y[i].qs + 224);
|
1193
|
+
const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);
|
1194
|
+
|
1195
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
1196
|
+
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
|
1197
|
+
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
|
1198
|
+
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
|
1199
|
+
sumi1 = vdotq_s32(sumi1, sqx3, qy3);
|
1200
|
+
sumi0 = vdotq_s32(sumi0, sqx4, qy4);
|
1201
|
+
sumi1 = vdotq_s32(sumi1, sqx5, qy5);
|
1202
|
+
#else
|
1203
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
|
1204
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
|
1205
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
|
1206
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
|
1207
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
|
1208
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
|
1209
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
|
1210
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
|
1211
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
|
1212
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
|
1213
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
|
1214
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
|
1215
|
+
#endif
|
1216
|
+
}
|
1217
|
+
|
1218
|
+
const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
|
1219
|
+
const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
|
1220
|
+
|
1221
|
+
const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
1222
|
+
|
1223
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
1224
|
+
sumi0 = vaddq_s32(sumi0, sumi1);
|
1225
|
+
sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
|
1226
|
+
|
1227
|
+
sumf += d * (float) vaddvq_s32(sumi0);
|
1228
|
+
#else
|
1229
|
+
sumi0 = vaddq_s16(sumi0, sumi1);
|
1230
|
+
sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
|
1231
|
+
|
1232
|
+
sumf += d * (float) vaddlvq_s16(sumi0);
|
1233
|
+
#endif
|
1234
|
+
}
|
1235
|
+
|
1236
|
+
*s = sumf;
|
1237
|
+
|
1238
|
+
#else
|
1239
|
+
const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
|
1240
|
+
|
1241
|
+
float sumf = 0.0f;
|
1242
|
+
|
1243
|
+
for (int i = 0; i < nb; ++i) {
|
1244
|
+
int sum = 0;
|
1245
|
+
|
1246
|
+
for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {
|
1247
|
+
for (size_t l = 0; l < 5; ++l) {
|
1248
|
+
for (size_t m = 0; m < 32; ++m) {
|
1249
|
+
uint8_t q = x[i].qs[j + m] * pow3[l];
|
1250
|
+
uint16_t xi = ((uint16_t) q * 3) >> 8;
|
1251
|
+
sum += (xi - 1) * y[i].qs[j*5 + l*32 + m];
|
1252
|
+
}
|
1253
|
+
}
|
1254
|
+
}
|
1255
|
+
for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {
|
1256
|
+
for (size_t l = 0; l < 5; ++l) {
|
1257
|
+
for (size_t m = 0; m < 16; ++m) {
|
1258
|
+
uint8_t q = x[i].qs[j + m] * pow3[l];
|
1259
|
+
uint16_t xi = ((uint16_t) q * 3) >> 8;
|
1260
|
+
sum += (xi - 1) * y[i].qs[j*5 + l*16 + m];
|
1261
|
+
}
|
1262
|
+
}
|
1263
|
+
}
|
1264
|
+
|
1265
|
+
for (size_t l = 0; l < 4; ++l) {
|
1266
|
+
for (size_t j = 0; j < sizeof(x->qh); ++j) {
|
1267
|
+
uint8_t q = x[i].qh[j] * pow3[l];
|
1268
|
+
uint16_t xi = ((uint16_t) q * 3) >> 8;
|
1269
|
+
sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j];
|
1270
|
+
}
|
1271
|
+
}
|
1272
|
+
|
1273
|
+
sumf += (float) sum * (LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d);
|
1274
|
+
}
|
1275
|
+
|
1276
|
+
*s = sumf;
|
1277
|
+
#endif
|
1278
|
+
}
|
1279
|
+
|
1280
|
+
void lm_ggml_vec_dot_tq2_0_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
|
1281
|
+
assert(nrc == 1);
|
1282
|
+
UNUSED(nrc);
|
1283
|
+
UNUSED(bx);
|
1284
|
+
UNUSED(by);
|
1285
|
+
UNUSED(bs);
|
1286
|
+
|
1287
|
+
const block_tq2_0 * LM_GGML_RESTRICT x = vx;
|
1288
|
+
const block_q8_K * LM_GGML_RESTRICT y = vy;
|
1289
|
+
|
1290
|
+
const int nb = n / QK_K;
|
1291
|
+
|
1292
|
+
#if defined(__ARM_NEON)
|
1293
|
+
float sumf = 0.0f;
|
1294
|
+
|
1295
|
+
const uint8x16_t m3 = vdupq_n_u8(3);
|
1296
|
+
|
1297
|
+
for (int i = 0; i < nb; ++i) {
|
1298
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
1299
|
+
int32x4_t sumi0 = vdupq_n_s32(0);
|
1300
|
+
int32x4_t sumi1 = vdupq_n_s32(0);
|
1301
|
+
#else
|
1302
|
+
int16x8_t sumi0 = vdupq_n_s16(0);
|
1303
|
+
int16x8_t sumi1 = vdupq_n_s16(0);
|
1304
|
+
#endif
|
1305
|
+
|
1306
|
+
for (size_t j = 0; j < sizeof(x->qs); j += 32) {
|
1307
|
+
uint8x16_t qx0 = vld1q_u8(x[i].qs + j);
|
1308
|
+
uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16);
|
1309
|
+
uint8x16_t qx2 = vshrq_n_u8(qx0, 2);
|
1310
|
+
uint8x16_t qx3 = vshrq_n_u8(qx1, 2);
|
1311
|
+
uint8x16_t qx4 = vshrq_n_u8(qx0, 4);
|
1312
|
+
uint8x16_t qx5 = vshrq_n_u8(qx1, 4);
|
1313
|
+
uint8x16_t qx6 = vshrq_n_u8(qx0, 6);
|
1314
|
+
uint8x16_t qx7 = vshrq_n_u8(qx1, 6);
|
1315
|
+
|
1316
|
+
int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3));
|
1317
|
+
int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3));
|
1318
|
+
int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3));
|
1319
|
+
int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3));
|
1320
|
+
int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3));
|
1321
|
+
int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3));
|
1322
|
+
int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3));
|
1323
|
+
int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3));
|
1324
|
+
|
1325
|
+
const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0);
|
1326
|
+
const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16);
|
1327
|
+
const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32);
|
1328
|
+
const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48);
|
1329
|
+
const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64);
|
1330
|
+
const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80);
|
1331
|
+
const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96);
|
1332
|
+
const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112);
|
1333
|
+
|
1334
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
1335
|
+
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
|
1336
|
+
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
|
1337
|
+
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
|
1338
|
+
sumi1 = vdotq_s32(sumi1, sqx3, qy3);
|
1339
|
+
sumi0 = vdotq_s32(sumi0, sqx4, qy4);
|
1340
|
+
sumi1 = vdotq_s32(sumi1, sqx5, qy5);
|
1341
|
+
sumi0 = vdotq_s32(sumi0, sqx6, qy6);
|
1342
|
+
sumi1 = vdotq_s32(sumi1, sqx7, qy7);
|
1343
|
+
#else
|
1344
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
|
1345
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
|
1346
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
|
1347
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
|
1348
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
|
1349
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
|
1350
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
|
1351
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
|
1352
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
|
1353
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
|
1354
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
|
1355
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
|
1356
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
|
1357
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
|
1358
|
+
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
|
1359
|
+
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
|
1360
|
+
#endif
|
1361
|
+
}
|
1362
|
+
|
1363
|
+
const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
|
1364
|
+
const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
|
1365
|
+
|
1366
|
+
const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
1367
|
+
|
1368
|
+
#if defined(__ARM_FEATURE_DOTPROD)
|
1369
|
+
sumi0 = vaddq_s32(sumi0, sumi1);
|
1370
|
+
sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
|
1371
|
+
|
1372
|
+
sumf += d * (float) vaddvq_s32(sumi0);
|
1373
|
+
#else
|
1374
|
+
sumi0 = vaddq_s16(sumi0, sumi1);
|
1375
|
+
sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
|
1376
|
+
|
1377
|
+
sumf += d * (float) vaddlvq_s16(sumi0);
|
1378
|
+
#endif
|
1379
|
+
}
|
1380
|
+
|
1381
|
+
*s = sumf;
|
1382
|
+
|
1383
|
+
#else
|
1384
|
+
float sumf = 0.0f;
|
1385
|
+
|
1386
|
+
for (int i = 0; i < nb; ++i) {
|
1387
|
+
int32_t sumi = 0;
|
1388
|
+
|
1389
|
+
for (size_t j = 0; j < sizeof(x->qs); j += 32) {
|
1390
|
+
for (size_t l = 0; l < 4; ++l) {
|
1391
|
+
for (size_t k = 0; k < 32; ++k) {
|
1392
|
+
sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1);
|
1393
|
+
}
|
1394
|
+
}
|
1395
|
+
}
|
1396
|
+
|
1397
|
+
const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
|
1398
|
+
|
1399
|
+
sumf += (float) sumi * d;
|
1400
|
+
}
|
1401
|
+
|
1402
|
+
*s = sumf;
|
1403
|
+
#endif
|
1404
|
+
}
|
1405
|
+
|
1406
|
+
void lm_ggml_vec_dot_q2_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
|
1407
|
+
assert(nrc == 1);
|
1408
|
+
UNUSED(nrc);
|
1409
|
+
UNUSED(bx);
|
1410
|
+
UNUSED(by);
|
1411
|
+
UNUSED(bs);
|
1412
|
+
|
1413
|
+
const block_q2_K * LM_GGML_RESTRICT x = vx;
|
1414
|
+
const block_q8_K * LM_GGML_RESTRICT y = vy;
|
1415
|
+
|
1416
|
+
const int nb = n / QK_K;
|
1417
|
+
|
1418
|
+
#ifdef __ARM_FEATURE_SVE
|
1419
|
+
const int vector_length = svcntb()*8;
|
1420
|
+
const svuint8_t m3s = svdup_n_u8(0x3);
|
1421
|
+
const svuint32_t m4s = svdup_n_u32(0xF);
|
1422
|
+
const svint32_t vzero_sv = svdup_n_s32(0);
|
1423
|
+
svfloat32_t acc_sum = svdup_n_f32(0);
|
1424
|
+
svbool_t pred_s32 = svptrue_pat_b32(SV_VL4);
|
1425
|
+
|
1426
|
+
switch (vector_length) {
|
1427
|
+
case 128:
|
1428
|
+
for (int i = 0; i < nb; ++i) {
|
1429
|
+
const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
|
1430
|
+
svfloat32_t d_broad = svdup_n_f32((float32_t)d);
|
1431
|
+
const float dmin = -y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].dmin);
|
1432
|
+
svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin);
|
1433
|
+
|
1434
|
+
const uint8_t * LM_GGML_RESTRICT q2 = x[i].qs;
|
1435
|
+
const int8_t * LM_GGML_RESTRICT q8_sv = y[i].qs;
|
1436
|
+
const uint8_t * LM_GGML_RESTRICT sc = x[i].scales;
|
1437
|
+
|
1438
|
+
svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc);
|
1439
|
+
const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
|
1440
|
+
|
1441
|
+
mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+4);
|
1442
|
+
const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
|
1443
|
+
|
1444
|
+
svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums);
|
1445
|
+
svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+4);
|
1446
|
+
|
1447
|
+
const svint32_t s0 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_2, q8sums_sv_2));
|
1448
|
+
|
1449
|
+
mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+8);
|
1450
|
+
const svint32_t mins_sv_3 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
|
1451
|
+
|
1452
|
+
mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+12);
|
1453
|
+
const svint32_t mins_sv_4 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4));
|
1454
|
+
|
1455
|
+
q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums+8);
|
1456
|
+
q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+12);
|
1457
|
+
|
1458
|
+
svint32_t s1 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_3, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_4, q8sums_sv_2));
|
1459
|
+
|
1460
|
+
svfloat32_t temp = svcvt_f32_s32_x(svptrue_b32(), svadd_s32_x(svptrue_b32(), s0, s1));
|
1461
|
+
|
1462
|
+
acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, temp, dmin_broad);
|
1463
|
+
|
1464
|
+
svint32_t sumi1 = svdup_n_s32(0);
|
1465
|
+
|
1466
|
+
{
|
1467
|
+
const svuint8_t q2bits_1 = svld1_u8(svptrue_b8(), q2);
|
1468
|
+
svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_1, m3s));
|
1469
|
+
svint8_t q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1470
|
+
const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc), m4s));
|
1471
|
+
|
1472
|
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 0));
|
1473
|
+
|
1474
|
+
const svuint8_t q2bits_3 = svld1_u8(svptrue_b8(), q2+16);
|
1475
|
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_3, m3s));
|
1476
|
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1477
|
+
|
1478
|
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 1));
|
1479
|
+
|
1480
|
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 2), m3s));
|
1481
|
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1482
|
+
|
1483
|
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 2));
|
1484
|
+
|
1485
|
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 2), m3s));
|
1486
|
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1487
|
+
|
1488
|
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 3));
|
1489
|
+
|
1490
|
+
|
1491
|
+
const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+4), m4s));
|
1492
|
+
|
1493
|
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 4), m3s));
|
1494
|
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1495
|
+
|
1496
|
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 0));
|
1497
|
+
|
1498
|
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 4), m3s));
|
1499
|
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1500
|
+
|
1501
|
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 1));
|
1502
|
+
|
1503
|
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 6), m3s));
|
1504
|
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1505
|
+
|
1506
|
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 2));
|
1507
|
+
|
1508
|
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 6), m3s));
|
1509
|
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1510
|
+
|
1511
|
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 3));
|
1512
|
+
|
1513
|
+
//-------------------------------
|
1514
|
+
|
1515
|
+
q2 += 32;
|
1516
|
+
const svint32_t scales_sv_2 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+8), m4s));
|
1517
|
+
const svuint8_t q2bits_2 = svld1_u8(svptrue_b8(), q2);
|
1518
|
+
|
1519
|
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_2, m3s));
|
1520
|
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1521
|
+
|
1522
|
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 0));
|
1523
|
+
|
1524
|
+
const svuint8_t q2bits_4 = svld1_u8(svptrue_b8(), q2+16);
|
1525
|
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_4, m3s));
|
1526
|
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1527
|
+
|
1528
|
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 1));
|
1529
|
+
|
1530
|
+
|
1531
|
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 2), m3s));
|
1532
|
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1533
|
+
|
1534
|
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 2));
|
1535
|
+
|
1536
|
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 2), m3s));
|
1537
|
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1538
|
+
|
1539
|
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 3));
|
1540
|
+
|
1541
|
+
|
1542
|
+
const svint32_t scales_sv_3 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+12), m4s));
|
1543
|
+
|
1544
|
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 4), m3s));
|
1545
|
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1546
|
+
|
1547
|
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 0));
|
1548
|
+
|
1549
|
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 4), m3s));
|
1550
|
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1551
|
+
|
1552
|
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 1));
|
1553
|
+
|
1554
|
+
|
1555
|
+
|
1556
|
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 6), m3s));
|
1557
|
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1558
|
+
|
1559
|
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 2));
|
1560
|
+
|
1561
|
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 6), m3s));
|
1562
|
+
q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1563
|
+
|
1564
|
+
sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 3));
|
1565
|
+
}
|
1566
|
+
acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, svcvt_f32_s32_x(svptrue_b32(), sumi1), d_broad);
|
1567
|
+
}
|
1568
|
+
*s = svaddv_f32(svptrue_b32(), acc_sum);
|
1569
|
+
break;
|
1570
|
+
|
1571
|
+
case 256:
|
1572
|
+
case 512:
|
1573
|
+
for (int i = 0; i < nb; ++i) {
|
1574
|
+
const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
|
1575
|
+
svfloat32_t d_broad = svdup_n_f32((float32_t)d);
|
1576
|
+
const float dmin = -y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].dmin);
|
1577
|
+
svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin);
|
1578
|
+
|
1579
|
+
const uint8_t * LM_GGML_RESTRICT q2 = x[i].qs;
|
1580
|
+
const int8_t * LM_GGML_RESTRICT q8_sv = y[i].qs;
|
1581
|
+
const uint8_t * LM_GGML_RESTRICT sc = x[i].scales;
|
1582
|
+
|
1583
|
+
const svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc); sc += 8;
|
1584
|
+
const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, m4s));
|
1585
|
+
const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, 4));
|
1586
|
+
svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums);
|
1587
|
+
|
1588
|
+
const svuint32_t mins_and_scales_sve_1 = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc);
|
1589
|
+
const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, m4s));
|
1590
|
+
const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, 4));
|
1591
|
+
|
1592
|
+
svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums+8);
|
1593
|
+
|
1594
|
+
svfloat32_t temp = svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_2, q8sums_sv_2)));
|
1595
|
+
|
1596
|
+
acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, temp, dmin_broad);
|
1597
|
+
|
1598
|
+
svint32_t sumi1 = svdup_n_s32(0);
|
1599
|
+
|
1600
|
+
{
|
1601
|
+
const svuint8_t q2bits_1 = svld1_u8(svptrue_pat_b8(SV_VL32), q2);
|
1602
|
+
svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_1, m3s));
|
1603
|
+
svint8_t q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
1604
|
+
|
1605
|
+
svint32_t scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 0), svdup_lane_s32(scales_sv, 1));
|
1606
|
+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
|
1607
|
+
|
1608
|
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 2), m3s));
|
1609
|
+
q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
1610
|
+
|
1611
|
+
svint32_t scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 2), svdup_lane_s32(scales_sv, 3));
|
1612
|
+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(svdup_n_s32(0), q2bytes_sv, q8bytes_sv), scale_2);
|
1613
|
+
|
1614
|
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 4), m3s));
|
1615
|
+
q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
1616
|
+
|
1617
|
+
scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 4), svdup_lane_s32(scales_sv, 5));
|
1618
|
+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
|
1619
|
+
|
1620
|
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 6), m3s));
|
1621
|
+
q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
1622
|
+
|
1623
|
+
scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 6), svdup_lane_s32(scales_sv, 7));
|
1624
|
+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
|
1625
|
+
|
1626
|
+
q2 += 32;
|
1627
|
+
|
1628
|
+
const svuint8_t q2bits_2 = svld1_u8(svptrue_pat_b8(SV_VL32), q2);
|
1629
|
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_2, m3s));
|
1630
|
+
q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
1631
|
+
|
1632
|
+
scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 0), svdup_lane_s32(scales_sv_1, 1));
|
1633
|
+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
|
1634
|
+
|
1635
|
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 2), m3s));
|
1636
|
+
q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
1637
|
+
|
1638
|
+
scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 2), svdup_lane_s32(scales_sv_1, 3));
|
1639
|
+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
|
1640
|
+
|
1641
|
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 4), m3s));
|
1642
|
+
q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
1643
|
+
|
1644
|
+
scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 4), svdup_lane_s32(scales_sv_1, 5));
|
1645
|
+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1);
|
1646
|
+
|
1647
|
+
q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 6), m3s));
|
1648
|
+
q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
1649
|
+
|
1650
|
+
scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 6), svdup_lane_s32(scales_sv_1, 7));
|
1651
|
+
sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2);
|
1652
|
+
}
|
1653
|
+
acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), sumi1), d_broad);
|
1654
|
+
}
|
1655
|
+
*s = svaddv_f32(svptrue_pat_b32(SV_VL8), acc_sum);
|
1656
|
+
break;
|
1657
|
+
|
1658
|
+
default:
|
1659
|
+
assert(false && "Unsupported vector length");
|
1660
|
+
break;
|
1661
|
+
}
|
1662
|
+
|
1663
|
+
#elif __ARM_NEON
|
1664
|
+
const uint8x16_t m3 = vdupq_n_u8(0x3);
|
1665
|
+
const uint8x16_t m4 = vdupq_n_u8(0xF);
|
1666
|
+
|
1667
|
+
const int32x4_t vzero = vdupq_n_s32(0);
|
1668
|
+
|
1669
|
+
lm_ggml_int8x16x2_t q2bytes;
|
1670
|
+
uint8_t aux[16];
|
1671
|
+
|
1672
|
+
float sum = 0;
|
1673
|
+
|
1674
|
+
for (int i = 0; i < nb; ++i) {
|
1675
|
+
const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
|
1676
|
+
const float dmin = -y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].dmin);
|
1677
|
+
|
1678
|
+
const uint8_t * LM_GGML_RESTRICT q2 = x[i].qs;
|
1679
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
1680
|
+
const uint8_t * LM_GGML_RESTRICT sc = x[i].scales;
|
1681
|
+
|
1682
|
+
const uint8x16_t mins_and_scales = vld1q_u8(sc);
|
1683
|
+
const uint8x16_t scales = vandq_u8(mins_and_scales, m4);
|
1684
|
+
vst1q_u8(aux, scales);
|
1685
|
+
|
1686
|
+
const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
|
1687
|
+
const lm_ggml_int16x8x2_t q8sums = lm_ggml_vld1q_s16_x2(y[i].bsums);
|
1688
|
+
const lm_ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}};
|
1689
|
+
const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
|
1690
|
+
vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
|
1691
|
+
const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
|
1692
|
+
vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1])));
|
1693
|
+
sum += dmin * vaddvq_s32(vaddq_s32(s0, s1));
|
1694
|
+
|
1695
|
+
int isum = 0;
|
1696
|
+
int is = 0;
|
1697
|
+
|
1698
|
+
// We use this macro instead of a function call because for some reason
|
1699
|
+
// the code runs 2-3% slower, even if the function is declared inline
|
1700
|
+
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
|
1701
|
+
isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
|
1702
|
+
isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
|
1703
|
+
|
1704
|
+
#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
|
1705
|
+
q8bytes = lm_ggml_vld1q_s8_x2(q8); q8 += 32;\
|
1706
|
+
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
|
1707
|
+
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
|
1708
|
+
MULTIPLY_ACCUM_WITH_SCALE((index));
|
1709
|
+
|
1710
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
1711
|
+
const lm_ggml_uint8x16x2_t q2bits = lm_ggml_vld1q_u8_x2(q2); q2 += 32;
|
1712
|
+
|
1713
|
+
lm_ggml_int8x16x2_t q8bytes = lm_ggml_vld1q_s8_x2(q8); q8 += 32;
|
1714
|
+
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
|
1715
|
+
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
|
1716
|
+
|
1717
|
+
MULTIPLY_ACCUM_WITH_SCALE(0);
|
1718
|
+
|
1719
|
+
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
|
1720
|
+
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
|
1721
|
+
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
|
1722
|
+
|
1723
|
+
is += 8;
|
1724
|
+
}
|
1725
|
+
|
1726
|
+
sum += d * isum;
|
1727
|
+
}
|
1728
|
+
|
1729
|
+
*s = sum;
|
1730
|
+
|
1731
|
+
#else
|
1732
|
+
|
1733
|
+
float sumf = 0;
|
1734
|
+
|
1735
|
+
for (int i = 0; i < nb; ++i) {
|
1736
|
+
|
1737
|
+
const uint8_t * q2 = x[i].qs;
|
1738
|
+
const int8_t * q8 = y[i].qs;
|
1739
|
+
const uint8_t * sc = x[i].scales;
|
1740
|
+
|
1741
|
+
int summs = 0;
|
1742
|
+
for (int j = 0; j < 16; ++j) {
|
1743
|
+
summs += y[i].bsums[j] * (sc[j] >> 4);
|
1744
|
+
}
|
1745
|
+
|
1746
|
+
const float dall = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
|
1747
|
+
const float dmin = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].dmin);
|
1748
|
+
|
1749
|
+
int isum = 0;
|
1750
|
+
int is = 0;
|
1751
|
+
int d;
|
1752
|
+
for (int k = 0; k < QK_K/128; ++k) {
|
1753
|
+
int shift = 0;
|
1754
|
+
for (int j = 0; j < 4; ++j) {
|
1755
|
+
d = sc[is++] & 0xF;
|
1756
|
+
int isuml = 0;
|
1757
|
+
for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
|
1758
|
+
isum += d * isuml;
|
1759
|
+
d = sc[is++] & 0xF;
|
1760
|
+
isuml = 0;
|
1761
|
+
for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
|
1762
|
+
isum += d * isuml;
|
1763
|
+
shift += 2;
|
1764
|
+
q8 += 32;
|
1765
|
+
}
|
1766
|
+
q2 += 32;
|
1767
|
+
}
|
1768
|
+
sumf += dall * isum - dmin * summs;
|
1769
|
+
}
|
1770
|
+
*s = sumf;
|
1771
|
+
#endif
|
1772
|
+
}
|
1773
|
+
|
1774
|
+
void lm_ggml_vec_dot_q3_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
|
1775
|
+
assert(n % QK_K == 0);
|
1776
|
+
assert(nrc == 1);
|
1777
|
+
UNUSED(nrc);
|
1778
|
+
UNUSED(bx);
|
1779
|
+
UNUSED(by);
|
1780
|
+
UNUSED(bs);
|
1781
|
+
|
1782
|
+
const uint32_t kmask1 = 0x03030303;
|
1783
|
+
const uint32_t kmask2 = 0x0f0f0f0f;
|
1784
|
+
|
1785
|
+
const block_q3_K * LM_GGML_RESTRICT x = vx;
|
1786
|
+
const block_q8_K * LM_GGML_RESTRICT y = vy;
|
1787
|
+
|
1788
|
+
const int nb = n / QK_K;
|
1789
|
+
|
1790
|
+
#if defined(__ARM_FEATURE_SVE)
|
1791
|
+
|
1792
|
+
uint32_t aux[3];
|
1793
|
+
uint32_t utmp[4];
|
1794
|
+
|
1795
|
+
const int8_t m32 = 32;
|
1796
|
+
const int vector_length = svcntb()*8;
|
1797
|
+
const svuint8_t m3b_sv = svdup_n_u8(0x3);
|
1798
|
+
const svint32_t vzero_sv = svdup_n_s32(0);
|
1799
|
+
|
1800
|
+
const svuint8_t m0_sv = svdup_n_u8(1);
|
1801
|
+
const svuint8_t m1_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 1);
|
1802
|
+
const svuint8_t m2_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 2);
|
1803
|
+
const svuint8_t m3_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 3);
|
1804
|
+
|
1805
|
+
float sum = 0;
|
1806
|
+
|
1807
|
+
for (int i = 0; i < nb; ++i) {
|
1808
|
+
|
1809
|
+
const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
|
1810
|
+
|
1811
|
+
const uint8_t * LM_GGML_RESTRICT q3_sv = x[i].qs;
|
1812
|
+
const uint8_t * LM_GGML_RESTRICT qh_sv = x[i].hmask;
|
1813
|
+
const int8_t * LM_GGML_RESTRICT q8_sv = y[i].qs;
|
1814
|
+
|
1815
|
+
// Set up scales
|
1816
|
+
memcpy(aux, x[i].scales, 12);
|
1817
|
+
utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
|
1818
|
+
utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
|
1819
|
+
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
|
1820
|
+
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
|
1821
|
+
|
1822
|
+
int8_t * scale = (int8_t *)utmp;
|
1823
|
+
|
1824
|
+
for (int j = 0; j < 16; ++j) scale[j] -= m32;
|
1825
|
+
|
1826
|
+
switch (vector_length) {
|
1827
|
+
case 128:
|
1828
|
+
{
|
1829
|
+
svuint8_t qhbits_sv_1 = svld1_u8(svptrue_b8(), qh_sv);
|
1830
|
+
svuint8_t qhbits_sv_2 = svld1_u8(svptrue_b8(), qh_sv+16);
|
1831
|
+
svuint8_t q3h_sv;
|
1832
|
+
|
1833
|
+
svint32_t sumi1_1 = svdup_n_s32(0);
|
1834
|
+
svint8_t q3bytes_sv;
|
1835
|
+
|
1836
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
1837
|
+
|
1838
|
+
const svuint8_t q3bits_sv = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
|
1839
|
+
const svuint8_t q3bits_sv_1 = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
|
1840
|
+
svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1841
|
+
svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1842
|
+
|
1843
|
+
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_1), 2);
|
1844
|
+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
1845
|
+
|
1846
|
+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
|
1847
|
+
|
1848
|
+
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_2), 2);
|
1849
|
+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv_1, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
1850
|
+
|
1851
|
+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
|
1852
|
+
|
1853
|
+
q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1854
|
+
q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1855
|
+
|
1856
|
+
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_1), 1);
|
1857
|
+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
1858
|
+
|
1859
|
+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
|
1860
|
+
|
1861
|
+
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_2), 1);
|
1862
|
+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
1863
|
+
|
1864
|
+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
|
1865
|
+
|
1866
|
+
|
1867
|
+
scale += 4;
|
1868
|
+
q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1869
|
+
q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1870
|
+
|
1871
|
+
q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_1);
|
1872
|
+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
1873
|
+
|
1874
|
+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
|
1875
|
+
|
1876
|
+
q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_2);
|
1877
|
+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
1878
|
+
|
1879
|
+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
|
1880
|
+
|
1881
|
+
|
1882
|
+
q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1883
|
+
q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
|
1884
|
+
|
1885
|
+
q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_1), 1);
|
1886
|
+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
1887
|
+
|
1888
|
+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
|
1889
|
+
|
1890
|
+
q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_2), 1);
|
1891
|
+
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
1892
|
+
|
1893
|
+
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
|
1894
|
+
|
1895
|
+
if (j == 0) {
|
1896
|
+
qhbits_sv_1 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_1, 4);
|
1897
|
+
qhbits_sv_2 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_2, 4);
|
1898
|
+
}
|
1899
|
+
|
1900
|
+
scale += 4;
|
1901
|
+
}
|
1902
|
+
|
1903
|
+
sum += d * (svaddv_s32(svptrue_b32(), sumi1_1));
|
1904
|
+
} break;
|
1905
|
+
case 256:
|
1906
|
+
case 512:
|
1907
|
+
{
|
1908
|
+
svuint8_t qhbits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), qh_sv);
|
1909
|
+
svuint8_t q3h_sv;
|
1910
|
+
|
1911
|
+
svint32_t sumi1_1 = svdup_n_s32(0);
|
1912
|
+
svint8_t q3bytes_sv;
|
1913
|
+
|
1914
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
1915
|
+
|
1916
|
+
const svuint8_t q3bits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), q3_sv); q3_sv += 32;
|
1917
|
+
svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
1918
|
+
svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
1919
|
+
|
1920
|
+
q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m0_sv, qhbits_sv), 2);
|
1921
|
+
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
1922
|
+
|
1923
|
+
|
1924
|
+
svint32_t scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
|
1925
|
+
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
|
1926
|
+
|
1927
|
+
q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m1_sv, qhbits_sv), 1);
|
1928
|
+
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
1929
|
+
|
1930
|
+
scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
|
1931
|
+
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
|
1932
|
+
|
1933
|
+
scale += 4;
|
1934
|
+
q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
1935
|
+
q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
|
1936
|
+
|
1937
|
+
q3h_sv = svbic_u8_x(svptrue_pat_b8(SV_VL32), m2_sv, qhbits_sv);
|
1938
|
+
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
1939
|
+
|
1940
|
+
scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
|
1941
|
+
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
|
1942
|
+
|
1943
|
+
q3h_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m3_sv, qhbits_sv), 1);
|
1944
|
+
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
|
1945
|
+
|
1946
|
+
scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
|
1947
|
+
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
|
1948
|
+
|
1949
|
+
if (j == 0) {
|
1950
|
+
qhbits_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), qhbits_sv, 4);
|
1951
|
+
}
|
1952
|
+
|
1953
|
+
scale += 4;
|
1954
|
+
}
|
1955
|
+
|
1956
|
+
sum += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), sumi1_1));
|
1957
|
+
} break;
|
1958
|
+
default:
|
1959
|
+
assert(false && "Unsupported vector length");
|
1960
|
+
break;
|
1961
|
+
}
|
1962
|
+
}
|
1963
|
+
*s = sum;
|
1964
|
+
|
1965
|
+
#elif __ARM_NEON
|
1966
|
+
|
1967
|
+
uint32_t aux[3];
|
1968
|
+
uint32_t utmp[4];
|
1969
|
+
|
1970
|
+
const uint8x16_t m3b = vdupq_n_u8(0x3);
|
1971
|
+
const int32x4_t vzero = vdupq_n_s32(0);
|
1972
|
+
|
1973
|
+
const uint8x16_t m0 = vdupq_n_u8(1);
|
1974
|
+
const uint8x16_t m1 = vshlq_n_u8(m0, 1);
|
1975
|
+
const uint8x16_t m2 = vshlq_n_u8(m0, 2);
|
1976
|
+
const uint8x16_t m3 = vshlq_n_u8(m0, 3);
|
1977
|
+
const int8_t m32 = 32;
|
1978
|
+
|
1979
|
+
lm_ggml_int8x16x4_t q3bytes;
|
1980
|
+
|
1981
|
+
float sum = 0;
|
1982
|
+
|
1983
|
+
for (int i = 0; i < nb; ++i) {
|
1984
|
+
|
1985
|
+
const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
|
1986
|
+
|
1987
|
+
const uint8_t * LM_GGML_RESTRICT q3 = x[i].qs;
|
1988
|
+
const uint8_t * LM_GGML_RESTRICT qh = x[i].hmask;
|
1989
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
1990
|
+
|
1991
|
+
lm_ggml_uint8x16x2_t qhbits = lm_ggml_vld1q_u8_x2(qh);
|
1992
|
+
|
1993
|
+
lm_ggml_uint8x16x4_t q3h;
|
1994
|
+
|
1995
|
+
int32_t isum = 0;
|
1996
|
+
|
1997
|
+
// Set up scales
|
1998
|
+
memcpy(aux, x[i].scales, 12);
|
1999
|
+
utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
|
2000
|
+
utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
|
2001
|
+
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
|
2002
|
+
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
|
2003
|
+
|
2004
|
+
int8_t * scale = (int8_t *)utmp;
|
2005
|
+
for (int j = 0; j < 16; ++j) scale[j] -= m32;
|
2006
|
+
|
2007
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
2008
|
+
|
2009
|
+
const lm_ggml_uint8x16x2_t q3bits = lm_ggml_vld1q_u8_x2(q3); q3 += 32;
|
2010
|
+
const lm_ggml_int8x16x4_t q8bytes_1 = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
|
2011
|
+
const lm_ggml_int8x16x4_t q8bytes_2 = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
|
2012
|
+
|
2013
|
+
q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
|
2014
|
+
q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
|
2015
|
+
q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1);
|
2016
|
+
q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1);
|
2017
|
+
|
2018
|
+
q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0]));
|
2019
|
+
q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1]));
|
2020
|
+
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
|
2021
|
+
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
|
2022
|
+
|
2023
|
+
isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
|
2024
|
+
isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
|
2025
|
+
isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
|
2026
|
+
isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
|
2027
|
+
|
2028
|
+
scale += 4;
|
2029
|
+
|
2030
|
+
q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
|
2031
|
+
q3h.val[1] = vbicq_u8(m2, qhbits.val[1]);
|
2032
|
+
q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1);
|
2033
|
+
q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1);
|
2034
|
+
|
2035
|
+
q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0]));
|
2036
|
+
q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1]));
|
2037
|
+
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
|
2038
|
+
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
|
2039
|
+
|
2040
|
+
isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
|
2041
|
+
isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
|
2042
|
+
isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
|
2043
|
+
isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
|
2044
|
+
|
2045
|
+
scale += 4;
|
2046
|
+
|
2047
|
+
if (j == 0) {
|
2048
|
+
qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4);
|
2049
|
+
qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4);
|
2050
|
+
}
|
2051
|
+
|
2052
|
+
}
|
2053
|
+
sum += d * isum;
|
2054
|
+
|
2055
|
+
}
|
2056
|
+
|
2057
|
+
*s = sum;
|
2058
|
+
|
2059
|
+
#else
|
2060
|
+
// scalar version
|
2061
|
+
// This function is written like this so the compiler can manage to vectorize most of it
|
2062
|
+
// Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
|
2063
|
+
// manually vectorized version above. Every other version I tried would run at least 4 times slower.
|
2064
|
+
// The ideal situation would be if we could just write the code once, and the compiler would
|
2065
|
+
// automatically produce the best possible set of machine instructions, instead of us having to manually
|
2066
|
+
// write vectorized versions for AVX, ARM_NEON, etc.
|
2067
|
+
|
2068
|
+
int8_t aux8[QK_K];
|
2069
|
+
int16_t aux16[8];
|
2070
|
+
float sums [8];
|
2071
|
+
int32_t aux32[8];
|
2072
|
+
memset(sums, 0, 8*sizeof(float));
|
2073
|
+
|
2074
|
+
uint32_t auxs[4];
|
2075
|
+
const int8_t * scales = (const int8_t*)auxs;
|
2076
|
+
|
2077
|
+
float sumf = 0;
|
2078
|
+
for (int i = 0; i < nb; ++i) {
|
2079
|
+
const uint8_t * LM_GGML_RESTRICT q3 = x[i].qs;
|
2080
|
+
const uint8_t * LM_GGML_RESTRICT hm = x[i].hmask;
|
2081
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
2082
|
+
memset(aux32, 0, 8*sizeof(int32_t));
|
2083
|
+
int8_t * LM_GGML_RESTRICT a = aux8;
|
2084
|
+
uint8_t m = 1;
|
2085
|
+
for (int j = 0; j < QK_K; j += 128) {
|
2086
|
+
for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
|
2087
|
+
for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
|
2088
|
+
a += 32; m <<= 1;
|
2089
|
+
for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
|
2090
|
+
for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
|
2091
|
+
a += 32; m <<= 1;
|
2092
|
+
for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
|
2093
|
+
for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
|
2094
|
+
a += 32; m <<= 1;
|
2095
|
+
for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
|
2096
|
+
for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
|
2097
|
+
a += 32; m <<= 1;
|
2098
|
+
q3 += 32;
|
2099
|
+
}
|
2100
|
+
a = aux8;
|
2101
|
+
|
2102
|
+
memcpy(auxs, x[i].scales, 12);
|
2103
|
+
uint32_t tmp = auxs[2];
|
2104
|
+
auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
|
2105
|
+
auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
|
2106
|
+
auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
|
2107
|
+
auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
|
2108
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
2109
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
2110
|
+
for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
|
2111
|
+
q8 += 8; a += 8;
|
2112
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
2113
|
+
for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
|
2114
|
+
q8 += 8; a += 8;
|
2115
|
+
}
|
2116
|
+
const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
2117
|
+
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
|
2118
|
+
}
|
2119
|
+
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
2120
|
+
*s = sumf;
|
2121
|
+
|
2122
|
+
#endif
|
2123
|
+
|
2124
|
+
}
|
2125
|
+
|
2126
|
+
void lm_ggml_vec_dot_q4_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
|
2127
|
+
assert(n % QK_K == 0);
|
2128
|
+
#ifdef __ARM_FEATURE_MATMUL_INT8
|
2129
|
+
assert((nrc == 2) || (nrc == 1));
|
2130
|
+
#else
|
2131
|
+
assert(nrc == 1);
|
2132
|
+
#endif
|
2133
|
+
UNUSED(nrc);
|
2134
|
+
UNUSED(bx);
|
2135
|
+
UNUSED(by);
|
2136
|
+
UNUSED(bs);
|
2137
|
+
|
2138
|
+
const block_q4_K * LM_GGML_RESTRICT x = vx;
|
2139
|
+
const block_q8_K * LM_GGML_RESTRICT y = vy;
|
2140
|
+
|
2141
|
+
const int nb = n / QK_K;
|
2142
|
+
|
2143
|
+
static const uint32_t kmask1 = 0x3f3f3f3f;
|
2144
|
+
static const uint32_t kmask2 = 0x0f0f0f0f;
|
2145
|
+
static const uint32_t kmask3 = 0x03030303;
|
2146
|
+
|
2147
|
+
uint32_t utmp[4];
|
2148
|
+
|
2149
|
+
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
2150
|
+
if (nrc == 2) {
|
2151
|
+
const block_q4_K * LM_GGML_RESTRICT x0 = x;
|
2152
|
+
const block_q4_K * LM_GGML_RESTRICT x1 = (const block_q4_K *) ((const uint8_t *)vx + bx);
|
2153
|
+
const block_q8_K * LM_GGML_RESTRICT y0 = y;
|
2154
|
+
const block_q8_K * LM_GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
|
2155
|
+
|
2156
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
2157
|
+
|
2158
|
+
float32x4_t vfsum = vdupq_n_f32(0.0f);
|
2159
|
+
|
2160
|
+
for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
|
2161
|
+
const uint8_t * LM_GGML_RESTRICT qx0 = x0->qs;
|
2162
|
+
const uint8_t * LM_GGML_RESTRICT qx1 = x1->qs;
|
2163
|
+
const int8_t * LM_GGML_RESTRICT qy0 = y0->qs;
|
2164
|
+
const int8_t * LM_GGML_RESTRICT qy1 = y1->qs;
|
2165
|
+
|
2166
|
+
// decode scales and mins
|
2167
|
+
int8_t x0_scales[8], x1_scales[8];
|
2168
|
+
int16x8_t x0_mins, x1_mins;
|
2169
|
+
{
|
2170
|
+
uint32_t scales_mins[3];
|
2171
|
+
memcpy(scales_mins, x0->scales, 12);
|
2172
|
+
const uint32_t mins_0_3 = scales_mins[1] & kmask1;
|
2173
|
+
const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4);
|
2174
|
+
const uint32x2_t mins = {mins_0_3, mins_4_7};
|
2175
|
+
x0_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins)));
|
2176
|
+
uint32_t scales[2];
|
2177
|
+
scales[0] = scales_mins[0] & kmask1; // scales 0~3
|
2178
|
+
scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7
|
2179
|
+
memcpy(x0_scales, scales, 8);
|
2180
|
+
}
|
2181
|
+
{
|
2182
|
+
uint32_t scales_mins[3];
|
2183
|
+
memcpy(scales_mins, x1->scales, 12);
|
2184
|
+
const uint32_t mins_0_3 = scales_mins[1] & kmask1;
|
2185
|
+
const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4);
|
2186
|
+
const uint32x2_t mins = {mins_0_3, mins_4_7};
|
2187
|
+
x1_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins)));
|
2188
|
+
uint32_t scales[2];
|
2189
|
+
scales[0] = scales_mins[0] & kmask1; // scales 0~3
|
2190
|
+
scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7
|
2191
|
+
memcpy(x1_scales, scales, 8);
|
2192
|
+
}
|
2193
|
+
|
2194
|
+
int32x4_t visum = {0};
|
2195
|
+
|
2196
|
+
// process 64 data points per iteration, totally 256 data points
|
2197
|
+
for (int j = 0; j < QK_K / 64; ++j, qx0 += 32, qx1 += 32, qy0 += 64, qy1 += 64) {
|
2198
|
+
const int8x16x4_t vy0 = vld1q_s8_x4(qy0);
|
2199
|
+
const int8x16x4_t vy1 = vld1q_s8_x4(qy1);
|
2200
|
+
|
2201
|
+
int8x16_t vx0[4], vx1[4];
|
2202
|
+
{
|
2203
|
+
const uint8x16x2_t vv = vld1q_u8_x2(qx0);
|
2204
|
+
vx0[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b));
|
2205
|
+
vx0[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b));
|
2206
|
+
vx0[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4));
|
2207
|
+
vx0[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4));
|
2208
|
+
}
|
2209
|
+
{
|
2210
|
+
const uint8x16x2_t vv = vld1q_u8_x2(qx1);
|
2211
|
+
vx1[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b));
|
2212
|
+
vx1[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b));
|
2213
|
+
vx1[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4));
|
2214
|
+
vx1[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4));
|
2215
|
+
}
|
2216
|
+
|
2217
|
+
// process 32 data points (share same block scale) per iteration
|
2218
|
+
for (int k = 0; k < 2; ++k) {
|
2219
|
+
const int blk = j * 2 + k;
|
2220
|
+
const int32x4_t block_scale = {
|
2221
|
+
x0_scales[blk],
|
2222
|
+
x0_scales[blk],
|
2223
|
+
x1_scales[blk],
|
2224
|
+
x1_scales[blk],
|
2225
|
+
};
|
2226
|
+
|
2227
|
+
int32x4_t vr = {0};
|
2228
|
+
for (int l = 0; l < 2; ++l) {
|
2229
|
+
const int idx = k * 2 + l;
|
2230
|
+
const int64x2_t vx0_s64 = vreinterpretq_s64_s8(vx0[idx]);
|
2231
|
+
const int64x2_t vx1_s64 = vreinterpretq_s64_s8(vx1[idx]);
|
2232
|
+
const int64x2_t vy0_s64 = vreinterpretq_s64_s8(vy0.val[idx]);
|
2233
|
+
const int64x2_t vy1_s64 = vreinterpretq_s64_s8(vy1.val[idx]);
|
2234
|
+
const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vx0_s64, vx1_s64));
|
2235
|
+
const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vx0_s64, vx1_s64));
|
2236
|
+
const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vy0_s64, vy1_s64));
|
2237
|
+
const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vy0_s64, vy1_s64));
|
2238
|
+
vr = vmmlaq_s32(vr, vx_l, vy_l);
|
2239
|
+
vr = vmmlaq_s32(vr, vx_h, vy_h);
|
2240
|
+
}
|
2241
|
+
// apply block scale, will NOT overflow
|
2242
|
+
// block_scale * sum_256(int4*int8) <= 2^(8+8+4+8) = 28 bits
|
2243
|
+
visum = vmlaq_s32(visum, vr, block_scale);
|
2244
|
+
}
|
2245
|
+
}
|
2246
|
+
|
2247
|
+
// adjust bias, apply superblock scale
|
2248
|
+
{
|
2249
|
+
int32_t bias[4];
|
2250
|
+
// no obvious uplift from sve sdot-16, just use neon mul add
|
2251
|
+
const int16x8_t y0_sums = vpaddq_s16(vld1q_s16(y0->bsums), vld1q_s16(y0->bsums+8));
|
2252
|
+
const int16x8_t y1_sums = vpaddq_s16(vld1q_s16(y1->bsums), vld1q_s16(y1->bsums+8));
|
2253
|
+
bias[0] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x0_mins)),
|
2254
|
+
vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x0_mins))));
|
2255
|
+
bias[1] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x0_mins)),
|
2256
|
+
vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x0_mins))));
|
2257
|
+
bias[2] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x1_mins)),
|
2258
|
+
vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x1_mins))));
|
2259
|
+
bias[3] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x1_mins)),
|
2260
|
+
vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x1_mins))));
|
2261
|
+
const float32x4_t dmins = {
|
2262
|
+
LM_GGML_CPU_FP16_TO_FP32(x0->dmin) * y0->d,
|
2263
|
+
LM_GGML_CPU_FP16_TO_FP32(x0->dmin) * y1->d,
|
2264
|
+
LM_GGML_CPU_FP16_TO_FP32(x1->dmin) * y0->d,
|
2265
|
+
LM_GGML_CPU_FP16_TO_FP32(x1->dmin) * y1->d,
|
2266
|
+
};
|
2267
|
+
vfsum = vmlsq_f32(vfsum, vcvtq_f32_s32(vld1q_s32(bias)), dmins);
|
2268
|
+
|
2269
|
+
const float32x4_t superblock_scale = {
|
2270
|
+
LM_GGML_CPU_FP16_TO_FP32(x0->d) * y0->d,
|
2271
|
+
LM_GGML_CPU_FP16_TO_FP32(x0->d) * y1->d,
|
2272
|
+
LM_GGML_CPU_FP16_TO_FP32(x1->d) * y0->d,
|
2273
|
+
LM_GGML_CPU_FP16_TO_FP32(x1->d) * y1->d,
|
2274
|
+
};
|
2275
|
+
vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
|
2276
|
+
}
|
2277
|
+
}
|
2278
|
+
|
2279
|
+
// vfsum = ABCD -> ACBD
|
2280
|
+
// AC -> s, BD -> (s+bs)
|
2281
|
+
vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
|
2282
|
+
vst1_f32(s, vget_low_f32 (vfsum));
|
2283
|
+
vst1_f32(s + bs, vget_high_f32(vfsum));
|
2284
|
+
|
2285
|
+
return;
|
2286
|
+
}
|
2287
|
+
#endif
|
2288
|
+
|
2289
|
+
#ifdef __ARM_FEATURE_SVE
|
2290
|
+
float sumf = 0;
|
2291
|
+
for (int i = 0; i < nb; ++i) {
|
2292
|
+
|
2293
|
+
const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
|
2294
|
+
const float dmin = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].dmin);
|
2295
|
+
|
2296
|
+
const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
|
2297
|
+
|
2298
|
+
memcpy(utmp, x[i].scales, K_SCALE_SIZE);
|
2299
|
+
|
2300
|
+
uint32x2_t mins8 = { 0 };
|
2301
|
+
mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
|
2302
|
+
mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
|
2303
|
+
|
2304
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
2305
|
+
utmp[0] &= kmask1;
|
2306
|
+
|
2307
|
+
const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
|
2308
|
+
const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
|
2309
|
+
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
|
2310
|
+
sumf -= dmin * vaddvq_s32(prod);
|
2311
|
+
|
2312
|
+
const uint8_t * scales = (const uint8_t *)utmp;
|
2313
|
+
|
2314
|
+
const uint8_t * LM_GGML_RESTRICT q4 = x[i].qs;
|
2315
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
2316
|
+
|
2317
|
+
const int vector_length = lm_ggml_cpu_get_sve_cnt()*8;
|
2318
|
+
const svuint8_t m4b = svdup_n_u8(0xf);
|
2319
|
+
const svint32_t mzero = svdup_n_s32(0);
|
2320
|
+
svint32_t sumi1 = svdup_n_s32(0);
|
2321
|
+
svint32_t sumi1_1 = svdup_n_s32(0);
|
2322
|
+
svint32_t sumi1_2 = svdup_n_s32(0);
|
2323
|
+
svint32_t sumi2 = svdup_n_s32(0);
|
2324
|
+
svint32_t sumi2_1 = svdup_n_s32(0);
|
2325
|
+
svint32_t sumi2_2 = svdup_n_s32(0);
|
2326
|
+
switch (vector_length) {
|
2327
|
+
case 128:
|
2328
|
+
{
|
2329
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
2330
|
+
svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b));
|
2331
|
+
svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
|
2332
|
+
sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
|
2333
|
+
q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b));
|
2334
|
+
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
|
2335
|
+
sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
|
2336
|
+
|
2337
|
+
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4));
|
2338
|
+
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
|
2339
|
+
sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
|
2340
|
+
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4));
|
2341
|
+
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
|
2342
|
+
sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
|
2343
|
+
q4 += 32;
|
2344
|
+
}
|
2345
|
+
sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2);
|
2346
|
+
sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2);
|
2347
|
+
sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2)));
|
2348
|
+
} break;
|
2349
|
+
case 256:
|
2350
|
+
case 512:
|
2351
|
+
{
|
2352
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
2353
|
+
const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32;
|
2354
|
+
svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b));
|
2355
|
+
svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
|
2356
|
+
sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
|
2357
|
+
|
2358
|
+
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4));
|
2359
|
+
q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
|
2360
|
+
sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
|
2361
|
+
}
|
2362
|
+
sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2)));
|
2363
|
+
} break;
|
2364
|
+
default:
|
2365
|
+
assert(false && "Unsupported vector length");
|
2366
|
+
break;
|
2367
|
+
}
|
2368
|
+
}
|
2369
|
+
*s = sumf;
|
2370
|
+
#elif defined __ARM_NEON
|
2371
|
+
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
2372
|
+
const int32x4_t mzero = vdupq_n_s32(0);
|
2373
|
+
|
2374
|
+
lm_ggml_int8x16x2_t q4bytes;
|
2375
|
+
lm_ggml_int8x16x2_t q8bytes;
|
2376
|
+
|
2377
|
+
float sumf = 0;
|
2378
|
+
|
2379
|
+
for (int i = 0; i < nb; ++i) {
|
2380
|
+
|
2381
|
+
const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
|
2382
|
+
const float dmin = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].dmin);
|
2383
|
+
|
2384
|
+
const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
|
2385
|
+
|
2386
|
+
memcpy(utmp, x[i].scales, 12);
|
2387
|
+
|
2388
|
+
uint32x2_t mins8 = { 0 };
|
2389
|
+
mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
|
2390
|
+
mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
|
2391
|
+
|
2392
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
2393
|
+
utmp[0] &= kmask1;
|
2394
|
+
|
2395
|
+
const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
|
2396
|
+
const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
|
2397
|
+
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
|
2398
|
+
sumf -= dmin * vaddvq_s32(prod);
|
2399
|
+
|
2400
|
+
const uint8_t * scales = (const uint8_t *)utmp;
|
2401
|
+
|
2402
|
+
const uint8_t * LM_GGML_RESTRICT q4 = x[i].qs;
|
2403
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
2404
|
+
|
2405
|
+
int32_t sumi1 = 0;
|
2406
|
+
int32_t sumi2 = 0;
|
2407
|
+
|
2408
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
2409
|
+
const lm_ggml_uint8x16x2_t q4bits = lm_ggml_vld1q_u8_x2(q4); q4 += 32;
|
2410
|
+
|
2411
|
+
q8bytes = lm_ggml_vld1q_s8_x2(q8); q8 += 32;
|
2412
|
+
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
|
2413
|
+
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
|
2414
|
+
|
2415
|
+
const int32x4_t p1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
2416
|
+
sumi1 += vaddvq_s32(p1) * scales[2*j+0];
|
2417
|
+
|
2418
|
+
q8bytes = lm_ggml_vld1q_s8_x2(q8); q8 += 32;
|
2419
|
+
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
|
2420
|
+
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
|
2421
|
+
|
2422
|
+
const int32x4_t p2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
|
2423
|
+
|
2424
|
+
sumi2 += vaddvq_s32(p2) * scales[2*j+1];
|
2425
|
+
}
|
2426
|
+
|
2427
|
+
sumf += d * (sumi1 + sumi2);
|
2428
|
+
|
2429
|
+
}
|
2430
|
+
|
2431
|
+
*s = sumf;
|
2432
|
+
|
2433
|
+
#else
|
2434
|
+
|
2435
|
+
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
2436
|
+
const uint8_t * mins = (const uint8_t*)&utmp[2];
|
2437
|
+
|
2438
|
+
int8_t aux8[QK_K];
|
2439
|
+
int16_t aux16[8];
|
2440
|
+
float sums [8];
|
2441
|
+
int32_t aux32[8];
|
2442
|
+
memset(sums, 0, 8*sizeof(float));
|
2443
|
+
|
2444
|
+
float sumf = 0;
|
2445
|
+
for (int i = 0; i < nb; ++i) {
|
2446
|
+
const uint8_t * LM_GGML_RESTRICT q4 = x[i].qs;
|
2447
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
2448
|
+
memset(aux32, 0, 8*sizeof(int32_t));
|
2449
|
+
int8_t * LM_GGML_RESTRICT a = aux8;
|
2450
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
2451
|
+
for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
|
2452
|
+
a += 32;
|
2453
|
+
for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
|
2454
|
+
a += 32; q4 += 32;
|
2455
|
+
}
|
2456
|
+
memcpy(utmp, x[i].scales, 12);
|
2457
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
2458
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
2459
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
2460
|
+
utmp[2] = uaux;
|
2461
|
+
utmp[0] &= kmask1;
|
2462
|
+
|
2463
|
+
int sumi = 0;
|
2464
|
+
for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
|
2465
|
+
a = aux8;
|
2466
|
+
int is = 0;
|
2467
|
+
for (int j = 0; j < QK_K/32; ++j) {
|
2468
|
+
int32_t scale = scales[is++];
|
2469
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
2470
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
2471
|
+
q8 += 8; a += 8;
|
2472
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
2473
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
2474
|
+
q8 += 8; a += 8;
|
2475
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
2476
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
2477
|
+
q8 += 8; a += 8;
|
2478
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
2479
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
2480
|
+
q8 += 8; a += 8;
|
2481
|
+
}
|
2482
|
+
const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
2483
|
+
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
|
2484
|
+
const float dmin = LM_GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
|
2485
|
+
sumf -= dmin * sumi;
|
2486
|
+
}
|
2487
|
+
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
2488
|
+
*s = sumf;
|
2489
|
+
#endif
|
2490
|
+
}
|
2491
|
+
|
2492
|
+
void lm_ggml_vec_dot_q5_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
|
2493
|
+
assert(n % QK_K == 0);
|
2494
|
+
assert(nrc == 1);
|
2495
|
+
UNUSED(nrc);
|
2496
|
+
UNUSED(bx);
|
2497
|
+
UNUSED(by);
|
2498
|
+
UNUSED(bs);
|
2499
|
+
|
2500
|
+
const block_q5_K * LM_GGML_RESTRICT x = vx;
|
2501
|
+
const block_q8_K * LM_GGML_RESTRICT y = vy;
|
2502
|
+
|
2503
|
+
const int nb = n / QK_K;
|
2504
|
+
|
2505
|
+
static const uint32_t kmask1 = 0x3f3f3f3f;
|
2506
|
+
static const uint32_t kmask2 = 0x0f0f0f0f;
|
2507
|
+
static const uint32_t kmask3 = 0x03030303;
|
2508
|
+
|
2509
|
+
uint32_t utmp[4];
|
2510
|
+
|
2511
|
+
|
2512
|
+
#ifdef __ARM_NEON
|
2513
|
+
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
2514
|
+
const uint8x16_t mone = vdupq_n_u8(1);
|
2515
|
+
const uint8x16_t mtwo = vdupq_n_u8(2);
|
2516
|
+
const int32x4_t mzero = vdupq_n_s32(0);
|
2517
|
+
|
2518
|
+
lm_ggml_int8x16x4_t q5bytes;
|
2519
|
+
|
2520
|
+
float sumf = 0;
|
2521
|
+
|
2522
|
+
for (int i = 0; i < nb; ++i) {
|
2523
|
+
|
2524
|
+
const float d = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d);
|
2525
|
+
const float dmin = y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].dmin);
|
2526
|
+
|
2527
|
+
const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
|
2528
|
+
|
2529
|
+
memcpy(utmp, x[i].scales, 12);
|
2530
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
2531
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
2532
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
2533
|
+
utmp[2] = uaux;
|
2534
|
+
utmp[0] &= kmask1;
|
2535
|
+
|
2536
|
+
const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8);
|
2537
|
+
const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
|
2538
|
+
const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
|
2539
|
+
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
|
2540
|
+
int32_t sumi_mins = vaddvq_s32(prod);
|
2541
|
+
|
2542
|
+
const uint8_t * scales = (const uint8_t *)utmp;
|
2543
|
+
|
2544
|
+
const uint8_t * LM_GGML_RESTRICT q5 = x[i].qs;
|
2545
|
+
const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
|
2546
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
2547
|
+
|
2548
|
+
lm_ggml_uint8x16x2_t qhbits = lm_ggml_vld1q_u8_x2(qh);
|
2549
|
+
|
2550
|
+
lm_ggml_uint8x16x4_t q5h;
|
2551
|
+
|
2552
|
+
int32_t sumi = 0;
|
2553
|
+
|
2554
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
2555
|
+
|
2556
|
+
const lm_ggml_uint8x16x2_t q5bits = lm_ggml_vld1q_u8_x2(q5); q5 += 32;
|
2557
|
+
const lm_ggml_int8x16x4_t q8bytes = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
|
2558
|
+
|
2559
|
+
q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
|
2560
|
+
q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
|
2561
|
+
q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3);
|
2562
|
+
q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3);
|
2563
|
+
qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2);
|
2564
|
+
qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2);
|
2565
|
+
|
2566
|
+
q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0]));
|
2567
|
+
q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1]));
|
2568
|
+
q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
|
2569
|
+
q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
|
2570
|
+
|
2571
|
+
sumi += vaddvq_s32(lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
|
2572
|
+
sumi += vaddvq_s32(lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
|
2573
|
+
}
|
2574
|
+
|
2575
|
+
sumf += d * sumi - dmin * sumi_mins;
|
2576
|
+
}
|
2577
|
+
|
2578
|
+
*s = sumf;
|
2579
|
+
|
2580
|
+
#else
|
2581
|
+
|
2582
|
+
const uint8_t * scales = (const uint8_t*)&utmp[0];
|
2583
|
+
const uint8_t * mins = (const uint8_t*)&utmp[2];
|
2584
|
+
|
2585
|
+
int8_t aux8[QK_K];
|
2586
|
+
int16_t aux16[8];
|
2587
|
+
float sums [8];
|
2588
|
+
int32_t aux32[8];
|
2589
|
+
memset(sums, 0, 8*sizeof(float));
|
2590
|
+
|
2591
|
+
float sumf = 0;
|
2592
|
+
for (int i = 0; i < nb; ++i) {
|
2593
|
+
const uint8_t * LM_GGML_RESTRICT q4 = x[i].qs;
|
2594
|
+
const uint8_t * LM_GGML_RESTRICT hm = x[i].qh;
|
2595
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
2596
|
+
memset(aux32, 0, 8*sizeof(int32_t));
|
2597
|
+
int8_t * LM_GGML_RESTRICT a = aux8;
|
2598
|
+
uint8_t m = 1;
|
2599
|
+
for (int j = 0; j < QK_K/64; ++j) {
|
2600
|
+
for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
|
2601
|
+
for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
|
2602
|
+
a += 32; m <<= 1;
|
2603
|
+
for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
|
2604
|
+
for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
|
2605
|
+
a += 32; m <<= 1;
|
2606
|
+
q4 += 32;
|
2607
|
+
}
|
2608
|
+
memcpy(utmp, x[i].scales, 12);
|
2609
|
+
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
2610
|
+
const uint32_t uaux = utmp[1] & kmask1;
|
2611
|
+
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
2612
|
+
utmp[2] = uaux;
|
2613
|
+
utmp[0] &= kmask1;
|
2614
|
+
|
2615
|
+
int sumi = 0;
|
2616
|
+
for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
|
2617
|
+
a = aux8;
|
2618
|
+
int is = 0;
|
2619
|
+
for (int j = 0; j < QK_K/32; ++j) {
|
2620
|
+
int32_t scale = scales[is++];
|
2621
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
2622
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
2623
|
+
q8 += 8; a += 8;
|
2624
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
2625
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
2626
|
+
q8 += 8; a += 8;
|
2627
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
2628
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
2629
|
+
q8 += 8; a += 8;
|
2630
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
2631
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
2632
|
+
q8 += 8; a += 8;
|
2633
|
+
}
|
2634
|
+
const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
2635
|
+
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
|
2636
|
+
const float dmin = LM_GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
|
2637
|
+
sumf -= dmin * sumi;
|
2638
|
+
}
|
2639
|
+
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
2640
|
+
*s = sumf;
|
2641
|
+
#endif
|
2642
|
+
}
|
2643
|
+
|
2644
|
+
void lm_ggml_vec_dot_q6_K_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
|
2645
|
+
assert(n % QK_K == 0);
|
2646
|
+
#ifdef __ARM_FEATURE_MATMUL_INT8
|
2647
|
+
assert((nrc == 2) || (nrc == 1));
|
2648
|
+
#else
|
2649
|
+
assert(nrc == 1);
|
2650
|
+
#endif
|
2651
|
+
UNUSED(nrc);
|
2652
|
+
UNUSED(bx);
|
2653
|
+
UNUSED(by);
|
2654
|
+
UNUSED(bs);
|
2655
|
+
|
2656
|
+
const block_q6_K * LM_GGML_RESTRICT x = vx;
|
2657
|
+
const block_q8_K * LM_GGML_RESTRICT y = vy;
|
2658
|
+
|
2659
|
+
const int nb = n / QK_K;
|
2660
|
+
|
2661
|
+
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
2662
|
+
if (nrc == 2) {
|
2663
|
+
const block_q6_K * LM_GGML_RESTRICT x0 = x;
|
2664
|
+
const block_q6_K * LM_GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
|
2665
|
+
const block_q8_K * LM_GGML_RESTRICT y0 = y;
|
2666
|
+
const block_q8_K * LM_GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
|
2667
|
+
|
2668
|
+
float32x4_t vfsum = vdupq_n_f32(0.0f);
|
2669
|
+
|
2670
|
+
for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
|
2671
|
+
const uint8_t * LM_GGML_RESTRICT ql0 = x0->ql;
|
2672
|
+
const uint8_t * LM_GGML_RESTRICT ql1 = x1->ql;
|
2673
|
+
const uint8_t * LM_GGML_RESTRICT qh0 = x0->qh;
|
2674
|
+
const uint8_t * LM_GGML_RESTRICT qh1 = x1->qh;
|
2675
|
+
const int8_t * LM_GGML_RESTRICT qy0 = y0->qs;
|
2676
|
+
const int8_t * LM_GGML_RESTRICT qy1 = y1->qs;
|
2677
|
+
|
2678
|
+
const uint8x16_t mone = vdupq_n_u8(0x30);
|
2679
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
2680
|
+
|
2681
|
+
int32x4_t visum = vdupq_n_s32(0);
|
2682
|
+
|
2683
|
+
// process 8 blocks per iteration, totally 16 blocks
|
2684
|
+
for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) {
|
2685
|
+
int8x16_t vx0[8], vx1[8];
|
2686
|
+
|
2687
|
+
// de-quantize vx0[8]
|
2688
|
+
{
|
2689
|
+
const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0);
|
2690
|
+
const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0);
|
2691
|
+
|
2692
|
+
uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
|
2693
|
+
uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
|
2694
|
+
uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
|
2695
|
+
uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
|
2696
|
+
|
2697
|
+
vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
|
2698
|
+
vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
|
2699
|
+
vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
|
2700
|
+
vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
|
2701
|
+
|
2702
|
+
q6h_0 = vandq_u8(mone, qh_bits.val[0]);
|
2703
|
+
q6h_1 = vandq_u8(mone, qh_bits.val[1]);
|
2704
|
+
q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
|
2705
|
+
q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
|
2706
|
+
|
2707
|
+
vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
|
2708
|
+
vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
|
2709
|
+
vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
|
2710
|
+
vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
|
2711
|
+
}
|
2712
|
+
|
2713
|
+
// de-quantize vx1[8]
|
2714
|
+
{
|
2715
|
+
const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1);
|
2716
|
+
const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1);
|
2717
|
+
|
2718
|
+
uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
|
2719
|
+
uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
|
2720
|
+
uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
|
2721
|
+
uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
|
2722
|
+
|
2723
|
+
vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
|
2724
|
+
vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
|
2725
|
+
vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
|
2726
|
+
vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
|
2727
|
+
|
2728
|
+
q6h_0 = vandq_u8(mone, qh_bits.val[0]);
|
2729
|
+
q6h_1 = vandq_u8(mone, qh_bits.val[1]);
|
2730
|
+
q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
|
2731
|
+
q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
|
2732
|
+
|
2733
|
+
vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
|
2734
|
+
vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
|
2735
|
+
vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
|
2736
|
+
vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
|
2737
|
+
}
|
2738
|
+
|
2739
|
+
// process 16 elements (one block with same scale) per iteration
|
2740
|
+
// - vx = concat(ql, qh) - 32
|
2741
|
+
// - r1,r2,r3,r4 = smmla(vx, vy)
|
2742
|
+
for (int k = 0; k < 8; ++k) {
|
2743
|
+
const int blk = j * 8 + k;
|
2744
|
+
|
2745
|
+
const int8x16_t vy0 = vld1q_s8(qy0);
|
2746
|
+
const int8x16_t vy1 = vld1q_s8(qy1);
|
2747
|
+
qy0 += 16;
|
2748
|
+
qy1 += 16;
|
2749
|
+
|
2750
|
+
const int32x4_t block_scale = {
|
2751
|
+
x0->scales[blk],
|
2752
|
+
x0->scales[blk],
|
2753
|
+
x1->scales[blk],
|
2754
|
+
x1->scales[blk],
|
2755
|
+
};
|
2756
|
+
|
2757
|
+
// calculate four results at once with outer product
|
2758
|
+
const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
|
2759
|
+
const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
|
2760
|
+
const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
|
2761
|
+
const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
|
2762
|
+
int32x4_t vr = vdupq_n_s32(0);
|
2763
|
+
vr = vmmlaq_s32(vr, vx_l, vy_l);
|
2764
|
+
vr = vmmlaq_s32(vr, vx_h, vy_h);
|
2765
|
+
|
2766
|
+
// apply block scale, will NOT overflow
|
2767
|
+
// block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits
|
2768
|
+
visum = vmlaq_s32(visum, vr, block_scale);
|
2769
|
+
}
|
2770
|
+
}
|
2771
|
+
|
2772
|
+
// adjust bias, apply superblock scale
|
2773
|
+
{
|
2774
|
+
int32_t bias[4];
|
2775
|
+
#ifdef __ARM_FEATURE_SVE
|
2776
|
+
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
|
2777
|
+
const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
|
2778
|
+
const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
|
2779
|
+
const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
|
2780
|
+
const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
|
2781
|
+
const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
|
2782
|
+
const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
|
2783
|
+
const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
|
2784
|
+
const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
|
2785
|
+
const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
|
2786
|
+
const svint64_t zero = svdup_n_s64(0);
|
2787
|
+
bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
|
2788
|
+
svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
|
2789
|
+
bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
|
2790
|
+
svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
|
2791
|
+
bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
|
2792
|
+
svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
|
2793
|
+
bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
|
2794
|
+
svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
|
2795
|
+
#else
|
2796
|
+
// NEON doesn't support int16 dot product, fallback to separated mul and add
|
2797
|
+
const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
|
2798
|
+
const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
|
2799
|
+
|
2800
|
+
int8x16_t scales_s8 = vld1q_s8(x0->scales);
|
2801
|
+
const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
|
2802
|
+
scales_s8 = vld1q_s8(x1->scales);
|
2803
|
+
const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
|
2804
|
+
|
2805
|
+
int32x4_t prod;
|
2806
|
+
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])),
|
2807
|
+
vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))),
|
2808
|
+
vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])),
|
2809
|
+
vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1]))));
|
2810
|
+
bias[0] = vaddvq_s32(prod);
|
2811
|
+
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])),
|
2812
|
+
vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))),
|
2813
|
+
vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])),
|
2814
|
+
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1]))));
|
2815
|
+
bias[1] = vaddvq_s32(prod);
|
2816
|
+
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])),
|
2817
|
+
vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))),
|
2818
|
+
vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])),
|
2819
|
+
vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1]))));
|
2820
|
+
bias[2] = vaddvq_s32(prod);
|
2821
|
+
prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])),
|
2822
|
+
vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))),
|
2823
|
+
vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])),
|
2824
|
+
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
|
2825
|
+
bias[3] = vaddvq_s32(prod);
|
2826
|
+
|
2827
|
+
#endif
|
2828
|
+
const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
|
2829
|
+
|
2830
|
+
const float32x4_t superblock_scale = {
|
2831
|
+
LM_GGML_CPU_FP16_TO_FP32(x0->d) * y0->d,
|
2832
|
+
LM_GGML_CPU_FP16_TO_FP32(x0->d) * y1->d,
|
2833
|
+
LM_GGML_CPU_FP16_TO_FP32(x1->d) * y0->d,
|
2834
|
+
LM_GGML_CPU_FP16_TO_FP32(x1->d) * y1->d,
|
2835
|
+
};
|
2836
|
+
|
2837
|
+
visum = vsubq_s32(visum, vibias);
|
2838
|
+
vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
|
2839
|
+
}
|
2840
|
+
}
|
2841
|
+
|
2842
|
+
// vfsum = ABCD -> ACBD
|
2843
|
+
// AC -> s, BD -> (s+bs)
|
2844
|
+
vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
|
2845
|
+
vst1_f32(s, vget_low_f32 (vfsum));
|
2846
|
+
vst1_f32(s + bs, vget_high_f32(vfsum));
|
2847
|
+
|
2848
|
+
return;
|
2849
|
+
}
|
2850
|
+
#endif
|
2851
|
+
|
2852
|
+
#ifdef __ARM_FEATURE_SVE
|
2853
|
+
const int vector_length = lm_ggml_cpu_get_sve_cnt()*8;
|
2854
|
+
float sum = 0;
|
2855
|
+
svuint8_t m4b = svdup_n_u8(0xf);
|
2856
|
+
svint32_t vzero = svdup_n_s32(0);
|
2857
|
+
svuint8_t mone = svdup_n_u8(0x30);
|
2858
|
+
svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4;
|
2859
|
+
svuint8_t q6h_1, q6h_2, q6h_3, q6h_4;
|
2860
|
+
|
2861
|
+
for (int i = 0; i < nb; ++i) {
|
2862
|
+
const float d_all = LM_GGML_CPU_FP16_TO_FP32(x[i].d);
|
2863
|
+
|
2864
|
+
const uint8_t * LM_GGML_RESTRICT q6 = x[i].ql;
|
2865
|
+
const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
|
2866
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
2867
|
+
|
2868
|
+
const int8_t * LM_GGML_RESTRICT scale = x[i].scales;
|
2869
|
+
|
2870
|
+
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
|
2871
|
+
const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums);
|
2872
|
+
const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8);
|
2873
|
+
const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale));
|
2874
|
+
const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8));
|
2875
|
+
const svint64_t prod = svdup_n_s64(0);
|
2876
|
+
int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1),
|
2877
|
+
svdot_s64(prod, q8sums_2, q6scales_2)));
|
2878
|
+
int32_t isum = 0;
|
2879
|
+
|
2880
|
+
switch (vector_length) {
|
2881
|
+
case 128:
|
2882
|
+
{
|
2883
|
+
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
|
2884
|
+
const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
|
2885
|
+
svint32_t isum_tmp = svdup_n_s32(0);
|
2886
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
2887
|
+
svuint8_t qhbits_1 = svld1_u8(pg8_16, qh);
|
2888
|
+
svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16);
|
2889
|
+
qh += 32;
|
2890
|
+
svuint8_t q6bits_1 = svld1_u8(pg8_16, q6);
|
2891
|
+
svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16);
|
2892
|
+
svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32);
|
2893
|
+
svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48);
|
2894
|
+
q6 += 64;
|
2895
|
+
svint8_t q8bytes_1 = svld1_s8(pg8_16, q8);
|
2896
|
+
svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16);
|
2897
|
+
svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32);
|
2898
|
+
svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48);
|
2899
|
+
q8 += 64;
|
2900
|
+
|
2901
|
+
q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4));
|
2902
|
+
q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4));
|
2903
|
+
q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2));
|
2904
|
+
q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2));
|
2905
|
+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1));
|
2906
|
+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2));
|
2907
|
+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3));
|
2908
|
+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4));
|
2909
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
|
2910
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
|
2911
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
|
2912
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
|
2913
|
+
|
2914
|
+
scale += 4;
|
2915
|
+
q8bytes_1 = svld1_s8(pg8_16, q8);
|
2916
|
+
q8bytes_2 = svld1_s8(pg8_16, q8+16);
|
2917
|
+
q8bytes_3 = svld1_s8(pg8_16, q8+32);
|
2918
|
+
q8bytes_4 = svld1_s8(pg8_16, q8+48);
|
2919
|
+
q8 += 64;
|
2920
|
+
|
2921
|
+
q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1);
|
2922
|
+
q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2);
|
2923
|
+
q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2));
|
2924
|
+
q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2));
|
2925
|
+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1));
|
2926
|
+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2));
|
2927
|
+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3));
|
2928
|
+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4));
|
2929
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]);
|
2930
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]);
|
2931
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]);
|
2932
|
+
isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]);
|
2933
|
+
scale += 4;
|
2934
|
+
}
|
2935
|
+
isum += svaddv_s32(pg32_4, isum_tmp);
|
2936
|
+
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
2937
|
+
}
|
2938
|
+
break;
|
2939
|
+
case 256:
|
2940
|
+
case 512:
|
2941
|
+
{
|
2942
|
+
const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2);
|
2943
|
+
const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8);
|
2944
|
+
const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32);
|
2945
|
+
svint32_t isum_tmp = svdup_n_s32(0);
|
2946
|
+
for (int j = 0; j < QK_K/128; j++) {
|
2947
|
+
svuint8_t qhbits_1 = svld1_u8(pg8_32, qh);
|
2948
|
+
qh += 32;
|
2949
|
+
svuint8_t q6bits_1 = svld1_u8(pg8_32, q6);
|
2950
|
+
svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32);
|
2951
|
+
q6 += 64;
|
2952
|
+
svint8_t q8bytes_1 = svld1_s8(pg8_32, q8);
|
2953
|
+
svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32);
|
2954
|
+
svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64);
|
2955
|
+
svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96);
|
2956
|
+
q8 += 128;
|
2957
|
+
q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4));
|
2958
|
+
q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2));
|
2959
|
+
q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1);
|
2960
|
+
q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2));
|
2961
|
+
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1));
|
2962
|
+
q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2));
|
2963
|
+
q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3));
|
2964
|
+
q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4));
|
2965
|
+
|
2966
|
+
svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale);
|
2967
|
+
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
|
2968
|
+
scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp);
|
2969
|
+
svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2);
|
2970
|
+
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
|
2971
|
+
scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp);
|
2972
|
+
svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4);
|
2973
|
+
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
|
2974
|
+
scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp);
|
2975
|
+
svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6);
|
2976
|
+
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
|
2977
|
+
scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp);
|
2978
|
+
svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp));
|
2979
|
+
svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp));
|
2980
|
+
svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp));
|
2981
|
+
svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp));
|
2982
|
+
|
2983
|
+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1);
|
2984
|
+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2);
|
2985
|
+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3);
|
2986
|
+
isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4);
|
2987
|
+
scale += 8;
|
2988
|
+
}
|
2989
|
+
isum += svaddv_s32(pg32_8, isum_tmp);
|
2990
|
+
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
2991
|
+
}
|
2992
|
+
break;
|
2993
|
+
default:
|
2994
|
+
assert(false && "Unsupported vector length");
|
2995
|
+
break;
|
2996
|
+
}
|
2997
|
+
}
|
2998
|
+
|
2999
|
+
*s = sum;
|
3000
|
+
|
3001
|
+
#elif __ARM_NEON
|
3002
|
+
float sum = 0;
|
3003
|
+
|
3004
|
+
const uint8x16_t m4b = vdupq_n_u8(0xF);
|
3005
|
+
const int32x4_t vzero = vdupq_n_s32(0);
|
3006
|
+
//const int8x16_t m32s = vdupq_n_s8(32);
|
3007
|
+
|
3008
|
+
const uint8x16_t mone = vdupq_n_u8(3);
|
3009
|
+
|
3010
|
+
lm_ggml_int8x16x4_t q6bytes;
|
3011
|
+
lm_ggml_uint8x16x4_t q6h;
|
3012
|
+
|
3013
|
+
for (int i = 0; i < nb; ++i) {
|
3014
|
+
|
3015
|
+
const float d_all = LM_GGML_CPU_FP16_TO_FP32(x[i].d);
|
3016
|
+
|
3017
|
+
const uint8_t * LM_GGML_RESTRICT q6 = x[i].ql;
|
3018
|
+
const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
|
3019
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
3020
|
+
|
3021
|
+
const int8_t * LM_GGML_RESTRICT scale = x[i].scales;
|
3022
|
+
|
3023
|
+
const lm_ggml_int16x8x2_t q8sums = lm_ggml_vld1q_s16_x2(y[i].bsums);
|
3024
|
+
const int8x16_t scales = vld1q_s8(scale);
|
3025
|
+
const lm_ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}};
|
3026
|
+
|
3027
|
+
const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
|
3028
|
+
vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
|
3029
|
+
vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])),
|
3030
|
+
vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1]))));
|
3031
|
+
int32_t isum_mins = vaddvq_s32(prod);
|
3032
|
+
|
3033
|
+
int32_t isum = 0;
|
3034
|
+
|
3035
|
+
for (int j = 0; j < QK_K/128; ++j) {
|
3036
|
+
|
3037
|
+
lm_ggml_uint8x16x2_t qhbits = lm_ggml_vld1q_u8_x2(qh); qh += 32;
|
3038
|
+
lm_ggml_uint8x16x4_t q6bits = lm_ggml_vld1q_u8_x4(q6); q6 += 64;
|
3039
|
+
lm_ggml_int8x16x4_t q8bytes = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
|
3040
|
+
|
3041
|
+
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
|
3042
|
+
q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
|
3043
|
+
uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2);
|
3044
|
+
q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
3045
|
+
shifted = vshrq_n_u8(qhbits.val[1], 2);
|
3046
|
+
q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
3047
|
+
|
3048
|
+
//q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
|
3049
|
+
//q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
|
3050
|
+
//q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s);
|
3051
|
+
//q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s);
|
3052
|
+
q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0]));
|
3053
|
+
q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1]));
|
3054
|
+
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
|
3055
|
+
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
|
3056
|
+
|
3057
|
+
isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
3058
|
+
vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
3059
|
+
vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
3060
|
+
vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
3061
|
+
|
3062
|
+
scale += 4;
|
3063
|
+
|
3064
|
+
q8bytes = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
|
3065
|
+
|
3066
|
+
shifted = vshrq_n_u8(qhbits.val[0], 4);
|
3067
|
+
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
3068
|
+
shifted = vshrq_n_u8(qhbits.val[1], 4);
|
3069
|
+
q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
3070
|
+
shifted = vshrq_n_u8(qhbits.val[0], 6);
|
3071
|
+
q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
3072
|
+
shifted = vshrq_n_u8(qhbits.val[1], 6);
|
3073
|
+
q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
|
3074
|
+
|
3075
|
+
//q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s);
|
3076
|
+
//q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s);
|
3077
|
+
//q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s);
|
3078
|
+
//q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s);
|
3079
|
+
q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0]));
|
3080
|
+
q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1]));
|
3081
|
+
q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
|
3082
|
+
q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
|
3083
|
+
|
3084
|
+
isum += vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
|
3085
|
+
vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
|
3086
|
+
vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
|
3087
|
+
vaddvq_s32(lm_ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
|
3088
|
+
scale += 4;
|
3089
|
+
}
|
3090
|
+
//sum += isum * d_all * y[i].d;
|
3091
|
+
sum += d_all * y[i].d * (isum - 32 * isum_mins);
|
3092
|
+
|
3093
|
+
}
|
3094
|
+
*s = sum;
|
3095
|
+
#else
|
3096
|
+
|
3097
|
+
int8_t aux8[QK_K];
|
3098
|
+
int16_t aux16[8];
|
3099
|
+
float sums [8];
|
3100
|
+
int32_t aux32[8];
|
3101
|
+
memset(sums, 0, 8*sizeof(float));
|
3102
|
+
|
3103
|
+
float sumf = 0;
|
3104
|
+
for (int i = 0; i < nb; ++i) {
|
3105
|
+
const uint8_t * LM_GGML_RESTRICT q4 = x[i].ql;
|
3106
|
+
const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
|
3107
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
3108
|
+
memset(aux32, 0, 8*sizeof(int32_t));
|
3109
|
+
int8_t * LM_GGML_RESTRICT a = aux8;
|
3110
|
+
for (int j = 0; j < QK_K; j += 128) {
|
3111
|
+
for (int l = 0; l < 32; ++l) {
|
3112
|
+
a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
3113
|
+
a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
3114
|
+
a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
3115
|
+
a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
3116
|
+
}
|
3117
|
+
a += 128;
|
3118
|
+
q4 += 64;
|
3119
|
+
qh += 32;
|
3120
|
+
}
|
3121
|
+
a = aux8;
|
3122
|
+
int is = 0;
|
3123
|
+
for (int j = 0; j < QK_K/16; ++j) {
|
3124
|
+
int scale = x[i].scales[is++];
|
3125
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
3126
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
3127
|
+
q8 += 8; a += 8;
|
3128
|
+
for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
|
3129
|
+
for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
|
3130
|
+
q8 += 8; a += 8;
|
3131
|
+
}
|
3132
|
+
const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
3133
|
+
for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
|
3134
|
+
}
|
3135
|
+
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
3136
|
+
*s = sumf;
|
3137
|
+
#endif
|
3138
|
+
}
|
3139
|
+
|
3140
|
+
#if defined (__ARM_NEON)
|
3141
|
+
static const int8_t keven_signs_q2xs[1024] = {
|
3142
|
+
1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
|
3143
|
+
1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
|
3144
|
+
1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
|
3145
|
+
1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
|
3146
|
+
1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
|
3147
|
+
1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
|
3148
|
+
1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
|
3149
|
+
1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
|
3150
|
+
1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
|
3151
|
+
1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
|
3152
|
+
1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
|
3153
|
+
1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
|
3154
|
+
1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
|
3155
|
+
1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
|
3156
|
+
1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
|
3157
|
+
1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
|
3158
|
+
1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
|
3159
|
+
1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
|
3160
|
+
1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
|
3161
|
+
1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
|
3162
|
+
1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
|
3163
|
+
1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
|
3164
|
+
1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
|
3165
|
+
1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
|
3166
|
+
1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
|
3167
|
+
1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
|
3168
|
+
1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
|
3169
|
+
1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
|
3170
|
+
1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
|
3171
|
+
1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
|
3172
|
+
1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
|
3173
|
+
1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
|
3174
|
+
};
|
3175
|
+
#endif
|
3176
|
+
|
3177
|
+
void lm_ggml_vec_dot_iq2_xxs_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
|
3178
|
+
assert(n % QK_K == 0);
|
3179
|
+
assert(nrc == 1);
|
3180
|
+
UNUSED(nrc);
|
3181
|
+
UNUSED(bx);
|
3182
|
+
UNUSED(by);
|
3183
|
+
UNUSED(bs);
|
3184
|
+
|
3185
|
+
const block_iq2_xxs * LM_GGML_RESTRICT x = vx;
|
3186
|
+
const block_q8_K * LM_GGML_RESTRICT y = vy;
|
3187
|
+
|
3188
|
+
const int nb = n / QK_K;
|
3189
|
+
|
3190
|
+
#if defined(__ARM_NEON)
|
3191
|
+
|
3192
|
+
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
3193
|
+
|
3194
|
+
uint32_t aux32[4];
|
3195
|
+
const uint8_t * aux8 = (const uint8_t *)aux32;
|
3196
|
+
|
3197
|
+
lm_ggml_int8x16x4_t q2u;
|
3198
|
+
lm_ggml_int8x16x4_t q2s;
|
3199
|
+
lm_ggml_int8x16x4_t q8b;
|
3200
|
+
|
3201
|
+
float sumf = 0;
|
3202
|
+
for (int i = 0; i < nb; ++i) {
|
3203
|
+
const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
3204
|
+
const uint16_t * LM_GGML_RESTRICT q2 = x[i].qs;
|
3205
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
3206
|
+
float sumf1 = 0, sumf2 = 0;
|
3207
|
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
3208
|
+
q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
|
3209
|
+
memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
|
3210
|
+
q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1])));
|
3211
|
+
q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3])));
|
3212
|
+
q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9])));
|
3213
|
+
q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11])));
|
3214
|
+
q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
|
3215
|
+
q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
|
3216
|
+
q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127))));
|
3217
|
+
q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127))));
|
3218
|
+
q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
|
3219
|
+
q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
|
3220
|
+
q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
|
3221
|
+
q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
|
3222
|
+
const int32x4_t p1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]);
|
3223
|
+
const int32x4_t p2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]);
|
3224
|
+
sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28));
|
3225
|
+
sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28));
|
3226
|
+
}
|
3227
|
+
sumf += d*(sumf1 + sumf2);
|
3228
|
+
}
|
3229
|
+
*s = 0.25f * sumf;
|
3230
|
+
|
3231
|
+
#else
|
3232
|
+
|
3233
|
+
uint32_t aux32[2];
|
3234
|
+
const uint8_t * aux8 = (const uint8_t *)aux32;
|
3235
|
+
|
3236
|
+
float sumf = 0.f;
|
3237
|
+
for (int i = 0; i < nb; ++i) {
|
3238
|
+
const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
3239
|
+
const uint16_t * LM_GGML_RESTRICT q2 = x[i].qs;
|
3240
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
3241
|
+
int32_t bsum = 0;
|
3242
|
+
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
3243
|
+
memcpy(aux32, q2, 2*sizeof(uint32_t));
|
3244
|
+
q2 += 4;
|
3245
|
+
const uint32_t ls = 2*(aux32[1] >> 28) + 1;
|
3246
|
+
int32_t sumi = 0;
|
3247
|
+
for (int l = 0; l < 4; ++l) {
|
3248
|
+
const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
|
3249
|
+
const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
|
3250
|
+
for (int j = 0; j < 8; ++j) {
|
3251
|
+
sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
|
3252
|
+
}
|
3253
|
+
q8 += 8;
|
3254
|
+
}
|
3255
|
+
bsum += sumi * ls;
|
3256
|
+
}
|
3257
|
+
sumf += d * bsum;
|
3258
|
+
}
|
3259
|
+
*s = 0.125f * sumf;
|
3260
|
+
#endif
|
3261
|
+
}
|
3262
|
+
|
3263
|
+
void lm_ggml_vec_dot_iq2_xs_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
|
3264
|
+
assert(n % QK_K == 0);
|
3265
|
+
assert(nrc == 1);
|
3266
|
+
UNUSED(nrc);
|
3267
|
+
UNUSED(bx);
|
3268
|
+
UNUSED(by);
|
3269
|
+
UNUSED(bs);
|
3270
|
+
|
3271
|
+
const block_iq2_xs * LM_GGML_RESTRICT x = vx;
|
3272
|
+
const block_q8_K * LM_GGML_RESTRICT y = vy;
|
3273
|
+
|
3274
|
+
const int nb = n / QK_K;
|
3275
|
+
|
3276
|
+
#if defined(__ARM_NEON)
|
3277
|
+
|
3278
|
+
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
3279
|
+
|
3280
|
+
lm_ggml_int8x16x4_t q2u;
|
3281
|
+
lm_ggml_int8x16x4_t q2s;
|
3282
|
+
lm_ggml_int8x16x4_t q8b;
|
3283
|
+
|
3284
|
+
int32x4x4_t scales32;
|
3285
|
+
|
3286
|
+
float sumf = 0;
|
3287
|
+
for (int i = 0; i < nb; ++i) {
|
3288
|
+
const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
3289
|
+
const uint16_t * LM_GGML_RESTRICT q2 = x[i].qs;
|
3290
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
3291
|
+
const uint8x8_t scales8 = vld1_u8(x[i].scales);
|
3292
|
+
const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf));
|
3293
|
+
const uint8x8_t scales_h = vshr_n_u8(scales8, 4);
|
3294
|
+
uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));
|
3295
|
+
scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1));
|
3296
|
+
const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales));
|
3297
|
+
const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales));
|
3298
|
+
scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1)));
|
3299
|
+
scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1)));
|
3300
|
+
scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2)));
|
3301
|
+
scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2)));
|
3302
|
+
int32x4_t sumi = vdupq_n_s32(0);
|
3303
|
+
for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
|
3304
|
+
q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
|
3305
|
+
q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));
|
3306
|
+
q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));
|
3307
|
+
q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511))));
|
3308
|
+
q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511))));
|
3309
|
+
q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9))));
|
3310
|
+
q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9))));
|
3311
|
+
q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9))));
|
3312
|
+
q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9))));
|
3313
|
+
q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
|
3314
|
+
q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
|
3315
|
+
q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
|
3316
|
+
q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
|
3317
|
+
const int32x4_t p1 = lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]);
|
3318
|
+
const int32x4_t p2 = lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]);
|
3319
|
+
const int32x4_t p3 = lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]);
|
3320
|
+
const int32x4_t p4 = lm_ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]);
|
3321
|
+
const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4));
|
3322
|
+
sumi = vmlaq_s32(sumi, p, scales32.val[ib64]);
|
3323
|
+
q2 += 8;
|
3324
|
+
}
|
3325
|
+
sumf += d*vaddvq_s32(sumi);
|
3326
|
+
}
|
3327
|
+
*s = 0.125f * sumf;
|
3328
|
+
|
3329
|
+
#else
|
3330
|
+
|
3331
|
+
float sumf = 0.f;
|
3332
|
+
for (int i = 0; i < nb; ++i) {
|
3333
|
+
const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
3334
|
+
const uint16_t * LM_GGML_RESTRICT q2 = x[i].qs;
|
3335
|
+
const uint8_t * LM_GGML_RESTRICT sc = x[i].scales;
|
3336
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
3337
|
+
int32_t bsum = 0;
|
3338
|
+
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
3339
|
+
const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1;
|
3340
|
+
const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1;
|
3341
|
+
int32_t sumi = 0;
|
3342
|
+
for (int l = 0; l < 2; ++l) {
|
3343
|
+
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
|
3344
|
+
const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
|
3345
|
+
for (int j = 0; j < 8; ++j) {
|
3346
|
+
sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
|
3347
|
+
}
|
3348
|
+
q8 += 8;
|
3349
|
+
}
|
3350
|
+
bsum += sumi * ls1;
|
3351
|
+
sumi = 0;
|
3352
|
+
for (int l = 2; l < 4; ++l) {
|
3353
|
+
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
|
3354
|
+
const uint8_t signs = ksigns_iq2xs[q2[l] >> 9];
|
3355
|
+
for (int j = 0; j < 8; ++j) {
|
3356
|
+
sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
|
3357
|
+
}
|
3358
|
+
q8 += 8;
|
3359
|
+
}
|
3360
|
+
bsum += sumi * ls2;
|
3361
|
+
q2 += 4;
|
3362
|
+
}
|
3363
|
+
sumf += d * bsum;
|
3364
|
+
}
|
3365
|
+
*s = 0.125f * sumf;
|
3366
|
+
#endif
|
3367
|
+
}
|
3368
|
+
|
3369
|
+
void lm_ggml_vec_dot_iq2_s_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
|
3370
|
+
assert(n % QK_K == 0);
|
3371
|
+
assert(nrc == 1);
|
3372
|
+
UNUSED(nrc);
|
3373
|
+
UNUSED(bx);
|
3374
|
+
UNUSED(by);
|
3375
|
+
UNUSED(bs);
|
3376
|
+
|
3377
|
+
const block_iq2_s * LM_GGML_RESTRICT x = vx;
|
3378
|
+
const block_q8_K * LM_GGML_RESTRICT y = vy;
|
3379
|
+
|
3380
|
+
const int nb = n / QK_K;
|
3381
|
+
|
3382
|
+
#if defined(__ARM_NEON)
|
3383
|
+
|
3384
|
+
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
|
3385
|
+
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
|
3386
|
+
};
|
3387
|
+
|
3388
|
+
static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
|
3389
|
+
|
3390
|
+
const lm_ggml_uint8x16x2_t mask1 = lm_ggml_vld1q_u8_x2(k_mask1);
|
3391
|
+
const uint8x16_t mask2 = vld1q_u8(k_mask2);
|
3392
|
+
const uint8x16_t m1 = vdupq_n_u8(1);
|
3393
|
+
const int32x4_t vzero = vdupq_n_s32(0);
|
3394
|
+
|
3395
|
+
uint8x16x2_t vs;
|
3396
|
+
lm_ggml_int8x16x4_t q2s;
|
3397
|
+
lm_ggml_int8x16x4_t q8b;
|
3398
|
+
|
3399
|
+
float sumf = 0;
|
3400
|
+
for (int i = 0; i < nb; ++i) {
|
3401
|
+
|
3402
|
+
const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
3403
|
+
|
3404
|
+
const uint8_t * LM_GGML_RESTRICT qs = x[i].qs;
|
3405
|
+
const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
|
3406
|
+
const uint16_t * LM_GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);
|
3407
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
3408
|
+
|
3409
|
+
int sumi1 = 0, sumi2 = 0;
|
3410
|
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
3411
|
+
q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
|
3412
|
+
q2s.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[0] | ((qh[ib32+0] << 8) & 0x300)))),
|
3413
|
+
vld1_s8((const int8_t *)(iq2s_grid + (qs[1] | ((qh[ib32+0] << 6) & 0x300)))));
|
3414
|
+
q2s.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[2] | ((qh[ib32+0] << 4) & 0x300)))),
|
3415
|
+
vld1_s8((const int8_t *)(iq2s_grid + (qs[3] | ((qh[ib32+0] << 2) & 0x300)))));
|
3416
|
+
q2s.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[4] | ((qh[ib32+1] << 8) & 0x300)))),
|
3417
|
+
vld1_s8((const int8_t *)(iq2s_grid + (qs[5] | ((qh[ib32+1] << 6) & 0x300)))));
|
3418
|
+
q2s.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[6] | ((qh[ib32+1] << 4) & 0x300)))),
|
3419
|
+
vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300)))));
|
3420
|
+
qs += 8;
|
3421
|
+
|
3422
|
+
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16)));
|
3423
|
+
vs.val[1] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
3424
|
+
vs.val[0] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
3425
|
+
vs.val[0] = vceqq_u8(vs.val[0], mask2);
|
3426
|
+
vs.val[1] = vceqq_u8(vs.val[1], mask2);
|
3427
|
+
|
3428
|
+
q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]);
|
3429
|
+
q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]);
|
3430
|
+
|
3431
|
+
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16)));
|
3432
|
+
vs.val[1] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
3433
|
+
vs.val[0] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
3434
|
+
vs.val[0] = vceqq_u8(vs.val[0], mask2);
|
3435
|
+
vs.val[1] = vceqq_u8(vs.val[1], mask2);
|
3436
|
+
|
3437
|
+
signs += 4;
|
3438
|
+
|
3439
|
+
q2s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[2]);
|
3440
|
+
q2s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[3]);
|
3441
|
+
|
3442
|
+
const int32x4_t p1 = lm_ggml_vdotq_s32(vzero, q2s.val[0], q8b.val[0]);
|
3443
|
+
const int32x4_t p2 = lm_ggml_vdotq_s32(vzero, q2s.val[1], q8b.val[1]);
|
3444
|
+
const int32x4_t p3 = lm_ggml_vdotq_s32(vzero, q2s.val[2], q8b.val[2]);
|
3445
|
+
const int32x4_t p4 = lm_ggml_vdotq_s32(vzero, q2s.val[3], q8b.val[3]);
|
3446
|
+
|
3447
|
+
sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32+0] & 0xf));
|
3448
|
+
sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32+0] >> 4));
|
3449
|
+
sumi1 += vaddvq_s32(p3) * (1 + 2*(x[i].scales[ib32+1] & 0xf));
|
3450
|
+
sumi2 += vaddvq_s32(p4) * (1 + 2*(x[i].scales[ib32+1] >> 4));
|
3451
|
+
}
|
3452
|
+
sumf += d*(sumi1 + sumi2);
|
3453
|
+
}
|
3454
|
+
|
3455
|
+
*s = 0.125f * sumf;
|
3456
|
+
|
3457
|
+
#else
|
3458
|
+
|
3459
|
+
float sumf = 0;
|
3460
|
+
for (int i = 0; i < nb; i++) {
|
3461
|
+
|
3462
|
+
const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
3463
|
+
const int8_t * q8 = y[i].qs;
|
3464
|
+
const uint8_t * qs = x[i].qs;
|
3465
|
+
const uint8_t * qh = x[i].qh;
|
3466
|
+
const uint8_t * signs = qs + QK_K/8;
|
3467
|
+
|
3468
|
+
int bsum = 0;
|
3469
|
+
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
3470
|
+
int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf);
|
3471
|
+
int ls2 = 1 + 2*(x[i].scales[ib32] >> 4);
|
3472
|
+
int sumi1 = 0, sumi2 = 0;
|
3473
|
+
for (int l = 0; l < 2; ++l) {
|
3474
|
+
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
|
3475
|
+
for (int j = 0; j < 8; ++j) {
|
3476
|
+
sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
|
3477
|
+
}
|
3478
|
+
q8 += 8;
|
3479
|
+
}
|
3480
|
+
for (int l = 2; l < 4; ++l) {
|
3481
|
+
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
|
3482
|
+
for (int j = 0; j < 8; ++j) {
|
3483
|
+
sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
|
3484
|
+
}
|
3485
|
+
q8 += 8;
|
3486
|
+
}
|
3487
|
+
bsum += ls1 * sumi1 + ls2 * sumi2;
|
3488
|
+
qs += 4;
|
3489
|
+
signs += 4;
|
3490
|
+
}
|
3491
|
+
|
3492
|
+
sumf += d * bsum;
|
3493
|
+
}
|
3494
|
+
|
3495
|
+
*s = 0.125f * sumf;
|
3496
|
+
|
3497
|
+
#endif
|
3498
|
+
|
3499
|
+
}
|
3500
|
+
|
3501
|
+
void lm_ggml_vec_dot_iq3_xxs_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
|
3502
|
+
assert(n % QK_K == 0);
|
3503
|
+
assert(nrc == 1);
|
3504
|
+
UNUSED(nrc);
|
3505
|
+
UNUSED(bx);
|
3506
|
+
UNUSED(by);
|
3507
|
+
UNUSED(bs);
|
3508
|
+
|
3509
|
+
const block_iq3_xxs * LM_GGML_RESTRICT x = vx;
|
3510
|
+
const block_q8_K * LM_GGML_RESTRICT y = vy;
|
3511
|
+
|
3512
|
+
const int nb = n / QK_K;
|
3513
|
+
|
3514
|
+
#if defined(__ARM_NEON)
|
3515
|
+
|
3516
|
+
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
|
3517
|
+
|
3518
|
+
uint32_t aux32[2];
|
3519
|
+
|
3520
|
+
lm_ggml_int8x16x4_t q3s;
|
3521
|
+
lm_ggml_int8x16x4_t q8b;
|
3522
|
+
|
3523
|
+
float sumf = 0;
|
3524
|
+
for (int i = 0; i < nb; ++i) {
|
3525
|
+
const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
3526
|
+
const uint8_t * LM_GGML_RESTRICT q3 = x[i].qs;
|
3527
|
+
const uint8_t * LM_GGML_RESTRICT gas = x[i].qs + QK_K/4;
|
3528
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
3529
|
+
float sumf1 = 0, sumf2 = 0;
|
3530
|
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
3531
|
+
q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
|
3532
|
+
memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t);
|
3533
|
+
const uint32x4_t aux32x4_0 = lm_ggml_vld1q_u32(iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]);
|
3534
|
+
const uint32x4_t aux32x4_1 = lm_ggml_vld1q_u32(iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]);
|
3535
|
+
const uint32x4_t aux32x4_2 = lm_ggml_vld1q_u32(iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]);
|
3536
|
+
const uint32x4_t aux32x4_3 = lm_ggml_vld1q_u32(iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]);
|
3537
|
+
q3 += 16;
|
3538
|
+
q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127))));
|
3539
|
+
q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127))));
|
3540
|
+
q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127))));
|
3541
|
+
q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
|
3542
|
+
q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0));
|
3543
|
+
q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1));
|
3544
|
+
q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2));
|
3545
|
+
q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3));
|
3546
|
+
const int32x4_t p1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
|
3547
|
+
const int32x4_t p2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
|
3548
|
+
sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28));
|
3549
|
+
sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28));
|
3550
|
+
}
|
3551
|
+
sumf += d*(sumf1 + sumf2);
|
3552
|
+
}
|
3553
|
+
*s = 0.5f * sumf;
|
3554
|
+
|
3555
|
+
#else
|
3556
|
+
|
3557
|
+
uint32_t aux32;
|
3558
|
+
|
3559
|
+
float sumf = 0.f;
|
3560
|
+
for (int i = 0; i < nb; ++i) {
|
3561
|
+
const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
3562
|
+
const uint8_t * LM_GGML_RESTRICT q3 = x[i].qs;
|
3563
|
+
const uint8_t * LM_GGML_RESTRICT gas = x[i].qs + QK_K/4;
|
3564
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
3565
|
+
int32_t bsum = 0;
|
3566
|
+
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
|
3567
|
+
memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t);
|
3568
|
+
const uint32_t ls = 2*(aux32 >> 28) + 1;
|
3569
|
+
int32_t sumi = 0;
|
3570
|
+
for (int l = 0; l < 4; ++l) {
|
3571
|
+
const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]);
|
3572
|
+
const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]);
|
3573
|
+
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
|
3574
|
+
for (int j = 0; j < 4; ++j) {
|
3575
|
+
sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1);
|
3576
|
+
sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1);
|
3577
|
+
}
|
3578
|
+
q8 += 8;
|
3579
|
+
}
|
3580
|
+
q3 += 8;
|
3581
|
+
bsum += sumi * ls;
|
3582
|
+
}
|
3583
|
+
sumf += d * bsum;
|
3584
|
+
}
|
3585
|
+
*s = 0.25f * sumf;
|
3586
|
+
#endif
|
3587
|
+
}
|
3588
|
+
|
3589
|
+
void lm_ggml_vec_dot_iq3_s_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
|
3590
|
+
assert(n % QK_K == 0);
|
3591
|
+
assert(nrc == 1);
|
3592
|
+
UNUSED(nrc);
|
3593
|
+
UNUSED(bx);
|
3594
|
+
UNUSED(by);
|
3595
|
+
UNUSED(bs);
|
3596
|
+
|
3597
|
+
const block_iq3_s * LM_GGML_RESTRICT x = vx;
|
3598
|
+
const block_q8_K * LM_GGML_RESTRICT y = vy;
|
3599
|
+
|
3600
|
+
const int nb = n / QK_K;
|
3601
|
+
|
3602
|
+
#if defined(__ARM_NEON)
|
3603
|
+
|
3604
|
+
typedef union {
|
3605
|
+
uint16x8_t vec_index;
|
3606
|
+
uint16_t index[8];
|
3607
|
+
} vec_index_t;
|
3608
|
+
|
3609
|
+
static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
|
3610
|
+
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
|
3611
|
+
};
|
3612
|
+
|
3613
|
+
static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
|
3614
|
+
|
3615
|
+
static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};
|
3616
|
+
|
3617
|
+
const lm_ggml_uint8x16x2_t mask1 = lm_ggml_vld1q_u8_x2(k_mask1);
|
3618
|
+
const uint8x16_t mask2 = vld1q_u8(k_mask2);
|
3619
|
+
|
3620
|
+
const int16x8_t hshift = vld1q_s16(k_shift);
|
3621
|
+
const uint16x8_t m256 = vdupq_n_u16(256);
|
3622
|
+
const uint8x16_t m1 = vdupq_n_u8(1);
|
3623
|
+
|
3624
|
+
uint8x16x2_t vs;
|
3625
|
+
lm_ggml_int8x16x4_t q3s;
|
3626
|
+
lm_ggml_int8x16x4_t q8b;
|
3627
|
+
vec_index_t idx;
|
3628
|
+
|
3629
|
+
uint32_t scales32[2];
|
3630
|
+
const uint8_t * scales8 = (const uint8_t *)scales32;
|
3631
|
+
|
3632
|
+
float sumf = 0;
|
3633
|
+
for (int i = 0; i < nb; ++i) {
|
3634
|
+
const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
3635
|
+
const uint8_t * LM_GGML_RESTRICT qs = x[i].qs;
|
3636
|
+
const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
|
3637
|
+
const uint16_t * LM_GGML_RESTRICT signs = (const uint16_t *)x[i].signs;
|
3638
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
3639
|
+
|
3640
|
+
memcpy(scales32, x[i].scales, 4);
|
3641
|
+
scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;
|
3642
|
+
scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101;
|
3643
|
+
|
3644
|
+
int sumi1 = 0, sumi2 = 0;
|
3645
|
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
3646
|
+
q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
|
3647
|
+
|
3648
|
+
const uint8x16_t idx_l = vld1q_u8(qs); qs += 16;
|
3649
|
+
idx.vec_index = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), hshift), m256));
|
3650
|
+
const uint32x4_t aux32x4_0 = lm_ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]],
|
3651
|
+
iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]);
|
3652
|
+
const uint32x4_t aux32x4_1 = lm_ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]],
|
3653
|
+
iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]);
|
3654
|
+
idx.vec_index = vorrq_u16(vmovl_u8(vget_high_u8(idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), hshift), m256));
|
3655
|
+
const uint32x4_t aux32x4_2 = lm_ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]],
|
3656
|
+
iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]);
|
3657
|
+
const uint32x4_t aux32x4_3 = lm_ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]],
|
3658
|
+
iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]);
|
3659
|
+
|
3660
|
+
|
3661
|
+
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16)));
|
3662
|
+
vs.val[1] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
3663
|
+
vs.val[0] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
3664
|
+
vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
|
3665
|
+
vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1);
|
3666
|
+
|
3667
|
+
q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0));
|
3668
|
+
q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1));
|
3669
|
+
|
3670
|
+
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16)));
|
3671
|
+
vs.val[1] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
3672
|
+
vs.val[0] = vandq_u8(lm_ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
|
3673
|
+
vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
|
3674
|
+
vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1);
|
3675
|
+
|
3676
|
+
signs += 4;
|
3677
|
+
|
3678
|
+
q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_2));
|
3679
|
+
q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_3));
|
3680
|
+
|
3681
|
+
const int32x4_t p1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
|
3682
|
+
const int32x4_t p2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
|
3683
|
+
|
3684
|
+
sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0];
|
3685
|
+
sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4];
|
3686
|
+
}
|
3687
|
+
sumf += d*(sumi1 + sumi2);
|
3688
|
+
}
|
3689
|
+
*s = sumf;
|
3690
|
+
|
3691
|
+
#else
|
3692
|
+
|
3693
|
+
float sumf = 0.f;
|
3694
|
+
for (int i = 0; i < nb; ++i) {
|
3695
|
+
const float d = LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
3696
|
+
const uint8_t * LM_GGML_RESTRICT qs = x[i].qs;
|
3697
|
+
const uint8_t * LM_GGML_RESTRICT qh = x[i].qh;
|
3698
|
+
const uint8_t * LM_GGML_RESTRICT signs = x[i].signs;
|
3699
|
+
const int8_t * LM_GGML_RESTRICT q8 = y[i].qs;
|
3700
|
+
int32_t bsum = 0;
|
3701
|
+
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
3702
|
+
const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1;
|
3703
|
+
const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1;
|
3704
|
+
int32_t sumi = 0;
|
3705
|
+
for (int l = 0; l < 4; ++l) {
|
3706
|
+
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256)));
|
3707
|
+
const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256)));
|
3708
|
+
for (int j = 0; j < 4; ++j) {
|
3709
|
+
sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
|
3710
|
+
sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
|
3711
|
+
}
|
3712
|
+
q8 += 8;
|
3713
|
+
}
|
3714
|
+
qs += 8;
|
3715
|
+
signs += 4;
|
3716
|
+
bsum += sumi * ls1;
|
3717
|
+
sumi = 0;
|
3718
|
+
for (int l = 0; l < 4; ++l) {
|
3719
|
+
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256)));
|
3720
|
+
const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256)));
|
3721
|
+
for (int j = 0; j < 4; ++j) {
|
3722
|
+
sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
|
3723
|
+
sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
|
3724
|
+
}
|
3725
|
+
q8 += 8;
|
3726
|
+
}
|
3727
|
+
qs += 8;
|
3728
|
+
signs += 4;
|
3729
|
+
bsum += sumi * ls2;
|
3730
|
+
}
|
3731
|
+
sumf += d * bsum;
|
3732
|
+
}
|
3733
|
+
*s = sumf;
|
3734
|
+
#endif
|
3735
|
+
}
|
3736
|
+
|
3737
|
+
void lm_ggml_vec_dot_iq1_s_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
|
3738
|
+
assert(n % QK_K == 0);
|
3739
|
+
assert(nrc == 1);
|
3740
|
+
UNUSED(nrc);
|
3741
|
+
UNUSED(bx);
|
3742
|
+
UNUSED(by);
|
3743
|
+
UNUSED(bs);
|
3744
|
+
|
3745
|
+
const block_iq1_s * LM_GGML_RESTRICT x = vx;
|
3746
|
+
const block_q8_K * LM_GGML_RESTRICT y = vy;
|
3747
|
+
|
3748
|
+
const int nb = n / QK_K;
|
3749
|
+
|
3750
|
+
#if defined __ARM_NEON
|
3751
|
+
|
3752
|
+
lm_ggml_int8x16x4_t q1b;
|
3753
|
+
lm_ggml_int8x16x4_t q8b;
|
3754
|
+
|
3755
|
+
float sumf = 0;
|
3756
|
+
for (int i = 0; i < nb; ++i) {
|
3757
|
+
|
3758
|
+
const int8_t * q8 = y[i].qs;
|
3759
|
+
const uint8_t * qs = x[i].qs;
|
3760
|
+
const uint16_t * qh = x[i].qh;
|
3761
|
+
|
3762
|
+
int sumi1 = 0, sumi2 = 0, sumi3 = 0;
|
3763
|
+
|
3764
|
+
for (int ib = 0; ib < QK_K/32; ib += 2) {
|
3765
|
+
|
3766
|
+
q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[ib+0] << 8) & 0x700)))),
|
3767
|
+
vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[ib+0] << 5) & 0x700)))));
|
3768
|
+
q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[ib+0] << 2) & 0x700)))),
|
3769
|
+
vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[ib+0] >> 1) & 0x700)))));
|
3770
|
+
q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[ib+1] << 8) & 0x700)))),
|
3771
|
+
vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[ib+1] << 5) & 0x700)))));
|
3772
|
+
q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[ib+1] << 2) & 0x700)))),
|
3773
|
+
vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[ib+1] >> 1) & 0x700)))));
|
3774
|
+
qs += 8;
|
3775
|
+
|
3776
|
+
q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
|
3777
|
+
|
3778
|
+
const int32x4_t p1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]);
|
3779
|
+
const int32x4_t p2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]);
|
3780
|
+
|
3781
|
+
const int ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
|
3782
|
+
const int ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
|
3783
|
+
sumi1 += vaddvq_s32(p1) * ls1;
|
3784
|
+
sumi2 += vaddvq_s32(p2) * ls2;
|
3785
|
+
sumi3 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * ls1 * (qh[ib+0] & 0x8000 ? -1 : 1)
|
3786
|
+
+ (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * ls2 * (qh[ib+1] & 0x8000 ? -1 : 1);
|
3787
|
+
|
3788
|
+
}
|
3789
|
+
|
3790
|
+
sumf += y[i].d * LM_GGML_CPU_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2 + IQ1S_DELTA * sumi3);
|
3791
|
+
}
|
3792
|
+
|
3793
|
+
*s = sumf;
|
3794
|
+
|
3795
|
+
#else
|
3796
|
+
|
3797
|
+
float sumf = 0;
|
3798
|
+
for (int i = 0; i < nb; i++) {
|
3799
|
+
|
3800
|
+
const int8_t * q8 = y[i].qs;
|
3801
|
+
const uint8_t * qs = x[i].qs;
|
3802
|
+
const uint16_t * qh = x[i].qh;
|
3803
|
+
|
3804
|
+
int sumi = 0, sumi1 = 0;
|
3805
|
+
for (int ib = 0; ib < QK_K/32; ++ib) {
|
3806
|
+
const int ls = 2*((qh[ib] >> 12) & 7) + 1;
|
3807
|
+
const int delta = qh[ib] & 0x8000 ? -1 : 1;
|
3808
|
+
int lsum = 0;
|
3809
|
+
for (int l = 0; l < 4; ++l) {
|
3810
|
+
const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
|
3811
|
+
for (int j = 0; j < 8; ++j) {
|
3812
|
+
lsum += q8[j] * grid[j];
|
3813
|
+
}
|
3814
|
+
q8 += 8;
|
3815
|
+
}
|
3816
|
+
sumi += ls * lsum;
|
3817
|
+
sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]);
|
3818
|
+
qs += 4;
|
3819
|
+
}
|
3820
|
+
|
3821
|
+
sumf += LM_GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1);
|
3822
|
+
}
|
3823
|
+
|
3824
|
+
*s = sumf;
|
3825
|
+
|
3826
|
+
#endif
|
3827
|
+
}
|
3828
|
+
|
3829
|
+
void lm_ggml_vec_dot_iq1_m_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
|
3830
|
+
assert(n % QK_K == 0);
|
3831
|
+
assert(nrc == 1);
|
3832
|
+
UNUSED(nrc);
|
3833
|
+
UNUSED(bx);
|
3834
|
+
UNUSED(by);
|
3835
|
+
UNUSED(bs);
|
3836
|
+
|
3837
|
+
const block_iq1_m * LM_GGML_RESTRICT x = vx;
|
3838
|
+
const block_q8_K * LM_GGML_RESTRICT y = vy;
|
3839
|
+
|
3840
|
+
const int nb = n / QK_K;
|
3841
|
+
|
3842
|
+
iq1m_scale_t scale;
|
3843
|
+
|
3844
|
+
#if defined __ARM_NEON
|
3845
|
+
const int32x4_t mask = vdupq_n_s32(0x7);
|
3846
|
+
const int32x4_t mone = vdupq_n_s32(1);
|
3847
|
+
const int32x4_t mzero = vdupq_n_s32(0);
|
3848
|
+
|
3849
|
+
lm_ggml_int8x16x4_t deltas;
|
3850
|
+
deltas.val[0] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(+1));
|
3851
|
+
deltas.val[1] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(+1));
|
3852
|
+
deltas.val[2] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(-1));
|
3853
|
+
deltas.val[3] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(-1));
|
3854
|
+
|
3855
|
+
lm_ggml_int8x16x4_t q1b;
|
3856
|
+
lm_ggml_int8x16x4_t q8b;
|
3857
|
+
|
3858
|
+
uint32_t aux32;
|
3859
|
+
const uint8_t * aux8 = (const uint8_t *)&aux32;
|
3860
|
+
|
3861
|
+
float sumf = 0;
|
3862
|
+
for (int i = 0; i < nb; ++i) {
|
3863
|
+
|
3864
|
+
const int8_t * q8 = y[i].qs;
|
3865
|
+
const uint8_t * qs = x[i].qs;
|
3866
|
+
const uint8_t * qh = x[i].qh;
|
3867
|
+
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
3868
|
+
|
3869
|
+
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
3870
|
+
|
3871
|
+
int32x4_t sumi1 = mzero;
|
3872
|
+
int32x4_t sumi2 = mzero;
|
3873
|
+
|
3874
|
+
for (int ib = 0; ib < QK_K/32; ib += 2) {
|
3875
|
+
|
3876
|
+
q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)))),
|
3877
|
+
vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 4) & 0x700)))));
|
3878
|
+
q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[1] << 8) & 0x700)))),
|
3879
|
+
vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[1] << 4) & 0x700)))));
|
3880
|
+
q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[2] << 8) & 0x700)))),
|
3881
|
+
vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[2] << 4) & 0x700)))));
|
3882
|
+
q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[3] << 8) & 0x700)))),
|
3883
|
+
vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[3] << 4) & 0x700)))));
|
3884
|
+
|
3885
|
+
q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
|
3886
|
+
|
3887
|
+
const int32x4_t p1 = vpaddq_s32(lm_ggml_vdotq_s32(mzero, q1b.val[0], q8b.val[0]), lm_ggml_vdotq_s32(mzero, q1b.val[1], q8b.val[1]));
|
3888
|
+
const int32x4_t p2 = vpaddq_s32(lm_ggml_vdotq_s32(mzero, q1b.val[2], q8b.val[2]), lm_ggml_vdotq_s32(mzero, q1b.val[3], q8b.val[3]));
|
3889
|
+
const int32x4_t p12 = vpaddq_s32(p1, p2);
|
3890
|
+
|
3891
|
+
const uint32_t * qh32 = (const uint32_t *)qh; // we are 4-byte aligned, so we can do that
|
3892
|
+
aux32 = ((qh32[0] >> 3) & 0x01010101) | ((qh32[0] >> 6) & 0x02020202);
|
3893
|
+
|
3894
|
+
const int32x4_t p3 = vpaddq_s32(lm_ggml_vdotq_s32(mzero, deltas.val[aux8[0]], q8b.val[0]), lm_ggml_vdotq_s32(mzero, deltas.val[aux8[1]], q8b.val[1]));
|
3895
|
+
const int32x4_t p4 = vpaddq_s32(lm_ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), lm_ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3]));
|
3896
|
+
const int32x4_t p34 = vpaddq_s32(p3, p4);
|
3897
|
+
|
3898
|
+
int32x4_t scales_4 = lm_ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9);
|
3899
|
+
|
3900
|
+
scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone);
|
3901
|
+
|
3902
|
+
sumi1 = vmlaq_s32(sumi1, scales_4, p12);
|
3903
|
+
sumi2 = vmlaq_s32(sumi2, scales_4, p34);
|
3904
|
+
|
3905
|
+
qs += 8; qh += 4;
|
3906
|
+
|
3907
|
+
}
|
3908
|
+
|
3909
|
+
sumf += y[i].d * LM_GGML_CPU_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2));
|
3910
|
+
}
|
3911
|
+
|
3912
|
+
*s = sumf;
|
3913
|
+
|
3914
|
+
#else
|
3915
|
+
|
3916
|
+
int sum1[2], sum2[2], delta[4];
|
3917
|
+
|
3918
|
+
float sumf = 0;
|
3919
|
+
for (int i = 0; i < nb; i++) {
|
3920
|
+
|
3921
|
+
const int8_t * q8 = y[i].qs;
|
3922
|
+
const uint8_t * qs = x[i].qs;
|
3923
|
+
const uint8_t * qh = x[i].qh;
|
3924
|
+
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
3925
|
+
|
3926
|
+
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
3927
|
+
|
3928
|
+
int sumi1 = 0, sumi2 = 0;
|
3929
|
+
for (int ib = 0; ib < QK_K/32; ++ib) {
|
3930
|
+
delta[0] = qh[0] & 0x08 ? -1 : 1;
|
3931
|
+
delta[1] = qh[0] & 0x80 ? -1 : 1;
|
3932
|
+
delta[2] = qh[1] & 0x08 ? -1 : 1;
|
3933
|
+
delta[3] = qh[1] & 0x80 ? -1 : 1;
|
3934
|
+
sum1[0] = sum1[1] = sum2[0] = sum2[1] = 0;
|
3935
|
+
for (int l = 0; l < 4; ++l) {
|
3936
|
+
const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((uint16_t)qh[l/2] << (8 - 4*(l%2))) & 0x700)));
|
3937
|
+
int lsum1 = 0, lsum2 = 0;
|
3938
|
+
for (int j = 0; j < 8; ++j) {
|
3939
|
+
lsum1 += q8[j] * grid[j];
|
3940
|
+
lsum2 += q8[j];
|
3941
|
+
}
|
3942
|
+
q8 += 8;
|
3943
|
+
sum1[l/2] += lsum1;
|
3944
|
+
sum2[l/2] += lsum2*delta[l];
|
3945
|
+
}
|
3946
|
+
|
3947
|
+
const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1;
|
3948
|
+
const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1;
|
3949
|
+
|
3950
|
+
sumi1 += sum1[0] * ls1 + sum1[1] * ls2;
|
3951
|
+
sumi2 += sum2[0] * ls1 + sum2[1] * ls2;
|
3952
|
+
qs += 4;
|
3953
|
+
qh += 2;
|
3954
|
+
}
|
3955
|
+
|
3956
|
+
sumf += LM_GGML_CPU_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);
|
3957
|
+
}
|
3958
|
+
|
3959
|
+
*s = sumf;
|
3960
|
+
|
3961
|
+
#endif
|
3962
|
+
}
|
3963
|
+
|
3964
|
+
void lm_ggml_vec_dot_iq4_nl_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
|
3965
|
+
assert(nrc == 1);
|
3966
|
+
UNUSED(nrc);
|
3967
|
+
UNUSED(bx);
|
3968
|
+
UNUSED(by);
|
3969
|
+
UNUSED(bs);
|
3970
|
+
assert(n % QK4_NL == 0);
|
3971
|
+
static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
|
3972
|
+
|
3973
|
+
const block_iq4_nl * LM_GGML_RESTRICT x = vx;
|
3974
|
+
const block_q8_0 * LM_GGML_RESTRICT y = vy;
|
3975
|
+
|
3976
|
+
const int nb = n / QK4_NL;
|
3977
|
+
|
3978
|
+
int ib = 0;
|
3979
|
+
float sumf = 0;
|
3980
|
+
|
3981
|
+
#if defined __ARM_NEON
|
3982
|
+
const int8x16_t values = vld1q_s8(kvalues_iq4nl);
|
3983
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
3984
|
+
uint8x16x2_t q4bits;
|
3985
|
+
int8x16x4_t q4b;
|
3986
|
+
int8x16x4_t q8b;
|
3987
|
+
int32x4_t prod_1, prod_2;
|
3988
|
+
|
3989
|
+
for (; ib + 1 < nb; ib += 2) {
|
3990
|
+
|
3991
|
+
q4bits.val[0] = vld1q_u8(x[ib + 0].qs);
|
3992
|
+
q4bits.val[1] = vld1q_u8(x[ib + 1].qs);
|
3993
|
+
q8b.val[0] = vld1q_s8(y[ib + 0].qs);
|
3994
|
+
q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16);
|
3995
|
+
q8b.val[2] = vld1q_s8(y[ib + 1].qs);
|
3996
|
+
q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16);
|
3997
|
+
|
3998
|
+
q4b.val[0] = lm_ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
|
3999
|
+
q4b.val[1] = lm_ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
|
4000
|
+
q4b.val[2] = lm_ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
|
4001
|
+
q4b.val[3] = lm_ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
|
4002
|
+
|
4003
|
+
prod_1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
|
4004
|
+
prod_2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
|
4005
|
+
|
4006
|
+
sumf +=
|
4007
|
+
LM_GGML_CPU_FP16_TO_FP32(x[ib+0].d) * LM_GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +
|
4008
|
+
LM_GGML_CPU_FP16_TO_FP32(x[ib+1].d) * LM_GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);
|
4009
|
+
}
|
4010
|
+
|
4011
|
+
#endif
|
4012
|
+
for (; ib < nb; ++ib) {
|
4013
|
+
const float d = LM_GGML_CPU_FP16_TO_FP32(y[ib].d)*LM_GGML_CPU_FP16_TO_FP32(x[ib].d);
|
4014
|
+
int sumi1 = 0, sumi2 = 0;
|
4015
|
+
for (int j = 0; j < QK4_NL/2; ++j) {
|
4016
|
+
sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];
|
4017
|
+
sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4];
|
4018
|
+
}
|
4019
|
+
sumf += d * (sumi1 + sumi2);
|
4020
|
+
}
|
4021
|
+
*s = sumf;
|
4022
|
+
}
|
4023
|
+
|
4024
|
+
void lm_ggml_vec_dot_iq4_xs_q8_K(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, size_t bx, const void * LM_GGML_RESTRICT vy, size_t by, int nrc) {
|
4025
|
+
assert(nrc == 1);
|
4026
|
+
UNUSED(nrc);
|
4027
|
+
UNUSED(bx);
|
4028
|
+
UNUSED(by);
|
4029
|
+
UNUSED(bs);
|
4030
|
+
assert(n % QK_K == 0);
|
4031
|
+
|
4032
|
+
const block_iq4_xs * LM_GGML_RESTRICT x = vx;
|
4033
|
+
const block_q8_K * LM_GGML_RESTRICT y = vy;
|
4034
|
+
|
4035
|
+
const int nb = n / QK_K;
|
4036
|
+
|
4037
|
+
#if defined __ARM_NEON
|
4038
|
+
const int8x16_t values = vld1q_s8(kvalues_iq4nl);
|
4039
|
+
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
4040
|
+
lm_ggml_uint8x16x2_t q4bits;
|
4041
|
+
lm_ggml_int8x16x4_t q4b;
|
4042
|
+
lm_ggml_int8x16x4_t q8b;
|
4043
|
+
int32x4_t prod_1, prod_2;
|
4044
|
+
|
4045
|
+
float sumf = 0;
|
4046
|
+
|
4047
|
+
for (int ibl = 0; ibl < nb; ++ibl) {
|
4048
|
+
|
4049
|
+
const int8_t * q8 = y[ibl].qs;
|
4050
|
+
const uint8_t * q4 = x[ibl].qs;
|
4051
|
+
uint16_t h = x[ibl].scales_h;
|
4052
|
+
|
4053
|
+
int sumi1 = 0, sumi2 = 0;
|
4054
|
+
for (int ib = 0; ib < QK_K/64; ++ib) {
|
4055
|
+
|
4056
|
+
q4bits = lm_ggml_vld1q_u8_x2(q4); q4 += 32;
|
4057
|
+
q8b = lm_ggml_vld1q_s8_x4(q8); q8 += 64;
|
4058
|
+
|
4059
|
+
q4b.val[0] = lm_ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
|
4060
|
+
q4b.val[1] = lm_ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
|
4061
|
+
q4b.val[2] = lm_ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
|
4062
|
+
q4b.val[3] = lm_ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
|
4063
|
+
|
4064
|
+
prod_1 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
|
4065
|
+
prod_2 = lm_ggml_vdotq_s32(lm_ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
|
4066
|
+
|
4067
|
+
int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32;
|
4068
|
+
int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32;
|
4069
|
+
h >>= 4;
|
4070
|
+
sumi1 += vaddvq_s32(prod_1) * ls1;
|
4071
|
+
sumi2 += vaddvq_s32(prod_2) * ls2;
|
4072
|
+
|
4073
|
+
}
|
4074
|
+
|
4075
|
+
sumf += LM_GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2);
|
4076
|
+
}
|
4077
|
+
|
4078
|
+
*s = sumf;
|
4079
|
+
|
4080
|
+
#else
|
4081
|
+
float sumf = 0;
|
4082
|
+
for (int ibl = 0; ibl < nb; ++ibl) {
|
4083
|
+
const float d4d8 = LM_GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d;
|
4084
|
+
uint16_t h = x[ibl].scales_h;
|
4085
|
+
const uint8_t * qs = x[ibl].qs;
|
4086
|
+
const int8_t * q8 = y[ibl].qs;
|
4087
|
+
for (int ib = 0; ib < QK_K/32; ib += 2) {
|
4088
|
+
const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30);
|
4089
|
+
const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30);
|
4090
|
+
h >>= 4;
|
4091
|
+
const float d1 = d4d8*(ls1 - 32);
|
4092
|
+
const float d2 = d4d8*(ls2 - 32);
|
4093
|
+
int sumi1 = 0, sumi2 = 0;
|
4094
|
+
for (int j = 0; j < 16; ++j) {
|
4095
|
+
sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
|
4096
|
+
sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4];
|
4097
|
+
}
|
4098
|
+
sumf += d1 * (sumi1 + sumi2);
|
4099
|
+
qs += 16;
|
4100
|
+
q8 += 32;
|
4101
|
+
sumi1 = sumi2 = 0;
|
4102
|
+
for (int j = 0; j < 16; ++j) {
|
4103
|
+
sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
|
4104
|
+
sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4];
|
4105
|
+
}
|
4106
|
+
sumf += d2 * (sumi1 + sumi2);
|
4107
|
+
qs += 16;
|
4108
|
+
q8 += 32;
|
4109
|
+
}
|
4110
|
+
}
|
4111
|
+
*s = sumf;
|
4112
|
+
#endif
|
4113
|
+
}
|
4114
|
+
|