whispercpp 1.3.1 → 1.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (797) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +4 -3
  3. data/README.md +92 -31
  4. data/Rakefile +26 -7
  5. data/ext/.gitignore +5 -7
  6. data/ext/dependencies.rb +61 -0
  7. data/ext/extconf.rb +21 -198
  8. data/ext/options.rb +221 -0
  9. data/ext/ruby_whisper.c +159 -0
  10. data/ext/ruby_whisper.h +17 -2
  11. data/ext/ruby_whisper_context.c +641 -0
  12. data/ext/ruby_whisper_error.c +52 -0
  13. data/ext/ruby_whisper_model.c +232 -0
  14. data/ext/ruby_whisper_params.c +1301 -0
  15. data/ext/ruby_whisper_segment.c +143 -0
  16. data/ext/ruby_whisper_transcribe.cpp +87 -0
  17. data/ext/ruby_whisper_vad_params.c +288 -0
  18. data/ext/sources/.dockerignore +3 -0
  19. data/ext/sources/.github/workflows/bindings-ruby.yml +21 -0
  20. data/ext/sources/CMakeGraphVizOptions.cmake +8 -0
  21. data/ext/sources/CMakeLists.txt +251 -0
  22. data/ext/sources/bindings/javascript/CMakeLists.txt +41 -0
  23. data/ext/sources/bindings/javascript/emscripten.cpp +93 -0
  24. data/ext/sources/bindings/javascript/libwhisper.worker.js +1 -0
  25. data/ext/sources/bindings/javascript/package-tmpl.json +26 -0
  26. data/ext/sources/bindings/javascript/package.json +26 -0
  27. data/ext/sources/bindings/javascript/whisper.js +19 -0
  28. data/ext/sources/build-xcframework.sh +547 -0
  29. data/ext/sources/ci/run.sh +336 -0
  30. data/ext/sources/close-issue.yml +28 -0
  31. data/ext/sources/cmake/DefaultTargetOptions.cmake +16 -0
  32. data/ext/sources/cmake/FindFFmpeg.cmake +163 -0
  33. data/ext/sources/cmake/build-info.cmake +60 -0
  34. data/ext/sources/cmake/git-vars.cmake +22 -0
  35. data/ext/sources/cmake/whisper-config.cmake.in +65 -0
  36. data/ext/sources/cmake/whisper.pc.in +10 -0
  37. data/ext/sources/examples/CMakeLists.txt +124 -0
  38. data/ext/sources/examples/addon.node/CMakeLists.txt +31 -0
  39. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +37 -0
  40. data/ext/sources/examples/addon.node/addon.cpp +438 -0
  41. data/ext/sources/examples/addon.node/index.js +54 -0
  42. data/ext/sources/examples/addon.node/package.json +16 -0
  43. data/ext/sources/examples/bench/CMakeLists.txt +8 -0
  44. data/ext/sources/examples/bench/bench.cpp +175 -0
  45. data/ext/sources/examples/bench.wasm/CMakeLists.txt +49 -0
  46. data/ext/sources/examples/bench.wasm/emscripten.cpp +87 -0
  47. data/ext/sources/examples/bench.wasm/index-tmpl.html +284 -0
  48. data/ext/sources/examples/cli/CMakeLists.txt +8 -0
  49. data/ext/sources/examples/cli/cli.cpp +1294 -0
  50. data/ext/sources/examples/coi-serviceworker.js +146 -0
  51. data/ext/sources/examples/command/CMakeLists.txt +10 -0
  52. data/ext/sources/examples/command/command.cpp +776 -0
  53. data/ext/sources/examples/command/commands.txt +9 -0
  54. data/ext/sources/examples/command.wasm/CMakeLists.txt +50 -0
  55. data/ext/sources/examples/command.wasm/emscripten.cpp +327 -0
  56. data/ext/sources/examples/command.wasm/index-tmpl.html +414 -0
  57. data/ext/sources/examples/common-ggml.cpp +238 -0
  58. data/ext/sources/examples/common-ggml.h +18 -0
  59. data/ext/sources/examples/common-sdl.cpp +227 -0
  60. data/ext/sources/examples/common-sdl.h +49 -0
  61. data/ext/sources/examples/common-whisper.cpp +168 -0
  62. data/ext/sources/examples/common-whisper.h +24 -0
  63. data/ext/sources/examples/common.cpp +675 -0
  64. data/ext/sources/examples/common.h +322 -0
  65. data/ext/sources/examples/deprecation-warning/CMakeLists.txt +6 -0
  66. data/ext/sources/examples/deprecation-warning/deprecation-warning.cpp +38 -0
  67. data/ext/sources/examples/ffmpeg-transcode.cpp +368 -0
  68. data/ext/sources/examples/generate-karaoke.sh +57 -0
  69. data/ext/sources/examples/grammar-parser.cpp +423 -0
  70. data/ext/sources/examples/grammar-parser.h +29 -0
  71. data/ext/sources/examples/helpers.js +191 -0
  72. data/ext/sources/examples/json.hpp +24596 -0
  73. data/ext/sources/examples/livestream.sh +112 -0
  74. data/ext/sources/examples/lsp/CMakeLists.txt +9 -0
  75. data/ext/sources/examples/lsp/lsp.cpp +467 -0
  76. data/ext/sources/examples/lsp/whisper.vim +362 -0
  77. data/ext/sources/examples/miniaudio.h +93468 -0
  78. data/ext/sources/examples/python/test_whisper_processor.py +7 -0
  79. data/ext/sources/examples/python/whisper_processor.py +54 -0
  80. data/ext/sources/examples/quantize/CMakeLists.txt +6 -0
  81. data/ext/sources/examples/quantize/quantize.cpp +223 -0
  82. data/ext/sources/examples/server/CMakeLists.txt +12 -0
  83. data/ext/sources/examples/server/bench.js +29 -0
  84. data/ext/sources/examples/server/httplib.h +10497 -0
  85. data/ext/sources/examples/server/server.cpp +1091 -0
  86. data/ext/sources/examples/server.py +115 -0
  87. data/ext/sources/examples/stb_vorbis.c +5584 -0
  88. data/ext/sources/examples/stream/CMakeLists.txt +10 -0
  89. data/ext/sources/examples/stream/stream.cpp +429 -0
  90. data/ext/sources/examples/stream.wasm/CMakeLists.txt +49 -0
  91. data/ext/sources/examples/stream.wasm/emscripten.cpp +216 -0
  92. data/ext/sources/examples/stream.wasm/index-tmpl.html +414 -0
  93. data/ext/sources/examples/sycl/CMakeLists.txt +9 -0
  94. data/ext/sources/examples/sycl/build.sh +22 -0
  95. data/ext/sources/examples/sycl/ls-sycl-device.cpp +11 -0
  96. data/ext/sources/examples/sycl/run-whisper.sh +17 -0
  97. data/ext/sources/examples/talk-llama/CMakeLists.txt +40 -0
  98. data/ext/sources/examples/talk-llama/eleven-labs.py +80 -0
  99. data/ext/sources/examples/talk-llama/llama-adapter.cpp +388 -0
  100. data/ext/sources/examples/talk-llama/llama-adapter.h +76 -0
  101. data/ext/sources/examples/talk-llama/llama-arch.cpp +1746 -0
  102. data/ext/sources/examples/talk-llama/llama-arch.h +437 -0
  103. data/ext/sources/examples/talk-llama/llama-batch.cpp +374 -0
  104. data/ext/sources/examples/talk-llama/llama-batch.h +89 -0
  105. data/ext/sources/examples/talk-llama/llama-chat.cpp +663 -0
  106. data/ext/sources/examples/talk-llama/llama-chat.h +58 -0
  107. data/ext/sources/examples/talk-llama/llama-context.cpp +2676 -0
  108. data/ext/sources/examples/talk-llama/llama-context.h +276 -0
  109. data/ext/sources/examples/talk-llama/llama-cparams.cpp +5 -0
  110. data/ext/sources/examples/talk-llama/llama-cparams.h +41 -0
  111. data/ext/sources/examples/talk-llama/llama-grammar.cpp +1229 -0
  112. data/ext/sources/examples/talk-llama/llama-grammar.h +173 -0
  113. data/ext/sources/examples/talk-llama/llama-graph.cpp +1618 -0
  114. data/ext/sources/examples/talk-llama/llama-graph.h +640 -0
  115. data/ext/sources/examples/talk-llama/llama-hparams.cpp +95 -0
  116. data/ext/sources/examples/talk-llama/llama-hparams.h +190 -0
  117. data/ext/sources/examples/talk-llama/llama-impl.cpp +167 -0
  118. data/ext/sources/examples/talk-llama/llama-impl.h +61 -0
  119. data/ext/sources/examples/talk-llama/llama-io.cpp +15 -0
  120. data/ext/sources/examples/talk-llama/llama-io.h +35 -0
  121. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2739 -0
  122. data/ext/sources/examples/talk-llama/llama-kv-cache.h +502 -0
  123. data/ext/sources/examples/talk-llama/llama-kv-cells.h +379 -0
  124. data/ext/sources/examples/talk-llama/llama-memory.cpp +1 -0
  125. data/ext/sources/examples/talk-llama/llama-memory.h +32 -0
  126. data/ext/sources/examples/talk-llama/llama-mmap.cpp +600 -0
  127. data/ext/sources/examples/talk-llama/llama-mmap.h +68 -0
  128. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +1138 -0
  129. data/ext/sources/examples/talk-llama/llama-model-loader.h +169 -0
  130. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +281 -0
  131. data/ext/sources/examples/talk-llama/llama-model-saver.h +37 -0
  132. data/ext/sources/examples/talk-llama/llama-model.cpp +13814 -0
  133. data/ext/sources/examples/talk-llama/llama-model.h +425 -0
  134. data/ext/sources/examples/talk-llama/llama-quant.cpp +966 -0
  135. data/ext/sources/examples/talk-llama/llama-quant.h +1 -0
  136. data/ext/sources/examples/talk-llama/llama-sampling.cpp +2575 -0
  137. data/ext/sources/examples/talk-llama/llama-sampling.h +32 -0
  138. data/ext/sources/examples/talk-llama/llama-vocab.cpp +3340 -0
  139. data/ext/sources/examples/talk-llama/llama-vocab.h +131 -0
  140. data/ext/sources/examples/talk-llama/llama.cpp +354 -0
  141. data/ext/sources/examples/talk-llama/llama.h +1377 -0
  142. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +23 -0
  143. data/ext/sources/examples/talk-llama/speak +40 -0
  144. data/ext/sources/examples/talk-llama/speak.bat +1 -0
  145. data/ext/sources/examples/talk-llama/speak.ps1 +14 -0
  146. data/ext/sources/examples/talk-llama/talk-llama.cpp +808 -0
  147. data/ext/sources/examples/talk-llama/unicode-data.cpp +7034 -0
  148. data/ext/sources/examples/talk-llama/unicode-data.h +20 -0
  149. data/ext/sources/examples/talk-llama/unicode.cpp +849 -0
  150. data/ext/sources/examples/talk-llama/unicode.h +66 -0
  151. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +8 -0
  152. data/ext/sources/examples/vad-speech-segments/speech.cpp +143 -0
  153. data/ext/sources/examples/wchess/CMakeLists.txt +10 -0
  154. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +19 -0
  155. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +803 -0
  156. data/ext/sources/examples/wchess/libwchess/Chessboard.h +33 -0
  157. data/ext/sources/examples/wchess/libwchess/WChess.cpp +193 -0
  158. data/ext/sources/examples/wchess/libwchess/WChess.h +63 -0
  159. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +117 -0
  160. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +8 -0
  161. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +249 -0
  162. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +50 -0
  163. data/ext/sources/examples/whisper.wasm/emscripten.cpp +118 -0
  164. data/ext/sources/examples/whisper.wasm/index-tmpl.html +658 -0
  165. data/ext/sources/ggml/CMakeLists.txt +390 -0
  166. data/ext/sources/ggml/cmake/BuildTypes.cmake +54 -0
  167. data/ext/sources/ggml/cmake/GitVars.cmake +22 -0
  168. data/ext/sources/ggml/cmake/common.cmake +26 -0
  169. data/ext/sources/ggml/cmake/ggml-config.cmake.in +152 -0
  170. data/ext/{ggml → sources/ggml}/include/ggml-alloc.h +1 -1
  171. data/ext/{ggml → sources/ggml}/include/ggml-backend.h +9 -7
  172. data/ext/{ggml → sources/ggml}/include/ggml-cpp.h +2 -1
  173. data/ext/{ggml → sources/ggml}/include/ggml-cpu.h +9 -1
  174. data/ext/{ggml → sources/ggml}/include/ggml-metal.h +1 -1
  175. data/ext/{ggml → sources/ggml}/include/ggml-opt.h +49 -28
  176. data/ext/{ggml → sources/ggml}/include/ggml-rpc.h +6 -1
  177. data/ext/{ggml → sources/ggml}/include/ggml-vulkan.h +0 -2
  178. data/ext/{ggml → sources/ggml}/include/ggml.h +182 -265
  179. data/ext/sources/ggml/include/gguf.h +202 -0
  180. data/ext/sources/ggml/src/CMakeLists.txt +346 -0
  181. data/ext/{ggml → sources/ggml}/src/ggml-alloc.c +34 -29
  182. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  183. data/ext/{ggml → sources/ggml}/src/ggml-backend-impl.h +1 -2
  184. data/ext/{ggml → sources/ggml}/src/ggml-backend-reg.cpp +87 -53
  185. data/ext/{ggml → sources/ggml}/src/ggml-backend.cpp +26 -14
  186. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  187. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +74 -0
  188. data/ext/sources/ggml/src/ggml-cann/Doxyfile +2579 -0
  189. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.cpp +10 -4
  190. data/ext/{ggml → sources/ggml}/src/ggml-cann/acl_tensor.h +5 -5
  191. data/ext/{ggml → sources/ggml}/src/ggml-cann/aclnn_ops.cpp +1272 -1506
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +1125 -0
  193. data/ext/{ggml → sources/ggml}/src/ggml-cann/common.h +135 -1
  194. data/ext/{ggml → sources/ggml}/src/ggml-cann/ggml-cann.cpp +564 -146
  195. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +30 -0
  196. data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/dup.cpp +3 -5
  197. data/ext/{ggml → sources/ggml}/src/ggml-common.h +12 -8
  198. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +504 -0
  199. data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.cpp +2 -1
  200. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  201. data/ext/sources/ggml/src/ggml-cpu/binary-ops.h +16 -0
  202. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
  203. data/ext/sources/ggml/src/ggml-cpu/common.h +72 -0
  204. data/ext/{ggml → sources/ggml}/src/ggml-cpu/cpu-feats-x86.cpp +5 -1
  205. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +6431 -0
  206. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-impl.h +163 -41
  207. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.c +4029 -1117
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +3510 -0
  209. data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu.cpp +67 -18
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +337 -0
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +95 -0
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +482 -0
  213. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +3544 -0
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +8903 -0
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +110 -0
  218. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  219. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  220. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +28 -0
  221. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +252 -0
  222. data/ext/sources/ggml/src/ggml-cpu/vec.h +818 -0
  223. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +184 -0
  224. data/ext/sources/ggml/src/ggml-cuda/acc.cu +61 -0
  225. data/ext/sources/ggml/src/ggml-cuda/acc.cuh +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/arange.cu +34 -0
  227. data/ext/sources/ggml/src/ggml-cuda/arange.cuh +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +91 -0
  229. data/ext/sources/ggml/src/ggml-cuda/argmax.cuh +3 -0
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +104 -0
  231. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +3 -0
  232. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +363 -0
  233. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +9 -0
  234. data/ext/sources/ggml/src/ggml-cuda/clamp.cu +45 -0
  235. data/ext/sources/ggml/src/ggml-cuda/clamp.cuh +5 -0
  236. data/ext/sources/ggml/src/ggml-cuda/common.cuh +828 -0
  237. data/ext/sources/ggml/src/ggml-cuda/concat.cu +221 -0
  238. data/ext/sources/ggml/src/ggml-cuda/concat.cuh +5 -0
  239. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +89 -0
  240. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cuh +5 -0
  241. data/ext/sources/ggml/src/ggml-cuda/convert.cu +730 -0
  242. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +26 -0
  243. data/ext/sources/ggml/src/ggml-cuda/count-equal.cu +64 -0
  244. data/ext/sources/ggml/src/ggml-cuda/count-equal.cuh +5 -0
  245. data/ext/sources/ggml/src/ggml-cuda/cp-async.cuh +57 -0
  246. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +705 -0
  247. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +11 -0
  248. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +189 -0
  249. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cuh +7 -0
  250. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +103 -0
  251. data/ext/sources/ggml/src/ggml-cuda/diagmask.cu +40 -0
  252. data/ext/sources/ggml/src/ggml-cuda/diagmask.cuh +5 -0
  253. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +881 -0
  254. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +1471 -0
  255. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +357 -0
  256. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +3 -0
  257. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +365 -0
  258. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +3 -0
  259. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +482 -0
  260. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +472 -0
  261. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +634 -0
  262. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +3 -0
  263. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +346 -0
  264. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +3 -0
  265. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +275 -0
  266. data/ext/sources/ggml/src/ggml-cuda/getrows.cuh +15 -0
  267. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +3505 -0
  268. data/ext/sources/ggml/src/ggml-cuda/gla.cu +93 -0
  269. data/ext/sources/ggml/src/ggml-cuda/gla.cuh +3 -0
  270. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +103 -0
  271. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +5 -0
  272. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +396 -0
  273. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +324 -0
  274. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +3217 -0
  275. data/ext/sources/ggml/src/ggml-cuda/mmv.cu +336 -0
  276. data/ext/sources/ggml/src/ggml-cuda/mmv.cuh +12 -0
  277. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +595 -0
  278. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +12 -0
  279. data/ext/sources/ggml/src/ggml-cuda/norm.cu +458 -0
  280. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +11 -0
  281. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cu +78 -0
  282. data/ext/sources/ggml/src/ggml-cuda/opt-step-adamw.cuh +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +68 -0
  284. data/ext/sources/ggml/src/ggml-cuda/out-prod.cuh +3 -0
  285. data/ext/sources/ggml/src/ggml-cuda/pad.cu +49 -0
  286. data/ext/sources/ggml/src/ggml-cuda/pad.cuh +5 -0
  287. data/ext/sources/ggml/src/ggml-cuda/pool2d.cu +94 -0
  288. data/ext/sources/ggml/src/ggml-cuda/pool2d.cuh +5 -0
  289. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +190 -0
  290. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +27 -0
  291. data/ext/sources/ggml/src/ggml-cuda/rope.cu +456 -0
  292. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +7 -0
  293. data/ext/sources/ggml/src/ggml-cuda/scale.cu +31 -0
  294. data/ext/sources/ggml/src/ggml-cuda/scale.cuh +5 -0
  295. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +283 -0
  296. data/ext/sources/ggml/src/ggml-cuda/softmax.cuh +7 -0
  297. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +148 -0
  298. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +3 -0
  299. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +153 -0
  300. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cuh +3 -0
  301. data/ext/sources/ggml/src/ggml-cuda/sum.cu +45 -0
  302. data/ext/sources/ggml/src/ggml-cuda/sum.cuh +5 -0
  303. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +39 -0
  304. data/ext/sources/ggml/src/ggml-cuda/sumrows.cuh +5 -0
  305. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +5 -0
  306. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +10 -0
  307. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
  308. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
  309. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
  310. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +5 -0
  311. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +10 -0
  312. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +10 -0
  313. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
  314. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
  315. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +5 -0
  316. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +10 -0
  317. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +5 -0
  334. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +5 -0
  335. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +5 -0
  336. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +5 -0
  337. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +5 -0
  338. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +5 -0
  339. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +5 -0
  341. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +5 -0
  342. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +5 -0
  343. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +5 -0
  344. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +5 -0
  345. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +5 -0
  346. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +5 -0
  347. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +5 -0
  348. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +5 -0
  349. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +5 -0
  350. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +5 -0
  351. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +5 -0
  352. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +5 -0
  353. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +5 -0
  354. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +5 -0
  355. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +5 -0
  356. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +5 -0
  357. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +5 -0
  358. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +5 -0
  359. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +5 -0
  360. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +5 -0
  361. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +5 -0
  362. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +5 -0
  363. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +5 -0
  364. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +5 -0
  365. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +5 -0
  366. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +5 -0
  367. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +5 -0
  368. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +5 -0
  369. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +5 -0
  370. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +5 -0
  371. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +5 -0
  372. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +5 -0
  373. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +5 -0
  374. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +5 -0
  375. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +5 -0
  376. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +5 -0
  377. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +5 -0
  378. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +5 -0
  379. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +5 -0
  380. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +5 -0
  381. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +5 -0
  382. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +5 -0
  383. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +5 -0
  384. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +5 -0
  385. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +5 -0
  386. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +5 -0
  387. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +5 -0
  388. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +5 -0
  389. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +5 -0
  390. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +5 -0
  391. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +5 -0
  392. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +5 -0
  393. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +5 -0
  394. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +5 -0
  395. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +5 -0
  396. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +5 -0
  397. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +5 -0
  398. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +5 -0
  399. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +5 -0
  400. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +5 -0
  401. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +5 -0
  402. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +5 -0
  403. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +5 -0
  404. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +5 -0
  405. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +5 -0
  406. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +5 -0
  407. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +5 -0
  408. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +5 -0
  409. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +5 -0
  410. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +78 -0
  411. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu +5 -0
  412. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu +5 -0
  413. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu +5 -0
  414. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu +5 -0
  415. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu +5 -0
  416. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu +5 -0
  417. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu +5 -0
  418. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu +5 -0
  419. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu +5 -0
  420. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu +5 -0
  421. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu +5 -0
  422. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu +5 -0
  423. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu +5 -0
  424. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu +5 -0
  425. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu +5 -0
  426. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu +5 -0
  427. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu +5 -0
  428. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu +5 -0
  429. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +47 -0
  430. data/ext/sources/ggml/src/ggml-cuda/tsembd.cuh +5 -0
  431. data/ext/sources/ggml/src/ggml-cuda/unary.cu +289 -0
  432. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +59 -0
  433. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +51 -0
  434. data/ext/sources/ggml/src/ggml-cuda/upscale.cuh +5 -0
  435. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +1135 -0
  436. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/cuda.h +1 -0
  437. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/hip.h +57 -0
  438. data/ext/{ggml → sources/ggml}/src/ggml-cuda/vendors/musa.h +7 -1
  439. data/ext/sources/ggml/src/ggml-cuda/wkv.cu +199 -0
  440. data/ext/sources/ggml/src/ggml-cuda/wkv.cuh +7 -0
  441. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +131 -0
  442. data/ext/{ggml → sources/ggml}/src/ggml-impl.h +64 -19
  443. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  444. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +112 -0
  445. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +58 -0
  446. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +25 -0
  447. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +52 -0
  448. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +52 -0
  449. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +52 -0
  450. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +52 -0
  451. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +30 -0
  452. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +22 -0
  453. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +17 -0
  454. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +31 -0
  455. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +31 -0
  456. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +38 -0
  457. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +39 -0
  458. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +44 -0
  459. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +52 -0
  460. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +69 -0
  461. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +51 -0
  462. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +33 -0
  463. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +35 -0
  464. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +140 -0
  465. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +106 -0
  466. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +73 -0
  467. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +52 -0
  468. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +28 -0
  469. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +84 -0
  470. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +21 -0
  471. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +53 -0
  472. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +52 -0
  473. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +52 -0
  474. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +52 -0
  475. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +52 -0
  476. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +19 -0
  477. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +23 -0
  478. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +22 -0
  479. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +72 -0
  480. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +71 -0
  481. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +120 -0
  482. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +622 -0
  483. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.m +2178 -1064
  484. data/ext/{ggml → sources/ggml}/src/ggml-metal/ggml-metal.metal +1575 -1218
  485. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +113 -0
  486. data/ext/sources/ggml/src/ggml-musa/mudnn.cu +112 -0
  487. data/ext/sources/ggml/src/ggml-musa/mudnn.cuh +12 -0
  488. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +96 -0
  489. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +5124 -0
  490. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +83 -0
  491. data/ext/sources/ggml/src/ggml-opencl/kernels/clamp.cl +20 -0
  492. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +184 -0
  493. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +118 -0
  494. data/ext/sources/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl +58 -0
  495. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +26 -0
  496. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +62 -0
  497. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +268 -0
  498. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +274 -0
  499. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +163 -0
  500. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +57 -0
  501. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +57 -0
  502. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +79 -0
  503. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +139 -0
  504. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl +118 -0
  505. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl +118 -0
  506. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl +94 -0
  507. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl +84 -0
  508. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl +118 -0
  509. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl +192 -0
  510. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl +307 -0
  511. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl +265 -0
  512. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl +272 -0
  513. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl +254 -0
  514. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +190 -0
  515. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +81 -0
  516. data/ext/sources/ggml/src/ggml-opencl/kernels/relu.cl +16 -0
  517. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +96 -0
  518. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +721 -0
  519. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +16 -0
  520. data/ext/sources/ggml/src/ggml-opencl/kernels/silu.cl +30 -0
  521. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +87 -0
  522. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +87 -0
  523. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +86 -0
  524. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +86 -0
  525. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +84 -0
  526. data/ext/{ggml → sources/ggml}/src/ggml-opt.cpp +373 -190
  527. data/ext/{ggml → sources/ggml}/src/ggml-quants.c +114 -120
  528. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  529. data/ext/{ggml → sources/ggml}/src/ggml-rpc/ggml-rpc.cpp +480 -73
  530. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +189 -0
  531. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +37 -0
  532. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +345 -0
  533. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  534. data/ext/{ggml → sources/ggml}/src/ggml-sycl/common.cpp +20 -32
  535. data/ext/sources/ggml/src/ggml-sycl/common.hpp +589 -0
  536. data/ext/{ggml → sources/ggml}/src/ggml-sycl/concat.cpp +32 -33
  537. data/ext/sources/ggml/src/ggml-sycl/concat.hpp +20 -0
  538. data/ext/{ggml → sources/ggml}/src/ggml-sycl/conv.cpp +4 -2
  539. data/ext/sources/ggml/src/ggml-sycl/conv.hpp +20 -0
  540. data/ext/{ggml → sources/ggml}/src/ggml-sycl/convert.cpp +104 -28
  541. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +34 -0
  542. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +700 -0
  543. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +11 -0
  544. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +791 -0
  545. data/ext/{ggml → sources/ggml}/src/ggml-sycl/dmmv.cpp +156 -17
  546. data/ext/sources/ggml/src/ggml-sycl/dmmv.hpp +27 -0
  547. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +2957 -0
  548. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1511 -0
  549. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +75 -0
  550. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +99 -0
  551. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +309 -0
  552. data/ext/sources/ggml/src/ggml-sycl/getrows.hpp +20 -0
  553. data/ext/{ggml → sources/ggml}/src/ggml-sycl/ggml-sycl.cpp +1004 -1240
  554. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +106 -0
  555. data/ext/sources/ggml/src/ggml-sycl/gla.hpp +8 -0
  556. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +136 -0
  557. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +21 -0
  558. data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmq.cpp +0 -1
  559. data/ext/sources/ggml/src/ggml-sycl/mmq.hpp +33 -0
  560. data/ext/{ggml → sources/ggml}/src/ggml-sycl/mmvq.cpp +261 -166
  561. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +27 -0
  562. data/ext/{ggml → sources/ggml}/src/ggml-sycl/norm.cpp +204 -81
  563. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +26 -0
  564. data/ext/{ggml → sources/ggml}/src/ggml-sycl/outprod.cpp +8 -17
  565. data/ext/sources/ggml/src/ggml-sycl/outprod.hpp +10 -0
  566. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +74 -0
  567. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +83 -0
  568. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +361 -0
  569. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +20 -0
  570. data/ext/{ggml → sources/ggml}/src/ggml-sycl/softmax.cpp +35 -25
  571. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +20 -0
  572. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  573. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  574. data/ext/{ggml → sources/ggml}/src/ggml-sycl/tsembd.cpp +3 -3
  575. data/ext/sources/ggml/src/ggml-sycl/tsembd.hpp +20 -0
  576. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +1215 -0
  577. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +293 -0
  578. data/ext/sources/ggml/src/ggml-sycl/wkv.hpp +10 -0
  579. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +196 -0
  580. data/ext/sources/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in +15 -0
  581. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/ggml-vulkan.cpp +3130 -1087
  582. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +39 -0
  583. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +29 -0
  584. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +29 -0
  585. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +51 -0
  586. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +69 -0
  587. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +17 -0
  588. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +41 -0
  589. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +49 -0
  590. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +105 -0
  591. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +23 -0
  592. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +51 -0
  593. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +242 -0
  594. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +17 -0
  595. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +31 -0
  596. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +20 -0
  597. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +462 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +699 -0
  599. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp +13 -0
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +42 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +35 -0
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +44 -0
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +43 -0
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +48 -0
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +39 -0
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +49 -0
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +32 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +34 -0
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +34 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +42 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +30 -0
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +32 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +68 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +34 -0
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +35 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +70 -0
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +33 -0
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +31 -0
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +34 -0
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +27 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +337 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +162 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +360 -0
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +267 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +59 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +25 -0
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +23 -0
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +64 -0
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp +9 -0
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp +76 -0
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +33 -0
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +41 -0
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +66 -0
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +100 -0
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +41 -0
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +22 -0
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +27 -0
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp +48 -0
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +169 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +118 -0
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +82 -0
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +79 -0
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +90 -0
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +87 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +87 -0
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +90 -0
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +88 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +118 -0
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +154 -0
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +130 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +132 -0
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +136 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +167 -0
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +130 -0
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +868 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +441 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +442 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +99 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +44 -0
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +42 -0
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +28 -0
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +74 -0
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +77 -0
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +21 -0
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +26 -0
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +37 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +52 -0
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +55 -0
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +58 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +60 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +43 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +43 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +47 -0
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +24 -0
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +20 -0
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +22 -0
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +26 -0
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +17 -0
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +173 -0
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +50 -0
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +17 -0
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +29 -0
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +37 -0
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +20 -0
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp +7 -0
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp +7 -0
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp +7 -0
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp +7 -0
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +41 -0
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +1373 -0
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +36 -0
  692. data/ext/{ggml → sources/ggml}/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -35
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +87 -0
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp +91 -0
  695. data/ext/{ggml → sources/ggml}/src/ggml.c +676 -1820
  696. data/ext/sources/ggml/src/gguf.cpp +1330 -0
  697. data/ext/{include → sources/include}/whisper.h +68 -2
  698. data/ext/sources/src/CMakeLists.txt +143 -0
  699. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.h +27 -15
  700. data/ext/{src → sources/src}/coreml/whisper-decoder-impl.m +35 -10
  701. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.h +21 -9
  702. data/ext/{src → sources/src}/coreml/whisper-encoder-impl.m +28 -3
  703. data/ext/sources/src/coreml/whisper-encoder.mm +73 -0
  704. data/ext/sources/src/whisper-arch.h +197 -0
  705. data/ext/{src → sources/src}/whisper.cpp +1905 -374
  706. data/ext/sources/tests/CMakeLists.txt +105 -0
  707. data/ext/sources/tests/earnings21/eval.mk +58 -0
  708. data/ext/sources/tests/earnings21/eval.py +68 -0
  709. data/ext/sources/tests/earnings21/normalizers/__init__.py +2 -0
  710. data/ext/sources/tests/earnings21/normalizers/basic.py +80 -0
  711. data/ext/sources/tests/earnings21/normalizers/english.json +1741 -0
  712. data/ext/sources/tests/earnings21/normalizers/english.py +550 -0
  713. data/ext/sources/tests/earnings21/requirements.txt +6 -0
  714. data/ext/sources/tests/en-0-ref.txt +1 -0
  715. data/ext/sources/tests/en-1-ref.txt +1 -0
  716. data/ext/sources/tests/en-2-ref.txt +1 -0
  717. data/ext/sources/tests/es-0-ref.txt +1 -0
  718. data/ext/sources/tests/librispeech/eval.mk +39 -0
  719. data/ext/sources/tests/librispeech/eval.py +47 -0
  720. data/ext/sources/tests/librispeech/normalizers/__init__.py +2 -0
  721. data/ext/sources/tests/librispeech/normalizers/basic.py +80 -0
  722. data/ext/sources/tests/librispeech/normalizers/english.json +1741 -0
  723. data/ext/sources/tests/librispeech/normalizers/english.py +550 -0
  724. data/ext/sources/tests/librispeech/requirements.txt +6 -0
  725. data/ext/sources/tests/run-tests.sh +130 -0
  726. data/ext/sources/tests/test-c.c +3 -0
  727. data/ext/sources/tests/test-vad-full.cpp +54 -0
  728. data/ext/sources/tests/test-vad.cpp +83 -0
  729. data/ext/sources/tests/test-whisper.js +58 -0
  730. data/extsources.rb +33 -5
  731. data/lib/whisper/model/uri.rb +149 -128
  732. data/sig/whisper.rbs +480 -0
  733. data/tests/helper.rb +28 -0
  734. data/tests/test_callback.rb +45 -3
  735. data/tests/test_error.rb +2 -2
  736. data/tests/test_model.rb +38 -0
  737. data/tests/test_package.rb +18 -3
  738. data/tests/test_params.rb +145 -8
  739. data/tests/test_segment.rb +10 -19
  740. data/tests/test_vad.rb +19 -0
  741. data/tests/test_vad_params.rb +103 -0
  742. data/tests/test_whisper.rb +37 -37
  743. data/whispercpp.gemspec +5 -4
  744. metadata +766 -111
  745. data/ext/cpu.mk +0 -9
  746. data/ext/examples/dr_wav.h +0 -8815
  747. data/ext/ggml/src/ggml-cann/aclnn_ops.h +0 -592
  748. data/ext/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -4262
  749. data/ext/ggml/src/ggml-cpu/ggml-cpu.c +0 -14123
  750. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.cpp +0 -1884
  751. data/ext/ggml/src/ggml-cpu/llamafile/sgemm.h +0 -14
  752. data/ext/ggml/src/ggml-metal/ggml-metal-impl.h +0 -288
  753. data/ext/ggml/src/ggml-sycl/element_wise.cpp +0 -1030
  754. data/ext/ggml/src/ggml-sycl/im2col.cpp +0 -126
  755. data/ext/ggml/src/ggml-sycl/rope.cpp +0 -276
  756. data/ext/ggml/src/ggml-sycl/wkv6.cpp +0 -141
  757. data/ext/metal-embed.mk +0 -17
  758. data/ext/metal.mk +0 -6
  759. data/ext/ruby_whisper.cpp +0 -1909
  760. data/ext/scripts/get-flags.mk +0 -38
  761. data/lib/whisper.rb +0 -2
  762. /data/ext/{ggml → sources/ggml}/include/ggml-blas.h +0 -0
  763. /data/ext/{ggml → sources/ggml}/include/ggml-cann.h +0 -0
  764. /data/ext/{ggml → sources/ggml}/include/ggml-cuda.h +0 -0
  765. /data/ext/{ggml → sources/ggml}/include/ggml-kompute.h +0 -0
  766. /data/ext/{ggml → sources/ggml}/include/ggml-opencl.h +0 -0
  767. /data/ext/{ggml → sources/ggml}/include/ggml-sycl.h +0 -0
  768. /data/ext/{ggml → sources/ggml}/src/ggml-amx/common.h +0 -0
  769. /data/ext/{ggml → sources/ggml}/src/ggml-amx/ggml-amx.cpp +0 -0
  770. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.cpp +0 -0
  771. /data/ext/{ggml → sources/ggml}/src/ggml-amx/mmq.h +0 -0
  772. /data/ext/{ggml → sources/ggml}/src/ggml-blas/ggml-blas.cpp +0 -0
  773. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/ascendc_kernels.h +0 -0
  774. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f16.cpp +0 -0
  775. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_f32.cpp +0 -0
  776. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -0
  777. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -0
  778. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -0
  779. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -0
  780. /data/ext/{ggml → sources/ggml}/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -0
  781. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/amx.h +0 -0
  782. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/common.h +0 -0
  783. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.cpp +0 -0
  784. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/amx/mmq.h +0 -0
  785. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-aarch64.h +0 -0
  786. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.cpp +0 -0
  787. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-hbm.h +0 -0
  788. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-quants.h +0 -0
  789. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.cpp +0 -0
  790. /data/ext/{ggml → sources/ggml}/src/ggml-cpu/ggml-cpu-traits.h +0 -0
  791. /data/ext/{ggml → sources/ggml}/src/ggml-kompute/ggml-kompute.cpp +0 -0
  792. /data/ext/{ggml → sources/ggml}/src/ggml-quants.h +0 -0
  793. /data/ext/{ggml → sources/ggml}/src/ggml-threading.cpp +0 -0
  794. /data/ext/{ggml → sources/ggml}/src/ggml-threading.h +0 -0
  795. /data/ext/{src → sources/src}/coreml/whisper-encoder.h +0 -0
  796. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.cpp +0 -0
  797. /data/ext/{src → sources/src}/openvino/whisper-openvino-encoder.h +0 -0
