@fugood/llama.node 0.3.16 → 0.3.17

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. package/CMakeLists.txt +3 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +5 -0
  19. package/package.json +1 -1
  20. package/src/LlamaCompletionWorker.cpp +8 -0
  21. package/src/LlamaCompletionWorker.h +1 -0
  22. package/src/LlamaContext.cpp +3 -2
  23. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +124 -0
  24. package/src/llama.cpp/.github/workflows/build.yml +70 -27
  25. package/src/llama.cpp/.github/workflows/docker.yml +6 -6
  26. package/src/llama.cpp/.github/workflows/server.yml +7 -11
  27. package/src/llama.cpp/CMakeLists.txt +23 -1
  28. package/src/llama.cpp/common/CMakeLists.txt +6 -3
  29. package/src/llama.cpp/common/arg.cpp +809 -105
  30. package/src/llama.cpp/common/arg.h +9 -0
  31. package/src/llama.cpp/common/chat.cpp +1 -1
  32. package/src/llama.cpp/common/common.cpp +31 -521
  33. package/src/llama.cpp/common/common.h +17 -36
  34. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  35. package/src/llama.cpp/common/llguidance.cpp +30 -47
  36. package/src/llama.cpp/common/minja/chat-template.hpp +15 -7
  37. package/src/llama.cpp/common/minja/minja.hpp +119 -93
  38. package/src/llama.cpp/common/sampling.cpp +3 -0
  39. package/src/llama.cpp/docs/build.md +122 -7
  40. package/src/llama.cpp/examples/CMakeLists.txt +0 -9
  41. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  42. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +1 -1
  43. package/src/llama.cpp/examples/embedding/embedding.cpp +7 -1
  44. package/src/llama.cpp/examples/export-lora/export-lora.cpp +1 -1
  45. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +15 -16
  46. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  47. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +210 -8
  48. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  49. package/src/llama.cpp/examples/llava/CMakeLists.txt +39 -24
  50. package/src/llama.cpp/examples/llava/clip-impl.h +345 -0
  51. package/src/llama.cpp/examples/llava/clip.cpp +2152 -1803
  52. package/src/llama.cpp/examples/llava/clip.h +39 -22
  53. package/src/llama.cpp/examples/llava/deprecation-warning.cpp +22 -0
  54. package/src/llama.cpp/examples/llava/llava.cpp +64 -52
  55. package/src/llama.cpp/examples/llava/mtmd-cli.cpp +344 -0
  56. package/src/llama.cpp/examples/llava/mtmd.cpp +708 -0
  57. package/src/llama.cpp/examples/llava/mtmd.h +168 -0
  58. package/src/llama.cpp/examples/llava/{qwen2vl-cli.cpp → qwen2vl-test.cpp} +83 -31
  59. package/src/llama.cpp/examples/main/main.cpp +16 -5
  60. package/src/llama.cpp/examples/parallel/parallel.cpp +3 -1
  61. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  62. package/src/llama.cpp/examples/perplexity/perplexity.cpp +17 -3
  63. package/src/llama.cpp/examples/quantize/quantize.cpp +115 -2
  64. package/src/llama.cpp/examples/rpc/CMakeLists.txt +4 -2
  65. package/src/llama.cpp/examples/rpc/rpc-server.cpp +163 -8
  66. package/src/llama.cpp/examples/run/CMakeLists.txt +12 -1
  67. package/src/llama.cpp/examples/run/run.cpp +14 -28
  68. package/src/llama.cpp/examples/server/httplib.h +313 -247
  69. package/src/llama.cpp/examples/server/server.cpp +238 -139
  70. package/src/llama.cpp/examples/server/utils.hpp +51 -2
  71. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  72. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  73. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  74. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  75. package/src/llama.cpp/examples/tts/tts.cpp +6 -9
  76. package/src/llama.cpp/ggml/CMakeLists.txt +8 -2
  77. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  78. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  79. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  80. package/src/llama.cpp/ggml/include/ggml.h +66 -99
  81. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
  82. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  83. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  84. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  85. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  86. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  87. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  88. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  89. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  90. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +48 -22
  91. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  92. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  93. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  94. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -192
  96. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +754 -404
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1003 -13519
  99. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +2 -7
  101. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +0 -1
  102. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +3 -4
  103. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +533 -88
  104. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8809 -0
  105. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  106. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  107. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  108. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  109. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +258 -0
  110. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  111. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  112. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  113. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  114. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  115. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +70 -3
  116. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  117. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -260
  118. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +293 -40
  119. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +96 -22
  120. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +350 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  123. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  124. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +2 -292
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +967 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  130. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +204 -280
  131. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  133. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  134. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  135. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  136. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  137. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  138. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +23 -0
  139. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +646 -114
  140. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +12 -0
  141. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +17 -8
  142. package/src/llama.cpp/ggml/src/ggml.c +141 -245
  143. package/src/llama.cpp/ggml/src/gguf.cpp +1 -0
  144. package/src/llama.cpp/include/llama.h +30 -11
  145. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  146. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  147. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  148. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  149. package/src/llama.cpp/requirements/requirements-all.txt +2 -0
  150. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  151. package/src/llama.cpp/src/CMakeLists.txt +3 -2
  152. package/src/llama.cpp/src/llama-adapter.cpp +37 -1
  153. package/src/llama.cpp/src/llama-arch.cpp +160 -17
  154. package/src/llama.cpp/src/llama-arch.h +16 -0
  155. package/src/llama.cpp/src/llama-chat.cpp +82 -17
  156. package/src/llama.cpp/src/llama-chat.h +6 -2
  157. package/src/llama.cpp/src/llama-context.cpp +108 -92
  158. package/src/llama.cpp/src/llama-context.h +1 -2
  159. package/src/llama.cpp/src/llama-graph.cpp +189 -119
  160. package/src/llama.cpp/src/llama-graph.h +26 -6
  161. package/src/llama.cpp/src/llama-hparams.h +13 -0
  162. package/src/llama.cpp/src/llama-kv-cache.cpp +70 -123
  163. package/src/llama.cpp/src/llama-kv-cache.h +41 -115
  164. package/src/llama.cpp/src/llama-memory.h +1 -1
  165. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  166. package/src/llama.cpp/src/llama-model-loader.cpp +10 -5
  167. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  168. package/src/llama.cpp/src/llama-model.cpp +1760 -534
  169. package/src/llama.cpp/src/llama-model.h +13 -1
  170. package/src/llama.cpp/src/llama-quant.cpp +29 -8
  171. package/src/llama.cpp/src/llama-sampling.cpp +7 -1
  172. package/src/llama.cpp/src/llama-vocab.cpp +44 -6
  173. package/src/llama.cpp/src/llama.cpp +1 -1
  174. package/src/llama.cpp/tests/CMakeLists.txt +43 -30
  175. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  176. package/src/llama.cpp/tests/test-backend-ops.cpp +82 -43
  177. package/src/llama.cpp/tests/test-chat-template.cpp +34 -13
  178. package/src/llama.cpp/tests/test-chat.cpp +12 -2
  179. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  180. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  181. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  182. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  183. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  184. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  185. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  186. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  187. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  188. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  189. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  190. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  191. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  192. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  193. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  194. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  195. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  196. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  197. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  198. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  199. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  200. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  201. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  202. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
