whispercpp 1.3.6 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (828) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/README.md +38 -5
  5. data/Rakefile +18 -3
  6. data/ext/dependencies.rb +10 -4
  7. data/ext/dependencies_for_windows.rb +17 -0
  8. data/ext/extconf.rb +20 -8
  9. data/ext/options.rb +54 -14
  10. data/ext/options_for_windows.rb +51 -0
  11. data/ext/ruby_whisper.c +36 -42
  12. data/ext/ruby_whisper.h +135 -0
  13. data/ext/ruby_whisper_context.c +107 -28
  14. data/ext/ruby_whisper_log_queue.c +180 -0
  15. data/ext/ruby_whisper_log_settable.h +47 -0
  16. data/ext/ruby_whisper_parakeet.c +49 -0
  17. data/ext/ruby_whisper_parakeet_context.c +304 -0
  18. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  19. data/ext/ruby_whisper_parakeet_model.c +84 -0
  20. data/ext/ruby_whisper_parakeet_params.c +548 -0
  21. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  22. data/ext/ruby_whisper_parakeet_token.c +188 -0
  23. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  24. data/ext/ruby_whisper_params.c +256 -65
  25. data/ext/ruby_whisper_segment.c +6 -6
  26. data/ext/ruby_whisper_transcribe.cpp +42 -15
  27. data/ext/sources/CMakeLists.txt +41 -3
  28. data/ext/sources/CMakePresets.json +95 -0
  29. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  30. data/ext/sources/cmake/parakeet.pc.in +10 -0
  31. data/ext/sources/cmake/whisper.pc.in +1 -1
  32. data/ext/sources/examples/CMakeLists.txt +4 -2
  33. data/ext/sources/examples/bench/bench.cpp +1 -1
  34. data/ext/sources/examples/cli/cli.cpp +43 -9
  35. data/ext/sources/examples/common-ggml.cpp +2 -0
  36. data/ext/sources/examples/common-whisper.cpp +139 -67
  37. data/ext/sources/examples/common-whisper.h +11 -0
  38. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  39. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  40. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  41. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  42. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  43. data/ext/sources/examples/server/server.cpp +199 -163
  44. data/ext/sources/ggml/CMakeLists.txt +21 -13
  45. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  46. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  47. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  48. data/ext/sources/ggml/include/ggml-backend.h +72 -10
  49. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  50. data/ext/sources/ggml/include/ggml-rpc.h +3 -3
  51. data/ext/sources/ggml/include/ggml.h +101 -9
  52. data/ext/sources/ggml/include/gguf.h +10 -2
  53. data/ext/sources/ggml/src/CMakeLists.txt +22 -5
  54. data/ext/sources/ggml/src/ggml-alloc.c +5 -1
  55. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  56. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  57. data/ext/sources/ggml/src/ggml-backend-reg.cpp +12 -0
  58. data/ext/sources/ggml/src/ggml-backend.cpp +110 -9
  59. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +4 -0
  60. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +672 -257
  61. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +71 -0
  62. data/ext/sources/ggml/src/ggml-cann/common.h +20 -10
  63. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +211 -30
  64. data/ext/sources/ggml/src/ggml-common.h +11 -0
  65. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +58 -29
  66. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +2 -0
  67. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +16 -16
  68. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +116 -7
  69. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +65 -0
  70. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  71. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  72. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +4279 -1292
  73. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +5 -35
  74. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +0 -1
  75. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  76. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +177 -27
  77. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +1 -1
  78. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +5 -0
  79. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  80. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -0
  81. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +95 -5
  82. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  83. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +146 -134
  84. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +88 -70
  85. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +372 -73
  86. data/ext/sources/ggml/src/ggml-cpu/ops.h +3 -0
  87. data/ext/sources/ggml/src/ggml-cpu/quants.c +55 -0
  88. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  89. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3 -0
  90. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +90 -0
  91. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +3 -16
  92. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  93. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  94. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  95. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  96. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  97. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  98. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  99. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  100. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  101. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  102. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  103. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  104. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  105. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  106. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  107. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +37 -53
  108. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  109. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +17 -7
  110. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  111. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  112. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +62 -26
  113. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +44 -18
  114. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  115. data/ext/sources/ggml/src/ggml-cuda/common.cuh +242 -28
  116. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  117. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  118. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  119. data/ext/sources/ggml/src/ggml-cuda/convert.cu +53 -0
  120. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  121. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +14 -6
  122. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  123. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +278 -44
  124. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +331 -130
  125. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +126 -27
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +40 -15
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +18 -9
  129. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +152 -49
  130. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  131. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  132. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  133. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +84 -35
  134. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  135. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1069 -609
  136. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  137. data/ext/sources/ggml/src/ggml-cuda/mean.cu +4 -2
  138. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +242 -195
  139. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +3 -3
  140. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  141. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +502 -423
  142. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +19 -12
  143. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +485 -57
  144. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -1
  145. data/ext/sources/ggml/src/ggml-cuda/norm.cu +36 -10
  146. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  147. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +133 -26
  148. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  149. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +5 -1
  150. data/ext/sources/ggml/src/ggml-cuda/rope.cu +11 -4
  151. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  152. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  153. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  154. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  155. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  156. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +45 -13
  157. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  158. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  159. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  160. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  161. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +1 -0
  162. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  163. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  164. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  165. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +1 -0
  166. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  167. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  168. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  169. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  170. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  171. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  172. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  173. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  174. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  176. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  177. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  178. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  179. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  191. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -4
  192. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +26 -23
  193. data/ext/sources/ggml/src/ggml-cuda/unary.cu +31 -2
  194. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  195. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +80 -0
  196. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  197. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -4
  198. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  199. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +2 -1
  200. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1428 -743
  201. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -7
  202. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +53 -84
  203. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +25 -12
  204. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +165 -184
  205. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  206. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  207. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +170 -127
  208. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  209. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  210. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  211. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +125 -97
  212. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  213. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +148 -42
  214. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +2 -2
  215. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +252 -62
  216. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +9 -0
  217. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +87 -1
  218. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  219. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  220. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  221. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  222. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  223. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  224. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  225. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  226. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +96 -13
  227. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +182 -57
  228. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  229. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +71 -3
  230. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +27 -10
  231. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +63 -23
  232. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +9 -8
  233. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  235. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +1 -0
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +5 -8
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +529 -815
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2522 -234
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +291 -95
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +59 -37
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +121 -133
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +244 -151
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +6 -6
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +719 -45
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +3 -1
  254. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +22 -9
  255. data/ext/sources/ggml/src/ggml-impl.h +6 -1
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +138 -13
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +32 -1
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +164 -28
  259. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +80 -0
  260. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +190 -19
  261. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +2 -0
  262. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +39 -26
  263. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +823 -322
  264. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  265. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +54 -5
  266. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +12248 -5907
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +67 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +59 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1819 -112
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl → gemm_noshuffle_q8_0_f32.cl} +1 -1
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +48 -64
  322. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +15 -5
  323. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +18 -11
  324. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +35 -13
  325. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +264 -192
  326. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +33 -7
  327. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  328. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +1 -0
  329. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +1 -0
  330. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  331. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +27 -3
  332. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +67 -36
  333. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +1 -0
  334. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +101 -44
  335. data/ext/sources/ggml/src/ggml-openvino/utils.h +23 -3
  336. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  337. data/ext/sources/ggml/src/ggml-quants.c +289 -114
  338. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  339. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  340. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  341. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  342. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  343. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +50 -4
  344. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +1 -1
  345. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +3 -1
  346. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  347. data/ext/sources/ggml/src/ggml-sycl/common.hpp +41 -1
  348. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +115 -13
  349. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +9 -0
  350. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  351. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  352. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  353. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  354. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  355. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  356. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1 -90
  357. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +0 -2
  358. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  359. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  360. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +7 -5
  361. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +4 -0
  362. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +76 -168
  363. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +7 -0
  364. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +3 -1
  365. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  366. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  367. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +69 -31
  368. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +1 -0
  369. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  370. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  371. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +823 -190
  372. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  373. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  374. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  375. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  376. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  377. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +71 -0
  378. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  379. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  380. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  381. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  382. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  383. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  384. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  385. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  386. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  387. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +1 -0
  388. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +1 -0
  389. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +1 -0
  390. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +1 -0
  391. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +1 -0
  392. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +1 -0
  393. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +1 -0
  394. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +1 -0
  395. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +1 -0
  396. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +1 -0
  397. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +1 -0
  398. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +1 -0
  399. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +1 -0
  400. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +1 -0
  401. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +1 -0
  402. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +1 -0
  403. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +1 -0
  404. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +1 -0
  405. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +1 -0
  406. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +1 -0
  407. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +1 -0
  408. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +1 -0
  409. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +1 -0
  410. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +1 -0
  411. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +1 -0
  412. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +1 -0
  413. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +1 -0
  414. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +1 -0
  415. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +1 -0
  416. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +1 -0
  417. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +1 -0
  418. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +1 -0
  420. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +1 -0
  421. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +1 -0
  422. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +1 -0
  423. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  424. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  425. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  426. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +215 -53
  427. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +4 -0
  428. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +2 -0
  429. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +2 -0
  430. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +1 -0
  431. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +1 -0
  432. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +0 -2
  433. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +11 -0
  434. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +2060 -535
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +197 -48
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +60 -59
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +115 -113
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +122 -31
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +125 -64
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -17
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +43 -10
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +5 -2
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +3 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +11 -1
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +171 -147
  484. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  485. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +2202 -283
  486. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2610 -1403
  487. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  488. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  489. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +8 -7
  490. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +76 -95
  491. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +19 -1
  492. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  493. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -50
  494. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +107 -184
  495. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  496. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  497. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  498. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  499. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  500. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  501. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +183 -78
  502. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  503. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  504. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +655 -495
  505. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  506. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  507. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  508. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +8 -6
  509. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +5 -1
  510. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +80 -409
  511. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  512. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  513. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  514. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  515. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  516. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  517. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  518. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +6 -4
  519. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  520. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +2 -3
  521. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  522. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  523. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  524. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  525. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  526. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +68 -48
  527. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  528. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +18 -14
  529. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +1 -1
  530. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +244 -10
  531. data/ext/sources/ggml/src/ggml.c +110 -28
  532. data/ext/sources/ggml/src/gguf.cpp +173 -28
  533. data/ext/sources/include/parakeet.h +342 -0
  534. data/ext/sources/include/whisper.h +10 -0
  535. data/ext/sources/media/matmul.png +0 -0
  536. data/ext/sources/src/CMakeLists.txt +23 -0
  537. data/ext/sources/src/parakeet-arch.h +188 -0
  538. data/ext/sources/src/parakeet.cpp +3838 -0
  539. data/ext/sources/src/whisper.cpp +56 -12
  540. data/extsources.rb +26 -10
  541. data/lib/whisper/log_settable.rb +36 -0
  542. data/lib/whisper/model/uri.rb +13 -1
  543. data/lib/whisper/output.rb +74 -0
  544. data/sig/whisper.rbs +411 -62
  545. data/test/helper.rb +2 -0
  546. data/test/jfk_reader/jfk_reader.c +50 -7
  547. data/test/test_callback.rb +1 -0
  548. data/test/test_package.rb +6 -5
  549. data/test/test_parakeet.rb +28 -0
  550. data/test/test_parakeet_callback.rb +107 -0
  551. data/test/test_parakeet_context.rb +116 -0
  552. data/test/test_parakeet_context_params.rb +24 -0
  553. data/test/test_parakeet_model.rb +21 -0
  554. data/test/test_parakeet_params.rb +78 -0
  555. data/test/test_parakeet_segment.rb +42 -0
  556. data/test/test_parakeet_token.rb +73 -0
  557. data/test/test_params.rb +2 -0
  558. data/test/test_vad_segment.rb +1 -1
  559. data/test/test_whisper.rb +24 -6
  560. data/whispercpp.gemspec +2 -2
  561. metadata +215 -281
  562. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  563. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  564. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  565. data/ext/sources/bindings/javascript/package.json +0 -26
  566. data/ext/sources/bindings/javascript/whisper.js +0 -19
  567. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  568. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  569. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  570. data/ext/sources/examples/addon.node/index.js +0 -59
  571. data/ext/sources/examples/addon.node/package.json +0 -16
  572. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  573. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  574. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  575. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  576. data/ext/sources/examples/coi-serviceworker.js +0 -146
  577. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  578. data/ext/sources/examples/command/command.cpp +0 -802
  579. data/ext/sources/examples/command/commands.txt +0 -9
  580. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  581. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  582. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  583. data/ext/sources/examples/generate-karaoke.sh +0 -57
  584. data/ext/sources/examples/helpers.js +0 -191
  585. data/ext/sources/examples/livestream.sh +0 -112
  586. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  587. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  588. data/ext/sources/examples/lsp/whisper.vim +0 -362
  589. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  590. data/ext/sources/examples/python/whisper_processor.py +0 -54
  591. data/ext/sources/examples/server/bench.js +0 -29
  592. data/ext/sources/examples/server.py +0 -120
  593. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  594. data/ext/sources/examples/stream/stream.cpp +0 -437
  595. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  596. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  597. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  598. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  599. data/ext/sources/examples/sycl/build.sh +0 -22
  600. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  601. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  602. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -48
  603. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  604. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -488
  605. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -89
  606. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2877
  607. data/ext/sources/examples/talk-llama/llama-arch.h +0 -628
  608. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -919
  609. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  610. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -896
  611. data/ext/sources/examples/talk-llama/llama-chat.h +0 -71
  612. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3633
  613. data/ext/sources/examples/talk-llama/llama-context.h +0 -359
  614. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  615. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -47
  616. data/ext/sources/examples/talk-llama/llama-ext.h +0 -12
  617. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  618. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  619. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2735
  620. data/ext/sources/examples/talk-llama/llama-graph.h +0 -1031
  621. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -258
  622. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -353
  623. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  624. data/ext/sources/examples/talk-llama/llama-impl.h +0 -75
  625. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  626. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  627. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -330
  628. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  629. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2285
  630. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -389
  631. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  632. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +0 -275
  633. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +0 -140
  634. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  635. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  636. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1165
  637. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  638. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  639. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  640. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -752
  641. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  642. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1655
  643. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -206
  644. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -299
  645. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -40
  646. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -9056
  647. data/ext/sources/examples/talk-llama/llama-model.h +0 -597
  648. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1304
  649. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  650. data/ext/sources/examples/talk-llama/llama-sampler.cpp +0 -3885
  651. data/ext/sources/examples/talk-llama/llama-sampler.h +0 -42
  652. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3970
  653. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -187
  654. data/ext/sources/examples/talk-llama/llama.cpp +0 -1194
  655. data/ext/sources/examples/talk-llama/llama.h +0 -1573
  656. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -190
  657. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  658. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  659. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -137
  660. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  661. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -123
  662. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -143
  663. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -133
  664. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -184
  665. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -145
  666. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  667. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  668. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  669. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  670. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  671. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  672. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  673. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -122
  674. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  675. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -142
  676. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -262
  677. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +0 -445
  678. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -132
  679. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  680. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -148
  681. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  682. data/ext/sources/examples/talk-llama/models/eurobert.cpp +0 -97
  683. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +0 -145
  684. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  685. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  686. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -111
  687. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  688. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  689. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  690. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  691. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  692. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  693. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  694. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -157
  695. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  696. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  697. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -195
  698. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -210
  699. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  700. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -139
  701. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  702. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -153
  703. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  704. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  705. data/ext/sources/examples/talk-llama/models/jais2.cpp +0 -123
  706. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  707. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +0 -381
  708. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -196
  709. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  710. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  711. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  712. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -175
  713. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  714. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +0 -289
  715. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -54
  716. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -129
  717. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -200
  718. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -123
  719. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  720. data/ext/sources/examples/talk-llama/models/models.h +0 -704
  721. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -109
  722. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  723. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -162
  724. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  725. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  726. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  727. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  728. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  729. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  730. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  731. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  732. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +0 -122
  733. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  734. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  735. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  736. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  737. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -320
  738. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  739. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -169
  740. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  741. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  742. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  743. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  744. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -120
  745. data/ext/sources/examples/talk-llama/models/qwen35.cpp +0 -381
  746. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +0 -422
  747. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -131
  748. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -525
  749. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -140
  750. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -132
  751. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  752. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  753. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -164
  754. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  755. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  756. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -137
  757. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  758. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  759. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  760. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  761. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  762. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  763. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  764. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +0 -165
  765. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  766. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  767. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  768. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  769. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  770. data/ext/sources/examples/talk-llama/speak +0 -40
  771. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  772. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  773. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  774. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  775. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  776. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1103
  777. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  778. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  779. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  780. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  781. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  782. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  783. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  784. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  785. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  786. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  787. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  788. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  789. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  790. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  791. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -155
  792. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  793. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  794. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +0 -123
  795. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +0 -17
  796. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +0 -333
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  798. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -182
  799. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  800. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -718
  801. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  802. data/ext/sources/tests/CMakeLists.txt +0 -112
  803. data/ext/sources/tests/earnings21/eval.mk +0 -58
  804. data/ext/sources/tests/earnings21/eval.py +0 -68
  805. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  806. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  807. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  808. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  809. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  810. data/ext/sources/tests/en-0-ref.txt +0 -1
  811. data/ext/sources/tests/en-1-ref.txt +0 -1
  812. data/ext/sources/tests/en-2-ref.txt +0 -1
  813. data/ext/sources/tests/es-0-ref.txt +0 -1
  814. data/ext/sources/tests/librispeech/eval.mk +0 -39
  815. data/ext/sources/tests/librispeech/eval.py +0 -47
  816. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  817. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  818. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  819. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  820. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  821. data/ext/sources/tests/run-tests.sh +0 -130
  822. data/ext/sources/tests/test-c.c +0 -3
  823. data/ext/sources/tests/test-vad-full.cpp +0 -56
  824. data/ext/sources/tests/test-vad.cpp +0 -83
  825. data/ext/sources/tests/test-whisper.js +0 -58
  826. data/lib/whisper/context.rb +0 -15
  827. data/lib/whisper/segment.rb +0 -58
  828. /data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl → gemv_noshuffle_q8_0_f32.cl} +0 -0
