whispercpp 1.3.3 → 1.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (586) hide show
  1. checksums.yaml +4 -4
  2. data/ext/ruby_whisper_params.c +55 -25
  3. data/ext/sources/CMakeLists.txt +1 -1
  4. data/ext/sources/bindings/javascript/package.json +1 -1
  5. data/ext/sources/build-xcframework.sh +24 -0
  6. data/ext/sources/examples/CMakeLists.txt +1 -0
  7. data/ext/sources/examples/addon.node/addon.cpp +19 -19
  8. data/ext/sources/examples/addon.node/index.js +7 -5
  9. data/ext/sources/examples/bench/bench.cpp +26 -16
  10. data/ext/sources/examples/bench.wasm/index-tmpl.html +10 -9
  11. data/ext/sources/examples/cli/cli.cpp +4 -2
  12. data/ext/sources/examples/command/command.cpp +26 -24
  13. data/ext/sources/examples/command.wasm/index-tmpl.html +5 -4
  14. data/ext/sources/examples/common-ggml.cpp +2 -0
  15. data/ext/sources/examples/lsp/lsp.cpp +19 -17
  16. data/ext/sources/examples/server/server.cpp +24 -13
  17. data/ext/sources/examples/server.py +6 -1
  18. data/ext/sources/examples/stream/stream.cpp +4 -2
  19. data/ext/sources/examples/stream.wasm/emscripten.cpp +6 -6
  20. data/ext/sources/examples/stream.wasm/index-tmpl.html +82 -5
  21. data/ext/sources/examples/talk-llama/CMakeLists.txt +2 -2
  22. data/ext/sources/examples/talk-llama/llama-adapter.cpp +101 -4
  23. data/ext/sources/examples/talk-llama/llama-adapter.h +6 -0
  24. data/ext/sources/examples/talk-llama/llama-arch.cpp +588 -15
  25. data/ext/sources/examples/talk-llama/llama-arch.h +58 -1
  26. data/ext/sources/examples/talk-llama/llama-batch.cpp +103 -71
  27. data/ext/sources/examples/talk-llama/llama-batch.h +31 -18
  28. data/ext/sources/examples/talk-llama/llama-chat.cpp +120 -5
  29. data/ext/sources/examples/talk-llama/llama-chat.h +7 -0
  30. data/ext/sources/examples/talk-llama/llama-context.cpp +460 -357
  31. data/ext/sources/examples/talk-llama/llama-context.h +44 -29
  32. data/ext/sources/examples/talk-llama/llama-cparams.h +4 -4
  33. data/ext/sources/examples/talk-llama/llama-graph.cpp +543 -271
  34. data/ext/sources/examples/talk-llama/llama-graph.h +278 -168
  35. data/ext/sources/examples/talk-llama/llama-hparams.cpp +118 -4
  36. data/ext/sources/examples/talk-llama/llama-hparams.h +61 -15
  37. data/ext/sources/examples/talk-llama/llama-impl.h +2 -0
  38. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +326 -0
  39. data/ext/sources/examples/talk-llama/{llama-kv-cache-unified-iswa.h → llama-kv-cache-iswa.h} +38 -29
  40. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +2020 -0
  41. data/ext/sources/examples/talk-llama/llama-kv-cache.h +358 -27
  42. data/ext/sources/examples/talk-llama/llama-kv-cells.h +80 -28
  43. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +56 -36
  44. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +30 -29
  45. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +48 -19
  46. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +13 -14
  47. data/ext/sources/examples/talk-llama/llama-memory.h +16 -10
  48. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +2 -0
  49. data/ext/sources/examples/talk-llama/llama-model-loader.h +3 -2
  50. data/ext/sources/examples/talk-llama/llama-model.cpp +7165 -2336
  51. data/ext/sources/examples/talk-llama/llama-model.h +60 -9
  52. data/ext/sources/examples/talk-llama/llama-quant.cpp +48 -10
  53. data/ext/sources/examples/talk-llama/llama-sampling.cpp +226 -126
  54. data/ext/sources/examples/talk-llama/llama-vocab.cpp +440 -13
  55. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -0
  56. data/ext/sources/examples/talk-llama/llama.cpp +65 -10
  57. data/ext/sources/examples/talk-llama/llama.h +95 -177
  58. data/ext/sources/examples/talk-llama/talk-llama.cpp +9 -6
  59. data/ext/sources/examples/talk-llama/unicode.cpp +207 -0
  60. data/ext/sources/examples/talk-llama/unicode.h +45 -0
  61. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +4 -2
  62. data/ext/sources/examples/whisper.wasm/index-tmpl.html +17 -16
  63. data/ext/sources/ggml/CMakeLists.txt +59 -31
  64. data/ext/sources/ggml/cmake/ggml-config.cmake.in +132 -93
  65. data/ext/sources/ggml/include/ggml-backend.h +17 -1
  66. data/ext/sources/ggml/include/ggml-cpu.h +1 -1
  67. data/ext/sources/ggml/include/ggml-metal.h +1 -6
  68. data/ext/sources/ggml/include/ggml-opt.h +25 -6
  69. data/ext/sources/ggml/include/ggml-webgpu.h +19 -0
  70. data/ext/sources/ggml/include/ggml-zdnn.h +17 -0
  71. data/ext/sources/ggml/include/ggml.h +221 -16
  72. data/ext/sources/ggml/src/CMakeLists.txt +17 -2
  73. data/ext/sources/ggml/src/ggml-alloc.c +265 -141
  74. data/ext/sources/ggml/src/ggml-backend-impl.h +4 -1
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +30 -13
  76. data/ext/sources/ggml/src/ggml-backend.cpp +221 -38
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -4
  79. data/ext/sources/ggml/src/ggml-cann/CMakeLists.txt +14 -0
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +3 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +903 -717
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +143 -25
  83. data/ext/sources/ggml/src/ggml-cann/common.h +143 -1
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +488 -69
  85. data/ext/sources/ggml/src/ggml-common.h +17 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +40 -18
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +4 -2
  88. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +132 -596
  89. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +14 -286
  90. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +103 -582
  91. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +162 -589
  92. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +265 -437
  93. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +3 -58
  94. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +521 -353
  95. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +54 -314
  96. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +184 -675
  97. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +4679 -1657
  98. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +32 -2
  99. data/ext/sources/ggml/src/ggml-cpu/common.h +14 -0
  100. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -6
  101. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +70 -42
  102. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +35 -28
  103. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +152 -18
  104. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +7 -1
  105. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +227 -97
  106. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +474 -1116
  107. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1587 -1177
  108. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -8
  109. data/ext/sources/ggml/src/ggml-cpu/quants.c +35 -0
  110. data/ext/sources/ggml/src/ggml-cpu/quants.h +8 -0
  111. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +458 -47
  112. data/ext/sources/ggml/src/ggml-cpu/repack.h +22 -0
  113. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +89 -60
  114. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1024 -0
  115. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +13 -0
  116. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +3196 -0
  117. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +26 -0
  118. data/ext/sources/ggml/src/ggml-cpu/traits.cpp +2 -2
  119. data/ext/sources/ggml/src/ggml-cpu/traits.h +1 -1
  120. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +170 -26
  121. data/ext/sources/ggml/src/ggml-cpu/vec.h +506 -63
  122. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +20 -16
  123. data/ext/sources/ggml/src/ggml-cuda/add-id.cu +58 -0
  124. data/ext/sources/ggml/src/ggml-cuda/add-id.cuh +3 -0
  125. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +330 -191
  126. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +2 -0
  127. data/ext/sources/ggml/src/ggml-cuda/common.cuh +250 -63
  128. data/ext/sources/ggml/src/ggml-cuda/conv-transpose-1d.cu +1 -4
  129. data/ext/sources/ggml/src/ggml-cuda/conv2d.cu +166 -0
  130. data/ext/sources/ggml/src/ggml-cuda/conv2d.cuh +5 -0
  131. data/ext/sources/ggml/src/ggml-cuda/convert.cu +95 -22
  132. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +15 -0
  133. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +217 -0
  134. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +64 -307
  135. data/ext/sources/ggml/src/ggml-cuda/cross-entropy-loss.cu +2 -14
  136. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +14 -40
  137. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +498 -367
  138. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +137 -91
  139. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +755 -0
  140. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +3 -0
  141. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +593 -0
  142. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +86 -50
  143. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +185 -198
  144. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  145. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +50 -39
  146. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +379 -107
  147. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +196 -35
  148. data/ext/sources/ggml/src/ggml-cuda/im2col.cuh +1 -0
  149. data/ext/sources/ggml/src/ggml-cuda/mean.cu +56 -2
  150. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +198 -45
  151. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +123 -0
  152. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +496 -0
  153. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +206 -57
  154. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +1262 -721
  155. data/ext/sources/ggml/src/ggml-cuda/{mmv.cu → mmvf.cu} +53 -53
  156. data/ext/sources/ggml/src/ggml-cuda/{mmv.cuh → mmvf.cuh} +3 -3
  157. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +64 -73
  158. data/ext/sources/ggml/src/ggml-cuda/norm.cu +284 -12
  159. data/ext/sources/ggml/src/ggml-cuda/norm.cuh +7 -0
  160. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cu +49 -0
  161. data/ext/sources/ggml/src/ggml-cuda/opt-step-sgd.cuh +5 -0
  162. data/ext/sources/ggml/src/ggml-cuda/pad.cu +46 -23
  163. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cu +91 -0
  164. data/ext/sources/ggml/src/ggml-cuda/pad_reflect_1d.cuh +5 -0
  165. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +12 -10
  166. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +53 -0
  167. data/ext/sources/ggml/src/ggml-cuda/roll.cu +67 -0
  168. data/ext/sources/ggml/src/ggml-cuda/roll.cuh +5 -0
  169. data/ext/sources/ggml/src/ggml-cuda/rope.cu +21 -27
  170. data/ext/sources/ggml/src/ggml-cuda/scale.cu +14 -11
  171. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +276 -0
  172. data/ext/sources/ggml/src/ggml-cuda/set-rows.cuh +7 -0
  173. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +34 -0
  174. data/ext/sources/ggml/src/ggml-cuda/softcap.cuh +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +126 -59
  176. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +10 -2
  177. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +322 -100
  178. data/ext/sources/ggml/src/ggml-cuda/sum.cu +6 -10
  179. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +21 -4
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu +7 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu +7 -0
  191. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu +7 -0
  192. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu +7 -0
  193. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu +7 -0
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu +7 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu +7 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu +7 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu +7 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu +7 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu +7 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu +7 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu +7 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu +7 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu +7 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu +7 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu +7 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu +7 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu +7 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu +7 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu +7 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +21 -18
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu +5 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu +5 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu +5 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu +5 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu +5 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu +5 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu +5 -0
  226. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu +5 -0
  227. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu +5 -0
  228. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu +5 -0
  229. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu +5 -0
  230. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu +5 -0
  231. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu +5 -0
  232. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu +5 -0
  233. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu +5 -0
  234. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +259 -0
  235. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +14 -0
  236. data/ext/sources/ggml/src/ggml-cuda/tsembd.cu +3 -3
  237. data/ext/sources/ggml/src/ggml-cuda/unary.cu +90 -0
  238. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +8 -0
  239. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +92 -6
  240. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +110 -22
  241. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  242. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +58 -36
  243. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +4 -3
  244. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +10 -2
  245. data/ext/sources/ggml/src/ggml-impl.h +119 -9
  246. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -7
  247. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +446 -0
  248. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.h +52 -0
  249. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +33 -0
  250. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +600 -0
  251. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +1376 -0
  252. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +226 -0
  253. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +1308 -0
  254. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +136 -63
  255. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +3158 -0
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +82 -0
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +718 -0
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +2854 -1503
  259. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +18 -8
  260. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +18 -0
  261. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +2510 -242
  262. data/ext/sources/ggml/src/ggml-opencl/kernels/add.cl +107 -0
  263. data/ext/sources/ggml/src/ggml-opencl/kernels/add_id.cl +42 -0
  264. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d.cl +185 -0
  265. data/ext/sources/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl +176 -0
  266. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +84 -0
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/div.cl +66 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f16.cl +370 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +370 -0
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl +373 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gelu.cl +27 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/glu.cl +177 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/group_norm.cl +49 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f16.cl +1 -1
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/im2col_f32.cl +1 -1
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/mul.cl +73 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl +130 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +132 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +133 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32.cl +189 -0
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_mxfp4_f32_flat.cl +176 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32.cl +140 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_id_q8_0_f32_flat.cl +222 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32.cl +144 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_mxfp4_f32_flat.cl +167 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32.cl +125 -0
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q8_0_f32_flat.cl +202 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/norm.cl +80 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +79 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +3 -2
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +189 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +34 -13
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +34 -13
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f16.cl +34 -13
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/softmax_f32.cl +34 -13
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/sub.cl +66 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +20 -0
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/tsembd.cl +2 -2
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/upscale.cl +2 -3
  300. data/ext/sources/ggml/src/ggml-opt.cpp +97 -41
  301. data/ext/sources/ggml/src/ggml-quants.c +111 -16
  302. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  303. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +67 -47
  304. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +2 -0
  305. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +15 -5
  306. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +6 -0
  307. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +25 -16
  308. data/ext/sources/ggml/src/ggml-sycl/conv.cpp +10 -4
  309. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +166 -99
  310. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +72 -306
  311. data/ext/sources/ggml/src/ggml-sycl/cpy.hpp +213 -1
  312. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +67 -49
  313. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +1 -31
  314. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +79 -29
  315. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  316. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +14 -26
  317. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +9 -6
  318. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +328 -323
  319. data/ext/sources/ggml/src/ggml-sycl/gla.cpp +2 -2
  320. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +2 -2
  321. data/ext/sources/ggml/src/ggml-sycl/mmq.cpp +80 -60
  322. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +201 -132
  323. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +74 -55
  324. data/ext/sources/ggml/src/ggml-sycl/quantize.hpp +133 -0
  325. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +8 -9
  326. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +35 -42
  327. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +234 -0
  328. data/ext/sources/ggml/src/ggml-sycl/set_rows.hpp +8 -0
  329. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +3 -3
  330. data/ext/sources/ggml/src/ggml-sycl/tsembd.cpp +12 -6
  331. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +2 -6
  332. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +16 -12
  333. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3492 -883
  334. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +41 -1
  335. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +42 -0
  336. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +13 -4
  337. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +39 -29
  338. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +349 -0
  339. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +2 -2
  340. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +66 -12
  341. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +154 -0
  342. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +21 -0
  343. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  344. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +2 -1
  345. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +6 -5
  346. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +4 -2
  347. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +32 -0
  348. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +1 -1
  349. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  350. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +1 -1
  351. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +1 -1
  352. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  353. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +21 -0
  354. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +69 -24
  355. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +60 -20
  356. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +98 -42
  357. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +64 -27
  358. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +74 -13
  359. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +27 -0
  360. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +11 -0
  361. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +39 -0
  362. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp +4 -17
  363. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +19 -10
  364. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +25 -15
  365. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp +4 -0
  366. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +22 -0
  367. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +22 -0
  368. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +18 -14
  369. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +126 -0
  370. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +65 -1
  371. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +11 -7
  372. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +140 -0
  373. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +144 -531
  374. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +206 -38
  375. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp +556 -0
  376. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +12 -5
  377. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +15 -9
  378. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +111 -0
  379. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +22 -0
  380. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +24 -3
  381. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +53 -3
  382. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +55 -11
  383. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +65 -0
  384. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +46 -0
  385. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +1 -4
  386. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -9
  387. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -9
  388. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -9
  389. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp +5 -0
  390. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +1 -1
  391. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +29 -7
  392. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +4 -0
  393. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +17 -0
  394. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +38 -5
  395. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +14 -0
  396. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +4 -3
  397. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +101 -9
  398. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +69 -5
  399. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp +25 -0
  400. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -77
  401. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +54 -0
  402. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1558 -0
  403. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +44 -0
  404. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +41 -0
  405. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +45 -0
  406. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +930 -0
  407. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +60 -0
  408. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +124 -0
  409. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +874 -0
  410. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/memset.wgsl +40 -0
  411. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +44 -0
  412. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +41 -0
  413. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +907 -0
  414. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +57 -0
  415. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +48 -0
  416. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +81 -0
  417. data/ext/sources/ggml/src/ggml-zdnn/CMakeLists.txt +36 -0
  418. data/ext/sources/ggml/src/ggml-zdnn/common.hpp +59 -0
  419. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +628 -0
  420. data/ext/sources/ggml/src/ggml-zdnn/mmf.cpp +80 -0
  421. data/ext/sources/ggml/src/ggml-zdnn/mmf.hpp +12 -0
  422. data/ext/sources/ggml/src/ggml-zdnn/utils.cpp +79 -0
  423. data/ext/sources/ggml/src/ggml-zdnn/utils.hpp +19 -0
  424. data/ext/sources/ggml/src/ggml.c +478 -98
  425. data/ext/sources/ggml/src/gguf.cpp +8 -1
  426. data/ext/sources/src/whisper.cpp +23 -46
  427. data/ext/sources/tests/CMakeLists.txt +8 -1
  428. data/ext/sources/tests/test-vad-full.cpp +3 -3
  429. data/ext/sources/tests/test-vad.cpp +2 -2
  430. data/lib/whisper/model/uri.rb +1 -1
  431. data/sig/whisper.rbs +7 -0
  432. data/test/test_params.rb +8 -0
  433. data/test/test_whisper.rb +1 -1
  434. data/whispercpp.gemspec +1 -1
  435. metadata +164 -157
  436. data/ext/sources/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +0 -279
  437. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.cpp +0 -1841
  438. data/ext/sources/examples/talk-llama/llama-kv-cache-unified.h +0 -303
  439. data/ext/sources/ggml/include/ggml-kompute.h +0 -50
  440. data/ext/sources/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  441. data/ext/sources/ggml/src/ggml-amx/common.h +0 -94
  442. data/ext/sources/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
  443. data/ext/sources/ggml/src/ggml-amx/mmq.cpp +0 -2510
  444. data/ext/sources/ggml/src/ggml-amx/mmq.h +0 -17
  445. data/ext/sources/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  446. data/ext/sources/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  447. data/ext/sources/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  448. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  449. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  450. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  451. data/ext/sources/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  452. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  453. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  454. data/ext/sources/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  455. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cu +0 -357
  456. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f16.cuh +0 -3
  457. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cu +0 -365
  458. data/ext/sources/ggml/src/ggml-cuda/fattn-tile-f32.cuh +0 -3
  459. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f16.cuh +0 -482
  460. data/ext/sources/ggml/src/ggml-cuda/fattn-vec-f32.cuh +0 -472
  461. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +0 -5
  462. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +0 -5
  463. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +0 -5
  464. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +0 -5
  465. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +0 -5
  466. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +0 -5
  467. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +0 -5
  468. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +0 -5
  469. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +0 -5
  470. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +0 -5
  471. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +0 -5
  472. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +0 -5
  473. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +0 -5
  474. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +0 -5
  475. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +0 -5
  476. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +0 -5
  477. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +0 -5
  478. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +0 -5
  479. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +0 -5
  480. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +0 -5
  481. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +0 -5
  482. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +0 -5
  483. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +0 -5
  484. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +0 -5
  485. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +0 -5
  486. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +0 -5
  487. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +0 -5
  488. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +0 -5
  489. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +0 -5
  490. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +0 -5
  491. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +0 -5
  492. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +0 -5
  493. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +0 -5
  494. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +0 -5
  495. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +0 -5
  496. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +0 -5
  497. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +0 -5
  498. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +0 -5
  499. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +0 -5
  500. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +0 -5
  501. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +0 -5
  502. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +0 -5
  503. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +0 -5
  504. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +0 -5
  505. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +0 -5
  506. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +0 -5
  507. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +0 -5
  508. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +0 -5
  509. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +0 -5
  510. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +0 -5
  511. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +0 -5
  512. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +0 -5
  513. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +0 -5
  514. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +0 -5
  515. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +0 -5
  516. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +0 -5
  517. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +0 -5
  518. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +0 -5
  519. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +0 -5
  520. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +0 -5
  521. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +0 -5
  522. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +0 -5
  523. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +0 -5
  524. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +0 -5
  525. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +0 -5
  526. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +0 -5
  527. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +0 -5
  528. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +0 -5
  529. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +0 -5
  530. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +0 -5
  531. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +0 -5
  532. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +0 -5
  533. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +0 -5
  534. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +0 -5
  535. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +0 -5
  536. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +0 -5
  537. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +0 -5
  538. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +0 -5
  539. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +0 -5
  540. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +0 -5
  541. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +0 -5
  542. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +0 -5
  543. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +0 -5
  544. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +0 -5
  545. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +0 -5
  546. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +0 -5
  547. data/ext/sources/ggml/src/ggml-kompute/CMakeLists.txt +0 -166
  548. data/ext/sources/ggml/src/ggml-kompute/ggml-kompute.cpp +0 -2251
  549. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/common.comp +0 -112
  550. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_add.comp +0 -58
  551. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp +0 -25
  552. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp +0 -52
  553. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp +0 -52
  554. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp +0 -52
  555. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp +0 -52
  556. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp +0 -30
  557. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp +0 -22
  558. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp +0 -17
  559. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp +0 -31
  560. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp +0 -31
  561. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp +0 -38
  562. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp +0 -39
  563. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp +0 -44
  564. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp +0 -52
  565. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp +0 -69
  566. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp +0 -51
  567. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp +0 -33
  568. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp +0 -35
  569. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp +0 -140
  570. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp +0 -106
  571. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp +0 -73
  572. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp +0 -52
  573. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp +0 -28
  574. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp +0 -84
  575. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp +0 -21
  576. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp +0 -53
  577. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp +0 -52
  578. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp +0 -52
  579. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp +0 -52
  580. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp +0 -52
  581. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp +0 -19
  582. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp +0 -23
  583. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp +0 -22
  584. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp +0 -72
  585. data/ext/sources/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp +0 -71
  586. data/ext/sources/ggml/src/ggml-metal/ggml-metal.m +0 -6280
