whispercpp 1.3.5 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1017) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/LICENSE +1 -1
  5. data/README.md +133 -3
  6. data/Rakefile +18 -3
  7. data/ext/dependencies.rb +10 -4
  8. data/ext/dependencies_for_windows.rb +17 -0
  9. data/ext/extconf.rb +20 -7
  10. data/ext/options.rb +54 -14
  11. data/ext/options_for_windows.rb +51 -0
  12. data/ext/ruby_whisper.c +56 -46
  13. data/ext/ruby_whisper.h +165 -2
  14. data/ext/ruby_whisper_context.c +297 -126
  15. data/ext/ruby_whisper_context_params.c +163 -0
  16. data/ext/ruby_whisper_log_queue.c +180 -0
  17. data/ext/ruby_whisper_log_settable.h +47 -0
  18. data/ext/ruby_whisper_model.c +0 -1
  19. data/ext/ruby_whisper_parakeet.c +49 -0
  20. data/ext/ruby_whisper_parakeet_context.c +304 -0
  21. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  22. data/ext/ruby_whisper_parakeet_model.c +84 -0
  23. data/ext/ruby_whisper_parakeet_params.c +548 -0
  24. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  25. data/ext/ruby_whisper_parakeet_token.c +188 -0
  26. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  27. data/ext/ruby_whisper_params.c +256 -66
  28. data/ext/ruby_whisper_segment.c +6 -7
  29. data/ext/ruby_whisper_token.c +29 -9
  30. data/ext/ruby_whisper_transcribe.cpp +46 -16
  31. data/ext/ruby_whisper_vad_context.c +48 -1
  32. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  33. data/ext/ruby_whisper_vad_params.c +0 -1
  34. data/ext/ruby_whisper_vad_segment.c +0 -1
  35. data/ext/ruby_whisper_vad_segments.c +0 -1
  36. data/ext/sources/CMakeLists.txt +41 -3
  37. data/ext/sources/CMakePresets.json +95 -0
  38. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  39. data/ext/sources/cmake/parakeet.pc.in +10 -0
  40. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  41. data/ext/sources/cmake/whisper.pc.in +1 -1
  42. data/ext/sources/examples/CMakeLists.txt +4 -2
  43. data/ext/sources/examples/bench/bench.cpp +24 -19
  44. data/ext/sources/examples/cli/cli.cpp +51 -9
  45. data/ext/sources/examples/common-ggml.cpp +4 -0
  46. data/ext/sources/examples/common-whisper.cpp +139 -67
  47. data/ext/sources/examples/common-whisper.h +11 -0
  48. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  49. data/ext/sources/examples/miniaudio.h +4507 -2131
  50. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  51. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  52. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  53. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  54. data/ext/sources/examples/server/server.cpp +213 -163
  55. data/ext/sources/ggml/CMakeLists.txt +29 -15
  56. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  57. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  58. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  59. data/ext/sources/ggml/include/ggml-backend.h +73 -11
  60. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  61. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  62. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  63. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  64. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  65. data/ext/sources/ggml/include/ggml-rpc.h +8 -3
  66. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  67. data/ext/sources/ggml/include/ggml.h +155 -16
  68. data/ext/sources/ggml/include/gguf.h +10 -2
  69. data/ext/sources/ggml/src/CMakeLists.txt +25 -5
  70. data/ext/sources/ggml/src/ggml-alloc.c +9 -10
  71. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  72. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  73. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  74. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
  76. data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
  79. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
  83. data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
  85. data/ext/sources/ggml/src/ggml-common.h +22 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
  88. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  89. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
  91. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
  92. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
  99. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
  101. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  102. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  103. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
  105. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
  106. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
  107. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  108. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
  109. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
  110. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
  111. data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
  112. data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
  113. data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
  115. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  116. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
  117. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
  118. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  119. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  120. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  121. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  122. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  123. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  124. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  125. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  126. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  127. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  128. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  129. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  130. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  131. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  132. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  133. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  134. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
  135. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  136. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
  137. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  138. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  139. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
  140. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
  141. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  142. data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
  143. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  144. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  145. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  146. data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
  147. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  148. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
  149. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  150. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
  151. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
  152. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  153. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
  154. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
  155. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
  156. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  157. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
  158. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  159. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  160. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  161. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
  162. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  163. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  164. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
  165. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  166. data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
  167. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
  168. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  169. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
  170. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  171. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
  172. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
  173. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  174. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
  175. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
  176. data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
  177. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  178. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  179. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
  180. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  181. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
  182. data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
  183. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  184. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  185. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  186. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  187. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  188. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  189. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  190. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
  191. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  192. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  193. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
  226. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
  227. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  228. data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
  229. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
  230. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
  231. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  232. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
  233. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
  235. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
  254. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  255. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  256. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  257. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  258. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  259. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  260. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  261. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  262. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
  263. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
  264. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  265. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  266. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
  267. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
  268. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
  269. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  270. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
  271. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  272. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  273. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  274. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  275. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  276. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  278. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  297. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
  302. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
  303. data/ext/sources/ggml/src/ggml-impl.h +68 -1
  304. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  305. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  306. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  307. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  308. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
  309. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
  310. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
  311. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
  312. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
  313. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
  314. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
  315. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
  316. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  317. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
  318. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  367. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  368. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  369. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  370. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  371. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  372. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  373. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  374. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  375. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  376. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  377. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  378. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  379. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  380. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  381. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  382. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  383. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  384. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  385. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  386. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  387. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  388. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  389. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
  390. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  391. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  392. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  393. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  394. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  395. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  396. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  397. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  398. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  399. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  400. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  401. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
  402. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  403. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
  404. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  405. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
  406. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
  407. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  412. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  413. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  414. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  415. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  416. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  417. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  418. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  419. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  420. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  421. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  422. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  423. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  424. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
  425. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  426. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  427. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  428. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  429. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  430. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  431. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  432. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
  433. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
  434. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  435. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  436. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  437. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  438. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  439. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  440. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
  441. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  442. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
  443. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
  444. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
  445. data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
  446. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  447. data/ext/sources/ggml/src/ggml-quants.c +385 -119
  448. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  449. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  450. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  451. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  452. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  453. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
  454. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  455. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
  456. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  457. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  458. data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
  459. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
  460. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
  461. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  462. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  463. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  464. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  465. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  466. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  467. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  468. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  469. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
  470. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
  471. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  472. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  473. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
  474. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
  475. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
  476. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
  477. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
  478. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  479. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  480. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  481. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
  482. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
  483. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  484. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  485. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
  486. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  487. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  488. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  489. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  490. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  491. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  492. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  493. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  494. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
  495. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  496. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  497. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  498. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  499. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  500. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  501. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  502. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  503. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  504. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  505. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  506. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  507. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  508. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  509. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  510. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  511. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  512. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  513. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  514. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  515. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  516. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
  517. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
  518. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
  519. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
  520. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
  521. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
  522. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
  523. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
  524. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
  525. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
  526. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
  527. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
  530. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
  531. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
  532. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
  534. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
  535. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
  536. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
  537. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
  538. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
  539. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
  540. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
  541. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
  542. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
  543. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
  544. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
  545. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
  547. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
  548. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
  549. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
  550. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
  551. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
  552. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  553. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  554. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  555. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
  556. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  557. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  558. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  559. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  560. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  561. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  562. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  563. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  564. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  565. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  566. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  567. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  568. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  569. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  570. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  571. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  572. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  573. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  574. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  575. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  576. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  577. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  578. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
  579. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
  580. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  581. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
  582. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  583. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  584. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  585. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  586. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  587. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  588. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  589. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  590. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  591. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  592. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
  593. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  594. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  595. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  596. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
  597. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  599. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
  663. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  664. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
  665. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
  666. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  667. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  668. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  669. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  670. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  671. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
  672. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
  673. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
  674. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  675. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
  676. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  677. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
  678. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  679. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  680. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  681. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  682. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  683. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  684. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
  685. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  686. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  687. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
  688. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  689. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  690. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  691. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
  692. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
  693. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
  694. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  695. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  696. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  697. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  698. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  699. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  700. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  701. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  702. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  703. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
  704. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  705. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
  706. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  707. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  708. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  709. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  710. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  711. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  712. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
  713. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  714. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
  715. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  716. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
  717. data/ext/sources/ggml/src/ggml.c +268 -52
  718. data/ext/sources/ggml/src/gguf.cpp +377 -47
  719. data/ext/sources/include/parakeet.h +342 -0
  720. data/ext/sources/include/whisper.h +10 -0
  721. data/ext/sources/media/matmul.png +0 -0
  722. data/ext/sources/src/CMakeLists.txt +23 -0
  723. data/ext/sources/src/parakeet-arch.h +188 -0
  724. data/ext/sources/src/parakeet.cpp +3838 -0
  725. data/ext/sources/src/whisper.cpp +62 -40
  726. data/extsources.rb +26 -10
  727. data/lib/whisper/log_settable.rb +36 -0
  728. data/lib/whisper/model/uri.rb +13 -1
  729. data/lib/whisper/output.rb +74 -0
  730. data/sig/whisper.rbs +445 -55
  731. data/test/helper.rb +2 -0
  732. data/test/jfk_reader/jfk_reader.c +50 -7
  733. data/test/test_callback.rb +1 -0
  734. data/test/test_context_params.rb +82 -0
  735. data/test/test_package.rb +6 -5
  736. data/test/test_parakeet.rb +28 -0
  737. data/test/test_parakeet_callback.rb +107 -0
  738. data/test/test_parakeet_context.rb +116 -0
  739. data/test/test_parakeet_context_params.rb +24 -0
  740. data/test/test_parakeet_model.rb +21 -0
  741. data/test/test_parakeet_params.rb +78 -0
  742. data/test/test_parakeet_segment.rb +42 -0
  743. data/test/test_parakeet_token.rb +73 -0
  744. data/test/test_params.rb +2 -0
  745. data/test/test_token.rb +11 -0
  746. data/test/test_vad_context.rb +58 -8
  747. data/test/test_vad_segment.rb +1 -1
  748. data/test/test_whisper.rb +44 -6
  749. data/whispercpp.gemspec +2 -2
  750. metadata +426 -280
  751. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  752. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  753. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  754. data/ext/sources/bindings/javascript/package.json +0 -26
  755. data/ext/sources/bindings/javascript/whisper.js +0 -19
  756. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  757. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  758. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  759. data/ext/sources/examples/addon.node/index.js +0 -59
  760. data/ext/sources/examples/addon.node/package.json +0 -16
  761. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  762. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  763. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  764. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  765. data/ext/sources/examples/coi-serviceworker.js +0 -146
  766. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  767. data/ext/sources/examples/command/command.cpp +0 -802
  768. data/ext/sources/examples/command/commands.txt +0 -9
  769. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  770. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  771. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  772. data/ext/sources/examples/generate-karaoke.sh +0 -57
  773. data/ext/sources/examples/helpers.js +0 -191
  774. data/ext/sources/examples/livestream.sh +0 -112
  775. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  776. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  777. data/ext/sources/examples/lsp/whisper.vim +0 -362
  778. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  779. data/ext/sources/examples/python/whisper_processor.py +0 -54
  780. data/ext/sources/examples/server/bench.js +0 -29
  781. data/ext/sources/examples/server.py +0 -120
  782. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  783. data/ext/sources/examples/stream/stream.cpp +0 -437
  784. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  785. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  786. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  787. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  788. data/ext/sources/examples/sycl/build.sh +0 -22
  789. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  790. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  791. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
  792. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  793. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
  794. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
  795. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
  796. data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
  797. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
  798. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  799. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
  800. data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
  801. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
  802. data/ext/sources/examples/talk-llama/llama-context.h +0 -360
  803. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  804. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
  805. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  806. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  807. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
  808. data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
  809. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
  810. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
  811. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  812. data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
  813. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  814. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  815. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
  816. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  817. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
  818. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
  819. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  820. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  821. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  822. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
  823. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  824. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  825. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  826. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
  827. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  828. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
  829. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
  830. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
  831. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
  832. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
  833. data/ext/sources/examples/talk-llama/llama-model.h +0 -544
  834. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
  835. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  836. data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
  837. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
  838. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
  839. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
  840. data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
  841. data/ext/sources/examples/talk-llama/llama.h +0 -1540
  842. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
  843. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  844. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  845. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
  846. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  847. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
  848. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
  849. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
  850. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
  851. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
  852. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  853. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  854. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  855. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  856. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  857. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  858. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  859. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
  860. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  861. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
  862. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
  863. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
  864. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  865. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
  866. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  867. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  868. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  869. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
  870. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  871. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  872. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  873. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  874. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  875. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  876. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  877. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
  878. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  879. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  880. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
  881. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
  882. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
  883. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  884. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
  885. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  886. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
  887. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  888. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  889. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  890. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
  891. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  892. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  893. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  894. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
  895. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  896. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
  897. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
  898. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
  899. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
  900. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  901. data/ext/sources/examples/talk-llama/models/models.h +0 -569
  902. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
  903. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  904. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
  905. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  906. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  907. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  908. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  909. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  910. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  911. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  912. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  913. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  914. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  915. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  916. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  917. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
  918. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  919. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
  920. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  921. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  922. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  923. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  924. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
  925. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
  926. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
  927. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
  928. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
  929. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  930. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  931. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
  932. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  933. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  934. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
  935. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  936. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  937. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  938. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  939. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  940. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  941. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  942. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  943. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  944. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  945. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  946. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  947. data/ext/sources/examples/talk-llama/speak +0 -40
  948. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  949. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  950. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  951. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  952. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  953. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
  954. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  955. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  956. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  957. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  958. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  959. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  960. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  961. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  962. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  963. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  964. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  965. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  966. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  967. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  968. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  969. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  970. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
  971. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
  972. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  973. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  974. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  975. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  976. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  977. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  978. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  979. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  980. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  981. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  982. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  983. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  984. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
  985. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  986. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
  987. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
  988. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  989. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  990. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  991. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
  992. data/ext/sources/tests/CMakeLists.txt +0 -112
  993. data/ext/sources/tests/earnings21/eval.mk +0 -58
  994. data/ext/sources/tests/earnings21/eval.py +0 -68
  995. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  996. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  997. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  998. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  999. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  1000. data/ext/sources/tests/en-0-ref.txt +0 -1
  1001. data/ext/sources/tests/en-1-ref.txt +0 -1
  1002. data/ext/sources/tests/en-2-ref.txt +0 -1
  1003. data/ext/sources/tests/es-0-ref.txt +0 -1
  1004. data/ext/sources/tests/librispeech/eval.mk +0 -39
  1005. data/ext/sources/tests/librispeech/eval.py +0 -47
  1006. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  1007. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  1008. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  1009. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  1010. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  1011. data/ext/sources/tests/run-tests.sh +0 -130
  1012. data/ext/sources/tests/test-c.c +0 -3
  1013. data/ext/sources/tests/test-vad-full.cpp +0 -56
  1014. data/ext/sources/tests/test-vad.cpp +0 -83
  1015. data/ext/sources/tests/test-whisper.js +0 -58
  1016. data/lib/whisper/context.rb +0 -15
  1017. data/lib/whisper/segment.rb +0 -58
