@fugood/llama.node 0.3.16 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (281) hide show
  1. package/CMakeLists.txt +6 -1
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +44 -2
  19. package/lib/index.js +132 -1
  20. package/lib/index.ts +203 -3
  21. package/package.json +2 -1
  22. package/src/EmbeddingWorker.cpp +1 -1
  23. package/src/LlamaCompletionWorker.cpp +374 -19
  24. package/src/LlamaCompletionWorker.h +31 -10
  25. package/src/LlamaContext.cpp +216 -7
  26. package/src/LlamaContext.h +12 -0
  27. package/src/common.hpp +15 -0
  28. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +233 -0
  29. package/src/llama.cpp/.github/workflows/build.yml +89 -767
  30. package/src/llama.cpp/.github/workflows/docker.yml +9 -6
  31. package/src/llama.cpp/.github/workflows/release.yml +716 -0
  32. package/src/llama.cpp/.github/workflows/server.yml +19 -23
  33. package/src/llama.cpp/CMakeLists.txt +11 -1
  34. package/src/llama.cpp/cmake/build-info.cmake +8 -2
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
  36. package/src/llama.cpp/common/CMakeLists.txt +35 -4
  37. package/src/llama.cpp/common/arg.cpp +844 -121
  38. package/src/llama.cpp/common/arg.h +9 -0
  39. package/src/llama.cpp/common/chat.cpp +129 -107
  40. package/src/llama.cpp/common/chat.h +2 -0
  41. package/src/llama.cpp/common/common.cpp +64 -518
  42. package/src/llama.cpp/common/common.h +35 -45
  43. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  44. package/src/llama.cpp/common/llguidance.cpp +31 -47
  45. package/src/llama.cpp/common/minja/chat-template.hpp +23 -11
  46. package/src/llama.cpp/common/minja/minja.hpp +186 -127
  47. package/src/llama.cpp/common/regex-partial.cpp +204 -0
  48. package/src/llama.cpp/common/regex-partial.h +56 -0
  49. package/src/llama.cpp/common/sampling.cpp +60 -50
  50. package/src/llama.cpp/docs/build.md +122 -7
  51. package/src/llama.cpp/examples/CMakeLists.txt +2 -32
  52. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  53. package/src/llama.cpp/examples/embedding/embedding.cpp +9 -12
  54. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  55. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  56. package/src/llama.cpp/examples/parallel/parallel.cpp +89 -15
  57. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  58. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  59. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  60. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  61. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  62. package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
  63. package/src/llama.cpp/examples/training/finetune.cpp +96 -0
  64. package/src/llama.cpp/ggml/CMakeLists.txt +35 -2
  65. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  66. package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
  67. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
  68. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  69. package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
  70. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  71. package/src/llama.cpp/ggml/include/ggml.h +76 -106
  72. package/src/llama.cpp/ggml/src/CMakeLists.txt +11 -8
  73. package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
  74. package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
  75. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  76. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  77. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  78. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  79. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  80. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  81. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  82. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  83. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +66 -33
  84. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  85. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  86. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  87. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  88. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -194
  89. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  90. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1060 -410
  91. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1008 -13533
  92. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +31 -16
  93. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +90 -12
  94. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -13
  95. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +266 -72
  96. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1034 -88
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8796 -0
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  99. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  101. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  102. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +252 -0
  103. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  104. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
  105. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  106. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  107. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  108. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  109. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +106 -14
  110. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  111. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -262
  112. package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
  113. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
  114. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +307 -40
  115. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +125 -45
  116. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +10 -8
  117. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +239 -0
  118. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  119. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  120. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +9 -307
  121. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
  122. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
  123. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  124. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +944 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  130. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  131. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +507 -411
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  133. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  134. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
  135. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  136. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  137. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  138. package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  140. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  141. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
  142. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +83 -49
  143. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1278 -282
  144. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +32 -0
  145. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +133 -30
  146. package/src/llama.cpp/ggml/src/ggml.c +170 -265
  147. package/src/llama.cpp/ggml/src/gguf.cpp +34 -33
  148. package/src/llama.cpp/include/llama.h +82 -22
  149. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  150. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  151. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  152. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  153. package/src/llama.cpp/requirements/requirements-all.txt +5 -3
  154. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  155. package/src/llama.cpp/scripts/xxd.cmake +1 -1
  156. package/src/llama.cpp/src/CMakeLists.txt +4 -2
  157. package/src/llama.cpp/src/llama-adapter.cpp +43 -1
  158. package/src/llama.cpp/src/llama-arch.cpp +163 -17
  159. package/src/llama.cpp/src/llama-arch.h +16 -0
  160. package/src/llama.cpp/src/llama-batch.cpp +5 -1
  161. package/src/llama.cpp/src/llama-batch.h +2 -1
  162. package/src/llama.cpp/src/llama-chat.cpp +91 -16
  163. package/src/llama.cpp/src/llama-chat.h +7 -2
  164. package/src/llama.cpp/src/llama-context.cpp +479 -575
  165. package/src/llama.cpp/src/llama-context.h +44 -33
  166. package/src/llama.cpp/src/llama-cparams.h +1 -0
  167. package/src/llama.cpp/src/llama-graph.cpp +209 -157
  168. package/src/llama.cpp/src/llama-graph.h +38 -14
  169. package/src/llama.cpp/src/llama-hparams.h +13 -0
  170. package/src/llama.cpp/src/llama-kv-cache.cpp +1604 -543
  171. package/src/llama.cpp/src/llama-kv-cache.h +283 -171
  172. package/src/llama.cpp/src/llama-memory.h +12 -2
  173. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  174. package/src/llama.cpp/src/llama-model-loader.cpp +34 -20
  175. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  176. package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
  177. package/src/llama.cpp/src/llama-model-saver.h +37 -0
  178. package/src/llama.cpp/src/llama-model.cpp +1803 -330
  179. package/src/llama.cpp/src/llama-model.h +21 -2
  180. package/src/llama.cpp/src/llama-quant.cpp +33 -10
  181. package/src/llama.cpp/src/llama-sampling.cpp +25 -7
  182. package/src/llama.cpp/src/llama-vocab.cpp +86 -10
  183. package/src/llama.cpp/src/llama-vocab.h +6 -0
  184. package/src/llama.cpp/src/llama.cpp +15 -1
  185. package/src/llama.cpp/tests/CMakeLists.txt +52 -31
  186. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  187. package/src/llama.cpp/tests/test-backend-ops.cpp +189 -90
  188. package/src/llama.cpp/tests/test-chat-template.cpp +26 -6
  189. package/src/llama.cpp/tests/test-chat.cpp +15 -3
  190. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  191. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  192. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  193. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  194. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  195. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  196. package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
  197. package/src/llama.cpp/tests/test-opt.cpp +33 -21
  198. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  199. package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
  200. package/src/llama.cpp/tests/test-sampling.cpp +1 -1
  201. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  202. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  203. package/src/llama.cpp/tools/CMakeLists.txt +39 -0
  204. package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +3 -3
  205. package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +1 -1
  206. package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +15 -16
  207. package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
  208. package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +623 -274
  209. package/src/llama.cpp/{examples → tools}/main/main.cpp +22 -14
  210. package/src/llama.cpp/tools/mtmd/CMakeLists.txt +47 -0
  211. package/src/llama.cpp/tools/mtmd/clip-impl.h +365 -0
  212. package/src/llama.cpp/tools/mtmd/clip.cpp +3646 -0
  213. package/src/llama.cpp/tools/mtmd/clip.h +99 -0
  214. package/src/llama.cpp/tools/mtmd/deprecation-warning.cpp +22 -0
  215. package/src/llama.cpp/tools/mtmd/mtmd-cli.cpp +370 -0
  216. package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
  217. package/src/llama.cpp/tools/mtmd/mtmd.cpp +678 -0
  218. package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
  219. package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +21 -5
  220. package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +53 -3
  221. package/src/llama.cpp/tools/rpc/CMakeLists.txt +4 -0
  222. package/src/llama.cpp/tools/rpc/rpc-server.cpp +322 -0
  223. package/src/llama.cpp/tools/run/CMakeLists.txt +16 -0
  224. package/src/llama.cpp/{examples → tools}/run/run.cpp +30 -30
  225. package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
  226. package/src/llama.cpp/{examples → tools}/server/httplib.h +313 -247
  227. package/src/llama.cpp/{examples → tools}/server/server.cpp +529 -215
  228. package/src/llama.cpp/{examples → tools}/server/utils.hpp +427 -6
  229. package/src/llama.cpp/{examples → tools}/tts/tts.cpp +6 -9
  230. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
  231. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  232. package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
  233. package/src/llama.cpp/examples/infill/infill.cpp +0 -590
  234. package/src/llama.cpp/examples/llava/CMakeLists.txt +0 -66
  235. package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
  236. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
  237. package/src/llama.cpp/examples/llava/clip.cpp +0 -3206
  238. package/src/llama.cpp/examples/llava/clip.h +0 -118
  239. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  240. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  241. package/src/llama.cpp/examples/llava/llava.cpp +0 -574
  242. package/src/llama.cpp/examples/llava/llava.h +0 -49
  243. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  244. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +0 -584
  245. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  246. package/src/llama.cpp/examples/rpc/CMakeLists.txt +0 -2
  247. package/src/llama.cpp/examples/rpc/rpc-server.cpp +0 -171
  248. package/src/llama.cpp/examples/run/CMakeLists.txt +0 -5
  249. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  250. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  251. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  252. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  253. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  254. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  255. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  256. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  257. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  258. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  259. /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
  260. /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
  261. /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
  262. /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
  263. /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
  264. /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
  265. /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
  266. /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
  267. /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
  268. /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
  269. /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
  270. /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
  271. /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
  272. /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
  273. /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
  274. /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
  275. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
  276. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
  277. /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
  278. /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
  279. /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
  280. /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
  281. /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
