@fugood/llama.node 0.3.2 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (286) hide show
  1. package/CMakeLists.txt +7 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +18 -1
  17. package/package.json +1 -1
  18. package/src/DetokenizeWorker.cpp +1 -1
  19. package/src/EmbeddingWorker.cpp +17 -7
  20. package/src/EmbeddingWorker.h +2 -1
  21. package/src/LlamaCompletionWorker.cpp +8 -8
  22. package/src/LlamaCompletionWorker.h +2 -2
  23. package/src/LlamaContext.cpp +89 -27
  24. package/src/LlamaContext.h +2 -0
  25. package/src/TokenizeWorker.cpp +1 -1
  26. package/src/common.hpp +4 -4
  27. package/src/llama.cpp/.github/workflows/build.yml +240 -168
  28. package/src/llama.cpp/.github/workflows/docker.yml +8 -8
  29. package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
  30. package/src/llama.cpp/.github/workflows/server.yml +21 -14
  31. package/src/llama.cpp/CMakeLists.txt +14 -6
  32. package/src/llama.cpp/Sources/llama/llama.h +4 -0
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/cmake/common.cmake +33 -0
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
  36. package/src/llama.cpp/common/CMakeLists.txt +6 -4
  37. package/src/llama.cpp/common/arg.cpp +986 -770
  38. package/src/llama.cpp/common/arg.h +22 -22
  39. package/src/llama.cpp/common/common.cpp +212 -351
  40. package/src/llama.cpp/common/common.h +204 -117
  41. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  42. package/src/llama.cpp/common/log.cpp +50 -50
  43. package/src/llama.cpp/common/log.h +18 -18
  44. package/src/llama.cpp/common/ngram-cache.cpp +36 -36
  45. package/src/llama.cpp/common/ngram-cache.h +19 -19
  46. package/src/llama.cpp/common/sampling.cpp +163 -121
  47. package/src/llama.cpp/common/sampling.h +41 -20
  48. package/src/llama.cpp/common/speculative.cpp +274 -0
  49. package/src/llama.cpp/common/speculative.h +28 -0
  50. package/src/llama.cpp/docs/build.md +134 -161
  51. package/src/llama.cpp/examples/CMakeLists.txt +33 -14
  52. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/batched/batched.cpp +19 -18
  54. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  55. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
  56. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  57. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  58. package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
  60. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
  61. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
  63. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
  64. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
  65. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  66. package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
  67. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
  68. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
  69. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
  70. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
  71. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  72. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
  73. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
  75. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  76. package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
  77. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  78. package/src/llama.cpp/examples/imatrix/imatrix.cpp +31 -13
  79. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  80. package/src/llama.cpp/examples/infill/infill.cpp +41 -87
  81. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
  82. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +439 -459
  83. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +2 -0
  84. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
  85. package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
  86. package/src/llama.cpp/examples/llava/clip.cpp +263 -66
  87. package/src/llama.cpp/examples/llava/clip.h +8 -2
  88. package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
  89. package/src/llama.cpp/examples/llava/llava.cpp +83 -22
  90. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
  91. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
  92. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
  94. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  95. package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
  96. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  97. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +16 -15
  98. package/src/llama.cpp/examples/lookup/lookup.cpp +30 -30
  99. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  100. package/src/llama.cpp/examples/main/main.cpp +73 -114
  101. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
  102. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  103. package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
  104. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  105. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  106. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  107. package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
  108. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  109. package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
  110. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  111. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
  112. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  113. package/src/llama.cpp/examples/retrieval/retrieval.cpp +16 -16
  114. package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
  115. package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
  116. package/src/llama.cpp/examples/run/run.cpp +911 -0
  117. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  118. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +38 -21
  119. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -16
  120. package/src/llama.cpp/examples/server/server.cpp +2073 -1339
  121. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  122. package/src/llama.cpp/examples/server/utils.hpp +354 -277
  123. package/src/llama.cpp/examples/simple/CMakeLists.txt +2 -2
  124. package/src/llama.cpp/examples/simple/simple.cpp +130 -94
  125. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  126. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +200 -0
  127. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  128. package/src/llama.cpp/examples/speculative/speculative.cpp +68 -64
  129. package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
  130. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
  131. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  132. package/src/llama.cpp/examples/tokenize/tokenize.cpp +3 -3
  133. package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
  134. package/src/llama.cpp/examples/tts/tts.cpp +932 -0
  135. package/src/llama.cpp/ggml/CMakeLists.txt +54 -36
  136. package/src/llama.cpp/ggml/include/ggml-backend.h +63 -34
  137. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  138. package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
  139. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  140. package/src/llama.cpp/ggml/include/ggml-cpu.h +135 -0
  141. package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
  142. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  143. package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
  144. package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
  145. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  146. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  147. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  148. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  149. package/src/llama.cpp/ggml/include/ggml.h +159 -417
  150. package/src/llama.cpp/ggml/src/CMakeLists.txt +121 -1155
  151. package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -28
  152. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +57 -36
  153. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +552 -0
  154. package/src/llama.cpp/ggml/src/ggml-backend.cpp +306 -867
  155. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  156. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +216 -65
  157. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +76 -0
  158. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
  159. package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
  160. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +343 -177
  161. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
  162. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
  163. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
  164. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
  165. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
  166. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
  167. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
  168. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
  169. package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
  170. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +336 -0
  171. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  172. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  173. package/src/llama.cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
  174. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  175. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  176. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  177. package/src/llama.cpp/ggml/src/{ggml-aarch64.c → ggml-cpu/ggml-cpu-aarch64.cpp} +1299 -246
  178. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  179. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  180. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  181. package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +14 -242
  182. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  183. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  184. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  185. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  186. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  187. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +628 -0
  188. package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +666 -0
  189. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +152 -0
  190. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +104 -0
  192. package/src/llama.cpp/ggml/src/ggml-impl.h +393 -22
  193. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  194. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +360 -127
  195. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +105 -0
  196. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  197. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +107 -0
  198. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
  199. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
  200. package/src/llama.cpp/ggml/src/ggml-opt.cpp +854 -0
  201. package/src/llama.cpp/ggml/src/ggml-quants.c +188 -10702
  202. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
  203. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  204. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +478 -300
  205. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +84 -0
  206. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  207. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +36 -5
  208. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +259 -0
  209. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
  210. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  211. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
  212. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +34 -35
  213. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  214. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  215. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
  216. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3638 -4151
  217. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
  218. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
  219. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -87
  220. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +7 -6
  221. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  222. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  223. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
  224. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
  225. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
  226. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
  227. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  228. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  229. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  230. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  231. package/src/llama.cpp/ggml/src/ggml-threading.h +14 -0
  232. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +92 -0
  233. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2138 -887
  234. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +3 -1
  235. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  236. package/src/llama.cpp/ggml/src/ggml.c +4427 -20125
  237. package/src/llama.cpp/include/llama-cpp.h +25 -0
  238. package/src/llama.cpp/include/llama.h +93 -52
  239. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
  240. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
  241. package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
  242. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  243. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  244. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  245. package/src/llama.cpp/src/CMakeLists.txt +4 -8
  246. package/src/llama.cpp/src/llama-grammar.cpp +15 -15
  247. package/src/llama.cpp/src/llama-grammar.h +2 -5
  248. package/src/llama.cpp/src/llama-sampling.cpp +779 -194
  249. package/src/llama.cpp/src/llama-sampling.h +21 -2
  250. package/src/llama.cpp/src/llama-vocab.cpp +55 -10
  251. package/src/llama.cpp/src/llama-vocab.h +35 -11
  252. package/src/llama.cpp/src/llama.cpp +4317 -2979
  253. package/src/llama.cpp/src/unicode-data.cpp +2 -2
  254. package/src/llama.cpp/src/unicode.cpp +62 -51
  255. package/src/llama.cpp/src/unicode.h +9 -10
  256. package/src/llama.cpp/tests/CMakeLists.txt +48 -38
  257. package/src/llama.cpp/tests/test-arg-parser.cpp +15 -15
  258. package/src/llama.cpp/tests/test-backend-ops.cpp +324 -80
  259. package/src/llama.cpp/tests/test-barrier.cpp +1 -0
  260. package/src/llama.cpp/tests/test-chat-template.cpp +59 -9
  261. package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
  262. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
  263. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
  264. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
  265. package/src/llama.cpp/tests/test-log.cpp +2 -2
  266. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  267. package/src/llama.cpp/tests/test-quantize-fns.cpp +24 -21
  268. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  269. package/src/llama.cpp/tests/test-rope.cpp +62 -20
  270. package/src/llama.cpp/tests/test-sampling.cpp +163 -138
  271. package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
  272. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  273. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  274. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
  275. package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
  276. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
  277. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
  278. package/src/llama.cpp/common/train.cpp +0 -1515
  279. package/src/llama.cpp/common/train.h +0 -233
  280. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
  281. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
  282. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -39
  283. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +0 -600
  284. package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
  285. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  286. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -22,11 +22,14 @@