@@ -1,5 +1,5 @@
1
1
  /*
2
- * Copyright (c) 2023-2024 The ggml authors
2
+ * Copyright (c) 2023-2026 The ggml authors
3
3
  *
4
4
  * Permission is hereby granted, free of charge, to any person obtaining a copy
5
5
  * of this software and associated documentation files (the "Software"), to
@@ -25,6 +25,7 @@
25
25
  #include "ggml-impl.h"
26
26
  #include "ggml.h"
27
27
 
28
+
28
29
  #include <aclnnop/aclnn_add.h>
29
30
  #include <aclnnop/aclnn_add_rms_norm.h>
30
31
  #include <aclnnop/aclnn_addcdiv.h>
@@ -45,7 +46,9 @@
45
46
  #include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
46
47
  #include <aclnnop/aclnn_ger.h>
47
48
  #include <aclnnop/aclnn_group_norm.h>
49
+ #include <aclnnop/aclnn_gather_v2.h>
48
50
  #include <aclnnop/aclnn_grouped_matmul_v3.h>
51
+ #include <aclnnop/aclnn_scatter.h>
49
52
  #include <aclnnop/aclnn_gt_scalar.h>
50
53
  #include <aclnnop/aclnn_im2col.h>
51
54
  #include <aclnnop/aclnn_index_copy.h>
@@ -58,9 +61,11 @@
58
61
  #include <aclnnop/aclnn_mean.h>
59
62
  #include <aclnnop/aclnn_mm.h>
60
63
  #include <aclnnop/aclnn_mul.h>
64
+ #include <aclnnop/aclnn_mv.h>
61
65
  #include <aclnnop/aclnn_permute.h>
62
66
  #include <aclnnop/aclnn_pow.h>
63
67
  #include <aclnnop/aclnn_pow_tensor_tensor.h>
68
+ #include <aclnnop/aclnn_recurrent_gated_delta_rule.h>
64
69
  #include <aclnnop/aclnn_reduce_sum.h>
65
70
  #include <aclnnop/aclnn_reflection_pad1d.h>
66
71
  #include <aclnnop/aclnn_repeat.h>
@@ -68,11 +73,15 @@
68
73
  #include <aclnnop/aclnn_rms_norm.h>
69
74
  #include <aclnnop/aclnn_roll.h>
70
75
  #include <aclnnop/aclnn_softmax.h>
76
+ #include <aclnnop/aclnn_softmax_cross_entropy_with_logits.h>
71
77
  #include <aclnnop/aclnn_sub.h>
72
78
  #include <aclnnop/aclnn_sum.h>
73
79
  #include <aclnnop/aclnn_threshold.h>
74
80
  #include <aclnnop/aclnn_tril.h>
81
+ #include <aclnnop/aclnn_triangular_solve.h>
75
82
  #include <aclnnop/aclnn_triu.h>
83
+ #include <aclnnop/aclnn_logical_not.h>
84
+ #include <aclnnop/aclnn_masked_fill_scalar.h>
76
85
  #include <aclnnop/aclnn_upsample_nearest_2d.h>
77
86
  #include <aclnnop/aclnn_weight_quant_batch_matmul_v2.h>
78
87
  #include <aclnnop/aclnn_zero.h>
@@ -150,6 +159,107 @@ void ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, ac
150
159
  GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst.get(), acl_src1.get());
151
160
  }
152
161
 
162
+ // Fused SwiGLU using aclnnSwiGlu: splits input along innermost dim, applies
163
+ // SiLU to left half, multiplies by right half.
164
+ //
165
+ // Falls back to the generic two-kernel path when src[1] != nullptr (two
166
+ // independent halves) or swapped != 0 (reversed activation order), as
167
+ // aclnnSwiGlu only handles the single interleaved tensor in standard order.
168
+ //
169
+ // CANN tiling for SwiGlu requires (storageShapeDim + viewDims) to be even.
170
+ // aclCreateTensor always uses storageShapeDim=1, so viewDims must be odd.
171
+ // We use a 3D view (1+3=4, even) to satisfy this constraint while preserving
172
+ // correct split semantics along the innermost (ne[0]) dimension.
173
+ void ggml_cann_swiglu(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
174
+ auto silu_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
175
+ GGML_CANN_CALL_ACLNN_OP(ctx, Silu, acl_src, acl_dst);
176
+ };
177
+
178
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
179
+ if (dst->src[1] != nullptr || swapped != 0) {
180
+ ggml_cann_op_unary_gated(silu_fn, ctx, dst);
181
+ return;
182
+ }
183
+
184
+ // aclnnSwiGlu requires the split dim (src->ne[0]) to be even; fall back otherwise.
185
+ if (dst->src[0]->ne[0] % 2 != 0) {
186
+ ggml_cann_op_unary_gated(silu_fn, ctx, dst);
187
+ return;
188
+ }
189
+
190
+ ggml_tensor * src0 = dst->src[0];
191
+ size_t elem_size = ggml_element_size(src0);
192
+
193
+ // src0 GGML: [2*ne0, ne1, ne2, ne3] → 3D view [2*ne0, ne1, ne2*ne3]
194
+ // CANN reversed: [ne2*ne3, ne1, 2*ne0], split along CANN dim 2 (last).
195
+ int64_t ne0_x2 = src0->ne[0];
196
+ int64_t ne1 = src0->ne[1];
197
+ int64_t ne23 = src0->ne[2] * src0->ne[3];
198
+ int64_t src3d_ne[] = { ne0_x2, ne1, ne23 };
199
+ size_t src3d_nb[] = { (size_t)src0->nb[0], (size_t)src0->nb[1], (size_t)src0->nb[2] };
200
+ acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0->data, ggml_cann_type_mapping(src0->type),
201
+ elem_size, src3d_ne, src3d_nb, 3);
202
+
203
+ // dst GGML: [ne0, ne1, ne2, ne3] → 3D view [ne0, ne1, ne2*ne3]
204
+ int64_t ne0 = dst->ne[0];
205
+ int64_t dst3d_ne[] = { ne0, ne1, ne23 };
206
+ size_t dst3d_nb[] = { (size_t)dst->nb[0], (size_t)dst->nb[1], (size_t)dst->nb[2] };
207
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type),
208
+ elem_size, dst3d_ne, dst3d_nb, 3);
209
+
210
+ // CANN tensor [ne23, ne1, 2*ne0]: split along CANN dim 2 (last) = 2*ne0.
211
+ GGML_CANN_CALL_ACLNN_OP(ctx, SwiGlu, acl_src.get(), (int64_t)2, acl_dst.get());
212
+ }
213
+
214
+ // Fused GeGLU using aclnnGeGluV3: splits input along ne[0] (CANN last dim),
215
+ // activates the LEFT half with GELU, multiplies by right half.
216
+ // approximate: 0=tanh, 1=none(erf). activateLeft=true matches GGML convention.
217
+ // outGelu is a required-but-discard output buffer.
218
+ //
219
+ // Falls back to the generic two-kernel path when src[1] != nullptr (two
220
+ // independent halves) or swapped != 0 (reversed activation order), as
221
+ // aclnnGeGluV3 only handles the single interleaved tensor in standard order.
222
+ void ggml_cann_geglu(ggml_backend_cann_context & ctx, ggml_tensor * dst, int64_t approximate) {
223
+ auto gelu_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
224
+ GGML_CANN_CALL_ACLNN_OP(ctx, Gelu, acl_src, acl_dst);
225
+ };
226
+
227
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
228
+ if (dst->src[1] != nullptr || swapped != 0) {
229
+ ggml_cann_op_unary_gated(gelu_fn, ctx, dst);
230
+ return;
231
+ }
232
+
233
+ // aclnnGeGluV3 requires the split dim (src->ne[0]) to be even; fall back otherwise.
234
+ if (dst->src[0]->ne[0] % 2 != 0) {
235
+ ggml_cann_op_unary_gated(gelu_fn, ctx, dst);
236
+ return;
237
+ }
238
+
239
+ ggml_tensor * src0 = dst->src[0];
240
+ acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0);
241
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
242
+
243
+ // Allocate a temporary buffer for the required outGelu output (same shape as dst).
244
+ // Build contiguous strides since the pool allocation is a fresh buffer.
245
+ size_t elem_size = ggml_element_size(dst);
246
+ int64_t ne[GGML_MAX_DIMS] = { dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3] };
247
+ size_t nb[GGML_MAX_DIMS];
248
+ nb[0] = elem_size;
249
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
250
+ nb[i] = nb[i - 1] * ne[i - 1];
251
+ }
252
+ size_t gelu_out_size = nb[GGML_MAX_DIMS - 1] * ne[GGML_MAX_DIMS - 1];
253
+ ggml_cann_pool_alloc gelu_out_alloc(ctx.pool(), gelu_out_size);
254
+
255
+ acl_tensor_ptr acl_gelu_out = ggml_cann_create_tensor(
256
+ gelu_out_alloc.get(), ggml_cann_type_mapping(dst->type), elem_size, ne, nb, GGML_MAX_DIMS);
257
+ // V3 adds activateLeft param; true → Gelu(left)*right, matching GGML convention.
258
+ // GGML dim 0 → CANN last dim (index GGML_MAX_DIMS-1 = 3 for 4D tensor).
259
+ GGML_CANN_CALL_ACLNN_OP(ctx, GeGluV3, acl_src.get(), (int64_t)(GGML_MAX_DIMS - 1), approximate, true,
260
+ acl_dst.get(), acl_gelu_out.get());
261
+ }
262
+
153
263
  /**
154
264
  * @brief Repeats elements of a tensor along each dimension according to the
155
265
  * specified repeat array.
@@ -433,6 +543,9 @@ void ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
433
543
  void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
434
544
  ggml_tensor * src = dst->src[0];
435
545
 
546
+ float eps;
547
+ memcpy(&eps, dst->op_params, sizeof(float));
548
+
436
549
  acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);
437
550
  acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
438
551
 
@@ -441,21 +554,33 @@ void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
441
554
  ggml_cann_pool_alloc temp_buffer_allocator(ctx.pool(), n_bytes);
442
555
  void * buffer = temp_buffer_allocator.get();
443
556
 
444
- int64_t div_ne[] = { 1, src->ne[1], src->ne[2], src->ne[3] };
445
- size_t div_nb[GGML_MAX_DIMS];
446
- div_nb[0] = sizeof(float);
557
+ int64_t norm_ne[] = { 1, src->ne[1], src->ne[2], src->ne[3] };
558
+ size_t norm_nb[GGML_MAX_DIMS];
559
+ norm_nb[0] = sizeof(float);
447
560
  for (int i = 1; i < GGML_MAX_DIMS; ++i) {
448
- div_nb[i] = div_nb[i - 1] * div_ne[i - 1];
561
+ norm_nb[i] = norm_nb[i - 1] * norm_ne[i - 1];
449
562
  }
450
- acl_tensor_ptr acl_div = ggml_cann_create_tensor(buffer, ACL_FLOAT, type_size, div_ne, div_nb, GGML_MAX_DIMS);
563
+ acl_tensor_ptr acl_norm = ggml_cann_create_tensor(buffer, ACL_FLOAT, sizeof(float), norm_ne, norm_nb, GGML_MAX_DIMS);
451
564
 
452
565
  std::vector<int64_t> norm_dims = { 3 };
453
566
  acl_int_array_ptr dims_array = ggml_cann_create_int_array(norm_dims.data(), norm_dims.size());
454
567
 
455
568
  float p_value = 2.0f;
456
569
  acl_scalar_ptr p_scalar = ggml_cann_create_scalar(&p_value, aclDataType::ACL_FLOAT);
457
- GGML_CANN_CALL_ACLNN_OP(ctx, Norm, acl_src.get(), p_scalar.get(), dims_array.get(), true, acl_div.get());
458
- GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src.get(), acl_div.get(), acl_dst.get());
570
+ GGML_CANN_CALL_ACLNN_OP(ctx, Norm, acl_src.get(), p_scalar.get(), dims_array.get(), true, acl_norm.get());
571
+
572
+ ggml_cann_pool_alloc clamp_buffer_allocator(ctx.pool());
573
+ acl_tensor_ptr acl_clamped;
574
+
575
+ if (eps > 0.0f) {
576
+ void * clamp_buf = clamp_buffer_allocator.alloc(n_bytes);
577
+ acl_clamped = ggml_cann_create_tensor(clamp_buf, ACL_FLOAT, sizeof(float), norm_ne, norm_nb, GGML_MAX_DIMS);
578
+ acl_scalar_ptr eps_scalar = ggml_cann_create_scalar(&eps, aclDataType::ACL_FLOAT);
579
+ GGML_CANN_CALL_ACLNN_OP(ctx, ClampMin, acl_norm.get(), eps_scalar.get(), acl_clamped.get());
580
+ }
581
+
582
+ aclTensor * acl_div_input = acl_clamped ? acl_clamped.get() : acl_norm.get();
583
+ GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src.get(), acl_div_input, acl_dst.get());
459
584
  }
460
585
 
461
586
  void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
@@ -471,56 +596,30 @@ void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor *
471
596
  logits_nb[1] = logits_nb[0] * logits_ne[0];
472
597
  acl_tensor_ptr acl_logits = ggml_cann_create_tensor(src0->data, ACL_FLOAT, sizeof(float), logits_ne, logits_nb, 2);
473
598
 
474
- size_t log_softmax_type_size = sizeof(float);
475
- int64_t log_softmax_n_bytes = nr * nc * log_softmax_type_size;
476
- ggml_cann_pool_alloc log_softmax_allocator(ctx.pool(), log_softmax_n_bytes);
477
- void * log_softmax_buffer = log_softmax_allocator.get();
478
-
479
- int64_t log_softmax_ne[] = { nc, nr };
480
- size_t log_softmax_nb[2];
481
- log_softmax_nb[0] = log_softmax_type_size;
482
- log_softmax_nb[1] = log_softmax_nb[0] * log_softmax_ne[0];
483
- acl_tensor_ptr acl_log_softmax = ggml_cann_create_tensor(log_softmax_buffer, ACL_FLOAT, log_softmax_type_size,
484
- log_softmax_ne, log_softmax_nb, 2);
485
-
486
- GGML_CANN_CALL_ACLNN_OP(ctx, LogSoftmax, acl_logits.get(), 1, acl_log_softmax.get());
487
-
488
599
  int64_t labels_ne[] = { nc, nr };
489
600
  size_t labels_nb[2];
490
601
  labels_nb[0] = ggml_type_size(src1->type);
491
602
  labels_nb[1] = labels_nb[0] * labels_ne[0];
492
603
  acl_tensor_ptr acl_labels = ggml_cann_create_tensor(src1->data, ACL_FLOAT, sizeof(float), labels_ne, labels_nb, 2);
493
604
 
494
- size_t mul_type_size = sizeof(float);
495
- int64_t mul_n_bytes = nr * nc * mul_type_size;
496
- ggml_cann_pool_alloc mul_allocator(ctx.pool(), mul_n_bytes);
497
- void * mul_buffer = mul_allocator.get();
605
+ size_t loss_per_sample_type_size = sizeof(float);
606
+ int64_t loss_per_sample_n_bytes = nr * loss_per_sample_type_size;
607
+ ggml_cann_pool_alloc loss_per_sample_allocator(ctx.pool(), loss_per_sample_n_bytes);
608
+ void * loss_per_sample_buffer = loss_per_sample_allocator.get();
498
609
 
499
- int64_t mul_ne[] = { nc, nr };
500
- size_t mul_nb[2];
501
- mul_nb[0] = mul_type_size;
502
- mul_nb[1] = mul_nb[0] * mul_ne[0];
503
- acl_tensor_ptr acl_mul_result = ggml_cann_create_tensor(mul_buffer, ACL_FLOAT, mul_type_size, mul_ne, mul_nb, 2);
610
+ int64_t loss_per_sample_ne[] = { nr };
611
+ size_t loss_per_sample_nb[1];
612
+ loss_per_sample_nb[0] = loss_per_sample_type_size;
613
+ acl_tensor_ptr acl_loss_per_sample = ggml_cann_create_tensor(
614
+ loss_per_sample_buffer, ACL_FLOAT, loss_per_sample_type_size, loss_per_sample_ne, loss_per_sample_nb, 1);
504
615
 
505
- GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_log_softmax.get(), acl_labels.get(), acl_mul_result.get());
616
+ size_t backprop_n_bytes = nr * nc * sizeof(float);
617
+ ggml_cann_pool_alloc backprop_allocator(ctx.pool(), backprop_n_bytes);
618
+ void * backprop_buffer = backprop_allocator.get();
619
+ acl_tensor_ptr acl_backprop = ggml_cann_create_tensor(backprop_buffer, ACL_FLOAT, sizeof(float), logits_ne, logits_nb, 2);
506
620
 
507
- size_t sum_per_sample_type_size = sizeof(float);
508
- int64_t sum_per_sample_n_bytes = nr * sum_per_sample_type_size;
509
- ggml_cann_pool_alloc sum_per_sample_allocator(ctx.pool(), sum_per_sample_n_bytes);
510
- void * sum_per_sample_buffer = sum_per_sample_allocator.get();
511
-
512
- int64_t sum_per_sample_ne[] = { nr };
513
- size_t sum_per_sample_nb[1];
514
- sum_per_sample_nb[0] = sum_per_sample_type_size;
515
- acl_tensor_ptr acl_sum_per_sample = ggml_cann_create_tensor(
516
- sum_per_sample_buffer, ACL_FLOAT, sum_per_sample_type_size, sum_per_sample_ne, sum_per_sample_nb, 1);
517
-
518
- std::vector<int64_t> sum_dims = { 1 };
519
- acl_int_array_ptr dims_array = ggml_cann_create_int_array(sum_dims.data(), sum_dims.size());
520
- bool keep_dims = false;
521
-
522
- GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_mul_result.get(), dims_array.get(), keep_dims, ACL_FLOAT,
523
- acl_sum_per_sample.get());
621
+ GGML_CANN_CALL_ACLNN_OP(ctx, SoftmaxCrossEntropyWithLogits, acl_logits.get(), acl_labels.get(),
622
+ acl_loss_per_sample.get(), acl_backprop.get());
524
623
 
525
624
  size_t total_sum_type_size = sizeof(float);
526
625
  int64_t total_sum_n_bytes = 1 * total_sum_type_size;
@@ -536,11 +635,12 @@ void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor *
536
635
 
537
636
  std::vector<int64_t> total_sum_dims = { 0 };
538
637
  acl_int_array_ptr total_sum_dims_array = ggml_cann_create_int_array(total_sum_dims.data(), total_sum_dims.size());
638
+ bool keep_dims = false;
539
639
 
540
- GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_sum_per_sample.get(), total_sum_dims_array.get(), keep_dims, ACL_FLOAT,
640
+ GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_loss_per_sample.get(), total_sum_dims_array.get(), keep_dims, ACL_FLOAT,
541
641
  acl_total_sum.get());
542
642
 
543
- float value = -1.0f / static_cast<float>(nr);
643
+ float value = 1.0f / static_cast<float>(nr);
544
644
  acl_scalar_ptr scale_factor = ggml_cann_create_scalar(&value, aclDataType::ACL_FLOAT);
545
645
  acl_tensor_ptr acl_dst =
546
646
  ggml_cann_create_tensor(dst->data, ACL_FLOAT, sizeof(float), total_sum_ne, total_sum_nb, 1);
@@ -578,6 +678,33 @@ void ggml_cann_group_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
578
678
  acl_mean_out.get(), acl_rstd_out.get());
579
679
  }
580
680
 
681
+ void ggml_cann_set(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
682
+ ggml_tensor * src0 = dst->src[0];
683
+ ggml_tensor * src1 = dst->src[1];
684
+
685
+ size_t nb1 = ((int32_t *) dst->op_params)[0];
686
+ size_t nb2 = ((int32_t *) dst->op_params)[1];
687
+ size_t nb3 = ((int32_t *) dst->op_params)[2];
688
+ size_t offset = ((int32_t *) dst->op_params)[3];
689
+ bool inplace = (bool) ((int32_t *) dst->op_params)[4];
690
+
691
+ size_t param_nb[] = { ggml_element_size(src0), nb1, nb2, nb3 };
692
+
693
+ // Create a view of dst at the target offset with src1's dimensions
694
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, src1->ne, param_nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset);
695
+ acl_tensor_ptr acl_src1 = ggml_cann_create_tensor(src1);
696
+
697
+ if (!inplace) {
698
+ // First copy src0 to dst entirely
699
+ size_t cpy_size = ggml_nbytes(dst);
700
+ ACL_CHECK(
701
+ aclrtMemcpyAsync(dst->data, cpy_size, src0->data, cpy_size, ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
702
+ }
703
+
704
+ // Copy src1 into the target region of dst
705
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst.get(), acl_src1.get());
706
+ }
707
+
581
708
  void ggml_cann_acc(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
582
709
  ggml_tensor * src0 = dst->src[0];
583
710
  ggml_tensor * src1 = dst->src[1];
@@ -641,6 +768,113 @@ void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
641
768
  aclnn_reduce_sum(ctx, dst, reduce_dims, 4);
642
769
  }
643
770
 
771
+ void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
772
+ ggml_tensor * src = dst->src[0];
773
+ acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);
774
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
775
+ // GGML cumsum operates along dim 0 (innermost / ne[0]).
776
+ // ggml_cann_create_tensor reverses dimensions to [ne3,ne2,ne1,ne0],
777
+ // so GGML dim 0 maps to CANN dim 3 (the last dim of the 4-D tensor).
778
+ GGML_CANN_CALL_ACLNN_OP(ctx, Cumsum, acl_src.get(), (int64_t)3,
779
+ ggml_cann_type_mapping(dst->type), acl_dst.get());
780
+ }
781
+
782
+ void ggml_cann_solve_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
783
+ ggml_tensor * src0 = dst->src[0]; // A: [N, N, B2, B3] lower triangular
784
+ ggml_tensor * src1 = dst->src[1]; // B: [K, N, B2, B3]
785
+
786
+ acl_tensor_ptr acl_a = ggml_cann_create_tensor(src0);
787
+ acl_tensor_ptr acl_b = ggml_cann_create_tensor(src1);
788
+ acl_tensor_ptr acl_x = ggml_cann_create_tensor(dst);
789
+
790
+ // mOut: triangular copy of A (required output), same shape as A.
791
+ const size_t a_bytes = ggml_nbytes(src0);
792
+ ggml_cann_pool_alloc m_alloc(ctx.pool(), a_bytes);
793
+ acl_tensor_ptr acl_m = ggml_cann_create_tensor(
794
+ m_alloc.get(), ggml_cann_type_mapping(src0->type),
795
+ ggml_type_size(src0->type), src0->ne, src0->nb, GGML_MAX_DIMS);
796
+
797
+ // Solve AX = B: upper=false (lower tri), transpose=false, unitriangular=false.
798
+ GGML_CANN_CALL_ACLNN_OP(ctx, TriangularSolve,
799
+ acl_b.get(), acl_a.get(), false, false, false,
800
+ acl_x.get(), acl_m.get());
801
+ }
802
+
803
+ void ggml_cann_diag(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
804
+ ggml_tensor * src = dst->src[0];
805
+
806
+ GGML_ASSERT(src->ne[1] == 1);
807
+
808
+ const int64_t N = src->ne[0];
809
+ const int64_t n_batch = src->ne[2] * src->ne[3];
810
+ const size_t nb_f32 = sizeof(float);
811
+
812
+ // Fill dst with zeros.
813
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
814
+ {
815
+ float zero = 0.0f;
816
+ acl_scalar_ptr acl_zero = ggml_cann_create_scalar(&zero, ACL_FLOAT);
817
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst.get(), acl_zero.get());
818
+ }
819
+
820
+ // Copy src vector onto the diagonal of dst via strided views.
821
+ // src viewed as [N, n_batch], contiguous strides.
822
+ int64_t ne_vec[2] = { N, n_batch };
823
+ size_t nb_src_vec[2] = { nb_f32, N * nb_f32 };
824
+ // dst diagonal view: stride (N+1)*4 steps along the diagonal.
825
+ size_t nb_dst_diag[2] = { (N + 1) * nb_f32, N * N * nb_f32 };
826
+
827
+ acl_tensor_ptr acl_src_vec = ggml_cann_create_tensor(src->data, ACL_FLOAT, nb_f32, ne_vec, nb_src_vec, 2);
828
+ acl_tensor_ptr acl_dst_diag = ggml_cann_create_tensor(dst->data, ACL_FLOAT, nb_f32, ne_vec, nb_dst_diag, 2);
829
+
830
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst_diag.get(), acl_src_vec.get());
831
+ }
832
+
833
+ void ggml_cann_fill(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
834
+ float c = ggml_get_op_params_f32(dst, 0);
835
+
836
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
837
+ acl_scalar_ptr acl_c = ggml_cann_create_scalar(&c, ACL_FLOAT);
838
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst.get(), acl_c.get());
839
+ }
840
+
841
+ void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
842
+ ggml_tensor * src = dst->src[0];
843
+
844
+ const int64_t S = src->ne[0];
845
+ const int64_t n_batch = src->ne[2] * src->ne[3];
846
+ const size_t nb_f32 = sizeof(float);
847
+
848
+ int64_t ne3d[3] = { S, S, n_batch };
849
+ size_t nb3d[3] = { nb_f32, S * nb_f32, S * S * nb_f32 };
850
+
851
+ const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
852
+
853
+ acl_tensor_ptr acl_src = ggml_cann_create_tensor(src->data, ACL_FLOAT, nb_f32, ne3d, nb3d, 3);
854
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ACL_FLOAT, nb_f32, ne3d, nb3d, 3);
855
+
856
+ switch (ttype) {
857
+ case GGML_TRI_TYPE_LOWER:
858
+ // Tril(-1): preserve row > col (strict lower), zero upper + diagonal.
859
+ GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src.get(), (int64_t)-1, acl_dst.get());
860
+ break;
861
+ case GGML_TRI_TYPE_UPPER_DIAG:
862
+ // Triu(0): preserve row <= col (upper + diagonal), zero strict lower.
863
+ GGML_CANN_CALL_ACLNN_OP(ctx, Triu, acl_src.get(), (int64_t)0, acl_dst.get());
864
+ break;
865
+ case GGML_TRI_TYPE_UPPER:
866
+ // Triu(1): preserve row < col (strict upper), zero lower + diagonal.
867
+ GGML_CANN_CALL_ACLNN_OP(ctx, Triu, acl_src.get(), (int64_t)1, acl_dst.get());
868
+ break;
869
+ case GGML_TRI_TYPE_LOWER_DIAG:
870
+ // Tril(0): preserve row >= col (lower + diagonal), zero strict upper.
871
+ GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src.get(), (int64_t)0, acl_dst.get());
872
+ break;
873
+ default:
874
+ GGML_ABORT("unsupported tri type");
875
+ }
876
+ }
877
+
644
878
  void ggml_cann_upsample_nearest2d(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
645
879
  ggml_tensor * src = dst->src[0];
646
880
  acl_tensor_ptr acl_src = ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW);
@@ -1543,8 +1777,8 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx,
1543
1777
  end = 2 * ((n_head - 1) - n_head_log2) + 1;
1544
1778
  step = 2;
1545
1779
  count = n_head - n_head_log2;
1546
- aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * sizeof(float), m1, count, start, end + 1, step,
1547
- dtype);
1780
+ aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * ggml_type_size(dtype), m1, count, start, end + 1,
1781
+ step, dtype);
1548
1782
  }
1549
1783
  }
1550
1784
 
@@ -1684,150 +1918,90 @@ void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
1684
1918
  aclnn_softmax(ctx, softmax_tensor.get(), 3, acl_dst.get());
1685
1919
  }
1686
1920
 
1687
- /**
1688
- * @brief Performs index select operation on a 4D tensor using the CANN backend.
1689
- *
1690
- * This function applies the `IndexSelect` operation along a specific dimension
1691
- * of the source tensor (`src_buffer`) using the indices from the index tensor (`index`).
1692
- * It iterates over the last two dimensions of the source tensor, creates the corresponding
1693
- * CANN tensors for the source, index, and output slices, and executes the `IndexSelect`
1694
- * operation for each slice.
1695
- *
1696
- * @param ctx The context for CANN backend operations.
1697
- * @param src_buffer The source buffer containing the 4D input tensor data.
1698
- * @param src_ne The dimensions of the source tensor.
1699
- * @param src_nb The strides (byte offsets) of the source tensor.
1700
- * @param dst_buffer The destination buffer where the output tensor data will be written.
1701
- * @param dst_ne The dimensions of the destination tensor.
1702
- * @param dst_nb The strides (byte offsets) of the destination tensor.
1703
- * @param index The index tensor specifying the indices to select from the source tensor.
1704
- * @param type The data type of the source and destination tensors.
1705
- */
1706
- static void aclnn_index_select_4d(ggml_backend_cann_context & ctx,
1707
- void * src_buffer,
1708
- int64_t * src_ne,
1709
- size_t * src_nb,
1710
- void * dst_buffer,
1711
- int64_t * dst_ne,
1712
- size_t * dst_nb,
1713
- ggml_tensor * index,
1714
- ggml_type type) {
1715
- for (int64_t i = 0; i < src_ne[3]; i++) {
1716
- for (int64_t j = 0; j < src_ne[2]; j++) {
1717
- // src
1718
- acl_tensor_ptr acl_src_tensor =
1719
- ggml_cann_create_tensor((char *) src_buffer + i * src_nb[3] + j * src_nb[2],
1720
- ggml_cann_type_mapping(type), ggml_type_size(type), src_ne, src_nb, 2);
1721
-
1722
- // index
1723
- acl_tensor_ptr acl_index = ggml_cann_create_tensor(
1724
- (char *) index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1],
1725
- ggml_cann_type_mapping(index->type), ggml_element_size(index), index->ne, index->nb, 1);
1726
-
1727
- // out
1728
- acl_tensor_ptr acl_out =
1729
- ggml_cann_create_tensor((char *) dst_buffer + i * dst_nb[3] + j * dst_nb[2],
1730
- ggml_cann_type_mapping(type), ggml_type_size(type), dst_ne, dst_nb, 2);
1731
- GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, acl_src_tensor.get(), 0, acl_index.get(), acl_out.get());
1732
- }
1733
- }
1734
- }
1735
-
1736
- /**
1737
- * @brief Performs inplace index copy operation on a 4D tensor using the CANN backend.
1738
- *
1739
- * This function applies the `IndexCopy` operation along a specific dimension of the
1740
- * destination tensor (`dst_buffer`) by copying elements from the source tensor (`src_buffer`)
1741
- * to positions specified by the index tensor (`index`).
1742
- * It iterates over the last two dimensions of the tensors, creates the corresponding
1743
- * CANN tensors for source, index, and destination slices, and performs the index copy
1744
- * operation for each slice.
1745
- *
1746
- * @param ctx The context for CANN backend operations.
1747
- * @param src_buffer The source buffer containing the 4D input tensor data to be copied.
1748
- * @param src_ne The dimensions of the source tensor.
1749
- * @param src_nb The strides (byte offsets) of the source tensor.
1750
- * @param dst_buffer The destination buffer where values will be copied to.
1751
- * @param dst_ne The dimensions of the destination tensor.
1752
- * @param dst_nb The strides (byte offsets) of the destination tensor.
1753
- * @param index The index tensor specifying target positions in the destination tensor.
1754
- * @param type The data type of the source and destination tensors.
1755
- */
1756
- static void aclnn_index_copy_4d(ggml_backend_cann_context & ctx,
1757
- void * src_buffer,
1758
- int64_t * src_ne,
1759
- size_t * src_nb,
1760
- void * dst_buffer,
1761
- int64_t * dst_ne,
1762
- size_t * dst_nb,
1763
- ggml_tensor * index,
1764
- ggml_type type) {
1765
- for (int64_t i = 0; i < src_ne[3]; i++) {
1766
- for (int64_t j = 0; j < src_ne[2]; j++) {
1767
- // src
1768
- acl_tensor_ptr acl_src_tensor =
1769
- ggml_cann_create_tensor((char *) src_buffer + i * src_nb[3] + j * src_nb[2],
1770
- ggml_cann_type_mapping(type), ggml_type_size(type), src_ne, src_nb, 2);
1771
-
1772
- // index
1773
- acl_tensor_ptr acl_index = ggml_cann_create_tensor(
1774
- (char *) index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1],
1775
- ggml_cann_type_mapping(index->type), ggml_element_size(index), index->ne, index->nb, 1);
1776
-
1777
- // out
1778
- acl_tensor_ptr acl_out =
1779
- ggml_cann_create_tensor((char *) dst_buffer + i * dst_nb[3] + j * dst_nb[2],
1780
- ggml_cann_type_mapping(type), ggml_type_size(type), dst_ne, dst_nb, 2);
1781
- GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexCopy, acl_out.get(), 0, acl_index.get(), acl_src_tensor.get());
1782
- }
1783
- }
1784
- }
1785
1921
 
