whispercpp 1.3.3 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (586) hide show
  1. checksums.yaml +4 -4
  2. data/ext/ruby_whisper_params.c +55 -25
  3. data/ext/sources/CMakeLists.txt +1 -1
  4. data/ext/sources/bindings/javascript/package.json +1 -1
  5. data/ext/sources/build-xcframework.sh +24 -0
  6. data/ext/sources/examples/CMakeLists.txt +1 -0
  7. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  8. data/ext/sources/examples/addon.node/index.js +7 -5
  9. data/ext/sources/examples/bench/bench.cpp +26 -16
  10. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  11. data/ext/sources/examples/cli/cli.cpp +4 -2
  12. data/ext/sources/examples/command/command.cpp +26 -24
  13. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  14. data/ext/sources/examples/common-ggml.cpp +2 -0
  15. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  16. data/ext/sources/examples/server/server.cpp +24 -13
  17. data/ext/sources/examples/server.py +6 -1
  18. data/ext/sources/examples/stream/stream.cpp +4 -2
  19. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  20. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  21. data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
  22. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  23. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  24. data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
  25. data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
  26. data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
  27. data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
  28. data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
  29. data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
  30. data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
  31. data/ext/sources/examples/talk-llama/llama-context.h +44 -29
  32. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  33. data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
  34. data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
  35. data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
  36. data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
  37. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  38. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  39. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  40. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
  41. data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
  42. data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
  43. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
  44. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  45. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
  46. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
  47. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  48. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
  49. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  50. data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
  51. data/ext/sources/examples/talk-llama/llama-model.h +60 -9
  52. data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
  53. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  54. data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
  55. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
  56. data/ext/sources/examples/talk-llama/llama.cpp +65 -10
  57. data/ext/sources/examples/talk-llama/llama.h +95 -177
  58. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  59. data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
  60. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  61. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  62. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  63. data/ext/sources/ggml/CMakeLists.txt +59 -31
  64. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  65. data/ext/sources/ggml/include/ggml-backend.h +17 -1
  66. data/ext/sources/ggml/include/ggml-cpu.h +1 -1
  67. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  68. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  69. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  70. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  71. data/ext/sources/ggml/include/ggml.h +221 -16
  72. data/ext/sources/ggml/src/CMakeLists.txt +17 -2
  73. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  74. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
  76. data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  79. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  83. data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
  85. data/ext/sources/ggml/src/ggml-common.h +17 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  90. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
  91. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  92. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
  93. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  94. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  95. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  96. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  97. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  98. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
  99. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  100. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
  101. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
  103. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  104. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  105. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
  106. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
  107. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
  108. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
  109. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  110. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
  112. data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
  113. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
  114. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  115. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  116. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  117. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  118. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  119. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  120. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
  121. data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
  122. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  123. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  124. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  125. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  126. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  127. data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
  128. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  129. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  130. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  131. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  132. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
  133. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  134. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  135. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  136. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  137. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
  138. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
  139. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  140. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  141. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  142. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
  143. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  144. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  145. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  146. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
  147. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  148. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  149. data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
  150. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  151. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  152. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  153. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  154. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  155. data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  156. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  157. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  158. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  159. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  160. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  161. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  162. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  163. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  164. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  165. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  166. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  167. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  168. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  169. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  170. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  171. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  172. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  173. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  174. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  176. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  177. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
  178. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  179. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  191. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  192. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  193. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  234. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  235. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  236. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  237. data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
  238. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
  239. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  240. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  241. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  242. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  243. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  244. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  245. data/ext/sources/ggml/src/ggml-impl.h +119 -9
  246. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  247. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  248. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  249. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  250. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  251. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  252. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  253. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  254. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
  255. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
  259. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  260. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
  261. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
  262. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  263. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  264. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  265. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  266. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  300. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  301. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  302. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  303. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
  304. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  305. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
  306. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  307. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
  308. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  309. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
  310. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
  311. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  312. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  313. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
  314. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
  315. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  316. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  317. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  318. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
  319. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  320. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  321. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  322. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
  323. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
  324. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  325. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  326. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
  327. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  328. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  329. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  330. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  331. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  332. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  333. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
  334. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  335. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  336. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  337. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  338. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  339. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  340. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  341. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  342. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  343. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  344. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  345. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  346. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  347. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  348. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  349. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  350. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  351. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  352. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  353. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  354. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  355. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  356. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  357. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  358. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  359. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  360. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  361. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  362. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  363. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  364. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  365. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
  366. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  367. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  368. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  369. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  370. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  371. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  372. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  373. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  374. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  375. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
  401. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  402. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  403. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  404. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  405. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  406. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  407. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  408. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  409. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  410. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  411. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  412. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  413. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  414. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  415. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  416. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  417. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  418. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  419. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  420. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  421. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  422. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  423. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  424. data/ext/sources/ggml/src/ggml.c +478 -98
  425. data/ext/sources/ggml/src/gguf.cpp +8 -1
  426. data/ext/sources/src/whisper.cpp +23 -46
  427. data/ext/sources/tests/CMakeLists.txt +8 -1
  428. data/ext/sources/tests/test-vad-full.cpp +3 -3
  429. data/ext/sources/tests/test-vad.cpp +2 -2
  430. data/lib/whisper/model/uri.rb +1 -1
  431. data/sig/whisper.rbs +7 -0
  432. data/test/test_params.rb +8 -0
  433. data/test/test_whisper.rb +1 -1
  434. data/whispercpp.gemspec +1 -1
  435. metadata +164 -157
  436. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  437. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  438. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  439. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  440. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  441. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  442. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  443. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  444. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  445. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  446. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  447. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  448. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  449. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  450. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  451. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  452. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  453. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  454. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  455. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  456. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  457. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  458. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  459. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  460. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  461. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  462. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  463. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  464. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  465. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  466. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  467. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  468. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  469. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  470. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  471. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  472. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  473. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  474. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  475. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  476. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  477. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  478. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  479. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  480. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  481. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  482. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  483. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  484. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  485. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  486. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  487. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  488. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  489. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  490. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  491. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  492. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  493. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  494. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  495. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  496. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  497. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  498. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  499. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  500. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  501. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  502. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  503. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  504. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  505. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  506. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  507. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  508. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  509. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  510. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  511. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  512. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  513. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  514. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  515. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  516. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  517. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  518. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  519. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  520. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  521. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  522. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  523. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  524. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  525. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  526. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  527. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  548. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  549. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  550. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  551. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  552. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  553. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  554. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  555. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  556. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  557. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  558. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  559. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  560. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  561. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  562. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  563. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  564. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  565. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  566. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  567. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  568. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  569. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  570. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  571. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  572. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  573. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  574. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  575. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  576. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  577. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  578. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  579. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  580. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  581. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  582. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  583. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  584. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  585. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  586. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -65,8 +65,13 @@
65
65
  #include <aclnnop/aclnn_eq_tensor.h>
66
66
  #include <aclnnop/aclnn_gt_scalar.h>
67
67
  #include <aclnnop/aclnn_pow.h>
68
- #include <aclnnop/aclnn_grouped_matmul_v2.h>
68
+ #include <aclnnop/aclnn_grouped_matmul_v3.h>
69
69
  #include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
70
+ #include <aclnnop/aclnn_zero.h>
71
+ #include <aclnnop/aclnn_index_copy.h>
72
+ #include <aclnnop/aclnn_index_select.h>
73
+ #include <aclnnop/aclnn_clamp.h>
74
+ #include <aclnnop/aclnn_threshold.h>
70
75
  #include <float.h>
71
76
 
72
77
  #include <cmath>
@@ -98,7 +103,7 @@ void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclT
98
103
  }
99
104
  }
100
105
 
101
- void ggml_cann_unary_op(
106
+ void ggml_cann_op_unary(
102
107
  std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
103
108
  ggml_backend_cann_context& ctx, ggml_tensor* dst) {
104
109
  ggml_tensor* src = dst->src[0];
@@ -110,6 +115,42 @@ void ggml_cann_unary_op(
110
115
  ggml_cann_release_resources(ctx, acl_src, acl_dst);
111
116
  }
112
117
 
118
+ void ggml_cann_op_unary_gated(
119
+ std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
120
+ ggml_backend_cann_context& ctx, ggml_tensor* dst) {
121
+ ggml_tensor* src0 = dst->src[0];
122
+ ggml_tensor* src1 = dst->src[1];
123
+
124
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
125
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
126
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
127
+
128
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
129
+ aclTensor *acl_src0 = nullptr, *acl_src1 = nullptr;
130
+ if(src1) {
131
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
132
+ GGML_ASSERT(src0->type == src1->type);
133
+
134
+ acl_src0 = ggml_cann_create_tensor(src0);
135
+ acl_src1 = ggml_cann_create_tensor(src1);
136
+ } else {
137
+ int64_t ne[] = {src0->ne[0] / 2, src0->ne[1], src0->ne[2], src0->ne[3]};
138
+ size_t nb[] = {src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]};
139
+ acl_src0 = ggml_cann_create_tensor(src0, ne, nb, GGML_MAX_DIMS, ACL_FORMAT_ND, 0);
140
+ acl_src1 = ggml_cann_create_tensor(src0, ne, nb, GGML_MAX_DIMS, ACL_FORMAT_ND, ne[0] * ggml_element_size(src0));
141
+ if (swapped) {
142
+ std::swap(acl_src0, acl_src1);
143
+ }
144
+ }
145
+
146
+ unary_op(ctx, acl_src0, acl_dst);
147
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst, acl_src1);
148
+
149
+ ggml_cann_release_resources(ctx, acl_src0, acl_dst);
150
+ if(src1)
151
+ ggml_cann_release_resources(ctx, acl_src1);
152
+ }
153
+
113
154
  /**
114
155
  * @brief Repeats elements of a tensor along each dimension according to the
115
156
  * specified repeat array.
@@ -548,9 +589,16 @@ void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
548
589
  // the position of elements in the array means which dirction to padding,
549
590
  // each position means: [dim0.front, dim0.behind, dim1.front, dim1.behind,
550
591
  // dim2.front, dim2.behind, dim3.front, dim3.behind]
551
- int64_t paddings[] = {
552
- 0, dst->ne[0] - src->ne[0], 0, dst->ne[1] - src->ne[1],
553
- 0, dst->ne[2] - src->ne[2], 0, dst->ne[3] - src->ne[3]};
592
+ const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
593
+ const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
594
+ const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
595
+ const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
596
+ const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
597
+ const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
598
+ const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
599
+ const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
600
+
601
+ int64_t paddings[] = {lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3};
554
602
  aclnn_pad(ctx, acl_src, acl_dst, paddings);
555
603
  ggml_cann_release_resources(ctx, acl_src, acl_dst);
556
604
  }
@@ -714,69 +762,55 @@ static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src,
714
762
  void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
715
763
  ggml_tensor* src0 = dst->src[0];
716
764
 
717
- aclTensor* acl_src = ggml_cann_create_tensor(src0);
718
- aclTensor* acl_dst = ggml_cann_create_tensor(dst);
719
765
  if (ggml_are_same_shape(src0, dst)) {
766
+ aclTensor* acl_src = ggml_cann_create_tensor(src0);
767
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
720
768
  if (dst->type == src0->type) {
721
769
  cann_copy(ctx, acl_src, acl_dst);
722
770
  } else {
723
771
  aclnn_cast(ctx, acl_src, acl_dst, ggml_cann_type_mapping(dst->type));
724
772
  }
773
+ ggml_cann_release_resources(ctx, acl_src, acl_dst);
725
774
  } else {
726
- if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
727
- if (dst->type == src0->type) {
728
- size_t cpy_size = ggml_nbytes(dst);
729
- ggml_cann_async_memcpy(ctx, dst->data, src0->data, cpy_size,
730
- ACL_MEMCPY_DEVICE_TO_DEVICE);
731
- return;
732
- } else {
733
- ggml_cann_pool_alloc src_buffer_allocator(
734
- ctx.pool(),
735
- ggml_nelements(dst) * ggml_type_size(dst->type));
736
- void* src_trans_buffer = src_buffer_allocator.get();
737
- size_t src_trans_nb[GGML_MAX_DIMS];
738
- src_trans_nb[0] = ggml_type_size(dst->type);
739
- for (int i = 1; i < GGML_MAX_DIMS; i++) {
740
- src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
741
- }
742
- aclTensor* src_trans_tensor = ggml_cann_create_tensor(
743
- src_trans_buffer, ggml_cann_type_mapping(dst->type),
744
- ggml_type_size(dst->type), src0->ne, src_trans_nb,
745
- GGML_MAX_DIMS);
746
-
747
- aclnn_cast(ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping(dst->type));
748
- size_t cpy_size = ggml_nbytes(dst);
749
- ggml_cann_async_memcpy(ctx, dst->data, src_trans_buffer, cpy_size,
750
- ACL_MEMCPY_DEVICE_TO_DEVICE);
751
- ggml_cann_release_resources(ctx, src_trans_tensor);
752
- return;
753
- }
754
- } else if (ggml_is_contiguous(dst)) {
755
- ggml_cann_pool_alloc src_buffer_allocator(
756
- ctx.pool(), ggml_nelements(dst) * ggml_type_size(dst->type));
757
- void* src_trans_buffer = src_buffer_allocator.get();
775
+ void* src_trans_buffer = src0->data;
776
+ ggml_cann_pool_alloc src_buffer_allocator;
777
+ if (!ggml_is_contiguous(src0)) {
778
+ aclTensor* acl_src = ggml_cann_create_tensor(src0);
779
+ src_buffer_allocator.alloc(ctx.pool(),
780
+ ggml_nelements(src0) * ggml_type_size(src0->type));
781
+ src_trans_buffer = src_buffer_allocator.get();
758
782
  size_t src_trans_nb[GGML_MAX_DIMS];
759
- src_trans_nb[0] = ggml_type_size(dst->type);
783
+ src_trans_nb[0] = ggml_type_size(src0->type);
760
784
  for (int i = 1; i < GGML_MAX_DIMS; i++) {
761
785
  src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
762
786
  }
763
787
  aclTensor* src_trans_tensor = ggml_cann_create_tensor(
764
- src_trans_buffer, ggml_cann_type_mapping(dst->type),
765
- ggml_type_size(dst->type), src0->ne, src_trans_nb,
788
+ src_trans_buffer, ggml_cann_type_mapping(src0->type),
789
+ ggml_type_size(src0->type), src0->ne, src_trans_nb,
766
790
  GGML_MAX_DIMS);
791
+ cann_copy(ctx, acl_src, src_trans_tensor);
792
+ ggml_cann_release_resources(ctx, acl_src, src_trans_tensor);
793
+ }
767
794
 
768
- aclnn_cast(ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping(dst->type));
795
+ size_t src_reshape_nb[GGML_MAX_DIMS];
796
+ src_reshape_nb[0] = ggml_type_size(src0->type);
797
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
798
+ src_reshape_nb[i] = src_reshape_nb[i - 1] * dst->ne[i - 1];
799
+ }
769
800
 
770
- size_t cpy_size = ggml_nbytes(dst);
771
- ggml_cann_async_memcpy(ctx, dst->data, src_trans_buffer, cpy_size,
772
- ACL_MEMCPY_DEVICE_TO_DEVICE);
773
- ggml_cann_release_resources(ctx, src_trans_tensor);
774
- return;
801
+ aclTensor* trans_acl_src = ggml_cann_create_tensor(src_trans_buffer,
802
+ ggml_cann_type_mapping(src0->type),ggml_type_size(src0->type),
803
+ dst->ne, src_reshape_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
804
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
805
+
806
+ if (dst->type == src0->type) {
807
+ cann_copy(ctx, trans_acl_src, acl_dst);
775
808
  } else {
776
- GGML_ABORT("Unsupport dst is not tontiguous.");
809
+ aclnn_cast(ctx, trans_acl_src, acl_dst, ggml_cann_type_mapping(dst->type));
777
810
  }
811
+ ggml_cann_release_resources(ctx, trans_acl_src, acl_dst);
778
812
  }
779
- ggml_cann_release_resources(ctx, acl_src, acl_dst);
813
+ return;
780
814
  }
781
815
 
782
816
  /**
@@ -804,10 +838,11 @@ static aclTensor* aclnn_zero(ggml_backend_cann_context& ctx, void* buffer,
804
838
  nb[i] = nb[i - 1] * ne[i - 1];
805
839
  }
806
840
 
807
- ggml_cann_async_memset(ctx, buffer, n_bytes, 0);
808
841
  aclTensor* zero =
809
842
  ggml_cann_create_tensor(buffer, type, type_size, ne, nb, dims);
843
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, zero);
810
844
  return zero;
845
+ GGML_UNUSED(n_bytes);
811
846
  }
812
847
 
813
848
  /**
@@ -841,6 +876,86 @@ static aclTensor* aclnn_values(ggml_backend_cann_context& ctx, void* buffer,
841
876
  return acl_tensor;
842
877
  }
843
878
 
879
+ /**
880
+ * @brief Fills a tensor with a scalar value.
881
+ *
882
+ * This function fills the destination tensor `acl_dst` with the scalar value
883
+ * `scalar`.
884
+ *
885
+ * @param ctx The context for the CANN backend operations.
886
+ * @param scalar The scalar value used to fill the tensor.
887
+ * @param acl_dst The destination tensor to be filled with the scalar value.
888
+ */
889
+ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
890
+ aclTensor* acl_dst) {
891
+ auto acl_scalar = aclCreateScalar(&scalar, aclDataType::ACL_FLOAT);
892
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst, acl_scalar);
893
+ ggml_cann_release_resources(ctx, acl_scalar);
894
+ }
895
+
896
+ /**
897
+ * @brief Get or expand a cached float32 tensor filled with a scalar value.
898
+ *
899
+ * This function manages cached device memory for float32 tensors. If the current
900
+ * cache size is insufficient for the requested tensor shape, the old memory will
901
+ * be released and new memory will be allocated. The allocated buffer is then
902
+ * initialized either with zeros (when @p value == 0.0f) or with the given scalar
903
+ * value using CANN operations. Finally, an aclTensor object is created from the
904
+ * cached memory and returned.
905
+ *
906
+ * @param ctx The CANN backend context that manages device memory.
907
+ * @param buffer A pointer to the cached device buffer (will be allocated
908
+ * or reallocated if necessary).
909
+ * @param cache_element The current number of cached elements. This will be
910
+ * updated when the cache is expanded.
911
+ * @param ne The tensor shape array (number of elements in each dimension).
912
+ * @param nb The stride size for each dimension.
913
+ * @param dims The number of tensor dimensions.
914
+ * @param value The scalar value used to fill the tensor (supports zero
915
+ * initialization via memset or arbitrary values via fill_scalar).
916
+ * @return An aclTensor pointer created from the cached buffer.
917
+ */
918
+ static aclTensor* get_f32_cache_acl_tensor(
919
+ ggml_backend_cann_context& ctx,
920
+ void** buffer,
921
+ int64_t &cache_element,
922
+ int64_t* ne,
923
+ size_t* nb,
924
+ int64_t dims,
925
+ float value) {
926
+ // Calculate total number of elements
927
+ int64_t n_element = 1;
928
+ for (int i = 0; i < dims; i++) {
929
+ n_element *= ne[i];
930
+ }
931
+ size_t size = n_element * sizeof(float);
932
+
933
+ // Allocate or expand cache if needed
934
+ if (cache_element < n_element) {
935
+ if (*buffer != nullptr) {
936
+ aclrtFree(*buffer);
937
+ *buffer = nullptr;
938
+ }
939
+
940
+ ACL_CHECK(aclrtMalloc(buffer, size, ACL_MEM_MALLOC_HUGE_FIRST));
941
+ cache_element = n_element;
942
+
943
+ // Initialize cache
944
+ if (value == 0.0f) {
945
+ ACL_CHECK(aclrtMemsetAsync(*buffer, size, 0, size, ctx.stream()));
946
+ } else {
947
+ int64_t pool_ne[1] = { n_element };
948
+ size_t pool_nb[1] = { sizeof(float) };
949
+ aclTensor* acl_value = ggml_cann_create_tensor(
950
+ *buffer, ACL_FLOAT, sizeof(float), pool_ne, pool_nb, 1);
951
+ aclnn_fill_scalar(ctx, 1, acl_value);
952
+ ggml_cann_release_resources(ctx, acl_value);
953
+ }
954
+ }
955
+
956
+ return ggml_cann_create_tensor(*buffer, ACL_FLOAT, sizeof(float), ne, nb, dims);
957
+ }
958
+
844
959
  void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
