whispercpp 1.3.5 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1017) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/LICENSE +1 -1
  5. data/README.md +133 -3
  6. data/Rakefile +18 -3
  7. data/ext/dependencies.rb +10 -4
  8. data/ext/dependencies_for_windows.rb +17 -0
  9. data/ext/extconf.rb +20 -7
  10. data/ext/options.rb +54 -14
  11. data/ext/options_for_windows.rb +51 -0
  12. data/ext/ruby_whisper.c +56 -46
  13. data/ext/ruby_whisper.h +165 -2
  14. data/ext/ruby_whisper_context.c +297 -126
  15. data/ext/ruby_whisper_context_params.c +163 -0
  16. data/ext/ruby_whisper_log_queue.c +180 -0
  17. data/ext/ruby_whisper_log_settable.h +47 -0
  18. data/ext/ruby_whisper_model.c +0 -1
  19. data/ext/ruby_whisper_parakeet.c +49 -0
  20. data/ext/ruby_whisper_parakeet_context.c +304 -0
  21. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  22. data/ext/ruby_whisper_parakeet_model.c +84 -0
  23. data/ext/ruby_whisper_parakeet_params.c +548 -0
  24. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  25. data/ext/ruby_whisper_parakeet_token.c +188 -0
  26. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  27. data/ext/ruby_whisper_params.c +256 -66
  28. data/ext/ruby_whisper_segment.c +6 -7
  29. data/ext/ruby_whisper_token.c +29 -9
  30. data/ext/ruby_whisper_transcribe.cpp +46 -16
  31. data/ext/ruby_whisper_vad_context.c +48 -1
  32. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  33. data/ext/ruby_whisper_vad_params.c +0 -1
  34. data/ext/ruby_whisper_vad_segment.c +0 -1
  35. data/ext/ruby_whisper_vad_segments.c +0 -1
  36. data/ext/sources/CMakeLists.txt +41 -3
  37. data/ext/sources/CMakePresets.json +95 -0
  38. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  39. data/ext/sources/cmake/parakeet.pc.in +10 -0
  40. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  41. data/ext/sources/cmake/whisper.pc.in +1 -1
  42. data/ext/sources/examples/CMakeLists.txt +4 -2
  43. data/ext/sources/examples/bench/bench.cpp +24 -19
  44. data/ext/sources/examples/cli/cli.cpp +51 -9
  45. data/ext/sources/examples/common-ggml.cpp +4 -0
  46. data/ext/sources/examples/common-whisper.cpp +139 -67
  47. data/ext/sources/examples/common-whisper.h +11 -0
  48. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  49. data/ext/sources/examples/miniaudio.h +4507 -2131
  50. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  51. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  52. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  53. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  54. data/ext/sources/examples/server/server.cpp +213 -163
  55. data/ext/sources/ggml/CMakeLists.txt +29 -15
  56. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  57. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  58. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  59. data/ext/sources/ggml/include/ggml-backend.h +73 -11
  60. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  61. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  62. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  63. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  64. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  65. data/ext/sources/ggml/include/ggml-rpc.h +8 -3
  66. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  67. data/ext/sources/ggml/include/ggml.h +155 -16
  68. data/ext/sources/ggml/include/gguf.h +10 -2
  69. data/ext/sources/ggml/src/CMakeLists.txt +25 -5
  70. data/ext/sources/ggml/src/ggml-alloc.c +9 -10
  71. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  72. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  73. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  74. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
  76. data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
  79. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
  83. data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
  85. data/ext/sources/ggml/src/ggml-common.h +22 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
  88. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  89. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
  91. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
  92. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
  99. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
  101. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  102. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  103. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
  105. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
  106. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
  107. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  108. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
  109. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
  110. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
  111. data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
  112. data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
  113. data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
  115. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  116. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
  117. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
  118. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  119. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  120. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  121. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  122. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  123. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  124. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  125. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  126. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  127. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  128. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  129. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  130. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  131. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  132. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  133. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  134. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
  135. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  136. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
  137. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  138. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  139. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
  140. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
  141. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  142. data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
  143. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  144. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  145. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  146. data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
  147. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  148. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
  149. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  150. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
  151. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
  152. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  153. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
  154. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
  155. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
  156. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  157. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
  158. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  159. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  160. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  161. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
  162. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  163. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  164. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
  165. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  166. data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
  167. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
  168. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  169. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
  170. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  171. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
  172. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
  173. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  174. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
  175. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
  176. data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
  177. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  178. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  179. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
  180. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  181. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
  182. data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
  183. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  184. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  185. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  186. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  187. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  188. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  189. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  190. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
  191. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  192. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  193. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
  226. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
  227. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  228. data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
  229. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
  230. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
  231. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  232. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
  233. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
  235. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
  254. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  255. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  256. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  257. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  258. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  259. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  260. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  261. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  262. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
  263. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
  264. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  265. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  266. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
  267. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
  268. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
  269. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  270. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
  271. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  272. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  273. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  274. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  275. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  276. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  278. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  297. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
  302. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
  303. data/ext/sources/ggml/src/ggml-impl.h +68 -1
  304. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  305. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  306. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  307. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  308. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
  309. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
  310. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
  311. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
  312. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
  313. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
  314. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
  315. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
  316. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  317. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
  318. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  367. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  368. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  369. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  370. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  371. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  372. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  373. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  374. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  375. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  376. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  377. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  378. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  379. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  380. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  381. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  382. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  383. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  384. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  385. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  386. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  387. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  388. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  389. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
  390. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  391. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  392. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  393. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  394. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  395. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  396. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  397. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  398. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  399. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  400. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  401. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
  402. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  403. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
  404. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  405. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
  406. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
  407. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  412. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  413. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  414. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  415. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  416. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  417. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  418. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  419. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  420. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  421. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  422. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  423. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  424. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
  425. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  426. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  427. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  428. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  429. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  430. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  431. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  432. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
  433. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
  434. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  435. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  436. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  437. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  438. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  439. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  440. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
  441. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  442. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
  443. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
  444. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
  445. data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
  446. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  447. data/ext/sources/ggml/src/ggml-quants.c +385 -119
  448. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  449. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  450. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  451. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  452. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  453. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
  454. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  455. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
  456. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  457. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  458. data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
  459. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
  460. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
  461. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  462. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  463. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  464. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  465. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  466. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  467. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  468. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  469. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
  470. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
  471. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  472. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  473. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
  474. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
  475. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
  476. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
  477. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
  478. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  479. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  480. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  481. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
  482. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
  483. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  484. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  485. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
  486. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  487. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  488. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  489. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  490. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  491. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  492. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  493. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  494. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
  495. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  496. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  497. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  498. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  499. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  500. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  501. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  502. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  503. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  504. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  505. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  506. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  507. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  508. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  509. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  510. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  511. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  512. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  513. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  514. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  515. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  516. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
  517. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
  518. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
  519. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
  520. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
  521. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
  522. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
  523. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
  524. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
  525. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
  526. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
  527. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
  530. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
  531. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
  532. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
  534. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
  535. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
  536. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
  537. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
  538. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
  539. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
  540. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
  541. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
  542. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
  543. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
  544. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
  545. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
  547. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
  548. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
  549. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
  550. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
  551. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
  552. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  553. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  554. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  555. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
  556. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  557. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  558. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  559. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  560. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  561. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  562. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  563. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  564. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  565. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  566. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  567. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  568. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  569. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  570. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  571. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  572. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  573. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  574. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  575. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  576. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  577. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  578. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
  579. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
  580. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  581. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
  582. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  583. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  584. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  585. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  586. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  587. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  588. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  589. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  590. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  591. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  592. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
  593. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  594. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  595. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  596. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
  597. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  599. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
  663. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  664. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
  665. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
  666. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  667. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  668. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  669. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  670. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  671. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
  672. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
  673. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
  674. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  675. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
  676. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  677. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
  678. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  679. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  680. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  681. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  682. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  683. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  684. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
  685. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  686. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  687. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
  688. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  689. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  690. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  691. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
  692. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
  693. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
  694. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  695. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  696. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  697. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  698. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  699. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  700. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  701. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  702. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  703. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
  704. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  705. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
  706. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  707. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  708. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  709. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  710. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  711. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  712. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
  713. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  714. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
  715. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  716. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
  717. data/ext/sources/ggml/src/ggml.c +268 -52
  718. data/ext/sources/ggml/src/gguf.cpp +377 -47
  719. data/ext/sources/include/parakeet.h +342 -0
  720. data/ext/sources/include/whisper.h +10 -0
  721. data/ext/sources/media/matmul.png +0 -0
  722. data/ext/sources/src/CMakeLists.txt +23 -0
  723. data/ext/sources/src/parakeet-arch.h +188 -0
  724. data/ext/sources/src/parakeet.cpp +3838 -0
  725. data/ext/sources/src/whisper.cpp +62 -40
  726. data/extsources.rb +26 -10
  727. data/lib/whisper/log_settable.rb +36 -0
  728. data/lib/whisper/model/uri.rb +13 -1
  729. data/lib/whisper/output.rb +74 -0
  730. data/sig/whisper.rbs +445 -55
  731. data/test/helper.rb +2 -0
  732. data/test/jfk_reader/jfk_reader.c +50 -7
  733. data/test/test_callback.rb +1 -0
  734. data/test/test_context_params.rb +82 -0
  735. data/test/test_package.rb +6 -5
  736. data/test/test_parakeet.rb +28 -0
  737. data/test/test_parakeet_callback.rb +107 -0
  738. data/test/test_parakeet_context.rb +116 -0
  739. data/test/test_parakeet_context_params.rb +24 -0
  740. data/test/test_parakeet_model.rb +21 -0
  741. data/test/test_parakeet_params.rb +78 -0
  742. data/test/test_parakeet_segment.rb +42 -0
  743. data/test/test_parakeet_token.rb +73 -0
  744. data/test/test_params.rb +2 -0
  745. data/test/test_token.rb +11 -0
  746. data/test/test_vad_context.rb +58 -8
  747. data/test/test_vad_segment.rb +1 -1
  748. data/test/test_whisper.rb +44 -6
  749. data/whispercpp.gemspec +2 -2
  750. metadata +426 -280
  751. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  752. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  753. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  754. data/ext/sources/bindings/javascript/package.json +0 -26
  755. data/ext/sources/bindings/javascript/whisper.js +0 -19
  756. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  757. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  758. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  759. data/ext/sources/examples/addon.node/index.js +0 -59
  760. data/ext/sources/examples/addon.node/package.json +0 -16
  761. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  762. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  763. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  764. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  765. data/ext/sources/examples/coi-serviceworker.js +0 -146
  766. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  767. data/ext/sources/examples/command/command.cpp +0 -802
  768. data/ext/sources/examples/command/commands.txt +0 -9
  769. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  770. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  771. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  772. data/ext/sources/examples/generate-karaoke.sh +0 -57
  773. data/ext/sources/examples/helpers.js +0 -191
  774. data/ext/sources/examples/livestream.sh +0 -112
  775. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  776. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  777. data/ext/sources/examples/lsp/whisper.vim +0 -362
  778. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  779. data/ext/sources/examples/python/whisper_processor.py +0 -54
  780. data/ext/sources/examples/server/bench.js +0 -29
  781. data/ext/sources/examples/server.py +0 -120
  782. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  783. data/ext/sources/examples/stream/stream.cpp +0 -437
  784. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  785. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  786. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  787. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  788. data/ext/sources/examples/sycl/build.sh +0 -22
  789. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  790. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  791. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
  792. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  793. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
  794. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
  795. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
  796. data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
  797. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
  798. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  799. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
  800. data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
  801. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
  802. data/ext/sources/examples/talk-llama/llama-context.h +0 -360
  803. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  804. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
  805. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  806. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  807. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
  808. data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
  809. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
  810. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
  811. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  812. data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
  813. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  814. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  815. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
  816. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  817. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
  818. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
  819. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  820. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  821. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  822. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
  823. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  824. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  825. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  826. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
  827. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  828. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
  829. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
  830. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
  831. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
  832. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
  833. data/ext/sources/examples/talk-llama/llama-model.h +0 -544
  834. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
  835. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  836. data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
  837. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
  838. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
  839. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
  840. data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
  841. data/ext/sources/examples/talk-llama/llama.h +0 -1540
  842. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
  843. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  844. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  845. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
  846. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  847. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
  848. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
  849. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
  850. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
  851. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
  852. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  853. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  854. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  855. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  856. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  857. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  858. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  859. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
  860. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  861. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
  862. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
  863. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
  864. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  865. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
  866. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  867. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  868. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  869. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
  870. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  871. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  872. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  873. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  874. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  875. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  876. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  877. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
  878. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  879. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  880. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
  881. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
  882. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
  883. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  884. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
  885. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  886. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
  887. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  888. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  889. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  890. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
  891. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  892. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  893. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  894. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
  895. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  896. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
  897. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
  898. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
  899. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
  900. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  901. data/ext/sources/examples/talk-llama/models/models.h +0 -569
  902. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
  903. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  904. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
  905. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  906. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  907. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  908. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  909. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  910. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  911. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  912. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  913. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  914. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  915. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  916. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  917. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
  918. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  919. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
  920. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  921. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  922. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  923. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  924. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
  925. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
  926. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
  927. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
  928. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
  929. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  930. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  931. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
  932. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  933. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  934. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
  935. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  936. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  937. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  938. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  939. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  940. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  941. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  942. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  943. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  944. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  945. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  946. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  947. data/ext/sources/examples/talk-llama/speak +0 -40
  948. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  949. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  950. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  951. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  952. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  953. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
  954. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  955. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  956. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  957. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  958. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  959. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  960. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  961. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  962. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  963. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  964. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  965. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  966. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  967. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  968. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  969. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  970. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
  971. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
  972. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  973. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  974. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  975. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  976. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  977. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  978. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  979. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  980. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  981. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  982. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  983. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  984. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
  985. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  986. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
  987. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
  988. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  989. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  990. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  991. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
  992. data/ext/sources/tests/CMakeLists.txt +0 -112
  993. data/ext/sources/tests/earnings21/eval.mk +0 -58
  994. data/ext/sources/tests/earnings21/eval.py +0 -68
  995. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  996. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  997. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  998. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  999. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  1000. data/ext/sources/tests/en-0-ref.txt +0 -1
  1001. data/ext/sources/tests/en-1-ref.txt +0 -1
  1002. data/ext/sources/tests/en-2-ref.txt +0 -1
  1003. data/ext/sources/tests/es-0-ref.txt +0 -1
  1004. data/ext/sources/tests/librispeech/eval.mk +0 -39
  1005. data/ext/sources/tests/librispeech/eval.py +0 -47
  1006. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  1007. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  1008. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  1009. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  1010. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  1011. data/ext/sources/tests/run-tests.sh +0 -130
  1012. data/ext/sources/tests/test-c.c +0 -3
  1013. data/ext/sources/tests/test-vad-full.cpp +0 -56
  1014. data/ext/sources/tests/test-vad.cpp +0 -83
  1015. data/ext/sources/tests/test-whisper.js +0 -58
  1016. data/lib/whisper/context.rb +0 -15
  1017. data/lib/whisper/segment.rb +0 -58
@@ -3,19 +3,32 @@
3
3
 
4
4
  #include "ime.h"
5
5
 
6
+ #include "binary-ops.h"
7
+ #include "common.h"
6
8
  #include "ggml-backend-impl.h"
7
9
  #include "ggml-common.h"
8
10
  #include "ggml-cpu.h"
11
+ #include "ime_env.h"
9
12
  #include "ime_kernels.h"
13
+ #include "ops.h"
14
+ #include "repack.h"
15
+ #include "rvv_kernels.h"
16
+ #include "spine_mem_pool.h"
10
17
  #include "traits.h"
18
+ #include "vec.h"
19
+
20
+ #include <fcntl.h>
21
+ #include <sys/mman.h>
22
+ #include <unistd.h>
11
23
 