22
22
 
23
23
  #include "aclnn_ops.h"
24
24
 
25
+ #include <aclnnop/aclnn_addcdiv.h>
25
26
  #include <aclnnop/aclnn_avgpool2d.h>
27
+ #include <aclnnop/aclnn_batch_matmul.h>
26
28
  #include <aclnnop/aclnn_cast.h>
27
29
  #include <aclnnop/aclnn_constant_pad_nd.h>
28
30
  #include <aclnnop/aclnn_copy.h>
29
31
  #include <aclnnop/aclnn_cos.h>
32
+ #include <aclnnop/aclnn_div.h>
30
33
  #include <aclnnop/aclnn_exp.h>
31
34
  #include <aclnnop/aclnn_fill_scalar.h>
32
35
  #include <aclnnop/aclnn_group_norm.h>
@@ -34,6 +37,7 @@
34
37
  #include <aclnnop/aclnn_layer_norm.h>
35
38
  #include <aclnnop/aclnn_matmul.h>
36
39
  #include <aclnnop/aclnn_max_pool.h>
40
+ #include <aclnnop/aclnn_mm.h>
37
41
  #include <aclnnop/aclnn_permute.h>
38
42
  #include <aclnnop/aclnn_pow_tensor_tensor.h>
39
43
  #include <aclnnop/aclnn_reduce_sum.h>
@@ -53,6 +57,7 @@
53
57
  #include <exception>
54
58
  #include <vector>
55
59
 
60
+ #include "ggml-impl.h"
56
61
  #include "kernels/ascendc_kernels.h"
57
62
 
58
63
  #define GGML_COMMON_DECL_C
@@ -241,10 +246,14 @@ void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
241
246
  aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
242
247
  aclTensor* acl_dst = ggml_cann_create_tensor(dst);
243
248
 
244
- int64_t concat_dim = 1;
249
+ const int32_t dim = ggml_get_op_params_i32(dst, 0);
250
+
251
+ GGML_ASSERT(dim >= 0 && dim < 4);
252
+ int32_t acl_dim = 3 - dim;
253
+
245
254
  aclTensor* tensors[] = {acl_src0, acl_src1};
246
255
  aclTensorList* tensorList = aclCreateTensorList(tensors, 2);
247
- aclnn_concat(ctx, tensorList, acl_dst, concat_dim);
256
+ aclnn_concat(ctx, tensorList, acl_dst, acl_dim);
248
257
 
249
258
  ACL_CHECK(aclDestroyTensorList(tensorList));
250
259
  ACL_CHECK(aclDestroyTensor(acl_dst));
@@ -1096,9 +1105,9 @@ static aclTensor* aclnn_zero(ggml_backend_cann_context& ctx, void* buffer,
1096
1105
  }
1097
1106
 
1098
1107
  /**
1099
- * @brief Creates an ACL tensor initialized with ones using a provided buffer.
1108
+ * @brief Creates an ACL tensor initialized with value using a provided buffer.
1100
1109
  *
1101
- * This function initializes a tensor with ones using the specified buffer and
1110
+ * This function initializes a tensor with value using the specified buffer and
1102
1111
  * tensor parameters.
1103
1112
  *
1104
1113
  * @param ctx The context for the CANN backend operations.
@@ -1111,12 +1120,12 @@ static aclTensor* aclnn_zero(ggml_backend_cann_context& ctx, void* buffer,
1111
1120
  * @param type_size The size of each element in the tensor data type.
1112
1121
  * @param value The value to be used for initializing the tensor (default
1113
1122
  * is 1.0).
1114
- * @return An ACL tensor initialized with ones.
1123
+ * @return An ACL tensor initialized with value.
1115
1124
  */
1116
- static aclTensor* aclnn_ones(ggml_backend_cann_context& ctx, void* buffer,
1117
- size_t n_bytes, int64_t* ne, int64_t dims,
1118
- aclDataType type, size_t type_size,
1119
- float value = 1.0f) {
1125
+ static aclTensor* aclnn_values(ggml_backend_cann_context& ctx, void* buffer,
1126
+ size_t n_bytes, int64_t* ne, int64_t dims,
1127
+ aclDataType type, size_t type_size,
1128
+ float value = 1.0f) {
1120
1129
  aclTensor* acl_tensor =
1121
1130
  aclnn_zero(ctx, buffer, n_bytes, ne, dims, type, type_size);
1122
1131
  float alpha_host = 1.0f;
@@ -1158,7 +1167,7 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1158
1167
  size_t one_tensor_n_bytes = src->ne[0] * ggml_element_size(src);
1159
1168
  ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes);
1160
1169
 
1161
- aclTensor* acl_gamma = aclnn_ones(
1170
+ aclTensor* acl_gamma = aclnn_values(
1162
1171
  ctx, one_tensor_allocator.get(), one_tensor_n_bytes, src->ne, 1,
1163
1172
  ggml_cann_type_mapping(src->type), ggml_element_size(src));
1164
1173
 
@@ -1202,9 +1211,9 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst,
1202
1211
  ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes);
1203
1212
 
1204
1213
  aclTensor* mask_tensor =
1205
- aclnn_ones(ctx, one_tensor_allocator.get(), one_tensor_n_bytes, src->ne,
1206
- GGML_MAX_DIMS, ggml_cann_type_mapping(src->type),
1207
- ggml_element_size(src), value);
1214
+ aclnn_values(ctx, one_tensor_allocator.get(), one_tensor_n_bytes,
1215
+ src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type),
1216
+ ggml_element_size(src), value);
1208
1217
 
1209
1218
  uint64_t workspaceSize = 0;
1210
1219
  aclOpExecutor* executor;
