whispercpp 1.3.3 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (586) hide show
  1. checksums.yaml +4 -4
  2. data/ext/ruby_whisper_params.c +55 -25
  3. data/ext/sources/CMakeLists.txt +1 -1
  4. data/ext/sources/bindings/javascript/package.json +1 -1
  5. data/ext/sources/build-xcframework.sh +24 -0
  6. data/ext/sources/examples/CMakeLists.txt +1 -0
  7. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  8. data/ext/sources/examples/addon.node/index.js +7 -5
  9. data/ext/sources/examples/bench/bench.cpp +26 -16
  10. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  11. data/ext/sources/examples/cli/cli.cpp +4 -2
  12. data/ext/sources/examples/command/command.cpp +26 -24
  13. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  14. data/ext/sources/examples/common-ggml.cpp +2 -0
  15. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  16. data/ext/sources/examples/server/server.cpp +24 -13
  17. data/ext/sources/examples/server.py +6 -1
  18. data/ext/sources/examples/stream/stream.cpp +4 -2
  19. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  20. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  21. data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
  22. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  23. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  24. data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
  25. data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
  26. data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
  27. data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
  28. data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
  29. data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
  30. data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
  31. data/ext/sources/examples/talk-llama/llama-context.h +44 -29
  32. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  33. data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
  34. data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
  35. data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
  36. data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
  37. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  38. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  39. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  40. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
  41. data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
  42. data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
  43. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
  44. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  45. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
  46. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
  47. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  48. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
  49. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  50. data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
  51. data/ext/sources/examples/talk-llama/llama-model.h +60 -9
  52. data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
  53. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  54. data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
  55. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
  56. data/ext/sources/examples/talk-llama/llama.cpp +65 -10
  57. data/ext/sources/examples/talk-llama/llama.h +95 -177
  58. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  59. data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
  60. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  61. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  62. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  63. data/ext/sources/ggml/CMakeLists.txt +59 -31
  64. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  65. data/ext/sources/ggml/include/ggml-backend.h +17 -1
  66. data/ext/sources/ggml/include/ggml-cpu.h +1 -1
  67. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  68. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  69. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  70. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  71. data/ext/sources/ggml/include/ggml.h +221 -16
  72. data/ext/sources/ggml/src/CMakeLists.txt +17 -2
  73. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  74. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
  76. data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  79. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  83. data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
  85. data/ext/sources/ggml/src/ggml-common.h +17 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  90. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
  91. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  92. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
  93. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  94. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  95. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  96. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  97. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  98. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
  99. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  100. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
  101. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
  103. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  104. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  105. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
  106. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
  107. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
  108. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
  109. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  110. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
  112. data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
  113. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
  114. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  115. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  116. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  117. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  118. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  119. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  120. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
  121. data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
  122. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  123. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  124. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  125. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  126. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  127. data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
  128. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  129. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  130. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  131. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  132. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
  133. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  134. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  135. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  136. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  137. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
  138. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
  139. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  140. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  141. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  142. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
  143. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  144. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  145. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  146. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
  147. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  148. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  149. data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
  150. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  151. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  152. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  153. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  154. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  155. data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  156. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  157. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  158. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  159. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  160. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  161. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  162. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  163. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  164. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  165. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  166. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  167. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  168. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  169. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  170. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  171. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  172. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  173. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  174. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  176. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  177. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
  178. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  179. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  191. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  192. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  193. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  234. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  235. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  236. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  237. data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
  238. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
  239. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  240. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  241. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  242. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  243. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  244. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  245. data/ext/sources/ggml/src/ggml-impl.h +119 -9
  246. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  247. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  248. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  249. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  250. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  251. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  252. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  253. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  254. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
  255. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
  259. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  260. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
  261. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
  262. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  263. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  264. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  265. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  266. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  300. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  301. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  302. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  303. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
  304. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  305. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
  306. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  307. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
  308. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  309. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
  310. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
  311. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  312. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  313. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
  314. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
  315. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  316. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  317. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  318. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
  319. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  320. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  321. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  322. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
  323. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
  324. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  325. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  326. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
  327. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  328. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  329. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  330. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  331. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  332. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  333. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
  334. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  335. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  336. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  337. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  338. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  339. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  340. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  341. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  342. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  343. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  344. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  345. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  346. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  347. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  348. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  349. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  350. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  351. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  352. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  353. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  354. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  355. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  356. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  357. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  358. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  359. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  360. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  361. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  362. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  363. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  364. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  365. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
  366. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  367. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  368. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  369. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  370. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  371. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  372. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  373. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  374. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  375. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
  401. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  402. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  403. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  404. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  405. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  406. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  407. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  408. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  409. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  410. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  411. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  412. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  413. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  414. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  415. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  416. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  417. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  418. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  419. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  420. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  421. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  422. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  423. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  424. data/ext/sources/ggml/src/ggml.c +478 -98
  425. data/ext/sources/ggml/src/gguf.cpp +8 -1
  426. data/ext/sources/src/whisper.cpp +23 -46
  427. data/ext/sources/tests/CMakeLists.txt +8 -1
  428. data/ext/sources/tests/test-vad-full.cpp +3 -3
  429. data/ext/sources/tests/test-vad.cpp +2 -2
  430. data/lib/whisper/model/uri.rb +1 -1
  431. data/sig/whisper.rbs +7 -0
  432. data/test/test_params.rb +8 -0
  433. data/test/test_whisper.rb +1 -1
  434. data/whispercpp.gemspec +1 -1
  435. metadata +164 -157
  436. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  437. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  438. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  439. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  440. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  441. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  442. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  443. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  444. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  445. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  446. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  447. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  448. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  449. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  450. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  451. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  452. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  453. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  454. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  455. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  456. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  457. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  458. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  459. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  460. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  461. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  462. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  463. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  464. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  465. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  466. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  467. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  468. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  469. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  470. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  471. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  472. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  473. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  474. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  475. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  476. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  477. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  478. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  479. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  480. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  481. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  482. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  483. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  484. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  485. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  486. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  487. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  488. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  489. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  490. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  491. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  492. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  493. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  494. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  495. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  496. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  497. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  498. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  499. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  500. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  501. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  502. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  503. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  504. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  505. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  506. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  507. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  508. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  509. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  510. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  511. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  512. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  513. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  514. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  515. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  516. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  517. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  518. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  519. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  520. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  521. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  522. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  523. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  524. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  525. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  526. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  527. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  548. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  549. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  550. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  551. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  552. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  553. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  554. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  555. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  556. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  557. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  558. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  559. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  560. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  561. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  562. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  563. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  564. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  565. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  566. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  567. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  568. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  569. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  570. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  571. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  572. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  573. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  574. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  575. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  576. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  577. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  578. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  579. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  580. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  581. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  582. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  583. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  584. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  585. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  586. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -0,0 +1,3196 @@