@@ -25,6 +25,7 @@
25
25
  #include "ggml-impl.h"
26
26
  #include "ggml.h"
27
27
 
28
+
28
29
  #include <aclnnop/aclnn_add.h>
29
30
  #include <aclnnop/aclnn_add_rms_norm.h>
30
31
  #include <aclnnop/aclnn_addcdiv.h>
@@ -45,7 +46,9 @@
45
46
  #include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
46
47
  #include <aclnnop/aclnn_ger.h>
47
48
  #include <aclnnop/aclnn_group_norm.h>
49
+ #include <aclnnop/aclnn_gather_v2.h>
48
50
  #include <aclnnop/aclnn_grouped_matmul_v3.h>
51
+ #include <aclnnop/aclnn_scatter.h>
49
52
  #include <aclnnop/aclnn_gt_scalar.h>
50
53
  #include <aclnnop/aclnn_im2col.h>
51
54
  #include <aclnnop/aclnn_index_copy.h>
@@ -62,6 +65,7 @@
62
65
  #include <aclnnop/aclnn_permute.h>
63
66
  #include <aclnnop/aclnn_pow.h>
64
67
  #include <aclnnop/aclnn_pow_tensor_tensor.h>
68
+ #include <aclnnop/aclnn_recurrent_gated_delta_rule.h>
65
69
  #include <aclnnop/aclnn_reduce_sum.h>
66
70
  #include <aclnnop/aclnn_reflection_pad1d.h>
67
71
  #include <aclnnop/aclnn_repeat.h>
@@ -69,11 +73,15 @@
69
73
  #include <aclnnop/aclnn_rms_norm.h>
70
74
  #include <aclnnop/aclnn_roll.h>
71
75
  #include <aclnnop/aclnn_softmax.h>
76
+ #include <aclnnop/aclnn_softmax_cross_entropy_with_logits.h>
72
77
  #include <aclnnop/aclnn_sub.h>
73
78
  #include <aclnnop/aclnn_sum.h>
74
79
  #include <aclnnop/aclnn_threshold.h>
75
80
  #include <aclnnop/aclnn_tril.h>
81
+ #include <aclnnop/aclnn_triangular_solve.h>
76
82
  #include <aclnnop/aclnn_triu.h>
83
+ #include <aclnnop/aclnn_logical_not.h>
84
+ #include <aclnnop/aclnn_masked_fill_scalar.h>
77
85
  #include <aclnnop/aclnn_upsample_nearest_2d.h>
78
86
  #include <aclnnop/aclnn_weight_quant_batch_matmul_v2.h>
79
87
  #include <aclnnop/aclnn_zero.h>
@@ -151,6 +159,107 @@ void ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, ac
151
159
  GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst.get(), acl_src1.get());
152
160
  }
153
161
 