@@ -1437,10 +1446,6 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1437
1446
  ggml_tensor* src0 = dst->src[0]; // kernel
1438
1447
  ggml_tensor* src1 = dst->src[1]; // input
1439
1448
 
1440
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
1441
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
1442
- GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
1443
-
1444
1449
  GGML_TENSOR_BINARY_OP_LOCALS;
1445
1450
 
1446
1451
  // aclnnIm2col only works on 2D. set s1, p1, d1 to 1 to perform 2D
@@ -1462,9 +1467,6 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1462
1467
  const int64_t OH = is_2D ? ne2 : 1;
1463
1468
  const int64_t OW = ne1;
1464
1469
 
1465
- GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
1466
- GGML_ASSERT(nb10 == sizeof(float));
1467
-
1468
1470
  // memory allocated increased to 3x when is_2D == false
1469
1471
  const int64_t n_bytes_factor = is_2D ? 1 : 3;
1470
1472
 
@@ -1768,6 +1770,92 @@ static void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
1768
1770
  ACL_CHECK(aclnnSin(workspaceAddr, workspaceSize, executor, ctx.stream()));
1769
1771
  }
1770
1772
 
1773
+ /**
1774
+ * @brief Performs element-wise division of tensor1 by tensor2 , multiplies the
1775
+ result by the scalar value and adds it to self .
1776
+ *
1777
+ * Performs element-wise division of tensor1 by tensor2,
1778
+ * multiplies the result by the scalar value and adds it to self .
1779
+ * The operation is defined as:
1780
+ * \f[
1781
+ * \text{out}_i = \text{selft}_i + \text{value} \times
1782
+ \frac{\text{tensor1}_i}{\text{tensor2}_i}
1783
+ * \f]
1784
+
1785
+ * @param ctx The context for the CANN backend operations.
1786
+ * @param acl_self The source tensor on which the addcdiv function will be
1787
+ applied.
1788
+ * @param tensor1 Numerator tensor.
1789
+ * @param tensor2 Denominator tensor.
1790
+ * @param value The value to be used for coefficient.
1791
+ */
1792
+ static void aclnn_inplace_addcdiv(ggml_backend_cann_context& ctx,
1793
+ aclTensor* acl_self, aclTensor* tensor1,
1794
+ aclTensor* tensor2, float value) {
1795
+ uint64_t workspaceSize = 0;
1796
+ aclOpExecutor* executor;
1797
+ void* workspaceAddr = nullptr;
1798
+ aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT);
1799
+
1800
+ ACL_CHECK(aclnnInplaceAddcdivGetWorkspaceSize(
1801
+ acl_self, tensor1, tensor2, acl_value, &workspaceSize, &executor));
1802
+ if (workspaceSize > 0) {
1803
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
1804
+ workspaceAddr = workspace_allocator.get();
1805
+ }
1806
+
1807
+ ACL_CHECK(aclnnInplaceAddcdiv(workspaceAddr, workspaceSize, executor,
1808
+ ctx.stream()));
1809
+ }
1810
+
1811
+ /**
1812
+ * @brief Matrix division, optionally in-place.
1813
+ *
1814
+ * This function division each element of the source tensor `acl_src` by the
1815
+ * tensor `acl_other` and stores the result in the destination tensor `acl_dst`.
1816
+ * If `inplace` is true, `acl_dst` will not be used and the operation is
1817
+ * performed in-place on `acl_src`. The operation is defined as: \f[
1818
+ * \text{dst}_i = \frac{\text{acl_src}_i}{\text{acl_other}_i}
1819
+ * \f]
1820
+ *
1821
+ * @param ctx The context for the CANN backend operations.
1822
+ * @param acl_src Numerator tensor..
1823
+ * @param acl_other Denominator tensor.
1824
+ * @param acl_dst The destination tensor where the result will be stored if
1825
+ * `inplace` is false.
1826
+ * @param inplace Flag indicating whether to perform the operation in-place on
1827
+ * `acl_src`.
1828
+ */
1829
+ static void aclnn_div_tensor(ggml_backend_cann_context& ctx, aclTensor* acl_src,
1830
+ aclTensor* acl_other, aclTensor* acl_dst,
1831
+ bool inplace) {
1832
+ uint64_t workspaceSize = 0;
1833
+ aclOpExecutor* executor;
1834
+ void* workspaceAddr = nullptr;
1835
+
1836
+ if (inplace) {
1837
+ ACL_CHECK(aclnnInplaceDivGetWorkspaceSize(acl_src, acl_other,
1838
+ &workspaceSize, &executor));
1839
+ if (workspaceSize > 0) {
1840
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
1841
+ workspaceAddr = workspace_allocator.get();
1842
+ }
1843
+
1844
+ ACL_CHECK(aclnnInplaceDiv(workspaceAddr, workspaceSize, executor,
1845
+ ctx.stream()));
1846
+ } else {
1847
+ ACL_CHECK(aclnnDivGetWorkspaceSize(acl_src, acl_other, acl_dst,
1848
+ &workspaceSize, &executor));
1849
+ if (workspaceSize > 0) {
1850
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
1851
+ workspaceAddr = workspace_allocator.get();
1852
+ }
1853
+
1854
+ ACL_CHECK(
1855
+ aclnnDiv(workspaceAddr, workspaceSize, executor, ctx.stream()));
1856
+ }
1857
+ }
1858
+
1771
1859
  void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
1772
1860
  ggml_tensor* dst) {
1773
1861
  const ggml_tensor* src = dst->src[0];
@@ -2311,7 +2399,16 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2311
2399
  ctx.stream()));
2312
2400
 