@@ -29,6 +29,8 @@
29
29
  #include <cstdio>
30
30
  #include <cstring>
31
31
  #include <mutex>
32
+ #include <queue>
33
+ #include <chrono>
32
34
 
33
35
  #include "ggml-impl.h"
34
36
  #include "ggml-backend-impl.h"
@@ -119,9 +121,10 @@ static ggml_cann_device_info ggml_cann_init() {
119
121
  prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
120
122
  prop.location.id = id;
121
123
  prop.reserve = 0;
122
- ACL_CHECK(aclrtMemGetAllocationGranularity(
124
+ err = aclrtMemGetAllocationGranularity(
123
125
  &prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED,
124
- &info.devices[id].vmm_granularity));
126
+ &info.devices[id].vmm_granularity);
127
+ info.devices[id].vmm = err == ACL_SUCCESS;
125
128
 
126
129
  size_t free, total;
127
130
  ggml_backend_cann_get_device_memory(id, &free, &total);
@@ -148,11 +151,223 @@ const ggml_cann_device_info& ggml_cann_info() {
148
151
 
149
152
  //#define DEBUG_CANN_MALLOC
150
153
  /**
151
- * @brief A pool of CANN buffers(legacy).
154
+ * @brief A pool of CANN buffers(priority segment buffer).
152
155
  *
153
156
  * This class manages a pool of CANN buffers for a specific device.
154
157
  */
155
- struct ggml_cann_pool_leg : public ggml_cann_pool {
158
+ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
159
+ /**
160
+ * @brief The maximum reuse margin for a buffer.
161
+ */
162
+ static const size_t max_reuse_margin = 1ull << 22; // 4MB
163
+
164
+ /**
165
+ * @brief The minimum free margin for a buffer.
166
+ */
167
+ static const size_t min_free_margin = 1ull << 20; // 1MB
168
+
169
+ /**
170
+ * @brief The alignment for buffer allocation.
171
+ */
172
+ static const size_t alignment = 128;
173
+
174
+ /**
175
+ * @brief The device ID associated with this buffer pool.
176
+ */
177
+ int device;
178
+
179
+ /**
180
+ * @brief Whether to disable clean during buffer allocation.
181
+ */
182
+ bool disable_clean = false;
183
+
184
+ /**
185
+ * @brief Structure representing a CANN buffer.
186
+ */
187
+ struct ggml_cann_buffer {
188
+ void* ptr = nullptr; ///< Pointer to the buffer.
189
+ size_t size = 0; ///< Size of the buffer.
190
+ std::chrono::steady_clock::time_point last_used; ///< Last used time.
191
+
192
+ bool operator>(const ggml_cann_buffer& other) const {
193
+ return size > other.size;
194
+ }
195
+ };
196
+
197
+ /**
198
+ * @brief Array of CANN buffers in the pool.
199
+ */
200
+ std::unordered_map<void*, size_t> buffer_pool;
201
+ std::priority_queue<ggml_cann_buffer,
202
+ std::vector<ggml_cann_buffer>,
203
+ std::greater<>> free_buffers ;
204
+
205
+ /**
206
+ * @brief Total size of all buffers in the pool.
207
+ */
208
+ size_t pool_size = 0;
209
+
210
+ /**
211
+ * @brief Constructor to initialize the buffer pool for a specific device.
212
+ *
213
+ * @param device The device ID to associate with this buffer pool.
214
+ */
215
+ explicit ggml_cann_pool_buf_prio(int device) : device(device) {
216
+ disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr;
217
+ }
218
+
219
+ /**
220
+ * @brief Destructor to free all buffers in the pool.
221
+ */
222
+ ~ggml_cann_pool_buf_prio() {
223
+ ggml_cann_set_device(device);
224
+ for (auto& [b_ptr, b_size] : buffer_pool) {
225
+ aclrtFree(b_ptr);
226
+ pool_size -= b_size;
227
+ }
228
+ buffer_pool.clear();
229
+ GGML_ASSERT(pool_size == 0);
230
+ }
231
+
232
+ /**
233
+ * @brief Allocate a buffer of the given size.
234
+ *
235
+ * @param size The size of the buffer to allocate.
236
+ * @param actual_size A pointer to a variable to receive the actual size of
237
+ * the allocated buffer.
238
+ * @return A pointer to the allocated buffer.
239
+ */
240
+ void* alloc(size_t size, size_t* actual_size) override {
241
+ size = GGML_PAD(size, alignment);
242
+ if (size == 0) {
243
+ size = alignment;
244
+ }
245
+
246
+ void* ptr = nullptr;
247
+ auto now = std::chrono::steady_clock::now();
248
+
249
+ std::vector<ggml_cann_buffer> free_buffers_rest;
250
+ free_buffers_rest.reserve(free_buffers.size());
251
+ while (!free_buffers.empty()) {
252
+ auto b = free_buffers.top();
253
+ free_buffers.pop();
254
+
255
+ if (b.size >= size) {
256
+ // reuse the buffer if the size is enough
257
+ const size_t margin = b.size - size;
258
+ if (margin <= max_reuse_margin) {
259
+ *actual_size = b.size;
260
+ ptr = b.ptr;
261
+ #ifdef DEBUG_CANN_MALLOC
262
+ GGML_LOG_INFO(
263
+ "cann pool[%d]: reused %p, "
264
+ "pool_size = %5u MB, "
265
+ "size = %5u MB, "
266
+ "margin = %5u MB\n",
267
+ device, b.ptr,
268
+ (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
269
+ (uint32_t)(GGML_PAD(size, 1048576) / 1048576),
270
+ (uint32_t)(GGML_PAD(margin, 1048576) / 1048576));
271
+ #endif
272
+ break;
273
+ }
274
+ }
275
+
276
+ bool should_clean = !disable_clean &&
277
+ b.size > min_free_margin &&
278
+ std::chrono::duration_cast<std::chrono::milliseconds>(now - b.last_used).count() > 100;
279
+ if (should_clean) {
280
+ // free the buffer if the size is needed to be freed
281
+ ACL_CHECK(aclrtFree(b.ptr));
282
+ pool_size -= b.size;
283
+ buffer_pool.erase(b.ptr);
284
+ #ifdef DEBUG_CANN_MALLOC
285
+ GGML_LOG_INFO(
286
+ "cann pool[%d]: clean %p, "
287
+ "pool_size = %5u MB, "
288
+ "size = %5u MB\n",
289
+ device, b.ptr,
290
+ (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
291
+ (uint32_t)(GGML_PAD(b.size, 1048576) / 1048576));
292
+ #endif
293
+ continue;
294
+ }
295
+ free_buffers_rest.push_back(b);
296
+ }
297
+ for (ggml_cann_buffer &b : free_buffers_rest) {
298
+ free_buffers.push(std::move(b));
299
+ }
300
+
301
+ #ifdef DEBUG_CANN_MALLOC
302
+ GGML_LOG_INFO("cann pool[%d] free pool_size = %5u MB\n\n", device, (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576));
303
+ #endif
304
+ if (ptr != nullptr) {
305
+ return ptr;
306
+ }
307
+
308
+ // allocate a new buffer if no buffer can be reused
309
+ ggml_cann_set_device(device);
310
+ ACL_CHECK(aclrtMalloc(&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST));
311
+ *actual_size = size;
312
+ pool_size += size;
313
+ #ifdef DEBUG_CANN_MALLOC
314
+ GGML_LOG_INFO(
315
+ "cann pool[%d]: allocate %p, "
316
+ "pool_size = %5u MB, "
317
+ "size = %5u MB\n",
318
+ device, ptr, (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
319
+ (uint32_t)(GGML_PAD(size, 1048576) / 1048576));
320
+ #endif
321
+ buffer_pool.emplace(ptr, size);
322
+ return ptr;
323
+ }
324
+
325
+ /**
326
+ * @brief Free a buffer and return it to the pool.
327
+ *
328
+ * @param ptr Pointer to the buffer to free.
329
+ * @param size Size of the buffer to free.
330
+ */
331
+ void free(void* ptr, size_t size) override {
332
+ GGML_UNUSED(size);
333
+ auto it = buffer_pool.find(ptr);
334
+ if (it == buffer_pool.end()) {
335
+ GGML_ABORT("cann pool[%d]: buffer %p not found in pool\n", device, ptr);
336
+ }
337
+
338
+ auto now = std::chrono::steady_clock::now();
339
+ free_buffers.emplace(ggml_cann_buffer{ptr, it->second, now});
340
+ #ifdef DEBUG_CANN_MALLOC
341
+ GGML_LOG_INFO(
342
+ "cann pool[%d]: return %p, "
343
+ "pool_size = %5u MB\n",
344
+ device, ptr,
345
+ (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576));
346
+ #endif
347
+ }
348
+ };
349
+
350
+ /**
351
+ * @brief A pool of CANN buffers(segment buffer).
352
+ *
353
+ * This class manages a pool of CANN buffers for a specific device.
354
+ */
355
+ struct ggml_cann_pool_buf : public ggml_cann_pool {
356
+ /**
357
+ * @brief The maximum reuse margin for a buffer.
358
+ */
359
+ static const size_t max_reuse_margin = 1ull << 22; // 4MB
360
+
361
+ /**
362
+ * @brief The minimum free margin for a buffer.
363
+ */
364
+ static const size_t min_free_margin = 1ull << 20; // 1MB
365
+
366
+ /**
367
+ * @brief The alignment for buffer allocation.
368
+ */
369
+ static const size_t alignment = 128;
370
+
156
371
  /**
157
372
  * @brief The maximum number of buffers in the pool.
158
373
  */
@@ -163,12 +378,19 @@ struct ggml_cann_pool_leg : public ggml_cann_pool {
163
378
  */
164
379
  int device;
165
380
 
381
+ /**
382
+ * @brief Whether to disable clean during buffer allocation.
383
+ */
384
+ bool disable_clean = false;
385
+
166
386
  /**
167
387
  * @brief Structure representing a CANN buffer.
168
388
  */
169
389
  struct ggml_cann_buffer {
170
390
  void* ptr = nullptr; ///< Pointer to the buffer memory.
171
391
  size_t size = 0; ///< Size of the buffer.
392
+ bool used = false; ///< Whether the buffer is currently in use.
393
+ std::chrono::steady_clock::time_point last_used; ///< Last used time.
172
394
  };
173
395
 
174
396
  /**
@@ -186,17 +408,19 @@ struct ggml_cann_pool_leg : public ggml_cann_pool {
186
408
  *
187
409
  * @param device The device ID to associate with this buffer pool.
188
410
  */
189
- explicit ggml_cann_pool_leg(int device) : device(device) {}
411
+ explicit ggml_cann_pool_buf(int device) : device(device) {
412
+ disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr;
413
+ }
190
414
 
191
415
  /**
192
416
  * @brief Destructor to free all buffers in the pool.
193
417
  */
194
- ~ggml_cann_pool_leg() {
418
+ ~ggml_cann_pool_buf() {
195
419
  ggml_cann_set_device(device);
196
420
  for (int i = 0; i < MAX_BUFFERS; ++i) {
197
421
  ggml_cann_buffer& b = buffer_pool[i];
198
422
  if (b.ptr != nullptr) {
199
- ACL_CHECK(aclrtFree(b.ptr));
423
+ aclrtFree(b.ptr);
200
424
  pool_size -= b.size;
201
425
  }
202
426
  }
@@ -212,63 +436,93 @@ struct ggml_cann_pool_leg : public ggml_cann_pool {
212
436
  * @return A pointer to the allocated buffer.
213
437
  */
214
438
  void* alloc(size_t size, size_t* actual_size) override {
215
- const size_t alignment = 128;
216
439
  size = GGML_PAD(size, alignment);
217
440
  if (size == 0) {
218
441
  size = alignment;
219
442
  }
220
- #ifdef DEBUG_CANN_MALLOC
221
- int nnz = 0;
222
- size_t max_size = 0;
223
- #endif
224
- size_t best_diff = 1ull << 36;
225
- int ibest = -1;
226
- for (int i = 0; i < MAX_BUFFERS; ++i) {
443
+
444
+ void* ptr = nullptr;
445
+ auto now = std::chrono::steady_clock::now();
446
+
447
+ int i = 0;
448
+ for (; i < MAX_BUFFERS; ++i) {
227
449
  ggml_cann_buffer& b = buffer_pool[i];
228
- if (b.ptr != nullptr) {
450
+ if (b.ptr == nullptr) {
451
+ break;
452
+ }
453
+ if (b.used) {
454
+ continue;
455
+ }
456
+ if (b.size >= size) {
457
+ // reuse the buffer if the size is enough
458
+ const size_t margin = b.size - size;
459
+ if (margin <= max_reuse_margin) {
460
+ *actual_size = b.size;
461
+ b.used = true;
462
+ ptr = b.ptr;
229
463
  #ifdef DEBUG_CANN_MALLOC
230
- ++nnz;
231
- if (b.size > max_size) max_size = b.size;
464
+ GGML_LOG_INFO(
465
+ "cann pool[%d]: reused %p, "
466
+ "pool_size = %5u MB, "
467
+ "size = %5u MB, "
468
+ "margin = %5u MB\n",
469
+ device, b.ptr,
470
+ (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
471
+ (uint32_t)(GGML_PAD(size, 1048576) / 1048576),
472
+ (uint32_t)(GGML_PAD(margin, 1048576) / 1048576));
232
473
  #endif
233
- if (b.size >= size) {
234
- size_t diff = b.size - size;
235
- if (diff < best_diff) {
236
- best_diff = diff;
237
- ibest = i;
238
- if (!best_diff) {
239
- void* ptr = b.ptr;
240
- *actual_size = b.size;
241
- b.ptr = nullptr;
242
- b.size = 0;
243
- return ptr;
244
- }
245
- }
474
+ break;
246
475
  }
247
476
  }
477
+
478
+ bool should_clean = !disable_clean &&
479
+ b.size > min_free_margin &&
480
+ std::chrono::duration_cast<std::chrono::milliseconds>(now - b.last_used).count() > 100;
481
+ if (should_clean) {
482
+ // free the buffer if the size is needed to be freed
483
+ ACL_CHECK(aclrtFree(b.ptr));
484
+ pool_size -= b.size;
485
+ #ifdef DEBUG_CANN_MALLOC
486
+ GGML_LOG_INFO(
487
+ "cann pool[%d]: clean %p, "
488
+ "pool_size = %5u MB, "
489
+ "size = %5u MB\n",
490
+ device, b.ptr,
491
+ (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
492
+ (uint32_t)(GGML_PAD(b.size, 1048576) / 1048576));
493
+ #endif
494
+ b.ptr = nullptr;
495
+ }
248
496
  }
249
- if (ibest >= 0) {
250
- ggml_cann_buffer& b = buffer_pool[ibest];
251
- void* ptr = b.ptr;
252
- *actual_size = b.size;
253
- b.ptr = nullptr;
254
- b.size = 0;
497
+ if (ptr != nullptr) {
255
498
  return ptr;
256
499
  }
257
- void* ptr;
258
- ggml_cann_set_device(device);
259
- ACL_CHECK(
260
- aclrtMalloc(&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST));
261
- *actual_size = size;
262
- pool_size += size;
500
+
501
+ if (i < MAX_BUFFERS) {
502
+ // allocate a new buffer if no buffer can be reused
503
+ ggml_cann_buffer& b = buffer_pool[i];
504
+ ggml_cann_set_device(device);
505
+ ACL_CHECK(aclrtMalloc(&b.ptr, size, ACL_MEM_MALLOC_HUGE_FIRST));
506
+ pool_size += size;
507
+ *actual_size = size;
508
+ b.size = size;
509
+ b.used = true;
510
+ if (i >= MAX_BUFFERS - 8) {
511
+ GGML_LOG_WARN("cann pool[%d]: slots almost full\n", device);
512
+ }
263
513
  #ifdef DEBUG_CANN_MALLOC
264
- GGML_LOG_INFO(
265
- "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, "
266
- "requested %u MB\n",
267
- __func__, device, nnz, (uint32_t)(max_size / 1024 / 1024),
268
- (uint32_t)(pool_size / 1024 / 1024),
269
- (uint32_t)(size / 1024 / 1024));
514
+ GGML_LOG_INFO(
515
+ "cann pool[%d]: allocate %p, "
516
+ "pool_size = %5u MB, "
517
+ "size = %5u MB\n",
518
+ device, b.ptr,
519
+ (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576),
520
+ (uint32_t)(GGML_PAD(b.size, 1048576) / 1048576));
270
521
  #endif
271
- return ptr;
522
+ return b.ptr;
523
+ }
524
+
525
+ GGML_ABORT("cann pool[%d]: slots full\n", device);
272
526
  }
273
527
 
274
528
  /**
@@ -278,18 +532,24 @@ struct ggml_cann_pool_leg : public ggml_cann_pool {
278
532
  * @param size Size of the buffer to free.
279
533
  */
280
534
  void free(void* ptr, size_t size) override {
535
+ GGML_UNUSED(size);
281
536
  for (int i = 0; i < MAX_BUFFERS; ++i) {
282
537
  ggml_cann_buffer& b = buffer_pool[i];
283
- if (b.ptr == nullptr) {
284
- b.ptr = ptr;
285
- b.size = size;
286
- return;
538
+ if (b.ptr != ptr) {
539
+ continue;
287
540
  }
541
+ b.used = false;
542
+ b.last_used = std::chrono::steady_clock::now();
543
+ #ifdef DEBUG_CANN_MALLOC
544
+ GGML_LOG_INFO(
545
+ "cann pool[%d]: return %p, "
546
+ "pool_size = %5u MB\n",
547
+ device, b.ptr,
548
+ (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576));
549
+ #endif
550
+ return;
288
551
  }
289
- // memory should always buffered. these memory may still needed by
290
- // tasks in stream.
291
- // TODO, fix me.
292
- GGML_ABORT("Cann buffer pool full, increase MAX_CANN_BUFFERS\n");
552
+ GGML_ABORT("cann pool[%d]: slots full\n", device);
293
553
  }
294
554
  };
295
555
 
@@ -347,8 +607,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
347
607
  * @param device The device ID to associate with this buffer pool.
348
608
  */
349
609
  explicit ggml_cann_pool_vmm(int device)
350
- : device(device),
351
- granularity(ggml_cann_info().devices[device].vmm_granularity) {
610
+ : device(device) {
352
611
  auto dev = ggml_cann_info().devices[device];
353
612
  granularity = dev.vmm_granularity;
354
613
  max_size = dev.total_vram;
@@ -471,7 +730,18 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
471
730
  */
472
731
  std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
473
732
  int device) {
474
- return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
733
+ bool disable_vmm = (getenv("GGML_CANN_DISABLE_VMM_POOL") != nullptr);
734
+ if (!disable_vmm && ggml_cann_info().devices[device].vmm) {
735
+ GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device);
736
+ return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
737
+ }
738
+ bool enable_buf_prio = (getenv("GGML_CANN_ENABLE_BUF_PRIO_POOL") != nullptr);
739
+ if (enable_buf_prio) {
740
+ GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device);
741
+ return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf_prio(device));
742
+ }
743
+ GGML_LOG_INFO("%s: device %d use buffer pool\n", __func__, device);
744
+ return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf(device));
475
745
  }
476
746
 
477
747
  // cann buffer
@@ -803,7 +1073,7 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(
803
1073
  return GGML_STATUS_SUCCESS;
804
1074
  }
805
1075
 
806
- // TODO: can backend doesn't support quantized yet. Just leave the code
1076
+ // TODO: cann backend doesn't support quantized yet. Just leave the code
807
1077
  // here.
808
1078
  if (ggml_is_quantized(tensor->type)) {
809
1079
  // Initialize padding to 0 to avoid possible NaN values
@@ -1020,8 +1290,11 @@ ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
1020
1290
 
1021
1291
  ggml_cann_set_device(buft_ctx->device);
1022
1292
 
1023
- size = std::max(size, (size_t)1);
1024
-
1293
+ const size_t alignment = 128;
1294
+ size = GGML_PAD(size, alignment);
1295
+ if (size == 0) {
1296
+ size = alignment;
1297
+ }
1025
1298
  void* dev_ptr;
1026
1299
  aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
1027
1300
  if (err != ACL_SUCCESS) {
@@ -1300,47 +1573,69 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1300
1573
  ggml_cann_dup(ctx, dst);
1301
1574
  break;
1302
1575
  case GGML_OP_ADD:
1303
- ggml_cann_add(ctx, dst);
1576
+ case GGML_OP_ADD1:
1577
+ ggml_cann_binary_op<aclnn_add>(ctx, dst);
1578
+ break;
1579
+ case GGML_OP_SUB:
1580
+ ggml_cann_binary_op<aclnn_sub>(ctx, dst);
1304
1581
  break;
1305
1582
  case GGML_OP_ACC:
1306
1583
  ggml_cann_acc(ctx, dst);
1307
1584
  break;
1308
1585
  case GGML_OP_MUL:
1309
- ggml_cann_mul_div<aclnnMulGetWorkspaceSize, aclnnMul>(ctx, dst);
1586
+ ggml_cann_binary_op<aclnn_mul>(ctx, dst);
1310
1587
  break;
1311
1588
  case GGML_OP_DIV:
1312
- ggml_cann_mul_div<aclnnDivGetWorkspaceSize, aclnnDiv>(ctx, dst);
1589
+ ggml_cann_binary_op<aclnn_div>(ctx, dst);
1313
1590
  break;
1314
1591
  case GGML_OP_UNARY:
1315
1592
  switch (ggml_get_unary_op(dst)) {
1593
+ case GGML_UNARY_OP_ABS:
1594
+ GGML_CANN_CALL_UNARY_OP(Abs);
1595
+ break;
1596
+ case GGML_UNARY_OP_NEG:
1597
+ GGML_CANN_CALL_UNARY_OP(Neg);
1598
+ break;
1316
1599
  case GGML_UNARY_OP_GELU:
1317
- ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
1318
- ctx, dst);
1600
+ GGML_CANN_CALL_UNARY_OP(Gelu);
1319
1601
  break;
1320
1602
  case GGML_UNARY_OP_SILU:
1321
- ggml_cann_activation<aclnnSiluGetWorkspaceSize, aclnnSilu>(
1322
- ctx, dst);
1323
- break;
1324
- // TODO: Use faster gelu??
1325
- case GGML_UNARY_OP_GELU_QUICK:
1326
- ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
1327
- ctx, dst);
1603
+ GGML_CANN_CALL_UNARY_OP(Silu);
1328
1604
  break;
1605
+ case GGML_UNARY_OP_GELU_QUICK: {
1606
+ auto lambda = [](ggml_backend_cann_context& ctx,
1607
+ aclTensor* acl_src,
1608
+ aclTensor* acl_dst) {
1609
+ GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
1610
+ };
1611
+ ggml_cann_unary_op(lambda, ctx, dst);
1612
+ } break;
1329
1613
  case GGML_UNARY_OP_TANH:
1330
- ggml_cann_activation<aclnnTanhGetWorkspaceSize, aclnnTanh>(
1331
- ctx, dst);
1614
+ GGML_CANN_CALL_UNARY_OP(Tanh);
1332
1615
  break;
1333
1616
  case GGML_UNARY_OP_RELU:
1334
- ggml_cann_activation<aclnnReluGetWorkspaceSize, aclnnRelu>(
1335
- ctx, dst);
1617
+ GGML_CANN_CALL_UNARY_OP(Relu);
1618
+ break;
1619
+ case GGML_UNARY_OP_SIGMOID:
1620
+ GGML_CANN_CALL_UNARY_OP(Sigmoid);
1336
1621
  break;
1337
1622
  case GGML_UNARY_OP_HARDSIGMOID:
1338
- ggml_cann_activation<aclnnHardsigmoidGetWorkspaceSize,
1339
- aclnnHardsigmoid>(ctx, dst);
1623
+ GGML_CANN_CALL_UNARY_OP(Hardsigmoid);
1340
1624
  break;
1341
1625
  case GGML_UNARY_OP_HARDSWISH:
1342
- ggml_cann_activation<aclnnHardswishGetWorkspaceSize,
1343
- aclnnHardswish>(ctx, dst);
1626
+ GGML_CANN_CALL_UNARY_OP(Hardswish);
1627
+ break;
1628
+ case GGML_UNARY_OP_EXP:
1629
+ GGML_CANN_CALL_UNARY_OP(Exp);
1630
+ break;
1631
+ case GGML_UNARY_OP_ELU:
1632
+ ggml_cann_elu(ctx, dst);
1633
+ break;
1634
+ case GGML_UNARY_OP_SGN:
1635
+ GGML_CANN_CALL_UNARY_OP(Sign);
1636
+ break;
1637
+ case GGML_UNARY_OP_STEP:
1638
+ ggml_cann_step(ctx, dst);
1344
1639
  break;
1345
1640
  default:
1346
1641
  return false;
@@ -1382,7 +1677,12 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1382
1677
  ggml_cann_scale(ctx, dst);
1383
1678
  break;
1384
1679
  case GGML_OP_SQR:
1385
- ggml_cann_sqr(ctx, dst);
1680
+ GGML_ASSERT(dst->src[1] == nullptr);
1681
+ dst->src[1] = dst->src[0];
1682
+ ggml_cann_binary_op<aclnn_mul>(ctx, dst);
1683
+ break;
1684
+ case GGML_OP_SQRT:
1685
+ GGML_CANN_CALL_UNARY_OP(Sqrt);
1386
1686
  break;
1387
1687
  case GGML_OP_CLAMP:
1388
1688
  ggml_cann_clamp(ctx, dst);
@@ -1414,12 +1714,39 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1414
1714
  case GGML_OP_POOL_2D:
1415
1715
  ggml_cann_pool2d(ctx, dst);
1416
1716
  break;
1717
+ case GGML_OP_SUM:
1718
+ ggml_cann_sum(ctx, dst);
1719
+ break;
1417
1720
  case GGML_OP_SUM_ROWS:
1418
1721
  ggml_cann_sum_rows(ctx, dst);
1419
1722
  break;
1420
1723
  case GGML_OP_ARGSORT:
1421
1724
  ggml_cann_argsort(ctx, dst);
1422
1725
  break;
1726
+ case GGML_OP_ARGMAX:
1727
+ ggml_cann_argmax(ctx, dst);
1728
+ break;
1729
+ case GGML_OP_COS:
1730
+ ggml_cann_unary_op<aclnn_cos>(ctx, dst);
1731
+ break;
1732
+ case GGML_OP_SIN:
1733
+ ggml_cann_unary_op<aclnn_sin>(ctx, dst);
1734
+ break;
1735
+ case GGML_OP_CONV_TRANSPOSE_1D:
1736
+ ggml_cann_conv_transpose_1d(ctx, dst);
1737
+ break;
1738
+ case GGML_OP_LOG:
1739
+ GGML_CANN_CALL_UNARY_OP(Log);
1740
+ break;
1741
+ case GGML_OP_MEAN:
1742
+ ggml_cann_mean(ctx, dst);
1743
+ break;
1744
+ case GGML_OP_PAD_REFLECT_1D:
1745
+ ggml_cann_pad_reflect_1d(ctx, dst);
1746
+ break;
1747
+ case GGML_OP_COUNT_EQUAL:
1748
+ ggml_cann_count_equal(ctx, dst);
1749
+ break;
1423
1750
  default:
1424
1751
  return false;
1425
1752
  }
@@ -1458,21 +1785,15 @@ static void ggml_backend_cann_free(ggml_backend_t backend) {
1458
1785
  ACL_CHECK(aclrtSynchronizeDevice());
1459
1786
  ACL_CHECK(aclrtResetDevice(cann_ctx->device));
1460
1787
 
1461
- // finalize when last backend freed.
1462
- if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) {
1463
- ACL_CHECK(aclFinalize());
1464
- }
1465
-
1466
1788
  delete cann_ctx;
1467
1789
  delete backend;
1468
1790
  }
1469
1791
 
1792
+
1470
1793
  /**
1471
1794
  * @brief Sets tensor data asynchronously in the CANN backend.
1472
1795
  *
1473
- * This function asynchronously sets tensor data in the CANN backend. Depending
1474
- * on the tensor type, it may perform data transformations before copying data
1475
- * to the device.
1796
+ * This function asynchronously sets tensor data in the CANN backend.
1476
1797
  *
1477
1798
  * @param backend Pointer to the CANN backend structure.
1478
1799
  * @param tensor Pointer to the tensor structure to set data for.
@@ -1487,23 +1808,28 @@ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
1487
1808
  size_t size) {
1488
1809
  ggml_backend_cann_context *cann_ctx =
1489
1810
  (ggml_backend_cann_context *)backend->context;
1811
+ ggml_backend_buffer_t buf =
1812
+ tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
1490
1813
 
1491
- if (!need_transform(tensor->type)) {
1492
- ACL_CHECK(aclrtMemcpyAsync((char *)tensor->data + offset, size, data,
1493
- size, ACL_MEMCPY_HOST_TO_DEVICE,
1494
- cann_ctx->stream()));
1495
- } else {
1496
- void *transform_buffer = malloc(size);
1497
- ggml_backend_cann_transform(tensor, data, transform_buffer);
1814
+ GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
1815
+ "unsupported buffer type");
1816
+ GGML_ASSERT(!ggml_is_quantized(tensor->type));
1498
1817
 
1499
- ACL_CHECK(aclrtMemcpyAsync(
1500
- (char *)tensor->data + offset, size, transform_buffer, size,
1501
- ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream()));
1502
- ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
1503
- free(transform_buffer);
1504
- }
1818
+ ggml_cann_async_memcpy(cann_ctx, (char *)tensor->data + offset, data, size,
1819
+ ACL_MEMCPY_HOST_TO_DEVICE);
1505
1820
  }
1506
1821
 
1822
+ /**
1823
+ * @brief Gets tensor data asynchronously in the CANN backend.
1824
+ *
1825
+ * This function asynchronously gets tensor data in the CANN backend.
1826
+ *
1827
+ * @param backend Pointer to the CANN backend structure.
1828
+ * @param tensor Pointer to the tensor structure to get data from.
1829
+ * @param data Pointer to the host data to copy from the tensor.
1830
+ * @param offset Offset in bytes within the host data.
1831
+ * @param size Size of the data to copy in bytes.
1832
+ */
1507
1833
  static void ggml_backend_cann_get_tensor_async(
1508
1834
  ggml_backend_t backend, const ggml_tensor *tensor, void *data,
1509
1835
  size_t offset, size_t size) {
@@ -1514,20 +1840,11 @@ static void ggml_backend_cann_get_tensor_async(
1514
1840
 
1515
1841
  GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
1516
1842
  "unsupported buffer type");
1843
+ GGML_ASSERT(!ggml_is_quantized(tensor->type));
1844
+
1845
+ ggml_cann_async_memcpy(cann_ctx, data, (char *)tensor->data + offset, size,
1846
+ ACL_MEMCPY_DEVICE_TO_HOST);
1517
1847
 
1518
- if (!need_transform(tensor->type)) {
1519
- ACL_CHECK(aclrtMemcpyAsync(data, size, (char *)tensor->data + offset,
1520
- size, ACL_MEMCPY_DEVICE_TO_HOST,
1521
- cann_ctx->stream()));
1522
- } else {
1523
- void *transform_buffer = malloc(size);
1524
- ACL_CHECK(aclrtMemcpyAsync(
1525
- transform_buffer, size, (char *)tensor->data + offset, size,
1526
- ACL_MEMCPY_DEVICE_TO_HOST, cann_ctx->stream()));
1527
- ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
1528
- ggml_backend_cann_transform_back(tensor, transform_buffer, data);
1529
- free(transform_buffer);
1530
- }
1531
1848
  }
1532
1849
 
1533
1850
  /**
@@ -1587,6 +1904,8 @@ static bool ggml_backend_cann_cpy_tensor_async(
1587
1904
  ggml_cann_set_device(cann_ctx_src->device);
1588
1905
  ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
1589
1906
 
1907
+ // wait for task_queue empty to keep task order.
1908
+ cann_ctx_src->task_queue.wait();
1590
1909
  ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
1591
1910
  ACL_MEMCPY_DEVICE_TO_DEVICE,
1592
1911
  cann_ctx_src->stream()));
@@ -1614,9 +1933,8 @@ static bool ggml_backend_cann_cpy_tensor_async(
1614
1933
  static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
1615
1934
  ggml_backend_cann_context* cann_ctx =
1616
1935
  (ggml_backend_cann_context*)backend->context;
1617
-
1936
+ cann_ctx->task_queue.wait();
1618
1937
  ggml_cann_set_device(cann_ctx->device);
1619
-
1620
1938
  ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
1621
1939
  }
1622
1940
 
@@ -1675,24 +1993,38 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
1675
1993
  switch (op->op) {
1676
1994
  case GGML_OP_UNARY:
1677
1995
  switch (ggml_get_unary_op(op)) {
1996
+ case GGML_UNARY_OP_ABS:
1997
+ case GGML_UNARY_OP_NEG:
1678
1998
  case GGML_UNARY_OP_GELU:
1679
1999
  case GGML_UNARY_OP_SILU:
1680
2000
  case GGML_UNARY_OP_RELU:
2001
+ case GGML_UNARY_OP_SIGMOID:
1681
2002
  case GGML_UNARY_OP_HARDSIGMOID:
1682
2003
  case GGML_UNARY_OP_HARDSWISH:
1683
2004
  case GGML_UNARY_OP_GELU_QUICK:
1684
2005
  case GGML_UNARY_OP_TANH:
2006
+ case GGML_UNARY_OP_EXP:
2007
+ case GGML_UNARY_OP_ELU:
2008
+ case GGML_UNARY_OP_SGN:
2009
+ case GGML_UNARY_OP_STEP:
1685
2010
  return true;
1686
2011
  default:
1687
2012
  return false;
1688
2013
  }
1689
2014
  case GGML_OP_MUL_MAT: {
1690
2015
  switch (op->src[0]->type) {
1691
- case GGML_TYPE_Q8_0:
1692
2016
  case GGML_TYPE_F16:
1693
2017
  case GGML_TYPE_F32:
1694
- case GGML_TYPE_Q4_0:
1695
2018
  return true;
2019
+ case GGML_TYPE_Q8_0:
2020
+ case GGML_TYPE_Q4_0:
2021
+ #ifdef ASCEND_310P
2022
+ // Q4 && Q8 per group is not suppor on 310p device
2023
+ return false;
2024
+ #endif
2025
+ // only support contiguous for quantized types.
2026
+ return ggml_is_contiguous(op->src[0]) &&
2027
+ ggml_is_contiguous(op->src[1]);
1696
2028
  default:
1697
2029
  return false;
1698
2030
  }
@@ -1704,7 +2036,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
1704
2036
  switch (op->src[0]->type) {
1705
2037
  case GGML_TYPE_F32:
1706
2038
  case GGML_TYPE_F16:
1707
- case GGML_TYPE_Q4_0:
1708
2039
  case GGML_TYPE_Q8_0:
1709
2040
  return true;
1710
2041
  default:
@@ -1712,16 +2043,21 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
1712
2043
  }
1713
2044
  } break;
1714
2045
  case GGML_OP_CPY: {
1715
- switch (op->type) {
1716
- case GGML_TYPE_F32:
1717
- case GGML_TYPE_F16:
1718
- case GGML_TYPE_Q8_0:
1719
- case GGML_TYPE_Q4_0:
1720
- return true;
1721
- default:
1722
- return false;
2046
+ ggml_tensor *src = op->src[0];
2047
+ if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
2048
+ (src->type != GGML_TYPE_F32 &&
2049
+ src->type != GGML_TYPE_F16)) {
2050
+ // only support F32 and F16.
2051
+ return false;
1723
2052
  }
1724
- }
2053
+
2054
+ if (!ggml_are_same_shape(op, src) && !ggml_is_contiguous(op)) {
2055
+ // unsupport dst is not contiguous.
2056
+ return false;
2057
+ }
2058
+
2059
+ return true;
2060
+ } break;
1725
2061
  case GGML_OP_CONT: {
1726
2062
  // TODO: support GGML_TYPE_BF16
1727
2063
  switch (op->src[0]->type) {
@@ -1734,13 +2070,14 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
1734
2070
  }
1735
2071
  case GGML_OP_ROPE: {
1736
2072
  // TODO: with ops-test v == 1
1737
- float * ext_factor = (float*)((int32_t*)op->op_params + 7);
2073
+ float ext_factor = 0.0f;
2074
+ memcpy(&ext_factor, (const float *) op->op_params + 7, sizeof(float));
1738
2075
  // TODO: n_dims <= ne0
1739
2076
  if (op->src[0]->ne[0] != op->op_params[1]) {
1740
2077
  return false;
1741
2078
  }
1742
2079
  // TODO: ext_factor != 0
1743
- if (*ext_factor != 0) {
2080
+ if (ext_factor != 0) {
1744
2081
  return false;
1745
2082
  }
1746
2083
 
@@ -1752,6 +2089,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
1752
2089
  return false;
1753
2090
  }
1754
2091
 
2092
+ if(!ggml_is_contiguous(op->src[0])){
2093
+ return false;
2094
+ }
1755
2095
  return true;
1756
2096
  }
1757
2097
  case GGML_OP_UPSCALE: {
@@ -1760,11 +2100,31 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
1760
2100
  if (op->src[0]->ne[2] * op->ne[3] != op->src[0]->ne[3] * op->ne[2]) {
1761
2101
  return false;
1762
2102
  }
2103
+ if (op->op_params[0] != GGML_SCALE_MODE_NEAREST) {
2104
+ return false;
2105
+ }
1763
2106
  return true;
1764
2107
  }
2108
+ case GGML_OP_POOL_2D: {
2109
+ const int32_t * opts = (const int32_t *) op->op_params;
2110
+ #ifdef ASCEND_310P
2111
+ enum ggml_op_pool opt = static_cast<ggml_op_pool>(opts[0]);
2112
+ if(opt == GGML_OP_POOL_MAX){
2113
+ return false;
2114
+ }
2115
+ #endif
2116
+ const int k0 = opts[1];
2117
+ const int k1 = opts[2];
2118
+ const int p0 = opts[5];
2119
+ const int p1 = opts[6];
2120
+ // value of paddingH should be at most half of kernelH
2121
+ // value of paddingW should be at most half of kernelW
2122
+ return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
2123
+ }
2124
+ case GGML_OP_SUM:
2125
+ case GGML_OP_DUP:
1765
2126
  case GGML_OP_IM2COL:
1766
2127
  case GGML_OP_CONCAT:
1767
- case GGML_OP_DUP:
1768
2128
  case GGML_OP_REPEAT:
1769
2129
  case GGML_OP_NONE:
1770
2130
  case GGML_OP_RESHAPE:
@@ -1773,15 +2133,17 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
1773
2133
  case GGML_OP_TRANSPOSE:
1774
2134
  case GGML_OP_NORM:
1775
2135
  case GGML_OP_ADD:
2136
+ case GGML_OP_ADD1:
2137
+ case GGML_OP_SUB:
1776
2138
  case GGML_OP_MUL:
1777
2139
  case GGML_OP_DIV:
1778
2140
  case GGML_OP_RMS_NORM:
1779
2141
  case GGML_OP_SCALE:
1780
2142
  case GGML_OP_SQR:
2143
+ case GGML_OP_SQRT:
1781
2144
  case GGML_OP_CLAMP:
1782
2145
  case GGML_OP_DIAG_MASK_INF:
1783
2146
  case GGML_OP_SOFT_MAX:
1784
- case GGML_OP_POOL_2D:
1785
2147
  case GGML_OP_SUM_ROWS:
1786
2148
  case GGML_OP_ARGSORT:
1787
2149
  case GGML_OP_ACC:
@@ -1790,6 +2152,14 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
1790
2152
  case GGML_OP_ARANGE:
1791
2153
  case GGML_OP_TIMESTEP_EMBEDDING:
1792
2154
  case GGML_OP_LEAKY_RELU:
2155
+ case GGML_OP_ARGMAX:
2156
+ case GGML_OP_COS:
2157
+ case GGML_OP_SIN:
2158
+ case GGML_OP_CONV_TRANSPOSE_1D:
2159
+ case GGML_OP_LOG:
2160
+ case GGML_OP_MEAN:
2161
+ case GGML_OP_PAD_REFLECT_1D:
2162
+ case GGML_OP_COUNT_EQUAL:
1793
2163
  return true;
1794
2164
  default:
1795
2165
  return false;