162
+ // Fused SwiGLU using aclnnSwiGlu: splits input along innermost dim, applies
163
+ // SiLU to left half, multiplies by right half.
164
+ //
165
+ // Falls back to the generic two-kernel path when src[1] != nullptr (two
166
+ // independent halves) or swapped != 0 (reversed activation order), as
167
+ // aclnnSwiGlu only handles the single interleaved tensor in standard order.
168
+ //
169
+ // CANN tiling for SwiGlu requires (storageShapeDim + viewDims) to be even.
170
+ // aclCreateTensor always uses storageShapeDim=1, so viewDims must be odd.
171
+ // We use a 3D view (1+3=4, even) to satisfy this constraint while preserving
172
+ // correct split semantics along the innermost (ne[0]) dimension.
173
+ void ggml_cann_swiglu(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
174
+ auto silu_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
175
+ GGML_CANN_CALL_ACLNN_OP(ctx, Silu, acl_src, acl_dst);
176
+ };
177
+
178
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
179
+ if (dst->src[1] != nullptr || swapped != 0) {
180
+ ggml_cann_op_unary_gated(silu_fn, ctx, dst);
181
+ return;
182
+ }
183
+
184
+ // aclnnSwiGlu requires the split dim (src->ne[0]) to be even; fall back otherwise.
185
+ if (dst->src[0]->ne[0] % 2 != 0) {
186
+ ggml_cann_op_unary_gated(silu_fn, ctx, dst);
187
+ return;
188
+ }
189
+
190
+ ggml_tensor * src0 = dst->src[0];
191
+ size_t elem_size = ggml_element_size(src0);
192
+
193
+ // src0 GGML: [2*ne0, ne1, ne2, ne3] → 3D view [2*ne0, ne1, ne2*ne3]
194
+ // CANN reversed: [ne2*ne3, ne1, 2*ne0], split along CANN dim 2 (last).
195
+ int64_t ne0_x2 = src0->ne[0];
196
+ int64_t ne1 = src0->ne[1];
197
+ int64_t ne23 = src0->ne[2] * src0->ne[3];
198
+ int64_t src3d_ne[] = { ne0_x2, ne1, ne23 };
199
+ size_t src3d_nb[] = { (size_t)src0->nb[0], (size_t)src0->nb[1], (size_t)src0->nb[2] };
200
+ acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0->data, ggml_cann_type_mapping(src0->type),
201
+ elem_size, src3d_ne, src3d_nb, 3);
202
+
203
+ // dst GGML: [ne0, ne1, ne2, ne3] → 3D view [ne0, ne1, ne2*ne3]
204
+ int64_t ne0 = dst->ne[0];
205
+ int64_t dst3d_ne[] = { ne0, ne1, ne23 };
206
+ size_t dst3d_nb[] = { (size_t)dst->nb[0], (size_t)dst->nb[1], (size_t)dst->nb[2] };
207
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type),
208
+ elem_size, dst3d_ne, dst3d_nb, 3);
209
+
210
+ // CANN tensor [ne23, ne1, 2*ne0]: split along CANN dim 2 (last) = 2*ne0.
211
+ GGML_CANN_CALL_ACLNN_OP(ctx, SwiGlu, acl_src.get(), (int64_t)2, acl_dst.get());
212
+ }
213
+
214
+ // Fused GeGLU using aclnnGeGluV3: splits input along ne[0] (CANN last dim),
215
+ // activates the LEFT half with GELU, multiplies by right half.
216
+ // approximate: 0=tanh, 1=none(erf). activateLeft=true matches GGML convention.
217
+ // outGelu is a required-but-discard output buffer.
218
+ //
219
+ // Falls back to the generic two-kernel path when src[1] != nullptr (two
220
+ // independent halves) or swapped != 0 (reversed activation order), as
221
+ // aclnnGeGluV3 only handles the single interleaved tensor in standard order.
222
+ void ggml_cann_geglu(ggml_backend_cann_context & ctx, ggml_tensor * dst, int64_t approximate) {
223
+ auto gelu_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
224
+ GGML_CANN_CALL_ACLNN_OP(ctx, Gelu, acl_src, acl_dst);
225
+ };
226
+
227
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
228
+ if (dst->src[1] != nullptr || swapped != 0) {
229
+ ggml_cann_op_unary_gated(gelu_fn, ctx, dst);
230
+ return;
231
+ }
232
+
233
+ // aclnnGeGluV3 requires the split dim (src->ne[0]) to be even; fall back otherwise.
234
+ if (dst->src[0]->ne[0] % 2 != 0) {
235
+ ggml_cann_op_unary_gated(gelu_fn, ctx, dst);
236
+ return;
237
+ }
238
+
239
+ ggml_tensor * src0 = dst->src[0];
240
+ acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0);
241
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
242
+
243
+ // Allocate a temporary buffer for the required outGelu output (same shape as dst).
244
+ // Build contiguous strides since the pool allocation is a fresh buffer.
245
+ size_t elem_size = ggml_element_size(dst);
246
+ int64_t ne[GGML_MAX_DIMS] = { dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3] };
247
+ size_t nb[GGML_MAX_DIMS];
248
+ nb[0] = elem_size;
249
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
250
+ nb[i] = nb[i - 1] * ne[i - 1];
251
+ }
252
+ size_t gelu_out_size = nb[GGML_MAX_DIMS - 1] * ne[GGML_MAX_DIMS - 1];
253
+ ggml_cann_pool_alloc gelu_out_alloc(ctx.pool(), gelu_out_size);
254
+
255
+ acl_tensor_ptr acl_gelu_out = ggml_cann_create_tensor(
256
+ gelu_out_alloc.get(), ggml_cann_type_mapping(dst->type), elem_size, ne, nb, GGML_MAX_DIMS);
257
+ // V3 adds activateLeft param; true → Gelu(left)*right, matching GGML convention.
258
+ // GGML dim 0 → CANN last dim (index GGML_MAX_DIMS-1 = 3 for 4D tensor).
259
+ GGML_CANN_CALL_ACLNN_OP(ctx, GeGluV3, acl_src.get(), (int64_t)(GGML_MAX_DIMS - 1), approximate, true,
260
+ acl_dst.get(), acl_gelu_out.get());
261
+ }
262
+
154
263
  /**
155
264
  * @brief Repeats elements of a tensor along each dimension according to the
156
265
  * specified repeat array.
@@ -434,6 +543,9 @@ void ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
434
543
  void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
435
544
  ggml_tensor * src = dst->src[0];
436
545
 
546
+ float eps;
547
+ memcpy(&eps, dst->op_params, sizeof(float));
548
+
437
549
  acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);
438
550
  acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
439
551
 
@@ -442,21 +554,33 @@ void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
442
554
  ggml_cann_pool_alloc temp_buffer_allocator(ctx.pool(), n_bytes);
443
555
  void * buffer = temp_buffer_allocator.get();
444
556
 
445
- int64_t div_ne[] = { 1, src->ne[1], src->ne[2], src->ne[3] };
446
- size_t div_nb[GGML_MAX_DIMS];
447
- div_nb[0] = sizeof(float);
557
+ int64_t norm_ne[] = { 1, src->ne[1], src->ne[2], src->ne[3] };
558
+ size_t norm_nb[GGML_MAX_DIMS];
559
+ norm_nb[0] = sizeof(float);
448
560
  for (int i = 1; i < GGML_MAX_DIMS; ++i) {
449
- div_nb[i] = div_nb[i - 1] * div_ne[i - 1];
561
+ norm_nb[i] = norm_nb[i - 1] * norm_ne[i - 1];
450
562
  }
451
- acl_tensor_ptr acl_div = ggml_cann_create_tensor(buffer, ACL_FLOAT, type_size, div_ne, div_nb, GGML_MAX_DIMS);
563
+ acl_tensor_ptr acl_norm = ggml_cann_create_tensor(buffer, ACL_FLOAT, sizeof(float), norm_ne, norm_nb, GGML_MAX_DIMS);
452
564
 
453
565
  std::vector<int64_t> norm_dims = { 3 };
454
566
  acl_int_array_ptr dims_array = ggml_cann_create_int_array(norm_dims.data(), norm_dims.size());
455
567
 
456
568
  float p_value = 2.0f;
457
569
  acl_scalar_ptr p_scalar = ggml_cann_create_scalar(&p_value, aclDataType::ACL_FLOAT);
458
- GGML_CANN_CALL_ACLNN_OP(ctx, Norm, acl_src.get(), p_scalar.get(), dims_array.get(), true, acl_div.get());
459
- GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src.get(), acl_div.get(), acl_dst.get());
570
+ GGML_CANN_CALL_ACLNN_OP(ctx, Norm, acl_src.get(), p_scalar.get(), dims_array.get(), true, acl_norm.get());
571
+
572
+ ggml_cann_pool_alloc clamp_buffer_allocator(ctx.pool());
573
+ acl_tensor_ptr acl_clamped;
574
+
575
+ if (eps > 0.0f) {
576
+ void * clamp_buf = clamp_buffer_allocator.alloc(n_bytes);
577
+ acl_clamped = ggml_cann_create_tensor(clamp_buf, ACL_FLOAT, sizeof(float), norm_ne, norm_nb, GGML_MAX_DIMS);
578
+ acl_scalar_ptr eps_scalar = ggml_cann_create_scalar(&eps, aclDataType::ACL_FLOAT);
579
+ GGML_CANN_CALL_ACLNN_OP(ctx, ClampMin, acl_norm.get(), eps_scalar.get(), acl_clamped.get());
580
+ }
581
+
582
+ aclTensor * acl_div_input = acl_clamped ? acl_clamped.get() : acl_norm.get();
583
+ GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src.get(), acl_div_input, acl_dst.get());
460
584
  }
461
585
 
462
586
  void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
@@ -472,56 +596,30 @@ void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor *
472
596
  logits_nb[1] = logits_nb[0] * logits_ne[0];
473
597
  acl_tensor_ptr acl_logits = ggml_cann_create_tensor(src0->data, ACL_FLOAT, sizeof(float), logits_ne, logits_nb, 2);
474
598
 
475
- size_t log_softmax_type_size = sizeof(float);
476
- int64_t log_softmax_n_bytes = nr * nc * log_softmax_type_size;
477
- ggml_cann_pool_alloc log_softmax_allocator(ctx.pool(), log_softmax_n_bytes);
478
- void * log_softmax_buffer = log_softmax_allocator.get();
479
-
480
- int64_t log_softmax_ne[] = { nc, nr };
481
- size_t log_softmax_nb[2];
482
- log_softmax_nb[0] = log_softmax_type_size;
483
- log_softmax_nb[1] = log_softmax_nb[0] * log_softmax_ne[0];
484
- acl_tensor_ptr acl_log_softmax = ggml_cann_create_tensor(log_softmax_buffer, ACL_FLOAT, log_softmax_type_size,
485
- log_softmax_ne, log_softmax_nb, 2);
486
-
487
- GGML_CANN_CALL_ACLNN_OP(ctx, LogSoftmax, acl_logits.get(), 1, acl_log_softmax.get());
488
-
489
599
  int64_t labels_ne[] = { nc, nr };
490
600
  size_t labels_nb[2];
491
601
  labels_nb[0] = ggml_type_size(src1->type);
492
602
  labels_nb[1] = labels_nb[0] * labels_ne[0];
493
603
  acl_tensor_ptr acl_labels = ggml_cann_create_tensor(src1->data, ACL_FLOAT, sizeof(float), labels_ne, labels_nb, 2);
494
604
 
495
- size_t mul_type_size = sizeof(float);
496
- int64_t mul_n_bytes = nr * nc * mul_type_size;
497
- ggml_cann_pool_alloc mul_allocator(ctx.pool(), mul_n_bytes);
498
- void * mul_buffer = mul_allocator.get();
605
+ size_t loss_per_sample_type_size = sizeof(float);
606
+ int64_t loss_per_sample_n_bytes = nr * loss_per_sample_type_size;
607
+ ggml_cann_pool_alloc loss_per_sample_allocator(ctx.pool(), loss_per_sample_n_bytes);
608
+ void * loss_per_sample_buffer = loss_per_sample_allocator.get();
499
609
 
500
- int64_t mul_ne[] = { nc, nr };
501
- size_t mul_nb[2];
502
- mul_nb[0] = mul_type_size;
503
- mul_nb[1] = mul_nb[0] * mul_ne[0];
504
- acl_tensor_ptr acl_mul_result = ggml_cann_create_tensor(mul_buffer, ACL_FLOAT, mul_type_size, mul_ne, mul_nb, 2);
610
+ int64_t loss_per_sample_ne[] = { nr };
611
+ size_t loss_per_sample_nb[1];
612
+ loss_per_sample_nb[0] = loss_per_sample_type_size;
613
+ acl_tensor_ptr acl_loss_per_sample = ggml_cann_create_tensor(
614
+ loss_per_sample_buffer, ACL_FLOAT, loss_per_sample_type_size, loss_per_sample_ne, loss_per_sample_nb, 1);
505
615
 
506
- GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_log_softmax.get(), acl_labels.get(), acl_mul_result.get());
616
+ size_t backprop_n_bytes = nr * nc * sizeof(float);
617
+ ggml_cann_pool_alloc backprop_allocator(ctx.pool(), backprop_n_bytes);
618
+ void * backprop_buffer = backprop_allocator.get();
619
+ acl_tensor_ptr acl_backprop = ggml_cann_create_tensor(backprop_buffer, ACL_FLOAT, sizeof(float), logits_ne, logits_nb, 2);
507
620
 
508
- size_t sum_per_sample_type_size = sizeof(float);
509
- int64_t sum_per_sample_n_bytes = nr * sum_per_sample_type_size;
510
- ggml_cann_pool_alloc sum_per_sample_allocator(ctx.pool(), sum_per_sample_n_bytes);
511
- void * sum_per_sample_buffer = sum_per_sample_allocator.get();
512
-
513
- int64_t sum_per_sample_ne[] = { nr };
514
- size_t sum_per_sample_nb[1];
515
- sum_per_sample_nb[0] = sum_per_sample_type_size;
516
- acl_tensor_ptr acl_sum_per_sample = ggml_cann_create_tensor(
517
- sum_per_sample_buffer, ACL_FLOAT, sum_per_sample_type_size, sum_per_sample_ne, sum_per_sample_nb, 1);
518
-
519
- std::vector<int64_t> sum_dims = { 1 };
520
- acl_int_array_ptr dims_array = ggml_cann_create_int_array(sum_dims.data(), sum_dims.size());
521
- bool keep_dims = false;
522
-
523
- GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_mul_result.get(), dims_array.get(), keep_dims, ACL_FLOAT,
524
- acl_sum_per_sample.get());
621
+ GGML_CANN_CALL_ACLNN_OP(ctx, SoftmaxCrossEntropyWithLogits, acl_logits.get(), acl_labels.get(),
622
+ acl_loss_per_sample.get(), acl_backprop.get());
525
623
 
526
624
  size_t total_sum_type_size = sizeof(float);
527
625
  int64_t total_sum_n_bytes = 1 * total_sum_type_size;
@@ -537,11 +635,12 @@ void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor *
537
635
 
538
636
  std::vector<int64_t> total_sum_dims = { 0 };
539
637
  acl_int_array_ptr total_sum_dims_array = ggml_cann_create_int_array(total_sum_dims.data(), total_sum_dims.size());
638
+ bool keep_dims = false;
540
639
 
541
- GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_sum_per_sample.get(), total_sum_dims_array.get(), keep_dims, ACL_FLOAT,
640
+ GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_loss_per_sample.get(), total_sum_dims_array.get(), keep_dims, ACL_FLOAT,
542
641
  acl_total_sum.get());
543
642
 
544
- float value = -1.0f / static_cast<float>(nr);
643
+ float value = 1.0f / static_cast<float>(nr);
545
644
  acl_scalar_ptr scale_factor = ggml_cann_create_scalar(&value, aclDataType::ACL_FLOAT);
546
645
  acl_tensor_ptr acl_dst =
547
646
  ggml_cann_create_tensor(dst->data, ACL_FLOAT, sizeof(float), total_sum_ne, total_sum_nb, 1);
@@ -579,6 +678,33 @@ void ggml_cann_group_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
579
678
  acl_mean_out.get(), acl_rstd_out.get());
580
679
  }
581
680
 
681
+ void ggml_cann_set(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
682
+ ggml_tensor * src0 = dst->src[0];
683
+ ggml_tensor * src1 = dst->src[1];
684
+
685
+ size_t nb1 = ((int32_t *) dst->op_params)[0];
686
+ size_t nb2 = ((int32_t *) dst->op_params)[1];
687
+ size_t nb3 = ((int32_t *) dst->op_params)[2];
688
+ size_t offset = ((int32_t *) dst->op_params)[3];
689
+ bool inplace = (bool) ((int32_t *) dst->op_params)[4];
690
+
691
+ size_t param_nb[] = { ggml_element_size(src0), nb1, nb2, nb3 };
692
+
693
+ // Create a view of dst at the target offset with src1's dimensions
694
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, src1->ne, param_nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset);
695
+ acl_tensor_ptr acl_src1 = ggml_cann_create_tensor(src1);
696
+
697
+ if (!inplace) {
698
+ // First copy src0 to dst entirely
699
+ size_t cpy_size = ggml_nbytes(dst);
700
+ ACL_CHECK(
701
+ aclrtMemcpyAsync(dst->data, cpy_size, src0->data, cpy_size, ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
702
+ }
703
+
704
+ // Copy src1 into the target region of dst
705
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst.get(), acl_src1.get());
706
+ }
707
+
582
708
  void ggml_cann_acc(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
583
709
  ggml_tensor * src0 = dst->src[0];
584
710
  ggml_tensor * src1 = dst->src[1];
@@ -642,6 +768,113 @@ void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
642
768
  aclnn_reduce_sum(ctx, dst, reduce_dims, 4);
643
769
  }
644
770
 
771
+ void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
772
+ ggml_tensor * src = dst->src[0];
773
+ acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);
774
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
775
+ // GGML cumsum operates along dim 0 (innermost / ne[0]).
776
+ // ggml_cann_create_tensor reverses dimensions to [ne3,ne2,ne1,ne0],
777
+ // so GGML dim 0 maps to CANN dim 3 (the last dim of the 4-D tensor).
778
+ GGML_CANN_CALL_ACLNN_OP(ctx, Cumsum, acl_src.get(), (int64_t)3,
779
+ ggml_cann_type_mapping(dst->type), acl_dst.get());
780
+ }
781
+
782
+ void ggml_cann_solve_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
783
+ ggml_tensor * src0 = dst->src[0]; // A: [N, N, B2, B3] lower triangular
784
+ ggml_tensor * src1 = dst->src[1]; // B: [K, N, B2, B3]
785
+
786
+ acl_tensor_ptr acl_a = ggml_cann_create_tensor(src0);
787
+ acl_tensor_ptr acl_b = ggml_cann_create_tensor(src1);
788
+ acl_tensor_ptr acl_x = ggml_cann_create_tensor(dst);
789
+
790
+ // mOut: triangular copy of A (required output), same shape as A.
791
+ const size_t a_bytes = ggml_nbytes(src0);
792
+ ggml_cann_pool_alloc m_alloc(ctx.pool(), a_bytes);
793
+ acl_tensor_ptr acl_m = ggml_cann_create_tensor(
794
+ m_alloc.get(), ggml_cann_type_mapping(src0->type),
795
+ ggml_type_size(src0->type), src0->ne, src0->nb, GGML_MAX_DIMS);
796
+
797
+ // Solve AX = B: upper=false (lower tri), transpose=false, unitriangular=false.
798
+ GGML_CANN_CALL_ACLNN_OP(ctx, TriangularSolve,
799
+ acl_b.get(), acl_a.get(), false, false, false,
800
+ acl_x.get(), acl_m.get());
801
+ }
802
+
803
+ void ggml_cann_diag(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
804
+ ggml_tensor * src = dst->src[0];
805
+
806
+ GGML_ASSERT(src->ne[1] == 1);
807
+
808
+ const int64_t N = src->ne[0];
809
+ const int64_t n_batch = src->ne[2] * src->ne[3];
810
+ const size_t nb_f32 = sizeof(float);
811
+
812
+ // Fill dst with zeros.
813
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
814
+ {
815
+ float zero = 0.0f;
816
+ acl_scalar_ptr acl_zero = ggml_cann_create_scalar(&zero, ACL_FLOAT);
817
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst.get(), acl_zero.get());
818
+ }
819
+
820
+ // Copy src vector onto the diagonal of dst via strided views.
821
+ // src viewed as [N, n_batch], contiguous strides.
822
+ int64_t ne_vec[2] = { N, n_batch };
823
+ size_t nb_src_vec[2] = { nb_f32, N * nb_f32 };
824
+ // dst diagonal view: stride (N+1)*4 steps along the diagonal.
825
+ size_t nb_dst_diag[2] = { (N + 1) * nb_f32, N * N * nb_f32 };
826
+
827
+ acl_tensor_ptr acl_src_vec = ggml_cann_create_tensor(src->data, ACL_FLOAT, nb_f32, ne_vec, nb_src_vec, 2);
828
+ acl_tensor_ptr acl_dst_diag = ggml_cann_create_tensor(dst->data, ACL_FLOAT, nb_f32, ne_vec, nb_dst_diag, 2);
829
+
830
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst_diag.get(), acl_src_vec.get());
831
+ }
832
+
833
+ void ggml_cann_fill(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
834
+ float c = ggml_get_op_params_f32(dst, 0);
835
+
836
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
837
+ acl_scalar_ptr acl_c = ggml_cann_create_scalar(&c, ACL_FLOAT);
838
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst.get(), acl_c.get());
839
+ }
840
+
841
+ void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
842
+ ggml_tensor * src = dst->src[0];
843
+
844
+ const int64_t S = src->ne[0];
845
+ const int64_t n_batch = src->ne[2] * src->ne[3];
846
+ const size_t nb_f32 = sizeof(float);
847
+
848
+ int64_t ne3d[3] = { S, S, n_batch };
849
+ size_t nb3d[3] = { nb_f32, S * nb_f32, S * S * nb_f32 };
850
+
851
+ const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
852
+
853
+ acl_tensor_ptr acl_src = ggml_cann_create_tensor(src->data, ACL_FLOAT, nb_f32, ne3d, nb3d, 3);
854
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ACL_FLOAT, nb_f32, ne3d, nb3d, 3);
855
+
856
+ switch (ttype) {
857
+ case GGML_TRI_TYPE_LOWER:
858
+ // Tril(-1): preserve row > col (strict lower), zero upper + diagonal.
859
+ GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src.get(), (int64_t)-1, acl_dst.get());
860
+ break;
861
+ case GGML_TRI_TYPE_UPPER_DIAG:
862
+ // Triu(0): preserve row <= col (upper + diagonal), zero strict lower.
863
+ GGML_CANN_CALL_ACLNN_OP(ctx, Triu, acl_src.get(), (int64_t)0, acl_dst.get());
864
+ break;
865
+ case GGML_TRI_TYPE_UPPER:
866
+ // Triu(1): preserve row < col (strict upper), zero lower + diagonal.
867
+ GGML_CANN_CALL_ACLNN_OP(ctx, Triu, acl_src.get(), (int64_t)1, acl_dst.get());
868
+ break;
869
+ case GGML_TRI_TYPE_LOWER_DIAG:
870
+ // Tril(0): preserve row >= col (lower + diagonal), zero strict upper.
871
+ GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src.get(), (int64_t)0, acl_dst.get());
872
+ break;
873
+ default:
874
+ GGML_ABORT("unsupported tri type");
875
+ }
876
+ }
877
+
645
878
  void ggml_cann_upsample_nearest2d(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
646
879
  ggml_tensor * src = dst->src[0];
647
880
  acl_tensor_ptr acl_src = ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW);
@@ -1544,8 +1777,8 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx,
1544
1777
  end = 2 * ((n_head - 1) - n_head_log2) + 1;
1545
1778
  step = 2;
1546
1779
  count = n_head - n_head_log2;
1547
- aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * sizeof(float), m1, count, start, end + 1, step,
1548
- dtype);
1780
+ aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * ggml_type_size(dtype), m1, count, start, end + 1,
1781
+ step, dtype);
1549
1782
  }
1550
1783
  }
1551
1784
 
@@ -1685,150 +1918,90 @@ void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
1685
1918
  aclnn_softmax(ctx, softmax_tensor.get(), 3, acl_dst.get());
1686
1919
  }
1687
1920
 
1688
- /**
1689
- * @brief Performs index select operation on a 4D tensor using the CANN backend.
1690
- *
1691
- * This function applies the `IndexSelect` operation along a specific dimension
1692
- * of the source tensor (`src_buffer`) using the indices from the index tensor (`index`).
1693
- * It iterates over the last two dimensions of the source tensor, creates the corresponding
1694
- * CANN tensors for the source, index, and output slices, and executes the `IndexSelect`
1695
- * operation for each slice.
1696
- *
1697
- * @param ctx The context for CANN backend operations.
1698
- * @param src_buffer The source buffer containing the 4D input tensor data.
1699
- * @param src_ne The dimensions of the source tensor.
1700
- * @param src_nb The strides (byte offsets) of the source tensor.
1701
- * @param dst_buffer The destination buffer where the output tensor data will be written.
1702
- * @param dst_ne The dimensions of the destination tensor.
1703
- * @param dst_nb The strides (byte offsets) of the destination tensor.
1704
- * @param index The index tensor specifying the indices to select from the source tensor.
1705
- * @param type The data type of the source and destination tensors.
1706
- */
1707
- static void aclnn_index_select_4d(ggml_backend_cann_context & ctx,
1708
- void * src_buffer,
1709
- int64_t * src_ne,
1710
- size_t * src_nb,
1711
- void * dst_buffer,
1712
- int64_t * dst_ne,
1713
- size_t * dst_nb,
1714
- ggml_tensor * index,
1715
- ggml_type type) {
1716
- for (int64_t i = 0; i < src_ne[3]; i++) {
1717
- for (int64_t j = 0; j < src_ne[2]; j++) {
1718
- // src
1719
- acl_tensor_ptr acl_src_tensor =
1720
- ggml_cann_create_tensor((char *) src_buffer + i * src_nb[3] + j * src_nb[2],
1721
- ggml_cann_type_mapping(type), ggml_type_size(type), src_ne, src_nb, 2);
1722
-
1723
- // index
1724
- acl_tensor_ptr acl_index = ggml_cann_create_tensor(
1725
- (char *) index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1],
1726
- ggml_cann_type_mapping(index->type), ggml_element_size(index), index->ne, index->nb, 1);
1727
-
1728
- // out
1729
- acl_tensor_ptr acl_out =
1730
- ggml_cann_create_tensor((char *) dst_buffer + i * dst_nb[3] + j * dst_nb[2],
1731
- ggml_cann_type_mapping(type), ggml_type_size(type), dst_ne, dst_nb, 2);
1732
- GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, acl_src_tensor.get(), 0, acl_index.get(), acl_out.get());
1733
- }
1734
- }
1735
- }
1736
-
1737
- /**
1738
- * @brief Performs inplace index copy operation on a 4D tensor using the CANN backend.
1739
- *
1740
- * This function applies the `IndexCopy` operation along a specific dimension of the
1741
- * destination tensor (`dst_buffer`) by copying elements from the source tensor (`src_buffer`)
1742
- * to positions specified by the index tensor (`index`).
1743
- * It iterates over the last two dimensions of the tensors, creates the corresponding
1744
- * CANN tensors for source, index, and destination slices, and performs the index copy
1745
- * operation for each slice.
1746
- *
1747
- * @param ctx The context for CANN backend operations.
1748
- * @param src_buffer The source buffer containing the 4D input tensor data to be copied.
1749
- * @param src_ne The dimensions of the source tensor.
1750
- * @param src_nb The strides (byte offsets) of the source tensor.
1751
- * @param dst_buffer The destination buffer where values will be copied to.
1752
- * @param dst_ne The dimensions of the destination tensor.
1753
- * @param dst_nb The strides (byte offsets) of the destination tensor.
1754
- * @param index The index tensor specifying target positions in the destination tensor.
1755
- * @param type The data type of the source and destination tensors.
1756
- */
1757
- static void aclnn_index_copy_4d(ggml_backend_cann_context & ctx,
1758
- void * src_buffer,
1759
- int64_t * src_ne,
1760
- size_t * src_nb,
1761
- void * dst_buffer,
1762
- int64_t * dst_ne,
1763
- size_t * dst_nb,
1764
- ggml_tensor * index,
1765
- ggml_type type) {
1766
- for (int64_t i = 0; i < src_ne[3]; i++) {
1767
- for (int64_t j = 0; j < src_ne[2]; j++) {
1768
- // src
1769
- acl_tensor_ptr acl_src_tensor =
1770
- ggml_cann_create_tensor((char *) src_buffer + i * src_nb[3] + j * src_nb[2],
1771
- ggml_cann_type_mapping(type), ggml_type_size(type), src_ne, src_nb, 2);
1772
-
1773
- // index
1774
- acl_tensor_ptr acl_index = ggml_cann_create_tensor(
1775
- (char *) index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1],
1776
- ggml_cann_type_mapping(index->type), ggml_element_size(index), index->ne, index->nb, 1);
1777
-
1778
- // out
1779
- acl_tensor_ptr acl_out =
1780
- ggml_cann_create_tensor((char *) dst_buffer + i * dst_nb[3] + j * dst_nb[2],
1781
- ggml_cann_type_mapping(type), ggml_type_size(type), dst_ne, dst_nb, 2);
1782
- GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexCopy, acl_out.get(), 0, acl_index.get(), acl_src_tensor.get());
1783
- }
1784
- }
1785
- }
1786
1921
 