2313
2401
  switch (src0->type) {
2314
- case GGML_TYPE_F32:
2402
+ case GGML_TYPE_F32: {
2403
+ #ifdef ASCEND_310P
2404
+ // Special operation for get_row_f32 kernel of 310P: clear the
2405
+ // content of dest data buffer when row is not aligned to 32 bytes
2406
+ if ((src0->ne[0] % 8) != 0) {
2407
+ size_t dst_len = src1->ne[0] * src1->ne[1] * src1->ne[2] *
2408
+ src0->ne[0] * ggml_type_size(GGML_TYPE_F32);
2409
+ ACL_CHECK(aclrtMemset((char*)dst->data, dst_len, 0, dst_len));
2410
+ }
2411
+ #endif
2315
2412
  aclrtlaunch_ascendc_get_row_f32(
2316
2413
  24, ctx.stream(), src0->data, src1->data, dst->data,
2317
2414
  ((ggml_tensor*)src0->extra)->ne,
@@ -2320,7 +2417,19 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2320
2417
  ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne,
2321
2418
  ((ggml_tensor*)dst->extra)->nb);
2322
2419
  break;
2323
- case GGML_TYPE_F16:
2420
+ }
2421
+ case GGML_TYPE_F16: {
2422
+ #ifdef ASCEND_310P
2423
+ // Special operation for get_row_f16 kernel of 310P: clear the
2424
+ // content of dest data buffer when row is not aligned to 32 bytes
2425
+ if ((src0->ne[0] % 16) != 0) {
2426
+ size_t dst_len =
2427
+ src1->ne[0] * src1->ne[1] * src1->ne[2] * src0->ne[0] *
2428
+ ggml_type_size(
2429
+ GGML_TYPE_F32); // out is also f32, even input is f16
2430
+ ACL_CHECK(aclrtMemset((char*)dst->data, dst_len, 0, dst_len));
2431
+ }
2432
+ #endif
2324
2433
  aclrtlaunch_ascendc_get_row_f16(
2325
2434
  24, ctx.stream(), src0->data, src1->data, dst->data,
2326
2435
  ((ggml_tensor*)src0->extra)->ne,
@@ -2329,6 +2438,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2329
2438
  ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne,
2330
2439
  ((ggml_tensor*)dst->extra)->nb);
2331
2440
  break;
2441
+ }
2332
2442
  case GGML_TYPE_Q4_0:
2333
2443
  aclrtlaunch_ascendc_get_row_q4_0(
2334
2444
  24, ctx.stream(), src0->data, src1->data, dst->data,
@@ -2407,7 +2517,6 @@ static void aclnn_mat_mul(ggml_backend_cann_context& ctx, aclTensor* acl_input,
2407
2517
  aclTensor* acl_weight, aclTensor* acl_dst) {
2408
2518
  int8_t cube_math_type = 1; // ALLOW_FP32_DOWN_PRECISION, when input is
2409
2519
  // fp32, atlas a2 will transpose it to HFLOAT32.
2410
-
2411
2520
  uint64_t workspaceSize = 0;
2412
2521
  aclOpExecutor* executor;
2413
2522
  void* workspaceAddr = nullptr;
@@ -2425,6 +2534,81 @@ static void aclnn_mat_mul(ggml_backend_cann_context& ctx, aclTensor* acl_input,
2425
2534
  aclnnMatmul(workspaceAddr, workspaceSize, executor, ctx.stream()));
2426
2535
  }
2427
2536
 
2537
+ /**
2538
+ * @brief Performs matrix multiplication of two 2D tensors.
2539
+ *
2540
+ * This function computes the matrix multiplication of the input tensor
2541
+ * `acl_input` and the weight tensor `acl_weight`, and stores the result in the
2542
+ * destination tensor `acl_dst`.
2543
+ * The operation is defined as:
2544
+ * \f[
2545
+ * \text {acl_dst}=\text {acl_input@acl_weight}
2546
+ * \f]
2547
+ *
2548
+ * @param ctx The context for the CANN backend operations.
2549
+ * @param acl_input The input tensor for the matrix multiplication.
2550
+ * @param acl_weight The weight tensor for the matrix multiplication.
2551
+ * @param acl_dst The destination tensor where the result of the matrix
2552
+ * multiplication will be stored.
2553
+ */
2554
+ static void aclnn_mat_mul_2d(ggml_backend_cann_context& ctx,
2555
+ aclTensor* acl_input, aclTensor* acl_weight,
2556
+ aclTensor* acl_dst) {
2557
+ int8_t cube_math_type = 2;
2558
+ uint64_t workspaceSize = 0;
2559
+ aclOpExecutor* executor;
2560
+ void* workspaceAddr = nullptr;
2561
+
2562
+ ACL_CHECK(aclnnMmGetWorkspaceSize(acl_input, acl_weight, acl_dst,
2563
+ cube_math_type, &workspaceSize,
2564
+ &executor));
2565
+
2566
+ if (workspaceSize > 0) {
2567
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
2568
+ workspaceAddr = workspace_allocator.get();
2569
+ }
2570
+
2571
+ ACL_CHECK(aclnnMm(workspaceAddr, workspaceSize, executor, ctx.stream()));
2572
+ }
2573
+
2574
+ /**
2575
+ * @brief Performs matrix multiplication of two 3D tensors.
2576
+ *
2577
+ * This function computes the matrix multiplication of the input tensor
2578
+ * `acl_input` and the weight tensor `acl_weight`, and stores the result in the
2579
+ * destination tensor `acl_dst`.
2580
+ * The operation is defined as:
2581
+ * \f[
2582
+ * \text {acl_dst}=\text {acl_input@acl_weight}
2583
+ * \f]
2584
+ *
2585
+ * @param ctx The context for the CANN backend operations.
2586
+ * @param acl_input The input tensor for the matrix multiplication.
2587
+ * @param acl_weight The weight tensor for the matrix multiplication.
2588
+ * @param acl_dst The destination tensor where the result of the matrix
2589
+ * multiplication will be stored.
2590
+ */
2591
+ static void aclnn_mat_mul_3d(ggml_backend_cann_context& ctx,
2592
+ aclTensor* acl_input, aclTensor* acl_weight,
2593
+ aclTensor* acl_dst) {
2594
+ int8_t cube_math_type = 2;
2595
+ uint64_t workspaceSize = 0;
2596
+ aclOpExecutor* executor;
2597
+ void* workspaceAddr = nullptr;
2598
+
2599
+ ACL_CHECK(aclnnBatchMatMulGetWorkspaceSize(acl_input, acl_weight, acl_dst,
2600
+ cube_math_type, &workspaceSize,
2601
+ &executor));
2602
+
2603
+ if (workspaceSize > 0) {
2604
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
2605
+ workspaceAddr = workspace_allocator.get();
2606
+ }
2607
+
2608
+ ACL_CHECK(
2609
+ aclnnBatchMatMul(workspaceAddr, workspaceSize, executor, ctx.stream()));
2610
+ }
2611
+
2428
2612
  /**
2429
2613
  * @brief Performs matrix multiplication with floating-point precision on
2430
2614
  * tensors using the CANN backend.
@@ -2446,20 +2630,39 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
2446
2630
  // broadcast, when weight ne2 or ne3 is not 1, weight need repeat.
2447
2631
  BCAST_MUL_MAT_SHAPE(input, weight, dst);
2448
2632
 
2449
- // transpose weight: [1,2,3,4] -> [1,2,4,3]
2633
+ int64_t n_dims = bcast_dims;
2634
+ if (bcast_input_ne[3] == bcast_weight_ne[3] && bcast_input_ne[3] == 1) {
2635
+ if (bcast_input_ne[2] == 1 && bcast_weight_ne[2] == 1) {
2636
+ n_dims = 2;
2637
+ } else if (bcast_input_ne[2] == 1) {
2638
+ n_dims = 3;
2639
+ }
2640
+ }
2641
+
2642
+ aclTensor* acl_input_tensor =
2643
+ ggml_cann_create_tensor(input, bcast_input_ne, bcast_input_nb, n_dims);
2450
2644
  int64_t transpose_ne[] = {bcast_weight_ne[1], bcast_weight_ne[0],
2451
2645
  bcast_weight_ne[2], bcast_weight_ne[3],
2452
2646
  bcast_weight_ne[4], bcast_weight_ne[5]};
2453
2647
  size_t transpose_nb[] = {bcast_weight_nb[1], bcast_weight_nb[0],
2454
2648
  bcast_weight_nb[2], bcast_weight_nb[3],
2455
2649
  bcast_weight_nb[4], bcast_weight_nb[5]};
2456
-
2457
2650
  aclTensor* acl_weight_tensor =
2458
- ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, bcast_dims);
2459
- aclTensor* acl_input_tensor =
2460
- ggml_cann_create_tensor(input, BCAST_MUL_MAT_PARAM(input));
2461
- aclTensor* acl_dst = ggml_cann_create_tensor(dst, BCAST_MUL_MAT_PARAM(dst));
2462
- aclnn_mat_mul(ctx, acl_input_tensor, acl_weight_tensor, acl_dst);
2651
+ ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims);
2652
+ aclTensor* acl_dst =
2653
+ ggml_cann_create_tensor(dst, bcast_dst_ne, bcast_dst_nb, n_dims);
2654
+
2655
+ switch (n_dims) {
2656
+ case 2:
2657
+ aclnn_mat_mul_2d(ctx, acl_input_tensor, acl_weight_tensor, acl_dst);
2658
+ break;
2659
+ case 3:
2660
+ aclnn_mat_mul_3d(ctx, acl_input_tensor, acl_weight_tensor, acl_dst);
2661
+ break;
2662
+ default:
2663
+ aclnn_mat_mul(ctx, acl_input_tensor, acl_weight_tensor, acl_dst);
2664
+ break;
2665
+ }
2463
2666
 
2464
2667
  ACL_CHECK(aclDestroyTensor(acl_weight_tensor));
2465
2668
  ACL_CHECK(aclDestroyTensor(acl_input_tensor));
@@ -2480,51 +2683,47 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
2480
2683
  * multiplication will be stored.
2481
2684
  */
2482
2685
  static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
2483
- ggml_tensor* dst,
2484
- const enum ggml_type type) {
2686
+ ggml_tensor* dst,
2687
+ const enum ggml_type type) {
2485
2688
  ggml_tensor* src0 = dst->src[0]; // weight
2486
2689
  ggml_tensor* src1 = dst->src[1]; // input
2487
2690
 
2488
- // The shape of the weight is NCHW. Matrix multiplication uses HW dims. HC
2489
- // is regarded as batch. weight need transpose.
2490
- int64_t weight_ne[] = {src0->ne[1], src0->ne[0]};
2691
+ // The shape of the weight is NCHW.
2692
+ // Matrix multiplication uses HW dims.
2693
+ // HC is regarded as batch.
2694
+ // weight need transpose.
2491
2695
  float weight_elem_size;
2492
2696
  if (type == GGML_TYPE_Q4_0) {
2493
2697
  weight_elem_size = float(sizeof(uint8_t)) / 2;
2494
- }
2495
- else if (type == GGML_TYPE_Q8_0) {
2698
+ } else if (type == GGML_TYPE_Q8_0) {
2496
2699
  weight_elem_size = float(sizeof(uint8_t));
2497
- }
2498
- else {
2700
+ } else {
2499
2701
  GGML_ABORT("Only support Q4_0 and Q8_0 MUL_MAT");
2500
2702
  }
2501
- float weight_nb[] = {weight_elem_size * src0->ne[0], weight_elem_size};
2502
-
2503
- // size of one matrix is element_size * height * width.
2504
- size_t weight_stride = weight_elem_size * src0->ne[0] * src0->ne[1];
2703
+ float weight_nb[] = {src0->ne[0] * weight_elem_size, weight_elem_size};
2704
+ size_t weight_stride = src0->ne[1] * src0->ne[0] * weight_elem_size;
2505
2705
  size_t weight_size = weight_stride * src0->ne[2] * src0->ne[3];
2506
2706
 
2507
2707
  // scale stored at the end of weight. Also need transpose.
2508
- GGML_ASSERT(QK4_0 == QK8_0);
2509
- int64_t scale_ne[] = {src0->ne[1], src0->ne[0] / QK8_0};
2510
2708
  size_t scale_elem_size = sizeof(uint16_t);
2511
2709
  size_t scale_nb[] = {src0->ne[0] / QK8_0 * scale_elem_size,
2512
2710
  scale_elem_size};
2513
- size_t scale_stride = scale_elem_size * src0->ne[0] * src0->ne[1] / QK8_0;
2711
+ size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size;
2514
2712
  char* scale_offset = (char*)src0->data + weight_size;
2515
2713
 
2516
2714
  // input
2517
- void* input_buffer;
2518
2715
  size_t input_elem_size = sizeof(uint16_t);
2519
2716
  int64_t input_ne[] = {src1->ne[0], src1->ne[1]};
2520
- size_t input_nb[] = {input_elem_size, input_elem_size * src1->ne[0]};
2521
- size_t input_stride = input_elem_size * src1->ne[0] * src1->ne[1];
2522
-
2717
+ size_t input_nb[] = {input_elem_size, input_ne[0] * input_elem_size};
2718
+ size_t input_stride = input_ne[0] * input_ne[1] * input_elem_size;
2523
2719
  ggml_cann_pool_alloc input_alloctor(ctx.pool());
2720
+ void* input_buffer = src1->data;
2721
+
2722
+ // case in
2524
2723
  if (src1->type != GGML_TYPE_F16) {
2525
2724
  aclTensor* acl_src1_tensor = ggml_cann_create_tensor(src1);
2526
- input_alloctor.alloc(ggml_nelements(src1) * input_elem_size);
2527
- input_buffer = input_alloctor.get();
2725
+ input_buffer =
2726
+ input_alloctor.alloc(ggml_nelements(src1) * input_elem_size);
2528
2727
 
2529
2728
  int64_t* input_cast_ne = src1->ne;
2530
2729
  size_t input_cast_nb[GGML_MAX_DIMS];
@@ -2537,85 +2736,136 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
2537
2736
  input_buffer, ACL_FLOAT16, input_elem_size, input_cast_ne,
2538
2737
  input_cast_nb, GGML_MAX_DIMS);
2539
2738
  aclnn_cast(ctx, acl_src1_tensor, acl_input_tensor, ACL_FLOAT16);
2739
+
2540
2740
  ACL_CHECK(aclDestroyTensor(acl_input_tensor));
2541
2741
  ACL_CHECK(aclDestroyTensor(acl_src1_tensor));
2542
- } else {
2543
- input_buffer = src1->data;
2544
2742
  }
2545
2743
 
2546
2744
  // output
2547
2745
  size_t output_elem_size = sizeof(uint16_t);
2548
- int64_t output_ne[] = {dst->ne[0], dst->ne[1]};
2549
- size_t output_nb[] = {output_elem_size, output_elem_size * dst->ne[0]};
2550
- ggml_cann_pool_alloc output_alloctor(
2551
- ctx.pool(), ggml_nelements(dst) * output_elem_size);
2552
- void* output_buffer = output_alloctor.get();
2553
- size_t output_stride = output_elem_size * dst->ne[0] * dst->ne[1];
2746
+ size_t output_nb[] = {output_elem_size, dst->ne[0] * output_elem_size};
2747
+ ggml_cann_pool_alloc output_allocator(ctx.pool());
2748
+ void* output_buffer =
2749
+ output_allocator.alloc(ggml_nelements(dst) * output_elem_size);
2750
+ size_t output_stride = dst->ne[0] * dst->ne[1] * output_elem_size;
2554
2751
 
2555
2752
  // aclnn
2753
+ int64_t max_elem_size = 65535;
2754
+ int64_t split_size = (src0->ne[1] / max_elem_size) + 1;
2755
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool());
2756
+ aclOpExecutor* executor = nullptr;
2556
2757
  uint64_t workspaceSize = 0;