1786
1922
  void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
1787
- ggml_tensor * src0 = dst->src[0]; // src
1923
+ ggml_tensor * src0 = dst->src[0]; // weight
1788
1924
  ggml_tensor * src1 = dst->src[1]; // index
1789
1925
 
1790
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1926
+ GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16
1927
+ || dst->type == GGML_TYPE_BF16);
1928
+
1929
+ // n_idx: number of row indices per (i2, i3) batch slice.
1930
+ // ggml guarantees: src0->ne[2] == src1->ne[1], src0->ne[3] == src1->ne[2], src1->ne[3] == 1.
1931
+ const int64_t n_idx = src1->ne[0];
1932
+
1933
+ // Gather all (i2, i3) batch slices from src into dst.
1934
+ // ggml_cann_create_tensor reverses dims, so ACL sees [ne1, ne0].
1935
+ // GatherV2 with dim=0 gathers along ACL dim-0 == ggml ne[1] (the vocabulary / row axis).
1936
+ // nb: the 4 strides of the source buffer (nb[0..1] for the 2D slice shape,
1937
+ // nb[2..3] for computing per-batch-slice base pointer offsets).
1938
+ auto gather_batched = [&](void * src_base, aclDataType acl_type, size_t type_size,
1939
+ const size_t * nb) {
1940
+ int64_t src_ne[2] = { src0->ne[0], src0->ne[1] };
1941
+ size_t src_nb_2d[2] = { nb[0], nb[1] };
1942
+ int64_t dst_ne[2] = { src0->ne[0], n_idx };
1943
+ size_t dst_nb_2d[2] = { dst->nb[0], dst->nb[1] };
1944
+ int64_t idx_ne[1] = { n_idx };
1945
+ size_t idx_nb[1] = { (size_t)ggml_element_size(src1) };
1946
+
1947
+ for (int64_t i3 = 0; i3 < src0->ne[3]; i3++) {
1948
+ for (int64_t i2 = 0; i2 < src0->ne[2]; i2++) {
1949
+ acl_tensor_ptr acl_src = ggml_cann_create_tensor(
1950
+ (char *)src_base + i3 * nb[3] + i2 * nb[2],
1951
+ acl_type, type_size, src_ne, src_nb_2d, 2);
1952
+ acl_tensor_ptr acl_idx = ggml_cann_create_tensor(
1953
+ (char *)src1->data + i3 * src1->nb[2] + i2 * src1->nb[1],
1954
+ ggml_cann_type_mapping(src1->type), (size_t)ggml_element_size(src1),
1955
+ idx_ne, idx_nb, 1);
1956
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(
1957
+ (char *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2],
1958
+ acl_type, type_size, dst_ne, dst_nb_2d, 2);
1959
+ GGML_CANN_CALL_ACLNN_OP(ctx, GatherV2, acl_src.get(), 0, acl_idx.get(), acl_dst.get());
1960
+ }
1961
+ }
1962
+ };
1791
1963
 
