@fugood/llama.node 0.3.2 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. package/CMakeLists.txt +2 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +1 -1
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +8 -8
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +8 -9
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +4 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +43 -9
  25. package/src/llama.cpp/.github/workflows/docker.yml +3 -0
  26. package/src/llama.cpp/CMakeLists.txt +7 -4
  27. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  28. package/src/llama.cpp/common/CMakeLists.txt +0 -2
  29. package/src/llama.cpp/common/arg.cpp +642 -607
  30. package/src/llama.cpp/common/arg.h +22 -22
  31. package/src/llama.cpp/common/common.cpp +79 -281
  32. package/src/llama.cpp/common/common.h +130 -100
  33. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  34. package/src/llama.cpp/common/log.cpp +50 -50
  35. package/src/llama.cpp/common/log.h +18 -18
  36. package/src/llama.cpp/common/ngram-cache.cpp +36 -36
  37. package/src/llama.cpp/common/ngram-cache.h +19 -19
  38. package/src/llama.cpp/common/sampling.cpp +116 -108
  39. package/src/llama.cpp/common/sampling.h +20 -20
  40. package/src/llama.cpp/docs/build.md +37 -17
  41. package/src/llama.cpp/examples/CMakeLists.txt +1 -1
  42. package/src/llama.cpp/examples/batched/batched.cpp +14 -14
  43. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
  44. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  45. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
  46. package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
  47. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
  48. package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
  49. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
  50. package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
  51. package/src/llama.cpp/examples/imatrix/imatrix.cpp +20 -11
  52. package/src/llama.cpp/examples/infill/infill.cpp +40 -86
  53. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +42 -151
  54. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  55. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
  56. package/src/llama.cpp/examples/llava/clip.cpp +1 -0
  57. package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
  58. package/src/llama.cpp/examples/llava/llava.cpp +37 -3
  59. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
  60. package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
  61. package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
  62. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  63. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +14 -14
  64. package/src/llama.cpp/examples/lookup/lookup.cpp +29 -29
  65. package/src/llama.cpp/examples/main/main.cpp +64 -109
  66. package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
  67. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  68. package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
  69. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
  70. package/src/llama.cpp/examples/retrieval/retrieval.cpp +13 -13
  71. package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
  72. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +34 -17
  73. package/src/llama.cpp/examples/server/CMakeLists.txt +4 -13
  74. package/src/llama.cpp/examples/server/server.cpp +553 -691
  75. package/src/llama.cpp/examples/server/utils.hpp +312 -25
  76. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  77. package/src/llama.cpp/examples/simple/simple.cpp +128 -96
  78. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  79. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  80. package/src/llama.cpp/examples/speculative/speculative.cpp +54 -51
  81. package/src/llama.cpp/examples/tokenize/tokenize.cpp +2 -2
  82. package/src/llama.cpp/ggml/CMakeLists.txt +15 -9
  83. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  84. package/src/llama.cpp/ggml/include/ggml-backend.h +46 -33
  85. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  86. package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
  87. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  88. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  89. package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
  90. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  91. package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
  92. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  93. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  94. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  95. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  96. package/src/llama.cpp/ggml/include/ggml.h +53 -393
  97. package/src/llama.cpp/ggml/src/CMakeLists.txt +66 -1149
  98. package/src/llama.cpp/ggml/src/ggml-aarch64.c +46 -3126
  99. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  100. package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -27
  101. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  102. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  103. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  104. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  105. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  106. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +6 -25
  107. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  108. package/src/llama.cpp/ggml/src/ggml-backend.cpp +303 -864
  109. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  110. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +213 -65
  111. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  112. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +255 -149
  113. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  114. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  115. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  116. package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -243
  117. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  118. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  119. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  120. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  121. package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +667 -1
  122. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  123. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  124. package/src/llama.cpp/ggml/src/ggml-impl.h +366 -16
  125. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  126. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +238 -72
  127. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  128. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  129. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  130. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  131. package/src/llama.cpp/ggml/src/ggml-quants.c +187 -10692
  132. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
  133. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  134. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +475 -300
  135. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  136. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  137. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +40 -0
  138. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +258 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  140. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +2 -22
  141. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  142. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  143. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3584 -4142
  144. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +69 -67
  145. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +3 -3
  146. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  147. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  148. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
  149. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  150. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  151. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  152. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  153. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  154. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  155. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +555 -623
  156. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +125 -206
  157. package/src/llama.cpp/ggml/src/ggml.c +4032 -19890
  158. package/src/llama.cpp/include/llama.h +67 -33
  159. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  160. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  161. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  162. package/src/llama.cpp/src/llama-sampling.cpp +745 -105
  163. package/src/llama.cpp/src/llama-sampling.h +21 -2
  164. package/src/llama.cpp/src/llama-vocab.cpp +49 -9
  165. package/src/llama.cpp/src/llama-vocab.h +35 -11
  166. package/src/llama.cpp/src/llama.cpp +2636 -2406
  167. package/src/llama.cpp/src/unicode-data.cpp +2 -2
  168. package/src/llama.cpp/tests/CMakeLists.txt +1 -2
  169. package/src/llama.cpp/tests/test-arg-parser.cpp +14 -14
  170. package/src/llama.cpp/tests/test-backend-ops.cpp +185 -60
  171. package/src/llama.cpp/tests/test-barrier.cpp +1 -0
  172. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  173. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
  174. package/src/llama.cpp/tests/test-log.cpp +2 -2
  175. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  176. package/src/llama.cpp/tests/test-quantize-fns.cpp +22 -19
  177. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  178. package/src/llama.cpp/tests/test-rope.cpp +1 -0
  179. package/src/llama.cpp/tests/test-sampling.cpp +162 -137
  180. package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
  181. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  182. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  183. package/src/llama.cpp/common/train.cpp +0 -1515
  184. package/src/llama.cpp/common/train.h +0 -233
  185. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
  186. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
  187. package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
  188. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  189. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
  190. /package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +0 -0