2557
- aclOpExecutor* executor;
2558
2758
  void* workspaceAddr = nullptr;
2559
-
2560
2759
  for (int64_t n1 = 0; n1 < src1->ne[3]; n1++) {
2561
2760
  for (int64_t c1 = 0; c1 < src1->ne[2]; c1++) {
2562
2761
  int64_t n0 = n1 / (src1->ne[3] / src0->ne[3]);
2563
2762
  int64_t c0 = c1 / (src1->ne[2] / src0->ne[2]);
2564
2763
 
2565
- int64_t batch1 = n1 * src1->ne[2] + c1;
2566
- int64_t batch0 = n0 * src0->ne[2] + c0;
2764
+ int64_t batch1 = (n1 * src1->ne[2]) + c1;
2765
+ int64_t batch0 = (n0 * src0->ne[2]) + c0;
2567
2766
 
2568
2767
  aclTensor* acl_input_tensor = ggml_cann_create_tensor(
2569
2768
  (char*)input_buffer + batch1 * input_stride, ACL_FLOAT16,
2570
2769
  input_elem_size, input_ne, input_nb, 2);
2770
+
2771
+ // first split
2772
+ int64_t weight_ne_offset = 0;
2773
+ int64_t weight_ne[2] = {
2774
+ max_elem_size > src0->ne[1] ? src0->ne[1] : max_elem_size,
2775
+ src0->ne[0]};
2776
+ int64_t scale_ne_offset = 0;
2777
+ int64_t scale_ne[2] = {weight_ne[0], weight_ne[1] / QK8_0};
2778
+ int64_t output_ne_offset = 0;
2779
+ int64_t output_ne[2] = {weight_ne[0], dst->ne[1]};
2780
+
2571
2781
  aclTensor* acl_weight_tensor = ggml_cann_create_tensor(
2572
2782
  (char*)src0->data + batch0 * weight_stride,
2573
2783
  ggml_cann_type_mapping(type), weight_elem_size, weight_ne,
2574
- weight_nb, 2);
2784
+ weight_nb, 2, ACL_FORMAT_ND, weight_ne_offset);
2575
2785
  aclTensor* acl_scale_tensor = ggml_cann_create_tensor(
2576
2786
  scale_offset + batch0 * scale_stride, ACL_FLOAT16,
2577
- scale_elem_size, scale_ne, scale_nb, 2);
2787
+ scale_elem_size, scale_ne, scale_nb, 2, ACL_FORMAT_ND,
2788
+ scale_ne_offset);
2578
2789
  aclTensor* acl_output_tensor = ggml_cann_create_tensor(
2579
2790
  (char*)output_buffer + batch1 * output_stride, ACL_FLOAT16,
2580
- output_elem_size, output_ne, output_nb, 2);
2791
+ output_elem_size, output_ne, output_nb, 2, ACL_FORMAT_ND,
2792
+ output_ne_offset);
2581
2793
 
