whispercpp 1.3.6 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (828) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/README.md +38 -5
  5. data/Rakefile +18 -3
  6. data/ext/dependencies.rb +10 -4
  7. data/ext/dependencies_for_windows.rb +17 -0
  8. data/ext/extconf.rb +20 -8
  9. data/ext/options.rb +54 -14
  10. data/ext/options_for_windows.rb +51 -0
  11. data/ext/ruby_whisper.c +36 -42
  12. data/ext/ruby_whisper.h +135 -0
  13. data/ext/ruby_whisper_context.c +107 -28
  14. data/ext/ruby_whisper_log_queue.c +180 -0
  15. data/ext/ruby_whisper_log_settable.h +47 -0
  16. data/ext/ruby_whisper_parakeet.c +49 -0
  17. data/ext/ruby_whisper_parakeet_context.c +304 -0
  18. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  19. data/ext/ruby_whisper_parakeet_model.c +84 -0
  20. data/ext/ruby_whisper_parakeet_params.c +548 -0
  21. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  22. data/ext/ruby_whisper_parakeet_token.c +188 -0
  23. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  24. data/ext/ruby_whisper_params.c +256 -65
  25. data/ext/ruby_whisper_segment.c +6 -6
  26. data/ext/ruby_whisper_transcribe.cpp +42 -15
  27. data/ext/sources/CMakeLists.txt +41 -3
  28. data/ext/sources/CMakePresets.json +95 -0
  29. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  30. data/ext/sources/cmake/parakeet.pc.in +10 -0
  31. data/ext/sources/cmake/whisper.pc.in +1 -1
  32. data/ext/sources/examples/CMakeLists.txt +4 -2
  33. data/ext/sources/examples/bench/bench.cpp +1 -1
  34. data/ext/sources/examples/cli/cli.cpp +43 -9
  35. data/ext/sources/examples/common-ggml.cpp +2 -0
  36. data/ext/sources/examples/common-whisper.cpp +139 -67
  37. data/ext/sources/examples/common-whisper.h +11 -0
  38. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  39. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  40. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  41. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  42. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  43. data/ext/sources/examples/server/server.cpp +199 -163
  44. data/ext/sources/ggml/CMakeLists.txt +21 -13
  45. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  46. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  47. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  48. data/ext/sources/ggml/include/ggml-backend.h +72 -10
  49. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  50. data/ext/sources/ggml/include/ggml-rpc.h +3 -3
  51. data/ext/sources/ggml/include/ggml.h +101 -9
  52. data/ext/sources/ggml/include/gguf.h +10 -2
  53. data/ext/sources/ggml/src/CMakeLists.txt +22 -5
  54. data/ext/sources/ggml/src/ggml-alloc.c +5 -1
  55. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  56. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  57. data/ext/sources/ggml/src/ggml-backend-reg.cpp +12 -0
  58. data/ext/sources/ggml/src/ggml-backend.cpp +110 -9
  59. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +4 -0
  60. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +672 -257
  61. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +71 -0
  62. data/ext/sources/ggml/src/ggml-cann/common.h +20 -10
  63. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +211 -30
  64. data/ext/sources/ggml/src/ggml-common.h +11 -0
  65. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +58 -29
  66. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +2 -0
  67. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +16 -16
  68. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +116 -7
  69. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +65 -0
  70. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  71. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  72. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +4279 -1292
  73. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +5 -35
  74. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +0 -1
  75. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  76. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +177 -27
  77. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +1 -1
  78. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +5 -0
  79. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  80. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -0
  81. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +95 -5
  82. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  83. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +146 -134
  84. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +88 -70
  85. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +372 -73
  86. data/ext/sources/ggml/src/ggml-cpu/ops.h +3 -0
  87. data/ext/sources/ggml/src/ggml-cpu/quants.c +55 -0
  88. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  89. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3 -0
  90. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +90 -0
  91. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +3 -16
  92. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  93. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  94. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  95. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  96. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  97. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  98. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  99. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  100. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  101. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  102. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  103. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  104. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  105. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  106. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  107. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +37 -53
  108. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  109. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +17 -7
  110. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  111. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  112. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +62 -26
  113. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +44 -18
  114. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  115. data/ext/sources/ggml/src/ggml-cuda/common.cuh +242 -28
  116. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  117. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  118. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  119. data/ext/sources/ggml/src/ggml-cuda/convert.cu +53 -0
  120. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  121. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +14 -6
  122. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  123. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +278 -44
  124. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +331 -130
  125. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +126 -27
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +40 -15
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +18 -9
  129. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +152 -49
  130. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  131. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  132. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  133. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +84 -35
  134. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  135. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1069 -609
  136. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  137. data/ext/sources/ggml/src/ggml-cuda/mean.cu +4 -2
  138. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +242 -195
  139. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +3 -3
  140. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  141. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +502 -423
  142. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +19 -12
  143. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +485 -57
  144. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -1
  145. data/ext/sources/ggml/src/ggml-cuda/norm.cu +36 -10
  146. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  147. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +133 -26
  148. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  149. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +5 -1
  150. data/ext/sources/ggml/src/ggml-cuda/rope.cu +11 -4
  151. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  152. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  153. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  154. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  155. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  156. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +45 -13
  157. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  158. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  159. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  160. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  161. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +1 -0
  162. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  163. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  164. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  165. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +1 -0
  166. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  167. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  168. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  169. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  170. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  171. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  172. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  173. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  174. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  176. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  177. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  178. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  179. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  191. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -4
  192. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +26 -23
  193. data/ext/sources/ggml/src/ggml-cuda/unary.cu +31 -2
  194. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  195. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +80 -0
  196. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  197. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -4
  198. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  199. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +2 -1
  200. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1428 -743
  201. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -7
  202. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +53 -84
  203. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +25 -12
  204. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +165 -184
  205. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  206. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  207. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +170 -127
  208. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  209. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  210. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  211. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +125 -97
  212. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  213. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +148 -42
  214. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +2 -2
  215. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +252 -62
  216. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +9 -0
  217. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +87 -1
  218. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  219. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  220. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  221. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  222. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  223. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  224. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  225. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  226. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +96 -13
  227. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +182 -57
  228. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  229. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +71 -3
  230. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +27 -10
  231. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +63 -23
  232. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +9 -8
  233. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  235. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +1 -0
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +5 -8
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +529 -815
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2522 -234
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +291 -95
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +59 -37
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +121 -133
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +244 -151
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +6 -6
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +719 -45
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +3 -1
  254. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +22 -9
  255. data/ext/sources/ggml/src/ggml-impl.h +6 -1
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +138 -13
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +32 -1
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +164 -28
  259. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +80 -0
  260. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +190 -19
  261. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +2 -0
  262. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +39 -26
  263. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +823 -322
  264. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  265. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +54 -5
  266. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +12248 -5907
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +67 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +59 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1819 -112
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl → gemm_noshuffle_q8_0_f32.cl} +1 -1
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +48 -64
  322. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +15 -5
  323. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +18 -11
  324. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +35 -13
  325. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +264 -192
  326. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +33 -7
  327. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  328. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +1 -0
  329. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +1 -0
  330. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  331. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +27 -3
  332. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +67 -36
  333. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +1 -0
  334. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +101 -44
  335. data/ext/sources/ggml/src/ggml-openvino/utils.h +23 -3
  336. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  337. data/ext/sources/ggml/src/ggml-quants.c +289 -114
  338. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  339. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  340. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  341. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  342. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  343. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +50 -4
  344. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +1 -1
  345. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +3 -1
  346. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  347. data/ext/sources/ggml/src/ggml-sycl/common.hpp +41 -1
  348. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +115 -13
  349. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +9 -0
  350. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  351. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  352. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  353. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  354. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  355. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  356. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1 -90
  357. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +0 -2
  358. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  359. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  360. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +7 -5
  361. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +4 -0
  362. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +76 -168
  363. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +7 -0
  364. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +3 -1
  365. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  366. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  367. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +69 -31
  368. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +1 -0
  369. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  370. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  371. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +823 -190
  372. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  373. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  374. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  375. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  376. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  377. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +71 -0
  378. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  379. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  380. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  381. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  382. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  383. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  384. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  385. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  386. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  387. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +1 -0
  388. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +1 -0
  389. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +1 -0
  390. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +1 -0
  391. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +1 -0
  392. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +1 -0
  393. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +1 -0
  394. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +1 -0
  395. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +1 -0
  396. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +1 -0
  397. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +1 -0
  398. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +1 -0
  399. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +1 -0
  400. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +1 -0
  401. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +1 -0
  402. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +1 -0
  403. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +1 -0
  404. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +1 -0
  405. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +1 -0
  406. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +1 -0
  407. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +1 -0
  408. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +1 -0
  409. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +1 -0
  410. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +1 -0
  411. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +1 -0
  412. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +1 -0
  413. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +1 -0
  414. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +1 -0
  415. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +1 -0
  416. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +1 -0
  417. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +1 -0
  418. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +1 -0
  420. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +1 -0
  421. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +1 -0
  422. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +1 -0
  423. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  424. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  425. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  426. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +215 -53
  427. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +4 -0
  428. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +2 -0
  429. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +2 -0
  430. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +1 -0
  431. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +1 -0
  432. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +0 -2
  433. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +11 -0
  434. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +2060 -535
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +197 -48
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +60 -59
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +115 -113
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +122 -31
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +125 -64
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -17
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +43 -10
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +5 -2
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +3 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +11 -1
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +171 -147
  484. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  485. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +2202 -283
  486. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2610 -1403
  487. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  488. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  489. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +8 -7
  490. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +76 -95
  491. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +19 -1
  492. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  493. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -50
  494. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +107 -184
  495. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  496. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  497. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  498. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  499. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  500. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  501. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +183 -78
  502. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  503. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  504. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +655 -495
  505. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  506. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  507. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  508. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +8 -6
  509. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +5 -1
  510. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +80 -409
  511. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  512. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  513. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  514. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  515. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  516. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  517. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  518. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +6 -4
  519. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  520. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +2 -3
  521. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  522. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  523. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  524. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  525. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  526. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +68 -48
  527. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  528. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +18 -14
  529. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +1 -1
  530. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +244 -10
  531. data/ext/sources/ggml/src/ggml.c +110 -28
  532. data/ext/sources/ggml/src/gguf.cpp +173 -28
  533. data/ext/sources/include/parakeet.h +342 -0
  534. data/ext/sources/include/whisper.h +10 -0
  535. data/ext/sources/media/matmul.png +0 -0
  536. data/ext/sources/src/CMakeLists.txt +23 -0
  537. data/ext/sources/src/parakeet-arch.h +188 -0
  538. data/ext/sources/src/parakeet.cpp +3838 -0
  539. data/ext/sources/src/whisper.cpp +56 -12
  540. data/extsources.rb +26 -10
  541. data/lib/whisper/log_settable.rb +36 -0
  542. data/lib/whisper/model/uri.rb +13 -1
  543. data/lib/whisper/output.rb +74 -0
  544. data/sig/whisper.rbs +411 -62
  545. data/test/helper.rb +2 -0
  546. data/test/jfk_reader/jfk_reader.c +50 -7
  547. data/test/test_callback.rb +1 -0
  548. data/test/test_package.rb +6 -5
  549. data/test/test_parakeet.rb +28 -0
  550. data/test/test_parakeet_callback.rb +107 -0
  551. data/test/test_parakeet_context.rb +116 -0
  552. data/test/test_parakeet_context_params.rb +24 -0
  553. data/test/test_parakeet_model.rb +21 -0
  554. data/test/test_parakeet_params.rb +78 -0
  555. data/test/test_parakeet_segment.rb +42 -0
  556. data/test/test_parakeet_token.rb +73 -0
  557. data/test/test_params.rb +2 -0
  558. data/test/test_vad_segment.rb +1 -1
  559. data/test/test_whisper.rb +24 -6
  560. data/whispercpp.gemspec +2 -2
  561. metadata +215 -281
  562. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  563. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  564. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  565. data/ext/sources/bindings/javascript/package.json +0 -26
  566. data/ext/sources/bindings/javascript/whisper.js +0 -19
  567. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  568. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  569. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  570. data/ext/sources/examples/addon.node/index.js +0 -59
  571. data/ext/sources/examples/addon.node/package.json +0 -16
  572. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  573. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  574. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  575. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  576. data/ext/sources/examples/coi-serviceworker.js +0 -146
  577. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  578. data/ext/sources/examples/command/command.cpp +0 -802
  579. data/ext/sources/examples/command/commands.txt +0 -9
  580. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  581. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  582. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  583. data/ext/sources/examples/generate-karaoke.sh +0 -57
  584. data/ext/sources/examples/helpers.js +0 -191
  585. data/ext/sources/examples/livestream.sh +0 -112
  586. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  587. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  588. data/ext/sources/examples/lsp/whisper.vim +0 -362
  589. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  590. data/ext/sources/examples/python/whisper_processor.py +0 -54
  591. data/ext/sources/examples/server/bench.js +0 -29
  592. data/ext/sources/examples/server.py +0 -120
  593. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  594. data/ext/sources/examples/stream/stream.cpp +0 -437
  595. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  596. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  597. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  598. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  599. data/ext/sources/examples/sycl/build.sh +0 -22
  600. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  601. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  602. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -48
  603. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  604. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -488
  605. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -89
  606. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2877
  607. data/ext/sources/examples/talk-llama/llama-arch.h +0 -628
  608. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -919
  609. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  610. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -896
  611. data/ext/sources/examples/talk-llama/llama-chat.h +0 -71
  612. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3633
  613. data/ext/sources/examples/talk-llama/llama-context.h +0 -359
  614. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  615. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -47
  616. data/ext/sources/examples/talk-llama/llama-ext.h +0 -12
  617. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  618. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  619. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2735
  620. data/ext/sources/examples/talk-llama/llama-graph.h +0 -1031
  621. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -258
  622. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -353
  623. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  624. data/ext/sources/examples/talk-llama/llama-impl.h +0 -75
  625. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  626. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  627. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -330
  628. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  629. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2285
  630. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -389
  631. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  632. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +0 -275
  633. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +0 -140
  634. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  635. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  636. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1165
  637. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  638. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  639. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  640. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -752
  641. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  642. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1655
  643. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -206
  644. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -299
  645. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -40
  646. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -9056
  647. data/ext/sources/examples/talk-llama/llama-model.h +0 -597
  648. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1304
  649. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  650. data/ext/sources/examples/talk-llama/llama-sampler.cpp +0 -3885
  651. data/ext/sources/examples/talk-llama/llama-sampler.h +0 -42
  652. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3970
  653. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -187
  654. data/ext/sources/examples/talk-llama/llama.cpp +0 -1194
  655. data/ext/sources/examples/talk-llama/llama.h +0 -1573
  656. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -190
  657. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  658. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  659. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -137
  660. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  661. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -123
  662. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -143
  663. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -133
  664. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -184
  665. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -145
  666. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  667. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  668. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  669. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  670. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  671. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  672. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  673. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -122
  674. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  675. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -142
  676. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -262
  677. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +0 -445
  678. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -132
  679. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  680. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -148
  681. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  682. data/ext/sources/examples/talk-llama/models/eurobert.cpp +0 -97
  683. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +0 -145
  684. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  685. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  686. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -111
  687. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  688. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  689. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  690. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  691. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  692. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  693. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  694. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -157
  695. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  696. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  697. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -195
  698. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -210
  699. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  700. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -139
  701. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  702. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -153
  703. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  704. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  705. data/ext/sources/examples/talk-llama/models/jais2.cpp +0 -123
  706. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  707. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +0 -381
  708. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -196
  709. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  710. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  711. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  712. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -175
  713. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  714. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +0 -289
  715. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -54
  716. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -129
  717. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -200
  718. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -123
  719. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  720. data/ext/sources/examples/talk-llama/models/models.h +0 -704
  721. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -109
  722. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  723. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -162
  724. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  725. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  726. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  727. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  728. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  729. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  730. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  731. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  732. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +0 -122
  733. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  734. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  735. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  736. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  737. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -320
  738. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  739. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -169
  740. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  741. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  742. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  743. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  744. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -120
  745. data/ext/sources/examples/talk-llama/models/qwen35.cpp +0 -381
  746. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +0 -422
  747. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -131
  748. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -525
  749. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -140
  750. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -132
  751. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  752. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  753. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -164
  754. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  755. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  756. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -137
  757. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  758. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  759. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  760. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  761. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  762. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  763. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  764. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +0 -165
  765. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  766. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  767. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  768. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  769. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  770. data/ext/sources/examples/talk-llama/speak +0 -40
  771. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  772. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  773. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  774. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  775. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  776. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1103
  777. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  778. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  779. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  780. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  781. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  782. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  783. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  784. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  785. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  786. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  787. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  788. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  789. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  790. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  791. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -155
  792. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  793. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  794. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +0 -123
  795. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +0 -17
  796. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +0 -333
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  798. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -182
  799. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  800. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -718
  801. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  802. data/ext/sources/tests/CMakeLists.txt +0 -112
  803. data/ext/sources/tests/earnings21/eval.mk +0 -58
  804. data/ext/sources/tests/earnings21/eval.py +0 -68
  805. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  806. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  807. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  808. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  809. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  810. data/ext/sources/tests/en-0-ref.txt +0 -1
  811. data/ext/sources/tests/en-1-ref.txt +0 -1
  812. data/ext/sources/tests/en-2-ref.txt +0 -1
  813. data/ext/sources/tests/es-0-ref.txt +0 -1
  814. data/ext/sources/tests/librispeech/eval.mk +0 -39
  815. data/ext/sources/tests/librispeech/eval.py +0 -47
  816. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  817. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  818. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  819. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  820. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  821. data/ext/sources/tests/run-tests.sh +0 -130
  822. data/ext/sources/tests/test-c.c +0 -3
  823. data/ext/sources/tests/test-vad-full.cpp +0 -56
  824. data/ext/sources/tests/test-vad.cpp +0 -83
  825. data/ext/sources/tests/test-whisper.js +0 -58
  826. data/lib/whisper/context.rb +0 -15
  827. data/lib/whisper/segment.rb +0 -58
  828. /data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl → gemv_noshuffle_q8_0_f32.cl} +0 -0
