whispercpp 1.3.4 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (891) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +158 -44
  4. data/ext/extconf.rb +3 -2
  5. data/ext/ruby_whisper.c +34 -6
  6. data/ext/ruby_whisper.h +67 -0
  7. data/ext/ruby_whisper_context.c +236 -144
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +12 -13
  10. data/ext/ruby_whisper_params.c +47 -24
  11. data/ext/ruby_whisper_segment.c +84 -20
  12. data/ext/ruby_whisper_token.c +371 -0
  13. data/ext/ruby_whisper_transcribe.cpp +5 -2
  14. data/ext/ruby_whisper_vad_context.c +122 -0
  15. data/ext/ruby_whisper_vad_context_detect.cpp +51 -0
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +138 -0
  18. data/ext/ruby_whisper_vad_segments.c +105 -0
  19. data/ext/sources/CMakeLists.txt +4 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  22. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  23. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  24. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  25. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  26. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  27. data/ext/sources/examples/bench/bench.cpp +23 -18
  28. data/ext/sources/examples/cli/cli.cpp +129 -112
  29. data/ext/sources/examples/common-ggml.cpp +2 -0
  30. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  31. data/ext/sources/examples/miniaudio.h +4507 -2131
  32. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  33. data/ext/sources/examples/server/server.cpp +28 -15
  34. data/ext/sources/examples/talk-llama/CMakeLists.txt +8 -3
  35. data/ext/sources/examples/talk-llama/llama-adapter.cpp +5 -2
  36. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -0
  37. data/ext/sources/examples/talk-llama/llama-arch.cpp +2378 -1988
  38. data/ext/sources/examples/talk-llama/llama-arch.h +109 -2
  39. data/ext/sources/examples/talk-llama/llama-batch.cpp +78 -34
  40. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  41. data/ext/sources/examples/talk-llama/llama-chat.cpp +100 -4
  42. data/ext/sources/examples/talk-llama/llama-chat.h +5 -0
  43. data/ext/sources/examples/talk-llama/llama-context.cpp +1088 -403
  44. data/ext/sources/examples/talk-llama/llama-context.h +70 -23
  45. data/ext/sources/examples/talk-llama/llama-cparams.h +6 -0
  46. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  47. data/ext/sources/examples/talk-llama/llama-grammar.cpp +295 -60
  48. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  49. data/ext/sources/examples/talk-llama/llama-graph.cpp +925 -155
  50. data/ext/sources/examples/talk-llama/llama-graph.h +234 -23
  51. data/ext/sources/examples/talk-llama/llama-hparams.cpp +79 -38
  52. data/ext/sources/examples/talk-llama/llama-hparams.h +118 -18
  53. data/ext/sources/examples/talk-llama/llama-impl.cpp +11 -7
  54. data/ext/sources/examples/talk-llama/llama-impl.h +14 -2
  55. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +8 -4
  56. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +405 -140
  57. data/ext/sources/examples/talk-llama/llama-kv-cache.h +24 -10
  58. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  59. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  60. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  61. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  62. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +42 -31
  63. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  64. data/ext/sources/examples/talk-llama/llama-mmap.cpp +197 -45
  65. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  66. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +606 -116
  67. data/ext/sources/examples/talk-llama/llama-model-loader.h +41 -5
  68. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +61 -44
  69. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  70. data/ext/sources/examples/talk-llama/llama-model.cpp +2756 -13643
  71. data/ext/sources/examples/talk-llama/llama-model.h +112 -18
  72. data/ext/sources/examples/talk-llama/llama-quant.cpp +582 -365
  73. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +1409 -199
  74. data/ext/sources/examples/talk-llama/llama-sampler.h +42 -0
  75. data/ext/sources/examples/talk-llama/llama-vocab.cpp +248 -82
  76. data/ext/sources/examples/talk-llama/llama-vocab.h +50 -40
  77. data/ext/sources/examples/talk-llama/llama.cpp +802 -21
  78. data/ext/sources/examples/talk-llama/llama.h +210 -39
  79. data/ext/sources/examples/talk-llama/models/afmoe.cpp +190 -0
  80. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  81. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  82. data/ext/sources/examples/talk-llama/models/arctic.cpp +137 -0
  83. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  84. data/ext/sources/examples/talk-llama/models/baichuan.cpp +123 -0
  85. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +143 -0
  86. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +133 -0
  87. data/ext/sources/examples/talk-llama/models/bert.cpp +184 -0
  88. data/ext/sources/examples/talk-llama/models/bitnet.cpp +145 -0
  89. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  90. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  91. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  92. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  93. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  94. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  95. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  96. data/ext/sources/examples/talk-llama/models/dbrx.cpp +122 -0
  97. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  98. data/ext/sources/examples/talk-llama/models/deepseek.cpp +142 -0
  99. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +262 -0
  100. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  101. data/ext/sources/examples/talk-llama/models/dots1.cpp +132 -0
  102. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  103. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +148 -0
  104. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  105. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  106. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  107. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  108. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  109. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +111 -0
  110. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  111. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  112. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  113. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  114. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  115. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  116. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  117. data/ext/sources/examples/talk-llama/models/glm4.cpp +157 -0
  118. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  119. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  120. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +195 -0
  121. data/ext/sources/examples/talk-llama/models/granite.cpp +210 -0
  122. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  123. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +139 -0
  124. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  125. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +153 -0
  126. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  127. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  128. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  129. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  130. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  131. data/ext/sources/examples/talk-llama/models/lfm2.cpp +196 -0
  132. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  133. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  134. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  135. data/ext/sources/examples/talk-llama/models/llama.cpp +175 -0
  136. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  137. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +289 -0
  138. data/ext/sources/examples/talk-llama/models/mamba.cpp +54 -0
  139. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +129 -0
  140. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +200 -0
  141. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +123 -0
  142. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  143. data/ext/sources/examples/talk-llama/models/models.h +704 -0
  144. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +109 -0
  145. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  146. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +162 -0
  147. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  148. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  149. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  150. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  151. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  153. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  154. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  155. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  156. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  157. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  158. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  159. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  160. data/ext/sources/examples/talk-llama/models/plamo2.cpp +320 -0
  161. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  162. data/ext/sources/examples/talk-llama/models/plm.cpp +169 -0
  163. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  164. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  165. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  166. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  167. data/ext/sources/examples/talk-llama/models/qwen3.cpp +120 -0
  168. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  169. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  170. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +131 -0
  171. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +525 -0
  172. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +140 -0
  173. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +132 -0
  174. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  175. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  176. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +164 -0
  177. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  178. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  179. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +137 -0
  180. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  181. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  182. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  183. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  184. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  185. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  186. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  187. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  188. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  189. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  190. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  191. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  192. data/ext/sources/examples/talk-llama/unicode.cpp +121 -79
  193. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  194. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  195. data/ext/sources/ggml/CMakeLists.txt +90 -56
  196. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  197. data/ext/sources/ggml/include/ggml-backend.h +5 -2
  198. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  199. data/ext/sources/ggml/include/ggml-cpu.h +6 -0
  200. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  201. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  202. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  203. data/ext/sources/ggml/include/ggml-rpc.h +14 -12
  204. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  205. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  206. data/ext/sources/ggml/include/ggml.h +246 -21
  207. data/ext/sources/ggml/src/CMakeLists.txt +85 -11
  208. data/ext/sources/ggml/src/ggml-alloc.c +128 -50
  209. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  210. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  211. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  212. data/ext/sources/ggml/src/ggml-backend-reg.cpp +54 -88
  213. data/ext/sources/ggml/src/ggml-backend.cpp +76 -23
  214. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +18 -4
  215. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +11 -11
  216. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +58 -46
  217. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +139 -48
  218. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2427 -1785
  219. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -362
  220. data/ext/sources/ggml/src/ggml-cann/common.h +285 -211
  221. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +663 -831
  222. data/ext/sources/ggml/src/ggml-common.h +11 -0
  223. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +170 -95
  224. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -18
  225. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  226. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  227. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  228. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +513 -27
  229. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +4192 -992
  230. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  231. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  232. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1761 -49
  233. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  234. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  235. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  236. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  237. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +124 -24
  238. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +157 -28
  239. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  240. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  241. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +8 -3
  242. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +251 -80
  243. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +19 -0
  244. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +587 -119
  245. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  246. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +1093 -194
  247. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1284 -203
  248. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  249. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1519 -527
  250. data/ext/sources/ggml/src/ggml-cpu/ops.h +6 -4
  251. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  252. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  253. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3632 -781
  254. data/ext/sources/ggml/src/ggml-cpu/repack.h +129 -4
  255. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  256. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +152 -46
  257. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  258. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +152 -1
  259. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  260. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +140 -0
  261. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  262. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  263. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  264. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +132 -6
  265. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  266. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +33 -31
  267. data/ext/sources/ggml/src/ggml-cuda/common.cuh +474 -85
  268. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  269. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  270. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  271. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +342 -246
  272. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  273. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  274. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  275. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  276. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +98 -74
  278. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +973 -665
  279. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  280. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1255 -0
  281. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +33 -40
  282. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +40 -18
  283. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  284. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +206 -45
  285. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  286. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  287. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  288. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  289. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1688 -302
  290. data/ext/sources/ggml/src/ggml-cuda/mean.cu +12 -10
  291. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +908 -48
  292. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +88 -20
  293. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +502 -90
  294. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  295. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  296. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  297. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +532 -193
  298. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +460 -104
  299. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +5 -2
  300. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +360 -122
  301. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +2 -1
  302. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  303. data/ext/sources/ggml/src/ggml-cuda/pad.cu +73 -39
  304. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +152 -1
  305. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  306. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  307. data/ext/sources/ggml/src/ggml-cuda/rope.cu +364 -149
  308. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  309. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  310. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  311. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  312. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +163 -41
  313. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  314. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  315. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +68 -50
  316. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  317. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  318. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  319. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  320. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  321. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  322. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  323. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  324. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  325. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  326. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  327. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  328. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  329. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  330. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  331. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  332. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  333. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +22 -4
  334. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +95 -0
  335. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  336. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +275 -119
  337. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -7
  338. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  339. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  340. data/ext/sources/ggml/src/ggml-cuda/unary.cu +160 -11
  341. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +38 -0
  342. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  343. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  344. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  345. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -1
  346. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  347. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +117 -0
  348. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3325 -0
  349. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +46 -0
  350. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +813 -0
  351. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  352. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +891 -0
  353. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  354. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  355. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +713 -0
  356. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  357. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +63 -0
  358. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +182 -0
  359. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  360. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  361. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  362. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  363. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +155 -0
  364. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +63 -0
  365. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  366. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  367. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  368. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  369. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  370. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  371. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  372. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  373. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  374. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  375. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  376. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  377. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  378. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  379. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +26 -0
  380. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1199 -0
  381. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2670 -0
  382. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +497 -0
  383. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  384. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +419 -0
  385. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  386. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  387. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +382 -0
  388. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +293 -0
  389. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  390. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  391. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  392. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  393. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  394. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  395. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +14 -13
  396. data/ext/sources/ggml/src/ggml-impl.h +129 -6
  397. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  398. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +15 -4
  399. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  400. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +173 -34
  401. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +912 -344
  402. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +124 -59
  403. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +588 -144
  404. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +396 -23
  405. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1724 -421
  406. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +16 -3
  407. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +333 -114
  408. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +3050 -1539
  409. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  410. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +30 -1
  411. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +4279 -497
  412. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  413. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  414. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  415. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +267 -0
  416. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  417. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  418. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +113 -0
  419. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  420. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  421. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  422. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  423. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  424. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  425. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  426. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  427. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  428. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +140 -0
  429. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  430. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  431. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  432. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  433. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  434. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  435. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  436. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  437. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  438. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  439. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  440. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  441. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  442. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  443. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  444. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  445. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  446. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  447. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  448. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  449. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +116 -0
  450. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  451. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  452. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  453. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  454. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  455. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  456. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +39 -0
  457. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  458. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  459. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  460. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  461. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  462. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  463. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  464. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  465. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  466. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  467. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  468. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  469. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  470. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  471. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  472. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  473. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  474. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  475. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  476. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  477. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  478. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  479. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  480. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  481. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  482. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  483. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  484. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  485. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  486. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  487. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  488. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  489. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  490. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  491. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  492. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  493. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  494. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  495. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  496. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  497. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  498. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  499. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  500. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  501. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  502. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  503. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  504. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  505. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  506. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  507. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  508. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +59 -87
  509. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +81 -0
  510. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  511. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +7 -0
  512. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -29
  513. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  514. data/ext/sources/ggml/src/ggml-sycl/common.hpp +427 -20
  515. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  516. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +103 -1
  517. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  518. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  519. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  520. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  521. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  522. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +867 -50
  523. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +401 -358
  524. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +12 -2
  525. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  526. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  527. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  528. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  529. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  530. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  531. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  532. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +645 -155
  534. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  535. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +221 -66
  536. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  537. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  538. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  539. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  540. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  541. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  542. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +5 -0
  543. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  544. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  545. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  547. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  548. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +457 -281
  549. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  550. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  551. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  552. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  553. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  554. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  555. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  556. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  557. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  558. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  559. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  560. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  561. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  562. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  563. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  564. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  565. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  566. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  567. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  568. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  569. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  570. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  571. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  572. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  573. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  574. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  575. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  576. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  577. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  578. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  579. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  580. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  581. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  582. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  583. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  584. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  585. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  586. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  587. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  588. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  589. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  590. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  591. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  592. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  593. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  594. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  595. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  596. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  597. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  598. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  599. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  600. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  601. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +71 -0
  602. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  603. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  604. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  605. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  606. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  607. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  608. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  609. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  610. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  611. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  612. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  613. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  614. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  615. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  616. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  617. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  618. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  619. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  620. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  621. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  622. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  623. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  624. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  625. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  626. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  627. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  628. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  629. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  630. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  631. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  632. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  633. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  634. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  635. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  636. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  637. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  638. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  639. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  640. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  641. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  642. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  643. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  644. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  645. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +39 -19
  646. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5994 -3055
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +18 -10
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  663. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  664. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  665. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  666. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  667. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  668. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  669. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  670. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  671. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  672. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  673. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  674. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  675. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  676. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  677. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  678. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  679. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  680. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  681. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  682. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  683. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  684. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  685. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  686. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  687. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  688. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  689. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  690. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  691. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  692. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  693. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  694. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  695. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  696. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  697. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  698. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  699. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  700. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  701. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  702. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  703. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +386 -160
  704. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +82 -20
  705. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +400 -174
  706. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +123 -37
  707. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  708. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +10 -9
  709. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  710. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  711. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  712. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  713. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  714. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  715. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  716. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  717. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  718. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  719. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  720. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  721. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  722. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  723. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  724. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  725. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  726. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  727. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  728. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +13 -10
  729. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  730. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  731. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  732. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  733. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +77 -29
  734. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  735. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  736. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  737. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  738. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  739. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  740. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  741. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  742. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  743. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  744. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  745. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  746. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  747. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  748. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  749. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  750. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  751. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +88 -105
  752. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +41 -26
  753. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  754. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +74 -0
  755. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +92 -230
  756. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  757. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  758. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  759. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  760. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  761. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  762. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  763. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  764. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  765. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  766. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  767. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  768. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  769. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  770. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +49 -4
  771. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  772. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  773. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  774. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +207 -0
  775. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  776. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +8 -49
  777. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +8 -32
  778. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +8 -32
  779. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +33 -0
  780. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +8 -38
  781. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  782. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  783. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  784. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  785. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  786. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  787. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  788. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  789. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  790. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  791. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  792. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  793. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  794. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  795. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  796. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  798. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +50 -0
  799. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  800. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  801. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  802. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  803. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  804. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  805. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  806. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  807. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  808. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  809. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  810. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  811. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  812. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  813. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  814. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  815. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +384 -180
  816. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  817. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  818. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1374 -0
  819. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2544 -726
  820. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  821. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  822. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  823. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  824. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  825. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  826. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  827. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +107 -0
  828. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  829. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +73 -15
  830. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +636 -0
  831. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  832. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  833. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +72 -261
  834. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +766 -0
  835. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +147 -0
  836. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +196 -0
  837. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  838. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  839. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  840. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  841. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  842. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +63 -0
  843. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  844. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  845. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  846. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  847. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  848. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +91 -0
  849. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +469 -0
  850. data/ext/sources/ggml/src/ggml.c +590 -64
  851. data/ext/sources/ggml/src/gguf.cpp +229 -44
  852. data/ext/sources/include/whisper.h +1 -0
  853. data/ext/sources/src/CMakeLists.txt +3 -1
  854. data/ext/sources/src/whisper.cpp +106 -62
  855. data/ext/sources/tests/CMakeLists.txt +2 -2
  856. data/ext/sources/tests/test-vad-full.cpp +4 -2
  857. data/ext/sources/tests/test-vad.cpp +1 -1
  858. data/extsources.rb +1 -0
  859. data/lib/whisper/model/uri.rb +17 -18
  860. data/sig/whisper.rbs +162 -4
  861. data/test/test_context_params.rb +82 -0
  862. data/test/test_params.rb +16 -8
  863. data/test/test_segment.rb +0 -1
  864. data/test/test_token.rb +81 -0
  865. data/test/test_vad.rb +1 -1
  866. data/test/test_vad_context.rb +100 -0
  867. data/test/test_vad_segment.rb +19 -0
  868. data/test/test_vad_segments.rb +16 -0
  869. data/test/test_whisper.rb +27 -0
  870. data/whispercpp.gemspec +1 -1
  871. metadata +502 -37
  872. data/ext/sources/build-xcframework.sh +0 -571
  873. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -32
  874. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  875. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  876. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  877. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  878. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  879. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  880. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  881. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  882. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  883. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  884. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  885. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  886. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  887. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  888. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  889. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  890. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  891. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -1,5 +1,5 @@