12
24
  #include <algorithm>
25
+ #include <atomic>
13
26
  #include <cassert>
27
+ #include <cerrno>
14
28
  #include <cmath>
15
29
  #include <cstdio> // for GGML_ASSERT
16
30
  #include <stdexcept>
17
31
  #include <thread>
18
-
19
32
  // clang-format off
20
33
  #if defined(__riscv)
21
34
 
@@ -25,13 +38,17 @@
25
38
  #include <riscv_vector.h>
26
39
  #endif
27
40
 
28
- #if !defined(__riscv_zfh)
29
- #error "riscv zfh extension not enabled"
41
+ #if !defined(__riscv_zfh) || !defined(__riscv_zvfh)
42
+ #error "riscv zfh extension not enabled, GGML_RV_ZFH and GGML_RV_ZVFH must be defined to 1"
30
43
  #endif
31
44
 
32
- #if defined(RISCV64_SPACEMIT_IME1)
45
+ #if !defined(__riscv_zba)
46
+ #error "riscv zba extension not enabled, GGML_RV_ZBA must be defined to 1"
47
+ #endif
48
+
49
+ #if defined(RISCV64_SPACEMIT_IME1) || defined(RISCV64_SPACEMIT_IME2)
33
50
  #else
34
- #error "RISCV64_SPACEMIT_IME1 not defined"
51
+ #error "RISCV64_SPACEMIT_IME1 or RISCV64_SPACEMIT_IME2 not defined"
35
52
  #endif
36
53
 
37
54
  #else
@@ -46,382 +63,490 @@
46
63
  #pragma GCC diagnostic ignored "-Wunused-parameter"
47
64
  #endif
48
65
 
49
- #if defined(RISCV64_SPACEMIT_IME1)
50
- #define QGEMM_STRIDEN_THREAD_ALIGN 16
51
- #else
52
- #define QGEMM_STRIDEN_THREAD_ALIGN 32
53
- #endif
54
-
55
66
  // clang-format on
56
67
 
57
- struct qnbitgemm_spacemit_ime_args {
58
- const float * a_ptr = nullptr;
59
- size_t lda = 0;
60
- const std::byte * packed_quant_b_data = nullptr;
61
- const float * quant_b_scale = nullptr;
62
- const void * quant_b_zp = nullptr;
63
- const float * quant_b_blksum = nullptr;
64
- const float * bias = nullptr;
65
- float * c_ptr = nullptr;
66
- size_t ldc = 0;
67
- };
68
-
69
- constexpr size_t div_round_up(size_t up, size_t down) {
70
- return (up + down - 1) / down;
71
- }
72
-
73
- constexpr size_t q8_blk_size(size_t blk_len) {
74
- const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t);
75
- // Currently, the strictest alignment requirement of a block is for a float.
76
- // Ensure contiguous blocks are suitably aligned.
77
- assert(blk_size % alignof(float) == 0);
78
- return blk_size;
68
+ extern "C" {
69
+ extern void ggml_threadpool_chunk_set(struct ggml_threadpool * tp, int value);
70
+ extern int ggml_threadpool_chunk_add(struct ggml_threadpool * tp, int value);
79
71
  }
80
72
 
81
73
  namespace ggml::cpu::riscv64_spacemit {
82
74
 
83
- const int num_ai_cores = std::thread::hardware_concurrency() / 2;
84
-
85
- } // namespace ggml::cpu::riscv64_spacemit
75
+ struct TLSContext {
76
+ int cpu_id{ -1 };
77
+ cpu_set_t cpuset;
78
+ void * tcm_buffer{ nullptr };
79
+ size_t tcm_buffer_size{ 0 };
80
+ };
86
81
 
87
- static void sqnbitgemm_spacemit_ime_i8i4(const size_t blk_len,
88
- const size_t gemm_k,
89
- const qnbitgemm_spacemit_ime_args * gemm_args,
90
- void * const per_gemm_ws,
91
- const size_t m_start,
92
- const size_t m_count,
93
- const size_t n_start,
94
- const size_t n_count) {
95
- constexpr size_t scale_stride = sizeof(uint16_t);
96
- constexpr size_t blk_bitwidth = 4;
82
+ thread_local TLSContext tls_context;
83
+
84
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> constexpr size_t get_repacked_block_type_size() {
85
+ if constexpr (std::is_same_v<BLOC_TYPE, block_q6_K> || std::is_same_v<BLOC_TYPE, block_q8_0>) {
86
+ return sizeof(block_q8_0);
87
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0>) {
88
+ return sizeof(block_q4_0) * INTER_SIZE / QK4_0;
89
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q4_1> || std::is_same_v<BLOC_TYPE, block_q4_K>) {
90
+ return (sizeof(block_q4_0) + sizeof(uint8_t)) * INTER_SIZE / QK4_1;
91
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q2_K>) {
92
+ return sizeof(spacemit_kernels::nrow_block_q2_k<1>);
93
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q3_K>) {
94
+ return sizeof(spacemit_kernels::nrow_block_q3_k<1>);
95
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_mxfp4>) {
96
+ return sizeof(spacemit_kernels::nrow_block_mxfp4<1>);
97
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q5_1> || std::is_same_v<BLOC_TYPE, block_q5_K>) {
98
+ return sizeof(spacemit_kernels::nrow_block_q5_1<1>);
99
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q5_0>) {
100
+ return sizeof(spacemit_kernels::nrow_block_q5_0<1>);
101
+ } else {
102
+ assert(false);
103
+ return 0;
104
+ }
105
+ }
97
106
 
98
- const size_t k_blks = div_round_up(gemm_k, blk_len);
107
+ template <typename BLOC_TYPE> constexpr bool block_type_has_zp() {
108
+ if constexpr (std::is_same_v<BLOC_TYPE, block_q6_K> || std::is_same_v<BLOC_TYPE, block_q8_0> ||
109
+ std::is_same_v<BLOC_TYPE, block_q3_K> || std::is_same_v<BLOC_TYPE, block_q4_0> ||
110
+ std::is_same_v<BLOC_TYPE, block_mxfp4> || std::is_same_v<BLOC_TYPE, block_q5_0>) {
111
+ return false;
112
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q4_1> || std::is_same_v<BLOC_TYPE, block_q4_K> ||
113
+ std::is_same_v<BLOC_TYPE, block_q2_K> || std::is_same_v<BLOC_TYPE, block_q5_1> ||
114
+ std::is_same_v<BLOC_TYPE, block_q5_K>) {
115
+ return true;
116
+ } else {
117
+ assert(false);
118
+ return false;
119
+ }
120
+ }
99
121
 
100
- const size_t lda = k_blks * q8_blk_size(blk_len);
101
- const size_t ldc = gemm_args->ldc;
102
- const size_t ldb = k_blks * (blk_len * blk_bitwidth / 8);
103
- const std::byte * quant_a_ptr = static_cast<const std::byte *>(per_gemm_ws) + m_start * lda;
122
+ class tensor_traits_base : public ggml::cpu::tensor_traits {
123
+ public:
124
+ virtual int repack(ggml_tensor * t, const void * data, size_t data_size) = 0;
125
+ };
104
126
 
105
- const size_t zero_point_stride = gemm_args->quant_b_zp != nullptr ? sizeof(uint8_t) : 0;
106
- const size_t packed_b_stride = ldb + k_blks * (scale_stride + zero_point_stride);
107
- const std::byte * packed_quant_b_data = gemm_args->packed_quant_b_data + n_start * packed_b_stride;
127
+ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
128
+ bool work_size(int /* n_threads */, const ggml_tensor * op, size_t & size) override {
129
+ switch (op->op) {
130
+ case GGML_OP_MUL_MAT:
131
+ {
132
+ int64_t src1_nelements = ggml_nelements(op->src[1]);
133
+
134
+ if constexpr (std::is_same_v<BLOC_TYPE, block_q2_K> || std::is_same_v<BLOC_TYPE, block_q3_K>) {
135
+ size =
136
+ spacemit_kernels::div_round_up(src1_nelements, QK_K) * spacemit_kernels::q8k_blk_size(QK_K);
137
+ } else if constexpr (INTER_SIZE == QK4_0) {
138
+ size = spacemit_kernels::div_round_up(src1_nelements, QK4_0) *
139
+ spacemit_kernels::q8_blk_size(QK4_0, true);
140
+ } else if constexpr (INTER_SIZE == 256) {
141
+ size = spacemit_kernels::div_round_up(src1_nelements, 256) *
142
+ spacemit_kernels::q8_hp_blk_size(256, true, true);
143
+ } else {
144
+ GGML_ABORT("unsupported block type");
145
+ }
108
146
 
109
- float * c_ptr = gemm_args->c_ptr + m_start * ldc + n_start;
147
+ size = GGML_PAD(size, sizeof(int64_t));
110
148
 
111
- size_t count_n = 0;
112
- const size_t compute_block_count_n = m_count == 1 ? n_count : 16;
113
- for (size_t n = 0; n < n_count; n += count_n) {
114
- count_n = std::min(n_count - n, compute_block_count_n);
149
+ return true;
150
+ }
151
+ case GGML_OP_MUL_MAT_ID:
152
+ {
153
+ int64_t src1_nelements = ggml_nelements(op->src[1]);
154
+
155
+ if constexpr (std::is_same_v<BLOC_TYPE, block_q2_K> || std::is_same_v<BLOC_TYPE, block_q3_K>) {
156
+ size =
157
+ spacemit_kernels::div_round_up(src1_nelements, QK_K) * spacemit_kernels::q8k_blk_size(QK_K);
158
+ } else if constexpr (INTER_SIZE == QK4_0) {
159
+ size = spacemit_kernels::div_round_up(src1_nelements, QK4_0) *
160
+ spacemit_kernels::q8_blk_size(QK4_0, true);
161
+ } else if constexpr (INTER_SIZE == 256) {
162
+ size = spacemit_kernels::div_round_up(src1_nelements, 256) *
163
+ spacemit_kernels::q8_hp_blk_size(256, true, true);
164
+ } else {
165
+ GGML_ABORT("unsupported block type");
166
+ }
115
167
 
116
- const std::byte * a_row = quant_a_ptr;
117
- const std::byte * b_col = packed_quant_b_data + n * packed_b_stride;
118
- const std::byte * b_col_zp = (zero_point_stride != 0) ? b_col : nullptr;
119
- float * c_blk = c_ptr + n;
168
+ size = GGML_PAD(size, sizeof(int64_t));
120
169
 
121
- int32_t rows_remaining = m_count;
170
+ const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert
171
+ const int64_t ne12 = op->src[1]->ne[2]; // n_tokens
122
172
 
123
- while (rows_remaining > 0) {
124
- const auto rows_handled = sqnbitgemm_spacemit_ime::ime1::gemm_kernel_i8i4(
125
- blk_len, a_row, b_col, nullptr, b_col_zp, c_blk, rows_remaining, count_n, gemm_k, k_blks, ldc, nullptr,
126
- scale_stride);
173
+ const size_t sizeof_mmid_row_mapping = sizeof(int64_t);
174
+ size += sizeof_mmid_row_mapping * ne02 * (ne12 + 1) + (ne02 + 1) * sizeof(int64_t);
127
175
 
128
- c_blk += rows_handled * ldc;
129
- a_row += rows_handled * lda;
176
+ size = GGML_PAD(size, sizeof(int64_t));
130
177
 
131
- rows_remaining -= rows_handled;
178
+ return true;
179
+ }
180
+ default:
181
+ // GGML_ABORT("fatal error");
182
+ break;
132
183
  }
184
+ return false;
133
185
  }
134
- }
135
186
 
136
- template <int K> constexpr int QK_0() {
137
- if constexpr (K == 4) {
138
- return QK4_0;
139
- }
140
- if constexpr (K == 8) {
141
- return QK8_0;
187
+ bool compute_forward(ggml_compute_params * params, ggml_tensor * op) override {
188
+ switch (op->op) {
189
+ case GGML_OP_MUL_MAT:
190
+ switch (op->src[0]->type) {
191
+ case GGML_TYPE_Q2_K:
192
+ case GGML_TYPE_Q3_K:
193
+ case GGML_TYPE_Q4_0:
194
+ case GGML_TYPE_Q4_1:
195
+ case GGML_TYPE_Q4_K:
196
+ case GGML_TYPE_Q6_K:
197
+ case GGML_TYPE_Q8_0:
198
+ case GGML_TYPE_Q5_1:
199
+ case GGML_TYPE_Q5_K:
200
+ //case GGML_TYPE_MXFP4:
201
+ forward_mul_mat(params, op);
202
+ return true;
203
+ default:
204
+ // GGML_ABORT("fatal error: unsupported type for src0 in MUL_MAT");
205
+ return false;
206
+ }
207
+ break;
208
+ case GGML_OP_MUL_MAT_ID:
209
+ switch (op->src[0]->type) {
210
+ case GGML_TYPE_Q2_K:
211
+ case GGML_TYPE_Q3_K:
212
+ case GGML_TYPE_Q4_0:
213
+ case GGML_TYPE_Q4_1:
214
+ case GGML_TYPE_Q4_K:
215
+ case GGML_TYPE_Q6_K:
216
+ case GGML_TYPE_Q8_0:
217
+ case GGML_TYPE_Q5_1:
218
+ case GGML_TYPE_Q5_K:
219
+ //case GGML_TYPE_MXFP4:
220
+ forward_mul_mat_id(params, op);
221
+ return true;
222
+ default:
223
+ // GGML_ABORT("fatal error: unsupported type for src0 in MUL_MAT_ID");
224
+ return false;
225
+ }
226
+ break;
227
+ default:
228
+ // GGML_ABORT("fatal error");
229
+ break;
230
+ }
231
+ return false;
142
232
  }
143
- return -1;
144
- }
145
233
 
146
- template <int K, int N> struct block {
147
- ggml_half d[N]; // deltas for N qK_0 blocks
148
- uint8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_0 blocks
149
- };
234
+ void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
235
+ constexpr size_t a_blk_len = INTER_SIZE;
236
+ constexpr size_t b_blk_len = INTER_SIZE;
150
237
 
151
- template <int K, int N> struct block_with_zp {
152
- ggml_half d[N]; // deltas for N qK_1 blocks
153
- uint8_t zp[N]; // zero points for N qK_1 blocks
154
- uint8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_1 blocks
155
- };
238
+ const ggml_tensor * src0 = op->src[0];
239
+ const ggml_tensor * src1 = op->src[1];
240
+ ggml_tensor * dst = op;
156
241
 
157
- // control size
158
- static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding");
159
- static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t),
160
- "wrong block_with_zp<4,16> size/padding");
161
- static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding");
242
+ GGML_TENSOR_BINARY_OP_LOCALS
162
243
 
163
- using block_q4_0x16 = block<4, 16>;
164
- using block_q4_1x16 = block_with_zp<4, 16>;
165
- using block_q8_0x16 = block<8, 16>;
244
+ int ith = params->ith;
245
+ int nth = params->nth;
166
246
 
167
- static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) {
168
- block_q4_0x16 out;
169
- GGML_ASSERT(QK4_0 / blck_size_interleave == 2);
247
+ [[maybe_unused]] const enum ggml_type type = src0->type;
170
248
 
171
- for (int i = 0; i < 16; i++) {
172
- out.d[i] = in[i].d;
173
- }
249
+ void * w_data = (void *) src0->data;
250
+ const float * feature = (const float *) src1->data;
251
+ float * output = (float *) dst->data;
174
252
 
175
- for (int i = 0; i < 16; i++) {
176
- // [0, 15], in.d & 0x0F
177
- for (int j = 0; j < QK4_0 / 4; j++) {
178
- //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
179
- //dst [b0 b8] ......... [b7 b15]
180
- out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4);
253
+ const int64_t gemm_m = ne11 * ne12 * ne13;
254
+ const int64_t gemm_k = ne10;
255
+ const int64_t gemm_n = ne01;
256
+
257
+ spacemit_kernels::quantize_a_row_def quantize_a_row_i8;
258
+ spacemit_kernels::quantize_a_row_def quantize_a_4row_i8;
259
+ spacemit_kernels::gemm_kernel_quantize_def gemm_kernel;
260
+ bool set_kernel_impl = false;
261
+
262
+ int64_t block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len);
263
+
264
+ #if defined(RISCV64_SPACEMIT_IME2)
265
+ if (!set_kernel_impl && (global_spine_env_info.use_ime2)) {
266
+ quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8;
267
+ quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8;
268
+ block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true);
269
+
270
+ if constexpr (std::is_same_v<BLOC_TYPE, block_q6_K> || std::is_same_v<BLOC_TYPE, block_q8_0>) {
271
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i8;
272
+ set_kernel_impl = true;
273
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0> || std::is_same_v<BLOC_TYPE, block_q4_1> ||
274
+ std::is_same_v<BLOC_TYPE, block_q4_K>) {
275
+ if constexpr (INTER_SIZE == 256) {
276
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4_hp;
277
+ quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8_hp;
278
+ quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8_hp;
279
+ block_stride_a = spacemit_kernels::q8_hp_blk_size(a_blk_len, true, true);
280
+ set_kernel_impl = true;
281
+ } else {
282
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4;
283
+ quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8;
284
+ quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8;
285
+ block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true);
286
+ set_kernel_impl = true;
287
+ }
288
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q2_K>) {
289
+ quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k;
290
+ quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8k;
291
+ block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len);
292
+
293
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i2k;
294
+ set_kernel_impl = true;
295
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q3_K>) {
296
+ quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k;
297
+ quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8k;
298
+ block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len);
299
+
300
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i3k;
301
+ set_kernel_impl = true;
302
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_mxfp4>) {
303
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8mxfp4;
304
+ set_kernel_impl = true;
305
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q5_1> || std::is_same_v<BLOC_TYPE, block_q5_K> ||
306
+ std::is_same_v<BLOC_TYPE, block_q5_0>) {
307
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i5;
308
+ set_kernel_impl = true;
309
+ }
181
310
  }
