@fugood/llama.node 0.2.3 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (319) hide show
  1. package/CMakeLists.txt +6 -3
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +8 -1
  17. package/package.json +3 -3
  18. package/patches/llama.patch +12 -12
  19. package/src/DetokenizeWorker.cpp +1 -1
  20. package/src/LlamaContext.cpp +33 -1
  21. package/src/LlamaContext.h +1 -0
  22. package/src/llama.cpp/.github/workflows/bench.yml +310 -0
  23. package/src/llama.cpp/.github/workflows/build.yml +1315 -0
  24. package/src/llama.cpp/.github/workflows/close-issue.yml +23 -0
  25. package/src/llama.cpp/.github/workflows/docker.yml +116 -0
  26. package/src/llama.cpp/.github/workflows/editorconfig.yml +27 -0
  27. package/src/llama.cpp/.github/workflows/gguf-publish.yml +44 -0
  28. package/src/llama.cpp/.github/workflows/labeler.yml +17 -0
  29. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +65 -0
  30. package/src/llama.cpp/.github/workflows/nix-ci.yml +72 -0
  31. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +22 -0
  32. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +36 -0
  33. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +35 -0
  34. package/src/llama.cpp/.github/workflows/python-lint.yml +23 -0
  35. package/src/llama.cpp/.github/workflows/python-type-check.yml +38 -0
  36. package/src/llama.cpp/.github/workflows/server.yml +183 -0
  37. package/src/llama.cpp/CMakeLists.txt +91 -1245
  38. package/src/llama.cpp/cmake/arm64-windows-llvm.cmake +1 -1
  39. package/src/llama.cpp/cmake/build-info.cmake +58 -0
  40. package/src/llama.cpp/cmake/git-vars.cmake +22 -0
  41. package/src/llama.cpp/common/CMakeLists.txt +4 -3
  42. package/src/llama.cpp/common/build-info.cpp.in +4 -0
  43. package/src/llama.cpp/common/common.cpp +1116 -877
  44. package/src/llama.cpp/common/common.h +191 -77
  45. package/src/llama.cpp/common/grammar-parser.cpp +118 -31
  46. package/src/llama.cpp/common/json-schema-to-grammar.cpp +346 -65
  47. package/src/llama.cpp/common/log.h +1 -1
  48. package/src/llama.cpp/common/ngram-cache.h +10 -3
  49. package/src/llama.cpp/common/sampling.cpp +19 -10
  50. package/src/llama.cpp/docs/build.md +353 -0
  51. package/src/llama.cpp/examples/CMakeLists.txt +22 -22
  52. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +6 -6
  54. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  55. package/src/llama.cpp/examples/batched/batched.cpp +52 -55
  56. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  57. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +20 -72
  58. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/examples/chat-13B.bat +57 -0
  60. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  61. package/src/llama.cpp/examples/{finetune → cvector-generator}/CMakeLists.txt +2 -2
  62. package/src/llama.cpp/examples/cvector-generator/completions.txt +582 -0
  63. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +503 -0
  64. package/src/llama.cpp/examples/cvector-generator/mean.hpp +48 -0
  65. package/src/llama.cpp/examples/cvector-generator/negative.txt +4 -0
  66. package/src/llama.cpp/examples/cvector-generator/pca.hpp +325 -0
  67. package/src/llama.cpp/examples/cvector-generator/positive.txt +4 -0
  68. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +35 -0
  69. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  70. package/src/llama.cpp/examples/embedding/embedding.cpp +94 -46
  71. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +2 -2
  72. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +4 -6
  73. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/export-lora/export-lora.cpp +344 -386
  75. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +2 -2
  76. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +30 -25
  77. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  78. package/src/llama.cpp/examples/gguf/gguf.cpp +5 -0
  79. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +15 -0
  80. package/src/llama.cpp/examples/gguf-hash/deps/rotate-bits/rotate-bits.h +46 -0
  81. package/src/llama.cpp/examples/gguf-hash/deps/sha1/sha1.c +295 -0
  82. package/src/llama.cpp/examples/gguf-hash/deps/sha1/sha1.h +52 -0
  83. package/src/llama.cpp/examples/gguf-hash/deps/sha256/sha256.c +221 -0
  84. package/src/llama.cpp/examples/gguf-hash/deps/sha256/sha256.h +24 -0
  85. package/src/llama.cpp/examples/gguf-hash/deps/xxhash/xxhash.c +42 -0
  86. package/src/llama.cpp/examples/gguf-hash/deps/xxhash/xxhash.h +7093 -0
  87. package/src/llama.cpp/examples/gguf-hash/gguf-hash.cpp +693 -0
  88. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  89. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +3 -3
  90. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  91. package/src/llama.cpp/examples/gritlm/gritlm.cpp +6 -2
  92. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/imatrix/imatrix.cpp +137 -176
  94. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  95. package/src/llama.cpp/examples/infill/infill.cpp +38 -153
  96. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +175 -94
  97. package/src/llama.cpp/examples/llama.android/app/build.gradle.kts +65 -0
  98. package/src/llama.cpp/examples/llama.android/build.gradle.kts +6 -0
  99. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +68 -0
  100. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/CMakeLists.txt +11 -7
  101. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +2 -2
  102. package/src/llama.cpp/examples/llama.android/settings.gradle.kts +18 -0
  103. package/src/llama.cpp/examples/llava/CMakeLists.txt +6 -5
  104. package/src/llama.cpp/examples/llava/android/build_64.sh +8 -0
  105. package/src/llama.cpp/examples/llava/clip.cpp +23 -14
  106. package/src/llama.cpp/examples/llava/llava-cli.cpp +8 -6
  107. package/src/llama.cpp/examples/llava/requirements.txt +3 -2
  108. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  109. package/src/llama.cpp/examples/lookahead/lookahead.cpp +2 -1
  110. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  111. package/src/llama.cpp/examples/lookup/lookup-create.cpp +2 -0
  112. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  113. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -2
  114. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  115. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  116. package/src/llama.cpp/examples/main/main.cpp +98 -75
  117. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +4 -5
  118. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  119. package/src/llama.cpp/examples/parallel/parallel.cpp +2 -1
  120. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  121. package/src/llama.cpp/examples/passkey/passkey.cpp +23 -43
  122. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  123. package/src/llama.cpp/examples/perplexity/perplexity.cpp +13 -10
  124. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  125. package/src/llama.cpp/examples/quantize/quantize.cpp +37 -34
  126. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  127. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +1 -1
  128. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  129. package/src/llama.cpp/examples/retrieval/retrieval.cpp +26 -77
  130. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  131. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +14 -7
  132. package/src/llama.cpp/examples/server/CMakeLists.txt +26 -2
  133. package/src/llama.cpp/examples/server/server.cpp +274 -671
  134. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  135. package/src/llama.cpp/examples/server/utils.hpp +28 -29
  136. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  137. package/src/llama.cpp/examples/simple/simple.cpp +21 -29
  138. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  139. package/src/llama.cpp/examples/speculative/speculative.cpp +2 -1
  140. package/src/llama.cpp/examples/sycl/CMakeLists.txt +1 -1
  141. package/src/llama.cpp/examples/sycl/build.sh +23 -0
  142. package/src/llama.cpp/examples/sycl/run-llama2.sh +36 -0
  143. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +33 -0
  144. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +9 -0
  145. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  146. package/src/llama.cpp/examples/tokenize/tokenize.cpp +16 -2
  147. package/src/llama.cpp/ggml/CMakeLists.txt +253 -0
  148. package/src/llama.cpp/{cmake → ggml/cmake}/FindSIMD.cmake +6 -6
  149. package/src/llama.cpp/{ggml-backend.h → ggml/include/ggml-backend.h} +22 -17
  150. package/src/llama.cpp/ggml/include/ggml-blas.h +23 -0
  151. package/src/llama.cpp/ggml/include/ggml-cann.h +125 -0
  152. package/src/llama.cpp/{ggml-cuda.h → ggml/include/ggml-cuda.h} +3 -0
  153. package/src/llama.cpp/{ggml-metal.h → ggml/include/ggml-metal.h} +1 -2
  154. package/src/llama.cpp/{ggml-sycl.h → ggml/include/ggml-sycl.h} +3 -10
  155. package/src/llama.cpp/{ggml.h → ggml/include/ggml.h} +80 -85
  156. package/src/llama.cpp/ggml/src/CMakeLists.txt +1329 -0
  157. package/src/llama.cpp/ggml/src/ggml-aarch64.c +2193 -0
  158. package/src/llama.cpp/ggml/src/ggml-aarch64.h +39 -0
  159. package/src/llama.cpp/{ggml-alloc.c → ggml/src/ggml-alloc.c} +100 -49
  160. package/src/llama.cpp/{ggml-backend-impl.h → ggml/src/ggml-backend-impl.h} +20 -8
  161. package/src/llama.cpp/{ggml-backend.c → ggml/src/ggml-backend.c} +307 -167
  162. package/src/llama.cpp/ggml/src/ggml-blas.cpp +367 -0
  163. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +198 -0
  164. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +230 -0
  165. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +2944 -0
  166. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  167. package/src/llama.cpp/ggml/src/ggml-cann/common.h +282 -0
  168. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +32 -0
  169. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +17 -0
  170. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +223 -0
  171. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +186 -0
  172. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +180 -0
  173. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +193 -0
  174. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  175. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +208 -0
  176. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +206 -0
  177. package/src/llama.cpp/ggml/src/ggml-cann.cpp +2023 -0
  178. package/src/llama.cpp/{ggml-common.h → ggml/src/ggml-common.h} +41 -7
  179. package/src/llama.cpp/{ggml-impl.h → ggml/src/ggml-impl.h} +113 -9
  180. package/src/llama.cpp/{ggml-kompute.cpp → ggml/src/ggml-kompute.cpp} +33 -18
  181. package/src/llama.cpp/{ggml-quants.c → ggml/src/ggml-quants.c} +1460 -940
  182. package/src/llama.cpp/{ggml-quants.h → ggml/src/ggml-quants.h} +19 -20
  183. package/src/llama.cpp/{ggml-rpc.cpp → ggml/src/ggml-rpc.cpp} +95 -72
  184. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +27 -0
  185. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +53 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +355 -0
  187. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +195 -0
  188. package/src/llama.cpp/ggml/src/ggml-sycl/concat.hpp +21 -0
  189. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +547 -0
  190. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +27 -0
  191. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +698 -0
  192. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  193. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.hpp +27 -0
  194. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +3011 -0
  195. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  196. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.hpp +33 -0
  197. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1027 -0
  198. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.hpp +27 -0
  199. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +374 -0
  200. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +35 -0
  201. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +66 -0
  202. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +275 -0
  203. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +22 -0
  204. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +251 -0
  205. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.hpp +24 -0
  206. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +1140 -0
  207. package/src/llama.cpp/ggml/src/ggml-sycl.cpp +5314 -0
  208. package/src/llama.cpp/{ggml-vulkan.cpp → ggml/src/ggml-vulkan.cpp} +1781 -1868
  209. package/src/llama.cpp/{ggml.c → ggml/src/ggml.c} +1245 -2087
  210. package/src/llama.cpp/{sgemm.cpp → ggml/src/llamafile/sgemm.cpp} +21 -24
  211. package/src/llama.cpp/{sgemm.h → ggml/src/llamafile/sgemm.h} +1 -1
  212. package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +5 -0
  213. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +552 -0
  214. package/src/llama.cpp/{llama.h → include/llama.h} +175 -100
  215. package/src/llama.cpp/models/.editorconfig +1 -0
  216. package/src/llama.cpp/models/ggml-vocab-aquila.gguf +0 -0
  217. package/src/llama.cpp/models/ggml-vocab-baichuan.gguf +0 -0
  218. package/src/llama.cpp/models/ggml-vocab-bert-bge.gguf +0 -0
  219. package/src/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +112 -0
  220. package/src/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +46 -0
  221. package/src/llama.cpp/models/ggml-vocab-command-r.gguf +0 -0
  222. package/src/llama.cpp/models/ggml-vocab-command-r.gguf.inp +112 -0
  223. package/src/llama.cpp/models/ggml-vocab-command-r.gguf.out +46 -0
  224. package/src/llama.cpp/models/ggml-vocab-deepseek-coder.gguf +0 -0
  225. package/src/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +112 -0
  226. package/src/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +46 -0
  227. package/src/llama.cpp/models/ggml-vocab-deepseek-llm.gguf +0 -0
  228. package/src/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +112 -0
  229. package/src/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +46 -0
  230. package/src/llama.cpp/models/ggml-vocab-falcon.gguf +0 -0
  231. package/src/llama.cpp/models/ggml-vocab-falcon.gguf.inp +112 -0
  232. package/src/llama.cpp/models/ggml-vocab-falcon.gguf.out +46 -0
  233. package/src/llama.cpp/models/ggml-vocab-gpt-2.gguf +0 -0
  234. package/src/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +112 -0
  235. package/src/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +46 -0
  236. package/src/llama.cpp/models/ggml-vocab-gpt-neox.gguf +0 -0
  237. package/src/llama.cpp/models/ggml-vocab-llama-bpe.gguf +0 -0
  238. package/src/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +112 -0
  239. package/src/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +46 -0
  240. package/src/llama.cpp/models/ggml-vocab-llama-spm.gguf +0 -0
  241. package/src/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +112 -0
  242. package/src/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +46 -0
  243. package/src/llama.cpp/models/ggml-vocab-mpt.gguf +0 -0
  244. package/src/llama.cpp/models/ggml-vocab-mpt.gguf.inp +112 -0
  245. package/src/llama.cpp/models/ggml-vocab-mpt.gguf.out +46 -0
  246. package/src/llama.cpp/models/ggml-vocab-phi-3.gguf +0 -0
  247. package/src/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +112 -0
  248. package/src/llama.cpp/models/ggml-vocab-phi-3.gguf.out +46 -0
  249. package/src/llama.cpp/models/ggml-vocab-qwen2.gguf +0 -0
  250. package/src/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +112 -0
  251. package/src/llama.cpp/models/ggml-vocab-qwen2.gguf.out +46 -0
  252. package/src/llama.cpp/models/ggml-vocab-refact.gguf +0 -0
  253. package/src/llama.cpp/models/ggml-vocab-refact.gguf.inp +112 -0
  254. package/src/llama.cpp/models/ggml-vocab-refact.gguf.out +46 -0
  255. package/src/llama.cpp/models/ggml-vocab-starcoder.gguf +0 -0
  256. package/src/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +112 -0
  257. package/src/llama.cpp/models/ggml-vocab-starcoder.gguf.out +46 -0
  258. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  259. package/src/llama.cpp/requirements/requirements-all.txt +12 -0
  260. package/src/llama.cpp/requirements/requirements-compare-llama-bench.txt +2 -0
  261. package/src/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -0
  262. package/src/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +3 -0
  263. package/src/llama.cpp/requirements/{requirements-convert.txt → requirements-convert_legacy_llama.txt} +1 -1
  264. package/src/llama.cpp/requirements/requirements-convert_llama_ggml_to_gguf.txt +1 -0
  265. package/src/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt +2 -0
  266. package/src/llama.cpp/requirements/requirements-pydantic.txt +3 -0
  267. package/src/llama.cpp/requirements/requirements-test-tokenizer-random.txt +1 -0
  268. package/src/llama.cpp/requirements.txt +5 -4
  269. package/src/llama.cpp/scripts/build-info.sh +30 -0
  270. package/src/llama.cpp/scripts/install-oneapi.bat +19 -0
  271. package/src/llama.cpp/src/CMakeLists.txt +33 -0
  272. package/src/llama.cpp/src/llama-grammar.cpp +539 -0
  273. package/src/llama.cpp/src/llama-grammar.h +39 -0
  274. package/src/llama.cpp/src/llama-impl.h +26 -0
  275. package/src/llama.cpp/src/llama-sampling.cpp +635 -0
  276. package/src/llama.cpp/src/llama-sampling.h +56 -0
  277. package/src/llama.cpp/src/llama-vocab.cpp +1721 -0
  278. package/src/llama.cpp/src/llama-vocab.h +130 -0
  279. package/src/llama.cpp/{llama.cpp → src/llama.cpp} +5979 -5260
  280. package/src/llama.cpp/{unicode-data.cpp → src/unicode-data.cpp} +851 -802
  281. package/src/llama.cpp/{unicode.cpp → src/unicode.cpp} +52 -30
  282. package/src/llama.cpp/{unicode.h → src/unicode.h} +5 -1
  283. package/src/llama.cpp/tests/CMakeLists.txt +19 -20
  284. package/src/llama.cpp/tests/test-backend-ops.cpp +245 -67
  285. package/src/llama.cpp/tests/test-chat-template.cpp +57 -3
  286. package/src/llama.cpp/tests/test-double-float.cpp +2 -2
  287. package/src/llama.cpp/tests/test-grad0.cpp +2 -2
  288. package/src/llama.cpp/tests/test-grammar-integration.cpp +978 -31
  289. package/src/llama.cpp/tests/test-grammar-parser.cpp +423 -158
  290. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +508 -135
  291. package/src/llama.cpp/tests/test-llama-grammar.cpp +15 -9
  292. package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -1
  293. package/src/llama.cpp/tests/test-quantize-perf.cpp +1 -1
  294. package/src/llama.cpp/tests/test-rope.cpp +3 -4
  295. package/src/llama.cpp/tests/test-sampling.cpp +5 -5
  296. package/src/llama.cpp/tests/test-tokenizer-0.cpp +6 -6
  297. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +20 -15
  298. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +22 -11
  299. package/bin/darwin/arm64/default.metallib +0 -0
  300. package/bin/darwin/x64/default.metallib +0 -0
  301. package/src/llama.cpp/examples/beam-search/CMakeLists.txt +0 -5
  302. package/src/llama.cpp/examples/beam-search/beam-search.cpp +0 -188
  303. package/src/llama.cpp/examples/finetune/finetune.cpp +0 -1862
  304. package/src/llama.cpp/examples/llama.android/llama/CMakeLists.txt +0 -55
  305. package/src/llama.cpp/examples/train-text-from-scratch/CMakeLists.txt +0 -5
  306. package/src/llama.cpp/examples/train-text-from-scratch/train-text-from-scratch.cpp +0 -1253
  307. package/src/llama.cpp/ggml-opencl.cpp +0 -2305
  308. package/src/llama.cpp/ggml-opencl.h +0 -36
  309. package/src/llama.cpp/ggml-sycl.cpp +0 -17340
  310. package/src/llama.cpp/ggml-vulkan-shaders.hpp +0 -81211
  311. package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf-update.txt +0 -2
  312. package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf.txt +0 -2
  313. package/src/llama.cpp/requirements/requirements-convert-llama-ggml-to-gguf.txt +0 -1
  314. package/src/llama.cpp/scripts/gen-build-info-cpp.cmake +0 -24
  315. /package/src/llama.cpp/{ggml-alloc.h → ggml/include/ggml-alloc.h} +0 -0
  316. /package/src/llama.cpp/{ggml-kompute.h → ggml/include/ggml-kompute.h} +0 -0
  317. /package/src/llama.cpp/{ggml-rpc.h → ggml/include/ggml-rpc.h} +0 -0
  318. /package/src/llama.cpp/{ggml-vulkan.h → ggml/include/ggml-vulkan.h} +0 -0
  319. /package/src/llama.cpp/{unicode-data.h → src/unicode-data.h} +0 -0
@@ -4,8 +4,6 @@
4
4
  #include "ggml-quants.h"
5
5
  #include "ggml-impl.h"
6
6
 
7
- #define GGML_COMMON_IMPL_C
8
- #include "ggml-common.h"
9
7
 
10
8
  #include <math.h>
11
9
  #include <string.h>
@@ -660,7 +658,7 @@ static inline __m128i packNibbles( __m256i bytes ) {
660
658
  #endif //__loongarch_asx
661
659
 
662
660
  // reference implementation for deterministic creation of model files
663
- void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
661
+ void quantize_row_q4_0_ref(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
664
662
  static const int qk = QK4_0;
665
663
 
666
664
  assert(k % qk == 0);
@@ -698,11 +696,11 @@ void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict
698
696
  }
699
697
 
700
698
  void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) {
701
- quantize_row_q4_0_reference(x, y, k);
699
+ quantize_row_q4_0_ref(x, y, k);
702
700
  }
703
701
 
704
702
 
705
- void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int64_t k) {
703
+ void quantize_row_q4_1_ref(const float * restrict x, block_q4_1 * restrict y, int64_t k) {
706
704
  const int qk = QK4_1;
707
705
 
708
706
  assert(k % qk == 0);
@@ -740,10 +738,10 @@ void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict
740
738
  }
741
739
 
742
740
  void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) {
743
- quantize_row_q4_1_reference(x, y, k);
741
+ quantize_row_q4_1_ref(x, y, k);
744
742
  }
745
743
 
746
- void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int64_t k) {
744
+ void quantize_row_q5_0_ref(const float * restrict x, block_q5_0 * restrict y, int64_t k) {
747
745
  static const int qk = QK5_0;
748
746
 
749
747
  assert(k % qk == 0);
@@ -788,10 +786,10 @@ void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict
788
786
  }
789
787
 
790
788
  void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) {
791
- quantize_row_q5_0_reference(x, y, k);
789
+ quantize_row_q5_0_ref(x, y, k);
792
790
  }
793
791
 
794
- void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int64_t k) {
792
+ void quantize_row_q5_1_ref(const float * restrict x, block_q5_1 * restrict y, int64_t k) {
795
793
  const int qk = QK5_1;
796
794
 
797
795
  assert(k % qk == 0);
@@ -836,11 +834,11 @@ void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict
836
834
  }
837
835
 
838
836
  void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) {
839
- quantize_row_q5_1_reference(x, y, k);
837
+ quantize_row_q5_1_ref(x, y, k);
840
838
  }
841
839
 
842
840
  // reference implementation for deterministic creation of model files
843
- void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int64_t k) {
841
+ void quantize_row_q8_0_ref(const float * restrict x, block_q8_0 * restrict y, int64_t k) {
844
842
  assert(k % QK8_0 == 0);
845
843
  const int nb = k / QK8_0;
846
844
 
@@ -1078,6 +1076,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
1078
1076
  }
1079
1077
  vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])), 0, &y[i].qs[0]);
1080
1078
  vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]);
1079
+ }
1081
1080
 
1082
1081
  #elif defined(__loongarch_asx)
1083
1082
  for (int i = 0; i < nb; i++) {
@@ -1145,12 +1144,12 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
1145
1144
  #else
1146
1145
  GGML_UNUSED(nb);
1147
1146
  // scalar
1148
- quantize_row_q8_0_reference(x, y, k);
1147
+ quantize_row_q8_0_ref(x, y, k);
1149
1148
  #endif
1150
1149
  }
1151
1150
 
1152
1151
  // reference implementation for deterministic creation of model files
1153
- void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int64_t k) {
1152
+ void quantize_row_q8_1_ref(const float * restrict x, block_q8_1 * restrict y, int64_t k) {
1154
1153
  assert(QK8_1 == 32);
1155
1154
  assert(k % QK8_1 == 0);
1156
1155
  const int nb = k / QK8_1;
@@ -1437,6 +1436,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
1437
1436
  accv = vec_add(accv, vec_sld(accv, accv, 4));
1438
1437
  accv = vec_add(accv, vec_sld(accv, accv, 8));
1439
1438
  y[i].s = GGML_FP32_TO_FP16(d * vec_extract(accv, 0));
1439
+ }
1440
1440
 
1441
1441
  #elif defined(__loongarch_asx)
1442
1442
  for (int i = 0; i < nb; i++) {
@@ -1508,7 +1508,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
1508
1508
  #else
1509
1509
  GGML_UNUSED(nb);
1510
1510
  // scalar
1511
- quantize_row_q8_1_reference(x, y, k);
1511
+ quantize_row_q8_1_ref(x, y, k);
1512
1512
  #endif
1513
1513
  }
1514
1514
 
@@ -1899,7 +1899,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t *
1899
1899
 
1900
1900
  //========================- 2-bit (de)-quantization
1901
1901
 
1902
- void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int64_t k) {
1902
+ void quantize_row_q2_K_ref(const float * restrict x, block_q2_K * restrict y, int64_t k) {
1903
1903
  assert(k % QK_K == 0);
1904
1904
  const int nb = k / QK_K;
1905
1905
 
@@ -2002,7 +2002,7 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int6
2002
2002
  }
2003
2003
 
2004
2004
  void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) {
2005
- quantize_row_q2_K_reference(x, vy, k);
2005
+ quantize_row_q2_K_ref(x, vy, k);
2006
2006
  }
2007
2007
 
2008
2008
  static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
@@ -2226,7 +2226,7 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
2226
2226
  size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
2227
2227
  size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row);
2228
2228
  if (!quant_weights) {
2229
- quantize_row_q2_K_reference(src, dst, (int64_t)nrow*n_per_row);
2229
+ quantize_row_q2_K_ref(src, dst, (int64_t)nrow*n_per_row);
2230
2230
  }
2231
2231
  else {
2232
2232
  char * qrow = (char *)dst;
@@ -2241,7 +2241,7 @@ size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nr
2241
2241
 
2242
2242
  //========================= 3-bit (de)-quantization
2243
2243
 
2244
- void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int64_t k) {
2244
+ void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, int64_t k) {
2245
2245
  assert(k % QK_K == 0);
2246
2246
  const int nb = k / QK_K;
2247
2247
 
@@ -2368,7 +2368,7 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int6
2368
2368
  }
2369
2369
 
2370
2370
  void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) {
2371
- quantize_row_q3_K_reference(x, vy, k);
2371
+ quantize_row_q3_K_ref(x, vy, k);
2372
2372
  }
2373
2373
 
2374
2374
  static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) {
@@ -2458,7 +2458,7 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
2458
2458
  size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
2459
2459
  size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row);
2460
2460
  if (!quant_weights) {
2461
- quantize_row_q3_K_reference(src, dst, (int64_t)nrow*n_per_row);
2461
+ quantize_row_q3_K_ref(src, dst, (int64_t)nrow*n_per_row);
2462
2462
  }
2463
2463
  else {
2464
2464
  char * qrow = (char *)dst;
@@ -2473,7 +2473,7 @@ size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nr
2473
2473
 
2474
2474
  // ====================== 4-bit (de)-quantization
2475
2475
 
2476
- void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int64_t k) {
2476
+ void quantize_row_q4_K_ref(const float * restrict x, block_q4_K * restrict y, int64_t k) {
2477
2477
  assert(k % QK_K == 0);
2478
2478
  const int nb = k / QK_K;
2479
2479
 
@@ -2572,7 +2572,7 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int6
2572
2572
  void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) {
2573
2573
  assert(k % QK_K == 0);
2574
2574
  block_q4_K * restrict y = vy;
2575
- quantize_row_q4_K_reference(x, y, k);
2575
+ quantize_row_q4_K_ref(x, y, k);
2576
2576
  }
2577
2577
 
2578
2578
  static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) {
@@ -2651,7 +2651,7 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri
2651
2651
  size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
2652
2652
  size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row);