1
+ #include "ggml.h"
2
+ #include "ime_kernels.h"
3
+
4
+ #include <algorithm>
5
+ #include <cmath>
6
+
7
+ // clang-format off
8
+ #if defined(__GNUC__)
9
+ #pragma GCC diagnostic ignored "-Woverlength-strings"
10
+ #pragma GCC diagnostic ignored "-Wcast-qual"
11
+ #pragma GCC diagnostic ignored "-Wunused-parameter"
12
+ #endif
13
+ // clang-format on
14
+ namespace sqnbitgemm_spacemit_ime {
15
+
16
+ #define QUANTIZEM4ROW_KERNEL \
17
+ "vmv.s.x v16, zero \n\t" \
18
+ "vfabs.v v8, v0 \n\t" \
19
+ "vfredmax.vs v16, v8, v16 \n\t" \
20
+ "vfmv.f.s f10, v16 \n\t" \
21
+ "fmul.s f10, f10, %[RMAXREC] \n\t" \
22
+ "fsw f10, (a1) \n\t" \
23
+ "fdiv.s f11, %[FONE], f10 \n\t" \
24
+ "vfmul.vf v16, v0, f11 \n\t" \
25
+ "vfcvt.x.f.v v16, v16 \n\t" \
26
+ "vsetvli t0, zero, e16, mf2 \n\t" \
27
+ "vnclip.wx v16, v16, zero \n\t" \
28
+ "vnclip.wx v17, v17, zero \n\t" \
29
+ "vnclip.wx v18, v18, zero \n\t" \
30
+ "vnclip.wx v19, v19, zero \n\t" \
31
+ "vnclip.wx v20, v20, zero \n\t" \
32
+ "vnclip.wx v21, v21, zero \n\t" \
33
+ "vnclip.wx v22, v22, zero \n\t" \
34
+ "vnclip.wx v23, v23, zero \n\t" \
35
+ "vsetvli t0, zero, e8, mf4 \n\t" \
36
+ "vnclip.wx v24, v16, zero \n\t" \
37
+ "vnclip.wx v25, v17, zero \n\t" \
38
+ "vnclip.wx v26, v18, zero \n\t" \
39
+ "vnclip.wx v27, v19, zero \n\t" \
40
+ "vnclip.wx v28, v20, zero \n\t" \
41
+ "vnclip.wx v29, v21, zero \n\t" \
42
+ "vnclip.wx v30, v22, zero \n\t" \
43
+ "vnclip.wx v31, v23, zero \n\t"
44
+
45
+ #define QUANTIZEM4ROW_STORE \
46
+ "addi t1, %[BlkLen], 0 \n\t" \
47
+ "vsetvli t0, t1, e8, mf4 \n\t" \
48
+ "vse8.v v24, (s1) \n\t" \
49
+ "addi s1, s1, 32 \n\t" \
50
+ "sub t1, t1, t0 \n\t" \
51
+ "vsetvli t0, t1, e8, mf4 \n\t" \
52
+ "vse8.v v25, (s1) \n\t" \
53
+ "addi s1, s1, 32 \n\t" \
54
+ "sub t1, t1, t0 \n\t" \
55
+ "vsetvli t0, t1, e8, mf4 \n\t" \
56
+ "vse8.v v26, (s1) \n\t" \
57
+ "addi s1, s1, 32 \n\t" \
58
+ "sub t1, t1, t0 \n\t" \
59
+ "vsetvli t0, t1, e8, mf4 \n\t" \
60
+ "vse8.v v27, (s1) \n\t" \
61
+ "addi s1, s1, 32 \n\t" \
62
+ "sub t1, t1, t0 \n\t" \
63
+ "vsetvli t0, t1, e8, mf4 \n\t" \
64
+ "vse8.v v28, (s1) \n\t" \
65
+ "addi s1, s1, 32 \n\t" \
66
+ "sub t1, t1, t0 \n\t" \
67
+ "vsetvli t0, t1, e8, mf4 \n\t" \
68
+ "vse8.v v29, (s1) \n\t" \
69
+ "addi s1, s1, 32 \n\t" \
70
+ "sub t1, t1, t0 \n\t" \
71
+ "vsetvli t0, t1, e8, mf4 \n\t" \
72
+ "vse8.v v30, (s1) \n\t" \
73
+ "addi s1, s1, 32 \n\t" \
74
+ "sub t1, t1, t0 \n\t" \
75
+ "vsetvli t0, t1, e8, mf4 \n\t" \
76
+ "vse8.v v31, (s1) \n\t"
77
+
78
+ namespace ime1 {
79
+ void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) {
80
+ constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);
81
+ const float fone = 1.0f;
82
+
83
+ if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64) {
84
+ for (size_t row_index = 0; row_index < 4; ++row_index) {
85
+ const float * SRC = A + row_index * CountK;
86
+ std::byte * DST = QuantA + row_index * sizeof(float);
87
+
88
+ const size_t offset = (4 - row_index) * 4 + row_index * 8;
89
+ const size_t stride = 4 * (sizeof(float) + BlkLen);
90
+ __asm__ volatile(
91
+ "vsetvli t0, zero, e32, m8 \n\t"
92
+ "addi t2, %[CountK], 0 \n\t"
93
+ "addi a1, %[DST], 0 \n\t"
94
+ "blt t2, %[BlkLen], TAIL%= \n\t"
95
+
96
+ "LOOP%=: \n\t"
97
+ "vsetvli t0, %[BlkLen], e32, m8 \n\t"
98
+ "vle32.v v0, (%[SRC]) \n\t"
99
+ "sub t2, t2, t0 \n\t"
100
+ "slli t1, t0, 2 \n\t"
101
+ "add %[SRC], %[SRC], t1 \n\t"
102
+ "add s1, a1, %[OFFSET] \n\t"
103
+
104
+ QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE
105
+
106
+ "add a1, a1, %[STRIDE] \n\t"
107
+ "bge t2, %[BlkLen], LOOP%= \n\t"
108
+
109
+ "TAIL%=: \n\t"
110
+ "blez t2, QUIT%= \n\t"
111
+ "vsetvli t0, zero, e32, m8 \n\t"
112
+ "vxor.vv v16, v16, v16 \n\t"
113
+ "vxor.vv v24, v24, v24 \n\t"
114
+ "vsetvli t0, t2, e32, m8 \n\t"
115
+ "vle32.v v0, (%[SRC]) \n\t"
116
+ "add s1, a1, %[OFFSET] \n\t"
117
+
118
+ QUANTIZEM4ROW_KERNEL
119
+
120
+ "addi t3, %[BlkLen], 0 \n\t"
121
+ "addi s2, s1, 0 \n\t"
122
+ "vsetvli t0, zero, e8, mf4 \n\t"
123
+ "vxor.vv v8, v8, v8 \n\t"
124
+ "SET_ZERO%=: \n\t"
125
+ "vse8.v v8, (s2) \n\t"
126
+ "addi s2, s2, 32 \n\t"
127
+ "addi t3, t3, -8 \n\t"
128
+ "bnez t3, SET_ZERO%= \n\t"
129
+
130
+ QUANTIZEM4ROW_STORE
131
+
132
+ "QUIT%=: \n\t"
133
+ : [SRC] "+r"(SRC)
134
+ : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
135
+ [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
136
+ : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11");
137
+ }
138
+ } else if (BlkLen == 128) {
139
+ for (size_t row_index = 0; row_index < 4; ++row_index) {
140
+ const float * SRC = A + row_index * CountK;
141
+ std::byte * DST = QuantA + row_index * sizeof(float);
142
+
143
+ const size_t offset = (4 - row_index) * 4 + row_index * 8;
144
+ const size_t stride = 4 * (sizeof(float) + BlkLen);
145
+ __asm__ volatile(
146
+ "vsetvli t0, zero, e32, m8 \n\t"
147
+ "li t6, 32 \n\t"
148
+ "addi t2, %[CountK], 0 \n\t"
149
+ "addi a1, %[DST], 0 \n\t"
150
+ "add s1, a1, %[OFFSET] \n\t"
151
+ "blt t2, %[BlkLen], TAIL%= \n\t"
152
+
153
+ "LOOP%=: \n\t"
154
+ "vsetvli t0, zero, e32, m8 \n\t"
155
+ "vle32.v v0, (%[SRC]) \n\t"
156
+ "addi %[SRC], %[SRC], 256 \n\t"
157
+ "vle32.v v8, (%[SRC]) \n\t"
158
+ "addi %[SRC], %[SRC], 256 \n\t"
159
+ "addi t2, t2, -128 \n\t"
160
+
161
+ "QUANTIZE%=: \n\t"
162
+ "add s1, a1, %[OFFSET] \n\t"
163
+ "vfabs.v v16, v0 \n\t"
164
+ "vfabs.v v24, v8 \n\t"
165
+ "vfmax.vv v16, v24, v16 \n\t"
166
+ "vfredmax.vs v24, v16, v24 \n\t"
167
+ "vfmv.f.s f10, v24 \n\t"
168
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
169
+ "fsw f10, (a1) \n\t"
170
+ "fdiv.s f11, %[FONE], f10 \n\t"
171
+ "vfmul.vf v16, v0, f11 \n\t"
172
+ "vfmul.vf v24, v8, f11 \n\t"
173
+ "vfcvt.x.f.v v16, v16 \n\t"
174
+ "vfcvt.x.f.v v24, v24 \n\t"
175
+ "vsetvli t0, zero, e16, m4 \n\t"
176
+ "vnclip.wx v16, v16, zero \n\t"
177
+ "vnclip.wx v20, v24, zero \n\t"
178
+ "vsetvli t0, zero, e8, m4 \n\t"
179
+ "vnclip.wx v16, v16, zero \n\t"
180
+ "vsetvli t0, zero, e64, m4 \n\t"
181
+ "vsse64.v v16, (s1), t6 \n\t"
182
+ "add a1, a1, %[STRIDE] \n\t"
183
+ "bge t2, %[BlkLen], LOOP%= \n\t"
184
+
185
+ "TAIL%=: \n\t"
186
+ "blez t2, QUIT%= \n\t"
187
+ "vsetvli t0, zero, e32, m8 \n\t"
188
+ "vxor.vv v0, v0, v0 \n\t"
189
+ "vxor.vv v8, v8, v8 \n\t"
190
+ "vxor.vv v16, v16, v16 \n\t"
191
+ "vxor.vv v24, v24, v24 \n\t"
192
+ "vsetvli t0, t2, e32, m8 \n\t"
193
+ "sub t2, t2, t0 \n\t"
194
+ "vle32.v v0, (%[SRC]) \n\t"
195
+ "addi %[SRC], %[SRC], 256 \n\t"
196
+ "vsetvli t0, t2, e32, m8 \n\t"
197
+ "vle32.v v8, (%[SRC]) \n\t"
198
+ "sub t2, t2, t2 \n\t"
199
+ "vsetvli t0, zero, e32, m8 \n\t"
200
+ "jal x0, QUANTIZE%= \n\t"
201
+
202
+ "QUIT%=: \n\t"
203
+ : [SRC] "+r"(SRC)
204
+ : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
205
+ [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
206
+ : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11");
207
+ }
208
+ } else if (BlkLen == 256) {
209
+ for (size_t row_index = 0; row_index < 4; ++row_index) {
210
+ const float * SRC = A + row_index * CountK;
211
+ std::byte * DST = QuantA + row_index * sizeof(float);
212
+ const size_t offset = (4 - row_index) * 4 + row_index * 8;
213
+ const size_t stride = 4 * (sizeof(float) + BlkLen);
214
+ __asm__ volatile(
215
+ "vsetvli t0, zero, e32, m8 \n\t"
216
+ "li t6, 32 \n\t"
217
+ "addi t2, %[CountK], 0 \n\t"
218
+ "addi a1, %[DST], 0 \n\t"
219
+ "add s1, a1, %[OFFSET] \n\t"
220
+ "blt t2, %[BlkLen], TAIL%= \n\t"
221
+
222
+ "LOOP%=: \n\t"
223
+ "vsetvli t0, zero, e32, m8 \n\t"
224
+ "vle32.v v0, (%[SRC]) \n\t"
225
+ "addi %[SRC], %[SRC], 256 \n\t"
226
+ "vle32.v v8, (%[SRC]) \n\t"
227
+ "addi %[SRC], %[SRC], 256 \n\t"
228
+ "vle32.v v16, (%[SRC]) \n\t"
229
+ "addi %[SRC], %[SRC], 256 \n\t"
230
+ "vle32.v v24, (%[SRC]) \n\t"
231
+ "addi %[SRC], %[SRC], -768 \n\t"
232
+ "addi t2, t2, -256 \n\t"
233
+ "vfabs.v v0, v0 \n\t"
234
+ "vfabs.v v8, v8 \n\t"
235
+ "vfabs.v v16, v16 \n\t"
236
+ "vfabs.v v24, v24 \n\t"
237
+ "vfmax.vv v8, v0, v8 \n\t"
238
+ "vfmax.vv v24, v24, v16 \n\t"
239
+ "vfmax.vv v8, v8, v24 \n\t"
240
+ "vfredmax.vs v24, v8, v24 \n\t"
241
+ "vfmv.f.s f10, v24 \n\t"
242
+ "vle32.v v0, (%[SRC]) \n\t"
243
+ "addi %[SRC], %[SRC], 256 \n\t"
244
+ "vle32.v v8, (%[SRC]) \n\t"
245
+ "addi %[SRC], %[SRC], 256 \n\t"
246
+ "vle32.v v16, (%[SRC]) \n\t"
247
+ "addi %[SRC], %[SRC], 256 \n\t"
248
+ "vle32.v v24, (%[SRC]) \n\t"
249
+ "addi %[SRC], %[SRC], 256 \n\t"
250
+
251
+ "QUANTIZE%=: \n\t"
252
+ "add s1, a1, %[OFFSET] \n\t"
253
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
254
+ "fsw f10, (a1) \n\t"
255
+ "fdiv.s f11, %[FONE], f10 \n\t"
256
+ "vfmul.vf v0, v0, f11 \n\t"
257
+ "vfmul.vf v8, v8, f11 \n\t"
258
+ "vfmul.vf v16, v16, f11 \n\t"
259
+ "vfmul.vf v24, v24, f11 \n\t"
260
+ "vfcvt.x.f.v v0, v0 \n\t"
261
+ "vfcvt.x.f.v v8, v8 \n\t"
262
+ "vfcvt.x.f.v v16, v16 \n\t"
263
+ "vfcvt.x.f.v v24, v24 \n\t"
264
+ "vsetvli t0, zero, e16, m4 \n\t"
265
+ "vnclip.wx v0, v0, zero \n\t"
266
+ "vnclip.wx v4, v8, zero \n\t"
267
+ "vnclip.wx v8, v16, zero \n\t"
268
+ "vnclip.wx v12, v24, zero \n\t"
269
+ "vsetvli t0, zero, e8, m4 \n\t"
270
+ "vnclip.wx v0, v0, zero \n\t"
271
+ "vnclip.wx v4, v8, zero \n\t"
272
+ "vsetvli t0, zero, e64, m8 \n\t"
273
+ "vsse64.v v0, (s1), t6 \n\t"
274
+ "add a1, a1, %[STRIDE] \n\t"
275
+ "bge t2, %[BlkLen], LOOP%= \n\t"
276
+
277
+ "TAIL%=: \n\t"
278
+ "blez t2, QUIT%= \n\t"
279
+ "vsetvli t0, zero, e32, m8 \n\t"
280
+ "vxor.vv v0, v0, v0 \n\t"
281
+ "vxor.vv v8, v8, v8 \n\t"
282
+ "vxor.vv v16, v16, v16 \n\t"
283
+ "vxor.vv v24, v24, v24 \n\t"
284
+ "addi t1, t2, 0 \n\t"
285
+ "vsetvli t0, t1, e32, m8 \n\t"
286
+ "sub t1, t1, t0 \n\t"
287
+ "vle32.v v0, (%[SRC]) \n\t"
288
+ "addi %[SRC], %[SRC], 256 \n\t"
289
+ "vsetvli t0, t1, e32, m8 \n\t"
290
+ "sub t1, t1, t0 \n\t"
291
+ "vle32.v v8, (%[SRC]) \n\t"
292
+ "addi %[SRC], %[SRC], 256 \n\t"
293
+ "vsetvli t0, t1, e32, m8 \n\t"
294
+ "sub t1, t1, t0 \n\t"
295
+ "vle32.v v16, (%[SRC]) \n\t"
296
+ "addi %[SRC], %[SRC], 256 \n\t"
297
+ "vsetvli t0, t1, e32, m8 \n\t"
298
+ "vle32.v v24, (%[SRC]) \n\t"
299
+ "addi %[SRC], %[SRC], -768 \n\t"
300
+ "vsetvli t0, zero, e32, m8 \n\t"
301
+ "vfabs.v v0, v0 \n\t"
302
+ "vfabs.v v8, v8 \n\t"
303
+ "vfabs.v v16, v16 \n\t"
304
+ "vfabs.v v24, v24 \n\t"
305
+ "vfmax.vv v8, v0, v8 \n\t"
306
+ "vfmax.vv v24, v16, v24 \n\t"
307
+ "vfmax.vv v8, v8, v24 \n\t"
308
+ "vfredmax.vs v24, v8, v24 \n\t"
309
+ "vfmv.f.s f10, v24 \n\t"
310
+ "add s1, a1, %[OFFSET] \n\t"
311
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
312
+ "fsw f10, (a1) \n\t"
313
+ "fdiv.s f11, %[FONE], f10 \n\t"
314
+ "vsetvli t0, zero, e64, m8 \n\t"
315
+ "vxor.vv v0, v0, v0 \n\t"
316
+ "vsse64.v v0, (s1), t6 \n\t"
317
+
318
+ "TAIL_LOOP%=: \n\t"
319
+ "vsetvli t0, zero, e32, m4 \n\t"
320
+ "vxor.vv v0, v0, v0 \n\t"
321
+ "vsetvli t0, t2, e32, m1 \n\t"
322
+ "sub t2, t2, t0 \n\t"
323
+ "vle32.v v0, (%[SRC]) \n\t"
324
+ "addi %[SRC], %[SRC], 32 \n\t"
325
+ "vfmul.vf v1, v0, f11 \n\t"
326
+ "vfcvt.x.f.v v2, v1 \n\t"
327
+ "vsetvli t0, zero, e16, mf2 \n\t"
328
+ "vnclip.wx v3, v2, zero \n\t"
329
+ "vsetvli t0, zero, e8, mf4 \n\t"
330
+ "vnclip.wx v3, v3, zero \n\t"
331
+ "vse8.v v3, (s1) \n\t"
332
+ "addi s1, s1, 32 \n\t"
333
+ "bnez t2, TAIL_LOOP%= \n\t"
334
+
335
+ "QUIT%=: \n\t"
336
+ : [SRC] "+r"(SRC)
337
+ : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
338
+ [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
339
+ : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11");
340
+ }
341
+ }
342
+ }
343
+
344
+ void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) {
345
+ const float * SRC = A;
346
+ std::byte * DST = QuantA;
347
+ constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);
348
+ const float fone = 1.0f;
349
+ std::byte * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen);
350
+ size_t offset = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK;
351
+
352
+ if (CountK <= BlkLen) {
353
+ float max_abs_A = 0.0f;
354
+ for (size_t k = 0; k < CountK; k++) {
355
+ max_abs_A = std::max(max_abs_A, fabsf(A[k]));
356
+ }
357
+ float scale_A = max_abs_A * range_max_reciprocal;
358
+
359
+ ((float *) QuantA)[0] = scale_A;
360
+
361
+ auto * QuantAData_offset = (int8_t *) (QuantA + sizeof(float));
362
+
363
+ for (size_t k = 0; k < CountK; k++) {
364
+ QuantAData_offset[k] =
365
+ (int8_t) std::clamp(roundf(A[k] / scale_A), (float) std::numeric_limits<int8_t>::lowest(),
366
+ (float) std::numeric_limits<int8_t>::max());
367
+ }
368
+ for (size_t k = CountK; k < BlkLen; k++) {
369
+ QuantAData_offset[k] = 0;
370
+ }
371
+
372
+ return;
373
+ }
374
+
375
+ if (BlkLen != 32 || BlkLen != 64 || BlkLen != 128) {
376
+ __asm__ volatile(
377
+ "vsetvli t0, zero, e8, m8 \n\t"
378
+ "vxor.vv v24, v24, v24 \n\t"
379
+ "LOOP%=: \n\t"
380
+ "vsetvli t0, %[CNT], e8, m8 \n\t"
381
+ "vse8.v v24, (%[DST]) \n\t"
382
+ "addi %[DST], %[DST], 128 \n\t"
383
+ "sub %[CNT], %[CNT], t0 \n\t"
384
+ "bnez %[CNT], LOOP%= \n\t"
385
+ : [DST] "+r"(QuantA_offset), [CNT] "+r"(offset)
386
+ :
387
+ : "cc", "t0");
388
+ }
389
+ if (BlkLen == 16) {
390
+ float buffer[64] = { 0.0f };
391
+ __asm__ volatile(
392
+ "addi t3, zero, 16*8 \n\t"
393
+ "addi t2, zero, 16 \n\t"
394
+ "blt %[K], t3, LOOP_K%= \n\t"
395
+ "blt %[K], t2, TAIL%= \n\t"
396
+ "LOOP_MAIN%=: \n\t"
397
+ "vsetvli t1, zero, e32, m2 \n\t"
398
+ "addi %[K], %[K], -128 \n\t"
399
+ "vle32.v v0, (%[SRC]) \n\t"
400
+ "addi %[SRC], %[SRC], 64 \n\t"
401
+ "vle32.v v2, (%[SRC]) \n\t"
402
+ "addi %[SRC], %[SRC], 64 \n\t"
403
+ "vle32.v v4, (%[SRC]) \n\t"
404
+ "addi %[SRC], %[SRC], 64 \n\t"
405
+ "vle32.v v6, (%[SRC]) \n\t"
406
+ "addi %[SRC], %[SRC], 64 \n\t"
407
+ "vle32.v v8, (%[SRC]) \n\t"
408
+ "addi %[SRC], %[SRC], 64 \n\t"
409
+ "vle32.v v10, (%[SRC]) \n\t"
410
+ "addi %[SRC], %[SRC], 64 \n\t"
411
+ "vle32.v v12, (%[SRC]) \n\t"
412
+ "addi %[SRC], %[SRC], 64 \n\t"
413
+ "vle32.v v14, (%[SRC]) \n\t"
414
+ "addi %[SRC], %[SRC], 64 \n\t"
415
+ "addi a1, %[BUFFER], 0 \n\t"
416
+ "vfabs.v v16, v0 \n\t"
417
+ "vfabs.v v18, v2 \n\t"
418
+ "vfabs.v v20, v4 \n\t"
419
+ "vfabs.v v22, v6 \n\t"
420
+ "vfabs.v v24, v8 \n\t"
421
+ "vfabs.v v26, v10 \n\t"
422
+ "vfabs.v v28, v12 \n\t"
423
+ "vfabs.v v30, v14 \n\t"
424
+ "vsetvli t0, zero, e32, m1 \n\t"
425
+ "vfmax.vv v16, v16, v17 \n\t"
426
+ "vfmax.vv v18, v18, v19 \n\t"
427
+ "vfmax.vv v20, v20, v21 \n\t"
428
+ "vfmax.vv v22, v22, v23 \n\t"
429
+ "vfmax.vv v24, v24, v25 \n\t"
430
+ "vfmax.vv v26, v26, v27 \n\t"
431
+ "vfmax.vv v28, v28, v29 \n\t"
432
+ "vfmax.vv v30, v30, v31 \n\t"
433
+ "vse32.v v16, (a1) \n\t"
434
+ "addi a1, a1, 32 \n\t"
435
+ "vse32.v v18, (a1) \n\t"
436
+ "addi a1, a1, 32 \n\t"
437
+ "vse32.v v20, (a1) \n\t"
438
+ "addi a1, a1, 32 \n\t"
439
+ "vse32.v v22, (a1) \n\t"
440
+ "addi a1, a1, 32 \n\t"
441
+ "vse32.v v24, (a1) \n\t"
442
+ "addi a1, a1, 32 \n\t"
443
+ "vse32.v v26, (a1) \n\t"
444
+ "addi a1, a1, 32 \n\t"
445
+ "vse32.v v28, (a1) \n\t"
446
+ "addi a1, a1, 32 \n\t"
447
+ "vse32.v v30, (a1) \n\t"
448
+ "addi a1, %[BUFFER], 0 \n\t"
449
+ "flw f0, (a1) \n\t"
450
+ "flw f1, 4(a1) \n\t"
451
+ "flw f2, 8(a1) \n\t"
452
+ "flw f3, 12(a1) \n\t"
453
+ "flw f4, 16(a1) \n\t"
454
+ "flw f5, 20(a1) \n\t"
455
+ "flw f6, 24(a1) \n\t"
456
+ "flw f7, 28(a1) \n\t"
457
+ "addi a1, a1, 32 \n\t"
458
+ "fmax.s f1, f0, f1 \n\t"
459
+ "fmax.s f3, f2, f3 \n\t"
460
+ "fmax.s f5, f4, f5 \n\t"
461
+ "fmax.s f7, f6, f7 \n\t"
462
+ "fmax.s f3, f1, f3 \n\t"
463
+ "fmax.s f7, f5, f7 \n\t"
464
+ "fmax.s f10, f3, f7 \n\t"
465
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
466
+ "fsw f10, (%[DST]) \n\t"
467
+ "addi %[DST], %[DST], 20 \n\t"
468
+ "fdiv.s f10, %[FONE], f10 \n\t"
469
+ "flw f0, (a1) \n\t"
470
+ "flw f1, 4(a1) \n\t"
471
+ "flw f2, 8(a1) \n\t"
472
+ "flw f3, 12(a1) \n\t"
473
+ "flw f4, 16(a1) \n\t"
474
+ "flw f5, 20(a1) \n\t"
475
+ "flw f6, 24(a1) \n\t"
476
+ "flw f7, 28(a1) \n\t"
477
+ "addi a1, a1, 32 \n\t"
478
+ "fmax.s f1, f0, f1 \n\t"
479
+ "fmax.s f3, f2, f3 \n\t"
480
+ "fmax.s f5, f4, f5 \n\t"
481
+ "fmax.s f7, f6, f7 \n\t"
482
+ "fmax.s f3, f1, f3 \n\t"
483
+ "fmax.s f7, f5, f7 \n\t"
484
+ "fmax.s f11, f3, f7 \n\t"
485
+ "fmul.s f11, f11, %[RMAXREC] \n\t"
486
+ "fsw f11, (%[DST]) \n\t"
487
+ "addi %[DST], %[DST], 20 \n\t"
488
+ "fdiv.s f11, %[FONE], f11 \n\t"
489
+ "flw f0, (a1) \n\t"
490
+ "flw f1, 4(a1) \n\t"
491
+ "flw f2, 8(a1) \n\t"
492
+ "flw f3, 12(a1) \n\t"
493
+ "flw f4, 16(a1) \n\t"
494
+ "flw f5, 20(a1) \n\t"
495
+ "flw f6, 24(a1) \n\t"
496
+ "flw f7, 28(a1) \n\t"
497
+ "addi a1, a1, 32 \n\t"
498
+ "fmax.s f1, f0, f1 \n\t"
499
+ "fmax.s f3, f2, f3 \n\t"
500
+ "fmax.s f5, f4, f5 \n\t"
501
+ "fmax.s f7, f6, f7 \n\t"
502
+ "fmax.s f3, f1, f3 \n\t"
503
+ "fmax.s f7, f5, f7 \n\t"
504
+ "fmax.s f12, f3, f7 \n\t"
505
+ "fmul.s f12, f12, %[RMAXREC] \n\t"
506
+ "fsw f12, (%[DST]) \n\t"
507
+ "addi %[DST], %[DST], 20 \n\t"
508
+ "fdiv.s f12, %[FONE], f12 \n\t"
509
+ "flw f0, (a1) \n\t"
510
+ "flw f1, 4(a1) \n\t"
511
+ "flw f2, 8(a1) \n\t"
512
+ "flw f3, 12(a1) \n\t"
513
+ "flw f4, 16(a1) \n\t"
514
+ "flw f5, 20(a1) \n\t"
515
+ "flw f6, 24(a1) \n\t"
516
+ "flw f7, 28(a1) \n\t"
517
+ "addi a1, a1, 32 \n\t"
518
+ "fmax.s f1, f0, f1 \n\t"
519
+ "fmax.s f3, f2, f3 \n\t"
520
+ "fmax.s f5, f4, f5 \n\t"
521
+ "fmax.s f7, f6, f7 \n\t"
522
+ "fmax.s f3, f1, f3 \n\t"
523
+ "fmax.s f7, f5, f7 \n\t"
524
+ "fmax.s f13, f3, f7 \n\t"
525
+ "fmul.s f13, f13, %[RMAXREC] \n\t"
526
+ "fsw f13, (%[DST]) \n\t"
527
+ "addi %[DST], %[DST], 20 \n\t"
528
+ "fdiv.s f13, %[FONE], f13 \n\t"
529
+ "flw f0, (a1) \n\t"
530
+ "flw f1, 4(a1) \n\t"
531
+ "flw f2, 8(a1) \n\t"
532
+ "flw f3, 12(a1) \n\t"
533
+ "flw f4, 16(a1) \n\t"
534
+ "flw f5, 20(a1) \n\t"
535
+ "flw f6, 24(a1) \n\t"
536
+ "flw f7, 28(a1) \n\t"
537
+ "addi a1, a1, 32 \n\t"
538
+ "fmax.s f1, f0, f1 \n\t"
539
+ "fmax.s f3, f2, f3 \n\t"
540
+ "fmax.s f5, f4, f5 \n\t"
541
+ "fmax.s f7, f6, f7 \n\t"
542
+ "fmax.s f3, f1, f3 \n\t"
543
+ "fmax.s f7, f5, f7 \n\t"
544
+ "fmax.s f14, f3, f7 \n\t"
545
+ "fmul.s f14, f14, %[RMAXREC] \n\t"
546
+ "fsw f14, (%[DST]) \n\t"
547
+ "addi %[DST], %[DST], 20 \n\t"
548
+ "fdiv.s f14, %[FONE], f14 \n\t"
549
+ "flw f0, (a1) \n\t"
550
+ "flw f1, 4(a1) \n\t"
551
+ "flw f2, 8(a1) \n\t"
552
+ "flw f3, 12(a1) \n\t"
553
+ "flw f4, 16(a1) \n\t"
554
+ "flw f5, 20(a1) \n\t"
555
+ "flw f6, 24(a1) \n\t"
556
+ "flw f7, 28(a1) \n\t"
557
+ "addi a1, a1, 32 \n\t"
558
+ "fmax.s f1, f0, f1 \n\t"
559
+ "fmax.s f3, f2, f3 \n\t"
560
+ "fmax.s f5, f4, f5 \n\t"
561
+ "fmax.s f7, f6, f7 \n\t"
562
+ "fmax.s f3, f1, f3 \n\t"
563
+ "fmax.s f7, f5, f7 \n\t"
564
+ "fmax.s f15, f3, f7 \n\t"
565
+ "fmul.s f15, f15, %[RMAXREC] \n\t"
566
+ "fsw f15, (%[DST]) \n\t"
567
+ "addi %[DST], %[DST], 20 \n\t"
568
+ "fdiv.s f15, %[FONE], f15 \n\t"
569
+ "flw f0, (a1) \n\t"
570
+ "flw f1, 4(a1) \n\t"
571
+ "flw f2, 8(a1) \n\t"
572
+ "flw f3, 12(a1) \n\t"
573
+ "flw f4, 16(a1) \n\t"
574
+ "flw f5, 20(a1) \n\t"
575
+ "flw f6, 24(a1) \n\t"
576
+ "flw f7, 28(a1) \n\t"
577
+ "addi a1, a1, 32 \n\t"
578
+ "fmax.s f1, f0, f1 \n\t"
579
+ "fmax.s f3, f2, f3 \n\t"
580
+ "fmax.s f5, f4, f5 \n\t"
581
+ "fmax.s f7, f6, f7 \n\t"
582
+ "fmax.s f3, f1, f3 \n\t"
583
+ "fmax.s f7, f5, f7 \n\t"
584
+ "fmax.s f16, f3, f7 \n\t"
585
+ "fmul.s f16, f16, %[RMAXREC] \n\t"
586
+ "fsw f16, (%[DST]) \n\t"
587
+ "addi %[DST], %[DST], 20 \n\t"
588
+ "fdiv.s f16, %[FONE], f16 \n\t"
589
+ "flw f0, (a1) \n\t"
590
+ "flw f1, 4(a1) \n\t"
591
+ "flw f2, 8(a1) \n\t"
592
+ "flw f3, 12(a1) \n\t"
593
+ "flw f4, 16(a1) \n\t"
594
+ "flw f5, 20(a1) \n\t"
595
+ "flw f6, 24(a1) \n\t"
596
+ "flw f7, 28(a1) \n\t"
597
+ "addi a1, a1, 32 \n\t"
598
+ "fmax.s f1, f0, f1 \n\t"
599
+ "fmax.s f3, f2, f3 \n\t"
600
+ "fmax.s f5, f4, f5 \n\t"
601
+ "fmax.s f7, f6, f7 \n\t"
602
+ "fmax.s f3, f1, f3 \n\t"
603
+ "fmax.s f7, f5, f7 \n\t"
604
+ "fmax.s f17, f3, f7 \n\t"
605
+ "fmul.s f17, f17, %[RMAXREC] \n\t"
606
+ "fsw f17, (%[DST]) \n\t"
607
+ "addi %[DST], %[DST], -136 \n\t"
608
+ "fdiv.s f17, %[FONE], f17 \n\t"
609
+ "vsetvli t0, zero, e32, m2 \n\t"
610
+ "vfmul.vf v16, v0, f10 \n\t"
611
+ "vfmul.vf v18, v2, f11 \n\t"
612
+ "vfmul.vf v20, v4, f12 \n\t"
613
+ "vfmul.vf v22, v6, f13 \n\t"
614
+ "vfmul.vf v24, v8, f14 \n\t"
615
+ "vfmul.vf v26, v10, f15 \n\t"
616
+ "vfmul.vf v28, v12, f16 \n\t"
617
+ "vfmul.vf v30, v14, f17 \n\t"
618
+ "vfcvt.x.f.v v16, v16 \n\t"
619
+ "vfcvt.x.f.v v18, v18 \n\t"
620
+ "vfcvt.x.f.v v20, v20 \n\t"
621
+ "vfcvt.x.f.v v22, v22 \n\t"
622
+ "vfcvt.x.f.v v24, v24 \n\t"
623
+ "vfcvt.x.f.v v26, v26 \n\t"
624
+ "vfcvt.x.f.v v28, v28 \n\t"
625
+ "vfcvt.x.f.v v30, v30 \n\t"
626
+ "vsetvli t0, zero, e16, m1 \n\t"
627
+ "vnclip.wx v16, v16, zero \n\t"
628
+ "vnclip.wx v18, v18, zero \n\t"
629
+ "vnclip.wx v20, v20, zero \n\t"
630
+ "vnclip.wx v22, v22, zero \n\t"
631
+ "vnclip.wx v24, v24, zero \n\t"
632
+ "vnclip.wx v26, v26, zero \n\t"
633
+ "vnclip.wx v28, v28, zero \n\t"
634
+ "vnclip.wx v30, v30, zero \n\t"
635
+ "vsetvli t0, t1, e8, mf2 \n\t"
636
+ "vnclip.wx v16, v16, zero \n\t"
637
+ "vnclip.wx v18, v18, zero \n\t"
638
+ "vnclip.wx v20, v20, zero \n\t"
639
+ "vnclip.wx v22, v22, zero \n\t"
640
+ "vnclip.wx v24, v24, zero \n\t"
641
+ "vnclip.wx v26, v26, zero \n\t"
642
+ "vnclip.wx v28, v28, zero \n\t"
643
+ "vnclip.wx v30, v30, zero \n\t"
644
+ "vse8.v v16, (%[DST]) \n\t"
645
+ "addi %[DST], %[DST], 20 \n\t"
646
+ "vse8.v v18, (%[DST]) \n\t"
647
+ "addi %[DST], %[DST], 20 \n\t"
648
+ "vse8.v v20, (%[DST]) \n\t"
649
+ "addi %[DST], %[DST], 20 \n\t"
650
+ "vse8.v v22, (%[DST]) \n\t"
651
+ "addi %[DST], %[DST], 20 \n\t"
652
+ "vse8.v v24, (%[DST]) \n\t"
653
+ "addi %[DST], %[DST], 20 \n\t"
654
+ "vse8.v v26, (%[DST]) \n\t"
655
+ "addi %[DST], %[DST], 20 \n\t"
656
+ "vse8.v v28, (%[DST]) \n\t"
657
+ "addi %[DST], %[DST], 20 \n\t"
658
+ "vse8.v v30, (%[DST]) \n\t"
659
+ "addi %[DST], %[DST], 16 \n\t"
660
+ "bge %[K], t3, LOOP_MAIN%= \n\t"
661
+ "blt %[K], t2, TAIL%= \n\t"
662
+ "LOOP_K%=: \n\t"
663
+ "vsetvli t1, %[K], e32, m2 \n\t"
664
+ "vle32.v v0, (%[SRC]) \n\t"
665
+ "addi %[SRC], %[SRC], 64 \n\t"
666
+ "sub %[K], %[K], t1 \n\t"
667
+ "vfabs.v v16, v0 \n\t"
668
+ "vsetvli t0, zero, e32, m1 \n\t"
669
+ "vfmax.vv v16, v16, v17 \n\t"
670
+ "vse32.v v16, (%[BUFFER]) \n\t"
671
+ "flw f0, (%[BUFFER]) \n\t"
672
+ "flw f1, 4(%[BUFFER]) \n\t"
673
+ "flw f2, 8(%[BUFFER]) \n\t"
674
+ "flw f3, 12(%[BUFFER]) \n\t"
675
+ "flw f4, 16(%[BUFFER]) \n\t"
676
+ "flw f5, 20(%[BUFFER]) \n\t"
677
+ "flw f6, 24(%[BUFFER]) \n\t"
678
+ "flw f7, 28(%[BUFFER]) \n\t"
679
+ "fmax.s f1, f0, f1 \n\t"
680
+ "fmax.s f3, f2, f3 \n\t"
681
+ "fmax.s f5, f4, f5 \n\t"
682
+ "fmax.s f7, f6, f7 \n\t"
683
+ "fmax.s f3, f1, f3 \n\t"
684
+ "fmax.s f7, f5, f7 \n\t"
685
+ "fmax.s f10, f3, f7 \n\t"
686
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
687
+ "fsw f10, (%[DST]) \n\t"
688
+ "addi %[DST], %[DST], 4 \n\t"
689
+ "fdiv.s f11, %[FONE], f10 \n\t"
690
+ "vsetvli t0, zero, e32, m2 \n\t"
691
+ "vfmul.vf v16, v0, f11 \n\t"
692
+ "vfcvt.x.f.v v16, v16 \n\t"
693
+ "vsetvli t0, zero, e16, m1 \n\t"
694
+ "vnclip.wx v16, v16, zero \n\t"
695
+ "vsetvli t0, t1, e8, mf2 \n\t"
696
+ "vnclip.wx v16, v16, zero \n\t"
697
+ "vse8.v v16, (%[DST]) \n\t"
698
+ "addi %[DST], %[DST], 16 \n\t"
699
+ "bge %[K], t2, LOOP_K%= \n\t"
700
+ "TAIL%=: \n\t"
701
+ "blez %[K], END%= \n\t"
702
+ "vsetvli t0, t3, e32, m2 \n\t"
703
+ "vxor.vv v16, v16, v16 \n\t"
704
+ "jal x0, LOOP_K%= \n\t"
705
+ "END%=: \n\t"
706
+ : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK)
707
+ : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BUFFER] "r"(buffer)
708
+ : "cc", "t3", "t2", "t1", "t0", "a1", "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f10", "f11", "f12",
709
+ "f13", "f14", "f15", "f16", "f17");
710
+ } else if (BlkLen == 32) {
711
+ __asm__ volatile(
712
+ "addi t3, zero, 32*4 \n\t"
713
+ "addi t2, zero, 32 \n\t"
714
+
715
+ "addi a1, %[SRC], 0 \n\t"
716
+ "addi a2, %[SRC], 128 \n\t"
717
+ "addi a3, %[SRC], 256 \n\t"
718
+ "addi a4, %[SRC], 384 \n\t"
719
+
720
+ "addi s1, %[DST], 0 \n\t"
721
+ "addi s2, %[DST], 36 \n\t"
722
+ "addi s3, %[DST], 72 \n\t"
723
+ "addi s4, %[DST], 108 \n\t"
724
+ "blt %[K], t3, LOOP_K%= \n\t"
725
+ "blt %[K], t2, TAIL%= \n\t"
726
+
727
+ "LOOP_MAIN%=: \n\t"
728
+ "vsetvli t1, zero, e32, m4 \n\t"
729
+ "addi %[K], %[K], -128 \n\t"
730
+ "vle32.v v0, (a1) \n\t"
731
+ "addi a1, a1, 512 \n\t"
732
+ "vle32.v v4, (a2) \n\t"
733
+ "addi a2, a2, 512 \n\t"
734
+ "vle32.v v8, (a3) \n\t"
735
+ "addi a3, a3, 512 \n\t"
736
+ "vle32.v v12, (a4) \n\t"
737
+ "addi a4, a4, 512 \n\t"
738
+ "vfabs.v v16, v0 \n\t"
739
+ "vfabs.v v20, v4 \n\t"
740
+ "vfabs.v v24, v8 \n\t"
741
+ "vfabs.v v28, v12 \n\t"
742
+ "vsetvli t0, zero, e32, m2 \n\t"
743
+ "vfmax.vv v16, v16, v18 \n\t"
744
+ "vfmax.vv v20, v20, v22 \n\t"
745
+ "vfmax.vv v24, v24, v26 \n\t"
746
+ "vfmax.vv v28, v28, v30 \n\t"
747
+ "vsetvli t0, zero, e32, m1 \n\t"
748
+ "vfmax.vv v16, v16, v17 \n\t"
749
+ "vfmax.vv v20, v20, v21 \n\t"
750
+ "vfmax.vv v24, v24, v25 \n\t"
751
+ "vfmax.vv v28, v28, v29 \n\t"
752
+
753
+ "vfredmax.vs v17, v16, v17 \n\t"
754
+ "vfredmax.vs v21, v20, v21 \n\t"
755
+ "vfredmax.vs v25, v24, v25 \n\t"
756
+ "vfredmax.vs v29, v28, v29 \n\t"
757
+ "vfmv.f.s f10, v17 \n\t"
758
+ "vfmv.f.s f11, v21 \n\t"
759
+ "vfmv.f.s f12, v25 \n\t"
760
+ "vfmv.f.s f13, v29 \n\t"
761
+
762
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
763
+ "fmul.s f11, f11, %[RMAXREC] \n\t"
764
+ "fmul.s f12, f12, %[RMAXREC] \n\t"
765
+ "fmul.s f13, f13, %[RMAXREC] \n\t"
766
+ "fsw f10, (s1) \n\t"
767
+ "addi s1, s1, 4 \n\t"
768
+
769
+ "fsw f11, (s2) \n\t"
770
+ "addi s2, s2, 4 \n\t"
771
+ "fsw f12, (s3) \n\t"
772
+ "addi s3, s3, 4 \n\t"
773
+ "fsw f13, (s4) \n\t"
774
+ "addi s4, s4, 4 \n\t"
775
+ "fdiv.s f10, %[FONE], f10 \n\t"
776
+ "fdiv.s f11, %[FONE], f11 \n\t"
777
+ "fdiv.s f12, %[FONE], f12 \n\t"
778
+ "fdiv.s f13, %[FONE], f13 \n\t"
779
+ "vsetvli t0, zero, e32, m4 \n\t"
780
+ "vfmul.vf v16, v0, f10 \n\t"
781
+ "vfmul.vf v20, v4, f11 \n\t"
782
+ "vfmul.vf v24, v8, f12 \n\t"
783
+ "vfmul.vf v28, v12, f13 \n\t"
784
+ "vfcvt.x.f.v v16, v16 \n\t"
785
+ "vfcvt.x.f.v v20, v20 \n\t"
786
+ "vfcvt.x.f.v v24, v24 \n\t"
787
+ "vfcvt.x.f.v v28, v28 \n\t"
788
+ "vsetvli t0, zero, e16, m2 \n\t"
789
+ "vnclip.wx v16, v16, zero \n\t"
790
+ "vnclip.wx v20, v20, zero \n\t"
791
+ "vnclip.wx v24, v24, zero \n\t"
792
+ "vnclip.wx v28, v28, zero \n\t"
793
+ "vsetvli t0, t1, e8, m1 \n\t"
794
+ "vnclip.wx v16, v16, zero \n\t"
795
+ "vnclip.wx v20, v20, zero \n\t"
796
+ "vnclip.wx v24, v24, zero \n\t"
797
+ "vnclip.wx v28, v28, zero \n\t"
798
+ "vse8.v v16, (s1) \n\t"
799
+ "addi s1, s1, 140 \n\t"
800
+ "vse8.v v20, (s2) \n\t"
801
+ "addi s2, s2, 140 \n\t"
802
+ "vse8.v v24, (s3) \n\t"
803
+ "addi s3, s3, 140 \n\t"
804
+ "vse8.v v28, (s4) \n\t"
805
+ "addi s4, s4, 140 \n\t"
806
+ "bge %[K], t3, LOOP_MAIN%= \n\t"
807
+ "blt %[K], t2, TAIL%= \n\t"
808
+ "LOOP_K%=: \n\t"
809
+ "vsetvli t1, %[K], e32, m4 \n\t"
810
+ "vle32.v v0, (a1) \n\t"
811
+ "addi a1, a1, 128 \n\t"
812
+ "sub %[K], %[K], t1 \n\t"
813
+ "vfabs.v v16, v0 \n\t"
814
+ "vsetvli t0, zero, e32, m2 \n\t"
815
+ "vfmax.vv v16, v16, v18 \n\t"
816
+ "vsetvli t0, zero, e32, m1 \n\t"
817
+ "vfmax.vv v16, v16, v17 \n\t"
818
+ "vfredmax.vs v17, v16, v17 \n\t"
819
+ "vfmv.f.s f10, v17 \n\t"
820
+
821
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
822
+ "fsw f10, (s1) \n\t"
823
+ "addi s1, s1, 4 \n\t"
824
+ "fdiv.s f11, %[FONE], f10 \n\t"
825
+ "vsetvli t0, zero, e32, m4 \n\t"
826
+ "vfmul.vf v16, v0, f11 \n\t"
827
+ "vfcvt.x.f.v v16, v16 \n\t"
828
+ "vsetvli t0, zero, e16, m2 \n\t"
829
+ "vnclip.wx v16, v16, zero \n\t"
830
+ "vsetvli t0, zero, e8, m1 \n\t"
831
+ "vnclip.wx v16, v16, zero \n\t"
832
+ "vse8.v v16, (s1) \n\t"
833
+ "addi s1, s1, 32 \n\t"
834
+ "bge %[K], t2, LOOP_K%= \n\t"
835
+ "TAIL%=: \n\t"
836
+ "blez %[K], END%= \n\t"
837
+ "vsetvli t0, t3, e32, m4 \n\t"
838
+ "vxor.vv v0, v0, v0 \n\t"
839
+ "vxor.vv v16, v16, v16 \n\t"
840
+ "jal x0, LOOP_K%= \n\t"
841
+ "END%=: \n\t"
842
+ : [K] "+r"(CountK)
843
+ : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST)
844
+ : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13");
845
+ } else if (BlkLen == 64) {
846
+ __asm__ volatile(
847
+ "addi t3, zero, 64*2 \n\t"
848
+ "addi t2, zero, 64 \n\t"
849
+ "addi a1, %[SRC], 0 \n\t"
850
+ "addi a2, %[SRC], 256 \n\t"
851
+ "addi s1, %[DST], 0 \n\t"
852
+ "addi s2, %[DST], 68 \n\t"
853
+ "blt %[K], t3, LOOP_K%= \n\t"
854
+ "blt %[K], t2, TAIL%= \n\t"
855
+ "LOOP_MAIN%=: \n\t"
856
+ "vsetvli t1, zero, e32, m8 \n\t"
857
+ "addi %[K], %[K], -128 \n\t"
858
+ "vle32.v v0, (a1) \n\t"
859
+ "addi a1, a1, 512 \n\t"
860
+ "vle32.v v8, (a2) \n\t"
861
+ "addi a2, a2, 512 \n\t"
862
+ "vfabs.v v16, v0 \n\t"
863
+ "vfabs.v v24, v8 \n\t"
864
+ "vsetvli t0, zero, e32, m4 \n\t"
865
+ "vfmax.vv v16, v16, v20 \n\t"
866
+ "vfmax.vv v24, v24, v28 \n\t"
867
+ "vsetvli t0, zero, e32, m2 \n\t"
868
+ "vfmax.vv v16, v16, v18 \n\t"
869
+ "vfmax.vv v24, v24, v26 \n\t"
870
+ "vsetvli t0, zero, e32, m1 \n\t"
871
+ "vfmax.vv v16, v16, v17 \n\t"
872
+ "vfmax.vv v24, v24, v25 \n\t"
873
+ "vfredmax.vs v17, v16, v17 \n\t"
874
+ "vfredmax.vs v25, v24, v25 \n\t"
875
+ "vfmv.f.s f10, v17 \n\t"
876
+ "vfmv.f.s f11, v25 \n\t"
877
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
878
+ "fmul.s f11, f11, %[RMAXREC] \n\t"
879
+ "fsw f10, (s1) \n\t"
880
+ "addi s1, s1, 4 \n\t"
881
+ "fsw f11, (s2) \n\t"
882
+ "addi s2, s2, 4 \n\t"
883
+ "fdiv.s f10, %[FONE], f10 \n\t"
884
+ "fdiv.s f11, %[FONE], f11 \n\t"
885
+ "vsetvli t0, zero, e32, m8 \n\t"
886
+ "vfmul.vf v16, v0, f10 \n\t"
887
+ "vfmul.vf v24, v8, f11 \n\t"
888
+ "vfcvt.x.f.v v16, v16 \n\t"
889
+ "vfcvt.x.f.v v24, v24 \n\t"
890
+ "vsetvli t0, zero, e16, m4 \n\t"
891
+ "vnclip.wx v16, v16, zero \n\t"
892
+ "vnclip.wx v24, v24, zero \n\t"
893
+ "vsetvli t0, t1, e8, m2 \n\t"
894
+ "vnclip.wx v16, v16, zero \n\t"
895
+ "vnclip.wx v24, v24, zero \n\t"
896
+ "vse8.v v16, (s1) \n\t"
897
+ "addi s1, s1, 132 \n\t"
898
+ "vse8.v v24, (s2) \n\t"
899
+ "addi s2, s2, 132 \n\t"
900
+ "bge %[K], t3, LOOP_MAIN%= \n\t"
901
+ "blt %[K], t2, TAIL%= \n\t"
902
+ "LOOP_K%=: \n\t"
903
+ "vsetvli t1, %[K], e32, m8 \n\t"
904
+ "vle32.v v0, (a1) \n\t"
905
+ "addi a1, a1, 256 \n\t"
906
+ "sub %[K], %[K], t1 \n\t"
907
+ "vfabs.v v16, v0 \n\t"
908
+ "vsetvli t0, zero, e32, m4 \n\t"
909
+ "vfmax.vv v16, v16, v20 \n\t"
910
+ "vsetvli t0, zero, e32, m2 \n\t"
911
+ "vfmax.vv v16, v16, v18 \n\t"
912
+ "vsetvli t0, zero, e32, m1 \n\t"
913
+ "vfmax.vv v16, v16, v17 \n\t"
914
+ "vfredmax.vs v17, v16, v17 \n\t"
915
+ "vfmv.f.s f10, v17 \n\t"
916
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
917
+ "fsw f10, (s1) \n\t"
918
+ "addi s1, s1, 4 \n\t"
919
+ "fdiv.s f11, %[FONE], f10 \n\t"
920
+ "vsetvli t0, zero, e32, m8 \n\t"
921
+ "vfmul.vf v16, v0, f11 \n\t"
922
+ "vfcvt.x.f.v v16, v16 \n\t"
923
+ "vsetvli t0, zero, e16, m4 \n\t"
924
+ "vnclip.wx v16, v16, zero \n\t"
925
+ "vsetvli t0, zero, e8, m2 \n\t"
926
+ "vnclip.wx v16, v16, zero \n\t"
927
+ "vse8.v v16, (s1) \n\t"
928
+ "addi s1, s1, 64 \n\t"
929
+ "bge %[K], t2, LOOP_K%= \n\t"
930
+ "TAIL%=: \n\t"
931
+ "blez %[K], END%= \n\t"
932
+ "vsetvli t0, t3, e32, m8 \n\t"
933
+ "vxor.vv v0, v0, v0 \n\t"
934
+ "vxor.vv v16, v16, v16 \n\t"
935
+ "jal x0, LOOP_K%= \n\t"
936
+ "END%=: \n\t"
937
+ : [K] "+r"(CountK)
938
+ : [SRC] "r"(SRC), [DST] "r"(DST), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
939
+ : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "s1", "s2", "f10", "f11");
940
+ } else if (BlkLen == 128) {
941
+ __asm__ volatile(
942
+ "addi t2, zero, 128 \n\t"
943
+ "addi a1, %[SRC], 0 \n\t"
944
+ "addi a2, %[SRC], 256 \n\t"
945
+ "blt %[K], t2, TAIL%= \n\t"
946
+ "LOOP_K%=: \n\t"
947
+ "vsetvli t1, zero, e32, m8 \n\t"
948
+ "vle32.v v0, (a1) \n\t"
949
+ "addi a1, a1, 512 \n\t"
950
+ "vle32.v v8, (a2) \n\t"
951
+ "addi a2, a2, 512 \n\t"
952
+ "sub %[K], %[K], t2 \n\t"
953
+ "QUANT%=: \n\t"
954
+ "vfabs.v v16, v0 \n\t"
955
+ "vfabs.v v24, v8 \n\t"
956
+ "vfmax.vv v24, v16, v24 \n\t"
957
+ "vsetvli t1, zero, e32, m4 \n\t"
958
+ "vfmax.vv v28, v24, v28 \n\t"
959
+ "vsetvli t0, zero, e32, m2 \n\t"
960
+ "vfmax.vv v30, v28, v30 \n\t"
961
+ "vsetvli t0, zero, e32, m1 \n\t"
962
+ "vfmax.vv v30, v30, v31 \n\t"
963
+ "vfredmax.vs v31, v30, v31 \n\t"
964
+ "vfmv.f.s f10, v31 \n\t"
965
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
966
+ "fsw f10, (%[DST]) \n\t"
967
+ "addi %[DST], %[DST], 4 \n\t"
968
+ "fdiv.s f11, %[FONE], f10 \n\t"
969
+ "vsetvli t0, zero, e32, m8 \n\t"
970
+ "vfmul.vf v16, v0, f11 \n\t"
971
+ "vfmul.vf v24, v8, f11 \n\t"
972
+ "vfcvt.x.f.v v16, v16 \n\t"
973
+ "vfcvt.x.f.v v24, v24 \n\t"
974
+ "vsetvli t0, zero, e16, m4 \n\t"
975
+ "vnclip.wx v16, v16, zero \n\t"
976
+ "vnclip.wx v20, v24, zero \n\t"
977
+ "vsetvli t0, zero, e8, m4 \n\t"
978
+ "vnclip.wx v16, v16, zero \n\t"
979
+ "vse8.v v16, (%[DST]) \n\t"
980
+ "addi %[DST], %[DST], 128 \n\t"
981
+ "bge %[K], t2, LOOP_K%= \n\t"
982
+ "TAIL%=: \n\t"
983
+ "blez %[K], END%= \n\t"
984
+ "vsetvli t1, zero, e32, m8 \n\t"
985
+ "vxor.vv v0, v0, v0 \n\t"
986
+ "vxor.vv v8, v8, v8 \n\t"
987
+ "vsetvli t0, %[K], e32, m8 \n\t"
988
+ "vle32.v v0, (a1) \n\t"
989
+ "sub %[K], %[K], t0 \n\t"
990
+ "vsetvli t0, %[K], e32, m8 \n\t"
991
+ "vle32.v v8, (a2) \n\t"
992
+ "sub %[K], %[K], t0 \n\t"
993
+ "vsetvli t1, zero, e32, m8 \n\t"
994
+ "jal x0, QUANT%= \n\t"
995
+ "END%=: \n\t"
996
+
997
+ : [DST] "+r"(DST), [K] "+r"(CountK)
998
+ : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC)
999
+ : "cc", "t2", "t1", "t0", "a1", "a2", "f10", "f11");
1000
+ } else {
1001
+ float buffer[8] = { 0.0f };
1002
+ size_t cnt = BlkLen / 256;
1003
+
1004
+ __asm__ volatile(
1005
+ "slli t3, %[BLK], 2 \n\t"
1006
+ "blt %[K], %[BLK], LOOP_TAIL%= \n\t"
1007
+ "LOOP_MAIN%=: \n\t"
1008
+ "vsetvli t0, zero, e32, m1 \n\t"
1009
+ "vxor.vv v31, v31, v31 \n\t"
1010
+ "vse32.v v31, (%[BUFFER]) \n\t"
1011
+ "addi t6, %[CNT], 0 \n\t"
1012
+ "LOOP_CMP%=: \n\t"
1013
+ "addi t6, t6, -1 \n\t"
1014
+ "vsetvli t0, zero, e32, m8 \n\t"
1015
+ "vle32.v v0, (%[SRC]) \n\t"
1016
+ "addi %[SRC], %[SRC], 256 \n\t"
1017
+ "vle32.v v8, (%[SRC]) \n\t"
1018
+ "addi %[SRC], %[SRC], 256 \n\t"
1019
+ "vle32.v v16, (%[SRC]) \n\t"
1020
+ "addi %[SRC], %[SRC], 256 \n\t"
1021
+ "vle32.v v24, (%[SRC]) \n\t"
1022
+ "addi %[SRC], %[SRC], 256 \n\t"
1023
+ "vfabs.v v0, v0 \n\t"
1024
+ "vfabs.v v8, v8 \n\t"
1025
+ "vfabs.v v16, v16 \n\t"
1026
+ "vfabs.v v24, v24 \n\t"
1027
+ "vfmax.vv v8, v0, v8 \n\t"
1028
+ "vfmax.vv v16, v16, v24 \n\t"
1029
+ "vfmax.vv v0, v0, v16 \n\t"
1030
+ "vsetvli t0, zero, e32, m4 \n\t"
1031
+ "vfmax.vv v0, v0, v4 \n\t"
1032
+ "vsetvli t0, zero, e32, m2 \n\t"
1033
+ "vfmax.vv v0, v0, v2 \n\t"
1034
+ "vsetvli t0, zero, e32, m1 \n\t"
1035
+ "vfmax.vv v0, v0, v1 \n\t"
1036
+ "vle32.v v30, (%[BUFFER]) \n\t"
1037
+ "vfmax.vv v31, v30, v0 \n\t"
1038
+ "vse32.v v31, (%[BUFFER]) \n\t"
1039
+ "bnez t6, LOOP_CMP%= \n\t"
1040
+ "sub %[SRC], %[SRC], t3 \n\t"
1041
+ "addi t6, %[CNT], 0 \n\t"
1042
+ "flw f0, (%[BUFFER]) \n\t"
1043
+ "flw f1, 4(%[BUFFER]) \n\t"
1044
+ "flw f2, 8(%[BUFFER]) \n\t"
1045
+ "flw f3, 12(%[BUFFER]) \n\t"
1046
+ "flw f4, 16(%[BUFFER]) \n\t"
1047
+ "flw f5, 20(%[BUFFER]) \n\t"
1048
+ "flw f6, 24(%[BUFFER]) \n\t"
1049
+ "flw f7, 28(%[BUFFER]) \n\t"
1050
+ "fmax.s f1, f0, f1 \n\t"
1051
+ "fmax.s f3, f2, f3 \n\t"
1052
+ "fmax.s f5, f4, f5 \n\t"
1053
+ "fmax.s f7, f6, f7 \n\t"
1054
+ "fmax.s f3, f1, f3 \n\t"
1055
+ "fmax.s f7, f5, f7 \n\t"
1056
+ "fmax.s f10, f3, f7 \n\t"
1057
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
1058
+ "fsw f10, (%[DST]) \n\t"
1059
+ "addi %[DST], %[DST], 4 \n\t"
1060
+ "fdiv.s f11, %[FONE], f10 \n\t"
1061
+ "addi t6, %[CNT], 0 \n\t"
1062
+ "LOOP_QUANT%=: \n\t"
1063
+ "addi t6, t6, -1 \n\t"
1064
+ "vsetvli t0, zero, e32, m8 \n\t"
1065
+ "vle32.v v0, (%[SRC]) \n\t"
1066
+ "addi %[SRC], %[SRC], 256 \n\t"
1067
+ "vle32.v v8, (%[SRC]) \n\t"
1068
+ "addi %[SRC], %[SRC], 256 \n\t"
1069
+ "vle32.v v16, (%[SRC]) \n\t"
1070
+ "addi %[SRC], %[SRC], 256 \n\t"
1071
+ "vle32.v v24, (%[SRC]) \n\t"
1072
+ "addi %[SRC], %[SRC], 256 \n\t"
1073
+ "vsetvli t0, zero, e32, m8 \n\t"
1074
+ "vfmul.vf v0, v0, f11 \n\t"
1075
+ "vfmul.vf v8, v8, f11 \n\t"
1076
+ "vfmul.vf v16, v16, f11 \n\t"
1077
+ "vfmul.vf v24, v24, f11 \n\t"
1078
+ "vfcvt.x.f.v v0, v0 \n\t"
1079
+ "vfcvt.x.f.v v8, v8 \n\t"
1080
+ "vfcvt.x.f.v v16, v16 \n\t"
1081
+ "vfcvt.x.f.v v24, v24 \n\t"
1082
+ "vsetvli t0, zero, e16, m4 \n\t"
1083
+ "vnclip.wx v0, v0, zero \n\t"
1084
+ "vnclip.wx v4, v8, zero \n\t"
1085
+ "vnclip.wx v8, v16, zero \n\t"
1086
+ "vnclip.wx v12, v24, zero \n\t"
1087
+ "vsetvli t0, zero, e8, m4 \n\t"
1088
+ "vnclip.wx v0, v0, zero \n\t"
1089
+ "vnclip.wx v4, v8, zero \n\t"
1090
+ "vse8.v v0, (%[DST]) \n\t"
1091
+ "addi %[DST], %[DST], 128 \n\t"
1092
+ "vse8.v v4, (%[DST]) \n\t"
1093
+ "addi %[DST], %[DST], 128 \n\t"
1094
+ "bnez t6, LOOP_QUANT%= \n\t"
1095
+ "sub %[K], %[K], %[BLK] \n\t"
1096
+ "bge %[K], %[BLK], LOOP_MAIN%= \n\t"
1097
+ "blez %[K], END%= \n\t"
1098
+ "LOOP_TAIL%=: \n\t"
1099
+ "vsetvli t0, zero, e32, m1 \n\t"
1100
+ "vxor.vv v31, v31, v31 \n\t"
1101
+ "vse32.v v31, (%[BUFFER]) \n\t"
1102
+ "addi t6, %[K], 0 \n\t"
1103
+ "addi s1, %[SRC], 0 \n\t"
1104
+ "TAIL_CMP%=: \n\t"
1105
+ "vsetvli t0, zero, e32, m8 \n\t"
1106
+ "vxor.vv v0, v0, v0 \n\t"
1107
+ "vsetvli t0, t6, e32, m8 \n\t"
1108
+ "vle32.v v0, (%[SRC]) \n\t"
1109
+ "addi %[SRC], %[SRC], 256 \n\t"
1110
+ "sub t6, t6, t0 \n\t"
1111
+ "vfabs.v v0, v0 \n\t"
1112
+ "vsetvli t0, zero, e32, m4 \n\t"
1113
+ "vfmax.vv v0, v0, v4 \n\t"
1114
+ "vsetvli t0, zero, e32, m2 \n\t"
1115
+ "vfmax.vv v0, v0, v2 \n\t"
1116
+ "vsetvli t0, zero, e32, m1 \n\t"
1117
+ "vfmax.vv v0, v0, v1 \n\t"
1118
+ "vle32.v v30, (%[BUFFER]) \n\t"
1119
+ "vfmax.vv v31, v30, v0 \n\t"
1120
+ "vse32.v v31, (%[BUFFER]) \n\t"
1121
+ "bnez t6, TAIL_CMP%= \n\t"
1122
+ "addi t6, %[K], 0 \n\t"
1123
+ "flw f0, (%[BUFFER]) \n\t"
1124
+ "flw f1, 4(%[BUFFER]) \n\t"
1125
+ "flw f2, 8(%[BUFFER]) \n\t"
1126
+ "flw f3, 12(%[BUFFER]) \n\t"
1127
+ "flw f4, 16(%[BUFFER]) \n\t"
1128
+ "flw f5, 20(%[BUFFER]) \n\t"
1129
+ "flw f6, 24(%[BUFFER]) \n\t"
1130
+ "flw f7, 28(%[BUFFER]) \n\t"
1131
+ "fmax.s f1, f0, f1 \n\t"
1132
+ "fmax.s f3, f2, f3 \n\t"
1133
+ "fmax.s f5, f4, f5 \n\t"
1134
+ "fmax.s f7, f6, f7 \n\t"
1135
+ "fmax.s f3, f1, f3 \n\t"
1136
+ "fmax.s f7, f5, f7 \n\t"
1137
+ "fmax.s f10, f3, f7 \n\t"
1138
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
1139
+ "fsw f10, (%[DST]) \n\t"
1140
+ "addi %[DST], %[DST], 4 \n\t"
1141
+ "fdiv.s f11, %[FONE], f10 \n\t"
1142
+ "addi t6, %[K], 0 \n\t"
1143
+ "TAIL_QUANT%=: \n\t"
1144
+ "vsetvli t0, zero, e32, m8 \n\t"
1145
+ "vxor.vv v0, v0, v0 \n\t"
1146
+ "vsetvli t1, t6, e32, m8 \n\t"
1147
+ "vle32.v v0, (s1) \n\t"
1148
+ "addi s1, s1, 256 \n\t"
1149
+ "sub t6, t6, t1 \n\t"
1150
+ "vsetvli t0, zero, e32, m8 \n\t"
1151
+ "vfmul.vf v0, v0, f11 \n\t"
1152
+ "vfcvt.x.f.v v0, v0 \n\t"
1153
+ "vsetvli t0, zero, e16, m4 \n\t"
1154
+ "vnclip.wx v0, v0, zero \n\t"
1155
+ "vsetvli t0, t1, e8, m2 \n\t"
1156
+ "vnclip.wx v0, v0, zero \n\t"
1157
+ "vse8.v v0, (%[DST]) \n\t"
1158
+ "addi %[DST], %[DST], 64 \n\t"
1159
+ "bnez t6, TAIL_QUANT%= \n\t"
1160
+ "END%=: \n\t"
1161
+ : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK)
1162
+ : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BLK] "r"(BlkLen), [BUFFER] "r"(buffer),
1163
+ [CNT] "r"(cnt)
1164
+ : "cc", "t1", "t0", "t6", "s1", "f0", "f1", "f2", "f3", "f4", "f5", "f6");
1165
+ }
1166
+ }
1167
+
1168
+ } // namespace ime1
1169
+
1170
+ namespace {
1171
+ #define SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 \
1172
+ "vmadot v16, v14, v0 \n\t" \
1173
+ "vmadot v18, v14, v1 \n\t" \
1174
+ "vmadot v20, v14, v2 \n\t" \
1175
+ "vmadot v22, v14, v3 \n\t" \
1176
+ "vmadot v16, v15, v4 \n\t" \
1177
+ "vmadot v18, v15, v5 \n\t" \
1178
+ "vmadot v20, v15, v6 \n\t" \
1179
+ "vmadot v22, v15, v7 \n\t"
1180
+
1181
+ #define SQ4BIT_KERNEL_ACC_1X4X4 \
1182
+ "vfcvt.f.x.v v16, v16 \n\t" \
1183
+ "vfcvt.f.x.v v18, v18 \n\t" \
1184
+ "vfcvt.f.x.v v20, v20 \n\t" \
1185
+ "vfcvt.f.x.v v22, v22 \n\t" \
1186
+ "addi s2, s1, 16 \n\t" \
1187
+ "addi s3, s1, 32 \n\t" \
1188
+ "addi s4, s1, 48 \n\t" \
1189
+ "addi s6, s5, 12 \n\t" \
1190
+ "vfmacc.vv v28, v16, v24 \n\t" \
1191
+ "vfmacc.vv v29, v18, v25 \n\t" \
1192
+ "vfmacc.vv v30, v20, v26 \n\t" \
1193
+ "vfmacc.vv v31, v22, v27 \n\t"
1194
+
1195
+ #define SQ4BIT_KERNEL_ACC_F16_1X4X4 \
1196
+ "vfcvt.f.x.v v16, v16 \n\t" \
1197
+ "vfcvt.f.x.v v18, v18 \n\t" \
1198
+ "vfcvt.f.x.v v20, v20 \n\t" \
1199
+ "vfcvt.f.x.v v22, v22 \n\t" \
1200
+ "addi s2, s1, 8 \n\t" \
1201
+ "addi s3, s1, 16 \n\t" \
1202
+ "addi s4, s1, 24 \n\t" \
1203
+ "addi s6, s5, 12 \n\t" \
1204
+ "vfmacc.vv v28, v16, v24 \n\t" \
1205
+ "vfmacc.vv v29, v18, v25 \n\t" \
1206
+ "vfmacc.vv v30, v20, v26 \n\t" \
1207
+ "vfmacc.vv v31, v22, v27 \n\t"
1208
+
1209
+ #define SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 \
1210
+ "vle8.v v4, (s1) \n\t" \
1211
+ "addi s1, s1, 128 \n\t" \
1212
+ "vle8.v v5, (s2) \n\t" \
1213
+ "addi s2, s2, 128 \n\t" \
1214
+ "vle8.v v6, (s3) \n\t" \
1215
+ "addi s3, s3, 128 \n\t" \
1216
+ "vle8.v v7, (s4) \n\t" \
1217
+ "addi s4, s4, 128 \n\t" \
1218
+ "vsetvli t0, zero, e8, mf4 \n\t" \
1219
+ "vle8.v v14, (s5) \n\t" \
1220
+ "addi s5, s5, 16 \n\t" \
1221
+ "vle8.v v15, (s6) \n\t" \
1222
+ "addi s6, s6, 16 \n\t" \
1223
+ "addi t5, t5, -1 \n\t" \
1224
+ "vsetvli t0, zero, e8, m1 \n\t" \
1225
+ "vand.vi v0, v4, 15 \n\t" \
1226
+ "vand.vi v1, v5, 15 \n\t" \
1227
+ "vand.vi v2, v6, 15 \n\t" \
1228
+ "vand.vi v3, v7, 15 \n\t" \
1229
+ "vsrl.vi v4, v4, 4 \n\t" \
1230
+ "vsrl.vi v5, v5, 4 \n\t" \
1231
+ "vsrl.vi v6, v6, 4 \n\t" \
1232
+ "vsrl.vi v7, v7, 4 \n\t"
1233
+
1234
+ #define SQ4BIT_KERNEL_LOAD_ZP_16X1 \
1235
+ "vsetvli t0, zero, e8, mf2 \n\t" \
1236
+ "vle8.v v1, (s7) \n\t" \
1237
+ "vsetvli t0, zero, e8, m1 \n\t" \
1238
+ "vrgather.vv v8, v1, v13 \n\t" \
1239
+ "vadd.vi v13, v13, 4 \n\t" \
1240
+ "vrgather.vv v9, v1, v13 \n\t" \
1241
+ "vadd.vi v13, v13, 4 \n\t" \
1242
+ "vrgather.vv v10, v1, v13 \n\t" \
1243
+ "vadd.vi v13, v13, 4 \n\t" \
1244
+ "vrgather.vv v11, v1, v13 \n\t" \
1245
+ "vadd.vi v13, v13, -12 \n\t"
1246
+
1247
+ // using for M4Kernel
1248
+ #define LOAD_B_16x8x2 \
1249
+ "vsetvli t0, zero, e8, m1 \n\t" \
1250
+ "vle8.v v6, (s1) \n\t" \
1251
+ "addi s1, s1, 32*4 \n\t" \
1252
+ "vle8.v v7, (s2) \n\t" \
1253
+ "addi s2, s2, 32*4 \n\t" \
1254
+ "vle8.v v8, (s3) \n\t" \
1255
+ "addi s3, s3, 32*4 \n\t" \
1256
+ "vle8.v v9, (s4) \n\t" \
1257
+ "addi s4, s4, 32*4 \n\t" \
1258
+ \
1259
+ "vand.vi v2, v6, 15 \n\t" \
1260
+ "vand.vi v3, v7, 15 \n\t" \
1261
+ "vand.vi v4, v8, 15 \n\t" \
1262
+ "vand.vi v5, v9, 15 \n\t" \
1263
+ \
1264
+ "vsrl.vi v6, v6, 4 \n\t" \
1265
+ "vsrl.vi v7, v7, 4 \n\t" \
1266
+ "vsrl.vi v8, v8, 4 \n\t" \
1267
+ "vsrl.vi v9, v9, 4 \n\t"
1268
+
1269
+ // [s2|s5, s3, s4, s6]
1270
+ #define LOAD_SCALE_4x16_FP16 \
1271
+ "addi s2, s5, -8 \n\t" \
1272
+ "addi s3, s5, 8 \n\t" \
1273
+ "addi s4, s5, 16 \n\t" \
1274
+ "addi s6, s5, 24 \n\t" \
1275
+ "li t1, 0xf0 \n\t" \
1276
+ "vmv.s.x v0, t1 \n\t" \
1277
+ "vsetvli t0, zero, e16, mf4 \n\t" \
1278
+ "vle16.v v9, (s5) \n\t" \
1279
+ "vle16.v v11, (s3) \n\t" \
1280
+ "vle16.v v13, (s4) \n\t" \
1281
+ "vle16.v v15, (s6) \n\t" \
1282
+ "vsetvli t0, zero, e16, mf2 \n\t" \
1283
+ "vle16.v v9, (s2), v0.t \n\t" \
1284
+ "vle16.v v11, (s5), v0.t \n\t" \
1285
+ "vle16.v v13, (s3), v0.t \n\t" \
1286
+ "vle16.v v15, (s4), v0.t \n\t" \
1287
+ "vfwcvt.f.f.v v8, v9 \n\t" \
1288
+ "vfwcvt.f.f.v v10, v11 \n\t" \
1289
+ "vfwcvt.f.f.v v12, v13 \n\t" \
1290
+ "vfwcvt.f.f.v v14, v15 \n\t" \
1291
+ "vsetvli t0, zero, e32, m1 \n\t" \
1292
+ "vmv.v.v v9, v8 \n\t" \
1293
+ "vmv.v.v v11, v10 \n\t" \
1294
+ "vmv.v.v v13, v12 \n\t" \
1295
+ "vmv.v.v v15, v14 \n\t" \
1296
+ "li t1, 0xf0 \n\t" \
1297
+ "vmv.s.x v0, t1 \n\t" \
1298
+ "vsetvli t0, zero, e32, mf2 \n\t" \
1299
+ "vfmul.vf v8, v8, f1 \n\t" \
1300
+ "vfmul.vf v10, v10, f1 \n\t" \
1301
+ "vfmul.vf v12, v12, f1 \n\t" \
1302
+ "vfmul.vf v14, v14, f1 \n\t" \
1303
+ "vfmul.vf v9, v9, f3 \n\t" \
1304
+ "vfmul.vf v11, v11, f3 \n\t" \
1305
+ "vfmul.vf v13, v13, f3 \n\t" \
1306
+ "vfmul.vf v15, v15, f3 \n\t" \
1307
+ "vsetvli t0, zero, e32, m1 \n\t" \
1308
+ "vfmul.vf v8, v8, f2, v0.t \n\t" \
1309
+ "vfmul.vf v10, v10, f2, v0.t \n\t" \
1310
+ "vfmul.vf v12, v12, f2, v0.t \n\t" \
1311
+ "vfmul.vf v14, v14, f2, v0.t \n\t" \
1312
+ "vfmul.vf v9, v9, f4, v0.t \n\t" \
1313
+ "vfmul.vf v11, v11, f4, v0.t \n\t" \
1314
+ "vfmul.vf v13, v13, f4, v0.t \n\t" \
1315
+ "vfmul.vf v15, v15, f4, v0.t \n\t"
1316
+
1317
+ // [s2|s5, s3, s4, s6]
1318
+ #define LOAD_SCALE_4x16 \
1319
+ "addi s2, s5, -16 \n\t" \
1320
+ "addi s3, s5, 16 \n\t" \
1321
+ "addi s4, s5, 32 \n\t" \
1322
+ "addi s6, s5, 48 \n\t" \
1323
+ "li t1, 0xf0 \n\t" \
1324
+ "vmv.s.x v0, t1 \n\t" \
1325
+ "vsetvli t0, zero, e32, mf2 \n\t" \
1326
+ "vle32.v v8, (s5) \n\t" \
1327
+ "vle32.v v10, (s3) \n\t" \
1328
+ "vle32.v v12, (s4) \n\t" \
1329
+ "vle32.v v14, (s6) \n\t" \
1330
+ "vsetvli t0, zero, e32, m1 \n\t" \
1331
+ "vle32.v v8, (s2), v0.t \n\t" \
1332
+ "vle32.v v10, (s5), v0.t \n\t" \
1333
+ "vle32.v v12, (s3), v0.t \n\t" \
1334
+ "vle32.v v14, (s4), v0.t \n\t" \
1335
+ "vmv.v.v v9, v8 \n\t" \
1336
+ "vmv.v.v v11, v10 \n\t" \
1337
+ "vmv.v.v v13, v12 \n\t" \
1338
+ "vmv.v.v v15, v14 \n\t" \
1339
+ "vsetvli t0, zero, e32, mf2 \n\t" \
1340
+ "vfmul.vf v8, v8, f1 \n\t" \
1341
+ "vfmul.vf v10, v10, f1 \n\t" \
1342
+ "vfmul.vf v12, v12, f1 \n\t" \
1343
+ "vfmul.vf v14, v14, f1 \n\t" \
1344
+ "vfmul.vf v9, v9, f3 \n\t" \
1345
+ "vfmul.vf v11, v11, f3 \n\t" \
1346
+ "vfmul.vf v13, v13, f3 \n\t" \
1347
+ "vfmul.vf v15, v15, f3 \n\t" \
1348
+ "vsetvli t0, zero, e32, m1 \n\t" \
1349
+ "vfmul.vf v8, v8, f2, v0.t \n\t" \
1350
+ "vfmul.vf v10, v10, f2, v0.t \n\t" \
1351
+ "vfmul.vf v12, v12, f2, v0.t \n\t" \
1352
+ "vfmul.vf v14, v14, f2, v0.t \n\t" \
1353
+ "vfmul.vf v9, v9, f4, v0.t \n\t" \
1354
+ "vfmul.vf v11, v11, f4, v0.t \n\t" \
1355
+ "vfmul.vf v13, v13, f4, v0.t \n\t" \
1356
+ "vfmul.vf v15, v15, f4, v0.t \n\t"
1357
+
1358
+ //[s1| BIAS, s2, s3, s4]
1359
+ #define LOAD_BIAS \
1360
+ "vsetvli t0, zero, e32, mf2 \n\t" \
1361
+ "li t1, 0xf0 \n\t" \
1362
+ "vmv.s.x v0, t1 \n\t" \
1363
+ "addi s1, %[BIAS], -16 \n\t" \
1364
+ "addi s2, %[BIAS], 16 \n\t" \
1365
+ "addi s3, %[BIAS], 32 \n\t" \
1366
+ "addi s4, %[BIAS], 48 \n\t" \
1367
+ \
1368
+ "vle32.v v24, (%[BIAS]) \n\t" \
1369
+ "vle32.v v26, (s2) \n\t" \
1370
+ "vle32.v v28, (s3) \n\t" \
1371
+ "vle32.v v30, (s4) \n\t" \
1372
+ "vsetvli t0, zero, e32, m1 \n\t" \
1373
+ "vle32.v v24, (s1), v0.t \n\t" \
1374
+ "vle32.v v26, (%[BIAS]), v0.t \n\t" \
1375
+ "vle32.v v28, (s2), v0.t \n\t" \
1376
+ "vle32.v v30, (s3), v0.t \n\t" \
1377
+ "vmv.v.v v25, v24 \n\t" \
1378
+ "vmv.v.v v27, v26 \n\t" \
1379
+ "vmv.v.v v29, v28 \n\t" \
1380
+ "vmv.v.v v31, v30 \n\t"
1381
+
1382
+ #define SQ4BIT_KERNEL_COMP_4x16x16 \
1383
+ "vmadot v16, v10, v2 \n\t" \
1384
+ "vmadot v18, v10, v3 \n\t" \
1385
+ "vmadot v20, v10, v4 \n\t" \
1386
+ "vmadot v22, v10, v5 \n\t" \
1387
+ "vmadot v16, v11, v6 \n\t" \
1388
+ "vmadot v18, v11, v7 \n\t" \
1389
+ "vmadot v20, v11, v8 \n\t" \
1390
+ "vmadot v22, v11, v9 \n\t"
1391
+
1392
+ #define SAVE_RESULT_4x16 \
1393
+ "addi a1, %[C], 0 \n\t" \
1394
+ "add a2, %[C], %[LDC] \n\t" \
1395
+ "add a3, a2, %[LDC] \n\t" \
1396
+ "add a4, a3, %[LDC] \n\t" \
1397
+ "addi a2, a2, -16 \n\t" \
1398
+ "addi a4, a4, -16 \n\t" \
1399
+ "li t1, 0xf0 \n\t" \
1400
+ "vmv.s.x v0, t1 \n\t" \
1401
+ "vsetvli t0, zero, e32, mf2 \n\t" \
1402
+ \
1403
+ "vse32.v v24, (a1) \n\t" \
1404
+ "addi a1, a1, 16 \n\t" \
1405
+ "vse32.v v25, (a3) \n\t" \
1406
+ "addi a3, a3, 16 \n\t" \
1407
+ \
1408
+ "vse32.v v26, (a1) \n\t" \
1409
+ "addi a1, a1, 16 \n\t" \
1410
+ "vse32.v v27, (a3) \n\t" \
1411
+ "addi a3, a3, 16 \n\t" \
1412
+ \
1413
+ "vse32.v v28, (a1) \n\t" \
1414
+ "addi a1, a1, 16 \n\t" \
1415
+ "vse32.v v29, (a3) \n\t" \
1416
+ "addi a3, a3, 16 \n\t" \
1417
+ \
1418
+ "vse32.v v30, (a1) \n\t" \
1419
+ "vse32.v v31, (a3) \n\t" \
1420
+ "vsetvli t0, zero, e32, m1 \n\t" \
1421
+ \
1422
+ "vse32.v v24, (a2), v0.t \n\t" \
1423
+ "addi a2, a2, 16 \n\t" \
1424
+ "vse32.v v25, (a4), v0.t \n\t" \
1425
+ "addi a4, a4, 16 \n\t" \
1426
+ \
1427
+ "vse32.v v26, (a2), v0.t \n\t" \
1428
+ "addi a2, a2, 16 \n\t" \
1429
+ "vse32.v v27, (a4), v0.t \n\t" \
1430
+ "addi a4, a4, 16 \n\t" \
1431
+ \
1432
+ "vse32.v v28, (a2), v0.t \n\t" \
1433
+ "addi a2, a2, 16 \n\t" \
1434
+ "vse32.v v29, (a4), v0.t \n\t" \
1435
+ "addi a4, a4, 16 \n\t" \
1436
+ \
1437
+ "vse32.v v30, (a2), v0.t \n\t" \
1438
+ "vse32.v v31, (a4), v0.t \n\t"
1439
+
1440
+ #define SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 \
1441
+ "vsetvli t0, zero, e8, mf2 \n\t" \
1442
+ "vle8.v v11, (s6) \n\t" \
1443
+ "vsetvli t0, zero, e8, m1 \n\t" \
1444
+ "vrgather.vv v12, v11, v1 \n\t" \
1445
+ "vadd.vi v1, v1, 4 \n\t" \
1446
+ "vrgather.vv v13, v11, v1 \n\t" \
1447
+ "vadd.vi v1, v1, 4 \n\t" \
1448
+ "vrgather.vv v14, v11, v1 \n\t" \
1449
+ "vadd.vi v1, v1, 4 \n\t" \
1450
+ "vrgather.vv v15, v11, v1 \n\t" \
1451
+ "vadd.vi v1, v1, -12 \n\t"
1452
+
1453
+ template <bool HasZeroPoint>
1454
+ void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
1455
+ const std::byte * QuantA,
1456
+ const std::byte * QuantBData,
1457
+ const float * QuantBScale,
1458
+ const std::byte * QuantBZeroPoint,
1459
+ float * C,
1460
+ size_t CountN,
1461
+ size_t BlockCountK,
1462
+ const float * Bias,
1463
+ const size_t ldc) {
1464
+ GGML_UNUSED(QuantBScale);
1465
+ GGML_UNUSED(QuantBZeroPoint);
1466
+ size_t LDC = ldc * sizeof(float);
1467
+ const size_t INNER = BlkLen / 16;
1468
+ float tmp[4 * 16];
1469
+
1470
+ if constexpr (HasZeroPoint) {
1471
+ for (size_t n = 0; n < CountN; n += 16) {
1472
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
1473
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
1474
+ n * BlockCountK * BlkLen / 2 + // b data
1475
+ n * BlockCountK * sizeof(uint8_t) + // zp
1476
+ n * BlockCountK * sizeof(_Float16); // scale
1477
+ float * CPtr = C + n;
1478
+ if (NBLKS < 16) {
1479
+ CPtr = tmp;
1480
+ LDC = 16 * sizeof(float);
1481
+ }
1482
+ if (Bias != nullptr) {
1483
+ const float * bias = Bias + n;
1484
+ if (NBLKS < 16) {
1485
+ __asm__ volatile(
1486
+ "vsetvli t0, %[N], e32, m2 \n\t"
1487
+ "vle32.v v0, (%[SRC]) \n\t"
1488
+ "vse32.v v0, (%[DST]) \n\t"
1489
+ :
1490
+ : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
1491
+ : "cc", "t0");
1492
+ bias = tmp;
1493
+ }
1494
+ __asm__ volatile(LOAD_BIAS
1495
+
1496
+ "addi t3, %[BlockCountK], 0 \n\t"
1497
+
1498
+ "vsetvli t0, zero, e8, m1 \n\t"
1499
+ "li s1, 24 \n\t"
1500
+ "vmv.v.i v1, 3 \n\t"
1501
+ "vsetvli t0, s1, e8, m1 \n\t"
1502
+ "vmv.v.i v1, 2 \n\t"
1503
+ "vsetvli t0, zero, e8, mf2 \n\t"
1504
+ "vmv.v.i v1, 1 \n\t"
1505
+ "vsetvli t0, zero, e8, mf4 \n\t"
1506
+ "vmv.v.i v1, 0 \n\t"
1507
+
1508
+ "addi a1, %[A], 0 \n\t"
1509
+ "addi s1, %[B], 0 \n\t"
1510
+
1511
+ "BLOCK_COUNTK_LOOP%=: \n\t"
1512
+ // scale offset
1513
+ "addi s5, s1, 0 \n\t"
1514
+ // zp offset
1515
+ "addi s6, s1, 32 \n\t"
1516
+ "addi s1, s6, 16 \n\t"
1517
+ "addi s2, s1, 32 \n\t"
1518
+ "addi s3, s1, 32*2 \n\t"
1519
+ "addi s4, s1, 32*3 \n\t"
1520
+
1521
+ "vsetvli t0, zero, e32, m8 \n\t"
1522
+ "vxor.vv v16, v16, v16 \n\t"
1523
+ // load a scale
1524
+ "flw f1, (a1) \n\t"
1525
+ "flw f2, 4(a1) \n\t"
1526
+ "flw f3, 8(a1) \n\t"
1527
+ "flw f4, 12(a1) \n\t"
1528
+ "addi a1, a1, 16 \n\t"
1529
+ "addi t2, %[INNER], 0 \n\t"
1530
+
1531
+ SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1532
+
1533
+ "BLOCK_INNER_LOOP%=: \n\t"
1534
+
1535
+ LOAD_B_16x8x2
1536
+
1537
+ "vle8.v v10, (a1) \n\t"
1538
+ "addi a1, a1, 32 \n\t"
1539
+ "vle8.v v11, (a1) \n\t"
1540
+ "addi a1, a1, 32 \n\t"
1541
+ "vsub.vv v2, v2, v12 \n\t"
1542
+ "vsub.vv v6, v6, v12 \n\t"
1543
+ "vsub.vv v3, v3, v13 \n\t"
1544
+ "vsub.vv v7, v7, v13 \n\t"
1545
+ "vsub.vv v4, v4, v14 \n\t"
1546
+ "vsub.vv v8, v8, v14 \n\t"
1547
+ "vsub.vv v5, v5, v15 \n\t"
1548
+ "vsub.vv v9, v9, v15 \n\t"
1549
+
1550
+ SQ4BIT_KERNEL_COMP_4x16x16
1551
+
1552
+ "addi t2, t2, -1 \n\t"
1553
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1554
+
1555
+ LOAD_SCALE_4x16_FP16
1556
+
1557
+ "vsetvli t0, zero, e32, m8 \n\t"
1558
+ "vfcvt.f.x.v v16, v16 \n\t"
1559
+ "vfmacc.vv v24, v16, v8 \n\t"
1560
+ "addi t3, t3, -1 \n\t"
1561
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1562
+
1563
+ "RESULT_SAVE%=: \n\t"
1564
+
1565
+ SAVE_RESULT_4x16
1566
+
1567
+ :
1568
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1569
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
1570
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
1571
+ "s2", "s3", "s4", "s5", "s6");
1572
+
1573
+ } else {
1574
+ __asm__ volatile(
1575
+ "vsetvli t0, zero, e32, m8 \n\t"
1576
+ "vxor.vv v24, v24, v24 \n\t"
1577
+ "addi t3, %[BlockCountK], 0 \n\t"
1578
+ "vsetvli t0, zero, e8, m1 \n\t"
1579
+ "li s1, 24 \n\t"
1580
+ "vmv.v.i v1, 3 \n\t"
1581
+ "vsetvli t0, s1, e8, m1 \n\t"
1582
+ "vmv.v.i v1, 2 \n\t"
1583
+ "vsetvli t0, zero, e8, mf2 \n\t"
1584
+ "vmv.v.i v1, 1 \n\t"
1585
+ "vsetvli t0, zero, e8, mf4 \n\t"
1586
+ "vmv.v.i v1, 0 \n\t"
1587
+ "addi a1, %[A], 0 \n\t"
1588
+ "addi s1, %[B], 0 \n\t"
1589
+ "BLOCK_COUNTK_LOOP%=: \n\t"
1590
+ // scale offset
1591
+ "addi s5, s1, 0 \n\t"
1592
+ // zp offset
1593
+ "addi s6, s1, 32 \n\t"
1594
+ "addi s1, s6, 16 \n\t"
1595
+ "addi s2, s1, 32 \n\t"
1596
+ "addi s3, s1, 32*2 \n\t"
1597
+ "addi s4, s1, 32*3 \n\t"
1598
+
1599
+ "vsetvli t0, zero, e32, m8 \n\t"
1600
+ "vxor.vv v16, v16, v16 \n\t"
1601
+ // load a scale
1602
+ "flw f1, (a1) \n\t"
1603
+ "flw f2, 4(a1) \n\t"
1604
+ "flw f3, 8(a1) \n\t"
1605
+ "flw f4, 12(a1) \n\t"
1606
+ "addi a1, a1, 16 \n\t"
1607
+ "addi t2, %[INNER], 0 \n\t"
1608
+
1609
+ SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1610
+
1611
+ "BLOCK_INNER_LOOP%=: \n\t"
1612
+
1613
+ LOAD_B_16x8x2
1614
+
1615
+ "vle8.v v10, (a1) \n\t"
1616
+ "addi a1, a1, 32 \n\t"
1617
+ "vle8.v v11, (a1) \n\t"
1618
+ "addi a1, a1, 32 \n\t"
1619
+ "vsub.vv v2, v2, v12 \n\t"
1620
+ "vsub.vv v6, v6, v12 \n\t"
1621
+ "vsub.vv v3, v3, v13 \n\t"
1622
+ "vsub.vv v7, v7, v13 \n\t"
1623
+ "vsub.vv v4, v4, v14 \n\t"
1624
+ "vsub.vv v8, v8, v14 \n\t"
1625
+ "vsub.vv v5, v5, v15 \n\t"
1626
+ "vsub.vv v9, v9, v15 \n\t"
1627
+
1628
+ SQ4BIT_KERNEL_COMP_4x16x16
1629
+
1630
+ "addi t2, t2, -1 \n\t"
1631
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1632
+
1633
+ LOAD_SCALE_4x16_FP16
1634
+
1635
+ "vsetvli t0, zero, e32, m8 \n\t"
1636
+ "vfcvt.f.x.v v16, v16 \n\t"
1637
+ "vfmacc.vv v24, v16, v8 \n\t"
1638
+ "addi t3, t3, -1 \n\t"
1639
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1640
+
1641
+ "RESULT_SAVE%=: \n\t"
1642
+
1643
+ SAVE_RESULT_4x16
1644
+
1645
+ :
1646
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1647
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
1648
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
1649
+ "s4", "s5", "s6");
1650
+ }
1651
+ }
1652
+ } else {
1653
+ for (size_t n = 0; n < CountN; n += 16) {
1654
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
1655
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
1656
+ n * BlockCountK * BlkLen / 2 + // b data
1657
+ n * BlockCountK * sizeof(_Float16); // scale
1658
+ float * CPtr = C + n;
1659
+ if (NBLKS < 16) {
1660
+ CPtr = tmp;
1661
+ LDC = 16 * sizeof(float);
1662
+ }
1663
+ if (Bias != nullptr) {
1664
+ const float * bias = Bias + n;
1665
+ if (NBLKS < 16) {
1666
+ __asm__ volatile(
1667
+ "vsetvli t0, %[N], e32, m2 \n\t"
1668
+ "vle32.v v0, (%[SRC]) \n\t"
1669
+ "vse32.v v0, (%[DST]) \n\t"
1670
+ :
1671
+ : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
1672
+ : "cc", "t0");
1673
+ bias = tmp;
1674
+ }
1675
+ __asm__ volatile(LOAD_BIAS
1676
+
1677
+ "addi t3, %[BlockCountK], 0 \n\t"
1678
+ "addi a1, %[A], 0 \n\t"
1679
+ "addi s1, %[B], 0 \n\t"
1680
+ "BLOCK_COUNTK_LOOP%=: \n\t"
1681
+ "addi s5, s1, 0 \n\t"
1682
+ "addi s1, s5, 32 \n\t"
1683
+ "addi s2, s1, 32 \n\t"
1684
+ "addi s3, s1, 32*2 \n\t"
1685
+ "addi s4, s1, 32*3 \n\t"
1686
+ "vsetvli t0, zero, e32, m8 \n\t"
1687
+ "vxor.vv v16, v16, v16 \n\t"
1688
+ // load a scale
1689
+ "flw f1, (a1) \n\t"
1690
+ "flw f2, 4(a1) \n\t"
1691
+ "flw f3, 8(a1) \n\t"
1692
+ "flw f4, 12(a1) \n\t"
1693
+ "addi a1, a1, 16 \n\t"
1694
+ "addi t2, %[INNER], 0 \n\t"
1695
+ "BLOCK_INNER_LOOP%=: \n\t"
1696
+
1697
+ LOAD_B_16x8x2
1698
+
1699
+ "vsetvli t0, zero, e8, m1 \n\t"
1700
+ "vle8.v v10, (a1) \n\t"
1701
+ "addi a1, a1, 32 \n\t"
1702
+ "vle8.v v11, (a1) \n\t"
1703
+ "addi a1, a1, 32 \n\t"
1704
+ "vadd.vi v2, v2, -8 \n\t"
1705
+ "vadd.vi v3, v3, -8 \n\t"
1706
+ "vadd.vi v4, v4, -8 \n\t"
1707
+ "vadd.vi v5, v5, -8 \n\t"
1708
+ "vadd.vi v6, v6, -8 \n\t"
1709
+ "vadd.vi v7, v7, -8 \n\t"
1710
+ "vadd.vi v8, v8, -8 \n\t"
1711
+ "vadd.vi v9, v9, -8 \n\t"
1712
+
1713
+ SQ4BIT_KERNEL_COMP_4x16x16
1714
+
1715
+ "addi t2, t2, -1 \n\t"
1716
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1717
+
1718
+ LOAD_SCALE_4x16_FP16
1719
+
1720
+ "vsetvli t0, zero, e32, m8 \n\t"
1721
+ "vfcvt.f.x.v v16, v16 \n\t"
1722
+ "vfmacc.vv v24, v16, v8 \n\t"
1723
+ "addi t3, t3, -1 \n\t"
1724
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1725
+ "RESULT_SAVE%=: \n\t"
1726
+
1727
+ SAVE_RESULT_4x16
1728
+
1729
+ :
1730
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1731
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
1732
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
1733
+ "s2", "s3", "s4", "s5", "s6");
1734
+
1735
+ } else {
1736
+ __asm__ volatile(
1737
+ "vsetvli t0, zero, e32, m8 \n\t"
1738
+ "vxor.vv v24, v24, v24 \n\t"
1739
+ "addi t3, %[BlockCountK], 0 \n\t"
1740
+ "addi a1, %[A], 0 \n\t"
1741
+ "addi s1, %[B], 0 \n\t"
1742
+ "BLOCK_COUNTK_LOOP%=: \n\t"
1743
+ "addi s5, s1, 0 \n\t"
1744
+ "addi s1, s5, 32 \n\t"
1745
+ "addi s2, s1, 32 \n\t"
1746
+ "addi s3, s1, 32*2 \n\t"
1747
+ "addi s4, s1, 32*3 \n\t"
1748
+ "vsetvli t0, zero, e32, m8 \n\t"
1749
+ "vxor.vv v16, v16, v16 \n\t"
1750
+ // load a scale
1751
+ "flw f1, (a1) \n\t"
1752
+ "flw f2, 4(a1) \n\t"
1753
+ "flw f3, 8(a1) \n\t"
1754
+ "flw f4, 12(a1) \n\t"
1755
+ "addi a1, a1, 16 \n\t"
1756
+ "addi t2, %[INNER], 0 \n\t"
1757
+ "BLOCK_INNER_LOOP%=: \n\t"
1758
+
1759
+ LOAD_B_16x8x2
1760
+
1761
+ "vsetvli t0, zero, e8, m1 \n\t"
1762
+ "vle8.v v10, (a1) \n\t"
1763
+ "addi a1, a1, 32 \n\t"
1764
+ "vle8.v v11, (a1) \n\t"
1765
+ "addi a1, a1, 32 \n\t"
1766
+ "vadd.vi v2, v2, -8 \n\t"
1767
+ "vadd.vi v3, v3, -8 \n\t"
1768
+ "vadd.vi v4, v4, -8 \n\t"
1769
+ "vadd.vi v5, v5, -8 \n\t"
1770
+ "vadd.vi v6, v6, -8 \n\t"
1771
+ "vadd.vi v7, v7, -8 \n\t"
1772
+ "vadd.vi v8, v8, -8 \n\t"
1773
+ "vadd.vi v9, v9, -8 \n\t"
1774
+
1775
+ SQ4BIT_KERNEL_COMP_4x16x16
1776
+
1777
+ "addi t2, t2, -1 \n\t"
1778
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1779
+
1780
+ LOAD_SCALE_4x16_FP16
1781
+
1782
+ "vsetvli t0, zero, e32, m8 \n\t"
1783
+ "vfcvt.f.x.v v16, v16 \n\t"
1784
+ "vfmacc.vv v24, v16, v8 \n\t"
1785
+ "addi t3, t3, -1 \n\t"
1786
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1787
+ "RESULT_SAVE%=: \n\t"
1788
+
1789
+ SAVE_RESULT_4x16
1790
+
1791
+ :
1792
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1793
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
1794
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
1795
+ "s4", "s5", "s6");
1796
+ }
1797
+ }
1798
+ }
1799
+ if (CountN % 16 != 0) {
1800
+ // stroe output from tmp to C when NBLKS less than 16.
1801
+ float * CPtr = C + CountN / 16 * 16;
1802
+ const size_t N = CountN % 16;
1803
+ LDC = ldc * sizeof(float);
1804
+ __asm__ volatile(
1805
+ "vsetvli t0, %[N], e32, m2 \n\t"
1806
+ "vle32.v v0, (%[SRC]) \n\t"
1807
+ "addi s2, %[SRC], 64 \n\t"
1808
+ "addi s3, %[SRC], 64*2 \n\t"
1809
+ "addi s4, %[SRC], 64*3 \n\t"
1810
+ "vle32.v v2, (s2) \n\t"
1811
+ "vle32.v v4, (s3) \n\t"
1812
+ "vle32.v v6, (s4) \n\t"
1813
+ "add t2, %[DST], %[LDC] \n\t"
1814
+ "add t3, t2, %[LDC] \n\t"
1815
+ "add t4, t3, %[LDC] \n\t"
1816
+ "vse32.v v0, (%[DST]) \n\t"
1817
+ "vse32.v v2, (t2) \n\t"
1818
+ "vse32.v v4, (t3) \n\t"
1819
+ "vse32.v v6, (t4) \n\t"
1820
+ :
1821
+ : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC)
1822
+ : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4");
1823
+ }
1824
+ }
1825
+
1826
+ template <bool HasZeroPoint>
1827
+ void SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen,
1828
+ const std::byte * QuantA,
1829
+ const std::byte * QuantBData,
1830
+ const float * QuantBScale,
1831
+ const std::byte * QuantBZeroPoint,
1832
+ float * C,
1833
+ size_t CountN,
1834
+ size_t BlockCountK,
1835
+ const float * Bias,
1836
+ const size_t ldc) {
1837
+ GGML_UNUSED(QuantBScale);
1838
+ GGML_UNUSED(QuantBZeroPoint);
1839
+ size_t LDC = ldc * sizeof(float);
1840
+ const size_t INNER = BlkLen / 16;
1841
+ float tmp[4 * 16];
1842
+
1843
+ if constexpr (HasZeroPoint) {
1844
+ for (size_t n = 0; n < CountN; n += 16) {
1845
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
1846
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
1847
+ n * BlockCountK * BlkLen / 2 + // b data
1848
+ n * BlockCountK * sizeof(uint8_t) + // zp
1849
+ n * BlockCountK * sizeof(float); // scale
1850
+ float * CPtr = C + n;
1851
+ if (NBLKS < 16) {
1852
+ CPtr = tmp;
1853
+ LDC = 16 * sizeof(float);
1854
+ }
1855
+ if (Bias != nullptr) {
1856
+ const float * bias = Bias + n;
1857
+ if (NBLKS < 16) {
1858
+ __asm__ volatile(
1859
+ "vsetvli t0, %[N], e32, m2 \n\t"
1860
+ "vle32.v v0, (%[SRC]) \n\t"
1861
+ "vse32.v v0, (%[DST]) \n\t"
1862
+ :
1863
+ : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
1864
+ : "cc", "t0");
1865
+ bias = tmp;
1866
+ }
1867
+
1868
+ __asm__ volatile(LOAD_BIAS
1869
+ "addi t3, %[BlockCountK], 0 \n\t"
1870
+ "vsetvli t0, zero, e8, m1 \n\t"
1871
+ "li s1, 24 \n\t"
1872
+ "vmv.v.i v1, 3 \n\t"
1873
+ "vsetvli t0, s1, e8, m1 \n\t"
1874
+ "vmv.v.i v1, 2 \n\t"
1875
+ "vsetvli t0, zero, e8, mf2 \n\t"
1876
+ "vmv.v.i v1, 1 \n\t"
1877
+ "vsetvli t0, zero, e8, mf4 \n\t"
1878
+ "vmv.v.i v1, 0 \n\t"
1879
+ "addi a1, %[A], 0 \n\t"
1880
+ "addi s1, %[B], 0 \n\t"
1881
+ "BLOCK_COUNTK_LOOP%=: \n\t"
1882
+ // scale offset
1883
+ "addi s5, s1, 0 \n\t"
1884
+ // zp offset
1885
+ "addi s6, s1, 64 \n\t"
1886
+ "addi s1, s6, 16 \n\t"
1887
+ "addi s2, s1, 32 \n\t"
1888
+ "addi s3, s1, 32*2 \n\t"
1889
+ "addi s4, s1, 32*3 \n\t"
1890
+ "vsetvli t0, zero, e32, m8 \n\t"
1891
+ "vxor.vv v16, v16, v16 \n\t"
1892
+ // load a scale
1893
+ "flw f1, (a1) \n\t"
1894
+ "flw f2, 4(a1) \n\t"
1895
+ "flw f3, 8(a1) \n\t"
1896
+ "flw f4, 12(a1) \n\t"
1897
+ "addi a1, a1, 16 \n\t"
1898
+ "addi t2, %[INNER], 0 \n\t"
1899
+
1900
+ SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1901
+
1902
+ "BLOCK_INNER_LOOP%=: \n\t"
1903
+
1904
+ LOAD_B_16x8x2
1905
+
1906
+ "vle8.v v10, (a1) \n\t"
1907
+ "addi a1, a1, 32 \n\t"
1908
+ "vle8.v v11, (a1) \n\t"
1909
+ "addi a1, a1, 32 \n\t"
1910
+ "vsub.vv v2, v2, v12 \n\t"
1911
+ "vsub.vv v6, v6, v12 \n\t"
1912
+ "vsub.vv v3, v3, v13 \n\t"
1913
+ "vsub.vv v7, v7, v13 \n\t"
1914
+ "vsub.vv v4, v4, v14 \n\t"
1915
+ "vsub.vv v8, v8, v14 \n\t"
1916
+ "vsub.vv v5, v5, v15 \n\t"
1917
+ "vsub.vv v9, v9, v15 \n\t"
1918
+
1919
+ SQ4BIT_KERNEL_COMP_4x16x16
1920
+
1921
+ "addi t2, t2, -1 \n\t"
1922
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1923
+
1924
+ LOAD_SCALE_4x16
1925
+
1926
+ "vsetvli t0, zero, e32, m8 \n\t"
1927
+ "vfcvt.f.x.v v16, v16 \n\t"
1928
+ "vfmacc.vv v24, v16, v8 \n\t"
1929
+ "addi t3, t3, -1 \n\t"
1930
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1931
+
1932
+ "RESULT_SAVE%=: \n\t"
1933
+
1934
+ SAVE_RESULT_4x16
1935
+
1936
+ :
1937
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1938
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
1939
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
1940
+ "s2", "s3", "s4", "s5", "s6");
1941
+
1942
+ } else {
1943
+ __asm__ volatile(
1944
+ "vsetvli t0, zero, e32, m8 \n\t"
1945
+ "vxor.vv v24, v24, v24 \n\t"
1946
+ "addi t3, %[BlockCountK], 0 \n\t"
1947
+ "vsetvli t0, zero, e8, m1 \n\t"
1948
+ "li s1, 24 \n\t"
1949
+ "vmv.v.i v1, 3 \n\t"
1950
+ "vsetvli t0, s1, e8, m1 \n\t"
1951
+ "vmv.v.i v1, 2 \n\t"
1952
+ "vsetvli t0, zero, e8, mf2 \n\t"
1953
+ "vmv.v.i v1, 1 \n\t"
1954
+ "vsetvli t0, zero, e8, mf4 \n\t"
1955
+ "vmv.v.i v1, 0 \n\t"
1956
+ "addi a1, %[A], 0 \n\t"
1957
+ "addi s1, %[B], 0 \n\t"
1958
+ "BLOCK_COUNTK_LOOP%=: \n\t"
1959
+ // scale offset
1960
+ "addi s5, s1, 0 \n\t"
1961
+ // zp offset
1962
+ "addi s6, s1, 64 \n\t"
1963
+ "addi s1, s6, 16 \n\t"
1964
+ "addi s2, s1, 32 \n\t"
1965
+ "addi s3, s1, 32*2 \n\t"
1966
+ "addi s4, s1, 32*3 \n\t"
1967
+ "vsetvli t0, zero, e32, m8 \n\t"
1968
+ "vxor.vv v16, v16, v16 \n\t"
1969
+ // load a scale
1970
+ // load a scale
1971
+ "flw f1, (a1) \n\t"
1972
+ "flw f2, 4(a1) \n\t"
1973
+ "flw f3, 8(a1) \n\t"
1974
+ "flw f4, 12(a1) \n\t"
1975
+ "addi a1, a1, 16 \n\t"
1976
+ "addi t2, %[INNER], 0 \n\t"
1977
+
1978
+ SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1979
+
1980
+ "BLOCK_INNER_LOOP%=: \n\t"
1981
+
1982
+ LOAD_B_16x8x2
1983
+
1984
+ "vle8.v v10, (a1) \n\t"
1985
+ "addi a1, a1, 32 \n\t"
1986
+ "vle8.v v11, (a1) \n\t"
1987
+ "addi a1, a1, 32 \n\t"
1988
+ "vsub.vv v2, v2, v12 \n\t"
1989
+ "vsub.vv v6, v6, v12 \n\t"
1990
+ "vsub.vv v3, v3, v13 \n\t"
1991
+ "vsub.vv v7, v7, v13 \n\t"
1992
+ "vsub.vv v4, v4, v14 \n\t"
1993
+ "vsub.vv v8, v8, v14 \n\t"
1994
+ "vsub.vv v5, v5, v15 \n\t"
1995
+ "vsub.vv v9, v9, v15 \n\t"
1996
+
1997
+ SQ4BIT_KERNEL_COMP_4x16x16
1998
+
1999
+ "addi t2, t2, -1 \n\t"
2000
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
2001
+
2002
+ LOAD_SCALE_4x16
2003
+
2004
+ "vsetvli t0, zero, e32, m8 \n\t"
2005
+ "vfcvt.f.x.v v16, v16 \n\t"
2006
+ "vfmacc.vv v24, v16, v8 \n\t"
2007
+ "addi t3, t3, -1 \n\t"
2008
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
2009
+
2010
+ "RESULT_SAVE%=: \n\t"
2011
+
2012
+ SAVE_RESULT_4x16
2013
+
2014
+ :
2015
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
2016
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
2017
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
2018
+ "s4", "s5", "s6");
2019
+ }
2020
+ }
2021
+ } else {
2022
+ for (size_t n = 0; n < CountN; n += 16) {
2023
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
2024
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2025
+ n * BlockCountK * BlkLen / 2 + // b data
2026
+ n * BlockCountK * sizeof(float); // scale
2027
+ float * CPtr = C + n;
2028
+ if (NBLKS < 16) {
2029
+ CPtr = tmp;
2030
+ LDC = 16 * sizeof(float);
2031
+ }
2032
+ if (Bias != nullptr) {
2033
+ const float * bias = Bias + n;
2034
+ if (NBLKS < 16) {
2035
+ __asm__ volatile(
2036
+ "vsetvli t0, %[N], e32, m2 \n\t"
2037
+ "vle32.v v0, (%[SRC]) \n\t"
2038
+ "vse32.v v0, (%[DST]) \n\t"
2039
+ :
2040
+ : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
2041
+ : "cc", "t0");
2042
+ bias = tmp;
2043
+ }
2044
+ __asm__ volatile(LOAD_BIAS
2045
+ "addi t3, %[BlockCountK], 0 \n\t"
2046
+ "addi a1, %[A], 0 \n\t"
2047
+ "addi s1, %[B], 0 \n\t"
2048
+ "BLOCK_COUNTK_LOOP%=: \n\t"
2049
+ "addi s5, s1, 0 \n\t"
2050
+ "addi s1, s5, 64 \n\t"
2051
+ "addi s2, s1, 32 \n\t"
2052
+ "addi s3, s1, 32*2 \n\t"
2053
+ "addi s4, s1, 32*3 \n\t"
2054
+ "vsetvli t0, zero, e32, m8 \n\t"
2055
+ "vxor.vv v16, v16, v16 \n\t"
2056
+ // load a scale
2057
+ "flw f1, (a1) \n\t"
2058
+ "flw f2, 4(a1) \n\t"
2059
+ "flw f3, 8(a1) \n\t"
2060
+ "flw f4, 12(a1) \n\t"
2061
+ "addi a1, a1, 16 \n\t"
2062
+ "addi t2, %[INNER], 0 \n\t"
2063
+ "BLOCK_INNER_LOOP%=: \n\t"
2064
+
2065
+ LOAD_B_16x8x2
2066
+
2067
+ "vsetvli t0, zero, e8, m1 \n\t"
2068
+ "vle8.v v10, (a1) \n\t"
2069
+ "addi a1, a1, 32 \n\t"
2070
+ "vle8.v v11, (a1) \n\t"
2071
+ "addi a1, a1, 32 \n\t"
2072
+ "vadd.vi v2, v2, -8 \n\t"
2073
+ "vadd.vi v3, v3, -8 \n\t"
2074
+ "vadd.vi v4, v4, -8 \n\t"
2075
+ "vadd.vi v5, v5, -8 \n\t"
2076
+ "vadd.vi v6, v6, -8 \n\t"
2077
+ "vadd.vi v7, v7, -8 \n\t"
2078
+ "vadd.vi v8, v8, -8 \n\t"
2079
+ "vadd.vi v9, v9, -8 \n\t"
2080
+
2081
+ SQ4BIT_KERNEL_COMP_4x16x16
2082
+
2083
+ "addi t2, t2, -1 \n\t"
2084
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
2085
+
2086
+ LOAD_SCALE_4x16
2087
+
2088
+ "vsetvli t0, zero, e32, m8 \n\t"
2089
+ "vfcvt.f.x.v v16, v16 \n\t"
2090
+ "vfmacc.vv v24, v16, v8 \n\t"
2091
+ "addi t3, t3, -1 \n\t"
2092
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
2093
+
2094
+ "RESULT_SAVE%=: \n\t"
2095
+
2096
+ SAVE_RESULT_4x16
2097
+
2098
+ :
2099
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
2100
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
2101
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
2102
+ "s2", "s3", "s4", "s5", "s6");
2103
+
2104
+ } else {
2105
+ __asm__ volatile(
2106
+ "vsetvli t0, zero, e32, m8 \n\t"
2107
+ "vxor.vv v24, v24, v24 \n\t"
2108
+ "addi t3, %[BlockCountK], 0 \n\t"
2109
+ "addi a1, %[A], 0 \n\t"
2110
+ "addi s1, %[B], 0 \n\t"
2111
+ "BLOCK_COUNTK_LOOP%=: \n\t"
2112
+ "addi s5, s1, 0 \n\t"
2113
+ "addi s1, s5, 64 \n\t"
2114
+ "addi s2, s1, 32 \n\t"
2115
+ "addi s3, s1, 32*2 \n\t"
2116
+ "addi s4, s1, 32*3 \n\t"
2117
+ "vsetvli t0, zero, e32, m8 \n\t"
2118
+ "vxor.vv v16, v16, v16 \n\t"
2119
+ // load a scale
2120
+ "flw f1, (a1) \n\t"
2121
+ "flw f2, 4(a1) \n\t"
2122
+ "flw f3, 8(a1) \n\t"
2123
+ "flw f4, 12(a1) \n\t"
2124
+ "addi a1, a1, 16 \n\t"
2125
+ "addi t2, %[INNER], 0 \n\t"
2126
+ "BLOCK_INNER_LOOP%=: \n\t"
2127
+
2128
+ LOAD_B_16x8x2
2129
+
2130
+ "vsetvli t0, zero, e8, m1 \n\t"
2131
+ "vle8.v v10, (a1) \n\t"
2132
+
2133
+ "addi a1, a1, 32 \n\t"
2134
+ "vle8.v v11, (a1) \n\t"
2135
+ "addi a1, a1, 32 \n\t"
2136
+ "vadd.vi v2, v2, -8 \n\t"
2137
+ "vadd.vi v3, v3, -8 \n\t"
2138
+ "vadd.vi v4, v4, -8 \n\t"
2139
+ "vadd.vi v5, v5, -8 \n\t"
2140
+ "vadd.vi v6, v6, -8 \n\t"
2141
+ "vadd.vi v7, v7, -8 \n\t"
2142
+ "vadd.vi v8, v8, -8 \n\t"
2143
+ "vadd.vi v9, v9, -8 \n\t"
2144
+
2145
+ SQ4BIT_KERNEL_COMP_4x16x16
2146
+
2147
+ "addi t2, t2, -1 \n\t"
2148
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
2149
+
2150
+ LOAD_SCALE_4x16
2151
+
2152
+ "vsetvli t0, zero, e32, m8 \n\t"
2153
+ "vfcvt.f.x.v v16, v16 \n\t"
2154
+ "vfmacc.vv v24, v16, v8 \n\t"
2155
+ "addi t3, t3, -1 \n\t"
2156
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
2157
+
2158
+ "RESULT_SAVE%=: \n\t"
2159
+
2160
+ SAVE_RESULT_4x16
2161
+
2162
+ :
2163
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
2164
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
2165
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
2166
+ "s4", "s5", "s6");
2167
+ }
2168
+ }
2169
+ }
2170
+ if (CountN % 16 != 0) {
2171
+ // stroe output from tmp to C when NBLKS less than 16.
2172
+ float * CPtr = C + CountN / 16 * 16;
2173
+ const size_t N = CountN % 16;
2174
+ LDC = ldc * sizeof(float);
2175
+ __asm__ volatile(
2176
+ "vsetvli t0, %[N], e32, m2 \n\t"
2177
+ "vle32.v v0, (%[SRC]) \n\t"
2178
+ "addi s2, %[SRC], 64 \n\t"
2179
+ "addi s3, %[SRC], 64*2 \n\t"
2180
+ "addi s4, %[SRC], 64*3 \n\t"
2181
+ "vle32.v v2, (s2) \n\t"
2182
+ "vle32.v v4, (s3) \n\t"
2183
+ "vle32.v v6, (s4) \n\t"
2184
+ "add t2, %[DST], %[LDC] \n\t"
2185
+ "add t3, t2, %[LDC] \n\t"
2186
+ "add t4, t3, %[LDC] \n\t"
2187
+ "vse32.v v0, (%[DST]) \n\t"
2188
+ "vse32.v v2, (t2) \n\t"
2189
+ "vse32.v v4, (t3) \n\t"
2190
+ "vse32.v v6, (t4) \n\t"
2191
+ :
2192
+ : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC)
2193
+ : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4");
2194
+ }
2195
+ }
2196
+
2197
+ template <bool HasZeroPoint>
2198
+ void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
2199
+ const std::byte * QuantA,
2200
+ const std::byte * QuantBData,
2201
+ const float * QuantBScale,
2202
+ const std::byte * QuantBZeroPoint,
2203
+ float * C,
2204
+ size_t CountN,
2205
+ size_t BlockCountK,
2206
+ const float * Bias) {
2207
+ GGML_UNUSED(QuantBScale);
2208
+ GGML_UNUSED(QuantBZeroPoint);
2209
+ size_t INNER = BlkLen / 16;
2210
+
2211
+ if constexpr (HasZeroPoint) {
2212
+ for (size_t n = 0; n < CountN; n += 16) {
2213
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2214
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2215
+ n * BlockCountK * BlkLen / 2 + // b data
2216
+ n * BlockCountK * sizeof(uint8_t) + // zp
2217
+ n * BlockCountK * sizeof(_Float16); // scale
2218
+ float * CPtr = C + n;
2219
+ size_t cnt = BlockCountK;
2220
+ if (Bias != nullptr) {
2221
+ const float * bias = Bias + n;
2222
+ __asm__ volatile(
2223
+ "addi t3, %[NBLKS], 0 \n\t"
2224
+ "vsetvli t0, zero, e8, m1 \n\t"
2225
+
2226
+ "vmv.v.i v13, 3 \n\t"
2227
+ "li s1, 24 \n\t"
2228
+ "vsetvli t0, s1, e8, m1 \n\t"
2229
+ "vmv.v.i v13, 2 \n\t"
2230
+ "vsetvli t0, zero, e8, mf2 \n\t"
2231
+ "vmv.v.i v13, 1 \n\t"
2232
+ "vsetvli t0, zero, e8, mf4 \n\t"
2233
+ "vmv.v.i v13, 0 \n\t"
2234
+ "addi s1, %[B], 0 \n\t"
2235
+ "addi s2, %[B], 8 \n\t"
2236
+ "addi s3, %[B], 16 \n\t"
2237
+ "addi s4, %[B], 24 \n\t"
2238
+ // zp offset
2239
+ "addi s7, %[B], 32 \n\t"
2240
+ // a offset
2241
+ "addi s5, %[A], 0 \n\t"
2242
+ "addi s6, %[A], 12 \n\t"
2243
+
2244
+ "vsetvli t0, t3, e32, mf2 \n\t"
2245
+ "vle32.v v28, (%[BIAS]) \n\t"
2246
+ "sub t3, t3, t0 \n\t"
2247
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2248
+ "vsetvli t0, t3, e32, mf2 \n\t"
2249
+ "vle32.v v29, (%[BIAS]) \n\t"
2250
+ "sub t3, t3, t0 \n\t"
2251
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2252
+ "vsetvli t0, t3, e32, mf2 \n\t"
2253
+ "vle32.v v30, (%[BIAS]) \n\t"
2254
+ "sub t3, t3, t0 \n\t"
2255
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2256
+ "vsetvli t0, t3, e32, mf2 \n\t"
2257
+ "vle32.v v31, (%[BIAS]) \n\t"
2258
+
2259
+ "LOOP_K%=: \n\t"
2260
+ "vsetvli t0, zero, e16, mf4 \n\t"
2261
+
2262
+ "vle16.v v4, (s1) \n\t"
2263
+ "addi s1, s1, 48 \n\t"
2264
+ "vle16.v v5, (s2) \n\t"
2265
+ "addi s2, s2, 72 \n\t"
2266
+ "vle16.v v6, (s3) \n\t"
2267
+ "addi s3, s3, 96 \n\t"
2268
+ "vle16.v v7, (s4) \n\t"
2269
+ "addi s4, s4, 120 \n\t"
2270
+ "flw f1, (s5) \n\t"
2271
+ "addi s5, s5, 4 \n\t"
2272
+ "vfwcvt.f.f.v v8, v4 \n\t"
2273
+ "vfwcvt.f.f.v v9, v5 \n\t"
2274
+ "vfwcvt.f.f.v v10, v6 \n\t"
2275
+ "vfwcvt.f.f.v v11, v7 \n\t"
2276
+
2277
+ "vsetvli t0, zero, e32, mf2 \n\t"
2278
+ "addi t5, %[INNER], 0 \n\t"
2279
+ "vxor.vv v16, v16, v16 \n\t"
2280
+ "vxor.vv v18, v18, v18 \n\t"
2281
+ "vxor.vv v20, v20, v20 \n\t"
2282
+ "vxor.vv v22, v22, v22 \n\t"
2283
+ "vfmul.vf v24, v8, f1 \n\t"
2284
+ "vfmul.vf v25, v9, f1 \n\t"
2285
+ "vfmul.vf v26, v10, f1 \n\t"
2286
+ "vfmul.vf v27, v11, f1 \n\t"
2287
+ "addi %[CNT], %[CNT], -1 \n\t"
2288
+
2289
+ SQ4BIT_KERNEL_LOAD_ZP_16X1
2290
+
2291
+ "LOOP_INNER%=: \n\t"
2292
+
2293
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2294
+
2295
+ "vsub.vv v0, v0, v8 \n\t"
2296
+ "vsub.vv v4, v4, v8 \n\t"
2297
+ "vsub.vv v1, v1, v9 \n\t"
2298
+ "vsub.vv v5, v5, v9 \n\t"
2299
+ "vsub.vv v2, v2, v10 \n\t"
2300
+ "vsub.vv v6, v6, v10 \n\t"
2301
+ "vsub.vv v3, v3, v11 \n\t"
2302
+ "vsub.vv v7, v7, v11 \n\t"
2303
+
2304
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2305
+
2306
+ "bnez t5, LOOP_INNER%= \n\t"
2307
+ "vsetvli t0, zero, e32, mf2 \n\t"
2308
+
2309
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
2310
+ "addi s7, s1, 32 \n\t"
2311
+
2312
+ "bnez %[CNT], LOOP_K%= \n\t"
2313
+ "addi t3, zero, 16 \n\t"
2314
+ "addi s1, %[C], 16 \n\t"
2315
+ "addi s2, %[C], 32 \n\t"
2316
+ "addi s3, %[C], 48 \n\t"
2317
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2318
+ "vse32.v v28, (%[C]) \n\t"
2319
+ "vse32.v v29, (s1) \n\t"
2320
+ "vse32.v v30, (s2) \n\t"
2321
+ "vse32.v v31, (s3) \n\t"
2322
+ "jal x0, END%= \n\t"
2323
+
2324
+ "ST_TAIL%=: \n\t"
2325
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2326
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2327
+ "vse32.v v28, (%[C]) \n\t"
2328
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2329
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2330
+ "vse32.v v29, (s1) \n\t"
2331
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2332
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2333
+ "vse32.v v30, (s2) \n\t"
2334
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2335
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2336
+ "vse32.v v31, (s3) \n\t"
2337
+ "END%=: \n\t"
2338
+
2339
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
2340
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2341
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2342
+ } else {
2343
+ __asm__ volatile(
2344
+ "vsetvli t0, zero, e32, m4 \n\t"
2345
+ "vxor.vv v28, v28, v28 \n\t"
2346
+
2347
+ "vsetvli t0, zero, e8, m1 \n\t"
2348
+ "vmv.v.i v13, 3 \n\t"
2349
+ "li s1, 24 \n\t"
2350
+ "vsetvli t0, s1, e8, m1 \n\t"
2351
+ "vmv.v.i v13, 2 \n\t"
2352
+ "vsetvli t0, zero, e8, mf2 \n\t"
2353
+ "vmv.v.i v13, 1 \n\t"
2354
+ "vsetvli t0, zero, e8, mf4 \n\t"
2355
+ "vmv.v.i v13, 0 \n\t"
2356
+
2357
+ "addi s1, %[B], 0 \n\t"
2358
+ "addi s2, %[B], 8 \n\t"
2359
+ "addi s3, %[B], 16 \n\t"
2360
+ "addi s4, %[B], 24 \n\t"
2361
+
2362
+ "addi s7, %[B], 32 \n\t"
2363
+
2364
+ "addi s5, %[A], 0 \n\t"
2365
+ "addi s6, %[A], 12 \n\t"
2366
+ "LOOP_K%=: \n\t"
2367
+ "vsetvli t0, zero, e16, mf4 \n\t"
2368
+ "vle16.v v4, (s1) \n\t"
2369
+ "addi s1, s1, 48 \n\t"
2370
+ "vle16.v v5, (s2) \n\t"
2371
+ "addi s2, s2, 72 \n\t"
2372
+ "vle16.v v6, (s3) \n\t"
2373
+ "addi s3, s3, 96 \n\t"
2374
+ "vle16.v v7, (s4) \n\t"
2375
+ "addi s4, s4, 120 \n\t"
2376
+ "flw f1, (s5) \n\t"
2377
+ "addi s5, s5, 4 \n\t"
2378
+
2379
+ "vfwcvt.f.f.v v8, v4 \n\t"
2380
+ "vfwcvt.f.f.v v9, v5 \n\t"
2381
+ "vfwcvt.f.f.v v10, v6 \n\t"
2382
+ "vfwcvt.f.f.v v11, v7 \n\t"
2383
+ "vsetvli t0, zero, e32, mf2 \n\t"
2384
+
2385
+ "addi t5, %[INNER], 0 \n\t"
2386
+ "vxor.vv v16, v16, v16 \n\t"
2387
+ "vxor.vv v18, v18, v18 \n\t"
2388
+ "vxor.vv v20, v20, v20 \n\t"
2389
+ "vxor.vv v22, v22, v22 \n\t"
2390
+ "vfmul.vf v24, v8, f1 \n\t"
2391
+ "vfmul.vf v25, v9, f1 \n\t"
2392
+ "vfmul.vf v26, v10, f1 \n\t"
2393
+ "vfmul.vf v27, v11, f1 \n\t"
2394
+ "addi %[CNT], %[CNT], -1 \n\t"
2395
+
2396
+ SQ4BIT_KERNEL_LOAD_ZP_16X1
2397
+
2398
+ "LOOP_INNER%=: \n\t"
2399
+
2400
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2401
+
2402
+ "vsub.vv v0, v0, v8 \n\t"
2403
+ "vsub.vv v4, v4, v8 \n\t"
2404
+ "vsub.vv v1, v1, v9 \n\t"
2405
+ "vsub.vv v5, v5, v9 \n\t"
2406
+ "vsub.vv v2, v2, v10 \n\t"
2407
+ "vsub.vv v6, v6, v10 \n\t"
2408
+ "vsub.vv v3, v3, v11 \n\t"
2409
+ "vsub.vv v7, v7, v11 \n\t"
2410
+
2411
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2412
+
2413
+ "bnez t5, LOOP_INNER%= \n\t"
2414
+ "vsetvli t0, zero, e32, mf2 \n\t"
2415
+
2416
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
2417
+ "addi s7, s1, 32 \n\t"
2418
+
2419
+ "bnez %[CNT], LOOP_K%= \n\t"
2420
+ "addi t3, zero, 16 \n\t"
2421
+ "addi s1, %[C], 16 \n\t"
2422
+ "addi s2, %[C], 32 \n\t"
2423
+ "addi s3, %[C], 48 \n\t"
2424
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2425
+ "vse32.v v28, (%[C]) \n\t"
2426
+ "vse32.v v29, (s1) \n\t"
2427
+ "vse32.v v30, (s2) \n\t"
2428
+ "vse32.v v31, (s3) \n\t"
2429
+ "jal x0, END%= \n\t"
2430
+
2431
+ "ST_TAIL%=: \n\t"
2432
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2433
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2434
+ "vse32.v v28, (%[C]) \n\t"
2435
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2436
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2437
+ "vse32.v v29, (s1) \n\t"
2438
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2439
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2440
+ "vse32.v v30, (s2) \n\t"
2441
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2442
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2443
+ "vse32.v v31, (s3) \n\t"
2444
+ "END%=: \n\t"
2445
+
2446
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
2447
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2448
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2449
+ }
2450
+ }
2451
+ } else {
2452
+ for (size_t n = 0; n < CountN; n += 16) {
2453
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2454
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2455
+ n * BlockCountK * BlkLen / 2 + // b data
2456
+ n * BlockCountK * sizeof(_Float16); // scale
2457
+ float * CPtr = C + n;
2458
+ size_t cnt = BlockCountK;
2459
+ if (Bias != nullptr) {
2460
+ const float * bias = Bias + n;
2461
+ __asm__ volatile(
2462
+ "addi t3, %[NBLKS], 0 \n\t"
2463
+ "addi s1, %[B], 0 \n\t"
2464
+ "addi s2, %[B], 8 \n\t"
2465
+ "addi s3, %[B], 16 \n\t"
2466
+ "addi s4, %[B], 24 \n\t"
2467
+ "addi s5, %[A], 0 \n\t"
2468
+ "addi s6, %[A], 12 \n\t"
2469
+ "vsetvli t0, t3, e32, mf2 \n\t"
2470
+ "vle32.v v28, (%[BIAS]) \n\t"
2471
+ "sub t3, t3, t0 \n\t"
2472
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2473
+ "vsetvli t0, t3, e32, mf2 \n\t"
2474
+ "vle32.v v29, (%[BIAS]) \n\t"
2475
+ "sub t3, t3, t0 \n\t"
2476
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2477
+ "vsetvli t0, t3, e32, mf2 \n\t"
2478
+ "vle32.v v30, (%[BIAS]) \n\t"
2479
+ "sub t3, t3, t0 \n\t"
2480
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2481
+ "vsetvli t0, t3, e32, mf2 \n\t"
2482
+ "vle32.v v31, (%[BIAS]) \n\t"
2483
+
2484
+ "LOOP_K%=: \n\t"
2485
+ "vsetvli t0, zero, e16, mf4 \n\t"
2486
+
2487
+ "vle16.v v4, (s1) \n\t"
2488
+ "addi s1, s1, 32 \n\t"
2489
+ "vle16.v v5, (s2) \n\t"
2490
+ "addi s2, s2, 56 \n\t"
2491
+ "vle16.v v6, (s3) \n\t"
2492
+ "addi s3, s3, 80 \n\t"
2493
+ "vle16.v v7, (s4) \n\t"
2494
+ "addi s4, s4, 104 \n\t"
2495
+ "flw f1, (s5) \n\t"
2496
+ "addi s5, s5, 4 \n\t"
2497
+ "vfwcvt.f.f.v v8, v4 \n\t"
2498
+ "vfwcvt.f.f.v v9, v5 \n\t"
2499
+ "vfwcvt.f.f.v v10, v6 \n\t"
2500
+ "vfwcvt.f.f.v v11, v7 \n\t"
2501
+
2502
+ "vsetvli t0, zero, e32, mf2 \n\t"
2503
+ "addi t5, %[INNER], 0 \n\t"
2504
+ "vxor.vv v16, v16, v16 \n\t"
2505
+ "vxor.vv v18, v18, v18 \n\t"
2506
+ "vxor.vv v20, v20, v20 \n\t"
2507
+ "vxor.vv v22, v22, v22 \n\t"
2508
+ "vfmul.vf v24, v8, f1 \n\t"
2509
+ "vfmul.vf v25, v9, f1 \n\t"
2510
+ "vfmul.vf v26, v10, f1 \n\t"
2511
+ "vfmul.vf v27, v11, f1 \n\t"
2512
+ "addi %[CNT], %[CNT], -1 \n\t"
2513
+ "vsetvli t0, zero, e8, m1 \n\t"
2514
+ "LOOP_INNER%=: \n\t"
2515
+
2516
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2517
+
2518
+ "vadd.vi v0, v0, -8 \n\t"
2519
+ "vadd.vi v1, v1, -8 \n\t"
2520
+ "vadd.vi v2, v2, -8 \n\t"
2521
+ "vadd.vi v3, v3, -8 \n\t"
2522
+ "vadd.vi v4, v4, -8 \n\t"
2523
+ "vadd.vi v5, v5, -8 \n\t"
2524
+ "vadd.vi v6, v6, -8 \n\t"
2525
+ "vadd.vi v7, v7, -8 \n\t"
2526
+
2527
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2528
+
2529
+ "bnez t5, LOOP_INNER%= \n\t"
2530
+ "vsetvli t0, zero, e32, mf2 \n\t"
2531
+
2532
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
2533
+
2534
+ "bnez %[CNT], LOOP_K%= \n\t"
2535
+ "addi t3, zero, 16 \n\t"
2536
+ "addi s1, %[C], 16 \n\t"
2537
+ "addi s2, %[C], 32 \n\t"
2538
+ "addi s3, %[C], 48 \n\t"
2539
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2540
+ "vse32.v v28, (%[C]) \n\t"
2541
+ "vse32.v v29, (s1) \n\t"
2542
+ "vse32.v v30, (s2) \n\t"
2543
+ "vse32.v v31, (s3) \n\t"
2544
+ "jal x0, END%= \n\t"
2545
+
2546
+ "ST_TAIL%=: \n\t"
2547
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2548
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2549
+ "vse32.v v28, (%[C]) \n\t"
2550
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2551
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2552
+ "vse32.v v29, (s1) \n\t"
2553
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2554
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2555
+ "vse32.v v30, (s2) \n\t"
2556
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2557
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2558
+ "vse32.v v31, (s3) \n\t"
2559
+ "END%=: \n\t"
2560
+
2561
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
2562
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2563
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
2564
+ } else {
2565
+ __asm__ volatile(
2566
+ "vsetvli t0, zero, e32, m4 \n\t"
2567
+ "vxor.vv v28, v28, v28 \n\t"
2568
+ "addi s1, %[B], 0 \n\t"
2569
+ "addi s2, %[B], 8 \n\t"
2570
+ "addi s3, %[B], 16 \n\t"
2571
+ "addi s4, %[B], 24 \n\t"
2572
+
2573
+ "addi s5, %[A], 0 \n\t"
2574
+ "addi s6, %[A], 12 \n\t"
2575
+ "LOOP_K%=: \n\t"
2576
+ "vsetvli t0, zero, e16, mf4 \n\t"
2577
+ "vle16.v v4, (s1) \n\t"
2578
+ "addi s1, s1, 32 \n\t"
2579
+ "vle16.v v5, (s2) \n\t"
2580
+ "addi s2, s2, 56 \n\t"
2581
+ "vle16.v v6, (s3) \n\t"
2582
+ "addi s3, s3, 80 \n\t"
2583
+ "vle16.v v7, (s4) \n\t"
2584
+ "addi s4, s4, 104 \n\t"
2585
+ "flw f1, (s5) \n\t"
2586
+ "addi s5, s5, 4 \n\t"
2587
+
2588
+ "vfwcvt.f.f.v v8, v4 \n\t"
2589
+ "vfwcvt.f.f.v v9, v5 \n\t"
2590
+ "vfwcvt.f.f.v v10, v6 \n\t"
2591
+ "vfwcvt.f.f.v v11, v7 \n\t"
2592
+ "vsetvli t0, zero, e32, mf2 \n\t"
2593
+
2594
+ "addi t5, %[INNER], 0 \n\t"
2595
+ "vxor.vv v16, v16, v16 \n\t"
2596
+ "vxor.vv v18, v18, v18 \n\t"
2597
+ "vxor.vv v20, v20, v20 \n\t"
2598
+ "vxor.vv v22, v22, v22 \n\t"
2599
+ "vfmul.vf v24, v8, f1 \n\t"
2600
+ "vfmul.vf v25, v9, f1 \n\t"
2601
+ "vfmul.vf v26, v10, f1 \n\t"
2602
+ "vfmul.vf v27, v11, f1 \n\t"
2603
+ "addi %[CNT], %[CNT], -1 \n\t"
2604
+ "vsetvli t0, zero, e8, m1 \n\t"
2605
+ "LOOP_INNER%=: \n\t"
2606
+
2607
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2608
+
2609
+ "vadd.vi v0, v0, -8 \n\t"
2610
+ "vadd.vi v1, v1, -8 \n\t"
2611
+ "vadd.vi v2, v2, -8 \n\t"
2612
+ "vadd.vi v3, v3, -8 \n\t"
2613
+ "vadd.vi v4, v4, -8 \n\t"
2614
+ "vadd.vi v5, v5, -8 \n\t"
2615
+ "vadd.vi v6, v6, -8 \n\t"
2616
+ "vadd.vi v7, v7, -8 \n\t"
2617
+
2618
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2619
+
2620
+ "bnez t5, LOOP_INNER%= \n\t"
2621
+ "vsetvli t0, zero, e32, mf2 \n\t"
2622
+
2623
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
2624
+
2625
+ "bnez %[CNT], LOOP_K%= \n\t"
2626
+ "addi t3, zero, 16 \n\t"
2627
+ "addi s1, %[C], 16 \n\t"
2628
+ "addi s2, %[C], 32 \n\t"
2629
+ "addi s3, %[C], 48 \n\t"
2630
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2631
+ "vse32.v v28, (%[C]) \n\t"
2632
+ "vse32.v v29, (s1) \n\t"
2633
+ "vse32.v v30, (s2) \n\t"
2634
+ "vse32.v v31, (s3) \n\t"
2635
+ "jal x0, END%= \n\t"
2636
+
2637
+ "ST_TAIL%=: \n\t"
2638
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2639
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2640
+ "vse32.v v28, (%[C]) \n\t"
2641
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2642
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2643
+ "vse32.v v29, (s1) \n\t"
2644
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2645
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2646
+ "vse32.v v30, (s2) \n\t"
2647
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2648
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2649
+ "vse32.v v31, (s3) \n\t"
2650
+ "END%=: \n\t"
2651
+
2652
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
2653
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2654
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
2655
+ }
2656
+ }
2657
+ }
2658
+ }
2659
+
2660
+ template <bool HasZeroPoint>
2661
+ void SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen,
2662
+ const std::byte * QuantA,
2663
+ const std::byte * QuantBData,
2664
+ const float * QuantBScale,
2665
+ const std::byte * QuantBZeroPoint,
2666
+ float * C,
2667
+ size_t CountN,
2668
+ size_t BlockCountK,
2669
+ const float * Bias) {
2670
+ GGML_UNUSED(QuantBScale);
2671
+ GGML_UNUSED(QuantBZeroPoint);
2672
+ const size_t INNER = BlkLen / 16;
2673
+ if constexpr (HasZeroPoint) {
2674
+ for (size_t n = 0; n < CountN; n += 16) {
2675
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2676
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2677
+ n * BlockCountK * BlkLen / 2 + // b data
2678
+ n * BlockCountK * sizeof(uint8_t) + // zp
2679
+ n * BlockCountK * sizeof(float); // scale
2680
+ float * CPtr = C + n;
2681
+ size_t cnt = BlockCountK;
2682
+ if (Bias != nullptr) {
2683
+ const float * bias = Bias + n;
2684
+ __asm__ volatile(
2685
+ "addi t3, %[NBLKS], 0 \n\t"
2686
+ "vsetvli t0, zero, e8, m1 \n\t"
2687
+ "vmv.v.i v13, 3 \n\t"
2688
+ "li s1, 24 \n\t"
2689
+ "vsetvli t0, s1, e8, m1 \n\t"
2690
+ "vmv.v.i v13, 2 \n\t"
2691
+ "vsetvli t0, zero, e8, mf2 \n\t"
2692
+ "vmv.v.i v13, 1 \n\t"
2693
+ "vsetvli t0, zero, e8, mf4 \n\t"
2694
+ "vmv.v.i v13, 0 \n\t"
2695
+ "vsetvli t0, zero, e32, m4 \n\t"
2696
+ "vxor.vv v28, v28, v28 \n\t"
2697
+
2698
+ // scale offset, scale0.0, scale1.0, scale2.0, scale3.0....scale15.0
2699
+ "addi s1, %[B], 0 \n\t"
2700
+ "addi s2, %[B], 16 \n\t"
2701
+ "addi s3, %[B], 32 \n\t"
2702
+ "addi s4, %[B], 48 \n\t"
2703
+ // zp offset
2704
+ "addi s7, %[B], 64 \n\t"
2705
+ // a offset
2706
+ "addi s5, %[A], 0 \n\t"
2707
+ "addi s6, %[A], 12 \n\t"
2708
+
2709
+ "vsetvli t0, t3, e32, mf2 \n\t"
2710
+ "vle32.v v28, (%[BIAS]) \n\t"
2711
+ "sub t3, t3, t0 \n\t"
2712
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2713
+ "vsetvli t0, t3, e32, mf2 \n\t"
2714
+ "vle32.v v29, (%[BIAS]) \n\t"
2715
+ "sub t3, t3, t0 \n\t"
2716
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2717
+ "vsetvli t0, t3, e32, mf2 \n\t"
2718
+ "vle32.v v30, (%[BIAS]) \n\t"
2719
+ "sub t3, t3, t0 \n\t"
2720
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2721
+ "vsetvli t0, t3, e32, mf2 \n\t"
2722
+ "vle32.v v31, (%[BIAS]) \n\t"
2723
+ "vsetvli t0, zero, e32, mf2 \n\t"
2724
+ "LOOP_K%=: \n\t"
2725
+
2726
+ // load scale
2727
+ "vle32.v v8, (s1) \n\t"
2728
+ "addi s1, s1, 80 \n\t"
2729
+ "vle32.v v9, (s2) \n\t"
2730
+ "addi s2, s2, 96 \n\t"
2731
+ "vle32.v v10, (s3) \n\t"
2732
+ "addi s3, s3, 112 \n\t"
2733
+ "vle32.v v11, (s4) \n\t"
2734
+ "addi s4, s4, 128 \n\t"
2735
+
2736
+ // load a scale
2737
+ "flw f1, (s5) \n\t"
2738
+ "addi s5, s5, 4 \n\t"
2739
+
2740
+ "addi t5, %[INNER], 0 \n\t"
2741
+ "vxor.vv v16, v16, v16 \n\t"
2742
+ "vxor.vv v18, v18, v18 \n\t"
2743
+ "vxor.vv v20, v20, v20 \n\t"
2744
+ "vxor.vv v22, v22, v22 \n\t"
2745
+
2746
+ // a scale * b scale
2747
+ "vfmul.vf v24, v8, f1 \n\t"
2748
+ "vfmul.vf v25, v9, f1 \n\t"
2749
+ "vfmul.vf v26, v10, f1 \n\t"
2750
+ "vfmul.vf v27, v11, f1 \n\t"
2751
+ "addi %[CNT], %[CNT], -1 \n\t"
2752
+
2753
+ SQ4BIT_KERNEL_LOAD_ZP_16X1
2754
+
2755
+ "LOOP_INNER%=: \n\t"
2756
+
2757
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2758
+
2759
+ "vsub.vv v0, v0, v8 \n\t"
2760
+ "vsub.vv v4, v4, v8 \n\t"
2761
+ "vsub.vv v1, v1, v9 \n\t"
2762
+ "vsub.vv v5, v5, v9 \n\t"
2763
+ "vsub.vv v2, v2, v10 \n\t"
2764
+ "vsub.vv v6, v6, v10 \n\t"
2765
+ "vsub.vv v3, v3, v11 \n\t"
2766
+ "vsub.vv v7, v7, v11 \n\t"
2767
+
2768
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2769
+
2770
+ "bnez t5, LOOP_INNER%= \n\t"
2771
+ "vsetvli t0, zero, e32, mf2 \n\t"
2772
+
2773
+ SQ4BIT_KERNEL_ACC_1X4X4
2774
+ "addi s7, s1, 64 \n\t"
2775
+
2776
+ "bnez %[CNT], LOOP_K%= \n\t"
2777
+
2778
+ "addi t3, zero, 16 \n\t"
2779
+ "addi s1, %[C], 16 \n\t"
2780
+ "addi s2, %[C], 32 \n\t"
2781
+ "addi s3, %[C], 48 \n\t"
2782
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2783
+ "vse32.v v28, (%[C]) \n\t"
2784
+ "vse32.v v29, (s1) \n\t"
2785
+ "vse32.v v30, (s2) \n\t"
2786
+ "vse32.v v31, (s3) \n\t"
2787
+ "jal x0, END%= \n\t"
2788
+
2789
+ "ST_TAIL%=: \n\t"
2790
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2791
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2792
+ "vse32.v v28, (%[C]) \n\t"
2793
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2794
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2795
+ "vse32.v v29, (s1) \n\t"
2796
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2797
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2798
+ "vse32.v v30, (s2) \n\t"
2799
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2800
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2801
+ "vse32.v v31, (s3) \n\t"
2802
+ "END%=: \n\t"
2803
+
2804
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
2805
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2806
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2807
+ } else {
2808
+ __asm__ volatile(
2809
+ "vsetvli t0, zero, e32, m4 \n\t"
2810
+ "vxor.vv v28, v28, v28 \n\t"
2811
+
2812
+ "vsetvli t0, zero, e8, m1 \n\t"
2813
+ "vmv.v.i v13, 3 \n\t"
2814
+ "li s1, 24 \n\t"
2815
+ "vsetvli t0, s1, e8, m1 \n\t"
2816
+ "vmv.v.i v13, 2 \n\t"
2817
+ "vsetvli t0, zero, e8, mf2 \n\t"
2818
+ "vmv.v.i v13, 1 \n\t"
2819
+ "vsetvli t0, zero, e8, mf4 \n\t"
2820
+ "vmv.v.i v13, 0 \n\t"
2821
+ "addi s1, %[B], 0 \n\t"
2822
+ "addi s2, %[B], 16 \n\t"
2823
+ "addi s3, %[B], 32 \n\t"
2824
+ "addi s4, %[B], 48 \n\t"
2825
+
2826
+ "addi s7, %[B], 64 \n\t"
2827
+
2828
+ "addi s5, %[A], 0 \n\t"
2829
+ "addi s6, %[A], 12 \n\t"
2830
+ "vsetvli t0, zero, e32, mf2 \n\t"
2831
+
2832
+ "LOOP_K%=: \n\t"
2833
+ "vle32.v v8, (s1) \n\t"
2834
+ "addi s1, s1, 80 \n\t"
2835
+ "vle32.v v9, (s2) \n\t"
2836
+ "addi s2, s2, 96 \n\t"
2837
+ "vle32.v v10, (s3) \n\t"
2838
+ "addi s3, s3, 112 \n\t"
2839
+ "vle32.v v11, (s4) \n\t"
2840
+ "addi s4, s4, 128 \n\t"
2841
+
2842
+ "flw f1, (s5) \n\t"
2843
+ "addi s5, s5, 4 \n\t"
2844
+
2845
+ "addi t5, %[INNER], 0 \n\t"
2846
+ "vxor.vv v16, v16, v16 \n\t"
2847
+ "vxor.vv v18, v18, v18 \n\t"
2848
+ "vxor.vv v20, v20, v20 \n\t"
2849
+ "vxor.vv v22, v22, v22 \n\t"
2850
+
2851
+ "vfmul.vf v24, v8, f1 \n\t"
2852
+ "vfmul.vf v25, v9, f1 \n\t"
2853
+ "vfmul.vf v26, v10, f1 \n\t"
2854
+ "vfmul.vf v27, v11, f1 \n\t"
2855
+ "addi %[CNT], %[CNT], -1 \n\t"
2856
+
2857
+ SQ4BIT_KERNEL_LOAD_ZP_16X1
2858
+
2859
+ "LOOP_INNER%=: \n\t"
2860
+
2861
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2862
+
2863
+ "vsub.vv v0, v0, v8 \n\t"
2864
+ "vsub.vv v4, v4, v8 \n\t"
2865
+ "vsub.vv v1, v1, v9 \n\t"
2866
+ "vsub.vv v5, v5, v9 \n\t"
2867
+ "vsub.vv v2, v2, v10 \n\t"
2868
+ "vsub.vv v6, v6, v10 \n\t"
2869
+ "vsub.vv v3, v3, v11 \n\t"
2870
+ "vsub.vv v7, v7, v11 \n\t"
2871
+
2872
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2873
+
2874
+ "bnez t5, LOOP_INNER%= \n\t"
2875
+ "vsetvli t0, zero, e32, mf2 \n\t"
2876
+
2877
+ SQ4BIT_KERNEL_ACC_1X4X4
2878
+ "addi s7, s1, 64 \n\t"
2879
+
2880
+ "bnez %[CNT], LOOP_K%= \n\t"
2881
+
2882
+ "addi t3, zero, 16 \n\t"
2883
+ "addi s1, %[C], 16 \n\t"
2884
+ "addi s2, %[C], 32 \n\t"
2885
+ "addi s3, %[C], 48 \n\t"
2886
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2887
+ "vse32.v v28, (%[C]) \n\t"
2888
+ "vse32.v v29, (s1) \n\t"
2889
+ "vse32.v v30, (s2) \n\t"
2890
+ "vse32.v v31, (s3) \n\t"
2891
+ "jal x0, END%= \n\t"
2892
+
2893
+ "ST_TAIL%=: \n\t"
2894
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2895
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2896
+ "vse32.v v28, (%[C]) \n\t"
2897
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2898
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2899
+ "vse32.v v29, (s1) \n\t"
2900
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2901
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2902
+ "vse32.v v30, (s2) \n\t"
2903
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2904
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2905
+ "vse32.v v31, (s3) \n\t"
2906
+ "END%=: \n\t"
2907
+
2908
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
2909
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2910
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2911
+ }
2912
+ }
2913
+ } else {
2914
+ for (size_t n = 0; n < CountN; n += 16) {
2915
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2916
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2917
+ n * BlockCountK * BlkLen / 2 + // b data
2918
+ n * BlockCountK * sizeof(float); // scale
2919
+ float * CPtr = C + n;
2920
+ size_t cnt = BlockCountK;
2921
+ if (Bias != nullptr) {
2922
+ const float * bias = Bias + n;
2923
+ __asm__ volatile(
2924
+ "addi t3, %[NBLKS], 0 \n\t"
2925
+ "addi s1, %[B], 0 \n\t"
2926
+ "addi s2, %[B], 16 \n\t"
2927
+ "addi s3, %[B], 32 \n\t"
2928
+ "addi s4, %[B], 48 \n\t"
2929
+ "addi s5, %[A], 0 \n\t"
2930
+ "addi s6, %[A], 12 \n\t"
2931
+ "vsetvli t0, t3, e32, mf2 \n\t"
2932
+ "vle32.v v28, (%[BIAS]) \n\t"
2933
+ "sub t3, t3, t0 \n\t"
2934
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2935
+ "vsetvli t0, t3, e32, mf2 \n\t"
2936
+ "vle32.v v29, (%[BIAS]) \n\t"
2937
+ "sub t3, t3, t0 \n\t"
2938
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2939
+ "vsetvli t0, t3, e32, mf2 \n\t"
2940
+ "vle32.v v30, (%[BIAS]) \n\t"
2941
+ "sub t3, t3, t0 \n\t"
2942
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2943
+ "vsetvli t0, t3, e32, mf2 \n\t"
2944
+ "vle32.v v31, (%[BIAS]) \n\t"
2945
+ "vsetvli t0, zero, e32, mf2 \n\t"
2946
+ "LOOP_K%=: \n\t"
2947
+ "vle32.v v8, (s1) \n\t"
2948
+ "addi s1, s1, 64 \n\t"
2949
+ "vle32.v v9, (s2) \n\t"
2950
+ "addi s2, s2, 80 \n\t"
2951
+ "vle32.v v10, (s3) \n\t"
2952
+ "addi s3, s3, 96 \n\t"
2953
+ "vle32.v v11, (s4) \n\t"
2954
+ "addi s4, s4, 112 \n\t"
2955
+ "flw f1, (s5) \n\t"
2956
+ "addi s5, s5, 4 \n\t"
2957
+
2958
+ "addi t5, %[INNER], 0 \n\t"
2959
+ "vxor.vv v16, v16, v16 \n\t"
2960
+ "vxor.vv v18, v18, v18 \n\t"
2961
+ "vxor.vv v20, v20, v20 \n\t"
2962
+ "vxor.vv v22, v22, v22 \n\t"
2963
+ "vfmul.vf v24, v8, f1 \n\t"
2964
+ "vfmul.vf v25, v9, f1 \n\t"
2965
+ "vfmul.vf v26, v10, f1 \n\t"
2966
+ "vfmul.vf v27, v11, f1 \n\t"
2967
+ "addi %[CNT], %[CNT], -1 \n\t"
2968
+ "vsetvli t0, zero, e8, m1 \n\t"
2969
+ "LOOP_INNER%=: \n\t"
2970
+
2971
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2972
+
2973
+ "vadd.vi v0, v0, -8 \n\t"
2974
+ "vadd.vi v1, v1, -8 \n\t"
2975
+ "vadd.vi v2, v2, -8 \n\t"
2976
+ "vadd.vi v3, v3, -8 \n\t"
2977
+ "vadd.vi v4, v4, -8 \n\t"
2978
+ "vadd.vi v5, v5, -8 \n\t"
2979
+ "vadd.vi v6, v6, -8 \n\t"
2980
+ "vadd.vi v7, v7, -8 \n\t"
2981
+
2982
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2983
+
2984
+ "bnez t5, LOOP_INNER%= \n\t"
2985
+ "vsetvli t0, zero, e32, mf2 \n\t"
2986
+
2987
+ SQ4BIT_KERNEL_ACC_1X4X4
2988
+
2989
+ "bnez %[CNT], LOOP_K%= \n\t"
2990
+ "addi t3, zero, 16 \n\t"
2991
+ "addi s1, %[C], 16 \n\t"
2992
+ "addi s2, %[C], 32 \n\t"
2993
+ "addi s3, %[C], 48 \n\t"
2994
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2995
+ "vse32.v v28, (%[C]) \n\t"
2996
+ "vse32.v v29, (s1) \n\t"
2997
+ "vse32.v v30, (s2) \n\t"
2998
+ "vse32.v v31, (s3) \n\t"
2999
+ "jal x0, END%= \n\t"
3000
+
3001
+ "ST_TAIL%=: \n\t"
3002
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3003
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3004
+ "vse32.v v28, (%[C]) \n\t"
3005
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3006
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3007
+ "vse32.v v29, (s1) \n\t"
3008
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3009
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3010
+ "vse32.v v30, (s2) \n\t"
3011
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3012
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3013
+ "vse32.v v31, (s3) \n\t"
3014
+ "END%=: \n\t"
3015
+
3016
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
3017
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
3018
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
3019
+ } else {
3020
+ __asm__ volatile(
3021
+ "vsetvli t0, zero, e32, m4 \n\t"
3022
+ "vxor.vv v28, v28, v28 \n\t"
3023
+ "addi s1, %[B], 0 \n\t"
3024
+ "addi s2, %[B], 16 \n\t"
3025
+ "addi s3, %[B], 32 \n\t"
3026
+ "addi s4, %[B], 48 \n\t"
3027
+
3028
+ "addi s5, %[A], 0 \n\t"
3029
+ "addi s6, %[A], 12 \n\t"
3030
+ "vsetvli t0, zero, e32, mf2 \n\t"
3031
+ "LOOP_K%=: \n\t"
3032
+ "vle32.v v8, (s1) \n\t"
3033
+ "addi s1, s1, 64 \n\t"
3034
+ "vle32.v v9, (s2) \n\t"
3035
+ "addi s2, s2, 80 \n\t"
3036
+ "vle32.v v10, (s3) \n\t"
3037
+ "addi s3, s3, 96 \n\t"
3038
+ "vle32.v v11, (s4) \n\t"
3039
+ "addi s4, s4, 112 \n\t"
3040
+ "flw f1, (s5) \n\t"
3041
+ "addi s5, s5, 4 \n\t"
3042
+
3043
+ "addi t5, %[INNER], 0 \n\t"
3044
+ "vxor.vv v16, v16, v16 \n\t"
3045
+ "vxor.vv v18, v18, v18 \n\t"
3046
+ "vxor.vv v20, v20, v20 \n\t"
3047
+ "vxor.vv v22, v22, v22 \n\t"
3048
+ "vfmul.vf v24, v8, f1 \n\t"
3049
+ "vfmul.vf v25, v9, f1 \n\t"
3050
+ "vfmul.vf v26, v10, f1 \n\t"
3051
+ "vfmul.vf v27, v11, f1 \n\t"
3052
+ "addi %[CNT], %[CNT], -1 \n\t"
3053
+ "vsetvli t0, zero, e8, m1 \n\t"
3054
+ "LOOP_INNER%=: \n\t"
3055
+
3056
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
3057
+
3058
+ "vadd.vi v0, v0, -8 \n\t"
3059
+ "vadd.vi v1, v1, -8 \n\t"
3060
+ "vadd.vi v2, v2, -8 \n\t"
3061
+ "vadd.vi v3, v3, -8 \n\t"
3062
+ "vadd.vi v4, v4, -8 \n\t"
3063
+ "vadd.vi v5, v5, -8 \n\t"
3064
+ "vadd.vi v6, v6, -8 \n\t"
3065
+ "vadd.vi v7, v7, -8 \n\t"
3066
+
3067
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
3068
+
3069
+ "bnez t5, LOOP_INNER%= \n\t"
3070
+ "vsetvli t0, zero, e32, mf2 \n\t"
3071
+
3072
+ SQ4BIT_KERNEL_ACC_1X4X4
3073
+
3074
+ "bnez %[CNT], LOOP_K%= \n\t"
3075
+ "addi t3, zero, 16 \n\t"
3076
+ "addi s1, %[C], 16 \n\t"
3077
+ "addi s2, %[C], 32 \n\t"
3078
+ "addi s3, %[C], 48 \n\t"
3079
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
3080
+ "vse32.v v28, (%[C]) \n\t"
3081
+ "vse32.v v29, (s1) \n\t"
3082
+ "vse32.v v30, (s2) \n\t"
3083
+ "vse32.v v31, (s3) \n\t"
3084
+ "jal x0, END%= \n\t"
3085
+
3086
+ "ST_TAIL%=: \n\t"
3087
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3088
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3089
+ "vse32.v v28, (%[C]) \n\t"
3090
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3091
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3092
+ "vse32.v v29, (s1) \n\t"
3093
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3094
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3095
+ "vse32.v v30, (s2) \n\t"
3096
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3097
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3098
+ "vse32.v v31, (s3) \n\t"
3099
+ "END%=: \n\t"
3100
+
3101
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
3102
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
3103
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
3104
+ }
3105
+ }
3106
+ }
3107
+ }
3108
+
3109
+ template <bool HasZeroPoint>
3110
+ inline void SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen,
3111
+ const std::byte * QuantA,
3112
+ const std::byte * QuantBData,
3113
+ const float * QuantBScale,
3114
+ const std::byte * QuantBZeroPoint,
3115
+ float * C,
3116
+ size_t CountM,
3117
+ size_t CountN,
3118
+ size_t BlockStrideQuantB,
3119
+ const float * Bias,
3120
+ const size_t ldc,
3121
+ const size_t scalestride) {
3122
+ if (scalestride == 4) {
3123
+ SQ4BitGemmM4Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C,
3124
+ CountN, BlockStrideQuantB, Bias, ldc);
3125
+
3126
+ } else if (scalestride == 2) {
3127
+ SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>(
3128
+ BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias, ldc);
3129
+ }
3130
+ }
3131
+
3132
+ template <bool HasZeroPoint>
3133
+ inline void SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen,
3134
+ const std::byte * QuantA,
3135
+ const std::byte * QuantBData,
3136
+ const float * QuantBScale,
3137
+ const std::byte * QuantBZeroPoint,
3138
+ float * C,
3139
+ size_t CountM,
3140
+ size_t CountN,
3141
+ size_t BlockStrideQuantB,
3142
+ const float * Bias,
3143
+ const size_t ldc,
3144
+ const size_t scalestride) {
3145
+ if (scalestride == 4) {
3146
+ SQ4BitGemmM1Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C,
3147
+ CountN, BlockStrideQuantB, Bias);
3148
+ } else if (scalestride == 2) {
3149
+ SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale,
3150
+ QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias);
3151
+ }
3152
+ }
3153
+
3154
+ } // namespace
3155
+
3156
+ namespace ime1 {
3157
+ size_t gemm_kernel_i8i4(size_t BlkLen,
3158
+ const std::byte * QuantA,
3159
+ const std::byte * QuantBData,
3160
+ const float * QuantBScale,
3161
+ const std::byte * QuantBZeroPoint,
3162
+ float * C,
3163
+ size_t CountM,
3164
+ size_t CountN,
3165
+ size_t CountK,
3166
+ size_t BlockCountK,
3167
+ size_t ldc,
3168
+ const float * Bias,
3169
+ const size_t ScaleStride) {
3170
+ GGML_UNUSED(CountM);
3171
+ GGML_UNUSED(CountK);
3172
+ GGML_UNUSED(ldc);
3173
+ if (CountM >= 4) {
3174
+ if (QuantBZeroPoint != nullptr) {
3175
+ SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint,
3176
+ C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride);
3177
+ } else {
3178
+ SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale,
3179
+ QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias,
3180
+ ldc, ScaleStride);
3181
+ }
3182
+ return 4;
3183
+ } else {
3184
+ if (QuantBZeroPoint != nullptr) {
3185
+ SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint,
3186
+ C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride);
3187
+ } else {
3188
+ SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale,
3189
+ QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias,
3190
+ ldc, ScaleStride);
3191
+ }
3192
+ return 1;
3193
+ }
3194
+ }
3195
+ } // namespace ime1
3196
+ } // namespace sqnbitgemm_spacemit_ime