845
960
  ggml_tensor* src = dst->src[0];
846
961
 
@@ -849,20 +964,40 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
849
964
 
850
965
  float eps;
851
966
  memcpy(&eps, dst->op_params, sizeof(float));
852
- size_t one_tensor_n_bytes = src->ne[0] * ggml_element_size(src);
853
- ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes);
854
-
855
- aclTensor* acl_gamma = aclnn_values(
856
- ctx, one_tensor_allocator.get(), one_tensor_n_bytes, src->ne, 1,
857
- ggml_cann_type_mapping(src->type), ggml_element_size(src));
858
-
859
- size_t zero_tensor_n_bytes =
860
- src->ne[1] * src->ne[2] * src->ne[3] * ggml_element_size(src);
861
- ggml_cann_pool_alloc zero_tensor_allocator(ctx.pool(), zero_tensor_n_bytes);
862
- aclTensor* acl_rstd =
863
- aclnn_zero(ctx, zero_tensor_allocator.get(), zero_tensor_n_bytes,
864
- src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type),
865
- ggml_element_size(src));
967
+
968
+ // build gamma, one...
969
+ size_t acl_gamma_nb[GGML_MAX_DIMS];
970
+ acl_gamma_nb[0] = sizeof(float);
971
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
972
+ acl_gamma_nb[i] = acl_gamma_nb[i - 1] * src->ne[i - 1];
973
+ }
974
+ aclTensor* acl_gamma = get_f32_cache_acl_tensor(
975
+ ctx,
976
+ &ctx.rms_norm_one_tensor_cache.cache,
977
+ ctx.rms_norm_one_tensor_cache.size,
978
+ src->ne,
979
+ acl_gamma_nb,
980
+ 1, // dims
981
+ 1.0f // value
982
+ );
983
+
984
+ // build rstd, zero...
985
+ int64_t acl_rstd_ne[] = {src->ne[1], src->ne[2], src->ne[3]};
986
+ size_t acl_rstd_nb[GGML_MAX_DIMS - 1];
987
+ acl_rstd_nb[0] = sizeof(float);
988
+ for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
989
+ acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1];
990
+ }
991
+ aclTensor* acl_rstd = get_f32_cache_acl_tensor(
992
+ ctx,
993
+ &ctx.rms_norm_zero_tensor_cache.cache,
994
+ ctx.rms_norm_zero_tensor_cache.size,
995
+ acl_rstd_ne,
996
+ acl_rstd_nb,
997
+ GGML_MAX_DIMS - 1,
998
+ 0.0f // value
999
+ );
1000
+
866
1001
  GGML_CANN_CALL_ACLNN_OP(ctx, RmsNorm, acl_src, acl_gamma, eps, acl_dst, acl_rstd);
867
1002
  ggml_cann_release_resources(ctx, acl_src, acl_dst, acl_gamma, acl_rstd);
868
1003
  }
@@ -877,14 +1012,13 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst,
877
1012
 
878
1013
  const int n_past = ((int32_t*)dst->op_params)[0];
879
1014
 
880
- size_t one_tensor_n_bytes = src->ne[0] * src->ne[1] * src->ne[2] *
881
- src->ne[3] * ggml_element_size(src);
882
- ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes);
1015
+ ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), ggml_nbytes(src));
1016
+ void* buffer = one_tensor_allocator.get();
1017
+
1018
+ aclTensor* mask_tensor = ggml_cann_create_tensor(buffer, ggml_cann_type_mapping(src->type),
1019
+ ggml_type_size(src->type), src->ne, src->nb, GGML_MAX_DIMS);
883
1020
 
884
- aclTensor* mask_tensor =
885
- aclnn_values(ctx, one_tensor_allocator.get(), one_tensor_n_bytes,
886
- src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type),
887
- ggml_element_size(src), value);
1021
+ aclnn_fill_scalar(ctx, value, mask_tensor);
888
1022
 
889
1023
  aclScalar* alpha = nullptr;
890
1024
  float alphaValue = 1.0f;
@@ -1133,12 +1267,20 @@ static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src) {
1133
1267
 
1134
1268
  void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
1135
1269
  aclTensor* acl_dst) {
1136
- GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);
1270
+ if(acl_dst == nullptr) {
1271
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCos, acl_src);
1272
+ } else {
1273
+ GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);
1274
+ }
1137
1275
  }
1138
1276
 
1139
1277
  void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
1140
1278
  aclTensor* acl_dst) {
1141
- GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);
1279
+ if(acl_dst == nullptr) {
1280
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSin, acl_src);
1281
+ } else {
1282
+ GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);
1283
+ }
1142
1284
  }
1143
1285
 
1144
1286
  void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
@@ -1251,23 +1393,6 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
1251
1393
  tmp_permute_tensor, tmp_mul_tensor, acl_dst);
1252
1394
  }
1253
1395
 
1254
- /**
1255
- * @brief Fills a tensor with a scalar value.
1256
- *
1257
- * This function fills the destination tensor `acl_dst` with the scalar value
1258
- * `scalar`.
1259
- *
1260
- * @param ctx The context for the CANN backend operations.
1261
- * @param scalar The scalar value used to fill the tensor.
1262
- * @param acl_dst The destination tensor to be filled with the scalar value.
1263
- */
1264
- static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
1265
- aclTensor* acl_dst) {
1266
- auto acl_scalar = aclCreateScalar(&scalar, aclDataType::ACL_FLOAT);
1267
- GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst, acl_scalar);
1268
- ggml_cann_release_resources(ctx, acl_scalar);
1269
- }
1270
-
1271
1396
  /**
1272
1397
  * @brief Raises each element of a tensor to the power of the corresponding
1273
1398
  * element in another tensor.
@@ -1290,160 +1415,201 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
1290
1415
  }
1291
1416
 
1292
1417
  /**
1293
- * @brief Applies the Alibi (Attention with Linear Biases) mechanism to the
1294
- * @details This function implements the Alibi mechanism, which introduces
1295
- * learnable biases into the attention scores to simulate relative
1296
- * position encoding without the need for explicit positional
1297
- * embeddings.
1298
- *
1299
- * @param ctx The backend CANN context for executing operations.
1300
- * @param acl_src The source tensor representing the query or key.
1301
- * @param acl_position The position tensor containing relative positions.
1302
- * @param acl_dst The destination tensor where the result will be stored.
1303
- * @param n_head The number of attention heads.
1304
- * @param src_ne The dimensions of the source tensor.
1305
- * @param src_nb0 The byte size of the first dimension of the source
1306
- tensor.
1307
- * @param max_bias The maximum bias value used in the Alibi mechanism.
1308
- * @param dst The destination tensor object for additional metadata.
1309
- *
1310
- * The function performs the following steps:
1311
- * 1. Calculates the logarithm floor of the number of heads to determine the
1312
- base for bias calculation.
1313
- * 2. Initializes arrays with arithmetic sequences and fills them with bias
1314
- values.
1315
- * 3. Computes the bias tensor based on the calculated biases and arithmetic
1316
- sequences.
1317
- * 4. Reshapes the bias tensor to match the dimensions of the input tensors.
1318
- * 5. Multiplies the position tensor by the bias tensor.
1319
- * 6. Adds the result of the multiplication to the source tensor to produce the
1320
- final output.
1418
+ * @brief Generate a range of values and apply a scalar base exponentiation.
1419
+ *
1420
+ * This function creates an evenly spaced sequence from `start` to `stop` (exclusive),
1421
+ * with step size `step`, stores it in a temporary buffer, and then computes:
1422
+ *
1423
+ * @f[
1424
+ * slope[i] = m^{\left( start + i \cdot step \right)}, \quad 0 \le i < size
1425
+ * @f]
1426
+ *
1427
+ * The results are written to the provided @p slope_buffer.
1428
+ *
1429
+ * @param ctx CANN backend context for memory allocation and operator execution.
1430
+ * @param slope_buffer Pointer to the output buffer (float array) for the computed slope values.
1431
+ * @param m Scalar base for the exponentiation.
1432
+ * @param size Number of elements in the generated sequence.
1433
+ * @param start Starting exponent offset.
1434
+ * @param stop Stopping exponent offset (exclusive).
1435
+ * @param step Step size for the exponent increment.
1436
+ * @param dtype Data type for slope tensor.
1321
1437
  */