@@ -20,12 +20,6 @@
20
20
  #define GROUP_MAX_EPS_IQ1_M 1e-7f
21
21
  #define GROUP_MAX_EPS_IQ1_S 1e-12f
22
22
 
23
- #if defined(_MSC_VER)
24
- // disable "possible loss of data" to avoid warnings for hundreds of casts
25
- // we should just be careful :)
26
- #pragma warning(disable: 4244 4267)
27
- #endif
28
-
29
23
  #define UNUSED GGML_UNUSED
30
24
 
31
25
  // some compilers don't provide _mm256_set_m128i, e.g. gcc 7
@@ -891,15 +885,15 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
891
885
  }
892
886
  #elif defined(__riscv_v_intrinsic)
893
887
 
894
- size_t vl = __riscv_vsetvl_e32m4(QK8_0);
888
+ size_t vl = QK8_0;
895
889
 
896
890
  for (int i = 0; i < nb; i++) {
897
891
  // load elements
898
- vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
892
+ vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_0, vl);
899
893
 
900
- vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
894
+ vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl);
901
895
  vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl);
902
- vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
896
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl);
903
897
  float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
904
898
 
905
899
  const float d = amax / ((1 << 7) - 1);
@@ -907,14 +901,14 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
907
901
 
908
902
  y[i].d = GGML_FP32_TO_FP16(d);
909
903
 
910
- vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
904
+ vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl);
911
905
 
912
906
  // convert to integer
913
- vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
914
- vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
907
+ vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl);
908
+ vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl);
915
909
 
916
910
  // store result
917
- __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
911
+ __riscv_vse8_v_i8m2(y[i].qs , vs, vl);
918
912
  }
919
913
 
920
914
  #elif defined(__POWER9_VECTOR__)
@@ -1229,15 +1223,15 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
1229
1223
  }
1230
1224
  #elif defined(__riscv_v_intrinsic)
1231
1225
 
1232
- size_t vl = __riscv_vsetvl_e32m4(QK8_1);
1226
+ size_t vl = QK8_1;
1233
1227
 
1234
1228
  for (int i = 0; i < nb; i++) {
1235
1229
  // load elements
1236
- vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
1230
+ vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_1, vl);
1237
1231
 
1238
- vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
1232
+ vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl);
1239
1233
  vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl);
1240
- vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
1234
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl);
1241
1235
  float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
1242
1236
 
1243
1237
  const float d = amax / ((1 << 7) - 1);
@@ -1245,18 +1239,18 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
1245
1239
 
1246
1240
  y[i].d = GGML_FP32_TO_FP16(d);
1247
1241
 
1248
- vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
1242
+ vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl);
1249
1243
 
1250
1244
  // convert to integer
1251
- vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
1252
- vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
1245
+ vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl);
1246
+ vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl);
1253
1247
 
1254
1248
  // store result
1255
- __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
1249
+ __riscv_vse8_v_i8m2(y[i].qs , vs, vl);
1256
1250
 
1257
1251
  // compute sum for y[i].s
1258
1252
  vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
1259
- vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
1253
+ vint16m1_t vwrs = __riscv_vwredsum_vs_i8m2_i16m1(vs, tmp2, vl);
1260
1254
 
1261
1255
  // set y[i].s
1262
1256
  int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
@@ -2391,33 +2385,31 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
2391
2385
 
2392
2386
  sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
2393
2387
  #elif defined(__riscv_v_intrinsic)
2394
- size_t vl = __riscv_vsetvl_e8m1(qk/2);
2388
+ size_t vl = qk / 2;
2395
2389
 
2396
2390
  for (; ib < nb; ++ib) {
2397
2391
  // load elements
2398
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
2392
+ vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl);
2399
2393
 
2400
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
2401
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
2394
+ vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
2395
+ vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl);
2402
2396
 
2403
2397
  // mask and store lower part of x, and then upper part
2404
- vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
2405
- vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
2398
+ vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
2399
+ vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
2406
2400
 
2407
- vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
2408
- vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
2401
+ vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
2402
+ vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
2409
2403
 
2410
2404
  // subtract offset
2411
- vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
2412
- vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
2405
+ vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl);
2406
+ vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl);
2413
2407
 
2414
- vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
2415
- vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
2408
+ vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
2409
+ vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl);
2416
2410
 
2417
2411
  vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
2418
-
2419
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
2420
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
2412
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
2421
2413
 
2422
2414
  int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
2423
2415
 
@@ -2783,29 +2775,27 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
2783
2775
 
2784
2776
  sumf = hsum_float_8(acc) + summs;
2785
2777
  #elif defined(__riscv_v_intrinsic)
2786
- size_t vl = __riscv_vsetvl_e8m1(qk/2);
2778
+ size_t vl = qk / 2;
2787
2779
 
2788
2780
  for (; ib < nb; ++ib) {
2789
2781
  // load elements
2790
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
2782
+ vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl);
2791
2783
 
2792
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
2793
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
2784
+ vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
2785
+ vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl);
2794
2786
 
2795
2787
  // mask and store lower part of x, and then upper part
2796
- vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
2797
- vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
2788
+ vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
2789
+ vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
2798
2790
 
2799
- vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
2800
- vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
2791
+ vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
2792
+ vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
2801
2793
 
2802
- vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
2803
- vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
2794
+ vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
2795
+ vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl);
2804
2796
 
2805
2797
  vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
2806
-
2807
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
2808
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
2798
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
2809
2799
 
2810
2800
  int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
2811
2801
 
@@ -3132,65 +3122,33 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
3132
3122
 
3133
3123
  sumf = hsum_float_8(acc);
3134
3124
  #elif defined(__riscv_v_intrinsic)
3135
- uint32_t qh;
3136
-
3137
- size_t vl = __riscv_vsetvl_e8m1(qk/2);
3138
-
3139
- // These temporary registers are for masking and shift operations
3140
- vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3141
- vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
3142
-
3143
- vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
3144
- vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
3125
+ size_t vl;
3126
+ size_t vlenb = __riscv_vlenb();
3145
3127
 
3146
3128
  for (; ib < nb; ++ib) {
3147
- memcpy(&qh, x[ib].qh, sizeof(uint32_t));
3148
-
3149
- // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
3150
- vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
3151
- vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
3152
- vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
3153
-
3154
- // ((qh & (1u << (j + 16))) >> (j + 12));
3155
- vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
3156
- vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
3157
-
3158
- // narrowing
3159
- vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
3160
- vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
3161
-
3162
- vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
3163
- vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
3164
-
3165
- // load
3166
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
3167
-
3168
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
3169
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
3170
-
3171
- vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
3172
- vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
3173
-
3174
- vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
3175
- vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
3176
-
3177
- vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
3178
- vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
3179
-
3180
- vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
3181
- vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
3182
-
3183
- vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
3184
- vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
3185
-
3186
- vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
3187
-
3188
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
3189
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
3190
-
3191
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
3192
-
3193
- sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi;
3129
+ vl = qk / 2;
3130
+ vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl);
3131
+ vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl));
3132
+ vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl));
3133
+ vint8m2_t v0c;
3134
+ if (vlenb == 16) {
3135
+ v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h);
3136
+ } else {
3137
+ v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32);
3138
+ v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l);
3139
+ }
3140
+
3141
+ vl = qk;
3142
+ vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl);
3143
+ qh = __riscv_vmnand_mm_b4(qh, qh, vl);
3144
+ vint8m2_t v0f = __riscv_vsub_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl);
3145
+ vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
3146
+ vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl);
3147
+ vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl);
3148
+ vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl);
3149
+ int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum);
3150
+
3151
+ sumf += (GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)) * sumi;
3194
3152
  }
3195
3153
 
3196
3154
  #elif defined(__POWER9_VECTOR__)
@@ -3503,60 +3461,30 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
3503
3461
 
3504
3462
  sumf = hsum_float_8(acc) + summs;
3505
3463
  #elif defined(__riscv_v_intrinsic)
3506
- uint32_t qh;
3507
-
3508
- size_t vl = __riscv_vsetvl_e8m1(qk/2);
3509
-
3510
- // temporary registers for shift operations
3511
- vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
3512
- vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
3464
+ size_t vl;
3465
+ size_t vlenb = __riscv_vlenb();
3513
3466
 