1
1
  /*
2
- * Copyright (c) 2023-2024 The ggml authors
2
+ * Copyright (c) 2023-2026 The ggml authors
3
3
  *
4
4
  * Permission is hereby granted, free of charge, to any person obtaining a copy
5
5
  * of this software and associated documentation files (the "Software"), to
@@ -22,24 +22,24 @@
22
22
 
23
23
  #include "ggml-cann.h"
24
24
 
25
+ #include "ggml-backend-impl.h"
26
+ #include "ggml-cann/aclnn_ops.h"
27
+ #include "ggml-cann/common.h"
28
+ #include "ggml-impl.h"
29
+ #include "ggml.h"
30
+
25
31
  #include <acl/acl.h>
26
- #include <stdarg.h>
27
32
  #include <aclnnop/aclnn_trans_matmul_weight.h>
33
+ #include <stdarg.h>
28
34
 
35
+ #include <chrono>
29
36
  #include <cmath>
30
37
  #include <cstdio>
31
38
  #include <cstring>
32
39
  #include <mutex>
40
+ #include <optional>
33
41
  #include <queue>
34
- #include <chrono>
35
42
  #include <unordered_set>
36
- #include <optional>
37
-
38
- #include "ggml-impl.h"
39
- #include "ggml-backend-impl.h"
40
- #include "ggml-cann/aclnn_ops.h"
41
- #include "ggml-cann/common.h"
42
- #include "ggml.h"
43
43
 
44
44
  #define GGML_COMMON_DECL_C
45
45
 
@@ -56,52 +56,52 @@
56
56
  * @param line The line number where the error occurred.
57
57
  * @param msg The error message.
58
58
  */
59
- [[noreturn]] void ggml_cann_error(const char* stmt, const char* func,
60
- const char* file, int line, const char* msg) {
59
+ [[noreturn]] void ggml_cann_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
61
60
  int32_t id = -1;
62
61
  aclrtGetDevice(&id);
63
62
 
64
63
  GGML_LOG_ERROR("CANN error: %s\n", msg);
65
- GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func,
66
- file, line);
64
+ GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, file, line);
67
65
  GGML_LOG_ERROR(" %s\n", stmt);
68
66
  // abort with GGML_ASSERT to get a stack trace
69
67
  GGML_ABORT("CANN error");
70
68
  }
71
69
 
70
+ // Thread-local variable to record the current device of this thread.
71
+ thread_local int g_current_cann_device = -1;
72
+
72
73
  /**
73
- * @brief Sets the device to be used by CANN.
74
+ * @brief Set the CANN device to be used.
74
75
  *
75
- * @param device The device ID to set.
76
+ * @param device The target device ID to set.
76
77
  */
