@fugood/llama.node 0.3.1 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. package/CMakeLists.txt +1 -8
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +4 -2
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +10 -10
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +14 -17
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +5 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +137 -29
  25. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  26. package/src/llama.cpp/.github/workflows/docker.yml +46 -34
  27. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  28. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  29. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  30. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  31. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  32. package/src/llama.cpp/CMakeLists.txt +26 -11
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/common/CMakeLists.txt +10 -10
  35. package/src/llama.cpp/common/arg.cpp +2041 -0
  36. package/src/llama.cpp/common/arg.h +77 -0
  37. package/src/llama.cpp/common/common.cpp +523 -1861
  38. package/src/llama.cpp/common/common.h +234 -106
  39. package/src/llama.cpp/common/console.cpp +3 -0
  40. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  41. package/src/llama.cpp/common/log.cpp +401 -0
  42. package/src/llama.cpp/common/log.h +66 -698
  43. package/src/llama.cpp/common/ngram-cache.cpp +39 -36
  44. package/src/llama.cpp/common/ngram-cache.h +19 -19
  45. package/src/llama.cpp/common/sampling.cpp +356 -350
  46. package/src/llama.cpp/common/sampling.h +62 -139
  47. package/src/llama.cpp/common/stb_image.h +5990 -6398
  48. package/src/llama.cpp/docs/build.md +72 -17
  49. package/src/llama.cpp/examples/CMakeLists.txt +1 -2
  50. package/src/llama.cpp/examples/batched/batched.cpp +49 -65
  51. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
  52. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  53. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
  54. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  55. package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
  56. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
  57. package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
  58. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  59. package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
  60. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  61. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  62. package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
  63. package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
  64. package/src/llama.cpp/examples/infill/infill.cpp +131 -192
  65. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
  66. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  67. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
  68. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  69. package/src/llama.cpp/examples/llava/clip.cpp +686 -150
  70. package/src/llama.cpp/examples/llava/clip.h +11 -2
  71. package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
  72. package/src/llama.cpp/examples/llava/llava.cpp +146 -26
  73. package/src/llama.cpp/examples/llava/llava.h +2 -3
  74. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  75. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  76. package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
  77. package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
  78. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  79. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
  80. package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
  81. package/src/llama.cpp/examples/main/main.cpp +216 -313
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
  83. package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
  84. package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  87. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
  88. package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
  89. package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
  90. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
  91. package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
  92. package/src/llama.cpp/examples/server/server.cpp +1347 -1531
  93. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  94. package/src/llama.cpp/examples/server/utils.hpp +396 -107
  95. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  96. package/src/llama.cpp/examples/simple/simple.cpp +132 -106
  97. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  98. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  99. package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
  100. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  101. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  102. package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
  103. package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
  104. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  105. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  106. package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
  107. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  108. package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
  109. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  110. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  111. package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
  112. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  113. package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
  114. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  115. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  116. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  117. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  118. package/src/llama.cpp/ggml/include/ggml.h +272 -505
  119. package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
  120. package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
  121. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  122. package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
  123. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  124. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  125. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  126. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  127. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  128. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
  129. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  130. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
  131. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  132. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
  133. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  134. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  135. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  136. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  137. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  138. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
  139. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  140. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  141. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  142. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  143. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  151. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
  152. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  153. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  155. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  156. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  157. package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
  158. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  159. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
  160. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  161. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  162. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  163. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  164. package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
  165. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  167. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  169. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
  172. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  173. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  174. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  175. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  176. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  177. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  178. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  179. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
  180. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  181. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  182. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  183. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
  184. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  185. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
  187. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
  188. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  189. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  190. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  192. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  193. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  194. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  195. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  196. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  197. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  198. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  199. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  200. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
  201. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
  202. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
  203. package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
  204. package/src/llama.cpp/include/llama.h +296 -285
  205. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  206. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  207. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  208. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  209. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  210. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  211. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  212. package/src/llama.cpp/src/llama-grammar.h +120 -15
  213. package/src/llama.cpp/src/llama-impl.h +156 -1
  214. package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
  215. package/src/llama.cpp/src/llama-sampling.h +39 -47
  216. package/src/llama.cpp/src/llama-vocab.cpp +390 -127
  217. package/src/llama.cpp/src/llama-vocab.h +60 -20
  218. package/src/llama.cpp/src/llama.cpp +6215 -3263
  219. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  220. package/src/llama.cpp/src/unicode-data.h +4 -4
  221. package/src/llama.cpp/src/unicode.cpp +15 -7
  222. package/src/llama.cpp/tests/CMakeLists.txt +4 -2
  223. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  224. package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
  225. package/src/llama.cpp/tests/test-barrier.cpp +94 -0
  226. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  227. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  228. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  229. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
  230. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  231. package/src/llama.cpp/tests/test-log.cpp +39 -0
  232. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  233. package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
  234. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  235. package/src/llama.cpp/tests/test-rope.cpp +2 -1
  236. package/src/llama.cpp/tests/test-sampling.cpp +226 -142
  237. package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
  238. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  239. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  240. package/patches/llama.patch +0 -22
  241. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  242. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  243. package/src/llama.cpp/common/grammar-parser.h +0 -29
  244. package/src/llama.cpp/common/train.cpp +0 -1513
  245. package/src/llama.cpp/common/train.h +0 -233
  246. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
  247. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  248. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
  249. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
  250. package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
  251. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  252. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -0,0 +1,91 @@
