whispercpp 1.3.2 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (664) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +6 -3
  3. data/README.md +71 -14
  4. data/Rakefile +20 -7
  5. data/ext/.gitignore +4 -6
  6. data/ext/dependencies.rb +36 -24
  7. data/ext/extconf.rb +1 -1
  8. data/ext/options.rb +48 -184
  9. data/ext/ruby_whisper.c +18 -0
  10. data/ext/ruby_whisper_context.c +43 -12
  11. data/ext/ruby_whisper_model.c +1 -1
  12. data/ext/ruby_whisper_params.c +59 -27
  13. data/ext/ruby_whisper_segment.c +81 -4
  14. data/ext/ruby_whisper_transcribe.cpp +13 -7
  15. data/ext/ruby_whisper_vad_params.c +1 -1
  16. data/ext/sources/CMakeLists.txt +5 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/build-xcframework.sh +24 -0
  19. data/ext/sources/examples/CMakeLists.txt +1 -0
  20. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +120 -24
  21. data/ext/sources/examples/addon.node/addon.cpp +154 -35
  22. data/ext/sources/examples/addon.node/index.js +10 -5
  23. data/ext/sources/examples/addon.node/vad-example.js +132 -0
  24. data/ext/sources/examples/bench/bench.cpp +29 -18
  25. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  26. data/ext/sources/examples/cli/cli.cpp +7 -4
  27. data/ext/sources/examples/command/command.cpp +58 -32
  28. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/common-whisper.cpp +14 -7
  31. data/ext/sources/examples/lsp/lsp.cpp +21 -17
  32. data/ext/sources/examples/quantize/quantize.cpp +3 -0
  33. data/ext/sources/examples/server/CMakeLists.txt +3 -0
  34. data/ext/sources/examples/server/server.cpp +193 -35
  35. data/ext/sources/examples/server.py +6 -1
  36. data/ext/sources/examples/stream/stream.cpp +10 -2
  37. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  38. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  39. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -0
  40. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  41. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  42. data/ext/sources/examples/talk-llama/llama-arch.cpp +756 -15
  43. data/ext/sources/examples/talk-llama/llama-arch.h +85 -1
  44. data/ext/sources/examples/talk-llama/llama-batch.cpp +773 -272
  45. data/ext/sources/examples/talk-llama/llama-batch.h +126 -55
  46. data/ext/sources/examples/talk-llama/llama-chat.cpp +150 -13
  47. data/ext/sources/examples/talk-llama/llama-chat.h +8 -0
  48. data/ext/sources/examples/talk-llama/llama-context.cpp +814 -542
  49. data/ext/sources/examples/talk-llama/llama-context.h +68 -32
  50. data/ext/sources/examples/talk-llama/llama-cparams.cpp +1 -1
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  52. data/ext/sources/examples/talk-llama/llama-graph.cpp +787 -440
  53. data/ext/sources/examples/talk-llama/llama-graph.h +333 -153
  54. data/ext/sources/examples/talk-llama/llama-hparams.cpp +128 -6
  55. data/ext/sources/examples/talk-llama/llama-hparams.h +80 -17
  56. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  57. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  58. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +137 -0
  59. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +1248 -1967
  60. data/ext/sources/examples/talk-llama/llama-kv-cache.h +218 -345
  61. data/ext/sources/examples/talk-llama/llama-kv-cells.h +164 -52
  62. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +266 -0
  63. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +139 -0
  64. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +1154 -0
  65. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +182 -0
  66. data/ext/sources/examples/talk-llama/llama-memory.cpp +58 -0
  67. data/ext/sources/examples/talk-llama/llama-memory.h +94 -4
  68. data/ext/sources/examples/talk-llama/llama-mmap.cpp +1 -1
  69. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +44 -17
  70. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  71. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +1 -0
  72. data/ext/sources/examples/talk-llama/llama-model.cpp +11377 -5248
  73. data/ext/sources/examples/talk-llama/llama-model.h +87 -9
  74. data/ext/sources/examples/talk-llama/llama-quant.cpp +137 -16
  75. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  76. data/ext/sources/examples/talk-llama/llama-vocab.cpp +502 -38
  77. data/ext/sources/examples/talk-llama/llama-vocab.h +46 -0
  78. data/ext/sources/examples/talk-llama/llama.cpp +76 -17
  79. data/ext/sources/examples/talk-llama/llama.h +176 -151
  80. data/ext/sources/examples/talk-llama/talk-llama.cpp +11 -6
  81. data/ext/sources/examples/talk-llama/unicode.cpp +212 -0
  82. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  83. data/ext/sources/examples/vad-speech-segments/speech.cpp +6 -0
  84. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +6 -2
  85. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  86. data/ext/sources/ggml/CMakeLists.txt +106 -33
  87. data/ext/sources/ggml/cmake/common.cmake +24 -0
  88. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  89. data/ext/sources/ggml/include/ggml-backend.h +18 -2
  90. data/ext/sources/ggml/include/ggml-cpu.h +2 -0
  91. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  92. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  93. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  94. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  95. data/ext/sources/ggml/include/ggml.h +365 -21
  96. data/ext/sources/ggml/src/CMakeLists.txt +98 -25
  97. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  98. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  99. data/ext/sources/ggml/src/ggml-backend-reg.cpp +35 -13
  100. data/ext/sources/ggml/src/ggml-backend.cpp +266 -60
  101. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +4 -4
  102. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  103. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +15 -0
  104. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  105. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  106. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  107. data/ext/sources/ggml/src/ggml-cann/common.h +149 -2
  108. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +521 -78
  109. data/ext/sources/ggml/src/ggml-common.h +21 -0
  110. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +165 -50
  111. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -3
  112. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +11 -10
  113. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +94 -0
  114. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +3650 -0
  115. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1891 -0
  116. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +2160 -0
  117. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp +82 -0
  118. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +2305 -0
  119. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1897 -0
  120. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +342 -0
  121. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +1468 -0
  122. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +1221 -0
  123. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +3820 -0
  124. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6307 -0
  125. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +214 -0
  126. data/ext/sources/ggml/src/ggml-cpu/common.h +18 -3
  127. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +23 -7
  128. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +179 -110
  129. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +44 -33
  130. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.cpp → hbm.cpp} +1 -1
  131. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  132. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  133. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +228 -98
  134. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +532 -1124
  135. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +5 -0
  136. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +3374 -2081
  137. data/ext/sources/ggml/src/ggml-cpu/ops.h +13 -8
  138. data/ext/sources/ggml/src/ggml-cpu/quants.c +1193 -0
  139. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-quants.h → quants.h} +34 -0
  140. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1982 -0
  141. data/ext/sources/ggml/src/ggml-cpu/repack.h +120 -0
  142. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +367 -46
  143. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  144. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  145. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  146. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  147. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.cpp → traits.cpp} +3 -3
  148. data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-traits.h → traits.h} +1 -1
  149. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +272 -35
  150. data/ext/sources/ggml/src/ggml-cpu/vec.h +794 -142
  151. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  152. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  153. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  154. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  155. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  156. data/ext/sources/ggml/src/ggml-cuda/common.cuh +291 -81
  157. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  158. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cu +161 -0
  159. data/ext/sources/ggml/src/ggml-cuda/conv2d-dw.cuh +5 -0
  160. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +91 -0
  161. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +4 -0
  162. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  163. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  164. data/ext/sources/ggml/src/ggml-cuda/convert.cu +117 -22
  165. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +20 -0
  166. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  167. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  168. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  169. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  170. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +499 -368
  171. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +142 -93
  172. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  173. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  174. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  175. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +90 -50
  176. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  177. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  178. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  179. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +636 -222
  180. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  181. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  182. data/ext/sources/ggml/src/ggml-cuda/mean.cu +73 -0
  183. data/ext/sources/ggml/src/ggml-cuda/mean.cuh +3 -0
  184. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  185. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  186. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  187. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  188. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  189. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +506 -0
  190. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +4 -5
  191. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  192. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  193. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  195. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  196. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  197. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  198. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  199. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  200. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  201. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  202. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  203. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  204. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  205. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  206. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  208. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  210. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  211. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -98
  212. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  213. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +23 -19
  214. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +0 -1
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  234. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  235. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  236. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  237. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  238. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  239. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  240. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  241. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  242. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  243. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  244. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  245. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  246. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  247. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  248. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  249. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  250. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  251. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  252. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  253. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  254. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  255. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  256. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  258. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  259. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  260. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  261. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  262. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  269. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  270. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  271. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  272. data/ext/sources/ggml/src/ggml-cuda/unary.cu +179 -0
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +15 -0
  274. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  277. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  278. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  279. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -2
  280. data/ext/sources/ggml/src/ggml-impl.h +229 -175
  281. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +21 -17
  282. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  283. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  284. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  285. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  286. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  287. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  288. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  289. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +163 -63
  290. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  291. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  292. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  293. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3208 -1575
  294. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  295. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +2 -2
  296. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +32 -0
  297. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4430 -792
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/argsort.cl +86 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +109 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +138 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +378 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +121 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl +283 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +30 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +39 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/sigmoid.cl +29 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +138 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +39 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +63 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +48 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +120 -0
  344. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  345. data/ext/sources/ggml/src/ggml-quants.c +117 -24
  346. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  347. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +85 -62
  348. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +3 -3
  349. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  350. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +9 -0
  351. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  352. data/ext/sources/ggml/src/ggml-sycl/common.hpp +20 -48
  353. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +13 -17
  354. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +21 -2
  355. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +116 -211
  356. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  357. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +32 -0
  358. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +700 -1041
  359. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +20 -9
  360. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +17 -26
  361. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +2 -96
  362. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +393 -250
  363. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +1 -1
  364. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +32 -8
  365. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  366. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +38 -11
  367. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +125 -21
  368. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  369. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  370. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +3 -1
  371. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +3 -0
  372. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +4 -3
  373. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +105 -17
  374. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +36 -32
  375. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +4198 -1145
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -12
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +98 -0
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  401. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  402. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  403. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +13 -0
  404. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +19 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp +29 -0
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +9 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +64 -11
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +9 -0
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +338 -71
  449. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  450. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  451. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  452. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  453. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  454. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  455. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  456. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  457. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  458. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  459. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  460. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  461. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  462. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  463. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  464. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  465. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  466. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  467. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  468. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  469. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  470. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  471. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  472. data/ext/sources/ggml/src/ggml.c +802 -142
  473. data/ext/sources/ggml/src/ggml.cpp +26 -0
  474. data/ext/sources/ggml/src/gguf.cpp +32 -4
  475. data/ext/sources/include/whisper.h +2 -0
  476. data/ext/sources/src/CMakeLists.txt +2 -0
  477. data/ext/sources/src/coreml/whisper-compat.h +10 -0
  478. data/ext/sources/src/coreml/whisper-compat.m +35 -0
  479. data/ext/sources/src/coreml/whisper-decoder-impl.m +1 -0
  480. data/ext/sources/src/coreml/whisper-encoder-impl.m +1 -0
  481. data/ext/sources/src/whisper.cpp +241 -215
  482. data/ext/sources/tests/CMakeLists.txt +8 -1
  483. data/ext/sources/tests/test-vad-full.cpp +3 -3
  484. data/ext/sources/tests/test-vad.cpp +2 -2
  485. data/extsources.rb +15 -9
  486. data/lib/whisper/context.rb +15 -0
  487. data/lib/whisper/model/uri.rb +57 -2
  488. data/lib/whisper/segment.rb +58 -0
  489. data/sig/whisper.rbs +75 -38
  490. data/{tests → test}/helper.rb +1 -12
  491. data/{tests → test}/test_model.rb +9 -0
  492. data/test/test_package.rb +51 -0
  493. data/{tests → test}/test_params.rb +8 -0
  494. data/test/test_segment.rb +146 -0
  495. data/{tests → test}/test_whisper.rb +70 -0
  496. data/whispercpp.gemspec +2 -3
  497. metadata +246 -191
  498. data/ext/sources/.dockerignore +0 -3
  499. data/ext/sources/.github/workflows/bindings-ruby.yml +0 -21
  500. data/ext/sources/ci/run.sh +0 -336
  501. data/ext/sources/close-issue.yml +0 -28
  502. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  503. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  504. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  505. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  506. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  507. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  508. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  509. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  510. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  511. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  512. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  513. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  514. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  515. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  516. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  517. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  518. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -6431
  519. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +0 -8
  520. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -13747
  521. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  522. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  523. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  524. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  525. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  526. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  527. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -336
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  548. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  549. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  550. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  551. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  552. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  553. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  554. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  555. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  556. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  557. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  558. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  559. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  560. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  561. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  562. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  563. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  564. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  565. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  566. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  567. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  568. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  569. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  570. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  571. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  572. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  573. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  574. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  575. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  576. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  577. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  578. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  579. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  580. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  581. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  582. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  583. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  584. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  585. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  586. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  587. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  588. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  589. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  590. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  591. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  592. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  593. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  594. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  595. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  596. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  597. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  598. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  599. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  600. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  601. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  602. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  603. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  604. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  605. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  606. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  607. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  608. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  609. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  610. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  611. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  612. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  613. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  614. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  615. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  616. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  617. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  618. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  619. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  620. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  621. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  622. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  623. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  624. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  625. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  626. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  627. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  628. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  629. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  630. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  631. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  632. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  633. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  634. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  635. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  636. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  637. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  638. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  639. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  640. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  641. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  642. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  643. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  644. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  645. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  646. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  647. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  648. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  649. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  650. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  651. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  652. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  653. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -5998
  654. data/tests/test_package.rb +0 -46
  655. data/tests/test_segment.rb +0 -74
  656. /data/ext/sources/ggml/src/ggml-cpu/{cpu-feats-x86.cpp → arch/x86/cpu-feats.cpp} +0 -0
  657. /data/ext/sources/ggml/src/ggml-cpu/{ggml-cpu-hbm.h → hbm.h} +0 -0
  658. /data/{tests → test}/jfk_reader/.gitignore +0 -0
  659. /data/{tests → test}/jfk_reader/extconf.rb +0 -0
  660. /data/{tests → test}/jfk_reader/jfk_reader.c +0 -0
  661. /data/{tests → test}/test_callback.rb +0 -0
  662. /data/{tests → test}/test_error.rb +0 -0
  663. /data/{tests → test}/test_vad.rb +0 -0
  664. /data/{tests → test}/test_vad_params.rb +0 -0