2653
2653
  if (!quant_weights) {
2654
- quantize_row_q4_K_reference(src, dst, (int64_t)nrow*n_per_row);
2654
+ quantize_row_q4_K_ref(src, dst, (int64_t)nrow*n_per_row);
2655
2655
  }
2656
2656
  else {
2657
2657
  char * qrow = (char *)dst;
@@ -2666,7 +2666,7 @@ size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nr
2666
2666
 
2667
2667
  // ====================== 5-bit (de)-quantization
2668
2668
 
2669
- void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int64_t k) {
2669
+ void quantize_row_q5_K_ref(const float * restrict x, block_q5_K * restrict y, int64_t k) {
2670
2670
  assert(k % QK_K == 0);
2671
2671
  const int64_t nb = k / QK_K;
2672
2672
 
@@ -2783,7 +2783,7 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int6
2783
2783
  void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) {
2784
2784
  assert(k % QK_K == 0);
2785
2785
  block_q5_K * restrict y = vy;
2786
- quantize_row_q5_K_reference(x, y, k);
2786
+ quantize_row_q5_K_ref(x, y, k);
2787
2787
  }
2788
2788
 
2789
2789
  static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) {
@@ -2882,7 +2882,7 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri
2882
2882
  size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
2883
2883
  size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row);
2884
2884
  if (!quant_weights) {
2885
- quantize_row_q5_K_reference(src, dst, (int64_t)nrow*n_per_row);
2885
+ quantize_row_q5_K_ref(src, dst, (int64_t)nrow*n_per_row);
2886
2886
  }
2887
2887
  else {
2888
2888
  char * qrow = (char *)dst;
@@ -2897,7 +2897,7 @@ size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nr
2897
2897
 
2898
2898
  // ====================== 6-bit (de)-quantization
2899
2899
 
2900
- void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int64_t k) {
2900
+ void quantize_row_q6_K_ref(const float * restrict x, block_q6_K * restrict y, int64_t k) {
2901
2901
  assert(k % QK_K == 0);
2902
2902
  const int64_t nb = k / QK_K;
2903
2903
 
@@ -3001,7 +3001,7 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int6
3001
3001
  void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) {
3002
3002
  assert(k % QK_K == 0);
3003
3003
  block_q6_K * restrict y = vy;
3004
- quantize_row_q6_K_reference(x, y, k);
3004
+ quantize_row_q6_K_ref(x, y, k);
3005
3005
  }
3006
3006
 
3007
3007
  static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) {
@@ -3091,7 +3091,7 @@ static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restri
3091
3091
  size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
3092
3092
  size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row);
3093
3093
  if (!quant_weights) {
3094
- quantize_row_q6_K_reference(src, dst, (int64_t)nrow*n_per_row);
3094
+ quantize_row_q6_K_ref(src, dst, (int64_t)nrow*n_per_row);
3095
3095
  }
3096
3096
  else {
3097
3097
  char * qrow = (char *)dst;
@@ -3108,7 +3108,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
3108
3108
  static_assert(QK4_0 == 32, "QK4_0 must be 32");
3109
3109
 
3110
3110
  if (!quant_weights) {
3111
- quantize_row_q4_0_reference(x, y, n_per_row);
3111
+ quantize_row_q4_0_ref(x, y, n_per_row);
3112
3112
  return;
3113
3113
  }
3114
3114
 
@@ -3134,7 +3134,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
3134
3134
 
3135
3135
  size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
3136
3136
  if (!quant_weights) {
3137
- quantize_row_q4_0_reference(src, dst, (int64_t)nrow*n_per_row);
3137
+ quantize_row_q4_0_ref(src, dst, (int64_t)nrow*n_per_row);
3138
3138
  return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
3139
3139
  }
3140
3140
  size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
@@ -3151,7 +3151,7 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri
3151
3151
  static_assert(QK4_1 == 32, "QK4_1 must be 32");
3152
3152
 
3153
3153
  if (!quant_weights) {
3154
- quantize_row_q4_1_reference(x, y, n_per_row);
3154
+ quantize_row_q4_1_ref(x, y, n_per_row);
3155
3155
  return;
3156
3156
  }
3157
3157
 
@@ -3179,7 +3179,7 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri
3179
3179
 
3180
3180
  size_t quantize_q4_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
3181
3181
  if (!quant_weights) {
3182
- quantize_row_q4_1_reference(src, dst, (int64_t)nrow*n_per_row);
3182
+ quantize_row_q4_1_ref(src, dst, (int64_t)nrow*n_per_row);
3183
3183
  return nrow * ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
3184
3184
  }
3185
3185
  size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
@@ -3196,7 +3196,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
3196
3196
  static_assert(QK5_0 == 32, "QK5_0 must be 32");
3197
3197
 
3198
3198
  if (!quant_weights) {
3199
- quantize_row_q5_0_reference(x, y, n_per_row);
3199
+ quantize_row_q5_0_ref(x, y, n_per_row);
3200
3200
  return;
3201
3201
  }
3202
3202
 
@@ -3233,7 +3233,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
3233
3233
 
3234
3234
  size_t quantize_q5_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
3235
3235
  if (!quant_weights) {
3236
- quantize_row_q5_0_reference(src, dst, (int64_t)nrow*n_per_row);
3236
+ quantize_row_q5_0_ref(src, dst, (int64_t)nrow*n_per_row);
3237
3237
  return nrow * ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
3238
3238
  }
3239
3239
  size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
@@ -3250,7 +3250,7 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri
3250
3250
  static_assert(QK5_1 == 32, "QK5_1 must be 32");
3251
3251
 
3252
3252
  if (!quant_weights) {
3253
- quantize_row_q5_1_reference(x, y, n_per_row);
3253
+ quantize_row_q5_1_ref(x, y, n_per_row);
3254
3254
  return;
3255
3255
  }
3256
3256
 
@@ -3286,7 +3286,7 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri
3286
3286
 
3287
3287
  size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
3288
3288
  if (!quant_weights) {
3289
- quantize_row_q5_1_reference(src, dst, (int64_t)nrow*n_per_row);
3289
+ quantize_row_q5_1_ref(src, dst, (int64_t)nrow*n_per_row);
3290
3290
  return nrow * ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
3291
3291
  }
3292
3292
  size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
@@ -3302,7 +3302,7 @@ size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nr
3302
3302
  size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
3303
3303
  (void)quant_weights; // not used
3304
3304
  const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row);
3305
- quantize_row_q8_0_reference(src, dst, (int64_t)nrow*n_per_row);
3305
+ quantize_row_q8_0_ref(src, dst, (int64_t)nrow*n_per_row);
3306
3306
  return nrow * row_size;
3307
3307
  }
3308
3308
 
@@ -3590,7 +3590,7 @@ void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y,
3590
3590
 
3591
3591
  //===================================== Q8_K ==============================================
3592
3592
 
3593
- void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int64_t k) {
3593
+ void quantize_row_q8_K_ref(const float * restrict x, block_q8_K * restrict y, int64_t k) {
3594
3594
  assert(k % QK_K == 0);
3595
3595
  const int64_t nb = k / QK_K;
3596
3596
 
@@ -3641,7 +3641,7 @@ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int6
3641
3641
  }
3642
3642
 
3643
3643
  void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
3644
- quantize_row_q8_K_reference(x, y, k);
3644
+ quantize_row_q8_K_ref(x, y, k);
3645
3645
  }
3646
3646
 
3647
3647
  //===================================== Dot ptoducts =================================
@@ -3808,59 +3808,61 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
3808
3808
  float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
3809
3809
  float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
3810
3810
 
3811
- vst1_f32(s, vget_low_f32(sumv2));
3811
+ vst1_f32(s, vget_low_f32(sumv2));
3812
3812
  vst1_f32(s + bs, vget_high_f32(sumv2));
3813
3813
  return;
3814
3814
  }
3815
3815
  #endif
3816
+
3817
+ int ib = 0;
3818
+ float sumf = 0;
3819
+
3816
3820
  #if defined(__ARM_FEATURE_SVE)
3817
- const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
3818
- const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
3821
+ if (svcntb() == QK8_0) {
3822
+ const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
3823
+ const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
3819
3824
 
3820
- svfloat32_t sumv0 = svdup_n_f32(0.0f);
3821
- svfloat32_t sumv1 = svdup_n_f32(0.0f);
3825
+ svfloat32_t sumv0 = svdup_n_f32(0.0f);
3826
+ svfloat32_t sumv1 = svdup_n_f32(0.0f);
3822
3827
 
3823
- assert(nb % 2 == 0); // TODO: handle odd nb
3828
+ for (; ib + 1 < nb; ib += 2) {
3829
+ const block_q4_0 * restrict x0 = &x[ib + 0];
3830
+ const block_q4_0 * restrict x1 = &x[ib + 1];
3831
+ const block_q8_0 * restrict y0 = &y[ib + 0];
3832
+ const block_q8_0 * restrict y1 = &y[ib + 1];
3824
3833
 
3825
- for (int i = 0; i < nb; i += 2) {
3826
- const block_q4_0 * restrict x0 = &x[i + 0];
3827
- const block_q4_0 * restrict x1 = &x[i + 1];
3828
- const block_q8_0 * restrict y0 = &y[i + 0];
3829
- const block_q8_0 * restrict y1 = &y[i + 1];
3834
+ // load x
3835
+ const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
3836
+ const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
3830
3837
 
3831
- // load x
3832
- const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
3833
- const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
3838
+ // 4-bit -> 8-bit
3839
+ const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
3840
+ const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
3834
3841
 
3835
- // 4-bit -> 8-bit
3836
- const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
3837
- const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
3842
+ // sub 8
3843
+ const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
3844
+ const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
3838
3845
 
3839
- // sub 8
3840
- const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
3841
- const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
3846
+ // load y
3847
+ const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
3848
+ const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
3842
3849
 
3843
- // load y
3844
- const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
3845
- const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
3850
+ // dot product
3851
+ sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3852
+ sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
3853
+ }
3846
3854
 
3847
- // dot product
3848
- sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
3849
- sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
3855
+ sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
3850
3856
  }
3851
-
3852
- *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
3853
3857
  #elif defined(__ARM_NEON)
3854
3858
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
3855
3859
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
3856
3860
 
3857
- assert(nb % 2 == 0); // TODO: handle odd nb
3858
-
3859
- for (int i = 0; i < nb; i += 2) {
3860
- const block_q4_0 * restrict x0 = &x[i + 0];
3861
- const block_q4_0 * restrict x1 = &x[i + 1];
3862
- const block_q8_0 * restrict y0 = &y[i + 0];
3863
- const block_q8_0 * restrict y1 = &y[i + 1];
3861
+ for (; ib + 1 < nb; ib += 2) {
3862
+ const block_q4_0 * restrict x0 = &x[ib + 0];
3863
+ const block_q4_0 * restrict x1 = &x[ib + 1];
3864
+ const block_q8_0 * restrict y0 = &y[ib + 0];
3865
+ const block_q8_0 * restrict y1 = &y[ib + 1];
3864
3866
 
3865
3867
  const uint8x16_t m4b = vdupq_n_u8(0x0F);
3866
3868
  const int8x16_t s8b = vdupq_n_s8(0x8);
@@ -3894,23 +3896,23 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
3894
3896
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
3895
3897
  }
3896
3898
 
3897
- *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3899
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
3898
3900
  #elif defined(__AVX2__)
3899
3901
  // Initialize accumulator with zeros
3900
3902
  __m256 acc = _mm256_setzero_ps();
3901
3903
 
3902
3904
  // Main loop
3903
- for (int i = 0; i < nb; ++i) {
3905
+ for (; ib < nb; ++ib) {
3904
3906
  /* Compute combined scale for the block */
3905
- const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
3907
+ const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
3906
3908
 
3907
- __m256i qx = bytes_from_nibbles_32(x[i].qs);
3909
+ __m256i qx = bytes_from_nibbles_32(x[ib].qs);
3908
3910
 
3909
3911
  // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
3910
3912
  const __m256i off = _mm256_set1_epi8( 8 );
3911
3913
  qx = _mm256_sub_epi8( qx, off );
3912
3914
 
3913
- __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs);
3915
+ __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
3914
3916
 
3915
3917
  const __m256 q = mul_sum_i8_pairs_float(qx, qy);
3916
3918
 
@@ -3918,28 +3920,28 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
3918
3920
  acc = _mm256_fmadd_ps( d, q, acc );
3919
3921
  }
3920
3922
 
3921
- *s = hsum_float_8(acc);
3923
+ sumf = hsum_float_8(acc);
3922
3924
  #elif defined(__AVX__)
3923
3925
  // Initialize accumulator with zeros
3924
3926
  __m256 acc = _mm256_setzero_ps();
3925
3927
 
3926
3928
  // Main loop
3927
- for (int i = 0; i < nb; ++i) {
3929
+ for (; ib < nb; ++ib) {
3928
3930
  // Compute combined scale for the block
3929
- const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
3931
+ const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
3930
3932
 
3931
3933
  const __m128i lowMask = _mm_set1_epi8(0xF);
3932
3934
  const __m128i off = _mm_set1_epi8(8);
3933
3935
 
3934
- const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
3936
+ const __m128i tmp = _mm_loadu_si128((const __m128i *)x[ib].qs);
3935
3937
 
3936
3938
  __m128i bx_0 = _mm_and_si128(lowMask, tmp);
3937
- __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
3939
+ __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
3938
3940
  bx_0 = _mm_sub_epi8(bx_0, off);
3939
3941
  const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
3940
3942
 
3941
3943
  bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
3942
- by_0 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
3944
+ by_0 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16));
3943
3945
  bx_0 = _mm_sub_epi8(bx_0, off);
3944
3946
  const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0);
3945
3947
 
@@ -3950,7 +3952,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
3950
3952
  acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
3951
3953
  }
3952
3954
 
3953
- *s = hsum_float_8(acc);
3955
+ sumf = hsum_float_8(acc);
3954
3956
  #elif defined(__SSSE3__)
3955
3957
  // set constants
3956
3958
  const __m128i lowMask = _mm_set1_epi8(0xF);
@@ -3962,94 +3964,40 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
3962
3964
  __m128 acc_2 = _mm_setzero_ps();
3963
3965
  __m128 acc_3 = _mm_setzero_ps();
3964
3966
 
3965
- // First round without accumulation
3966
- {
3967
- _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0);
3968
- _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
3969
-
3970
- // Compute combined scale for the block 0 and 1
3971
- const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) );
3972
-
3973
- const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
3974
-
3975
- __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
3976
- __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
3977
- bx_0 = _mm_sub_epi8(bx_0, off);
3978
- const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
3979
-
3980
- __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
3981
- __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16));
3982
- bx_1 = _mm_sub_epi8(bx_1, off);
3983
- const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
3984
-
3985
- _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0);
3986
- _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
3987
-
3988
- // Compute combined scale for the block 2 and 3
3989
- const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) );
3990
-
3991
- const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
3992
-
3993
- __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
3994
- __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs);
3995
- bx_2 = _mm_sub_epi8(bx_2, off);
3996
- const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
3997
-
3998
- __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
3999
- __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16));
4000
- bx_3 = _mm_sub_epi8(bx_3, off);
4001
- const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
4002
-
4003
- // Convert int32_t to float
4004
- __m128 p0 = _mm_cvtepi32_ps(i32_0);
4005
- __m128 p1 = _mm_cvtepi32_ps(i32_1);
4006
- __m128 p2 = _mm_cvtepi32_ps(i32_2);
4007
- __m128 p3 = _mm_cvtepi32_ps(i32_3);
4008
-
4009
- // Apply the scale
4010
- acc_0 = _mm_mul_ps( d_0_1, p0 );
4011
- acc_1 = _mm_mul_ps( d_0_1, p1 );
4012
- acc_2 = _mm_mul_ps( d_2_3, p2 );
4013
- acc_3 = _mm_mul_ps( d_2_3, p3 );
4014
- }
4015
-
4016
- assert(nb % 2 == 0); // TODO: handle odd nb
4017
-
4018
- // Main loop
4019
- for (int i = 2; i < nb; i+=2) {
4020
- _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0);
4021
- _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
3967
+ for (; ib + 1 < nb; ib += 2) {
3968
+ _mm_prefetch(&x[ib] + sizeof(block_q4_0), _MM_HINT_T0);
3969
+ _mm_prefetch(&y[ib] + sizeof(block_q8_0), _MM_HINT_T0);
4022
3970
 
4023
3971
  // Compute combined scale for the block 0 and 1
4024
- const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
3972
+ const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
4025
3973
 
4026
- const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
3974
+ const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[ib].qs);
4027
3975
 
4028
3976
  __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
4029
- __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
3977
+ __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
4030
3978
  bx_0 = _mm_sub_epi8(bx_0, off);
4031
3979
  const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
4032
3980
 
4033
3981
  __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
4034
- __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
3982
+ __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16));
4035
3983
  bx_1 = _mm_sub_epi8(bx_1, off);
4036
3984
  const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
4037
3985
 
4038
- _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
4039
- _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
3986
+ _mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
3987
+ _mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
4040
3988
 
4041
3989
  // Compute combined scale for the block 2 and 3
4042
- const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) );
3990
+ const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) );
4043
3991
 
4044
- const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
3992
+ const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
4045
3993
 
4046
3994
  __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
4047
- __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs);
3995
+ __m128i by_2 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
4048
3996
  bx_2 = _mm_sub_epi8(bx_2, off);
4049
3997
  const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
4050
3998
 
4051
3999
  __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
4052
- __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16));
4000
+ __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[ib + 1].qs + 16));
4053
4001
  bx_3 = _mm_sub_epi8(bx_3, off);
4054
4002
  const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
4055
4003
 
@@ -4072,18 +4020,16 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4072
4020
  acc_3 = _mm_add_ps(p3_d, acc_3);
4073
4021
  }
4074
4022
 
4075
- *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
4023
+ sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
4076
4024
  #elif defined(__riscv_v_intrinsic)
4077
- float sumf = 0.0;
4078
-
4079
4025
  size_t vl = __riscv_vsetvl_e8m1(qk/2);
4080
4026
 
4081
- for (int i = 0; i < nb; i++) {
4027
+ for (; ib < nb; ++ib) {
4082
4028
  // load elements
4083
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
4029
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
4084
4030
 
4085
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
4086
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
4031
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
4032
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
4087
4033
 
4088
4034
  // mask and store lower part of x, and then upper part
4089
4035
  vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
@@ -4106,30 +4052,29 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4106
4052
 
4107
4053
  int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
4108
4054
 
4109
- sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
4055
+ sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d);
4110
4056
  }
4111
4057
 
4112
- *s = sumf;
4113
-
4114
4058
  #elif defined(__POWER9_VECTOR__)
4115
4059
  const vector signed char lowMask = vec_splats((signed char)0xF);
4060
+ const vector signed int v0 = vec_splats((int32_t)0);
4116
4061
  const vector unsigned char v4 = vec_splats((unsigned char)0x4);
4117
4062
  const vector signed char v8 = vec_splats((signed char)0x8);
4118
4063
 
4119
4064
  vector float vsumf0 = vec_splats(0.0f);
4120
4065
 
4121
- #pragma GCC unroll 4
4122
- for (int i = 0; i < nb; i++) {
4123
- __builtin_prefetch(x[i].qs, 0, 1);
4124
- __builtin_prefetch(y[i].qs, 0, 1);
4066
+ #pragma GCC unroll 8
4067
+ for (; ib < nb; ++ib) {
4068
+ __builtin_prefetch(x[ib].qs, 0, 1);
4069
+ __builtin_prefetch(y[ib].qs, 0, 1);
4125
4070
 
4126
- vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
4127
- vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[i].d));
4071
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
4072
+ vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
4128
4073
  vector float vd = vec_mul(vxd, vyd);
4129
4074
 
4130
- vector signed char qxs = (vector signed char)vec_xl( 0, x[i].qs);
4131
- vector signed char q8y0 = vec_xl( 0, y[i].qs);
4132
- vector signed char q8y1 = vec_xl(16, y[i].qs);
4075
+ vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
4076
+ vector signed char q8y0 = vec_xl( 0, y[ib].qs);
4077
+ vector signed char q8y1 = vec_xl(16, y[ib].qs);
4133
4078
 
4134
4079
  vector signed char q4x0 = vec_and(qxs, lowMask);
4135
4080
  vector signed char q4x1 = vec_sr(qxs, v4);
@@ -4140,9 +4085,10 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4140
4085
  vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0));
4141
4086
  vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1));
4142
4087
 
4143
- qv0 = vec_add(qv0, qv1);
4088
+ vector signed int vsumi0 = v0;
4144
4089
 
4145
- vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0));
4090
+ vsumi0 = vec_sum4s(qv0, vsumi0);
4091
+ vsumi0 = vec_sum4s(qv1, vsumi0);
4146
4092
 
4147
4093
  vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
4148
4094
  }
@@ -4150,24 +4096,24 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4150
4096
  vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
4151
4097
  vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
4152
4098
 
4153
- *s = vec_extract(vsumf0, 0);
4099
+ sumf = vec_extract(vsumf0, 0);
4154
4100
 
4155
4101
  #elif defined(__loongarch_asx)
4156
4102
  // Initialize accumulator with zeros
4157
4103
  __m256 acc = (__m256)__lasx_xvldi(0);
4158
4104
 
4159
4105
  // Main loop
4160
- for (int i = 0; i < nb; ++i) {
4106
+ for (; ib < nb; ++ib) {
4161
4107
  /* Compute combined scale for the block */
4162
- const __m256 d = __lasx_xvreplfr2vr_s( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
4108
+ const __m256 d = __lasx_xvreplfr2vr_s( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
4163
4109
 
4164
- __m256i qx = bytes_from_nibbles_32(x[i].qs);
4110
+ __m256i qx = bytes_from_nibbles_32(x[ib].qs);
4165
4111
 
4166
4112
  // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
4167
4113
  const __m256i off = __lasx_xvreplgr2vr_b( 8 );
4168
4114
  qx = __lasx_xvsub_b( qx, off );
4169
4115
 
4170
- __m256i qy = __lasx_xvld((const __m256i *)y[i].qs, 0);
4116
+ __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
4171
4117
 
4172
4118
  const __m256 q = mul_sum_i8_pairs_float(qx, qy);
4173
4119
 
@@ -4175,7 +4121,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4175
4121
  acc = __lasx_xvfmadd_s( d, q, acc );
4176
4122
  }
4177
4123
 
4178
- *s = hsum_float_8(acc);
4124
+ sumf = hsum_float_8(acc);
4179
4125
  #elif defined(__loongarch_sx)
4180
4126
  // set constants
4181
4127
  const __m128i low_mask = __lsx_vreplgr2vr_b(0xF);
@@ -4187,89 +4133,38 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4187
4133
  __m128 acc_2 = __lsx_vldi(0);
4188
4134
  __m128 acc_3 = __lsx_vldi(0);
4189
4135
 
4190
- // First round without accumulation
4191
- {
4192
- _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0);
4193
- _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
4194
-
4195
- // Compute combined scale for the block 0 and 1
4196
- const __m128 d_0_1 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) );
4197
-
4198
- const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[0].qs, 0);
4199
-
4200
- __m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1);
4201
- __m128i by_0 = __lsx_vld((const __m128i *)y[0].qs, 0);
4202
- bx_0 = __lsx_vsub_b(bx_0, off);
4203
- const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
4204
-
4205
- __m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4));
4206
- __m128i by_1 = __lsx_vld((const __m128i *)(y[0].qs + 16), 0);
4207
- bx_1 = __lsx_vsub_b(bx_1, off);
4208
- const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
4209
-
4210
- // Compute combined scale for the block 2 and 3
4211
- const __m128 d_2_3 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) );
4212
-
4213
- const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[1].qs, 0);
4214
-
4215
- __m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3);
4216
- __m128i by_2 = __lsx_vld((const __m128i *)y[1].qs, 0);
4217
- bx_2 = __lsx_vsub_b(bx_2, off);
4218
- const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
4219
-
4220
- __m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4));
4221
- __m128i by_3 = __lsx_vld((const __m128i *)(y[1].qs + 16), 0);
4222
- bx_3 = __lsx_vsub_b(bx_3, off);
4223
- const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
4224
-
4225
- // Convert int32_t to float
4226
- __m128 p0 = __lsx_vffint_s_w(i32_0);
4227
- __m128 p1 = __lsx_vffint_s_w(i32_1);
4228
- __m128 p2 = __lsx_vffint_s_w(i32_2);
4229
- __m128 p3 = __lsx_vffint_s_w(i32_3);
4230
-
4231
- // Apply the scale
4232
- acc_0 = __lsx_vfmul_s( d_0_1, p0 );
4233
- acc_1 = __lsx_vfmul_s( d_0_1, p1 );
4234
- acc_2 = __lsx_vfmul_s( d_2_3, p2 );
4235
- acc_3 = __lsx_vfmul_s( d_2_3, p3 );
4236
- }
4237
-
4238
- assert(nb % 2 == 0); // TODO: handle odd nb
4239
-
4240
- // Main loop
4241
- for (int i = 2; i < nb; i+=2) {
4136
+ for (; ib + 1 < nb; ib += 2) {
4242
4137
 
4243
4138
  // Compute combined scale for the block 0 and 1
4244
- const __m128 d_0_1 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
4139
+ const __m128 d_0_1 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
4245
4140
 
4246
- const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[i].qs, 0);
4141
+ const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0);
4247
4142
 
4248
4143
  __m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1);
4249
- __m128i by_0 = __lsx_vld((const __m128i *)y[i].qs, 0);
4144
+ __m128i by_0 = __lsx_vld((const __m128i *)y[ib].qs, 0);
4250
4145
  bx_0 = __lsx_vsub_b(bx_0, off);
4251
4146
  const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
4252
4147
 
4253
4148
  __m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4));
4254
- __m128i by_1 = __lsx_vld((const __m128i *)(y[i].qs + 16), 0);
4149
+ __m128i by_1 = __lsx_vld((const __m128i *)(y[ib].qs + 16), 0);
4255
4150
  bx_1 = __lsx_vsub_b(bx_1, off);
4256
4151
  const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
4257
4152
 
4258
- //_mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
4259
- //_mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
4153
+ //_mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
4154
+ //_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
4260
4155
 
4261
4156
  // Compute combined scale for the block 2 and 3
4262
- const __m128 d_2_3 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) );
4157
+ const __m128 d_2_3 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) );
4263
4158
 