1322
- static void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src,
1323
- aclTensor* acl_position, aclTensor* acl_dst,
1324
- const int n_head, int64_t* src_ne, const size_t src_nb0,
1325
- float max_bias, ggml_tensor* dst) {
1326
- const int64_t ne2_ne3 = src_ne[2] * src_ne[3];
1327
- GGML_ASSERT(src_nb0 == sizeof(float));
1328
- GGML_ASSERT(n_head == src_ne[2]);
1329
-
1330
- const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
1331
-
1332
- float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
1333
- float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
1334
-
1335
- // init arange
1336
- ggml_cann_pool_alloc arange_allocator(ctx.pool(),
1337
- ne2_ne3 * ggml_type_size(dst->type));
1338
- void* tmp_arange_buffer = arange_allocator.get();
1438
+ static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer,
1439
+ float m, int64_t size, float start, float stop, float step, ggml_type dtype){
1440
+ aclDataType acl_type = ggml_cann_type_mapping(dtype);
1441
+ size_t type_size = ggml_type_size(dtype);
1339
1442
 
1340
- // arange1: [1, ..., n_heads_log2_floor+1)
1341
- float start = 1;
1342
- float stop = n_heads_log2_floor + 1;
1343
- float step = 1;
1344
- int64_t n_elements_arange = n_heads_log2_floor;
1443
+ int64_t ne[] = {size};
1444
+ size_t nb[] = {type_size};
1345
1445
 
1346
- int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
1347
- size_t tmp_arange1_nb[] = {sizeof(dst->type)};
1348
- aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
1349
- tmp_arange_buffer, ggml_cann_type_mapping(dst->type),
1350
- ggml_type_size(dst->type), tmp_arange1_ne, tmp_arange1_nb,
1351
- GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
1446
+ ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * type_size);
1447
+ void* arange_buffer = arange_allocator.get();
1352
1448
 
1353
- aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
1354
-
1355
- aclTensor* tmp_arange2_tensor = nullptr;
1356
- if (n_heads_log2_floor < ne2_ne3) {
1357
- // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
1358
- start = 1;
1359
- stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
1360
- step = 2;
1361
- n_elements_arange = ne2_ne3 - n_heads_log2_floor;
1362
- int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
1363
- size_t tmp_arange2_nb[] = {sizeof(dst->type)};
1364
-
1365
- aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
1366
- (char*)tmp_arange_buffer +
1367
- n_heads_log2_floor * ggml_type_size(dst->type),
1368
- ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
1369
- tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
1370
- aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
1371
- n_elements_arange);
1372
- }
1449
+ aclTensor* arange_tensor = ggml_cann_create_tensor(
1450
+ arange_buffer, acl_type, type_size, ne, nb, 1);
1451
+ aclnn_arange(ctx, arange_tensor, start, stop, step, size);
1373
1452
 
1374
- // init mk_base
1375
- ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
1376
- ne2_ne3 * ggml_type_size(dst->type));
1377
- void* tmp_mk_base_buffer = mk_base_allocator.get();
1378
- int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
1379
- size_t tmp_mk_base1_nb[] = {sizeof(dst->type)};
1380
- aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
1381
- tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
1382
- ggml_type_size(dst->type), tmp_mk_base1_ne, tmp_mk_base1_nb,
1383
- GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
1453
+ aclTensor* slope_tensor = ggml_cann_create_tensor(
1454
+ slope_buffer, acl_type, type_size, ne, nb, 1);
1384
1455
 
1385
- aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
1386
-
1387
- aclTensor* tmp_mk_base2_tensor = nullptr;
1388
- if (n_heads_log2_floor < ne2_ne3) {
1389
- int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
1390
- size_t tmp_mk_base2_nb[] = {sizeof(dst->type)};
1391
- aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
1392
- (char*)tmp_mk_base_buffer +
1393
- n_heads_log2_floor * ggml_type_size(dst->type),
1394
- ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
1395
- tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
1396
- aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
1397
- }
1456
+ aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT);
1398
1457
 
1399
- // init mk
1400
- int64_t tmp_mk_base_ne[] = {ne2_ne3};
1401
- size_t tmp_mk_base_nb[] = {sizeof(dst->type)};
1402
- aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
1403
- tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
1404
- ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb,
1405
- GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
1406
- aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
1407
- tmp_arange_buffer, ggml_cann_type_mapping(dst->type),
1408
- ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb,
1409
- GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
1410
- aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
1458
+ GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, sc, arange_tensor, slope_tensor);
1459
+ ggml_cann_release_resources(ctx, sc, arange_tensor, slope_tensor);
1460
+ }
1411
1461
 
1412
- // reshape mk
1413
- int64_t tmp_mk_ne[] = {1, 1, src_ne[2], src_ne[3]};
1414
- size_t tmp_mk_nb[GGML_MAX_DIMS];
1415
- tmp_mk_nb[0] = ggml_type_size(dst->type);
1416
- for (int i = 1; i < GGML_MAX_DIMS; i++) {
1417
- tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
1462
+ /**
1463
+ * @brief Compute slope values for multiple attention heads based on ALiBi bias parameters.
1464
+ *
1465
+ * This function generates slope values for each attention head according to the ALiBi
1466
+ * (Attention with Linear Biases) method. It splits the computation into two ranges depending
1467
+ * on whether the head index is less than @p n_head_log2 or not, and uses different base values
1468
+ * (`m0` and `m1`) for the exponentiation.
1469
+ *
1470
+ * @f[
1471
+ * slope[h] =
1472
+ * \begin{cases}
1473
+ * m_0^{(h + 1)}, & h < n\_head\_log2 \\
1474
+ * m_1^{\left( 2 \cdot (h - n\_head\_log2) + 1 \right)}, & h \geq n\_head\_log2
1475
+ * \end{cases}
1476
+ * \quad , \quad \text{if } max\_bias > 0
1477
+ * @f]
1478
+ *
1479
+ * If @p max_bias <= 0, all slope values are set to 1.0.
1480
+ *
1481
+ * @param ctx CANN backend context for memory allocation and operator execution.
1482
+ * @param n_head Total number of attention heads.
1483
+ * @param slope_buffer Pointer to the output buffer (float array) for storing slopes.
1484
+ * @param max_bias Maximum bias value for slope computation.
1485
+ * @param dtype Data type for slope tensor.
1486
+ *
1487
+ */
1488
+ static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head,
1489
+ void* slope_buffer, float max_bias, ggml_type dtype) {
1490
+ const int n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
1491
+
1492
+ float m0 = powf(2.0f, -(max_bias) / n_head_log2);
1493
+ float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1494
+
1495
+ // const float slope = (max_bias > 0.0f) ?
1496
+ // h < n_head_log2 ?
1497
+ // powf(m0, h + 1) :
1498
+ // powf(m1, 2*(h - n_head_log2) + 1) :
1499
+ // 1.0f;
1500
+ // arange1
1501
+ float start = 0 + 1;
1502
+ float end = (n_head_log2 - 1) + 1;
1503
+ float step = 1;
1504
+ float count = n_head_log2;
1505
+ // end needs to be +1 because aclnn uses a left-closed, right-open interval.
1506
+ aclnn_get_slope_inner(ctx, slope_buffer, m0, count, start, end + 1, step, dtype);
1507
+ if (n_head_log2 < n_head) {
1508
+ // arange2
1509
+ start = 2 * (n_head_log2 - n_head_log2) + 1;
1510
+ end = 2 * ((n_head - 1) - n_head_log2) + 1;
1511
+ step = 2;
1512
+ count = n_head - n_head_log2;
1513
+ aclnn_get_slope_inner(
1514
+ ctx, (char *) slope_buffer + n_head_log2 * sizeof(float),
1515
+ m1, count, start, end + 1, step, dtype);
1418
1516
  }
1419
- aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
1420
- tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
1421
- ggml_type_size(dst->type), tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
1422
- ACL_FORMAT_ND);
1517
+ }
1423
1518
 
1424
- // acl_position * mk
1425
- int64_t tmp_output_ne[] = {src_ne[0], src_ne[1], src_ne[2], src_ne[3]};
1426
- size_t tmp_output_nb[GGML_MAX_DIMS];
1427
- tmp_output_nb[0] = ggml_type_size(dst->type);
1428
- for (int i = 1; i < GGML_MAX_DIMS; i++) {
1429
- tmp_output_nb[i] = tmp_output_nb[i - 1] * tmp_output_ne[i - 1];
1519
+ /**
1520
+ * @brief Add ALiBi (Attention with Linear Biases) positional biases to the attention mask.
1521
+ *
1522
+ * This function computes the ALiBi slopes for each attention head (if max_bias > 0),
1523
+ * multiplies them with the attention mask to produce bias tensors, and adds these biases
1524
+ * to the destination tensor (@p dst).
1525
+ *
1526
+ * The function performs necessary broadcasting of the mask and slope tensors to match
1527
+ * the shape of the destination tensor, then applies element-wise multiplication and addition
1528
+ * using CANN operators.
1529
+ *
1530
+ * @param ctx CANN backend context for memory management and operator execution.
1531
+ * @param mask Input attention mask tensor, assumed to be contiguous.
1532
+ * @param dst Destination tensor to which ALiBi biases will be added.
1533
+ * @param dst_ptr Pointer to the memory of the destination tensor.
1534
+ * @param max_bias Maximum bias value controlling the slope scaling.
1535
+ *
1536
+ * @note
1537
+ * - Write data into dst_ptr using only the shape information of the dst tensor.
1538
+ * - `GGML_MAX_DIMS + 2` is used to extend tensor dimensions for broadcasting.
1539
+ */
1540
+ static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask,
1541
+ ggml_tensor* dst, void* dst_ptr, float max_bias) {
1542
+ void* slope_buffer = nullptr;
1543
+ void* bias_buffer = nullptr;
1544
+
1545
+ if (max_bias > 0.0f) {
1546
+ int64_t n_heads = dst->ne[2];
1547
+ ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float));
1548
+ slope_buffer = slope_allocator.get();
1549
+ ggml_cann_pool_alloc bias_allocator(
1550
+ ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst));
1551
+ bias_buffer = bias_allocator.get();
1552
+ aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias, GGML_TYPE_F32);
1553
+ }
1554
+
1555
+ // broadcast for mask, slop and dst;
1556
+ int64_t nr2 = dst->ne[2] / mask->ne[2];
1557
+ int64_t nr3 = dst->ne[3] / mask->ne[3];
1558
+
1559
+ // broadcast the mask across rows
1560
+ int64_t mask_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], 1, mask->ne[3], 1 };
1561
+ size_t mask_nb[] = {
1562
+ mask_nb[0] = mask->nb[0], mask_nb[1] = mask->nb[1], mask_nb[2] = mask->nb[2],
1563
+ mask_nb[3] = mask->nb[2], mask_nb[4] = mask->nb[3], mask_nb[5] = mask->nb[3]
1564
+ };
1565
+
1566
+ int64_t dst_ne[] = { dst->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], nr3 };
1567
+ size_t dst_nb[] = {
1568
+ dst_nb[0] = dst->nb[0], dst_nb[1] = dst->nb[1], dst_nb[2] = dst->nb[2],
1569
+ dst_nb[3] = dst->nb[2], dst_nb[4] = dst->nb[3], dst_nb[5] = dst->nb[3]
1570
+ };
1571
+
1572
+ // slope is a 1 dim tensor, slope.ne2 == dst.ne2
1573
+ int64_t slope_ne[] = { 1, 1, mask->ne[2], nr2, 1, 1 };
1574
+ size_t slope_nb[GGML_MAX_DIMS + 2];
1575
+ slope_nb[0] = sizeof(float);
1576
+ for (int i = 1; i < GGML_MAX_DIMS + 2; i++) {
1577
+ slope_nb[i] = slope_nb[i - 1] * slope_ne[i - 1];
1430
1578
  }
1431
- ggml_cann_pool_alloc output_allocator(ctx.pool(), ggml_nbytes(dst));
1432
- void* tmp_output_buffer = output_allocator.get();
1433
- aclTensor* tmp_output_tensor = ggml_cann_create_tensor(
1434
- tmp_output_buffer, ggml_cann_type_mapping(dst->type),
1435
- ggml_type_size(dst->type), tmp_output_ne, tmp_output_nb, GGML_MAX_DIMS,
1436
- ACL_FORMAT_ND);
1437
- aclnn_mul(ctx, acl_position, tmp_mk_tensor, tmp_output_tensor);
1438
1579
 
1439
- // add
1440
- aclnn_add(ctx, tmp_output_tensor, acl_src, acl_dst);
1441
- ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
1442
- tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
1443
- tmp_arange_tensor, tmp_mk_tensor, tmp_output_tensor);
1580
+ aclTensor* acl_slope = ggml_cann_create_tensor(
1581
+ slope_buffer, ACL_FLOAT, sizeof(float),
1582
+ slope_ne, slope_nb, GGML_MAX_DIMS + 2);
1583
+ aclTensor* acl_mask = ggml_cann_create_tensor(
1584
+ mask, mask_ne, mask_nb, GGML_MAX_DIMS + 2);
1585
+
1586
+ // write data into dst_ptr using only the shape information of the dst tensor.
1587
+ aclTensor* acl_dst = ggml_cann_create_tensor(
1588
+ dst_ptr, ggml_cann_type_mapping(dst->type),
1589
+ ggml_type_size(dst->type), dst_ne, dst_nb,
1590
+ GGML_MAX_DIMS + 2);
1591
+
1592
+ if (max_bias > 0.0f) {
1593
+ int64_t bias_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], 1 };
1594
+ size_t bias_nb[GGML_MAX_DIMS + 2];
1595
+ bias_nb[0] = sizeof(float);
1596
+ for (int i = 1; i < GGML_MAX_DIMS + 2; i++) {
1597
+ bias_nb[i] = bias_nb[i - 1] * bias_ne[i - 1];
1598
+ }
1599
+ aclTensor* bias_tensor = ggml_cann_create_tensor(
1600
+ bias_buffer, ACL_FLOAT, sizeof(float),
1601
+ bias_ne, bias_nb, GGML_MAX_DIMS + 2);
1602
+
1603
+ aclnn_mul(ctx, acl_slope, acl_mask, bias_tensor);
1604
+ aclnn_add(ctx, acl_dst, bias_tensor);
1605
+ ggml_cann_release_resources(ctx, bias_tensor);
1606
+ } else {
1607
+ aclnn_add(ctx, acl_dst, acl_mask);
1608
+ }
1609
+ ggml_cann_release_resources(ctx, acl_slope, acl_mask, acl_dst);
1444
1610
  }
1445
1611
 
1446
- void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1612
+ void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
1447
1613
  ggml_cann_dup(ctx, dst);
1448
1614
  }
1449
1615
 
@@ -1461,165 +1627,135 @@ void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1461
1627
  * @param acl_dst The destination tensor where the softmax results will be
1462
1628
  * stored.
1463
1629
  */