@@ -28,8 +28,8 @@
28
28
  #include <aclnnop/aclnn_cast.h>
29
29
  #include <aclnnop/aclnn_constant_pad_nd.h>
30
30
  #include <aclnnop/aclnn_copy.h>
31
- #include <aclnnop/aclnn_cos.h>
32
31
  #include <aclnnop/aclnn_div.h>
32
+ #include <aclnnop/aclnn_embedding.h>
33
33
  #include <aclnnop/aclnn_exp.h>
34
34
  #include <aclnnop/aclnn_fill_scalar.h>
35
35
  #include <aclnnop/aclnn_group_norm.h>
@@ -44,12 +44,29 @@
44
44
  #include <aclnnop/aclnn_repeat.h>
45
45
  #include <aclnnop/aclnn_repeat_interleave.h>
46
46
  #include <aclnnop/aclnn_roll.h>
47
- #include <aclnnop/aclnn_sin.h>
48
47
  #include <aclnnop/aclnn_softmax.h>
49
48
  #include <aclnnop/aclnn_tril.h>
50
49
  #include <aclnnop/aclnn_triu.h>
51
50
  #include <aclnnop/aclnn_upsample_nearest_2d.h>
52
51
  #include <aclnnop/aclnn_weight_quant_batch_matmul_v2.h>
52
+ #include <aclnnop/aclnn_argmax.h>
53
+ #include <aclnnop/aclnn_sum.h>
54
+ #include <aclnnop/aclnn_rms_norm.h>
55
+ #include <aclnnop/aclnn_im2col.h>
56
+ #include <aclnnop/aclnn_add.h>
57
+ #include <aclnnop/aclnn_sub.h>
58
+ #include <aclnnop/aclnn_mul.h>
59
+ #include <aclnnop/aclnn_div.h>
60
+ #include <aclnnop/aclnn_convolution.h>
61
+ #include <aclnnop/aclnn_elu.h>
62
+ #include <aclnnop/aclnn_log.h>
63
+ #include <aclnnop/aclnn_mean.h>
64
+ #include <aclnnop/aclnn_reflection_pad1d.h>
65
+ #include <aclnnop/aclnn_eq_tensor.h>
66
+ #include <aclnnop/aclnn_gt_scalar.h>
67
+ #include <aclnnop/aclnn_pow.h>
68
+ #include <aclnnop/aclnn_grouped_matmul_v2.h>
69
+ #include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
53
70
  #include <float.h>
54
71
 
55
72
  #include <cmath>
@@ -58,12 +75,41 @@
58
75
  #include <vector>
59
76
 
60
77
  #include "ggml-impl.h"
61
- #include "kernels/ascendc_kernels.h"
78
+ #include "ggml.h"
62
79
 
63
80
  #define GGML_COMMON_DECL_C
64
81
 
65
82
  #include "../ggml-common.h"
66
83
 
84
+
85
+ void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclTensor ** acl_src0,
86
+ aclTensor ** acl_src1, aclTensor ** acl_dst) {
87
+ GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_can_repeat(src1, src0));
88
+ // Need bcast
89
+ if (!ggml_are_same_shape(src0, src1) && ggml_cann_need_bcast(src0, src1)) {
90
+ BCAST_SHAPE(src0, src1)
91
+ *acl_src0 = ggml_cann_create_tensor(src0, BCAST_PARAM(src0));
92
+ *acl_src1 = ggml_cann_create_tensor(src1, BCAST_PARAM(src1));
93
+ *acl_dst = ggml_cann_create_tensor(dst, BCAST_PARAM(src0));
94
+ } else {
95
+ *acl_src0 = ggml_cann_create_tensor(src0);
96
+ *acl_src1 = ggml_cann_create_tensor(src1);
97
+ *acl_dst = ggml_cann_create_tensor(dst);
98
+ }
99
+ }
100
+
101
+ void ggml_cann_unary_op(
102
+ std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
103
+ ggml_backend_cann_context& ctx, ggml_tensor* dst) {
104
+ ggml_tensor* src = dst->src[0];
105
+
106
+ aclTensor* acl_src = ggml_cann_create_tensor(src);
107
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
108
+
109
+ unary_op(ctx, acl_src, acl_dst);
110
+ ggml_cann_release_resources(ctx, acl_src, acl_dst);
111
+ }
112
+
67
113
  /**
68
114
  * @brief Repeats elements of a tensor along each dimension according to the
69
115
  * specified repeat array.
@@ -79,24 +125,26 @@ static void aclnn_repeat(ggml_backend_cann_context& ctx, aclTensor* acl_src,
79
125
  // repeat tensor along each dim with repeat_array
80
126
  aclIntArray* repeats = aclCreateIntArray(repeat_array, GGML_MAX_DIMS);
81
127
 
82
- uint64_t workspaceSize = 0;
83
- aclOpExecutor* executor;
84
- void* workspaceAddr = nullptr;
85
-
86
- ACL_CHECK(aclnnRepeatGetWorkspaceSize(acl_src, repeats, acl_dst,
87
- &workspaceSize, &executor));
128
+ GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_src, repeats, acl_dst);
129
+ ggml_cann_release_resources(ctx, repeats);
130
+ }
88
131
 
89
- if (workspaceSize > 0) {
90
- // Memory from allocator will "free" immediately, and this memory
91
- // will be alloced to other pointers, but it won't access before
92
- // this async task end because all tasks in same stream will execute
93
- // in queue.
94
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
95
- workspaceAddr = workspace_allocator.get();
96
- }
97
- ACL_CHECK(
98
- aclnnRepeat(workspaceAddr, workspaceSize, executor, ctx.stream()));
99
- ACL_CHECK(aclDestroyIntArray(repeats));
132
+ /**
133
+ * @brief Casts the data type of a source tensor to a destination tensor.
134
+ *
135
+ * This function casts the data type of the source tensor `acl_src` to the
136
+ * specified data type `cast_data_type` and stores the result in the destination
137
+ * tensor `acl_dst`.
138
+ *
139
+ * @param ctx The context for the CANN backend operations.
140
+ * @param acl_src The source tensor whose data type will be casted.
141
+ * @param acl_dst The destination tensor where the casted result will be stored.
142
+ * @param cast_data_type The target data type to which the source tensor will be
143
+ * casted.
144
+ */
145
+ static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src,
146
+ aclTensor* acl_dst, aclDataType cast_data_type) {
147
+ GGML_CANN_CALL_ACLNN_OP(ctx, Cast, acl_src, cast_data_type, acl_dst);
100
148
  }
101
149
 
102
150
  void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -110,73 +158,78 @@ void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
110
158
  dst->ne[1] / src->ne[1], dst->ne[0] / src->ne[0]};
111
159
 
112
160
  aclnn_repeat(ctx, acl_src, acl_dst, repeatsArray);
113
- ACL_CHECK(aclDestroyTensor(acl_src));
114
- ACL_CHECK(aclDestroyTensor(acl_dst));
161
+ ggml_cann_release_resources(ctx, acl_src, acl_dst);
115
162
  }
116
163
 
117
- /**
118
- * @brief Adds two tensors element-wise and stores the result in a destination
119
- * tensor.
120
- *
121
- * This function performs the operation:
122
- * \f[
123
- * dst = acl\_src0 + alpha \times acl\_src1
124
- * \f]
125
- * where alpha is a scalar value and defaults to 1.0f.
126
- *
127
- * @param ctx The context for the CANN backend operations.
128
- * @param acl_src0 The first source tensor.
129
- * @param acl_src1 The second source tensor.
130
- * @param acl_dst The destination tensor where the result will be stored.
131
- */
132
- static void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
164
+ void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
133
165
  aclTensor* acl_src1, aclTensor* acl_dst) {
134
- aclScalar* alpha = nullptr;
135
166
  float alphaValue = 1.0f;
136
- alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
137
-
138
- uint64_t workspaceSize = 0;
139
- aclOpExecutor* executor;
140
- void* workspaceAddr = nullptr;
141
-
142
- ACL_CHECK(aclnnAddGetWorkspaceSize(acl_src0, acl_src1, alpha, acl_dst,
143
- &workspaceSize, &executor));
144
- if (workspaceSize > 0) {
145
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
146
- workspaceAddr = workspace_allocator.get();
147
- }
148
-
149
- ACL_CHECK(aclnnAdd(workspaceAddr, workspaceSize, executor, ctx.stream()));
167
+ aclScalar* alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
168
+ if (acl_dst != nullptr)
169
+ GGML_CANN_CALL_ACLNN_OP(ctx, Add, acl_src0, acl_src1, alpha, acl_dst);
170
+ else
171
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_src0, acl_src1, alpha);
172
+ ggml_cann_release_resources(ctx, alpha);
173
+ }
150
174
 
151
- ACL_CHECK(aclDestroyScalar(alpha));
175
+ void aclnn_sub(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
176
+ aclTensor* acl_src1, aclTensor* acl_dst) {
177
+ float alphaValue = 1.0f;
178
+ aclScalar* alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
179
+ if (acl_dst != nullptr)
180
+ GGML_CANN_CALL_ACLNN_OP(ctx, Sub, acl_src0, acl_src1, alpha, acl_dst);
181
+ else
182
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSub, acl_src0, acl_src1, alpha);
183
+ ggml_cann_release_resources(ctx, alpha);
152
184
  }
153
185
 
154
- void ggml_cann_add(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
155
- ggml_tensor* src0 = dst->src[0];
156
- ggml_tensor* src1 = dst->src[1];
157
- GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
186
+ void aclnn_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src,
187
+ aclTensor* acl_other, aclTensor* acl_dst) {
188
+ if (acl_dst != nullptr)
189
+ GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_src, acl_other, acl_dst);
190
+ else
191
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_src, acl_other);
192
+ }
158
193
 
159
- aclTensor* acl_src0;
160
- aclTensor* acl_src1;
161
- aclTensor* acl_dst;
194
+ void aclnn_div(ggml_backend_cann_context& ctx, aclTensor* acl_src,
195
+ aclTensor* acl_other, aclTensor* acl_dst) {
196
+ if (acl_dst != nullptr)
197
+ GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src, acl_other, acl_dst);
198
+ else
199
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDiv, acl_src, acl_other);
200
+ }
162
201
 
163
- // Need bcast
164
- if (!ggml_are_same_shape(src0, src1) && ggml_cann_need_bcast(src0, src1)) {
165
- BCAST_SHAPE(src0, src1)
166
- acl_src0 = ggml_cann_create_tensor(src0, BCAST_PARAM(src0));
167
- acl_src1 = ggml_cann_create_tensor(src1, BCAST_PARAM(src1));
168
- acl_dst = ggml_cann_create_tensor(dst, BCAST_PARAM(src0));
202
+ /**
203
+ * @brief Multiplies elements of a tensor by a scalar value, optionally
204
+ * in-place.
205
+ *
206
+ * This function multiplies each element of the source tensor `acl_src` by the
207
+ * scalar `scale` and stores the result in the destination tensor `acl_dst`. If
208
+ * `inplace` is true, `acl_dst` will not be used and the operation is performed
209
+ * in-place on `acl_src`.
210
+ * The operation is defined as:
211
+ * \f[
212
+ * \text {acl_dst }_i=\text {acl_src }_i \times \text {scale}
213
+ * \f]
214
+ *
215
+ * @param ctx The context for the CANN backend operations.
216
+ * @param acl_src The source tensor whose elements will be multiplied.
217
+ * @param scale The scalar value by which each element of `acl_src` will be
218
+ * multiplied.
219
+ * @param acl_dst The destination tensor where the result will be stored if
220
+ * `inplace` is false.
221
+ * @param inplace Flag indicating whether to perform the operation in-place on
222
+ * `acl_src`.
223
+ */
224
+ static void aclnn_muls(ggml_backend_cann_context& ctx, aclTensor* acl_src,
225
+ float scale, aclTensor* acl_dst, bool inplace) {
226
+ aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT);
227
+ if (inplace) {
228
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_src, acl_scale);
169
229
  } else {
170
- acl_src0 = ggml_cann_create_tensor(src0);
171
- acl_src1 = ggml_cann_create_tensor(src1);
172
- acl_dst = ggml_cann_create_tensor(dst);
230
+ GGML_CANN_CALL_ACLNN_OP(ctx, Muls, acl_src, acl_scale, acl_dst);
173
231
  }
174
-
175
- aclnn_add(ctx, acl_src0, acl_src1, acl_dst);
176
-
177
- ACL_CHECK(aclDestroyTensor(acl_src0));
178
- ACL_CHECK(aclDestroyTensor(acl_src1));
179
- ACL_CHECK(aclDestroyTensor(acl_dst));
232
+ ggml_cann_release_resources(ctx, acl_scale);
180
233
  }
181
234
 
182
235
  void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -193,23 +246,8 @@ void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
193
246
  aclScalar* acl_negative_slope =
194
247
  aclCreateScalar(&negative_slope, aclDataType::ACL_FLOAT);
195
248
 
196
- uint64_t workspaceSize = 0;
197
- aclOpExecutor* executor;
198
- void* workspaceAddr = nullptr;
199
-
200
- ACL_CHECK(aclnnLeakyReluGetWorkspaceSize(
201
- acl_src, acl_negative_slope, acl_dst, &workspaceSize, &executor));
202
- if (workspaceSize > 0) {
203
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
204
- workspaceAddr = workspace_allocator.get();
205
- }
206
-
207
- ACL_CHECK(
208
- aclnnLeakyRelu(workspaceAddr, workspaceSize, executor, ctx.stream()));
209
-
210
- ACL_CHECK(aclDestroyScalar(acl_negative_slope));
211
- ACL_CHECK(aclDestroyTensor(acl_src));
212
- ACL_CHECK(aclDestroyTensor(acl_dst));
249
+ GGML_CANN_CALL_ACLNN_OP(ctx, LeakyRelu, acl_src, acl_negative_slope, acl_dst);
250
+ ggml_cann_release_resources(ctx, acl_negative_slope, acl_src, acl_dst);
213
251
  }
214
252
 