4264
- const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[i + 1].qs, 0);
4159
+ const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0);
4265
4160
 
4266
4161
  __m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3);
4267
- __m128i by_2 = __lsx_vld((const __m128i *)y[i + 1].qs, 0);
4162
+ __m128i by_2 = __lsx_vld((const __m128i *)y[ib + 1].qs, 0);
4268
4163
  bx_2 = __lsx_vsub_b(bx_2, off);
4269
4164
  const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
4270
4165
 
4271
4166
  __m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4));
4272
- __m128i by_3 = __lsx_vld((const __m128i *)(y[i + 1].qs + 16), 0);
4167
+ __m128i by_3 = __lsx_vld((const __m128i *)(y[ib + 1].qs + 16), 0);
4273
4168
  bx_3 = __lsx_vsub_b(bx_3, off);
4274
4169
  const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
4275
4170
 
@@ -4292,27 +4187,25 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4292
4187
  acc_3 = __lsx_vfadd_s(p3_d, acc_3);
4293
4188
  }
4294
4189
 
4295
- *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
4296
-
4297
- #else
4298
- // scalar
4299
- float sumf = 0.0;
4300
-
4301
- for (int i = 0; i < nb; i++) {
4302
- int sumi = 0;
4190
+ sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
4191
+ #endif
4192
+ for (; ib < nb; ++ib) {
4193
+ int sumi0 = 0;
4194
+ int sumi1 = 0;
4303
4195
 
4304
4196
  for (int j = 0; j < qk/2; ++j) {
4305
- const int v0 = (x[i].qs[j] & 0x0F) - 8;
4306
- const int v1 = (x[i].qs[j] >> 4) - 8;
4197
+ const int v0 = (x[ib].qs[j] & 0x0F) - 8;
4198
+ const int v1 = (x[ib].qs[j] >> 4) - 8;
4307
4199
 
4308
- sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
4200
+ sumi0 += (v0 * y[ib].qs[j]);
4201
+ sumi1 += (v1 * y[ib].qs[j + qk/2]);
4309
4202
  }
4310
4203
 
4311
- sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
4204
+ int sumi = sumi0 + sumi1;
4205
+ sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d);
4312
4206
  }
4313
4207
 
4314
4208
  *s = sumf;
4315
- #endif
4316
4209
  }
4317
4210
 
4318
4211
  void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
@@ -4398,11 +4291,15 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
4398
4291
  float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
4399
4292
  sumv2 = vaddq_f32(sumv2, summs0);
4400
4293
 
4401
- vst1_f32(s, vget_low_f32(sumv2));
4294
+ vst1_f32(s, vget_low_f32 (sumv2));
4402
4295
  vst1_f32(s + bs, vget_high_f32(sumv2));
4403
4296
  return;
4404
4297
  }
4405
4298
  #endif
4299
+
4300
+ int ib = 0;
4301
+ float sumf = 0;
4302
+
4406
4303
  // TODO: add WASM SIMD
4407
4304
  #if defined(__ARM_NEON)
4408
4305
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
@@ -4410,13 +4307,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
4410
4307
 
4411
4308
  float summs = 0;
4412
4309
 
4413
- assert(nb % 2 == 0); // TODO: handle odd nb
4414
-
4415
- for (int i = 0; i < nb; i += 2) {
4416
- const block_q4_1 * restrict x0 = &x[i + 0];
4417
- const block_q4_1 * restrict x1 = &x[i + 1];
4418
- const block_q8_1 * restrict y0 = &y[i + 0];
4419
- const block_q8_1 * restrict y1 = &y[i + 1];
4310
+ for (; ib + 1 < nb; ib += 2) {
4311
+ const block_q4_1 * restrict x0 = &x[ib + 0];
4312
+ const block_q4_1 * restrict x1 = &x[ib + 1];
4313
+ const block_q8_1 * restrict y0 = &y[ib + 0];
4314
+ const block_q8_1 * restrict y1 = &y[ib + 1];
4420
4315
 
4421
4316
  summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s) + GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s);
4422
4317
 
@@ -4445,7 +4340,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
4445
4340
  sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
4446
4341
  }
4447
4342
 
4448
- *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
4343
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
4449
4344
  #elif defined(__AVX2__) || defined(__AVX__)
4450
4345
  // Initialize accumulator with zeros
4451
4346
  __m256 acc = _mm256_setzero_ps();
@@ -4453,11 +4348,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
4453
4348
  float summs = 0;
4454
4349
 
4455
4350
  // Main loop
4456
- for (int i = 0; i < nb; ++i) {
4457
- const float d0 = GGML_FP16_TO_FP32(x[i].d);
4458
- const float d1 = GGML_FP16_TO_FP32(y[i].d);
4351
+ for (; ib < nb; ++ib) {
4352
+ const float d0 = GGML_FP16_TO_FP32(x[ib].d);
4353
+ const float d1 = GGML_FP16_TO_FP32(y[ib].d);
4459
4354
 
4460
- summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s);
4355
+ summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
4461
4356
 
4462
4357
  const __m256 d0v = _mm256_set1_ps( d0 );
4463
4358
  const __m256 d1v = _mm256_set1_ps( d1 );
@@ -4466,8 +4361,8 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
4466
4361
  const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
4467
4362
 
4468
4363
  // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
4469
- const __m256i qx = bytes_from_nibbles_32(x[i].qs);
4470
- const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[i].qs );
4364
+ const __m256i qx = bytes_from_nibbles_32(x[ib].qs);
4365
+ const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[ib].qs );
4471
4366
 
4472
4367
  const __m256 xy = mul_sum_us8_pairs_float(qx, qy);
4473
4368
 
@@ -4479,18 +4374,16 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
4479
4374
  #endif
4480
4375
  }
4481
4376
 
4482
- *s = hsum_float_8(acc) + summs;
4377
+ sumf = hsum_float_8(acc) + summs;
4483
4378
  #elif defined(__riscv_v_intrinsic)
4484
- float sumf = 0.0;
4485
-
4486
4379
  size_t vl = __riscv_vsetvl_e8m1(qk/2);
4487
4380
 
4488
- for (int i = 0; i < nb; i++) {
4381
+ for (; ib < nb; ++ib) {
4489
4382
  // load elements
4490
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
4383
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
4491
4384
 
4492
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
4493
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
4385
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
4386
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
4494
4387
 
4495
4388
  // mask and store lower part of x, and then upper part
4496
4389
  vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
@@ -4509,43 +4402,40 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
4509
4402
 
4510
4403
  int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
4511
4404
 
4512
- sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d))*sumi + GGML_FP16_TO_FP32(x[i].m)*GGML_FP16_TO_FP32(y[i].s);
4405
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
4513
4406
  }
4514
4407
 
4515
- *s = sumf;
4516
-
4517
4408
  #elif defined(__POWER9_VECTOR__)
4518
4409
  const vector signed char lowMask = vec_splats((signed char)0xF);
4410
+ const vector signed int v0 = vec_splats((int32_t)0);
4519
4411
  const vector unsigned char v4 = vec_splats((unsigned char)0x4);
4520
4412
 
4521
4413
  vector float vsumf0 = vec_splats(0.0f);
4522
4414
 
4523
4415
  #pragma GCC unroll 4
4524
- for (int i = 0; i < nb; i++) {
4525
- __builtin_prefetch(x[i].qs, 0, 1);
4526
- __builtin_prefetch(y[i].qs, 0, 1);
4416
+ for (; ib < nb; ++ib) {
4417
+ __builtin_prefetch(x[ib].qs, 0, 1);
4418
+ __builtin_prefetch(y[ib].qs, 0, 1);
4527
4419
 
4528
- vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
4529
- vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[i].d));
4420
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
4421
+ vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
4530
4422
  vector float vd = vec_mul(vxd, vyd);
4531
4423
 
4532
- vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].m));
4533
- vector float vys = {GGML_FP16_TO_FP32(y[i].s), 0.0f, 0.0f, 0.0f};
4424
+ vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[ib].m));
4425
+ vector float vys = {GGML_FP16_TO_FP32(y[ib].s), 0.0f, 0.0f, 0.0f};
4534
4426
  vsumf0 = vec_madd(vxmin, vys, vsumf0);
4535
4427
 
4536
- vector signed char qxs = (vector signed char)vec_xl( 0, x[i].qs);
4537
- vector signed char q8y0 = vec_xl( 0, y[i].qs);
4538
- vector signed char q8y1 = vec_xl(16, y[i].qs);
4539
-
4540
- vector signed char q4x0 = vec_and(qxs, lowMask);
4541
- vector signed char q4x1 = vec_sr(qxs, v4);
4428
+ vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
4429
+ vector signed char q8y0 = vec_xl( 0, y[ib].qs);
4430
+ vector signed char q8y1 = vec_xl(16, y[ib].qs);
4542
4431
 
4543
- vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0));
4544
- vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1));
4432
+ vector unsigned char q4x0 = (vector unsigned char)vec_and(qxs, lowMask);
4433
+ vector unsigned char q4x1 = (vector unsigned char)vec_sr(qxs, v4);
4545
4434
 
4546
- qv0 = vec_add(qv0, qv1);
4435
+ vector signed int vsumi0 = v0;
4547
4436
 
4548
- vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0));
4437
+ vsumi0 = vec_msum(q8y0, q4x0, vsumi0);
4438
+ vsumi0 = vec_msum(q8y1, q4x1, vsumi0);
4549
4439
 
4550
4440
  vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
4551
4441
  }
@@ -4553,7 +4443,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
4553
4443
  vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
4554
4444
  vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
4555
4445
 
4556
- *s = vec_extract(vsumf0, 0);
4446
+ sumf = vec_extract(vsumf0, 0);
4557
4447
 
4558
4448
  #elif defined(__loongarch_asx)
4559
4449
  // Initialize accumulator with zeros
@@ -4562,11 +4452,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
4562
4452
  float summs = 0;
4563
4453
 
4564
4454
  // Main loop
4565
- for (int i = 0; i < nb; ++i) {
4566
- const float d0 = GGML_FP16_TO_FP32(x[i].d);
4567
- const float d1 = GGML_FP16_TO_FP32(y[i].d);
4455
+ for (; ib < nb; ++ib) {
4456
+ const float d0 = GGML_FP16_TO_FP32(x[ib].d);
4457
+ const float d1 = GGML_FP16_TO_FP32(y[ib].d);
4568
4458
 
4569
- summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s);
4459
+ summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
4570
4460
 
4571
4461
  const __m256 d0v = __lasx_xvreplfr2vr_s( d0 );
4572
4462
  const __m256 d1v = __lasx_xvreplfr2vr_s( d1 );
@@ -4575,8 +4465,8 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
4575
4465
  const __m256 d0d1 = __lasx_xvfmul_s( d0v, d1v );
4576
4466
 
4577
4467
  // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
4578
- const __m256i qx = bytes_from_nibbles_32(x[i].qs);
4579
- const __m256i qy = __lasx_xvld( (const __m256i *)y[i].qs, 0);
4468
+ const __m256i qx = bytes_from_nibbles_32(x[ib].qs);
4469
+ const __m256i qy = __lasx_xvld( (const __m256i *)y[ib].qs, 0);
4580
4470
 
4581
4471
  const __m256 xy = mul_sum_us8_pairs_float(qx, qy);
4582
4472
 
@@ -4584,33 +4474,34 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
4584
4474
  acc = __lasx_xvfmadd_s( d0d1, xy, acc );
4585
4475
  }
4586
4476
 
4587
- *s = hsum_float_8(acc) + summs;
4588
-
4589
- #else
4590
- // scalar
4591
- float sumf = 0.0;
4592
-
4593
- for (int i = 0; i < nb; i++) {
4594
- int sumi = 0;
4477
+ sumf = hsum_float_8(acc) + summs;
4478
+ #endif
4479
+ for (; ib < nb; ++ib) {
4480
+ int sumi0 = 0;
4481
+ int sumi1 = 0;
4595
4482
 
4596
4483
  for (int j = 0; j < qk/2; ++j) {
4597
- const int v0 = (x[i].qs[j] & 0x0F);
4598
- const int v1 = (x[i].qs[j] >> 4);
4484
+ const int v0 = (x[ib].qs[j] & 0x0F);
4485
+ const int v1 = (x[ib].qs[j] >> 4);
4599
4486
 
4600
- sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
4487
+ sumi0 += (v0 * y[ib].qs[j]);
4488
+ sumi1 += (v1 * y[ib].qs[j + qk/2]);
4601
4489
  }
4602
4490
 
4603
- sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d))*sumi + GGML_FP16_TO_FP32(x[i].m)*GGML_FP16_TO_FP32(y[i].s);
4491
+ int sumi = sumi0 + sumi1;
4492
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
4604
4493
  }
4605
4494
 
4606
4495
  *s = sumf;
4607
- #endif
4608
4496
  }
4609
4497
 
4610
4498
  void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
4611
4499
  const int qk = QK8_0;
4612
4500
  const int nb = n / qk;
4613
4501
 
4502
+ int ib = 0;
4503
+ float sumf = 0;
4504
+
4614
4505
  assert(n % qk == 0);
4615
4506
  assert(qk == QK5_0);
4616
4507
  assert(nrc == 1);
@@ -4632,13 +4523,11 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4632
4523
  uint64_t tmp0[4];
4633
4524
  uint64_t tmp1[4];
4634
4525
 
4635
- assert(nb % 2 == 0); // TODO: handle odd nb
4636
-
4637
- for (int i = 0; i < nb; i += 2) {
4638
- const block_q5_0 * restrict x0 = &x[i];
4639
- const block_q5_0 * restrict x1 = &x[i + 1];
4640
- const block_q8_0 * restrict y0 = &y[i];
4641
- const block_q8_0 * restrict y1 = &y[i + 1];
4526
+ for (; ib + 1 < nb; ib += 2) {
4527
+ const block_q5_0 * restrict x0 = &x[ib];
4528
+ const block_q5_0 * restrict x1 = &x[ib + 1];
4529
+ const block_q8_0 * restrict y0 = &y[ib];
4530
+ const block_q8_0 * restrict y1 = &y[ib + 1];
4642
4531
 
4643
4532
  const uint8x16_t m4b = vdupq_n_u8(0x0F);
4644
4533
 
@@ -4690,7 +4579,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4690
4579
  ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
4691
4580
  }
4692
4581
 
4693
- *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
4582
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
4694
4583
  #elif defined(__wasm_simd128__)
4695
4584
  v128_t sumv = wasm_f32x4_splat(0.0f);
4696
4585
 
@@ -4698,9 +4587,9 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4698
4587
  uint64_t tmp[4];
4699
4588
 
4700
4589
  // TODO: check if unrolling this is better
4701
- for (int i = 0; i < nb; ++i) {
4702
- const block_q5_0 * restrict x0 = &x[i];
4703
- const block_q8_0 * restrict y0 = &y[i];
4590
+ for (; ib < nb; ++ib) {
4591
+ const block_q5_0 * restrict x0 = &x[ib];
4592
+ const block_q8_0 * restrict y0 = &y[ib];
4704
4593
 
4705
4594
  const v128_t m4b = wasm_i8x16_splat(0x0F);
4706
4595
 
@@ -4750,23 +4639,23 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4750
4639
  wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d))));
4751
4640
  }
4752
4641
 
4753
- *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
4754
- wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
4642
+ sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
4643
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
4755
4644
  #elif defined(__AVX2__)
4756
4645
  // Initialize accumulator with zeros
4757
4646
  __m256 acc = _mm256_setzero_ps();
4758
4647
 
4759
4648
  // Main loop
4760
- for (int i = 0; i < nb; i++) {
4649
+ for (; ib < nb; ++ib) {
4761
4650
  /* Compute combined scale for the block */
4762
- const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
4651
+ const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
4763
4652
 
4764
- __m256i qx = bytes_from_nibbles_32(x[i].qs);
4765
- __m256i bxhi = bytes_from_bits_32(x[i].qh);
4653
+ __m256i qx = bytes_from_nibbles_32(x[ib].qs);
4654
+ __m256i bxhi = bytes_from_bits_32(x[ib].qh);
4766
4655
  bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
4767
4656
  qx = _mm256_or_si256(qx, bxhi);
4768
4657
 
4769
- __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs);
4658
+ __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
4770
4659
 
4771
4660
  const __m256 q = mul_sum_i8_pairs_float(qx, qy);
4772
4661
 
@@ -4774,19 +4663,19 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4774
4663
  acc = _mm256_fmadd_ps(d, q, acc);
4775
4664
  }
4776
4665
 
4777
- *s = hsum_float_8(acc);
4666
+ sumf = hsum_float_8(acc);
4778
4667
  #elif defined(__AVX__)
4779
4668
  // Initialize accumulator with zeros
4780
4669
  __m256 acc = _mm256_setzero_ps();
4781
4670
  __m128i mask = _mm_set1_epi8((char)0xF0);
4782
4671
 
4783
4672
  // Main loop
4784
- for (int i = 0; i < nb; i++) {
4673
+ for (; ib < nb; ++ib) {
4785
4674
  /* Compute combined scale for the block */
4786
- const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
4675
+ const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
4787
4676
 
4788
- __m256i bx_0 = bytes_from_nibbles_32(x[i].qs);
4789
- const __m256i bxhi = bytes_from_bits_32(x[i].qh);
4677
+ __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs);
4678
+ const __m256i bxhi = bytes_from_bits_32(x[ib].qh);
4790
4679
  __m128i bxhil = _mm256_castsi256_si128(bxhi);
4791
4680
  __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
4792
4681
  bxhil = _mm_andnot_si128(bxhil, mask);
@@ -4797,7 +4686,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4797
4686
  bxh = _mm_or_si128(bxh, bxhih);
4798
4687
  bx_0 = MM256_SET_M128I(bxh, bxl);
4799
4688
 
4800
- const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs);
4689
+ const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs);
4801
4690
 
4802
4691
  const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0);
4803
4692
 
@@ -4805,10 +4694,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4805
4694
  acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
4806
4695
  }
4807
4696
 
4808
- *s = hsum_float_8(acc);
4697
+ sumf = hsum_float_8(acc);
4809
4698
  #elif defined(__riscv_v_intrinsic)
4810
- float sumf = 0.0;
4811
-
4812
4699
  uint32_t qh;
4813
4700
 
4814
4701
  size_t vl = __riscv_vsetvl_e8m1(qk/2);
@@ -4820,8 +4707,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4820
4707
  vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
4821
4708
  vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
4822
4709
 
4823
- for (int i = 0; i < nb; i++) {
4824
- memcpy(&qh, x[i].qh, sizeof(uint32_t));
4710
+ for (; ib < nb; ++ib) {
4711
+ memcpy(&qh, x[ib].qh, sizeof(uint32_t));
4825
4712
 
4826
4713
  // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
4827
4714
  vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
@@ -4840,10 +4727,10 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4840
4727
  vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
4841
4728
 
4842
4729
  // load
4843
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
4730
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
4844
4731
 
4845
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
4846
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
4732
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
4733
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
4847
4734
 
4848
4735
  vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
4849
4736
  vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
@@ -4867,11 +4754,9 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4867
4754
 
4868
4755
  int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
4869
4756
 
4870
- sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
4757
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi;
4871
4758
  }
4872
4759
 
4873
- *s = sumf;
4874
-
4875
4760
  #elif defined(__POWER9_VECTOR__)
4876
4761
  const vector signed char lowMask = vec_splats((signed char)0xF);
4877
4762
  const vector unsigned char v4 = vec_splats((unsigned char)4);
@@ -4879,27 +4764,27 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4879
4764
  vector float vsumf0 = vec_splats(0.0f);
4880
4765
 
4881
4766
  #pragma GCC unroll 4
4882
- for (int i = 0; i < nb; ++i) {
4883
- __builtin_prefetch(x[i].qs, 0, 1);
4884
- __builtin_prefetch(y[i].qs, 0, 1);
4767
+ for (; ib < nb; ++ib) {
4768
+ __builtin_prefetch(x[ib].qs, 0, 1);
4769
+ __builtin_prefetch(y[ib].qs, 0, 1);
4885
4770
 
4886
- vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
4887
- vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[i].d));
4771
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
4772
+ vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
4888
4773
  vector float vd = vec_mul(vxd, vyd);
4889
4774
 
4890
- vector signed long long aux64x2_0 = {(uint64_t)(table_b2b_1[x[i].qh[0]]), (uint64_t)(table_b2b_1[x[i].qh[1]])};
4891
- vector signed long long aux64x2_1 = {(uint64_t)(table_b2b_1[x[i].qh[2]]), (uint64_t)(table_b2b_1[x[i].qh[3]])};
4775
+ vector signed long long aux64x2_0 = {(uint64_t)(table_b2b_1[x[ib].qh[0]]), (uint64_t)(table_b2b_1[x[ib].qh[1]])};
4776
+ vector signed long long aux64x2_1 = {(uint64_t)(table_b2b_1[x[ib].qh[2]]), (uint64_t)(table_b2b_1[x[ib].qh[3]])};
4892
4777
 
4893
4778
  vector signed char qh0 = (vector signed char)aux64x2_0;
4894
4779
  vector signed char qh1 = (vector signed char)aux64x2_1;
4895
4780
 
4896
- vector signed char qxs = (vector signed char)vec_xl( 0, x[i].qs);
4781
+ vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
4897
4782
 
4898
4783
  vector signed char q5x0 = vec_sub(vec_and (qxs, lowMask), qh0);
4899
4784
  vector signed char q5x1 = vec_sub(vec_sr(qxs, v4), qh1);
4900
4785
 
4901
- vector signed char q8y0 = vec_xl( 0, y[i].qs);
4902
- vector signed char q8y1 = vec_xl( 16, y[i].qs);
4786
+ vector signed char q8y0 = vec_xl( 0, y[ib].qs);
4787
+ vector signed char q8y1 = vec_xl( 16, y[ib].qs);
4903
4788
 
4904
4789
  vector signed short qv0 = vec_add(vec_mule(q5x0, q8y0), vec_mulo(q5x0, q8y0));
4905
4790
  vector signed short qv1 = vec_add(vec_mule(q5x1, q8y1), vec_mulo(q5x1, q8y1));
@@ -4914,23 +4799,23 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4914
4799
  vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
4915
4800
  vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
4916
4801
 
4917
- *s = vec_extract(vsumf0, 0);
4802
+ sumf = vec_extract(vsumf0, 0);
4918
4803
 
4919
4804
  #elif defined(__loongarch_asx)
4920
4805
  // Initialize accumulator with zeros
4921
4806
  __m256 acc = (__m256)__lasx_xvldi(0);
4922
4807
 
4923
4808
  // Main loop
4924
- for (int i = 0; i < nb; i++) {
4809
+ for (; ib < nb; ++ib) {
4925
4810
  /* Compute combined scale for the block */
4926
- const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); //FIXME
4811
+ const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); //FIXME
4927
4812
 
4928
- __m256i qx = bytes_from_nibbles_32(x[i].qs);
4929
- __m256i bxhi = bytes_from_bits_32(x[i].qh);
4813
+ __m256i qx = bytes_from_nibbles_32(x[ib].qs);
4814
+ __m256i bxhi = bytes_from_bits_32(x[ib].qh);
4930
4815
  bxhi = __lasx_xvandn_v(bxhi, __lasx_xvreplgr2vr_b((char)0xF0));
4931
4816
  qx = __lasx_xvor_v(qx, bxhi);
4932
4817
 
4933
- __m256i qy = __lasx_xvld((const __m256i *)y[i].qs, 0);
4818
+ __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
4934
4819
 
4935
4820
  const __m256 q = mul_sum_i8_pairs_float(qx, qy);
4936
4821
 
@@ -4938,39 +4823,40 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
4938
4823
  acc = __lasx_xvfmadd_s(d, q, acc);
4939
4824
  }
4940
4825
 
4941
- *s = hsum_float_8(acc);
4942
-
4943
- #else
4944
- // scalar
4945
- float sumf = 0.0;
4946
-
4947
- for (int i = 0; i < nb; i++) {
4826
+ sumf = hsum_float_8(acc);
4827
+ #endif
4828
+ for (; ib < nb; ++ib) {
4948
4829
  uint32_t qh;
4949
- memcpy(&qh, x[i].qh, sizeof(qh));
4830
+ memcpy(&qh, x[ib].qh, sizeof(qh));
4950
4831
 
4951
- int sumi = 0;
4832
+ int sumi0 = 0;
4833
+ int sumi1 = 0;
4952
4834
 
4953
4835
  for (int j = 0; j < qk/2; ++j) {
4954
4836
  const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
4955
4837
  const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
4956
4838
 
4957
- const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
4958
- const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
4839
+ const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);
4840
+ const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16);
4959
4841
 
4960
- sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
4842
+ sumi0 += (x0 * y[ib].qs[j]);
4843
+ sumi1 += (x1 * y[ib].qs[j + qk/2]);
4961
4844
  }
4962
4845
 
4963
- sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
4846
+ int sumi = sumi0 + sumi1;
4847
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi;
4964
4848
  }
4965
4849
 
4966
4850
  *s = sumf;
4967
- #endif
4968
4851
  }
4969
4852
 
4970
4853
  void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
4971
4854
  const int qk = QK8_1;
4972
4855
  const int nb = n / qk;
4973
4856
 
4857
+ int ib = 0;
4858
+ float sumf = 0;
4859
+
4974
4860
  assert(n % qk == 0);
4975
4861
  assert(qk == QK5_1);
4976
4862
  assert(nrc == 1);
@@ -4995,13 +4881,11 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
4995
4881
  uint64_t tmp0[4];
4996
4882
  uint64_t tmp1[4];
4997
4883
 
4998
- assert(nb % 2 == 0); // TODO: handle odd nb
4999
-
5000
- for (int i = 0; i < nb; i += 2) {
5001
- const block_q5_1 * restrict x0 = &x[i];
5002
- const block_q5_1 * restrict x1 = &x[i + 1];
5003
- const block_q8_1 * restrict y0 = &y[i];
5004
- const block_q8_1 * restrict y1 = &y[i + 1];
4884
+ for (; ib + 1 < nb; ib += 2) {
4885
+ const block_q5_1 * restrict x0 = &x[ib];
4886
+ const block_q5_1 * restrict x1 = &x[ib + 1];
4887
+ const block_q8_1 * restrict y0 = &y[ib];
4888
+ const block_q8_1 * restrict y1 = &y[ib + 1];
5005
4889
 
5006
4890
  const uint8x16_t m4b = vdupq_n_u8(0x0F);
5007
4891
 
@@ -5056,7 +4940,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
5056
4940
  ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
5057
4941
  }
5058
4942
 
5059
- *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
4943
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
5060
4944
  #elif defined(__wasm_simd128__)
5061
4945
  v128_t sumv = wasm_f32x4_splat(0.0f);
5062
4946
 
@@ -5066,9 +4950,9 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
5066
4950
  uint64_t tmp[4];
5067
4951
 
5068
4952
  // TODO: check if unrolling this is better
5069
- for (int i = 0; i < nb; ++i) {
5070
- const block_q5_1 * restrict x0 = &x[i];
5071
- const block_q8_1 * restrict y0 = &y[i];
4953
+ for (; ib < nb; ++ib) {
4954
+ const block_q5_1 * restrict x0 = &x[ib];
4955
+ const block_q8_1 * restrict y0 = &y[ib];
5072
4956
 
5073
4957
  summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s);
5074
4958
 
@@ -5120,8 +5004,8 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
5120
5004
  wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d))));