3514
3467
  for (; ib < nb; ++ib) {
3515
- memcpy(&qh, x[ib].qh, sizeof(uint32_t));
3516
-
3517
- // load qh
3518
- vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
3519
-
3520
- // ((qh >> (j + 0)) << 4) & 0x10;
3521
- vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
3522
- vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
3523
- vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
3524
-
3525
- // ((qh >> (j + 12)) ) & 0x10;
3526
- vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
3527
- vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
3528
-
3529
- // narrowing
3530
- vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
3531
- vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
3532
-
3533
- vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
3534
- vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
3535
-
3536
- // load
3537
- vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
3538
-
3539
- vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
3540
- vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
3541
-
3542
- vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
3543
- vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
3544
-
3545
- vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
3546
- vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
3547
-
3548
- vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
3549
- vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
3550
-
3551
- vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
3552
- vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
3553
-
3554
- vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
3555
-
3556
- vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
3557
- vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
3558
-
3559
- int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
3468
+ vl = qk / 2;
3469
+ vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl);
3470
+ vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl));
3471
+ vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl));
3472
+ vint8m2_t v0c;
3473
+ if (vlenb == 16) {
3474
+ v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h);
3475
+ } else {
3476
+ v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32);
3477
+ v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l);
3478
+ }
3479
+
3480
+ vl = qk;
3481
+ vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl);
3482
+ vint8m2_t v0f = __riscv_vor_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl);
3483
+ vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
3484
+ vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl);
3485
+ vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl);
3486
+ vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl);
3487
+ int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum);
3560
3488
 
3561
3489
  sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
3562
3490
  }
@@ -3970,17 +3898,17 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
3970
3898
 
3971
3899
  sumf = hsum_float_8(accum);
3972
3900
  #elif defined(__riscv_v_intrinsic)
3973
- size_t vl = __riscv_vsetvl_e8m1(qk);
3901
+ size_t vl = qk;
3974
3902
 
3975
3903
  for (; ib < nb; ++ib) {
3976
3904
  // load elements
3977
- vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[ib].qs, vl);
3978
- vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
3905
+ vint8m2_t bx_0 = __riscv_vle8_v_i8m2(x[ib].qs, vl);
3906
+ vint8m2_t by_0 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
3979
3907
 
3980
- vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl);
3908
+ vint16m4_t vw_mul = __riscv_vwmul_vv_i16m4(bx_0, by_0, vl);
3981
3909
 
3982
3910
  vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
3983
- vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
3911
+ vint32m1_t v_sum = __riscv_vwredsum_vs_i16m4_i32m1(vw_mul, v_zero, vl);
3984
3912
 
3985
3913
  int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
3986
3914
 
@@ -5174,84 +5102,182 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
5174
5102
 
5175
5103
  #elif defined __riscv_v_intrinsic
5176
5104
 
5105
+ const int vector_length = __riscv_vlenb() * 8;
5177
5106
  float sumf = 0;
5178
- uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
5179
- 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
5180
5107
 
5181
- for (int i = 0; i < nb; ++i) {
5108
+ uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
5109
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };
5110
+ uint8_t atmp[16];
5182
5111
 
5183
- const uint8_t * q2 = x[i].qs;
5184
- const int8_t * q8 = y[i].qs;
5185
- const uint8_t * sc = x[i].scales;
5186
-
5187
- const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5188
- const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
5112
+ switch (vector_length) {
5113
+ case 256:
5114
+ for (int i = 0; i < nb; ++i) {
5115
+ const uint8_t * q2 = x[i].qs;
5116
+ const int8_t * q8 = y[i].qs;
5117
+ const uint8_t * sc = x[i].scales;
5189
5118
 
5190
- size_t vl = 16;
5119
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5120
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
5191
5121
 
5192
- vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl);
5193
- vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
5122
+ size_t vl = 16;
5194
5123
 
5195
- vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl);
5124
+ vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl);
5125
+ vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
5196
5126
 
5197
- vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl);
5198
- vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
5199
- vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
5200
- vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
5201
- vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
5127
+ vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl);
5202
5128
 
5203
- sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums);
5129
+ vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl);
5130
+ vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
5131
+ vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
5132
+ vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
5133
+ vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
5204
5134
 
5205
- vl = 32;
5135
+ sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums);
5206
5136
 
5207
- vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
5208
- vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl);
5137
+ vl = 32;
5209
5138
 
5210
- uint8_t is=0;
5211
- int isum=0;
5139
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
5140
+ vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl);
5212
5141
 
5213
- for (int j = 0; j < QK_K/128; ++j) {
5214
- // load Q2
5215
- vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
5142
+ uint8_t is = 0;
5143
+ int isum = 0;
5216
5144
 
5217
- vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl);
5218
- vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl);
5219
- vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl);
5220
- vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl);
5145
+ for (int j = 0; j < QK_K / 128; ++j) {
5146
+ // load Q2
5147
+ vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
5221
5148
 
5222
- // duplicate scale elements for product
5223
- vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl);
5224
- vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl);
5225
- vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl);
5226
- vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl);
5149
+ vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl);
5150
+ vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl);
5151
+ vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl);
5152
+ vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl);
5227
5153
 
5228
- vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl));
5229
- vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl));
5230
- vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl));
5231
- vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl));
5154
+ // duplicate scale elements for product
5155
+ vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl);
5156
+ vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl);
5157
+ vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl);
5158
+ vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl);
5232
5159
 
5233
- // load Q8
5234
- vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
5235
- vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
5236
- vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl);
5237
- vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl);
5160
+ vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl));
5161
+ vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl));
5162
+ vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl));
5163
+ vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl));
5238
5164
 
5239
- vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl);
5240
- vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl);
5241
- vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl);
5242
- vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl);
5165
+ // load Q8
5166
+ vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
5167
+ vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl);
5168
+ vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl);
5169
+ vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl);
5243
5170
 
5244
- vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl);
5245
- vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl);
5171
+ vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl);
5172
+ vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl);
5173
+ vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl);
5174
+ vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl);
5246
5175
 
5247
- isum += __riscv_vmv_x_s_i32m1_i32(isum1);
5176
+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl);
5177
+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl);
5248
5178
 
5249
- q2+=32; q8+=128; is=8;
5179
+ isum += __riscv_vmv_x_s_i32m1_i32(isum1);
5250
5180
 
5251
- }
5181
+ q2 += 32;
5182
+ q8 += 128;
5183
+ is = 8;
5184
+ }
5252
5185
 
5253
- sumf += dall * isum;
5186
+ sumf += dall * isum;
5187
+ }
5188
+ break;
5189
+ case 128:
5190
+ for (int i = 0; i < nb; ++i) {
5191
+ const uint8_t * q2 = x[i].qs;
5192
+ const int8_t * q8 = y[i].qs;
5193
+ const uint8_t * sc = x[i].scales;
5194
+ const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
5195
+ const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
5196
+ uint8_t *patmp = atmp;
5197
+ int vsums;
5198
+ int tmp;
5199
+ __asm__ __volatile__(
5200
+ "vsetivli zero, 16, e8, m1\n\t"
5201
+ "vmv.v.x v8, zero\n\t"
5202
+ "vle8.v v1, (%[sc])\n\t"
5203
+ "vand.vi v0, v1, 0xF\n\t"
5204
+ "vsrl.vi v1, v1, 4\n\t"
5205
+ "vse8.v v0, (%[scale])\n\t"
5206
+ "vsetivli zero, 16, e16, m2\n\t"
5207
+ "vle16.v v2, (%[bsums])\n\t"
5208
+ "vzext.vf2 v0, v1\n\t"
5209
+ "vwmul.vv v4, v0, v2\n\t"
5210
+ "vsetivli zero, 16, e32, m4\n\t"
5211
+ "vredsum.vs v8, v4, v8\n\t"
5212
+ "vmv.x.s %[vsums], v8"
5213
+ : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums)
5214
+ : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums)
5215
+ : "memory"
5216
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
5217
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
5218
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
5219
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
5220
+ );
5221
+ sumf += dmin * vsums;
5222
+ int isum = 0;
5223
+
5224
+ for (int j = 0; j < QK_K/128; ++j) {
5225
+ __asm__ __volatile__(
5226
+ "vsetvli zero, %[vl32], e8, m2\n\t"
5227
+ "vle8.v v0, (%[q2])\n\t"
5228
+ "vsrl.vi v2, v0, 2\n\t"
5229
+ "vsrl.vi v4, v0, 4\n\t"
5230
+ "vsrl.vi v6, v0, 6\n\t"
5231
+ "vand.vi v0, v0, 0x3\n\t"
5232
+ "vand.vi v2, v2, 0x3\n\t"
5233
+ "vand.vi v4, v4, 0x3\n\t"
5234
+ "vsetvli zero, %[vl128], e8, m8\n\t"
5235
+ "vle8.v v8, (%[q8])\n\t"
5236
+ "vsetvli zero, %[vl64], e8, m4\n\t"
5237
+ "vwmul.vv v16, v0, v8\n\t"
5238
+ "vwmul.vv v24, v4, v12\n\t"
5239
+ "vsetivli zero, 16, e16, m2\n\t"
5240
+ "vmv.v.x v0, zero\n\t"
5241
+ "vwredsum.vs v10, v16, v0\n\t"
5242
+ "vwredsum.vs v9, v18, v0\n\t"
5243
+ "vwredsum.vs v8, v20, v0\n\t"
5244
+ "vwredsum.vs v7, v22, v0\n\t"
5245
+ "vwredsum.vs v11, v24, v0\n\t"
5246
+ "vwredsum.vs v12, v26, v0\n\t"
5247
+ "vwredsum.vs v13, v28, v0\n\t"
5248
+ "vwredsum.vs v14, v30, v0\n\t"
5249
+ "vsetivli zero, 4, e32, m1\n\t"
5250
+ "vslideup.vi v10, v9, 1\n\t"
5251
+ "vslideup.vi v8, v7, 1\n\t"
5252
+ "vslideup.vi v11, v12, 1\n\t"
5253
+ "vslideup.vi v13, v14, 1\n\t"
5254
+ "vslideup.vi v10, v8, 2\n\t"
5255
+ "vslideup.vi v11, v13, 2\n\t"
5256
+ "vsetivli zero, 8, e32, m2\n\t"
5257
+ "vle8.v v15, (%[scale])\n\t"
5258
+ "vzext.vf4 v12, v15\n\t"
5259
+ "vmul.vv v10, v10, v12\n\t"
5260
+ "vredsum.vs v0, v10, v0\n\t"
5261
+ "vmv.x.s %[tmp], v0\n\t"
5262
+ "add %[isum], %[isum], %[tmp]"
5263
+ : [tmp] "=&r" (tmp), [isum] "+&r" (isum)
5264
+ : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
5265
+ , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
5266
+ : "memory"
5267
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
5268
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
5269
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
5270
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
5271
+ );
5272
+ q2 += 32; q8 += 128; patmp += 8;
5273
+ }
5254
5274
 