1792
1964
  switch (src0->type) {
1965
+ case GGML_TYPE_BF16:
1793
1966
  case GGML_TYPE_F16:
1794
1967
  case GGML_TYPE_F32:
1795
1968
  if (src0->type == dst->type) {
1796
- aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb, dst->data, dst->ne, dst->nb, src1,
1797
- dst->type);
1969
+ gather_batched(src0->data,
1970
+ ggml_cann_type_mapping(src0->type), ggml_type_size(src0->type),
1971
+ src0->nb);
1798
1972
  } else {
1799
- acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
1800
- ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * ggml_element_size(dst));
1801
- void * src_trans_buffer = src_buffer_allocator.get();
1802
- size_t src_trans_nb[GGML_MAX_DIMS];
1803
- src_trans_nb[0] = dst->nb[0];
1973
+ // Cast src0 to dst type, then gather.
1974
+ ggml_cann_pool_alloc src_cast_allocator(ctx.pool(),
1975
+ ggml_nelements(src0) * ggml_element_size(dst));
1976
+ size_t src_cast_nb[GGML_MAX_DIMS];
1977
+ src_cast_nb[0] = ggml_type_size(dst->type);
1804
1978
  for (int i = 1; i < GGML_MAX_DIMS; i++) {
1805
- src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
1979
+ src_cast_nb[i] = src_cast_nb[i - 1] * src0->ne[i - 1];
1806
1980
  }
1807
- acl_tensor_ptr src_trans_tensor =
1808
- ggml_cann_create_tensor(src_trans_buffer, ggml_cann_type_mapping(dst->type),
1809
- ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS);
1810
- aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type));
1811
- aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1,
1812
- dst->type);
1981
+ acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
1982
+ acl_tensor_ptr acl_src_cast = ggml_cann_create_tensor(
1983
+ src_cast_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
1984
+ src0->ne, src_cast_nb, GGML_MAX_DIMS);
1985
+ aclnn_cast(ctx, acl_src0.get(), acl_src_cast.get(), ggml_cann_type_mapping(dst->type));
1986
+
1987
+ gather_batched(src_cast_allocator.get(),
1988
+ ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
1989
+ src_cast_nb);
1813
1990
  }
1814
1991
  break;
1815
1992
  case GGML_TYPE_Q8_0:
1816
1993
  {
1817
- // add 1 dim for bcast mul.
1994
+ // Dequantize Q8_0 to dst type, then gather.
1818
1995
  size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1], dequant_nb[GGML_MAX_DIMS + 1];
1819
1996
  int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1], *dequant_ne;
1820
- int64_t scale_offset = 0;
1821
- // [3,4,5,64] -> [3,4,5,2,32]
1822
- weight_ne[0] = QK8_0;
1823
- weight_ne[1] = src0->ne[0] / QK8_0;
1824
- weight_nb[0] = sizeof(int8_t);
1825
- weight_nb[1] = weight_nb[0] * weight_ne[0];
1997
+ weight_ne[0] = QK8_0;
1998
+ weight_ne[1] = src0->ne[0] / QK8_0;
1999
+ weight_nb[0] = sizeof(int8_t);
2000
+ weight_nb[1] = weight_nb[0] * weight_ne[0];
1826
2001
  for (int i = 2; i < GGML_MAX_DIMS + 1; i++) {
1827
2002
  weight_ne[i] = src0->ne[i - 1];
1828
2003
  weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1];
1829
2004
  }
1830
- // [3,4,5,64] -> [3,4,5,2,1]
1831
2005
  scale_ne[0] = 1;
1832
2006
  scale_ne[1] = src0->ne[0] / QK8_0;
1833
2007
  scale_nb[0] = sizeof(uint16_t);
@@ -1836,31 +2010,33 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
1836
2010
  scale_ne[i] = src0->ne[i - 1];
1837
2011
  scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1];
1838
2012
  }
1839
- // [3,4,5,64] -> [3,4,5,2,32]
1840
2013
  dequant_ne = weight_ne;
1841
2014
  dequant_nb[0] = ggml_type_size(dst->type);
1842
2015
  for (int i = 1; i < GGML_MAX_DIMS + 1; i++) {
1843
2016
  dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1];
1844
2017
  }
1845
- scale_offset = ggml_nelements(src0) * sizeof(int8_t);
1846
- ggml_cann_pool_alloc dequant_buffer_allocator(ctx.pool(),
1847
- ggml_nelements(src0) * ggml_type_size(dst->type));
1848
- acl_tensor_ptr acl_weight_tensor = ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t),
1849
- weight_ne, weight_nb, GGML_MAX_DIMS + 1);
1850
- acl_tensor_ptr acl_scale_tensor =
1851
- ggml_cann_create_tensor(src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb,
1852
- GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset);
1853
- acl_tensor_ptr dequant_tensor =
1854
- ggml_cann_create_tensor(dequant_buffer_allocator.get(), ggml_cann_type_mapping(dst->type),
1855
- ggml_type_size(dst->type), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1);
1856
- aclnn_mul(ctx, acl_weight_tensor.get(), acl_scale_tensor.get(), dequant_tensor.get());
1857
- dequant_nb[0] = ggml_type_size(dst->type);
2018
+ const int64_t scale_offset = ggml_nelements(src0) * sizeof(int8_t);
2019
+ ggml_cann_pool_alloc dequant_allocator(ctx.pool(),
2020
+ ggml_nelements(src0) * ggml_type_size(dst->type));
2021
+ acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t),
2022
+ weight_ne, weight_nb, GGML_MAX_DIMS + 1);
2023
+ acl_tensor_ptr acl_scale = ggml_cann_create_tensor(
2024
+ src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb,
2025
+ GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset);
2026
+ acl_tensor_ptr acl_dequant = ggml_cann_create_tensor(
2027
+ dequant_allocator.get(), ggml_cann_type_mapping(dst->type),
2028
+ ggml_type_size(dst->type), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1);
2029
+ aclnn_mul(ctx, acl_weight.get(), acl_scale.get(), acl_dequant.get());
2030
+
2031
+ // Reinterpret dequant buffer as 4D [src0->ne] with contiguous strides.
1858
2032
  dequant_ne = src0->ne;
2033
+ dequant_nb[0] = ggml_type_size(dst->type);
1859
2034
  for (int i = 1; i < GGML_MAX_DIMS; i++) {
1860
2035
  dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1];
1861
2036
  }
1862
- aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(), dequant_ne, dequant_nb, dst->data, dst->ne,
1863
- dst->nb, src1, dst->type);
2037
+ gather_batched(dequant_allocator.get(),
2038
+ ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
2039
+ dequant_nb);
1864
2040
  break;
1865
2041
  }
1866
2042
  default:
@@ -1870,30 +2046,70 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
1870
2046
  }
1871
2047
 
1872
2048
  void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
1873
- ggml_tensor * src0 = dst->src[0]; // src
1874
- ggml_tensor * src1 = dst->src[1]; // index
2049
+ ggml_tensor * src0 = dst->src[0]; // source values
2050
+ ggml_tensor * src1 = dst->src[1]; // row indices
2051
+
2052
+ // n_idx: number of source rows to scatter per batch slice.
2053
+ // ggml guarantees: src0->ne[1] == src1->ne[0].
2054
+ const int64_t n_idx = src1->ne[0];
2055
+
2056
+ // Copy n_idx rows of src [ne0, n_idx] into dst [ne0, ne1] at positions given by a 1D index.
2057
+ // ggml_cann_create_tensor reverses dims, so ACL sees [ne1, ne0] for dst.
2058
+ // InplaceIndexCopy with dim=0 copies along ACL dim-0 == ggml ne[1] (the row axis).
2059
+ // src_nb: the 4 strides of the source buffer (nb[0..1] for the 2D slice shape,
2060
+ // nb[2..3] for computing per-batch-slice base pointer offsets).
2061
+ auto scatter_batched = [&](void * src_base, aclDataType acl_type, size_t type_size,
2062
+ const size_t * src_nb) {
2063
+ int64_t d_ne[2] = { dst->ne[0], dst->ne[1] };
2064
+ size_t d_nb[2] = { dst->nb[0], dst->nb[1] };
2065
+ int64_t s_ne[2] = { dst->ne[0], n_idx };
2066
+ size_t s_nb_2d[2] = { src_nb[0], src_nb[1] };
2067
+ int64_t i_ne[1] = { n_idx };
2068
+ size_t i_nb[1] = { (size_t)ggml_element_size(src1) };
2069
+
2070
+ for (int64_t i3 = 0; i3 < dst->ne[3]; i3++) {
2071
+ for (int64_t i2 = 0; i2 < dst->ne[2]; i2++) {
2072
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(
2073
+ (char *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2],
2074
+ acl_type, type_size, d_ne, d_nb, 2);
2075
+ acl_tensor_ptr acl_idx = ggml_cann_create_tensor(
2076
+ (char *)src1->data + (i3 % src1->ne[2]) * src1->nb[2] + (i2 % src1->ne[1]) * src1->nb[1],
2077
+ ggml_cann_type_mapping(src1->type), (size_t)ggml_element_size(src1),
2078
+ i_ne, i_nb, 1);
2079
+ acl_tensor_ptr acl_src = ggml_cann_create_tensor(
2080
+ (char *)src_base + i3 * src_nb[3] + i2 * src_nb[2],
2081
+ acl_type, type_size, s_ne, s_nb_2d, 2);
2082
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexCopy, acl_dst.get(), 0, acl_idx.get(), acl_src.get());
2083
+ }
2084
+ }
2085
+ };
1875
2086
 
1876
2087
  switch (dst->type) {
1877
2088
  case GGML_TYPE_F32:
1878
- {
1879
- aclnn_index_copy_4d(ctx, src0->data, src0->ne, src0->nb, dst->data, dst->ne, dst->nb, src1, dst->type);
1880
- break;
1881
- }
2089
+ scatter_batched(src0->data,
2090
+ ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
2091
+ src0->nb);
2092
+ break;
1882
2093
  case GGML_TYPE_F16:
2094
+ case GGML_TYPE_BF16:
1883
2095
  {
1884
- acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
1885
- ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(uint16_t));
1886
- void * src_trans_buffer = src_buffer_allocator.get();
1887
- size_t src_trans_nb[GGML_MAX_DIMS];
1888
- src_trans_nb[0] = sizeof(uint16_t);
2096
+ // Cast src0 (F32) to dst type first.
2097
+ ggml_cann_pool_alloc src_cast_allocator(ctx.pool(),
2098
+ ggml_nelements(src0) * ggml_type_size(dst->type));
2099
+ size_t src_cast_nb[GGML_MAX_DIMS];
2100
+ src_cast_nb[0] = ggml_type_size(dst->type);
1889
2101
  for (int i = 1; i < GGML_MAX_DIMS; i++) {
1890
- src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
2102
+ src_cast_nb[i] = src_cast_nb[i - 1] * src0->ne[i - 1];
1891
2103
  }
1892
- acl_tensor_ptr src_trans_tensor = ggml_cann_create_tensor(
1893
- src_trans_buffer, ACL_FLOAT16, ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS);
1894
- aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type));
1895
- aclnn_index_copy_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1,
1896
- dst->type);
2104
+ acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
2105
+ acl_tensor_ptr acl_src_cast = ggml_cann_create_tensor(
2106
+ src_cast_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
2107
+ src0->ne, src_cast_nb, GGML_MAX_DIMS);
2108
+ aclnn_cast(ctx, acl_src0.get(), acl_src_cast.get(), ggml_cann_type_mapping(dst->type));
2109
+
2110
+ scatter_batched(src_cast_allocator.get(),
2111
+ ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
2112
+ src_cast_nb);
1897
2113
  break;