5121
5005
  }
5122
5006
 
5123
- *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
5124
- wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
5007
+ sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
5008
+ wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
5125
5009
  #elif defined(__AVX2__)
5126
5010
  // Initialize accumulator with zeros
5127
5011
  __m256 acc = _mm256_setzero_ps();
@@ -5129,25 +5013,25 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
5129
5013
  float summs = 0.0f;
5130
5014
 
5131
5015
  // Main loop
5132
- for (int i = 0; i < nb; i++) {
5133
- const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
5016
+ for (; ib < nb; ++ib) {
5017
+ const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d));
5134
5018
 
5135
- summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s);
5019
+ summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
5136
5020
 
5137
- __m256i qx = bytes_from_nibbles_32(x[i].qs);
5138
- __m256i bxhi = bytes_from_bits_32(x[i].qh);
5021
+ __m256i qx = bytes_from_nibbles_32(x[ib].qs);
5022
+ __m256i bxhi = bytes_from_bits_32(x[ib].qh);
5139
5023
  bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
5140
5024
  qx = _mm256_or_si256(qx, bxhi);
5141
5025
 
5142
- const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[i].d));
5143
- const __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs);
5026
+ const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[ib].d));
5027
+ const __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
5144
5028
 
5145
5029
  const __m256 q = mul_sum_us8_pairs_float(qx, qy);
5146
5030
 
5147
5031
  acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
5148
5032
  }
5149
5033
 
5150
- *s = hsum_float_8(acc) + summs;
5034
+ sumf = hsum_float_8(acc) + summs;
5151
5035
  #elif defined(__AVX__)
5152
5036
  // Initialize accumulator with zeros
5153
5037
  __m256 acc = _mm256_setzero_ps();
@@ -5156,13 +5040,13 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
5156
5040
  float summs = 0.0f;
5157
5041
 
5158
5042
  // Main loop
5159
- for (int i = 0; i < nb; i++) {
5160
- const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
5043
+ for (; ib < nb; ++ib) {
5044
+ const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d));
5161
5045
 
5162
- summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s);
5046
+ summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
5163
5047
 
5164
- __m256i bx_0 = bytes_from_nibbles_32(x[i].qs);
5165
- const __m256i bxhi = bytes_from_bits_32(x[i].qh);
5048
+ __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs);
5049
+ const __m256i bxhi = bytes_from_bits_32(x[ib].qh);
5166
5050
  __m128i bxhil = _mm256_castsi256_si128(bxhi);
5167
5051
  __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
5168
5052
  bxhil = _mm_and_si128(bxhil, mask);
@@ -5173,18 +5057,16 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
5173
5057
  bxh = _mm_or_si128(bxh, bxhih);
5174
5058
  bx_0 = MM256_SET_M128I(bxh, bxl);
5175
5059
 
5176
- const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[i].d));
5177
- const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs);
5060
+ const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[ib].d));
5061
+ const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs);
5178
5062
 
5179
5063
  const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0);
5180
5064
 
5181
5065
  acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
5182
5066
  }
5183
5067
 
5184
- *s = hsum_float_8(acc) + summs;
5068
+ sumf = hsum_float_8(acc) + summs;
5185
5069
  #elif defined(__riscv_v_intrinsic)
5186
- float sumf = 0.0;
5187
-
5188
5070
  uint32_t qh;
5189
5071
 
5190
5072
  size_t vl = __riscv_vsetvl_e8m1(qk/2);
@@ -5193,8 +5075,8 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
5193
5075
  vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
5194
5076
  vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
5195
5077
 
5196
- for (int i = 0; i < nb; i++) {
5197
- memcpy(&qh, x[i].qh, sizeof(uint32_t));
5078
+ for (; ib < nb; ++ib) {
5079
+ memcpy(&qh, x[ib].qh, sizeof(uint32_t));
5198
5080
 
5199
5081
  // load qh
5200
5082
  vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
@@ -5216,10 +5098,10 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
5216
5098
  vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
5217
5099
 
5218
5100
  // load
5219
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl);
5101
+ vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
5220
5102
 
5221
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl);
5222
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl);
5103
+ vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
5104
+ vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
5223
5105
 
5224
5106
  vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
5225
5107
  vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
@@ -5240,50 +5122,47 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
5240
5122
 
5241
5123
  int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
5242
5124
 
5243
- sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d))*sumi + GGML_FP16_TO_FP32(x[i].m)*GGML_FP16_TO_FP32(y[i].s);
5125
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
5244
5126
  }
5245
5127
 
5246
- *s = sumf;
5247
-
5248
5128
  #elif defined(__POWER9_VECTOR__)
5249
5129
  const vector signed char lowMask = vec_splats((signed char)0xF);
5130
+ const vector signed int v0 = vec_splats((int32_t)0);
5250
5131
  const vector unsigned char v4 = vec_splats((unsigned char)0x4);
5251
5132
 
5252
5133
  vector float vsumf0 = vec_splats(0.0f);
5253
5134
 
5254
5135
  #pragma GCC unroll 4
5255
- for (int i = 0; i < nb; ++i) {
5256
- __builtin_prefetch(x[i].qs, 0, 1);
5257
- __builtin_prefetch(y[i].qs, 0, 1);
5136
+ for (; ib < nb; ++ib) {
5137
+ __builtin_prefetch(x[ib].qs, 0, 1);
5138
+ __builtin_prefetch(y[ib].qs, 0, 1);
5258
5139
 
5259
- vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
5260
- vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[i].d));
5140
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
5141
+ vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
5261
5142
  vector float vd = vec_mul(vxd, vyd);
5262
5143
 
5263
- vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].m));
5264
- vector float vys = {GGML_FP16_TO_FP32(y[i].s), 0.f, 0.f, 0.f};
5144
+ vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[ib].m));
5145
+ vector float vys = {GGML_FP16_TO_FP32(y[ib].s), 0.f, 0.f, 0.f};
5265
5146
  vsumf0 = vec_madd(vxmin, vys, vsumf0);
5266
5147
 
5267
- vector unsigned long long aux64x2_0 = {(uint64_t)(table_b2b_0[x[i].qh[0]]), (uint64_t)(table_b2b_0[x[i].qh[1]])};
5268
- vector unsigned long long aux64x2_1 = {(uint64_t)(table_b2b_0[x[i].qh[2]]), (uint64_t)(table_b2b_0[x[i].qh[3]])};
5148
+ vector unsigned long long aux64x2_0 = {(uint64_t)(table_b2b_0[x[ib].qh[0]]), (uint64_t)(table_b2b_0[x[ib].qh[1]])};
5149
+ vector unsigned long long aux64x2_1 = {(uint64_t)(table_b2b_0[x[ib].qh[2]]), (uint64_t)(table_b2b_0[x[ib].qh[3]])};
5269
5150
 
5270
5151
  vector signed char qh0 = (vector signed char)aux64x2_0;
5271
5152
  vector signed char qh1 = (vector signed char)aux64x2_1;
5272
5153
 
5273
- vector signed char qxs = (vector signed char)vec_xl( 0, x[i].qs);
5274
-
5275
- vector signed char q5x0 = vec_or(vec_and(qxs, lowMask), qh0);
5276
- vector signed char q5x1 = vec_or(vec_sr(qxs, v4), qh1);
5154
+ vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
5277
5155
 
5278
- vector signed char q8y0 = vec_xl( 0, y[i].qs);
5279
- vector signed char q8y1 = vec_xl( 16, y[i].qs);
5156
+ vector unsigned char q5x0 = (vector unsigned char)vec_or(vec_and(qxs, lowMask), qh0);
5157
+ vector unsigned char q5x1 = (vector unsigned char)vec_or(vec_sr(qxs, v4), qh1);
5280
5158
 
5281
- vector signed short qv0 = vec_add(vec_mule(q5x0, q8y0), vec_mulo(q5x0, q8y0));
5282
- vector signed short qv1 = vec_add(vec_mule(q5x1, q8y1), vec_mulo(q5x1, q8y1));
5159
+ vector signed char q8y0 = vec_xl( 0, y[ib].qs);
5160
+ vector signed char q8y1 = vec_xl( 16, y[ib].qs);
5283
5161
 
5284
- qv0 = vec_add(qv0, qv1);
5162
+ vector signed int vsumi0 = v0;
5285
5163
 
5286
- vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0));
5164
+ vsumi0 = vec_msum(q8y0, q5x0, vsumi0);
5165
+ vsumi0 = vec_msum(q8y1, q5x1, vsumi0);
5287
5166
 
5288
5167
  vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
5289
5168
  }
@@ -5291,7 +5170,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
5291
5170
  vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
5292
5171
  vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
5293
5172
 
5294
- *s = vec_extract(vsumf0, 0);
5173
+ sumf = vec_extract(vsumf0, 0);
5295
5174
 
5296
5175
  #elif defined(__loongarch_asx)
5297
5176
  // Initialize accumulator with zeros
@@ -5300,51 +5179,49 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
5300
5179
  float summs = 0.0f;
5301
5180
 
5302
5181
  // Main loop
5303
- for (int i = 0; i < nb; i++) {
5304
- const __m256 dx = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[i].d));
5182
+ for (; ib < nb; ++ib) {
5183
+ const __m256 dx = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d));
5305
5184
 
5306
- summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s);
5185
+ summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
5307
5186
 
5308
- __m256i qx = bytes_from_nibbles_32(x[i].qs);
5309
- __m256i bxhi = bytes_from_bits_32(x[i].qh);
5187
+ __m256i qx = bytes_from_nibbles_32(x[ib].qs);
5188
+ __m256i bxhi = bytes_from_bits_32(x[ib].qh);
5310
5189
  bxhi = __lasx_xvand_v(bxhi, __lasx_xvreplgr2vr_b(0x10));
5311
5190
  qx = __lasx_xvor_v(qx, bxhi);
5312
5191
 
5313
- const __m256 dy = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[i].d));
5314
- const __m256i qy = __lasx_xvld((const __m256i *)y[i].qs, 0);
5192
+ const __m256 dy = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib].d));
5193
+ const __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
5315
5194
 
5316
5195
  const __m256 q = mul_sum_us8_pairs_float(qx, qy);
5317
5196
 
5318
5197
  acc = __lasx_xvfmadd_s(q, __lasx_xvfmul_s(dx, dy), acc);
5319
5198
  }
5320
5199
 
5321
- *s = hsum_float_8(acc) + summs;
5322
-
5323
- #else
5324
- // scalar
5325
- float sumf = 0.0;
5326
-
5327
- for (int i = 0; i < nb; i++) {
5200
+ sumf = hsum_float_8(acc) + summs;
5201
+ #endif
5202
+ for (; ib < nb; ++ib) {
5328
5203
  uint32_t qh;
5329
- memcpy(&qh, x[i].qh, sizeof(qh));
5204
+ memcpy(&qh, x[ib].qh, sizeof(qh));
5330
5205
 
5331
- int sumi = 0;
5206
+ int sumi0 = 0;
5207
+ int sumi1 = 0;
5332
5208
 
5333
5209
  for (int j = 0; j < qk/2; ++j) {
5334
5210
  const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
5335
5211
  const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
5336
5212
 
5337
- const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0;
5338
- const int32_t x1 = (x[i].qs[j] >> 4) | xh_1;
5213
+ const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;
5214
+ const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1;
5339
5215
 
5340
- sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
5216
+ sumi0 += (x0 * y[ib].qs[j]);
5217
+ sumi1 += (x1 * y[ib].qs[j + qk/2]);
5341
5218
  }
5342
5219
 
5343
- sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d))*sumi + GGML_FP16_TO_FP32(x[i].m)*GGML_FP16_TO_FP32(y[i].s);
5220
+ int sumi = sumi0 + sumi1;
5221
+ sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
5344
5222
  }
5345
5223
 
5346
5224
  *s = sumf;
5347
- #endif
5348
5225
  }
5349
5226
 
5350
5227
  void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
@@ -5421,42 +5298,44 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
5421
5298
  return;
5422
5299
  }
5423
5300
  #endif
5301
+
5302
+ int ib = 0;
5303
+ float sumf = 0;
5304
+
5424
5305
  #if defined(__ARM_FEATURE_SVE)
5425
- svfloat32_t sumv0 = svdup_n_f32(0.0f);
5426
- svfloat32_t sumv1 = svdup_n_f32(0.0f);
5306
+ if (svcntb() == QK8_0) {
5307
+ svfloat32_t sumv0 = svdup_n_f32(0.0f);
5308
+ svfloat32_t sumv1 = svdup_n_f32(0.0f);
5427
5309
 
5428
- assert(nb % 2 == 0); // TODO: handle odd nb
5310
+ for (; ib + 1 < nb; ib += 2) {
5311
+ const block_q8_0 * restrict x0 = &x[ib + 0];
5312
+ const block_q8_0 * restrict x1 = &x[ib + 1];
5313
+ const block_q8_0 * restrict y0 = &y[ib + 0];
5314
+ const block_q8_0 * restrict y1 = &y[ib + 1];
5429
5315
 
5430
- for (int i = 0; i < nb; i += 2) {
5431
- const block_q8_0 * restrict x0 = &x[i + 0];
5432
- const block_q8_0 * restrict x1 = &x[i + 1];
5433
- const block_q8_0 * restrict y0 = &y[i + 0];
5434
- const block_q8_0 * restrict y1 = &y[i + 1];
5316
+ // load x
5317
+ const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
5318
+ const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
5435
5319
 
5436
- // load x
5437
- const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
5438
- const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
5320
+ // load y
5321
+ const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
5322
+ const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
5439
5323
 
5440
- // load y
5441
- const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
5442
- const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
5324
+ sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
5325
+ sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
5326
+ }
5443
5327
 
5444
- sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
5445
- sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
5328
+ sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
5446
5329
  }
5447
-
5448
- *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
5449
5330
  #elif defined(__ARM_NEON)
5450
5331
  float32x4_t sumv0 = vdupq_n_f32(0.0f);
5451
5332
  float32x4_t sumv1 = vdupq_n_f32(0.0f);
5452
5333
 
5453
- assert(nb % 2 == 0); // TODO: handle odd nb
5454
-
5455
- for (int i = 0; i < nb; i += 2) {
5456
- const block_q8_0 * restrict x0 = &x[i + 0];
5457
- const block_q8_0 * restrict x1 = &x[i + 1];
5458
- const block_q8_0 * restrict y0 = &y[i + 0];
5459
- const block_q8_0 * restrict y1 = &y[i + 1];
5334
+ for (; ib + 1 < nb; ib += 2) {
5335
+ const block_q8_0 * restrict x0 = &x[ib + 0];
5336
+ const block_q8_0 * restrict x1 = &x[ib + 1];
5337
+ const block_q8_0 * restrict y0 = &y[ib + 0];
5338
+ const block_q8_0 * restrict y1 = &y[ib + 1];
5460
5339
 
5461
5340
  const int8x16_t x0_0 = vld1q_s8(x0->qs);
5462
5341
  const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
@@ -5478,17 +5357,17 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
5478
5357
  ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
5479
5358
  }
5480
5359
 
5481
- *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
5360
+ sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
5482
5361
  #elif defined(__AVX2__) || defined(__AVX__)
5483
5362
  // Initialize accumulator with zeros
5484
5363
  __m256 acc = _mm256_setzero_ps();
5485
5364
 
5486
5365
  // Main loop
5487
- for (int i = 0; i < nb; ++i) {
5366
+ for (; ib < nb; ++ib) {
5488
5367
  // Compute combined scale for the block
5489
- const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
5490
- __m256i qx = _mm256_loadu_si256((const __m256i *)x[i].qs);
5491
- __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs);
5368
+ const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
5369
+ __m256i qx = _mm256_loadu_si256((const __m256i *)x[ib].qs);
5370
+ __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
5492
5371
 
5493
5372
  const __m256 q = mul_sum_i8_pairs_float(qx, qy);
5494
5373
 
@@ -5500,15 +5379,14 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
5500
5379
  #endif
5501
5380
  }
5502
5381
 
5503
- *s = hsum_float_8(acc);
5382
+ sumf = hsum_float_8(acc);
5504
5383
  #elif defined(__riscv_v_intrinsic)
5505
- float sumf = 0.0;
5506
5384
  size_t vl = __riscv_vsetvl_e8m1(qk);
5507
5385
 
5508
- for (int i = 0; i < nb; i++) {
5386
+ for (; ib < nb; ++ib) {
5509
5387
  // load elements
5510
- vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[i].qs, vl);
5511
- vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
5388
+ vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[ib].qs, vl);
5389
+ vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
5512
5390
 
5513
5391
  vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl);
5514
5392
 
@@ -5517,40 +5395,38 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
5517
5395
 
5518
5396
  int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
5519
5397
 
5520
- sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
5398
+ sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d));
5521
5399
  }
5522
-
5523
- *s = sumf;
5524
-
5525
5400
  #elif defined(__POWER9_VECTOR__)
5401
+ const vector signed int v0 = vec_splats((int32_t)0);
5526
5402
  vector float vsumf0 = vec_splats(0.0f);
5527
5403
 
5528
- #pragma GCC unroll 4
5529
- for (int i = 0; i < nb; i++) {
5530
- __builtin_prefetch(x[i].qs, 0, 1);
5531
- __builtin_prefetch(y[i].qs, 0, 1);
5404
+ #pragma GCC unroll 8
5405
+ for (; ib < nb; ++ib) {
5406
+ __builtin_prefetch(x[ib].qs, 0, 1);
5407
+ __builtin_prefetch(y[ib].qs, 0, 1);
5532
5408
 
5533
- vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
5534
- vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[i].d));
5409
+ vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
5410
+ vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
5535
5411
  vector float vd = vec_mul(vxd, vyd);
5536
5412
 
5537
- vector signed char q8x0 = vec_xl( 0, x[i].qs);
5538
- vector signed char q8x1 = vec_xl(16, x[i].qs);
5539
- vector signed char q8y0 = vec_xl( 0, y[i].qs);
5540
- vector signed char q8y1 = vec_xl(16, y[i].qs);
5413
+ vector signed char q8x0 = vec_xl( 0, x[ib].qs);
5414
+ vector signed char q8x1 = vec_xl(16, x[ib].qs);
5415
+ vector signed char q8y0 = vec_xl( 0, y[ib].qs);
5416
+ vector signed char q8y1 = vec_xl(16, y[ib].qs);
5541
5417
 
5542
5418
  vector signed short qv0 = vec_mule(q8x0, q8y0);
5543
5419
  vector signed short qv1 = vec_mulo(q8x0, q8y0);
5544
5420
  vector signed short qv2 = vec_mule(q8x1, q8y1);
5545
5421
  vector signed short qv3 = vec_mulo(q8x1, q8y1);
5546
5422
 
5547
- vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackh(qv1));
5548
- vector signed int vsumi1 = vec_add(vec_unpackl(qv0), vec_unpackl(qv1));
5549
- vector signed int vsumi2 = vec_add(vec_unpackh(qv2), vec_unpackh(qv3));
5550
- vector signed int vsumi3 = vec_add(vec_unpackl(qv2), vec_unpackl(qv3));
5423
+ vector signed int vsumi0 = v0;
5424
+ vector signed int vsumi1 = v0;
5551
5425
 
5552
- vsumi0 = vec_add(vsumi0, vsumi2);
5553
- vsumi1 = vec_add(vsumi1, vsumi3);
5426
+ vsumi0 = vec_sum4s(qv0, vsumi0);
5427
+ vsumi1 = vec_sum4s(qv1, vsumi1);
5428
+ vsumi0 = vec_sum4s(qv2, vsumi0);
5429
+ vsumi1 = vec_sum4s(qv3, vsumi1);
5554
5430
 
5555
5431
  vsumi0 = vec_add(vsumi0, vsumi1);
5556
5432
 
@@ -5560,18 +5436,18 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
5560
5436
  vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
5561
5437
  vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
5562
5438
 
5563
- *s = vec_extract(vsumf0, 0);
5439
+ sumf = vec_extract(vsumf0, 0);
5564
5440
 
5565
5441
  #elif defined(__loongarch_asx)
5566
5442
  // Initialize accumulator with zeros
5567
5443
  __m256 acc = (__m256)__lasx_xvldi(0);
5568
5444
 
5569
5445
  // Main loop
5570
- for (int i = 0; i < nb; ++i) {
5446
+ for (; ib < nb; ++ib) {
5571
5447
  // Compute combined scale for the block
5572
- const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
5573
- __m256i qx = __lasx_xvld((const __m256i *)x[i].qs, 0);
5574
- __m256i qy = __lasx_xvld((const __m256i *)y[i].qs, 0);
5448
+ const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
5449
+ __m256i qx = __lasx_xvld((const __m256i *)x[ib].qs, 0);
5450
+ __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
5575
5451
 
5576
5452
  const __m256 q = mul_sum_i8_pairs_float(qx, qy);
5577
5453
 
@@ -5579,24 +5455,19 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
5579
5455
  acc = __lasx_xvfmadd_s( d, q, acc );
5580
5456
  }
5581
5457
 
5582
- *s = hsum_float_8(acc);
5583
-
5584
- #else
5585
- // scalar
5586
- float sumf = 0.0;
5587
-
5588
- for (int i = 0; i < nb; i++) {
5458
+ sumf = hsum_float_8(acc);
5459
+ #endif
5460
+ for (; ib < nb; ++ib) {
5589
5461
  int sumi = 0;
5590
5462
 
5591
5463
  for (int j = 0; j < qk; j++) {
5592
- sumi += x[i].qs[j]*y[i].qs[j];
5464
+ sumi += x[ib].qs[j]*y[ib].qs[j];
5593
5465
  }
5594
5466
 
5595
- sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
5467
+ sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d));
5596
5468
  }
5597
5469
 
5598
5470
  *s = sumf;
5599
- #endif
5600
5471
  }
5601
5472
 
5602
5473
  void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
@@ -5938,6 +5809,7 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
5938
5809
  #elif defined(__POWER9_VECTOR__)
5939
5810
  const vector signed char lowMask = vec_splats((signed char)0x3);
5940
5811
  const vector signed char lowScaleMask = vec_splats((signed char)0xF);
5812
+ const vector int v0 = vec_splats((int32_t)0);
5941
5813
  const vector unsigned char v2 = vec_splats((unsigned char)0x2);
5942
5814
  const vector unsigned char v6 = vec_splats((unsigned char)0x6);
5943
5815
  const vector unsigned char v4 = vec_splats((unsigned char)0x4);
@@ -5975,15 +5847,17 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
5975
5847
  vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2);
5976
5848
  vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3);
5977
5849
 
5978
- vector signed int vsumi0 = vec_splats((int32_t)0);
5979
- vector signed int vsumi1 = vec_splats((int32_t)0);
5980
- vector signed int vsumi2 = vec_splats((int32_t)0);
5981
- vector signed int vsumi3 = vec_splats((int32_t)0);
5982
- vector signed int vsumi4 = vec_splats((int32_t)0);
5983
- vector signed int vsumi5 = vec_splats((int32_t)0);
5984
- vector signed int vsumi6 = vec_splats((int32_t)0);
5985
- vector signed int vsumi7 = vec_splats((int32_t)0);
5850
+ vector signed int vsumi0 = v0;
5851
+ vector signed int vsumi1 = v0;
5852
+ vector signed int vsumi2 = v0;
5853
+ vector signed int vsumi3 = v0;
5854
+ vector signed int vsumi4 = v0;
5855
+ vector signed int vsumi5 = v0;
5856
+ vector signed int vsumi6 = v0;
5857
+ vector signed int vsumi7 = v0;
5986
5858
 
5859
+ const uint8_t * restrict q2 = x[i].qs;
5860
+ const int8_t * restrict q8 = y[i].qs;
5987
5861
 