2582
2794
  ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
2583
2795
  acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr,
2584
2796
  nullptr, nullptr, nullptr, QK8_0, acl_output_tensor,
2585
2797
  &workspaceSize, &executor));
2586
-
2587
- if (workspaceSize > 0 && workspaceAddr == nullptr) {
2588
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(),
2589
- workspaceSize);
2590
- workspaceAddr = workspace_allocator.get();
2798
+ if (workspaceAddr == nullptr) {
2799
+ workspaceAddr = workspace_allocator.alloc(workspaceSize);
2591
2800
  }
2592
-
2593
2801
  ACL_CHECK(aclnnWeightQuantBatchMatmulV2(
2594
2802
  workspaceAddr, workspaceSize, executor, ctx.stream()));
2595
2803
 
2596
- ACL_CHECK(aclDestroyTensor(acl_input_tensor));
2597
2804
  ACL_CHECK(aclDestroyTensor(acl_weight_tensor));
2598
2805
  ACL_CHECK(aclDestroyTensor(acl_scale_tensor));
2599
2806
  ACL_CHECK(aclDestroyTensor(acl_output_tensor));
2807
+
2808
+ // other splits
2809
+ for (int64_t split = 1; split < split_size; split++) {
2810
+ weight_ne_offset +=
2811
+ weight_elem_size * weight_ne[0] * weight_ne[1];
2812
+ weight_ne[0] = max_elem_size * (split + 1) > src0->ne[1]
2813
+ ? src0->ne[1] - (max_elem_size * split)
2814
+ : max_elem_size;
2815
+ scale_ne_offset += scale_elem_size * scale_ne[0] * scale_ne[1];
2816
+ scale_ne[0] = weight_ne[0];
2817
+ output_ne_offset +=
2818
+ output_elem_size * output_ne[0] * output_ne[1];
2819
+ output_ne[0] = weight_ne[0];
2820
+
2821
+ acl_weight_tensor = ggml_cann_create_tensor(
2822
+ (char*)src0->data + batch0 * weight_stride,
2823
+ ggml_cann_type_mapping(type), weight_elem_size, weight_ne,
2824
+ weight_nb, 2, ACL_FORMAT_ND, weight_ne_offset);
2825
+ acl_scale_tensor = ggml_cann_create_tensor(
2826
+ scale_offset + batch0 * scale_stride, ACL_FLOAT16,
2827
+ scale_elem_size, scale_ne, scale_nb, 2, ACL_FORMAT_ND,
2828
+ scale_ne_offset);
2829
+ acl_output_tensor = ggml_cann_create_tensor(
2830
+ (char*)output_buffer + batch1 * output_stride, ACL_FLOAT16,
2831
+ output_elem_size, output_ne, output_nb, 2, ACL_FORMAT_ND,
2832
+ output_ne_offset);
2833
+
2834
+ ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
2835
+ acl_input_tensor, acl_weight_tensor, acl_scale_tensor,
2836
+ nullptr, nullptr, nullptr, nullptr, QK8_0,
2837
+ acl_output_tensor, &workspaceSize, &executor));
2838
+ ACL_CHECK(aclnnWeightQuantBatchMatmulV2(
2839
+ workspaceAddr, workspaceSize, executor, ctx.stream()));
2840
+
2841
+ ACL_CHECK(aclDestroyTensor(acl_weight_tensor));
2842
+ ACL_CHECK(aclDestroyTensor(acl_scale_tensor));
2843
+ ACL_CHECK(aclDestroyTensor(acl_output_tensor));
2844
+ }
2845
+
2846
+ ACL_CHECK(aclDestroyTensor(acl_input_tensor));
2600
2847
  }
2601
2848
  }
2602
2849
 
2603
2850
  // cast out
2604
- int64_t* output_cast_ne = dst->ne;
2605
- size_t output_cast_nb[GGML_MAX_DIMS];
2606
- output_cast_nb[0] = sizeof(uint16_t);
2607
- for (int i = 1; i < GGML_MAX_DIMS; i++) {
2608
- output_cast_nb[i] = output_cast_nb[i - 1] * output_cast_ne[i - 1];
2609
- }
2851
+ if (dst->type != GGML_TYPE_F16) {
2852
+ int64_t* output_cast_ne = dst->ne;
2853
+ size_t output_cast_nb[GGML_MAX_DIMS];
2854
+ output_cast_nb[0] = sizeof(uint16_t);
2855
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
2856
+ output_cast_nb[i] = output_cast_nb[i - 1] * output_cast_ne[i - 1];
2857
+ }
2610
2858
 