@@ -3,19 +3,32 @@
3
3
 
4
4
  #include "ime.h"
5
5
 
6
+ #include "binary-ops.h"
7
+ #include "common.h"
6
8
  #include "ggml-backend-impl.h"
7
9
  #include "ggml-common.h"
8
10
  #include "ggml-cpu.h"
11
+ #include "ime_env.h"
9
12
  #include "ime_kernels.h"
13
+ #include "ops.h"
14
+ #include "repack.h"
15
+ #include "rvv_kernels.h"
16
+ #include "spine_mem_pool.h"
10
17
  #include "traits.h"
18
+ #include "vec.h"
19
+
20
+ #include <fcntl.h>
21
+ #include <sys/mman.h>
22
+ #include <unistd.h>
11
23
 
12
24
  #include <algorithm>
25
+ #include <atomic>
13
26
  #include <cassert>
27
+ #include <cerrno>
14
28
  #include <cmath>
15
29
  #include <cstdio> // for GGML_ASSERT
16
30
  #include <stdexcept>
17
31
  #include <thread>
18
-
19
32
  // clang-format off
20
33
  #if defined(__riscv)
21
34
 
@@ -25,13 +38,17 @@
25
38
  #include <riscv_vector.h>
26
39
  #endif
27
40
 
28
- #if !defined(__riscv_zfh)
29
- #error "riscv zfh extension not enabled"
41
+ #if !defined(__riscv_zfh) || !defined(__riscv_zvfh)
42
+ #error "riscv zfh extension not enabled, GGML_RV_ZFH and GGML_RV_ZVFH must be defined to 1"
30
43
  #endif
31
44
 
32
- #if defined(RISCV64_SPACEMIT_IME1)
45
+ #if !defined(__riscv_zba)
46
+ #error "riscv zba extension not enabled, GGML_RV_ZBA must be defined to 1"
47
+ #endif
48
+
49
+ #if defined(RISCV64_SPACEMIT_IME1) || defined(RISCV64_SPACEMIT_IME2)
33
50
  #else
34
- #error "RISCV64_SPACEMIT_IME1 not defined"
51
+ #error "RISCV64_SPACEMIT_IME1 or RISCV64_SPACEMIT_IME2 not defined"
35
52
  #endif
36
53
 
37
54
  #else
@@ -46,382 +63,490 @@
46
63
  #pragma GCC diagnostic ignored "-Wunused-parameter"
47
64
  #endif
48
65
 
49
- #if defined(RISCV64_SPACEMIT_IME1)
50
- #define QGEMM_STRIDEN_THREAD_ALIGN 16
51
- #else
52
- #define QGEMM_STRIDEN_THREAD_ALIGN 32
53
- #endif
54
-
55
66
  // clang-format on
56
67
 
57
- struct qnbitgemm_spacemit_ime_args {
58
- const float * a_ptr = nullptr;
59
- size_t lda = 0;
60
- const std::byte * packed_quant_b_data = nullptr;
61
- const float * quant_b_scale = nullptr;
62
- const void * quant_b_zp = nullptr;
63
- const float * quant_b_blksum = nullptr;
64
- const float * bias = nullptr;
65
- float * c_ptr = nullptr;
66
- size_t ldc = 0;
67
- };
68
-
69
- constexpr size_t div_round_up(size_t up, size_t down) {
70
- return (up + down - 1) / down;
71
- }
72
-
73
- constexpr size_t q8_blk_size(size_t blk_len) {
74
- const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t);
75
- // Currently, the strictest alignment requirement of a block is for a float.
76
- // Ensure contiguous blocks are suitably aligned.
77
- assert(blk_size % alignof(float) == 0);
78
- return blk_size;
68
+ extern "C" {
69
+ extern void ggml_threadpool_chunk_set(struct ggml_threadpool * tp, int value);
70
+ extern int ggml_threadpool_chunk_add(struct ggml_threadpool * tp, int value);
79
71
  }
80
72
 
81
73
  namespace ggml::cpu::riscv64_spacemit {
82
74
 
83
- const int num_ai_cores = std::thread::hardware_concurrency() / 2;
84
-
85
- } // namespace ggml::cpu::riscv64_spacemit
75
+ struct TLSContext {
76
+ int cpu_id{ -1 };
77
+ cpu_set_t cpuset;
78
+ void * tcm_buffer{ nullptr };
79
+ size_t tcm_buffer_size{ 0 };
80
+ };
86
81
 
87
- static void sqnbitgemm_spacemit_ime_i8i4(const size_t blk_len,
88
- const size_t gemm_k,
89
- const qnbitgemm_spacemit_ime_args * gemm_args,
90
- void * const per_gemm_ws,
91
- const size_t m_start,
92
- const size_t m_count,
93
- const size_t n_start,
94
- const size_t n_count) {
95
- constexpr size_t scale_stride = sizeof(uint16_t);
96
- constexpr size_t blk_bitwidth = 4;
82
+ thread_local TLSContext tls_context;
83
+
84
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> constexpr size_t get_repacked_block_type_size() {
85
+ if constexpr (std::is_same_v<BLOC_TYPE, block_q6_K> || std::is_same_v<BLOC_TYPE, block_q8_0>) {
86
+ return sizeof(block_q8_0);
87
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0>) {
88
+ return sizeof(block_q4_0) * INTER_SIZE / QK4_0;
89
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q4_1> || std::is_same_v<BLOC_TYPE, block_q4_K>) {
90
+ return (sizeof(block_q4_0) + sizeof(uint8_t)) * INTER_SIZE / QK4_1;
91
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q2_K>) {
92
+ return sizeof(spacemit_kernels::nrow_block_q2_k<1>);
93
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q3_K>) {
94
+ return sizeof(spacemit_kernels::nrow_block_q3_k<1>);
95
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_mxfp4>) {
96
+ return sizeof(spacemit_kernels::nrow_block_mxfp4<1>);
97
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q5_1> || std::is_same_v<BLOC_TYPE, block_q5_K>) {
98
+ return sizeof(spacemit_kernels::nrow_block_q5_1<1>);
99
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q5_0>) {
100
+ return sizeof(spacemit_kernels::nrow_block_q5_0<1>);
101
+ } else {
102
+ assert(false);
103
+ return 0;
104
+ }
105
+ }
97
106
 
98
- const size_t k_blks = div_round_up(gemm_k, blk_len);
107
+ template <typename BLOC_TYPE> constexpr bool block_type_has_zp() {
108
+ if constexpr (std::is_same_v<BLOC_TYPE, block_q6_K> || std::is_same_v<BLOC_TYPE, block_q8_0> ||
109
+ std::is_same_v<BLOC_TYPE, block_q3_K> || std::is_same_v<BLOC_TYPE, block_q4_0> ||
110
+ std::is_same_v<BLOC_TYPE, block_mxfp4> || std::is_same_v<BLOC_TYPE, block_q5_0>) {
111
+ return false;
112
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q4_1> || std::is_same_v<BLOC_TYPE, block_q4_K> ||
113
+ std::is_same_v<BLOC_TYPE, block_q2_K> || std::is_same_v<BLOC_TYPE, block_q5_1> ||
114
+ std::is_same_v<BLOC_TYPE, block_q5_K>) {
115
+ return true;
116
+ } else {
117
+ assert(false);
118
+ return false;
119
+ }
120
+ }
99
121
 
100
- const size_t lda = k_blks * q8_blk_size(blk_len);
101
- const size_t ldc = gemm_args->ldc;
102
- const size_t ldb = k_blks * (blk_len * blk_bitwidth / 8);
103
- const std::byte * quant_a_ptr = static_cast<const std::byte *>(per_gemm_ws) + m_start * lda;
122
+ class tensor_traits_base : public ggml::cpu::tensor_traits {
123
+ public:
124
+ virtual int repack(ggml_tensor * t, const void * data, size_t data_size) = 0;
125
+ };
104
126
 
105
- const size_t zero_point_stride = gemm_args->quant_b_zp != nullptr ? sizeof(uint8_t) : 0;
106
- const size_t packed_b_stride = ldb + k_blks * (scale_stride + zero_point_stride);
107
- const std::byte * packed_quant_b_data = gemm_args->packed_quant_b_data + n_start * packed_b_stride;
127
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
128
+ bool work_size(int /* n_threads */, const ggml_tensor * op, size_t & size) override {
129
+ switch (op->op) {
130
+ case GGML_OP_MUL_MAT:
131
+ {
132
+ int64_t src1_nelements = ggml_nelements(op->src[1]);
133
+
134
+ if constexpr (std::is_same_v<BLOC_TYPE, block_q2_K> || std::is_same_v<BLOC_TYPE, block_q3_K>) {
135
+ size =
136
+ spacemit_kernels::div_round_up(src1_nelements, QK_K) * spacemit_kernels::q8k_blk_size(QK_K);
137
+ } else if constexpr (INTER_SIZE == QK4_0) {
138
+ size = spacemit_kernels::div_round_up(src1_nelements, QK4_0) *
139
+ spacemit_kernels::q8_blk_size(QK4_0, true);
140
+ } else if constexpr (INTER_SIZE == 256) {
141
+ size = spacemit_kernels::div_round_up(src1_nelements, 256) *
142
+ spacemit_kernels::q8_hp_blk_size(256, true, true);
143
+ } else {
144
+ GGML_ABORT("unsupported block type");
145
+ }
108
146
 
109
- float * c_ptr = gemm_args->c_ptr + m_start * ldc + n_start;
147
+ size = GGML_PAD(size, sizeof(int64_t));
110
148
 
111
- size_t count_n = 0;
112
- const size_t compute_block_count_n = m_count == 1 ? n_count : 16;
113
- for (size_t n = 0; n < n_count; n += count_n) {
114
- count_n = std::min(n_count - n, compute_block_count_n);
149
+ return true;
150
+ }
151
+ case GGML_OP_MUL_MAT_ID:
152
+ {
153
+ int64_t src1_nelements = ggml_nelements(op->src[1]);
154
+
155
+ if constexpr (std::is_same_v<BLOC_TYPE, block_q2_K> || std::is_same_v<BLOC_TYPE, block_q3_K>) {
156
+ size =
157
+ spacemit_kernels::div_round_up(src1_nelements, QK_K) * spacemit_kernels::q8k_blk_size(QK_K);
158
+ } else if constexpr (INTER_SIZE == QK4_0) {
159
+ size = spacemit_kernels::div_round_up(src1_nelements, QK4_0) *
160
+ spacemit_kernels::q8_blk_size(QK4_0, true);
161
+ } else if constexpr (INTER_SIZE == 256) {
162
+ size = spacemit_kernels::div_round_up(src1_nelements, 256) *
163
+ spacemit_kernels::q8_hp_blk_size(256, true, true);
164
+ } else {
165
+ GGML_ABORT("unsupported block type");
166
+ }
115
167
 
116
- const std::byte * a_row = quant_a_ptr;
117
- const std::byte * b_col = packed_quant_b_data + n * packed_b_stride;
118
- const std::byte * b_col_zp = (zero_point_stride != 0) ? b_col : nullptr;
119
- float * c_blk = c_ptr + n;
168
+ size = GGML_PAD(size, sizeof(int64_t));
120
169
 
121
- int32_t rows_remaining = m_count;
170
+ const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert
171
+ const int64_t ne12 = op->src[1]->ne[2]; // n_tokens
122
172
 
123
- while (rows_remaining > 0) {
124
- const auto rows_handled = sqnbitgemm_spacemit_ime::ime1::gemm_kernel_i8i4(
125
- blk_len, a_row, b_col, nullptr, b_col_zp, c_blk, rows_remaining, count_n, gemm_k, k_blks, ldc, nullptr,
126
- scale_stride);
173
+ const size_t sizeof_mmid_row_mapping = sizeof(int64_t);
174
+ size += sizeof_mmid_row_mapping * ne02 * (ne12 + 1) + (ne02 + 1) * sizeof(int64_t);
127
175
 
128
- c_blk += rows_handled * ldc;
129
- a_row += rows_handled * lda;
176
+ size = GGML_PAD(size, sizeof(int64_t));
130
177
 
131
- rows_remaining -= rows_handled;
178
+ return true;
179
+ }
180
+ default:
181
+ // GGML_ABORT("fatal error");
182
+ break;
132
183
  }
184
+ return false;
133
185
  }
134
- }
135
186
 
136
- template <int K> constexpr int QK_0() {
137
- if constexpr (K == 4) {
138
- return QK4_0;
139
- }
140
- if constexpr (K == 8) {
141
- return QK8_0;
187
+ bool compute_forward(ggml_compute_params * params, ggml_tensor * op) override {
188
+ switch (op->op) {
189
+ case GGML_OP_MUL_MAT:
190
+ switch (op->src[0]->type) {
191
+ case GGML_TYPE_Q2_K:
192
+ case GGML_TYPE_Q3_K:
193
+ case GGML_TYPE_Q4_0:
194
+ case GGML_TYPE_Q4_1:
195
+ case GGML_TYPE_Q4_K:
196
+ case GGML_TYPE_Q6_K:
197
+ case GGML_TYPE_Q8_0:
198
+ case GGML_TYPE_Q5_1:
199
+ case GGML_TYPE_Q5_K:
200
+ //case GGML_TYPE_MXFP4:
201
+ forward_mul_mat(params, op);
202
+ return true;
203
+ default:
204
+ // GGML_ABORT("fatal error: unsupported type for src0 in MUL_MAT");
205
+ return false;
206
+ }
207
+ break;
208
+ case GGML_OP_MUL_MAT_ID:
209
+ switch (op->src[0]->type) {
210
+ case GGML_TYPE_Q2_K:
211
+ case GGML_TYPE_Q3_K:
212
+ case GGML_TYPE_Q4_0:
213
+ case GGML_TYPE_Q4_1:
214
+ case GGML_TYPE_Q4_K:
215
+ case GGML_TYPE_Q6_K:
216
+ case GGML_TYPE_Q8_0:
217
+ case GGML_TYPE_Q5_1:
218
+ case GGML_TYPE_Q5_K:
219
+ //case GGML_TYPE_MXFP4:
220
+ forward_mul_mat_id(params, op);
221
+ return true;
222
+ default:
223
+ // GGML_ABORT("fatal error: unsupported type for src0 in MUL_MAT_ID");
224
+ return false;
225
+ }
226
+ break;
227
+ default:
228
+ // GGML_ABORT("fatal error");
229
+ break;
230
+ }
231
+ return false;
142
232
  }
143
- return -1;
144
- }
145
233
 
146
- template <int K, int N> struct block {
147
- ggml_half d[N]; // deltas for N qK_0 blocks
148
- uint8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_0 blocks
149
- };
234
+ void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
235
+ constexpr size_t a_blk_len = INTER_SIZE;
236
+ constexpr size_t b_blk_len = INTER_SIZE;
150
237
 
151
- template <int K, int N> struct block_with_zp {
152
- ggml_half d[N]; // deltas for N qK_1 blocks
153
- uint8_t zp[N]; // zero points for N qK_1 blocks
154
- uint8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_1 blocks
155
- };
238
+ const ggml_tensor * src0 = op->src[0];
239
+ const ggml_tensor * src1 = op->src[1];
240
+ ggml_tensor * dst = op;
156
241
 
157
- // control size
158
- static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding");
159
- static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t),
160
- "wrong block_with_zp<4,16> size/padding");
161
- static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding");
242
+ GGML_TENSOR_BINARY_OP_LOCALS
162
243
 
163
- using block_q4_0x16 = block<4, 16>;
164
- using block_q4_1x16 = block_with_zp<4, 16>;
165
- using block_q8_0x16 = block<8, 16>;
244
+ int ith = params->ith;
245
+ int nth = params->nth;
166
246
 
167
- static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) {
168
- block_q4_0x16 out;
169
- GGML_ASSERT(QK4_0 / blck_size_interleave == 2);
247
+ [[maybe_unused]] const enum ggml_type type = src0->type;
170
248
 
171
- for (int i = 0; i < 16; i++) {
172
- out.d[i] = in[i].d;
173
- }
249
+ void * w_data = (void *) src0->data;
250
+ const float * feature = (const float *) src1->data;
251
+ float * output = (float *) dst->data;
174
252
 
175
- for (int i = 0; i < 16; i++) {
176
- // [0, 15], in.d & 0x0F
177
- for (int j = 0; j < QK4_0 / 4; j++) {
178
- //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
179
- //dst [b0 b8] ......... [b7 b15]
180
- out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4);
253
+ const int64_t gemm_m = ne11 * ne12 * ne13;
254
+ const int64_t gemm_k = ne10;
255
+ const int64_t gemm_n = ne01;
256
+
257
+ spacemit_kernels::quantize_a_row_def quantize_a_row_i8;
258
+ spacemit_kernels::quantize_a_row_def quantize_a_4row_i8;
259
+ spacemit_kernels::gemm_kernel_quantize_def gemm_kernel;
260
+ bool set_kernel_impl = false;
261
+
262
+ int64_t block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len);
263
+
264
+ #if defined(RISCV64_SPACEMIT_IME2)
265
+ if (!set_kernel_impl && (global_spine_env_info.use_ime2)) {
266
+ quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8;
267
+ quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8;
268
+ block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true);
269
+
270
+ if constexpr (std::is_same_v<BLOC_TYPE, block_q6_K> || std::is_same_v<BLOC_TYPE, block_q8_0>) {
271
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i8;
272
+ set_kernel_impl = true;
273
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0> || std::is_same_v<BLOC_TYPE, block_q4_1> ||
274
+ std::is_same_v<BLOC_TYPE, block_q4_K>) {
275
+ if constexpr (INTER_SIZE == 256) {
276
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4_hp;
277
+ quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8_hp;
278
+ quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8_hp;
279
+ block_stride_a = spacemit_kernels::q8_hp_blk_size(a_blk_len, true, true);
280
+ set_kernel_impl = true;
281
+ } else {
282
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4;
283
+ quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8;
284
+ quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8;
285
+ block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true);
286
+ set_kernel_impl = true;
287
+ }
288
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q2_K>) {
289
+ quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k;
290
+ quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8k;
291
+ block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len);
292
+
293
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i2k;
294
+ set_kernel_impl = true;
295
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q3_K>) {
296
+ quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k;
297
+ quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8k;
298
+ block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len);
299
+
300
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i3k;
301
+ set_kernel_impl = true;
302
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_mxfp4>) {
303
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8mxfp4;
304
+ set_kernel_impl = true;
305
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q5_1> || std::is_same_v<BLOC_TYPE, block_q5_K> ||
306
+ std::is_same_v<BLOC_TYPE, block_q5_0>) {
307
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i5;
308
+ set_kernel_impl = true;
309
+ }
181
310
  }
