@fugood/llama.node 0.2.2 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (320) hide show
  1. package/CMakeLists.txt +5 -2
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +8 -1
  17. package/package.json +1 -1
  18. package/patches/llama.patch +12 -12
  19. package/src/DetokenizeWorker.cpp +1 -1
  20. package/src/LlamaContext.cpp +33 -1
  21. package/src/LlamaContext.h +1 -0
  22. package/src/LoadSessionWorker.cpp +1 -0
  23. package/src/llama.cpp/.github/workflows/bench.yml +310 -0
  24. package/src/llama.cpp/.github/workflows/build.yml +1315 -0
  25. package/src/llama.cpp/.github/workflows/close-issue.yml +23 -0
  26. package/src/llama.cpp/.github/workflows/docker.yml +116 -0
  27. package/src/llama.cpp/.github/workflows/editorconfig.yml +27 -0
  28. package/src/llama.cpp/.github/workflows/gguf-publish.yml +44 -0
  29. package/src/llama.cpp/.github/workflows/labeler.yml +17 -0
  30. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +65 -0
  31. package/src/llama.cpp/.github/workflows/nix-ci.yml +72 -0
  32. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +22 -0
  33. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +36 -0
  34. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +35 -0
  35. package/src/llama.cpp/.github/workflows/python-lint.yml +23 -0
  36. package/src/llama.cpp/.github/workflows/python-type-check.yml +38 -0
  37. package/src/llama.cpp/.github/workflows/server.yml +183 -0
  38. package/src/llama.cpp/CMakeLists.txt +91 -1245
  39. package/src/llama.cpp/cmake/arm64-windows-llvm.cmake +1 -1
  40. package/src/llama.cpp/cmake/build-info.cmake +58 -0
  41. package/src/llama.cpp/cmake/git-vars.cmake +22 -0
  42. package/src/llama.cpp/common/CMakeLists.txt +4 -3
  43. package/src/llama.cpp/common/build-info.cpp.in +4 -0
  44. package/src/llama.cpp/common/common.cpp +1116 -877
  45. package/src/llama.cpp/common/common.h +191 -77
  46. package/src/llama.cpp/common/grammar-parser.cpp +118 -31
  47. package/src/llama.cpp/common/json-schema-to-grammar.cpp +346 -65
  48. package/src/llama.cpp/common/log.h +1 -1
  49. package/src/llama.cpp/common/ngram-cache.h +10 -3
  50. package/src/llama.cpp/common/sampling.cpp +19 -10
  51. package/src/llama.cpp/docs/build.md +353 -0
  52. package/src/llama.cpp/examples/CMakeLists.txt +22 -22
  53. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +1 -1
  54. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +6 -6
  55. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  56. package/src/llama.cpp/examples/batched/batched.cpp +52 -55
  57. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  58. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +20 -72
  59. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +1 -1
  60. package/src/llama.cpp/examples/chat-13B.bat +57 -0
  61. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/{finetune → cvector-generator}/CMakeLists.txt +2 -2
  63. package/src/llama.cpp/examples/cvector-generator/completions.txt +582 -0
  64. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +503 -0
  65. package/src/llama.cpp/examples/cvector-generator/mean.hpp +48 -0
  66. package/src/llama.cpp/examples/cvector-generator/negative.txt +4 -0
  67. package/src/llama.cpp/examples/cvector-generator/pca.hpp +325 -0
  68. package/src/llama.cpp/examples/cvector-generator/positive.txt +4 -0
  69. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +35 -0
  70. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  71. package/src/llama.cpp/examples/embedding/embedding.cpp +94 -46
  72. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +2 -2
  73. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +4 -6
  74. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  75. package/src/llama.cpp/examples/export-lora/export-lora.cpp +344 -386
  76. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +2 -2
  77. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +30 -25
  78. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/gguf/gguf.cpp +5 -0
  80. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +15 -0
  81. package/src/llama.cpp/examples/gguf-hash/deps/rotate-bits/rotate-bits.h +46 -0
  82. package/src/llama.cpp/examples/gguf-hash/deps/sha1/sha1.c +295 -0
  83. package/src/llama.cpp/examples/gguf-hash/deps/sha1/sha1.h +52 -0
  84. package/src/llama.cpp/examples/gguf-hash/deps/sha256/sha256.c +221 -0
  85. package/src/llama.cpp/examples/gguf-hash/deps/sha256/sha256.h +24 -0
  86. package/src/llama.cpp/examples/gguf-hash/deps/xxhash/xxhash.c +42 -0
  87. package/src/llama.cpp/examples/gguf-hash/deps/xxhash/xxhash.h +7093 -0
  88. package/src/llama.cpp/examples/gguf-hash/gguf-hash.cpp +693 -0
  89. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  90. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +3 -3
  91. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  92. package/src/llama.cpp/examples/gritlm/gritlm.cpp +6 -2
  93. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  94. package/src/llama.cpp/examples/imatrix/imatrix.cpp +137 -176
  95. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  96. package/src/llama.cpp/examples/infill/infill.cpp +38 -153
  97. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +175 -94
  98. package/src/llama.cpp/examples/llama.android/app/build.gradle.kts +65 -0
  99. package/src/llama.cpp/examples/llama.android/build.gradle.kts +6 -0
  100. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +68 -0
  101. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/CMakeLists.txt +11 -7
  102. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +2 -2
  103. package/src/llama.cpp/examples/llama.android/settings.gradle.kts +18 -0
  104. package/src/llama.cpp/examples/llava/CMakeLists.txt +6 -5
  105. package/src/llama.cpp/examples/llava/android/build_64.sh +8 -0
  106. package/src/llama.cpp/examples/llava/clip.cpp +23 -14
  107. package/src/llama.cpp/examples/llava/llava-cli.cpp +8 -6
  108. package/src/llama.cpp/examples/llava/requirements.txt +3 -2
  109. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  110. package/src/llama.cpp/examples/lookahead/lookahead.cpp +2 -1
  111. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  112. package/src/llama.cpp/examples/lookup/lookup-create.cpp +2 -0
  113. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  114. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -2
  115. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  116. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  117. package/src/llama.cpp/examples/main/main.cpp +98 -75
  118. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +4 -5
  119. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  120. package/src/llama.cpp/examples/parallel/parallel.cpp +2 -1
  121. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  122. package/src/llama.cpp/examples/passkey/passkey.cpp +23 -43
  123. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  124. package/src/llama.cpp/examples/perplexity/perplexity.cpp +13 -10
  125. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  126. package/src/llama.cpp/examples/quantize/quantize.cpp +37 -34
  127. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  128. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +1 -1
  129. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  130. package/src/llama.cpp/examples/retrieval/retrieval.cpp +26 -77
  131. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  132. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +14 -7
  133. package/src/llama.cpp/examples/server/CMakeLists.txt +26 -2
  134. package/src/llama.cpp/examples/server/server.cpp +274 -671
  135. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  136. package/src/llama.cpp/examples/server/utils.hpp +28 -29
  137. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  138. package/src/llama.cpp/examples/simple/simple.cpp +21 -29
  139. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  140. package/src/llama.cpp/examples/speculative/speculative.cpp +2 -1
  141. package/src/llama.cpp/examples/sycl/CMakeLists.txt +1 -1
  142. package/src/llama.cpp/examples/sycl/build.sh +23 -0
  143. package/src/llama.cpp/examples/sycl/run-llama2.sh +36 -0
  144. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +33 -0
  145. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +9 -0
  146. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  147. package/src/llama.cpp/examples/tokenize/tokenize.cpp +16 -2
  148. package/src/llama.cpp/ggml/CMakeLists.txt +253 -0
  149. package/src/llama.cpp/{cmake → ggml/cmake}/FindSIMD.cmake +6 -6
  150. package/src/llama.cpp/{ggml-backend.h → ggml/include/ggml-backend.h} +22 -17
  151. package/src/llama.cpp/ggml/include/ggml-blas.h +23 -0
  152. package/src/llama.cpp/ggml/include/ggml-cann.h +125 -0
  153. package/src/llama.cpp/{ggml-cuda.h → ggml/include/ggml-cuda.h} +3 -0
  154. package/src/llama.cpp/{ggml-metal.h → ggml/include/ggml-metal.h} +1 -2
  155. package/src/llama.cpp/{ggml-sycl.h → ggml/include/ggml-sycl.h} +3 -10
  156. package/src/llama.cpp/{ggml.h → ggml/include/ggml.h} +80 -85
  157. package/src/llama.cpp/ggml/src/CMakeLists.txt +1329 -0
  158. package/src/llama.cpp/ggml/src/ggml-aarch64.c +2193 -0
  159. package/src/llama.cpp/ggml/src/ggml-aarch64.h +39 -0
  160. package/src/llama.cpp/{ggml-alloc.c → ggml/src/ggml-alloc.c} +100 -49
  161. package/src/llama.cpp/{ggml-backend-impl.h → ggml/src/ggml-backend-impl.h} +20 -8
  162. package/src/llama.cpp/{ggml-backend.c → ggml/src/ggml-backend.c} +307 -167
  163. package/src/llama.cpp/ggml/src/ggml-blas.cpp +367 -0
  164. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +198 -0
  165. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +230 -0
  166. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +2944 -0
  167. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +592 -0
  168. package/src/llama.cpp/ggml/src/ggml-cann/common.h +282 -0
  169. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +32 -0
  170. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +17 -0
  171. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +223 -0
  172. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +186 -0
  173. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +180 -0
  174. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +193 -0
  175. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +191 -0
  176. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +208 -0
  177. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +206 -0
  178. package/src/llama.cpp/ggml/src/ggml-cann.cpp +2023 -0
  179. package/src/llama.cpp/{ggml-common.h → ggml/src/ggml-common.h} +41 -7
  180. package/src/llama.cpp/{ggml-impl.h → ggml/src/ggml-impl.h} +113 -9
  181. package/src/llama.cpp/{ggml-kompute.cpp → ggml/src/ggml-kompute.cpp} +33 -18
  182. package/src/llama.cpp/{ggml-quants.c → ggml/src/ggml-quants.c} +1460 -940
  183. package/src/llama.cpp/{ggml-quants.h → ggml/src/ggml-quants.h} +19 -20
  184. package/src/llama.cpp/{ggml-rpc.cpp → ggml/src/ggml-rpc.cpp} +95 -72
  185. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +27 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +53 -0
  187. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +355 -0
  188. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +195 -0
  189. package/src/llama.cpp/ggml/src/ggml-sycl/concat.hpp +21 -0
  190. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +547 -0
  191. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +27 -0
  192. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +698 -0
  193. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +1023 -0
  194. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.hpp +27 -0
  195. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +3011 -0
  196. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +3031 -0
  197. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.hpp +33 -0
  198. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1027 -0
  199. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.hpp +27 -0
  200. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +374 -0
  201. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +35 -0
  202. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +66 -0
  203. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +275 -0
  204. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +22 -0
  205. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +251 -0
  206. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.hpp +24 -0
  207. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +1140 -0
  208. package/src/llama.cpp/ggml/src/ggml-sycl.cpp +5314 -0
  209. package/src/llama.cpp/{ggml-vulkan.cpp → ggml/src/ggml-vulkan.cpp} +1781 -1868
  210. package/src/llama.cpp/{ggml.c → ggml/src/ggml.c} +1245 -2087
  211. package/src/llama.cpp/{sgemm.cpp → ggml/src/llamafile/sgemm.cpp} +21 -24
  212. package/src/llama.cpp/{sgemm.h → ggml/src/llamafile/sgemm.h} +1 -1
  213. package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +5 -0
  214. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +552 -0
  215. package/src/llama.cpp/{llama.h → include/llama.h} +175 -100
  216. package/src/llama.cpp/models/.editorconfig +1 -0
  217. package/src/llama.cpp/models/ggml-vocab-aquila.gguf +0 -0
  218. package/src/llama.cpp/models/ggml-vocab-baichuan.gguf +0 -0
  219. package/src/llama.cpp/models/ggml-vocab-bert-bge.gguf +0 -0
  220. package/src/llama.cpp/models/ggml-vocab-bert-bge.gguf.inp +112 -0
  221. package/src/llama.cpp/models/ggml-vocab-bert-bge.gguf.out +46 -0
  222. package/src/llama.cpp/models/ggml-vocab-command-r.gguf +0 -0
  223. package/src/llama.cpp/models/ggml-vocab-command-r.gguf.inp +112 -0
  224. package/src/llama.cpp/models/ggml-vocab-command-r.gguf.out +46 -0
  225. package/src/llama.cpp/models/ggml-vocab-deepseek-coder.gguf +0 -0
  226. package/src/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.inp +112 -0
  227. package/src/llama.cpp/models/ggml-vocab-deepseek-coder.gguf.out +46 -0
  228. package/src/llama.cpp/models/ggml-vocab-deepseek-llm.gguf +0 -0
  229. package/src/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.inp +112 -0
  230. package/src/llama.cpp/models/ggml-vocab-deepseek-llm.gguf.out +46 -0
  231. package/src/llama.cpp/models/ggml-vocab-falcon.gguf +0 -0
  232. package/src/llama.cpp/models/ggml-vocab-falcon.gguf.inp +112 -0
  233. package/src/llama.cpp/models/ggml-vocab-falcon.gguf.out +46 -0
  234. package/src/llama.cpp/models/ggml-vocab-gpt-2.gguf +0 -0
  235. package/src/llama.cpp/models/ggml-vocab-gpt-2.gguf.inp +112 -0
  236. package/src/llama.cpp/models/ggml-vocab-gpt-2.gguf.out +46 -0
  237. package/src/llama.cpp/models/ggml-vocab-gpt-neox.gguf +0 -0
  238. package/src/llama.cpp/models/ggml-vocab-llama-bpe.gguf +0 -0
  239. package/src/llama.cpp/models/ggml-vocab-llama-bpe.gguf.inp +112 -0
  240. package/src/llama.cpp/models/ggml-vocab-llama-bpe.gguf.out +46 -0
  241. package/src/llama.cpp/models/ggml-vocab-llama-spm.gguf +0 -0
  242. package/src/llama.cpp/models/ggml-vocab-llama-spm.gguf.inp +112 -0
  243. package/src/llama.cpp/models/ggml-vocab-llama-spm.gguf.out +46 -0
  244. package/src/llama.cpp/models/ggml-vocab-mpt.gguf +0 -0
  245. package/src/llama.cpp/models/ggml-vocab-mpt.gguf.inp +112 -0
  246. package/src/llama.cpp/models/ggml-vocab-mpt.gguf.out +46 -0
  247. package/src/llama.cpp/models/ggml-vocab-phi-3.gguf +0 -0
  248. package/src/llama.cpp/models/ggml-vocab-phi-3.gguf.inp +112 -0
  249. package/src/llama.cpp/models/ggml-vocab-phi-3.gguf.out +46 -0
  250. package/src/llama.cpp/models/ggml-vocab-qwen2.gguf +0 -0
  251. package/src/llama.cpp/models/ggml-vocab-qwen2.gguf.inp +112 -0
  252. package/src/llama.cpp/models/ggml-vocab-qwen2.gguf.out +46 -0
  253. package/src/llama.cpp/models/ggml-vocab-refact.gguf +0 -0
  254. package/src/llama.cpp/models/ggml-vocab-refact.gguf.inp +112 -0
  255. package/src/llama.cpp/models/ggml-vocab-refact.gguf.out +46 -0
  256. package/src/llama.cpp/models/ggml-vocab-starcoder.gguf +0 -0
  257. package/src/llama.cpp/models/ggml-vocab-starcoder.gguf.inp +112 -0
  258. package/src/llama.cpp/models/ggml-vocab-starcoder.gguf.out +46 -0
  259. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  260. package/src/llama.cpp/requirements/requirements-all.txt +12 -0
  261. package/src/llama.cpp/requirements/requirements-compare-llama-bench.txt +2 -0
  262. package/src/llama.cpp/requirements/requirements-convert_hf_to_gguf.txt +3 -0
  263. package/src/llama.cpp/requirements/requirements-convert_hf_to_gguf_update.txt +3 -0
  264. package/src/llama.cpp/requirements/{requirements-convert.txt → requirements-convert_legacy_llama.txt} +1 -1
  265. package/src/llama.cpp/requirements/requirements-convert_llama_ggml_to_gguf.txt +1 -0
  266. package/src/llama.cpp/requirements/requirements-convert_lora_to_gguf.txt +2 -0
  267. package/src/llama.cpp/requirements/requirements-pydantic.txt +3 -0
  268. package/src/llama.cpp/requirements/requirements-test-tokenizer-random.txt +1 -0
  269. package/src/llama.cpp/requirements.txt +5 -4
  270. package/src/llama.cpp/scripts/build-info.sh +30 -0
  271. package/src/llama.cpp/scripts/install-oneapi.bat +19 -0
  272. package/src/llama.cpp/src/CMakeLists.txt +33 -0
  273. package/src/llama.cpp/src/llama-grammar.cpp +539 -0
  274. package/src/llama.cpp/src/llama-grammar.h +39 -0
  275. package/src/llama.cpp/src/llama-impl.h +26 -0
  276. package/src/llama.cpp/src/llama-sampling.cpp +635 -0
  277. package/src/llama.cpp/src/llama-sampling.h +56 -0
  278. package/src/llama.cpp/src/llama-vocab.cpp +1721 -0
  279. package/src/llama.cpp/src/llama-vocab.h +130 -0
  280. package/src/llama.cpp/{llama.cpp → src/llama.cpp} +5979 -5260
  281. package/src/llama.cpp/{unicode-data.cpp → src/unicode-data.cpp} +851 -802
  282. package/src/llama.cpp/{unicode.cpp → src/unicode.cpp} +52 -30
  283. package/src/llama.cpp/{unicode.h → src/unicode.h} +5 -1
  284. package/src/llama.cpp/tests/CMakeLists.txt +19 -20
  285. package/src/llama.cpp/tests/test-backend-ops.cpp +245 -67
  286. package/src/llama.cpp/tests/test-chat-template.cpp +57 -3
  287. package/src/llama.cpp/tests/test-double-float.cpp +2 -2
  288. package/src/llama.cpp/tests/test-grad0.cpp +2 -2
  289. package/src/llama.cpp/tests/test-grammar-integration.cpp +978 -31
  290. package/src/llama.cpp/tests/test-grammar-parser.cpp +423 -158
  291. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +508 -135
  292. package/src/llama.cpp/tests/test-llama-grammar.cpp +15 -9
  293. package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -1
  294. package/src/llama.cpp/tests/test-quantize-perf.cpp +1 -1
  295. package/src/llama.cpp/tests/test-rope.cpp +3 -4
  296. package/src/llama.cpp/tests/test-sampling.cpp +5 -5
  297. package/src/llama.cpp/tests/test-tokenizer-0.cpp +6 -6
  298. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +20 -15
  299. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +22 -11
  300. package/bin/darwin/arm64/default.metallib +0 -0
  301. package/bin/darwin/x64/default.metallib +0 -0
  302. package/src/llama.cpp/examples/beam-search/CMakeLists.txt +0 -5
  303. package/src/llama.cpp/examples/beam-search/beam-search.cpp +0 -188
  304. package/src/llama.cpp/examples/finetune/finetune.cpp +0 -1862
  305. package/src/llama.cpp/examples/llama.android/llama/CMakeLists.txt +0 -55
  306. package/src/llama.cpp/examples/train-text-from-scratch/CMakeLists.txt +0 -5
  307. package/src/llama.cpp/examples/train-text-from-scratch/train-text-from-scratch.cpp +0 -1253
  308. package/src/llama.cpp/ggml-opencl.cpp +0 -2305
  309. package/src/llama.cpp/ggml-opencl.h +0 -36
  310. package/src/llama.cpp/ggml-sycl.cpp +0 -17340
  311. package/src/llama.cpp/ggml-vulkan-shaders.hpp +0 -81211
  312. package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf-update.txt +0 -2
  313. package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf.txt +0 -2
  314. package/src/llama.cpp/requirements/requirements-convert-llama-ggml-to-gguf.txt +0 -1
  315. package/src/llama.cpp/scripts/gen-build-info-cpp.cmake +0 -24
  316. /package/src/llama.cpp/{ggml-alloc.h → ggml/include/ggml-alloc.h} +0 -0
  317. /package/src/llama.cpp/{ggml-kompute.h → ggml/include/ggml-kompute.h} +0 -0
  318. /package/src/llama.cpp/{ggml-rpc.h → ggml/include/ggml-rpc.h} +0 -0
  319. /package/src/llama.cpp/{ggml-vulkan.h → ggml/include/ggml-vulkan.h} +0 -0
  320. /package/src/llama.cpp/{unicode-data.h → src/unicode-data.h} +0 -0