5275
+ sumf += dall * isum;
5276
+ }
5277
+ break;
5278
+ default:
5279
+ assert(false && "Unsupported vector length");
5280
+ break;
5255
5281
  }
5256
5282
 
5257
5283
  *s = sumf;
@@ -6116,97 +6142,221 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
6116
6142
  uint32_t aux[3];
6117
6143
  uint32_t utmp[4];
6118
6144
 
6145
+ const int vector_length = __riscv_vlenb() * 8;
6119
6146
  float sumf = 0;
6120
- for (int i = 0; i < nb; ++i) {
6121
6147
 
6122
- const uint8_t * GGML_RESTRICT q3 = x[i].qs;
6123
- const uint8_t * GGML_RESTRICT qh = x[i].hmask;
6124
- const int8_t * GGML_RESTRICT q8 = y[i].qs;
6148
+ switch (vector_length) {
6149
+ case 256:
6150
+ for (int i = 0; i < nb; ++i) {
6125
6151
 
6126
- memcpy(aux, x[i].scales, 12);
6127
- utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
6128
- utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
6129
- utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
6130
- utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
6152
+ const uint8_t * GGML_RESTRICT q3 = x[i].qs;
6153
+ const uint8_t * GGML_RESTRICT qh = x[i].hmask;
6154
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
6131
6155
 
6132
- int8_t * scale = (int8_t *)utmp;
6133
- for (int j = 0; j < 16; ++j) scale[j] -= 32;
6156
+ memcpy(aux, x[i].scales, 12);
6157
+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
6158
+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
6159
+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
6160
+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
6134
6161
 
6162
+ int8_t * scale = (int8_t *)utmp;
6163
+ for (int j = 0; j < 16; ++j) scale[j] -= 32;
6135
6164
 
6136
- size_t vl = 32;
6137
- uint8_t m = 1;
6138
6165
 
6139
- vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
6140
- vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl);
6166
+ size_t vl = 32;
6167
+ uint8_t m = 1;
6141
6168
 
6142
- int sum_t = 0;
6169
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
6170
+ vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl);
6143
6171
 
6144
- for (int j = 0; j < QK_K; j += 128) {
6172
+ int sum_t = 0;
6145
6173
 
6146
- vl = 32;
6174
+ for (int j = 0; j < QK_K; j += 128) {
6147
6175
 
6148
- // load Q3
6149
- vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
6176
+ vl = 32;
6150
6177
 
6151
- vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl));
6152
- vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl));
6153
- vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
6154
- vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
6178
+ // load Q3
6179
+ vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
6155
6180
 
6156
- // compute mask for subtraction
6157
- vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
6158
- vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
6159
- vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl);
6160
- m <<= 1;
6181
+ vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl));
6182
+ vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl));
6183
+ vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
6184
+ vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
6161
6185
 
6162
- vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
6163
- vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
6164
- vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl);
6165
- m <<= 1;
6186
+ // compute mask for subtraction
6187
+ vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
6188
+ vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
6189
+ vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl);
6190
+ m <<= 1;
6166
6191
 
6167
- vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
6168
- vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
6169
- vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl);
6170
- m <<= 1;
6192
+ vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
6193
+ vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
6194
+ vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl);
6195
+ m <<= 1;
6171
6196
 
6172
- vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
6173
- vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
6174
- vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl);
6175
- m <<= 1;
6197
+ vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
6198
+ vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
6199
+ vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl);
6200
+ m <<= 1;
6176
6201
 
6177
- // load Q8 and take product with Q3
6178
- vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl);
6179
- vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
6180
- vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
6181
- vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
6202
+ vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
6203
+ vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
6204
+ vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl);
6205
+ m <<= 1;
6182
6206
 
6183
- vl = 16;
6207
+ // load Q8 and take product with Q3
6208
+ vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl);
6209
+ vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
6210
+ vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
6211
+ vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
6184
6212
 
6185
- // retrieve lane to multiply with scale
6186
- vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
6187
- vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
6188
- vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
6189
- vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
6190
- vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
6191
- vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
6192
- vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
6193
- vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
6213
+ vl = 16;
6194
6214
 
6195
- vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl);
6196
- vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl);
6197
- vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl);
6198
- vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl);
6215
+ // retrieve lane to multiply with scale
6216
+ vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
6217
+ vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
6218
+ vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
6219
+ vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
6220
+ vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
6221
+ vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
6222
+ vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
6223
+ vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
6199
6224
 
6200
- sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
6225
+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl);
6226
+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl);
6227
+ vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl);
6228
+ vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl);
6201
6229
 
6202
- q3 += 32; q8 += 128; scale += 8;
6230
+ sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
6203
6231
 
6204
- }
6232
+ q3 += 32; q8 += 128; scale += 8;
6205
6233
 
6206
- const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
6234
+ }
6235
+
6236
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
6237
+
6238
+ sumf += d*sum_t;
6239
+
6240
+ }
6241
+ break;
6242
+ case 128:
6243
+ for (int i = 0; i < nb; ++i) {
6244
+ const uint8_t * restrict q3 = x[i].qs;
6245
+ const uint8_t * restrict qh = x[i].hmask;
6246
+ const int8_t * restrict q8 = y[i].qs;
6247
+
6248
+ int8_t * scale = (int8_t *)utmp;
6249
+ int tmp;
6250
+ __asm__ __volatile__(
6251
+ "vsetivli zero, 12, e8, m1\n\t"
6252
+ "vle8.v v0, (%[s6b])\n\t"
6253
+ "vmv1r.v v2, v0\n\t"
6254
+ "vsetivli zero, 2, e64, m1\n\t"
6255
+ "vmv.v.x v9, %[sh]\n\t"\
6256
+ "vslidedown.vi v1, v0, 1\n\t"
6257
+ "vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4}
6258
+ "vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]}
6259
+ "vsetivli zero, 4, e32, m1\n\t"
6260
+ "vid.v v9\n\t"
6261
+ "vmv.x.s %[tmp], v1\n\t"
6262
+ "vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6}
6263
+ "vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]}
6264
+ "vsrl.vv v4, v1, v9\n\t"
6265
+ "vsrl.vv v2, v0, v8\n\t"
6266
+ "vand.vx v5, v4, %[kmask1]\n\t"
6267
+ "vand.vx v3, v2, %[kmask2]\n\t"
6268
+ "vsll.vi v6, v5, 4\n\t"
6269
+ "vor.vv v7, v6, v3\n\t"
6270
+ "vsetivli zero, 16, e8, m1\n\t"
6271
+ "vsub.vx v0, v7, %[c]\n\t"
6272
+ "vse8.v v0, (%[scale])"
6273
+ : [tmp] "=&r" (tmp)
6274
+ : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32)
6275
+ , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2)
6276
+ : "memory"
6277
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
6278
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
6279
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
6280
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
6281
+ );
6207
6282
 