182
- }
311
+ #endif
183
312
 
184
- for (int i = 0; i < 16; i++) {
185
- // [16, 31], in.d & 0xF0
186
- for (int j = 0; j < QK4_0 / 4; j++) {
187
- //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
188
- //dst [b16 b24] ......... [b23 b31]
189
- out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0);
313
+ #if defined(RISCV64_SPACEMIT_IME1)
314
+ if (!set_kernel_impl && (global_spine_env_info.use_ime1)) {
315
+ quantize_a_row_i8 = spacemit_kernels::ime1::quantize_a_row_i8;
316
+ quantize_a_4row_i8 = spacemit_kernels::ime1::quantize_a_4row_i8;
317
+
318
+ if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0> || std::is_same_v<BLOC_TYPE, block_q4_1> ||
319
+ std::is_same_v<BLOC_TYPE, block_q4_K>) {
320
+ gemm_kernel = spacemit_kernels::ime1::gemm_kernel_i8i4;
321
+ set_kernel_impl = true;
322
+ }
323
+ }
324
+ #endif
325
+ if (!set_kernel_impl) {
326
+ GGML_ABORT("no kernel implementation found for the block type");
190
327
  }
191
- }
192
328
 
193
- return out;
194
- }
329
+ const int64_t a_k_blks = spacemit_kernels::div_round_up(gemm_k, a_blk_len);
330
+ const int64_t b_k_blks = spacemit_kernels::div_round_up(gemm_k, b_blk_len);
195
331
 
196
- static block_q4_1x16 make_block_q4_1x16(block_q4_1 * in, unsigned int blck_size_interleave) {
197
- block_q4_1x16 out;
198
- GGML_ASSERT(QK4_1 / blck_size_interleave == 2);
199
-
200
- for (int i = 0; i < 16; i++) {
201
- float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d);
202
- float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m);
203
- float mid = -std::nearbyintf(m / d);
204
- mid = std::min(15.0f, std::max(0.0f, mid));
205
- out.d[i] = GGML_FP32_TO_FP16(d);
206
- out.zp[i] = static_cast<uint8_t>(mid);
207
- }
332
+ const int64_t row_stride_a = a_k_blks * block_stride_a;
333
+ const int64_t gemm_workspace_size = GGML_PAD(gemm_m * row_stride_a, alignof(int64_t));
208
334
 
209
- for (int i = 0; i < 16; i++) {
210
- // [0, 15], in.d & 0x0F
211
- for (int j = 0; j < QK4_1 / 4; j++) {
212
- //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
213
- //dst [b0 b8] ......... [b7 b15]
214
- out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4);
335
+ if (ith == 0 && params->wsize < gemm_workspace_size) {
336
+ GGML_ABORT("wsize less than gemm_workspace_size");
215
337
  }
216
- }
217
338
 
218
- for (int i = 0; i < 16; i++) {
219
- // [16, 31], in.d & 0xF0
220
- for (int j = 0; j < QK4_1 / 4; j++) {
221
- //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
222
- //dst [b16 b24] ......... [b23 b31]
223
- out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0);
224
- }
225
- }
339
+ uintptr_t ws_ptr = reinterpret_cast<uintptr_t>(params->wdata);
226
340
 
227
- return out;
228
- }
341
+ void * tcm_buffer = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer;
342
+ const int64_t tcm_buffer_size = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size;
229
343
 
230
- static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t,
231
- int interleave_block,
232
- const void * GGML_RESTRICT data,
233
- size_t data_size) {
234
- GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
235
- GGML_ASSERT(interleave_block == 16);
344
+ auto * quant_a_buffer = reinterpret_cast<uint8_t *>(ws_ptr);
236
345
 
237
- constexpr int nrows_interleaved = 16;
346
+ constexpr int64_t row_align = 4;
347
+ const int64_t row_blks = spacemit_kernels::div_round_up(gemm_m, row_align);
238
348
 
239
- block_q4_0x16 * dst = (block_q4_0x16 *) t->data;
240
- const block_q4_0 * src = (const block_q4_0 *) data;
241
- block_q4_0 dst_tmp[16];
242
- int nrow = ggml_nrows(t);
243
- int nblocks = t->ne[0] / QK4_0;
349
+ const int64_t row_stride_b = b_k_blks * get_repacked_block_type_size<BLOC_TYPE, INTER_SIZE, NB_COLS>();
350
+ const int64_t per_mb_rows_wsize = row_align * row_stride_a;
351
+ const int64_t per_nb_cols_wsize = NB_COLS * row_stride_b;
244
352
 
245
- GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
353
+ const int64_t barrier_idx = static_cast<int64_t>(ith / 2);
246
354
 
247
- if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) {
248
- return -1;
249
- }
355
+ GGML_ASSERT(global_spine_env_info.init_barrier != nullptr);
356
+ GGML_ASSERT(barrier_idx < spine_init_barrier_count);
357
+ spine_barrier_t * cur_barrier = &global_spine_env_info.init_barrier[barrier_idx];
250
358
 
251
- for (int b = 0; b < nrow; b += nrows_interleaved) {
252
- for (int64_t x = 0; x < nblocks; x++) {
253
- for (int i = 0; i < nrows_interleaved; i++) {
254
- dst_tmp[i] = src[x + i * nblocks];
359
+ if (gemm_m == 1) {
360
+ int task_per_thread = spacemit_kernels::div_round_up(a_k_blks, nth);
361
+ int a_blk_start = ith * task_per_thread;
362
+ int a_blk_end = std::min(a_blk_start + task_per_thread, (int) a_k_blks);
363
+ if (a_blk_start < a_blk_end) {
364
+ quantize_a_row_i8(a_blk_len, feature + a_blk_start * a_blk_len, (a_blk_end - a_blk_start) * a_blk_len,
365
+ quant_a_buffer + a_blk_start * block_stride_a);
366
+ }
367
+ } else {
368
+ int task_per_thread = spacemit_kernels::div_round_up(row_blks, nth);
369
+ int m_row_blk_start = ith * task_per_thread;
370
+ int m_row_blk_end = std::min(m_row_blk_start + task_per_thread, (int) row_blks);
371
+ for (int m_row_blk = m_row_blk_start; m_row_blk < m_row_blk_end; m_row_blk++) {
372
+ int m_idx = m_row_blk * row_align;
373
+ int rows_tobe_handled = (gemm_m - m_idx) > row_align ? row_align : (gemm_m - m_idx);
374
+
375
+ if (rows_tobe_handled == row_align && quantize_a_4row_i8 != nullptr) {
376
+ const float * a_row_ptr = feature + m_idx * gemm_k;
377
+ auto * quant_a_row_ptr = quant_a_buffer + m_idx * row_stride_a;
378
+ quantize_a_4row_i8(a_blk_len, a_row_ptr, gemm_k, quant_a_row_ptr);
379
+ } else {
380
+ while (rows_tobe_handled) {
381
+ const float * a_row_ptr = feature + m_idx * gemm_k;
382
+ auto * quant_a_row_ptr = quant_a_buffer + m_idx * row_stride_a;
383
+ quantize_a_row_i8(a_blk_len, a_row_ptr, gemm_k, quant_a_row_ptr);
384
+ rows_tobe_handled -= 1;
385
+ m_idx += 1;
386
+ }
387
+ }
255
388
  }
256
- *dst++ = make_block_q4_0x16(dst_tmp, interleave_block);
257
389
  }
258
- src += nrows_interleaved * nblocks;
259
- }
260
- return 0;
261
390
 
262
- GGML_UNUSED(data_size);
263
- }
391
+ ggml_barrier(params->threadpool);
264
392
 
265
- static int repack_q4_1_to_q4_1_16_bl(struct ggml_tensor * t,
266
- int interleave_block,
267
- const void * GGML_RESTRICT data,
268
- size_t data_size) {
269
- GGML_ASSERT(t->type == GGML_TYPE_Q4_1);
270
- GGML_ASSERT(interleave_block == 16);
393
+ const int64_t gemm_m_stride = gemm_n / gemm_m > 64 ? gemm_m : 16;
394
+ const int64_t gemm_m_blocked = spacemit_kernels::div_round_up(gemm_m, gemm_m_stride);
395
+ const int64_t max_gemm_n_stride = spacemit_kernels::div_round_up(gemm_n * gemm_m_blocked, nth);
271
396
 
272
- constexpr int nrows_interleaved = 16;
397
+ int64_t gemm_n_stride = gemm_n;
398
+ if (max_gemm_n_stride < gemm_n) {
399
+ gemm_n_stride =
400
+ std::min(gemm_n_stride, spacemit_kernels::div_round_up(max_gemm_n_stride, NB_COLS) * NB_COLS);
401
+ }
273
402
 
274
- block_q4_1x16 * dst = (block_q4_1x16 *) t->data;
275
- const block_q4_1 * src = (const block_q4_1 *) data;
276
- block_q4_1 dst_tmp[16];
277
- int nrow = ggml_nrows(t);
278
- int nblocks = t->ne[0] / QK4_1;
403
+ if (gemm_n_stride == gemm_n && tcm_buffer != nullptr && per_mb_rows_wsize <= tcm_buffer_size) {
404
+ for (int64_t m_start = ith * row_align; m_start < gemm_m; m_start += row_align * nth) {
405
+ uint8_t * b_col = reinterpret_cast<uint8_t *>(w_data);
406
+ uint8_t * b_col_zp = block_type_has_zp<BLOC_TYPE>() ? b_col : nullptr;
279
407
 
280
- GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1));
408
+ int64_t m_row_real = std::min(gemm_m - m_start, row_align);
281
409
 
282
- if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) {
283
- return -1;
284
- }
410
+ spacemit_kernels::rvv::memcpy1d(tcm_buffer, quant_a_buffer + m_start * row_stride_a,
411
+ m_row_real * row_stride_a);
285
412
 
286
- for (int b = 0; b < nrow; b += nrows_interleaved) {
287
- for (int64_t x = 0; x < nblocks; x++) {
288
- for (int i = 0; i < nrows_interleaved; i++) {
289
- dst_tmp[i] = src[x + i * nblocks];
413
+ int64_t n_blk_real = 0;
414
+ for (int64_t ni = 0; ni < gemm_n; ni += n_blk_real, b_col += n_blk_real * row_stride_b) {
415
+ n_blk_real = std::min(gemm_n - ni, (int64_t) NB_COLS);
416
+
417
+ uint8_t * a_row_ptr = (uint8_t *) tcm_buffer;
418
+ float * c_blk = output + m_start * gemm_n + ni;
419
+
420
+ int32_t rows_remaining = m_row_real;
421
+
422
+ while (rows_remaining > 0) {
423
+ auto rows_handled = gemm_kernel(b_blk_len, a_row_ptr, b_col, b_col_zp, c_blk, rows_remaining,
424
+ n_blk_real, b_k_blks, gemm_n);
425
+
426
+ c_blk += rows_handled * gemm_n;
427
+ a_row_ptr += rows_handled * row_stride_a;
428
+
429
+ rows_remaining -= rows_handled;
430
+ }
431
+ }
290
432
  }
291
- *dst++ = make_block_q4_1x16(dst_tmp, interleave_block);
292
- }
293
- src += nrows_interleaved * nblocks;
294
- }
295
- return 0;
433
+ } else if (tcm_buffer != nullptr && per_nb_cols_wsize <= tcm_buffer_size) {
434
+ uint8_t * a_row = quant_a_buffer;
435
+ uint8_t * b_col = reinterpret_cast<uint8_t *>(tcm_buffer);
436
+ if ((gemm_workspace_size + per_nb_cols_wsize) <= tcm_buffer_size) {
437
+ a_row = (uint8_t *) tcm_buffer;
438
+ b_col = reinterpret_cast<uint8_t *>(tcm_buffer) + gemm_workspace_size;
439
+ }
440
+ uint8_t * b_col_zp = block_type_has_zp<BLOC_TYPE>() ? b_col : nullptr;
296
441
 
297
- GGML_UNUSED(data_size);
298
- }
442
+ int64_t ni = ith * NB_COLS;
443
+ int64_t nb_real = std::min(gemm_n - ni, NB_COLS);
299
444
 
