whispercpp 1.3.5 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1017) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/LICENSE +1 -1
  5. data/README.md +133 -3
  6. data/Rakefile +18 -3
  7. data/ext/dependencies.rb +10 -4
  8. data/ext/dependencies_for_windows.rb +17 -0
  9. data/ext/extconf.rb +20 -7
  10. data/ext/options.rb +54 -14
  11. data/ext/options_for_windows.rb +51 -0
  12. data/ext/ruby_whisper.c +56 -46
  13. data/ext/ruby_whisper.h +165 -2
  14. data/ext/ruby_whisper_context.c +297 -126
  15. data/ext/ruby_whisper_context_params.c +163 -0
  16. data/ext/ruby_whisper_log_queue.c +180 -0
  17. data/ext/ruby_whisper_log_settable.h +47 -0
  18. data/ext/ruby_whisper_model.c +0 -1
  19. data/ext/ruby_whisper_parakeet.c +49 -0
  20. data/ext/ruby_whisper_parakeet_context.c +304 -0
  21. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  22. data/ext/ruby_whisper_parakeet_model.c +84 -0
  23. data/ext/ruby_whisper_parakeet_params.c +548 -0
  24. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  25. data/ext/ruby_whisper_parakeet_token.c +188 -0
  26. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  27. data/ext/ruby_whisper_params.c +256 -66
  28. data/ext/ruby_whisper_segment.c +6 -7
  29. data/ext/ruby_whisper_token.c +29 -9
  30. data/ext/ruby_whisper_transcribe.cpp +46 -16
  31. data/ext/ruby_whisper_vad_context.c +48 -1
  32. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  33. data/ext/ruby_whisper_vad_params.c +0 -1
  34. data/ext/ruby_whisper_vad_segment.c +0 -1
  35. data/ext/ruby_whisper_vad_segments.c +0 -1
  36. data/ext/sources/CMakeLists.txt +41 -3
  37. data/ext/sources/CMakePresets.json +95 -0
  38. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  39. data/ext/sources/cmake/parakeet.pc.in +10 -0
  40. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  41. data/ext/sources/cmake/whisper.pc.in +1 -1
  42. data/ext/sources/examples/CMakeLists.txt +4 -2
  43. data/ext/sources/examples/bench/bench.cpp +24 -19
  44. data/ext/sources/examples/cli/cli.cpp +51 -9
  45. data/ext/sources/examples/common-ggml.cpp +4 -0
  46. data/ext/sources/examples/common-whisper.cpp +139 -67
  47. data/ext/sources/examples/common-whisper.h +11 -0
  48. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  49. data/ext/sources/examples/miniaudio.h +4507 -2131
  50. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  51. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  52. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  53. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  54. data/ext/sources/examples/server/server.cpp +213 -163
  55. data/ext/sources/ggml/CMakeLists.txt +29 -15
  56. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  57. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  58. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  59. data/ext/sources/ggml/include/ggml-backend.h +73 -11
  60. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  61. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  62. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  63. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  64. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  65. data/ext/sources/ggml/include/ggml-rpc.h +8 -3
  66. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  67. data/ext/sources/ggml/include/ggml.h +155 -16
  68. data/ext/sources/ggml/include/gguf.h +10 -2
  69. data/ext/sources/ggml/src/CMakeLists.txt +25 -5
  70. data/ext/sources/ggml/src/ggml-alloc.c +9 -10
  71. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  72. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  73. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  74. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
  76. data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
  79. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
  83. data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
  85. data/ext/sources/ggml/src/ggml-common.h +22 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
  88. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  89. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
  91. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
  92. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
  99. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
  101. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  102. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  103. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
  105. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
  106. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
  107. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  108. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
  109. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
  110. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
  111. data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
  112. data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
  113. data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
  115. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  116. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
  117. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
  118. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  119. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  120. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  121. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  122. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  123. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  124. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  125. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  126. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  127. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  128. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  129. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  130. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  131. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  132. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  133. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  134. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
  135. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  136. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
  137. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  138. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  139. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
  140. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
  141. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  142. data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
  143. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  144. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  145. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  146. data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
  147. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  148. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
  149. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  150. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
  151. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
  152. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  153. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
  154. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
  155. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
  156. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  157. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
  158. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  159. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  160. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  161. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
  162. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  163. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  164. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
  165. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  166. data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
  167. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
  168. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  169. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
  170. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  171. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
  172. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
  173. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  174. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
  175. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
  176. data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
  177. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  178. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  179. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
  180. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  181. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
  182. data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
  183. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  184. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  185. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  186. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  187. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  188. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  189. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  190. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
  191. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  192. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  193. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
  226. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
  227. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  228. data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
  229. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
  230. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
  231. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  232. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
  233. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
  235. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
  254. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  255. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  256. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  257. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  258. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  259. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  260. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  261. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  262. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
  263. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
  264. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  265. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  266. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
  267. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
  268. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
  269. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  270. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
  271. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  272. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  273. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  274. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  275. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  276. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  278. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  297. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
  302. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
  303. data/ext/sources/ggml/src/ggml-impl.h +68 -1
  304. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  305. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  306. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  307. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  308. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
  309. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
  310. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
  311. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
  312. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
  313. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
  314. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
  315. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
  316. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  317. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
  318. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  367. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  368. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  369. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  370. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  371. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  372. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  373. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  374. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  375. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  376. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  377. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  378. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  379. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  380. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  381. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  382. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  383. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  384. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  385. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  386. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  387. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  388. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  389. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
  390. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  391. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  392. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  393. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  394. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  395. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  396. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  397. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  398. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  399. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  400. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  401. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
  402. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  403. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
  404. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  405. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
  406. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
  407. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  412. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  413. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  414. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  415. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  416. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  417. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  418. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  419. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  420. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  421. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  422. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  423. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  424. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
  425. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  426. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  427. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  428. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  429. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  430. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  431. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  432. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
  433. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
  434. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  435. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  436. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  437. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  438. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  439. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  440. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
  441. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  442. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
  443. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
  444. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
  445. data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
  446. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  447. data/ext/sources/ggml/src/ggml-quants.c +385 -119
  448. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  449. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  450. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  451. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  452. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  453. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
  454. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  455. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
  456. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  457. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  458. data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
  459. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
  460. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
  461. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  462. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  463. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  464. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  465. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  466. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  467. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  468. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  469. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
  470. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
  471. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  472. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  473. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
  474. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
  475. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
  476. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
  477. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
  478. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  479. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  480. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  481. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
  482. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
  483. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  484. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  485. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
  486. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  487. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  488. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  489. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  490. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  491. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  492. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  493. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  494. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
  495. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  496. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  497. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  498. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  499. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  500. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  501. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  502. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  503. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  504. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  505. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  506. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  507. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  508. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  509. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  510. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  511. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  512. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  513. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  514. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  515. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  516. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
  517. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
  518. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
  519. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
  520. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
  521. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
  522. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
  523. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
  524. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
  525. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
  526. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
  527. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
  530. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
  531. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
  532. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
  534. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
  535. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
  536. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
  537. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
  538. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
  539. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
  540. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
  541. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
  542. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
  543. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
  544. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
  545. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
  547. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
  548. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
  549. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
  550. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
  551. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
  552. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  553. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  554. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  555. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
  556. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  557. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  558. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  559. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  560. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  561. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  562. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  563. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  564. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  565. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  566. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  567. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  568. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  569. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  570. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  571. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  572. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  573. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  574. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  575. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  576. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  577. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  578. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
  579. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
  580. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  581. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
  582. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  583. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  584. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  585. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  586. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  587. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  588. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  589. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  590. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  591. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  592. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
  593. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  594. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  595. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  596. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
  597. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  599. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
  663. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  664. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
  665. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
  666. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  667. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  668. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  669. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  670. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  671. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
  672. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
  673. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
  674. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  675. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
  676. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  677. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
  678. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  679. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  680. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  681. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  682. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  683. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  684. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
  685. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  686. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  687. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
  688. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  689. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  690. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  691. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
  692. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
  693. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
  694. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  695. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  696. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  697. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  698. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  699. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  700. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  701. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  702. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  703. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
  704. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  705. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
  706. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  707. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  708. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  709. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  710. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  711. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  712. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
  713. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  714. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
  715. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  716. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
  717. data/ext/sources/ggml/src/ggml.c +268 -52
  718. data/ext/sources/ggml/src/gguf.cpp +377 -47
  719. data/ext/sources/include/parakeet.h +342 -0
  720. data/ext/sources/include/whisper.h +10 -0
  721. data/ext/sources/media/matmul.png +0 -0
  722. data/ext/sources/src/CMakeLists.txt +23 -0
  723. data/ext/sources/src/parakeet-arch.h +188 -0
  724. data/ext/sources/src/parakeet.cpp +3838 -0
  725. data/ext/sources/src/whisper.cpp +62 -40
  726. data/extsources.rb +26 -10
  727. data/lib/whisper/log_settable.rb +36 -0
  728. data/lib/whisper/model/uri.rb +13 -1
  729. data/lib/whisper/output.rb +74 -0
  730. data/sig/whisper.rbs +445 -55
  731. data/test/helper.rb +2 -0
  732. data/test/jfk_reader/jfk_reader.c +50 -7
  733. data/test/test_callback.rb +1 -0
  734. data/test/test_context_params.rb +82 -0
  735. data/test/test_package.rb +6 -5
  736. data/test/test_parakeet.rb +28 -0
  737. data/test/test_parakeet_callback.rb +107 -0
  738. data/test/test_parakeet_context.rb +116 -0
  739. data/test/test_parakeet_context_params.rb +24 -0
  740. data/test/test_parakeet_model.rb +21 -0
  741. data/test/test_parakeet_params.rb +78 -0
  742. data/test/test_parakeet_segment.rb +42 -0
  743. data/test/test_parakeet_token.rb +73 -0
  744. data/test/test_params.rb +2 -0
  745. data/test/test_token.rb +11 -0
  746. data/test/test_vad_context.rb +58 -8
  747. data/test/test_vad_segment.rb +1 -1
  748. data/test/test_whisper.rb +44 -6
  749. data/whispercpp.gemspec +2 -2
  750. metadata +426 -280
  751. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  752. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  753. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  754. data/ext/sources/bindings/javascript/package.json +0 -26
  755. data/ext/sources/bindings/javascript/whisper.js +0 -19
  756. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  757. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  758. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  759. data/ext/sources/examples/addon.node/index.js +0 -59
  760. data/ext/sources/examples/addon.node/package.json +0 -16
  761. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  762. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  763. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  764. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  765. data/ext/sources/examples/coi-serviceworker.js +0 -146
  766. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  767. data/ext/sources/examples/command/command.cpp +0 -802
  768. data/ext/sources/examples/command/commands.txt +0 -9
  769. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  770. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  771. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  772. data/ext/sources/examples/generate-karaoke.sh +0 -57
  773. data/ext/sources/examples/helpers.js +0 -191
  774. data/ext/sources/examples/livestream.sh +0 -112
  775. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  776. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  777. data/ext/sources/examples/lsp/whisper.vim +0 -362
  778. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  779. data/ext/sources/examples/python/whisper_processor.py +0 -54
  780. data/ext/sources/examples/server/bench.js +0 -29
  781. data/ext/sources/examples/server.py +0 -120
  782. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  783. data/ext/sources/examples/stream/stream.cpp +0 -437
  784. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  785. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  786. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  787. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  788. data/ext/sources/examples/sycl/build.sh +0 -22
  789. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  790. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  791. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
  792. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  793. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
  794. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
  795. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
  796. data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
  797. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
  798. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  799. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
  800. data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
  801. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
  802. data/ext/sources/examples/talk-llama/llama-context.h +0 -360
  803. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  804. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
  805. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  806. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  807. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
  808. data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
  809. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
  810. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
  811. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  812. data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
  813. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  814. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  815. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
  816. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  817. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
  818. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
  819. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  820. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  821. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  822. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
  823. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  824. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  825. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  826. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
  827. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  828. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
  829. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
  830. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
  831. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
  832. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
  833. data/ext/sources/examples/talk-llama/llama-model.h +0 -544
  834. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
  835. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  836. data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
  837. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
  838. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
  839. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
  840. data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
  841. data/ext/sources/examples/talk-llama/llama.h +0 -1540
  842. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
  843. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  844. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  845. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
  846. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  847. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
  848. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
  849. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
  850. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
  851. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
  852. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  853. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  854. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  855. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  856. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  857. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  858. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  859. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
  860. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  861. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
  862. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
  863. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
  864. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  865. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
  866. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  867. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  868. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  869. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
  870. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  871. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  872. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  873. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  874. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  875. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  876. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  877. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
  878. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  879. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  880. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
  881. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
  882. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
  883. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  884. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
  885. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  886. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
  887. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  888. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  889. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  890. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
  891. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  892. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  893. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  894. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
  895. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  896. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
  897. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
  898. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
  899. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
  900. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  901. data/ext/sources/examples/talk-llama/models/models.h +0 -569
  902. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
  903. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  904. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
  905. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  906. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  907. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  908. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  909. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  910. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  911. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  912. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  913. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  914. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  915. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  916. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  917. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
  918. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  919. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
  920. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  921. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  922. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  923. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  924. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
  925. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
  926. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
  927. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
  928. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
  929. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  930. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  931. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
  932. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  933. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  934. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
  935. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  936. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  937. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  938. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  939. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  940. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  941. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  942. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  943. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  944. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  945. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  946. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  947. data/ext/sources/examples/talk-llama/speak +0 -40
  948. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  949. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  950. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  951. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  952. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  953. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
  954. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  955. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  956. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  957. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  958. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  959. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  960. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  961. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  962. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  963. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  964. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  965. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  966. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  967. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  968. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  969. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  970. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
  971. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
  972. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  973. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  974. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  975. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  976. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  977. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  978. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  979. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  980. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  981. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  982. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  983. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  984. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
  985. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  986. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
  987. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
  988. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  989. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  990. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  991. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
  992. data/ext/sources/tests/CMakeLists.txt +0 -112
  993. data/ext/sources/tests/earnings21/eval.mk +0 -58
  994. data/ext/sources/tests/earnings21/eval.py +0 -68
  995. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  996. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  997. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  998. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  999. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  1000. data/ext/sources/tests/en-0-ref.txt +0 -1
  1001. data/ext/sources/tests/en-1-ref.txt +0 -1
  1002. data/ext/sources/tests/en-2-ref.txt +0 -1
  1003. data/ext/sources/tests/es-0-ref.txt +0 -1
  1004. data/ext/sources/tests/librispeech/eval.mk +0 -39
  1005. data/ext/sources/tests/librispeech/eval.py +0 -47
  1006. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  1007. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  1008. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  1009. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  1010. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  1011. data/ext/sources/tests/run-tests.sh +0 -130
  1012. data/ext/sources/tests/test-c.c +0 -3
  1013. data/ext/sources/tests/test-vad-full.cpp +0 -56
  1014. data/ext/sources/tests/test-vad.cpp +0 -83
  1015. data/ext/sources/tests/test-whisper.js +0 -58
  1016. data/lib/whisper/context.rb +0 -15
  1017. data/lib/whisper/segment.rb +0 -58
@@ -1,8 +1,26 @@
1
+ #include "ggml-impl.h"
1
2
  #include "ggml.h"
2
3
  #include "ime_kernels.h"
4
+ #include "rvv_kernels.h"
3
5
 
4
6
  #include <algorithm>
5
7
  #include <cmath>
8
+ #include <stdexcept>
9
+
10
+ #if !defined(__riscv_v) || !defined(__riscv_v_intrinsic)
11
+ # error "riscv v extension or v_intrinsic not enabled"
12
+ #else
13
+ # include <riscv_vector.h>
14
+ #endif
15
+
16
+ #if !defined(__riscv_zfh)
17
+ # error "riscv zfh extension not enabled"
18
+ #endif
19
+
20
+ #if defined(RISCV64_SPACEMIT_IME1)
21
+ #else
22
+ # error "RISCV64_SPACEMIT_IME1 not defined"
23
+ #endif
6
24
 
7
25
  // clang-format off
8
26
  #if defined(__GNUC__)
@@ -11,7 +29,7 @@
11
29
  #pragma GCC diagnostic ignored "-Wunused-parameter"
12
30
  #endif
13
31
  // clang-format on
