whispercpp 1.3.3 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (963) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +79 -25
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/CMakeLists.txt +1 -0
  23. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  24. data/ext/sources/examples/addon.node/index.js +7 -5
  25. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  26. data/ext/sources/examples/bench/bench.cpp +26 -16
  27. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  28. data/ext/sources/examples/cli/cli.cpp +122 -111
  29. data/ext/sources/examples/command/command.cpp +26 -24
  30. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  31. data/ext/sources/examples/common-ggml.cpp +2 -0
  32. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  34. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  35. data/ext/sources/examples/server/server.cpp +34 -24
  36. data/ext/sources/examples/server.py +6 -1
  37. data/ext/sources/examples/stream/stream.cpp +4 -2
  38. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  39. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  40. data/ext/sources/examples/talk-llama/CMakeLists.txt +7 -3
  41. data/ext/sources/examples/talk-llama/llama-adapter.cpp +113 -7
  42. data/ext/sources/examples/talk-llama/llama-adapter.h +13 -1
  43. data/ext/sources/examples/talk-llama/llama-arch.cpp +2136 -1491
  44. data/ext/sources/examples/talk-llama/llama-arch.h +125 -3
  45. data/ext/sources/examples/talk-llama/llama-batch.cpp +174 -100
  46. data/ext/sources/examples/talk-llama/llama-batch.h +46 -20
  47. data/ext/sources/examples/talk-llama/llama-chat.cpp +199 -8
  48. data/ext/sources/examples/talk-llama/llama-chat.h +11 -0
  49. data/ext/sources/examples/talk-llama/llama-context.cpp +1213 -413
  50. data/ext/sources/examples/talk-llama/llama-context.h +99 -36
  51. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -4
  52. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  53. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  54. data/ext/sources/examples/talk-llama/llama-graph.cpp +883 -294
  55. data/ext/sources/examples/talk-llama/llama-graph.h +361 -161
  56. data/ext/sources/examples/talk-llama/llama-hparams.cpp +144 -6
  57. data/ext/sources/examples/talk-llama/llama-hparams.h +100 -23
  58. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  59. data/ext/sources/examples/talk-llama/llama-impl.h +3 -1
  60. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +328 -0
  61. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  62. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2100 -0
  63. data/ext/sources/examples/talk-llama/llama-kv-cache.h +373 -27
  64. data/ext/sources/examples/talk-llama/llama-kv-cells.h +124 -30
  65. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +63 -41
  66. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  67. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +77 -35
  68. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +15 -16
  69. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  70. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  71. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  72. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +93 -9
  73. data/ext/sources/examples/talk-llama/llama-model-loader.h +9 -2
  74. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  75. data/ext/sources/examples/talk-llama/llama-model.cpp +3369 -10145
  76. data/ext/sources/examples/talk-llama/llama-model.h +104 -12
  77. data/ext/sources/examples/talk-llama/llama-quant.cpp +53 -30
  78. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1520 -324
  79. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  80. data/ext/sources/examples/talk-llama/llama-vocab.cpp +562 -39
  81. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -0
  82. data/ext/sources/examples/talk-llama/llama.cpp +794 -12
  83. data/ext/sources/examples/talk-llama/llama.h +246 -190
  84. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  85. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  86. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  88. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  89. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  90. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  91. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  92. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  93. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  94. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  95. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  96. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  97. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  98. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  99. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  100. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  101. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  102. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  103. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  104. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  105. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  106. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  107. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  108. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  109. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  110. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  111. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  112. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  113. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  114. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  115. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  116. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  117. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  118. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  119. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  120. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  121. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  122. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  123. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  124. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  125. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  126. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  127. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  128. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  129. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  130. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  131. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  132. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  133. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  134. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  135. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  136. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  137. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  156. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  158. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  159. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  160. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  161. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  162. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  163. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  166. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  168. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  169. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  171. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  172. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  173. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  174. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  178. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  179. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  180. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  181. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  182. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  183. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  184. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  185. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  186. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  187. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  188. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  189. data/ext/sources/examples/talk-llama/unicode.cpp +309 -16
  190. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  191. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  192. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  193. data/ext/sources/examples/whisper.wasm/index-tmpl.html +18 -17
  194. data/ext/sources/ggml/CMakeLists.txt +135 -79
  195. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +21 -2
  198. data/ext/sources/ggml/include/ggml-cpu.h +2 -1
  199. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  200. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  201. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  202. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  203. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  204. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +406 -23
  207. data/ext/sources/ggml/src/CMakeLists.txt +99 -13
  208. data/ext/sources/ggml/src/ggml-alloc.c +368 -161
  209. data/ext/sources/ggml/src/ggml-backend-impl.h +5 -5
  210. data/ext/sources/ggml/src/ggml-backend-reg.cpp +55 -14
  211. data/ext/sources/ggml/src/ggml-backend.cpp +290 -57
  212. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  213. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -13
  214. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  215. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +59 -45
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  217. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2586 -1917
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +348 -309
  219. data/ext/sources/ggml/src/ggml-cann/common.h +350 -133
  220. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +894 -625
  221. data/ext/sources/ggml/src/ggml-common.h +17 -0
  222. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +167 -75
  223. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +5 -2
  224. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  225. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +560 -622
  226. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1002 -270
  227. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +107 -587
  228. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  229. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  230. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +373 -486
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  232. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  233. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  234. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  235. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4682 -1660
  237. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +82 -4
  238. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  239. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +18 -9
  240. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +263 -111
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +39 -28
  242. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +683 -82
  243. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +38 -43
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +435 -119
  245. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  246. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1234 -1182
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  248. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +2167 -1480
  249. data/ext/sources/ggml/src/ggml-cpu/ops.h +10 -12
  250. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  251. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  252. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +1132 -81
  253. data/ext/sources/ggml/src/ggml-cpu/repack.h +36 -0
  254. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +120 -93
  255. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1025 -0
  256. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  258. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  259. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  260. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  261. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  262. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  263. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +294 -27
  264. data/ext/sources/ggml/src/ggml-cpu/vec.h +606 -48
  265. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +92 -17
  266. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  267. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  268. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  269. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  270. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  271. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  272. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  273. data/ext/sources/ggml/src/ggml-cuda/common.cuh +588 -128
  274. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  275. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  276. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  278. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +25 -0
  279. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  280. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +335 -485
  281. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  282. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  283. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  284. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  286. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  287. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  288. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +519 -378
  289. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +750 -637
  290. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +49 -0
  291. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1244 -0
  292. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +586 -0
  293. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +98 -61
  294. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  295. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +230 -197
  296. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  297. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  298. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  300. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1557 -294
  301. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  302. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  303. data/ext/sources/ggml/src/ggml-cuda/mean.cu +57 -2
  304. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +915 -69
  305. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +171 -0
  306. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +835 -0
  307. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  308. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  309. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +109 -67
  310. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1601 -733
  311. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +802 -0
  312. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +12 -0
  313. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +286 -149
  314. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  315. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  316. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  317. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  318. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/pad.cu +86 -32
  320. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  321. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  322. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +163 -10
  323. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  324. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  325. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  326. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/rope.cu +207 -98
  328. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  329. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  330. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +330 -0
  331. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  332. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  333. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  334. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  335. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +325 -61
  337. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  338. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  339. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -12
  340. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +291 -104
  341. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  342. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +40 -19
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  407. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  408. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +351 -0
  409. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +21 -0
  410. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  411. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  413. data/ext/sources/ggml/src/ggml-cuda/unary.cu +189 -5
  414. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +44 -0
  415. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +248 -6
  416. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  417. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +8 -0
  418. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +70 -37
  419. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +10 -3
  420. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  421. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  422. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  423. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  424. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  425. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  426. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  427. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  428. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  429. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  430. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  431. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  432. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  433. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  434. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  435. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  436. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  437. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  438. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  439. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  440. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  441. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  442. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  443. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  444. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  445. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  446. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  447. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  448. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  449. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  450. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  451. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +16 -13
  452. data/ext/sources/ggml/src/ggml-impl.h +186 -15
  453. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  454. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  455. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  456. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  457. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +609 -0
  458. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1743 -0
  459. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +273 -0
  460. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1686 -0
  461. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +356 -61
  462. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +4161 -0
  463. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +94 -0
  464. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +724 -0
  465. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +4495 -1876
  466. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +21 -9
  467. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +29 -0
  468. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4005 -427
  469. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  470. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  471. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  472. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  473. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +147 -0
  474. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  475. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  476. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  477. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  478. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +371 -0
  479. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  480. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  481. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  482. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  483. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  484. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  485. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  486. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  487. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  488. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  489. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +146 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +147 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +94 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +208 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +33 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  521. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  522. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  523. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  524. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +497 -195
  525. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  526. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  527. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +6 -5
  530. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  531. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +50 -30
  532. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  533. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +200 -99
  534. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  535. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  536. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -309
  537. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  538. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  539. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  540. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +77 -34
  541. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +397 -314
  542. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  543. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  544. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  545. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +643 -413
  546. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  547. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  548. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  549. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +223 -132
  550. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +230 -55
  551. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  552. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  553. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  554. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  555. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  556. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  557. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  558. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  559. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  560. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  561. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  562. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  563. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +65 -59
  564. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  565. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  566. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  567. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  568. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +330 -165
  569. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  570. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  571. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  572. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  573. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +60 -6
  574. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  575. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  576. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +7398 -2635
  577. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  578. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  579. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +43 -3
  580. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  581. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +15 -6
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +56 -39
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +347 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +5 -5
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +67 -13
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +158 -16
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +38 -3
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +2 -2
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +3 -2
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +7 -6
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +5 -3
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +4 -4
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +2 -2
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +4 -4
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +4 -4
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +2 -2
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +103 -36
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +220 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +139 -45
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +113 -38
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +75 -14
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +19 -17
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +21 -12
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +28 -18
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +4 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +33 -17
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +125 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +227 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +20 -14
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +143 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -556
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +230 -51
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +566 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +90 -223
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +195 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +41 -5
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +59 -9
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +104 -14
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -52
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -35
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -35
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +5 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +3 -3
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +30 -8
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +6 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +16 -6
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +5 -4
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +435 -24
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +148 -6
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl +25 -0
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +619 -177
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  745. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +80 -0
  746. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  747. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3087 -0
  748. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  749. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  750. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  751. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  752. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  753. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +147 -0
  754. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  755. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  756. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  757. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  758. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  759. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  760. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  761. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  762. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  763. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +123 -0
  764. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  765. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  766. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  767. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  768. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  769. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  770. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  771. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  772. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  773. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  774. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  775. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  776. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  777. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  778. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  779. data/ext/sources/ggml/src/ggml.c +901 -129
  780. data/ext/sources/ggml/src/gguf.cpp +8 -1
  781. data/ext/sources/include/whisper.h +1 -0
  782. data/ext/sources/src/CMakeLists.txt +3 -1
  783. data/ext/sources/src/whisper.cpp +124 -81
  784. data/ext/sources/tests/CMakeLists.txt +8 -1
  785. data/ext/sources/tests/test-vad-full.cpp +7 -5
  786. data/ext/sources/tests/test-vad.cpp +3 -3
  787. data/extsources.rb +1 -0
  788. data/lib/whisper/model/uri.rb +17 -18
  789. data/sig/whisper.rbs +126 -2
  790. data/test/test_params.rb +24 -8
  791. data/test/test_segment.rb +0 -1
  792. data/test/test_token.rb +70 -0
  793. data/test/test_vad.rb +1 -1
  794. data/test/test_vad_context.rb +50 -0
  795. data/test/test_vad_segment.rb +19 -0
  796. data/test/test_vad_segments.rb +16 -0
  797. data/test/test_whisper.rb +8 -1
  798. data/whispercpp.gemspec +1 -1
  799. metadata +439 -179
  800. data/ext/sources/build-xcframework.sh +0 -547
  801. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  802. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  803. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  804. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  805. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  806. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  807. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  808. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  809. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  810. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  811. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  812. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  813. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  814. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  815. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  816. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  817. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  818. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  819. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  820. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  821. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  822. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  823. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  824. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  825. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  826. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  827. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +0 -506
  828. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +0 -11
  829. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  830. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  831. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  832. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  833. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  834. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  835. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  836. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  837. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  838. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  839. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  840. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  841. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  842. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  843. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  844. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  845. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  846. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  847. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  848. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  849. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  850. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  851. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  852. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  853. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  854. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  855. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  856. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  857. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  858. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  859. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  860. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  861. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  862. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  863. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  864. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  865. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  866. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  867. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  868. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  869. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  870. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  871. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  872. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  873. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  874. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  875. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  876. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  877. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  878. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  879. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  880. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  881. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  882. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  883. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  884. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  885. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  886. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  887. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  888. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  889. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  890. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  891. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  892. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  893. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  894. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  895. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  896. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  897. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  898. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  899. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  900. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  901. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  902. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  903. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  904. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  905. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  906. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  907. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  908. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  909. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  910. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  911. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  912. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  913. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  914. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  915. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  916. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  917. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  918. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  919. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  920. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  921. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  922. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  923. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  924. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  925. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  926. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  927. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  928. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  929. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  930. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  931. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  932. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  933. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  934. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  935. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  936. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  937. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  938. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  939. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  940. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  941. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  942. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  943. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  944. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  945. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  946. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  947. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  948. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  949. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  950. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  951. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  952. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  953. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  954. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
  955. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +0 -162
  956. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +0 -118
  957. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -99
  958. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -58
  959. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  960. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  961. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  962. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  963. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