300
- static inline void get_scale_min_k4(int j,
301
- const uint8_t * GGML_RESTRICT q,
302
- uint8_t * GGML_RESTRICT d,
303
- uint8_t * GGML_RESTRICT m) {
304
- if (j < 4) {
305
- *d = q[j] & 63;
306
- *m = q[j + 4] & 63;
307
- } else {
308
- *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
309
- *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);
310
- }
311
- }
445
+ if (ith % 2 == 0 && nb_real > 0) {
446
+ spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast<uint8_t *>(w_data) + ni * row_stride_b,
447
+ nb_real * row_stride_b);
448
+ if (a_row != quant_a_buffer) {
449
+ spacemit_kernels::rvv::memcpy1d(a_row, quant_a_buffer, gemm_workspace_size);
450
+ }
451
+ }
312
452
 
313
- static int repack_q4_k_to_q4_1_16_bl(struct ggml_tensor * t,
314
- int interleave_block,
315
- const void * GGML_RESTRICT data,
316
- size_t data_size) {
317
- GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
318
- GGML_ASSERT(interleave_block == 16);
319
- GGML_ASSERT(QK_K / QK4_1 == 8);
453
+ spine_barrier_wait(cur_barrier);
320
454
 
321
- constexpr int nrows_interleaved = 16;
455
+ if (ith % 2 != 0 && nb_real > 0) {
456
+ if (a_row != quant_a_buffer) {
457
+ spacemit_kernels::rvv::memcpy1d(a_row, quant_a_buffer, gemm_workspace_size);
458
+ }
459
+ spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast<uint8_t *>(w_data) + ni * row_stride_b,
460
+ nb_real * row_stride_b);
461
+ }
322
462
 
323
- block_q4_1x16 * dst = (block_q4_1x16 *) t->data;
324
- const block_q4_K * src = (const block_q4_K *) data;
325
- block_q4_1 dst_tmp[16];
326
- int nrow = ggml_nrows(t);
327
- int nblocks = t->ne[0] / QK_K;
463
+ for (; ni < gemm_n; ni += NB_COLS * nth) {
464
+ int64_t rows_remaining = gemm_m;
465
+ float * c_blk = output + ni;
466
+ auto * a_row_cur = a_row;
328
467
 
329
- if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) {
330
- return -1;
331
- }
468
+ if (ith % 2 != 0) {
469
+ spine_barrier_wait(cur_barrier);
470
+ }
332
471
 
333
- for (int b = 0; b < nrow; b += nrows_interleaved) {
334
- for (int64_t x = 0; x < nblocks; x++) {
335
- for (int j = 0; j < 8; j++) {
336
- for (int i = 0; i < nrows_interleaved; i++) {
337
- uint8_t sc, m;
338
- const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d);
339
- const float min =
340
- GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin);
341
- get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m);
342
- const float d1 = d * sc;
343
- const float m1 = min * m;
344
-
345
- dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1);
346
- dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1);
347
- // src -> [b0, b32] [b1, b33] ... [b31, b63]
348
- // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63]
349
- const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1;
350
- if (j % 2 == 0) {
351
- for (int ii = 0; ii < 16; ii++) {
352
- dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4);
353
- }
354
- } else {
355
- for (int ii = 0; ii < 16; ii++) {
356
- dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0);
357
- }
358
- }
472
+ while (rows_remaining > 0) {
473
+ auto rows_handled = gemm_kernel(b_blk_len, a_row_cur, b_col, b_col_zp, c_blk, rows_remaining,
474
+ nb_real, b_k_blks, gemm_n);
475
+
476
+ c_blk += rows_handled * gemm_n;
477
+ a_row_cur += rows_handled * row_stride_a;
478
+
479
+ rows_remaining -= rows_handled;
480
+ }
481
+
482
+ if (ith % 2 == 0) {
483
+ spine_barrier_wait(cur_barrier);
484
+ }
485
+
486
+ const int64_t next_ni = ni + NB_COLS * nth;
487
+ if (next_ni < gemm_n) {
488
+ nb_real = std::min(gemm_n - next_ni, NB_COLS);
489
+ spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast<uint8_t *>(w_data) + next_ni * row_stride_b,
490
+ nb_real * row_stride_b);
359
491
  }
360
- *dst++ = make_block_q4_1x16(dst_tmp, interleave_block);
361
492
  }
362
- }
363
- src += nrows_interleaved * nblocks;
364
- }
365
- return 0;
493
+ } else {
494
+ const int64_t task_count_m = spacemit_kernels::div_round_up(gemm_m, gemm_m_stride);
495
+ const int64_t task_count_n = spacemit_kernels::div_round_up(gemm_n, gemm_n_stride);
366
496
 
367
- GGML_UNUSED(data_size);
368
- }
497
+ int64_t task_count = task_count_m * task_count_n;
498
+ int64_t task_per_thread = (task_count + nth - 1) / nth;
499
+ int64_t start = ith * task_per_thread;
500
+ int64_t end = std::min((ith + 1) * task_per_thread, task_count);
501
+ for (int64_t compute_idx = start; compute_idx < end; compute_idx++) {
502
+ const auto tid_n = compute_idx / task_count_m;
503
+ const auto tid_m = compute_idx % task_count_m;
369
504
 
370
- namespace ggml::cpu::riscv64_spacemit {
505
+ const int64_t m_start = tid_m * gemm_m_stride;
506
+ const int64_t m_count = std::min(gemm_m - m_start, (int64_t) gemm_m_stride);
371
507
 
372
- template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
373
- int repack(struct ggml_tensor *, const void *, size_t);
508
+ const int64_t n_start = tid_n * gemm_n_stride;
509
+ const int64_t n_count = std::min(gemm_n - n_start, (int64_t) gemm_n_stride);
374
510
 
375
- template <> int repack<block_q4_0, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {
376
- return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size);
377
- }
511
+ const int64_t n_blk = m_count == 1 ? n_count : NB_COLS;
378
512
 
379
- template <> int repack<block_q4_1, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {
380
- return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size);
381
- }
513
+ uint8_t * b_col = reinterpret_cast<uint8_t *>(w_data) + n_start * row_stride_b;
514
+ uint8_t * b_col_zp = block_type_has_zp<BLOC_TYPE>() ? b_col : nullptr;
382
515
 
383
- template <> int repack<block_q4_K, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {
384
- return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size);
385
- }
516
+ int64_t n_blk_real = 0;
517
+ for (int64_t ni = 0; ni < n_count; ni += n_blk_real, b_col += n_blk_real * row_stride_b) {
518
+ n_blk_real = std::min(n_count - ni, n_blk);
386
519
 
387
- class tensor_traits_base : public ggml::cpu::tensor_traits {
388
- public:
389
- virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
390
- };
520
+ uint8_t * a_row = quant_a_buffer + m_start * row_stride_a;
391
521
 
392
- template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
393
- bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
394
- switch (op->op) {
395
- case GGML_OP_MUL_MAT:
396
- size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1])) * 4;
397
- size = ((size + QK4_0 - 1) / QK4_0) * (QK4_0 * sizeof(float) + sizeof(float));
398
- return true;
399
- default:
400
- // GGML_ABORT("fatal error");
401
- break;
402
- }
403
- return false;
404
- }
522
+ float * c_blk = output + m_start * gemm_n + n_start + ni;
405
523
 
406
- bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
407
- switch (op->op) {
408
- case GGML_OP_MUL_MAT:
409
- if (op->src[0]->type == GGML_TYPE_Q4_0 || //
410
- op->src[0]->type == GGML_TYPE_Q4_1 || //
411
- op->src[0]->type == GGML_TYPE_Q4_K) {
412
- forward_mul_mat_q4(params, op);
413
- return true;
524
+ int64_t rows_remaining = m_count;
525
+
526
+ uint8_t * b_col_cur = b_col;
527
+ uint8_t * b_col_zp_cur = b_col_zp;
528
+
529
+ while (rows_remaining > 0) {
530
+ auto rows_handled = gemm_kernel(b_blk_len, a_row, b_col_cur, b_col_zp_cur, c_blk,
531
+ rows_remaining, n_blk_real, b_k_blks, gemm_n);
532
+
533
+ c_blk += rows_handled * gemm_n;
534
+ a_row += rows_handled * row_stride_a;
535
+
536
+ rows_remaining -= rows_handled;
537
+ }
414
538
  }
415
- default:
416
- // GGML_ABORT("fatal error");
417
- break;
539
+ }
418
540
  }
419
- return false;
420
541
  }
421
542
 
422
- void forward_mul_mat_q4(ggml_compute_params * params, ggml_tensor * op) {
543
+ void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) {
544
+ constexpr size_t a_blk_len = INTER_SIZE;
545
+ constexpr size_t b_blk_len = INTER_SIZE;
546
+
423
547
  const ggml_tensor * src0 = op->src[0];
424
548
  const ggml_tensor * src1 = op->src[1];
549
+ const ggml_tensor * ids = op->src[2];
425
550
  ggml_tensor * dst = op;
426
551
 
427
552
  GGML_TENSOR_BINARY_OP_LOCALS
@@ -429,133 +554,381 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
429
554
  int ith = params->ith;
430
555
  int nth = params->nth;
431
556
 
432
- [[maybe_unused]] const enum ggml_type type = src0->type;
557
+ // row groups
558
+ const int n_ids = ids->ne[0]; // n_expert_used
559
+ const int n_as = ne02; // n_expert
560
+
561
+ struct mmid_row_mapping {
562
+ int32_t i1;
563
+ int32_t i2;
564
+ };
565
+
566
+ spacemit_kernels::quantize_a_row_def quantize_a_row_i8;
567
+ spacemit_kernels::gemm_kernel_quantize_def gemm_kernel;
568
+ spacemit_kernels::moe_gemm_kernel_quantize_def moe_gemm_kernel_m2;
569
+ bool set_kernel_impl = false;
570
+ size_t block_stride_a = spacemit_kernels::q8_blk_size(QK4_0);
571
+
572
+ #if defined(RISCV64_SPACEMIT_IME2)
573
+ if (!set_kernel_impl && (global_spine_env_info.use_ime2)) {
574
+ quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8;
575
+ block_stride_a = spacemit_kernels::q8_blk_size(QK4_0, true);
576
+
577
+ if constexpr (std::is_same_v<BLOC_TYPE, block_q6_K> || std::is_same_v<BLOC_TYPE, block_q8_0>) {
578
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i8;
579
+ set_kernel_impl = true;
580
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0> || std::is_same_v<BLOC_TYPE, block_q4_1> ||
581
+ std::is_same_v<BLOC_TYPE, block_q4_K>) {
582
+ if constexpr (INTER_SIZE == 256) {
583
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4_hp;
584
+ quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8_hp;
585
+ block_stride_a = spacemit_kernels::q8_hp_blk_size(a_blk_len, true, true);
586
+ set_kernel_impl = true;
587
+ } else {
588
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4;
589
+ moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8i4;
590
+ quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8;
591
+ block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true);
592
+ set_kernel_impl = true;
593
+ }
594
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q2_K>) {
595
+ quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k;
596
+ block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len);
597
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i2k;
598
+ set_kernel_impl = true;
599
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q3_K>) {
600
+ quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k;
601
+ block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len);
602
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i3k;
603
+ set_kernel_impl = true;
604
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_mxfp4>) {
605
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8mxfp4;
606
+ moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8mxfp4;
607
+ set_kernel_impl = true;
608
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q5_1> || std::is_same_v<BLOC_TYPE, block_q5_K> ||
609
+ std::is_same_v<BLOC_TYPE, block_q5_0>) {
610
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i5;
611
+ moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8i5;
612
+ set_kernel_impl = true;
613
+ }
614
+ }
615
+ #endif
433
616
 
434
- void * w_data = (void *) src0->data;
435
- const float * feature = (const float *) src1->data;
436
- float * output = (float *) dst->data;
617
+ #if defined(RISCV64_SPACEMIT_IME1)
618
+ if (!set_kernel_impl && (global_spine_env_info.use_ime1)) {
619
+ quantize_a_row_i8 = spacemit_kernels::ime1::quantize_a_row_i8;
620
+
621
+ if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0> || std::is_same_v<BLOC_TYPE, block_q4_1> ||
622
+ std::is_same_v<BLOC_TYPE, block_q4_K>) {
623
+ gemm_kernel = spacemit_kernels::ime1::gemm_kernel_i8i4;
624
+ set_kernel_impl = true;
625
+ }
626
+ }
627
+ #endif
628
+ if (!set_kernel_impl) {
629
+ GGML_ABORT("no kernel implementation found for the block type");
630
+ }
437
631
 
438
- const size_t batch_feature = ne12 * ne13;
439
- [[maybe_unused]] const size_t batch_weight = ne02 * ne03;
440
- const size_t gemm_m = ne11;
441
- const size_t gemm_k = ne10;
442
- const size_t gemm_n = ne01;
632
+ const size_t a_k_blks = spacemit_kernels::div_round_up(ne10, a_blk_len);
633
+ const size_t b_k_blks = spacemit_kernels::div_round_up(ne10, b_blk_len);
443
634
 
444
- GGML_ASSERT(batch_weight == 1);
635
+ const size_t nbw1 = a_k_blks * block_stride_a;
636
+ const size_t nbw2 = ne11 * nbw1;
637
+ const size_t nbw3 = nbw2 * ne12;
638
+ const size_t gemm_workspace_size = GGML_PAD(nbw3, alignof(int64_t));
445
639
 
446
- const size_t block_count_k = div_round_up(gemm_k, QK4_0);
447
- const size_t per_gemm_workspace_size = gemm_m * block_count_k * q8_blk_size(QK4_0);
448
- const size_t per_gemm_workspace_stride =
449
- div_round_up(per_gemm_workspace_size, alignof(uint64_t)) * alignof(uint64_t);
450
- const size_t gemm_workspace_size = batch_feature * per_gemm_workspace_stride;
451
- const size_t desired_wsize = gemm_workspace_size + alignof(uint64_t) - 1;
640
+ const uintptr_t ws_ptr = reinterpret_cast<uintptr_t>(params->wdata);
641
+ auto * quant_a_buffer = reinterpret_cast<uint8_t *>(ws_ptr);
452
642
 