215
253
  /**
@@ -225,18 +263,7 @@ void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
225
263
  static void aclnn_concat(ggml_backend_cann_context& ctx,
226
264
  aclTensorList* tensorList, aclTensor* acl_dst,
227
265
  int64_t concat_dim) {
228
- uint64_t workspaceSize = 0;
229
- aclOpExecutor* executor;
230
- void* workspaceAddr = nullptr;
231
-
232
- ACL_CHECK(aclnnCatGetWorkspaceSize(tensorList, concat_dim, acl_dst,
233
- &workspaceSize, &executor));
234
- if (workspaceSize > 0) {
235
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
236
- workspaceAddr = workspace_allocator.get();
237
- }
238
-
239
- ACL_CHECK(aclnnCat(workspaceAddr, workspaceSize, executor, ctx.stream()));
266
+ GGML_CANN_CALL_ACLNN_OP(ctx, Cat, tensorList, concat_dim, acl_dst);
240
267
  }
241
268
 
242
269
  void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -252,11 +279,10 @@ void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
252
279
  int32_t acl_dim = 3 - dim;
253
280
 
254
281
  aclTensor* tensors[] = {acl_src0, acl_src1};
255
- aclTensorList* tensorList = aclCreateTensorList(tensors, 2);
256
- aclnn_concat(ctx, tensorList, acl_dst, acl_dim);
282
+ aclTensorList* tensor_list = aclCreateTensorList(tensors, 2);
283
+ aclnn_concat(ctx, tensor_list, acl_dst, acl_dim);
257
284
 
258
- ACL_CHECK(aclDestroyTensorList(tensorList));
259
- ACL_CHECK(aclDestroyTensor(acl_dst));
285
+ ggml_cann_release_resources(ctx, tensor_list, acl_dst);
260
286
  }
261
287
 
262
288
  /**
@@ -282,27 +308,12 @@ static void aclnn_arange(ggml_backend_cann_context& ctx, aclTensor* acl_dst,
282
308
  int64_t steps = (int64_t)std::ceil((stop - start) / step);
283
309
  GGML_ASSERT(n_elements == steps);
284
310
 
285
- uint64_t workspaceSize = 0;
286
- aclOpExecutor* executor;
287
- void* workspaceAddr = nullptr;
288
-
289
311
  aclScalar* acl_start = aclCreateScalar(&start, aclDataType::ACL_FLOAT);
290
312
  aclScalar* acl_end = aclCreateScalar(&stop, aclDataType::ACL_FLOAT);
291
313
  aclScalar* acl_step = aclCreateScalar(&step, aclDataType::ACL_FLOAT);
292
314
 
293
- ACL_CHECK(aclnnArangeGetWorkspaceSize(acl_start, acl_end, acl_step, acl_dst,
294
- &workspaceSize, &executor));
295
- if (workspaceSize > 0) {
296
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
297
- workspaceAddr = workspace_allocator.get();
298
- }
299
-
300
- ACL_CHECK(
301
- aclnnArange(workspaceAddr, workspaceSize, executor, ctx.stream()));
302
-
303
- ACL_CHECK(aclDestroyScalar(acl_start));
304
- ACL_CHECK(aclDestroyScalar(acl_end));
305
- ACL_CHECK(aclDestroyScalar(acl_step));
315
+ GGML_CANN_CALL_ACLNN_OP(ctx, Arange, acl_start, acl_end, acl_step, acl_dst);
316
+ ggml_cann_release_resources(ctx, acl_start, acl_end, acl_step);
306
317
  }
307
318
 
308
319
  void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -319,18 +330,11 @@ void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
319
330
  memcpy(&step, (float*)dst->op_params + 2, sizeof(float));
320
331
 
321
332
  aclnn_arange(ctx, acl_dst, start, stop, step, n_elements);
322
- ACL_CHECK(aclDestroyTensor(acl_dst));
323
- }
324
-
325
- void ggml_cann_sqr(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
326
- dst->src[1] = dst->src[0];
327
- ggml_cann_mul_div<aclnnMulGetWorkspaceSize, aclnnMul>(ctx, dst);
333
+ ggml_cann_release_resources(ctx, acl_dst);
328
334
  }
329
335
 
330
336
  void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
331
337
  ggml_tensor* src = dst->src[0];
332
- GGML_ASSERT(src->type == GGML_TYPE_F32);
333
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
334
338
 
335
339
  float min;
336
340
  float max;
@@ -343,23 +347,8 @@ void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
343
347
  aclScalar* acl_min = aclCreateScalar(&min, aclDataType::ACL_FLOAT);
344
348
  aclScalar* acl_max = aclCreateScalar(&max, aclDataType::ACL_FLOAT);
345
349
 
346
- uint64_t workspaceSize = 0;
347
- aclOpExecutor* executor;
348
- void* workspaceAddr = nullptr;
349
-
350
- ACL_CHECK(aclnnClampGetWorkspaceSize(acl_src, acl_min, acl_max, acl_dst,
351
- &workspaceSize, &executor));
352
- if (workspaceSize > 0) {
353
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
354
- workspaceAddr = workspace_allocator.get();
355
- }
356
-
357
- ACL_CHECK(aclnnClamp(workspaceAddr, workspaceSize, executor, ctx.stream()));
358
-
359
- ACL_CHECK(aclDestroyScalar(acl_min));
360
- ACL_CHECK(aclDestroyScalar(acl_max));
361
- ACL_CHECK(aclDestroyTensor(acl_src));
362
- ACL_CHECK(aclDestroyTensor(acl_dst));
350
+ GGML_CANN_CALL_ACLNN_OP(ctx, Clamp, acl_src, acl_min, acl_max, acl_dst);
351
+ ggml_cann_release_resources(ctx, acl_min, acl_max, acl_src, acl_dst);
363
352
  }
364
353
 
365
354
  void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -373,22 +362,8 @@ void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
373
362
  aclTensor* acl_src = ggml_cann_create_tensor(src);
374
363
  aclTensor* acl_dst = ggml_cann_create_tensor(dst);
375
364
 
376
- uint64_t workspaceSize = 0;
377
- aclOpExecutor* executor;
378
- void* workspaceAddr = nullptr;
379
-
380
- ACL_CHECK(aclnnMulsGetWorkspaceSize(acl_src, scale, acl_dst, &workspaceSize,
381
- &executor));
382
- if (workspaceSize > 0) {
383
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
384
- workspaceAddr = workspace_allocator.get();
385
- }
386
-
387
- ACL_CHECK(aclnnMuls(workspaceAddr, workspaceSize, executor, ctx.stream()));
388
-
389
- ACL_CHECK(aclDestroyScalar(scale));
390
- ACL_CHECK(aclDestroyTensor(acl_src));
391
- ACL_CHECK(aclDestroyTensor(acl_dst));
365
+ GGML_CANN_CALL_ACLNN_OP(ctx, Muls, acl_src, scale, acl_dst);
366
+ ggml_cann_release_resources(ctx, scale, acl_src, acl_dst);
392
367
  }
393
368
 
394
369
  void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -403,36 +378,10 @@ void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
403
378
  aclTensor* tmp_tensor =
404
379
  ggml_cann_create_tensor(buffer, ACL_INT64, ggml_type_size(dst->type),
405
380
  dst->ne, dst->nb, GGML_MAX_DIMS);
406
-
407
- uint64_t workspaceSize = 0;
408
- aclOpExecutor* executor;
409
- void* workspaceAddr = nullptr;
410
-
411
- ACL_CHECK(aclnnArgsortGetWorkspaceSize(
412
- acl_src, -1, (order == GGML_SORT_ORDER_DESC ? true : false), tmp_tensor,
413
- &workspaceSize, &executor));
414
- if (workspaceSize > 0) {
415
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
416
- workspaceAddr = workspace_allocator.get();
417
- }
418
-
419
- ACL_CHECK(
420
- aclnnArgsort(workspaceAddr, workspaceSize, executor, ctx.stream()));
421
-
422
- workspaceSize = 0;
423
- ACL_CHECK(aclnnCastGetWorkspaceSize(tmp_tensor,
424
- ggml_cann_type_mapping(dst->type),
425
- acl_dst, &workspaceSize, &executor));
426
- if (workspaceSize > 0) {
427
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
428
- workspaceAddr = workspace_allocator.get();
429
- }
430
-
431
- ACL_CHECK(aclnnCast(workspaceAddr, workspaceSize, executor, ctx.stream()));
432
-
433
- ACL_CHECK(aclDestroyTensor(acl_src));
434
- ACL_CHECK(aclDestroyTensor(tmp_tensor));
435
- ACL_CHECK(aclDestroyTensor(acl_dst));
381
+ GGML_CANN_CALL_ACLNN_OP(ctx, Argsort, acl_src, -1, (order == GGML_SORT_ORDER_DESC ? true : false),
382
+ tmp_tensor);
383
+ GGML_CANN_CALL_ACLNN_OP(ctx, Cast, tmp_tensor, ggml_cann_type_mapping(dst->type), acl_dst);
384
+ ggml_cann_release_resources(ctx, acl_src, tmp_tensor, acl_dst);
436
385
  }
437
386
 
438
387
  void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -444,27 +393,11 @@ void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
444
393
  float eps;
445
394
  memcpy(&eps, dst->op_params, sizeof(float));
446
395
 
447
- uint64_t workspaceSize = 0;
448
- aclOpExecutor* executor;
449
- void* workspaceAddr = nullptr;
450
-
451
396
  std::vector<int64_t> normData = {dst->ne[0]};
452
397
  aclIntArray* norm = aclCreateIntArray(normData.data(), normData.size());
453
- ACL_CHECK(aclnnLayerNormGetWorkspaceSize(acl_src, norm, nullptr, nullptr,
454
- eps, acl_dst, nullptr, nullptr,
455
- &workspaceSize, &executor));
456
-
457
- if (workspaceSize > 0) {
458
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
459
- workspaceAddr = workspace_allocator.get();
460
- }
461
-
462
- ACL_CHECK(
463
- aclnnLayerNorm(workspaceAddr, workspaceSize, executor, ctx.stream()));
464
-
465
- ACL_CHECK(aclDestroyIntArray(norm));
466
- ACL_CHECK(aclDestroyTensor(acl_src));
467
- ACL_CHECK(aclDestroyTensor(acl_dst));
398
+ GGML_CANN_CALL_ACLNN_OP(ctx, LayerNorm, acl_src, norm, nullptr, nullptr,
399
+ eps, acl_dst, nullptr, nullptr);
400
+ ggml_cann_release_resources(ctx, norm, acl_src, acl_dst);
468
401
  }
469
402
 
470
403
  void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -478,10 +411,6 @@ void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
478
411
  float eps;
479
412
  memcpy(&eps, dst->op_params + 1, sizeof(float));
480
413
 
481
- uint64_t workspaceSize = 0;
482
- aclOpExecutor* executor;
483
- void* workspaceAddr = nullptr;
484
-
485
414
  int64_t N = src->ne[3];
486
415
  int64_t C = src->ne[2];
487
416
  int64_t HxW = src->ne[1] * src->ne[0];
@@ -498,22 +427,9 @@ void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
498
427
  aclTensor* acl_rstd_out = ggml_cann_create_tensor(
499
428
  (char*)buffer + n_bytes, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND);
500
429
 
501
- ACL_CHECK(aclnnGroupNormGetWorkspaceSize(
502
- acl_src, nullptr, nullptr, N, C, HxW, n_groups, eps, acl_dst,
503
- acl_mean_out, acl_rstd_out, &workspaceSize, &executor));
504
-
505
- if (workspaceSize > 0) {
506
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
507
- workspaceAddr = workspace_allocator.get();
508
- }
509
-
510
- ACL_CHECK(
511
- aclnnGroupNorm(workspaceAddr, workspaceSize, executor, ctx.stream()));
512
-
513
- ACL_CHECK(aclDestroyTensor(acl_src));
514
- ACL_CHECK(aclDestroyTensor(acl_dst));
515
- ACL_CHECK(aclDestroyTensor(acl_mean_out));
516
- ACL_CHECK(aclDestroyTensor(acl_rstd_out));
430
+ GGML_CANN_CALL_ACLNN_OP(ctx, GroupNorm, acl_src, nullptr, nullptr, N, C, HxW, n_groups, eps,
431
+ acl_dst, acl_mean_out, acl_rstd_out);
432
+ ggml_cann_release_resources(ctx, acl_src, acl_dst, acl_mean_out, acl_rstd_out);
517
433
  }
518
434
 
519
435
  void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -536,68 +452,52 @@ void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
536
452
  float alphaValue = 1.0f;
537
453
  alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
538
454
 
539
- uint64_t workspaceSize = 0;
540
- aclOpExecutor* executor;
541
- void* workspaceAddr = nullptr;
542
-
543
455
  if (!inplace) {
544
456
  size_t cpy_size = ggml_nbytes(dst);
545
- ACL_CHECK(aclrtMemcpyAsync(dst->data, cpy_size, src0->data, cpy_size,
546
- ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
457
+ ggml_cann_async_memcpy(ctx, dst->data, src0->data, cpy_size,
458
+ ACL_MEMCPY_DEVICE_TO_DEVICE);
547
459
  aclTensor* acl_src0 = ggml_cann_create_tensor(
548
460
  src0, src1->ne, src0->nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset);
549
- ACL_CHECK(aclnnAddGetWorkspaceSize(acl_src0, acl_src1, alpha, acl_dst,
550
- &workspaceSize, &executor));
551
- if (workspaceSize > 0) {
552
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
553
- workspaceAddr = workspace_allocator.get();
554
- }
555
- ACL_CHECK(
556
- aclnnAdd(workspaceAddr, workspaceSize, executor, ctx.stream()));
557
- ACL_CHECK(aclDestroyTensor(acl_src0));
461
+
462
+ GGML_CANN_CALL_ACLNN_OP(ctx, Add, acl_src0, acl_src1, alpha, acl_dst);
463
+ ggml_cann_release_resources(ctx, acl_src0);
558
464
  } else {
559
- ACL_CHECK(aclnnInplaceAddGetWorkspaceSize(acl_dst, acl_src1, alpha,
560
- &workspaceSize, &executor));
561
- if (workspaceSize > 0) {
562
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
563
- workspaceAddr = workspace_allocator.get();
564
- }
565
- ACL_CHECK(aclnnInplaceAdd(workspaceAddr, workspaceSize, executor,
566
- ctx.stream()));
465
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst, acl_src1, alpha);
567
466
  }
568
-
569
- ACL_CHECK(aclDestroyTensor(acl_src1));
570
- ACL_CHECK(aclDestroyTensor(acl_dst));
467
+ ggml_cann_release_resources(ctx, acl_src1, acl_dst);
571
468
  }
572
469
 
573
- void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
470
+ /**
471
+ * @brief Performs sum reduction on a given tensor along specified dimensions.
472
+ *
473
+ * This function reduces the input tensor by summing along the specified dimensions.
474
+ *
475
+ * @param ctx The context for the CANN backend operations.
476
+ * @param dst The destination tensor where the reduced result will be stored.
477
+ * @param dim An array of dimension indices.
478
+ * @param dim_size The number of dimensions.
479
+ */
480
+ static void aclnn_reduce_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst,
481
+ int64_t* dim, size_t dim_size) {
482
+ GGML_ASSERT(dst->ne[0] == 1);
574
483
  ggml_tensor* src = dst->src[0];
575
-
576
484
  aclTensor* acl_src = ggml_cann_create_tensor(src);
577
-
578
- GGML_ASSERT(dst->ne[0] == 1);
579
485
  aclTensor* acl_dst = ggml_cann_create_tensor(dst);
486
+ aclIntArray* reduce_dims = aclCreateIntArray(dim, dim_size);
580
487
 
581
- int64_t reduce_dims_host[] = {3};
582
- aclIntArray* reduce_dims = aclCreateIntArray(reduce_dims_host, 1);
583
-
584
- uint64_t workspaceSize = 0;
585
- aclOpExecutor* executor;
586
- void* workspaceAddr = nullptr;
587
-
588
- ACL_CHECK(aclnnReduceSumGetWorkspaceSize(
589
- acl_src, reduce_dims, true, ggml_cann_type_mapping(src->type), acl_dst,
590
- &workspaceSize, &executor));
591
- if (workspaceSize > 0) {
592
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
593
- workspaceAddr = workspace_allocator.get();
594
- }
488
+ GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_src, reduce_dims, true,
489
+ ggml_cann_type_mapping(dst->type), acl_dst);
490
+ ggml_cann_release_resources(ctx, acl_src, acl_dst, reduce_dims);
491
+ }
595
492
 
596
- ACL_CHECK(
597
- aclnnReduceSum(workspaceAddr, workspaceSize, executor, ctx.stream()));
493
+ void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
494
+ int64_t reduce_dims[] = {3};
495
+ aclnn_reduce_sum(ctx, dst, reduce_dims, 1);
496
+ }
598
497
 
599
- ACL_CHECK(aclDestroyTensor(acl_src));
600
- ACL_CHECK(aclDestroyTensor(acl_dst));
498
+ void ggml_cann_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
499
+ int64_t reduce_dims[] = {0, 1, 2, 3};
500
+ aclnn_reduce_sum(ctx, dst, reduce_dims, 4);
601
501
  }
602
502
 
603
503
  void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx,
@@ -611,23 +511,8 @@ void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx,
611
511
  std::vector<int64_t> output_size{dst->ne[1], dst->ne[0]};
612
512
  auto output_size_array = aclCreateIntArray(output_size.data(), 2);
613
513
 
614
- uint64_t workspaceSize = 0;
615
- aclOpExecutor* executor;
616
- void* workspaceAddr = nullptr;
617
-
618
- ACL_CHECK(aclnnUpsampleNearest2dGetWorkspaceSize(
619
- acl_src, output_size_array, acl_dst, &workspaceSize, &executor));
620
- if (workspaceSize > 0) {
621
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
622
- workspaceAddr = workspace_allocator.get();
623
- }
624
-
625
- ACL_CHECK(aclnnUpsampleNearest2d(workspaceAddr, workspaceSize, executor,
626
- ctx.stream()));
627
-
628
- ACL_CHECK(aclDestroyIntArray(output_size_array));
629
- ACL_CHECK(aclDestroyTensor(acl_src));
630
- ACL_CHECK(aclDestroyTensor(acl_dst));
514
+ GGML_CANN_CALL_ACLNN_OP(ctx, UpsampleNearest2d, acl_src, output_size_array, acl_dst);
515
+ ggml_cann_release_resources(ctx, acl_src, acl_dst, output_size_array);
631
516
  }
632
517
 
633
518
  /**
@@ -650,23 +535,8 @@ static void aclnn_pad(ggml_backend_cann_context& ctx, aclTensor* acl_src,
650
535
  aclIntArray* acl_pad = aclCreateIntArray(paddings, GGML_MAX_DIMS * 2);
651
536
  aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT);
652
537
 
653
- uint64_t workspaceSize = 0;
654
- aclOpExecutor* executor;
655
- void* workspaceAddr = nullptr;
656
-
657
- ACL_CHECK(aclnnConstantPadNdGetWorkspaceSize(
658
- acl_src, acl_pad, acl_value, acl_dst, &workspaceSize, &executor));
659
-
660
- if (workspaceSize > 0) {
661
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
662
- workspaceAddr = workspace_allocator.get();
663
- }
664
-
665
- ACL_CHECK(aclnnConstantPadNd(workspaceAddr, workspaceSize, executor,
666
- ctx.stream()));
667
-
668
- ACL_CHECK(aclDestroyIntArray(acl_pad));
669
- ACL_CHECK(aclDestroyScalar(acl_value));
538
+ GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_src, acl_pad, acl_value, acl_dst);
539
+ ggml_cann_release_resources(ctx, acl_pad, acl_value);
670
540
  }
671
541
 
672
542
  void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -682,9 +552,7 @@ void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
682
552
  0, dst->ne[0] - src->ne[0], 0, dst->ne[1] - src->ne[1],
683
553
  0, dst->ne[2] - src->ne[2], 0, dst->ne[3] - src->ne[3]};
684
554
  aclnn_pad(ctx, acl_src, acl_dst, paddings);
685
-
686
- ACL_CHECK(aclDestroyTensor(acl_dst));
687
- ACL_CHECK(aclDestroyTensor(acl_src));
555
+ ggml_cann_release_resources(ctx, acl_src, acl_dst);
688
556
  }
689
557
 
690
558
  /**
@@ -730,28 +598,15 @@ static void ggml_cann_avg_pool2d(ggml_backend_cann_context& ctx,
730
598
  bool count_include_pad = true;
731
599
  int64_t divisor_override = 0;
732
600
  int8_t cube_math_type = 0;
601
+ #ifdef ASCEND_310P
602
+ cube_math_type = 1;
603
+ #endif
733
604
 
734
- uint64_t workspaceSize = 0;
735
- aclOpExecutor* executor;
736
- void* workspaceAddr = nullptr;
737
-
738
- ACL_CHECK(aclnnAvgPool2dGetWorkspaceSize(
739
- acl_src, kernel_size, strides, paddings_avg, ceil_mode,
740
- count_include_pad, divisor_override, cube_math_type, acl_dst,
741
- &workspaceSize, &executor));
742
-
743
- if (workspaceSize > 0) {
744
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
745
- workspaceAddr = workspace_allocator.get();
746
- }
747
- ACL_CHECK(
748
- aclnnAvgPool2d(workspaceAddr, workspaceSize, executor, ctx.stream()));
749
-
750
- ACL_CHECK(aclDestroyTensor(acl_src));
751
- ACL_CHECK(aclDestroyTensor(acl_dst));
752
- ACL_CHECK(aclDestroyIntArray(kernel_size));
753
- ACL_CHECK(aclDestroyIntArray(strides));
754
- ACL_CHECK(aclDestroyIntArray(paddings_avg));
605
+ GGML_CANN_CALL_ACLNN_OP(ctx, AvgPool2d, acl_src, kernel_size, strides, paddings_avg,
606
+ ceil_mode, count_include_pad, divisor_override,
607
+ cube_math_type, acl_dst);
608
+ ggml_cann_release_resources(ctx, acl_src, acl_dst, kernel_size, strides,
609
+ paddings_avg);
755
610
  }
756
611
 
757
612
  /**
@@ -819,29 +674,10 @@ static void ggml_cann_max_pool2d(ggml_backend_cann_context& ctx,
819
674
 
820
675
  bool ceil_mode = false;
821
676
  int64_t auto_pads = 0;
822
-
823
- uint64_t workspaceSize = 0;
824
- aclOpExecutor* executor;
825
- void* workspaceAddr = nullptr;
826
-
827
- ACL_CHECK(aclnnMaxPoolGetWorkspaceSize(
828
- tmp_tensor, kernel_size, strides, auto_pads, paddings_max, dilations,
829
- ceil_mode, acl_dst, &workspaceSize, &executor));
830
- if (workspaceSize > 0) {
831
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
832
- workspaceAddr = workspace_allocator.get();
833
- }
834
-
835
- ACL_CHECK(
836
- aclnnMaxPool(workspaceAddr, workspaceSize, executor, ctx.stream()));
837
-
838
- ACL_CHECK(aclDestroyTensor(acl_src));
839
- ACL_CHECK(aclDestroyTensor(acl_dst));
840
- ACL_CHECK(aclDestroyTensor(tmp_tensor));
841
- ACL_CHECK(aclDestroyIntArray(kernel_size));
842
- ACL_CHECK(aclDestroyIntArray(strides));
843
- ACL_CHECK(aclDestroyIntArray(paddings_max));
844
- ACL_CHECK(aclDestroyIntArray(dilations));
677
+ GGML_CANN_CALL_ACLNN_OP(ctx, MaxPool, tmp_tensor, kernel_size, strides, auto_pads,
678
+ paddings_max, dilations, ceil_mode, acl_dst);
679
+ ggml_cann_release_resources(ctx, acl_src, acl_dst, tmp_tensor, kernel_size,
680
+ strides, paddings_max, dilations);
845
681
  }
846
682
 
847
683
  void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -872,207 +708,77 @@ void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
872
708
  */
873
709
  static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src,
874
710
  aclTensor* acl_dst) {
875
- uint64_t workspaceSize = 0;
876
- aclOpExecutor* executor;
877
- void* workspaceAddr = nullptr;
878
-
879
- ACL_CHECK(aclnnInplaceCopyGetWorkspaceSize(acl_dst, acl_src, &workspaceSize,
880
- &executor));
881
-
882
- if (workspaceSize > 0) {
883
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
884
- workspaceAddr = workspace_allocator.get();
885
- }
886
-
887
- ACL_CHECK(
888
- aclnnInplaceCopy(workspaceAddr, workspaceSize, executor, ctx.stream()));
711
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst, acl_src);
889
712
  }
890
713
 
891
714
  void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
892
- ggml_tensor* src = dst->src[0];
715
+ ggml_tensor* src0 = dst->src[0];
893
716
 
894
- aclTensor* acl_src = ggml_cann_create_tensor(src);
717
+ aclTensor* acl_src = ggml_cann_create_tensor(src0);
895
718
  aclTensor* acl_dst = ggml_cann_create_tensor(dst);