1
+ if (GGML_STATIC)
2
+ set(BLA_STATIC ON)
3
+ endif()
4
+ #if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.22)
5
+ # set(BLA_SIZEOF_INTEGER 8)
6
+ #endif()
7
+
8
+ set(BLA_VENDOR ${GGML_BLAS_VENDOR})
9
+ find_package(BLAS)
10
+
11
+ if (BLAS_FOUND)
12
+ message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}")
13
+
14
+ add_library(ggml-blas
15
+ ggml-blas.cpp
16
+ )
17
+
18
+ target_link_libraries(ggml-blas PRIVATE ggml-base)
19
+ target_include_directories(ggml-blas PRIVATE . ..)
20
+
21
+ if (${GGML_BLAS_VENDOR} MATCHES "Apple")
22
+ add_compile_definitions(ACCELERATE_NEW_LAPACK)
23
+ add_compile_definitions(ACCELERATE_LAPACK_ILP64)
24
+ add_compile_definitions(GGML_BLAS_USE_ACCELERATE)
25
+ elseif ("${BLAS_INCLUDE_DIRS}" STREQUAL "")
26
+ # BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake.
27
+ # see https://gitlab.kitware.com/cmake/cmake/-/issues/20268
28
+ find_package(PkgConfig REQUIRED)
29
+ if (${GGML_BLAS_VENDOR} MATCHES "Generic")
30
+ pkg_check_modules(DepBLAS blas)
31
+ elseif (${GGML_BLAS_VENDOR} MATCHES "OpenBLAS")
32
+ # As of openblas v0.3.22, the 64-bit is named openblas64.pc
33
+ pkg_check_modules(DepBLAS openblas64)
34
+ if (NOT DepBLAS_FOUND)
35
+ pkg_check_modules(DepBLAS openblas)
36
+ endif()
37
+ elseif (${GGML_BLAS_VENDOR} MATCHES "FLAME")
38
+ add_compile_definitions(GGML_BLAS_USE_BLIS)
39
+ pkg_check_modules(DepBLAS blis)
40
+ elseif (${GGML_BLAS_VENDOR} MATCHES "ATLAS")
41
+ pkg_check_modules(DepBLAS blas-atlas)
42
+ elseif (${GGML_BLAS_VENDOR} MATCHES "FlexiBLAS")
43
+ pkg_check_modules(DepBLAS flexiblas_api)
44
+ elseif (${GGML_BLAS_VENDOR} MATCHES "Intel")
45
+ add_compile_definitions(GGML_BLAS_USE_MKL)
46
+ # all Intel* libraries share the same include path
47
+ pkg_check_modules(DepBLAS mkl-sdl)
48
+ elseif (${GGML_BLAS_VENDOR} MATCHES "NVHPC")
49
+ # this doesn't provide pkg-config
50
+ # suggest to assign BLAS_INCLUDE_DIRS on your own
51
+ if ("${NVHPC_VERSION}" STREQUAL "")
52
+ message(WARNING "Better to set NVHPC_VERSION")
53
+ else()
54
+ set(DepBLAS_FOUND ON)
55
+ set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include")
56
+ endif()
57
+ endif()
58
+ if (DepBLAS_FOUND)
59
+ set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS})
60
+ else()
61
+ message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically"
62
+ " detected by pkgconfig, trying to find cblas.h from possible paths...")
63
+ find_path(BLAS_INCLUDE_DIRS
64
+ NAMES cblas.h
65
+ HINTS
66
+ /usr/include
67
+ /usr/local/include
68
+ /usr/include/openblas
69
+ /opt/homebrew/opt/openblas/include
70
+ /usr/local/opt/openblas/include
71
+ /usr/include/x86_64-linux-gnu/openblas/include
72
+ )
73
+ endif()
74
+ endif()
75
+
76
+ message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}")
77
+
78
+ #add_compile_options(${BLAS_LINKER_FLAGS})
79
+ target_compile_options(ggml-blas PRIVATE ${BLAS_LINKER_FLAGS})
80
+
81
+ if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel"))
82
+ add_compile_definitions(GGML_BLAS_USE_MKL)
83
+ endif()
84
+
85
+ target_link_libraries (ggml-blas PRIVATE ${BLAS_LIBRARIES})
86
+ target_include_directories(ggml-blas PRIVATE ${BLAS_INCLUDE_DIRS})
87
+ else()
88
+ message(ERROR "BLAS not found, please refer to "
89
+ "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors"
90
+ " to set correct GGML_BLAS_VENDOR")
91
+ endif()
@@ -1,10 +1,12 @@
1
+ #include "ggml-impl.h"
1
2
  #include "ggml-blas.h"