5988
5862
  for (int j = 0; j < QK_K/128; ++j) {
5989
5863
  __builtin_prefetch(q2, 0, 1);
@@ -5993,14 +5867,14 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
5993
5867
  vector signed char qxs1 = (vector signed char)vec_xl(16, q2);
5994
5868
  q2 += 32;
5995
5869
 
5996
- vector signed char q2x00 = vec_and(qxs0, lowMask);
5997
- vector signed char q2x01 = vec_and(vec_sr(qxs0, v2), lowMask);
5998
- vector signed char q2x02 = vec_and(vec_sr(qxs0, v4), lowMask);
5999
- vector signed char q2x03 = vec_and(vec_sr(qxs0, v6), lowMask);
6000
- vector signed char q2x10 = vec_and(qxs1, lowMask);
6001
- vector signed char q2x11 = vec_and(vec_sr(qxs1, v2), lowMask);
6002
- vector signed char q2x12 = vec_and(vec_sr(qxs1, v4), lowMask);
6003
- vector signed char q2x13 = vec_and(vec_sr(qxs1, v6), lowMask);
5870
+ vector unsigned char q2x00 = (vector unsigned char)vec_and(qxs0, lowMask);
5871
+ vector unsigned char q2x01 = (vector unsigned char)vec_and(vec_sr(qxs0, v2), lowMask);
5872
+ vector unsigned char q2x02 = (vector unsigned char)vec_and(vec_sr(qxs0, v4), lowMask);
5873
+ vector unsigned char q2x03 = (vector unsigned char)vec_and(vec_sr(qxs0, v6), lowMask);
5874
+ vector unsigned char q2x10 = (vector unsigned char)vec_and(qxs1, lowMask);
5875
+ vector unsigned char q2x11 = (vector unsigned char)vec_and(vec_sr(qxs1, v2), lowMask);
5876
+ vector unsigned char q2x12 = (vector unsigned char)vec_and(vec_sr(qxs1, v4), lowMask);
5877
+ vector unsigned char q2x13 = (vector unsigned char)vec_and(vec_sr(qxs1, v6), lowMask);
6004
5878
 
6005
5879
  vector signed char q8y00 = vec_xl( 0, q8);
6006
5880
  vector signed char q8y10 = vec_xl( 16, q8);
@@ -6012,45 +5886,36 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6012
5886
  vector signed char q8y13 = vec_xl(112, q8);
6013
5887
  q8 += 128;
6014
5888
 
6015
- vector signed short qv0 = vec_add(vec_mule(q2x00, q8y00), vec_mulo(q2x00, q8y00));
6016
- vector signed short qv1 = vec_add(vec_mule(q2x01, q8y01), vec_mulo(q2x01, q8y01));
6017
- vector signed short qv2 = vec_add(vec_mule(q2x02, q8y02), vec_mulo(q2x02, q8y02));
6018
- vector signed short qv3 = vec_add(vec_mule(q2x03, q8y03), vec_mulo(q2x03, q8y03));
6019
- vector signed short qv4 = vec_add(vec_mule(q2x10, q8y10), vec_mulo(q2x10, q8y10));
6020
- vector signed short qv5 = vec_add(vec_mule(q2x11, q8y11), vec_mulo(q2x11, q8y11));
6021
- vector signed short qv6 = vec_add(vec_mule(q2x12, q8y12), vec_mulo(q2x12, q8y12));
6022
- vector signed short qv7 = vec_add(vec_mule(q2x13, q8y13), vec_mulo(q2x13, q8y13));
6023
-
6024
- vector signed short vscales_h = vec_unpackh(vscales);
6025
- vector signed short vs0 = vec_splat(vscales_h, 0);
6026
- vector signed short vs1 = vec_splat(vscales_h, 1);
6027
- vector signed short vs2 = vec_splat(vscales_h, 2);
6028
- vector signed short vs3 = vec_splat(vscales_h, 3);
6029
- vector signed short vs4 = vec_splat(vscales_h, 4);
6030
- vector signed short vs5 = vec_splat(vscales_h, 5);
6031
- vector signed short vs6 = vec_splat(vscales_h, 6);
6032
- vector signed short vs7 = vec_splat(vscales_h, 7);
5889
+ vector signed int qv0 = vec_msum(q8y00, q2x00, v0);
5890
+ vector signed int qv1 = vec_msum(q8y01, q2x01, v0);
5891
+ vector signed int qv2 = vec_msum(q8y02, q2x02, v0);
5892
+ vector signed int qv3 = vec_msum(q8y03, q2x03, v0);
5893
+ vector signed int qv4 = vec_msum(q8y10, q2x10, v0);
5894
+ vector signed int qv5 = vec_msum(q8y11, q2x11, v0);
5895
+ vector signed int qv6 = vec_msum(q8y12, q2x12, v0);
5896
+ vector signed int qv7 = vec_msum(q8y13, q2x13, v0);
5897
+
5898
+ vector signed short vscales_07 = vec_unpackh(vscales);
5899
+ vector signed int vscales_03 = vec_unpackh(vscales_07);
5900
+ vector signed int vscales_47 = vec_unpackl(vscales_07);
5901
+ vector signed int vs0 = vec_splat(vscales_03, 0);
5902
+ vector signed int vs1 = vec_splat(vscales_03, 1);
5903
+ vector signed int vs2 = vec_splat(vscales_03, 2);
5904
+ vector signed int vs3 = vec_splat(vscales_03, 3);
5905
+ vector signed int vs4 = vec_splat(vscales_47, 0);
5906
+ vector signed int vs5 = vec_splat(vscales_47, 1);
5907
+ vector signed int vs6 = vec_splat(vscales_47, 2);
5908
+ vector signed int vs7 = vec_splat(vscales_47, 3);
6033
5909
  vscales = vec_sld(vscales, vscales, 8);
6034
5910
 
6035
- qv0 = vec_mul(qv0, vs0);
6036
- qv1 = vec_mul(qv1, vs2);
6037
- qv2 = vec_mul(qv2, vs4);
6038
- qv3 = vec_mul(qv3, vs6);
6039
-
6040
- qv0 = vec_madd(qv4, vs1, qv0);
6041
- qv1 = vec_madd(qv5, vs3, qv1);
6042
- qv2 = vec_madd(qv6, vs5, qv2);
6043
- qv3 = vec_madd(qv7, vs7, qv3);
6044
-
6045
- vsumi0 = vec_add(vec_unpackh(qv0), vsumi0);
6046
- vsumi1 = vec_add(vec_unpackh(qv1), vsumi1);
6047
- vsumi2 = vec_add(vec_unpackh(qv2), vsumi2);
6048
- vsumi3 = vec_add(vec_unpackh(qv3), vsumi3);
6049
-
6050
- vsumi4 = vec_add(vec_unpackl(qv0), vsumi4);
6051
- vsumi5 = vec_add(vec_unpackl(qv1), vsumi5);
6052
- vsumi6 = vec_add(vec_unpackl(qv2), vsumi6);
6053
- vsumi7 = vec_add(vec_unpackl(qv3), vsumi7);
5911
+ vsumi0 = vec_add(vec_mul(qv0, vs0), vsumi0);
5912
+ vsumi1 = vec_add(vec_mul(qv1, vs2), vsumi1);
5913
+ vsumi2 = vec_add(vec_mul(qv2, vs4), vsumi2);
5914
+ vsumi3 = vec_add(vec_mul(qv3, vs6), vsumi3);
5915
+ vsumi4 = vec_add(vec_mul(qv4, vs1), vsumi4);
5916
+ vsumi5 = vec_add(vec_mul(qv5, vs3), vsumi5);
5917
+ vsumi6 = vec_add(vec_mul(qv6, vs5), vsumi6);
5918
+ vsumi7 = vec_add(vec_mul(qv7, vs7), vsumi7);
6054
5919
  }
6055
5920
 
6056
5921
  vsumi0 = vec_add(vsumi0, vsumi4);
@@ -6088,6 +5953,7 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6088
5953
 
6089
5954
  const uint8_t * restrict q2 = x[i].qs;
6090
5955
  const int8_t * restrict q8 = y[i].qs;
5956
+
6091
5957
  const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0);
6092
5958
  const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4);
6093
5959
  const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4);
@@ -6640,6 +6506,9 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6640
6506
 
6641
6507
  #elif defined(__POWER9_VECTOR__)
6642
6508
  const vector signed char lowMask = vec_splats((signed char)0x3);
6509
+ const vector signed char lowMask1 = vec_splats((int8_t)0xf);
6510
+ const vector signed char lowMask2 = vec_splats((int8_t)0x30);
6511
+ const vector int v0 = vec_splats((int32_t)0);
6643
6512
  const vector signed char v1 = vec_splats((signed char)0x1);
6644
6513
  const vector unsigned char v2 = vec_splats((unsigned char)0x2);
6645
6514
  const vector unsigned char v3 = vec_splats((unsigned char)0x3);
@@ -6657,30 +6526,33 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6657
6526
  vector float vyd = vec_splats(y[i].d);
6658
6527
  vector float vd = vec_mul(vxd, vyd);
6659
6528
 
6660
- uint32_t aux[3];
6661
- uint32_t utmp[4];
6529
+ UNUSED(kmask1);
6530
+ UNUSED(kmask2);
6662
6531
 
6663
- memcpy(aux, x[i].scales, 12);
6664
- utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
6665
- utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
6666
- utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
6667
- utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
6532
+ vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8);
6533
+ vector signed char u1 = vec_and(u0, lowMask1);
6534
+ vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4);
6535
+ vector signed char u3 = (vector signed char)vec_mergeh((vector signed int)u2, (vector signed int)vec_sr(u2, v2));
6536
+ vector signed char u30 = vec_sl(vec_and(u3, lowMask), v4);
6537
+ vector signed char u31 = vec_and(u3, lowMask2);
6538
+
6539
+ u1 = vec_or(u1, u30);
6540
+ u2 = vec_or(vec_sr(u0, v4), u31);
6668
6541
 
6669
- vector signed char vscales = (vector signed char)vec_xl( 0, utmp);
6542
+ vector signed char vscales = (vector signed char)vec_mergeh((vector signed long long)u1, (vector signed long long)u2);
6670
6543
  vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].hmask);
6671
6544
  vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].hmask);
6672
6545
 
6673
6546
  vscales = vec_sub(vscales, off);
6674
6547
 
6675
- vector signed int vsumi0 = vec_splats((int32_t)0);
6676
- vector signed int vsumi1 = vec_splats((int32_t)0);
6677
- vector signed int vsumi2 = vec_splats((int32_t)0);
6678
- vector signed int vsumi3 = vec_splats((int32_t)0);
6679
- vector signed int vsumi4 = vec_splats((int32_t)0);
6680
- vector signed int vsumi5 = vec_splats((int32_t)0);
6681
- vector signed int vsumi6 = vec_splats((int32_t)0);
6682
- vector signed int vsumi7 = vec_splats((int32_t)0);
6683
-
6548
+ vector signed int vsumi0 = v0;
6549
+ vector signed int vsumi1 = v0;
6550
+ vector signed int vsumi2 = v0;
6551
+ vector signed int vsumi3 = v0;
6552
+ vector signed int vsumi4 = v0;
6553
+ vector signed int vsumi5 = v0;
6554
+ vector signed int vsumi6 = v0;
6555
+ vector signed int vsumi7 = v0;
6684
6556
 
6685
6557
  const uint8_t * restrict q3 = x[i].qs;
6686
6558
  const int8_t * restrict q8 = y[i].qs;
@@ -6754,23 +6626,14 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6754
6626
  vector signed short qv12 = vec_add(vec_mule(q3x12, q8y12), vec_mulo(q3x12, q8y12));
6755
6627
  vector signed short qv13 = vec_add(vec_mule(q3x13, q8y13), vec_mulo(q3x13, q8y13));
6756
6628
 
6757
- vector signed int vsum0 = vec_add(vec_mule(qv00, vs0), vec_mulo(qv00, vs0));
6758
- vector signed int vsum1 = vec_add(vec_mule(qv01, vs2), vec_mulo(qv01, vs2));
6759
- vector signed int vsum2 = vec_add(vec_mule(qv02, vs4), vec_mulo(qv02, vs4));
6760
- vector signed int vsum3 = vec_add(vec_mule(qv03, vs6), vec_mulo(qv03, vs6));
6761
- vector signed int vsum4 = vec_add(vec_mule(qv10, vs1), vec_mulo(qv10, vs1));
6762
- vector signed int vsum5 = vec_add(vec_mule(qv11, vs3), vec_mulo(qv11, vs3));
6763
- vector signed int vsum6 = vec_add(vec_mule(qv12, vs5), vec_mulo(qv12, vs5));
6764
- vector signed int vsum7 = vec_add(vec_mule(qv13, vs7), vec_mulo(qv13, vs7));
6765
-
6766
- vsumi0 = vec_add(vsum0, vsumi0);
6767
- vsumi1 = vec_add(vsum1, vsumi1);
6768
- vsumi2 = vec_add(vsum2, vsumi2);
6769
- vsumi3 = vec_add(vsum3, vsumi3);
6770
- vsumi4 = vec_add(vsum4, vsumi4);
6771
- vsumi5 = vec_add(vsum5, vsumi5);
6772
- vsumi6 = vec_add(vsum6, vsumi6);
6773
- vsumi7 = vec_add(vsum7, vsumi7);
6629
+ vsumi0 = vec_msum(qv00, vs0, vsumi0);
6630
+ vsumi1 = vec_msum(qv01, vs2, vsumi1);
6631
+ vsumi2 = vec_msum(qv02, vs4, vsumi2);
6632
+ vsumi3 = vec_msum(qv03, vs6, vsumi3);
6633
+ vsumi4 = vec_msum(qv10, vs1, vsumi4);
6634
+ vsumi5 = vec_msum(qv11, vs3, vsumi5);
6635
+ vsumi6 = vec_msum(qv12, vs5, vsumi6);
6636
+ vsumi7 = vec_msum(qv13, vs7, vsumi7);
6774
6637
  }
6775
6638
 
6776
6639
  vsumi0 = vec_add(vsumi0, vsumi4);
@@ -6807,6 +6670,8 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6807
6670
  for (int i = 0; i < nb; ++i) {
6808
6671
 
6809
6672
  const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
6673
+ const uint8_t * restrict q3 = x[i].qs;
6674
+ const int8_t * restrict q8 = y[i].qs;
6810
6675
  // Set up scales
6811
6676
  memcpy(aux, x[i].scales, 12);
6812
6677
  __m128i scales128 = lsx_set_w(
@@ -6828,29 +6693,32 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
6828
6693
 
6829
6694
  int bit = 0;
6830
6695
  int is = 0;
6696
+ __m256i xvbit;
6831
6697
 
6832
- const uint8_t * restrict q3 = x[i].qs;
6833
- const int8_t * restrict q8 = y[i].qs;
6834
6698
 
6835
6699
  for (int j = 0; j < QK_K/128; ++j) {
6836
6700
  // load low 2 bits
6837
6701
  const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;
6838
6702
 
6703
+ xvbit = __lasx_xvreplgr2vr_h(bit);
6839
6704
  // prepare low and high bits
6840
6705
  const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3);
6841
- const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
6706
+ const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
6842
6707
  ++bit;
6843
6708
 
6709
+ xvbit = __lasx_xvreplgr2vr_h(bit);
6844
6710
  const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3);
6845
- const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
6711
+ const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
6846
6712
  ++bit;
6847
6713
 
6714
+ xvbit = __lasx_xvreplgr2vr_h(bit);
6848
6715
  const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3);
6849
- const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
6716
+ const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
6850
6717
  ++bit;
6851
6718
 
6719
+ xvbit = __lasx_xvreplgr2vr_h(bit);
6852
6720
  const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3);
6853
- const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
6721
+ const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
6854
6722
  ++bit;
6855
6723
 
6856
6724
  // load Q8 quants
@@ -7264,6 +7132,10 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7264
7132
 
7265
7133
  #elif defined(__POWER9_VECTOR__)
7266
7134
  const vector signed char lowMask = vec_splats((signed char)0xF);
7135
+ const vector signed char lowMask1 = vec_splats((int8_t)0x3f);
7136
+ const vector signed char lowMask2 = vec_splats((int8_t)0x30);
7137
+ const vector int v0 = vec_splats((int32_t)0);
7138
+ const vector unsigned char v2 = vec_splats((uint8_t)2);
7267
7139
  const vector unsigned char v4 = vec_splats((unsigned char)0x4);
7268
7140
 
7269
7141
  vector float vsumf0 = vec_splats(0.0f);
@@ -7282,15 +7154,24 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7282
7154
  vector signed short q8ysums0 = vec_xl( 0, y[i].bsums);
7283
7155
  vector signed short q8ysums1 = vec_xl(16, y[i].bsums);
7284
7156
 
7285
- memcpy(utmp, x[i].scales, 12);
7157
+ UNUSED(kmask1);
7158
+ UNUSED(kmask2);
7159
+ UNUSED(kmask3);
7160
+ UNUSED(utmp);
7286
7161
 
7287
- utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
7288
- const uint32_t uaux = utmp[1] & kmask1;
7289
- utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
7290
- utmp[2] = uaux;
7291
- utmp[0] &= kmask1;
7162
+ vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8);
7163
+ vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2);
7164
+ vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4);
7165
+ vector signed char u3 = vec_sr(u2, v4);
7166
+
7167
+ vector signed char u30 = u1;
7168
+ vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3);
7169
+
7170
+ u1 = vec_and(u0, lowMask1);
7171
+ u2 = vec_or(u30, u31);
7172
+
7173
+ vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2);
7292
7174
 
7293
- vector signed char utmps = (vector signed char)vec_xl( 0, utmp);
7294
7175
  vector signed short vscales = vec_unpackh(utmps);
7295
7176
  vector signed short q4xmins = vec_unpackl(utmps);
7296
7177
  vector signed short q4xmins0 = vec_mergeh(q4xmins, q4xmins);
@@ -7306,14 +7187,10 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7306
7187
  vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2);
7307
7188
  vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3);
7308
7189
 
7309
- vector signed int vsumi0 = vec_splats((int32_t)0);
7310
- vector signed int vsumi1 = vec_splats((int32_t)0);
7311
- vector signed int vsumi2 = vec_splats((int32_t)0);
7312
- vector signed int vsumi3 = vec_splats((int32_t)0);
7313
- vector signed int vsumi4 = vec_splats((int32_t)0);
7314
- vector signed int vsumi5 = vec_splats((int32_t)0);
7315
- vector signed int vsumi6 = vec_splats((int32_t)0);
7316
- vector signed int vsumi7 = vec_splats((int32_t)0);
7190
+ vector signed int vsumi0 = v0;
7191
+ vector signed int vsumi1 = v0;
7192
+ vector signed int vsumi2 = v0;
7193
+ vector signed int vsumi3 = v0;
7317
7194
 
7318
7195
  const uint8_t * restrict q4 = x[i].qs;
7319
7196
  const int8_t * restrict q8 = y[i].qs;
@@ -7328,14 +7205,14 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7328
7205
  vector signed char qxs3 = (vector signed char)vec_xl(48, q4);
7329
7206
  q4 += 64;
7330
7207
 
7331
- vector signed char q4x00 = vec_and(qxs0, lowMask);
7332
- vector signed char q4x01 = vec_sr(qxs0, v4);
7333
- vector signed char q4x10 = vec_and(qxs1, lowMask);
7334
- vector signed char q4x11 = vec_sr(qxs1, v4);
7335
- vector signed char q4x20 = vec_and(qxs2, lowMask);
7336
- vector signed char q4x21 = vec_sr(qxs2, v4);
7337
- vector signed char q4x30 = vec_and(qxs3, lowMask);
7338
- vector signed char q4x31 = vec_sr(qxs3, v4);
7208
+ vector unsigned char q4x00 = (vector unsigned char)vec_and(qxs0, lowMask);
7209
+ vector unsigned char q4x01 = (vector unsigned char)vec_sr(qxs0, v4);
7210
+ vector unsigned char q4x10 = (vector unsigned char)vec_and(qxs1, lowMask);
7211
+ vector unsigned char q4x11 = (vector unsigned char)vec_sr(qxs1, v4);
7212
+ vector unsigned char q4x20 = (vector unsigned char)vec_and(qxs2, lowMask);
7213
+ vector unsigned char q4x21 = (vector unsigned char)vec_sr(qxs2, v4);
7214
+ vector unsigned char q4x30 = (vector unsigned char)vec_and(qxs3, lowMask);
7215
+ vector unsigned char q4x31 = (vector unsigned char)vec_sr(qxs3, v4);
7339
7216
 
7340
7217
  vector signed char q8y00 = vec_xl( 0, q8);
7341
7218
  vector signed char q8y10 = vec_xl( 16, q8);
@@ -7347,41 +7224,33 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7347
7224
  vector signed char q8y31 = vec_xl(112, q8);
7348
7225
  q8 += 128;
7349
7226
 
7350
- vector signed short qv00 = vec_add(vec_mule(q4x00, q8y00), vec_mulo(q4x00, q8y00));
7351
- vector signed short qv01 = vec_add(vec_mule(q4x01, q8y01), vec_mulo(q4x01, q8y01));
7352
- vector signed short qv10 = vec_add(vec_mule(q4x10, q8y10), vec_mulo(q4x10, q8y10));
7353
- vector signed short qv11 = vec_add(vec_mule(q4x11, q8y11), vec_mulo(q4x11, q8y11));
7354
- vector signed short qv20 = vec_add(vec_mule(q4x20, q8y20), vec_mulo(q4x20, q8y20));
7355
- vector signed short qv21 = vec_add(vec_mule(q4x21, q8y21), vec_mulo(q4x21, q8y21));
7356
- vector signed short qv30 = vec_add(vec_mule(q4x30, q8y30), vec_mulo(q4x30, q8y30));
7357
- vector signed short qv31 = vec_add(vec_mule(q4x31, q8y31), vec_mulo(q4x31, q8y31));
7358
-
7359
- vector signed short vs0 = vec_splat(vscales, 0);
7360
- vector signed short vs1 = vec_splat(vscales, 1);
7361
- vector signed short vs2 = vec_splat(vscales, 2);
7362
- vector signed short vs3 = vec_splat(vscales, 3);
7227
+ vector signed int qv00 = vec_msum(q8y00, q4x00, v0);
7228
+ vector signed int qv01 = vec_msum(q8y01, q4x01, v0);
7229
+ vector signed int qv10 = vec_msum(q8y10, q4x10, v0);
7230
+ vector signed int qv11 = vec_msum(q8y11, q4x11, v0);
7231
+ vector signed int qv20 = vec_msum(q8y20, q4x20, v0);
7232
+ vector signed int qv21 = vec_msum(q8y21, q4x21, v0);
7233
+ vector signed int qv30 = vec_msum(q8y30, q4x30, v0);
7234
+ vector signed int qv31 = vec_msum(q8y31, q4x31, v0);
7235
+
7236
+ vector signed int vscales_h = vec_unpackh(vscales);
7237
+ vector signed int vs0 = vec_splat(vscales_h, 0);
7238
+ vector signed int vs1 = vec_splat(vscales_h, 1);
7239
+ vector signed int vs2 = vec_splat(vscales_h, 2);
7240
+ vector signed int vs3 = vec_splat(vscales_h, 3);
7363
7241
  vscales = vec_sld(vscales, vscales, 8);
7364
7242
 
7365
- qv00 = vec_add(qv00, qv10);
7366
- qv10 = vec_add(qv01, qv11);
7367
- qv20 = vec_add(qv20, qv30);
7368
- qv30 = vec_add(qv21, qv31);
7243
+ vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0);
7244
+ vsumi1 = vec_add(vec_mul(qv01, vs1), vsumi1);
7245
+ vsumi2 = vec_add(vec_mul(qv20, vs2), vsumi2);
7246
+ vsumi3 = vec_add(vec_mul(qv21, vs3), vsumi3);
7369
7247
 
7370
- vsumi0 = vec_add(vec_mule(qv00, vs0), vsumi0);
7371
- vsumi1 = vec_add(vec_mulo(qv00, vs0), vsumi1);
7372
- vsumi2 = vec_add(vec_mule(qv10, vs1), vsumi2);
7373
- vsumi3 = vec_add(vec_mulo(qv10, vs1), vsumi3);
7374
- vsumi4 = vec_add(vec_mule(qv20, vs2), vsumi4);
7375
- vsumi5 = vec_add(vec_mulo(qv20, vs2), vsumi5);
7376
- vsumi6 = vec_add(vec_mule(qv30, vs3), vsumi6);
7377
- vsumi7 = vec_add(vec_mulo(qv30, vs3), vsumi7);
7248
+ vsumi0 = vec_add(vec_mul(qv10, vs0), vsumi0);
7249
+ vsumi1 = vec_add(vec_mul(qv11, vs1), vsumi1);
7250
+ vsumi2 = vec_add(vec_mul(qv30, vs2), vsumi2);
7251
+ vsumi3 = vec_add(vec_mul(qv31, vs3), vsumi3);
7378
7252
  }
7379
7253
 
7380
- vsumi0 = vec_add(vsumi0, vsumi4);
7381
- vsumi1 = vec_add(vsumi1, vsumi5);
7382
- vsumi2 = vec_add(vsumi2, vsumi6);
7383
- vsumi3 = vec_add(vsumi3, vsumi7);
7384
-
7385
7254
  vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
7386
7255
  vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
7387
7256
  vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
@@ -7399,6 +7268,9 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7399
7268
  *s = vec_extract(vsumf0, 0);
7400
7269
 
7401
7270
  #elif defined __loongarch_asx
7271
+ GGML_UNUSED(kmask1);
7272
+ GGML_UNUSED(kmask2);
7273
+ GGML_UNUSED(kmask3);
7402
7274
 
7403
7275
  const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
7404
7276
 
@@ -7411,6 +7283,11 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7411
7283
  const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
7412
7284
 
7413
7285
  memcpy(utmp, x[i].scales, 12);
7286
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
7287
+ const uint32_t uaux = utmp[1] & kmask1;
7288
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
7289
+ utmp[2] = uaux;
7290
+ utmp[0] &= kmask1;
7414
7291
 
7415
7292
  const uint8_t * restrict q4 = x[i].qs;
7416
7293
  const int8_t * restrict q8 = y[i].qs;
@@ -7450,16 +7327,17 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7450
7327
 
7451
7328
  __m256 vd = __lasx_xvreplfr2vr_s(d);
7452
7329
  acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
7330
+
7453
7331
  }
7454
7332
 
7455
7333
  acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee));
7456
7334
  __m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0);
7457
7335
  acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
7458
7336
 
7337
+
7459
7338
  ft_union fi;
7460
7339
  fi.i = __lsx_vpickve2gr_w(acc_m, 0);
7461
7340
  *s = hsum_float_8(acc) + fi.f ;
7462
-
7463
7341
  #else
7464
7342
 
7465
7343
  const uint8_t * scales = (const uint8_t*)&utmp[0];
@@ -7874,6 +7752,9 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7874
7752
 
7875
7753
  #elif defined(__POWER9_VECTOR__)
7876
7754
  const vector signed char lowMask = vec_splats((signed char)0xF);
7755
+ const vector signed char lowMask1 = vec_splats((int8_t)0x3f);
7756
+ const vector signed char lowMask2 = vec_splats((int8_t)0x30);
7757
+ const vector int v0 = vec_splats((int32_t)0);
7877
7758
  const vector unsigned char v1 = vec_splats((unsigned char)0x1);
7878
7759
  const vector unsigned char v2 = vec_splats((unsigned char)0x2);
7879
7760
  const vector unsigned char v3 = vec_splats((unsigned char)0x3);
@@ -7892,18 +7773,27 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7892
7773
  vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin));
7893
7774
  vector float vdmin = vec_mul(vxmin, vyd);
7894
7775
 
7895
- memcpy(utmp, x[i].scales, 12);
7776
+ UNUSED(kmask1);
7777
+ UNUSED(kmask2);
7778
+ UNUSED(kmask3);
7779
+ UNUSED(utmp);
7896
7780
 
7897
- utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
7898
- const uint32_t uaux = utmp[1] & kmask1;
7899
- utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
7900
- utmp[2] = uaux;
7901
- utmp[0] &= kmask1;
7781
+ vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8);
7782
+ vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2);
7783
+ vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4);
7784
+ vector signed char u3 = vec_sr(u2, v4);
7785
+
7786
+ vector signed char u30 = u1;
7787
+ vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3);
7788
+
7789
+ u1 = vec_and(u0, lowMask1);
7790
+ u2 = vec_or(u30, u31);
7791
+
7792
+ vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2);
7902
7793
 
7903
7794
  vector signed short q8ysums0 = vec_xl( 0, y[i].bsums);
7904
7795
  vector signed short q8ysums1 = vec_xl(16, y[i].bsums);
7905
7796
 
7906
- vector signed char utmps = (vector signed char)vec_xl( 0, utmp);
7907
7797
  vector signed short vscales = vec_unpackh(utmps);
7908
7798
 
