@fugood/llama.node 0.3.1 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. package/CMakeLists.txt +1 -8
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +4 -2
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +10 -10
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +14 -17
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +5 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +137 -29
  25. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  26. package/src/llama.cpp/.github/workflows/docker.yml +46 -34
  27. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  28. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  29. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  30. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  31. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  32. package/src/llama.cpp/CMakeLists.txt +26 -11
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/common/CMakeLists.txt +10 -10
  35. package/src/llama.cpp/common/arg.cpp +2041 -0
  36. package/src/llama.cpp/common/arg.h +77 -0
  37. package/src/llama.cpp/common/common.cpp +523 -1861
  38. package/src/llama.cpp/common/common.h +234 -106
  39. package/src/llama.cpp/common/console.cpp +3 -0
  40. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  41. package/src/llama.cpp/common/log.cpp +401 -0
  42. package/src/llama.cpp/common/log.h +66 -698
  43. package/src/llama.cpp/common/ngram-cache.cpp +39 -36
  44. package/src/llama.cpp/common/ngram-cache.h +19 -19
  45. package/src/llama.cpp/common/sampling.cpp +356 -350
  46. package/src/llama.cpp/common/sampling.h +62 -139
  47. package/src/llama.cpp/common/stb_image.h +5990 -6398
  48. package/src/llama.cpp/docs/build.md +72 -17
  49. package/src/llama.cpp/examples/CMakeLists.txt +1 -2
  50. package/src/llama.cpp/examples/batched/batched.cpp +49 -65
  51. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
  52. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  53. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
  54. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  55. package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
  56. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
  57. package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
  58. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  59. package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
  60. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  61. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  62. package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
  63. package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
  64. package/src/llama.cpp/examples/infill/infill.cpp +131 -192
  65. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
  66. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  67. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
  68. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  69. package/src/llama.cpp/examples/llava/clip.cpp +686 -150
  70. package/src/llama.cpp/examples/llava/clip.h +11 -2
  71. package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
  72. package/src/llama.cpp/examples/llava/llava.cpp +146 -26
  73. package/src/llama.cpp/examples/llava/llava.h +2 -3
  74. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  75. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  76. package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
  77. package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
  78. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  79. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
  80. package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
  81. package/src/llama.cpp/examples/main/main.cpp +216 -313
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
  83. package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
  84. package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  87. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
  88. package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
  89. package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
  90. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
  91. package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
  92. package/src/llama.cpp/examples/server/server.cpp +1347 -1531
  93. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  94. package/src/llama.cpp/examples/server/utils.hpp +396 -107
  95. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  96. package/src/llama.cpp/examples/simple/simple.cpp +132 -106
  97. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  98. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  99. package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
  100. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  101. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  102. package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
  103. package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
  104. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  105. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  106. package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
  107. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  108. package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
  109. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  110. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  111. package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
  112. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  113. package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
  114. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  115. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  116. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  117. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  118. package/src/llama.cpp/ggml/include/ggml.h +272 -505
  119. package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
  120. package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
  121. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  122. package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
  123. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  124. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  125. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  126. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  127. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  128. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
  129. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  130. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
  131. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  132. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
  133. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  134. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  135. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  136. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  137. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  138. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
  139. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  140. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  141. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  142. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  143. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  151. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
  152. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  153. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  155. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  156. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  157. package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
  158. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  159. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
  160. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  161. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  162. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  163. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  164. package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
  165. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  167. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  169. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
  172. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  173. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  174. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  175. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  176. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  177. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  178. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  179. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
  180. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  181. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  182. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  183. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
  184. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  185. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
  187. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
  188. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  189. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  190. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  192. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  193. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  194. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  195. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  196. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  197. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  198. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  199. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  200. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
  201. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
  202. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
  203. package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
  204. package/src/llama.cpp/include/llama.h +296 -285
  205. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  206. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  207. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  208. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  209. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  210. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  211. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  212. package/src/llama.cpp/src/llama-grammar.h +120 -15
  213. package/src/llama.cpp/src/llama-impl.h +156 -1
  214. package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
  215. package/src/llama.cpp/src/llama-sampling.h +39 -47
  216. package/src/llama.cpp/src/llama-vocab.cpp +390 -127
  217. package/src/llama.cpp/src/llama-vocab.h +60 -20
  218. package/src/llama.cpp/src/llama.cpp +6215 -3263
  219. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  220. package/src/llama.cpp/src/unicode-data.h +4 -4
  221. package/src/llama.cpp/src/unicode.cpp +15 -7
  222. package/src/llama.cpp/tests/CMakeLists.txt +4 -2
  223. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  224. package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
  225. package/src/llama.cpp/tests/test-barrier.cpp +94 -0
  226. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  227. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  228. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  229. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
  230. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  231. package/src/llama.cpp/tests/test-log.cpp +39 -0
  232. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  233. package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
  234. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  235. package/src/llama.cpp/tests/test-rope.cpp +2 -1
  236. package/src/llama.cpp/tests/test-sampling.cpp +226 -142
  237. package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
  238. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  239. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  240. package/patches/llama.patch +0 -22
  241. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  242. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  243. package/src/llama.cpp/common/grammar-parser.h +0 -29
  244. package/src/llama.cpp/common/train.cpp +0 -1513
  245. package/src/llama.cpp/common/train.h +0 -233
  246. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
  247. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  248. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
  249. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
  250. package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
  251. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  252. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -464,9 +464,11 @@ void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