453
- if (ith == 0 && params->wsize < desired_wsize) {
454
- throw std::runtime_error("wsize less than desired_wsize");
643
+ if (ne11 == 1) {
644
+ for (int64_t ii = ith; ii < ne12 * a_k_blks; ii += nth) {
645
+ int64_t i12 = ii / a_k_blks;
646
+ int64_t ak_blk_id = ii % a_k_blks;
647
+ quantize_a_row_i8(a_blk_len, (float *) ((char *) src1->data + i12 * nb12) + ak_blk_id * a_blk_len,
648
+ a_blk_len, quant_a_buffer + i12 * nbw2 + ak_blk_id * block_stride_a);
649
+ }
650
+ } else {
651
+ for (int64_t ii = ith; ii < ne12 * ne11; ii += nth) {
652
+ int64_t i12 = ii / ne11;
653
+ int64_t i11 = ii % ne11;
654
+ quantize_a_row_i8(a_blk_len, (float *) ((char *) src1->data + i12 * nb12 + i11 * nb11), ne10,
655
+ quant_a_buffer + i12 * nbw2 + i11 * nbw1);
656
+ }
455
657
  }
456
658
 
457
- std::vector<qnbitgemm_spacemit_ime_args> qnbitgemm_args(batch_feature);
659
+ #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) *ne12 + (i1)]
458
660
 
459
- for (size_t i = 0; i < batch_feature; i++) {
460
- qnbitgemm_args[i].a_ptr = feature + gemm_m * gemm_k * i;
461
- qnbitgemm_args[i].lda = gemm_k;
462
- qnbitgemm_args[i].packed_quant_b_data = (const std::byte *) w_data;
463
- qnbitgemm_args[i].quant_b_scale = nullptr;
661
+ int64_t * matrix_row_counts = (int64_t *) (ws_ptr + gemm_workspace_size);
662
+ int32_t * valid_ep_count = (int32_t *) (matrix_row_counts + n_as);
663
+ int32_t * valid_act_count = (int32_t *) (valid_ep_count + 1);
664
+ int64_t * valid_matrix_row_counts = (int64_t *) (valid_act_count + 1);
665
+ mmid_row_mapping * matrix_rows = (mmid_row_mapping *) (valid_matrix_row_counts + n_as);
464
666
 
465
- if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0>) {
466
- qnbitgemm_args[i].quant_b_zp = nullptr;
467
- } else {
468
- qnbitgemm_args[i].quant_b_zp = w_data;
667
+ if (ith == 0) {
668
+ // initialize matrix_row_counts
669
+ memset(matrix_row_counts, 0, n_as * sizeof(int64_t));
670
+
671
+ // group rows by src0 matrix
672
+ for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
673
+ for (int32_t id = 0; id < n_ids; ++id) {
674
+ const int32_t i02 =
675
+ *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
676
+
677
+ GGML_ASSERT(i02 >= 0 && i02 < n_as);
678
+
679
+ MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 };
680
+ matrix_row_counts[i02] += 1;
681
+ }
469
682
  }
470
683
 
471
- qnbitgemm_args[i].bias = nullptr;
472
- qnbitgemm_args[i].c_ptr = output + gemm_m * gemm_n * i;
473
- qnbitgemm_args[i].ldc = gemm_n;
684
+ int32_t valid_ep_count_t = 0;
685
+ int32_t valid_act_count_t = 0;
686
+ for (int cur_a = 0; cur_a < n_as; ++cur_a) {
687
+ const int64_t cne1 = matrix_row_counts[cur_a];
688
+ if (cne1 == 0) {
689
+ continue;
690
+ }
691
+ valid_matrix_row_counts[valid_ep_count_t] = cur_a;
692
+ valid_act_count_t += cne1;
693
+ valid_ep_count_t += 1;
694
+ }
695
+ valid_ep_count[0] = valid_ep_count_t;
696
+ valid_act_count[0] = valid_act_count_t;
474
697
  }
475
698
 
476
- const uintptr_t ws_ptr = reinterpret_cast<uintptr_t>(params->wdata);
477
- void * ws = reinterpret_cast<void *>((ws_ptr + alignof(uint64_t) - 1) & (~(alignof(uint64_t) - 1)));
478
- const size_t quant_a_stride = block_count_k * q8_blk_size(QK4_0);
699
+ const int64_t barrier_idx = static_cast<int64_t>(ith / 2);
479
700
 
480
- {
481
- constexpr size_t block_size_m = 4;
482
- size_t per_gemm_block_count_m = div_round_up(gemm_m, block_size_m);
483
- int32_t task_count = batch_feature * per_gemm_block_count_m;
484
- int32_t task_per_thread = (task_count + nth - 1) / nth;
485
- int32_t start = ith * task_per_thread;
486
- int32_t end = std::min((ith + 1) * task_per_thread, task_count);
487
- for (int32_t compute_idx = start; compute_idx < end; compute_idx++) {
488
- int32_t gemm_idx = compute_idx / per_gemm_block_count_m;
489
- int32_t block_idx_in_gemm = compute_idx % per_gemm_block_count_m;
490
- int32_t m_idx = block_idx_in_gemm * block_size_m;
491
- const qnbitgemm_spacemit_ime_args & data = qnbitgemm_args[gemm_idx];
492
- int32_t rows_tobe_handled = (gemm_m - m_idx) > block_size_m ? block_size_m : (gemm_m - m_idx);
493
-
494
- if (rows_tobe_handled == block_size_m) {
495
- const float * a_row_ptr = data.a_ptr + m_idx * data.lda;
496
- std::byte * quant_a_row_ptr =
497
- static_cast<std::byte *>(ws) + gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride;
498
- sqnbitgemm_spacemit_ime::ime1::quantize_a_4row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr);
499
- } else {
500
- while (rows_tobe_handled) {
501
- const float * a_row_ptr = data.a_ptr + m_idx * data.lda;
502
- std::byte * quant_a_row_ptr = static_cast<std::byte *>(ws) +
503
- gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride;
504
- sqnbitgemm_spacemit_ime::ime1::quantize_a_row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr);
505
- rows_tobe_handled -= 1;
506
- m_idx += 1;
701
+ GGML_ASSERT(global_spine_env_info.init_barrier != nullptr);
702
+ GGML_ASSERT(barrier_idx < spine_init_barrier_count);
703
+ spine_barrier_t * cur_barrier = &global_spine_env_info.init_barrier[barrier_idx];
704
+
705
+ ggml_barrier(params->threadpool);
706
+
707
+ const size_t row_stride_b = b_k_blks * get_repacked_block_type_size<BLOC_TYPE, INTER_SIZE, NB_COLS>();
708
+ const size_t expert_b_stride = ne01 * row_stride_b;
709
+ const size_t per_nb_cols_wsize = NB_COLS * row_stride_b;
710
+
711
+ std::array<const uint8_t *, 2> src_workspaces;
712
+ std::array<float *, 2> dst_workspaces;
713
+
714
+ auto * tcm_buffer = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer;
715
+ const auto tcm_buffer_size = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size;
716
+
717
+ const auto valid_ep_count_t = valid_ep_count[0];
718
+ const auto valid_act_count_t = valid_act_count[0];
719
+
720
+ int nth_es = 1;
721
+ int nth_n = nth;
722
+
723
+ int ith_es = ith % nth_es;
724
+ int ith_n = (ith / nth_es) % nth_n;
725
+
726
+ if (valid_ep_count_t % nth == 0 && tcm_buffer != nullptr && valid_ep_count_t == n_as &&
727
+ valid_act_count_t == n_as && per_nb_cols_wsize <= tcm_buffer_size) {
728
+ for (int64_t valid_id = ith; valid_id < valid_ep_count_t; valid_id += nth) {
729
+ const int64_t cur_a = valid_matrix_row_counts[valid_id];
730
+
731
+ auto * src0_cur = (uint8_t *) src0->data + cur_a * expert_b_stride;
732
+
733
+ mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, 0);
734
+ const int id = row_mapping.i1;
735
+ const int64_t i11 = id % ne11;
736
+ const int64_t i12 = row_mapping.i2;
737
+ const int64_t i1 = id;
738
+ const int64_t i2 = i12;
739
+
740
+ auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2);
741
+ float * c_blk = (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2));
742
+
743
+ uint8_t * a_row = src1_col;
744
+ uint8_t * b_col = reinterpret_cast<uint8_t *>(tcm_buffer);
745
+ if ((nbw1 + per_nb_cols_wsize) <= tcm_buffer_size) {
746
+ a_row = (uint8_t *) tcm_buffer;
747
+ b_col = reinterpret_cast<uint8_t *>(tcm_buffer) + nbw1;
748
+ }
749
+ uint8_t * b_col_zp = block_type_has_zp<BLOC_TYPE>() ? b_col : nullptr;
750
+
751
+ if (ith % 2 == 0) {
752
+ spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast<uint8_t *>(src0_cur), per_nb_cols_wsize);
753
+
754
+ if (a_row != src1_col) {
755
+ spacemit_kernels::rvv::memcpy1d(a_row, src1_col, nbw1);
756
+ }
757
+ }
758
+
759
+ spine_barrier_wait(cur_barrier);
760
+
761
+ if (ith % 2 != 0) {
762
+ if (a_row != src1_col) {
763
+ spacemit_kernels::rvv::memcpy1d(a_row, src1_col, nbw1);
764
+ }
765
+
766
+ spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast<uint8_t *>(src0_cur), per_nb_cols_wsize);
767
+ }
768
+
769
+ int64_t nb_real = std::min(ne01, NB_COLS);
770
+ for (int64_t ni = 0; ni < ne01; ni += NB_COLS) {
771
+ if (ith % 2 != 0) {
772
+ spine_barrier_wait(cur_barrier);
773
+ }
774
+
775
+ gemm_kernel(b_blk_len, a_row, b_col, b_col_zp, c_blk + ni, 1, nb_real, b_k_blks, ne01);
776
+
777
+ if (ith % 2 == 0) {
778
+ spine_barrier_wait(cur_barrier);
779
+ }
780
+
781
+ const int64_t next_ni = ni + NB_COLS;
782
+ if (next_ni < ne01) {
783
+ nb_real = std::min(ne01 - next_ni, NB_COLS);
784
+ spacemit_kernels::rvv::memcpy1d(
785
+ b_col, reinterpret_cast<uint8_t *>(src0_cur) + next_ni * row_stride_b, per_nb_cols_wsize);
507
786
  }
508
787
  }
509
788
  }
510
- }
789
+ } else {
790
+ for (int64_t valid_id = ith_es; valid_id < valid_ep_count_t; valid_id += nth_es) {
791
+ const int64_t cur_a = valid_matrix_row_counts[valid_id];
792
+ const int64_t cne1 = matrix_row_counts[cur_a];
511
793
 
512
- ggml_barrier(params->threadpool);
794
+ int64_t src1_cur_start = 0;
795
+ int64_t src1_cur_end = cne1;
513
796
 
514
- if (ith >= ggml::cpu::riscv64_spacemit::num_ai_cores) {
515
- return;
516
- }
517
- nth = std::min(nth, int{ ggml::cpu::riscv64_spacemit::num_ai_cores });
518
-
519
- size_t threads_per_gemm = nth / batch_feature;
520
- constexpr size_t gemm_m_stride = 128;
521
- size_t nc = gemm_n;
522
- const size_t gemm_m_blocked = div_round_up(gemm_m, gemm_m_stride);
523
- const size_t max_nc = div_round_up(gemm_n * gemm_m_blocked, threads_per_gemm);
524
- if (max_nc < nc) {
525
- nc = std::min(nc, div_round_up(max_nc, QGEMM_STRIDEN_THREAD_ALIGN) * QGEMM_STRIDEN_THREAD_ALIGN);
526
- }
527
- const size_t gemm_n_stride = nc;
528
- const size_t thread_count_m = div_round_up(gemm_m, gemm_m_stride);
529
- const size_t thread_count_n = div_round_up(gemm_n, gemm_n_stride);
530
- threads_per_gemm = thread_count_m * thread_count_n;
797
+ int64_t src0_cur_start = (ith_n * ne01) / nth_n;
798
+ int64_t src0_cur_end = MIN(((ith_n + 1) * ne01) / nth_n, ne01);
531
799
 
532
- {
533
- int task_count = batch_feature * threads_per_gemm;
534
- int task_per_thread = (task_count + nth - 1) / nth;
535
- int start = ith * task_per_thread;
536
- int end = std::min((ith + 1) * task_per_thread, task_count);
537
- for (int compute_idx = start; compute_idx < end; compute_idx++) {
538
- const auto gemm_i = compute_idx / threads_per_gemm;
539
- const auto blk_i = compute_idx % threads_per_gemm;
540
- const auto * data = &qnbitgemm_args[gemm_i];
800
+ if (src1_cur_start >= src1_cur_end || src0_cur_start >= src0_cur_end) {
801
+ continue;
802
+ }
803
+
804
+ src0_cur_start =
805
+ (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
806
+ src0_cur_end =
807
+ (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
808
+
809
+ auto * src0_cur = (uint8_t *) src0->data + cur_a * expert_b_stride + src0_cur_start * row_stride_b;
810
+ uint8_t * b_col_zp = block_type_has_zp<BLOC_TYPE>() ? src0_cur : nullptr;
811
+
812
+ size_t extra_tcm_buffer_size = tcm_buffer_size;
813
+ void * extra_tcm_buffer = tcm_buffer;
814
+ if (tcm_buffer != nullptr && (src1_cur_end - src1_cur_start) >= 4 &&
815
+ (src0_cur_end - src0_cur_start) * row_stride_b <= tcm_buffer_size) {
816
+ spacemit_kernels::rvv::memcpy1d(tcm_buffer, src0_cur,
817
+ (src0_cur_end - src0_cur_start) * row_stride_b);
818
+ src0_cur = reinterpret_cast<uint8_t *>(tcm_buffer);
819
+ b_col_zp = block_type_has_zp<BLOC_TYPE>() ? src0_cur : nullptr;
820
+ extra_tcm_buffer_size -= (src0_cur_end - src0_cur_start) * row_stride_b;
821
+ extra_tcm_buffer = reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(tcm_buffer) +
822
+ (src0_cur_end - src0_cur_start) * row_stride_b);
823
+ }
541
824
 
542
- const auto tid_n = blk_i / thread_count_m;
543
- const auto tid_m = blk_i % thread_count_m;
825
+ int ir1 = src1_cur_start;
544
826
 
545
- const size_t m_start = tid_m * gemm_m_stride;
546
- const size_t m_count = std::min(gemm_m - m_start, (size_t) gemm_m_stride);
827
+ if (extra_tcm_buffer_size >= nbw1 && extra_tcm_buffer != nullptr) {
828
+ int64_t quant_a_tile_size = extra_tcm_buffer_size / nbw1;
829
+ do {
830
+ quant_a_tile_size = MIN(quant_a_tile_size, src1_cur_end - ir1);
547
831
 
548
- const size_t n_start = tid_n * gemm_n_stride;
549
- const size_t n_count = std::min(gemm_n - n_start, (size_t) gemm_n_stride);
832
+ uint8_t * quant_a_tile_buffer = reinterpret_cast<uint8_t *>(extra_tcm_buffer);
550
833
 
551
- void * per_gemm_ws = reinterpret_cast<std::byte *>(ws) + gemm_i * per_gemm_workspace_stride;
834
+ int iir1 = ir1;
835
+ for (; iir1 < (ir1 + quant_a_tile_size); ++iir1) {
836
+ mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, iir1);
552
837
 
553
- sqnbitgemm_spacemit_ime_i8i4(QK4_0, gemm_k, data, per_gemm_ws, m_start, m_count, n_start, n_count);
838
+ const int id = row_mapping.i1; // selected expert index
839
+
840
+ const int64_t i11 = id % ne11;
841
+ const int64_t i12 = row_mapping.i2; // row index in src1
842
+
843
+ auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2);
844
+ spacemit_kernels::rvv::memcpy1d(quant_a_tile_buffer, src1_col, nbw1);
845
+ quant_a_tile_buffer = quant_a_tile_buffer + nbw1;
846
+ }
847
+
848
+ quant_a_tile_buffer = reinterpret_cast<uint8_t *>(extra_tcm_buffer);
849
+ iir1 = ir1;
850
+
851
+ if (moe_gemm_kernel_m2 != nullptr) {
852
+ for (; iir1 < (ir1 + quant_a_tile_size - 1); iir1 += 2, quant_a_tile_buffer += 2 * nbw1) {
853
+ mmid_row_mapping row_mapping_0 = MMID_MATRIX_ROW(cur_a, iir1);
854
+ mmid_row_mapping row_mapping_1 = MMID_MATRIX_ROW(cur_a, iir1 + 1);
855
+
856
+ src_workspaces[0] = quant_a_tile_buffer;
857
+ src_workspaces[1] = quant_a_tile_buffer + nbw1;
858
+
859
+ dst_workspaces[0] =
860
+ (float *) ((char *) dst->data + (row_mapping_0.i1 * nb1 + row_mapping_0.i2 * nb2)) +
861
+ src0_cur_start;
862
+ dst_workspaces[1] = (float *) ((char *) dst->data +
863
+ ((row_mapping_1.i1) * nb1 + (row_mapping_1.i2) * nb2)) +
864
+ src0_cur_start;
865
+ moe_gemm_kernel_m2(b_blk_len, src_workspaces.data(), src0_cur, b_col_zp,
866
+ dst_workspaces.data(), 1, src0_cur_end - src0_cur_start, b_k_blks,
867
+ ne01);
868
+ }
869
+ }
870
+
871
+ for (; iir1 < (ir1 + quant_a_tile_size); iir1++, quant_a_tile_buffer += nbw1) {
872
+ mmid_row_mapping row_mapping_0 = MMID_MATRIX_ROW(cur_a, iir1);
873
+
874
+ gemm_kernel(
875
+ b_blk_len, quant_a_tile_buffer, src0_cur, b_col_zp,
876
+ (float *) ((char *) dst->data + (row_mapping_0.i1 * nb1 + row_mapping_0.i2 * nb2)) +
877
+ src0_cur_start,
878
+ 1, src0_cur_end - src0_cur_start, b_k_blks, ne01);
879
+ }
880
+
881
+ ir1 += quant_a_tile_size;
882
+ } while (ir1 < src1_cur_end);
883
+ } else {
884
+ if (moe_gemm_kernel_m2 != nullptr) {
885
+ for (; ir1 < src1_cur_end - 1; ir1 += 2) {
886
+ for (int iir1 = 0; iir1 < 2; ++iir1) {
887
+ mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1 + iir1);
888
+
889
+ const int id = row_mapping.i1; // selected expert index
890
+
891
+ const int64_t i11 = id % ne11;
892
+ const int64_t i12 = row_mapping.i2; // row index in src1
893
+
894
+ const int64_t i1 = id; // selected expert index
895
+ const int64_t i2 = i12; // row
896
+
897
+ src_workspaces[iir1] = quant_a_buffer + (i11 * nbw1 + i12 * nbw2);
898
+
899
+ dst_workspaces[iir1] =
900
+ (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start;
901
+ }
902
+
903
+ moe_gemm_kernel_m2(b_blk_len, src_workspaces.data(), src0_cur, b_col_zp,
904
+ dst_workspaces.data(), 1, src0_cur_end - src0_cur_start, b_k_blks, ne01);
905
+ }
906
+ }
907
+
908
+ for (; ir1 < src1_cur_end; ir1++) {
909
+ mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
910
+
911
+ const int id = row_mapping.i1; // selected expert index
912
+
913
+ const int64_t i11 = id % ne11;
914
+ const int64_t i12 = row_mapping.i2; // row index in src1
915
+
916
+ const int64_t i1 = id; // selected expert index
917
+ const int64_t i2 = i12; // row
918
+
919
+ auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2);
920
+
921
+ gemm_kernel(b_blk_len, src1_col, src0_cur, b_col_zp,
922
+ (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, 1,
923
+ src0_cur_end - src0_cur_start, b_k_blks, ne01);
924
+ }
925
+ }
554
926
  }
555
927
  }
928
+ #undef MMID_MATRIX_ROW
556
929
  }