2
3
  #include "ggml-backend-impl.h"
3
4
 
4
5
  #include <future>
5
6
  #include <vector>
7
+ #include <cstring>
6
8
 
7
- #if defined(GGML_USE_ACCELERATE)
9
+ #if defined(GGML_BLAS_USE_ACCELERATE)
8
10
  # include <Accelerate/Accelerate.h>
9
11
  #elif defined(GGML_BLAS_USE_MKL)
10
12
  # include <mkl.h>
@@ -25,30 +27,6 @@ struct ggml_backend_blas_context {
25
27
  #endif
26
28
  };
27
29
 
28
- // helper function to determine if it is better to use BLAS or not
29
- // for large matrices, BLAS is faster
30
- static bool ggml_backend_blas_use_blas(const struct ggml_tensor * dst) {
31
- const struct ggml_tensor * src0 = dst->src[0];
32
- const struct ggml_tensor * src1 = dst->src[1];
33
-
34
- const int64_t ne10 = src1->ne[0];
35
-
36
- const int64_t ne0 = dst->ne[0];
37
- const int64_t ne1 = dst->ne[1];
38
-
39
- // TODO: find the optimal values for these
40
- if (ggml_is_contiguous(src0) &&
41
- ggml_is_contiguous(src1) &&
42
- src1->type == GGML_TYPE_F32 &&
43
- (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
44
-
45
- /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
46
- return true;
47
- }
48
-
49
- return false;
50
- }
51
-
52
30
  static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
53
31
  const struct ggml_tensor * src0 = dst->src[0];
54
32
  const struct ggml_tensor * src1 = dst->src[1];
@@ -87,8 +65,8 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
87
65
 
88
66
  // convert src0 to float
89
67
  if (type != GGML_TYPE_F32) {
90
- ggml_type_traits_t type_traits = ggml_internal_get_type_traits(type);
91
- ggml_to_float_t const to_float = type_traits.to_float;
68
+ const auto * type_traits = ggml_get_type_traits(type);
69
+ ggml_to_float_t const to_float = type_traits->to_float;
92
70
 
93
71
  for (int64_t i03 = 0; i03 < ne03; i03++) {
94
72
  for (int64_t i02 = 0; i02 < ne02; i02++) {
@@ -234,25 +212,19 @@ static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct g
234
212
 
235
213
  // backend interface
236
214
 
237
- GGML_CALL static const char * ggml_backend_blas_name(ggml_backend_t backend) {
215
+ static const char * ggml_backend_blas_get_name(ggml_backend_t backend) {
238
216
  return "BLAS";
239
217
 
240
218
  GGML_UNUSED(backend);
241
219
  }
242
220
 
243
- GGML_CALL static void ggml_backend_blas_free(ggml_backend_t backend) {
221
+ static void ggml_backend_blas_free(ggml_backend_t backend) {
244
222
  ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
245
223
  delete ctx;
246
224
  delete backend;
247
225
  }
248
226
 
249
- GGML_CALL static ggml_backend_buffer_type_t ggml_backend_blas_get_default_buffer_type(ggml_backend_t backend) {
250
- return ggml_backend_cpu_buffer_type();
251
-
252
- GGML_UNUSED(backend);
253
- }
254
-
255
- GGML_CALL static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
227
+ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
256
228
  ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
257
229
 
258
230
  for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -284,31 +256,9 @@ GGML_CALL static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t
284
256
  GGML_UNUSED(backend);
285
257
  }
286
258
 
287
- GGML_CALL static bool ggml_backend_blas_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
288
- const struct ggml_tensor * src0 = op->src[0];
289
- const struct ggml_tensor * src1 = op->src[1];
290
-
291
- return (op->op == GGML_OP_MUL_MAT && ggml_backend_blas_use_blas(op)) ||
292
- (op->op == GGML_OP_OUT_PROD && op->src[0]->type == GGML_TYPE_F32 &&
293
- op->src[1]->type == GGML_TYPE_F32 &&
294
- ggml_is_matrix(src0) &&
295
- ggml_is_matrix(src1) &&
296
- ggml_is_contiguous(src0) &&
297
- (ggml_is_contiguous(src1) || ggml_is_transposed(src1)));
298
-
299
- GGML_UNUSED(backend);
300
- }
301
-
302
- GGML_CALL static bool ggml_backend_blas_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
303
- return ggml_backend_buft_is_host(buft);
304
-
305
- GGML_UNUSED(backend);
306
- }
307
-
308
259
  static struct ggml_backend_i blas_backend_i = {
309
- /* .get_name = */ ggml_backend_blas_name,
260
+ /* .get_name = */ ggml_backend_blas_get_name,
310
261
  /* .free = */ ggml_backend_blas_free,
311
- /* .get_default_buffer_type = */ ggml_backend_blas_get_default_buffer_type,
312
262
  /* .set_tensor_async = */ NULL,
313
263
  /* .get_tensor_async = */ NULL,
314
264
  /* .cpy_tensor_async = */ NULL,
@@ -318,14 +268,8 @@ static struct ggml_backend_i blas_backend_i = {
318
268
  /* .graph_plan_update = */ NULL,
319
269
  /* .graph_plan_compute = */ NULL,
320
270
  /* .graph_compute = */ ggml_backend_blas_graph_compute,
321
- /* .supports_op = */ ggml_backend_blas_supports_op,
322
- /* .supports_buft = */ ggml_backend_blas_supports_buft,
323
- /* .offload_op = */ NULL,
324
- /* .event_new = */ NULL,
325
- /* .event_free = */ NULL,
326
271
  /* .event_record = */ NULL,
327
272
  /* .event_wait = */ NULL,
328
- /* .event_synchronize = */ NULL,
329
273
  };
330
274
 
331
275
  static ggml_guid_t ggml_backend_blas_guid(void) {
@@ -339,23 +283,24 @@ ggml_backend_t ggml_backend_blas_init(void) {
339
283
  ggml_backend_t backend = new ggml_backend {
340
284
  /* .guid = */ ggml_backend_blas_guid(),
341
285
  /* .interface = */ blas_backend_i,
286
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0),
342
287
  /* .context = */ ctx,
343
288
  };
344
289
 
345
- #if !defined(NDEBUG) && defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
290
+ #if defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
346
291
  if (openblas_get_parallel() != OPENBLAS_OPENMP) {
347
- fprintf(stderr, "%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__);
292
+ GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__);
348
293
  }
349
294
  #endif
350
295
 
351
- #if !defined(NDEBUG) && defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP)
352
- fprintf(stderr, "%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__);
296
+ #if defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP)
297
+ GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__);
353
298
  #endif
354
299
 
355
300
  return backend;
356
301
  }
357
302
 
358
- GGML_CALL bool ggml_backend_is_blas(ggml_backend_t backend) {
303
+ bool ggml_backend_is_blas(ggml_backend_t backend) {
359
304
  return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_blas_guid());
360
305
  }
361
306
 
@@ -365,3 +310,205 @@ void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads)
365
310
  ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend_blas->context;
366
311
  ctx->n_threads = n_threads;
367
312
  }
313
+
314
+ // device interface
315
+
316
+ static const char * ggml_backend_blas_device_get_name(ggml_backend_dev_t dev) {
317
+ return "BLAS";
318
+
319
+ GGML_UNUSED(dev);
320
+ }
321
+
322
+ static const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t dev) {
323
+ #if defined(GGML_BLAS_USE_ACCELERATE)
324
+ return "Accelerate";
325
+ #elif defined(GGML_BLAS_USE_MKL)
326
+ return "MKL";
327
+ #elif defined(GGML_BLAS_USE_BLIS)
328
+ return "BLIS";
329
+ #elif defined(GGML_BLAS_USE_NVPL)
330
+ return "NVPL";
331
+ #elif defined(OPENBLAS_VERSION)
332
+ return "OpenBLAS";
333
+ #else
334
+ return "BLAS";
335
+ #endif
336
+
337
+ GGML_UNUSED(dev);
338
+ }
339
+
340
+ static void ggml_backend_blas_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
341
+ // TODO
342
+ *free = 0;
343
+ *total = 0;
344
+
345
+ GGML_UNUSED(dev);
346
+ }
347
+
348
+ static enum ggml_backend_dev_type ggml_backend_blas_device_get_type(ggml_backend_dev_t dev) {
349
+ return GGML_BACKEND_DEVICE_TYPE_ACCEL;
350
+
351
+ GGML_UNUSED(dev);
352
+ }
353
+
354
+ static void ggml_backend_blas_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
355
+ props->name = ggml_backend_blas_device_get_name(dev);
356
+ props->description = ggml_backend_blas_device_get_description(dev);
357
+ props->type = ggml_backend_blas_device_get_type(dev);
358
+ ggml_backend_blas_device_get_memory(dev, &props->memory_free, &props->memory_total);
359
+ props->caps = {
360
+ /* .async = */ false,
361
+ /* .host_buffer = */ false,
362
+ /* .buffer_from_host_ptr = */ true,
363
+ /* .events = */ false,
364
+ };
365
+ }
366
+
367
+ static ggml_backend_t ggml_backend_blas_device_init_backend(ggml_backend_dev_t dev, const char * params) {
368
+ return ggml_backend_blas_init();
369
+
370
+ GGML_UNUSED(dev);
371
+ GGML_UNUSED(params);
372
+ }
373
+
374
+ static ggml_backend_buffer_type_t ggml_backend_blas_device_get_buffer_type(ggml_backend_dev_t dev) {
375
+ return ggml_backend_cpu_buffer_type();
376
+
377
+ GGML_UNUSED(dev);
378
+ }
379
+
380
+ static ggml_backend_buffer_t ggml_backend_blas_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
381
+ return ggml_backend_cpu_buffer_from_ptr(ptr, size);
382
+
383
+ GGML_UNUSED(dev);
384
+ GGML_UNUSED(max_tensor_size);
385
+ }
386
+
387
+ static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
388
+ const struct ggml_tensor * src0 = op->src[0];
389
+ const struct ggml_tensor * src1 = op->src[1];
390
+
391
+ switch (op->op) {
392
+ case GGML_OP_NONE:
393
+ case GGML_OP_RESHAPE:
394
+ case GGML_OP_VIEW:
395
+ case GGML_OP_PERMUTE:
396
+ case GGML_OP_TRANSPOSE:
397
+ return true;
398
+
399
+ case GGML_OP_MUL_MAT:
400
+ {
401
+ // BLAS usually is only faster for large matrices
402
+ const struct ggml_tensor * src0 = op->src[0];
403
+ const struct ggml_tensor * src1 = op->src[1];
404
+
405
+ const int64_t ne10 = src1->ne[0];
406
+
407
+ const int64_t ne0 = op->ne[0];
408
+ const int64_t ne1 = op->ne[1];
409
+
410
+ // TODO: find the optimal value
411
+ const int64_t min_batch = 32;
412
+
413
+ return ggml_is_contiguous(src0) &&
414
+ ggml_is_contiguous(src1) &&
415
+ src1->type == GGML_TYPE_F32 &&
416
+ (ne0 >= min_batch && ne1 >= min_batch && ne10 >= min_batch) &&
417
+ (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
418
+ }
419
+
420
+ case GGML_OP_OUT_PROD:
421
+ return op->src[0]->type == GGML_TYPE_F32 &&
422
+ op->src[1]->type == GGML_TYPE_F32 &&
423
+ ggml_is_matrix(src0) &&
424
+ ggml_is_matrix(src1) &&
425
+ ggml_is_contiguous(src0) &&
426
+ (ggml_is_contiguous(src1) || ggml_is_transposed(src1)) &&
427
+ (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
428
+
429
+ default:
430
+ return false;
431
+
432
+ }
433
+
434
+ GGML_UNUSED(dev);
435
+ }
436
+
437
+ static bool ggml_backend_blas_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
438
+ return ggml_backend_buft_is_host(buft);
439
+
440
+ GGML_UNUSED(dev);
441
+ }
442
+
443
+ static const struct ggml_backend_device_i ggml_backend_blas_device_i = {
444
+ /* .get_name = */ ggml_backend_blas_device_get_name,
445
+ /* .get_description = */ ggml_backend_blas_device_get_description,
446
+ /* .get_memory = */ ggml_backend_blas_device_get_memory,
447
+ /* .get_type = */ ggml_backend_blas_device_get_type,
448
+ /* .get_props = */ ggml_backend_blas_device_get_props,
449
+ /* .init_backend = */ ggml_backend_blas_device_init_backend,
450
+ /* .get_buffer_type = */ ggml_backend_blas_device_get_buffer_type,
451
+ /* .get_host_buffer_type = */ NULL,
452
+ /* .buffer_from_host_ptr = */ ggml_backend_blas_device_buffer_from_host_ptr,
453
+ /* .supports_op = */ ggml_backend_blas_device_supports_op,
454
+ /* .supports_buft = */ ggml_backend_blas_device_supports_buft,
455
+ /* .offload_op = */ NULL,
456
+ /* .event_new = */ NULL,
457
+ /* .event_free = */ NULL,
458
+ /* .event_synchronize = */ NULL,
459
+ };
460
+
461
+ // backend reg interface
462
+
463
+ static const char * ggml_backend_blas_reg_get_name(ggml_backend_reg_t reg) {
464
+ return "BLAS";
465
+
466
+ GGML_UNUSED(reg);
467
+ }
468
+
469
+ static size_t ggml_backend_blas_reg_get_device_count(ggml_backend_reg_t reg) {
470
+ return 1;
471
+
472
+ GGML_UNUSED(reg);
473
+ }
474
+
475
+ static ggml_backend_dev_t ggml_backend_blas_reg_get_device(ggml_backend_reg_t reg, size_t index) {
476
+ GGML_ASSERT(index == 0);
477
+
478
+ static ggml_backend_device ggml_backend_blas_device = {
479
+ /* .iface = */ ggml_backend_blas_device_i,
480
+ /* .reg = */ reg,
481
+ /* .context = */ nullptr,
482
+ };
483
+
484
+ return &ggml_backend_blas_device;
485
+
486
+ GGML_UNUSED(reg);
487
+ GGML_UNUSED(index);
488
+ }
489
+
490
+ static void * ggml_backend_blas_get_proc_address(ggml_backend_reg_t reg, const char * name) {
491
+ if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
492
+ return (void *)ggml_backend_blas_set_n_threads;
493
+ }
494
+ return NULL;
495
+
496
+ GGML_UNUSED(reg);
497
+ GGML_UNUSED(name);
498
+ }
499
+
500
+ static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = {
501
+ /* .get_name = */ ggml_backend_blas_reg_get_name,
502
+ /* .get_device_count = */ ggml_backend_blas_reg_get_device_count,
503
+ /* .get_device = */ ggml_backend_blas_reg_get_device,
504
+ /* .get_proc_address = */ ggml_backend_blas_get_proc_address,
505
+ };
506
+
507
+ ggml_backend_reg_t ggml_backend_blas_reg(void) {
508
+ static struct ggml_backend_reg ggml_backend_blas_reg = {
509
+ /* .iface = */ ggml_backend_blas_reg_i,
510
+ /* .context = */ NULL,
511
+ };
512
+
513
+ return &ggml_backend_blas_reg;
514
+ }
@@ -0,0 +1,46 @@
1
+ if ("cann${CANN_INSTALL_DIR}" STREQUAL "cann" AND DEFINED ENV{ASCEND_TOOLKIT_HOME})
2
+ set(CANN_INSTALL_DIR $ENV{ASCEND_TOOLKIT_HOME})
3
+ message(STATUS "CANN: updated CANN_INSTALL_DIR from ASCEND_TOOLKIT_HOME=$ENV{ASCEND_TOOLKIT_HOME}")
4
+ endif()
5
+
6
+ if (CANN_INSTALL_DIR)
7
+ # Only Support Linux.
8
+ if (NOT UNIX)
9
+ message(FATAL_ERROR "CANN: CANN toolkit supports unix but not ${CMAKE_SYSTEM_NAME}")
10
+ endif()
11
+
12
+ # Supported platforms: x86-64, arm64
13
+ if (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64")
14
+ elseif (CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "amd64")
15
+ else()
16
+ message(FATAL_ERROR "CANN: CANN toolkit supports x86-64 and arm64 but not ${CMAKE_SYSTEM_PROCESSOR}")
17
+ endif()
18
+
19
+ # Set header and libs
20
+ set(CANN_INCLUDE_DIRS
21
+ ${CANN_INSTALL_DIR}/include
22
+ ${CANN_INSTALL_DIR}/include/aclnn
23
+ ${CANN_INSTALL_DIR}/acllib/include
24
+ )
25
+
26
+ add_subdirectory(kernels)
27
+ list(APPEND CANN_LIBRARIES
28
+ ascendcl
29
+ nnopbase
30
+ opapi
31
+ acl_op_compiler
32
+ ascendc_kernels
33
+ )
34
+
35
+ file(GLOB GGML_SOURCES_CANN "*.cpp")
36
+
37
+ add_library(ggml-cann ${GGML_SOURCES_CANN})
38
+ target_link_libraries(ggml-cann PRIVATE ggml-base ${CANN_LIBRARIES})
39
+ target_include_directories(ggml-cann PRIVATE . .. ${CANN_INCLUDE_DIRS})
40
+ target_link_directories(ggml-cann PRIVATE ${CANN_INSTALL_DIR}/lib64)
41
+
42
+ message(STATUS "CANN: CANN_INCLUDE_DIRS = ${CANN_INCLUDE_DIRS}")
43
+ message(STATUS "CANN: CANN_LIBRARIES = ${CANN_LIBRARIES}")
44
+ else()
45
+ message(FATAL_ERROR "CANN: Can't find CANN_INSTALL_DIR, did you forget to source set_var.sh?")
46
+ endif()
@@ -37,6 +37,10 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
37
37
  return ACL_INT16;
38
38
  case GGML_TYPE_I32:
39
39
  return ACL_INT32;
40
+ case GGML_TYPE_Q4_0:
41
+ return ACL_INT4;
42
+ case GGML_TYPE_Q8_0:
43
+ return ACL_INT8;
40
44
  default:
41
45
  return ACL_DT_UNDEFINED;
42
46
  }