182
- }
311
+ #endif
183
312
 
184
- for (int i = 0; i < 16; i++) {
185
- // [16, 31], in.d & 0xF0
186
- for (int j = 0; j < QK4_0 / 4; j++) {
187
- //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
188
- //dst [b16 b24] ......... [b23 b31]
189
- out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0);
313
+ #if defined(RISCV64_SPACEMIT_IME1)
314
+ if (!set_kernel_impl && (global_spine_env_info.use_ime1)) {
315
+ quantize_a_row_i8 = spacemit_kernels::ime1::quantize_a_row_i8;
316
+ quantize_a_4row_i8 = spacemit_kernels::ime1::quantize_a_4row_i8;
317
+
318
+ if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0> || std::is_same_v<BLOC_TYPE, block_q4_1> ||
319
+ std::is_same_v<BLOC_TYPE, block_q4_K>) {
320
+ gemm_kernel = spacemit_kernels::ime1::gemm_kernel_i8i4;
321
+ set_kernel_impl = true;
322
+ }
323
+ }
324
+ #endif
325
+ if (!set_kernel_impl) {
326
+ GGML_ABORT("no kernel implementation found for the block type");
190
327
  }
191
- }
192
328
 
193
- return out;
194
- }
329
+ const int64_t a_k_blks = spacemit_kernels::div_round_up(gemm_k, a_blk_len);
330
+ const int64_t b_k_blks = spacemit_kernels::div_round_up(gemm_k, b_blk_len);
195
331
 
196
- static block_q4_1x16 make_block_q4_1x16(block_q4_1 * in, unsigned int blck_size_interleave) {
197
- block_q4_1x16 out;
198
- GGML_ASSERT(QK4_1 / blck_size_interleave == 2);
199
-
200
- for (int i = 0; i < 16; i++) {
201
- float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d);
202
- float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m);
203
- float mid = -std::nearbyintf(m / d);
204
- mid = std::min(15.0f, std::max(0.0f, mid));
205
- out.d[i] = GGML_FP32_TO_FP16(d);
206
- out.zp[i] = static_cast<uint8_t>(mid);
207
- }
332
+ const int64_t row_stride_a = a_k_blks * block_stride_a;
333
+ const int64_t gemm_workspace_size = GGML_PAD(gemm_m * row_stride_a, alignof(int64_t));
208
334
 
209
- for (int i = 0; i < 16; i++) {
210
- // [0, 15], in.d & 0x0F
211
- for (int j = 0; j < QK4_1 / 4; j++) {
212
- //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
213
- //dst [b0 b8] ......... [b7 b15]
214
- out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4);
335
+ if (ith == 0 && params->wsize < gemm_workspace_size) {
336
+ GGML_ABORT("wsize less than gemm_workspace_size");
215
337
  }
216
- }
217
338
 
218
- for (int i = 0; i < 16; i++) {
219
- // [16, 31], in.d & 0xF0
220
- for (int j = 0; j < QK4_1 / 4; j++) {
221
- //src [b0 b16] ......... [b8 b24] ......... [b15 b31]
222
- //dst [b16 b24] ......... [b23 b31]
223
- out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0);
224
- }
225
- }
339
+ uintptr_t ws_ptr = reinterpret_cast<uintptr_t>(params->wdata);
226
340
 
227
- return out;
228
- }
341
+ void * tcm_buffer = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer;
342
+ const int64_t tcm_buffer_size = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size;
229
343
 
230
- static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t,
231
- int interleave_block,
232
- const void * GGML_RESTRICT data,
233
- size_t data_size) {
234
- GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
235
- GGML_ASSERT(interleave_block == 16);
344
+ auto * quant_a_buffer = reinterpret_cast<uint8_t *>(ws_ptr);
236
345
 
237
- constexpr int nrows_interleaved = 16;
346
+ constexpr int64_t row_align = 4;
347
+ const int64_t row_blks = spacemit_kernels::div_round_up(gemm_m, row_align);
238
348
 
239
- block_q4_0x16 * dst = (block_q4_0x16 *) t->data;
240
- const block_q4_0 * src = (const block_q4_0 *) data;
241
- block_q4_0 dst_tmp[16];
242
- int nrow = ggml_nrows(t);
243
- int nblocks = t->ne[0] / QK4_0;
349
+ const int64_t row_stride_b = b_k_blks * get_repacked_block_type_size<BLOC_TYPE, INTER_SIZE, NB_COLS>();
350
+ const int64_t per_mb_rows_wsize = row_align * row_stride_a;
351
+ const int64_t per_nb_cols_wsize = NB_COLS * row_stride_b;
244
352
 
245
- GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
353
+ const int64_t barrier_idx = static_cast<int64_t>(ith / 2);
246
354
 
247
- if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) {
248
- return -1;
249
- }
355
+ GGML_ASSERT(global_spine_env_info.init_barrier != nullptr);
356
+ GGML_ASSERT(barrier_idx < spine_init_barrier_count);
357
+ spine_barrier_t * cur_barrier = &global_spine_env_info.init_barrier[barrier_idx];
250
358
 
251
- for (int b = 0; b < nrow; b += nrows_interleaved) {
252
- for (int64_t x = 0; x < nblocks; x++) {
253
- for (int i = 0; i < nrows_interleaved; i++) {
254
- dst_tmp[i] = src[x + i * nblocks];
359
+ if (gemm_m == 1) {
360
+ int task_per_thread = spacemit_kernels::div_round_up(a_k_blks, nth);
361
+ int a_blk_start = ith * task_per_thread;
362
+ int a_blk_end = std::min(a_blk_start + task_per_thread, (int) a_k_blks);
363
+ if (a_blk_start < a_blk_end) {
364
+ quantize_a_row_i8(a_blk_len, feature + a_blk_start * a_blk_len, (a_blk_end - a_blk_start) * a_blk_len,
365
+ quant_a_buffer + a_blk_start * block_stride_a);
366
+ }
367
+ } else {
368
+ int task_per_thread = spacemit_kernels::div_round_up(row_blks, nth);
369
+ int m_row_blk_start = ith * task_per_thread;
370
+ int m_row_blk_end = std::min(m_row_blk_start + task_per_thread, (int) row_blks);
371
+ for (int m_row_blk = m_row_blk_start; m_row_blk < m_row_blk_end; m_row_blk++) {
372
+ int m_idx = m_row_blk * row_align;
373
+ int rows_tobe_handled = (gemm_m - m_idx) > row_align ? row_align : (gemm_m - m_idx);
374
+
375
+ if (rows_tobe_handled == row_align && quantize_a_4row_i8 != nullptr) {
376
+ const float * a_row_ptr = feature + m_idx * gemm_k;
377
+ auto * quant_a_row_ptr = quant_a_buffer + m_idx * row_stride_a;
378
+ quantize_a_4row_i8(a_blk_len, a_row_ptr, gemm_k, quant_a_row_ptr);
379
+ } else {
380
+ while (rows_tobe_handled) {
381
+ const float * a_row_ptr = feature + m_idx * gemm_k;
382
+ auto * quant_a_row_ptr = quant_a_buffer + m_idx * row_stride_a;
383
+ quantize_a_row_i8(a_blk_len, a_row_ptr, gemm_k, quant_a_row_ptr);
384
+ rows_tobe_handled -= 1;
385
+ m_idx += 1;
386
+ }
387
+ }
255
388
  }
256
- *dst++ = make_block_q4_0x16(dst_tmp, interleave_block);
257
389
  }
258
- src += nrows_interleaved * nblocks;
259
- }
260
- return 0;
261
390
 
262
- GGML_UNUSED(data_size);
263
- }
391
+ ggml_barrier(params->threadpool);
264
392
 
265
- static int repack_q4_1_to_q4_1_16_bl(struct ggml_tensor * t,
266
- int interleave_block,
267
- const void * GGML_RESTRICT data,
268
- size_t data_size) {
269
- GGML_ASSERT(t->type == GGML_TYPE_Q4_1);
270
- GGML_ASSERT(interleave_block == 16);
393
+ const int64_t gemm_m_stride = gemm_n / gemm_m > 64 ? gemm_m : 16;
394
+ const int64_t gemm_m_blocked = spacemit_kernels::div_round_up(gemm_m, gemm_m_stride);
395
+ const int64_t max_gemm_n_stride = spacemit_kernels::div_round_up(gemm_n * gemm_m_blocked, nth);
271
396
 
272
- constexpr int nrows_interleaved = 16;
397
+ int64_t gemm_n_stride = gemm_n;
398
+ if (max_gemm_n_stride < gemm_n) {
399
+ gemm_n_stride =
400
+ std::min(gemm_n_stride, spacemit_kernels::div_round_up(max_gemm_n_stride, NB_COLS) * NB_COLS);
401
+ }
273
402
 
274
- block_q4_1x16 * dst = (block_q4_1x16 *) t->data;
275
- const block_q4_1 * src = (const block_q4_1 *) data;
276
- block_q4_1 dst_tmp[16];
277
- int nrow = ggml_nrows(t);
278
- int nblocks = t->ne[0] / QK4_1;
403
+ if (gemm_n_stride == gemm_n && tcm_buffer != nullptr && per_mb_rows_wsize <= tcm_buffer_size) {
404
+ for (int64_t m_start = ith * row_align; m_start < gemm_m; m_start += row_align * nth) {
405
+ uint8_t * b_col = reinterpret_cast<uint8_t *>(w_data);
406
+ uint8_t * b_col_zp = block_type_has_zp<BLOC_TYPE>() ? b_col : nullptr;
279
407
 
280
- GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1));
408
+ int64_t m_row_real = std::min(gemm_m - m_start, row_align);
281
409
 
282
- if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) {
283
- return -1;
284
- }
410
+ spacemit_kernels::rvv::memcpy1d(tcm_buffer, quant_a_buffer + m_start * row_stride_a,
411
+ m_row_real * row_stride_a);
285
412
 
286
- for (int b = 0; b < nrow; b += nrows_interleaved) {
287
- for (int64_t x = 0; x < nblocks; x++) {
288
- for (int i = 0; i < nrows_interleaved; i++) {
289
- dst_tmp[i] = src[x + i * nblocks];
413
+ int64_t n_blk_real = 0;
414
+ for (int64_t ni = 0; ni < gemm_n; ni += n_blk_real, b_col += n_blk_real * row_stride_b) {
415
+ n_blk_real = std::min(gemm_n - ni, (int64_t) NB_COLS);
416
+
417
+ uint8_t * a_row_ptr = (uint8_t *) tcm_buffer;
418
+ float * c_blk = output + m_start * gemm_n + ni;
419
+
420
+ int32_t rows_remaining = m_row_real;
421
+
422
+ while (rows_remaining > 0) {
423
+ auto rows_handled = gemm_kernel(b_blk_len, a_row_ptr, b_col, b_col_zp, c_blk, rows_remaining,
424
+ n_blk_real, b_k_blks, gemm_n);
425
+
426
+ c_blk += rows_handled * gemm_n;
427
+ a_row_ptr += rows_handled * row_stride_a;
428
+
429
+ rows_remaining -= rows_handled;
430
+ }
431
+ }
290
432
  }
291
- *dst++ = make_block_q4_1x16(dst_tmp, interleave_block);
292
- }
293
- src += nrows_interleaved * nblocks;
294
- }
295
- return 0;
433
+ } else if (tcm_buffer != nullptr && per_nb_cols_wsize <= tcm_buffer_size) {
434
+ uint8_t * a_row = quant_a_buffer;
435
+ uint8_t * b_col = reinterpret_cast<uint8_t *>(tcm_buffer);
436
+ if ((gemm_workspace_size + per_nb_cols_wsize) <= tcm_buffer_size) {
437
+ a_row = (uint8_t *) tcm_buffer;
438
+ b_col = reinterpret_cast<uint8_t *>(tcm_buffer) + gemm_workspace_size;
439
+ }
440
+ uint8_t * b_col_zp = block_type_has_zp<BLOC_TYPE>() ? b_col : nullptr;
296
441
 
297
- GGML_UNUSED(data_size);
298
- }
442
+ int64_t ni = ith * NB_COLS;
443
+ int64_t nb_real = std::min(gemm_n - ni, NB_COLS);
299
444
 
300
- static inline void get_scale_min_k4(int j,
301
- const uint8_t * GGML_RESTRICT q,
302
- uint8_t * GGML_RESTRICT d,
303
- uint8_t * GGML_RESTRICT m) {
304
- if (j < 4) {
305
- *d = q[j] & 63;
306
- *m = q[j + 4] & 63;
307
- } else {
308
- *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
309
- *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);
310
- }
311
- }
445
+ if (ith % 2 == 0 && nb_real > 0) {
446
+ spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast<uint8_t *>(w_data) + ni * row_stride_b,
447
+ nb_real * row_stride_b);
448
+ if (a_row != quant_a_buffer) {
449
+ spacemit_kernels::rvv::memcpy1d(a_row, quant_a_buffer, gemm_workspace_size);
450
+ }
451
+ }
312
452
 
313
- static int repack_q4_k_to_q4_1_16_bl(struct ggml_tensor * t,
314
- int interleave_block,
315
- const void * GGML_RESTRICT data,
316
- size_t data_size) {
317
- GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
318
- GGML_ASSERT(interleave_block == 16);
319
- GGML_ASSERT(QK_K / QK4_1 == 8);
453
+ spine_barrier_wait(cur_barrier);
320
454
 
321
- constexpr int nrows_interleaved = 16;
455
+ if (ith % 2 != 0 && nb_real > 0) {
456
+ if (a_row != quant_a_buffer) {
457
+ spacemit_kernels::rvv::memcpy1d(a_row, quant_a_buffer, gemm_workspace_size);
458
+ }
459
+ spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast<uint8_t *>(w_data) + ni * row_stride_b,
460
+ nb_real * row_stride_b);
461
+ }
322
462
 