1464
- static void aclnn_softmax(ggml_backend_cann_context& ctx, aclTensor* acl_src,
1465
- int64_t dim, aclTensor* acl_dst) {
1630
+ static void aclnn_softmax(ggml_backend_cann_context & ctx,
1631
+ aclTensor* acl_src, int64_t dim, aclTensor * acl_dst) {
1466
1632
  GGML_CANN_CALL_ACLNN_OP(ctx, Softmax, acl_src, dim, acl_dst);
1467
1633
  }
1468
1634
 
1469
- void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1635
+ void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
1470
1636
  ggml_tensor* src0 = dst->src[0];
1471
1637
  ggml_tensor* src1 = dst->src[1]; // mask
1472
1638
 
1473
1639
  aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
1474
- aclTensor* acl_dst = ggml_cann_create_tensor(dst);
1640
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
1475
1641
 
1476
- float scale = 1.0f;
1642
+ float scale = 1.0f;
1477
1643
  float max_bias = 0.0f;
1478
1644
 
1479
- memcpy(&scale, (float*)dst->op_params + 0, sizeof(float));
1480
- memcpy(&max_bias, (float*)dst->op_params + 1, sizeof(float));
1645
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
1646
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
1481
1647
 
1482
1648
  // input mul scale
1483
1649
  aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT);
1650
+ ggml_cann_pool_alloc src_tensor_allocator(ctx.pool(), ggml_nbytes(src0));
1651
+ void* src_tensor_buffer = src_tensor_allocator.get();
1652
+ aclTensor* softmax_tensor = ggml_cann_create_tensor(
1653
+ src_tensor_buffer, ggml_cann_type_mapping(src0->type),
1654
+ ggml_element_size(src0), src0->ne, src0->nb,GGML_MAX_DIMS);
1484
1655
 
1485
- size_t n_bytes = ggml_nbytes(src0);
1486
- ggml_cann_pool_alloc mul_scale_allocator(ctx.pool(), n_bytes);
1487
- void* input_mul_scale_buffer = mul_scale_allocator.get();
1488
- aclTensor* acl_input_mul_scale_tensor = ggml_cann_create_tensor(
1489
- input_mul_scale_buffer, ACL_FLOAT, ggml_type_size(src0->type), src0->ne,
1490
- src0->nb, GGML_MAX_DIMS);
1491
-
1492
- bool inplace = false;
1493
- aclnn_muls(ctx, acl_src0, scale, acl_input_mul_scale_tensor, inplace);
1656
+ aclnn_muls(ctx, acl_src0, scale, softmax_tensor, false);
1494
1657
 
1495
1658
  // mask
1496
- aclTensor* acl_src1_fp32_tensor = nullptr;
1497
- aclTensor* tmp_mask_tensor = nullptr;
1498
- ggml_cann_pool_alloc src1_fp32_allocator(ctx.pool());
1499
1659
  if (src1) {
1500
- const bool use_f16 = src1->type == GGML_TYPE_F16;
1501
- if (use_f16) {
1502
- // cast to fp32
1503
- size_t n_bytes = ggml_nelements(src1) * sizeof(float_t);
1504
- size_t src1_fp32_nb[GGML_MAX_DIMS];
1505
- src1_fp32_nb[0] = sizeof(float_t);
1506
- for (int i = 1; i < GGML_MAX_DIMS; i++) {
1507
- src1_fp32_nb[i] = src1_fp32_nb[i - 1] * src1->ne[i - 1];
1508
- }
1509
- src1_fp32_allocator.alloc(n_bytes);
1510
- void* src1_fp32_buffer = src1_fp32_allocator.get();
1511
- acl_src1_fp32_tensor = ggml_cann_create_tensor(
1512
- src1_fp32_buffer, ACL_FLOAT, sizeof(float), src1->ne,
1513
- src1_fp32_nb, GGML_MAX_DIMS);
1514
- aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
1515
- aclnn_cast(ctx, acl_src1, acl_src1_fp32_tensor, ACL_FLOAT);
1516
- ggml_cann_release_resources(ctx, acl_src1);
1517
- } else {
1518
- acl_src1_fp32_tensor = ggml_cann_create_tensor(src1);
1519
- }
1660
+ aclnn_add_alibi(ctx, src1, src0, src_tensor_buffer, max_bias);
1661
+ }
1662
+ // softmax
1663
+ aclnn_softmax(ctx, softmax_tensor, 3, acl_dst);
1664
+ ggml_cann_release_resources(ctx, acl_src0, acl_dst, acl_scale, softmax_tensor);
1665
+ }
1520
1666
 
1521
- // broadcast the mask across rows, only use ne11 of ne01 in mask
1522
- if (src1->ne[1] != src0->ne[1]) {
1523
- // mask shape: [1,1,ne11,ne10]
1524
- int64_t tmp_mask_ne[] = {src0->ne[0], src0->ne[1], 1, 1};
1525
- size_t tmp_mask_nb[GGML_MAX_DIMS];
1526
- tmp_mask_nb[0] = sizeof(float_t);
1527
- for (int i = 1; i < GGML_MAX_DIMS; i++) {
1528
- tmp_mask_nb[i] = tmp_mask_nb[i - 1] * tmp_mask_ne[i - 1];
1529
- }
1530
- tmp_mask_tensor = ggml_cann_create_tensor(
1531
- src1->data, ACL_FLOAT, sizeof(float), tmp_mask_ne, tmp_mask_nb,
1532
- GGML_MAX_DIMS, ACL_FORMAT_ND);
1533
- }
1667
+ /**
1668
+ * @brief Performs index select operation on a 4D tensor using the CANN backend.
1669
+ *
1670
+ * This function applies the `IndexSelect` operation along a specific dimension
1671
+ * of the source tensor (`src_buffer`) using the indices from the index tensor (`index`).
1672
+ * It iterates over the last two dimensions of the source tensor, creates the corresponding
1673
+ * CANN tensors for the source, index, and output slices, and executes the `IndexSelect`
1674
+ * operation for each slice.
1675
+ *
1676
+ * @param ctx The context for CANN backend operations.
1677
+ * @param src_buffer The source buffer containing the 4D input tensor data.
1678
+ * @param src_ne The dimensions of the source tensor.
1679
+ * @param src_nb The strides (byte offsets) of the source tensor.
1680
+ * @param dst_buffer The destination buffer where the output tensor data will be written.
1681
+ * @param dst_ne The dimensions of the destination tensor.
1682
+ * @param dst_nb The strides (byte offsets) of the destination tensor.
1683
+ * @param index The index tensor specifying the indices to select from the source tensor.
1684
+ * @param type The data type of the source and destination tensors.
1685
+ */
1686
+ static void aclnn_index_select_4d(ggml_backend_cann_context& ctx,
1687
+ void* src_buffer,int64_t* src_ne, size_t* src_nb,
1688
+ void* dst_buffer, int64_t* dst_ne, size_t* dst_nb,
1689
+ ggml_tensor* index, ggml_type type) {
1690
+ for (int64_t i = 0; i < src_ne[3]; i++) {
1691
+ for (int64_t j = 0; j < src_ne[2]; j++) {
1692
+ // src
1693
+ aclTensor* acl_src_tensor = ggml_cann_create_tensor(
1694
+ (char*)src_buffer + i * src_nb[3] + j * src_nb[2],
1695
+ ggml_cann_type_mapping(type), ggml_type_size(type),
1696
+ src_ne, src_nb, 2);
1534
1697
 
1535
- // alibi
1536
- const int n_head = src0->ne[2];
1537
- const size_t src_nb0 = src0->nb[0];
1538
-
1539
- n_bytes = ggml_nbytes(dst);
1540
- ggml_cann_pool_alloc output_allocator(ctx.pool(), n_bytes);
1541
- void* output_buffer = output_allocator.get();
1542
- aclTensor* alibi_output_tensor = ggml_cann_create_tensor(
1543
- output_buffer, ACL_FLOAT, ggml_type_size(dst->type), dst->ne,
1544
- dst->nb, GGML_MAX_DIMS);
1545
- if (max_bias <= 0.0f) {
1546
- // slope = 1.0
1547
- if (tmp_mask_tensor) {
1548
- aclnn_add(ctx, tmp_mask_tensor, acl_input_mul_scale_tensor,
1549
- alibi_output_tensor);
1550
- } else {
1551
- aclnn_add(ctx, acl_src1_fp32_tensor, acl_input_mul_scale_tensor,
1552
- alibi_output_tensor);
1553
- }
1554
- } else {
1555
- // slope != 1.0
1556
- if (tmp_mask_tensor) {
1557
- aclnn_alibi(ctx, acl_input_mul_scale_tensor, tmp_mask_tensor,
1558
- alibi_output_tensor, n_head, src0->ne, src_nb0,
1559
- max_bias, dst);
1560
- } else {
1561
- aclnn_alibi(ctx, acl_input_mul_scale_tensor,
1562
- acl_src1_fp32_tensor, alibi_output_tensor, n_head,
1563
- src0->ne, src_nb0, max_bias, dst);
1564
- }
1565
- }
1698
+ // index
1699
+ aclTensor* acl_index = ggml_cann_create_tensor(
1700
+ (char*)index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1],
1701
+ ggml_cann_type_mapping(index->type), ggml_element_size(index),
1702
+ index->ne, index->nb, 1);
1566
1703
 
1567
- // softmax
1568
- aclnn_softmax(ctx, alibi_output_tensor, 3, acl_dst);
1569
- ggml_cann_release_resources(ctx, alibi_output_tensor);
1570
- } else {
1571
- aclnn_softmax(ctx, acl_input_mul_scale_tensor, 3, acl_dst);
1704
+ // out
1705
+ aclTensor* acl_out = ggml_cann_create_tensor(
1706
+ (char*)dst_buffer + i * dst_nb[3] + j * dst_nb[2],
1707
+ ggml_cann_type_mapping(type), ggml_type_size(type),
1708
+ dst_ne, dst_nb, 2);
1709
+ GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, acl_src_tensor, 0, acl_index, acl_out);
1710
+ ggml_cann_release_resources(ctx, acl_src_tensor, acl_index, acl_out);
1711
+ }
1572
1712
  }
1573
-
1574
- ggml_cann_release_resources(ctx, acl_src0, acl_src1_fp32_tensor, acl_dst,
1575
- acl_scale, acl_input_mul_scale_tensor, tmp_mask_tensor);
1576
1713
  }
1577
1714
 
1578
1715
  /**
1579
- * @brief Performs embedding operation on a 4D tensor using the CANN backend.
1716
+ * @brief Performs inplace index copy operation on a 4D tensor using the CANN backend.
1580
1717
  *
1581
- * This function extracts slices from the source tensor (`src_buffer`),
1582
- * index tensor (`index`), and destination tensor (`dst`), and performs an
1583
- * embedding operation on them. The embedding operation is applied by iterating
1584
- * over the last two dimensions of the source tensor, creating the necessary
1585
- * tensors for the source, index, and output, and executing the embedding operation.
1718
+ * This function applies the `IndexCopy` operation along a specific dimension of the
1719
+ * destination tensor (`dst_buffer`) by copying elements from the source tensor (`src_buffer`)
1720
+ * to positions specified by the index tensor (`index`).
1721
+ * It iterates over the last two dimensions of the tensors, creates the corresponding
1722
+ * CANN tensors for source, index, and destination slices, and performs the index copy
1723
+ * operation for each slice.
1586
1724
  *
1587
1725
  * @param ctx The context for CANN backend operations.
1588
- * @param src_buffer The source buffer holding the data for the source tensor.
1726
+ * @param src_buffer The source buffer containing the 4D input tensor data to be copied.
1589
1727
  * @param src_ne The dimensions of the source tensor.
1590
1728
  * @param src_nb The strides (byte offsets) of the source tensor.
1591
- * @param index The index tensor used in the embedding operation.
1592
- * @param dst The destination tensor where the result will be stored.
1729
+ * @param dst_buffer The destination buffer where values will be copied to.
1730
+ * @param dst_ne The dimensions of the destination tensor.
1731
+ * @param dst_nb The strides (byte offsets) of the destination tensor.
1732
+ * @param index The index tensor specifying target positions in the destination tensor.
1733
+ * @param type The data type of the source and destination tensors.
1593
1734
  */
1594
- static void aclnn_embedding_4d(ggml_backend_cann_context& ctx, void* src_buffer,
1595
- int64_t* src_ne, size_t* src_nb, ggml_tensor* index,
1596
- ggml_tensor* dst) {
1735
+ static void aclnn_index_copy_4d(ggml_backend_cann_context& ctx,
1736
+ void* src_buffer,int64_t* src_ne, size_t* src_nb,
1737
+ void* dst_buffer, int64_t* dst_ne, size_t* dst_nb,
1738
+ ggml_tensor* index, ggml_type type) {
1597
1739
  for (int64_t i = 0; i < src_ne[3]; i++) {
1598
1740
  for (int64_t j = 0; j < src_ne[2]; j++) {
1599
1741
  // src
1600
- int64_t acl_src_ne[2] = {src_ne[0], src_ne[1]};
1601
- size_t acl_src_nb[2] = {src_nb[0], src_nb[1]};
1602
1742
  aclTensor* acl_src_tensor = ggml_cann_create_tensor(
1603
1743
  (char*)src_buffer + i * src_nb[3] + j * src_nb[2],
1604
- ggml_cann_type_mapping(dst->type), ggml_element_size(dst),
1605
- acl_src_ne, acl_src_nb, 2);
1744
+ ggml_cann_type_mapping(type), ggml_type_size(type),
1745
+ src_ne, src_nb, 2);
1606
1746
 
1607
1747
  // index
1608
- int64_t acl_index_ne[1] = {index->ne[0]};
1609
- size_t acl_index_nb[1] = {index->nb[0]};
1610
1748
  aclTensor* acl_index = ggml_cann_create_tensor(
1611
- (char*)index->data + i * index->nb[2] + j * index->nb[1],
1749
+ (char*)index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1],
1612
1750
  ggml_cann_type_mapping(index->type), ggml_element_size(index),
1613
- acl_index_ne, acl_index_nb, 1);
1751
+ index->ne, index->nb, 1);
1614
1752
 
1615
1753
  // out
1616
- int64_t acl_out_ne[2] = {dst->ne[0], dst->ne[1]};
1617
- size_t acl_out_nb[2] = {dst->nb[0], dst->nb[1]};
1618
1754
  aclTensor* acl_out = ggml_cann_create_tensor(
1619
- (char*)dst->data + i * dst->nb[3] + j * dst->nb[2],
1620
- ggml_cann_type_mapping(dst->type), ggml_element_size(dst),
1621
- acl_out_ne, acl_out_nb, 2);
1622
- GGML_CANN_CALL_ACLNN_OP(ctx, Embedding, acl_src_tensor, acl_index, acl_out);
1755
+ (char*)dst_buffer + i * dst_nb[3] + j * dst_nb[2],
1756
+ ggml_cann_type_mapping(type), ggml_type_size(type),
1757
+ dst_ne, dst_nb, 2);
1758
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexCopy, acl_out, 0, acl_index, acl_src_tensor);
1623
1759
  ggml_cann_release_resources(ctx, acl_src_tensor, acl_index, acl_out);