557
930
 
558
- int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
931
+ int repack(ggml_tensor * t, const void * data, size_t data_size) override {
559
932
  GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type),
560
933
  (int) NB_COLS, (int) INTER_SIZE);
561
934
  return ggml::cpu::riscv64_spacemit::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size);
@@ -563,309 +936,464 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
563
936
  };
564
937
 
565
938
  class tensor_traits_common : public tensor_traits_base {
566
- bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
939
+ bool work_size(int n_threads, const ggml_tensor * op, size_t & size) override {
567
940
  switch (op->op) {
568
- case GGML_OP_NORM:
569
- case GGML_OP_RMS_NORM:
570
- size = 0;
941
+ case GGML_OP_FLASH_ATTN_EXT:
942
+ {
943
+ const int n_tasks = n_threads;
944
+ const int64_t neq2 = op->src[0]->ne[2]; // number of query heads
945
+ const int64_t DK = op->src[1]->ne[0];
946
+ const int64_t DV = op->src[2]->ne[0]; // DV
947
+
948
+ // Tiled flash attention scratch (tile sizes defined in common.h)
949
+ // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding
950
+ size_t prefill = sizeof(float) *
951
+ (GGML_FA_TILE_Q * DK + 2 * GGML_FA_TILE_Q * GGML_FA_TILE_KV + GGML_FA_TILE_Q * DV +
952
+ GGML_FA_TILE_KV * DV + GGML_FA_TILE_KV * DK) *
953
+ n_tasks;
954
+
955
+ // Decode path: n_kv_chunks = n_tasks (one chunk per thread)
956
+ // Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ
957
+ size_t n_chunks = n_tasks;
958
+ size_t decode = sizeof(float) * (neq2 * n_chunks * (2 + DV) + n_tasks * (DK + 2 * DV));
959
+
960
+ size = MAX(prefill, decode);
961
+ }
571
962
  return true;
572
963
  default:
573
- // GGML_ABORT("fatal error");
574
964
  break;
575
965
  }
576
966
  return false;
577
967
  }
578
968
 
579
- bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
969
+ bool compute_forward(ggml_compute_params * params, ggml_tensor * op) override {
580
970
  switch (op->op) {
581
971
  case GGML_OP_NORM:
582
- forward_norm_f32(params, op);
583
- return true;
972
+ switch (op->src[0]->type) {
973
+ case GGML_TYPE_F32:
974
+ spacemit_kernels::rvv::forward_norm_f32(params, op);
975
+ return true;
976
+ default:
977
+ GGML_ABORT("fatal error");
978
+ }
584
979
  case GGML_OP_RMS_NORM:
585
- forward_rms_norm_f32(params, op);
980
+ switch (op->src[0]->type) {
981
+ case GGML_TYPE_F32:
982
+ spacemit_kernels::rvv::forward_rms_norm_f32(params, op);
983
+ return true;
984
+ default:
985
+ GGML_ABORT("fatal error");
986
+ }
987
+ case GGML_OP_ADD:
988
+ switch (op->src[0]->type) {
989
+ case GGML_TYPE_F32:
990
+ spacemit_kernels::rvv::forward_binary<GGML_OP_ADD, float>(params, op);
991
+ return true;
992
+ case GGML_TYPE_F16:
993
+ spacemit_kernels::rvv::forward_binary<GGML_OP_ADD, _Float16>(params, op);
994
+ return true;
995
+ default:
996
+ ggml_compute_forward_add(params, op);
997
+ return true;
998
+ }
999
+ case GGML_OP_SUB:
1000
+ switch (op->src[0]->type) {
1001
+ case GGML_TYPE_F32:
1002
+ spacemit_kernels::rvv::forward_binary<GGML_OP_SUB, float>(params, op);
1003
+ return true;
1004
+ case GGML_TYPE_F16:
1005
+ spacemit_kernels::rvv::forward_binary<GGML_OP_SUB, _Float16>(params, op);
1006
+ return true;
1007
+ default:
1008
+ ggml_compute_forward_sub(params, op);
1009
+ return true;
1010
+ }
1011
+ case GGML_OP_MUL:
1012
+ switch (op->src[0]->type) {
1013
+ case GGML_TYPE_F32:
1014
+ spacemit_kernels::rvv::forward_binary<GGML_OP_MUL, float>(params, op);
1015
+ return true;
1016
+ case GGML_TYPE_F16:
1017
+ spacemit_kernels::rvv::forward_binary<GGML_OP_MUL, _Float16>(params, op);
1018
+ return true;
1019
+ default:
1020
+ ggml_compute_forward_mul(params, op);
1021
+ return true;
1022
+ }
1023
+ case GGML_OP_DIV:
1024
+ switch (op->src[0]->type) {
1025
+ case GGML_TYPE_F32:
1026
+ spacemit_kernels::rvv::forward_binary<GGML_OP_DIV, float>(params, op);
1027
+ return true;
1028
+ case GGML_TYPE_F16:
1029
+ spacemit_kernels::rvv::forward_binary<GGML_OP_DIV, _Float16>(params, op);
1030
+ return true;
1031
+ default:
1032
+ ggml_compute_forward_div(params, op);
1033
+ return true;
1034
+ }
1035
+ case GGML_OP_FLASH_ATTN_EXT:
1036
+ forward_flash_attn_ext_f16(params, op);
1037
+ return true;
1038
+ case GGML_OP_CONT:
1039
+ {
1040
+ const ggml_tensor * src0 = op->src[0];
1041
+ if (op->type == src0->type && op->nb[0] != src0->nb[0] && op->nb[0] == src0->nb[1] &&
1042
+ op->ne[3] * op->ne[2] * op->nb[2] == src0->ne[3] * src0->ne[2] * src0->nb[2]) {
1043
+ spacemit_kernels::rvv::forward_cont_with_permute(params, op);
1044
+ } else {
1045
+ ggml_compute_forward_cont(params, op);
1046
+ }
1047
+ return true;
1048
+ }
1049
+ case GGML_OP_CPY:
1050
+ {
1051
+ const ggml_tensor * src0 = op->src[0];
1052
+ if (op->type == src0->type && op->nb[0] == src0->nb[1] && src0->nb[0] != src0->nb[1] &&
1053
+ ggml_nelements(src0) == ggml_nelements(op)) {
1054
+ spacemit_kernels::rvv::forward_cpy_with_permute(params, op);
1055
+ } else {
1056
+ ggml_compute_forward_cpy(params, op);
1057
+ }
1058
+ return true;
1059
+ }
1060
+ case GGML_OP_REPEAT:
1061
+ {
1062
+ const bool rows_equal = ggml_nrows(op->src[0]) == ggml_nrows(op);
1063
+ const bool broadcast_or_equal = op->src[0]->ne[0] == 1 || op->src[0]->ne[0] == op->ne[0];
1064
+
1065
+ if (rows_equal && broadcast_or_equal) {
1066
+ switch (op->src[0]->type) {
1067
+ case GGML_TYPE_F32:
1068
+ spacemit_kernels::rvv::forward_repeat_nrows<int32_t>(params, op);
1069
+ return true;
1070
+ case GGML_TYPE_F16:
1071
+ spacemit_kernels::rvv::forward_repeat_nrows<int16_t>(params, op);
1072
+ return true;
1073
+ default:
1074
+ break;
1075
+ }
1076
+ }
1077
+
1078
+ if (op->src[0]->ne[1] == 1 && op->src[0]->ne[0] == op->ne[0]) {
1079
+ switch (op->src[0]->type) {
1080
+ case GGML_TYPE_F32:
1081
+ spacemit_kernels::rvv::forward_repeat_dim1<int32_t>(params, op);
1082
+ return true;
1083
+ case GGML_TYPE_F16:
1084
+ spacemit_kernels::rvv::forward_repeat_dim1<int16_t>(params, op);
1085
+ return true;
1086
+ default:
1087
+ break;
1088
+ }
1089
+ }
1090
+
1091
+ ggml_compute_forward_repeat(params, op);
1092
+ }
1093
+ return true;
1094
+ case GGML_OP_SUM_ROWS:
1095
+ {
1096
+ if (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) {
1097
+ spacemit_kernels::rvv::forward_sum_rows<float>(params, op);
1098
+ } else {
1099
+ ggml_compute_forward_sum_rows(params, op);
1100
+ }
1101
+ }
1102
+ return true;
1103
+ case GGML_OP_GET_ROWS:
1104
+ {
1105
+ if (op->src[0]->type == op->type) {
1106
+ switch (op->src[0]->type) {
1107
+ case GGML_TYPE_F32:
1108
+ spacemit_kernels::rvv::forward_get_rows<int32_t>(params, op);
1109
+ return true;
1110
+ case GGML_TYPE_F16:
1111
+ spacemit_kernels::rvv::forward_get_rows<int16_t>(params, op);
1112
+ return true;
1113
+ default:
1114
+ break;
1115
+ }
1116
+ }
1117
+
1118
+ ggml_compute_forward_get_rows(params, op);
1119
+ }
586
1120
  return true;
1121
+ case GGML_OP_CONCAT:
1122
+ {
1123
+ const int32_t dim = ggml_get_op_params_i32(op, 0);
1124
+ if (dim == 0 && op->type == op->src[0]->type) {
1125
+ switch (op->src[0]->type) {
1126
+ case GGML_TYPE_F32:
1127
+ spacemit_kernels::rvv::forward_concat<int32_t>(params, op);
1128
+ return true;
1129
+ case GGML_TYPE_F16:
1130
+ spacemit_kernels::rvv::forward_concat<int16_t>(params, op);
1131
+ return true;
1132
+ default:
1133
+ break;
1134
+ }
1135
+ }
1136
+
1137
+ ggml_compute_forward_concat(params, op);
1138
+ }
1139
+ return true;
1140
+ // TODO For GGML_OP_GATED_DELTA_NET
1141
+ // case GGML_OP_GATED_DELTA_NET:
1142
+ // return true;
587
1143
  default:
588
- // GGML_ABORT("fatal error");
589
1144
  break;
590
1145
  }
591
1146
  return false;
592
1147
  }
593
1148
 
594
- void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op) {
595
- const ggml_tensor * src0 = op->src[0];
596
- ggml_tensor * dst = op;
597
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
598
- GGML_ASSERT(src0->nb[0] == sizeof(float));
1149
+ void forward_flash_attn_ext_f16(const ggml_compute_params * params, ggml_tensor * dst) {
1150
+ const ggml_tensor * q = dst->src[0];
1151
+ const ggml_tensor * k = dst->src[1];
1152
+ const ggml_tensor * v = dst->src[2];
1153
+
1154
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
1155
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
1156
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
1157
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
1158
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
1159
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
1160
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
1161
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
1162
+
1163
+ const int64_t DK = nek0;
1164
+ const int64_t DV = nev0;
1165
+
1166
+ const bool supported_prec = (dst->op_params[3] == GGML_PREC_F32 || dst->op_params[3] == GGML_PREC_DEFAULT);
1167
+ const bool supported_types = (q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16);
1168
+ const bool supported_shape = (DK > 0 && DK <= 128 && DV > 0 && DV <= 128);
1169
+ const bool supported_vlen = (__riscv_vlenb() == 128);
1170
+
1171
+ if (!(supported_prec && supported_types && supported_shape && supported_vlen)) {
1172
+ ggml_compute_forward_flash_attn_ext(params, dst);
1173
+ return;
1174
+ }
1175
+
1176
+ // total rows in q
1177
+ const int64_t nr = neq1 * neq2 * neq3;
599
1178
 
1179
+ // rows per thread
600
1180
  const int ith = params->ith;
601
1181
  const int nth = params->nth;
602
1182
 
603
- GGML_TENSOR_UNARY_OP_LOCALS
1183
+ static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
1184
+ const bool use_tiled = !params->use_ref && (neq1 >= Q_TILE_SZ);
604
1185
 
605
- float epsilon;
606
- memcpy(&epsilon, dst->op_params, sizeof(float));
1186
+ // 4x chunks per thread
1187
+ // int nth_scaled = nth * 4;
1188
+ // int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
1189
+ // int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
607
1190
 
608
- GGML_ASSERT(epsilon > 0.0f);
1191
+ // if (nth == 1 || nchunk < nth) {
1192
+ // nchunk = nth;
1193
+ // }
609
1194
 
610
- auto * input = (float *) src0->data;
611
- auto * output = (float *) dst->data;
1195
+ int64_t nchunk = nth;
612
1196
 
613
- const auto hidden_size = ne00;
614
- const auto task_count = ne01 * ne02 * ne03;
615
- const auto task_per_thread = (task_count + nth - 1) / nth;
616
-
617
- const auto task_begin = ith * task_per_thread;
618
- const auto task_end = std::min((ith + 1) * task_per_thread, task_count);
1197
+ if (ith == 0) {
1198
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
1199
+ ggml_threadpool_chunk_set(params->threadpool, nth);
1200
+ }
619
1201
 
620
- for (auto task_idx = task_begin; task_idx < task_end; task_idx++) {
621
- auto offset = task_idx * hidden_size;
622
- auto * p_input = const_cast<float *>(input + offset);
1202
+ ggml_barrier(params->threadpool);
623
1203
 
624
- auto * p_output = output + offset;
625
- auto * p_temp_output = p_output;
626
- auto * p_gamma_data = (const float *) nullptr;
627
- auto * p_beta_data = (const float *) nullptr;
628
- size_t gvl = __riscv_vsetvlmax_e32m4();
629
- vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl);
630
- vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl);
631
- int64_t length = hidden_size;
632
- while (length > 0) {
633
- gvl = __riscv_vsetvl_e32m4(length);
634
- // load data
635
- vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl);
1204
+ // The number of elements in each chunk
1205
+ const int64_t dr = (nr + nchunk - 1) / nchunk;
636
1206
 
637
- sum = __riscv_vfadd_vv_f32m4(sum, src_data, gvl);
638
- sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl);
1207
+ // The first chunk comes from our thread_id, the rest will get auto-assigned.
1208
+ int current_chunk = ith;
639
1209
 