@@ -24,6 +24,7 @@
24
24
 
25
25
  #include <acl/acl.h>
26
26
  #include <stdarg.h>
27
+ #include <aclnnop/aclnn_trans_matmul_weight.h>
27
28
 
28
29
  #include <cmath>
29
30
  #include <cstdio>
@@ -74,13 +75,12 @@
74
75
  * @param device The device ID to set.
75
76
  */
76
77
  void ggml_cann_set_device(const int32_t device) {
77
- // TODO: uncomment these lines after empty context has fixed.
78
- // int current_device;
79
- // ACL_CHECK(aclrtGetDevice(&current_device));
78
+ int current_device = -1;
79
+ aclrtGetDevice(&current_device);
80
80
 
81
- // if (device == current_device) {
82
- // return;
83
- // }
81
+ if (device == current_device) {
82
+ return;
83
+ }
84
84
  ACL_CHECK(aclrtSetDevice(device));
85
85
  }
86
86
 
@@ -115,6 +115,24 @@ bool parse_bool(const std::string& value) {
115
115
  return valid_values.find(value) != valid_values.end();
116
116
  }
117
117
 
118
+ /**
119
+ * @brief Parse a string as an integer, returning 0 if invalid.
120
+ *
121
+ * This function attempts to convert the input string `value` to an `int`.
122
+ * If the string is not a valid integer or is out of the `int` range,
123
+ * it returns 0.
124
+ *
125
+ * @param value The string to parse.
126
+ * @return The parsed integer, or 0 if conversion fails.
127
+ */
128
+ int parse_integer(const std::string& value) {
129
+ try {
130
+ return std::stoi(value);
131
+ } catch (...) {
132
+ return 0;
133
+ }
134
+ }
135
+
118
136
  /**
119
137
  * @brief Initialize the CANN device information.
120
138
  *
@@ -1115,6 +1133,98 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(
1115
1133
  return GGML_STATUS_SUCCESS;
1116
1134
  }
1117
1135
 
1136
+ /**
1137
+ * @brief Workspace for caching NZ buffers per device.
1138
+ *
1139
+ * This struct manages a device buffer used in NZ computations. It supports
1140
+ * allocation, reallocation, and clearing of cached memory. The struct is
1141
+ * designed to be used with a global array, one per device.
1142
+ */
1143
+ struct ggml_cann_nz_workspace {
1144
+ void* ptr; // Pointer to allocated device buffer
1145
+ size_t allocated; // Size of currently allocated buffer in bytes
1146
+
1147
+ /**
1148
+ * @brief Constructor. Initializes the workspace with no allocated memory.
1149
+ */
1150
+ ggml_cann_nz_workspace() : ptr(nullptr), allocated(0) {}
1151
+
1152
+ /**
1153
+ * @brief Free cached memory and reset the workspace.
1154
+ *
1155
+ * If a buffer has been allocated, this function releases it using
1156
+ * aclrtFree and resets internal state.
1157
+ */
1158
+ void clear() {
1159
+ if (ptr) {
1160
+ ACL_CHECK(aclrtFree(ptr));
1161
+ ptr = nullptr;
1162
+ allocated = 0;
1163
+ }
1164
+ }
1165
+
1166
+ /**
1167
+ * @brief Allocate or reallocate the workspace buffer.
1168
+ *
1169
+ * If the requested size is larger than the currently allocated size,
1170
+ * the old buffer will be freed and a new buffer of the requested size
1171
+ * will be allocated on the device.
1172
+ *
1173
+ * @param new_size Size in bytes to allocate for the workspace.
1174
+ */
1175
+ void realloc(size_t new_size) {
1176
+ if (new_size > allocated) {
1177
+ clear();
1178
+ ACL_CHECK(aclrtMalloc(&ptr, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
1179
+ allocated = new_size;
1180
+ }
1181
+ }
1182
+
1183
+ /**
1184
+ * @brief Get the device buffer pointer.
1185
+ *
1186
+ * @return Pointer to the allocated buffer, or nullptr if not allocated.
1187
+ */
1188
+ void* get() const { return ptr; }
1189
+ };
1190
+
1191
+ /**
1192
+ * @brief Global array of NZ workspaces, one per device.
1193
+ */
1194
+ static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES];
1195
+
1196
+ /**
1197
+ * @brief Convert tensor weights to NZ format using Ascend CANN API.
1198
+ *
1199
+ * This function creates a transposed tensor descriptor and performs the
1200
+ * TransMatmulWeight operation. Converting tensor formats can significantly
1201
+ * improve performance on certain hardware.
1202
+ *
1203
+ * @param tensor Pointer to the input ggml_tensor containing the weights.
1204
+ * @param offset Byte offset within the tensor data buffer where weights start.
1205
+ * @param device device id.
1206
+ *
1207
+ * @note The workspace buffer used in this function is managed globally and reused
1208
+ * across calls. This reduces overhead from repeated memory allocation and deallocation.
1209
+ */
1210
+ static void weight_format_to_nz(ggml_tensor *tensor, size_t offset, int device) {
1211
+ aclTensor* weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne,
1212
+ tensor->nb, 2, ACL_FORMAT_ND, offset);
1213
+ uint64_t workspaceSize = 0;
1214
+ aclOpExecutor *executor;
1215
+
1216
+ // TransMatmulWeight
1217
+ ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed,
1218
+ &workspaceSize, &executor));
1219
+ // Avoid frequent malloc/free of the workspace.
1220
+ g_nz_workspaces[device].realloc(workspaceSize);
1221
+
1222
+ void* g_nz_workspace = g_nz_workspaces[device].get();
1223
+
1224
+ ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr));
1225
+ ACL_CHECK(aclDestroyTensor(weightTransposed));
1226
+ }
1227
+
1118
1228
  // TODO: need handle tensor which has paddings.
