whispercpp 1.3.2 → 1.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +4 -2
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  19. data/ext/sources/examples/addon.node/addon.cpp +150 -31
  20. data/ext/sources/examples/addon.node/index.js +3 -0
  21. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  22. data/ext/sources/examples/bench/bench.cpp +3 -2
  23. data/ext/sources/examples/cli/cli.cpp +3 -2
  24. data/ext/sources/examples/command/command.cpp +32 -8
  25. data/ext/sources/examples/common-whisper.cpp +14 -7
  26. data/ext/sources/examples/lsp/lsp.cpp +2 -0
  27. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  28. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  29. data/ext/sources/examples/server/server.cpp +169 -22
  30. data/ext/sources/examples/stream/stream.cpp +6 -0
  31. data/ext/sources/examples/talk-llama/CMakeLists.txt +4 -1
  32. data/ext/sources/examples/talk-llama/llama-arch.cpp +171 -3
  33. data/ext/sources/examples/talk-llama/llama-arch.h +28 -1
  34. data/ext/sources/examples/talk-llama/llama-batch.cpp +741 -272
  35. data/ext/sources/examples/talk-llama/llama-batch.h +112 -54
  36. data/ext/sources/examples/talk-llama/llama-chat.cpp +30 -8
  37. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  38. data/ext/sources/examples/talk-llama/llama-context.cpp +520 -351
  39. data/ext/sources/examples/talk-llama/llama-context.h +38 -17
  40. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  41. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -1
  42. data/ext/sources/examples/talk-llama/llama-graph.cpp +447 -372
  43. data/ext/sources/examples/talk-llama/llama-graph.h +128 -58
  44. data/ext/sources/examples/talk-llama/llama-hparams.cpp +10 -2
  45. data/ext/sources/examples/talk-llama/llama-hparams.h +19 -2
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +279 -0
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.h +128 -0
  48. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +1841 -0
  49. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +303 -0
  50. data/ext/sources/examples/talk-llama/llama-kv-cache.h +14 -472
  51. data/ext/sources/examples/talk-llama/llama-kv-cells.h +86 -26
  52. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +246 -0
  53. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +138 -0
  54. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1125 -0
  55. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +183 -0
  56. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  57. data/ext/sources/examples/talk-llama/llama-memory.h +88 -4
  58. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  59. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +42 -17
  60. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  61. data/ext/sources/examples/talk-llama/llama-model.cpp +1863 -563
  62. data/ext/sources/examples/talk-llama/llama-model.h +27 -0
  63. data/ext/sources/examples/talk-llama/llama-quant.cpp +89 -6
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +65 -28
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +1 -0
  66. data/ext/sources/examples/talk-llama/llama.cpp +11 -7
  67. data/ext/sources/examples/talk-llama/llama.h +147 -40
  68. data/ext/sources/examples/talk-llama/talk-llama.cpp +2 -0
  69. data/ext/sources/examples/talk-llama/unicode.cpp +5 -0
  70. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  71. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +2 -0
  72. data/ext/sources/ggml/CMakeLists.txt +48 -3
  73. data/ext/sources/ggml/cmake/common.cmake +24 -0
  74. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  75. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  76. data/ext/sources/ggml/include/ggml.h +144 -5
  77. data/ext/sources/ggml/src/CMakeLists.txt +82 -24
  78. data/ext/sources/ggml/src/ggml-backend-reg.cpp +5 -0
  79. data/ext/sources/ggml/src/ggml-backend.cpp +46 -23
  80. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +3 -3
  81. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +1 -0
  82. data/ext/sources/ggml/src/ggml-cann/common.h +6 -1
  83. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +33 -9
  84. data/ext/sources/ggml/src/ggml-common.h +4 -0
  85. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +133 -40
  86. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -1
  87. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +4114 -0
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2163 -0
  91. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2639 -0
  92. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2732 -0
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +2069 -0
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +397 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1300 -0
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1481 -0
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +4311 -0
  99. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-aarch64.cpp → arch/x86/repack.cpp} +79 -3225
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +184 -0
  101. data/ext/sources/ggml/src/ggml-cpu/common.h +4 -3
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +16 -7
  103. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +146 -105
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +12 -8
  105. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  106. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1 -1
  107. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +58 -8
  108. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  109. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1057 -174
  110. data/ext/sources/ggml/src/ggml-cpu/ops.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/quants.c +1158 -0
  112. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +26 -0
  113. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1571 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.h +98 -0
  115. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +330 -38
  116. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +1 -1
  117. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +111 -18
  118. data/ext/sources/ggml/src/ggml-cpu/vec.h +303 -94
  119. data/ext/sources/ggml/src/ggml-cuda/common.cuh +60 -37
  120. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  121. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  122. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  123. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  124. data/ext/sources/ggml/src/ggml-cuda/convert.cu +22 -0
  125. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +5 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +2 -2
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +5 -2
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +4 -0
  129. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +265 -123
  130. data/ext/sources/ggml/src/ggml-cuda/mean.cu +19 -0
  131. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  132. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +257 -87
  133. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +2 -3
  134. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +6 -4
  135. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +5 -18
  136. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  137. data/ext/sources/ggml/src/ggml-cuda/unary.cu +89 -0
  138. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +7 -0
  139. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +4 -0
  140. data/ext/sources/ggml/src/ggml-impl.h +127 -183
  141. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +11 -10
  142. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +27 -0
  143. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +331 -49
  144. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +564 -282
  145. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  146. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +14 -0
  147. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1859 -489
  148. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  149. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  150. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +72 -0
  151. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +201 -0
  152. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +72 -0
  153. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  154. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  155. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  156. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  157. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +72 -0
  158. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  159. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  160. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  161. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +121 -0
  162. data/ext/sources/ggml/src/ggml-quants.c +6 -8
  163. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +18 -15
  164. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  165. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +5 -6
  166. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  167. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +28 -41
  168. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +4 -10
  169. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +117 -165
  170. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +192 -53
  171. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  172. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +49 -67
  173. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +31 -1
  174. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +648 -1039
  175. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +18 -9
  176. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  177. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +8 -105
  178. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -100
  179. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  180. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  181. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +60 -80
  182. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +158 -203
  183. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +55 -74
  184. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -10
  185. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +138 -27
  186. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  187. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  188. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  189. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +3 -8
  190. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +108 -16
  191. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +12 -16
  192. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  193. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +726 -282
  194. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  195. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  196. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  197. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +15 -0
  198. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  199. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  200. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +12 -3
  201. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  202. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +10 -1
  203. data/ext/sources/ggml/src/ggml.c +328 -48
  204. data/ext/sources/ggml/src/ggml.cpp +26 -0
  205. data/ext/sources/ggml/src/gguf.cpp +24 -3
  206. data/ext/sources/include/whisper.h +2 -0
  207. data/ext/sources/src/CMakeLists.txt +2 -0
  208. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  209. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  210. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  211. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  212. data/ext/sources/src/whisper.cpp +218 -169
  213. data/extsources.rb +15 -9
  214. data/lib/whisper/context.rb +15 -0
  215. data/lib/whisper/model/uri.rb +56 -1
  216. data/lib/whisper/segment.rb +58 -0
  217. data/sig/whisper.rbs +68 -38
  218. data/{tests → test}/helper.rb +1 -12
  219. data/{tests → test}/test_model.rb +9 -0
  220. data/test/test_package.rb +51 -0
  221. data/test/test_segment.rb +146 -0
  222. data/{tests → test}/test_whisper.rb +70 -0
  223. data/whispercpp.gemspec +2 -3
  224. metadata +91 -43
  225. data/ext/sources/.dockerignore +0 -3
  226. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  227. data/ext/sources/ci/run.sh +0 -336
  228. data/ext/sources/close-issue.yml +0 -28
  229. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2739
  230. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  231. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  232. data/tests/test_package.rb +0 -46
  233. data/tests/test_segment.rb +0 -74
  234. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  235. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  236. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +0 -0
  237. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  238. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  239. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  240. /data/{tests → test}/test_callback.rb +0 -0
  241. /data/{tests → test}/test_error.rb +0 -0
  242. /data/{tests → test}/test_params.rb +0 -0
  243. /data/{tests → test}/test_vad.rb +0 -0
  244. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -0,0 +1,2069 @@