@@ -0,0 +1,91 @@
1
+ if (GGML_STATIC)
2
+ set(BLA_STATIC ON)
3
+ endif()
4
+ #if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.22)
5
+ # set(BLA_SIZEOF_INTEGER 8)
6
+ #endif()
7
+
8
+ set(BLA_VENDOR ${GGML_BLAS_VENDOR})
9
+ find_package(BLAS)
10
+
11
+ if (BLAS_FOUND)
12
+ message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}")
13
+
14
+ add_library(ggml-blas
15
+ ggml-blas.cpp
16
+ )
17
+
18
+ target_link_libraries(ggml-blas PRIVATE ggml-base)
19
+ target_include_directories(ggml-blas PRIVATE . ..)
20
+
21
+ if (${GGML_BLAS_VENDOR} MATCHES "Apple")
22
+ add_compile_definitions(ACCELERATE_NEW_LAPACK)
23
+ add_compile_definitions(ACCELERATE_LAPACK_ILP64)
24
+ add_compile_definitions(GGML_BLAS_USE_ACCELERATE)
25
+ elseif ("${BLAS_INCLUDE_DIRS}" STREQUAL "")
26
+ # BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake.
27
+ # see https://gitlab.kitware.com/cmake/cmake/-/issues/20268
28
+ find_package(PkgConfig REQUIRED)
29
+ if (${GGML_BLAS_VENDOR} MATCHES "Generic")
30
+ pkg_check_modules(DepBLAS blas)
31
+ elseif (${GGML_BLAS_VENDOR} MATCHES "OpenBLAS")
32
+ # As of openblas v0.3.22, the 64-bit is named openblas64.pc
33
+ pkg_check_modules(DepBLAS openblas64)
34
+ if (NOT DepBLAS_FOUND)
35
+ pkg_check_modules(DepBLAS openblas)
36
+ endif()
37
+ elseif (${GGML_BLAS_VENDOR} MATCHES "FLAME")
38
+ add_compile_definitions(GGML_BLAS_USE_BLIS)
39
+ pkg_check_modules(DepBLAS blis)
40
+ elseif (${GGML_BLAS_VENDOR} MATCHES "ATLAS")
41
+ pkg_check_modules(DepBLAS blas-atlas)
42
+ elseif (${GGML_BLAS_VENDOR} MATCHES "FlexiBLAS")
43
+ pkg_check_modules(DepBLAS flexiblas_api)
44
+ elseif (${GGML_BLAS_VENDOR} MATCHES "Intel")
45
+ add_compile_definitions(GGML_BLAS_USE_MKL)
46
+ # all Intel* libraries share the same include path
47
+ pkg_check_modules(DepBLAS mkl-sdl)
48
+ elseif (${GGML_BLAS_VENDOR} MATCHES "NVHPC")
49
+ # this doesn't provide pkg-config
50
+ # suggest to assign BLAS_INCLUDE_DIRS on your own
51
+ if ("${NVHPC_VERSION}" STREQUAL "")
52
+ message(WARNING "Better to set NVHPC_VERSION")
53
+ else()
54
+ set(DepBLAS_FOUND ON)
55
+ set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include")
56
+ endif()
57
+ endif()
58
+ if (DepBLAS_FOUND)
59
+ set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS})
60
+ else()
61
+ message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically"
62
+ " detected by pkgconfig, trying to find cblas.h from possible paths...")
63
+ find_path(BLAS_INCLUDE_DIRS
64
+ NAMES cblas.h
65
+ HINTS
66
+ /usr/include
67
+ /usr/local/include
68
+ /usr/include/openblas
69
+ /opt/homebrew/opt/openblas/include
70
+ /usr/local/opt/openblas/include
71
+ /usr/include/x86_64-linux-gnu/openblas/include
72
+ )
73
+ endif()
74
+ endif()
75
+
76
+ message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}")
77
+
78
+ #add_compile_options(${BLAS_LINKER_FLAGS})
79
+ target_compile_options(ggml-blas PRIVATE ${BLAS_LINKER_FLAGS})
80
+
81
+ if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel"))
82
+ add_compile_definitions(GGML_BLAS_USE_MKL)
83
+ endif()
84
+
85
+ target_link_libraries (ggml-blas PRIVATE ${BLAS_LIBRARIES})
86
+ target_include_directories(ggml-blas PRIVATE ${BLAS_INCLUDE_DIRS})
87
+ else()
88
+ message(ERROR "BLAS not found, please refer to "
89
+ "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors"
90
+ " to set correct GGML_BLAS_VENDOR")
91
+ endif()
@@ -4,8 +4,9 @@
4
4
 