1898
2114
  }
1899
2115
  default:
@@ -1964,7 +2180,7 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context & ctx, ggml_tensor *
1964
2180
 
1965
2181
  // Only check env once.
1966
2182
  static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
1967
- if (weight_to_nz && is_matmul_weight(weight)) {
2183
+ if (weight_to_nz && weight->type != GGML_TYPE_BF16 && is_matmul_weight(weight)) {
1968
2184
  acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);
1969
2185
  } else {
1970
2186
  acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND);
@@ -2145,6 +2361,9 @@ void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
2145
2361
  switch (type) {
2146
2362
  case GGML_TYPE_F32:
2147
2363
  case GGML_TYPE_F16:
2364
+ #ifndef ASCEND_310P
2365
+ case GGML_TYPE_BF16:
2366
+ #endif
2148
2367
  ggml_cann_mat_mul_fp(ctx, dst);
2149
2368
  break;
2150
2369
  case GGML_TYPE_Q4_0:
@@ -2338,20 +2557,21 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
2338
2557
 
2339
2558
  // Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.
2340
2559
  // TODO: acl_yarn_ramp_tensor use rope cache.
2341
- bool yarn_ramp_tensor_updated = false;
2342
- acl_tensor_ptr acl_yarn_ramp_tensor;
2560
+ bool yarn_ramp_tensor_updated = false;
2561
+ acl_tensor_ptr acl_yarn_ramp_tensor;
2343
2562
  if (ext_factor != 0 && (theta_scale_updated || ctx.rope_cache.theta_scale_length != theta_scale_length ||
2344
2563
  ctx.rope_cache.freq_scale != freq_scale)) {
2345
2564
  yarn_ramp_tensor_updated = true;
2346
2565
  if (ctx.rope_cache.yarn_ramp_cache != nullptr) {
2347
2566
  ACL_CHECK(aclrtFree(ctx.rope_cache.yarn_ramp_cache));
2348
2567
  }
2349
- ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
2568
+ ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float),
2569
+ ACL_MEM_MALLOC_HUGE_FIRST));
2350
2570
  // -rope_yarn_ramp
2351
2571
  // const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
2352
2572
  // return MIN(1, MAX(0, y)) - 1;
2353
- acl_yarn_ramp_tensor =
2354
- ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
2573
+ acl_yarn_ramp_tensor = ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float),
2574
+ theta_scale_ne, theta_scale_nb, 1);
2355
2575
  float zero_value = 0, one_value = 1;
2356
2576
  float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
2357
2577
  acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT);
@@ -2382,8 +2602,8 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
2382
2602
  GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get());
2383
2603
  GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
2384
2604
  } else {
2385
- acl_yarn_ramp_tensor =
2386
- ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
2605
+ acl_yarn_ramp_tensor = ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float),
2606
+ theta_scale_ne, theta_scale_nb, 1);
2387
2607
  }
2388
2608
  // Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale.
2389
2609
  if (ext_factor != 0) {
@@ -2941,6 +3161,27 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
2941
3161
  // Rotate full tensor (no tail), using trans tensors
2942
3162
  GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor.get(), acl_cos_reshape_tensor.get(),
2943
3163
  acl_sin_reshape_tensor.get(), acl_mode, acl_dst_trans_tensor.get());
3164
+ } else if (src0->data == dst->data && !ggml_is_contiguous(src0)) {
3165
+ // In-place on non-contiguous tensor: RotaryPositionEmbedding cannot safely
3166
+ // read and write the same non-contiguous buffer. Use contiguous temporaries.
3167
+ size_t contiguous_nb[GGML_MAX_DIMS];
3168
+ contiguous_nb[0] = sizeof(float);
3169
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
3170
+ contiguous_nb[i] = contiguous_nb[i - 1] * src0->ne[i - 1];
3171
+ }
3172
+ int64_t total_elements = ggml_nelements(src0);
3173
+ ggml_cann_pool_alloc inplace_src_alloc(ctx.pool(), total_elements * sizeof(float));
3174
+ ggml_cann_pool_alloc inplace_dst_alloc(ctx.pool(), total_elements * sizeof(float));
3175
+
3176
+ acl_tensor_ptr acl_src_contig = ggml_cann_create_tensor(inplace_src_alloc.get(), ACL_FLOAT, sizeof(float),
3177
+ src0->ne, contiguous_nb, GGML_MAX_DIMS);
3178
+ acl_tensor_ptr acl_dst_contig = ggml_cann_create_tensor(inplace_dst_alloc.get(), ACL_FLOAT, sizeof(float),
3179
+ dst->ne, contiguous_nb, GGML_MAX_DIMS);
3180
+
3181
+ cann_copy(ctx, acl_src.get(), acl_src_contig.get());
3182
+ GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_contig.get(), acl_cos_reshape_tensor.get(),
3183
+ acl_sin_reshape_tensor.get(), acl_mode, acl_dst_contig.get());
3184
+ cann_copy(ctx, acl_dst_contig.get(), acl_dst.get());
2944
3185
  } else {
2945
3186
  // Rotate full tensor (no tail), using original tensors
2946
3187
  GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src.get(), acl_cos_reshape_tensor.get(),
@@ -2982,6 +3223,58 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
2982
3223
  }
2983
3224
  }
2984
3225
 
3226
+ void ggml_cann_rope_cache_preload(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3227
+ ggml_tensor * src0 = dst->src[0];
3228
+
3229
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
3230
+ int sections[4];
3231
+ const int n_dims = ((int32_t *) dst->op_params)[1];
3232
+ const int mode = ((int32_t *) dst->op_params)[2];
3233
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
3234
+
3235
+ GGML_TENSOR_UNARY_OP_LOCALS
3236
+
3237
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
3238
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
3239
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
3240
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
3241
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
3242
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
3243
+ memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int) * 4);
3244
+
3245
+ const float theta_scale = powf(freq_base, -2.0f / n_dims);
3246
+
3247
+ float corr_dims[2];
3248
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
3249
+
3250
+ bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
3251
+ const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
3252
+ const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;
3253
+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
3254
+
3255
+ if (is_imrope || mrope_used) {
3256
+ is_neox = true;
3257
+ }
3258
+
3259
+ int64_t rope_dims = n_dims;
3260
+ if (is_vision) {
3261
+ rope_dims = src0->ne[0];
3262
+ }
3263
+
3264
+ // Run the full cache init on the non-captured stream. This performs all
3265
+ // host-to-device memcpy, aclrtMalloc/Free, and on-device computations
3266
+ // so that the memory pool is warmed up and cache metadata is populated.
3267
+ aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections,
3268
+ mrope_used, is_imrope, is_vision, rope_dims);
3269
+
3270
+ // Reset `cached` so that during graph capture the on-device computations
3271
+ // (sin/cos, position multiply, repeat, etc.) still execute and get recorded
3272
+ // into the captured graph. The cache metadata (theta_scale_length,
3273
+ // theta_scale, sections, position_length, etc.) remains set, which causes
3274
+ // all host-to-device copy and malloc/free branches to be skipped.
3275
+ ctx.rope_cache.cached = false;
3276
+ }
3277
+
2985
3278
  void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
2986
3279
  ggml_tensor * src0 = dst->src[0];
2987
3280
 
@@ -2991,20 +3284,20 @@ void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
2991
3284
  GGML_CANN_CALL_ACLNN_OP(ctx, ArgMax, acl_src.get(), 3, false, acl_dst.get());
2992
3285
  }
2993
3286
 
2994
- void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){
3287
+ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
2995
3288
  ggml_tensor * src0 = dst->src[0];
2996
3289
  ggml_tensor * src1 = dst->src[1];
2997
3290
 
2998
3291
  // stride
2999
- int64_t s0 = ((const int32_t*)(dst->op_params))[0];
3292
+ int64_t s0 = ((const int32_t *) (dst->op_params))[0];
3000
3293
 
3001
- acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL);
3294
+ acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL);
3002
3295
  acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL);
3003
- acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL);
3296
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL);
3004
3297
 
3005
3298
  // get base information of input and kernel
3006
- int64_t input_len = *(src1->ne);
3007
- int64_t dst_len = *(dst->ne);
3299
+ int64_t input_len = *(src1->ne);
3300
+ int64_t dst_len = *(dst->ne);
3008
3301
  int64_t kernel_size = *(src0->ne);
3009
3302
 
3010
3303
  // set the max kernel size for each conv
@@ -3012,56 +3305,55 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
3012
3305
 
3013
3306
  // compute the partition of kernel
3014
3307
  int64_t part_num = 1;
3015
- part_num = (kernel_size + max_kernel_size - 1) / max_kernel_size;
3308
+ part_num = (kernel_size + max_kernel_size - 1) / max_kernel_size;
3016
3309
 
3017
3310
  int64_t strideVal[1];
3018
- strideVal[0] = s0;
3019
- acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1);
3020
- int64_t paddingVal[] = {0};
3021
- acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1);
3022
- int64_t dilationVal[] = {1};
3023
- acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1);
3024
- bool transposed = true;
3025
- int64_t groups = 1;
3026
- int8_t cubeMathType = 0;
3311
+ strideVal[0] = s0;
3312
+ acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1);
3313
+ int64_t paddingVal[] = { 0 };
3314
+ acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1);
3315
+ int64_t dilationVal[] = { 1 };
3316
+ acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1);
3317
+ bool transposed = true;
3318
+ int64_t groups = 1;
3319
+ int8_t cubeMathType = 0;
3027
3320
 
3028
3321
  #ifdef ASCEND_310P
3029
3322
  cubeMathType = 1;
3030
3323
  #endif
3031
3324
 
3032
3325
  auto weight_type = ggml_cann_type_mapping(src0->type);
3033
- auto dst_type = ggml_cann_type_mapping(dst->type);
3326
+ auto dst_type = ggml_cann_type_mapping(dst->type);
3034
3327
 
3035
3328
  // slice the kernel to make each conv available
3036
- int64_t slice_dim = -1;
3329
+ int64_t slice_dim = -1;
3037
3330
  int64_t slice_start = 0;
3038
- int64_t slice_end = max_kernel_size;
3039
- int64_t slice_step = 1;
3040
- int64_t interval = max_kernel_size;
3331
+ int64_t slice_end = max_kernel_size;
3332
+ int64_t slice_step = 1;
3333
+ int64_t interval = max_kernel_size;
3041
3334
 
3042
- int64_t left_pad_len = dilationVal[0] * (max_kernel_size - 1) + 1 - 2 * paddingVal[0];
3335
+ int64_t left_pad_len = dilationVal[0] * (max_kernel_size - 1) + 1 - 2 * paddingVal[0];
3043
3336
  int64_t right_pad_len = 0;
3044
3337
 
3045
- acl_scalar_ptr alpha = nullptr;
3046
- float alphaValue = 1.0;
3047
- alpha = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT);
3338
+ acl_scalar_ptr alpha = nullptr;
3339
+ float alphaValue = 1.0;
3340
+ alpha = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT);
3048
3341
 
3049
3342
  // set zero to destination
3050
3343
  GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get());
3051
3344
 
3052
- for(int k = 0; k < part_num; k++){
3053
-
3345
+ for (int k = 0; k < part_num; k++) {
3054
3346
  // create part kernel tensor and slice from big kernel
3055
3347
  slice_start = max_kernel_size * k;
3056
- if(k == part_num - 1){
3348
+ if (k == part_num - 1) {
3057
3349
  slice_end = kernel_size;
3058
- interval = kernel_size - max_kernel_size * k;
3059
- }else{
3060
- slice_end = max_kernel_size * (k+1);
3350
+ interval = kernel_size - max_kernel_size * k;
3351
+ } else {
3352
+ slice_end = max_kernel_size * (k + 1);
3061
3353
  }
3062
3354
 
3063
3355
  int64_t part_ne[4];
3064
- for(int i = 0; i < 4; i++) {
3356
+ for (int i = 0; i < 4; i++) {
3065
3357
  part_ne[i] = *(src0->ne + i);
3066
3358
  }
3067
3359
  part_ne[0] = interval;
@@ -3074,16 +3366,17 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
3074
3366
 
3075
3367
  ggml_cann_pool_alloc part_kernel_allocator;
3076
3368
  part_kernel_allocator.alloc(ctx.pool(), part_nb[3]);
3077
- void* part_kernel_buf = part_kernel_allocator.get();
3369
+ void * part_kernel_buf = part_kernel_allocator.get();
3078
3370
 
3079
- acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type,
3080
- ggml_element_size(src0), part_ne, part_nb, 3, ACL_FORMAT_NCL);
3371
+ acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type, ggml_element_size(src0),
3372
+ part_ne, part_nb, 3, ACL_FORMAT_NCL);
3081
3373
 
3082
- GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step, part_kernel.get());
3374
+ GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step,
3375
+ part_kernel.get());
3083
3376
 
3084
3377
  // create the part conv result tensor
3085
3378
  int64_t part_dst_ne[4];
3086
- for(int i = 0; i < 4; i++){
3379
+ for (int i = 0; i < 4; i++) {
3087
3380
  part_dst_ne[i] = *(dst->ne + i);
3088
3381
  }
3089
3382
  part_dst_ne[0] = (input_len - 1) * strideVal[0] - 2 * paddingVal[0] + dilationVal[0] * (part_ne[0] - 1) + 1;
@@ -3095,32 +3388,33 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
3095
3388
  }
3096
3389
  ggml_cann_pool_alloc part_dst_allocator;
3097
3390
  part_dst_allocator.alloc(ctx.pool(), part_dst_nb[3]);
3098
- void* part_dst_buf = part_dst_allocator.get();
3391
+ void * part_dst_buf = part_dst_allocator.get();
3099
3392
 
3100
3393
  acl_tensor_ptr acl_part_dst = ggml_cann_create_tensor(part_dst_buf, dst_type, ggml_element_size(dst),
3101
- part_dst_ne, part_dst_nb, 3, ACL_FORMAT_NCL);
3394
+ part_dst_ne, part_dst_nb, 3, ACL_FORMAT_NCL);
3102
3395
  GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_part_dst.get());
3103
3396
 
3104
3397
  // compute part conv transpose 1d
3105
3398
  GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input.get(), part_kernel.get(), nullptr, stride.get(),
3106
- padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(), cubeMathType);
3399
+ padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(),
3400
+ cubeMathType);
3107
3401
 
3108
3402
  // compute the position of part result in final result
3109
3403
  int64_t global_start = slice_start;
3110
- int64_t global_end = std::min((input_len - 1) * strideVal[0] + slice_end, dst_len);
3404
+ int64_t global_end = std::min((input_len - 1) * strideVal[0] + slice_end, dst_len);
3111
3405
 
3112
- left_pad_len = global_start;
3406
+ left_pad_len = global_start;
3113
3407
  right_pad_len = dst_len - global_end;
3114
3408
 
3115
- std::vector<int64_t> padDataVal = {left_pad_len,right_pad_len};
3116
- acl_int_array_ptr padData = ggml_cann_create_int_array(padDataVal.data(), 2);
3409
+ std::vector<int64_t> padDataVal = { left_pad_len, right_pad_len };
3410
+ acl_int_array_ptr padData = ggml_cann_create_int_array(padDataVal.data(), 2);
3117
3411
 