896
-
897
- ggml_cann_pool_alloc src_extra_allocator(ctx.pool(), sizeof(ggml_tensor));
898
- ggml_cann_pool_alloc dst_extra_allocator(ctx.pool(), sizeof(ggml_tensor));
899
- src->extra = src_extra_allocator.get();
900
- dst->extra = dst_extra_allocator.get();
901
- ACL_CHECK(aclrtMemcpyAsync(src->extra, sizeof(ggml_tensor), src,
902
- sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE,
903
- ctx.stream()));
904
- ACL_CHECK(aclrtMemcpyAsync(dst->extra, sizeof(ggml_tensor), dst,
905
- sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE,
906
- ctx.stream()));
907
-
908
- if ((dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32) &&
909
- ggml_are_same_shape(src, dst)) {
910
- cann_copy(ctx, acl_src, acl_dst);
911
- ACL_CHECK(aclDestroyTensor(acl_src));
912
- ACL_CHECK(aclDestroyTensor(acl_dst));
913
- return;
914
- }
915
- // TODO: simplify
916
- if (src->type == GGML_TYPE_F16) {
917
- if (dst->type == GGML_TYPE_Q8_0) {
918
- aclrtlaunch_ascendc_quantize_f16_q8_0(
919
- 24, ctx.stream(), src->data, dst->data,
920
- ((ggml_tensor*)src->extra)->ne, ((ggml_tensor*)src->extra)->nb,
921
- ((ggml_tensor*)dst->extra)->ne);
922
- return;
923
- }
924
- if (dst->type == GGML_TYPE_Q4_0) {
925
- aclrtlaunch_ascendc_quantize_f16_to_q4_0(
926
- 24, ctx.stream(), src->data, dst->data,
927
- ((ggml_tensor*)src->extra)->ne, ((ggml_tensor*)src->extra)->nb,
928
- ((ggml_tensor*)dst->extra)->ne);
929
- return;
930
- }
931
- if (dst->type == GGML_TYPE_F16) {
932
- if (ggml_are_same_shape(src, dst)) {
933
- cann_copy(ctx, acl_src, acl_dst);
934
- ACL_CHECK(aclDestroyTensor(acl_src));
935
- ACL_CHECK(aclDestroyTensor(acl_dst));
936
- return;
937
- }
938
- if (ggml_is_contiguous(dst)) {
939
- const size_t src_type_size = ggml_type_size(src->type);
940
- if (src->nb[0] == src_type_size) {
941
- // src0 is contigous on first dimension, copy by rows
942
- int64_t rows_num = ggml_nrows(src);
943
-
944
- aclrtlaunch_ascendc_dup_by_rows_fp16(
945
- rows_num, ctx.stream(), src->data, dst->data,
946
- ((ggml_tensor*)src->extra)->ne,
947
- ((ggml_tensor*)src->extra)->nb,
948
- ((ggml_tensor*)dst->extra)->ne,
949
- ((ggml_tensor*)dst->extra)->nb);
950
- return;
951
- }
952
- GGML_ABORT("fatal error");
953
- }
954
- GGML_ABORT("fatal error");
955
- }
956
- if (dst->type == GGML_TYPE_F32) {
957
- if (ggml_are_same_shape(src, dst)) {
958
- cann_copy(ctx, acl_src, acl_dst);
959
- ACL_CHECK(aclDestroyTensor(acl_src));
960
- ACL_CHECK(aclDestroyTensor(acl_dst));
961
- return;
962
- }
963
- if (ggml_is_contiguous(dst)) {
964
- const size_t src_type_size = ggml_type_size(src->type);
965
- if (src->nb[0] == src_type_size) {
966
- // src0 is contigous on first dimension, copy by rows
967
- int64_t rows_num = ggml_nrows(src);
968
- aclrtlaunch_ascendc_dup_by_rows_fp16_to_fp32(
969
- rows_num, ctx.stream(), src->data, dst->data,
970
- ((ggml_tensor*)src->extra)->ne,
971
- ((ggml_tensor*)src->extra)->nb,
972
- ((ggml_tensor*)dst->extra)->ne,
973
- ((ggml_tensor*)dst->extra)->nb);
974
- return;
975
- }
976
- GGML_ABORT("fatal error");
977
- }
978
- GGML_ABORT("fatal error");
979
- }
980
- // TODO
981
- GGML_ABORT("fatal error");
982
- } else if (src->type == GGML_TYPE_F32) {
983
- // TODO: if (src0->type == dst->type && ne00 == ne0 && nb00 == type_size
984
- // && nb0 == type_size)
985
- if (dst->type == GGML_TYPE_Q8_0) {
986
- aclrtlaunch_ascendc_quantize_f32_q8_0(
987
- 24, ctx.stream(), src->data, dst->data,
988
- ((ggml_tensor*)src->extra)->ne, ((ggml_tensor*)src->extra)->nb,
989
- ((ggml_tensor*)dst->extra)->ne);
990
- return;
991
- }
992
- if (dst->type == GGML_TYPE_Q4_0) {
993
- aclrtlaunch_ascendc_quantize_f32_to_q4_0(
994
- 24, ctx.stream(), src->data, dst->data,
995
- ((ggml_tensor*)src->extra)->ne, ((ggml_tensor*)src->extra)->nb,
996
- ((ggml_tensor*)dst->extra)->ne);
997
- return;
719
+ if (ggml_are_same_shape(src0, dst)) {
720
+ if (dst->type == src0->type) {
721
+ cann_copy(ctx, acl_src, acl_dst);
722
+ } else {
723
+ aclnn_cast(ctx, acl_src, acl_dst, ggml_cann_type_mapping(dst->type));
998
724
  }
999
- if (dst->type == GGML_TYPE_F32) {
1000
- if (ggml_are_same_shape(src, dst)) {
1001
- cann_copy(ctx, acl_src, acl_dst);
1002
- ACL_CHECK(aclDestroyTensor(acl_src));
1003
- ACL_CHECK(aclDestroyTensor(acl_dst));
725
+ } else {
726
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
727
+ if (dst->type == src0->type) {
728
+ size_t cpy_size = ggml_nbytes(dst);
729
+ ggml_cann_async_memcpy(ctx, dst->data, src0->data, cpy_size,
730
+ ACL_MEMCPY_DEVICE_TO_DEVICE);
1004
731
  return;
1005
- }
1006
- if (ggml_is_contiguous(dst)) {
1007
- const size_t src_type_size = ggml_type_size(src->type);
1008
- if (src->nb[0] == src_type_size) {
1009
- // src0 is contigous on first dimension, copy by rows
1010
- int64_t rows_num = ggml_nrows(src);
1011
- aclrtlaunch_ascendc_dup_by_rows_fp32(
1012
- rows_num, ctx.stream(), src->data, dst->data,
1013
- ((ggml_tensor*)src->extra)->ne,
1014
- ((ggml_tensor*)src->extra)->nb,
1015
- ((ggml_tensor*)dst->extra)->ne,
1016
- ((ggml_tensor*)dst->extra)->nb);
1017
- return;
1018
- }
1019
- GGML_ABORT("fatal error");
1020
732
  } else {
1021
- // TODO: dst not contiguous
1022
- GGML_ABORT("fatal error");
1023
- }
1024
- }
1025
- if (dst->type == GGML_TYPE_F16) {
1026
- if (ggml_are_same_shape(src, dst)) {
1027
- cann_copy(ctx, acl_src, acl_dst);
1028
- ACL_CHECK(aclDestroyTensor(acl_src));
1029
- ACL_CHECK(aclDestroyTensor(acl_dst));
733
+ ggml_cann_pool_alloc src_buffer_allocator(
734
+ ctx.pool(),
735
+ ggml_nelements(dst) * ggml_type_size(dst->type));
736
+ void* src_trans_buffer = src_buffer_allocator.get();
737
+ size_t src_trans_nb[GGML_MAX_DIMS];
738
+ src_trans_nb[0] = ggml_type_size(dst->type);
739
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
740
+ src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
741
+ }
742
+ aclTensor* src_trans_tensor = ggml_cann_create_tensor(
743
+ src_trans_buffer, ggml_cann_type_mapping(dst->type),
744
+ ggml_type_size(dst->type), src0->ne, src_trans_nb,
745
+ GGML_MAX_DIMS);
746
+
747
+ aclnn_cast(ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping(dst->type));
748
+ size_t cpy_size = ggml_nbytes(dst);
749
+ ggml_cann_async_memcpy(ctx, dst->data, src_trans_buffer, cpy_size,
750
+ ACL_MEMCPY_DEVICE_TO_DEVICE);
751
+ ggml_cann_release_resources(ctx, src_trans_tensor);
1030
752
  return;
1031
753
  }
1032
- if (ggml_is_contiguous(dst)) {
1033
- const size_t src_type_size = ggml_type_size(src->type);
1034
- if (src->nb[0] == src_type_size) {
1035
- // src0 is contigous on first dimension, copy by rows
1036
- int64_t rows_num = ggml_nrows(src);
1037
- aclrtlaunch_ascendc_dup_by_rows_fp32_to_fp16(
1038
- rows_num, ctx.stream(), src->data, dst->data,
1039
- ((ggml_tensor*)src->extra)->ne,
1040
- ((ggml_tensor*)src->extra)->nb,
1041
- ((ggml_tensor*)dst->extra)->ne,
1042
- ((ggml_tensor*)dst->extra)->nb);
1043
- return;
1044
- }
1045
- GGML_ABORT("fatal error");
754
+ } else if (ggml_is_contiguous(dst)) {
755
+ ggml_cann_pool_alloc src_buffer_allocator(
756
+ ctx.pool(), ggml_nelements(dst) * ggml_type_size(dst->type));
757
+ void* src_trans_buffer = src_buffer_allocator.get();
758
+ size_t src_trans_nb[GGML_MAX_DIMS];
759
+ src_trans_nb[0] = ggml_type_size(dst->type);
760
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
761
+ src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
1046
762
  }
1047
- }
1048
- // TODO
1049
- GGML_ABORT("fatal error");
1050
- } else {
1051
- if (ggml_are_same_shape(src, dst)) {
1052
- cann_copy(ctx, acl_src, acl_dst);
1053
- ACL_CHECK(aclDestroyTensor(acl_src));
1054
- ACL_CHECK(aclDestroyTensor(acl_dst));
763
+ aclTensor* src_trans_tensor = ggml_cann_create_tensor(
764
+ src_trans_buffer, ggml_cann_type_mapping(dst->type),
765
+ ggml_type_size(dst->type), src0->ne, src_trans_nb,
766
+ GGML_MAX_DIMS);
767
+
768
+ aclnn_cast(ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping(dst->type));
769
+
770
+ size_t cpy_size = ggml_nbytes(dst);
771
+ ggml_cann_async_memcpy(ctx, dst->data, src_trans_buffer, cpy_size,
772
+ ACL_MEMCPY_DEVICE_TO_DEVICE);
773
+ ggml_cann_release_resources(ctx, src_trans_tensor);
1055
774
  return;
775
+ } else {
776
+ GGML_ABORT("Unsupport dst is not tontiguous.");
1056
777
  }
1057
- GGML_ABORT("fatal error");
1058
778
  }
779
+ ggml_cann_release_resources(ctx, acl_src, acl_dst);
1059
780
  }
1060
781
 
1061
- #ifdef __cplusplus
1062
- extern "C" {
1063
- #endif
1064
- aclnnStatus aclnnRmsNormGetWorkspaceSize(const aclTensor* x,
1065
- const aclTensor* gamma, double epsilon,
1066
- const aclTensor* yOut,
1067
- const aclTensor* rstdOout,
1068
- uint64_t* workspaceSize,
1069
- aclOpExecutor** executor);
1070
- aclnnStatus aclnnRmsNorm(void* workspace, uint64_t workspaceSize,
1071
- aclOpExecutor* executor, aclrtStream stream);
1072
- #ifdef __cplusplus
1073
- }
1074
- #endif
1075
-
1076
782
  /**
1077
783
  * @brief Creates an ACL tensor initialized with zeros using a provided buffer.
1078
784
  *
@@ -1098,7 +804,7 @@ static aclTensor* aclnn_zero(ggml_backend_cann_context& ctx, void* buffer,
1098
804
  nb[i] = nb[i - 1] * ne[i - 1];
1099
805
  }
1100
806
 
1101
- ACL_CHECK(aclrtMemsetAsync(buffer, n_bytes, 0, n_bytes, ctx.stream()));
807
+ ggml_cann_async_memset(ctx, buffer, n_bytes, 0);
1102
808
  aclTensor* zero =
1103
809
  ggml_cann_create_tensor(buffer, type, type_size, ne, nb, dims);
1104
810
  return zero;
@@ -1131,21 +837,7 @@ static aclTensor* aclnn_values(ggml_backend_cann_context& ctx, void* buffer,
1131
837
  float alpha_host = 1.0f;
1132
838
  aclScalar* alpha = aclCreateScalar(&alpha_host, aclDataType::ACL_FLOAT);
1133
839
  aclScalar* other = aclCreateScalar(&value, aclDataType::ACL_FLOAT);
1134
-
1135
- uint64_t workspaceSize = 0;
1136
- aclOpExecutor* executor;
1137
- void* workspaceAddr = nullptr;
1138
-
1139
- ACL_CHECK(aclnnInplaceAddsGetWorkspaceSize(acl_tensor, other, alpha,
1140
- &workspaceSize, &executor));
1141
-
1142
- if (workspaceSize > 0) {
1143
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
1144
- workspaceAddr = workspace_allocator.get();
1145
- }
1146
- ACL_CHECK(
1147
- aclnnInplaceAdds(workspaceAddr, workspaceSize, executor, ctx.stream()));
1148
-
840
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_tensor, other, alpha);
1149
841
  return acl_tensor;
1150
842
  }
1151
843
 
@@ -1157,13 +849,6 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1157
849
 
1158
850
  float eps;
1159
851
  memcpy(&eps, dst->op_params, sizeof(float));
1160
-
1161
- GGML_ASSERT(eps > 0.0f);
1162
-
1163
- uint64_t workspaceSize = 0;
1164
- aclOpExecutor* executor;
1165
- void* workspaceAddr = nullptr;
1166
-
1167
852
  size_t one_tensor_n_bytes = src->ne[0] * ggml_element_size(src);
1168
853
  ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes);
1169
854
 
@@ -1178,22 +863,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1178
863
  aclnn_zero(ctx, zero_tensor_allocator.get(), zero_tensor_n_bytes,
1179
864
  src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type),
1180
865
  ggml_element_size(src));
1181
-
1182
- ACL_CHECK(aclnnRmsNormGetWorkspaceSize(
1183
- acl_src, acl_gamma, eps, acl_dst, acl_rstd, &workspaceSize, &executor));
1184
-
1185
- if (workspaceSize > 0) {
1186
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
1187
- workspaceAddr = workspace_allocator.get();
1188
- }
1189
-
1190
- ACL_CHECK(
1191
- aclnnRmsNorm(workspaceAddr, workspaceSize, executor, ctx.stream()));
1192
-
1193
- ACL_CHECK(aclDestroyTensor(acl_src));
1194
- ACL_CHECK(aclDestroyTensor(acl_dst));
1195
- ACL_CHECK(aclDestroyTensor(acl_gamma));
1196
- ACL_CHECK(aclDestroyTensor(acl_rstd));
866
+ GGML_CANN_CALL_ACLNN_OP(ctx, RmsNorm, acl_src, acl_gamma, eps, acl_dst, acl_rstd);
867
+ ggml_cann_release_resources(ctx, acl_src, acl_dst, acl_gamma, acl_rstd);
1197
868
  }
1198
869
 
1199
870
  // TODO: performace is low.
@@ -1215,75 +886,14 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst,
1215
886
  src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type),
1216
887
  ggml_element_size(src), value);
1217
888
 
1218
- uint64_t workspaceSize = 0;
1219
- aclOpExecutor* executor;
1220
- void* workspaceAddr = nullptr;
1221
-
1222
- ACL_CHECK(aclnnInplaceTriuGetWorkspaceSize(mask_tensor, n_past + 1,
1223
- &workspaceSize, &executor));
1224
- if (workspaceSize > 0) {
1225
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
1226
- workspaceAddr = workspace_allocator.get();
1227
- }
1228
-
1229
- ACL_CHECK(
1230
- aclnnInplaceTriu(workspaceAddr, workspaceSize, executor, ctx.stream()));
1231
-
1232
- ACL_CHECK(aclnnTrilGetWorkspaceSize(acl_src, n_past + 1, acl_dst,
1233
- &workspaceSize, &executor));
1234
- if (workspaceSize > 0) {
1235
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
1236
- workspaceAddr = workspace_allocator.get();
1237
- }
1238
-
1239
- ACL_CHECK(aclnnTril(workspaceAddr, workspaceSize, executor, ctx.stream()));
1240
-
1241
889
  aclScalar* alpha = nullptr;
1242
890
  float alphaValue = 1.0f;
1243
891
  alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
1244
892
 
1245
- ACL_CHECK(aclnnInplaceAddGetWorkspaceSize(acl_dst, mask_tensor, alpha,
1246
- &workspaceSize, &executor));
1247
- if (workspaceSize > 0) {
1248
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
1249
- workspaceAddr = workspace_allocator.get();
1250
- }
1251
- ACL_CHECK(
1252
- aclnnInplaceAdd(workspaceAddr, workspaceSize, executor, ctx.stream()));
1253
-
1254
- ACL_CHECK(aclDestroyScalar(alpha));
1255
- ACL_CHECK(aclDestroyTensor(mask_tensor));
1256
- ACL_CHECK(aclDestroyTensor(acl_src));
1257
- ACL_CHECK(aclDestroyTensor(acl_dst));
1258
- }
1259
-
1260
- /**
1261
- * @brief Casts the data type of a source tensor to a destination tensor.
1262
- *
1263
- * This function casts the data type of the source tensor `acl_src` to the
1264
- * specified data type `cast_data_type` and stores the result in the destination
1265
- * tensor `acl_dst`.
1266
- *
1267
- * @param ctx The context for the CANN backend operations.
1268
- * @param acl_src The source tensor whose data type will be casted.
1269
- * @param acl_dst The destination tensor where the casted result will be stored.
1270
- * @param cast_data_type The target data type to which the source tensor will be
1271
- * casted.
1272
- */
1273
- static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src,
1274
- aclTensor* acl_dst, aclDataType cast_data_type) {
1275
- uint64_t workspaceSize = 0;
1276
- aclOpExecutor* executor;
1277
- void* workspaceAddr = nullptr;
1278
-
1279
- ACL_CHECK(aclnnCastGetWorkspaceSize(acl_src, cast_data_type, acl_dst,
1280
- &workspaceSize, &executor));
1281
- if (workspaceSize > 0) {
1282
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
1283
- workspaceAddr = workspace_allocator.get();
1284
- }
1285
-
1286
- ACL_CHECK(aclnnCast(workspaceAddr, workspaceSize, executor, ctx.stream()));
893
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceTriu, mask_tensor, n_past + 1);
894
+ GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src, n_past + 1, acl_dst);
895
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst, mask_tensor, alpha);
896
+ ggml_cann_release_resources(ctx, alpha, acl_src, acl_dst, mask_tensor);
1287
897
  }
1288
898
 
1289
899
  /**
@@ -1304,40 +914,10 @@ static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src,
1304
914
  static void aclnn_permute(ggml_backend_cann_context& ctx, aclTensor* acl_src,
1305
915
  aclTensor* acl_dst, int64_t* new_dim, uint64_t dims) {
1306
916
  aclIntArray* acl_dims = aclCreateIntArray(new_dim, dims);
1307
-
1308
- uint64_t workspaceSize = 0;
1309
- aclOpExecutor* executor;
1310
- void* workspaceAddr = nullptr;
1311
-
1312
- ACL_CHECK(aclnnPermuteGetWorkspaceSize(acl_src, acl_dims, acl_dst,
1313
- &workspaceSize, &executor));
1314
- if (workspaceSize > 0) {
1315
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
1316
- workspaceAddr = workspace_allocator.get();
1317
- }
1318
-
1319
- ACL_CHECK(
1320
- aclnnPermute(workspaceAddr, workspaceSize, executor, ctx.stream()));
1321
-
1322
- ACL_CHECK(aclDestroyIntArray(acl_dims));
917
+ GGML_CANN_CALL_ACLNN_OP(ctx, Permute, acl_src, acl_dims, acl_dst);
918
+ ggml_cann_release_resources(ctx, acl_dims);
1323
919
  }
1324
920
 
1325
- #ifdef __cplusplus
1326
- extern "C" {
1327
- #endif
1328
- aclnnStatus aclnnIm2colGetWorkspaceSize(const aclTensor* self,
1329
- const aclIntArray* kernelSize,
1330
- const aclIntArray* dilation,
1331
- const aclIntArray* padding,
1332
- const aclIntArray* stride,
1333
- aclTensor* out, uint64_t* workspaceSize,
1334
- aclOpExecutor** executor);
1335
- aclnnStatus aclnnIm2col(void* workspace, uint64_t workspaceSize,
1336
- aclOpExecutor* executor, aclrtStream stream);
1337
- #ifdef __cplusplus
1338
- }
1339
- #endif
1340
-
1341
921
  static void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx,
1342
922
  ggml_tensor* dst,
1343
923
  ggml_tensor* src1,
@@ -1356,8 +936,7 @@ static void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx,
1356
936
  aclnn_permute(ctx, tmp_im2col_tensor, acl_dst, permute_dim, 3);
1357
937
  }
1358
938
 
1359
- // release
1360
- ACL_CHECK(aclDestroyTensor(acl_dst));
939
+ ggml_cann_release_resources(ctx, acl_dst);
1361
940
  }
1362
941
 
1363
942
  static void ggml_cann_im2col_1d_post_process(
@@ -1379,7 +958,6 @@ static void ggml_cann_im2col_1d_post_process(
1379
958
 
1380
959
  // Permute: [N, IC * KH * KW, OW * OH] ->
1381
960
  // [N, OW * OH * n_bytes_factor, IC * KH * KW]
1382
- aclTensor* tmp_permute_tensor = nullptr;
1383
961
  ggml_cann_pool_alloc tmp_permute_allocator(ctx.pool());
1384
962
  tmp_permute_allocator.alloc(ggml_nbytes(dst) * n_bytes_factor);
1385
963
  void* tmp_permute_buffer = tmp_permute_allocator.get();
@@ -1391,7 +969,7 @@ static void ggml_cann_im2col_1d_post_process(
1391
969
  tmp_permute_nb[i] = tmp_permute_nb[i - 1] * tmp_permute_ne[i - 1];
1392
970
  }
1393
971
 
1394
- tmp_permute_tensor = ggml_cann_create_tensor(
972
+ aclTensor* tmp_permute_tensor = ggml_cann_create_tensor(
1395
973
  tmp_permute_buffer, ggml_cann_type_mapping(dst->type),
1396
974
  ggml_type_size(dst->type), tmp_permute_ne, tmp_permute_nb,
1397
975
  GGML_MAX_DIMS - 1, ACL_FORMAT_ND);
@@ -1421,9 +999,8 @@ static void ggml_cann_im2col_1d_post_process(
1421
999
  c * KH * KW * n_step_w * ggml_type_size(dst->type);
1422
1000
 
1423
1001
  for (int i = 0; i < n_step_w; i++) {
1424
- ACL_CHECK(aclrtMemcpyAsync(
1425
- cur_dst_buffer, size_cpy, cur_permute_buffer, size_cpy,
1426
- ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
1002
+ ggml_cann_async_memcpy(ctx, cur_dst_buffer, cur_permute_buffer, size_cpy,
1003
+ ACL_MEMCPY_DEVICE_TO_DEVICE);
1427
1004
  cur_dst_buffer =
1428
1005
  (char*)cur_dst_buffer + KH * KW * ggml_type_size(dst->type);
1429
1006
  cur_permute_buffer = (char*)cur_permute_buffer +
@@ -1433,13 +1010,11 @@ static void ggml_cann_im2col_1d_post_process(
1433
1010
  } else {
1434
1011
  offset = KH * KW * n_step_w *
1435
1012
  ggml_type_size(dst->type); // equal to ggml_nbytes(dst)
1436
- ACL_CHECK(aclrtMemcpyAsync(dst->data, offset,
1437
- (char*)tmp_permute_buffer + offset, offset,
1438
- ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
1013
+ ggml_cann_async_memcpy(ctx, dst->data, (char*)tmp_permute_buffer + offset, offset,
1014
+ ACL_MEMCPY_DEVICE_TO_DEVICE);
1439
1015
  }
1440
1016
 
1441
- // release
1442
- ACL_CHECK(aclDestroyTensor(tmp_permute_tensor));
1017
+ ggml_cann_release_resources(ctx, tmp_permute_tensor);
1443
1018
  }
1444
1019
 
1445
1020
  void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -1501,23 +1076,8 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1501
1076
  auto* dilations = aclCreateIntArray(dilation_size.data(), 2);
1502
1077
  auto* paddings = aclCreateIntArray(padding_dims.data(), 2);
1503
1078
  auto* strides = aclCreateIntArray(stride_dims.data(), 2);
1504
-
1505
- uint64_t workspaceSize = 0;
1506
- aclOpExecutor* executor;
1507
- void* workspaceAddr = nullptr;
1508
-
1509
- ACL_CHECK(aclnnIm2colGetWorkspaceSize(acl_src1, kernel_size, dilations,
1510
- paddings, strides, tmp_im2col_tensor,
1511
- &workspaceSize, &executor));
1512
-
1513
- ggml_cann_pool_alloc workspace_allocator(ctx.pool());
1514
- if (workspaceSize > 0) {
1515
- workspace_allocator.alloc(workspaceSize);
1516
- workspaceAddr = workspace_allocator.get();
1517
- }
1518
-
1519
- ACL_CHECK(
1520
- aclnnIm2col(workspaceAddr, workspaceSize, executor, ctx.stream()));
1079
+ GGML_CANN_CALL_ACLNN_OP(ctx, Im2col, acl_src1, kernel_size, dilations,
1080
+ paddings, strides, tmp_im2col_tensor);
1521
1081
 
1522
1082
  // Cast if dst is f16.
1523
1083
  aclTensor* tmp_cast_tensor = nullptr;
@@ -1536,8 +1096,7 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1536
1096
  tmp_cast_buffer, ggml_cann_type_mapping(dst->type),
1537
1097
  ggml_type_size(dst->type), tmp_im2col_ne, temp_cast_nb,
1538
1098
  GGML_MAX_DIMS - 1, ACL_FORMAT_ND);
1539
- aclnn_cast(ctx, tmp_im2col_tensor, tmp_cast_tensor,
1540
- ggml_cann_type_mapping(dst->type));
1099
+ aclnn_cast(ctx, tmp_im2col_tensor, tmp_cast_tensor, ggml_cann_type_mapping(dst->type));
1541
1100
  }
1542
1101
 
1543
1102
  // post-processing
@@ -1551,14 +1110,8 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1551
1110
  tmp_im2col_tensor, im2col_op_params);
1552
1111
  }
1553
1112
 
1554
- // release
1555
- ACL_CHECK(aclDestroyTensor(acl_src1));
1556
- ACL_CHECK(aclDestroyTensor(tmp_im2col_tensor));
1557
- ACL_CHECK(aclDestroyTensor(tmp_cast_tensor));
1558
- ACL_CHECK(aclDestroyIntArray(kernel_size));
1559
- ACL_CHECK(aclDestroyIntArray(dilations));
1560
- ACL_CHECK(aclDestroyIntArray(paddings));
1561
- ACL_CHECK(aclDestroyIntArray(strides));
1113
+ ggml_cann_release_resources(ctx, acl_src1, tmp_im2col_tensor, tmp_cast_tensor,
1114
+ kernel_size, dilations, paddings, strides);
1562
1115
  }
1563
1116
 
1564
1117
  /**
@@ -1575,285 +1128,17 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1575
1128
  * @param acl_src The tensor on which the exponential function will be applied.
1576
1129
  */