@@ -89,33 +93,6 @@ bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1) {
89
93
  return false;
90
94
  }
91
95
 
92
- aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype,
93
- size_t type_size, int64_t* ne, size_t* nb,
94
- int64_t dims, aclFormat format,
95
- size_t offset) {
96
- int64_t tmp_ne[GGML_MAX_DIMS * 2];
97
- int64_t tmp_stride[GGML_MAX_DIMS * 2];
98
-
99
- memcpy(tmp_ne, ne, dims * sizeof(int64_t));
100
- for (int i = 0; i < dims; i++) {
101
- tmp_stride[i] = nb[i] / type_size;
102
- }
103
-
104
- std::reverse(tmp_ne, tmp_ne + dims);
105
- std::reverse(tmp_stride, tmp_stride + dims);
106
-
107
- int64_t acl_storage_len = 0;
108
- for (int i = 0; i < dims; i++) {
109
- acl_storage_len += (ne[i] - 1) * nb[i];
110
- }
111
-
112
- aclTensor* acl_tensor =
113
- aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size,
114
- format, &acl_storage_len, 1, data_ptr);
115
-
116
- return acl_tensor;
117
- }
118
-
119
96
  int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0,
120
97
  const ggml_tensor* src1,
121
98
  int64_t* bcast_src0_ne,