323
- block_q4_1x16 * dst = (block_q4_1x16 *) t->data;
324
- const block_q4_K * src = (const block_q4_K *) data;
325
- block_q4_1 dst_tmp[16];
326
- int nrow = ggml_nrows(t);
327
- int nblocks = t->ne[0] / QK_K;
463
+ for (; ni < gemm_n; ni += NB_COLS * nth) {
464
+ int64_t rows_remaining = gemm_m;
465
+ float * c_blk = output + ni;
466
+ auto * a_row_cur = a_row;
328
467
 
329
- if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) {
330
- return -1;
331
- }
468
+ if (ith % 2 != 0) {
469
+ spine_barrier_wait(cur_barrier);
470
+ }
332
471
 
333
- for (int b = 0; b < nrow; b += nrows_interleaved) {
334
- for (int64_t x = 0; x < nblocks; x++) {
335
- for (int j = 0; j < 8; j++) {
336
- for (int i = 0; i < nrows_interleaved; i++) {
337
- uint8_t sc, m;
338
- const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d);
339
- const float min =
340
- GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin);
341
- get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m);
342
- const float d1 = d * sc;
343
- const float m1 = min * m;
344
-
345
- dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1);
346
- dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1);
347
- // src -> [b0, b32] [b1, b33] ... [b31, b63]
348
- // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63]
349
- const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1;
350
- if (j % 2 == 0) {
351
- for (int ii = 0; ii < 16; ii++) {
352
- dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4);
353
- }
354
- } else {
355
- for (int ii = 0; ii < 16; ii++) {
356
- dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0);
357
- }
358
- }
472
+ while (rows_remaining > 0) {
473
+ auto rows_handled = gemm_kernel(b_blk_len, a_row_cur, b_col, b_col_zp, c_blk, rows_remaining,
474
+ nb_real, b_k_blks, gemm_n);
475
+
476
+ c_blk += rows_handled * gemm_n;
477
+ a_row_cur += rows_handled * row_stride_a;
478
+
479
+ rows_remaining -= rows_handled;
480
+ }
481
+
482
+ if (ith % 2 == 0) {
483
+ spine_barrier_wait(cur_barrier);
484
+ }
485
+
486
+ const int64_t next_ni = ni + NB_COLS * nth;
487
+ if (next_ni < gemm_n) {
488
+ nb_real = std::min(gemm_n - next_ni, NB_COLS);
489
+ spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast<uint8_t *>(w_data) + next_ni * row_stride_b,
490
+ nb_real * row_stride_b);
359
491
  }
360
- *dst++ = make_block_q4_1x16(dst_tmp, interleave_block);
361
492
  }
362
- }
363
- src += nrows_interleaved * nblocks;
364
- }
365
- return 0;
493
+ } else {
494
+ const int64_t task_count_m = spacemit_kernels::div_round_up(gemm_m, gemm_m_stride);
495
+ const int64_t task_count_n = spacemit_kernels::div_round_up(gemm_n, gemm_n_stride);
366
496
 
367
- GGML_UNUSED(data_size);
368
- }
497
+ int64_t task_count = task_count_m * task_count_n;
498
+ int64_t task_per_thread = (task_count + nth - 1) / nth;
499
+ int64_t start = ith * task_per_thread;
500
+ int64_t end = std::min((ith + 1) * task_per_thread, task_count);
501
+ for (int64_t compute_idx = start; compute_idx < end; compute_idx++) {
502
+ const auto tid_n = compute_idx / task_count_m;
503
+ const auto tid_m = compute_idx % task_count_m;
369
504
 
370
- namespace ggml::cpu::riscv64_spacemit {
505
+ const int64_t m_start = tid_m * gemm_m_stride;
506
+ const int64_t m_count = std::min(gemm_m - m_start, (int64_t) gemm_m_stride);
371
507
 
372
- template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
373
- int repack(struct ggml_tensor *, const void *, size_t);
508
+ const int64_t n_start = tid_n * gemm_n_stride;
509
+ const int64_t n_count = std::min(gemm_n - n_start, (int64_t) gemm_n_stride);
374
510
 
375
- template <> int repack<block_q4_0, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {
376
- return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size);
377
- }
511
+ const int64_t n_blk = m_count == 1 ? n_count : NB_COLS;
378
512
 
379
- template <> int repack<block_q4_1, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {
380
- return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size);
381
- }
513
+ uint8_t * b_col = reinterpret_cast<uint8_t *>(w_data) + n_start * row_stride_b;
514
+ uint8_t * b_col_zp = block_type_has_zp<BLOC_TYPE>() ? b_col : nullptr;
382
515
 
383
- template <> int repack<block_q4_K, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) {
384
- return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size);
385
- }
516
+ int64_t n_blk_real = 0;
517
+ for (int64_t ni = 0; ni < n_count; ni += n_blk_real, b_col += n_blk_real * row_stride_b) {
518
+ n_blk_real = std::min(n_count - ni, n_blk);
386
519
 
387
- class tensor_traits_base : public ggml::cpu::tensor_traits {
388
- public:
389
- virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
390
- };
520
+ uint8_t * a_row = quant_a_buffer + m_start * row_stride_a;
391
521
 
392
- template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base {
393
- bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
394
- switch (op->op) {
395
- case GGML_OP_MUL_MAT:
396
- size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1])) * 4;
397
- size = ((size + QK4_0 - 1) / QK4_0) * (QK4_0 * sizeof(float) + sizeof(float));
398
- return true;
399
- default:
400
- // GGML_ABORT("fatal error");
401
- break;
402
- }
403
- return false;
404
- }
522
+ float * c_blk = output + m_start * gemm_n + n_start + ni;
405
523
 
406
- bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
407
- switch (op->op) {
408
- case GGML_OP_MUL_MAT:
409
- if (op->src[0]->type == GGML_TYPE_Q4_0 || //
410
- op->src[0]->type == GGML_TYPE_Q4_1 || //
411
- op->src[0]->type == GGML_TYPE_Q4_K) {
412
- forward_mul_mat_q4(params, op);
413
- return true;
524
+ int64_t rows_remaining = m_count;
525
+
526
+ uint8_t * b_col_cur = b_col;
527
+ uint8_t * b_col_zp_cur = b_col_zp;
528
+
529
+ while (rows_remaining > 0) {
530
+ auto rows_handled = gemm_kernel(b_blk_len, a_row, b_col_cur, b_col_zp_cur, c_blk,
531
+ rows_remaining, n_blk_real, b_k_blks, gemm_n);
532
+
533
+ c_blk += rows_handled * gemm_n;
534
+ a_row += rows_handled * row_stride_a;
535
+
536
+ rows_remaining -= rows_handled;
537
+ }
414
538
  }
415
- default:
416
- // GGML_ABORT("fatal error");
417
- break;
539
+ }
418
540
  }
419
- return false;
420
541
  }
421
542
 
422
- void forward_mul_mat_q4(ggml_compute_params * params, ggml_tensor * op) {
543
+ void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) {
544
+ constexpr size_t a_blk_len = INTER_SIZE;
545
+ constexpr size_t b_blk_len = INTER_SIZE;
546
+
423
547
  const ggml_tensor * src0 = op->src[0];
424
548
  const ggml_tensor * src1 = op->src[1];
549
+ const ggml_tensor * ids = op->src[2];
425
550
  ggml_tensor * dst = op;
426
551
 
427
552
  GGML_TENSOR_BINARY_OP_LOCALS
@@ -429,133 +554,381 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
429
554
  int ith = params->ith;
430
555
  int nth = params->nth;
431
556
 
432
- [[maybe_unused]] const enum ggml_type type = src0->type;
557
+ // row groups
558
+ const int n_ids = ids->ne[0]; // n_expert_used
559
+ const int n_as = ne02; // n_expert
560
+
561
+ struct mmid_row_mapping {
562
+ int32_t i1;
563
+ int32_t i2;
564
+ };
565
+
566
+ spacemit_kernels::quantize_a_row_def quantize_a_row_i8;
567
+ spacemit_kernels::gemm_kernel_quantize_def gemm_kernel;
568
+ spacemit_kernels::moe_gemm_kernel_quantize_def moe_gemm_kernel_m2;
569
+ bool set_kernel_impl = false;
570
+ size_t block_stride_a = spacemit_kernels::q8_blk_size(QK4_0);
571
+
572
+ #if defined(RISCV64_SPACEMIT_IME2)
573
+ if (!set_kernel_impl && (global_spine_env_info.use_ime2)) {
574
+ quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8;
575
+ block_stride_a = spacemit_kernels::q8_blk_size(QK4_0, true);
576
+
577
+ if constexpr (std::is_same_v<BLOC_TYPE, block_q6_K> || std::is_same_v<BLOC_TYPE, block_q8_0>) {
578
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i8;
579
+ set_kernel_impl = true;
580
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0> || std::is_same_v<BLOC_TYPE, block_q4_1> ||
581
+ std::is_same_v<BLOC_TYPE, block_q4_K>) {
582
+ if constexpr (INTER_SIZE == 256) {
583
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4_hp;
584
+ quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8_hp;
585
+ block_stride_a = spacemit_kernels::q8_hp_blk_size(a_blk_len, true, true);
586
+ set_kernel_impl = true;
587
+ } else {
588
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4;
589
+ moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8i4;
590
+ quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8;
591
+ block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true);
592
+ set_kernel_impl = true;
593
+ }
594
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q2_K>) {
595
+ quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k;
596
+ block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len);
597
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i2k;
598
+ set_kernel_impl = true;
599
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q3_K>) {
600
+ quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k;
601
+ block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len);
602
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i3k;
603
+ set_kernel_impl = true;
604
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_mxfp4>) {
605
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8mxfp4;
606
+ moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8mxfp4;
607
+ set_kernel_impl = true;
608
+ } else if constexpr (std::is_same_v<BLOC_TYPE, block_q5_1> || std::is_same_v<BLOC_TYPE, block_q5_K> ||
609
+ std::is_same_v<BLOC_TYPE, block_q5_0>) {
610
+ gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i5;
611
+ moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8i5;
612
+ set_kernel_impl = true;
613
+ }
614
+ }
615
+ #endif
433
616
 
434
- void * w_data = (void *) src0->data;
435
- const float * feature = (const float *) src1->data;
436
- float * output = (float *) dst->data;
617
+ #if defined(RISCV64_SPACEMIT_IME1)
618
+ if (!set_kernel_impl && (global_spine_env_info.use_ime1)) {
619
+ quantize_a_row_i8 = spacemit_kernels::ime1::quantize_a_row_i8;
620
+
621
+ if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0> || std::is_same_v<BLOC_TYPE, block_q4_1> ||
622
+ std::is_same_v<BLOC_TYPE, block_q4_K>) {
623
+ gemm_kernel = spacemit_kernels::ime1::gemm_kernel_i8i4;
624
+ set_kernel_impl = true;
625
+ }
626
+ }
627
+ #endif
628
+ if (!set_kernel_impl) {
629
+ GGML_ABORT("no kernel implementation found for the block type");
630
+ }
437
631
 
438
- const size_t batch_feature = ne12 * ne13;
439
- [[maybe_unused]] const size_t batch_weight = ne02 * ne03;
440
- const size_t gemm_m = ne11;
441
- const size_t gemm_k = ne10;
442
- const size_t gemm_n = ne01;
632
+ const size_t a_k_blks = spacemit_kernels::div_round_up(ne10, a_blk_len);
633
+ const size_t b_k_blks = spacemit_kernels::div_round_up(ne10, b_blk_len);
443
634
 
444
- GGML_ASSERT(batch_weight == 1);
635
+ const size_t nbw1 = a_k_blks * block_stride_a;
636
+ const size_t nbw2 = ne11 * nbw1;
637
+ const size_t nbw3 = nbw2 * ne12;
638
+ const size_t gemm_workspace_size = GGML_PAD(nbw3, alignof(int64_t));
445
639
 
446
- const size_t block_count_k = div_round_up(gemm_k, QK4_0);
447
- const size_t per_gemm_workspace_size = gemm_m * block_count_k * q8_blk_size(QK4_0);
448
- const size_t per_gemm_workspace_stride =
449
- div_round_up(per_gemm_workspace_size, alignof(uint64_t)) * alignof(uint64_t);
450
- const size_t gemm_workspace_size = batch_feature * per_gemm_workspace_stride;
451
- const size_t desired_wsize = gemm_workspace_size + alignof(uint64_t) - 1;
640
+ const uintptr_t ws_ptr = reinterpret_cast<uintptr_t>(params->wdata);
641
+ auto * quant_a_buffer = reinterpret_cast<uint8_t *>(ws_ptr);
452
642
 
453
- if (ith == 0 && params->wsize < desired_wsize) {
454
- throw std::runtime_error("wsize less than desired_wsize");
643
+ if (ne11 == 1) {
644
+ for (int64_t ii = ith; ii < ne12 * a_k_blks; ii += nth) {
645
+ int64_t i12 = ii / a_k_blks;
646
+ int64_t ak_blk_id = ii % a_k_blks;
647
+ quantize_a_row_i8(a_blk_len, (float *) ((char *) src1->data + i12 * nb12) + ak_blk_id * a_blk_len,
648
+ a_blk_len, quant_a_buffer + i12 * nbw2 + ak_blk_id * block_stride_a);
649
+ }
650
+ } else {
651
+ for (int64_t ii = ith; ii < ne12 * ne11; ii += nth) {
652
+ int64_t i12 = ii / ne11;
653
+ int64_t i11 = ii % ne11;
654
+ quantize_a_row_i8(a_blk_len, (float *) ((char *) src1->data + i12 * nb12 + i11 * nb11), ne10,
655
+ quant_a_buffer + i12 * nbw2 + i11 * nbw1);
656
+ }
455
657
  }
456
658
 
457
- std::vector<qnbitgemm_spacemit_ime_args> qnbitgemm_args(batch_feature);
659
+ #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) *ne12 + (i1)]
458
660
 
459
- for (size_t i = 0; i < batch_feature; i++) {
460
- qnbitgemm_args[i].a_ptr = feature + gemm_m * gemm_k * i;
461
- qnbitgemm_args[i].lda = gemm_k;
462
- qnbitgemm_args[i].packed_quant_b_data = (const std::byte *) w_data;
463
- qnbitgemm_args[i].quant_b_scale = nullptr;
661
+ int64_t * matrix_row_counts = (int64_t *) (ws_ptr + gemm_workspace_size);
662
+ int32_t * valid_ep_count = (int32_t *) (matrix_row_counts + n_as);
663
+ int32_t * valid_act_count = (int32_t *) (valid_ep_count + 1);
664
+ int64_t * valid_matrix_row_counts = (int64_t *) (valid_act_count + 1);
665
+ mmid_row_mapping * matrix_rows = (mmid_row_mapping *) (valid_matrix_row_counts + n_as);
464
666
 
465
- if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0>) {
466
- qnbitgemm_args[i].quant_b_zp = nullptr;
467
- } else {
468
- qnbitgemm_args[i].quant_b_zp = w_data;
667
+ if (ith == 0) {
668
+ // initialize matrix_row_counts
669
+ memset(matrix_row_counts, 0, n_as * sizeof(int64_t));
670
+
671
+ // group rows by src0 matrix
672
+ for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
673
+ for (int32_t id = 0; id < n_ids; ++id) {
674
+ const int32_t i02 =
675
+ *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
676
+
677
+ GGML_ASSERT(i02 >= 0 && i02 < n_as);
678
+
679
+ MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 };
680
+ matrix_row_counts[i02] += 1;
681
+ }
469
682
  }
470
683
 
471
- qnbitgemm_args[i].bias = nullptr;
472
- qnbitgemm_args[i].c_ptr = output + gemm_m * gemm_n * i;
473
- qnbitgemm_args[i].ldc = gemm_n;
684
+ int32_t valid_ep_count_t = 0;
685
+ int32_t valid_act_count_t = 0;
686
+ for (int cur_a = 0; cur_a < n_as; ++cur_a) {
687
+ const int64_t cne1 = matrix_row_counts[cur_a];
688
+ if (cne1 == 0) {
689
+ continue;
690
+ }
691
+ valid_matrix_row_counts[valid_ep_count_t] = cur_a;
692
+ valid_act_count_t += cne1;
693
+ valid_ep_count_t += 1;
694
+ }
695
+ valid_ep_count[0] = valid_ep_count_t;
696
+ valid_act_count[0] = valid_act_count_t;
474
697
  }
