whispercpp 1.3.2 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (664) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +59 -27
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/build-xcframework.sh +24 -0
  19. data/ext/sources/examples/CMakeLists.txt +1 -0
  20. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  21. data/ext/sources/examples/addon.node/addon.cpp +154 -35
  22. data/ext/sources/examples/addon.node/index.js +10 -5
  23. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  24. data/ext/sources/examples/bench/bench.cpp +29 -18
  25. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  26. data/ext/sources/examples/cli/cli.cpp +7 -4
  27. data/ext/sources/examples/command/command.cpp +58 -32
  28. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/common-whisper.cpp +14 -7
  31. data/ext/sources/examples/lsp/lsp.cpp +21 -17
  32. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  33. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  34. data/ext/sources/examples/server/server.cpp +193 -35
  35. data/ext/sources/examples/server.py +6 -1
  36. data/ext/sources/examples/stream/stream.cpp +10 -2
  37. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  38. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  39. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -0
  40. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  41. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  42. data/ext/sources/examples/talk-llama/llama-arch.cpp +756 -15
  43. data/ext/sources/examples/talk-llama/llama-arch.h +85 -1
  44. data/ext/sources/examples/talk-llama/llama-batch.cpp +773 -272
  45. data/ext/sources/examples/talk-llama/llama-batch.h +126 -55
  46. data/ext/sources/examples/talk-llama/llama-chat.cpp +150 -13
  47. data/ext/sources/examples/talk-llama/llama-chat.h +8 -0
  48. data/ext/sources/examples/talk-llama/llama-context.cpp +814 -542
  49. data/ext/sources/examples/talk-llama/llama-context.h +68 -32
  50. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  52. data/ext/sources/examples/talk-llama/llama-graph.cpp +787 -440
  53. data/ext/sources/examples/talk-llama/llama-graph.h +333 -153
  54. data/ext/sources/examples/talk-llama/llama-hparams.cpp +128 -6
  55. data/ext/sources/examples/talk-llama/llama-hparams.h +80 -17
  56. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  57. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  58. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +137 -0
  59. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +1248 -1967
  60. data/ext/sources/examples/talk-llama/llama-kv-cache.h +218 -345
  61. data/ext/sources/examples/talk-llama/llama-kv-cells.h +164 -52
  62. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +266 -0
  63. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +139 -0
  64. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1154 -0
  65. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +182 -0
  66. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  67. data/ext/sources/examples/talk-llama/llama-memory.h +94 -4
  68. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  69. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +44 -17
  70. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  71. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  72. data/ext/sources/examples/talk-llama/llama-model.cpp +11377 -5248
  73. data/ext/sources/examples/talk-llama/llama-model.h +87 -9
  74. data/ext/sources/examples/talk-llama/llama-quant.cpp +137 -16
  75. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  76. data/ext/sources/examples/talk-llama/llama-vocab.cpp +502 -38
  77. data/ext/sources/examples/talk-llama/llama-vocab.h +46 -0
  78. data/ext/sources/examples/talk-llama/llama.cpp +76 -17
  79. data/ext/sources/examples/talk-llama/llama.h +176 -151
  80. data/ext/sources/examples/talk-llama/talk-llama.cpp +11 -6
  81. data/ext/sources/examples/talk-llama/unicode.cpp +212 -0
  82. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  83. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  84. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +6 -2
  85. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  86. data/ext/sources/ggml/CMakeLists.txt +106 -33
  87. data/ext/sources/ggml/cmake/common.cmake +24 -0
  88. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  89. data/ext/sources/ggml/include/ggml-backend.h +18 -2
  90. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  91. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  92. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  93. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  94. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  95. data/ext/sources/ggml/include/ggml.h +365 -21
  96. data/ext/sources/ggml/src/CMakeLists.txt +98 -25
  97. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  98. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  99. data/ext/sources/ggml/src/ggml-backend-reg.cpp +35 -13
  100. data/ext/sources/ggml/src/ggml-backend.cpp +266 -60
  101. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +4 -4
  102. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  103. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +15 -0
  104. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  105. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  106. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  107. data/ext/sources/ggml/src/ggml-cann/common.h +149 -2
  108. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +521 -78
  109. data/ext/sources/ggml/src/ggml-common.h +21 -0
  110. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +165 -50
  111. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -3
  112. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  113. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  114. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +3650 -0
  115. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1891 -0
  116. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2160 -0
  117. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  118. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
  119. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1897 -0
  120. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
  121. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
  122. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
  123. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
  124. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
  125. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +214 -0
  126. data/ext/sources/ggml/src/ggml-cpu/common.h +18 -3
  127. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +23 -7
  128. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +179 -110
  129. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +44 -33
  130. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  131. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  132. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  133. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +228 -98
  134. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +532 -1124
  135. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  136. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +3374 -2081
  137. data/ext/sources/ggml/src/ggml-cpu/ops.h +13 -8
  138. data/ext/sources/ggml/src/ggml-cpu/quants.c +1193 -0
  139. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +34 -0
  140. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1982 -0
  141. data/ext/sources/ggml/src/ggml-cpu/repack.h +120 -0
  142. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +367 -46
  143. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  144. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  145. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  146. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  147. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +3 -3
  148. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +1 -1
  149. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +272 -35
  150. data/ext/sources/ggml/src/ggml-cpu/vec.h +794 -142
  151. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  152. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  153. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  154. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  155. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  156. data/ext/sources/ggml/src/ggml-cuda/common.cuh +291 -81
  157. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  158. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  159. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  160. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  161. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  162. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  163. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  164. data/ext/sources/ggml/src/ggml-cuda/convert.cu +117 -22
  165. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +20 -0
  166. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  167. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  168. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  169. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  170. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +499 -368
  171. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +142 -93
  172. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  173. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  174. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  175. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +90 -50
  176. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  177. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  178. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  179. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +636 -222
  180. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  181. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  182. data/ext/sources/ggml/src/ggml-cuda/mean.cu +73 -0
  183. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  184. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  185. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  186. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  187. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  188. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  189. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +506 -0
  190. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +4 -5
  191. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  192. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  193. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  195. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  196. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  197. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  198. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  199. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  200. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  201. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  202. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  203. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  204. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  205. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  206. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  208. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  210. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  211. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -98
  212. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  213. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +23 -19
  214. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  234. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  235. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  236. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  237. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  238. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  239. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  240. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  241. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  242. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  243. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  244. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  245. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  246. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  247. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  248. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  249. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  250. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  251. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  252. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  253. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  254. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  255. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  256. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  258. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  259. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  260. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  261. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  262. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  269. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  270. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  271. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  272. data/ext/sources/ggml/src/ggml-cuda/unary.cu +179 -0
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +15 -0
  274. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  277. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  278. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  279. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -2
  280. data/ext/sources/ggml/src/ggml-impl.h +229 -175
  281. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +21 -17
  282. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  283. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  284. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  285. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  286. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  287. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  288. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  289. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +163 -63
  290. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  291. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  292. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  293. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3208 -1575
  294. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  295. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  296. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +32 -0
  297. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4430 -792
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +138 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
  344. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  345. data/ext/sources/ggml/src/ggml-quants.c +117 -24
  346. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  347. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +85 -62
  348. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  349. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  350. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +9 -0
  351. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  352. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  353. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +13 -17
  354. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +21 -2
  355. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +116 -211
  356. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  357. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  358. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +700 -1041
  359. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +20 -9
  360. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +17 -26
  361. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +2 -96
  362. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +393 -250
  363. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  364. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  365. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  366. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -11
  367. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +125 -21
  368. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  369. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  370. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  371. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  372. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +4 -3
  373. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +105 -17
  374. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  375. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4198 -1145
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  401. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  402. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  403. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  404. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +19 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +64 -11
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +338 -71
  449. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  450. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  451. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  452. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  453. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  454. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  455. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  456. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  457. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  458. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  459. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  460. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  461. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  462. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  463. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  464. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  465. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  466. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  467. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  468. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  469. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  470. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  471. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  472. data/ext/sources/ggml/src/ggml.c +802 -142
  473. data/ext/sources/ggml/src/ggml.cpp +26 -0
  474. data/ext/sources/ggml/src/gguf.cpp +32 -4
  475. data/ext/sources/include/whisper.h +2 -0
  476. data/ext/sources/src/CMakeLists.txt +2 -0
  477. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  478. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  479. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  480. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  481. data/ext/sources/src/whisper.cpp +241 -215
  482. data/ext/sources/tests/CMakeLists.txt +8 -1
  483. data/ext/sources/tests/test-vad-full.cpp +3 -3
  484. data/ext/sources/tests/test-vad.cpp +2 -2
  485. data/extsources.rb +15 -9
  486. data/lib/whisper/context.rb +15 -0
  487. data/lib/whisper/model/uri.rb +57 -2
  488. data/lib/whisper/segment.rb +58 -0
  489. data/sig/whisper.rbs +75 -38
  490. data/{tests → test}/helper.rb +1 -12
  491. data/{tests → test}/test_model.rb +9 -0
  492. data/test/test_package.rb +51 -0
  493. data/{tests → test}/test_params.rb +8 -0
  494. data/test/test_segment.rb +146 -0
  495. data/{tests → test}/test_whisper.rb +70 -0
  496. data/whispercpp.gemspec +2 -3
  497. metadata +246 -191
  498. data/ext/sources/.dockerignore +0 -3
  499. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  500. data/ext/sources/ci/run.sh +0 -336
  501. data/ext/sources/close-issue.yml +0 -28
  502. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  503. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  504. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  505. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  506. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  507. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  508. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  509. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  510. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  511. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  512. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  513. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  514. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  515. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  516. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  517. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  518. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -6431
  519. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  520. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  521. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  522. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  523. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  524. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  525. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  526. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  527. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -336
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  548. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  549. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  550. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  551. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  552. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  553. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  554. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  555. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  556. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  557. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  558. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  559. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  560. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  561. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  562. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  563. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  564. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  565. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  566. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  567. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  568. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  569. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  570. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  571. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  572. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  573. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  574. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  575. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  576. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  577. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  578. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  579. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  580. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  581. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  582. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  583. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  584. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  585. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  586. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  587. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  588. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  589. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  590. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  591. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  592. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  593. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  594. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  595. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  596. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  597. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  598. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  599. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  600. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  601. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  602. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  603. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  604. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  605. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  606. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  607. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  608. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  609. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  610. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  611. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  612. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  613. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  614. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  615. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  616. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  617. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  618. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  619. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  620. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  621. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  622. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  623. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  624. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  625. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  626. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  627. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  628. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  629. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  630. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  631. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  632. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  633. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  634. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  635. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  636. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  637. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  638. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  639. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  640. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  641. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  642. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  643. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  644. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  645. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  646. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  647. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  648. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  649. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  650. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  651. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  652. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  653. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -5998
  654. data/tests/test_package.rb +0 -46
  655. data/tests/test_segment.rb +0 -74
  656. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  657. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  658. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  659. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  660. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  661. /data/{tests → test}/test_callback.rb +0 -0
  662. /data/{tests → test}/test_error.rb +0 -0
  663. /data/{tests → test}/test_vad.rb +0 -0
  664. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -0,0 +1,2160 @@