3118
- acl_scalar_ptr pad_value = nullptr;
3119
- float pad_valueVal = 0.0;
3120
- pad_value = ggml_cann_create_scalar(&pad_valueVal, aclDataType::ACL_FLOAT);
3412
+ acl_scalar_ptr pad_value = nullptr;
3413
+ float pad_valueVal = 0.0;
3414
+ pad_value = ggml_cann_create_scalar(&pad_valueVal, aclDataType::ACL_FLOAT);
3121
3415
 
3122
3416
  int64_t conv_result_ne[4];
3123
- for(int i = 0; i < 4; i++){
3417
+ for (int i = 0; i < 4; i++) {
3124
3418
  conv_result_ne[i] = *(dst->ne + i);
3125
3419
  }
3126
3420
 
@@ -3132,13 +3426,14 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
3132
3426
 
3133
3427
  ggml_cann_pool_alloc conv_result_allocator;
3134
3428
  conv_result_allocator.alloc(ctx.pool(), conv_result_nb[3]);
3135
- void* conv_result_buf = conv_result_allocator.get();
3429
+ void * conv_result_buf = conv_result_allocator.get();
3136
3430
 
3137
3431
  acl_tensor_ptr conv_result = ggml_cann_create_tensor(conv_result_buf, dst_type, ggml_element_size(dst),
3138
- conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL);
3432
+ conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL);
3139
3433
 
3140
3434
  GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, conv_result.get());
3141
- GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(), conv_result.get());
3435
+ GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(),
3436
+ conv_result.get());
3142
3437
  GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst.get(), conv_result.get(), alpha.get());
3143
3438
  }
3144
3439
  }
@@ -3175,29 +3470,50 @@ void ggml_cann_pad_reflect_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst
3175
3470
  int64_t paddingsArray[2] = { opts[0], opts[1] };
3176
3471
  acl_int_array_ptr paddings = ggml_cann_create_int_array(paddingsArray, 2);
3177
3472
 
3178
- for (int64_t i = 0; i < src0->ne[3]; i++) {
3179
- acl_tensor_ptr acl_src =
3180
- ggml_cann_create_tensor((char *) src0->data + i * src0->ne[3], ggml_cann_type_mapping(src0->type),
3181
- ggml_element_size(src0), src0->ne, src0->nb, 3);
3473
+ // Collapsing ne[2]*ne[3] into a single batch dimension requires that dim3
3474
+ // is contiguous with respect to dim2 in both src and dst.
3475
+ GGML_ASSERT(src0->nb[3] == src0->nb[2] * src0->ne[2]);
3476
+ GGML_ASSERT(dst->nb[3] == dst->nb[2] * dst->ne[2]);
3182
3477
 
3183
- acl_tensor_ptr acl_dst =
3184
- ggml_cann_create_tensor((char *) dst->data + i * src0->ne[3], ggml_cann_type_mapping(dst->type),
3185
- ggml_element_size(dst), dst->ne, dst->nb, 3);
3478
+ int64_t src_ne_3d[3] = { src0->ne[0], src0->ne[1], src0->ne[2] * src0->ne[3] };
3479
+ int64_t dst_ne_3d[3] = { dst->ne[0], dst->ne[1], dst->ne[2] * dst->ne[3] };
3186
3480
 
3187
- GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src.get(), paddings.get(), acl_dst.get());
3188
- }
3481
+ acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0->data, ggml_cann_type_mapping(src0->type),
3482
+ ggml_element_size(src0), src_ne_3d, src0->nb, 3);
3483
+
3484
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type),
3485
+ ggml_element_size(dst), dst_ne_3d, dst->nb, 3);
3486
+
3487
+ GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src.get(), paddings.get(), acl_dst.get());
3189
3488
  }
3190
3489
 
3191
3490
  void ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3192
3491
  ggml_tensor * src0 = dst->src[0];
3193
3492
  ggml_tensor * src1 = dst->src[1];
3194
3493
 
3494
+ // Write element-wise equality (0 or 1) into a temporary buffer to avoid
3495
+ // modifying src0 in-place. Use the same type as src0 so ReduceSum can
3496
+ // consume it directly without a type cast.
3497
+ ggml_cann_pool_alloc eq_alloc(ctx.pool(), ggml_nelements(src0) * ggml_element_size(src0));
3498
+ size_t eq_nb[GGML_MAX_DIMS];
3499
+ eq_nb[0] = ggml_element_size(src0);
3500
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
3501
+ eq_nb[i] = eq_nb[i - 1] * src0->ne[i - 1];
3502
+ }
3503
+ acl_tensor_ptr acl_eq = ggml_cann_create_tensor(
3504
+ eq_alloc.get(), ggml_cann_type_mapping(src0->type), ggml_element_size(src0),
3505
+ src0->ne, eq_nb, GGML_MAX_DIMS);
3506
+
3195
3507
  acl_tensor_ptr acl_self = ggml_cann_create_tensor(src0);
3196
3508
  acl_tensor_ptr acl_other = ggml_cann_create_tensor(src1);
3197
-
3198
- GGML_CANN_CALL_ACLNN_OP(ctx, InplaceEqTensor, acl_self.get(), acl_other.get());
3199
-
3200
- ggml_cann_sum(ctx, dst);
3509
+ GGML_CANN_CALL_ACLNN_OP(ctx, EqTensor, acl_self.get(), acl_other.get(), acl_eq.get());
3510
+
3511
+ // Sum the 0/1 values into dst.
3512
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
3513
+ int64_t dims[4] = { 0, 1, 2, 3 };
3514
+ acl_int_array_ptr dims_arr = ggml_cann_create_int_array(dims, 4);
3515
+ GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_eq.get(), dims_arr.get(), true,
3516
+ ggml_cann_type_mapping(dst->type), acl_dst.get());
3201
3517
  }
3202
3518
 
3203
3519
  void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
@@ -3213,6 +3529,27 @@ void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3213
3529
  GGML_CANN_CALL_ACLNN_OP(ctx, GtScalar, acl_src.get(), alpha.get(), acl_dst.get());
3214
3530
  }
3215
3531
 
3532
+ void ggml_cann_softplus(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3533
+ ggml_tensor * src0 = dst->src[0];
3534
+
3535
+ acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0);
3536
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
3537
+
3538
+ float beta_val = 1.0f;
3539
+ float threshold_val = 20.0f;
3540
+ acl_scalar_ptr beta = ggml_cann_create_scalar(&beta_val, ACL_FLOAT);
3541
+ acl_scalar_ptr threshold = ggml_cann_create_scalar(&threshold_val, ACL_FLOAT);
3542
+
3543
+ GGML_CANN_CALL_ACLNN_OP(ctx, Softplus, acl_src.get(), beta.get(), threshold.get(), acl_dst.get());
3544
+ }
3545
+
3546
+ void ggml_cann_geglu_quick(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3547
+ auto gelu_quick_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
3548
+ GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
3549
+ };
3550
+ ggml_cann_op_unary_gated(gelu_quick_fn, ctx, dst);
3551
+ }
3552
+
3216
3553
  /**
3217
3554
  * @brief Performs expert-specific matrix multiplication (MoE) with
3218
3555
  * floating-point precision using the CANN backend.
@@ -3282,130 +3619,223 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context & ctx, ggml_tensor
3282
3619
  }
3283
3620
 
3284
3621
  /**
3285
- * @brief Performs expert-specific matrix multiplication (MoE) with
3286
- * quantized precision using the CANN backend.
3622
+ * @brief Performs quantized matrix multiplication for Mixture of Experts (MoE)
3623
+ * models using the CANN backend.
3287
3624
  *
3288
- * This function executes a matrix multiplication operation tailored for
3289
- * Mixture of Experts (MoE) models, where the input tensor is multiplied
3290
- * with expert-specific quantized weight matrices. It leverages the CANN
3291
- * backend to perform efficient low-precision computations and stores the
3292
- * quantized result in the destination tensor `dst`.
3625
+ * This function implements MUL_MAT_ID operation for quantized weight matrices
3626
+ * (Q4_0 and Q8_0 formats). It selects expert-specific weight matrices based on
3627
+ * the provided expert indices, and computes matrix multiplication using CANN's
3628
+ * WeightQuantBatchMatmulV2 operator.
3293
3629
  *
3294
- * Quantization techniques reduce memory footprint and improve performance
3295
- * by using lower-bit representations (e.g., int8) instead of floating-point.
3296
- * This function is designed to work with such formats and may incorporate
3297
- * optimizations like identity-based fast paths or routing masks for sparse
3298
- * expert selection.
3630
+ * The function performs the following steps:
3631
+ * 1. Converts input/output tensors to F16 format if necessary
3632
+ * 2. Uses IndexSelect to extract expert-specific weights and scales based on indices
3633
+ * 3. Performs quantized matrix multiplication for each expert using WeightQuantBatchMatmulV2
3634
+ * 4. Converts output back to the target type if needed
3299
3635
  *
3300
- * @param ctx The context for executing CANN backend operations.
3301
- * @param dst The destination tensor where the quantized MoE multiplication result
3302
- * will be stored.
3636
+ * Tensor shapes:
3637
+ * - dst: [M, K, N, 1] - output tensor
3638
+ * - src0: [D, M, A, 1] - quantized weight matrices (Q4_0 or Q8_0)
3639
+ * - src1: [D, B, N, 1] - input activations (B = K for per-expert input, or B = 1 for broadcast)
3640
+ * - ids: [K, N] - expert indices for routing
3641
+ *
3642
+ * @param ctx The CANN backend context for operation execution.
3643
+ * @param dst The destination tensor where the multiplication result will be stored.
3303
3644
  *
3304
- * @note This function assumes quantized data types and is designed for
3305
- * MoE architectures with potential sparse expert routing.
3645
+ * @note Only Q4_0 and Q8_0 quantization formats are supported.
3646
+ * @note The function handles automatic type conversion to/from F16 as needed by the hardware.
3306
3647
  */
3307
3648
  static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3308
- // TODO: Use aclnnGroupedMatMul
3309
- //dst [M, K, N, 1]
3310
- ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
3311
- ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
3312
- ggml_tensor * ids = dst->src[2]; //ids [K, N]
3649
+ // dst: [M, K, N, 1]
3650
+ // src0: [D, M, A, 1] - quantized weights
3651
+ // src1: [D, B, N, 1] - input activations, B = K or B = 1
3652
+ // ids: [K, N] - expert indices
3653
+ ggml_tensor * src0 = dst->src[0];
3654
+ ggml_tensor * src1 = dst->src[1];
3655
+ ggml_tensor * ids = dst->src[2];
3313
3656
 
3314
- GGML_TENSOR_BINARY_OP_LOCALS
3657
+ GGML_ASSERT(src0->ne[3] == 1);
3658
+ GGML_ASSERT(src1->ne[3] == 1);
3659
+ GGML_ASSERT(dst->ne[3] == 1);
3660
+ GGML_ASSERT(src1->ne[2] == ids->ne[1]);
3661
+
3662
+ const int64_t n_batches = ids->ne[1];
3663
+ const int64_t n_select_experts = ids->ne[0];
3664
+ const enum ggml_type type = src0->type;
3665
+
3666
+ const int32_t group_size = QK8_0; // Both Q4_0 and Q8_0 use group size of 32
3667
+ GGML_ASSERT(group_size == QK4_0);
3668
+
3669
+ // Calculate element size for quantized weights
3670
+ const float weight_elem_size =
3671
+ (type == GGML_TYPE_Q4_0) ? 0.5f :
3672
+ (type == GGML_TYPE_Q8_0) ? 1.0f :
3673
+ (GGML_ABORT("MUL_MAT_ID only supports Q4_0 and Q8_0"), 0.0f);
3674
+
3675
+ // Calculate scale offset in memory
3676
+ const size_t weight_size = src0->ne[0] * src0->ne[1] * src0->ne[2] * weight_elem_size;
3677
+ const size_t scale_elem_size = sizeof(uint16_t);
3678
+ char * scale_data = (char *) src0->data + weight_size;
3679
+
3680
+ // Allocate buffers for selected expert weights and scales
3681
+ const size_t selected_weight_size = src0->ne[0] * src0->ne[1] * n_select_experts * weight_elem_size;
3682
+ ggml_cann_pool_alloc selected_weight_alloc(ctx.pool(), selected_weight_size);
3683
+ void * selected_weight_buffer = selected_weight_alloc.get();
3684
+
3685
+ const size_t selected_scale_size = (src0->ne[0] / group_size) * src0->ne[1] * n_select_experts * scale_elem_size;
3686
+ ggml_cann_pool_alloc selected_scale_alloc(ctx.pool(), selected_scale_size);
3687
+ void * selected_scale_buffer = selected_scale_alloc.get();
3688
+
3689
+ // Helper lambda to allocate and cast tensor to F16 if needed
3690
+ constexpr size_t f16_elem_size = sizeof(uint16_t);
3691
+ auto prepare_f16_buffer = [&](ggml_tensor * tensor, ggml_cann_pool_alloc & allocator,
3692
+ bool need_cast = false) -> void * {
3693
+ if (tensor->type == GGML_TYPE_F16) {
3694
+ return tensor->data;
3695
+ }
3696
+
3697
+ size_t total_size = f16_elem_size;
3698
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
3699
+ total_size *= tensor->ne[i];
3700
+ }
3701
+ void * buffer = allocator.alloc(total_size);
3702
+
3703
+ if (need_cast == false) {
3704
+ return buffer;
3705
+ }
3315
3706
 
3316
- // copy index from npu to cpu
3317
- int64_t n_as = ne02; // A
3318
- int64_t n_ids = ids->ne[0]; // K
3707
+ int64_t ne[GGML_MAX_DIMS];
3708
+ size_t nb[GGML_MAX_DIMS] = { f16_elem_size };
3709
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
3710
+ ne[i] = tensor->ne[i];
3711
+ if (i > 0) {
3712
+ nb[i] = nb[i - 1] * ne[i - 1];
3713
+ }
3714
+ }
3319
3715
 
3320
- std::vector<char> ids_host(ggml_nbytes(ids));
3321
- ACL_CHECK(aclrtMemcpyAsync(ids_host.data(), ggml_nbytes(ids), ids->data, ggml_nbytes(ids),
3322
- ACL_MEMCPY_DEVICE_TO_HOST, ctx.stream()));
3323
- ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
3716
+ acl_tensor_ptr src_tensor = ggml_cann_create_tensor(tensor);
3717
+ acl_tensor_ptr f16_tensor = ggml_cann_create_tensor(buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS);
3718
+ aclnn_cast(ctx, src_tensor.get(), f16_tensor.get(), ACL_FLOAT16);
3324
3719
 
3325
- char * src0_original = (char *) src0->data;
3326
- char * src1_original = (char *) src1->data;
3327
- char * dst_original = (char *) dst->data;
3720
+ return buffer;
3721
+ };
3328
3722
 
3329
- ggml_tensor src0_row = *src0;
3330
- ggml_tensor src1_row = *src1;
3331
- ggml_tensor dst_row = *dst;
3723
+ // Prepare input and output buffers
3724
+ ggml_cann_pool_alloc input_alloc(ctx.pool());
3725
+ void * input_buffer = prepare_f16_buffer(src1, input_alloc, true);
3332
3726
 
3333
- const enum ggml_type type = dst->src[0]->type;
3334
- float weight_elem_size;
3335
- if (type == GGML_TYPE_Q4_0) {
3336
- weight_elem_size = float(sizeof(uint8_t)) / 2;
3337
- } else if (type == GGML_TYPE_Q8_0) {
3338
- weight_elem_size = float(sizeof(uint8_t));
3339
- } else {
3340
- GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 ");
3341
- }
3727
+ ggml_cann_pool_alloc output_alloc(ctx.pool());
3728
+ void * output_buffer = prepare_f16_buffer(dst, output_alloc, false);
3342
3729
 
