whispercpp 1.3.6 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (828) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/README.md +38 -5
  5. data/Rakefile +18 -3
  6. data/ext/dependencies.rb +10 -4
  7. data/ext/dependencies_for_windows.rb +17 -0
  8. data/ext/extconf.rb +20 -8
  9. data/ext/options.rb +54 -14
  10. data/ext/options_for_windows.rb +51 -0
  11. data/ext/ruby_whisper.c +36 -42
  12. data/ext/ruby_whisper.h +135 -0
  13. data/ext/ruby_whisper_context.c +107 -28
  14. data/ext/ruby_whisper_log_queue.c +180 -0
  15. data/ext/ruby_whisper_log_settable.h +47 -0
  16. data/ext/ruby_whisper_parakeet.c +49 -0
  17. data/ext/ruby_whisper_parakeet_context.c +304 -0
  18. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  19. data/ext/ruby_whisper_parakeet_model.c +84 -0
  20. data/ext/ruby_whisper_parakeet_params.c +548 -0
  21. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  22. data/ext/ruby_whisper_parakeet_token.c +188 -0
  23. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  24. data/ext/ruby_whisper_params.c +256 -65
  25. data/ext/ruby_whisper_segment.c +6 -6
  26. data/ext/ruby_whisper_transcribe.cpp +42 -15
  27. data/ext/sources/CMakeLists.txt +41 -3
  28. data/ext/sources/CMakePresets.json +95 -0
  29. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  30. data/ext/sources/cmake/parakeet.pc.in +10 -0
  31. data/ext/sources/cmake/whisper.pc.in +1 -1
  32. data/ext/sources/examples/CMakeLists.txt +4 -2
  33. data/ext/sources/examples/bench/bench.cpp +1 -1
  34. data/ext/sources/examples/cli/cli.cpp +43 -9
  35. data/ext/sources/examples/common-ggml.cpp +2 -0
  36. data/ext/sources/examples/common-whisper.cpp +139 -67
  37. data/ext/sources/examples/common-whisper.h +11 -0
  38. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  39. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  40. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  41. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  42. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  43. data/ext/sources/examples/server/server.cpp +199 -163
  44. data/ext/sources/ggml/CMakeLists.txt +21 -13
  45. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  46. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  47. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  48. data/ext/sources/ggml/include/ggml-backend.h +72 -10
  49. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  50. data/ext/sources/ggml/include/ggml-rpc.h +3 -3
  51. data/ext/sources/ggml/include/ggml.h +101 -9
  52. data/ext/sources/ggml/include/gguf.h +10 -2
  53. data/ext/sources/ggml/src/CMakeLists.txt +22 -5
  54. data/ext/sources/ggml/src/ggml-alloc.c +5 -1
  55. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  56. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  57. data/ext/sources/ggml/src/ggml-backend-reg.cpp +12 -0
  58. data/ext/sources/ggml/src/ggml-backend.cpp +110 -9
  59. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +4 -0
  60. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +672 -257
  61. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +71 -0
  62. data/ext/sources/ggml/src/ggml-cann/common.h +20 -10
  63. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +211 -30
  64. data/ext/sources/ggml/src/ggml-common.h +11 -0
  65. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +58 -29
  66. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +2 -0
  67. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +16 -16
  68. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +116 -7
  69. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +65 -0
  70. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  71. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  72. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +4279 -1292
  73. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +5 -35
  74. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +0 -1
  75. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  76. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +177 -27
  77. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +1 -1
  78. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +5 -0
  79. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  80. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -0
  81. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +95 -5
  82. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  83. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +146 -134
  84. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +88 -70
  85. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +372 -73
  86. data/ext/sources/ggml/src/ggml-cpu/ops.h +3 -0
  87. data/ext/sources/ggml/src/ggml-cpu/quants.c +55 -0
  88. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  89. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3 -0
  90. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +90 -0
  91. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +3 -16
  92. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  93. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  94. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  95. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  96. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  97. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  98. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  99. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  100. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  101. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  102. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  103. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  104. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  105. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  106. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  107. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +37 -53
  108. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  109. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +17 -7
  110. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  111. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  112. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +62 -26
  113. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +44 -18
  114. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  115. data/ext/sources/ggml/src/ggml-cuda/common.cuh +242 -28
  116. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  117. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  118. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  119. data/ext/sources/ggml/src/ggml-cuda/convert.cu +53 -0
  120. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  121. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +14 -6
  122. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  123. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +278 -44
  124. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +331 -130
  125. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +126 -27
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +40 -15
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +18 -9
  129. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +152 -49
  130. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  131. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  132. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  133. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +84 -35
  134. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  135. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1069 -609
  136. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  137. data/ext/sources/ggml/src/ggml-cuda/mean.cu +4 -2
  138. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +242 -195
  139. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +3 -3
  140. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  141. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +502 -423
  142. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +19 -12
  143. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +485 -57
  144. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -1
  145. data/ext/sources/ggml/src/ggml-cuda/norm.cu +36 -10
  146. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  147. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +133 -26
  148. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  149. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +5 -1
  150. data/ext/sources/ggml/src/ggml-cuda/rope.cu +11 -4
  151. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  152. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  153. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  154. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  155. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  156. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +45 -13
  157. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  158. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  159. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  160. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  161. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +1 -0
  162. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  163. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  164. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  165. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +1 -0
  166. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  167. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  168. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  169. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  170. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  171. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  172. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  173. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  174. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  176. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  177. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  178. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  179. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  191. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -4
  192. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +26 -23
  193. data/ext/sources/ggml/src/ggml-cuda/unary.cu +31 -2
  194. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  195. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +80 -0
  196. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  197. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -4
  198. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  199. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +2 -1
  200. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1428 -743
  201. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -7
  202. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +53 -84
  203. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +25 -12
  204. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +165 -184
  205. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  206. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  207. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +170 -127
  208. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  209. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  210. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  211. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +125 -97
  212. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  213. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +148 -42
  214. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +2 -2
  215. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +252 -62
  216. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +9 -0
  217. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +87 -1
  218. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  219. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  220. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  221. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  222. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  223. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  224. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  225. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  226. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +96 -13
  227. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +182 -57
  228. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  229. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +71 -3
  230. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +27 -10
  231. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +63 -23
  232. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +9 -8
  233. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  235. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +1 -0
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +5 -8
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +529 -815
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2522 -234
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +291 -95
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +59 -37
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +121 -133
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +244 -151
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +6 -6
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +719 -45
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +3 -1
  254. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +22 -9
  255. data/ext/sources/ggml/src/ggml-impl.h +6 -1
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +138 -13
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +32 -1
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +164 -28
  259. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +80 -0
  260. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +190 -19
  261. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +2 -0
  262. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +39 -26
  263. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +823 -322
  264. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  265. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +54 -5
  266. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +12248 -5907
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +67 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +59 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1819 -112
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl → gemm_noshuffle_q8_0_f32.cl} +1 -1
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +48 -64
  322. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +15 -5
  323. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +18 -11
  324. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +35 -13
  325. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +264 -192
  326. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +33 -7
  327. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  328. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +1 -0
  329. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +1 -0
  330. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  331. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +27 -3
  332. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +67 -36
  333. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +1 -0
  334. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +101 -44
  335. data/ext/sources/ggml/src/ggml-openvino/utils.h +23 -3
  336. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  337. data/ext/sources/ggml/src/ggml-quants.c +289 -114
  338. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  339. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  340. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  341. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  342. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  343. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +50 -4
  344. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +1 -1
  345. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +3 -1
  346. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  347. data/ext/sources/ggml/src/ggml-sycl/common.hpp +41 -1
  348. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +115 -13
  349. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +9 -0
  350. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  351. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  352. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  353. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  354. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  355. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  356. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1 -90
  357. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +0 -2
  358. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  359. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  360. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +7 -5
  361. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +4 -0
  362. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +76 -168
  363. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +7 -0
  364. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +3 -1
  365. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  366. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  367. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +69 -31
  368. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +1 -0
  369. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  370. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  371. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +823 -190
  372. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  373. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  374. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  375. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  376. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  377. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +71 -0
  378. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  379. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  380. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  381. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  382. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  383. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  384. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  385. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  386. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  387. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +1 -0
  388. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +1 -0
  389. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +1 -0
  390. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +1 -0
  391. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +1 -0
  392. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +1 -0
  393. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +1 -0
  394. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +1 -0
  395. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +1 -0
  396. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +1 -0
  397. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +1 -0
  398. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +1 -0
  399. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +1 -0
  400. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +1 -0
  401. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +1 -0
  402. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +1 -0
  403. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +1 -0
  404. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +1 -0
  405. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +1 -0
  406. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +1 -0
  407. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +1 -0
  408. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +1 -0
  409. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +1 -0
  410. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +1 -0
  411. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +1 -0
  412. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +1 -0
  413. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +1 -0
  414. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +1 -0
  415. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +1 -0
  416. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +1 -0
  417. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +1 -0
  418. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +1 -0
  420. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +1 -0
  421. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +1 -0
  422. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +1 -0
  423. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  424. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  425. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  426. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +215 -53
  427. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +4 -0
  428. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +2 -0
  429. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +2 -0
  430. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +1 -0
  431. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +1 -0
  432. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +0 -2
  433. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +11 -0
  434. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +2060 -535
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +197 -48
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +60 -59
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +115 -113
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +122 -31
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +125 -64
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -17
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +43 -10
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +5 -2
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +3 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +11 -1
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +171 -147
  484. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  485. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +2202 -283
  486. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2610 -1403
  487. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  488. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  489. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +8 -7
  490. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +76 -95
  491. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +19 -1
  492. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  493. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -50
  494. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +107 -184
  495. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  496. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  497. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  498. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  499. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  500. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  501. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +183 -78
  502. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  503. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  504. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +655 -495
  505. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  506. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  507. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  508. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +8 -6
  509. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +5 -1
  510. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +80 -409
  511. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  512. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  513. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  514. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  515. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  516. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  517. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  518. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +6 -4
  519. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  520. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +2 -3
  521. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  522. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  523. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  524. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  525. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  526. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +68 -48
  527. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  528. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +18 -14
  529. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +1 -1
  530. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +244 -10
  531. data/ext/sources/ggml/src/ggml.c +110 -28
  532. data/ext/sources/ggml/src/gguf.cpp +173 -28
  533. data/ext/sources/include/parakeet.h +342 -0
  534. data/ext/sources/include/whisper.h +10 -0
  535. data/ext/sources/media/matmul.png +0 -0
  536. data/ext/sources/src/CMakeLists.txt +23 -0
  537. data/ext/sources/src/parakeet-arch.h +188 -0
  538. data/ext/sources/src/parakeet.cpp +3838 -0
  539. data/ext/sources/src/whisper.cpp +56 -12
  540. data/extsources.rb +26 -10
  541. data/lib/whisper/log_settable.rb +36 -0
  542. data/lib/whisper/model/uri.rb +13 -1
  543. data/lib/whisper/output.rb +74 -0
  544. data/sig/whisper.rbs +411 -62
  545. data/test/helper.rb +2 -0
  546. data/test/jfk_reader/jfk_reader.c +50 -7
  547. data/test/test_callback.rb +1 -0
  548. data/test/test_package.rb +6 -5
  549. data/test/test_parakeet.rb +28 -0
  550. data/test/test_parakeet_callback.rb +107 -0
  551. data/test/test_parakeet_context.rb +116 -0
  552. data/test/test_parakeet_context_params.rb +24 -0
  553. data/test/test_parakeet_model.rb +21 -0
  554. data/test/test_parakeet_params.rb +78 -0
  555. data/test/test_parakeet_segment.rb +42 -0
  556. data/test/test_parakeet_token.rb +73 -0
  557. data/test/test_params.rb +2 -0
  558. data/test/test_vad_segment.rb +1 -1
  559. data/test/test_whisper.rb +24 -6
  560. data/whispercpp.gemspec +2 -2
  561. metadata +215 -281
  562. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  563. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  564. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  565. data/ext/sources/bindings/javascript/package.json +0 -26
  566. data/ext/sources/bindings/javascript/whisper.js +0 -19
  567. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  568. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  569. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  570. data/ext/sources/examples/addon.node/index.js +0 -59
  571. data/ext/sources/examples/addon.node/package.json +0 -16
  572. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  573. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  574. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  575. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  576. data/ext/sources/examples/coi-serviceworker.js +0 -146
  577. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  578. data/ext/sources/examples/command/command.cpp +0 -802
  579. data/ext/sources/examples/command/commands.txt +0 -9
  580. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  581. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  582. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  583. data/ext/sources/examples/generate-karaoke.sh +0 -57
  584. data/ext/sources/examples/helpers.js +0 -191
  585. data/ext/sources/examples/livestream.sh +0 -112
  586. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  587. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  588. data/ext/sources/examples/lsp/whisper.vim +0 -362
  589. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  590. data/ext/sources/examples/python/whisper_processor.py +0 -54
  591. data/ext/sources/examples/server/bench.js +0 -29
  592. data/ext/sources/examples/server.py +0 -120
  593. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  594. data/ext/sources/examples/stream/stream.cpp +0 -437
  595. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  596. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  597. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  598. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  599. data/ext/sources/examples/sycl/build.sh +0 -22
  600. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  601. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  602. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -48
  603. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  604. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -488
  605. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -89
  606. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2877
  607. data/ext/sources/examples/talk-llama/llama-arch.h +0 -628
  608. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -919
  609. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  610. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -896
  611. data/ext/sources/examples/talk-llama/llama-chat.h +0 -71
  612. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3633
  613. data/ext/sources/examples/talk-llama/llama-context.h +0 -359
  614. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  615. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -47
  616. data/ext/sources/examples/talk-llama/llama-ext.h +0 -12
  617. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  618. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  619. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2735
  620. data/ext/sources/examples/talk-llama/llama-graph.h +0 -1031
  621. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -258
  622. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -353
  623. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  624. data/ext/sources/examples/talk-llama/llama-impl.h +0 -75
  625. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  626. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  627. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -330
  628. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  629. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2285
  630. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -389
  631. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  632. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +0 -275
  633. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +0 -140
  634. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  635. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  636. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1165
  637. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  638. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  639. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  640. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -752
  641. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  642. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1655
  643. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -206
  644. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -299
  645. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -40
  646. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -9056
  647. data/ext/sources/examples/talk-llama/llama-model.h +0 -597
  648. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1304
  649. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  650. data/ext/sources/examples/talk-llama/llama-sampler.cpp +0 -3885
  651. data/ext/sources/examples/talk-llama/llama-sampler.h +0 -42
  652. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3970
  653. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -187
  654. data/ext/sources/examples/talk-llama/llama.cpp +0 -1194
  655. data/ext/sources/examples/talk-llama/llama.h +0 -1573
  656. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -190
  657. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  658. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  659. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -137
  660. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  661. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -123
  662. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -143
  663. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -133
  664. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -184
  665. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -145
  666. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  667. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  668. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  669. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  670. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  671. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  672. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  673. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -122
  674. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  675. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -142
  676. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -262
  677. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +0 -445
  678. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -132
  679. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  680. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -148
  681. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  682. data/ext/sources/examples/talk-llama/models/eurobert.cpp +0 -97
  683. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +0 -145
  684. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  685. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  686. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -111
  687. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  688. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  689. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  690. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  691. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  692. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  693. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  694. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -157
  695. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  696. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  697. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -195
  698. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -210
  699. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  700. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -139
  701. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  702. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -153
  703. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  704. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  705. data/ext/sources/examples/talk-llama/models/jais2.cpp +0 -123
  706. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  707. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +0 -381
  708. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -196
  709. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  710. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  711. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  712. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -175
  713. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  714. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +0 -289
  715. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -54
  716. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -129
  717. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -200
  718. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -123
  719. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  720. data/ext/sources/examples/talk-llama/models/models.h +0 -704
  721. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -109
  722. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  723. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -162
  724. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  725. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  726. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  727. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  728. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  729. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  730. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  731. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  732. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +0 -122
  733. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  734. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  735. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  736. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  737. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -320
  738. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  739. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -169
  740. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  741. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  742. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  743. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  744. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -120
  745. data/ext/sources/examples/talk-llama/models/qwen35.cpp +0 -381
  746. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +0 -422
  747. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -131
  748. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -525
  749. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -140
  750. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -132
  751. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  752. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  753. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -164
  754. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  755. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  756. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -137
  757. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  758. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  759. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  760. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  761. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  762. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  763. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  764. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +0 -165
  765. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  766. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  767. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  768. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  769. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  770. data/ext/sources/examples/talk-llama/speak +0 -40
  771. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  772. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  773. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  774. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  775. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  776. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1103
  777. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  778. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  779. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  780. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  781. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  782. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  783. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  784. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  785. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  786. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  787. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  788. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  789. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  790. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  791. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -155
  792. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  793. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  794. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +0 -123
  795. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +0 -17
  796. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +0 -333
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  798. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -182
  799. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  800. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -718
  801. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  802. data/ext/sources/tests/CMakeLists.txt +0 -112
  803. data/ext/sources/tests/earnings21/eval.mk +0 -58
  804. data/ext/sources/tests/earnings21/eval.py +0 -68
  805. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  806. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  807. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  808. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  809. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  810. data/ext/sources/tests/en-0-ref.txt +0 -1
  811. data/ext/sources/tests/en-1-ref.txt +0 -1
  812. data/ext/sources/tests/en-2-ref.txt +0 -1
  813. data/ext/sources/tests/es-0-ref.txt +0 -1
  814. data/ext/sources/tests/librispeech/eval.mk +0 -39
  815. data/ext/sources/tests/librispeech/eval.py +0 -47
  816. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  817. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  818. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  819. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  820. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  821. data/ext/sources/tests/run-tests.sh +0 -130
  822. data/ext/sources/tests/test-c.c +0 -3
  823. data/ext/sources/tests/test-vad-full.cpp +0 -56
  824. data/ext/sources/tests/test-vad.cpp +0 -83
  825. data/ext/sources/tests/test-whisper.js +0 -58
  826. data/lib/whisper/context.rb +0 -15
  827. data/lib/whisper/segment.rb +0 -58
  828. /data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl → gemv_noshuffle_q8_0_f32.cl} +0 -0
@@ -1,8 +1,26 @@
1
+ #include "ggml-impl.h"
1
2
  #include "ggml.h"
2
3
  #include "ime_kernels.h"
4
+ #include "rvv_kernels.h"
3
5
 
4
6
  #include <algorithm>
5
7
  #include <cmath>
8
+ #include <stdexcept>
9
+
10
+ #if !defined(__riscv_v) || !defined(__riscv_v_intrinsic)
11
+ # error "riscv v extension or v_intrinsic not enabled"
12
+ #else
13
+ # include <riscv_vector.h>
14
+ #endif
15
+
16
+ #if !defined(__riscv_zfh)
17
+ # error "riscv zfh extension not enabled"
18
+ #endif
19
+
20
+ #if defined(RISCV64_SPACEMIT_IME1)
21
+ #else
22
+ # error "RISCV64_SPACEMIT_IME1 not defined"
23
+ #endif
6
24
 
7
25
  // clang-format off
8
26
  #if defined(__GNUC__)
@@ -11,7 +29,7 @@
11
29
  #pragma GCC diagnostic ignored "-Wunused-parameter"
12
30
  #endif
13
31
  // clang-format on