2611
- aclTensor* acl_output_tensor =
2612
- ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, output_elem_size,
2613
- output_cast_ne, output_cast_nb, GGML_MAX_DIMS);
2614
- aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
2615
- aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor, ACL_FLOAT);
2859
+ aclTensor* acl_output_tensor = ggml_cann_create_tensor(
2860
+ output_buffer, ACL_FLOAT16, output_elem_size, output_cast_ne,
2861
+ output_cast_nb, GGML_MAX_DIMS);
2862
+ aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
2863
+ aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor,
2864
+ ggml_cann_type_mapping(dst->type));
2616
2865
 
2617
- ACL_CHECK(aclDestroyTensor(acl_output_tensor));
2618
- ACL_CHECK(aclDestroyTensor(acl_dst_tensor));
2866
+ ACL_CHECK(aclDestroyTensor(acl_output_tensor));
2867
+ ACL_CHECK(aclDestroyTensor(acl_dst_tensor));
2868
+ }
2619
2869
  }
2620
2870
 
2621
2871
  void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -2714,12 +2964,14 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
2714
2964
  static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
2715
2965
  aclTensor* acl_cos_repeat_tensor,
2716
2966
  aclTensor* acl_sin_repeat_tensor,
2717
- float theta_scale, bool is_neox) {
2967
+ float theta_scale, float freq_scale,
2968
+ float attn_factor, bool is_neox) {
2718
2969
  // int sin/cos cache, cache has different repeat method depond on
2719
2970
  // @param.is_neox
2720
2971
 
2721
2972
  ggml_tensor* src0 = dst->src[0]; // input
2722
2973
  ggml_tensor* src1 = dst->src[1]; // position
2974
+ ggml_tensor* src2 = dst->src[2]; // freq_factors
2723
2975
 
2724
2976
  // arange, [0,1,...,ne0/2]
2725
2977
  int64_t arange_length = src0->ne[0] / 2;
@@ -2748,11 +3000,26 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
2748
3000
  ggml_cann_pool_alloc theta_scale_allocator(ctx.pool(),
2749
3001
  arange_length * sizeof(float_t));
2750
3002
  void* theta_scale_buffer = theta_scale_allocator.get();
2751
- aclTensor* acl_theta_scale_tensor = aclnn_ones(
3003
+ aclTensor* acl_theta_scale_tensor = aclnn_values(
2752
3004
  ctx, theta_scale_buffer, arange_length * sizeof(float_t), arange_ne,
2753
3005
  GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), theta_scale);
2754
3006
  aclnn_pow_tensor_tensor(ctx, acl_theta_scale_tensor, acl_arange_tensor);
2755
3007
 
3008
+ // freq_scale
3009
+ if (freq_scale != 1) {
3010
+ aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true);
3011
+ }
3012
+
3013
+ // freq_factors
3014
+ if (src2) {
3015
+ aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
3016
+ src2->data, ggml_cann_type_mapping(src2->type),
3017
+ ggml_type_size(src2->type), arange_ne, arange_nb, GGML_MAX_DIMS);
3018
+ aclnn_div_tensor(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor,
3019
+ nullptr, true);
3020
+ ACL_CHECK(aclDestroyTensor(acl_freq_factors_tensor));
3021
+ }
3022
+
2756
3023
  // position
2757
3024
  GGML_ASSERT(src1->type == GGML_TYPE_I32);
2758
3025
  int64_t position_length = src1->ne[0];
@@ -2816,6 +3083,12 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
2816
3083
  GGML_MAX_DIMS, ACL_FORMAT_ND);
2817
3084
  aclnn_cos(ctx, acl_permute_tensor, acl_cos_tensor);
2818
3085
 
3086
+ // attn_factor
3087
+ if (attn_factor != 1) {
3088
+ aclnn_muls(ctx, acl_sin_tensor, attn_factor, nullptr, true);
3089
+ aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true);
3090
+ }
3091
+
2819
3092
  // repeat
2820
3093
  if (is_neox) {
2821
3094
  int64_t repeatsArray[] = {1, 1, 1, 2};
@@ -2841,15 +3114,27 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
2841
3114
  ACL_CHECK(aclDestroyTensor(acl_cos_tensor));
2842
3115
  }
2843
3116
 
3117
+ #ifdef __cplusplus
3118
+ extern "C" {
3119
+ #endif
3120
+ aclnnStatus aclnnRotaryPositionEmbeddingGetWorkspaceSize(
3121
+ const aclTensor* x, const aclTensor* cos, const aclTensor* sin,
3122
+ int64_t mode, const aclTensor* yOut, uint64_t* workspaceSize,
3123
+ aclOpExecutor** executor);
3124
+ aclnnStatus aclnnRotaryPositionEmbedding(void* workspace,
3125
+ uint64_t workspaceSize,
3126
+ aclOpExecutor* executor,
3127
+ aclrtStream stream);
3128
+ #ifdef __cplusplus
3129
+ }
3130
+ #endif
3131
+
2844
3132
  void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2845
3133
  // TODO: use ascendc
2846
3134
  // Only test with LLAMA model.
2847
3135
  ggml_tensor* src0 = dst->src[0]; // input
2848
3136
  ggml_tensor* src2 = dst->src[2]; // freq_factors
2849
3137
 
2850
- // TODO: with freq_factors
2851
- GGML_ASSERT(src2 == NULL);
2852
-
2853
3138
  // param
2854
3139
  float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
2855
3140
  // const int n_past = ((int32_t *) dst->op_params)[0];
@@ -2867,13 +3152,11 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2867
3152
  memcpy(&beta_fast, (int32_t*)dst->op_params + 9, sizeof(float));
2868
3153
  memcpy(&beta_slow, (int32_t*)dst->op_params + 10, sizeof(float));
2869
3154
 
2870
- GGML_ASSERT(n_dims <= ne0);
3155
+ // TODO: n_dims <= ne0
3156
+ GGML_ASSERT(n_dims == ne0);
2871
3157
  GGML_ASSERT(n_dims % 2 == 0);
2872
-
2873
3158
  // TODO: ext_factor != 0
2874
3159
  GGML_ASSERT(ext_factor == 0);
2875
- // TODO: freq_scale != 1
2876
- GGML_ASSERT(freq_scale == 1);
2877
3160
 
2878
3161
  const float theta_scale = powf(freq_base, -2.0f / n_dims);
2879
3162
 
@@ -2904,7 +3187,13 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2904
3187
  ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float_t),
2905
3188
  sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
2906
3189
  aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor,
2907
- theta_scale, is_neox);
3190
+ theta_scale, freq_scale, attn_factor, is_neox);
3191
+
3192
+ aclTensor* acl_src = ggml_cann_create_tensor(src0);
3193
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
3194
+
3195
+ #ifdef ASCEND_310P
3196
+ // Special ROPE operation for 310P
2908
3197
 
2909
3198
  // roll input
2910
3199
  void* input_roll_buffer;
@@ -2947,7 +3236,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2947
3236
  for (int i = 1; i < GGML_MAX_DIMS; i++) {
2948
3237
  minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1];
2949
3238
  }
2950
- acl_minus_one_tensor = aclnn_ones(
3239
+ acl_minus_one_tensor = aclnn_values(
2951
3240
  ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0],
2952
3241
  minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1);
2953
3242
  int64_t dim = 3;
@@ -2974,17 +3263,15 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2974
3263
 
2975
3264
  ACL_CHECK(aclDestroyTensor(acl_input_roll_tensor));
2976
3265
  ACL_CHECK(aclDestroyTensor(acl_input_tensor));