7909
7799
  vector signed short q5xmins = vec_unpackl(utmps);
@@ -7923,10 +7813,10 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7923
7813
  vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].qh);
7924
7814
  vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].qh);
7925
7815
 
7926
- vector signed int vsumi0 = vec_splats((int32_t)0);
7927
- vector signed int vsumi1 = vec_splats((int32_t)0);
7928
- vector signed int vsumi2 = vec_splats((int32_t)0);
7929
- vector signed int vsumi3 = vec_splats((int32_t)0);
7816
+ vector signed int vsumi0 = v0;
7817
+ vector signed int vsumi1 = v0;
7818
+ vector signed int vsumi2 = v0;
7819
+ vector signed int vsumi3 = v0;
7930
7820
 
7931
7821
  const uint8_t * restrict q5 = x[i].qs;
7932
7822
  const int8_t * restrict q8 = y[i].qs;
@@ -7951,10 +7841,10 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7951
7841
  qxhs0 = vec_sr(qxhs0, v2);
7952
7842
  qxhs1 = vec_sr(qxhs1, v2);
7953
7843
 
7954
- vector signed char q5x00 = vec_or(q5h00, qxs00);
7955
- vector signed char q5x01 = vec_or(q5h01, qxs01);
7956
- vector signed char q5x10 = vec_or(q5h10, qxs10);
7957
- vector signed char q5x11 = vec_or(q5h11, qxs11);
7844
+ vector unsigned char q5x00 = (vector unsigned char)vec_or(q5h00, qxs00);
7845
+ vector unsigned char q5x01 = (vector unsigned char)vec_or(q5h01, qxs01);
7846
+ vector unsigned char q5x10 = (vector unsigned char)vec_or(q5h10, qxs10);
7847
+ vector unsigned char q5x11 = (vector unsigned char)vec_or(q5h11, qxs11);
7958
7848
 
7959
7849
  vector signed char q8y00 = vec_xl( 0, q8);
7960
7850
  vector signed char q8y10 = vec_xl(16, q8);
@@ -7962,22 +7852,20 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7962
7852
  vector signed char q8y11 = vec_xl(48, q8);
7963
7853
  q8 += 64;
7964
7854
 
7965
- vector signed short qv00 = vec_add(vec_mule(q5x00, q8y00), vec_mulo(q5x00, q8y00));
7966
- vector signed short qv01 = vec_add(vec_mule(q5x01, q8y01), vec_mulo(q5x01, q8y01));
7967
- vector signed short qv10 = vec_add(vec_mule(q5x10, q8y10), vec_mulo(q5x10, q8y10));
7968
- vector signed short qv11 = vec_add(vec_mule(q5x11, q8y11), vec_mulo(q5x11, q8y11));
7855
+ vector signed int qv00 = vec_msum(q8y00, q5x00, v0);
7856
+ vector signed int qv01 = vec_msum(q8y01, q5x01, v0);
7857
+ vector signed int qv10 = vec_msum(q8y10, q5x10, v0);
7858
+ vector signed int qv11 = vec_msum(q8y11, q5x11, v0);
7969
7859
 
7970
- vector signed short vs0 = vec_splat(vscales, 0);
7971
- vector signed short vs1 = vec_splat(vscales, 1);
7860
+ vector signed int vscales_h = vec_unpackh(vscales);
7861
+ vector signed int vs0 = vec_splat(vscales_h, 0);
7862
+ vector signed int vs1 = vec_splat(vscales_h, 1);
7972
7863
  vscales = vec_sld(vscales, vscales, 12);
7973
7864
 
7974
- qv00 = vec_add(qv00, qv10);
7975
- qv01 = vec_add(qv01, qv11);
7976
-
7977
- vsumi0 = vec_add(vec_mule(qv00, vs0), vsumi0);
7978
- vsumi1 = vec_add(vec_mulo(qv00, vs0), vsumi1);
7979
- vsumi2 = vec_add(vec_mule(qv01, vs1), vsumi2);
7980
- vsumi3 = vec_add(vec_mulo(qv01, vs1), vsumi3);
7865
+ vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0);
7866
+ vsumi1 = vec_add(vec_mul(qv10, vs0), vsumi1);
7867
+ vsumi2 = vec_add(vec_mul(qv01, vs1), vsumi2);
7868
+ vsumi3 = vec_add(vec_mul(qv11, vs1), vsumi3);
7981
7869
  }
7982
7870
 
7983
7871
  vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
@@ -7997,6 +7885,9 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
7997
7885
  *s = vec_extract(vsumf0, 0);
7998
7886
 
7999
7887
  #elif defined __loongarch_asx
7888
+ GGML_UNUSED(kmask1);
7889
+ GGML_UNUSED(kmask2);
7890
+ GGML_UNUSED(kmask3);
8000
7891
 
8001
7892
  const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
8002
7893
  const __m128i mzero = __lsx_vldi(0);
@@ -8015,6 +7906,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
8015
7906
  const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
8016
7907
 
8017
7908
  memcpy(utmp, x[i].scales, 12);
7909
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
7910
+ const uint32_t uaux = utmp[1] & kmask1;
7911
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
7912
+ utmp[2] = uaux;
7913
+ utmp[0] &= kmask1;
8018
7914
 
8019
7915
  const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
8020
7916
 
@@ -8033,6 +7929,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
8033
7929
  __m256i sumi = __lasx_xvldi(0);
8034
7930
 
8035
7931
  int bit = 0;
7932
+ __m256i xvbit;
8036
7933
 
8037
7934
  for (int j = 0; j < QK_K/64; ++j) {
8038
7935
 
@@ -8041,13 +7938,15 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
8041
7938
 
8042
7939
  const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
8043
7940
 
7941
+ xvbit = __lasx_xvreplgr2vr_h(bit++);
8044
7942
  const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4);
8045
- const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvand_v(hbits, hmask), bit++), 4);
7943
+ const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
8046
7944
  const __m256i q5_0 = __lasx_xvadd_b(q5l_0, q5h_0);
8047
7945
  hmask = __lasx_xvslli_h(hmask, 1);
8048
7946
 
7947
+ xvbit = __lasx_xvreplgr2vr_h(bit++);
8049
7948
  const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4);
8050
- const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvand_v(hbits, hmask), bit++), 4);
7949
+ const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
8051
7950
  const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1);
8052
7951
  hmask = __lasx_xvslli_h(hmask, 1);
8053
7952
 
@@ -8061,10 +7960,12 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
8061
7960
  p16_1 = lasx_madd_h(scale_1, p16_1);
8062
7961
 
8063
7962
  sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
7963
+
8064
7964
  }
8065
7965
 
8066
7966
  __m256 vd = __lasx_xvreplfr2vr_s(d);
8067
7967
  acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
7968
+
8068
7969
  }
8069
7970
 
8070
7971
  *s = hsum_float_8(acc) + summs;
@@ -8525,6 +8426,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
8525
8426
 
8526
8427
  #elif defined(__POWER9_VECTOR__)
8527
8428
  const vector signed char lowMask = vec_splats((signed char)0xF);
8429
+ const vector int v0 = vec_splats((int32_t)0);
8528
8430
  const vector unsigned char v2 = vec_splats((unsigned char)0x2);
8529
8431
  const vector unsigned char v3 = vec_splats((unsigned char)0x3);
8530
8432
  const vector unsigned char v4 = vec_splats((unsigned char)0x4);
@@ -8541,14 +8443,14 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
8541
8443
  vector float vyd = vec_splats(y[i].d);
8542
8444
  vector float vd = vec_mul(vxd, vyd);
8543
8445
 
8544
- vector signed int vsumi0 = vec_splats((int32_t)0);
8545
- vector signed int vsumi1 = vec_splats((int32_t)0);
8546
- vector signed int vsumi2 = vec_splats((int32_t)0);
8547
- vector signed int vsumi3 = vec_splats((int32_t)0);
8548
- vector signed int vsumi4 = vec_splats((int32_t)0);
8549
- vector signed int vsumi5 = vec_splats((int32_t)0);
8550
- vector signed int vsumi6 = vec_splats((int32_t)0);
8551
- vector signed int vsumi7 = vec_splats((int32_t)0);
8446
+ vector signed int vsumi0 = v0;
8447
+ vector signed int vsumi1 = v0;
8448
+ vector signed int vsumi2 = v0;
8449
+ vector signed int vsumi3 = v0;
8450
+ vector signed int vsumi4 = v0;
8451
+ vector signed int vsumi5 = v0;
8452
+ vector signed int vsumi6 = v0;
8453
+ vector signed int vsumi7 = v0;
8552
8454
 
8553
8455
  const uint8_t * restrict q6 = x[i].ql;
8554
8456
  const uint8_t * restrict qh = x[i].qh;
@@ -8628,23 +8530,14 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
8628
8530
  vector signed short vs6 = vec_splat(vscales, 6);
8629
8531
  vector signed short vs7 = vec_splat(vscales, 7);
8630
8532
 
8631
- vsumi0 = vec_add(vec_mule(qv00, vs0), vsumi0);
8632
- vsumi1 = vec_add(vec_mulo(qv00, vs0), vsumi1);
8633
- vsumi2 = vec_add(vec_mule(qv01, vs4), vsumi2);
8634
- vsumi3 = vec_add(vec_mulo(qv01, vs4), vsumi3);
8635
- vsumi4 = vec_add(vec_mule(qv10, vs1), vsumi4);
8636
- vsumi5 = vec_add(vec_mulo(qv10, vs1), vsumi5);
8637
- vsumi6 = vec_add(vec_mule(qv11, vs5), vsumi6);
8638
- vsumi7 = vec_add(vec_mulo(qv11, vs5), vsumi7);
8639
-
8640
- vsumi0 = vec_add(vec_mule(qv20, vs2), vsumi0);
8641
- vsumi1 = vec_add(vec_mulo(qv20, vs2), vsumi1);
8642
- vsumi2 = vec_add(vec_mule(qv21, vs6), vsumi2);
8643
- vsumi3 = vec_add(vec_mulo(qv21, vs6), vsumi3);
8644
- vsumi4 = vec_add(vec_mule(qv30, vs3), vsumi4);
8645
- vsumi5 = vec_add(vec_mulo(qv30, vs3), vsumi5);
8646
- vsumi6 = vec_add(vec_mule(qv31, vs7), vsumi6);
8647
- vsumi7 = vec_add(vec_mulo(qv31, vs7), vsumi7);
8533
+ vsumi0 = vec_msum(qv00, vs0, vsumi0);
8534
+ vsumi1 = vec_msum(qv01, vs4, vsumi1);
8535
+ vsumi2 = vec_msum(qv10, vs1, vsumi2);
8536
+ vsumi3 = vec_msum(qv11, vs5, vsumi3);
8537
+ vsumi4 = vec_msum(qv20, vs2, vsumi4);
8538
+ vsumi5 = vec_msum(qv21, vs6, vsumi5);
8539
+ vsumi6 = vec_msum(qv30, vs3, vsumi6);
8540
+ vsumi7 = vec_msum(qv31, vs7, vsumi7);
8648
8541
  }
8649
8542
 
8650
8543
  vsumi0 = vec_add(vsumi0, vsumi4);
@@ -8791,7 +8684,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
8791
8684
  #endif
8792
8685
  }
8793
8686
 
8794
- #if defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx)
8687
+ #if defined (__AVX__) || defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx)
8795
8688
  static const int8_t keven_signs_q2xs[1024] = {
8796
8689
  1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
8797
8690
  1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
@@ -8924,7 +8817,63 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void
8924
8817
 
8925
8818
  *s = 0.125f * hsum_float_8(accumf);
8926
8819
 
8820
+ #elif defined(__AVX__)
8821
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
8822
+
8823
+ uint32_t aux32[4];
8824
+ const uint8_t * aux8 = (const uint8_t *)aux32;
8825
+
8826
+ __m256 accumf = _mm256_setzero_ps();
8827
+ for (int i = 0; i < nb; ++i) {
8828
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8829
+ const uint16_t * restrict q2 = x[i].qs;
8830
+ const int8_t * restrict q8 = y[i].qs;
8831
+ __m128i sumi1_0 = _mm_setzero_si128();
8832
+ __m128i sumi1_1 = _mm_setzero_si128();
8833
+ __m128i sumi2_0 = _mm_setzero_si128();
8834
+ __m128i sumi2_1 = _mm_setzero_si128();
8835
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
8836
+ const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
8837
+ const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
8838
+ const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
8839
+ const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
8840
+ memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
8841
+ const __m128i q2_1_0 = _mm_set_epi64x(iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
8842
+ const __m128i q2_1_1 = _mm_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]]);
8843
+ const __m128i q2_2_0 = _mm_set_epi64x(iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
8844
+ const __m128i q2_2_1 = _mm_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]]);
8845
+ const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
8846
+ const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
8847
+ const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
8848
+ const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127]);
8849
+ const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
8850
+ const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
8851
+ const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
8852
+ const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
8853
+ const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
8854
+ const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
8855
+ const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
8856
+ const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
8857
+ const uint16_t ls1 = aux32[1] >> 28;
8858
+ const uint16_t ls2 = aux32[3] >> 28;
8859
+ const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
8860
+ const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
8861
+ const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
8862
+ const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
8863
+ sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
8864
+ sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
8865
+ sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
8866
+ sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
8867
+ }
8868
+
8869
+ accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
8870
+
8871
+ }
8872
+
8873
+ *s = 0.125f * hsum_float_8(accumf);
8874
+
8927
8875
  #elif defined(__POWER9_VECTOR__)
8876
+ const vector int v0 = vec_splats((int32_t)0);
8928
8877
  vector float vsumf0 = vec_splats(0.0f);
8929
8878
  vector float vsumf1 = vec_splats(0.0f);
8930
8879
  vector float vsumf2 = vec_splats(0.0f);
@@ -8937,14 +8886,10 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void
8937
8886
  vector float vyd = vec_splats(y[i].d);
8938
8887
  vector float vd = vec_mul(vxd, vyd);
8939
8888
 
8940
- vector signed int vsumi0 = vec_splats((int32_t)0);
8941
- vector signed int vsumi1 = vec_splats((int32_t)0);
8942
- vector signed int vsumi2 = vec_splats((int32_t)0);
8943
- vector signed int vsumi3 = vec_splats((int32_t)0);
8944
- vector signed int vsumi4 = vec_splats((int32_t)0);
8945
- vector signed int vsumi5 = vec_splats((int32_t)0);
8946
- vector signed int vsumi6 = vec_splats((int32_t)0);
8947
- vector signed int vsumi7 = vec_splats((int32_t)0);
8889
+ vector signed int vsumi0 = v0;
8890
+ vector signed int vsumi1 = v0;
8891
+ vector signed int vsumi2 = v0;
8892
+ vector signed int vsumi3 = v0;
8948
8893
 
8949
8894
  const uint16_t * restrict q2 = x[i].qs;
8950
8895
  const int8_t * restrict q8 = y[i].qs;
@@ -8991,21 +8936,12 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void
8991
8936
  vector signed short vscales01 = vec_splats((int16_t)(2*ls0+1));
8992
8937
  vector signed short vscales23 = vec_splats((int16_t)(2*ls1+1));
8993
8938
 
8994
- vsumi0 = vec_add(vec_mule(qv0, vscales01), vsumi0);
8995
- vsumi1 = vec_add(vec_mule(qv1, vscales01), vsumi1);
8996
- vsumi2 = vec_add(vec_mule(qv2, vscales23), vsumi2);
8997
- vsumi3 = vec_add(vec_mule(qv3, vscales23), vsumi3);
8998
- vsumi4 = vec_add(vec_mulo(qv0, vscales01), vsumi4);
8999
- vsumi5 = vec_add(vec_mulo(qv1, vscales01), vsumi5);
9000
- vsumi6 = vec_add(vec_mulo(qv2, vscales23), vsumi6);
9001
- vsumi7 = vec_add(vec_mulo(qv3, vscales23), vsumi7);
8939
+ vsumi0 = vec_msum(qv0, vscales01, vsumi0);
8940
+ vsumi1 = vec_msum(qv1, vscales01, vsumi1);
8941
+ vsumi2 = vec_msum(qv2, vscales23, vsumi2);
8942
+ vsumi3 = vec_msum(qv3, vscales23, vsumi3);
9002
8943
  }
9003
8944
 
9004
- vsumi0 = vec_add(vsumi0, vsumi4);
9005
- vsumi1 = vec_add(vsumi1, vsumi5);
9006
- vsumi2 = vec_add(vsumi2, vsumi6);
9007
- vsumi3 = vec_add(vsumi3, vsumi7);
9008
-
9009
8945
  vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
9010
8946
  vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
9011
8947
  vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
@@ -9279,6 +9215,165 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
9279
9215
  }
9280
9216
 
9281
9217
  *s = 0.125f * hsum_float_8(accumf);
9218
+
9219
+ #elif defined(__AVX__)
9220
+ const __m128i mone = _mm_set1_epi8(1);
9221
+ static const char block_sign_shuffle_mask_1[32] = {
9222
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
9223
+ 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
9224
+ };
9225
+ static const char block_sign_shuffle_mask_2[32] = {
9226
+ 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
9227
+ 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
9228
+ };
9229
+ static const uint8_t bit_selector_mask_bytes[32] = {
9230
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
9231
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
9232
+ };
9233
+
9234
+ const __m128i bit_selector_mask_0 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes);
9235
+ const __m128i bit_selector_mask_1 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes + 1);
9236
+ const __m128i block_sign_shuffle_1_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1);
9237
+ const __m128i block_sign_shuffle_1_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1 + 1);
9238
+ const __m128i block_sign_shuffle_2_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2);
9239
+ const __m128i block_sign_shuffle_2_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2 + 1);
9240
+
9241
+ static const uint8_t k_bit_helper[32] = {
9242
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
9243
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
9244
+ };
9245
+ const __m128i bit_helper_0 = _mm_loadu_si128((const __m128i*)k_bit_helper);
9246
+ const __m128i bit_helper_1 = _mm_loadu_si128((const __m128i*)k_bit_helper + 1);
9247
+ const __m128i m511 = _mm_set1_epi16(511);
9248
+ const __m128i m4 = _mm_set1_epi8(0xf);
9249
+ const __m128i m1 = _mm_set1_epi8(1);
9250
+
9251
+ uint64_t aux64;
9252
+
9253
+ // somewhat hacky, but gives a significant boost in performance
9254
+ __m256i aux_gindex;
9255
+ const uint16_t * gindex = (const uint16_t *)&aux_gindex;
9256
+
9257
+ __m256 accumf = _mm256_setzero_ps();
9258
+ for (int i = 0; i < nb; ++i) {
9259
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
9260
+ const uint16_t * restrict q2 = x[i].qs;
9261
+ const int8_t * restrict q8 = y[i].qs;
9262
+
9263
+ memcpy(&aux64, x[i].scales, 8);
9264
+ __m128i stmp = _mm_set1_epi64x(aux64);
9265
+ stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
9266
+ const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
9267
+
9268
+ __m128i sumi1_0 = _mm_setzero_si128();
9269
+ __m128i sumi1_1 = _mm_setzero_si128();
9270
+ __m128i sumi2_0 = _mm_setzero_si128();
9271
+ __m128i sumi2_1 = _mm_setzero_si128();
9272
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
9273
+
9274
+ const __m128i q2_data_0 = _mm_loadu_si128((const __m128i*)q2);
9275
+ const __m128i q2_data_1 = _mm_loadu_si128((const __m128i*)q2 + 1); q2 += 16;
9276
+ aux_gindex = MM256_SET_M128I(_mm_and_si128(q2_data_1, m511), _mm_and_si128(q2_data_0, m511));
9277
+
9278
+ const __m128i partial_sign_bits_0 = _mm_srli_epi16(q2_data_0, 9);
9279
+ const __m128i partial_sign_bits_1 = _mm_srli_epi16(q2_data_1, 9);
9280
+ const __m128i partial_sign_bits_upper_0 = _mm_srli_epi16(q2_data_0, 13);
9281
+ const __m128i partial_sign_bits_upper_1 = _mm_srli_epi16(q2_data_1, 13);
9282
+ const __m128i partial_sign_bits_for_counting_0 = _mm_xor_si128(partial_sign_bits_0, partial_sign_bits_upper_0);
9283
+ const __m128i partial_sign_bits_for_counting_1 = _mm_xor_si128(partial_sign_bits_1, partial_sign_bits_upper_1);
9284
+
9285
+ const __m128i odd_bits_0 = _mm_shuffle_epi8(bit_helper_0, partial_sign_bits_for_counting_0);
9286
+ const __m128i odd_bits_1 = _mm_shuffle_epi8(bit_helper_1, partial_sign_bits_for_counting_1);
9287
+ const __m128i full_sign_bits_0 = _mm_or_si128(partial_sign_bits_0, odd_bits_0);
9288
+ const __m128i full_sign_bits_1 = _mm_or_si128(partial_sign_bits_1, odd_bits_1);
9289
+
9290
+ const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
9291
+ const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
9292
+ const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
9293
+ const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
9294
+ const __m128i q8_3_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
9295
+ const __m128i q8_3_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
9296
+ const __m128i q8_4_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
9297
+ const __m128i q8_4_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
9298
+
9299
+ const __m128i q2_1_0 = _mm_set_epi64x(iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);
9300
+ const __m128i q2_1_1 = _mm_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]]);
9301
+ const __m128i q2_2_0 = _mm_set_epi64x(iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);
9302
+ const __m128i q2_2_1 = _mm_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]]);
9303
+ const __m128i q2_3_0 = _mm_set_epi64x(iq2xs_grid[gindex[9]], iq2xs_grid[gindex[8]]);
9304
+ const __m128i q2_3_1 = _mm_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]]);
9305
+ const __m128i q2_4_0 = _mm_set_epi64x(iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
9306
+ const __m128i q2_4_1 = _mm_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]]);
9307
+
9308
+ // AVX2 full_signs_1 is full_sign_bits_0 here
9309
+ // AVX2 full_signs_2 is full_sign_bits_1 here
9310
+ __m128i signs_0, signs_1;
9311
+ signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_0);
9312
+ signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_1);
9313
+ signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
9314
+ signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
9315
+ const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, _mm_or_si128(signs_0, mone));
9316
+ const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, _mm_or_si128(signs_1, mone));
9317
+
9318
+ signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_0);
9319
+ signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_1);
9320
+ signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
9321
+ signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
9322
+ const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, _mm_or_si128(signs_0, mone));
9323
+ const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, _mm_or_si128(signs_1, mone));
9324
+
9325
+ signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_0);
9326
+ signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_1);
9327
+ signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
9328
+ signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
9329
+ const __m128i q8s_3_0 = _mm_sign_epi8(q8_3_0, _mm_or_si128(signs_0, mone));
9330
+ const __m128i q8s_3_1 = _mm_sign_epi8(q8_3_1, _mm_or_si128(signs_1, mone));
9331
+
9332
+ signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_0);
9333
+ signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_1);
9334
+ signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
9335
+ signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
9336
+ const __m128i q8s_4_0 = _mm_sign_epi8(q8_4_0, _mm_or_si128(signs_0, mone));
9337
+ const __m128i q8s_4_1 = _mm_sign_epi8(q8_4_1, _mm_or_si128(signs_1, mone));
9338
+
9339
+ const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
9340
+ const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
9341
+ const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
9342
+ const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
9343
+ const __m128i dot3_0 = _mm_maddubs_epi16(q2_3_0, q8s_3_0);
9344
+ const __m128i dot3_1 = _mm_maddubs_epi16(q2_3_1, q8s_3_1);
9345
+ const __m128i dot4_0 = _mm_maddubs_epi16(q2_4_0, q8s_4_0);
9346
+ const __m128i dot4_1 = _mm_maddubs_epi16(q2_4_1, q8s_4_1);
9347
+
9348
+ __m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0));
9349
+ const __m128i sc1_0 = _mm_cvtepi8_epi16(sc_tmp);
9350
+ const __m128i sc1_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
9351
+ sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1));
9352
+ const __m128i sc2_0 = _mm_cvtepi8_epi16(sc_tmp);
9353
+ const __m128i sc2_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
9354
+ sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2));
9355
+ const __m128i sc3_0 = _mm_cvtepi8_epi16(sc_tmp);
9356
+ const __m128i sc3_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
9357
+ sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3));
9358
+ const __m128i sc4_0 = _mm_cvtepi8_epi16(sc_tmp);
9359
+ const __m128i sc4_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
9360
+
9361
+ sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot1_0, sc1_0));
9362
+ sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot1_1, sc1_1));
9363
+ sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot2_0, sc2_0));
9364
+ sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot2_1, sc2_1));
9365
+ sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot3_0, sc3_0));
9366
+ sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot3_1, sc3_1));
9367
+ sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot4_0, sc4_0));
9368
+ sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot4_1, sc4_1));
9369
+ }
9370
+
9371
+ accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
9372
+
9373
+ }
9374
+
9375
+ *s = 0.125f * hsum_float_8(accumf);
9376
+
9282
9377
  #elif defined(__loongarch_asx)
9283
9378
 
9284
9379
  const __m256i mone = __lasx_xvreplgr2vr_b(1);
@@ -9397,6 +9492,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
9397
9492
 
9398
9493
  *s = 0.125f * hsum_float_8(accumf);
9399
9494
  #elif defined(__POWER9_VECTOR__)
9495
+ const vector int v0 = vec_splats((int32_t)0);
9400
9496
  vector float vsumf0 = vec_splats(0.0f);
9401
9497
  vector float vsumf1 = vec_splats(0.0f);
9402
9498
  vector float vsumf2 = vec_splats(0.0f);
@@ -9409,14 +9505,10 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
9409
9505
  vector float vyd = vec_splats(y[i].d);
9410
9506
  vector float vd = vec_mul(vxd, vyd);
9411
9507
 
9412
- vector signed int vsumi0 = vec_splats((int32_t)0);
9413
- vector signed int vsumi1 = vec_splats((int32_t)0);
9414
- vector signed int vsumi2 = vec_splats((int32_t)0);
9415
- vector signed int vsumi3 = vec_splats((int32_t)0);
9416
- vector signed int vsumi4 = vec_splats((int32_t)0);
9417
- vector signed int vsumi5 = vec_splats((int32_t)0);
9418
- vector signed int vsumi6 = vec_splats((int32_t)0);
9419
- vector signed int vsumi7 = vec_splats((int32_t)0);
9508
+ vector signed int vsumi0 = v0;
9509
+ vector signed int vsumi1 = v0;
9510
+ vector signed int vsumi2 = v0;
9511
+ vector signed int vsumi3 = v0;
9420
9512
 
9421
9513
  const uint16_t * restrict q2 = x[i].qs;
9422
9514
  const uint8_t * restrict sc = x[i].scales;