464
464
  aclTensor* acl_src = ggml_cann_create_tensor(src);
465
465
  aclTensor* acl_dst = ggml_cann_create_tensor(dst);
466
466
 
467
- const float eps = 1e-6f; // TODO: make this a parameter
468
467
  int n_groups = dst->op_params[0];
469
468
 
469
+ float eps;
470
+ memcpy(&eps, dst->op_params + 1, sizeof(float));
471
+
470
472
  uint64_t workspaceSize = 0;
471
473
  aclOpExecutor* executor;
472
474
  void* workspaceAddr = nullptr;
@@ -910,6 +912,13 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
910
912
  ((ggml_tensor*)dst->extra)->ne);
911
913
  return;
912
914
  }
915
+ if (dst->type == GGML_TYPE_Q4_0) {
916
+ aclrtlaunch_ascendc_quantize_f16_to_q4_0(
917
+ 24, ctx.stream(), src->data, dst->data,
918
+ ((ggml_tensor*)src->extra)->ne, ((ggml_tensor*)src->extra)->nb,
919
+ ((ggml_tensor*)dst->extra)->ne);
920
+ return;
921
+ }
913
922
  if (dst->type == GGML_TYPE_F16) {
914
923
  if (ggml_are_same_shape(src, dst)) {
915
924
  cann_copy(ctx, acl_src, acl_dst);
@@ -971,6 +980,13 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
971
980
  ((ggml_tensor*)dst->extra)->ne);
972
981
  return;
973
982
  }