1787
1922
  void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
1788
- ggml_tensor * src0 = dst->src[0]; // src
1923
+ ggml_tensor * src0 = dst->src[0]; // weight
1789
1924
  ggml_tensor * src1 = dst->src[1]; // index
1790
1925
 
1791
- GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
1926
+ GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16
1927
+ || dst->type == GGML_TYPE_BF16);
1928
+
1929
+ // n_idx: number of row indices per (i2, i3) batch slice.
1930
+ // ggml guarantees: src0->ne[2] == src1->ne[1], src0->ne[3] == src1->ne[2], src1->ne[3] == 1.
1931
+ const int64_t n_idx = src1->ne[0];
1932
+
1933
+ // Gather all (i2, i3) batch slices from src into dst.
1934
+ // ggml_cann_create_tensor reverses dims, so ACL sees [ne1, ne0].
1935
+ // GatherV2 with dim=0 gathers along ACL dim-0 == ggml ne[1] (the vocabulary / row axis).
1936
+ // nb: the 4 strides of the source buffer (nb[0..1] for the 2D slice shape,
1937
+ // nb[2..3] for computing per-batch-slice base pointer offsets).
1938
+ auto gather_batched = [&](void * src_base, aclDataType acl_type, size_t type_size,
1939
+ const size_t * nb) {
1940
+ int64_t src_ne[2] = { src0->ne[0], src0->ne[1] };
1941
+ size_t src_nb_2d[2] = { nb[0], nb[1] };
1942
+ int64_t dst_ne[2] = { src0->ne[0], n_idx };
1943
+ size_t dst_nb_2d[2] = { dst->nb[0], dst->nb[1] };
1944
+ int64_t idx_ne[1] = { n_idx };
1945
+ size_t idx_nb[1] = { (size_t)ggml_element_size(src1) };
1946
+
1947
+ for (int64_t i3 = 0; i3 < src0->ne[3]; i3++) {
1948
+ for (int64_t i2 = 0; i2 < src0->ne[2]; i2++) {
1949
+ acl_tensor_ptr acl_src = ggml_cann_create_tensor(
1950
+ (char *)src_base + i3 * nb[3] + i2 * nb[2],
1951
+ acl_type, type_size, src_ne, src_nb_2d, 2);
1952
+ acl_tensor_ptr acl_idx = ggml_cann_create_tensor(
1953
+ (char *)src1->data + i3 * src1->nb[2] + i2 * src1->nb[1],
1954
+ ggml_cann_type_mapping(src1->type), (size_t)ggml_element_size(src1),
1955
+ idx_ne, idx_nb, 1);
1956
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(
1957
+ (char *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2],
1958
+ acl_type, type_size, dst_ne, dst_nb_2d, 2);
1959
+ GGML_CANN_CALL_ACLNN_OP(ctx, GatherV2, acl_src.get(), 0, acl_idx.get(), acl_dst.get());
1960
+ }
1961
+ }
1962
+ };
1792
1963
 