1119
1229
  /**
1120
1230
  * @brief Set tensor data in a CANN buffer.
@@ -1139,9 +1249,16 @@ static void ggml_backend_cann_buffer_set_tensor(
1139
1249
  // For acl, synchronous functions use this default stream.
1140
1250
  // Why aclrtSynchronizeDevice?
1141
1251
 
1252
+ // Only check env once.
1253
+ static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
1142
1254
  if (!need_transform(tensor->type)) {
1143
1255
  ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size,
1144
1256
  ACL_MEMCPY_HOST_TO_DEVICE));
1257
+ if (weight_to_nz && is_matmul_weight((const ggml_tensor*)tensor)) {
1258
+ GGML_ASSERT(tensor->ne[2] == 1);
1259
+ GGML_ASSERT(tensor->ne[3] == 1);
1260
+ weight_format_to_nz(tensor, offset, ctx->device);
1261
+ }
1145
1262
  } else {
1146
1263
  void *transform_buffer = malloc(size);
1147
1264
  ggml_backend_cann_transform(tensor, data, transform_buffer);
@@ -1216,6 +1333,10 @@ static bool ggml_backend_cann_buffer_cpy_tensor(
1216
1333
  ACL_MEMCPY_DEVICE_TO_DEVICE));
1217
1334
  return true;
1218
1335
  } else {
1336
+ #ifdef ASCEND_310P
1337
+ // TODO: Support 310p P2P copy
1338
+ return false;
1339
+ #endif
1219
1340
  // Different device but can access by peer.
1220
1341
  int32_t canAccessPeer = 0;
1221
1342
  ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device,
@@ -1375,20 +1496,32 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(
1375
1496
  size_t size = ggml_nbytes(tensor);
1376
1497
  int64_t ne0 = tensor->ne[0];
1377
1498
 
1499
+ // Only check env once.
1500
+ static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
1501
+
1378
1502
  // last line must bigger than 32, because every single op deal at
1379
1503
  // least 32 bytes.
1380
1504
  // TODO: quantized type?
1381
1505
  // int64_t line_size = ne0 * ggml_element_size(tensor);
1382
1506
  // int64_t line_size_align_32 = (line_size + 31) & ~31;
1383
1507
  // size += (line_size_align_32 - line_size);
1384
-
1385
- // TODO: not support quantized yet.
1386
- // TODO: consider un-continue tensor.
1387
1508
  if (ggml_is_quantized(tensor->type)) {
1388
1509
  if (ne0 % MATRIX_ROW_PADDING != 0) {
1389
1510
  size += ggml_row_size(
1390
1511
  tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
1391
1512
  }
1513
+ } else if (weight_to_nz && is_matmul_weight((const ggml_tensor*)tensor)) {
1514
+ // NZ format weight are not support quantized yet.
1515
+ // If ND tensor transform to NZ, size may changed.
1516
+ int64_t shape[] = {tensor->ne[1], tensor->ne[0]};
1517
+ GGML_ASSERT(tensor->ne[2] == 1);
1518
+ GGML_ASSERT(tensor->ne[3] == 1);
1519
+ const aclIntArray *acl_shape = aclCreateIntArray(shape, 2);
1520
+ size_t new_size;
1521
+ ACL_CHECK(aclnnCalculateMatmulWeightSizeV2(acl_shape,
1522
+ ggml_cann_type_mapping(tensor->type), &new_size));
1523
+ ACL_CHECK(aclDestroyIntArray(acl_shape));
1524
+ size = std::max(size, new_size);
1392
1525
  }
1393
1526
 
1394
1527
  return size;
@@ -1594,6 +1727,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1594
1727
  case GGML_OP_GET_ROWS:
1595
1728
  ggml_cann_get_rows(ctx, dst);
1596
1729
  break;
1730
+ case GGML_OP_SET_ROWS:
1731
+ ggml_cann_set_rows(ctx, dst);
1732
+ break;
1597
1733
  case GGML_OP_DUP:
1598
1734
  ggml_cann_dup(ctx, dst);
1599
1735
  break;
@@ -1616,16 +1752,18 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1616
1752
  case GGML_OP_UNARY:
1617
1753
  switch (ggml_get_unary_op(dst)) {
1618
1754
  case GGML_UNARY_OP_ABS:
1619
- GGML_CANN_CALL_UNARY_OP(Abs);
1755
+ GGML_CANN_CALL_OP_UNARY(Abs);
1620
1756
  break;
1621
1757
  case GGML_UNARY_OP_NEG:
1622
- GGML_CANN_CALL_UNARY_OP(Neg);
1758
+ GGML_CANN_CALL_OP_UNARY(Neg);
1623
1759
  break;
1624
1760
  case GGML_UNARY_OP_GELU:
1625
- GGML_CANN_CALL_UNARY_OP(Gelu);
1761
+ case GGML_UNARY_OP_GELU_ERF:
1762
+ // aclnnGelu internally uses the erf-based approximation.
1763
+ GGML_CANN_CALL_OP_UNARY(Gelu);
1626
1764
  break;
1627
1765
  case GGML_UNARY_OP_SILU:
1628
- GGML_CANN_CALL_UNARY_OP(Silu);
1766
+ GGML_CANN_CALL_OP_UNARY(Silu);
1629
1767
  break;
1630
1768
  case GGML_UNARY_OP_GELU_QUICK: {
1631
1769
  auto lambda = [](ggml_backend_cann_context& ctx,
@@ -1633,31 +1771,31 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1633
1771
  aclTensor* acl_dst) {
1634
1772
  GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
1635
1773
  };
1636
- ggml_cann_unary_op(lambda, ctx, dst);
1774
+ ggml_cann_op_unary(lambda, ctx, dst);
1637
1775
  } break;
1638
1776
  case GGML_UNARY_OP_TANH:
1639
- GGML_CANN_CALL_UNARY_OP(Tanh);
1777
+ GGML_CANN_CALL_OP_UNARY(Tanh);
1640
1778
  break;
1641
1779
  case GGML_UNARY_OP_RELU:
1642
- GGML_CANN_CALL_UNARY_OP(Relu);
1780
+ GGML_CANN_CALL_OP_UNARY(Relu);
1643
1781
  break;
1644
1782
  case GGML_UNARY_OP_SIGMOID:
1645
- GGML_CANN_CALL_UNARY_OP(Sigmoid);
1783
+ GGML_CANN_CALL_OP_UNARY(Sigmoid);
1646
1784
  break;
1647
1785
  case GGML_UNARY_OP_HARDSIGMOID:
1648
- GGML_CANN_CALL_UNARY_OP(Hardsigmoid);
1786
+ GGML_CANN_CALL_OP_UNARY(Hardsigmoid);
1649
1787
  break;
1650
1788
  case GGML_UNARY_OP_HARDSWISH:
1651
- GGML_CANN_CALL_UNARY_OP(Hardswish);
1789
+ GGML_CANN_CALL_OP_UNARY(Hardswish);
1652
1790
  break;
1653
1791
  case GGML_UNARY_OP_EXP:
1654
- GGML_CANN_CALL_UNARY_OP(Exp);
1792
+ GGML_CANN_CALL_OP_UNARY(Exp);
1655
1793
  break;
1656
1794
  case GGML_UNARY_OP_ELU:
1657
1795
  ggml_cann_elu(ctx, dst);
1658
1796
  break;
1659
1797
  case GGML_UNARY_OP_SGN:
1660
- GGML_CANN_CALL_UNARY_OP(Sign);
1798
+ GGML_CANN_CALL_OP_UNARY(Sign);
1661
1799
  break;
1662
1800
  case GGML_UNARY_OP_STEP:
1663
1801
  ggml_cann_step(ctx, dst);
@@ -1666,6 +1804,31 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1666
1804
  return false;
1667
1805
  }
1668
1806
  break;
1807
+ case GGML_OP_GLU:
1808
+ switch (ggml_get_glu_op(dst)) {
1809
+ case GGML_GLU_OP_REGLU:
1810
+ GGML_CANN_CALL_OP_UNARY_GATED(Relu);
1811
+ break;
1812
+ case GGML_GLU_OP_GEGLU:
1813
+ case GGML_GLU_OP_GEGLU_ERF:
1814
+ // aclnnGelu internally uses the erf-based approximation.
1815
+ GGML_CANN_CALL_OP_UNARY_GATED(Gelu);
1816
+ break;
1817
+ case GGML_GLU_OP_SWIGLU:
1818
+ GGML_CANN_CALL_OP_UNARY_GATED(Silu);
1819
+ break;
1820
+ case GGML_GLU_OP_GEGLU_QUICK: {
1821
+ auto lambda = [](ggml_backend_cann_context& ctx,
1822
+ aclTensor* acl_src,
1823
+ aclTensor* acl_dst) {
1824
+ GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
1825
+ };
1826
+ ggml_cann_op_unary_gated(lambda, ctx, dst);
1827
+ } break;
1828
+ default:
1829
+ return false;
1830
+ }
1831
+ break;
1669
1832
  case GGML_OP_NORM:
1670
1833
  ggml_cann_norm(ctx, dst);
1671
1834
  break;
@@ -1708,7 +1871,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1708
1871
  ggml_cann_binary_op<aclnn_mul>(ctx, dst);
1709
1872
  break;
1710
1873
  case GGML_OP_SQRT:
1711
- GGML_CANN_CALL_UNARY_OP(Sqrt);
1874
+ GGML_CANN_CALL_OP_UNARY(Sqrt);
1712
1875
  break;
1713
1876
  case GGML_OP_CLAMP:
1714
1877
  ggml_cann_clamp(ctx, dst);
@@ -1753,16 +1916,16 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1753
1916
  ggml_cann_argmax(ctx, dst);
1754
1917
  break;
1755
1918
  case GGML_OP_COS:
1756
- ggml_cann_unary_op<aclnn_cos>(ctx, dst);
1919
+ ggml_cann_op_unary<aclnn_cos>(ctx, dst);
1757
1920
  break;
1758
1921
  case GGML_OP_SIN:
1759
- ggml_cann_unary_op<aclnn_sin>(ctx, dst);
1922
+ ggml_cann_op_unary<aclnn_sin>(ctx, dst);
1760
1923
  break;
1761
1924
  case GGML_OP_CONV_TRANSPOSE_1D:
1762
1925
  ggml_cann_conv_transpose_1d(ctx, dst);
1763
1926
  break;
1764
1927
  case GGML_OP_LOG:
1765
- GGML_CANN_CALL_UNARY_OP(Log);
1928
+ GGML_CANN_CALL_OP_UNARY(Log);
1766
1929
  break;
1767
1930
  case GGML_OP_MEAN:
1768
1931
  ggml_cann_mean(ctx, dst);
@@ -1895,6 +2058,8 @@ static bool ggml_backend_cann_cpy_tensor_async(
1895
2058
  GGML_ASSERT(ggml_backend_is_cann(backend_src) ||
1896
2059
  ggml_backend_is_cann(backend_dst));
1897
2060
 
2061
+ GGML_ASSERT(!is_matmul_weight((const ggml_tensor*)src));
2062
+
1898
2063
  if (!ggml_backend_buffer_is_cann(src->buffer) ||
1899
2064
  !ggml_backend_buffer_is_cann(dst->buffer)) {
1900
2065
  return false;
@@ -1911,7 +2076,14 @@ static bool ggml_backend_cann_cpy_tensor_async(
1911
2076
  (ggml_backend_cann_context*)backend_dst->context;
1912
2077
 
1913
2078
  size_t copy_size = ggml_nbytes(dst);
2079
+ if (copy_size == 0) {
2080
+ return true;
2081
+ }
1914
2082
  if (backend_src != backend_dst) {
2083
+ #ifdef ASCEND_310P
2084
+ // TODO: Support 310p P2P copy
2085
+ return false;
2086
+ #endif
1915
2087
  ggml_backend_cann_buffer_context* buf_ctx_src =
1916
2088
  (ggml_backend_cann_buffer_context*)buf_src->context;
1917
2089
  ggml_backend_cann_buffer_context* buf_ctx_dst =
@@ -1928,7 +2100,6 @@ static bool ggml_backend_cann_cpy_tensor_async(
1928
2100
  }
1929
2101
 
1930
2102
  // need open both directions for memcpyasync between devices.
1931
- ggml_cann_set_device(cann_ctx_dst->device);
1932
2103
  ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_src->device, 0));
1933
2104
  ggml_cann_set_device(cann_ctx_src->device);
1934
2105
  ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
@@ -1938,9 +2109,17 @@ static bool ggml_backend_cann_cpy_tensor_async(
1938
2109
  ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
1939
2110
  ACL_MEMCPY_DEVICE_TO_DEVICE,
1940
2111
  cann_ctx_src->stream()));
1941
-
1942
- //TODO: workaround for Event didn`t work here.
1943
- aclrtSynchronizeStream(cann_ctx_src->stream());
2112
+ // record event on src stream after the copy
2113
+ // TODO: this event is not effective with acl graph mode, change to use aclrtSynchronizeStream
2114
+ // if (!cann_ctx_src->copy_event) {
2115
+ // ACL_CHECK(aclrtCreateEventWithFlag(&cann_ctx_src->copy_event, ACL_EVENT_SYNC));
2116
+ // }
2117
+ // ACL_CHECK(aclrtRecordEvent(cann_ctx_src->copy_event, cann_ctx_src->stream()));
2118
+
2119
+ // // wait on dst stream for the copy to complete
2120
+ // ggml_cann_set_device(cann_ctx_dst->device);
2121
+ // ACL_CHECK(aclrtStreamWaitEvent(cann_ctx_dst->stream(), cann_ctx_src->copy_event));
2122
+ ACL_CHECK(aclrtSynchronizeStream(cann_ctx_src->stream()));
1944
2123
  } else {
1945
2124
  // src and dst are on the same backend
1946
2125
  ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
@@ -1967,6 +2146,193 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
1967
2146
  ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
1968
2147
  }
1969
2148
 
2149
+ #ifdef USE_ACL_GRAPH
2150
+ /**
2151
+ * @brief Add a new CANN graph to the LRU cache by populating node properties from the ggml graph.
2152
+ *
2153
+ * This function creates a new ggml_cann_graph object and fills its node properties
2154
+ * (operation type, dimensions, strides, input sources, and operation parameters)
2155
+ * based on the current ggml computation graph.
2156
+ *
2157
+ * Each node in the ggml graph is mapped to a property entry in the new CANN graph:
2158
+ * - node address
2159
+ * - operation type
2160
+ * - shape (ne) and strides (nb)
2161
+ * - source tensor addresses
2162
+ * - operation parameters
2163
+ *
2164
+ * After initialization, the new graph is pushed into the LRU cache owned by the
2165
+ * CANN backend context. The cache takes ownership of the graph and manages its
2166
+ * lifetime (including deletion upon eviction).
2167
+ *
2168
+ * @param cann_ctx The CANN backend context containing the graph cache.
2169
+ * @param cgraph The current ggml computation graph.
2170
+ */
2171
+ static void add_lru_matched_graph_node_properties(
2172
+ ggml_backend_cann_context * cann_ctx,
2173
+ ggml_cgraph * cgraph) {
2174
+ // Create a new ggml_cann_graph object on the heap (its lifetime is managed by the cache).
2175
+ ggml_cann_graph * new_graph = new ggml_cann_graph();
2176
+ new_graph->ggml_graph_properties.resize(cgraph->n_nodes);
2177
+
2178
+ for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) {
2179
+ ggml_tensor * node = cgraph->nodes[node_idx];
2180
+ auto & prop = new_graph->ggml_graph_properties[node_idx];
2181
+
2182
+ prop.node_address = node->data;
2183
+ prop.node_op = node->op;
2184
+
2185
+ std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne);
2186
+ std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
2187
+
2188
+ for (int src = 0; src < GGML_MAX_SRC; ++src) {
2189
+ prop.src_address[src] = node->src[src] ? node->src[src]->data : nullptr;
2190
+ }
2191
+
2192
+ memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
2193
+ }
2194
+
2195
+ // Insert into the LRU cache (cache takes ownership and will delete it when evicted).
2196
+ cann_ctx->graph_lru_cache.push(new_graph);
2197
+ }
2198
+
2199
+ /**
2200
+ * @brief Check if a ggml tensor node matches a previously captured CANN graph node.
2201
+ *
2202
+ * This function compares all relevant fields (address, op type, shape, source inputs, op params)
2203
+ * to determine whether the current node matches a previously recorded version.
2204
+ *
2205
+ * @param node The current ggml tensor node.
2206
+ * @param graph_node_properties The stored properties of a CANN graph node.
2207
+ * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
2208
+ */
2209
+ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
2210
+ if (node->data != graph_node_properties->node_address &&
2211
+ node->op != GGML_OP_VIEW) {
2212
+ return false;
2213
+ }
2214
+ if (node->op != graph_node_properties->node_op) {
2215
+ return false;
2216
+ }
2217
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
2218
+ if (node->ne[i] != graph_node_properties->ne[i]) {
2219
+ return false;
2220
+ }
2221
+ if (node->nb[i] != graph_node_properties->nb[i]) {
2222
+ return false;
2223
+ }
2224
+ }
2225
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
2226
+ if (node->src[i] &&
2227
+ node->src[i]->data != graph_node_properties->src_address[i] &&
2228
+ node->op != GGML_OP_VIEW
2229
+ ) {
2230
+ return false;
2231
+ }
2232
+ }
2233
+ if (node->op == GGML_OP_SCALE &&
2234
+ memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
2235
+ return false;
2236
+ }
2237
+ return true;
2238
+ }
2239
+
2240
+ /**
2241
+ * @brief Check whether there is a cached CANN graph that matches the current ggml graph.
2242
+ *
2243
+ * This function iterates through the cached CANN graphs stored in the LRU cache and
2244
+ * compares them against the given ggml computation graph. A match requires that the
2245
+ * number of nodes is the same and that each node’s properties (operation type,
2246
+ * dimensions, strides, inputs, and operation parameters) are identical.
2247
+ *
2248
+ * If a matching graph is found, it is promoted to the front of the LRU cache and the
2249
+ * function returns true. Otherwise, the function returns false, indicating that a new
2250
+ * CANN graph needs to be captured.
2251
+ *
2252
+ * @param cann_ctx The CANN backend context containing the graph cache.
2253
+ * @param cgraph The current ggml computation graph.
2254
+ * @return true if a matching cached graph exists; false otherwise.
2255
+ */
2256
+ static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
2257
+ ggml_cann_graph_lru_cache &lru_cache = cann_ctx->graph_lru_cache;
2258
+ for (auto &graph_ptr : lru_cache.cache_list) {
2259
+ // Skip graphs with a different number of nodes.
2260
+ if (graph_ptr->ggml_graph_properties.size() != static_cast<size_t>(cgraph->n_nodes)) {
2261
+ continue;
2262
+ }
2263
+
2264
+ // Check if all nodes match.
2265
+ bool all_match = true;
2266
+ for (int i = 0; i < cgraph->n_nodes; ++i) {
2267
+ if (!ggml_graph_node_has_matching_properties(cgraph->nodes[i], &graph_ptr->ggml_graph_properties[i])) {
2268
+ all_match = false;
2269
+ break;
2270
+ }
2271
+ }
2272
+
2273
+ if (all_match) {
2274
+ // update cache_list && renturn graph_ptr
2275
+ lru_cache.move_to_front(graph_ptr);
2276
+ return true;
2277
+ }
2278
+ }
2279
+
2280
+ return false;
2281
+ }
2282
+ #endif // USE_ACL_GRAPH
2283
+
2284
+ /**
2285
+ * @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
2286
+ *
2287
+ * If CANN graph execution is enabled and graph capture is required, this function begins
2288
+ * graph capture, runs the graph, ends capture, and stores the captured graph.
2289
+ *
2290
+ * Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher.
2291
+ *
2292
+ * @param cann_ctx The CANN backend context.
2293
+ * @param cgraph The ggml computation graph.
2294
+ * @param use_cann_graph Whether to use CANN graph execution.
2295
+ * @param cann_graph_update_required Whether graph capture is needed due to graph changes.
2296
+ */
2297
+ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph,
2298
+ bool & use_cann_graph, bool & cann_graph_update_required) {
2299
+ #ifdef USE_ACL_GRAPH
2300
+ ggml_cann_graph* matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
2301
+ if (use_cann_graph && cann_graph_update_required) {
2302
+ ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
2303
+ }
2304
+ #endif // USE_ACL_GRAPH
2305
+ // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
2306
+ // With the use of CANN graphs, the execution will be performed by the graph launch.
2307
+ if (!use_cann_graph || cann_graph_update_required) {
2308
+ for (int i = 0; i < cgraph->n_nodes; i++) {
2309
+ ggml_tensor * node = cgraph->nodes[i];
2310
+
2311
+ if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2312
+ continue;
2313
+ }
2314
+
2315
+ bool ok = ggml_cann_compute_forward(*cann_ctx, node);
2316
+ if (!ok) {
2317
+ GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
2318
+ }
2319
+ GGML_ASSERT(ok);
2320
+ }
2321
+ }
2322
+
2323
+ #ifdef USE_ACL_GRAPH
2324
+ if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture
2325
+ ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
2326
+ }
2327
+
2328
+ if (use_cann_graph) {
2329
+ // Execute graph
2330
+ ACL_CHECK(aclmdlRIExecuteAsync(matched_graph->graph, cann_ctx->stream()));
2331
+ }
2332
+ #endif // USE_ACL_GRAPH
2333
+ }
2334
+
2335
+
1970
2336
  /**
1971
2337
  * @brief Computes a computational graph using a CANN backend.
1972
2338
  *
@@ -1983,24 +2349,53 @@ static enum ggml_status ggml_backend_cann_graph_compute(
1983
2349
  ggml_backend_t backend, ggml_cgraph* cgraph) {
1984
2350
  ggml_backend_cann_context* cann_ctx =
1985
2351
  (ggml_backend_cann_context*)backend->context;
1986
-
1987
2352
  ggml_cann_set_device(cann_ctx->device);
1988
-
1989
- for (int i = 0; i < cgraph->n_nodes; i++) {
1990
- ggml_tensor* node = cgraph->nodes[i];
1991
-
1992
- if (ggml_is_empty(node) || node->op == GGML_OP_NONE) {
1993
- continue;
2353
+ g_nz_workspaces[cann_ctx->device].clear();
2354
+
2355
+ // calculate rope cache for fist layer in current device.
2356
+ cann_ctx->rope_cache.cached = false;
2357
+
2358
+ #ifdef USE_ACL_GRAPH
2359
+ bool use_cann_graph = true;
2360
+ bool cann_graph_update_required = false;
2361
+
2362
+ static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
2363
+ if (!prefill_use_graph) {
2364
+ // Do not use acl_graph for prefill.
2365
+ for (int i = 0; i < cgraph->n_nodes; i++) {
2366
+ ggml_tensor * node = cgraph->nodes[i];
2367
+ // TODO: Optimize here. Currently, we can only
2368
+ // get seq_len by FA's input.
2369
+ if (node->op == GGML_OP_FLASH_ATTN_EXT) {
2370
+ // Q -> src[0], shape: [B, S, N, D]
2371
+ use_cann_graph = (node->src[0]->ne[1] == 1);
2372
+ break;
2373
+ }
1994
2374
  }
2375
+ }
1995
2376
 
1996
- bool ok = ggml_cann_compute_forward(*cann_ctx, node);
2377
+ if (!cann_ctx->acl_graph_mode) {
2378
+ use_cann_graph = false;
2379
+ }
1997
2380
 
1998
- if (!ok) {
1999
- GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__,
2000
- node->name, ggml_op_name(node->op));
2381
+ if (use_cann_graph) {
2382
+ // If no matching graph is found, the graph needs to be recaptured.
2383
+ cann_graph_update_required = !is_matched_graph(cann_ctx, cgraph);
2384
+ if (cann_graph_update_required) {
2385
+ // If no matching graph is found, add a new ACL graph.
2386
+ add_lru_matched_graph_node_properties(cann_ctx, cgraph);
2001
2387
  }
2002
- GGML_ASSERT(ok);
2003
2388
  }
2389
+ #else
2390
+ bool use_cann_graph = false;
2391
+ bool cann_graph_update_required = false;
2392
+ #endif // USE_ACL_GRAPH
2393
+ evaluate_and_capture_cann_graph(
2394
+ cann_ctx,
2395
+ cgraph,
2396
+ use_cann_graph,
2397
+ cann_graph_update_required
2398
+ );
2004
2399
 
2005
2400
  return GGML_STATUS_SUCCESS;
2006
2401
  }
@@ -2036,10 +2431,23 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2036
2431
  case GGML_UNARY_OP_ELU:
2037
2432
  case GGML_UNARY_OP_SGN:
2038
2433
  case GGML_UNARY_OP_STEP:
2434
+ case GGML_UNARY_OP_GELU_ERF:
2435
+ return true;
2436
+ default:
2437
+ return false;
2438
+ }
2439
+ case GGML_OP_GLU:
2440
+ switch (ggml_get_glu_op(op)) {
2441
+ case GGML_GLU_OP_REGLU:
2442
+ case GGML_GLU_OP_GEGLU:
2443
+ case GGML_GLU_OP_SWIGLU:
2444
+ case GGML_GLU_OP_GEGLU_ERF:
2445
+ case GGML_GLU_OP_GEGLU_QUICK:
2039
2446
  return true;
2040
2447
  default:
2041
2448
  return false;
2042
2449
  }
2450
+ break;
2043
2451
  case GGML_OP_MUL_MAT: {
2044
2452
  switch (op->src[0]->type) {
2045
2453
  case GGML_TYPE_F16:
@@ -2048,7 +2456,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2048
2456
  case GGML_TYPE_Q8_0:
2049
2457
  case GGML_TYPE_Q4_0:
2050
2458
  #ifdef ASCEND_310P
2051
- // Q4 && Q8 per group is not suppor on 310p device
2459
+ // Q4 && Q8 per group is not support on 310p device
2052
2460
  return false;
2053
2461
  #endif
2054
2462
  // only support contiguous for quantized types.
@@ -2066,7 +2474,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2066
2474
  case GGML_TYPE_Q8_0:
2067
2475
  case GGML_TYPE_Q4_0:
2068
2476
  #ifdef ASCEND_310P
2069
- // Q4 && Q8 per group is not suppor on 310p device
2477
+ // Q4 && Q8 per group is not support on 310p device
2070
2478
  return false;
2071
2479
  #endif
2072
2480
  // only support contiguous for quantized types.
@@ -2086,6 +2494,15 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2086
2494
  return false;
2087
2495
  }
2088
2496
  } break;
2497
+ case GGML_OP_SET_ROWS: {
2498
+ switch (op->type) {
2499
+ case GGML_TYPE_F32:
2500
+ case GGML_TYPE_F16:
2501
+ return true;
2502
+ default:
2503
+ return false;
2504
+ }
2505
+ } break;
2089
2506
  case GGML_OP_CPY: {
2090
2507
  ggml_tensor *src = op->src[0];
2091
2508
  if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
@@ -2094,12 +2511,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2094
2511
  // only support F32 and F16.
2095
2512
  return false;
2096
2513
  }
2097
-
2098
- if (!ggml_are_same_shape(op, src) && !ggml_is_contiguous(op)) {
2099
- // unsupport dst is not contiguous.
2100
- return false;
2101
- }
2102
-
2103
2514
  return true;
2104
2515
  } break;
2105
2516
  case GGML_OP_CONT: {
@@ -2114,16 +2525,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2114
2525
  }
2115
2526
  case GGML_OP_ROPE: {
2116
2527
  // TODO: with ops-test v == 1
2117
- float ext_factor = 0.0f;
2118
- memcpy(&ext_factor, (const float *) op->op_params + 7, sizeof(float));
2119
2528
  // TODO: n_dims <= ne0
2120
2529
  if (op->src[0]->ne[0] != op->op_params[1]) {
2121
2530
  return false;
2122
2531
  }
2123
- // TODO: ext_factor != 0
2124
- if (ext_factor != 0) {
2125
- return false;
2126
- }
2127
2532
 
2128
2533
  const int mode = ((const int32_t *) op->op_params)[2];
2129
2534
  if (mode & GGML_ROPE_TYPE_MROPE) {
@@ -2132,10 +2537,11 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2132
2537
  if (mode & GGML_ROPE_TYPE_VISION) {
2133
2538
  return false;
2134
2539
  }
2135
-
2540
+ #ifdef ASCEND_310P
2136
2541
  if(!ggml_is_contiguous(op->src[0])){
2137
2542
  return false;
2138
2543
  }
2544
+ #endif
2139
2545
  return true;
2140
2546
  }
2141
2547
  case GGML_OP_UPSCALE: {
@@ -2165,8 +2571,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2165
2571
  // value of paddingW should be at most half of kernelW
2166
2572
  return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
2167
2573
  }
2168
- case GGML_OP_SUM:
2169
2574
  case GGML_OP_DUP:
2575
+ case GGML_OP_SUM:
2170
2576
  case GGML_OP_IM2COL:
2171
2577
  case GGML_OP_CONCAT:
2172
2578
  case GGML_OP_REPEAT:
@@ -2182,12 +2588,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2182
2588
  case GGML_OP_MUL:
2183
2589
  case GGML_OP_DIV:
2184
2590
  case GGML_OP_RMS_NORM:
2185
- case GGML_OP_SCALE:
2186
2591
  case GGML_OP_SQR:
2187
2592
  case GGML_OP_SQRT:
2188
2593
  case GGML_OP_CLAMP:
2189
2594
  case GGML_OP_DIAG_MASK_INF:
2190
- case GGML_OP_SOFT_MAX:
2191
2595
  case GGML_OP_SUM_ROWS:
2192
2596
  case GGML_OP_ARGSORT:
2193
2597
  case GGML_OP_ACC:
@@ -2199,13 +2603,29 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2199
2603
  case GGML_OP_ARGMAX:
2200
2604
  case GGML_OP_COS:
2201
2605
  case GGML_OP_SIN:
2202
- case GGML_OP_CONV_TRANSPOSE_1D:
2203
2606
  case GGML_OP_LOG:
2204
2607
  case GGML_OP_MEAN:
2205
2608
  case GGML_OP_PAD_REFLECT_1D:
2206
2609
  case GGML_OP_COUNT_EQUAL:
2207
2610
  return true;
2611
+ case GGML_OP_CONV_TRANSPOSE_1D:
2612
+ // TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255.
2613
+ return (op->src[0]->ne[0] - 1) <= 255;
2614
+ case GGML_OP_SCALE:
2615
+ float bias;
2616
+ memcpy(&bias, (const float *)(op->op_params) + 1, sizeof(float));
2617
+ return bias == 0.0f; // TODO: support bias != 0.0f
2618
+ case GGML_OP_SOFT_MAX:
2619
+ // TODO: support attention sinks [TAG_ATTN_SINKS]
2620
+ if (op->src[2]) {
2621
+ return false;
2622
+ }
2623
+ return true;
2208
2624
  case GGML_OP_FLASH_ATTN_EXT:{
2625
+ #ifdef ASCEND_310P
2626
+ // FA not support on 310p device
2627
+ return false;
2628
+ #endif
2209
2629
  // derived from [ggml-cuda.cu]
2210
2630
  if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
2211
2631
  return false;
@@ -2216,22 +2636,20 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
2216
2636
  if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
2217
2637
  return false;
2218
2638
  }
2219
- if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
2220
- // different head sizes of K and V are not supported yet
2221
- return false;
2222
- }
2223
- if (op->src[0]->ne[0] == 192) {
2639
+ // TODO: support attention sinks [TAG_ATTN_SINKS]
2640
+ if (op->src[4]) {
2224
2641
  return false;
2225
2642
  }
2226
- if (op->src[0]->ne[0] == 576) {
2227
- // DeepSeek MLA
2643
+ if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
2644
+ // different head sizes of K and V are not supported yet
2228
2645
  return false;
2229
2646
  }
2230
- if (op->src[0]->ne[3] != 1) {
2647
+ if (op->src[0]->ne[0] % 16 != 0) {
2648
+ // TODO: padding to support
2231
2649
  return false;
2232
2650
  }
2233
2651
  float logitSoftcap = 0.0f;
2234
- memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float));
2652
+ memcpy(&logitSoftcap, (const float *)(op->op_params) + 2, sizeof(float));
2235
2653
  if(logitSoftcap != 0.0f) {
2236
2654
  return false;
2237
2655
  }
@@ -2338,6 +2756,7 @@ static const ggml_backend_i ggml_backend_cann_interface = {
2338
2756
  /* .graph_compute = */ ggml_backend_cann_graph_compute,
2339
2757
  /* .event_record = */ ggml_backend_cann_event_record,
2340
2758
  /* .event_wait = */ ggml_backend_cann_event_wait,
2759
+ /* .graph_optimize = */ NULL,
2341
2760
  };
2342
2761
 
2343
2762
  /**