1
+ #define GGML_COMMON_IMPL_C
2
+ #include "ggml-common.h"
3
+ #include "ggml-quants.h"
4
+ #include "ggml-impl.h"
5
+ #include "ggml-cpu.h"
6
+ #include "simd-mappings.h"
7
+
8
+ #include "../../quants.h"
9
+ #include "../../ggml-cpu-impl.h"
10
+
11
+ #include <math.h>
12
+ #include <string.h>
13
+ #include <assert.h>
14
+ #include <float.h>
15
+ #include <stdlib.h> // for qsort
16
+ #include <stdio.h> // for GGML_ASSERT
17
+
18
+ #define GROUP_MAX_EPS 1e-15f
19
+ #define GROUP_MAX_EPS_IQ3_XXS 1e-8f
20
+ #define GROUP_MAX_EPS_IQ2_S 1e-8f
21
+ #define GROUP_MAX_EPS_IQ1_M 1e-7f
22
+ #define GROUP_MAX_EPS_IQ1_S 1e-12f
23
+
24
+ #define UNUSED GGML_UNUSED
25
+
26
+ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
27
+ assert(QK8_0 == 32);
28
+ assert(k % QK8_0 == 0);
29
+ const int nb = k / QK8_0;
30
+
31
+ block_q8_0 * GGML_RESTRICT y = vy;
32
+
33
+ #if defined(__riscv_v)
34
+
35
+ size_t vl = QK8_0;
36
+
37
+ for (int i = 0; i < nb; i++) {
38
+ // load elements
39
+ vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_0, vl);
40
+
41
+ vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl);
42
+ vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl);
43
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl);
44
+ float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
45
+
46
+ const float d = amax / ((1 << 7) - 1);
47
+ const float id = d ? 1.0f/d : 0.0f;
48
+
49
+ y[i].d = GGML_CPU_FP32_TO_FP16(d);
50
+
51
+ vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl);
52
+
53
+ // convert to integer
54
+ vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl);
55
+ vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl);
56
+
57
+ // store result
58
+ __riscv_vse8_v_i8m2(y[i].qs , vs, vl);
59
+ }
60
+ #else
61
+ GGML_UNUSED(nb);
62
+ // scalar
63
+ quantize_row_q8_0_ref(x, y, k);
64
+ #endif
65
+ }
66
+
67
+ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
68
+ assert(k % QK8_1 == 0);
69
+ const int nb = k / QK8_1;
70
+
71
+ block_q8_1 * GGML_RESTRICT y = vy;
72
+
73
+ #if defined(__riscv_v)
74
+
75
+ size_t vl = QK8_1;
76
+
77
+ for (int i = 0; i < nb; i++) {
78
+ // load elements
79
+ vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_1, vl);
80
+
81
+ vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl);
82
+ vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl);
83
+ vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl);
84
+ float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
85
+
86
+ const float d = amax / ((1 << 7) - 1);
87
+ const float id = d ? 1.0f/d : 0.0f;
88
+
89
+ y[i].d = GGML_CPU_FP32_TO_FP16(d);
90
+
91
+ vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl);
92
+
93
+ // convert to integer
94
+ vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl);
95
+ vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl);
96
+
97
+ // store result
98
+ __riscv_vse8_v_i8m2(y[i].qs , vs, vl);
99
+
100
+ // compute sum for y[i].s
101
+ vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
102
+ vint16m1_t vwrs = __riscv_vwredsum_vs_i8m2_i16m1(vs, tmp2, vl);
103
+
104
+ // set y[i].s
105
+ int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
106
+ y[i].s = GGML_CPU_FP32_TO_FP16(sum*d);
107
+ }
108
+
109
+ #else
110
+ GGML_UNUSED(nb);
111
+ // scalar
112
+ quantize_row_q8_1_ref(x, y, k);
113
+ #endif
114
+ }
115
+
116
+ //===================================== Dot products =================================
117
+
118
+ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
119
+ const int qk = QK8_0;
120
+ const int nb = n / qk;
121
+
122
+ assert(n % qk == 0);
123
+ assert(nrc == 1);
124
+ UNUSED(nrc);
125
+ UNUSED(bx);
126
+ UNUSED(by);
127
+ UNUSED(bs);
128
+
129
+ const block_q4_0 * GGML_RESTRICT x = vx;
130
+ const block_q8_0 * GGML_RESTRICT y = vy;
131
+
132
+ int ib = 0;
133
+ float sumf = 0;
134
+
135
+ #if defined(__riscv_v)
136
+ size_t vl = qk / 2;
137
+
138
+ for (; ib < nb; ++ib) {
139
+ // load elements
140
+ vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl);
141
+
142
+ vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
143
+ vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl);
144
+
145
+ // mask and store lower part of x, and then upper part
146
+ vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
147
+ vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
148
+
149
+ vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
150
+ vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
151
+
152
+ // subtract offset
153
+ vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl);
154
+ vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl);
155
+
156
+ vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
157
+ vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl);
158
+
159
+ vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
160
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
161
+
162
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
163
+
164
+ sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);
165
+ }
166
+
167
+ #endif
168
+ for (; ib < nb; ++ib) {
169
+ int sumi0 = 0;
170
+ int sumi1 = 0;
171
+
172
+ for (int j = 0; j < qk/2; ++j) {
173
+ const int v0 = (x[ib].qs[j] & 0x0F) - 8;
174
+ const int v1 = (x[ib].qs[j] >> 4) - 8;
175
+
176
+ sumi0 += (v0 * y[ib].qs[j]);
177
+ sumi1 += (v1 * y[ib].qs[j + qk/2]);
178
+ }
179
+
180
+ int sumi = sumi0 + sumi1;
181
+ sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);
182
+ }
183
+
184
+ *s = sumf;
185
+ }
186
+
187
+ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
188
+ const int qk = QK8_1;
189
+ const int nb = n / qk;
190
+
191
+ assert(n % qk == 0);
192
+ assert(nrc == 1);
193
+ UNUSED(nrc);
194
+ UNUSED(bx);
195
+ UNUSED(by);
196
+ UNUSED(bs);
197
+
198
+ const block_q4_1 * GGML_RESTRICT x = vx;
199
+ const block_q8_1 * GGML_RESTRICT y = vy;
200
+
201
+ int ib = 0;
202
+ float sumf = 0;
203
+
204
+ #if defined(__riscv_v)
205
+ size_t vl = qk / 2;
206
+
207
+ for (; ib < nb; ++ib) {
208
+ // load elements
209
+ vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl);
210
+
211
+ vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
212
+ vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl);
213
+
214
+ // mask and store lower part of x, and then upper part
215
+ vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
216
+ vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
217
+
218
+ vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
219
+ vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
220
+
221
+ vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
222
+ vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl);
223
+
224
+ vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
225
+ vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
226
+
227
+ int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
228
+
229
+ sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
230
+ }
231
+
232
+ #endif
233
+ for (; ib < nb; ++ib) {
234
+ int sumi0 = 0;
235
+ int sumi1 = 0;
236
+
237
+ for (int j = 0; j < qk/2; ++j) {
238
+ const int v0 = (x[ib].qs[j] & 0x0F);
239
+ const int v1 = (x[ib].qs[j] >> 4);
240
+
241
+ sumi0 += (v0 * y[ib].qs[j]);
242
+ sumi1 += (v1 * y[ib].qs[j + qk/2]);
243
+ }
244
+
245
+ int sumi = sumi0 + sumi1;
246
+ sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
247
+ }
248
+
249
+ *s = sumf;
250
+ }
251
+
252
+ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
253
+ const int qk = QK8_0;
254
+ const int nb = n / qk;
255
+
256
+ int ib = 0;
257
+ float sumf = 0;
258
+
259
+ assert(n % qk == 0);
260
+ assert(qk == QK5_0);
261
+ assert(nrc == 1);
262
+ UNUSED(nrc);
263
+ UNUSED(bx);
264
+ UNUSED(by);
265
+ UNUSED(bs);
266
+
267
+ const block_q5_0 * GGML_RESTRICT x = vx;
268
+ const block_q8_0 * GGML_RESTRICT y = vy;
269
+
270
+ #if defined(__riscv_v)
271
+ size_t vl;
272
+ size_t vlenb = __riscv_vlenb();
273
+
274
+ for (; ib < nb; ++ib) {
275
+ vl = qk / 2;
276
+ vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl);
277
+ vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl));
278
+ vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl));
279
+ vint8m2_t v0c;
280
+ if (vlenb == 16) {
281
+ v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h);
282
+ } else {
283
+ v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32);
284
+ v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l);
285
+ }
286
+
287
+ vl = qk;
288
+ vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl);
289
+ qh = __riscv_vmnand_mm_b4(qh, qh, vl);
290
+ vint8m2_t v0f = __riscv_vsub_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl);
291
+ vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
292
+ vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl);
293
+ vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl);
294
+ vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl);
295
+ int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum);
296
+
297
+ sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d)) * sumi;
298
+ }
299
+
300
+ #endif
301
+ for (; ib < nb; ++ib) {
302
+ uint32_t qh;
303
+ memcpy(&qh, x[ib].qh, sizeof(qh));
304
+
305
+ int sumi0 = 0;
306
+ int sumi1 = 0;
307
+
308
+ for (int j = 0; j < qk/2; ++j) {
309
+ const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
310
+ const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
311
+
312
+ const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);
313
+ const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16);
314
+
315
+ sumi0 += (x0 * y[ib].qs[j]);
316
+ sumi1 += (x1 * y[ib].qs[j + qk/2]);
317
+ }
318
+
319
+ int sumi = sumi0 + sumi1;
320
+ sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d)) * sumi;
321
+ }
322
+
323
+ *s = sumf;
324
+ }
325
+
326
+ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
327
+ const int qk = QK8_1;
328
+ const int nb = n / qk;
329
+
330
+ int ib = 0;
331
+ float sumf = 0;
332
+
333
+ assert(n % qk == 0);
334
+ assert(qk == QK5_1);
335
+ assert(nrc == 1);
336
+ UNUSED(nrc);
337
+ UNUSED(bx);
338
+ UNUSED(by);
339
+ UNUSED(bs);
340
+
341
+ const block_q5_1 * GGML_RESTRICT x = vx;
342
+ const block_q8_1 * GGML_RESTRICT y = vy;
343
+
344
+ #if defined(__riscv_v)
345
+ size_t vl;
346
+ size_t vlenb = __riscv_vlenb();
347
+
348
+ for (; ib < nb; ++ib) {
349
+ vl = qk / 2;
350
+ vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl);
351
+ vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl));
352
+ vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl));
353
+ vint8m2_t v0c;
354
+ if (vlenb == 16) {
355
+ v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h);
356
+ } else {
357
+ v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32);
358
+ v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l);
359
+ }
360
+
361
+ vl = qk;
362
+ vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl);
363
+ vint8m2_t v0f = __riscv_vor_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl);
364
+ vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
365
+ vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl);
366
+ vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl);
367
+ vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl);
368
+ int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum);
369
+
370
+ sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
371
+ }
372
+
373
+ #endif
374
+ for (; ib < nb; ++ib) {
375
+ uint32_t qh;
376
+ memcpy(&qh, x[ib].qh, sizeof(qh));
377
+
378
+ int sumi0 = 0;
379
+ int sumi1 = 0;
380
+
381
+ for (int j = 0; j < qk/2; ++j) {
382
+ const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
383
+ const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
384
+
385
+ const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;
386
+ const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1;
387
+
388
+ sumi0 += (x0 * y[ib].qs[j]);
389
+ sumi1 += (x1 * y[ib].qs[j + qk/2]);
390
+ }
391
+
392
+ int sumi = sumi0 + sumi1;
393
+ sumf += (GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d))*sumi + GGML_CPU_FP16_TO_FP32(x[ib].m)*GGML_CPU_FP16_TO_FP32(y[ib].s);
394
+ }
395
+
396
+ *s = sumf;
397
+ }
398
+
399
+ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
400
+ const int qk = QK8_0;
401
+ const int nb = n / qk;
402
+
403
+ assert(n % qk == 0);
404
+ assert(nrc == 1);
405
+ UNUSED(nrc);
406
+ UNUSED(bx);
407
+ UNUSED(by);
408
+ UNUSED(bs);
409
+
410
+ const block_q8_0 * GGML_RESTRICT x = vx;
411
+ const block_q8_0 * GGML_RESTRICT y = vy;
412
+
413
+ int ib = 0;
414
+ float sumf = 0;
415
+
416
+ #if defined(__riscv_v)
417
+ size_t vl = qk;
418
+
419
+ for (; ib < nb; ++ib) {
420
+ // load elements
421
+ vint8m2_t bx_0 = __riscv_vle8_v_i8m2(x[ib].qs, vl);
422
+ vint8m2_t by_0 = __riscv_vle8_v_i8m2(y[ib].qs, vl);
423
+
424
+ vint16m4_t vw_mul = __riscv_vwmul_vv_i16m4(bx_0, by_0, vl);
425
+
426
+ vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
427
+ vint32m1_t v_sum = __riscv_vwredsum_vs_i16m4_i32m1(vw_mul, v_zero, vl);
428
+
429
+ int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
430
+
431
+ sumf += sumi*(GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d));
432
+ }
433
+
434
+ #endif
435
+ for (; ib < nb; ++ib) {
436
+ int sumi = 0;
437
+
438
+ for (int j = 0; j < qk; j++) {
439
+ sumi += x[ib].qs[j]*y[ib].qs[j];
440
+ }
441
+
442
+ sumf += sumi*(GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d));
443
+ }
444
+
445
+ *s = sumf;
446
+ }
447
+
448
+ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
449
+ assert(nrc == 1);
450
+ UNUSED(nrc);
451
+ UNUSED(bx);
452
+ UNUSED(by);
453
+ UNUSED(bs);
454
+
455
+ const block_q2_K * GGML_RESTRICT x = vx;
456
+ const block_q8_K * GGML_RESTRICT y = vy;
457
+
458
+ const int nb = n / QK_K;
459
+
460
+ #if defined __riscv_xtheadvector
461
+
462
+ float sumf = 0;
463
+ uint8_t atmp[16];
464
+
465
+ for (int i = 0; i < nb; ++i) {
466
+ const uint8_t * q2 = x[i].qs;
467
+ const int8_t * q8 = y[i].qs;
468
+ const uint8_t * sc = x[i].scales;
469
+ const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
470
+ const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
471
+ uint8_t *patmp = atmp;
472
+ int vsums;
473
+ int tmp;
474
+ __asm__ __volatile__(
475
+ "th.vsetvli zero, %[vl16], e8, m1\n\t"
476
+ "th.vmv.v.x v8, zero\n\t"
477
+ "th.vlb.v v1, (%[sc])\n\t"
478
+ "th.vand.vi v0, v1, 0xF\n\t"
479
+ "th.vsrl.vi v1, v1, 4\n\t"
480
+ "th.vsb.v v0, (%[scale])\n\t"
481
+ "th.vwaddu.vx v16, v1, zero\n\t"
482
+ "th.vsetvli zero, %[vl16], e16, m2\n\t"
483
+ "th.vlh.v v2, (%[bsums])\n\t"
484
+ "th.vwmul.vv v4, v16, v2\n\t"
485
+ "th.vsetvli zero, %[vl16], e32, m4\n\t"
486
+ "th.vredsum.vs v8, v4, v8\n\t"
487
+ "th.vmv.x.s %[vsums], v8"
488
+ : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums)
489
+ : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums)
490
+ , [vl16] "r" (16)
491
+ : "memory"
492
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
493
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
494
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
495
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
496
+ );
497
+ sumf += dmin * vsums;
498
+ int isum = 0;
499
+
500
+ for (int j = 0; j < QK_K/128; ++j) {
501
+ __asm__ __volatile__(
502
+ "th.vsetvli zero, %[vl32], e8, m2\n\t"
503
+ "th.vlb.v v0, (%[q2])\n\t"
504
+ "th.vsrl.vi v2, v0, 2\n\t"
505
+ "th.vsrl.vi v4, v0, 4\n\t"
506
+ "th.vsrl.vi v6, v0, 6\n\t"
507
+ "th.vand.vi v0, v0, 0x3\n\t"
508
+ "th.vand.vi v2, v2, 0x3\n\t"
509
+ "th.vand.vi v4, v4, 0x3\n\t"
510
+ "th.vsetvli zero, %[vl128], e8, m8\n\t"
511
+ "th.vlb.v v8, (%[q8])\n\t"
512
+ "th.vsetvli zero, %[vl64], e8, m4\n\t"
513
+ "th.vwmul.vv v16, v0, v8\n\t"
514
+ "th.vwmul.vv v24, v4, v12\n\t"
515
+ "th.vsetvli zero, %[vl16], e16, m2\n\t"
516
+ "th.vmv.v.x v0, zero\n\t"
517
+ "th.vwredsum.vs v10, v16, v0\n\t"
518
+ "th.vwredsum.vs v9, v18, v0\n\t"
519
+ "th.vwredsum.vs v8, v20, v0\n\t"
520
+ "th.vwredsum.vs v7, v22, v0\n\t"
521
+ "th.vwredsum.vs v11, v24, v0\n\t"
522
+ "th.vwredsum.vs v12, v26, v0\n\t"
523
+ "th.vwredsum.vs v13, v28, v0\n\t"
524
+ "th.vwredsum.vs v14, v30, v0\n\t"
525
+ "li %[tmp], 4\n\t"
526
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
527
+ "th.vslideup.vi v10, v9, 1\n\t"
528
+ "th.vslideup.vi v8, v7, 1\n\t"
529
+ "th.vslideup.vi v11, v12, 1\n\t"
530
+ "th.vslideup.vi v13, v14, 1\n\t"
531
+ "th.vslideup.vi v10, v8, 2\n\t"
532
+ "th.vslideup.vi v11, v13, 2\n\t"
533
+ "li %[tmp], 8\n\t"
534
+ "th.vsetvli zero, %[tmp], e32, m2\n\t"
535
+ "th.vlbu.v v12, (%[scale])\n\t"
536
+ "th.vmul.vv v10, v10, v12\n\t"
537
+ "th.vredsum.vs v0, v10, v0\n\t"
538
+ "th.vmv.x.s %[tmp], v0\n\t"
539
+ "add %[isum], %[isum], %[tmp]"
540
+ : [tmp] "=&r" (tmp), [isum] "+&r" (isum)
541
+ : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
542
+ , [vl16] "r" (16), [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
543
+ : "memory"
544
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
545
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
546
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
547
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
548
+ );
549
+ q2 += 32; q8 += 128; patmp += 8;
550
+ }
551
+
552
+ sumf += dall * isum;
553
+ }
554
+
555
+ *s = sumf;
556
+
557
+ #elif defined __riscv_v
558
+
559
+ float sumf = 0;
560
+ uint8_t atmp[16];
561
+
562
+ const int vector_length = __riscv_vlenb() * 8;
563
+ uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
564
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };
565
+
566
+ switch (vector_length) {
567
+ case 256:
568
+ for (int i = 0; i < nb; ++i) {
569
+ const uint8_t * q2 = x[i].qs;
570
+ const int8_t * q8 = y[i].qs;
571
+ const uint8_t * sc = x[i].scales;
572
+
573
+ const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
574
+ const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
575
+
576
+ size_t vl = 16;
577
+
578
+ vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl);
579
+ vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
580
+
581
+ vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl);
582
+
583
+ vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl);
584
+ vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
585
+ vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
586
+ vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
587
+ vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
588
+
589
+ sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums);
590
+
591
+ vl = 32;
592
+
593
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
594
+ vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl);
595
+
596
+ uint8_t is = 0;
597
+ int isum = 0;
598
+
599
+ for (int j = 0; j < QK_K / 128; ++j) {
600
+ // load Q2
601
+ vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
602
+
603
+ vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl);
604
+ vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl);
605
+ vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl);
606
+ vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl);
607
+
608
+ // duplicate scale elements for product
609
+ vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl);
610
+ vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl);
611
+ vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl);
612
+ vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl);
613
+
614
+ vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl));
615
+ vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl));
616
+ vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl));
617
+ vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl));
618
+
619
+ // load Q8
620
+ vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
621
+ vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl);
622
+ vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl);
623
+ vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl);
624
+
625
+ vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl);
626
+ vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl);
627
+ vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl);
628
+ vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl);
629
+
630
+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl);
631
+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl);
632
+
633
+ isum += __riscv_vmv_x_s_i32m1_i32(isum1);
634
+
635
+ q2 += 32;
636
+ q8 += 128;
637
+ is = 8;
638
+ }
639
+
640
+ sumf += dall * isum;
641
+ }
642
+ break;
643
+ case 128:
644
+ for (int i = 0; i < nb; ++i) {
645
+ const uint8_t * q2 = x[i].qs;
646
+ const int8_t * q8 = y[i].qs;
647
+ const uint8_t * sc = x[i].scales;
648
+ const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
649
+ const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
650
+ uint8_t *patmp = atmp;
651
+ int vsums;
652
+ int tmp;
653
+ __asm__ __volatile__(
654
+ "vsetivli zero, 16, e8, m1\n\t"
655
+ "vmv.v.x v8, zero\n\t"
656
+ "vle8.v v1, (%[sc])\n\t"
657
+ "vand.vi v0, v1, 0xF\n\t"
658
+ "vsrl.vi v1, v1, 4\n\t"
659
+ "vse8.v v0, (%[scale])\n\t"
660
+ "vsetivli zero, 16, e16, m2\n\t"
661
+ "vle16.v v2, (%[bsums])\n\t"
662
+ "vzext.vf2 v0, v1\n\t"
663
+ "vwmul.vv v4, v0, v2\n\t"
664
+ "vsetivli zero, 16, e32, m4\n\t"
665
+ "vredsum.vs v8, v4, v8\n\t"
666
+ "vmv.x.s %[vsums], v8"
667
+ : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums)
668
+ : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums)
669
+ : "memory"
670
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
671
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
672
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
673
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
674
+ );
675
+ sumf += dmin * vsums;
676
+ int isum = 0;
677
+
678
+ for (int j = 0; j < QK_K/128; ++j) {
679
+ __asm__ __volatile__(
680
+ "vsetvli zero, %[vl32], e8, m2\n\t"
681
+ "vle8.v v0, (%[q2])\n\t"
682
+ "vsrl.vi v2, v0, 2\n\t"
683
+ "vsrl.vi v4, v0, 4\n\t"
684
+ "vsrl.vi v6, v0, 6\n\t"
685
+ "vand.vi v0, v0, 0x3\n\t"
686
+ "vand.vi v2, v2, 0x3\n\t"
687
+ "vand.vi v4, v4, 0x3\n\t"
688
+ "vsetvli zero, %[vl128], e8, m8\n\t"
689
+ "vle8.v v8, (%[q8])\n\t"
690
+ "vsetvli zero, %[vl64], e8, m4\n\t"
691
+ "vwmul.vv v16, v0, v8\n\t"
692
+ "vwmul.vv v24, v4, v12\n\t"
693
+ "vsetivli zero, 16, e16, m2\n\t"
694
+ "vmv.v.x v0, zero\n\t"
695
+ "vwredsum.vs v10, v16, v0\n\t"
696
+ "vwredsum.vs v9, v18, v0\n\t"
697
+ "vwredsum.vs v8, v20, v0\n\t"
698
+ "vwredsum.vs v7, v22, v0\n\t"
699
+ "vwredsum.vs v11, v24, v0\n\t"
700
+ "vwredsum.vs v12, v26, v0\n\t"
701
+ "vwredsum.vs v13, v28, v0\n\t"
702
+ "vwredsum.vs v14, v30, v0\n\t"
703
+ "vsetivli zero, 4, e32, m1\n\t"
704
+ "vslideup.vi v10, v9, 1\n\t"
705
+ "vslideup.vi v8, v7, 1\n\t"
706
+ "vslideup.vi v11, v12, 1\n\t"
707
+ "vslideup.vi v13, v14, 1\n\t"
708
+ "vslideup.vi v10, v8, 2\n\t"
709
+ "vslideup.vi v11, v13, 2\n\t"
710
+ "vsetivli zero, 8, e32, m2\n\t"
711
+ "vle8.v v15, (%[scale])\n\t"
712
+ "vzext.vf4 v12, v15\n\t"
713
+ "vmul.vv v10, v10, v12\n\t"
714
+ "vredsum.vs v0, v10, v0\n\t"
715
+ "vmv.x.s %[tmp], v0\n\t"
716
+ "add %[isum], %[isum], %[tmp]"
717
+ : [tmp] "=&r" (tmp), [isum] "+&r" (isum)
718
+ : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
719
+ , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
720
+ : "memory"
721
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
722
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
723
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
724
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
725
+ );
726
+ q2 += 32; q8 += 128; patmp += 8;
727
+ }
728
+
729
+ sumf += dall * isum;
730
+ }
731
+ break;
732
+ default:
733
+ assert(false && "Unsupported vector length");
734
+ break;
735
+ }
736
+
737
+ *s = sumf;
738
+
739
+ #else
740
+
741
+ float sumf = 0;
742
+
743
+ for (int i = 0; i < nb; ++i) {
744
+
745
+ const uint8_t * q2 = x[i].qs;
746
+ const int8_t * q8 = y[i].qs;
747
+ const uint8_t * sc = x[i].scales;
748
+
749
+ int summs = 0;
750
+ for (int j = 0; j < 16; ++j) {
751
+ summs += y[i].bsums[j] * (sc[j] >> 4);
752
+ }
753
+
754
+ const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
755
+ const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
756
+
757
+ int isum = 0;
758
+ int is = 0;
759
+ int d;
760
+ for (int k = 0; k < QK_K/128; ++k) {
761
+ int shift = 0;
762
+ for (int j = 0; j < 4; ++j) {
763
+ d = sc[is++] & 0xF;
764
+ int isuml = 0;
765
+ for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
766
+ isum += d * isuml;
767
+ d = sc[is++] & 0xF;
768
+ isuml = 0;
769
+ for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
770
+ isum += d * isuml;
771
+ shift += 2;
772
+ q8 += 32;
773
+ }
774
+ q2 += 32;
775
+ }
776
+ sumf += dall * isum - dmin * summs;
777
+ }
778
+ *s = sumf;
779
+ #endif
780
+ }
781
+
782
+ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
783
+ assert(n % QK_K == 0);
784
+ assert(nrc == 1);
785
+ UNUSED(nrc);
786
+ UNUSED(bx);
787
+ UNUSED(by);
788
+ UNUSED(bs);
789
+
790
+ const uint32_t kmask1 = 0x03030303;
791
+ const uint32_t kmask2 = 0x0f0f0f0f;
792
+
793
+ const block_q3_K * GGML_RESTRICT x = vx;
794
+ const block_q8_K * GGML_RESTRICT y = vy;
795
+
796
+ const int nb = n / QK_K;
797
+
798
+ #if defined __riscv_xtheadvector
799
+
800
+ uint32_t utmp[4];
801
+ float sumf = 0;
802
+
803
+ for (int i = 0; i < nb; ++i) {
804
+ const uint8_t * restrict q3 = x[i].qs;
805
+ const uint8_t * restrict qh = x[i].hmask;
806
+ const int8_t * restrict q8 = y[i].qs;
807
+
808
+ int8_t * scale = (int8_t *)utmp;
809
+ int tmp;
810
+ __asm__ __volatile__(
811
+ "li %[tmp], 12\n\t"
812
+ "th.vsetvli zero, %[tmp], e8, m1\n\t"
813
+ "th.vlb.v v0, (%[s6b])\n\t"
814
+ "th.vmv.v.v v2, v0\n\t"
815
+ "li %[tmp], 2\n\t"
816
+ "th.vsetvli zero, %[tmp], e64, m1\n\t"
817
+ "th.vmv.v.x v9, %[sh]\n\t"\
818
+ "th.vslidedown.vi v1, v0, 1\n\t"
819
+ "th.vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4}
820
+ "th.vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]}
821
+ "li %[tmp], 4\n\t"
822
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
823
+ "th.vid.v v9\n\t"
824
+ "th.vmv.x.s %[tmp], v1\n\t"
825
+ "th.vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6}
826
+ "th.vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]}
827
+ "th.vsrl.vv v4, v1, v9\n\t"
828
+ "th.vsrl.vv v2, v0, v8\n\t"
829
+ "th.vand.vx v5, v4, %[kmask1]\n\t"
830
+ "th.vand.vx v3, v2, %[kmask2]\n\t"
831
+ "th.vsll.vi v6, v5, 4\n\t"
832
+ "th.vor.vv v7, v6, v3\n\t"
833
+ "li %[tmp], 16\n\t"
834
+ "th.vsetvli zero, %[tmp], e8, m1\n\t"
835
+ "th.vsub.vx v0, v7, %[c]\n\t"
836
+ "th.vsb.v v0, (%[scale])"
837
+ : [tmp] "=&r" (tmp)
838
+ : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32)
839
+ , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2)
840
+ : "memory"
841
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
842
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
843
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
844
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
845
+ );
846
+
847
+ uint8_t m = 1;
848
+ int isum = 0;
849
+ for (int j = 0; j < QK_K; j += 128) {
850
+ __asm__ __volatile__(
851
+ // fixme: use v0p7 mask layout directly
852
+ "th.vsetvli zero, %[vl32], e8, m2\n\t"
853
+ "th.vlb.v v8, (%[q3])\n\t"
854
+ "th.vsrl.vi v10, v8, 2\n\t"
855
+ "th.vsrl.vi v12, v8, 4\n\t"
856
+ "th.vsrl.vi v14, v8, 6\n\t"
857
+ "th.vand.vi v8, v8, 3\n\t"
858
+ "th.vand.vi v10, v10, 3\n\t"
859
+ "th.vand.vi v12, v12, 3\n\t"
860
+ "th.vlb.v v2, (%[qh])\n\t"
861
+ "th.vand.vx v4, v2, %[m]\n\t"
862
+ "slli %[m], %[m], 1\n\t"
863
+ "th.vmseq.vx v0, v4, zero\n\t"
864
+ "th.vadd.vi v8, v8, -4, v0.t\n\t"
865
+ "th.vand.vx v4, v2, %[m]\n\t"
866
+ "slli %[m], %[m], 1\n\t"
867
+ "th.vmseq.vx v0, v4, zero\n\t"
868
+ "th.vadd.vi v10, v10, -4, v0.t\n\t"
869
+ "th.vand.vx v4, v2, %[m]\n\t"
870
+ "slli %[m], %[m], 1\n\t"
871
+ "th.vmseq.vx v0, v4, zero\n\t"
872
+ "th.vadd.vi v12, v12, -4, v0.t\n\t"
873
+ "th.vand.vx v4, v2, %[m]\n\t"
874
+ "slli %[m], %[m], 1\n\t"
875
+ "th.vmseq.vx v0, v4, zero\n\t"
876
+ "th.vadd.vi v14, v14, -4, v0.t\n\t"
877
+ "th.vsetvli zero, %[vl128], e8, m8\n\t"
878
+ "th.vlb.v v0, (%[q8])\n\t"
879
+ "th.vsetvli zero, %[vl64], e8, m4\n\t"
880
+ "th.vwmul.vv v16, v0, v8\n\t"
881
+ "th.vwmul.vv v24, v4, v12\n\t"
882
+ "li %[tmp], 16\n\t"
883
+ "th.vsetvli zero, %[tmp], e16, m2\n\t"
884
+ "th.vmv.v.x v0, zero\n\t"
885
+ "th.vwredsum.vs v10, v16, v0\n\t"
886
+ "th.vwredsum.vs v9, v18, v0\n\t"
887
+ "th.vwredsum.vs v8, v20, v0\n\t"
888
+ "th.vwredsum.vs v7, v22, v0\n\t"
889
+ "th.vwredsum.vs v11, v24, v0\n\t"
890
+ "th.vwredsum.vs v12, v26, v0\n\t"
891
+ "th.vwredsum.vs v13, v28, v0\n\t"
892
+ "th.vwredsum.vs v14, v30, v0\n\t"
893
+ "li %[tmp], 4\n\t"
894
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
895
+ "th.vslideup.vi v10, v9, 1\n\t"
896
+ "th.vslideup.vi v8, v7, 1\n\t"
897
+ "th.vslideup.vi v11, v12, 1\n\t"
898
+ "th.vslideup.vi v13, v14, 1\n\t"
899
+ "th.vslideup.vi v10, v8, 2\n\t"
900
+ "th.vslideup.vi v11, v13, 2\n\t"
901
+ "li %[tmp], 8\n\t"
902
+ "th.vsetvli zero, %[tmp], e32, m2\n\t"
903
+ "th.vlb.v v12, (%[scale])\n\t"
904
+ "th.vmul.vv v10, v10, v12\n\t"
905
+ "th.vredsum.vs v0, v10, v0\n\t"
906
+ "th.vmv.x.s %[tmp], v0\n\t"
907
+ "add %[isum], %[isum], %[tmp]"
908
+ : [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum)
909
+ : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32)
910
+ , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8)
911
+ : "memory"
912
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
913
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
914
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
915
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
916
+ );
917
+ q3 += 32; q8 += 128; scale += 8;
918
+ }
919
+
920
+ const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
921
+ sumf += d * isum;
922
+ }
923
+
924
+ *s = sumf;
925
+
926
+ #elif defined __riscv_v
927
+
928
+ uint32_t utmp[4];
929
+ float sumf = 0;
930
+ uint32_t aux[3];
931
+ const int vector_length = __riscv_vlenb() * 8;
932
+
933
+ switch (vector_length) {
934
+ case 256:
935
+ for (int i = 0; i < nb; ++i) {
936
+
937
+ const uint8_t * GGML_RESTRICT q3 = x[i].qs;
938
+ const uint8_t * GGML_RESTRICT qh = x[i].hmask;
939
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
940
+
941
+ memcpy(aux, x[i].scales, 12);
942
+ utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
943
+ utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
944
+ utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
945
+ utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
946
+
947
+ int8_t * scale = (int8_t *)utmp;
948
+ for (int j = 0; j < 16; ++j) scale[j] -= 32;
949
+
950
+
951
+ size_t vl = 32;
952
+ uint8_t m = 1;
953
+
954
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
955
+ vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl);
956
+
957
+ int sum_t = 0;
958
+
959
+ for (int j = 0; j < QK_K; j += 128) {
960
+
961
+ vl = 32;
962
+
963
+ // load Q3
964
+ vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
965
+
966
+ vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl));
967
+ vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl));
968
+ vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
969
+ vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
970
+
971
+ // compute mask for subtraction
972
+ vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
973
+ vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
974
+ vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl);
975
+ m <<= 1;
976
+
977
+ vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
978
+ vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
979
+ vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl);
980
+ m <<= 1;
981
+
982
+ vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
983
+ vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
984
+ vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl);
985
+ m <<= 1;
986
+
987
+ vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
988
+ vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
989
+ vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl);
990
+ m <<= 1;
991
+
992
+ // load Q8 and take product with Q3
993
+ vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl);
994
+ vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
995
+ vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
996
+ vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
997
+
998
+ vl = 16;
999
+
1000
+ // retrieve lane to multiply with scale
1001
+ vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
1002
+ vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
1003
+ vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
1004
+ vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
1005
+ vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
1006
+ vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
1007
+ vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
1008
+ vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
1009
+
1010
+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl);
1011
+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl);
1012
+ vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl);
1013
+ vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl);
1014
+
1015
+ sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
1016
+
1017
+ q3 += 32; q8 += 128; scale += 8;
1018
+
1019
+ }
1020
+
1021
+ const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1022
+
1023
+ sumf += d*sum_t;
1024
+
1025
+ }
1026
+ break;
1027
+ case 128:
1028
+ for (int i = 0; i < nb; ++i) {
1029
+ const uint8_t * restrict q3 = x[i].qs;
1030
+ const uint8_t * restrict qh = x[i].hmask;
1031
+ const int8_t * restrict q8 = y[i].qs;
1032
+
1033
+ int8_t * scale = (int8_t *)utmp;
1034
+ int tmp;
1035
+ __asm__ __volatile__(
1036
+ "vsetivli zero, 12, e8, m1\n\t"
1037
+ "vle8.v v0, (%[s6b])\n\t"
1038
+ "vmv1r.v v2, v0\n\t"
1039
+ "vsetivli zero, 2, e64, m1\n\t"
1040
+ "vmv.v.x v9, %[sh]\n\t"\
1041
+ "vslidedown.vi v1, v0, 1\n\t"
1042
+ "vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4}
1043
+ "vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]}
1044
+ "vsetivli zero, 4, e32, m1\n\t"
1045
+ "vid.v v9\n\t"
1046
+ "vmv.x.s %[tmp], v1\n\t"
1047
+ "vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6}
1048
+ "vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]}
1049
+ "vsrl.vv v4, v1, v9\n\t"
1050
+ "vsrl.vv v2, v0, v8\n\t"
1051
+ "vand.vx v5, v4, %[kmask1]\n\t"
1052
+ "vand.vx v3, v2, %[kmask2]\n\t"
1053
+ "vsll.vi v6, v5, 4\n\t"
1054
+ "vor.vv v7, v6, v3\n\t"
1055
+ "vsetivli zero, 16, e8, m1\n\t"
1056
+ "vsub.vx v0, v7, %[c]\n\t"
1057
+ "vse8.v v0, (%[scale])"
1058
+ : [tmp] "=&r" (tmp)
1059
+ : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32)
1060
+ , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2)
1061
+ : "memory"
1062
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1063
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1064
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1065
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1066
+ );
1067
+
1068
+ uint8_t m = 1;
1069
+ int isum = 0;
1070
+ for (int j = 0; j < QK_K; j += 128) {
1071
+ __asm__ __volatile__(
1072
+ "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t"
1073
+ "vle8.v v8, (%[q3])\n\t"
1074
+ "vsrl.vi v10, v8, 2\n\t"
1075
+ "vsrl.vi v12, v8, 4\n\t"
1076
+ "vsrl.vi v14, v8, 6\n\t"
1077
+ "vand.vi v8, v8, 3\n\t"
1078
+ "vand.vi v10, v10, 3\n\t"
1079
+ "vand.vi v12, v12, 3\n\t"
1080
+ "vle8.v v2, (%[qh])\n\t"
1081
+ "vand.vx v4, v2, %[m]\n\t"
1082
+ "slli %[m], %[m], 1\n\t"
1083
+ "vmseq.vx v0, v4, zero\n\t"
1084
+ "vadd.vi v8, v8, -4, v0.t\n\t"
1085
+ "vand.vx v4, v2, %[m]\n\t"
1086
+ "slli %[m], %[m], 1\n\t"
1087
+ "vmseq.vx v0, v4, zero\n\t"
1088
+ "vadd.vi v10, v10, -4, v0.t\n\t"
1089
+ "vand.vx v4, v2, %[m]\n\t"
1090
+ "slli %[m], %[m], 1\n\t"
1091
+ "vmseq.vx v0, v4, zero\n\t"
1092
+ "vadd.vi v12, v12, -4, v0.t\n\t"
1093
+ "vand.vx v4, v2, %[m]\n\t"
1094
+ "slli %[m], %[m], 1\n\t"
1095
+ "vmseq.vx v0, v4, zero\n\t"
1096
+ "vadd.vi v14, v14, -4, v0.t\n\t"
1097
+ "vsetvli zero, %[vl128], e8, m8\n\t"
1098
+ "vle8.v v0, (%[q8])\n\t"
1099
+ "vsetvli zero, %[vl64], e8, m4\n\t"
1100
+ "vwmul.vv v16, v0, v8\n\t"
1101
+ "vwmul.vv v24, v4, v12\n\t"
1102
+ "vsetivli zero, 16, e16, m2\n\t"
1103
+ "vmv.v.x v0, zero\n\t"
1104
+ "vwredsum.vs v10, v16, v0\n\t"
1105
+ "vwredsum.vs v9, v18, v0\n\t"
1106
+ "vwredsum.vs v8, v20, v0\n\t"
1107
+ "vwredsum.vs v7, v22, v0\n\t"
1108
+ "vwredsum.vs v11, v24, v0\n\t"
1109
+ "vwredsum.vs v12, v26, v0\n\t"
1110
+ "vwredsum.vs v13, v28, v0\n\t"
1111
+ "vwredsum.vs v14, v30, v0\n\t"
1112
+ "vsetivli zero, 4, e32, m1\n\t"
1113
+ "vslideup.vi v10, v9, 1\n\t"
1114
+ "vslideup.vi v8, v7, 1\n\t"
1115
+ "vslideup.vi v11, v12, 1\n\t"
1116
+ "vslideup.vi v13, v14, 1\n\t"
1117
+ "vslideup.vi v10, v8, 2\n\t"
1118
+ "vslideup.vi v11, v13, 2\n\t"
1119
+ "vsetivli zero, 8, e32, m2\n\t"
1120
+ "vle8.v v15, (%[scale])\n\t"
1121
+ "vsext.vf4 v12, v15\n\t"
1122
+ "vmul.vv v10, v10, v12\n\t"
1123
+ "vredsum.vs v0, v10, v0\n\t"
1124
+ "vmv.x.s %[tmp], v0\n\t"
1125
+ "add %[isum], %[isum], %[tmp]"
1126
+ : [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum)
1127
+ : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32)
1128
+ , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8)
1129
+ : "memory"
1130
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1131
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1132
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1133
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1134
+ );
1135
+ q3 += 32; q8 += 128; scale += 8;
1136
+ }
1137
+
1138
+ const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1139
+ sumf += d * isum;
1140
+ }
1141
+ break;
1142
+ default:
1143
+ assert(false && "Unsupported vector length");
1144
+ break;
1145
+ }
1146
+
1147
+ *s = sumf;
1148
+
1149
+ #else
1150
+ // scalar version
1151
+ // This function is written like this so the compiler can manage to vectorize most of it
1152
+ // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
1153
+ // manually vectorized version above. Every other version I tried would run at least 4 times slower.
1154
+ // The ideal situation would be if we could just write the code once, and the compiler would
1155
+ // automatically produce the best possible set of machine instructions, instead of us having to manually
1156
+ // write vectorized versions for AVX, ARM_NEON, etc.
1157
+
1158
+ int8_t aux8[QK_K];
1159
+ int16_t aux16[8];
1160
+ float sums [8];
1161
+ int32_t aux32[8];
1162
+ memset(sums, 0, 8*sizeof(float));
1163
+
1164
+ uint32_t auxs[4];
1165
+ const int8_t * scales = (const int8_t*)auxs;
1166
+
1167
+ float sumf = 0;
1168
+ for (int i = 0; i < nb; ++i) {
1169
+ const uint8_t * GGML_RESTRICT q3 = x[i].qs;
1170
+ const uint8_t * GGML_RESTRICT hm = x[i].hmask;
1171
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1172
+ memset(aux32, 0, 8*sizeof(int32_t));
1173
+ int8_t * GGML_RESTRICT a = aux8;
1174
+ uint8_t m = 1;
1175
+ for (int j = 0; j < QK_K; j += 128) {
1176
+ for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
1177
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
1178
+ a += 32; m <<= 1;
1179
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
1180
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
1181
+ a += 32; m <<= 1;
1182
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
1183
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
1184
+ a += 32; m <<= 1;
1185
+ for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
1186
+ for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
1187
+ a += 32; m <<= 1;
1188
+ q3 += 32;
1189
+ }
1190
+ a = aux8;
1191
+
1192
+ memcpy(auxs, x[i].scales, 12);
1193
+ uint32_t tmp = auxs[2];
1194
+ auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
1195
+ auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
1196
+ auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
1197
+ auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
1198
+ for (int j = 0; j < QK_K/16; ++j) {
1199
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1200
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
1201
+ q8 += 8; a += 8;
1202
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1203
+ for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
1204
+ q8 += 8; a += 8;
1205
+ }
1206
+ const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1207
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
1208
+ }
1209
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
1210
+ *s = sumf;
1211
+
1212
+ #endif
1213
+
1214
+ }
1215
+
1216
+ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1217
+ assert(n % QK_K == 0);
1218
+ assert(nrc == 1);
1219
+ UNUSED(nrc);
1220
+ UNUSED(bx);
1221
+ UNUSED(by);
1222
+ UNUSED(bs);
1223
+
1224
+ const block_q4_K * GGML_RESTRICT x = vx;
1225
+ const block_q8_K * GGML_RESTRICT y = vy;
1226
+
1227
+ const int nb = n / QK_K;
1228
+
1229
+ static const uint32_t kmask1 = 0x3f3f3f3f;
1230
+ static const uint32_t kmask2 = 0x0f0f0f0f;
1231
+ static const uint32_t kmask3 = 0x03030303;
1232
+
1233
+ uint32_t utmp[4];
1234
+
1235
+ #if defined __riscv_xtheadvector
1236
+
1237
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
1238
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
1239
+
1240
+ float sumf = 0;
1241
+
1242
+ for (int i = 0; i < nb; ++i) {
1243
+ const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1244
+ const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1245
+
1246
+ int tmp, tmp2, sumi;
1247
+ __asm__ __volatile__(
1248
+ "li %[t1], 12\n\t"
1249
+ "th.vsetvli zero, %[t1], e8, m1\n\t"
1250
+ "th.vlb.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]}
1251
+ "li %[t1], 4\n\t"
1252
+ "th.vsetvli zero, %[t1], e32, m1\n\t"
1253
+ "th.vslidedown.vi v2, v1, 2\n\t"
1254
+ "th.vmv.v.v v3, v2\n\t"
1255
+ "th.vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
1256
+ "li %[t1], 2\n\t"
1257
+ "th.vsetvli zero, %[t1], e32, m1\n\t"
1258
+ "th.vmv.v.i v4, 4\n\t"
1259
+ "th.vand.vx v8, v1, %[kmask1]\n\t"
1260
+ "th.vslide1up.vx v5, v4, zero\n\t" // {0, 4}
1261
+ "th.vsrl.vi v6, v1, 6\n\t"
1262
+ "th.vsrl.vv v7, v2, v5\n\t"
1263
+ "th.vand.vx v0, v6, %[kmask3]\n\t"
1264
+ "th.vand.vx v2, v7, %[kmask2]\n\t"
1265
+ "th.vsll.vi v6, v0, 4\n\t"
1266
+ "li %[t2], 8\n\t"
1267
+ "addi %[t1], %[utmp], 4\n\t"
1268
+ "th.vor.vv v1, v6, v2\n\t"
1269
+ "th.vssw.v v8, (%[utmp]), %[t2]\n\t"
1270
+ "th.vssw.v v1, (%[t1]), %[t2]\n\t"
1271
+ "th.vsetvli zero, zero, e32, m2\n\t" // vl == 8
1272
+ "th.vlw.v v2, (%[bsums])\n\t"
1273
+ "th.vsetvli zero, %[t2], e16, m1\n\t"
1274
+ "th.vnsrl.vi v0, v2, 0\n\t"
1275
+ "th.vnsrl.vi v1, v2, 16\n\t"
1276
+ "th.vadd.vv v2, v0, v1\n\t"
1277
+ "th.vlbu.v v4, (%[mins])\n\t"
1278
+ "th.vwmul.vv v6, v4, v2\n\t"
1279
+ "th.vmv.v.x v0, zero\n\t"
1280
+ "th.vsetvli zero, %[t2], e32, m2\n\t"
1281
+ "th.vredsum.vs v0, v6, v0\n\t"
1282
+ "th.vmv.x.s %[sumi], v0"
1283
+ : [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi)
1284
+ : [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
1285
+ , [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1)
1286
+ , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
1287
+ : "memory"
1288
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1289
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1290
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1291
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1292
+ );
1293
+ sumf -= dmin * sumi;
1294
+
1295
+ const uint8_t * restrict q4 = x[i].qs;
1296
+ const int8_t * restrict q8 = y[i].qs;
1297
+
1298
+ sumi = 0;
1299
+ const uint8_t * scale = scales;
1300
+
1301
+ for (int j = 0; j < QK_K/128; ++j) {
1302
+ int vl128 = 128, vl64 = 64, vl32 = 32;
1303
+ __asm__ __volatile__(
1304
+ "th.vsetvli zero, %[vl128], e8, m8\n\t"
1305
+ "th.vlb.v v8, (%[q8])\n\t"
1306
+ "th.vsetvli zero, %[vl64], e8, m4\n\t"
1307
+ "th.vlb.v v0, (%[q4])\n\t"
1308
+ "th.vsrl.vi v4, v0, 4\n\t"
1309
+ "th.vand.vi v0, v0, 0xF\n\t"
1310
+ "th.vsetvli zero, %[vl32], e8, m2\n\t"
1311
+ "th.vwmul.vv v28, v6, v14\n\t"
1312
+ "th.vwmul.vv v20, v4, v10\n\t"
1313
+ "th.vwmul.vv v24, v2, v12\n\t"
1314
+ "th.vwmul.vv v16, v0, v8\n\t"
1315
+ "li %[tmp], 4\n\t"
1316
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
1317
+ "th.vlbu.v v1, (%[scale])\n\t"
1318
+ "th.vmv.v.x v0, zero\n\t"
1319
+ "th.vsetvli zero, %[vl32], e16, m4\n\t"
1320
+ "th.vwredsum.vs v6, v24, v0\n\t"
1321
+ "th.vwredsum.vs v7, v28, v0\n\t"
1322
+ "th.vwredsum.vs v4, v16, v0\n\t"
1323
+ "th.vwredsum.vs v5, v20, v0\n\t"
1324
+ "th.vsetvli zero, %[tmp], e32, m1\n\t"
1325
+ "th.vslideup.vi v6, v7, 1\n\t"
1326
+ "th.vslideup.vi v4, v5, 1\n\t"
1327
+ "th.vslideup.vi v4, v6, 2\n\t"
1328
+ "th.vmul.vv v8, v4, v1\n\t"
1329
+ "th.vredsum.vs v0, v8, v0\n\t"
1330
+ "th.vmv.x.s %[tmp], v0\n\t"
1331
+ "add %[sumi], %[sumi], %[tmp]"
1332
+ : [tmp] "=&r" (tmp), [sumi] "+&r" (sumi)
1333
+ : [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32)
1334
+ , [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale)
1335
+ : "memory"
1336
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1337
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1338
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1339
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1340
+ );
1341
+
1342
+ q4 += 64; q8 += 128; scale += 4;
1343
+ }
1344
+
1345
+ sumf += d * sumi;
1346
+
1347
+ }
1348
+
1349
+ *s = sumf;
1350
+
1351
+ #elif defined __riscv_v
1352
+
1353
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
1354
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
1355
+
1356
+ float sumf = 0;
1357
+ const int vector_length = __riscv_vlenb() * 8;
1358
+
1359
+ switch (vector_length) {
1360
+ case 256:
1361
+ for (int i = 0; i < nb; ++i) {
1362
+
1363
+ size_t vl = 8;
1364
+
1365
+ const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1366
+ const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1367
+
1368
+ vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
1369
+ vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
1370
+ vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
1371
+
1372
+ memcpy(utmp, x[i].scales, 12);
1373
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1374
+ const uint32_t uaux = utmp[1] & kmask1;
1375
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1376
+ utmp[2] = uaux;
1377
+ utmp[0] &= kmask1;
1378
+
1379
+ vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
1380
+ vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
1381
+ vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
1382
+
1383
+ vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
1384
+ sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
1385
+
1386
+ const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1387
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1388
+
1389
+ vl = 32;
1390
+
1391
+ int32_t sum_1 = 0;
1392
+ int32_t sum_2 = 0;
1393
+
1394
+ vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
1395
+
1396
+ for (int j = 0; j < QK_K/64; ++j) {
1397
+ // load Q4
1398
+ vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
1399
+
1400
+ // load Q8 and multiply it with lower Q4 nibble
1401
+ vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
1402
+ vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
1403
+ vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl);
1404
+ vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl);
1405
+
1406
+ sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0];
1407
+
1408
+ // load Q8 and multiply it with upper Q4 nibble
1409
+ vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
1410
+ vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
1411
+ vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl);
1412
+ vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl);
1413
+
1414
+ sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1];
1415
+
1416
+ q4 += 32; q8 += 64;
1417
+
1418
+ }
1419
+
1420
+ sumf += d*(sum_1 + sum_2);
1421
+
1422
+ }
1423
+ break;
1424
+ case 128:
1425
+ for (int i = 0; i < nb; ++i) {
1426
+ const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1427
+ const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1428
+
1429
+ int tmp, tmp2, sumi;
1430
+ __asm__ __volatile__(
1431
+ "vsetivli zero, 12, e8, m1\n\t"
1432
+ "vle8.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]}
1433
+ "vsetivli zero, 4, e32, m1\n\t"
1434
+ "vslidedown.vi v2, v1, 2\n\t"
1435
+ "vmv1r.v v3, v2\n\t"
1436
+ "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
1437
+ "vsetivli zero, 2, e32, m1\n\t"
1438
+ "vmv.v.i v4, 4\n\t"
1439
+ "vand.vx v8, v1, %[kmask1]\n\t"
1440
+ "vslide1up.vx v5, v4, zero\n\t" // {0, 4}
1441
+ "vsrl.vi v6, v1, 6\n\t"
1442
+ "vsrl.vv v7, v2, v5\n\t"
1443
+ "vand.vx v0, v6, %[kmask3]\n\t"
1444
+ "vand.vx v2, v7, %[kmask2]\n\t"
1445
+ "vsll.vi v6, v0, 4\n\t"
1446
+ "li %[t2], 8\n\t"
1447
+ "addi %[t1], %[utmp], 4\n\t"
1448
+ "vor.vv v1, v6, v2\n\t"
1449
+ "vsse32.v v8, (%[utmp]), %[t2]\n\t"
1450
+ "vsse32.v v1, (%[t1]), %[t2]\n\t"
1451
+ "vsetivli zero, 8, e16, m1\n\t"
1452
+ "vle32.v v2, (%[bsums])\n\t"
1453
+ "vnsrl.wi v0, v2, 0\n\t"
1454
+ "vnsrl.wi v1, v2, 16\n\t"
1455
+ "vadd.vv v2, v0, v1\n\t"
1456
+ "vle8.v v3, (%[mins])\n\t"
1457
+ "vzext.vf2 v4, v3\n\t"
1458
+ "vwmul.vv v6, v4, v2\n\t"
1459
+ "vmv.v.x v0, zero\n\t"
1460
+ "vsetivli zero, 8, e32, m2\n\t"
1461
+ "vredsum.vs v0, v6, v0\n\t"
1462
+ "vmv.x.s %[sumi], v0"
1463
+ : [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi)
1464
+ : [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
1465
+ , [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1)
1466
+ , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
1467
+ : "memory"
1468
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1469
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1470
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1471
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1472
+ );
1473
+ sumf -= dmin * sumi;
1474
+
1475
+ const uint8_t * restrict q4 = x[i].qs;
1476
+ const int8_t * restrict q8 = y[i].qs;
1477
+
1478
+ sumi = 0;
1479
+ const uint8_t * scale = scales;
1480
+
1481
+ for (int j = 0; j < QK_K/128; ++j) {
1482
+ int vl128 = 128, vl64 = 64, vl32 = 32;
1483
+ __asm__ __volatile__(
1484
+ "vsetvli zero, %[vl128], e8, m8\n\t"
1485
+ "vle8.v v8, (%[q8])\n\t"
1486
+ "vsetvli zero, %[vl64], e8, m4\n\t"
1487
+ "vle8.v v0, (%[q4])\n\t"
1488
+ "vsrl.vi v4, v0, 4\n\t"
1489
+ "vand.vi v0, v0, 0xF\n\t"
1490
+ "vsetvli zero, %[vl32], e8, m2\n\t"
1491
+ "vwmul.vv v28, v6, v14\n\t"
1492
+ "vwmul.vv v20, v4, v10\n\t"
1493
+ "vwmul.vv v24, v2, v12\n\t"
1494
+ "vwmul.vv v16, v0, v8\n\t"
1495
+ "vsetivli zero, 4, e32, m1\n\t"
1496
+ "vle8.v v2, (%[scale])\n\t"
1497
+ "vmv.v.x v0, zero\n\t"
1498
+ "vzext.vf4 v1, v2\n\t"
1499
+ "vsetvli zero, %[vl32], e16, m4\n\t"
1500
+ "vwredsum.vs v6, v24, v0\n\t"
1501
+ "vwredsum.vs v7, v28, v0\n\t"
1502
+ "vwredsum.vs v4, v16, v0\n\t"
1503
+ "vwredsum.vs v5, v20, v0\n\t"
1504
+ "vsetivli zero, 4, e32, m1\n\t"
1505
+ "vslideup.vi v6, v7, 1\n\t"
1506
+ "vslideup.vi v4, v5, 1\n\t"
1507
+ "vslideup.vi v4, v6, 2\n\t"
1508
+ "vmul.vv v8, v4, v1\n\t"
1509
+ "vredsum.vs v0, v8, v0\n\t"
1510
+ "vmv.x.s %[tmp], v0\n\t"
1511
+ "add %[sumi], %[sumi], %[tmp]"
1512
+ : [tmp] "=&r" (tmp), [sumi] "+&r" (sumi)
1513
+ : [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32)
1514
+ , [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale)
1515
+ : "memory"
1516
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1517
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1518
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1519
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1520
+ );
1521
+
1522
+ q4 += 64; q8 += 128; scale += 4;
1523
+ }
1524
+
1525
+ sumf += d * sumi;
1526
+ }
1527
+ break;
1528
+ default:
1529
+ assert(false && "Unsupported vector length");
1530
+ break;
1531
+ }
1532
+
1533
+ *s = sumf;
1534
+
1535
+ #else
1536
+
1537
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
1538
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
1539
+
1540
+ int8_t aux8[QK_K];
1541
+ int16_t aux16[8];
1542
+ float sums [8];
1543
+ int32_t aux32[8];
1544
+ memset(sums, 0, 8*sizeof(float));
1545
+
1546
+ float sumf = 0;
1547
+ for (int i = 0; i < nb; ++i) {
1548
+ const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1549
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1550
+ memset(aux32, 0, 8*sizeof(int32_t));
1551
+ int8_t * GGML_RESTRICT a = aux8;
1552
+ for (int j = 0; j < QK_K/64; ++j) {
1553
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
1554
+ a += 32;
1555
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
1556
+ a += 32; q4 += 32;
1557
+ }
1558
+ memcpy(utmp, x[i].scales, 12);
1559
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1560
+ const uint32_t uaux = utmp[1] & kmask1;
1561
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1562
+ utmp[2] = uaux;
1563
+ utmp[0] &= kmask1;
1564
+
1565
+ int sumi = 0;
1566
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
1567
+ a = aux8;
1568
+ int is = 0;
1569
+ for (int j = 0; j < QK_K/32; ++j) {
1570
+ int32_t scale = scales[is++];
1571
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1572
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1573
+ q8 += 8; a += 8;
1574
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1575
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1576
+ q8 += 8; a += 8;
1577
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1578
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1579
+ q8 += 8; a += 8;
1580
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1581
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1582
+ q8 += 8; a += 8;
1583
+ }
1584
+ const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1585
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
1586
+ const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
1587
+ sumf -= dmin * sumi;
1588
+ }
1589
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
1590
+ *s = sumf;
1591
+ #endif
1592
+ }
1593
+
1594
+ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1595
+ assert(n % QK_K == 0);
1596
+ assert(nrc == 1);
1597
+ UNUSED(nrc);
1598
+ UNUSED(bx);
1599
+ UNUSED(by);
1600
+ UNUSED(bs);
1601
+
1602
+ const block_q5_K * GGML_RESTRICT x = vx;
1603
+ const block_q8_K * GGML_RESTRICT y = vy;
1604
+
1605
+ const int nb = n / QK_K;
1606
+
1607
+ static const uint32_t kmask1 = 0x3f3f3f3f;
1608
+ static const uint32_t kmask2 = 0x0f0f0f0f;
1609
+ static const uint32_t kmask3 = 0x03030303;
1610
+
1611
+ uint32_t utmp[4];
1612
+
1613
+ #if defined __riscv_v
1614
+
1615
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
1616
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
1617
+
1618
+ float sumf = 0;
1619
+ float sums = 0.0;
1620
+
1621
+ size_t vl;
1622
+
1623
+ for (int i = 0; i < nb; ++i) {
1624
+
1625
+ vl = 8;
1626
+
1627
+ const uint8_t * GGML_RESTRICT q5 = x[i].qs;
1628
+ const uint8_t * GGML_RESTRICT hm = x[i].qh;
1629
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1630
+
1631
+ const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1632
+ const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
1633
+
1634
+ vint16m1_t q8sums_0 = __riscv_vlse16_v_i16m1(y[i].bsums, 4, vl);
1635
+ vint16m1_t q8sums_1 = __riscv_vlse16_v_i16m1(y[i].bsums+1, 4, vl);
1636
+ vint16m1_t q8sums = __riscv_vadd_vv_i16m1(q8sums_0, q8sums_1, vl);
1637
+
1638
+ memcpy(utmp, x[i].scales, 12);
1639
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1640
+ const uint32_t uaux = utmp[1] & kmask1;
1641
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1642
+ utmp[2] = uaux;
1643
+ utmp[0] &= kmask1;
1644
+
1645
+ vuint8mf2_t mins8 = __riscv_vle8_v_u8mf2(mins, vl);
1646
+ vint16m1_t v_mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
1647
+ vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, v_mins, vl);
1648
+
1649
+ vint32m1_t sumi = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
1650
+ sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
1651
+
1652
+ vl = 32;
1653
+ int32_t aux32 = 0;
1654
+ int is = 0;
1655
+
1656
+ uint8_t m = 1;
1657
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
1658
+ vuint8m2_t vqh = __riscv_vle8_v_u8m2(hm, vl);
1659
+
1660
+ for (int j = 0; j < QK_K/64; ++j) {
1661
+ // load Q5 and Q8
1662
+ vuint8m2_t q5_x = __riscv_vle8_v_u8m2(q5, vl);
1663
+ vint8m2_t q8_y1 = __riscv_vle8_v_i8m2(q8, vl);
1664
+ vint8m2_t q8_y2 = __riscv_vle8_v_i8m2(q8+32, vl);
1665
+
1666
+ // compute mask for addition
1667
+ vint8m2_t q5_a = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(q5_x, 0x0F, vl));
1668
+ vuint8m2_t qh_m1 = __riscv_vand_vx_u8m2(vqh, m, vl);
1669
+ vbool4_t vmask_1 = __riscv_vmsne_vx_u8m2_b4(qh_m1, 0, vl);
1670
+ vint8m2_t q5_m1 = __riscv_vadd_vx_i8m2_mu(vmask_1, q5_a, q5_a, 16, vl);
1671
+ m <<= 1;
1672
+
1673
+ vint8m2_t q5_l = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vsrl_vx_u8m2(q5_x, 0x04, vl));
1674
+ vuint8m2_t qh_m2 = __riscv_vand_vx_u8m2(vqh, m, vl);
1675
+ vbool4_t vmask_2 = __riscv_vmsne_vx_u8m2_b4(qh_m2, 0, vl);
1676
+ vint8m2_t q5_m2 = __riscv_vadd_vx_i8m2_mu(vmask_2, q5_l, q5_l, 16, vl);
1677
+ m <<= 1;
1678
+
1679
+ vint16m4_t v0 = __riscv_vwmul_vv_i16m4(q5_m1, q8_y1, vl);
1680
+ vint16m4_t v1 = __riscv_vwmul_vv_i16m4(q5_m2, q8_y2, vl);
1681
+
1682
+ vint32m8_t vs1 = __riscv_vwmul_vx_i32m8(v0, scales[is++], vl);
1683
+ vint32m8_t vs2 = __riscv_vwmul_vx_i32m8(v1, scales[is++], vl);
1684
+
1685
+ vint32m1_t vacc1 = __riscv_vredsum_vs_i32m8_i32m1(vs1, vzero, vl);
1686
+ vint32m1_t vacc2 = __riscv_vredsum_vs_i32m8_i32m1(vs2, vacc1, vl);
1687
+
1688
+ aux32 += __riscv_vmv_x_s_i32m1_i32(vacc2);
1689
+ q5 += 32; q8 += 64;
1690
+
1691
+ }
1692
+
1693
+ sums += aux32 * d;
1694
+
1695
+ }
1696
+
1697
+ *s = sumf+sums;
1698
+
1699
+ #else
1700
+
1701
+ const uint8_t * scales = (const uint8_t*)&utmp[0];
1702
+ const uint8_t * mins = (const uint8_t*)&utmp[2];
1703
+
1704
+ int8_t aux8[QK_K];
1705
+ int16_t aux16[8];
1706
+ float sums [8];
1707
+ int32_t aux32[8];
1708
+ memset(sums, 0, 8*sizeof(float));
1709
+
1710
+ float sumf = 0;
1711
+ for (int i = 0; i < nb; ++i) {
1712
+ const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1713
+ const uint8_t * GGML_RESTRICT hm = x[i].qh;
1714
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1715
+ memset(aux32, 0, 8*sizeof(int32_t));
1716
+ int8_t * GGML_RESTRICT a = aux8;
1717
+ uint8_t m = 1;
1718
+ for (int j = 0; j < QK_K/64; ++j) {
1719
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
1720
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
1721
+ a += 32; m <<= 1;
1722
+ for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
1723
+ for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
1724
+ a += 32; m <<= 1;
1725
+ q4 += 32;
1726
+ }
1727
+ memcpy(utmp, x[i].scales, 12);
1728
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1729
+ const uint32_t uaux = utmp[1] & kmask1;
1730
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1731
+ utmp[2] = uaux;
1732
+ utmp[0] &= kmask1;
1733
+
1734
+ int sumi = 0;
1735
+ for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
1736
+ a = aux8;
1737
+ int is = 0;
1738
+ for (int j = 0; j < QK_K/32; ++j) {
1739
+ int32_t scale = scales[is++];
1740
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1741
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1742
+ q8 += 8; a += 8;
1743
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1744
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1745
+ q8 += 8; a += 8;
1746
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1747
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1748
+ q8 += 8; a += 8;
1749
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
1750
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
1751
+ q8 += 8; a += 8;
1752
+ }
1753
+ const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1754
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
1755
+ const float dmin = GGML_CPU_FP16_TO_FP32(x[i].dmin) * y[i].d;
1756
+ sumf -= dmin * sumi;
1757
+ }
1758
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
1759
+ *s = sumf;
1760
+ #endif
1761
+ }
1762
+
1763
+ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1764
+ assert(n % QK_K == 0);
1765
+ assert(nrc == 1);
1766
+ UNUSED(nrc);
1767
+ UNUSED(bx);
1768
+ UNUSED(by);
1769
+ UNUSED(bs);
1770
+
1771
+ const block_q6_K * GGML_RESTRICT x = vx;
1772
+ const block_q8_K * GGML_RESTRICT y = vy;
1773
+
1774
+ const int nb = n / QK_K;
1775
+
1776
+ #if defined __riscv_xtheadvector
1777
+
1778
+ float sumf = 0;
1779
+
1780
+ for (int i = 0; i < nb; ++i) {
1781
+
1782
+ const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1783
+
1784
+ const uint8_t * restrict q6 = x[i].ql;
1785
+ const uint8_t * restrict qh = x[i].qh;
1786
+ const int8_t * restrict q8 = y[i].qs;
1787
+
1788
+ const int8_t * restrict scale = x[i].scales;
1789
+
1790
+ int sum_t = 0;
1791
+ int t0;
1792
+
1793
+ for (int j = 0; j < QK_K/128; ++j) {
1794
+ __asm__ __volatile__(
1795
+ "th.vsetvli zero, %[vl32], e8, m2\n\t" // vl == 32
1796
+ "th.vlb.v v4, (%[qh])\n\t"
1797
+ "th.vsll.vi v0, v4, 4\n\t"
1798
+ "th.vsll.vi v2, v4, 2\n\t"
1799
+ "th.vsrl.vi v6, v4, 2\n\t"
1800
+ "th.vsetvli zero, %[vl64], e8, m4\n\t" // vl == 64
1801
+ "th.vlb.v v8, (%[q6])\n\t"
1802
+ "th.vsrl.vi v12, v8, 4\n\t"
1803
+ "th.vand.vi v8, v8, 0xF\n\t"
1804
+ "th.vsetvli zero, %[vl128], e8, m8\n\t" // vl == 128
1805
+ "th.vand.vx v0, v0, %[mask]\n\t"
1806
+ "th.vor.vv v8, v8, v0\n\t"
1807
+ "th.vlb.v v0, (%[q8])\n\t"
1808
+ "th.vsub.vx v8, v8, %[vl32]\n\t"
1809
+ "th.vsetvli zero, %[vl64], e8, m4\n\t" // vl == 64
1810
+ "th.vwmul.vv v16, v0, v8\n\t"
1811
+ "th.vwmul.vv v24, v4, v12\n\t"
1812
+ "li %[t0], 16\n\t"
1813
+ "th.vsetvli zero, %[t0], e16, m2\n\t" // vl == 16
1814
+ "th.vmv.v.x v0, zero\n\t"
1815
+ "th.vwredsum.vs v10, v16, v0\n\t"
1816
+ "th.vwredsum.vs v9, v18, v0\n\t"
1817
+ "th.vwredsum.vs v8, v20, v0\n\t"
1818
+ "th.vwredsum.vs v7, v22, v0\n\t"
1819
+ "th.vwredsum.vs v11, v24, v0\n\t"
1820
+ "th.vwredsum.vs v12, v26, v0\n\t"
1821
+ "th.vwredsum.vs v13, v28, v0\n\t"
1822
+ "th.vwredsum.vs v14, v30, v0\n\t"
1823
+ "li %[t0], 4\n\t"
1824
+ "th.vsetvli zero, %[t0], e32, m1\n\t" // vl == 4
1825
+ "th.vslideup.vi v10, v9, 1\n\t"
1826
+ "th.vslideup.vi v8, v7, 1\n\t"
1827
+ "th.vslideup.vi v11, v12, 1\n\t"
1828
+ "th.vslideup.vi v13, v14, 1\n\t"
1829
+ "th.vslideup.vi v10, v8, 2\n\t"
1830
+ "th.vslideup.vi v11, v13, 2\n\t"
1831
+ "li %[t0], 8\n\t"
1832
+ "th.vsetvli zero, %[t0], e32, m2\n\t" // vl == 8
1833
+ "th.vlb.v v4, (%[scale])\n\t"
1834
+ "th.vmul.vv v2, v4, v10\n\t"
1835
+ "th.vredsum.vs v0, v2, v0\n\t"
1836
+ "th.vmv.x.s %[t0], v0\n\t"
1837
+ "add %[sumi], %[sumi], %[t0]"
1838
+ : [sumi] "+&r" (sum_t), [t0] "=&r" (t0)
1839
+ : [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale)
1840
+ , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
1841
+ , [mask] "r" (0x30)
1842
+ : "memory"
1843
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
1844
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
1845
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
1846
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
1847
+ );
1848
+ q6 += 64; qh += 32; q8 += 128; scale += 8;
1849
+ }
1850
+
1851
+ sumf += d * sum_t;
1852
+
1853
+ }
1854
+
1855
+ *s = sumf;
1856
+
1857
+ #elif defined __riscv_v
1858
+
1859
+ float sumf = 0;
1860
+ const int vector_length = __riscv_vlenb() * 8;
1861
+
1862
+ switch (vector_length) {
1863
+ case 256:
1864
+ for (int i = 0; i < nb; ++i) {
1865
+
1866
+ const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1867
+
1868
+ const uint8_t * GGML_RESTRICT q6 = x[i].ql;
1869
+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
1870
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1871
+
1872
+ const int8_t * GGML_RESTRICT scale = x[i].scales;
1873
+
1874
+ size_t vl;
1875
+
1876
+ vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
1877
+
1878
+ int sum_t = 0;
1879
+ int is = 0;
1880
+
1881
+ for (int j = 0; j < QK_K/128; ++j) {
1882
+
1883
+ vl = 32;
1884
+
1885
+ // load qh
1886
+ vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
1887
+
1888
+ // load Q6
1889
+ vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
1890
+ vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
1891
+
1892
+ vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
1893
+ vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
1894
+ vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
1895
+ vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
1896
+
1897
+ vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
1898
+ vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
1899
+ vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
1900
+ vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
1901
+
1902
+ vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
1903
+ vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
1904
+ vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
1905
+ vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
1906
+
1907
+ vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
1908
+ vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
1909
+ vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
1910
+ vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
1911
+
1912
+ // load Q8 and take product
1913
+ vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
1914
+ vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
1915
+ vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
1916
+ vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
1917
+
1918
+ vl = 16;
1919
+
1920
+ vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
1921
+ vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
1922
+ vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
1923
+ vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
1924
+ vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
1925
+ vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
1926
+ vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
1927
+ vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
1928
+
1929
+ vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
1930
+ vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
1931
+ vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
1932
+ vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
1933
+
1934
+ sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
1935
+
1936
+ q6 += 64; qh += 32; q8 += 128; is=8;
1937
+
1938
+ }
1939
+
1940
+ sumf += d * sum_t;
1941
+
1942
+ }
1943
+ break;
1944
+ case 128:
1945
+ for (int i = 0; i < nb; ++i) {
1946
+
1947
+ const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1948
+
1949
+ const uint8_t * restrict q6 = x[i].ql;
1950
+ const uint8_t * restrict qh = x[i].qh;
1951
+ const int8_t * restrict q8 = y[i].qs;
1952
+
1953
+ const int8_t * restrict scale = x[i].scales;
1954
+
1955
+ int sum_t = 0;
1956
+ int t0;
1957
+
1958
+ for (int j = 0; j < QK_K/128; ++j) {
1959
+ __asm__ __volatile__(
1960
+ "vsetvli zero, %[vl32], e8, m2\n\t"
1961
+ "vle8.v v4, (%[qh])\n\t"
1962
+ "vsll.vi v0, v4, 4\n\t"
1963
+ "vsll.vi v2, v4, 2\n\t"
1964
+ "vsrl.vi v6, v4, 2\n\t"
1965
+ "vsetvli zero, %[vl64], e8, m4\n\t"
1966
+ "vle8.v v8, (%[q6])\n\t"
1967
+ "vsrl.vi v12, v8, 4\n\t"
1968
+ "vand.vi v8, v8, 0xF\n\t"
1969
+ "vsetvli zero, %[vl128], e8, m8\n\t"
1970
+ "vand.vx v0, v0, %[mask]\n\t"
1971
+ "vor.vv v8, v8, v0\n\t"
1972
+ "vle8.v v0, (%[q8])\n\t"
1973
+ "vsub.vx v8, v8, %[vl32]\n\t"
1974
+ "vsetvli zero, %[vl64], e8, m4\n\t"
1975
+ "vwmul.vv v16, v0, v8\n\t"
1976
+ "vwmul.vv v24, v4, v12\n\t"
1977
+ "vsetivli zero, 16, e16, m2\n\t"
1978
+ "vmv.v.x v0, zero\n\t"
1979
+ "vwredsum.vs v10, v16, v0\n\t"
1980
+ "vwredsum.vs v9, v18, v0\n\t"
1981
+ "vwredsum.vs v8, v20, v0\n\t"
1982
+ "vwredsum.vs v7, v22, v0\n\t"
1983
+ "vwredsum.vs v11, v24, v0\n\t"
1984
+ "vwredsum.vs v12, v26, v0\n\t"
1985
+ "vwredsum.vs v13, v28, v0\n\t"
1986
+ "vwredsum.vs v14, v30, v0\n\t"
1987
+ "vsetivli zero, 4, e32, m1\n\t"
1988
+ "vslideup.vi v10, v9, 1\n\t"
1989
+ "vslideup.vi v8, v7, 1\n\t"
1990
+ "vslideup.vi v11, v12, 1\n\t"
1991
+ "vslideup.vi v13, v14, 1\n\t"
1992
+ "vslideup.vi v10, v8, 2\n\t"
1993
+ "vslideup.vi v11, v13, 2\n\t"
1994
+ "vsetivli zero, 8, e32, m2\n\t"
1995
+ "vle8.v v2, (%[scale])\n\t"
1996
+ "vsext.vf4 v4, v2\n\t"
1997
+ "vmul.vv v2, v4, v10\n\t"
1998
+ "vredsum.vs v0, v2, v0\n\t"
1999
+ "vmv.x.s %[t0], v0\n\t"
2000
+ "add %[sumi], %[sumi], %[t0]"
2001
+ : [sumi] "+&r" (sum_t), [t0] "=&r" (t0)
2002
+ : [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale)
2003
+ , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
2004
+ , [mask] "r" (0x30)
2005
+ : "memory"
2006
+ , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
2007
+ , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
2008
+ , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
2009
+ , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
2010
+ );
2011
+ q6 += 64; qh += 32; q8 += 128; scale += 8;
2012
+ }
2013
+
2014
+ sumf += d * sum_t;
2015
+
2016
+ }
2017
+ break;
2018
+ default:
2019
+ assert(false && "Unsupported vector length");
2020
+ break;
2021
+ }
2022
+
2023
+ *s = sumf;
2024
+
2025
+ #else
2026
+
2027
+ int8_t aux8[QK_K];
2028
+ int16_t aux16[8];
2029
+ float sums [8];
2030
+ int32_t aux32[8];
2031
+ memset(sums, 0, 8*sizeof(float));
2032
+
2033
+ float sumf = 0;
2034
+ for (int i = 0; i < nb; ++i) {
2035
+ const uint8_t * GGML_RESTRICT q4 = x[i].ql;
2036
+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
2037
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
2038
+ memset(aux32, 0, 8*sizeof(int32_t));
2039
+ int8_t * GGML_RESTRICT a = aux8;
2040
+ for (int j = 0; j < QK_K; j += 128) {
2041
+ for (int l = 0; l < 32; ++l) {
2042
+ a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
2043
+ a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
2044
+ a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
2045
+ a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
2046
+ }
2047
+ a += 128;
2048
+ q4 += 64;
2049
+ qh += 32;
2050
+ }
2051
+ a = aux8;
2052
+ int is = 0;
2053
+ for (int j = 0; j < QK_K/16; ++j) {
2054
+ int scale = x[i].scales[is++];
2055
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2056
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2057
+ q8 += 8; a += 8;
2058
+ for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
2059
+ for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
2060
+ q8 += 8; a += 8;
2061
+ }
2062
+ const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
2063
+ for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
2064
+ }
2065
+ for (int l = 0; l < 8; ++l) sumf += sums[l];
2066
+ *s = sumf;
2067
+ #endif
2068
+ }
2069
+