1793
1964
  switch (src0->type) {
1965
+ case GGML_TYPE_BF16:
1794
1966
  case GGML_TYPE_F16:
1795
1967
  case GGML_TYPE_F32:
1796
1968
  if (src0->type == dst->type) {
1797
- aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb, dst->data, dst->ne, dst->nb, src1,
1798
- dst->type);
1969
+ gather_batched(src0->data,
1970
+ ggml_cann_type_mapping(src0->type), ggml_type_size(src0->type),
1971
+ src0->nb);
1799
1972
  } else {
1800
- acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
1801
- ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * ggml_element_size(dst));
1802
- void * src_trans_buffer = src_buffer_allocator.get();
1803
- size_t src_trans_nb[GGML_MAX_DIMS];
1804
- src_trans_nb[0] = dst->nb[0];
1973
+ // Cast src0 to dst type, then gather.
1974
+ ggml_cann_pool_alloc src_cast_allocator(ctx.pool(),
1975
+ ggml_nelements(src0) * ggml_element_size(dst));
1976
+ size_t src_cast_nb[GGML_MAX_DIMS];
1977
+ src_cast_nb[0] = ggml_type_size(dst->type);
1805
1978
  for (int i = 1; i < GGML_MAX_DIMS; i++) {
1806
- src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
1979
+ src_cast_nb[i] = src_cast_nb[i - 1] * src0->ne[i - 1];
1807
1980
  }
1808
- acl_tensor_ptr src_trans_tensor =
1809
- ggml_cann_create_tensor(src_trans_buffer, ggml_cann_type_mapping(dst->type),
1810
- ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS);
1811
- aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type));
1812
- aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1,
1813
- dst->type);
1981
+ acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
1982
+ acl_tensor_ptr acl_src_cast = ggml_cann_create_tensor(
1983
+ src_cast_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
1984
+ src0->ne, src_cast_nb, GGML_MAX_DIMS);
1985
+ aclnn_cast(ctx, acl_src0.get(), acl_src_cast.get(), ggml_cann_type_mapping(dst->type));
1986
+
1987
+ gather_batched(src_cast_allocator.get(),
1988
+ ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
1989
+ src_cast_nb);
1814
1990
  }