1624
1760
  }
1625
1761
  }
@@ -1631,17 +1767,18 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1631
1767
 
1632
1768
  switch (src0->type) {
1633
1769
  case GGML_TYPE_F32: {
1634
- aclnn_embedding_4d(ctx, src0->data, src0->ne, src0->nb, src1,
1635
- dst);
1770
+ aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb,
1771
+ dst->data, dst->ne, dst->nb,
1772
+ src1, dst->type);
1636
1773
  break;
1637
1774
  }
1638
1775
  case GGML_TYPE_F16: {
1639
1776
  aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
1640
1777
  ggml_cann_pool_alloc src_buffer_allocator(
1641
- ctx.pool(), ggml_nelements(src0) * sizeof(float_t));
1778
+ ctx.pool(), ggml_nelements(src0) * sizeof(float));
1642
1779
  void* src_trans_buffer = src_buffer_allocator.get();
1643
1780
  size_t src_trans_nb[GGML_MAX_DIMS];
1644
- src_trans_nb[0] = sizeof(float_t);
1781
+ src_trans_nb[0] = sizeof(float);
1645
1782
  for (int i = 1; i < GGML_MAX_DIMS; i++) {
1646
1783
  src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
1647
1784
  }
@@ -1649,8 +1786,9 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1649
1786
  src_trans_buffer, ACL_FLOAT, ggml_type_size(dst->type),
1650
1787
  src0->ne, src_trans_nb, GGML_MAX_DIMS);
1651
1788
  aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type));
1652
- aclnn_embedding_4d(ctx, src_trans_buffer, src0->ne,
1653
- src_trans_nb, src1, dst);
1789
+ aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb,
1790
+ dst->data, dst->ne, dst->nb,
1791
+ src1, dst->type);
1654
1792
  ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor);
1655
1793
  break;
1656
1794
  }
@@ -1684,14 +1822,14 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1684
1822
 
1685
1823
  // [3,4,5,64] -> [3,4,5,2,32]
1686
1824
  dequant_ne = weight_ne;
1687
- dequant_nb[0] = sizeof(float_t);
1825
+ dequant_nb[0] = sizeof(float);
1688
1826
  for (int i = 1; i < GGML_MAX_DIMS + 1; i++) {
1689
1827
  dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1];
1690
1828
  }
1691
1829
 
1692
1830
  scale_offset = ggml_nelements(src0) * sizeof(int8_t);
1693
1831
  ggml_cann_pool_alloc dequant_buffer_allocator(
1694
- ctx.pool(), ggml_nelements(src0) * sizeof(float_t));
1832
+ ctx.pool(), ggml_nelements(src0) * sizeof(float));
1695
1833
 
1696
1834
  aclTensor* acl_weight_tensor = ggml_cann_create_tensor(
1697
1835
  src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb,
@@ -1700,18 +1838,20 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1700
1838
  src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb,
1701
1839
  GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset);
1702
1840
  aclTensor* dequant_tensor = ggml_cann_create_tensor(
1703
- dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float_t),
1841
+ dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float),
1704
1842
  dequant_ne, dequant_nb, GGML_MAX_DIMS + 1);
1705
1843
 
1706
1844
  aclnn_mul(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor);
1707
- dequant_nb[0] = sizeof(float_t);
1845
+ dequant_nb[0] = sizeof(float);
1708
1846
  dequant_ne = src0->ne;
1709
1847
  for (int i = 1; i < GGML_MAX_DIMS; i++) {
1710
1848
  dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1];
1711
1849
  }
1712
1850
 
1713
- aclnn_embedding_4d(ctx, dequant_buffer_allocator.get(),
1714
- dequant_ne, dequant_nb, src1, dst);
1851
+ aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(),
1852
+ dequant_ne, dequant_nb,
1853
+ dst->data, dst->ne, dst->nb,
1854
+ src1, dst->type);
1715
1855
 
1716
1856
  ggml_cann_release_resources(ctx, dequant_tensor);
1717
1857
  break;
@@ -1722,6 +1862,43 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1722
1862
  }
1723
1863
  }
1724
1864
 
1865
+ void ggml_cann_set_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1866
+ ggml_tensor* src0 = dst->src[0]; // src
1867
+ ggml_tensor* src1 = dst->src[1]; // index
1868
+
1869
+ switch (dst->type) {
1870
+ case GGML_TYPE_F32: {
1871
+ aclnn_index_copy_4d(ctx, src0->data, src0->ne, src0->nb,
1872
+ dst->data, dst->ne, dst->nb,
1873
+ src1, dst->type);
1874
+ break;
1875
+ }
1876
+ case GGML_TYPE_F16: {
1877
+ aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
1878
+ ggml_cann_pool_alloc src_buffer_allocator(
1879
+ ctx.pool(), ggml_nelements(src0) * sizeof(uint16_t));
1880
+ void* src_trans_buffer = src_buffer_allocator.get();
1881
+ size_t src_trans_nb[GGML_MAX_DIMS];
1882
+ src_trans_nb[0] = sizeof(uint16_t);
1883
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
1884
+ src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
1885
+ }
1886
+ aclTensor* src_trans_tensor = ggml_cann_create_tensor(
1887
+ src_trans_buffer, ACL_FLOAT16, ggml_type_size(dst->type),
1888
+ src0->ne, src_trans_nb, GGML_MAX_DIMS);
1889
+ aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type));
1890
+ aclnn_index_copy_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb,
1891
+ dst->data, dst->ne, dst->nb,
1892
+ src1, dst->type);
1893
+ ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor);
1894
+ break;
1895
+ }
1896
+ default:
1897
+ GGML_ABORT("Unsupported tensor type for GGML_OP_SET_ROWS");
1898
+ break;
1899
+ }
1900
+ }
1901
+
1725
1902
  /**
1726
1903
  * @brief Repeats elements of a tensor along a specified dimension.
1727
1904
  *
@@ -1783,8 +1960,25 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
1783
1960
  size_t transpose_nb[] = {bcast_weight_nb[1], bcast_weight_nb[0],
1784
1961
  bcast_weight_nb[2], bcast_weight_nb[3],
1785
1962
  bcast_weight_nb[4], bcast_weight_nb[5]};
1786
- aclTensor* acl_weight_tensor =
1787
- ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims);
1963
+ aclTensor* acl_weight_tensor;
1964
+
1965
+ // Only check env once.
1966
+ static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
1967
+ if (weight_to_nz && is_matmul_weight(weight)) {
1968
+ int64_t acl_stride[2] = {1, transpose_ne[1]};
1969
+
1970
+ // Reverse ne.
1971
+ std::reverse(transpose_ne, transpose_ne + n_dims);
1972
+
1973
+ std::vector<int64_t> storageDims = {transpose_ne[0], transpose_ne[1]};
1974
+
1975
+ acl_weight_tensor = aclCreateTensor(
1976
+ transpose_ne, n_dims, ggml_cann_type_mapping(weight->type), acl_stride,
1977
+ 0, ACL_FORMAT_FRACTAL_NZ, storageDims.data(), 2, weight->data);
1978
+ } else {
1979
+ acl_weight_tensor =
1980
+ ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND);
1981
+ }
1788
1982
  aclTensor* acl_dst =
1789
1983
  ggml_cann_create_tensor(dst, bcast_dst_ne, bcast_dst_nb, n_dims);
1790
1984
 
@@ -2050,63 +2244,190 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
2050
2244
  ggml_cann_release_resources(ctx, acl_index, acl_value);
2051
2245
  }
2052
2246
 
2247
+ /**
2248
+ * @brief Initializes and caches sine/cosine positional encoding values
2249
+ * (used in RoPE, Rotary Position Embedding) for attention layers.
2250
+ *
2251
+ * This function computes and caches the sin/cos values of
2252
+ * θ = position * theta_scale for RoPE encoding. The cache is shared
2253
+ * across attention layers, and only the first attention layer will
2254
+ * trigger initialization. The cache includes repeated sin/cos values
2255
+ * with different repeat methods depending on the @param is_neox flag.
2256
+ *
2257
+ * Steps performed by this function:
2258
+ * 1. Identify whether the target tensor belongs to Q/K in attention
2259
+ * and restrict computation to the first layer only.
2260
+ * 2. Initialize the theta scale array (arange → power → freq scaling).
2261
+ * 3. Allocate sin/cos caches if the max prompt length increases.
2262
+ * 4. Compute θ = position * theta_scale.
2263
+ * 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
2264
+ * 6. Expand sin/cos values by repeat or repeat_interleave depending
2265
+ * on whether @param is_neox is enabled.
2266
+ *
2267
+ * @param ctx The CANN backend context, holding memory pool,
2268
+ * stream, and persistent buffers for rope init/cache.
2269
+ * @param dst The destination ggml_tensor whose computation
2270
+ * depends on the RoPE values (usually Qcur/Kcur).
2271
+ * @param theta_scale Scalar exponent base for computing theta scale values.
2272
+ * @param freq_scale Frequency scaling factor, applied to theta scale.
2273
+ * @param attn_factor Attention scaling factor, applied to sin/cos.
2274
+ * @param is_neox Whether to use Neox-style repeat strategy
2275
+ * (dim expansion vs repeat_interleave).
2276
+ */
2053
2277
  static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
2054
- aclTensor* acl_cos_repeat_tensor,
2055
- aclTensor* acl_sin_repeat_tensor,
2278
+ float* corr_dims, float ext_factor,
2056
2279
  float theta_scale, float freq_scale,
2057
2280
  float attn_factor, bool is_neox) {
2058
- // int sin/cos cache, cache has different repeat method depond on
2059
- // @param.is_neox
2060
-
2061
2281
  ggml_tensor* src0 = dst->src[0]; // input
2062
2282
  ggml_tensor* src1 = dst->src[1]; // position
2063
2283
  ggml_tensor* src2 = dst->src[2]; // freq_factors
2064
2284
 
2065
- GGML_TENSOR_BINARY_OP_LOCALS
2285
+ if(src2 == nullptr && ctx.rope_cache.cached
2286
+ && ctx.rope_cache.ext_factor == ext_factor
2287
+ && ctx.rope_cache.theta_scale == theta_scale
2288
+ && ctx.rope_cache.freq_scale == freq_scale
2289
+ && ctx.rope_cache.attn_factor == attn_factor
2290
+ && ctx.rope_cache.is_neox == is_neox) {
2291
+ // use cache.
2292
+ return;
2293
+ }
2066
2294
 
2067
- // theta_scale arange, [0,1,...,ne00/2 - 1]
2068
- int64_t theta_scale_length = ne00 / 2;
2069
- ggml_cann_pool_alloc theta_scale_allocator(ctx.pool(),
2070
- theta_scale_length * sizeof(float_t));
2071
- void* theta_scale_buffer = theta_scale_allocator.get();
2295
+ int64_t theta_scale_length = src0->ne[0] / 2;
2072
2296
  int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1};
2073
- size_t theta_scale_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t),
2074
- theta_scale_length * sizeof(float_t)};
2297
+ size_t theta_scale_nb[] = {sizeof(float), sizeof(float), sizeof(float),
2298
+ theta_scale_length * sizeof(float)};
2075
2299
 
2076
- aclTensor* acl_theta_scale_tensor =
2077
- ggml_cann_create_tensor(theta_scale_buffer, ACL_FLOAT, sizeof(float_t),
2078
- theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2079
- float start = 0;
2080
- float step = 1;
2081
- float stop = ne00 / 2;
2082
- float n_elements = ne00 / 2;
2083
- aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
2084
-
2085
- // power
2086
- aclScalar* acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT);
2087
- GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor,
2088
- acl_theta_scale_tensor);
2089
-
2090
- // freq_scale
2091
- if (freq_scale != 1) {
2092
- aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true);
2300
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
2301
+ int64_t position_length = src1->ne[0];
2302
+ int64_t position_ne[] = {1, 1, position_length, 1};
2303
+ size_t position_nb[] = {sizeof(int32_t), sizeof(int32_t), sizeof(int32_t),
2304
+ sizeof(int32_t) * position_length};
2305
+
2306
+ int64_t theta_ne[] = {theta_scale_length, 1, position_length, 1};
2307
+ size_t theta_nb[GGML_MAX_DIMS];
2308
+ theta_nb[0] = sizeof(float);
2309
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
2310
+ theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
2311
+ }
2312
+
2313
+ // theta_scale arange, [0,1,...,ne00/2 - 1]
2314
+ aclTensor* acl_theta_scale_tensor = nullptr;
2315
+ // cache theta scale
2316
+ if (ctx.rope_cache.theta_scale_length != theta_scale_length ||
2317
+ // theta_scale and freq_scale should not change during the current token inference process,
2318
+ // so we can directly use == here instead of comparing the absolute difference.
2319
+ ctx.rope_cache.theta_scale != theta_scale ||
2320
+ ctx.rope_cache.freq_scale != freq_scale) {
2321
+
2322
+ ctx.rope_cache.theta_scale_length = theta_scale_length;
2323
+
2324
+ if (ctx.rope_cache.theta_scale_cache != nullptr) {
2325
+ ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
2326
+ }
2327
+ ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
2328
+
2329
+ acl_theta_scale_tensor =
2330
+ ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
2331
+ theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2332
+
2333
+ float start = 0;
2334
+ float step = 1;
2335
+ float stop = theta_scale_length;
2336
+ float n_elements = theta_scale_length;
2337
+ aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
2338
+
2339
+ ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
2340
+ aclTensor* acl_yarn_ramp_tensor = nullptr;
2341
+ if (ext_factor != 0) {
2342
+ // -rope_yarn_ramp
2343
+ // const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
2344
+ // return MIN(1, MAX(0, y)) - 1;
2345
+ yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
2346
+ void* yarn_ramp_buffer = yarn_ramp_allocator.get();
2347
+ acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float),
2348
+ theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2349
+ float zero_value = 0, one_value = 1;
2350
+ float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
2351
+ aclScalar* low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT);
2352
+ aclScalar* zero = aclCreateScalar(&zero_value, aclDataType::ACL_FLOAT);
2353
+ aclScalar* one = aclCreateScalar(&one_value, aclDataType::ACL_FLOAT);
2354
+ aclScalar* denom_safe = aclCreateScalar(&denom_safe_value, aclDataType::ACL_FLOAT);
2355
+ aclScalar* ext_factor_sc = aclCreateScalar(&ext_factor, aclDataType::ACL_FLOAT);
2356
+
2357
+ GGML_CANN_CALL_ACLNN_OP(ctx, Subs, acl_theta_scale_tensor, low, one, acl_yarn_ramp_tensor);
2358
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor, denom_safe);
2359
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor, zero, zero);
2360
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor, one);
2361
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor, one, one);
2362
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor, ext_factor_sc);
2363
+
2364
+ // theta_interp = freq_scale * theta_extrap;
2365
+ // theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
2366
+ // theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix;
2367
+ // theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix;
2368
+ // theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix);
2369
+ //
2370
+ // we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse
2371
+ // cache freq_scale + (freq_scale - 1) * ramp_mix
2372
+ float freq_scale_1 = freq_scale - 1;
2373
+ aclScalar* freq_scale_sc = aclCreateScalar(&freq_scale, aclDataType::ACL_FLOAT);
2374
+ aclScalar* freq_scale_1_sc = aclCreateScalar(&freq_scale_1, aclDataType::ACL_FLOAT);
2375
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor, freq_scale_1_sc);
2376
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor, freq_scale_sc, one);
2377
+
2378
+ ggml_cann_release_resources(ctx, low, zero, one, denom_safe, ext_factor_sc, freq_scale_sc, freq_scale_1_sc);
2379
+ }
2380
+
2381
+ // power
2382
+ aclScalar* acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT);
2383
+ GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor,
2384
+ acl_theta_scale_tensor);
2385
+
2386
+ if (ext_factor != 0) {
2387
+ aclnn_mul(ctx, acl_theta_scale_tensor, acl_yarn_ramp_tensor);
2388
+ } else if (freq_scale != 1) {
2389
+ aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true);
2390
+ }
2391
+
2392
+ ggml_cann_release_resources(ctx, acl_yarn_ramp_tensor, acl_theta_scale);
2393
+ } else {
2394
+ // use cache
2395
+ acl_theta_scale_tensor =
2396
+ ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
2397
+ theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2093
2398
  }