5
5
  #include <future>
6
6
  #include <vector>
7
+ #include <cstring>
7
8
 
8
- #if defined(GGML_USE_ACCELERATE)
9
+ #if defined(GGML_BLAS_USE_ACCELERATE)
9
10
  # include <Accelerate/Accelerate.h>
10
11
  #elif defined(GGML_BLAS_USE_MKL)
11
12
  # include <mkl.h>
@@ -26,30 +27,6 @@ struct ggml_backend_blas_context {
26
27
  #endif
27
28
  };
28
29
 
29
- // helper function to determine if it is better to use BLAS or not
30
- // for large matrices, BLAS is faster
31
- static bool ggml_backend_blas_use_blas(const struct ggml_tensor * dst) {
32
- const struct ggml_tensor * src0 = dst->src[0];
33
- const struct ggml_tensor * src1 = dst->src[1];
34
-
35
- const int64_t ne10 = src1->ne[0];
36
-
37
- const int64_t ne0 = dst->ne[0];
38
- const int64_t ne1 = dst->ne[1];
39
-
40
- // TODO: find the optimal values for these
41
- if (ggml_is_contiguous(src0) &&
42
- ggml_is_contiguous(src1) &&
43
- src1->type == GGML_TYPE_F32 &&
44
- (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
45
-
46
- /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
47
- return true;
48
- }
49
-
50
- return false;
51
- }
52
-
53
30
  static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
54
31
  const struct ggml_tensor * src0 = dst->src[0];
55
32
  const struct ggml_tensor * src1 = dst->src[1];
@@ -88,8 +65,8 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
88
65
 
89
66
  // convert src0 to float
90
67
  if (type != GGML_TYPE_F32) {
91
- ggml_type_traits_t type_traits = ggml_internal_get_type_traits(type);
92
- ggml_to_float_t const to_float = type_traits.to_float;
68
+ const auto * type_traits = ggml_get_type_traits(type);
69
+ ggml_to_float_t const to_float = type_traits->to_float;
93
70
 
94
71
  for (int64_t i03 = 0; i03 < ne03; i03++) {
95
72
  for (int64_t i02 = 0; i02 < ne02; i02++) {
@@ -235,7 +212,7 @@ static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct g
235
212
 
236
213
  // backend interface
237
214
 
238
- static const char * ggml_backend_blas_name(ggml_backend_t backend) {
215
+ static const char * ggml_backend_blas_get_name(ggml_backend_t backend) {
239
216
  return "BLAS";
240
217
 
241
218
  GGML_UNUSED(backend);
@@ -247,12 +224,6 @@ static void ggml_backend_blas_free(ggml_backend_t backend) {
247
224
  delete backend;
248
225
  }
249
226
 
250
- static ggml_backend_buffer_type_t ggml_backend_blas_get_default_buffer_type(ggml_backend_t backend) {
251
- return ggml_backend_cpu_buffer_type();
252
-
253
- GGML_UNUSED(backend);
254
- }
255
-
256
227
  static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
257
228
  ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
258
229
 
@@ -285,31 +256,9 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend,
285
256
  GGML_UNUSED(backend);
286
257
  }
287
258
 
288
- static bool ggml_backend_blas_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
289
- const struct ggml_tensor * src0 = op->src[0];
290
- const struct ggml_tensor * src1 = op->src[1];
291
-
292
- return (op->op == GGML_OP_MUL_MAT && ggml_backend_blas_use_blas(op)) ||
293
- (op->op == GGML_OP_OUT_PROD && op->src[0]->type == GGML_TYPE_F32 &&
294
- op->src[1]->type == GGML_TYPE_F32 &&
295
- ggml_is_matrix(src0) &&
296
- ggml_is_matrix(src1) &&
297
- ggml_is_contiguous(src0) &&
298
- (ggml_is_contiguous(src1) || ggml_is_transposed(src1)));
299
-
300
- GGML_UNUSED(backend);
301
- }
302
-
303
- static bool ggml_backend_blas_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
304
- return ggml_backend_buft_is_host(buft);
305
-
306
- GGML_UNUSED(backend);
307
- }
308
-
309
259
  static struct ggml_backend_i blas_backend_i = {
310
- /* .get_name = */ ggml_backend_blas_name,
260
+ /* .get_name = */ ggml_backend_blas_get_name,
311
261
  /* .free = */ ggml_backend_blas_free,
312
- /* .get_default_buffer_type = */ ggml_backend_blas_get_default_buffer_type,
313
262
  /* .set_tensor_async = */ NULL,
314
263
  /* .get_tensor_async = */ NULL,
315
264
  /* .cpy_tensor_async = */ NULL,
@@ -319,9 +268,6 @@ static struct ggml_backend_i blas_backend_i = {
319
268
  /* .graph_plan_update = */ NULL,
320
269
  /* .graph_plan_compute = */ NULL,
321
270
  /* .graph_compute = */ ggml_backend_blas_graph_compute,
322
- /* .supports_op = */ ggml_backend_blas_supports_op,
323
- /* .supports_buft = */ ggml_backend_blas_supports_buft,
324
- /* .offload_op = */ NULL,
325
271
  /* .event_record = */ NULL,
326
272
  /* .event_wait = */ NULL,
327
273
  };
@@ -337,18 +283,18 @@ ggml_backend_t ggml_backend_blas_init(void) {
337
283
  ggml_backend_t backend = new ggml_backend {
338
284
  /* .guid = */ ggml_backend_blas_guid(),
339
285
  /* .interface = */ blas_backend_i,
340
- /* .device = */ nullptr,
286
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0),
341
287
  /* .context = */ ctx,
342
288
  };
343
289
 
344
- #if !defined(NDEBUG) && defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
290
+ #if defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
345
291
  if (openblas_get_parallel() != OPENBLAS_OPENMP) {
346
- fprintf(stderr, "%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__);
292
+ GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__);
347
293
  }
348
294
  #endif
349
295
 
350
- #if !defined(NDEBUG) && defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP)
351
- fprintf(stderr, "%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__);
296
+ #if defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP)
297
+ GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__);
352
298
  #endif
353
299
 
354
300
  return backend;
@@ -364,3 +310,205 @@ void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads)
364
310
  ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend_blas->context;
365
311
  ctx->n_threads = n_threads;
366
312
  }
313
+
314
+ // device interface
315
+
316
+ static const char * ggml_backend_blas_device_get_name(ggml_backend_dev_t dev) {
317
+ return "BLAS";
318
+
319
+ GGML_UNUSED(dev);
320
+ }
321
+
322
+ static const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t dev) {
323
+ #if defined(GGML_BLAS_USE_ACCELERATE)
324
+ return "Accelerate";
325
+ #elif defined(GGML_BLAS_USE_MKL)
326
+ return "MKL";
327
+ #elif defined(GGML_BLAS_USE_BLIS)
328
+ return "BLIS";
329
+ #elif defined(GGML_BLAS_USE_NVPL)
330
+ return "NVPL";
331
+ #elif defined(OPENBLAS_VERSION)
332
+ return "OpenBLAS";
333
+ #else
334
+ return "BLAS";
335
+ #endif
336
+
337
+ GGML_UNUSED(dev);
338
+ }
339
+
340
+ static void ggml_backend_blas_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
341
+ // TODO
342
+ *free = 0;
343
+ *total = 0;
344
+
345
+ GGML_UNUSED(dev);
346
+ }
347
+
348
+ static enum ggml_backend_dev_type ggml_backend_blas_device_get_type(ggml_backend_dev_t dev) {
349
+ return GGML_BACKEND_DEVICE_TYPE_ACCEL;
350
+
351
+ GGML_UNUSED(dev);
352
+ }
353
+
354
+ static void ggml_backend_blas_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
355
+ props->name = ggml_backend_blas_device_get_name(dev);
356
+ props->description = ggml_backend_blas_device_get_description(dev);
357
+ props->type = ggml_backend_blas_device_get_type(dev);
358
+ ggml_backend_blas_device_get_memory(dev, &props->memory_free, &props->memory_total);
359
+ props->caps = {
360
+ /* .async = */ false,
361
+ /* .host_buffer = */ false,
362
+ /* .buffer_from_host_ptr = */ true,
363
+ /* .events = */ false,
364
+ };
365
+ }
366
+
367
+ static ggml_backend_t ggml_backend_blas_device_init_backend(ggml_backend_dev_t dev, const char * params) {
368
+ return ggml_backend_blas_init();
369
+
370
+ GGML_UNUSED(dev);
371
+ GGML_UNUSED(params);
372
+ }
373
+
374
+ static ggml_backend_buffer_type_t ggml_backend_blas_device_get_buffer_type(ggml_backend_dev_t dev) {
375
+ return ggml_backend_cpu_buffer_type();
376
+
377
+ GGML_UNUSED(dev);
378
+ }
379
+
380
+ static ggml_backend_buffer_t ggml_backend_blas_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
381
+ return ggml_backend_cpu_buffer_from_ptr(ptr, size);
382
+
383
+ GGML_UNUSED(dev);
384
+ GGML_UNUSED(max_tensor_size);
385
+ }
386
+
387
+ static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
388
+ const struct ggml_tensor * src0 = op->src[0];
389
+ const struct ggml_tensor * src1 = op->src[1];
390
+
391
+ switch (op->op) {
392
+ case GGML_OP_NONE:
393
+ case GGML_OP_RESHAPE:
394
+ case GGML_OP_VIEW:
395
+ case GGML_OP_PERMUTE:
396
+ case GGML_OP_TRANSPOSE:
397
+ return true;
398
+
399
+ case GGML_OP_MUL_MAT:
400
+ {
401
+ // BLAS usually is only faster for large matrices
402
+ const struct ggml_tensor * src0 = op->src[0];
403
+ const struct ggml_tensor * src1 = op->src[1];
404
+
405
+ const int64_t ne10 = src1->ne[0];
406
+
407
+ const int64_t ne0 = op->ne[0];
408
+ const int64_t ne1 = op->ne[1];
409
+
410
+ // TODO: find the optimal value
411
+ const int64_t min_batch = 32;
412
+
413
+ return ggml_is_contiguous(src0) &&
414
+ ggml_is_contiguous(src1) &&
415
+ src1->type == GGML_TYPE_F32 &&
416
+ (ne0 >= min_batch && ne1 >= min_batch && ne10 >= min_batch) &&
417
+ (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
418
+ }
419
+
420
+ case GGML_OP_OUT_PROD:
421
+ return op->src[0]->type == GGML_TYPE_F32 &&
422
+ op->src[1]->type == GGML_TYPE_F32 &&
423
+ ggml_is_matrix(src0) &&
424
+ ggml_is_matrix(src1) &&
425
+ ggml_is_contiguous(src0) &&
426
+ (ggml_is_contiguous(src1) || ggml_is_transposed(src1)) &&
427
+ (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
428
+
429
+ default:
430
+ return false;
431
+
432
+ }
433
+
434
+ GGML_UNUSED(dev);
435
+ }
436
+
437
+ static bool ggml_backend_blas_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
438
+ return ggml_backend_buft_is_host(buft);
439
+
440
+ GGML_UNUSED(dev);
441
+ }
442
+
443
+ static const struct ggml_backend_device_i ggml_backend_blas_device_i = {
444
+ /* .get_name = */ ggml_backend_blas_device_get_name,
445
+ /* .get_description = */ ggml_backend_blas_device_get_description,
446
+ /* .get_memory = */ ggml_backend_blas_device_get_memory,
447
+ /* .get_type = */ ggml_backend_blas_device_get_type,
448
+ /* .get_props = */ ggml_backend_blas_device_get_props,
449
+ /* .init_backend = */ ggml_backend_blas_device_init_backend,
450
+ /* .get_buffer_type = */ ggml_backend_blas_device_get_buffer_type,
451
+ /* .get_host_buffer_type = */ NULL,
452
+ /* .buffer_from_host_ptr = */ ggml_backend_blas_device_buffer_from_host_ptr,
453
+ /* .supports_op = */ ggml_backend_blas_device_supports_op,
454
+ /* .supports_buft = */ ggml_backend_blas_device_supports_buft,
455
+ /* .offload_op = */ NULL,
456
+ /* .event_new = */ NULL,
457
+ /* .event_free = */ NULL,
458
+ /* .event_synchronize = */ NULL,
459
+ };
460
+
461
+ // backend reg interface
462
+
463
+ static const char * ggml_backend_blas_reg_get_name(ggml_backend_reg_t reg) {
464
+ return "BLAS";
465
+
466
+ GGML_UNUSED(reg);
467
+ }
468
+
469
+ static size_t ggml_backend_blas_reg_get_device_count(ggml_backend_reg_t reg) {
470
+ return 1;
471
+
472
+ GGML_UNUSED(reg);
473
+ }
474
+
475
+ static ggml_backend_dev_t ggml_backend_blas_reg_get_device(ggml_backend_reg_t reg, size_t index) {
476
+ GGML_ASSERT(index == 0);
477
+
478
+ static ggml_backend_device ggml_backend_blas_device = {
479
+ /* .iface = */ ggml_backend_blas_device_i,
480
+ /* .reg = */ reg,
481
+ /* .context = */ nullptr,
482
+ };
483
+
484
+ return &ggml_backend_blas_device;
485
+
486
+ GGML_UNUSED(reg);
487
+ GGML_UNUSED(index);
488
+ }
489
+
490
+ static void * ggml_backend_blas_get_proc_address(ggml_backend_reg_t reg, const char * name) {
491
+ if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
492
+ return (void *)ggml_backend_blas_set_n_threads;
493
+ }
494
+ return NULL;
495
+
496
+ GGML_UNUSED(reg);
497
+ GGML_UNUSED(name);
498
+ }
499
+
500
+ static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = {
501
+ /* .get_name = */ ggml_backend_blas_reg_get_name,
502
+ /* .get_device_count = */ ggml_backend_blas_reg_get_device_count,
503
+ /* .get_device = */ ggml_backend_blas_reg_get_device,
504
+ /* .get_proc_address = */ ggml_backend_blas_get_proc_address,
505
+ };
506
+
507
+ ggml_backend_reg_t ggml_backend_blas_reg(void) {
508
+ static struct ggml_backend_reg ggml_backend_blas_reg = {
509
+ /* .iface = */ ggml_backend_blas_reg_i,
510
+ /* .context = */ NULL,
511
+ };
512
+
513
+ return &ggml_backend_blas_reg;
514
+ }
@@ -0,0 +1,46 @@
1
+ if ("cann${CANN_INSTALL_DIR}" STREQUAL "cann" AND DEFINED ENV{ASCEND_TOOLKIT_HOME})
2
+ set(CANN_INSTALL_DIR $ENV{ASCEND_TOOLKIT_HOME})
3
+ message(STATUS "CANN: updated CANN_INSTALL_DIR from ASCEND_TOOLKIT_HOME=$ENV{ASCEND_TOOLKIT_HOME}")
4
+ endif()
5
+
6
+ if (CANN_INSTALL_DIR)
7
+ # Only Support Linux.
8
+ if (NOT UNIX)
9
+ message(FATAL_ERROR "CANN: CANN toolkit supports unix but not ${CMAKE_SYSTEM_NAME}")
10
+ endif()
11
+
12
+ # Supported platforms: x86-64, arm64
13
+ if (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64")
14
+ elseif (CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "amd64")
15
+ else()
16
+ message(FATAL_ERROR "CANN: CANN toolkit supports x86-64 and arm64 but not ${CMAKE_SYSTEM_PROCESSOR}")
17
+ endif()
18
+
19
+ # Set header and libs
20
+ set(CANN_INCLUDE_DIRS
21
+ ${CANN_INSTALL_DIR}/include
22
+ ${CANN_INSTALL_DIR}/include/aclnn
23
+ ${CANN_INSTALL_DIR}/acllib/include
24
+ )
25
+
26
+ add_subdirectory(kernels)
27
+ list(APPEND CANN_LIBRARIES
28
+ ascendcl
29
+ nnopbase
30
+ opapi
31
+ acl_op_compiler
32
+ ascendc_kernels
33
+ )
34
+
35
+ file(GLOB GGML_SOURCES_CANN "*.cpp")
36
+
37
+ add_library(ggml-cann ${GGML_SOURCES_CANN})
38
+ target_link_libraries(ggml-cann PRIVATE ggml-base ${CANN_LIBRARIES})
39
+ target_include_directories(ggml-cann PRIVATE . .. ${CANN_INCLUDE_DIRS})
40
+ target_link_directories(ggml-cann PRIVATE ${CANN_INSTALL_DIR}/lib64)
41
+
42
+ message(STATUS "CANN: CANN_INCLUDE_DIRS = ${CANN_INCLUDE_DIRS}")
43
+ message(STATUS "CANN: CANN_LIBRARIES = ${CANN_LIBRARIES}")
44
+ else()
45
+ message(FATAL_ERROR "CANN: Can't find CANN_INSTALL_DIR, did you forget to source set_var.sh?")
46
+ endif()