6208
- sumf += d*sum_t;
6283
+ uint8_t m = 1;
6284
+ int isum = 0;
6285
+ for (int j = 0; j < QK_K; j += 128) {
6286
+ __asm__ __volatile__(
6287
+ "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t"
6288
+ "vle8.v v8, (%[q3])\n\t"
6289
+ "vsrl.vi v10, v8, 2\n\t"
6290
+ "vsrl.vi v12, v8, 4\n\t"
6291
+ "vsrl.vi v14, v8, 6\n\t"
6292
+ "vand.vi v8, v8, 3\n\t"
6293
+ "vand.vi v10, v10, 3\n\t"
6294
+ "vand.vi v12, v12, 3\n\t"
6295
+ "vle8.v v2, (%[qh])\n\t"
6296
+ "vand.vx v4, v2, %[m]\n\t"
6297
+ "slli %[m], %[m], 1\n\t"
6298
+ "vmseq.vx v0, v4, zero\n\t"
6299
+ "vadd.vi v8, v8, -4, v0.t\n\t"
6300
+ "vand.vx v4, v2, %[m]\n\t"
6301
+ "slli %[m], %[m], 1\n\t"
6302
+ "vmseq.vx v0, v4, zero\n\t"
6303
+ "vadd.vi v10, v10, -4, v0.t\n\t"
6304
+ "vand.vx v4, v2, %[m]\n\t"
6305
+ "slli %[m], %[m], 1\n\t"
6306
+ "vmseq.vx v0, v4, zero\n\t"
6307
+ "vadd.vi v12, v12, -4, v0.t\n\t"
6308
+ "vand.vx v4, v2, %[m]\n\t"
6309
+ "slli %[m], %[m], 1\n\t"
6310
+ "vmseq.vx v0, v4, zero\n\t"
6311
+ "vadd.vi v14, v14, -4, v0.t\n\t"
6312
+ "vsetvli zero, %[vl128], e8, m8\n\t"
6313
+ "vle8.v v0, (%[q8])\n\t"
6314
+ "vsetvli zero, %[vl64], e8, m4\n\t"
6315
+ "vwmul.vv v16, v0, v8\n\t"
6316
+ "vwmul.vv v24, v4, v12\n\t"
6317
+ "vsetivli zero, 16, e16, m2\n\t"
6318
+ "vmv.v.x v0, zero\n\t"
6319
+ "vwredsum.vs v10, v16, v0\n\t"
6320
+ "vwredsum.vs v9, v18, v0\n\t"
6321
+ "vwredsum.vs v8, v20, v0\n\t"
6322
+ "vwredsum.vs v7, v22, v0\n\t"
6323
+ "vwredsum.vs v11, v24, v0\n\t"
6324
+ "vwredsum.vs v12, v26, v0\n\t"
6325
+ "vwredsum.vs v13, v28, v0\n\t"
6326
+ "vwredsum.vs v14, v30, v0\n\t"
6327
+ "vsetivli zero, 4, e32, m1\n\t"
6328
+ "vslideup.vi v10, v9, 1\n\t"
6329
+ "vslideup.vi v8, v7, 1\n\t"
6330
+ "vslideup.vi v11, v12, 1\n\t"
6331
+ "vslideup.vi v13, v14, 1\n\t"
6332
+ "vslideup.vi v10, v8, 2\n\t"
6333
+ "vslideup.vi v11, v13, 2\n\t"
6334
+ "vsetivli zero, 8, e32, m2\n\t"\
6335
+ "vle8.v v15, (%[scale])\n\t"
6336
+ "vsext.vf4 v12, v15\n\t"
6337
+ "vmul.vv v10, v10, v12\n\t"
6338
+ "vredsum.vs v0, v10, v0\n\t"
6339
+ "vmv.x.s %[tmp], v0\n\t"
6340
+ "add %[isum], %[isum], %[tmp]"
6341
+ : [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum)
6342
+ : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32)
6343
+ , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8)
6344
+ : "memory"
6345
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
6346
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
6347
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
6348
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
6349
+ );
6350
+ q3 += 32; q8 += 128; scale += 8;
6351
+ }
6209
6352
 
6353
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
6354
+ sumf += d * isum;
6355
+ }
6356
+ break;
6357
+ default:
6358
+ assert(false && "Unsupported vector length");
6359
+ break;
6210
6360
  }
6211
6361
 
6212
6362
  *s = sumf;
@@ -6440,7 +6590,118 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
6440
6590
  }
6441
6591
 
6442
6592
  *s = hsum_float_8(acc);
6593
+ #elif defined(__VXE__) || defined(__VXE2__)
6594
+ uint32_t aux[3];
6595
+ uint32_t utmp[4];
6596
+
6597
+ const int32x4_t v_z = vec_splat_s32(0);
6598
+ const uint8x16_t v_3m = vec_splat_u8(0x03);
6599
+
6600
+ const uint8x16_t v_0c = vec_splat_u8(1);
6601
+ const uint8x16_t v_1c = vec_sl(v_0c, 1);
6602
+ const uint8x16_t v_2c = vec_sl(v_0c, 2);
6603
+ const uint8x16_t v_3c = vec_sl(v_0c, 3);
6604
+
6605
+ uint8x16_t q3h[4];
6606
+ uint8x16_t q3b[2];
6607
+ int8x16_t q3bytes[4];
6608
+ int8x16_t q8bytes[4];
6609
+ uint8x16_t qhbits[2];
6610
+
6611
+ float sum = 0;
6612
+
6613
+ for (int i = 0; i < nb; ++i) {
6614
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
6615
+
6616
+ const uint8_t * restrict x0l = x[i].qs;
6617
+ const uint8_t * restrict x0h = x[i].hmask;
6618
+ const int8_t * restrict y0 = y[i].qs;
6619
+
6620
+ qhbits[0] = vec_xl(0 , x0h);
6621
+ qhbits[1] = vec_xl(16, x0h);
6622
+
6623
+ int32_t isum = 0;
6443
6624
 
6625
+ memcpy(aux, x[i].scales, 12);
6626
+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
6627
+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
6628
+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
6629
+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
6630
+
6631
+ int8_t * scale = (int8_t *)utmp;
6632
+ for (int j = 0; j < 16; ++j) scale[j] -= 32;
6633
+
6634
+ for (int j = 0; j < QK_K/128; ++j) {
6635
+ int32x4_t isum0, isum1, isum2, isum3;
6636
+
6637
+ q3b[0] = vec_xl(0 , x0l);
6638
+ q3b[1] = vec_xl(16, x0l);
6639
+ x0l += 32;
6640
+
6641
+ q8bytes[0] = vec_xl(0 , y0);
6642
+ q8bytes[1] = vec_xl(16 , y0);
6643
+ q8bytes[2] = vec_xl(32 , y0);
6644
+ q8bytes[3] = vec_xl(48 , y0);
6645
+ q8bytes[4] = vec_xl(64 , y0);
6646
+ q8bytes[5] = vec_xl(80 , y0);
6647
+ q8bytes[6] = vec_xl(96 , y0);
6648
+ q8bytes[7] = vec_xl(112, y0);
6649
+ y0 += 128;
6650
+
6651
+ q3h[0] = vec_sl(vec_andc(v_0c, qhbits[0]), 2);
6652
+ q3h[1] = vec_sl(vec_andc(v_0c, qhbits[1]), 2);
6653
+ q3h[2] = vec_sl(vec_andc(v_1c, qhbits[0]), 1);
6654
+ q3h[3] = vec_sl(vec_andc(v_1c, qhbits[1]), 1);
6655
+
6656
+ q3bytes[0] = vec_sub((int8x16_t)vec_and(q3b[0], v_3m), (int8x16_t)q3h[0]);
6657
+ q3bytes[1] = vec_sub((int8x16_t)vec_and(q3b[1], v_3m), (int8x16_t)q3h[1]);
6658
+ q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 2), v_3m), (int8x16_t)q3h[2]);
6659
+ q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 2), v_3m), (int8x16_t)q3h[3]);
6660
+
6661
+ isum0 = ggml_vec_dot(v_z, q3bytes[0], q8bytes[0]);
6662
+ isum1 = ggml_vec_dot(v_z, q3bytes[1], q8bytes[1]);
6663
+ isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[2]);
6664
+ isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[3]);
6665
+
6666
+ isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0];
6667
+ isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1];
6668
+ isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2];
6669
+ isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3];
6670
+
6671
+ scale += 4;
6672
+
6673
+ q3h[0] = vec_andc(v_2c, qhbits[0]);
6674
+ q3h[1] = vec_andc(v_2c, qhbits[1]);
6675
+ q3h[2] = vec_sr(vec_andc(v_3c, qhbits[0]), 1);
6676
+ q3h[3] = vec_sr(vec_andc(v_3c, qhbits[1]), 1);
6677
+
6678
+ q3bytes[0] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 4), v_3m), (int8x16_t)q3h[0]);
6679
+ q3bytes[1] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 4), v_3m), (int8x16_t)q3h[1]);
6680
+ q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 6), v_3m), (int8x16_t)q3h[2]);
6681
+ q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 6), v_3m), (int8x16_t)q3h[3]);
6682
+
6683
+ isum0 = ggml_vec_dot(v_z, q3bytes[0], q8bytes[4]);
6684
+ isum1 = ggml_vec_dot(v_z, q3bytes[1], q8bytes[5]);
6685
+ isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[6]);
6686
+ isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[7]);
6687
+
6688
+ isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0];
6689
+ isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1];
6690
+ isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2];
6691
+ isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3];
6692
+
6693
+ scale += 4;
6694
+
6695
+ if (j == 0) {
6696
+ qhbits[0] = vec_sr(qhbits[0], 4);
6697
+ qhbits[1] = vec_sr(qhbits[1], 4);
6698
+ }
6699
+ }
6700
+
6701
+ sum += d * isum;
6702
+ }
6703
+
6704
+ *s = sum;
6444
6705
  #else
6445
6706
  // scalar version
6446
6707
  // This function is written like this so the compiler can manage to vectorize most of it
@@ -6924,69 +7185,181 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
6924
7185
  const uint8_t * scales = (const uint8_t*)&utmp[0];
6925
7186
  const uint8_t * mins = (const uint8_t*)&utmp[2];
6926
7187
 
7188
+ const int vector_length = __riscv_vlenb() * 8;
6927
7189
  float sumf = 0;
6928
7190
 