475
698
 
476
- const uintptr_t ws_ptr = reinterpret_cast<uintptr_t>(params->wdata);
477
- void * ws = reinterpret_cast<void *>((ws_ptr + alignof(uint64_t) - 1) & (~(alignof(uint64_t) - 1)));
478
- const size_t quant_a_stride = block_count_k * q8_blk_size(QK4_0);
699
+ const int64_t barrier_idx = static_cast<int64_t>(ith / 2);
479
700
 
480
- {
481
- constexpr size_t block_size_m = 4;
482
- size_t per_gemm_block_count_m = div_round_up(gemm_m, block_size_m);
483
- int32_t task_count = batch_feature * per_gemm_block_count_m;
484
- int32_t task_per_thread = (task_count + nth - 1) / nth;
485
- int32_t start = ith * task_per_thread;
486
- int32_t end = std::min((ith + 1) * task_per_thread, task_count);
487
- for (int32_t compute_idx = start; compute_idx < end; compute_idx++) {
488
- int32_t gemm_idx = compute_idx / per_gemm_block_count_m;
489
- int32_t block_idx_in_gemm = compute_idx % per_gemm_block_count_m;
490
- int32_t m_idx = block_idx_in_gemm * block_size_m;
491
- const qnbitgemm_spacemit_ime_args & data = qnbitgemm_args[gemm_idx];
492
- int32_t rows_tobe_handled = (gemm_m - m_idx) > block_size_m ? block_size_m : (gemm_m - m_idx);
493
-
494
- if (rows_tobe_handled == block_size_m) {
495
- const float * a_row_ptr = data.a_ptr + m_idx * data.lda;
496
- std::byte * quant_a_row_ptr =
497
- static_cast<std::byte *>(ws) + gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride;
498
- sqnbitgemm_spacemit_ime::ime1::quantize_a_4row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr);
499
- } else {
500
- while (rows_tobe_handled) {
501
- const float * a_row_ptr = data.a_ptr + m_idx * data.lda;
502
- std::byte * quant_a_row_ptr = static_cast<std::byte *>(ws) +
503
- gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride;
504
- sqnbitgemm_spacemit_ime::ime1::quantize_a_row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr);
505
- rows_tobe_handled -= 1;
506
- m_idx += 1;
701
+ GGML_ASSERT(global_spine_env_info.init_barrier != nullptr);
702
+ GGML_ASSERT(barrier_idx < spine_init_barrier_count);
703
+ spine_barrier_t * cur_barrier = &global_spine_env_info.init_barrier[barrier_idx];
704
+
705
+ ggml_barrier(params->threadpool);
706
+
707
+ const size_t row_stride_b = b_k_blks * get_repacked_block_type_size<BLOC_TYPE, INTER_SIZE, NB_COLS>();
708
+ const size_t expert_b_stride = ne01 * row_stride_b;
709
+ const size_t per_nb_cols_wsize = NB_COLS * row_stride_b;
710
+
711
+ std::array<const uint8_t *, 2> src_workspaces;
712
+ std::array<float *, 2> dst_workspaces;
713
+
714
+ auto * tcm_buffer = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer;
715
+ const auto tcm_buffer_size = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size;
716
+
717
+ const auto valid_ep_count_t = valid_ep_count[0];
718
+ const auto valid_act_count_t = valid_act_count[0];
719
+
720
+ int nth_es = 1;
721
+ int nth_n = nth;
722
+
723
+ int ith_es = ith % nth_es;
724
+ int ith_n = (ith / nth_es) % nth_n;
725
+
726
+ if (valid_ep_count_t % nth == 0 && tcm_buffer != nullptr && valid_ep_count_t == n_as &&
727
+ valid_act_count_t == n_as && per_nb_cols_wsize <= tcm_buffer_size) {
728
+ for (int64_t valid_id = ith; valid_id < valid_ep_count_t; valid_id += nth) {
729
+ const int64_t cur_a = valid_matrix_row_counts[valid_id];
730
+
731
+ auto * src0_cur = (uint8_t *) src0->data + cur_a * expert_b_stride;
732
+
733
+ mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, 0);
734
+ const int id = row_mapping.i1;
735
+ const int64_t i11 = id % ne11;
736
+ const int64_t i12 = row_mapping.i2;
737
+ const int64_t i1 = id;
738
+ const int64_t i2 = i12;
739
+
740
+ auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2);
741
+ float * c_blk = (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2));
742
+
743
+ uint8_t * a_row = src1_col;
744
+ uint8_t * b_col = reinterpret_cast<uint8_t *>(tcm_buffer);
745
+ if ((nbw1 + per_nb_cols_wsize) <= tcm_buffer_size) {
746
+ a_row = (uint8_t *) tcm_buffer;
747
+ b_col = reinterpret_cast<uint8_t *>(tcm_buffer) + nbw1;
748
+ }
749
+ uint8_t * b_col_zp = block_type_has_zp<BLOC_TYPE>() ? b_col : nullptr;
750
+
751
+ if (ith % 2 == 0) {
752
+ spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast<uint8_t *>(src0_cur), per_nb_cols_wsize);
753
+
754
+ if (a_row != src1_col) {
755
+ spacemit_kernels::rvv::memcpy1d(a_row, src1_col, nbw1);
756
+ }
757
+ }
758
+
759
+ spine_barrier_wait(cur_barrier);
760
+
761
+ if (ith % 2 != 0) {
762
+ if (a_row != src1_col) {
763
+ spacemit_kernels::rvv::memcpy1d(a_row, src1_col, nbw1);
764
+ }
765
+
766
+ spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast<uint8_t *>(src0_cur), per_nb_cols_wsize);
767
+ }
768
+
769
+ int64_t nb_real = std::min(ne01, NB_COLS);
770
+ for (int64_t ni = 0; ni < ne01; ni += NB_COLS) {
771
+ if (ith % 2 != 0) {
772
+ spine_barrier_wait(cur_barrier);
773
+ }
774
+
775
+ gemm_kernel(b_blk_len, a_row, b_col, b_col_zp, c_blk + ni, 1, nb_real, b_k_blks, ne01);
776
+
777
+ if (ith % 2 == 0) {
778
+ spine_barrier_wait(cur_barrier);
779
+ }
780
+
781
+ const int64_t next_ni = ni + NB_COLS;
782
+ if (next_ni < ne01) {
783
+ nb_real = std::min(ne01 - next_ni, NB_COLS);
784
+ spacemit_kernels::rvv::memcpy1d(
785
+ b_col, reinterpret_cast<uint8_t *>(src0_cur) + next_ni * row_stride_b, per_nb_cols_wsize);
507
786
  }
508
787
  }
509
788
  }
510
- }
789
+ } else {
790
+ for (int64_t valid_id = ith_es; valid_id < valid_ep_count_t; valid_id += nth_es) {
791
+ const int64_t cur_a = valid_matrix_row_counts[valid_id];
792
+ const int64_t cne1 = matrix_row_counts[cur_a];
511
793
 
512
- ggml_barrier(params->threadpool);
794
+ int64_t src1_cur_start = 0;
795
+ int64_t src1_cur_end = cne1;
513
796
 
514
- if (ith >= ggml::cpu::riscv64_spacemit::num_ai_cores) {
515
- return;
516
- }
517
- nth = std::min(nth, int{ ggml::cpu::riscv64_spacemit::num_ai_cores });
518
-
519
- size_t threads_per_gemm = nth / batch_feature;
520
- constexpr size_t gemm_m_stride = 128;
521
- size_t nc = gemm_n;
522
- const size_t gemm_m_blocked = div_round_up(gemm_m, gemm_m_stride);
523
- const size_t max_nc = div_round_up(gemm_n * gemm_m_blocked, threads_per_gemm);
524
- if (max_nc < nc) {
525
- nc = std::min(nc, div_round_up(max_nc, QGEMM_STRIDEN_THREAD_ALIGN) * QGEMM_STRIDEN_THREAD_ALIGN);
526
- }
527
- const size_t gemm_n_stride = nc;
528
- const size_t thread_count_m = div_round_up(gemm_m, gemm_m_stride);
529
- const size_t thread_count_n = div_round_up(gemm_n, gemm_n_stride);
530
- threads_per_gemm = thread_count_m * thread_count_n;
797
+ int64_t src0_cur_start = (ith_n * ne01) / nth_n;
798
+ int64_t src0_cur_end = MIN(((ith_n + 1) * ne01) / nth_n, ne01);
531
799
 
532
- {
533
- int task_count = batch_feature * threads_per_gemm;
534
- int task_per_thread = (task_count + nth - 1) / nth;
535
- int start = ith * task_per_thread;
536
- int end = std::min((ith + 1) * task_per_thread, task_count);
537
- for (int compute_idx = start; compute_idx < end; compute_idx++) {
538
- const auto gemm_i = compute_idx / threads_per_gemm;
539
- const auto blk_i = compute_idx % threads_per_gemm;
540
- const auto * data = &qnbitgemm_args[gemm_i];
800
+ if (src1_cur_start >= src1_cur_end || src0_cur_start >= src0_cur_end) {
801
+ continue;
802
+ }
803
+
804
+ src0_cur_start =
805
+ (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
806
+ src0_cur_end =
807
+ (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
808
+
809
+ auto * src0_cur = (uint8_t *) src0->data + cur_a * expert_b_stride + src0_cur_start * row_stride_b;
810
+ uint8_t * b_col_zp = block_type_has_zp<BLOC_TYPE>() ? src0_cur : nullptr;
811
+
812
+ size_t extra_tcm_buffer_size = tcm_buffer_size;
813
+ void * extra_tcm_buffer = tcm_buffer;
814
+ if (tcm_buffer != nullptr && (src1_cur_end - src1_cur_start) >= 4 &&
815
+ (src0_cur_end - src0_cur_start) * row_stride_b <= tcm_buffer_size) {
816
+ spacemit_kernels::rvv::memcpy1d(tcm_buffer, src0_cur,
817
+ (src0_cur_end - src0_cur_start) * row_stride_b);
818
+ src0_cur = reinterpret_cast<uint8_t *>(tcm_buffer);
819
+ b_col_zp = block_type_has_zp<BLOC_TYPE>() ? src0_cur : nullptr;
820
+ extra_tcm_buffer_size -= (src0_cur_end - src0_cur_start) * row_stride_b;
821
+ extra_tcm_buffer = reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(tcm_buffer) +
822
+ (src0_cur_end - src0_cur_start) * row_stride_b);
823
+ }
541
824
 
542
- const auto tid_n = blk_i / thread_count_m;
543
- const auto tid_m = blk_i % thread_count_m;
825
+ int ir1 = src1_cur_start;
544
826
 
545
- const size_t m_start = tid_m * gemm_m_stride;
546
- const size_t m_count = std::min(gemm_m - m_start, (size_t) gemm_m_stride);
827
+ if (extra_tcm_buffer_size >= nbw1 && extra_tcm_buffer != nullptr) {
828
+ int64_t quant_a_tile_size = extra_tcm_buffer_size / nbw1;
829
+ do {
830
+ quant_a_tile_size = MIN(quant_a_tile_size, src1_cur_end - ir1);
547
831
 
548
- const size_t n_start = tid_n * gemm_n_stride;
549
- const size_t n_count = std::min(gemm_n - n_start, (size_t) gemm_n_stride);
832
+ uint8_t * quant_a_tile_buffer = reinterpret_cast<uint8_t *>(extra_tcm_buffer);
550
833
 
551
- void * per_gemm_ws = reinterpret_cast<std::byte *>(ws) + gemm_i * per_gemm_workspace_stride;
834
+ int iir1 = ir1;
835
+ for (; iir1 < (ir1 + quant_a_tile_size); ++iir1) {
836
+ mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, iir1);
552
837
 
553
- sqnbitgemm_spacemit_ime_i8i4(QK4_0, gemm_k, data, per_gemm_ws, m_start, m_count, n_start, n_count);
838
+ const int id = row_mapping.i1; // selected expert index
839
+
840
+ const int64_t i11 = id % ne11;
841
+ const int64_t i12 = row_mapping.i2; // row index in src1
842
+
843
+ auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2);
844
+ spacemit_kernels::rvv::memcpy1d(quant_a_tile_buffer, src1_col, nbw1);
845
+ quant_a_tile_buffer = quant_a_tile_buffer + nbw1;
846
+ }
847
+
848
+ quant_a_tile_buffer = reinterpret_cast<uint8_t *>(extra_tcm_buffer);
849
+ iir1 = ir1;
850
+
851
+ if (moe_gemm_kernel_m2 != nullptr) {
852
+ for (; iir1 < (ir1 + quant_a_tile_size - 1); iir1 += 2, quant_a_tile_buffer += 2 * nbw1) {
853
+ mmid_row_mapping row_mapping_0 = MMID_MATRIX_ROW(cur_a, iir1);
854
+ mmid_row_mapping row_mapping_1 = MMID_MATRIX_ROW(cur_a, iir1 + 1);
855
+
856
+ src_workspaces[0] = quant_a_tile_buffer;
857
+ src_workspaces[1] = quant_a_tile_buffer + nbw1;
858
+
859
+ dst_workspaces[0] =
860
+ (float *) ((char *) dst->data + (row_mapping_0.i1 * nb1 + row_mapping_0.i2 * nb2)) +
861
+ src0_cur_start;
862
+ dst_workspaces[1] = (float *) ((char *) dst->data +
863
+ ((row_mapping_1.i1) * nb1 + (row_mapping_1.i2) * nb2)) +
864
+ src0_cur_start;
865
+ moe_gemm_kernel_m2(b_blk_len, src_workspaces.data(), src0_cur, b_col_zp,
866
+ dst_workspaces.data(), 1, src0_cur_end - src0_cur_start, b_k_blks,
867
+ ne01);
868
+ }
869
+ }
870
+
871
+ for (; iir1 < (ir1 + quant_a_tile_size); iir1++, quant_a_tile_buffer += nbw1) {
872
+ mmid_row_mapping row_mapping_0 = MMID_MATRIX_ROW(cur_a, iir1);
873
+
874
+ gemm_kernel(
875
+ b_blk_len, quant_a_tile_buffer, src0_cur, b_col_zp,
876
+ (float *) ((char *) dst->data + (row_mapping_0.i1 * nb1 + row_mapping_0.i2 * nb2)) +
877
+ src0_cur_start,
878
+ 1, src0_cur_end - src0_cur_start, b_k_blks, ne01);
879
+ }
880
+
881
+ ir1 += quant_a_tile_size;
882
+ } while (ir1 < src1_cur_end);
883
+ } else {
884
+ if (moe_gemm_kernel_m2 != nullptr) {
885
+ for (; ir1 < src1_cur_end - 1; ir1 += 2) {
886
+ for (int iir1 = 0; iir1 < 2; ++iir1) {
887
+ mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1 + iir1);
888
+
889
+ const int id = row_mapping.i1; // selected expert index
890
+
891
+ const int64_t i11 = id % ne11;
892
+ const int64_t i12 = row_mapping.i2; // row index in src1
893
+
894
+ const int64_t i1 = id; // selected expert index
895
+ const int64_t i2 = i12; // row
896
+
897
+ src_workspaces[iir1] = quant_a_buffer + (i11 * nbw1 + i12 * nbw2);
898
+
899
+ dst_workspaces[iir1] =
900
+ (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start;
901
+ }
902
+
903
+ moe_gemm_kernel_m2(b_blk_len, src_workspaces.data(), src0_cur, b_col_zp,
904
+ dst_workspaces.data(), 1, src0_cur_end - src0_cur_start, b_k_blks, ne01);
905
+ }
906
+ }
907
+
908
+ for (; ir1 < src1_cur_end; ir1++) {
909
+ mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
910
+
911
+ const int id = row_mapping.i1; // selected expert index
912
+
913
+ const int64_t i11 = id % ne11;
914
+ const int64_t i12 = row_mapping.i2; // row index in src1
915
+
916
+ const int64_t i1 = id; // selected expert index
917
+ const int64_t i2 = i12; // row
918
+
919
+ auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2);
920
+
921
+ gemm_kernel(b_blk_len, src1_col, src0_cur, b_col_zp,
922
+ (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, 1,
923
+ src0_cur_end - src0_cur_start, b_k_blks, ne01);
924
+ }
925
+ }
554
926
  }
555
927
  }
928
+ #undef MMID_MATRIX_ROW
556
929
  }