2977
-
2978
3266
  // init [-1, -1, -1, 1, 1,1,...]
2979
3267
  minus_one_scale_buffer = minus_one_scale_allocator.get();
2980
-
2981
3268
  int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1};
2982
3269
  size_t minus_one_nb[GGML_MAX_DIMS];
2983
3270
  minus_one_nb[0] = sizeof(float_t);
2984
3271
  for (int i = 1; i < GGML_MAX_DIMS; i++) {
2985
3272
  minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1];
2986
3273
  }
2987
- acl_minus_one_tensor = aclnn_ones(
3274
+ acl_minus_one_tensor = aclnn_values(
2988
3275
  ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0],
2989
3276
  minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1);
2990
3277
  // -1 * first half
@@ -3026,14 +3313,12 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
3026
3313
  acl_input_roll_mul_scale_tensor);
3027
3314
 
3028
3315
  // output
3029
- aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
3030
- aclTensor* acl_dst = ggml_cann_create_tensor(dst);
3031
3316
  void* output_fp32_buffer;
3032
3317
  if (src0->type == GGML_TYPE_F32) {
3033
- aclnn_inplace_mul(ctx, acl_src0, acl_cos_reshape_tensor);
3318
+ aclnn_inplace_mul(ctx, acl_src, acl_cos_reshape_tensor);
3034
3319
  aclnn_inplace_mul(ctx, acl_input_roll_mul_scale_tensor,
3035
3320
  acl_sin_reshape_tensor);
3036
- aclnn_add(ctx, acl_src0, acl_input_roll_mul_scale_tensor, acl_dst);
3321
+ aclnn_add(ctx, acl_src, acl_input_roll_mul_scale_tensor, acl_dst);
3037
3322
  // TODO: ne0 != n_dims in mode2
3038
3323
  } else if (src0->type == GGML_TYPE_F16) {
3039
3324
  size_t input_fp32_nb[GGML_MAX_DIMS];
@@ -3060,7 +3345,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
3060
3345
  aclTensor* output_fp32_tensor = ggml_cann_create_tensor(
3061
3346
  output_fp32_buffer, ACL_FLOAT, sizeof(float_t), dst->ne,
3062
3347
  input_fp32_nb, GGML_MAX_DIMS);
3063
- aclnn_mul(ctx, acl_src0, acl_cos_reshape_tensor, input_fp32_tensor1);
3348
+ aclnn_mul(ctx, acl_src, acl_cos_reshape_tensor, input_fp32_tensor1);
3064
3349
  aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor,
3065
3350
  input_fp32_tensor2);
3066
3351
  aclnn_add(ctx, input_fp32_tensor1, input_fp32_tensor2,
@@ -3070,13 +3355,73 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
3070
3355
  ACL_CHECK(aclDestroyTensor(input_fp32_tensor1));
3071
3356
  ACL_CHECK(aclDestroyTensor(input_fp32_tensor2));
3072
3357
  ACL_CHECK(aclDestroyTensor(output_fp32_tensor));
3358
+ ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor));
3359
+ ACL_CHECK(aclDestroyTensor(acl_minus_one_tensor));
3360
+ ACL_CHECK(aclDestroyTensor(acl_input_roll_mul_scale_tensor));
3361
+ ACL_CHECK(aclDestroyTensor(acl_input_roll_reshape_tensor));
3362
+ ACL_CHECK(aclDestroyTensor(acl_src));
3073
3363
  }
3364
+ return;
3365
+ #endif
3074
3366
 
3075
- ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor));
3367
+ // src0 == GGML_TYPE_F16
3368
+ // TODO: optimization this `if` code
3369
+ if (src0->type == GGML_TYPE_F16) {
3370
+ ggml_cann_pool_alloc sin_final_allocator(
3371
+ ctx.pool(), src0->ne[0] * src0->ne[2] * ggml_type_size(src0->type));
3372
+ ggml_cann_pool_alloc cos_final_allocator(
3373
+ ctx.pool(), src0->ne[0] * src0->ne[2] * ggml_type_size(src0->type));
3374
+ void* sin_final_buffer = sin_final_allocator.get();
3375
+ void* cos_final_buffer = cos_final_allocator.get();
3376
+
3377
+ int64_t sin_final_ne[4] = {src0->ne[0], 1, src0->ne[2], 1};
3378
+ size_t sin_final_nb[GGML_MAX_DIMS];
3379
+ sin_final_nb[0] = ggml_type_size(src0->type);
3380
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
3381
+ sin_final_nb[i] = sin_final_nb[i - 1] * sin_final_ne[i - 1];
3382
+ }
3383
+ aclTensor* acl_sin_final_tensor = ggml_cann_create_tensor(
3384
+ sin_final_buffer, ggml_cann_type_mapping(src0->type),
3385
+ ggml_type_size(src0->type), sin_final_ne, sin_final_nb,
3386
+ GGML_MAX_DIMS);
3387
+ aclTensor* acl_cos_final_tensor = ggml_cann_create_tensor(
3388
+ cos_final_buffer, ggml_cann_type_mapping(src0->type),
3389
+ ggml_type_size(src0->type), sin_final_ne, sin_final_nb,
3390
+ GGML_MAX_DIMS);
3391
+
3392
+ aclnn_cast(ctx, acl_sin_reshape_tensor, acl_sin_final_tensor,
3393
+ ggml_cann_type_mapping(src0->type));
3394
+ aclnn_cast(ctx, acl_cos_reshape_tensor, acl_cos_final_tensor,
3395
+ ggml_cann_type_mapping(src0->type));
3396
+ ACL_CHECK(aclDestroyTensor(acl_cos_reshape_tensor));
3397
+ ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor));
3398
+ acl_sin_reshape_tensor = acl_sin_final_tensor;
3399
+ acl_cos_reshape_tensor = acl_cos_final_tensor;
3400
+ }
3401
+
3402
+ uint64_t workspaceSize = 0;
3403
+ aclOpExecutor* executor;
3404
+
3405
+ void* workspaceAddr = nullptr;
3406
+
3407
+ int acl_mode = mode;
3408
+ if (mode == 0) {
3409
+ acl_mode = 1;
3410
+ }
3411
+
3412
+ ACL_CHECK(aclnnRotaryPositionEmbeddingGetWorkspaceSize(
3413
+ acl_src, acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode,
3414
+ acl_dst, &workspaceSize, &executor));
3415
+ if (workspaceSize > 0) {
3416
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
3417
+ workspaceAddr = workspace_allocator.get();
3418
+ }
3419
+
3420
+ ACL_CHECK(aclnnRotaryPositionEmbedding(workspaceAddr, workspaceSize,
3421
+ executor, ctx.stream()));
3422
+
3423
+ ACL_CHECK(aclDestroyTensor(acl_src));
3076
3424
  ACL_CHECK(aclDestroyTensor(acl_cos_reshape_tensor));
3077
- ACL_CHECK(aclDestroyTensor(acl_minus_one_tensor));
3078
- ACL_CHECK(aclDestroyTensor(acl_input_roll_mul_scale_tensor));
3079
- ACL_CHECK(aclDestroyTensor(acl_input_roll_reshape_tensor));
3080
- ACL_CHECK(aclDestroyTensor(acl_src0));
3425
+ ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor));
3081
3426
  ACL_CHECK(aclDestroyTensor(acl_dst));
3082
3427
  }