2094
2399
 
2400
+ ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool());
2095
2401
  // freq_factors
2096
2402
  if (src2) {
2403
+ freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float));
2404
+ void* freq_fac_res_ptr = freq_fac_res_allocator.get();
2097
2405
  aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
2098
2406
  src2->data, ggml_cann_type_mapping(src2->type),
2099
2407
  ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2100
- aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
2101
- ggml_cann_release_resources(ctx, acl_freq_factors_tensor);
2408
+ aclTensor* acl_freq_fac_res_tensor = ggml_cann_create_tensor(
2409
+ freq_fac_res_ptr, ACL_FLOAT, sizeof(float),
2410
+ theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2411
+ aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
2412
+ std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor);
2413
+ ggml_cann_release_resources(ctx, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
2414
+ }
2415
+
2416
+ // init sin_repeat && cos_repeat, only to accelerate first layer on each device
2417
+ if (position_length > ctx.rope_cache.position_length) {
2418
+ ctx.rope_cache.position_length = position_length;
2419
+ if (ctx.rope_cache.sin_cache != nullptr) {
2420
+ ACL_CHECK(aclrtFree(ctx.rope_cache.sin_cache));
2421
+ }
2422
+ if (ctx.rope_cache.cos_cache != nullptr) {
2423
+ ACL_CHECK(aclrtFree(ctx.rope_cache.cos_cache));
2424
+ }
2425
+ int64_t repeat_theta_length = theta_scale_length * position_length * 2;
2426
+ ACL_CHECK(aclrtMalloc(&ctx.rope_cache.sin_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
2427
+ ACL_CHECK(aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
2102
2428
  }
2103
2429
 
2104
2430
  // position
2105
- GGML_ASSERT(src1->type == GGML_TYPE_I32);
2106
- int64_t position_length = src1->ne[0];
2107
- int64_t position_ne[] = {1, 1, position_length, 1};
2108
- size_t position_nb[] = {sizeof(int32_t), sizeof(int32_t), sizeof(int32_t),
2109
- sizeof(int32_t) * position_length};
2110
2431
  aclTensor* acl_position_tensor = ggml_cann_create_tensor(
2111
2432
  src1->data, ggml_cann_type_mapping(src1->type),
2112
2433
  ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
@@ -2114,43 +2435,55 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
2114
2435
  // power * position
2115
2436
  int64_t theta_length = theta_scale_length * position_length;
2116
2437
  ggml_cann_pool_alloc theta_allocator(ctx.pool(),
2117
- theta_length * sizeof(float_t));
2438
+ theta_length * sizeof(float));
2118
2439
  void* theta_buffer = theta_allocator.get();
2119
- int64_t theta_ne[] = {theta_scale_length, 1, position_length, 1};
2120
- size_t theta_nb[GGML_MAX_DIMS];
2121
- theta_nb[0] = sizeof(float_t);
2122
- for (int i = 1; i < GGML_MAX_DIMS; i++) {
2123
- theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
2124
- }
2440
+
2125
2441
  aclTensor* acl_theta_tensor =
2126
- ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
2442
+ ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float),
2127
2443
  theta_ne, theta_nb, GGML_MAX_DIMS);
2128
2444
  aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
2129
- acl_theta_tensor);
2445
+ acl_theta_tensor);
2130
2446
 
2131
2447
  // sin/cos
2132
2448
  ggml_cann_pool_alloc sin_allocator(ctx.pool(),
2133
- theta_length * sizeof(float_t));
2449
+ theta_length * sizeof(float));
2134
2450
  void* sin_buffer = sin_allocator.get();
2135
2451
  aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
2136
- sin_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
2452
+ sin_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb,
2137
2453
  GGML_MAX_DIMS, ACL_FORMAT_ND);
2138
2454
  aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
2139
2455
 
2140
2456
  ggml_cann_pool_alloc cos_allocator(ctx.pool(),
2141
- theta_length * sizeof(float_t));
2457
+ theta_length * sizeof(float));
2142
2458
  void* cos_buffer = cos_allocator.get();
2143
2459
  aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
2144
- cos_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
2460
+ cos_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb,
2145
2461
  GGML_MAX_DIMS, ACL_FORMAT_ND);
2146
2462
  aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
2147
2463
 
2464
+ if (ext_factor != 0) {
2465
+ attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale);
2466
+ }
2467
+
2148
2468
  // attn_factor
2149
2469
  if (attn_factor != 1) {
2150
2470
  aclnn_muls(ctx, acl_sin_tensor, attn_factor, nullptr, true);
2151
2471
  aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true);
2152
2472
  }
2153
2473
 
2474
+ int64_t sin_reshape_ne[4] = {src0->ne[0], 1, src0->ne[2], 1};
2475
+ size_t sin_reshape_nb[GGML_MAX_DIMS];
2476
+ sin_reshape_nb[0] = sizeof(float);
2477
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
2478
+ sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
2479
+ }
2480
+ aclTensor* acl_sin_repeat_tensor =
2481
+ ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float),
2482
+ sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
2483
+ aclTensor* acl_cos_repeat_tensor =
2484
+ ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float),
2485
+ sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
2486
+
2154
2487
  // repeat
2155
2488
  if (is_neox) {
2156
2489
  int64_t repeatsArray[] = {1, 1, 1, 2};
@@ -2166,9 +2499,17 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
2166
2499
  num_repeats, output_size);
2167
2500
  }
2168
2501
 
2169
- // release
2502
+ // Other layers use cache except first layer.
2503
+ ctx.rope_cache.cached = true;
2504
+ ctx.rope_cache.ext_factor = ext_factor;
2505
+ ctx.rope_cache.theta_scale = theta_scale;
2506
+ ctx.rope_cache.freq_scale = freq_scale;
2507
+ ctx.rope_cache.attn_factor = attn_factor;
2508
+ ctx.rope_cache.is_neox = is_neox;
2509
+
2170
2510
  ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
2171
- acl_theta_tensor, acl_sin_tensor, acl_cos_tensor, acl_theta_scale);
2511
+ acl_theta_tensor, acl_sin_tensor, acl_sin_repeat_tensor, acl_cos_tensor,
2512
+ acl_cos_repeat_tensor);
2172
2513
  }
2173
2514
 
2174
2515
  #ifdef __cplusplus
@@ -2187,8 +2528,6 @@ aclnnStatus aclnnRotaryPositionEmbedding(void* workspace,
2187
2528
  #endif
2188
2529
 
2189
2530
  void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2190
- // TODO: use ascendc
2191
- // Only test with LLAMA model.
2192
2531
  ggml_tensor* src0 = dst->src[0]; // input
2193
2532
 
2194
2533
  // param
@@ -2211,8 +2550,6 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2211
2550
  // TODO: n_dims <= ne0
2212
2551
  GGML_ASSERT(n_dims == ne0);
2213
2552
  GGML_ASSERT(n_dims % 2 == 0);
2214
- // TODO: ext_factor != 0
2215
- GGML_ASSERT(ext_factor == 0);
2216
2553
 
2217
2554
  const float theta_scale = powf(freq_base, -2.0f / n_dims);
2218
2555
 
@@ -2222,28 +2559,22 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2222
2559
 
2223
2560
  const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
2224
2561
 
2225
- // init cos/sin cache
2226
- ggml_cann_pool_alloc sin_allocator(
2227
- ctx.pool(), ne00 * ne02 * sizeof(float_t));
2228
- ggml_cann_pool_alloc cos_allocator(
2229
- ctx.pool(), ne00 * ne02 * sizeof(float_t));
2230
- void* sin_buffer = sin_allocator.get();
2231
- void* cos_buffer = cos_allocator.get();
2562
+ // init ctx.rope_cos/rope_sin cache
2563
+ aclnn_cache_init(ctx, dst, corr_dims, ext_factor,
2564
+ theta_scale, freq_scale, attn_factor, is_neox);
2232
2565
 
2233
2566
  int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
2234
2567
  size_t sin_reshape_nb[GGML_MAX_DIMS];
2235
- sin_reshape_nb[0] = sizeof(float_t);
2568
+ sin_reshape_nb[0] = sizeof(float);
2236
2569
  for (int i = 1; i < GGML_MAX_DIMS; i++) {
2237
2570
  sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
2238
2571
  }
2239
2572
  aclTensor* acl_sin_reshape_tensor =
2240
- ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float_t),
2573
+ ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float),
2241
2574
  sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
2242
2575
  aclTensor* acl_cos_reshape_tensor =
2243
- ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float_t),
2576
+ ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float),
2244
2577
  sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
2245
- aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor,
2246
- theta_scale, freq_scale, attn_factor, is_neox);
2247
2578
 
2248
2579
  aclTensor* acl_src = ggml_cann_create_tensor(src0);
2249
2580
  aclTensor* acl_dst = ggml_cann_create_tensor(dst);
@@ -2257,7 +2588,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2257
2588
  void* minus_one_scale_buffer = nullptr;
2258
2589
  ggml_cann_pool_alloc roll_allocator(ctx.pool(), ggml_nbytes(src0));
2259
2590
  ggml_cann_pool_alloc minus_one_scale_allocator(
2260
- ctx.pool(), sizeof(float_t) * src0->ne[0]);
2591
+ ctx.pool(), sizeof(float) * src0->ne[0]);
2261
2592
  if (!is_neox) {
2262
2593
  // roll input: [q0,q1,q2,q3,...] -> [q1,q0,q3,q2,...]
2263
2594
  input_roll_buffer = roll_allocator.get();
@@ -2287,13 +2618,13 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2287
2618
 
2288
2619
  int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1};
2289
2620
  size_t minus_one_nb[GGML_MAX_DIMS];
2290
- minus_one_nb[0] = sizeof(float_t);
2621
+ minus_one_nb[0] = sizeof(float);
2291
2622
  for (int i = 1; i < GGML_MAX_DIMS; i++) {
2292
2623
  minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1];
2293
2624
  }
2294
2625
  acl_minus_one_tensor = aclnn_values(
2295
- ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0],
2296
- minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1);
2626
+ ctx, minus_one_scale_buffer, sizeof(float) * src0->ne[0],
2627
+ minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float), 1);
2297
2628
  int64_t dim = 3;
2298
2629
  int64_t* index = new int64_t[src0->ne[0]];
2299
2630
  for (int i = 0; i < src0->ne[0]; i++) {
@@ -2321,22 +2652,22 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2321
2652
  minus_one_scale_buffer = minus_one_scale_allocator.get();
2322
2653
  int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1};
2323
2654
  size_t minus_one_nb[GGML_MAX_DIMS];
2324
- minus_one_nb[0] = sizeof(float_t);
2655
+ minus_one_nb[0] = sizeof(float);
2325
2656
  for (int i = 1; i < GGML_MAX_DIMS; i++) {
2326
2657
  minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1];
2327
2658
  }
2328
2659
  acl_minus_one_tensor = aclnn_values(
2329
- ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0],
2330
- minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1);
2660
+ ctx, minus_one_scale_buffer, sizeof(float) * src0->ne[0],
2661
+ minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float), 1);
2331
2662
  // -1 * first half
2332
2663
  int64_t first_half_ne[4] = {src0->ne[0] / 2, 1, 1, 1};
2333
2664
  size_t first_half_nb[GGML_MAX_DIMS];
2334
- first_half_nb[0] = sizeof(float_t);
2665
+ first_half_nb[0] = sizeof(float);
2335
2666
  for (int i = 1; i < GGML_MAX_DIMS; i++) {
2336
2667
  first_half_nb[i] = first_half_nb[i - 1] * first_half_ne[i - 1];
2337
2668
  }
2338
2669
  aclTensor* acl_first_half_tensor = ggml_cann_create_tensor(
2339
- minus_one_scale_buffer, ACL_FLOAT, sizeof(float_t), first_half_ne,
2670
+ minus_one_scale_buffer, ACL_FLOAT, sizeof(float), first_half_ne,
2340
2671
  first_half_nb, GGML_MAX_DIMS);
2341
2672
  bool inplace = true;
2342
2673
  float scale = -1;
@@ -2376,28 +2707,28 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2376
2707
  // TODO: ne0 != n_dims in mode2