1815
1991
  break;
1816
1992
  case GGML_TYPE_Q8_0:
1817
1993
  {
1818
- // add 1 dim for bcast mul.
1994
+ // Dequantize Q8_0 to dst type, then gather.
1819
1995
  size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1], dequant_nb[GGML_MAX_DIMS + 1];
1820
1996
  int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1], *dequant_ne;
1821
- int64_t scale_offset = 0;
1822
- // [3,4,5,64] -> [3,4,5,2,32]
1823
- weight_ne[0] = QK8_0;
1824
- weight_ne[1] = src0->ne[0] / QK8_0;
1825
- weight_nb[0] = sizeof(int8_t);
1826
- weight_nb[1] = weight_nb[0] * weight_ne[0];
1997
+ weight_ne[0] = QK8_0;
1998
+ weight_ne[1] = src0->ne[0] / QK8_0;
1999
+ weight_nb[0] = sizeof(int8_t);
2000
+ weight_nb[1] = weight_nb[0] * weight_ne[0];
1827
2001
  for (int i = 2; i < GGML_MAX_DIMS + 1; i++) {
1828
2002
  weight_ne[i] = src0->ne[i - 1];
1829
2003
  weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1];
1830
2004
  }
1831
- // [3,4,5,64] -> [3,4,5,2,1]
1832
2005
  scale_ne[0] = 1;
1833
2006
  scale_ne[1] = src0->ne[0] / QK8_0;
1834
2007
  scale_nb[0] = sizeof(uint16_t);
@@ -1837,31 +2010,33 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
1837
2010
  scale_ne[i] = src0->ne[i - 1];
1838
2011
  scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1];
1839
2012
  }
1840
- // [3,4,5,64] -> [3,4,5,2,32]
1841
2013
  dequant_ne = weight_ne;
1842
2014
  dequant_nb[0] = ggml_type_size(dst->type);
1843
2015
  for (int i = 1; i < GGML_MAX_DIMS + 1; i++) {
1844
2016
  dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1];
1845
2017
  }
1846
- scale_offset = ggml_nelements(src0) * sizeof(int8_t);
1847
- ggml_cann_pool_alloc dequant_buffer_allocator(ctx.pool(),
1848
- ggml_nelements(src0) * ggml_type_size(dst->type));
1849
- acl_tensor_ptr acl_weight_tensor = ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t),
1850
- weight_ne, weight_nb, GGML_MAX_DIMS + 1);
1851
- acl_tensor_ptr acl_scale_tensor =
1852
- ggml_cann_create_tensor(src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb,
1853
- GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset);
1854
- acl_tensor_ptr dequant_tensor =
1855
- ggml_cann_create_tensor(dequant_buffer_allocator.get(), ggml_cann_type_mapping(dst->type),
1856
- ggml_type_size(dst->type), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1);
1857
- aclnn_mul(ctx, acl_weight_tensor.get(), acl_scale_tensor.get(), dequant_tensor.get());
1858
- dequant_nb[0] = ggml_type_size(dst->type);
2018
+ const int64_t scale_offset = ggml_nelements(src0) * sizeof(int8_t);
2019
+ ggml_cann_pool_alloc dequant_allocator(ctx.pool(),
2020
+ ggml_nelements(src0) * ggml_type_size(dst->type));
2021
+ acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t),
2022
+ weight_ne, weight_nb, GGML_MAX_DIMS + 1);
2023
+ acl_tensor_ptr acl_scale = ggml_cann_create_tensor(
2024
+ src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb,
2025
+ GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset);
2026
+ acl_tensor_ptr acl_dequant = ggml_cann_create_tensor(
2027
+ dequant_allocator.get(), ggml_cann_type_mapping(dst->type),
2028
+ ggml_type_size(dst->type), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1);
2029
+ aclnn_mul(ctx, acl_weight.get(), acl_scale.get(), acl_dequant.get());
2030
+
2031
+ // Reinterpret dequant buffer as 4D [src0->ne] with contiguous strides.
1859
2032
  dequant_ne = src0->ne;
2033
+ dequant_nb[0] = ggml_type_size(dst->type);
1860
2034
  for (int i = 1; i < GGML_MAX_DIMS; i++) {
1861
2035
  dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1];
1862
2036
  }
1863
- aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(), dequant_ne, dequant_nb, dst->data, dst->ne,
1864
- dst->nb, src1, dst->type);
2037
+ gather_batched(dequant_allocator.get(),
2038
+ ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
2039
+ dequant_nb);
1865
2040
  break;
1866
2041
  }
1867
2042
  default:
@@ -1871,30 +2046,70 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
1871
2046
  }
1872
2047
 
1873
2048
  void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
1874
- ggml_tensor * src0 = dst->src[0]; // src
1875
- ggml_tensor * src1 = dst->src[1]; // index
2049
+ ggml_tensor * src0 = dst->src[0]; // source values
2050
+ ggml_tensor * src1 = dst->src[1]; // row indices
2051
+
2052
+ // n_idx: number of source rows to scatter per batch slice.
2053
+ // ggml guarantees: src0->ne[1] == src1->ne[0].
2054
+ const int64_t n_idx = src1->ne[0];
2055
+
2056
+ // Copy n_idx rows of src [ne0, n_idx] into dst [ne0, ne1] at positions given by a 1D index.
2057
+ // ggml_cann_create_tensor reverses dims, so ACL sees [ne1, ne0] for dst.
2058
+ // InplaceIndexCopy with dim=0 copies along ACL dim-0 == ggml ne[1] (the row axis).
2059
+ // src_nb: the 4 strides of the source buffer (nb[0..1] for the 2D slice shape,
2060
+ // nb[2..3] for computing per-batch-slice base pointer offsets).
2061
+ auto scatter_batched = [&](void * src_base, aclDataType acl_type, size_t type_size,
2062
+ const size_t * src_nb) {
2063
+ int64_t d_ne[2] = { dst->ne[0], dst->ne[1] };
2064
+ size_t d_nb[2] = { dst->nb[0], dst->nb[1] };
2065
+ int64_t s_ne[2] = { dst->ne[0], n_idx };
2066
+ size_t s_nb_2d[2] = { src_nb[0], src_nb[1] };
2067
+ int64_t i_ne[1] = { n_idx };
2068
+ size_t i_nb[1] = { (size_t)ggml_element_size(src1) };
2069
+
2070
+ for (int64_t i3 = 0; i3 < dst->ne[3]; i3++) {
2071
+ for (int64_t i2 = 0; i2 < dst->ne[2]; i2++) {
2072
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(
2073
+ (char *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2],
2074
+ acl_type, type_size, d_ne, d_nb, 2);
2075
+ acl_tensor_ptr acl_idx = ggml_cann_create_tensor(
2076
+ (char *)src1->data + (i3 % src1->ne[2]) * src1->nb[2] + (i2 % src1->ne[1]) * src1->nb[1],
2077
+ ggml_cann_type_mapping(src1->type), (size_t)ggml_element_size(src1),
2078
+ i_ne, i_nb, 1);
2079
+ acl_tensor_ptr acl_src = ggml_cann_create_tensor(
2080
+ (char *)src_base + i3 * src_nb[3] + i2 * src_nb[2],
2081
+ acl_type, type_size, s_ne, s_nb_2d, 2);
2082
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexCopy, acl_dst.get(), 0, acl_idx.get(), acl_src.get());
2083
+ }
2084
+ }
2085
+ };
1876
2086
 
1877
2087
  switch (dst->type) {
1878
2088
  case GGML_TYPE_F32:
1879
- {
1880
- aclnn_index_copy_4d(ctx, src0->data, src0->ne, src0->nb, dst->data, dst->ne, dst->nb, src1, dst->type);
1881
- break;
1882
- }
2089
+ scatter_batched(src0->data,
2090
+ ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
2091
+ src0->nb);
2092
+ break;
1883
2093
  case GGML_TYPE_F16:
2094
+ case GGML_TYPE_BF16:
1884
2095
  {
1885
- acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
1886
- ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(uint16_t));
1887
- void * src_trans_buffer = src_buffer_allocator.get();
1888
- size_t src_trans_nb[GGML_MAX_DIMS];
1889
- src_trans_nb[0] = sizeof(uint16_t);
2096
+ // Cast src0 (F32) to dst type first.
2097
+ ggml_cann_pool_alloc src_cast_allocator(ctx.pool(),
2098
+ ggml_nelements(src0) * ggml_type_size(dst->type));
2099
+ size_t src_cast_nb[GGML_MAX_DIMS];
2100
+ src_cast_nb[0] = ggml_type_size(dst->type);
1890
2101
  for (int i = 1; i < GGML_MAX_DIMS; i++) {
1891
- src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
2102
+ src_cast_nb[i] = src_cast_nb[i - 1] * src0->ne[i - 1];
1892
2103
  }
1893
- acl_tensor_ptr src_trans_tensor = ggml_cann_create_tensor(
1894
- src_trans_buffer, ACL_FLOAT16, ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS);
1895
- aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type));
1896
- aclnn_index_copy_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1,
1897
- dst->type);
2104
+ acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0);
2105
+ acl_tensor_ptr acl_src_cast = ggml_cann_create_tensor(
2106
+ src_cast_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
2107
+ src0->ne, src_cast_nb, GGML_MAX_DIMS);
2108
+ aclnn_cast(ctx, acl_src0.get(), acl_src_cast.get(), ggml_cann_type_mapping(dst->type));
2109
+
2110
+ scatter_batched(src_cast_allocator.get(),
2111
+ ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
2112
+ src_cast_nb);
1898
2113
  break;
1899
2114
  }
1900
2115
  default:
@@ -1965,7 +2180,7 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context & ctx, ggml_tensor *
1965
2180
 
1966
2181
  // Only check env once.
1967
2182
  static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
1968
- if (weight_to_nz && is_matmul_weight(weight)) {
2183
+ if (weight_to_nz && weight->type != GGML_TYPE_BF16 && is_matmul_weight(weight)) {
1969
2184
  acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);
1970
2185
  } else {
1971
2186
  acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND);
@@ -2146,6 +2361,9 @@ void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
2146
2361
  switch (type) {
2147
2362
  case GGML_TYPE_F32:
2148
2363
  case GGML_TYPE_F16:
2364
+ #ifndef ASCEND_310P
2365
+ case GGML_TYPE_BF16:
2366
+ #endif
2149
2367
  ggml_cann_mat_mul_fp(ctx, dst);
2150
2368
  break;
2151
2369
  case GGML_TYPE_Q4_0:
@@ -2943,6 +3161,27 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
2943
3161
  // Rotate full tensor (no tail), using trans tensors
2944
3162
  GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor.get(), acl_cos_reshape_tensor.get(),
2945
3163
  acl_sin_reshape_tensor.get(), acl_mode, acl_dst_trans_tensor.get());