@@ -9464,21 +9556,12 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void *
9464
9556
  vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1));
9465
9557
  vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1));
9466
9558
 
9467
- vsumi0 = vec_add(vec_mule(qv0, vscales0), vsumi0);
9468
- vsumi1 = vec_add(vec_mule(qv1, vscales1), vsumi1);
9469
- vsumi2 = vec_add(vec_mule(qv2, vscales2), vsumi2);
9470
- vsumi3 = vec_add(vec_mule(qv3, vscales3), vsumi3);
9471
- vsumi4 = vec_add(vec_mulo(qv0, vscales0), vsumi4);
9472
- vsumi5 = vec_add(vec_mulo(qv1, vscales1), vsumi5);
9473
- vsumi6 = vec_add(vec_mulo(qv2, vscales2), vsumi6);
9474
- vsumi7 = vec_add(vec_mulo(qv3, vscales3), vsumi7);
9559
+ vsumi0 = vec_msum(qv0, vscales0, vsumi0);
9560
+ vsumi1 = vec_msum(qv1, vscales1, vsumi1);
9561
+ vsumi2 = vec_msum(qv2, vscales2, vsumi2);
9562
+ vsumi3 = vec_msum(qv3, vscales3, vsumi3);
9475
9563
  }
9476
9564
 
9477
- vsumi0 = vec_add(vsumi0, vsumi4);
9478
- vsumi1 = vec_add(vsumi1, vsumi5);
9479
- vsumi2 = vec_add(vsumi2, vsumi6);
9480
- vsumi3 = vec_add(vsumi3, vsumi7);
9481
-
9482
9565
  vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
9483
9566
  vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
9484
9567
  vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
@@ -9694,6 +9777,98 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
9694
9777
 
9695
9778
  *s = 0.125f * hsum_float_8(accumf);
9696
9779
 
9780
+ #elif defined(__AVX__)
9781
+ static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
9782
+ 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
9783
+ };
9784
+
9785
+ static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
9786
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
9787
+ };
9788
+
9789
+ const __m128i m4 = _mm_set1_epi8(0xf);
9790
+ const __m128i m1 = _mm_set1_epi8(1);
9791
+
9792
+ const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
9793
+ const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
9794
+ const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
9795
+ const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
9796
+
9797
+ uint64_t aux64;
9798
+
9799
+ __m256 accumf = _mm256_setzero_ps();
9800
+ for (int i = 0; i < nb; ++i) {
9801
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
9802
+ const uint8_t * restrict qs = x[i].qs;
9803
+ const uint8_t * restrict qh = x[i].qh;
9804
+ const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
9805
+ const int8_t * restrict q8 = y[i].qs;
9806
+
9807
+ memcpy(&aux64, x[i].scales, 8);
9808
+ const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1);
9809
+ const __m128i scales16_0 = _mm_cvtepi8_epi16(scales8);
9810
+ const __m128i scales16_1 = _mm_cvtepi8_epi16(_mm_srli_si128(scales8, 8));
9811
+
9812
+ __m128i sumi1_0 = _mm_setzero_si128();
9813
+ __m128i sumi1_1 = _mm_setzero_si128();
9814
+ __m128i sumi2_0 = _mm_setzero_si128();
9815
+ __m128i sumi2_1 = _mm_setzero_si128();
9816
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
9817
+ const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
9818
+ const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
9819
+ const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
9820
+ const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
9821
+ const __m128i q2_1_0 = _mm_set_epi64x(iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
9822
+ iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
9823
+ const __m128i q2_1_1 = _mm_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
9824
+ iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)]);
9825
+ const __m128i q2_2_0 = _mm_set_epi64x(iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
9826
+ iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
9827
+ const __m128i q2_2_1 = _mm_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
9828
+ iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)]);
9829
+ qs += 8;
9830
+
9831
+ __m128i aux128_0 = _mm_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16));
9832
+ __m128i aux128_1 = aux128_0;
9833
+ aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
9834
+ aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
9835
+ const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
9836
+ const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
9837
+ const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
9838
+ const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
9839
+
9840
+ aux128_0 = _mm_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16));
9841
+ aux128_1 = aux128_0;
9842
+ aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
9843
+ aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
9844
+ const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
9845
+ const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
9846
+ const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
9847
+ const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
9848
+
9849
+ signs += 4;
9850
+
9851
+ const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
9852
+ const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
9853
+ const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
9854
+ const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
9855
+
9856
+ const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 0)));
9857
+ const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 1)));
9858
+ const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 0)));
9859
+ const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 1)));
9860
+ sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
9861
+ sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
9862
+ sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
9863
+ sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
9864
+ }
9865
+
9866
+ accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
9867
+
9868
+ }
9869
+
9870
+ *s = 0.125f * hsum_float_8(accumf);
9871
+
9697
9872
  #elif defined(__POWER9_VECTOR__)
9698
9873
  static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
9699
9874
  0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
@@ -9701,6 +9876,8 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
9701
9876
 
9702
9877
  static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
9703
9878
 
9879
+ const vector int v0 = vec_splats((int32_t)0);
9880
+
9704
9881
  vector float vsumf0 = vec_splats(0.0f);
9705
9882
  vector float vsumf1 = vec_splats(0.0f);
9706
9883
  vector float vsumf2 = vec_splats(0.0f);
@@ -9715,14 +9892,10 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
9715
9892
  vector float vyd = vec_splats(y[i].d);
9716
9893
  vector float vd = vec_mul(vxd, vyd);
9717
9894
 
9718
- vector signed int vsumi0 = vec_splats((int32_t)0);
9719
- vector signed int vsumi1 = vec_splats((int32_t)0);
9720
- vector signed int vsumi2 = vec_splats((int32_t)0);
9721
- vector signed int vsumi3 = vec_splats((int32_t)0);
9722
- vector signed int vsumi4 = vec_splats((int32_t)0);
9723
- vector signed int vsumi5 = vec_splats((int32_t)0);
9724
- vector signed int vsumi6 = vec_splats((int32_t)0);
9725
- vector signed int vsumi7 = vec_splats((int32_t)0);
9895
+ vector signed int vsumi0 = v0;
9896
+ vector signed int vsumi1 = v0;
9897
+ vector signed int vsumi2 = v0;
9898
+ vector signed int vsumi3 = v0;
9726
9899
 
9727
9900
  const uint8_t * restrict q2 = x[i].qs;
9728
9901
  const uint8_t * restrict qh = x[i].qh;
@@ -9782,21 +9955,12 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void *
9782
9955
  vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1));
9783
9956
  vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1));
9784
9957
 
9785
- vsumi0 = vec_add(vec_mule(qv0, vscales0), vsumi0);
9786
- vsumi1 = vec_add(vec_mule(qv1, vscales1), vsumi1);
9787
- vsumi2 = vec_add(vec_mule(qv2, vscales2), vsumi2);
9788
- vsumi3 = vec_add(vec_mule(qv3, vscales3), vsumi3);
9789
- vsumi4 = vec_add(vec_mulo(qv0, vscales0), vsumi4);
9790
- vsumi5 = vec_add(vec_mulo(qv1, vscales1), vsumi5);
9791
- vsumi6 = vec_add(vec_mulo(qv2, vscales2), vsumi6);
9792
- vsumi7 = vec_add(vec_mulo(qv3, vscales3), vsumi7);
9958
+ vsumi0 = vec_msum(qv0, vscales0, vsumi0);
9959
+ vsumi1 = vec_msum(qv1, vscales1, vsumi1);
9960
+ vsumi2 = vec_msum(qv2, vscales2, vsumi2);
9961
+ vsumi3 = vec_msum(qv3, vscales3, vsumi3);
9793
9962
  }
9794
9963
 
9795
- vsumi0 = vec_add(vsumi0, vsumi4);
9796
- vsumi1 = vec_add(vsumi1, vsumi5);
9797
- vsumi2 = vec_add(vsumi2, vsumi6);
9798
- vsumi3 = vec_add(vsumi3, vsumi7);
9799
-
9800
9964
  vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
9801
9965
  vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
9802
9966
  vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
@@ -10031,9 +10195,68 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
10031
10195
 
10032
10196
  *s = 0.25f * hsum_float_8(accumf);
10033
10197
 
10198
+ #elif defined(__AVX__)
10199
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
10200
+
10201
+ uint32_t aux32[2];
10202
+
10203
+ __m256 accumf = _mm256_setzero_ps();
10204
+ for (int i = 0; i < nb; ++i) {
10205
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
10206
+ const uint8_t * restrict q3 = x[i].qs;
10207
+ const uint8_t * restrict gas = x[i].qs + QK_K/4;
10208
+ const int8_t * restrict q8 = y[i].qs;
10209
+ __m128i sumi1_0 = _mm_setzero_si128();
10210
+ __m128i sumi1_1 = _mm_setzero_si128();
10211
+ __m128i sumi2_0 = _mm_setzero_si128();
10212
+ __m128i sumi2_1 = _mm_setzero_si128();
10213
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
10214
+ const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
10215
+ const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
10216
+ const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
10217
+ const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
10218
+ const __m128i q2_1_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
10219
+ const __m128i q2_1_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
10220
+ q3 += 8;
10221
+ const __m128i q2_2_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
10222
+ const __m128i q2_2_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
10223
+ q3 += 8;
10224
+ memcpy(aux32, gas, 8); gas += 8;
10225
+ const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
10226
+ const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127]);
10227
+ const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
10228
+ const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
10229
+ const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
10230
+ const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
10231
+ const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
10232
+ const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
10233
+ const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
10234
+ const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
10235
+ const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
10236
+ const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
10237
+ const uint16_t ls1 = aux32[0] >> 28;
10238
+ const uint16_t ls2 = aux32[1] >> 28;
10239
+ const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
10240
+ const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
10241
+ const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
10242
+ const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
10243
+ sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
10244
+ sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
10245
+ sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
10246
+ sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
10247
+ }
10248
+
10249
+ accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
10250
+
10251
+ }
10252
+
10253
+ *s = 0.25f * hsum_float_8(accumf);
10254
+
10034
10255
  #elif defined(__POWER9_VECTOR__)
10035
10256
  const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
10036
10257
 
10258
+ const vector int v0 = vec_splats((int32_t)0);
10259
+
10037
10260
  vector float vsumf0 = vec_splats(0.0f);
10038
10261
  vector float vsumf1 = vec_splats(0.0f);
10039
10262
  vector float vsumf2 = vec_splats(0.0f);
@@ -10044,14 +10267,10 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
10044
10267
  vector float vyd = vec_splats(y[i].d);
10045
10268
  vector float vd = vec_mul(vxd, vyd);
10046
10269
 
10047
- vector signed int vsumi0 = vec_splats((int32_t)0);
10048
- vector signed int vsumi1 = vec_splats((int32_t)0);
10049
- vector signed int vsumi2 = vec_splats((int32_t)0);
10050
- vector signed int vsumi3 = vec_splats((int32_t)0);
10051
- vector signed int vsumi4 = vec_splats((int32_t)0);
10052
- vector signed int vsumi5 = vec_splats((int32_t)0);
10053
- vector signed int vsumi6 = vec_splats((int32_t)0);
10054
- vector signed int vsumi7 = vec_splats((int32_t)0);
10270
+ vector signed int vsumi0 = v0;
10271
+ vector signed int vsumi1 = v0;
10272
+ vector signed int vsumi2 = v0;
10273
+ vector signed int vsumi3 = v0;
10055
10274
 
10056
10275
  const uint8_t * restrict q3 = x[i].qs;
10057
10276
  const uint32_t * restrict signs = (const uint32_t *)(x[i].qs + QK_K/4);
@@ -10096,21 +10315,12 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
10096
10315
  vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1));
10097
10316
  vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1));
10098
10317
 
10099
- vsumi0 = vec_add(vec_mule(qv0, vscales01), vsumi0);
10100
- vsumi1 = vec_add(vec_mule(qv1, vscales01), vsumi1);
10101
- vsumi2 = vec_add(vec_mule(qv2, vscales23), vsumi2);
10102
- vsumi3 = vec_add(vec_mule(qv3, vscales23), vsumi3);
10103
- vsumi4 = vec_add(vec_mulo(qv0, vscales01), vsumi4);
10104
- vsumi5 = vec_add(vec_mulo(qv1, vscales01), vsumi5);
10105
- vsumi6 = vec_add(vec_mulo(qv2, vscales23), vsumi6);
10106
- vsumi7 = vec_add(vec_mulo(qv3, vscales23), vsumi7);
10318
+ vsumi0 = vec_msum(qv0, vscales01, vsumi0);
10319
+ vsumi1 = vec_msum(qv1, vscales01, vsumi1);
10320
+ vsumi2 = vec_msum(qv2, vscales23, vsumi2);
10321
+ vsumi3 = vec_msum(qv3, vscales23, vsumi3);
10107
10322
  }
10108
10323
 
10109
- vsumi0 = vec_add(vsumi0, vsumi4);
10110
- vsumi1 = vec_add(vsumi1, vsumi5);
10111
- vsumi2 = vec_add(vsumi2, vsumi6);
10112
- vsumi3 = vec_add(vsumi3, vsumi7);
10113
-
10114
10324
  vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
10115
10325
  vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
10116
10326
  vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
@@ -10393,6 +10603,112 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
10393
10603
 
10394
10604
  *s = hsum_float_8(accumf);
10395
10605
 
10606
+ #elif defined(__AVX__)
10607
+ static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
10608
+ 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
10609
+ };
10610
+
10611
+ static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
10612
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
10613
+ };
10614
+
10615
+ const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
10616
+ const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
10617
+ const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
10618
+ const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
10619
+
10620
+ const __m128i idx_mul_0 = _mm_set_epi32(32, 64, 128, 256);
10621
+ const __m128i idx_mul_1 = _mm_set_epi32(2, 4, 8, 16);
10622
+ const __m128i idx_mask = _mm_set1_epi32(256);
10623
+
10624
+ typedef union {
10625
+ __m128i vec[4];
10626
+ uint32_t index[16];
10627
+ } index_t;
10628
+
10629
+ index_t idx;
10630
+
10631
+ __m256 accumf = _mm256_setzero_ps();
10632
+ for (int i = 0; i < nb; ++i) {
10633
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
10634
+ const uint8_t * restrict qs = x[i].qs;
10635
+ const uint8_t * restrict qh = x[i].qh;
10636
+ const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
10637
+ const int8_t * restrict q8 = y[i].qs;
10638
+ __m128i sumi1_0 = _mm_setzero_si128();
10639
+ __m128i sumi1_1 = _mm_setzero_si128();
10640
+ __m128i sumi2_0 = _mm_setzero_si128();
10641
+ __m128i sumi2_1 = _mm_setzero_si128();
10642
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
10643
+ const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
10644
+ const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
10645
+ const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
10646
+ const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
10647
+ const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs);
10648
+ const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp);
10649
+ const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16;
10650
+ idx.vec[0] = _mm_set1_epi32(qh[ib32+0]);
10651
+ idx.vec[1] = idx.vec[0];
10652
+ idx.vec[2] = _mm_set1_epi32(qh[ib32+1]);
10653
+ idx.vec[3] = idx.vec[2];
10654
+
10655
+ idx.vec[0] = _mm_and_si128(_mm_mullo_epi32(idx.vec[0], idx_mul_0), idx_mask);
10656
+ idx.vec[1] = _mm_and_si128(_mm_mullo_epi32(idx.vec[1], idx_mul_1), idx_mask);
10657
+ idx.vec[2] = _mm_and_si128(_mm_mullo_epi32(idx.vec[2], idx_mul_0), idx_mask);
10658
+ idx.vec[3] = _mm_and_si128(_mm_mullo_epi32(idx.vec[3], idx_mul_1), idx_mask);
10659
+
10660
+ idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0));
10661
+ idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8)));
10662
+ idx.vec[2] = _mm_or_si128(idx.vec[2], _mm_cvtepi16_epi32(idx_l_1));
10663
+ idx.vec[3] = _mm_or_si128(idx.vec[3], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8)));
10664
+
10665
+ const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]);
10666
+ const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]);
10667
+ const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[9]], iq3s_grid[idx.index[8]]);
10668
+ const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]);
10669
+
10670
+ __m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16));
10671
+ __m128i aux128_1 = aux128_0;
10672
+ aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
10673
+ aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
10674
+ const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
10675
+ const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
10676
+ const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
10677
+ const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
10678
+
10679
+ aux128_0 = _mm_set1_epi32(signs[2] | (signs[3] << 16));
10680
+ aux128_1 = aux128_0;
10681
+ aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
10682
+ aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
10683
+ const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
10684
+ const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
10685
+ const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
10686
+ const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
10687
+
10688
+ signs += 4;
10689
+
10690
+ const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
10691
+ const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
10692
+ const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
10693
+ const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
10694
+ const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
10695
+ const uint16_t ls2 = x[i].scales[ib32/2] >> 4;
10696
+ const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
10697
+ const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
10698
+ const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
10699
+ const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
10700
+ sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
10701
+ sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
10702
+ sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
10703
+ sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
10704
+ }
10705
+
10706
+ accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
10707
+
10708
+ }
10709
+
10710
+ *s = hsum_float_8(accumf);
10711
+
10396
10712
  #elif defined(__POWER9_VECTOR__)
10397
10713
  static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
10398
10714
  0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
@@ -10400,6 +10716,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
10400
10716
 
10401
10717
  static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
10402
10718
 
10719
+ const vector int v0 = vec_splats((int32_t)0);
10720
+
10403
10721
  vector float vsumf0 = vec_splats(0.0f);
10404
10722
  vector float vsumf1 = vec_splats(0.0f);
10405
10723
  vector float vsumf2 = vec_splats(0.0f);
@@ -10420,14 +10738,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
10420
10738
  const uint8_t * restrict sc = x[i].scales;
10421
10739
  const int8_t * restrict q8 = y[i].qs;
10422
10740
 
10423
- vector signed int vsumi0 = vec_splats((int32_t)0);
10424
- vector signed int vsumi1 = vec_splats((int32_t)0);
10425
- vector signed int vsumi2 = vec_splats((int32_t)0);
10426
- vector signed int vsumi3 = vec_splats((int32_t)0);
10427
- vector signed int vsumi4 = vec_splats((int32_t)0);
10428
- vector signed int vsumi5 = vec_splats((int32_t)0);
10429
- vector signed int vsumi6 = vec_splats((int32_t)0);
10430
- vector signed int vsumi7 = vec_splats((int32_t)0);
10741
+ vector signed int vsumi0 = v0;
10742
+ vector signed int vsumi1 = v0;
10743
+ vector signed int vsumi2 = v0;
10744
+ vector signed int vsumi3 = v0;
10431
10745
 
10432
10746
  for (int j = 0; j < QK_K/32; j += 2) {
10433
10747
  __builtin_prefetch(q3, 0, 1);
@@ -10481,21 +10795,12 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
10481
10795
  vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1));
10482
10796
  vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1));
10483
10797
 
10484
- vsumi0 = vec_add(vec_mule(qv0, vscales01), vsumi0);
10485
- vsumi1 = vec_add(vec_mule(qv1, vscales01), vsumi1);
10486
- vsumi2 = vec_add(vec_mule(qv2, vscales23), vsumi2);
10487
- vsumi3 = vec_add(vec_mule(qv3, vscales23), vsumi3);
10488
- vsumi4 = vec_add(vec_mulo(qv0, vscales01), vsumi4);
10489
- vsumi5 = vec_add(vec_mulo(qv1, vscales01), vsumi5);
10490
- vsumi6 = vec_add(vec_mulo(qv2, vscales23), vsumi6);
10491
- vsumi7 = vec_add(vec_mulo(qv3, vscales23), vsumi7);
10798
+ vsumi0 = vec_msum(qv0, vscales01, vsumi0);
10799
+ vsumi1 = vec_msum(qv1, vscales01, vsumi1);
10800
+ vsumi2 = vec_msum(qv2, vscales23, vsumi2);
10801
+ vsumi3 = vec_msum(qv3, vscales23, vsumi3);
10492
10802
  }
10493
10803
 
10494
- vsumi0 = vec_add(vsumi0, vsumi4);
10495
- vsumi1 = vec_add(vsumi1, vsumi5);
10496
- vsumi2 = vec_add(vsumi2, vsumi6);
10497
- vsumi3 = vec_add(vsumi3, vsumi7);
10498
-
10499
10804
  vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
10500
10805
  vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
10501
10806
  vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
@@ -10641,6 +10946,14 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
10641
10946
  }
10642
10947
 
10643
10948
 
10949
+ #if defined(__AVX__)
10950
+ static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
10951
+ const __m128i ax = _mm_sign_epi8(x, x);
10952
+ const __m128i sy = _mm_sign_epi8(y, x);
10953
+ return _mm_maddubs_epi16(ax, sy);
10954
+ }
10955
+ #endif
10956
+
10644
10957
  #if defined(__AVX2__)
10645
10958
  static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
10646
10959
  const __m256i ax = _mm256_sign_epi8(x, x);
@@ -10758,6 +11071,54 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void
10758
11071
 
10759
11072
  *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
10760
11073
 
11074
+ #elif defined __AVX__
11075
+ __m256 accum = _mm256_setzero_ps();
11076
+ float accum1 = 0;
11077
+ for (int i = 0; i < nb; ++i) {
11078
+
11079
+ const int8_t * q8 = y[i].qs;
11080
+ const uint8_t * qs = x[i].qs;
11081
+ const uint16_t * qh = x[i].qh;
11082
+
11083
+ __m128i sumi1_0 = _mm_setzero_si128();
11084
+ __m128i sumi1_1 = _mm_setzero_si128();
11085
+ int sumi1 = 0;
11086
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
11087
+ const __m128i q1b_1_0 = _mm_set_epi64x(iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
11088
+ const __m128i q1b_1_1 = _mm_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)]);
11089
+ const __m128i q1b_2_0 = _mm_set_epi64x(iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
11090
+ const __m128i q1b_2_1 = _mm_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)]);
11091
+ qs += 8;
11092
+ const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
11093
+ const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
11094
+ const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
11095
+ const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
11096
+
11097
+ const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);
11098
+ const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);
11099
+ const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);
11100
+ const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);
11101
+ const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
11102
+ const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
11103
+ const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(ls1));
11104
+ const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(ls1));
11105
+ const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(ls2));
11106
+ const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(ls2));
11107
+
11108
+ sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));
11109
+ sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));
11110
+ sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
11111
+ + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
11112
+ }
11113
+
11114
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
11115
+ accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum);
11116
+ accum1 += d * sumi1;
11117
+
11118
+ }
11119
+
11120
+ *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
11121
+
10761
11122
  #elif defined(__POWER9_VECTOR__)
10762
11123
  const vector unsigned char v0 = vec_splats((unsigned char)0x0);
10763
11124
  const vector unsigned short vsign = vec_splats((unsigned short)0x8000);
@@ -10776,10 +11137,6 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void
10776
11137
  vector signed int vsumi1 = vec_splats((int32_t)0);
10777
11138
  vector signed int vsumi2 = vec_splats((int32_t)0);
10778
11139
  vector signed int vsumi3 = vec_splats((int32_t)0);
10779
- vector signed int vsumi4 = vec_splats((int32_t)0);
10780
- vector signed int vsumi5 = vec_splats((int32_t)0);
10781
- vector signed int vsumi6 = vec_splats((int32_t)0);
10782
- vector signed int vsumi7 = vec_splats((int32_t)0);
10783
11140
  vector signed int vsumi8 = vec_splats((int32_t)0);
10784
11141
 
10785
11142
  const uint8_t * restrict q1 = x[i].qs;
@@ -10821,14 +11178,10 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void
10821
11178
  vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1));
10822
11179
  vector signed short vscales = vec_sld(vscales23, vscales01, 8);
10823
11180
 
10824
- vsumi0 = vec_add(vec_mule(qv0, vscales01), vsumi0);
10825
- vsumi1 = vec_add(vec_mule(qv1, vscales01), vsumi1);
10826
- vsumi2 = vec_add(vec_mule(qv2, vscales23), vsumi2);
10827
- vsumi3 = vec_add(vec_mule(qv3, vscales23), vsumi3);
10828
- vsumi4 = vec_add(vec_mulo(qv0, vscales01), vsumi4);
10829
- vsumi5 = vec_add(vec_mulo(qv1, vscales01), vsumi5);
10830
- vsumi6 = vec_add(vec_mulo(qv2, vscales23), vsumi6);
10831
- vsumi7 = vec_add(vec_mulo(qv3, vscales23), vsumi7);
11181
+ vsumi0 = vec_msum(qv0, vscales01, vsumi0);
11182
+ vsumi1 = vec_msum(qv1, vscales01, vsumi1);
11183
+ vsumi2 = vec_msum(qv2, vscales23, vsumi2);
11184
+ vsumi3 = vec_msum(qv3, vscales23, vsumi3);
10832
11185
 
10833
11186
  vector signed short q8ysums = vec_xl_len(qs, 8);
10834
11187
  qs += 4;
@@ -10843,11 +11196,6 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void
10843
11196
  vsumi8 = vec_add(vec_mule(q8ysum, vscales), vsumi8);
10844
11197
  }
10845
11198
 
10846
- vsumi0 = vec_add(vsumi0, vsumi4);
10847
- vsumi1 = vec_add(vsumi1, vsumi5);
10848
- vsumi2 = vec_add(vsumi2, vsumi6);
10849
- vsumi3 = vec_add(vsumi3, vsumi7);
10850
-
10851
11199
  vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
10852
11200
  vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
10853
11201
  vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
@@ -11109,6 +11457,92 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
11109
11457
 
11110
11458
  *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
11111
11459
 