6929
- for (int i = 0; i < nb; ++i) {
7191
+ switch (vector_length) {
7192
+ case 256:
7193
+ for (int i = 0; i < nb; ++i) {
6930
7194
 
6931
- size_t vl = 8;
7195
+ size_t vl = 8;
6932
7196
 
6933
- const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
6934
- const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
7197
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
7198
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
6935
7199
 
6936
- vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
6937
- vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
6938
- vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
7200
+ vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
7201
+ vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
7202
+ vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
6939
7203
 
6940
- memcpy(utmp, x[i].scales, 12);
6941
- utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
6942
- const uint32_t uaux = utmp[1] & kmask1;
6943
- utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
6944
- utmp[2] = uaux;
6945
- utmp[0] &= kmask1;
7204
+ memcpy(utmp, x[i].scales, 12);
7205
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
7206
+ const uint32_t uaux = utmp[1] & kmask1;
7207
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
7208
+ utmp[2] = uaux;
7209
+ utmp[0] &= kmask1;
6946
7210
 
6947
- vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
6948
- vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
6949
- vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
7211
+ vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
7212
+ vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
7213
+ vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
6950
7214
 
6951
- vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
6952
- sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
7215
+ vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
7216
+ sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
6953
7217
 
6954
- const uint8_t * GGML_RESTRICT q4 = x[i].qs;
6955
- const int8_t * GGML_RESTRICT q8 = y[i].qs;
7218
+ const uint8_t * GGML_RESTRICT q4 = x[i].qs;
7219
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
6956
7220
 
6957
- vl = 32;
7221
+ vl = 32;
6958
7222
 
6959
- int32_t sum_1 = 0;
6960
- int32_t sum_2 = 0;
7223
+ int32_t sum_1 = 0;
7224
+ int32_t sum_2 = 0;
6961
7225
 
6962
- vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
7226
+ vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
6963
7227
 
6964
- for (int j = 0; j < QK_K/64; ++j) {
6965
- // load Q4
6966
- vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
7228
+ for (int j = 0; j < QK_K/64; ++j) {
7229
+ // load Q4
7230
+ vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
6967
7231
 
6968
- // load Q8 and multiply it with lower Q4 nibble
6969
- vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
6970
- vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
6971
- vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl);
6972
- vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl);
7232
+ // load Q8 and multiply it with lower Q4 nibble
7233
+ vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
7234
+ vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
7235
+ vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl);
7236
+ vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl);
6973
7237
 
6974
- sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0];
7238
+ sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0];
6975
7239
 
6976
- // load Q8 and multiply it with upper Q4 nibble
6977
- vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
6978
- vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
6979
- vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl);
6980
- vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl);
7240
+ // load Q8 and multiply it with upper Q4 nibble
7241
+ vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
7242
+ vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
7243
+ vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl);
7244
+ vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl);
6981
7245
 
6982
- sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1];
7246
+ sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1];
6983
7247
 
6984
- q4 += 32; q8 += 64;
7248
+ q4 += 32; q8 += 64;
6985
7249
 
6986
- }
7250
+ }
6987
7251
 
6988
- sumf += d*(sum_1 + sum_2);
7252
+ sumf += d*(sum_1 + sum_2);
7253
+
7254
+ }
7255
+ break;
7256
+ case 128:
7257
+ for (int i = 0; i < nb; ++i) {
7258
+ const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
7259
+ const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
7260
+
7261
+ int tmp, tmp2, sumi;
7262
+ __asm__ __volatile__(
7263
+ "vsetivli zero, 12, e8, m1\n\t"
7264
+ "vle8.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]}
7265
+ "vsetivli zero, 4, e32, m1\n\t"
7266
+ "vslidedown.vi v2, v1, 2\n\t"
7267
+ "vmv1r.v v3, v2\n\t"
7268
+ "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
7269
+ "vsetivli zero, 2, e32, m1\n\t"
7270
+ "vmv.v.i v4, 4\n\t"
7271
+ "vand.vx v8, v1, %[kmask1]\n\t"
7272
+ "vslide1up.vx v5, v4, zero\n\t" // {0, 4}
7273
+ "vsrl.vi v6, v1, 6\n\t"
7274
+ "vsrl.vv v7, v2, v5\n\t"
7275
+ "vand.vx v0, v6, %[kmask3]\n\t"
7276
+ "vand.vx v2, v7, %[kmask2]\n\t"
7277
+ "vsll.vi v6, v0, 4\n\t"
7278
+ "li %[t2], 8\n\t"
7279
+ "addi %[t1], %[utmp], 4\n\t"
7280
+ "vor.vv v1, v6, v2\n\t"
7281
+ "vsse32.v v8, (%[utmp]), %[t2]\n\t"
7282
+ "vsse32.v v1, (%[t1]), %[t2]\n\t"
7283
+ "vsetivli zero, 8, e16, m1\n\t"
7284
+ "vle32.v v2, (%[bsums])\n\t"
7285
+ "vnsrl.wi v0, v2, 0\n\t"
7286
+ "vnsrl.wi v1, v2, 16\n\t"
7287
+ "vadd.vv v2, v0, v1\n\t"
7288
+ "vle8.v v3, (%[mins])\n\t"
7289
+ "vzext.vf2 v4, v3\n\t"
7290
+ "vwmul.vv v6, v4, v2\n\t"
7291
+ "vmv.v.x v0, zero\n\t"
7292
+ "vsetivli zero, 8, e32, m2\n\t"
7293
+ "vredsum.vs v0, v6, v0\n\t"
7294
+ "vmv.x.s %[sumi], v0"
7295
+ : [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi)
7296
+ : [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
7297
+ , [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1)
7298
+ , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
7299
+ : "memory"
7300
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
7301
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
7302
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
7303
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
7304
+ );
7305
+ sumf -= dmin * sumi;
7306
+
7307
+ const uint8_t * restrict q4 = x[i].qs;
7308
+ const int8_t * restrict q8 = y[i].qs;
7309
+
7310
+ sumi = 0;
7311
+ const uint8_t * scale = scales;
7312
+
7313
+ for (int j = 0; j < QK_K/128; ++j) {
7314
+ int vl128 = 128, vl64 = 64, vl32 = 32;
7315
+ __asm__ __volatile__(
7316
+ "vsetvli zero, %[vl128], e8, m8\n\t"
7317
+ "vle8.v v8, (%[q8])\n\t"
7318
+ "vsetvli zero, %[vl64], e8, m4\n\t"
7319
+ "vle8.v v0, (%[q4])\n\t"
7320
+ "vsrl.vi v4, v0, 4\n\t"
7321
+ "vand.vi v0, v0, 0xF\n\t"
7322
+ "vsetvli zero, %[vl32], e8, m2\n\t"
7323
+ "vwmul.vv v28, v6, v14\n\t"
7324
+ "vwmul.vv v20, v4, v10\n\t"
7325
+ "vwmul.vv v24, v2, v12\n\t"
7326
+ "vwmul.vv v16, v0, v8\n\t"
7327
+ "vsetivli zero, 4, e32, m1\n\t"
7328
+ "vle8.v v2, (%[scale])\n\t"
7329
+ "vmv.v.x v0, zero\n\t"
7330
+ "vzext.vf4 v1, v2\n\t"
7331
+ "vsetvli zero, %[vl32], e16, m4\n\t"
7332
+ "vwredsum.vs v6, v24, v0\n\t"
7333
+ "vwredsum.vs v7, v28, v0\n\t"
7334
+ "vwredsum.vs v4, v16, v0\n\t"
7335
+ "vwredsum.vs v5, v20, v0\n\t"
7336
+ "vsetivli zero, 4, e32, m1\n\t"
7337
+ "vslideup.vi v6, v7, 1\n\t"
7338
+ "vslideup.vi v4, v5, 1\n\t"
7339
+ "vslideup.vi v4, v6, 2\n\t"
7340
+ "vmul.vv v8, v4, v1\n\t"
7341
+ "vredsum.vs v0, v8, v0\n\t"
7342
+ "vmv.x.s %[tmp], v0\n\t"
7343
+ "add %[sumi], %[sumi], %[tmp]"
7344
+ : [tmp] "=&r" (tmp), [sumi] "+&r" (sumi)
7345
+ : [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32)
7346
+ , [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale)
7347
+ : "memory"
7348
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
7349
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
7350
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
7351
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
7352
+ );
7353
+
7354
+ q4 += 64; q8 += 128; scale += 4;
7355
+ }
6989
7356
 
7357
+ sumf += d * sumi;
7358
+ }
7359
+ break;
7360
+ default:
7361
+ assert(false && "Unsupported vector length");
7362
+ break;
6990
7363
  }
6991
7364
 
6992
7365
  *s = sumf;
@@ -7722,9 +8095,9 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
7722
8095
  const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
7723
8096
  const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
7724
8097
 
7725
- vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
7726
- vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
7727
- vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
8098
+ vint16m1_t q8sums_0 = __riscv_vlse16_v_i16m1(y[i].bsums, 4, vl);
8099
+ vint16m1_t q8sums_1 = __riscv_vlse16_v_i16m1(y[i].bsums+1, 4, vl);
8100
+ vint16m1_t q8sums = __riscv_vadd_vv_i16m1(q8sums_0, q8sums_1, vl);
7728
8101
 
7729
8102
  memcpy(utmp, x[i].scales, 12);
7730
8103
  utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
@@ -7733,11 +8106,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
7733
8106
  utmp[2] = uaux;
7734
8107
  utmp[0] &= kmask1;
7735
8108
 
7736
- vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
7737
- vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
7738
- vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
8109
+ vuint8mf2_t mins8 = __riscv_vle8_v_u8mf2(mins, vl);
8110
+ vint16m1_t v_mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
8111
+ vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, v_mins, vl);
7739
8112
 