@@ -0,0 +1,2023 @@
1
+ /*
2
+ * Copyright (c) 2023-2024 The ggml authors
3
+ *
4
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
5
+ * of this software and associated documentation files (the "Software"), to
6
+ * deal in the Software without restriction, including without limitation the
7
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
8
+ * sell copies of the Software, and to permit persons to whom the Software is
9
+ * furnished to do so, subject to the following conditions:
10
+ *
11
+ * The above copyright notice and this permission notice shall be included in
12
+ * all copies or substantial portions of the Software.
13
+ *
14
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
20
+ * IN THE SOFTWARE.
21
+ */
22
+
23
+ #include "ggml-cann.h"
24
+
25
+ #include <acl/acl.h>
26
+ #include <stdarg.h>
27
+
28
+ #include <cmath>
29
+ #include <cstdio>
30
+ #include <cstring>
31
+ #include <mutex>
32
+
33
+ #include "ggml-backend-impl.h"
34
+ #include "ggml-cann/aclnn_ops.h"
35
+ #include "ggml-cann/common.h"
36
+
37
+ #define GGML_COMMON_DECL_C
38
+
39
+ #include "ggml-common.h"
40
+
41
+ /**
42
+ * @brief Default logging callback for GGML.
43
+ *
44
+ * This function is the default logging callback that logs messages to stderr.
45
+ *
46
+ * @param level The log level.
47
+ * @param msg The log message.
48
+ * @param user_data User data passed to the callback.
49
+ */
50
+ static void ggml_cann_default_log_callback(enum ggml_log_level level,
51
+ const char* msg, void* user_data) {
52
+ GGML_UNUSED(level);
53
+ GGML_UNUSED(user_data);
54
+ fprintf(stderr, "%s", msg);
55
+ }
56
+
57
+ ggml_log_callback ggml_cann_log_callback = ggml_cann_default_log_callback;
58
+ void* ggml_cann_log_user_data = NULL;
59
+
60
+ GGML_API void ggml_backend_cann_log_set_callback(ggml_log_callback log_callback,
61
+ void* user_data) {
62
+ ggml_cann_log_callback = log_callback;
63
+ ggml_cann_log_user_data = user_data;
64
+ }
65
+
66
+ #define GGML_CANN_LOG_INFO(...) ggml_cann_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
67
+ #define GGML_CANN_LOG_WARN(...) ggml_cann_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
68
+ #define GGML_CANN_LOG_ERROR(...) \
69
+ ggml_cann_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
70
+
71
+ GGML_ATTRIBUTE_FORMAT(2, 3)
72
+
73
+ /**
74
+ * @brief Log a message using the current logging callback.
75
+ *
76
+ * This function formats a log message and passes it to the current logging
77
+ * callback.
78
+ *
79
+ * @param level The log level.
80
+ * @param format The format string for the log message.
81
+ * @param ... The arguments for the format string.
82
+ */
83
+ static void ggml_cann_log(enum ggml_log_level level, const char* format, ...) {
84
+ if (ggml_cann_log_callback != NULL) {
85
+ va_list args;
86
+ va_start(args, format);
87
+ char buffer[128];
88
+ int len = vsnprintf(buffer, 128, format, args);
89
+ if (len < 128) {
90
+ ggml_cann_log_callback(level, buffer, ggml_cann_log_user_data);
91
+ } else {
92
+ // vsnprintf adds a null terminator
93
+ std::vector<char> buffer2(len + 1);
94
+ va_end(args);
95
+ va_start(args, format);
96
+ vsnprintf(&buffer2[0], buffer2.size(), format, args);
97
+ ggml_cann_log_callback(level, buffer2.data(),
98
+ ggml_cann_log_user_data);
99
+ }
100
+ va_end(args);
101
+ }
102
+ }
103
+
104
+ /**
105
+ * @brief Handles CANN errors by printing an error message and aborting.
106
+ *
107
+ * @param stmt The statement that caused the error.
108
+ * @param func The function in which the error occurred.
109
+ * @param file The file in which the error occurred.
110
+ * @param line The line number where the error occurred.
111
+ * @param msg The error message.
112
+ */
113
+ [[noreturn]] void ggml_cann_error(const char* stmt, const char* func,
114
+ const char* file, int line, const char* msg) {
115
+ int32_t id = -1;
116
+ aclrtGetDevice(&id);
117
+
118
+ GGML_CANN_LOG_ERROR("CANN error: %s\n", msg);
119
+ GGML_CANN_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func,
120
+ file, line);
121
+ GGML_CANN_LOG_ERROR(" %s\n", stmt);
122
+ // abort with GGML_ASSERT to get a stack trace
123
+ GGML_ABORT("CANN error");
124
+ }
125
+
126
+ /**
127
+ * @brief Sets the device to be used by CANN.
128
+ *
129
+ * @param device The device ID to set.
130
+ */
131
+ void ggml_cann_set_device(const int32_t device) {
132
+ // TODO: uncomment these lines after empty context has fixed.
133
+ // int current_device;
134
+ // ACL_CHECK(aclrtGetDevice(&current_device));
135
+
136
+ // if (device == current_device) {
137
+ // return;
138
+ // }
139
+ ACL_CHECK(aclrtSetDevice(device));
140
+ }
141
+
142
+ /**
143
+ * @brief Retrieves the current device ID.
144
+ *
145
+ * @return The current device ID.
146
+ */
147
+ int32_t ggml_cann_get_device() {
148
+ int32_t id;
149
+ ACL_CHECK(aclrtGetDevice(&id));
150
+ return id;
151
+ }
152
+
153
+ /**
154
+ * @brief Initialize the CANN device information.
155
+ *
156
+ * This function initializes the CANN device information by obtaining the
157
+ * device count and setting the memory allocation granularity for each device.
158
+ *
159
+ * @return A structure containing the device information.
160
+ */
161
+ static ggml_cann_device_info ggml_cann_init() {
162
+ ggml_cann_device_info info = {};
163
+
164
+ aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count);
165
+
166
+ if (err != ACL_SUCCESS) {
167
+ GGML_CANN_LOG_ERROR("%s: failed to initialize CANN: %s\n",
168
+ __func__, aclGetRecentErrMsg());
169
+ return info;
170
+ }
171
+
172
+ GGML_ASSERT(info.device_count <= GGML_CANN_MAX_DEVICES);
173
+
174
+ for (int id = 0; id < info.device_count; ++id) {
175
+ aclrtPhysicalMemProp prop = {};
176
+ prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
177
+ prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
178
+ prop.memAttr = ACL_HBM_MEM_HUGE;
179
+ prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
180
+ prop.location.id = id;
181
+ prop.reserve = 0;
182
+ ACL_CHECK(aclrtMemGetAllocationGranularity(
183
+ &prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED,
184
+ &info.devices[id].vmm_granularity));
185
+ }
186
+
187
+ // TODO: add more device info later.
188
+ return info;
189
+ }
190
+
191
+ /**
192
+ * @brief Retrieve the CANN device information.
193
+ *
194
+ * This function returns a reference to a structure containing the CANN device
195
+ * information. The device information is initialized once and reused on
196
+ * subsequent calls.
197
+ *
198
+ * @return A reference to the structure containing the device information.
199
+ */
200
+ const ggml_cann_device_info& ggml_cann_info() {
201
+ static ggml_cann_device_info info = ggml_cann_init();
202
+ return info;
203
+ }
204
+
205
+ //#define DEBUG_CANN_MALLOC
206
+ /**
207
+ * @brief A pool of CANN buffers(legacy).
208
+ *
209
+ * This class manages a pool of CANN buffers for a specific device.
210
+ */
211
+ struct ggml_cann_pool_leg : public ggml_cann_pool {
212
+ /**
213
+ * @brief The maximum number of buffers in the pool.
214
+ */
215
+ static const int MAX_BUFFERS = 256;
216
+
217
+ /**
218
+ * @brief The device ID associated with this buffer pool.
219
+ */
220
+ int device;
221
+
222
+ /**
223
+ * @brief Structure representing a CANN buffer.
224
+ */
225
+ struct ggml_cann_buffer {
226
+ void* ptr = nullptr; ///< Pointer to the buffer memory.
227
+ size_t size = 0; ///< Size of the buffer.
228
+ };
229
+
230
+ /**
231
+ * @brief Array of CANN buffers in the pool.
232
+ */
233
+ ggml_cann_buffer buffer_pool[MAX_BUFFERS] = {};
234
+
235
+ /**
236
+ * @brief Total size of all buffers in the pool.
237
+ */
238
+ size_t pool_size = 0;
239
+
240
+ /**
241
+ * @brief Constructor to initialize the buffer pool for a specific device.
242
+ *
243
+ * @param device The device ID to associate with this buffer pool.
244
+ */
245
+ explicit ggml_cann_pool_leg(int device) : device(device) {}
246
+
247
+ /**
248
+ * @brief Destructor to free all buffers in the pool.
249
+ */
250
+ ~ggml_cann_pool_leg() {
251
+ ggml_cann_set_device(device);
252
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
253
+ ggml_cann_buffer& b = buffer_pool[i];
254
+ if (b.ptr != nullptr) {
255
+ ACL_CHECK(aclrtFree(b.ptr));
256
+ pool_size -= b.size;
257
+ }
258
+ }
259
+ GGML_ASSERT(pool_size == 0);
260
+ }
261
+
262
+ /**
263
+ * @brief Allocate a buffer of the given size.
264
+ *
265
+ * @param size The size of the buffer to allocate.
266
+ * @param actual_size A pointer to a variable to receive the actual size of
267
+ * the allocated buffer.
268
+ * @return A pointer to the allocated buffer.
269
+ */
270
+ void* alloc(size_t size, size_t* actual_size) override {
271
+ #ifdef DEBUG_CANN_MALLOC
272
+ int nnz = 0;
273
+ size_t max_size = 0;
274
+ #endif
275
+ size_t best_diff = 1ull << 36;
276
+ int ibest = -1;
277
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
278
+ ggml_cann_buffer& b = buffer_pool[i];
279
+ if (b.ptr != nullptr) {
280
+ #ifdef DEBUG_CANN_MALLOC
281
+ ++nnz;
282
+ if (b.size > max_size) max_size = b.size;
283
+ #endif
284
+ if (b.size >= size) {
285
+ size_t diff = b.size - size;
286
+ if (diff < best_diff) {
287
+ best_diff = diff;
288
+ ibest = i;
289
+ if (!best_diff) {
290
+ void* ptr = b.ptr;
291
+ *actual_size = b.size;
292
+ b.ptr = nullptr;
293
+ b.size = 0;
294
+ return ptr;
295
+ }
296
+ }
297
+ }
298
+ }
299
+ }
300
+ if (ibest >= 0) {
301
+ ggml_cann_buffer& b = buffer_pool[ibest];
302
+ void* ptr = b.ptr;
303
+ *actual_size = b.size;
304
+ b.ptr = nullptr;
305
+ b.size = 0;
306
+ return ptr;
307
+ }
308
+ void* ptr;
309
+ size_t look_ahead_size = (size_t)(1.05 * size);
310
+ look_ahead_size = 256 * ((look_ahead_size + 255) / 256);
311
+ ggml_cann_set_device(device);
312
+ ACL_CHECK(
313
+ aclrtMalloc(&ptr, look_ahead_size, ACL_MEM_MALLOC_HUGE_FIRST));
314
+ *actual_size = look_ahead_size;
315
+ pool_size += look_ahead_size;
316
+ #ifdef DEBUG_CANN_MALLOC
317
+ GGML_CANN_LOG_INFO(
318
+ "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, "
319
+ "requested %u MB\n",
320
+ __func__, device, nnz, (uint32_t)(max_size / 1024 / 1024),
321
+ (uint32_t)(pool_size / 1024 / 1024),
322
+ (uint32_t)(size / 1024 / 1024));
323
+ #endif
324
+ return ptr;
325
+ }
326
+
327
+ /**
328
+ * @brief Free a buffer and return it to the pool.
329
+ *
330
+ * @param ptr Pointer to the buffer to free.
331
+ * @param size Size of the buffer to free.
332
+ */
333
+ void free(void* ptr, size_t size) override {
334
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
335
+ ggml_cann_buffer& b = buffer_pool[i];
336
+ if (b.ptr == nullptr) {
337
+ b.ptr = ptr;
338
+ b.size = size;
339
+ return;
340
+ }
341
+ }
342
+ // memory should always buffered. these memory may still needed by
343
+ // tasks in stream.
344
+ // TODO, fix me.
345
+ GGML_ABORT("Cann buffer pool full, increase MAX_CANN_BUFFERS\n");
346
+ }
347
+ };
348
+
349
+ /**
350
+ * @brief A pool of CANN buffers with virtual memory.
351
+ *
352
+ * This class manages a pool of CANN buffers with virtual memory for a specific
353
+ * device.
354
+ */
355
+ struct ggml_cann_pool_vmm : public ggml_cann_pool {
356
+ /**
357
+ * @brief The maximum size of the virtual memory pool (32 GB).
358
+ */
359
+ static const size_t CANN_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
360
+
361
+ /**
362
+ * @brief The device ID associated with this buffer pool.
363
+ */
364
+ int device;
365
+
366
+ /**
367
+ * @brief Pointer to the start of the virtual memory pool.
368
+ */
369
+ void* pool_addr = 0;
370
+
371
+ /**
372
+ * @brief Amount of virtual memory used in the pool.
373
+ */
374
+ size_t pool_used = 0;
375
+
376
+ /**
377
+ * @brief Total size of the virtual memory pool.
378
+ */
379
+ size_t pool_size = 0;
380
+
381
+ /**
382
+ * @brief Allocation granularity for the virtual memory pool.
383
+ */
384
+ size_t granularity;
385
+
386
+ /**
387
+ * @brief Handles for the physical memory allocated.
388
+ */
389
+ std::vector<aclrtDrvMemHandle> handles;
390
+
391
+ /**
392
+ * @brief Offsets for the mapped memory regions.
393
+ */
394
+ std::vector<void*> map_offsets;
395
+
396
+ /**
397
+ * @brief Constructor to initialize the buffer pool with virtual memory for
398
+ * a specific device.
399
+ *
400
+ * @param device The device ID to associate with this buffer pool.
401
+ */
402
+ explicit ggml_cann_pool_vmm(int device)
403
+ : device(device),
404
+ granularity(ggml_cann_info().devices[device].vmm_granularity) {}
405
+
406
+ /**
407
+ * @brief Destructor to free all buffers in the virtual memory pool.
408
+ */
409
+ ~ggml_cann_pool_vmm() {
410
+ if (pool_addr != 0) {
411
+ for (auto& offset : map_offsets) {
412
+ ACL_CHECK(aclrtUnmapMem(offset));
413
+ }
414
+ for (auto& handle : handles) {
415
+ ACL_CHECK(aclrtFreePhysical(handle));
416
+ }
417
+ ACL_CHECK(aclrtReleaseMemAddress(pool_addr));
418
+ }
419
+ }
420
+
421
+ /**
422
+ * @brief Allocate a buffer of the given size in the virtual memory pool.
423
+ *
424
+ * @param size The size of the buffer to allocate.
425
+ * @param actual_size A pointer to a variable to receive the actual size of
426
+ * the allocated buffer.
427
+ * @return A pointer to the allocated buffer.
428
+ */
429
+ void* alloc(size_t size, size_t* actual_size) override {
430
+ // round up the allocation size to the alignment to ensure that all
431
+ // allocations are aligned for all data types
432
+ const size_t alignment = 128;
433
+ size = alignment * ((size + alignment - 1) / alignment);
434
+
435
+ size_t avail = pool_size - pool_used;
436
+
437
+ if (size > avail) {
438
+ // round up to the next multiple of the granularity
439
+ size_t reserve_size = size - avail;
440
+ reserve_size =
441
+ granularity * ((reserve_size + granularity - 1) / granularity);
442
+
443
+ GGML_ASSERT(pool_size + reserve_size <= CANN_POOL_VMM_MAX_SIZE);
444
+
445
+ // allocate more physical memory
446
+ aclrtPhysicalMemProp prop = {};
447
+ prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
448
+ prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
449
+ prop.memAttr = ACL_HBM_MEM_HUGE;
450
+ prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
451
+ prop.location.id = device;
452
+ prop.reserve = 0;
453
+ aclrtDrvMemHandle handle;
454
+ ACL_CHECK(aclrtMallocPhysical(&handle, reserve_size, &prop, 0));
455
+
456
+ // reserve virtual address space (if not already reserved)
457
+ if (pool_addr == 0) {
458
+ ACL_CHECK(aclrtReserveMemAddress(
459
+ &pool_addr, CANN_POOL_VMM_MAX_SIZE, 0, NULL, 1));
460
+ }
461
+
462
+ // map at the end of the pool
463
+ ACL_CHECK(aclrtMapMem((char*)pool_addr + pool_size, reserve_size, 0,
464
+ handle, 0));
465
+
466
+ handles.push_back(handle);
467
+ map_offsets.push_back((char*)pool_addr + pool_size);
468
+
469
+ // add to the pool
470
+ pool_size += reserve_size;
471
+
472
+ // GGML_CANN_LOG_INFO("cann pool[%d]: size increased to %llu MB (
473
+ // reserved %llu MB)\n",
474
+ // device, (unsigned long long) (pool_size/1024/1024),
475
+ // (unsigned long long) (reserve_size/1024/1024));
476
+ }
477
+
478
+ GGML_ASSERT(pool_addr != 0);
479
+
480
+ void* ptr = (void*)((char*)pool_addr + pool_used);
481
+ *actual_size = size;
482
+ pool_used += size;
483
+
484
+ #ifdef DEBUG_CANN_MALLOC
485
+ GGML_CANN_LOG_INFO("cann pool[%d]: allocated %llu bytes at %llx\n", device,
486
+ (unsigned long long)size, (unsigned long long)ptr);
487
+ #endif
488
+ return ptr;
489
+ }
490
+
491
+ /**
492
+ * @brief Free a buffer and return it to the virtual memory pool.
493
+ *
494
+ * @param ptr Pointer to the buffer to free.
495
+ * @param size Size of the buffer to free.
496
+ */
497
+ void free(void* ptr, size_t size) override {
498
+ #ifdef DEBUG_CANN_MALLOC
499
+ GGML_CANN_LOG_INFO("cann pool[%d]: freed %llu bytes at %llx\n", device,
500
+ (unsigned long long)size, (unsigned long long)ptr);
501
+ #endif
502
+
503
+ pool_used -= size;
504
+
505
+ // all deallocations must be in reverse order of the allocations
506
+ GGML_ASSERT(ptr == (void*)((char*)pool_addr + pool_used));
507
+ }
508
+ };
509
+
510
+ /**
511
+ * @brief Create a new CANN pool for a specific device.
512
+ *
513
+ * Factory method to create a new CANN pool object based on the device type.
514
+ *
515
+ * @param device The device ID for which to create the pool.
516
+ * @return A unique pointer to the created CANN pool.
517
+ */
518
+ std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
519
+ int device) {
520
+ // return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_leg(device));
521
+ return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
522
+ }
523
+
524
+ // cann buffer
525
+ /**
526
+ * @brief Context for managing a CANN buffer associated with a specific device.
527
+ *
528
+ * This structure holds information about a CANN buffer, including the device
529
+ * ID, device pointer, and a name derived from GGML_CANN_NAME and the device ID.
530
+ */
531
+ struct ggml_backend_cann_buffer_context {
532
+ int32_t device; ///< The device ID associated with this buffer context.
533
+ void* dev_ptr =
534
+ nullptr; ///< Pointer to the device memory allocated for the buffer.
535
+
536
+ /**
537
+ * @brief Constructor to initialize the CANN buffer context.
538
+ *
539
+ * @param device The device ID associated with this buffer context.
540
+ * @param dev_ptr Pointer to the device memory allocated for the buffer.
541
+ */
542
+ ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
543
+ : device(device),
544
+ dev_ptr(dev_ptr) {}
545
+
546
+ /**
547
+ * @brief Destructor to free the device memory allocated for the buffer.
548
+ */
549
+ ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }
550
+ };
551
+
552
+ /**
553
+ * @brief Retrieve the name associated with a CANN buffer.
554
+ *
555
+ * This function returns the name of a CANN buffer, which is stored in the
556
+ * context of the buffer.
557
+ *
558
+ * @param buffer The CANN buffer whose name is to be retrieved.
559
+ * @return A pointer to a C-string containing the name of the buffer.
560
+ */
561
+
562
+ GGML_CALL static const char* ggml_backend_cann_buffer_get_name(
563
+ ggml_backend_buffer_t buffer) {
564
+ return "CANN";
565
+
566
+ GGML_UNUSED(buffer);
567
+ }
568
+
569
+ /**
570
+ * @brief Check if a buffer is a CANN buffer.
571
+ *
572
+ * This function checks if a given buffer is a CANN buffer by comparing its
573
+ * `get_name` function pointer to `ggml_backend_cann_buffer_get_name`.
574
+ *
575
+ * @param buffer The buffer to check.
576
+ * @return true if the buffer is a CANN buffer, false otherwise.
577
+ */
578
+ GGML_CALL static bool ggml_backend_buffer_is_cann(
579
+ ggml_backend_buffer_t buffer) {
580
+ return buffer->iface.get_name == ggml_backend_cann_buffer_get_name;
581
+ }
582
+
583
+ /**
584
+ * @brief Free resources associated with a CANN buffer.
585
+ *
586
+ * This function frees the resources associated with a CANN buffer, including
587
+ * its context.
588
+ *
589
+ * @param buffer The CANN buffer to free.
590
+ */
591
+ GGML_CALL static void ggml_backend_cann_buffer_free_buffer(
592
+ ggml_backend_buffer_t buffer) {
593
+ ggml_backend_cann_buffer_context* ctx =
594
+ (ggml_backend_cann_buffer_context*)buffer->context;
595
+ delete ctx;
596
+ }
597
+
598
+ /**
599
+ * @brief Retrieve the base pointer of a CANN buffer.
600
+ *
601
+ * This function returns the base pointer of a CANN buffer, which points to the
602
+ * device memory allocated for the buffer.
603
+ *
604
+ * @param buffer The CANN buffer whose base pointer is to be retrieved.
605
+ * @return A pointer to the base of the device memory allocated for the buffer.
606
+ */
607
+ GGML_CALL static void* ggml_backend_cann_buffer_get_base(
608
+ ggml_backend_buffer_t buffer) {
609
+ ggml_backend_cann_buffer_context* ctx =
610
+ (ggml_backend_cann_buffer_context*)buffer->context;
611
+ return ctx->dev_ptr;
612
+ }
613
+
614
+ /**
615
+ * @brief Transform quantized Q4.0 tensor data into a format suitable for CANN
616
+ * processing.
617
+ *
618
+ * This function transforms quantized Q4.0 tensor data into a format suitable
619
+ * for CANN processing. It extracts quantization values and scales from the
620
+ * source data and prepares them in a format expected by CANN operations.
621
+ *
622
+ * @param tensor Pointer to the tensor information.
623
+ * @param src Pointer to the source data in Q4.0 format.
624
+ * @param dst Pointer to the destination buffer where transformed data will be
625
+ * stored.
626
+ */
627
+ GGML_CALL static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor,
628
+ const void* src,
629
+ void* dst) {
630
+ GGML_ASSERT(tensor->op == GGML_OP_NONE);
631
+
632
+ int64_t n_elems = ggml_nelements(tensor);
633
+ int64_t groups = n_elems / QK4_0;
634
+ size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
635
+
636
+ uint8_t* quant_offset = (uint8_t*)dst;
637
+ uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
638
+
639
+ for (int i = 0; i < groups; i++) {
640
+ const block_q4_0* group =
641
+ (const block_q4_0*)((const char*)src + i * sizeof(block_q4_0));
642
+ *scale_offset = group->d;
643
+ scale_offset++;
644
+
645
+ // 0-15
646
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
647
+ (*quant_offset) = (group->qs[j] & 0x0F);
648
+ (*quant_offset) |= ((group->qs[j + 1] << 4));
649
+ quant_offset++;
650
+ }
651
+
652
+ // 16-31
653
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
654
+ (*quant_offset) = (group->qs[j] >> 4);
655
+ (*quant_offset) |= (group->qs[j + 1] & 0xF0);
656
+ quant_offset++;
657
+ }
658
+ }
659
+
660
+ // put (uint4b_t -8) into int4b_t
661
+ for (quant_offset = (uint8_t*)dst;
662
+ quant_offset < (uint8_t*)dst + quant_bytes; quant_offset++) {
663
+ (*quant_offset) ^= 0x88;
664
+ }
665
+ }
666
+
667
+ /**
668
+ * @brief Transform CANN processed data back into quantized Q4.0 format.
669
+ *
670
+ * This function transforms CANN processed data back into quantized Q4.0 format.
671
+ * It reverses the transformation performed by
672
+ * ggml_backend_cann_transform_q4_0(), converting the data back into its
673
+ * original quantized form.
674
+ *
675
+ * @param tensor Pointer to the tensor information.
676
+ * @param src Pointer to the source buffer containing transformed data.
677
+ * @param dst Pointer to the destination buffer where the Q4.0 formatted data
678
+ * will be stored.
679
+ */
680
+ GGML_CALL static void ggml_backend_cann_transform_back_q4_0(
681
+ const ggml_tensor* tensor, void* src, void* dst) {
682
+ GGML_ASSERT(tensor->op == GGML_OP_NONE);
683
+
684
+ int64_t n_elems = ggml_nelements(tensor);
685
+ int64_t groups = n_elems / QK4_0;
686
+ size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
687
+
688
+ uint8_t* quant_offset = (uint8_t*)src;
689
+ uint16_t* scale_offset = (uint16_t*)((char*)src + quant_bytes);
690
+
691
+ for (; quant_offset < (uint8_t*)src + quant_bytes; quant_offset++) {
692
+ (*quant_offset) ^= 0x88;
693
+ }
694
+ quant_offset = (uint8_t*)src;
695
+
696
+ for (int i = 0; i < groups; i++) {
697
+ block_q4_0* group = (block_q4_0*)((char*)dst + i * sizeof(block_q4_0));
698
+ group->d = *scale_offset;
699
+ scale_offset++;
700
+
701
+ // 0-15
702
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
703
+ group->qs[j] = ((*quant_offset) & 0x0F);
704
+ group->qs[j + 1] = ((*quant_offset) >> 4);
705
+ quant_offset++;
706
+ }
707
+
708
+ // 16-31
709
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
710
+ group->qs[j] |= ((*quant_offset) << 4);
711
+ group->qs[j + 1] |= ((*quant_offset) & 0xF0);
712
+ quant_offset++;
713
+ }
714
+ }
715
+ }
716
+
717
+ /**
718
+ * @brief Transform quantized Q8.0 tensor data into a format suitable for CANN
719
+ * processing.
720
+ *
721
+ * This function transforms quantized Q8.0 tensor data into a format suitable
722
+ * for CANN processing. It extracts quantization values and scales from the
723
+ * source data and prepares them in a format expected by CANN operations.
724
+ *
725
+ * @param tensor Pointer to the tensor information.
726
+ * @param src Pointer to the source data in Q8.0 format.
727
+ * @param dst Pointer to the destination buffer where transformed data will be
728
+ * stored.
729
+ */
730
+ GGML_CALL static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor,
731
+ const void* src,
732
+ void* dst) {
733
+ int64_t n_elems = ggml_nelements(tensor);
734
+ int64_t groups = n_elems / QK8_0;
735
+ size_t quant_bytes = n_elems * sizeof(uint8_t);
736
+
737
+ uint8_t* quant_offset = (uint8_t*)dst;
738
+ uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
739
+
740
+ for (int i = 0; i < groups; i++) {
741
+ const block_q8_0* group =
742
+ (const block_q8_0*)((const char*)src + i * sizeof(block_q8_0));
743
+ *scale_offset = group->d;
744
+ scale_offset++;
745
+ size_t group_quant_size = QK8_0 * sizeof(uint8_t);
746
+ memcpy(quant_offset, group->qs, group_quant_size);
747
+ quant_offset += group_quant_size;
748
+ }
749
+ }
750
+
751
+ /**
752
+ * @brief Transform CANN processed data back into quantized Q8.0 format.
753
+ *
754
+ * This function transforms CANN processed data back into quantized Q8.0 format.
755
+ * It reverses the transformation performed by
756
+ * ggml_backend_cann_transform_q8_0(), converting the data back into its
757
+ * original quantized form.
758
+ *
759
+ * @param tensor Pointer to the tensor information.
760
+ * @param src Pointer to the source buffer containing transformed data.
761
+ * @param dst Pointer to the destination buffer where the Q8.0 formatted data
762
+ * will be stored.
763
+ */
764
+ GGML_CALL static void ggml_backend_cann_transform_back_q8_0(
765
+ const ggml_tensor* tensor, const void* src, void* dst) {
766
+ int64_t n_elems = ggml_nelements(tensor);
767
+ int64_t groups = n_elems / QK8_0;
768
+ size_t quant_bytes = n_elems * sizeof(uint8_t);
769
+
770
+ const uint8_t* quant_offset = (const uint8_t*)src;
771
+ const uint16_t* scale_offset =
772
+ (const uint16_t*)((const char*)src + quant_bytes);
773
+
774
+ for (int i = 0; i < groups; i++) {
775
+ block_q8_0* group = (block_q8_0*)((char*)dst + i * sizeof(block_q8_0));
776
+ group->d = *scale_offset;
777
+ scale_offset++;
778
+ size_t group_quant_size = QK8_0 * sizeof(uint8_t);
779
+ memcpy(group->qs, quant_offset, group_quant_size);
780
+ quant_offset += group_quant_size;
781
+ }
782
+ }
783
+
784
+ /**
785
+ * @brief Transform tensor data based on its type for CANN processing.
786
+ *
787
+ * This function transforms tensor data based on its quantization type for CANN
788
+ * processing. It dispatches the transformation based on the tensor's type to
789
+ * specialized functions handling Q4.0 and Q8.0 formats.
790
+ *
791
+ * @param tensor Pointer to the tensor information.
792
+ * @param src Pointer to the source data to be transformed.
793
+ * @param dst Pointer to the destination buffer where transformed data will be
794
+ * stored.
795
+ */
796
+ GGML_CALL static void ggml_backend_cann_transform(ggml_tensor* tensor,
797
+ const void* src, void* dst) {
798
+ switch (tensor->type) {
799
+ case GGML_TYPE_Q4_0:
800
+ ggml_backend_cann_transform_q4_0(tensor, src, dst);
801
+ break;
802
+ case GGML_TYPE_Q8_0:
803
+ ggml_backend_cann_transform_q8_0(tensor, src, dst);
804
+ break;
805
+ default:
806
+ break;
807
+ }
808
+ }
809
+
810
+ /**
811
+ * @brief Transform CANN processed data back into tensor data based on its type.
812
+ *
813
+ * This function transforms CANN processed data back into tensor data based on
814
+ * its quantization type for Q4.0 and Q8.0 formats. It dispatches the
815
+ * transformation based on the tensor's type to specialized functions.
816
+ *
817
+ * @param tensor Pointer to the tensor information.
818
+ * @param src Pointer to the source data containing CANN processed data.
819
+ * @param dst Pointer to the destination buffer where transformed tensor data
820
+ * will be stored.
821
+ */
822
+ GGML_CALL static void ggml_backend_cann_transform_back(
823
+ const ggml_tensor* tensor, void* src, void* dst) {
824
+ switch (tensor->type) {
825
+ case GGML_TYPE_Q4_0:
826
+ ggml_backend_cann_transform_back_q4_0(tensor, src, dst);
827
+ break;
828
+ case GGML_TYPE_Q8_0:
829
+ ggml_backend_cann_transform_back_q8_0(tensor, src, dst);
830
+ break;
831
+ default:
832
+ break;
833
+ }
834
+ }
835
+
836
+ /**
837
+ * @brief Check if transformation is needed for a given tensor type.
838
+ *
839
+ * This function checks if transformation is needed for a given tensor type
840
+ * to prepare data for CANN processing.
841
+ *
842
+ * @param type The tensor type to check.
843
+ * @return true if transformation is needed, false otherwise.
844
+ */
845
+ GGML_CALL static bool need_transform(ggml_type type) {
846
+ switch (type) {
847
+ case GGML_TYPE_Q4_0:
848
+ case GGML_TYPE_Q8_0:
849
+ return true;
850
+ default:
851
+ return false;
852
+ }
853
+ }
854
+
855
+ /**
856
+ * @brief Initialize a tensor using data from a CANN buffer.
857
+ *
858
+ * This function initializes a tensor using data from a CANN buffer.
859
+ * It handles special cases such as views and quantization.
860
+ *
861
+ * @param buffer The CANN buffer from which to initialize the tensor.
862
+ * @param tensor Pointer to the tensor to be initialized.
863
+ */
864
+ GGML_CALL static void ggml_backend_cann_buffer_init_tensor(
865
+ ggml_backend_buffer_t buffer, ggml_tensor* tensor) {
866
+ if (tensor->view_src != NULL && tensor->view_offs == 0) {
867
+ GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
868
+ return;
869
+ }
870
+
871
+ // TODO: can backend doesn't support quantized yet. Just leave the code
872
+ // here.
873
+ if (ggml_is_quantized(tensor->type)) {
874
+ // Initialize padding to 0 to avoid possible NaN values
875
+ size_t original_size = ggml_nbytes(tensor);
876
+ size_t padded_size =
877
+ ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
878
+
879
+ if (padded_size > original_size && tensor->view_src == nullptr) {
880
+ size_t memset_size = padded_size - original_size;
881
+ ACL_CHECK(aclrtMemset((char*)tensor->data + original_size,
882
+ memset_size, 0, memset_size));
883
+ }
884
+ }
885
+ }
886
+
887
+ // TODO: need handle tensor which has paddings.
888
+ /**
889
+ * @brief Set tensor data in a CANN buffer.
890
+ *
891
+ * This function sets tensor data in a CANN buffer, handling transformations
892
+ * if needed based on the tensor's type.
893
+ *
894
+ * @param buffer The CANN buffer where the tensor data will be set.
895
+ * @param tensor Pointer to the tensor whose data will be set.
896
+ * @param data Pointer to the source data to be copied into the tensor.
897
+ * @param offset Offset in the source data from where to start copying.
898
+ * @param size Size of the data to be copied, in bytes.
899
+ */
900
+ GGML_CALL static void ggml_backend_cann_buffer_set_tensor(
901
+ ggml_backend_buffer_t buffer, ggml_tensor* tensor, const void* data,
902
+ size_t offset, size_t size) {
903
+ // GGML_ASSERT(size == ggml_nbytes(tensor));
904
+ ggml_backend_cann_buffer_context* ctx =
905
+ (ggml_backend_cann_buffer_context*)buffer->context;
906
+
907
+ ggml_cann_set_device(ctx->device);
908
+ // TODO: refer to cann(#6017), it use thread's default stream.
909
+ // For acl, synchronous functions use this default stream.
910
+ // Why aclrtSynchronizeDevice?
911
+
912
+ if (!need_transform(tensor->type)) {
913
+ ACL_CHECK(aclrtMemcpy(tensor->data, size, (const char*)data + offset,
914
+ size, ACL_MEMCPY_HOST_TO_DEVICE));
915
+ } else {
916
+ void* transform_buffer = malloc(size);
917
+ ggml_backend_cann_transform(tensor, (const char*)data + offset,
918
+ transform_buffer);
919
+
920
+ #ifndef NDEBUG
921
+ void* check_buffer = malloc(size);
922
+ ggml_backend_cann_transform_back(tensor, transform_buffer,
923
+ check_buffer);
924
+ GGML_ASSERT(memcmp((const char*)data + offset, check_buffer, size) ==
925
+ 0);
926
+ free(check_buffer);
927
+ #endif
928
+ ACL_CHECK(aclrtMemcpy(tensor->data, size, transform_buffer, size,
929
+ ACL_MEMCPY_HOST_TO_DEVICE));
930
+ free(transform_buffer);
931
+ }
932
+ }
933
+
934
+ /**
935
+ * @brief Get tensor data from a CANN buffer.
936
+ *
937
+ * This function retrieves tensor data from a CANN buffer, handling
938
+ * transformations if needed based on the tensor's type.
939
+ *
940
+ * @param buffer The CANN buffer from which to retrieve tensor data.
941
+ * @param tensor Pointer to the tensor whose data will be retrieved.
942
+ * @param data Pointer to the destination buffer where the tensor data will be
943
+ * copied.
944
+ * @param offset Offset in the destination buffer where to start copying.
945
+ * @param size Size of the data to be copied, in bytes.
946
+ */
947
+ GGML_CALL static void ggml_backend_cann_buffer_get_tensor(
948
+ ggml_backend_buffer_t buffer, const ggml_tensor* tensor, void* data,
949
+ size_t offset, size_t size) {
950
+ GGML_ASSERT(size == ggml_nbytes(tensor));
951
+ ggml_backend_cann_buffer_context* ctx =
952
+ (ggml_backend_cann_buffer_context*)buffer->context;
953
+
954
+ ggml_cann_set_device(ctx->device);
955
+
956
+ if (!need_transform(tensor->type)) {
957
+ ACL_CHECK(aclrtMemcpy((char*)data + offset, size, tensor->data, size,
958
+ ACL_MEMCPY_DEVICE_TO_HOST));
959
+ } else {
960
+ void* transform_buffer = malloc(size);
961
+ ACL_CHECK(aclrtMemcpy(transform_buffer, size, tensor->data, size,
962
+ ACL_MEMCPY_DEVICE_TO_HOST));
963
+ ggml_backend_cann_transform_back(tensor, transform_buffer,
964
+ (char*)data + offset);
965
+ free(transform_buffer);
966
+ }
967
+ }
968
+
969
+ /**
970
+ * @brief Copy tensor data between CANN buffers if possible.
971
+ *
972
+ * This function copies tensor data between CANN buffers if the source and
973
+ * destination buffers are CANN buffers and they meet the necessary conditions
974
+ * (same device or devices can access each other).
975
+ *
976
+ * @param buffer The destination CANN buffer where the tensor data will be
977
+ * copied.
978
+ * @param src Pointer to the source tensor whose data will be copied.
979
+ * @param dst Pointer to the destination tensor where the data will be copied.
980
+ * @return true if the copy operation succeeded, false otherwise.
981
+ */
982
+ GGML_CALL static bool ggml_backend_cann_buffer_cpy_tensor(
983
+ ggml_backend_buffer_t buffer, const ggml_tensor* src, ggml_tensor* dst) {
984
+ if (ggml_backend_buffer_is_cann(src->buffer)) {
985
+ ggml_backend_cann_buffer_context* src_ctx =
986
+ (ggml_backend_cann_buffer_context*)src->buffer->context;
987
+ ggml_backend_cann_buffer_context* dst_ctx =
988
+ (ggml_backend_cann_buffer_context*)buffer->context;
989
+
990
+ size_t memcpy_size = ggml_nbytes(src);
991
+ // Same device.
992
+ if (src_ctx->device == dst_ctx->device) {
993
+ ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
994
+ (const char*)src->data, memcpy_size,
995
+ ACL_MEMCPY_DEVICE_TO_DEVICE));
996
+ return true;
997
+ } else {
998
+ // Different device but can access by peer.
999
+ int32_t canAccessPeer = 0;
1000
+ ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device,
1001
+ dst_ctx->device));
1002
+ if (canAccessPeer) {
1003
+ ggml_cann_set_device(src_ctx->device);
1004
+ ACL_CHECK(aclrtDeviceEnablePeerAccess(dst_ctx->device, 0));
1005
+ ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
1006
+ (const char*)src->data, memcpy_size,
1007
+ ACL_MEMCPY_DEVICE_TO_DEVICE));
1008
+ return true;
1009
+ }
1010
+ }
1011
+ }
1012
+ return false;
1013
+ }
1014
+
1015
+ /**
1016
+ * @brief Clear a CANN buffer by setting all its memory to a specified value.
1017
+ *
1018
+ * This function clears a CANN buffer by setting all its memory to a specified
1019
+ * value.
1020
+ *
1021
+ * @param buffer The CANN buffer to be cleared.
1022
+ * @param value The value to which each byte in the buffer will be set.
1023
+ */
1024
+ GGML_CALL static void ggml_backend_cann_buffer_clear(
1025
+ ggml_backend_buffer_t buffer, uint8_t value) {
1026
+ ggml_backend_cann_buffer_context* ctx =
1027
+ (ggml_backend_cann_buffer_context*)buffer->context;
1028
+
1029
+ ggml_cann_set_device(ctx->device);
1030
+ ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size));
1031
+ }
1032
+
1033
+ /**
1034
+ * @brief Interface for a CANN buffer in the backend.
1035
+ *
1036
+ * This structure defines function pointers to operations that can be performed
1037
+ * on a CANN buffer within the backend.
1038
+ */
1039
+ static ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
1040
+ /* .get_name = */ ggml_backend_cann_buffer_get_name,
1041
+ /* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
1042
+ /* .get_base = */ ggml_backend_cann_buffer_get_base,
1043
+ /* .init_tensor = */ ggml_backend_cann_buffer_init_tensor,
1044
+ /* .set_tensor = */ ggml_backend_cann_buffer_set_tensor,
1045
+ /* .get_tensor = */ ggml_backend_cann_buffer_get_tensor,
1046
+ /* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor,
1047
+ /* .clear = */ ggml_backend_cann_buffer_clear,
1048
+ /* .reset = */ NULL,
1049
+ };
1050
+
1051
+ // cann buffer type
1052
+ /**
1053
+ * @brief Structure representing context information for a specific backend
1054
+ * buffer type.
1055
+ */
1056
+ struct ggml_backend_cann_buffer_type_context {
1057
+ int32_t
1058
+ device; /**< Device identifier associated with the buffer context. */
1059
+ std::string name; /**< Name associated with the buffer context. */
1060
+ };
1061
+
1062
+ /**
1063
+ * @brief Retrieves the name associated with a CANN buffer type.
1064
+ *
1065
+ * This function returns the descriptive name associated with the specified
1066
+ * CANN buffer type context.
1067
+ *
1068
+ * @param buft Pointer to the buffer type context.
1069
+ * @return Const pointer to the C-style string containing the name.
1070
+ */
1071
+ GGML_CALL static const char* ggml_backend_cann_buffer_type_name(
1072
+ ggml_backend_buffer_type_t buft) {
1073
+ return "CANN";
1074
+
1075
+ GGML_UNUSED(buft);
1076
+ }
1077
+
1078
+ /**
1079
+ * @brief Allocates a new CANN buffer of the specified type and size.
1080
+ *
1081
+ * This function allocates a new CANN buffer on the specified device with the
1082
+ * given size.
1083
+ *
1084
+ * @param buft Pointer to the buffer type context.
1085
+ * @param size Size in bytes of the buffer to allocate.
1086
+ * @return Pointer to the allocated buffer, or nullptr if allocation fails.
1087
+ */
1088
+ GGML_CALL static ggml_backend_buffer_t
1089
+ ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
1090
+ size_t size) {
1091
+ ggml_backend_cann_buffer_type_context* buft_ctx =
1092
+ (ggml_backend_cann_buffer_type_context*)buft->context;
1093
+
1094
+ ggml_cann_set_device(buft_ctx->device);
1095
+
1096
+ size = std::max(size, (size_t)1);
1097
+
1098
+ void* dev_ptr;
1099
+ aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
1100
+ if (err != ACL_SUCCESS) {
1101
+ GGML_CANN_LOG_ERROR(
1102
+ "%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n",
1103
+ __func__, size / 1024.0 / 1024.0, buft_ctx->device,
1104
+ aclGetRecentErrMsg());
1105
+ return nullptr;
1106
+ }
1107
+
1108
+ ggml_backend_cann_buffer_context* ctx =
1109
+ new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);
1110
+
1111
+ return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface,
1112
+ ctx, size);
1113
+ }
1114
+
1115
+ /**
1116
+ * @brief Retrieves the memory alignment requirement for CANN buffers of this
1117
+ * type.
1118
+ *
1119
+ * This function returns the alignment requirement in bytes for memory allocated
1120
+ * by the CANN buffer type.
1121
+ *
1122
+ * @param buft Pointer to the buffer type context (unused in this
1123
+ * implementation).
1124
+ * @return The alignment requirement in bytes (fixed at 128 bytes for CANN
1125
+ * buffers).
1126
+ */
1127
+ GGML_CALL static size_t ggml_backend_cann_buffer_type_get_alignment(
1128
+ ggml_backend_buffer_type_t buft) {
1129
+ return 128;
1130
+
1131
+ GGML_UNUSED(buft);
1132
+ }
1133
+
1134
+ /**
1135
+ * @brief Calculates the allocation size required for a tensor in a CANN buffer.
1136
+ *
1137
+ * Computes the total allocation size needed for storing the tensor's data in a
1138
+ * CANN buffer, considering any necessary padding or adjustments for quantized
1139
+ * types.
1140
+ *
1141
+ * @param buft Pointer to the buffer type context (unused in this
1142
+ * implementation).
1143
+ * @param tensor Pointer to the tensor for which the allocation size is
1144
+ * calculated.
1145
+ * @return The total allocation size in bytes required for the tensor in the
1146
+ * CANN buffer.
1147
+ */
1148
+ GGML_CALL static size_t ggml_backend_cann_buffer_type_get_alloc_size(
1149
+ ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
1150
+ size_t size = ggml_nbytes(tensor);
1151
+ int64_t ne0 = tensor->ne[0];
1152
+
1153
+ // last line must bigger than 32, because every single op deal at
1154
+ // least 32 bytes.
1155
+ // TODO: quantized type?
1156
+ // int64_t line_size = ne0 * ggml_element_size(tensor);
1157
+ // int64_t line_size_align_32 = (line_size + 31) & ~31;
1158
+ // size += (line_size_align_32 - line_size);
1159
+
1160
+ // TODO: not support quantized yet.
1161
+ // TODO: consider un-continue tensor.
1162
+ if (ggml_is_quantized(tensor->type)) {
1163
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
1164
+ size += ggml_row_size(
1165
+ tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
1166
+ }
1167
+ }
1168
+
1169
+ return size;
1170
+
1171
+ GGML_UNUSED(buft);
1172
+ }
1173
+
1174
+ /**
1175
+ * @brief Interface for managing CANN buffer types in the GGML backend.
1176
+ *
1177
+ * Provides function pointers for allocating, querying properties, and managing
1178
+ * memory for CANN buffer types in the GGML backend.
1179
+ */
1180
+ static ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {
1181
+ /* .get_name = */ ggml_backend_cann_buffer_type_name,
1182
+ /* .alloc_buffer = */ ggml_backend_cann_buffer_type_alloc_buffer,
1183
+ /* .get_alignment = */ ggml_backend_cann_buffer_type_get_alignment,
1184
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
1185
+ /* .get_alloc_size = */ ggml_backend_cann_buffer_type_get_alloc_size,
1186
+ /* .is_host = */ NULL,
1187
+ };
1188
+
1189
+ /**
1190
+ * @brief Retrieves the CANN buffer type for a specified device.
1191
+ *
1192
+ * This function initializes and returns the buffer type interface associated
1193
+ * with the given device. It ensures thread-safe access using a mutex.
1194
+ *
1195
+ * @param device The device index for which to retrieve the buffer type.
1196
+ * @return A pointer to the buffer type interface for the specified device, or
1197
+ * nullptr if the device index is out of range.
1198
+ */
1199
+ GGML_CALL ggml_backend_buffer_type_t
1200
+ ggml_backend_cann_buffer_type(int32_t device) {
1201
+ static std::mutex mutex;
1202
+ std::lock_guard<std::mutex> lock(mutex);
1203
+
1204
+ if (device >= ggml_backend_cann_get_device_count()) {
1205
+ return nullptr;
1206
+ }
1207
+
1208
+ static ggml_backend_buffer_type
1209
+ ggml_backend_cann_buffer_types[GGML_CANN_MAX_DEVICES];
1210
+
1211
+ static bool ggml_backend_cann_buffer_type_initialized = false;
1212
+
1213
+ if (!ggml_backend_cann_buffer_type_initialized) {
1214
+ for (int32_t i = 0; i < GGML_CANN_MAX_DEVICES; i++) {
1215
+ ggml_backend_cann_buffer_types[i] = {
1216
+ /* .iface = */ ggml_backend_cann_buffer_type_interface,
1217
+ /* .context = */
1218
+ new ggml_backend_cann_buffer_type_context{
1219
+ i, "CANN" + std::to_string(i)},
1220
+ };
1221
+ }
1222
+ ggml_backend_cann_buffer_type_initialized = true;
1223
+ }
1224
+
1225
+ return &ggml_backend_cann_buffer_types[device];
1226
+ }
1227
+
1228
+ /**
1229
+ * @brief Computes the forward operation for a given tensor using CANN
1230
+ * operations.
1231
+ *
1232
+ * This function selects the appropriate CANN operation based on the type of
1233
+ * operation specified in the tensor and performs the computation.
1234
+ *
1235
+ * @param ctx The CANN context containing necessary resources and
1236
+ * configurations.
1237
+ * @param dst The destination tensor where the result of the computation will be
1238
+ * stored.
1239
+ * @return true if the computation was successful; false otherwise.
1240
+ */
1241
+ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1242
+ struct ggml_tensor* dst) {
1243
+ switch (dst->op) {
1244
+ case GGML_OP_REPEAT:
1245
+ ggml_cann_repeat(ctx, dst);
1246
+ break;
1247
+ case GGML_OP_GET_ROWS:
1248
+ ggml_cann_get_rows(ctx, dst);
1249
+ break;
1250
+ case GGML_OP_DUP:
1251
+ ggml_cann_dup(ctx, dst);
1252
+ break;
1253
+ case GGML_OP_ADD:
1254
+ ggml_cann_add(ctx, dst);
1255
+ break;
1256
+ case GGML_OP_ACC:
1257
+ ggml_cann_acc(ctx, dst);
1258
+ break;
1259
+ case GGML_OP_MUL:
1260
+ ggml_cann_mul_div<aclnnMulGetWorkspaceSize, aclnnMul>(ctx, dst);
1261
+ break;
1262
+ case GGML_OP_DIV:
1263
+ ggml_cann_mul_div<aclnnDivGetWorkspaceSize, aclnnDiv>(ctx, dst);
1264
+ break;
1265
+ case GGML_OP_UNARY:
1266
+ switch (ggml_get_unary_op(dst)) {
1267
+ case GGML_UNARY_OP_GELU:
1268
+ ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
1269
+ ctx, dst);
1270
+ break;
1271
+ case GGML_UNARY_OP_SILU:
1272
+ ggml_cann_activation<aclnnSiluGetWorkspaceSize, aclnnSilu>(
1273
+ ctx, dst);
1274
+ break;
1275
+ // TODO: Use faster gelu??
1276
+ case GGML_UNARY_OP_GELU_QUICK:
1277
+ ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
1278
+ ctx, dst);
1279
+ break;
1280
+ case GGML_UNARY_OP_TANH:
1281
+ ggml_cann_activation<aclnnTanhGetWorkspaceSize, aclnnTanh>(
1282
+ ctx, dst);
1283
+ break;
1284
+ case GGML_UNARY_OP_RELU:
1285
+ ggml_cann_activation<aclnnReluGetWorkspaceSize, aclnnRelu>(
1286
+ ctx, dst);
1287
+ break;
1288
+ case GGML_UNARY_OP_HARDSIGMOID:
1289
+ ggml_cann_activation<aclnnHardsigmoidGetWorkspaceSize,
1290
+ aclnnHardsigmoid>(ctx, dst);
1291
+ break;
1292
+ case GGML_UNARY_OP_HARDSWISH:
1293
+ ggml_cann_activation<aclnnHardswishGetWorkspaceSize,
1294
+ aclnnHardswish>(ctx, dst);
1295
+ break;
1296
+ default:
1297
+ return false;
1298
+ }
1299
+ break;
1300
+ case GGML_OP_NORM:
1301
+ ggml_cann_norm(ctx, dst);
1302
+ break;
1303
+ case GGML_OP_GROUP_NORM:
1304
+ ggml_cann_group_norm(ctx, dst);
1305
+ break;
1306
+ case GGML_OP_CONCAT:
1307
+ ggml_cann_concat(ctx, dst);
1308
+ break;
1309
+ case GGML_OP_UPSCALE:
1310
+ ggml_cann_upsample_nearest2d(ctx, dst);
1311
+ break;
1312
+ case GGML_OP_PAD:
1313
+ ggml_cann_pad(ctx, dst);
1314
+ break;
1315
+ case GGML_OP_ARANGE:
1316
+ ggml_cann_arange(ctx, dst);
1317
+ break;
1318
+ case GGML_OP_TIMESTEP_EMBEDDING:
1319
+ ggml_cann_timestep_embedding(ctx, dst);
1320
+ break;
1321
+ case GGML_OP_LEAKY_RELU:
1322
+ ggml_cann_leaky_relu(ctx, dst);
1323
+ break;
1324
+ case GGML_OP_RMS_NORM:
1325
+ ggml_cann_rms_norm(ctx, dst);
1326
+ break;
1327
+ case GGML_OP_MUL_MAT:
1328
+ ggml_cann_mul_mat(ctx, dst);
1329
+ break;
1330
+ case GGML_OP_MUL_MAT_ID:
1331
+ return false;
1332
+ case GGML_OP_SCALE:
1333
+ ggml_cann_scale(ctx, dst);
1334
+ break;
1335
+ case GGML_OP_SQR:
1336
+ ggml_cann_sqr(ctx, dst);
1337
+ break;
1338
+ case GGML_OP_CLAMP:
1339
+ ggml_cann_clamp(ctx, dst);
1340
+ break;
1341
+ case GGML_OP_CPY:
1342
+ ggml_cann_cpy(ctx, dst);
1343
+ break;
1344
+ case GGML_OP_CONT:
1345
+ ggml_cann_dup(ctx, dst);
1346
+ break;
1347
+ case GGML_OP_NONE:
1348
+ case GGML_OP_RESHAPE:
1349
+ case GGML_OP_VIEW:
1350
+ case GGML_OP_PERMUTE:
1351
+ case GGML_OP_TRANSPOSE:
1352
+ break;
1353
+ case GGML_OP_DIAG_MASK_INF:
1354
+ ggml_cann_diag_mask(ctx, dst, -INFINITY);
1355
+ break;
1356
+ case GGML_OP_SOFT_MAX:
1357
+ ggml_cann_softmax(ctx, dst);
1358
+ break;
1359
+ case GGML_OP_ROPE:
1360
+ ggml_cann_rope(ctx, dst);
1361
+ break;
1362
+ case GGML_OP_IM2COL:
1363
+ ggml_cann_im2col(ctx, dst);
1364
+ break;
1365
+ case GGML_OP_POOL_2D:
1366
+ ggml_cann_pool2d(ctx, dst);
1367
+ break;
1368
+ case GGML_OP_SUM_ROWS:
1369
+ ggml_cann_sum_rows(ctx, dst);
1370
+ break;
1371
+ case GGML_OP_ARGSORT:
1372
+ ggml_cann_argsort(ctx, dst);
1373
+ break;
1374
+ default:
1375
+ return false;
1376
+ }
1377
+
1378
+ return true;
1379
+ }
1380
+
1381
+ // backend
1382
+ /**
1383
+ * @brief Retrieves the name associated with the CANN backend.
1384
+ *
1385
+ * This function returns the name assigned to the CANN backend, which is stored
1386
+ * in the context of the provided backend structure.
1387
+ *
1388
+ * @param backend Pointer to the CANN backend structure.
1389
+ * @return A pointer to a constant string representing the backend name.
1390
+ */
1391
+ GGML_CALL static const char* ggml_backend_cann_name(ggml_backend_t backend) {
1392
+ ggml_backend_cann_context* cann_ctx =
1393
+ (ggml_backend_cann_context*)backend->context;
1394
+
1395
+ return cann_ctx->name.c_str();
1396
+ }
1397
+
1398
+ /**
1399
+ * @brief Frees resources associated with the CANN backend.
1400
+ *
1401
+ * This function releases resources associated with the CANN backend context
1402
+ * and resets the device associated with the backend to its initial state.
1403
+ *
1404
+ * @param backend Pointer to the CANN backend structure to be freed.
1405
+ */
1406
+ GGML_CALL static void ggml_backend_cann_free(ggml_backend_t backend) {
1407
+ ggml_backend_cann_context* cann_ctx =
1408
+ (ggml_backend_cann_context*)backend->context;
1409
+ ACL_CHECK(aclrtSynchronizeDevice());
1410
+ ACL_CHECK(aclrtResetDevice(cann_ctx->device));
1411
+
1412
+ // finalize when last backend freed.
1413
+ if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) {
1414
+ ACL_CHECK(aclFinalize());
1415
+ }
1416
+
1417
+ delete cann_ctx;
1418
+ delete backend;
1419
+ }
1420
+
1421
+ /**
1422
+ * @brief Retrieves the default buffer type associated with the CANN backend.
1423
+ *
1424
+ * This function returns the buffer type specific to the device associated
1425
+ * with the CANN backend. It is used to allocate buffers for computations
1426
+ * performed by the backend.
1427
+ *
1428
+ * @param backend Pointer to the CANN backend structure.
1429
+ * @return Pointer to the buffer type structure for the CANN backend.
1430
+ */
1431
+ GGML_CALL static ggml_backend_buffer_type_t
1432
+ ggml_backend_cann_get_default_buffer_type(ggml_backend_t backend) {
1433
+ ggml_backend_cann_context* cann_ctx =
1434
+ (ggml_backend_cann_context*)backend->context;
1435
+
1436
+ return ggml_backend_cann_buffer_type(cann_ctx->device);
1437
+ }
1438
+
1439
+ /**
1440
+ * @brief Sets tensor data asynchronously in the CANN backend.
1441
+ *
1442
+ * This function asynchronously sets tensor data in the CANN backend. Depending
1443
+ * on the tensor type, it may perform data transformations before copying data
1444
+ * to the device.
1445
+ *
1446
+ * @param backend Pointer to the CANN backend structure.
1447
+ * @param tensor Pointer to the tensor structure to set data for.
1448
+ * @param data Pointer to the host data to copy to the tensor.
1449
+ * @param offset Offset in bytes within the host data.
1450
+ * @param size Size of the data to copy in bytes.
1451
+ */
1452
+ GGML_CALL static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
1453
+ ggml_tensor* tensor,
1454
+ const void* data,
1455
+ size_t offset,
1456
+ size_t size) {
1457
+ ggml_backend_cann_context* cann_ctx =
1458
+ (ggml_backend_cann_context*)backend->context;
1459
+
1460
+ if (!need_transform(tensor->type)) {
1461
+ ACL_CHECK(aclrtMemcpyAsync(
1462
+ tensor->data, size, (const char*)data + offset, size,
1463
+ ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream()));
1464
+ } else {
1465
+ void* transform_buffer = malloc(size);
1466
+ ggml_backend_cann_transform(tensor, (const char*)data + offset,
1467
+ transform_buffer);
1468
+
1469
+ #ifndef NDEBUG
1470
+ void* check_buffer = malloc(size);
1471
+ ggml_backend_cann_transform_back(tensor, transform_buffer,
1472
+ check_buffer);
1473
+ GGML_ASSERT(memcmp((const char*)data + offset, check_buffer, size));
1474
+ free(check_buffer);
1475
+ #endif
1476
+ ACL_CHECK(aclrtMemcpyAsync(tensor->data, size, transform_buffer, size,
1477
+ ACL_MEMCPY_HOST_TO_DEVICE,
1478
+ cann_ctx->stream()));
1479
+ ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
1480
+ free(transform_buffer);
1481
+ }
1482
+ }
1483
+
1484
+ GGML_CALL static void ggml_backend_cann_get_tensor_async(
1485
+ ggml_backend_t backend, const ggml_tensor* tensor, void* data,
1486
+ size_t offset, size_t size) {
1487
+ ggml_backend_cann_context* cann_ctx =
1488
+ (ggml_backend_cann_context*)backend->context;
1489
+ ggml_backend_buffer_t buf =
1490
+ tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
1491
+
1492
+ GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
1493
+ "unsupported buffer type");
1494
+
1495
+ if (!need_transform(tensor->type)) {
1496
+ ACL_CHECK(aclrtMemcpyAsync((char*)data + offset, size, tensor->data,
1497
+ size, ACL_MEMCPY_DEVICE_TO_HOST,
1498
+ cann_ctx->stream()));
1499
+ } else {
1500
+ void* transform_buffer = malloc(size);
1501
+ ACL_CHECK(aclrtMemcpyAsync(transform_buffer, size, tensor->data, size,
1502
+ ACL_MEMCPY_DEVICE_TO_HOST,
1503
+ cann_ctx->stream()));
1504
+ ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
1505
+ ggml_backend_cann_transform_back(tensor, transform_buffer,
1506
+ (char*)data + offset);
1507
+ free(transform_buffer);
1508
+ }
1509
+ }
1510
+
1511
+ /**
1512
+ * @brief Asynchronously copies tensor data between CANN backends.
1513
+ *
1514
+ * This function copies tensor data asynchronously between two CANN backends. It
1515
+ * checks if both tensors reside in CANN buffers and whether the devices support
1516
+ * peer-to-peer access for direct copying. If not, it returns false.
1517
+ *
1518
+ * @param backend_src Pointer to the source CANN backend structure.
1519
+ * @param backend_dst Pointer to the destination CANN backend structure.
1520
+ * @param src Pointer to the source tensor to copy data from.
1521
+ * @param dst Pointer to the destination tensor to copy data to.
1522
+ * @return true if the copy operation succeeds, false otherwise.
1523
+ */
1524
+ GGML_CALL static bool ggml_backend_cann_cpy_tensor_async(
1525
+ ggml_backend_t backend_src, ggml_backend_t backend_dst,
1526
+ const ggml_tensor* src, ggml_tensor* dst) {
1527
+ GGML_ASSERT(ggml_backend_is_cann(backend_src) ||
1528
+ ggml_backend_is_cann(backend_dst));
1529
+
1530
+ if (!ggml_backend_buffer_is_cann(src->buffer) ||
1531
+ !ggml_backend_buffer_is_cann(dst->buffer)) {
1532
+ return false;
1533
+ }
1534
+
1535
+ ggml_backend_buffer_t buf_src =
1536
+ src->view_src ? src->view_src->buffer : src->buffer;
1537
+ ggml_backend_buffer_t buf_dst =
1538
+ dst->view_src ? dst->view_src->buffer : dst->buffer;
1539
+
1540
+ ggml_backend_cann_context* cann_ctx_src =
1541
+ (ggml_backend_cann_context*)backend_src->context;
1542
+ ggml_backend_cann_context* cann_ctx_dst =
1543
+ (ggml_backend_cann_context*)backend_dst->context;
1544
+
1545
+ size_t copy_size = ggml_nbytes(dst);
1546
+ if (backend_src != backend_dst) {
1547
+ ggml_backend_cann_buffer_context* buf_ctx_src =
1548
+ (ggml_backend_cann_buffer_context*)buf_src->context;
1549
+ ggml_backend_cann_buffer_context* buf_ctx_dst =
1550
+ (ggml_backend_cann_buffer_context*)buf_dst->context;
1551
+
1552
+ GGML_ASSERT(cann_ctx_src->device == buf_ctx_src->device);
1553
+ GGML_ASSERT(cann_ctx_dst->device == buf_ctx_dst->device);
1554
+
1555
+ int32_t canAccessPeer = 0;
1556
+ ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, cann_ctx_src->device,
1557
+ cann_ctx_dst->device));
1558
+ if (!canAccessPeer) {
1559
+ return false;
1560
+ }
1561
+
1562
+ // need open both directions for memcpyasync between devices.
1563
+ ggml_cann_set_device(cann_ctx_dst->device);
1564
+ ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_src->device, 0));
1565
+ ggml_cann_set_device(cann_ctx_src->device);
1566
+ ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
1567
+
1568
+ ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
1569
+ ACL_MEMCPY_DEVICE_TO_DEVICE,
1570
+ cann_ctx_src->stream()));
1571
+
1572
+ //TODO: workaround for Event didn`t work here.
1573
+ aclrtSynchronizeStream(cann_ctx_src->stream());
1574
+ } else {
1575
+ // src and dst are on the same backend
1576
+ ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
1577
+ ACL_MEMCPY_DEVICE_TO_DEVICE,
1578
+ cann_ctx_dst->stream()));
1579
+ }
1580
+
1581
+ return true;
1582
+ }
1583
+
1584
+ /**
1585
+ * @brief Synchronizes a CANN backend.
1586
+ *
1587
+ * This function synchronizes the specified CANN backend by waiting for all
1588
+ * operations in its associated stream to complete.
1589
+ *
1590
+ * @param backend Pointer to the CANN backend structure to synchronize.
1591
+ */
1592
+ GGML_CALL static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
1593
+ ggml_backend_cann_context* cann_ctx =
1594
+ (ggml_backend_cann_context*)backend->context;
1595
+
1596
+ ggml_cann_set_device(cann_ctx->device);
1597
+
1598
+ ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
1599
+ }
1600
+
1601
+ /**
1602
+ * @brief Computes a computational graph using a CANN backend.
1603
+ *
1604
+ * This function computes the operations defined in the computational graph
1605
+ * using the specified CANN backend.
1606
+ *
1607
+ * @param backend Pointer to the CANN backend structure to use for computation.
1608
+ * @param cgraph Pointer to the computational graph structure containing nodes
1609
+ * representing operations to be computed.
1610
+ * @return enum ggml_status Returns GGML_STATUS_SUCCESS if computation
1611
+ * completes successfully, otherwise an appropriate error status.
1612
+ */
1613
+ GGML_CALL static enum ggml_status ggml_backend_cann_graph_compute(
1614
+ ggml_backend_t backend, ggml_cgraph* cgraph) {
1615
+ ggml_backend_cann_context* cann_ctx =
1616
+ (ggml_backend_cann_context*)backend->context;
1617
+
1618
+ ggml_cann_set_device(cann_ctx->device);
1619
+
1620
+ for (int i = 0; i < cgraph->n_nodes; i++) {
1621
+ ggml_tensor* node = cgraph->nodes[i];
1622
+
1623
+ if (ggml_is_empty(node) || node->op == GGML_OP_NONE) {
1624
+ continue;
1625
+ }
1626
+
1627
+ bool ok = ggml_cann_compute_forward(*cann_ctx, node);
1628
+
1629
+ if (!ok) {
1630
+ GGML_CANN_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__,
1631
+ node->name, ggml_op_name(node->op));
1632
+ }
1633
+ GGML_ASSERT(ok);
1634
+ }
1635
+
1636
+ return GGML_STATUS_SUCCESS;
1637
+ }
1638
+
1639
+ /**
1640
+ * @brief Checks if the CANN backend supports a specific operation.
1641
+ *
1642
+ * This function checks whether the specified operation is supported by the
1643
+ * CANN backend.
1644
+ *
1645
+ * @param backend Pointer to the CANN backend structure to check support for
1646
+ * the operation.
1647
+ * @param op Pointer to the tensor representing the operation to check.
1648
+ * @return bool Returns true if the operation is supported by the backend,
1649
+ * otherwise false.
1650
+ */
1651
+ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend,
1652
+ const ggml_tensor* op) {
1653
+ switch (op->op) {
1654
+ case GGML_OP_UNARY:
1655
+ switch (ggml_get_unary_op(op)) {
1656
+ case GGML_UNARY_OP_GELU:
1657
+ case GGML_UNARY_OP_SILU:
1658
+ case GGML_UNARY_OP_RELU:
1659
+ case GGML_UNARY_OP_HARDSIGMOID:
1660
+ case GGML_UNARY_OP_HARDSWISH:
1661
+ case GGML_UNARY_OP_GELU_QUICK:
1662
+ case GGML_UNARY_OP_TANH:
1663
+ return true;
1664
+ default:
1665
+ return false;
1666
+ }
1667
+ case GGML_OP_MUL_MAT: {
1668
+ switch (op->src[0]->type) {
1669
+ // case GGML_TYPE_Q4_0:
1670
+ case GGML_TYPE_F16:
1671
+ case GGML_TYPE_F32:
1672
+ case GGML_TYPE_Q8_0:
1673
+ return true;
1674
+ default:
1675
+ return false;
1676
+ }
1677
+ }
1678
+ case GGML_OP_MUL_MAT_ID:
1679
+ return false;
1680
+ // embedding
1681
+ case GGML_OP_GET_ROWS: {
1682
+ switch (op->src[0]->type) {
1683
+ case GGML_TYPE_F32:
1684
+ case GGML_TYPE_F16:
1685
+ case GGML_TYPE_Q4_0:
1686
+ case GGML_TYPE_Q8_0:
1687
+ return true;
1688
+ default:
1689
+ return false;
1690
+ }
1691
+ } break;
1692
+ case GGML_OP_CPY: {
1693
+ switch (op->type) {
1694
+ case GGML_TYPE_F32:
1695
+ case GGML_TYPE_F16:
1696
+ case GGML_TYPE_Q8_0:
1697
+ return true;
1698
+ default:
1699
+ return false;
1700
+ }
1701
+ }
1702
+ case GGML_OP_DUP:
1703
+ case GGML_OP_REPEAT:
1704
+ case GGML_OP_CONCAT:
1705
+ case GGML_OP_NONE:
1706
+ case GGML_OP_RESHAPE:
1707
+ case GGML_OP_VIEW:
1708
+ case GGML_OP_PERMUTE:
1709
+ case GGML_OP_TRANSPOSE:
1710
+ case GGML_OP_NORM:
1711
+ case GGML_OP_ADD:
1712
+ case GGML_OP_MUL:
1713
+ case GGML_OP_DIV:
1714
+ case GGML_OP_RMS_NORM:
1715
+ case GGML_OP_SCALE:
1716
+ case GGML_OP_SQR:
1717
+ case GGML_OP_CLAMP:
1718
+ case GGML_OP_CONT:
1719
+ case GGML_OP_DIAG_MASK_INF:
1720
+ case GGML_OP_SOFT_MAX:
1721
+ case GGML_OP_ROPE:
1722
+ case GGML_OP_IM2COL:
1723
+ case GGML_OP_POOL_2D:
1724
+ case GGML_OP_SUM_ROWS:
1725
+ case GGML_OP_ARGSORT:
1726
+ case GGML_OP_ACC:
1727
+ case GGML_OP_GROUP_NORM:
1728
+ case GGML_OP_UPSCALE:
1729
+ case GGML_OP_PAD:
1730
+ case GGML_OP_ARANGE:
1731
+ case GGML_OP_TIMESTEP_EMBEDDING:
1732
+ case GGML_OP_LEAKY_RELU:
1733
+ return true;
1734
+ default:
1735
+ return false;
1736
+ }
1737
+
1738
+ GGML_UNUSED(backend);
1739
+ }
1740
+
1741
+ /**
1742
+ * @brief Checks if the backend buffer type is associated with the CANN backend.
1743
+ *
1744
+ * This function checks whether the provided backend buffer type is associated
1745
+ * with the CANN backend based on the comparison of its name retrieval function
1746
+ * pointer.
1747
+ *
1748
+ * @param buft Pointer to the backend buffer type to check.
1749
+ * @return bool Returns true if the buffer type is associated with the CANN
1750
+ * backend, otherwise false.
1751
+ */
1752
+ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
1753
+ return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
1754
+ }
1755
+
1756
+ /**
1757
+ * @brief Checks if the CANN backend supports a specific backend buffer type.
1758
+ *
1759
+ * This function determines whether the CANN backend supports the given backend
1760
+ * buffer type by comparing the device context of the backend and buffer type.
1761
+ * It returns true if the devices are same between the backend context and
1762
+ * buffer type context.
1763
+ *
1764
+ * @param backend Pointer to the CANN backend.
1765
+ * @param buft Pointer to the backend buffer type to check.
1766
+ * @return bool Returns true if the CANN backend supports the buffer type,
1767
+ * otherwise false.
1768
+ */
1769
+ GGML_CALL static bool ggml_backend_cann_supports_buft(
1770
+ ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
1771
+ if (ggml_backend_buft_is_cann(buft)) {
1772
+ ggml_backend_cann_context * cann_ctx =
1773
+ (ggml_backend_cann_context *)backend->context;
1774
+ ggml_backend_cann_buffer_type_context * buft_ctx =
1775
+ (ggml_backend_cann_buffer_type_context *)buft->context;
1776
+ return buft_ctx->device == cann_ctx->device;
1777
+ }
1778
+ return false;
1779
+ }
1780
+
1781
+ /**
1782
+ * @brief Determines if a tensor operation should be offloaded to the CANN
1783
+ * backend.
1784
+ *
1785
+ * This function checks if a given tensor operation should be offloaded to the
1786
+ * CANN backend based on the operation type and the size of the tensor. It
1787
+ * returns true if the second dimension (ne[1]) of the tensor is greater than or
1788
+ * equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS.
1789
+ *
1790
+ * @param backend Pointer to the CANN backend.
1791
+ * @param op Pointer to the tensor operation to check.
1792
+ * @return bool Returns true if the operation should be offloaded, otherwise
1793
+ * false.
1794
+ */
1795
+ GGML_CALL static bool ggml_backend_cann_offload_op(ggml_backend_t backend,
1796
+ const ggml_tensor* op) {
1797
+ const int min_batch_size = 32;
1798
+ GGML_UNUSED(backend);
1799
+
1800
+ return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
1801
+ }
1802
+
1803
+ /**
1804
+ * @brief Creates a new event for the CANN backend.
1805
+ *
1806
+ * This function initializes a new event for the CANN backend by setting the
1807
+ * device and creating an ACL runtime event. The created event is then wrapped
1808
+ * in a ggml_backend_event structure and returned.
1809
+ *
1810
+ * @param backend Pointer to the CANN backend.
1811
+ * @return ggml_backend_event_t Returns a pointer to the new event structure.
1812
+ */
1813
+ static ggml_backend_event_t ggml_backend_cann_event_new(
1814
+ ggml_backend_t backend) {
1815
+ ggml_backend_cann_context* cann_ctx =
1816
+ (ggml_backend_cann_context*)backend->context;
1817
+
1818
+ ggml_cann_set_device(cann_ctx->device);
1819
+
1820
+ aclrtEvent event;
1821
+ ACL_CHECK(aclrtCreateEvent(&event));
1822
+
1823
+ return new ggml_backend_event{
1824
+ /* .backend = */ backend,
1825
+ /* .context = */ event,
1826
+ };
1827
+ }
1828
+
1829
+ /**
1830
+ * @brief Frees a CANN backend event.
1831
+ *
1832
+ * This function destroys the ACL runtime event associated with the given CANN
1833
+ * backend event and then deletes the event structure itself.
1834
+ *
1835
+ * @param event Pointer to the event structure to be freed.
1836
+ */
1837
+ static void ggml_backend_cann_event_free(ggml_backend_event_t event) {
1838
+ ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
1839
+
1840
+ delete event;
1841
+ }
1842
+
1843
+ /**
1844
+ * @brief Records an event on the CANN backend stream.
1845
+ *
1846
+ * This function records the given event on the ACL runtime stream associated
1847
+ * with the backend context.
1848
+ *
1849
+ * @param event Pointer to the event structure to be recorded.
1850
+ */
1851
+ static void ggml_backend_cann_event_record(ggml_backend_event_t event) {
1852
+ ggml_backend_cann_context* cann_ctx =
1853
+ (ggml_backend_cann_context*)event->backend->context;
1854
+
1855
+ ACL_CHECK(aclrtRecordEvent((aclrtEvent)event->context, cann_ctx->stream()));
1856
+ }
1857
+
1858
+ /**
1859
+ * @brief Waits for a recorded event to complete on the CANN backend stream.
1860
+ *
1861
+ * This function makes the given backend wait for the event to complete on its
1862
+ * ACL runtime stream.
1863
+ *
1864
+ * @param backend Pointer to the backend structure.
1865
+ * @param event Pointer to the event structure that the backend needs to wait
1866
+ * for.
1867
+ */
1868
+ static void ggml_backend_cann_event_wait(ggml_backend_t backend,
1869
+ ggml_backend_event_t event) {
1870
+ ggml_backend_cann_context* cann_ctx =
1871
+ (ggml_backend_cann_context*)backend->context;
1872
+
1873
+ if (ggml_backend_is_cann(event->backend)) {
1874
+ ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(),
1875
+ (aclrtEvent)event->context));
1876
+ } else {
1877
+ GGML_ABORT("fatal error");
1878
+ }
1879
+ }
1880
+
1881
+ /**
1882
+ * @brief Synchronizes the given event on the CANN backend.
1883
+ *
1884
+ * This function waits for the specified event to complete on the ACL runtime.
1885
+ *
1886
+ * @param event Pointer to the event structure to be synchronized.
1887
+ */
1888
+ static void ggml_backend_cann_event_synchronize(ggml_backend_event_t event) {
1889
+ ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
1890
+ }
1891
+
1892
+ /**
1893
+ * @brief Structure defining the interface for the CANN backend.
1894
+ *
1895
+ * This structure contains function pointers for various operations
1896
+ * supported by the CANN backend, including name retrieval, memory
1897
+ * management, tensor operations, synchronization, and event handling.
1898
+ */
1899
+ static ggml_backend_i ggml_backend_cann_interface = {
1900
+ /* .get_name = */ ggml_backend_cann_name,
1901
+ /* .free = */ ggml_backend_cann_free,
1902
+ /* .get_default_buffer_type = */ ggml_backend_cann_get_default_buffer_type,
1903
+ /* .set_tensor_async = */ ggml_backend_cann_set_tensor_async,
1904
+ /* .get_tensor_async = */ ggml_backend_cann_get_tensor_async,
1905
+ /* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async,
1906
+ /* .synchronize = */ ggml_backend_cann_synchronize,
1907
+ /* .graph_plan_create = */ NULL,
1908
+ /* .graph_plan_free = */ NULL,
1909
+ /* .graph_plan_update = */ NULL,
1910
+ /* .graph_plan_compute = */ NULL,
1911
+ /* .graph_compute = */ ggml_backend_cann_graph_compute,
1912
+ /* .supports_op = */ ggml_backend_cann_supports_op,
1913
+ /* .supports_buft = */ ggml_backend_cann_supports_buft,
1914
+ /* .offload_op = */ ggml_backend_cann_offload_op,
1915
+ /* .event_new = */ ggml_backend_cann_event_new,
1916
+ /* .event_free = */ ggml_backend_cann_event_free,
1917
+ /* .event_record = */ ggml_backend_cann_event_record,
1918
+ /* .event_wait = */ ggml_backend_cann_event_wait,
1919
+ /* .event_synchronize = */ ggml_backend_cann_event_synchronize,
1920
+ };
1921
+
1922
+ /**
1923
+ * @brief Return the hardcoded GUID for the CANN backend.
1924
+ *
1925
+ * This function returns a static GUID which uniquely identifies the CANN
1926
+ * backend.
1927
+ *
1928
+ * @return A pointer to the static GUID.
1929
+ */
1930
+ static ggml_guid_t ggml_backend_cann_guid() {
1931
+ static ggml_guid guid = {0xa1, 0x94, 0xaf, 0xac, 0xbd, 0x4f, 0x47, 0x34,
1932
+ 0xbe, 0x1a, 0x9e, 0x71, 0x1f, 0x9e, 0xed, 0x64};
1933
+ return &guid;
1934
+ }
1935
+
1936
+ GGML_CALL ggml_backend_t ggml_backend_cann_init(int32_t device) {
1937
+ aclInit(nullptr);
1938
+ if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
1939
+ GGML_CANN_LOG_ERROR("%s: error: invalid device %d\n", __func__, device);
1940
+ return nullptr;
1941
+ }
1942
+
1943
+ ggml_backend_cann_context* ctx = new ggml_backend_cann_context(device);
1944
+ if (ctx == nullptr) {
1945
+ GGML_CANN_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
1946
+ return nullptr;
1947
+ }
1948
+
1949
+ ggml_backend_t cann_backend =
1950
+ new ggml_backend{/* .guid = */ ggml_backend_cann_guid(),
1951
+ /* .interface = */ ggml_backend_cann_interface,
1952
+ /* .context = */ ctx};
1953
+
1954
+ return cann_backend;
1955
+ }
1956
+
1957
+ GGML_CALL bool ggml_backend_is_cann(ggml_backend_t backend) {
1958
+ return backend != NULL &&
1959
+ ggml_guid_matches(backend->guid, ggml_backend_cann_guid());
1960
+ }
1961
+
1962
+ GGML_CALL int32_t ggml_backend_cann_get_device_count() {
1963
+ return ggml_cann_info().device_count;
1964
+ }
1965
+
1966
+ GGML_CALL void ggml_backend_cann_get_device_description(
1967
+ int32_t device, char* description, size_t description_size) {
1968
+ ggml_cann_set_device(device);
1969
+ const char* soc_name = aclrtGetSocName();
1970
+ snprintf(description, description_size, "%s", soc_name);
1971
+ }
1972
+
1973
+ GGML_CALL void ggml_backend_cann_get_device_memory(int32_t device, size_t* free,
1974
+ size_t* total) {
1975
+ ggml_cann_set_device(device);
1976
+ ACL_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, free, total));
1977
+ }
1978
+
1979
+ // backend registry
1980
+ /**
1981
+ * @brief Initializes a CANN backend based on the provided parameters.
1982
+ *
1983
+ * This function initializes a CANN backend using the device index and then
1984
+ * initializes the backend using `ggml_backend_cann_init`.
1985
+ *
1986
+ * @param params Parameters for initialization (unused in this implementation).
1987
+ * @param user_data User data containing the device index to initialize the
1988
+ * backend.
1989
+ * @return ggml_backend_t The initialized CANN backend.
1990
+ */
1991
+ GGML_CALL static ggml_backend_t ggml_backend_reg_cann_init(const char* params,
1992
+ void* user_data) {
1993
+ ggml_backend_t cann_backend =
1994
+ ggml_backend_cann_init((int)(intptr_t)user_data);
1995
+ return cann_backend;
1996
+
1997
+ GGML_UNUSED(params);
1998
+ }
1999
+
2000
+ extern "C" GGML_CALL int ggml_backend_cann_reg_devices();
2001
+
2002
+ /**
2003
+ * @brief Registers CANN (Ascend) devices as backend options.
2004
+ *
2005
+ * This function initializes ACL, retrieves the number of available CANN
2006
+ * devices, and registers each device as a backend option using
2007
+ * `ggml_backend_register`. Each device is given a unique name based on
2008
+ * `GGML_CANN_NAME` followed by its index.
2009
+ *
2010
+ * @return int The number of CANN devices registered.
2011
+ */
2012
+ GGML_CALL int ggml_backend_cann_reg_devices() {
2013
+ uint32_t device_count = ggml_backend_cann_get_device_count();
2014
+ // initialization
2015
+ for (uint32_t i = 0; i < device_count; i++) {
2016
+ char name[128];
2017
+ snprintf(name, sizeof(name), "CANN%d", i);
2018
+ ggml_backend_register(name, ggml_backend_reg_cann_init,
2019
+ ggml_backend_cann_buffer_type(i),
2020
+ (void*)(intptr_t)i);
2021
+ }
2022
+ return device_count;
2023
+ }