14
- namespace sqnbitgemm_spacemit_ime {
32
+ namespace spacemit_kernels {
15
33
 
16
34
  #define QUANTIZEM4ROW_KERNEL \
17
35
  "vmv.s.x v16, zero \n\t" \
@@ -76,1093 +94,208 @@ namespace sqnbitgemm_spacemit_ime {
76
94
  "vse8.v v31, (s1) \n\t"
77
95
 
78
96
  namespace ime1 {
79
- void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) {
97
+ void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, uint8_t * QuantA) {
80
98
  constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);
81
99
  const float fone = 1.0f;
82
100
 
83
- if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64) {
84
- for (size_t row_index = 0; row_index < 4; ++row_index) {
85
- const float * SRC = A + row_index * CountK;
86
- std::byte * DST = QuantA + row_index * sizeof(float);
101
+ for (size_t row_index = 0; row_index < 4; ++row_index) {
102
+ const float * SRC = A + row_index * CountK;
103
+ uint8_t * DST = QuantA + row_index * sizeof(float);
87
104
 
88
- const size_t offset = (4 - row_index) * 4 + row_index * 8;
89
- const size_t stride = 4 * (sizeof(float) + BlkLen);
90
- __asm__ volatile(
91
- "vsetvli t0, zero, e32, m8 \n\t"
92
- "addi t2, %[CountK], 0 \n\t"
93
- "addi a1, %[DST], 0 \n\t"
94
- "blt t2, %[BlkLen], TAIL%= \n\t"
95
-
96
- "LOOP%=: \n\t"
97
- "vsetvli t0, %[BlkLen], e32, m8 \n\t"
98
- "vle32.v v0, (%[SRC]) \n\t"
99
- "sub t2, t2, t0 \n\t"
100
- "slli t1, t0, 2 \n\t"
101
- "add %[SRC], %[SRC], t1 \n\t"
102
- "add s1, a1, %[OFFSET] \n\t"
103
-
104
- QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE
105
-
106
- "add a1, a1, %[STRIDE] \n\t"
107
- "bge t2, %[BlkLen], LOOP%= \n\t"
108
-
109
- "TAIL%=: \n\t"
110
- "blez t2, QUIT%= \n\t"
111
- "vsetvli t0, zero, e32, m8 \n\t"
112
- "vxor.vv v16, v16, v16 \n\t"
113
- "vxor.vv v24, v24, v24 \n\t"
114
- "vsetvli t0, t2, e32, m8 \n\t"
115
- "vle32.v v0, (%[SRC]) \n\t"
116
- "add s1, a1, %[OFFSET] \n\t"
117
-
118
- QUANTIZEM4ROW_KERNEL
119
-
120
- "addi t3, %[BlkLen], 0 \n\t"
121
- "addi s2, s1, 0 \n\t"
122
- "vsetvli t0, zero, e8, mf4 \n\t"
123
- "vxor.vv v8, v8, v8 \n\t"
124
- "SET_ZERO%=: \n\t"
125
- "vse8.v v8, (s2) \n\t"
126
- "addi s2, s2, 32 \n\t"
127
- "addi t3, t3, -8 \n\t"
128
- "bnez t3, SET_ZERO%= \n\t"
129
-
130
- QUANTIZEM4ROW_STORE
131
-
132
- "QUIT%=: \n\t"
133
- : [SRC] "+r"(SRC)
134
- : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
135
- [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
136
- : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11");
137
- }
138
- } else if (BlkLen == 128) {
139
- for (size_t row_index = 0; row_index < 4; ++row_index) {
140
- const float * SRC = A + row_index * CountK;
141
- std::byte * DST = QuantA + row_index * sizeof(float);
142
-
143
- const size_t offset = (4 - row_index) * 4 + row_index * 8;
144
- const size_t stride = 4 * (sizeof(float) + BlkLen);
145
- __asm__ volatile(
146
- "vsetvli t0, zero, e32, m8 \n\t"
147
- "li t6, 32 \n\t"
148
- "addi t2, %[CountK], 0 \n\t"
149
- "addi a1, %[DST], 0 \n\t"
150
- "add s1, a1, %[OFFSET] \n\t"
151
- "blt t2, %[BlkLen], TAIL%= \n\t"
152
-
153
- "LOOP%=: \n\t"
154
- "vsetvli t0, zero, e32, m8 \n\t"
155
- "vle32.v v0, (%[SRC]) \n\t"
156
- "addi %[SRC], %[SRC], 256 \n\t"
157
- "vle32.v v8, (%[SRC]) \n\t"
158
- "addi %[SRC], %[SRC], 256 \n\t"
159
- "addi t2, t2, -128 \n\t"
160
-
161
- "QUANTIZE%=: \n\t"
162
- "add s1, a1, %[OFFSET] \n\t"
163
- "vfabs.v v16, v0 \n\t"
164
- "vfabs.v v24, v8 \n\t"
165
- "vfmax.vv v16, v24, v16 \n\t"
166
- "vfredmax.vs v24, v16, v24 \n\t"
167
- "vfmv.f.s f10, v24 \n\t"
168
- "fmul.s f10, f10, %[RMAXREC] \n\t"
169
- "fsw f10, (a1) \n\t"
170
- "fdiv.s f11, %[FONE], f10 \n\t"
171
- "vfmul.vf v16, v0, f11 \n\t"
172
- "vfmul.vf v24, v8, f11 \n\t"
173
- "vfcvt.x.f.v v16, v16 \n\t"
174
- "vfcvt.x.f.v v24, v24 \n\t"
175
- "vsetvli t0, zero, e16, m4 \n\t"
176
- "vnclip.wx v16, v16, zero \n\t"
177
- "vnclip.wx v20, v24, zero \n\t"
178
- "vsetvli t0, zero, e8, m4 \n\t"
179
- "vnclip.wx v16, v16, zero \n\t"
180
- "vsetvli t0, zero, e64, m4 \n\t"
181
- "vsse64.v v16, (s1), t6 \n\t"
182
- "add a1, a1, %[STRIDE] \n\t"
183
- "bge t2, %[BlkLen], LOOP%= \n\t"
184
-
185
- "TAIL%=: \n\t"
186
- "blez t2, QUIT%= \n\t"
187
- "vsetvli t0, zero, e32, m8 \n\t"
188
- "vxor.vv v0, v0, v0 \n\t"
189
- "vxor.vv v8, v8, v8 \n\t"
190
- "vxor.vv v16, v16, v16 \n\t"
191
- "vxor.vv v24, v24, v24 \n\t"
192
- "vsetvli t0, t2, e32, m8 \n\t"
193
- "sub t2, t2, t0 \n\t"
194
- "vle32.v v0, (%[SRC]) \n\t"
195
- "addi %[SRC], %[SRC], 256 \n\t"
196
- "vsetvli t0, t2, e32, m8 \n\t"
197
- "vle32.v v8, (%[SRC]) \n\t"
198
- "sub t2, t2, t2 \n\t"
199
- "vsetvli t0, zero, e32, m8 \n\t"
200
- "jal x0, QUANTIZE%= \n\t"
201
-
202
- "QUIT%=: \n\t"
203
- : [SRC] "+r"(SRC)
204
- : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
205
- [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
206
- : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11");
207
- }
208
- } else if (BlkLen == 256) {
209
- for (size_t row_index = 0; row_index < 4; ++row_index) {
210
- const float * SRC = A + row_index * CountK;
211
- std::byte * DST = QuantA + row_index * sizeof(float);
212
- const size_t offset = (4 - row_index) * 4 + row_index * 8;
213
- const size_t stride = 4 * (sizeof(float) + BlkLen);
214
- __asm__ volatile(
215
- "vsetvli t0, zero, e32, m8 \n\t"
216
- "li t6, 32 \n\t"
217
- "addi t2, %[CountK], 0 \n\t"
218
- "addi a1, %[DST], 0 \n\t"
219
- "add s1, a1, %[OFFSET] \n\t"
220
- "blt t2, %[BlkLen], TAIL%= \n\t"
221
-
222
- "LOOP%=: \n\t"
223
- "vsetvli t0, zero, e32, m8 \n\t"
224
- "vle32.v v0, (%[SRC]) \n\t"
225
- "addi %[SRC], %[SRC], 256 \n\t"
226
- "vle32.v v8, (%[SRC]) \n\t"
227
- "addi %[SRC], %[SRC], 256 \n\t"
228
- "vle32.v v16, (%[SRC]) \n\t"
229
- "addi %[SRC], %[SRC], 256 \n\t"
230
- "vle32.v v24, (%[SRC]) \n\t"
231
- "addi %[SRC], %[SRC], -768 \n\t"
232
- "addi t2, t2, -256 \n\t"
233
- "vfabs.v v0, v0 \n\t"
234
- "vfabs.v v8, v8 \n\t"
235
- "vfabs.v v16, v16 \n\t"
236
- "vfabs.v v24, v24 \n\t"
237
- "vfmax.vv v8, v0, v8 \n\t"
238
- "vfmax.vv v24, v24, v16 \n\t"
239
- "vfmax.vv v8, v8, v24 \n\t"
240
- "vfredmax.vs v24, v8, v24 \n\t"
241
- "vfmv.f.s f10, v24 \n\t"
242
- "vle32.v v0, (%[SRC]) \n\t"
243
- "addi %[SRC], %[SRC], 256 \n\t"
244
- "vle32.v v8, (%[SRC]) \n\t"
245
- "addi %[SRC], %[SRC], 256 \n\t"
246
- "vle32.v v16, (%[SRC]) \n\t"
247
- "addi %[SRC], %[SRC], 256 \n\t"
248
- "vle32.v v24, (%[SRC]) \n\t"
249
- "addi %[SRC], %[SRC], 256 \n\t"
250
-
251
- "QUANTIZE%=: \n\t"
252
- "add s1, a1, %[OFFSET] \n\t"
253
- "fmul.s f10, f10, %[RMAXREC] \n\t"
254
- "fsw f10, (a1) \n\t"
255
- "fdiv.s f11, %[FONE], f10 \n\t"
256
- "vfmul.vf v0, v0, f11 \n\t"
257
- "vfmul.vf v8, v8, f11 \n\t"
258
- "vfmul.vf v16, v16, f11 \n\t"
259
- "vfmul.vf v24, v24, f11 \n\t"
260
- "vfcvt.x.f.v v0, v0 \n\t"
261
- "vfcvt.x.f.v v8, v8 \n\t"
262
- "vfcvt.x.f.v v16, v16 \n\t"
263
- "vfcvt.x.f.v v24, v24 \n\t"
264
- "vsetvli t0, zero, e16, m4 \n\t"
265
- "vnclip.wx v0, v0, zero \n\t"
266
- "vnclip.wx v4, v8, zero \n\t"
267
- "vnclip.wx v8, v16, zero \n\t"
268
- "vnclip.wx v12, v24, zero \n\t"
269
- "vsetvli t0, zero, e8, m4 \n\t"
270
- "vnclip.wx v0, v0, zero \n\t"
271
- "vnclip.wx v4, v8, zero \n\t"
272
- "vsetvli t0, zero, e64, m8 \n\t"
273
- "vsse64.v v0, (s1), t6 \n\t"
274
- "add a1, a1, %[STRIDE] \n\t"
275
- "bge t2, %[BlkLen], LOOP%= \n\t"
276
-
277
- "TAIL%=: \n\t"
278
- "blez t2, QUIT%= \n\t"
279
- "vsetvli t0, zero, e32, m8 \n\t"
280
- "vxor.vv v0, v0, v0 \n\t"
281
- "vxor.vv v8, v8, v8 \n\t"
282
- "vxor.vv v16, v16, v16 \n\t"
283
- "vxor.vv v24, v24, v24 \n\t"
284
- "addi t1, t2, 0 \n\t"
285
- "vsetvli t0, t1, e32, m8 \n\t"
286
- "sub t1, t1, t0 \n\t"
287
- "vle32.v v0, (%[SRC]) \n\t"
288
- "addi %[SRC], %[SRC], 256 \n\t"
289
- "vsetvli t0, t1, e32, m8 \n\t"
290
- "sub t1, t1, t0 \n\t"
291
- "vle32.v v8, (%[SRC]) \n\t"
292
- "addi %[SRC], %[SRC], 256 \n\t"
293
- "vsetvli t0, t1, e32, m8 \n\t"
294
- "sub t1, t1, t0 \n\t"
295
- "vle32.v v16, (%[SRC]) \n\t"
296
- "addi %[SRC], %[SRC], 256 \n\t"
297
- "vsetvli t0, t1, e32, m8 \n\t"
298
- "vle32.v v24, (%[SRC]) \n\t"
299
- "addi %[SRC], %[SRC], -768 \n\t"
300
- "vsetvli t0, zero, e32, m8 \n\t"
301
- "vfabs.v v0, v0 \n\t"
302
- "vfabs.v v8, v8 \n\t"
303
- "vfabs.v v16, v16 \n\t"
304
- "vfabs.v v24, v24 \n\t"
305
- "vfmax.vv v8, v0, v8 \n\t"
306
- "vfmax.vv v24, v16, v24 \n\t"
307
- "vfmax.vv v8, v8, v24 \n\t"
308
- "vfredmax.vs v24, v8, v24 \n\t"
309
- "vfmv.f.s f10, v24 \n\t"
310
- "add s1, a1, %[OFFSET] \n\t"
311
- "fmul.s f10, f10, %[RMAXREC] \n\t"
312
- "fsw f10, (a1) \n\t"
313
- "fdiv.s f11, %[FONE], f10 \n\t"
314
- "vsetvli t0, zero, e64, m8 \n\t"
315
- "vxor.vv v0, v0, v0 \n\t"
316
- "vsse64.v v0, (s1), t6 \n\t"
317
-
318
- "TAIL_LOOP%=: \n\t"
319
- "vsetvli t0, zero, e32, m4 \n\t"
320
- "vxor.vv v0, v0, v0 \n\t"
321
- "vsetvli t0, t2, e32, m1 \n\t"
322
- "sub t2, t2, t0 \n\t"
323
- "vle32.v v0, (%[SRC]) \n\t"
324
- "addi %[SRC], %[SRC], 32 \n\t"
325
- "vfmul.vf v1, v0, f11 \n\t"
326
- "vfcvt.x.f.v v2, v1 \n\t"
327
- "vsetvli t0, zero, e16, mf2 \n\t"
328
- "vnclip.wx v3, v2, zero \n\t"
329
- "vsetvli t0, zero, e8, mf4 \n\t"
330
- "vnclip.wx v3, v3, zero \n\t"
331
- "vse8.v v3, (s1) \n\t"
332
- "addi s1, s1, 32 \n\t"
333
- "bnez t2, TAIL_LOOP%= \n\t"
334
-
335
- "QUIT%=: \n\t"
336
- : [SRC] "+r"(SRC)
337
- : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
338
- [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
339
- : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11");
340
- }
105
+ const size_t offset = (4 - row_index) * 4 + row_index * 8;
106
+ const size_t stride = 4 * (sizeof(float) + BlkLen);
107
+ __asm__ volatile(
108
+ "vsetvli t0, zero, e32, m8 \n\t"
109
+ "addi t2, %[CountK], 0 \n\t"
110
+ "addi a1, %[DST], 0 \n\t"
111
+ "blt t2, %[BlkLen], TAIL%= \n\t"
112
+
113
+ "LOOP%=: \n\t"
114
+ "vsetvli t0, %[BlkLen], e32, m8 \n\t"
115
+ "vle32.v v0, (%[SRC]) \n\t"
116
+ "sub t2, t2, t0 \n\t"
117
+ "slli t1, t0, 2 \n\t"
118
+ "add %[SRC], %[SRC], t1 \n\t"
119
+ "add s1, a1, %[OFFSET] \n\t"
120
+
121
+ QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE
122
+
123
+ "add a1, a1, %[STRIDE] \n\t"
124
+ "bge t2, %[BlkLen], LOOP%= \n\t"
125
+
126
+ "TAIL%=: \n\t"
127
+ "blez t2, QUIT%= \n\t"
128
+ "vsetvli t0, zero, e32, m8 \n\t"
129
+ "vxor.vv v16, v16, v16 \n\t"
130
+ "vxor.vv v24, v24, v24 \n\t"
131
+ "vsetvli t0, t2, e32, m8 \n\t"
132
+ "vle32.v v0, (%[SRC]) \n\t"
133
+ "add s1, a1, %[OFFSET] \n\t"
134
+
135
+ QUANTIZEM4ROW_KERNEL
136
+
137
+ "addi t3, %[BlkLen], 0 \n\t"
138
+ "addi s2, s1, 0 \n\t"
139
+ "vsetvli t0, zero, e8, mf4 \n\t"
140
+ "vxor.vv v8, v8, v8 \n\t"
141
+ "SET_ZERO%=: \n\t"
142
+ "vse8.v v8, (s2) \n\t"
143
+ "addi s2, s2, 32 \n\t"
144
+ "addi t3, t3, -8 \n\t"
145
+ "bnez t3, SET_ZERO%= \n\t"
146
+
147
+ QUANTIZEM4ROW_STORE
148
+
149
+ "QUIT%=: \n\t"
150
+ : [SRC] "+r"(SRC)
151
+ : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), [CountK] "r"(CountK),
152
+ [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
153
+ : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11");
341
154
  }
342
155
  }
343
156
 
344
- void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) {
157
+ void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, uint8_t * QuantA) {
345
158
  const float * SRC = A;
346
- std::byte * DST = QuantA;
159
+ uint8_t * DST = QuantA;
347
160
  constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);
348
161
  const float fone = 1.0f;
349
- std::byte * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen);
162
+ uint8_t * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen);
350
163
  size_t offset = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK;
351
164
 
352
- if (CountK <= BlkLen) {
353
- float max_abs_A = 0.0f;
354
- for (size_t k = 0; k < CountK; k++) {
355
- max_abs_A = std::max(max_abs_A, fabsf(A[k]));
356
- }
357
- float scale_A = max_abs_A * range_max_reciprocal;
358
-
359
- ((float *) QuantA)[0] = scale_A;
360
-
361
- auto * QuantAData_offset = (int8_t *) (QuantA + sizeof(float));
362
-
363
- for (size_t k = 0; k < CountK; k++) {
364
- QuantAData_offset[k] =
365
- (int8_t) std::clamp(roundf(A[k] / scale_A), (float) std::numeric_limits<int8_t>::lowest(),
366
- (float) std::numeric_limits<int8_t>::max());
367
- }
368
- for (size_t k = CountK; k < BlkLen; k++) {
369
- QuantAData_offset[k] = 0;
370
- }
371
-
372
- return;
373
- }
374
-
375
- if (BlkLen != 32 || BlkLen != 64 || BlkLen != 128) {
376
- __asm__ volatile(
377
- "vsetvli t0, zero, e8, m8 \n\t"
378
- "vxor.vv v24, v24, v24 \n\t"
379
- "LOOP%=: \n\t"
380
- "vsetvli t0, %[CNT], e8, m8 \n\t"
381
- "vse8.v v24, (%[DST]) \n\t"
382
- "addi %[DST], %[DST], 128 \n\t"
383
- "sub %[CNT], %[CNT], t0 \n\t"
384
- "bnez %[CNT], LOOP%= \n\t"
385
- : [DST] "+r"(QuantA_offset), [CNT] "+r"(offset)
386
- :
387
- : "cc", "t0");
388
- }
389
- if (BlkLen == 16) {
390
- float buffer[64] = { 0.0f };
391
- __asm__ volatile(
392
- "addi t3, zero, 16*8 \n\t"
393
- "addi t2, zero, 16 \n\t"
394
- "blt %[K], t3, LOOP_K%= \n\t"
395
- "blt %[K], t2, TAIL%= \n\t"
396
- "LOOP_MAIN%=: \n\t"
397
- "vsetvli t1, zero, e32, m2 \n\t"
398
- "addi %[K], %[K], -128 \n\t"
399
- "vle32.v v0, (%[SRC]) \n\t"
400
- "addi %[SRC], %[SRC], 64 \n\t"
401
- "vle32.v v2, (%[SRC]) \n\t"
402
- "addi %[SRC], %[SRC], 64 \n\t"
403
- "vle32.v v4, (%[SRC]) \n\t"
404
- "addi %[SRC], %[SRC], 64 \n\t"
405
- "vle32.v v6, (%[SRC]) \n\t"
406
- "addi %[SRC], %[SRC], 64 \n\t"
407
- "vle32.v v8, (%[SRC]) \n\t"
408
- "addi %[SRC], %[SRC], 64 \n\t"
409
- "vle32.v v10, (%[SRC]) \n\t"
410
- "addi %[SRC], %[SRC], 64 \n\t"
411
- "vle32.v v12, (%[SRC]) \n\t"
412
- "addi %[SRC], %[SRC], 64 \n\t"
413
- "vle32.v v14, (%[SRC]) \n\t"
414
- "addi %[SRC], %[SRC], 64 \n\t"
415
- "addi a1, %[BUFFER], 0 \n\t"
416
- "vfabs.v v16, v0 \n\t"
417
- "vfabs.v v18, v2 \n\t"
418
- "vfabs.v v20, v4 \n\t"
419
- "vfabs.v v22, v6 \n\t"
420
- "vfabs.v v24, v8 \n\t"
421
- "vfabs.v v26, v10 \n\t"
422
- "vfabs.v v28, v12 \n\t"
423
- "vfabs.v v30, v14 \n\t"
424
- "vsetvli t0, zero, e32, m1 \n\t"
425
- "vfmax.vv v16, v16, v17 \n\t"
426
- "vfmax.vv v18, v18, v19 \n\t"
427
- "vfmax.vv v20, v20, v21 \n\t"
428
- "vfmax.vv v22, v22, v23 \n\t"
429
- "vfmax.vv v24, v24, v25 \n\t"
430
- "vfmax.vv v26, v26, v27 \n\t"
431
- "vfmax.vv v28, v28, v29 \n\t"
432
- "vfmax.vv v30, v30, v31 \n\t"
433
- "vse32.v v16, (a1) \n\t"
434
- "addi a1, a1, 32 \n\t"
435
- "vse32.v v18, (a1) \n\t"
436
- "addi a1, a1, 32 \n\t"
437
- "vse32.v v20, (a1) \n\t"
438
- "addi a1, a1, 32 \n\t"
439
- "vse32.v v22, (a1) \n\t"
440
- "addi a1, a1, 32 \n\t"
441
- "vse32.v v24, (a1) \n\t"
442
- "addi a1, a1, 32 \n\t"
443
- "vse32.v v26, (a1) \n\t"
444
- "addi a1, a1, 32 \n\t"
445
- "vse32.v v28, (a1) \n\t"
446
- "addi a1, a1, 32 \n\t"
447
- "vse32.v v30, (a1) \n\t"
448
- "addi a1, %[BUFFER], 0 \n\t"
449
- "flw f0, (a1) \n\t"
450
- "flw f1, 4(a1) \n\t"
451
- "flw f2, 8(a1) \n\t"
452
- "flw f3, 12(a1) \n\t"
453
- "flw f4, 16(a1) \n\t"
454
- "flw f5, 20(a1) \n\t"
455
- "flw f6, 24(a1) \n\t"
456
- "flw f7, 28(a1) \n\t"
457
- "addi a1, a1, 32 \n\t"
458
- "fmax.s f1, f0, f1 \n\t"
459
- "fmax.s f3, f2, f3 \n\t"
460
- "fmax.s f5, f4, f5 \n\t"
461
- "fmax.s f7, f6, f7 \n\t"
462
- "fmax.s f3, f1, f3 \n\t"
463
- "fmax.s f7, f5, f7 \n\t"
464
- "fmax.s f10, f3, f7 \n\t"
465
- "fmul.s f10, f10, %[RMAXREC] \n\t"
466
- "fsw f10, (%[DST]) \n\t"
467
- "addi %[DST], %[DST], 20 \n\t"
468
- "fdiv.s f10, %[FONE], f10 \n\t"
469
- "flw f0, (a1) \n\t"
470
- "flw f1, 4(a1) \n\t"
471
- "flw f2, 8(a1) \n\t"
472
- "flw f3, 12(a1) \n\t"
473
- "flw f4, 16(a1) \n\t"
474
- "flw f5, 20(a1) \n\t"
475
- "flw f6, 24(a1) \n\t"
476
- "flw f7, 28(a1) \n\t"
477
- "addi a1, a1, 32 \n\t"
478
- "fmax.s f1, f0, f1 \n\t"
479
- "fmax.s f3, f2, f3 \n\t"
480
- "fmax.s f5, f4, f5 \n\t"
481
- "fmax.s f7, f6, f7 \n\t"
482
- "fmax.s f3, f1, f3 \n\t"
483
- "fmax.s f7, f5, f7 \n\t"
484
- "fmax.s f11, f3, f7 \n\t"
485
- "fmul.s f11, f11, %[RMAXREC] \n\t"
486
- "fsw f11, (%[DST]) \n\t"
487
- "addi %[DST], %[DST], 20 \n\t"
488
- "fdiv.s f11, %[FONE], f11 \n\t"
489
- "flw f0, (a1) \n\t"
490
- "flw f1, 4(a1) \n\t"
491
- "flw f2, 8(a1) \n\t"
492
- "flw f3, 12(a1) \n\t"
493
- "flw f4, 16(a1) \n\t"
494
- "flw f5, 20(a1) \n\t"
495
- "flw f6, 24(a1) \n\t"
496
- "flw f7, 28(a1) \n\t"
497
- "addi a1, a1, 32 \n\t"
498
- "fmax.s f1, f0, f1 \n\t"
499
- "fmax.s f3, f2, f3 \n\t"
500
- "fmax.s f5, f4, f5 \n\t"
501
- "fmax.s f7, f6, f7 \n\t"
502
- "fmax.s f3, f1, f3 \n\t"
503
- "fmax.s f7, f5, f7 \n\t"
504
- "fmax.s f12, f3, f7 \n\t"
505
- "fmul.s f12, f12, %[RMAXREC] \n\t"
506
- "fsw f12, (%[DST]) \n\t"
507
- "addi %[DST], %[DST], 20 \n\t"
508
- "fdiv.s f12, %[FONE], f12 \n\t"
509
- "flw f0, (a1) \n\t"
510
- "flw f1, 4(a1) \n\t"
511
- "flw f2, 8(a1) \n\t"
512
- "flw f3, 12(a1) \n\t"
513
- "flw f4, 16(a1) \n\t"
514
- "flw f5, 20(a1) \n\t"
515
- "flw f6, 24(a1) \n\t"
516
- "flw f7, 28(a1) \n\t"
517
- "addi a1, a1, 32 \n\t"
518
- "fmax.s f1, f0, f1 \n\t"
519
- "fmax.s f3, f2, f3 \n\t"
520
- "fmax.s f5, f4, f5 \n\t"
521
- "fmax.s f7, f6, f7 \n\t"
522
- "fmax.s f3, f1, f3 \n\t"
523
- "fmax.s f7, f5, f7 \n\t"
524
- "fmax.s f13, f3, f7 \n\t"
525
- "fmul.s f13, f13, %[RMAXREC] \n\t"
526
- "fsw f13, (%[DST]) \n\t"
527
- "addi %[DST], %[DST], 20 \n\t"
528
- "fdiv.s f13, %[FONE], f13 \n\t"
529
- "flw f0, (a1) \n\t"
530
- "flw f1, 4(a1) \n\t"
531
- "flw f2, 8(a1) \n\t"
532
- "flw f3, 12(a1) \n\t"
533
- "flw f4, 16(a1) \n\t"
534
- "flw f5, 20(a1) \n\t"
535
- "flw f6, 24(a1) \n\t"
536
- "flw f7, 28(a1) \n\t"
537
- "addi a1, a1, 32 \n\t"
538
- "fmax.s f1, f0, f1 \n\t"
539
- "fmax.s f3, f2, f3 \n\t"
540
- "fmax.s f5, f4, f5 \n\t"
541
- "fmax.s f7, f6, f7 \n\t"
542
- "fmax.s f3, f1, f3 \n\t"
543
- "fmax.s f7, f5, f7 \n\t"
544
- "fmax.s f14, f3, f7 \n\t"
545
- "fmul.s f14, f14, %[RMAXREC] \n\t"
546
- "fsw f14, (%[DST]) \n\t"
547
- "addi %[DST], %[DST], 20 \n\t"
548
- "fdiv.s f14, %[FONE], f14 \n\t"
549
- "flw f0, (a1) \n\t"
550
- "flw f1, 4(a1) \n\t"
551
- "flw f2, 8(a1) \n\t"
552
- "flw f3, 12(a1) \n\t"
553
- "flw f4, 16(a1) \n\t"
554
- "flw f5, 20(a1) \n\t"
555
- "flw f6, 24(a1) \n\t"
556
- "flw f7, 28(a1) \n\t"
557
- "addi a1, a1, 32 \n\t"
558
- "fmax.s f1, f0, f1 \n\t"
559
- "fmax.s f3, f2, f3 \n\t"
560
- "fmax.s f5, f4, f5 \n\t"
561
- "fmax.s f7, f6, f7 \n\t"
562
- "fmax.s f3, f1, f3 \n\t"
563
- "fmax.s f7, f5, f7 \n\t"
564
- "fmax.s f15, f3, f7 \n\t"
565
- "fmul.s f15, f15, %[RMAXREC] \n\t"
566
- "fsw f15, (%[DST]) \n\t"
567
- "addi %[DST], %[DST], 20 \n\t"
568
- "fdiv.s f15, %[FONE], f15 \n\t"
569
- "flw f0, (a1) \n\t"
570
- "flw f1, 4(a1) \n\t"
571
- "flw f2, 8(a1) \n\t"
572
- "flw f3, 12(a1) \n\t"
573
- "flw f4, 16(a1) \n\t"
574
- "flw f5, 20(a1) \n\t"
575
- "flw f6, 24(a1) \n\t"
576
- "flw f7, 28(a1) \n\t"
577
- "addi a1, a1, 32 \n\t"
578
- "fmax.s f1, f0, f1 \n\t"
579
- "fmax.s f3, f2, f3 \n\t"
580
- "fmax.s f5, f4, f5 \n\t"
581
- "fmax.s f7, f6, f7 \n\t"
582
- "fmax.s f3, f1, f3 \n\t"
583
- "fmax.s f7, f5, f7 \n\t"
584
- "fmax.s f16, f3, f7 \n\t"
585
- "fmul.s f16, f16, %[RMAXREC] \n\t"
586
- "fsw f16, (%[DST]) \n\t"
587
- "addi %[DST], %[DST], 20 \n\t"
588
- "fdiv.s f16, %[FONE], f16 \n\t"
589
- "flw f0, (a1) \n\t"
590
- "flw f1, 4(a1) \n\t"
591
- "flw f2, 8(a1) \n\t"
592
- "flw f3, 12(a1) \n\t"
593
- "flw f4, 16(a1) \n\t"
594
- "flw f5, 20(a1) \n\t"
595
- "flw f6, 24(a1) \n\t"
596
- "flw f7, 28(a1) \n\t"
597
- "addi a1, a1, 32 \n\t"
598
- "fmax.s f1, f0, f1 \n\t"
599
- "fmax.s f3, f2, f3 \n\t"
600
- "fmax.s f5, f4, f5 \n\t"
601
- "fmax.s f7, f6, f7 \n\t"
602
- "fmax.s f3, f1, f3 \n\t"
603
- "fmax.s f7, f5, f7 \n\t"
604
- "fmax.s f17, f3, f7 \n\t"
605
- "fmul.s f17, f17, %[RMAXREC] \n\t"
606
- "fsw f17, (%[DST]) \n\t"
607
- "addi %[DST], %[DST], -136 \n\t"
608
- "fdiv.s f17, %[FONE], f17 \n\t"
609
- "vsetvli t0, zero, e32, m2 \n\t"
610
- "vfmul.vf v16, v0, f10 \n\t"
611
- "vfmul.vf v18, v2, f11 \n\t"
612
- "vfmul.vf v20, v4, f12 \n\t"
613
- "vfmul.vf v22, v6, f13 \n\t"
614
- "vfmul.vf v24, v8, f14 \n\t"
615
- "vfmul.vf v26, v10, f15 \n\t"
616
- "vfmul.vf v28, v12, f16 \n\t"
617
- "vfmul.vf v30, v14, f17 \n\t"
618
- "vfcvt.x.f.v v16, v16 \n\t"
619
- "vfcvt.x.f.v v18, v18 \n\t"
620
- "vfcvt.x.f.v v20, v20 \n\t"
621
- "vfcvt.x.f.v v22, v22 \n\t"
622
- "vfcvt.x.f.v v24, v24 \n\t"
623
- "vfcvt.x.f.v v26, v26 \n\t"
624
- "vfcvt.x.f.v v28, v28 \n\t"
625
- "vfcvt.x.f.v v30, v30 \n\t"
626
- "vsetvli t0, zero, e16, m1 \n\t"
627
- "vnclip.wx v16, v16, zero \n\t"
628
- "vnclip.wx v18, v18, zero \n\t"
629
- "vnclip.wx v20, v20, zero \n\t"
630
- "vnclip.wx v22, v22, zero \n\t"
631
- "vnclip.wx v24, v24, zero \n\t"
632
- "vnclip.wx v26, v26, zero \n\t"
633
- "vnclip.wx v28, v28, zero \n\t"
634
- "vnclip.wx v30, v30, zero \n\t"
635
- "vsetvli t0, t1, e8, mf2 \n\t"
636
- "vnclip.wx v16, v16, zero \n\t"
637
- "vnclip.wx v18, v18, zero \n\t"
638
- "vnclip.wx v20, v20, zero \n\t"
639
- "vnclip.wx v22, v22, zero \n\t"
640
- "vnclip.wx v24, v24, zero \n\t"
641
- "vnclip.wx v26, v26, zero \n\t"
642
- "vnclip.wx v28, v28, zero \n\t"
643
- "vnclip.wx v30, v30, zero \n\t"
644
- "vse8.v v16, (%[DST]) \n\t"
645
- "addi %[DST], %[DST], 20 \n\t"
646
- "vse8.v v18, (%[DST]) \n\t"
647
- "addi %[DST], %[DST], 20 \n\t"
648
- "vse8.v v20, (%[DST]) \n\t"
649
- "addi %[DST], %[DST], 20 \n\t"
650
- "vse8.v v22, (%[DST]) \n\t"
651
- "addi %[DST], %[DST], 20 \n\t"
652
- "vse8.v v24, (%[DST]) \n\t"
653
- "addi %[DST], %[DST], 20 \n\t"
654
- "vse8.v v26, (%[DST]) \n\t"
655
- "addi %[DST], %[DST], 20 \n\t"
656
- "vse8.v v28, (%[DST]) \n\t"
657
- "addi %[DST], %[DST], 20 \n\t"
658
- "vse8.v v30, (%[DST]) \n\t"
659
- "addi %[DST], %[DST], 16 \n\t"
660
- "bge %[K], t3, LOOP_MAIN%= \n\t"
661
- "blt %[K], t2, TAIL%= \n\t"
662
- "LOOP_K%=: \n\t"
663
- "vsetvli t1, %[K], e32, m2 \n\t"
664
- "vle32.v v0, (%[SRC]) \n\t"
665
- "addi %[SRC], %[SRC], 64 \n\t"
666
- "sub %[K], %[K], t1 \n\t"
667
- "vfabs.v v16, v0 \n\t"
668
- "vsetvli t0, zero, e32, m1 \n\t"
669
- "vfmax.vv v16, v16, v17 \n\t"
670
- "vse32.v v16, (%[BUFFER]) \n\t"
671
- "flw f0, (%[BUFFER]) \n\t"
672
- "flw f1, 4(%[BUFFER]) \n\t"
673
- "flw f2, 8(%[BUFFER]) \n\t"
674
- "flw f3, 12(%[BUFFER]) \n\t"
675
- "flw f4, 16(%[BUFFER]) \n\t"
676
- "flw f5, 20(%[BUFFER]) \n\t"
677
- "flw f6, 24(%[BUFFER]) \n\t"
678
- "flw f7, 28(%[BUFFER]) \n\t"
679
- "fmax.s f1, f0, f1 \n\t"
680
- "fmax.s f3, f2, f3 \n\t"
681
- "fmax.s f5, f4, f5 \n\t"
682
- "fmax.s f7, f6, f7 \n\t"
683
- "fmax.s f3, f1, f3 \n\t"
684
- "fmax.s f7, f5, f7 \n\t"
685
- "fmax.s f10, f3, f7 \n\t"
686
- "fmul.s f10, f10, %[RMAXREC] \n\t"
687
- "fsw f10, (%[DST]) \n\t"
688
- "addi %[DST], %[DST], 4 \n\t"
689
- "fdiv.s f11, %[FONE], f10 \n\t"
690
- "vsetvli t0, zero, e32, m2 \n\t"
691
- "vfmul.vf v16, v0, f11 \n\t"
692
- "vfcvt.x.f.v v16, v16 \n\t"
693
- "vsetvli t0, zero, e16, m1 \n\t"
694
- "vnclip.wx v16, v16, zero \n\t"
695
- "vsetvli t0, t1, e8, mf2 \n\t"
696
- "vnclip.wx v16, v16, zero \n\t"
697
- "vse8.v v16, (%[DST]) \n\t"
698
- "addi %[DST], %[DST], 16 \n\t"
699
- "bge %[K], t2, LOOP_K%= \n\t"
700
- "TAIL%=: \n\t"
701
- "blez %[K], END%= \n\t"
702
- "vsetvli t0, t3, e32, m2 \n\t"
703
- "vxor.vv v16, v16, v16 \n\t"
704
- "jal x0, LOOP_K%= \n\t"
705
- "END%=: \n\t"
706
- : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK)
707
- : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BUFFER] "r"(buffer)
708
- : "cc", "t3", "t2", "t1", "t0", "a1", "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f10", "f11", "f12",
709
- "f13", "f14", "f15", "f16", "f17");
710
- } else if (BlkLen == 32) {
711
- __asm__ volatile(
712
- "addi t3, zero, 32*4 \n\t"
713
- "addi t2, zero, 32 \n\t"
714
-
715
- "addi a1, %[SRC], 0 \n\t"
716
- "addi a2, %[SRC], 128 \n\t"
717
- "addi a3, %[SRC], 256 \n\t"
718
- "addi a4, %[SRC], 384 \n\t"
719
-
720
- "addi s1, %[DST], 0 \n\t"
721
- "addi s2, %[DST], 36 \n\t"
722
- "addi s3, %[DST], 72 \n\t"
723
- "addi s4, %[DST], 108 \n\t"
724
- "blt %[K], t3, LOOP_K%= \n\t"
725
- "blt %[K], t2, TAIL%= \n\t"
726
-
727
- "LOOP_MAIN%=: \n\t"
728
- "vsetvli t1, zero, e32, m4 \n\t"
729
- "addi %[K], %[K], -128 \n\t"
730
- "vle32.v v0, (a1) \n\t"
731
- "addi a1, a1, 512 \n\t"
732
- "vle32.v v4, (a2) \n\t"
733
- "addi a2, a2, 512 \n\t"
734
- "vle32.v v8, (a3) \n\t"
735
- "addi a3, a3, 512 \n\t"
736
- "vle32.v v12, (a4) \n\t"
737
- "addi a4, a4, 512 \n\t"
738
- "vfabs.v v16, v0 \n\t"
739
- "vfabs.v v20, v4 \n\t"
740
- "vfabs.v v24, v8 \n\t"
741
- "vfabs.v v28, v12 \n\t"
742
- "vsetvli t0, zero, e32, m2 \n\t"
743
- "vfmax.vv v16, v16, v18 \n\t"
744
- "vfmax.vv v20, v20, v22 \n\t"
745
- "vfmax.vv v24, v24, v26 \n\t"
746
- "vfmax.vv v28, v28, v30 \n\t"
747
- "vsetvli t0, zero, e32, m1 \n\t"
748
- "vfmax.vv v16, v16, v17 \n\t"
749
- "vfmax.vv v20, v20, v21 \n\t"
750
- "vfmax.vv v24, v24, v25 \n\t"
751
- "vfmax.vv v28, v28, v29 \n\t"
752
-
753
- "vfredmax.vs v17, v16, v17 \n\t"
754
- "vfredmax.vs v21, v20, v21 \n\t"
755
- "vfredmax.vs v25, v24, v25 \n\t"
756
- "vfredmax.vs v29, v28, v29 \n\t"
757
- "vfmv.f.s f10, v17 \n\t"
758
- "vfmv.f.s f11, v21 \n\t"
759
- "vfmv.f.s f12, v25 \n\t"
760
- "vfmv.f.s f13, v29 \n\t"
761
-
762
- "fmul.s f10, f10, %[RMAXREC] \n\t"
763
- "fmul.s f11, f11, %[RMAXREC] \n\t"
764
- "fmul.s f12, f12, %[RMAXREC] \n\t"
765
- "fmul.s f13, f13, %[RMAXREC] \n\t"
766
- "fsw f10, (s1) \n\t"
767
- "addi s1, s1, 4 \n\t"
768
-
769
- "fsw f11, (s2) \n\t"
770
- "addi s2, s2, 4 \n\t"
771
- "fsw f12, (s3) \n\t"
772
- "addi s3, s3, 4 \n\t"
773
- "fsw f13, (s4) \n\t"
774
- "addi s4, s4, 4 \n\t"
775
- "fdiv.s f10, %[FONE], f10 \n\t"
776
- "fdiv.s f11, %[FONE], f11 \n\t"
777
- "fdiv.s f12, %[FONE], f12 \n\t"
778
- "fdiv.s f13, %[FONE], f13 \n\t"
779
- "vsetvli t0, zero, e32, m4 \n\t"
780
- "vfmul.vf v16, v0, f10 \n\t"
781
- "vfmul.vf v20, v4, f11 \n\t"
782
- "vfmul.vf v24, v8, f12 \n\t"
783
- "vfmul.vf v28, v12, f13 \n\t"
784
- "vfcvt.x.f.v v16, v16 \n\t"
785
- "vfcvt.x.f.v v20, v20 \n\t"
786
- "vfcvt.x.f.v v24, v24 \n\t"
787
- "vfcvt.x.f.v v28, v28 \n\t"
788
- "vsetvli t0, zero, e16, m2 \n\t"
789
- "vnclip.wx v16, v16, zero \n\t"
790
- "vnclip.wx v20, v20, zero \n\t"
791
- "vnclip.wx v24, v24, zero \n\t"
792
- "vnclip.wx v28, v28, zero \n\t"
793
- "vsetvli t0, t1, e8, m1 \n\t"
794
- "vnclip.wx v16, v16, zero \n\t"
795
- "vnclip.wx v20, v20, zero \n\t"
796
- "vnclip.wx v24, v24, zero \n\t"
797
- "vnclip.wx v28, v28, zero \n\t"
798
- "vse8.v v16, (s1) \n\t"
799
- "addi s1, s1, 140 \n\t"
800
- "vse8.v v20, (s2) \n\t"
801
- "addi s2, s2, 140 \n\t"
802
- "vse8.v v24, (s3) \n\t"
803
- "addi s3, s3, 140 \n\t"
804
- "vse8.v v28, (s4) \n\t"
805
- "addi s4, s4, 140 \n\t"
806
- "bge %[K], t3, LOOP_MAIN%= \n\t"
807
- "blt %[K], t2, TAIL%= \n\t"
808
- "LOOP_K%=: \n\t"
809
- "vsetvli t1, %[K], e32, m4 \n\t"
810
- "vle32.v v0, (a1) \n\t"
811
- "addi a1, a1, 128 \n\t"
812
- "sub %[K], %[K], t1 \n\t"
813
- "vfabs.v v16, v0 \n\t"
814
- "vsetvli t0, zero, e32, m2 \n\t"
815
- "vfmax.vv v16, v16, v18 \n\t"
816
- "vsetvli t0, zero, e32, m1 \n\t"
817
- "vfmax.vv v16, v16, v17 \n\t"
818
- "vfredmax.vs v17, v16, v17 \n\t"
819
- "vfmv.f.s f10, v17 \n\t"
820
-
821
- "fmul.s f10, f10, %[RMAXREC] \n\t"
822
- "fsw f10, (s1) \n\t"
823
- "addi s1, s1, 4 \n\t"
824
- "fdiv.s f11, %[FONE], f10 \n\t"
825
- "vsetvli t0, zero, e32, m4 \n\t"
826
- "vfmul.vf v16, v0, f11 \n\t"
827
- "vfcvt.x.f.v v16, v16 \n\t"
828
- "vsetvli t0, zero, e16, m2 \n\t"
829
- "vnclip.wx v16, v16, zero \n\t"
830
- "vsetvli t0, zero, e8, m1 \n\t"
831
- "vnclip.wx v16, v16, zero \n\t"
832
- "vse8.v v16, (s1) \n\t"
833
- "addi s1, s1, 32 \n\t"
834
- "bge %[K], t2, LOOP_K%= \n\t"
835
- "TAIL%=: \n\t"
836
- "blez %[K], END%= \n\t"
837
- "vsetvli t0, t3, e32, m4 \n\t"
838
- "vxor.vv v0, v0, v0 \n\t"
839
- "vxor.vv v16, v16, v16 \n\t"
840
- "jal x0, LOOP_K%= \n\t"
841
- "END%=: \n\t"
842
- : [K] "+r"(CountK)
843
- : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST)
844
- : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13");
845
- } else if (BlkLen == 64) {
846
- __asm__ volatile(
847
- "addi t3, zero, 64*2 \n\t"
848
- "addi t2, zero, 64 \n\t"
849
- "addi a1, %[SRC], 0 \n\t"
850
- "addi a2, %[SRC], 256 \n\t"
851
- "addi s1, %[DST], 0 \n\t"
852
- "addi s2, %[DST], 68 \n\t"
853
- "blt %[K], t3, LOOP_K%= \n\t"
854
- "blt %[K], t2, TAIL%= \n\t"
855
- "LOOP_MAIN%=: \n\t"
856
- "vsetvli t1, zero, e32, m8 \n\t"
857
- "addi %[K], %[K], -128 \n\t"
858
- "vle32.v v0, (a1) \n\t"
859
- "addi a1, a1, 512 \n\t"
860
- "vle32.v v8, (a2) \n\t"
861
- "addi a2, a2, 512 \n\t"
862
- "vfabs.v v16, v0 \n\t"
863
- "vfabs.v v24, v8 \n\t"
864
- "vsetvli t0, zero, e32, m4 \n\t"
865
- "vfmax.vv v16, v16, v20 \n\t"
866
- "vfmax.vv v24, v24, v28 \n\t"
867
- "vsetvli t0, zero, e32, m2 \n\t"
868
- "vfmax.vv v16, v16, v18 \n\t"
869
- "vfmax.vv v24, v24, v26 \n\t"
870
- "vsetvli t0, zero, e32, m1 \n\t"
871
- "vfmax.vv v16, v16, v17 \n\t"
872
- "vfmax.vv v24, v24, v25 \n\t"
873
- "vfredmax.vs v17, v16, v17 \n\t"
874
- "vfredmax.vs v25, v24, v25 \n\t"
875
- "vfmv.f.s f10, v17 \n\t"
876
- "vfmv.f.s f11, v25 \n\t"
877
- "fmul.s f10, f10, %[RMAXREC] \n\t"
878
- "fmul.s f11, f11, %[RMAXREC] \n\t"
879
- "fsw f10, (s1) \n\t"
880
- "addi s1, s1, 4 \n\t"
881
- "fsw f11, (s2) \n\t"
882
- "addi s2, s2, 4 \n\t"
883
- "fdiv.s f10, %[FONE], f10 \n\t"
884
- "fdiv.s f11, %[FONE], f11 \n\t"
885
- "vsetvli t0, zero, e32, m8 \n\t"
886
- "vfmul.vf v16, v0, f10 \n\t"
887
- "vfmul.vf v24, v8, f11 \n\t"
888
- "vfcvt.x.f.v v16, v16 \n\t"
889
- "vfcvt.x.f.v v24, v24 \n\t"
890
- "vsetvli t0, zero, e16, m4 \n\t"
891
- "vnclip.wx v16, v16, zero \n\t"
892
- "vnclip.wx v24, v24, zero \n\t"
893
- "vsetvli t0, t1, e8, m2 \n\t"
894
- "vnclip.wx v16, v16, zero \n\t"
895
- "vnclip.wx v24, v24, zero \n\t"
896
- "vse8.v v16, (s1) \n\t"
897
- "addi s1, s1, 132 \n\t"
898
- "vse8.v v24, (s2) \n\t"
899
- "addi s2, s2, 132 \n\t"
900
- "bge %[K], t3, LOOP_MAIN%= \n\t"
901
- "blt %[K], t2, TAIL%= \n\t"
902
- "LOOP_K%=: \n\t"
903
- "vsetvli t1, %[K], e32, m8 \n\t"
904
- "vle32.v v0, (a1) \n\t"
905
- "addi a1, a1, 256 \n\t"
906
- "sub %[K], %[K], t1 \n\t"
907
- "vfabs.v v16, v0 \n\t"
908
- "vsetvli t0, zero, e32, m4 \n\t"
909
- "vfmax.vv v16, v16, v20 \n\t"
910
- "vsetvli t0, zero, e32, m2 \n\t"
911
- "vfmax.vv v16, v16, v18 \n\t"
912
- "vsetvli t0, zero, e32, m1 \n\t"
913
- "vfmax.vv v16, v16, v17 \n\t"
914
- "vfredmax.vs v17, v16, v17 \n\t"
915
- "vfmv.f.s f10, v17 \n\t"
916
- "fmul.s f10, f10, %[RMAXREC] \n\t"
917
- "fsw f10, (s1) \n\t"
918
- "addi s1, s1, 4 \n\t"
919
- "fdiv.s f11, %[FONE], f10 \n\t"
920
- "vsetvli t0, zero, e32, m8 \n\t"
921
- "vfmul.vf v16, v0, f11 \n\t"
922
- "vfcvt.x.f.v v16, v16 \n\t"
923
- "vsetvli t0, zero, e16, m4 \n\t"
924
- "vnclip.wx v16, v16, zero \n\t"
925
- "vsetvli t0, zero, e8, m2 \n\t"
926
- "vnclip.wx v16, v16, zero \n\t"
927
- "vse8.v v16, (s1) \n\t"
928
- "addi s1, s1, 64 \n\t"
929
- "bge %[K], t2, LOOP_K%= \n\t"
930
- "TAIL%=: \n\t"
931
- "blez %[K], END%= \n\t"
932
- "vsetvli t0, t3, e32, m8 \n\t"
933
- "vxor.vv v0, v0, v0 \n\t"
934
- "vxor.vv v16, v16, v16 \n\t"
935
- "jal x0, LOOP_K%= \n\t"
936
- "END%=: \n\t"
937
- : [K] "+r"(CountK)
938
- : [SRC] "r"(SRC), [DST] "r"(DST), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
939
- : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "s1", "s2", "f10", "f11");
940
- } else if (BlkLen == 128) {
941
- __asm__ volatile(
942
- "addi t2, zero, 128 \n\t"
943
- "addi a1, %[SRC], 0 \n\t"
944
- "addi a2, %[SRC], 256 \n\t"
945
- "blt %[K], t2, TAIL%= \n\t"
946
- "LOOP_K%=: \n\t"
947
- "vsetvli t1, zero, e32, m8 \n\t"
948
- "vle32.v v0, (a1) \n\t"
949
- "addi a1, a1, 512 \n\t"
950
- "vle32.v v8, (a2) \n\t"
951
- "addi a2, a2, 512 \n\t"
952
- "sub %[K], %[K], t2 \n\t"
953
- "QUANT%=: \n\t"
954
- "vfabs.v v16, v0 \n\t"
955
- "vfabs.v v24, v8 \n\t"
956
- "vfmax.vv v24, v16, v24 \n\t"
957
- "vsetvli t1, zero, e32, m4 \n\t"
958
- "vfmax.vv v28, v24, v28 \n\t"
959
- "vsetvli t0, zero, e32, m2 \n\t"
960
- "vfmax.vv v30, v28, v30 \n\t"
961
- "vsetvli t0, zero, e32, m1 \n\t"
962
- "vfmax.vv v30, v30, v31 \n\t"
963
- "vfredmax.vs v31, v30, v31 \n\t"
964
- "vfmv.f.s f10, v31 \n\t"
965
- "fmul.s f10, f10, %[RMAXREC] \n\t"
966
- "fsw f10, (%[DST]) \n\t"
967
- "addi %[DST], %[DST], 4 \n\t"
968
- "fdiv.s f11, %[FONE], f10 \n\t"
969
- "vsetvli t0, zero, e32, m8 \n\t"
970
- "vfmul.vf v16, v0, f11 \n\t"
971
- "vfmul.vf v24, v8, f11 \n\t"
972
- "vfcvt.x.f.v v16, v16 \n\t"
973
- "vfcvt.x.f.v v24, v24 \n\t"
974
- "vsetvli t0, zero, e16, m4 \n\t"
975
- "vnclip.wx v16, v16, zero \n\t"
976
- "vnclip.wx v20, v24, zero \n\t"
977
- "vsetvli t0, zero, e8, m4 \n\t"
978
- "vnclip.wx v16, v16, zero \n\t"
979
- "vse8.v v16, (%[DST]) \n\t"
980
- "addi %[DST], %[DST], 128 \n\t"
981
- "bge %[K], t2, LOOP_K%= \n\t"
982
- "TAIL%=: \n\t"
983
- "blez %[K], END%= \n\t"
984
- "vsetvli t1, zero, e32, m8 \n\t"
985
- "vxor.vv v0, v0, v0 \n\t"
986
- "vxor.vv v8, v8, v8 \n\t"
987
- "vsetvli t0, %[K], e32, m8 \n\t"
988
- "vle32.v v0, (a1) \n\t"
989
- "sub %[K], %[K], t0 \n\t"
990
- "vsetvli t0, %[K], e32, m8 \n\t"
991
- "vle32.v v8, (a2) \n\t"
992
- "sub %[K], %[K], t0 \n\t"
993
- "vsetvli t1, zero, e32, m8 \n\t"
994
- "jal x0, QUANT%= \n\t"
995
- "END%=: \n\t"
996
-
997
- : [DST] "+r"(DST), [K] "+r"(CountK)
998
- : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC)
999
- : "cc", "t2", "t1", "t0", "a1", "a2", "f10", "f11");
1000
- } else {
1001
- float buffer[8] = { 0.0f };
1002
- size_t cnt = BlkLen / 256;
1003
-
1004
- __asm__ volatile(
1005
- "slli t3, %[BLK], 2 \n\t"
1006
- "blt %[K], %[BLK], LOOP_TAIL%= \n\t"
1007
- "LOOP_MAIN%=: \n\t"
1008
- "vsetvli t0, zero, e32, m1 \n\t"
1009
- "vxor.vv v31, v31, v31 \n\t"
1010
- "vse32.v v31, (%[BUFFER]) \n\t"
1011
- "addi t6, %[CNT], 0 \n\t"
1012
- "LOOP_CMP%=: \n\t"
1013
- "addi t6, t6, -1 \n\t"
1014
- "vsetvli t0, zero, e32, m8 \n\t"
1015
- "vle32.v v0, (%[SRC]) \n\t"
1016
- "addi %[SRC], %[SRC], 256 \n\t"
1017
- "vle32.v v8, (%[SRC]) \n\t"
1018
- "addi %[SRC], %[SRC], 256 \n\t"
1019
- "vle32.v v16, (%[SRC]) \n\t"
1020
- "addi %[SRC], %[SRC], 256 \n\t"
1021
- "vle32.v v24, (%[SRC]) \n\t"
1022
- "addi %[SRC], %[SRC], 256 \n\t"
1023
- "vfabs.v v0, v0 \n\t"
1024
- "vfabs.v v8, v8 \n\t"
1025
- "vfabs.v v16, v16 \n\t"
1026
- "vfabs.v v24, v24 \n\t"
1027
- "vfmax.vv v8, v0, v8 \n\t"
1028
- "vfmax.vv v16, v16, v24 \n\t"
1029
- "vfmax.vv v0, v0, v16 \n\t"
1030
- "vsetvli t0, zero, e32, m4 \n\t"
1031
- "vfmax.vv v0, v0, v4 \n\t"
1032
- "vsetvli t0, zero, e32, m2 \n\t"
1033
- "vfmax.vv v0, v0, v2 \n\t"
1034
- "vsetvli t0, zero, e32, m1 \n\t"
1035
- "vfmax.vv v0, v0, v1 \n\t"
1036
- "vle32.v v30, (%[BUFFER]) \n\t"
1037
- "vfmax.vv v31, v30, v0 \n\t"
1038
- "vse32.v v31, (%[BUFFER]) \n\t"
1039
- "bnez t6, LOOP_CMP%= \n\t"
1040
- "sub %[SRC], %[SRC], t3 \n\t"
1041
- "addi t6, %[CNT], 0 \n\t"
1042
- "flw f0, (%[BUFFER]) \n\t"
1043
- "flw f1, 4(%[BUFFER]) \n\t"
1044
- "flw f2, 8(%[BUFFER]) \n\t"
1045
- "flw f3, 12(%[BUFFER]) \n\t"
1046
- "flw f4, 16(%[BUFFER]) \n\t"
1047
- "flw f5, 20(%[BUFFER]) \n\t"
1048
- "flw f6, 24(%[BUFFER]) \n\t"
1049
- "flw f7, 28(%[BUFFER]) \n\t"
1050
- "fmax.s f1, f0, f1 \n\t"
1051
- "fmax.s f3, f2, f3 \n\t"
1052
- "fmax.s f5, f4, f5 \n\t"
1053
- "fmax.s f7, f6, f7 \n\t"
1054
- "fmax.s f3, f1, f3 \n\t"
1055
- "fmax.s f7, f5, f7 \n\t"
1056
- "fmax.s f10, f3, f7 \n\t"
1057
- "fmul.s f10, f10, %[RMAXREC] \n\t"
1058
- "fsw f10, (%[DST]) \n\t"
1059
- "addi %[DST], %[DST], 4 \n\t"
1060
- "fdiv.s f11, %[FONE], f10 \n\t"
1061
- "addi t6, %[CNT], 0 \n\t"
1062
- "LOOP_QUANT%=: \n\t"
1063
- "addi t6, t6, -1 \n\t"
1064
- "vsetvli t0, zero, e32, m8 \n\t"
1065
- "vle32.v v0, (%[SRC]) \n\t"
1066
- "addi %[SRC], %[SRC], 256 \n\t"
1067
- "vle32.v v8, (%[SRC]) \n\t"
1068
- "addi %[SRC], %[SRC], 256 \n\t"
1069
- "vle32.v v16, (%[SRC]) \n\t"
1070
- "addi %[SRC], %[SRC], 256 \n\t"
1071
- "vle32.v v24, (%[SRC]) \n\t"
1072
- "addi %[SRC], %[SRC], 256 \n\t"
1073
- "vsetvli t0, zero, e32, m8 \n\t"
1074
- "vfmul.vf v0, v0, f11 \n\t"
1075
- "vfmul.vf v8, v8, f11 \n\t"
1076
- "vfmul.vf v16, v16, f11 \n\t"
1077
- "vfmul.vf v24, v24, f11 \n\t"
1078
- "vfcvt.x.f.v v0, v0 \n\t"
1079
- "vfcvt.x.f.v v8, v8 \n\t"
1080
- "vfcvt.x.f.v v16, v16 \n\t"
1081
- "vfcvt.x.f.v v24, v24 \n\t"
1082
- "vsetvli t0, zero, e16, m4 \n\t"
1083
- "vnclip.wx v0, v0, zero \n\t"
1084
- "vnclip.wx v4, v8, zero \n\t"
1085
- "vnclip.wx v8, v16, zero \n\t"
1086
- "vnclip.wx v12, v24, zero \n\t"
1087
- "vsetvli t0, zero, e8, m4 \n\t"
1088
- "vnclip.wx v0, v0, zero \n\t"
1089
- "vnclip.wx v4, v8, zero \n\t"
1090
- "vse8.v v0, (%[DST]) \n\t"
1091
- "addi %[DST], %[DST], 128 \n\t"
1092
- "vse8.v v4, (%[DST]) \n\t"
1093
- "addi %[DST], %[DST], 128 \n\t"
1094
- "bnez t6, LOOP_QUANT%= \n\t"
1095
- "sub %[K], %[K], %[BLK] \n\t"
1096
- "bge %[K], %[BLK], LOOP_MAIN%= \n\t"
1097
- "blez %[K], END%= \n\t"
1098
- "LOOP_TAIL%=: \n\t"
1099
- "vsetvli t0, zero, e32, m1 \n\t"
1100
- "vxor.vv v31, v31, v31 \n\t"
1101
- "vse32.v v31, (%[BUFFER]) \n\t"
1102
- "addi t6, %[K], 0 \n\t"
1103
- "addi s1, %[SRC], 0 \n\t"
1104
- "TAIL_CMP%=: \n\t"
1105
- "vsetvli t0, zero, e32, m8 \n\t"
1106
- "vxor.vv v0, v0, v0 \n\t"
1107
- "vsetvli t0, t6, e32, m8 \n\t"
1108
- "vle32.v v0, (%[SRC]) \n\t"
1109
- "addi %[SRC], %[SRC], 256 \n\t"
1110
- "sub t6, t6, t0 \n\t"
1111
- "vfabs.v v0, v0 \n\t"
1112
- "vsetvli t0, zero, e32, m4 \n\t"
1113
- "vfmax.vv v0, v0, v4 \n\t"
1114
- "vsetvli t0, zero, e32, m2 \n\t"
1115
- "vfmax.vv v0, v0, v2 \n\t"
1116
- "vsetvli t0, zero, e32, m1 \n\t"
1117
- "vfmax.vv v0, v0, v1 \n\t"
1118
- "vle32.v v30, (%[BUFFER]) \n\t"
1119
- "vfmax.vv v31, v30, v0 \n\t"
1120
- "vse32.v v31, (%[BUFFER]) \n\t"
1121
- "bnez t6, TAIL_CMP%= \n\t"
1122
- "addi t6, %[K], 0 \n\t"
1123
- "flw f0, (%[BUFFER]) \n\t"
1124
- "flw f1, 4(%[BUFFER]) \n\t"
1125
- "flw f2, 8(%[BUFFER]) \n\t"
1126
- "flw f3, 12(%[BUFFER]) \n\t"
1127
- "flw f4, 16(%[BUFFER]) \n\t"
1128
- "flw f5, 20(%[BUFFER]) \n\t"
1129
- "flw f6, 24(%[BUFFER]) \n\t"
1130
- "flw f7, 28(%[BUFFER]) \n\t"
1131
- "fmax.s f1, f0, f1 \n\t"
1132
- "fmax.s f3, f2, f3 \n\t"
1133
- "fmax.s f5, f4, f5 \n\t"
1134
- "fmax.s f7, f6, f7 \n\t"
1135
- "fmax.s f3, f1, f3 \n\t"
1136
- "fmax.s f7, f5, f7 \n\t"
1137
- "fmax.s f10, f3, f7 \n\t"
1138
- "fmul.s f10, f10, %[RMAXREC] \n\t"
1139
- "fsw f10, (%[DST]) \n\t"
1140
- "addi %[DST], %[DST], 4 \n\t"
1141
- "fdiv.s f11, %[FONE], f10 \n\t"
1142
- "addi t6, %[K], 0 \n\t"
1143
- "TAIL_QUANT%=: \n\t"
1144
- "vsetvli t0, zero, e32, m8 \n\t"
1145
- "vxor.vv v0, v0, v0 \n\t"
1146
- "vsetvli t1, t6, e32, m8 \n\t"
1147
- "vle32.v v0, (s1) \n\t"
1148
- "addi s1, s1, 256 \n\t"
1149
- "sub t6, t6, t1 \n\t"
1150
- "vsetvli t0, zero, e32, m8 \n\t"
1151
- "vfmul.vf v0, v0, f11 \n\t"
1152
- "vfcvt.x.f.v v0, v0 \n\t"
1153
- "vsetvli t0, zero, e16, m4 \n\t"
1154
- "vnclip.wx v0, v0, zero \n\t"
1155
- "vsetvli t0, t1, e8, m2 \n\t"
1156
- "vnclip.wx v0, v0, zero \n\t"
1157
- "vse8.v v0, (%[DST]) \n\t"
1158
- "addi %[DST], %[DST], 64 \n\t"
1159
- "bnez t6, TAIL_QUANT%= \n\t"
1160
- "END%=: \n\t"
1161
- : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK)
1162
- : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BLK] "r"(BlkLen), [BUFFER] "r"(buffer),
1163
- [CNT] "r"(cnt)
1164
- : "cc", "t1", "t0", "t6", "s1", "f0", "f1", "f2", "f3", "f4", "f5", "f6");
1165
- }
165
+ __asm__ volatile(
166
+ "addi t3, zero, 32*4 \n\t"
167
+ "addi t2, zero, 32 \n\t"
168
+
169
+ "addi a1, %[SRC], 0 \n\t"
170
+ "addi a2, %[SRC], 128 \n\t"
171
+ "addi a3, %[SRC], 256 \n\t"
172
+ "addi a4, %[SRC], 384 \n\t"
173
+
174
+ "addi s1, %[DST], 0 \n\t"
175
+ "addi s2, %[DST], 36 \n\t"
176
+ "addi s3, %[DST], 72 \n\t"
177
+ "addi s4, %[DST], 108 \n\t"
178
+ "blt %[K], t3, LOOP_K%= \n\t"
179
+ "blt %[K], t2, TAIL%= \n\t"
180
+
181
+ "LOOP_MAIN%=: \n\t"
182
+ "vsetvli t1, zero, e32, m4 \n\t"
183
+ "addi %[K], %[K], -128 \n\t"
184
+ "vle32.v v0, (a1) \n\t"
185
+ "addi a1, a1, 512 \n\t"
186
+ "vle32.v v4, (a2) \n\t"
187
+ "addi a2, a2, 512 \n\t"
188
+ "vle32.v v8, (a3) \n\t"
189
+ "addi a3, a3, 512 \n\t"
190
+ "vle32.v v12, (a4) \n\t"
191
+ "addi a4, a4, 512 \n\t"
192
+ "vfabs.v v16, v0 \n\t"
193
+ "vfabs.v v20, v4 \n\t"
194
+ "vfabs.v v24, v8 \n\t"
195
+ "vfabs.v v28, v12 \n\t"
196
+ "vsetvli t0, zero, e32, m2 \n\t"
197
+ "vfmax.vv v16, v16, v18 \n\t"
198
+ "vfmax.vv v20, v20, v22 \n\t"
199
+ "vfmax.vv v24, v24, v26 \n\t"
200
+ "vfmax.vv v28, v28, v30 \n\t"
201
+ "vsetvli t0, zero, e32, m1 \n\t"
202
+ "vfmax.vv v16, v16, v17 \n\t"
203
+ "vfmax.vv v20, v20, v21 \n\t"
204
+ "vfmax.vv v24, v24, v25 \n\t"
205
+ "vfmax.vv v28, v28, v29 \n\t"
206
+
207
+ "vfredmax.vs v17, v16, v17 \n\t"
208
+ "vfredmax.vs v21, v20, v21 \n\t"
209
+ "vfredmax.vs v25, v24, v25 \n\t"
210
+ "vfredmax.vs v29, v28, v29 \n\t"
211
+ "vfmv.f.s f10, v17 \n\t"
212
+ "vfmv.f.s f11, v21 \n\t"
213
+ "vfmv.f.s f12, v25 \n\t"
214
+ "vfmv.f.s f13, v29 \n\t"
215
+
216
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
217
+ "fmul.s f11, f11, %[RMAXREC] \n\t"
218
+ "fmul.s f12, f12, %[RMAXREC] \n\t"
219
+ "fmul.s f13, f13, %[RMAXREC] \n\t"
220
+ "fsw f10, (s1) \n\t"
221
+ "addi s1, s1, 4 \n\t"
222
+
223
+ "fsw f11, (s2) \n\t"
224
+ "addi s2, s2, 4 \n\t"
225
+ "fsw f12, (s3) \n\t"
226
+ "addi s3, s3, 4 \n\t"
227
+ "fsw f13, (s4) \n\t"
228
+ "addi s4, s4, 4 \n\t"
229
+ "fdiv.s f10, %[FONE], f10 \n\t"
230
+ "fdiv.s f11, %[FONE], f11 \n\t"
231
+ "fdiv.s f12, %[FONE], f12 \n\t"
232
+ "fdiv.s f13, %[FONE], f13 \n\t"
233
+ "vsetvli t0, zero, e32, m4 \n\t"
234
+ "vfmul.vf v16, v0, f10 \n\t"
235
+ "vfmul.vf v20, v4, f11 \n\t"
236
+ "vfmul.vf v24, v8, f12 \n\t"
237
+ "vfmul.vf v28, v12, f13 \n\t"
238
+ "vfcvt.x.f.v v16, v16 \n\t"
239
+ "vfcvt.x.f.v v20, v20 \n\t"
240
+ "vfcvt.x.f.v v24, v24 \n\t"
241
+ "vfcvt.x.f.v v28, v28 \n\t"
242
+ "vsetvli t0, zero, e16, m2 \n\t"
243
+ "vnclip.wx v16, v16, zero \n\t"
244
+ "vnclip.wx v20, v20, zero \n\t"
245
+ "vnclip.wx v24, v24, zero \n\t"
246
+ "vnclip.wx v28, v28, zero \n\t"
247
+ "vsetvli t0, t1, e8, m1 \n\t"
248
+ "vnclip.wx v16, v16, zero \n\t"
249
+ "vnclip.wx v20, v20, zero \n\t"
250
+ "vnclip.wx v24, v24, zero \n\t"
251
+ "vnclip.wx v28, v28, zero \n\t"
252
+ "vse8.v v16, (s1) \n\t"
253
+ "addi s1, s1, 140 \n\t"
254
+ "vse8.v v20, (s2) \n\t"
255
+ "addi s2, s2, 140 \n\t"
256
+ "vse8.v v24, (s3) \n\t"
257
+ "addi s3, s3, 140 \n\t"
258
+ "vse8.v v28, (s4) \n\t"
259
+ "addi s4, s4, 140 \n\t"
260
+ "bge %[K], t3, LOOP_MAIN%= \n\t"
261
+ "blt %[K], t2, TAIL%= \n\t"
262
+ "LOOP_K%=: \n\t"
263
+ "vsetvli t1, %[K], e32, m4 \n\t"
264
+ "vle32.v v0, (a1) \n\t"
265
+ "addi a1, a1, 128 \n\t"
266
+ "sub %[K], %[K], t1 \n\t"
267
+ "vfabs.v v16, v0 \n\t"
268
+ "vsetvli t0, zero, e32, m2 \n\t"
269
+ "vfmax.vv v16, v16, v18 \n\t"
270
+ "vsetvli t0, zero, e32, m1 \n\t"
271
+ "vfmax.vv v16, v16, v17 \n\t"
272
+ "vfredmax.vs v17, v16, v17 \n\t"
273
+ "vfmv.f.s f10, v17 \n\t"
274
+
275
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
276
+ "fsw f10, (s1) \n\t"
277
+ "addi s1, s1, 4 \n\t"
278
+ "fdiv.s f11, %[FONE], f10 \n\t"
279
+ "vsetvli t0, zero, e32, m4 \n\t"
280
+ "vfmul.vf v16, v0, f11 \n\t"
281
+ "vfcvt.x.f.v v16, v16 \n\t"
282
+ "vsetvli t0, zero, e16, m2 \n\t"
283
+ "vnclip.wx v16, v16, zero \n\t"
284
+ "vsetvli t0, zero, e8, m1 \n\t"
285
+ "vnclip.wx v16, v16, zero \n\t"
286
+ "vse8.v v16, (s1) \n\t"
287
+ "addi s1, s1, 32 \n\t"
288
+ "bge %[K], t2, LOOP_K%= \n\t"
289
+ "TAIL%=: \n\t"
290
+ "blez %[K], END%= \n\t"
291
+ "vsetvli t0, t3, e32, m4 \n\t"
292
+ "vxor.vv v0, v0, v0 \n\t"
293
+ "vxor.vv v16, v16, v16 \n\t"
294
+ "jal x0, LOOP_K%= \n\t"
295
+ "END%=: \n\t"
296
+ : [K] "+r"(CountK)
297
+ : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST)
298
+ : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13");
1166
299
  }