7740
- vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
8113
+ vint32m1_t sumi = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
7741
8114
  sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
7742
8115
 
7743
8116
  vl = 32;
@@ -7746,43 +8119,42 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
7746
8119
 
7747
8120
  uint8_t m = 1;
7748
8121
  vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
7749
- vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl);
8122
+ vuint8m2_t vqh = __riscv_vle8_v_u8m2(hm, vl);
7750
8123
 
7751
8124
  for (int j = 0; j < QK_K/64; ++j) {
7752
8125
  // load Q5 and Q8
7753
- vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl);
7754
- vint8m1_t q8_y1 = __riscv_vle8_v_i8m1(q8, vl);
7755
- vint8m1_t q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl);
8126
+ vuint8m2_t q5_x = __riscv_vle8_v_u8m2(q5, vl);
8127
+ vint8m2_t q8_y1 = __riscv_vle8_v_i8m2(q8, vl);
8128
+ vint8m2_t q8_y2 = __riscv_vle8_v_i8m2(q8+32, vl);
7756
8129
 
7757
8130
  // compute mask for addition
7758
- vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl));
7759
- vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
7760
- vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl);
7761
- vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_mu(vmask_1, q5_a, q5_a, 16, vl);
8131
+ vint8m2_t q5_a = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(q5_x, 0x0F, vl));
8132
+ vuint8m2_t qh_m1 = __riscv_vand_vx_u8m2(vqh, m, vl);
8133
+ vbool4_t vmask_1 = __riscv_vmsne_vx_u8m2_b4(qh_m1, 0, vl);
8134
+ vint8m2_t q5_m1 = __riscv_vadd_vx_i8m2_mu(vmask_1, q5_a, q5_a, 16, vl);
7762
8135
  m <<= 1;
7763
8136
 
7764
- vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl));
7765
- vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
7766
- vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl);
7767
- vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_mu(vmask_2, q5_l, q5_l, 16, vl);
8137
+ vint8m2_t q5_l = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vsrl_vx_u8m2(q5_x, 0x04, vl));
8138
+ vuint8m2_t qh_m2 = __riscv_vand_vx_u8m2(vqh, m, vl);
8139
+ vbool4_t vmask_2 = __riscv_vmsne_vx_u8m2_b4(qh_m2, 0, vl);
8140
+ vint8m2_t q5_m2 = __riscv_vadd_vx_i8m2_mu(vmask_2, q5_l, q5_l, 16, vl);
7768
8141
  m <<= 1;
7769
8142
 
7770
- vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl);
7771
- vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl);
8143
+ vint16m4_t v0 = __riscv_vwmul_vv_i16m4(q5_m1, q8_y1, vl);
8144
+ vint16m4_t v1 = __riscv_vwmul_vv_i16m4(q5_m2, q8_y2, vl);
7772
8145
 
7773
- vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl);
7774
- vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl);
8146
+ vint32m8_t vs1 = __riscv_vwmul_vx_i32m8(v0, scales[is++], vl);
8147
+ vint32m8_t vs2 = __riscv_vwmul_vx_i32m8(v1, scales[is++], vl);
7775
8148
 
7776
- vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl);
7777
- vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl);
8149
+ vint32m1_t vacc1 = __riscv_vredsum_vs_i32m8_i32m1(vs1, vzero, vl);
8150
+ vint32m1_t vacc2 = __riscv_vredsum_vs_i32m8_i32m1(vs2, vacc1, vl);
7778
8151
 
7779
- aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2);
8152
+ aux32 += __riscv_vmv_x_s_i32m1_i32(vacc2);
7780
8153
  q5 += 32; q8 += 64;
7781
8154
 
7782
8155
  }
7783
8156
 
7784
- vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1);
7785
- sums += __riscv_vfmv_f_s_f32m1_f32(vaux);
8157
+ sums += aux32 * d;
7786
8158
 
7787
8159
  }
7788
8160
 
@@ -8147,7 +8519,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
8147
8519
 
8148
8520
  void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
8149
8521
  assert(n % QK_K == 0);
8522
+ #ifdef __ARM_FEATURE_MATMUL_INT8
8523
+ assert((nrc == 2) || (nrc == 1));
8524
+ #else
8150
8525
  assert(nrc == 1);
8526
+ #endif
8151
8527
  UNUSED(nrc);
8152
8528
  UNUSED(bx);
8153
8529
  UNUSED(by);
@@ -8158,6 +8534,197 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
8158
8534
 
8159
8535
  const int nb = n / QK_K;
8160
8536
 
8537
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
8538
+ if (nrc == 2) {
8539
+ const block_q6_K * GGML_RESTRICT x0 = x;
8540
+ const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
8541
+ const block_q8_K * GGML_RESTRICT y0 = y;
8542
+ const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
8543
+
8544
+ float32x4_t vfsum = vdupq_n_f32(0.0f);
8545
+
8546
+ for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
8547
+ const uint8_t * GGML_RESTRICT ql0 = x0->ql;
8548
+ const uint8_t * GGML_RESTRICT ql1 = x1->ql;
8549
+ const uint8_t * GGML_RESTRICT qh0 = x0->qh;
8550
+ const uint8_t * GGML_RESTRICT qh1 = x1->qh;
8551
+ const int8_t * GGML_RESTRICT qy0 = y0->qs;
8552
+ const int8_t * GGML_RESTRICT qy1 = y1->qs;
8553
+
8554
+ const uint8x16_t mone = vdupq_n_u8(0x30);
8555
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
8556
+
8557
+ int32x4_t visum = vdupq_n_s32(0);
8558
+
8559
+ // process 8 blocks per iteration, totally 16 blocks
8560
+ for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) {
8561
+ int8x16_t vx0[8], vx1[8];
8562
+
8563
+ // de-quantize vx0[8]
8564
+ {
8565
+ const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0);
8566
+ const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0);
8567
+
8568
+ uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
8569
+ uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
8570
+ uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
8571
+ uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
8572
+
8573
+ vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
8574
+ vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
8575
+ vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
8576
+ vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
8577
+
8578
+ q6h_0 = vandq_u8(mone, qh_bits.val[0]);
8579
+ q6h_1 = vandq_u8(mone, qh_bits.val[1]);
8580
+ q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
8581
+ q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
8582
+
8583
+ vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
8584
+ vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
8585
+ vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
8586
+ vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
8587
+ }
8588
+
8589
+ // de-quantize vx1[8]
8590
+ {
8591
+ const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1);
8592
+ const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1);
8593
+
8594
+ uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
8595
+ uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
8596
+ uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
8597
+ uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
8598
+
8599
+ vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
8600
+ vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
8601
+ vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
8602
+ vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
8603
+
8604
+ q6h_0 = vandq_u8(mone, qh_bits.val[0]);
8605
+ q6h_1 = vandq_u8(mone, qh_bits.val[1]);
8606
+ q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
8607
+ q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
8608
+
8609
+ vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
8610
+ vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
8611
+ vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
8612
+ vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
8613
+ }
8614
+
8615
+ // process 16 elements (one block with same scale) per iteration
8616
+ // - vx = concat(ql, qh) - 32
8617
+ // - r1,r2,r3,r4 = smmla(vx, vy)
8618
+ for (int k = 0; k < 8; ++k) {
8619
+ const int blk = j * 8 + k;
8620
+
8621
+ const int8x16_t vy0 = vld1q_s8(qy0);
8622
+ const int8x16_t vy1 = vld1q_s8(qy1);
8623
+ qy0 += 16;
8624
+ qy1 += 16;
8625
+
8626
+ const int32x4_t block_scale = {
8627
+ x0->scales[blk],
8628
+ x0->scales[blk],
8629
+ x1->scales[blk],
8630
+ x1->scales[blk],
8631
+ };
8632
+
8633
+ // calculate four results at once with outer product
8634
+ const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
8635
+ const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
8636
+ const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
8637
+ const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
8638
+ int32x4_t vr = vdupq_n_s32(0);
8639
+ vr = vmmlaq_s32(vr, vx_l, vy_l);
8640
+ vr = vmmlaq_s32(vr, vx_h, vy_h);
8641
+
8642
+ // apply block scale, will NOT overflow
8643
+ // block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits
8644
+ visum = vmlaq_s32(visum, vr, block_scale);
8645
+ }
8646
+ }
8647
+
8648
+ // adjust bias, apply superblock scale
8649
+ {
8650
+ int32_t bias[4];
8651
+ #ifdef __ARM_FEATURE_SVE
8652
+ const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
8653
+ const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
8654
+ const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
8655
+ const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
8656
+ const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
8657
+ const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
8658
+ const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
8659
+ const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
8660
+ const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
8661
+ const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
8662
+ const svint64_t zero = svdup_n_s64(0);
8663
+ bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
8664
+ svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
8665
+ bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
8666
+ svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
8667
+ bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
8668
+ svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
8669
+ bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
8670
+ svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
8671
+ #else
8672
+ // NEON doesn't support int16 dot product, fallback to separated mul and add
8673
+ const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
8674
+ const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
8675
+
8676
+ int8x16_t scales_s8 = vld1q_s8(x0->scales);
8677
+ const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
8678
+ scales_s8 = vld1q_s8(x1->scales);
8679
+ const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
8680
+
8681
+ int32x4_t prod;
8682
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])),
8683
+ vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))),
8684
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])),
8685
+ vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1]))));
8686
+ bias[0] = vaddvq_s32(prod);
8687
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])),
8688
+ vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))),
8689
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])),
8690
+ vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1]))));
8691
+ bias[1] = vaddvq_s32(prod);
8692
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])),
8693
+ vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))),
8694
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])),
8695
+ vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1]))));
8696
+ bias[2] = vaddvq_s32(prod);
8697
+ prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])),
8698
+ vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))),
8699
+ vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])),
8700
+ vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
8701
+ bias[3] = vaddvq_s32(prod);
8702
+
8703
+ #endif
8704
+ const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
8705
+
8706
+ const float32x4_t superblock_scale = {
8707
+ GGML_FP16_TO_FP32(x0->d) * y0->d,
8708
+ GGML_FP16_TO_FP32(x0->d) * y1->d,
8709
+ GGML_FP16_TO_FP32(x1->d) * y0->d,
8710
+ GGML_FP16_TO_FP32(x1->d) * y1->d,
8711
+ };
8712
+
8713
+ visum = vsubq_s32(visum, vibias);
8714
+ vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
8715
+ }
8716
+ }
8717
+
8718
+ // vfsum = ABCD -> ACBD
8719
+ // AC -> s, BD -> (s+bs)
8720
+ vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
8721
+ vst1_f32(s, vget_low_f32 (vfsum));
8722
+ vst1_f32(s + bs, vget_high_f32(vfsum));
8723
+
8724
+ return;
8725
+ }
8726
+ #endif
8727
+
8161
8728
  #ifdef __ARM_FEATURE_SVE