1577
1130
  static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src) {
1578
- uint64_t workspaceSize = 0;
1579
- aclOpExecutor* executor;
1580
- void* workspaceAddr = nullptr;
1581
-
1582
- ACL_CHECK(
1583
- aclnnInplaceExpGetWorkspaceSize(acl_src, &workspaceSize, &executor));
1584
- if (workspaceSize > 0) {
1585
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
1586
- workspaceAddr = workspace_allocator.get();
1587
- }
1588
-
1589
- ACL_CHECK(
1590
- aclnnInplaceExp(workspaceAddr, workspaceSize, executor, ctx.stream()));
1591
- }
1592
-
1593
- /**
1594
- * @brief Multiplies elements of a tensor by a scalar value, optionally
1595
- * in-place.
1596
- *
1597
- * This function multiplies each element of the source tensor `acl_src` by the
1598
- * scalar `scale` and stores the result in the destination tensor `acl_dst`. If
1599
- * `inplace` is true, `acl_dst` will not be used and the operation is performed
1600
- * in-place on `acl_src`.
1601
- * The operation is defined as:
1602
- * \f[
1603
- * \text {acl_dst }_i=\text {acl_src }_i \times \text {scale}
1604
- * \f]
1605
- *
1606
- * @param ctx The context for the CANN backend operations.
1607
- * @param acl_src The source tensor whose elements will be multiplied.
1608
- * @param scale The scalar value by which each element of `acl_src` will be
1609
- * multiplied.
1610
- * @param acl_dst The destination tensor where the result will be stored if
1611
- * `inplace` is false.
1612
- * @param inplace Flag indicating whether to perform the operation in-place on
1613
- * `acl_src`.
1614
- */
1615
- static void aclnn_muls(ggml_backend_cann_context& ctx, aclTensor* acl_src,
1616
- float scale, aclTensor* acl_dst, bool inplace) {
1617
- aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT);
1618
-
1619
- uint64_t workspaceSize = 0;
1620
- aclOpExecutor* executor;
1621
- void* workspaceAddr = nullptr;
1622
-
1623
- if (inplace) {
1624
- ACL_CHECK(aclnnInplaceMulsGetWorkspaceSize(acl_src, acl_scale,
1625
- &workspaceSize, &executor));
1626
- if (workspaceSize > 0) {
1627
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
1628
- workspaceAddr = workspace_allocator.get();
1629
- }
1630
-
1631
- ACL_CHECK(aclnnInplaceMuls(workspaceAddr, workspaceSize, executor,
1632
- ctx.stream()));
1633
- } else {
1634
- ACL_CHECK(aclnnMulsGetWorkspaceSize(acl_src, acl_scale, acl_dst,
1635
- &workspaceSize, &executor));
1636
- if (workspaceSize > 0) {
1637
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
1638
- workspaceAddr = workspace_allocator.get();
1639
- }
1640
-
1641
- ACL_CHECK(
1642
- aclnnMuls(workspaceAddr, workspaceSize, executor, ctx.stream()));
1643
- }
1644
-
1645
- ACL_CHECK(aclDestroyScalar(acl_scale));
1131
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceExp, acl_src);
1646
1132
  }
1647
1133
 
1648
- /**
1649
- * @brief Performs an in-place element-wise multiplication of two tensors.
1650
- *
1651
- * This function performs an element-wise multiplication of the tensors
1652
- * `acl_src` and `acl_other` and stores the result in `acl_src`.
1653
- * The operation is defined as:
1654
- * \f[
1655
- * \text {acl_src }_i=\text {acl_src }_i \times \text {acl_other }_i
1656
- * \f]
1657
- *
1658
- * @param ctx The context for the CANN backend operations.
1659
- * @param acl_src The source tensor where the multiplication result will be
1660
- * stored.
1661
- * @param acl_other The tensor whose elements will be multiplied with `acl_src`.
1662
- */
1663
- static void aclnn_inplace_mul(ggml_backend_cann_context& ctx,
1664
- aclTensor* acl_src, aclTensor* acl_other) {
1665
- uint64_t workspaceSize = 0;
1666
- aclOpExecutor* executor;
1667
- void* workspaceAddr = nullptr;
1668
-
1669
- ACL_CHECK(aclnnInplaceMulGetWorkspaceSize(acl_src, acl_other,
1670
- &workspaceSize, &executor));
1671
- if (workspaceSize > 0) {
1672
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
1673
- workspaceAddr = workspace_allocator.get();
1674
- }
1675
-
1676
- ACL_CHECK(
1677
- aclnnInplaceMul(workspaceAddr, workspaceSize, executor, ctx.stream()));
1678
- }
1679
-
1680
- /**
1681
- * @brief Performs element-wise multiplication of two tensors and stores the
1682
- * result in a destination tensor.
1683
- *
1684
- * This function performs element-wise multiplication of the tensors `acl_src`
1685
- * and `acl_other` and stores the result in the destination tensor `acl_dst`.
1686
- * The operation is defined as:
1687
- * \f[
1688
- * \text {acl_dst }_i=\text {acl_src }_i \times \text {acl_other }_i
1689
- * \f]
1690
- *
1691
- * @param ctx The context for the CANN backend operations.
1692
- * @param acl_src The first tensor for element-wise multiplication.
1693
- * @param acl_other The second tensor for element-wise multiplication.
1694
- * @param acl_dst The destination tensor where the result will be stored.
1695
- */
1696
- static void aclnn_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src,
1697
- aclTensor* acl_other, aclTensor* acl_dst) {
1698
- uint64_t workspaceSize = 0;
1699
- aclOpExecutor* executor;
1700
- void* workspaceAddr = nullptr;
1701
-
1702
- ACL_CHECK(aclnnMulGetWorkspaceSize(acl_src, acl_other, acl_dst,
1703
- &workspaceSize, &executor));
1704
- if (workspaceSize > 0) {
1705
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
1706
- workspaceAddr = workspace_allocator.get();
1707
- }
1708
-
1709
- ACL_CHECK(aclnnMul(workspaceAddr, workspaceSize, executor, ctx.stream()));
1710
- }
1711
-
1712
- /**
1713
- * @brief Applies element-wise cosine function to the elements of a tensor.
1714
- *
1715
- * This function computes the cosine of each element in the source tensor
1716
- * `acl_src` and stores the result in the destination tensor `acl_dst`. The
1717
- * operation is defined as: \f[ \text {acl_dst }_i=\cos \left(\text {acl_src
1718
- * }_i\right) \f]
1719
- *
1720
- * @param ctx The context for the CANN backend operations.
1721
- * @param acl_src The source tensor on which the cosine function will be
1722
- * applied.
1723
- * @param acl_dst The destination tensor where the cosine results will be
1724
- * stored.
1725
- */
1726
- static void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
1134
+ void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
1727
1135
  aclTensor* acl_dst) {
1728
- uint64_t workspaceSize = 0;
1729
- aclOpExecutor* executor;
1730
- void* workspaceAddr = nullptr;
1731
-
1732
- ACL_CHECK(
1733
- aclnnCosGetWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor));
1734
- if (workspaceSize > 0) {
1735
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
1736
- workspaceAddr = workspace_allocator.get();
1737
- }
1738
-
1739
- ACL_CHECK(aclnnCos(workspaceAddr, workspaceSize, executor, ctx.stream()));
1136
+ GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);
1740
1137
  }
1741
1138
 
1742
- /**
1743
- * @brief Applies element-wise sine function to the elements of a tensor.
1744
- *
1745
- * This function computes the sine of each element in the source tensor
1746
- `acl_src`
1747
- * and stores the result in the destination tensor `acl_dst`.
1748
- * The operation is defined as:
1749
- * \f[
1750
- * \text {acl_dst }_i=\sin \left(\text {acl_src }_i\right)
1751
- * \f]
1752
-
1753
- * @param ctx The context for the CANN backend operations.
1754
- * @param acl_src The source tensor on which the sine function will be applied.
1755
- * @param acl_dst The destination tensor where the sine results will be stored.
1756
- */
1757
- static void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
1139
+ void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
1758
1140
  aclTensor* acl_dst) {
1759
- uint64_t workspaceSize = 0;
1760
- aclOpExecutor* executor;
1761
- void* workspaceAddr = nullptr;
1762
-
1763
- ACL_CHECK(
1764
- aclnnSinGetWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor));
1765
- if (workspaceSize > 0) {
1766
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
1767
- workspaceAddr = workspace_allocator.get();
1768
- }
1769
-
1770
- ACL_CHECK(aclnnSin(workspaceAddr, workspaceSize, executor, ctx.stream()));
1771
- }
1772
-
1773
- /**
1774
- * @brief Performs element-wise division of tensor1 by tensor2 , multiplies the
1775
- result by the scalar value and adds it to self .
1776
- *
1777
- * Performs element-wise division of tensor1 by tensor2,
1778
- * multiplies the result by the scalar value and adds it to self .
1779
- * The operation is defined as:
1780
- * \f[
1781
- * \text{out}_i = \text{selft}_i + \text{value} \times
1782
- \frac{\text{tensor1}_i}{\text{tensor2}_i}
1783
- * \f]
1784
-
1785
- * @param ctx The context for the CANN backend operations.
1786
- * @param acl_self The source tensor on which the addcdiv function will be
1787
- applied.
1788
- * @param tensor1 Numerator tensor.
1789
- * @param tensor2 Denominator tensor.
1790
- * @param value The value to be used for coefficient.
1791
- */
1792
- static void aclnn_inplace_addcdiv(ggml_backend_cann_context& ctx,
1793
- aclTensor* acl_self, aclTensor* tensor1,
1794
- aclTensor* tensor2, float value) {
1795
- uint64_t workspaceSize = 0;
1796
- aclOpExecutor* executor;
1797
- void* workspaceAddr = nullptr;
1798
- aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT);
1799
-
1800
- ACL_CHECK(aclnnInplaceAddcdivGetWorkspaceSize(
1801
- acl_self, tensor1, tensor2, acl_value, &workspaceSize, &executor));
1802
- if (workspaceSize > 0) {
1803
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
1804
- workspaceAddr = workspace_allocator.get();
1805
- }
1806
-
1807
- ACL_CHECK(aclnnInplaceAddcdiv(workspaceAddr, workspaceSize, executor,
1808
- ctx.stream()));
1809
- }
1810
-
1811
- /**
1812
- * @brief Matrix division, optionally in-place.
1813
- *
1814
- * This function division each element of the source tensor `acl_src` by the
1815
- * tensor `acl_other` and stores the result in the destination tensor `acl_dst`.
1816
- * If `inplace` is true, `acl_dst` will not be used and the operation is
1817
- * performed in-place on `acl_src`. The operation is defined as: \f[
1818
- * \text{dst}_i = \frac{\text{acl_src}_i}{\text{acl_other}_i}
1819
- * \f]
1820
- *
1821
- * @param ctx The context for the CANN backend operations.
1822
- * @param acl_src Numerator tensor..
1823
- * @param acl_other Denominator tensor.
1824
- * @param acl_dst The destination tensor where the result will be stored if
1825
- * `inplace` is false.
1826
- * @param inplace Flag indicating whether to perform the operation in-place on
1827
- * `acl_src`.
1828
- */
1829
- static void aclnn_div_tensor(ggml_backend_cann_context& ctx, aclTensor* acl_src,
1830
- aclTensor* acl_other, aclTensor* acl_dst,
1831
- bool inplace) {
1832
- uint64_t workspaceSize = 0;
1833
- aclOpExecutor* executor;
1834
- void* workspaceAddr = nullptr;
1835
-
1836
- if (inplace) {
1837
- ACL_CHECK(aclnnInplaceDivGetWorkspaceSize(acl_src, acl_other,
1838
- &workspaceSize, &executor));
1839
- if (workspaceSize > 0) {
1840
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
1841
- workspaceAddr = workspace_allocator.get();
1842
- }
1843
-
1844
- ACL_CHECK(aclnnInplaceDiv(workspaceAddr, workspaceSize, executor,
1845
- ctx.stream()));
1846
- } else {
1847
- ACL_CHECK(aclnnDivGetWorkspaceSize(acl_src, acl_other, acl_dst,
1848
- &workspaceSize, &executor));
1849
- if (workspaceSize > 0) {
1850
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
1851
- workspaceAddr = workspace_allocator.get();
1852
- }
1853
-
1854
- ACL_CHECK(
1855
- aclnnDiv(workspaceAddr, workspaceSize, executor, ctx.stream()));
1856
- }
1141
+ GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);
1857
1142
  }
1858
1143
 
1859
1144
  void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
@@ -1902,13 +1187,13 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
1902
1187
 
1903
1188
  ggml_cann_pool_alloc permute_allocator(ctx.pool(), ggml_nbytes(src));
1904
1189
  void* tmp_permute_buffer = permute_allocator.get();
1905
- aclTensor* tmp_permute_tenosr = ggml_cann_create_tensor(
1190
+ aclTensor* tmp_permute_tensor = ggml_cann_create_tensor(
1906
1191
  tmp_permute_buffer, ggml_cann_type_mapping(src->type),
1907
1192
  ggml_type_size(src->type), tmp_permute_ne, tmp_permute_nb,
1908
1193
  GGML_MAX_DIMS, ACL_FORMAT_ND);
1909
1194
  int64_t permute_dim[] = {0, 1, 3, 2};
1910
1195
  int64_t num_dims = 4;
1911
- aclnn_permute(ctx, acl_src, tmp_permute_tenosr, permute_dim, num_dims);
1196
+ aclnn_permute(ctx, acl_src, tmp_permute_tensor, permute_dim, num_dims);
1912
1197
 
1913
1198
  // timestep * freq
1914
1199
  int64_t tmp_mul_ne[] = {src->ne[1] * half, src->ne[0], src->ne[2],
@@ -1929,7 +1214,7 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
1929
1214
  tmp_mul_buffer, ggml_cann_type_mapping(src->type),
1930
1215
  ggml_type_size(src->type), tmp_mul_ne, tmp_mul_nb, GGML_MAX_DIMS,
1931
1216
  ACL_FORMAT_ND);
1932
- aclnn_mul(ctx, tmp_permute_tenosr, tmp_arange_tensor, tmp_mul_tensor);
1217
+ aclnn_mul(ctx, tmp_permute_tensor, tmp_arange_tensor, tmp_mul_tensor);
1933
1218
 
1934
1219
  // cos
1935
1220
  ggml_cann_pool_alloc cos_allocator(
@@ -1957,17 +1242,13 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
1957
1242
  int64_t concat_dim = 3;
1958
1243
  aclTensor* acl_dst = ggml_cann_create_tensor(dst);
1959
1244
  aclTensor* tensors[] = {tmp_cos_tensor, tmp_sin_tensor};
1960
- aclTensorList* tensorList = aclCreateTensorList(tensors, 2);
1961
- aclnn_concat(ctx, tensorList, acl_dst, concat_dim);
1245
+ aclTensorList* tensor_list = aclCreateTensorList(tensors, 2);
1246
+ aclnn_concat(ctx, tensor_list, acl_dst, concat_dim);
1962
1247
 
1963
1248
  // release
1964
1249
  // segmentation fault when delete both tensorList and his elements.
1965
- ACL_CHECK(aclDestroyTensorList(tensorList));
1966
- ACL_CHECK(aclDestroyTensor(acl_src));
1967
- ACL_CHECK(aclDestroyTensor(tmp_arange_tensor));
1968
- ACL_CHECK(aclDestroyTensor(tmp_permute_tenosr));
1969
- ACL_CHECK(aclDestroyTensor(tmp_mul_tensor));
1970
- ACL_CHECK(aclDestroyTensor(acl_dst));
1250
+ ggml_cann_release_resources(ctx, tensor_list, acl_src, tmp_arange_tensor,
1251
+ tmp_permute_tensor, tmp_mul_tensor, acl_dst);
1971
1252
  }
1972
1253
 
1973
1254
  /**
@@ -1983,21 +1264,8 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
1983
1264
  static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
1984
1265
  aclTensor* acl_dst) {
1985
1266
  auto acl_scalar = aclCreateScalar(&scalar, aclDataType::ACL_FLOAT);
1986
-
1987
- uint64_t workspaceSize = 0;
1988
- aclOpExecutor* executor;
1989
- void* workspaceAddr = nullptr;
1990
-
1991
- ACL_CHECK(aclnnInplaceFillScalarGetWorkspaceSize(
1992
- acl_dst, acl_scalar, &workspaceSize, &executor));
1993
- if (workspaceSize > 0) {
1994
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
1995
- workspaceAddr = workspace_allocator.get();
1996
- }
1997
-
1998
- ACL_CHECK(aclnnInplaceFillScalar(workspaceAddr, workspaceSize, executor,
1999
- ctx.stream()));
2000
- ACL_CHECK(aclDestroyScalar(acl_scalar));
1267
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst, acl_scalar);
1268
+ ggml_cann_release_resources(ctx, acl_scalar);
2001
1269
  }
2002
1270
 
2003
1271
  /**
@@ -2018,19 +1286,7 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
2018
1286
  */
2019
1287
  static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
2020
1288
  aclTensor* acl_dst, aclTensor* acl_exp) {
2021
- uint64_t workspaceSize = 0;
2022
- aclOpExecutor* executor;
2023
- void* workspaceAddr = nullptr;
2024
-
2025
- ACL_CHECK(aclnnInplacePowTensorTensorGetWorkspaceSize(
2026
- acl_dst, acl_exp, &workspaceSize, &executor));
2027
- if (workspaceSize > 0) {
2028
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
2029
- workspaceAddr = workspace_allocator.get();
2030
- }
2031
-
2032
- ACL_CHECK(aclnnInplacePowTensorTensor(workspaceAddr, workspaceSize,
2033
- executor, ctx.stream()));
1289
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplacePowTensorTensor, acl_dst, acl_exp);
2034
1290
  }
2035
1291
 
2036
1292
  /**
@@ -2182,56 +1438,15 @@ static void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src,
2182
1438
 
2183
1439
  // add
2184
1440
  aclnn_add(ctx, tmp_output_tensor, acl_src, acl_dst);
2185
-
2186
- ACL_CHECK(aclDestroyTensor(tmp_arange1_tensor));
2187
- ACL_CHECK(aclDestroyTensor(tmp_arange2_tensor));
2188
- ACL_CHECK(aclDestroyTensor(tmp_mk_base1_tensor));
2189
- ACL_CHECK(aclDestroyTensor(tmp_mk_base2_tensor));
2190
- ACL_CHECK(aclDestroyTensor(tmp_mk_base_tensor));
2191
- ACL_CHECK(aclDestroyTensor(tmp_arange_tensor));
2192
- ACL_CHECK(aclDestroyTensor(tmp_mk_tensor));
2193
- ACL_CHECK(aclDestroyTensor(tmp_output_tensor));
1441
+ ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
1442
+ tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
1443
+ tmp_arange_tensor, tmp_mk_tensor, tmp_output_tensor);
2194
1444
  }
2195
1445
 
2196
1446
  void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2197
1447
  ggml_cann_dup(ctx, dst);
2198
1448
  }
2199
1449
 
2200
- /**
2201
- * @brief Performs element-wise addition of two tensors in place.
2202
- *
2203
- * This function adds the source tensor `acl_src` to the destination tensor
2204
- * `acl_dst` element-wise and stores the result in the destination tensor
2205
- * `acl_dst`.
2206
- *
2207
- * @param ctx The context for the CANN backend operations.
2208
- * @param acl_src The source tensor to be added.
2209
- * @param acl_dst The destination tensor which will hold the result of the
2210
- * addition.
2211
- */
2212
- static void aclnn_inplace_add(ggml_backend_cann_context& ctx,
2213
- aclTensor* acl_src, aclTensor* acl_dst) {
2214
- aclScalar* alpha = nullptr;
2215
- float alphaValue = 1.0f;
2216
- alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
2217
-
2218
- uint64_t workspaceSize = 0;
2219
- aclOpExecutor* executor;
2220
- void* workspaceAddr = nullptr;
2221
-
2222
- ACL_CHECK(aclnnInplaceAddGetWorkspaceSize(acl_dst, acl_src, alpha,
2223
- &workspaceSize, &executor));
2224
- if (workspaceSize > 0) {
2225
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
2226
- workspaceAddr = workspace_allocator.get();
2227
- }
2228
-
2229
- ACL_CHECK(
2230
- aclnnInplaceAdd(workspaceAddr, workspaceSize, executor, ctx.stream()));
2231
-
2232
- ACL_CHECK(aclDestroyScalar(alpha));
2233
- }
2234
-
2235
1450
  /**
2236
1451
  * @brief Applies the softmax function to a tensor along a specified dimension.
2237
1452
  *
@@ -2248,20 +1463,7 @@ static void aclnn_inplace_add(ggml_backend_cann_context& ctx,
2248
1463
  */
2249
1464
  static void aclnn_softmax(ggml_backend_cann_context& ctx, aclTensor* acl_src,
2250
1465
  int64_t dim, aclTensor* acl_dst) {
2251
- uint64_t workspaceSize = 0;
2252
- aclOpExecutor* executor;
2253
- void* workspaceAddr = nullptr;
2254
-
2255
- ACL_CHECK(aclnnSoftmaxGetWorkspaceSize(acl_src, dim, acl_dst,
2256
- &workspaceSize, &executor));
2257
-
2258
- if (workspaceSize > 0) {
2259
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
2260
- workspaceAddr = workspace_allocator.get();
2261
- }
2262
-
2263
- aclrtStream stream = ctx.stream();
2264
- ACL_CHECK(aclnnSoftmax(workspaceAddr, workspaceSize, executor, stream));
1466
+ GGML_CANN_CALL_ACLNN_OP(ctx, Softmax, acl_src, dim, acl_dst);
2265
1467
  }
2266
1468
 
2267
1469
  void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -2311,8 +1513,7 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2311
1513
  src1_fp32_nb, GGML_MAX_DIMS);
2312
1514
  aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
2313
1515
  aclnn_cast(ctx, acl_src1, acl_src1_fp32_tensor, ACL_FLOAT);
2314
-
2315
- ACL_CHECK(aclDestroyTensor(acl_src1));
1516
+ ggml_cann_release_resources(ctx, acl_src1);
2316
1517
  } else {
2317
1518
  acl_src1_fp32_tensor = ggml_cann_create_tensor(src1);
2318
1519
  }
@@ -2365,98 +1566,158 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2365
1566
 
2366
1567
  // softmax
2367
1568
  aclnn_softmax(ctx, alibi_output_tensor, 3, acl_dst);
2368
- ACL_CHECK(aclDestroyTensor(alibi_output_tensor));
1569
+ ggml_cann_release_resources(ctx, alibi_output_tensor);
2369
1570
  } else {
2370
1571
  aclnn_softmax(ctx, acl_input_mul_scale_tensor, 3, acl_dst);
2371
1572
  }
2372
1573
 
2373
- ACL_CHECK(aclDestroyTensor(acl_src0));
2374
- ACL_CHECK(aclDestroyTensor(acl_src1_fp32_tensor));
2375
- ACL_CHECK(aclDestroyTensor(acl_dst));
2376
- ACL_CHECK(aclDestroyScalar(acl_scale));
2377
- ACL_CHECK(aclDestroyTensor(acl_input_mul_scale_tensor));
2378
- ACL_CHECK(aclDestroyTensor(tmp_mask_tensor));
1574
+ ggml_cann_release_resources(ctx, acl_src0, acl_src1_fp32_tensor, acl_dst,
1575
+ acl_scale, acl_input_mul_scale_tensor, tmp_mask_tensor);
2379
1576
  }
2380
1577
 
2381
- void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2382
- ggml_tensor* src0 = dst->src[0];
2383
- ggml_tensor* src1 = dst->src[1];
1578
+ /**
1579
+ * @brief Performs embedding operation on a 4D tensor using the CANN backend.
1580
+ *
1581
+ * This function extracts slices from the source tensor (`src_buffer`),
1582
+ * index tensor (`index`), and destination tensor (`dst`), and performs an
1583
+ * embedding operation on them. The embedding operation is applied by iterating
1584
+ * over the last two dimensions of the source tensor, creating the necessary
1585
+ * tensors for the source, index, and output, and executing the embedding operation.
1586
+ *
1587
+ * @param ctx The context for CANN backend operations.
1588
+ * @param src_buffer The source buffer holding the data for the source tensor.
1589
+ * @param src_ne The dimensions of the source tensor.
1590
+ * @param src_nb The strides (byte offsets) of the source tensor.
1591
+ * @param index The index tensor used in the embedding operation.
1592
+ * @param dst The destination tensor where the result will be stored.
1593
+ */
1594
+ static void aclnn_embedding_4d(ggml_backend_cann_context& ctx, void* src_buffer,
1595
+ int64_t* src_ne, size_t* src_nb, ggml_tensor* index,
1596
+ ggml_tensor* dst) {
1597
+ for (int64_t i = 0; i < src_ne[3]; i++) {
1598
+ for (int64_t j = 0; j < src_ne[2]; j++) {
1599
+ // src
1600
+ int64_t acl_src_ne[2] = {src_ne[0], src_ne[1]};
1601
+ size_t acl_src_nb[2] = {src_nb[0], src_nb[1]};
1602
+ aclTensor* acl_src_tensor = ggml_cann_create_tensor(
1603
+ (char*)src_buffer + i * src_nb[3] + j * src_nb[2],
1604
+ ggml_cann_type_mapping(dst->type), ggml_element_size(dst),
1605
+ acl_src_ne, acl_src_nb, 2);
1606
+
1607
+ // index
1608
+ int64_t acl_index_ne[1] = {index->ne[0]};
1609
+ size_t acl_index_nb[1] = {index->nb[0]};
1610
+ aclTensor* acl_index = ggml_cann_create_tensor(
1611
+ (char*)index->data + i * index->nb[2] + j * index->nb[1],
1612
+ ggml_cann_type_mapping(index->type), ggml_element_size(index),
1613
+ acl_index_ne, acl_index_nb, 1);
1614
+
1615
+ // out
1616
+ int64_t acl_out_ne[2] = {dst->ne[0], dst->ne[1]};
1617
+ size_t acl_out_nb[2] = {dst->nb[0], dst->nb[1]};
1618
+ aclTensor* acl_out = ggml_cann_create_tensor(
1619
+ (char*)dst->data + i * dst->nb[3] + j * dst->nb[2],
1620
+ ggml_cann_type_mapping(dst->type), ggml_element_size(dst),
1621
+ acl_out_ne, acl_out_nb, 2);
1622
+ GGML_CANN_CALL_ACLNN_OP(ctx, Embedding, acl_src_tensor, acl_index, acl_out);
1623
+ ggml_cann_release_resources(ctx, acl_src_tensor, acl_index, acl_out);
1624
+ }
1625
+ }
1626
+ }
2384
1627
 