640
- __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl);
1210
+ while (current_chunk < nchunk) {
1211
+ const int64_t ir0 = dr * current_chunk;
1212
+ const int64_t ir1 = MIN(ir0 + dr, nr);
641
1213
 
642
- p_input += gvl;
643
- p_temp_output += gvl;
644
- length -= gvl;
1214
+ if (use_tiled) {
1215
+ spacemit_kernels::rvv::forward_flash_attn_ext_f16_tiled_vlen1024_vf16(
1216
+ params, dst, ir0, ir1, ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer,
1217
+ ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size);
1218
+ } else {
1219
+ spacemit_kernels::rvv::forward_flash_attn_ext_f16_one_chunk_vlen1024_vf16(
1220
+ params, dst, ir0, ir1, ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer,
1221
+ ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size);
645
1222
  }
646
1223
 
647
- gvl = __riscv_vsetvlmax_e32m1();
648
-
649
- float mean = 0.f;
650
- vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl);
651
- vfloat32m1_t mean_v =
652
- __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl);
653
- mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl);
654
- mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl);
655
- mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl);
656
- mean = __riscv_vfmv_f_s_f32m1_f32(mean_v);
657
- mean /= hidden_size;
658
-
659
- vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0),
660
- __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl);
661
- mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl);
662
- mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl);
663
- mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl);
664
-
665
- float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v);
666
- mean_square /= hidden_size;
667
- mean_square = sqrt(mean_square - mean * mean + epsilon);
668
-
669
- mean_square = 1.0f / mean_square;
670
- length = hidden_size;
671
- p_temp_output = p_output;
672
-
673
- if (p_gamma_data == nullptr && p_beta_data == nullptr) {
674
- while (length > 0) {
675
- gvl = __riscv_vsetvl_e32m4(length);
676
- vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
677
- src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);
678
- src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
679
- __riscv_vse32_v_f32m4(p_output, src_data, gvl);
680
- p_temp_output += gvl;
681
- p_output += gvl;
682
- length -= gvl;
683
- }
684
- } else if (p_beta_data == nullptr) {
685
- while (length > 0) {
686
- gvl = __riscv_vsetvl_e32m4(length);
687
- vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
688
- vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
689
- src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);
690
- src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
691
- src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
692
- __riscv_vse32_v_f32m4(p_output, src_data, gvl);
693
- p_temp_output += gvl;
694
- p_output += gvl;
695
- p_gamma_data += gvl;
696
- length -= gvl;
697
- }
698
- } else if (p_gamma_data != nullptr) {
699
- while (length > 0) {
700
- gvl = __riscv_vsetvl_e32m4(length);
701
- vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
702
- vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
703
- src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);
704
- src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
705
- src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
706
- vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl);
707
- src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl);
708
- p_beta_data += gvl;
709
- __riscv_vse32_v_f32m4(p_output, src_data, gvl);
710
- p_temp_output += gvl;
711
- p_output += gvl;
712
- p_gamma_data += gvl;
713
- length -= gvl;
714
- }
715
- }
1224
+ current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
716
1225
  }
717
1226
  }
718
1227
 
719
- void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op) {
720
- const ggml_tensor * src0 = op->src[0];
721
- ggml_tensor * dst = op;
722
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
723
- GGML_ASSERT(src0->nb[0] == sizeof(float));
724
-
725
- const int ith = params->ith;
726
- const int nth = params->nth;
727
-
728
- GGML_TENSOR_UNARY_OP_LOCALS
729
-
730
- float epsilon;
731
- memcpy(&epsilon, dst->op_params, sizeof(float));
732
-
733
- GGML_ASSERT(epsilon > 0.0f);
734
-
735
- auto * input = (float *) src0->data;
736
- auto * output = (float *) dst->data;
737
-
738
- const auto hidden_size = ne00;
739
- const auto task_count = ne01 * ne02 * ne03;
740
- const auto task_per_thread = (task_count + nth - 1) / nth;
741
-
742
- const auto task_begin = ith * task_per_thread;
743
- const auto task_end = std::min((ith + 1) * task_per_thread, task_count);
744
-
745
- for (auto task_idx = task_begin; task_idx < task_end; task_idx++) {
746
- auto offset = task_idx * hidden_size;
747
- auto * p_input = const_cast<float *>(input + offset);
748
- auto * p_output = output + offset;
749
- auto * p_temp_output = p_output;
750
- auto * p_gamma_data = (const float *) nullptr;
751
- auto * p_beta_data = (const float *) nullptr;
752
-
753
- size_t gvl = __riscv_vsetvlmax_e32m4();
754
- // vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl);
755
- vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl);
756
- int64_t length = hidden_size;
757
- while (length > 0) {
758
- gvl = __riscv_vsetvl_e32m4(length);
759
- // load data
760
- vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl);
1228
+ int repack(ggml_tensor * t, const void * data, size_t data_size) override {
1229
+ memcpy(t->data, data, data_size);
1230
+ return 0;
1231
+ }
1232
+ };
761
1233
 
762
- sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl);
1234
+ // Impl By IME1
1235
+ static const tensor_traits<block_q4_0, 32, 16> q4_0_16x32_q8_0;
1236
+ static const tensor_traits<block_q4_1, 32, 16> q4_1_16x32_q8_0;
1237
+ static const tensor_traits<block_q4_K, 32, 16> q4_k_16x32_q8_0;
1238
+ // Impl By IME2
1239
+ static const tensor_traits<block_q2_K, 256, 32> q2_k_32x256_q8_0;
1240
+ static const tensor_traits<block_q3_K, 256, 32> q3_k_32x256_q8_0;
1241
+ static const tensor_traits<block_q4_0, 32, 32> q4_0_32x32_q8_0;
1242
+ static const tensor_traits<block_q4_1, 32, 32> q4_1_32x32_q8_0;
1243
+ static const tensor_traits<block_q4_0, 256, 32> q4_0_32x256_q8_0;
1244
+ static const tensor_traits<block_q4_1, 256, 32> q4_1_32x256_q8_0;
1245
+ static const tensor_traits<block_q4_K, 32, 32> q4_k_32x32_q8_0;
1246
+ static const tensor_traits<block_q6_K, 32, 32> q6_k_32x32_q8_0;
1247
+ static const tensor_traits<block_q8_0, 32, 32> q8_0_32x32_q8_0;
1248
+ static const tensor_traits<block_mxfp4, 32, 32> mxfp4_32x32_q8_0;
1249
+ static const tensor_traits<block_q5_K, 32, 32> q5_k_32x32_q8_0;
1250
+ static const tensor_traits<block_q5_1, 32, 32> q5_1_32x32_q8_0;
1251
+ static const tensor_traits<block_q5_0, 32, 32> q5_0_32x32_q8_0;
1252
+ // Impl By RVV
1253
+ static const tensor_traits_common rvv_impl;
763
1254
 
764
- __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl);
1255
+ } // namespace ggml::cpu::riscv64_spacemit
765
1256
 
766
- p_input += gvl;
767
- p_temp_output += gvl;
768
- length -= gvl;
1257
+ static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const ggml_tensor * cur) {
1258
+ switch (cur->type) {
1259
+ case GGML_TYPE_Q2_K:
1260
+ {
1261
+ #if defined(RISCV64_SPACEMIT_IME2)
1262
+ if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1263
+ return &ggml::cpu::riscv64_spacemit::q2_k_32x256_q8_0;
1264
+ }
1265
+ #endif
769
1266
  }
1267
+ break;
1268
+ case GGML_TYPE_Q3_K:
1269
+ {
1270
+ #if defined(RISCV64_SPACEMIT_IME2)
1271
+ if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1272
+ return &ggml::cpu::riscv64_spacemit::q3_k_32x256_q8_0;
1273
+ }
1274
+ #endif
1275
+ }
1276
+ break;
1277
+ case GGML_TYPE_Q4_0:
1278
+ {
1279
+ #if defined(RISCV64_SPACEMIT_IME2)
1280
+ if (cur->ne[1] % 32 == 0 && cur->ne[0] % 256 == 0 &&
1281
+ (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1282
+ return &ggml::cpu::riscv64_spacemit::q4_0_32x256_q8_0;
1283
+ }
770
1284
 
771
- gvl = __riscv_vsetvlmax_e32m1();
772
-
773
- // float mean = 0.f;
774
- vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl);
775
-
776
- vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0),
777
- __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl);
778
- mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl);
779
- mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl);
780
- mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl);
781
-
782
- float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v);
783
- mean_square /= hidden_size;
1285
+ if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1286
+ return &ggml::cpu::riscv64_spacemit::q4_0_32x32_q8_0;
1287
+ }
1288
+ #endif
784
1289
 
785
- mean_square = sqrt(mean_square + epsilon);
1290
+ #if defined(RISCV64_SPACEMIT_IME1)
1291
+ if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) {
1292
+ return &ggml::cpu::riscv64_spacemit::q4_0_16x32_q8_0;
1293
+ }
1294
+ #endif
1295
+ }
1296
+ break;
1297
+ case GGML_TYPE_Q4_1:
1298
+ {
1299
+ #if defined(RISCV64_SPACEMIT_IME2)
1300
+ // TODO
1301
+ // if (cur->ne[1] % 32 == 0 && cur->ne[0] % 256 == 0 &&
1302
+ // (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1303
+ // return &ggml::cpu::riscv64_spacemit::q4_1_32x256_q8_0;
1304
+ // }
1305
+
1306
+ if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1307
+ return &ggml::cpu::riscv64_spacemit::q4_1_32x32_q8_0;
1308
+ }
1309
+ #endif
786
1310
 
787
- mean_square = 1.0f / mean_square;
788
- length = hidden_size;
789
- p_temp_output = p_output;
1311
+ #if defined(RISCV64_SPACEMIT_IME1)
1312
+ if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) {
1313
+ return &ggml::cpu::riscv64_spacemit::q4_1_16x32_q8_0;
1314
+ }
1315
+ #endif
1316
+ }
1317
+ break;
1318
+ case GGML_TYPE_Q4_K:
1319
+ {
1320
+ #if defined(RISCV64_SPACEMIT_IME2)
1321
+ if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1322
+ return &ggml::cpu::riscv64_spacemit::q4_k_32x32_q8_0;
1323
+ }
1324
+ #endif
790
1325
 
791
- if (p_gamma_data == nullptr && p_beta_data == nullptr) {
792
- while (length > 0) {
793
- gvl = __riscv_vsetvl_e32m4(length);
794
- vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
795
- src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
796
- __riscv_vse32_v_f32m4(p_output, src_data, gvl);
797
- p_temp_output += gvl;
798
- p_output += gvl;
799
- length -= gvl;
1326
+ #if defined(RISCV64_SPACEMIT_IME1)
1327
+ if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) {
1328
+ return &ggml::cpu::riscv64_spacemit::q4_k_16x32_q8_0;
800
1329
  }
801
- } else if (p_beta_data == nullptr) {
802
- while (length > 0) {
803
- gvl = __riscv_vsetvl_e32m4(length);
804
- vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
805
- vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
806
- src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
807
- src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
808
- __riscv_vse32_v_f32m4(p_output, src_data, gvl);
809
- p_temp_output += gvl;
810
- p_output += gvl;
811
- p_gamma_data += gvl;
812
- length -= gvl;
1330
+ #endif
1331
+ }
1332
+ break;
1333
+ case GGML_TYPE_Q6_K:
1334
+ {
1335
+ #if defined(RISCV64_SPACEMIT_IME2)
1336
+ if ((ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1337
+ return &ggml::cpu::riscv64_spacemit::q6_k_32x32_q8_0;
813
1338
  }
814
- } else if (p_gamma_data != nullptr) {
815
- while (length > 0) {
816
- gvl = __riscv_vsetvl_e32m4(length);
817
- vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
818
- vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
819
- src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
820
- src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
821
- vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl);
822
- src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl);
823
- p_beta_data += gvl;
824
- __riscv_vse32_v_f32m4(p_output, src_data, gvl);
825
- p_temp_output += gvl;
826
- p_output += gvl;
827
- p_gamma_data += gvl;
828
- length -= gvl;
1339
+ #endif
1340
+ }
1341
+ break;
1342
+ case GGML_TYPE_Q8_0:
1343
+ {
1344
+ #if defined(RISCV64_SPACEMIT_IME2)
1345
+ if ((ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1346
+ return &ggml::cpu::riscv64_spacemit::q8_0_32x32_q8_0;
829
1347
  }
1348
+ #endif
830
1349
  }
831
- }
832
- }
833
-
834
- int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
835
- memcpy(t->data, data, data_size);
836
- return 0;
837
- }
838
- };
839
-
840
- static const tensor_traits<block_q4_0, 8, 16> q4_0_16x8_q8_0;
841
- static const tensor_traits<block_q4_1, 8, 16> q4_1_16x8_q8_0;
842
- static const tensor_traits<block_q4_K, 8, 16> q4_k_16x8_q8_0;
843
- static const tensor_traits_common rvv_impl;
844
-
845
- } // namespace ggml::cpu::riscv64_spacemit
846
-
847
- static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const struct ggml_tensor * cur) {
848
- if (cur->type == GGML_TYPE_Q4_0) {
849
- if (cur->ne[1] % 16 == 0) {
850
- return &ggml::cpu::riscv64_spacemit::q4_0_16x8_q8_0;
851
- }
852
- } else if (cur->type == GGML_TYPE_Q4_1) {
853
- if (cur->ne[1] % 16 == 0) {
854
- return &ggml::cpu::riscv64_spacemit::q4_1_16x8_q8_0;
855
- }
856
- } else if (cur->type == GGML_TYPE_Q4_K) {
857
- if (cur->ne[1] % 16 == 0) {
858
- return &ggml::cpu::riscv64_spacemit::q4_k_16x8_q8_0;
859
- }
860
- } else if (cur->type == GGML_TYPE_F32) {
861
- return &ggml::cpu::riscv64_spacemit::rvv_impl;
1350
+ break;
1351
+ case GGML_TYPE_MXFP4:
1352
+ {
1353
+ #if defined(RISCV64_SPACEMIT_IME2)
1354
+ // TODO
1355
+ // if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1356
+ // return &ggml::cpu::riscv64_spacemit::mxfp4_32x32_q8_0;
1357
+ // }
1358
+ #endif
1359
+ }
1360
+ break;
1361
+ case GGML_TYPE_Q5_K:
1362
+ {
1363
+ #if defined(RISCV64_SPACEMIT_IME2)
1364
+ if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1365
+ return &ggml::cpu::riscv64_spacemit::q5_k_32x32_q8_0;
1366
+ }
1367
+ #endif
1368
+ }
1369
+ break;
1370
+ case GGML_TYPE_Q5_1:
1371
+ {
1372
+ #if defined(RISCV64_SPACEMIT_IME2)
1373
+ if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1374
+ return &ggml::cpu::riscv64_spacemit::q5_1_32x32_q8_0;
1375
+ }
1376
+ #endif
1377
+ }
1378
+ break;
1379
+ case GGML_TYPE_Q5_0:
1380
+ {
1381
+ #if defined(RISCV64_SPACEMIT_IME2)
1382
+ if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1383
+ return &ggml::cpu::riscv64_spacemit::q5_0_32x32_q8_0;
1384
+ }
1385
+ #endif
1386
+ }
1387
+ break;
1388
+ default:
1389
+ break;
862
1390
  }
863
1391
 
864
1392
  return nullptr;
865
1393
  }
866
1394
 
867
1395
  static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_backend_buffer_t buffer,
868
- struct ggml_tensor * tensor) {
1396
+ ggml_tensor * tensor) {
869
1397
  tensor->extra =
870
1398
  (void *) const_cast<ggml::cpu::tensor_traits *>(ggml_riscv64_spacemit_get_optimal_repack_type(tensor));
871
1399
 
@@ -874,8 +1402,46 @@ static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_ba
874
1402
  return GGML_STATUS_SUCCESS;
875
1403
  }
876
1404
 