14
- namespace sqnbitgemm_spacemit_ime {
32
+ namespace spacemit_kernels {
15
33
 
16
34
  #define QUANTIZEM4ROW_KERNEL \
17
35
  "vmv.s.x v16, zero \n\t" \
@@ -76,1093 +94,208 @@ namespace sqnbitgemm_spacemit_ime {
76
94
  "vse8.v v31, (s1) \n\t"
77
95
 
78
96
  namespace ime1 {
79
- void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) {
97
+ void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, uint8_t * QuantA) {
80
98
  constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);
81
99
  const float fone = 1.0f;
82
100
 
83
- if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64) {
84
- for (size_t row_index = 0; row_index < 4; ++row_index) {
85
- const float * SRC = A + row_index * CountK;
86
- std::byte * DST = QuantA + row_index * sizeof(float);
101
+ for (size_t row_index = 0; row_index < 4; ++row_index) {
102
+ const float * SRC = A + row_index * CountK;
103
+ uint8_t * DST = QuantA + row_index * sizeof(float);
87
104
 
88
- const size_t offset = (4 - row_index) * 4 + row_index * 8;
89
- const size_t stride = 4 * (sizeof(float) + BlkLen);
90
- __asm__ volatile(
91
- "vsetvli t0, zero, e32, m8 \n\t"
92
- "addi t2, %[CountK], 0 \n\t"
93
- "addi a1, %[DST], 0 \n\t"
94
- "blt t2, %[BlkLen], TAIL%= \n\t"
95
-
96
- "LOOP%=: \n\t"
97
- "vsetvli t0, %[BlkLen], e32, m8 \n\t"
98
- "vle32.v v0, (%[SRC]) \n\t"
99
- "sub t2, t2, t0 \n\t"
100
- "slli t1, t0, 2 \n\t"
101
- "add %[SRC], %[SRC], t1 \n\t"
102
- "add s1, a1, %[OFFSET] \n\t"
103
-
104
- QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE
105
-
106
- "add a1, a1, %[STRIDE] \n\t"
107
- "bge t2, %[BlkLen], LOOP%= \n\t"
108
-
109
- "TAIL%=: \n\t"
110
- "blez t2, QUIT%= \n\t"
111
- "vsetvli t0, zero, e32, m8 \n\t"
112
- "vxor.vv v16, v16, v16 \n\t"
113
- "vxor.vv v24, v24, v24 \n\t"
114
- "vsetvli t0, t2, e32, m8 \n\t"
115
- "vle32.v v0, (%[SRC]) \n\t"
116
- "add s1, a1, %[OFFSET] \n\t"
117
-
118
- QUANTIZEM4ROW_KERNEL
119
-
120
- "addi t3, %[BlkLen], 0 \n\t"
121
- "addi s2, s1, 0 \n\t"
122
- "vsetvli t0, zero, e8, mf4 \n\t"
123
- "vxor.vv v8, v8, v8 \n\t"
124
- "SET_ZERO%=: \n\t"
125
- "vse8.v v8, (s2) \n\t"
126
- "addi s2, s2, 32 \n\t"
127
- "addi t3, t3, -8 \n\t"
128
- "bnez t3, SET_ZERO%= \n\t"
129
-
130
- QUANTIZEM4ROW_STORE
131
-
132
- "QUIT%=: \n\t"
133
- : [SRC] "+r"(SRC)
134
- : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
135
- [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
136
- : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11");
137
- }
138
- } else if (BlkLen == 128) {
139
- for (size_t row_index = 0; row_index < 4; ++row_index) {
140
- const float * SRC = A + row_index * CountK;
141
- std::byte * DST = QuantA + row_index * sizeof(float);
142
-
143
- const size_t offset = (4 - row_index) * 4 + row_index * 8;
144
- const size_t stride = 4 * (sizeof(float) + BlkLen);
145
- __asm__ volatile(
146
- "vsetvli t0, zero, e32, m8 \n\t"
147
- "li t6, 32 \n\t"
148
- "addi t2, %[CountK], 0 \n\t"
149
- "addi a1, %[DST], 0 \n\t"
150
- "add s1, a1, %[OFFSET] \n\t"
151
- "blt t2, %[BlkLen], TAIL%= \n\t"
152
-
153
- "LOOP%=: \n\t"
154
- "vsetvli t0, zero, e32, m8 \n\t"
155
- "vle32.v v0, (%[SRC]) \n\t"
156
- "addi %[SRC], %[SRC], 256 \n\t"
157
- "vle32.v v8, (%[SRC]) \n\t"
158
- "addi %[SRC], %[SRC], 256 \n\t"
159
- "addi t2, t2, -128 \n\t"
160
-
161
- "QUANTIZE%=: \n\t"
162
- "add s1, a1, %[OFFSET] \n\t"
163
- "vfabs.v v16, v0 \n\t"
164
- "vfabs.v v24, v8 \n\t"
165
- "vfmax.vv v16, v24, v16 \n\t"
166
- "vfredmax.vs v24, v16, v24 \n\t"
167
- "vfmv.f.s f10, v24 \n\t"
168
- "fmul.s f10, f10, %[RMAXREC] \n\t"
169
- "fsw f10, (a1) \n\t"
170
- "fdiv.s f11, %[FONE], f10 \n\t"
171
- "vfmul.vf v16, v0, f11 \n\t"
172
- "vfmul.vf v24, v8, f11 \n\t"
173
- "vfcvt.x.f.v v16, v16 \n\t"
174
- "vfcvt.x.f.v v24, v24 \n\t"
175
- "vsetvli t0, zero, e16, m4 \n\t"
176
- "vnclip.wx v16, v16, zero \n\t"
177
- "vnclip.wx v20, v24, zero \n\t"
178
- "vsetvli t0, zero, e8, m4 \n\t"
179
- "vnclip.wx v16, v16, zero \n\t"
180
- "vsetvli t0, zero, e64, m4 \n\t"
181
- "vsse64.v v16, (s1), t6 \n\t"
182
- "add a1, a1, %[STRIDE] \n\t"
183
- "bge t2, %[BlkLen], LOOP%= \n\t"
184
-
185
- "TAIL%=: \n\t"
186
- "blez t2, QUIT%= \n\t"
187
- "vsetvli t0, zero, e32, m8 \n\t"
188
- "vxor.vv v0, v0, v0 \n\t"
189
- "vxor.vv v8, v8, v8 \n\t"
190
- "vxor.vv v16, v16, v16 \n\t"
191
- "vxor.vv v24, v24, v24 \n\t"
192
- "vsetvli t0, t2, e32, m8 \n\t"
193
- "sub t2, t2, t0 \n\t"
194
- "vle32.v v0, (%[SRC]) \n\t"
195
- "addi %[SRC], %[SRC], 256 \n\t"
196
- "vsetvli t0, t2, e32, m8 \n\t"
197
- "vle32.v v8, (%[SRC]) \n\t"
198
- "sub t2, t2, t2 \n\t"
199
- "vsetvli t0, zero, e32, m8 \n\t"
200
- "jal x0, QUANTIZE%= \n\t"
201
-
202
- "QUIT%=: \n\t"
203
- : [SRC] "+r"(SRC)
204
- : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
205
- [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
206
- : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11");
207
- }
208
- } else if (BlkLen == 256) {
209
- for (size_t row_index = 0; row_index < 4; ++row_index) {
210
- const float * SRC = A + row_index * CountK;
211
- std::byte * DST = QuantA + row_index * sizeof(float);
212
- const size_t offset = (4 - row_index) * 4 + row_index * 8;
213
- const size_t stride = 4 * (sizeof(float) + BlkLen);
214
- __asm__ volatile(
215
- "vsetvli t0, zero, e32, m8 \n\t"
216
- "li t6, 32 \n\t"
217
- "addi t2, %[CountK], 0 \n\t"
218
- "addi a1, %[DST], 0 \n\t"
219
- "add s1, a1, %[OFFSET] \n\t"
220
- "blt t2, %[BlkLen], TAIL%= \n\t"
221
-
222
- "LOOP%=: \n\t"
223
- "vsetvli t0, zero, e32, m8 \n\t"
224
- "vle32.v v0, (%[SRC]) \n\t"
225
- "addi %[SRC], %[SRC], 256 \n\t"
226
- "vle32.v v8, (%[SRC]) \n\t"
227
- "addi %[SRC], %[SRC], 256 \n\t"
228
- "vle32.v v16, (%[SRC]) \n\t"
229
- "addi %[SRC], %[SRC], 256 \n\t"
230
- "vle32.v v24, (%[SRC]) \n\t"
231
- "addi %[SRC], %[SRC], -768 \n\t"
232
- "addi t2, t2, -256 \n\t"
233
- "vfabs.v v0, v0 \n\t"
234
- "vfabs.v v8, v8 \n\t"
235
- "vfabs.v v16, v16 \n\t"
236
- "vfabs.v v24, v24 \n\t"
237
- "vfmax.vv v8, v0, v8 \n\t"
238
- "vfmax.vv v24, v24, v16 \n\t"
239
- "vfmax.vv v8, v8, v24 \n\t"
240
- "vfredmax.vs v24, v8, v24 \n\t"
241
- "vfmv.f.s f10, v24 \n\t"
242
- "vle32.v v0, (%[SRC]) \n\t"
243
- "addi %[SRC], %[SRC], 256 \n\t"
244
- "vle32.v v8, (%[SRC]) \n\t"
245
- "addi %[SRC], %[SRC], 256 \n\t"
246
- "vle32.v v16, (%[SRC]) \n\t"
247
- "addi %[SRC], %[SRC], 256 \n\t"
248
- "vle32.v v24, (%[SRC]) \n\t"
249
- "addi %[SRC], %[SRC], 256 \n\t"
250
-
251
- "QUANTIZE%=: \n\t"
252
- "add s1, a1, %[OFFSET] \n\t"
253
- "fmul.s f10, f10, %[RMAXREC] \n\t"
254
- "fsw f10, (a1) \n\t"
255
- "fdiv.s f11, %[FONE], f10 \n\t"
256
- "vfmul.vf v0, v0, f11 \n\t"
257
- "vfmul.vf v8, v8, f11 \n\t"
258
- "vfmul.vf v16, v16, f11 \n\t"
259
- "vfmul.vf v24, v24, f11 \n\t"
260
- "vfcvt.x.f.v v0, v0 \n\t"
261
- "vfcvt.x.f.v v8, v8 \n\t"
262
- "vfcvt.x.f.v v16, v16 \n\t"
263
- "vfcvt.x.f.v v24, v24 \n\t"
264
- "vsetvli t0, zero, e16, m4 \n\t"
265
- "vnclip.wx v0, v0, zero \n\t"
266
- "vnclip.wx v4, v8, zero \n\t"
267
- "vnclip.wx v8, v16, zero \n\t"
268
- "vnclip.wx v12, v24, zero \n\t"
269
- "vsetvli t0, zero, e8, m4 \n\t"
270
- "vnclip.wx v0, v0, zero \n\t"
271
- "vnclip.wx v4, v8, zero \n\t"
272
- "vsetvli t0, zero, e64, m8 \n\t"
273
- "vsse64.v v0, (s1), t6 \n\t"
274
- "add a1, a1, %[STRIDE] \n\t"
275
- "bge t2, %[BlkLen], LOOP%= \n\t"
276
-
277
- "TAIL%=: \n\t"
278
- "blez t2, QUIT%= \n\t"
279
- "vsetvli t0, zero, e32, m8 \n\t"
280
- "vxor.vv v0, v0, v0 \n\t"
281
- "vxor.vv v8, v8, v8 \n\t"
282
- "vxor.vv v16, v16, v16 \n\t"
283
- "vxor.vv v24, v24, v24 \n\t"
284
- "addi t1, t2, 0 \n\t"
285
- "vsetvli t0, t1, e32, m8 \n\t"
286
- "sub t1, t1, t0 \n\t"
287
- "vle32.v v0, (%[SRC]) \n\t"
288
- "addi %[SRC], %[SRC], 256 \n\t"
289
- "vsetvli t0, t1, e32, m8 \n\t"
290
- "sub t1, t1, t0 \n\t"
291
- "vle32.v v8, (%[SRC]) \n\t"
292
- "addi %[SRC], %[SRC], 256 \n\t"
293
- "vsetvli t0, t1, e32, m8 \n\t"
294
- "sub t1, t1, t0 \n\t"
295
- "vle32.v v16, (%[SRC]) \n\t"
296
- "addi %[SRC], %[SRC], 256 \n\t"
297
- "vsetvli t0, t1, e32, m8 \n\t"
298
- "vle32.v v24, (%[SRC]) \n\t"
299
- "addi %[SRC], %[SRC], -768 \n\t"
300
- "vsetvli t0, zero, e32, m8 \n\t"
301
- "vfabs.v v0, v0 \n\t"
302
- "vfabs.v v8, v8 \n\t"
303
- "vfabs.v v16, v16 \n\t"
304
- "vfabs.v v24, v24 \n\t"
305
- "vfmax.vv v8, v0, v8 \n\t"
306
- "vfmax.vv v24, v16, v24 \n\t"
307
- "vfmax.vv v8, v8, v24 \n\t"
308
- "vfredmax.vs v24, v8, v24 \n\t"
309
- "vfmv.f.s f10, v24 \n\t"
310
- "add s1, a1, %[OFFSET] \n\t"
311
- "fmul.s f10, f10, %[RMAXREC] \n\t"
312
- "fsw f10, (a1) \n\t"
313
- "fdiv.s f11, %[FONE], f10 \n\t"
314
- "vsetvli t0, zero, e64, m8 \n\t"
315
- "vxor.vv v0, v0, v0 \n\t"
316
- "vsse64.v v0, (s1), t6 \n\t"
317
-
318
- "TAIL_LOOP%=: \n\t"
319
- "vsetvli t0, zero, e32, m4 \n\t"
320
- "vxor.vv v0, v0, v0 \n\t"
321
- "vsetvli t0, t2, e32, m1 \n\t"
322
- "sub t2, t2, t0 \n\t"
323
- "vle32.v v0, (%[SRC]) \n\t"
324
- "addi %[SRC], %[SRC], 32 \n\t"
325
- "vfmul.vf v1, v0, f11 \n\t"
326
- "vfcvt.x.f.v v2, v1 \n\t"
327
- "vsetvli t0, zero, e16, mf2 \n\t"
328
- "vnclip.wx v3, v2, zero \n\t"
329
- "vsetvli t0, zero, e8, mf4 \n\t"
330
- "vnclip.wx v3, v3, zero \n\t"
331
- "vse8.v v3, (s1) \n\t"
332
- "addi s1, s1, 32 \n\t"
333
- "bnez t2, TAIL_LOOP%= \n\t"
334
-
335
- "QUIT%=: \n\t"
336
- : [SRC] "+r"(SRC)
337
- : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
338
- [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
339
- : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11");
340
- }
105
+ const size_t offset = (4 - row_index) * 4 + row_index * 8;
106
+ const size_t stride = 4 * (sizeof(float) + BlkLen);
107
+ __asm__ volatile(
108
+ "vsetvli t0, zero, e32, m8 \n\t"
109
+ "addi t2, %[CountK], 0 \n\t"
110
+ "addi a1, %[DST], 0 \n\t"
111
+ "blt t2, %[BlkLen], TAIL%= \n\t"
112
+
113
+ "LOOP%=: \n\t"
114
+ "vsetvli t0, %[BlkLen], e32, m8 \n\t"
115
+ "vle32.v v0, (%[SRC]) \n\t"
116
+ "sub t2, t2, t0 \n\t"
117
+ "slli t1, t0, 2 \n\t"
118
+ "add %[SRC], %[SRC], t1 \n\t"
119
+ "add s1, a1, %[OFFSET] \n\t"
120
+
121
+ QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE
122
+
123
+ "add a1, a1, %[STRIDE] \n\t"
124
+ "bge t2, %[BlkLen], LOOP%= \n\t"
125
+
126
+ "TAIL%=: \n\t"
127
+ "blez t2, QUIT%= \n\t"
128
+ "vsetvli t0, zero, e32, m8 \n\t"
129
+ "vxor.vv v16, v16, v16 \n\t"
130
+ "vxor.vv v24, v24, v24 \n\t"
131
+ "vsetvli t0, t2, e32, m8 \n\t"
132
+ "vle32.v v0, (%[SRC]) \n\t"
133
+ "add s1, a1, %[OFFSET] \n\t"
134
+
135
+ QUANTIZEM4ROW_KERNEL
136
+
137
+ "addi t3, %[BlkLen], 0 \n\t"
138
+ "addi s2, s1, 0 \n\t"
139
+ "vsetvli t0, zero, e8, mf4 \n\t"
140
+ "vxor.vv v8, v8, v8 \n\t"
141
+ "SET_ZERO%=: \n\t"
142
+ "vse8.v v8, (s2) \n\t"
143
+ "addi s2, s2, 32 \n\t"
144
+ "addi t3, t3, -8 \n\t"
145
+ "bnez t3, SET_ZERO%= \n\t"
146
+
147
+ QUANTIZEM4ROW_STORE
148
+
149
+ "QUIT%=: \n\t"
150
+ : [SRC] "+r"(SRC)
151
+ : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), [CountK] "r"(CountK),
152
+ [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
153
+ : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11");
341
154
  }
342
155
  }
343
156
 
344
- void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) {
157
+ void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, uint8_t * QuantA) {
345
158
  const float * SRC = A;
346
- std::byte * DST = QuantA;
159
+ uint8_t * DST = QuantA;
347
160
  constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);
348
161
  const float fone = 1.0f;
349
- std::byte * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen);
162
+ uint8_t * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen);
350
163
  size_t offset = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK;
351
164
 
352
- if (CountK <= BlkLen) {
353
- float max_abs_A = 0.0f;
354
- for (size_t k = 0; k < CountK; k++) {
355
- max_abs_A = std::max(max_abs_A, fabsf(A[k]));
356
- }
357
- float scale_A = max_abs_A * range_max_reciprocal;
358
-
359
- ((float *) QuantA)[0] = scale_A;
360
-
361
- auto * QuantAData_offset = (int8_t *) (QuantA + sizeof(float));
362
-
363
- for (size_t k = 0; k < CountK; k++) {
364
- QuantAData_offset[k] =
365
- (int8_t) std::clamp(roundf(A[k] / scale_A), (float) std::numeric_limits<int8_t>::lowest(),
366
- (float) std::numeric_limits<int8_t>::max());
367
- }
368
- for (size_t k = CountK; k < BlkLen; k++) {
369
- QuantAData_offset[k] = 0;
370
- }
371
-
372
- return;
373
- }
374
-
375
- if (BlkLen != 32 || BlkLen != 64 || BlkLen != 128) {
376
- __asm__ volatile(
377
- "vsetvli t0, zero, e8, m8 \n\t"
378
- "vxor.vv v24, v24, v24 \n\t"
379
- "LOOP%=: \n\t"
380
- "vsetvli t0, %[CNT], e8, m8 \n\t"
381
- "vse8.v v24, (%[DST]) \n\t"
382
- "addi %[DST], %[DST], 128 \n\t"
383
- "sub %[CNT], %[CNT], t0 \n\t"
384
- "bnez %[CNT], LOOP%= \n\t"
385
- : [DST] "+r"(QuantA_offset), [CNT] "+r"(offset)
386
- :
387
- : "cc", "t0");
388
- }
389
- if (BlkLen == 16) {
390
- float buffer[64] = { 0.0f };
391
- __asm__ volatile(
392
- "addi t3, zero, 16*8 \n\t"
393
- "addi t2, zero, 16 \n\t"
394
- "blt %[K], t3, LOOP_K%= \n\t"
395
- "blt %[K], t2, TAIL%= \n\t"
396
- "LOOP_MAIN%=: \n\t"
397
- "vsetvli t1, zero, e32, m2 \n\t"
398
- "addi %[K], %[K], -128 \n\t"
399
- "vle32.v v0, (%[SRC]) \n\t"
400
- "addi %[SRC], %[SRC], 64 \n\t"
401
- "vle32.v v2, (%[SRC]) \n\t"
402
- "addi %[SRC], %[SRC], 64 \n\t"
403
- "vle32.v v4, (%[SRC]) \n\t"
404
- "addi %[SRC], %[SRC], 64 \n\t"
405
- "vle32.v v6, (%[SRC]) \n\t"
406
- "addi %[SRC], %[SRC], 64 \n\t"
407
- "vle32.v v8, (%[SRC]) \n\t"
408
- "addi %[SRC], %[SRC], 64 \n\t"
409
- "vle32.v v10, (%[SRC]) \n\t"
410
- "addi %[SRC], %[SRC], 64 \n\t"
411
- "vle32.v v12, (%[SRC]) \n\t"
412
- "addi %[SRC], %[SRC], 64 \n\t"
413
- "vle32.v v14, (%[SRC]) \n\t"
414
- "addi %[SRC], %[SRC], 64 \n\t"
415
- "addi a1, %[BUFFER], 0 \n\t"
416
- "vfabs.v v16, v0 \n\t"
417
- "vfabs.v v18, v2 \n\t"
418
- "vfabs.v v20, v4 \n\t"
419
- "vfabs.v v22, v6 \n\t"
420
- "vfabs.v v24, v8 \n\t"
421
- "vfabs.v v26, v10 \n\t"
422
- "vfabs.v v28, v12 \n\t"
423
- "vfabs.v v30, v14 \n\t"
424
- "vsetvli t0, zero, e32, m1 \n\t"
425
- "vfmax.vv v16, v16, v17 \n\t"
426
- "vfmax.vv v18, v18, v19 \n\t"
427
- "vfmax.vv v20, v20, v21 \n\t"
428
- "vfmax.vv v22, v22, v23 \n\t"
429
- "vfmax.vv v24, v24, v25 \n\t"
430
- "vfmax.vv v26, v26, v27 \n\t"
431
- "vfmax.vv v28, v28, v29 \n\t"
432
- "vfmax.vv v30, v30, v31 \n\t"
433
- "vse32.v v16, (a1) \n\t"
434
- "addi a1, a1, 32 \n\t"
435
- "vse32.v v18, (a1) \n\t"
436
- "addi a1, a1, 32 \n\t"
437
- "vse32.v v20, (a1) \n\t"
438
- "addi a1, a1, 32 \n\t"
439
- "vse32.v v22, (a1) \n\t"
440
- "addi a1, a1, 32 \n\t"
441
- "vse32.v v24, (a1) \n\t"
442
- "addi a1, a1, 32 \n\t"
443
- "vse32.v v26, (a1) \n\t"
444
- "addi a1, a1, 32 \n\t"
445
- "vse32.v v28, (a1) \n\t"
446
- "addi a1, a1, 32 \n\t"
447
- "vse32.v v30, (a1) \n\t"
448
- "addi a1, %[BUFFER], 0 \n\t"
449
- "flw f0, (a1) \n\t"
450
- "flw f1, 4(a1) \n\t"
451
- "flw f2, 8(a1) \n\t"
452
- "flw f3, 12(a1) \n\t"
453
- "flw f4, 16(a1) \n\t"
454
- "flw f5, 20(a1) \n\t"
455
- "flw f6, 24(a1) \n\t"
456
- "flw f7, 28(a1) \n\t"
457
- "addi a1, a1, 32 \n\t"
458
- "fmax.s f1, f0, f1 \n\t"
459
- "fmax.s f3, f2, f3 \n\t"
460
- "fmax.s f5, f4, f5 \n\t"
461
- "fmax.s f7, f6, f7 \n\t"
462
- "fmax.s f3, f1, f3 \n\t"
463
- "fmax.s f7, f5, f7 \n\t"
464
- "fmax.s f10, f3, f7 \n\t"
465
- "fmul.s f10, f10, %[RMAXREC] \n\t"
466
- "fsw f10, (%[DST]) \n\t"
467
- "addi %[DST], %[DST], 20 \n\t"
468
- "fdiv.s f10, %[FONE], f10 \n\t"
469
- "flw f0, (a1) \n\t"
470
- "flw f1, 4(a1) \n\t"
471
- "flw f2, 8(a1) \n\t"
472
- "flw f3, 12(a1) \n\t"
473
- "flw f4, 16(a1) \n\t"
474
- "flw f5, 20(a1) \n\t"
475
- "flw f6, 24(a1) \n\t"
476
- "flw f7, 28(a1) \n\t"
477
- "addi a1, a1, 32 \n\t"
478
- "fmax.s f1, f0, f1 \n\t"
479
- "fmax.s f3, f2, f3 \n\t"
480
- "fmax.s f5, f4, f5 \n\t"
481
- "fmax.s f7, f6, f7 \n\t"
482
- "fmax.s f3, f1, f3 \n\t"
483
- "fmax.s f7, f5, f7 \n\t"
484
- "fmax.s f11, f3, f7 \n\t"
485
- "fmul.s f11, f11, %[RMAXREC] \n\t"
486
- "fsw f11, (%[DST]) \n\t"
487
- "addi %[DST], %[DST], 20 \n\t"
488
- "fdiv.s f11, %[FONE], f11 \n\t"
489
- "flw f0, (a1) \n\t"
490
- "flw f1, 4(a1) \n\t"
491
- "flw f2, 8(a1) \n\t"
492
- "flw f3, 12(a1) \n\t"
493
- "flw f4, 16(a1) \n\t"
494
- "flw f5, 20(a1) \n\t"
495
- "flw f6, 24(a1) \n\t"
496
- "flw f7, 28(a1) \n\t"
497
- "addi a1, a1, 32 \n\t"
498
- "fmax.s f1, f0, f1 \n\t"
499
- "fmax.s f3, f2, f3 \n\t"
500
- "fmax.s f5, f4, f5 \n\t"
501
- "fmax.s f7, f6, f7 \n\t"
502
- "fmax.s f3, f1, f3 \n\t"
503
- "fmax.s f7, f5, f7 \n\t"
504
- "fmax.s f12, f3, f7 \n\t"
505
- "fmul.s f12, f12, %[RMAXREC] \n\t"
506
- "fsw f12, (%[DST]) \n\t"
507
- "addi %[DST], %[DST], 20 \n\t"
508
- "fdiv.s f12, %[FONE], f12 \n\t"
509
- "flw f0, (a1) \n\t"
510
- "flw f1, 4(a1) \n\t"
511
- "flw f2, 8(a1) \n\t"
512
- "flw f3, 12(a1) \n\t"
513
- "flw f4, 16(a1) \n\t"
514
- "flw f5, 20(a1) \n\t"
515
- "flw f6, 24(a1) \n\t"
516
- "flw f7, 28(a1) \n\t"
517
- "addi a1, a1, 32 \n\t"
518
- "fmax.s f1, f0, f1 \n\t"
519
- "fmax.s f3, f2, f3 \n\t"
520
- "fmax.s f5, f4, f5 \n\t"
521
- "fmax.s f7, f6, f7 \n\t"
522
- "fmax.s f3, f1, f3 \n\t"
523
- "fmax.s f7, f5, f7 \n\t"
524
- "fmax.s f13, f3, f7 \n\t"
525
- "fmul.s f13, f13, %[RMAXREC] \n\t"
526
- "fsw f13, (%[DST]) \n\t"
527
- "addi %[DST], %[DST], 20 \n\t"
528
- "fdiv.s f13, %[FONE], f13 \n\t"
529
- "flw f0, (a1) \n\t"
530
- "flw f1, 4(a1) \n\t"
531
- "flw f2, 8(a1) \n\t"
532
- "flw f3, 12(a1) \n\t"
533
- "flw f4, 16(a1) \n\t"
534
- "flw f5, 20(a1) \n\t"
535
- "flw f6, 24(a1) \n\t"
536
- "flw f7, 28(a1) \n\t"
537
- "addi a1, a1, 32 \n\t"
538
- "fmax.s f1, f0, f1 \n\t"
539
- "fmax.s f3, f2, f3 \n\t"
540
- "fmax.s f5, f4, f5 \n\t"
541
- "fmax.s f7, f6, f7 \n\t"
542
- "fmax.s f3, f1, f3 \n\t"
543
- "fmax.s f7, f5, f7 \n\t"
544
- "fmax.s f14, f3, f7 \n\t"
545
- "fmul.s f14, f14, %[RMAXREC] \n\t"
546
- "fsw f14, (%[DST]) \n\t"
547
- "addi %[DST], %[DST], 20 \n\t"
548
- "fdiv.s f14, %[FONE], f14 \n\t"
549
- "flw f0, (a1) \n\t"
550
- "flw f1, 4(a1) \n\t"
551
- "flw f2, 8(a1) \n\t"
552
- "flw f3, 12(a1) \n\t"
553
- "flw f4, 16(a1) \n\t"
554
- "flw f5, 20(a1) \n\t"
555
- "flw f6, 24(a1) \n\t"
556
- "flw f7, 28(a1) \n\t"
557
- "addi a1, a1, 32 \n\t"
558
- "fmax.s f1, f0, f1 \n\t"
559
- "fmax.s f3, f2, f3 \n\t"
560
- "fmax.s f5, f4, f5 \n\t"
561
- "fmax.s f7, f6, f7 \n\t"
562
- "fmax.s f3, f1, f3 \n\t"
563
- "fmax.s f7, f5, f7 \n\t"
564
- "fmax.s f15, f3, f7 \n\t"
565
- "fmul.s f15, f15, %[RMAXREC] \n\t"
566
- "fsw f15, (%[DST]) \n\t"
567
- "addi %[DST], %[DST], 20 \n\t"
568
- "fdiv.s f15, %[FONE], f15 \n\t"
569
- "flw f0, (a1) \n\t"
570
- "flw f1, 4(a1) \n\t"
571
- "flw f2, 8(a1) \n\t"
572
- "flw f3, 12(a1) \n\t"
573
- "flw f4, 16(a1) \n\t"
574
- "flw f5, 20(a1) \n\t"
575
- "flw f6, 24(a1) \n\t"
576
- "flw f7, 28(a1) \n\t"
577
- "addi a1, a1, 32 \n\t"
578
- "fmax.s f1, f0, f1 \n\t"
579
- "fmax.s f3, f2, f3 \n\t"
580
- "fmax.s f5, f4, f5 \n\t"
581
- "fmax.s f7, f6, f7 \n\t"
582
- "fmax.s f3, f1, f3 \n\t"
583
- "fmax.s f7, f5, f7 \n\t"
584
- "fmax.s f16, f3, f7 \n\t"
585
- "fmul.s f16, f16, %[RMAXREC] \n\t"
586
- "fsw f16, (%[DST]) \n\t"
587
- "addi %[DST], %[DST], 20 \n\t"
588
- "fdiv.s f16, %[FONE], f16 \n\t"
589
- "flw f0, (a1) \n\t"
590
- "flw f1, 4(a1) \n\t"
591
- "flw f2, 8(a1) \n\t"
592
- "flw f3, 12(a1) \n\t"
593
- "flw f4, 16(a1) \n\t"
594
- "flw f5, 20(a1) \n\t"
595
- "flw f6, 24(a1) \n\t"
596
- "flw f7, 28(a1) \n\t"
597
- "addi a1, a1, 32 \n\t"
598
- "fmax.s f1, f0, f1 \n\t"
599
- "fmax.s f3, f2, f3 \n\t"
600
- "fmax.s f5, f4, f5 \n\t"
601
- "fmax.s f7, f6, f7 \n\t"
602
- "fmax.s f3, f1, f3 \n\t"
603
- "fmax.s f7, f5, f7 \n\t"
604
- "fmax.s f17, f3, f7 \n\t"
605
- "fmul.s f17, f17, %[RMAXREC] \n\t"
606
- "fsw f17, (%[DST]) \n\t"
607
- "addi %[DST], %[DST], -136 \n\t"
608
- "fdiv.s f17, %[FONE], f17 \n\t"
609
- "vsetvli t0, zero, e32, m2 \n\t"
610
- "vfmul.vf v16, v0, f10 \n\t"
611
- "vfmul.vf v18, v2, f11 \n\t"
612
- "vfmul.vf v20, v4, f12 \n\t"
613
- "vfmul.vf v22, v6, f13 \n\t"
614
- "vfmul.vf v24, v8, f14 \n\t"
615
- "vfmul.vf v26, v10, f15 \n\t"
616
- "vfmul.vf v28, v12, f16 \n\t"
617
- "vfmul.vf v30, v14, f17 \n\t"
618
- "vfcvt.x.f.v v16, v16 \n\t"
619
- "vfcvt.x.f.v v18, v18 \n\t"
620
- "vfcvt.x.f.v v20, v20 \n\t"
621
- "vfcvt.x.f.v v22, v22 \n\t"
622
- "vfcvt.x.f.v v24, v24 \n\t"
623
- "vfcvt.x.f.v v26, v26 \n\t"
624
- "vfcvt.x.f.v v28, v28 \n\t"
625
- "vfcvt.x.f.v v30, v30 \n\t"
626
- "vsetvli t0, zero, e16, m1 \n\t"
627
- "vnclip.wx v16, v16, zero \n\t"
628
- "vnclip.wx v18, v18, zero \n\t"
629
- "vnclip.wx v20, v20, zero \n\t"
630
- "vnclip.wx v22, v22, zero \n\t"
631
- "vnclip.wx v24, v24, zero \n\t"
632
- "vnclip.wx v26, v26, zero \n\t"
633
- "vnclip.wx v28, v28, zero \n\t"
634
- "vnclip.wx v30, v30, zero \n\t"
635
- "vsetvli t0, t1, e8, mf2 \n\t"
636
- "vnclip.wx v16, v16, zero \n\t"
637
- "vnclip.wx v18, v18, zero \n\t"
638
- "vnclip.wx v20, v20, zero \n\t"
639
- "vnclip.wx v22, v22, zero \n\t"
640
- "vnclip.wx v24, v24, zero \n\t"
641
- "vnclip.wx v26, v26, zero \n\t"
642
- "vnclip.wx v28, v28, zero \n\t"
643
- "vnclip.wx v30, v30, zero \n\t"
644
- "vse8.v v16, (%[DST]) \n\t"
645
- "addi %[DST], %[DST], 20 \n\t"
646
- "vse8.v v18, (%[DST]) \n\t"
647
- "addi %[DST], %[DST], 20 \n\t"
648
- "vse8.v v20, (%[DST]) \n\t"
649
- "addi %[DST], %[DST], 20 \n\t"
650
- "vse8.v v22, (%[DST]) \n\t"
651
- "addi %[DST], %[DST], 20 \n\t"
652
- "vse8.v v24, (%[DST]) \n\t"
653
- "addi %[DST], %[DST], 20 \n\t"
654
- "vse8.v v26, (%[DST]) \n\t"
655
- "addi %[DST], %[DST], 20 \n\t"
656
- "vse8.v v28, (%[DST]) \n\t"
657
- "addi %[DST], %[DST], 20 \n\t"
658
- "vse8.v v30, (%[DST]) \n\t"
659
- "addi %[DST], %[DST], 16 \n\t"
660
- "bge %[K], t3, LOOP_MAIN%= \n\t"
661
- "blt %[K], t2, TAIL%= \n\t"
662
- "LOOP_K%=: \n\t"
663
- "vsetvli t1, %[K], e32, m2 \n\t"
664
- "vle32.v v0, (%[SRC]) \n\t"
665
- "addi %[SRC], %[SRC], 64 \n\t"
666
- "sub %[K], %[K], t1 \n\t"
667
- "vfabs.v v16, v0 \n\t"
668
- "vsetvli t0, zero, e32, m1 \n\t"
669
- "vfmax.vv v16, v16, v17 \n\t"
670
- "vse32.v v16, (%[BUFFER]) \n\t"
671
- "flw f0, (%[BUFFER]) \n\t"
672
- "flw f1, 4(%[BUFFER]) \n\t"
673
- "flw f2, 8(%[BUFFER]) \n\t"
674
- "flw f3, 12(%[BUFFER]) \n\t"
675
- "flw f4, 16(%[BUFFER]) \n\t"
676
- "flw f5, 20(%[BUFFER]) \n\t"
677
- "flw f6, 24(%[BUFFER]) \n\t"
678
- "flw f7, 28(%[BUFFER]) \n\t"
679
- "fmax.s f1, f0, f1 \n\t"
680
- "fmax.s f3, f2, f3 \n\t"
681
- "fmax.s f5, f4, f5 \n\t"
682
- "fmax.s f7, f6, f7 \n\t"
683
- "fmax.s f3, f1, f3 \n\t"
684
- "fmax.s f7, f5, f7 \n\t"
685
- "fmax.s f10, f3, f7 \n\t"
686
- "fmul.s f10, f10, %[RMAXREC] \n\t"
687
- "fsw f10, (%[DST]) \n\t"
688
- "addi %[DST], %[DST], 4 \n\t"
689
- "fdiv.s f11, %[FONE], f10 \n\t"
690
- "vsetvli t0, zero, e32, m2 \n\t"
691
- "vfmul.vf v16, v0, f11 \n\t"
692
- "vfcvt.x.f.v v16, v16 \n\t"
693
- "vsetvli t0, zero, e16, m1 \n\t"
694
- "vnclip.wx v16, v16, zero \n\t"
695
- "vsetvli t0, t1, e8, mf2 \n\t"
696
- "vnclip.wx v16, v16, zero \n\t"
697
- "vse8.v v16, (%[DST]) \n\t"
698
- "addi %[DST], %[DST], 16 \n\t"
699
- "bge %[K], t2, LOOP_K%= \n\t"
700
- "TAIL%=: \n\t"
701
- "blez %[K], END%= \n\t"
702
- "vsetvli t0, t3, e32, m2 \n\t"
703
- "vxor.vv v16, v16, v16 \n\t"
704
- "jal x0, LOOP_K%= \n\t"
705
- "END%=: \n\t"
706
- : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK)
707
- : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BUFFER] "r"(buffer)
708
- : "cc", "t3", "t2", "t1", "t0", "a1", "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f10", "f11", "f12",
709
- "f13", "f14", "f15", "f16", "f17");
710
- } else if (BlkLen == 32) {
711
- __asm__ volatile(
712
- "addi t3, zero, 32*4 \n\t"
713
- "addi t2, zero, 32 \n\t"
714
-
715
- "addi a1, %[SRC], 0 \n\t"
716
- "addi a2, %[SRC], 128 \n\t"
717
- "addi a3, %[SRC], 256 \n\t"
718
- "addi a4, %[SRC], 384 \n\t"
719
-
720
- "addi s1, %[DST], 0 \n\t"
721
- "addi s2, %[DST], 36 \n\t"
722
- "addi s3, %[DST], 72 \n\t"
723
- "addi s4, %[DST], 108 \n\t"
724
- "blt %[K], t3, LOOP_K%= \n\t"
725
- "blt %[K], t2, TAIL%= \n\t"
726
-
727
- "LOOP_MAIN%=: \n\t"
728
- "vsetvli t1, zero, e32, m4 \n\t"
729
- "addi %[K], %[K], -128 \n\t"
730
- "vle32.v v0, (a1) \n\t"
731
- "addi a1, a1, 512 \n\t"
732
- "vle32.v v4, (a2) \n\t"
733
- "addi a2, a2, 512 \n\t"
734
- "vle32.v v8, (a3) \n\t"
735
- "addi a3, a3, 512 \n\t"
736
- "vle32.v v12, (a4) \n\t"
737
- "addi a4, a4, 512 \n\t"
738
- "vfabs.v v16, v0 \n\t"
739
- "vfabs.v v20, v4 \n\t"
740
- "vfabs.v v24, v8 \n\t"
741
- "vfabs.v v28, v12 \n\t"
742
- "vsetvli t0, zero, e32, m2 \n\t"
743
- "vfmax.vv v16, v16, v18 \n\t"
744
- "vfmax.vv v20, v20, v22 \n\t"
745
- "vfmax.vv v24, v24, v26 \n\t"
746
- "vfmax.vv v28, v28, v30 \n\t"
747
- "vsetvli t0, zero, e32, m1 \n\t"
748
- "vfmax.vv v16, v16, v17 \n\t"
749
- "vfmax.vv v20, v20, v21 \n\t"
750
- "vfmax.vv v24, v24, v25 \n\t"
751
- "vfmax.vv v28, v28, v29 \n\t"
752
-
753
- "vfredmax.vs v17, v16, v17 \n\t"
754
- "vfredmax.vs v21, v20, v21 \n\t"
755
- "vfredmax.vs v25, v24, v25 \n\t"
756
- "vfredmax.vs v29, v28, v29 \n\t"
757
- "vfmv.f.s f10, v17 \n\t"
758
- "vfmv.f.s f11, v21 \n\t"
759
- "vfmv.f.s f12, v25 \n\t"
760
- "vfmv.f.s f13, v29 \n\t"
761
-
762
- "fmul.s f10, f10, %[RMAXREC] \n\t"
763
- "fmul.s f11, f11, %[RMAXREC] \n\t"
764
- "fmul.s f12, f12, %[RMAXREC] \n\t"
765
- "fmul.s f13, f13, %[RMAXREC] \n\t"
766
- "fsw f10, (s1) \n\t"
767
- "addi s1, s1, 4 \n\t"
768
-
769
- "fsw f11, (s2) \n\t"
770
- "addi s2, s2, 4 \n\t"
771
- "fsw f12, (s3) \n\t"
772
- "addi s3, s3, 4 \n\t"
773
- "fsw f13, (s4) \n\t"
774
- "addi s4, s4, 4 \n\t"
775
- "fdiv.s f10, %[FONE], f10 \n\t"
776
- "fdiv.s f11, %[FONE], f11 \n\t"
777
- "fdiv.s f12, %[FONE], f12 \n\t"
778
- "fdiv.s f13, %[FONE], f13 \n\t"
779
- "vsetvli t0, zero, e32, m4 \n\t"
780
- "vfmul.vf v16, v0, f10 \n\t"
781
- "vfmul.vf v20, v4, f11 \n\t"
782
- "vfmul.vf v24, v8, f12 \n\t"
783
- "vfmul.vf v28, v12, f13 \n\t"
784
- "vfcvt.x.f.v v16, v16 \n\t"
785
- "vfcvt.x.f.v v20, v20 \n\t"
786
- "vfcvt.x.f.v v24, v24 \n\t"
787
- "vfcvt.x.f.v v28, v28 \n\t"
788
- "vsetvli t0, zero, e16, m2 \n\t"
789
- "vnclip.wx v16, v16, zero \n\t"
790
- "vnclip.wx v20, v20, zero \n\t"
791
- "vnclip.wx v24, v24, zero \n\t"
792
- "vnclip.wx v28, v28, zero \n\t"
793
- "vsetvli t0, t1, e8, m1 \n\t"
794
- "vnclip.wx v16, v16, zero \n\t"
795
- "vnclip.wx v20, v20, zero \n\t"
796
- "vnclip.wx v24, v24, zero \n\t"
797
- "vnclip.wx v28, v28, zero \n\t"
798
- "vse8.v v16, (s1) \n\t"
799
- "addi s1, s1, 140 \n\t"
800
- "vse8.v v20, (s2) \n\t"
801
- "addi s2, s2, 140 \n\t"
802
- "vse8.v v24, (s3) \n\t"
803
- "addi s3, s3, 140 \n\t"
804
- "vse8.v v28, (s4) \n\t"
805
- "addi s4, s4, 140 \n\t"
806
- "bge %[K], t3, LOOP_MAIN%= \n\t"
807
- "blt %[K], t2, TAIL%= \n\t"
808
- "LOOP_K%=: \n\t"
809
- "vsetvli t1, %[K], e32, m4 \n\t"
810
- "vle32.v v0, (a1) \n\t"
811
- "addi a1, a1, 128 \n\t"
812
- "sub %[K], %[K], t1 \n\t"
813
- "vfabs.v v16, v0 \n\t"
814
- "vsetvli t0, zero, e32, m2 \n\t"
815
- "vfmax.vv v16, v16, v18 \n\t"
816
- "vsetvli t0, zero, e32, m1 \n\t"
817
- "vfmax.vv v16, v16, v17 \n\t"
818
- "vfredmax.vs v17, v16, v17 \n\t"
819
- "vfmv.f.s f10, v17 \n\t"
820
-
821
- "fmul.s f10, f10, %[RMAXREC] \n\t"
822
- "fsw f10, (s1) \n\t"
823
- "addi s1, s1, 4 \n\t"
824
- "fdiv.s f11, %[FONE], f10 \n\t"
825
- "vsetvli t0, zero, e32, m4 \n\t"
826
- "vfmul.vf v16, v0, f11 \n\t"
827
- "vfcvt.x.f.v v16, v16 \n\t"
828
- "vsetvli t0, zero, e16, m2 \n\t"
829
- "vnclip.wx v16, v16, zero \n\t"
830
- "vsetvli t0, zero, e8, m1 \n\t"
831
- "vnclip.wx v16, v16, zero \n\t"
832
- "vse8.v v16, (s1) \n\t"
833
- "addi s1, s1, 32 \n\t"
834
- "bge %[K], t2, LOOP_K%= \n\t"
835
- "TAIL%=: \n\t"
836
- "blez %[K], END%= \n\t"
837
- "vsetvli t0, t3, e32, m4 \n\t"
838
- "vxor.vv v0, v0, v0 \n\t"
839
- "vxor.vv v16, v16, v16 \n\t"
840
- "jal x0, LOOP_K%= \n\t"
841
- "END%=: \n\t"
842
- : [K] "+r"(CountK)
843
- : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST)
844
- : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13");
845
- } else if (BlkLen == 64) {
846
- __asm__ volatile(
847
- "addi t3, zero, 64*2 \n\t"
848
- "addi t2, zero, 64 \n\t"
849
- "addi a1, %[SRC], 0 \n\t"
850
- "addi a2, %[SRC], 256 \n\t"
851
- "addi s1, %[DST], 0 \n\t"
852
- "addi s2, %[DST], 68 \n\t"
853
- "blt %[K], t3, LOOP_K%= \n\t"
854
- "blt %[K], t2, TAIL%= \n\t"
855
- "LOOP_MAIN%=: \n\t"
856
- "vsetvli t1, zero, e32, m8 \n\t"
857
- "addi %[K], %[K], -128 \n\t"
858
- "vle32.v v0, (a1) \n\t"
859
- "addi a1, a1, 512 \n\t"
860
- "vle32.v v8, (a2) \n\t"
861
- "addi a2, a2, 512 \n\t"
862
- "vfabs.v v16, v0 \n\t"
863
- "vfabs.v v24, v8 \n\t"
864
- "vsetvli t0, zero, e32, m4 \n\t"
865
- "vfmax.vv v16, v16, v20 \n\t"
866
- "vfmax.vv v24, v24, v28 \n\t"
867
- "vsetvli t0, zero, e32, m2 \n\t"
868
- "vfmax.vv v16, v16, v18 \n\t"
869
- "vfmax.vv v24, v24, v26 \n\t"
870
- "vsetvli t0, zero, e32, m1 \n\t"
871
- "vfmax.vv v16, v16, v17 \n\t"
872
- "vfmax.vv v24, v24, v25 \n\t"
873
- "vfredmax.vs v17, v16, v17 \n\t"
874
- "vfredmax.vs v25, v24, v25 \n\t"
875
- "vfmv.f.s f10, v17 \n\t"
876
- "vfmv.f.s f11, v25 \n\t"
877
- "fmul.s f10, f10, %[RMAXREC] \n\t"
878
- "fmul.s f11, f11, %[RMAXREC] \n\t"
879
- "fsw f10, (s1) \n\t"
880
- "addi s1, s1, 4 \n\t"
881
- "fsw f11, (s2) \n\t"
882
- "addi s2, s2, 4 \n\t"
883
- "fdiv.s f10, %[FONE], f10 \n\t"
884
- "fdiv.s f11, %[FONE], f11 \n\t"
885
- "vsetvli t0, zero, e32, m8 \n\t"
886
- "vfmul.vf v16, v0, f10 \n\t"
887
- "vfmul.vf v24, v8, f11 \n\t"
888
- "vfcvt.x.f.v v16, v16 \n\t"
889
- "vfcvt.x.f.v v24, v24 \n\t"
890
- "vsetvli t0, zero, e16, m4 \n\t"
891
- "vnclip.wx v16, v16, zero \n\t"
892
- "vnclip.wx v24, v24, zero \n\t"
893
- "vsetvli t0, t1, e8, m2 \n\t"
894
- "vnclip.wx v16, v16, zero \n\t"
895
- "vnclip.wx v24, v24, zero \n\t"
896
- "vse8.v v16, (s1) \n\t"
897
- "addi s1, s1, 132 \n\t"
898
- "vse8.v v24, (s2) \n\t"
899
- "addi s2, s2, 132 \n\t"
900
- "bge %[K], t3, LOOP_MAIN%= \n\t"
901
- "blt %[K], t2, TAIL%= \n\t"
902
- "LOOP_K%=: \n\t"
903
- "vsetvli t1, %[K], e32, m8 \n\t"
904
- "vle32.v v0, (a1) \n\t"
905
- "addi a1, a1, 256 \n\t"
906
- "sub %[K], %[K], t1 \n\t"
907
- "vfabs.v v16, v0 \n\t"
908
- "vsetvli t0, zero, e32, m4 \n\t"
909
- "vfmax.vv v16, v16, v20 \n\t"
910
- "vsetvli t0, zero, e32, m2 \n\t"
911
- "vfmax.vv v16, v16, v18 \n\t"
912
- "vsetvli t0, zero, e32, m1 \n\t"
913
- "vfmax.vv v16, v16, v17 \n\t"
914
- "vfredmax.vs v17, v16, v17 \n\t"
915
- "vfmv.f.s f10, v17 \n\t"
916
- "fmul.s f10, f10, %[RMAXREC] \n\t"
917
- "fsw f10, (s1) \n\t"
918
- "addi s1, s1, 4 \n\t"
919
- "fdiv.s f11, %[FONE], f10 \n\t"
920
- "vsetvli t0, zero, e32, m8 \n\t"
921
- "vfmul.vf v16, v0, f11 \n\t"
922
- "vfcvt.x.f.v v16, v16 \n\t"
923
- "vsetvli t0, zero, e16, m4 \n\t"
924
- "vnclip.wx v16, v16, zero \n\t"
925
- "vsetvli t0, zero, e8, m2 \n\t"
926
- "vnclip.wx v16, v16, zero \n\t"
927
- "vse8.v v16, (s1) \n\t"
928
- "addi s1, s1, 64 \n\t"
929
- "bge %[K], t2, LOOP_K%= \n\t"
930
- "TAIL%=: \n\t"
931
- "blez %[K], END%= \n\t"
932
- "vsetvli t0, t3, e32, m8 \n\t"
933
- "vxor.vv v0, v0, v0 \n\t"
934
- "vxor.vv v16, v16, v16 \n\t"
935
- "jal x0, LOOP_K%= \n\t"
936
- "END%=: \n\t"
937
- : [K] "+r"(CountK)
938
- : [SRC] "r"(SRC), [DST] "r"(DST), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
939
- : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "s1", "s2", "f10", "f11");
940
- } else if (BlkLen == 128) {
941
- __asm__ volatile(
942
- "addi t2, zero, 128 \n\t"
943
- "addi a1, %[SRC], 0 \n\t"
944
- "addi a2, %[SRC], 256 \n\t"
945
- "blt %[K], t2, TAIL%= \n\t"
946
- "LOOP_K%=: \n\t"
947
- "vsetvli t1, zero, e32, m8 \n\t"
948
- "vle32.v v0, (a1) \n\t"
949
- "addi a1, a1, 512 \n\t"
950
- "vle32.v v8, (a2) \n\t"
951
- "addi a2, a2, 512 \n\t"
952
- "sub %[K], %[K], t2 \n\t"
953
- "QUANT%=: \n\t"
954
- "vfabs.v v16, v0 \n\t"
955
- "vfabs.v v24, v8 \n\t"
956
- "vfmax.vv v24, v16, v24 \n\t"
957
- "vsetvli t1, zero, e32, m4 \n\t"
958
- "vfmax.vv v28, v24, v28 \n\t"
959
- "vsetvli t0, zero, e32, m2 \n\t"
960
- "vfmax.vv v30, v28, v30 \n\t"
961
- "vsetvli t0, zero, e32, m1 \n\t"
962
- "vfmax.vv v30, v30, v31 \n\t"
963
- "vfredmax.vs v31, v30, v31 \n\t"
964
- "vfmv.f.s f10, v31 \n\t"
965
- "fmul.s f10, f10, %[RMAXREC] \n\t"
966
- "fsw f10, (%[DST]) \n\t"
967
- "addi %[DST], %[DST], 4 \n\t"
968
- "fdiv.s f11, %[FONE], f10 \n\t"
969
- "vsetvli t0, zero, e32, m8 \n\t"
970
- "vfmul.vf v16, v0, f11 \n\t"
971
- "vfmul.vf v24, v8, f11 \n\t"
972
- "vfcvt.x.f.v v16, v16 \n\t"
973
- "vfcvt.x.f.v v24, v24 \n\t"
974
- "vsetvli t0, zero, e16, m4 \n\t"
975
- "vnclip.wx v16, v16, zero \n\t"
976
- "vnclip.wx v20, v24, zero \n\t"
977
- "vsetvli t0, zero, e8, m4 \n\t"
978
- "vnclip.wx v16, v16, zero \n\t"
979
- "vse8.v v16, (%[DST]) \n\t"
980
- "addi %[DST], %[DST], 128 \n\t"
981
- "bge %[K], t2, LOOP_K%= \n\t"
982
- "TAIL%=: \n\t"
983
- "blez %[K], END%= \n\t"
984
- "vsetvli t1, zero, e32, m8 \n\t"
985
- "vxor.vv v0, v0, v0 \n\t"
986
- "vxor.vv v8, v8, v8 \n\t"
987
- "vsetvli t0, %[K], e32, m8 \n\t"
988
- "vle32.v v0, (a1) \n\t"
989
- "sub %[K], %[K], t0 \n\t"
990
- "vsetvli t0, %[K], e32, m8 \n\t"
991
- "vle32.v v8, (a2) \n\t"
992
- "sub %[K], %[K], t0 \n\t"
993
- "vsetvli t1, zero, e32, m8 \n\t"
994
- "jal x0, QUANT%= \n\t"
995
- "END%=: \n\t"
996
-
997
- : [DST] "+r"(DST), [K] "+r"(CountK)
998
- : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC)
999
- : "cc", "t2", "t1", "t0", "a1", "a2", "f10", "f11");
1000
- } else {
1001
- float buffer[8] = { 0.0f };
1002
- size_t cnt = BlkLen / 256;
1003
-
1004
- __asm__ volatile(
1005
- "slli t3, %[BLK], 2 \n\t"
1006
- "blt %[K], %[BLK], LOOP_TAIL%= \n\t"
1007
- "LOOP_MAIN%=: \n\t"
1008
- "vsetvli t0, zero, e32, m1 \n\t"
1009
- "vxor.vv v31, v31, v31 \n\t"
1010
- "vse32.v v31, (%[BUFFER]) \n\t"
1011
- "addi t6, %[CNT], 0 \n\t"
1012
- "LOOP_CMP%=: \n\t"
1013
- "addi t6, t6, -1 \n\t"
1014
- "vsetvli t0, zero, e32, m8 \n\t"
1015
- "vle32.v v0, (%[SRC]) \n\t"
1016
- "addi %[SRC], %[SRC], 256 \n\t"
1017
- "vle32.v v8, (%[SRC]) \n\t"
1018
- "addi %[SRC], %[SRC], 256 \n\t"
1019
- "vle32.v v16, (%[SRC]) \n\t"
1020
- "addi %[SRC], %[SRC], 256 \n\t"
1021
- "vle32.v v24, (%[SRC]) \n\t"
1022
- "addi %[SRC], %[SRC], 256 \n\t"
1023
- "vfabs.v v0, v0 \n\t"
1024
- "vfabs.v v8, v8 \n\t"
1025
- "vfabs.v v16, v16 \n\t"
1026
- "vfabs.v v24, v24 \n\t"
1027
- "vfmax.vv v8, v0, v8 \n\t"
1028
- "vfmax.vv v16, v16, v24 \n\t"
1029
- "vfmax.vv v0, v0, v16 \n\t"
1030
- "vsetvli t0, zero, e32, m4 \n\t"
1031
- "vfmax.vv v0, v0, v4 \n\t"
1032
- "vsetvli t0, zero, e32, m2 \n\t"
1033
- "vfmax.vv v0, v0, v2 \n\t"
1034
- "vsetvli t0, zero, e32, m1 \n\t"
1035
- "vfmax.vv v0, v0, v1 \n\t"
1036
- "vle32.v v30, (%[BUFFER]) \n\t"
1037
- "vfmax.vv v31, v30, v0 \n\t"
1038
- "vse32.v v31, (%[BUFFER]) \n\t"
1039
- "bnez t6, LOOP_CMP%= \n\t"
1040
- "sub %[SRC], %[SRC], t3 \n\t"
1041
- "addi t6, %[CNT], 0 \n\t"
1042
- "flw f0, (%[BUFFER]) \n\t"
1043
- "flw f1, 4(%[BUFFER]) \n\t"
1044
- "flw f2, 8(%[BUFFER]) \n\t"
1045
- "flw f3, 12(%[BUFFER]) \n\t"
1046
- "flw f4, 16(%[BUFFER]) \n\t"
1047
- "flw f5, 20(%[BUFFER]) \n\t"
1048
- "flw f6, 24(%[BUFFER]) \n\t"
1049
- "flw f7, 28(%[BUFFER]) \n\t"
1050
- "fmax.s f1, f0, f1 \n\t"
1051
- "fmax.s f3, f2, f3 \n\t"
1052
- "fmax.s f5, f4, f5 \n\t"
1053
- "fmax.s f7, f6, f7 \n\t"
1054
- "fmax.s f3, f1, f3 \n\t"
1055
- "fmax.s f7, f5, f7 \n\t"
1056
- "fmax.s f10, f3, f7 \n\t"
1057
- "fmul.s f10, f10, %[RMAXREC] \n\t"
1058
- "fsw f10, (%[DST]) \n\t"
1059
- "addi %[DST], %[DST], 4 \n\t"
1060
- "fdiv.s f11, %[FONE], f10 \n\t"
1061
- "addi t6, %[CNT], 0 \n\t"
1062
- "LOOP_QUANT%=: \n\t"
1063
- "addi t6, t6, -1 \n\t"
1064
- "vsetvli t0, zero, e32, m8 \n\t"
1065
- "vle32.v v0, (%[SRC]) \n\t"
1066
- "addi %[SRC], %[SRC], 256 \n\t"
1067
- "vle32.v v8, (%[SRC]) \n\t"
1068
- "addi %[SRC], %[SRC], 256 \n\t"
1069
- "vle32.v v16, (%[SRC]) \n\t"
1070
- "addi %[SRC], %[SRC], 256 \n\t"
1071
- "vle32.v v24, (%[SRC]) \n\t"
1072
- "addi %[SRC], %[SRC], 256 \n\t"
1073
- "vsetvli t0, zero, e32, m8 \n\t"
1074
- "vfmul.vf v0, v0, f11 \n\t"
1075
- "vfmul.vf v8, v8, f11 \n\t"
1076
- "vfmul.vf v16, v16, f11 \n\t"
1077
- "vfmul.vf v24, v24, f11 \n\t"
1078
- "vfcvt.x.f.v v0, v0 \n\t"
1079
- "vfcvt.x.f.v v8, v8 \n\t"
1080
- "vfcvt.x.f.v v16, v16 \n\t"
1081
- "vfcvt.x.f.v v24, v24 \n\t"
1082
- "vsetvli t0, zero, e16, m4 \n\t"
1083
- "vnclip.wx v0, v0, zero \n\t"
1084
- "vnclip.wx v4, v8, zero \n\t"
1085
- "vnclip.wx v8, v16, zero \n\t"
1086
- "vnclip.wx v12, v24, zero \n\t"
1087
- "vsetvli t0, zero, e8, m4 \n\t"
1088
- "vnclip.wx v0, v0, zero \n\t"
1089
- "vnclip.wx v4, v8, zero \n\t"
1090
- "vse8.v v0, (%[DST]) \n\t"
1091
- "addi %[DST], %[DST], 128 \n\t"
1092
- "vse8.v v4, (%[DST]) \n\t"
1093
- "addi %[DST], %[DST], 128 \n\t"
1094
- "bnez t6, LOOP_QUANT%= \n\t"
1095
- "sub %[K], %[K], %[BLK] \n\t"
1096
- "bge %[K], %[BLK], LOOP_MAIN%= \n\t"
1097
- "blez %[K], END%= \n\t"
1098
- "LOOP_TAIL%=: \n\t"
1099
- "vsetvli t0, zero, e32, m1 \n\t"
1100
- "vxor.vv v31, v31, v31 \n\t"
1101
- "vse32.v v31, (%[BUFFER]) \n\t"
1102
- "addi t6, %[K], 0 \n\t"
1103
- "addi s1, %[SRC], 0 \n\t"
1104
- "TAIL_CMP%=: \n\t"
1105
- "vsetvli t0, zero, e32, m8 \n\t"
1106
- "vxor.vv v0, v0, v0 \n\t"
1107
- "vsetvli t0, t6, e32, m8 \n\t"
1108
- "vle32.v v0, (%[SRC]) \n\t"
1109
- "addi %[SRC], %[SRC], 256 \n\t"
1110
- "sub t6, t6, t0 \n\t"
1111
- "vfabs.v v0, v0 \n\t"
1112
- "vsetvli t0, zero, e32, m4 \n\t"
1113
- "vfmax.vv v0, v0, v4 \n\t"
1114
- "vsetvli t0, zero, e32, m2 \n\t"
1115
- "vfmax.vv v0, v0, v2 \n\t"
1116
- "vsetvli t0, zero, e32, m1 \n\t"
1117
- "vfmax.vv v0, v0, v1 \n\t"
1118
- "vle32.v v30, (%[BUFFER]) \n\t"
1119
- "vfmax.vv v31, v30, v0 \n\t"
1120
- "vse32.v v31, (%[BUFFER]) \n\t"
1121
- "bnez t6, TAIL_CMP%= \n\t"
1122
- "addi t6, %[K], 0 \n\t"
1123
- "flw f0, (%[BUFFER]) \n\t"
1124
- "flw f1, 4(%[BUFFER]) \n\t"
1125
- "flw f2, 8(%[BUFFER]) \n\t"
1126
- "flw f3, 12(%[BUFFER]) \n\t"
1127
- "flw f4, 16(%[BUFFER]) \n\t"
1128
- "flw f5, 20(%[BUFFER]) \n\t"
1129
- "flw f6, 24(%[BUFFER]) \n\t"
1130
- "flw f7, 28(%[BUFFER]) \n\t"
1131
- "fmax.s f1, f0, f1 \n\t"
1132
- "fmax.s f3, f2, f3 \n\t"
1133
- "fmax.s f5, f4, f5 \n\t"
1134
- "fmax.s f7, f6, f7 \n\t"
1135
- "fmax.s f3, f1, f3 \n\t"
1136
- "fmax.s f7, f5, f7 \n\t"
1137
- "fmax.s f10, f3, f7 \n\t"
1138
- "fmul.s f10, f10, %[RMAXREC] \n\t"
1139
- "fsw f10, (%[DST]) \n\t"
1140
- "addi %[DST], %[DST], 4 \n\t"
1141
- "fdiv.s f11, %[FONE], f10 \n\t"
1142
- "addi t6, %[K], 0 \n\t"
1143
- "TAIL_QUANT%=: \n\t"
1144
- "vsetvli t0, zero, e32, m8 \n\t"
1145
- "vxor.vv v0, v0, v0 \n\t"
1146
- "vsetvli t1, t6, e32, m8 \n\t"
1147
- "vle32.v v0, (s1) \n\t"
1148
- "addi s1, s1, 256 \n\t"
1149
- "sub t6, t6, t1 \n\t"
1150
- "vsetvli t0, zero, e32, m8 \n\t"
1151
- "vfmul.vf v0, v0, f11 \n\t"
1152
- "vfcvt.x.f.v v0, v0 \n\t"
1153
- "vsetvli t0, zero, e16, m4 \n\t"
1154
- "vnclip.wx v0, v0, zero \n\t"
1155
- "vsetvli t0, t1, e8, m2 \n\t"
1156
- "vnclip.wx v0, v0, zero \n\t"
1157
- "vse8.v v0, (%[DST]) \n\t"
1158
- "addi %[DST], %[DST], 64 \n\t"
1159
- "bnez t6, TAIL_QUANT%= \n\t"
1160
- "END%=: \n\t"
1161
- : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK)
1162
- : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BLK] "r"(BlkLen), [BUFFER] "r"(buffer),
1163
- [CNT] "r"(cnt)
1164
- : "cc", "t1", "t0", "t6", "s1", "f0", "f1", "f2", "f3", "f4", "f5", "f6");
1165
- }
165
+ __asm__ volatile(
166
+ "addi t3, zero, 32*4 \n\t"
167
+ "addi t2, zero, 32 \n\t"
168
+
169
+ "addi a1, %[SRC], 0 \n\t"
170
+ "addi a2, %[SRC], 128 \n\t"
171
+ "addi a3, %[SRC], 256 \n\t"
172
+ "addi a4, %[SRC], 384 \n\t"
173
+
174
+ "addi s1, %[DST], 0 \n\t"
175
+ "addi s2, %[DST], 36 \n\t"
176
+ "addi s3, %[DST], 72 \n\t"
177
+ "addi s4, %[DST], 108 \n\t"
178
+ "blt %[K], t3, LOOP_K%= \n\t"
179
+ "blt %[K], t2, TAIL%= \n\t"
180
+
181
+ "LOOP_MAIN%=: \n\t"
182
+ "vsetvli t1, zero, e32, m4 \n\t"
183
+ "addi %[K], %[K], -128 \n\t"
184
+ "vle32.v v0, (a1) \n\t"
185
+ "addi a1, a1, 512 \n\t"
186
+ "vle32.v v4, (a2) \n\t"
187
+ "addi a2, a2, 512 \n\t"
188
+ "vle32.v v8, (a3) \n\t"
189
+ "addi a3, a3, 512 \n\t"
190
+ "vle32.v v12, (a4) \n\t"
191
+ "addi a4, a4, 512 \n\t"
192
+ "vfabs.v v16, v0 \n\t"
193
+ "vfabs.v v20, v4 \n\t"
194
+ "vfabs.v v24, v8 \n\t"
195
+ "vfabs.v v28, v12 \n\t"
196
+ "vsetvli t0, zero, e32, m2 \n\t"
197
+ "vfmax.vv v16, v16, v18 \n\t"
198
+ "vfmax.vv v20, v20, v22 \n\t"
199
+ "vfmax.vv v24, v24, v26 \n\t"
200
+ "vfmax.vv v28, v28, v30 \n\t"
201
+ "vsetvli t0, zero, e32, m1 \n\t"
202
+ "vfmax.vv v16, v16, v17 \n\t"
203
+ "vfmax.vv v20, v20, v21 \n\t"
204
+ "vfmax.vv v24, v24, v25 \n\t"
205
+ "vfmax.vv v28, v28, v29 \n\t"
206
+
207
+ "vfredmax.vs v17, v16, v17 \n\t"
208
+ "vfredmax.vs v21, v20, v21 \n\t"
209
+ "vfredmax.vs v25, v24, v25 \n\t"
210
+ "vfredmax.vs v29, v28, v29 \n\t"
211
+ "vfmv.f.s f10, v17 \n\t"
212
+ "vfmv.f.s f11, v21 \n\t"
213
+ "vfmv.f.s f12, v25 \n\t"
214
+ "vfmv.f.s f13, v29 \n\t"
215
+
216
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
217
+ "fmul.s f11, f11, %[RMAXREC] \n\t"
218
+ "fmul.s f12, f12, %[RMAXREC] \n\t"
219
+ "fmul.s f13, f13, %[RMAXREC] \n\t"
220
+ "fsw f10, (s1) \n\t"
221
+ "addi s1, s1, 4 \n\t"
222
+
223
+ "fsw f11, (s2) \n\t"
224
+ "addi s2, s2, 4 \n\t"
225
+ "fsw f12, (s3) \n\t"
226
+ "addi s3, s3, 4 \n\t"
227
+ "fsw f13, (s4) \n\t"
228
+ "addi s4, s4, 4 \n\t"
229
+ "fdiv.s f10, %[FONE], f10 \n\t"
230
+ "fdiv.s f11, %[FONE], f11 \n\t"
231
+ "fdiv.s f12, %[FONE], f12 \n\t"
232
+ "fdiv.s f13, %[FONE], f13 \n\t"
233
+ "vsetvli t0, zero, e32, m4 \n\t"
234
+ "vfmul.vf v16, v0, f10 \n\t"
235
+ "vfmul.vf v20, v4, f11 \n\t"
236
+ "vfmul.vf v24, v8, f12 \n\t"
237
+ "vfmul.vf v28, v12, f13 \n\t"
238
+ "vfcvt.x.f.v v16, v16 \n\t"
239
+ "vfcvt.x.f.v v20, v20 \n\t"
240
+ "vfcvt.x.f.v v24, v24 \n\t"
241
+ "vfcvt.x.f.v v28, v28 \n\t"
242
+ "vsetvli t0, zero, e16, m2 \n\t"
243
+ "vnclip.wx v16, v16, zero \n\t"
244
+ "vnclip.wx v20, v20, zero \n\t"
245
+ "vnclip.wx v24, v24, zero \n\t"
246
+ "vnclip.wx v28, v28, zero \n\t"
247
+ "vsetvli t0, t1, e8, m1 \n\t"
248
+ "vnclip.wx v16, v16, zero \n\t"
249
+ "vnclip.wx v20, v20, zero \n\t"
250
+ "vnclip.wx v24, v24, zero \n\t"
251
+ "vnclip.wx v28, v28, zero \n\t"
252
+ "vse8.v v16, (s1) \n\t"
253
+ "addi s1, s1, 140 \n\t"
254
+ "vse8.v v20, (s2) \n\t"
255
+ "addi s2, s2, 140 \n\t"
256
+ "vse8.v v24, (s3) \n\t"
257
+ "addi s3, s3, 140 \n\t"
258
+ "vse8.v v28, (s4) \n\t"
259
+ "addi s4, s4, 140 \n\t"
260
+ "bge %[K], t3, LOOP_MAIN%= \n\t"
261
+ "blt %[K], t2, TAIL%= \n\t"
262
+ "LOOP_K%=: \n\t"
263
+ "vsetvli t1, %[K], e32, m4 \n\t"
264
+ "vle32.v v0, (a1) \n\t"
265
+ "addi a1, a1, 128 \n\t"
266
+ "sub %[K], %[K], t1 \n\t"
267
+ "vfabs.v v16, v0 \n\t"
268
+ "vsetvli t0, zero, e32, m2 \n\t"
269
+ "vfmax.vv v16, v16, v18 \n\t"
270
+ "vsetvli t0, zero, e32, m1 \n\t"
271
+ "vfmax.vv v16, v16, v17 \n\t"
272
+ "vfredmax.vs v17, v16, v17 \n\t"
273
+ "vfmv.f.s f10, v17 \n\t"
274
+
275
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
276
+ "fsw f10, (s1) \n\t"
277
+ "addi s1, s1, 4 \n\t"
278
+ "fdiv.s f11, %[FONE], f10 \n\t"
279
+ "vsetvli t0, zero, e32, m4 \n\t"
280
+ "vfmul.vf v16, v0, f11 \n\t"
281
+ "vfcvt.x.f.v v16, v16 \n\t"
282
+ "vsetvli t0, zero, e16, m2 \n\t"
283
+ "vnclip.wx v16, v16, zero \n\t"
284
+ "vsetvli t0, zero, e8, m1 \n\t"
285
+ "vnclip.wx v16, v16, zero \n\t"
286
+ "vse8.v v16, (s1) \n\t"
287
+ "addi s1, s1, 32 \n\t"
288
+ "bge %[K], t2, LOOP_K%= \n\t"
289
+ "TAIL%=: \n\t"
290
+ "blez %[K], END%= \n\t"
291
+ "vsetvli t0, t3, e32, m4 \n\t"
292
+ "vxor.vv v0, v0, v0 \n\t"
293
+ "vxor.vv v16, v16, v16 \n\t"
294
+ "jal x0, LOOP_K%= \n\t"
295
+ "END%=: \n\t"
296
+ : [K] "+r"(CountK)
297
+ : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST)
298
+ : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13");
1166
299
  }