2385
- ggml_cann_pool_alloc src0_extra_allocator(ctx.pool(), sizeof(ggml_tensor));
2386
- ggml_cann_pool_alloc src1_extra_allocator(ctx.pool(), sizeof(ggml_tensor));
2387
- ggml_cann_pool_alloc dst_extra_allocator(ctx.pool(), sizeof(ggml_tensor));
2388
- src0->extra = src0_extra_allocator.get();
2389
- src1->extra = src1_extra_allocator.get();
2390
- dst->extra = dst_extra_allocator.get();
2391
- ACL_CHECK(aclrtMemcpyAsync(src0->extra, sizeof(ggml_tensor), src0,
2392
- sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE,
2393
- ctx.stream()));
2394
- ACL_CHECK(aclrtMemcpyAsync(src1->extra, sizeof(ggml_tensor), src1,
2395
- sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE,
2396
- ctx.stream()));
2397
- ACL_CHECK(aclrtMemcpyAsync(dst->extra, sizeof(ggml_tensor), dst,
2398
- sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE,
2399
- ctx.stream()));
1628
+ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1629
+ ggml_tensor* src0 = dst->src[0]; // src
1630
+ ggml_tensor* src1 = dst->src[1]; // index
2400
1631
 
2401
1632
  switch (src0->type) {
2402
1633
  case GGML_TYPE_F32: {
2403
- #ifdef ASCEND_310P
2404
- // Special operation for get_row_f32 kernel of 310P: clear the
2405
- // content of dest data buffer when row is not aligned to 32 bytes
2406
- if ((src0->ne[0] % 8) != 0) {
2407
- size_t dst_len = src1->ne[0] * src1->ne[1] * src1->ne[2] *
2408
- src0->ne[0] * ggml_type_size(GGML_TYPE_F32);
2409
- ACL_CHECK(aclrtMemset((char*)dst->data, dst_len, 0, dst_len));
2410
- }
2411
- #endif
2412
- aclrtlaunch_ascendc_get_row_f32(
2413
- 24, ctx.stream(), src0->data, src1->data, dst->data,
2414
- ((ggml_tensor*)src0->extra)->ne,
2415
- ((ggml_tensor*)src0->extra)->nb,
2416
- ((ggml_tensor*)src1->extra)->ne,
2417
- ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne,
2418
- ((ggml_tensor*)dst->extra)->nb);
1634
+ aclnn_embedding_4d(ctx, src0->data, src0->ne, src0->nb, src1,
1635
+ dst);
2419
1636
  break;
2420
1637
  }
2421
1638
  case GGML_TYPE_F16: {
2422
- #ifdef ASCEND_310P
2423
- // Special operation for get_row_f16 kernel of 310P: clear the
2424
- // content of dest data buffer when row is not aligned to 32 bytes
2425
- if ((src0->ne[0] % 16) != 0) {
2426
- size_t dst_len =
2427
- src1->ne[0] * src1->ne[1] * src1->ne[2] * src0->ne[0] *
2428
- ggml_type_size(
2429
- GGML_TYPE_F32); // out is also f32, even input is f16
2430
- ACL_CHECK(aclrtMemset((char*)dst->data, dst_len, 0, dst_len));
1639
+ aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
1640
+ ggml_cann_pool_alloc src_buffer_allocator(
1641
+ ctx.pool(), ggml_nelements(src0) * sizeof(float_t));
1642
+ void* src_trans_buffer = src_buffer_allocator.get();
1643
+ size_t src_trans_nb[GGML_MAX_DIMS];
1644
+ src_trans_nb[0] = sizeof(float_t);
1645
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
1646
+ src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
2431
1647
  }
2432
- #endif
2433
- aclrtlaunch_ascendc_get_row_f16(
2434
- 24, ctx.stream(), src0->data, src1->data, dst->data,
2435
- ((ggml_tensor*)src0->extra)->ne,
2436
- ((ggml_tensor*)src0->extra)->nb,
2437
- ((ggml_tensor*)src1->extra)->ne,
2438
- ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne,
2439
- ((ggml_tensor*)dst->extra)->nb);
1648
+ aclTensor* src_trans_tensor = ggml_cann_create_tensor(
1649
+ src_trans_buffer, ACL_FLOAT, ggml_type_size(dst->type),
1650
+ src0->ne, src_trans_nb, GGML_MAX_DIMS);
1651
+ aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type));
1652
+ aclnn_embedding_4d(ctx, src_trans_buffer, src0->ne,
1653
+ src_trans_nb, src1, dst);
1654
+ ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor);
2440
1655
  break;
2441
1656
  }
2442
- case GGML_TYPE_Q4_0:
2443
- aclrtlaunch_ascendc_get_row_q4_0(
2444
- 24, ctx.stream(), src0->data, src1->data, dst->data,
2445
- ((ggml_tensor*)src0->extra)->ne,
2446
- ((ggml_tensor*)src1->extra)->ne,
2447
- ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne,
2448
- ((ggml_tensor*)dst->extra)->nb);
2449
- break;
2450
- case GGML_TYPE_Q8_0:
2451
- aclrtlaunch_ascendc_get_row_q8_0(
2452
- 24, ctx.stream(), src0->data, src1->data, dst->data,
2453
- ((ggml_tensor*)src0->extra)->ne,
2454
- ((ggml_tensor*)src1->extra)->ne,
2455
- ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne,
2456
- ((ggml_tensor*)dst->extra)->nb);
1657
+ case GGML_TYPE_Q8_0: {
1658
+ // add 1 dim for bcast mul.
1659
+ size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1],
1660
+ dequant_nb[GGML_MAX_DIMS + 1];
1661
+ int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1],
1662
+ *dequant_ne;
1663
+ int64_t scale_offset = 0;
1664
+
1665
+ // [3,4,5,64] -> [3,4,5,2,32]
1666
+ weight_ne[0] = QK8_0;
1667
+ weight_ne[1] = src0->ne[0] / QK8_0;
1668
+ weight_nb[0] = sizeof(int8_t);
1669
+ weight_nb[1] = weight_nb[0] * weight_ne[0];
1670
+ for (int i = 2; i < GGML_MAX_DIMS + 1; i++) {
1671
+ weight_ne[i] = src0->ne[i - 1];
1672
+ weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1];
1673
+ }
1674
+
1675
+ // [3,4,5,64] -> [3,4,5,2,1]
1676
+ scale_ne[0] = 1;
1677
+ scale_ne[1] = src0->ne[0] / QK8_0;
1678
+ scale_nb[0] = sizeof(uint16_t);
1679
+ scale_nb[1] = scale_nb[0] * scale_ne[0];
1680
+ for (int i = 2; i < GGML_MAX_DIMS + 1; i++) {
1681
+ scale_ne[i] = src0->ne[i - 1];
1682
+ scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1];
1683
+ }
1684
+
1685
+ // [3,4,5,64] -> [3,4,5,2,32]
1686
+ dequant_ne = weight_ne;
1687
+ dequant_nb[0] = sizeof(float_t);
1688
+ for (int i = 1; i < GGML_MAX_DIMS + 1; i++) {
1689
+ dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1];
1690
+ }
1691
+
1692
+ scale_offset = ggml_nelements(src0) * sizeof(int8_t);
1693
+ ggml_cann_pool_alloc dequant_buffer_allocator(
1694
+ ctx.pool(), ggml_nelements(src0) * sizeof(float_t));
1695
+
1696
+ aclTensor* acl_weight_tensor = ggml_cann_create_tensor(
1697
+ src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb,
1698
+ GGML_MAX_DIMS + 1);
1699
+ aclTensor* acl_scale_tensor = ggml_cann_create_tensor(
1700
+ src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb,
1701
+ GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset);
1702
+ aclTensor* dequant_tensor = ggml_cann_create_tensor(
1703
+ dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float_t),
1704
+ dequant_ne, dequant_nb, GGML_MAX_DIMS + 1);
1705
+
1706
+ aclnn_mul(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor);
1707
+ dequant_nb[0] = sizeof(float_t);
1708
+ dequant_ne = src0->ne;
1709
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
1710
+ dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1];
1711
+ }
1712
+
1713
+ aclnn_embedding_4d(ctx, dequant_buffer_allocator.get(),
1714
+ dequant_ne, dequant_nb, src1, dst);
1715
+
1716
+ ggml_cann_release_resources(ctx, dequant_tensor);
2457
1717
  break;
1718
+ }
2458
1719
  default:
2459
- GGML_ABORT("fatal error");
1720
+ GGML_ABORT("Unsupported tensor type for GGML_OP_GET_ROWS");
2460
1721
  break;
2461
1722
  }
2462
1723
  }
@@ -2480,133 +1741,8 @@ static void aclnn_repeat_interleave(ggml_backend_cann_context& ctx,
2480
1741
  aclTensor* acl_src, aclTensor* acl_dst,
2481
1742
  int64_t dim, int64_t repeats,
2482
1743
  int64_t output_size) {
2483
- uint64_t workspaceSize = 0;
2484
- aclOpExecutor* executor;
2485
- void* workspaceAddr = nullptr;
2486
-
2487
- ACL_CHECK(aclnnRepeatInterleaveIntWithDimGetWorkspaceSize(
2488
- acl_src, repeats, dim, output_size, acl_dst, &workspaceSize,
2489
- &executor));
2490
- if (workspaceSize > 0) {
2491
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
2492
- workspaceAddr = workspace_allocator.get();
2493
- }
2494
-
2495
- ACL_CHECK(aclnnRepeatInterleaveIntWithDim(workspaceAddr, workspaceSize,
2496
- executor, ctx.stream()));
2497
- }
2498
-
2499
- /**
2500
- * @brief Performs matrix multiplication of two tensors.
2501
- *
2502
- * This function computes the matrix multiplication of the input tensor
2503
- * `acl_input` and the weight tensor `acl_weight`, and stores the result in the
2504
- * destination tensor `acl_dst`.
2505
- * The operation is defined as:
2506
- * \f[
2507
- * \text {acl_dst}=\text {acl_input@acl_weight}
2508
- * \f]
2509
- *
2510
- * @param ctx The context for the CANN backend operations.
2511
- * @param acl_input The input tensor for the matrix multiplication.
2512
- * @param acl_weight The weight tensor for the matrix multiplication.
2513
- * @param acl_dst The destination tensor where the result of the matrix
2514
- * multiplication will be stored.
2515
- */
2516
- static void aclnn_mat_mul(ggml_backend_cann_context& ctx, aclTensor* acl_input,
2517
- aclTensor* acl_weight, aclTensor* acl_dst) {
2518
- int8_t cube_math_type = 1; // ALLOW_FP32_DOWN_PRECISION, when input is
2519
- // fp32, atlas a2 will transpose it to HFLOAT32.
2520
- uint64_t workspaceSize = 0;
2521
- aclOpExecutor* executor;
2522
- void* workspaceAddr = nullptr;
2523
-
2524
- ACL_CHECK(aclnnMatmulGetWorkspaceSize(acl_input, acl_weight, acl_dst,
2525
- cube_math_type, &workspaceSize,
2526
- &executor));
2527
-
2528
- if (workspaceSize > 0) {
2529
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
2530
- workspaceAddr = workspace_allocator.get();
2531
- }
2532
-
2533
- ACL_CHECK(
2534
- aclnnMatmul(workspaceAddr, workspaceSize, executor, ctx.stream()));
2535
- }
2536
-
2537
- /**
2538
- * @brief Performs matrix multiplication of two 2D tensors.
2539
- *
2540
- * This function computes the matrix multiplication of the input tensor
2541
- * `acl_input` and the weight tensor `acl_weight`, and stores the result in the
2542
- * destination tensor `acl_dst`.
2543
- * The operation is defined as:
2544
- * \f[
2545
- * \text {acl_dst}=\text {acl_input@acl_weight}
2546
- * \f]
2547
- *
2548
- * @param ctx The context for the CANN backend operations.
2549
- * @param acl_input The input tensor for the matrix multiplication.
2550
- * @param acl_weight The weight tensor for the matrix multiplication.
2551
- * @param acl_dst The destination tensor where the result of the matrix
2552
- * multiplication will be stored.
2553
- */
2554
- static void aclnn_mat_mul_2d(ggml_backend_cann_context& ctx,
2555
- aclTensor* acl_input, aclTensor* acl_weight,
2556
- aclTensor* acl_dst) {
2557
- int8_t cube_math_type = 2;
2558
- uint64_t workspaceSize = 0;
2559
- aclOpExecutor* executor;
2560
- void* workspaceAddr = nullptr;
2561
-
2562
- ACL_CHECK(aclnnMmGetWorkspaceSize(acl_input, acl_weight, acl_dst,
2563
- cube_math_type, &workspaceSize,
2564
- &executor));
2565
-
2566
- if (workspaceSize > 0) {
2567
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
2568
- workspaceAddr = workspace_allocator.get();
2569
- }
2570
-
2571
- ACL_CHECK(aclnnMm(workspaceAddr, workspaceSize, executor, ctx.stream()));
2572
- }
2573
-
2574
- /**
2575
- * @brief Performs matrix multiplication of two 3D tensors.
2576
- *
2577
- * This function computes the matrix multiplication of the input tensor
2578
- * `acl_input` and the weight tensor `acl_weight`, and stores the result in the
2579
- * destination tensor `acl_dst`.
2580
- * The operation is defined as:
2581
- * \f[
2582
- * \text {acl_dst}=\text {acl_input@acl_weight}
2583
- * \f]
2584
- *
2585
- * @param ctx The context for the CANN backend operations.
2586
- * @param acl_input The input tensor for the matrix multiplication.
2587
- * @param acl_weight The weight tensor for the matrix multiplication.
2588
- * @param acl_dst The destination tensor where the result of the matrix
2589
- * multiplication will be stored.
2590
- */
2591
- static void aclnn_mat_mul_3d(ggml_backend_cann_context& ctx,
2592
- aclTensor* acl_input, aclTensor* acl_weight,
2593
- aclTensor* acl_dst) {
2594
- int8_t cube_math_type = 2;
2595
- uint64_t workspaceSize = 0;
2596
- aclOpExecutor* executor;
2597
- void* workspaceAddr = nullptr;
2598
-
2599
- ACL_CHECK(aclnnBatchMatMulGetWorkspaceSize(acl_input, acl_weight, acl_dst,
2600
- cube_math_type, &workspaceSize,
2601
- &executor));
2602
-
2603
- if (workspaceSize > 0) {
2604
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
2605
- workspaceAddr = workspace_allocator.get();
2606
- }
2607
-
2608
- ACL_CHECK(
2609
- aclnnBatchMatMul(workspaceAddr, workspaceSize, executor, ctx.stream()));
1744
+ GGML_CANN_CALL_ACLNN_OP(ctx, RepeatInterleaveIntWithDim, acl_src, repeats, dim,
1745
+ output_size, acl_dst);
2610
1746
  }
2611
1747
 