557
930
 
558
- int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
931
+ int repack(ggml_tensor * t, const void * data, size_t data_size) override {
559
932
  GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type),
560
933
  (int) NB_COLS, (int) INTER_SIZE);
561
934
  return ggml::cpu::riscv64_spacemit::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size);
@@ -563,309 +936,464 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
563
936
  };
564
937
 
565
938
  class tensor_traits_common : public tensor_traits_base {
566
- bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
939
+ bool work_size(int n_threads, const ggml_tensor * op, size_t & size) override {
567
940
  switch (op->op) {
568
- case GGML_OP_NORM:
569
- case GGML_OP_RMS_NORM:
570
- size = 0;
941
+ case GGML_OP_FLASH_ATTN_EXT:
942
+ {
943
+ const int n_tasks = n_threads;
944
+ const int64_t neq2 = op->src[0]->ne[2]; // number of query heads
945
+ const int64_t DK = op->src[1]->ne[0];
946
+ const int64_t DV = op->src[2]->ne[0]; // DV
947
+
948
+ // Tiled flash attention scratch (tile sizes defined in common.h)
949
+ // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding
950
+ size_t prefill = sizeof(float) *
951
+ (GGML_FA_TILE_Q * DK + 2 * GGML_FA_TILE_Q * GGML_FA_TILE_KV + GGML_FA_TILE_Q * DV +
952
+ GGML_FA_TILE_KV * DV + GGML_FA_TILE_KV * DK) *
953
+ n_tasks;
954
+
955
+ // Decode path: n_kv_chunks = n_tasks (one chunk per thread)
956
+ // Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ
957
+ size_t n_chunks = n_tasks;
958
+ size_t decode = sizeof(float) * (neq2 * n_chunks * (2 + DV) + n_tasks * (DK + 2 * DV));
959
+
960
+ size = MAX(prefill, decode);
961
+ }
571
962
  return true;
572
963
  default:
573
- // GGML_ABORT("fatal error");
574
964
  break;
575
965
  }
576
966
  return false;
577
967
  }
578
968
 
579
- bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
969
+ bool compute_forward(ggml_compute_params * params, ggml_tensor * op) override {
580
970
  switch (op->op) {
581
971
  case GGML_OP_NORM:
582
- forward_norm_f32(params, op);
583
- return true;
972
+ switch (op->src[0]->type) {
973
+ case GGML_TYPE_F32:
974
+ spacemit_kernels::rvv::forward_norm_f32(params, op);
975
+ return true;
976
+ default:
977
+ GGML_ABORT("fatal error");
978
+ }
584
979
  case GGML_OP_RMS_NORM:
585
- forward_rms_norm_f32(params, op);
980
+ switch (op->src[0]->type) {
981
+ case GGML_TYPE_F32:
982
+ spacemit_kernels::rvv::forward_rms_norm_f32(params, op);
983
+ return true;
984
+ default:
985
+ GGML_ABORT("fatal error");
986
+ }
987
+ case GGML_OP_ADD:
988
+ switch (op->src[0]->type) {
989
+ case GGML_TYPE_F32:
990
+ spacemit_kernels::rvv::forward_binary<GGML_OP_ADD, float>(params, op);
991
+ return true;
992
+ case GGML_TYPE_F16:
993
+ spacemit_kernels::rvv::forward_binary<GGML_OP_ADD, _Float16>(params, op);
994
+ return true;
995
+ default:
996
+ ggml_compute_forward_add(params, op);
997
+ return true;
998
+ }
999
+ case GGML_OP_SUB:
1000
+ switch (op->src[0]->type) {
1001
+ case GGML_TYPE_F32:
1002
+ spacemit_kernels::rvv::forward_binary<GGML_OP_SUB, float>(params, op);
1003
+ return true;
1004
+ case GGML_TYPE_F16:
1005
+ spacemit_kernels::rvv::forward_binary<GGML_OP_SUB, _Float16>(params, op);
1006
+ return true;
1007
+ default:
1008
+ ggml_compute_forward_sub(params, op);
1009
+ return true;
1010
+ }
1011
+ case GGML_OP_MUL:
1012
+ switch (op->src[0]->type) {
1013
+ case GGML_TYPE_F32:
1014
+ spacemit_kernels::rvv::forward_binary<GGML_OP_MUL, float>(params, op);
1015
+ return true;
1016
+ case GGML_TYPE_F16:
1017
+ spacemit_kernels::rvv::forward_binary<GGML_OP_MUL, _Float16>(params, op);
1018
+ return true;
1019
+ default:
1020
+ ggml_compute_forward_mul(params, op);
1021
+ return true;
1022
+ }
1023
+ case GGML_OP_DIV:
1024
+ switch (op->src[0]->type) {
1025
+ case GGML_TYPE_F32:
1026
+ spacemit_kernels::rvv::forward_binary<GGML_OP_DIV, float>(params, op);
1027
+ return true;
1028
+ case GGML_TYPE_F16:
1029
+ spacemit_kernels::rvv::forward_binary<GGML_OP_DIV, _Float16>(params, op);
1030
+ return true;
1031
+ default:
1032
+ ggml_compute_forward_div(params, op);
1033
+ return true;
1034
+ }
1035
+ case GGML_OP_FLASH_ATTN_EXT:
1036
+ forward_flash_attn_ext_f16(params, op);
1037
+ return true;
1038
+ case GGML_OP_CONT:
1039
+ {
1040
+ const ggml_tensor * src0 = op->src[0];
1041
+ if (op->type == src0->type && op->nb[0] != src0->nb[0] && op->nb[0] == src0->nb[1] &&
1042
+ op->ne[3] * op->ne[2] * op->nb[2] == src0->ne[3] * src0->ne[2] * src0->nb[2]) {
1043
+ spacemit_kernels::rvv::forward_cont_with_permute(params, op);
1044
+ } else {
1045
+ ggml_compute_forward_cont(params, op);
1046
+ }
1047
+ return true;
1048
+ }
1049
+ case GGML_OP_CPY:
1050
+ {
1051
+ const ggml_tensor * src0 = op->src[0];
1052
+ if (op->type == src0->type && op->nb[0] == src0->nb[1] && src0->nb[0] != src0->nb[1] &&
1053
+ ggml_nelements(src0) == ggml_nelements(op)) {
1054
+ spacemit_kernels::rvv::forward_cpy_with_permute(params, op);
1055
+ } else {
1056
+ ggml_compute_forward_cpy(params, op);
1057
+ }
1058
+ return true;
1059
+ }
1060
+ case GGML_OP_REPEAT:
1061
+ {
1062
+ const bool rows_equal = ggml_nrows(op->src[0]) == ggml_nrows(op);
1063
+ const bool broadcast_or_equal = op->src[0]->ne[0] == 1 || op->src[0]->ne[0] == op->ne[0];
1064
+
1065
+ if (rows_equal && broadcast_or_equal) {
1066
+ switch (op->src[0]->type) {
1067
+ case GGML_TYPE_F32:
1068
+ spacemit_kernels::rvv::forward_repeat_nrows<int32_t>(params, op);
1069
+ return true;
1070
+ case GGML_TYPE_F16:
1071
+ spacemit_kernels::rvv::forward_repeat_nrows<int16_t>(params, op);
1072
+ return true;
1073
+ default:
1074
+ break;
1075
+ }
1076
+ }
1077
+
1078
+ if (op->src[0]->ne[1] == 1 && op->src[0]->ne[0] == op->ne[0]) {
1079
+ switch (op->src[0]->type) {
1080
+ case GGML_TYPE_F32:
1081
+ spacemit_kernels::rvv::forward_repeat_dim1<int32_t>(params, op);
1082
+ return true;
1083
+ case GGML_TYPE_F16:
1084
+ spacemit_kernels::rvv::forward_repeat_dim1<int16_t>(params, op);
1085
+ return true;
1086
+ default:
1087
+ break;
1088
+ }
1089
+ }
1090
+
1091
+ ggml_compute_forward_repeat(params, op);
1092
+ }
1093
+ return true;
1094
+ case GGML_OP_SUM_ROWS:
1095
+ {
1096
+ if (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) {
1097
+ spacemit_kernels::rvv::forward_sum_rows<float>(params, op);
1098
+ } else {
1099
+ ggml_compute_forward_sum_rows(params, op);
1100
+ }
1101
+ }
1102
+ return true;
1103
+ case GGML_OP_GET_ROWS:
1104
+ {
1105
+ if (op->src[0]->type == op->type) {
1106
+ switch (op->src[0]->type) {
1107
+ case GGML_TYPE_F32:
1108
+ spacemit_kernels::rvv::forward_get_rows<int32_t>(params, op);
1109
+ return true;
1110
+ case GGML_TYPE_F16:
1111
+ spacemit_kernels::rvv::forward_get_rows<int16_t>(params, op);
1112
+ return true;
1113
+ default:
1114
+ break;
1115
+ }
1116
+ }
1117
+
1118
+ ggml_compute_forward_get_rows(params, op);
1119
+ }
586
1120
  return true;
1121
+ case GGML_OP_CONCAT:
1122
+ {
1123
+ const int32_t dim = ggml_get_op_params_i32(op, 0);
1124
+ if (dim == 0 && op->type == op->src[0]->type) {
1125
+ switch (op->src[0]->type) {
1126
+ case GGML_TYPE_F32:
1127
+ spacemit_kernels::rvv::forward_concat<int32_t>(params, op);
1128
+ return true;
1129
+ case GGML_TYPE_F16:
1130
+ spacemit_kernels::rvv::forward_concat<int16_t>(params, op);
1131
+ return true;
1132
+ default:
1133
+ break;
1134
+ }
1135
+ }
1136
+
1137
+ ggml_compute_forward_concat(params, op);
1138
+ }
1139
+ return true;
1140
+ // TODO For GGML_OP_GATED_DELTA_NET
1141
+ // case GGML_OP_GATED_DELTA_NET:
1142
+ // return true;
587
1143
  default:
588
- // GGML_ABORT("fatal error");
589
1144
  break;
590
1145
  }
591
1146
  return false;
592
1147
  }
593
1148
 
594
- void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op) {
595
- const ggml_tensor * src0 = op->src[0];
596
- ggml_tensor * dst = op;
597
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
598
- GGML_ASSERT(src0->nb[0] == sizeof(float));
1149
+ void forward_flash_attn_ext_f16(const ggml_compute_params * params, ggml_tensor * dst) {
1150
+ const ggml_tensor * q = dst->src[0];
1151
+ const ggml_tensor * k = dst->src[1];
1152
+ const ggml_tensor * v = dst->src[2];
1153
+
1154
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
1155
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
1156
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
1157
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
1158
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
1159
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
1160
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
1161
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
1162
+
1163
+ const int64_t DK = nek0;
1164
+ const int64_t DV = nev0;
1165
+
1166
+ const bool supported_prec = (dst->op_params[3] == GGML_PREC_F32 || dst->op_params[3] == GGML_PREC_DEFAULT);
1167
+ const bool supported_types = (q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16);
1168
+ const bool supported_shape = (DK > 0 && DK <= 128 && DV > 0 && DV <= 128);
1169
+ const bool supported_vlen = (__riscv_vlenb() == 128);
1170
+
1171
+ if (!(supported_prec && supported_types && supported_shape && supported_vlen)) {
1172
+ ggml_compute_forward_flash_attn_ext(params, dst);
1173
+ return;
1174
+ }
1175
+
1176
+ // total rows in q
1177
+ const int64_t nr = neq1 * neq2 * neq3;
599
1178
 
1179
+ // rows per thread
600
1180
  const int ith = params->ith;
601
1181
  const int nth = params->nth;
602
1182
 
603
- GGML_TENSOR_UNARY_OP_LOCALS
1183
+ static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
1184
+ const bool use_tiled = !params->use_ref && (neq1 >= Q_TILE_SZ);
604
1185
 
605
- float epsilon;
606
- memcpy(&epsilon, dst->op_params, sizeof(float));
1186
+ // 4x chunks per thread
1187
+ // int nth_scaled = nth * 4;
1188
+ // int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
1189
+ // int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
607
1190
 
608
- GGML_ASSERT(epsilon > 0.0f);
1191
+ // if (nth == 1 || nchunk < nth) {
1192
+ // nchunk = nth;
1193
+ // }
609
1194
 
610
- auto * input = (float *) src0->data;
611
- auto * output = (float *) dst->data;
1195
+ int64_t nchunk = nth;
612
1196
 
613
- const auto hidden_size = ne00;
614
- const auto task_count = ne01 * ne02 * ne03;
615
- const auto task_per_thread = (task_count + nth - 1) / nth;
616
-
617
- const auto task_begin = ith * task_per_thread;
618
- const auto task_end = std::min((ith + 1) * task_per_thread, task_count);
1197
+ if (ith == 0) {
1198
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
1199
+ ggml_threadpool_chunk_set(params->threadpool, nth);
1200
+ }
619
1201
 
620
- for (auto task_idx = task_begin; task_idx < task_end; task_idx++) {
621
- auto offset = task_idx * hidden_size;
622
- auto * p_input = const_cast<float *>(input + offset);
1202
+ ggml_barrier(params->threadpool);
623
1203
 
624
- auto * p_output = output + offset;
625
- auto * p_temp_output = p_output;
626
- auto * p_gamma_data = (const float *) nullptr;
627
- auto * p_beta_data = (const float *) nullptr;
628
- size_t gvl = __riscv_vsetvlmax_e32m4();
629
- vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl);
630
- vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl);
631
- int64_t length = hidden_size;
632
- while (length > 0) {
633
- gvl = __riscv_vsetvl_e32m4(length);
634
- // load data
635
- vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl);
1204
+ // The number of elements in each chunk
1205
+ const int64_t dr = (nr + nchunk - 1) / nchunk;
636
1206
 
637
- sum = __riscv_vfadd_vv_f32m4(sum, src_data, gvl);
638
- sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl);
1207
+ // The first chunk comes from our thread_id, the rest will get auto-assigned.
1208
+ int current_chunk = ith;
639
1209
 