77
78
  void ggml_cann_set_device(const int32_t device) {
78
- int current_device = -1;
79
- aclrtGetDevice(&current_device);
79
+ // int current_device = -1;
80
+ // Note: In some CANN versions, if no device has been set yet,
81
+ // aclrtGetDevice(&current_device) may return 0 by default.
82
+ // aclrtGetDevice(&current_device);
80
83
 
81
- if (device == current_device) {
82
- return;
84
+ // If the current device is already the target one, no need to switch.
85
+ if (device == g_current_cann_device) {
86
+ return;
83
87
  }
88
+
89
+ // Switch to the new device.
84
90
  ACL_CHECK(aclrtSetDevice(device));
85
- }
86
91
 
87
- /**
88
- * @brief Retrieves the current device ID.
89
- *
90
- * @return The current device ID.
91
- */
92
- int32_t ggml_cann_get_device() {
93
- int32_t id;
94
- ACL_CHECK(aclrtGetDevice(&id));
95
- return id;
92
+ // Update the global device record.
93
+ g_current_cann_device = device;
96
94
  }
97
95
 
98
96
  /**
99
- * @brief Get the value of the specified environment variable (name).
97
+ * @brief Get the value of the specified environment variable (name) as lowercase.
100
98
  * if not empty, return a std::string object
101
99
  */
102
- std::optional<std::string> get_env(const std::string& name) {
103
- const char* val = std::getenv(name.c_str());
104
- if (!val) return std::nullopt;
100
+ std::optional<std::string> get_env_as_lowercase(const std::string & name) {
101
+ const char * val = std::getenv(name.c_str());
102
+ if (!val) {
103
+ return std::nullopt;
104
+ }
105
105
  std::string res = std::string(val);
106
106
  std::transform(res.begin(), res.end(), res.begin(), ::tolower);
107
107
  return res;
@@ -110,8 +110,8 @@ std::optional<std::string> get_env(const std::string& name) {
110
110
  /**
111
111
  * @brief Verify whether the environment variable is a valid value.
112
112
  */
113
- bool parse_bool(const std::string& value) {
114
- std::unordered_set<std::string> valid_values = {"on", "1", "yes", "y", "enable", "true"};
113
+ bool parse_bool(const std::string & value) {
114
+ static const std::unordered_set<std::string> valid_values = { "on", "1", "yes", "y", "enable", "true" };
115
115
  return valid_values.find(value) != valid_values.end();
116
116
  }
117
117
 
@@ -125,7 +125,7 @@ bool parse_bool(const std::string& value) {
125
125
  * @param value The string to parse.
126
126
  * @return The parsed integer, or 0 if conversion fails.
127
127
  */
128
- int parse_integer(const std::string& value) {
128
+ int parse_integer(const std::string & value) {
129
129
  try {
130
130
  return std::stoi(value);
131
131
  } catch (...) {
@@ -144,11 +144,10 @@ int parse_integer(const std::string& value) {
144
144
  static ggml_cann_device_info ggml_cann_init() {
145
145
  ggml_cann_device_info info = {};
146
146
 
147
- aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count);
147
+ aclError err = aclrtGetDeviceCount((uint32_t *) &info.device_count);
148
148
 
149
149
  if (err != ACL_SUCCESS) {
150
- GGML_LOG_ERROR("%s: failed to initialize CANN: %s\n",
151
- __func__, aclGetRecentErrMsg());
150
+ GGML_LOG_ERROR("%s: failed to initialize CANN: %s\n", __func__, aclGetRecentErrMsg());
152
151
  return info;
153
152
  }
154
153
 
@@ -156,16 +155,15 @@ static ggml_cann_device_info ggml_cann_init() {
156
155
 
157
156
  for (int id = 0; id < info.device_count; ++id) {
158
157
  aclrtPhysicalMemProp prop = {};
159
- prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
160
- prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
161
- prop.memAttr = ACL_HBM_MEM_HUGE;
162
- prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
163
- prop.location.id = id;
164
- prop.reserve = 0;
165
- err = aclrtMemGetAllocationGranularity(
166
- &prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED,
167
- &info.devices[id].vmm_granularity);
168
- info.devices[id].vmm = err == ACL_SUCCESS;
158
+ prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
159
+ prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
160
+ prop.memAttr = ACL_HBM_MEM_HUGE;
161
+ prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
162
+ prop.location.id = id;
163
+ prop.reserve = 0;
164
+ err = aclrtMemGetAllocationGranularity(&prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED,
165
+ &info.devices[id].vmm_granularity);
166
+ info.devices[id].vmm = err == ACL_SUCCESS;
169
167
 
170
168
  size_t free, total;
171
169
  ggml_backend_cann_get_device_memory(id, &free, &total);
@@ -185,7 +183,7 @@ static ggml_cann_device_info ggml_cann_init() {
185
183
  *
186
184
  * @return A reference to the structure containing the device information.
187
185
  */
188
- const ggml_cann_device_info& ggml_cann_info() {
186
+ const ggml_cann_device_info & ggml_cann_info() {
189
187
  static ggml_cann_device_info info = ggml_cann_init();
190
188
  return info;
191
189
  }
@@ -205,7 +203,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
205
203
  /**
206
204
  * @brief The minimum free margin for a buffer.
207
205
  */
208
- static const size_t min_free_margin = 1ull << 20; // 1MB
206
+ static const size_t min_free_margin = 1ull << 20; // 1MB
209
207
 
210
208
  /**
211
209
  * @brief The alignment for buffer allocation.
@@ -226,22 +224,18 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
226
224
  * @brief Structure representing a CANN buffer.
227
225
  */
228
226
  struct ggml_cann_buffer {
229
- void* ptr = nullptr; ///< Pointer to the buffer.
230
- size_t size = 0; ///< Size of the buffer.
231
- std::chrono::steady_clock::time_point last_used; ///< Last used time.
227
+ void * ptr = nullptr; ///< Pointer to the buffer.
228
+ size_t size = 0; ///< Size of the buffer.
229
+ std::chrono::steady_clock::time_point last_used; ///< Last used time.
232
230
 
233
- bool operator>(const ggml_cann_buffer& other) const {
234
- return size > other.size;
235
- }
231
+ bool operator>(const ggml_cann_buffer & other) const { return size > other.size; }
236
232
  };
237
233
 
238
234
  /**
239
235
  * @brief Array of CANN buffers in the pool.
240
236
  */
241
- std::unordered_map<void*, size_t> buffer_pool;
242
- std::priority_queue<ggml_cann_buffer,
243
- std::vector<ggml_cann_buffer>,
244
- std::greater<>> free_buffers ;
237
+ std::unordered_map<void *, size_t> buffer_pool;
238
+ std::priority_queue<ggml_cann_buffer, std::vector<ggml_cann_buffer>, std::greater<>> free_buffers;
245
239
 
246
240
  /**
247
241
  * @brief Total size of all buffers in the pool.
@@ -254,7 +248,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
254
248
  * @param device The device ID to associate with this buffer pool.
255
249
  */
256
250
  explicit ggml_cann_pool_buf_prio(int device) : device(device) {
257
- disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
251
+ disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
258
252
  }
259
253
 
260
254
  /**
@@ -262,7 +256,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
262
256
  */
263
257
  ~ggml_cann_pool_buf_prio() {
264
258
  ggml_cann_set_device(device);
265
- for (auto& [b_ptr, b_size] : buffer_pool) {
259
+ for (auto & [b_ptr, b_size] : buffer_pool) {
266
260
  aclrtFree(b_ptr);
267
261
  pool_size -= b_size;
268
262
  }
@@ -278,14 +272,14 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
278
272
  * the allocated buffer.
279
273
  * @return A pointer to the allocated buffer.
280
274
  */
281
- void* alloc(size_t size, size_t* actual_size) override {
275
+ void * alloc(size_t size, size_t * actual_size) override {
282
276
  size = GGML_PAD(size, alignment);
283
277
  if (size == 0) {
284
278
  size = alignment;
285
279
  }
286
280
 
287
- void* ptr = nullptr;
288
- auto now = std::chrono::steady_clock::now();
281
+ void * ptr = nullptr;
282
+ auto now = std::chrono::steady_clock::now();
289
283
 
290
284
  std::vector<ggml_cann_buffer> free_buffers_rest;
291
285
  free_buffers_rest.reserve(free_buffers.size());
@@ -298,24 +292,22 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
298
292
  const size_t margin = b.size - size;
299
293
  if (margin <= max_reuse_margin) {
300
294
  *actual_size = b.size;
301
- ptr = b.ptr;
295
+ ptr = b.ptr;
302
296
  #ifdef DEBUG_CANN_MALLOC
303
297
  GGML_LOG_INFO(
304
298
  "cann pool[%d]: reused %p, "
305
299
  "pool_size = %5u MB, "
306
300
  "size = %5u MB, "
307
301
  "margin = %5u MB\n",
308
- device, b.ptr,
309
- (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
310
- (uint32_t)(GGML_PAD(size, 1048576) / 1048576),
311
- (uint32_t)(GGML_PAD(margin, 1048576) / 1048576));
302
+ device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
303
+ (uint32_t) (GGML_PAD(size, 1048576) / 1048576),
304
+ (uint32_t) (GGML_PAD(margin, 1048576) / 1048576));
312
305
  #endif
313
306
  break;
314
307
  }
315
308
  }
316
309
 
317
- bool should_clean = !disable_clean &&
318
- b.size > min_free_margin &&
310
+ bool should_clean = !disable_clean && b.size > min_free_margin &&
319
311
  std::chrono::duration_cast<std::chrono::milliseconds>(now - b.last_used).count() > 100;
320
312
  if (should_clean) {
321
313
  // free the buffer if the size is needed to be freed
@@ -327,20 +319,20 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
327
319
  "cann pool[%d]: clean %p, "
328
320
  "pool_size = %5u MB, "
329
321
  "size = %5u MB\n",
330
- device, b.ptr,
331
- (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
332
- (uint32_t)(GGML_PAD(b.size, 1048576) / 1048576));
322
+ device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
323
+ (uint32_t) (GGML_PAD(b.size, 1048576) / 1048576));
333
324
  #endif
334
325
  continue;
335
326
  }
336
327
  free_buffers_rest.push_back(b);
337
328
  }
338
- for (ggml_cann_buffer &b : free_buffers_rest) {
329
+ for (ggml_cann_buffer & b : free_buffers_rest) {
339
330
  free_buffers.push(std::move(b));
340
331
  }
341
332
 
342
333
  #ifdef DEBUG_CANN_MALLOC
343
- GGML_LOG_INFO("cann pool[%d] free pool_size = %5u MB\n\n", device, (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576));
334
+ GGML_LOG_INFO("cann pool[%d] free pool_size = %5u MB\n\n", device,
335
+ (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576));
344
336
  #endif
345
337
  if (ptr != nullptr) {
346
338
  return ptr;
@@ -356,8 +348,8 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
356
348
  "cann pool[%d]: allocate %p, "
357
349
  "pool_size = %5u MB, "
358
350
  "size = %5u MB\n",
359
- device, ptr, (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
360
- (uint32_t)(GGML_PAD(size, 1048576) / 1048576));
351
+ device, ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
352
+ (uint32_t) (GGML_PAD(size, 1048576) / 1048576));
361
353
  #endif
362
354
  buffer_pool.emplace(ptr, size);
363
355
  return ptr;
@@ -369,7 +361,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
369
361
  * @param ptr Pointer to the buffer to free.
370
362
  * @param size Size of the buffer to free.
371
363
  */
372
- void free(void* ptr, size_t size) override {
364
+ void free(void * ptr, size_t size) override {
373
365
  GGML_UNUSED(size);
374
366
  auto it = buffer_pool.find(ptr);
375
367
  if (it == buffer_pool.end()) {
@@ -377,13 +369,12 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
377
369
  }
378
370
 
379
371
  auto now = std::chrono::steady_clock::now();
380
- free_buffers.emplace(ggml_cann_buffer{ptr, it->second, now});
372
+ free_buffers.emplace(ggml_cann_buffer{ ptr, it->second, now });
381
373
  #ifdef DEBUG_CANN_MALLOC
382
374
  GGML_LOG_INFO(
383
375
  "cann pool[%d]: return %p, "
384
376
  "pool_size = %5u MB\n",
385
- device, ptr,
386
- (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576));
377
+ device, ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576));
387
378
  #endif
388
379
  }
389
380
  };
@@ -402,7 +393,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
402
393
  /**
403
394
  * @brief The minimum free margin for a buffer.
404
395
  */
405
- static const size_t min_free_margin = 1ull << 20; // 1MB
396
+ static const size_t min_free_margin = 1ull << 20; // 1MB
406
397
 
407
398
  /**
408
399
  * @brief The alignment for buffer allocation.
@@ -428,10 +419,10 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
428
419
  * @brief Structure representing a CANN buffer.
429
420
  */
430
421
  struct ggml_cann_buffer {
431
- void* ptr = nullptr; ///< Pointer to the buffer memory.
432
- size_t size = 0; ///< Size of the buffer.
433
- bool used = false; ///< Whether the buffer is currently in use.
434
- std::chrono::steady_clock::time_point last_used; ///< Last used time.
422
+ void * ptr = nullptr; ///< Pointer to the buffer memory.
423
+ size_t size = 0; ///< Size of the buffer.
424
+ bool used = false; ///< Whether the buffer is currently in use.
425
+ std::chrono::steady_clock::time_point last_used; ///< Last used time.
435
426
  };
436
427
 
437
428
  /**
@@ -450,7 +441,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
450
441
  * @param device The device ID to associate with this buffer pool.
451
442
  */
452
443
  explicit ggml_cann_pool_buf(int device) : device(device) {
453
- disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
444
+ disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
454
445
  }
455
446
 
456
447
  /**
@@ -459,7 +450,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
459
450
  ~ggml_cann_pool_buf() {
460
451
  ggml_cann_set_device(device);
461
452
  for (int i = 0; i < MAX_BUFFERS; ++i) {
462
- ggml_cann_buffer& b = buffer_pool[i];
453
+ ggml_cann_buffer & b = buffer_pool[i];
463
454
  if (b.ptr != nullptr) {
464
455
  aclrtFree(b.ptr);
465
456
  pool_size -= b.size;
@@ -476,18 +467,18 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
476
467
  * the allocated buffer.
477
468
  * @return A pointer to the allocated buffer.
478
469
  */
479
- void* alloc(size_t size, size_t* actual_size) override {
470
+ void * alloc(size_t size, size_t * actual_size) override {
480
471
  size = GGML_PAD(size, alignment);
481
472
  if (size == 0) {
482
473
  size = alignment;
483
474
  }
484
475
 
485
- void* ptr = nullptr;
486
- auto now = std::chrono::steady_clock::now();
476
+ void * ptr = nullptr;
477
+ auto now = std::chrono::steady_clock::now();
487
478
 
488
479
  int i = 0;
489
480
  for (; i < MAX_BUFFERS; ++i) {
490
- ggml_cann_buffer& b = buffer_pool[i];
481
+ ggml_cann_buffer & b = buffer_pool[i];
491
482
  if (b.ptr == nullptr) {
492
483
  break;
493
484
  }
@@ -499,25 +490,23 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
499
490
  const size_t margin = b.size - size;
500
491
  if (margin <= max_reuse_margin) {
501
492
  *actual_size = b.size;
502
- b.used = true;
503
- ptr = b.ptr;
493
+ b.used = true;
494
+ ptr = b.ptr;
504
495
  #ifdef DEBUG_CANN_MALLOC
505
496
  GGML_LOG_INFO(
506
497
  "cann pool[%d]: reused %p, "
507
498
  "pool_size = %5u MB, "
508
499
  "size = %5u MB, "
509
500
  "margin = %5u MB\n",
510
- device, b.ptr,
511
- (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
512
- (uint32_t)(GGML_PAD(size, 1048576) / 1048576),
513
- (uint32_t)(GGML_PAD(margin, 1048576) / 1048576));
501
+ device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
502
+ (uint32_t) (GGML_PAD(size, 1048576) / 1048576),
503
+ (uint32_t) (GGML_PAD(margin, 1048576) / 1048576));
514
504
  #endif
515
505
  break;
516
506
  }
517
507
  }
518
508
 
519
- bool should_clean = !disable_clean &&
520
- b.size > min_free_margin &&
509
+ bool should_clean = !disable_clean && b.size > min_free_margin &&
521
510
  std::chrono::duration_cast<std::chrono::milliseconds>(now - b.last_used).count() > 100;
522
511
  if (should_clean) {
523
512
  // free the buffer if the size is needed to be freed
@@ -528,9 +517,8 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
528
517
  "cann pool[%d]: clean %p, "
529
518
  "pool_size = %5u MB, "
530
519
  "size = %5u MB\n",
531
- device, b.ptr,
532
- (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
533
- (uint32_t)(GGML_PAD(b.size, 1048576) / 1048576));
520
+ device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
521
+ (uint32_t) (GGML_PAD(b.size, 1048576) / 1048576));
534
522
  #endif
535
523
  b.ptr = nullptr;
536
524
  }
@@ -541,13 +529,13 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
541
529
 
542
530
  if (i < MAX_BUFFERS) {
543
531
  // allocate a new buffer if no buffer can be reused
544
- ggml_cann_buffer& b = buffer_pool[i];
532
+ ggml_cann_buffer & b = buffer_pool[i];
545
533
  ggml_cann_set_device(device);
546
534
  ACL_CHECK(aclrtMalloc(&b.ptr, size, ACL_MEM_MALLOC_HUGE_FIRST));
547
535
  pool_size += size;
548
536
  *actual_size = size;
549
- b.size = size;
550
- b.used = true;
537
+ b.size = size;
538
+ b.used = true;
551
539
  if (i >= MAX_BUFFERS - 8) {
552
540
  GGML_LOG_WARN("cann pool[%d]: slots almost full\n", device);
553
541
  }
@@ -556,9 +544,8 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
556
544
  "cann pool[%d]: allocate %p, "
557
545
  "pool_size = %5u MB, "
558
546
  "size = %5u MB\n",
559
- device, b.ptr,
560
- (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
561
- (uint32_t)(GGML_PAD(b.size, 1048576) / 1048576));
547
+ device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
548
+ (uint32_t) (GGML_PAD(b.size, 1048576) / 1048576));
562
549
  #endif
563
550
  return b.ptr;
564
551
  }
@@ -572,21 +559,20 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
572
559
  * @param ptr Pointer to the buffer to free.
573
560
  * @param size Size of the buffer to free.
574
561
  */
575
- void free(void* ptr, size_t size) override {
562
+ void free(void * ptr, size_t size) override {
576
563
  GGML_UNUSED(size);
577
564
  for (int i = 0; i < MAX_BUFFERS; ++i) {
578
- ggml_cann_buffer& b = buffer_pool[i];
565
+ ggml_cann_buffer & b = buffer_pool[i];
579
566
  if (b.ptr != ptr) {
580
567
  continue;
581
568
  }
582
- b.used = false;
569
+ b.used = false;
583
570
  b.last_used = std::chrono::steady_clock::now();
584
571
  #ifdef DEBUG_CANN_MALLOC
585
572
  GGML_LOG_INFO(
586
573
  "cann pool[%d]: return %p, "
587
574
  "pool_size = %5u MB\n",
588
- device, b.ptr,
589
- (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576));
575
+ device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576));
590
576
  #endif
591
577
  return;
592
578
  }
@@ -614,7 +600,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
614
600
  /**
615
601
  * @brief Pointer to the start of the virtual memory pool.
616
602
  */
617
- void* pool_addr = 0;
603
+ void * pool_addr = 0;
618
604
 
619
605
  /**
620
606
  * @brief Amount of virtual memory used in the pool.
@@ -639,7 +625,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
639
625
  /**
640
626
  * @brief Offsets for the mapped memory regions.
641
627
  */
642
- std::vector<void*> map_offsets;
628
+ std::vector<void *> map_offsets;
643
629
 
644
630
  /**
645
631
  * @brief Constructor to initialize the buffer pool with virtual memory for
@@ -647,11 +633,10 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
647
633
  *
648
634
  * @param device The device ID to associate with this buffer pool.
649
635
  */
650
- explicit ggml_cann_pool_vmm(int device)
651
- : device(device) {
652
- auto dev = ggml_cann_info().devices[device];
636
+ explicit ggml_cann_pool_vmm(int device) : device(device) {
637
+ auto dev = ggml_cann_info().devices[device];
653
638
  granularity = dev.vmm_granularity;
654
- max_size = dev.total_vram;
639
+ max_size = dev.total_vram;
655
640
  }
656
641
 
657
642
  /**
@@ -659,10 +644,10 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
659
644
  */
660
645
  ~ggml_cann_pool_vmm() {
661
646
  if (pool_addr != 0) {
662
- for (auto& offset : map_offsets) {
647
+ for (auto & offset : map_offsets) {
663
648
  ACL_CHECK(aclrtUnmapMem(offset));
664
649
  }
665
- for (auto& handle : handles) {
650
+ for (auto & handle : handles) {
666
651
  ACL_CHECK(aclrtFreePhysical(handle));
667
652
  }
668
653
  ACL_CHECK(aclrtReleaseMemAddress(pool_addr));
@@ -677,11 +662,11 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
677
662
  * the allocated buffer.
678
663
  * @return A pointer to the allocated buffer.
679
664
  */
680
- void* alloc(size_t size, size_t* actual_size) override {
665
+ void * alloc(size_t size, size_t * actual_size) override {
681
666
  // round up the allocation size to the alignment to ensure that all
682
667
  // allocations are aligned for all data types
683
668
  const size_t alignment = 128;
684
- size = GGML_PAD(size, alignment);
669
+ size = GGML_PAD(size, alignment);
685
670
  if (size == 0) {
686
671
  size = alignment;
687
672
  }
@@ -691,53 +676,51 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
691
676
  if (size > avail) {
692
677
  // round up to the next multiple of the granularity
693
678
  size_t reserve_size = size - avail;
694
- reserve_size = GGML_PAD(reserve_size, granularity);
679
+ reserve_size = GGML_PAD(reserve_size, granularity);
695
680
 
696
681
  GGML_ASSERT(pool_size + reserve_size <= max_size);
697
682
 
698
683
  // allocate more physical memory
699
684
  aclrtPhysicalMemProp prop = {};
700
- prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
701
- prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
702
- prop.memAttr = ACL_HBM_MEM_HUGE;
703
- prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
704
- prop.location.id = device;
705
- prop.reserve = 0;
685
+ prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
686
+ prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
687
+ prop.memAttr = ACL_HBM_MEM_HUGE;
688
+ prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
689
+ prop.location.id = device;
690
+ prop.reserve = 0;
706
691
  aclrtDrvMemHandle handle;
707
692
  ACL_CHECK(aclrtMallocPhysical(&handle, reserve_size, &prop, 0));
708
693
 
709
694
  // reserve virtual address space (if not already reserved)
710
695
  if (pool_addr == 0) {
711
- ACL_CHECK(aclrtReserveMemAddress(
712
- &pool_addr, max_size, 0, NULL, 1));
696
+ ACL_CHECK(aclrtReserveMemAddress(&pool_addr, max_size, 0, NULL, 1));
713
697
  }
714
698
 
715
699
  // map at the end of the pool
716
- ACL_CHECK(aclrtMapMem((char*)pool_addr + pool_size, reserve_size, 0,
717
- handle, 0));
700
+ ACL_CHECK(aclrtMapMem((char *) pool_addr + pool_size, reserve_size, 0, handle, 0));
718
701
 
719
702
  handles.push_back(handle);
720
- map_offsets.push_back((char*)pool_addr + pool_size);
703
+ map_offsets.push_back((char *) pool_addr + pool_size);
721
704
 
722
705
  // add to the pool
723
706
  pool_size += reserve_size;
724
707
 
725
708
  #ifdef DEBUG_CANN_MALLOC
726
- GGML_LOG_INFO("cann pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
727
- device, (unsigned long long) (pool_size/1024/1024),
728
- (unsigned long long) (reserve_size/1024/1024));
709
+ GGML_LOG_INFO("cann pool[%d]: size increased to %llu MB (reserved %llu MB)\n", device,
710
+ (unsigned long long) (pool_size / 1024 / 1024),
711
+ (unsigned long long) (reserve_size / 1024 / 1024));
729
712
  #endif
730
713
  }
731
714
 
732
715
  GGML_ASSERT(pool_addr != 0);
733
716
 
734
- void* ptr = (void*)((char*)pool_addr + pool_used);
717
+ void * ptr = (void *) ((char *) pool_addr + pool_used);
735
718
  *actual_size = size;
736
719
  pool_used += size;
737
720
 
738
721
  #ifdef DEBUG_CANN_MALLOC
739
- GGML_LOG_INFO("cann pool[%d]: allocated %llu bytes at %llx\n", device,
740
- (unsigned long long)size, (unsigned long long)ptr);
722
+ GGML_LOG_INFO("cann pool[%d]: allocated %llu bytes at %llx\n", device, (unsigned long long) size,
723
+ (unsigned long long) ptr);
741
724
  #endif
742
725
  return ptr;
743
726
  }
@@ -748,16 +731,16 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
748
731
  * @param ptr Pointer to the buffer to free.
749
732
  * @param size Size of the buffer to free.
750
733
  */
751
- void free(void* ptr, size_t size) override {
734
+ void free(void * ptr, size_t size) override {
752
735
  #ifdef DEBUG_CANN_MALLOC
753
- GGML_LOG_INFO("cann pool[%d]: freed %llu bytes at %llx\n", device,
754
- (unsigned long long)size, (unsigned long long)ptr);
736
+ GGML_LOG_INFO("cann pool[%d]: freed %llu bytes at %llx\n", device, (unsigned long long) size,
737
+ (unsigned long long) ptr);
755
738
  #endif
756
739
 
757
740
  pool_used -= size;
758
741
 
759
742
  // all deallocations must be in reverse order of the allocations
760
- GGML_ASSERT(ptr == (void*)((char*)pool_addr + pool_used));
743
+ GGML_ASSERT(ptr == (void *) ((char *) pool_addr + pool_used));
761
744
  }
762
745
  };
763
746
 
@@ -769,9 +752,8 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
769
752
  * @param device The device ID for which to create the pool.
770
753
  * @return A unique pointer to the created CANN pool.
771
754
  */
772
- std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
773
- int device) {
774
- std::string mem_pool_type = get_env("GGML_CANN_MEM_POOL").value_or("");
755
+ std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(int device) {
756
+ std::string mem_pool_type = get_env_as_lowercase("GGML_CANN_MEM_POOL").value_or("");
775
757
 
776
758
  if (mem_pool_type == "prio") {
777
759
  GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device);
@@ -795,9 +777,8 @@ std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
795
777
  * ID, device pointer, and a name derived from GGML_CANN_NAME and the device ID.
796
778
  */
797
779
  struct ggml_backend_cann_buffer_context {
798
- int32_t device; ///< The device ID associated with this buffer context.
799
- void* dev_ptr =
800
- nullptr; ///< Pointer to the device memory allocated for the buffer.
780
+ int32_t device; ///< The device ID associated with this buffer context.
781
+ void * dev_ptr = nullptr; ///< Pointer to the device memory allocated for the buffer.
801
782
 
802
783
  /**
803
784
  * @brief Constructor to initialize the CANN buffer context.
@@ -805,9 +786,7 @@ struct ggml_backend_cann_buffer_context {
805
786
  * @param device The device ID associated with this buffer context.
806
787
  * @param dev_ptr Pointer to the device memory allocated for the buffer.
807
788
  */
808
- ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
809
- : device(device),
810
- dev_ptr(dev_ptr) {}
789
+ ggml_backend_cann_buffer_context(int32_t device, void * dev_ptr) : device(device), dev_ptr(dev_ptr) {}
811
790
 
812
791
  /**
813
792
  * @brief Destructor to free the device memory allocated for the buffer.
@@ -815,19 +794,44 @@ struct ggml_backend_cann_buffer_context {
815
794
  ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }
816
795
  };
817
796
 
797
+ // cann buffer type
798
+ /**
799
+ * @brief Structure representing context information for a specific backend
800
+ * buffer type.
801
+ */
802
+ struct ggml_backend_cann_buffer_type_context {
803
+ int32_t device; /**< Device identifier associated with the buffer context. */
804
+ std::string name; /**< Name associated with the buffer context. */
805
+ };
806
+
807
+ /**
808
+ * @brief Retrieves the name associated with a CANN buffer type.
809
+ *
810
+ * This function returns the descriptive name associated with the specified
811
+ * CANN buffer type context.
812
+ *
813
+ * @param buft Pointer to the buffer type context.
814
+ * @return Const pointer to the C-style string containing the name.
815
+ */
816
+ static const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) {
817
+ ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;
818
+
819
+ return buft_ctx->name.c_str();
820
+ }
821
+
818
822
  /**
819
- * @brief Check if a buffer is a CANN buffer.
823
+ * @brief Checks if the backend buffer type is associated with the CANN backend.
820
824
  *
821
- * This function checks if a given buffer is a CANN buffer by comparing its
822
- * `get_name` function pointer to `ggml_backend_cann_buffer_get_name`.
825
+ * This function checks whether the provided backend buffer type is associated
826
+ * with the CANN backend based on the comparison of its name retrieval function
827
+ * pointer.
823
828
  *
824
- * @param buffer The buffer to check.
825
- * @return true if the buffer is a CANN buffer, false otherwise.
829
+ * @param buft Pointer to the backend buffer type to check.
830
+ * @return bool Returns true if the buffer type is associated with the CANN
831
+ * backend, otherwise false.
826
832
  */
827
- static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft);
828
- static bool ggml_backend_buffer_is_cann(
829
- ggml_backend_buffer_t buffer) {
830
- return ggml_backend_buft_is_cann(buffer->buft);
833
+ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
834
+ return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
831
835
  }
832
836
 
833
837
  /**
@@ -838,10 +842,8 @@ static bool ggml_backend_buffer_is_cann(
838
842
  *
839
843
  * @param buffer The CANN buffer to free.
840
844
  */
841
- static void ggml_backend_cann_buffer_free_buffer(
842
- ggml_backend_buffer_t buffer) {
843
- ggml_backend_cann_buffer_context* ctx =
844
- (ggml_backend_cann_buffer_context*)buffer->context;
845
+ static void ggml_backend_cann_buffer_free_buffer(ggml_backend_buffer_t buffer) {
846
+ ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
845
847
  delete ctx;
846
848
  }
847
849
 
@@ -854,10 +856,8 @@ static void ggml_backend_cann_buffer_free_buffer(
854
856
  * @param buffer The CANN buffer whose base pointer is to be retrieved.
855
857
  * @return A pointer to the base of the device memory allocated for the buffer.
856
858
  */
857
- static void* ggml_backend_cann_buffer_get_base(
858
- ggml_backend_buffer_t buffer) {
859
- ggml_backend_cann_buffer_context* ctx =
860
- (ggml_backend_cann_buffer_context*)buffer->context;
859
+ static void * ggml_backend_cann_buffer_get_base(ggml_backend_buffer_t buffer) {
860
+ ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
861
861
  return ctx->dev_ptr;
862
862
  }
863
863
 
@@ -874,21 +874,17 @@ static void* ggml_backend_cann_buffer_get_base(
874
874
  * @param dst Pointer to the destination buffer where transformed data will be
875
875
  * stored.
876
876
  */
877
- static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor,
878
- const void* src,
879
- void* dst) {
880
-
881
- int64_t n_elems = ggml_nelements(tensor);
882
- int64_t groups = n_elems / QK4_0;
883
- size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
877
+ static void ggml_backend_cann_transform_q4_0(ggml_tensor * tensor, const void * src, void * dst) {
878
+ int64_t n_elems = ggml_nelements(tensor);
879
+ int64_t groups = n_elems / QK4_0;
880
+ size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
884
881
 
885
- uint8_t* quant_offset = (uint8_t*)dst;
886
- uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
882
+ uint8_t * quant_offset = (uint8_t *) dst;
883
+ uint16_t * scale_offset = (uint16_t *) ((char *) dst + quant_bytes);
887
884
 
888
885
  for (int i = 0; i < groups; i++) {
889
- const block_q4_0* group =
890
- (const block_q4_0*)((const char*)src + i * sizeof(block_q4_0));
891
- *scale_offset = group->d;
886
+ const block_q4_0 * group = (const block_q4_0 *) ((const char *) src + i * sizeof(block_q4_0));
887
+ *scale_offset = group->d;
892
888
  scale_offset++;
893
889
 
894
890
  // 0-15
@@ -907,8 +903,7 @@ static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor,
907
903
  }
908
904
 
909
905
  // put (uint4b_t -8) into int4b_t
910
- for (quant_offset = (uint8_t*)dst;
911
- quant_offset < (uint8_t*)dst + quant_bytes; quant_offset++) {
906
+ for (quant_offset = (uint8_t *) dst; quant_offset < (uint8_t *) dst + quant_bytes; quant_offset++) {
912
907
  (*quant_offset) ^= 0x88;
913
908
  }
914
909
  }
@@ -926,29 +921,27 @@ static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor,
926
921
  * @param dst Pointer to the destination buffer where the Q4.0 formatted data
927
922
  * will be stored.
928
923
  */
929
- static void ggml_backend_cann_transform_back_q4_0(
930
- const ggml_tensor* tensor, void* src, void* dst) {
931
-
932
- int64_t n_elems = ggml_nelements(tensor);
933
- int64_t groups = n_elems / QK4_0;
934
- size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
924
+ static void ggml_backend_cann_transform_back_q4_0(const ggml_tensor * tensor, void * src, void * dst) {
925
+ int64_t n_elems = ggml_nelements(tensor);
926
+ int64_t groups = n_elems / QK4_0;
927
+ size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
935
928
 
936
- uint8_t* quant_offset = (uint8_t*)src;
937
- uint16_t* scale_offset = (uint16_t*)((char*)src + quant_bytes);
929
+ uint8_t * quant_offset = (uint8_t *) src;
930
+ uint16_t * scale_offset = (uint16_t *) ((char *) src + quant_bytes);
938
931
 
939
- for (; quant_offset < (uint8_t*)src + quant_bytes; quant_offset++) {
932
+ for (; quant_offset < (uint8_t *) src + quant_bytes; quant_offset++) {
940
933
  (*quant_offset) ^= 0x88;
941
934
  }
942
- quant_offset = (uint8_t*)src;
935
+ quant_offset = (uint8_t *) src;
943
936
 
944
937
  for (int i = 0; i < groups; i++) {
945
- block_q4_0* group = (block_q4_0*)((char*)dst + i * sizeof(block_q4_0));
946
- group->d = *scale_offset;
938
+ block_q4_0 * group = (block_q4_0 *) ((char *) dst + i * sizeof(block_q4_0));
939
+ group->d = *scale_offset;
947
940
  scale_offset++;
948
941
 
949
942
  // 0-15
950
943
  for (int j = 0; j < QK4_0 / 2; j += 2) {
951
- group->qs[j] = ((*quant_offset) & 0x0F);
944
+ group->qs[j] = ((*quant_offset) & 0x0F);
952
945
  group->qs[j + 1] = ((*quant_offset) >> 4);
953
946
  quant_offset++;
954
947
  }
@@ -975,20 +968,17 @@ static void ggml_backend_cann_transform_back_q4_0(
975
968
  * @param dst Pointer to the destination buffer where transformed data will be
976
969
  * stored.
977
970
  */
978
- static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor,
979
- const void* src,
980
- void* dst) {
981
- int64_t n_elems = ggml_nelements(tensor);
982
- int64_t groups = n_elems / QK8_0;
983
- size_t quant_bytes = n_elems * sizeof(uint8_t);
971
+ static void ggml_backend_cann_transform_q8_0(ggml_tensor * tensor, const void * src, void * dst) {
972
+ int64_t n_elems = ggml_nelements(tensor);
973
+ int64_t groups = n_elems / QK8_0;
974
+ size_t quant_bytes = n_elems * sizeof(uint8_t);
984
975
 
985
- uint8_t* quant_offset = (uint8_t*)dst;
986
- uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
976
+ uint8_t * quant_offset = (uint8_t *) dst;
977
+ uint16_t * scale_offset = (uint16_t *) ((char *) dst + quant_bytes);
987
978
 
988
979
  for (int i = 0; i < groups; i++) {
989
- const block_q8_0* group =
990
- (const block_q8_0*)((const char*)src + i * sizeof(block_q8_0));
991
- *scale_offset = group->d;
980
+ const block_q8_0 * group = (const block_q8_0 *) ((const char *) src + i * sizeof(block_q8_0));
981
+ *scale_offset = group->d;
992
982
  scale_offset++;
993
983
  size_t group_quant_size = QK8_0 * sizeof(uint8_t);
994
984
  memcpy(quant_offset, group->qs, group_quant_size);
@@ -1009,19 +999,17 @@ static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor,
1009
999
  * @param dst Pointer to the destination buffer where the Q8.0 formatted data
1010
1000
  * will be stored.
1011
1001
  */
1012
- static void ggml_backend_cann_transform_back_q8_0(
1013
- const ggml_tensor* tensor, const void* src, void* dst) {
1014
- int64_t n_elems = ggml_nelements(tensor);
1015
- int64_t groups = n_elems / QK8_0;
1016
- size_t quant_bytes = n_elems * sizeof(uint8_t);
1002
+ static void ggml_backend_cann_transform_back_q8_0(const ggml_tensor * tensor, const void * src, void * dst) {
1003
+ int64_t n_elems = ggml_nelements(tensor);
1004
+ int64_t groups = n_elems / QK8_0;
1005
+ size_t quant_bytes = n_elems * sizeof(uint8_t);
1017
1006
 
1018
- const uint8_t* quant_offset = (const uint8_t*)src;
1019
- const uint16_t* scale_offset =
1020
- (const uint16_t*)((const char*)src + quant_bytes);
1007
+ const uint8_t * quant_offset = (const uint8_t *) src;
1008
+ const uint16_t * scale_offset = (const uint16_t *) ((const char *) src + quant_bytes);
1021
1009
 
1022
1010
  for (int i = 0; i < groups; i++) {
1023
- block_q8_0* group = (block_q8_0*)((char*)dst + i * sizeof(block_q8_0));
1024
- group->d = *scale_offset;
1011
+ block_q8_0 * group = (block_q8_0 *) ((char *) dst + i * sizeof(block_q8_0));
1012
+ group->d = *scale_offset;
1025
1013
  scale_offset++;
1026
1014
  size_t group_quant_size = QK8_0 * sizeof(uint8_t);
1027
1015
  memcpy(group->qs, quant_offset, group_quant_size);
@@ -1041,8 +1029,7 @@ static void ggml_backend_cann_transform_back_q8_0(
1041
1029
  * @param dst Pointer to the destination buffer where transformed data will be
1042
1030
  * stored.
1043
1031
  */
1044
- static void ggml_backend_cann_transform(ggml_tensor* tensor,
1045
- const void* src, void* dst) {
1032
+ static void ggml_backend_cann_transform(ggml_tensor * tensor, const void * src, void * dst) {
1046
1033
  switch (tensor->type) {
1047
1034
  case GGML_TYPE_Q4_0:
1048
1035
  ggml_backend_cann_transform_q4_0(tensor, src, dst);
@@ -1067,8 +1054,7 @@ static void ggml_backend_cann_transform(ggml_tensor* tensor,
1067
1054
  * @param dst Pointer to the destination buffer where transformed tensor data
1068
1055
  * will be stored.
1069
1056
  */
1070
- static void ggml_backend_cann_transform_back(
1071
- const ggml_tensor* tensor, void* src, void* dst) {
1057
+ static void ggml_backend_cann_transform_back(const ggml_tensor * tensor, void * src, void * dst) {
1072
1058
  switch (tensor->type) {
1073
1059
  case GGML_TYPE_Q4_0:
1074
1060
  ggml_backend_cann_transform_back_q4_0(tensor, src, dst);
@@ -1109,8 +1095,7 @@ static bool need_transform(ggml_type type) {
1109
1095
  * @param buffer The CANN buffer from which to initialize the tensor.
1110
1096
  * @param tensor Pointer to the tensor to be initialized.
1111
1097
  */
1112
- static enum ggml_status ggml_backend_cann_buffer_init_tensor(
1113
- ggml_backend_buffer_t buffer, ggml_tensor* tensor) {
1098
+ static enum ggml_status ggml_backend_cann_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
1114
1099
  if (tensor->view_src != NULL && tensor->view_offs == 0) {
1115
1100
  GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
1116
1101
  return GGML_STATUS_SUCCESS;
@@ -1121,13 +1106,11 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(
1121
1106
  if (ggml_is_quantized(tensor->type)) {
1122
1107
  // Initialize padding to 0 to avoid possible NaN values
1123
1108
  size_t original_size = ggml_nbytes(tensor);
1124
- size_t padded_size =
1125
- ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
1109
+ size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
1126
1110
 
1127
1111
  if (padded_size > original_size && tensor->view_src == nullptr) {
1128
1112
  size_t memset_size = padded_size - original_size;
1129
- ACL_CHECK(aclrtMemset((char*)tensor->data + original_size,
1130
- memset_size, 0, memset_size));
1113
+ ACL_CHECK(aclrtMemset((char *) tensor->data + original_size, memset_size, 0, memset_size));
1131
1114
  }
1132
1115
  }
1133
1116
  return GGML_STATUS_SUCCESS;
@@ -1141,8 +1124,8 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(
1141
1124
  * designed to be used with a global array, one per device.
1142
1125
  */
1143
1126
  struct ggml_cann_nz_workspace {
1144
- void* ptr; // Pointer to allocated device buffer
1145
- size_t allocated; // Size of currently allocated buffer in bytes
1127
+ void * ptr; // Pointer to allocated device buffer
1128
+ size_t allocated; // Size of currently allocated buffer in bytes
1146
1129
 
1147
1130
  /**
1148
1131
  * @brief Constructor. Initializes the workspace with no allocated memory.
@@ -1158,7 +1141,7 @@ struct ggml_cann_nz_workspace {
1158
1141
  void clear() {
1159
1142
  if (ptr) {
1160
1143
  ACL_CHECK(aclrtFree(ptr));
1161
- ptr = nullptr;
1144
+ ptr = nullptr;
1162
1145
  allocated = 0;
1163
1146
  }
1164
1147
  }
@@ -1185,7 +1168,7 @@ struct ggml_cann_nz_workspace {
1185
1168
  *
1186
1169
  * @return Pointer to the allocated buffer, or nullptr if not allocated.
1187
1170
  */
1188
- void* get() const { return ptr; }
1171
+ void * get() const { return ptr; }
1189
1172
  };
1190
1173
 
1191
1174
  /**
@@ -1207,22 +1190,19 @@ static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES];
1207
1190
  * @note The workspace buffer used in this function is managed globally and reused
1208
1191
  * across calls. This reduces overhead from repeated memory allocation and deallocation.
1209
1192
  */
1210
- static void weight_format_to_nz(ggml_tensor *tensor, size_t offset, int device) {
1211
- aclTensor* weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne,
1212
- tensor->nb, 2, ACL_FORMAT_ND, offset);
1213
- uint64_t workspaceSize = 0;
1214
- aclOpExecutor *executor;
1193
+ static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device) {
1194
+ acl_tensor_ptr weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset);
1195
+ uint64_t workspaceSize = 0;
1196
+ aclOpExecutor * executor;
1215
1197
 
1216
1198
  // TransMatmulWeight
1217
- ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed,
1218
- &workspaceSize, &executor));
1199
+ ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed.get(), &workspaceSize, &executor));
1219
1200
  // Avoid frequent malloc/free of the workspace.
1220
1201
  g_nz_workspaces[device].realloc(workspaceSize);
1221
1202
 
1222
- void* g_nz_workspace = g_nz_workspaces[device].get();
1203
+ void * g_nz_workspace = g_nz_workspaces[device].get();
1223
1204
 
1224
1205
  ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr));
1225
- ACL_CHECK(aclDestroyTensor(weightTransposed));
1226
1206
  }
1227
1207
 
1228
1208
  // TODO: need handle tensor which has paddings.
@@ -1238,11 +1218,12 @@ static void weight_format_to_nz(ggml_tensor *tensor, size_t offset, int device)
1238
1218
  * @param offset Offset in the source data from where to start copying.
1239
1219
  * @param size Size of the data to be copied, in bytes.
1240
1220
  */
1241
- static void ggml_backend_cann_buffer_set_tensor(
1242
- ggml_backend_buffer_t buffer, ggml_tensor *tensor, const void *data,
1243
- size_t offset, size_t size) {
1244
- ggml_backend_cann_buffer_context *ctx =
1245
- (ggml_backend_cann_buffer_context *)buffer->context;
1221
+ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer,
1222
+ ggml_tensor * tensor,
1223
+ const void * data,
1224
+ size_t offset,
1225
+ size_t size) {
1226
+ ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
1246
1227
 
1247
1228
  ggml_cann_set_device(ctx->device);
1248
1229
  // TODO: refer to cann(#6017), it use thread's default stream.
@@ -1250,22 +1231,19 @@ static void ggml_backend_cann_buffer_set_tensor(
1250
1231
  // Why aclrtSynchronizeDevice?
1251
1232
 
1252
1233
  // Only check env once.
1253
- static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
1234
+ static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
1254
1235
  if (!need_transform(tensor->type)) {
1255
- ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size,
1256
- ACL_MEMCPY_HOST_TO_DEVICE));
1257
- if (weight_to_nz && is_matmul_weight((const ggml_tensor*)tensor)) {
1236
+ ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE));
1237
+ if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
1258
1238
  GGML_ASSERT(tensor->ne[2] == 1);
1259
1239
  GGML_ASSERT(tensor->ne[3] == 1);
1260
1240
  weight_format_to_nz(tensor, offset, ctx->device);
1261
1241
  }
1262
1242
  } else {
1263
- void *transform_buffer = malloc(size);
1243
+ void * transform_buffer = malloc(size);
1264
1244
  ggml_backend_cann_transform(tensor, data, transform_buffer);
1265
1245
 
1266
- ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size,
1267
- transform_buffer, size,
1268
- ACL_MEMCPY_HOST_TO_DEVICE));
1246
+ ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, transform_buffer, size, ACL_MEMCPY_HOST_TO_DEVICE));
1269
1247
  free(transform_buffer);
1270
1248
  }
1271
1249
  }
@@ -1283,22 +1261,20 @@ static void ggml_backend_cann_buffer_set_tensor(
1283
1261
  * @param offset Offset in the destination buffer where to start copying.
1284
1262
  * @param size Size of the data to be copied, in bytes.
1285
1263
  */
1286
- static void ggml_backend_cann_buffer_get_tensor(
1287
- ggml_backend_buffer_t buffer, const ggml_tensor* tensor, void* data,
1288
- size_t offset, size_t size) {
1289
- ggml_backend_cann_buffer_context* ctx =
1290
- (ggml_backend_cann_buffer_context*)buffer->context;
1264
+ static void ggml_backend_cann_buffer_get_tensor(ggml_backend_buffer_t buffer,
1265
+ const ggml_tensor * tensor,
1266
+ void * data,
1267
+ size_t offset,
1268
+ size_t size) {
1269
+ ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
1291
1270
 
1292
1271
  ggml_cann_set_device(ctx->device);
1293
1272
 
1294
1273
  if (!need_transform(tensor->type)) {
1295
- ACL_CHECK(aclrtMemcpy(data, size, (char*)tensor->data + offset, size,
1296
- ACL_MEMCPY_DEVICE_TO_HOST));
1274
+ ACL_CHECK(aclrtMemcpy(data, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST));
1297
1275
  } else {
1298
- void* transform_buffer = malloc(size);
1299
- ACL_CHECK(aclrtMemcpy(transform_buffer, size,
1300
- (char*)tensor->data + offset, size,
1301
- ACL_MEMCPY_DEVICE_TO_HOST));
1276
+ void * transform_buffer = malloc(size);
1277
+ ACL_CHECK(aclrtMemcpy(transform_buffer, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST));
1302
1278
  ggml_backend_cann_transform_back(tensor, transform_buffer, data);
1303
1279
  free(transform_buffer);
1304
1280
  }
@@ -1317,19 +1293,17 @@ static void ggml_backend_cann_buffer_get_tensor(
1317
1293
  * @param dst Pointer to the destination tensor where the data will be copied.
1318
1294
  * @return true if the copy operation succeeded, false otherwise.
1319
1295
  */
1320
- static bool ggml_backend_cann_buffer_cpy_tensor(
1321
- ggml_backend_buffer_t buffer, const ggml_tensor* src, ggml_tensor* dst) {
1322
- if (ggml_backend_buffer_is_cann(src->buffer)) {
1323
- ggml_backend_cann_buffer_context* src_ctx =
1324
- (ggml_backend_cann_buffer_context*)src->buffer->context;
1325
- ggml_backend_cann_buffer_context* dst_ctx =
1326
- (ggml_backend_cann_buffer_context*)buffer->context;
1296
+ static bool ggml_backend_cann_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
1297
+ const ggml_tensor * src,
1298
+ ggml_tensor * dst) {
1299
+ if (ggml_backend_buft_is_cann(src->buffer->buft)) {
1300
+ ggml_backend_cann_buffer_context * src_ctx = (ggml_backend_cann_buffer_context *) src->buffer->context;
1301
+ ggml_backend_cann_buffer_context * dst_ctx = (ggml_backend_cann_buffer_context *) buffer->context;
1327
1302
 
1328
1303
  size_t memcpy_size = ggml_nbytes(src);
1329
1304
  // Same device.
1330
1305
  if (src_ctx->device == dst_ctx->device) {
1331
- ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
1332
- (const char*)src->data, memcpy_size,
1306
+ ACL_CHECK(aclrtMemcpy((char *) dst->data, memcpy_size, (const char *) src->data, memcpy_size,
1333
1307
  ACL_MEMCPY_DEVICE_TO_DEVICE));
1334
1308
  return true;
1335
1309
  } else {
@@ -1339,13 +1313,11 @@ static bool ggml_backend_cann_buffer_cpy_tensor(
1339
1313
  #endif
1340
1314
  // Different device but can access by peer.
1341
1315
  int32_t canAccessPeer = 0;
1342
- ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device,
1343
- dst_ctx->device));
1316
+ ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device, dst_ctx->device));
1344
1317
  if (canAccessPeer) {
1345
1318
  ggml_cann_set_device(src_ctx->device);
1346
1319
  ACL_CHECK(aclrtDeviceEnablePeerAccess(dst_ctx->device, 0));
1347
- ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
1348
- (const char*)src->data, memcpy_size,
1320
+ ACL_CHECK(aclrtMemcpy((char *) dst->data, memcpy_size, (const char *) src->data, memcpy_size,
1349
1321
  ACL_MEMCPY_DEVICE_TO_DEVICE));
1350
1322
  return true;
1351
1323
  }
@@ -1363,10 +1335,8 @@ static bool ggml_backend_cann_buffer_cpy_tensor(
1363
1335
  * @param buffer The CANN buffer to be cleared.
1364
1336
  * @param value The value to which each byte in the buffer will be set.
1365
1337
  */
1366
- static void ggml_backend_cann_buffer_clear(
1367
- ggml_backend_buffer_t buffer, uint8_t value) {
1368
- ggml_backend_cann_buffer_context* ctx =
1369
- (ggml_backend_cann_buffer_context*)buffer->context;
1338
+ static void ggml_backend_cann_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
1339
+ ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
1370
1340
 
1371
1341
  ggml_cann_set_device(ctx->device);
1372
1342
  ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size));
@@ -1390,34 +1360,6 @@ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
1390
1360
  /* .reset = */ NULL,
1391
1361
  };
1392
1362
 
1393
- // cann buffer type
1394
- /**
1395
- * @brief Structure representing context information for a specific backend
1396
- * buffer type.
1397
- */
1398
- struct ggml_backend_cann_buffer_type_context {
1399
- int32_t
1400
- device; /**< Device identifier associated with the buffer context. */
1401
- std::string name; /**< Name associated with the buffer context. */
1402
- };
1403
-
1404
- /**
1405
- * @brief Retrieves the name associated with a CANN buffer type.
1406
- *
1407
- * This function returns the descriptive name associated with the specified
1408
- * CANN buffer type context.
1409
- *
1410
- * @param buft Pointer to the buffer type context.
1411
- * @return Const pointer to the C-style string containing the name.
1412
- */
1413
- static const char* ggml_backend_cann_buffer_type_name(
1414
- ggml_backend_buffer_type_t buft) {
1415
- ggml_backend_cann_buffer_type_context* buft_ctx =
1416
- (ggml_backend_cann_buffer_type_context*)buft->context;
1417
-
1418
- return buft_ctx->name.c_str();
1419
- }
1420
-
1421
1363
  /**
1422
1364
  * @brief Allocates a new CANN buffer of the specified type and size.
1423
1365
  *
@@ -1428,34 +1370,27 @@ static const char* ggml_backend_cann_buffer_type_name(
1428
1370
  * @param size Size in bytes of the buffer to allocate.
1429
1371
  * @return Pointer to the allocated buffer, or nullptr if allocation fails.
1430
1372
  */
1431
- static ggml_backend_buffer_t
1432
- ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
1433
- size_t size) {
1434
- ggml_backend_cann_buffer_type_context* buft_ctx =
1435
- (ggml_backend_cann_buffer_type_context*)buft->context;
1373
+ static ggml_backend_buffer_t ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1374
+ ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;
1436
1375
 
1437
1376
  ggml_cann_set_device(buft_ctx->device);
1438
1377
 
1439
1378
  const size_t alignment = 128;
1440
- size = GGML_PAD(size, alignment);
1379
+ size = GGML_PAD(size, alignment);
1441
1380
  if (size == 0) {
1442
1381
  size = alignment;
1443
1382
  }
1444
- void* dev_ptr;
1383
+ void * dev_ptr;
1445
1384
  aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
1446
1385
  if (err != ACL_SUCCESS) {
1447
- GGML_LOG_ERROR(
1448
- "%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n",
1449
- __func__, size / 1024.0 / 1024.0, buft_ctx->device,
1450
- aclGetRecentErrMsg());
1386
+ GGML_LOG_ERROR("%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n", __func__,
1387
+ size / 1024.0 / 1024.0, buft_ctx->device, aclGetRecentErrMsg());
1451
1388
  return nullptr;
1452
1389
  }
1453
1390
 
1454
- ggml_backend_cann_buffer_context* ctx =
1455
- new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);
1391
+ ggml_backend_cann_buffer_context * ctx = new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);
1456
1392
 
1457
- return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface,
1458
- ctx, size);
1393
+ return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface, ctx, size);
1459
1394
  }
1460
1395
 
1461
1396
  /**
@@ -1470,8 +1405,7 @@ ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
1470
1405
  * @return The alignment requirement in bytes (fixed at 128 bytes for CANN
1471
1406
  * buffers).
1472
1407
  */
1473
- static size_t ggml_backend_cann_buffer_type_get_alignment(
1474
- ggml_backend_buffer_type_t buft) {
1408
+ static size_t ggml_backend_cann_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
1475
1409
  return 128;
1476
1410
 
1477
1411
  GGML_UNUSED(buft);
@@ -1491,13 +1425,13 @@ static size_t ggml_backend_cann_buffer_type_get_alignment(
1491
1425
  * @return The total allocation size in bytes required for the tensor in the
1492
1426
  * CANN buffer.
1493
1427
  */
1494
- static size_t ggml_backend_cann_buffer_type_get_alloc_size(
1495
- ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
1496
- size_t size = ggml_nbytes(tensor);
1497
- int64_t ne0 = tensor->ne[0];
1428
+ static size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft,
1429
+ const ggml_tensor * tensor) {
1430
+ size_t size = ggml_nbytes(tensor);
1431
+ int64_t ne0 = tensor->ne[0];
1498
1432
 
1499
1433
  // Only check env once.
1500
- static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
1434
+ static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
1501
1435
 
1502
1436
  // last line must bigger than 32, because every single op deal at
1503
1437
  // least 32 bytes.
@@ -1507,19 +1441,17 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(
1507
1441
  // size += (line_size_align_32 - line_size);
1508
1442
  if (ggml_is_quantized(tensor->type)) {
1509
1443
  if (ne0 % MATRIX_ROW_PADDING != 0) {
1510
- size += ggml_row_size(
1511
- tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
1444
+ size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
1512
1445
  }
1513
- } else if (weight_to_nz && is_matmul_weight((const ggml_tensor*)tensor)) {
1446
+ } else if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
1514
1447
  // NZ format weight are not support quantized yet.
1515
1448
  // If ND tensor transform to NZ, size may changed.
1516
- int64_t shape[] = {tensor->ne[1], tensor->ne[0]};
1449
+ int64_t shape[] = { tensor->ne[1], tensor->ne[0] };
1517
1450
  GGML_ASSERT(tensor->ne[2] == 1);
1518
1451
  GGML_ASSERT(tensor->ne[3] == 1);
1519
- const aclIntArray *acl_shape = aclCreateIntArray(shape, 2);
1520
- size_t new_size;
1521
- ACL_CHECK(aclnnCalculateMatmulWeightSizeV2(acl_shape,
1522
- ggml_cann_type_mapping(tensor->type), &new_size));
1452
+ const aclIntArray * acl_shape = aclCreateIntArray(shape, 2);
1453
+ size_t new_size;
1454
+ ACL_CHECK(aclnnCalculateMatmulWeightSizeV2(acl_shape, ggml_cann_type_mapping(tensor->type), &new_size));
1523
1455
  ACL_CHECK(aclDestroyIntArray(acl_shape));
1524
1456
  size = std::max(size, new_size);
1525
1457
  }
@@ -1560,17 +1492,15 @@ static const ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface
1560
1492
  * @return A pointer to the buffer type interface for the specified device, or
1561
1493
  * nullptr if the device index is out of range.
1562
1494
  */
1563
- ggml_backend_buffer_type_t
1564
- ggml_backend_cann_buffer_type(int32_t device) {
1565
- static std::mutex mutex;
1495
+ ggml_backend_buffer_type_t ggml_backend_cann_buffer_type(int32_t device) {
1496
+ static std::mutex mutex;
1566
1497
  std::lock_guard<std::mutex> lock(mutex);
1567
1498
 
1568
1499
  if (device >= ggml_backend_cann_get_device_count()) {
1569
1500
  return nullptr;
1570
1501
  }
1571
1502
 
1572
- static ggml_backend_buffer_type
1573
- ggml_backend_cann_buffer_types[GGML_CANN_MAX_DEVICES];
1503
+ static ggml_backend_buffer_type ggml_backend_cann_buffer_types[GGML_CANN_MAX_DEVICES];
1574
1504
 
1575
1505
  static bool ggml_backend_cann_buffer_type_initialized = false;
1576
1506
 
@@ -1580,8 +1510,7 @@ ggml_backend_cann_buffer_type(int32_t device) {
1580
1510
  /* .iface = */ ggml_backend_cann_buffer_type_interface,
1581
1511
  /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), i),
1582
1512
  /* .context = */
1583
- new ggml_backend_cann_buffer_type_context{
1584
- i, "CANN" + std::to_string(i)},
1513
+ new ggml_backend_cann_buffer_type_context{ i, "CANN" + std::to_string(i) },
1585
1514
  };
1586
1515
  }
1587
1516
  ggml_backend_cann_buffer_type_initialized = true;
@@ -1645,16 +1574,16 @@ static void * ggml_cann_host_malloc(size_t size) {
1645
1574
  }
1646
1575
 
1647
1576
  const size_t alignment = 128;
1648
- size = GGML_PAD(size, alignment);
1577
+ size = GGML_PAD(size, alignment);
1649
1578
  if (size == 0) {
1650
1579
  size = alignment;
1651
1580
  }
1652
1581
 
1653
- void * hostPtr = nullptr;
1654
- aclError err = aclrtMallocHost((void **) &hostPtr, size);
1582
+ void * hostPtr = nullptr;
1583
+ aclError err = aclrtMallocHost((void **) &hostPtr, size);
1655
1584
  if (err != ACL_SUCCESS) {
1656
- GGML_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
1657
- size / 1024.0 / 1024.0, aclGetRecentErrMsg());
1585
+ GGML_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__, size / 1024.0 / 1024.0,
1586
+ aclGetRecentErrMsg());
1658
1587
  return nullptr;
1659
1588
  }
1660
1589
  return hostPtr;
@@ -1667,7 +1596,8 @@ static void * ggml_cann_host_malloc(size_t size) {
1667
1596
  * @param size Size in bytes of the host buffer to allocate.
1668
1597
  * @return Pointer to the allocated host buffer, or CPU buffer pointer if allocation fails.
1669
1598
  */
1670
- static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1599
+ static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
1600
+ size_t size) {
1671
1601
  void * hostPtr = ggml_cann_host_malloc(size);
1672
1602
 
1673
1603
  if (hostPtr == nullptr) {
@@ -1676,8 +1606,8 @@ static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggm
1676
1606
  }
1677
1607
 
1678
1608
  ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(hostPtr, size);
1679
- buffer->buft = buft;
1680
- buffer->iface.free_buffer = ggml_backend_cann_host_buffer_free;
1609
+ buffer->buft = buft;
1610
+ buffer->iface.free_buffer = ggml_backend_cann_host_buffer_free;
1681
1611
 
1682
1612
  return buffer;
1683
1613
  }
@@ -1691,14 +1621,15 @@ static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggm
1691
1621
  ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
1692
1622
  static struct ggml_backend_buffer_type ggml_backend_cann_buffer_type_host = {
1693
1623
  /* .iface = */ {
1694
- /* .get_name = */ ggml_backend_cann_host_buffer_type_name,
1695
- /* .alloc_buffer = */ ggml_backend_cann_host_buffer_type_alloc_buffer,
1696
- /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
1697
- /* .get_max_size = */ NULL, // defaults to SIZE_MAX
1624
+ /* .get_name = */ ggml_backend_cann_host_buffer_type_name,
1625
+ /* .alloc_buffer = */ ggml_backend_cann_host_buffer_type_alloc_buffer,
1626
+ /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
1627
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
1698
1628
  /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
1699
- /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
1700
- },
1701
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
1629
+ /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
1630
+ },
1631
+ /* .device = */
1632
+ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
1702
1633
  /* .context = */ nullptr,
1703
1634
  };
1704
1635
 
@@ -1718,8 +1649,7 @@ ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
1718
1649
  * stored.
1719
1650
  * @return true if the computation was successful; false otherwise.
1720
1651
  */
1721
- static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1722
- struct ggml_tensor* dst) {
1652
+ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct ggml_tensor * dst) {
1723
1653
  switch (dst->op) {
1724
1654
  case GGML_OP_REPEAT:
1725
1655
  ggml_cann_repeat(ctx, dst);
@@ -1765,14 +1695,14 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1765
1695
  case GGML_UNARY_OP_SILU:
1766
1696
  GGML_CANN_CALL_OP_UNARY(Silu);
1767
1697
  break;
1768
- case GGML_UNARY_OP_GELU_QUICK: {
1769
- auto lambda = [](ggml_backend_cann_context& ctx,
1770
- aclTensor* acl_src,
1771
- aclTensor* acl_dst) {
1772
- GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
1773
- };
1774
- ggml_cann_op_unary(lambda, ctx, dst);
1775
- } break;
1698
+ case GGML_UNARY_OP_GELU_QUICK:
1699
+ {
1700
+ auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
1701
+ GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
1702
+ };
1703
+ ggml_cann_op_unary(lambda, ctx, dst);
1704
+ }
1705
+ break;
1776
1706
  case GGML_UNARY_OP_TANH:
1777
1707
  GGML_CANN_CALL_OP_UNARY(Tanh);
1778
1708
  break;
@@ -1817,14 +1747,14 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1817
1747
  case GGML_GLU_OP_SWIGLU:
1818
1748
  GGML_CANN_CALL_OP_UNARY_GATED(Silu);
1819
1749
  break;
1820
- case GGML_GLU_OP_GEGLU_QUICK: {
1821
- auto lambda = [](ggml_backend_cann_context& ctx,
1822
- aclTensor* acl_src,
1823
- aclTensor* acl_dst) {
1824
- GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
1825
- };
1826
- ggml_cann_op_unary_gated(lambda, ctx, dst);
1827
- } break;
1750
+ case GGML_GLU_OP_GEGLU_QUICK:
1751
+ {
1752
+ auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
1753
+ GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
1754
+ };
1755
+ ggml_cann_op_unary_gated(lambda, ctx, dst);
1756
+ }
1757
+ break;
1828
1758
  default:
1829
1759
  return false;
1830
1760
  }
@@ -1835,6 +1765,12 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1835
1765
  case GGML_OP_GROUP_NORM:
1836
1766
  ggml_cann_group_norm(ctx, dst);
1837
1767
  break;
1768
+ case GGML_OP_L2_NORM:
1769
+ ggml_cann_l2_norm(ctx, dst);
1770
+ break;
1771
+ case GGML_OP_CROSS_ENTROPY_LOSS:
1772
+ ggml_cann_cross_entropy_loss(ctx, dst);
1773
+ break;
1838
1774
  case GGML_OP_CONCAT:
1839
1775
  ggml_cann_concat(ctx, dst);
1840
1776
  break;
@@ -1939,6 +1875,15 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1939
1875
  case GGML_OP_FLASH_ATTN_EXT:
1940
1876
  ggml_cann_flash_attn_ext(ctx, dst);
1941
1877
  break;
1878
+ case GGML_OP_OUT_PROD:
1879
+ ggml_cann_out_prod(ctx, dst);
1880
+ break;
1881
+ case GGML_OP_GATED_LINEAR_ATTN:
1882
+ ggml_cann_gated_linear_attn(ctx, dst);
1883
+ break;
1884
+ case GGML_OP_SSM_CONV:
1885
+ ggml_cann_ssm_conv(ctx, dst);
1886
+ break;
1942
1887
  default:
1943
1888
  return false;
1944
1889
  }
@@ -1956,9 +1901,8 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1956
1901
  * @param backend Pointer to the CANN backend structure.
1957
1902
  * @return A pointer to a constant string representing the backend name.
1958
1903
  */
1959
- static const char* ggml_backend_cann_name(ggml_backend_t backend) {
1960
- ggml_backend_cann_context* cann_ctx =
1961
- (ggml_backend_cann_context*)backend->context;
1904
+ static const char * ggml_backend_cann_name(ggml_backend_t backend) {
1905
+ ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
1962
1906
 
1963
1907
  return cann_ctx->name.c_str();
1964
1908
  }
@@ -1972,8 +1916,7 @@ static const char* ggml_backend_cann_name(ggml_backend_t backend) {
1972
1916
  * @param backend Pointer to the CANN backend structure to be freed.
1973
1917
  */
1974
1918
  static void ggml_backend_cann_free(ggml_backend_t backend) {
1975
- ggml_backend_cann_context* cann_ctx =
1976
- (ggml_backend_cann_context*)backend->context;
1919
+ ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
1977
1920
  ACL_CHECK(aclrtSynchronizeDevice());
1978
1921
  ACL_CHECK(aclrtResetDevice(cann_ctx->device));
1979
1922
 
@@ -1981,7 +1924,6 @@ static void ggml_backend_cann_free(ggml_backend_t backend) {
1981
1924
  delete backend;
1982
1925
  }
1983
1926
 
1984
-
1985
1927
  /**
1986
1928
  * @brief Sets tensor data asynchronously in the CANN backend.
1987
1929
  *
@@ -1994,21 +1936,18 @@ static void ggml_backend_cann_free(ggml_backend_t backend) {
1994
1936
  * @param size Size of the data to copy in bytes.
1995
1937
  */
1996
1938
  static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
1997
- ggml_tensor *tensor,
1998
- const void *data,
1999
- size_t offset,
2000
- size_t size) {
2001
- ggml_backend_cann_context *cann_ctx =
2002
- (ggml_backend_cann_context *)backend->context;
2003
- ggml_backend_buffer_t buf =
2004
- tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
2005
-
2006
- GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
2007
- "unsupported buffer type");
1939
+ ggml_tensor * tensor,
1940
+ const void * data,
1941
+ size_t offset,
1942
+ size_t size) {
1943
+ ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
1944
+ ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
1945
+
1946
+ GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
2008
1947
  GGML_ASSERT(!ggml_is_quantized(tensor->type));
2009
1948
 
2010
- ggml_cann_async_memcpy(cann_ctx, (char *)tensor->data + offset, data, size,
2011
- ACL_MEMCPY_HOST_TO_DEVICE);
1949
+ ACL_CHECK(aclrtMemcpyAsync((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE,
1950
+ cann_ctx->stream()));
2012
1951
  }
2013
1952
 
2014
1953
  /**
@@ -2022,21 +1961,19 @@ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
2022
1961
  * @param offset Offset in bytes within the host data.
2023
1962
  * @param size Size of the data to copy in bytes.
2024
1963
  */
2025
- static void ggml_backend_cann_get_tensor_async(
2026
- ggml_backend_t backend, const ggml_tensor *tensor, void *data,
2027
- size_t offset, size_t size) {
2028
- ggml_backend_cann_context *cann_ctx =
2029
- (ggml_backend_cann_context *)backend->context;
2030
- ggml_backend_buffer_t buf =
2031
- tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
1964
+ static void ggml_backend_cann_get_tensor_async(ggml_backend_t backend,
1965
+ const ggml_tensor * tensor,
1966
+ void * data,
1967
+ size_t offset,
1968
+ size_t size) {
1969
+ ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
1970
+ ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
2032
1971
 
2033
- GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
2034
- "unsupported buffer type");
1972
+ GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
2035
1973
  GGML_ASSERT(!ggml_is_quantized(tensor->type));
2036
1974
 
2037
- ggml_cann_async_memcpy(cann_ctx, data, (char *)tensor->data + offset, size,
2038
- ACL_MEMCPY_DEVICE_TO_HOST);
2039
-
1975
+ ACL_CHECK(aclrtMemcpyAsync(data, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST,
1976
+ cann_ctx->stream()));
2040
1977
  }
2041
1978
 
2042
1979
  /**
@@ -2052,28 +1989,23 @@ static void ggml_backend_cann_get_tensor_async(
2052
1989
  * @param dst Pointer to the destination tensor to copy data to.
2053
1990
  * @return true if the copy operation succeeds, false otherwise.
2054
1991
  */
2055
- static bool ggml_backend_cann_cpy_tensor_async(
2056
- ggml_backend_t backend_src, ggml_backend_t backend_dst,
2057
- const ggml_tensor* src, ggml_tensor* dst) {
2058
- GGML_ASSERT(ggml_backend_is_cann(backend_src) ||
2059
- ggml_backend_is_cann(backend_dst));
1992
+ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src,
1993
+ ggml_backend_t backend_dst,
1994
+ const ggml_tensor * src,
1995
+ ggml_tensor * dst) {
1996
+ GGML_ASSERT(ggml_backend_is_cann(backend_src) || ggml_backend_is_cann(backend_dst));
2060
1997
 
2061
- GGML_ASSERT(!is_matmul_weight((const ggml_tensor*)src));
1998
+ GGML_ASSERT(!is_matmul_weight((const ggml_tensor *) src));
2062
1999
 
2063
- if (!ggml_backend_buffer_is_cann(src->buffer) ||
2064
- !ggml_backend_buffer_is_cann(dst->buffer)) {
2000
+ if (!ggml_backend_buft_is_cann(src->buffer->buft) || !ggml_backend_buft_is_cann(dst->buffer->buft)) {
2065
2001
  return false;
2066
2002
  }
2067
2003
 
2068
- ggml_backend_buffer_t buf_src =
2069
- src->view_src ? src->view_src->buffer : src->buffer;
2070
- ggml_backend_buffer_t buf_dst =
2071
- dst->view_src ? dst->view_src->buffer : dst->buffer;
2004
+ ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
2005
+ ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
2072
2006
 
2073
- ggml_backend_cann_context* cann_ctx_src =
2074
- (ggml_backend_cann_context*)backend_src->context;
2075
- ggml_backend_cann_context* cann_ctx_dst =
2076
- (ggml_backend_cann_context*)backend_dst->context;
2007
+ ggml_backend_cann_context * cann_ctx_src = (ggml_backend_cann_context *) backend_src->context;
2008
+ ggml_backend_cann_context * cann_ctx_dst = (ggml_backend_cann_context *) backend_dst->context;
2077
2009
 
2078
2010
  size_t copy_size = ggml_nbytes(dst);
2079
2011
  if (copy_size == 0) {
@@ -2084,17 +2016,14 @@ static bool ggml_backend_cann_cpy_tensor_async(
2084
2016
  // TODO: Support 310p P2P copy
2085
2017
  return false;
2086
2018
  #endif
2087
- ggml_backend_cann_buffer_context* buf_ctx_src =
2088
- (ggml_backend_cann_buffer_context*)buf_src->context;
2089
- ggml_backend_cann_buffer_context* buf_ctx_dst =
2090
- (ggml_backend_cann_buffer_context*)buf_dst->context;
2019
+ ggml_backend_cann_buffer_context * buf_ctx_src = (ggml_backend_cann_buffer_context *) buf_src->context;
2020
+ ggml_backend_cann_buffer_context * buf_ctx_dst = (ggml_backend_cann_buffer_context *) buf_dst->context;
2091
2021
 
2092
2022
  GGML_ASSERT(cann_ctx_src->device == buf_ctx_src->device);
2093
2023
  GGML_ASSERT(cann_ctx_dst->device == buf_ctx_dst->device);
2094
2024
 
2095
2025
  int32_t canAccessPeer = 0;
2096
- ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, cann_ctx_src->device,
2097
- cann_ctx_dst->device));
2026
+ ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, cann_ctx_src->device, cann_ctx_dst->device));
2098
2027
  if (!canAccessPeer) {
2099
2028
  return false;
2100
2029
  }
@@ -2105,9 +2034,7 @@ static bool ggml_backend_cann_cpy_tensor_async(
2105
2034
  ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
2106
2035
 
2107
2036
  // wait for task_queue empty to keep task order.
2108
- cann_ctx_src->task_queue.wait();
2109
- ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
2110
- ACL_MEMCPY_DEVICE_TO_DEVICE,
2037
+ ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE,
2111
2038
  cann_ctx_src->stream()));
2112
2039
  // record event on src stream after the copy
2113
2040
  // TODO: this event is not effective with acl graph mode, change to use aclrtSynchronizeStream
@@ -2122,8 +2049,7 @@ static bool ggml_backend_cann_cpy_tensor_async(
2122
2049
  ACL_CHECK(aclrtSynchronizeStream(cann_ctx_src->stream()));
2123
2050
  } else {
2124
2051
  // src and dst are on the same backend
2125
- ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
2126
- ACL_MEMCPY_DEVICE_TO_DEVICE,
2052
+ ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE,
2127
2053
  cann_ctx_dst->stream()));
2128
2054
  }
2129
2055
 
@@ -2139,147 +2065,44 @@ static bool ggml_backend_cann_cpy_tensor_async(
2139
2065
  * @param backend Pointer to the CANN backend structure to synchronize.
2140
2066
  */
2141
2067
  static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
2142
- ggml_backend_cann_context* cann_ctx =
2143
- (ggml_backend_cann_context*)backend->context;
2144
- cann_ctx->task_queue.wait();
2068
+ ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
2145
2069
  ggml_cann_set_device(cann_ctx->device);
2146
2070
  ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
2147
2071
  }
2148
2072
 
2149
- #ifdef USE_ACL_GRAPH
2150
- /**
2151
- * @brief Add a new CANN graph to the LRU cache by populating node properties from the ggml graph.
2152
- *
2153
- * This function creates a new ggml_cann_graph object and fills its node properties
2154
- * (operation type, dimensions, strides, input sources, and operation parameters)
2155
- * based on the current ggml computation graph.
2156
- *
2157
- * Each node in the ggml graph is mapped to a property entry in the new CANN graph:
2158
- * - node address
2159
- * - operation type
2160
- * - shape (ne) and strides (nb)
2161
- * - source tensor addresses
2162
- * - operation parameters
2163
- *
2164
- * After initialization, the new graph is pushed into the LRU cache owned by the
2165
- * CANN backend context. The cache takes ownership of the graph and manages its
2166
- * lifetime (including deletion upon eviction).
2167
- *
2168
- * @param cann_ctx The CANN backend context containing the graph cache.
2169
- * @param cgraph The current ggml computation graph.
2170
- */
2171
- static void add_lru_matched_graph_node_properties(
2172
- ggml_backend_cann_context * cann_ctx,
2173
- ggml_cgraph * cgraph) {
2174
- // Create a new ggml_cann_graph object on the heap (its lifetime is managed by the cache).
2175
- ggml_cann_graph * new_graph = new ggml_cann_graph();
2176
- new_graph->ggml_graph_properties.resize(cgraph->n_nodes);
2177
-
2178
- for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) {
2179
- ggml_tensor * node = cgraph->nodes[node_idx];
2180
- auto & prop = new_graph->ggml_graph_properties[node_idx];
2181
-
2182
- prop.node_address = node->data;
2183
- prop.node_op = node->op;
2184
-
2185
- std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne);
2186
- std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
2187
-
2188
- for (int src = 0; src < GGML_MAX_SRC; ++src) {
2189
- prop.src_address[src] = node->src[src] ? node->src[src]->data : nullptr;
2190
- }
2191
-
2192
- memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
2193
- }
2194
-
2195
- // Insert into the LRU cache (cache takes ownership and will delete it when evicted).
2196
- cann_ctx->graph_lru_cache.push(new_graph);
2197
- }
2198
-
2199
2073
  /**
2200
- * @brief Check if a ggml tensor node matches a previously captured CANN graph node.
2074
+ * @brief Check if CANN backend can fuse the specified operation sequence
2201
2075
  *
2202
- * This function compares all relevant fields (address, op type, shape, source inputs, op params)
2203
- * to determine whether the current node matches a previously recorded version.
2076
+ * This function determines whether an operation sequence starting from the specified node
2077
+ * can be fused into an optimized operation in the CANN backend. Operation fusion can reduce
2078
+ * memory access overhead and improve computational efficiency.
2204
2079
  *
2205
- * @param node The current ggml tensor node.
2206
- * @param graph_node_properties The stored properties of a CANN graph node.
2207
- * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
2080
+ * @param cgraph Pointer to the computation graph
2081
+ * @param node_idx Index of the starting node in the computation graph
2082
+ * @param ops Sequence of operation types to check for fusion
2083
+ * @return true if the operations can be fused
2084
+ * @return false if the operations cannot be fused
2208
2085
  */
2209
- static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
2210
- if (node->data != graph_node_properties->node_address &&
2211
- node->op != GGML_OP_VIEW) {
2086
+ static bool ggml_cann_can_fuse(const struct ggml_cgraph * cgraph,
2087
+ int node_idx,
2088
+ std::initializer_list<enum ggml_op> ops) {
2089
+ if (!ggml_can_fuse(cgraph, node_idx, ops)) {
2212
2090
  return false;
2213
2091
  }
2214
- if (node->op != graph_node_properties->node_op) {
2215
- return false;
2216
- }
2217
- for (int i = 0; i < GGML_MAX_DIMS; i++) {
2218
- if (node->ne[i] != graph_node_properties->ne[i]) {
2219
- return false;
2220
- }
2221
- if (node->nb[i] != graph_node_properties->nb[i]) {
2222
- return false;
2223
- }
2224
- }
2225
- for (int i = 0; i < GGML_MAX_SRC; i++) {
2226
- if (node->src[i] &&
2227
- node->src[i]->data != graph_node_properties->src_address[i] &&
2228
- node->op != GGML_OP_VIEW
2229
- ) {
2230
- return false;
2231
- }
2232
- }
2233
- if (node->op == GGML_OP_SCALE &&
2234
- memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
2235
- return false;
2236
- }
2237
- return true;
2238
- }
2239
2092
 
2240
- /**
2241
- * @brief Check whether there is a cached CANN graph that matches the current ggml graph.
2242
- *
2243
- * This function iterates through the cached CANN graphs stored in the LRU cache and
2244
- * compares them against the given ggml computation graph. A match requires that the
2245
- * number of nodes is the same and that each node’s properties (operation type,
2246
- * dimensions, strides, inputs, and operation parameters) are identical.
2247
- *
2248
- * If a matching graph is found, it is promoted to the front of the LRU cache and the
2249
- * function returns true. Otherwise, the function returns false, indicating that a new
2250
- * CANN graph needs to be captured.
2251
- *
2252
- * @param cann_ctx The CANN backend context containing the graph cache.
2253
- * @param cgraph The current ggml computation graph.
2254
- * @return true if a matching cached graph exists; false otherwise.
2255
- */
2256
- static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
2257
- ggml_cann_graph_lru_cache &lru_cache = cann_ctx->graph_lru_cache;
2258
- for (auto &graph_ptr : lru_cache.cache_list) {
2259
- // Skip graphs with a different number of nodes.
2260
- if (graph_ptr->ggml_graph_properties.size() != static_cast<size_t>(cgraph->n_nodes)) {
2261
- continue;
2262
- }
2263
-
2264
- // Check if all nodes match.
2265
- bool all_match = true;
2266
- for (int i = 0; i < cgraph->n_nodes; ++i) {
2267
- if (!ggml_graph_node_has_matching_properties(cgraph->nodes[i], &graph_ptr->ggml_graph_properties[i])) {
2268
- all_match = false;
2269
- break;
2270
- }
2271
- }
2272
-
2273
- if (all_match) {
2274
- // update cache_list && renturn graph_ptr
2275
- lru_cache.move_to_front(graph_ptr);
2276
- return true;
2093
+ // CANN backend supports fusing ADD + RMS_NORM operations
2094
+ if ((ops.size() == 2) && ops.begin()[0] == GGML_OP_ADD && ops.begin()[1] == GGML_OP_RMS_NORM) {
2095
+ ggml_tensor * add_node = cgraph->nodes[node_idx];
2096
+ // TODO: support broadcast for ADD + RMS_NORM
2097
+ if (add_node->src[0]->ne[0] != add_node->src[1]->ne[0] || add_node->src[0]->ne[1] != add_node->src[1]->ne[1] ||
2098
+ add_node->src[0]->ne[2] != add_node->src[1]->ne[2] || add_node->src[0]->ne[3] != add_node->src[1]->ne[3]) {
2099
+ return false;
2277
2100
  }
2101
+ return true;
2278
2102
  }
2279
2103
 
2280
2104
  return false;
2281
2105
  }
2282
- #endif // USE_ACL_GRAPH
2283
2106
 
2284
2107
  /**
2285
2108
  * @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
@@ -2289,26 +2112,41 @@ static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph *
2289
2112
  *
2290
2113
  * Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher.
2291
2114
  *
2292
- * @param cann_ctx The CANN backend context.
2293
- * @param cgraph The ggml computation graph.
2294
- * @param use_cann_graph Whether to use CANN graph execution.
2295
- * @param cann_graph_update_required Whether graph capture is needed due to graph changes.
2115
+ * @param cann_ctx The CANN backend context.
2116
+ * @param cgraph The ggml computation graph.
2117
+ * @param use_cann_graph Whether to use CANN graph execution.
2118
+ * @param cann_graph_capture_required Whether graph capture is needed due to graph changes.
2296
2119
  */
2297
- static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph,
2298
- bool & use_cann_graph, bool & cann_graph_update_required) {
2120
+ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx,
2121
+ ggml_cgraph * cgraph,
2122
+ bool use_cann_graph,
2123
+ bool cann_graph_capture_required) {
2299
2124
  #ifdef USE_ACL_GRAPH
2300
- ggml_cann_graph* matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
2301
- if (use_cann_graph && cann_graph_update_required) {
2125
+ if (use_cann_graph && cann_graph_capture_required) { // Begin CANN graph capture
2302
2126
  ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
2303
2127
  }
2304
- #endif // USE_ACL_GRAPH
2128
+ #endif // USE_ACL_GRAPH
2305
2129
  // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
2306
2130
  // With the use of CANN graphs, the execution will be performed by the graph launch.
2307
- if (!use_cann_graph || cann_graph_update_required) {
2131
+ static bool opt_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or(""));
2132
+
2133
+ if (!use_cann_graph || cann_graph_capture_required) {
2308
2134
  for (int i = 0; i < cgraph->n_nodes; i++) {
2309
2135
  ggml_tensor * node = cgraph->nodes[i];
2136
+ if (opt_fusion) {
2137
+ if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
2138
+ ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
2139
+ i++;
2140
+ continue;
2141
+ }
2142
+ }
2310
2143
 
2311
- if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2144
+ if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
2145
+ node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2146
+ continue;
2147
+ }
2148
+
2149
+ if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
2312
2150
  continue;
2313
2151
  }
2314
2152
 
@@ -2321,18 +2159,20 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
2321
2159
  }
2322
2160
 
2323
2161
  #ifdef USE_ACL_GRAPH
2324
- if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture
2325
- ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
2326
- }
2327
-
2328
2162
  if (use_cann_graph) {
2329
- // Execute graph
2163
+ GGML_ASSERT(!cann_ctx->graph_lru_cache.cache_list.empty());
2164
+ ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
2165
+
2166
+ if (cann_graph_capture_required) { // End CANN graph capture
2167
+ ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
2168
+ }
2169
+
2170
+ // Execute CANN graph
2330
2171
  ACL_CHECK(aclmdlRIExecuteAsync(matched_graph->graph, cann_ctx->stream()));
2331
2172
  }
2332
- #endif // USE_ACL_GRAPH
2173
+ #endif // USE_ACL_GRAPH
2333
2174
  }
2334
2175
 
2335
-
2336
2176
  /**
2337
2177
  * @brief Computes a computational graph using a CANN backend.
2338
2178
  *
@@ -2345,21 +2185,19 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
2345
2185
  * @return enum ggml_status Returns GGML_STATUS_SUCCESS if computation
2346
2186
  * completes successfully, otherwise an appropriate error status.
2347
2187
  */
2348
- static enum ggml_status ggml_backend_cann_graph_compute(
2349
- ggml_backend_t backend, ggml_cgraph* cgraph) {
2350
- ggml_backend_cann_context* cann_ctx =
2351
- (ggml_backend_cann_context*)backend->context;
2188
+ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
2189
+ ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
2352
2190
  ggml_cann_set_device(cann_ctx->device);
2353
2191
  g_nz_workspaces[cann_ctx->device].clear();
2354
2192
 
2355
2193
  // calculate rope cache for fist layer in current device.
2356
2194
  cann_ctx->rope_cache.cached = false;
2357
2195
 
2196
+ bool graph_capture_required = false;
2358
2197
  #ifdef USE_ACL_GRAPH
2359
2198
  bool use_cann_graph = true;
2360
- bool cann_graph_update_required = false;
2361
2199
 
2362
- static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
2200
+ static bool prefill_use_graph = parse_bool(get_env_as_lowercase("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
2363
2201
  if (!prefill_use_graph) {
2364
2202
  // Do not use acl_graph for prefill.
2365
2203
  for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -2380,22 +2218,17 @@ static enum ggml_status ggml_backend_cann_graph_compute(
2380
2218
 
2381
2219
  if (use_cann_graph) {
2382
2220
  // If no matching graph is found, the graph needs to be recaptured.
2383
- cann_graph_update_required = !is_matched_graph(cann_ctx, cgraph);
2384
- if (cann_graph_update_required) {
2221
+ graph_capture_required = !cann_ctx->graph_lru_cache.find_and_move_to_front(cgraph);
2222
+ if (graph_capture_required) {
2385
2223
  // If no matching graph is found, add a new ACL graph.
2386
- add_lru_matched_graph_node_properties(cann_ctx, cgraph);
2224
+ ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph);
2225
+ cann_ctx->graph_lru_cache.push(new_graph);
2387
2226
  }
2388
2227
  }
2389
2228
  #else
2390
2229
  bool use_cann_graph = false;
2391
- bool cann_graph_update_required = false;
2392
2230
  #endif // USE_ACL_GRAPH
2393
- evaluate_and_capture_cann_graph(
2394
- cann_ctx,
2395
- cgraph,
2396
- use_cann_graph,
2397
- cann_graph_update_required
2398
- );
2231
+ evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, graph_capture_required);
2399
2232
 
2400
2233
  return GGML_STATUS_SUCCESS;
2401
2234
  }
@@ -2412,8 +2245,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(
2412
2245
  * @return bool Returns true if the operation is supported by the backend,
2413
2246
  * otherwise false.
2414
2247
  */
2415
- static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2416
- const ggml_tensor* op) {
2248
+ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
2417
2249
  switch (op->op) {
2418
2250
  case GGML_OP_UNARY:
2419
2251
  switch (ggml_get_unary_op(op)) {
@@ -2448,24 +2280,24 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2448
2280
  return false;
2449
2281
  }
2450
2282
  break;
2451
- case GGML_OP_MUL_MAT: {
2452
- switch (op->src[0]->type) {
2453
- case GGML_TYPE_F16:
2454
- case GGML_TYPE_F32:
2455
- return true;
2456
- case GGML_TYPE_Q8_0:
2457
- case GGML_TYPE_Q4_0:
2283
+ case GGML_OP_MUL_MAT:
2284
+ {
2285
+ switch (op->src[0]->type) {
2286
+ case GGML_TYPE_F16:
2287
+ case GGML_TYPE_F32:
2288
+ return true;
2289
+ case GGML_TYPE_Q8_0:
2290
+ case GGML_TYPE_Q4_0:
2458
2291
  #ifdef ASCEND_310P
2459
- // Q4 && Q8 per group is not support on 310p device
2460
- return false;
2292
+ // Q4 && Q8 per group is not support on 310p device
2293
+ return false;
2461
2294
  #endif
2462
- // only support contiguous for quantized types.
2463
- return ggml_is_contiguous(op->src[0]) &&
2464
- ggml_is_contiguous(op->src[1]);
2465
- default:
2466
- return false;
2295
+ // only support contiguous for quantized types.
2296
+ return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
2297
+ default:
2298
+ return false;
2299
+ }
2467
2300
  }
2468
- }
2469
2301
  case GGML_OP_MUL_MAT_ID:
2470
2302
  switch (op->src[0]->type) {
2471
2303
  case GGML_TYPE_F16:
@@ -2478,101 +2310,109 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2478
2310
  return false;
2479
2311
  #endif
2480
2312
  // only support contiguous for quantized types.
2481
- return ggml_is_contiguous(op->src[0]) &&
2482
- ggml_is_contiguous(op->src[1]);
2313
+ return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
2483
2314
  default:
2484
2315
  return false;
2485
2316
  }
2486
2317
  // embedding
2487
- case GGML_OP_GET_ROWS: {
2488
- switch (op->src[0]->type) {
2489
- case GGML_TYPE_F32:
2490
- case GGML_TYPE_F16:
2491
- case GGML_TYPE_Q8_0:
2492
- return true;
2493
- default:
2494
- return false;
2495
- }
2496
- } break;
2497
- case GGML_OP_SET_ROWS: {
2498
- switch (op->type) {
2499
- case GGML_TYPE_F32:
2500
- case GGML_TYPE_F16:
2501
- return true;
2502
- default:
2503
- return false;
2318
+ case GGML_OP_GET_ROWS:
2319
+ {
2320
+ switch (op->src[0]->type) {
2321
+ case GGML_TYPE_F32:
2322
+ case GGML_TYPE_F16:
2323
+ case GGML_TYPE_Q8_0:
2324
+ return true;
2325
+ default:
2326
+ return false;
2327
+ }
2504
2328
  }
2505
- } break;
2506
- case GGML_OP_CPY: {
2507
- ggml_tensor *src = op->src[0];
2508
- if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
2509
- (src->type != GGML_TYPE_F32 &&
2510
- src->type != GGML_TYPE_F16)) {
2511
- // only support F32 and F16.
2512
- return false;
2329
+ break;
2330
+ case GGML_OP_SET_ROWS:
2331
+ {
2332
+ switch (op->type) {
2333
+ case GGML_TYPE_F32:
2334
+ case GGML_TYPE_F16:
2335
+ return true;
2336
+ default:
2337
+ return false;
2338
+ }
2513
2339
  }
2514
- return true;
2515
- } break;
2516
- case GGML_OP_CONT: {
2517
- // TODO: support GGML_TYPE_BF16
2518
- switch (op->src[0]->type) {
2519
- case GGML_TYPE_F32:
2520
- case GGML_TYPE_F16:
2521
- return true;
2522
- default:
2340
+ break;
2341
+ case GGML_OP_CPY:
2342
+ {
2343
+ ggml_tensor * src = op->src[0];
2344
+ if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
2345
+ (src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16)) {
2346
+ // only support F32 and F16.
2523
2347
  return false;
2348
+ }
2349
+ return true;
2524
2350
  }
2525
- }
2526
- case GGML_OP_ROPE: {
2527
- // TODO: with ops-test v == 1
2528
- // TODO: n_dims <= ne0
2529
- if (op->src[0]->ne[0] != op->op_params[1]) {
2530
- return false;
2531
- }
2532
-
2533
- const int mode = ((const int32_t *) op->op_params)[2];
2534
- if (mode & GGML_ROPE_TYPE_MROPE) {
2535
- return false;
2536
- }
2537
- if (mode & GGML_ROPE_TYPE_VISION) {
2538
- return false;
2351
+ break;
2352
+ case GGML_OP_CONT:
2353
+ {
2354
+ // TODO: support GGML_TYPE_BF16
2355
+ switch (op->src[0]->type) {
2356
+ case GGML_TYPE_F32:
2357
+ case GGML_TYPE_F16:
2358
+ return true;
2359
+ default:
2360
+ return false;
2361
+ }
2539
2362
  }
2363
+ case GGML_OP_ROPE:
2364
+ {
2365
+ if (op->src[0]->ne[0] > 896) {
2366
+ return false;
2367
+ }
2540
2368
  #ifdef ASCEND_310P
2541
- if(!ggml_is_contiguous(op->src[0])){
2542
- return false;
2543
- }
2369
+ // TODO: Support rope_dim < ne00(dim)
2370
+ if (op->src[0]->ne[0] != op->op_params[1]) {
2371
+ return false;
2372
+ }
2373
+ if (!ggml_is_contiguous(op->src[0])) {
2374
+ return false;
2375
+ }
2544
2376
  #endif
2545
- return true;
2546
- }
2547
- case GGML_OP_UPSCALE: {
2548
- // aclnnUpsampleNearest2dGetWorkspaceSize not support
2549
- // selfDimN[2]/outDimN[2] or selfDimC[3]/outDimC[3] not equal
2550
- if (op->src[0]->ne[2] * op->ne[3] != op->src[0]->ne[3] * op->ne[2]) {
2551
- return false;
2377
+ return true;
2552
2378
  }
2553
- if (op->op_params[0] != GGML_SCALE_MODE_NEAREST) {
2554
- return false;
2379
+ case GGML_OP_UPSCALE:
2380
+ {
2381
+ // aclnnUpsampleNearest2dGetWorkspaceSize not support
2382
+ // selfDimN[2]/outDimN[2] or selfDimC[3]/outDimC[3] not equal
2383
+ if (op->src[0]->ne[2] * op->ne[3] != op->src[0]->ne[3] * op->ne[2]) {
2384
+ return false;
2385
+ }
2386
+ if (op->op_params[0] != GGML_SCALE_MODE_NEAREST) {
2387
+ return false;
2388
+ }
2389
+ if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) {
2390
+ return false;
2391
+ }
2392
+ return true;
2555
2393
  }
2556
- return true;
2557
- }
2558
- case GGML_OP_POOL_2D: {
2559
- const int32_t * opts = (const int32_t *) op->op_params;
2394
+ case GGML_OP_POOL_2D:
2395
+ {
2396
+ const int32_t * opts = (const int32_t *) op->op_params;
2560
2397
  #ifdef ASCEND_310P
2561
- enum ggml_op_pool opt = static_cast<ggml_op_pool>(opts[0]);
2562
- if(opt == GGML_OP_POOL_MAX){
2563
- return false;
2564
- }
2398
+ enum ggml_op_pool opt = static_cast<ggml_op_pool>(opts[0]);
2399
+ if (opt == GGML_OP_POOL_MAX) {
2400
+ return false;
2401
+ }
2565
2402
  #endif
2566
- const int k0 = opts[1];
2567
- const int k1 = opts[2];
2568
- const int p0 = opts[5];
2569
- const int p1 = opts[6];
2570
- // value of paddingH should be at most half of kernelH
2571
- // value of paddingW should be at most half of kernelW
2572
- return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
2573
- }
2574
- case GGML_OP_DUP:
2403
+ const int k0 = opts[1];
2404
+ const int k1 = opts[2];
2405
+ const int p0 = opts[5];
2406
+ const int p1 = opts[6];
2407
+ // value of paddingH should be at most half of kernelH
2408
+ // value of paddingW should be at most half of kernelW
2409
+ return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
2410
+ }
2575
2411
  case GGML_OP_SUM:
2412
+ return ggml_is_contiguous_rows(op->src[0]);
2413
+ case GGML_OP_L2_NORM:
2414
+ case GGML_OP_CROSS_ENTROPY_LOSS:
2415
+ case GGML_OP_DUP:
2576
2416
  case GGML_OP_IM2COL:
2577
2417
  case GGML_OP_CONCAT:
2578
2418
  case GGML_OP_REPEAT:
@@ -2596,7 +2436,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2596
2436
  case GGML_OP_ARGSORT:
2597
2437
  case GGML_OP_ACC:
2598
2438
  case GGML_OP_GROUP_NORM:
2439
+ return true;
2599
2440
  case GGML_OP_PAD:
2441
+ // TODO: add circular padding support for cann, see https://github.com/ggml-org/llama.cpp/pull/16985
2442
+ return ggml_get_op_params_i32(op, 8) == 0;
2600
2443
  case GGML_OP_ARANGE:
2601
2444
  case GGML_OP_TIMESTEP_EMBEDDING:
2602
2445
  case GGML_OP_LEAKY_RELU:
@@ -2607,54 +2450,72 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2607
2450
  case GGML_OP_MEAN:
2608
2451
  case GGML_OP_PAD_REFLECT_1D:
2609
2452
  case GGML_OP_COUNT_EQUAL:
2453
+ case GGML_OP_GATED_LINEAR_ATTN:
2610
2454
  return true;
2455
+ case GGML_OP_OUT_PROD:
2456
+ {
2457
+ #ifdef ASCEND_310P
2458
+ // Ger is not supported on 310p device
2459
+ return false;
2460
+ #endif
2461
+ switch (op->src[0]->type) {
2462
+ case GGML_TYPE_F16:
2463
+ case GGML_TYPE_F32:
2464
+ return true;
2465
+ default:
2466
+ return false;
2467
+ }
2468
+ }
2611
2469
  case GGML_OP_CONV_TRANSPOSE_1D:
2612
- // TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255.
2613
- return (op->src[0]->ne[0] - 1) <= 255;
2470
+ return true;
2614
2471
  case GGML_OP_SCALE:
2615
2472
  float bias;
2616
- memcpy(&bias, (const float *)(op->op_params) + 1, sizeof(float));
2617
- return bias == 0.0f; // TODO: support bias != 0.0f
2473
+ memcpy(&bias, (const float *) (op->op_params) + 1, sizeof(float));
2474
+ return bias == 0.0f; // TODO: support bias != 0.0f
2618
2475
  case GGML_OP_SOFT_MAX:
2619
2476
  // TODO: support attention sinks [TAG_ATTN_SINKS]
2620
2477
  if (op->src[2]) {
2621
2478
  return false;
2622
2479
  }
2623
2480
  return true;
2624
- case GGML_OP_FLASH_ATTN_EXT:{
2481
+ case GGML_OP_FLASH_ATTN_EXT:
2482
+ {
2625
2483
  #ifdef ASCEND_310P
2626
- // FA not support on 310p device
2627
- return false;
2628
- #endif
2629
- // derived from [ggml-cuda.cu]
2630
- if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
2631
- return false;
2632
- }
2633
- if(op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_F32 && op->src[1]->type != GGML_TYPE_BF16){
2634
- return false;
2635
- }
2636
- if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
2637
- return false;
2638
- }
2639
- // TODO: support attention sinks [TAG_ATTN_SINKS]
2640
- if (op->src[4]) {
2641
- return false;
2642
- }
2643
- if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
2644
- // different head sizes of K and V are not supported yet
2645
- return false;
2646
- }
2647
- if (op->src[0]->ne[0] % 16 != 0) {
2648
- // TODO: padding to support
2649
- return false;
2650
- }
2651
- float logitSoftcap = 0.0f;
2652
- memcpy(&logitSoftcap, (const float *)(op->op_params) + 2, sizeof(float));
2653
- if(logitSoftcap != 0.0f) {
2484
+ // FA not support on 310p device
2654
2485
  return false;
2486
+ #endif
2487
+ // derived from [ggml-cuda.cu]
2488
+ if (op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16) {
2489
+ return false;
2490
+ }
2491
+ if (op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_F32 &&
2492
+ op->src[1]->type != GGML_TYPE_BF16) {
2493
+ return false;
2494
+ }
2495
+ if (op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16) {
2496
+ return false;
2497
+ }
2498
+ // TODO: support attention sinks [TAG_ATTN_SINKS]
2499
+ if (op->src[4]) {
2500
+ return false;
2501
+ }
2502
+ if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
2503
+ // different head sizes of K and V are not supported yet
2504
+ return false;
2505
+ }
2506
+ if (op->src[0]->ne[0] % 16 != 0) {
2507
+ // TODO: padding to support
2508
+ return false;
2509
+ }
2510
+ float logitSoftcap = 0.0f;
2511
+ memcpy(&logitSoftcap, (const float *) (op->op_params) + 2, sizeof(float));
2512
+ if (logitSoftcap != 0.0f) {
2513
+ return false;
2514
+ }
2515
+ return true;
2655
2516
  }
2517
+ case GGML_OP_SSM_CONV:
2656
2518
  return true;
2657
- }
2658
2519
  default:
2659
2520
  return false;
2660
2521
  }
@@ -2662,43 +2523,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2662
2523
  GGML_UNUSED(dev);
2663
2524
  }
2664
2525
 
2665
- /**
2666
- * @brief Checks if the backend buffer type is associated with the CANN backend.
2667
- *
2668
- * This function checks whether the provided backend buffer type is associated
2669
- * with the CANN backend based on the comparison of its name retrieval function
2670
- * pointer.
2671
- *
2672
- * @param buft Pointer to the backend buffer type to check.
2673
- * @return bool Returns true if the buffer type is associated with the CANN
2674
- * backend, otherwise false.
2675
- */
2676
- static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
2677
- return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
2678
- }
2679
-
2680
- /**
2681
- * @brief Determines if a tensor operation should be offloaded to the CANN
2682
- * backend.
2683
- *
2684
- * This function checks if a given tensor operation should be offloaded to the
2685
- * CANN backend based on the operation type and the size of the tensor. It
2686
- * returns true if the second dimension (ne[1]) of the tensor is greater than or
2687
- * equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS.
2688
- *
2689
- * @param backend Pointer to the CANN backend.
2690
- * @param op Pointer to the tensor operation to check.
2691
- * @return bool Returns true if the operation should be offloaded, otherwise
2692
- * false.
2693
- */
2694
- static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev,
2695
- const ggml_tensor* op) {
2696
- const int min_batch_size = 32;
2697
- GGML_UNUSED(dev);
2698
-
2699
- return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
2700
- }
2701
-
2702
2526
  /**
2703
2527
  * @brief Records an event on the CANN backend stream.
2704
2528
  *
@@ -2708,9 +2532,8 @@ static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev,
2708
2532
  * @param event Pointer to the event structure to be recorded.
2709
2533
  */
2710
2534
  static void ggml_backend_cann_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
2711
- ggml_backend_cann_context* cann_ctx =
2712
- (ggml_backend_cann_context*)backend->context;
2713
- ACL_CHECK(aclrtRecordEvent((aclrtEvent)event->context, cann_ctx->stream()));
2535
+ ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
2536
+ ACL_CHECK(aclrtRecordEvent((aclrtEvent) event->context, cann_ctx->stream()));
2714
2537
  }
2715
2538
 
2716
2539
  /**
@@ -2723,13 +2546,10 @@ static void ggml_backend_cann_event_record(ggml_backend_t backend, ggml_backend_
2723
2546
  * @param event Pointer to the event structure that the backend needs to wait
2724
2547
  * for.
2725
2548
  */
2726
- static void ggml_backend_cann_event_wait(ggml_backend_t backend,
2727
- ggml_backend_event_t event) {
2728
- ggml_backend_cann_context* cann_ctx =
2729
- (ggml_backend_cann_context*)backend->context;
2549
+ static void ggml_backend_cann_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
2550
+ ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
2730
2551
  if (ggml_backend_is_cann(backend)) {
2731
- ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(),
2732
- (aclrtEvent)event->context));
2552
+ ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(), (aclrtEvent) event->context));
2733
2553
  } else {
2734
2554
  GGML_ABORT("fatal error");
2735
2555
  }
@@ -2768,30 +2588,31 @@ static const ggml_backend_i ggml_backend_cann_interface = {
2768
2588
  * @return A pointer to the static GUID.
2769
2589
  */
2770
2590
  static ggml_guid_t ggml_backend_cann_guid() {
2771
- static ggml_guid guid = {0xa1, 0x94, 0xaf, 0xac, 0xbd, 0x4f, 0x47, 0x34,
2772
- 0xbe, 0x1a, 0x9e, 0x71, 0x1f, 0x9e, 0xed, 0x64};
2591
+ static ggml_guid guid = { 0xa1, 0x94, 0xaf, 0xac, 0xbd, 0x4f, 0x47, 0x34,
2592
+ 0xbe, 0x1a, 0x9e, 0x71, 0x1f, 0x9e, 0xed, 0x64 };
2773
2593
  return &guid;
2774
2594
  }
2775
2595
 
2776
2596
  // backend device
2777
2597
  struct ggml_backend_cann_device_context {
2778
- int device;
2598
+ int device;
2779
2599
  std::string name;
2780
2600
  std::string description;
2601
+ int op_offload_min_batch_size;
2781
2602
  };
2782
2603
 
2783
2604
  static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) {
2784
- ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
2605
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;
2785
2606
  return ctx->name.c_str();
2786
2607
  }
2787
2608
 
2788
- static const char* ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) {
2789
- ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
2609
+ static const char * ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) {
2610
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;
2790
2611
  return ctx->description.c_str();
2791
2612
  }
2792
2613
 
2793
2614
  static void ggml_backend_cann_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
2794
- ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
2615
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;
2795
2616
  ggml_backend_cann_get_device_memory(ctx->device, free, total);
2796
2617
  }
2797
2618
 
@@ -2818,7 +2639,7 @@ static void ggml_backend_cann_device_get_props(ggml_backend_dev_t dev, ggml_back
2818
2639
 
2819
2640
  static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, const char * params) {
2820
2641
  GGML_UNUSED(params);
2821
- ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
2642
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;
2822
2643
  return ggml_backend_cann_init(ctx->device);
2823
2644
  }
2824
2645
 
@@ -2835,19 +2656,17 @@ static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, cons
2835
2656
  * @return bool Returns true if the CANN backend supports the buffer type,
2836
2657
  * otherwise false.
2837
2658
  */
2838
- static bool ggml_backend_cann_supports_buft(
2839
- ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
2659
+ static bool ggml_backend_cann_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
2840
2660
  if (ggml_backend_buft_is_cann(buft)) {
2841
- ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
2842
- ggml_backend_cann_buffer_type_context * buft_ctx =
2843
- (ggml_backend_cann_buffer_type_context *)buft->context;
2661
+ ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *) dev->context;
2662
+ ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;
2844
2663
  return buft_ctx->device == dev_ctx->device;
2845
2664
  }
2846
2665
  return false;
2847
2666
  }
2848
2667
 
2849
2668
  static ggml_backend_buffer_type_t ggml_backend_cann_device_get_buffer_type(ggml_backend_dev_t dev) {
2850
- ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
2669
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;
2851
2670
  return ggml_backend_cann_buffer_type(ctx->device);
2852
2671
  }
2853
2672
 
@@ -2856,6 +2675,26 @@ static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(
2856
2675
  return ggml_backend_cann_host_buffer_type();
2857
2676
  }
2858
2677
 
2678
+ /**
2679
+ * @brief Determines if a tensor operation should be offloaded to the CANN
2680
+ * backend.
2681
+ *
2682
+ * This function checks if a given tensor operation should be offloaded to the
2683
+ * CANN backend based on the operation type and the size of the tensor. It
2684
+ * returns true if the second dimension (ne[1]) of the tensor is greater than or
2685
+ * equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS.
2686
+ *
2687
+ * @param backend Pointer to the CANN backend.
2688
+ * @param op Pointer to the tensor operation to check.
2689
+ * @return bool Returns true if the operation should be offloaded, otherwise
2690
+ * false.
2691
+ */
2692
+ static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
2693
+ ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
2694
+
2695
+ return op->ne[1] >= dev_ctx->op_offload_min_batch_size && op->op != GGML_OP_GET_ROWS;
2696
+ }
2697
+
2859
2698
  /**
2860
2699
  * @brief Creates a new event for the CANN backend device.
2861
2700
  *
@@ -2866,9 +2705,8 @@ static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(
2866
2705
  * @param backend Pointer to the CANN backend.
2867
2706
  * @return ggml_backend_event_t Returns a pointer to the new event structure.
2868
2707
  */
2869
- static ggml_backend_event_t ggml_backend_cann_device_event_new(
2870
- ggml_backend_dev_t dev) {
2871
- ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
2708
+ static ggml_backend_event_t ggml_backend_cann_device_event_new(ggml_backend_dev_t dev) {
2709
+ ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *) dev->context;
2872
2710
 
2873
2711
  ggml_cann_set_device(dev_ctx->device);
2874
2712
 
@@ -2890,7 +2728,7 @@ static ggml_backend_event_t ggml_backend_cann_device_event_new(
2890
2728
  * @param event Pointer to the event structure to be freed.
2891
2729
  */
2892
2730
  static void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
2893
- ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
2731
+ ACL_CHECK(aclrtDestroyEvent((aclrtEvent) event->context));
2894
2732
 
2895
2733
  delete event;
2896
2734
  GGML_UNUSED(dev);
@@ -2904,7 +2742,7 @@ static void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_bac
2904
2742
  * @param event Pointer to the event structure to be synchronized.
2905
2743
  */
2906
2744
  static void ggml_backend_cann_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
2907
- ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
2745
+ ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent) event->context));
2908
2746
 
2909
2747
  GGML_UNUSED(dev);
2910
2748
  }
@@ -2915,10 +2753,10 @@ static const ggml_backend_device_i ggml_backend_cann_device_interface = {
2915
2753
  /* .get_memory = */ ggml_backend_cann_device_get_memory,
2916
2754
  /* .get_type = */ ggml_backend_cann_device_get_type,
2917
2755
  /* .get_props = */ ggml_backend_cann_device_get_props,
2918
- /* .init_backend = */ ggml_backend_cann_device_init, // called for every card
2756
+ /* .init_backend = */ ggml_backend_cann_device_init, // called for every card
2919
2757
  /* .get_buffer_type = */ ggml_backend_cann_device_get_buffer_type,
2920
2758
  /* .get_host_buffer_type = */ ggml_backend_cann_device_get_host_buffer_type,
2921
- /* .buffer_from_host_ptr = */ NULL, // not supported for CANN
2759
+ /* .buffer_from_host_ptr = */ NULL, // not supported for CANN
2922
2760
  /* .supports_op = */ ggml_backend_cann_supports_op,
2923
2761
  /* .supports_buft = */ ggml_backend_cann_supports_buft,
2924
2762
  /* .offload_op = */ ggml_backend_cann_offload_op,
@@ -2927,7 +2765,6 @@ static const ggml_backend_device_i ggml_backend_cann_device_interface = {
2927
2765
  /* .event_synchronize = */ ggml_backend_cann_device_event_synchronize,
2928
2766
  };
2929
2767
 
2930
-
2931
2768
  // backend reg
2932
2769
  struct ggml_backend_cann_reg_context {
2933
2770
  std::vector<ggml_backend_dev_t> devices;
@@ -2939,12 +2776,12 @@ static const char * ggml_backend_cann_reg_get_name(ggml_backend_reg_t reg) {
2939
2776
  }
2940
2777
 
2941
2778
  static size_t ggml_backend_cann_reg_get_device_count(ggml_backend_reg_t reg) {
2942
- ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
2779
+ ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *) reg->context;
2943
2780
  return ctx->devices.size();
2944
2781
  }
2945
2782
 
2946
2783
  static ggml_backend_dev_t ggml_backend_cann_reg_get_device(ggml_backend_reg_t reg, size_t index) {
2947
- ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
2784
+ ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *) reg->context;
2948
2785
  GGML_ASSERT(index < ctx->devices.size());
2949
2786
  return ctx->devices[index];
2950
2787
  }
@@ -2966,34 +2803,32 @@ static const ggml_backend_reg_i ggml_backend_cann_reg_interface = {
2966
2803
  // backend registry, called only once for cann backend
2967
2804
  ggml_backend_reg_t ggml_backend_cann_reg() {
2968
2805
  static ggml_backend_reg reg;
2969
- static bool initialized = false;
2806
+ static bool initialized = false;
2970
2807
 
2971
2808
  {
2972
- static std::mutex mutex;
2809
+ static std::mutex mutex;
2973
2810
  std::lock_guard<std::mutex> lock(mutex);
2974
2811
  if (!initialized) {
2975
2812
  aclInit(nullptr);
2976
2813
  ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context;
2814
+ const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
2977
2815
 
2978
2816
  for (int i = 0; i < ggml_cann_info().device_count; i++) {
2979
- ggml_backend_cann_device_context* dev_ctx = new ggml_backend_cann_device_context();
2980
- dev_ctx->description = aclrtGetSocName();
2981
- dev_ctx->device = i;
2982
- dev_ctx->name = GGML_CANN_NAME + std::to_string(i);
2817
+ ggml_backend_cann_device_context * dev_ctx = new ggml_backend_cann_device_context();
2818
+ dev_ctx->description = aclrtGetSocName();
2819
+ dev_ctx->device = i;
2820
+ dev_ctx->name = GGML_CANN_NAME + std::to_string(i);
2821
+ dev_ctx->op_offload_min_batch_size = min_batch_size;
2983
2822
  ggml_cann_set_device(i);
2984
- ggml_backend_dev_t dev = new ggml_backend_device {
2985
- /* .iface = */ ggml_backend_cann_device_interface,
2986
- /* .reg = */ &reg,
2987
- /* .context = */ dev_ctx
2988
- };
2823
+ ggml_backend_dev_t dev = new ggml_backend_device{ /* .iface = */ ggml_backend_cann_device_interface,
2824
+ /* .reg = */ &reg,
2825
+ /* .context = */ dev_ctx };
2989
2826
  ctx->devices.push_back(dev);
2990
2827
  }
2991
2828
 
2992
- reg = ggml_backend_reg {
2993
- /* .api_version = */ GGML_BACKEND_API_VERSION,
2994
- /* .iface = */ ggml_backend_cann_reg_interface,
2995
- /* .context = */ ctx
2996
- };
2829
+ reg = ggml_backend_reg{ /* .api_version = */ GGML_BACKEND_API_VERSION,
2830
+ /* .iface = */ ggml_backend_cann_reg_interface,
2831
+ /* .context = */ ctx };
2997
2832
  }
2998
2833
 
2999
2834
  initialized = true;
@@ -3009,39 +2844,36 @@ ggml_backend_t ggml_backend_cann_init(int32_t device) {
3009
2844
  return nullptr;
3010
2845
  }
3011
2846
 
3012
- ggml_backend_cann_context* ctx = new ggml_backend_cann_context(device);
2847
+ ggml_backend_cann_context * ctx = new ggml_backend_cann_context(device);
3013
2848
  if (ctx == nullptr) {
3014
2849
  GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
3015
2850
  return nullptr;
3016
2851
  }
3017
2852
  ggml_cann_set_device(ctx->device);
3018
2853
  ggml_backend_t cann_backend =
3019
- new ggml_backend{/* .guid = */ ggml_backend_cann_guid(),
3020
- /* .interface = */ ggml_backend_cann_interface,
3021
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
3022
- /* .context = */ ctx};
2854
+ new ggml_backend{ /* .guid = */ ggml_backend_cann_guid(),
2855
+ /* .interface = */ ggml_backend_cann_interface,
2856
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
2857
+ /* .context = */ ctx };
3023
2858
 
3024
2859
  return cann_backend;
3025
2860
  }
3026
2861
 
3027
2862
  bool ggml_backend_is_cann(ggml_backend_t backend) {
3028
- return backend != NULL &&
3029
- ggml_guid_matches(backend->guid, ggml_backend_cann_guid());
2863
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cann_guid());
3030
2864
  }
3031
2865
 
3032
2866
  int32_t ggml_backend_cann_get_device_count() {
3033
2867
  return ggml_cann_info().device_count;
3034
2868
  }
3035
2869
 
3036
- void ggml_backend_cann_get_device_description(
3037
- int32_t device, char* description, size_t description_size) {
2870
+ void ggml_backend_cann_get_device_description(int32_t device, char * description, size_t description_size) {
3038
2871
  ggml_cann_set_device(device);
3039
- const char* soc_name = aclrtGetSocName();
2872
+ const char * soc_name = aclrtGetSocName();
3040
2873
  snprintf(description, description_size, "%s", soc_name);
3041
2874
  }
3042
2875
 
3043
- void ggml_backend_cann_get_device_memory(int32_t device, size_t* free,
3044
- size_t* total) {
2876
+ void ggml_backend_cann_get_device_memory(int32_t device, size_t * free, size_t * total) {
3045
2877
  ggml_cann_set_device(device);
3046
2878
  ACL_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, free, total));
3047
2879
  }