2377
2708
  } else if (src0->type == GGML_TYPE_F16) {
2378
2709
  size_t input_fp32_nb[GGML_MAX_DIMS];
2379
- input_fp32_nb[0] = sizeof(float_t);
2710
+ input_fp32_nb[0] = sizeof(float);
2380
2711
  for (int i = 1; i < GGML_MAX_DIMS; i++) {
2381
2712
  input_fp32_nb[i] = input_fp32_nb[i - 1] * dst->ne[i - 1];
2382
2713
  }
2383
2714
  ggml_cann_pool_alloc fp32_allocator1(
2384
- ctx.pool(), ggml_nelements(dst) * sizeof(float_t));
2715
+ ctx.pool(), ggml_nelements(dst) * sizeof(float));
2385
2716
  void* input_fp32_buffer1 = fp32_allocator1.get();
2386
2717
  aclTensor* input_fp32_tensor1 = ggml_cann_create_tensor(
2387
- input_fp32_buffer1, ACL_FLOAT, sizeof(float_t), dst->ne,
2718
+ input_fp32_buffer1, ACL_FLOAT, sizeof(float), dst->ne,
2388
2719
  input_fp32_nb, GGML_MAX_DIMS);
2389
2720
  ggml_cann_pool_alloc fp32_allocator2(
2390
- ctx.pool(), ggml_nelements(dst) * sizeof(float_t));
2721
+ ctx.pool(), ggml_nelements(dst) * sizeof(float));
2391
2722
  void* input_fp32_buffer2 = fp32_allocator2.get();
2392
2723
  aclTensor* input_fp32_tensor2 = ggml_cann_create_tensor(
2393
- input_fp32_buffer2, ACL_FLOAT, sizeof(float_t), dst->ne,
2724
+ input_fp32_buffer2, ACL_FLOAT, sizeof(float), dst->ne,
2394
2725
  input_fp32_nb, GGML_MAX_DIMS);
2395
2726
 
2396
2727
  ggml_cann_pool_alloc fp32_allocator(
2397
- ctx.pool(), ggml_nelements(dst) * sizeof(float_t));
2728
+ ctx.pool(), ggml_nelements(dst) * sizeof(float));
2398
2729
  output_fp32_buffer = fp32_allocator.get();
2399
2730
  aclTensor* output_fp32_tensor = ggml_cann_create_tensor(
2400
- output_fp32_buffer, ACL_FLOAT, sizeof(float_t), dst->ne,
2731
+ output_fp32_buffer, ACL_FLOAT, sizeof(float), dst->ne,
2401
2732
  input_fp32_nb, GGML_MAX_DIMS);
2402
2733
  aclnn_mul(ctx, acl_src, acl_cos_reshape_tensor, input_fp32_tensor1);
2403
2734
  aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor,
@@ -2494,8 +2825,6 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
2494
2825
  aclIntArray *padding = aclCreateIntArray(paddingVal, 1);
2495
2826
  int64_t dilationVal[] = {1};
2496
2827
  aclIntArray *dilation = aclCreateIntArray(dilationVal, 1);
2497
- bool transposed = true;
2498
- int64_t groups = 1;
2499
2828
  int8_t cubeMathType = 0;
2500
2829
 
2501
2830
  #ifdef ASCEND_310P
@@ -2503,7 +2832,7 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
2503
2832
  #endif
2504
2833
 
2505
2834
  GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input, acl_weight, nullptr, stride,
2506
- padding, dilation, transposed, padding, groups, acl_dst, cubeMathType);
2835
+ padding, dilation, true, padding, 1, acl_dst, cubeMathType);
2507
2836
 
2508
2837
  ggml_cann_release_resources(ctx, acl_weight, acl_dst, stride, padding, dilation);
2509
2838
  }
@@ -2612,113 +2941,49 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){
2612
2941
  */
2613
2942
  static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2614
2943
  //dst [M, K, N, 1]
2615
- ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
2616
- ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
2944
+ ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1] -> [D, M, K, 1]
2945
+ ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1 -> [D, 1, K, 1]
2617
2946
  ggml_tensor * ids = dst->src[2]; //ids [K, N]
2618
2947
 
2619
- GGML_TENSOR_BINARY_OP_LOCALS
2620
-
2621
- // copy index from npu to cpu
2622
- int64_t n_as = ne02; // A
2623
- int64_t n_ids = ids->ne[0]; // K
2624
-
2625
- std::vector<char> ids_host(ggml_nbytes(ids));
2626
- ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids),
2627
- ACL_MEMCPY_DEVICE_TO_HOST);
2628
- ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
2948
+ GGML_ASSERT(src0->ne[3] == 1);
2949
+ GGML_ASSERT(src1->ne[3] == 1);
2950
+ GGML_ASSERT(dst->ne[3] == 1);
2629
2951
 
2630
- char * src0_original = (char *) src0->data;
2631
- char * src1_original = (char *) src1->data;
2632
- char * dst_original = (char *) dst->data;
2633
- size_t ori_src0_nb[4] = {nb00, nb01, nb02, nb03};
2952
+ int64_t batch = src1->ne[2];
2953
+ GGML_ASSERT(batch == ids->ne[1]);
2634
2954
 
2635
- // src0 is F16, src1 is F32, dst is F32
2636
- ggml_cann_pool_alloc src0_cast_allocator;
2637
- if (src0->type == GGML_TYPE_F16) {
2638
- src0_cast_allocator.alloc(ctx.pool(), sizeof(float) * ggml_nelements(src0));
2639
- void* src0_cast_buf = src0_cast_allocator.get();
2955
+ ggml_cann_pool_alloc export_allocator(ctx.pool(), src0->ne[0] * src0->ne[1] * ids->ne[0] * ggml_element_size(src0));
2956
+ void* export_ptr = export_allocator.get();
2957
+ for (int64_t i = 0; i < batch; i++) {
2958
+ aclTensor *select_index = ggml_cann_create_tensor(ids, ids->ne, ids->nb, 1, ACL_FORMAT_ND, i * ids->nb[1]);
2959
+ aclTensor *export_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3);
2640
2960
 
2641
- size_t cast_nb[GGML_MAX_DIMS];
2642
- cast_nb[0] = sizeof(float_t);
2643
- for (int i = 1; i < GGML_MAX_DIMS; i++) {
2644
- cast_nb[i] = cast_nb[i - 1] * src0->ne[i - 1];
2961
+ int64_t select_export_ne[] = {src0->ne[0], src0->ne[1], ids->ne[0]};
2962
+ size_t select_export_nb[3];
2963
+ select_export_nb[0] = src0->nb[0];
2964
+ for (int k = 1;k < 3; k++) {
2965
+ select_export_nb[k] = select_export_nb[k-1] * select_export_ne[k-1];
2645
2966
  }
2646
2967
 
2647
- aclTensor* acl_src0_f16 = ggml_cann_create_tensor(src0);
2648
- aclTensor* acl_cast = ggml_cann_create_tensor(src0_cast_buf,
2649
- ACL_FLOAT, sizeof(float), src0->ne, cast_nb, 4);
2650
- GGML_CANN_CALL_ACLNN_OP(ctx, Cast, acl_src0_f16, ACL_FLOAT, acl_cast);
2651
- ggml_cann_release_resources(ctx, acl_cast, acl_src0_f16);
2652
-
2653
- src0_original = (char *) src0_cast_buf;
2654
- memcpy(ori_src0_nb, cast_nb, sizeof(ori_src0_nb));
2655
- }
2656
-
2657
- std::vector<aclTensor*> src0_tensor_vec;
2658
- std::vector<aclTensor*> src1_tensor_vec;
2659
- std::vector<aclTensor*> dst_tensor_vec;
2660
- for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2661
- for (int64_t id = 0; id < n_ids; id++) {
2662
- // src0_row [M, D] -> weight && permute
2663
- int64_t src0_ne[2] = {ne01, ne00};
2664
- size_t src0_nb[2] = {ori_src0_nb[1], ori_src0_nb[0]};
2665
- // src1_row [D, 1] -> input
2666
- int64_t src1_ne[2] = {ne10, 1};
2667
- size_t src1_nb[2] = {nb10, nb11};
2668
- // dst_row [M, 1] -> out
2669
- int64_t dst_ne[2] = {ne0, 1};
2670
- size_t dst_nb[2] = {nb0, nb1};
2671
-
2672
- // expert index
2673
- int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
2674
- GGML_ASSERT(i02 >= 0 && i02 < n_as);
2675
-
2676
- // If B = 1 (broadcast), always use 0; otherwise, use id.
2677
- int64_t i11 = (ne11 == 1 ? 0 : id);
2678
- int64_t i12 = iid1;
2679
-
2680
- int64_t i1 = id;
2681
- int64_t i2 = i12;
2968
+ aclTensor *select_export = ggml_cann_create_tensor(export_ptr, ggml_cann_type_mapping(src0->type), ggml_element_size(src0), select_export_ne, select_export_nb, 3);
2969
+ GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, export_weight, 0, select_index, select_export);
2682
2970
 
2683
- void* src0_tmp_ptr = src0_original + i02*ori_src0_nb[2];
2684
- void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
2685
- void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
2686
-
2687
- aclTensor* acl_src0 = ggml_cann_create_tensor(src0_tmp_ptr,
2688
- ACL_FLOAT, sizeof(float),
2689
- src0_ne, src0_nb, 2);
2690
- aclTensor* acl_src1 = ggml_cann_create_tensor(src1_tmp_ptr,
2691
- ACL_FLOAT, sizeof(float),
2692
- src1_ne, src1_nb, 2);
2693
- aclTensor* acl_dst = ggml_cann_create_tensor(dst_tmp_ptr,
2694
- ACL_FLOAT, sizeof(float),
2695
- dst_ne, dst_nb, 2);
2696
-
2697
- src0_tensor_vec.push_back(acl_src0);
2698
- src1_tensor_vec.push_back(acl_src1);
2699
- dst_tensor_vec.push_back(acl_dst);
2700
- }
2701
- }
2971
+ int64_t select_transpose_ne[] = {select_export_ne[1], select_export_ne[0], select_export_ne[2]};
2972
+ size_t select_transpose_nb[] = {select_export_nb[1], select_export_nb[0], select_export_nb[2]};
2973
+ aclTensor *select_export_transpose = ggml_cann_create_tensor(export_ptr, ggml_cann_type_mapping(src0->type), ggml_element_size(src0), select_transpose_ne, select_transpose_nb, 3);
2702
2974
 
2703
- size_t GROUP_SIZE = 128;
2704
- // GroupedMatmulV2 required tensor_list.size < 128
2705
- for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) {
2706
- // split and call GroupedMatmulV2
2707
- size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size());
2708
- std::vector<aclTensor*> src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end);
2709
- std::vector<aclTensor*> src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end);
2710
- std::vector<aclTensor*> dst_tensor_vec_split(dst_tensor_vec.begin() + i, dst_tensor_vec.begin() + end);
2975
+ int64_t active_tensor_ne[] = {src1->ne[0], 1, src1->ne[1]};
2976
+ size_t active_tensor_nb[] = {src1->nb[0], src1->nb[1], src1->nb[1]};
2977
+ aclTensor *active_tensor = ggml_cann_create_tensor(src1, active_tensor_ne, active_tensor_nb, 3, ACL_FORMAT_ND, i * src1->nb[2]);
2711
2978
 
2712
- aclTensorList* src0_tensor_list = aclCreateTensorList(src0_tensor_vec_split.data(), src0_tensor_vec_split.size());
2713
- aclTensorList* src1_tensor_list = aclCreateTensorList(src1_tensor_vec_split.data(), src1_tensor_vec_split.size());
2714
- aclTensorList* dst_tensor_list = aclCreateTensorList(dst_tensor_vec_split.data(), dst_tensor_vec_split.size());
2979
+ int64_t dst_ne[] = {dst->ne[0], 1, dst->ne[1]};
2980
+ size_t dst_nb[] = {dst->nb[0], dst->nb[1], dst->nb[1]};
2981
+ aclTensor *acl_dst = ggml_cann_create_tensor(dst, dst_ne,dst_nb, 3, ACL_FORMAT_ND, i * dst->nb[2]);
2715
2982
 
2716
- GGML_CANN_CALL_ACLNN_OP(ctx, GroupedMatmulV2, src1_tensor_list, src0_tensor_list,
2717
- nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, -1, dst_tensor_list);
2983
+ GGML_CANN_CALL_ACLNN_OP(ctx, BatchMatMul, active_tensor, select_export_transpose, acl_dst, 2);
2718
2984
 
2719
- ggml_cann_release_resources(ctx, src0_tensor_list, src1_tensor_list, dst_tensor_list);
2985
+ ggml_cann_release_resources(ctx, select_index, export_weight, select_export, active_tensor, acl_dst, select_export_transpose);
2720
2986
  }
2721
- return;
2722
2987
  }
2723
2988
 