2612
1748
  /**
@@ -2654,19 +1790,19 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
2654
1790
 
2655
1791
  switch (n_dims) {
2656
1792
  case 2:
2657
- aclnn_mat_mul_2d(ctx, acl_input_tensor, acl_weight_tensor, acl_dst);
1793
+ GGML_CANN_CALL_ACLNN_OP(ctx, Mm, acl_input_tensor, acl_weight_tensor, acl_dst, 2);
2658
1794
  break;
2659
1795
  case 3:
2660
- aclnn_mat_mul_3d(ctx, acl_input_tensor, acl_weight_tensor, acl_dst);
1796
+ GGML_CANN_CALL_ACLNN_OP(ctx, BatchMatMul, acl_input_tensor, acl_weight_tensor, acl_dst, 2);
2661
1797
  break;
2662
1798
  default:
2663
- aclnn_mat_mul(ctx, acl_input_tensor, acl_weight_tensor, acl_dst);
1799
+ // ALLOW_FP32_DOWN_PRECISION, when input is
1800
+ // fp32, atlas a2 will transpose it to HFLOAT32.
1801
+ GGML_CANN_CALL_ACLNN_OP(ctx, Matmul, acl_input_tensor, acl_weight_tensor, acl_dst, 1);
2664
1802
  break;
2665
1803
  }
2666
1804
 
2667
- ACL_CHECK(aclDestroyTensor(acl_weight_tensor));
2668
- ACL_CHECK(aclDestroyTensor(acl_input_tensor));
2669
- ACL_CHECK(aclDestroyTensor(acl_dst));
1805
+ ggml_cann_release_resources(ctx, acl_weight_tensor, acl_input_tensor, acl_dst);
2670
1806
  }
2671
1807
 
2672
1808
  /**
@@ -2736,9 +1872,7 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
2736
1872
  input_buffer, ACL_FLOAT16, input_elem_size, input_cast_ne,
2737
1873
  input_cast_nb, GGML_MAX_DIMS);
2738
1874
  aclnn_cast(ctx, acl_src1_tensor, acl_input_tensor, ACL_FLOAT16);
2739
-
2740
- ACL_CHECK(aclDestroyTensor(acl_input_tensor));
2741
- ACL_CHECK(aclDestroyTensor(acl_src1_tensor));
1875
+ ggml_cann_release_resources(ctx, acl_input_tensor, acl_src1_tensor);
2742
1876
  }
2743
1877
 
2744
1878
  // output
@@ -2753,9 +1887,6 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
2753
1887
  int64_t max_elem_size = 65535;
2754
1888
  int64_t split_size = (src0->ne[1] / max_elem_size) + 1;
2755
1889
  ggml_cann_pool_alloc workspace_allocator(ctx.pool());
2756
- aclOpExecutor* executor = nullptr;
2757
- uint64_t workspaceSize = 0;
2758
- void* workspaceAddr = nullptr;
2759
1890
  for (int64_t n1 = 0; n1 < src1->ne[3]; n1++) {
2760
1891
  for (int64_t c1 = 0; c1 < src1->ne[2]; c1++) {
2761
1892
  int64_t n0 = n1 / (src1->ne[3] / src0->ne[3]);
@@ -2790,20 +1921,15 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
2790
1921
  (char*)output_buffer + batch1 * output_stride, ACL_FLOAT16,
2791
1922
  output_elem_size, output_ne, output_nb, 2, ACL_FORMAT_ND,
2792
1923
  output_ne_offset);
2793
-
2794
- ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
2795
- acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr,
2796
- nullptr, nullptr, nullptr, QK8_0, acl_output_tensor,
2797
- &workspaceSize, &executor));
2798
- if (workspaceAddr == nullptr) {
2799
- workspaceAddr = workspace_allocator.alloc(workspaceSize);
1924
+ int64_t antiquantGroupSize = 0;
1925
+ if (src0->ne[0] > QK8_0) {
1926
+ antiquantGroupSize = QK8_0;
2800
1927
  }
2801
- ACL_CHECK(aclnnWeightQuantBatchMatmulV2(
2802
- workspaceAddr, workspaceSize, executor, ctx.stream()));
2803
-
2804
- ACL_CHECK(aclDestroyTensor(acl_weight_tensor));
2805
- ACL_CHECK(aclDestroyTensor(acl_scale_tensor));
2806
- ACL_CHECK(aclDestroyTensor(acl_output_tensor));
1928
+ GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, acl_input_tensor,
1929
+ acl_weight_tensor, acl_scale_tensor, nullptr,
1930
+ nullptr, nullptr, nullptr, antiquantGroupSize,
1931
+ acl_output_tensor);
1932
+ ggml_cann_release_resources(ctx, acl_weight_tensor, acl_scale_tensor, acl_output_tensor);
2807
1933
 
2808
1934
  // other splits
2809
1935
  for (int64_t split = 1; split < split_size; split++) {
@@ -2830,20 +1956,14 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
2830
1956
  (char*)output_buffer + batch1 * output_stride, ACL_FLOAT16,
2831
1957
  output_elem_size, output_ne, output_nb, 2, ACL_FORMAT_ND,
2832
1958
  output_ne_offset);
2833
-
2834
- ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
2835
- acl_input_tensor, acl_weight_tensor, acl_scale_tensor,
2836
- nullptr, nullptr, nullptr, nullptr, QK8_0,
2837
- acl_output_tensor, &workspaceSize, &executor));
2838
- ACL_CHECK(aclnnWeightQuantBatchMatmulV2(
2839
- workspaceAddr, workspaceSize, executor, ctx.stream()));
2840
-
2841
- ACL_CHECK(aclDestroyTensor(acl_weight_tensor));
2842
- ACL_CHECK(aclDestroyTensor(acl_scale_tensor));
2843
- ACL_CHECK(aclDestroyTensor(acl_output_tensor));
1959
+ GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, acl_input_tensor,
1960
+ acl_weight_tensor, acl_scale_tensor, nullptr,
1961
+ nullptr, nullptr, nullptr, antiquantGroupSize,
1962
+ acl_output_tensor);
1963
+ ggml_cann_release_resources(ctx, acl_weight_tensor, acl_scale_tensor, acl_output_tensor);
2844
1964
  }
2845
1965
 
2846
- ACL_CHECK(aclDestroyTensor(acl_input_tensor));
1966
+ ggml_cann_release_resources(ctx, acl_input_tensor);
2847
1967
  }
2848
1968
  }
2849
1969
 
@@ -2860,11 +1980,9 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
2860
1980
  output_buffer, ACL_FLOAT16, output_elem_size, output_cast_ne,
2861
1981
  output_cast_nb, GGML_MAX_DIMS);
2862
1982
  aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
2863
- aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor,
2864
- ggml_cann_type_mapping(dst->type));
1983
+ aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
2865
1984
 
2866
- ACL_CHECK(aclDestroyTensor(acl_output_tensor));
2867
- ACL_CHECK(aclDestroyTensor(acl_dst_tensor));
1985
+ ggml_cann_release_resources(ctx, acl_output_tensor, acl_dst_tensor);
2868
1986
  }
2869
1987
  }
2870
1988
 
@@ -2880,7 +1998,7 @@ void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2880
1998
  ggml_cann_mul_mat_quant(ctx, dst, type);
2881
1999
  break;
2882
2000
  default:
2883
- GGML_ABORT("fatal error");
2001
+ GGML_ABORT("Unsupported type for mul_mat");
2884
2002
  break;
2885
2003
  }
2886
2004
  }
@@ -2905,22 +2023,8 @@ static void aclnn_roll(ggml_backend_cann_context& ctx, aclTensor* acl_src,
2905
2023
  aclTensor* acl_dst, int64_t* shifts, int64_t* dims) {
2906
2024
  aclIntArray* acl_shifts = aclCreateIntArray(shifts, 1);
2907
2025
  aclIntArray* acl_dims = aclCreateIntArray(dims, 1);
2908
-
2909
- uint64_t workspaceSize = 0;
2910
- aclOpExecutor* executor;
2911
- void* workspaceAddr = nullptr;
2912
-
2913
- ACL_CHECK(aclnnRollGetWorkspaceSize(acl_src, acl_shifts, acl_dims, acl_dst,
2914
- &workspaceSize, &executor));
2915
- if (workspaceSize > 0) {
2916
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
2917
- workspaceAddr = workspace_allocator.get();
2918
- }
2919
-
2920
- ACL_CHECK(aclnnRoll(workspaceAddr, workspaceSize, executor, ctx.stream()));
2921
-
2922
- ACL_CHECK(aclDestroyIntArray(acl_shifts));
2923
- ACL_CHECK(aclDestroyIntArray(acl_dims));
2026
+ GGML_CANN_CALL_ACLNN_OP(ctx, Roll, acl_src, acl_shifts, acl_dims, acl_dst);
2027
+ ggml_cann_release_resources(ctx, acl_shifts, acl_dims);
2924
2028
  }
2925
2029
 
2926
2030
  /**
@@ -2942,23 +2046,8 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
2942
2046
  float value) {
2943
2047
  aclIntArray* acl_index = aclCreateIntArray(index, index_num);
2944
2048
  aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT);
2945
-
2946
- uint64_t workspaceSize = 0;
2947
- aclOpExecutor* executor;
2948
- void* workspaceAddr = nullptr;
2949
-
2950
- ACL_CHECK(aclnnInplaceIndexFillTensorGetWorkspaceSize(
2951
- acl_src, dim, acl_index, acl_value, &workspaceSize, &executor));
2952
- if (workspaceSize > 0) {
2953
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
2954
- workspaceAddr = workspace_allocator.get();
2955
- }
2956
-
2957
- ACL_CHECK(aclnnInplaceIndexFillTensor(workspaceAddr, workspaceSize,
2958
- executor, ctx.stream()));
2959
-
2960
- ACL_CHECK(aclDestroyIntArray(acl_index));
2961
- ACL_CHECK(aclDestroyScalar(acl_value));
2049
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexFillTensor, acl_src, dim, acl_index, acl_value);
2050
+ ggml_cann_release_resources(ctx, acl_index, acl_value);
2962
2051
  }
2963
2052
 
2964
2053
  static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
@@ -2973,37 +2062,30 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
2973
2062
  ggml_tensor* src1 = dst->src[1]; // position
2974
2063
  ggml_tensor* src2 = dst->src[2]; // freq_factors
2975
2064
 
2976
- // arange, [0,1,...,ne0/2]
2977
- int64_t arange_length = src0->ne[0] / 2;
2978
- ggml_cann_pool_alloc arange_allocator(ctx.pool(),
2979
- arange_length * sizeof(float_t));
2980
- void* arange_buffer = arange_allocator.get();
2981
- int64_t arange_ne[] = {arange_length, 1, 1, 1};
2982
- size_t arange_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t),
2983
- arange_length * sizeof(float_t)};
2984
-
2985
- aclTensor* acl_arange_tensor =
2986
- ggml_cann_create_tensor(arange_buffer, ACL_FLOAT, sizeof(float_t),
2987
- arange_ne, arange_nb, GGML_MAX_DIMS);
2065
+ GGML_TENSOR_BINARY_OP_LOCALS
2066
+
2067
+ // theta_scale arange, [0,1,...,ne00/2 - 1]
2068
+ int64_t theta_scale_length = ne00 / 2;
2069
+ ggml_cann_pool_alloc theta_scale_allocator(ctx.pool(),
2070
+ theta_scale_length * sizeof(float_t));
2071
+ void* theta_scale_buffer = theta_scale_allocator.get();
2072
+ int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1};
2073
+ size_t theta_scale_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t),
2074
+ theta_scale_length * sizeof(float_t)};
2075
+
2076
+ aclTensor* acl_theta_scale_tensor =
2077
+ ggml_cann_create_tensor(theta_scale_buffer, ACL_FLOAT, sizeof(float_t),
2078
+ theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2988
2079
  float start = 0;
2989
2080
  float step = 1;
2990
- float stop = src0->ne[0] / 2;
2991
- float n_elements = src0->ne[0] / 2;
2992
- aclnn_arange(ctx, acl_arange_tensor, start, stop, step, n_elements);
2081
+ float stop = ne00 / 2;
2082
+ float n_elements = ne00 / 2;
2083
+ aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
2993
2084
 
2994
2085
  // power
2995
- // aclnnPowScalarTensor(): @param self is tensor which should be scalar, so
2996
- // use aclnn_pow_tensor_tensor() until fixed. aclScalar* acl_theta_scale =
2997
- // aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT);
2998
- // aclnn_power_scalar_tensor(ctx, acl_theta_scale, acl_arange_tensor,
2999
- // acl_power_tensor);
3000
- ggml_cann_pool_alloc theta_scale_allocator(ctx.pool(),
3001
- arange_length * sizeof(float_t));
3002
- void* theta_scale_buffer = theta_scale_allocator.get();
3003
- aclTensor* acl_theta_scale_tensor = aclnn_values(
3004
- ctx, theta_scale_buffer, arange_length * sizeof(float_t), arange_ne,
3005
- GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), theta_scale);
3006
- aclnn_pow_tensor_tensor(ctx, acl_theta_scale_tensor, acl_arange_tensor);
2086
+ aclScalar* acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT);
2087
+ GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor,
2088
+ acl_theta_scale_tensor);
3007
2089
 
3008
2090
  // freq_scale
3009
2091
  if (freq_scale != 1) {
@@ -3014,29 +2096,27 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
3014
2096
  if (src2) {
3015
2097
  aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
3016
2098
  src2->data, ggml_cann_type_mapping(src2->type),
3017
- ggml_type_size(src2->type), arange_ne, arange_nb, GGML_MAX_DIMS);
3018
- aclnn_div_tensor(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor,
3019
- nullptr, true);
3020
- ACL_CHECK(aclDestroyTensor(acl_freq_factors_tensor));
2099
+ ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2100
+ aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
2101
+ ggml_cann_release_resources(ctx, acl_freq_factors_tensor);
3021
2102
  }
3022
2103
 
3023
2104
  // position
3024
2105
  GGML_ASSERT(src1->type == GGML_TYPE_I32);
3025
2106
  int64_t position_length = src1->ne[0];
3026
- int64_t position_ne[] = {1, position_length, 1, 1};
3027
- size_t position_nb[] = {sizeof(int32_t), sizeof(int32_t),
3028
- sizeof(int32_t) * position_length,
2107
+ int64_t position_ne[] = {1, 1, position_length, 1};
2108
+ size_t position_nb[] = {sizeof(int32_t), sizeof(int32_t), sizeof(int32_t),
3029
2109
  sizeof(int32_t) * position_length};
3030
2110
  aclTensor* acl_position_tensor = ggml_cann_create_tensor(
3031
2111
  src1->data, ggml_cann_type_mapping(src1->type),
3032
2112
  ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
3033
2113
 
3034
2114
  // power * position
3035
- int64_t theta_length = arange_length * position_length;
2115
+ int64_t theta_length = theta_scale_length * position_length;
3036
2116
  ggml_cann_pool_alloc theta_allocator(ctx.pool(),
3037
2117
  theta_length * sizeof(float_t));
3038
2118
  void* theta_buffer = theta_allocator.get();
3039
- int64_t theta_ne[] = {arange_length, position_length, 1, 1};
2119
+ int64_t theta_ne[] = {theta_scale_length, 1, position_length, 1};
3040
2120
  size_t theta_nb[GGML_MAX_DIMS];
3041
2121
  theta_nb[0] = sizeof(float_t);
3042
2122
  for (int i = 1; i < GGML_MAX_DIMS; i++) {
@@ -3048,40 +2128,22 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
3048
2128
  aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
3049
2129
  acl_theta_tensor);
3050
2130
 
3051
- // permute: [0,1,2,3]->[0,2,1,3]
3052
- int64_t permute_ne[] = {arange_length, 1, position_length, 1};
3053
- size_t permute_nb[GGML_MAX_DIMS];
3054
- permute_nb[0] = sizeof(float_t);
3055
- for (int i = 1; i < GGML_MAX_DIMS; i++) {
3056
- permute_nb[i] = permute_nb[i - 1] * permute_ne[i - 1];
3057
- }
3058
- ggml_cann_pool_alloc permute_allocator(ctx.pool(),
3059
- theta_length * sizeof(float_t));
3060
- void* permute_buffer = permute_allocator.get();
3061
- aclTensor* acl_permute_tensor = ggml_cann_create_tensor(
3062
- permute_buffer, ACL_FLOAT, sizeof(float_t), permute_ne, permute_nb,
3063
- GGML_MAX_DIMS, ACL_FORMAT_ND);
3064
- int64_t permute_dim[] = {0, 2, 1, 3};
3065
- int64_t num_dims = 4;
3066
- aclnn_permute(ctx, acl_theta_tensor, acl_permute_tensor, permute_dim,
3067
- num_dims);
3068
-
3069
2131
  // sin/cos
3070
2132
  ggml_cann_pool_alloc sin_allocator(ctx.pool(),
3071
2133
  theta_length * sizeof(float_t));
3072
2134
  void* sin_buffer = sin_allocator.get();
3073
2135
  aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
3074
- sin_buffer, ACL_FLOAT, sizeof(float_t), permute_ne, permute_nb,
2136
+ sin_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
3075
2137
  GGML_MAX_DIMS, ACL_FORMAT_ND);
3076
- aclnn_sin(ctx, acl_permute_tensor, acl_sin_tensor);
2138
+ aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
3077
2139
 
3078
2140
  ggml_cann_pool_alloc cos_allocator(ctx.pool(),
3079
2141
  theta_length * sizeof(float_t));
3080
2142
  void* cos_buffer = cos_allocator.get();
3081
2143
  aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
3082
- cos_buffer, ACL_FLOAT, sizeof(float_t), permute_ne, permute_nb,
2144
+ cos_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
3083
2145
  GGML_MAX_DIMS, ACL_FORMAT_ND);
3084
- aclnn_cos(ctx, acl_permute_tensor, acl_cos_tensor);
2146
+ aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
3085
2147
 
3086
2148
  // attn_factor
3087
2149
  if (attn_factor != 1) {
@@ -3097,7 +2159,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
3097
2159
  } else {
3098
2160
  int64_t num_repeats = 2;
3099
2161
  int64_t dim = 3;
3100
- int64_t output_size = arange_length * num_repeats;
2162
+ int64_t output_size = theta_scale_length * num_repeats;
3101
2163
  aclnn_repeat_interleave(ctx, acl_sin_tensor, acl_sin_repeat_tensor, dim,
3102
2164
  num_repeats, output_size);
3103
2165
  aclnn_repeat_interleave(ctx, acl_cos_tensor, acl_cos_repeat_tensor, dim,
@@ -3105,13 +2167,8 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
3105
2167
  }
3106
2168
 
3107
2169
  // release
3108
- ACL_CHECK(aclDestroyTensor(acl_arange_tensor));
3109
- ACL_CHECK(aclDestroyTensor(acl_theta_scale_tensor));
3110
- ACL_CHECK(aclDestroyTensor(acl_position_tensor));
3111
- ACL_CHECK(aclDestroyTensor(acl_theta_tensor));
3112
- ACL_CHECK(aclDestroyTensor(acl_permute_tensor));
3113
- ACL_CHECK(aclDestroyTensor(acl_sin_tensor));
3114
- ACL_CHECK(aclDestroyTensor(acl_cos_tensor));
2170
+ ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
2171
+ acl_theta_tensor, acl_sin_tensor, acl_cos_tensor, acl_theta_scale);
3115
2172
  }
3116
2173
 
3117
2174
  #ifdef __cplusplus
@@ -3133,7 +2190,6 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
3133
2190
  // TODO: use ascendc
3134
2191
  // Only test with LLAMA model.
3135
2192
  ggml_tensor* src0 = dst->src[0]; // input
3136
- ggml_tensor* src2 = dst->src[2]; // freq_factors
3137
2193
 
3138
2194
  // param
3139
2195
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
@@ -3168,13 +2224,13 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
3168
2224
 
3169
2225
  // init cos/sin cache
3170
2226
  ggml_cann_pool_alloc sin_allocator(
3171
- ctx.pool(), src0->ne[0] * src0->ne[2] * sizeof(float_t));
2227
+ ctx.pool(), ne00 * ne02 * sizeof(float_t));
3172
2228
  ggml_cann_pool_alloc cos_allocator(
3173
- ctx.pool(), src0->ne[0] * src0->ne[2] * sizeof(float_t));
2229
+ ctx.pool(), ne00 * ne02 * sizeof(float_t));
3174
2230
  void* sin_buffer = sin_allocator.get();
3175
2231
  void* cos_buffer = cos_allocator.get();
3176
2232
 
3177
- int64_t sin_reshape_ne[4] = {src0->ne[0], 1, src0->ne[2], 1};
2233
+ int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
3178
2234
  size_t sin_reshape_nb[GGML_MAX_DIMS];
3179
2235
  sin_reshape_nb[0] = sizeof(float_t);
3180
2236
  for (int i = 1; i < GGML_MAX_DIMS; i++) {
@@ -3187,7 +2243,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
3187
2243
  ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float_t),
3188
2244
  sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
3189
2245
  aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor,
3190
- theta_scale, freq_scale, attn_factor, is_neox);
2246
+ theta_scale, freq_scale, attn_factor, is_neox);
3191
2247
 
3192
2248
  aclTensor* acl_src = ggml_cann_create_tensor(src0);
3193
2249
  aclTensor* acl_dst = ggml_cann_create_tensor(dst);
@@ -3224,8 +2280,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
3224
2280
  int64_t shifts[] = {1};
3225
2281
  int64_t dims[] = {3};
3226
2282
  aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims);
3227
- ACL_CHECK(aclDestroyTensor(acl_input_roll_tensor));
3228
- ACL_CHECK(aclDestroyTensor(acl_input_tensor));
2283
+ ggml_cann_release_resources(ctx, acl_input_roll_tensor, acl_input_tensor);
3229
2284
 
3230
2285
  // init [-1, 1, -1, 1, ...]
3231
2286
  minus_one_scale_buffer = minus_one_scale_allocator.get();
@@ -3261,8 +2316,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
3261
2316
  int64_t dims[] = {3};
3262
2317
  aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims);
3263
2318
 
3264
- ACL_CHECK(aclDestroyTensor(acl_input_roll_tensor));
3265
- ACL_CHECK(aclDestroyTensor(acl_input_tensor));
2319
+ ggml_cann_release_resources(ctx, acl_input_roll_tensor, acl_input_tensor);
3266
2320
  // init [-1, -1, -1, 1, 1,1,...]
3267
2321
  minus_one_scale_buffer = minus_one_scale_allocator.get();
3268
2322
  int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1};
@@ -3287,7 +2341,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
3287
2341
  bool inplace = true;
3288
2342
  float scale = -1;
3289
2343
  aclnn_muls(ctx, acl_first_half_tensor, scale, nullptr, inplace);
3290
- ACL_CHECK(aclDestroyTensor(acl_first_half_tensor));
2344
+ ggml_cann_release_resources(ctx, acl_first_half_tensor);
3291
2345
  }
3292
2346
 
3293
2347
  // TODO: n_dims < ne0
@@ -3315,8 +2369,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
3315
2369
  // output
3316
2370
  void* output_fp32_buffer;
3317
2371
  if (src0->type == GGML_TYPE_F32) {
3318
- aclnn_inplace_mul(ctx, acl_src, acl_cos_reshape_tensor);
3319
- aclnn_inplace_mul(ctx, acl_input_roll_mul_scale_tensor,
2372
+ aclnn_mul(ctx, acl_src, acl_cos_reshape_tensor);
2373
+ aclnn_mul(ctx, acl_input_roll_mul_scale_tensor,
3320
2374
  acl_sin_reshape_tensor);
3321
2375
  aclnn_add(ctx, acl_src, acl_input_roll_mul_scale_tensor, acl_dst);
3322
2376
  // TODO: ne0 != n_dims in mode2
@@ -3352,76 +2406,788 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
3352
2406
  output_fp32_tensor);
3353
2407
  aclnn_cast(ctx, output_fp32_tensor, acl_dst, ACL_FLOAT16);
3354
2408
 
3355
- ACL_CHECK(aclDestroyTensor(input_fp32_tensor1));
3356
- ACL_CHECK(aclDestroyTensor(input_fp32_tensor2));
3357
- ACL_CHECK(aclDestroyTensor(output_fp32_tensor));
3358
- ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor));
3359
- ACL_CHECK(aclDestroyTensor(acl_minus_one_tensor));
3360
- ACL_CHECK(aclDestroyTensor(acl_input_roll_mul_scale_tensor));
3361
- ACL_CHECK(aclDestroyTensor(acl_input_roll_reshape_tensor));
3362
- ACL_CHECK(aclDestroyTensor(acl_src));
2409
+ ggml_cann_release_resources(ctx, input_fp32_tensor1, input_fp32_tensor2,
2410
+ output_fp32_tensor, acl_sin_reshape_tensor,
2411
+ acl_minus_one_tensor, acl_input_roll_mul_scale_tensor,
2412
+ acl_input_roll_reshape_tensor, acl_src);
3363
2413
  }
3364
2414
  return;
3365
2415
  #endif
3366
2416
 
3367
- // src0 == GGML_TYPE_F16
3368
- // TODO: optimization this `if` code
2417
+ // ggml_mode = 0 --> aclnn_model = 1
2418
+ int64_t acl_mode = mode == 0 ? 1 : mode;
2419
+
2420
+ switch (src0->type) {
2421
+ case GGML_TYPE_F32: {
2422
+ GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src,
2423
+ acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode, acl_dst);
2424
+ break;
2425
+ }
2426
+ case GGML_TYPE_F16: {
2427
+ ggml_cann_pool_alloc src_trans_allocator(
2428
+ ctx.pool(), ggml_nelements(src0) * sizeof(float));
2429
+ void* src_trans_buffer = src_trans_allocator.get();
2430
+ ggml_cann_pool_alloc dst_trans_allocator(
2431
+ ctx.pool(), ggml_nelements(dst) * sizeof(float));
2432
+ void* dst_trans_buffer = dst_trans_allocator.get();
2433
+
2434
+ size_t src_trans_nb[GGML_MAX_DIMS];
2435
+ src_trans_nb[0] = sizeof(float);
2436
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
2437
+ src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
2438
+ }
2439
+
2440
+ aclTensor* acl_src_trans_tensor = ggml_cann_create_tensor(
2441
+ src_trans_buffer, ACL_FLOAT, sizeof(float), src0->ne, src_trans_nb,
2442
+ GGML_MAX_DIMS);
2443
+ aclTensor* acl_dst_trans_tensor = ggml_cann_create_tensor(
2444
+ dst_trans_buffer, ACL_FLOAT, sizeof(float), dst->ne, src_trans_nb,
2445
+ GGML_MAX_DIMS);
2446
+
2447
+ aclnn_cast(ctx, acl_src, acl_src_trans_tensor, ACL_FLOAT);
2448
+
2449
+ GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor,
2450
+ acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode,
2451
+ acl_dst_trans_tensor);
2452
+
2453
+ aclnn_cast(ctx, acl_dst_trans_tensor, acl_dst, ACL_FLOAT16);
2454
+
2455
+ ggml_cann_release_resources(ctx, acl_src_trans_tensor,
2456
+ acl_dst_trans_tensor);
2457
+ break;
2458
+ }
2459
+ default:
2460
+ GGML_ABORT("Unsupported tensor type for GGML_OP_ROPE");
2461
+ break;
2462
+ }
2463
+ ggml_cann_release_resources(ctx, acl_cos_reshape_tensor,
2464
+ acl_sin_reshape_tensor, acl_src, acl_dst);
2465
+ }
2466
+
2467
+
2468
+ void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst){
2469
+ ggml_tensor * src0 = dst->src[0];
2470
+
2471
+ aclTensor* acl_src = ggml_cann_create_tensor(src0);
2472
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3);
2473
+
2474
+ GGML_CANN_CALL_ACLNN_OP(ctx, ArgMax, acl_src, 3, false, acl_dst);
2475
+
2476
+ ggml_cann_release_resources(ctx, acl_src, acl_dst);
2477
+ }
2478
+
2479
+ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){
2480
+ ggml_tensor * src0 = dst->src[0];
2481
+ ggml_tensor * src1 = dst->src[1];
2482
+
2483
+ // stride
2484
+ int64_t s0 = ((const int32_t*)(dst->op_params))[0];
2485
+
2486
+ aclTensor* acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL);
2487
+ aclTensor* acl_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL);
2488
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL);
2489
+
2490
+ int64_t strideVal[1];
2491
+ strideVal[0] = s0;
2492
+ aclIntArray *stride = aclCreateIntArray(strideVal, 1);
2493
+ int64_t paddingVal[] = {0};
2494
+ aclIntArray *padding = aclCreateIntArray(paddingVal, 1);
2495
+ int64_t dilationVal[] = {1};
2496
+ aclIntArray *dilation = aclCreateIntArray(dilationVal, 1);
2497
+ bool transposed = true;
2498
+ int64_t groups = 1;
2499
+ int8_t cubeMathType = 0;
2500
+
2501
+ #ifdef ASCEND_310P
2502
+ cubeMathType = 1;
2503
+ #endif
2504
+
2505
+ GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input, acl_weight, nullptr, stride,
2506
+ padding, dilation, transposed, padding, groups, acl_dst, cubeMathType);
2507
+
2508
+ ggml_cann_release_resources(ctx, acl_weight, acl_dst, stride, padding, dilation);
2509
+ }
2510
+
2511
+ void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst){
2512
+ ggml_tensor * src0 = dst->src[0];
2513
+
2514
+ aclTensor* acl_input = ggml_cann_create_tensor(src0);
2515
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
2516
+
2517
+ float alphaValue = 1.0f;
2518
+ aclScalar* alpha = nullptr;
2519
+ alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
2520
+
2521
+ GGML_CANN_CALL_ACLNN_OP(ctx, Elu, acl_input, alpha, alpha, alpha,
2522
+ acl_dst);
2523
+
2524
+ ggml_cann_release_resources(ctx, acl_input, acl_dst, alpha);
2525
+ }
2526
+
2527
+ void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst){
2528
+ ggml_tensor * src0 = dst->src[0];
2529
+
2530
+ aclTensor* acl_src = ggml_cann_create_tensor(src0);
2531
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
2532
+
2533
+ int64_t reduceDimValue[] = {3};
2534
+ aclIntArray* reduceDim = aclCreateIntArray(reduceDimValue, 1);
2535
+ bool keepDim = true;
2536
+
2537
+ GGML_CANN_CALL_ACLNN_OP(ctx, Mean, acl_src, reduceDim, keepDim, ACL_FLOAT, acl_dst);
2538
+
2539
+ ggml_cann_release_resources(ctx, acl_src, acl_dst, reduceDim);
2540
+ }
2541
+
2542
+ void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){
2543
+ ggml_tensor * src0 = dst->src[0];
2544
+ int32_t *opts = (int32_t *) dst->op_params;
2545
+ int64_t paddingsArray[2] = {opts[0], opts[1]};
2546
+ aclIntArray* paddings = aclCreateIntArray(paddingsArray, 2);
2547
+
2548
+ for (int64_t i = 0; i < src0->ne[3]; i++) {
2549
+ aclTensor* acl_src = ggml_cann_create_tensor(
2550
+ (char*)src0->data + i * src0->ne[3],
2551
+ ggml_cann_type_mapping(src0->type), ggml_element_size(src0),
2552
+ src0->ne, src0->nb, 3);
2553
+
2554
+ aclTensor* acl_dst = ggml_cann_create_tensor(
2555
+ (char*)dst->data + i * src0->ne[3],
2556
+ ggml_cann_type_mapping(dst->type), ggml_element_size(dst),
2557
+ dst->ne, dst->nb, 3);
2558
+
2559
+ GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src, paddings, acl_dst);
2560
+
2561
+ ggml_cann_release_resources(ctx, acl_src, acl_dst);
2562
+ }
2563
+ ggml_cann_release_resources(ctx, paddings);
2564
+ }
2565
+
2566
+ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst){
2567
+ ggml_tensor * src0 = dst->src[0];
2568
+ ggml_tensor * src1 = dst->src[1];
2569
+
2570
+ aclTensor* acl_self = ggml_cann_create_tensor(src0);
2571
+ aclTensor* acl_other = ggml_cann_create_tensor(src1);
2572
+
2573
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceEqTensor, acl_self, acl_other);
2574
+
2575
+ ggml_cann_sum(ctx, dst);
2576
+
2577
+ ggml_cann_release_resources(ctx, acl_self, acl_other);
2578
+ }
2579
+
2580
+ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){
2581
+ ggml_tensor * src0 = dst->src[0];
2582
+
2583
+ aclTensor* acl_src = ggml_cann_create_tensor(src0);
2584
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
2585
+
2586
+ float alphaValue = 0.0f;
2587
+ aclScalar* alpha = nullptr;
2588
+ alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
2589
+
2590
+ GGML_CANN_CALL_ACLNN_OP(ctx, GtScalar, acl_src, alpha, acl_dst);
2591
+
2592
+ ggml_cann_release_resources(ctx, acl_src, acl_dst, alpha);
2593
+ }
2594
+
2595
+ /**
2596
+ * @brief Performs expert-specific matrix multiplication (MoE) with
2597
+ * floating-point precision using the CANN backend.
2598
+ *
2599
+ * This function executes a matrix multiplication operation tailored for
2600
+ * Mixture of Experts (MoE) models, where the input tensor is multiplied
2601
+ * with expert-specific weight matrices. It uses the CANN backend for
2602
+ * efficient computation and stores the result in the destination tensor `dst`.
2603
+ * The operation may leverage identity-based optimizations or routing masks
2604
+ * as part of sparse expert selection.
2605
+ *
2606
+ * @param ctx The context for executing CANN backend operations.
2607
+ * @param dst The destination tensor where the MoE multiplication result
2608
+ * will be stored.
2609
+ *
2610
+ * @note This function assumes floating-point data types and is designed for
2611
+ * MoE architectures, possibly involving sparse expert routing.
2612
+ */
2613
+ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2614
+ //dst [M, K, N, 1]
2615
+ ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
2616
+ ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
2617
+ ggml_tensor * ids = dst->src[2]; //ids [K, N]
2618
+
2619
+ GGML_TENSOR_BINARY_OP_LOCALS
2620
+
2621
+ // copy index from npu to cpu
2622
+ int64_t n_as = ne02; // A
2623
+ int64_t n_ids = ids->ne[0]; // K
2624
+
2625
+ std::vector<char> ids_host(ggml_nbytes(ids));
2626
+ ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids),
2627
+ ACL_MEMCPY_DEVICE_TO_HOST);
2628
+ ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
2629
+
2630
+ char * src0_original = (char *) src0->data;
2631
+ char * src1_original = (char *) src1->data;
2632
+ char * dst_original = (char *) dst->data;
2633
+ size_t ori_src0_nb[4] = {nb00, nb01, nb02, nb03};
2634
+
2635
+ // src0 is F16, src1 is F32, dst is F32
2636
+ ggml_cann_pool_alloc src0_cast_allocator;
3369
2637
  if (src0->type == GGML_TYPE_F16) {
3370
- ggml_cann_pool_alloc sin_final_allocator(
3371
- ctx.pool(), src0->ne[0] * src0->ne[2] * ggml_type_size(src0->type));
3372
- ggml_cann_pool_alloc cos_final_allocator(
3373
- ctx.pool(), src0->ne[0] * src0->ne[2] * ggml_type_size(src0->type));
3374
- void* sin_final_buffer = sin_final_allocator.get();
3375
- void* cos_final_buffer = cos_final_allocator.get();
3376
-
3377
- int64_t sin_final_ne[4] = {src0->ne[0], 1, src0->ne[2], 1};
3378
- size_t sin_final_nb[GGML_MAX_DIMS];
3379
- sin_final_nb[0] = ggml_type_size(src0->type);
2638
+ src0_cast_allocator.alloc(ctx.pool(), sizeof(float) * ggml_nelements(src0));
2639
+ void* src0_cast_buf = src0_cast_allocator.get();
2640
+
2641
+ size_t cast_nb[GGML_MAX_DIMS];
2642
+ cast_nb[0] = sizeof(float_t);
3380
2643
  for (int i = 1; i < GGML_MAX_DIMS; i++) {
3381
- sin_final_nb[i] = sin_final_nb[i - 1] * sin_final_ne[i - 1];
2644
+ cast_nb[i] = cast_nb[i - 1] * src0->ne[i - 1];
3382
2645
  }
3383
- aclTensor* acl_sin_final_tensor = ggml_cann_create_tensor(
3384
- sin_final_buffer, ggml_cann_type_mapping(src0->type),
3385
- ggml_type_size(src0->type), sin_final_ne, sin_final_nb,
3386
- GGML_MAX_DIMS);
3387
- aclTensor* acl_cos_final_tensor = ggml_cann_create_tensor(
3388
- cos_final_buffer, ggml_cann_type_mapping(src0->type),
3389
- ggml_type_size(src0->type), sin_final_ne, sin_final_nb,
3390
- GGML_MAX_DIMS);
3391
2646
 
3392
- aclnn_cast(ctx, acl_sin_reshape_tensor, acl_sin_final_tensor,
3393
- ggml_cann_type_mapping(src0->type));
3394
- aclnn_cast(ctx, acl_cos_reshape_tensor, acl_cos_final_tensor,
3395
- ggml_cann_type_mapping(src0->type));
3396
- ACL_CHECK(aclDestroyTensor(acl_cos_reshape_tensor));
3397
- ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor));
3398
- acl_sin_reshape_tensor = acl_sin_final_tensor;
3399
- acl_cos_reshape_tensor = acl_cos_final_tensor;
2647
+ aclTensor* acl_src0_f16 = ggml_cann_create_tensor(src0);
2648
+ aclTensor* acl_cast = ggml_cann_create_tensor(src0_cast_buf,
2649
+ ACL_FLOAT, sizeof(float), src0->ne, cast_nb, 4);
2650
+ GGML_CANN_CALL_ACLNN_OP(ctx, Cast, acl_src0_f16, ACL_FLOAT, acl_cast);
2651
+ ggml_cann_release_resources(ctx, acl_cast, acl_src0_f16);
2652
+
2653
+ src0_original = (char *) src0_cast_buf;
2654
+ memcpy(ori_src0_nb, cast_nb, sizeof(ori_src0_nb));
2655
+ }
2656
+
2657
+ std::vector<aclTensor*> src0_tensor_vec;
2658
+ std::vector<aclTensor*> src1_tensor_vec;
2659
+ std::vector<aclTensor*> dst_tensor_vec;
2660
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2661
+ for (int64_t id = 0; id < n_ids; id++) {
2662
+ // src0_row [M, D] -> weight && permute
2663
+ int64_t src0_ne[2] = {ne01, ne00};
2664
+ size_t src0_nb[2] = {ori_src0_nb[1], ori_src0_nb[0]};
2665
+ // src1_row [D, 1] -> input
2666
+ int64_t src1_ne[2] = {ne10, 1};
2667
+ size_t src1_nb[2] = {nb10, nb11};
2668
+ // dst_row [M, 1] -> out
2669
+ int64_t dst_ne[2] = {ne0, 1};
2670
+ size_t dst_nb[2] = {nb0, nb1};
2671
+
2672
+ // expert index
2673
+ int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
2674
+ GGML_ASSERT(i02 >= 0 && i02 < n_as);
2675
+
2676
+ // If B = 1 (broadcast), always use 0; otherwise, use id.
2677
+ int64_t i11 = (ne11 == 1 ? 0 : id);
2678
+ int64_t i12 = iid1;
2679
+
2680
+ int64_t i1 = id;
2681
+ int64_t i2 = i12;
2682
+
2683
+ void* src0_tmp_ptr = src0_original + i02*ori_src0_nb[2];
2684
+ void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
2685
+ void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
2686
+
2687
+ aclTensor* acl_src0 = ggml_cann_create_tensor(src0_tmp_ptr,
2688
+ ACL_FLOAT, sizeof(float),
2689
+ src0_ne, src0_nb, 2);
2690
+ aclTensor* acl_src1 = ggml_cann_create_tensor(src1_tmp_ptr,
2691
+ ACL_FLOAT, sizeof(float),
2692
+ src1_ne, src1_nb, 2);
2693
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst_tmp_ptr,
2694
+ ACL_FLOAT, sizeof(float),
2695
+ dst_ne, dst_nb, 2);
2696
+
2697
+ src0_tensor_vec.push_back(acl_src0);
2698
+ src1_tensor_vec.push_back(acl_src1);
2699
+ dst_tensor_vec.push_back(acl_dst);
2700
+ }
2701
+ }
2702
+
2703
+ size_t GROUP_SIZE = 128;
2704
+ // GroupedMatmulV2 required tensor_list.size < 128
2705
+ for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) {
2706
+ // split and call GroupedMatmulV2
2707
+ size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size());
2708
+ std::vector<aclTensor*> src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end);
2709
+ std::vector<aclTensor*> src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end);
2710
+ std::vector<aclTensor*> dst_tensor_vec_split(dst_tensor_vec.begin() + i, dst_tensor_vec.begin() + end);
2711
+
2712
+ aclTensorList* src0_tensor_list = aclCreateTensorList(src0_tensor_vec_split.data(), src0_tensor_vec_split.size());
2713
+ aclTensorList* src1_tensor_list = aclCreateTensorList(src1_tensor_vec_split.data(), src1_tensor_vec_split.size());
2714
+ aclTensorList* dst_tensor_list = aclCreateTensorList(dst_tensor_vec_split.data(), dst_tensor_vec_split.size());
2715
+
2716
+ GGML_CANN_CALL_ACLNN_OP(ctx, GroupedMatmulV2, src1_tensor_list, src0_tensor_list,
2717
+ nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, -1, dst_tensor_list);
2718
+
2719
+ ggml_cann_release_resources(ctx, src0_tensor_list, src1_tensor_list, dst_tensor_list);
2720
+ }
2721
+ return;
2722
+ }
2723
+
2724
+ /**
2725
+ * @brief Performs expert-specific matrix multiplication (MoE) with
2726
+ * quantized precision using the CANN backend.
2727
+ *
2728
+ * This function executes a matrix multiplication operation tailored for
2729
+ * Mixture of Experts (MoE) models, where the input tensor is multiplied
2730
+ * with expert-specific quantized weight matrices. It leverages the CANN
2731
+ * backend to perform efficient low-precision computations and stores the
2732
+ * quantized result in the destination tensor `dst`.
2733
+ *
2734
+ * Quantization techniques reduce memory footprint and improve performance
2735
+ * by using lower-bit representations (e.g., int8) instead of floating-point.
2736
+ * This function is designed to work with such formats and may incorporate
2737
+ * optimizations like identity-based fast paths or routing masks for sparse
2738
+ * expert selection.
2739
+ *
2740
+ * @param ctx The context for executing CANN backend operations.
2741
+ * @param dst The destination tensor where the quantized MoE multiplication result
2742
+ * will be stored.
2743
+ *
2744
+ * @note This function assumes quantized data types and is designed for
2745
+ * MoE architectures with potential sparse expert routing.
2746
+ */
2747
+ static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2748
+ // TODO: Use aclnnGroupedMatMul
2749
+ //dst [M, K, N, 1]
2750
+ ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
2751
+ ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
2752
+ ggml_tensor * ids = dst->src[2]; //ids [K, N]
2753
+
2754
+ GGML_TENSOR_BINARY_OP_LOCALS
2755
+
2756
+ // copy index from npu to cpu
2757
+ int64_t n_as = ne02; // A
2758
+ int64_t n_ids = ids->ne[0]; // K
2759
+
2760
+ std::vector<char> ids_host(ggml_nbytes(ids));
2761
+ ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids),
2762
+ ACL_MEMCPY_DEVICE_TO_HOST);
2763
+ ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
2764
+
2765
+ char * src0_original = (char *) src0->data;
2766
+ char * src1_original = (char *) src1->data;
2767
+ char * dst_original = (char *) dst->data;
2768
+
2769
+ ggml_tensor src0_row = *src0;
2770
+ ggml_tensor src1_row = *src1;
2771
+ ggml_tensor dst_row = *dst;
2772
+
2773
+ const enum ggml_type type = dst->src[0]->type;
2774
+ float weight_elem_size;
2775
+ if (type == GGML_TYPE_Q4_0) {
2776
+ weight_elem_size = float(sizeof(uint8_t)) / 2;
2777
+ } else if (type == GGML_TYPE_Q8_0) {
2778
+ weight_elem_size = float(sizeof(uint8_t));
2779
+ } else {
2780
+ GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 ");
3400
2781
  }