1405
+ static void ggml_backend_riscv64_spacemit_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1406
+ GGML_ASSERT(buffer);
1407
+
1408
+ void * base = buffer->context;
1409
+ if (base == nullptr) {
1410
+ return;
1411
+ }
1412
+
1413
+ ggml::cpu::riscv64_spacemit::spine_mem_pool_free(base);
1414
+ }
1415
+
1416
+ static void * ggml_backend_riscv64_spacemit_buffer_get_base(ggml_backend_buffer_t buffer) {
1417
+ GGML_ASSERT(buffer);
1418
+
1419
+ void * base = buffer->context;
1420
+ GGML_ASSERT(base != nullptr);
1421
+ return base;
1422
+ }
1423
+
1424
+ static void ggml_backend_riscv64_spacemit_buffer_memset_tensor(ggml_backend_buffer_t buffer,
1425
+ ggml_tensor * tensor,
1426
+ uint8_t value,
1427
+ size_t offset,
1428
+ size_t size) {
1429
+ GGML_ASSERT(tensor);
1430
+ memset((char *) tensor->data + offset, value, size);
1431
+
1432
+ GGML_UNUSED(buffer);
1433
+ }
1434
+
1435
+ static void ggml_backend_riscv64_spacemit_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
1436
+ GGML_ASSERT(buffer);
1437
+
1438
+ void * base = buffer->context;
1439
+ GGML_ASSERT(base != nullptr);
1440
+ memset(base, value, buffer->size);
1441
+ }
1442
+
877
1443
  static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_t buffer,
878
- struct ggml_tensor * tensor,
1444
+ ggml_tensor * tensor,
879
1445
  const void * data,
880
1446
  size_t offset,
881
1447
  size_t size) {
@@ -891,6 +1457,20 @@ static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_
891
1457
  GGML_UNUSED(buffer);
892
1458
  }
893
1459
 
1460
+ static const ggml_backend_buffer_i ggml_backend_riscv64_spacemit_buffer_i = {
1461
+ /* .free_buffer = */ ggml_backend_riscv64_spacemit_buffer_free_buffer,
1462
+ /* .get_base = */ ggml_backend_riscv64_spacemit_buffer_get_base,
1463
+ /* .init_tensor = */ ggml_backend_riscv64_spacemit_buffer_init_tensor,
1464
+ /* .memset_tensor = */ ggml_backend_riscv64_spacemit_buffer_memset_tensor,
1465
+ /* .set_tensor = */ ggml_backend_riscv64_spacemit_buffer_set_tensor,
1466
+ /* .get_tensor = */ nullptr,
1467
+ /* .set_tensor_2d = */ nullptr,
1468
+ /* .get_tensor_2d = */ nullptr,
1469
+ /* .cpy_tensor = */ nullptr,
1470
+ /* .clear = */ ggml_backend_riscv64_spacemit_buffer_clear,
1471
+ /* .reset = */ nullptr,
1472
+ };
1473
+
894
1474
  static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
895
1475
  return "CPU_RISCV64_SPACEMIT";
896
1476
 
@@ -899,18 +1479,12 @@ static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_
899
1479
 
900
1480
  static ggml_backend_buffer_t ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
901
1481
  size_t size) {
902
- ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
903
-
904
- if (buffer == nullptr) {
1482
+ void * base = ggml::cpu::riscv64_spacemit::spine_mem_pool_alloc(size, 64);
1483
+ if (base == nullptr) {
905
1484
  return nullptr;
906
1485
  }
907
1486
 
908
- buffer->buft = buft;
909
- buffer->iface.init_tensor = ggml_backend_riscv64_spacemit_buffer_init_tensor;
910
- buffer->iface.set_tensor = ggml_backend_riscv64_spacemit_buffer_set_tensor;
911
- buffer->iface.get_tensor = nullptr;
912
- buffer->iface.cpy_tensor = nullptr;
913
- return buffer;
1487
+ return ggml_backend_buffer_init(buft, ggml_backend_riscv64_spacemit_buffer_i, base, size);
914
1488
  }
915
1489
 
916
1490
  static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
@@ -919,44 +1493,91 @@ static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_b
919
1493
  GGML_UNUSED(buft);
920
1494
  }
921
1495
 
922
- static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft,
923
- const struct ggml_tensor * tensor) {
1496
+ static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
924
1497
  for (int i = 0; i < GGML_MAX_DIMS; ++i) {
925
1498
  if (tensor->ne[i] <= 0) {
926
1499
  return 0;
927
1500
  }
928
1501
  }
929
1502
 
930
- size_t nbytes;
1503
+ GGML_UNUSED(buft);
1504
+
1505
+ const auto plain_nbytes = [&]() {
1506
+ size_t total = ggml_type_size(tensor->type);
1507
+ for (int i = 0; i < GGML_MAX_DIMS; ++i) {
1508
+ total += (tensor->ne[i] - 1) * tensor->nb[i];
1509
+ }
1510
+ return total;
1511
+ };
1512
+
931
1513
  const size_t blck_size = ggml_blck_size(tensor->type);
932
1514
  if (blck_size == 1) {
933
- nbytes = ggml_type_size(tensor->type);
934
- for (int i = 0; i < GGML_MAX_DIMS; ++i) {
935
- nbytes += (tensor->ne[i] - 1) * tensor->nb[i];
1515
+ return plain_nbytes();
1516
+ }
1517
+
1518
+ const size_t row_nbytes = tensor->ne[0] * tensor->nb[0] / blck_size;
1519
+
1520
+ const auto add_strided_nbytes = [&](size_t total, size_t src_block_size, size_t dst_block_size) {
1521
+ for (int i = 1; i < GGML_MAX_DIMS; ++i) {
1522
+ total += (tensor->ne[i] - 1) * (tensor->nb[i] / src_block_size) * dst_block_size;
936
1523
  }
937
- } else {
938
- nbytes = tensor->ne[0] * tensor->nb[0] / blck_size;
939
- if (tensor->type == GGML_TYPE_Q4_K) {
940
- GGML_ASSERT(nbytes % sizeof(block_q4_K) == 0);
941
- nbytes = (nbytes / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8;
942
- for (int i = 1; i < GGML_MAX_DIMS; ++i) {
943
- nbytes += (tensor->ne[i] - 1) * (tensor->nb[i] / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8;
944
- }
945
- } else {
946
- for (int i = 1; i < GGML_MAX_DIMS; ++i) {
947
- nbytes += (tensor->ne[i] - 1) * tensor->nb[i];
948
- }
1524
+ return total;
1525
+ };
1526
+
1527
+ const auto remap_block_nbytes = [&](size_t src_block_size, size_t dst_block_size, int64_t padded_rows = 0) {
1528
+ GGML_ASSERT(row_nbytes % src_block_size == 0);
1529
+
1530
+ size_t total =
1531
+ add_strided_nbytes((row_nbytes / src_block_size) * dst_block_size, src_block_size, dst_block_size);
1532
+
1533
+ if (padded_rows > 0 && tensor->ne[1] % padded_rows != 0) {
1534
+ total += (padded_rows - tensor->ne[1] % padded_rows) * (tensor->nb[1] / src_block_size) * dst_block_size;
949
1535
  }
1536
+
1537
+ return total;
1538
+ };
1539
+
1540
+ size_t nbytes = row_nbytes;
1541
+ switch (tensor->type) {
1542
+ case GGML_TYPE_Q4_K:
1543
+ nbytes = remap_block_nbytes(sizeof(block_q4_K), sizeof(block_q4_1) * 8);
1544
+ break;
1545
+ case GGML_TYPE_Q6_K:
1546
+ nbytes = remap_block_nbytes(sizeof(block_q6_K), sizeof(block_q8_0) * 8, 32);
1547
+ break;
1548
+ case GGML_TYPE_Q8_0:
1549
+ nbytes = remap_block_nbytes(sizeof(block_q8_0), sizeof(block_q8_0), 32);
1550
+ break;
1551
+ case GGML_TYPE_Q2_K:
1552
+ nbytes = remap_block_nbytes(sizeof(block_q2_K), sizeof(spacemit_kernels::nrow_block_q2_k<1>));
1553
+ break;
1554
+ case GGML_TYPE_Q3_K:
1555
+ nbytes = remap_block_nbytes(sizeof(block_q3_K), sizeof(spacemit_kernels::nrow_block_q3_k<1>));
1556
+ break;
1557
+ case GGML_TYPE_MXFP4:
1558
+ nbytes = remap_block_nbytes(sizeof(block_mxfp4), sizeof(spacemit_kernels::nrow_block_mxfp4<1>));
1559
+ break;
1560
+ case GGML_TYPE_Q5_K:
1561
+ nbytes = remap_block_nbytes(sizeof(block_q5_K), sizeof(spacemit_kernels::nrow_block_q5_1<1>) * 8);
1562
+ break;
1563
+ case GGML_TYPE_Q5_1:
1564
+ nbytes = remap_block_nbytes(sizeof(block_q5_1), sizeof(spacemit_kernels::nrow_block_q5_1<1>));
1565
+ break;
1566
+ case GGML_TYPE_Q5_0:
1567
+ nbytes = remap_block_nbytes(sizeof(block_q5_0), sizeof(spacemit_kernels::nrow_block_q5_0<1>));
1568
+ break;
1569
+ default:
1570
+ nbytes = add_strided_nbytes(row_nbytes, 1, 1);
1571
+ break;
950
1572
  }
951
1573
 
952
- GGML_UNUSED(buft);
953
1574
  return nbytes;
954
1575
  }
955
1576
 
956
1577
  namespace ggml::cpu::riscv64_spacemit {
957
1578
 
958
1579
  class extra_buffer_type : ggml::cpu::extra_buffer_type {
959
- bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
1580
+ bool supports_op(ggml_backend_dev_t, const ggml_tensor * op) override {
960
1581
  switch (op->op) {
961
1582
  case GGML_OP_MUL_MAT:
962
1583
  if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) &&
@@ -970,10 +1591,16 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
970
1591
  }
971
1592
  }
972
1593
  break;
973
- case GGML_OP_NORM:
974
- case GGML_OP_RMS_NORM:
975
- if (op->src[0]->type == GGML_TYPE_F32) {
976
- return true;
1594
+ case GGML_OP_MUL_MAT_ID:
1595
+ if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 3) &&
1596
+ op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type() &&
1597
+ ggml_riscv64_spacemit_get_optimal_repack_type(op->src[0])) {
1598
+ if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
1599
+ return false;
1600
+ }
1601
+ if (op->src[1]->type == GGML_TYPE_F32) {
1602
+ return true;
1603
+ }
977
1604
  }
978
1605
  break;
979
1606
  default:
@@ -983,15 +1610,28 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
983
1610
  return false;
984
1611
  }
985
1612
 
986
- ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
1613
+ ggml::cpu::tensor_traits * get_tensor_traits(const ggml_tensor * op) override {
987
1614
  switch (op->op) {
988
1615
  case GGML_OP_MUL_MAT:
1616
+ case GGML_OP_MUL_MAT_ID:
989
1617
  if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type()) {
990
1618
  return (ggml::cpu::tensor_traits *) op->src[0]->extra;
991
1619
  }
992
1620
  break;
993
1621
  case GGML_OP_NORM:
994
1622
  case GGML_OP_RMS_NORM:
1623
+ case GGML_OP_ADD:
1624
+ case GGML_OP_SUB:
1625
+ case GGML_OP_MUL:
1626
+ case GGML_OP_DIV:
1627
+ case GGML_OP_FLASH_ATTN_EXT:
1628
+ case GGML_OP_CONT:
1629
+ case GGML_OP_CPY:
1630
+ case GGML_OP_REPEAT:
1631
+ case GGML_OP_SUM_ROWS:
1632
+ case GGML_OP_GET_ROWS:
1633
+ case GGML_OP_CONCAT:
1634
+ // case GGML_OP_GATED_DELTA_NET:
995
1635
  return (ggml::cpu::tensor_traits *) (&ggml::cpu::riscv64_spacemit::rvv_impl);
996
1636
  default:
997
1637
  // GGML_ABORT("fatal error");
@@ -1005,7 +1645,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
1005
1645
  } // namespace ggml::cpu::riscv64_spacemit
1006
1646
 
1007
1647
  ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) {
1008
- static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = {
1648
+ static ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = {
1009
1649
  /* .iface = */
1010
1650
  {
1011
1651
  /* .get_name = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name,
@@ -1023,3 +1663,78 @@ ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) {
1023
1663
 
1024
1664
  return &ggml_backend_cpu_buffer_type_riscv64_spacemit;
1025
1665
  }
1666
+
1667
+ extern "C" {
1668
+ static int bind_ai_thread() {
1669
+ int fd, bytes;
1670
+ char str[32];
1671
+
1672
+ fd = open("/proc/set_ai_thread", O_WRONLY);
1673
+ if (fd < 0) {
1674
+ GGML_LOG_ERROR("try open /proc/set_ai_thread failed\n");
1675
+ return -1;
1676
+ }
1677
+
1678
+ snprintf(str, 16, "%d", 0);
1679
+ bytes = write(fd, str, strlen(str));
1680
+ if (bytes < 0) {
1681
+ GGML_LOG_ERROR("try write /proc/set_ai_thread failed\n");
1682
+ close(fd);
1683
+ return -1;
1684
+ }
1685
+
1686
+ close(fd);
1687
+ return 0;
1688
+ }
1689
+
1690
+ void ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(int thread_n) {
1691
+ int cpu_id = sched_getcpu();
1692
+ if (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2 &&
1693
+ !((1 << cpu_id) & ggml::cpu::riscv64_spacemit::global_spine_env_info.cpu_mask)) {
1694
+ GGML_PRINT_DEBUG("bind_ai_thread for thread %d, pid %d\n", thread_n, getpid());
1695
+ bind_ai_thread();
1696
+ }
1697
+
1698
+ if (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_tcm &&
1699
+ ggml::cpu::riscv64_spacemit::tls_context.cpu_id == -1) {
1700
+ CPU_ZERO(&(ggml::cpu::riscv64_spacemit::tls_context.cpuset));
1701
+ pthread_t main_thread = pthread_self();
1702
+ const auto & perfer_core_ids = ggml::cpu::riscv64_spacemit::global_spine_env_info.perfer_core_ids;
1703
+ if (thread_n < 0 || static_cast<size_t>(thread_n) >= perfer_core_ids.size()) {
1704
+ GGML_ABORT("thread_n %d exceeds perfer_core_ids size %zu\n", thread_n, perfer_core_ids.size());
1705
+ }
1706
+ auto perfer_cpu_id = perfer_core_ids[static_cast<size_t>(thread_n)];
1707
+ CPU_SET(perfer_cpu_id, &(ggml::cpu::riscv64_spacemit::tls_context.cpuset));
1708
+ int s =
1709
+ pthread_setaffinity_np(main_thread, sizeof(cpu_set_t), &(ggml::cpu::riscv64_spacemit::tls_context.cpuset));
1710
+ if (s != 0) {
1711
+ GGML_ABORT("set thread affinity error for thread_n %d, cpu_id %d\n", thread_n, perfer_cpu_id);
1712
+ }
1713
+
1714
+ int ai_cpu_id = perfer_cpu_id - ggml::cpu::riscv64_spacemit::global_spine_env_info.aicpu_id_offset;
1715
+ ggml::cpu::riscv64_spacemit::tls_context.cpu_id = ai_cpu_id;
1716
+ ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer =
1717
+ ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_get(ai_cpu_id);
1718
+ ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size =
1719
+ ggml::cpu::riscv64_spacemit::global_spine_env_info.tcm_blk_size;
1720
+ }
1721
+
1722
+ if (ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer != nullptr) {
1723
+ void * rt =
1724
+ ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_wait(ggml::cpu::riscv64_spacemit::tls_context.cpu_id);
1725
+ if (rt == nullptr) {
1726
+ GGML_ABORT("wait tcm buffer failed for cpu_id: %d", ggml::cpu::riscv64_spacemit::tls_context.cpu_id);
1727
+ }
1728
+ }
1729
+ }
1730
+
1731
+ void ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(int thread_n) {
1732
+ if (ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer != nullptr) {
1733
+ auto rt = ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_release(
1734
+ ggml::cpu::riscv64_spacemit::tls_context.cpu_id);
1735
+ if (rt != 0) {
1736
+ GGML_ABORT("release tcm buffer failed for cpu_id: %d", ggml::cpu::riscv64_spacemit::tls_context.cpu_id);
1737
+ }
1738
+ }
1739
+ }
1740
+ }