3343
- // src0_row [D, M, 1, 1] weight without permute
3344
- src0_row.ne[2] = 1;
3345
- src0_row.ne[3] = 1;
3346
- src0_row.nb[0] = weight_elem_size;
3347
- src0_row.nb[1] = weight_elem_size * ne00;
3348
- src0_row.nb[2] = weight_elem_size * ne00;
3349
- src0_row.nb[3] = weight_elem_size * ne00;
3350
- size_t weight_stride = ne00 * ne01 * weight_elem_size;
3351
- size_t weight_size = weight_stride * ne02 * ne03;
3730
+ // Process each batch
3731
+ for (int64_t batch_idx = 0; batch_idx < n_batches; batch_idx++) {
3732
+ // Create index tensor for current batch
3733
+ const size_t index_offset = batch_idx * ids->nb[1];
3734
+ acl_tensor_ptr batch_indices = ggml_cann_create_tensor(ids, ids->ne, ids->nb, 1, ACL_FORMAT_ND, index_offset);
3352
3735
 
3353
- // scale [D, M, 1, 1] -> scale && permute
3354
- size_t scale_elem_size = sizeof(uint16_t);
3355
- size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size;
3736
+ // Select quantized weights using expert indices
3737
+ // Q4_0 stores 2 values per byte, Q8_0 stores 1 value per byte
3738
+ const int64_t weight_d = (type == GGML_TYPE_Q4_0) ? src0->ne[0] / 2 : src0->ne[0];
3739
+ const int64_t weight_m = src0->ne[1];
3740
+ const int64_t weight_n_experts = src0->ne[2];
3741
+
3742
+ int64_t weight_ne[3] = { weight_d, weight_m, weight_n_experts };
3743
+ size_t weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t), weight_d * weight_m * sizeof(int8_t) };
3744
+
3745
+ acl_tensor_ptr all_weights =
3746
+ ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb, 3);
3747
+
3748
+ int64_t selected_weight_ne[3] = { weight_d, weight_m, n_select_experts };
3749
+ size_t selected_weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t),
3750
+ weight_d * weight_m * sizeof(int8_t) };
3751
+
3752
+ acl_tensor_ptr selected_weights = ggml_cann_create_tensor(selected_weight_buffer, ACL_INT8, sizeof(int8_t),
3753
+ selected_weight_ne, selected_weight_nb, 3);
3754
+
3755
+ GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_weights.get(), 0, batch_indices.get(), selected_weights.get());
3356
3756
 
3357
- // src1_row [D, 1, 1, 1] -> input
3358
- src1_row.ne[1] = 1;
3359
- src1_row.ne[2] = 1;
3360
- src1_row.ne[3] = 1;
3361
- src1_row.nb[2] = nb11;
3362
- src1_row.nb[3] = nb11;
3363
-
3364
- // dst_row [M, 1, 1, 1] -> out
3365
- dst_row.ne[1] = 1;
3366
- dst_row.ne[2] = 1;
3367
- dst_row.ne[3] = 1;
3368
- dst_row.nb[2] = nb1;
3369
- dst_row.nb[3] = nb1;
3370
-
3371
- //create weight for one row
3372
- ggml_cann_pool_alloc weight_allocator(ctx.pool());
3373
- void * weight_buffer = weight_allocator.alloc(nb02);
3374
- for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
3375
- for (int64_t id = 0; id < n_ids; id++) {
3376
- // expert index
3377
- int32_t i02 = *(int32_t *) (ids_host.data() + iid1 * ids->nb[1] + id * ids->nb[0]);
3378
- GGML_ASSERT(i02 >= 0 && i02 < n_as);
3379
-
3380
- // If B = 1 (broadcast), always use 0; otherwise, use id.
3381
- int64_t i11 = (ne11 == 1 ? 0 : id);
3382
- int64_t i12 = iid1;
3383
-
3384
- int64_t i1 = id;
3385
- int64_t i2 = i12;
3386
-
3387
- void * src0_tmp_ptr = src0_original + i02 * weight_stride;
3388
- void * scale_tmp_ptr = src0_original + weight_size + i02 * scale_stride;
3389
- void * src1_tmp_ptr = src1_original + i11 * nb11 + i12 * nb12;
3390
- void * dst_tmp_ptr = dst_original + i1 * nb1 + i2 * nb2;
3391
-
3392
- // mem cpy
3393
- ACL_CHECK(aclrtMemcpyAsync(weight_buffer, weight_stride, src0_tmp_ptr, weight_stride,
3394
- ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
3395
- void * scale_buffer = (char *) weight_buffer + weight_stride;
3396
- ACL_CHECK(aclrtMemcpyAsync(scale_buffer, scale_stride, scale_tmp_ptr, scale_stride,
3397
- ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
3398
-
3399
- src0_row.data = weight_buffer;
3400
- src1_row.data = src1_tmp_ptr;
3401
- dst_row.data = dst_tmp_ptr;
3402
- dst_row.src[0] = &src0_row;
3403
- dst_row.src[1] = &src1_row;
3404
-
3405
- ggml_cann_mul_mat(ctx, &dst_row);
3757
+ // Select scales using the same expert indices
3758
+ const int64_t scale_d = src0->ne[0] / group_size;
3759
+ int64_t scale_ne[3] = { scale_d, weight_m, weight_n_experts };
3760
+ size_t scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size, scale_d * weight_m * scale_elem_size };
3761
+
3762
+ acl_tensor_ptr all_scales =
3763
+ ggml_cann_create_tensor(scale_data, ACL_FLOAT16, scale_elem_size, scale_ne, scale_nb, 3);
3764
+
3765
+ int64_t selected_scale_ne[3] = { scale_d, weight_m, n_select_experts };
3766
+ size_t selected_scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size,
3767
+ scale_d * weight_m * scale_elem_size };
3768
+
3769
+ acl_tensor_ptr selected_scales = ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size,
3770
+ selected_scale_ne, selected_scale_nb, 3);
3771
+
3772
+ GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_scales.get(), 0, batch_indices.get(), selected_scales.get());
3773
+
3774
+ // Process each expert for current batch
3775
+ // IndexSelect output layout: [D, M, K] in contiguous format
3776
+ // WeightQuantBatchMatmulV2 expects: [M, D] with row-major stride
3777
+ for (int64_t expert_idx = 0; expert_idx < n_select_experts; expert_idx++) {
3778
+ // Determine input offset: broadcast if src1->ne[1]==1, otherwise use per-expert input
3779
+ const size_t input_offset =
3780
+ (batch_idx * src1->ne[1] + (src1->ne[1] == 1 ? 0 : expert_idx)) * src1->ne[0] * f16_elem_size;
3781
+ const size_t output_offset = (batch_idx * dst->ne[1] + expert_idx) * dst->ne[0] * f16_elem_size;
3782
+
3783
+ // Create weight view for current expert: [D, M, K] -> [M, D]
3784
+ int64_t weight_view_ne[2] = { weight_m, src0->ne[0] };
3785
+ float weight_view_nb[2] = { src0->ne[0] * weight_elem_size, weight_elem_size };
3786
+ const size_t weight_view_offset = expert_idx * selected_weight_nb[2];
3787
+
3788
+ acl_tensor_ptr weight_view =
3789
+ ggml_cann_create_tensor(selected_weight_buffer, ggml_cann_type_mapping(type), weight_elem_size,
3790
+ weight_view_ne, weight_view_nb, 2, ACL_FORMAT_ND, weight_view_offset);
3791
+
3792
+ // Create scale view for current expert: [D, M, K] -> [M, D]
3793
+ int64_t scale_view_ne[2] = { weight_m, scale_d };
3794
+ size_t scale_view_nb[2] = { selected_scale_nb[1], selected_scale_nb[0] };
3795
+ const size_t scale_view_offset = expert_idx * selected_scale_nb[2];
3796
+
3797
+ acl_tensor_ptr scale_view =
3798
+ ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size, scale_view_ne,
3799
+ scale_view_nb, 2, ACL_FORMAT_ND, scale_view_offset);
3800
+
3801
+ // Create input activation tensor [D, 1]
3802
+ int64_t input_ne[2] = { src1->ne[0], 1 };
3803
+ size_t input_nb[2] = { f16_elem_size, src1->ne[0] * f16_elem_size };
3804
+
3805
+ acl_tensor_ptr input_tensor = ggml_cann_create_tensor(input_buffer, ACL_FLOAT16, f16_elem_size, input_ne,
3806
+ input_nb, 2, ACL_FORMAT_ND, input_offset);
3807
+
3808
+ // Create output tensor [M, 1]
3809
+ int64_t output_ne[2] = { dst->ne[0], 1 };
3810
+ size_t output_nb[2] = { f16_elem_size, dst->ne[0] * f16_elem_size };
3811
+
3812
+ acl_tensor_ptr output_tensor = ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, output_ne,
3813
+ output_nb, 2, ACL_FORMAT_ND, output_offset);
3814
+
3815
+ // Perform quantized matrix multiplication
3816
+ GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, input_tensor.get(), weight_view.get(),
3817
+ scale_view.get(), nullptr, nullptr, nullptr, nullptr, group_size,
3818
+ output_tensor.get());
3406
3819
  }
3407
3820
  }
3408
- return;
3821
+
3822
+ // Cast output back to original type if we used a temporary F16 buffer
3823
+ if (dst->type != GGML_TYPE_F16) {
3824
+ int64_t ne[GGML_MAX_DIMS];
3825
+ size_t nb[GGML_MAX_DIMS] = { f16_elem_size };
3826
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
3827
+ ne[i] = dst->ne[i];
3828
+ if (i > 0) {
3829
+ nb[i] = nb[i - 1] * ne[i - 1];
3830
+ }
3831
+ }
3832
+
3833
+ acl_tensor_ptr f16_output =
3834
+ ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS);
3835
+ acl_tensor_ptr dst_tensor = ggml_cann_create_tensor(dst);
3836
+
3837
+ aclnn_cast(ctx, f16_output.get(), dst_tensor.get(), ggml_cann_type_mapping(dst->type));
3838
+ }
3409
3839
  }
3410
3840
 
3411
3841
  void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
@@ -3502,6 +3932,44 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
3502
3932
  acl_k_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne, src1_bsnd_nb, GGML_MAX_DIMS);
3503
3933
  acl_v_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne, src2_bsnd_nb, GGML_MAX_DIMS);
3504
3934
 
3935
+ // Step 2.5: Pad Q, K, V along head dimension if D is not a multiple of 16
3936
+ // (required by FusedInferAttentionScoreV2)
3937
+ const int64_t D = src0->ne[0];
3938
+ const int64_t D_padded = GGML_PAD(D, 16);
3939
+ const bool needs_padding = (D != D_padded);
3940
+
3941
+ ggml_cann_pool_alloc q_pad_allocator(ctx.pool());
3942
+ ggml_cann_pool_alloc k_pad_allocator(ctx.pool());
3943
+ ggml_cann_pool_alloc v_pad_allocator(ctx.pool());
3944
+
3945
+ if (needs_padding) {
3946
+ int64_t paddings[] = { 0, D_padded - D, 0, 0, 0, 0, 0, 0 };
3947
+
3948
+ auto pad_fa_tensor = [&](acl_tensor_ptr & tensor, const int64_t * bsnd_ne,
3949
+ ggml_cann_pool_alloc & allocator) {
3950
+ int64_t pad_ne[GGML_MAX_DIMS] = { D_padded, bsnd_ne[1], bsnd_ne[2], bsnd_ne[3] };
3951
+ size_t pad_nb[GGML_MAX_DIMS];
3952
+ pad_nb[0] = faElemSize;
3953
+ for (int i = 1; i < GGML_MAX_DIMS; ++i) {
3954
+ pad_nb[i] = pad_nb[i - 1] * pad_ne[i - 1];
3955
+ }
3956
+ int64_t nelements = pad_ne[0] * pad_ne[1] * pad_ne[2] * pad_ne[3];
3957
+ void * buffer = allocator.alloc(nelements * faElemSize);
3958
+ acl_tensor_ptr padded =
3959
+ ggml_cann_create_tensor(buffer, faDataType, faElemSize, pad_ne, pad_nb, GGML_MAX_DIMS);
3960
+ aclnn_pad(ctx, tensor.get(), padded.get(), paddings);
3961
+ tensor = std::move(padded);
3962
+ };
3963
+
3964
+ pad_fa_tensor(acl_q_tensor, src0_bsnd_ne, q_pad_allocator);
3965
+ pad_fa_tensor(acl_k_tensor, src1_bsnd_ne, k_pad_allocator);
3966
+ pad_fa_tensor(acl_v_tensor, src2_bsnd_ne, v_pad_allocator);
3967
+
3968
+ src0_bsnd_ne[0] = D_padded;
3969
+ src1_bsnd_ne[0] = D_padded;
3970
+ src2_bsnd_ne[0] = D_padded;
3971
+ }
3972
+
3505
3973
  // Step 3: create the PSEShift tensor if needed
3506
3974
  // this tensor is considered as mask (f16) in the llama.cpp
3507
3975
  acl_tensor_ptr bcast_pse_tensor;
@@ -3591,17 +4059,16 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
3591
4059
 
3592
4060
  GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
3593
4061
  acl_tensor_ptr fa_dst_tensor;
3594
- acl_tensor_ptr acl_dst_tensor;
3595
4062
  ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
3596
- if (dst->type == GGML_TYPE_F32) {
3597
- void * out_f16_buffer = out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
3598
-
4063
+ if (dst->type == GGML_TYPE_F32 || needs_padding) {
3599
4064
  int64_t * out_f16_ne = src0_bsnd_ne;
3600
4065
  size_t out_f16_nb[GGML_MAX_DIMS];
3601
4066
  out_f16_nb[0] = faElemSize;
3602
4067
  for (int i = 1; i < GGML_MAX_DIMS; ++i) {
3603
4068
  out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
3604
4069
  }
4070
+ int64_t out_nelements = out_f16_ne[0] * out_f16_ne[1] * out_f16_ne[2] * out_f16_ne[3];
4071
+ void * out_f16_buffer = out_f16_allocator.alloc(out_nelements * faElemSize);
3605
4072
 
3606
4073
  fa_dst_tensor =
3607
4074
  ggml_cann_create_tensor(out_f16_buffer, faDataType, faElemSize, out_f16_ne, out_f16_nb, GGML_MAX_DIMS);
@@ -3633,8 +4100,33 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
3633
4100
  nullptr // softmaxLse
3634
4101
  );
3635
4102
 
3636
- if (dst->type == GGML_TYPE_F32) {
3637
- // Step 6: post-processing, permute and cast to f32
4103
+ // Step 6: post-processing — slice padded output and/or cast to f32
4104
+ if (needs_padding) {
4105
+ ggml_cann_pool_alloc sliced_f16_allocator(ctx.pool());
4106
+
4107
+ if (dst->type == GGML_TYPE_F32) {
4108
+ int64_t sliced_ne[GGML_MAX_DIMS] = { D, src0_bsnd_ne[1], src0_bsnd_ne[2], src0_bsnd_ne[3] };
4109
+ size_t sliced_nb[GGML_MAX_DIMS];
4110
+ sliced_nb[0] = faElemSize;
4111
+ for (int i = 1; i < GGML_MAX_DIMS; ++i) {
4112
+ sliced_nb[i] = sliced_nb[i - 1] * sliced_ne[i - 1];
4113
+ }
4114
+ int64_t sliced_nelements = sliced_ne[0] * sliced_ne[1] * sliced_ne[2] * sliced_ne[3];
4115
+ void * sliced_buffer = sliced_f16_allocator.alloc(sliced_nelements * faElemSize);
4116
+ acl_tensor_ptr sliced_f16_tensor = ggml_cann_create_tensor(sliced_buffer, faDataType, faElemSize,
4117
+ sliced_ne, sliced_nb, GGML_MAX_DIMS);
4118
+
4119
+ GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(),
4120
+ (int64_t) -1, (int64_t) 0, D, (int64_t) 1, sliced_f16_tensor.get());
4121
+
4122
+ acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst);
4123
+ aclnn_cast(ctx, sliced_f16_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type));
4124
+ } else {
4125
+ acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst);
4126
+ GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(),
4127
+ (int64_t) -1, (int64_t) 0, D, (int64_t) 1, acl_dst_tensor.get());
4128
+ }
4129
+ } else if (dst->type == GGML_TYPE_F32) {
3638
4130
  acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst);