2724
2989
  /**
@@ -2867,11 +3132,38 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2867
3132
 
2868
3133
  void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
2869
3134
 
2870
- ggml_tensor* src0 = dst->src[0]; // q, fp32
2871
- ggml_tensor* src1 = dst->src[1]; // k, fp16
2872
- ggml_tensor* src2 = dst->src[2]; // v, fp16
3135
+ ggml_tensor* src0 = dst->src[0]; // q, fp32 | B, N, S, D (uncont) -> B, S, N, D (cont)
3136
+ ggml_tensor* src1 = dst->src[1]; // k, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)
3137
+ ggml_tensor* src2 = dst->src[2]; // v, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)
2873
3138
  ggml_tensor* src3 = dst->src[3]; // mask, fp16
2874
3139
 
3140
+ // B, N, S, D (uncont) -> B, S, N, D (cont)
3141
+ int64_t src0_bsnd_ne[GGML_MAX_DIMS];
3142
+ memcpy(src0_bsnd_ne, src0->ne, GGML_MAX_DIMS * sizeof(int64_t));
3143
+ size_t src0_bsnd_nb[GGML_MAX_DIMS];
3144
+ memcpy(src0_bsnd_nb, src0->nb, GGML_MAX_DIMS * sizeof(size_t));
3145
+ int64_t src1_bsnd_ne[GGML_MAX_DIMS];
3146
+ memcpy(src1_bsnd_ne, src1->ne, GGML_MAX_DIMS * sizeof(int64_t));
3147
+ size_t src1_bsnd_nb[GGML_MAX_DIMS];
3148
+ memcpy(src1_bsnd_nb, src1->nb, GGML_MAX_DIMS * sizeof(size_t));
3149
+ int64_t src2_bsnd_ne[GGML_MAX_DIMS];
3150
+ memcpy(src2_bsnd_ne, src2->ne, GGML_MAX_DIMS * sizeof(int64_t));
3151
+ size_t src2_bsnd_nb[GGML_MAX_DIMS];
3152
+ memcpy(src2_bsnd_nb, src2->nb, GGML_MAX_DIMS * sizeof(size_t));
3153
+
3154
+ auto transpose12 = [](int64_t* ne, size_t* nb) {
3155
+ int64_t ne_tmp = ne[1];
3156
+ size_t nb_tmp = nb[1];
3157
+ ne[1] = ne[2];
3158
+ nb[1] = nb[2];
3159
+ ne[2] = ne_tmp;
3160
+ nb[2] = nb_tmp;
3161
+ };
3162
+
3163
+ transpose12(src0_bsnd_ne, src0_bsnd_nb);
3164
+ transpose12(src1_bsnd_ne, src1_bsnd_nb);
3165
+ transpose12(src2_bsnd_ne, src2_bsnd_nb);
3166
+
2875
3167
  float maxBias = 0.0f;
2876
3168
  float scaleValue = 1.0f;
2877
3169
  float logitSoftcap = 0.0f;
@@ -2893,11 +3185,12 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
2893
3185
  void* src0_f16_buffer = nullptr;
2894
3186
 
2895
3187
  if(ggml_cann_type_mapping(src0->type) != faDataType){
2896
- aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0);
3188
+ aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne,
3189
+ src0_bsnd_nb, GGML_MAX_DIMS);
2897
3190
  src0_f16_buffer = src0_f16_allocator.alloc(
2898
3191
  ggml_nelements(src0) * faElemSize);
2899
3192
 
2900
- int64_t* src0_f16_ne = src0->ne;
3193
+ int64_t* src0_f16_ne = src0_bsnd_ne;
2901
3194
  size_t src0_f16_nb[GGML_MAX_DIMS];
2902
3195
  src0_f16_nb[0] = sizeof(uint16_t);
2903
3196
  for(int i = 1; i < GGML_MAX_DIMS; ++i){
@@ -2911,20 +3204,23 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
2911
3204
  aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType);
2912
3205
  ggml_cann_release_resources(ctx, acl_src0_f32_tensor);
2913
3206
  }else{
2914
- acl_src0_f16_tensor = ggml_cann_create_tensor(src0);
3207
+ acl_src0_f16_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne,
3208
+ src0_bsnd_nb, GGML_MAX_DIMS);
2915
3209
  }
2916
3210
 
2917
3211
  // Step 2: create the acl tensors for src1 (Key), src2 (Value),
2918
3212
  // and the direct output from FusedInferAttention
2919
3213
 
2920
- acl_src1_f16_tensor = ggml_cann_create_tensor(src1);
2921
- acl_src2_f16_tensor = ggml_cann_create_tensor(src2);
3214
+ acl_src1_f16_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne,
3215
+ src1_bsnd_nb, GGML_MAX_DIMS);
3216
+ acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne,
3217
+ src2_bsnd_nb, GGML_MAX_DIMS);
2922
3218
 
2923
3219
  ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
2924
3220
  void* out_f16_buffer = out_f16_allocator.alloc(
2925
3221
  ggml_nelements(dst) * faElemSize);
2926
3222
 
2927
- int64_t* out_f16_ne = src0->ne;
3223
+ int64_t* out_f16_ne = src0_bsnd_ne;
2928
3224
  size_t out_f16_nb[GGML_MAX_DIMS];
2929
3225
  out_f16_nb[0] = faElemSize;
2930
3226
  for(int i = 1; i < GGML_MAX_DIMS; ++i){
@@ -2938,168 +3234,81 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
2938
3234
 
2939
3235
  // Step 3: create the PSEShift tensor if needed
2940
3236
  // this tensor is considered as mask (f16) in the llama.cpp
2941
-
2942
3237
  aclTensor* bcast_pse_tensor = nullptr;
2943
- int64_t bcast_pse_ne[GGML_MAX_DIMS];
2944
- size_t bcast_pse_nb[GGML_MAX_DIMS];
2945
3238
  ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool());
2946
- void* bcast_pse_buffer = nullptr;
2947
-
2948
3239
  if(src3 != nullptr){
2949
- bcast_pse_buffer = bcast_pse_allocator.alloc(
2950
- ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t));
2951
-
2952
- if(src0->ne[1] > 1){
2953
- // Case 1: broadcast pse for prefill stage with multiple head
2954
- aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3);
2955
- bcast_pse_ne[0] = src3->ne[0];
2956
- bcast_pse_ne[1] = src3->ne[1];
2957
- bcast_pse_ne[2] = src0->ne[2];
2958
- bcast_pse_ne[3] = src3->ne[3];
3240
+ // Construct the truncated pse tensor (common for prefill/decode)
3241
+ int64_t trunc_pse_ne[GGML_MAX_DIMS] = {
3242
+ src3->ne[0], // D
3243
+ src0->ne[1], // S (number of Q tokens)
3244
+ src3->ne[2], // mask N
3245
+ src3->ne[3] // B
3246
+ };
3247
+ size_t* trunc_pse_nb = src3->nb;
3248
+
3249
+ aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
3250
+ src3->data, ACL_FLOAT16, sizeof(uint16_t),
3251
+ trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS
3252
+ );
2959
3253
 
3254
+ int64_t bcast_pse_ne[GGML_MAX_DIMS];
3255
+ size_t bcast_pse_nb[GGML_MAX_DIMS];
3256
+ bcast_pse_ne[0] = src3->ne[0]; // D
3257
+ bcast_pse_ne[1] = src0->ne[1]; // S
3258
+ bcast_pse_ne[2] = src0->ne[2]; // N (num_heads)
3259
+ bcast_pse_ne[3] = src3->ne[3]; // B
3260
+ if (maxBias == 0.0f) {
3261
+ // When maxBias == 0.0f, use nb = 0 reduce once repeat (Qwen2)
3262
+ // Construct the bcast tensor (simulate repeat on the head dimension using stride=0)
2960
3263
  bcast_pse_nb[0] = sizeof(uint16_t);
2961
- for(int i = 1; i < GGML_MAX_DIMS; ++i){
2962
- bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
2963
- }
3264
+ bcast_pse_nb[1] = bcast_pse_nb[0] * bcast_pse_ne[0];
3265
+ bcast_pse_nb[2] = 0; // <---- the head dimension shares the same data
3266
+ bcast_pse_nb[3] = src3->nb[3];
2964
3267
 
2965
3268
  bcast_pse_tensor = ggml_cann_create_tensor(
2966
- bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
2967
- bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
2968
-
2969
- int64_t repeats[] = {1, src0->ne[2], 1, 1};
2970
- aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats);
2971
-
2972
- ggml_cann_release_resources(ctx, acl_mask_f16_tensor);
2973
- }else{
2974
- // Case 2: trunc the first row and broadcast pse for decode stage with multiple head
2975
- int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]};
2976
- size_t* trunc_pse_nb = src3->nb;
2977
-
2978
- aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
2979
3269
  src3->data, ACL_FLOAT16, sizeof(uint16_t),
2980
- trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS);
2981
-
2982
- bcast_pse_ne[0] = src3->ne[0];
2983
- bcast_pse_ne[1] = src0->ne[1];
2984
- bcast_pse_ne[2] = src0->ne[2];
2985
- bcast_pse_ne[3] = src3->ne[3];
3270
+ bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS
3271
+ );
2986
3272
 
3273
+ ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
3274
+ } else {
2987
3275
  bcast_pse_nb[0] = sizeof(uint16_t);
2988
- for(int i = 1; i < GGML_MAX_DIMS; ++i){
3276
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
2989
3277
  bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
2990
3278
  }
2991
3279
 
3280
+ void* bcast_pse_buffer = bcast_pse_allocator.alloc(
3281
+ ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t)
3282
+ );
3283
+
2992
3284
  bcast_pse_tensor = ggml_cann_create_tensor(
2993
3285
  bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
2994
- bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
3286
+ bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS
3287
+ );
2995
3288
 
2996
3289
  int64_t repeats[] = {1, src0->ne[2], 1, 1};
2997
3290
  aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats);
2998
3291
 
2999
- ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
3000
- }
3001
-
3002
- // Compute the slope if needed. Derived from ggml_cann_softmax().
3003
- if(maxBias != 0.0f){
3004
3292
  // alibi
3005
- const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3];
3006
- const int64_t n_head = src0->ne[2];
3007
- const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
3008
- float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor);
3009
- float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor);
3010
- // init arange
3011
- ggml_cann_pool_alloc arange_allocator(ctx.pool(),
3012
- ne2_ne3 * faElemSize);
3013
- void* tmp_arange_buffer = arange_allocator.get();
3014
-
3015
- // arange1: [1, ..., n_heads_log2_floor+1)
3016
- float start = 1;
3017
- float stop = n_heads_log2_floor + 1;
3018
- float step = 1;
3019
- int64_t n_elements_arange = n_heads_log2_floor;
3020
-
3021
- int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
3022
- size_t tmp_arange1_nb[] = {faElemSize};
3023
- aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
3024
- tmp_arange_buffer, faDataType, faElemSize,
3025
- tmp_arange1_ne, tmp_arange1_nb,
3026
- GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3027
-
3028
- aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
3029
-
3030
- aclTensor* tmp_arange2_tensor = nullptr;
3031
- if (n_heads_log2_floor < ne2_ne3) {
3032
- // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
3033
- start = 1;
3034
- stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
3035
- step = 2;
3036
- n_elements_arange = ne2_ne3 - n_heads_log2_floor;
3037
- int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
3038
- size_t tmp_arange2_nb[] = {faElemSize};
3039
-
3040
- aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
3041
- (char*)tmp_arange_buffer +
3042
- n_heads_log2_floor * faElemSize,
3043
- faDataType, faElemSize,
3044
- tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3045
- aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
3046
- n_elements_arange);
3293
+ // Compute the slope if needed. Derived from ggml_cann_softmax().
3294
+ const int64_t n_heads = src0->ne[2];
3295
+ ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(uint16_t));
3296
+ void* slope_buffer = slope_allocator.get();
3297
+ aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias, GGML_TYPE_F16);
3298
+
3299
+ int64_t slope_ne[] = {1, 1, n_heads, 1};
3300
+ size_t slope_nb[GGML_MAX_DIMS];
3301
+ slope_nb[0] = sizeof(uint16_t);
3302
+ for(int i = 1;i<GGML_MAX_DIMS;i++) {
3303
+ slope_nb[i] = slope_nb[i-1] * slope_ne[0];
3047
3304
  }
3048
3305
 
3049
- // init mk_base
3050
- ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
3051
- ne2_ne3 * faElemSize);
3052
- void* tmp_mk_base_buffer = mk_base_allocator.get();
3053
- int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
3054
- size_t tmp_mk_base1_nb[] = {faElemSize};
3055
- aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
3056
- tmp_mk_base_buffer, faDataType, faElemSize,
3057
- tmp_mk_base1_ne, tmp_mk_base1_nb,
3058
- GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3059
-
3060
- aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
3061
-
3062
- aclTensor* tmp_mk_base2_tensor = nullptr;
3063
- if (n_heads_log2_floor < ne2_ne3) {
3064
- int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
3065
- size_t tmp_mk_base2_nb[] = {faElemSize};
3066
- aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
3067
- (char*)tmp_mk_base_buffer +
3068
- n_heads_log2_floor * faElemSize,
3069
- faDataType, faElemSize,
3070
- tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3071
- aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
3072
- }
3306
+ aclTensor* slope_tensor = ggml_cann_create_tensor(
3307
+ slope_buffer, ACL_FLOAT16, sizeof(uint16_t),
3308
+ slope_ne, slope_nb, GGML_MAX_DIMS);
3309
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, slope_tensor);
3073
3310
 
3074
- // init mk
3075
- int64_t tmp_mk_base_ne[] = {ne2_ne3};
3076
- size_t tmp_mk_base_nb[] = {faElemSize};
3077
- aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
3078
- tmp_mk_base_buffer, faDataType, faElemSize,
3079
- tmp_mk_base_ne, tmp_mk_base_nb,
3080
- GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3081
- aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
3082
- tmp_arange_buffer, faDataType, faElemSize,
3083
- tmp_mk_base_ne, tmp_mk_base_nb,
3084
- GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3085
- aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
3086
-
3087
- // reshape mk
3088
- int64_t tmp_mk_ne[] = {1, 1, src0->ne[2], src0->ne[3]};
3089
- size_t tmp_mk_nb[GGML_MAX_DIMS];
3090
- tmp_mk_nb[0] = faElemSize;
3091
- for (int i = 1; i < GGML_MAX_DIMS; i++) {
3092
- tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
3093
- }
3094
- aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
3095
- tmp_mk_base_buffer, faDataType, faElemSize,
3096
- tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
3097
- ACL_FORMAT_ND);
3098
- GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor);
3099
-
3100
- ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
3101
- tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
3102
- tmp_arange_tensor, tmp_mk_tensor);
3311
+ ggml_cann_release_resources(ctx, slope_tensor, acl_mask_f16_trunc_tensor);
3103
3312
  }
3104
3313
  }
3105
3314
 
@@ -3116,7 +3325,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
3116
3325
  // double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d)
3117
3326
  int64_t preTokens = 65535;
3118
3327
  int64_t nextTokens = 65535;
3119
- char layout[5] = {'B', 'N', 'S', 'D', 0};
3328
+ char layout[5] = {'B', 'S', 'N', 'D', 0};
3120
3329
  int64_t sparseMode = 0;
3121
3330
  int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2;
3122
3331
  int64_t blockSize = 0;
@@ -3153,32 +3362,9 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
3153
3362
  );
3154
3363
 
3155
3364
  // Step 6: post-processing, permute and cast to f32
3156
-
3157
- int64_t new_dim[] = {0, 2, 1, 3};
3158
3365
  aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
3159
-
3160
- if(ggml_cann_type_mapping(dst->type) != faDataType){
3161
- ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool());
3162
- perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
3163
- void* perm_out_f16_buffer = perm_out_f16_allocator.get();
3164
-
3165
- int64_t* perm_out_f16_ne = dst->ne;
3166
- size_t perm_out_f16_nb[GGML_MAX_DIMS];
3167
- perm_out_f16_nb[0] = faElemSize;
3168
- for(int i = 1; i < GGML_MAX_DIMS; ++i){
3169
- perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1];
3170
- }
3171
- aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor(
3172
- perm_out_f16_buffer, faDataType, faElemSize,
3173
- perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS);
3174
- aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS);
3175
- aclnn_cast(ctx,
3176
- acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
3177
- ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor);
3178
- }else{
3179
- // only need to permute
3180
- aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS);
3181
- }
3366
+ // TODO: when dst is fp16, don't need cast
3367
+ aclnn_cast(ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
3182
3368
  ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
3183
3369
  acl_src1_f16_tensor,
3184
3370
  acl_src2_f16_tensor,