11460
+ #elif defined __AVX__
11461
+ const __m128i mask = _mm_set1_epi16(0x7);
11462
+ const __m128i mone = _mm_set1_epi16(1);
11463
+
11464
+ __m256 accum1 = _mm256_setzero_ps();
11465
+ __m256 accum2 = _mm256_setzero_ps();
11466
+ for (int i = 0; i < nb; ++i) {
11467
+
11468
+ const int8_t * q8 = y[i].qs;
11469
+ const uint8_t * qs = x[i].qs;
11470
+ const uint8_t * qh = x[i].qh;
11471
+ const uint16_t * sc = (const uint16_t *)x[i].scales;
11472
+
11473
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
11474
+
11475
+ __m128i sumi1_0 = _mm_setzero_si128();
11476
+ __m128i sumi1_1 = _mm_setzero_si128();
11477
+ __m128i sumi2_0 = _mm_setzero_si128();
11478
+ __m128i sumi2_1 = _mm_setzero_si128();
11479
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
11480
+ const __m128i q1b_1_0 = _mm_set_epi64x(
11481
+ iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]);
11482
+ const __m128i q1b_1_1 = _mm_set_epi64x(
11483
+ iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)]);
11484
+ const __m128i q1b_2_0 = _mm_set_epi64x(
11485
+ iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]);
11486
+ const __m128i q1b_2_1 = _mm_set_epi64x(
11487
+ iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)]);
11488
+ const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
11489
+ const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
11490
+ const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
11491
+ const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
11492
+
11493
+ const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);
11494
+ const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);
11495
+ const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);
11496
+ const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);
11497
+
11498
+ const __m128i delta1_0 = _mm_set_epi64x(qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
11499
+ qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
11500
+ const __m128i delta1_1 = _mm_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
11501
+ qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
11502
+ const __m128i delta2_0 = _mm_set_epi64x(qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
11503
+ qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
11504
+ const __m128i delta2_1 = _mm_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
11505
+ qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
11506
+
11507
+ const __m128i dot3_0 = mul_add_epi8_sse(delta1_0, q8b_1_0);
11508
+ const __m128i dot3_1 = mul_add_epi8_sse(delta1_1, q8b_1_1);
11509
+ const __m128i dot4_0 = mul_add_epi8_sse(delta2_0, q8b_2_0);
11510
+ const __m128i dot4_1 = mul_add_epi8_sse(delta2_1, q8b_2_1);
11511
+
11512
+ __m128i scale1_0 = _mm_set1_epi16(sc[ib/2] >> 0);
11513
+ __m128i scale1_1 = _mm_set1_epi16(sc[ib/2] >> 3);
11514
+ __m128i scale2_0 = _mm_set1_epi16(sc[ib/2] >> 6);
11515
+ __m128i scale2_1 = _mm_set1_epi16(sc[ib/2] >> 9);
11516
+
11517
+ scale1_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_0, mask), 1), mone);
11518
+ scale1_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_1, mask), 1), mone);
11519
+ scale2_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_0, mask), 1), mone);
11520
+ scale2_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_1, mask), 1), mone);
11521
+ const __m128i p1_0 = _mm_madd_epi16(dot1_0, scale1_0);
11522
+ const __m128i p1_1 = _mm_madd_epi16(dot1_1, scale1_1);
11523
+ const __m128i p2_0 = _mm_madd_epi16(dot2_0, scale2_0);
11524
+ const __m128i p2_1 = _mm_madd_epi16(dot2_1, scale2_1);
11525
+ const __m128i p3_0 = _mm_madd_epi16(dot3_0, scale1_0);
11526
+ const __m128i p3_1 = _mm_madd_epi16(dot3_1, scale1_1);
11527
+ const __m128i p4_0 = _mm_madd_epi16(dot4_0, scale2_0);
11528
+ const __m128i p4_1 = _mm_madd_epi16(dot4_1, scale2_1);
11529
+
11530
+ sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));
11531
+ sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));
11532
+ sumi2_0 = _mm_add_epi32(sumi2_0, _mm_add_epi32(p3_0, p4_0));
11533
+ sumi2_1 = _mm_add_epi32(sumi2_1, _mm_add_epi32(p3_1, p4_1));
11534
+
11535
+ qs += 8; qh += 4;
11536
+ }
11537
+
11538
+ const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16));
11539
+
11540
+ accum1 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum1);
11541
+ accum2 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi2_1, sumi2_0))), accum2);
11542
+ }
11543
+
11544
+ *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
11545
+
11112
11546
  #else
11113
11547
 
11114
11548
  int sum1[2], sum2[2], delta[4];
@@ -11173,6 +11607,9 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
11173
11607
 
11174
11608
  const int nb = n / QK4_NL;
11175
11609
 
11610
+ int ib = 0;
11611
+ float sumf = 0;
11612
+
11176
11613
  #if defined __ARM_NEON
11177
11614
  const int8x16_t values = vld1q_s8(kvalues_iq4nl);
11178
11615
  const uint8x16_t m4b = vdupq_n_u8(0x0f);
@@ -11181,16 +11618,14 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
11181
11618
  int8x16x4_t q8b;
11182
11619
  int32x4_t prod_1, prod_2;
11183
11620
 
11184
- float sumf = 0;
11185
-
11186
- for (int ib = 0; ib < nb; ib += 2) {
11621
+ for (; ib + 1 < nb; ib += 2) {
11187
11622
 
11188
- q4bits.val[0] = vld1q_u8(x[ib+0].qs);
11189
- q4bits.val[1] = vld1q_u8(x[ib+1].qs);
11190
- q8b.val[0] = vld1q_s8(y[ib+0].qs);
11191
- q8b.val[1] = vld1q_s8(y[ib+0].qs + 16);
11192
- q8b.val[2] = vld1q_s8(y[ib+1].qs);
11193
- q8b.val[3] = vld1q_s8(y[ib+1].qs + 16);
11623
+ q4bits.val[0] = vld1q_u8(x[ib + 0].qs);
11624
+ q4bits.val[1] = vld1q_u8(x[ib + 1].qs);
11625
+ q8b.val[0] = vld1q_s8(y[ib + 0].qs);
11626
+ q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16);
11627
+ q8b.val[2] = vld1q_s8(y[ib + 1].qs);
11628
+ q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16);
11194
11629
 
11195
11630
  q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
11196
11631
  q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
@@ -11201,12 +11636,10 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
11201
11636
  prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
11202
11637
 
11203
11638
  sumf +=
11204
- GGML_FP16_TO_FP32(x[ib+0].d) * GGML_FP16_TO_FP32(y[ib+0].d) * vaddvq_s32(prod_1) +
11205
- GGML_FP16_TO_FP32(x[ib+1].d) * GGML_FP16_TO_FP32(y[ib+1].d) * vaddvq_s32(prod_2);
11639
+ GGML_FP16_TO_FP32(x[ib+0].d) * GGML_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +
11640
+ GGML_FP16_TO_FP32(x[ib+1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);
11206
11641
  }
11207
11642
 
11208
- *s = sumf;
11209
-
11210
11643
  #elif defined __AVX2__
11211
11644
 
11212
11645
  const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
@@ -11215,11 +11648,11 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
11215
11648
 
11216
11649
  __m256 accum1 = _mm256_setzero_ps();
11217
11650
  __m256 accum2 = _mm256_setzero_ps();
11218
- for (int ib = 0; ib < nb; ib += 2) {
11219
- const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[0].qs);
11220
- const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[1].qs);
11221
- const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[0].qs);
11222
- const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[1].qs);
11651
+ for (; ib + 1 < nb; ib += 2) {
11652
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
11653
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
11654
+ const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);
11655
+ const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);
11223
11656
  const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
11224
11657
  _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
11225
11658
  const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
@@ -11228,19 +11661,52 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
11228
11661
  const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
11229
11662
  const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
11230
11663
  const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
11231
- accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)),
11664
+ accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)),
11232
11665
  _mm256_cvtepi32_ps(p_1), accum1);
11233
- accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)),
11666
+ accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)),
11234
11667
  _mm256_cvtepi32_ps(p_2), accum2);
11235
-
11236
- y += 2;
11237
- x += 2;
11238
11668
  }
11239
11669
 
11240
- *s = hsum_float_8(_mm256_add_ps(accum1, accum2));
11670
+ sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
11671
+
11672
+ #elif defined __AVX__
11673
+ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
11674
+ const __m128i m4b = _mm_set1_epi8(0x0f);
11675
+ const __m128i mone = _mm_set1_epi16(1);
11676
+
11677
+ __m256 accum1 = _mm256_setzero_ps();
11678
+ __m256 accum2 = _mm256_setzero_ps();
11679
+ for (; ib + 1 < nb; ib += 2) {
11680
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
11681
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
11682
+ const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
11683
+ const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
11684
+ const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
11685
+ const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
11686
+
11687
+ const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
11688
+ const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
11689
+ const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
11690
+ const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
11691
+ const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
11692
+ const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
11693
+ const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
11694
+ const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
11695
+ const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone);
11696
+ const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
11697
+ const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
11698
+ const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
11699
+ accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)),
11700
+ _mm256_cvtepi32_ps(MM256_SET_M128I(p_1_1, p_1_0))), accum1);
11701
+ accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)),
11702
+ _mm256_cvtepi32_ps(MM256_SET_M128I(p_2_1, p_2_0))), accum2);
11703
+ }
11704
+
11705
+ sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
11241
11706
 
11242
11707
  #elif defined(__POWER9_VECTOR__)
11243
11708
  const vector signed char lowMask = vec_splats((signed char)0xF);
11709
+ const vector signed int v0 = vec_splats((int32_t)0);
11244
11710
  const vector unsigned char v4 = vec_splats((unsigned char)0x4);
11245
11711
 
11246
11712
  vector float vsumf0 = vec_splats(0.0f);
@@ -11249,7 +11715,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
11249
11715
  const vector signed char values = vec_xl( 0, kvalues_iq4nl);
11250
11716
 
11251
11717
  #pragma GCC unroll 4
11252
- for (int ib = 0; ib < nb; ++ib) {
11718
+ for (; ib < nb; ++ib) {
11253
11719
  __builtin_prefetch(x[ib].qs, 0, 1);
11254
11720
  __builtin_prefetch(y[ib].qs, 0, 1);
11255
11721
 
@@ -11271,8 +11737,11 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
11271
11737
  vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0));
11272
11738
  vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1));
11273
11739
 
11274
- vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0));
11275
- vector signed int vsumi1 = vec_add(vec_unpackh(qv1), vec_unpackl(qv1));
11740
+ vector signed int vsumi0 = v0;
11741
+ vector signed int vsumi1 = v0;
11742
+
11743
+ vsumi0 = vec_sum4s(qv0, vsumi0);
11744
+ vsumi1 = vec_sum4s(qv1, vsumi1);
11276
11745
 
11277
11746
  vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
11278
11747
  vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
@@ -11283,7 +11752,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
11283
11752
  vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
11284
11753
  vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
11285
11754
 
11286
- *s = vec_extract(vsumf0, 0);
11755
+ sumf = vec_extract(vsumf0, 0);
11287
11756
 
11288
11757
  #elif defined (__loongarch_asx)
11289
11758
 
@@ -11293,11 +11762,11 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
11293
11762
 
11294
11763
  __m256 accum1 = (__m256)__lasx_xvldi(0);
11295
11764
  __m256 accum2 = (__m256)__lasx_xvldi(0);
11296
- for (int ib = 0; ib < nb; ib += 2) {
11297
- const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[0].qs, 0);
11298
- const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[1].qs, 0);
11299
- const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[0].qs, 0);
11300
- const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[1].qs, 0);
11765
+ for (; ib + 1 < nb; ib += 2) {
11766
+ const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[ib + 0].qs, 0);
11767
+ const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[ib + 1].qs, 0);
11768
+ const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[ib + 0].qs, 0);
11769
+ const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[ib + 1].qs, 0);
11301
11770
  const __m256i q4b_1 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)),
11302
11771
  lsx_shuffle_b(values128, __lsx_vand_v(q4bits_1, m4b)));
11303
11772
  const __m256i q4b_2 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)),
@@ -11306,20 +11775,16 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
11306
11775
  const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
11307
11776
  const __m256i p_1 = lasx_madd_h(p16_1, mone);
11308
11777
  const __m256i p_2 = lasx_madd_h(p16_2, mone);
11309
- accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)),
11778
+ accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)),
11310
11779
  __lasx_xvffint_s_w(p_1), accum1);
11311
- accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)),
11780
+ accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)),
11312
11781
  __lasx_xvffint_s_w(p_2), accum2);
11313
-
11314
- y += 2;
11315
- x += 2;
11316
11782
  }
11317
11783
 
11318
- *s = hsum_float_8(__lasx_xvfadd_s(accum1, accum2));
11784
+ sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2));
11319
11785
 
11320
- #else
11321
- float sumf = 0;
11322
- for (int ib = 0; ib < nb; ++ib) {
11786
+ #endif
11787
+ for (; ib < nb; ++ib) {
11323
11788
  const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d);
11324
11789
  int sumi1 = 0, sumi2 = 0;
11325
11790
  for (int j = 0; j < QK4_NL/2; ++j) {
@@ -11329,7 +11794,6 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
11329
11794
  sumf += d * (sumi1 + sumi2);
11330
11795
  }
11331
11796
  *s = sumf;
11332
- #endif
11333
11797
  }
11334
11798
 
11335
11799
  void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
@@ -11425,8 +11889,57 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
11425
11889
 
11426
11890
  *s = hsum_float_8(accum);
11427
11891
 
11892
+ #elif defined __AVX__
11893
+ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
11894
+ const __m128i m4b = _mm_set1_epi8(0x0f);
11895
+
11896
+ __m256 accum = _mm256_setzero_ps();
11897
+ for (int ibl = 0; ibl < nb; ++ibl) {
11898
+ const uint8_t * qs = x[ibl].qs;
11899
+ const int8_t * q8 = y[ibl].qs;
11900
+ uint16_t sh = x[ibl].scales_h;
11901
+ __m128i sumi1_0 = _mm_setzero_si128();
11902
+ __m128i sumi1_1 = _mm_setzero_si128();
11903
+ __m128i sumi2_0 = _mm_setzero_si128();
11904
+ __m128i sumi2_1 = _mm_setzero_si128();
11905
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
11906
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)qs); qs += 16;
11907
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)qs); qs += 16;
11908
+ const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
11909
+ const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
11910
+ const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
11911
+ const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
11912
+ const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
11913
+ const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
11914
+ const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
11915
+ const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
11916
+ const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
11917
+ const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
11918
+ const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
11919
+ const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
11920
+ const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
11921
+ const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32;
11922
+ sh >>= 4;
11923
+ const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, _mm_set1_epi16(ls1));
11924
+ const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, _mm_set1_epi16(ls1));
11925
+ const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, _mm_set1_epi16(ls2));
11926
+ const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, _mm_set1_epi16(ls2));
11927
+ sumi1_0 = _mm_add_epi32(p_1_0, sumi1_0);
11928
+ sumi1_1 = _mm_add_epi32(p_1_1, sumi1_1);
11929
+ sumi2_0 = _mm_add_epi32(p_2_0, sumi2_0);
11930
+ sumi2_1 = _mm_add_epi32(p_2_1, sumi2_1);
11931
+ }
11932
+ __m128i sumi12_0 = _mm_add_epi32(sumi1_0, sumi2_0);
11933
+ __m128i sumi12_1 = _mm_add_epi32(sumi1_1, sumi2_1);
11934
+ accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
11935
+ _mm256_cvtepi32_ps(MM256_SET_M128I(sumi12_1, sumi12_0))), accum);
11936
+ }
11937
+
11938
+ *s = hsum_float_8(accum);
11939
+
11428
11940
  #elif defined(__POWER9_VECTOR__)
11429
11941
  const vector signed char lowMask = vec_splats((signed char)0xF);
11942
+ const vector int v0 = vec_splats((int32_t)0);
11430
11943
  const vector unsigned char v4 = vec_splats((unsigned char)0x4);
11431
11944
 
11432
11945
  vector float vsumf0 = vec_splats(0.0f);
@@ -11442,14 +11955,10 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
11442
11955
  vector float vyd = vec_splats(y[ibl].d);
11443
11956
  vector float vd = vec_mul(vxd, vyd);
11444
11957
 
11445
- vector signed int vsumi0 = vec_splats((int32_t)0);
11446
- vector signed int vsumi1 = vec_splats((int32_t)0);
11447
- vector signed int vsumi2 = vec_splats((int32_t)0);
11448
- vector signed int vsumi3 = vec_splats((int32_t)0);
11449
- vector signed int vsumi4 = vec_splats((int32_t)0);
11450
- vector signed int vsumi5 = vec_splats((int32_t)0);
11451
- vector signed int vsumi6 = vec_splats((int32_t)0);
11452
- vector signed int vsumi7 = vec_splats((int32_t)0);
11958
+ vector signed int vsumi0 = v0;
11959
+ vector signed int vsumi1 = v0;
11960
+ vector signed int vsumi2 = v0;
11961
+ vector signed int vsumi3 = v0;
11453
11962
 
11454
11963
  uint16_t h = x[ibl].scales_h;
11455
11964
 
@@ -11494,21 +12003,12 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
11494
12003
  vector signed short vscales01 = vec_splats((int16_t)ls0);
11495
12004
  vector signed short vscales23 = vec_splats((int16_t)ls1);
11496
12005
 
11497
- vsumi0 = vec_add(vec_mule(qv0, vscales01), vsumi0);
11498
- vsumi1 = vec_add(vec_mule(qv1, vscales01), vsumi1);
11499
- vsumi2 = vec_add(vec_mule(qv2, vscales23), vsumi2);
11500
- vsumi3 = vec_add(vec_mule(qv3, vscales23), vsumi3);
11501
- vsumi4 = vec_add(vec_mulo(qv0, vscales01), vsumi4);
11502
- vsumi5 = vec_add(vec_mulo(qv1, vscales01), vsumi5);
11503
- vsumi6 = vec_add(vec_mulo(qv2, vscales23), vsumi6);
11504
- vsumi7 = vec_add(vec_mulo(qv3, vscales23), vsumi7);
12006
+ vsumi0 = vec_msum(qv0, vscales01, vsumi0);
12007
+ vsumi1 = vec_msum(qv1, vscales01, vsumi1);
12008
+ vsumi2 = vec_msum(qv2, vscales23, vsumi2);
12009
+ vsumi3 = vec_msum(qv3, vscales23, vsumi3);
11505
12010
  }
11506
12011
 
11507
- vsumi0 = vec_add(vsumi0, vsumi4);
11508
- vsumi1 = vec_add(vsumi1, vsumi5);
11509
- vsumi2 = vec_add(vsumi2, vsumi6);
11510
- vsumi3 = vec_add(vsumi3, vsumi7);
11511
-
11512
12012
  vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
11513
12013
  vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
11514
12014
  vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
@@ -12204,7 +12704,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
12204
12704
  printf("Oops: found point %u not on grid:", u);
12205
12705
  for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
12206
12706
  printf("\n");
12207
- GGML_ASSERT(false);
12707
+ GGML_ABORT("fatal error");
12208
12708
  }
12209
12709
  q2[2*ib+0] |= ((uint32_t) grid_index << 8*k);
12210
12710
  q2[2*ib+1] |= (block_signs[k] << 7*k);
@@ -12383,7 +12883,7 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v
12383
12883
  printf("Oops: found point %u not on grid:", u);
12384
12884
  for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
12385
12885
  printf("\n");
12386
- GGML_ASSERT(false);
12886
+ GGML_ABORT("fatal error");
12387
12887
  }
12388
12888
  q2[2*ib+k] = grid_index | (block_signs[k] << 9);
12389
12889
  }
@@ -12826,7 +13326,7 @@ static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, v
12826
13326
  printf("Oops: found point %u not on grid:", u);
12827
13327
  for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]);
12828
13328
  printf("\n");
12829
- GGML_ASSERT(false);
13329
+ GGML_ABORT("fatal error");
12830
13330
  }
12831
13331
  if (grid_size == 256) {
12832
13332
  q3[8*ib+k] = grid_index;
@@ -12880,10 +13380,10 @@ size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int64_t
12880
13380
  void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int64_t k) {
12881
13381
  assert(k % QK_K == 0);
12882
13382
  block_iq3_xxs * restrict y = vy;
12883
- quantize_row_iq3_xxs_reference(x, y, k);
13383
+ quantize_row_iq3_xxs_ref(x, y, k);
12884
13384
  }
12885
13385
 
12886
- void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) {
13386
+ void quantize_row_iq3_xxs_ref(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) {
12887
13387
  assert(k % QK_K == 0);
12888
13388
  quantize_row_iq3_xxs_impl(256, x, y, k, NULL);
12889
13389
  }
@@ -13039,7 +13539,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo
13039
13539
  printf("Oops: found point %u not on grid:", u);
13040
13540
  for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]);
13041
13541
  printf("\n");
13042
- GGML_ASSERT(false);
13542
+ GGML_ABORT("fatal error");
13043
13543
  }
13044
13544
  qs[k] = grid_index & 255;
13045
13545
  qh[(ib*bs4+k)/8] |= ((grid_index >> 8) << ((ib*bs4+k)%8));
@@ -13096,10 +13596,10 @@ size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int64_t n
13096
13596
  void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int64_t k) {
13097
13597
  assert(k % QK_K == 0);
13098
13598
  block_iq3_s * restrict y = vy;
13099
- quantize_row_iq3_s_reference(x, y, k);
13599
+ quantize_row_iq3_s_ref(x, y, k);
13100
13600
  }
13101
13601
 
13102
- void quantize_row_iq3_s_reference(const float * restrict x, block_iq3_s * restrict y, int64_t k) {
13602
+ void quantize_row_iq3_s_ref(const float * restrict x, block_iq3_s * restrict y, int64_t k) {
13103
13603
  assert(k % QK_K == 0);
13104
13604
  quantize_iq3_s(x, y, 1, k, NULL);
13105
13605
  }
@@ -13111,7 +13611,7 @@ static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const u
13111
13611
  const float * restrict xval, const float * restrict weight, float * scale, int8_t * restrict L, int ngrid) {
13112
13612
  int num_neighbors = neighbours[0];
13113
13613
  GGML_ASSERT(num_neighbors > 0);
13114
- float best_score = 0;
13614
+ float best_score = -FLT_MAX;
13115
13615
  int grid_index = -1;
13116
13616
  for (int j = 1; j <= num_neighbors; ++j) {
13117
13617
  const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
@@ -13309,7 +13809,7 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
13309
13809
  sumw[j+1] = sumw[j] + weight[i];
13310
13810
  }
13311
13811
  }
13312
- float best_score = 0, scale = max;
13812
+ float best_score = -FLT_MIN, scale = max;
13313
13813
  int besti1 = -1, besti2 = -1, best_shift = 0;
13314
13814
  for (int i1 = 0; i1 <= block_size; ++i1) {
13315
13815
  for (int i2 = i1; i2 <= block_size; ++i2) {
@@ -13485,7 +13985,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
13485
13985
  idx[2*j] = j;
13486
13986
  }
13487
13987
  qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper);
13488
- float best_score = 0, scale = max;
13988
+ float best_score = -FLT_MIN, scale = max;
13489
13989
  int besti1 = -1, besti2 = -1, best_k = -1;
13490
13990
  // 0: +, +
13491
13991
  // 1: +, -
@@ -13837,7 +14337,7 @@ void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int64_t k
13837
14337
  }
13838
14338
  }
13839
14339
 
13840
- void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int64_t k) {
14340
+ void quantize_row_iq4_nl_ref(const float * restrict x, block_iq4_nl * restrict y, int64_t k) {
13841
14341
  assert(k % QK4_NL == 0);
13842
14342
  quantize_row_iq4_nl(x, y, k);
13843
14343
  }
@@ -13865,10 +14365,10 @@ size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t
13865
14365
  void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int64_t k) {
13866
14366
  assert(k % QK_K == 0);
13867
14367
  block_iq4_xs * restrict y = vy;
13868
- quantize_row_iq4_xs_reference(x, y, k);
14368
+ quantize_row_iq4_xs_ref(x, y, k);
13869
14369
  }
13870
14370
 
13871
- void quantize_row_iq4_xs_reference(const float * restrict x, block_iq4_xs * restrict y, int64_t k) {
14371
+ void quantize_row_iq4_xs_ref(const float * restrict x, block_iq4_xs * restrict y, int64_t k) {
13872
14372
  assert(k % QK_K == 0);
13873
14373
  quantize_iq4_xs(x, y, 1, k, NULL);
13874
14374
  }
@@ -14015,7 +14515,7 @@ static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy
14015
14515
  printf("Oops: found point %u not on grid:", u);
14016
14516
  for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
14017
14517
  printf("\n");
14018
- GGML_ASSERT(false);
14518
+ GGML_ABORT("fatal error");
14019
14519
  }
14020
14520
  const int i8 = 2*ib + k;
14021
14521
  y[ibl].qs[i8] = grid_index & 255;
@@ -14055,7 +14555,7 @@ size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int64_t n
14055
14555
  return nrow * nblock * sizeof(block_iq2_s);
14056
14556
  }
14057
14557
 
14058
- void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restrict y, int64_t k) {
14558
+ void quantize_row_iq2_s_ref(const float * restrict x, block_iq2_s * restrict y, int64_t k) {
14059
14559
  assert(k % QK_K == 0);
14060
14560
  quantize_iq2_s(x, y, 1, k, NULL);
14061
14561
  }
@@ -14063,7 +14563,7 @@ void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restri
14063
14563
  void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) {
14064
14564
  assert(k % QK_K == 0);
14065
14565
  block_iq2_s * restrict y = vy;
14066
- quantize_row_iq2_s_reference(x, y, k);
14566
+ quantize_row_iq2_s_ref(x, y, k);
14067
14567
  }
14068
14568
 
14069
14569
  static bool validate_float(float f, size_t i) {
@@ -14118,6 +14618,16 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) {
14118
14618
  } \
14119
14619
  }
14120
14620
 
14621
+ #define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
14622
+ const type * q = (const type *) (data); \
14623
+ for (size_t i = 0; i < (nb); ++i) { \
14624
+ for (size_t j = 0; j < (nr); ++j) { \
14625
+ if (!validate_fp16(q[i].d[j], i)) { \
14626
+ return false; \
14627
+ } \
14628
+ } \
14629
+ }
14630
+
14121
14631
  bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes) {
14122
14632
  if (type < 0 || type >= GGML_TYPE_COUNT) {
14123
14633
  fprintf(stderr, "%s: invalid type %d\n", __func__, type);
@@ -14125,7 +14635,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
14125
14635
  }
14126
14636
 
14127
14637
  if (nbytes % ggml_type_size(type) != 0) {
14128
- fprintf(stderr, "%s: invalid size %zu for type %d\n", __func__, nbytes, type);
14638
+ fprintf(stderr, "%s: invalid size %zu for type %s (type size = %zu)\n", __func__, nbytes, ggml_type_name(type), ggml_type_size(type));
14129
14639
  return false;
14130
14640
  }
14131
14641
 
@@ -14335,6 +14845,16 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
14335
14845
  {
14336
14846
  VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
14337
14847
  } break;
14848
+ case GGML_TYPE_Q4_0_4_4:
14849
+ case GGML_TYPE_Q4_0_4_8:
14850
+ {
14851
+ VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x4, data, nbytes / sizeof(block_q4_0x4), 4);
14852
+ } break;
14853
+ case GGML_TYPE_Q4_0_8_8:
14854
+ {
14855
+ VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x8, data, nbytes / sizeof(block_q4_0x8), 8);
14856
+ } break;
14857
+
14338
14858
  case GGML_TYPE_I8:
14339
14859
  case GGML_TYPE_I16:
14340
14860
  case GGML_TYPE_I32: