whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -22,24 +22,24 @@
22
22
 
23
23
  #include "ggml-cann.h"
24
24
 
25
+ #include "ggml-backend-impl.h"
26
+ #include "ggml-cann/aclnn_ops.h"
27
+ #include "ggml-cann/common.h"
28
+ #include "ggml-impl.h"
29
+ #include "ggml.h"
30
+
25
31
  #include <acl/acl.h>
26
- #include <stdarg.h>
27
32
  #include <aclnnop/aclnn_trans_matmul_weight.h>
33
+ #include <stdarg.h>
28
34
 
35
+ #include <chrono>
29
36
  #include <cmath>
30
37
  #include <cstdio>
31
38
  #include <cstring>
32
39
  #include <mutex>
40
+ #include <optional>
33
41
  #include <queue>
34
- #include <chrono>
35
42
  #include <unordered_set>
36
- #include <optional>
37
-
38
- #include "ggml-impl.h"
39
- #include "ggml-backend-impl.h"
40
- #include "ggml-cann/aclnn_ops.h"
41
- #include "ggml-cann/common.h"
42
- #include "ggml.h"
43
43
 
44
44
  #define GGML_COMMON_DECL_C
45
45
 
@@ -56,32 +56,41 @@
56
56
  * @param line The line number where the error occurred.
57
57
  * @param msg The error message.
58
58
  */
59
- [[noreturn]] void ggml_cann_error(const char* stmt, const char* func,
60
- const char* file, int line, const char* msg) {
59
+ [[noreturn]] void ggml_cann_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
61
60
  int32_t id = -1;
62
61
  aclrtGetDevice(&id);
63
62
 
64
63
  GGML_LOG_ERROR("CANN error: %s\n", msg);
65
- GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func,
66
- file, line);
64
+ GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, file, line);
67
65
  GGML_LOG_ERROR(" %s\n", stmt);
68
66
  // abort with GGML_ASSERT to get a stack trace
69
67
  GGML_ABORT("CANN error");
70
68
  }
71
69
 
70
+ // Thread-local variable to record the current device of this thread.
71
+ thread_local int g_current_cann_device = -1;
72
+
72
73
  /**
73
- * @brief Sets the device to be used by CANN.
74
+ * @brief Set the CANN device to be used.
74
75
  *
75
- * @param device The device ID to set.
76
+ * @param device The target device ID to set.
76
77
  */