640
- __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl);
1210
+ while (current_chunk < nchunk) {
1211
+ const int64_t ir0 = dr * current_chunk;
1212
+ const int64_t ir1 = MIN(ir0 + dr, nr);
641
1213
 
642
- p_input += gvl;
643
- p_temp_output += gvl;
644
- length -= gvl;
1214
+ if (use_tiled) {
1215
+ spacemit_kernels::rvv::forward_flash_attn_ext_f16_tiled_vlen1024_vf16(
1216
+ params, dst, ir0, ir1, ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer,
1217
+ ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size);
1218
+ } else {
1219
+ spacemit_kernels::rvv::forward_flash_attn_ext_f16_one_chunk_vlen1024_vf16(
1220
+ params, dst, ir0, ir1, ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer,
1221
+ ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size);
645
1222
  }
646
1223
 
647
- gvl = __riscv_vsetvlmax_e32m1();
648
-
649
- float mean = 0.f;
650
- vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl);
651
- vfloat32m1_t mean_v =
652
- __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl);
653
- mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl);
654
- mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl);
655
- mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl);
656
- mean = __riscv_vfmv_f_s_f32m1_f32(mean_v);
657
- mean /= hidden_size;
658
-
659
- vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0),
660
- __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl);
661
- mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl);
662
- mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl);
663
- mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl);
664
-
665
- float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v);
666
- mean_square /= hidden_size;
667
- mean_square = sqrt(mean_square - mean * mean + epsilon);
668
-
669
- mean_square = 1.0f / mean_square;
670
- length = hidden_size;
671
- p_temp_output = p_output;
672
-
673
- if (p_gamma_data == nullptr && p_beta_data == nullptr) {
674
- while (length > 0) {
675
- gvl = __riscv_vsetvl_e32m4(length);
676
- vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
677
- src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);
678
- src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
679
- __riscv_vse32_v_f32m4(p_output, src_data, gvl);
680
- p_temp_output += gvl;
681
- p_output += gvl;
682
- length -= gvl;
683
- }
684
- } else if (p_beta_data == nullptr) {
685
- while (length > 0) {
686
- gvl = __riscv_vsetvl_e32m4(length);
687
- vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
688
- vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
689
- src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);
690
- src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
691
- src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
692
- __riscv_vse32_v_f32m4(p_output, src_data, gvl);
693
- p_temp_output += gvl;
694
- p_output += gvl;
695
- p_gamma_data += gvl;
696
- length -= gvl;
697
- }
698
- } else if (p_gamma_data != nullptr) {
699
- while (length > 0) {
700
- gvl = __riscv_vsetvl_e32m4(length);
701
- vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
702
- vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
703
- src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl);
704
- src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
705
- src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
706
- vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl);
707
- src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl);
708
- p_beta_data += gvl;
709
- __riscv_vse32_v_f32m4(p_output, src_data, gvl);
710
- p_temp_output += gvl;
711
- p_output += gvl;
712
- p_gamma_data += gvl;
713
- length -= gvl;
714
- }
715
- }
1224
+ current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
716
1225
  }
717
1226
  }
718
1227
 
719
- void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op) {
720
- const ggml_tensor * src0 = op->src[0];
721
- ggml_tensor * dst = op;
722
- GGML_ASSERT(ggml_are_same_shape(src0, dst));
723
- GGML_ASSERT(src0->nb[0] == sizeof(float));
724
-
725
- const int ith = params->ith;
726
- const int nth = params->nth;
727
-
728
- GGML_TENSOR_UNARY_OP_LOCALS
729
-
730
- float epsilon;
731
- memcpy(&epsilon, dst->op_params, sizeof(float));
732
-
733
- GGML_ASSERT(epsilon > 0.0f);
734
-
735
- auto * input = (float *) src0->data;
736
- auto * output = (float *) dst->data;
737
-
738
- const auto hidden_size = ne00;
739
- const auto task_count = ne01 * ne02 * ne03;
740
- const auto task_per_thread = (task_count + nth - 1) / nth;
741
-
742
- const auto task_begin = ith * task_per_thread;
743
- const auto task_end = std::min((ith + 1) * task_per_thread, task_count);
744
-
745
- for (auto task_idx = task_begin; task_idx < task_end; task_idx++) {
746
- auto offset = task_idx * hidden_size;
747
- auto * p_input = const_cast<float *>(input + offset);
748
- auto * p_output = output + offset;
749
- auto * p_temp_output = p_output;
750
- auto * p_gamma_data = (const float *) nullptr;
751
- auto * p_beta_data = (const float *) nullptr;
752
-
753
- size_t gvl = __riscv_vsetvlmax_e32m4();
754
- // vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl);
755
- vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl);
756
- int64_t length = hidden_size;
757
- while (length > 0) {
758
- gvl = __riscv_vsetvl_e32m4(length);
759
- // load data
760
- vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl);
1228
+ int repack(ggml_tensor * t, const void * data, size_t data_size) override {
1229
+ memcpy(t->data, data, data_size);
1230
+ return 0;
1231
+ }
1232
+ };
761
1233
 
762
- sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl);
1234
+ // Impl By IME1
1235
+ static const tensor_traits<block_q4_0, 32, 16> q4_0_16x32_q8_0;
1236
+ static const tensor_traits<block_q4_1, 32, 16> q4_1_16x32_q8_0;
1237
+ static const tensor_traits<block_q4_K, 32, 16> q4_k_16x32_q8_0;
1238
+ // Impl By IME2
1239
+ static const tensor_traits<block_q2_K, 256, 32> q2_k_32x256_q8_0;
1240
+ static const tensor_traits<block_q3_K, 256, 32> q3_k_32x256_q8_0;
1241
+ static const tensor_traits<block_q4_0, 32, 32> q4_0_32x32_q8_0;
1242
+ static const tensor_traits<block_q4_1, 32, 32> q4_1_32x32_q8_0;
1243
+ static const tensor_traits<block_q4_0, 256, 32> q4_0_32x256_q8_0;
1244
+ static const tensor_traits<block_q4_1, 256, 32> q4_1_32x256_q8_0;
1245
+ static const tensor_traits<block_q4_K, 32, 32> q4_k_32x32_q8_0;
1246
+ static const tensor_traits<block_q6_K, 32, 32> q6_k_32x32_q8_0;
1247
+ static const tensor_traits<block_q8_0, 32, 32> q8_0_32x32_q8_0;
1248
+ static const tensor_traits<block_mxfp4, 32, 32> mxfp4_32x32_q8_0;
1249
+ static const tensor_traits<block_q5_K, 32, 32> q5_k_32x32_q8_0;
1250
+ static const tensor_traits<block_q5_1, 32, 32> q5_1_32x32_q8_0;
1251
+ static const tensor_traits<block_q5_0, 32, 32> q5_0_32x32_q8_0;
1252
+ // Impl By RVV
1253
+ static const tensor_traits_common rvv_impl;
763
1254
 
764
- __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl);
1255
+ } // namespace ggml::cpu::riscv64_spacemit
765
1256
 
766
- p_input += gvl;
767
- p_temp_output += gvl;
768
- length -= gvl;
1257
+ static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const ggml_tensor * cur) {
1258
+ switch (cur->type) {
1259
+ case GGML_TYPE_Q2_K:
1260
+ {
1261
+ #if defined(RISCV64_SPACEMIT_IME2)
1262
+ if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1263
+ return &ggml::cpu::riscv64_spacemit::q2_k_32x256_q8_0;
1264
+ }
1265
+ #endif
769
1266
  }
1267
+ break;
1268
+ case GGML_TYPE_Q3_K:
1269
+ {
1270
+ #if defined(RISCV64_SPACEMIT_IME2)
1271
+ if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1272
+ return &ggml::cpu::riscv64_spacemit::q3_k_32x256_q8_0;
1273
+ }
1274
+ #endif
1275
+ }
1276
+ break;
1277
+ case GGML_TYPE_Q4_0:
1278
+ {
1279
+ #if defined(RISCV64_SPACEMIT_IME2)
1280
+ if (cur->ne[1] % 32 == 0 && cur->ne[0] % 256 == 0 &&
1281
+ (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1282
+ return &ggml::cpu::riscv64_spacemit::q4_0_32x256_q8_0;
1283
+ }
770
1284
 
771
- gvl = __riscv_vsetvlmax_e32m1();
772
-
773
- // float mean = 0.f;
774
- vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl);
775
-
776
- vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0),
777
- __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl);
778
- mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl);
779
- mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl);
780
- mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl);
781
-
782
- float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v);
783
- mean_square /= hidden_size;
1285
+ if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1286
+ return &ggml::cpu::riscv64_spacemit::q4_0_32x32_q8_0;
1287
+ }
1288
+ #endif
784
1289
 
785
- mean_square = sqrt(mean_square + epsilon);
1290
+ #if defined(RISCV64_SPACEMIT_IME1)
1291
+ if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) {
1292
+ return &ggml::cpu::riscv64_spacemit::q4_0_16x32_q8_0;
1293
+ }
1294
+ #endif
1295
+ }
1296
+ break;
1297
+ case GGML_TYPE_Q4_1:
1298
+ {
1299
+ #if defined(RISCV64_SPACEMIT_IME2)
1300
+ // TODO
1301
+ // if (cur->ne[1] % 32 == 0 && cur->ne[0] % 256 == 0 &&
1302
+ // (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1303
+ // return &ggml::cpu::riscv64_spacemit::q4_1_32x256_q8_0;
1304
+ // }
1305
+
1306
+ if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1307
+ return &ggml::cpu::riscv64_spacemit::q4_1_32x32_q8_0;
1308
+ }
1309
+ #endif
786
1310
 
787
- mean_square = 1.0f / mean_square;
788
- length = hidden_size;
789
- p_temp_output = p_output;
1311
+ #if defined(RISCV64_SPACEMIT_IME1)
1312
+ if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) {
1313
+ return &ggml::cpu::riscv64_spacemit::q4_1_16x32_q8_0;
1314
+ }
1315
+ #endif
1316
+ }
1317
+ break;
1318
+ case GGML_TYPE_Q4_K:
1319
+ {
1320
+ #if defined(RISCV64_SPACEMIT_IME2)
1321
+ if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1322
+ return &ggml::cpu::riscv64_spacemit::q4_k_32x32_q8_0;
1323
+ }
1324
+ #endif
790
1325
 
791
- if (p_gamma_data == nullptr && p_beta_data == nullptr) {
792
- while (length > 0) {
793
- gvl = __riscv_vsetvl_e32m4(length);
794
- vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
795
- src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
796
- __riscv_vse32_v_f32m4(p_output, src_data, gvl);
797
- p_temp_output += gvl;
798
- p_output += gvl;
799
- length -= gvl;
1326
+ #if defined(RISCV64_SPACEMIT_IME1)
1327
+ if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) {
1328
+ return &ggml::cpu::riscv64_spacemit::q4_k_16x32_q8_0;
800
1329
  }
801
- } else if (p_beta_data == nullptr) {
802
- while (length > 0) {
803
- gvl = __riscv_vsetvl_e32m4(length);
804
- vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
805
- vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
806
- src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
807
- src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
808
- __riscv_vse32_v_f32m4(p_output, src_data, gvl);
809
- p_temp_output += gvl;
810
- p_output += gvl;
811
- p_gamma_data += gvl;
812
- length -= gvl;
1330
+ #endif
1331
+ }
1332
+ break;
1333
+ case GGML_TYPE_Q6_K:
1334
+ {
1335
+ #if defined(RISCV64_SPACEMIT_IME2)
1336
+ if ((ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1337
+ return &ggml::cpu::riscv64_spacemit::q6_k_32x32_q8_0;
813
1338
  }
814
- } else if (p_gamma_data != nullptr) {
815
- while (length > 0) {
816
- gvl = __riscv_vsetvl_e32m4(length);
817
- vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl);
818
- vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl);
819
- src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl);
820
- src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl);
821
- vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl);
822
- src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl);
823
- p_beta_data += gvl;
824
- __riscv_vse32_v_f32m4(p_output, src_data, gvl);
825
- p_temp_output += gvl;
826
- p_output += gvl;
827
- p_gamma_data += gvl;
828
- length -= gvl;
1339
+ #endif
1340
+ }
1341
+ break;
1342
+ case GGML_TYPE_Q8_0:
1343
+ {
1344
+ #if defined(RISCV64_SPACEMIT_IME2)
1345
+ if ((ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1346
+ return &ggml::cpu::riscv64_spacemit::q8_0_32x32_q8_0;
829
1347
  }
1348
+ #endif
830
1349
  }
831
- }
832
- }
833
-
834
- int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
835
- memcpy(t->data, data, data_size);
836
- return 0;
837
- }
838
- };
839
-
840
- static const tensor_traits<block_q4_0, 8, 16> q4_0_16x8_q8_0;
841
- static const tensor_traits<block_q4_1, 8, 16> q4_1_16x8_q8_0;
842
- static const tensor_traits<block_q4_K, 8, 16> q4_k_16x8_q8_0;
843
- static const tensor_traits_common rvv_impl;
844
-
845
- } // namespace ggml::cpu::riscv64_spacemit
846
-
847
- static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const struct ggml_tensor * cur) {
848
- if (cur->type == GGML_TYPE_Q4_0) {
849
- if (cur->ne[1] % 16 == 0) {
850
- return &ggml::cpu::riscv64_spacemit::q4_0_16x8_q8_0;
851
- }
852
- } else if (cur->type == GGML_TYPE_Q4_1) {
853
- if (cur->ne[1] % 16 == 0) {
854
- return &ggml::cpu::riscv64_spacemit::q4_1_16x8_q8_0;
855
- }
856
- } else if (cur->type == GGML_TYPE_Q4_K) {
857
- if (cur->ne[1] % 16 == 0) {
858
- return &ggml::cpu::riscv64_spacemit::q4_k_16x8_q8_0;
859
- }
860
- } else if (cur->type == GGML_TYPE_F32) {
861
- return &ggml::cpu::riscv64_spacemit::rvv_impl;
1350
+ break;
1351
+ case GGML_TYPE_MXFP4:
1352
+ {
1353
+ #if defined(RISCV64_SPACEMIT_IME2)
1354
+ // TODO
1355
+ // if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1356
+ // return &ggml::cpu::riscv64_spacemit::mxfp4_32x32_q8_0;
1357
+ // }
1358
+ #endif
1359
+ }
1360
+ break;
1361
+ case GGML_TYPE_Q5_K:
1362
+ {
1363
+ #if defined(RISCV64_SPACEMIT_IME2)
1364
+ if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1365
+ return &ggml::cpu::riscv64_spacemit::q5_k_32x32_q8_0;
1366
+ }
1367
+ #endif
1368
+ }
1369
+ break;
1370
+ case GGML_TYPE_Q5_1:
1371
+ {
1372
+ #if defined(RISCV64_SPACEMIT_IME2)
1373
+ if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1374
+ return &ggml::cpu::riscv64_spacemit::q5_1_32x32_q8_0;
1375
+ }
1376
+ #endif
1377
+ }
1378
+ break;
1379
+ case GGML_TYPE_Q5_0:
1380
+ {
1381
+ #if defined(RISCV64_SPACEMIT_IME2)
1382
+ if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) {
1383
+ return &ggml::cpu::riscv64_spacemit::q5_0_32x32_q8_0;
1384
+ }
1385
+ #endif
1386
+ }
1387
+ break;
1388
+ default:
1389
+ break;
862
1390
  }
863
1391
 
864
1392
  return nullptr;
865
1393
  }
866
1394
 
867
1395
  static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_backend_buffer_t buffer,
868
- struct ggml_tensor * tensor) {
1396
+ ggml_tensor * tensor) {
869
1397
  tensor->extra =
870
1398
  (void *) const_cast<ggml::cpu::tensor_traits *>(ggml_riscv64_spacemit_get_optimal_repack_type(tensor));
871
1399
 
@@ -874,8 +1402,46 @@ static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_ba
874
1402
  return GGML_STATUS_SUCCESS;
875
1403
  }
876
1404
 