@@ -0,0 +1,3196 @@
1
+ #include "ggml.h"
2
+ #include "ime_kernels.h"
3
+
4
+ #include <algorithm>
5
+ #include <cmath>
6
+
7
+ // clang-format off
8
+ #if defined(__GNUC__)
9
+ #pragma GCC diagnostic ignored "-Woverlength-strings"
10
+ #pragma GCC diagnostic ignored "-Wcast-qual"
11
+ #pragma GCC diagnostic ignored "-Wunused-parameter"
12
+ #endif
13
+ // clang-format on
14
+ namespace sqnbitgemm_spacemit_ime {
15
+
16
+ #define QUANTIZEM4ROW_KERNEL \
17
+ "vmv.s.x v16, zero \n\t" \
18
+ "vfabs.v v8, v0 \n\t" \
19
+ "vfredmax.vs v16, v8, v16 \n\t" \
20
+ "vfmv.f.s f10, v16 \n\t" \
21
+ "fmul.s f10, f10, %[RMAXREC] \n\t" \
22
+ "fsw f10, (a1) \n\t" \
23
+ "fdiv.s f11, %[FONE], f10 \n\t" \
24
+ "vfmul.vf v16, v0, f11 \n\t" \
25
+ "vfcvt.x.f.v v16, v16 \n\t" \
26
+ "vsetvli t0, zero, e16, mf2 \n\t" \
27
+ "vnclip.wx v16, v16, zero \n\t" \
28
+ "vnclip.wx v17, v17, zero \n\t" \
29
+ "vnclip.wx v18, v18, zero \n\t" \
30
+ "vnclip.wx v19, v19, zero \n\t" \
31
+ "vnclip.wx v20, v20, zero \n\t" \
32
+ "vnclip.wx v21, v21, zero \n\t" \
33
+ "vnclip.wx v22, v22, zero \n\t" \
34
+ "vnclip.wx v23, v23, zero \n\t" \
35
+ "vsetvli t0, zero, e8, mf4 \n\t" \
36
+ "vnclip.wx v24, v16, zero \n\t" \
37
+ "vnclip.wx v25, v17, zero \n\t" \
38
+ "vnclip.wx v26, v18, zero \n\t" \
39
+ "vnclip.wx v27, v19, zero \n\t" \
40
+ "vnclip.wx v28, v20, zero \n\t" \
41
+ "vnclip.wx v29, v21, zero \n\t" \
42
+ "vnclip.wx v30, v22, zero \n\t" \
43
+ "vnclip.wx v31, v23, zero \n\t"
44
+
45
+ #define QUANTIZEM4ROW_STORE \
46
+ "addi t1, %[BlkLen], 0 \n\t" \
47
+ "vsetvli t0, t1, e8, mf4 \n\t" \
48
+ "vse8.v v24, (s1) \n\t" \
49
+ "addi s1, s1, 32 \n\t" \
50
+ "sub t1, t1, t0 \n\t" \
51
+ "vsetvli t0, t1, e8, mf4 \n\t" \
52
+ "vse8.v v25, (s1) \n\t" \
53
+ "addi s1, s1, 32 \n\t" \
54
+ "sub t1, t1, t0 \n\t" \
55
+ "vsetvli t0, t1, e8, mf4 \n\t" \
56
+ "vse8.v v26, (s1) \n\t" \
57
+ "addi s1, s1, 32 \n\t" \
58
+ "sub t1, t1, t0 \n\t" \
59
+ "vsetvli t0, t1, e8, mf4 \n\t" \
60
+ "vse8.v v27, (s1) \n\t" \
61
+ "addi s1, s1, 32 \n\t" \
62
+ "sub t1, t1, t0 \n\t" \
63
+ "vsetvli t0, t1, e8, mf4 \n\t" \
64
+ "vse8.v v28, (s1) \n\t" \
65
+ "addi s1, s1, 32 \n\t" \
66
+ "sub t1, t1, t0 \n\t" \
67
+ "vsetvli t0, t1, e8, mf4 \n\t" \
68
+ "vse8.v v29, (s1) \n\t" \
69
+ "addi s1, s1, 32 \n\t" \
70
+ "sub t1, t1, t0 \n\t" \
71
+ "vsetvli t0, t1, e8, mf4 \n\t" \
72
+ "vse8.v v30, (s1) \n\t" \
73
+ "addi s1, s1, 32 \n\t" \
74
+ "sub t1, t1, t0 \n\t" \
75
+ "vsetvli t0, t1, e8, mf4 \n\t" \
76
+ "vse8.v v31, (s1) \n\t"
77
+
78
+ namespace ime1 {
79
+ void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) {
80
+ constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);
81
+ const float fone = 1.0f;
82
+
83
+ if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64) {
84
+ for (size_t row_index = 0; row_index < 4; ++row_index) {
85
+ const float * SRC = A + row_index * CountK;
86
+ std::byte * DST = QuantA + row_index * sizeof(float);
87
+
88
+ const size_t offset = (4 - row_index) * 4 + row_index * 8;
89
+ const size_t stride = 4 * (sizeof(float) + BlkLen);
90
+ __asm__ volatile(
91
+ "vsetvli t0, zero, e32, m8 \n\t"
92
+ "addi t2, %[CountK], 0 \n\t"
93
+ "addi a1, %[DST], 0 \n\t"
94
+ "blt t2, %[BlkLen], TAIL%= \n\t"
95
+
96
+ "LOOP%=: \n\t"
97
+ "vsetvli t0, %[BlkLen], e32, m8 \n\t"
98
+ "vle32.v v0, (%[SRC]) \n\t"
99
+ "sub t2, t2, t0 \n\t"
100
+ "slli t1, t0, 2 \n\t"
101
+ "add %[SRC], %[SRC], t1 \n\t"
102
+ "add s1, a1, %[OFFSET] \n\t"
103
+
104
+ QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE
105
+
106
+ "add a1, a1, %[STRIDE] \n\t"
107
+ "bge t2, %[BlkLen], LOOP%= \n\t"
108
+
109
+ "TAIL%=: \n\t"
110
+ "blez t2, QUIT%= \n\t"
111
+ "vsetvli t0, zero, e32, m8 \n\t"
112
+ "vxor.vv v16, v16, v16 \n\t"
113
+ "vxor.vv v24, v24, v24 \n\t"
114
+ "vsetvli t0, t2, e32, m8 \n\t"
115
+ "vle32.v v0, (%[SRC]) \n\t"
116
+ "add s1, a1, %[OFFSET] \n\t"
117
+
118
+ QUANTIZEM4ROW_KERNEL
119
+
120
+ "addi t3, %[BlkLen], 0 \n\t"
121
+ "addi s2, s1, 0 \n\t"
122
+ "vsetvli t0, zero, e8, mf4 \n\t"
123
+ "vxor.vv v8, v8, v8 \n\t"
124
+ "SET_ZERO%=: \n\t"
125
+ "vse8.v v8, (s2) \n\t"
126
+ "addi s2, s2, 32 \n\t"
127
+ "addi t3, t3, -8 \n\t"
128
+ "bnez t3, SET_ZERO%= \n\t"
129
+
130
+ QUANTIZEM4ROW_STORE
131
+
132
+ "QUIT%=: \n\t"
133
+ : [SRC] "+r"(SRC)
134
+ : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
135
+ [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
136
+ : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11");
137
+ }
138
+ } else if (BlkLen == 128) {
139
+ for (size_t row_index = 0; row_index < 4; ++row_index) {
140
+ const float * SRC = A + row_index * CountK;
141
+ std::byte * DST = QuantA + row_index * sizeof(float);
142
+
143
+ const size_t offset = (4 - row_index) * 4 + row_index * 8;
144
+ const size_t stride = 4 * (sizeof(float) + BlkLen);
145
+ __asm__ volatile(
146
+ "vsetvli t0, zero, e32, m8 \n\t"
147
+ "li t6, 32 \n\t"
148
+ "addi t2, %[CountK], 0 \n\t"
149
+ "addi a1, %[DST], 0 \n\t"
150
+ "add s1, a1, %[OFFSET] \n\t"
151
+ "blt t2, %[BlkLen], TAIL%= \n\t"
152
+
153
+ "LOOP%=: \n\t"
154
+ "vsetvli t0, zero, e32, m8 \n\t"
155
+ "vle32.v v0, (%[SRC]) \n\t"
156
+ "addi %[SRC], %[SRC], 256 \n\t"
157
+ "vle32.v v8, (%[SRC]) \n\t"
158
+ "addi %[SRC], %[SRC], 256 \n\t"
159
+ "addi t2, t2, -128 \n\t"
160
+
161
+ "QUANTIZE%=: \n\t"
162
+ "add s1, a1, %[OFFSET] \n\t"
163
+ "vfabs.v v16, v0 \n\t"
164
+ "vfabs.v v24, v8 \n\t"
165
+ "vfmax.vv v16, v24, v16 \n\t"
166
+ "vfredmax.vs v24, v16, v24 \n\t"
167
+ "vfmv.f.s f10, v24 \n\t"
168
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
169
+ "fsw f10, (a1) \n\t"
170
+ "fdiv.s f11, %[FONE], f10 \n\t"
171
+ "vfmul.vf v16, v0, f11 \n\t"
172
+ "vfmul.vf v24, v8, f11 \n\t"
173
+ "vfcvt.x.f.v v16, v16 \n\t"
174
+ "vfcvt.x.f.v v24, v24 \n\t"
175
+ "vsetvli t0, zero, e16, m4 \n\t"
176
+ "vnclip.wx v16, v16, zero \n\t"
177
+ "vnclip.wx v20, v24, zero \n\t"
178
+ "vsetvli t0, zero, e8, m4 \n\t"
179
+ "vnclip.wx v16, v16, zero \n\t"
180
+ "vsetvli t0, zero, e64, m4 \n\t"
181
+ "vsse64.v v16, (s1), t6 \n\t"
182
+ "add a1, a1, %[STRIDE] \n\t"
183
+ "bge t2, %[BlkLen], LOOP%= \n\t"
184
+
185
+ "TAIL%=: \n\t"
186
+ "blez t2, QUIT%= \n\t"
187
+ "vsetvli t0, zero, e32, m8 \n\t"
188
+ "vxor.vv v0, v0, v0 \n\t"
189
+ "vxor.vv v8, v8, v8 \n\t"
190
+ "vxor.vv v16, v16, v16 \n\t"
191
+ "vxor.vv v24, v24, v24 \n\t"
192
+ "vsetvli t0, t2, e32, m8 \n\t"
193
+ "sub t2, t2, t0 \n\t"
194
+ "vle32.v v0, (%[SRC]) \n\t"
195
+ "addi %[SRC], %[SRC], 256 \n\t"
196
+ "vsetvli t0, t2, e32, m8 \n\t"
197
+ "vle32.v v8, (%[SRC]) \n\t"
198
+ "sub t2, t2, t2 \n\t"
199
+ "vsetvli t0, zero, e32, m8 \n\t"
200
+ "jal x0, QUANTIZE%= \n\t"
201
+
202
+ "QUIT%=: \n\t"
203
+ : [SRC] "+r"(SRC)
204
+ : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
205
+ [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
206
+ : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11");
207
+ }
208
+ } else if (BlkLen == 256) {
209
+ for (size_t row_index = 0; row_index < 4; ++row_index) {
210
+ const float * SRC = A + row_index * CountK;
211
+ std::byte * DST = QuantA + row_index * sizeof(float);
212
+ const size_t offset = (4 - row_index) * 4 + row_index * 8;
213
+ const size_t stride = 4 * (sizeof(float) + BlkLen);
214
+ __asm__ volatile(
215
+ "vsetvli t0, zero, e32, m8 \n\t"
216
+ "li t6, 32 \n\t"
217
+ "addi t2, %[CountK], 0 \n\t"
218
+ "addi a1, %[DST], 0 \n\t"
219
+ "add s1, a1, %[OFFSET] \n\t"
220
+ "blt t2, %[BlkLen], TAIL%= \n\t"
221
+
222
+ "LOOP%=: \n\t"
223
+ "vsetvli t0, zero, e32, m8 \n\t"
224
+ "vle32.v v0, (%[SRC]) \n\t"
225
+ "addi %[SRC], %[SRC], 256 \n\t"
226
+ "vle32.v v8, (%[SRC]) \n\t"
227
+ "addi %[SRC], %[SRC], 256 \n\t"
228
+ "vle32.v v16, (%[SRC]) \n\t"
229
+ "addi %[SRC], %[SRC], 256 \n\t"
230
+ "vle32.v v24, (%[SRC]) \n\t"
231
+ "addi %[SRC], %[SRC], -768 \n\t"
232
+ "addi t2, t2, -256 \n\t"
233
+ "vfabs.v v0, v0 \n\t"
234
+ "vfabs.v v8, v8 \n\t"
235
+ "vfabs.v v16, v16 \n\t"
236
+ "vfabs.v v24, v24 \n\t"
237
+ "vfmax.vv v8, v0, v8 \n\t"
238
+ "vfmax.vv v24, v24, v16 \n\t"
239
+ "vfmax.vv v8, v8, v24 \n\t"
240
+ "vfredmax.vs v24, v8, v24 \n\t"
241
+ "vfmv.f.s f10, v24 \n\t"
242
+ "vle32.v v0, (%[SRC]) \n\t"
243
+ "addi %[SRC], %[SRC], 256 \n\t"
244
+ "vle32.v v8, (%[SRC]) \n\t"
245
+ "addi %[SRC], %[SRC], 256 \n\t"
246
+ "vle32.v v16, (%[SRC]) \n\t"
247
+ "addi %[SRC], %[SRC], 256 \n\t"
248
+ "vle32.v v24, (%[SRC]) \n\t"
249
+ "addi %[SRC], %[SRC], 256 \n\t"
250
+
251
+ "QUANTIZE%=: \n\t"
252
+ "add s1, a1, %[OFFSET] \n\t"
253
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
254
+ "fsw f10, (a1) \n\t"
255
+ "fdiv.s f11, %[FONE], f10 \n\t"
256
+ "vfmul.vf v0, v0, f11 \n\t"
257
+ "vfmul.vf v8, v8, f11 \n\t"
258
+ "vfmul.vf v16, v16, f11 \n\t"
259
+ "vfmul.vf v24, v24, f11 \n\t"
260
+ "vfcvt.x.f.v v0, v0 \n\t"
261
+ "vfcvt.x.f.v v8, v8 \n\t"
262
+ "vfcvt.x.f.v v16, v16 \n\t"
263
+ "vfcvt.x.f.v v24, v24 \n\t"
264
+ "vsetvli t0, zero, e16, m4 \n\t"
265
+ "vnclip.wx v0, v0, zero \n\t"
266
+ "vnclip.wx v4, v8, zero \n\t"
267
+ "vnclip.wx v8, v16, zero \n\t"
268
+ "vnclip.wx v12, v24, zero \n\t"
269
+ "vsetvli t0, zero, e8, m4 \n\t"
270
+ "vnclip.wx v0, v0, zero \n\t"
271
+ "vnclip.wx v4, v8, zero \n\t"
272
+ "vsetvli t0, zero, e64, m8 \n\t"
273
+ "vsse64.v v0, (s1), t6 \n\t"
274
+ "add a1, a1, %[STRIDE] \n\t"
275
+ "bge t2, %[BlkLen], LOOP%= \n\t"
276
+
277
+ "TAIL%=: \n\t"
278
+ "blez t2, QUIT%= \n\t"
279
+ "vsetvli t0, zero, e32, m8 \n\t"
280
+ "vxor.vv v0, v0, v0 \n\t"
281
+ "vxor.vv v8, v8, v8 \n\t"
282
+ "vxor.vv v16, v16, v16 \n\t"
283
+ "vxor.vv v24, v24, v24 \n\t"
284
+ "addi t1, t2, 0 \n\t"
285
+ "vsetvli t0, t1, e32, m8 \n\t"
286
+ "sub t1, t1, t0 \n\t"
287
+ "vle32.v v0, (%[SRC]) \n\t"
288
+ "addi %[SRC], %[SRC], 256 \n\t"
289
+ "vsetvli t0, t1, e32, m8 \n\t"
290
+ "sub t1, t1, t0 \n\t"
291
+ "vle32.v v8, (%[SRC]) \n\t"
292
+ "addi %[SRC], %[SRC], 256 \n\t"
293
+ "vsetvli t0, t1, e32, m8 \n\t"
294
+ "sub t1, t1, t0 \n\t"
295
+ "vle32.v v16, (%[SRC]) \n\t"
296
+ "addi %[SRC], %[SRC], 256 \n\t"
297
+ "vsetvli t0, t1, e32, m8 \n\t"
298
+ "vle32.v v24, (%[SRC]) \n\t"
299
+ "addi %[SRC], %[SRC], -768 \n\t"
300
+ "vsetvli t0, zero, e32, m8 \n\t"
301
+ "vfabs.v v0, v0 \n\t"
302
+ "vfabs.v v8, v8 \n\t"
303
+ "vfabs.v v16, v16 \n\t"
304
+ "vfabs.v v24, v24 \n\t"
305
+ "vfmax.vv v8, v0, v8 \n\t"
306
+ "vfmax.vv v24, v16, v24 \n\t"
307
+ "vfmax.vv v8, v8, v24 \n\t"
308
+ "vfredmax.vs v24, v8, v24 \n\t"
309
+ "vfmv.f.s f10, v24 \n\t"
310
+ "add s1, a1, %[OFFSET] \n\t"
311
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
312
+ "fsw f10, (a1) \n\t"
313
+ "fdiv.s f11, %[FONE], f10 \n\t"
314
+ "vsetvli t0, zero, e64, m8 \n\t"
315
+ "vxor.vv v0, v0, v0 \n\t"
316
+ "vsse64.v v0, (s1), t6 \n\t"
317
+
318
+ "TAIL_LOOP%=: \n\t"
319
+ "vsetvli t0, zero, e32, m4 \n\t"
320
+ "vxor.vv v0, v0, v0 \n\t"
321
+ "vsetvli t0, t2, e32, m1 \n\t"
322
+ "sub t2, t2, t0 \n\t"
323
+ "vle32.v v0, (%[SRC]) \n\t"
324
+ "addi %[SRC], %[SRC], 32 \n\t"
325
+ "vfmul.vf v1, v0, f11 \n\t"
326
+ "vfcvt.x.f.v v2, v1 \n\t"
327
+ "vsetvli t0, zero, e16, mf2 \n\t"
328
+ "vnclip.wx v3, v2, zero \n\t"
329
+ "vsetvli t0, zero, e8, mf4 \n\t"
330
+ "vnclip.wx v3, v3, zero \n\t"
331
+ "vse8.v v3, (s1) \n\t"
332
+ "addi s1, s1, 32 \n\t"
333
+ "bnez t2, TAIL_LOOP%= \n\t"
334
+
335
+ "QUIT%=: \n\t"
336
+ : [SRC] "+r"(SRC)
337
+ : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
338
+ [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
339
+ : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11");
340
+ }
341
+ }
342
+ }
343
+
344
+ void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) {
345
+ const float * SRC = A;
346
+ std::byte * DST = QuantA;
347
+ constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);
348
+ const float fone = 1.0f;
349
+ std::byte * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen);
350
+ size_t offset = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK;
351
+
352
+ if (CountK <= BlkLen) {
353
+ float max_abs_A = 0.0f;
354
+ for (size_t k = 0; k < CountK; k++) {
355
+ max_abs_A = std::max(max_abs_A, fabsf(A[k]));
356
+ }
357
+ float scale_A = max_abs_A * range_max_reciprocal;
358
+
359
+ ((float *) QuantA)[0] = scale_A;
360
+
361
+ auto * QuantAData_offset = (int8_t *) (QuantA + sizeof(float));
362
+
363
+ for (size_t k = 0; k < CountK; k++) {
364
+ QuantAData_offset[k] =
365
+ (int8_t) std::clamp(roundf(A[k] / scale_A), (float) std::numeric_limits<int8_t>::lowest(),
366
+ (float) std::numeric_limits<int8_t>::max());
367
+ }
368
+ for (size_t k = CountK; k < BlkLen; k++) {
369
+ QuantAData_offset[k] = 0;
370
+ }
371
+
372
+ return;
373
+ }
374
+
375
+ if (BlkLen != 32 || BlkLen != 64 || BlkLen != 128) {
376
+ __asm__ volatile(
377
+ "vsetvli t0, zero, e8, m8 \n\t"
378
+ "vxor.vv v24, v24, v24 \n\t"
379
+ "LOOP%=: \n\t"
380
+ "vsetvli t0, %[CNT], e8, m8 \n\t"
381
+ "vse8.v v24, (%[DST]) \n\t"
382
+ "addi %[DST], %[DST], 128 \n\t"
383
+ "sub %[CNT], %[CNT], t0 \n\t"
384
+ "bnez %[CNT], LOOP%= \n\t"
385
+ : [DST] "+r"(QuantA_offset), [CNT] "+r"(offset)
386
+ :
387
+ : "cc", "t0");
388
+ }
389
+ if (BlkLen == 16) {
390
+ float buffer[64] = { 0.0f };
391
+ __asm__ volatile(
392
+ "addi t3, zero, 16*8 \n\t"
393
+ "addi t2, zero, 16 \n\t"
394
+ "blt %[K], t3, LOOP_K%= \n\t"
395
+ "blt %[K], t2, TAIL%= \n\t"
396
+ "LOOP_MAIN%=: \n\t"
397
+ "vsetvli t1, zero, e32, m2 \n\t"
398
+ "addi %[K], %[K], -128 \n\t"
399
+ "vle32.v v0, (%[SRC]) \n\t"
400
+ "addi %[SRC], %[SRC], 64 \n\t"
401
+ "vle32.v v2, (%[SRC]) \n\t"
402
+ "addi %[SRC], %[SRC], 64 \n\t"
403
+ "vle32.v v4, (%[SRC]) \n\t"
404
+ "addi %[SRC], %[SRC], 64 \n\t"
405
+ "vle32.v v6, (%[SRC]) \n\t"
406
+ "addi %[SRC], %[SRC], 64 \n\t"
407
+ "vle32.v v8, (%[SRC]) \n\t"
408
+ "addi %[SRC], %[SRC], 64 \n\t"
409
+ "vle32.v v10, (%[SRC]) \n\t"
410
+ "addi %[SRC], %[SRC], 64 \n\t"
411
+ "vle32.v v12, (%[SRC]) \n\t"
412
+ "addi %[SRC], %[SRC], 64 \n\t"
413
+ "vle32.v v14, (%[SRC]) \n\t"
414
+ "addi %[SRC], %[SRC], 64 \n\t"
415
+ "addi a1, %[BUFFER], 0 \n\t"
416
+ "vfabs.v v16, v0 \n\t"
417
+ "vfabs.v v18, v2 \n\t"
418
+ "vfabs.v v20, v4 \n\t"
419
+ "vfabs.v v22, v6 \n\t"
420
+ "vfabs.v v24, v8 \n\t"
421
+ "vfabs.v v26, v10 \n\t"
422
+ "vfabs.v v28, v12 \n\t"
423
+ "vfabs.v v30, v14 \n\t"
424
+ "vsetvli t0, zero, e32, m1 \n\t"
425
+ "vfmax.vv v16, v16, v17 \n\t"
426
+ "vfmax.vv v18, v18, v19 \n\t"
427
+ "vfmax.vv v20, v20, v21 \n\t"
428
+ "vfmax.vv v22, v22, v23 \n\t"
429
+ "vfmax.vv v24, v24, v25 \n\t"
430
+ "vfmax.vv v26, v26, v27 \n\t"
431
+ "vfmax.vv v28, v28, v29 \n\t"
432
+ "vfmax.vv v30, v30, v31 \n\t"
433
+ "vse32.v v16, (a1) \n\t"
434
+ "addi a1, a1, 32 \n\t"
435
+ "vse32.v v18, (a1) \n\t"
436
+ "addi a1, a1, 32 \n\t"
437
+ "vse32.v v20, (a1) \n\t"
438
+ "addi a1, a1, 32 \n\t"
439
+ "vse32.v v22, (a1) \n\t"
440
+ "addi a1, a1, 32 \n\t"
441
+ "vse32.v v24, (a1) \n\t"
442
+ "addi a1, a1, 32 \n\t"
443
+ "vse32.v v26, (a1) \n\t"
444
+ "addi a1, a1, 32 \n\t"
445
+ "vse32.v v28, (a1) \n\t"
446
+ "addi a1, a1, 32 \n\t"
447
+ "vse32.v v30, (a1) \n\t"
448
+ "addi a1, %[BUFFER], 0 \n\t"
449
+ "flw f0, (a1) \n\t"
450
+ "flw f1, 4(a1) \n\t"
451
+ "flw f2, 8(a1) \n\t"
452
+ "flw f3, 12(a1) \n\t"
453
+ "flw f4, 16(a1) \n\t"
454
+ "flw f5, 20(a1) \n\t"
455
+ "flw f6, 24(a1) \n\t"
456
+ "flw f7, 28(a1) \n\t"
457
+ "addi a1, a1, 32 \n\t"
458
+ "fmax.s f1, f0, f1 \n\t"
459
+ "fmax.s f3, f2, f3 \n\t"
460
+ "fmax.s f5, f4, f5 \n\t"
461
+ "fmax.s f7, f6, f7 \n\t"
462
+ "fmax.s f3, f1, f3 \n\t"
463
+ "fmax.s f7, f5, f7 \n\t"
464
+ "fmax.s f10, f3, f7 \n\t"
465
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
466
+ "fsw f10, (%[DST]) \n\t"
467
+ "addi %[DST], %[DST], 20 \n\t"
468
+ "fdiv.s f10, %[FONE], f10 \n\t"
469
+ "flw f0, (a1) \n\t"
470
+ "flw f1, 4(a1) \n\t"
471
+ "flw f2, 8(a1) \n\t"
472
+ "flw f3, 12(a1) \n\t"
473
+ "flw f4, 16(a1) \n\t"
474
+ "flw f5, 20(a1) \n\t"
475
+ "flw f6, 24(a1) \n\t"
476
+ "flw f7, 28(a1) \n\t"
477
+ "addi a1, a1, 32 \n\t"
478
+ "fmax.s f1, f0, f1 \n\t"
479
+ "fmax.s f3, f2, f3 \n\t"
480
+ "fmax.s f5, f4, f5 \n\t"
481
+ "fmax.s f7, f6, f7 \n\t"
482
+ "fmax.s f3, f1, f3 \n\t"
483
+ "fmax.s f7, f5, f7 \n\t"
484
+ "fmax.s f11, f3, f7 \n\t"
485
+ "fmul.s f11, f11, %[RMAXREC] \n\t"
486
+ "fsw f11, (%[DST]) \n\t"
487
+ "addi %[DST], %[DST], 20 \n\t"
488
+ "fdiv.s f11, %[FONE], f11 \n\t"
489
+ "flw f0, (a1) \n\t"
490
+ "flw f1, 4(a1) \n\t"
491
+ "flw f2, 8(a1) \n\t"
492
+ "flw f3, 12(a1) \n\t"
493
+ "flw f4, 16(a1) \n\t"
494
+ "flw f5, 20(a1) \n\t"
495
+ "flw f6, 24(a1) \n\t"
496
+ "flw f7, 28(a1) \n\t"
497
+ "addi a1, a1, 32 \n\t"
498
+ "fmax.s f1, f0, f1 \n\t"
499
+ "fmax.s f3, f2, f3 \n\t"
500
+ "fmax.s f5, f4, f5 \n\t"
501
+ "fmax.s f7, f6, f7 \n\t"
502
+ "fmax.s f3, f1, f3 \n\t"
503
+ "fmax.s f7, f5, f7 \n\t"
504
+ "fmax.s f12, f3, f7 \n\t"
505
+ "fmul.s f12, f12, %[RMAXREC] \n\t"
506
+ "fsw f12, (%[DST]) \n\t"
507
+ "addi %[DST], %[DST], 20 \n\t"
508
+ "fdiv.s f12, %[FONE], f12 \n\t"
509
+ "flw f0, (a1) \n\t"
510
+ "flw f1, 4(a1) \n\t"
511
+ "flw f2, 8(a1) \n\t"
512
+ "flw f3, 12(a1) \n\t"
513
+ "flw f4, 16(a1) \n\t"
514
+ "flw f5, 20(a1) \n\t"
515
+ "flw f6, 24(a1) \n\t"
516
+ "flw f7, 28(a1) \n\t"
517
+ "addi a1, a1, 32 \n\t"
518
+ "fmax.s f1, f0, f1 \n\t"
519
+ "fmax.s f3, f2, f3 \n\t"
520
+ "fmax.s f5, f4, f5 \n\t"
521
+ "fmax.s f7, f6, f7 \n\t"
522
+ "fmax.s f3, f1, f3 \n\t"
523
+ "fmax.s f7, f5, f7 \n\t"
524
+ "fmax.s f13, f3, f7 \n\t"
525
+ "fmul.s f13, f13, %[RMAXREC] \n\t"
526
+ "fsw f13, (%[DST]) \n\t"
527
+ "addi %[DST], %[DST], 20 \n\t"
528
+ "fdiv.s f13, %[FONE], f13 \n\t"
529
+ "flw f0, (a1) \n\t"
530
+ "flw f1, 4(a1) \n\t"
531
+ "flw f2, 8(a1) \n\t"
532
+ "flw f3, 12(a1) \n\t"
533
+ "flw f4, 16(a1) \n\t"
534
+ "flw f5, 20(a1) \n\t"
535
+ "flw f6, 24(a1) \n\t"
536
+ "flw f7, 28(a1) \n\t"
537
+ "addi a1, a1, 32 \n\t"
538
+ "fmax.s f1, f0, f1 \n\t"
539
+ "fmax.s f3, f2, f3 \n\t"
540
+ "fmax.s f5, f4, f5 \n\t"
541
+ "fmax.s f7, f6, f7 \n\t"
542
+ "fmax.s f3, f1, f3 \n\t"
543
+ "fmax.s f7, f5, f7 \n\t"
544
+ "fmax.s f14, f3, f7 \n\t"
545
+ "fmul.s f14, f14, %[RMAXREC] \n\t"
546
+ "fsw f14, (%[DST]) \n\t"
547
+ "addi %[DST], %[DST], 20 \n\t"
548
+ "fdiv.s f14, %[FONE], f14 \n\t"
549
+ "flw f0, (a1) \n\t"
550
+ "flw f1, 4(a1) \n\t"
551
+ "flw f2, 8(a1) \n\t"
552
+ "flw f3, 12(a1) \n\t"
553
+ "flw f4, 16(a1) \n\t"
554
+ "flw f5, 20(a1) \n\t"
555
+ "flw f6, 24(a1) \n\t"
556
+ "flw f7, 28(a1) \n\t"
557
+ "addi a1, a1, 32 \n\t"
558
+ "fmax.s f1, f0, f1 \n\t"
559
+ "fmax.s f3, f2, f3 \n\t"
560
+ "fmax.s f5, f4, f5 \n\t"
561
+ "fmax.s f7, f6, f7 \n\t"
562
+ "fmax.s f3, f1, f3 \n\t"
563
+ "fmax.s f7, f5, f7 \n\t"
564
+ "fmax.s f15, f3, f7 \n\t"
565
+ "fmul.s f15, f15, %[RMAXREC] \n\t"
566
+ "fsw f15, (%[DST]) \n\t"
567
+ "addi %[DST], %[DST], 20 \n\t"
568
+ "fdiv.s f15, %[FONE], f15 \n\t"
569
+ "flw f0, (a1) \n\t"
570
+ "flw f1, 4(a1) \n\t"
571
+ "flw f2, 8(a1) \n\t"
572
+ "flw f3, 12(a1) \n\t"
573
+ "flw f4, 16(a1) \n\t"
574
+ "flw f5, 20(a1) \n\t"
575
+ "flw f6, 24(a1) \n\t"
576
+ "flw f7, 28(a1) \n\t"
577
+ "addi a1, a1, 32 \n\t"
578
+ "fmax.s f1, f0, f1 \n\t"
579
+ "fmax.s f3, f2, f3 \n\t"
580
+ "fmax.s f5, f4, f5 \n\t"
581
+ "fmax.s f7, f6, f7 \n\t"
582
+ "fmax.s f3, f1, f3 \n\t"
583
+ "fmax.s f7, f5, f7 \n\t"
584
+ "fmax.s f16, f3, f7 \n\t"
585
+ "fmul.s f16, f16, %[RMAXREC] \n\t"
586
+ "fsw f16, (%[DST]) \n\t"
587
+ "addi %[DST], %[DST], 20 \n\t"
588
+ "fdiv.s f16, %[FONE], f16 \n\t"
589
+ "flw f0, (a1) \n\t"
590
+ "flw f1, 4(a1) \n\t"
591
+ "flw f2, 8(a1) \n\t"
592
+ "flw f3, 12(a1) \n\t"
593
+ "flw f4, 16(a1) \n\t"
594
+ "flw f5, 20(a1) \n\t"
595
+ "flw f6, 24(a1) \n\t"
596
+ "flw f7, 28(a1) \n\t"
597
+ "addi a1, a1, 32 \n\t"
598
+ "fmax.s f1, f0, f1 \n\t"
599
+ "fmax.s f3, f2, f3 \n\t"
600
+ "fmax.s f5, f4, f5 \n\t"
601
+ "fmax.s f7, f6, f7 \n\t"
602
+ "fmax.s f3, f1, f3 \n\t"
603
+ "fmax.s f7, f5, f7 \n\t"
604
+ "fmax.s f17, f3, f7 \n\t"
605
+ "fmul.s f17, f17, %[RMAXREC] \n\t"
606
+ "fsw f17, (%[DST]) \n\t"
607
+ "addi %[DST], %[DST], -136 \n\t"
608
+ "fdiv.s f17, %[FONE], f17 \n\t"
609
+ "vsetvli t0, zero, e32, m2 \n\t"
610
+ "vfmul.vf v16, v0, f10 \n\t"
611
+ "vfmul.vf v18, v2, f11 \n\t"
612
+ "vfmul.vf v20, v4, f12 \n\t"
613
+ "vfmul.vf v22, v6, f13 \n\t"
614
+ "vfmul.vf v24, v8, f14 \n\t"
615
+ "vfmul.vf v26, v10, f15 \n\t"
616
+ "vfmul.vf v28, v12, f16 \n\t"
617
+ "vfmul.vf v30, v14, f17 \n\t"
618
+ "vfcvt.x.f.v v16, v16 \n\t"
619
+ "vfcvt.x.f.v v18, v18 \n\t"
620
+ "vfcvt.x.f.v v20, v20 \n\t"
621
+ "vfcvt.x.f.v v22, v22 \n\t"
622
+ "vfcvt.x.f.v v24, v24 \n\t"
623
+ "vfcvt.x.f.v v26, v26 \n\t"
624
+ "vfcvt.x.f.v v28, v28 \n\t"
625
+ "vfcvt.x.f.v v30, v30 \n\t"
626
+ "vsetvli t0, zero, e16, m1 \n\t"
627
+ "vnclip.wx v16, v16, zero \n\t"
628
+ "vnclip.wx v18, v18, zero \n\t"
629
+ "vnclip.wx v20, v20, zero \n\t"
630
+ "vnclip.wx v22, v22, zero \n\t"
631
+ "vnclip.wx v24, v24, zero \n\t"
632
+ "vnclip.wx v26, v26, zero \n\t"
633
+ "vnclip.wx v28, v28, zero \n\t"
634
+ "vnclip.wx v30, v30, zero \n\t"
635
+ "vsetvli t0, t1, e8, mf2 \n\t"
636
+ "vnclip.wx v16, v16, zero \n\t"
637
+ "vnclip.wx v18, v18, zero \n\t"
638
+ "vnclip.wx v20, v20, zero \n\t"
639
+ "vnclip.wx v22, v22, zero \n\t"
640
+ "vnclip.wx v24, v24, zero \n\t"
641
+ "vnclip.wx v26, v26, zero \n\t"
642
+ "vnclip.wx v28, v28, zero \n\t"
643
+ "vnclip.wx v30, v30, zero \n\t"
644
+ "vse8.v v16, (%[DST]) \n\t"
645
+ "addi %[DST], %[DST], 20 \n\t"
646
+ "vse8.v v18, (%[DST]) \n\t"
647
+ "addi %[DST], %[DST], 20 \n\t"
648
+ "vse8.v v20, (%[DST]) \n\t"
649
+ "addi %[DST], %[DST], 20 \n\t"
650
+ "vse8.v v22, (%[DST]) \n\t"
651
+ "addi %[DST], %[DST], 20 \n\t"
652
+ "vse8.v v24, (%[DST]) \n\t"
653
+ "addi %[DST], %[DST], 20 \n\t"
654
+ "vse8.v v26, (%[DST]) \n\t"
655
+ "addi %[DST], %[DST], 20 \n\t"
656
+ "vse8.v v28, (%[DST]) \n\t"
657
+ "addi %[DST], %[DST], 20 \n\t"
658
+ "vse8.v v30, (%[DST]) \n\t"
659
+ "addi %[DST], %[DST], 16 \n\t"
660
+ "bge %[K], t3, LOOP_MAIN%= \n\t"
661
+ "blt %[K], t2, TAIL%= \n\t"
662
+ "LOOP_K%=: \n\t"
663
+ "vsetvli t1, %[K], e32, m2 \n\t"
664
+ "vle32.v v0, (%[SRC]) \n\t"
665
+ "addi %[SRC], %[SRC], 64 \n\t"
666
+ "sub %[K], %[K], t1 \n\t"
667
+ "vfabs.v v16, v0 \n\t"
668
+ "vsetvli t0, zero, e32, m1 \n\t"
669
+ "vfmax.vv v16, v16, v17 \n\t"
670
+ "vse32.v v16, (%[BUFFER]) \n\t"
671
+ "flw f0, (%[BUFFER]) \n\t"
672
+ "flw f1, 4(%[BUFFER]) \n\t"
673
+ "flw f2, 8(%[BUFFER]) \n\t"
674
+ "flw f3, 12(%[BUFFER]) \n\t"
675
+ "flw f4, 16(%[BUFFER]) \n\t"
676
+ "flw f5, 20(%[BUFFER]) \n\t"
677
+ "flw f6, 24(%[BUFFER]) \n\t"
678
+ "flw f7, 28(%[BUFFER]) \n\t"
679
+ "fmax.s f1, f0, f1 \n\t"
680
+ "fmax.s f3, f2, f3 \n\t"
681
+ "fmax.s f5, f4, f5 \n\t"
682
+ "fmax.s f7, f6, f7 \n\t"
683
+ "fmax.s f3, f1, f3 \n\t"
684
+ "fmax.s f7, f5, f7 \n\t"
685
+ "fmax.s f10, f3, f7 \n\t"
686
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
687
+ "fsw f10, (%[DST]) \n\t"
688
+ "addi %[DST], %[DST], 4 \n\t"
689
+ "fdiv.s f11, %[FONE], f10 \n\t"
690
+ "vsetvli t0, zero, e32, m2 \n\t"
691
+ "vfmul.vf v16, v0, f11 \n\t"
692
+ "vfcvt.x.f.v v16, v16 \n\t"
693
+ "vsetvli t0, zero, e16, m1 \n\t"
694
+ "vnclip.wx v16, v16, zero \n\t"
695
+ "vsetvli t0, t1, e8, mf2 \n\t"
696
+ "vnclip.wx v16, v16, zero \n\t"
697
+ "vse8.v v16, (%[DST]) \n\t"
698
+ "addi %[DST], %[DST], 16 \n\t"
699
+ "bge %[K], t2, LOOP_K%= \n\t"
700
+ "TAIL%=: \n\t"
701
+ "blez %[K], END%= \n\t"
702
+ "vsetvli t0, t3, e32, m2 \n\t"
703
+ "vxor.vv v16, v16, v16 \n\t"
704
+ "jal x0, LOOP_K%= \n\t"
705
+ "END%=: \n\t"
706
+ : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK)
707
+ : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BUFFER] "r"(buffer)
708
+ : "cc", "t3", "t2", "t1", "t0", "a1", "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f10", "f11", "f12",
709
+ "f13", "f14", "f15", "f16", "f17");
710
+ } else if (BlkLen == 32) {
711
+ __asm__ volatile(
712
+ "addi t3, zero, 32*4 \n\t"
713
+ "addi t2, zero, 32 \n\t"
714
+
715
+ "addi a1, %[SRC], 0 \n\t"
716
+ "addi a2, %[SRC], 128 \n\t"
717
+ "addi a3, %[SRC], 256 \n\t"
718
+ "addi a4, %[SRC], 384 \n\t"
719
+
720
+ "addi s1, %[DST], 0 \n\t"
721
+ "addi s2, %[DST], 36 \n\t"
722
+ "addi s3, %[DST], 72 \n\t"
723
+ "addi s4, %[DST], 108 \n\t"
724
+ "blt %[K], t3, LOOP_K%= \n\t"
725
+ "blt %[K], t2, TAIL%= \n\t"
726
+
727
+ "LOOP_MAIN%=: \n\t"
728
+ "vsetvli t1, zero, e32, m4 \n\t"
729
+ "addi %[K], %[K], -128 \n\t"
730
+ "vle32.v v0, (a1) \n\t"
731
+ "addi a1, a1, 512 \n\t"
732
+ "vle32.v v4, (a2) \n\t"
733
+ "addi a2, a2, 512 \n\t"
734
+ "vle32.v v8, (a3) \n\t"
735
+ "addi a3, a3, 512 \n\t"
736
+ "vle32.v v12, (a4) \n\t"
737
+ "addi a4, a4, 512 \n\t"
738
+ "vfabs.v v16, v0 \n\t"
739
+ "vfabs.v v20, v4 \n\t"
740
+ "vfabs.v v24, v8 \n\t"
741
+ "vfabs.v v28, v12 \n\t"
742
+ "vsetvli t0, zero, e32, m2 \n\t"
743
+ "vfmax.vv v16, v16, v18 \n\t"
744
+ "vfmax.vv v20, v20, v22 \n\t"
745
+ "vfmax.vv v24, v24, v26 \n\t"
746
+ "vfmax.vv v28, v28, v30 \n\t"
747
+ "vsetvli t0, zero, e32, m1 \n\t"
748
+ "vfmax.vv v16, v16, v17 \n\t"
749
+ "vfmax.vv v20, v20, v21 \n\t"
750
+ "vfmax.vv v24, v24, v25 \n\t"
751
+ "vfmax.vv v28, v28, v29 \n\t"
752
+
753
+ "vfredmax.vs v17, v16, v17 \n\t"
754
+ "vfredmax.vs v21, v20, v21 \n\t"
755
+ "vfredmax.vs v25, v24, v25 \n\t"
756
+ "vfredmax.vs v29, v28, v29 \n\t"
757
+ "vfmv.f.s f10, v17 \n\t"
758
+ "vfmv.f.s f11, v21 \n\t"
759
+ "vfmv.f.s f12, v25 \n\t"
760
+ "vfmv.f.s f13, v29 \n\t"
761
+
762
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
763
+ "fmul.s f11, f11, %[RMAXREC] \n\t"
764
+ "fmul.s f12, f12, %[RMAXREC] \n\t"
765
+ "fmul.s f13, f13, %[RMAXREC] \n\t"
766
+ "fsw f10, (s1) \n\t"
767
+ "addi s1, s1, 4 \n\t"
768
+
769
+ "fsw f11, (s2) \n\t"
770
+ "addi s2, s2, 4 \n\t"
771
+ "fsw f12, (s3) \n\t"
772
+ "addi s3, s3, 4 \n\t"
773
+ "fsw f13, (s4) \n\t"
774
+ "addi s4, s4, 4 \n\t"
775
+ "fdiv.s f10, %[FONE], f10 \n\t"
776
+ "fdiv.s f11, %[FONE], f11 \n\t"
777
+ "fdiv.s f12, %[FONE], f12 \n\t"
778
+ "fdiv.s f13, %[FONE], f13 \n\t"
779
+ "vsetvli t0, zero, e32, m4 \n\t"
780
+ "vfmul.vf v16, v0, f10 \n\t"
781
+ "vfmul.vf v20, v4, f11 \n\t"
782
+ "vfmul.vf v24, v8, f12 \n\t"
783
+ "vfmul.vf v28, v12, f13 \n\t"
784
+ "vfcvt.x.f.v v16, v16 \n\t"
785
+ "vfcvt.x.f.v v20, v20 \n\t"
786
+ "vfcvt.x.f.v v24, v24 \n\t"
787
+ "vfcvt.x.f.v v28, v28 \n\t"
788
+ "vsetvli t0, zero, e16, m2 \n\t"
789
+ "vnclip.wx v16, v16, zero \n\t"
790
+ "vnclip.wx v20, v20, zero \n\t"
791
+ "vnclip.wx v24, v24, zero \n\t"
792
+ "vnclip.wx v28, v28, zero \n\t"
793
+ "vsetvli t0, t1, e8, m1 \n\t"
794
+ "vnclip.wx v16, v16, zero \n\t"
795
+ "vnclip.wx v20, v20, zero \n\t"
796
+ "vnclip.wx v24, v24, zero \n\t"
797
+ "vnclip.wx v28, v28, zero \n\t"
798
+ "vse8.v v16, (s1) \n\t"
799
+ "addi s1, s1, 140 \n\t"
800
+ "vse8.v v20, (s2) \n\t"
801
+ "addi s2, s2, 140 \n\t"
802
+ "vse8.v v24, (s3) \n\t"
803
+ "addi s3, s3, 140 \n\t"
804
+ "vse8.v v28, (s4) \n\t"
805
+ "addi s4, s4, 140 \n\t"
806
+ "bge %[K], t3, LOOP_MAIN%= \n\t"
807
+ "blt %[K], t2, TAIL%= \n\t"
808
+ "LOOP_K%=: \n\t"
809
+ "vsetvli t1, %[K], e32, m4 \n\t"
810
+ "vle32.v v0, (a1) \n\t"
811
+ "addi a1, a1, 128 \n\t"
812
+ "sub %[K], %[K], t1 \n\t"
813
+ "vfabs.v v16, v0 \n\t"
814
+ "vsetvli t0, zero, e32, m2 \n\t"
815
+ "vfmax.vv v16, v16, v18 \n\t"
816
+ "vsetvli t0, zero, e32, m1 \n\t"
817
+ "vfmax.vv v16, v16, v17 \n\t"
818
+ "vfredmax.vs v17, v16, v17 \n\t"
819
+ "vfmv.f.s f10, v17 \n\t"
820
+
821
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
822
+ "fsw f10, (s1) \n\t"
823
+ "addi s1, s1, 4 \n\t"
824
+ "fdiv.s f11, %[FONE], f10 \n\t"
825
+ "vsetvli t0, zero, e32, m4 \n\t"
826
+ "vfmul.vf v16, v0, f11 \n\t"
827
+ "vfcvt.x.f.v v16, v16 \n\t"
828
+ "vsetvli t0, zero, e16, m2 \n\t"
829
+ "vnclip.wx v16, v16, zero \n\t"
830
+ "vsetvli t0, zero, e8, m1 \n\t"
831
+ "vnclip.wx v16, v16, zero \n\t"
832
+ "vse8.v v16, (s1) \n\t"
833
+ "addi s1, s1, 32 \n\t"
834
+ "bge %[K], t2, LOOP_K%= \n\t"
835
+ "TAIL%=: \n\t"
836
+ "blez %[K], END%= \n\t"
837
+ "vsetvli t0, t3, e32, m4 \n\t"
838
+ "vxor.vv v0, v0, v0 \n\t"
839
+ "vxor.vv v16, v16, v16 \n\t"
840
+ "jal x0, LOOP_K%= \n\t"
841
+ "END%=: \n\t"
842
+ : [K] "+r"(CountK)
843
+ : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST)
844
+ : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13");
845
+ } else if (BlkLen == 64) {
846
+ __asm__ volatile(
847
+ "addi t3, zero, 64*2 \n\t"
848
+ "addi t2, zero, 64 \n\t"
849
+ "addi a1, %[SRC], 0 \n\t"
850
+ "addi a2, %[SRC], 256 \n\t"
851
+ "addi s1, %[DST], 0 \n\t"
852
+ "addi s2, %[DST], 68 \n\t"
853
+ "blt %[K], t3, LOOP_K%= \n\t"
854
+ "blt %[K], t2, TAIL%= \n\t"
855
+ "LOOP_MAIN%=: \n\t"
856
+ "vsetvli t1, zero, e32, m8 \n\t"
857
+ "addi %[K], %[K], -128 \n\t"
858
+ "vle32.v v0, (a1) \n\t"
859
+ "addi a1, a1, 512 \n\t"
860
+ "vle32.v v8, (a2) \n\t"
861
+ "addi a2, a2, 512 \n\t"
862
+ "vfabs.v v16, v0 \n\t"
863
+ "vfabs.v v24, v8 \n\t"
864
+ "vsetvli t0, zero, e32, m4 \n\t"
865
+ "vfmax.vv v16, v16, v20 \n\t"
866
+ "vfmax.vv v24, v24, v28 \n\t"
867
+ "vsetvli t0, zero, e32, m2 \n\t"
868
+ "vfmax.vv v16, v16, v18 \n\t"
869
+ "vfmax.vv v24, v24, v26 \n\t"
870
+ "vsetvli t0, zero, e32, m1 \n\t"
871
+ "vfmax.vv v16, v16, v17 \n\t"
872
+ "vfmax.vv v24, v24, v25 \n\t"
873
+ "vfredmax.vs v17, v16, v17 \n\t"
874
+ "vfredmax.vs v25, v24, v25 \n\t"
875
+ "vfmv.f.s f10, v17 \n\t"
876
+ "vfmv.f.s f11, v25 \n\t"
877
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
878
+ "fmul.s f11, f11, %[RMAXREC] \n\t"
879
+ "fsw f10, (s1) \n\t"
880
+ "addi s1, s1, 4 \n\t"
881
+ "fsw f11, (s2) \n\t"
882
+ "addi s2, s2, 4 \n\t"
883
+ "fdiv.s f10, %[FONE], f10 \n\t"
884
+ "fdiv.s f11, %[FONE], f11 \n\t"
885
+ "vsetvli t0, zero, e32, m8 \n\t"
886
+ "vfmul.vf v16, v0, f10 \n\t"
887
+ "vfmul.vf v24, v8, f11 \n\t"
888
+ "vfcvt.x.f.v v16, v16 \n\t"
889
+ "vfcvt.x.f.v v24, v24 \n\t"
890
+ "vsetvli t0, zero, e16, m4 \n\t"
891
+ "vnclip.wx v16, v16, zero \n\t"
892
+ "vnclip.wx v24, v24, zero \n\t"
893
+ "vsetvli t0, t1, e8, m2 \n\t"
894
+ "vnclip.wx v16, v16, zero \n\t"
895
+ "vnclip.wx v24, v24, zero \n\t"
896
+ "vse8.v v16, (s1) \n\t"
897
+ "addi s1, s1, 132 \n\t"
898
+ "vse8.v v24, (s2) \n\t"
899
+ "addi s2, s2, 132 \n\t"
900
+ "bge %[K], t3, LOOP_MAIN%= \n\t"
901
+ "blt %[K], t2, TAIL%= \n\t"
902
+ "LOOP_K%=: \n\t"
903
+ "vsetvli t1, %[K], e32, m8 \n\t"
904
+ "vle32.v v0, (a1) \n\t"
905
+ "addi a1, a1, 256 \n\t"
906
+ "sub %[K], %[K], t1 \n\t"
907
+ "vfabs.v v16, v0 \n\t"
908
+ "vsetvli t0, zero, e32, m4 \n\t"
909
+ "vfmax.vv v16, v16, v20 \n\t"
910
+ "vsetvli t0, zero, e32, m2 \n\t"
911
+ "vfmax.vv v16, v16, v18 \n\t"
912
+ "vsetvli t0, zero, e32, m1 \n\t"
913
+ "vfmax.vv v16, v16, v17 \n\t"
914
+ "vfredmax.vs v17, v16, v17 \n\t"
915
+ "vfmv.f.s f10, v17 \n\t"
916
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
917
+ "fsw f10, (s1) \n\t"
918
+ "addi s1, s1, 4 \n\t"
919
+ "fdiv.s f11, %[FONE], f10 \n\t"
920
+ "vsetvli t0, zero, e32, m8 \n\t"
921
+ "vfmul.vf v16, v0, f11 \n\t"
922
+ "vfcvt.x.f.v v16, v16 \n\t"
923
+ "vsetvli t0, zero, e16, m4 \n\t"
924
+ "vnclip.wx v16, v16, zero \n\t"
925
+ "vsetvli t0, zero, e8, m2 \n\t"
926
+ "vnclip.wx v16, v16, zero \n\t"
927
+ "vse8.v v16, (s1) \n\t"
928
+ "addi s1, s1, 64 \n\t"
929
+ "bge %[K], t2, LOOP_K%= \n\t"
930
+ "TAIL%=: \n\t"
931
+ "blez %[K], END%= \n\t"
932
+ "vsetvli t0, t3, e32, m8 \n\t"
933
+ "vxor.vv v0, v0, v0 \n\t"
934
+ "vxor.vv v16, v16, v16 \n\t"
935
+ "jal x0, LOOP_K%= \n\t"
936
+ "END%=: \n\t"
937
+ : [K] "+r"(CountK)
938
+ : [SRC] "r"(SRC), [DST] "r"(DST), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
939
+ : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "s1", "s2", "f10", "f11");
940
+ } else if (BlkLen == 128) {
941
+ __asm__ volatile(
942
+ "addi t2, zero, 128 \n\t"
943
+ "addi a1, %[SRC], 0 \n\t"
944
+ "addi a2, %[SRC], 256 \n\t"
945
+ "blt %[K], t2, TAIL%= \n\t"
946
+ "LOOP_K%=: \n\t"
947
+ "vsetvli t1, zero, e32, m8 \n\t"
948
+ "vle32.v v0, (a1) \n\t"
949
+ "addi a1, a1, 512 \n\t"
950
+ "vle32.v v8, (a2) \n\t"
951
+ "addi a2, a2, 512 \n\t"
952
+ "sub %[K], %[K], t2 \n\t"
953
+ "QUANT%=: \n\t"
954
+ "vfabs.v v16, v0 \n\t"
955
+ "vfabs.v v24, v8 \n\t"
956
+ "vfmax.vv v24, v16, v24 \n\t"
957
+ "vsetvli t1, zero, e32, m4 \n\t"
958
+ "vfmax.vv v28, v24, v28 \n\t"
959
+ "vsetvli t0, zero, e32, m2 \n\t"
960
+ "vfmax.vv v30, v28, v30 \n\t"
961
+ "vsetvli t0, zero, e32, m1 \n\t"
962
+ "vfmax.vv v30, v30, v31 \n\t"
963
+ "vfredmax.vs v31, v30, v31 \n\t"
964
+ "vfmv.f.s f10, v31 \n\t"
965
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
966
+ "fsw f10, (%[DST]) \n\t"
967
+ "addi %[DST], %[DST], 4 \n\t"
968
+ "fdiv.s f11, %[FONE], f10 \n\t"
969
+ "vsetvli t0, zero, e32, m8 \n\t"
970
+ "vfmul.vf v16, v0, f11 \n\t"
971
+ "vfmul.vf v24, v8, f11 \n\t"
972
+ "vfcvt.x.f.v v16, v16 \n\t"
973
+ "vfcvt.x.f.v v24, v24 \n\t"
974
+ "vsetvli t0, zero, e16, m4 \n\t"
975
+ "vnclip.wx v16, v16, zero \n\t"
976
+ "vnclip.wx v20, v24, zero \n\t"
977
+ "vsetvli t0, zero, e8, m4 \n\t"
978
+ "vnclip.wx v16, v16, zero \n\t"
979
+ "vse8.v v16, (%[DST]) \n\t"
980
+ "addi %[DST], %[DST], 128 \n\t"
981
+ "bge %[K], t2, LOOP_K%= \n\t"
982
+ "TAIL%=: \n\t"
983
+ "blez %[K], END%= \n\t"
984
+ "vsetvli t1, zero, e32, m8 \n\t"
985
+ "vxor.vv v0, v0, v0 \n\t"
986
+ "vxor.vv v8, v8, v8 \n\t"
987
+ "vsetvli t0, %[K], e32, m8 \n\t"
988
+ "vle32.v v0, (a1) \n\t"
989
+ "sub %[K], %[K], t0 \n\t"
990
+ "vsetvli t0, %[K], e32, m8 \n\t"
991
+ "vle32.v v8, (a2) \n\t"
992
+ "sub %[K], %[K], t0 \n\t"
993
+ "vsetvli t1, zero, e32, m8 \n\t"
994
+ "jal x0, QUANT%= \n\t"
995
+ "END%=: \n\t"
996
+
997
+ : [DST] "+r"(DST), [K] "+r"(CountK)
998
+ : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC)
999
+ : "cc", "t2", "t1", "t0", "a1", "a2", "f10", "f11");
1000
+ } else {
1001
+ float buffer[8] = { 0.0f };
1002
+ size_t cnt = BlkLen / 256;
1003
+
1004
+ __asm__ volatile(
1005
+ "slli t3, %[BLK], 2 \n\t"
1006
+ "blt %[K], %[BLK], LOOP_TAIL%= \n\t"
1007
+ "LOOP_MAIN%=: \n\t"
1008
+ "vsetvli t0, zero, e32, m1 \n\t"
1009
+ "vxor.vv v31, v31, v31 \n\t"
1010
+ "vse32.v v31, (%[BUFFER]) \n\t"
1011
+ "addi t6, %[CNT], 0 \n\t"
1012
+ "LOOP_CMP%=: \n\t"
1013
+ "addi t6, t6, -1 \n\t"
1014
+ "vsetvli t0, zero, e32, m8 \n\t"
1015
+ "vle32.v v0, (%[SRC]) \n\t"
1016
+ "addi %[SRC], %[SRC], 256 \n\t"
1017
+ "vle32.v v8, (%[SRC]) \n\t"
1018
+ "addi %[SRC], %[SRC], 256 \n\t"
1019
+ "vle32.v v16, (%[SRC]) \n\t"
1020
+ "addi %[SRC], %[SRC], 256 \n\t"
1021
+ "vle32.v v24, (%[SRC]) \n\t"
1022
+ "addi %[SRC], %[SRC], 256 \n\t"
1023
+ "vfabs.v v0, v0 \n\t"
1024
+ "vfabs.v v8, v8 \n\t"
1025
+ "vfabs.v v16, v16 \n\t"
1026
+ "vfabs.v v24, v24 \n\t"
1027
+ "vfmax.vv v8, v0, v8 \n\t"
1028
+ "vfmax.vv v16, v16, v24 \n\t"
1029
+ "vfmax.vv v0, v0, v16 \n\t"
1030
+ "vsetvli t0, zero, e32, m4 \n\t"
1031
+ "vfmax.vv v0, v0, v4 \n\t"
1032
+ "vsetvli t0, zero, e32, m2 \n\t"
1033
+ "vfmax.vv v0, v0, v2 \n\t"
1034
+ "vsetvli t0, zero, e32, m1 \n\t"
1035
+ "vfmax.vv v0, v0, v1 \n\t"
1036
+ "vle32.v v30, (%[BUFFER]) \n\t"
1037
+ "vfmax.vv v31, v30, v0 \n\t"
1038
+ "vse32.v v31, (%[BUFFER]) \n\t"
1039
+ "bnez t6, LOOP_CMP%= \n\t"
1040
+ "sub %[SRC], %[SRC], t3 \n\t"
1041
+ "addi t6, %[CNT], 0 \n\t"
1042
+ "flw f0, (%[BUFFER]) \n\t"
1043
+ "flw f1, 4(%[BUFFER]) \n\t"
1044
+ "flw f2, 8(%[BUFFER]) \n\t"
1045
+ "flw f3, 12(%[BUFFER]) \n\t"
1046
+ "flw f4, 16(%[BUFFER]) \n\t"
1047
+ "flw f5, 20(%[BUFFER]) \n\t"
1048
+ "flw f6, 24(%[BUFFER]) \n\t"
1049
+ "flw f7, 28(%[BUFFER]) \n\t"
1050
+ "fmax.s f1, f0, f1 \n\t"
1051
+ "fmax.s f3, f2, f3 \n\t"
1052
+ "fmax.s f5, f4, f5 \n\t"
1053
+ "fmax.s f7, f6, f7 \n\t"
1054
+ "fmax.s f3, f1, f3 \n\t"
1055
+ "fmax.s f7, f5, f7 \n\t"
1056
+ "fmax.s f10, f3, f7 \n\t"
1057
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
1058
+ "fsw f10, (%[DST]) \n\t"
1059
+ "addi %[DST], %[DST], 4 \n\t"
1060
+ "fdiv.s f11, %[FONE], f10 \n\t"
1061
+ "addi t6, %[CNT], 0 \n\t"
1062
+ "LOOP_QUANT%=: \n\t"
1063
+ "addi t6, t6, -1 \n\t"
1064
+ "vsetvli t0, zero, e32, m8 \n\t"
1065
+ "vle32.v v0, (%[SRC]) \n\t"
1066
+ "addi %[SRC], %[SRC], 256 \n\t"
1067
+ "vle32.v v8, (%[SRC]) \n\t"
1068
+ "addi %[SRC], %[SRC], 256 \n\t"
1069
+ "vle32.v v16, (%[SRC]) \n\t"
1070
+ "addi %[SRC], %[SRC], 256 \n\t"
1071
+ "vle32.v v24, (%[SRC]) \n\t"
1072
+ "addi %[SRC], %[SRC], 256 \n\t"
1073
+ "vsetvli t0, zero, e32, m8 \n\t"
1074
+ "vfmul.vf v0, v0, f11 \n\t"
1075
+ "vfmul.vf v8, v8, f11 \n\t"
1076
+ "vfmul.vf v16, v16, f11 \n\t"
1077
+ "vfmul.vf v24, v24, f11 \n\t"
1078
+ "vfcvt.x.f.v v0, v0 \n\t"
1079
+ "vfcvt.x.f.v v8, v8 \n\t"
1080
+ "vfcvt.x.f.v v16, v16 \n\t"
1081
+ "vfcvt.x.f.v v24, v24 \n\t"
1082
+ "vsetvli t0, zero, e16, m4 \n\t"
1083
+ "vnclip.wx v0, v0, zero \n\t"
1084
+ "vnclip.wx v4, v8, zero \n\t"
1085
+ "vnclip.wx v8, v16, zero \n\t"
1086
+ "vnclip.wx v12, v24, zero \n\t"
1087
+ "vsetvli t0, zero, e8, m4 \n\t"
1088
+ "vnclip.wx v0, v0, zero \n\t"
1089
+ "vnclip.wx v4, v8, zero \n\t"
1090
+ "vse8.v v0, (%[DST]) \n\t"
1091
+ "addi %[DST], %[DST], 128 \n\t"
1092
+ "vse8.v v4, (%[DST]) \n\t"
1093
+ "addi %[DST], %[DST], 128 \n\t"
1094
+ "bnez t6, LOOP_QUANT%= \n\t"
1095
+ "sub %[K], %[K], %[BLK] \n\t"
1096
+ "bge %[K], %[BLK], LOOP_MAIN%= \n\t"
1097
+ "blez %[K], END%= \n\t"
1098
+ "LOOP_TAIL%=: \n\t"
1099
+ "vsetvli t0, zero, e32, m1 \n\t"
1100
+ "vxor.vv v31, v31, v31 \n\t"
1101
+ "vse32.v v31, (%[BUFFER]) \n\t"
1102
+ "addi t6, %[K], 0 \n\t"
1103
+ "addi s1, %[SRC], 0 \n\t"
1104
+ "TAIL_CMP%=: \n\t"
1105
+ "vsetvli t0, zero, e32, m8 \n\t"
1106
+ "vxor.vv v0, v0, v0 \n\t"
1107
+ "vsetvli t0, t6, e32, m8 \n\t"
1108
+ "vle32.v v0, (%[SRC]) \n\t"
1109
+ "addi %[SRC], %[SRC], 256 \n\t"
1110
+ "sub t6, t6, t0 \n\t"
1111
+ "vfabs.v v0, v0 \n\t"
1112
+ "vsetvli t0, zero, e32, m4 \n\t"
1113
+ "vfmax.vv v0, v0, v4 \n\t"
1114
+ "vsetvli t0, zero, e32, m2 \n\t"
1115
+ "vfmax.vv v0, v0, v2 \n\t"
1116
+ "vsetvli t0, zero, e32, m1 \n\t"
1117
+ "vfmax.vv v0, v0, v1 \n\t"
1118
+ "vle32.v v30, (%[BUFFER]) \n\t"
1119
+ "vfmax.vv v31, v30, v0 \n\t"
1120
+ "vse32.v v31, (%[BUFFER]) \n\t"
1121
+ "bnez t6, TAIL_CMP%= \n\t"
1122
+ "addi t6, %[K], 0 \n\t"
1123
+ "flw f0, (%[BUFFER]) \n\t"
1124
+ "flw f1, 4(%[BUFFER]) \n\t"
1125
+ "flw f2, 8(%[BUFFER]) \n\t"
1126
+ "flw f3, 12(%[BUFFER]) \n\t"
1127
+ "flw f4, 16(%[BUFFER]) \n\t"
1128
+ "flw f5, 20(%[BUFFER]) \n\t"
1129
+ "flw f6, 24(%[BUFFER]) \n\t"
1130
+ "flw f7, 28(%[BUFFER]) \n\t"
1131
+ "fmax.s f1, f0, f1 \n\t"
1132
+ "fmax.s f3, f2, f3 \n\t"
1133
+ "fmax.s f5, f4, f5 \n\t"
1134
+ "fmax.s f7, f6, f7 \n\t"
1135
+ "fmax.s f3, f1, f3 \n\t"
1136
+ "fmax.s f7, f5, f7 \n\t"
1137
+ "fmax.s f10, f3, f7 \n\t"
1138
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
1139
+ "fsw f10, (%[DST]) \n\t"
1140
+ "addi %[DST], %[DST], 4 \n\t"
1141
+ "fdiv.s f11, %[FONE], f10 \n\t"
1142
+ "addi t6, %[K], 0 \n\t"
1143
+ "TAIL_QUANT%=: \n\t"
1144
+ "vsetvli t0, zero, e32, m8 \n\t"
1145
+ "vxor.vv v0, v0, v0 \n\t"
1146
+ "vsetvli t1, t6, e32, m8 \n\t"
1147
+ "vle32.v v0, (s1) \n\t"
1148
+ "addi s1, s1, 256 \n\t"
1149
+ "sub t6, t6, t1 \n\t"
1150
+ "vsetvli t0, zero, e32, m8 \n\t"
1151
+ "vfmul.vf v0, v0, f11 \n\t"
1152
+ "vfcvt.x.f.v v0, v0 \n\t"
1153
+ "vsetvli t0, zero, e16, m4 \n\t"
1154
+ "vnclip.wx v0, v0, zero \n\t"
1155
+ "vsetvli t0, t1, e8, m2 \n\t"
1156
+ "vnclip.wx v0, v0, zero \n\t"
1157
+ "vse8.v v0, (%[DST]) \n\t"
1158
+ "addi %[DST], %[DST], 64 \n\t"
1159
+ "bnez t6, TAIL_QUANT%= \n\t"
1160
+ "END%=: \n\t"
1161
+ : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK)
1162
+ : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BLK] "r"(BlkLen), [BUFFER] "r"(buffer),
1163
+ [CNT] "r"(cnt)
1164
+ : "cc", "t1", "t0", "t6", "s1", "f0", "f1", "f2", "f3", "f4", "f5", "f6");
1165
+ }
1166
+ }
1167
+
1168
+ } // namespace ime1
1169
+
1170
+ namespace {
1171
+ #define SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 \
1172
+ "vmadot v16, v14, v0 \n\t" \
1173
+ "vmadot v18, v14, v1 \n\t" \
1174
+ "vmadot v20, v14, v2 \n\t" \
1175
+ "vmadot v22, v14, v3 \n\t" \
1176
+ "vmadot v16, v15, v4 \n\t" \
1177
+ "vmadot v18, v15, v5 \n\t" \
1178
+ "vmadot v20, v15, v6 \n\t" \
1179
+ "vmadot v22, v15, v7 \n\t"
1180
+
1181
+ #define SQ4BIT_KERNEL_ACC_1X4X4 \
1182
+ "vfcvt.f.x.v v16, v16 \n\t" \
1183
+ "vfcvt.f.x.v v18, v18 \n\t" \
1184
+ "vfcvt.f.x.v v20, v20 \n\t" \
1185
+ "vfcvt.f.x.v v22, v22 \n\t" \
1186
+ "addi s2, s1, 16 \n\t" \
1187
+ "addi s3, s1, 32 \n\t" \
1188
+ "addi s4, s1, 48 \n\t" \
1189
+ "addi s6, s5, 12 \n\t" \
1190
+ "vfmacc.vv v28, v16, v24 \n\t" \
1191
+ "vfmacc.vv v29, v18, v25 \n\t" \
1192
+ "vfmacc.vv v30, v20, v26 \n\t" \
1193
+ "vfmacc.vv v31, v22, v27 \n\t"
1194
+
1195
+ #define SQ4BIT_KERNEL_ACC_F16_1X4X4 \
1196
+ "vfcvt.f.x.v v16, v16 \n\t" \
1197
+ "vfcvt.f.x.v v18, v18 \n\t" \
1198
+ "vfcvt.f.x.v v20, v20 \n\t" \
1199
+ "vfcvt.f.x.v v22, v22 \n\t" \
1200
+ "addi s2, s1, 8 \n\t" \
1201
+ "addi s3, s1, 16 \n\t" \
1202
+ "addi s4, s1, 24 \n\t" \
1203
+ "addi s6, s5, 12 \n\t" \
1204
+ "vfmacc.vv v28, v16, v24 \n\t" \
1205
+ "vfmacc.vv v29, v18, v25 \n\t" \
1206
+ "vfmacc.vv v30, v20, v26 \n\t" \
1207
+ "vfmacc.vv v31, v22, v27 \n\t"
1208
+
1209
+ #define SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 \
1210
+ "vle8.v v4, (s1) \n\t" \
1211
+ "addi s1, s1, 128 \n\t" \
1212
+ "vle8.v v5, (s2) \n\t" \
1213
+ "addi s2, s2, 128 \n\t" \
1214
+ "vle8.v v6, (s3) \n\t" \
1215
+ "addi s3, s3, 128 \n\t" \
1216
+ "vle8.v v7, (s4) \n\t" \
1217
+ "addi s4, s4, 128 \n\t" \
1218
+ "vsetvli t0, zero, e8, mf4 \n\t" \
1219
+ "vle8.v v14, (s5) \n\t" \
1220
+ "addi s5, s5, 16 \n\t" \
1221
+ "vle8.v v15, (s6) \n\t" \
1222
+ "addi s6, s6, 16 \n\t" \
1223
+ "addi t5, t5, -1 \n\t" \
1224
+ "vsetvli t0, zero, e8, m1 \n\t" \
1225
+ "vand.vi v0, v4, 15 \n\t" \
1226
+ "vand.vi v1, v5, 15 \n\t" \
1227
+ "vand.vi v2, v6, 15 \n\t" \
1228
+ "vand.vi v3, v7, 15 \n\t" \
1229
+ "vsrl.vi v4, v4, 4 \n\t" \
1230
+ "vsrl.vi v5, v5, 4 \n\t" \
1231
+ "vsrl.vi v6, v6, 4 \n\t" \
1232
+ "vsrl.vi v7, v7, 4 \n\t"
1233
+
1234
+ #define SQ4BIT_KERNEL_LOAD_ZP_16X1 \
1235
+ "vsetvli t0, zero, e8, mf2 \n\t" \
1236
+ "vle8.v v1, (s7) \n\t" \
1237
+ "vsetvli t0, zero, e8, m1 \n\t" \
1238
+ "vrgather.vv v8, v1, v13 \n\t" \
1239
+ "vadd.vi v13, v13, 4 \n\t" \
1240
+ "vrgather.vv v9, v1, v13 \n\t" \
1241
+ "vadd.vi v13, v13, 4 \n\t" \
1242
+ "vrgather.vv v10, v1, v13 \n\t" \
1243
+ "vadd.vi v13, v13, 4 \n\t" \
1244
+ "vrgather.vv v11, v1, v13 \n\t" \
1245
+ "vadd.vi v13, v13, -12 \n\t"
1246
+
1247
+ // using for M4Kernel
1248
+ #define LOAD_B_16x8x2 \
1249
+ "vsetvli t0, zero, e8, m1 \n\t" \
1250
+ "vle8.v v6, (s1) \n\t" \
1251
+ "addi s1, s1, 32*4 \n\t" \
1252
+ "vle8.v v7, (s2) \n\t" \
1253
+ "addi s2, s2, 32*4 \n\t" \
1254
+ "vle8.v v8, (s3) \n\t" \
1255
+ "addi s3, s3, 32*4 \n\t" \
1256
+ "vle8.v v9, (s4) \n\t" \
1257
+ "addi s4, s4, 32*4 \n\t" \
1258
+ \
1259
+ "vand.vi v2, v6, 15 \n\t" \
1260
+ "vand.vi v3, v7, 15 \n\t" \
1261
+ "vand.vi v4, v8, 15 \n\t" \
1262
+ "vand.vi v5, v9, 15 \n\t" \
1263
+ \
1264
+ "vsrl.vi v6, v6, 4 \n\t" \
1265
+ "vsrl.vi v7, v7, 4 \n\t" \
1266
+ "vsrl.vi v8, v8, 4 \n\t" \
1267
+ "vsrl.vi v9, v9, 4 \n\t"
1268
+
1269
+ // [s2|s5, s3, s4, s6]
1270
+ #define LOAD_SCALE_4x16_FP16 \
1271
+ "addi s2, s5, -8 \n\t" \
1272
+ "addi s3, s5, 8 \n\t" \
1273
+ "addi s4, s5, 16 \n\t" \
1274
+ "addi s6, s5, 24 \n\t" \
1275
+ "li t1, 0xf0 \n\t" \
1276
+ "vmv.s.x v0, t1 \n\t" \
1277
+ "vsetvli t0, zero, e16, mf4 \n\t" \
1278
+ "vle16.v v9, (s5) \n\t" \
1279
+ "vle16.v v11, (s3) \n\t" \
1280
+ "vle16.v v13, (s4) \n\t" \
1281
+ "vle16.v v15, (s6) \n\t" \
1282
+ "vsetvli t0, zero, e16, mf2 \n\t" \
1283
+ "vle16.v v9, (s2), v0.t \n\t" \
1284
+ "vle16.v v11, (s5), v0.t \n\t" \
1285
+ "vle16.v v13, (s3), v0.t \n\t" \
1286
+ "vle16.v v15, (s4), v0.t \n\t" \
1287
+ "vfwcvt.f.f.v v8, v9 \n\t" \
1288
+ "vfwcvt.f.f.v v10, v11 \n\t" \
1289
+ "vfwcvt.f.f.v v12, v13 \n\t" \
1290
+ "vfwcvt.f.f.v v14, v15 \n\t" \
1291
+ "vsetvli t0, zero, e32, m1 \n\t" \
1292
+ "vmv.v.v v9, v8 \n\t" \
1293
+ "vmv.v.v v11, v10 \n\t" \
1294
+ "vmv.v.v v13, v12 \n\t" \
1295
+ "vmv.v.v v15, v14 \n\t" \
1296
+ "li t1, 0xf0 \n\t" \
1297
+ "vmv.s.x v0, t1 \n\t" \
1298
+ "vsetvli t0, zero, e32, mf2 \n\t" \
1299
+ "vfmul.vf v8, v8, f1 \n\t" \
1300
+ "vfmul.vf v10, v10, f1 \n\t" \
1301
+ "vfmul.vf v12, v12, f1 \n\t" \
1302
+ "vfmul.vf v14, v14, f1 \n\t" \
1303
+ "vfmul.vf v9, v9, f3 \n\t" \
1304
+ "vfmul.vf v11, v11, f3 \n\t" \
1305
+ "vfmul.vf v13, v13, f3 \n\t" \
1306
+ "vfmul.vf v15, v15, f3 \n\t" \
1307
+ "vsetvli t0, zero, e32, m1 \n\t" \
1308
+ "vfmul.vf v8, v8, f2, v0.t \n\t" \
1309
+ "vfmul.vf v10, v10, f2, v0.t \n\t" \
1310
+ "vfmul.vf v12, v12, f2, v0.t \n\t" \
1311
+ "vfmul.vf v14, v14, f2, v0.t \n\t" \
1312
+ "vfmul.vf v9, v9, f4, v0.t \n\t" \
1313
+ "vfmul.vf v11, v11, f4, v0.t \n\t" \
1314
+ "vfmul.vf v13, v13, f4, v0.t \n\t" \
1315
+ "vfmul.vf v15, v15, f4, v0.t \n\t"
1316
+
1317
+ // [s2|s5, s3, s4, s6]
1318
+ #define LOAD_SCALE_4x16 \
1319
+ "addi s2, s5, -16 \n\t" \
1320
+ "addi s3, s5, 16 \n\t" \
1321
+ "addi s4, s5, 32 \n\t" \
1322
+ "addi s6, s5, 48 \n\t" \
1323
+ "li t1, 0xf0 \n\t" \
1324
+ "vmv.s.x v0, t1 \n\t" \
1325
+ "vsetvli t0, zero, e32, mf2 \n\t" \
1326
+ "vle32.v v8, (s5) \n\t" \
1327
+ "vle32.v v10, (s3) \n\t" \
1328
+ "vle32.v v12, (s4) \n\t" \
1329
+ "vle32.v v14, (s6) \n\t" \
1330
+ "vsetvli t0, zero, e32, m1 \n\t" \
1331
+ "vle32.v v8, (s2), v0.t \n\t" \
1332
+ "vle32.v v10, (s5), v0.t \n\t" \
1333
+ "vle32.v v12, (s3), v0.t \n\t" \
1334
+ "vle32.v v14, (s4), v0.t \n\t" \
1335
+ "vmv.v.v v9, v8 \n\t" \
1336
+ "vmv.v.v v11, v10 \n\t" \
1337
+ "vmv.v.v v13, v12 \n\t" \
1338
+ "vmv.v.v v15, v14 \n\t" \
1339
+ "vsetvli t0, zero, e32, mf2 \n\t" \
1340
+ "vfmul.vf v8, v8, f1 \n\t" \
1341
+ "vfmul.vf v10, v10, f1 \n\t" \
1342
+ "vfmul.vf v12, v12, f1 \n\t" \
1343
+ "vfmul.vf v14, v14, f1 \n\t" \
1344
+ "vfmul.vf v9, v9, f3 \n\t" \
1345
+ "vfmul.vf v11, v11, f3 \n\t" \
1346
+ "vfmul.vf v13, v13, f3 \n\t" \
1347
+ "vfmul.vf v15, v15, f3 \n\t" \
1348
+ "vsetvli t0, zero, e32, m1 \n\t" \
1349
+ "vfmul.vf v8, v8, f2, v0.t \n\t" \
1350
+ "vfmul.vf v10, v10, f2, v0.t \n\t" \
1351
+ "vfmul.vf v12, v12, f2, v0.t \n\t" \
1352
+ "vfmul.vf v14, v14, f2, v0.t \n\t" \
1353
+ "vfmul.vf v9, v9, f4, v0.t \n\t" \
1354
+ "vfmul.vf v11, v11, f4, v0.t \n\t" \
1355
+ "vfmul.vf v13, v13, f4, v0.t \n\t" \
1356
+ "vfmul.vf v15, v15, f4, v0.t \n\t"
1357
+
1358
+ //[s1| BIAS, s2, s3, s4]
1359
+ #define LOAD_BIAS \
1360
+ "vsetvli t0, zero, e32, mf2 \n\t" \
1361
+ "li t1, 0xf0 \n\t" \
1362
+ "vmv.s.x v0, t1 \n\t" \
1363
+ "addi s1, %[BIAS], -16 \n\t" \
1364
+ "addi s2, %[BIAS], 16 \n\t" \
1365
+ "addi s3, %[BIAS], 32 \n\t" \
1366
+ "addi s4, %[BIAS], 48 \n\t" \
1367
+ \
1368
+ "vle32.v v24, (%[BIAS]) \n\t" \
1369
+ "vle32.v v26, (s2) \n\t" \
1370
+ "vle32.v v28, (s3) \n\t" \
1371
+ "vle32.v v30, (s4) \n\t" \
1372
+ "vsetvli t0, zero, e32, m1 \n\t" \
1373
+ "vle32.v v24, (s1), v0.t \n\t" \
1374
+ "vle32.v v26, (%[BIAS]), v0.t \n\t" \
1375
+ "vle32.v v28, (s2), v0.t \n\t" \
1376
+ "vle32.v v30, (s3), v0.t \n\t" \
1377
+ "vmv.v.v v25, v24 \n\t" \
1378
+ "vmv.v.v v27, v26 \n\t" \
1379
+ "vmv.v.v v29, v28 \n\t" \
1380
+ "vmv.v.v v31, v30 \n\t"
1381
+
1382
+ #define SQ4BIT_KERNEL_COMP_4x16x16 \
1383
+ "vmadot v16, v10, v2 \n\t" \
1384
+ "vmadot v18, v10, v3 \n\t" \
1385
+ "vmadot v20, v10, v4 \n\t" \
1386
+ "vmadot v22, v10, v5 \n\t" \
1387
+ "vmadot v16, v11, v6 \n\t" \
1388
+ "vmadot v18, v11, v7 \n\t" \
1389
+ "vmadot v20, v11, v8 \n\t" \
1390
+ "vmadot v22, v11, v9 \n\t"
1391
+
1392
+ #define SAVE_RESULT_4x16 \
1393
+ "addi a1, %[C], 0 \n\t" \
1394
+ "add a2, %[C], %[LDC] \n\t" \
1395
+ "add a3, a2, %[LDC] \n\t" \
1396
+ "add a4, a3, %[LDC] \n\t" \
1397
+ "addi a2, a2, -16 \n\t" \
1398
+ "addi a4, a4, -16 \n\t" \
1399
+ "li t1, 0xf0 \n\t" \
1400
+ "vmv.s.x v0, t1 \n\t" \
1401
+ "vsetvli t0, zero, e32, mf2 \n\t" \
1402
+ \
1403
+ "vse32.v v24, (a1) \n\t" \
1404
+ "addi a1, a1, 16 \n\t" \
1405
+ "vse32.v v25, (a3) \n\t" \
1406
+ "addi a3, a3, 16 \n\t" \
1407
+ \
1408
+ "vse32.v v26, (a1) \n\t" \
1409
+ "addi a1, a1, 16 \n\t" \
1410
+ "vse32.v v27, (a3) \n\t" \
1411
+ "addi a3, a3, 16 \n\t" \
1412
+ \
1413
+ "vse32.v v28, (a1) \n\t" \
1414
+ "addi a1, a1, 16 \n\t" \
1415
+ "vse32.v v29, (a3) \n\t" \
1416
+ "addi a3, a3, 16 \n\t" \
1417
+ \
1418
+ "vse32.v v30, (a1) \n\t" \
1419
+ "vse32.v v31, (a3) \n\t" \
1420
+ "vsetvli t0, zero, e32, m1 \n\t" \
1421
+ \
1422
+ "vse32.v v24, (a2), v0.t \n\t" \
1423
+ "addi a2, a2, 16 \n\t" \
1424
+ "vse32.v v25, (a4), v0.t \n\t" \
1425
+ "addi a4, a4, 16 \n\t" \
1426
+ \
1427
+ "vse32.v v26, (a2), v0.t \n\t" \
1428
+ "addi a2, a2, 16 \n\t" \
1429
+ "vse32.v v27, (a4), v0.t \n\t" \
1430
+ "addi a4, a4, 16 \n\t" \
1431
+ \
1432
+ "vse32.v v28, (a2), v0.t \n\t" \
1433
+ "addi a2, a2, 16 \n\t" \
1434
+ "vse32.v v29, (a4), v0.t \n\t" \
1435
+ "addi a4, a4, 16 \n\t" \
1436
+ \
1437
+ "vse32.v v30, (a2), v0.t \n\t" \
1438
+ "vse32.v v31, (a4), v0.t \n\t"
1439
+
1440
+ #define SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 \
1441
+ "vsetvli t0, zero, e8, mf2 \n\t" \
1442
+ "vle8.v v11, (s6) \n\t" \
1443
+ "vsetvli t0, zero, e8, m1 \n\t" \
1444
+ "vrgather.vv v12, v11, v1 \n\t" \
1445
+ "vadd.vi v1, v1, 4 \n\t" \
1446
+ "vrgather.vv v13, v11, v1 \n\t" \
1447
+ "vadd.vi v1, v1, 4 \n\t" \
1448
+ "vrgather.vv v14, v11, v1 \n\t" \
1449
+ "vadd.vi v1, v1, 4 \n\t" \
1450
+ "vrgather.vv v15, v11, v1 \n\t" \
1451
+ "vadd.vi v1, v1, -12 \n\t"
1452
+
1453
+ template <bool HasZeroPoint>
1454
+ void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
1455
+ const std::byte * QuantA,
1456
+ const std::byte * QuantBData,
1457
+ const float * QuantBScale,
1458
+ const std::byte * QuantBZeroPoint,
1459
+ float * C,
1460
+ size_t CountN,
1461
+ size_t BlockCountK,
1462
+ const float * Bias,
1463
+ const size_t ldc) {
1464
+ GGML_UNUSED(QuantBScale);
1465
+ GGML_UNUSED(QuantBZeroPoint);
1466
+ size_t LDC = ldc * sizeof(float);
1467
+ const size_t INNER = BlkLen / 16;
1468
+ float tmp[4 * 16];
1469
+
1470
+ if constexpr (HasZeroPoint) {
1471
+ for (size_t n = 0; n < CountN; n += 16) {
1472
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
1473
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
1474
+ n * BlockCountK * BlkLen / 2 + // b data
1475
+ n * BlockCountK * sizeof(uint8_t) + // zp
1476
+ n * BlockCountK * sizeof(_Float16); // scale
1477
+ float * CPtr = C + n;
1478
+ if (NBLKS < 16) {
1479
+ CPtr = tmp;
1480
+ LDC = 16 * sizeof(float);
1481
+ }
1482
+ if (Bias != nullptr) {
1483
+ const float * bias = Bias + n;
1484
+ if (NBLKS < 16) {
1485
+ __asm__ volatile(
1486
+ "vsetvli t0, %[N], e32, m2 \n\t"
1487
+ "vle32.v v0, (%[SRC]) \n\t"
1488
+ "vse32.v v0, (%[DST]) \n\t"
1489
+ :
1490
+ : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
1491
+ : "cc", "t0");
1492
+ bias = tmp;
1493
+ }
1494
+ __asm__ volatile(LOAD_BIAS
1495
+
1496
+ "addi t3, %[BlockCountK], 0 \n\t"
1497
+
1498
+ "vsetvli t0, zero, e8, m1 \n\t"
1499
+ "li s1, 24 \n\t"
1500
+ "vmv.v.i v1, 3 \n\t"
1501
+ "vsetvli t0, s1, e8, m1 \n\t"
1502
+ "vmv.v.i v1, 2 \n\t"
1503
+ "vsetvli t0, zero, e8, mf2 \n\t"
1504
+ "vmv.v.i v1, 1 \n\t"
1505
+ "vsetvli t0, zero, e8, mf4 \n\t"
1506
+ "vmv.v.i v1, 0 \n\t"
1507
+
1508
+ "addi a1, %[A], 0 \n\t"
1509
+ "addi s1, %[B], 0 \n\t"
1510
+
1511
+ "BLOCK_COUNTK_LOOP%=: \n\t"
1512
+ // scale offset
1513
+ "addi s5, s1, 0 \n\t"
1514
+ // zp offset
1515
+ "addi s6, s1, 32 \n\t"
1516
+ "addi s1, s6, 16 \n\t"
1517
+ "addi s2, s1, 32 \n\t"
1518
+ "addi s3, s1, 32*2 \n\t"
1519
+ "addi s4, s1, 32*3 \n\t"
1520
+
1521
+ "vsetvli t0, zero, e32, m8 \n\t"
1522
+ "vxor.vv v16, v16, v16 \n\t"
1523
+ // load a scale
1524
+ "flw f1, (a1) \n\t"
1525
+ "flw f2, 4(a1) \n\t"
1526
+ "flw f3, 8(a1) \n\t"
1527
+ "flw f4, 12(a1) \n\t"
1528
+ "addi a1, a1, 16 \n\t"
1529
+ "addi t2, %[INNER], 0 \n\t"
1530
+
1531
+ SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1532
+
1533
+ "BLOCK_INNER_LOOP%=: \n\t"
1534
+
1535
+ LOAD_B_16x8x2
1536
+
1537
+ "vle8.v v10, (a1) \n\t"
1538
+ "addi a1, a1, 32 \n\t"
1539
+ "vle8.v v11, (a1) \n\t"
1540
+ "addi a1, a1, 32 \n\t"
1541
+ "vsub.vv v2, v2, v12 \n\t"
1542
+ "vsub.vv v6, v6, v12 \n\t"
1543
+ "vsub.vv v3, v3, v13 \n\t"
1544
+ "vsub.vv v7, v7, v13 \n\t"
1545
+ "vsub.vv v4, v4, v14 \n\t"
1546
+ "vsub.vv v8, v8, v14 \n\t"
1547
+ "vsub.vv v5, v5, v15 \n\t"
1548
+ "vsub.vv v9, v9, v15 \n\t"
1549
+
1550
+ SQ4BIT_KERNEL_COMP_4x16x16
1551
+
1552
+ "addi t2, t2, -1 \n\t"
1553
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1554
+
1555
+ LOAD_SCALE_4x16_FP16
1556
+
1557
+ "vsetvli t0, zero, e32, m8 \n\t"
1558
+ "vfcvt.f.x.v v16, v16 \n\t"
1559
+ "vfmacc.vv v24, v16, v8 \n\t"
1560
+ "addi t3, t3, -1 \n\t"
1561
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1562
+
1563
+ "RESULT_SAVE%=: \n\t"
1564
+
1565
+ SAVE_RESULT_4x16
1566
+
1567
+ :
1568
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1569
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
1570
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
1571
+ "s2", "s3", "s4", "s5", "s6");
1572
+
1573
+ } else {
1574
+ __asm__ volatile(
1575
+ "vsetvli t0, zero, e32, m8 \n\t"
1576
+ "vxor.vv v24, v24, v24 \n\t"
1577
+ "addi t3, %[BlockCountK], 0 \n\t"
1578
+ "vsetvli t0, zero, e8, m1 \n\t"
1579
+ "li s1, 24 \n\t"
1580
+ "vmv.v.i v1, 3 \n\t"
1581
+ "vsetvli t0, s1, e8, m1 \n\t"
1582
+ "vmv.v.i v1, 2 \n\t"
1583
+ "vsetvli t0, zero, e8, mf2 \n\t"
1584
+ "vmv.v.i v1, 1 \n\t"
1585
+ "vsetvli t0, zero, e8, mf4 \n\t"
1586
+ "vmv.v.i v1, 0 \n\t"
1587
+ "addi a1, %[A], 0 \n\t"
1588
+ "addi s1, %[B], 0 \n\t"
1589
+ "BLOCK_COUNTK_LOOP%=: \n\t"
1590
+ // scale offset
1591
+ "addi s5, s1, 0 \n\t"
1592
+ // zp offset
1593
+ "addi s6, s1, 32 \n\t"
1594
+ "addi s1, s6, 16 \n\t"
1595
+ "addi s2, s1, 32 \n\t"
1596
+ "addi s3, s1, 32*2 \n\t"
1597
+ "addi s4, s1, 32*3 \n\t"
1598
+
1599
+ "vsetvli t0, zero, e32, m8 \n\t"
1600
+ "vxor.vv v16, v16, v16 \n\t"
1601
+ // load a scale
1602
+ "flw f1, (a1) \n\t"
1603
+ "flw f2, 4(a1) \n\t"
1604
+ "flw f3, 8(a1) \n\t"
1605
+ "flw f4, 12(a1) \n\t"
1606
+ "addi a1, a1, 16 \n\t"
1607
+ "addi t2, %[INNER], 0 \n\t"
1608
+
1609
+ SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1610
+
1611
+ "BLOCK_INNER_LOOP%=: \n\t"
1612
+
1613
+ LOAD_B_16x8x2
1614
+
1615
+ "vle8.v v10, (a1) \n\t"
1616
+ "addi a1, a1, 32 \n\t"
1617
+ "vle8.v v11, (a1) \n\t"
1618
+ "addi a1, a1, 32 \n\t"
1619
+ "vsub.vv v2, v2, v12 \n\t"
1620
+ "vsub.vv v6, v6, v12 \n\t"
1621
+ "vsub.vv v3, v3, v13 \n\t"
1622
+ "vsub.vv v7, v7, v13 \n\t"
1623
+ "vsub.vv v4, v4, v14 \n\t"
1624
+ "vsub.vv v8, v8, v14 \n\t"
1625
+ "vsub.vv v5, v5, v15 \n\t"
1626
+ "vsub.vv v9, v9, v15 \n\t"
1627
+
1628
+ SQ4BIT_KERNEL_COMP_4x16x16
1629
+
1630
+ "addi t2, t2, -1 \n\t"
1631
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1632
+
1633
+ LOAD_SCALE_4x16_FP16
1634
+
1635
+ "vsetvli t0, zero, e32, m8 \n\t"
1636
+ "vfcvt.f.x.v v16, v16 \n\t"
1637
+ "vfmacc.vv v24, v16, v8 \n\t"
1638
+ "addi t3, t3, -1 \n\t"
1639
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1640
+
1641
+ "RESULT_SAVE%=: \n\t"
1642
+
1643
+ SAVE_RESULT_4x16
1644
+
1645
+ :
1646
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1647
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
1648
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
1649
+ "s4", "s5", "s6");
1650
+ }
1651
+ }
1652
+ } else {
1653
+ for (size_t n = 0; n < CountN; n += 16) {
1654
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
1655
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
1656
+ n * BlockCountK * BlkLen / 2 + // b data
1657
+ n * BlockCountK * sizeof(_Float16); // scale
1658
+ float * CPtr = C + n;
1659
+ if (NBLKS < 16) {
1660
+ CPtr = tmp;
1661
+ LDC = 16 * sizeof(float);
1662
+ }
1663
+ if (Bias != nullptr) {
1664
+ const float * bias = Bias + n;
1665
+ if (NBLKS < 16) {
1666
+ __asm__ volatile(
1667
+ "vsetvli t0, %[N], e32, m2 \n\t"
1668
+ "vle32.v v0, (%[SRC]) \n\t"
1669
+ "vse32.v v0, (%[DST]) \n\t"
1670
+ :
1671
+ : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
1672
+ : "cc", "t0");
1673
+ bias = tmp;
1674
+ }
1675
+ __asm__ volatile(LOAD_BIAS
1676
+
1677
+ "addi t3, %[BlockCountK], 0 \n\t"
1678
+ "addi a1, %[A], 0 \n\t"
1679
+ "addi s1, %[B], 0 \n\t"
1680
+ "BLOCK_COUNTK_LOOP%=: \n\t"
1681
+ "addi s5, s1, 0 \n\t"
1682
+ "addi s1, s5, 32 \n\t"
1683
+ "addi s2, s1, 32 \n\t"
1684
+ "addi s3, s1, 32*2 \n\t"
1685
+ "addi s4, s1, 32*3 \n\t"
1686
+ "vsetvli t0, zero, e32, m8 \n\t"
1687
+ "vxor.vv v16, v16, v16 \n\t"
1688
+ // load a scale
1689
+ "flw f1, (a1) \n\t"
1690
+ "flw f2, 4(a1) \n\t"
1691
+ "flw f3, 8(a1) \n\t"
1692
+ "flw f4, 12(a1) \n\t"
1693
+ "addi a1, a1, 16 \n\t"
1694
+ "addi t2, %[INNER], 0 \n\t"
1695
+ "BLOCK_INNER_LOOP%=: \n\t"
1696
+
1697
+ LOAD_B_16x8x2
1698
+
1699
+ "vsetvli t0, zero, e8, m1 \n\t"
1700
+ "vle8.v v10, (a1) \n\t"
1701
+ "addi a1, a1, 32 \n\t"
1702
+ "vle8.v v11, (a1) \n\t"
1703
+ "addi a1, a1, 32 \n\t"
1704
+ "vadd.vi v2, v2, -8 \n\t"
1705
+ "vadd.vi v3, v3, -8 \n\t"
1706
+ "vadd.vi v4, v4, -8 \n\t"
1707
+ "vadd.vi v5, v5, -8 \n\t"
1708
+ "vadd.vi v6, v6, -8 \n\t"
1709
+ "vadd.vi v7, v7, -8 \n\t"
1710
+ "vadd.vi v8, v8, -8 \n\t"
1711
+ "vadd.vi v9, v9, -8 \n\t"
1712
+
1713
+ SQ4BIT_KERNEL_COMP_4x16x16
1714
+
1715
+ "addi t2, t2, -1 \n\t"
1716
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1717
+
1718
+ LOAD_SCALE_4x16_FP16
1719
+
1720
+ "vsetvli t0, zero, e32, m8 \n\t"
1721
+ "vfcvt.f.x.v v16, v16 \n\t"
1722
+ "vfmacc.vv v24, v16, v8 \n\t"
1723
+ "addi t3, t3, -1 \n\t"
1724
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1725
+ "RESULT_SAVE%=: \n\t"
1726
+
1727
+ SAVE_RESULT_4x16
1728
+
1729
+ :
1730
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1731
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
1732
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
1733
+ "s2", "s3", "s4", "s5", "s6");
1734
+
1735
+ } else {
1736
+ __asm__ volatile(
1737
+ "vsetvli t0, zero, e32, m8 \n\t"
1738
+ "vxor.vv v24, v24, v24 \n\t"
1739
+ "addi t3, %[BlockCountK], 0 \n\t"
1740
+ "addi a1, %[A], 0 \n\t"
1741
+ "addi s1, %[B], 0 \n\t"
1742
+ "BLOCK_COUNTK_LOOP%=: \n\t"
1743
+ "addi s5, s1, 0 \n\t"
1744
+ "addi s1, s5, 32 \n\t"
1745
+ "addi s2, s1, 32 \n\t"
1746
+ "addi s3, s1, 32*2 \n\t"
1747
+ "addi s4, s1, 32*3 \n\t"
1748
+ "vsetvli t0, zero, e32, m8 \n\t"
1749
+ "vxor.vv v16, v16, v16 \n\t"
1750
+ // load a scale
1751
+ "flw f1, (a1) \n\t"
1752
+ "flw f2, 4(a1) \n\t"
1753
+ "flw f3, 8(a1) \n\t"
1754
+ "flw f4, 12(a1) \n\t"
1755
+ "addi a1, a1, 16 \n\t"
1756
+ "addi t2, %[INNER], 0 \n\t"
1757
+ "BLOCK_INNER_LOOP%=: \n\t"
1758
+
1759
+ LOAD_B_16x8x2
1760
+
1761
+ "vsetvli t0, zero, e8, m1 \n\t"
1762
+ "vle8.v v10, (a1) \n\t"
1763
+ "addi a1, a1, 32 \n\t"
1764
+ "vle8.v v11, (a1) \n\t"
1765
+ "addi a1, a1, 32 \n\t"
1766
+ "vadd.vi v2, v2, -8 \n\t"
1767
+ "vadd.vi v3, v3, -8 \n\t"
1768
+ "vadd.vi v4, v4, -8 \n\t"
1769
+ "vadd.vi v5, v5, -8 \n\t"
1770
+ "vadd.vi v6, v6, -8 \n\t"
1771
+ "vadd.vi v7, v7, -8 \n\t"
1772
+ "vadd.vi v8, v8, -8 \n\t"
1773
+ "vadd.vi v9, v9, -8 \n\t"
1774
+
1775
+ SQ4BIT_KERNEL_COMP_4x16x16
1776
+
1777
+ "addi t2, t2, -1 \n\t"
1778
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1779
+
1780
+ LOAD_SCALE_4x16_FP16
1781
+
1782
+ "vsetvli t0, zero, e32, m8 \n\t"
1783
+ "vfcvt.f.x.v v16, v16 \n\t"
1784
+ "vfmacc.vv v24, v16, v8 \n\t"
1785
+ "addi t3, t3, -1 \n\t"
1786
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1787
+ "RESULT_SAVE%=: \n\t"
1788
+
1789
+ SAVE_RESULT_4x16
1790
+
1791
+ :
1792
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1793
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
1794
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
1795
+ "s4", "s5", "s6");
1796
+ }
1797
+ }
1798
+ }
1799
+ if (CountN % 16 != 0) {
1800
+ // stroe output from tmp to C when NBLKS less than 16.
1801
+ float * CPtr = C + CountN / 16 * 16;
1802
+ const size_t N = CountN % 16;
1803
+ LDC = ldc * sizeof(float);
1804
+ __asm__ volatile(
1805
+ "vsetvli t0, %[N], e32, m2 \n\t"
1806
+ "vle32.v v0, (%[SRC]) \n\t"
1807
+ "addi s2, %[SRC], 64 \n\t"
1808
+ "addi s3, %[SRC], 64*2 \n\t"
1809
+ "addi s4, %[SRC], 64*3 \n\t"
1810
+ "vle32.v v2, (s2) \n\t"
1811
+ "vle32.v v4, (s3) \n\t"
1812
+ "vle32.v v6, (s4) \n\t"
1813
+ "add t2, %[DST], %[LDC] \n\t"
1814
+ "add t3, t2, %[LDC] \n\t"
1815
+ "add t4, t3, %[LDC] \n\t"
1816
+ "vse32.v v0, (%[DST]) \n\t"
1817
+ "vse32.v v2, (t2) \n\t"
1818
+ "vse32.v v4, (t3) \n\t"
1819
+ "vse32.v v6, (t4) \n\t"
1820
+ :
1821
+ : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC)
1822
+ : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4");
1823
+ }
1824
+ }
1825
+
1826
+ template <bool HasZeroPoint>
1827
+ void SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen,
1828
+ const std::byte * QuantA,
1829
+ const std::byte * QuantBData,
1830
+ const float * QuantBScale,
1831
+ const std::byte * QuantBZeroPoint,
1832
+ float * C,
1833
+ size_t CountN,
1834
+ size_t BlockCountK,
1835
+ const float * Bias,
1836
+ const size_t ldc) {
1837
+ GGML_UNUSED(QuantBScale);
1838
+ GGML_UNUSED(QuantBZeroPoint);
1839
+ size_t LDC = ldc * sizeof(float);
1840
+ const size_t INNER = BlkLen / 16;
1841
+ float tmp[4 * 16];
1842
+
1843
+ if constexpr (HasZeroPoint) {
1844
+ for (size_t n = 0; n < CountN; n += 16) {
1845
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
1846
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
1847
+ n * BlockCountK * BlkLen / 2 + // b data
1848
+ n * BlockCountK * sizeof(uint8_t) + // zp
1849
+ n * BlockCountK * sizeof(float); // scale
1850
+ float * CPtr = C + n;
1851
+ if (NBLKS < 16) {
1852
+ CPtr = tmp;
1853
+ LDC = 16 * sizeof(float);
1854
+ }
1855
+ if (Bias != nullptr) {
1856
+ const float * bias = Bias + n;
1857
+ if (NBLKS < 16) {
1858
+ __asm__ volatile(
1859
+ "vsetvli t0, %[N], e32, m2 \n\t"
1860
+ "vle32.v v0, (%[SRC]) \n\t"
1861
+ "vse32.v v0, (%[DST]) \n\t"
1862
+ :
1863
+ : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
1864
+ : "cc", "t0");
1865
+ bias = tmp;
1866
+ }
1867
+
1868
+ __asm__ volatile(LOAD_BIAS
1869
+ "addi t3, %[BlockCountK], 0 \n\t"
1870
+ "vsetvli t0, zero, e8, m1 \n\t"
1871
+ "li s1, 24 \n\t"
1872
+ "vmv.v.i v1, 3 \n\t"
1873
+ "vsetvli t0, s1, e8, m1 \n\t"
1874
+ "vmv.v.i v1, 2 \n\t"
1875
+ "vsetvli t0, zero, e8, mf2 \n\t"
1876
+ "vmv.v.i v1, 1 \n\t"
1877
+ "vsetvli t0, zero, e8, mf4 \n\t"
1878
+ "vmv.v.i v1, 0 \n\t"
1879
+ "addi a1, %[A], 0 \n\t"
1880
+ "addi s1, %[B], 0 \n\t"
1881
+ "BLOCK_COUNTK_LOOP%=: \n\t"
1882
+ // scale offset
1883
+ "addi s5, s1, 0 \n\t"
1884
+ // zp offset
1885
+ "addi s6, s1, 64 \n\t"
1886
+ "addi s1, s6, 16 \n\t"
1887
+ "addi s2, s1, 32 \n\t"
1888
+ "addi s3, s1, 32*2 \n\t"
1889
+ "addi s4, s1, 32*3 \n\t"
1890
+ "vsetvli t0, zero, e32, m8 \n\t"
1891
+ "vxor.vv v16, v16, v16 \n\t"
1892
+ // load a scale
1893
+ "flw f1, (a1) \n\t"
1894
+ "flw f2, 4(a1) \n\t"
1895
+ "flw f3, 8(a1) \n\t"
1896
+ "flw f4, 12(a1) \n\t"
1897
+ "addi a1, a1, 16 \n\t"
1898
+ "addi t2, %[INNER], 0 \n\t"
1899
+
1900
+ SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1901
+
1902
+ "BLOCK_INNER_LOOP%=: \n\t"
1903
+
1904
+ LOAD_B_16x8x2
1905
+
1906
+ "vle8.v v10, (a1) \n\t"
1907
+ "addi a1, a1, 32 \n\t"
1908
+ "vle8.v v11, (a1) \n\t"
1909
+ "addi a1, a1, 32 \n\t"
1910
+ "vsub.vv v2, v2, v12 \n\t"
1911
+ "vsub.vv v6, v6, v12 \n\t"
1912
+ "vsub.vv v3, v3, v13 \n\t"
1913
+ "vsub.vv v7, v7, v13 \n\t"
1914
+ "vsub.vv v4, v4, v14 \n\t"
1915
+ "vsub.vv v8, v8, v14 \n\t"
1916
+ "vsub.vv v5, v5, v15 \n\t"
1917
+ "vsub.vv v9, v9, v15 \n\t"
1918
+
1919
+ SQ4BIT_KERNEL_COMP_4x16x16
1920
+
1921
+ "addi t2, t2, -1 \n\t"
1922
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1923
+
1924
+ LOAD_SCALE_4x16
1925
+
1926
+ "vsetvli t0, zero, e32, m8 \n\t"
1927
+ "vfcvt.f.x.v v16, v16 \n\t"
1928
+ "vfmacc.vv v24, v16, v8 \n\t"
1929
+ "addi t3, t3, -1 \n\t"
1930
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1931
+
1932
+ "RESULT_SAVE%=: \n\t"
1933
+
1934
+ SAVE_RESULT_4x16
1935
+
1936
+ :
1937
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1938
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
1939
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
1940
+ "s2", "s3", "s4", "s5", "s6");
1941
+
1942
+ } else {
1943
+ __asm__ volatile(
1944
+ "vsetvli t0, zero, e32, m8 \n\t"
1945
+ "vxor.vv v24, v24, v24 \n\t"
1946
+ "addi t3, %[BlockCountK], 0 \n\t"
1947
+ "vsetvli t0, zero, e8, m1 \n\t"
1948
+ "li s1, 24 \n\t"
1949
+ "vmv.v.i v1, 3 \n\t"
1950
+ "vsetvli t0, s1, e8, m1 \n\t"
1951
+ "vmv.v.i v1, 2 \n\t"
1952
+ "vsetvli t0, zero, e8, mf2 \n\t"
1953
+ "vmv.v.i v1, 1 \n\t"
1954
+ "vsetvli t0, zero, e8, mf4 \n\t"
1955
+ "vmv.v.i v1, 0 \n\t"
1956
+ "addi a1, %[A], 0 \n\t"
1957
+ "addi s1, %[B], 0 \n\t"
1958
+ "BLOCK_COUNTK_LOOP%=: \n\t"
1959
+ // scale offset
1960
+ "addi s5, s1, 0 \n\t"
1961
+ // zp offset
1962
+ "addi s6, s1, 64 \n\t"
1963
+ "addi s1, s6, 16 \n\t"
1964
+ "addi s2, s1, 32 \n\t"
1965
+ "addi s3, s1, 32*2 \n\t"
1966
+ "addi s4, s1, 32*3 \n\t"
1967
+ "vsetvli t0, zero, e32, m8 \n\t"
1968
+ "vxor.vv v16, v16, v16 \n\t"
1969
+ // load a scale
1970
+ // load a scale
1971
+ "flw f1, (a1) \n\t"
1972
+ "flw f2, 4(a1) \n\t"
1973
+ "flw f3, 8(a1) \n\t"
1974
+ "flw f4, 12(a1) \n\t"
1975
+ "addi a1, a1, 16 \n\t"
1976
+ "addi t2, %[INNER], 0 \n\t"
1977
+
1978
+ SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1979
+
1980
+ "BLOCK_INNER_LOOP%=: \n\t"
1981
+
1982
+ LOAD_B_16x8x2
1983
+
1984
+ "vle8.v v10, (a1) \n\t"
1985
+ "addi a1, a1, 32 \n\t"
1986
+ "vle8.v v11, (a1) \n\t"
1987
+ "addi a1, a1, 32 \n\t"
1988
+ "vsub.vv v2, v2, v12 \n\t"
1989
+ "vsub.vv v6, v6, v12 \n\t"
1990
+ "vsub.vv v3, v3, v13 \n\t"
1991
+ "vsub.vv v7, v7, v13 \n\t"
1992
+ "vsub.vv v4, v4, v14 \n\t"
1993
+ "vsub.vv v8, v8, v14 \n\t"
1994
+ "vsub.vv v5, v5, v15 \n\t"
1995
+ "vsub.vv v9, v9, v15 \n\t"
1996
+
1997
+ SQ4BIT_KERNEL_COMP_4x16x16
1998
+
1999
+ "addi t2, t2, -1 \n\t"
2000
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
2001
+
2002
+ LOAD_SCALE_4x16
2003
+
2004
+ "vsetvli t0, zero, e32, m8 \n\t"
2005
+ "vfcvt.f.x.v v16, v16 \n\t"
2006
+ "vfmacc.vv v24, v16, v8 \n\t"
2007
+ "addi t3, t3, -1 \n\t"
2008
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
2009
+
2010
+ "RESULT_SAVE%=: \n\t"
2011
+
2012
+ SAVE_RESULT_4x16
2013
+
2014
+ :
2015
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
2016
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
2017
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
2018
+ "s4", "s5", "s6");
2019
+ }
2020
+ }
2021
+ } else {
2022
+ for (size_t n = 0; n < CountN; n += 16) {
2023
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
2024
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2025
+ n * BlockCountK * BlkLen / 2 + // b data
2026
+ n * BlockCountK * sizeof(float); // scale
2027
+ float * CPtr = C + n;
2028
+ if (NBLKS < 16) {
2029
+ CPtr = tmp;
2030
+ LDC = 16 * sizeof(float);
2031
+ }
2032
+ if (Bias != nullptr) {
2033
+ const float * bias = Bias + n;
2034
+ if (NBLKS < 16) {
2035
+ __asm__ volatile(
2036
+ "vsetvli t0, %[N], e32, m2 \n\t"
2037
+ "vle32.v v0, (%[SRC]) \n\t"
2038
+ "vse32.v v0, (%[DST]) \n\t"
2039
+ :
2040
+ : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
2041
+ : "cc", "t0");
2042
+ bias = tmp;
2043
+ }
2044
+ __asm__ volatile(LOAD_BIAS
2045
+ "addi t3, %[BlockCountK], 0 \n\t"
2046
+ "addi a1, %[A], 0 \n\t"
2047
+ "addi s1, %[B], 0 \n\t"
2048
+ "BLOCK_COUNTK_LOOP%=: \n\t"
2049
+ "addi s5, s1, 0 \n\t"
2050
+ "addi s1, s5, 64 \n\t"
2051
+ "addi s2, s1, 32 \n\t"
2052
+ "addi s3, s1, 32*2 \n\t"
2053
+ "addi s4, s1, 32*3 \n\t"
2054
+ "vsetvli t0, zero, e32, m8 \n\t"
2055
+ "vxor.vv v16, v16, v16 \n\t"
2056
+ // load a scale
2057
+ "flw f1, (a1) \n\t"
2058
+ "flw f2, 4(a1) \n\t"
2059
+ "flw f3, 8(a1) \n\t"
2060
+ "flw f4, 12(a1) \n\t"
2061
+ "addi a1, a1, 16 \n\t"
2062
+ "addi t2, %[INNER], 0 \n\t"
2063
+ "BLOCK_INNER_LOOP%=: \n\t"
2064
+
2065
+ LOAD_B_16x8x2
2066
+
2067
+ "vsetvli t0, zero, e8, m1 \n\t"
2068
+ "vle8.v v10, (a1) \n\t"
2069
+ "addi a1, a1, 32 \n\t"
2070
+ "vle8.v v11, (a1) \n\t"
2071
+ "addi a1, a1, 32 \n\t"
2072
+ "vadd.vi v2, v2, -8 \n\t"
2073
+ "vadd.vi v3, v3, -8 \n\t"
2074
+ "vadd.vi v4, v4, -8 \n\t"
2075
+ "vadd.vi v5, v5, -8 \n\t"
2076
+ "vadd.vi v6, v6, -8 \n\t"
2077
+ "vadd.vi v7, v7, -8 \n\t"
2078
+ "vadd.vi v8, v8, -8 \n\t"
2079
+ "vadd.vi v9, v9, -8 \n\t"
2080
+
2081
+ SQ4BIT_KERNEL_COMP_4x16x16
2082
+
2083
+ "addi t2, t2, -1 \n\t"
2084
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
2085
+
2086
+ LOAD_SCALE_4x16
2087
+
2088
+ "vsetvli t0, zero, e32, m8 \n\t"
2089
+ "vfcvt.f.x.v v16, v16 \n\t"
2090
+ "vfmacc.vv v24, v16, v8 \n\t"
2091
+ "addi t3, t3, -1 \n\t"
2092
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
2093
+
2094
+ "RESULT_SAVE%=: \n\t"
2095
+
2096
+ SAVE_RESULT_4x16
2097
+
2098
+ :
2099
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
2100
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
2101
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
2102
+ "s2", "s3", "s4", "s5", "s6");
2103
+
2104
+ } else {
2105
+ __asm__ volatile(
2106
+ "vsetvli t0, zero, e32, m8 \n\t"
2107
+ "vxor.vv v24, v24, v24 \n\t"
2108
+ "addi t3, %[BlockCountK], 0 \n\t"
2109
+ "addi a1, %[A], 0 \n\t"
2110
+ "addi s1, %[B], 0 \n\t"
2111
+ "BLOCK_COUNTK_LOOP%=: \n\t"
2112
+ "addi s5, s1, 0 \n\t"
2113
+ "addi s1, s5, 64 \n\t"
2114
+ "addi s2, s1, 32 \n\t"
2115
+ "addi s3, s1, 32*2 \n\t"
2116
+ "addi s4, s1, 32*3 \n\t"
2117
+ "vsetvli t0, zero, e32, m8 \n\t"
2118
+ "vxor.vv v16, v16, v16 \n\t"
2119
+ // load a scale
2120
+ "flw f1, (a1) \n\t"
2121
+ "flw f2, 4(a1) \n\t"
2122
+ "flw f3, 8(a1) \n\t"
2123
+ "flw f4, 12(a1) \n\t"
2124
+ "addi a1, a1, 16 \n\t"
2125
+ "addi t2, %[INNER], 0 \n\t"
2126
+ "BLOCK_INNER_LOOP%=: \n\t"
2127
+
2128
+ LOAD_B_16x8x2
2129
+
2130
+ "vsetvli t0, zero, e8, m1 \n\t"
2131
+ "vle8.v v10, (a1) \n\t"
2132
+
2133
+ "addi a1, a1, 32 \n\t"
2134
+ "vle8.v v11, (a1) \n\t"
2135
+ "addi a1, a1, 32 \n\t"
2136
+ "vadd.vi v2, v2, -8 \n\t"
2137
+ "vadd.vi v3, v3, -8 \n\t"
2138
+ "vadd.vi v4, v4, -8 \n\t"
2139
+ "vadd.vi v5, v5, -8 \n\t"
2140
+ "vadd.vi v6, v6, -8 \n\t"
2141
+ "vadd.vi v7, v7, -8 \n\t"
2142
+ "vadd.vi v8, v8, -8 \n\t"
2143
+ "vadd.vi v9, v9, -8 \n\t"
2144
+
2145
+ SQ4BIT_KERNEL_COMP_4x16x16
2146
+
2147
+ "addi t2, t2, -1 \n\t"
2148
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
2149
+
2150
+ LOAD_SCALE_4x16
2151
+
2152
+ "vsetvli t0, zero, e32, m8 \n\t"
2153
+ "vfcvt.f.x.v v16, v16 \n\t"
2154
+ "vfmacc.vv v24, v16, v8 \n\t"
2155
+ "addi t3, t3, -1 \n\t"
2156
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
2157
+
2158
+ "RESULT_SAVE%=: \n\t"
2159
+
2160
+ SAVE_RESULT_4x16
2161
+
2162
+ :
2163
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
2164
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
2165
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
2166
+ "s4", "s5", "s6");
2167
+ }
2168
+ }
2169
+ }
2170
+ if (CountN % 16 != 0) {
2171
+ // stroe output from tmp to C when NBLKS less than 16.
2172
+ float * CPtr = C + CountN / 16 * 16;
2173
+ const size_t N = CountN % 16;
2174
+ LDC = ldc * sizeof(float);
2175
+ __asm__ volatile(
2176
+ "vsetvli t0, %[N], e32, m2 \n\t"
2177
+ "vle32.v v0, (%[SRC]) \n\t"
2178
+ "addi s2, %[SRC], 64 \n\t"
2179
+ "addi s3, %[SRC], 64*2 \n\t"
2180
+ "addi s4, %[SRC], 64*3 \n\t"
2181
+ "vle32.v v2, (s2) \n\t"
2182
+ "vle32.v v4, (s3) \n\t"
2183
+ "vle32.v v6, (s4) \n\t"
2184
+ "add t2, %[DST], %[LDC] \n\t"
2185
+ "add t3, t2, %[LDC] \n\t"
2186
+ "add t4, t3, %[LDC] \n\t"
2187
+ "vse32.v v0, (%[DST]) \n\t"
2188
+ "vse32.v v2, (t2) \n\t"
2189
+ "vse32.v v4, (t3) \n\t"
2190
+ "vse32.v v6, (t4) \n\t"
2191
+ :
2192
+ : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC)
2193
+ : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4");
2194
+ }
2195
+ }
2196
+
2197
+ template <bool HasZeroPoint>
2198
+ void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
2199
+ const std::byte * QuantA,
2200
+ const std::byte * QuantBData,
2201
+ const float * QuantBScale,
2202
+ const std::byte * QuantBZeroPoint,
2203
+ float * C,
2204
+ size_t CountN,
2205
+ size_t BlockCountK,
2206
+ const float * Bias) {
2207
+ GGML_UNUSED(QuantBScale);
2208
+ GGML_UNUSED(QuantBZeroPoint);
2209
+ size_t INNER = BlkLen / 16;
2210
+
2211
+ if constexpr (HasZeroPoint) {
2212
+ for (size_t n = 0; n < CountN; n += 16) {
2213
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2214
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2215
+ n * BlockCountK * BlkLen / 2 + // b data
2216
+ n * BlockCountK * sizeof(uint8_t) + // zp
2217
+ n * BlockCountK * sizeof(_Float16); // scale
2218
+ float * CPtr = C + n;
2219
+ size_t cnt = BlockCountK;
2220
+ if (Bias != nullptr) {
2221
+ const float * bias = Bias + n;
2222
+ __asm__ volatile(
2223
+ "addi t3, %[NBLKS], 0 \n\t"
2224
+ "vsetvli t0, zero, e8, m1 \n\t"
2225
+
2226
+ "vmv.v.i v13, 3 \n\t"
2227
+ "li s1, 24 \n\t"
2228
+ "vsetvli t0, s1, e8, m1 \n\t"
2229
+ "vmv.v.i v13, 2 \n\t"
2230
+ "vsetvli t0, zero, e8, mf2 \n\t"
2231
+ "vmv.v.i v13, 1 \n\t"
2232
+ "vsetvli t0, zero, e8, mf4 \n\t"
2233
+ "vmv.v.i v13, 0 \n\t"
2234
+ "addi s1, %[B], 0 \n\t"
2235
+ "addi s2, %[B], 8 \n\t"
2236
+ "addi s3, %[B], 16 \n\t"
2237
+ "addi s4, %[B], 24 \n\t"
2238
+ // zp offset
2239
+ "addi s7, %[B], 32 \n\t"
2240
+ // a offset
2241
+ "addi s5, %[A], 0 \n\t"
2242
+ "addi s6, %[A], 12 \n\t"
2243
+
2244
+ "vsetvli t0, t3, e32, mf2 \n\t"
2245
+ "vle32.v v28, (%[BIAS]) \n\t"
2246
+ "sub t3, t3, t0 \n\t"
2247
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2248
+ "vsetvli t0, t3, e32, mf2 \n\t"
2249
+ "vle32.v v29, (%[BIAS]) \n\t"
2250
+ "sub t3, t3, t0 \n\t"
2251
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2252
+ "vsetvli t0, t3, e32, mf2 \n\t"
2253
+ "vle32.v v30, (%[BIAS]) \n\t"
2254
+ "sub t3, t3, t0 \n\t"
2255
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2256
+ "vsetvli t0, t3, e32, mf2 \n\t"
2257
+ "vle32.v v31, (%[BIAS]) \n\t"
2258
+
2259
+ "LOOP_K%=: \n\t"
2260
+ "vsetvli t0, zero, e16, mf4 \n\t"
2261
+
2262
+ "vle16.v v4, (s1) \n\t"
2263
+ "addi s1, s1, 48 \n\t"
2264
+ "vle16.v v5, (s2) \n\t"
2265
+ "addi s2, s2, 72 \n\t"
2266
+ "vle16.v v6, (s3) \n\t"
2267
+ "addi s3, s3, 96 \n\t"
2268
+ "vle16.v v7, (s4) \n\t"
2269
+ "addi s4, s4, 120 \n\t"
2270
+ "flw f1, (s5) \n\t"
2271
+ "addi s5, s5, 4 \n\t"
2272
+ "vfwcvt.f.f.v v8, v4 \n\t"
2273
+ "vfwcvt.f.f.v v9, v5 \n\t"
2274
+ "vfwcvt.f.f.v v10, v6 \n\t"
2275
+ "vfwcvt.f.f.v v11, v7 \n\t"
2276
+
2277
+ "vsetvli t0, zero, e32, mf2 \n\t"
2278
+ "addi t5, %[INNER], 0 \n\t"
2279
+ "vxor.vv v16, v16, v16 \n\t"
2280
+ "vxor.vv v18, v18, v18 \n\t"
2281
+ "vxor.vv v20, v20, v20 \n\t"
2282
+ "vxor.vv v22, v22, v22 \n\t"
2283
+ "vfmul.vf v24, v8, f1 \n\t"
2284
+ "vfmul.vf v25, v9, f1 \n\t"
2285
+ "vfmul.vf v26, v10, f1 \n\t"
2286
+ "vfmul.vf v27, v11, f1 \n\t"
2287
+ "addi %[CNT], %[CNT], -1 \n\t"
2288
+
2289
+ SQ4BIT_KERNEL_LOAD_ZP_16X1
2290
+
2291
+ "LOOP_INNER%=: \n\t"
2292
+
2293
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2294
+
2295
+ "vsub.vv v0, v0, v8 \n\t"
2296
+ "vsub.vv v4, v4, v8 \n\t"
2297
+ "vsub.vv v1, v1, v9 \n\t"
2298
+ "vsub.vv v5, v5, v9 \n\t"
2299
+ "vsub.vv v2, v2, v10 \n\t"
2300
+ "vsub.vv v6, v6, v10 \n\t"
2301
+ "vsub.vv v3, v3, v11 \n\t"
2302
+ "vsub.vv v7, v7, v11 \n\t"
2303
+
2304
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2305
+
2306
+ "bnez t5, LOOP_INNER%= \n\t"
2307
+ "vsetvli t0, zero, e32, mf2 \n\t"
2308
+
2309
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
2310
+ "addi s7, s1, 32 \n\t"
2311
+
2312
+ "bnez %[CNT], LOOP_K%= \n\t"
2313
+ "addi t3, zero, 16 \n\t"
2314
+ "addi s1, %[C], 16 \n\t"
2315
+ "addi s2, %[C], 32 \n\t"
2316
+ "addi s3, %[C], 48 \n\t"
2317
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2318
+ "vse32.v v28, (%[C]) \n\t"
2319
+ "vse32.v v29, (s1) \n\t"
2320
+ "vse32.v v30, (s2) \n\t"
2321
+ "vse32.v v31, (s3) \n\t"
2322
+ "jal x0, END%= \n\t"
2323
+
2324
+ "ST_TAIL%=: \n\t"
2325
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2326
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2327
+ "vse32.v v28, (%[C]) \n\t"
2328
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2329
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2330
+ "vse32.v v29, (s1) \n\t"
2331
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2332
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2333
+ "vse32.v v30, (s2) \n\t"
2334
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2335
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2336
+ "vse32.v v31, (s3) \n\t"
2337
+ "END%=: \n\t"
2338
+
2339
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
2340
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2341
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2342
+ } else {
2343
+ __asm__ volatile(
2344
+ "vsetvli t0, zero, e32, m4 \n\t"
2345
+ "vxor.vv v28, v28, v28 \n\t"
2346
+
2347
+ "vsetvli t0, zero, e8, m1 \n\t"
2348
+ "vmv.v.i v13, 3 \n\t"
2349
+ "li s1, 24 \n\t"
2350
+ "vsetvli t0, s1, e8, m1 \n\t"
2351
+ "vmv.v.i v13, 2 \n\t"
2352
+ "vsetvli t0, zero, e8, mf2 \n\t"
2353
+ "vmv.v.i v13, 1 \n\t"
2354
+ "vsetvli t0, zero, e8, mf4 \n\t"
2355
+ "vmv.v.i v13, 0 \n\t"
2356
+
2357
+ "addi s1, %[B], 0 \n\t"
2358
+ "addi s2, %[B], 8 \n\t"
2359
+ "addi s3, %[B], 16 \n\t"
2360
+ "addi s4, %[B], 24 \n\t"
2361
+
2362
+ "addi s7, %[B], 32 \n\t"
2363
+
2364
+ "addi s5, %[A], 0 \n\t"
2365
+ "addi s6, %[A], 12 \n\t"
2366
+ "LOOP_K%=: \n\t"
2367
+ "vsetvli t0, zero, e16, mf4 \n\t"
2368
+ "vle16.v v4, (s1) \n\t"
2369
+ "addi s1, s1, 48 \n\t"
2370
+ "vle16.v v5, (s2) \n\t"
2371
+ "addi s2, s2, 72 \n\t"
2372
+ "vle16.v v6, (s3) \n\t"
2373
+ "addi s3, s3, 96 \n\t"
2374
+ "vle16.v v7, (s4) \n\t"
2375
+ "addi s4, s4, 120 \n\t"
2376
+ "flw f1, (s5) \n\t"
2377
+ "addi s5, s5, 4 \n\t"
2378
+
2379
+ "vfwcvt.f.f.v v8, v4 \n\t"
2380
+ "vfwcvt.f.f.v v9, v5 \n\t"
2381
+ "vfwcvt.f.f.v v10, v6 \n\t"
2382
+ "vfwcvt.f.f.v v11, v7 \n\t"
2383
+ "vsetvli t0, zero, e32, mf2 \n\t"
2384
+
2385
+ "addi t5, %[INNER], 0 \n\t"
2386
+ "vxor.vv v16, v16, v16 \n\t"
2387
+ "vxor.vv v18, v18, v18 \n\t"
2388
+ "vxor.vv v20, v20, v20 \n\t"
2389
+ "vxor.vv v22, v22, v22 \n\t"
2390
+ "vfmul.vf v24, v8, f1 \n\t"
2391
+ "vfmul.vf v25, v9, f1 \n\t"
2392
+ "vfmul.vf v26, v10, f1 \n\t"
2393
+ "vfmul.vf v27, v11, f1 \n\t"
2394
+ "addi %[CNT], %[CNT], -1 \n\t"
2395
+
2396
+ SQ4BIT_KERNEL_LOAD_ZP_16X1
2397
+
2398
+ "LOOP_INNER%=: \n\t"
2399
+
2400
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2401
+
2402
+ "vsub.vv v0, v0, v8 \n\t"
2403
+ "vsub.vv v4, v4, v8 \n\t"
2404
+ "vsub.vv v1, v1, v9 \n\t"
2405
+ "vsub.vv v5, v5, v9 \n\t"
2406
+ "vsub.vv v2, v2, v10 \n\t"
2407
+ "vsub.vv v6, v6, v10 \n\t"
2408
+ "vsub.vv v3, v3, v11 \n\t"
2409
+ "vsub.vv v7, v7, v11 \n\t"
2410
+
2411
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2412
+
2413
+ "bnez t5, LOOP_INNER%= \n\t"
2414
+ "vsetvli t0, zero, e32, mf2 \n\t"
2415
+
2416
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
2417
+ "addi s7, s1, 32 \n\t"
2418
+
2419
+ "bnez %[CNT], LOOP_K%= \n\t"
2420
+ "addi t3, zero, 16 \n\t"
2421
+ "addi s1, %[C], 16 \n\t"
2422
+ "addi s2, %[C], 32 \n\t"
2423
+ "addi s3, %[C], 48 \n\t"
2424
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2425
+ "vse32.v v28, (%[C]) \n\t"
2426
+ "vse32.v v29, (s1) \n\t"
2427
+ "vse32.v v30, (s2) \n\t"
2428
+ "vse32.v v31, (s3) \n\t"
2429
+ "jal x0, END%= \n\t"
2430
+
2431
+ "ST_TAIL%=: \n\t"
2432
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2433
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2434
+ "vse32.v v28, (%[C]) \n\t"
2435
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2436
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2437
+ "vse32.v v29, (s1) \n\t"
2438
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2439
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2440
+ "vse32.v v30, (s2) \n\t"
2441
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2442
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2443
+ "vse32.v v31, (s3) \n\t"
2444
+ "END%=: \n\t"
2445
+
2446
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
2447
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2448
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2449
+ }
2450
+ }
2451
+ } else {
2452
+ for (size_t n = 0; n < CountN; n += 16) {
2453
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2454
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2455
+ n * BlockCountK * BlkLen / 2 + // b data
2456
+ n * BlockCountK * sizeof(_Float16); // scale
2457
+ float * CPtr = C + n;
2458
+ size_t cnt = BlockCountK;
2459
+ if (Bias != nullptr) {
2460
+ const float * bias = Bias + n;
2461
+ __asm__ volatile(
2462
+ "addi t3, %[NBLKS], 0 \n\t"
2463
+ "addi s1, %[B], 0 \n\t"
2464
+ "addi s2, %[B], 8 \n\t"
2465
+ "addi s3, %[B], 16 \n\t"
2466
+ "addi s4, %[B], 24 \n\t"
2467
+ "addi s5, %[A], 0 \n\t"
2468
+ "addi s6, %[A], 12 \n\t"
2469
+ "vsetvli t0, t3, e32, mf2 \n\t"
2470
+ "vle32.v v28, (%[BIAS]) \n\t"
2471
+ "sub t3, t3, t0 \n\t"
2472
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2473
+ "vsetvli t0, t3, e32, mf2 \n\t"
2474
+ "vle32.v v29, (%[BIAS]) \n\t"
2475
+ "sub t3, t3, t0 \n\t"
2476
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2477
+ "vsetvli t0, t3, e32, mf2 \n\t"
2478
+ "vle32.v v30, (%[BIAS]) \n\t"
2479
+ "sub t3, t3, t0 \n\t"
2480
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2481
+ "vsetvli t0, t3, e32, mf2 \n\t"
2482
+ "vle32.v v31, (%[BIAS]) \n\t"
2483
+
2484
+ "LOOP_K%=: \n\t"
2485
+ "vsetvli t0, zero, e16, mf4 \n\t"
2486
+
2487
+ "vle16.v v4, (s1) \n\t"
2488
+ "addi s1, s1, 32 \n\t"
2489
+ "vle16.v v5, (s2) \n\t"
2490
+ "addi s2, s2, 56 \n\t"
2491
+ "vle16.v v6, (s3) \n\t"
2492
+ "addi s3, s3, 80 \n\t"
2493
+ "vle16.v v7, (s4) \n\t"
2494
+ "addi s4, s4, 104 \n\t"
2495
+ "flw f1, (s5) \n\t"
2496
+ "addi s5, s5, 4 \n\t"
2497
+ "vfwcvt.f.f.v v8, v4 \n\t"
2498
+ "vfwcvt.f.f.v v9, v5 \n\t"
2499
+ "vfwcvt.f.f.v v10, v6 \n\t"
2500
+ "vfwcvt.f.f.v v11, v7 \n\t"
2501
+
2502
+ "vsetvli t0, zero, e32, mf2 \n\t"
2503
+ "addi t5, %[INNER], 0 \n\t"
2504
+ "vxor.vv v16, v16, v16 \n\t"
2505
+ "vxor.vv v18, v18, v18 \n\t"
2506
+ "vxor.vv v20, v20, v20 \n\t"
2507
+ "vxor.vv v22, v22, v22 \n\t"
2508
+ "vfmul.vf v24, v8, f1 \n\t"
2509
+ "vfmul.vf v25, v9, f1 \n\t"
2510
+ "vfmul.vf v26, v10, f1 \n\t"
2511
+ "vfmul.vf v27, v11, f1 \n\t"
2512
+ "addi %[CNT], %[CNT], -1 \n\t"
2513
+ "vsetvli t0, zero, e8, m1 \n\t"
2514
+ "LOOP_INNER%=: \n\t"
2515
+
2516
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2517
+
2518
+ "vadd.vi v0, v0, -8 \n\t"
2519
+ "vadd.vi v1, v1, -8 \n\t"
2520
+ "vadd.vi v2, v2, -8 \n\t"
2521
+ "vadd.vi v3, v3, -8 \n\t"
2522
+ "vadd.vi v4, v4, -8 \n\t"
2523
+ "vadd.vi v5, v5, -8 \n\t"
2524
+ "vadd.vi v6, v6, -8 \n\t"
2525
+ "vadd.vi v7, v7, -8 \n\t"
2526
+
2527
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2528
+
2529
+ "bnez t5, LOOP_INNER%= \n\t"
2530
+ "vsetvli t0, zero, e32, mf2 \n\t"
2531
+
2532
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
2533
+
2534
+ "bnez %[CNT], LOOP_K%= \n\t"
2535
+ "addi t3, zero, 16 \n\t"
2536
+ "addi s1, %[C], 16 \n\t"
2537
+ "addi s2, %[C], 32 \n\t"
2538
+ "addi s3, %[C], 48 \n\t"
2539
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2540
+ "vse32.v v28, (%[C]) \n\t"
2541
+ "vse32.v v29, (s1) \n\t"
2542
+ "vse32.v v30, (s2) \n\t"
2543
+ "vse32.v v31, (s3) \n\t"
2544
+ "jal x0, END%= \n\t"
2545
+
2546
+ "ST_TAIL%=: \n\t"
2547
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2548
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2549
+ "vse32.v v28, (%[C]) \n\t"
2550
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2551
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2552
+ "vse32.v v29, (s1) \n\t"
2553
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2554
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2555
+ "vse32.v v30, (s2) \n\t"
2556
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2557
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2558
+ "vse32.v v31, (s3) \n\t"
2559
+ "END%=: \n\t"
2560
+
2561
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
2562
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2563
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
2564
+ } else {
2565
+ __asm__ volatile(
2566
+ "vsetvli t0, zero, e32, m4 \n\t"
2567
+ "vxor.vv v28, v28, v28 \n\t"
2568
+ "addi s1, %[B], 0 \n\t"
2569
+ "addi s2, %[B], 8 \n\t"
2570
+ "addi s3, %[B], 16 \n\t"
2571
+ "addi s4, %[B], 24 \n\t"
2572
+
2573
+ "addi s5, %[A], 0 \n\t"
2574
+ "addi s6, %[A], 12 \n\t"
2575
+ "LOOP_K%=: \n\t"
2576
+ "vsetvli t0, zero, e16, mf4 \n\t"
2577
+ "vle16.v v4, (s1) \n\t"
2578
+ "addi s1, s1, 32 \n\t"
2579
+ "vle16.v v5, (s2) \n\t"
2580
+ "addi s2, s2, 56 \n\t"
2581
+ "vle16.v v6, (s3) \n\t"
2582
+ "addi s3, s3, 80 \n\t"
2583
+ "vle16.v v7, (s4) \n\t"
2584
+ "addi s4, s4, 104 \n\t"
2585
+ "flw f1, (s5) \n\t"
2586
+ "addi s5, s5, 4 \n\t"
2587
+
2588
+ "vfwcvt.f.f.v v8, v4 \n\t"
2589
+ "vfwcvt.f.f.v v9, v5 \n\t"
2590
+ "vfwcvt.f.f.v v10, v6 \n\t"
2591
+ "vfwcvt.f.f.v v11, v7 \n\t"
2592
+ "vsetvli t0, zero, e32, mf2 \n\t"
2593
+
2594
+ "addi t5, %[INNER], 0 \n\t"
2595
+ "vxor.vv v16, v16, v16 \n\t"
2596
+ "vxor.vv v18, v18, v18 \n\t"
2597
+ "vxor.vv v20, v20, v20 \n\t"
2598
+ "vxor.vv v22, v22, v22 \n\t"
2599
+ "vfmul.vf v24, v8, f1 \n\t"
2600
+ "vfmul.vf v25, v9, f1 \n\t"
2601
+ "vfmul.vf v26, v10, f1 \n\t"
2602
+ "vfmul.vf v27, v11, f1 \n\t"
2603
+ "addi %[CNT], %[CNT], -1 \n\t"
2604
+ "vsetvli t0, zero, e8, m1 \n\t"
2605
+ "LOOP_INNER%=: \n\t"
2606
+
2607
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2608
+
2609
+ "vadd.vi v0, v0, -8 \n\t"
2610
+ "vadd.vi v1, v1, -8 \n\t"
2611
+ "vadd.vi v2, v2, -8 \n\t"
2612
+ "vadd.vi v3, v3, -8 \n\t"
2613
+ "vadd.vi v4, v4, -8 \n\t"
2614
+ "vadd.vi v5, v5, -8 \n\t"
2615
+ "vadd.vi v6, v6, -8 \n\t"
2616
+ "vadd.vi v7, v7, -8 \n\t"
2617
+
2618
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2619
+
2620
+ "bnez t5, LOOP_INNER%= \n\t"
2621
+ "vsetvli t0, zero, e32, mf2 \n\t"
2622
+
2623
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
2624
+
2625
+ "bnez %[CNT], LOOP_K%= \n\t"
2626
+ "addi t3, zero, 16 \n\t"
2627
+ "addi s1, %[C], 16 \n\t"
2628
+ "addi s2, %[C], 32 \n\t"
2629
+ "addi s3, %[C], 48 \n\t"
2630
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2631
+ "vse32.v v28, (%[C]) \n\t"
2632
+ "vse32.v v29, (s1) \n\t"
2633
+ "vse32.v v30, (s2) \n\t"
2634
+ "vse32.v v31, (s3) \n\t"
2635
+ "jal x0, END%= \n\t"
2636
+
2637
+ "ST_TAIL%=: \n\t"
2638
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2639
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2640
+ "vse32.v v28, (%[C]) \n\t"
2641
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2642
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2643
+ "vse32.v v29, (s1) \n\t"
2644
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2645
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2646
+ "vse32.v v30, (s2) \n\t"
2647
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2648
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2649
+ "vse32.v v31, (s3) \n\t"
2650
+ "END%=: \n\t"
2651
+
2652
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
2653
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2654
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
2655
+ }
2656
+ }
2657
+ }
2658
+ }
2659
+
2660
+ template <bool HasZeroPoint>
2661
+ void SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen,
2662
+ const std::byte * QuantA,
2663
+ const std::byte * QuantBData,
2664
+ const float * QuantBScale,
2665
+ const std::byte * QuantBZeroPoint,
2666
+ float * C,
2667
+ size_t CountN,
2668
+ size_t BlockCountK,
2669
+ const float * Bias) {
2670
+ GGML_UNUSED(QuantBScale);
2671
+ GGML_UNUSED(QuantBZeroPoint);
2672
+ const size_t INNER = BlkLen / 16;
2673
+ if constexpr (HasZeroPoint) {
2674
+ for (size_t n = 0; n < CountN; n += 16) {
2675
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2676
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2677
+ n * BlockCountK * BlkLen / 2 + // b data
2678
+ n * BlockCountK * sizeof(uint8_t) + // zp
2679
+ n * BlockCountK * sizeof(float); // scale
2680
+ float * CPtr = C + n;
2681
+ size_t cnt = BlockCountK;
2682
+ if (Bias != nullptr) {
2683
+ const float * bias = Bias + n;
2684
+ __asm__ volatile(
2685
+ "addi t3, %[NBLKS], 0 \n\t"
2686
+ "vsetvli t0, zero, e8, m1 \n\t"
2687
+ "vmv.v.i v13, 3 \n\t"
2688
+ "li s1, 24 \n\t"
2689
+ "vsetvli t0, s1, e8, m1 \n\t"
2690
+ "vmv.v.i v13, 2 \n\t"
2691
+ "vsetvli t0, zero, e8, mf2 \n\t"
2692
+ "vmv.v.i v13, 1 \n\t"
2693
+ "vsetvli t0, zero, e8, mf4 \n\t"
2694
+ "vmv.v.i v13, 0 \n\t"
2695
+ "vsetvli t0, zero, e32, m4 \n\t"
2696
+ "vxor.vv v28, v28, v28 \n\t"
2697
+
2698
+ // scale offset, scale0.0, scale1.0, scale2.0, scale3.0....scale15.0
2699
+ "addi s1, %[B], 0 \n\t"
2700
+ "addi s2, %[B], 16 \n\t"
2701
+ "addi s3, %[B], 32 \n\t"
2702
+ "addi s4, %[B], 48 \n\t"
2703
+ // zp offset
2704
+ "addi s7, %[B], 64 \n\t"
2705
+ // a offset
2706
+ "addi s5, %[A], 0 \n\t"
2707
+ "addi s6, %[A], 12 \n\t"
2708
+
2709
+ "vsetvli t0, t3, e32, mf2 \n\t"
2710
+ "vle32.v v28, (%[BIAS]) \n\t"
2711
+ "sub t3, t3, t0 \n\t"
2712
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2713
+ "vsetvli t0, t3, e32, mf2 \n\t"
2714
+ "vle32.v v29, (%[BIAS]) \n\t"
2715
+ "sub t3, t3, t0 \n\t"
2716
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2717
+ "vsetvli t0, t3, e32, mf2 \n\t"
2718
+ "vle32.v v30, (%[BIAS]) \n\t"
2719
+ "sub t3, t3, t0 \n\t"
2720
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2721
+ "vsetvli t0, t3, e32, mf2 \n\t"
2722
+ "vle32.v v31, (%[BIAS]) \n\t"
2723
+ "vsetvli t0, zero, e32, mf2 \n\t"
2724
+ "LOOP_K%=: \n\t"
2725
+
2726
+ // load scale
2727
+ "vle32.v v8, (s1) \n\t"
2728
+ "addi s1, s1, 80 \n\t"
2729
+ "vle32.v v9, (s2) \n\t"
2730
+ "addi s2, s2, 96 \n\t"
2731
+ "vle32.v v10, (s3) \n\t"
2732
+ "addi s3, s3, 112 \n\t"
2733
+ "vle32.v v11, (s4) \n\t"
2734
+ "addi s4, s4, 128 \n\t"
2735
+
2736
+ // load a scale
2737
+ "flw f1, (s5) \n\t"
2738
+ "addi s5, s5, 4 \n\t"
2739
+
2740
+ "addi t5, %[INNER], 0 \n\t"
2741
+ "vxor.vv v16, v16, v16 \n\t"
2742
+ "vxor.vv v18, v18, v18 \n\t"
2743
+ "vxor.vv v20, v20, v20 \n\t"
2744
+ "vxor.vv v22, v22, v22 \n\t"
2745
+
2746
+ // a scale * b scale
2747
+ "vfmul.vf v24, v8, f1 \n\t"
2748
+ "vfmul.vf v25, v9, f1 \n\t"
2749
+ "vfmul.vf v26, v10, f1 \n\t"
2750
+ "vfmul.vf v27, v11, f1 \n\t"
2751
+ "addi %[CNT], %[CNT], -1 \n\t"
2752
+
2753
+ SQ4BIT_KERNEL_LOAD_ZP_16X1
2754
+
2755
+ "LOOP_INNER%=: \n\t"
2756
+
2757
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2758
+
2759
+ "vsub.vv v0, v0, v8 \n\t"
2760
+ "vsub.vv v4, v4, v8 \n\t"
2761
+ "vsub.vv v1, v1, v9 \n\t"
2762
+ "vsub.vv v5, v5, v9 \n\t"
2763
+ "vsub.vv v2, v2, v10 \n\t"
2764
+ "vsub.vv v6, v6, v10 \n\t"
2765
+ "vsub.vv v3, v3, v11 \n\t"
2766
+ "vsub.vv v7, v7, v11 \n\t"
2767
+
2768
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2769
+
2770
+ "bnez t5, LOOP_INNER%= \n\t"
2771
+ "vsetvli t0, zero, e32, mf2 \n\t"
2772
+
2773
+ SQ4BIT_KERNEL_ACC_1X4X4
2774
+ "addi s7, s1, 64 \n\t"
2775
+
2776
+ "bnez %[CNT], LOOP_K%= \n\t"
2777
+
2778
+ "addi t3, zero, 16 \n\t"
2779
+ "addi s1, %[C], 16 \n\t"
2780
+ "addi s2, %[C], 32 \n\t"
2781
+ "addi s3, %[C], 48 \n\t"
2782
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2783
+ "vse32.v v28, (%[C]) \n\t"
2784
+ "vse32.v v29, (s1) \n\t"
2785
+ "vse32.v v30, (s2) \n\t"
2786
+ "vse32.v v31, (s3) \n\t"
2787
+ "jal x0, END%= \n\t"
2788
+
2789
+ "ST_TAIL%=: \n\t"
2790
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2791
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2792
+ "vse32.v v28, (%[C]) \n\t"
2793
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2794
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2795
+ "vse32.v v29, (s1) \n\t"
2796
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2797
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2798
+ "vse32.v v30, (s2) \n\t"
2799
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2800
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2801
+ "vse32.v v31, (s3) \n\t"
2802
+ "END%=: \n\t"
2803
+
2804
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
2805
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2806
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2807
+ } else {
2808
+ __asm__ volatile(
2809
+ "vsetvli t0, zero, e32, m4 \n\t"
2810
+ "vxor.vv v28, v28, v28 \n\t"
2811
+
2812
+ "vsetvli t0, zero, e8, m1 \n\t"
2813
+ "vmv.v.i v13, 3 \n\t"
2814
+ "li s1, 24 \n\t"
2815
+ "vsetvli t0, s1, e8, m1 \n\t"
2816
+ "vmv.v.i v13, 2 \n\t"
2817
+ "vsetvli t0, zero, e8, mf2 \n\t"
2818
+ "vmv.v.i v13, 1 \n\t"
2819
+ "vsetvli t0, zero, e8, mf4 \n\t"
2820
+ "vmv.v.i v13, 0 \n\t"
2821
+ "addi s1, %[B], 0 \n\t"
2822
+ "addi s2, %[B], 16 \n\t"
2823
+ "addi s3, %[B], 32 \n\t"
2824
+ "addi s4, %[B], 48 \n\t"
2825
+
2826
+ "addi s7, %[B], 64 \n\t"
2827
+
2828
+ "addi s5, %[A], 0 \n\t"
2829
+ "addi s6, %[A], 12 \n\t"
2830
+ "vsetvli t0, zero, e32, mf2 \n\t"
2831
+
2832
+ "LOOP_K%=: \n\t"
2833
+ "vle32.v v8, (s1) \n\t"
2834
+ "addi s1, s1, 80 \n\t"
2835
+ "vle32.v v9, (s2) \n\t"
2836
+ "addi s2, s2, 96 \n\t"
2837
+ "vle32.v v10, (s3) \n\t"
2838
+ "addi s3, s3, 112 \n\t"
2839
+ "vle32.v v11, (s4) \n\t"
2840
+ "addi s4, s4, 128 \n\t"
2841
+
2842
+ "flw f1, (s5) \n\t"
2843
+ "addi s5, s5, 4 \n\t"
2844
+
2845
+ "addi t5, %[INNER], 0 \n\t"
2846
+ "vxor.vv v16, v16, v16 \n\t"
2847
+ "vxor.vv v18, v18, v18 \n\t"
2848
+ "vxor.vv v20, v20, v20 \n\t"
2849
+ "vxor.vv v22, v22, v22 \n\t"
2850
+
2851
+ "vfmul.vf v24, v8, f1 \n\t"
2852
+ "vfmul.vf v25, v9, f1 \n\t"
2853
+ "vfmul.vf v26, v10, f1 \n\t"
2854
+ "vfmul.vf v27, v11, f1 \n\t"
2855
+ "addi %[CNT], %[CNT], -1 \n\t"
2856
+
2857
+ SQ4BIT_KERNEL_LOAD_ZP_16X1
2858
+
2859
+ "LOOP_INNER%=: \n\t"
2860
+
2861
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2862
+
2863
+ "vsub.vv v0, v0, v8 \n\t"
2864
+ "vsub.vv v4, v4, v8 \n\t"
2865
+ "vsub.vv v1, v1, v9 \n\t"
2866
+ "vsub.vv v5, v5, v9 \n\t"
2867
+ "vsub.vv v2, v2, v10 \n\t"
2868
+ "vsub.vv v6, v6, v10 \n\t"
2869
+ "vsub.vv v3, v3, v11 \n\t"
2870
+ "vsub.vv v7, v7, v11 \n\t"
2871
+
2872
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2873
+
2874
+ "bnez t5, LOOP_INNER%= \n\t"
2875
+ "vsetvli t0, zero, e32, mf2 \n\t"
2876
+
2877
+ SQ4BIT_KERNEL_ACC_1X4X4
2878
+ "addi s7, s1, 64 \n\t"
2879
+
2880
+ "bnez %[CNT], LOOP_K%= \n\t"
2881
+
2882
+ "addi t3, zero, 16 \n\t"
2883
+ "addi s1, %[C], 16 \n\t"
2884
+ "addi s2, %[C], 32 \n\t"
2885
+ "addi s3, %[C], 48 \n\t"
2886
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2887
+ "vse32.v v28, (%[C]) \n\t"
2888
+ "vse32.v v29, (s1) \n\t"
2889
+ "vse32.v v30, (s2) \n\t"
2890
+ "vse32.v v31, (s3) \n\t"
2891
+ "jal x0, END%= \n\t"
2892
+
2893
+ "ST_TAIL%=: \n\t"
2894
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2895
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2896
+ "vse32.v v28, (%[C]) \n\t"
2897
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2898
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2899
+ "vse32.v v29, (s1) \n\t"
2900
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2901
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2902
+ "vse32.v v30, (s2) \n\t"
2903
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2904
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2905
+ "vse32.v v31, (s3) \n\t"
2906
+ "END%=: \n\t"
2907
+
2908
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
2909
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2910
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2911
+ }
2912
+ }
2913
+ } else {
2914
+ for (size_t n = 0; n < CountN; n += 16) {
2915
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2916
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2917
+ n * BlockCountK * BlkLen / 2 + // b data
2918
+ n * BlockCountK * sizeof(float); // scale
2919
+ float * CPtr = C + n;
2920
+ size_t cnt = BlockCountK;
2921
+ if (Bias != nullptr) {
2922
+ const float * bias = Bias + n;
2923
+ __asm__ volatile(
2924
+ "addi t3, %[NBLKS], 0 \n\t"
2925
+ "addi s1, %[B], 0 \n\t"
2926
+ "addi s2, %[B], 16 \n\t"
2927
+ "addi s3, %[B], 32 \n\t"
2928
+ "addi s4, %[B], 48 \n\t"
2929
+ "addi s5, %[A], 0 \n\t"
2930
+ "addi s6, %[A], 12 \n\t"
2931
+ "vsetvli t0, t3, e32, mf2 \n\t"
2932
+ "vle32.v v28, (%[BIAS]) \n\t"
2933
+ "sub t3, t3, t0 \n\t"
2934
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2935
+ "vsetvli t0, t3, e32, mf2 \n\t"
2936
+ "vle32.v v29, (%[BIAS]) \n\t"
2937
+ "sub t3, t3, t0 \n\t"
2938
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2939
+ "vsetvli t0, t3, e32, mf2 \n\t"
2940
+ "vle32.v v30, (%[BIAS]) \n\t"
2941
+ "sub t3, t3, t0 \n\t"
2942
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2943
+ "vsetvli t0, t3, e32, mf2 \n\t"
2944
+ "vle32.v v31, (%[BIAS]) \n\t"
2945
+ "vsetvli t0, zero, e32, mf2 \n\t"
2946
+ "LOOP_K%=: \n\t"
2947
+ "vle32.v v8, (s1) \n\t"
2948
+ "addi s1, s1, 64 \n\t"
2949
+ "vle32.v v9, (s2) \n\t"
2950
+ "addi s2, s2, 80 \n\t"
2951
+ "vle32.v v10, (s3) \n\t"
2952
+ "addi s3, s3, 96 \n\t"
2953
+ "vle32.v v11, (s4) \n\t"
2954
+ "addi s4, s4, 112 \n\t"
2955
+ "flw f1, (s5) \n\t"
2956
+ "addi s5, s5, 4 \n\t"
2957
+
2958
+ "addi t5, %[INNER], 0 \n\t"
2959
+ "vxor.vv v16, v16, v16 \n\t"
2960
+ "vxor.vv v18, v18, v18 \n\t"
2961
+ "vxor.vv v20, v20, v20 \n\t"
2962
+ "vxor.vv v22, v22, v22 \n\t"
2963
+ "vfmul.vf v24, v8, f1 \n\t"
2964
+ "vfmul.vf v25, v9, f1 \n\t"
2965
+ "vfmul.vf v26, v10, f1 \n\t"
2966
+ "vfmul.vf v27, v11, f1 \n\t"
2967
+ "addi %[CNT], %[CNT], -1 \n\t"
2968
+ "vsetvli t0, zero, e8, m1 \n\t"
2969
+ "LOOP_INNER%=: \n\t"
2970
+
2971
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2972
+
2973
+ "vadd.vi v0, v0, -8 \n\t"
2974
+ "vadd.vi v1, v1, -8 \n\t"
2975
+ "vadd.vi v2, v2, -8 \n\t"
2976
+ "vadd.vi v3, v3, -8 \n\t"
2977
+ "vadd.vi v4, v4, -8 \n\t"
2978
+ "vadd.vi v5, v5, -8 \n\t"
2979
+ "vadd.vi v6, v6, -8 \n\t"
2980
+ "vadd.vi v7, v7, -8 \n\t"
2981
+
2982
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2983
+
2984
+ "bnez t5, LOOP_INNER%= \n\t"
2985
+ "vsetvli t0, zero, e32, mf2 \n\t"
2986
+
2987
+ SQ4BIT_KERNEL_ACC_1X4X4
2988
+
2989
+ "bnez %[CNT], LOOP_K%= \n\t"
2990
+ "addi t3, zero, 16 \n\t"
2991
+ "addi s1, %[C], 16 \n\t"
2992
+ "addi s2, %[C], 32 \n\t"
2993
+ "addi s3, %[C], 48 \n\t"
2994
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2995
+ "vse32.v v28, (%[C]) \n\t"
2996
+ "vse32.v v29, (s1) \n\t"
2997
+ "vse32.v v30, (s2) \n\t"
2998
+ "vse32.v v31, (s3) \n\t"
2999
+ "jal x0, END%= \n\t"
3000
+
3001
+ "ST_TAIL%=: \n\t"
3002
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3003
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3004
+ "vse32.v v28, (%[C]) \n\t"
3005
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3006
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3007
+ "vse32.v v29, (s1) \n\t"
3008
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3009
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3010
+ "vse32.v v30, (s2) \n\t"
3011
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3012
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3013
+ "vse32.v v31, (s3) \n\t"
3014
+ "END%=: \n\t"
3015
+
3016
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
3017
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
3018
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
3019
+ } else {
3020
+ __asm__ volatile(
3021
+ "vsetvli t0, zero, e32, m4 \n\t"
3022
+ "vxor.vv v28, v28, v28 \n\t"
3023
+ "addi s1, %[B], 0 \n\t"
3024
+ "addi s2, %[B], 16 \n\t"
3025
+ "addi s3, %[B], 32 \n\t"
3026
+ "addi s4, %[B], 48 \n\t"
3027
+
3028
+ "addi s5, %[A], 0 \n\t"
3029
+ "addi s6, %[A], 12 \n\t"
3030
+ "vsetvli t0, zero, e32, mf2 \n\t"
3031
+ "LOOP_K%=: \n\t"
3032
+ "vle32.v v8, (s1) \n\t"
3033
+ "addi s1, s1, 64 \n\t"
3034
+ "vle32.v v9, (s2) \n\t"
3035
+ "addi s2, s2, 80 \n\t"
3036
+ "vle32.v v10, (s3) \n\t"
3037
+ "addi s3, s3, 96 \n\t"
3038
+ "vle32.v v11, (s4) \n\t"
3039
+ "addi s4, s4, 112 \n\t"
3040
+ "flw f1, (s5) \n\t"
3041
+ "addi s5, s5, 4 \n\t"
3042
+
3043
+ "addi t5, %[INNER], 0 \n\t"
3044
+ "vxor.vv v16, v16, v16 \n\t"
3045
+ "vxor.vv v18, v18, v18 \n\t"
3046
+ "vxor.vv v20, v20, v20 \n\t"
3047
+ "vxor.vv v22, v22, v22 \n\t"
3048
+ "vfmul.vf v24, v8, f1 \n\t"
3049
+ "vfmul.vf v25, v9, f1 \n\t"
3050
+ "vfmul.vf v26, v10, f1 \n\t"
3051
+ "vfmul.vf v27, v11, f1 \n\t"
3052
+ "addi %[CNT], %[CNT], -1 \n\t"
3053
+ "vsetvli t0, zero, e8, m1 \n\t"
3054
+ "LOOP_INNER%=: \n\t"
3055
+
3056
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
3057
+
3058
+ "vadd.vi v0, v0, -8 \n\t"
3059
+ "vadd.vi v1, v1, -8 \n\t"
3060
+ "vadd.vi v2, v2, -8 \n\t"
3061
+ "vadd.vi v3, v3, -8 \n\t"
3062
+ "vadd.vi v4, v4, -8 \n\t"
3063
+ "vadd.vi v5, v5, -8 \n\t"
3064
+ "vadd.vi v6, v6, -8 \n\t"
3065
+ "vadd.vi v7, v7, -8 \n\t"
3066
+
3067
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
3068
+
3069
+ "bnez t5, LOOP_INNER%= \n\t"
3070
+ "vsetvli t0, zero, e32, mf2 \n\t"
3071
+
3072
+ SQ4BIT_KERNEL_ACC_1X4X4
3073
+
3074
+ "bnez %[CNT], LOOP_K%= \n\t"
3075
+ "addi t3, zero, 16 \n\t"
3076
+ "addi s1, %[C], 16 \n\t"
3077
+ "addi s2, %[C], 32 \n\t"
3078
+ "addi s3, %[C], 48 \n\t"
3079
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
3080
+ "vse32.v v28, (%[C]) \n\t"
3081
+ "vse32.v v29, (s1) \n\t"
3082
+ "vse32.v v30, (s2) \n\t"
3083
+ "vse32.v v31, (s3) \n\t"
3084
+ "jal x0, END%= \n\t"
3085
+
3086
+ "ST_TAIL%=: \n\t"
3087
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3088
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3089
+ "vse32.v v28, (%[C]) \n\t"
3090
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3091
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3092
+ "vse32.v v29, (s1) \n\t"
3093
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3094
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3095
+ "vse32.v v30, (s2) \n\t"
3096
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3097
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3098
+ "vse32.v v31, (s3) \n\t"
3099
+ "END%=: \n\t"
3100
+
3101
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
3102
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
3103
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
3104
+ }
3105
+ }
3106
+ }
3107
+ }
3108
+
3109
+ template <bool HasZeroPoint>
3110
+ inline void SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen,
3111
+ const std::byte * QuantA,
3112
+ const std::byte * QuantBData,
3113
+ const float * QuantBScale,
3114
+ const std::byte * QuantBZeroPoint,
3115
+ float * C,
3116
+ size_t CountM,
3117
+ size_t CountN,
3118
+ size_t BlockStrideQuantB,
3119
+ const float * Bias,
3120
+ const size_t ldc,
3121
+ const size_t scalestride) {
3122
+ if (scalestride == 4) {
3123
+ SQ4BitGemmM4Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C,
3124
+ CountN, BlockStrideQuantB, Bias, ldc);
3125
+
3126
+ } else if (scalestride == 2) {
3127
+ SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>(
3128
+ BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias, ldc);
3129
+ }
3130
+ }
3131
+
3132
+ template <bool HasZeroPoint>
3133
+ inline void SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen,
3134
+ const std::byte * QuantA,
3135
+ const std::byte * QuantBData,
3136
+ const float * QuantBScale,
3137
+ const std::byte * QuantBZeroPoint,
3138
+ float * C,
3139
+ size_t CountM,
3140
+ size_t CountN,
3141
+ size_t BlockStrideQuantB,
3142
+ const float * Bias,
3143
+ const size_t ldc,
3144
+ const size_t scalestride) {
3145
+ if (scalestride == 4) {
3146
+ SQ4BitGemmM1Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C,
3147
+ CountN, BlockStrideQuantB, Bias);
3148
+ } else if (scalestride == 2) {
3149
+ SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale,
3150
+ QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias);
3151
+ }
3152
+ }
3153
+
3154
+ } // namespace
3155
+
3156
+ namespace ime1 {
3157
+ size_t gemm_kernel_i8i4(size_t BlkLen,
3158
+ const std::byte * QuantA,
3159
+ const std::byte * QuantBData,
3160
+ const float * QuantBScale,
3161
+ const std::byte * QuantBZeroPoint,
3162
+ float * C,
3163
+ size_t CountM,
3164
+ size_t CountN,
3165
+ size_t CountK,
3166
+ size_t BlockCountK,
3167
+ size_t ldc,
3168
+ const float * Bias,
3169
+ const size_t ScaleStride) {
3170
+ GGML_UNUSED(CountM);
3171
+ GGML_UNUSED(CountK);
3172
+ GGML_UNUSED(ldc);
3173
+ if (CountM >= 4) {
3174
+ if (QuantBZeroPoint != nullptr) {
3175
+ SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint,
3176
+ C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride);
3177
+ } else {
3178
+ SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale,
3179
+ QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias,
3180
+ ldc, ScaleStride);
3181
+ }
3182
+ return 4;
3183
+ } else {
3184
+ if (QuantBZeroPoint != nullptr) {
3185
+ SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint,
3186
+ C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride);
3187
+ } else {
3188
+ SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale,
3189
+ QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias,
3190
+ ldc, ScaleStride);
3191
+ }
3192
+ return 1;
3193
+ }
3194
+ }
3195
+ } // namespace ime1
3196
+ } // namespace sqnbitgemm_spacemit_ime