1167
300
 
1168
301
  } // namespace ime1
@@ -1451,1746 +584,444 @@ namespace {
1451
584
  "vadd.vi v1, v1, -12 \n\t"
1452
585
 
1453
586
  template <bool HasZeroPoint>
1454
- void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
1455
- const std::byte * QuantA,
1456
- const std::byte * QuantBData,
1457
- const float * QuantBScale,
1458
- const std::byte * QuantBZeroPoint,
1459
- float * C,
1460
- size_t CountN,
1461
- size_t BlockCountK,
1462
- const float * Bias,
1463
- const size_t ldc) {
1464
- GGML_UNUSED(QuantBScale);
1465
- GGML_UNUSED(QuantBZeroPoint);
587
+ void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
588
+ const uint8_t * QuantA,
589
+ const uint8_t * QuantBData,
590
+ float * C,
591
+ size_t CountN,
592
+ size_t BlockCountK,
593
+ const size_t ldc) {
1466
594
  size_t LDC = ldc * sizeof(float);
1467
595
  const size_t INNER = BlkLen / 16;
1468
596
  float tmp[4 * 16];
1469
597
 
1470
598
  if constexpr (HasZeroPoint) {
1471
599
  for (size_t n = 0; n < CountN; n += 16) {
1472
- size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
1473
- std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
1474
- n * BlockCountK * BlkLen / 2 + // b data
1475
- n * BlockCountK * sizeof(uint8_t) + // zp
1476
- n * BlockCountK * sizeof(_Float16); // scale
600
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
601
+ uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + //
602
+ n * BlockCountK * BlkLen / 2 + // b data
603
+ n * BlockCountK * sizeof(uint8_t) + // zp
604
+ n * BlockCountK * sizeof(_Float16); // scale
1477
605
  float * CPtr = C + n;
1478
606
  if (NBLKS < 16) {
1479
607
  CPtr = tmp;
1480
608
  LDC = 16 * sizeof(float);
1481
609
  }
1482
- if (Bias != nullptr) {
1483
- const float * bias = Bias + n;
1484
- if (NBLKS < 16) {
1485
- __asm__ volatile(
1486
- "vsetvli t0, %[N], e32, m2 \n\t"
1487
- "vle32.v v0, (%[SRC]) \n\t"
1488
- "vse32.v v0, (%[DST]) \n\t"
1489
- :
1490
- : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
1491
- : "cc", "t0");
1492
- bias = tmp;
1493
- }
1494
- __asm__ volatile(LOAD_BIAS
1495
-
1496
- "addi t3, %[BlockCountK], 0 \n\t"
1497
-
1498
- "vsetvli t0, zero, e8, m1 \n\t"
1499
- "li s1, 24 \n\t"
1500
- "vmv.v.i v1, 3 \n\t"
1501
- "vsetvli t0, s1, e8, m1 \n\t"
1502
- "vmv.v.i v1, 2 \n\t"
1503
- "vsetvli t0, zero, e8, mf2 \n\t"
1504
- "vmv.v.i v1, 1 \n\t"
1505
- "vsetvli t0, zero, e8, mf4 \n\t"
1506
- "vmv.v.i v1, 0 \n\t"
1507
-
1508
- "addi a1, %[A], 0 \n\t"
1509
- "addi s1, %[B], 0 \n\t"
1510
-
1511
- "BLOCK_COUNTK_LOOP%=: \n\t"
1512
- // scale offset
1513
- "addi s5, s1, 0 \n\t"
1514
- // zp offset
1515
- "addi s6, s1, 32 \n\t"
1516
- "addi s1, s6, 16 \n\t"
1517
- "addi s2, s1, 32 \n\t"
1518
- "addi s3, s1, 32*2 \n\t"
1519
- "addi s4, s1, 32*3 \n\t"
1520
-
1521
- "vsetvli t0, zero, e32, m8 \n\t"
1522
- "vxor.vv v16, v16, v16 \n\t"
1523
- // load a scale
1524
- "flw f1, (a1) \n\t"
1525
- "flw f2, 4(a1) \n\t"
1526
- "flw f3, 8(a1) \n\t"
1527
- "flw f4, 12(a1) \n\t"
1528
- "addi a1, a1, 16 \n\t"
1529
- "addi t2, %[INNER], 0 \n\t"
1530
-
1531
- SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1532
-
1533
- "BLOCK_INNER_LOOP%=: \n\t"
1534
-
1535
- LOAD_B_16x8x2
1536
-
1537
- "vle8.v v10, (a1) \n\t"
1538
- "addi a1, a1, 32 \n\t"
1539
- "vle8.v v11, (a1) \n\t"
1540
- "addi a1, a1, 32 \n\t"
1541
- "vsub.vv v2, v2, v12 \n\t"
1542
- "vsub.vv v6, v6, v12 \n\t"
1543
- "vsub.vv v3, v3, v13 \n\t"
1544
- "vsub.vv v7, v7, v13 \n\t"
1545
- "vsub.vv v4, v4, v14 \n\t"
1546
- "vsub.vv v8, v8, v14 \n\t"
1547
- "vsub.vv v5, v5, v15 \n\t"
1548
- "vsub.vv v9, v9, v15 \n\t"
1549
-
1550
- SQ4BIT_KERNEL_COMP_4x16x16
1551
-
1552
- "addi t2, t2, -1 \n\t"
1553
- "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1554
-
1555
- LOAD_SCALE_4x16_FP16
1556
-
1557
- "vsetvli t0, zero, e32, m8 \n\t"
1558
- "vfcvt.f.x.v v16, v16 \n\t"
1559
- "vfmacc.vv v24, v16, v8 \n\t"
1560
- "addi t3, t3, -1 \n\t"
1561
- "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1562
-
1563
- "RESULT_SAVE%=: \n\t"
1564
-
1565
- SAVE_RESULT_4x16
1566
-
1567
- :
1568
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1569
- [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
1570
- : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
1571
- "s2", "s3", "s4", "s5", "s6");
1572
-
1573
- } else {
1574
- __asm__ volatile(
1575
- "vsetvli t0, zero, e32, m8 \n\t"
1576
- "vxor.vv v24, v24, v24 \n\t"
1577
- "addi t3, %[BlockCountK], 0 \n\t"
1578
- "vsetvli t0, zero, e8, m1 \n\t"
1579
- "li s1, 24 \n\t"
1580
- "vmv.v.i v1, 3 \n\t"
1581
- "vsetvli t0, s1, e8, m1 \n\t"
1582
- "vmv.v.i v1, 2 \n\t"
1583
- "vsetvli t0, zero, e8, mf2 \n\t"
1584
- "vmv.v.i v1, 1 \n\t"
1585
- "vsetvli t0, zero, e8, mf4 \n\t"
1586
- "vmv.v.i v1, 0 \n\t"
1587
- "addi a1, %[A], 0 \n\t"
1588
- "addi s1, %[B], 0 \n\t"
1589
- "BLOCK_COUNTK_LOOP%=: \n\t"
1590
- // scale offset
1591
- "addi s5, s1, 0 \n\t"
1592
- // zp offset
1593
- "addi s6, s1, 32 \n\t"
1594
- "addi s1, s6, 16 \n\t"
1595
- "addi s2, s1, 32 \n\t"
1596
- "addi s3, s1, 32*2 \n\t"
1597
- "addi s4, s1, 32*3 \n\t"
1598
-
1599
- "vsetvli t0, zero, e32, m8 \n\t"
1600
- "vxor.vv v16, v16, v16 \n\t"
1601
- // load a scale
1602
- "flw f1, (a1) \n\t"
1603
- "flw f2, 4(a1) \n\t"
1604
- "flw f3, 8(a1) \n\t"
1605
- "flw f4, 12(a1) \n\t"
1606
- "addi a1, a1, 16 \n\t"
1607
- "addi t2, %[INNER], 0 \n\t"
1608
-
1609
- SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1610
-
1611
- "BLOCK_INNER_LOOP%=: \n\t"
1612
-
1613
- LOAD_B_16x8x2
1614
-
1615
- "vle8.v v10, (a1) \n\t"
1616
- "addi a1, a1, 32 \n\t"
1617
- "vle8.v v11, (a1) \n\t"
1618
- "addi a1, a1, 32 \n\t"
1619
- "vsub.vv v2, v2, v12 \n\t"
1620
- "vsub.vv v6, v6, v12 \n\t"
1621
- "vsub.vv v3, v3, v13 \n\t"
1622
- "vsub.vv v7, v7, v13 \n\t"
1623
- "vsub.vv v4, v4, v14 \n\t"
1624
- "vsub.vv v8, v8, v14 \n\t"
1625
- "vsub.vv v5, v5, v15 \n\t"
1626
- "vsub.vv v9, v9, v15 \n\t"
1627
-
1628
- SQ4BIT_KERNEL_COMP_4x16x16
1629
-
1630
- "addi t2, t2, -1 \n\t"
1631
- "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1632
-
1633
- LOAD_SCALE_4x16_FP16
1634
-
1635
- "vsetvli t0, zero, e32, m8 \n\t"
1636
- "vfcvt.f.x.v v16, v16 \n\t"
1637
- "vfmacc.vv v24, v16, v8 \n\t"
1638
- "addi t3, t3, -1 \n\t"
1639
- "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1640
-
1641
- "RESULT_SAVE%=: \n\t"
1642
-
1643
- SAVE_RESULT_4x16
1644
-
1645
- :
1646
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1647
- [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
1648
- : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
1649
- "s4", "s5", "s6");
1650
- }
1651
- }
1652
- } else {
1653
- for (size_t n = 0; n < CountN; n += 16) {
1654
- size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
1655
- std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
1656
- n * BlockCountK * BlkLen / 2 + // b data
1657
- n * BlockCountK * sizeof(_Float16); // scale
1658
- float * CPtr = C + n;
1659
- if (NBLKS < 16) {
1660
- CPtr = tmp;
1661
- LDC = 16 * sizeof(float);
1662
- }
1663
- if (Bias != nullptr) {
1664
- const float * bias = Bias + n;
1665
- if (NBLKS < 16) {
1666
- __asm__ volatile(
1667
- "vsetvli t0, %[N], e32, m2 \n\t"
1668
- "vle32.v v0, (%[SRC]) \n\t"
1669
- "vse32.v v0, (%[DST]) \n\t"
1670
- :
1671
- : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
1672
- : "cc", "t0");
1673
- bias = tmp;
1674
- }
1675
- __asm__ volatile(LOAD_BIAS
1676
-
1677
- "addi t3, %[BlockCountK], 0 \n\t"
1678
- "addi a1, %[A], 0 \n\t"
1679
- "addi s1, %[B], 0 \n\t"
1680
- "BLOCK_COUNTK_LOOP%=: \n\t"
1681
- "addi s5, s1, 0 \n\t"
1682
- "addi s1, s5, 32 \n\t"
1683
- "addi s2, s1, 32 \n\t"
1684
- "addi s3, s1, 32*2 \n\t"
1685
- "addi s4, s1, 32*3 \n\t"
1686
- "vsetvli t0, zero, e32, m8 \n\t"
1687
- "vxor.vv v16, v16, v16 \n\t"
1688
- // load a scale
1689
- "flw f1, (a1) \n\t"
1690
- "flw f2, 4(a1) \n\t"
1691
- "flw f3, 8(a1) \n\t"
1692
- "flw f4, 12(a1) \n\t"
1693
- "addi a1, a1, 16 \n\t"
1694
- "addi t2, %[INNER], 0 \n\t"
1695
- "BLOCK_INNER_LOOP%=: \n\t"
1696
-
1697
- LOAD_B_16x8x2
1698
-
1699
- "vsetvli t0, zero, e8, m1 \n\t"
1700
- "vle8.v v10, (a1) \n\t"
1701
- "addi a1, a1, 32 \n\t"
1702
- "vle8.v v11, (a1) \n\t"
1703
- "addi a1, a1, 32 \n\t"
1704
- "vadd.vi v2, v2, -8 \n\t"
1705
- "vadd.vi v3, v3, -8 \n\t"
1706
- "vadd.vi v4, v4, -8 \n\t"
1707
- "vadd.vi v5, v5, -8 \n\t"
1708
- "vadd.vi v6, v6, -8 \n\t"
1709
- "vadd.vi v7, v7, -8 \n\t"
1710
- "vadd.vi v8, v8, -8 \n\t"
1711
- "vadd.vi v9, v9, -8 \n\t"
1712
-
1713
- SQ4BIT_KERNEL_COMP_4x16x16
1714
-
1715
- "addi t2, t2, -1 \n\t"
1716
- "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1717
-
1718
- LOAD_SCALE_4x16_FP16
1719
-
1720
- "vsetvli t0, zero, e32, m8 \n\t"
1721
- "vfcvt.f.x.v v16, v16 \n\t"
1722
- "vfmacc.vv v24, v16, v8 \n\t"
1723
- "addi t3, t3, -1 \n\t"
1724
- "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1725
- "RESULT_SAVE%=: \n\t"
1726
-
1727
- SAVE_RESULT_4x16
1728
-
1729
- :
1730
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1731
- [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
1732
- : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
1733
- "s2", "s3", "s4", "s5", "s6");
1734
-
1735
- } else {
1736
- __asm__ volatile(
1737
- "vsetvli t0, zero, e32, m8 \n\t"
1738
- "vxor.vv v24, v24, v24 \n\t"
1739
- "addi t3, %[BlockCountK], 0 \n\t"
1740
- "addi a1, %[A], 0 \n\t"
1741
- "addi s1, %[B], 0 \n\t"
1742
- "BLOCK_COUNTK_LOOP%=: \n\t"
1743
- "addi s5, s1, 0 \n\t"
1744
- "addi s1, s5, 32 \n\t"
1745
- "addi s2, s1, 32 \n\t"
1746
- "addi s3, s1, 32*2 \n\t"
1747
- "addi s4, s1, 32*3 \n\t"
1748
- "vsetvli t0, zero, e32, m8 \n\t"
1749
- "vxor.vv v16, v16, v16 \n\t"
1750
- // load a scale
1751
- "flw f1, (a1) \n\t"
1752
- "flw f2, 4(a1) \n\t"
1753
- "flw f3, 8(a1) \n\t"
1754
- "flw f4, 12(a1) \n\t"
1755
- "addi a1, a1, 16 \n\t"
1756
- "addi t2, %[INNER], 0 \n\t"
1757
- "BLOCK_INNER_LOOP%=: \n\t"
1758
-
1759
- LOAD_B_16x8x2
1760
-
1761
- "vsetvli t0, zero, e8, m1 \n\t"
1762
- "vle8.v v10, (a1) \n\t"
1763
- "addi a1, a1, 32 \n\t"
1764
- "vle8.v v11, (a1) \n\t"
1765
- "addi a1, a1, 32 \n\t"
1766
- "vadd.vi v2, v2, -8 \n\t"
1767
- "vadd.vi v3, v3, -8 \n\t"
1768
- "vadd.vi v4, v4, -8 \n\t"
1769
- "vadd.vi v5, v5, -8 \n\t"
1770
- "vadd.vi v6, v6, -8 \n\t"
1771
- "vadd.vi v7, v7, -8 \n\t"
1772
- "vadd.vi v8, v8, -8 \n\t"
1773
- "vadd.vi v9, v9, -8 \n\t"
1774
-
1775
- SQ4BIT_KERNEL_COMP_4x16x16
1776
-
1777
- "addi t2, t2, -1 \n\t"
1778
- "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1779
-
1780
- LOAD_SCALE_4x16_FP16
1781
-
1782
- "vsetvli t0, zero, e32, m8 \n\t"
1783
- "vfcvt.f.x.v v16, v16 \n\t"
1784
- "vfmacc.vv v24, v16, v8 \n\t"
1785
- "addi t3, t3, -1 \n\t"
1786
- "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1787
- "RESULT_SAVE%=: \n\t"
1788
-
1789
- SAVE_RESULT_4x16
1790
-
1791
- :
1792
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1793
- [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
1794
- : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
1795
- "s4", "s5", "s6");
1796
- }
1797
- }
1798
- }
1799
- if (CountN % 16 != 0) {
1800
- // stroe output from tmp to C when NBLKS less than 16.
1801
- float * CPtr = C + CountN / 16 * 16;
1802
- const size_t N = CountN % 16;
1803
- LDC = ldc * sizeof(float);
1804
- __asm__ volatile(
1805
- "vsetvli t0, %[N], e32, m2 \n\t"
1806
- "vle32.v v0, (%[SRC]) \n\t"
1807
- "addi s2, %[SRC], 64 \n\t"
1808
- "addi s3, %[SRC], 64*2 \n\t"
1809
- "addi s4, %[SRC], 64*3 \n\t"
1810
- "vle32.v v2, (s2) \n\t"
1811
- "vle32.v v4, (s3) \n\t"
1812
- "vle32.v v6, (s4) \n\t"
1813
- "add t2, %[DST], %[LDC] \n\t"
1814
- "add t3, t2, %[LDC] \n\t"
1815
- "add t4, t3, %[LDC] \n\t"
1816
- "vse32.v v0, (%[DST]) \n\t"
1817
- "vse32.v v2, (t2) \n\t"
1818
- "vse32.v v4, (t3) \n\t"
1819
- "vse32.v v6, (t4) \n\t"
1820
- :
1821
- : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC)
1822
- : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4");
1823
- }
1824
- }
1825
610
 
1826
- template <bool HasZeroPoint>
1827
- void SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen,
1828
- const std::byte * QuantA,
1829
- const std::byte * QuantBData,
1830
- const float * QuantBScale,
1831
- const std::byte * QuantBZeroPoint,
1832
- float * C,
1833
- size_t CountN,
1834
- size_t BlockCountK,
1835
- const float * Bias,
1836
- const size_t ldc) {
1837
- GGML_UNUSED(QuantBScale);
1838
- GGML_UNUSED(QuantBZeroPoint);
1839
- size_t LDC = ldc * sizeof(float);
1840
- const size_t INNER = BlkLen / 16;
1841
- float tmp[4 * 16];
1842
-
1843
- if constexpr (HasZeroPoint) {
1844
- for (size_t n = 0; n < CountN; n += 16) {
1845
- size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
1846
- std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
1847
- n * BlockCountK * BlkLen / 2 + // b data
1848
- n * BlockCountK * sizeof(uint8_t) + // zp
1849
- n * BlockCountK * sizeof(float); // scale
1850
- float * CPtr = C + n;
1851
- if (NBLKS < 16) {
1852
- CPtr = tmp;
1853
- LDC = 16 * sizeof(float);
1854
- }
1855
- if (Bias != nullptr) {
1856
- const float * bias = Bias + n;
1857
- if (NBLKS < 16) {
1858
- __asm__ volatile(
1859
- "vsetvli t0, %[N], e32, m2 \n\t"
1860
- "vle32.v v0, (%[SRC]) \n\t"
1861
- "vse32.v v0, (%[DST]) \n\t"
1862
- :
1863
- : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
1864
- : "cc", "t0");
1865
- bias = tmp;
1866
- }
1867
-
1868
- __asm__ volatile(LOAD_BIAS
1869
- "addi t3, %[BlockCountK], 0 \n\t"
1870
- "vsetvli t0, zero, e8, m1 \n\t"
1871
- "li s1, 24 \n\t"
1872
- "vmv.v.i v1, 3 \n\t"
1873
- "vsetvli t0, s1, e8, m1 \n\t"
1874
- "vmv.v.i v1, 2 \n\t"
1875
- "vsetvli t0, zero, e8, mf2 \n\t"
1876
- "vmv.v.i v1, 1 \n\t"
1877
- "vsetvli t0, zero, e8, mf4 \n\t"
1878
- "vmv.v.i v1, 0 \n\t"
1879
- "addi a1, %[A], 0 \n\t"
1880
- "addi s1, %[B], 0 \n\t"
1881
- "BLOCK_COUNTK_LOOP%=: \n\t"
1882
- // scale offset
1883
- "addi s5, s1, 0 \n\t"
1884
- // zp offset
1885
- "addi s6, s1, 64 \n\t"
1886
- "addi s1, s6, 16 \n\t"
1887
- "addi s2, s1, 32 \n\t"
1888
- "addi s3, s1, 32*2 \n\t"
1889
- "addi s4, s1, 32*3 \n\t"
1890
- "vsetvli t0, zero, e32, m8 \n\t"
1891
- "vxor.vv v16, v16, v16 \n\t"
1892
- // load a scale
1893
- "flw f1, (a1) \n\t"
1894
- "flw f2, 4(a1) \n\t"
1895
- "flw f3, 8(a1) \n\t"
1896
- "flw f4, 12(a1) \n\t"
1897
- "addi a1, a1, 16 \n\t"
1898
- "addi t2, %[INNER], 0 \n\t"
1899
-
1900
- SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1901
-
1902
- "BLOCK_INNER_LOOP%=: \n\t"
1903
-
1904
- LOAD_B_16x8x2
1905
-
1906
- "vle8.v v10, (a1) \n\t"
1907
- "addi a1, a1, 32 \n\t"
1908
- "vle8.v v11, (a1) \n\t"
1909
- "addi a1, a1, 32 \n\t"
1910
- "vsub.vv v2, v2, v12 \n\t"
1911
- "vsub.vv v6, v6, v12 \n\t"
1912
- "vsub.vv v3, v3, v13 \n\t"
1913
- "vsub.vv v7, v7, v13 \n\t"
1914
- "vsub.vv v4, v4, v14 \n\t"
1915
- "vsub.vv v8, v8, v14 \n\t"
1916
- "vsub.vv v5, v5, v15 \n\t"
1917
- "vsub.vv v9, v9, v15 \n\t"
1918
-
1919
- SQ4BIT_KERNEL_COMP_4x16x16
1920
-
1921
- "addi t2, t2, -1 \n\t"
1922
- "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1923
-
1924
- LOAD_SCALE_4x16
1925
-
1926
- "vsetvli t0, zero, e32, m8 \n\t"
1927
- "vfcvt.f.x.v v16, v16 \n\t"
1928
- "vfmacc.vv v24, v16, v8 \n\t"
1929
- "addi t3, t3, -1 \n\t"
1930
- "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1931
-
1932
- "RESULT_SAVE%=: \n\t"
1933
-
1934
- SAVE_RESULT_4x16
1935
-
1936
- :
1937
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1938
- [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
1939
- : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
1940
- "s2", "s3", "s4", "s5", "s6");
1941
-
1942
- } else {
1943
- __asm__ volatile(
1944
- "vsetvli t0, zero, e32, m8 \n\t"
1945
- "vxor.vv v24, v24, v24 \n\t"
1946
- "addi t3, %[BlockCountK], 0 \n\t"
1947
- "vsetvli t0, zero, e8, m1 \n\t"
1948
- "li s1, 24 \n\t"
1949
- "vmv.v.i v1, 3 \n\t"
1950
- "vsetvli t0, s1, e8, m1 \n\t"
1951
- "vmv.v.i v1, 2 \n\t"
1952
- "vsetvli t0, zero, e8, mf2 \n\t"
1953
- "vmv.v.i v1, 1 \n\t"
1954
- "vsetvli t0, zero, e8, mf4 \n\t"
1955
- "vmv.v.i v1, 0 \n\t"
1956
- "addi a1, %[A], 0 \n\t"
1957
- "addi s1, %[B], 0 \n\t"
1958
- "BLOCK_COUNTK_LOOP%=: \n\t"
1959
- // scale offset
1960
- "addi s5, s1, 0 \n\t"
1961
- // zp offset
1962
- "addi s6, s1, 64 \n\t"
1963
- "addi s1, s6, 16 \n\t"
1964
- "addi s2, s1, 32 \n\t"
1965
- "addi s3, s1, 32*2 \n\t"
1966
- "addi s4, s1, 32*3 \n\t"
1967
- "vsetvli t0, zero, e32, m8 \n\t"
1968
- "vxor.vv v16, v16, v16 \n\t"
1969
- // load a scale
1970
- // load a scale
1971
- "flw f1, (a1) \n\t"
1972
- "flw f2, 4(a1) \n\t"
1973
- "flw f3, 8(a1) \n\t"
1974
- "flw f4, 12(a1) \n\t"
1975
- "addi a1, a1, 16 \n\t"
1976
- "addi t2, %[INNER], 0 \n\t"
1977
-
1978
- SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1979
-
1980
- "BLOCK_INNER_LOOP%=: \n\t"
1981
-
1982
- LOAD_B_16x8x2
1983
-
1984
- "vle8.v v10, (a1) \n\t"
1985
- "addi a1, a1, 32 \n\t"
1986
- "vle8.v v11, (a1) \n\t"
1987
- "addi a1, a1, 32 \n\t"
1988
- "vsub.vv v2, v2, v12 \n\t"
1989
- "vsub.vv v6, v6, v12 \n\t"
1990
- "vsub.vv v3, v3, v13 \n\t"
1991
- "vsub.vv v7, v7, v13 \n\t"
1992
- "vsub.vv v4, v4, v14 \n\t"
1993
- "vsub.vv v8, v8, v14 \n\t"
1994
- "vsub.vv v5, v5, v15 \n\t"
1995
- "vsub.vv v9, v9, v15 \n\t"
1996
-
1997
- SQ4BIT_KERNEL_COMP_4x16x16
1998
-
1999
- "addi t2, t2, -1 \n\t"
2000
- "bnez t2, BLOCK_INNER_LOOP%= \n\t"
2001
-
2002
- LOAD_SCALE_4x16
2003
-
2004
- "vsetvli t0, zero, e32, m8 \n\t"
2005
- "vfcvt.f.x.v v16, v16 \n\t"
2006
- "vfmacc.vv v24, v16, v8 \n\t"
2007
- "addi t3, t3, -1 \n\t"
2008
- "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
2009
-
2010
- "RESULT_SAVE%=: \n\t"
2011
-
2012
- SAVE_RESULT_4x16
2013
-
2014
- :
2015
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
2016
- [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
2017
- : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
2018
- "s4", "s5", "s6");
2019
- }
611
+ __asm__ volatile(
612
+ "vsetvli t0, zero, e32, m8 \n\t"
613
+ "vxor.vv v24, v24, v24 \n\t"
614
+ "addi t3, %[BlockCountK], 0 \n\t"
615
+ "vsetvli t0, zero, e8, m1 \n\t"
616
+ "li s1, 24 \n\t"
617
+ "vmv.v.i v1, 3 \n\t"
618
+ "vsetvli t0, s1, e8, m1 \n\t"
619
+ "vmv.v.i v1, 2 \n\t"
620
+ "vsetvli t0, zero, e8, mf2 \n\t"
621
+ "vmv.v.i v1, 1 \n\t"
622
+ "vsetvli t0, zero, e8, mf4 \n\t"
623
+ "vmv.v.i v1, 0 \n\t"
624
+ "addi a1, %[A], 0 \n\t"
625
+ "addi s1, %[B], 0 \n\t"
626
+ "BLOCK_COUNTK_LOOP%=: \n\t"
627
+ // scale offset
628
+ "addi s5, s1, 0 \n\t"
629
+ // zp offset
630
+ "addi s6, s1, 32 \n\t"
631
+ "addi s1, s6, 16 \n\t"
632
+ "addi s2, s1, 32 \n\t"
633
+ "addi s3, s1, 32*2 \n\t"
634
+ "addi s4, s1, 32*3 \n\t"
635
+
636
+ "vsetvli t0, zero, e32, m8 \n\t"
637
+ "vxor.vv v16, v16, v16 \n\t"
638
+ // load a scale
639
+ "flw f1, (a1) \n\t"
640
+ "flw f2, 4(a1) \n\t"
641
+ "flw f3, 8(a1) \n\t"
642
+ "flw f4, 12(a1) \n\t"
643
+ "addi a1, a1, 16 \n\t"
644
+ "addi t2, %[INNER], 0 \n\t"
645
+
646
+ SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
647
+
648
+ "BLOCK_INNER_LOOP%=: \n\t"
649
+
650
+ LOAD_B_16x8x2
651
+
652
+ "vle8.v v10, (a1) \n\t"
653
+ "addi a1, a1, 32 \n\t"
654
+ "vle8.v v11, (a1) \n\t"
655
+ "addi a1, a1, 32 \n\t"
656
+ "vsub.vv v2, v2, v12 \n\t"
657
+ "vsub.vv v6, v6, v12 \n\t"
658
+ "vsub.vv v3, v3, v13 \n\t"
659
+ "vsub.vv v7, v7, v13 \n\t"
660
+ "vsub.vv v4, v4, v14 \n\t"
661
+ "vsub.vv v8, v8, v14 \n\t"
662
+ "vsub.vv v5, v5, v15 \n\t"
663
+ "vsub.vv v9, v9, v15 \n\t"
664
+
665
+ SQ4BIT_KERNEL_COMP_4x16x16
666
+
667
+ "addi t2, t2, -1 \n\t"
668
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
669
+
670
+ LOAD_SCALE_4x16_FP16
671
+
672
+ "vsetvli t0, zero, e32, m8 \n\t"
673
+ "vfcvt.f.x.v v16, v16 \n\t"
674
+ "vfmacc.vv v24, v16, v8 \n\t"
675
+ "addi t3, t3, -1 \n\t"
676
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
677
+
678
+ "RESULT_SAVE%=: \n\t"
679
+
680
+ SAVE_RESULT_4x16
681
+
682
+ :
683
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
684
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
685
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4",
686
+ "s5", "s6");
2020
687
  }
2021
688
  } else {
2022
689
  for (size_t n = 0; n < CountN; n += 16) {
2023
- size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
2024
- std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2025
- n * BlockCountK * BlkLen / 2 + // b data
2026
- n * BlockCountK * sizeof(float); // scale
690
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
691
+ uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + //
692
+ n * BlockCountK * BlkLen / 2 + // b data
693
+ n * BlockCountK * sizeof(_Float16); // scale
2027
694
  float * CPtr = C + n;
2028
695
  if (NBLKS < 16) {
2029
696
  CPtr = tmp;
2030
697
  LDC = 16 * sizeof(float);
2031
698
  }
2032
- if (Bias != nullptr) {
2033
- const float * bias = Bias + n;
2034
- if (NBLKS < 16) {
2035
- __asm__ volatile(
2036
- "vsetvli t0, %[N], e32, m2 \n\t"
2037
- "vle32.v v0, (%[SRC]) \n\t"
2038
- "vse32.v v0, (%[DST]) \n\t"
2039
- :
2040
- : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
2041
- : "cc", "t0");
2042
- bias = tmp;
2043
- }
2044
- __asm__ volatile(LOAD_BIAS
2045
- "addi t3, %[BlockCountK], 0 \n\t"
2046
- "addi a1, %[A], 0 \n\t"
2047
- "addi s1, %[B], 0 \n\t"
2048
- "BLOCK_COUNTK_LOOP%=: \n\t"
2049
- "addi s5, s1, 0 \n\t"
2050
- "addi s1, s5, 64 \n\t"
2051
- "addi s2, s1, 32 \n\t"
2052
- "addi s3, s1, 32*2 \n\t"
2053
- "addi s4, s1, 32*3 \n\t"
2054
- "vsetvli t0, zero, e32, m8 \n\t"
2055
- "vxor.vv v16, v16, v16 \n\t"
2056
- // load a scale
2057
- "flw f1, (a1) \n\t"
2058
- "flw f2, 4(a1) \n\t"
2059
- "flw f3, 8(a1) \n\t"
2060
- "flw f4, 12(a1) \n\t"
2061
- "addi a1, a1, 16 \n\t"
2062
- "addi t2, %[INNER], 0 \n\t"
2063
- "BLOCK_INNER_LOOP%=: \n\t"
2064
-
2065
- LOAD_B_16x8x2
2066
-
2067
- "vsetvli t0, zero, e8, m1 \n\t"
2068
- "vle8.v v10, (a1) \n\t"
2069
- "addi a1, a1, 32 \n\t"
2070
- "vle8.v v11, (a1) \n\t"
2071
- "addi a1, a1, 32 \n\t"
2072
- "vadd.vi v2, v2, -8 \n\t"
2073
- "vadd.vi v3, v3, -8 \n\t"
2074
- "vadd.vi v4, v4, -8 \n\t"
2075
- "vadd.vi v5, v5, -8 \n\t"
2076
- "vadd.vi v6, v6, -8 \n\t"
2077
- "vadd.vi v7, v7, -8 \n\t"
2078
- "vadd.vi v8, v8, -8 \n\t"
2079
- "vadd.vi v9, v9, -8 \n\t"
2080
-
2081
- SQ4BIT_KERNEL_COMP_4x16x16
2082
-
2083
- "addi t2, t2, -1 \n\t"
2084
- "bnez t2, BLOCK_INNER_LOOP%= \n\t"
2085
-
2086
- LOAD_SCALE_4x16
2087
-
2088
- "vsetvli t0, zero, e32, m8 \n\t"
2089
- "vfcvt.f.x.v v16, v16 \n\t"
2090
- "vfmacc.vv v24, v16, v8 \n\t"
2091
- "addi t3, t3, -1 \n\t"
2092
- "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
2093
-
2094
- "RESULT_SAVE%=: \n\t"
2095
-
2096
- SAVE_RESULT_4x16
2097
-
2098
- :
2099
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
2100
- [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
2101
- : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
2102
- "s2", "s3", "s4", "s5", "s6");
2103
-
2104
- } else {
2105
- __asm__ volatile(
2106
- "vsetvli t0, zero, e32, m8 \n\t"
2107
- "vxor.vv v24, v24, v24 \n\t"
2108
- "addi t3, %[BlockCountK], 0 \n\t"
2109
- "addi a1, %[A], 0 \n\t"
2110
- "addi s1, %[B], 0 \n\t"
2111
- "BLOCK_COUNTK_LOOP%=: \n\t"
2112
- "addi s5, s1, 0 \n\t"
2113
- "addi s1, s5, 64 \n\t"
2114
- "addi s2, s1, 32 \n\t"
2115
- "addi s3, s1, 32*2 \n\t"
2116
- "addi s4, s1, 32*3 \n\t"
2117
- "vsetvli t0, zero, e32, m8 \n\t"
2118
- "vxor.vv v16, v16, v16 \n\t"
2119
- // load a scale
2120
- "flw f1, (a1) \n\t"
2121
- "flw f2, 4(a1) \n\t"
2122
- "flw f3, 8(a1) \n\t"
2123
- "flw f4, 12(a1) \n\t"
2124
- "addi a1, a1, 16 \n\t"
2125
- "addi t2, %[INNER], 0 \n\t"
2126
- "BLOCK_INNER_LOOP%=: \n\t"
2127
-
2128
- LOAD_B_16x8x2
2129
-
2130
- "vsetvli t0, zero, e8, m1 \n\t"
2131
- "vle8.v v10, (a1) \n\t"
2132
-
2133
- "addi a1, a1, 32 \n\t"
2134
- "vle8.v v11, (a1) \n\t"
2135
- "addi a1, a1, 32 \n\t"
2136
- "vadd.vi v2, v2, -8 \n\t"
2137
- "vadd.vi v3, v3, -8 \n\t"
2138
- "vadd.vi v4, v4, -8 \n\t"
2139
- "vadd.vi v5, v5, -8 \n\t"
2140
- "vadd.vi v6, v6, -8 \n\t"
2141
- "vadd.vi v7, v7, -8 \n\t"
2142
- "vadd.vi v8, v8, -8 \n\t"
2143
- "vadd.vi v9, v9, -8 \n\t"
2144
-
2145
- SQ4BIT_KERNEL_COMP_4x16x16
2146
-
2147
- "addi t2, t2, -1 \n\t"
2148
- "bnez t2, BLOCK_INNER_LOOP%= \n\t"
2149
-
2150
- LOAD_SCALE_4x16
2151
-
2152
- "vsetvli t0, zero, e32, m8 \n\t"
2153
- "vfcvt.f.x.v v16, v16 \n\t"
2154
- "vfmacc.vv v24, v16, v8 \n\t"
2155
- "addi t3, t3, -1 \n\t"
2156
- "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
2157
-
2158
- "RESULT_SAVE%=: \n\t"
2159
-
2160
- SAVE_RESULT_4x16
2161
-
2162
- :
2163
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
2164
- [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
2165
- : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
2166
- "s4", "s5", "s6");
2167
- }
699
+
700
+ __asm__ volatile(
701
+ "vsetvli t0, zero, e32, m8 \n\t"
702
+ "vxor.vv v24, v24, v24 \n\t"
703
+ "addi t3, %[BlockCountK], 0 \n\t"
704
+ "addi a1, %[A], 0 \n\t"
705
+ "addi s1, %[B], 0 \n\t"
706
+ "BLOCK_COUNTK_LOOP%=: \n\t"
707
+ "addi s5, s1, 0 \n\t"
708
+ "addi s1, s5, 32 \n\t"
709
+ "addi s2, s1, 32 \n\t"
710
+ "addi s3, s1, 32*2 \n\t"
711
+ "addi s4, s1, 32*3 \n\t"
712
+ "vsetvli t0, zero, e32, m8 \n\t"
713
+ "vxor.vv v16, v16, v16 \n\t"
714
+ // load a scale
715
+ "flw f1, (a1) \n\t"
716
+ "flw f2, 4(a1) \n\t"
717
+ "flw f3, 8(a1) \n\t"
718
+ "flw f4, 12(a1) \n\t"
719
+ "addi a1, a1, 16 \n\t"
720
+ "addi t2, %[INNER], 0 \n\t"
721
+ "BLOCK_INNER_LOOP%=: \n\t"
722
+
723
+ LOAD_B_16x8x2
724
+
725
+ "vsetvli t0, zero, e8, m1 \n\t"
726
+ "vle8.v v10, (a1) \n\t"
727
+ "addi a1, a1, 32 \n\t"
728
+ "vle8.v v11, (a1) \n\t"
729
+ "addi a1, a1, 32 \n\t"
730
+ "vadd.vi v2, v2, -8 \n\t"
731
+ "vadd.vi v3, v3, -8 \n\t"
732
+ "vadd.vi v4, v4, -8 \n\t"
733
+ "vadd.vi v5, v5, -8 \n\t"
734
+ "vadd.vi v6, v6, -8 \n\t"
735
+ "vadd.vi v7, v7, -8 \n\t"
736
+ "vadd.vi v8, v8, -8 \n\t"
737
+ "vadd.vi v9, v9, -8 \n\t"
738
+
739
+ SQ4BIT_KERNEL_COMP_4x16x16
740
+
741
+ "addi t2, t2, -1 \n\t"
742
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
743
+
744
+ LOAD_SCALE_4x16_FP16
745
+
746
+ "vsetvli t0, zero, e32, m8 \n\t"
747
+ "vfcvt.f.x.v v16, v16 \n\t"
748
+ "vfmacc.vv v24, v16, v8 \n\t"
749
+ "addi t3, t3, -1 \n\t"
750
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
751
+ "RESULT_SAVE%=: \n\t"
752
+
753
+ SAVE_RESULT_4x16
754
+
755
+ :
756
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
757
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
758
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4",
759
+ "s5", "s6");
2168
760
  }
2169
761
  }
2170
- if (CountN % 16 != 0) {
2171
- // stroe output from tmp to C when NBLKS less than 16.
2172
- float * CPtr = C + CountN / 16 * 16;
2173
- const size_t N = CountN % 16;
2174
- LDC = ldc * sizeof(float);
2175
- __asm__ volatile(
2176
- "vsetvli t0, %[N], e32, m2 \n\t"
2177
- "vle32.v v0, (%[SRC]) \n\t"
2178
- "addi s2, %[SRC], 64 \n\t"
2179
- "addi s3, %[SRC], 64*2 \n\t"
2180
- "addi s4, %[SRC], 64*3 \n\t"
2181
- "vle32.v v2, (s2) \n\t"
2182
- "vle32.v v4, (s3) \n\t"
2183
- "vle32.v v6, (s4) \n\t"
2184
- "add t2, %[DST], %[LDC] \n\t"
2185
- "add t3, t2, %[LDC] \n\t"
2186
- "add t4, t3, %[LDC] \n\t"
2187
- "vse32.v v0, (%[DST]) \n\t"
2188
- "vse32.v v2, (t2) \n\t"
2189
- "vse32.v v4, (t3) \n\t"
2190
- "vse32.v v6, (t4) \n\t"
2191
- :
2192
- : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC)
2193
- : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4");
2194
- }
2195
762
  }
2196
763
 
2197
764
  template <bool HasZeroPoint>
2198
- void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
2199
- const std::byte * QuantA,
2200
- const std::byte * QuantBData,
2201
- const float * QuantBScale,
2202
- const std::byte * QuantBZeroPoint,
2203
- float * C,
2204
- size_t CountN,
2205
- size_t BlockCountK,
2206
- const float * Bias) {
2207
- GGML_UNUSED(QuantBScale);
2208
- GGML_UNUSED(QuantBZeroPoint);
765
+ void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
766
+ const uint8_t * QuantA,
767
+ const uint8_t * QuantBData,
768
+ float * C,
769
+ size_t CountN,
770
+ size_t BlockCountK,
771
+ const size_t ldc) {
772
+ GGML_UNUSED(ldc);
2209
773
  size_t INNER = BlkLen / 16;
2210
774
 
2211
775
  if constexpr (HasZeroPoint) {
2212
776
  for (size_t n = 0; n < CountN; n += 16) {
2213
- size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2214
- std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2215
- n * BlockCountK * BlkLen / 2 + // b data
2216
- n * BlockCountK * sizeof(uint8_t) + // zp
2217
- n * BlockCountK * sizeof(_Float16); // scale
2218
- float * CPtr = C + n;
2219
- size_t cnt = BlockCountK;
2220
- if (Bias != nullptr) {
2221
- const float * bias = Bias + n;
2222
- __asm__ volatile(
2223
- "addi t3, %[NBLKS], 0 \n\t"
2224
- "vsetvli t0, zero, e8, m1 \n\t"
2225
-
2226
- "vmv.v.i v13, 3 \n\t"
2227
- "li s1, 24 \n\t"
2228
- "vsetvli t0, s1, e8, m1 \n\t"
2229
- "vmv.v.i v13, 2 \n\t"
2230
- "vsetvli t0, zero, e8, mf2 \n\t"
2231
- "vmv.v.i v13, 1 \n\t"
2232
- "vsetvli t0, zero, e8, mf4 \n\t"
2233
- "vmv.v.i v13, 0 \n\t"
2234
- "addi s1, %[B], 0 \n\t"
2235
- "addi s2, %[B], 8 \n\t"
2236
- "addi s3, %[B], 16 \n\t"
2237
- "addi s4, %[B], 24 \n\t"
2238
- // zp offset
2239
- "addi s7, %[B], 32 \n\t"
2240
- // a offset
2241
- "addi s5, %[A], 0 \n\t"
2242
- "addi s6, %[A], 12 \n\t"
2243
-
2244
- "vsetvli t0, t3, e32, mf2 \n\t"
2245
- "vle32.v v28, (%[BIAS]) \n\t"
2246
- "sub t3, t3, t0 \n\t"
2247
- "addi %[BIAS], %[BIAS], 16 \n\t"
2248
- "vsetvli t0, t3, e32, mf2 \n\t"
2249
- "vle32.v v29, (%[BIAS]) \n\t"
2250
- "sub t3, t3, t0 \n\t"
2251
- "addi %[BIAS], %[BIAS], 16 \n\t"
2252
- "vsetvli t0, t3, e32, mf2 \n\t"
2253
- "vle32.v v30, (%[BIAS]) \n\t"
2254
- "sub t3, t3, t0 \n\t"
2255
- "addi %[BIAS], %[BIAS], 16 \n\t"
2256
- "vsetvli t0, t3, e32, mf2 \n\t"
2257
- "vle32.v v31, (%[BIAS]) \n\t"
2258
-
2259
- "LOOP_K%=: \n\t"
2260
- "vsetvli t0, zero, e16, mf4 \n\t"
2261
-
2262
- "vle16.v v4, (s1) \n\t"
2263
- "addi s1, s1, 48 \n\t"
2264
- "vle16.v v5, (s2) \n\t"
2265
- "addi s2, s2, 72 \n\t"
2266
- "vle16.v v6, (s3) \n\t"
2267
- "addi s3, s3, 96 \n\t"
2268
- "vle16.v v7, (s4) \n\t"
2269
- "addi s4, s4, 120 \n\t"
2270
- "flw f1, (s5) \n\t"
2271
- "addi s5, s5, 4 \n\t"
2272
- "vfwcvt.f.f.v v8, v4 \n\t"
2273
- "vfwcvt.f.f.v v9, v5 \n\t"
2274
- "vfwcvt.f.f.v v10, v6 \n\t"
2275
- "vfwcvt.f.f.v v11, v7 \n\t"
2276
-
2277
- "vsetvli t0, zero, e32, mf2 \n\t"
2278
- "addi t5, %[INNER], 0 \n\t"
2279
- "vxor.vv v16, v16, v16 \n\t"
2280
- "vxor.vv v18, v18, v18 \n\t"
2281
- "vxor.vv v20, v20, v20 \n\t"
2282
- "vxor.vv v22, v22, v22 \n\t"
2283
- "vfmul.vf v24, v8, f1 \n\t"
2284
- "vfmul.vf v25, v9, f1 \n\t"
2285
- "vfmul.vf v26, v10, f1 \n\t"
2286
- "vfmul.vf v27, v11, f1 \n\t"
2287
- "addi %[CNT], %[CNT], -1 \n\t"
2288
-
2289
- SQ4BIT_KERNEL_LOAD_ZP_16X1
2290
-
2291
- "LOOP_INNER%=: \n\t"
2292
-
2293
- SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2294
-
2295
- "vsub.vv v0, v0, v8 \n\t"
2296
- "vsub.vv v4, v4, v8 \n\t"
2297
- "vsub.vv v1, v1, v9 \n\t"
2298
- "vsub.vv v5, v5, v9 \n\t"
2299
- "vsub.vv v2, v2, v10 \n\t"
2300
- "vsub.vv v6, v6, v10 \n\t"
2301
- "vsub.vv v3, v3, v11 \n\t"
2302
- "vsub.vv v7, v7, v11 \n\t"
2303
-
2304
- SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2305
-
2306
- "bnez t5, LOOP_INNER%= \n\t"
2307
- "vsetvli t0, zero, e32, mf2 \n\t"
2308
-
2309
- SQ4BIT_KERNEL_ACC_F16_1X4X4
2310
- "addi s7, s1, 32 \n\t"
2311
-
2312
- "bnez %[CNT], LOOP_K%= \n\t"
2313
- "addi t3, zero, 16 \n\t"
2314
- "addi s1, %[C], 16 \n\t"
2315
- "addi s2, %[C], 32 \n\t"
2316
- "addi s3, %[C], 48 \n\t"
2317
- "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2318
- "vse32.v v28, (%[C]) \n\t"
2319
- "vse32.v v29, (s1) \n\t"
2320
- "vse32.v v30, (s2) \n\t"
2321
- "vse32.v v31, (s3) \n\t"
2322
- "jal x0, END%= \n\t"
2323
-
2324
- "ST_TAIL%=: \n\t"
2325
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2326
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2327
- "vse32.v v28, (%[C]) \n\t"
2328
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2329
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2330
- "vse32.v v29, (s1) \n\t"
2331
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2332
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2333
- "vse32.v v30, (s2) \n\t"
2334
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2335
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2336
- "vse32.v v31, (s3) \n\t"
2337
- "END%=: \n\t"
2338
-
2339
- : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
2340
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2341
- : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2342
- } else {
2343
- __asm__ volatile(
2344
- "vsetvli t0, zero, e32, m4 \n\t"
2345
- "vxor.vv v28, v28, v28 \n\t"
2346
-
2347
- "vsetvli t0, zero, e8, m1 \n\t"
2348
- "vmv.v.i v13, 3 \n\t"
2349
- "li s1, 24 \n\t"
2350
- "vsetvli t0, s1, e8, m1 \n\t"
2351
- "vmv.v.i v13, 2 \n\t"
2352
- "vsetvli t0, zero, e8, mf2 \n\t"
2353
- "vmv.v.i v13, 1 \n\t"
2354
- "vsetvli t0, zero, e8, mf4 \n\t"
2355
- "vmv.v.i v13, 0 \n\t"
2356
-
2357
- "addi s1, %[B], 0 \n\t"
2358
- "addi s2, %[B], 8 \n\t"
2359
- "addi s3, %[B], 16 \n\t"
2360
- "addi s4, %[B], 24 \n\t"
2361
-
2362
- "addi s7, %[B], 32 \n\t"
2363
-
2364
- "addi s5, %[A], 0 \n\t"
2365
- "addi s6, %[A], 12 \n\t"
2366
- "LOOP_K%=: \n\t"
2367
- "vsetvli t0, zero, e16, mf4 \n\t"
2368
- "vle16.v v4, (s1) \n\t"
2369
- "addi s1, s1, 48 \n\t"
2370
- "vle16.v v5, (s2) \n\t"
2371
- "addi s2, s2, 72 \n\t"
2372
- "vle16.v v6, (s3) \n\t"
2373
- "addi s3, s3, 96 \n\t"
2374
- "vle16.v v7, (s4) \n\t"
2375
- "addi s4, s4, 120 \n\t"
2376
- "flw f1, (s5) \n\t"
2377
- "addi s5, s5, 4 \n\t"
2378
-
2379
- "vfwcvt.f.f.v v8, v4 \n\t"
2380
- "vfwcvt.f.f.v v9, v5 \n\t"
2381
- "vfwcvt.f.f.v v10, v6 \n\t"
2382
- "vfwcvt.f.f.v v11, v7 \n\t"
2383
- "vsetvli t0, zero, e32, mf2 \n\t"
2384
-
2385
- "addi t5, %[INNER], 0 \n\t"
2386
- "vxor.vv v16, v16, v16 \n\t"
2387
- "vxor.vv v18, v18, v18 \n\t"
2388
- "vxor.vv v20, v20, v20 \n\t"
2389
- "vxor.vv v22, v22, v22 \n\t"
2390
- "vfmul.vf v24, v8, f1 \n\t"
2391
- "vfmul.vf v25, v9, f1 \n\t"
2392
- "vfmul.vf v26, v10, f1 \n\t"
2393
- "vfmul.vf v27, v11, f1 \n\t"
2394
- "addi %[CNT], %[CNT], -1 \n\t"
2395
-
2396
- SQ4BIT_KERNEL_LOAD_ZP_16X1
2397
-
2398
- "LOOP_INNER%=: \n\t"
2399
-
2400
- SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2401
-
2402
- "vsub.vv v0, v0, v8 \n\t"
2403
- "vsub.vv v4, v4, v8 \n\t"
2404
- "vsub.vv v1, v1, v9 \n\t"
2405
- "vsub.vv v5, v5, v9 \n\t"
2406
- "vsub.vv v2, v2, v10 \n\t"
2407
- "vsub.vv v6, v6, v10 \n\t"
2408
- "vsub.vv v3, v3, v11 \n\t"
2409
- "vsub.vv v7, v7, v11 \n\t"
2410
-
2411
- SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2412
-
2413
- "bnez t5, LOOP_INNER%= \n\t"
2414
- "vsetvli t0, zero, e32, mf2 \n\t"
2415
-
2416
- SQ4BIT_KERNEL_ACC_F16_1X4X4
2417
- "addi s7, s1, 32 \n\t"
2418
-
2419
- "bnez %[CNT], LOOP_K%= \n\t"
2420
- "addi t3, zero, 16 \n\t"
2421
- "addi s1, %[C], 16 \n\t"
2422
- "addi s2, %[C], 32 \n\t"
2423
- "addi s3, %[C], 48 \n\t"
2424
- "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2425
- "vse32.v v28, (%[C]) \n\t"
2426
- "vse32.v v29, (s1) \n\t"
2427
- "vse32.v v30, (s2) \n\t"
2428
- "vse32.v v31, (s3) \n\t"
2429
- "jal x0, END%= \n\t"
2430
-
2431
- "ST_TAIL%=: \n\t"
2432
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2433
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2434
- "vse32.v v28, (%[C]) \n\t"
2435
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2436
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2437
- "vse32.v v29, (s1) \n\t"
2438
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2439
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2440
- "vse32.v v30, (s2) \n\t"
2441
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2442
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2443
- "vse32.v v31, (s3) \n\t"
2444
- "END%=: \n\t"
2445
-
2446
- : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
2447
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2448
- : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2449
- }
2450
- }
2451
- } else {
2452
- for (size_t n = 0; n < CountN; n += 16) {
2453
- size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2454
- std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2455
- n * BlockCountK * BlkLen / 2 + // b data
2456
- n * BlockCountK * sizeof(_Float16); // scale
777
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
778
+ uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + //
779
+ n * BlockCountK * BlkLen / 2 + // b data
780
+ n * BlockCountK * sizeof(uint8_t) + // zp
781
+ n * BlockCountK * sizeof(_Float16); // scale
2457
782
  float * CPtr = C + n;
2458
783
  size_t cnt = BlockCountK;
2459
- if (Bias != nullptr) {
2460
- const float * bias = Bias + n;
2461
- __asm__ volatile(
2462
- "addi t3, %[NBLKS], 0 \n\t"
2463
- "addi s1, %[B], 0 \n\t"
2464
- "addi s2, %[B], 8 \n\t"
2465
- "addi s3, %[B], 16 \n\t"
2466
- "addi s4, %[B], 24 \n\t"
2467
- "addi s5, %[A], 0 \n\t"
2468
- "addi s6, %[A], 12 \n\t"
2469
- "vsetvli t0, t3, e32, mf2 \n\t"
2470
- "vle32.v v28, (%[BIAS]) \n\t"
2471
- "sub t3, t3, t0 \n\t"
2472
- "addi %[BIAS], %[BIAS], 16 \n\t"
2473
- "vsetvli t0, t3, e32, mf2 \n\t"
2474
- "vle32.v v29, (%[BIAS]) \n\t"
2475
- "sub t3, t3, t0 \n\t"
2476
- "addi %[BIAS], %[BIAS], 16 \n\t"
2477
- "vsetvli t0, t3, e32, mf2 \n\t"
2478
- "vle32.v v30, (%[BIAS]) \n\t"
2479
- "sub t3, t3, t0 \n\t"
2480
- "addi %[BIAS], %[BIAS], 16 \n\t"
2481
- "vsetvli t0, t3, e32, mf2 \n\t"
2482
- "vle32.v v31, (%[BIAS]) \n\t"
2483
-
2484
- "LOOP_K%=: \n\t"
2485
- "vsetvli t0, zero, e16, mf4 \n\t"
2486
-
2487
- "vle16.v v4, (s1) \n\t"
2488
- "addi s1, s1, 32 \n\t"
2489
- "vle16.v v5, (s2) \n\t"
2490
- "addi s2, s2, 56 \n\t"
2491
- "vle16.v v6, (s3) \n\t"
2492
- "addi s3, s3, 80 \n\t"
2493
- "vle16.v v7, (s4) \n\t"
2494
- "addi s4, s4, 104 \n\t"
2495
- "flw f1, (s5) \n\t"
2496
- "addi s5, s5, 4 \n\t"
2497
- "vfwcvt.f.f.v v8, v4 \n\t"
2498
- "vfwcvt.f.f.v v9, v5 \n\t"
2499
- "vfwcvt.f.f.v v10, v6 \n\t"
2500
- "vfwcvt.f.f.v v11, v7 \n\t"
2501
-
2502
- "vsetvli t0, zero, e32, mf2 \n\t"
2503
- "addi t5, %[INNER], 0 \n\t"
2504
- "vxor.vv v16, v16, v16 \n\t"
2505
- "vxor.vv v18, v18, v18 \n\t"
2506
- "vxor.vv v20, v20, v20 \n\t"
2507
- "vxor.vv v22, v22, v22 \n\t"
2508
- "vfmul.vf v24, v8, f1 \n\t"
2509
- "vfmul.vf v25, v9, f1 \n\t"
2510
- "vfmul.vf v26, v10, f1 \n\t"
2511
- "vfmul.vf v27, v11, f1 \n\t"
2512
- "addi %[CNT], %[CNT], -1 \n\t"
2513
- "vsetvli t0, zero, e8, m1 \n\t"
2514
- "LOOP_INNER%=: \n\t"
2515
-
2516
- SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2517
-
2518
- "vadd.vi v0, v0, -8 \n\t"
2519
- "vadd.vi v1, v1, -8 \n\t"
2520
- "vadd.vi v2, v2, -8 \n\t"
2521
- "vadd.vi v3, v3, -8 \n\t"
2522
- "vadd.vi v4, v4, -8 \n\t"
2523
- "vadd.vi v5, v5, -8 \n\t"
2524
- "vadd.vi v6, v6, -8 \n\t"
2525
- "vadd.vi v7, v7, -8 \n\t"
2526
-
2527
- SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2528
-
2529
- "bnez t5, LOOP_INNER%= \n\t"
2530
- "vsetvli t0, zero, e32, mf2 \n\t"
2531
-
2532
- SQ4BIT_KERNEL_ACC_F16_1X4X4
2533
-
2534
- "bnez %[CNT], LOOP_K%= \n\t"
2535
- "addi t3, zero, 16 \n\t"
2536
- "addi s1, %[C], 16 \n\t"
2537
- "addi s2, %[C], 32 \n\t"
2538
- "addi s3, %[C], 48 \n\t"
2539
- "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2540
- "vse32.v v28, (%[C]) \n\t"
2541
- "vse32.v v29, (s1) \n\t"
2542
- "vse32.v v30, (s2) \n\t"
2543
- "vse32.v v31, (s3) \n\t"
2544
- "jal x0, END%= \n\t"
2545
-
2546
- "ST_TAIL%=: \n\t"
2547
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2548
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2549
- "vse32.v v28, (%[C]) \n\t"
2550
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2551
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2552
- "vse32.v v29, (s1) \n\t"
2553
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2554
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2555
- "vse32.v v30, (s2) \n\t"
2556
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2557
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2558
- "vse32.v v31, (s3) \n\t"
2559
- "END%=: \n\t"
2560
-
2561
- : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
2562
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2563
- : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
2564
- } else {
2565
- __asm__ volatile(
2566
- "vsetvli t0, zero, e32, m4 \n\t"
2567
- "vxor.vv v28, v28, v28 \n\t"
2568
- "addi s1, %[B], 0 \n\t"
2569
- "addi s2, %[B], 8 \n\t"
2570
- "addi s3, %[B], 16 \n\t"
2571
- "addi s4, %[B], 24 \n\t"
2572
-
2573
- "addi s5, %[A], 0 \n\t"
2574
- "addi s6, %[A], 12 \n\t"
2575
- "LOOP_K%=: \n\t"
2576
- "vsetvli t0, zero, e16, mf4 \n\t"
2577
- "vle16.v v4, (s1) \n\t"
2578
- "addi s1, s1, 32 \n\t"
2579
- "vle16.v v5, (s2) \n\t"
2580
- "addi s2, s2, 56 \n\t"
2581
- "vle16.v v6, (s3) \n\t"
2582
- "addi s3, s3, 80 \n\t"
2583
- "vle16.v v7, (s4) \n\t"
2584
- "addi s4, s4, 104 \n\t"
2585
- "flw f1, (s5) \n\t"
2586
- "addi s5, s5, 4 \n\t"
2587
-
2588
- "vfwcvt.f.f.v v8, v4 \n\t"
2589
- "vfwcvt.f.f.v v9, v5 \n\t"
2590
- "vfwcvt.f.f.v v10, v6 \n\t"
2591
- "vfwcvt.f.f.v v11, v7 \n\t"
2592
- "vsetvli t0, zero, e32, mf2 \n\t"
2593
-
2594
- "addi t5, %[INNER], 0 \n\t"
2595
- "vxor.vv v16, v16, v16 \n\t"
2596
- "vxor.vv v18, v18, v18 \n\t"
2597
- "vxor.vv v20, v20, v20 \n\t"
2598
- "vxor.vv v22, v22, v22 \n\t"
2599
- "vfmul.vf v24, v8, f1 \n\t"
2600
- "vfmul.vf v25, v9, f1 \n\t"
2601
- "vfmul.vf v26, v10, f1 \n\t"
2602
- "vfmul.vf v27, v11, f1 \n\t"
2603
- "addi %[CNT], %[CNT], -1 \n\t"
2604
- "vsetvli t0, zero, e8, m1 \n\t"
2605
- "LOOP_INNER%=: \n\t"
2606
-
2607
- SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2608
-
2609
- "vadd.vi v0, v0, -8 \n\t"
2610
- "vadd.vi v1, v1, -8 \n\t"
2611
- "vadd.vi v2, v2, -8 \n\t"
2612
- "vadd.vi v3, v3, -8 \n\t"
2613
- "vadd.vi v4, v4, -8 \n\t"
2614
- "vadd.vi v5, v5, -8 \n\t"
2615
- "vadd.vi v6, v6, -8 \n\t"
2616
- "vadd.vi v7, v7, -8 \n\t"
2617
-
2618
- SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2619
-
2620
- "bnez t5, LOOP_INNER%= \n\t"
2621
- "vsetvli t0, zero, e32, mf2 \n\t"
2622
-
2623
- SQ4BIT_KERNEL_ACC_F16_1X4X4
2624
-
2625
- "bnez %[CNT], LOOP_K%= \n\t"
2626
- "addi t3, zero, 16 \n\t"
2627
- "addi s1, %[C], 16 \n\t"
2628
- "addi s2, %[C], 32 \n\t"
2629
- "addi s3, %[C], 48 \n\t"
2630
- "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2631
- "vse32.v v28, (%[C]) \n\t"
2632
- "vse32.v v29, (s1) \n\t"
2633
- "vse32.v v30, (s2) \n\t"
2634
- "vse32.v v31, (s3) \n\t"
2635
- "jal x0, END%= \n\t"
2636
-
2637
- "ST_TAIL%=: \n\t"
2638
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2639
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2640
- "vse32.v v28, (%[C]) \n\t"
2641
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2642
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2643
- "vse32.v v29, (s1) \n\t"
2644
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2645
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2646
- "vse32.v v30, (s2) \n\t"
2647
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2648
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2649
- "vse32.v v31, (s3) \n\t"
2650
- "END%=: \n\t"
2651
-
2652
- : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
2653
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2654
- : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
2655
- }
2656
- }
2657
- }
2658
- }
2659
784
 
2660
- template <bool HasZeroPoint>
2661
- void SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen,
2662
- const std::byte * QuantA,
2663
- const std::byte * QuantBData,
2664
- const float * QuantBScale,
2665
- const std::byte * QuantBZeroPoint,
2666
- float * C,
2667
- size_t CountN,
2668
- size_t BlockCountK,
2669
- const float * Bias) {
2670
- GGML_UNUSED(QuantBScale);
2671
- GGML_UNUSED(QuantBZeroPoint);
2672
- const size_t INNER = BlkLen / 16;
2673
- if constexpr (HasZeroPoint) {
2674
- for (size_t n = 0; n < CountN; n += 16) {
2675
- size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2676
- std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2677
- n * BlockCountK * BlkLen / 2 + // b data
2678
- n * BlockCountK * sizeof(uint8_t) + // zp
2679
- n * BlockCountK * sizeof(float); // scale
2680
- float * CPtr = C + n;
2681
- size_t cnt = BlockCountK;
2682
- if (Bias != nullptr) {
2683
- const float * bias = Bias + n;
2684
- __asm__ volatile(
2685
- "addi t3, %[NBLKS], 0 \n\t"
2686
- "vsetvli t0, zero, e8, m1 \n\t"
2687
- "vmv.v.i v13, 3 \n\t"
2688
- "li s1, 24 \n\t"
2689
- "vsetvli t0, s1, e8, m1 \n\t"
2690
- "vmv.v.i v13, 2 \n\t"
2691
- "vsetvli t0, zero, e8, mf2 \n\t"
2692
- "vmv.v.i v13, 1 \n\t"
2693
- "vsetvli t0, zero, e8, mf4 \n\t"
2694
- "vmv.v.i v13, 0 \n\t"
2695
- "vsetvli t0, zero, e32, m4 \n\t"
2696
- "vxor.vv v28, v28, v28 \n\t"
2697
-
2698
- // scale offset, scale0.0, scale1.0, scale2.0, scale3.0....scale15.0
2699
- "addi s1, %[B], 0 \n\t"
2700
- "addi s2, %[B], 16 \n\t"
2701
- "addi s3, %[B], 32 \n\t"
2702
- "addi s4, %[B], 48 \n\t"
2703
- // zp offset
2704
- "addi s7, %[B], 64 \n\t"
2705
- // a offset
2706
- "addi s5, %[A], 0 \n\t"
2707
- "addi s6, %[A], 12 \n\t"
2708
-
2709
- "vsetvli t0, t3, e32, mf2 \n\t"
2710
- "vle32.v v28, (%[BIAS]) \n\t"
2711
- "sub t3, t3, t0 \n\t"
2712
- "addi %[BIAS], %[BIAS], 16 \n\t"
2713
- "vsetvli t0, t3, e32, mf2 \n\t"
2714
- "vle32.v v29, (%[BIAS]) \n\t"
2715
- "sub t3, t3, t0 \n\t"
2716
- "addi %[BIAS], %[BIAS], 16 \n\t"
2717
- "vsetvli t0, t3, e32, mf2 \n\t"
2718
- "vle32.v v30, (%[BIAS]) \n\t"
2719
- "sub t3, t3, t0 \n\t"
2720
- "addi %[BIAS], %[BIAS], 16 \n\t"
2721
- "vsetvli t0, t3, e32, mf2 \n\t"
2722
- "vle32.v v31, (%[BIAS]) \n\t"
2723
- "vsetvli t0, zero, e32, mf2 \n\t"
2724
- "LOOP_K%=: \n\t"
2725
-
2726
- // load scale
2727
- "vle32.v v8, (s1) \n\t"
2728
- "addi s1, s1, 80 \n\t"
2729
- "vle32.v v9, (s2) \n\t"
2730
- "addi s2, s2, 96 \n\t"
2731
- "vle32.v v10, (s3) \n\t"
2732
- "addi s3, s3, 112 \n\t"
2733
- "vle32.v v11, (s4) \n\t"
2734
- "addi s4, s4, 128 \n\t"
2735
-
2736
- // load a scale
2737
- "flw f1, (s5) \n\t"
2738
- "addi s5, s5, 4 \n\t"
2739
-
2740
- "addi t5, %[INNER], 0 \n\t"
2741
- "vxor.vv v16, v16, v16 \n\t"
2742
- "vxor.vv v18, v18, v18 \n\t"
2743
- "vxor.vv v20, v20, v20 \n\t"
2744
- "vxor.vv v22, v22, v22 \n\t"
2745
-
2746
- // a scale * b scale
2747
- "vfmul.vf v24, v8, f1 \n\t"
2748
- "vfmul.vf v25, v9, f1 \n\t"
2749
- "vfmul.vf v26, v10, f1 \n\t"
2750
- "vfmul.vf v27, v11, f1 \n\t"
2751
- "addi %[CNT], %[CNT], -1 \n\t"
2752
-
2753
- SQ4BIT_KERNEL_LOAD_ZP_16X1
2754
-
2755
- "LOOP_INNER%=: \n\t"
2756
-
2757
- SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2758
-
2759
- "vsub.vv v0, v0, v8 \n\t"
2760
- "vsub.vv v4, v4, v8 \n\t"
2761
- "vsub.vv v1, v1, v9 \n\t"
2762
- "vsub.vv v5, v5, v9 \n\t"
2763
- "vsub.vv v2, v2, v10 \n\t"
2764
- "vsub.vv v6, v6, v10 \n\t"
2765
- "vsub.vv v3, v3, v11 \n\t"
2766
- "vsub.vv v7, v7, v11 \n\t"
2767
-
2768
- SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2769
-
2770
- "bnez t5, LOOP_INNER%= \n\t"
2771
- "vsetvli t0, zero, e32, mf2 \n\t"
2772
-
2773
- SQ4BIT_KERNEL_ACC_1X4X4
2774
- "addi s7, s1, 64 \n\t"
2775
-
2776
- "bnez %[CNT], LOOP_K%= \n\t"
2777
-
2778
- "addi t3, zero, 16 \n\t"
2779
- "addi s1, %[C], 16 \n\t"
2780
- "addi s2, %[C], 32 \n\t"
2781
- "addi s3, %[C], 48 \n\t"
2782
- "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2783
- "vse32.v v28, (%[C]) \n\t"
2784
- "vse32.v v29, (s1) \n\t"
2785
- "vse32.v v30, (s2) \n\t"
2786
- "vse32.v v31, (s3) \n\t"
2787
- "jal x0, END%= \n\t"
2788
-
2789
- "ST_TAIL%=: \n\t"
2790
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2791
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2792
- "vse32.v v28, (%[C]) \n\t"
2793
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2794
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2795
- "vse32.v v29, (s1) \n\t"
2796
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2797
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2798
- "vse32.v v30, (s2) \n\t"
2799
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2800
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2801
- "vse32.v v31, (s3) \n\t"
2802
- "END%=: \n\t"
2803
-
2804
- : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
2805
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2806
- : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2807
- } else {
2808
- __asm__ volatile(
2809
- "vsetvli t0, zero, e32, m4 \n\t"
2810
- "vxor.vv v28, v28, v28 \n\t"
2811
-
2812
- "vsetvli t0, zero, e8, m1 \n\t"
2813
- "vmv.v.i v13, 3 \n\t"
2814
- "li s1, 24 \n\t"
2815
- "vsetvli t0, s1, e8, m1 \n\t"
2816
- "vmv.v.i v13, 2 \n\t"
2817
- "vsetvli t0, zero, e8, mf2 \n\t"
2818
- "vmv.v.i v13, 1 \n\t"
2819
- "vsetvli t0, zero, e8, mf4 \n\t"
2820
- "vmv.v.i v13, 0 \n\t"
2821
- "addi s1, %[B], 0 \n\t"
2822
- "addi s2, %[B], 16 \n\t"
2823
- "addi s3, %[B], 32 \n\t"
2824
- "addi s4, %[B], 48 \n\t"
2825
-
2826
- "addi s7, %[B], 64 \n\t"
2827
-
2828
- "addi s5, %[A], 0 \n\t"
2829
- "addi s6, %[A], 12 \n\t"
2830
- "vsetvli t0, zero, e32, mf2 \n\t"
2831
-
2832
- "LOOP_K%=: \n\t"
2833
- "vle32.v v8, (s1) \n\t"
2834
- "addi s1, s1, 80 \n\t"
2835
- "vle32.v v9, (s2) \n\t"
2836
- "addi s2, s2, 96 \n\t"
2837
- "vle32.v v10, (s3) \n\t"
2838
- "addi s3, s3, 112 \n\t"
2839
- "vle32.v v11, (s4) \n\t"
2840
- "addi s4, s4, 128 \n\t"
2841
-
2842
- "flw f1, (s5) \n\t"
2843
- "addi s5, s5, 4 \n\t"
2844
-
2845
- "addi t5, %[INNER], 0 \n\t"
2846
- "vxor.vv v16, v16, v16 \n\t"
2847
- "vxor.vv v18, v18, v18 \n\t"
2848
- "vxor.vv v20, v20, v20 \n\t"
2849
- "vxor.vv v22, v22, v22 \n\t"
2850
-
2851
- "vfmul.vf v24, v8, f1 \n\t"
2852
- "vfmul.vf v25, v9, f1 \n\t"
2853
- "vfmul.vf v26, v10, f1 \n\t"
2854
- "vfmul.vf v27, v11, f1 \n\t"
2855
- "addi %[CNT], %[CNT], -1 \n\t"
2856
-
2857
- SQ4BIT_KERNEL_LOAD_ZP_16X1
2858
-
2859
- "LOOP_INNER%=: \n\t"
2860
-
2861
- SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2862
-
2863
- "vsub.vv v0, v0, v8 \n\t"
2864
- "vsub.vv v4, v4, v8 \n\t"
2865
- "vsub.vv v1, v1, v9 \n\t"
2866
- "vsub.vv v5, v5, v9 \n\t"
2867
- "vsub.vv v2, v2, v10 \n\t"
2868
- "vsub.vv v6, v6, v10 \n\t"
2869
- "vsub.vv v3, v3, v11 \n\t"
2870
- "vsub.vv v7, v7, v11 \n\t"
2871
-
2872
- SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2873
-
2874
- "bnez t5, LOOP_INNER%= \n\t"
2875
- "vsetvli t0, zero, e32, mf2 \n\t"
2876
-
2877
- SQ4BIT_KERNEL_ACC_1X4X4
2878
- "addi s7, s1, 64 \n\t"
2879
-
2880
- "bnez %[CNT], LOOP_K%= \n\t"
2881
-
2882
- "addi t3, zero, 16 \n\t"
2883
- "addi s1, %[C], 16 \n\t"
2884
- "addi s2, %[C], 32 \n\t"
2885
- "addi s3, %[C], 48 \n\t"
2886
- "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2887
- "vse32.v v28, (%[C]) \n\t"
2888
- "vse32.v v29, (s1) \n\t"
2889
- "vse32.v v30, (s2) \n\t"
2890
- "vse32.v v31, (s3) \n\t"
2891
- "jal x0, END%= \n\t"
2892
-
2893
- "ST_TAIL%=: \n\t"
2894
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2895
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2896
- "vse32.v v28, (%[C]) \n\t"
2897
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2898
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2899
- "vse32.v v29, (s1) \n\t"
2900
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2901
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2902
- "vse32.v v30, (s2) \n\t"
2903
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2904
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
2905
- "vse32.v v31, (s3) \n\t"
2906
- "END%=: \n\t"
2907
-
2908
- : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
2909
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2910
- : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2911
- }
785
+ __asm__ volatile(
786
+ "vsetvli t0, zero, e32, m4 \n\t"
787
+ "vxor.vv v28, v28, v28 \n\t"
788
+
789
+ "vsetvli t0, zero, e8, m1 \n\t"
790
+ "vmv.v.i v13, 3 \n\t"
791
+ "li s1, 24 \n\t"
792
+ "vsetvli t0, s1, e8, m1 \n\t"
793
+ "vmv.v.i v13, 2 \n\t"
794
+ "vsetvli t0, zero, e8, mf2 \n\t"
795
+ "vmv.v.i v13, 1 \n\t"
796
+ "vsetvli t0, zero, e8, mf4 \n\t"
797
+ "vmv.v.i v13, 0 \n\t"
798
+
799
+ "addi s1, %[B], 0 \n\t"
800
+ "addi s2, %[B], 8 \n\t"
801
+ "addi s3, %[B], 16 \n\t"
802
+ "addi s4, %[B], 24 \n\t"
803
+
804
+ "addi s7, %[B], 32 \n\t"
805
+
806
+ "addi s5, %[A], 0 \n\t"
807
+ "addi s6, %[A], 12 \n\t"
808
+ "LOOP_K%=: \n\t"
809
+ "vsetvli t0, zero, e16, mf4 \n\t"
810
+ "vle16.v v4, (s1) \n\t"
811
+ "addi s1, s1, 48 \n\t"
812
+ "vle16.v v5, (s2) \n\t"
813
+ "addi s2, s2, 72 \n\t"
814
+ "vle16.v v6, (s3) \n\t"
815
+ "addi s3, s3, 96 \n\t"
816
+ "vle16.v v7, (s4) \n\t"
817
+ "addi s4, s4, 120 \n\t"
818
+ "flw f1, (s5) \n\t"
819
+ "addi s5, s5, 4 \n\t"
820
+
821
+ "vfwcvt.f.f.v v8, v4 \n\t"
822
+ "vfwcvt.f.f.v v9, v5 \n\t"
823
+ "vfwcvt.f.f.v v10, v6 \n\t"
824
+ "vfwcvt.f.f.v v11, v7 \n\t"
825
+ "vsetvli t0, zero, e32, mf2 \n\t"
826
+
827
+ "addi t5, %[INNER], 0 \n\t"
828
+ "vxor.vv v16, v16, v16 \n\t"
829
+ "vxor.vv v18, v18, v18 \n\t"
830
+ "vxor.vv v20, v20, v20 \n\t"
831
+ "vxor.vv v22, v22, v22 \n\t"
832
+ "vfmul.vf v24, v8, f1 \n\t"
833
+ "vfmul.vf v25, v9, f1 \n\t"
834
+ "vfmul.vf v26, v10, f1 \n\t"
835
+ "vfmul.vf v27, v11, f1 \n\t"
836
+ "addi %[CNT], %[CNT], -1 \n\t"
837
+
838
+ SQ4BIT_KERNEL_LOAD_ZP_16X1
839
+
840
+ "LOOP_INNER%=: \n\t"
841
+
842
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
843
+
844
+ "vsub.vv v0, v0, v8 \n\t"
845
+ "vsub.vv v4, v4, v8 \n\t"
846
+ "vsub.vv v1, v1, v9 \n\t"
847
+ "vsub.vv v5, v5, v9 \n\t"
848
+ "vsub.vv v2, v2, v10 \n\t"
849
+ "vsub.vv v6, v6, v10 \n\t"
850
+ "vsub.vv v3, v3, v11 \n\t"
851
+ "vsub.vv v7, v7, v11 \n\t"
852
+
853
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
854
+
855
+ "bnez t5, LOOP_INNER%= \n\t"
856
+ "vsetvli t0, zero, e32, mf2 \n\t"
857
+
858
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
859
+ "addi s7, s1, 32 \n\t"
860
+
861
+ "bnez %[CNT], LOOP_K%= \n\t"
862
+ "addi t3, zero, 16 \n\t"
863
+ "addi s1, %[C], 16 \n\t"
864
+ "addi s2, %[C], 32 \n\t"
865
+ "addi s3, %[C], 48 \n\t"
866
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
867
+ "vse32.v v28, (%[C]) \n\t"
868
+ "vse32.v v29, (s1) \n\t"
869
+ "vse32.v v30, (s2) \n\t"
870
+ "vse32.v v31, (s3) \n\t"
871
+ "jal x0, END%= \n\t"
872
+
873
+ "ST_TAIL%=: \n\t"
874
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
875
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
876
+ "vse32.v v28, (%[C]) \n\t"
877
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
878
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
879
+ "vse32.v v29, (s1) \n\t"
880
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
881
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
882
+ "vse32.v v30, (s2) \n\t"
883
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
884
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
885
+ "vse32.v v31, (s3) \n\t"
886
+ "END%=: \n\t"
887
+
888
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
889
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
890
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2912
891
  }
2913
892
  } else {
2914
893
  for (size_t n = 0; n < CountN; n += 16) {
2915
- size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2916
- std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2917
- n * BlockCountK * BlkLen / 2 + // b data
2918
- n * BlockCountK * sizeof(float); // scale
894
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
895
+ uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + //
896
+ n * BlockCountK * BlkLen / 2 + // b data
897
+ n * BlockCountK * sizeof(_Float16); // scale
2919
898
  float * CPtr = C + n;
2920
899
  size_t cnt = BlockCountK;
2921
- if (Bias != nullptr) {
2922
- const float * bias = Bias + n;
2923
- __asm__ volatile(
2924
- "addi t3, %[NBLKS], 0 \n\t"
2925
- "addi s1, %[B], 0 \n\t"
2926
- "addi s2, %[B], 16 \n\t"
2927
- "addi s3, %[B], 32 \n\t"
2928
- "addi s4, %[B], 48 \n\t"
2929
- "addi s5, %[A], 0 \n\t"
2930
- "addi s6, %[A], 12 \n\t"
2931
- "vsetvli t0, t3, e32, mf2 \n\t"
2932
- "vle32.v v28, (%[BIAS]) \n\t"
2933
- "sub t3, t3, t0 \n\t"
2934
- "addi %[BIAS], %[BIAS], 16 \n\t"
2935
- "vsetvli t0, t3, e32, mf2 \n\t"
2936
- "vle32.v v29, (%[BIAS]) \n\t"
2937
- "sub t3, t3, t0 \n\t"
2938
- "addi %[BIAS], %[BIAS], 16 \n\t"
2939
- "vsetvli t0, t3, e32, mf2 \n\t"
2940
- "vle32.v v30, (%[BIAS]) \n\t"
2941
- "sub t3, t3, t0 \n\t"
2942
- "addi %[BIAS], %[BIAS], 16 \n\t"
2943
- "vsetvli t0, t3, e32, mf2 \n\t"
2944
- "vle32.v v31, (%[BIAS]) \n\t"
2945
- "vsetvli t0, zero, e32, mf2 \n\t"
2946
- "LOOP_K%=: \n\t"
2947
- "vle32.v v8, (s1) \n\t"
2948
- "addi s1, s1, 64 \n\t"
2949
- "vle32.v v9, (s2) \n\t"
2950
- "addi s2, s2, 80 \n\t"
2951
- "vle32.v v10, (s3) \n\t"
2952
- "addi s3, s3, 96 \n\t"
2953
- "vle32.v v11, (s4) \n\t"
2954
- "addi s4, s4, 112 \n\t"
2955
- "flw f1, (s5) \n\t"
2956
- "addi s5, s5, 4 \n\t"
2957
-
2958
- "addi t5, %[INNER], 0 \n\t"
2959
- "vxor.vv v16, v16, v16 \n\t"
2960
- "vxor.vv v18, v18, v18 \n\t"
2961
- "vxor.vv v20, v20, v20 \n\t"
2962
- "vxor.vv v22, v22, v22 \n\t"
2963
- "vfmul.vf v24, v8, f1 \n\t"
2964
- "vfmul.vf v25, v9, f1 \n\t"
2965
- "vfmul.vf v26, v10, f1 \n\t"
2966
- "vfmul.vf v27, v11, f1 \n\t"
2967
- "addi %[CNT], %[CNT], -1 \n\t"
2968
- "vsetvli t0, zero, e8, m1 \n\t"
2969
- "LOOP_INNER%=: \n\t"
2970
-
2971
- SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2972
-
2973
- "vadd.vi v0, v0, -8 \n\t"
2974
- "vadd.vi v1, v1, -8 \n\t"
2975
- "vadd.vi v2, v2, -8 \n\t"
2976
- "vadd.vi v3, v3, -8 \n\t"
2977
- "vadd.vi v4, v4, -8 \n\t"
2978
- "vadd.vi v5, v5, -8 \n\t"
2979
- "vadd.vi v6, v6, -8 \n\t"
2980
- "vadd.vi v7, v7, -8 \n\t"
2981
-
2982
- SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2983
-
2984
- "bnez t5, LOOP_INNER%= \n\t"
2985
- "vsetvli t0, zero, e32, mf2 \n\t"
2986
-
2987
- SQ4BIT_KERNEL_ACC_1X4X4
2988
-
2989
- "bnez %[CNT], LOOP_K%= \n\t"
2990
- "addi t3, zero, 16 \n\t"
2991
- "addi s1, %[C], 16 \n\t"
2992
- "addi s2, %[C], 32 \n\t"
2993
- "addi s3, %[C], 48 \n\t"
2994
- "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2995
- "vse32.v v28, (%[C]) \n\t"
2996
- "vse32.v v29, (s1) \n\t"
2997
- "vse32.v v30, (s2) \n\t"
2998
- "vse32.v v31, (s3) \n\t"
2999
- "jal x0, END%= \n\t"
3000
-
3001
- "ST_TAIL%=: \n\t"
3002
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3003
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
3004
- "vse32.v v28, (%[C]) \n\t"
3005
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3006
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
3007
- "vse32.v v29, (s1) \n\t"
3008
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3009
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
3010
- "vse32.v v30, (s2) \n\t"
3011
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3012
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
3013
- "vse32.v v31, (s3) \n\t"
3014
- "END%=: \n\t"
3015
-
3016
- : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
3017
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
3018
- : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
3019
- } else {
3020
- __asm__ volatile(
3021
- "vsetvli t0, zero, e32, m4 \n\t"
3022
- "vxor.vv v28, v28, v28 \n\t"
3023
- "addi s1, %[B], 0 \n\t"
3024
- "addi s2, %[B], 16 \n\t"
3025
- "addi s3, %[B], 32 \n\t"
3026
- "addi s4, %[B], 48 \n\t"
3027
-
3028
- "addi s5, %[A], 0 \n\t"
3029
- "addi s6, %[A], 12 \n\t"
3030
- "vsetvli t0, zero, e32, mf2 \n\t"
3031
- "LOOP_K%=: \n\t"
3032
- "vle32.v v8, (s1) \n\t"
3033
- "addi s1, s1, 64 \n\t"
3034
- "vle32.v v9, (s2) \n\t"
3035
- "addi s2, s2, 80 \n\t"
3036
- "vle32.v v10, (s3) \n\t"
3037
- "addi s3, s3, 96 \n\t"
3038
- "vle32.v v11, (s4) \n\t"
3039
- "addi s4, s4, 112 \n\t"
3040
- "flw f1, (s5) \n\t"
3041
- "addi s5, s5, 4 \n\t"
3042
-
3043
- "addi t5, %[INNER], 0 \n\t"
3044
- "vxor.vv v16, v16, v16 \n\t"
3045
- "vxor.vv v18, v18, v18 \n\t"
3046
- "vxor.vv v20, v20, v20 \n\t"
3047
- "vxor.vv v22, v22, v22 \n\t"
3048
- "vfmul.vf v24, v8, f1 \n\t"
3049
- "vfmul.vf v25, v9, f1 \n\t"
3050
- "vfmul.vf v26, v10, f1 \n\t"
3051
- "vfmul.vf v27, v11, f1 \n\t"
3052
- "addi %[CNT], %[CNT], -1 \n\t"
3053
- "vsetvli t0, zero, e8, m1 \n\t"
3054
- "LOOP_INNER%=: \n\t"
3055
-
3056
- SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
3057
-
3058
- "vadd.vi v0, v0, -8 \n\t"
3059
- "vadd.vi v1, v1, -8 \n\t"
3060
- "vadd.vi v2, v2, -8 \n\t"
3061
- "vadd.vi v3, v3, -8 \n\t"
3062
- "vadd.vi v4, v4, -8 \n\t"
3063
- "vadd.vi v5, v5, -8 \n\t"
3064
- "vadd.vi v6, v6, -8 \n\t"
3065
- "vadd.vi v7, v7, -8 \n\t"
3066
-
3067
- SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
3068
-
3069
- "bnez t5, LOOP_INNER%= \n\t"
3070
- "vsetvli t0, zero, e32, mf2 \n\t"
3071
-
3072
- SQ4BIT_KERNEL_ACC_1X4X4
3073
-
3074
- "bnez %[CNT], LOOP_K%= \n\t"
3075
- "addi t3, zero, 16 \n\t"
3076
- "addi s1, %[C], 16 \n\t"
3077
- "addi s2, %[C], 32 \n\t"
3078
- "addi s3, %[C], 48 \n\t"
3079
- "blt %[NBLKS], t3, ST_TAIL%= \n\t"
3080
- "vse32.v v28, (%[C]) \n\t"
3081
- "vse32.v v29, (s1) \n\t"
3082
- "vse32.v v30, (s2) \n\t"
3083
- "vse32.v v31, (s3) \n\t"
3084
- "jal x0, END%= \n\t"
3085
-
3086
- "ST_TAIL%=: \n\t"
3087
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3088
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
3089
- "vse32.v v28, (%[C]) \n\t"
3090
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3091
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
3092
- "vse32.v v29, (s1) \n\t"
3093
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3094
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
3095
- "vse32.v v30, (s2) \n\t"
3096
- "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3097
- "sub %[NBLKS], %[NBLKS], t0 \n\t"
3098
- "vse32.v v31, (s3) \n\t"
3099
- "END%=: \n\t"
3100
-
3101
- : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
3102
- : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
3103
- : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
3104
- }
3105
- }
3106
- }
3107
- }
3108
-
3109
- template <bool HasZeroPoint>
3110
- inline void SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen,
3111
- const std::byte * QuantA,
3112
- const std::byte * QuantBData,
3113
- const float * QuantBScale,
3114
- const std::byte * QuantBZeroPoint,
3115
- float * C,
3116
- size_t CountM,
3117
- size_t CountN,
3118
- size_t BlockStrideQuantB,
3119
- const float * Bias,
3120
- const size_t ldc,
3121
- const size_t scalestride) {
3122
- if (scalestride == 4) {
3123
- SQ4BitGemmM4Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C,
3124
- CountN, BlockStrideQuantB, Bias, ldc);
3125
-
3126
- } else if (scalestride == 2) {
3127
- SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>(
3128
- BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias, ldc);
3129
- }
3130
- }
3131
900
 
3132
- template <bool HasZeroPoint>
3133
- inline void SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen,
3134
- const std::byte * QuantA,
3135
- const std::byte * QuantBData,
3136
- const float * QuantBScale,
3137
- const std::byte * QuantBZeroPoint,
3138
- float * C,
3139
- size_t CountM,
3140
- size_t CountN,
3141
- size_t BlockStrideQuantB,
3142
- const float * Bias,
3143
- const size_t ldc,
3144
- const size_t scalestride) {
3145
- if (scalestride == 4) {
3146
- SQ4BitGemmM1Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C,
3147
- CountN, BlockStrideQuantB, Bias);
3148
- } else if (scalestride == 2) {
3149
- SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale,
3150
- QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias);
901
+ __asm__ volatile(
902
+ "vsetvli t0, zero, e32, m4 \n\t"
903
+ "vxor.vv v28, v28, v28 \n\t"
904
+ "addi s1, %[B], 0 \n\t"
905
+ "addi s2, %[B], 8 \n\t"
906
+ "addi s3, %[B], 16 \n\t"
907
+ "addi s4, %[B], 24 \n\t"
908
+
909
+ "addi s5, %[A], 0 \n\t"
910
+ "addi s6, %[A], 12 \n\t"
911
+ "LOOP_K%=: \n\t"
912
+ "vsetvli t0, zero, e16, mf4 \n\t"
913
+ "vle16.v v4, (s1) \n\t"
914
+ "addi s1, s1, 32 \n\t"
915
+ "vle16.v v5, (s2) \n\t"
916
+ "addi s2, s2, 56 \n\t"
917
+ "vle16.v v6, (s3) \n\t"
918
+ "addi s3, s3, 80 \n\t"
919
+ "vle16.v v7, (s4) \n\t"
920
+ "addi s4, s4, 104 \n\t"
921
+ "flw f1, (s5) \n\t"
922
+ "addi s5, s5, 4 \n\t"
923
+
924
+ "vfwcvt.f.f.v v8, v4 \n\t"
925
+ "vfwcvt.f.f.v v9, v5 \n\t"
926
+ "vfwcvt.f.f.v v10, v6 \n\t"
927
+ "vfwcvt.f.f.v v11, v7 \n\t"
928
+ "vsetvli t0, zero, e32, mf2 \n\t"
929
+
930
+ "addi t5, %[INNER], 0 \n\t"
931
+ "vxor.vv v16, v16, v16 \n\t"
932
+ "vxor.vv v18, v18, v18 \n\t"
933
+ "vxor.vv v20, v20, v20 \n\t"
934
+ "vxor.vv v22, v22, v22 \n\t"
935
+ "vfmul.vf v24, v8, f1 \n\t"
936
+ "vfmul.vf v25, v9, f1 \n\t"
937
+ "vfmul.vf v26, v10, f1 \n\t"
938
+ "vfmul.vf v27, v11, f1 \n\t"
939
+ "addi %[CNT], %[CNT], -1 \n\t"
940
+ "vsetvli t0, zero, e8, m1 \n\t"
941
+ "LOOP_INNER%=: \n\t"
942
+
943
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
944
+
945
+ "vadd.vi v0, v0, -8 \n\t"
946
+ "vadd.vi v1, v1, -8 \n\t"
947
+ "vadd.vi v2, v2, -8 \n\t"
948
+ "vadd.vi v3, v3, -8 \n\t"
949
+ "vadd.vi v4, v4, -8 \n\t"
950
+ "vadd.vi v5, v5, -8 \n\t"
951
+ "vadd.vi v6, v6, -8 \n\t"
952
+ "vadd.vi v7, v7, -8 \n\t"
953
+
954
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
955
+
956
+ "bnez t5, LOOP_INNER%= \n\t"
957
+ "vsetvli t0, zero, e32, mf2 \n\t"
958
+
959
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
960
+
961
+ "bnez %[CNT], LOOP_K%= \n\t"
962
+ "addi t3, zero, 16 \n\t"
963
+ "addi s1, %[C], 16 \n\t"
964
+ "addi s2, %[C], 32 \n\t"
965
+ "addi s3, %[C], 48 \n\t"
966
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
967
+ "vse32.v v28, (%[C]) \n\t"
968
+ "vse32.v v29, (s1) \n\t"
969
+ "vse32.v v30, (s2) \n\t"
970
+ "vse32.v v31, (s3) \n\t"
971
+ "jal x0, END%= \n\t"
972
+
973
+ "ST_TAIL%=: \n\t"
974
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
975
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
976
+ "vse32.v v28, (%[C]) \n\t"
977
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
978
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
979
+ "vse32.v v29, (s1) \n\t"
980
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
981
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
982
+ "vse32.v v30, (s2) \n\t"
983
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
984
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
985
+ "vse32.v v31, (s3) \n\t"
986
+ "END%=: \n\t"
987
+
988
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
989
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
990
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
991
+ }
3151
992
  }
3152
993
  }
3153
-
3154
994
  } // namespace
3155
995
 
3156
996
  namespace ime1 {
3157
- size_t gemm_kernel_i8i4(size_t BlkLen,
3158
- const std::byte * QuantA,
3159
- const std::byte * QuantBData,
3160
- const float * QuantBScale,
3161
- const std::byte * QuantBZeroPoint,
3162
- float * C,
3163
- size_t CountM,
3164
- size_t CountN,
3165
- size_t CountK,
3166
- size_t BlockCountK,
3167
- size_t ldc,
3168
- const float * Bias,
3169
- const size_t ScaleStride) {
3170
- GGML_UNUSED(CountM);
3171
- GGML_UNUSED(CountK);
3172
- GGML_UNUSED(ldc);
3173
- if (CountM >= 4) {
3174
- if (QuantBZeroPoint != nullptr) {
3175
- SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint,
3176
- C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride);
997
+ size_t gemm_kernel_i8i4(size_t blk_len,
998
+ const uint8_t * quant_a_ptr,
999
+ const uint8_t * quant_b_data,
1000
+ const uint8_t * quant_b_zp,
1001
+ float * c_ptr,
1002
+ size_t count_m,
1003
+ size_t count_n,
1004
+ size_t k_blks,
1005
+ size_t ldc) {
1006
+ if (count_m >= 4) {
1007
+ if (quant_b_zp != nullptr) {
1008
+ SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl<true>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, k_blks,
1009
+ ldc);
3177
1010
  } else {
3178
- SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale,
3179
- QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias,
3180
- ldc, ScaleStride);
1011
+ SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl<false>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n,
1012
+ k_blks, ldc);
3181
1013
  }
3182
1014
  return 4;
3183
1015
  } else {
3184
- if (QuantBZeroPoint != nullptr) {
3185
- SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint,
3186
- C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride);
1016
+ if (quant_b_zp != nullptr) {
1017
+ SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl<true>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, k_blks,
1018
+ ldc);
3187
1019
  } else {
3188
- SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale,
3189
- QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias,
3190
- ldc, ScaleStride);
1020
+ SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl<false>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n,
1021
+ k_blks, ldc);
3191
1022
  }
3192
1023
  return 1;
3193
1024
  }
3194
1025
  }
3195
1026
  } // namespace ime1
3196
- } // namespace sqnbitgemm_spacemit_ime
1027
+ } // namespace spacemit_kernels