1
+ #define GGML_COMMON_IMPL_C
2
+ #include "ggml-common.h"
3
+ #include "ggml-quants.h"
4
+ #include "ggml-impl.h"
5
+ #include "ggml-cpu.h"
6
+ #include "simd-mappings.h"
7
+
8
+ #include "../../quants.h"
9
+ #include "../../ggml-cpu-impl.h"
10
+
11
+ #include <math.h>
12
+ #include <string.h>
13
+ #include <assert.h>
14
+ #include <float.h>
15
+ #include <stdlib.h> // for qsort
16
+ #include <stdio.h> // for GGML_ASSERT
17
+
18
+ #define GROUP_MAX_EPS 1e-15f
19
+ #define GROUP_MAX_EPS_IQ3_XXS 1e-8f
20
+ #define GROUP_MAX_EPS_IQ2_S 1e-8f
21
+ #define GROUP_MAX_EPS_IQ1_M 1e-7f
22
+ #define GROUP_MAX_EPS_IQ1_S 1e-12f
23
+
24
+ #define UNUSED GGML_UNUSED
25
+
26
+ #if defined(__loongarch_sx)
27
+
28
+ static __m128i lsx_packs_w(__m128i a, __m128i b) {
29
+ __m128i tmp, tmp1;
30
+ tmp = __lsx_vsat_w(a, 15);
31
+ tmp1 = __lsx_vsat_w(b, 15);
32
+ return __lsx_vpickev_h(tmp1, tmp);
33
+ }
34
+
35
+ static __m128i lsx_packs_h(__m128i a, __m128i b) {
36
+ __m128i tmp, tmp1;
37
+ tmp = __lsx_vsat_h(a, 7);
38
+ tmp1 = __lsx_vsat_h(b, 7);
39
+ return __lsx_vpickev_b(tmp1, tmp);
40
+ }
41
+
42
+ static __m128i lsx_packus_h(__m128i a, __m128i b) {
43
+ __m128i tmp, tmp1;
44
+ tmp = __lsx_vsat_hu(a, 7);
45
+ tmp1 = __lsx_vsat_hu(b, 7);
46
+ return __lsx_vpickev_b(tmp1, tmp);
47
+ }
48
+
49
+ static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
50
+ __m128i tmp1, tmp2;
51
+ tmp1 = __lsx_vmulwev_h_b(a, b);
52
+ tmp2 = __lsx_vmulwod_h_b(a, b);
53
+ return __lsx_vsadd_h(tmp1, tmp2);
54
+ }
55
+
56
+ static __m128i lsx_madd_h(__m128i a, __m128i b) {
57
+ __m128i tmp1, tmp2;
58
+ tmp1 = __lsx_vmulwev_w_h(a, b);
59
+ tmp2 = __lsx_vmulwod_w_h(a, b);
60
+ return __lsx_vadd_w(tmp1, tmp2);
61
+ }
62
+
63
+ static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
64
+ v4i32 __ret = {d, c, b, a};
65
+ return (__m128i)__ret;
66
+ }
67
+
68
+ static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
69
+ __m128i mask_f, zero, tmp0, tmp2, mask;
70
+ int f = 0x8f;
71
+ mask_f = __lsx_vreplgr2vr_b(f);
72
+ zero = __lsx_vldi(0);
73
+ tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
74
+ tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
75
+ mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
76
+ tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
77
+ return __lsx_vshuf_b(a, zero, tmp2);
78
+ }
79
+
80
+ static __m128i lsx_hadd_h(__m128i a, __m128i b) {
81
+ __m128i tmp1 = __lsx_vpickev_h(b, a);
82
+ __m128i tmp2 = __lsx_vpickod_h(b, a);
83
+ return __lsx_vadd_h(tmp1, tmp2);
84
+ }
85
+
86
+ static __m128i lsx_hadd_w(__m128i a, __m128i b) {
87
+ __m128i tmp1 = __lsx_vpickev_w(b, a);
88
+ __m128i tmp2 = __lsx_vpickod_w(b, a);
89
+ return __lsx_vadd_w(tmp1, tmp2);
90
+ }
91
+
92
+ static __m128 lsx_hadd_s(__m128 a, __m128 b) {
93
+ __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
94
+ __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
95
+
96
+ return __lsx_vfadd_s(tmp1, tmp2);
97
+ }
98
+
99
+ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
100
+ __m128 res_0 =lsx_hadd_s(a, b);
101
+ __m128 res_1 =lsx_hadd_s(c, d);
102
+ __m128 res =lsx_hadd_s(res_0, res_1);
103
+ res =lsx_hadd_s(res, res);
104
+ res =lsx_hadd_s(res, res);
105
+
106
+ return ((v4f32)res)[0];
107
+ }
108
+
109
+ // multiply int8_t, add results pairwise twice
110
+ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
111
+ // Get absolute values of x vectors
112
+ const __m128i ax = __lsx_vsigncov_b(x, x);
113
+ // Sign the values of the y vectors
114
+ const __m128i sy = __lsx_vsigncov_b(x, y);
115
+ // Perform multiplication and create 16-bit values
116
+ const __m128i dot = lsx_maddubs_h(ax, sy);
117
+ const __m128i ones = __lsx_vreplgr2vr_h(1);
118
+ return lsx_madd_h(ones, dot);
119
+ }
120
+ #endif
121
+
122
+ #if defined(__loongarch_asx)
123
+
124
+ #ifdef __clang__
125
+ #define VREGS_PREFIX "$vr"
126
+ #define XREGS_PREFIX "$xr"
127
+ #else // GCC
128
+ #define VREGS_PREFIX "$f"
129
+ #define XREGS_PREFIX "$f"
130
+ #endif
131
+ #define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31"
132
+ // Convert __m128i to __m256i
133
+ static inline __m256i ____m256i(__m128i in) {
134
+ __m256i out = __lasx_xvldi(0);
135
+ __asm__ volatile (
136
+ ".irp i," __ALL_REGS "\n\t"
137
+ " .ifc %[out], " XREGS_PREFIX"\\i \n\t"
138
+ " .irp j," __ALL_REGS "\n\t"
139
+ " .ifc %[in], " VREGS_PREFIX "\\j \n\t"
140
+ " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t"
141
+ " .endif \n\t"
142
+ " .endr \n\t"
143
+ " .endif \n\t"
144
+ ".endr \n\t"
145
+ : [out] "+f" (out) : [in] "f" (in)
146
+ );
147
+ return out;
148
+ }
149
+ // Convert two __m128i to __m256i
150
+ static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) {
151
+ __m256i out;
152
+ __asm__ volatile (
153
+ ".irp i," __ALL_REGS "\n\t"
154
+ " .ifc %[hi], " VREGS_PREFIX "\\i \n\t"
155
+ " .irp j," __ALL_REGS "\n\t"
156
+ " .ifc %[lo], " VREGS_PREFIX "\\j \n\t"
157
+ " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t"
158
+ " .endif \n\t"
159
+ " .endr \n\t"
160
+ " .endif \n\t"
161
+ ".endr \n\t"
162
+ ".ifnc %[out], %[hi] \n\t"
163
+ ".irp i," __ALL_REGS "\n\t"
164
+ " .ifc %[out], " XREGS_PREFIX "\\i \n\t"
165
+ " .irp j," __ALL_REGS "\n\t"
166
+ " .ifc %[hi], " VREGS_PREFIX "\\j \n\t"
167
+ " xvori.b $xr\\i, $xr\\j, 0 \n\t"
168
+ " .endif \n\t"
169
+ " .endr \n\t"
170
+ " .endif \n\t"
171
+ ".endr \n\t"
172
+ ".endif \n\t"
173
+ : [out] "=f" (out), [hi] "+f" (inhi)
174
+ : [lo] "f" (inlo)
175
+ );
176
+ return out;
177
+ }
178
+ // Convert __m256i low part to __m128i
179
+ static inline __m128i lasx_extracti128_lo(__m256i in) {
180
+ __m128i out;
181
+ __asm__ volatile (
182
+ ".ifnc %[out], %[in] \n\t"
183
+ ".irp i," __ALL_REGS "\n\t"
184
+ " .ifc %[out], " VREGS_PREFIX "\\i \n\t"
185
+ " .irp j," __ALL_REGS "\n\t"
186
+ " .ifc %[in], " XREGS_PREFIX "\\j \n\t"
187
+ " vori.b $vr\\i, $vr\\j, 0 \n\t"
188
+ " .endif \n\t"
189
+ " .endr \n\t"
190
+ " .endif \n\t"
191
+ ".endr \n\t"
192
+ ".endif \n\t"
193
+ : [out] "=f" (out) : [in] "f" (in)
194
+ );
195
+ return out;
196
+ }
197
+ // Convert __m256i high part to __m128i
198
+ static inline __m128i lasx_extracti128_hi(__m256i in) {
199
+ __m128i out;
200
+ __asm__ volatile (
201
+ ".irp i," __ALL_REGS "\n\t"
202
+ " .ifc %[out], " VREGS_PREFIX "\\i \n\t"
203
+ " .irp j," __ALL_REGS "\n\t"
204
+ " .ifc %[in], " XREGS_PREFIX "\\j \n\t"
205
+ " xvpermi.q $xr\\i, $xr\\j, 0x11 \n\t"
206
+ " .endif \n\t"
207
+ " .endr \n\t"
208
+ " .endif \n\t"
209
+ ".endr \n\t"
210
+ : [out] "=f" (out) : [in] "f" (in)
211
+ );
212
+ return out;
213
+ }
214
+
215
+ static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0) {
216
+ v8i32 __ret = {e0, e1, e2, e3, e4, e5, e6, e7};
217
+ return (__m256i)__ret;
218
+ }
219
+
220
+ static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) {
221
+ v4i64 __ret = {d, c, b, a};
222
+ return (__m256i)__ret;
223
+ }
224
+
225
+ static __m256i lasx_insertf128( __m128i x, __m128i y) {
226
+ return lasx_set_q(x, y);
227
+ }
228
+
229
+ static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
230
+ __m256i mask_f, zero, tmp0, tmp2, mask;
231
+ int f = 0x8f;
232
+ mask_f = __lasx_xvreplgr2vr_b(f);
233
+ zero = __lasx_xvldi(0);
234
+ tmp0 = __lasx_xvand_v(b, mask_f); // get mask with low 4 bit and sign bits
235
+ tmp0 = __lasx_xvori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
236
+ mask = __lasx_xvsle_b(zero, tmp0); // if mask >= 0, set mask
237
+ tmp2 = __lasx_xvand_v(tmp0, mask); // maskout the in2 < ones
238
+ return __lasx_xvshuf_b(a, zero, tmp2);
239
+ }
240
+
241
+ static __m256i lasx_extu8_16(__m128i a) {
242
+ return __lasx_vext2xv_hu_bu(____m256i(a));
243
+ }
244
+
245
+ static __m256i lasx_ext8_16(__m128i a) {
246
+ return __lasx_vext2xv_h_b(____m256i(a));
247
+ }
248
+
249
+ static __m256i lasx_ext16_32(__m128i a) {
250
+ return __lasx_vext2xv_w_h(____m256i(a));
251
+ }
252
+
253
+ static __m128i lasx_extracti128( __m256i a, int pos) {
254
+ __m128i ret;
255
+ if( pos == 0)
256
+ {
257
+ ret = lasx_extracti128_lo(a);
258
+ } else {
259
+ ret = lasx_extracti128_hi(a);
260
+ }
261
+ return ret;
262
+ }
263
+
264
+ static __m128 lasx_extractf128( __m256 a, int pos) {
265
+ __m128 ret;
266
+ if( pos == 0)
267
+ {
268
+ ret = (__m128)lasx_extracti128_lo((__m256i)a);
269
+ } else {
270
+ ret = (__m128)lasx_extracti128_hi((__m256i)a);
271
+ }
272
+ return ret;
273
+ }
274
+
275
+ static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
276
+ __m256i tmp1, tmp2;
277
+ tmp1 = __lasx_xvmulwev_h_b(a, b);
278
+ tmp2 = __lasx_xvmulwod_h_b(a, b);
279
+ return __lasx_xvsadd_h(tmp1, tmp2);
280
+ }
281
+
282
+ static __m256i lasx_madd_h(__m256i a, __m256i b) {
283
+ __m256i tmp1, tmp2;
284
+ tmp1 = __lasx_xvmulwev_w_h(a, b);
285
+ tmp2 = __lasx_xvmulwod_w_h(a, b);
286
+ return __lasx_xvadd_w(tmp1, tmp2);
287
+ }
288
+
289
+ static __m256i lasx_packs_w(__m256i a, __m256i b) {
290
+ __m256i tmp, tmp1;
291
+ tmp = __lasx_xvsat_w(a, 15);
292
+ tmp1 = __lasx_xvsat_w(b, 15);
293
+ return __lasx_xvpickev_h(tmp1, tmp);
294
+ }
295
+
296
+ static __m256i lasx_packs_h(__m256i a, __m256i b) {
297
+ __m256i tmp, tmp1;
298
+ tmp = __lasx_xvsat_h(a, 7);
299
+ tmp1 = __lasx_xvsat_h(b, 7);
300
+ return __lasx_xvpickev_b(tmp1, tmp);
301
+ }
302
+
303
+ static inline __m256i lasx_madd_h_b(__m256i a, __m256i b) {
304
+ __m256i tmp1, tmp2;
305
+ tmp1 = __lasx_xvmulwev_h_b(a, b);
306
+ tmp2 = __lasx_xvmulwod_h_b(a, b);
307
+ return __lasx_xvadd_h(tmp1, tmp2);
308
+ }
309
+
310
+ static inline __m256i lasx_xvrepl128vei_h(__m256i a, const unsigned int b) {
311
+ switch (b) {
312
+ case 0: return __lasx_xvrepl128vei_h(a, 0);
313
+ case 1: return __lasx_xvrepl128vei_h(a, 1);
314
+ case 2: return __lasx_xvrepl128vei_h(a, 2);
315
+ case 3: return __lasx_xvrepl128vei_h(a, 3);
316
+ case 4: return __lasx_xvrepl128vei_h(a, 4);
317
+ case 5: return __lasx_xvrepl128vei_h(a, 5);
318
+ case 6: return __lasx_xvrepl128vei_h(a, 6);
319
+ case 7: return __lasx_xvrepl128vei_h(a, 7);
320
+ default: __builtin_unreachable();
321
+ }
322
+ }
323
+
324
+ static inline __m256i lasx_xvandi_b_bit(__m256i a, const unsigned int b) {
325
+ switch (b) {
326
+ case 0: return __lasx_xvandi_b(a, 1 << 0);
327
+ case 1: return __lasx_xvandi_b(a, 1 << 1);
328
+ case 2: return __lasx_xvandi_b(a, 1 << 2);
329
+ case 3: return __lasx_xvandi_b(a, 1 << 3);
330
+ case 4: return __lasx_xvandi_b(a, 1 << 4);
331
+ case 5: return __lasx_xvandi_b(a, 1 << 5);
332
+ case 6: return __lasx_xvandi_b(a, 1 << 6);
333
+ case 7: return __lasx_xvandi_b(a, 1 << 7);
334
+ default: __builtin_unreachable();
335
+ }
336
+ }
337
+
338
+ // horizontally add 8 floats
339
+ static inline float hsum_float_8(const __m256 x) {
340
+ __m128 res = lasx_extractf128(x, 1);
341
+ res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
342
+ res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
343
+ res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
344
+ return ((v4f32)res)[0];
345
+ }
346
+
347
+ // horizontally add 8 int32_t
348
+ static inline int hsum_i32_8(const __m256i a) {
349
+
350
+ __m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11);
351
+ __m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00);
352
+
353
+ __m128i tmp1_128 = lasx_extracti128_lo(tmp1);
354
+ __m128i tmp2_128 = lasx_extracti128_lo(tmp2);
355
+
356
+ __m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128);
357
+
358
+ __m128i ev = __lsx_vpickev_w(sum128, sum128);
359
+ __m128i od = __lsx_vpickod_w(sum128, sum128);
360
+ __m128i sum64 = __lsx_vadd_w(ev, od);
361
+
362
+ int sum64_1, sum64_2;
363
+ sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
364
+ sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
365
+
366
+ return sum64_1 + sum64_2;
367
+ }
368
+
369
+ // horizontally add 4 int32_t
370
+ static inline int hsum_i32_4(const __m128i a) {
371
+ __m128i ev = __lsx_vpickev_w(a, a);
372
+ __m128i od = __lsx_vpickod_w(a, a);
373
+ __m128i sum64 = __lsx_vadd_w(ev, od);
374
+
375
+ int sum64_1, sum64_2;
376
+ sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
377
+ sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
378
+
379
+ return sum64_1 + sum64_2;
380
+ }
381
+
382
+ // spread 32 bits to 32 bytes { 0x00, 0xFF }
383
+ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
384
+
385
+ uint32_t x32;
386
+ memcpy(&x32, x, sizeof(uint32_t));
387
+ const __m256i shuf_mask = lasx_set_d(
388
+ 0x0303030303030303, 0x0202020202020202,
389
+ 0x0101010101010101, 0x0000000000000000);
390
+
391
+ __m256i bytes = lasx_shuffle_b(__lasx_xvreplgr2vr_w(x32), shuf_mask);
392
+ const __m256i bit_mask = __lasx_xvreplgr2vr_d(0x7fbfdfeff7fbfdfe);
393
+ bytes = __lasx_xvor_v(bytes, bit_mask);
394
+ return __lasx_xvseq_b(bytes, __lasx_xvreplgr2vr_d(-1));
395
+ }
396
+
397
+ // Unpack 32 4-bit fields into 32 bytes
398
+ // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
399
+ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) {
400
+ const __m128i lo = __lsx_vld((const __m128i *)rsi, 0);
401
+ __m128i hi = __lsx_vsrli_h(lo, 4);
402
+ return __lasx_xvandi_b(lasx_insertf128(hi, lo), 0xf);
403
+ }
404
+
405
+ // add int16_t pairwise and return as float vector
406
+ static inline __m256 sum_i16_pairs_float(const __m256i x) {
407
+ __m256i v = __lasx_xvpackod_h(x, x);
408
+ __m256i summed_pairs = __lasx_xvaddwev_w_h(x, v);
409
+ return __lasx_xvffint_s_w(summed_pairs);
410
+ }
411
+
412
+ static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
413
+ // Perform multiplication and create 16-bit values
414
+ const __m256i dot = lasx_maddubs_h(ax, sy);
415
+ return sum_i16_pairs_float(dot);
416
+ }
417
+
418
+ // multiply int8_t, add results pairwise twice and return as float vector
419
+ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
420
+ const __m256i dot = lasx_madd_h_b(x, y);
421
+ return sum_i16_pairs_float(dot);
422
+ }
423
+
424
+ static inline __m128i packNibbles( __m256i bytes ) {
425
+ // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
426
+ const __m256i lowByte = __lasx_xvreplgr2vr_h(0xFF);
427
+ __m256i high = __lasx_xvandn_v(lowByte, bytes);
428
+ __m256i low = __lasx_xvand_v(lowByte, bytes);
429
+ high = __lasx_xvsrli_h(high, 4);
430
+ bytes = __lasx_xvor_v(low, high);
431
+ // Compress uint16_t lanes into bytes
432
+ __m128i *r0 = (__m128i *)&bytes;
433
+ __m256i tmp_h128 = __lasx_xvpermi_q(bytes, bytes, 0x11);
434
+ __m128i *r1 = (__m128i *)&tmp_h128;
435
+
436
+ __m128i zero = __lsx_vldi(0);
437
+ __m128i tmp, tmp2, tmp3;
438
+
439
+ tmp = __lsx_vmax_h(zero, *r0);
440
+ tmp2 = __lsx_vsat_hu(tmp, 7);
441
+
442
+ tmp = __lsx_vmax_h(zero, *r1);
443
+ tmp3 = __lsx_vsat_hu(tmp, 7);
444
+ return __lsx_vpickev_b(tmp3, tmp2);
445
+ }
446
+ #endif //__loongarch_asx
447
+
448
+ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
449
+ assert(QK8_0 == 32);
450
+ assert(k % QK8_0 == 0);
451
+ const int nb = k / QK8_0;
452
+
453
+ block_q8_0 * GGML_RESTRICT y = vy;
454
+
455
+ #if defined(__loongarch_asx)
456
+ for (int i = 0; i < nb; i++) {
457
+ __m256 v0 = (__m256)__lasx_xvld( x , 0);
458
+ __m256 v1 = (__m256)__lasx_xvld( x , 32);
459
+ __m256 v2 = (__m256)__lasx_xvld( x , 64);
460
+ __m256 v3 = (__m256)__lasx_xvld( x , 96);
461
+ x += 32;
462
+
463
+ // Compute max(abs(e)) for the block
464
+ const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f );
465
+ __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 );
466
+ max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) );
467
+ max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) );
468
+ max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) );
469
+
470
+ __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs , 0) );
471
+ max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
472
+ __m128 tmp = max4;
473
+ max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 ));
474
+ const float max_scalar = ((v4f32)max4)[0];
475
+
476
+ // Quantize these floats
477
+ const float d = max_scalar / 127.f;
478
+ y[i].d = GGML_CPU_FP32_TO_FP16(d);
479
+ const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
480
+ const __m256 mul = (__m256)__lasx_xvreplfr2vr_s( id );
481
+
482
+ // Apply the multiplier
483
+ v0 = __lasx_xvfmul_s( v0, mul );
484
+ v1 = __lasx_xvfmul_s( v1, mul );
485
+ v2 = __lasx_xvfmul_s( v2, mul );
486
+ v3 = __lasx_xvfmul_s( v3, mul );
487
+
488
+ // Round to nearest integer
489
+ __m256i i0 = __lasx_xvftintrne_w_s( v0 );
490
+ __m256i i1 = __lasx_xvftintrne_w_s( v1 );
491
+ __m256i i2 = __lasx_xvftintrne_w_s( v2 );
492
+ __m256i i3 = __lasx_xvftintrne_w_s( v3 );
493
+
494
+ __m128i ni0 = lasx_extracti128( i0, 0 );
495
+ __m128i ni1 = lasx_extracti128( i0, 1);
496
+ __m128i ni2 = lasx_extracti128( i1, 0);
497
+ __m128i ni3 = lasx_extracti128( i1, 1);
498
+ __m128i ni4 = lasx_extracti128( i2, 0);
499
+ __m128i ni5 = lasx_extracti128( i2, 1);
500
+ __m128i ni6 = lasx_extracti128( i3, 0);
501
+ __m128i ni7 = lasx_extracti128( i3, 1);
502
+
503
+ // Convert int32 to int16
504
+ ni0 = lsx_packs_w( ni0, ni1 );
505
+ ni2 = lsx_packs_w( ni2, ni3 );
506
+ ni4 = lsx_packs_w( ni4, ni5 );
507
+ ni6 = lsx_packs_w( ni6, ni7 );
508
+ // Convert int16 to int8
509
+ ni0 = lsx_packs_h( ni0, ni2 );
510
+ ni4 = lsx_packs_h( ni4, ni6 );
511
+
512
+ __lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0);
513
+ __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
514
+
515
+ }
516
+ #else
517
+ GGML_UNUSED(nb);
518
+ // scalar
519
+ quantize_row_q8_0_ref(x, y, k);
520
+ #endif
521
+ }
522
+
523
+ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
524
+ assert(k % QK8_1 == 0);
525
+ const int nb = k / QK8_1;
526
+
527
+ block_q8_1 * GGML_RESTRICT y = vy;
528
+
529
+ #if defined(__loongarch_asx)
530
+ for (int i = 0; i < nb; i++) {
531
+ __m256 v0 = (__m256)__lasx_xvld( x , 0 );
532
+ __m256 v1 = (__m256)__lasx_xvld( x , 32 );
533
+ __m256 v2 = (__m256)__lasx_xvld( x , 64 );
534
+ __m256 v3 = (__m256)__lasx_xvld( x , 96 );
535
+ x += 32;
536
+
537
+ // Compute max(abs(e)) for the block
538
+ const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f );
539
+ __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 );
540
+ max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) );
541
+ max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) );
542
+ max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) );
543
+
544
+ __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs, 0) );
545
+ max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
546
+ __m128 tmp = max4;
547
+ max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x1 ));
548
+ const float max_scalar = ((v4f32)max4)[0];
549
+
550
+ // Quantize these floats
551
+ const float d = max_scalar / 127.f;
552
+ y[i].d = GGML_CPU_FP32_TO_FP16(d);
553
+ const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
554
+ const __m256 mul = __lasx_xvreplfr2vr_s( id );
555
+
556
+ // Apply the multiplier
557
+ v0 = __lasx_xvfmul_s( v0, mul );
558
+ v1 = __lasx_xvfmul_s( v1, mul );
559
+ v2 = __lasx_xvfmul_s( v2, mul );
560
+ v3 = __lasx_xvfmul_s( v3, mul );
561
+
562
+ // Round to nearest integer
563
+ __m256i i0 = __lasx_xvftintrne_w_s( v0 );
564
+ __m256i i1 = __lasx_xvftintrne_w_s( v1 );
565
+ __m256i i2 = __lasx_xvftintrne_w_s( v2 );
566
+ __m256i i3 = __lasx_xvftintrne_w_s( v3 );
567
+
568
+ __m128i ni0 = lasx_extracti128(i0, 0);
569
+ __m128i ni1 = lasx_extracti128( i0, 1);
570
+ __m128i ni2 = lasx_extracti128( i1, 0);
571
+ __m128i ni3 = lasx_extracti128( i1, 1);
572
+ __m128i ni4 = lasx_extracti128( i2, 0 );
573
+ __m128i ni5 = lasx_extracti128( i2, 1);
574
+ __m128i ni6 = lasx_extracti128( i3, 0);
575
+ __m128i ni7 = lasx_extracti128( i3, 1);
576
+
577
+ // Compute the sum of the quants and set y[i].s
578
+ const __m128i s0 = __lsx_vadd_w(__lsx_vadd_w(ni0, ni1), __lsx_vadd_w(ni2, ni3));
579
+ const __m128i s1 = __lsx_vadd_w(__lsx_vadd_w(ni4, ni5), __lsx_vadd_w(ni6, ni7));
580
+ y[i].s = GGML_CPU_FP32_TO_FP16(d * hsum_i32_4(__lsx_vadd_w(s0, s1)));
581
+
582
+ // Convert int32 to int16
583
+ ni0 = lsx_packs_w( ni0, ni1 );
584
+ ni2 = lsx_packs_w( ni2, ni3 );
585
+ ni4 = lsx_packs_w( ni4, ni5 );
586
+ ni6 = lsx_packs_w( ni6, ni7 );
587
+ // Convert int16 to int8
588
+ ni0 = lsx_packs_h( ni0, ni2 );
589
+ ni4 = lsx_packs_h( ni4, ni6 );
590
+
591
+ __lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0);
592
+ __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
593
+ }
594
+ #else
595
+ GGML_UNUSED(nb);
596
+ // scalar
597
+ quantize_row_q8_1_ref(x, y, k);
598
+ #endif
599
+ }
600
+
601
+
602
+ //===================================== Dot products =================================
603
+
604
+ //
605
+ // Helper functions
606
+ //
607
+
608
+ #if defined(__loongarch_asx)
609
+ // shuffles to pick the required scales in dot products
610
+ static inline __m256i get_scale_shuffle_q3k(int i) {
611
+ static const uint8_t k_shuffle[128] = {
612
+ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
613
+ 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
614
+ 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
615
+ 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
616
+ };
617
+ return __lasx_xvld((const __m256i*)k_shuffle + i, 0);
618
+ }
619
+ static inline __m256i get_scale_shuffle_k4(int i) {
620
+ static const uint8_t k_shuffle[256] = {
621
+ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
622
+ 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
623
+ 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
624
+ 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
625
+ 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
626
+ 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
627
+ 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
628
+ 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
629
+ };
630
+ return __lasx_xvld((const __m256i*)k_shuffle + i, 0);
631
+ }
632
+ static inline __m128i get_scale_shuffle(int i) {
633
+ static const uint8_t k_shuffle[128] = {
634
+ 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
635
+ 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
636
+ 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
637
+ 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
638
+ 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
639
+ 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
640
+ 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
641
+ 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
642
+ };
643
+ return __lsx_vld((const __m128i*)k_shuffle + i, 0);
644
+ }
645
+ #endif
646
+
647
+ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
648
+ const int qk = QK8_0;
649
+ const int nb = n / qk;
650
+
651
+ assert(n % qk == 0);
652
+ assert(nrc == 1);
653
+ UNUSED(nrc);
654
+ UNUSED(bx);
655
+ UNUSED(by);
656
+ UNUSED(bs);
657
+
658
+ const block_q4_0 * GGML_RESTRICT x = vx;
659
+ const block_q8_0 * GGML_RESTRICT y = vy;
660
+
661
+ int ib = 0;
662
+ float sumf = 0;
663
+
664
+ #if defined(__loongarch_asx)
665
+ // Initialize accumulator with zeros
666
+ __m256 acc = (__m256)__lasx_xvldi(0);
667
+
668
+ // Main loop
669
+ for (; ib < nb; ++ib) {
670
+ /* Compute combined scale for the block */
671
+ const __m256 d = __lasx_xvreplfr2vr_s( GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d) );
672
+
673
+ __m256i qx = bytes_from_nibbles_32(x[ib].qs);
674
+
675
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
676
+ const __m256i off = __lasx_xvreplgr2vr_b( 8 );
677
+ qx = __lasx_xvsub_b( qx, off );
678
+
679
+ __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
680
+
681
+ const __m256 q = mul_sum_i8_pairs_float(qx, qy);
682
+
683
+ /* Multiply q with scale and accumulate */
684
+ acc = __lasx_xvfmadd_s( d, q, acc );
685
+ }
686
+
687
+ sumf = hsum_float_8(acc);
688
+
689
+ #elif defined(__loongarch_sx)
690
+ // set constants
691
+ const __m128i low_mask = __lsx_vreplgr2vr_b(0xF);
692
+ const __m128i off = __lsx_vreplgr2vr_b(8);
693
+
694
+ // Initialize accumulator with zeros
695
+ __m128 acc_0 = (__m128)__lsx_vldi(0);
696
+ __m128 acc_1 = (__m128)__lsx_vldi(0);
697
+ __m128 acc_2 = (__m128)__lsx_vldi(0);
698
+ __m128 acc_3 = (__m128)__lsx_vldi(0);
699
+
700
+ for (; ib + 1 < nb; ib += 2) {
701
+
702
+ // Compute combined scale for the block 0 and 1
703
+ const __m128 d_0_1 = (__m128)__lsx_vreplgr2vr_w( GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d) );
704
+
705
+ const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0);
706
+
707
+ __m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1);
708
+ __m128i by_0 = __lsx_vld((const __m128i *)y[ib].qs, 0);
709
+ bx_0 = __lsx_vsub_b(bx_0, off);
710
+ const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
711
+
712
+ __m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4));
713
+ __m128i by_1 = __lsx_vld((const __m128i *)(y[ib].qs + 16), 0);
714
+ bx_1 = __lsx_vsub_b(bx_1, off);
715
+ const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
716
+
717
+ //_mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
718
+ //_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
719
+
720
+ // Compute combined scale for the block 2 and 3
721
+ const __m128 d_2_3 = (__m128)__lsx_vreplgr2vr_w( GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) );
722
+
723
+ const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0);
724
+
725
+ __m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3);
726
+ __m128i by_2 = __lsx_vld((const __m128i *)y[ib + 1].qs, 0);
727
+ bx_2 = __lsx_vsub_b(bx_2, off);
728
+ const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
729
+
730
+ __m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4));
731
+ __m128i by_3 = __lsx_vld((const __m128i *)(y[ib + 1].qs + 16), 0);
732
+ bx_3 = __lsx_vsub_b(bx_3, off);
733
+ const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
734
+
735
+ // Convert int32_t to float
736
+ __m128 p0 = __lsx_vffint_s_w(i32_0);
737
+ __m128 p1 = __lsx_vffint_s_w(i32_1);
738
+ __m128 p2 = __lsx_vffint_s_w(i32_2);
739
+ __m128 p3 = __lsx_vffint_s_w(i32_3);
740
+
741
+ // Apply the scale
742
+ __m128 p0_d = __lsx_vfmul_s( d_0_1, p0 );
743
+ __m128 p1_d = __lsx_vfmul_s( d_0_1, p1 );
744
+ __m128 p2_d = __lsx_vfmul_s( d_2_3, p2 );
745
+ __m128 p3_d = __lsx_vfmul_s( d_2_3, p3 );
746
+
747
+ // Acummulate
748
+ acc_0 = __lsx_vfadd_s(p0_d, acc_0);
749
+ acc_1 = __lsx_vfadd_s(p1_d, acc_1);
750
+ acc_2 = __lsx_vfadd_s(p2_d, acc_2);
751
+ acc_3 = __lsx_vfadd_s(p3_d, acc_3);
752
+ }
753
+
754
+ sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
755
+
756
+ #endif
757
+ for (; ib < nb; ++ib) {
758
+ int sumi0 = 0;
759
+ int sumi1 = 0;
760
+
761
+ for (int j = 0; j < qk/2; ++j) {
762
+ const int v0 = (x[ib].qs[j] & 0x0F) - 8;
763
+ const int v1 = (x[ib].qs[j] >> 4) - 8;
764
+
765
+ sumi0 += (v0 * y[ib].qs[j]);
766
+ sumi1 += (v1 * y[ib].qs[j + qk/2]);
767
+ }
768
+
769
+ int sumi = sumi0 + sumi1;
770
+ sumf += sumi*GGML_CPU_FP16_TO_FP32(x[ib].d)*GGML_CPU_FP16_TO_FP32(y[ib].d);
771
+ }
772
+
773
+ *s = sumf;
774
+ }
775
+
776
+ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
777
+ const int qk = QK8_1;
778
+ const int nb = n / qk;
779
+
780
+ assert(n % qk == 0);
781
+ assert(nrc == 1);
782
+ UNUSED(nrc);
783
+ UNUSED(bx);
784
+ UNUSED(by);
785
+ UNUSED(bs);
786
+
787
+ const block_q4_1 * GGML_RESTRICT x = vx;
788
+ const block_q8_1 * GGML_RESTRICT y = vy;
789
+
790
+ int ib = 0;
791
+ float sumf = 0;
792
+
793
+ #if defined(__loongarch_asx)
794
+ // Initialize accumulator with zeros
795
+ __m256 acc = (__m256)__lasx_xvldi(0);
796
+
797
+ float summs = 0;
798
+
799
+ // Main loop
800
+ for (; ib < nb; ++ib) {
801
+ const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d);
802
+ const float d1 = GGML_CPU_FP16_TO_FP32(y[ib].d);
803
+
804
+ summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);
805
+
806
+ const __m256 d0v = __lasx_xvreplfr2vr_s( d0 );
807
+ const __m256 d1v = __lasx_xvreplfr2vr_s( d1 );
808
+
809
+ // Compute combined scales
810
+ const __m256 d0d1 = __lasx_xvfmul_s( d0v, d1v );
811
+
812
+ // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
813
+ const __m256i qx = bytes_from_nibbles_32(x[ib].qs);
814
+ const __m256i qy = __lasx_xvld( (const __m256i *)y[ib].qs, 0);
815
+
816
+ const __m256 xy = mul_sum_us8_pairs_float(qx, qy);
817
+
818
+ // Accumulate d0*d1*x*y
819
+ acc = __lasx_xvfmadd_s( d0d1, xy, acc );
820
+ }
821
+
822
+ sumf = hsum_float_8(acc) + summs;
823
+
824
+ *s = sumf;
825
+ #else
826
+ UNUSED(nb);
827
+ UNUSED(x);
828
+ UNUSED(y);
829
+ UNUSED(ib);
830
+ UNUSED(sumf);
831
+ ggml_vec_dot_q4_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
832
+ #endif
833
+ }
834
+
835
+ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
836
+ const int qk = QK8_0;
837
+ const int nb = n / qk;
838
+
839
+ int ib = 0;
840
+ float sumf = 0;
841
+
842
+ assert(n % qk == 0);
843
+ assert(qk == QK5_0);
844
+ assert(nrc == 1);
845
+ UNUSED(nrc);
846
+ UNUSED(bx);
847
+ UNUSED(by);
848
+ UNUSED(bs);
849
+
850
+ const block_q5_0 * GGML_RESTRICT x = vx;
851
+ const block_q8_0 * GGML_RESTRICT y = vy;
852
+
853
+ #if defined(__loongarch_asx)
854
+ // Initialize accumulator with zeros
855
+ __m256 acc = (__m256)__lasx_xvldi(0);
856
+
857
+ // Main loop
858
+ for (; ib < nb; ++ib) {
859
+ /* Compute combined scale for the block */
860
+ const __m256 d = __lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d)); //FIXME
861
+
862
+ __m256i qx = bytes_from_nibbles_32(x[ib].qs);
863
+ __m256i bxhi = bytes_from_bits_32(x[ib].qh);
864
+ bxhi = __lasx_xvandn_v(bxhi, __lasx_xvreplgr2vr_b((char)0xF0));
865
+ qx = __lasx_xvor_v(qx, bxhi);
866
+
867
+ __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
868
+
869
+ const __m256 q = mul_sum_i8_pairs_float(qx, qy);
870
+
871
+ /* Multiply q with scale and accumulate */
872
+ acc = __lasx_xvfmadd_s(d, q, acc);
873
+ }
874
+
875
+ sumf = hsum_float_8(acc);
876
+
877
+ *s = sumf;
878
+ #else
879
+ UNUSED(nb);
880
+ UNUSED(ib);
881
+ UNUSED(sumf);
882
+ UNUSED(x);
883
+ UNUSED(y);
884
+ ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
885
+ #endif
886
+ }
887
+
888
+ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
889
+ const int qk = QK8_1;
890
+ const int nb = n / qk;
891
+
892
+ int ib = 0;
893
+ float sumf = 0;
894
+
895
+ assert(n % qk == 0);
896
+ assert(qk == QK5_1);
897
+ assert(nrc == 1);
898
+ UNUSED(nrc);
899
+ UNUSED(bx);
900
+ UNUSED(by);
901
+ UNUSED(bs);
902
+
903
+ const block_q5_1 * GGML_RESTRICT x = vx;
904
+ const block_q8_1 * GGML_RESTRICT y = vy;
905
+
906
+ #if defined(__loongarch_asx)
907
+ // Initialize accumulator with zeros
908
+ __m256 acc = (__m256)__lasx_xvldi(0);
909
+
910
+ float summs = 0.0f;
911
+
912
+ // Main loop
913
+ for (; ib < nb; ++ib) {
914
+ const __m256 dx = __lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(x[ib].d));
915
+
916
+ summs += GGML_CPU_FP16_TO_FP32(x[ib].m) * GGML_CPU_FP16_TO_FP32(y[ib].s);
917
+
918
+ __m256i qx = bytes_from_nibbles_32(x[ib].qs);
919
+ __m256i bxhi = bytes_from_bits_32(x[ib].qh);
920
+ bxhi = __lasx_xvand_v(bxhi, __lasx_xvreplgr2vr_b(0x10));
921
+ qx = __lasx_xvor_v(qx, bxhi);
922
+
923
+ const __m256 dy = __lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(y[ib].d));
924
+ const __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
925
+
926
+ const __m256 q = mul_sum_us8_pairs_float(qx, qy);
927
+
928
+ acc = __lasx_xvfmadd_s(q, __lasx_xvfmul_s(dx, dy), acc);
929
+ }
930
+
931
+ sumf = hsum_float_8(acc) + summs;
932
+
933
+ *s = sumf;
934
+ #else
935
+ UNUSED(nb);
936
+ UNUSED(ib);
937
+ UNUSED(sumf);
938
+ UNUSED(x);
939
+ UNUSED(y);
940
+ ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc);
941
+ #endif
942
+ }
943
+
944
+ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
945
+ const int qk = QK8_0;
946
+ const int nb = n / qk;
947
+
948
+ assert(n % qk == 0);
949
+ assert(nrc == 1);
950
+ UNUSED(nrc);
951
+ UNUSED(bx);
952
+ UNUSED(by);
953
+ UNUSED(bs);
954
+
955
+ const block_q8_0 * GGML_RESTRICT x = vx;
956
+ const block_q8_0 * GGML_RESTRICT y = vy;
957
+
958
+ int ib = 0;
959
+ float sumf = 0;
960
+
961
+ #if defined(__loongarch_asx)
962
+ // Initialize accumulator with zeros
963
+ __m256 acc = (__m256)__lasx_xvldi(0);
964
+
965
+ // Main loop
966
+ for (; ib < nb; ++ib) {
967
+ // Compute combined scale for the block
968
+ const __m256 d = __lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));
969
+ __m256i qx = __lasx_xvld((const __m256i *)x[ib].qs, 0);
970
+ __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
971
+
972
+ const __m256 q = mul_sum_i8_pairs_float(qx, qy);
973
+
974
+ // Multiply q with scale and accumulate
975
+ acc = __lasx_xvfmadd_s( d, q, acc );
976
+ }
977
+
978
+ sumf = hsum_float_8(acc);
979
+
980
+ *s = sumf;
981
+ #else
982
+ UNUSED(nb);
983
+ UNUSED(ib);
984
+ UNUSED(sumf);
985
+ UNUSED(x);
986
+ UNUSED(y);
987
+ ggml_vec_dot_q8_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
988
+ #endif
989
+ }
990
+
991
+ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
992
+ assert(nrc == 1);
993
+ UNUSED(nrc);
994
+ UNUSED(bx);
995
+ UNUSED(by);
996
+ UNUSED(bs);
997
+
998
+ const block_q2_K * GGML_RESTRICT x = vx;
999
+ const block_q8_K * GGML_RESTRICT y = vy;
1000
+
1001
+ const int nb = n / QK_K;
1002
+
1003
+ #if defined __loongarch_asx
1004
+
1005
+ __m256 acc = (__m256)__lasx_xvldi(0);
1006
+
1007
+ for (int i = 0; i < nb; ++i) {
1008
+
1009
+ const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1010
+ const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1011
+
1012
+ const uint8_t * GGML_RESTRICT q2 = x[i].qs;
1013
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1014
+
1015
+ const __m128i mins_and_scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
1016
+ const __m128i scales128 = __lsx_vandi_b(mins_and_scales128, 0xf);
1017
+ const __m256i mins = lasx_ext8_16(__lsx_vsrli_b(mins_and_scales128, 4));
1018
+ const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0));
1019
+
1020
+ acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc);
1021
+
1022
+ const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
1023
+ const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
1024
+
1025
+ __m256i sumi = __lasx_xvldi(0);
1026
+
1027
+ for (int j = 0; j < QK_K/128; ++j) {
1028
+
1029
+ const __m256i q2bits = __lasx_xvld((const __m256i*)q2, 0); q2 += 32;
1030
+
1031
+ const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1032
+ const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1033
+ const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1034
+ const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1035
+
1036
+ const __m256i q2_0 = __lasx_xvandi_b(q2bits, 3);
1037
+ const __m256i q2_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 2), 3);
1038
+ const __m256i q2_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 4), 3);
1039
+ const __m256i q2_3 = __lasx_xvsrli_b(q2bits, 6);
1040
+
1041
+ __m256i p0 = lasx_madd_h_b(q2_0, q8_0);
1042
+ __m256i p1 = lasx_madd_h_b(q2_1, q8_1);
1043
+ __m256i p2 = lasx_madd_h_b(q2_2, q8_2);
1044
+ __m256i p3 = lasx_madd_h_b(q2_3, q8_3);
1045
+
1046
+ p0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p0);
1047
+ p1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p1);
1048
+ p2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p2);
1049
+ p3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p3);
1050
+
1051
+ p0 = __lasx_xvadd_w(p0, p1);
1052
+ p2 = __lasx_xvadd_w(p2, p3);
1053
+
1054
+ sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p0, p2));
1055
+ }
1056
+
1057
+ acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
1058
+
1059
+ }
1060
+
1061
+ *s = hsum_float_8(acc);
1062
+
1063
+ #else
1064
+ UNUSED(x);
1065
+ UNUSED(y);
1066
+ UNUSED(nb);
1067
+ ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1068
+ #endif
1069
+ }
1070
+
1071
+ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1072
+ assert(n % QK_K == 0);
1073
+ assert(nrc == 1);
1074
+ UNUSED(nrc);
1075
+ UNUSED(bx);
1076
+ UNUSED(by);
1077
+ UNUSED(bs);
1078
+
1079
+ const uint32_t kmask1 = 0x03030303;
1080
+ const uint32_t kmask2 = 0x0f0f0f0f;
1081
+
1082
+ const block_q3_K * GGML_RESTRICT x = vx;
1083
+ const block_q8_K * GGML_RESTRICT y = vy;
1084
+
1085
+ const int nb = n / QK_K;
1086
+
1087
+ #if defined __loongarch_asx
1088
+
1089
+ const __m128i m32 = __lsx_vreplgr2vr_b(32);
1090
+
1091
+ __m256 acc = (__m256)__lasx_xvldi(0);
1092
+
1093
+ uint32_t aux[3];
1094
+
1095
+ for (int i = 0; i < nb; ++i) {
1096
+
1097
+ const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1098
+ const uint8_t * GGML_RESTRICT q3 = x[i].qs;
1099
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1100
+ // Set up scales
1101
+ memcpy(aux, x[i].scales, 12);
1102
+ __m128i scales128 = lsx_set_w(
1103
+ ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
1104
+ ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
1105
+ (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
1106
+ (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
1107
+ scales128 = __lsx_vsub_b(scales128, m32);
1108
+
1109
+ const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
1110
+ const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
1111
+
1112
+ // high bit
1113
+ const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0);
1114
+
1115
+ // integer accumulator
1116
+ __m256i sumi = __lasx_xvldi(0);
1117
+
1118
+ for (int j = 0; j < QK_K/128; ++j) {
1119
+ // load low 2 bits
1120
+ const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;
1121
+
1122
+ // prepare low and high bits
1123
+ const __m256i q3l_0 = __lasx_xvandi_b(q3bits, 3);
1124
+ const __m256i q3l_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 2), 3);
1125
+ const __m256i q3l_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 4), 3);
1126
+ const __m256i q3l_3 = __lasx_xvsrli_b(q3bits, 6);
1127
+ const __m256i q3h_0 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 0), 0), 2);
1128
+ const __m256i q3h_1 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 1), 0), 2);
1129
+ const __m256i q3h_2 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 2), 0), 2);
1130
+ const __m256i q3h_3 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 3), 0), 2);
1131
+ const __m256i q3_0 = __lasx_xvor_v(q3h_0, q3l_0);
1132
+ const __m256i q3_1 = __lasx_xvor_v(q3h_1, q3l_1);
1133
+ const __m256i q3_2 = __lasx_xvor_v(q3h_2, q3l_2);
1134
+ const __m256i q3_3 = __lasx_xvor_v(q3h_3, q3l_3);
1135
+
1136
+ // load Q8 quants
1137
+ const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1138
+ const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1139
+ const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1140
+ const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1141
+
1142
+ __m256i p16_0 = lasx_madd_h_b(q8_0, q3_0);
1143
+ __m256i p16_1 = lasx_madd_h_b(q8_1, q3_1);
1144
+ __m256i p16_2 = lasx_madd_h_b(q8_2, q3_2);
1145
+ __m256i p16_3 = lasx_madd_h_b(q8_3, q3_3);
1146
+
1147
+ // multiply with scales
1148
+ p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);
1149
+ p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);
1150
+ p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);
1151
+ p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);
1152
+
1153
+ // accumulate
1154
+ p16_0 = __lasx_xvadd_w(p16_0, p16_1);
1155
+ p16_2 = __lasx_xvadd_w(p16_2, p16_3);
1156
+ sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2));
1157
+ }
1158
+ // multiply with block scale and accumulate
1159
+ acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
1160
+ }
1161
+
1162
+ *s = hsum_float_8(acc);
1163
+
1164
+ #else
1165
+ UNUSED(kmask1);
1166
+ UNUSED(kmask2);
1167
+ UNUSED(x);
1168
+ UNUSED(y);
1169
+ UNUSED(nb);
1170
+ ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1171
+ #endif
1172
+ }
1173
+
1174
+ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1175
+ assert(n % QK_K == 0);
1176
+ assert(nrc == 1);
1177
+ UNUSED(nrc);
1178
+ UNUSED(bx);
1179
+ UNUSED(by);
1180
+ UNUSED(bs);
1181
+
1182
+ const block_q4_K * GGML_RESTRICT x = vx;
1183
+ const block_q8_K * GGML_RESTRICT y = vy;
1184
+
1185
+ const int nb = n / QK_K;
1186
+
1187
+ static const uint32_t kmask1 = 0x3f3f3f3f;
1188
+ static const uint32_t kmask2 = 0x0f0f0f0f;
1189
+ static const uint32_t kmask3 = 0x03030303;
1190
+
1191
+ uint32_t utmp[4];
1192
+
1193
+ #if defined __loongarch_asx
1194
+
1195
+ __m256 acc = (__m256)__lasx_xvldi(0);
1196
+ __m128 acc_m = (__m128)__lsx_vldi(0);
1197
+
1198
+ for (int i = 0; i < nb; ++i) {
1199
+
1200
+ const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1201
+ const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1202
+
1203
+ memcpy(utmp, x[i].scales, 12);
1204
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1205
+ const uint32_t uaux = utmp[1] & kmask1;
1206
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1207
+ utmp[2] = uaux;
1208
+ utmp[0] &= kmask1;
1209
+
1210
+ const uint8_t * GGML_RESTRICT q4 = x[i].qs;
1211
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1212
+
1213
+ const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
1214
+ const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
1215
+ const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
1216
+
1217
+ const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
1218
+ const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
1219
+ const __m128i prod = lsx_madd_h(mins128, q8s);
1220
+ acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
1221
+
1222
+ const __m256i scales = lasx_insertf128(scales128, scales128);
1223
+
1224
+ __m256i sumi = __lasx_xvldi(0);
1225
+
1226
+ for (int j = 0; j < QK_K/64; ++j) {
1227
+
1228
+ const __m256i scale_l = lasx_xvrepl128vei_h(scales, 2 * j + 0);
1229
+ const __m256i scale_h = lasx_xvrepl128vei_h(scales, 2 * j + 1);
1230
+
1231
+ const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
1232
+ const __m256i q4l = __lasx_xvandi_b(q4bits, 0xf);
1233
+ const __m256i q4h = __lasx_xvsrli_b(q4bits, 4);
1234
+
1235
+ const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1236
+ __m256i p16l = lasx_madd_h_b(q4l, q8l);
1237
+ p16l = lasx_madd_h(scale_l, p16l);
1238
+
1239
+ const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1240
+ __m256i p16h = lasx_madd_h_b(q4h, q8h);
1241
+ p16h = lasx_madd_h(scale_h, p16h);
1242
+ const __m256i sumj = __lasx_xvadd_w(p16l, p16h);
1243
+
1244
+ sumi = __lasx_xvadd_w(sumi, sumj);
1245
+ }
1246
+
1247
+ __m256 vd = __lasx_xvreplfr2vr_s(d);
1248
+ acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
1249
+
1250
+ }
1251
+
1252
+ acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee));
1253
+ __m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0);
1254
+ acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
1255
+
1256
+
1257
+ *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
1258
+
1259
+ #else
1260
+ UNUSED(x);
1261
+ UNUSED(y);
1262
+ UNUSED(nb);
1263
+ UNUSED(kmask1);
1264
+ UNUSED(kmask2);
1265
+ UNUSED(kmask3);
1266
+ UNUSED(utmp);
1267
+ ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1268
+ #endif
1269
+ }
1270
+
1271
+ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1272
+ assert(n % QK_K == 0);
1273
+ assert(nrc == 1);
1274
+ UNUSED(nrc);
1275
+ UNUSED(bx);
1276
+ UNUSED(by);
1277
+ UNUSED(bs);
1278
+
1279
+ const block_q5_K * GGML_RESTRICT x = vx;
1280
+ const block_q8_K * GGML_RESTRICT y = vy;
1281
+
1282
+ const int nb = n / QK_K;
1283
+
1284
+ static const uint32_t kmask1 = 0x3f3f3f3f;
1285
+ static const uint32_t kmask2 = 0x0f0f0f0f;
1286
+ static const uint32_t kmask3 = 0x03030303;
1287
+
1288
+ uint32_t utmp[4];
1289
+
1290
+ #if defined __loongarch_asx
1291
+
1292
+ __m256 acc = (__m256)__lasx_xvldi(0);
1293
+ __m128 acc_m = (__m128)__lsx_vldi(0);
1294
+
1295
+ for (int i = 0; i < nb; ++i) {
1296
+
1297
+ const uint8_t * GGML_RESTRICT q5 = x[i].qs;
1298
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1299
+
1300
+ const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1301
+ const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
1302
+
1303
+ memcpy(utmp, x[i].scales, 12);
1304
+ utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
1305
+ const uint32_t uaux = utmp[1] & kmask1;
1306
+ utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
1307
+ utmp[2] = uaux;
1308
+ utmp[0] &= kmask1;
1309
+
1310
+ const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
1311
+ const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
1312
+ const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
1313
+
1314
+ const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
1315
+ const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
1316
+ const __m128i prod = lsx_madd_h(mins128, q8s);
1317
+ acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
1318
+
1319
+ const __m256i scales = lasx_insertf128(scales128, scales128);
1320
+
1321
+ const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0);
1322
+
1323
+ __m256i sumi = __lasx_xvldi(0);
1324
+
1325
+ for (int j = 0; j < QK_K/64; ++j) {
1326
+
1327
+ const __m256i scale_0 = lasx_xvrepl128vei_h(scales, 2 * j + 0);
1328
+ const __m256i scale_1 = lasx_xvrepl128vei_h(scales, 2 * j + 1);
1329
+
1330
+ const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
1331
+
1332
+ const __m256i q5l_0 = __lasx_xvandi_b(q5bits, 0xf);
1333
+ const __m256i q5l_1 = __lasx_xvsrli_b(q5bits, 4);
1334
+ const __m256i q5h_0 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 0), 0), 0xef);
1335
+ const __m256i q5h_1 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 1), 0), 0xef);
1336
+ const __m256i q5_0 = __lasx_xvor_v(q5l_0, q5h_0);
1337
+ const __m256i q5_1 = __lasx_xvor_v(q5l_1, q5h_1);
1338
+
1339
+ const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1340
+ const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1341
+
1342
+ __m256i p16_0 = lasx_madd_h_b(q5_0, q8_0);
1343
+ __m256i p16_1 = lasx_madd_h_b(q5_1, q8_1);
1344
+
1345
+ p16_0 = lasx_madd_h(scale_0, p16_0);
1346
+ p16_1 = lasx_madd_h(scale_1, p16_1);
1347
+
1348
+ sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
1349
+
1350
+ }
1351
+
1352
+ __m256 vd = __lasx_xvreplfr2vr_s(d);
1353
+ acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
1354
+
1355
+ }
1356
+
1357
+ acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 8));
1358
+ acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 4));
1359
+
1360
+ *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
1361
+
1362
+ #else
1363
+ UNUSED(x);
1364
+ UNUSED(y);
1365
+ UNUSED(nb);
1366
+ UNUSED(kmask1);
1367
+ UNUSED(kmask2);
1368
+ UNUSED(kmask3);
1369
+ UNUSED(utmp);
1370
+ ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1371
+ #endif
1372
+ }
1373
+
1374
+ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1375
+ assert(n % QK_K == 0);
1376
+ assert(nrc == 1);
1377
+ UNUSED(nrc);
1378
+ UNUSED(bx);
1379
+ UNUSED(by);
1380
+ UNUSED(bs);
1381
+
1382
+ const block_q6_K * GGML_RESTRICT x = vx;
1383
+ const block_q8_K * GGML_RESTRICT y = vy;
1384
+
1385
+ const int nb = n / QK_K;
1386
+
1387
+ #if defined __loongarch_asx
1388
+
1389
+ const __m256i m32s = __lasx_xvreplgr2vr_b(32);
1390
+
1391
+ __m256 acc = (__m256)__lasx_xvldi(0);
1392
+
1393
+ for (int i = 0; i < nb; ++i) {
1394
+
1395
+ const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
1396
+
1397
+ const uint8_t * GGML_RESTRICT q4 = x[i].ql;
1398
+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
1399
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1400
+
1401
+ const __m128i scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
1402
+ const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
1403
+ const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
1404
+
1405
+ __m256i sumi = __lasx_xvldi(0);
1406
+
1407
+ for (int j = 0; j < QK_K/128; ++j) {
1408
+
1409
+ const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
1410
+ const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
1411
+ const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32;
1412
+
1413
+ const __m256i q4h_0 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3), 4);
1414
+ const __m256i q4h_1 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3 << 2), 2);
1415
+ const __m256i q4h_2 = __lasx_xvandi_b(q4bitsH, 3 << 4);
1416
+ const __m256i q4h_3 = __lasx_xvsrli_b(__lasx_xvandi_b(q4bitsH, 3 << 6), 2);
1417
+
1418
+ const __m256i q4_0 = __lasx_xvor_v(__lasx_xvandi_b(q4bits1, 0xf), q4h_0);
1419
+ const __m256i q4_1 = __lasx_xvor_v(__lasx_xvandi_b(q4bits2, 0xf), q4h_1);
1420
+ const __m256i q4_2 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits1, 4), q4h_2);
1421
+ const __m256i q4_3 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits2, 4), q4h_3);
1422
+
1423
+ const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1424
+ const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1425
+ const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1426
+ const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
1427
+
1428
+ __m256i p16_0 = lasx_madd_h_b(__lasx_xvsub_b(q4_0, m32s), q8_0);
1429
+ __m256i p16_1 = lasx_madd_h_b(__lasx_xvsub_b(q4_1, m32s), q8_1);
1430
+ __m256i p16_2 = lasx_madd_h_b(__lasx_xvsub_b(q4_2, m32s), q8_2);
1431
+ __m256i p16_3 = lasx_madd_h_b(__lasx_xvsub_b(q4_3, m32s), q8_3);
1432
+
1433
+ p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);
1434
+ p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);
1435
+ p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);
1436
+ p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);
1437
+
1438
+ sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
1439
+ sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3));
1440
+ }
1441
+
1442
+ acc = __lasx_xvfmadd_s((__m256)__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
1443
+ }
1444
+
1445
+ *s = hsum_float_8(acc);
1446
+
1447
+ #else
1448
+ UNUSED(x);
1449
+ UNUSED(y);
1450
+ UNUSED(nb);
1451
+ ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1452
+ #endif
1453
+ }
1454
+
1455
+ #if defined(__loongarch_asx)
1456
+ static const int8_t keven_signs_q2xs[1024] = {
1457
+ 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
1458
+ 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
1459
+ 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
1460
+ 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
1461
+ 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
1462
+ 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
1463
+ 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
1464
+ 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
1465
+ 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
1466
+ 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
1467
+ 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
1468
+ 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
1469
+ 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
1470
+ 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
1471
+ 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
1472
+ 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
1473
+ 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
1474
+ 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
1475
+ 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
1476
+ 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
1477
+ 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
1478
+ 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
1479
+ 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
1480
+ 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
1481
+ 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
1482
+ 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
1483
+ 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
1484
+ 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
1485
+ 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
1486
+ 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
1487
+ 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
1488
+ 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
1489
+ };
1490
+ #endif
1491
+
1492
+ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1493
+ assert(n % QK_K == 0);
1494
+ assert(nrc == 1);
1495
+ UNUSED(nrc);
1496
+ UNUSED(bx);
1497
+ UNUSED(by);
1498
+ UNUSED(bs);
1499
+
1500
+ const block_iq2_xxs * GGML_RESTRICT x = vx;
1501
+ const block_q8_K * GGML_RESTRICT y = vy;
1502
+
1503
+ const int nb = n / QK_K;
1504
+
1505
+ #if defined(__loongarch_asx)
1506
+
1507
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
1508
+
1509
+ uint32_t aux32[4];
1510
+ const uint8_t * aux8 = (const uint8_t *)aux32;
1511
+
1512
+ __m256 accumf = (__m256)__lasx_xvldi(0);
1513
+ for (int i = 0; i < nb; ++i) {
1514
+ const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1515
+ const uint16_t * GGML_RESTRICT q2 = x[i].qs;
1516
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1517
+ __m256i sumi1 = __lasx_xvldi(0);
1518
+ __m256i sumi2 = __lasx_xvldi(0);
1519
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
1520
+ const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1521
+ const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1522
+ memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
1523
+
1524
+ const __m256i q2_1 = lasx_set_d(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
1525
+ const __m256i q2_2 = lasx_set_d(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
1526
+ const __m256i s2_1 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
1527
+ signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
1528
+ const __m256i s2_2 = lasx_set_d(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
1529
+ signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
1530
+ const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1);
1531
+ const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2);
1532
+ const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1);
1533
+ const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2);
1534
+ const uint16_t ls1 = aux32[1] >> 28;
1535
+ const uint16_t ls2 = aux32[3] >> 28;
1536
+ const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1));
1537
+ const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1));
1538
+ sumi1 = __lasx_xvadd_w(sumi1, p1);
1539
+ sumi2 = __lasx_xvadd_w(sumi2, p2);
1540
+ }
1541
+
1542
+ accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
1543
+ }
1544
+
1545
+ *s = 0.125f * hsum_float_8(accumf);
1546
+
1547
+ #else
1548
+ UNUSED(x);
1549
+ UNUSED(y);
1550
+ UNUSED(nb);
1551
+ ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1552
+ #endif
1553
+ }
1554
+
1555
+ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1556
+ assert(n % QK_K == 0);
1557
+ assert(nrc == 1);
1558
+ UNUSED(nrc);
1559
+ UNUSED(bx);
1560
+ UNUSED(by);
1561
+ UNUSED(bs);
1562
+
1563
+ const block_iq2_xs * GGML_RESTRICT x = vx;
1564
+ const block_q8_K * GGML_RESTRICT y = vy;
1565
+
1566
+ const int nb = n / QK_K;
1567
+
1568
+ #if defined(__loongarch_asx)
1569
+
1570
+ const __m256i mone = __lasx_xvreplgr2vr_b(1);
1571
+ static const char block_sign_shuffle_mask_1[32] = {
1572
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
1573
+ 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
1574
+ };
1575
+ static const char block_sign_shuffle_mask_2[32] = {
1576
+ 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
1577
+ 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
1578
+ };
1579
+ static const uint8_t bit_selector_mask_bytes[32] = {
1580
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
1581
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
1582
+ };
1583
+
1584
+ const __m256i bit_selector_mask = __lasx_xvld((const __m256i*)bit_selector_mask_bytes, 0);
1585
+ const __m256i block_sign_shuffle_1 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_1, 0);
1586
+ const __m256i block_sign_shuffle_2 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_2, 0);
1587
+
1588
+ static const uint8_t k_bit_helper[32] = {
1589
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
1590
+ 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
1591
+ };
1592
+ const __m256i bit_helper = __lasx_xvld((const __m256i*)k_bit_helper, 0);
1593
+ const __m256i m511 = __lasx_xvreplgr2vr_h(511);
1594
+ const __m128i m4 = __lsx_vreplgr2vr_b(0xf);
1595
+ const __m128i m1 = __lsx_vreplgr2vr_b(1);
1596
+
1597
+ uint64_t aux64;
1598
+
1599
+ // somewhat hacky, but gives a significant boost in performance
1600
+ __m256i aux_gindex;
1601
+ const uint16_t * gindex = (const uint16_t *)&aux_gindex;
1602
+
1603
+ __m256 accumf = (__m256)__lasx_xvldi(0);
1604
+ for (int i = 0; i < nb; ++i) {
1605
+ const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1606
+ const uint16_t * GGML_RESTRICT q2 = x[i].qs;
1607
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1608
+
1609
+ memcpy(&aux64, x[i].scales, 8);
1610
+ __m128i stmp = __lsx_vreplgr2vr_d(aux64);
1611
+ stmp = __lsx_vilvl_b( __lsx_vand_v(__lsx_vsrli_h(stmp, 4), m4), __lsx_vand_v(stmp, m4));
1612
+ const __m128i scales = __lsx_vadd_b(__lsx_vslli_h(stmp, 1), m1);
1613
+
1614
+ __m256i sumi1 = __lasx_xvldi(0);
1615
+ __m256i sumi2 = __lasx_xvldi(0);
1616
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
1617
+
1618
+ const __m256i q2_data = __lasx_xvld((const __m256i*)q2, 0); q2 += 16;
1619
+ aux_gindex = __lasx_xvand_v(q2_data, m511);
1620
+
1621
+ const __m256i partial_sign_bits = __lasx_xvsrli_h(q2_data, 9);
1622
+ const __m256i partial_sign_bits_upper = __lasx_xvsrli_h(q2_data, 13);
1623
+ const __m256i partial_sign_bits_for_counting = __lasx_xvxor_v(partial_sign_bits, partial_sign_bits_upper);
1624
+
1625
+ const __m256i odd_bits = lasx_shuffle_b(bit_helper, partial_sign_bits_for_counting);
1626
+ const __m256i full_sign_bits = __lasx_xvor_v(partial_sign_bits, odd_bits);
1627
+
1628
+ const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1629
+ const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1630
+ const __m256i q8_3 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1631
+ const __m256i q8_4 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1632
+
1633
+ const __m256i q2_1 = lasx_set_d(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]],
1634
+ iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]);
1635
+ const __m256i q2_2 = lasx_set_d(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]],
1636
+ iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]);
1637
+ const __m256i q2_3 = lasx_set_d(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]],
1638
+ iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]);
1639
+ const __m256i q2_4 = lasx_set_d(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]],
1640
+ iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
1641
+
1642
+ const __m128i full_signs_l = lasx_extracti128(full_sign_bits, 0);
1643
+ const __m128i full_signs_h = lasx_extracti128(full_sign_bits, 1);
1644
+ const __m256i full_signs_1 = lasx_insertf128(full_signs_l, full_signs_l);
1645
+ const __m256i full_signs_2 = lasx_insertf128(full_signs_h, full_signs_h);
1646
+
1647
+ __m256i signs;
1648
+ signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_1);
1649
+ signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
1650
+ const __m256i q8s_1 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_1);
1651
+
1652
+ signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_2);
1653
+ signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
1654
+ const __m256i q8s_2 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_2);
1655
+
1656
+ signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_1);
1657
+ signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
1658
+ const __m256i q8s_3 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_3);
1659
+
1660
+ signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_2);
1661
+ signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
1662
+ const __m256i q8s_4 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_4);
1663
+
1664
+ const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1);
1665
+ const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2);
1666
+ const __m256i dot3 = lasx_maddubs_h(q2_3, q8s_3);
1667
+ const __m256i dot4 = lasx_maddubs_h(q2_4, q8s_4);
1668
+
1669
+ const __m256i sc1 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+0)));
1670
+ const __m256i sc2 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+1)));
1671
+ const __m256i sc3 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+2)));
1672
+ const __m256i sc4 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+3)));
1673
+
1674
+ sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot1, sc1));
1675
+ sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot2, sc2));
1676
+ sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot3, sc3));
1677
+ sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot4, sc4));
1678
+ }
1679
+
1680
+ accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
1681
+
1682
+ }
1683
+
1684
+ *s = 0.125f * hsum_float_8(accumf);
1685
+
1686
+ #else
1687
+ UNUSED(x);
1688
+ UNUSED(y);
1689
+ UNUSED(nb);
1690
+ ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1691
+ #endif
1692
+ }
1693
+
1694
+ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1695
+ assert(n % QK_K == 0);
1696
+ assert(nrc == 1);
1697
+ UNUSED(nrc);
1698
+ UNUSED(bx);
1699
+ UNUSED(by);
1700
+ UNUSED(bs);
1701
+
1702
+ const block_iq2_s * GGML_RESTRICT x = vx;
1703
+ const block_q8_K * GGML_RESTRICT y = vy;
1704
+
1705
+ const int nb = n / QK_K;
1706
+
1707
+ #if defined(__loongarch_asx)
1708
+
1709
+ static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
1710
+ 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
1711
+ };
1712
+
1713
+ static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
1714
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
1715
+ };
1716
+
1717
+
1718
+ const __m128i m4 = __lsx_vreplgr2vr_b(0xf);
1719
+ const __m128i m1 = __lsx_vreplgr2vr_b(1);
1720
+
1721
+ const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0);
1722
+ const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0);
1723
+ uint64_t aux64;
1724
+
1725
+ __m256 accumf = (__m256)__lasx_xvldi(0);
1726
+ for (int i = 0; i < nb; ++i) {
1727
+ const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1728
+ const uint8_t * GGML_RESTRICT qs = x[i].qs;
1729
+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
1730
+ const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8);
1731
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1732
+
1733
+ __m128i tmp1;
1734
+ memcpy(&aux64, x[i].scales, 8);
1735
+ tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64, 0);
1736
+ tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64 >> 4, 1);
1737
+ const __m128i scales8 = __lsx_vadd_b(__lsx_vslli_h(__lsx_vand_v(tmp1, m4), 1), m1);
1738
+ const __m256i scales16 = lasx_ext8_16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15
1739
+
1740
+ __m256i sumi1 = __lasx_xvldi(0);
1741
+ __m256i sumi2 = __lasx_xvldi(0);
1742
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
1743
+ const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1744
+ const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1745
+ const __m256i q2_1 = lasx_set_d(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
1746
+ iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)],
1747
+ iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
1748
+ iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
1749
+ const __m256i q2_2 = lasx_set_d(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
1750
+ iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)],
1751
+ iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
1752
+ iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
1753
+ qs += 8;
1754
+
1755
+ __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | ((uint32_t) signs[1] << 16));
1756
+ aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
1757
+ const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2);
1758
+ const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1);
1759
+
1760
+ aux256 = __lasx_xvreplgr2vr_w(signs[2] | ((uint32_t) signs[3] << 16));
1761
+ aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
1762
+ const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2);
1763
+ const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2);
1764
+
1765
+ signs += 4;
1766
+
1767
+ const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1
1768
+ const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3
1769
+
1770
+ const __m256i p1 = lasx_madd_h(dot1, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+0)));
1771
+ const __m256i p2 = lasx_madd_h(dot2, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+1)));
1772
+ sumi1 = __lasx_xvadd_w(sumi1, p1);
1773
+ sumi2 = __lasx_xvadd_w(sumi2, p2);
1774
+ }
1775
+
1776
+ accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
1777
+ }
1778
+
1779
+ *s = 0.125f * hsum_float_8(accumf);
1780
+
1781
+ #else
1782
+ UNUSED(x);
1783
+ UNUSED(y);
1784
+ UNUSED(nb);
1785
+ ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1786
+ #endif
1787
+ }
1788
+
1789
+ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1790
+ assert(n % QK_K == 0);
1791
+ assert(nrc == 1);
1792
+ UNUSED(nrc);
1793
+ UNUSED(bx);
1794
+ UNUSED(by);
1795
+ UNUSED(bs);
1796
+
1797
+ const block_iq3_xxs * GGML_RESTRICT x = vx;
1798
+ const block_q8_K * GGML_RESTRICT y = vy;
1799
+
1800
+ const int nb = n / QK_K;
1801
+
1802
+ #if defined(__loongarch_asx)
1803
+
1804
+ const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
1805
+
1806
+ uint32_t aux32[2];
1807
+
1808
+ __m256 accumf = (__m256)__lasx_xvldi(0);
1809
+ for (int i = 0; i < nb; ++i) {
1810
+ const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1811
+ const uint8_t * GGML_RESTRICT q3 = x[i].qs;
1812
+ const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4;
1813
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1814
+ __m256i sumi1 = __lasx_xvldi(0);
1815
+ __m256i sumi2 = __lasx_xvldi(0);
1816
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
1817
+ const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1818
+ const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1819
+ const __m256i q2_1 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
1820
+ iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
1821
+ q3 += 8;
1822
+ const __m256i q2_2 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
1823
+ iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
1824
+ q3 += 8;
1825
+ memcpy(aux32, gas, 8); gas += 8;
1826
+
1827
+ const __m256i s2_1 = lasx_set_d(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],
1828
+ signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]);
1829
+ const __m256i s2_2 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
1830
+ signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
1831
+ const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1);
1832
+ const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2);
1833
+ const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1);
1834
+ const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2);
1835
+ const uint16_t ls1 = aux32[0] >> 28;
1836
+ const uint16_t ls2 = aux32[1] >> 28;
1837
+
1838
+ const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1));
1839
+ const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1));
1840
+ sumi1 = __lasx_xvadd_w(sumi1, p1);
1841
+ sumi2 = __lasx_xvadd_w(sumi2, p2);
1842
+ }
1843
+
1844
+ accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
1845
+ }
1846
+
1847
+ *s = 0.25f * hsum_float_8(accumf);
1848
+
1849
+ #else
1850
+ UNUSED(x);
1851
+ UNUSED(y);
1852
+ UNUSED(nb);
1853
+ ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1854
+ #endif
1855
+ }
1856
+
1857
+ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1858
+ assert(n % QK_K == 0);
1859
+ assert(nrc == 1);
1860
+ UNUSED(nrc);
1861
+ UNUSED(bx);
1862
+ UNUSED(by);
1863
+ UNUSED(bs);
1864
+
1865
+ const block_iq3_s * GGML_RESTRICT x = vx;
1866
+ const block_q8_K * GGML_RESTRICT y = vy;
1867
+
1868
+ const int nb = n / QK_K;
1869
+
1870
+ #if defined(__loongarch_asx)
1871
+
1872
+ static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
1873
+ 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
1874
+ };
1875
+
1876
+ static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
1877
+ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
1878
+ };
1879
+
1880
+ const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0);
1881
+ const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0);
1882
+
1883
+ __m256i idx_shift = lasx_set_w(1, 2, 3, 4, 5, 6, 7, 8);
1884
+ const __m256i idx_mask = __lasx_xvreplgr2vr_w(256);
1885
+
1886
+ typedef union {
1887
+ __m256i vec[2];
1888
+ uint32_t index[16];
1889
+ } index_t;
1890
+
1891
+ index_t idx;
1892
+
1893
+ __m256 accumf = (__m256)__lasx_xvldi(0);
1894
+ for (int i = 0; i < nb; ++i) {
1895
+ const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
1896
+ const uint8_t * GGML_RESTRICT qs = x[i].qs;
1897
+ const uint8_t * GGML_RESTRICT qh = x[i].qh;
1898
+ const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs;
1899
+ const int8_t * GGML_RESTRICT q8 = y[i].qs;
1900
+ __m256i sumi1 = __lasx_xvldi(0);
1901
+ __m256i sumi2 = __lasx_xvldi(0);
1902
+ for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
1903
+ const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1904
+ const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
1905
+ const __m256i idx_l = lasx_extu8_16(__lsx_vld(qs, 0)); qs += 16;
1906
+ idx.vec[0] = __lasx_xvreplgr2vr_w(qh[ib32+0]);
1907
+ idx.vec[1] = __lasx_xvreplgr2vr_w(qh[ib32+1]);
1908
+ idx.vec[0] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[0], idx_shift), idx_mask);
1909
+ idx.vec[1] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[1], idx_shift), idx_mask);
1910
+ idx.vec[0] = __lasx_xvor_v(idx.vec[0], lasx_ext16_32(lasx_extracti128(idx_l, 0)));
1911
+ idx.vec[1] = __lasx_xvor_v(idx.vec[1], lasx_ext16_32(lasx_extracti128(idx_l, 1)));
1912
+
1913
+ // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange.
1914
+ //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4);
1915
+ //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4);
1916
+ const __m256i q2_1 = lasx_set_w(
1917
+ iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]],
1918
+ iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]
1919
+ );
1920
+ const __m256i q2_2 = lasx_set_w(
1921
+ iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]],
1922
+ iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]]
1923
+ );
1924
+
1925
+ __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | (signs[1] << 16));
1926
+ aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
1927
+ const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2);
1928
+ const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1);
1929
+
1930
+ aux256 = __lasx_xvreplgr2vr_w(signs[2] | (signs[3] << 16));
1931
+ aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
1932
+ const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2);
1933
+ const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2);
1934
+
1935
+ signs += 4;
1936
+
1937
+ const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1);
1938
+ const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2);
1939
+ const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
1940
+ const uint16_t ls2 = x[i].scales[ib32/2] >> 4;
1941
+ const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1));
1942
+ const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1));
1943
+ sumi1 = __lasx_xvadd_w(sumi1, p1);
1944
+ sumi2 = __lasx_xvadd_w(sumi2, p2);
1945
+ }
1946
+
1947
+ accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
1948
+ }
1949
+
1950
+ *s = hsum_float_8(accumf);
1951
+
1952
+ #else
1953
+ UNUSED(x);
1954
+ UNUSED(y);
1955
+ UNUSED(nb);
1956
+ ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
1957
+ #endif
1958
+ }
1959
+
1960
+ #if defined(__loongarch_asx)
1961
+ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
1962
+ const __m256i a = __lasx_xvmulwev_h_b(x, y);
1963
+ const __m256i b = __lasx_xvmulwod_h_b(x, y);
1964
+ return __lasx_xvadd_h(a, b);
1965
+ }
1966
+ #endif
1967
+
1968
+ void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
1969
+ assert(n % QK_K == 0);
1970
+ assert(nrc == 1);
1971
+ UNUSED(nrc);
1972
+ UNUSED(bx);
1973
+ UNUSED(by);
1974
+ UNUSED(bs);
1975
+
1976
+ const block_iq1_s * GGML_RESTRICT x = vx;
1977
+ const block_q8_K * GGML_RESTRICT y = vy;
1978
+
1979
+ const int nb = n / QK_K;
1980
+
1981
+ #if defined(__loongarch_asx)
1982
+
1983
+ __m256 accum = (__m256)__lasx_xvldi(0);
1984
+ float accum1 = 0;
1985
+ for (int i = 0; i < nb; ++i) {
1986
+
1987
+ const int8_t * q8 = y[i].qs;
1988
+ const uint8_t * qs = x[i].qs;
1989
+ const uint16_t * qh = x[i].qh;
1990
+
1991
+ __m256i sumi = __lasx_xvldi(0);
1992
+ int sumi1 = 0;
1993
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
1994
+ __m256i q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)], 0);
1995
+ q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], 1);
1996
+ q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], 2);
1997
+ q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], 3);
1998
+
1999
+ __m256i q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)], 0);
2000
+ q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], 1);
2001
+ q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], 2);
2002
+ q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], 3);
2003
+
2004
+ qs += 8;
2005
+ const __m256i q8b_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
2006
+ const __m256i q8b_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
2007
+
2008
+ const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
2009
+ const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
2010
+ const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
2011
+ const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
2012
+
2013
+ __m256i tmp1, tmp5, tmp6;
2014
+ tmp1 = __lasx_xvreplgr2vr_h(ls1);
2015
+ tmp5 = __lasx_xvmulwev_w_h(dot1, tmp1);
2016
+ tmp6 = __lasx_xvmulwod_w_h(dot1, tmp1);
2017
+ const __m256i p1 = __lasx_xvadd_w(tmp5, tmp6);
2018
+
2019
+ tmp1 = __lasx_xvreplgr2vr_h(ls2);
2020
+ tmp5 = __lasx_xvmulwev_w_h(dot2, tmp1);
2021
+ tmp6 = __lasx_xvmulwod_w_h(dot2, tmp1);
2022
+ const __m256i p2 = __lasx_xvadd_w(tmp5, tmp6);
2023
+
2024
+ sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p1, p2));
2025
+ sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
2026
+ + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
2027
+ }
2028
+
2029
+ const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
2030
+ accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), accum);
2031
+ accum1 += d * sumi1;
2032
+ }
2033
+
2034
+ *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
2035
+
2036
+ #else
2037
+ UNUSED(x);
2038
+ UNUSED(y);
2039
+ UNUSED(nb);
2040
+ ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2041
+ #endif
2042
+ }
2043
+
2044
+ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2045
+ assert(nrc == 1);
2046
+ UNUSED(nrc);
2047
+ UNUSED(bx);
2048
+ UNUSED(by);
2049
+ UNUSED(bs);
2050
+ assert(n % QK4_NL == 0);
2051
+ static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
2052
+
2053
+ const block_iq4_nl * GGML_RESTRICT x = vx;
2054
+ const block_q8_0 * GGML_RESTRICT y = vy;
2055
+
2056
+ const int nb = n / QK4_NL;
2057
+
2058
+ int ib = 0;
2059
+ float sumf = 0;
2060
+
2061
+ #if defined (__loongarch_asx)
2062
+
2063
+ const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0);
2064
+ const __m128i m4b = __lsx_vreplgr2vr_b(0x0f);
2065
+ const __m256i mone = __lasx_xvreplgr2vr_h(1);
2066
+
2067
+ __m256 accum1 = (__m256)__lasx_xvldi(0);
2068
+ __m256 accum2 = (__m256)__lasx_xvldi(0);
2069
+ for (; ib + 1 < nb; ib += 2) {
2070
+ const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[ib + 0].qs, 0);
2071
+ const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[ib + 1].qs, 0);
2072
+ const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[ib + 0].qs, 0);
2073
+ const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[ib + 1].qs, 0);
2074
+ const __m256i q4b_1 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)),
2075
+ lsx_shuffle_b(values128, __lsx_vand_v(q4bits_1, m4b)));
2076
+ const __m256i q4b_2 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)),
2077
+ lsx_shuffle_b(values128, __lsx_vand_v(q4bits_2, m4b)));
2078
+ const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
2079
+ const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
2080
+ const __m256i p_1 = lasx_madd_h(p16_1, mone);
2081
+ const __m256i p_2 = lasx_madd_h(p16_2, mone);
2082
+ accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_CPU_FP16_TO_FP32(x[ib + 0].d)),
2083
+ __lasx_xvffint_s_w(p_1), accum1);
2084
+ accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_CPU_FP16_TO_FP32(x[ib + 1].d)),
2085
+ __lasx_xvffint_s_w(p_2), accum2);
2086
+ }
2087
+
2088
+ sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2));
2089
+
2090
+ #endif
2091
+ for (; ib < nb; ++ib) {
2092
+ const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_FP16_TO_FP32(x[ib].d);
2093
+ int sumi1 = 0, sumi2 = 0;
2094
+ for (int j = 0; j < QK4_NL/2; ++j) {
2095
+ sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];
2096
+ sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4];
2097
+ }
2098
+ sumf += d * (sumi1 + sumi2);
2099
+ }
2100
+ *s = sumf;
2101
+ }
2102
+
2103
+ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
2104
+ assert(nrc == 1);
2105
+ UNUSED(nrc);
2106
+ UNUSED(bx);
2107
+ UNUSED(by);
2108
+ UNUSED(bs);
2109
+ assert(n % QK_K == 0);
2110
+
2111
+ const block_iq4_xs * GGML_RESTRICT x = vx;
2112
+ const block_q8_K * GGML_RESTRICT y = vy;
2113
+
2114
+ const int nb = n / QK_K;
2115
+
2116
+ #if defined(__loongarch_asx)
2117
+
2118
+ const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0);
2119
+
2120
+ __m256 accum = (__m256)__lasx_xvldi(0);
2121
+
2122
+ for (int ibl = 0; ibl < nb; ++ibl) {
2123
+ const uint8_t * qs = x[ibl].qs;
2124
+ const int8_t * q8 = y[ibl].qs;
2125
+ uint16_t sh = x[ibl].scales_h;
2126
+ __m256i sumi1 = __lasx_xvldi(0);
2127
+ __m256i sumi2 = __lasx_xvldi(0);
2128
+ for (int ib = 0; ib < QK_K/32; ib += 2) {
2129
+ const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
2130
+ const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
2131
+ const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
2132
+ const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
2133
+ const __m256i q4b_1 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_1, 4)),
2134
+ __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_1, 0xf)));
2135
+ const __m256i q4b_2 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_2, 4)),
2136
+ __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_2, 0xf)));
2137
+ const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
2138
+ const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
2139
+ const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
2140
+ const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32;
2141
+ sh >>= 4;
2142
+ const __m256i p_1 = lasx_madd_h(p16_1, __lasx_xvreplgr2vr_h(ls1));
2143
+ const __m256i p_2 = lasx_madd_h(p16_2, __lasx_xvreplgr2vr_h(ls2));
2144
+ sumi1 = __lasx_xvadd_w(p_1, sumi1);
2145
+ sumi2 = __lasx_xvadd_w(p_2, sumi2);
2146
+ }
2147
+ accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_CPU_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
2148
+ __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accum);
2149
+ }
2150
+
2151
+ *s = hsum_float_8(accum);
2152
+
2153
+ #else
2154
+ UNUSED(x);
2155
+ UNUSED(y);
2156
+ UNUSED(nb);
2157
+ ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
2158
+ #endif
2159
+ }
2160
+