@@ -23,6 +23,9 @@
23
23
  #ifndef CANN_ACL_TENSOR_H
24
24
  #define CANN_ACL_TENSOR_H
25
25
 
26
+ #include <algorithm>
27
+ #include <cstring>
28
+
26
29
  #include <aclnn/aclnn_base.h>
27
30
  #include "common.h"
28
31
 
@@ -65,7 +68,8 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne = null
65
68
  size_t offset = 0);
66
69
 
67
70
  /**
68
- * @brief Creates an ACL tensor from provided parameters.
71
+ * @brief Template for creating an ACL tensor from provided parameters. typename TYPE
72
+ * should be size_t or float.
69
73
  *
70
74
  * @details This function creates an ACL tensor using the provided data pointer,
71
75
  * data type, dimensions, strides, format, offset, and additional parameters.
@@ -83,10 +87,34 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne = null
83
87
  * @param offset Offset in bytes for the ACL tensor data. Defaults to 0.
84
88
  * @return Pointer to the created ACL tensor.
85
89
  */
90
+ template<typename TYPE>
86
91
  aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype,
87
- size_t type_size, int64_t* ne, size_t* nb,
88
- int64_t dims, aclFormat format = ACL_FORMAT_ND,
89
- size_t offset = 0);
92
+ TYPE type_size, int64_t* ne, TYPE* nb,
93
+ int64_t dims,
94
+ aclFormat format = ACL_FORMAT_ND,
95
+ size_t offset = 0) {
96
+ int64_t tmp_ne[GGML_MAX_DIMS * 2];
97
+ int64_t tmp_stride[GGML_MAX_DIMS * 2];
98
+
99
+ memcpy(tmp_ne, ne, dims * sizeof(int64_t));
100
+ for (int i = 0; i < dims; i++) {
101
+ tmp_stride[i] = nb[i] / type_size;
102
+ }
103
+
104
+ std::reverse(tmp_ne, tmp_ne + dims);
105
+ std::reverse(tmp_stride, tmp_stride + dims);
106
+
107
+ int64_t acl_storage_len = 0;
108
+ for (int i = 0; i < dims; i++) {
109
+ acl_storage_len += (ne[i] - 1) * nb[i];
110
+ }
111
+
112
+ aclTensor* acl_tensor =
113
+ aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size,
114
+ format, &acl_storage_len, 1, data_ptr);
115
+
116
+ return acl_tensor;
117
+ }
90
118
 
91
119
  /**
92
120
  * @brief Checks if tensors require broadcasting based on their shapes.