3401
2782
 
3402
- uint64_t workspaceSize = 0;
3403
- aclOpExecutor* executor;
2783
+ // src0_row [D, M, 1, 1] weight without permute
2784
+ src0_row.ne[2] = 1;
2785
+ src0_row.ne[3] = 1;
2786
+ src0_row.nb[0] = weight_elem_size;
2787
+ src0_row.nb[1] = weight_elem_size * ne00;
2788
+ src0_row.nb[2] = weight_elem_size * ne00;
2789
+ src0_row.nb[3] = weight_elem_size * ne00;
2790
+ size_t weight_stride = ne00 * ne01 * weight_elem_size;
2791
+ size_t weight_size = weight_stride * ne02 * ne03;
3404
2792
 
3405
- void* workspaceAddr = nullptr;
2793
+ // scale [D, M, 1, 1] -> scale && permute
2794
+ size_t scale_elem_size = sizeof(uint16_t);
2795
+ size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size;
3406
2796
 
3407
- int acl_mode = mode;
3408
- if (mode == 0) {
3409
- acl_mode = 1;
2797
+ // src1_row [D, 1, 1, 1] -> input
2798
+ src1_row.ne[1] = 1;
2799
+ src1_row.ne[2] = 1;
2800
+ src1_row.ne[3] = 1;
2801
+ src1_row.nb[2] = nb11;
2802
+ src1_row.nb[3] = nb11;
2803
+
2804
+ // dst_row [M, 1, 1, 1] -> out
2805
+ dst_row.ne[1] = 1;
2806
+ dst_row.ne[2] = 1;
2807
+ dst_row.ne[3] = 1;
2808
+ dst_row.nb[2] = nb1;
2809
+ dst_row.nb[3] = nb1;
2810
+
2811
+ //create weight for one row
2812
+ ggml_cann_pool_alloc weight_allocator(ctx.pool());
2813
+ void* weight_buffer = weight_allocator.alloc(nb02);
2814
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2815
+ for (int64_t id = 0; id < n_ids; id++) {
2816
+ // expert index
2817
+ int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
2818
+ GGML_ASSERT(i02 >= 0 && i02 < n_as);
2819
+
2820
+ // If B = 1 (broadcast), always use 0; otherwise, use id.
2821
+ int64_t i11 = (ne11 == 1 ? 0 : id);
2822
+ int64_t i12 = iid1;
2823
+
2824
+ int64_t i1 = id;
2825
+ int64_t i2 = i12;
2826
+
2827
+ void* src0_tmp_ptr = src0_original + i02*weight_stride;
2828
+ void* scale_tmp_ptr = src0_original + weight_size + i02*scale_stride;
2829
+ void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
2830
+ void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
2831
+
2832
+ // mem cpy
2833
+ ggml_cann_async_memcpy(ctx, weight_buffer, src0_tmp_ptr, weight_stride,
2834
+ ACL_MEMCPY_DEVICE_TO_DEVICE);
2835
+ void* scale_buffer = (char*)weight_buffer + weight_stride;
2836
+ ggml_cann_async_memcpy(ctx, scale_buffer, scale_tmp_ptr, scale_stride,
2837
+ ACL_MEMCPY_DEVICE_TO_DEVICE);
2838
+
2839
+ src0_row.data = weight_buffer;
2840
+ src1_row.data = src1_tmp_ptr;
2841
+ dst_row.data = dst_tmp_ptr;
2842
+ dst_row.src[0] = &src0_row;
2843
+ dst_row.src[1] = &src1_row;
2844
+
2845
+ ggml_cann_mul_mat(ctx, &dst_row);
2846
+ }
3410
2847
  }
2848
+ return;
2849
+ }
3411
2850
 
3412
- ACL_CHECK(aclnnRotaryPositionEmbeddingGetWorkspaceSize(
3413
- acl_src, acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode,
3414
- acl_dst, &workspaceSize, &executor));
3415
- if (workspaceSize > 0) {
3416
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
3417
- workspaceAddr = workspace_allocator.get();
2851
+ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2852
+ const enum ggml_type type = dst->src[0]->type;
2853
+ switch (type) {
2854
+ case GGML_TYPE_F32:
2855
+ case GGML_TYPE_F16:
2856
+ ggml_cann_mul_mat_id_fp(ctx, dst);
2857
+ break;
2858
+ case GGML_TYPE_Q4_0:
2859
+ case GGML_TYPE_Q8_0:
2860
+ ggml_cann_mul_mat_id_quant(ctx, dst);
2861
+ break;
2862
+ default:
2863
+ GGML_ABORT("Unsupported type for mul_mat_id");
2864
+ break;
3418
2865
  }
2866
+ }
2867
+
2868
+ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
2869
+
2870
+ ggml_tensor* src0 = dst->src[0]; // q, fp32
2871
+ ggml_tensor* src1 = dst->src[1]; // k, fp16
2872
+ ggml_tensor* src2 = dst->src[2]; // v, fp16
2873
+ ggml_tensor* src3 = dst->src[3]; // mask, fp16
2874
+
2875
+ float maxBias = 0.0f;
2876
+ float scaleValue = 1.0f;
2877
+ float logitSoftcap = 0.0f;
2878
+ memcpy(&scaleValue, (float*)dst->op_params + 0, sizeof(float));
2879
+ memcpy(&maxBias, (float*)dst->op_params + 1, sizeof(float));
2880
+ memcpy(&logitSoftcap, (float*)dst->op_params + 2, sizeof(float));
2881
+
2882
+ if(logitSoftcap == 0.0f){
2883
+ size_t faElemSize = sizeof(uint16_t);
2884
+ auto faDataType = ACL_FLOAT16; //ACL_BF16;
2885
+
2886
+ aclTensor* acl_src0_f16_tensor = nullptr;
2887
+ aclTensor* acl_src1_f16_tensor = nullptr;
2888
+ aclTensor* acl_src2_f16_tensor = nullptr;
2889
+ aclTensor* acl_dst_f16_tensor = nullptr;
2890
+
2891
+ // Step 1: cast the src0 (Query) to fp16 if needed
2892
+ ggml_cann_pool_alloc src0_f16_allocator(ctx.pool());
2893
+ void* src0_f16_buffer = nullptr;
2894
+
2895
+ if(ggml_cann_type_mapping(src0->type) != faDataType){
2896
+ aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0);
2897
+ src0_f16_buffer = src0_f16_allocator.alloc(
2898
+ ggml_nelements(src0) * faElemSize);
2899
+
2900
+ int64_t* src0_f16_ne = src0->ne;
2901
+ size_t src0_f16_nb[GGML_MAX_DIMS];
2902
+ src0_f16_nb[0] = sizeof(uint16_t);
2903
+ for(int i = 1; i < GGML_MAX_DIMS; ++i){
2904
+ src0_f16_nb[i] = src0_f16_nb[i - 1] * src0_f16_ne[i - 1];
2905
+ }
2906
+
2907
+ acl_src0_f16_tensor = ggml_cann_create_tensor(
2908
+ src0_f16_buffer, faDataType, faElemSize,
2909
+ src0_f16_ne, src0_f16_nb, GGML_MAX_DIMS
2910
+ );
2911
+ aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType);
2912
+ ggml_cann_release_resources(ctx, acl_src0_f32_tensor);
2913
+ }else{
2914
+ acl_src0_f16_tensor = ggml_cann_create_tensor(src0);
2915
+ }
3419
2916
 
3420
- ACL_CHECK(aclnnRotaryPositionEmbedding(workspaceAddr, workspaceSize,
3421
- executor, ctx.stream()));
2917
+ // Step 2: create the acl tensors for src1 (Key), src2 (Value),
2918
+ // and the direct output from FusedInferAttention
3422
2919
 
3423
- ACL_CHECK(aclDestroyTensor(acl_src));
3424
- ACL_CHECK(aclDestroyTensor(acl_cos_reshape_tensor));
3425
- ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor));
3426
- ACL_CHECK(aclDestroyTensor(acl_dst));
2920
+ acl_src1_f16_tensor = ggml_cann_create_tensor(src1);
2921
+ acl_src2_f16_tensor = ggml_cann_create_tensor(src2);
2922
+
2923
+ ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
2924
+ void* out_f16_buffer = out_f16_allocator.alloc(
2925
+ ggml_nelements(dst) * faElemSize);
2926
+
2927
+ int64_t* out_f16_ne = src0->ne;
2928
+ size_t out_f16_nb[GGML_MAX_DIMS];
2929
+ out_f16_nb[0] = faElemSize;
2930
+ for(int i = 1; i < GGML_MAX_DIMS; ++i){
2931
+ out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
2932
+ }
2933
+
2934
+ acl_dst_f16_tensor = ggml_cann_create_tensor(
2935
+ out_f16_buffer, faDataType, faElemSize,
2936
+ out_f16_ne, out_f16_nb, GGML_MAX_DIMS
2937
+ );
2938
+
2939
+ // Step 3: create the PSEShift tensor if needed
2940
+ // this tensor is considered as mask (f16) in the llama.cpp
2941
+
2942
+ aclTensor* bcast_pse_tensor = nullptr;
2943
+ int64_t bcast_pse_ne[GGML_MAX_DIMS];
2944
+ size_t bcast_pse_nb[GGML_MAX_DIMS];
2945
+ ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool());
2946
+ void* bcast_pse_buffer = nullptr;
2947
+
2948
+ if(src3 != nullptr){
2949
+ bcast_pse_buffer = bcast_pse_allocator.alloc(
2950
+ ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t));
2951
+
2952
+ if(src0->ne[1] > 1){
2953
+ // Case 1: broadcast pse for prefill stage with multiple head
2954
+ aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3);
2955
+ bcast_pse_ne[0] = src3->ne[0];
2956
+ bcast_pse_ne[1] = src3->ne[1];
2957
+ bcast_pse_ne[2] = src0->ne[2];
2958
+ bcast_pse_ne[3] = src3->ne[3];
2959
+
2960
+ bcast_pse_nb[0] = sizeof(uint16_t);
2961
+ for(int i = 1; i < GGML_MAX_DIMS; ++i){
2962
+ bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
2963
+ }
2964
+
2965
+ bcast_pse_tensor = ggml_cann_create_tensor(
2966
+ bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
2967
+ bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
2968
+
2969
+ int64_t repeats[] = {1, src0->ne[2], 1, 1};
2970
+ aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats);
2971
+
2972
+ ggml_cann_release_resources(ctx, acl_mask_f16_tensor);
2973
+ }else{
2974
+ // Case 2: trunc the first row and broadcast pse for decode stage with multiple head
2975
+ int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]};
2976
+ size_t* trunc_pse_nb = src3->nb;
2977
+
2978
+ aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
2979
+ src3->data, ACL_FLOAT16, sizeof(uint16_t),
2980
+ trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS);
2981
+
2982
+ bcast_pse_ne[0] = src3->ne[0];
2983
+ bcast_pse_ne[1] = src0->ne[1];
2984
+ bcast_pse_ne[2] = src0->ne[2];
2985
+ bcast_pse_ne[3] = src3->ne[3];
2986
+
2987
+ bcast_pse_nb[0] = sizeof(uint16_t);
2988
+ for(int i = 1; i < GGML_MAX_DIMS; ++i){
2989
+ bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
2990
+ }
2991
+
2992
+ bcast_pse_tensor = ggml_cann_create_tensor(
2993
+ bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
2994
+ bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
2995
+
2996
+ int64_t repeats[] = {1, src0->ne[2], 1, 1};
2997
+ aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats);
2998
+
2999
+ ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
3000
+ }
3001
+
3002
+ // Compute the slope if needed. Derived from ggml_cann_softmax().
3003
+ if(maxBias != 0.0f){
3004
+ // alibi
3005
+ const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3];
3006
+ const int64_t n_head = src0->ne[2];
3007
+ const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
3008
+ float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor);
3009
+ float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor);
3010
+ // init arange
3011
+ ggml_cann_pool_alloc arange_allocator(ctx.pool(),
3012
+ ne2_ne3 * faElemSize);
3013
+ void* tmp_arange_buffer = arange_allocator.get();
3014
+
3015
+ // arange1: [1, ..., n_heads_log2_floor+1)
3016
+ float start = 1;
3017
+ float stop = n_heads_log2_floor + 1;
3018
+ float step = 1;
3019
+ int64_t n_elements_arange = n_heads_log2_floor;
3020
+
3021
+ int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
3022
+ size_t tmp_arange1_nb[] = {faElemSize};
3023
+ aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
3024
+ tmp_arange_buffer, faDataType, faElemSize,
3025
+ tmp_arange1_ne, tmp_arange1_nb,
3026
+ GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3027
+
3028
+ aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
3029
+
3030
+ aclTensor* tmp_arange2_tensor = nullptr;
3031
+ if (n_heads_log2_floor < ne2_ne3) {
3032
+ // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
3033
+ start = 1;
3034
+ stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
3035
+ step = 2;
3036
+ n_elements_arange = ne2_ne3 - n_heads_log2_floor;
3037
+ int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
3038
+ size_t tmp_arange2_nb[] = {faElemSize};
3039
+
3040
+ aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
3041
+ (char*)tmp_arange_buffer +
3042
+ n_heads_log2_floor * faElemSize,
3043
+ faDataType, faElemSize,
3044
+ tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3045
+ aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
3046
+ n_elements_arange);
3047
+ }
3048
+
3049
+ // init mk_base
3050
+ ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
3051
+ ne2_ne3 * faElemSize);
3052
+ void* tmp_mk_base_buffer = mk_base_allocator.get();
3053
+ int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
3054
+ size_t tmp_mk_base1_nb[] = {faElemSize};
3055
+ aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
3056
+ tmp_mk_base_buffer, faDataType, faElemSize,
3057
+ tmp_mk_base1_ne, tmp_mk_base1_nb,
3058
+ GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3059
+
3060
+ aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
3061
+
3062
+ aclTensor* tmp_mk_base2_tensor = nullptr;
3063
+ if (n_heads_log2_floor < ne2_ne3) {
3064
+ int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
3065
+ size_t tmp_mk_base2_nb[] = {faElemSize};
3066
+ aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
3067
+ (char*)tmp_mk_base_buffer +
3068
+ n_heads_log2_floor * faElemSize,
3069
+ faDataType, faElemSize,
3070
+ tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3071
+ aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
3072
+ }
3073
+
3074
+ // init mk
3075
+ int64_t tmp_mk_base_ne[] = {ne2_ne3};
3076
+ size_t tmp_mk_base_nb[] = {faElemSize};
3077
+ aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
3078
+ tmp_mk_base_buffer, faDataType, faElemSize,
3079
+ tmp_mk_base_ne, tmp_mk_base_nb,
3080
+ GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3081
+ aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
3082
+ tmp_arange_buffer, faDataType, faElemSize,
3083
+ tmp_mk_base_ne, tmp_mk_base_nb,
3084
+ GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
3085
+ aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
3086
+
3087
+ // reshape mk
3088
+ int64_t tmp_mk_ne[] = {1, 1, src0->ne[2], src0->ne[3]};
3089
+ size_t tmp_mk_nb[GGML_MAX_DIMS];
3090
+ tmp_mk_nb[0] = faElemSize;
3091
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
3092
+ tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
3093
+ }
3094
+ aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
3095
+ tmp_mk_base_buffer, faDataType, faElemSize,
3096
+ tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
3097
+ ACL_FORMAT_ND);
3098
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor);
3099
+
3100
+ ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
3101
+ tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
3102
+ tmp_arange_tensor, tmp_mk_tensor);
3103
+ }
3104
+ }
3105
+
3106
+ // Step 4: set the inputs for FusedInferAttention.
3107
+ int kvTensorNum = 1;
3108
+ aclTensor* acl_q_tensor = acl_src0_f16_tensor;
3109
+ aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor};
3110
+ aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor};
3111
+ auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum);
3112
+ auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum);
3113
+
3114
+ int64_t numHeads = src0->ne[2]; // N
3115
+ int64_t numKeyValueHeads = src1->ne[2];
3116
+ // double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d)
3117
+ int64_t preTokens = 65535;
3118
+ int64_t nextTokens = 65535;
3119
+ char layout[5] = {'B', 'N', 'S', 'D', 0};
3120
+ int64_t sparseMode = 0;
3121
+ int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2;
3122
+ int64_t blockSize = 0;
3123
+ int64_t antiquantMode = 0;
3124
+ bool softmaxLseFlag = false;
3125
+ int64_t keyAntiquantMode = 0;
3126
+ int64_t valueAntiquantMode = 0;
3127
+
3128
+ // Step 5: launch the FusedInferAttentionScoreV2 kernel.
3129
+ // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md
3130
+
3131
+ GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2,
3132
+ acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
3133
+ bcast_pse_tensor, nullptr, // pse, mask
3134
+ nullptr, nullptr, // actSeqLen, actSeqLenkv
3135
+ nullptr, nullptr, // deqScale1, quantScale1
3136
+ nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2
3137
+ nullptr, nullptr, // antiquantScale, antiquantOffset
3138
+ nullptr, // blockTable
3139
+ nullptr, nullptr, // qPadSize, kvPadSize
3140
+ nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset
3141
+ nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset
3142
+ nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen
3143
+ numHeads, scaleValue, // heads, scaleValue
3144
+ preTokens, nextTokens, // preTokens, nextTokens
3145
+ layout, // inputLayout
3146
+ numKeyValueHeads, // numKVHeads
3147
+ sparseMode, innerPrecise, // sparseMode, innerPrecise
3148
+ blockSize, antiquantMode, // blockSize, antiquantMode
3149
+ softmaxLseFlag, // softmaxLseFlag
3150
+ keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
3151
+ acl_dst_f16_tensor, // attentionOut
3152
+ nullptr // softmaxLse
3153
+ );
3154
+
3155
+ // Step 6: post-processing, permute and cast to f32
3156
+
3157
+ int64_t new_dim[] = {0, 2, 1, 3};
3158
+ aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
3159
+
3160
+ if(ggml_cann_type_mapping(dst->type) != faDataType){
3161
+ ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool());
3162
+ perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
3163
+ void* perm_out_f16_buffer = perm_out_f16_allocator.get();
3164
+
3165
+ int64_t* perm_out_f16_ne = dst->ne;
3166
+ size_t perm_out_f16_nb[GGML_MAX_DIMS];
3167
+ perm_out_f16_nb[0] = faElemSize;
3168
+ for(int i = 1; i < GGML_MAX_DIMS; ++i){
3169
+ perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1];
3170
+ }
3171
+ aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor(
3172
+ perm_out_f16_buffer, faDataType, faElemSize,
3173
+ perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS);
3174
+ aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS);
3175
+ aclnn_cast(ctx,
3176
+ acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
3177
+ ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor);
3178
+ }else{
3179
+ // only need to permute
3180
+ aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS);
3181
+ }
3182
+ ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
3183
+ acl_src1_f16_tensor,
3184
+ acl_src2_f16_tensor,
3185
+ acl_dst_f16_tensor,
3186
+ acl_dst_tensor);
3187
+ if(src3 != nullptr){
3188
+ ggml_cann_release_resources(ctx, bcast_pse_tensor);
3189
+ }
3190
+ }else{
3191
+ GGML_ABORT("Function is not implemented.");
3192
+ }
3427
3193
  }