1167
300
 
1168
301
  } // namespace ime1
@@ -1451,1746 +584,444 @@ namespace {
1451
584
  "vadd.vi v1, v1, -12 \n\t"
1452
585
 
1453
586
  template <bool HasZeroPoint>
1454
- void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
1455
- const std::byte * QuantA,
1456
- const std::byte * QuantBData,
1457
- const float * QuantBScale,
1458
- const std::byte * QuantBZeroPoint,
1459
- float * C,
1460
- size_t CountN,
1461
- size_t BlockCountK,
1462
- const float * Bias,
1463
- const size_t ldc) {
1464
- GGML_UNUSED(QuantBScale);
1465
- GGML_UNUSED(QuantBZeroPoint);
587
+ void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
588
+ const uint8_t * QuantA,
589
+ const uint8_t * QuantBData,
590
+ float * C,
591
+ size_t CountN,
592
+ size_t BlockCountK,
593
+ const size_t ldc) {
1466
594
  size_t LDC = ldc * sizeof(float);
1467
595
  const size_t INNER = BlkLen / 16;
1468
596
  float tmp[4 * 16];
1469
597
 
1470
598
  if constexpr (HasZeroPoint) {
1471
599
  for (size_t n = 0; n < CountN; n += 16) {
1472
- size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
1473
- std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
1474
- n * BlockCountK * BlkLen / 2 + // b data
1475
- n * BlockCountK * sizeof(uint8_t) + // zp
1476
- n * BlockCountK * sizeof(_Float16); // scale
600
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
601
+ uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + //
602
+ n * BlockCountK * BlkLen / 2 + // b data
603
+ n * BlockCountK * sizeof(uint8_t) + // zp
604
+ n * BlockCountK * sizeof(_Float16); // scale
1477
605
  float * CPtr = C + n;
1478
606
  if (NBLKS < 16) {
1479
607
  CPtr = tmp;
1480
608
  LDC = 16 * sizeof(float);
1481
609
  }
1482
- if (Bias != nullptr) {
1483
- const float * bias = Bias + n;
1484
- if (NBLKS < 16) {
1485
- __asm__ volatile(
1486
- "vsetvli t0, %[N], e32, m2 \n\t"
1487
- "vle32.v v0, (%[SRC]) \n\t"
1488
- "vse32.v v0, (%[DST]) \n\t"
1489
- :
1490
- : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
1491
- : "cc", "t0");
1492
- bias = tmp;
1493
- }
1494
- __asm__ volatile(LOAD_BIAS
1495
-
1496
- "addi t3, %[BlockCountK], 0 \n\t"
1497
-
1498
- "vsetvli t0, zero, e8, m1 \n\t"
1499
- "li s1, 24 \n\t"
1500
- "vmv.v.i v1, 3 \n\t"
1501
- "vsetvli t0, s1, e8, m1 \n\t"
1502
- "vmv.v.i v1, 2 \n\t"
1503
- "vsetvli t0, zero, e8, mf2 \n\t"
1504
- "vmv.v.i v1, 1 \n\t"
1505
- "vsetvli t0, zero, e8, mf4 \n\t"
1506
- "vmv.v.i v1, 0 \n\t"
1507
-
1508
- "addi a1, %[A], 0 \n\t"
1509
- "addi s1, %[B], 0 \n\t"
1510
-
1511
- "BLOCK_COUNTK_LOOP%=: \n\t"
1512
- // scale offset
1513
- "addi s5, s1, 0 \n\t"
1514
- // zp offset
1515
- "addi s6, s1, 32 \n\t"
1516
- "addi s1, s6, 16 \n\t"
1517
- "addi s2, s1, 32 \n\t"
1518
- "addi s3, s1, 32*2 \n\t"
1519
- "addi s4, s1, 32*3 \n\t"
1520
-
1521
- "vsetvli t0, zero, e32, m8 \n\t"
1522
- "vxor.vv v16, v16, v16 \n\t"
1523
- // load a scale
1524
- "flw f1, (a1) \n\t"
1525
- "flw f2, 4(a1) \n\t"
1526
- "flw f3, 8(a1) \n\t"
1527
- "flw f4, 12(a1) \n\t"
1528
- "addi a1, a1, 16 \n\t"
1529
- "addi t2, %[INNER], 0 \n\t"
1530
-
1531
- SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1532
-
1533
- "BLOCK_INNER_LOOP%=: \n\t"
1534
-
1535
- LOAD_B_16x8x2
1536
-
1537
- "vle8.v v10, (a1) \n\t"
1538
- "addi a1, a1, 32 \n\t"
1539
- "vle8.v v11, (a1) \n\t"
1540
- "addi a1, a1, 32 \n\t"
1541
- "vsub.vv v2, v2, v12 \n\t"
1542
- "vsub.vv v6, v6, v12 \n\t"
1543
- "vsub.vv v3, v3, v13 \n\t"
1544
- "vsub.vv v7, v7, v13 \n\t"
1545
- "vsub.vv v4, v4, v14 \n\t"
1546
- "vsub.vv v8, v8, v14 \n\t"
1547
- "vsub.vv v5, v5, v15 \n\t"
1548
- "vsub.vv v9, v9, v15 \n\t"
1549
-
1550
- SQ4BIT_KERNEL_COMP_4x16x16
1551
-
1552
- "addi t2, t2, -1 \n\t"
1553
- "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1554
-
1555
- LOAD_SCALE_4x16_FP16
1556
-
1557
- "vsetvli t0, zero, e32, m8 \n\t"
1558
- "vfcvt.f.x.v v16, v16 \n\t"
1559
- "vfmacc.vv v24, v16, v8 \n\t"
1560
- "addi t3, t3, -1 \n\t"
1561
- "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1562
-
1563
- "RESULT_SAVE%=: \n\t"
1564
-
1565
- SAVE_RESULT_4x16
1566
-
1567
- :
1568
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1569
- [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
1570
- : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
1571
- "s2", "s3", "s4", "s5", "s6");
1572
-
1573
- } else {
1574
- __asm__ volatile(
1575
- "vsetvli t0, zero, e32, m8 \n\t"
1576
- "vxor.vv v24, v24, v24 \n\t"
1577
- "addi t3, %[BlockCountK], 0 \n\t"
1578
- "vsetvli t0, zero, e8, m1 \n\t"
1579
- "li s1, 24 \n\t"
1580
- "vmv.v.i v1, 3 \n\t"
1581
- "vsetvli t0, s1, e8, m1 \n\t"
1582
- "vmv.v.i v1, 2 \n\t"
1583
- "vsetvli t0, zero, e8, mf2 \n\t"
1584
- "vmv.v.i v1, 1 \n\t"
1585
- "vsetvli t0, zero, e8, mf4 \n\t"
1586
- "vmv.v.i v1, 0 \n\t"
1587
- "addi a1, %[A], 0 \n\t"
1588
- "addi s1, %[B], 0 \n\t"
1589
- "BLOCK_COUNTK_LOOP%=: \n\t"
1590
- // scale offset
1591
- "addi s5, s1, 0 \n\t"
1592
- // zp offset
1593
- "addi s6, s1, 32 \n\t"
1594
- "addi s1, s6, 16 \n\t"
1595
- "addi s2, s1, 32 \n\t"
1596
- "addi s3, s1, 32*2 \n\t"
1597
- "addi s4, s1, 32*3 \n\t"
1598
-
1599
- "vsetvli t0, zero, e32, m8 \n\t"
1600
- "vxor.vv v16, v16, v16 \n\t"
1601
- // load a scale
1602
- "flw f1, (a1) \n\t"
1603
- "flw f2, 4(a1) \n\t"
1604
- "flw f3, 8(a1) \n\t"
1605
- "flw f4, 12(a1) \n\t"
1606
- "addi a1, a1, 16 \n\t"
1607
- "addi t2, %[INNER], 0 \n\t"
1608
-
1609
- SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1610
-
1611
- "BLOCK_INNER_LOOP%=: \n\t"
1612
-
1613
- LOAD_B_16x8x2
1614
-
1615
- "vle8.v v10, (a1) \n\t"
1616
- "addi a1, a1, 32 \n\t"
1617
- "vle8.v v11, (a1) \n\t"
1618
- "addi a1, a1, 32 \n\t"
1619
- "vsub.vv v2, v2, v12 \n\t"
1620
- "vsub.vv v6, v6, v12 \n\t"
1621
- "vsub.vv v3, v3, v13 \n\t"
1622
- "vsub.vv v7, v7, v13 \n\t"
1623
- "vsub.vv v4, v4, v14 \n\t"
1624
- "vsub.vv v8, v8, v14 \n\t"
1625
- "vsub.vv v5, v5, v15 \n\t"
1626
- "vsub.vv v9, v9, v15 \n\t"
1627
-
1628
- SQ4BIT_KERNEL_COMP_4x16x16
1629
-
1630
- "addi t2, t2, -1 \n\t"
1631
- "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1632
-
1633
- LOAD_SCALE_4x16_FP16
1634
-
1635
- "vsetvli t0, zero, e32, m8 \n\t"
1636
- "vfcvt.f.x.v v16, v16 \n\t"
1637
- "vfmacc.vv v24, v16, v8 \n\t"
1638
- "addi t3, t3, -1 \n\t"
1639
- "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1640
-
1641
- "RESULT_SAVE%=: \n\t"
1642
-
1643
- SAVE_RESULT_4x16
1644
-
1645
- :
1646
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1647
- [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
1648
- : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
1649
- "s4", "s5", "s6");
1650
- }
1651
- }
1652
- } else {
1653
- for (size_t n = 0; n < CountN; n += 16) {
1654
- size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
1655
- std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
1656
- n * BlockCountK * BlkLen / 2 + // b data
1657
- n * BlockCountK * sizeof(_Float16); // scale
1658
- float * CPtr = C + n;
1659
- if (NBLKS < 16) {
1660
- CPtr = tmp;
1661
- LDC = 16 * sizeof(float);
1662
- }
1663
- if (Bias != nullptr) {
1664
- const float * bias = Bias + n;
1665
- if (NBLKS < 16) {
1666
- __asm__ volatile(
1667
- "vsetvli t0, %[N], e32, m2 \n\t"
1668
- "vle32.v v0, (%[SRC]) \n\t"
1669
- "vse32.v v0, (%[DST]) \n\t"
1670
- :
1671
- : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
1672
- : "cc", "t0");
1673
- bias = tmp;
1674
- }
1675
- __asm__ volatile(LOAD_BIAS
1676
-
1677
- "addi t3, %[BlockCountK], 0 \n\t"
1678
- "addi a1, %[A], 0 \n\t"
1679
- "addi s1, %[B], 0 \n\t"
1680
- "BLOCK_COUNTK_LOOP%=: \n\t"
1681
- "addi s5, s1, 0 \n\t"
1682
- "addi s1, s5, 32 \n\t"
1683
- "addi s2, s1, 32 \n\t"
1684
- "addi s3, s1, 32*2 \n\t"
1685
- "addi s4, s1, 32*3 \n\t"
1686
- "vsetvli t0, zero, e32, m8 \n\t"
1687
- "vxor.vv v16, v16, v16 \n\t"
1688
- // load a scale
1689
- "flw f1, (a1) \n\t"
1690
- "flw f2, 4(a1) \n\t"
1691
- "flw f3, 8(a1) \n\t"
1692
- "flw f4, 12(a1) \n\t"
1693
- "addi a1, a1, 16 \n\t"
1694
- "addi t2, %[INNER], 0 \n\t"
1695
- "BLOCK_INNER_LOOP%=: \n\t"
1696
-
1697
- LOAD_B_16x8x2
1698
-
1699
- "vsetvli t0, zero, e8, m1 \n\t"
1700
- "vle8.v v10, (a1) \n\t"
1701
- "addi a1, a1, 32 \n\t"
1702
- "vle8.v v11, (a1) \n\t"
1703
- "addi a1, a1, 32 \n\t"
1704
- "vadd.vi v2, v2, -8 \n\t"
1705
- "vadd.vi v3, v3, -8 \n\t"
1706
- "vadd.vi v4, v4, -8 \n\t"
1707
- "vadd.vi v5, v5, -8 \n\t"
1708
- "vadd.vi v6, v6, -8 \n\t"
1709
- "vadd.vi v7, v7, -8 \n\t"
1710
- "vadd.vi v8, v8, -8 \n\t"
1711
- "vadd.vi v9, v9, -8 \n\t"
1712
-
1713
- SQ4BIT_KERNEL_COMP_4x16x16
1714
-
1715
- "addi t2, t2, -1 \n\t"
1716
- "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1717
-
1718
- LOAD_SCALE_4x16_FP16
1719
-
1720
- "vsetvli t0, zero, e32, m8 \n\t"
1721
- "vfcvt.f.x.v v16, v16 \n\t"
1722
- "vfmacc.vv v24, v16, v8 \n\t"
1723
- "addi t3, t3, -1 \n\t"
1724
- "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1725
- "RESULT_SAVE%=: \n\t"
1726
-
1727
- SAVE_RESULT_4x16
1728
-
1729
- :
1730
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1731
- [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
1732
- : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
1733
- "s2", "s3", "s4", "s5", "s6");
1734
-
1735
- } else {
1736
- __asm__ volatile(
1737
- "vsetvli t0, zero, e32, m8 \n\t"
1738
- "vxor.vv v24, v24, v24 \n\t"
1739
- "addi t3, %[BlockCountK], 0 \n\t"
1740
- "addi a1, %[A], 0 \n\t"
1741
- "addi s1, %[B], 0 \n\t"
1742
- "BLOCK_COUNTK_LOOP%=: \n\t"
1743
- "addi s5, s1, 0 \n\t"
1744
- "addi s1, s5, 32 \n\t"
1745
- "addi s2, s1, 32 \n\t"
1746
- "addi s3, s1, 32*2 \n\t"
1747
- "addi s4, s1, 32*3 \n\t"
1748
- "vsetvli t0, zero, e32, m8 \n\t"
1749
- "vxor.vv v16, v16, v16 \n\t"
1750
- // load a scale
1751
- "flw f1, (a1) \n\t"
1752
- "flw f2, 4(a1) \n\t"
1753
- "flw f3, 8(a1) \n\t"
1754
- "flw f4, 12(a1) \n\t"
1755
- "addi a1, a1, 16 \n\t"
1756
- "addi t2, %[INNER], 0 \n\t"
1757
- "BLOCK_INNER_LOOP%=: \n\t"
1758
-
1759
- LOAD_B_16x8x2
1760
-
1761
- "vsetvli t0, zero, e8, m1 \n\t"
1762
- "vle8.v v10, (a1) \n\t"
1763
- "addi a1, a1, 32 \n\t"
1764
- "vle8.v v11, (a1) \n\t"
1765
- "addi a1, a1, 32 \n\t"
1766
- "vadd.vi v2, v2, -8 \n\t"
1767
- "vadd.vi v3, v3, -8 \n\t"
1768
- "vadd.vi v4, v4, -8 \n\t"
1769
- "vadd.vi v5, v5, -8 \n\t"
1770
- "vadd.vi v6, v6, -8 \n\t"
1771
- "vadd.vi v7, v7, -8 \n\t"
1772
- "vadd.vi v8, v8, -8 \n\t"
1773
- "vadd.vi v9, v9, -8 \n\t"
1774
-
1775
- SQ4BIT_KERNEL_COMP_4x16x16
1776
-
1777
- "addi t2, t2, -1 \n\t"
1778
- "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1779
-
1780
- LOAD_SCALE_4x16_FP16
1781
-
1782
- "vsetvli t0, zero, e32, m8 \n\t"
1783
- "vfcvt.f.x.v v16, v16 \n\t"
1784
- "vfmacc.vv v24, v16, v8 \n\t"
1785
- "addi t3, t3, -1 \n\t"
1786
- "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1787
- "RESULT_SAVE%=: \n\t"
1788
-
1789
- SAVE_RESULT_4x16
1790
-
1791
- :
1792
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1793
- [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
1794
- : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
1795
- "s4", "s5", "s6");
1796
- }
1797
- }
1798
- }
1799
- if (CountN % 16 != 0) {
1800
- // stroe output from tmp to C when NBLKS less than 16.
1801
- float * CPtr = C + CountN / 16 * 16;
1802
- const size_t N = CountN % 16;
1803
- LDC = ldc * sizeof(float);
1804
- __asm__ volatile(
1805
- "vsetvli t0, %[N], e32, m2 \n\t"
1806
- "vle32.v v0, (%[SRC]) \n\t"
1807
- "addi s2, %[SRC], 64 \n\t"
1808
- "addi s3, %[SRC], 64*2 \n\t"
1809
- "addi s4, %[SRC], 64*3 \n\t"
1810
- "vle32.v v2, (s2) \n\t"
1811
- "vle32.v v4, (s3) \n\t"
1812
- "vle32.v v6, (s4) \n\t"
1813
- "add t2, %[DST], %[LDC] \n\t"
1814
- "add t3, t2, %[LDC] \n\t"
1815
- "add t4, t3, %[LDC] \n\t"
1816
- "vse32.v v0, (%[DST]) \n\t"
1817
- "vse32.v v2, (t2) \n\t"
1818
- "vse32.v v4, (t3) \n\t"
1819
- "vse32.v v6, (t4) \n\t"
1820
- :
1821
- : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC)
1822
- : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4");
1823
- }
1824
- }
1825
610
 
1826
- template <bool HasZeroPoint>
1827
- void SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen,
1828
- const std::byte * QuantA,
1829
- const std::byte * QuantBData,
1830
- const float * QuantBScale,
1831
- const std::byte * QuantBZeroPoint,
1832
- float * C,
1833
- size_t CountN,
1834
- size_t BlockCountK,
1835
- const float * Bias,
1836
- const size_t ldc) {
1837
- GGML_UNUSED(QuantBScale);
1838
- GGML_UNUSED(QuantBZeroPoint);
1839
- size_t LDC = ldc * sizeof(float);
1840
- const size_t INNER = BlkLen / 16;
1841
- float tmp[4 * 16];
1842
-
1843
- if constexpr (HasZeroPoint) {
1844
- for (size_t n = 0; n < CountN; n += 16) {
1845
- size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
1846
- std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
1847
- n * BlockCountK * BlkLen / 2 + // b data
1848
- n * BlockCountK * sizeof(uint8_t) + // zp
1849
- n * BlockCountK * sizeof(float); // scale
1850
- float * CPtr = C + n;
1851
- if (NBLKS < 16) {
1852
- CPtr = tmp;
1853
- LDC = 16 * sizeof(float);
1854
- }
1855
- if (Bias != nullptr) {
1856
- const float * bias = Bias + n;
1857
- if (NBLKS < 16) {
1858
- __asm__ volatile(
1859
- "vsetvli t0, %[N], e32, m2 \n\t"
1860
- "vle32.v v0, (%[SRC]) \n\t"
1861
- "vse32.v v0, (%[DST]) \n\t"
1862
- :
1863
- : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
1864
- : "cc", "t0");
1865
- bias = tmp;
1866
- }
1867
-
1868
- __asm__ volatile(LOAD_BIAS
1869
- "addi t3, %[BlockCountK], 0 \n\t"
1870
- "vsetvli t0, zero, e8, m1 \n\t"
1871
- "li s1, 24 \n\t"
1872
- "vmv.v.i v1, 3 \n\t"
1873
- "vsetvli t0, s1, e8, m1 \n\t"
1874
- "vmv.v.i v1, 2 \n\t"
1875
- "vsetvli t0, zero, e8, mf2 \n\t"
1876
- "vmv.v.i v1, 1 \n\t"
1877
- "vsetvli t0, zero, e8, mf4 \n\t"
1878
- "vmv.v.i v1, 0 \n\t"
1879
- "addi a1, %[A], 0 \n\t"
1880
- "addi s1, %[B], 0 \n\t"
1881
- "BLOCK_COUNTK_LOOP%=: \n\t"
1882
- // scale offset
1883
- "addi s5, s1, 0 \n\t"
1884
- // zp offset
1885
- "addi s6, s1, 64 \n\t"
1886
- "addi s1, s6, 16 \n\t"
1887
- "addi s2, s1, 32 \n\t"
1888
- "addi s3, s1, 32*2 \n\t"
1889
- "addi s4, s1, 32*3 \n\t"
1890
- "vsetvli t0, zero, e32, m8 \n\t"
1891
- "vxor.vv v16, v16, v16 \n\t"
1892
- // load a scale
1893
- "flw f1, (a1) \n\t"
1894
- "flw f2, 4(a1) \n\t"
1895
- "flw f3, 8(a1) \n\t"
1896
- "flw f4, 12(a1) \n\t"
1897
- "addi a1, a1, 16 \n\t"
1898
- "addi t2, %[INNER], 0 \n\t"
1899
-
1900
- SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1901
-
1902
- "BLOCK_INNER_LOOP%=: \n\t"
1903
-
1904
- LOAD_B_16x8x2
1905
-
1906
- "vle8.v v10, (a1) \n\t"
1907
- "addi a1, a1, 32 \n\t"
1908
- "vle8.v v11, (a1) \n\t"
1909
- "addi a1, a1, 32 \n\t"
1910
- "vsub.vv v2, v2, v12 \n\t"
1911
- "vsub.vv v6, v6, v12 \n\t"
1912
- "vsub.vv v3, v3, v13 \n\t"
1913
- "vsub.vv v7, v7, v13 \n\t"
1914
- "vsub.vv v4, v4, v14 \n\t"
1915
- "vsub.vv v8, v8, v14 \n\t"
1916
- "vsub.vv v5, v5, v15 \n\t"
1917
- "vsub.vv v9, v9, v15 \n\t"
1918
-
1919
- SQ4BIT_KERNEL_COMP_4x16x16
1920
-
1921
- "addi t2, t2, -1 \n\t"
1922
- "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1923
-
1924
- LOAD_SCALE_4x16
1925
-
1926
- "vsetvli t0, zero, e32, m8 \n\t"
1927
- "vfcvt.f.x.v v16, v16 \n\t"
1928
- "vfmacc.vv v24, v16, v8 \n\t"
1929
- "addi t3, t3, -1 \n\t"
1930
- "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1931
-
1932
- "RESULT_SAVE%=: \n\t"
1933
-
1934
- SAVE_RESULT_4x16
1935
-
1936
- :
1937
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1938
- [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
1939
- : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
1940
- "s2", "s3", "s4", "s5", "s6");
1941
-
1942
- } else {
1943
- __asm__ volatile(
1944
- "vsetvli t0, zero, e32, m8 \n\t"
1945
- "vxor.vv v24, v24, v24 \n\t"
1946
- "addi t3, %[BlockCountK], 0 \n\t"
1947
- "vsetvli t0, zero, e8, m1 \n\t"
1948
- "li s1, 24 \n\t"
1949
- "vmv.v.i v1, 3 \n\t"
1950
- "vsetvli t0, s1, e8, m1 \n\t"
1951
- "vmv.v.i v1, 2 \n\t"
1952
- "vsetvli t0, zero, e8, mf2 \n\t"
1953
- "vmv.v.i v1, 1 \n\t"
1954
- "vsetvli t0, zero, e8, mf4 \n\t"
1955
- "vmv.v.i v1, 0 \n\t"
1956
- "addi a1, %[A], 0 \n\t"
1957
- "addi s1, %[B], 0 \n\t"
1958
- "BLOCK_COUNTK_LOOP%=: \n\t"
1959
- // scale offset
1960
- "addi s5, s1, 0 \n\t"
1961
- // zp offset
1962
- "addi s6, s1, 64 \n\t"
1963
- "addi s1, s6, 16 \n\t"
1964
- "addi s2, s1, 32 \n\t"
1965
- "addi s3, s1, 32*2 \n\t"
1966
- "addi s4, s1, 32*3 \n\t"
1967
- "vsetvli t0, zero, e32, m8 \n\t"
1968
- "vxor.vv v16, v16, v16 \n\t"
1969
- // load a scale
1970
- // load a scale
1971
- "flw f1, (a1) \n\t"
1972
- "flw f2, 4(a1) \n\t"
1973
- "flw f3, 8(a1) \n\t"
1974
- "flw f4, 12(a1) \n\t"
1975
- "addi a1, a1, 16 \n\t"
1976
- "addi t2, %[INNER], 0 \n\t"
1977
-
1978
- SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1979
-
1980
- "BLOCK_INNER_LOOP%=: \n\t"
1981
-
1982
- LOAD_B_16x8x2
1983
-
1984
- "vle8.v v10, (a1) \n\t"
1985
- "addi a1, a1, 32 \n\t"
1986
- "vle8.v v11, (a1) \n\t"
1987
- "addi a1, a1, 32 \n\t"
1988
- "vsub.vv v2, v2, v12 \n\t"
1989
- "vsub.vv v6, v6, v12 \n\t"
1990
- "vsub.vv v3, v3, v13 \n\t"
1991
- "vsub.vv v7, v7, v13 \n\t"
1992
- "vsub.vv v4, v4, v14 \n\t"
1993
- "vsub.vv v8, v8, v14 \n\t"
1994
- "vsub.vv v5, v5, v15 \n\t"
1995
- "vsub.vv v9, v9, v15 \n\t"
1996
-
1997
- SQ4BIT_KERNEL_COMP_4x16x16
1998
-
1999
- "addi t2, t2, -1 \n\t"
2000
- "bnez t2, BLOCK_INNER_LOOP%= \n\t"
2001
-
2002
- LOAD_SCALE_4x16
2003
-
2004
- "vsetvli t0, zero, e32, m8 \n\t"
2005
- "vfcvt.f.x.v v16, v16 \n\t"
2006
- "vfmacc.vv v24, v16, v8 \n\t"
2007
- "addi t3, t3, -1 \n\t"
2008
- "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
2009
-
2010
- "RESULT_SAVE%=: \n\t"
2011
-
2012
- SAVE_RESULT_4x16
2013
-
2014
- :
2015
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
2016
- [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
2017
- : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
2018
- "s4", "s5", "s6");
2019
- }
611
+ __asm__ volatile(
612
+ "vsetvli t0, zero, e32, m8 \n\t"
613
+ "vxor.vv v24, v24, v24 \n\t"
614
+ "addi t3, %[BlockCountK], 0 \n\t"
615
+ "vsetvli t0, zero, e8, m1 \n\t"
616
+ "li s1, 24 \n\t"
617
+ "vmv.v.i v1, 3 \n\t"
618
+ "vsetvli t0, s1, e8, m1 \n\t"
619
+ "vmv.v.i v1, 2 \n\t"
620
+ "vsetvli t0, zero, e8, mf2 \n\t"
621
+ "vmv.v.i v1, 1 \n\t"
622
+ "vsetvli t0, zero, e8, mf4 \n\t"
623
+ "vmv.v.i v1, 0 \n\t"
624
+ "addi a1, %[A], 0 \n\t"
625
+ "addi s1, %[B], 0 \n\t"
626
+ "BLOCK_COUNTK_LOOP%=: \n\t"
627
+ // scale offset
628
+ "addi s5, s1, 0 \n\t"
629
+ // zp offset
630
+ "addi s6, s1, 32 \n\t"
631
+ "addi s1, s6, 16 \n\t"
632
+ "addi s2, s1, 32 \n\t"
633
+ "addi s3, s1, 32*2 \n\t"
634
+ "addi s4, s1, 32*3 \n\t"
635
+
636
+ "vsetvli t0, zero, e32, m8 \n\t"
637
+ "vxor.vv v16, v16, v16 \n\t"
638
+ // load a scale
639
+ "flw f1, (a1) \n\t"
640
+ "flw f2, 4(a1) \n\t"
641
+ "flw f3, 8(a1) \n\t"
642
+ "flw f4, 12(a1) \n\t"
643
+ "addi a1, a1, 16 \n\t"
644
+ "addi t2, %[INNER], 0 \n\t"
645
+
646
+ SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
647
+
648
+ "BLOCK_INNER_LOOP%=: \n\t"
649
+
650
+ LOAD_B_16x8x2
651
+
652
+ "vle8.v v10, (a1) \n\t"
653
+ "addi a1, a1, 32 \n\t"
654
+ "vle8.v v11, (a1) \n\t"
655
+ "addi a1, a1, 32 \n\t"
656
+ "vsub.vv v2, v2, v12 \n\t"
657
+ "vsub.vv v6, v6, v12 \n\t"
658
+ "vsub.vv v3, v3, v13 \n\t"
659
+ "vsub.vv v7, v7, v13 \n\t"
660
+ "vsub.vv v4, v4, v14 \n\t"
661
+ "vsub.vv v8, v8, v14 \n\t"
662
+ "vsub.vv v5, v5, v15 \n\t"
663
+ "vsub.vv v9, v9, v15 \n\t"
664
+
665
+ SQ4BIT_KERNEL_COMP_4x16x16
666
+
667
+ "addi t2, t2, -1 \n\t"
668
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
669
+
670
+ LOAD_SCALE_4x16_FP16
671
+
672
+ "vsetvli t0, zero, e32, m8 \n\t"
673
+ "vfcvt.f.x.v v16, v16 \n\t"
674
+ "vfmacc.vv v24, v16, v8 \n\t"
675
+ "addi t3, t3, -1 \n\t"
676
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
677
+
678
+ "RESULT_SAVE%=: \n\t"
679
+
680
+ SAVE_RESULT_4x16
681
+
682
+ :
683
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
684
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
685
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4",
686
+ "s5", "s6");
2020
687
  }
2021
688
  } else {
2022
689
  for (size_t n = 0; n < CountN; n += 16) {
2023
- size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
2024
- std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2025
- n * BlockCountK * BlkLen / 2 + // b data
2026
- n * BlockCountK * sizeof(float); // scale
690
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
691
+ uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + //
692
+ n * BlockCountK * BlkLen / 2 + // b data
693
+ n * BlockCountK * sizeof(_Float16); // scale
2027
694
  float * CPtr = C + n;
2028
695
  if (NBLKS < 16) {
2029
696
  CPtr = tmp;
2030
697
  LDC = 16 * sizeof(float);
2031
698
  }
2032
- if (Bias != nullptr) {
2033
- const float * bias = Bias + n;
2034
- if (NBLKS < 16) {
2035
- __asm__ volatile(
2036
- "vsetvli t0, %[N], e32, m2 \n\t"
2037
- "vle32.v v0, (%[SRC]) \n\t"
2038
- "vse32.v v0, (%[DST]) \n\t"
2039
- :
2040
- : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
2041
- : "cc", "t0");
2042
- bias = tmp;
2043
- }
2044
- __asm__ volatile(LOAD_BIAS
2045
- "addi t3, %[BlockCountK], 0 \n\t"
2046
- "addi a1, %[A], 0 \n\t"
2047
- "addi s1, %[B], 0 \n\t"
2048
- "BLOCK_COUNTK_LOOP%=: \n\t"
2049
- "addi s5, s1, 0 \n\t"
2050
- "addi s1, s5, 64 \n\t"
2051
- "addi s2, s1, 32 \n\t"
2052
- "addi s3, s1, 32*2 \n\t"
2053
- "addi s4, s1, 32*3 \n\t"
2054
- "vsetvli t0, zero, e32, m8 \n\t"
2055
- "vxor.vv v16, v16, v16 \n\t"
2056
- // load a scale
2057
- "flw f1, (a1) \n\t"
2058
- "flw f2, 4(a1) \n\t"
2059
- "flw f3, 8(a1) \n\t"
2060
- "flw f4, 12(a1) \n\t"
2061
- "addi a1, a1, 16 \n\t"
2062
- "addi t2, %[INNER], 0 \n\t"
2063
- "BLOCK_INNER_LOOP%=: \n\t"
2064
-
2065
- LOAD_B_16x8x2
2066
-
2067
- "vsetvli t0, zero, e8, m1 \n\t"
2068
- "vle8.v v10, (a1) \n\t"
2069
- "addi a1, a1, 32 \n\t"
2070
- "vle8.v v11, (a1) \n\t"
2071
- "addi a1, a1, 32 \n\t"
2072
- "vadd.vi v2, v2, -8 \n\t"
2073
- "vadd.vi v3, v3, -8 \n\t"
2074
- "vadd.vi v4, v4, -8 \n\t"
2075
- "vadd.vi v5, v5, -8 \n\t"
2076
- "vadd.vi v6, v6, -8 \n\t"
2077
- "vadd.vi v7, v7, -8 \n\t"
2078
- "vadd.vi v8, v8, -8 \n\t"
2079
- "vadd.vi v9, v9, -8 \n\t"
2080
-
2081
- SQ4BIT_KERNEL_COMP_4x16x16
2082
-
2083
- "addi t2, t2, -1 \n\t"
2084
- "bnez t2, BLOCK_INNER_LOOP%= \n\t"
2085
-
2086
- LOAD_SCALE_4x16
2087
-
2088
- "vsetvli t0, zero, e32, m8 \n\t"
2089
- "vfcvt.f.x.v v16, v16 \n\t"
2090
- "vfmacc.vv v24, v16, v8 \n\t"
2091
- "addi t3, t3, -1 \n\t"
2092
- "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
2093
-
2094
- "RESULT_SAVE%=: \n\t"
2095
-
2096
- SAVE_RESULT_4x16
2097
-
2098
- :
2099
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
2100
- [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
2101
- : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
2102
- "s2", "s3", "s4", "s5", "s6");
2103
-
2104
- } else {
2105
- __asm__ volatile(
2106
- "vsetvli t0, zero, e32, m8 \n\t"
2107
- "vxor.vv v24, v24, v24 \n\t"
2108
- "addi t3, %[BlockCountK], 0 \n\t"
2109
- "addi a1, %[A], 0 \n\t"
2110
- "addi s1, %[B], 0 \n\t"
2111
- "BLOCK_COUNTK_LOOP%=: \n\t"
2112
- "addi s5, s1, 0 \n\t"
2113
- "addi s1, s5, 64 \n\t"
2114
- "addi s2, s1, 32 \n\t"
2115
- "addi s3, s1, 32*2 \n\t"
2116
- "addi s4, s1, 32*3 \n\t"
2117
- "vsetvli t0, zero, e32, m8 \n\t"
2118
- "vxor.vv v16, v16, v16 \n\t"
2119
- // load a scale
2120
- "flw f1, (a1) \n\t"
2121
- "flw f2, 4(a1) \n\t"
2122
- "flw f3, 8(a1) \n\t"
2123
- "flw f4, 12(a1) \n\t"
2124
- "addi a1, a1, 16 \n\t"
2125
- "addi t2, %[INNER], 0 \n\t"
2126
- "BLOCK_INNER_LOOP%=: \n\t"
2127
-
2128
- LOAD_B_16x8x2
2129
-
2130
- "vsetvli t0, zero, e8, m1 \n\t"
2131
- "vle8.v v10, (a1) \n\t"
2132
-
2133
- "addi a1, a1, 32 \n\t"
2134
- "vle8.v v11, (a1) \n\t"
2135
- "addi a1, a1, 32 \n\t"
2136
- "vadd.vi v2, v2, -8 \n\t"
2137
- "vadd.vi v3, v3, -8 \n\t"
2138
- "vadd.vi v4, v4, -8 \n\t"
2139
- "vadd.vi v5, v5, -8 \n\t"
2140
- "vadd.vi v6, v6, -8 \n\t"
2141
- "vadd.vi v7, v7, -8 \n\t"
2142
- "vadd.vi v8, v8, -8 \n\t"
2143
- "vadd.vi v9, v9, -8 \n\t"
2144
-
2145
- SQ4BIT_KERNEL_COMP_4x16x16
2146
-
2147
- "addi t2, t2, -1 \n\t"
2148
- "bnez t2, BLOCK_INNER_LOOP%= \n\t"
2149
-
2150
- LOAD_SCALE_4x16
2151
-
2152
- "vsetvli t0, zero, e32, m8 \n\t"
2153
- "vfcvt.f.x.v v16, v16 \n\t"
2154
- "vfmacc.vv v24, v16, v8 \n\t"
2155
- "addi t3, t3, -1 \n\t"
2156
- "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
2157
-
2158
- "RESULT_SAVE%=: \n\t"
2159
-
2160
- SAVE_RESULT_4x16
2161
-
2162
- :
2163
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
2164
- [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
2165
- : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
2166
- "s4", "s5", "s6");
2167
- }
699
+
700
+ __asm__ volatile(
701
+ "vsetvli t0, zero, e32, m8 \n\t"
702
+ "vxor.vv v24, v24, v24 \n\t"
703
+ "addi t3, %[BlockCountK], 0 \n\t"
704
+ "addi a1, %[A], 0 \n\t"
705
+ "addi s1, %[B], 0 \n\t"
706
+ "BLOCK_COUNTK_LOOP%=: \n\t"
707
+ "addi s5, s1, 0 \n\t"
708
+ "addi s1, s5, 32 \n\t"
709
+ "addi s2, s1, 32 \n\t"
710
+ "addi s3, s1, 32*2 \n\t"
711
+ "addi s4, s1, 32*3 \n\t"
712
+ "vsetvli t0, zero, e32, m8 \n\t"
713
+ "vxor.vv v16, v16, v16 \n\t"
714
+ // load a scale
715
+ "flw f1, (a1) \n\t"
716
+ "flw f2, 4(a1) \n\t"
717
+ "flw f3, 8(a1) \n\t"
718
+ "flw f4, 12(a1) \n\t"
719
+ "addi a1, a1, 16 \n\t"
720
+ "addi t2, %[INNER], 0 \n\t"
721
+ "BLOCK_INNER_LOOP%=: \n\t"
722
+
723
+ LOAD_B_16x8x2
724
+
725
+ "vsetvli t0, zero, e8, m1 \n\t"
726
+ "vle8.v v10, (a1) \n\t"
727
+ "addi a1, a1, 32 \n\t"
728
+ "vle8.v v11, (a1) \n\t"
729
+ "addi a1, a1, 32 \n\t"
730
+ "vadd.vi v2, v2, -8 \n\t"
731
+ "vadd.vi v3, v3, -8 \n\t"
732
+ "vadd.vi v4, v4, -8 \n\t"
733
+ "vadd.vi v5, v5, -8 \n\t"
734
+ "vadd.vi v6, v6, -8 \n\t"
735
+ "vadd.vi v7, v7, -8 \n\t"
736
+ "vadd.vi v8, v8, -8 \n\t"
737
+ "vadd.vi v9, v9, -8 \n\t"
738
+
739
+ SQ4BIT_KERNEL_COMP_4x16x16
740
+
741
+ "addi t2, t2, -1 \n\t"
742
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
743
+
744
+ LOAD_SCALE_4x16_FP16
745
+
746
+ "vsetvli t0, zero, e32, m8 \n\t"
747
+ "vfcvt.f.x.v v16, v16 \n\t"
748
+ "vfmacc.vv v24, v16, v8 \n\t"
749
+ "addi t3, t3, -1 \n\t"
750
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
751
+ "RESULT_SAVE%=: \n\t"
752
+
753
+ SAVE_RESULT_4x16
754
+
755
+ :
756
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
757
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
758
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4",
759
+ "s5", "s6");
2168
760
  }
2169
761
  }
2170
- if (CountN % 16 != 0) {
2171
- // stroe output from tmp to C when NBLKS less than 16.
2172
- float * CPtr = C + CountN / 16 * 16;
2173
- const size_t N = CountN % 16;
2174
- LDC = ldc * sizeof(float);
2175
- __asm__ volatile(
2176
- "vsetvli t0, %[N], e32, m2 \n\t"
2177
- "vle32.v v0, (%[SRC]) \n\t"
2178
- "addi s2, %[SRC], 64 \n\t"
2179
- "addi s3, %[SRC], 64*2 \n\t"
2180
- "addi s4, %[SRC], 64*3 \n\t"
2181
- "vle32.v v2, (s2) \n\t"
2182
- "vle32.v v4, (s3) \n\t"
2183
- "vle32.v v6, (s4) \n\t"
2184
- "add t2, %[DST], %[LDC] \n\t"
2185
- "add t3, t2, %[LDC] \n\t"
2186
- "add t4, t3, %[LDC] \n\t"
2187
- "vse32.v v0, (%[DST]) \n\t"
2188
- "vse32.v v2, (t2) \n\t"
2189
- "vse32.v v4, (t3) \n\t"
2190
- "vse32.v v6, (t4) \n\t"
2191
- :
2192
- : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC)
2193
- : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4");
2194
- }
2195
762
  }
2196
763
 
2197
764
  template <bool HasZeroPoint>
2198
- void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
2199
- const std::byte * QuantA,
2200
- const std::byte * QuantBData,
2201
- const float * QuantBScale,
2202
- const std::byte * QuantBZeroPoint,
2203
- float * C,
2204
- size_t CountN,
2205
- size_t BlockCountK,
2206
- const float * Bias) {
2207
- GGML_UNUSED(QuantBScale);
2208
- GGML_UNUSED(QuantBZeroPoint);
765
+ void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
766
+ const uint8_t * QuantA,
767
+ const uint8_t * QuantBData,
768
+ float * C,
769
+ size_t CountN,
770
+ size_t BlockCountK,
771
+ const size_t ldc) {
772
+ GGML_UNUSED(ldc);
2209
773
  size_t INNER = BlkLen / 16;
2210
774
 
2211
775
  if constexpr (HasZeroPoint) {
2212
776
  for (size_t n = 0; n < CountN; n += 16) {
2213
- size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2214
- std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2215
- n * BlockCountK * BlkLen / 2 + // b data
2216
- n * BlockCountK * sizeof(uint8_t) + // zp
2217
- n * BlockCountK * sizeof(_Float16); // scale
2218
- float * CPtr = C + n;
2219
- size_t cnt = BlockCountK;
2220
- if (Bias != nullptr) {
2221
- const float * bias = Bias + n;
2222
- __asm__ volatile(
2223
- "addi t3, %[NBLKS], 0 \n\t"
2224
- "vsetvli t0, zero, e8, m1 \n\t"
2225
-
2226
- "vmv.v.i v13, 3 \n\t"
2227
- "li s1, 24 \n\t"
2228
- "vsetvli t0, s1, e8, m1 \n\t"
2229
- "vmv.v.i v13, 2 \n\t"
2230
- "vsetvli t0, zero, e8, mf2 \n\t"
2231
- "vmv.v.i v13, 1 \n\t"
2232
- "vsetvli t0, zero, e8, mf4 \n\t"
2233
- "vmv.v.i v13, 0 \n\t"
2234
- "addi s1, %[B], 0 \n\t"
2235
- "addi s2, %[B], 8 \n\t"
2236
- "addi s3, %[B], 16 \n\t"
2237
- "addi s4, %[B], 24 \n\t"
2238
- // zp offset
2239
- "addi s7, %[B], 32 \n\t"
2240
- // a offset
2241
- "addi s5, %[A], 0 \n\t"
2242
- "addi s6, %[A], 12 \n\t"
2243
-
2244
- "vsetvli t0, t3, e32, mf2 \n\t"
2245
- "vle32.v v28, (%[BIAS]) \n\t"
2246
- "sub t3, t3, t0 \n\t"
2247
- "addi %[BIAS], %[BIAS], 16 \n\t"
2248
- "vsetvli t0, t3, e32, mf2 \n\t"
2249
- "vle32.v v29, (%[BIAS]) \n\t"
2250
- "sub t3, t3, t0 \n\t"
2251
- "addi %[BIAS], %[BIAS], 16 \n\t"
2252
- "vsetvli t0, t3, e32, mf2 \n\t"
2253
- "vle32.v v30, (%[BIAS]) \n\t"
2254
- "sub t3, t3, t0 \n\t"
2255
- "addi %[BIAS], %[BIAS], 16 \n\t"
2256
- "vsetvli t0, t3, e32, mf2 \n\t"
2257
- "vle32.v v31, (%[BIAS]) \n\t"
2258
-
2259
- "LOOP_K%=: \n\t"
2260
- "vsetvli t0, zero, e16, mf4 \n\t"
2261
-
2262
- "vle16.v v4, (s1) \n\t"
2263
- "addi s1, s1, 48 \n\t"
2264
- "vle16.v v5, (s2) \n\t"
2265
- "addi s2, s2, 72 \n\t"
2266
- "vle16.v v6, (s3) \n\t"
2267
- "addi s3, s3, 96 \n\t"
2268
- "vle16.v v7, (s4) \n\t"
2269
- "addi s4, s4, 120 \n\t"
2270
- "flw f1, (s5) \n\t"
2271
- "addi s5, s5, 4 \n\t"
2272
- "vfwcvt.f.f.v v8, v4 \n\t"
2273
- "vfwcvt.f.f.v v9, v5 \n\t"
2274
- "vfwcvt.f.f.v v10, v6 \n\t"
2275
- "vfwcvt.f.f.v v11, v7 \n\t"
2276
-
2277
- "vsetvli t0, zero, e32, mf2 \n\t"
2278
- "addi t5, %[INNER], 0 \n\t"
2279
- "vxor.vv v16, v16, v16 \n\t"
2280
- "vxor.vv v18, v18, v18 \n\t"
2281
- "vxor.vv v20, v20, v20 \n\t"
2282
- "vxor.vv v22, v22, v22 \n\t"
2283
- "vfmul.vf v24, v8, f1 \n\t"
2284
- "vfmul.vf v25, v9, f1 \n\t"
2285
- "vfmul.vf v26, v10, f1 \n\t"
2286
- "vfmul.vf v27, v11, f1 \n\t"
2287
- "addi %[CNT], %[CNT], -1 \n\t"
2288
-
2289
- SQ4BIT_KERNEL_LOAD_ZP_16X1
2290
-
2291
- "LOOP_INNER%=: \n\t"
2292
-
2293
- SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2294
-
2295
- "vsub.vv v0, v0, v8 \n\t"
2296
- "vsub.vv v4, v4, v8 \n\t"
2297
- "vsub.vv v1, v1, v9 \n\t"
2298
- "vsub.vv v5, v5, v9 \n\t"
2299
- "vsub.vv v2, v2, v10 \n\t"
2300
- "vsub.vv v6, v6, v10 \n\t"
2301
- "vsub.vv v3, v3, v11 \n\t"
2302
- "vsub.vv v7, v7, v11 \n\t"
2303
-
2304
- SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2305
-
2306
- "bnez t5, LOOP_INNER%= \n\t"
2307
- "vsetvli t0, zero, e32, mf2 \n\t"
2308
-
2309
- SQ4BIT_KERNEL_ACC_F16_1X4X4
2310
- "addi s7, s1, 32 \n\t"
2311
-
2312
- "bnez %[CNT], LOOP_K%= \n\t"
2313
- "addi t3, zero, 16 \n\t"
2314
- "addi s1, %[C], 16 \n\t"
2315
- "addi s2, %[C], 32 \n\t"
2316
- "addi s3, %[C], 48 \n\t"
2317
- "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2318
- "vse32.v v28, (%[C]) \n\t"
2319
- "vse32.v v29, (s1) \n\t"
2320
- "vse32.v v30, (s2) \n\t"
2321
- "vse32.v v31, (s3) \n\t"
2322
- "jal x0, END%= \n\t"
2323
-
2324
- "ST_TAIL%=: \n\t"
2325
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2326
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2327
- "vse32.v v28, (%[C]) \n\t"
2328
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2329
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2330
- "vse32.v v29, (s1) \n\t"
2331
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2332
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2333
- "vse32.v v30, (s2) \n\t"
2334
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2335
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2336
- "vse32.v v31, (s3) \n\t"
2337
- "END%=: \n\t"
2338
-
2339
- : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
2340
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2341
- : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2342
- } else {
2343
- __asm__ volatile(
2344
- "vsetvli t0, zero, e32, m4 \n\t"
2345
- "vxor.vv v28, v28, v28 \n\t"
2346
-
2347
- "vsetvli t0, zero, e8, m1 \n\t"
2348
- "vmv.v.i v13, 3 \n\t"
2349
- "li s1, 24 \n\t"
2350
- "vsetvli t0, s1, e8, m1 \n\t"
2351
- "vmv.v.i v13, 2 \n\t"
2352
- "vsetvli t0, zero, e8, mf2 \n\t"
2353
- "vmv.v.i v13, 1 \n\t"
2354
- "vsetvli t0, zero, e8, mf4 \n\t"
2355
- "vmv.v.i v13, 0 \n\t"
2356
-
2357
- "addi s1, %[B], 0 \n\t"
2358
- "addi s2, %[B], 8 \n\t"
2359
- "addi s3, %[B], 16 \n\t"
2360
- "addi s4, %[B], 24 \n\t"
2361
-
2362
- "addi s7, %[B], 32 \n\t"
2363
-
2364
- "addi s5, %[A], 0 \n\t"
2365
- "addi s6, %[A], 12 \n\t"
2366
- "LOOP_K%=: \n\t"
2367
- "vsetvli t0, zero, e16, mf4 \n\t"
2368
- "vle16.v v4, (s1) \n\t"
2369
- "addi s1, s1, 48 \n\t"
2370
- "vle16.v v5, (s2) \n\t"
2371
- "addi s2, s2, 72 \n\t"
2372
- "vle16.v v6, (s3) \n\t"
2373
- "addi s3, s3, 96 \n\t"
2374
- "vle16.v v7, (s4) \n\t"
2375
- "addi s4, s4, 120 \n\t"
2376
- "flw f1, (s5) \n\t"
2377
- "addi s5, s5, 4 \n\t"
2378
-
2379
- "vfwcvt.f.f.v v8, v4 \n\t"
2380
- "vfwcvt.f.f.v v9, v5 \n\t"
2381
- "vfwcvt.f.f.v v10, v6 \n\t"
2382
- "vfwcvt.f.f.v v11, v7 \n\t"
2383
- "vsetvli t0, zero, e32, mf2 \n\t"
2384
-
2385
- "addi t5, %[INNER], 0 \n\t"
2386
- "vxor.vv v16, v16, v16 \n\t"
2387
- "vxor.vv v18, v18, v18 \n\t"
2388
- "vxor.vv v20, v20, v20 \n\t"
2389
- "vxor.vv v22, v22, v22 \n\t"
2390
- "vfmul.vf v24, v8, f1 \n\t"
2391
- "vfmul.vf v25, v9, f1 \n\t"
2392
- "vfmul.vf v26, v10, f1 \n\t"
2393
- "vfmul.vf v27, v11, f1 \n\t"
2394
- "addi %[CNT], %[CNT], -1 \n\t"
2395
-
2396
- SQ4BIT_KERNEL_LOAD_ZP_16X1
2397
-
2398
- "LOOP_INNER%=: \n\t"
2399
-
2400
- SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2401
-
2402
- "vsub.vv v0, v0, v8 \n\t"
2403
- "vsub.vv v4, v4, v8 \n\t"
2404
- "vsub.vv v1, v1, v9 \n\t"
2405
- "vsub.vv v5, v5, v9 \n\t"
2406
- "vsub.vv v2, v2, v10 \n\t"
2407
- "vsub.vv v6, v6, v10 \n\t"
2408
- "vsub.vv v3, v3, v11 \n\t"
2409
- "vsub.vv v7, v7, v11 \n\t"
2410
-
2411
- SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2412
-
2413
- "bnez t5, LOOP_INNER%= \n\t"
2414
- "vsetvli t0, zero, e32, mf2 \n\t"
2415
-
2416
- SQ4BIT_KERNEL_ACC_F16_1X4X4
2417
- "addi s7, s1, 32 \n\t"
2418
-
2419
- "bnez %[CNT], LOOP_K%= \n\t"
2420
- "addi t3, zero, 16 \n\t"
2421
- "addi s1, %[C], 16 \n\t"
2422
- "addi s2, %[C], 32 \n\t"
2423
- "addi s3, %[C], 48 \n\t"
2424
- "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2425
- "vse32.v v28, (%[C]) \n\t"
2426
- "vse32.v v29, (s1) \n\t"
2427
- "vse32.v v30, (s2) \n\t"
2428
- "vse32.v v31, (s3) \n\t"
2429
- "jal x0, END%= \n\t"
2430
-
2431
- "ST_TAIL%=: \n\t"
2432
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2433
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2434
- "vse32.v v28, (%[C]) \n\t"
2435
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2436
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2437
- "vse32.v v29, (s1) \n\t"
2438
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2439
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2440
- "vse32.v v30, (s2) \n\t"
2441
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2442
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2443
- "vse32.v v31, (s3) \n\t"
2444
- "END%=: \n\t"
2445
-
2446
- : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
2447
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2448
- : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2449
- }
2450
- }
2451
- } else {
2452
- for (size_t n = 0; n < CountN; n += 16) {
2453
- size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2454
- std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2455
- n * BlockCountK * BlkLen / 2 + // b data
2456
- n * BlockCountK * sizeof(_Float16); // scale
777
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
778
+ uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + //
779
+ n * BlockCountK * BlkLen / 2 + // b data
780
+ n * BlockCountK * sizeof(uint8_t) + // zp
781
+ n * BlockCountK * sizeof(_Float16); // scale
2457
782
  float * CPtr = C + n;
2458
783
  size_t cnt = BlockCountK;
2459
- if (Bias != nullptr) {
2460
- const float * bias = Bias + n;
2461
- __asm__ volatile(
2462
- "addi t3, %[NBLKS], 0 \n\t"
2463
- "addi s1, %[B], 0 \n\t"
2464
- "addi s2, %[B], 8 \n\t"
2465
- "addi s3, %[B], 16 \n\t"
2466
- "addi s4, %[B], 24 \n\t"
2467
- "addi s5, %[A], 0 \n\t"
2468
- "addi s6, %[A], 12 \n\t"
2469
- "vsetvli t0, t3, e32, mf2 \n\t"
2470
- "vle32.v v28, (%[BIAS]) \n\t"
2471
- "sub t3, t3, t0 \n\t"
2472
- "addi %[BIAS], %[BIAS], 16 \n\t"
2473
- "vsetvli t0, t3, e32, mf2 \n\t"
2474
- "vle32.v v29, (%[BIAS]) \n\t"
2475
- "sub t3, t3, t0 \n\t"
2476
- "addi %[BIAS], %[BIAS], 16 \n\t"
2477
- "vsetvli t0, t3, e32, mf2 \n\t"
2478
- "vle32.v v30, (%[BIAS]) \n\t"
2479
- "sub t3, t3, t0 \n\t"
2480
- "addi %[BIAS], %[BIAS], 16 \n\t"
2481
- "vsetvli t0, t3, e32, mf2 \n\t"
2482
- "vle32.v v31, (%[BIAS]) \n\t"
2483
-
2484
- "LOOP_K%=: \n\t"
2485
- "vsetvli t0, zero, e16, mf4 \n\t"
2486
-
2487
- "vle16.v v4, (s1) \n\t"
2488
- "addi s1, s1, 32 \n\t"
2489
- "vle16.v v5, (s2) \n\t"
2490
- "addi s2, s2, 56 \n\t"
2491
- "vle16.v v6, (s3) \n\t"
2492
- "addi s3, s3, 80 \n\t"
2493
- "vle16.v v7, (s4) \n\t"
2494
- "addi s4, s4, 104 \n\t"
2495
- "flw f1, (s5) \n\t"
2496
- "addi s5, s5, 4 \n\t"
2497
- "vfwcvt.f.f.v v8, v4 \n\t"
2498
- "vfwcvt.f.f.v v9, v5 \n\t"
2499
- "vfwcvt.f.f.v v10, v6 \n\t"
2500
- "vfwcvt.f.f.v v11, v7 \n\t"
2501
-
2502
- "vsetvli t0, zero, e32, mf2 \n\t"
2503
- "addi t5, %[INNER], 0 \n\t"
2504
- "vxor.vv v16, v16, v16 \n\t"
2505
- "vxor.vv v18, v18, v18 \n\t"
2506
- "vxor.vv v20, v20, v20 \n\t"
2507
- "vxor.vv v22, v22, v22 \n\t"
2508
- "vfmul.vf v24, v8, f1 \n\t"
2509
- "vfmul.vf v25, v9, f1 \n\t"
2510
- "vfmul.vf v26, v10, f1 \n\t"
2511
- "vfmul.vf v27, v11, f1 \n\t"
2512
- "addi %[CNT], %[CNT], -1 \n\t"
2513
- "vsetvli t0, zero, e8, m1 \n\t"
2514
- "LOOP_INNER%=: \n\t"
2515
-
2516
- SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2517
-
2518
- "vadd.vi v0, v0, -8 \n\t"
2519
- "vadd.vi v1, v1, -8 \n\t"
2520
- "vadd.vi v2, v2, -8 \n\t"
2521
- "vadd.vi v3, v3, -8 \n\t"
2522
- "vadd.vi v4, v4, -8 \n\t"
2523
- "vadd.vi v5, v5, -8 \n\t"
2524
- "vadd.vi v6, v6, -8 \n\t"
2525
- "vadd.vi v7, v7, -8 \n\t"
2526
-
2527
- SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2528
-
2529
- "bnez t5, LOOP_INNER%= \n\t"
2530
- "vsetvli t0, zero, e32, mf2 \n\t"
2531
-
2532
- SQ4BIT_KERNEL_ACC_F16_1X4X4
2533
-
2534
- "bnez %[CNT], LOOP_K%= \n\t"
2535
- "addi t3, zero, 16 \n\t"
2536
- "addi s1, %[C], 16 \n\t"
2537
- "addi s2, %[C], 32 \n\t"
2538
- "addi s3, %[C], 48 \n\t"
2539
- "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2540
- "vse32.v v28, (%[C]) \n\t"
2541
- "vse32.v v29, (s1) \n\t"
2542
- "vse32.v v30, (s2) \n\t"
2543
- "vse32.v v31, (s3) \n\t"
2544
- "jal x0, END%= \n\t"
2545
-
2546
- "ST_TAIL%=: \n\t"
2547
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2548
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2549
- "vse32.v v28, (%[C]) \n\t"
2550
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2551
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2552
- "vse32.v v29, (s1) \n\t"
2553
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2554
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2555
- "vse32.v v30, (s2) \n\t"
2556
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2557
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2558
- "vse32.v v31, (s3) \n\t"
2559
- "END%=: \n\t"
2560
-
2561
- : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
2562
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2563
- : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
2564
- } else {
2565
- __asm__ volatile(
2566
- "vsetvli t0, zero, e32, m4 \n\t"
2567
- "vxor.vv v28, v28, v28 \n\t"
2568
- "addi s1, %[B], 0 \n\t"
2569
- "addi s2, %[B], 8 \n\t"
2570
- "addi s3, %[B], 16 \n\t"
2571
- "addi s4, %[B], 24 \n\t"
2572
-
2573
- "addi s5, %[A], 0 \n\t"
2574
- "addi s6, %[A], 12 \n\t"
2575
- "LOOP_K%=: \n\t"
2576
- "vsetvli t0, zero, e16, mf4 \n\t"
2577
- "vle16.v v4, (s1) \n\t"
2578
- "addi s1, s1, 32 \n\t"
2579
- "vle16.v v5, (s2) \n\t"
2580
- "addi s2, s2, 56 \n\t"
2581
- "vle16.v v6, (s3) \n\t"
2582
- "addi s3, s3, 80 \n\t"
2583
- "vle16.v v7, (s4) \n\t"
2584
- "addi s4, s4, 104 \n\t"
2585
- "flw f1, (s5) \n\t"
2586
- "addi s5, s5, 4 \n\t"
2587
-
2588
- "vfwcvt.f.f.v v8, v4 \n\t"
2589
- "vfwcvt.f.f.v v9, v5 \n\t"
2590
- "vfwcvt.f.f.v v10, v6 \n\t"
2591
- "vfwcvt.f.f.v v11, v7 \n\t"
2592
- "vsetvli t0, zero, e32, mf2 \n\t"
2593
-
2594
- "addi t5, %[INNER], 0 \n\t"
2595
- "vxor.vv v16, v16, v16 \n\t"
2596
- "vxor.vv v18, v18, v18 \n\t"
2597
- "vxor.vv v20, v20, v20 \n\t"
2598
- "vxor.vv v22, v22, v22 \n\t"
2599
- "vfmul.vf v24, v8, f1 \n\t"
2600
- "vfmul.vf v25, v9, f1 \n\t"
2601
- "vfmul.vf v26, v10, f1 \n\t"
2602
- "vfmul.vf v27, v11, f1 \n\t"
2603
- "addi %[CNT], %[CNT], -1 \n\t"
2604
- "vsetvli t0, zero, e8, m1 \n\t"
2605
- "LOOP_INNER%=: \n\t"
2606
-
2607
- SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2608
-
2609
- "vadd.vi v0, v0, -8 \n\t"
2610
- "vadd.vi v1, v1, -8 \n\t"
2611
- "vadd.vi v2, v2, -8 \n\t"
2612
- "vadd.vi v3, v3, -8 \n\t"
2613
- "vadd.vi v4, v4, -8 \n\t"
2614
- "vadd.vi v5, v5, -8 \n\t"
2615
- "vadd.vi v6, v6, -8 \n\t"
2616
- "vadd.vi v7, v7, -8 \n\t"
2617
-
2618
- SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2619
-
2620
- "bnez t5, LOOP_INNER%= \n\t"
2621
- "vsetvli t0, zero, e32, mf2 \n\t"
2622
-
2623
- SQ4BIT_KERNEL_ACC_F16_1X4X4
2624
-
2625
- "bnez %[CNT], LOOP_K%= \n\t"
2626
- "addi t3, zero, 16 \n\t"
2627
- "addi s1, %[C], 16 \n\t"
2628
- "addi s2, %[C], 32 \n\t"
2629
- "addi s3, %[C], 48 \n\t"
2630
- "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2631
- "vse32.v v28, (%[C]) \n\t"
2632
- "vse32.v v29, (s1) \n\t"
2633
- "vse32.v v30, (s2) \n\t"
2634
- "vse32.v v31, (s3) \n\t"
2635
- "jal x0, END%= \n\t"
2636
-
2637
- "ST_TAIL%=: \n\t"
2638
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2639
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2640
- "vse32.v v28, (%[C]) \n\t"
2641
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2642
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2643
- "vse32.v v29, (s1) \n\t"
2644
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2645
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2646
- "vse32.v v30, (s2) \n\t"
2647
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2648
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2649
- "vse32.v v31, (s3) \n\t"
2650
- "END%=: \n\t"
2651
-
2652
- : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
2653
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2654
- : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
2655
- }
2656
- }
2657
- }
2658
- }
2659
784
 
2660
- template <bool HasZeroPoint>
2661
- void SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen,
2662
- const std::byte * QuantA,
2663
- const std::byte * QuantBData,
2664
- const float * QuantBScale,
2665
- const std::byte * QuantBZeroPoint,
2666
- float * C,
2667
- size_t CountN,
2668
- size_t BlockCountK,
2669
- const float * Bias) {
2670
- GGML_UNUSED(QuantBScale);
2671
- GGML_UNUSED(QuantBZeroPoint);
2672
- const size_t INNER = BlkLen / 16;
2673
- if constexpr (HasZeroPoint) {
2674
- for (size_t n = 0; n < CountN; n += 16) {
2675
- size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2676
- std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2677
- n * BlockCountK * BlkLen / 2 + // b data
2678
- n * BlockCountK * sizeof(uint8_t) + // zp
2679
- n * BlockCountK * sizeof(float); // scale
2680
- float * CPtr = C + n;
2681
- size_t cnt = BlockCountK;
2682
- if (Bias != nullptr) {
2683
- const float * bias = Bias + n;
2684
- __asm__ volatile(
2685
- "addi t3, %[NBLKS], 0 \n\t"
2686
- "vsetvli t0, zero, e8, m1 \n\t"
2687
- "vmv.v.i v13, 3 \n\t"
2688
- "li s1, 24 \n\t"
2689
- "vsetvli t0, s1, e8, m1 \n\t"
2690
- "vmv.v.i v13, 2 \n\t"
2691
- "vsetvli t0, zero, e8, mf2 \n\t"
2692
- "vmv.v.i v13, 1 \n\t"
2693
- "vsetvli t0, zero, e8, mf4 \n\t"
2694
- "vmv.v.i v13, 0 \n\t"
2695
- "vsetvli t0, zero, e32, m4 \n\t"
2696
- "vxor.vv v28, v28, v28 \n\t"
2697
-
2698
- // scale offset, scale0.0, scale1.0, scale2.0, scale3.0....scale15.0
2699
- "addi s1, %[B], 0 \n\t"
2700
- "addi s2, %[B], 16 \n\t"
2701
- "addi s3, %[B], 32 \n\t"
2702
- "addi s4, %[B], 48 \n\t"
2703
- // zp offset
2704
- "addi s7, %[B], 64 \n\t"
2705
- // a offset
2706
- "addi s5, %[A], 0 \n\t"
2707
- "addi s6, %[A], 12 \n\t"
2708
-
2709
- "vsetvli t0, t3, e32, mf2 \n\t"
2710
- "vle32.v v28, (%[BIAS]) \n\t"
2711
- "sub t3, t3, t0 \n\t"
2712
- "addi %[BIAS], %[BIAS], 16 \n\t"
2713
- "vsetvli t0, t3, e32, mf2 \n\t"
2714
- "vle32.v v29, (%[BIAS]) \n\t"
2715
- "sub t3, t3, t0 \n\t"
2716
- "addi %[BIAS], %[BIAS], 16 \n\t"
2717
- "vsetvli t0, t3, e32, mf2 \n\t"
2718
- "vle32.v v30, (%[BIAS]) \n\t"
2719
- "sub t3, t3, t0 \n\t"
2720
- "addi %[BIAS], %[BIAS], 16 \n\t"
2721
- "vsetvli t0, t3, e32, mf2 \n\t"
2722
- "vle32.v v31, (%[BIAS]) \n\t"
2723
- "vsetvli t0, zero, e32, mf2 \n\t"
2724
- "LOOP_K%=: \n\t"
2725
-
2726
- // load scale
2727
- "vle32.v v8, (s1) \n\t"
2728
- "addi s1, s1, 80 \n\t"
2729
- "vle32.v v9, (s2) \n\t"
2730
- "addi s2, s2, 96 \n\t"
2731
- "vle32.v v10, (s3) \n\t"
2732
- "addi s3, s3, 112 \n\t"
2733
- "vle32.v v11, (s4) \n\t"
2734
- "addi s4, s4, 128 \n\t"
2735
-
2736
- // load a scale
2737
- "flw f1, (s5) \n\t"
2738
- "addi s5, s5, 4 \n\t"
2739
-
2740
- "addi t5, %[INNER], 0 \n\t"
2741
- "vxor.vv v16, v16, v16 \n\t"
2742
- "vxor.vv v18, v18, v18 \n\t"
2743
- "vxor.vv v20, v20, v20 \n\t"
2744
- "vxor.vv v22, v22, v22 \n\t"
2745
-
2746
- // a scale * b scale
2747
- "vfmul.vf v24, v8, f1 \n\t"
2748
- "vfmul.vf v25, v9, f1 \n\t"
2749
- "vfmul.vf v26, v10, f1 \n\t"
2750
- "vfmul.vf v27, v11, f1 \n\t"
2751
- "addi %[CNT], %[CNT], -1 \n\t"
2752
-
2753
- SQ4BIT_KERNEL_LOAD_ZP_16X1
2754
-
2755
- "LOOP_INNER%=: \n\t"
2756
-
2757
- SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2758
-
2759
- "vsub.vv v0, v0, v8 \n\t"
2760
- "vsub.vv v4, v4, v8 \n\t"
2761
- "vsub.vv v1, v1, v9 \n\t"
2762
- "vsub.vv v5, v5, v9 \n\t"
2763
- "vsub.vv v2, v2, v10 \n\t"
2764
- "vsub.vv v6, v6, v10 \n\t"
2765
- "vsub.vv v3, v3, v11 \n\t"
2766
- "vsub.vv v7, v7, v11 \n\t"
2767
-
2768
- SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2769
-
2770
- "bnez t5, LOOP_INNER%= \n\t"
2771
- "vsetvli t0, zero, e32, mf2 \n\t"
2772
-
2773
- SQ4BIT_KERNEL_ACC_1X4X4
2774
- "addi s7, s1, 64 \n\t"
2775
-
2776
- "bnez %[CNT], LOOP_K%= \n\t"
2777
-
2778
- "addi t3, zero, 16 \n\t"
2779
- "addi s1, %[C], 16 \n\t"
2780
- "addi s2, %[C], 32 \n\t"
2781
- "addi s3, %[C], 48 \n\t"
2782
- "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2783
- "vse32.v v28, (%[C]) \n\t"
2784
- "vse32.v v29, (s1) \n\t"
2785
- "vse32.v v30, (s2) \n\t"
2786
- "vse32.v v31, (s3) \n\t"
2787
- "jal x0, END%= \n\t"
2788
-
2789
- "ST_TAIL%=: \n\t"
2790
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2791
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2792
- "vse32.v v28, (%[C]) \n\t"
2793
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2794
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2795
- "vse32.v v29, (s1) \n\t"
2796
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2797
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2798
- "vse32.v v30, (s2) \n\t"
2799
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2800
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2801
- "vse32.v v31, (s3) \n\t"
2802
- "END%=: \n\t"
2803
-
2804
- : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
2805
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2806
- : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2807
- } else {
2808
- __asm__ volatile(
2809
- "vsetvli t0, zero, e32, m4 \n\t"
2810
- "vxor.vv v28, v28, v28 \n\t"
2811
-
2812
- "vsetvli t0, zero, e8, m1 \n\t"
2813
- "vmv.v.i v13, 3 \n\t"
2814
- "li s1, 24 \n\t"
2815
- "vsetvli t0, s1, e8, m1 \n\t"
2816
- "vmv.v.i v13, 2 \n\t"
2817
- "vsetvli t0, zero, e8, mf2 \n\t"
2818
- "vmv.v.i v13, 1 \n\t"
2819
- "vsetvli t0, zero, e8, mf4 \n\t"
2820
- "vmv.v.i v13, 0 \n\t"
2821
- "addi s1, %[B], 0 \n\t"
2822
- "addi s2, %[B], 16 \n\t"
2823
- "addi s3, %[B], 32 \n\t"
2824
- "addi s4, %[B], 48 \n\t"
2825
-
2826
- "addi s7, %[B], 64 \n\t"
2827
-
2828
- "addi s5, %[A], 0 \n\t"
2829
- "addi s6, %[A], 12 \n\t"
2830
- "vsetvli t0, zero, e32, mf2 \n\t"
2831
-
2832
- "LOOP_K%=: \n\t"
2833
- "vle32.v v8, (s1) \n\t"
2834
- "addi s1, s1, 80 \n\t"
2835
- "vle32.v v9, (s2) \n\t"
2836
- "addi s2, s2, 96 \n\t"
2837
- "vle32.v v10, (s3) \n\t"
2838
- "addi s3, s3, 112 \n\t"
2839
- "vle32.v v11, (s4) \n\t"
2840
- "addi s4, s4, 128 \n\t"
2841
-
2842
- "flw f1, (s5) \n\t"
2843
- "addi s5, s5, 4 \n\t"
2844
-
2845
- "addi t5, %[INNER], 0 \n\t"
2846
- "vxor.vv v16, v16, v16 \n\t"
2847
- "vxor.vv v18, v18, v18 \n\t"
2848
- "vxor.vv v20, v20, v20 \n\t"
2849
- "vxor.vv v22, v22, v22 \n\t"
2850
-
2851
- "vfmul.vf v24, v8, f1 \n\t"
2852
- "vfmul.vf v25, v9, f1 \n\t"
2853
- "vfmul.vf v26, v10, f1 \n\t"
2854
- "vfmul.vf v27, v11, f1 \n\t"
2855
- "addi %[CNT], %[CNT], -1 \n\t"
2856
-
2857
- SQ4BIT_KERNEL_LOAD_ZP_16X1
2858
-
2859
- "LOOP_INNER%=: \n\t"
2860
-
2861
- SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2862
-
2863
- "vsub.vv v0, v0, v8 \n\t"
2864
- "vsub.vv v4, v4, v8 \n\t"
2865
- "vsub.vv v1, v1, v9 \n\t"
2866
- "vsub.vv v5, v5, v9 \n\t"
2867
- "vsub.vv v2, v2, v10 \n\t"
2868
- "vsub.vv v6, v6, v10 \n\t"
2869
- "vsub.vv v3, v3, v11 \n\t"
2870
- "vsub.vv v7, v7, v11 \n\t"
2871
-
2872
- SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2873
-
2874
- "bnez t5, LOOP_INNER%= \n\t"
2875
- "vsetvli t0, zero, e32, mf2 \n\t"
2876
-
2877
- SQ4BIT_KERNEL_ACC_1X4X4
2878
- "addi s7, s1, 64 \n\t"
2879
-
2880
- "bnez %[CNT], LOOP_K%= \n\t"
2881
-
2882
- "addi t3, zero, 16 \n\t"
2883
- "addi s1, %[C], 16 \n\t"
2884
- "addi s2, %[C], 32 \n\t"
2885
- "addi s3, %[C], 48 \n\t"
2886
- "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2887
- "vse32.v v28, (%[C]) \n\t"
2888
- "vse32.v v29, (s1) \n\t"
2889
- "vse32.v v30, (s2) \n\t"
2890
- "vse32.v v31, (s3) \n\t"
2891
- "jal x0, END%= \n\t"
2892
-
2893
- "ST_TAIL%=: \n\t"
2894
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2895
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2896
- "vse32.v v28, (%[C]) \n\t"
2897
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2898
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2899
- "vse32.v v29, (s1) \n\t"
2900
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2901
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2902
- "vse32.v v30, (s2) \n\t"
2903
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2904
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2905
- "vse32.v v31, (s3) \n\t"
2906
- "END%=: \n\t"
2907
-
2908
- : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
2909
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2910
- : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2911
- }
785
+ __asm__ volatile(
786
+ "vsetvli t0, zero, e32, m4 \n\t"
787
+ "vxor.vv v28, v28, v28 \n\t"
788
+
789
+ "vsetvli t0, zero, e8, m1 \n\t"
790
+ "vmv.v.i v13, 3 \n\t"
791
+ "li s1, 24 \n\t"
792
+ "vsetvli t0, s1, e8, m1 \n\t"
793
+ "vmv.v.i v13, 2 \n\t"
794
+ "vsetvli t0, zero, e8, mf2 \n\t"
795
+ "vmv.v.i v13, 1 \n\t"
796
+ "vsetvli t0, zero, e8, mf4 \n\t"
797
+ "vmv.v.i v13, 0 \n\t"
798
+
799
+ "addi s1, %[B], 0 \n\t"
800
+ "addi s2, %[B], 8 \n\t"
801
+ "addi s3, %[B], 16 \n\t"
802
+ "addi s4, %[B], 24 \n\t"
803
+
804
+ "addi s7, %[B], 32 \n\t"
805
+
806
+ "addi s5, %[A], 0 \n\t"
807
+ "addi s6, %[A], 12 \n\t"
808
+ "LOOP_K%=: \n\t"
809
+ "vsetvli t0, zero, e16, mf4 \n\t"
810
+ "vle16.v v4, (s1) \n\t"
811
+ "addi s1, s1, 48 \n\t"
812
+ "vle16.v v5, (s2) \n\t"
813
+ "addi s2, s2, 72 \n\t"
814
+ "vle16.v v6, (s3) \n\t"
815
+ "addi s3, s3, 96 \n\t"
816
+ "vle16.v v7, (s4) \n\t"
817
+ "addi s4, s4, 120 \n\t"
818
+ "flw f1, (s5) \n\t"
819
+ "addi s5, s5, 4 \n\t"
820
+
821
+ "vfwcvt.f.f.v v8, v4 \n\t"
822
+ "vfwcvt.f.f.v v9, v5 \n\t"
823
+ "vfwcvt.f.f.v v10, v6 \n\t"
824
+ "vfwcvt.f.f.v v11, v7 \n\t"
825
+ "vsetvli t0, zero, e32, mf2 \n\t"
826
+
827
+ "addi t5, %[INNER], 0 \n\t"
828
+ "vxor.vv v16, v16, v16 \n\t"
829
+ "vxor.vv v18, v18, v18 \n\t"
830
+ "vxor.vv v20, v20, v20 \n\t"
831
+ "vxor.vv v22, v22, v22 \n\t"
832
+ "vfmul.vf v24, v8, f1 \n\t"
833
+ "vfmul.vf v25, v9, f1 \n\t"
834
+ "vfmul.vf v26, v10, f1 \n\t"
835
+ "vfmul.vf v27, v11, f1 \n\t"
836
+ "addi %[CNT], %[CNT], -1 \n\t"
837
+
838
+ SQ4BIT_KERNEL_LOAD_ZP_16X1
839
+
840
+ "LOOP_INNER%=: \n\t"
841
+
842
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
843
+
844
+ "vsub.vv v0, v0, v8 \n\t"
845
+ "vsub.vv v4, v4, v8 \n\t"
846
+ "vsub.vv v1, v1, v9 \n\t"
847
+ "vsub.vv v5, v5, v9 \n\t"
848
+ "vsub.vv v2, v2, v10 \n\t"
849
+ "vsub.vv v6, v6, v10 \n\t"
850
+ "vsub.vv v3, v3, v11 \n\t"
851
+ "vsub.vv v7, v7, v11 \n\t"
852
+
853
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
854
+
855
+ "bnez t5, LOOP_INNER%= \n\t"
856
+ "vsetvli t0, zero, e32, mf2 \n\t"
857
+
858
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
859
+ "addi s7, s1, 32 \n\t"
860
+
861
+ "bnez %[CNT], LOOP_K%= \n\t"
862
+ "addi t3, zero, 16 \n\t"
863
+ "addi s1, %[C], 16 \n\t"
864
+ "addi s2, %[C], 32 \n\t"
865
+ "addi s3, %[C], 48 \n\t"
866
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
867
+ "vse32.v v28, (%[C]) \n\t"
868
+ "vse32.v v29, (s1) \n\t"
869
+ "vse32.v v30, (s2) \n\t"
870
+ "vse32.v v31, (s3) \n\t"
871
+ "jal x0, END%= \n\t"
872
+
873
+ "ST_TAIL%=: \n\t"
874
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
875
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
876
+ "vse32.v v28, (%[C]) \n\t"
877
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
878
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
879
+ "vse32.v v29, (s1) \n\t"
880
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
881
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
882
+ "vse32.v v30, (s2) \n\t"
883
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
884
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
885
+ "vse32.v v31, (s3) \n\t"
886
+ "END%=: \n\t"
887
+
888
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
889
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
890
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2912
891
  }
2913
892
  } else {
2914
893
  for (size_t n = 0; n < CountN; n += 16) {
2915
- size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2916
- std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2917
- n * BlockCountK * BlkLen / 2 + // b data
2918
- n * BlockCountK * sizeof(float); // scale
894
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
895
+ uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + //
896
+ n * BlockCountK * BlkLen / 2 + // b data
897
+ n * BlockCountK * sizeof(_Float16); // scale
2919
898
  float * CPtr = C + n;
2920
899
  size_t cnt = BlockCountK;
2921
- if (Bias != nullptr) {
2922
- const float * bias = Bias + n;
2923
- __asm__ volatile(
2924
- "addi t3, %[NBLKS], 0 \n\t"
2925
- "addi s1, %[B], 0 \n\t"
2926
- "addi s2, %[B], 16 \n\t"
2927
- "addi s3, %[B], 32 \n\t"
2928
- "addi s4, %[B], 48 \n\t"
2929
- "addi s5, %[A], 0 \n\t"
2930
- "addi s6, %[A], 12 \n\t"
2931
- "vsetvli t0, t3, e32, mf2 \n\t"
2932
- "vle32.v v28, (%[BIAS]) \n\t"
2933
- "sub t3, t3, t0 \n\t"
2934
- "addi %[BIAS], %[BIAS], 16 \n\t"
2935
- "vsetvli t0, t3, e32, mf2 \n\t"
2936
- "vle32.v v29, (%[BIAS]) \n\t"
2937
- "sub t3, t3, t0 \n\t"
2938
- "addi %[BIAS], %[BIAS], 16 \n\t"
2939
- "vsetvli t0, t3, e32, mf2 \n\t"
2940
- "vle32.v v30, (%[BIAS]) \n\t"
2941
- "sub t3, t3, t0 \n\t"
2942
- "addi %[BIAS], %[BIAS], 16 \n\t"
2943
- "vsetvli t0, t3, e32, mf2 \n\t"
2944
- "vle32.v v31, (%[BIAS]) \n\t"
2945
- "vsetvli t0, zero, e32, mf2 \n\t"
2946
- "LOOP_K%=: \n\t"
2947
- "vle32.v v8, (s1) \n\t"
2948
- "addi s1, s1, 64 \n\t"
2949
- "vle32.v v9, (s2) \n\t"
2950
- "addi s2, s2, 80 \n\t"
2951
- "vle32.v v10, (s3) \n\t"
2952
- "addi s3, s3, 96 \n\t"
2953
- "vle32.v v11, (s4) \n\t"
2954
- "addi s4, s4, 112 \n\t"
2955
- "flw f1, (s5) \n\t"
2956
- "addi s5, s5, 4 \n\t"
2957
-
2958
- "addi t5, %[INNER], 0 \n\t"
2959
- "vxor.vv v16, v16, v16 \n\t"
2960
- "vxor.vv v18, v18, v18 \n\t"
2961
- "vxor.vv v20, v20, v20 \n\t"
2962
- "vxor.vv v22, v22, v22 \n\t"
2963
- "vfmul.vf v24, v8, f1 \n\t"
2964
- "vfmul.vf v25, v9, f1 \n\t"
2965
- "vfmul.vf v26, v10, f1 \n\t"
2966
- "vfmul.vf v27, v11, f1 \n\t"
2967
- "addi %[CNT], %[CNT], -1 \n\t"
2968
- "vsetvli t0, zero, e8, m1 \n\t"
2969
- "LOOP_INNER%=: \n\t"
2970
-
2971
- SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2972
-
2973
- "vadd.vi v0, v0, -8 \n\t"
2974
- "vadd.vi v1, v1, -8 \n\t"
2975
- "vadd.vi v2, v2, -8 \n\t"
2976
- "vadd.vi v3, v3, -8 \n\t"
2977
- "vadd.vi v4, v4, -8 \n\t"
2978
- "vadd.vi v5, v5, -8 \n\t"
2979
- "vadd.vi v6, v6, -8 \n\t"
2980
- "vadd.vi v7, v7, -8 \n\t"
2981
-
2982
- SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2983
-
2984
- "bnez t5, LOOP_INNER%= \n\t"
2985
- "vsetvli t0, zero, e32, mf2 \n\t"
2986
-
2987
- SQ4BIT_KERNEL_ACC_1X4X4
2988
-
2989
- "bnez %[CNT], LOOP_K%= \n\t"
2990
- "addi t3, zero, 16 \n\t"
2991
- "addi s1, %[C], 16 \n\t"
2992
- "addi s2, %[C], 32 \n\t"
2993
- "addi s3, %[C], 48 \n\t"
2994
- "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2995
- "vse32.v v28, (%[C]) \n\t"
2996
- "vse32.v v29, (s1) \n\t"
2997
- "vse32.v v30, (s2) \n\t"
2998
- "vse32.v v31, (s3) \n\t"
2999
- "jal x0, END%= \n\t"
3000
-
3001
- "ST_TAIL%=: \n\t"
3002
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3003
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
3004
- "vse32.v v28, (%[C]) \n\t"
3005
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3006
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
3007
- "vse32.v v29, (s1) \n\t"
3008
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3009
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
3010
- "vse32.v v30, (s2) \n\t"
3011
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3012
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
3013
- "vse32.v v31, (s3) \n\t"
3014
- "END%=: \n\t"
3015
-
3016
- : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
3017
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
3018
- : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
3019
- } else {
3020
- __asm__ volatile(
3021
- "vsetvli t0, zero, e32, m4 \n\t"
3022
- "vxor.vv v28, v28, v28 \n\t"
3023
- "addi s1, %[B], 0 \n\t"
3024
- "addi s2, %[B], 16 \n\t"
3025
- "addi s3, %[B], 32 \n\t"
3026
- "addi s4, %[B], 48 \n\t"
3027
-
3028
- "addi s5, %[A], 0 \n\t"
3029
- "addi s6, %[A], 12 \n\t"
3030
- "vsetvli t0, zero, e32, mf2 \n\t"
3031
- "LOOP_K%=: \n\t"
3032
- "vle32.v v8, (s1) \n\t"
3033
- "addi s1, s1, 64 \n\t"
3034
- "vle32.v v9, (s2) \n\t"
3035
- "addi s2, s2, 80 \n\t"
3036
- "vle32.v v10, (s3) \n\t"
3037
- "addi s3, s3, 96 \n\t"
3038
- "vle32.v v11, (s4) \n\t"
3039
- "addi s4, s4, 112 \n\t"
3040
- "flw f1, (s5) \n\t"
3041
- "addi s5, s5, 4 \n\t"
3042
-
3043
- "addi t5, %[INNER], 0 \n\t"
3044
- "vxor.vv v16, v16, v16 \n\t"
3045
- "vxor.vv v18, v18, v18 \n\t"
3046
- "vxor.vv v20, v20, v20 \n\t"
3047
- "vxor.vv v22, v22, v22 \n\t"
3048
- "vfmul.vf v24, v8, f1 \n\t"
3049
- "vfmul.vf v25, v9, f1 \n\t"
3050
- "vfmul.vf v26, v10, f1 \n\t"
3051
- "vfmul.vf v27, v11, f1 \n\t"
3052
- "addi %[CNT], %[CNT], -1 \n\t"
3053
- "vsetvli t0, zero, e8, m1 \n\t"
3054
- "LOOP_INNER%=: \n\t"
3055
-
3056
- SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
3057
-
3058
- "vadd.vi v0, v0, -8 \n\t"
3059
- "vadd.vi v1, v1, -8 \n\t"
3060
- "vadd.vi v2, v2, -8 \n\t"
3061
- "vadd.vi v3, v3, -8 \n\t"
3062
- "vadd.vi v4, v4, -8 \n\t"
3063
- "vadd.vi v5, v5, -8 \n\t"
3064
- "vadd.vi v6, v6, -8 \n\t"
3065
- "vadd.vi v7, v7, -8 \n\t"
3066
-
3067
- SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
3068
-
3069
- "bnez t5, LOOP_INNER%= \n\t"
3070
- "vsetvli t0, zero, e32, mf2 \n\t"
3071
-
3072
- SQ4BIT_KERNEL_ACC_1X4X4
3073
-
3074
- "bnez %[CNT], LOOP_K%= \n\t"
3075
- "addi t3, zero, 16 \n\t"
3076
- "addi s1, %[C], 16 \n\t"
3077
- "addi s2, %[C], 32 \n\t"
3078
- "addi s3, %[C], 48 \n\t"
3079
- "blt %[NBLKS], t3, ST_TAIL%= \n\t"
3080
- "vse32.v v28, (%[C]) \n\t"
3081
- "vse32.v v29, (s1) \n\t"
3082
- "vse32.v v30, (s2) \n\t"
3083
- "vse32.v v31, (s3) \n\t"
3084
- "jal x0, END%= \n\t"
3085
-
3086
- "ST_TAIL%=: \n\t"
3087
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3088
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
3089
- "vse32.v v28, (%[C]) \n\t"
3090
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3091
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
3092
- "vse32.v v29, (s1) \n\t"
3093
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3094
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
3095
- "vse32.v v30, (s2) \n\t"
3096
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3097
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
3098
- "vse32.v v31, (s3) \n\t"
3099
- "END%=: \n\t"
3100
-
3101
- : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
3102
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
3103
- : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
3104
- }
3105
- }
3106
- }
3107
- }
3108
-
3109
- template <bool HasZeroPoint>
3110
- inline void SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen,
3111
- const std::byte * QuantA,
3112
- const std::byte * QuantBData,
3113
- const float * QuantBScale,
3114
- const std::byte * QuantBZeroPoint,
3115
- float * C,
3116
- size_t CountM,
3117
- size_t CountN,
3118
- size_t BlockStrideQuantB,
3119
- const float * Bias,
3120
- const size_t ldc,
3121
- const size_t scalestride) {
3122
- if (scalestride == 4) {
3123
- SQ4BitGemmM4Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C,
3124
- CountN, BlockStrideQuantB, Bias, ldc);
3125
-
3126
- } else if (scalestride == 2) {
3127
- SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>(
3128
- BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias, ldc);
3129
- }
3130
- }
3131
900
 
3132
- template <bool HasZeroPoint>
3133
- inline void SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen,
3134
- const std::byte * QuantA,
3135
- const std::byte * QuantBData,
3136
- const float * QuantBScale,
3137
- const std::byte * QuantBZeroPoint,
3138
- float * C,
3139
- size_t CountM,
3140
- size_t CountN,
3141
- size_t BlockStrideQuantB,
3142
- const float * Bias,
3143
- const size_t ldc,
3144
- const size_t scalestride) {
3145
- if (scalestride == 4) {
3146
- SQ4BitGemmM1Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C,
3147
- CountN, BlockStrideQuantB, Bias);
3148
- } else if (scalestride == 2) {
3149
- SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale,
3150
- QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias);
901
+ __asm__ volatile(
902
+ "vsetvli t0, zero, e32, m4 \n\t"
903
+ "vxor.vv v28, v28, v28 \n\t"
904
+ "addi s1, %[B], 0 \n\t"
905
+ "addi s2, %[B], 8 \n\t"
906
+ "addi s3, %[B], 16 \n\t"
907
+ "addi s4, %[B], 24 \n\t"
908
+
909
+ "addi s5, %[A], 0 \n\t"
910
+ "addi s6, %[A], 12 \n\t"
911
+ "LOOP_K%=: \n\t"
912
+ "vsetvli t0, zero, e16, mf4 \n\t"
913
+ "vle16.v v4, (s1) \n\t"
914
+ "addi s1, s1, 32 \n\t"
915
+ "vle16.v v5, (s2) \n\t"
916
+ "addi s2, s2, 56 \n\t"
917
+ "vle16.v v6, (s3) \n\t"
918
+ "addi s3, s3, 80 \n\t"
919
+ "vle16.v v7, (s4) \n\t"
920
+ "addi s4, s4, 104 \n\t"
921
+ "flw f1, (s5) \n\t"
922
+ "addi s5, s5, 4 \n\t"
923
+
924
+ "vfwcvt.f.f.v v8, v4 \n\t"
925
+ "vfwcvt.f.f.v v9, v5 \n\t"
926
+ "vfwcvt.f.f.v v10, v6 \n\t"
927
+ "vfwcvt.f.f.v v11, v7 \n\t"
928
+ "vsetvli t0, zero, e32, mf2 \n\t"
929
+
930
+ "addi t5, %[INNER], 0 \n\t"
931
+ "vxor.vv v16, v16, v16 \n\t"
932
+ "vxor.vv v18, v18, v18 \n\t"
933
+ "vxor.vv v20, v20, v20 \n\t"
934
+ "vxor.vv v22, v22, v22 \n\t"
935
+ "vfmul.vf v24, v8, f1 \n\t"
936
+ "vfmul.vf v25, v9, f1 \n\t"
937
+ "vfmul.vf v26, v10, f1 \n\t"
938
+ "vfmul.vf v27, v11, f1 \n\t"
939
+ "addi %[CNT], %[CNT], -1 \n\t"
940
+ "vsetvli t0, zero, e8, m1 \n\t"
941
+ "LOOP_INNER%=: \n\t"
942
+
943
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
944
+
945
+ "vadd.vi v0, v0, -8 \n\t"
946
+ "vadd.vi v1, v1, -8 \n\t"
947
+ "vadd.vi v2, v2, -8 \n\t"
948
+ "vadd.vi v3, v3, -8 \n\t"
949
+ "vadd.vi v4, v4, -8 \n\t"
950
+ "vadd.vi v5, v5, -8 \n\t"
951
+ "vadd.vi v6, v6, -8 \n\t"
952
+ "vadd.vi v7, v7, -8 \n\t"
953
+
954
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
955
+
956
+ "bnez t5, LOOP_INNER%= \n\t"
957
+ "vsetvli t0, zero, e32, mf2 \n\t"
958
+
959
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
960
+
961
+ "bnez %[CNT], LOOP_K%= \n\t"
962
+ "addi t3, zero, 16 \n\t"
963
+ "addi s1, %[C], 16 \n\t"
964
+ "addi s2, %[C], 32 \n\t"
965
+ "addi s3, %[C], 48 \n\t"
966
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
967
+ "vse32.v v28, (%[C]) \n\t"
968
+ "vse32.v v29, (s1) \n\t"
969
+ "vse32.v v30, (s2) \n\t"
970
+ "vse32.v v31, (s3) \n\t"
971
+ "jal x0, END%= \n\t"
972
+
973
+ "ST_TAIL%=: \n\t"
974
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
975
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
976
+ "vse32.v v28, (%[C]) \n\t"
977
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
978
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
979
+ "vse32.v v29, (s1) \n\t"
980
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
981
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
982
+ "vse32.v v30, (s2) \n\t"
983
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
984
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
985
+ "vse32.v v31, (s3) \n\t"
986
+ "END%=: \n\t"
987
+
988
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
989
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
990
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
991
+ }
3151
992
  }
3152
993
  }
3153
-
3154
994
  } // namespace
3155
995
 
3156
996
  namespace ime1 {
3157
- size_t gemm_kernel_i8i4(size_t BlkLen,
3158
- const std::byte * QuantA,
3159
- const std::byte * QuantBData,
3160
- const float * QuantBScale,
3161
- const std::byte * QuantBZeroPoint,
3162
- float * C,
3163
- size_t CountM,
3164
- size_t CountN,
3165
- size_t CountK,
3166
- size_t BlockCountK,
3167
- size_t ldc,
3168
- const float * Bias,
3169
- const size_t ScaleStride) {
3170
- GGML_UNUSED(CountM);
3171
- GGML_UNUSED(CountK);
3172
- GGML_UNUSED(ldc);
3173
- if (CountM >= 4) {
3174
- if (QuantBZeroPoint != nullptr) {
3175
- SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint,
3176
- C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride);
997
+ size_t gemm_kernel_i8i4(size_t blk_len,
998
+ const uint8_t * quant_a_ptr,
999
+ const uint8_t * quant_b_data,
1000
+ const uint8_t * quant_b_zp,
1001
+ float * c_ptr,
1002
+ size_t count_m,
1003
+ size_t count_n,
1004
+ size_t k_blks,
1005
+ size_t ldc) {
1006
+ if (count_m >= 4) {
1007
+ if (quant_b_zp != nullptr) {
1008
+ SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl<true>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, k_blks,
1009
+ ldc);
3177
1010
  } else {
3178
- SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale,
3179
- QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias,
3180
- ldc, ScaleStride);
1011
+ SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl<false>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n,
1012
+ k_blks, ldc);
3181
1013
  }
3182
1014
  return 4;
3183
1015
  } else {
3184
- if (QuantBZeroPoint != nullptr) {
3185
- SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint,
3186
- C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride);
1016
+ if (quant_b_zp != nullptr) {
1017
+ SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl<true>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, k_blks,
1018
+ ldc);
3187
1019
  } else {
3188
- SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale,
3189
- QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias,
3190
- ldc, ScaleStride);
1020
+ SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl<false>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n,
1021
+ k_blks, ldc);
3191
1022
  }
3192
1023
  return 1;
3193
1024
  }
3194
1025
  }
3195
1026
  } // namespace ime1
3196
- } // namespace sqnbitgemm_spacemit_ime
1027
+ } // namespace spacemit_kernels