3164
+ } else if (src0->data == dst->data && !ggml_is_contiguous(src0)) {
3165
+ // In-place on non-contiguous tensor: RotaryPositionEmbedding cannot safely
3166
+ // read and write the same non-contiguous buffer. Use contiguous temporaries.
3167
+ size_t contiguous_nb[GGML_MAX_DIMS];
3168
+ contiguous_nb[0] = sizeof(float);
3169
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
3170
+ contiguous_nb[i] = contiguous_nb[i - 1] * src0->ne[i - 1];
3171
+ }
3172
+ int64_t total_elements = ggml_nelements(src0);
3173
+ ggml_cann_pool_alloc inplace_src_alloc(ctx.pool(), total_elements * sizeof(float));
3174
+ ggml_cann_pool_alloc inplace_dst_alloc(ctx.pool(), total_elements * sizeof(float));
3175
+
3176
+ acl_tensor_ptr acl_src_contig = ggml_cann_create_tensor(inplace_src_alloc.get(), ACL_FLOAT, sizeof(float),
3177
+ src0->ne, contiguous_nb, GGML_MAX_DIMS);
3178
+ acl_tensor_ptr acl_dst_contig = ggml_cann_create_tensor(inplace_dst_alloc.get(), ACL_FLOAT, sizeof(float),
3179
+ dst->ne, contiguous_nb, GGML_MAX_DIMS);
3180
+
3181
+ cann_copy(ctx, acl_src.get(), acl_src_contig.get());
3182
+ GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_contig.get(), acl_cos_reshape_tensor.get(),
3183
+ acl_sin_reshape_tensor.get(), acl_mode, acl_dst_contig.get());
3184
+ cann_copy(ctx, acl_dst_contig.get(), acl_dst.get());
2946
3185
  } else {
2947
3186
  // Rotate full tensor (no tail), using original tensors
2948
3187
  GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src.get(), acl_cos_reshape_tensor.get(),
@@ -2984,6 +3223,58 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
2984
3223
  }
2985
3224
  }
2986
3225
 
3226
+ void ggml_cann_rope_cache_preload(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3227
+ ggml_tensor * src0 = dst->src[0];
3228
+
3229
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
3230
+ int sections[4];
3231
+ const int n_dims = ((int32_t *) dst->op_params)[1];
3232
+ const int mode = ((int32_t *) dst->op_params)[2];
3233
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
3234
+
3235
+ GGML_TENSOR_UNARY_OP_LOCALS
3236
+
3237
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
3238
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
3239
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
3240
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
3241
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
3242
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
3243
+ memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int) * 4);
3244
+
3245
+ const float theta_scale = powf(freq_base, -2.0f / n_dims);
3246
+
3247
+ float corr_dims[2];
3248
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
3249
+
3250
+ bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
3251
+ const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
3252
+ const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;
3253
+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
3254
+
3255
+ if (is_imrope || mrope_used) {
3256
+ is_neox = true;
3257
+ }
3258
+
3259
+ int64_t rope_dims = n_dims;
3260
+ if (is_vision) {
3261
+ rope_dims = src0->ne[0];
3262
+ }
3263
+
3264
+ // Run the full cache init on the non-captured stream. This performs all
3265
+ // host-to-device memcpy, aclrtMalloc/Free, and on-device computations
3266
+ // so that the memory pool is warmed up and cache metadata is populated.
3267
+ aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections,
3268
+ mrope_used, is_imrope, is_vision, rope_dims);
3269
+
3270
+ // Reset `cached` so that during graph capture the on-device computations
3271
+ // (sin/cos, position multiply, repeat, etc.) still execute and get recorded
3272
+ // into the captured graph. The cache metadata (theta_scale_length,
3273
+ // theta_scale, sections, position_length, etc.) remains set, which causes
3274
+ // all host-to-device copy and malloc/free branches to be skipped.
3275
+ ctx.rope_cache.cached = false;
3276
+ }
3277
+
2987
3278
  void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
2988
3279
  ggml_tensor * src0 = dst->src[0];
2989
3280
 
@@ -3179,29 +3470,50 @@ void ggml_cann_pad_reflect_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst
3179
3470
  int64_t paddingsArray[2] = { opts[0], opts[1] };
3180
3471
  acl_int_array_ptr paddings = ggml_cann_create_int_array(paddingsArray, 2);
3181
3472
 
3182
- for (int64_t i = 0; i < src0->ne[3]; i++) {
3183
- acl_tensor_ptr acl_src =
3184
- ggml_cann_create_tensor((char *) src0->data + i * src0->ne[3], ggml_cann_type_mapping(src0->type),
3185
- ggml_element_size(src0), src0->ne, src0->nb, 3);
3473
+ // Collapsing ne[2]*ne[3] into a single batch dimension requires that dim3
3474
+ // is contiguous with respect to dim2 in both src and dst.
3475
+ GGML_ASSERT(src0->nb[3] == src0->nb[2] * src0->ne[2]);
3476
+ GGML_ASSERT(dst->nb[3] == dst->nb[2] * dst->ne[2]);
3186
3477
 
3187
- acl_tensor_ptr acl_dst =
3188
- ggml_cann_create_tensor((char *) dst->data + i * src0->ne[3], ggml_cann_type_mapping(dst->type),
3189
- ggml_element_size(dst), dst->ne, dst->nb, 3);
3478
+ int64_t src_ne_3d[3] = { src0->ne[0], src0->ne[1], src0->ne[2] * src0->ne[3] };
3479
+ int64_t dst_ne_3d[3] = { dst->ne[0], dst->ne[1], dst->ne[2] * dst->ne[3] };
3190
3480
 
3191
- GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src.get(), paddings.get(), acl_dst.get());
3192
- }
3481
+ acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0->data, ggml_cann_type_mapping(src0->type),
3482
+ ggml_element_size(src0), src_ne_3d, src0->nb, 3);
3483
+
3484
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type),
3485
+ ggml_element_size(dst), dst_ne_3d, dst->nb, 3);
3486
+
3487
+ GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src.get(), paddings.get(), acl_dst.get());
3193
3488
  }
3194
3489
 
3195
3490
  void ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3196
3491
  ggml_tensor * src0 = dst->src[0];
3197
3492
  ggml_tensor * src1 = dst->src[1];
3198
3493
 
3494
+ // Write element-wise equality (0 or 1) into a temporary buffer to avoid
3495
+ // modifying src0 in-place. Use the same type as src0 so ReduceSum can
3496
+ // consume it directly without a type cast.
3497
+ ggml_cann_pool_alloc eq_alloc(ctx.pool(), ggml_nelements(src0) * ggml_element_size(src0));
3498
+ size_t eq_nb[GGML_MAX_DIMS];
3499
+ eq_nb[0] = ggml_element_size(src0);
3500
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
3501
+ eq_nb[i] = eq_nb[i - 1] * src0->ne[i - 1];
3502
+ }
3503
+ acl_tensor_ptr acl_eq = ggml_cann_create_tensor(
3504
+ eq_alloc.get(), ggml_cann_type_mapping(src0->type), ggml_element_size(src0),
3505
+ src0->ne, eq_nb, GGML_MAX_DIMS);
3506
+
3199
3507
  acl_tensor_ptr acl_self = ggml_cann_create_tensor(src0);
3200
3508
  acl_tensor_ptr acl_other = ggml_cann_create_tensor(src1);
3201
-
3202
- GGML_CANN_CALL_ACLNN_OP(ctx, InplaceEqTensor, acl_self.get(), acl_other.get());
3203
-
3204
- ggml_cann_sum(ctx, dst);
3509
+ GGML_CANN_CALL_ACLNN_OP(ctx, EqTensor, acl_self.get(), acl_other.get(), acl_eq.get());
3510
+
3511
+ // Sum the 0/1 values into dst.
3512
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
3513
+ int64_t dims[4] = { 0, 1, 2, 3 };
3514
+ acl_int_array_ptr dims_arr = ggml_cann_create_int_array(dims, 4);
3515
+ GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_eq.get(), dims_arr.get(), true,
3516
+ ggml_cann_type_mapping(dst->type), acl_dst.get());
3205
3517
  }
3206
3518
 
3207
3519
  void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
@@ -3217,6 +3529,27 @@ void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3217
3529
  GGML_CANN_CALL_ACLNN_OP(ctx, GtScalar, acl_src.get(), alpha.get(), acl_dst.get());
3218
3530
  }
3219
3531
 