983
+ if (dst->type == GGML_TYPE_Q4_0) {
984
+ aclrtlaunch_ascendc_quantize_f32_to_q4_0(
985
+ 24, ctx.stream(), src->data, dst->data,
986
+ ((ggml_tensor*)src->extra)->ne, ((ggml_tensor*)src->extra)->nb,
987
+ ((ggml_tensor*)dst->extra)->ne);
988
+ return;
989
+ }
974
990
  if (dst->type == GGML_TYPE_F32) {
975
991
  if (ggml_are_same_shape(src, dst)) {
976
992
  cann_copy(ctx, acl_src, acl_dst);
@@ -1312,6 +1328,111 @@ aclnnStatus aclnnIm2col(void* workspace, uint64_t workspaceSize,
1312
1328
  #ifdef __cplusplus
1313
1329
  }
1314
1330
  #endif
1331
+
1332
+ static void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx,
1333
+ ggml_tensor* dst,
1334
+ ggml_tensor* src1,
1335
+ aclTensor* tmp_cast_tensor,
1336
+ aclTensor* tmp_im2col_tensor) {
1337
+ // Permute: [N, IC * KH * KW, OW * OH] -> [N, OW * OH, IC * KH * KW]
1338
+ int64_t dst_ne[] = {dst->ne[0], dst->ne[1] * dst->ne[2], dst->ne[3]};
1339
+ size_t dst_nb[] = {dst->nb[0], dst->nb[1], dst->nb[3]};
1340
+ aclTensor* acl_dst =
1341
+ ggml_cann_create_tensor(dst, dst_ne, dst_nb, GGML_MAX_DIMS - 1);
1342
+
1343
+ int64_t permute_dim[] = {0, 2, 1};
1344
+ if (src1->type != dst->type) {
1345
+ aclnn_permute(ctx, tmp_cast_tensor, acl_dst, permute_dim, 3);
1346
+ } else {
1347
+ aclnn_permute(ctx, tmp_im2col_tensor, acl_dst, permute_dim, 3);
1348
+ }
1349
+
1350
+ // release
1351
+ ACL_CHECK(aclDestroyTensor(acl_dst));
1352
+ }
1353
+
1354
+ static void ggml_cann_im2col_1d_post_process(
1355
+ ggml_backend_cann_context& ctx, ggml_tensor* dst, ggml_tensor* src1,
1356
+ aclTensor* tmp_cast_tensor, aclTensor* tmp_im2col_tensor,
1357
+ const std::vector<int64_t>& im2col_op_params) {
1358
+ // get params
1359
+ const int64_t KH = im2col_op_params[0];
1360
+ const int64_t KW = im2col_op_params[1];
1361
+ const int64_t IW = im2col_op_params[2];
1362
+ const int64_t IC = im2col_op_params[3];
1363
+ const int64_t N = im2col_op_params[4];
1364
+ const int64_t OH = im2col_op_params[5];
1365
+ const int64_t OW = im2col_op_params[6];
1366
+ const int64_t s0 = im2col_op_params[7];
1367
+ const int64_t p0 = im2col_op_params[8];
1368
+ const int64_t d0 = im2col_op_params[9];
1369
+ const int64_t n_bytes_factor = im2col_op_params[10];
1370
+
1371
+ // Permute: [N, IC * KH * KW, OW * OH] ->
1372
+ // [N, OW * OH * n_bytes_factor, IC * KH * KW]
1373
+ aclTensor* tmp_permute_tensor = nullptr;
1374
+ ggml_cann_pool_alloc tmp_permute_allocator(ctx.pool());
1375
+ tmp_permute_allocator.alloc(ggml_nbytes(dst) * n_bytes_factor);
1376
+ void* tmp_permute_buffer = tmp_permute_allocator.get();
1377
+
1378
+ int64_t tmp_permute_ne[] = {IC * KH * KW, OW * OH * n_bytes_factor, N};
1379
+ size_t tmp_permute_nb[GGML_MAX_DIMS - 1];
1380
+ tmp_permute_nb[0] = ggml_type_size(dst->type);
1381
+ for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
1382
+ tmp_permute_nb[i] = tmp_permute_nb[i - 1] * tmp_permute_ne[i - 1];
1383
+ }
1384
+
1385
+ tmp_permute_tensor = ggml_cann_create_tensor(
1386
+ tmp_permute_buffer, ggml_cann_type_mapping(dst->type),
1387
+ ggml_type_size(dst->type), tmp_permute_ne, tmp_permute_nb,
1388
+ GGML_MAX_DIMS - 1, ACL_FORMAT_ND);
1389
+
1390
+ int64_t permute_dim[] = {0, 2, 1};
1391
+ if (src1->type != dst->type) {
1392
+ aclnn_permute(ctx, tmp_cast_tensor, tmp_permute_tensor, permute_dim, 3);
1393
+ } else {
1394
+ aclnn_permute(ctx, tmp_im2col_tensor, tmp_permute_tensor, permute_dim,
1395
+ 3);
1396
+ }
1397
+
1398
+ // number of times the kernel moves in W dimension
1399
+ const int n_step_w = (IW + 2 * p0 - d0 * (KW - 1) - 1) / s0 + 1;
1400
+ size_t offset;
1401
+ void *cur_dst_buffer = dst->data, *cur_permute_buffer = tmp_permute_buffer;
1402
+
1403
+ // memory copy with offset to restore 1D im2col from 2d
1404
+ if (IC > 1) {
1405
+ offset = IC * KH * KW * n_step_w * ggml_type_size(dst->type);
1406
+ size_t size_cpy = KH * KW * ggml_type_size(dst->type);
1407
+
1408
+ for (int c = 0; c < IC; c++) {
1409
+ cur_permute_buffer = (char*)tmp_permute_buffer + offset +
1410
+ KH * KW * c * ggml_type_size(dst->type);
1411
+ cur_dst_buffer = (char*)dst->data +
1412
+ c * KH * KW * n_step_w * ggml_type_size(dst->type);
1413
+
1414
+ for (int i = 0; i < n_step_w; i++) {
1415
+ ACL_CHECK(aclrtMemcpyAsync(
1416
+ cur_dst_buffer, size_cpy, cur_permute_buffer, size_cpy,
1417
+ ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
1418
+ cur_dst_buffer =
1419
+ (char*)cur_dst_buffer + KH * KW * ggml_type_size(dst->type);
1420
+ cur_permute_buffer = (char*)cur_permute_buffer +
1421
+ KH * KW * IC * ggml_type_size(dst->type);
1422
+ }
1423
+ }
1424
+ } else {
1425
+ offset = KH * KW * n_step_w *
1426
+ ggml_type_size(dst->type); // equal to ggml_nbytes(dst)
1427
+ ACL_CHECK(aclrtMemcpyAsync(dst->data, offset,
1428
+ (char*)tmp_permute_buffer + offset, offset,
1429
+ ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
1430
+ }
1431
+
1432
+ // release
1433
+ ACL_CHECK(aclDestroyTensor(tmp_permute_tensor));
1434
+ }
1435
+
1315
1436
  void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1316
1437
  ggml_tensor* src0 = dst->src[0]; // kernel
1317
1438
  ggml_tensor* src1 = dst->src[1]; // input
@@ -1320,21 +1441,23 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1320
1441
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
1321
1442
  GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
1322
1443
 
1444
+ GGML_TENSOR_BINARY_OP_LOCALS;
1445
+
1446
+ // aclnnIm2col only works on 2D. set s1, p1, d1 to 1 to perform 2D
1447
+ // im2col and do post-processing to restore it to 1D.
1448
+ const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
1323
1449
  const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
1324
- const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
1450
+ const int32_t s1 = is_2D ? ((const int32_t*)(dst->op_params))[1] : 1;
1325
1451
  const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
1326
- const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
1452
+ const int32_t p1 = is_2D ? ((const int32_t*)(dst->op_params))[3] : 1;
1327
1453
  const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
1328
- const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
1329
- const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
1454
+ const int32_t d1 = is_2D ? ((const int32_t*)(dst->op_params))[5] : 1;
1330
1455
 
1331
- GGML_TENSOR_BINARY_OP_LOCALS;
1332
-
1333
- const int64_t N = is_2D ? ne13 : ne12;
1334
- const int64_t IC = is_2D ? ne12 : ne11;
1335
-
1336
- const int64_t KH = is_2D ? ne01 : 1;
1456
+ const int64_t N = ne13;
1457
+ const int64_t IC = ne12;
1458
+ const int64_t KH = ne01;
1337
1459
  const int64_t KW = ne00;
1460
+ const int64_t IW = ne10;
1338
1461
 
1339
1462
  const int64_t OH = is_2D ? ne2 : 1;
1340
1463
  const int64_t OW = ne1;
@@ -1342,9 +1465,12 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1342
1465
  GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
1343
1466
  GGML_ASSERT(nb10 == sizeof(float));
1344
1467
 
1345
- // im2col: [N,C,H,W] -> [N, IC * KH * KW, OW * OH]
1468
+ // memory allocated increased to 3x when is_2D == false
1469
+ const int64_t n_bytes_factor = is_2D ? 1 : 3;
1470
+
1471
+ // im2col: [N,C,H,W] -> [N, IC * KH * KW, OW * OH * n_bytes_factor]
1346
1472
  aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
1347
- int64_t tmp_im2col_ne[] = {OW * OH, IC * KH * KW, N};
1473
+ int64_t tmp_im2col_ne[] = {OW * OH * n_bytes_factor, IC * KH * KW, N};
1348
1474
  size_t tmp_im2col_nb[GGML_MAX_DIMS - 1];
1349
1475
 
1350
1476
  tmp_im2col_nb[0] = ggml_type_size(src1->type);
@@ -1356,8 +1482,10 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1356
1482
  // If dst is f16, tmp_buffer is f32, we need alloc src.typesize *
1357
1483
  // dst.elemcount.
1358
1484
  ggml_cann_pool_alloc im2col_allocator(
1359
- ctx.pool(), ggml_nelements(dst) * ggml_element_size(src1));
1485
+ ctx.pool(),
1486
+ ggml_nelements(dst) * ggml_element_size(src1) * n_bytes_factor);
1360
1487
  void* tmp_im2col_buffer = im2col_allocator.get();
1488
+
1361
1489
  aclTensor* tmp_im2col_tensor = ggml_cann_create_tensor(
1362
1490
  tmp_im2col_buffer, ggml_cann_type_mapping(src1->type),
1363
1491
  ggml_type_size(src1->type), tmp_im2col_ne, tmp_im2col_nb,
@@ -1380,8 +1508,9 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1380
1508
  paddings, strides, tmp_im2col_tensor,
1381
1509
  &workspaceSize, &executor));
1382
1510
 
1511
+ ggml_cann_pool_alloc workspace_allocator(ctx.pool());
1383
1512
  if (workspaceSize > 0) {
1384
- ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
1513
+ workspace_allocator.alloc(workspaceSize);
1385
1514
  workspaceAddr = workspace_allocator.get();
1386
1515
  }
1387
1516
 
@@ -1391,9 +1520,10 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1391
1520
  // Cast if dst is f16.
1392
1521
  aclTensor* tmp_cast_tensor = nullptr;
1393
1522
  ggml_cann_pool_alloc tmp_cast_allocator(ctx.pool());
1523
+ void* tmp_cast_buffer = nullptr;
1394
1524
  if (src1->type != dst->type) {
1395
- tmp_cast_allocator.alloc(ggml_nbytes(dst));
1396
- void* tmp_cast_buffer = tmp_cast_allocator.get();
1525
+ tmp_cast_allocator.alloc(ggml_nbytes(dst) * n_bytes_factor);
1526
+ tmp_cast_buffer = tmp_cast_allocator.get();
1397
1527
  size_t temp_cast_nb[GGML_MAX_DIMS - 1];
1398
1528
  temp_cast_nb[0] = ggml_type_size(dst->type);
1399
1529
  for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
@@ -1408,24 +1538,21 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1408
1538
  ggml_cann_type_mapping(dst->type));
1409
1539
  }
1410
1540
 
1411
- // Permute: [N, IC * KH * KW, OW * OH] -> [N, OW * OH, IC * KH * KW]
1412
- int64_t dst_ne[] = {dst->ne[0], dst->ne[1] * dst->ne[2], dst->ne[3]};
1413
- size_t dst_nb[] = {dst->nb[0], dst->nb[1], dst->nb[3]};
1414
- aclTensor* acl_dst =
1415
- ggml_cann_create_tensor(dst, dst_ne, dst_nb, GGML_MAX_DIMS - 1);
1416
-
1417
- int64_t permute_dim[] = {0, 2, 1};
1418
- if (src1->type != dst->type) {
1419
- aclnn_permute(ctx, tmp_cast_tensor, acl_dst, permute_dim, 3);
1541
+ // post-processing
1542
+ if (is_2D) {
1543
+ ggml_cann_im2col_2d_post_process(ctx, dst, src1, tmp_cast_tensor,
1544
+ tmp_im2col_tensor);
1420
1545
  } else {
1421
- aclnn_permute(ctx, tmp_im2col_tensor, acl_dst, permute_dim, 3);
1546
+ std::vector<int64_t> im2col_op_params = {
1547
+ KH, KW, IW, IC, N, OH, OW, s0, p0, d0, n_bytes_factor};
1548
+ ggml_cann_im2col_1d_post_process(ctx, dst, src1, tmp_cast_tensor,
1549
+ tmp_im2col_tensor, im2col_op_params);
1422
1550
  }
1423
1551
 
1424
1552
  // release
1425
1553
  ACL_CHECK(aclDestroyTensor(acl_src1));
1426
1554
  ACL_CHECK(aclDestroyTensor(tmp_im2col_tensor));
1427
1555
  ACL_CHECK(aclDestroyTensor(tmp_cast_tensor));
1428
- ACL_CHECK(aclDestroyTensor(acl_dst));
1429
1556
  ACL_CHECK(aclDestroyIntArray(kernel_size));
1430
1557
  ACL_CHECK(aclDestroyIntArray(dilations));
1431
1558
  ACL_CHECK(aclDestroyIntArray(paddings));
@@ -2352,21 +2479,33 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
2352
2479
  * @param dst The destination tensor where the result of the matrix
2353
2480
  * multiplication will be stored.
2354
2481
  */
2355
- static void ggml_cann_mul_mat_q8_0(ggml_backend_cann_context& ctx,
2356
- ggml_tensor* dst) {
2482
+ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
2483
+ ggml_tensor* dst,
2484
+ const enum ggml_type type) {
2357
2485
  ggml_tensor* src0 = dst->src[0]; // weight
2358
2486
  ggml_tensor* src1 = dst->src[1]; // input
2359
2487
 
2360
2488
  // The shape of the weight is NCHW. Matrix multiplication uses HW dims. HC
2361
2489
  // is regarded as batch. weight need transpose.
2362
2490
  int64_t weight_ne[] = {src0->ne[1], src0->ne[0]};
2363
- size_t weight_elem_size = sizeof(uint8_t);
2364
- size_t weight_nb[] = {weight_elem_size * src0->ne[0], weight_elem_size};
2491
+ float weight_elem_size;
2492
+ if (type == GGML_TYPE_Q4_0) {
2493
+ weight_elem_size = float(sizeof(uint8_t)) / 2;
2494
+ }
2495
+ else if (type == GGML_TYPE_Q8_0) {
2496
+ weight_elem_size = float(sizeof(uint8_t));
2497
+ }
2498
+ else {
2499
+ GGML_ABORT("Only support Q4_0 and Q8_0 MUL_MAT");
2500
+ }
2501
+ float weight_nb[] = {weight_elem_size * src0->ne[0], weight_elem_size};
2502
+
2365
2503
  // size of one matrix is element_size * height * width.
2366
2504
  size_t weight_stride = weight_elem_size * src0->ne[0] * src0->ne[1];
2367
2505
  size_t weight_size = weight_stride * src0->ne[2] * src0->ne[3];
2368
2506
 
2369
2507
  // scale stored at the end of weight. Also need transpose.
2508
+ GGML_ASSERT(QK4_0 == QK8_0);
2370
2509
  int64_t scale_ne[] = {src0->ne[1], src0->ne[0] / QK8_0};
2371
2510
  size_t scale_elem_size = sizeof(uint16_t);
2372
2511
  size_t scale_nb[] = {src0->ne[0] / QK8_0 * scale_elem_size,
@@ -2381,10 +2520,10 @@ static void ggml_cann_mul_mat_q8_0(ggml_backend_cann_context& ctx,
2381
2520
  size_t input_nb[] = {input_elem_size, input_elem_size * src1->ne[0]};
2382
2521
  size_t input_stride = input_elem_size * src1->ne[0] * src1->ne[1];
2383
2522
 
2523
+ ggml_cann_pool_alloc input_alloctor(ctx.pool());
2384
2524
  if (src1->type != GGML_TYPE_F16) {
2385
2525
  aclTensor* acl_src1_tensor = ggml_cann_create_tensor(src1);
2386
- ggml_cann_pool_alloc input_alloctor(
2387
- ctx.pool(), ggml_nelements(src1) * input_elem_size);
2526
+ input_alloctor.alloc(ggml_nelements(src1) * input_elem_size);
2388
2527
  input_buffer = input_alloctor.get();
2389
2528
 
2390
2529
  int64_t* input_cast_ne = src1->ne;
@@ -2430,8 +2569,9 @@ static void ggml_cann_mul_mat_q8_0(ggml_backend_cann_context& ctx,
2430
2569
  (char*)input_buffer + batch1 * input_stride, ACL_FLOAT16,
2431
2570
  input_elem_size, input_ne, input_nb, 2);
2432
2571
  aclTensor* acl_weight_tensor = ggml_cann_create_tensor(
2433
- (char*)src0->data + batch0 * weight_stride, ACL_INT8,
2434
- weight_elem_size, weight_ne, weight_nb, 2);
2572
+ (char*)src0->data + batch0 * weight_stride,
2573
+ ggml_cann_type_mapping(type), weight_elem_size, weight_ne,
2574
+ weight_nb, 2);
2435
2575
  aclTensor* acl_scale_tensor = ggml_cann_create_tensor(
2436
2576
  scale_offset + batch0 * scale_stride, ACL_FLOAT16,
2437
2577
  scale_elem_size, scale_ne, scale_nb, 2);
@@ -2485,11 +2625,9 @@ void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2485
2625
  case GGML_TYPE_F16:
2486
2626
  ggml_cann_mat_mul_fp(ctx, dst);
2487
2627
  break;
2488
- // case GGML_TYPE_Q4_0:
2489
- // ggml_cann_mul_mat_q4_0(ctx, dst);
2490
- // break;
2628
+ case GGML_TYPE_Q4_0:
2491
2629
  case GGML_TYPE_Q8_0:
2492
- ggml_cann_mul_mat_q8_0(ctx, dst);
2630
+ ggml_cann_mul_mat_quant(ctx, dst, type);
2493
2631
  break;
2494
2632
  default:
2495
2633
  GGML_ABORT("fatal error");
@@ -2743,7 +2881,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2743
2881
  ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast,
2744
2882
  beta_slow, corr_dims);
2745
2883
 
2746
- const bool is_neox = mode & 2;
2884
+ const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
2747
2885
 
2748
2886
  // init cos/sin cache
2749
2887
  ggml_cann_pool_alloc sin_allocator(
@@ -227,6 +227,7 @@ struct ggml_backend_cann_context {
227
227
  * @brief Destructor for cleaning up resources.
228
228
  */
229
229
  ~ggml_backend_cann_context() {
230
+ ggml_cann_set_device(device);
230
231
  if (copy_event != nullptr) {
231
232
  ACL_CHECK(aclrtDestroyEvent(copy_event));
232
233
  }