77
78
  void ggml_cann_set_device(const int32_t device) {
78
- int current_device = -1;
79
- aclrtGetDevice(&current_device);
79
+ // int current_device = -1;
80
+ // Note: In some CANN versions, if no device has been set yet,
81
+ // aclrtGetDevice(&current_device) may return 0 by default.
82
+ // aclrtGetDevice(&current_device);
80
83
 
81
- if (device == current_device) {
82
- return;
84
+ // If the current device is already the target one, no need to switch.
85
+ if (device == g_current_cann_device) {
86
+ return;
83
87
  }
88
+
89
+ // Switch to the new device.
84
90
  ACL_CHECK(aclrtSetDevice(device));
91
+
92
+ // Update the global device record.
93
+ g_current_cann_device = device;
85
94
  }
86
95
 
87
96
  /**
@@ -96,12 +105,14 @@ int32_t ggml_cann_get_device() {
96
105
  }
97
106
 
98
107
  /**
99
- * @brief Get the value of the specified environment variable (name).
108
+ * @brief Get the value of the specified environment variable (name) as lowercase.
100
109
  * if not empty, return a std::string object
101
110
  */
102
- std::optional<std::string> get_env(const std::string& name) {
103
- const char* val = std::getenv(name.c_str());
104
- if (!val) return std::nullopt;
111
+ std::optional<std::string> get_env_as_lowercase(const std::string & name) {
112
+ const char * val = std::getenv(name.c_str());
113
+ if (!val) {
114
+ return std::nullopt;
115
+ }
105
116
  std::string res = std::string(val);
106
117
  std::transform(res.begin(), res.end(), res.begin(), ::tolower);
107
118
  return res;
@@ -110,8 +121,8 @@ std::optional<std::string> get_env(const std::string& name) {
110
121
  /**
111
122
  * @brief Verify whether the environment variable is a valid value.
112
123
  */
113
- bool parse_bool(const std::string& value) {
114
- std::unordered_set<std::string> valid_values = {"on", "1", "yes", "y", "enable", "true"};
124
+ bool parse_bool(const std::string & value) {
125
+ static const std::unordered_set<std::string> valid_values = { "on", "1", "yes", "y", "enable", "true" };
115
126
  return valid_values.find(value) != valid_values.end();
116
127
  }
117
128
 
@@ -125,7 +136,7 @@ bool parse_bool(const std::string& value) {
125
136
  * @param value The string to parse.
126
137
  * @return The parsed integer, or 0 if conversion fails.
127
138
  */
128
- int parse_integer(const std::string& value) {
139
+ int parse_integer(const std::string & value) {
129
140
  try {
130
141
  return std::stoi(value);
131
142
  } catch (...) {
@@ -144,11 +155,10 @@ int parse_integer(const std::string& value) {
144
155
  static ggml_cann_device_info ggml_cann_init() {
145
156
  ggml_cann_device_info info = {};
146
157
 
147
- aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count);
158
+ aclError err = aclrtGetDeviceCount((uint32_t *) &info.device_count);
148
159
 
149
160
  if (err != ACL_SUCCESS) {
150
- GGML_LOG_ERROR("%s: failed to initialize CANN: %s\n",
151
- __func__, aclGetRecentErrMsg());
161
+ GGML_LOG_ERROR("%s: failed to initialize CANN: %s\n", __func__, aclGetRecentErrMsg());
152
162
  return info;
153
163
  }
154
164
 
@@ -156,16 +166,15 @@ static ggml_cann_device_info ggml_cann_init() {
156
166
 
157
167
  for (int id = 0; id < info.device_count; ++id) {
158
168
  aclrtPhysicalMemProp prop = {};
159
- prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
160
- prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
161
- prop.memAttr = ACL_HBM_MEM_HUGE;
162
- prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
163
- prop.location.id = id;
164
- prop.reserve = 0;
165
- err = aclrtMemGetAllocationGranularity(
166
- &prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED,
167
- &info.devices[id].vmm_granularity);
168
- info.devices[id].vmm = err == ACL_SUCCESS;
169
+ prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
170
+ prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
171
+ prop.memAttr = ACL_HBM_MEM_HUGE;
172
+ prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
173
+ prop.location.id = id;
174
+ prop.reserve = 0;
175
+ err = aclrtMemGetAllocationGranularity(&prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED,
176
+ &info.devices[id].vmm_granularity);
177
+ info.devices[id].vmm = err == ACL_SUCCESS;
169
178
 
170
179
  size_t free, total;
171
180
  ggml_backend_cann_get_device_memory(id, &free, &total);
@@ -185,7 +194,7 @@ static ggml_cann_device_info ggml_cann_init() {
185
194
  *
186
195
  * @return A reference to the structure containing the device information.
187
196
  */
188
- const ggml_cann_device_info& ggml_cann_info() {
197
+ const ggml_cann_device_info & ggml_cann_info() {
189
198
  static ggml_cann_device_info info = ggml_cann_init();
190
199
  return info;
191
200
  }
@@ -205,7 +214,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
205
214
  /**
206
215
  * @brief The minimum free margin for a buffer.
207
216
  */
208
- static const size_t min_free_margin = 1ull << 20; // 1MB
217
+ static const size_t min_free_margin = 1ull << 20; // 1MB
209
218
 
210
219
  /**
211
220
  * @brief The alignment for buffer allocation.
@@ -226,22 +235,18 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
226
235
  * @brief Structure representing a CANN buffer.
227
236
  */
228
237
  struct ggml_cann_buffer {
229
- void* ptr = nullptr; ///< Pointer to the buffer.
230
- size_t size = 0; ///< Size of the buffer.
231
- std::chrono::steady_clock::time_point last_used; ///< Last used time.
238
+ void * ptr = nullptr; ///< Pointer to the buffer.
239
+ size_t size = 0; ///< Size of the buffer.
240
+ std::chrono::steady_clock::time_point last_used; ///< Last used time.
232
241
 
233
- bool operator>(const ggml_cann_buffer& other) const {
234
- return size > other.size;
235
- }
242
+ bool operator>(const ggml_cann_buffer & other) const { return size > other.size; }
236
243
  };
237
244
 
238
245
  /**
239
246
  * @brief Array of CANN buffers in the pool.
240
247
  */
241
- std::unordered_map<void*, size_t> buffer_pool;
242
- std::priority_queue<ggml_cann_buffer,
243
- std::vector<ggml_cann_buffer>,
244
- std::greater<>> free_buffers ;
248
+ std::unordered_map<void *, size_t> buffer_pool;
249
+ std::priority_queue<ggml_cann_buffer, std::vector<ggml_cann_buffer>, std::greater<>> free_buffers;
245
250
 
246
251
  /**
247
252
  * @brief Total size of all buffers in the pool.
@@ -254,7 +259,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
254
259
  * @param device The device ID to associate with this buffer pool.
255
260
  */
256
261
  explicit ggml_cann_pool_buf_prio(int device) : device(device) {
257
- disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
262
+ disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
258
263
  }
259
264
 
260
265
  /**
@@ -262,7 +267,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
262
267
  */
263
268
  ~ggml_cann_pool_buf_prio() {
264
269
  ggml_cann_set_device(device);
265
- for (auto& [b_ptr, b_size] : buffer_pool) {
270
+ for (auto & [b_ptr, b_size] : buffer_pool) {
266
271
  aclrtFree(b_ptr);
267
272
  pool_size -= b_size;
268
273
  }
@@ -278,14 +283,14 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
278
283
  * the allocated buffer.
279
284
  * @return A pointer to the allocated buffer.
280
285
  */
281
- void* alloc(size_t size, size_t* actual_size) override {
286
+ void * alloc(size_t size, size_t * actual_size) override {
282
287
  size = GGML_PAD(size, alignment);
283
288
  if (size == 0) {
284
289
  size = alignment;
285
290
  }
286
291
 
287
- void* ptr = nullptr;
288
- auto now = std::chrono::steady_clock::now();
292
+ void * ptr = nullptr;
293
+ auto now = std::chrono::steady_clock::now();
289
294
 
290
295
  std::vector<ggml_cann_buffer> free_buffers_rest;
291
296
  free_buffers_rest.reserve(free_buffers.size());
@@ -298,24 +303,22 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
298
303
  const size_t margin = b.size - size;
299
304
  if (margin <= max_reuse_margin) {
300
305
  *actual_size = b.size;
301
- ptr = b.ptr;
306
+ ptr = b.ptr;
302
307
  #ifdef DEBUG_CANN_MALLOC
303
308
  GGML_LOG_INFO(
304
309
  "cann pool[%d]: reused %p, "
305
310
  "pool_size = %5u MB, "
306
311
  "size = %5u MB, "
307
312
  "margin = %5u MB\n",
308
- device, b.ptr,
309
- (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
310
- (uint32_t)(GGML_PAD(size, 1048576) / 1048576),
311
- (uint32_t)(GGML_PAD(margin, 1048576) / 1048576));
313
+ device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
314
+ (uint32_t) (GGML_PAD(size, 1048576) / 1048576),
315
+ (uint32_t) (GGML_PAD(margin, 1048576) / 1048576));
312
316
  #endif
313
317
  break;
314
318
  }
315
319
  }
316
320
 
317
- bool should_clean = !disable_clean &&
318
- b.size > min_free_margin &&
321
+ bool should_clean = !disable_clean && b.size > min_free_margin &&
319
322
  std::chrono::duration_cast<std::chrono::milliseconds>(now - b.last_used).count() > 100;
320
323
  if (should_clean) {
321
324
  // free the buffer if the size is needed to be freed
@@ -327,20 +330,20 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
327
330
  "cann pool[%d]: clean %p, "
328
331
  "pool_size = %5u MB, "
329
332
  "size = %5u MB\n",
330
- device, b.ptr,
331
- (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
332
- (uint32_t)(GGML_PAD(b.size, 1048576) / 1048576));
333
+ device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
334
+ (uint32_t) (GGML_PAD(b.size, 1048576) / 1048576));
333
335
  #endif
334
336
  continue;
335
337
  }
336
338
  free_buffers_rest.push_back(b);
337
339
  }
338
- for (ggml_cann_buffer &b : free_buffers_rest) {
340
+ for (ggml_cann_buffer & b : free_buffers_rest) {
339
341
  free_buffers.push(std::move(b));
340
342
  }
341
343
 
342
344
  #ifdef DEBUG_CANN_MALLOC
343
- GGML_LOG_INFO("cann pool[%d] free pool_size = %5u MB\n\n", device, (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576));
345
+ GGML_LOG_INFO("cann pool[%d] free pool_size = %5u MB\n\n", device,
346
+ (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576));
344
347
  #endif
345
348
  if (ptr != nullptr) {
346
349
  return ptr;
@@ -356,8 +359,8 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
356
359
  "cann pool[%d]: allocate %p, "
357
360
  "pool_size = %5u MB, "
358
361
  "size = %5u MB\n",
359
- device, ptr, (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
360
- (uint32_t)(GGML_PAD(size, 1048576) / 1048576));
362
+ device, ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
363
+ (uint32_t) (GGML_PAD(size, 1048576) / 1048576));
361
364
  #endif
362
365
  buffer_pool.emplace(ptr, size);
363
366
  return ptr;
@@ -369,7 +372,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
369
372
  * @param ptr Pointer to the buffer to free.
370
373
  * @param size Size of the buffer to free.
371
374
  */
372
- void free(void* ptr, size_t size) override {
375
+ void free(void * ptr, size_t size) override {
373
376
  GGML_UNUSED(size);
374
377
  auto it = buffer_pool.find(ptr);
375
378
  if (it == buffer_pool.end()) {
@@ -377,13 +380,12 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
377
380
  }
378
381
 
379
382
  auto now = std::chrono::steady_clock::now();
380
- free_buffers.emplace(ggml_cann_buffer{ptr, it->second, now});
383
+ free_buffers.emplace(ggml_cann_buffer{ ptr, it->second, now });
381
384
  #ifdef DEBUG_CANN_MALLOC
382
385
  GGML_LOG_INFO(
383
386
  "cann pool[%d]: return %p, "
384
387
  "pool_size = %5u MB\n",
385
- device, ptr,
386
- (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576));
388
+ device, ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576));
387
389
  #endif
388
390
  }
389
391
  };
@@ -402,7 +404,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
402
404
  /**
403
405
  * @brief The minimum free margin for a buffer.
404
406
  */
405
- static const size_t min_free_margin = 1ull << 20; // 1MB
407
+ static const size_t min_free_margin = 1ull << 20; // 1MB
406
408
 
407
409
  /**
408
410
  * @brief The alignment for buffer allocation.
@@ -428,10 +430,10 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
428
430
  * @brief Structure representing a CANN buffer.
429
431
  */
430
432
  struct ggml_cann_buffer {
431
- void* ptr = nullptr; ///< Pointer to the buffer memory.
432
- size_t size = 0; ///< Size of the buffer.
433
- bool used = false; ///< Whether the buffer is currently in use.
434
- std::chrono::steady_clock::time_point last_used; ///< Last used time.
433
+ void * ptr = nullptr; ///< Pointer to the buffer memory.
434
+ size_t size = 0; ///< Size of the buffer.
435
+ bool used = false; ///< Whether the buffer is currently in use.
436
+ std::chrono::steady_clock::time_point last_used; ///< Last used time.
435
437
  };
436
438
 
437
439
  /**
@@ -450,7 +452,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
450
452
  * @param device The device ID to associate with this buffer pool.
451
453
  */
452
454
  explicit ggml_cann_pool_buf(int device) : device(device) {
453
- disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
455
+ disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
454
456
  }
455
457
 
456
458
  /**
@@ -459,7 +461,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
459
461
  ~ggml_cann_pool_buf() {
460
462
  ggml_cann_set_device(device);
461
463
  for (int i = 0; i < MAX_BUFFERS; ++i) {
462
- ggml_cann_buffer& b = buffer_pool[i];
464
+ ggml_cann_buffer & b = buffer_pool[i];
463
465
  if (b.ptr != nullptr) {
464
466
  aclrtFree(b.ptr);
465
467
  pool_size -= b.size;
@@ -476,18 +478,18 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
476
478
  * the allocated buffer.
477
479
  * @return A pointer to the allocated buffer.
478
480
  */
479
- void* alloc(size_t size, size_t* actual_size) override {
481
+ void * alloc(size_t size, size_t * actual_size) override {
480
482
  size = GGML_PAD(size, alignment);
481
483
  if (size == 0) {
482
484
  size = alignment;
483
485
  }
484
486
 
485
- void* ptr = nullptr;
486
- auto now = std::chrono::steady_clock::now();
487
+ void * ptr = nullptr;
488
+ auto now = std::chrono::steady_clock::now();
487
489
 
488
490
  int i = 0;
489
491
  for (; i < MAX_BUFFERS; ++i) {
490
- ggml_cann_buffer& b = buffer_pool[i];
492
+ ggml_cann_buffer & b = buffer_pool[i];
491
493
  if (b.ptr == nullptr) {
492
494
  break;
493
495
  }
@@ -499,25 +501,23 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
499
501
  const size_t margin = b.size - size;
500
502
  if (margin <= max_reuse_margin) {
501
503
  *actual_size = b.size;
502
- b.used = true;
503
- ptr = b.ptr;
504
+ b.used = true;
505
+ ptr = b.ptr;
504
506
  #ifdef DEBUG_CANN_MALLOC
505
507
  GGML_LOG_INFO(
506
508
  "cann pool[%d]: reused %p, "
507
509
  "pool_size = %5u MB, "
508
510
  "size = %5u MB, "
509
511
  "margin = %5u MB\n",
510
- device, b.ptr,
511
- (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
512
- (uint32_t)(GGML_PAD(size, 1048576) / 1048576),
513
- (uint32_t)(GGML_PAD(margin, 1048576) / 1048576));
512
+ device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
513
+ (uint32_t) (GGML_PAD(size, 1048576) / 1048576),
514
+ (uint32_t) (GGML_PAD(margin, 1048576) / 1048576));
514
515
  #endif
515
516
  break;
516
517
  }
517
518
  }
518
519
 
519
- bool should_clean = !disable_clean &&
520
- b.size > min_free_margin &&
520
+ bool should_clean = !disable_clean && b.size > min_free_margin &&
521
521
  std::chrono::duration_cast<std::chrono::milliseconds>(now - b.last_used).count() > 100;
522
522
  if (should_clean) {
523
523
  // free the buffer if the size is needed to be freed
@@ -528,9 +528,8 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
528
528
  "cann pool[%d]: clean %p, "
529
529
  "pool_size = %5u MB, "
530
530
  "size = %5u MB\n",
531
- device, b.ptr,
532
- (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
533
- (uint32_t)(GGML_PAD(b.size, 1048576) / 1048576));
531
+ device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
532
+ (uint32_t) (GGML_PAD(b.size, 1048576) / 1048576));
534
533
  #endif
535
534
  b.ptr = nullptr;
536
535
  }
@@ -541,13 +540,13 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
541
540
 
542
541
  if (i < MAX_BUFFERS) {
543
542
  // allocate a new buffer if no buffer can be reused
544
- ggml_cann_buffer& b = buffer_pool[i];
543
+ ggml_cann_buffer & b = buffer_pool[i];
545
544
  ggml_cann_set_device(device);
546
545
  ACL_CHECK(aclrtMalloc(&b.ptr, size, ACL_MEM_MALLOC_HUGE_FIRST));
547
546
  pool_size += size;
548
547
  *actual_size = size;
549
- b.size = size;
550
- b.used = true;
548
+ b.size = size;
549
+ b.used = true;
551
550
  if (i >= MAX_BUFFERS - 8) {
552
551
  GGML_LOG_WARN("cann pool[%d]: slots almost full\n", device);
553
552
  }
@@ -556,9 +555,8 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
556
555
  "cann pool[%d]: allocate %p, "
557
556
  "pool_size = %5u MB, "
558
557
  "size = %5u MB\n",
559
- device, b.ptr,
560
- (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
561
- (uint32_t)(GGML_PAD(b.size, 1048576) / 1048576));
558
+ device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
559
+ (uint32_t) (GGML_PAD(b.size, 1048576) / 1048576));
562
560
  #endif
563
561
  return b.ptr;
564
562
  }
@@ -572,21 +570,20 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
572
570
  * @param ptr Pointer to the buffer to free.
573
571
  * @param size Size of the buffer to free.
574
572
  */
575
- void free(void* ptr, size_t size) override {
573
+ void free(void * ptr, size_t size) override {
576
574
  GGML_UNUSED(size);
577
575
  for (int i = 0; i < MAX_BUFFERS; ++i) {
578
- ggml_cann_buffer& b = buffer_pool[i];
576
+ ggml_cann_buffer & b = buffer_pool[i];
579
577
  if (b.ptr != ptr) {
580
578
  continue;
581
579
  }
582
- b.used = false;
580
+ b.used = false;
583
581
  b.last_used = std::chrono::steady_clock::now();
584
582
  #ifdef DEBUG_CANN_MALLOC
585
583
  GGML_LOG_INFO(
586
584
  "cann pool[%d]: return %p, "
587
585
  "pool_size = %5u MB\n",
588
- device, b.ptr,
589
- (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576));
586
+ device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576));
590
587
  #endif
591
588
  return;
592
589
  }
@@ -614,7 +611,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
614
611
  /**
615
612
  * @brief Pointer to the start of the virtual memory pool.
616
613
  */
617
- void* pool_addr = 0;
614
+ void * pool_addr = 0;
618
615
 
619
616
  /**
620
617
  * @brief Amount of virtual memory used in the pool.
@@ -639,7 +636,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
639
636
  /**
640
637
  * @brief Offsets for the mapped memory regions.
641
638
  */
642
- std::vector<void*> map_offsets;
639
+ std::vector<void *> map_offsets;
643
640
 
644
641
  /**
645
642
  * @brief Constructor to initialize the buffer pool with virtual memory for
@@ -647,11 +644,10 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
647
644
  *
648
645
  * @param device The device ID to associate with this buffer pool.
649
646
  */
650
- explicit ggml_cann_pool_vmm(int device)
651
- : device(device) {
652
- auto dev = ggml_cann_info().devices[device];
647
+ explicit ggml_cann_pool_vmm(int device) : device(device) {
648
+ auto dev = ggml_cann_info().devices[device];
653
649
  granularity = dev.vmm_granularity;
654
- max_size = dev.total_vram;
650
+ max_size = dev.total_vram;
655
651
  }
656
652
 
657
653
  /**
@@ -659,10 +655,10 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
659
655
  */
660
656
  ~ggml_cann_pool_vmm() {
661
657
  if (pool_addr != 0) {
662
- for (auto& offset : map_offsets) {
658
+ for (auto & offset : map_offsets) {
663
659
  ACL_CHECK(aclrtUnmapMem(offset));
664
660
  }
665
- for (auto& handle : handles) {
661
+ for (auto & handle : handles) {
666
662
  ACL_CHECK(aclrtFreePhysical(handle));
667
663
  }
668
664
  ACL_CHECK(aclrtReleaseMemAddress(pool_addr));
@@ -677,11 +673,11 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
677
673
  * the allocated buffer.
678
674
  * @return A pointer to the allocated buffer.
679
675
  */
680
- void* alloc(size_t size, size_t* actual_size) override {
676
+ void * alloc(size_t size, size_t * actual_size) override {
681
677
  // round up the allocation size to the alignment to ensure that all
682
678
  // allocations are aligned for all data types
683
679
  const size_t alignment = 128;
684
- size = GGML_PAD(size, alignment);
680
+ size = GGML_PAD(size, alignment);
685
681
  if (size == 0) {
686
682
  size = alignment;
687
683
  }
@@ -691,53 +687,51 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
691
687
  if (size > avail) {
692
688
  // round up to the next multiple of the granularity
693
689
  size_t reserve_size = size - avail;
694
- reserve_size = GGML_PAD(reserve_size, granularity);
690
+ reserve_size = GGML_PAD(reserve_size, granularity);
695
691
 
696
692
  GGML_ASSERT(pool_size + reserve_size <= max_size);
697
693
 
698
694
  // allocate more physical memory
699
695
  aclrtPhysicalMemProp prop = {};
700
- prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
701
- prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
702
- prop.memAttr = ACL_HBM_MEM_HUGE;
703
- prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
704
- prop.location.id = device;
705
- prop.reserve = 0;
696
+ prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
697
+ prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
698
+ prop.memAttr = ACL_HBM_MEM_HUGE;
699
+ prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
700
+ prop.location.id = device;
701
+ prop.reserve = 0;
706
702
  aclrtDrvMemHandle handle;
707
703
  ACL_CHECK(aclrtMallocPhysical(&handle, reserve_size, &prop, 0));
708
704
 
709
705
  // reserve virtual address space (if not already reserved)
710
706
  if (pool_addr == 0) {
711
- ACL_CHECK(aclrtReserveMemAddress(
712
- &pool_addr, max_size, 0, NULL, 1));
707
+ ACL_CHECK(aclrtReserveMemAddress(&pool_addr, max_size, 0, NULL, 1));
713
708
  }
714
709
 
715
710
  // map at the end of the pool
716
- ACL_CHECK(aclrtMapMem((char*)pool_addr + pool_size, reserve_size, 0,
717
- handle, 0));
711
+ ACL_CHECK(aclrtMapMem((char *) pool_addr + pool_size, reserve_size, 0, handle, 0));
718
712
 
719
713
  handles.push_back(handle);
720
- map_offsets.push_back((char*)pool_addr + pool_size);
714
+ map_offsets.push_back((char *) pool_addr + pool_size);
721
715
 
722
716
  // add to the pool
723
717
  pool_size += reserve_size;
724
718
 
725
719
  #ifdef DEBUG_CANN_MALLOC
726
- GGML_LOG_INFO("cann pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
727
- device, (unsigned long long) (pool_size/1024/1024),
728
- (unsigned long long) (reserve_size/1024/1024));
720
+ GGML_LOG_INFO("cann pool[%d]: size increased to %llu MB (reserved %llu MB)\n", device,
721
+ (unsigned long long) (pool_size / 1024 / 1024),
722
+ (unsigned long long) (reserve_size / 1024 / 1024));
729
723
  #endif
730
724
  }
731
725
 
732
726
  GGML_ASSERT(pool_addr != 0);
733
727
 
734
- void* ptr = (void*)((char*)pool_addr + pool_used);
728
+ void * ptr = (void *) ((char *) pool_addr + pool_used);
735
729
  *actual_size = size;
736
730
  pool_used += size;
737
731
 
738
732
  #ifdef DEBUG_CANN_MALLOC
739
- GGML_LOG_INFO("cann pool[%d]: allocated %llu bytes at %llx\n", device,
740
- (unsigned long long)size, (unsigned long long)ptr);
733
+ GGML_LOG_INFO("cann pool[%d]: allocated %llu bytes at %llx\n", device, (unsigned long long) size,
734
+ (unsigned long long) ptr);
741
735
  #endif
742
736
  return ptr;
743
737
  }
@@ -748,16 +742,16 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
748
742
  * @param ptr Pointer to the buffer to free.
749
743
  * @param size Size of the buffer to free.
750
744
  */
751
- void free(void* ptr, size_t size) override {
745
+ void free(void * ptr, size_t size) override {
752
746
  #ifdef DEBUG_CANN_MALLOC
753
- GGML_LOG_INFO("cann pool[%d]: freed %llu bytes at %llx\n", device,
754
- (unsigned long long)size, (unsigned long long)ptr);
747
+ GGML_LOG_INFO("cann pool[%d]: freed %llu bytes at %llx\n", device, (unsigned long long) size,
748
+ (unsigned long long) ptr);
755
749
  #endif
756
750
 
757
751
  pool_used -= size;
758
752
 
759
753
  // all deallocations must be in reverse order of the allocations
760
- GGML_ASSERT(ptr == (void*)((char*)pool_addr + pool_used));
754
+ GGML_ASSERT(ptr == (void *) ((char *) pool_addr + pool_used));
761
755
  }
762
756
  };
763
757
 
@@ -769,9 +763,8 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
769
763
  * @param device The device ID for which to create the pool.
770
764
  * @return A unique pointer to the created CANN pool.
771
765
  */
772
- std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
773
- int device) {
774
- std::string mem_pool_type = get_env("GGML_CANN_MEM_POOL").value_or("");
766
+ std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(int device) {
767
+ std::string mem_pool_type = get_env_as_lowercase("GGML_CANN_MEM_POOL").value_or("");
775
768
 
776
769
  if (mem_pool_type == "prio") {
777
770
  GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device);
@@ -795,9 +788,8 @@ std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
795
788
  * ID, device pointer, and a name derived from GGML_CANN_NAME and the device ID.
796
789
  */
797
790
  struct ggml_backend_cann_buffer_context {
798
- int32_t device; ///< The device ID associated with this buffer context.
799
- void* dev_ptr =
800
- nullptr; ///< Pointer to the device memory allocated for the buffer.
791
+ int32_t device; ///< The device ID associated with this buffer context.
792
+ void * dev_ptr = nullptr; ///< Pointer to the device memory allocated for the buffer.
801
793
 
802
794
  /**
803
795
  * @brief Constructor to initialize the CANN buffer context.
@@ -805,9 +797,7 @@ struct ggml_backend_cann_buffer_context {
805
797
  * @param device The device ID associated with this buffer context.
806
798
  * @param dev_ptr Pointer to the device memory allocated for the buffer.
807
799
  */
808
- ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
809
- : device(device),
810
- dev_ptr(dev_ptr) {}
800
+ ggml_backend_cann_buffer_context(int32_t device, void * dev_ptr) : device(device), dev_ptr(dev_ptr) {}
811
801
 
812
802
  /**
813
803
  * @brief Destructor to free the device memory allocated for the buffer.
@@ -825,8 +815,8 @@ struct ggml_backend_cann_buffer_context {
825
815
  * @return true if the buffer is a CANN buffer, false otherwise.
826
816
  */
827
817
  static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft);
828
- static bool ggml_backend_buffer_is_cann(
829
- ggml_backend_buffer_t buffer) {
818
+
819
+ static bool ggml_backend_buffer_is_cann(ggml_backend_buffer_t buffer) {
830
820
  return ggml_backend_buft_is_cann(buffer->buft);
831
821
  }
832
822
 
@@ -838,10 +828,8 @@ static bool ggml_backend_buffer_is_cann(
838
828
  *
839
829
  * @param buffer The CANN buffer to free.
840
830
  */
841
- static void ggml_backend_cann_buffer_free_buffer(
842
- ggml_backend_buffer_t buffer) {
843
- ggml_backend_cann_buffer_context* ctx =
844
- (ggml_backend_cann_buffer_context*)buffer->context;
831
+ static void ggml_backend_cann_buffer_free_buffer(ggml_backend_buffer_t buffer) {
832
+ ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
845
833
  delete ctx;
846
834
  }
847
835
 
@@ -854,10 +842,8 @@ static void ggml_backend_cann_buffer_free_buffer(
854
842
  * @param buffer The CANN buffer whose base pointer is to be retrieved.
855
843
  * @return A pointer to the base of the device memory allocated for the buffer.
856
844
  */
857
- static void* ggml_backend_cann_buffer_get_base(
858
- ggml_backend_buffer_t buffer) {
859
- ggml_backend_cann_buffer_context* ctx =
860
- (ggml_backend_cann_buffer_context*)buffer->context;
845
+ static void * ggml_backend_cann_buffer_get_base(ggml_backend_buffer_t buffer) {
846
+ ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
861
847
  return ctx->dev_ptr;
862
848
  }
863
849
 
@@ -874,21 +860,17 @@ static void* ggml_backend_cann_buffer_get_base(
874
860
  * @param dst Pointer to the destination buffer where transformed data will be
875
861
  * stored.
876
862
  */
877
- static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor,
878
- const void* src,
879
- void* dst) {
863
+ static void ggml_backend_cann_transform_q4_0(ggml_tensor * tensor, const void * src, void * dst) {
864
+ int64_t n_elems = ggml_nelements(tensor);
865
+ int64_t groups = n_elems / QK4_0;
866
+ size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
880
867
 
881
- int64_t n_elems = ggml_nelements(tensor);
882
- int64_t groups = n_elems / QK4_0;
883
- size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
884
-
885
- uint8_t* quant_offset = (uint8_t*)dst;
886
- uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
868
+ uint8_t * quant_offset = (uint8_t *) dst;
869
+ uint16_t * scale_offset = (uint16_t *) ((char *) dst + quant_bytes);
887
870
 
888
871
  for (int i = 0; i < groups; i++) {
889
- const block_q4_0* group =
890
- (const block_q4_0*)((const char*)src + i * sizeof(block_q4_0));
891
- *scale_offset = group->d;
872
+ const block_q4_0 * group = (const block_q4_0 *) ((const char *) src + i * sizeof(block_q4_0));
873
+ *scale_offset = group->d;
892
874
  scale_offset++;
893
875
 
894
876
  // 0-15
@@ -907,8 +889,7 @@ static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor,
907
889
  }
908
890
 
909
891
  // put (uint4b_t -8) into int4b_t
910
- for (quant_offset = (uint8_t*)dst;
911
- quant_offset < (uint8_t*)dst + quant_bytes; quant_offset++) {
892
+ for (quant_offset = (uint8_t *) dst; quant_offset < (uint8_t *) dst + quant_bytes; quant_offset++) {
912
893
  (*quant_offset) ^= 0x88;
913
894
  }
914
895
  }
@@ -926,29 +907,27 @@ static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor,
926
907
  * @param dst Pointer to the destination buffer where the Q4.0 formatted data
927
908
  * will be stored.
928
909
  */
929
- static void ggml_backend_cann_transform_back_q4_0(
930
- const ggml_tensor* tensor, void* src, void* dst) {
931
-
932
- int64_t n_elems = ggml_nelements(tensor);
933
- int64_t groups = n_elems / QK4_0;
934
- size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
910
+ static void ggml_backend_cann_transform_back_q4_0(const ggml_tensor * tensor, void * src, void * dst) {
911
+ int64_t n_elems = ggml_nelements(tensor);
912
+ int64_t groups = n_elems / QK4_0;
913
+ size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
935
914
 
936
- uint8_t* quant_offset = (uint8_t*)src;
937
- uint16_t* scale_offset = (uint16_t*)((char*)src + quant_bytes);
915
+ uint8_t * quant_offset = (uint8_t *) src;
916
+ uint16_t * scale_offset = (uint16_t *) ((char *) src + quant_bytes);
938
917
 
939
- for (; quant_offset < (uint8_t*)src + quant_bytes; quant_offset++) {
918
+ for (; quant_offset < (uint8_t *) src + quant_bytes; quant_offset++) {
940
919
  (*quant_offset) ^= 0x88;
941
920
  }
942
- quant_offset = (uint8_t*)src;
921
+ quant_offset = (uint8_t *) src;
943
922
 
944
923
  for (int i = 0; i < groups; i++) {
945
- block_q4_0* group = (block_q4_0*)((char*)dst + i * sizeof(block_q4_0));
946
- group->d = *scale_offset;
924
+ block_q4_0 * group = (block_q4_0 *) ((char *) dst + i * sizeof(block_q4_0));
925
+ group->d = *scale_offset;
947
926
  scale_offset++;
948
927
 
949
928
  // 0-15
950
929
  for (int j = 0; j < QK4_0 / 2; j += 2) {
951
- group->qs[j] = ((*quant_offset) & 0x0F);
930
+ group->qs[j] = ((*quant_offset) & 0x0F);
952
931
  group->qs[j + 1] = ((*quant_offset) >> 4);
953
932
  quant_offset++;
954
933
  }
@@ -975,20 +954,17 @@ static void ggml_backend_cann_transform_back_q4_0(
975
954
  * @param dst Pointer to the destination buffer where transformed data will be
976
955
  * stored.
977
956
  */
978
- static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor,
979
- const void* src,
980
- void* dst) {
981
- int64_t n_elems = ggml_nelements(tensor);
982
- int64_t groups = n_elems / QK8_0;
983
- size_t quant_bytes = n_elems * sizeof(uint8_t);
957
+ static void ggml_backend_cann_transform_q8_0(ggml_tensor * tensor, const void * src, void * dst) {
958
+ int64_t n_elems = ggml_nelements(tensor);
959
+ int64_t groups = n_elems / QK8_0;
960
+ size_t quant_bytes = n_elems * sizeof(uint8_t);
984
961
 
985
- uint8_t* quant_offset = (uint8_t*)dst;
986
- uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
962
+ uint8_t * quant_offset = (uint8_t *) dst;
963
+ uint16_t * scale_offset = (uint16_t *) ((char *) dst + quant_bytes);
987
964
 
988
965
  for (int i = 0; i < groups; i++) {
989
- const block_q8_0* group =
990
- (const block_q8_0*)((const char*)src + i * sizeof(block_q8_0));
991
- *scale_offset = group->d;
966
+ const block_q8_0 * group = (const block_q8_0 *) ((const char *) src + i * sizeof(block_q8_0));
967
+ *scale_offset = group->d;
992
968
  scale_offset++;
993
969
  size_t group_quant_size = QK8_0 * sizeof(uint8_t);
994
970
  memcpy(quant_offset, group->qs, group_quant_size);
@@ -1009,19 +985,17 @@ static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor,
1009
985
  * @param dst Pointer to the destination buffer where the Q8.0 formatted data
1010
986
  * will be stored.
1011
987
  */
1012
- static void ggml_backend_cann_transform_back_q8_0(
1013
- const ggml_tensor* tensor, const void* src, void* dst) {
1014
- int64_t n_elems = ggml_nelements(tensor);
1015
- int64_t groups = n_elems / QK8_0;
1016
- size_t quant_bytes = n_elems * sizeof(uint8_t);
988
+ static void ggml_backend_cann_transform_back_q8_0(const ggml_tensor * tensor, const void * src, void * dst) {
989
+ int64_t n_elems = ggml_nelements(tensor);
990
+ int64_t groups = n_elems / QK8_0;
991
+ size_t quant_bytes = n_elems * sizeof(uint8_t);
1017
992
 
1018
- const uint8_t* quant_offset = (const uint8_t*)src;
1019
- const uint16_t* scale_offset =
1020
- (const uint16_t*)((const char*)src + quant_bytes);
993
+ const uint8_t * quant_offset = (const uint8_t *) src;
994
+ const uint16_t * scale_offset = (const uint16_t *) ((const char *) src + quant_bytes);
1021
995
 
1022
996
  for (int i = 0; i < groups; i++) {
1023
- block_q8_0* group = (block_q8_0*)((char*)dst + i * sizeof(block_q8_0));
1024
- group->d = *scale_offset;
997
+ block_q8_0 * group = (block_q8_0 *) ((char *) dst + i * sizeof(block_q8_0));
998
+ group->d = *scale_offset;
1025
999
  scale_offset++;
1026
1000
  size_t group_quant_size = QK8_0 * sizeof(uint8_t);
1027
1001
  memcpy(group->qs, quant_offset, group_quant_size);
@@ -1041,8 +1015,7 @@ static void ggml_backend_cann_transform_back_q8_0(
1041
1015
  * @param dst Pointer to the destination buffer where transformed data will be
1042
1016
  * stored.
1043
1017
  */
1044
- static void ggml_backend_cann_transform(ggml_tensor* tensor,
1045
- const void* src, void* dst) {
1018
+ static void ggml_backend_cann_transform(ggml_tensor * tensor, const void * src, void * dst) {
1046
1019
  switch (tensor->type) {
1047
1020
  case GGML_TYPE_Q4_0:
1048
1021
  ggml_backend_cann_transform_q4_0(tensor, src, dst);
@@ -1067,8 +1040,7 @@ static void ggml_backend_cann_transform(ggml_tensor* tensor,
1067
1040
  * @param dst Pointer to the destination buffer where transformed tensor data
1068
1041
  * will be stored.
1069
1042
  */
1070
- static void ggml_backend_cann_transform_back(
1071
- const ggml_tensor* tensor, void* src, void* dst) {
1043
+ static void ggml_backend_cann_transform_back(const ggml_tensor * tensor, void * src, void * dst) {
1072
1044
  switch (tensor->type) {
1073
1045
  case GGML_TYPE_Q4_0:
1074
1046
  ggml_backend_cann_transform_back_q4_0(tensor, src, dst);
@@ -1109,8 +1081,7 @@ static bool need_transform(ggml_type type) {
1109
1081
  * @param buffer The CANN buffer from which to initialize the tensor.
1110
1082
  * @param tensor Pointer to the tensor to be initialized.
1111
1083
  */
1112
- static enum ggml_status ggml_backend_cann_buffer_init_tensor(
1113
- ggml_backend_buffer_t buffer, ggml_tensor* tensor) {
1084
+ static enum ggml_status ggml_backend_cann_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
1114
1085
  if (tensor->view_src != NULL && tensor->view_offs == 0) {
1115
1086
  GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
1116
1087
  return GGML_STATUS_SUCCESS;
@@ -1121,13 +1092,11 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(
1121
1092
  if (ggml_is_quantized(tensor->type)) {
1122
1093
  // Initialize padding to 0 to avoid possible NaN values
1123
1094
  size_t original_size = ggml_nbytes(tensor);
1124
- size_t padded_size =
1125
- ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
1095
+ size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
1126
1096
 
1127
1097
  if (padded_size > original_size && tensor->view_src == nullptr) {
1128
1098
  size_t memset_size = padded_size - original_size;
1129
- ACL_CHECK(aclrtMemset((char*)tensor->data + original_size,
1130
- memset_size, 0, memset_size));
1099
+ ACL_CHECK(aclrtMemset((char *) tensor->data + original_size, memset_size, 0, memset_size));
1131
1100
  }
1132
1101
  }
1133
1102
  return GGML_STATUS_SUCCESS;
@@ -1141,8 +1110,8 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(
1141
1110
  * designed to be used with a global array, one per device.
1142
1111
  */
1143
1112
  struct ggml_cann_nz_workspace {
1144
- void* ptr; // Pointer to allocated device buffer
1145
- size_t allocated; // Size of currently allocated buffer in bytes
1113
+ void * ptr; // Pointer to allocated device buffer
1114
+ size_t allocated; // Size of currently allocated buffer in bytes
1146
1115
 
1147
1116
  /**
1148
1117
  * @brief Constructor. Initializes the workspace with no allocated memory.
@@ -1158,7 +1127,7 @@ struct ggml_cann_nz_workspace {
1158
1127
  void clear() {
1159
1128
  if (ptr) {
1160
1129
  ACL_CHECK(aclrtFree(ptr));
1161
- ptr = nullptr;
1130
+ ptr = nullptr;
1162
1131
  allocated = 0;
1163
1132
  }
1164
1133
  }
@@ -1185,7 +1154,7 @@ struct ggml_cann_nz_workspace {
1185
1154
  *
1186
1155
  * @return Pointer to the allocated buffer, or nullptr if not allocated.
1187
1156
  */
1188
- void* get() const { return ptr; }
1157
+ void * get() const { return ptr; }
1189
1158
  };
1190
1159
 
1191
1160
  /**
@@ -1207,22 +1176,19 @@ static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES];
1207
1176
  * @note The workspace buffer used in this function is managed globally and reused
1208
1177
  * across calls. This reduces overhead from repeated memory allocation and deallocation.
1209
1178
  */
1210
- static void weight_format_to_nz(ggml_tensor *tensor, size_t offset, int device) {
1211
- aclTensor* weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne,
1212
- tensor->nb, 2, ACL_FORMAT_ND, offset);
1213
- uint64_t workspaceSize = 0;
1214
- aclOpExecutor *executor;
1179
+ static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device) {
1180
+ acl_tensor_ptr weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset);
1181
+ uint64_t workspaceSize = 0;
1182
+ aclOpExecutor * executor;
1215
1183
 
1216
1184
  // TransMatmulWeight
1217
- ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed,
1218
- &workspaceSize, &executor));
1185
+ ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed.get(), &workspaceSize, &executor));
1219
1186
  // Avoid frequent malloc/free of the workspace.
1220
1187
  g_nz_workspaces[device].realloc(workspaceSize);
1221
1188
 
1222
- void* g_nz_workspace = g_nz_workspaces[device].get();
1189
+ void * g_nz_workspace = g_nz_workspaces[device].get();
1223
1190
 
1224
1191
  ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr));
1225
- ACL_CHECK(aclDestroyTensor(weightTransposed));
1226
1192
  }
1227
1193
 
1228
1194
  // TODO: need handle tensor which has paddings.
@@ -1238,11 +1204,12 @@ static void weight_format_to_nz(ggml_tensor *tensor, size_t offset, int device)
1238
1204
  * @param offset Offset in the source data from where to start copying.
1239
1205
  * @param size Size of the data to be copied, in bytes.
1240
1206
  */
1241
- static void ggml_backend_cann_buffer_set_tensor(
1242
- ggml_backend_buffer_t buffer, ggml_tensor *tensor, const void *data,
1243
- size_t offset, size_t size) {
1244
- ggml_backend_cann_buffer_context *ctx =
1245
- (ggml_backend_cann_buffer_context *)buffer->context;
1207
+ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer,
1208
+ ggml_tensor * tensor,
1209
+ const void * data,
1210
+ size_t offset,
1211
+ size_t size) {
1212
+ ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
1246
1213
 
1247
1214
  ggml_cann_set_device(ctx->device);
1248
1215
  // TODO: refer to cann(#6017), it use thread's default stream.
@@ -1250,22 +1217,19 @@ static void ggml_backend_cann_buffer_set_tensor(
1250
1217
  // Why aclrtSynchronizeDevice?
1251
1218
 
1252
1219
  // Only check env once.
1253
- static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
1220
+ static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
1254
1221
  if (!need_transform(tensor->type)) {
1255
- ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size,
1256
- ACL_MEMCPY_HOST_TO_DEVICE));
1257
- if (weight_to_nz && is_matmul_weight((const ggml_tensor*)tensor)) {
1222
+ ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE));
1223
+ if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
1258
1224
  GGML_ASSERT(tensor->ne[2] == 1);
1259
1225
  GGML_ASSERT(tensor->ne[3] == 1);
1260
1226
  weight_format_to_nz(tensor, offset, ctx->device);
1261
1227
  }
1262
1228
  } else {
1263
- void *transform_buffer = malloc(size);
1229
+ void * transform_buffer = malloc(size);
1264
1230
  ggml_backend_cann_transform(tensor, data, transform_buffer);
1265
1231
 
1266
- ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size,
1267
- transform_buffer, size,
1268
- ACL_MEMCPY_HOST_TO_DEVICE));
1232
+ ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, transform_buffer, size, ACL_MEMCPY_HOST_TO_DEVICE));
1269
1233
  free(transform_buffer);
1270
1234
  }
1271
1235
  }
@@ -1283,22 +1247,20 @@ static void ggml_backend_cann_buffer_set_tensor(
1283
1247
  * @param offset Offset in the destination buffer where to start copying.
1284
1248
  * @param size Size of the data to be copied, in bytes.
1285
1249
  */
1286
- static void ggml_backend_cann_buffer_get_tensor(
1287
- ggml_backend_buffer_t buffer, const ggml_tensor* tensor, void* data,
1288
- size_t offset, size_t size) {
1289
- ggml_backend_cann_buffer_context* ctx =
1290
- (ggml_backend_cann_buffer_context*)buffer->context;
1250
+ static void ggml_backend_cann_buffer_get_tensor(ggml_backend_buffer_t buffer,
1251
+ const ggml_tensor * tensor,
1252
+ void * data,
1253
+ size_t offset,
1254
+ size_t size) {
1255
+ ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
1291
1256
 
1292
1257
  ggml_cann_set_device(ctx->device);
1293
1258
 
1294
1259
  if (!need_transform(tensor->type)) {
1295
- ACL_CHECK(aclrtMemcpy(data, size, (char*)tensor->data + offset, size,
1296
- ACL_MEMCPY_DEVICE_TO_HOST));
1260
+ ACL_CHECK(aclrtMemcpy(data, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST));
1297
1261
  } else {
1298
- void* transform_buffer = malloc(size);
1299
- ACL_CHECK(aclrtMemcpy(transform_buffer, size,
1300
- (char*)tensor->data + offset, size,
1301
- ACL_MEMCPY_DEVICE_TO_HOST));
1262
+ void * transform_buffer = malloc(size);
1263
+ ACL_CHECK(aclrtMemcpy(transform_buffer, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST));
1302
1264
  ggml_backend_cann_transform_back(tensor, transform_buffer, data);
1303
1265
  free(transform_buffer);
1304
1266
  }
@@ -1317,19 +1279,17 @@ static void ggml_backend_cann_buffer_get_tensor(
1317
1279
  * @param dst Pointer to the destination tensor where the data will be copied.
1318
1280
  * @return true if the copy operation succeeded, false otherwise.
1319
1281
  */
1320
- static bool ggml_backend_cann_buffer_cpy_tensor(
1321
- ggml_backend_buffer_t buffer, const ggml_tensor* src, ggml_tensor* dst) {
1282
+ static bool ggml_backend_cann_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
1283
+ const ggml_tensor * src,
1284
+ ggml_tensor * dst) {
1322
1285
  if (ggml_backend_buffer_is_cann(src->buffer)) {
1323
- ggml_backend_cann_buffer_context* src_ctx =
1324
- (ggml_backend_cann_buffer_context*)src->buffer->context;
1325
- ggml_backend_cann_buffer_context* dst_ctx =
1326
- (ggml_backend_cann_buffer_context*)buffer->context;
1286
+ ggml_backend_cann_buffer_context * src_ctx = (ggml_backend_cann_buffer_context *) src->buffer->context;
1287
+ ggml_backend_cann_buffer_context * dst_ctx = (ggml_backend_cann_buffer_context *) buffer->context;
1327
1288
 
1328
1289
  size_t memcpy_size = ggml_nbytes(src);
1329
1290
  // Same device.
1330
1291
  if (src_ctx->device == dst_ctx->device) {
1331
- ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
1332
- (const char*)src->data, memcpy_size,
1292
+ ACL_CHECK(aclrtMemcpy((char *) dst->data, memcpy_size, (const char *) src->data, memcpy_size,
1333
1293
  ACL_MEMCPY_DEVICE_TO_DEVICE));
1334
1294
  return true;
1335
1295
  } else {
@@ -1339,13 +1299,11 @@ static bool ggml_backend_cann_buffer_cpy_tensor(
1339
1299
  #endif
1340
1300
  // Different device but can access by peer.
1341
1301
  int32_t canAccessPeer = 0;
1342
- ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device,
1343
- dst_ctx->device));
1302
+ ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device, dst_ctx->device));
1344
1303
  if (canAccessPeer) {
1345
1304
  ggml_cann_set_device(src_ctx->device);
1346
1305
  ACL_CHECK(aclrtDeviceEnablePeerAccess(dst_ctx->device, 0));
1347
- ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
1348
- (const char*)src->data, memcpy_size,
1306
+ ACL_CHECK(aclrtMemcpy((char *) dst->data, memcpy_size, (const char *) src->data, memcpy_size,
1349
1307
  ACL_MEMCPY_DEVICE_TO_DEVICE));
1350
1308
  return true;
1351
1309
  }
@@ -1363,10 +1321,8 @@ static bool ggml_backend_cann_buffer_cpy_tensor(
1363
1321
  * @param buffer The CANN buffer to be cleared.
1364
1322
  * @param value The value to which each byte in the buffer will be set.
1365
1323
  */
1366
- static void ggml_backend_cann_buffer_clear(
1367
- ggml_backend_buffer_t buffer, uint8_t value) {
1368
- ggml_backend_cann_buffer_context* ctx =
1369
- (ggml_backend_cann_buffer_context*)buffer->context;
1324
+ static void ggml_backend_cann_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
1325
+ ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
1370
1326
 
1371
1327
  ggml_cann_set_device(ctx->device);
1372
1328
  ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size));
@@ -1396,9 +1352,8 @@ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
1396
1352
  * buffer type.
1397
1353
  */
1398
1354
  struct ggml_backend_cann_buffer_type_context {
1399
- int32_t
1400
- device; /**< Device identifier associated with the buffer context. */
1401
- std::string name; /**< Name associated with the buffer context. */
1355
+ int32_t device; /**< Device identifier associated with the buffer context. */
1356
+ std::string name; /**< Name associated with the buffer context. */
1402
1357
  };
1403
1358
 
1404
1359
  /**
@@ -1410,10 +1365,8 @@ struct ggml_backend_cann_buffer_type_context {
1410
1365
  * @param buft Pointer to the buffer type context.
1411
1366
  * @return Const pointer to the C-style string containing the name.
1412
1367
  */
1413
- static const char* ggml_backend_cann_buffer_type_name(
1414
- ggml_backend_buffer_type_t buft) {
1415
- ggml_backend_cann_buffer_type_context* buft_ctx =
1416
- (ggml_backend_cann_buffer_type_context*)buft->context;
1368
+ static const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) {
1369
+ ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;
1417
1370
 
1418
1371
  return buft_ctx->name.c_str();
1419
1372
  }
@@ -1428,34 +1381,27 @@ static const char* ggml_backend_cann_buffer_type_name(
1428
1381
  * @param size Size in bytes of the buffer to allocate.
1429
1382
  * @return Pointer to the allocated buffer, or nullptr if allocation fails.
1430
1383
  */
1431
- static ggml_backend_buffer_t
1432
- ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
1433
- size_t size) {
1434
- ggml_backend_cann_buffer_type_context* buft_ctx =
1435
- (ggml_backend_cann_buffer_type_context*)buft->context;
1384
+ static ggml_backend_buffer_t ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1385
+ ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;
1436
1386
 
1437
1387
  ggml_cann_set_device(buft_ctx->device);
1438
1388
 
1439
1389
  const size_t alignment = 128;
1440
- size = GGML_PAD(size, alignment);
1390
+ size = GGML_PAD(size, alignment);
1441
1391
  if (size == 0) {
1442
1392
  size = alignment;
1443
1393
  }
1444
- void* dev_ptr;
1394
+ void * dev_ptr;
1445
1395
  aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
1446
1396
  if (err != ACL_SUCCESS) {
1447
- GGML_LOG_ERROR(
1448
- "%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n",
1449
- __func__, size / 1024.0 / 1024.0, buft_ctx->device,
1450
- aclGetRecentErrMsg());
1397
+ GGML_LOG_ERROR("%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n", __func__,
1398
+ size / 1024.0 / 1024.0, buft_ctx->device, aclGetRecentErrMsg());
1451
1399
  return nullptr;
1452
1400
  }
1453
1401
 
1454
- ggml_backend_cann_buffer_context* ctx =
1455
- new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);
1402
+ ggml_backend_cann_buffer_context * ctx = new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);
1456
1403
 
1457
- return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface,
1458
- ctx, size);
1404
+ return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface, ctx, size);
1459
1405
  }
1460
1406
 
1461
1407
  /**
@@ -1470,8 +1416,7 @@ ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
1470
1416
  * @return The alignment requirement in bytes (fixed at 128 bytes for CANN
1471
1417
  * buffers).
1472
1418
  */
1473
- static size_t ggml_backend_cann_buffer_type_get_alignment(
1474
- ggml_backend_buffer_type_t buft) {
1419
+ static size_t ggml_backend_cann_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
1475
1420
  return 128;
1476
1421
 
1477
1422
  GGML_UNUSED(buft);
@@ -1491,13 +1436,13 @@ static size_t ggml_backend_cann_buffer_type_get_alignment(
1491
1436
  * @return The total allocation size in bytes required for the tensor in the
1492
1437
  * CANN buffer.
1493
1438
  */
1494
- static size_t ggml_backend_cann_buffer_type_get_alloc_size(
1495
- ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
1496
- size_t size = ggml_nbytes(tensor);
1497
- int64_t ne0 = tensor->ne[0];
1439
+ static size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft,
1440
+ const ggml_tensor * tensor) {
1441
+ size_t size = ggml_nbytes(tensor);
1442
+ int64_t ne0 = tensor->ne[0];
1498
1443
 
1499
1444
  // Only check env once.
1500
- static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
1445
+ static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
1501
1446
 
1502
1447
  // last line must bigger than 32, because every single op deal at
1503
1448
  // least 32 bytes.
@@ -1507,19 +1452,17 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(
1507
1452
  // size += (line_size_align_32 - line_size);
1508
1453
  if (ggml_is_quantized(tensor->type)) {
1509
1454
  if (ne0 % MATRIX_ROW_PADDING != 0) {
1510
- size += ggml_row_size(
1511
- tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
1455
+ size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
1512
1456
  }
1513
- } else if (weight_to_nz && is_matmul_weight((const ggml_tensor*)tensor)) {
1457
+ } else if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
1514
1458
  // NZ format weight are not support quantized yet.
1515
1459
  // If ND tensor transform to NZ, size may changed.
1516
- int64_t shape[] = {tensor->ne[1], tensor->ne[0]};
1460
+ int64_t shape[] = { tensor->ne[1], tensor->ne[0] };
1517
1461
  GGML_ASSERT(tensor->ne[2] == 1);
1518
1462
  GGML_ASSERT(tensor->ne[3] == 1);
1519
- const aclIntArray *acl_shape = aclCreateIntArray(shape, 2);
1520
- size_t new_size;
1521
- ACL_CHECK(aclnnCalculateMatmulWeightSizeV2(acl_shape,
1522
- ggml_cann_type_mapping(tensor->type), &new_size));
1463
+ const aclIntArray * acl_shape = aclCreateIntArray(shape, 2);
1464
+ size_t new_size;
1465
+ ACL_CHECK(aclnnCalculateMatmulWeightSizeV2(acl_shape, ggml_cann_type_mapping(tensor->type), &new_size));
1523
1466
  ACL_CHECK(aclDestroyIntArray(acl_shape));
1524
1467
  size = std::max(size, new_size);
1525
1468
  }
@@ -1560,17 +1503,15 @@ static const ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface
1560
1503
  * @return A pointer to the buffer type interface for the specified device, or
1561
1504
  * nullptr if the device index is out of range.
1562
1505
  */
1563
- ggml_backend_buffer_type_t
1564
- ggml_backend_cann_buffer_type(int32_t device) {
1565
- static std::mutex mutex;
1506
+ ggml_backend_buffer_type_t ggml_backend_cann_buffer_type(int32_t device) {
1507
+ static std::mutex mutex;
1566
1508
  std::lock_guard<std::mutex> lock(mutex);
1567
1509
 
1568
1510
  if (device >= ggml_backend_cann_get_device_count()) {
1569
1511
  return nullptr;
1570
1512
  }
1571
1513
 
1572
- static ggml_backend_buffer_type
1573
- ggml_backend_cann_buffer_types[GGML_CANN_MAX_DEVICES];
1514
+ static ggml_backend_buffer_type ggml_backend_cann_buffer_types[GGML_CANN_MAX_DEVICES];
1574
1515
 
1575
1516
  static bool ggml_backend_cann_buffer_type_initialized = false;
1576
1517
 
@@ -1580,8 +1521,7 @@ ggml_backend_cann_buffer_type(int32_t device) {
1580
1521
  /* .iface = */ ggml_backend_cann_buffer_type_interface,
1581
1522
  /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), i),
1582
1523
  /* .context = */
1583
- new ggml_backend_cann_buffer_type_context{
1584
- i, "CANN" + std::to_string(i)},
1524
+ new ggml_backend_cann_buffer_type_context{ i, "CANN" + std::to_string(i) },
1585
1525
  };
1586
1526
  }
1587
1527
  ggml_backend_cann_buffer_type_initialized = true;
@@ -1645,16 +1585,16 @@ static void * ggml_cann_host_malloc(size_t size) {
1645
1585
  }
1646
1586
 
1647
1587
  const size_t alignment = 128;
1648
- size = GGML_PAD(size, alignment);
1588
+ size = GGML_PAD(size, alignment);
1649
1589
  if (size == 0) {
1650
1590
  size = alignment;
1651
1591
  }
1652
1592
 
1653
- void * hostPtr = nullptr;
1654
- aclError err = aclrtMallocHost((void **) &hostPtr, size);
1593
+ void * hostPtr = nullptr;
1594
+ aclError err = aclrtMallocHost((void **) &hostPtr, size);
1655
1595
  if (err != ACL_SUCCESS) {
1656
- GGML_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
1657
- size / 1024.0 / 1024.0, aclGetRecentErrMsg());
1596
+ GGML_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__, size / 1024.0 / 1024.0,
1597
+ aclGetRecentErrMsg());
1658
1598
  return nullptr;
1659
1599
  }
1660
1600
  return hostPtr;
@@ -1667,7 +1607,8 @@ static void * ggml_cann_host_malloc(size_t size) {
1667
1607
  * @param size Size in bytes of the host buffer to allocate.
1668
1608
  * @return Pointer to the allocated host buffer, or CPU buffer pointer if allocation fails.
1669
1609
  */
1670
- static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1610
+ static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
1611
+ size_t size) {
1671
1612
  void * hostPtr = ggml_cann_host_malloc(size);
1672
1613
 
1673
1614
  if (hostPtr == nullptr) {
@@ -1676,8 +1617,8 @@ static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggm
1676
1617
  }
1677
1618
 
1678
1619
  ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(hostPtr, size);
1679
- buffer->buft = buft;
1680
- buffer->iface.free_buffer = ggml_backend_cann_host_buffer_free;
1620
+ buffer->buft = buft;
1621
+ buffer->iface.free_buffer = ggml_backend_cann_host_buffer_free;
1681
1622
 
1682
1623
  return buffer;
1683
1624
  }
@@ -1691,14 +1632,15 @@ static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggm
1691
1632
  ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
1692
1633
  static struct ggml_backend_buffer_type ggml_backend_cann_buffer_type_host = {
1693
1634
  /* .iface = */ {
1694
- /* .get_name = */ ggml_backend_cann_host_buffer_type_name,
1695
- /* .alloc_buffer = */ ggml_backend_cann_host_buffer_type_alloc_buffer,
1696
- /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
1697
- /* .get_max_size = */ NULL, // defaults to SIZE_MAX
1635
+ /* .get_name = */ ggml_backend_cann_host_buffer_type_name,
1636
+ /* .alloc_buffer = */ ggml_backend_cann_host_buffer_type_alloc_buffer,
1637
+ /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
1638
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
1698
1639
  /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
1699
- /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
1700
- },
1701
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
1640
+ /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
1641
+ },
1642
+ /* .device = */
1643
+ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
1702
1644
  /* .context = */ nullptr,
1703
1645
  };
1704
1646
 
@@ -1718,8 +1660,7 @@ ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
1718
1660
  * stored.
1719
1661
  * @return true if the computation was successful; false otherwise.
1720
1662
  */
1721
- static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1722
- struct ggml_tensor* dst) {
1663
+ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct ggml_tensor * dst) {
1723
1664
  switch (dst->op) {
1724
1665
  case GGML_OP_REPEAT:
1725
1666
  ggml_cann_repeat(ctx, dst);
@@ -1765,14 +1706,14 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1765
1706
  case GGML_UNARY_OP_SILU:
1766
1707
  GGML_CANN_CALL_OP_UNARY(Silu);
1767
1708
  break;
1768
- case GGML_UNARY_OP_GELU_QUICK: {
1769
- auto lambda = [](ggml_backend_cann_context& ctx,
1770
- aclTensor* acl_src,
1771
- aclTensor* acl_dst) {
1772
- GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
1773
- };
1774
- ggml_cann_op_unary(lambda, ctx, dst);
1775
- } break;
1709
+ case GGML_UNARY_OP_GELU_QUICK:
1710
+ {
1711
+ auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
1712
+ GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
1713
+ };
1714
+ ggml_cann_op_unary(lambda, ctx, dst);
1715
+ }
1716
+ break;
1776
1717
  case GGML_UNARY_OP_TANH:
1777
1718
  GGML_CANN_CALL_OP_UNARY(Tanh);
1778
1719
  break;
@@ -1817,14 +1758,14 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1817
1758
  case GGML_GLU_OP_SWIGLU:
1818
1759
  GGML_CANN_CALL_OP_UNARY_GATED(Silu);
1819
1760
  break;
1820
- case GGML_GLU_OP_GEGLU_QUICK: {
1821
- auto lambda = [](ggml_backend_cann_context& ctx,
1822
- aclTensor* acl_src,
1823
- aclTensor* acl_dst) {
1824
- GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
1825
- };
1826
- ggml_cann_op_unary_gated(lambda, ctx, dst);
1827
- } break;
1761
+ case GGML_GLU_OP_GEGLU_QUICK:
1762
+ {
1763
+ auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
1764
+ GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
1765
+ };
1766
+ ggml_cann_op_unary_gated(lambda, ctx, dst);
1767
+ }
1768
+ break;
1828
1769
  default:
1829
1770
  return false;
1830
1771
  }
@@ -1835,6 +1776,12 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1835
1776
  case GGML_OP_GROUP_NORM:
1836
1777
  ggml_cann_group_norm(ctx, dst);
1837
1778
  break;
1779
+ case GGML_OP_L2_NORM:
1780
+ ggml_cann_l2_norm(ctx, dst);
1781
+ break;
1782
+ case GGML_OP_CROSS_ENTROPY_LOSS:
1783
+ ggml_cann_cross_entropy_loss(ctx, dst);
1784
+ break;
1838
1785
  case GGML_OP_CONCAT:
1839
1786
  ggml_cann_concat(ctx, dst);
1840
1787
  break;
@@ -1939,6 +1886,12 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1939
1886
  case GGML_OP_FLASH_ATTN_EXT:
1940
1887
  ggml_cann_flash_attn_ext(ctx, dst);
1941
1888
  break;
1889
+ case GGML_OP_OUT_PROD:
1890
+ ggml_cann_out_prod(ctx, dst);
1891
+ break;
1892
+ case GGML_OP_SSM_CONV:
1893
+ ggml_cann_ssm_conv(ctx, dst);
1894
+ break;
1942
1895
  default:
1943
1896
  return false;
1944
1897
  }
@@ -1956,9 +1909,8 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1956
1909
  * @param backend Pointer to the CANN backend structure.
1957
1910
  * @return A pointer to a constant string representing the backend name.
1958
1911
  */
1959
- static const char* ggml_backend_cann_name(ggml_backend_t backend) {
1960
- ggml_backend_cann_context* cann_ctx =
1961
- (ggml_backend_cann_context*)backend->context;
1912
+ static const char * ggml_backend_cann_name(ggml_backend_t backend) {
1913
+ ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
1962
1914
 
1963
1915
  return cann_ctx->name.c_str();
1964
1916
  }
@@ -1972,8 +1924,7 @@ static const char* ggml_backend_cann_name(ggml_backend_t backend) {
1972
1924
  * @param backend Pointer to the CANN backend structure to be freed.
1973
1925
  */
1974
1926
  static void ggml_backend_cann_free(ggml_backend_t backend) {
1975
- ggml_backend_cann_context* cann_ctx =
1976
- (ggml_backend_cann_context*)backend->context;
1927
+ ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
1977
1928
  ACL_CHECK(aclrtSynchronizeDevice());
1978
1929
  ACL_CHECK(aclrtResetDevice(cann_ctx->device));
1979
1930
 
@@ -1981,7 +1932,6 @@ static void ggml_backend_cann_free(ggml_backend_t backend) {
1981
1932
  delete backend;
1982
1933
  }
1983
1934
 
1984
-
1985
1935
  /**
1986
1936
  * @brief Sets tensor data asynchronously in the CANN backend.
1987
1937
  *
@@ -1994,21 +1944,18 @@ static void ggml_backend_cann_free(ggml_backend_t backend) {
1994
1944
  * @param size Size of the data to copy in bytes.
1995
1945
  */
1996
1946
  static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
1997
- ggml_tensor *tensor,
1998
- const void *data,
1999
- size_t offset,
2000
- size_t size) {
2001
- ggml_backend_cann_context *cann_ctx =
2002
- (ggml_backend_cann_context *)backend->context;
2003
- ggml_backend_buffer_t buf =
2004
- tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
2005
-
2006
- GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
2007
- "unsupported buffer type");
1947
+ ggml_tensor * tensor,
1948
+ const void * data,
1949
+ size_t offset,
1950
+ size_t size) {
1951
+ ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
1952
+ ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
1953
+
1954
+ GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
2008
1955
  GGML_ASSERT(!ggml_is_quantized(tensor->type));
2009
1956
 
2010
- ggml_cann_async_memcpy(cann_ctx, (char *)tensor->data + offset, data, size,
2011
- ACL_MEMCPY_HOST_TO_DEVICE);
1957
+ ACL_CHECK(aclrtMemcpyAsync((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE,
1958
+ cann_ctx->stream()));
2012
1959
  }
2013
1960
 
2014
1961
  /**
@@ -2022,21 +1969,19 @@ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
2022
1969
  * @param offset Offset in bytes within the host data.
2023
1970
  * @param size Size of the data to copy in bytes.
2024
1971
  */
2025
- static void ggml_backend_cann_get_tensor_async(
2026
- ggml_backend_t backend, const ggml_tensor *tensor, void *data,
2027
- size_t offset, size_t size) {
2028
- ggml_backend_cann_context *cann_ctx =
2029
- (ggml_backend_cann_context *)backend->context;
2030
- ggml_backend_buffer_t buf =
2031
- tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
1972
+ static void ggml_backend_cann_get_tensor_async(ggml_backend_t backend,
1973
+ const ggml_tensor * tensor,
1974
+ void * data,
1975
+ size_t offset,
1976
+ size_t size) {
1977
+ ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
1978
+ ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
2032
1979
 
2033
- GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
2034
- "unsupported buffer type");
1980
+ GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
2035
1981
  GGML_ASSERT(!ggml_is_quantized(tensor->type));
2036
1982
 
2037
- ggml_cann_async_memcpy(cann_ctx, data, (char *)tensor->data + offset, size,
2038
- ACL_MEMCPY_DEVICE_TO_HOST);
2039
-
1983
+ ACL_CHECK(aclrtMemcpyAsync(data, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST,
1984
+ cann_ctx->stream()));
2040
1985
  }
2041
1986
 
2042
1987
  /**
@@ -2052,28 +1997,23 @@ static void ggml_backend_cann_get_tensor_async(
2052
1997
  * @param dst Pointer to the destination tensor to copy data to.
2053
1998
  * @return true if the copy operation succeeds, false otherwise.
2054
1999
  */
2055
- static bool ggml_backend_cann_cpy_tensor_async(
2056
- ggml_backend_t backend_src, ggml_backend_t backend_dst,
2057
- const ggml_tensor* src, ggml_tensor* dst) {
2058
- GGML_ASSERT(ggml_backend_is_cann(backend_src) ||
2059
- ggml_backend_is_cann(backend_dst));
2000
+ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src,
2001
+ ggml_backend_t backend_dst,
2002
+ const ggml_tensor * src,
2003
+ ggml_tensor * dst) {
2004
+ GGML_ASSERT(ggml_backend_is_cann(backend_src) || ggml_backend_is_cann(backend_dst));
2060
2005
 
2061
- GGML_ASSERT(!is_matmul_weight((const ggml_tensor*)src));
2006
+ GGML_ASSERT(!is_matmul_weight((const ggml_tensor *) src));
2062
2007
 
2063
- if (!ggml_backend_buffer_is_cann(src->buffer) ||
2064
- !ggml_backend_buffer_is_cann(dst->buffer)) {
2008
+ if (!ggml_backend_buffer_is_cann(src->buffer) || !ggml_backend_buffer_is_cann(dst->buffer)) {
2065
2009
  return false;
2066
2010
  }
2067
2011
 
2068
- ggml_backend_buffer_t buf_src =
2069
- src->view_src ? src->view_src->buffer : src->buffer;
2070
- ggml_backend_buffer_t buf_dst =
2071
- dst->view_src ? dst->view_src->buffer : dst->buffer;
2012
+ ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
2013
+ ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
2072
2014
 
2073
- ggml_backend_cann_context* cann_ctx_src =
2074
- (ggml_backend_cann_context*)backend_src->context;
2075
- ggml_backend_cann_context* cann_ctx_dst =
2076
- (ggml_backend_cann_context*)backend_dst->context;
2015
+ ggml_backend_cann_context * cann_ctx_src = (ggml_backend_cann_context *) backend_src->context;
2016
+ ggml_backend_cann_context * cann_ctx_dst = (ggml_backend_cann_context *) backend_dst->context;
2077
2017
 
2078
2018
  size_t copy_size = ggml_nbytes(dst);
2079
2019
  if (copy_size == 0) {
@@ -2084,17 +2024,14 @@ static bool ggml_backend_cann_cpy_tensor_async(
2084
2024
  // TODO: Support 310p P2P copy
2085
2025
  return false;
2086
2026
  #endif
2087
- ggml_backend_cann_buffer_context* buf_ctx_src =
2088
- (ggml_backend_cann_buffer_context*)buf_src->context;
2089
- ggml_backend_cann_buffer_context* buf_ctx_dst =
2090
- (ggml_backend_cann_buffer_context*)buf_dst->context;
2027
+ ggml_backend_cann_buffer_context * buf_ctx_src = (ggml_backend_cann_buffer_context *) buf_src->context;
2028
+ ggml_backend_cann_buffer_context * buf_ctx_dst = (ggml_backend_cann_buffer_context *) buf_dst->context;
2091
2029
 
2092
2030
  GGML_ASSERT(cann_ctx_src->device == buf_ctx_src->device);
2093
2031
  GGML_ASSERT(cann_ctx_dst->device == buf_ctx_dst->device);
2094
2032
 
2095
2033
  int32_t canAccessPeer = 0;
2096
- ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, cann_ctx_src->device,
2097
- cann_ctx_dst->device));
2034
+ ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, cann_ctx_src->device, cann_ctx_dst->device));
2098
2035
  if (!canAccessPeer) {
2099
2036
  return false;
2100
2037
  }
@@ -2105,9 +2042,7 @@ static bool ggml_backend_cann_cpy_tensor_async(
2105
2042
  ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
2106
2043
 
2107
2044
  // wait for task_queue empty to keep task order.
2108
- cann_ctx_src->task_queue.wait();
2109
- ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
2110
- ACL_MEMCPY_DEVICE_TO_DEVICE,
2045
+ ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE,
2111
2046
  cann_ctx_src->stream()));
2112
2047
  // record event on src stream after the copy
2113
2048
  // TODO: this event is not effective with acl graph mode, change to use aclrtSynchronizeStream
@@ -2122,8 +2057,7 @@ static bool ggml_backend_cann_cpy_tensor_async(
2122
2057
  ACL_CHECK(aclrtSynchronizeStream(cann_ctx_src->stream()));
2123
2058
  } else {
2124
2059
  // src and dst are on the same backend
2125
- ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
2126
- ACL_MEMCPY_DEVICE_TO_DEVICE,
2060
+ ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE,
2127
2061
  cann_ctx_dst->stream()));
2128
2062
  }
2129
2063
 
@@ -2139,147 +2073,44 @@ static bool ggml_backend_cann_cpy_tensor_async(
2139
2073
  * @param backend Pointer to the CANN backend structure to synchronize.
2140
2074
  */
2141
2075
  static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
2142
- ggml_backend_cann_context* cann_ctx =
2143
- (ggml_backend_cann_context*)backend->context;
2144
- cann_ctx->task_queue.wait();
2076
+ ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
2145
2077
  ggml_cann_set_device(cann_ctx->device);
2146
2078
  ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
2147
2079
  }
2148
2080
 
2149
- #ifdef USE_ACL_GRAPH
2150
2081
  /**
2151
- * @brief Add a new CANN graph to the LRU cache by populating node properties from the ggml graph.
2152
- *
2153
- * This function creates a new ggml_cann_graph object and fills its node properties
2154
- * (operation type, dimensions, strides, input sources, and operation parameters)
2155
- * based on the current ggml computation graph.
2156
- *
2157
- * Each node in the ggml graph is mapped to a property entry in the new CANN graph:
2158
- * - node address
2159
- * - operation type
2160
- * - shape (ne) and strides (nb)
2161
- * - source tensor addresses
2162
- * - operation parameters
2082
+ * @brief Check if CANN backend can fuse the specified operation sequence
2163
2083
  *
2164
- * After initialization, the new graph is pushed into the LRU cache owned by the
2165
- * CANN backend context. The cache takes ownership of the graph and manages its
2166
- * lifetime (including deletion upon eviction).
2084
+ * This function determines whether an operation sequence starting from the specified node
2085
+ * can be fused into an optimized operation in the CANN backend. Operation fusion can reduce
2086
+ * memory access overhead and improve computational efficiency.
2167
2087
  *
2168
- * @param cann_ctx The CANN backend context containing the graph cache.
2169
- * @param cgraph The current ggml computation graph.
2088
+ * @param cgraph Pointer to the computation graph
2089
+ * @param node_idx Index of the starting node in the computation graph
2090
+ * @param ops Sequence of operation types to check for fusion
2091
+ * @return true if the operations can be fused
2092
+ * @return false if the operations cannot be fused
2170
2093
  */
2171
- static void add_lru_matched_graph_node_properties(
2172
- ggml_backend_cann_context * cann_ctx,
2173
- ggml_cgraph * cgraph) {
2174
- // Create a new ggml_cann_graph object on the heap (its lifetime is managed by the cache).
2175
- ggml_cann_graph * new_graph = new ggml_cann_graph();
2176
- new_graph->ggml_graph_properties.resize(cgraph->n_nodes);
2177
-
2178
- for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) {
2179
- ggml_tensor * node = cgraph->nodes[node_idx];
2180
- auto & prop = new_graph->ggml_graph_properties[node_idx];
2181
-
2182
- prop.node_address = node->data;
2183
- prop.node_op = node->op;
2184
-
2185
- std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne);
2186
- std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
2187
-
2188
- for (int src = 0; src < GGML_MAX_SRC; ++src) {
2189
- prop.src_address[src] = node->src[src] ? node->src[src]->data : nullptr;
2190
- }
2191
-
2192
- memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
2193
- }
2194
-
2195
- // Insert into the LRU cache (cache takes ownership and will delete it when evicted).
2196
- cann_ctx->graph_lru_cache.push(new_graph);
2197
- }
2198
-
2199
- /**
2200
- * @brief Check if a ggml tensor node matches a previously captured CANN graph node.
2201
- *
2202
- * This function compares all relevant fields (address, op type, shape, source inputs, op params)
2203
- * to determine whether the current node matches a previously recorded version.
2204
- *
2205
- * @param node The current ggml tensor node.
2206
- * @param graph_node_properties The stored properties of a CANN graph node.
2207
- * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
2208
- */
2209
- static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
2210
- if (node->data != graph_node_properties->node_address &&
2211
- node->op != GGML_OP_VIEW) {
2212
- return false;
2213
- }
2214
- if (node->op != graph_node_properties->node_op) {
2094
+ static bool ggml_cann_can_fuse(const struct ggml_cgraph * cgraph,
2095
+ int node_idx,
2096
+ std::initializer_list<enum ggml_op> ops) {
2097
+ if (!ggml_can_fuse(cgraph, node_idx, ops)) {
2215
2098
  return false;
2216
2099
  }
2217
- for (int i = 0; i < GGML_MAX_DIMS; i++) {
2218
- if (node->ne[i] != graph_node_properties->ne[i]) {
2219
- return false;
2220
- }
2221
- if (node->nb[i] != graph_node_properties->nb[i]) {
2222
- return false;
2223
- }
2224
- }
2225
- for (int i = 0; i < GGML_MAX_SRC; i++) {
2226
- if (node->src[i] &&
2227
- node->src[i]->data != graph_node_properties->src_address[i] &&
2228
- node->op != GGML_OP_VIEW
2229
- ) {
2230
- return false;
2231
- }
2232
- }
2233
- if (node->op == GGML_OP_SCALE &&
2234
- memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
2235
- return false;
2236
- }
2237
- return true;
2238
- }
2239
2100
 
2240
- /**
2241
- * @brief Check whether there is a cached CANN graph that matches the current ggml graph.
2242
- *
2243
- * This function iterates through the cached CANN graphs stored in the LRU cache and
2244
- * compares them against the given ggml computation graph. A match requires that the
2245
- * number of nodes is the same and that each node’s properties (operation type,
2246
- * dimensions, strides, inputs, and operation parameters) are identical.
2247
- *
2248
- * If a matching graph is found, it is promoted to the front of the LRU cache and the
2249
- * function returns true. Otherwise, the function returns false, indicating that a new
2250
- * CANN graph needs to be captured.
2251
- *
2252
- * @param cann_ctx The CANN backend context containing the graph cache.
2253
- * @param cgraph The current ggml computation graph.
2254
- * @return true if a matching cached graph exists; false otherwise.
2255
- */
2256
- static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
2257
- ggml_cann_graph_lru_cache &lru_cache = cann_ctx->graph_lru_cache;
2258
- for (auto &graph_ptr : lru_cache.cache_list) {
2259
- // Skip graphs with a different number of nodes.
2260
- if (graph_ptr->ggml_graph_properties.size() != static_cast<size_t>(cgraph->n_nodes)) {
2261
- continue;
2262
- }
2263
-
2264
- // Check if all nodes match.
2265
- bool all_match = true;
2266
- for (int i = 0; i < cgraph->n_nodes; ++i) {
2267
- if (!ggml_graph_node_has_matching_properties(cgraph->nodes[i], &graph_ptr->ggml_graph_properties[i])) {
2268
- all_match = false;
2269
- break;
2270
- }
2271
- }
2272
-
2273
- if (all_match) {
2274
- // update cache_list && renturn graph_ptr
2275
- lru_cache.move_to_front(graph_ptr);
2276
- return true;
2101
+ // CANN backend supports fusing ADD + RMS_NORM operations
2102
+ if ((ops.size() == 2) && ops.begin()[0] == GGML_OP_ADD && ops.begin()[1] == GGML_OP_RMS_NORM) {
2103
+ ggml_tensor * add_node = cgraph->nodes[node_idx];
2104
+ // TODO: support broadcast for ADD + RMS_NORM
2105
+ if (add_node->src[0]->ne[0] != add_node->src[1]->ne[0] || add_node->src[0]->ne[1] != add_node->src[1]->ne[1] ||
2106
+ add_node->src[0]->ne[2] != add_node->src[1]->ne[2] || add_node->src[0]->ne[3] != add_node->src[1]->ne[3]) {
2107
+ return false;
2277
2108
  }
2109
+ return true;
2278
2110
  }
2279
2111
 
2280
2112
  return false;
2281
2113
  }
2282
- #endif // USE_ACL_GRAPH
2283
2114
 
2284
2115
  /**
2285
2116
  * @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
@@ -2289,26 +2120,37 @@ static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph *
2289
2120
  *
2290
2121
  * Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher.
2291
2122
  *
2292
- * @param cann_ctx The CANN backend context.
2293
- * @param cgraph The ggml computation graph.
2294
- * @param use_cann_graph Whether to use CANN graph execution.
2295
- * @param cann_graph_update_required Whether graph capture is needed due to graph changes.
2123
+ * @param cann_ctx The CANN backend context.
2124
+ * @param cgraph The ggml computation graph.
2125
+ * @param use_cann_graph Whether to use CANN graph execution.
2126
+ * @param cann_graph_capture_required Whether graph capture is needed due to graph changes.
2296
2127
  */
2297
- static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph,
2298
- bool & use_cann_graph, bool & cann_graph_update_required) {
2128
+ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx,
2129
+ ggml_cgraph * cgraph,
2130
+ bool use_cann_graph,
2131
+ bool cann_graph_capture_required) {
2299
2132
  #ifdef USE_ACL_GRAPH
2300
- ggml_cann_graph* matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
2301
- if (use_cann_graph && cann_graph_update_required) {
2133
+ if (use_cann_graph && cann_graph_capture_required) { // Begin CANN graph capture
2302
2134
  ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
2303
2135
  }
2304
- #endif // USE_ACL_GRAPH
2136
+ #endif // USE_ACL_GRAPH
2305
2137
  // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
2306
2138
  // With the use of CANN graphs, the execution will be performed by the graph launch.
2307
- if (!use_cann_graph || cann_graph_update_required) {
2139
+ static bool opt_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or(""));
2140
+
2141
+ if (!use_cann_graph || cann_graph_capture_required) {
2308
2142
  for (int i = 0; i < cgraph->n_nodes; i++) {
2309
2143
  ggml_tensor * node = cgraph->nodes[i];
2144
+ if (opt_fusion) {
2145
+ if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
2146
+ ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
2147
+ i++;
2148
+ continue;
2149
+ }
2150
+ }
2310
2151
 
2311
- if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2152
+ if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
2153
+ node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2312
2154
  continue;
2313
2155
  }
2314
2156
 
@@ -2321,18 +2163,20 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
2321
2163
  }
2322
2164
 
2323
2165
  #ifdef USE_ACL_GRAPH
2324
- if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture
2325
- ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
2326
- }
2327
-
2328
2166
  if (use_cann_graph) {
2329
- // Execute graph
2167
+ GGML_ASSERT(!cann_ctx->graph_lru_cache.cache_list.empty());
2168
+ ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
2169
+
2170
+ if (cann_graph_capture_required) { // End CANN graph capture
2171
+ ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
2172
+ }
2173
+
2174
+ // Execute CANN graph
2330
2175
  ACL_CHECK(aclmdlRIExecuteAsync(matched_graph->graph, cann_ctx->stream()));
2331
2176
  }
2332
- #endif // USE_ACL_GRAPH
2177
+ #endif // USE_ACL_GRAPH
2333
2178
  }
2334
2179
 
2335
-
2336
2180
  /**
2337
2181
  * @brief Computes a computational graph using a CANN backend.
2338
2182
  *
@@ -2345,21 +2189,19 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
2345
2189
  * @return enum ggml_status Returns GGML_STATUS_SUCCESS if computation
2346
2190
  * completes successfully, otherwise an appropriate error status.
2347
2191
  */
2348
- static enum ggml_status ggml_backend_cann_graph_compute(
2349
- ggml_backend_t backend, ggml_cgraph* cgraph) {
2350
- ggml_backend_cann_context* cann_ctx =
2351
- (ggml_backend_cann_context*)backend->context;
2192
+ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
2193
+ ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
2352
2194
  ggml_cann_set_device(cann_ctx->device);
2353
2195
  g_nz_workspaces[cann_ctx->device].clear();
2354
2196
 
2355
2197
  // calculate rope cache for fist layer in current device.
2356
2198
  cann_ctx->rope_cache.cached = false;
2357
2199
 
2200
+ bool graph_capture_required = false;
2358
2201
  #ifdef USE_ACL_GRAPH
2359
2202
  bool use_cann_graph = true;
2360
- bool cann_graph_update_required = false;
2361
2203
 
2362
- static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
2204
+ static bool prefill_use_graph = parse_bool(get_env_as_lowercase("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
2363
2205
  if (!prefill_use_graph) {
2364
2206
  // Do not use acl_graph for prefill.
2365
2207
  for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -2380,22 +2222,17 @@ static enum ggml_status ggml_backend_cann_graph_compute(
2380
2222
 
2381
2223
  if (use_cann_graph) {
2382
2224
  // If no matching graph is found, the graph needs to be recaptured.
2383
- cann_graph_update_required = !is_matched_graph(cann_ctx, cgraph);
2384
- if (cann_graph_update_required) {
2225
+ graph_capture_required = !cann_ctx->graph_lru_cache.find_and_move_to_front(cgraph);
2226
+ if (graph_capture_required) {
2385
2227
  // If no matching graph is found, add a new ACL graph.
2386
- add_lru_matched_graph_node_properties(cann_ctx, cgraph);
2228
+ ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph);
2229
+ cann_ctx->graph_lru_cache.push(new_graph);
2387
2230
  }
2388
2231
  }
2389
2232
  #else
2390
2233
  bool use_cann_graph = false;
2391
- bool cann_graph_update_required = false;
2392
2234
  #endif // USE_ACL_GRAPH
2393
- evaluate_and_capture_cann_graph(
2394
- cann_ctx,
2395
- cgraph,
2396
- use_cann_graph,
2397
- cann_graph_update_required
2398
- );
2235
+ evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, graph_capture_required);
2399
2236
 
2400
2237
  return GGML_STATUS_SUCCESS;
2401
2238
  }
@@ -2412,8 +2249,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(
2412
2249
  * @return bool Returns true if the operation is supported by the backend,
2413
2250
  * otherwise false.
2414
2251
  */
2415
- static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2416
- const ggml_tensor* op) {
2252
+ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
2417
2253
  switch (op->op) {
2418
2254
  case GGML_OP_UNARY:
2419
2255
  switch (ggml_get_unary_op(op)) {
@@ -2448,24 +2284,24 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2448
2284
  return false;
2449
2285
  }
2450
2286
  break;
2451
- case GGML_OP_MUL_MAT: {
2452
- switch (op->src[0]->type) {
2453
- case GGML_TYPE_F16:
2454
- case GGML_TYPE_F32:
2455
- return true;
2456
- case GGML_TYPE_Q8_0:
2457
- case GGML_TYPE_Q4_0:
2287
+ case GGML_OP_MUL_MAT:
2288
+ {
2289
+ switch (op->src[0]->type) {
2290
+ case GGML_TYPE_F16:
2291
+ case GGML_TYPE_F32:
2292
+ return true;
2293
+ case GGML_TYPE_Q8_0:
2294
+ case GGML_TYPE_Q4_0:
2458
2295
  #ifdef ASCEND_310P
2459
- // Q4 && Q8 per group is not support on 310p device
2460
- return false;
2296
+ // Q4 && Q8 per group is not support on 310p device
2297
+ return false;
2461
2298
  #endif
2462
- // only support contiguous for quantized types.
2463
- return ggml_is_contiguous(op->src[0]) &&
2464
- ggml_is_contiguous(op->src[1]);
2465
- default:
2466
- return false;
2299
+ // only support contiguous for quantized types.
2300
+ return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
2301
+ default:
2302
+ return false;
2303
+ }
2467
2304
  }
2468
- }
2469
2305
  case GGML_OP_MUL_MAT_ID:
2470
2306
  switch (op->src[0]->type) {
2471
2307
  case GGML_TYPE_F16:
@@ -2478,101 +2314,109 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2478
2314
  return false;
2479
2315
  #endif
2480
2316
  // only support contiguous for quantized types.
2481
- return ggml_is_contiguous(op->src[0]) &&
2482
- ggml_is_contiguous(op->src[1]);
2317
+ return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
2483
2318
  default:
2484
2319
  return false;
2485
2320
  }
2486
2321
  // embedding
2487
- case GGML_OP_GET_ROWS: {
2488
- switch (op->src[0]->type) {
2489
- case GGML_TYPE_F32:
2490
- case GGML_TYPE_F16:
2491
- case GGML_TYPE_Q8_0:
2492
- return true;
2493
- default:
2494
- return false;
2495
- }
2496
- } break;
2497
- case GGML_OP_SET_ROWS: {
2498
- switch (op->type) {
2499
- case GGML_TYPE_F32:
2500
- case GGML_TYPE_F16:
2501
- return true;
2502
- default:
2503
- return false;
2322
+ case GGML_OP_GET_ROWS:
2323
+ {
2324
+ switch (op->src[0]->type) {
2325
+ case GGML_TYPE_F32:
2326
+ case GGML_TYPE_F16:
2327
+ case GGML_TYPE_Q8_0:
2328
+ return true;
2329
+ default:
2330
+ return false;
2331
+ }
2504
2332
  }
2505
- } break;
2506
- case GGML_OP_CPY: {
2507
- ggml_tensor *src = op->src[0];
2508
- if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
2509
- (src->type != GGML_TYPE_F32 &&
2510
- src->type != GGML_TYPE_F16)) {
2511
- // only support F32 and F16.
2512
- return false;
2333
+ break;
2334
+ case GGML_OP_SET_ROWS:
2335
+ {
2336
+ switch (op->type) {
2337
+ case GGML_TYPE_F32:
2338
+ case GGML_TYPE_F16:
2339
+ return true;
2340
+ default:
2341
+ return false;
2342
+ }
2513
2343
  }
2514
- return true;
2515
- } break;
2516
- case GGML_OP_CONT: {
2517
- // TODO: support GGML_TYPE_BF16
2518
- switch (op->src[0]->type) {
2519
- case GGML_TYPE_F32:
2520
- case GGML_TYPE_F16:
2521
- return true;
2522
- default:
2344
+ break;
2345
+ case GGML_OP_CPY:
2346
+ {
2347
+ ggml_tensor * src = op->src[0];
2348
+ if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
2349
+ (src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16)) {
2350
+ // only support F32 and F16.
2523
2351
  return false;
2352
+ }
2353
+ return true;
2524
2354
  }
2525
- }
2526
- case GGML_OP_ROPE: {
2527
- // TODO: with ops-test v == 1
2528
- // TODO: n_dims <= ne0
2529
- if (op->src[0]->ne[0] != op->op_params[1]) {
2530
- return false;
2531
- }
2532
-
2533
- const int mode = ((const int32_t *) op->op_params)[2];
2534
- if (mode & GGML_ROPE_TYPE_MROPE) {
2535
- return false;
2536
- }
2537
- if (mode & GGML_ROPE_TYPE_VISION) {
2538
- return false;
2355
+ break;
2356
+ case GGML_OP_CONT:
2357
+ {
2358
+ // TODO: support GGML_TYPE_BF16
2359
+ switch (op->src[0]->type) {
2360
+ case GGML_TYPE_F32:
2361
+ case GGML_TYPE_F16:
2362
+ return true;
2363
+ default:
2364
+ return false;
2365
+ }
2539
2366
  }
2367
+ case GGML_OP_ROPE:
2368
+ {
2369
+ if (op->src[0]->ne[0] > 896) {
2370
+ return false;
2371
+ }
2540
2372
  #ifdef ASCEND_310P
2541
- if(!ggml_is_contiguous(op->src[0])){
2542
- return false;
2543
- }
2373
+ // TODO: Support rope_dim < ne00(dim)
2374
+ if (op->src[0]->ne[0] != op->op_params[1]) {
2375
+ return false;
2376
+ }
2377
+ if (!ggml_is_contiguous(op->src[0])) {
2378
+ return false;
2379
+ }
2544
2380
  #endif
2545
- return true;
2546
- }
2547
- case GGML_OP_UPSCALE: {
2548
- // aclnnUpsampleNearest2dGetWorkspaceSize not support
2549
- // selfDimN[2]/outDimN[2] or selfDimC[3]/outDimC[3] not equal
2550
- if (op->src[0]->ne[2] * op->ne[3] != op->src[0]->ne[3] * op->ne[2]) {
2551
- return false;
2381
+ return true;
2552
2382
  }
2553
- if (op->op_params[0] != GGML_SCALE_MODE_NEAREST) {
2554
- return false;
2383
+ case GGML_OP_UPSCALE:
2384
+ {
2385
+ // aclnnUpsampleNearest2dGetWorkspaceSize not support
2386
+ // selfDimN[2]/outDimN[2] or selfDimC[3]/outDimC[3] not equal
2387
+ if (op->src[0]->ne[2] * op->ne[3] != op->src[0]->ne[3] * op->ne[2]) {
2388
+ return false;
2389
+ }
2390
+ if (op->op_params[0] != GGML_SCALE_MODE_NEAREST) {
2391
+ return false;
2392
+ }
2393
+ if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) {
2394
+ return false;
2395
+ }
2396
+ return true;
2555
2397
  }
2556
- return true;
2557
- }
2558
- case GGML_OP_POOL_2D: {
2559
- const int32_t * opts = (const int32_t *) op->op_params;
2398
+ case GGML_OP_POOL_2D:
2399
+ {
2400
+ const int32_t * opts = (const int32_t *) op->op_params;
2560
2401
  #ifdef ASCEND_310P
2561
- enum ggml_op_pool opt = static_cast<ggml_op_pool>(opts[0]);
2562
- if(opt == GGML_OP_POOL_MAX){
2563
- return false;
2564
- }
2402
+ enum ggml_op_pool opt = static_cast<ggml_op_pool>(opts[0]);
2403
+ if (opt == GGML_OP_POOL_MAX) {
2404
+ return false;
2405
+ }
2565
2406
  #endif
2566
- const int k0 = opts[1];
2567
- const int k1 = opts[2];
2568
- const int p0 = opts[5];
2569
- const int p1 = opts[6];
2570
- // value of paddingH should be at most half of kernelH
2571
- // value of paddingW should be at most half of kernelW
2572
- return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
2573
- }
2574
- case GGML_OP_DUP:
2407
+ const int k0 = opts[1];
2408
+ const int k1 = opts[2];
2409
+ const int p0 = opts[5];
2410
+ const int p1 = opts[6];
2411
+ // value of paddingH should be at most half of kernelH
2412
+ // value of paddingW should be at most half of kernelW
2413
+ return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
2414
+ }
2575
2415
  case GGML_OP_SUM:
2416
+ return ggml_is_contiguous_rows(op->src[0]);
2417
+ case GGML_OP_L2_NORM:
2418
+ case GGML_OP_CROSS_ENTROPY_LOSS:
2419
+ case GGML_OP_DUP:
2576
2420
  case GGML_OP_IM2COL:
2577
2421
  case GGML_OP_CONCAT:
2578
2422
  case GGML_OP_REPEAT:
@@ -2596,7 +2440,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2596
2440
  case GGML_OP_ARGSORT:
2597
2441
  case GGML_OP_ACC:
2598
2442
  case GGML_OP_GROUP_NORM:
2443
+ return true;
2599
2444
  case GGML_OP_PAD:
2445
+ // TODO: add circular padding support for cann, see https://github.com/ggml-org/llama.cpp/pull/16985
2446
+ return ggml_get_op_params_i32(op, 8) == 0;
2600
2447
  case GGML_OP_ARANGE:
2601
2448
  case GGML_OP_TIMESTEP_EMBEDDING:
2602
2449
  case GGML_OP_LEAKY_RELU:
@@ -2608,53 +2455,70 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2608
2455
  case GGML_OP_PAD_REFLECT_1D:
2609
2456
  case GGML_OP_COUNT_EQUAL:
2610
2457
  return true;
2458
+ case GGML_OP_OUT_PROD:
2459
+ {
2460
+ #ifdef ASCEND_310P
2461
+ // Ger is not supported on 310p device
2462
+ return false;
2463
+ #endif
2464
+ switch (op->src[0]->type) {
2465
+ case GGML_TYPE_F16:
2466
+ case GGML_TYPE_F32:
2467
+ return true;
2468
+ default:
2469
+ return false;
2470
+ }
2471
+ }
2611
2472
  case GGML_OP_CONV_TRANSPOSE_1D:
2612
- // TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255.
2613
- return (op->src[0]->ne[0] - 1) <= 255;
2473
+ return true;
2614
2474
  case GGML_OP_SCALE:
2615
2475
  float bias;
2616
- memcpy(&bias, (const float *)(op->op_params) + 1, sizeof(float));
2617
- return bias == 0.0f; // TODO: support bias != 0.0f
2476
+ memcpy(&bias, (const float *) (op->op_params) + 1, sizeof(float));
2477
+ return bias == 0.0f; // TODO: support bias != 0.0f
2618
2478
  case GGML_OP_SOFT_MAX:
2619
2479
  // TODO: support attention sinks [TAG_ATTN_SINKS]
2620
2480
  if (op->src[2]) {
2621
2481
  return false;
2622
2482
  }
2623
2483
  return true;
2624
- case GGML_OP_FLASH_ATTN_EXT:{
2484
+ case GGML_OP_FLASH_ATTN_EXT:
2485
+ {
2625
2486
  #ifdef ASCEND_310P
2626
- // FA not support on 310p device
2627
- return false;
2628
- #endif
2629
- // derived from [ggml-cuda.cu]
2630
- if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
2631
- return false;
2632
- }
2633
- if(op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_F32 && op->src[1]->type != GGML_TYPE_BF16){
2634
- return false;
2635
- }
2636
- if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
2637
- return false;
2638
- }
2639
- // TODO: support attention sinks [TAG_ATTN_SINKS]
2640
- if (op->src[4]) {
2641
- return false;
2642
- }
2643
- if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
2644
- // different head sizes of K and V are not supported yet
2645
- return false;
2646
- }
2647
- if (op->src[0]->ne[0] % 16 != 0) {
2648
- // TODO: padding to support
2649
- return false;
2650
- }
2651
- float logitSoftcap = 0.0f;
2652
- memcpy(&logitSoftcap, (const float *)(op->op_params) + 2, sizeof(float));
2653
- if(logitSoftcap != 0.0f) {
2487
+ // FA not support on 310p device
2654
2488
  return false;
2489
+ #endif
2490
+ // derived from [ggml-cuda.cu]
2491
+ if (op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16) {
2492
+ return false;
2493
+ }
2494
+ if (op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_F32 &&
2495
+ op->src[1]->type != GGML_TYPE_BF16) {
2496
+ return false;
2497
+ }
2498
+ if (op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16) {
2499
+ return false;
2500
+ }
2501
+ // TODO: support attention sinks [TAG_ATTN_SINKS]
2502
+ if (op->src[4]) {
2503
+ return false;
2504
+ }
2505
+ if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
2506
+ // different head sizes of K and V are not supported yet
2507
+ return false;
2508
+ }
2509
+ if (op->src[0]->ne[0] % 16 != 0) {
2510
+ // TODO: padding to support
2511
+ return false;
2512
+ }
2513
+ float logitSoftcap = 0.0f;
2514
+ memcpy(&logitSoftcap, (const float *) (op->op_params) + 2, sizeof(float));
2515
+ if (logitSoftcap != 0.0f) {
2516
+ return false;
2517
+ }
2518
+ return true;
2655
2519
  }
2520
+ case GGML_OP_SSM_CONV:
2656
2521
  return true;
2657
- }
2658
2522
  default:
2659
2523
  return false;
2660
2524
  }
@@ -2677,28 +2541,6 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
2677
2541
  return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
2678
2542
  }
2679
2543
 
2680
- /**
2681
- * @brief Determines if a tensor operation should be offloaded to the CANN
2682
- * backend.
2683
- *
2684
- * This function checks if a given tensor operation should be offloaded to the
2685
- * CANN backend based on the operation type and the size of the tensor. It
2686
- * returns true if the second dimension (ne[1]) of the tensor is greater than or
2687
- * equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS.
2688
- *
2689
- * @param backend Pointer to the CANN backend.
2690
- * @param op Pointer to the tensor operation to check.
2691
- * @return bool Returns true if the operation should be offloaded, otherwise
2692
- * false.
2693
- */
2694
- static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev,
2695
- const ggml_tensor* op) {
2696
- const int min_batch_size = 32;
2697
- GGML_UNUSED(dev);
2698
-
2699
- return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
2700
- }
2701
-
2702
2544
  /**
2703
2545
  * @brief Records an event on the CANN backend stream.
2704
2546
  *
@@ -2708,9 +2550,8 @@ static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev,
2708
2550
  * @param event Pointer to the event structure to be recorded.
2709
2551
  */
2710
2552
  static void ggml_backend_cann_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
2711
- ggml_backend_cann_context* cann_ctx =
2712
- (ggml_backend_cann_context*)backend->context;
2713
- ACL_CHECK(aclrtRecordEvent((aclrtEvent)event->context, cann_ctx->stream()));
2553
+ ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
2554
+ ACL_CHECK(aclrtRecordEvent((aclrtEvent) event->context, cann_ctx->stream()));
2714
2555
  }
2715
2556
 
2716
2557
  /**
@@ -2723,13 +2564,10 @@ static void ggml_backend_cann_event_record(ggml_backend_t backend, ggml_backend_
2723
2564
  * @param event Pointer to the event structure that the backend needs to wait
2724
2565
  * for.
2725
2566
  */
2726
- static void ggml_backend_cann_event_wait(ggml_backend_t backend,
2727
- ggml_backend_event_t event) {
2728
- ggml_backend_cann_context* cann_ctx =
2729
- (ggml_backend_cann_context*)backend->context;
2567
+ static void ggml_backend_cann_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
2568
+ ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
2730
2569
  if (ggml_backend_is_cann(backend)) {
2731
- ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(),
2732
- (aclrtEvent)event->context));
2570
+ ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(), (aclrtEvent) event->context));
2733
2571
  } else {
2734
2572
  GGML_ABORT("fatal error");
2735
2573
  }
@@ -2768,30 +2606,31 @@ static const ggml_backend_i ggml_backend_cann_interface = {
2768
2606
  * @return A pointer to the static GUID.
2769
2607
  */
2770
2608
  static ggml_guid_t ggml_backend_cann_guid() {
2771
- static ggml_guid guid = {0xa1, 0x94, 0xaf, 0xac, 0xbd, 0x4f, 0x47, 0x34,
2772
- 0xbe, 0x1a, 0x9e, 0x71, 0x1f, 0x9e, 0xed, 0x64};
2609
+ static ggml_guid guid = { 0xa1, 0x94, 0xaf, 0xac, 0xbd, 0x4f, 0x47, 0x34,
2610
+ 0xbe, 0x1a, 0x9e, 0x71, 0x1f, 0x9e, 0xed, 0x64 };
2773
2611
  return &guid;
2774
2612
  }
2775
2613
 
2776
2614
  // backend device
2777
2615
  struct ggml_backend_cann_device_context {
2778
- int device;
2616
+ int device;
2779
2617
  std::string name;
2780
2618
  std::string description;
2619
+ int op_offload_min_batch_size;
2781
2620
  };
2782
2621
 
2783
2622
  static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) {
2784
- ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
2623
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;
2785
2624
  return ctx->name.c_str();
2786
2625
  }
2787
2626
 
2788
- static const char* ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) {
2789
- ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
2627
+ static const char * ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) {
2628
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;
2790
2629
  return ctx->description.c_str();
2791
2630
  }
2792
2631
 
2793
2632
  static void ggml_backend_cann_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
2794
- ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
2633
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;
2795
2634
  ggml_backend_cann_get_device_memory(ctx->device, free, total);
2796
2635
  }
2797
2636
 
@@ -2818,7 +2657,7 @@ static void ggml_backend_cann_device_get_props(ggml_backend_dev_t dev, ggml_back
2818
2657
 
2819
2658
  static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, const char * params) {
2820
2659
  GGML_UNUSED(params);
2821
- ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
2660
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;
2822
2661
  return ggml_backend_cann_init(ctx->device);
2823
2662
  }
2824
2663
 
@@ -2835,19 +2674,17 @@ static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, cons
2835
2674
  * @return bool Returns true if the CANN backend supports the buffer type,
2836
2675
  * otherwise false.
2837
2676
  */
2838
- static bool ggml_backend_cann_supports_buft(
2839
- ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
2677
+ static bool ggml_backend_cann_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
2840
2678
  if (ggml_backend_buft_is_cann(buft)) {
2841
- ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
2842
- ggml_backend_cann_buffer_type_context * buft_ctx =
2843
- (ggml_backend_cann_buffer_type_context *)buft->context;
2679
+ ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *) dev->context;
2680
+ ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;
2844
2681
  return buft_ctx->device == dev_ctx->device;
2845
2682
  }
2846
2683
  return false;
2847
2684
  }
2848
2685
 
2849
2686
  static ggml_backend_buffer_type_t ggml_backend_cann_device_get_buffer_type(ggml_backend_dev_t dev) {
2850
- ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
2687
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;
2851
2688
  return ggml_backend_cann_buffer_type(ctx->device);
2852
2689
  }
2853
2690
 
@@ -2856,6 +2693,26 @@ static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(
2856
2693
  return ggml_backend_cann_host_buffer_type();
2857
2694
  }
2858
2695
 
2696
+ /**
2697
+ * @brief Determines if a tensor operation should be offloaded to the CANN
2698
+ * backend.
2699
+ *
2700
+ * This function checks if a given tensor operation should be offloaded to the
2701
+ * CANN backend based on the operation type and the size of the tensor. It
2702
+ * returns true if the second dimension (ne[1]) of the tensor is greater than or
2703
+ * equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS.
2704
+ *
2705
+ * @param backend Pointer to the CANN backend.
2706
+ * @param op Pointer to the tensor operation to check.
2707
+ * @return bool Returns true if the operation should be offloaded, otherwise
2708
+ * false.
2709
+ */
2710
+ static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
2711
+ ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
2712
+
2713
+ return op->ne[1] >= dev_ctx->op_offload_min_batch_size && op->op != GGML_OP_GET_ROWS;
2714
+ }
2715
+
2859
2716
  /**
2860
2717
  * @brief Creates a new event for the CANN backend device.
2861
2718
  *
@@ -2866,9 +2723,8 @@ static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(
2866
2723
  * @param backend Pointer to the CANN backend.
2867
2724
  * @return ggml_backend_event_t Returns a pointer to the new event structure.
2868
2725
  */
2869
- static ggml_backend_event_t ggml_backend_cann_device_event_new(
2870
- ggml_backend_dev_t dev) {
2871
- ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
2726
+ static ggml_backend_event_t ggml_backend_cann_device_event_new(ggml_backend_dev_t dev) {
2727
+ ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *) dev->context;
2872
2728
 
2873
2729
  ggml_cann_set_device(dev_ctx->device);
2874
2730
 
@@ -2890,7 +2746,7 @@ static ggml_backend_event_t ggml_backend_cann_device_event_new(
2890
2746
  * @param event Pointer to the event structure to be freed.
2891
2747
  */
2892
2748
  static void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
2893
- ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
2749
+ ACL_CHECK(aclrtDestroyEvent((aclrtEvent) event->context));
2894
2750
 
2895
2751
  delete event;
2896
2752
  GGML_UNUSED(dev);
@@ -2904,7 +2760,7 @@ static void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_bac
2904
2760
  * @param event Pointer to the event structure to be synchronized.
2905
2761
  */
2906
2762
  static void ggml_backend_cann_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
2907
- ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
2763
+ ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent) event->context));
2908
2764
 
2909
2765
  GGML_UNUSED(dev);
2910
2766
  }
@@ -2915,10 +2771,10 @@ static const ggml_backend_device_i ggml_backend_cann_device_interface = {
2915
2771
  /* .get_memory = */ ggml_backend_cann_device_get_memory,
2916
2772
  /* .get_type = */ ggml_backend_cann_device_get_type,
2917
2773
  /* .get_props = */ ggml_backend_cann_device_get_props,
2918
- /* .init_backend = */ ggml_backend_cann_device_init, // called for every card
2774
+ /* .init_backend = */ ggml_backend_cann_device_init, // called for every card
2919
2775
  /* .get_buffer_type = */ ggml_backend_cann_device_get_buffer_type,
2920
2776
  /* .get_host_buffer_type = */ ggml_backend_cann_device_get_host_buffer_type,
2921
- /* .buffer_from_host_ptr = */ NULL, // not supported for CANN
2777
+ /* .buffer_from_host_ptr = */ NULL, // not supported for CANN
2922
2778
  /* .supports_op = */ ggml_backend_cann_supports_op,
2923
2779
  /* .supports_buft = */ ggml_backend_cann_supports_buft,
2924
2780
  /* .offload_op = */ ggml_backend_cann_offload_op,
@@ -2927,7 +2783,6 @@ static const ggml_backend_device_i ggml_backend_cann_device_interface = {
2927
2783
  /* .event_synchronize = */ ggml_backend_cann_device_event_synchronize,
2928
2784
  };
2929
2785
 
2930
-
2931
2786
  // backend reg
2932
2787
  struct ggml_backend_cann_reg_context {
2933
2788
  std::vector<ggml_backend_dev_t> devices;
@@ -2939,12 +2794,12 @@ static const char * ggml_backend_cann_reg_get_name(ggml_backend_reg_t reg) {
2939
2794
  }
2940
2795
 
2941
2796
  static size_t ggml_backend_cann_reg_get_device_count(ggml_backend_reg_t reg) {
2942
- ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
2797
+ ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *) reg->context;
2943
2798
  return ctx->devices.size();
2944
2799
  }
2945
2800
 
2946
2801
  static ggml_backend_dev_t ggml_backend_cann_reg_get_device(ggml_backend_reg_t reg, size_t index) {
2947
- ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
2802
+ ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *) reg->context;
2948
2803
  GGML_ASSERT(index < ctx->devices.size());
2949
2804
  return ctx->devices[index];
2950
2805
  }
@@ -2966,34 +2821,32 @@ static const ggml_backend_reg_i ggml_backend_cann_reg_interface = {
2966
2821
  // backend registry, called only once for cann backend
2967
2822
  ggml_backend_reg_t ggml_backend_cann_reg() {
2968
2823
  static ggml_backend_reg reg;
2969
- static bool initialized = false;
2824
+ static bool initialized = false;
2970
2825
 
2971
2826
  {
2972
- static std::mutex mutex;
2827
+ static std::mutex mutex;
2973
2828
  std::lock_guard<std::mutex> lock(mutex);
2974
2829
  if (!initialized) {
2975
2830
  aclInit(nullptr);
2976
2831
  ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context;
2832
+ const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
2977
2833
 
2978
2834
  for (int i = 0; i < ggml_cann_info().device_count; i++) {
2979
- ggml_backend_cann_device_context* dev_ctx = new ggml_backend_cann_device_context();
2980
- dev_ctx->description = aclrtGetSocName();
2981
- dev_ctx->device = i;
2982
- dev_ctx->name = GGML_CANN_NAME + std::to_string(i);
2835
+ ggml_backend_cann_device_context * dev_ctx = new ggml_backend_cann_device_context();
2836
+ dev_ctx->description = aclrtGetSocName();
2837
+ dev_ctx->device = i;
2838
+ dev_ctx->name = GGML_CANN_NAME + std::to_string(i);
2839
+ dev_ctx->op_offload_min_batch_size = min_batch_size;
2983
2840
  ggml_cann_set_device(i);
2984
- ggml_backend_dev_t dev = new ggml_backend_device {
2985
- /* .iface = */ ggml_backend_cann_device_interface,
2986
- /* .reg = */ &reg,
2987
- /* .context = */ dev_ctx
2988
- };
2841
+ ggml_backend_dev_t dev = new ggml_backend_device{ /* .iface = */ ggml_backend_cann_device_interface,
2842
+ /* .reg = */ &reg,
2843
+ /* .context = */ dev_ctx };
2989
2844
  ctx->devices.push_back(dev);
2990
2845
  }
2991
2846
 
2992
- reg = ggml_backend_reg {
2993
- /* .api_version = */ GGML_BACKEND_API_VERSION,
2994
- /* .iface = */ ggml_backend_cann_reg_interface,
2995
- /* .context = */ ctx
2996
- };
2847
+ reg = ggml_backend_reg{ /* .api_version = */ GGML_BACKEND_API_VERSION,
2848
+ /* .iface = */ ggml_backend_cann_reg_interface,
2849
+ /* .context = */ ctx };
2997
2850
  }
2998
2851
 
2999
2852
  initialized = true;
@@ -3009,39 +2862,36 @@ ggml_backend_t ggml_backend_cann_init(int32_t device) {
3009
2862
  return nullptr;
3010
2863
  }
3011
2864
 
3012
- ggml_backend_cann_context* ctx = new ggml_backend_cann_context(device);
2865
+ ggml_backend_cann_context * ctx = new ggml_backend_cann_context(device);
3013
2866
  if (ctx == nullptr) {
3014
2867
  GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
3015
2868
  return nullptr;
3016
2869
  }
3017
2870
  ggml_cann_set_device(ctx->device);
3018
2871
  ggml_backend_t cann_backend =
3019
- new ggml_backend{/* .guid = */ ggml_backend_cann_guid(),
3020
- /* .interface = */ ggml_backend_cann_interface,
3021
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
3022
- /* .context = */ ctx};
2872
+ new ggml_backend{ /* .guid = */ ggml_backend_cann_guid(),
2873
+ /* .interface = */ ggml_backend_cann_interface,
2874
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
2875
+ /* .context = */ ctx };
3023
2876
 
3024
2877
  return cann_backend;
3025
2878
  }
3026
2879
 
3027
2880
  bool ggml_backend_is_cann(ggml_backend_t backend) {
3028
- return backend != NULL &&
3029
- ggml_guid_matches(backend->guid, ggml_backend_cann_guid());
2881
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cann_guid());
3030
2882
  }
3031
2883
 
3032
2884
  int32_t ggml_backend_cann_get_device_count() {
3033
2885
  return ggml_cann_info().device_count;
3034
2886
  }
3035
2887
 
3036
- void ggml_backend_cann_get_device_description(
3037
- int32_t device, char* description, size_t description_size) {
2888
+ void ggml_backend_cann_get_device_description(int32_t device, char * description, size_t description_size) {
3038
2889
  ggml_cann_set_device(device);
3039
- const char* soc_name = aclrtGetSocName();
2890
+ const char * soc_name = aclrtGetSocName();
3040
2891
  snprintf(description, description_size, "%s", soc_name);
3041
2892
  }
3042
2893
 
3043
- void ggml_backend_cann_get_device_memory(int32_t device, size_t* free,
3044
- size_t* total) {
2894
+ void ggml_backend_cann_get_device_memory(int32_t device, size_t * free, size_t * total) {
3045
2895
  ggml_cann_set_device(device);
3046
2896
  ACL_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, free, total));
3047
2897
  }