3532
+ void ggml_cann_softplus(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3533
+ ggml_tensor * src0 = dst->src[0];
3534
+
3535
+ acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0);
3536
+ acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
3537
+
3538
+ float beta_val = 1.0f;
3539
+ float threshold_val = 20.0f;
3540
+ acl_scalar_ptr beta = ggml_cann_create_scalar(&beta_val, ACL_FLOAT);
3541
+ acl_scalar_ptr threshold = ggml_cann_create_scalar(&threshold_val, ACL_FLOAT);
3542
+
3543
+ GGML_CANN_CALL_ACLNN_OP(ctx, Softplus, acl_src.get(), beta.get(), threshold.get(), acl_dst.get());
3544
+ }
3545
+
3546
+ void ggml_cann_geglu_quick(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3547
+ auto gelu_quick_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
3548
+ GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
3549
+ };
3550
+ ggml_cann_op_unary_gated(gelu_quick_fn, ctx, dst);
3551
+ }
3552
+
3220
3553
  /**
3221
3554
  * @brief Performs expert-specific matrix multiplication (MoE) with
3222
3555
  * floating-point precision using the CANN backend.
@@ -3599,6 +3932,44 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
3599
3932
  acl_k_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne, src1_bsnd_nb, GGML_MAX_DIMS);
3600
3933
  acl_v_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne, src2_bsnd_nb, GGML_MAX_DIMS);
3601
3934
 
3935
+ // Step 2.5: Pad Q, K, V along head dimension if D is not a multiple of 16
3936
+ // (required by FusedInferAttentionScoreV2)
3937
+ const int64_t D = src0->ne[0];
3938
+ const int64_t D_padded = GGML_PAD(D, 16);
3939
+ const bool needs_padding = (D != D_padded);
3940
+
3941
+ ggml_cann_pool_alloc q_pad_allocator(ctx.pool());
3942
+ ggml_cann_pool_alloc k_pad_allocator(ctx.pool());
3943
+ ggml_cann_pool_alloc v_pad_allocator(ctx.pool());
3944
+
3945
+ if (needs_padding) {
3946
+ int64_t paddings[] = { 0, D_padded - D, 0, 0, 0, 0, 0, 0 };
3947
+
3948
+ auto pad_fa_tensor = [&](acl_tensor_ptr & tensor, const int64_t * bsnd_ne,
3949
+ ggml_cann_pool_alloc & allocator) {
3950
+ int64_t pad_ne[GGML_MAX_DIMS] = { D_padded, bsnd_ne[1], bsnd_ne[2], bsnd_ne[3] };
3951
+ size_t pad_nb[GGML_MAX_DIMS];
3952
+ pad_nb[0] = faElemSize;
3953
+ for (int i = 1; i < GGML_MAX_DIMS; ++i) {
3954
+ pad_nb[i] = pad_nb[i - 1] * pad_ne[i - 1];
3955
+ }
3956
+ int64_t nelements = pad_ne[0] * pad_ne[1] * pad_ne[2] * pad_ne[3];
3957
+ void * buffer = allocator.alloc(nelements * faElemSize);
3958
+ acl_tensor_ptr padded =
3959
+ ggml_cann_create_tensor(buffer, faDataType, faElemSize, pad_ne, pad_nb, GGML_MAX_DIMS);
3960
+ aclnn_pad(ctx, tensor.get(), padded.get(), paddings);
3961
+ tensor = std::move(padded);
3962
+ };
3963
+
3964
+ pad_fa_tensor(acl_q_tensor, src0_bsnd_ne, q_pad_allocator);
3965
+ pad_fa_tensor(acl_k_tensor, src1_bsnd_ne, k_pad_allocator);
3966
+ pad_fa_tensor(acl_v_tensor, src2_bsnd_ne, v_pad_allocator);
3967
+
3968
+ src0_bsnd_ne[0] = D_padded;
3969
+ src1_bsnd_ne[0] = D_padded;
3970
+ src2_bsnd_ne[0] = D_padded;
3971
+ }
3972
+
3602
3973
  // Step 3: create the PSEShift tensor if needed
3603
3974
  // this tensor is considered as mask (f16) in the llama.cpp
3604
3975
  acl_tensor_ptr bcast_pse_tensor;
@@ -3688,17 +4059,16 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
3688
4059
 
3689
4060
  GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
3690
4061
  acl_tensor_ptr fa_dst_tensor;
3691
- acl_tensor_ptr acl_dst_tensor;
3692
4062
  ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
3693
- if (dst->type == GGML_TYPE_F32) {
3694
- void * out_f16_buffer = out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
3695
-
4063
+ if (dst->type == GGML_TYPE_F32 || needs_padding) {
3696
4064
  int64_t * out_f16_ne = src0_bsnd_ne;
3697
4065
  size_t out_f16_nb[GGML_MAX_DIMS];
3698
4066
  out_f16_nb[0] = faElemSize;
3699
4067
  for (int i = 1; i < GGML_MAX_DIMS; ++i) {
3700
4068
  out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
3701
4069
  }
4070
+ int64_t out_nelements = out_f16_ne[0] * out_f16_ne[1] * out_f16_ne[2] * out_f16_ne[3];
4071
+ void * out_f16_buffer = out_f16_allocator.alloc(out_nelements * faElemSize);
3702
4072
 
3703
4073
  fa_dst_tensor =
3704
4074
  ggml_cann_create_tensor(out_f16_buffer, faDataType, faElemSize, out_f16_ne, out_f16_nb, GGML_MAX_DIMS);
@@ -3730,8 +4100,33 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
3730
4100
  nullptr // softmaxLse
3731
4101
  );
3732
4102
 
3733
- if (dst->type == GGML_TYPE_F32) {
3734
- // Step 6: post-processing, permute and cast to f32
4103
+ // Step 6: post-processing — slice padded output and/or cast to f32
4104
+ if (needs_padding) {
4105
+ ggml_cann_pool_alloc sliced_f16_allocator(ctx.pool());
4106
+
4107
+ if (dst->type == GGML_TYPE_F32) {
4108
+ int64_t sliced_ne[GGML_MAX_DIMS] = { D, src0_bsnd_ne[1], src0_bsnd_ne[2], src0_bsnd_ne[3] };
4109
+ size_t sliced_nb[GGML_MAX_DIMS];
4110
+ sliced_nb[0] = faElemSize;
4111
+ for (int i = 1; i < GGML_MAX_DIMS; ++i) {
4112
+ sliced_nb[i] = sliced_nb[i - 1] * sliced_ne[i - 1];
4113
+ }
4114
+ int64_t sliced_nelements = sliced_ne[0] * sliced_ne[1] * sliced_ne[2] * sliced_ne[3];
4115
+ void * sliced_buffer = sliced_f16_allocator.alloc(sliced_nelements * faElemSize);
4116
+ acl_tensor_ptr sliced_f16_tensor = ggml_cann_create_tensor(sliced_buffer, faDataType, faElemSize,
4117
+ sliced_ne, sliced_nb, GGML_MAX_DIMS);
4118
+
4119
+ GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(),
4120
+ (int64_t) -1, (int64_t) 0, D, (int64_t) 1, sliced_f16_tensor.get());
4121
+
4122
+ acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst);
4123
+ aclnn_cast(ctx, sliced_f16_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type));
4124
+ } else {
4125
+ acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst);
4126
+ GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(),
4127
+ (int64_t) -1, (int64_t) 0, D, (int64_t) 1, acl_dst_tensor.get());
4128
+ }
4129
+ } else if (dst->type == GGML_TYPE_F32) {
3735
4130
  acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst);
3736
4131
  aclnn_cast(ctx, fa_dst_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type));
3737
4132
  }
@@ -3741,46 +4136,65 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
3741
4136
  }
3742
4137
 
3743
4138
  static void ggml_cann_out_prod_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
3744
- ggml_tensor * src0 = dst->src[0]; // weight
3745
- ggml_tensor * src1 = dst->src[1]; // input
4139
+ ggml_tensor * src0 = dst->src[0]; // weight [ne00=m, ne01=K, ne02, ne03]
4140
+ ggml_tensor * src1 = dst->src[1]; // input [ne10=n, ne11=K, ne12, ne13]
3746
4141
  GGML_TENSOR_BINARY_OP_LOCALS
3747
4142
 
3748
- acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
3749
- GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get());
4143
+ // dst[i,j] = sum_k src0[i,k] * src1[j,k] i.e. dst = src0 @ src1^T.
4144
+ //
4145
+ // ggml_cann_create_tensor reverses dimension order, so ACL sees:
4146
+ // acl_src0 slice: ggml[m,K] -> ACL[K,m]
4147
+ // acl_src1 slice: ggml[n,K] -> ACL[K,n]
4148
+ // acl_dst slice: ggml[m,n] -> ACL[n,m]
4149
+ //
4150
+ // Build a transposed view of src1 by swapping ne[0]/ne[1]:
4151
+ // src1_t: ggml[K,n] (swapped strides) -> ACL[n,K]
4152
+ //
4153
+ // Matmul(src1_t [n,K], src0 [K,m]) = [n,m] = acl_dst ✓
4154
+ //
4155
+ // The outer batch loop is kept because src0 may have fewer batch slices than
4156
+ // dst (ne02 <= ne2, ne03 <= ne3): this is a strided-broadcast not supported
4157
+ // by standard CANN Matmul broadcasting.
4158
+
4159
+ const aclDataType src0_acl_type = ggml_cann_type_mapping(src0->type);
4160
+ const aclDataType src1_acl_type = ggml_cann_type_mapping(src1->type);
4161
+ const aclDataType dst_acl_type = ggml_cann_type_mapping(dst->type);
4162
+ const size_t src0_type_sz = ggml_type_size(src0->type);
4163
+ const size_t src1_type_sz = ggml_type_size(src1->type);
4164
+ const size_t dst_type_sz = ggml_type_size(dst->type);
3750
4165
 
3751
4166
  const int64_t dps2 = ne2 / ne02;
3752
4167
  const int64_t dps3 = ne3 / ne03;
4168
+
3753
4169
  for (int64_t i3 = 0; i3 < ne3; i3++) {
3754
4170
  for (int64_t i2 = 0; i2 < ne2; i2++) {
3755
4171
  const int64_t i02 = i2 / dps2;
3756
4172
  const int64_t i03 = i3 / dps3;
3757
4173
 
3758
- const int64_t i12 = i2;
3759
- const int64_t i13 = i3;
3760
- acl_tensor_ptr accumulator =
3761
- ggml_cann_create_tensor((char *) dst->data + i2 * nb2 + i3 * nb3, ggml_cann_type_mapping(dst->type),
3762
- ggml_type_size(dst->type), dst->ne, dst->nb, 2);
3763
-
3764
- // The outer product needs to be accumulated in this dimension.
3765
- for (int64_t i1 = 0; i1 < ne11; i1++) {
3766
- acl_tensor_ptr acl_input = ggml_cann_create_tensor(
3767
- (char *) src1->data + i1 * nb11 + i12 * nb12 + i13 * nb13, ggml_cann_type_mapping(src0->type),
3768
- ggml_type_size(src0->type), src1->ne, src1->nb, 1);
3769
-
3770
- acl_tensor_ptr acl_weight = ggml_cann_create_tensor(
3771
- (char *) src0->data + i1 * nb01 + i02 * nb02 + i03 * nb03, ggml_cann_type_mapping(src0->type),
3772
- ggml_type_size(src0->type), src0->ne, src0->nb, 1);
3773
-
3774
- ggml_cann_pool_alloc output_allocator(ctx.pool());
3775
- void * output_buffer = output_allocator.alloc(ggml_nbytes(dst));
3776
- acl_tensor_ptr acl_out = ggml_cann_create_tensor(output_buffer, ggml_cann_type_mapping(dst->type),
3777
- ggml_type_size(dst->type), dst->ne, dst->nb, 2);
3778
-
3779
- GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_input.get(), acl_weight.get(), acl_out.get());
3780
- float alpha_value = 1.0f;
3781
- aclScalar * alpha = aclCreateScalar(&alpha_value, ACL_FLOAT);
3782
- GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, accumulator.get(), acl_out.get(), alpha);
3783
- }
4174
+ // src0 2D slice at [i02, i03]: ggml [m, K] -> ACL [K, m]
4175
+ int64_t src0_ne[2] = { ne00, ne01 };
4176
+ size_t src0_nb[2] = { nb00, nb01 };
4177
+ acl_tensor_ptr acl_src0_s = ggml_cann_create_tensor(
4178
+ (char *) src0->data + i02 * nb02 + i03 * nb03,
4179
+ src0_acl_type, src0_type_sz, src0_ne, src0_nb, 2);
4180
+
4181
+ // src1 transposed 2D slice at [i2, i3]: swap ne/nb -> ggml[K,n] -> ACL[n,K]
4182
+ int64_t src1_t_ne[2] = { ne11, ne10 };
4183
+ size_t src1_t_nb[2] = { nb11, nb10 };
4184
+ acl_tensor_ptr acl_src1_t = ggml_cann_create_tensor(
4185
+ (char *) src1->data + i2 * nb12 + i3 * nb13,
4186
+ src1_acl_type, src1_type_sz, src1_t_ne, src1_t_nb, 2);
4187
+
4188
+ // dst 2D slice at [i2, i3]: ggml [m, n] -> ACL [n, m]
4189
+ int64_t dst_ne[2] = { ne0, ne1 };
4190
+ size_t dst_nb[2] = { nb0, nb1 };
4191
+ acl_tensor_ptr acl_dst_s = ggml_cann_create_tensor(
4192
+ (char *) dst->data + i2 * nb2 + i3 * nb3,
4193
+ dst_acl_type, dst_type_sz, dst_ne, dst_nb, 2);
4194
+
4195
+ // Matmul(src1_t [n,K], src0 [K,m]) = [n,m] = acl_dst_s ✓
4196
+ GGML_CANN_CALL_ACLNN_OP(ctx, Matmul,
4197
+ acl_src1_t.get(), acl_src0_s.get(), acl_dst_s.get(), (int8_t) 1);
3784
4198
  }
3785
4199
  }
3786
4200
  }
@@ -4019,3 +4433,4 @@ void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor *
4019
4433
  }
4020
4434
  }
4021
4435
  }
4436
+