1405
+ static void ggml_backend_riscv64_spacemit_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1406
+ GGML_ASSERT(buffer);
1407
+
1408
+ void * base = buffer->context;
1409
+ if (base == nullptr) {
1410
+ return;
1411
+ }
1412
+
1413
+ ggml::cpu::riscv64_spacemit::spine_mem_pool_free(base);
1414
+ }
1415
+
1416
+ static void * ggml_backend_riscv64_spacemit_buffer_get_base(ggml_backend_buffer_t buffer) {
1417
+ GGML_ASSERT(buffer);
1418
+
1419
+ void * base = buffer->context;
1420
+ GGML_ASSERT(base != nullptr);
1421
+ return base;
1422
+ }
1423
+
1424
+ static void ggml_backend_riscv64_spacemit_buffer_memset_tensor(ggml_backend_buffer_t buffer,
1425
+ ggml_tensor * tensor,
1426
+ uint8_t value,
1427
+ size_t offset,
1428
+ size_t size) {
1429
+ GGML_ASSERT(tensor);
1430
+ memset((char *) tensor->data + offset, value, size);
1431
+
1432
+ GGML_UNUSED(buffer);
1433
+ }
1434
+
1435
+ static void ggml_backend_riscv64_spacemit_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
1436
+ GGML_ASSERT(buffer);
1437
+
1438
+ void * base = buffer->context;
1439
+ GGML_ASSERT(base != nullptr);
1440
+ memset(base, value, buffer->size);
1441
+ }
1442
+
877
1443
  static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_t buffer,
878
- struct ggml_tensor * tensor,
1444
+ ggml_tensor * tensor,
879
1445
  const void * data,
880
1446
  size_t offset,
881
1447
  size_t size) {
@@ -891,6 +1457,20 @@ static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_
891
1457
  GGML_UNUSED(buffer);
892
1458
  }
893
1459
 
1460
+ static const ggml_backend_buffer_i ggml_backend_riscv64_spacemit_buffer_i = {
1461
+ /* .free_buffer = */ ggml_backend_riscv64_spacemit_buffer_free_buffer,
1462
+ /* .get_base = */ ggml_backend_riscv64_spacemit_buffer_get_base,
1463
+ /* .init_tensor = */ ggml_backend_riscv64_spacemit_buffer_init_tensor,
1464
+ /* .memset_tensor = */ ggml_backend_riscv64_spacemit_buffer_memset_tensor,
1465
+ /* .set_tensor = */ ggml_backend_riscv64_spacemit_buffer_set_tensor,
1466
+ /* .get_tensor = */ nullptr,
1467
+ /* .set_tensor_2d = */ nullptr,
1468
+ /* .get_tensor_2d = */ nullptr,
1469
+ /* .cpy_tensor = */ nullptr,
1470
+ /* .clear = */ ggml_backend_riscv64_spacemit_buffer_clear,
1471
+ /* .reset = */ nullptr,
1472
+ };
1473
+
894
1474
  static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
895
1475
  return "CPU_RISCV64_SPACEMIT";
896
1476
 
@@ -899,18 +1479,12 @@ static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_
899
1479
 
900
1480
  static ggml_backend_buffer_t ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
901
1481
  size_t size) {
902
- ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
903
-
904
- if (buffer == nullptr) {
1482
+ void * base = ggml::cpu::riscv64_spacemit::spine_mem_pool_alloc(size, 64);
1483
+ if (base == nullptr) {
905
1484
  return nullptr;
906
1485
  }
907
1486
 
908
- buffer->buft = buft;
909
- buffer->iface.init_tensor = ggml_backend_riscv64_spacemit_buffer_init_tensor;
910
- buffer->iface.set_tensor = ggml_backend_riscv64_spacemit_buffer_set_tensor;
911
- buffer->iface.get_tensor = nullptr;
912
- buffer->iface.cpy_tensor = nullptr;
913
- return buffer;
1487
+ return ggml_backend_buffer_init(buft, ggml_backend_riscv64_spacemit_buffer_i, base, size);
914
1488
  }
915
1489
 
916
1490
  static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
@@ -919,44 +1493,91 @@ static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_b
919
1493
  GGML_UNUSED(buft);
920
1494
  }
921
1495
 
922
- static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft,
923
- const struct ggml_tensor * tensor) {
1496
+ static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
924
1497
  for (int i = 0; i < GGML_MAX_DIMS; ++i) {
925
1498
  if (tensor->ne[i] <= 0) {
926
1499
  return 0;
927
1500
  }
928
1501
  }
929
1502
 
930
- size_t nbytes;
1503
+ GGML_UNUSED(buft);
1504
+
1505
+ const auto plain_nbytes = [&]() {
1506
+ size_t total = ggml_type_size(tensor->type);
1507
+ for (int i = 0; i < GGML_MAX_DIMS; ++i) {
1508
+ total += (tensor->ne[i] - 1) * tensor->nb[i];
1509
+ }
1510
+ return total;
1511
+ };
1512
+
931
1513
  const size_t blck_size = ggml_blck_size(tensor->type);
932
1514
  if (blck_size == 1) {
933
- nbytes = ggml_type_size(tensor->type);
934
- for (int i = 0; i < GGML_MAX_DIMS; ++i) {
935
- nbytes += (tensor->ne[i] - 1) * tensor->nb[i];
1515
+ return plain_nbytes();
1516
+ }
1517
+
1518
+ const size_t row_nbytes = tensor->ne[0] * tensor->nb[0] / blck_size;
1519
+
1520
+ const auto add_strided_nbytes = [&](size_t total, size_t src_block_size, size_t dst_block_size) {
1521
+ for (int i = 1; i < GGML_MAX_DIMS; ++i) {
1522
+ total += (tensor->ne[i] - 1) * (tensor->nb[i] / src_block_size) * dst_block_size;
936
1523
  }
937
- } else {
938
- nbytes = tensor->ne[0] * tensor->nb[0] / blck_size;
939
- if (tensor->type == GGML_TYPE_Q4_K) {
940
- GGML_ASSERT(nbytes % sizeof(block_q4_K) == 0);
941
- nbytes = (nbytes / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8;
942
- for (int i = 1; i < GGML_MAX_DIMS; ++i) {
943
- nbytes += (tensor->ne[i] - 1) * (tensor->nb[i] / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8;
944
- }
945
- } else {
946
- for (int i = 1; i < GGML_MAX_DIMS; ++i) {
947
- nbytes += (tensor->ne[i] - 1) * tensor->nb[i];
948
- }
1524
+ return total;
1525
+ };
1526
+
1527
+ const auto remap_block_nbytes = [&](size_t src_block_size, size_t dst_block_size, int64_t padded_rows = 0) {
1528
+ GGML_ASSERT(row_nbytes % src_block_size == 0);
1529
+
1530
+ size_t total =
1531
+ add_strided_nbytes((row_nbytes / src_block_size) * dst_block_size, src_block_size, dst_block_size);
1532
+
1533
+ if (padded_rows > 0 && tensor->ne[1] % padded_rows != 0) {
1534
+ total += (padded_rows - tensor->ne[1] % padded_rows) * (tensor->nb[1] / src_block_size) * dst_block_size;
949
1535
  }
1536
+
1537
+ return total;
1538
+ };
1539
+
1540
+ size_t nbytes = row_nbytes;
1541
+ switch (tensor->type) {
1542
+ case GGML_TYPE_Q4_K:
1543
+ nbytes = remap_block_nbytes(sizeof(block_q4_K), sizeof(block_q4_1) * 8);
1544
+ break;
1545
+ case GGML_TYPE_Q6_K:
1546
+ nbytes = remap_block_nbytes(sizeof(block_q6_K), sizeof(block_q8_0) * 8, 32);
1547
+ break;
1548
+ case GGML_TYPE_Q8_0:
1549
+ nbytes = remap_block_nbytes(sizeof(block_q8_0), sizeof(block_q8_0), 32);
1550
+ break;
1551
+ case GGML_TYPE_Q2_K:
1552
+ nbytes = remap_block_nbytes(sizeof(block_q2_K), sizeof(spacemit_kernels::nrow_block_q2_k<1>));
1553
+ break;
1554
+ case GGML_TYPE_Q3_K:
1555
+ nbytes = remap_block_nbytes(sizeof(block_q3_K), sizeof(spacemit_kernels::nrow_block_q3_k<1>));
1556
+ break;
1557
+ case GGML_TYPE_MXFP4:
1558
+ nbytes = remap_block_nbytes(sizeof(block_mxfp4), sizeof(spacemit_kernels::nrow_block_mxfp4<1>));
1559
+ break;
1560
+ case GGML_TYPE_Q5_K:
1561
+ nbytes = remap_block_nbytes(sizeof(block_q5_K), sizeof(spacemit_kernels::nrow_block_q5_1<1>) * 8);
1562
+ break;
1563
+ case GGML_TYPE_Q5_1:
1564
+ nbytes = remap_block_nbytes(sizeof(block_q5_1), sizeof(spacemit_kernels::nrow_block_q5_1<1>));
1565
+ break;
1566
+ case GGML_TYPE_Q5_0:
1567
+ nbytes = remap_block_nbytes(sizeof(block_q5_0), sizeof(spacemit_kernels::nrow_block_q5_0<1>));
1568
+ break;
1569
+ default:
1570
+ nbytes = add_strided_nbytes(row_nbytes, 1, 1);
1571
+ break;
950
1572
  }
951
1573
 
952
- GGML_UNUSED(buft);
953
1574
  return nbytes;
954
1575
  }
955
1576
 
956
1577
  namespace ggml::cpu::riscv64_spacemit {
957
1578
 
958
1579
  class extra_buffer_type : ggml::cpu::extra_buffer_type {
959
- bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
1580
+ bool supports_op(ggml_backend_dev_t, const ggml_tensor * op) override {
960
1581
  switch (op->op) {
961
1582
  case GGML_OP_MUL_MAT:
962
1583
  if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) &&
@@ -970,10 +1591,16 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
970
1591
  }
971
1592
  }
972
1593
  break;
973
- case GGML_OP_NORM:
974
- case GGML_OP_RMS_NORM:
975
- if (op->src[0]->type == GGML_TYPE_F32) {
976
- return true;
1594
+ case GGML_OP_MUL_MAT_ID:
1595
+ if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 3) &&
1596
+ op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type() &&
1597
+ ggml_riscv64_spacemit_get_optimal_repack_type(op->src[0])) {
1598
+ if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
1599
+ return false;
1600
+ }
1601
+ if (op->src[1]->type == GGML_TYPE_F32) {
1602
+ return true;
1603
+ }
977
1604
  }
978
1605
  break;
979
1606
  default:
@@ -983,15 +1610,28 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
983
1610
  return false;
984
1611
  }
985
1612
 
986
- ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
1613
+ ggml::cpu::tensor_traits * get_tensor_traits(const ggml_tensor * op) override {
987
1614
  switch (op->op) {
988
1615
  case GGML_OP_MUL_MAT:
1616
+ case GGML_OP_MUL_MAT_ID:
989
1617
  if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type()) {
990
1618
  return (ggml::cpu::tensor_traits *) op->src[0]->extra;
991
1619
  }
992
1620
  break;
993
1621
  case GGML_OP_NORM:
994
1622
  case GGML_OP_RMS_NORM:
1623
+ case GGML_OP_ADD:
1624
+ case GGML_OP_SUB:
1625
+ case GGML_OP_MUL:
1626
+ case GGML_OP_DIV:
1627
+ case GGML_OP_FLASH_ATTN_EXT:
1628
+ case GGML_OP_CONT:
1629
+ case GGML_OP_CPY:
1630
+ case GGML_OP_REPEAT:
1631
+ case GGML_OP_SUM_ROWS:
1632
+ case GGML_OP_GET_ROWS:
1633
+ case GGML_OP_CONCAT:
1634
+ // case GGML_OP_GATED_DELTA_NET:
995
1635
  return (ggml::cpu::tensor_traits *) (&ggml::cpu::riscv64_spacemit::rvv_impl);
996
1636
  default:
997
1637
  // GGML_ABORT("fatal error");
@@ -1005,7 +1645,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
1005
1645
  } // namespace ggml::cpu::riscv64_spacemit
1006
1646
 
1007
1647
  ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) {
1008
- static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = {
1648
+ static ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = {
1009
1649
  /* .iface = */
1010
1650
  {
1011
1651
  /* .get_name = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name,
@@ -1023,3 +1663,78 @@ ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) {
1023
1663
 
1024
1664
  return &ggml_backend_cpu_buffer_type_riscv64_spacemit;
1025
1665
  }
1666
+
1667
+ extern "C" {
1668
+ static int bind_ai_thread() {
1669
+ int fd, bytes;
1670
+ char str[32];
1671
+
1672
+ fd = open("/proc/set_ai_thread", O_WRONLY);
1673
+ if (fd < 0) {
1674
+ GGML_LOG_ERROR("try open /proc/set_ai_thread failed\n");
1675
+ return -1;
1676
+ }
1677
+
1678
+ snprintf(str, 16, "%d", 0);
1679
+ bytes = write(fd, str, strlen(str));
1680
+ if (bytes < 0) {
1681
+ GGML_LOG_ERROR("try write /proc/set_ai_thread failed\n");
1682
+ close(fd);
1683
+ return -1;
1684
+ }
1685
+
1686
+ close(fd);
1687
+ return 0;
1688
+ }
1689
+
1690
+ void ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(int thread_n) {
1691
+ int cpu_id = sched_getcpu();
1692
+ if (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2 &&
1693
+ !((1 << cpu_id) & ggml::cpu::riscv64_spacemit::global_spine_env_info.cpu_mask)) {
1694
+ GGML_PRINT_DEBUG("bind_ai_thread for thread %d, pid %d\n", thread_n, getpid());
1695
+ bind_ai_thread();
1696
+ }
1697
+
1698
+ if (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_tcm &&
1699
+ ggml::cpu::riscv64_spacemit::tls_context.cpu_id == -1) {
1700
+ CPU_ZERO(&(ggml::cpu::riscv64_spacemit::tls_context.cpuset));
1701
+ pthread_t main_thread = pthread_self();
1702
+ const auto & perfer_core_ids = ggml::cpu::riscv64_spacemit::global_spine_env_info.perfer_core_ids;
1703
+ if (thread_n < 0 || static_cast<size_t>(thread_n) >= perfer_core_ids.size()) {
1704
+ GGML_ABORT("thread_n %d exceeds perfer_core_ids size %zu\n", thread_n, perfer_core_ids.size());
1705
+ }
1706
+ auto perfer_cpu_id = perfer_core_ids[static_cast<size_t>(thread_n)];
1707
+ CPU_SET(perfer_cpu_id, &(ggml::cpu::riscv64_spacemit::tls_context.cpuset));
1708
+ int s =
1709
+ pthread_setaffinity_np(main_thread, sizeof(cpu_set_t), &(ggml::cpu::riscv64_spacemit::tls_context.cpuset));
1710
+ if (s != 0) {
1711
+ GGML_ABORT("set thread affinity error for thread_n %d, cpu_id %d\n", thread_n, perfer_cpu_id);
1712
+ }
1713
+
1714
+ int ai_cpu_id = perfer_cpu_id - ggml::cpu::riscv64_spacemit::global_spine_env_info.aicpu_id_offset;
1715
+ ggml::cpu::riscv64_spacemit::tls_context.cpu_id = ai_cpu_id;
1716
+ ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer =
1717
+ ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_get(ai_cpu_id);
1718
+ ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size =
1719
+ ggml::cpu::riscv64_spacemit::global_spine_env_info.tcm_blk_size;
1720
+ }
1721
+
1722
+ if (ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer != nullptr) {
1723
+ void * rt =
1724
+ ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_wait(ggml::cpu::riscv64_spacemit::tls_context.cpu_id);
1725
+ if (rt == nullptr) {
1726
+ GGML_ABORT("wait tcm buffer failed for cpu_id: %d", ggml::cpu::riscv64_spacemit::tls_context.cpu_id);
1727
+ }
1728
+ }
1729
+ }
1730
+
1731
+ void ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(int thread_n) {
1732
+ if (ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer != nullptr) {
1733
+ auto rt = ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_release(
1734
+ ggml::cpu::riscv64_spacemit::tls_context.cpu_id);
1735
+ if (rt != 0) {
1736
+ GGML_ABORT("release tcm buffer failed for cpu_id: %d", ggml::cpu::riscv64_spacemit::tls_context.cpu_id);
1737
+ }
1738
+ }
1739
+ }
1740
+ }