8162
8729
  const int vector_length = ggml_cpu_get_sve_cnt()*8;
8163
8730
  float sum = 0;
@@ -8667,85 +9234,168 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
8667
9234
 
8668
9235
  #elif defined __riscv_v_intrinsic
8669
9236
 
9237
+ const int vector_length = __riscv_vlenb() * 8;
8670
9238
  float sumf = 0;
8671
- for (int i = 0; i < nb; ++i) {
8672
9239
 
8673
- const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
9240
+ switch (vector_length) {
9241
+ case 256:
9242
+ for (int i = 0; i < nb; ++i) {
8674
9243
 
8675
- const uint8_t * GGML_RESTRICT q6 = x[i].ql;
8676
- const uint8_t * GGML_RESTRICT qh = x[i].qh;
8677
- const int8_t * GGML_RESTRICT q8 = y[i].qs;
9244
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
8678
9245
 
8679
- const int8_t * GGML_RESTRICT scale = x[i].scales;
9246
+ const uint8_t * GGML_RESTRICT q6 = x[i].ql;
9247
+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
9248
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
8680
9249
 
8681
- size_t vl;
9250
+ const int8_t * GGML_RESTRICT scale = x[i].scales;
8682
9251
 
8683
- vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
9252
+ size_t vl;
8684
9253
 
8685
- int sum_t = 0;
8686
- int is = 0;
9254
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
8687
9255
 
8688
- for (int j = 0; j < QK_K/128; ++j) {
9256
+ int sum_t = 0;
9257
+ int is = 0;
8689
9258
 
8690
- vl = 32;
9259
+ for (int j = 0; j < QK_K/128; ++j) {
8691
9260
 
8692
- // load qh
8693
- vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
9261
+ vl = 32;
8694
9262
 
8695
- // load Q6
8696
- vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
8697
- vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
9263
+ // load qh
9264
+ vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
8698
9265
 
8699
- vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
8700
- vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
8701
- vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
8702
- vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
9266
+ // load Q6
9267
+ vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
9268
+ vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
8703
9269
 
8704
- vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
8705
- vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
8706
- vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
8707
- vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
9270
+ vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
9271
+ vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
9272
+ vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
9273
+ vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
8708
9274
 
8709
- vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
8710
- vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
8711
- vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
8712
- vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
9275
+ vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
9276
+ vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
9277
+ vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
9278
+ vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
8713
9279
 
8714
- vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
8715
- vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
8716
- vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
8717
- vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
9280
+ vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
9281
+ vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
9282
+ vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
9283
+ vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
8718
9284
 
8719
- // load Q8 and take product
8720
- vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
8721
- vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
8722
- vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
8723
- vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
9285
+ vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
9286
+ vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
9287
+ vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
9288
+ vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
8724
9289
 
8725
- vl = 16;
9290
+ // load Q8 and take product
9291
+ vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
9292
+ vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
9293
+ vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
9294
+ vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
8726
9295
 
8727
- vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
8728
- vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
8729
- vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
8730
- vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
8731
- vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
8732
- vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
8733
- vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
8734
- vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
9296
+ vl = 16;
8735
9297
 
8736
- vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
8737
- vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
8738
- vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
8739
- vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
9298
+ vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
9299
+ vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
9300
+ vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
9301
+ vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
9302
+ vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
9303
+ vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
9304
+ vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
9305
+ vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
8740
9306
 
8741
- sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
9307
+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
9308
+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
9309
+ vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
9310
+ vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
8742
9311
 
8743
- q6 += 64; qh += 32; q8 += 128; is=8;
9312
+ sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
8744
9313
 
8745
- }
9314
+ q6 += 64; qh += 32; q8 += 128; is=8;
8746
9315
 
8747
- sumf += d * sum_t;
9316
+ }
9317
+
9318
+ sumf += d * sum_t;
9319
+
9320
+ }
9321
+ break;
9322
+ case 128:
9323
+ for (int i = 0; i < nb; ++i) {
9324
+
9325
+ const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
9326
+
9327
+ const uint8_t * restrict q6 = x[i].ql;
9328
+ const uint8_t * restrict qh = x[i].qh;
9329
+ const int8_t * restrict q8 = y[i].qs;
9330
+
9331
+ const int8_t * restrict scale = x[i].scales;
9332
+
9333
+ int sum_t = 0;
9334
+ int t0;
9335
+
9336
+ for (int j = 0; j < QK_K/128; ++j) {
9337
+ __asm__ __volatile__(
9338
+ "vsetvli zero, %[vl32], e8, m2\n\t"
9339
+ "vle8.v v4, (%[qh])\n\t"
9340
+ "vsll.vi v0, v4, 4\n\t"
9341
+ "vsll.vi v2, v4, 2\n\t"
9342
+ "vsrl.vi v6, v4, 2\n\t"
9343
+ "vsetvli zero, %[vl64], e8, m4\n\t"
9344
+ "vle8.v v8, (%[q6])\n\t"
9345
+ "vsrl.vi v12, v8, 4\n\t"
9346
+ "vand.vi v8, v8, 0xF\n\t"
9347
+ "vsetvli zero, %[vl128], e8, m8\n\t"
9348
+ "vand.vx v0, v0, %[mask]\n\t"
9349
+ "vor.vv v8, v8, v0\n\t"
9350
+ "vle8.v v0, (%[q8])\n\t"
9351
+ "vsub.vx v8, v8, %[vl32]\n\t"
9352
+ "vsetvli zero, %[vl64], e8, m4\n\t"
9353
+ "vwmul.vv v16, v0, v8\n\t"
9354
+ "vwmul.vv v24, v4, v12\n\t"
9355
+ "vsetivli zero, 16, e16, m2\n\t"
9356
+ "vmv.v.x v0, zero\n\t"
9357
+ "vwredsum.vs v10, v16, v0\n\t"
9358
+ "vwredsum.vs v9, v18, v0\n\t"
9359
+ "vwredsum.vs v8, v20, v0\n\t"
9360
+ "vwredsum.vs v7, v22, v0\n\t"
9361
+ "vwredsum.vs v11, v24, v0\n\t"
9362
+ "vwredsum.vs v12, v26, v0\n\t"
9363
+ "vwredsum.vs v13, v28, v0\n\t"
9364
+ "vwredsum.vs v14, v30, v0\n\t"
9365
+ "vsetivli zero, 4, e32, m1\n\t"
9366
+ "vslideup.vi v10, v9, 1\n\t"
9367
+ "vslideup.vi v8, v7, 1\n\t"
9368
+ "vslideup.vi v11, v12, 1\n\t"
9369
+ "vslideup.vi v13, v14, 1\n\t"
9370
+ "vslideup.vi v10, v8, 2\n\t"
9371
+ "vslideup.vi v11, v13, 2\n\t"
9372
+ "vsetivli zero, 8, e32, m2\n\t"
9373
+ "vle8.v v2, (%[scale])\n\t"
9374
+ "vsext.vf4 v4, v2\n\t"
9375
+ "vmul.vv v2, v4, v10\n\t"
9376
+ "vredsum.vs v0, v2, v0\n\t"
9377
+ "vmv.x.s %[t0], v0\n\t"
9378
+ "add %[sumi], %[sumi], %[t0]"
9379
+ : [sumi] "+&r" (sum_t), [t0] "=&r" (t0)
9380
+ : [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale)
9381
+ , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
9382
+ , [mask] "r" (0x30)
9383
+ : "memory"
9384
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
9385
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
9386
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
9387
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
9388
+ );
9389
+ q6 += 64; qh += 32; q8 += 128; scale += 8;
9390
+ }
9391
+
9392
+ sumf += d * sum_t;
8748
9393
 
9394
+ }
9395
+ break;
9396
+ default:
9397
+ assert(false && "Unsupported vector length");
9398
+ break;
8749
9399
  }
8750
9400
 
8751
9401
  *s = sumf;