@@ -0,0 +1,3196 @@
1
+ #include "ggml.h"
2
+ #include "ime_kernels.h"
3
+
4
+ #include <algorithm>
5
+ #include <cmath>
6
+
7
+ // clang-format off
8
+ #if defined(__GNUC__)
9
+ #pragma GCC diagnostic ignored "-Woverlength-strings"
10
+ #pragma GCC diagnostic ignored "-Wcast-qual"
11
+ #pragma GCC diagnostic ignored "-Wunused-parameter"
12
+ #endif
13
+ // clang-format on
14
+ namespace sqnbitgemm_spacemit_ime {
15
+
16
+ #define QUANTIZEM4ROW_KERNEL \
17
+ "vmv.s.x v16, zero \n\t" \
18
+ "vfabs.v v8, v0 \n\t" \
19
+ "vfredmax.vs v16, v8, v16 \n\t" \
20
+ "vfmv.f.s f10, v16 \n\t" \
21
+ "fmul.s f10, f10, %[RMAXREC] \n\t" \
22
+ "fsw f10, (a1) \n\t" \
23
+ "fdiv.s f11, %[FONE], f10 \n\t" \
24
+ "vfmul.vf v16, v0, f11 \n\t" \
25
+ "vfcvt.x.f.v v16, v16 \n\t" \
26
+ "vsetvli t0, zero, e16, mf2 \n\t" \
27
+ "vnclip.wx v16, v16, zero \n\t" \
28
+ "vnclip.wx v17, v17, zero \n\t" \
29
+ "vnclip.wx v18, v18, zero \n\t" \
30
+ "vnclip.wx v19, v19, zero \n\t" \
31
+ "vnclip.wx v20, v20, zero \n\t" \
32
+ "vnclip.wx v21, v21, zero \n\t" \
33
+ "vnclip.wx v22, v22, zero \n\t" \
34
+ "vnclip.wx v23, v23, zero \n\t" \
35
+ "vsetvli t0, zero, e8, mf4 \n\t" \
36
+ "vnclip.wx v24, v16, zero \n\t" \
37
+ "vnclip.wx v25, v17, zero \n\t" \
38
+ "vnclip.wx v26, v18, zero \n\t" \
39
+ "vnclip.wx v27, v19, zero \n\t" \
40
+ "vnclip.wx v28, v20, zero \n\t" \
41
+ "vnclip.wx v29, v21, zero \n\t" \
42
+ "vnclip.wx v30, v22, zero \n\t" \
43
+ "vnclip.wx v31, v23, zero \n\t"
44
+
45
+ #define QUANTIZEM4ROW_STORE \
46
+ "addi t1, %[BlkLen], 0 \n\t" \
47
+ "vsetvli t0, t1, e8, mf4 \n\t" \
48
+ "vse8.v v24, (s1) \n\t" \
49
+ "addi s1, s1, 32 \n\t" \
50
+ "sub t1, t1, t0 \n\t" \
51
+ "vsetvli t0, t1, e8, mf4 \n\t" \
52
+ "vse8.v v25, (s1) \n\t" \
53
+ "addi s1, s1, 32 \n\t" \
54
+ "sub t1, t1, t0 \n\t" \
55
+ "vsetvli t0, t1, e8, mf4 \n\t" \
56
+ "vse8.v v26, (s1) \n\t" \
57
+ "addi s1, s1, 32 \n\t" \
58
+ "sub t1, t1, t0 \n\t" \
59
+ "vsetvli t0, t1, e8, mf4 \n\t" \
60
+ "vse8.v v27, (s1) \n\t" \
61
+ "addi s1, s1, 32 \n\t" \
62
+ "sub t1, t1, t0 \n\t" \
63
+ "vsetvli t0, t1, e8, mf4 \n\t" \
64
+ "vse8.v v28, (s1) \n\t" \
65
+ "addi s1, s1, 32 \n\t" \
66
+ "sub t1, t1, t0 \n\t" \
67
+ "vsetvli t0, t1, e8, mf4 \n\t" \
68
+ "vse8.v v29, (s1) \n\t" \
69
+ "addi s1, s1, 32 \n\t" \
70
+ "sub t1, t1, t0 \n\t" \
71
+ "vsetvli t0, t1, e8, mf4 \n\t" \
72
+ "vse8.v v30, (s1) \n\t" \
73
+ "addi s1, s1, 32 \n\t" \
74
+ "sub t1, t1, t0 \n\t" \
75
+ "vsetvli t0, t1, e8, mf4 \n\t" \
76
+ "vse8.v v31, (s1) \n\t"
77
+
78
+ namespace ime1 {
79
+ void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) {
80
+ constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);
81
+ const float fone = 1.0f;
82
+
83
+ if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64) {
84
+ for (size_t row_index = 0; row_index < 4; ++row_index) {
85
+ const float * SRC = A + row_index * CountK;
86
+ std::byte * DST = QuantA + row_index * sizeof(float);
87
+
88
+ const size_t offset = (4 - row_index) * 4 + row_index * 8;
89
+ const size_t stride = 4 * (sizeof(float) + BlkLen);
90
+ __asm__ volatile(
91
+ "vsetvli t0, zero, e32, m8 \n\t"
92
+ "addi t2, %[CountK], 0 \n\t"
93
+ "addi a1, %[DST], 0 \n\t"
94
+ "blt t2, %[BlkLen], TAIL%= \n\t"
95
+
96
+ "LOOP%=: \n\t"
97
+ "vsetvli t0, %[BlkLen], e32, m8 \n\t"
98
+ "vle32.v v0, (%[SRC]) \n\t"
99
+ "sub t2, t2, t0 \n\t"
100
+ "slli t1, t0, 2 \n\t"
101
+ "add %[SRC], %[SRC], t1 \n\t"
102
+ "add s1, a1, %[OFFSET] \n\t"
103
+
104
+ QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE
105
+
106
+ "add a1, a1, %[STRIDE] \n\t"
107
+ "bge t2, %[BlkLen], LOOP%= \n\t"
108
+
109
+ "TAIL%=: \n\t"
110
+ "blez t2, QUIT%= \n\t"
111
+ "vsetvli t0, zero, e32, m8 \n\t"
112
+ "vxor.vv v16, v16, v16 \n\t"
113
+ "vxor.vv v24, v24, v24 \n\t"
114
+ "vsetvli t0, t2, e32, m8 \n\t"
115
+ "vle32.v v0, (%[SRC]) \n\t"
116
+ "add s1, a1, %[OFFSET] \n\t"
117
+
118
+ QUANTIZEM4ROW_KERNEL
119
+
120
+ "addi t3, %[BlkLen], 0 \n\t"
121
+ "addi s2, s1, 0 \n\t"
122
+ "vsetvli t0, zero, e8, mf4 \n\t"
123
+ "vxor.vv v8, v8, v8 \n\t"
124
+ "SET_ZERO%=: \n\t"
125
+ "vse8.v v8, (s2) \n\t"
126
+ "addi s2, s2, 32 \n\t"
127
+ "addi t3, t3, -8 \n\t"
128
+ "bnez t3, SET_ZERO%= \n\t"
129
+
130
+ QUANTIZEM4ROW_STORE
131
+
132
+ "QUIT%=: \n\t"
133
+ : [SRC] "+r"(SRC)
134
+ : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
135
+ [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
136
+ : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11");
137
+ }
138
+ } else if (BlkLen == 128) {
139
+ for (size_t row_index = 0; row_index < 4; ++row_index) {
140
+ const float * SRC = A + row_index * CountK;
141
+ std::byte * DST = QuantA + row_index * sizeof(float);
142
+
143
+ const size_t offset = (4 - row_index) * 4 + row_index * 8;
144
+ const size_t stride = 4 * (sizeof(float) + BlkLen);
145
+ __asm__ volatile(
146
+ "vsetvli t0, zero, e32, m8 \n\t"
147
+ "li t6, 32 \n\t"
148
+ "addi t2, %[CountK], 0 \n\t"
149
+ "addi a1, %[DST], 0 \n\t"
150
+ "add s1, a1, %[OFFSET] \n\t"
151
+ "blt t2, %[BlkLen], TAIL%= \n\t"
152
+
153
+ "LOOP%=: \n\t"
154
+ "vsetvli t0, zero, e32, m8 \n\t"
155
+ "vle32.v v0, (%[SRC]) \n\t"
156
+ "addi %[SRC], %[SRC], 256 \n\t"
157
+ "vle32.v v8, (%[SRC]) \n\t"
158
+ "addi %[SRC], %[SRC], 256 \n\t"
159
+ "addi t2, t2, -128 \n\t"
160
+
161
+ "QUANTIZE%=: \n\t"
162
+ "add s1, a1, %[OFFSET] \n\t"
163
+ "vfabs.v v16, v0 \n\t"
164
+ "vfabs.v v24, v8 \n\t"
165
+ "vfmax.vv v16, v24, v16 \n\t"
166
+ "vfredmax.vs v24, v16, v24 \n\t"
167
+ "vfmv.f.s f10, v24 \n\t"
168
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
169
+ "fsw f10, (a1) \n\t"
170
+ "fdiv.s f11, %[FONE], f10 \n\t"
171
+ "vfmul.vf v16, v0, f11 \n\t"
172
+ "vfmul.vf v24, v8, f11 \n\t"
173
+ "vfcvt.x.f.v v16, v16 \n\t"
174
+ "vfcvt.x.f.v v24, v24 \n\t"
175
+ "vsetvli t0, zero, e16, m4 \n\t"
176
+ "vnclip.wx v16, v16, zero \n\t"
177
+ "vnclip.wx v20, v24, zero \n\t"
178
+ "vsetvli t0, zero, e8, m4 \n\t"
179
+ "vnclip.wx v16, v16, zero \n\t"
180
+ "vsetvli t0, zero, e64, m4 \n\t"
181
+ "vsse64.v v16, (s1), t6 \n\t"
182
+ "add a1, a1, %[STRIDE] \n\t"
183
+ "bge t2, %[BlkLen], LOOP%= \n\t"
184
+
185
+ "TAIL%=: \n\t"
186
+ "blez t2, QUIT%= \n\t"
187
+ "vsetvli t0, zero, e32, m8 \n\t"
188
+ "vxor.vv v0, v0, v0 \n\t"
189
+ "vxor.vv v8, v8, v8 \n\t"
190
+ "vxor.vv v16, v16, v16 \n\t"
191
+ "vxor.vv v24, v24, v24 \n\t"
192
+ "vsetvli t0, t2, e32, m8 \n\t"
193
+ "sub t2, t2, t0 \n\t"
194
+ "vle32.v v0, (%[SRC]) \n\t"
195
+ "addi %[SRC], %[SRC], 256 \n\t"
196
+ "vsetvli t0, t2, e32, m8 \n\t"
197
+ "vle32.v v8, (%[SRC]) \n\t"
198
+ "sub t2, t2, t2 \n\t"
199
+ "vsetvli t0, zero, e32, m8 \n\t"
200
+ "jal x0, QUANTIZE%= \n\t"
201
+
202
+ "QUIT%=: \n\t"
203
+ : [SRC] "+r"(SRC)
204
+ : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
205
+ [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
206
+ : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11");
207
+ }
208
+ } else if (BlkLen == 256) {
209
+ for (size_t row_index = 0; row_index < 4; ++row_index) {
210
+ const float * SRC = A + row_index * CountK;
211
+ std::byte * DST = QuantA + row_index * sizeof(float);
212
+ const size_t offset = (4 - row_index) * 4 + row_index * 8;
213
+ const size_t stride = 4 * (sizeof(float) + BlkLen);
214
+ __asm__ volatile(
215
+ "vsetvli t0, zero, e32, m8 \n\t"
216
+ "li t6, 32 \n\t"
217
+ "addi t2, %[CountK], 0 \n\t"
218
+ "addi a1, %[DST], 0 \n\t"
219
+ "add s1, a1, %[OFFSET] \n\t"
220
+ "blt t2, %[BlkLen], TAIL%= \n\t"
221
+
222
+ "LOOP%=: \n\t"
223
+ "vsetvli t0, zero, e32, m8 \n\t"
224
+ "vle32.v v0, (%[SRC]) \n\t"
225
+ "addi %[SRC], %[SRC], 256 \n\t"
226
+ "vle32.v v8, (%[SRC]) \n\t"
227
+ "addi %[SRC], %[SRC], 256 \n\t"
228
+ "vle32.v v16, (%[SRC]) \n\t"
229
+ "addi %[SRC], %[SRC], 256 \n\t"
230
+ "vle32.v v24, (%[SRC]) \n\t"
231
+ "addi %[SRC], %[SRC], -768 \n\t"
232
+ "addi t2, t2, -256 \n\t"
233
+ "vfabs.v v0, v0 \n\t"
234
+ "vfabs.v v8, v8 \n\t"
235
+ "vfabs.v v16, v16 \n\t"
236
+ "vfabs.v v24, v24 \n\t"
237
+ "vfmax.vv v8, v0, v8 \n\t"
238
+ "vfmax.vv v24, v24, v16 \n\t"
239
+ "vfmax.vv v8, v8, v24 \n\t"
240
+ "vfredmax.vs v24, v8, v24 \n\t"
241
+ "vfmv.f.s f10, v24 \n\t"
242
+ "vle32.v v0, (%[SRC]) \n\t"
243
+ "addi %[SRC], %[SRC], 256 \n\t"
244
+ "vle32.v v8, (%[SRC]) \n\t"
245
+ "addi %[SRC], %[SRC], 256 \n\t"
246
+ "vle32.v v16, (%[SRC]) \n\t"
247
+ "addi %[SRC], %[SRC], 256 \n\t"
248
+ "vle32.v v24, (%[SRC]) \n\t"
249
+ "addi %[SRC], %[SRC], 256 \n\t"
250
+
251
+ "QUANTIZE%=: \n\t"
252
+ "add s1, a1, %[OFFSET] \n\t"
253
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
254
+ "fsw f10, (a1) \n\t"
255
+ "fdiv.s f11, %[FONE], f10 \n\t"
256
+ "vfmul.vf v0, v0, f11 \n\t"
257
+ "vfmul.vf v8, v8, f11 \n\t"
258
+ "vfmul.vf v16, v16, f11 \n\t"
259
+ "vfmul.vf v24, v24, f11 \n\t"
260
+ "vfcvt.x.f.v v0, v0 \n\t"
261
+ "vfcvt.x.f.v v8, v8 \n\t"
262
+ "vfcvt.x.f.v v16, v16 \n\t"
263
+ "vfcvt.x.f.v v24, v24 \n\t"
264
+ "vsetvli t0, zero, e16, m4 \n\t"
265
+ "vnclip.wx v0, v0, zero \n\t"
266
+ "vnclip.wx v4, v8, zero \n\t"
267
+ "vnclip.wx v8, v16, zero \n\t"
268
+ "vnclip.wx v12, v24, zero \n\t"
269
+ "vsetvli t0, zero, e8, m4 \n\t"
270
+ "vnclip.wx v0, v0, zero \n\t"
271
+ "vnclip.wx v4, v8, zero \n\t"
272
+ "vsetvli t0, zero, e64, m8 \n\t"
273
+ "vsse64.v v0, (s1), t6 \n\t"
274
+ "add a1, a1, %[STRIDE] \n\t"
275
+ "bge t2, %[BlkLen], LOOP%= \n\t"
276
+
277
+ "TAIL%=: \n\t"
278
+ "blez t2, QUIT%= \n\t"
279
+ "vsetvli t0, zero, e32, m8 \n\t"
280
+ "vxor.vv v0, v0, v0 \n\t"
281
+ "vxor.vv v8, v8, v8 \n\t"
282
+ "vxor.vv v16, v16, v16 \n\t"
283
+ "vxor.vv v24, v24, v24 \n\t"
284
+ "addi t1, t2, 0 \n\t"
285
+ "vsetvli t0, t1, e32, m8 \n\t"
286
+ "sub t1, t1, t0 \n\t"
287
+ "vle32.v v0, (%[SRC]) \n\t"
288
+ "addi %[SRC], %[SRC], 256 \n\t"
289
+ "vsetvli t0, t1, e32, m8 \n\t"
290
+ "sub t1, t1, t0 \n\t"
291
+ "vle32.v v8, (%[SRC]) \n\t"
292
+ "addi %[SRC], %[SRC], 256 \n\t"
293
+ "vsetvli t0, t1, e32, m8 \n\t"
294
+ "sub t1, t1, t0 \n\t"
295
+ "vle32.v v16, (%[SRC]) \n\t"
296
+ "addi %[SRC], %[SRC], 256 \n\t"
297
+ "vsetvli t0, t1, e32, m8 \n\t"
298
+ "vle32.v v24, (%[SRC]) \n\t"
299
+ "addi %[SRC], %[SRC], -768 \n\t"
300
+ "vsetvli t0, zero, e32, m8 \n\t"
301
+ "vfabs.v v0, v0 \n\t"
302
+ "vfabs.v v8, v8 \n\t"
303
+ "vfabs.v v16, v16 \n\t"
304
+ "vfabs.v v24, v24 \n\t"
305
+ "vfmax.vv v8, v0, v8 \n\t"
306
+ "vfmax.vv v24, v16, v24 \n\t"
307
+ "vfmax.vv v8, v8, v24 \n\t"
308
+ "vfredmax.vs v24, v8, v24 \n\t"
309
+ "vfmv.f.s f10, v24 \n\t"
310
+ "add s1, a1, %[OFFSET] \n\t"
311
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
312
+ "fsw f10, (a1) \n\t"
313
+ "fdiv.s f11, %[FONE], f10 \n\t"
314
+ "vsetvli t0, zero, e64, m8 \n\t"
315
+ "vxor.vv v0, v0, v0 \n\t"
316
+ "vsse64.v v0, (s1), t6 \n\t"
317
+
318
+ "TAIL_LOOP%=: \n\t"
319
+ "vsetvli t0, zero, e32, m4 \n\t"
320
+ "vxor.vv v0, v0, v0 \n\t"
321
+ "vsetvli t0, t2, e32, m1 \n\t"
322
+ "sub t2, t2, t0 \n\t"
323
+ "vle32.v v0, (%[SRC]) \n\t"
324
+ "addi %[SRC], %[SRC], 32 \n\t"
325
+ "vfmul.vf v1, v0, f11 \n\t"
326
+ "vfcvt.x.f.v v2, v1 \n\t"
327
+ "vsetvli t0, zero, e16, mf2 \n\t"
328
+ "vnclip.wx v3, v2, zero \n\t"
329
+ "vsetvli t0, zero, e8, mf4 \n\t"
330
+ "vnclip.wx v3, v3, zero \n\t"
331
+ "vse8.v v3, (s1) \n\t"
332
+ "addi s1, s1, 32 \n\t"
333
+ "bnez t2, TAIL_LOOP%= \n\t"
334
+
335
+ "QUIT%=: \n\t"
336
+ : [SRC] "+r"(SRC)
337
+ : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride),
338
+ [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
339
+ : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11");
340
+ }
341
+ }
342
+ }
343
+
344
+ void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) {
345
+ const float * SRC = A;
346
+ std::byte * DST = QuantA;
347
+ constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1);
348
+ const float fone = 1.0f;
349
+ std::byte * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen);
350
+ size_t offset = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK;
351
+
352
+ if (CountK <= BlkLen) {
353
+ float max_abs_A = 0.0f;
354
+ for (size_t k = 0; k < CountK; k++) {
355
+ max_abs_A = std::max(max_abs_A, fabsf(A[k]));
356
+ }
357
+ float scale_A = max_abs_A * range_max_reciprocal;
358
+
359
+ ((float *) QuantA)[0] = scale_A;
360
+
361
+ auto * QuantAData_offset = (int8_t *) (QuantA + sizeof(float));
362
+
363
+ for (size_t k = 0; k < CountK; k++) {
364
+ QuantAData_offset[k] =
365
+ (int8_t) std::clamp(roundf(A[k] / scale_A), (float) std::numeric_limits<int8_t>::lowest(),
366
+ (float) std::numeric_limits<int8_t>::max());
367
+ }
368
+ for (size_t k = CountK; k < BlkLen; k++) {
369
+ QuantAData_offset[k] = 0;
370
+ }
371
+
372
+ return;
373
+ }
374
+
375
+ if (BlkLen != 32 || BlkLen != 64 || BlkLen != 128) {
376
+ __asm__ volatile(
377
+ "vsetvli t0, zero, e8, m8 \n\t"
378
+ "vxor.vv v24, v24, v24 \n\t"
379
+ "LOOP%=: \n\t"
380
+ "vsetvli t0, %[CNT], e8, m8 \n\t"
381
+ "vse8.v v24, (%[DST]) \n\t"
382
+ "addi %[DST], %[DST], 128 \n\t"
383
+ "sub %[CNT], %[CNT], t0 \n\t"
384
+ "bnez %[CNT], LOOP%= \n\t"
385
+ : [DST] "+r"(QuantA_offset), [CNT] "+r"(offset)
386
+ :
387
+ : "cc", "t0");
388
+ }
389
+ if (BlkLen == 16) {
390
+ float buffer[64] = { 0.0f };
391
+ __asm__ volatile(
392
+ "addi t3, zero, 16*8 \n\t"
393
+ "addi t2, zero, 16 \n\t"
394
+ "blt %[K], t3, LOOP_K%= \n\t"
395
+ "blt %[K], t2, TAIL%= \n\t"
396
+ "LOOP_MAIN%=: \n\t"
397
+ "vsetvli t1, zero, e32, m2 \n\t"
398
+ "addi %[K], %[K], -128 \n\t"
399
+ "vle32.v v0, (%[SRC]) \n\t"
400
+ "addi %[SRC], %[SRC], 64 \n\t"
401
+ "vle32.v v2, (%[SRC]) \n\t"
402
+ "addi %[SRC], %[SRC], 64 \n\t"
403
+ "vle32.v v4, (%[SRC]) \n\t"
404
+ "addi %[SRC], %[SRC], 64 \n\t"
405
+ "vle32.v v6, (%[SRC]) \n\t"
406
+ "addi %[SRC], %[SRC], 64 \n\t"
407
+ "vle32.v v8, (%[SRC]) \n\t"
408
+ "addi %[SRC], %[SRC], 64 \n\t"
409
+ "vle32.v v10, (%[SRC]) \n\t"
410
+ "addi %[SRC], %[SRC], 64 \n\t"
411
+ "vle32.v v12, (%[SRC]) \n\t"
412
+ "addi %[SRC], %[SRC], 64 \n\t"
413
+ "vle32.v v14, (%[SRC]) \n\t"
414
+ "addi %[SRC], %[SRC], 64 \n\t"
415
+ "addi a1, %[BUFFER], 0 \n\t"
416
+ "vfabs.v v16, v0 \n\t"
417
+ "vfabs.v v18, v2 \n\t"
418
+ "vfabs.v v20, v4 \n\t"
419
+ "vfabs.v v22, v6 \n\t"
420
+ "vfabs.v v24, v8 \n\t"
421
+ "vfabs.v v26, v10 \n\t"
422
+ "vfabs.v v28, v12 \n\t"
423
+ "vfabs.v v30, v14 \n\t"
424
+ "vsetvli t0, zero, e32, m1 \n\t"
425
+ "vfmax.vv v16, v16, v17 \n\t"
426
+ "vfmax.vv v18, v18, v19 \n\t"
427
+ "vfmax.vv v20, v20, v21 \n\t"
428
+ "vfmax.vv v22, v22, v23 \n\t"
429
+ "vfmax.vv v24, v24, v25 \n\t"
430
+ "vfmax.vv v26, v26, v27 \n\t"
431
+ "vfmax.vv v28, v28, v29 \n\t"
432
+ "vfmax.vv v30, v30, v31 \n\t"
433
+ "vse32.v v16, (a1) \n\t"
434
+ "addi a1, a1, 32 \n\t"
435
+ "vse32.v v18, (a1) \n\t"
436
+ "addi a1, a1, 32 \n\t"
437
+ "vse32.v v20, (a1) \n\t"
438
+ "addi a1, a1, 32 \n\t"
439
+ "vse32.v v22, (a1) \n\t"
440
+ "addi a1, a1, 32 \n\t"
441
+ "vse32.v v24, (a1) \n\t"
442
+ "addi a1, a1, 32 \n\t"
443
+ "vse32.v v26, (a1) \n\t"
444
+ "addi a1, a1, 32 \n\t"
445
+ "vse32.v v28, (a1) \n\t"
446
+ "addi a1, a1, 32 \n\t"
447
+ "vse32.v v30, (a1) \n\t"
448
+ "addi a1, %[BUFFER], 0 \n\t"
449
+ "flw f0, (a1) \n\t"
450
+ "flw f1, 4(a1) \n\t"
451
+ "flw f2, 8(a1) \n\t"
452
+ "flw f3, 12(a1) \n\t"
453
+ "flw f4, 16(a1) \n\t"
454
+ "flw f5, 20(a1) \n\t"
455
+ "flw f6, 24(a1) \n\t"
456
+ "flw f7, 28(a1) \n\t"
457
+ "addi a1, a1, 32 \n\t"
458
+ "fmax.s f1, f0, f1 \n\t"
459
+ "fmax.s f3, f2, f3 \n\t"
460
+ "fmax.s f5, f4, f5 \n\t"
461
+ "fmax.s f7, f6, f7 \n\t"
462
+ "fmax.s f3, f1, f3 \n\t"
463
+ "fmax.s f7, f5, f7 \n\t"
464
+ "fmax.s f10, f3, f7 \n\t"
465
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
466
+ "fsw f10, (%[DST]) \n\t"
467
+ "addi %[DST], %[DST], 20 \n\t"
468
+ "fdiv.s f10, %[FONE], f10 \n\t"
469
+ "flw f0, (a1) \n\t"
470
+ "flw f1, 4(a1) \n\t"
471
+ "flw f2, 8(a1) \n\t"
472
+ "flw f3, 12(a1) \n\t"
473
+ "flw f4, 16(a1) \n\t"
474
+ "flw f5, 20(a1) \n\t"
475
+ "flw f6, 24(a1) \n\t"
476
+ "flw f7, 28(a1) \n\t"
477
+ "addi a1, a1, 32 \n\t"
478
+ "fmax.s f1, f0, f1 \n\t"
479
+ "fmax.s f3, f2, f3 \n\t"
480
+ "fmax.s f5, f4, f5 \n\t"
481
+ "fmax.s f7, f6, f7 \n\t"
482
+ "fmax.s f3, f1, f3 \n\t"
483
+ "fmax.s f7, f5, f7 \n\t"
484
+ "fmax.s f11, f3, f7 \n\t"
485
+ "fmul.s f11, f11, %[RMAXREC] \n\t"
486
+ "fsw f11, (%[DST]) \n\t"
487
+ "addi %[DST], %[DST], 20 \n\t"
488
+ "fdiv.s f11, %[FONE], f11 \n\t"
489
+ "flw f0, (a1) \n\t"
490
+ "flw f1, 4(a1) \n\t"
491
+ "flw f2, 8(a1) \n\t"
492
+ "flw f3, 12(a1) \n\t"
493
+ "flw f4, 16(a1) \n\t"
494
+ "flw f5, 20(a1) \n\t"
495
+ "flw f6, 24(a1) \n\t"
496
+ "flw f7, 28(a1) \n\t"
497
+ "addi a1, a1, 32 \n\t"
498
+ "fmax.s f1, f0, f1 \n\t"
499
+ "fmax.s f3, f2, f3 \n\t"
500
+ "fmax.s f5, f4, f5 \n\t"
501
+ "fmax.s f7, f6, f7 \n\t"
502
+ "fmax.s f3, f1, f3 \n\t"
503
+ "fmax.s f7, f5, f7 \n\t"
504
+ "fmax.s f12, f3, f7 \n\t"
505
+ "fmul.s f12, f12, %[RMAXREC] \n\t"
506
+ "fsw f12, (%[DST]) \n\t"
507
+ "addi %[DST], %[DST], 20 \n\t"
508
+ "fdiv.s f12, %[FONE], f12 \n\t"
509
+ "flw f0, (a1) \n\t"
510
+ "flw f1, 4(a1) \n\t"
511
+ "flw f2, 8(a1) \n\t"
512
+ "flw f3, 12(a1) \n\t"
513
+ "flw f4, 16(a1) \n\t"
514
+ "flw f5, 20(a1) \n\t"
515
+ "flw f6, 24(a1) \n\t"
516
+ "flw f7, 28(a1) \n\t"
517
+ "addi a1, a1, 32 \n\t"
518
+ "fmax.s f1, f0, f1 \n\t"
519
+ "fmax.s f3, f2, f3 \n\t"
520
+ "fmax.s f5, f4, f5 \n\t"
521
+ "fmax.s f7, f6, f7 \n\t"
522
+ "fmax.s f3, f1, f3 \n\t"
523
+ "fmax.s f7, f5, f7 \n\t"
524
+ "fmax.s f13, f3, f7 \n\t"
525
+ "fmul.s f13, f13, %[RMAXREC] \n\t"
526
+ "fsw f13, (%[DST]) \n\t"
527
+ "addi %[DST], %[DST], 20 \n\t"
528
+ "fdiv.s f13, %[FONE], f13 \n\t"
529
+ "flw f0, (a1) \n\t"
530
+ "flw f1, 4(a1) \n\t"
531
+ "flw f2, 8(a1) \n\t"
532
+ "flw f3, 12(a1) \n\t"
533
+ "flw f4, 16(a1) \n\t"
534
+ "flw f5, 20(a1) \n\t"
535
+ "flw f6, 24(a1) \n\t"
536
+ "flw f7, 28(a1) \n\t"
537
+ "addi a1, a1, 32 \n\t"
538
+ "fmax.s f1, f0, f1 \n\t"
539
+ "fmax.s f3, f2, f3 \n\t"
540
+ "fmax.s f5, f4, f5 \n\t"
541
+ "fmax.s f7, f6, f7 \n\t"
542
+ "fmax.s f3, f1, f3 \n\t"
543
+ "fmax.s f7, f5, f7 \n\t"
544
+ "fmax.s f14, f3, f7 \n\t"
545
+ "fmul.s f14, f14, %[RMAXREC] \n\t"
546
+ "fsw f14, (%[DST]) \n\t"
547
+ "addi %[DST], %[DST], 20 \n\t"
548
+ "fdiv.s f14, %[FONE], f14 \n\t"
549
+ "flw f0, (a1) \n\t"
550
+ "flw f1, 4(a1) \n\t"
551
+ "flw f2, 8(a1) \n\t"
552
+ "flw f3, 12(a1) \n\t"
553
+ "flw f4, 16(a1) \n\t"
554
+ "flw f5, 20(a1) \n\t"
555
+ "flw f6, 24(a1) \n\t"
556
+ "flw f7, 28(a1) \n\t"
557
+ "addi a1, a1, 32 \n\t"
558
+ "fmax.s f1, f0, f1 \n\t"
559
+ "fmax.s f3, f2, f3 \n\t"
560
+ "fmax.s f5, f4, f5 \n\t"
561
+ "fmax.s f7, f6, f7 \n\t"
562
+ "fmax.s f3, f1, f3 \n\t"
563
+ "fmax.s f7, f5, f7 \n\t"
564
+ "fmax.s f15, f3, f7 \n\t"
565
+ "fmul.s f15, f15, %[RMAXREC] \n\t"
566
+ "fsw f15, (%[DST]) \n\t"
567
+ "addi %[DST], %[DST], 20 \n\t"
568
+ "fdiv.s f15, %[FONE], f15 \n\t"
569
+ "flw f0, (a1) \n\t"
570
+ "flw f1, 4(a1) \n\t"
571
+ "flw f2, 8(a1) \n\t"
572
+ "flw f3, 12(a1) \n\t"
573
+ "flw f4, 16(a1) \n\t"
574
+ "flw f5, 20(a1) \n\t"
575
+ "flw f6, 24(a1) \n\t"
576
+ "flw f7, 28(a1) \n\t"
577
+ "addi a1, a1, 32 \n\t"
578
+ "fmax.s f1, f0, f1 \n\t"
579
+ "fmax.s f3, f2, f3 \n\t"
580
+ "fmax.s f5, f4, f5 \n\t"
581
+ "fmax.s f7, f6, f7 \n\t"
582
+ "fmax.s f3, f1, f3 \n\t"
583
+ "fmax.s f7, f5, f7 \n\t"
584
+ "fmax.s f16, f3, f7 \n\t"
585
+ "fmul.s f16, f16, %[RMAXREC] \n\t"
586
+ "fsw f16, (%[DST]) \n\t"
587
+ "addi %[DST], %[DST], 20 \n\t"
588
+ "fdiv.s f16, %[FONE], f16 \n\t"
589
+ "flw f0, (a1) \n\t"
590
+ "flw f1, 4(a1) \n\t"
591
+ "flw f2, 8(a1) \n\t"
592
+ "flw f3, 12(a1) \n\t"
593
+ "flw f4, 16(a1) \n\t"
594
+ "flw f5, 20(a1) \n\t"
595
+ "flw f6, 24(a1) \n\t"
596
+ "flw f7, 28(a1) \n\t"
597
+ "addi a1, a1, 32 \n\t"
598
+ "fmax.s f1, f0, f1 \n\t"
599
+ "fmax.s f3, f2, f3 \n\t"
600
+ "fmax.s f5, f4, f5 \n\t"
601
+ "fmax.s f7, f6, f7 \n\t"
602
+ "fmax.s f3, f1, f3 \n\t"
603
+ "fmax.s f7, f5, f7 \n\t"
604
+ "fmax.s f17, f3, f7 \n\t"
605
+ "fmul.s f17, f17, %[RMAXREC] \n\t"
606
+ "fsw f17, (%[DST]) \n\t"
607
+ "addi %[DST], %[DST], -136 \n\t"
608
+ "fdiv.s f17, %[FONE], f17 \n\t"
609
+ "vsetvli t0, zero, e32, m2 \n\t"
610
+ "vfmul.vf v16, v0, f10 \n\t"
611
+ "vfmul.vf v18, v2, f11 \n\t"
612
+ "vfmul.vf v20, v4, f12 \n\t"
613
+ "vfmul.vf v22, v6, f13 \n\t"
614
+ "vfmul.vf v24, v8, f14 \n\t"
615
+ "vfmul.vf v26, v10, f15 \n\t"
616
+ "vfmul.vf v28, v12, f16 \n\t"
617
+ "vfmul.vf v30, v14, f17 \n\t"
618
+ "vfcvt.x.f.v v16, v16 \n\t"
619
+ "vfcvt.x.f.v v18, v18 \n\t"
620
+ "vfcvt.x.f.v v20, v20 \n\t"
621
+ "vfcvt.x.f.v v22, v22 \n\t"
622
+ "vfcvt.x.f.v v24, v24 \n\t"
623
+ "vfcvt.x.f.v v26, v26 \n\t"
624
+ "vfcvt.x.f.v v28, v28 \n\t"
625
+ "vfcvt.x.f.v v30, v30 \n\t"
626
+ "vsetvli t0, zero, e16, m1 \n\t"
627
+ "vnclip.wx v16, v16, zero \n\t"
628
+ "vnclip.wx v18, v18, zero \n\t"
629
+ "vnclip.wx v20, v20, zero \n\t"
630
+ "vnclip.wx v22, v22, zero \n\t"
631
+ "vnclip.wx v24, v24, zero \n\t"
632
+ "vnclip.wx v26, v26, zero \n\t"
633
+ "vnclip.wx v28, v28, zero \n\t"
634
+ "vnclip.wx v30, v30, zero \n\t"
635
+ "vsetvli t0, t1, e8, mf2 \n\t"
636
+ "vnclip.wx v16, v16, zero \n\t"
637
+ "vnclip.wx v18, v18, zero \n\t"
638
+ "vnclip.wx v20, v20, zero \n\t"
639
+ "vnclip.wx v22, v22, zero \n\t"
640
+ "vnclip.wx v24, v24, zero \n\t"
641
+ "vnclip.wx v26, v26, zero \n\t"
642
+ "vnclip.wx v28, v28, zero \n\t"
643
+ "vnclip.wx v30, v30, zero \n\t"
644
+ "vse8.v v16, (%[DST]) \n\t"
645
+ "addi %[DST], %[DST], 20 \n\t"
646
+ "vse8.v v18, (%[DST]) \n\t"
647
+ "addi %[DST], %[DST], 20 \n\t"
648
+ "vse8.v v20, (%[DST]) \n\t"
649
+ "addi %[DST], %[DST], 20 \n\t"
650
+ "vse8.v v22, (%[DST]) \n\t"
651
+ "addi %[DST], %[DST], 20 \n\t"
652
+ "vse8.v v24, (%[DST]) \n\t"
653
+ "addi %[DST], %[DST], 20 \n\t"
654
+ "vse8.v v26, (%[DST]) \n\t"
655
+ "addi %[DST], %[DST], 20 \n\t"
656
+ "vse8.v v28, (%[DST]) \n\t"
657
+ "addi %[DST], %[DST], 20 \n\t"
658
+ "vse8.v v30, (%[DST]) \n\t"
659
+ "addi %[DST], %[DST], 16 \n\t"
660
+ "bge %[K], t3, LOOP_MAIN%= \n\t"
661
+ "blt %[K], t2, TAIL%= \n\t"
662
+ "LOOP_K%=: \n\t"
663
+ "vsetvli t1, %[K], e32, m2 \n\t"
664
+ "vle32.v v0, (%[SRC]) \n\t"
665
+ "addi %[SRC], %[SRC], 64 \n\t"
666
+ "sub %[K], %[K], t1 \n\t"
667
+ "vfabs.v v16, v0 \n\t"
668
+ "vsetvli t0, zero, e32, m1 \n\t"
669
+ "vfmax.vv v16, v16, v17 \n\t"
670
+ "vse32.v v16, (%[BUFFER]) \n\t"
671
+ "flw f0, (%[BUFFER]) \n\t"
672
+ "flw f1, 4(%[BUFFER]) \n\t"
673
+ "flw f2, 8(%[BUFFER]) \n\t"
674
+ "flw f3, 12(%[BUFFER]) \n\t"
675
+ "flw f4, 16(%[BUFFER]) \n\t"
676
+ "flw f5, 20(%[BUFFER]) \n\t"
677
+ "flw f6, 24(%[BUFFER]) \n\t"
678
+ "flw f7, 28(%[BUFFER]) \n\t"
679
+ "fmax.s f1, f0, f1 \n\t"
680
+ "fmax.s f3, f2, f3 \n\t"
681
+ "fmax.s f5, f4, f5 \n\t"
682
+ "fmax.s f7, f6, f7 \n\t"
683
+ "fmax.s f3, f1, f3 \n\t"
684
+ "fmax.s f7, f5, f7 \n\t"
685
+ "fmax.s f10, f3, f7 \n\t"
686
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
687
+ "fsw f10, (%[DST]) \n\t"
688
+ "addi %[DST], %[DST], 4 \n\t"
689
+ "fdiv.s f11, %[FONE], f10 \n\t"
690
+ "vsetvli t0, zero, e32, m2 \n\t"
691
+ "vfmul.vf v16, v0, f11 \n\t"
692
+ "vfcvt.x.f.v v16, v16 \n\t"
693
+ "vsetvli t0, zero, e16, m1 \n\t"
694
+ "vnclip.wx v16, v16, zero \n\t"
695
+ "vsetvli t0, t1, e8, mf2 \n\t"
696
+ "vnclip.wx v16, v16, zero \n\t"
697
+ "vse8.v v16, (%[DST]) \n\t"
698
+ "addi %[DST], %[DST], 16 \n\t"
699
+ "bge %[K], t2, LOOP_K%= \n\t"
700
+ "TAIL%=: \n\t"
701
+ "blez %[K], END%= \n\t"
702
+ "vsetvli t0, t3, e32, m2 \n\t"
703
+ "vxor.vv v16, v16, v16 \n\t"
704
+ "jal x0, LOOP_K%= \n\t"
705
+ "END%=: \n\t"
706
+ : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK)
707
+ : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BUFFER] "r"(buffer)
708
+ : "cc", "t3", "t2", "t1", "t0", "a1", "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f10", "f11", "f12",
709
+ "f13", "f14", "f15", "f16", "f17");
710
+ } else if (BlkLen == 32) {
711
+ __asm__ volatile(
712
+ "addi t3, zero, 32*4 \n\t"
713
+ "addi t2, zero, 32 \n\t"
714
+
715
+ "addi a1, %[SRC], 0 \n\t"
716
+ "addi a2, %[SRC], 128 \n\t"
717
+ "addi a3, %[SRC], 256 \n\t"
718
+ "addi a4, %[SRC], 384 \n\t"
719
+
720
+ "addi s1, %[DST], 0 \n\t"
721
+ "addi s2, %[DST], 36 \n\t"
722
+ "addi s3, %[DST], 72 \n\t"
723
+ "addi s4, %[DST], 108 \n\t"
724
+ "blt %[K], t3, LOOP_K%= \n\t"
725
+ "blt %[K], t2, TAIL%= \n\t"
726
+
727
+ "LOOP_MAIN%=: \n\t"
728
+ "vsetvli t1, zero, e32, m4 \n\t"
729
+ "addi %[K], %[K], -128 \n\t"
730
+ "vle32.v v0, (a1) \n\t"
731
+ "addi a1, a1, 512 \n\t"
732
+ "vle32.v v4, (a2) \n\t"
733
+ "addi a2, a2, 512 \n\t"
734
+ "vle32.v v8, (a3) \n\t"
735
+ "addi a3, a3, 512 \n\t"
736
+ "vle32.v v12, (a4) \n\t"
737
+ "addi a4, a4, 512 \n\t"
738
+ "vfabs.v v16, v0 \n\t"
739
+ "vfabs.v v20, v4 \n\t"
740
+ "vfabs.v v24, v8 \n\t"
741
+ "vfabs.v v28, v12 \n\t"
742
+ "vsetvli t0, zero, e32, m2 \n\t"
743
+ "vfmax.vv v16, v16, v18 \n\t"
744
+ "vfmax.vv v20, v20, v22 \n\t"
745
+ "vfmax.vv v24, v24, v26 \n\t"
746
+ "vfmax.vv v28, v28, v30 \n\t"
747
+ "vsetvli t0, zero, e32, m1 \n\t"
748
+ "vfmax.vv v16, v16, v17 \n\t"
749
+ "vfmax.vv v20, v20, v21 \n\t"
750
+ "vfmax.vv v24, v24, v25 \n\t"
751
+ "vfmax.vv v28, v28, v29 \n\t"
752
+
753
+ "vfredmax.vs v17, v16, v17 \n\t"
754
+ "vfredmax.vs v21, v20, v21 \n\t"
755
+ "vfredmax.vs v25, v24, v25 \n\t"
756
+ "vfredmax.vs v29, v28, v29 \n\t"
757
+ "vfmv.f.s f10, v17 \n\t"
758
+ "vfmv.f.s f11, v21 \n\t"
759
+ "vfmv.f.s f12, v25 \n\t"
760
+ "vfmv.f.s f13, v29 \n\t"
761
+
762
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
763
+ "fmul.s f11, f11, %[RMAXREC] \n\t"
764
+ "fmul.s f12, f12, %[RMAXREC] \n\t"
765
+ "fmul.s f13, f13, %[RMAXREC] \n\t"
766
+ "fsw f10, (s1) \n\t"
767
+ "addi s1, s1, 4 \n\t"
768
+
769
+ "fsw f11, (s2) \n\t"
770
+ "addi s2, s2, 4 \n\t"
771
+ "fsw f12, (s3) \n\t"
772
+ "addi s3, s3, 4 \n\t"
773
+ "fsw f13, (s4) \n\t"
774
+ "addi s4, s4, 4 \n\t"
775
+ "fdiv.s f10, %[FONE], f10 \n\t"
776
+ "fdiv.s f11, %[FONE], f11 \n\t"
777
+ "fdiv.s f12, %[FONE], f12 \n\t"
778
+ "fdiv.s f13, %[FONE], f13 \n\t"
779
+ "vsetvli t0, zero, e32, m4 \n\t"
780
+ "vfmul.vf v16, v0, f10 \n\t"
781
+ "vfmul.vf v20, v4, f11 \n\t"
782
+ "vfmul.vf v24, v8, f12 \n\t"
783
+ "vfmul.vf v28, v12, f13 \n\t"
784
+ "vfcvt.x.f.v v16, v16 \n\t"
785
+ "vfcvt.x.f.v v20, v20 \n\t"
786
+ "vfcvt.x.f.v v24, v24 \n\t"
787
+ "vfcvt.x.f.v v28, v28 \n\t"
788
+ "vsetvli t0, zero, e16, m2 \n\t"
789
+ "vnclip.wx v16, v16, zero \n\t"
790
+ "vnclip.wx v20, v20, zero \n\t"
791
+ "vnclip.wx v24, v24, zero \n\t"
792
+ "vnclip.wx v28, v28, zero \n\t"
793
+ "vsetvli t0, t1, e8, m1 \n\t"
794
+ "vnclip.wx v16, v16, zero \n\t"
795
+ "vnclip.wx v20, v20, zero \n\t"
796
+ "vnclip.wx v24, v24, zero \n\t"
797
+ "vnclip.wx v28, v28, zero \n\t"
798
+ "vse8.v v16, (s1) \n\t"
799
+ "addi s1, s1, 140 \n\t"
800
+ "vse8.v v20, (s2) \n\t"
801
+ "addi s2, s2, 140 \n\t"
802
+ "vse8.v v24, (s3) \n\t"
803
+ "addi s3, s3, 140 \n\t"
804
+ "vse8.v v28, (s4) \n\t"
805
+ "addi s4, s4, 140 \n\t"
806
+ "bge %[K], t3, LOOP_MAIN%= \n\t"
807
+ "blt %[K], t2, TAIL%= \n\t"
808
+ "LOOP_K%=: \n\t"
809
+ "vsetvli t1, %[K], e32, m4 \n\t"
810
+ "vle32.v v0, (a1) \n\t"
811
+ "addi a1, a1, 128 \n\t"
812
+ "sub %[K], %[K], t1 \n\t"
813
+ "vfabs.v v16, v0 \n\t"
814
+ "vsetvli t0, zero, e32, m2 \n\t"
815
+ "vfmax.vv v16, v16, v18 \n\t"
816
+ "vsetvli t0, zero, e32, m1 \n\t"
817
+ "vfmax.vv v16, v16, v17 \n\t"
818
+ "vfredmax.vs v17, v16, v17 \n\t"
819
+ "vfmv.f.s f10, v17 \n\t"
820
+
821
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
822
+ "fsw f10, (s1) \n\t"
823
+ "addi s1, s1, 4 \n\t"
824
+ "fdiv.s f11, %[FONE], f10 \n\t"
825
+ "vsetvli t0, zero, e32, m4 \n\t"
826
+ "vfmul.vf v16, v0, f11 \n\t"
827
+ "vfcvt.x.f.v v16, v16 \n\t"
828
+ "vsetvli t0, zero, e16, m2 \n\t"
829
+ "vnclip.wx v16, v16, zero \n\t"
830
+ "vsetvli t0, zero, e8, m1 \n\t"
831
+ "vnclip.wx v16, v16, zero \n\t"
832
+ "vse8.v v16, (s1) \n\t"
833
+ "addi s1, s1, 32 \n\t"
834
+ "bge %[K], t2, LOOP_K%= \n\t"
835
+ "TAIL%=: \n\t"
836
+ "blez %[K], END%= \n\t"
837
+ "vsetvli t0, t3, e32, m4 \n\t"
838
+ "vxor.vv v0, v0, v0 \n\t"
839
+ "vxor.vv v16, v16, v16 \n\t"
840
+ "jal x0, LOOP_K%= \n\t"
841
+ "END%=: \n\t"
842
+ : [K] "+r"(CountK)
843
+ : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST)
844
+ : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13");
845
+ } else if (BlkLen == 64) {
846
+ __asm__ volatile(
847
+ "addi t3, zero, 64*2 \n\t"
848
+ "addi t2, zero, 64 \n\t"
849
+ "addi a1, %[SRC], 0 \n\t"
850
+ "addi a2, %[SRC], 256 \n\t"
851
+ "addi s1, %[DST], 0 \n\t"
852
+ "addi s2, %[DST], 68 \n\t"
853
+ "blt %[K], t3, LOOP_K%= \n\t"
854
+ "blt %[K], t2, TAIL%= \n\t"
855
+ "LOOP_MAIN%=: \n\t"
856
+ "vsetvli t1, zero, e32, m8 \n\t"
857
+ "addi %[K], %[K], -128 \n\t"
858
+ "vle32.v v0, (a1) \n\t"
859
+ "addi a1, a1, 512 \n\t"
860
+ "vle32.v v8, (a2) \n\t"
861
+ "addi a2, a2, 512 \n\t"
862
+ "vfabs.v v16, v0 \n\t"
863
+ "vfabs.v v24, v8 \n\t"
864
+ "vsetvli t0, zero, e32, m4 \n\t"
865
+ "vfmax.vv v16, v16, v20 \n\t"
866
+ "vfmax.vv v24, v24, v28 \n\t"
867
+ "vsetvli t0, zero, e32, m2 \n\t"
868
+ "vfmax.vv v16, v16, v18 \n\t"
869
+ "vfmax.vv v24, v24, v26 \n\t"
870
+ "vsetvli t0, zero, e32, m1 \n\t"
871
+ "vfmax.vv v16, v16, v17 \n\t"
872
+ "vfmax.vv v24, v24, v25 \n\t"
873
+ "vfredmax.vs v17, v16, v17 \n\t"
874
+ "vfredmax.vs v25, v24, v25 \n\t"
875
+ "vfmv.f.s f10, v17 \n\t"
876
+ "vfmv.f.s f11, v25 \n\t"
877
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
878
+ "fmul.s f11, f11, %[RMAXREC] \n\t"
879
+ "fsw f10, (s1) \n\t"
880
+ "addi s1, s1, 4 \n\t"
881
+ "fsw f11, (s2) \n\t"
882
+ "addi s2, s2, 4 \n\t"
883
+ "fdiv.s f10, %[FONE], f10 \n\t"
884
+ "fdiv.s f11, %[FONE], f11 \n\t"
885
+ "vsetvli t0, zero, e32, m8 \n\t"
886
+ "vfmul.vf v16, v0, f10 \n\t"
887
+ "vfmul.vf v24, v8, f11 \n\t"
888
+ "vfcvt.x.f.v v16, v16 \n\t"
889
+ "vfcvt.x.f.v v24, v24 \n\t"
890
+ "vsetvli t0, zero, e16, m4 \n\t"
891
+ "vnclip.wx v16, v16, zero \n\t"
892
+ "vnclip.wx v24, v24, zero \n\t"
893
+ "vsetvli t0, t1, e8, m2 \n\t"
894
+ "vnclip.wx v16, v16, zero \n\t"
895
+ "vnclip.wx v24, v24, zero \n\t"
896
+ "vse8.v v16, (s1) \n\t"
897
+ "addi s1, s1, 132 \n\t"
898
+ "vse8.v v24, (s2) \n\t"
899
+ "addi s2, s2, 132 \n\t"
900
+ "bge %[K], t3, LOOP_MAIN%= \n\t"
901
+ "blt %[K], t2, TAIL%= \n\t"
902
+ "LOOP_K%=: \n\t"
903
+ "vsetvli t1, %[K], e32, m8 \n\t"
904
+ "vle32.v v0, (a1) \n\t"
905
+ "addi a1, a1, 256 \n\t"
906
+ "sub %[K], %[K], t1 \n\t"
907
+ "vfabs.v v16, v0 \n\t"
908
+ "vsetvli t0, zero, e32, m4 \n\t"
909
+ "vfmax.vv v16, v16, v20 \n\t"
910
+ "vsetvli t0, zero, e32, m2 \n\t"
911
+ "vfmax.vv v16, v16, v18 \n\t"
912
+ "vsetvli t0, zero, e32, m1 \n\t"
913
+ "vfmax.vv v16, v16, v17 \n\t"
914
+ "vfredmax.vs v17, v16, v17 \n\t"
915
+ "vfmv.f.s f10, v17 \n\t"
916
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
917
+ "fsw f10, (s1) \n\t"
918
+ "addi s1, s1, 4 \n\t"
919
+ "fdiv.s f11, %[FONE], f10 \n\t"
920
+ "vsetvli t0, zero, e32, m8 \n\t"
921
+ "vfmul.vf v16, v0, f11 \n\t"
922
+ "vfcvt.x.f.v v16, v16 \n\t"
923
+ "vsetvli t0, zero, e16, m4 \n\t"
924
+ "vnclip.wx v16, v16, zero \n\t"
925
+ "vsetvli t0, zero, e8, m2 \n\t"
926
+ "vnclip.wx v16, v16, zero \n\t"
927
+ "vse8.v v16, (s1) \n\t"
928
+ "addi s1, s1, 64 \n\t"
929
+ "bge %[K], t2, LOOP_K%= \n\t"
930
+ "TAIL%=: \n\t"
931
+ "blez %[K], END%= \n\t"
932
+ "vsetvli t0, t3, e32, m8 \n\t"
933
+ "vxor.vv v0, v0, v0 \n\t"
934
+ "vxor.vv v16, v16, v16 \n\t"
935
+ "jal x0, LOOP_K%= \n\t"
936
+ "END%=: \n\t"
937
+ : [K] "+r"(CountK)
938
+ : [SRC] "r"(SRC), [DST] "r"(DST), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal)
939
+ : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "s1", "s2", "f10", "f11");
940
+ } else if (BlkLen == 128) {
941
+ __asm__ volatile(
942
+ "addi t2, zero, 128 \n\t"
943
+ "addi a1, %[SRC], 0 \n\t"
944
+ "addi a2, %[SRC], 256 \n\t"
945
+ "blt %[K], t2, TAIL%= \n\t"
946
+ "LOOP_K%=: \n\t"
947
+ "vsetvli t1, zero, e32, m8 \n\t"
948
+ "vle32.v v0, (a1) \n\t"
949
+ "addi a1, a1, 512 \n\t"
950
+ "vle32.v v8, (a2) \n\t"
951
+ "addi a2, a2, 512 \n\t"
952
+ "sub %[K], %[K], t2 \n\t"
953
+ "QUANT%=: \n\t"
954
+ "vfabs.v v16, v0 \n\t"
955
+ "vfabs.v v24, v8 \n\t"
956
+ "vfmax.vv v24, v16, v24 \n\t"
957
+ "vsetvli t1, zero, e32, m4 \n\t"
958
+ "vfmax.vv v28, v24, v28 \n\t"
959
+ "vsetvli t0, zero, e32, m2 \n\t"
960
+ "vfmax.vv v30, v28, v30 \n\t"
961
+ "vsetvli t0, zero, e32, m1 \n\t"
962
+ "vfmax.vv v30, v30, v31 \n\t"
963
+ "vfredmax.vs v31, v30, v31 \n\t"
964
+ "vfmv.f.s f10, v31 \n\t"
965
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
966
+ "fsw f10, (%[DST]) \n\t"
967
+ "addi %[DST], %[DST], 4 \n\t"
968
+ "fdiv.s f11, %[FONE], f10 \n\t"
969
+ "vsetvli t0, zero, e32, m8 \n\t"
970
+ "vfmul.vf v16, v0, f11 \n\t"
971
+ "vfmul.vf v24, v8, f11 \n\t"
972
+ "vfcvt.x.f.v v16, v16 \n\t"
973
+ "vfcvt.x.f.v v24, v24 \n\t"
974
+ "vsetvli t0, zero, e16, m4 \n\t"
975
+ "vnclip.wx v16, v16, zero \n\t"
976
+ "vnclip.wx v20, v24, zero \n\t"
977
+ "vsetvli t0, zero, e8, m4 \n\t"
978
+ "vnclip.wx v16, v16, zero \n\t"
979
+ "vse8.v v16, (%[DST]) \n\t"
980
+ "addi %[DST], %[DST], 128 \n\t"
981
+ "bge %[K], t2, LOOP_K%= \n\t"
982
+ "TAIL%=: \n\t"
983
+ "blez %[K], END%= \n\t"
984
+ "vsetvli t1, zero, e32, m8 \n\t"
985
+ "vxor.vv v0, v0, v0 \n\t"
986
+ "vxor.vv v8, v8, v8 \n\t"
987
+ "vsetvli t0, %[K], e32, m8 \n\t"
988
+ "vle32.v v0, (a1) \n\t"
989
+ "sub %[K], %[K], t0 \n\t"
990
+ "vsetvli t0, %[K], e32, m8 \n\t"
991
+ "vle32.v v8, (a2) \n\t"
992
+ "sub %[K], %[K], t0 \n\t"
993
+ "vsetvli t1, zero, e32, m8 \n\t"
994
+ "jal x0, QUANT%= \n\t"
995
+ "END%=: \n\t"
996
+
997
+ : [DST] "+r"(DST), [K] "+r"(CountK)
998
+ : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC)
999
+ : "cc", "t2", "t1", "t0", "a1", "a2", "f10", "f11");
1000
+ } else {
1001
+ float buffer[8] = { 0.0f };
1002
+ size_t cnt = BlkLen / 256;
1003
+
1004
+ __asm__ volatile(
1005
+ "slli t3, %[BLK], 2 \n\t"
1006
+ "blt %[K], %[BLK], LOOP_TAIL%= \n\t"
1007
+ "LOOP_MAIN%=: \n\t"
1008
+ "vsetvli t0, zero, e32, m1 \n\t"
1009
+ "vxor.vv v31, v31, v31 \n\t"
1010
+ "vse32.v v31, (%[BUFFER]) \n\t"
1011
+ "addi t6, %[CNT], 0 \n\t"
1012
+ "LOOP_CMP%=: \n\t"
1013
+ "addi t6, t6, -1 \n\t"
1014
+ "vsetvli t0, zero, e32, m8 \n\t"
1015
+ "vle32.v v0, (%[SRC]) \n\t"
1016
+ "addi %[SRC], %[SRC], 256 \n\t"
1017
+ "vle32.v v8, (%[SRC]) \n\t"
1018
+ "addi %[SRC], %[SRC], 256 \n\t"
1019
+ "vle32.v v16, (%[SRC]) \n\t"
1020
+ "addi %[SRC], %[SRC], 256 \n\t"
1021
+ "vle32.v v24, (%[SRC]) \n\t"
1022
+ "addi %[SRC], %[SRC], 256 \n\t"
1023
+ "vfabs.v v0, v0 \n\t"
1024
+ "vfabs.v v8, v8 \n\t"
1025
+ "vfabs.v v16, v16 \n\t"
1026
+ "vfabs.v v24, v24 \n\t"
1027
+ "vfmax.vv v8, v0, v8 \n\t"
1028
+ "vfmax.vv v16, v16, v24 \n\t"
1029
+ "vfmax.vv v0, v0, v16 \n\t"
1030
+ "vsetvli t0, zero, e32, m4 \n\t"
1031
+ "vfmax.vv v0, v0, v4 \n\t"
1032
+ "vsetvli t0, zero, e32, m2 \n\t"
1033
+ "vfmax.vv v0, v0, v2 \n\t"
1034
+ "vsetvli t0, zero, e32, m1 \n\t"
1035
+ "vfmax.vv v0, v0, v1 \n\t"
1036
+ "vle32.v v30, (%[BUFFER]) \n\t"
1037
+ "vfmax.vv v31, v30, v0 \n\t"
1038
+ "vse32.v v31, (%[BUFFER]) \n\t"
1039
+ "bnez t6, LOOP_CMP%= \n\t"
1040
+ "sub %[SRC], %[SRC], t3 \n\t"
1041
+ "addi t6, %[CNT], 0 \n\t"
1042
+ "flw f0, (%[BUFFER]) \n\t"
1043
+ "flw f1, 4(%[BUFFER]) \n\t"
1044
+ "flw f2, 8(%[BUFFER]) \n\t"
1045
+ "flw f3, 12(%[BUFFER]) \n\t"
1046
+ "flw f4, 16(%[BUFFER]) \n\t"
1047
+ "flw f5, 20(%[BUFFER]) \n\t"
1048
+ "flw f6, 24(%[BUFFER]) \n\t"
1049
+ "flw f7, 28(%[BUFFER]) \n\t"
1050
+ "fmax.s f1, f0, f1 \n\t"
1051
+ "fmax.s f3, f2, f3 \n\t"
1052
+ "fmax.s f5, f4, f5 \n\t"
1053
+ "fmax.s f7, f6, f7 \n\t"
1054
+ "fmax.s f3, f1, f3 \n\t"
1055
+ "fmax.s f7, f5, f7 \n\t"
1056
+ "fmax.s f10, f3, f7 \n\t"
1057
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
1058
+ "fsw f10, (%[DST]) \n\t"
1059
+ "addi %[DST], %[DST], 4 \n\t"
1060
+ "fdiv.s f11, %[FONE], f10 \n\t"
1061
+ "addi t6, %[CNT], 0 \n\t"
1062
+ "LOOP_QUANT%=: \n\t"
1063
+ "addi t6, t6, -1 \n\t"
1064
+ "vsetvli t0, zero, e32, m8 \n\t"
1065
+ "vle32.v v0, (%[SRC]) \n\t"
1066
+ "addi %[SRC], %[SRC], 256 \n\t"
1067
+ "vle32.v v8, (%[SRC]) \n\t"
1068
+ "addi %[SRC], %[SRC], 256 \n\t"
1069
+ "vle32.v v16, (%[SRC]) \n\t"
1070
+ "addi %[SRC], %[SRC], 256 \n\t"
1071
+ "vle32.v v24, (%[SRC]) \n\t"
1072
+ "addi %[SRC], %[SRC], 256 \n\t"
1073
+ "vsetvli t0, zero, e32, m8 \n\t"
1074
+ "vfmul.vf v0, v0, f11 \n\t"
1075
+ "vfmul.vf v8, v8, f11 \n\t"
1076
+ "vfmul.vf v16, v16, f11 \n\t"
1077
+ "vfmul.vf v24, v24, f11 \n\t"
1078
+ "vfcvt.x.f.v v0, v0 \n\t"
1079
+ "vfcvt.x.f.v v8, v8 \n\t"
1080
+ "vfcvt.x.f.v v16, v16 \n\t"
1081
+ "vfcvt.x.f.v v24, v24 \n\t"
1082
+ "vsetvli t0, zero, e16, m4 \n\t"
1083
+ "vnclip.wx v0, v0, zero \n\t"
1084
+ "vnclip.wx v4, v8, zero \n\t"
1085
+ "vnclip.wx v8, v16, zero \n\t"
1086
+ "vnclip.wx v12, v24, zero \n\t"
1087
+ "vsetvli t0, zero, e8, m4 \n\t"
1088
+ "vnclip.wx v0, v0, zero \n\t"
1089
+ "vnclip.wx v4, v8, zero \n\t"
1090
+ "vse8.v v0, (%[DST]) \n\t"
1091
+ "addi %[DST], %[DST], 128 \n\t"
1092
+ "vse8.v v4, (%[DST]) \n\t"
1093
+ "addi %[DST], %[DST], 128 \n\t"
1094
+ "bnez t6, LOOP_QUANT%= \n\t"
1095
+ "sub %[K], %[K], %[BLK] \n\t"
1096
+ "bge %[K], %[BLK], LOOP_MAIN%= \n\t"
1097
+ "blez %[K], END%= \n\t"
1098
+ "LOOP_TAIL%=: \n\t"
1099
+ "vsetvli t0, zero, e32, m1 \n\t"
1100
+ "vxor.vv v31, v31, v31 \n\t"
1101
+ "vse32.v v31, (%[BUFFER]) \n\t"
1102
+ "addi t6, %[K], 0 \n\t"
1103
+ "addi s1, %[SRC], 0 \n\t"
1104
+ "TAIL_CMP%=: \n\t"
1105
+ "vsetvli t0, zero, e32, m8 \n\t"
1106
+ "vxor.vv v0, v0, v0 \n\t"
1107
+ "vsetvli t0, t6, e32, m8 \n\t"
1108
+ "vle32.v v0, (%[SRC]) \n\t"
1109
+ "addi %[SRC], %[SRC], 256 \n\t"
1110
+ "sub t6, t6, t0 \n\t"
1111
+ "vfabs.v v0, v0 \n\t"
1112
+ "vsetvli t0, zero, e32, m4 \n\t"
1113
+ "vfmax.vv v0, v0, v4 \n\t"
1114
+ "vsetvli t0, zero, e32, m2 \n\t"
1115
+ "vfmax.vv v0, v0, v2 \n\t"
1116
+ "vsetvli t0, zero, e32, m1 \n\t"
1117
+ "vfmax.vv v0, v0, v1 \n\t"
1118
+ "vle32.v v30, (%[BUFFER]) \n\t"
1119
+ "vfmax.vv v31, v30, v0 \n\t"
1120
+ "vse32.v v31, (%[BUFFER]) \n\t"
1121
+ "bnez t6, TAIL_CMP%= \n\t"
1122
+ "addi t6, %[K], 0 \n\t"
1123
+ "flw f0, (%[BUFFER]) \n\t"
1124
+ "flw f1, 4(%[BUFFER]) \n\t"
1125
+ "flw f2, 8(%[BUFFER]) \n\t"
1126
+ "flw f3, 12(%[BUFFER]) \n\t"
1127
+ "flw f4, 16(%[BUFFER]) \n\t"
1128
+ "flw f5, 20(%[BUFFER]) \n\t"
1129
+ "flw f6, 24(%[BUFFER]) \n\t"
1130
+ "flw f7, 28(%[BUFFER]) \n\t"
1131
+ "fmax.s f1, f0, f1 \n\t"
1132
+ "fmax.s f3, f2, f3 \n\t"
1133
+ "fmax.s f5, f4, f5 \n\t"
1134
+ "fmax.s f7, f6, f7 \n\t"
1135
+ "fmax.s f3, f1, f3 \n\t"
1136
+ "fmax.s f7, f5, f7 \n\t"
1137
+ "fmax.s f10, f3, f7 \n\t"
1138
+ "fmul.s f10, f10, %[RMAXREC] \n\t"
1139
+ "fsw f10, (%[DST]) \n\t"
1140
+ "addi %[DST], %[DST], 4 \n\t"
1141
+ "fdiv.s f11, %[FONE], f10 \n\t"
1142
+ "addi t6, %[K], 0 \n\t"
1143
+ "TAIL_QUANT%=: \n\t"
1144
+ "vsetvli t0, zero, e32, m8 \n\t"
1145
+ "vxor.vv v0, v0, v0 \n\t"
1146
+ "vsetvli t1, t6, e32, m8 \n\t"
1147
+ "vle32.v v0, (s1) \n\t"
1148
+ "addi s1, s1, 256 \n\t"
1149
+ "sub t6, t6, t1 \n\t"
1150
+ "vsetvli t0, zero, e32, m8 \n\t"
1151
+ "vfmul.vf v0, v0, f11 \n\t"
1152
+ "vfcvt.x.f.v v0, v0 \n\t"
1153
+ "vsetvli t0, zero, e16, m4 \n\t"
1154
+ "vnclip.wx v0, v0, zero \n\t"
1155
+ "vsetvli t0, t1, e8, m2 \n\t"
1156
+ "vnclip.wx v0, v0, zero \n\t"
1157
+ "vse8.v v0, (%[DST]) \n\t"
1158
+ "addi %[DST], %[DST], 64 \n\t"
1159
+ "bnez t6, TAIL_QUANT%= \n\t"
1160
+ "END%=: \n\t"
1161
+ : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK)
1162
+ : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BLK] "r"(BlkLen), [BUFFER] "r"(buffer),
1163
+ [CNT] "r"(cnt)
1164
+ : "cc", "t1", "t0", "t6", "s1", "f0", "f1", "f2", "f3", "f4", "f5", "f6");
1165
+ }
1166
+ }
1167
+
1168
+ } // namespace ime1
1169
+
1170
+ namespace {
1171
+ #define SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 \
1172
+ "vmadot v16, v14, v0 \n\t" \
1173
+ "vmadot v18, v14, v1 \n\t" \
1174
+ "vmadot v20, v14, v2 \n\t" \
1175
+ "vmadot v22, v14, v3 \n\t" \
1176
+ "vmadot v16, v15, v4 \n\t" \
1177
+ "vmadot v18, v15, v5 \n\t" \
1178
+ "vmadot v20, v15, v6 \n\t" \
1179
+ "vmadot v22, v15, v7 \n\t"
1180
+
1181
+ #define SQ4BIT_KERNEL_ACC_1X4X4 \
1182
+ "vfcvt.f.x.v v16, v16 \n\t" \
1183
+ "vfcvt.f.x.v v18, v18 \n\t" \
1184
+ "vfcvt.f.x.v v20, v20 \n\t" \
1185
+ "vfcvt.f.x.v v22, v22 \n\t" \
1186
+ "addi s2, s1, 16 \n\t" \
1187
+ "addi s3, s1, 32 \n\t" \
1188
+ "addi s4, s1, 48 \n\t" \
1189
+ "addi s6, s5, 12 \n\t" \
1190
+ "vfmacc.vv v28, v16, v24 \n\t" \
1191
+ "vfmacc.vv v29, v18, v25 \n\t" \
1192
+ "vfmacc.vv v30, v20, v26 \n\t" \
1193
+ "vfmacc.vv v31, v22, v27 \n\t"
1194
+
1195
+ #define SQ4BIT_KERNEL_ACC_F16_1X4X4 \
1196
+ "vfcvt.f.x.v v16, v16 \n\t" \
1197
+ "vfcvt.f.x.v v18, v18 \n\t" \
1198
+ "vfcvt.f.x.v v20, v20 \n\t" \
1199
+ "vfcvt.f.x.v v22, v22 \n\t" \
1200
+ "addi s2, s1, 8 \n\t" \
1201
+ "addi s3, s1, 16 \n\t" \
1202
+ "addi s4, s1, 24 \n\t" \
1203
+ "addi s6, s5, 12 \n\t" \
1204
+ "vfmacc.vv v28, v16, v24 \n\t" \
1205
+ "vfmacc.vv v29, v18, v25 \n\t" \
1206
+ "vfmacc.vv v30, v20, v26 \n\t" \
1207
+ "vfmacc.vv v31, v22, v27 \n\t"
1208
+
1209
+ #define SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 \
1210
+ "vle8.v v4, (s1) \n\t" \
1211
+ "addi s1, s1, 128 \n\t" \
1212
+ "vle8.v v5, (s2) \n\t" \
1213
+ "addi s2, s2, 128 \n\t" \
1214
+ "vle8.v v6, (s3) \n\t" \
1215
+ "addi s3, s3, 128 \n\t" \
1216
+ "vle8.v v7, (s4) \n\t" \
1217
+ "addi s4, s4, 128 \n\t" \
1218
+ "vsetvli t0, zero, e8, mf4 \n\t" \
1219
+ "vle8.v v14, (s5) \n\t" \
1220
+ "addi s5, s5, 16 \n\t" \
1221
+ "vle8.v v15, (s6) \n\t" \
1222
+ "addi s6, s6, 16 \n\t" \
1223
+ "addi t5, t5, -1 \n\t" \
1224
+ "vsetvli t0, zero, e8, m1 \n\t" \
1225
+ "vand.vi v0, v4, 15 \n\t" \
1226
+ "vand.vi v1, v5, 15 \n\t" \
1227
+ "vand.vi v2, v6, 15 \n\t" \
1228
+ "vand.vi v3, v7, 15 \n\t" \
1229
+ "vsrl.vi v4, v4, 4 \n\t" \
1230
+ "vsrl.vi v5, v5, 4 \n\t" \
1231
+ "vsrl.vi v6, v6, 4 \n\t" \
1232
+ "vsrl.vi v7, v7, 4 \n\t"
1233
+
1234
+ #define SQ4BIT_KERNEL_LOAD_ZP_16X1 \
1235
+ "vsetvli t0, zero, e8, mf2 \n\t" \
1236
+ "vle8.v v1, (s7) \n\t" \
1237
+ "vsetvli t0, zero, e8, m1 \n\t" \
1238
+ "vrgather.vv v8, v1, v13 \n\t" \
1239
+ "vadd.vi v13, v13, 4 \n\t" \
1240
+ "vrgather.vv v9, v1, v13 \n\t" \
1241
+ "vadd.vi v13, v13, 4 \n\t" \
1242
+ "vrgather.vv v10, v1, v13 \n\t" \
1243
+ "vadd.vi v13, v13, 4 \n\t" \
1244
+ "vrgather.vv v11, v1, v13 \n\t" \
1245
+ "vadd.vi v13, v13, -12 \n\t"
1246
+
1247
+ // using for M4Kernel
1248
+ #define LOAD_B_16x8x2 \
1249
+ "vsetvli t0, zero, e8, m1 \n\t" \
1250
+ "vle8.v v6, (s1) \n\t" \
1251
+ "addi s1, s1, 32*4 \n\t" \
1252
+ "vle8.v v7, (s2) \n\t" \
1253
+ "addi s2, s2, 32*4 \n\t" \
1254
+ "vle8.v v8, (s3) \n\t" \
1255
+ "addi s3, s3, 32*4 \n\t" \
1256
+ "vle8.v v9, (s4) \n\t" \
1257
+ "addi s4, s4, 32*4 \n\t" \
1258
+ \
1259
+ "vand.vi v2, v6, 15 \n\t" \
1260
+ "vand.vi v3, v7, 15 \n\t" \
1261
+ "vand.vi v4, v8, 15 \n\t" \
1262
+ "vand.vi v5, v9, 15 \n\t" \
1263
+ \
1264
+ "vsrl.vi v6, v6, 4 \n\t" \
1265
+ "vsrl.vi v7, v7, 4 \n\t" \
1266
+ "vsrl.vi v8, v8, 4 \n\t" \
1267
+ "vsrl.vi v9, v9, 4 \n\t"
1268
+
1269
+ // [s2|s5, s3, s4, s6]
1270
+ #define LOAD_SCALE_4x16_FP16 \
1271
+ "addi s2, s5, -8 \n\t" \
1272
+ "addi s3, s5, 8 \n\t" \
1273
+ "addi s4, s5, 16 \n\t" \
1274
+ "addi s6, s5, 24 \n\t" \
1275
+ "li t1, 0xf0 \n\t" \
1276
+ "vmv.s.x v0, t1 \n\t" \
1277
+ "vsetvli t0, zero, e16, mf4 \n\t" \
1278
+ "vle16.v v9, (s5) \n\t" \
1279
+ "vle16.v v11, (s3) \n\t" \
1280
+ "vle16.v v13, (s4) \n\t" \
1281
+ "vle16.v v15, (s6) \n\t" \
1282
+ "vsetvli t0, zero, e16, mf2 \n\t" \
1283
+ "vle16.v v9, (s2), v0.t \n\t" \
1284
+ "vle16.v v11, (s5), v0.t \n\t" \
1285
+ "vle16.v v13, (s3), v0.t \n\t" \
1286
+ "vle16.v v15, (s4), v0.t \n\t" \
1287
+ "vfwcvt.f.f.v v8, v9 \n\t" \
1288
+ "vfwcvt.f.f.v v10, v11 \n\t" \
1289
+ "vfwcvt.f.f.v v12, v13 \n\t" \
1290
+ "vfwcvt.f.f.v v14, v15 \n\t" \
1291
+ "vsetvli t0, zero, e32, m1 \n\t" \
1292
+ "vmv.v.v v9, v8 \n\t" \
1293
+ "vmv.v.v v11, v10 \n\t" \
1294
+ "vmv.v.v v13, v12 \n\t" \
1295
+ "vmv.v.v v15, v14 \n\t" \
1296
+ "li t1, 0xf0 \n\t" \
1297
+ "vmv.s.x v0, t1 \n\t" \
1298
+ "vsetvli t0, zero, e32, mf2 \n\t" \
1299
+ "vfmul.vf v8, v8, f1 \n\t" \
1300
+ "vfmul.vf v10, v10, f1 \n\t" \
1301
+ "vfmul.vf v12, v12, f1 \n\t" \
1302
+ "vfmul.vf v14, v14, f1 \n\t" \
1303
+ "vfmul.vf v9, v9, f3 \n\t" \
1304
+ "vfmul.vf v11, v11, f3 \n\t" \
1305
+ "vfmul.vf v13, v13, f3 \n\t" \
1306
+ "vfmul.vf v15, v15, f3 \n\t" \
1307
+ "vsetvli t0, zero, e32, m1 \n\t" \
1308
+ "vfmul.vf v8, v8, f2, v0.t \n\t" \
1309
+ "vfmul.vf v10, v10, f2, v0.t \n\t" \
1310
+ "vfmul.vf v12, v12, f2, v0.t \n\t" \
1311
+ "vfmul.vf v14, v14, f2, v0.t \n\t" \
1312
+ "vfmul.vf v9, v9, f4, v0.t \n\t" \
1313
+ "vfmul.vf v11, v11, f4, v0.t \n\t" \
1314
+ "vfmul.vf v13, v13, f4, v0.t \n\t" \
1315
+ "vfmul.vf v15, v15, f4, v0.t \n\t"
1316
+
1317
+ // [s2|s5, s3, s4, s6]
1318
+ #define LOAD_SCALE_4x16 \
1319
+ "addi s2, s5, -16 \n\t" \
1320
+ "addi s3, s5, 16 \n\t" \
1321
+ "addi s4, s5, 32 \n\t" \
1322
+ "addi s6, s5, 48 \n\t" \
1323
+ "li t1, 0xf0 \n\t" \
1324
+ "vmv.s.x v0, t1 \n\t" \
1325
+ "vsetvli t0, zero, e32, mf2 \n\t" \
1326
+ "vle32.v v8, (s5) \n\t" \
1327
+ "vle32.v v10, (s3) \n\t" \
1328
+ "vle32.v v12, (s4) \n\t" \
1329
+ "vle32.v v14, (s6) \n\t" \
1330
+ "vsetvli t0, zero, e32, m1 \n\t" \
1331
+ "vle32.v v8, (s2), v0.t \n\t" \
1332
+ "vle32.v v10, (s5), v0.t \n\t" \
1333
+ "vle32.v v12, (s3), v0.t \n\t" \
1334
+ "vle32.v v14, (s4), v0.t \n\t" \
1335
+ "vmv.v.v v9, v8 \n\t" \
1336
+ "vmv.v.v v11, v10 \n\t" \
1337
+ "vmv.v.v v13, v12 \n\t" \
1338
+ "vmv.v.v v15, v14 \n\t" \
1339
+ "vsetvli t0, zero, e32, mf2 \n\t" \
1340
+ "vfmul.vf v8, v8, f1 \n\t" \
1341
+ "vfmul.vf v10, v10, f1 \n\t" \
1342
+ "vfmul.vf v12, v12, f1 \n\t" \
1343
+ "vfmul.vf v14, v14, f1 \n\t" \
1344
+ "vfmul.vf v9, v9, f3 \n\t" \
1345
+ "vfmul.vf v11, v11, f3 \n\t" \
1346
+ "vfmul.vf v13, v13, f3 \n\t" \
1347
+ "vfmul.vf v15, v15, f3 \n\t" \
1348
+ "vsetvli t0, zero, e32, m1 \n\t" \
1349
+ "vfmul.vf v8, v8, f2, v0.t \n\t" \
1350
+ "vfmul.vf v10, v10, f2, v0.t \n\t" \
1351
+ "vfmul.vf v12, v12, f2, v0.t \n\t" \
1352
+ "vfmul.vf v14, v14, f2, v0.t \n\t" \
1353
+ "vfmul.vf v9, v9, f4, v0.t \n\t" \
1354
+ "vfmul.vf v11, v11, f4, v0.t \n\t" \
1355
+ "vfmul.vf v13, v13, f4, v0.t \n\t" \
1356
+ "vfmul.vf v15, v15, f4, v0.t \n\t"
1357
+
1358
+ //[s1| BIAS, s2, s3, s4]
1359
+ #define LOAD_BIAS \
1360
+ "vsetvli t0, zero, e32, mf2 \n\t" \
1361
+ "li t1, 0xf0 \n\t" \
1362
+ "vmv.s.x v0, t1 \n\t" \
1363
+ "addi s1, %[BIAS], -16 \n\t" \
1364
+ "addi s2, %[BIAS], 16 \n\t" \
1365
+ "addi s3, %[BIAS], 32 \n\t" \
1366
+ "addi s4, %[BIAS], 48 \n\t" \
1367
+ \
1368
+ "vle32.v v24, (%[BIAS]) \n\t" \
1369
+ "vle32.v v26, (s2) \n\t" \
1370
+ "vle32.v v28, (s3) \n\t" \
1371
+ "vle32.v v30, (s4) \n\t" \
1372
+ "vsetvli t0, zero, e32, m1 \n\t" \
1373
+ "vle32.v v24, (s1), v0.t \n\t" \
1374
+ "vle32.v v26, (%[BIAS]), v0.t \n\t" \
1375
+ "vle32.v v28, (s2), v0.t \n\t" \
1376
+ "vle32.v v30, (s3), v0.t \n\t" \
1377
+ "vmv.v.v v25, v24 \n\t" \
1378
+ "vmv.v.v v27, v26 \n\t" \
1379
+ "vmv.v.v v29, v28 \n\t" \
1380
+ "vmv.v.v v31, v30 \n\t"
1381
+
1382
+ #define SQ4BIT_KERNEL_COMP_4x16x16 \
1383
+ "vmadot v16, v10, v2 \n\t" \
1384
+ "vmadot v18, v10, v3 \n\t" \
1385
+ "vmadot v20, v10, v4 \n\t" \
1386
+ "vmadot v22, v10, v5 \n\t" \
1387
+ "vmadot v16, v11, v6 \n\t" \
1388
+ "vmadot v18, v11, v7 \n\t" \
1389
+ "vmadot v20, v11, v8 \n\t" \
1390
+ "vmadot v22, v11, v9 \n\t"
1391
+
1392
+ #define SAVE_RESULT_4x16 \
1393
+ "addi a1, %[C], 0 \n\t" \
1394
+ "add a2, %[C], %[LDC] \n\t" \
1395
+ "add a3, a2, %[LDC] \n\t" \
1396
+ "add a4, a3, %[LDC] \n\t" \
1397
+ "addi a2, a2, -16 \n\t" \
1398
+ "addi a4, a4, -16 \n\t" \
1399
+ "li t1, 0xf0 \n\t" \
1400
+ "vmv.s.x v0, t1 \n\t" \
1401
+ "vsetvli t0, zero, e32, mf2 \n\t" \
1402
+ \
1403
+ "vse32.v v24, (a1) \n\t" \
1404
+ "addi a1, a1, 16 \n\t" \
1405
+ "vse32.v v25, (a3) \n\t" \
1406
+ "addi a3, a3, 16 \n\t" \
1407
+ \
1408
+ "vse32.v v26, (a1) \n\t" \
1409
+ "addi a1, a1, 16 \n\t" \
1410
+ "vse32.v v27, (a3) \n\t" \
1411
+ "addi a3, a3, 16 \n\t" \
1412
+ \
1413
+ "vse32.v v28, (a1) \n\t" \
1414
+ "addi a1, a1, 16 \n\t" \
1415
+ "vse32.v v29, (a3) \n\t" \
1416
+ "addi a3, a3, 16 \n\t" \
1417
+ \
1418
+ "vse32.v v30, (a1) \n\t" \
1419
+ "vse32.v v31, (a3) \n\t" \
1420
+ "vsetvli t0, zero, e32, m1 \n\t" \
1421
+ \
1422
+ "vse32.v v24, (a2), v0.t \n\t" \
1423
+ "addi a2, a2, 16 \n\t" \
1424
+ "vse32.v v25, (a4), v0.t \n\t" \
1425
+ "addi a4, a4, 16 \n\t" \
1426
+ \
1427
+ "vse32.v v26, (a2), v0.t \n\t" \
1428
+ "addi a2, a2, 16 \n\t" \
1429
+ "vse32.v v27, (a4), v0.t \n\t" \
1430
+ "addi a4, a4, 16 \n\t" \
1431
+ \
1432
+ "vse32.v v28, (a2), v0.t \n\t" \
1433
+ "addi a2, a2, 16 \n\t" \
1434
+ "vse32.v v29, (a4), v0.t \n\t" \
1435
+ "addi a4, a4, 16 \n\t" \
1436
+ \
1437
+ "vse32.v v30, (a2), v0.t \n\t" \
1438
+ "vse32.v v31, (a4), v0.t \n\t"
1439
+
1440
+ #define SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 \
1441
+ "vsetvli t0, zero, e8, mf2 \n\t" \
1442
+ "vle8.v v11, (s6) \n\t" \
1443
+ "vsetvli t0, zero, e8, m1 \n\t" \
1444
+ "vrgather.vv v12, v11, v1 \n\t" \
1445
+ "vadd.vi v1, v1, 4 \n\t" \
1446
+ "vrgather.vv v13, v11, v1 \n\t" \
1447
+ "vadd.vi v1, v1, 4 \n\t" \
1448
+ "vrgather.vv v14, v11, v1 \n\t" \
1449
+ "vadd.vi v1, v1, 4 \n\t" \
1450
+ "vrgather.vv v15, v11, v1 \n\t" \
1451
+ "vadd.vi v1, v1, -12 \n\t"
1452
+
1453
+ template <bool HasZeroPoint>
1454
+ void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
1455
+ const std::byte * QuantA,
1456
+ const std::byte * QuantBData,
1457
+ const float * QuantBScale,
1458
+ const std::byte * QuantBZeroPoint,
1459
+ float * C,
1460
+ size_t CountN,
1461
+ size_t BlockCountK,
1462
+ const float * Bias,
1463
+ const size_t ldc) {
1464
+ GGML_UNUSED(QuantBScale);
1465
+ GGML_UNUSED(QuantBZeroPoint);
1466
+ size_t LDC = ldc * sizeof(float);
1467
+ const size_t INNER = BlkLen / 16;
1468
+ float tmp[4 * 16];
1469
+
1470
+ if constexpr (HasZeroPoint) {
1471
+ for (size_t n = 0; n < CountN; n += 16) {
1472
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
1473
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
1474
+ n * BlockCountK * BlkLen / 2 + // b data
1475
+ n * BlockCountK * sizeof(uint8_t) + // zp
1476
+ n * BlockCountK * sizeof(_Float16); // scale
1477
+ float * CPtr = C + n;
1478
+ if (NBLKS < 16) {
1479
+ CPtr = tmp;
1480
+ LDC = 16 * sizeof(float);
1481
+ }
1482
+ if (Bias != nullptr) {
1483
+ const float * bias = Bias + n;
1484
+ if (NBLKS < 16) {
1485
+ __asm__ volatile(
1486
+ "vsetvli t0, %[N], e32, m2 \n\t"
1487
+ "vle32.v v0, (%[SRC]) \n\t"
1488
+ "vse32.v v0, (%[DST]) \n\t"
1489
+ :
1490
+ : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
1491
+ : "cc", "t0");
1492
+ bias = tmp;
1493
+ }
1494
+ __asm__ volatile(LOAD_BIAS
1495
+
1496
+ "addi t3, %[BlockCountK], 0 \n\t"
1497
+
1498
+ "vsetvli t0, zero, e8, m1 \n\t"
1499
+ "li s1, 24 \n\t"
1500
+ "vmv.v.i v1, 3 \n\t"
1501
+ "vsetvli t0, s1, e8, m1 \n\t"
1502
+ "vmv.v.i v1, 2 \n\t"
1503
+ "vsetvli t0, zero, e8, mf2 \n\t"
1504
+ "vmv.v.i v1, 1 \n\t"
1505
+ "vsetvli t0, zero, e8, mf4 \n\t"
1506
+ "vmv.v.i v1, 0 \n\t"
1507
+
1508
+ "addi a1, %[A], 0 \n\t"
1509
+ "addi s1, %[B], 0 \n\t"
1510
+
1511
+ "BLOCK_COUNTK_LOOP%=: \n\t"
1512
+ // scale offset
1513
+ "addi s5, s1, 0 \n\t"
1514
+ // zp offset
1515
+ "addi s6, s1, 32 \n\t"
1516
+ "addi s1, s6, 16 \n\t"
1517
+ "addi s2, s1, 32 \n\t"
1518
+ "addi s3, s1, 32*2 \n\t"
1519
+ "addi s4, s1, 32*3 \n\t"
1520
+
1521
+ "vsetvli t0, zero, e32, m8 \n\t"
1522
+ "vxor.vv v16, v16, v16 \n\t"
1523
+ // load a scale
1524
+ "flw f1, (a1) \n\t"
1525
+ "flw f2, 4(a1) \n\t"
1526
+ "flw f3, 8(a1) \n\t"
1527
+ "flw f4, 12(a1) \n\t"
1528
+ "addi a1, a1, 16 \n\t"
1529
+ "addi t2, %[INNER], 0 \n\t"
1530
+
1531
+ SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1532
+
1533
+ "BLOCK_INNER_LOOP%=: \n\t"
1534
+
1535
+ LOAD_B_16x8x2
1536
+
1537
+ "vle8.v v10, (a1) \n\t"
1538
+ "addi a1, a1, 32 \n\t"
1539
+ "vle8.v v11, (a1) \n\t"
1540
+ "addi a1, a1, 32 \n\t"
1541
+ "vsub.vv v2, v2, v12 \n\t"
1542
+ "vsub.vv v6, v6, v12 \n\t"
1543
+ "vsub.vv v3, v3, v13 \n\t"
1544
+ "vsub.vv v7, v7, v13 \n\t"
1545
+ "vsub.vv v4, v4, v14 \n\t"
1546
+ "vsub.vv v8, v8, v14 \n\t"
1547
+ "vsub.vv v5, v5, v15 \n\t"
1548
+ "vsub.vv v9, v9, v15 \n\t"
1549
+
1550
+ SQ4BIT_KERNEL_COMP_4x16x16
1551
+
1552
+ "addi t2, t2, -1 \n\t"
1553
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1554
+
1555
+ LOAD_SCALE_4x16_FP16
1556
+
1557
+ "vsetvli t0, zero, e32, m8 \n\t"
1558
+ "vfcvt.f.x.v v16, v16 \n\t"
1559
+ "vfmacc.vv v24, v16, v8 \n\t"
1560
+ "addi t3, t3, -1 \n\t"
1561
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1562
+
1563
+ "RESULT_SAVE%=: \n\t"
1564
+
1565
+ SAVE_RESULT_4x16
1566
+
1567
+ :
1568
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1569
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
1570
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
1571
+ "s2", "s3", "s4", "s5", "s6");
1572
+
1573
+ } else {
1574
+ __asm__ volatile(
1575
+ "vsetvli t0, zero, e32, m8 \n\t"
1576
+ "vxor.vv v24, v24, v24 \n\t"
1577
+ "addi t3, %[BlockCountK], 0 \n\t"
1578
+ "vsetvli t0, zero, e8, m1 \n\t"
1579
+ "li s1, 24 \n\t"
1580
+ "vmv.v.i v1, 3 \n\t"
1581
+ "vsetvli t0, s1, e8, m1 \n\t"
1582
+ "vmv.v.i v1, 2 \n\t"
1583
+ "vsetvli t0, zero, e8, mf2 \n\t"
1584
+ "vmv.v.i v1, 1 \n\t"
1585
+ "vsetvli t0, zero, e8, mf4 \n\t"
1586
+ "vmv.v.i v1, 0 \n\t"
1587
+ "addi a1, %[A], 0 \n\t"
1588
+ "addi s1, %[B], 0 \n\t"
1589
+ "BLOCK_COUNTK_LOOP%=: \n\t"
1590
+ // scale offset
1591
+ "addi s5, s1, 0 \n\t"
1592
+ // zp offset
1593
+ "addi s6, s1, 32 \n\t"
1594
+ "addi s1, s6, 16 \n\t"
1595
+ "addi s2, s1, 32 \n\t"
1596
+ "addi s3, s1, 32*2 \n\t"
1597
+ "addi s4, s1, 32*3 \n\t"
1598
+
1599
+ "vsetvli t0, zero, e32, m8 \n\t"
1600
+ "vxor.vv v16, v16, v16 \n\t"
1601
+ // load a scale
1602
+ "flw f1, (a1) \n\t"
1603
+ "flw f2, 4(a1) \n\t"
1604
+ "flw f3, 8(a1) \n\t"
1605
+ "flw f4, 12(a1) \n\t"
1606
+ "addi a1, a1, 16 \n\t"
1607
+ "addi t2, %[INNER], 0 \n\t"
1608
+
1609
+ SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1610
+
1611
+ "BLOCK_INNER_LOOP%=: \n\t"
1612
+
1613
+ LOAD_B_16x8x2
1614
+
1615
+ "vle8.v v10, (a1) \n\t"
1616
+ "addi a1, a1, 32 \n\t"
1617
+ "vle8.v v11, (a1) \n\t"
1618
+ "addi a1, a1, 32 \n\t"
1619
+ "vsub.vv v2, v2, v12 \n\t"
1620
+ "vsub.vv v6, v6, v12 \n\t"
1621
+ "vsub.vv v3, v3, v13 \n\t"
1622
+ "vsub.vv v7, v7, v13 \n\t"
1623
+ "vsub.vv v4, v4, v14 \n\t"
1624
+ "vsub.vv v8, v8, v14 \n\t"
1625
+ "vsub.vv v5, v5, v15 \n\t"
1626
+ "vsub.vv v9, v9, v15 \n\t"
1627
+
1628
+ SQ4BIT_KERNEL_COMP_4x16x16
1629
+
1630
+ "addi t2, t2, -1 \n\t"
1631
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1632
+
1633
+ LOAD_SCALE_4x16_FP16
1634
+
1635
+ "vsetvli t0, zero, e32, m8 \n\t"
1636
+ "vfcvt.f.x.v v16, v16 \n\t"
1637
+ "vfmacc.vv v24, v16, v8 \n\t"
1638
+ "addi t3, t3, -1 \n\t"
1639
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1640
+
1641
+ "RESULT_SAVE%=: \n\t"
1642
+
1643
+ SAVE_RESULT_4x16
1644
+
1645
+ :
1646
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1647
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
1648
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
1649
+ "s4", "s5", "s6");
1650
+ }
1651
+ }
1652
+ } else {
1653
+ for (size_t n = 0; n < CountN; n += 16) {
1654
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
1655
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
1656
+ n * BlockCountK * BlkLen / 2 + // b data
1657
+ n * BlockCountK * sizeof(_Float16); // scale
1658
+ float * CPtr = C + n;
1659
+ if (NBLKS < 16) {
1660
+ CPtr = tmp;
1661
+ LDC = 16 * sizeof(float);
1662
+ }
1663
+ if (Bias != nullptr) {
1664
+ const float * bias = Bias + n;
1665
+ if (NBLKS < 16) {
1666
+ __asm__ volatile(
1667
+ "vsetvli t0, %[N], e32, m2 \n\t"
1668
+ "vle32.v v0, (%[SRC]) \n\t"
1669
+ "vse32.v v0, (%[DST]) \n\t"
1670
+ :
1671
+ : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
1672
+ : "cc", "t0");
1673
+ bias = tmp;
1674
+ }
1675
+ __asm__ volatile(LOAD_BIAS
1676
+
1677
+ "addi t3, %[BlockCountK], 0 \n\t"
1678
+ "addi a1, %[A], 0 \n\t"
1679
+ "addi s1, %[B], 0 \n\t"
1680
+ "BLOCK_COUNTK_LOOP%=: \n\t"
1681
+ "addi s5, s1, 0 \n\t"
1682
+ "addi s1, s5, 32 \n\t"
1683
+ "addi s2, s1, 32 \n\t"
1684
+ "addi s3, s1, 32*2 \n\t"
1685
+ "addi s4, s1, 32*3 \n\t"
1686
+ "vsetvli t0, zero, e32, m8 \n\t"
1687
+ "vxor.vv v16, v16, v16 \n\t"
1688
+ // load a scale
1689
+ "flw f1, (a1) \n\t"
1690
+ "flw f2, 4(a1) \n\t"
1691
+ "flw f3, 8(a1) \n\t"
1692
+ "flw f4, 12(a1) \n\t"
1693
+ "addi a1, a1, 16 \n\t"
1694
+ "addi t2, %[INNER], 0 \n\t"
1695
+ "BLOCK_INNER_LOOP%=: \n\t"
1696
+
1697
+ LOAD_B_16x8x2
1698
+
1699
+ "vsetvli t0, zero, e8, m1 \n\t"
1700
+ "vle8.v v10, (a1) \n\t"
1701
+ "addi a1, a1, 32 \n\t"
1702
+ "vle8.v v11, (a1) \n\t"
1703
+ "addi a1, a1, 32 \n\t"
1704
+ "vadd.vi v2, v2, -8 \n\t"
1705
+ "vadd.vi v3, v3, -8 \n\t"
1706
+ "vadd.vi v4, v4, -8 \n\t"
1707
+ "vadd.vi v5, v5, -8 \n\t"
1708
+ "vadd.vi v6, v6, -8 \n\t"
1709
+ "vadd.vi v7, v7, -8 \n\t"
1710
+ "vadd.vi v8, v8, -8 \n\t"
1711
+ "vadd.vi v9, v9, -8 \n\t"
1712
+
1713
+ SQ4BIT_KERNEL_COMP_4x16x16
1714
+
1715
+ "addi t2, t2, -1 \n\t"
1716
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1717
+
1718
+ LOAD_SCALE_4x16_FP16
1719
+
1720
+ "vsetvli t0, zero, e32, m8 \n\t"
1721
+ "vfcvt.f.x.v v16, v16 \n\t"
1722
+ "vfmacc.vv v24, v16, v8 \n\t"
1723
+ "addi t3, t3, -1 \n\t"
1724
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1725
+ "RESULT_SAVE%=: \n\t"
1726
+
1727
+ SAVE_RESULT_4x16
1728
+
1729
+ :
1730
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1731
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
1732
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
1733
+ "s2", "s3", "s4", "s5", "s6");
1734
+
1735
+ } else {
1736
+ __asm__ volatile(
1737
+ "vsetvli t0, zero, e32, m8 \n\t"
1738
+ "vxor.vv v24, v24, v24 \n\t"
1739
+ "addi t3, %[BlockCountK], 0 \n\t"
1740
+ "addi a1, %[A], 0 \n\t"
1741
+ "addi s1, %[B], 0 \n\t"
1742
+ "BLOCK_COUNTK_LOOP%=: \n\t"
1743
+ "addi s5, s1, 0 \n\t"
1744
+ "addi s1, s5, 32 \n\t"
1745
+ "addi s2, s1, 32 \n\t"
1746
+ "addi s3, s1, 32*2 \n\t"
1747
+ "addi s4, s1, 32*3 \n\t"
1748
+ "vsetvli t0, zero, e32, m8 \n\t"
1749
+ "vxor.vv v16, v16, v16 \n\t"
1750
+ // load a scale
1751
+ "flw f1, (a1) \n\t"
1752
+ "flw f2, 4(a1) \n\t"
1753
+ "flw f3, 8(a1) \n\t"
1754
+ "flw f4, 12(a1) \n\t"
1755
+ "addi a1, a1, 16 \n\t"
1756
+ "addi t2, %[INNER], 0 \n\t"
1757
+ "BLOCK_INNER_LOOP%=: \n\t"
1758
+
1759
+ LOAD_B_16x8x2
1760
+
1761
+ "vsetvli t0, zero, e8, m1 \n\t"
1762
+ "vle8.v v10, (a1) \n\t"
1763
+ "addi a1, a1, 32 \n\t"
1764
+ "vle8.v v11, (a1) \n\t"
1765
+ "addi a1, a1, 32 \n\t"
1766
+ "vadd.vi v2, v2, -8 \n\t"
1767
+ "vadd.vi v3, v3, -8 \n\t"
1768
+ "vadd.vi v4, v4, -8 \n\t"
1769
+ "vadd.vi v5, v5, -8 \n\t"
1770
+ "vadd.vi v6, v6, -8 \n\t"
1771
+ "vadd.vi v7, v7, -8 \n\t"
1772
+ "vadd.vi v8, v8, -8 \n\t"
1773
+ "vadd.vi v9, v9, -8 \n\t"
1774
+
1775
+ SQ4BIT_KERNEL_COMP_4x16x16
1776
+
1777
+ "addi t2, t2, -1 \n\t"
1778
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1779
+
1780
+ LOAD_SCALE_4x16_FP16
1781
+
1782
+ "vsetvli t0, zero, e32, m8 \n\t"
1783
+ "vfcvt.f.x.v v16, v16 \n\t"
1784
+ "vfmacc.vv v24, v16, v8 \n\t"
1785
+ "addi t3, t3, -1 \n\t"
1786
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1787
+ "RESULT_SAVE%=: \n\t"
1788
+
1789
+ SAVE_RESULT_4x16
1790
+
1791
+ :
1792
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1793
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
1794
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
1795
+ "s4", "s5", "s6");
1796
+ }
1797
+ }
1798
+ }
1799
+ if (CountN % 16 != 0) {
1800
+ // stroe output from tmp to C when NBLKS less than 16.
1801
+ float * CPtr = C + CountN / 16 * 16;
1802
+ const size_t N = CountN % 16;
1803
+ LDC = ldc * sizeof(float);
1804
+ __asm__ volatile(
1805
+ "vsetvli t0, %[N], e32, m2 \n\t"
1806
+ "vle32.v v0, (%[SRC]) \n\t"
1807
+ "addi s2, %[SRC], 64 \n\t"
1808
+ "addi s3, %[SRC], 64*2 \n\t"
1809
+ "addi s4, %[SRC], 64*3 \n\t"
1810
+ "vle32.v v2, (s2) \n\t"
1811
+ "vle32.v v4, (s3) \n\t"
1812
+ "vle32.v v6, (s4) \n\t"
1813
+ "add t2, %[DST], %[LDC] \n\t"
1814
+ "add t3, t2, %[LDC] \n\t"
1815
+ "add t4, t3, %[LDC] \n\t"
1816
+ "vse32.v v0, (%[DST]) \n\t"
1817
+ "vse32.v v2, (t2) \n\t"
1818
+ "vse32.v v4, (t3) \n\t"
1819
+ "vse32.v v6, (t4) \n\t"
1820
+ :
1821
+ : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC)
1822
+ : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4");
1823
+ }
1824
+ }
1825
+
1826
+ template <bool HasZeroPoint>
1827
+ void SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen,
1828
+ const std::byte * QuantA,
1829
+ const std::byte * QuantBData,
1830
+ const float * QuantBScale,
1831
+ const std::byte * QuantBZeroPoint,
1832
+ float * C,
1833
+ size_t CountN,
1834
+ size_t BlockCountK,
1835
+ const float * Bias,
1836
+ const size_t ldc) {
1837
+ GGML_UNUSED(QuantBScale);
1838
+ GGML_UNUSED(QuantBZeroPoint);
1839
+ size_t LDC = ldc * sizeof(float);
1840
+ const size_t INNER = BlkLen / 16;
1841
+ float tmp[4 * 16];
1842
+
1843
+ if constexpr (HasZeroPoint) {
1844
+ for (size_t n = 0; n < CountN; n += 16) {
1845
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
1846
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
1847
+ n * BlockCountK * BlkLen / 2 + // b data
1848
+ n * BlockCountK * sizeof(uint8_t) + // zp
1849
+ n * BlockCountK * sizeof(float); // scale
1850
+ float * CPtr = C + n;
1851
+ if (NBLKS < 16) {
1852
+ CPtr = tmp;
1853
+ LDC = 16 * sizeof(float);
1854
+ }
1855
+ if (Bias != nullptr) {
1856
+ const float * bias = Bias + n;
1857
+ if (NBLKS < 16) {
1858
+ __asm__ volatile(
1859
+ "vsetvli t0, %[N], e32, m2 \n\t"
1860
+ "vle32.v v0, (%[SRC]) \n\t"
1861
+ "vse32.v v0, (%[DST]) \n\t"
1862
+ :
1863
+ : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
1864
+ : "cc", "t0");
1865
+ bias = tmp;
1866
+ }
1867
+
1868
+ __asm__ volatile(LOAD_BIAS
1869
+ "addi t3, %[BlockCountK], 0 \n\t"
1870
+ "vsetvli t0, zero, e8, m1 \n\t"
1871
+ "li s1, 24 \n\t"
1872
+ "vmv.v.i v1, 3 \n\t"
1873
+ "vsetvli t0, s1, e8, m1 \n\t"
1874
+ "vmv.v.i v1, 2 \n\t"
1875
+ "vsetvli t0, zero, e8, mf2 \n\t"
1876
+ "vmv.v.i v1, 1 \n\t"
1877
+ "vsetvli t0, zero, e8, mf4 \n\t"
1878
+ "vmv.v.i v1, 0 \n\t"
1879
+ "addi a1, %[A], 0 \n\t"
1880
+ "addi s1, %[B], 0 \n\t"
1881
+ "BLOCK_COUNTK_LOOP%=: \n\t"
1882
+ // scale offset
1883
+ "addi s5, s1, 0 \n\t"
1884
+ // zp offset
1885
+ "addi s6, s1, 64 \n\t"
1886
+ "addi s1, s6, 16 \n\t"
1887
+ "addi s2, s1, 32 \n\t"
1888
+ "addi s3, s1, 32*2 \n\t"
1889
+ "addi s4, s1, 32*3 \n\t"
1890
+ "vsetvli t0, zero, e32, m8 \n\t"
1891
+ "vxor.vv v16, v16, v16 \n\t"
1892
+ // load a scale
1893
+ "flw f1, (a1) \n\t"
1894
+ "flw f2, 4(a1) \n\t"
1895
+ "flw f3, 8(a1) \n\t"
1896
+ "flw f4, 12(a1) \n\t"
1897
+ "addi a1, a1, 16 \n\t"
1898
+ "addi t2, %[INNER], 0 \n\t"
1899
+
1900
+ SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1901
+
1902
+ "BLOCK_INNER_LOOP%=: \n\t"
1903
+
1904
+ LOAD_B_16x8x2
1905
+
1906
+ "vle8.v v10, (a1) \n\t"
1907
+ "addi a1, a1, 32 \n\t"
1908
+ "vle8.v v11, (a1) \n\t"
1909
+ "addi a1, a1, 32 \n\t"
1910
+ "vsub.vv v2, v2, v12 \n\t"
1911
+ "vsub.vv v6, v6, v12 \n\t"
1912
+ "vsub.vv v3, v3, v13 \n\t"
1913
+ "vsub.vv v7, v7, v13 \n\t"
1914
+ "vsub.vv v4, v4, v14 \n\t"
1915
+ "vsub.vv v8, v8, v14 \n\t"
1916
+ "vsub.vv v5, v5, v15 \n\t"
1917
+ "vsub.vv v9, v9, v15 \n\t"
1918
+
1919
+ SQ4BIT_KERNEL_COMP_4x16x16
1920
+
1921
+ "addi t2, t2, -1 \n\t"
1922
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
1923
+
1924
+ LOAD_SCALE_4x16
1925
+
1926
+ "vsetvli t0, zero, e32, m8 \n\t"
1927
+ "vfcvt.f.x.v v16, v16 \n\t"
1928
+ "vfmacc.vv v24, v16, v8 \n\t"
1929
+ "addi t3, t3, -1 \n\t"
1930
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
1931
+
1932
+ "RESULT_SAVE%=: \n\t"
1933
+
1934
+ SAVE_RESULT_4x16
1935
+
1936
+ :
1937
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
1938
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
1939
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
1940
+ "s2", "s3", "s4", "s5", "s6");
1941
+
1942
+ } else {
1943
+ __asm__ volatile(
1944
+ "vsetvli t0, zero, e32, m8 \n\t"
1945
+ "vxor.vv v24, v24, v24 \n\t"
1946
+ "addi t3, %[BlockCountK], 0 \n\t"
1947
+ "vsetvli t0, zero, e8, m1 \n\t"
1948
+ "li s1, 24 \n\t"
1949
+ "vmv.v.i v1, 3 \n\t"
1950
+ "vsetvli t0, s1, e8, m1 \n\t"
1951
+ "vmv.v.i v1, 2 \n\t"
1952
+ "vsetvli t0, zero, e8, mf2 \n\t"
1953
+ "vmv.v.i v1, 1 \n\t"
1954
+ "vsetvli t0, zero, e8, mf4 \n\t"
1955
+ "vmv.v.i v1, 0 \n\t"
1956
+ "addi a1, %[A], 0 \n\t"
1957
+ "addi s1, %[B], 0 \n\t"
1958
+ "BLOCK_COUNTK_LOOP%=: \n\t"
1959
+ // scale offset
1960
+ "addi s5, s1, 0 \n\t"
1961
+ // zp offset
1962
+ "addi s6, s1, 64 \n\t"
1963
+ "addi s1, s6, 16 \n\t"
1964
+ "addi s2, s1, 32 \n\t"
1965
+ "addi s3, s1, 32*2 \n\t"
1966
+ "addi s4, s1, 32*3 \n\t"
1967
+ "vsetvli t0, zero, e32, m8 \n\t"
1968
+ "vxor.vv v16, v16, v16 \n\t"
1969
+ // load a scale
1970
+ // load a scale
1971
+ "flw f1, (a1) \n\t"
1972
+ "flw f2, 4(a1) \n\t"
1973
+ "flw f3, 8(a1) \n\t"
1974
+ "flw f4, 12(a1) \n\t"
1975
+ "addi a1, a1, 16 \n\t"
1976
+ "addi t2, %[INNER], 0 \n\t"
1977
+
1978
+ SQ4BIT_KERNEL_LOAD_ZP_16X1_v2
1979
+
1980
+ "BLOCK_INNER_LOOP%=: \n\t"
1981
+
1982
+ LOAD_B_16x8x2
1983
+
1984
+ "vle8.v v10, (a1) \n\t"
1985
+ "addi a1, a1, 32 \n\t"
1986
+ "vle8.v v11, (a1) \n\t"
1987
+ "addi a1, a1, 32 \n\t"
1988
+ "vsub.vv v2, v2, v12 \n\t"
1989
+ "vsub.vv v6, v6, v12 \n\t"
1990
+ "vsub.vv v3, v3, v13 \n\t"
1991
+ "vsub.vv v7, v7, v13 \n\t"
1992
+ "vsub.vv v4, v4, v14 \n\t"
1993
+ "vsub.vv v8, v8, v14 \n\t"
1994
+ "vsub.vv v5, v5, v15 \n\t"
1995
+ "vsub.vv v9, v9, v15 \n\t"
1996
+
1997
+ SQ4BIT_KERNEL_COMP_4x16x16
1998
+
1999
+ "addi t2, t2, -1 \n\t"
2000
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
2001
+
2002
+ LOAD_SCALE_4x16
2003
+
2004
+ "vsetvli t0, zero, e32, m8 \n\t"
2005
+ "vfcvt.f.x.v v16, v16 \n\t"
2006
+ "vfmacc.vv v24, v16, v8 \n\t"
2007
+ "addi t3, t3, -1 \n\t"
2008
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
2009
+
2010
+ "RESULT_SAVE%=: \n\t"
2011
+
2012
+ SAVE_RESULT_4x16
2013
+
2014
+ :
2015
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
2016
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
2017
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
2018
+ "s4", "s5", "s6");
2019
+ }
2020
+ }
2021
+ } else {
2022
+ for (size_t n = 0; n < CountN; n += 16) {
2023
+ size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n;
2024
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2025
+ n * BlockCountK * BlkLen / 2 + // b data
2026
+ n * BlockCountK * sizeof(float); // scale
2027
+ float * CPtr = C + n;
2028
+ if (NBLKS < 16) {
2029
+ CPtr = tmp;
2030
+ LDC = 16 * sizeof(float);
2031
+ }
2032
+ if (Bias != nullptr) {
2033
+ const float * bias = Bias + n;
2034
+ if (NBLKS < 16) {
2035
+ __asm__ volatile(
2036
+ "vsetvli t0, %[N], e32, m2 \n\t"
2037
+ "vle32.v v0, (%[SRC]) \n\t"
2038
+ "vse32.v v0, (%[DST]) \n\t"
2039
+ :
2040
+ : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS)
2041
+ : "cc", "t0");
2042
+ bias = tmp;
2043
+ }
2044
+ __asm__ volatile(LOAD_BIAS
2045
+ "addi t3, %[BlockCountK], 0 \n\t"
2046
+ "addi a1, %[A], 0 \n\t"
2047
+ "addi s1, %[B], 0 \n\t"
2048
+ "BLOCK_COUNTK_LOOP%=: \n\t"
2049
+ "addi s5, s1, 0 \n\t"
2050
+ "addi s1, s5, 64 \n\t"
2051
+ "addi s2, s1, 32 \n\t"
2052
+ "addi s3, s1, 32*2 \n\t"
2053
+ "addi s4, s1, 32*3 \n\t"
2054
+ "vsetvli t0, zero, e32, m8 \n\t"
2055
+ "vxor.vv v16, v16, v16 \n\t"
2056
+ // load a scale
2057
+ "flw f1, (a1) \n\t"
2058
+ "flw f2, 4(a1) \n\t"
2059
+ "flw f3, 8(a1) \n\t"
2060
+ "flw f4, 12(a1) \n\t"
2061
+ "addi a1, a1, 16 \n\t"
2062
+ "addi t2, %[INNER], 0 \n\t"
2063
+ "BLOCK_INNER_LOOP%=: \n\t"
2064
+
2065
+ LOAD_B_16x8x2
2066
+
2067
+ "vsetvli t0, zero, e8, m1 \n\t"
2068
+ "vle8.v v10, (a1) \n\t"
2069
+ "addi a1, a1, 32 \n\t"
2070
+ "vle8.v v11, (a1) \n\t"
2071
+ "addi a1, a1, 32 \n\t"
2072
+ "vadd.vi v2, v2, -8 \n\t"
2073
+ "vadd.vi v3, v3, -8 \n\t"
2074
+ "vadd.vi v4, v4, -8 \n\t"
2075
+ "vadd.vi v5, v5, -8 \n\t"
2076
+ "vadd.vi v6, v6, -8 \n\t"
2077
+ "vadd.vi v7, v7, -8 \n\t"
2078
+ "vadd.vi v8, v8, -8 \n\t"
2079
+ "vadd.vi v9, v9, -8 \n\t"
2080
+
2081
+ SQ4BIT_KERNEL_COMP_4x16x16
2082
+
2083
+ "addi t2, t2, -1 \n\t"
2084
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
2085
+
2086
+ LOAD_SCALE_4x16
2087
+
2088
+ "vsetvli t0, zero, e32, m8 \n\t"
2089
+ "vfcvt.f.x.v v16, v16 \n\t"
2090
+ "vfmacc.vv v24, v16, v8 \n\t"
2091
+ "addi t3, t3, -1 \n\t"
2092
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
2093
+
2094
+ "RESULT_SAVE%=: \n\t"
2095
+
2096
+ SAVE_RESULT_4x16
2097
+
2098
+ :
2099
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
2100
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias)
2101
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1",
2102
+ "s2", "s3", "s4", "s5", "s6");
2103
+
2104
+ } else {
2105
+ __asm__ volatile(
2106
+ "vsetvli t0, zero, e32, m8 \n\t"
2107
+ "vxor.vv v24, v24, v24 \n\t"
2108
+ "addi t3, %[BlockCountK], 0 \n\t"
2109
+ "addi a1, %[A], 0 \n\t"
2110
+ "addi s1, %[B], 0 \n\t"
2111
+ "BLOCK_COUNTK_LOOP%=: \n\t"
2112
+ "addi s5, s1, 0 \n\t"
2113
+ "addi s1, s5, 64 \n\t"
2114
+ "addi s2, s1, 32 \n\t"
2115
+ "addi s3, s1, 32*2 \n\t"
2116
+ "addi s4, s1, 32*3 \n\t"
2117
+ "vsetvli t0, zero, e32, m8 \n\t"
2118
+ "vxor.vv v16, v16, v16 \n\t"
2119
+ // load a scale
2120
+ "flw f1, (a1) \n\t"
2121
+ "flw f2, 4(a1) \n\t"
2122
+ "flw f3, 8(a1) \n\t"
2123
+ "flw f4, 12(a1) \n\t"
2124
+ "addi a1, a1, 16 \n\t"
2125
+ "addi t2, %[INNER], 0 \n\t"
2126
+ "BLOCK_INNER_LOOP%=: \n\t"
2127
+
2128
+ LOAD_B_16x8x2
2129
+
2130
+ "vsetvli t0, zero, e8, m1 \n\t"
2131
+ "vle8.v v10, (a1) \n\t"
2132
+
2133
+ "addi a1, a1, 32 \n\t"
2134
+ "vle8.v v11, (a1) \n\t"
2135
+ "addi a1, a1, 32 \n\t"
2136
+ "vadd.vi v2, v2, -8 \n\t"
2137
+ "vadd.vi v3, v3, -8 \n\t"
2138
+ "vadd.vi v4, v4, -8 \n\t"
2139
+ "vadd.vi v5, v5, -8 \n\t"
2140
+ "vadd.vi v6, v6, -8 \n\t"
2141
+ "vadd.vi v7, v7, -8 \n\t"
2142
+ "vadd.vi v8, v8, -8 \n\t"
2143
+ "vadd.vi v9, v9, -8 \n\t"
2144
+
2145
+ SQ4BIT_KERNEL_COMP_4x16x16
2146
+
2147
+ "addi t2, t2, -1 \n\t"
2148
+ "bnez t2, BLOCK_INNER_LOOP%= \n\t"
2149
+
2150
+ LOAD_SCALE_4x16
2151
+
2152
+ "vsetvli t0, zero, e32, m8 \n\t"
2153
+ "vfcvt.f.x.v v16, v16 \n\t"
2154
+ "vfmacc.vv v24, v16, v8 \n\t"
2155
+ "addi t3, t3, -1 \n\t"
2156
+ "bnez t3, BLOCK_COUNTK_LOOP%= \n\t"
2157
+
2158
+ "RESULT_SAVE%=: \n\t"
2159
+
2160
+ SAVE_RESULT_4x16
2161
+
2162
+ :
2163
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC),
2164
+ [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr)
2165
+ : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3",
2166
+ "s4", "s5", "s6");
2167
+ }
2168
+ }
2169
+ }
2170
+ if (CountN % 16 != 0) {
2171
+ // stroe output from tmp to C when NBLKS less than 16.
2172
+ float * CPtr = C + CountN / 16 * 16;
2173
+ const size_t N = CountN % 16;
2174
+ LDC = ldc * sizeof(float);
2175
+ __asm__ volatile(
2176
+ "vsetvli t0, %[N], e32, m2 \n\t"
2177
+ "vle32.v v0, (%[SRC]) \n\t"
2178
+ "addi s2, %[SRC], 64 \n\t"
2179
+ "addi s3, %[SRC], 64*2 \n\t"
2180
+ "addi s4, %[SRC], 64*3 \n\t"
2181
+ "vle32.v v2, (s2) \n\t"
2182
+ "vle32.v v4, (s3) \n\t"
2183
+ "vle32.v v6, (s4) \n\t"
2184
+ "add t2, %[DST], %[LDC] \n\t"
2185
+ "add t3, t2, %[LDC] \n\t"
2186
+ "add t4, t3, %[LDC] \n\t"
2187
+ "vse32.v v0, (%[DST]) \n\t"
2188
+ "vse32.v v2, (t2) \n\t"
2189
+ "vse32.v v4, (t3) \n\t"
2190
+ "vse32.v v6, (t4) \n\t"
2191
+ :
2192
+ : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC)
2193
+ : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4");
2194
+ }
2195
+ }
2196
+
2197
+ template <bool HasZeroPoint>
2198
+ void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen,
2199
+ const std::byte * QuantA,
2200
+ const std::byte * QuantBData,
2201
+ const float * QuantBScale,
2202
+ const std::byte * QuantBZeroPoint,
2203
+ float * C,
2204
+ size_t CountN,
2205
+ size_t BlockCountK,
2206
+ const float * Bias) {
2207
+ GGML_UNUSED(QuantBScale);
2208
+ GGML_UNUSED(QuantBZeroPoint);
2209
+ size_t INNER = BlkLen / 16;
2210
+
2211
+ if constexpr (HasZeroPoint) {
2212
+ for (size_t n = 0; n < CountN; n += 16) {
2213
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2214
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2215
+ n * BlockCountK * BlkLen / 2 + // b data
2216
+ n * BlockCountK * sizeof(uint8_t) + // zp
2217
+ n * BlockCountK * sizeof(_Float16); // scale
2218
+ float * CPtr = C + n;
2219
+ size_t cnt = BlockCountK;
2220
+ if (Bias != nullptr) {
2221
+ const float * bias = Bias + n;
2222
+ __asm__ volatile(
2223
+ "addi t3, %[NBLKS], 0 \n\t"
2224
+ "vsetvli t0, zero, e8, m1 \n\t"
2225
+
2226
+ "vmv.v.i v13, 3 \n\t"
2227
+ "li s1, 24 \n\t"
2228
+ "vsetvli t0, s1, e8, m1 \n\t"
2229
+ "vmv.v.i v13, 2 \n\t"
2230
+ "vsetvli t0, zero, e8, mf2 \n\t"
2231
+ "vmv.v.i v13, 1 \n\t"
2232
+ "vsetvli t0, zero, e8, mf4 \n\t"
2233
+ "vmv.v.i v13, 0 \n\t"
2234
+ "addi s1, %[B], 0 \n\t"
2235
+ "addi s2, %[B], 8 \n\t"
2236
+ "addi s3, %[B], 16 \n\t"
2237
+ "addi s4, %[B], 24 \n\t"
2238
+ // zp offset
2239
+ "addi s7, %[B], 32 \n\t"
2240
+ // a offset
2241
+ "addi s5, %[A], 0 \n\t"
2242
+ "addi s6, %[A], 12 \n\t"
2243
+
2244
+ "vsetvli t0, t3, e32, mf2 \n\t"
2245
+ "vle32.v v28, (%[BIAS]) \n\t"
2246
+ "sub t3, t3, t0 \n\t"
2247
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2248
+ "vsetvli t0, t3, e32, mf2 \n\t"
2249
+ "vle32.v v29, (%[BIAS]) \n\t"
2250
+ "sub t3, t3, t0 \n\t"
2251
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2252
+ "vsetvli t0, t3, e32, mf2 \n\t"
2253
+ "vle32.v v30, (%[BIAS]) \n\t"
2254
+ "sub t3, t3, t0 \n\t"
2255
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2256
+ "vsetvli t0, t3, e32, mf2 \n\t"
2257
+ "vle32.v v31, (%[BIAS]) \n\t"
2258
+
2259
+ "LOOP_K%=: \n\t"
2260
+ "vsetvli t0, zero, e16, mf4 \n\t"
2261
+
2262
+ "vle16.v v4, (s1) \n\t"
2263
+ "addi s1, s1, 48 \n\t"
2264
+ "vle16.v v5, (s2) \n\t"
2265
+ "addi s2, s2, 72 \n\t"
2266
+ "vle16.v v6, (s3) \n\t"
2267
+ "addi s3, s3, 96 \n\t"
2268
+ "vle16.v v7, (s4) \n\t"
2269
+ "addi s4, s4, 120 \n\t"
2270
+ "flw f1, (s5) \n\t"
2271
+ "addi s5, s5, 4 \n\t"
2272
+ "vfwcvt.f.f.v v8, v4 \n\t"
2273
+ "vfwcvt.f.f.v v9, v5 \n\t"
2274
+ "vfwcvt.f.f.v v10, v6 \n\t"
2275
+ "vfwcvt.f.f.v v11, v7 \n\t"
2276
+
2277
+ "vsetvli t0, zero, e32, mf2 \n\t"
2278
+ "addi t5, %[INNER], 0 \n\t"
2279
+ "vxor.vv v16, v16, v16 \n\t"
2280
+ "vxor.vv v18, v18, v18 \n\t"
2281
+ "vxor.vv v20, v20, v20 \n\t"
2282
+ "vxor.vv v22, v22, v22 \n\t"
2283
+ "vfmul.vf v24, v8, f1 \n\t"
2284
+ "vfmul.vf v25, v9, f1 \n\t"
2285
+ "vfmul.vf v26, v10, f1 \n\t"
2286
+ "vfmul.vf v27, v11, f1 \n\t"
2287
+ "addi %[CNT], %[CNT], -1 \n\t"
2288
+
2289
+ SQ4BIT_KERNEL_LOAD_ZP_16X1
2290
+
2291
+ "LOOP_INNER%=: \n\t"
2292
+
2293
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2294
+
2295
+ "vsub.vv v0, v0, v8 \n\t"
2296
+ "vsub.vv v4, v4, v8 \n\t"
2297
+ "vsub.vv v1, v1, v9 \n\t"
2298
+ "vsub.vv v5, v5, v9 \n\t"
2299
+ "vsub.vv v2, v2, v10 \n\t"
2300
+ "vsub.vv v6, v6, v10 \n\t"
2301
+ "vsub.vv v3, v3, v11 \n\t"
2302
+ "vsub.vv v7, v7, v11 \n\t"
2303
+
2304
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2305
+
2306
+ "bnez t5, LOOP_INNER%= \n\t"
2307
+ "vsetvli t0, zero, e32, mf2 \n\t"
2308
+
2309
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
2310
+ "addi s7, s1, 32 \n\t"
2311
+
2312
+ "bnez %[CNT], LOOP_K%= \n\t"
2313
+ "addi t3, zero, 16 \n\t"
2314
+ "addi s1, %[C], 16 \n\t"
2315
+ "addi s2, %[C], 32 \n\t"
2316
+ "addi s3, %[C], 48 \n\t"
2317
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2318
+ "vse32.v v28, (%[C]) \n\t"
2319
+ "vse32.v v29, (s1) \n\t"
2320
+ "vse32.v v30, (s2) \n\t"
2321
+ "vse32.v v31, (s3) \n\t"
2322
+ "jal x0, END%= \n\t"
2323
+
2324
+ "ST_TAIL%=: \n\t"
2325
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2326
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2327
+ "vse32.v v28, (%[C]) \n\t"
2328
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2329
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2330
+ "vse32.v v29, (s1) \n\t"
2331
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2332
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2333
+ "vse32.v v30, (s2) \n\t"
2334
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2335
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2336
+ "vse32.v v31, (s3) \n\t"
2337
+ "END%=: \n\t"
2338
+
2339
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
2340
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2341
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2342
+ } else {
2343
+ __asm__ volatile(
2344
+ "vsetvli t0, zero, e32, m4 \n\t"
2345
+ "vxor.vv v28, v28, v28 \n\t"
2346
+
2347
+ "vsetvli t0, zero, e8, m1 \n\t"
2348
+ "vmv.v.i v13, 3 \n\t"
2349
+ "li s1, 24 \n\t"
2350
+ "vsetvli t0, s1, e8, m1 \n\t"
2351
+ "vmv.v.i v13, 2 \n\t"
2352
+ "vsetvli t0, zero, e8, mf2 \n\t"
2353
+ "vmv.v.i v13, 1 \n\t"
2354
+ "vsetvli t0, zero, e8, mf4 \n\t"
2355
+ "vmv.v.i v13, 0 \n\t"
2356
+
2357
+ "addi s1, %[B], 0 \n\t"
2358
+ "addi s2, %[B], 8 \n\t"
2359
+ "addi s3, %[B], 16 \n\t"
2360
+ "addi s4, %[B], 24 \n\t"
2361
+
2362
+ "addi s7, %[B], 32 \n\t"
2363
+
2364
+ "addi s5, %[A], 0 \n\t"
2365
+ "addi s6, %[A], 12 \n\t"
2366
+ "LOOP_K%=: \n\t"
2367
+ "vsetvli t0, zero, e16, mf4 \n\t"
2368
+ "vle16.v v4, (s1) \n\t"
2369
+ "addi s1, s1, 48 \n\t"
2370
+ "vle16.v v5, (s2) \n\t"
2371
+ "addi s2, s2, 72 \n\t"
2372
+ "vle16.v v6, (s3) \n\t"
2373
+ "addi s3, s3, 96 \n\t"
2374
+ "vle16.v v7, (s4) \n\t"
2375
+ "addi s4, s4, 120 \n\t"
2376
+ "flw f1, (s5) \n\t"
2377
+ "addi s5, s5, 4 \n\t"
2378
+
2379
+ "vfwcvt.f.f.v v8, v4 \n\t"
2380
+ "vfwcvt.f.f.v v9, v5 \n\t"
2381
+ "vfwcvt.f.f.v v10, v6 \n\t"
2382
+ "vfwcvt.f.f.v v11, v7 \n\t"
2383
+ "vsetvli t0, zero, e32, mf2 \n\t"
2384
+
2385
+ "addi t5, %[INNER], 0 \n\t"
2386
+ "vxor.vv v16, v16, v16 \n\t"
2387
+ "vxor.vv v18, v18, v18 \n\t"
2388
+ "vxor.vv v20, v20, v20 \n\t"
2389
+ "vxor.vv v22, v22, v22 \n\t"
2390
+ "vfmul.vf v24, v8, f1 \n\t"
2391
+ "vfmul.vf v25, v9, f1 \n\t"
2392
+ "vfmul.vf v26, v10, f1 \n\t"
2393
+ "vfmul.vf v27, v11, f1 \n\t"
2394
+ "addi %[CNT], %[CNT], -1 \n\t"
2395
+
2396
+ SQ4BIT_KERNEL_LOAD_ZP_16X1
2397
+
2398
+ "LOOP_INNER%=: \n\t"
2399
+
2400
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2401
+
2402
+ "vsub.vv v0, v0, v8 \n\t"
2403
+ "vsub.vv v4, v4, v8 \n\t"
2404
+ "vsub.vv v1, v1, v9 \n\t"
2405
+ "vsub.vv v5, v5, v9 \n\t"
2406
+ "vsub.vv v2, v2, v10 \n\t"
2407
+ "vsub.vv v6, v6, v10 \n\t"
2408
+ "vsub.vv v3, v3, v11 \n\t"
2409
+ "vsub.vv v7, v7, v11 \n\t"
2410
+
2411
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2412
+
2413
+ "bnez t5, LOOP_INNER%= \n\t"
2414
+ "vsetvli t0, zero, e32, mf2 \n\t"
2415
+
2416
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
2417
+ "addi s7, s1, 32 \n\t"
2418
+
2419
+ "bnez %[CNT], LOOP_K%= \n\t"
2420
+ "addi t3, zero, 16 \n\t"
2421
+ "addi s1, %[C], 16 \n\t"
2422
+ "addi s2, %[C], 32 \n\t"
2423
+ "addi s3, %[C], 48 \n\t"
2424
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2425
+ "vse32.v v28, (%[C]) \n\t"
2426
+ "vse32.v v29, (s1) \n\t"
2427
+ "vse32.v v30, (s2) \n\t"
2428
+ "vse32.v v31, (s3) \n\t"
2429
+ "jal x0, END%= \n\t"
2430
+
2431
+ "ST_TAIL%=: \n\t"
2432
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2433
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2434
+ "vse32.v v28, (%[C]) \n\t"
2435
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2436
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2437
+ "vse32.v v29, (s1) \n\t"
2438
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2439
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2440
+ "vse32.v v30, (s2) \n\t"
2441
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2442
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2443
+ "vse32.v v31, (s3) \n\t"
2444
+ "END%=: \n\t"
2445
+
2446
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
2447
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2448
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2449
+ }
2450
+ }
2451
+ } else {
2452
+ for (size_t n = 0; n < CountN; n += 16) {
2453
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2454
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2455
+ n * BlockCountK * BlkLen / 2 + // b data
2456
+ n * BlockCountK * sizeof(_Float16); // scale
2457
+ float * CPtr = C + n;
2458
+ size_t cnt = BlockCountK;
2459
+ if (Bias != nullptr) {
2460
+ const float * bias = Bias + n;
2461
+ __asm__ volatile(
2462
+ "addi t3, %[NBLKS], 0 \n\t"
2463
+ "addi s1, %[B], 0 \n\t"
2464
+ "addi s2, %[B], 8 \n\t"
2465
+ "addi s3, %[B], 16 \n\t"
2466
+ "addi s4, %[B], 24 \n\t"
2467
+ "addi s5, %[A], 0 \n\t"
2468
+ "addi s6, %[A], 12 \n\t"
2469
+ "vsetvli t0, t3, e32, mf2 \n\t"
2470
+ "vle32.v v28, (%[BIAS]) \n\t"
2471
+ "sub t3, t3, t0 \n\t"
2472
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2473
+ "vsetvli t0, t3, e32, mf2 \n\t"
2474
+ "vle32.v v29, (%[BIAS]) \n\t"
2475
+ "sub t3, t3, t0 \n\t"
2476
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2477
+ "vsetvli t0, t3, e32, mf2 \n\t"
2478
+ "vle32.v v30, (%[BIAS]) \n\t"
2479
+ "sub t3, t3, t0 \n\t"
2480
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2481
+ "vsetvli t0, t3, e32, mf2 \n\t"
2482
+ "vle32.v v31, (%[BIAS]) \n\t"
2483
+
2484
+ "LOOP_K%=: \n\t"
2485
+ "vsetvli t0, zero, e16, mf4 \n\t"
2486
+
2487
+ "vle16.v v4, (s1) \n\t"
2488
+ "addi s1, s1, 32 \n\t"
2489
+ "vle16.v v5, (s2) \n\t"
2490
+ "addi s2, s2, 56 \n\t"
2491
+ "vle16.v v6, (s3) \n\t"
2492
+ "addi s3, s3, 80 \n\t"
2493
+ "vle16.v v7, (s4) \n\t"
2494
+ "addi s4, s4, 104 \n\t"
2495
+ "flw f1, (s5) \n\t"
2496
+ "addi s5, s5, 4 \n\t"
2497
+ "vfwcvt.f.f.v v8, v4 \n\t"
2498
+ "vfwcvt.f.f.v v9, v5 \n\t"
2499
+ "vfwcvt.f.f.v v10, v6 \n\t"
2500
+ "vfwcvt.f.f.v v11, v7 \n\t"
2501
+
2502
+ "vsetvli t0, zero, e32, mf2 \n\t"
2503
+ "addi t5, %[INNER], 0 \n\t"
2504
+ "vxor.vv v16, v16, v16 \n\t"
2505
+ "vxor.vv v18, v18, v18 \n\t"
2506
+ "vxor.vv v20, v20, v20 \n\t"
2507
+ "vxor.vv v22, v22, v22 \n\t"
2508
+ "vfmul.vf v24, v8, f1 \n\t"
2509
+ "vfmul.vf v25, v9, f1 \n\t"
2510
+ "vfmul.vf v26, v10, f1 \n\t"
2511
+ "vfmul.vf v27, v11, f1 \n\t"
2512
+ "addi %[CNT], %[CNT], -1 \n\t"
2513
+ "vsetvli t0, zero, e8, m1 \n\t"
2514
+ "LOOP_INNER%=: \n\t"
2515
+
2516
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2517
+
2518
+ "vadd.vi v0, v0, -8 \n\t"
2519
+ "vadd.vi v1, v1, -8 \n\t"
2520
+ "vadd.vi v2, v2, -8 \n\t"
2521
+ "vadd.vi v3, v3, -8 \n\t"
2522
+ "vadd.vi v4, v4, -8 \n\t"
2523
+ "vadd.vi v5, v5, -8 \n\t"
2524
+ "vadd.vi v6, v6, -8 \n\t"
2525
+ "vadd.vi v7, v7, -8 \n\t"
2526
+
2527
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2528
+
2529
+ "bnez t5, LOOP_INNER%= \n\t"
2530
+ "vsetvli t0, zero, e32, mf2 \n\t"
2531
+
2532
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
2533
+
2534
+ "bnez %[CNT], LOOP_K%= \n\t"
2535
+ "addi t3, zero, 16 \n\t"
2536
+ "addi s1, %[C], 16 \n\t"
2537
+ "addi s2, %[C], 32 \n\t"
2538
+ "addi s3, %[C], 48 \n\t"
2539
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2540
+ "vse32.v v28, (%[C]) \n\t"
2541
+ "vse32.v v29, (s1) \n\t"
2542
+ "vse32.v v30, (s2) \n\t"
2543
+ "vse32.v v31, (s3) \n\t"
2544
+ "jal x0, END%= \n\t"
2545
+
2546
+ "ST_TAIL%=: \n\t"
2547
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2548
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2549
+ "vse32.v v28, (%[C]) \n\t"
2550
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2551
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2552
+ "vse32.v v29, (s1) \n\t"
2553
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2554
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2555
+ "vse32.v v30, (s2) \n\t"
2556
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2557
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2558
+ "vse32.v v31, (s3) \n\t"
2559
+ "END%=: \n\t"
2560
+
2561
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
2562
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2563
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
2564
+ } else {
2565
+ __asm__ volatile(
2566
+ "vsetvli t0, zero, e32, m4 \n\t"
2567
+ "vxor.vv v28, v28, v28 \n\t"
2568
+ "addi s1, %[B], 0 \n\t"
2569
+ "addi s2, %[B], 8 \n\t"
2570
+ "addi s3, %[B], 16 \n\t"
2571
+ "addi s4, %[B], 24 \n\t"
2572
+
2573
+ "addi s5, %[A], 0 \n\t"
2574
+ "addi s6, %[A], 12 \n\t"
2575
+ "LOOP_K%=: \n\t"
2576
+ "vsetvli t0, zero, e16, mf4 \n\t"
2577
+ "vle16.v v4, (s1) \n\t"
2578
+ "addi s1, s1, 32 \n\t"
2579
+ "vle16.v v5, (s2) \n\t"
2580
+ "addi s2, s2, 56 \n\t"
2581
+ "vle16.v v6, (s3) \n\t"
2582
+ "addi s3, s3, 80 \n\t"
2583
+ "vle16.v v7, (s4) \n\t"
2584
+ "addi s4, s4, 104 \n\t"
2585
+ "flw f1, (s5) \n\t"
2586
+ "addi s5, s5, 4 \n\t"
2587
+
2588
+ "vfwcvt.f.f.v v8, v4 \n\t"
2589
+ "vfwcvt.f.f.v v9, v5 \n\t"
2590
+ "vfwcvt.f.f.v v10, v6 \n\t"
2591
+ "vfwcvt.f.f.v v11, v7 \n\t"
2592
+ "vsetvli t0, zero, e32, mf2 \n\t"
2593
+
2594
+ "addi t5, %[INNER], 0 \n\t"
2595
+ "vxor.vv v16, v16, v16 \n\t"
2596
+ "vxor.vv v18, v18, v18 \n\t"
2597
+ "vxor.vv v20, v20, v20 \n\t"
2598
+ "vxor.vv v22, v22, v22 \n\t"
2599
+ "vfmul.vf v24, v8, f1 \n\t"
2600
+ "vfmul.vf v25, v9, f1 \n\t"
2601
+ "vfmul.vf v26, v10, f1 \n\t"
2602
+ "vfmul.vf v27, v11, f1 \n\t"
2603
+ "addi %[CNT], %[CNT], -1 \n\t"
2604
+ "vsetvli t0, zero, e8, m1 \n\t"
2605
+ "LOOP_INNER%=: \n\t"
2606
+
2607
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2608
+
2609
+ "vadd.vi v0, v0, -8 \n\t"
2610
+ "vadd.vi v1, v1, -8 \n\t"
2611
+ "vadd.vi v2, v2, -8 \n\t"
2612
+ "vadd.vi v3, v3, -8 \n\t"
2613
+ "vadd.vi v4, v4, -8 \n\t"
2614
+ "vadd.vi v5, v5, -8 \n\t"
2615
+ "vadd.vi v6, v6, -8 \n\t"
2616
+ "vadd.vi v7, v7, -8 \n\t"
2617
+
2618
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2619
+
2620
+ "bnez t5, LOOP_INNER%= \n\t"
2621
+ "vsetvli t0, zero, e32, mf2 \n\t"
2622
+
2623
+ SQ4BIT_KERNEL_ACC_F16_1X4X4
2624
+
2625
+ "bnez %[CNT], LOOP_K%= \n\t"
2626
+ "addi t3, zero, 16 \n\t"
2627
+ "addi s1, %[C], 16 \n\t"
2628
+ "addi s2, %[C], 32 \n\t"
2629
+ "addi s3, %[C], 48 \n\t"
2630
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2631
+ "vse32.v v28, (%[C]) \n\t"
2632
+ "vse32.v v29, (s1) \n\t"
2633
+ "vse32.v v30, (s2) \n\t"
2634
+ "vse32.v v31, (s3) \n\t"
2635
+ "jal x0, END%= \n\t"
2636
+
2637
+ "ST_TAIL%=: \n\t"
2638
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2639
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2640
+ "vse32.v v28, (%[C]) \n\t"
2641
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2642
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2643
+ "vse32.v v29, (s1) \n\t"
2644
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2645
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2646
+ "vse32.v v30, (s2) \n\t"
2647
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2648
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2649
+ "vse32.v v31, (s3) \n\t"
2650
+ "END%=: \n\t"
2651
+
2652
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
2653
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2654
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
2655
+ }
2656
+ }
2657
+ }
2658
+ }
2659
+
2660
+ template <bool HasZeroPoint>
2661
+ void SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen,
2662
+ const std::byte * QuantA,
2663
+ const std::byte * QuantBData,
2664
+ const float * QuantBScale,
2665
+ const std::byte * QuantBZeroPoint,
2666
+ float * C,
2667
+ size_t CountN,
2668
+ size_t BlockCountK,
2669
+ const float * Bias) {
2670
+ GGML_UNUSED(QuantBScale);
2671
+ GGML_UNUSED(QuantBZeroPoint);
2672
+ const size_t INNER = BlkLen / 16;
2673
+ if constexpr (HasZeroPoint) {
2674
+ for (size_t n = 0; n < CountN; n += 16) {
2675
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2676
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2677
+ n * BlockCountK * BlkLen / 2 + // b data
2678
+ n * BlockCountK * sizeof(uint8_t) + // zp
2679
+ n * BlockCountK * sizeof(float); // scale
2680
+ float * CPtr = C + n;
2681
+ size_t cnt = BlockCountK;
2682
+ if (Bias != nullptr) {
2683
+ const float * bias = Bias + n;
2684
+ __asm__ volatile(
2685
+ "addi t3, %[NBLKS], 0 \n\t"
2686
+ "vsetvli t0, zero, e8, m1 \n\t"
2687
+ "vmv.v.i v13, 3 \n\t"
2688
+ "li s1, 24 \n\t"
2689
+ "vsetvli t0, s1, e8, m1 \n\t"
2690
+ "vmv.v.i v13, 2 \n\t"
2691
+ "vsetvli t0, zero, e8, mf2 \n\t"
2692
+ "vmv.v.i v13, 1 \n\t"
2693
+ "vsetvli t0, zero, e8, mf4 \n\t"
2694
+ "vmv.v.i v13, 0 \n\t"
2695
+ "vsetvli t0, zero, e32, m4 \n\t"
2696
+ "vxor.vv v28, v28, v28 \n\t"
2697
+
2698
+ // scale offset, scale0.0, scale1.0, scale2.0, scale3.0....scale15.0
2699
+ "addi s1, %[B], 0 \n\t"
2700
+ "addi s2, %[B], 16 \n\t"
2701
+ "addi s3, %[B], 32 \n\t"
2702
+ "addi s4, %[B], 48 \n\t"
2703
+ // zp offset
2704
+ "addi s7, %[B], 64 \n\t"
2705
+ // a offset
2706
+ "addi s5, %[A], 0 \n\t"
2707
+ "addi s6, %[A], 12 \n\t"
2708
+
2709
+ "vsetvli t0, t3, e32, mf2 \n\t"
2710
+ "vle32.v v28, (%[BIAS]) \n\t"
2711
+ "sub t3, t3, t0 \n\t"
2712
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2713
+ "vsetvli t0, t3, e32, mf2 \n\t"
2714
+ "vle32.v v29, (%[BIAS]) \n\t"
2715
+ "sub t3, t3, t0 \n\t"
2716
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2717
+ "vsetvli t0, t3, e32, mf2 \n\t"
2718
+ "vle32.v v30, (%[BIAS]) \n\t"
2719
+ "sub t3, t3, t0 \n\t"
2720
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2721
+ "vsetvli t0, t3, e32, mf2 \n\t"
2722
+ "vle32.v v31, (%[BIAS]) \n\t"
2723
+ "vsetvli t0, zero, e32, mf2 \n\t"
2724
+ "LOOP_K%=: \n\t"
2725
+
2726
+ // load scale
2727
+ "vle32.v v8, (s1) \n\t"
2728
+ "addi s1, s1, 80 \n\t"
2729
+ "vle32.v v9, (s2) \n\t"
2730
+ "addi s2, s2, 96 \n\t"
2731
+ "vle32.v v10, (s3) \n\t"
2732
+ "addi s3, s3, 112 \n\t"
2733
+ "vle32.v v11, (s4) \n\t"
2734
+ "addi s4, s4, 128 \n\t"
2735
+
2736
+ // load a scale
2737
+ "flw f1, (s5) \n\t"
2738
+ "addi s5, s5, 4 \n\t"
2739
+
2740
+ "addi t5, %[INNER], 0 \n\t"
2741
+ "vxor.vv v16, v16, v16 \n\t"
2742
+ "vxor.vv v18, v18, v18 \n\t"
2743
+ "vxor.vv v20, v20, v20 \n\t"
2744
+ "vxor.vv v22, v22, v22 \n\t"
2745
+
2746
+ // a scale * b scale
2747
+ "vfmul.vf v24, v8, f1 \n\t"
2748
+ "vfmul.vf v25, v9, f1 \n\t"
2749
+ "vfmul.vf v26, v10, f1 \n\t"
2750
+ "vfmul.vf v27, v11, f1 \n\t"
2751
+ "addi %[CNT], %[CNT], -1 \n\t"
2752
+
2753
+ SQ4BIT_KERNEL_LOAD_ZP_16X1
2754
+
2755
+ "LOOP_INNER%=: \n\t"
2756
+
2757
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2758
+
2759
+ "vsub.vv v0, v0, v8 \n\t"
2760
+ "vsub.vv v4, v4, v8 \n\t"
2761
+ "vsub.vv v1, v1, v9 \n\t"
2762
+ "vsub.vv v5, v5, v9 \n\t"
2763
+ "vsub.vv v2, v2, v10 \n\t"
2764
+ "vsub.vv v6, v6, v10 \n\t"
2765
+ "vsub.vv v3, v3, v11 \n\t"
2766
+ "vsub.vv v7, v7, v11 \n\t"
2767
+
2768
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2769
+
2770
+ "bnez t5, LOOP_INNER%= \n\t"
2771
+ "vsetvli t0, zero, e32, mf2 \n\t"
2772
+
2773
+ SQ4BIT_KERNEL_ACC_1X4X4
2774
+ "addi s7, s1, 64 \n\t"
2775
+
2776
+ "bnez %[CNT], LOOP_K%= \n\t"
2777
+
2778
+ "addi t3, zero, 16 \n\t"
2779
+ "addi s1, %[C], 16 \n\t"
2780
+ "addi s2, %[C], 32 \n\t"
2781
+ "addi s3, %[C], 48 \n\t"
2782
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2783
+ "vse32.v v28, (%[C]) \n\t"
2784
+ "vse32.v v29, (s1) \n\t"
2785
+ "vse32.v v30, (s2) \n\t"
2786
+ "vse32.v v31, (s3) \n\t"
2787
+ "jal x0, END%= \n\t"
2788
+
2789
+ "ST_TAIL%=: \n\t"
2790
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2791
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2792
+ "vse32.v v28, (%[C]) \n\t"
2793
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2794
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2795
+ "vse32.v v29, (s1) \n\t"
2796
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2797
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2798
+ "vse32.v v30, (s2) \n\t"
2799
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2800
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2801
+ "vse32.v v31, (s3) \n\t"
2802
+ "END%=: \n\t"
2803
+
2804
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
2805
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2806
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2807
+ } else {
2808
+ __asm__ volatile(
2809
+ "vsetvli t0, zero, e32, m4 \n\t"
2810
+ "vxor.vv v28, v28, v28 \n\t"
2811
+
2812
+ "vsetvli t0, zero, e8, m1 \n\t"
2813
+ "vmv.v.i v13, 3 \n\t"
2814
+ "li s1, 24 \n\t"
2815
+ "vsetvli t0, s1, e8, m1 \n\t"
2816
+ "vmv.v.i v13, 2 \n\t"
2817
+ "vsetvli t0, zero, e8, mf2 \n\t"
2818
+ "vmv.v.i v13, 1 \n\t"
2819
+ "vsetvli t0, zero, e8, mf4 \n\t"
2820
+ "vmv.v.i v13, 0 \n\t"
2821
+ "addi s1, %[B], 0 \n\t"
2822
+ "addi s2, %[B], 16 \n\t"
2823
+ "addi s3, %[B], 32 \n\t"
2824
+ "addi s4, %[B], 48 \n\t"
2825
+
2826
+ "addi s7, %[B], 64 \n\t"
2827
+
2828
+ "addi s5, %[A], 0 \n\t"
2829
+ "addi s6, %[A], 12 \n\t"
2830
+ "vsetvli t0, zero, e32, mf2 \n\t"
2831
+
2832
+ "LOOP_K%=: \n\t"
2833
+ "vle32.v v8, (s1) \n\t"
2834
+ "addi s1, s1, 80 \n\t"
2835
+ "vle32.v v9, (s2) \n\t"
2836
+ "addi s2, s2, 96 \n\t"
2837
+ "vle32.v v10, (s3) \n\t"
2838
+ "addi s3, s3, 112 \n\t"
2839
+ "vle32.v v11, (s4) \n\t"
2840
+ "addi s4, s4, 128 \n\t"
2841
+
2842
+ "flw f1, (s5) \n\t"
2843
+ "addi s5, s5, 4 \n\t"
2844
+
2845
+ "addi t5, %[INNER], 0 \n\t"
2846
+ "vxor.vv v16, v16, v16 \n\t"
2847
+ "vxor.vv v18, v18, v18 \n\t"
2848
+ "vxor.vv v20, v20, v20 \n\t"
2849
+ "vxor.vv v22, v22, v22 \n\t"
2850
+
2851
+ "vfmul.vf v24, v8, f1 \n\t"
2852
+ "vfmul.vf v25, v9, f1 \n\t"
2853
+ "vfmul.vf v26, v10, f1 \n\t"
2854
+ "vfmul.vf v27, v11, f1 \n\t"
2855
+ "addi %[CNT], %[CNT], -1 \n\t"
2856
+
2857
+ SQ4BIT_KERNEL_LOAD_ZP_16X1
2858
+
2859
+ "LOOP_INNER%=: \n\t"
2860
+
2861
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2862
+
2863
+ "vsub.vv v0, v0, v8 \n\t"
2864
+ "vsub.vv v4, v4, v8 \n\t"
2865
+ "vsub.vv v1, v1, v9 \n\t"
2866
+ "vsub.vv v5, v5, v9 \n\t"
2867
+ "vsub.vv v2, v2, v10 \n\t"
2868
+ "vsub.vv v6, v6, v10 \n\t"
2869
+ "vsub.vv v3, v3, v11 \n\t"
2870
+ "vsub.vv v7, v7, v11 \n\t"
2871
+
2872
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2873
+
2874
+ "bnez t5, LOOP_INNER%= \n\t"
2875
+ "vsetvli t0, zero, e32, mf2 \n\t"
2876
+
2877
+ SQ4BIT_KERNEL_ACC_1X4X4
2878
+ "addi s7, s1, 64 \n\t"
2879
+
2880
+ "bnez %[CNT], LOOP_K%= \n\t"
2881
+
2882
+ "addi t3, zero, 16 \n\t"
2883
+ "addi s1, %[C], 16 \n\t"
2884
+ "addi s2, %[C], 32 \n\t"
2885
+ "addi s3, %[C], 48 \n\t"
2886
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2887
+ "vse32.v v28, (%[C]) \n\t"
2888
+ "vse32.v v29, (s1) \n\t"
2889
+ "vse32.v v30, (s2) \n\t"
2890
+ "vse32.v v31, (s3) \n\t"
2891
+ "jal x0, END%= \n\t"
2892
+
2893
+ "ST_TAIL%=: \n\t"
2894
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2895
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2896
+ "vse32.v v28, (%[C]) \n\t"
2897
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2898
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2899
+ "vse32.v v29, (s1) \n\t"
2900
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2901
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2902
+ "vse32.v v30, (s2) \n\t"
2903
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
2904
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
2905
+ "vse32.v v31, (s3) \n\t"
2906
+ "END%=: \n\t"
2907
+
2908
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
2909
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
2910
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7");
2911
+ }
2912
+ }
2913
+ } else {
2914
+ for (size_t n = 0; n < CountN; n += 16) {
2915
+ size_t nblks = (CountN - n) > 16 ? 16 : CountN - n;
2916
+ std::byte * QuantBDataPtr = (std::byte *) QuantBData + //
2917
+ n * BlockCountK * BlkLen / 2 + // b data
2918
+ n * BlockCountK * sizeof(float); // scale
2919
+ float * CPtr = C + n;
2920
+ size_t cnt = BlockCountK;
2921
+ if (Bias != nullptr) {
2922
+ const float * bias = Bias + n;
2923
+ __asm__ volatile(
2924
+ "addi t3, %[NBLKS], 0 \n\t"
2925
+ "addi s1, %[B], 0 \n\t"
2926
+ "addi s2, %[B], 16 \n\t"
2927
+ "addi s3, %[B], 32 \n\t"
2928
+ "addi s4, %[B], 48 \n\t"
2929
+ "addi s5, %[A], 0 \n\t"
2930
+ "addi s6, %[A], 12 \n\t"
2931
+ "vsetvli t0, t3, e32, mf2 \n\t"
2932
+ "vle32.v v28, (%[BIAS]) \n\t"
2933
+ "sub t3, t3, t0 \n\t"
2934
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2935
+ "vsetvli t0, t3, e32, mf2 \n\t"
2936
+ "vle32.v v29, (%[BIAS]) \n\t"
2937
+ "sub t3, t3, t0 \n\t"
2938
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2939
+ "vsetvli t0, t3, e32, mf2 \n\t"
2940
+ "vle32.v v30, (%[BIAS]) \n\t"
2941
+ "sub t3, t3, t0 \n\t"
2942
+ "addi %[BIAS], %[BIAS], 16 \n\t"
2943
+ "vsetvli t0, t3, e32, mf2 \n\t"
2944
+ "vle32.v v31, (%[BIAS]) \n\t"
2945
+ "vsetvli t0, zero, e32, mf2 \n\t"
2946
+ "LOOP_K%=: \n\t"
2947
+ "vle32.v v8, (s1) \n\t"
2948
+ "addi s1, s1, 64 \n\t"
2949
+ "vle32.v v9, (s2) \n\t"
2950
+ "addi s2, s2, 80 \n\t"
2951
+ "vle32.v v10, (s3) \n\t"
2952
+ "addi s3, s3, 96 \n\t"
2953
+ "vle32.v v11, (s4) \n\t"
2954
+ "addi s4, s4, 112 \n\t"
2955
+ "flw f1, (s5) \n\t"
2956
+ "addi s5, s5, 4 \n\t"
2957
+
2958
+ "addi t5, %[INNER], 0 \n\t"
2959
+ "vxor.vv v16, v16, v16 \n\t"
2960
+ "vxor.vv v18, v18, v18 \n\t"
2961
+ "vxor.vv v20, v20, v20 \n\t"
2962
+ "vxor.vv v22, v22, v22 \n\t"
2963
+ "vfmul.vf v24, v8, f1 \n\t"
2964
+ "vfmul.vf v25, v9, f1 \n\t"
2965
+ "vfmul.vf v26, v10, f1 \n\t"
2966
+ "vfmul.vf v27, v11, f1 \n\t"
2967
+ "addi %[CNT], %[CNT], -1 \n\t"
2968
+ "vsetvli t0, zero, e8, m1 \n\t"
2969
+ "LOOP_INNER%=: \n\t"
2970
+
2971
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
2972
+
2973
+ "vadd.vi v0, v0, -8 \n\t"
2974
+ "vadd.vi v1, v1, -8 \n\t"
2975
+ "vadd.vi v2, v2, -8 \n\t"
2976
+ "vadd.vi v3, v3, -8 \n\t"
2977
+ "vadd.vi v4, v4, -8 \n\t"
2978
+ "vadd.vi v5, v5, -8 \n\t"
2979
+ "vadd.vi v6, v6, -8 \n\t"
2980
+ "vadd.vi v7, v7, -8 \n\t"
2981
+
2982
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
2983
+
2984
+ "bnez t5, LOOP_INNER%= \n\t"
2985
+ "vsetvli t0, zero, e32, mf2 \n\t"
2986
+
2987
+ SQ4BIT_KERNEL_ACC_1X4X4
2988
+
2989
+ "bnez %[CNT], LOOP_K%= \n\t"
2990
+ "addi t3, zero, 16 \n\t"
2991
+ "addi s1, %[C], 16 \n\t"
2992
+ "addi s2, %[C], 32 \n\t"
2993
+ "addi s3, %[C], 48 \n\t"
2994
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
2995
+ "vse32.v v28, (%[C]) \n\t"
2996
+ "vse32.v v29, (s1) \n\t"
2997
+ "vse32.v v30, (s2) \n\t"
2998
+ "vse32.v v31, (s3) \n\t"
2999
+ "jal x0, END%= \n\t"
3000
+
3001
+ "ST_TAIL%=: \n\t"
3002
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3003
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3004
+ "vse32.v v28, (%[C]) \n\t"
3005
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3006
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3007
+ "vse32.v v29, (s1) \n\t"
3008
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3009
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3010
+ "vse32.v v30, (s2) \n\t"
3011
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3012
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3013
+ "vse32.v v31, (s3) \n\t"
3014
+ "END%=: \n\t"
3015
+
3016
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias)
3017
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
3018
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
3019
+ } else {
3020
+ __asm__ volatile(
3021
+ "vsetvli t0, zero, e32, m4 \n\t"
3022
+ "vxor.vv v28, v28, v28 \n\t"
3023
+ "addi s1, %[B], 0 \n\t"
3024
+ "addi s2, %[B], 16 \n\t"
3025
+ "addi s3, %[B], 32 \n\t"
3026
+ "addi s4, %[B], 48 \n\t"
3027
+
3028
+ "addi s5, %[A], 0 \n\t"
3029
+ "addi s6, %[A], 12 \n\t"
3030
+ "vsetvli t0, zero, e32, mf2 \n\t"
3031
+ "LOOP_K%=: \n\t"
3032
+ "vle32.v v8, (s1) \n\t"
3033
+ "addi s1, s1, 64 \n\t"
3034
+ "vle32.v v9, (s2) \n\t"
3035
+ "addi s2, s2, 80 \n\t"
3036
+ "vle32.v v10, (s3) \n\t"
3037
+ "addi s3, s3, 96 \n\t"
3038
+ "vle32.v v11, (s4) \n\t"
3039
+ "addi s4, s4, 112 \n\t"
3040
+ "flw f1, (s5) \n\t"
3041
+ "addi s5, s5, 4 \n\t"
3042
+
3043
+ "addi t5, %[INNER], 0 \n\t"
3044
+ "vxor.vv v16, v16, v16 \n\t"
3045
+ "vxor.vv v18, v18, v18 \n\t"
3046
+ "vxor.vv v20, v20, v20 \n\t"
3047
+ "vxor.vv v22, v22, v22 \n\t"
3048
+ "vfmul.vf v24, v8, f1 \n\t"
3049
+ "vfmul.vf v25, v9, f1 \n\t"
3050
+ "vfmul.vf v26, v10, f1 \n\t"
3051
+ "vfmul.vf v27, v11, f1 \n\t"
3052
+ "addi %[CNT], %[CNT], -1 \n\t"
3053
+ "vsetvli t0, zero, e8, m1 \n\t"
3054
+ "LOOP_INNER%=: \n\t"
3055
+
3056
+ SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4
3057
+
3058
+ "vadd.vi v0, v0, -8 \n\t"
3059
+ "vadd.vi v1, v1, -8 \n\t"
3060
+ "vadd.vi v2, v2, -8 \n\t"
3061
+ "vadd.vi v3, v3, -8 \n\t"
3062
+ "vadd.vi v4, v4, -8 \n\t"
3063
+ "vadd.vi v5, v5, -8 \n\t"
3064
+ "vadd.vi v6, v6, -8 \n\t"
3065
+ "vadd.vi v7, v7, -8 \n\t"
3066
+
3067
+ SQ4BIT_KERNEL_COMP_1x8x2_4X8X4
3068
+
3069
+ "bnez t5, LOOP_INNER%= \n\t"
3070
+ "vsetvli t0, zero, e32, mf2 \n\t"
3071
+
3072
+ SQ4BIT_KERNEL_ACC_1X4X4
3073
+
3074
+ "bnez %[CNT], LOOP_K%= \n\t"
3075
+ "addi t3, zero, 16 \n\t"
3076
+ "addi s1, %[C], 16 \n\t"
3077
+ "addi s2, %[C], 32 \n\t"
3078
+ "addi s3, %[C], 48 \n\t"
3079
+ "blt %[NBLKS], t3, ST_TAIL%= \n\t"
3080
+ "vse32.v v28, (%[C]) \n\t"
3081
+ "vse32.v v29, (s1) \n\t"
3082
+ "vse32.v v30, (s2) \n\t"
3083
+ "vse32.v v31, (s3) \n\t"
3084
+ "jal x0, END%= \n\t"
3085
+
3086
+ "ST_TAIL%=: \n\t"
3087
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3088
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3089
+ "vse32.v v28, (%[C]) \n\t"
3090
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3091
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3092
+ "vse32.v v29, (s1) \n\t"
3093
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3094
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3095
+ "vse32.v v30, (s2) \n\t"
3096
+ "vsetvli t0, %[NBLKS], e32, mf2 \n\t"
3097
+ "sub %[NBLKS], %[NBLKS], t0 \n\t"
3098
+ "vse32.v v31, (s3) \n\t"
3099
+ "END%=: \n\t"
3100
+
3101
+ : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks)
3102
+ : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr)
3103
+ : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6");
3104
+ }
3105
+ }
3106
+ }
3107
+ }
3108
+
3109
+ template <bool HasZeroPoint>
3110
+ inline void SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen,
3111
+ const std::byte * QuantA,
3112
+ const std::byte * QuantBData,
3113
+ const float * QuantBScale,
3114
+ const std::byte * QuantBZeroPoint,
3115
+ float * C,
3116
+ size_t CountM,
3117
+ size_t CountN,
3118
+ size_t BlockStrideQuantB,
3119
+ const float * Bias,
3120
+ const size_t ldc,
3121
+ const size_t scalestride) {
3122
+ if (scalestride == 4) {
3123
+ SQ4BitGemmM4Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C,
3124
+ CountN, BlockStrideQuantB, Bias, ldc);
3125
+
3126
+ } else if (scalestride == 2) {
3127
+ SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>(
3128
+ BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias, ldc);
3129
+ }
3130
+ }
3131
+
3132
+ template <bool HasZeroPoint>
3133
+ inline void SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen,
3134
+ const std::byte * QuantA,
3135
+ const std::byte * QuantBData,
3136
+ const float * QuantBScale,
3137
+ const std::byte * QuantBZeroPoint,
3138
+ float * C,
3139
+ size_t CountM,
3140
+ size_t CountN,
3141
+ size_t BlockStrideQuantB,
3142
+ const float * Bias,
3143
+ const size_t ldc,
3144
+ const size_t scalestride) {
3145
+ if (scalestride == 4) {
3146
+ SQ4BitGemmM1Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C,
3147
+ CountN, BlockStrideQuantB, Bias);
3148
+ } else if (scalestride == 2) {
3149
+ SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale,
3150
+ QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias);
3151
+ }
3152
+ }
3153
+
3154
+ } // namespace
3155
+
3156
+ namespace ime1 {
3157
+ size_t gemm_kernel_i8i4(size_t BlkLen,
3158
+ const std::byte * QuantA,
3159
+ const std::byte * QuantBData,
3160
+ const float * QuantBScale,
3161
+ const std::byte * QuantBZeroPoint,
3162
+ float * C,
3163
+ size_t CountM,
3164
+ size_t CountN,
3165
+ size_t CountK,
3166
+ size_t BlockCountK,
3167
+ size_t ldc,
3168
+ const float * Bias,
3169
+ const size_t ScaleStride) {
3170
+ GGML_UNUSED(CountM);
3171
+ GGML_UNUSED(CountK);
3172
+ GGML_UNUSED(ldc);
3173
+ if (CountM >= 4) {
3174
+ if (QuantBZeroPoint != nullptr) {
3175
+ SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint,
3176
+ C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride);
3177
+ } else {
3178
+ SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale,
3179
+ QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias,
3180
+ ldc, ScaleStride);
3181
+ }
3182
+ return 4;
3183
+ } else {
3184
+ if (QuantBZeroPoint != nullptr) {
3185
+ SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint,
3186
+ C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride);
3187
+ } else {
3188
+ SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale,
3189
+ QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias,
3190
+ ldc, ScaleStride);
3191
+ }
3192
+ return 1;
3193
+ }
3194
+ }
3195
+ } // namespace ime1
3196
+ } // namespace sqnbitgemm_spacemit_ime