3639
4131
  aclnn_cast(ctx, fa_dst_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type));
3640
4132
  }
@@ -3644,46 +4136,65 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
3644
4136
  }
3645
4137
 
3646
4138
  static void ggml_cann_out_prod_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3647
- ggml_tensor * src0 = dst->src[0]; // weight
3648
- ggml_tensor * src1 = dst->src[1]; // input
4139
+ ggml_tensor * src0 = dst->src[0]; // weight [ne00=m, ne01=K, ne02, ne03]
4140
+ ggml_tensor * src1 = dst->src[1]; // input [ne10=n, ne11=K, ne12, ne13]
3649
4141
  GGML_TENSOR_BINARY_OP_LOCALS
3650
4142
 
3651
- acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
3652
- GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get());
4143
+ // dst[i,j] = sum_k src0[i,k] * src1[j,k] i.e. dst = src0 @ src1^T.
4144
+ //
4145
+ // ggml_cann_create_tensor reverses dimension order, so ACL sees:
4146
+ // acl_src0 slice: ggml[m,K] -> ACL[K,m]
4147
+ // acl_src1 slice: ggml[n,K] -> ACL[K,n]
4148
+ // acl_dst slice: ggml[m,n] -> ACL[n,m]
4149
+ //
4150
+ // Build a transposed view of src1 by swapping ne[0]/ne[1]:
4151
+ // src1_t: ggml[K,n] (swapped strides) -> ACL[n,K]
4152
+ //
4153
+ // Matmul(src1_t [n,K], src0 [K,m]) = [n,m] = acl_dst ✓
4154
+ //
4155
+ // The outer batch loop is kept because src0 may have fewer batch slices than
4156
+ // dst (ne02 <= ne2, ne03 <= ne3): this is a strided-broadcast not supported
4157
+ // by standard CANN Matmul broadcasting.
4158
+
4159
+ const aclDataType src0_acl_type = ggml_cann_type_mapping(src0->type);
4160
+ const aclDataType src1_acl_type = ggml_cann_type_mapping(src1->type);
4161
+ const aclDataType dst_acl_type = ggml_cann_type_mapping(dst->type);
4162
+ const size_t src0_type_sz = ggml_type_size(src0->type);
4163
+ const size_t src1_type_sz = ggml_type_size(src1->type);
4164
+ const size_t dst_type_sz = ggml_type_size(dst->type);
3653
4165
 
3654
4166
  const int64_t dps2 = ne2 / ne02;
3655
4167
  const int64_t dps3 = ne3 / ne03;
4168
+
3656
4169
  for (int64_t i3 = 0; i3 < ne3; i3++) {
3657
4170
  for (int64_t i2 = 0; i2 < ne2; i2++) {
3658
4171
  const int64_t i02 = i2 / dps2;
3659
4172
  const int64_t i03 = i3 / dps3;
3660
4173
 
3661
- const int64_t i12 = i2;
3662
- const int64_t i13 = i3;
3663
- acl_tensor_ptr accumulator =
3664
- ggml_cann_create_tensor((char *) dst->data + i2 * nb2 + i3 * nb3, ggml_cann_type_mapping(dst->type),
3665
- ggml_type_size(dst->type), dst->ne, dst->nb, 2);
3666
-
3667
- // The outer product needs to be accumulated in this dimension.
3668
- for (int64_t i1 = 0; i1 < ne11; i1++) {
3669
- acl_tensor_ptr acl_input = ggml_cann_create_tensor(
3670
- (char *) src1->data + i1 * nb11 + i12 * nb12 + i13 * nb13, ggml_cann_type_mapping(src0->type),
3671
- ggml_type_size(src0->type), src1->ne, src1->nb, 1);
3672
-
3673
- acl_tensor_ptr acl_weight = ggml_cann_create_tensor(
3674
- (char *) src0->data + i1 * nb01 + i02 * nb02 + i03 * nb03, ggml_cann_type_mapping(src0->type),
3675
- ggml_type_size(src0->type), src0->ne, src0->nb, 1);
3676
-
3677
- ggml_cann_pool_alloc output_allocator(ctx.pool());
3678
- void * output_buffer = output_allocator.alloc(ggml_nbytes(dst));
3679
- acl_tensor_ptr acl_out = ggml_cann_create_tensor(output_buffer, ggml_cann_type_mapping(dst->type),
3680
- ggml_type_size(dst->type), dst->ne, dst->nb, 2);
3681
-
3682
- GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_input.get(), acl_weight.get(), acl_out.get());
3683
- float alpha_value = 1.0f;
3684
- aclScalar * alpha = aclCreateScalar(&alpha_value, ACL_FLOAT);
3685
- GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, accumulator.get(), acl_out.get(), alpha);
3686
- }
4174
+ // src0 2D slice at [i02, i03]: ggml [m, K] -> ACL [K, m]
4175
+ int64_t src0_ne[2] = { ne00, ne01 };
4176
+ size_t src0_nb[2] = { nb00, nb01 };
4177
+ acl_tensor_ptr acl_src0_s = ggml_cann_create_tensor(
4178
+ (char *) src0->data + i02 * nb02 + i03 * nb03,
4179
+ src0_acl_type, src0_type_sz, src0_ne, src0_nb, 2);
4180
+
4181
+ // src1 transposed 2D slice at [i2, i3]: swap ne/nb -> ggml[K,n] -> ACL[n,K]
4182
+ int64_t src1_t_ne[2] = { ne11, ne10 };
4183
+ size_t src1_t_nb[2] = { nb11, nb10 };
4184
+ acl_tensor_ptr acl_src1_t = ggml_cann_create_tensor(
4185
+ (char *) src1->data + i2 * nb12 + i3 * nb13,
4186
+ src1_acl_type, src1_type_sz, src1_t_ne, src1_t_nb, 2);
4187
+
4188
+ // dst 2D slice at [i2, i3]: ggml [m, n] -> ACL [n, m]
4189
+ int64_t dst_ne[2] = { ne0, ne1 };
4190
+ size_t dst_nb[2] = { nb0, nb1 };
4191
+ acl_tensor_ptr acl_dst_s = ggml_cann_create_tensor(
4192
+ (char *) dst->data + i2 * nb2 + i3 * nb3,
4193
+ dst_acl_type, dst_type_sz, dst_ne, dst_nb, 2);
4194
+
4195
+ // Matmul(src1_t [n,K], src0 [K,m]) = [n,m] = acl_dst_s ✓
4196
+ GGML_CANN_CALL_ACLNN_OP(ctx, Matmul,
4197
+ acl_src1_t.get(), acl_src0_s.get(), acl_dst_s.get(), (int8_t) 1);
3687
4198
  }
3688
4199
  }
3689
4200
  }
@@ -3742,15 +4253,15 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3742
4253
  // we want a view: ne_w = { nc, 1, nr } // [K, 1, C]
3743
4254
  // so that reversed dims -> [C, 1, K] which matches
3744
4255
  // [out_channels, in_channels/groups, kernel_size]
3745
- int64_t w_ne[GGML_MAX_DIMS] = { nc, 1, nr, 1 }; // [K, 1 input ch. per group, C groups]
4256
+ int64_t w_ne[GGML_MAX_DIMS] = { nc, 1, nr, 1 }; // [K, 1 input ch. per group, C groups]
3746
4257
  // Layout: src1 data is [K, C] with
3747
4258
  // offset(k, c) = k*nb0 + c*nb1
3748
4259
  // We want offset_w(k, 0, c) = k*nb0 + c*nb1,
3749
4260
  // so we can reuse nb0 and nb1, and set nb2 = nb1.
3750
- size_t w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] }; // same as src1
4261
+ size_t w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] }; // same as src1
3751
4262
 
3752
- acl_tensor_ptr acl_w = ggml_cann_create_tensor(
3753
- src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL);
4263
+ acl_tensor_ptr acl_w = ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type),
4264
+ ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL);
3754
4265
 
3755
4266
  // 3) Output: dst is { d_inner, n_t, n_s } (CLN)
3756
4267
  //
@@ -3768,11 +4279,12 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3768
4279
  // nb_y[0] = nr * sizeof(float); // step in L
3769
4280
  // nb_y[1] = sizeof(float); // step in C
3770
4281
  // nb_y[2] = nr * n_t * sizeof(float); // step in N
3771
- int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 }; // [L_out, C, N]
3772
- size_t y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float), dst->nb[3] }; // [nr, 1, nr * n_t]
4282
+ int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 }; // [L_out, C, N]
4283
+ size_t y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float),
4284
+ dst->nb[3] }; // [nr, 1, nr * n_t]
3773
4285
 
3774
- acl_tensor_ptr acl_y = ggml_cann_create_tensor(
3775
- dst->data, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL);
4286
+ acl_tensor_ptr acl_y = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type),
4287
+ ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL);
3776
4288
 
3777
4289
  // --- Conv1d parameters: depthwise, stride 1, no padding ("valid") ---
3778
4290
  int64_t strideVal[1] = { 1 };
@@ -3791,22 +4303,15 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3791
4303
  cubeMathType = 1;
3792
4304
  #endif
3793
4305
 
3794
- GGML_CANN_CALL_ACLNN_OP(ctx,
3795
- Convolution,
4306
+ GGML_CANN_CALL_ACLNN_OP(ctx, Convolution,
3796
4307
  acl_x.get(), // input: N, C, L_in = ncs
3797
4308
  acl_w.get(), // weight: [C, 1, K] with groups=nr
3798
4309
  nullptr, // bias
3799
- stride.get(),
3800
- padding.get(),
3801
- dilation.get(),
3802
- transposed,
3803
- padding.get(), // output padding (unused for non-transposed)
3804
- groups,
3805
- acl_y.get(),
3806
- cubeMathType);
4310
+ stride.get(), padding.get(), dilation.get(), transposed,
4311
+ padding.get(), // output padding (unused for non-transposed)
4312
+ groups, acl_y.get(), cubeMathType);
3807
4313
  }
3808
4314
 
3809
-
3810
4315
  void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,
3811
4316
  ggml_tensor * add_node,
3812
4317
  ggml_tensor * rms_norm_node) {
@@ -3860,3 +4365,72 @@ void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,
3860
4365
  eps, // double type
3861
4366
  acl_yout.get(), acl_rstd.get(), acl_xout.get());
3862
4367
  }
4368
+
4369
+ void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
4370
+ ggml_tensor * k = dst->src[0];
4371
+ ggml_tensor * v = dst->src[1];
4372
+ ggml_tensor * q = dst->src[2];
4373
+ ggml_tensor * g = dst->src[3];
4374
+ ggml_tensor * s = dst->src[4];
4375
+
4376
+ int64_t B = dst->src[4]->ne[1];
4377
+ int64_t T = dst->src[0]->ne[2];
4378
+ int64_t H = dst->src[0]->ne[1];
4379
+ int64_t C = dst->ne[0];
4380
+ int64_t D = C / H;
4381
+ int64_t L = T / B;
4382
+
4383
+ int64_t ne_qkg[2] = { 1, D };
4384
+ int64_t ne_s[2] = { D, D };
4385
+ int64_t ne_st[2] = { ne_s[1], ne_s[0] };
4386
+ int64_t ne_vo[2] = { D, 1 };
4387
+ int64_t ne_q[1] = { D };
4388
+ size_t nb_base = ggml_type_size(k->type);
4389
+ size_t nb_qkg[2] = { nb_base, nb_base };
4390
+ size_t nb_s[2] = { nb_base, D * nb_base };
4391
+ size_t nb_st[2] = { nb_s[1], nb_s[0] };
4392
+ size_t nb_vo[2] = { nb_base, D * nb_base };
4393
+ size_t nb_q[1] = { nb_base };
4394
+
4395
+ const float scale = ggml_get_op_params_f32(dst, 0);
4396
+
4397
+ acl_tensor_ptr acl_s = ggml_cann_create_tensor(s, s->ne, s->nb, 2, ACL_FORMAT_ND);
4398
+ acl_tensor_ptr new_state = ggml_cann_create_tensor(dst, s->ne, s->nb, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base);
4399
+ cann_copy(ctx, acl_s.get(), new_state.get());
4400
+
4401
+ for (int64_t b = 0; b < B; b++) {
4402
+ for (int64_t h = 0; h < H; h++) {
4403
+ size_t s_offset = (b * (H * D * D) + h * (D * D)) * nb_base;
4404
+ // D * D
4405
+ acl_tensor_ptr acl_s_new =
4406
+ ggml_cann_create_tensor(dst, ne_s, nb_s, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset);
4407
+ acl_tensor_ptr acl_s_new_t =
4408
+ ggml_cann_create_tensor(dst, ne_st, nb_st, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset);
4409
+ for (int64_t l = 0; l < L; l++) {
4410
+ size_t qkvgo_offset = (b * (L * H * D) + l * (H * D) + h * (D)) * nb_base;
4411
+ // D * 1
4412
+ acl_tensor_ptr acl_k = ggml_cann_create_tensor(k, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset);
4413
+ acl_tensor_ptr acl_g = ggml_cann_create_tensor(g, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset);
4414
+ // D
4415
+ acl_tensor_ptr acl_q = ggml_cann_create_tensor(q, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset);
4416
+ // 1 * D
4417
+ acl_tensor_ptr acl_v = ggml_cann_create_tensor(v, ne_vo, nb_vo, 2, ACL_FORMAT_ND, qkvgo_offset);
4418
+ // D
4419
+ acl_tensor_ptr acl_o = ggml_cann_create_tensor(dst, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset);
4420
+ // k ⊗ v
4421
+ size_t buf_size = D * D * nb_base;
4422
+ ggml_cann_pool_alloc buffer_allocator(ctx.pool(), buf_size);
4423
+ acl_tensor_ptr tmp_tensor = ggml_cann_create_tensor(
4424
+ buffer_allocator.get(), ggml_cann_type_mapping(k->type), nb_base, ne_s, nb_s, 2);
4425
+ aclnn_mul(ctx, acl_k.get(), acl_v.get(), tmp_tensor.get());
4426
+ //s_new = g ⊗ s_old + k ⊗ v
4427
+ aclnn_mul(ctx, acl_s_new.get(), acl_g.get(), nullptr);
4428
+ aclnn_add(ctx, acl_s_new.get(), tmp_tensor.get(), nullptr);
4429
+ // compute output
4430
+ GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_s_new_t.get(), acl_q.get(), acl_o.get(), 1);
4431
+ aclnn_muls(ctx, acl_o.get(), scale, nullptr, true);
4432
+ }
4433
+ }
4434
+ }
4435
+ }
4436
+